From e1cacead44185b7a6b25c0aaca2cd2b6ce9e6ee7 Mon Sep 17 00:00:00 2001 From: Seth Michael Larson Date: Tue, 14 Apr 2020 07:51:50 -0500 Subject: [PATCH] Add 'inference_config' on ES >=7.8 --- .ci/test-matrix.yml | 1 + eland/common.py | 18 +++++++++++++++++- eland/ml/_model_serializer.py | 4 +++- eland/ml/imported_ml_model.py | 19 +++++++++++-------- eland/ml/ml_model.py | 2 +- 5 files changed, 33 insertions(+), 11 deletions(-) diff --git a/.ci/test-matrix.yml b/.ci/test-matrix.yml index de847ed..b50a16c 100755 --- a/.ci/test-matrix.yml +++ b/.ci/test-matrix.yml @@ -2,6 +2,7 @@ ELASTICSEARCH_VERSION: - 8.0.0-SNAPSHOT + - 7.x-SNAPSHOT - 7.6-SNAPSHOT TEST_SUITE: diff --git a/eland/common.py b/eland/common.py index 6b367c3..fe19c09 100644 --- a/eland/common.py +++ b/eland/common.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import re import warnings from enum import Enum from typing import Union, List, Tuple @@ -266,5 +267,20 @@ def ensure_es_client( es_client: Union[str, List[str], Tuple[str, ...], Elasticsearch] ) -> Elasticsearch: if not isinstance(es_client, Elasticsearch): - return Elasticsearch(es_client) + es_client = Elasticsearch(es_client) return es_client + + +def es_version(es_client: Elasticsearch) -> Tuple[int, int, int]: + """Tags the current ES client with a cached '_eland_es_version' + property if one doesn't exist yet for the current Elasticsearch version. + """ + if not hasattr(es_client, "_eland_es_version"): + major, minor, patch = [ + int(x) + for x in re.match( + r"^(\d+)\.(\d+)\.(\d+)", es_client.info()["version"]["number"] + ).groups() + ] + es_client._eland_es_version = (major, minor, patch) + return es_client._eland_es_version diff --git a/eland/ml/_model_serializer.py b/eland/ml/_model_serializer.py index e7a66b7..8b138f2 100644 --- a/eland/ml/_model_serializer.py +++ b/eland/ml/_model_serializer.py @@ -51,7 +51,9 @@ class ModelSerializer(ABC): json_string = json.dumps( {"trained_model": self.to_dict()}, separators=(",", ":") ) - return base64.b64encode(gzip.compress(bytes(json_string, "utf-8"))) + return base64.b64encode(gzip.compress(json_string.encode("utf-8"))).decode( + "ascii" + ) class TreeNode: diff --git a/eland/ml/imported_ml_model.py b/eland/ml/imported_ml_model.py index 6d67ec4..642a940 100644 --- a/eland/ml/imported_ml_model.py +++ b/eland/ml/imported_ml_model.py @@ -16,6 +16,7 @@ from typing import Union, List import numpy as np +from eland.common import es_version from eland.ml._model_transformers import ( SKLearnDecisionTreeTransformer, SKLearnForestRegressorTransformer, @@ -157,15 +158,17 @@ class ImportedMLModel(MLModel): if overwrite: self.delete_model() - serialized_model = str(serializer.serialize_and_compress_model())[ - 2:-1 - ] # remove `b` and str quotes + serialized_model = serializer.serialize_and_compress_model() + body = { + "compressed_definition": serialized_model, + "input": {"field_names": feature_names}, + } + # 'inference_config' is required in 7.8+ but isn't available in <=7.7 + if es_version(self._client) >= (7, 8): + body["inference_config"] = {self._model_type: {}} + self._client.ml.put_trained_model( - model_id=self._model_id, - body={ - "input": {"field_names": feature_names}, - "compressed_definition": serialized_model, - }, + model_id=self._model_id, body=body, ) def predict(self, X): diff --git a/eland/ml/ml_model.py b/eland/ml/ml_model.py index c1ebce8..32ae1a0 100644 --- a/eland/ml/ml_model.py +++ b/eland/ml/ml_model.py @@ -51,6 +51,6 @@ class MLModel: If model doesn't exist, ignore failure. """ try: - self._client.ml.delete_trained_model(model_id=self._model_id) + self._client.ml.delete_trained_model(model_id=self._model_id, ignore=(404,)) except elasticsearch.NotFoundError: pass