Add 'inference_config' on ES >=7.8

This commit is contained in:
Seth Michael Larson 2020-04-14 07:51:50 -05:00 committed by GitHub
parent 448770df78
commit e1cacead44
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 33 additions and 11 deletions

View File

@ -2,6 +2,7 @@
ELASTICSEARCH_VERSION:
- 8.0.0-SNAPSHOT
- 7.x-SNAPSHOT
- 7.6-SNAPSHOT
TEST_SUITE:

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import re
import warnings
from enum import Enum
from typing import Union, List, Tuple
@ -266,5 +267,20 @@ def ensure_es_client(
es_client: Union[str, List[str], Tuple[str, ...], Elasticsearch]
) -> Elasticsearch:
if not isinstance(es_client, Elasticsearch):
return Elasticsearch(es_client)
es_client = Elasticsearch(es_client)
return es_client
def es_version(es_client: Elasticsearch) -> Tuple[int, int, int]:
"""Tags the current ES client with a cached '_eland_es_version'
property if one doesn't exist yet for the current Elasticsearch version.
"""
if not hasattr(es_client, "_eland_es_version"):
major, minor, patch = [
int(x)
for x in re.match(
r"^(\d+)\.(\d+)\.(\d+)", es_client.info()["version"]["number"]
).groups()
]
es_client._eland_es_version = (major, minor, patch)
return es_client._eland_es_version

View File

@ -51,7 +51,9 @@ class ModelSerializer(ABC):
json_string = json.dumps(
{"trained_model": self.to_dict()}, separators=(",", ":")
)
return base64.b64encode(gzip.compress(bytes(json_string, "utf-8")))
return base64.b64encode(gzip.compress(json_string.encode("utf-8"))).decode(
"ascii"
)
class TreeNode:

View File

@ -16,6 +16,7 @@ from typing import Union, List
import numpy as np
from eland.common import es_version
from eland.ml._model_transformers import (
SKLearnDecisionTreeTransformer,
SKLearnForestRegressorTransformer,
@ -157,15 +158,17 @@ class ImportedMLModel(MLModel):
if overwrite:
self.delete_model()
serialized_model = str(serializer.serialize_and_compress_model())[
2:-1
] # remove `b` and str quotes
self._client.ml.put_trained_model(
model_id=self._model_id,
serialized_model = serializer.serialize_and_compress_model()
body = {
"input": {"field_names": feature_names},
"compressed_definition": serialized_model,
},
"input": {"field_names": feature_names},
}
# 'inference_config' is required in 7.8+ but isn't available in <=7.7
if es_version(self._client) >= (7, 8):
body["inference_config"] = {self._model_type: {}}
self._client.ml.put_trained_model(
model_id=self._model_id, body=body,
)
def predict(self, X):

View File

@ -51,6 +51,6 @@ class MLModel:
If model doesn't exist, ignore failure.
"""
try:
self._client.ml.delete_trained_model(model_id=self._model_id)
self._client.ml.delete_trained_model(model_id=self._model_id, ignore=(404,))
except elasticsearch.NotFoundError:
pass