mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
Add 'inference_config' on ES >=7.8
This commit is contained in:
parent
448770df78
commit
e1cacead44
@ -2,6 +2,7 @@
|
||||
|
||||
ELASTICSEARCH_VERSION:
|
||||
- 8.0.0-SNAPSHOT
|
||||
- 7.x-SNAPSHOT
|
||||
- 7.6-SNAPSHOT
|
||||
|
||||
TEST_SUITE:
|
||||
|
@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
import warnings
|
||||
from enum import Enum
|
||||
from typing import Union, List, Tuple
|
||||
@ -266,5 +267,20 @@ def ensure_es_client(
|
||||
es_client: Union[str, List[str], Tuple[str, ...], Elasticsearch]
|
||||
) -> Elasticsearch:
|
||||
if not isinstance(es_client, Elasticsearch):
|
||||
return Elasticsearch(es_client)
|
||||
es_client = Elasticsearch(es_client)
|
||||
return es_client
|
||||
|
||||
|
||||
def es_version(es_client: Elasticsearch) -> Tuple[int, int, int]:
|
||||
"""Tags the current ES client with a cached '_eland_es_version'
|
||||
property if one doesn't exist yet for the current Elasticsearch version.
|
||||
"""
|
||||
if not hasattr(es_client, "_eland_es_version"):
|
||||
major, minor, patch = [
|
||||
int(x)
|
||||
for x in re.match(
|
||||
r"^(\d+)\.(\d+)\.(\d+)", es_client.info()["version"]["number"]
|
||||
).groups()
|
||||
]
|
||||
es_client._eland_es_version = (major, minor, patch)
|
||||
return es_client._eland_es_version
|
||||
|
@ -51,7 +51,9 @@ class ModelSerializer(ABC):
|
||||
json_string = json.dumps(
|
||||
{"trained_model": self.to_dict()}, separators=(",", ":")
|
||||
)
|
||||
return base64.b64encode(gzip.compress(bytes(json_string, "utf-8")))
|
||||
return base64.b64encode(gzip.compress(json_string.encode("utf-8"))).decode(
|
||||
"ascii"
|
||||
)
|
||||
|
||||
|
||||
class TreeNode:
|
||||
|
@ -16,6 +16,7 @@ from typing import Union, List
|
||||
|
||||
import numpy as np
|
||||
|
||||
from eland.common import es_version
|
||||
from eland.ml._model_transformers import (
|
||||
SKLearnDecisionTreeTransformer,
|
||||
SKLearnForestRegressorTransformer,
|
||||
@ -157,15 +158,17 @@ class ImportedMLModel(MLModel):
|
||||
if overwrite:
|
||||
self.delete_model()
|
||||
|
||||
serialized_model = str(serializer.serialize_and_compress_model())[
|
||||
2:-1
|
||||
] # remove `b` and str quotes
|
||||
self._client.ml.put_trained_model(
|
||||
model_id=self._model_id,
|
||||
body={
|
||||
"input": {"field_names": feature_names},
|
||||
serialized_model = serializer.serialize_and_compress_model()
|
||||
body = {
|
||||
"compressed_definition": serialized_model,
|
||||
},
|
||||
"input": {"field_names": feature_names},
|
||||
}
|
||||
# 'inference_config' is required in 7.8+ but isn't available in <=7.7
|
||||
if es_version(self._client) >= (7, 8):
|
||||
body["inference_config"] = {self._model_type: {}}
|
||||
|
||||
self._client.ml.put_trained_model(
|
||||
model_id=self._model_id, body=body,
|
||||
)
|
||||
|
||||
def predict(self, X):
|
||||
|
@ -51,6 +51,6 @@ class MLModel:
|
||||
If model doesn't exist, ignore failure.
|
||||
"""
|
||||
try:
|
||||
self._client.ml.delete_trained_model(model_id=self._model_id)
|
||||
self._client.ml.delete_trained_model(model_id=self._model_id, ignore=(404,))
|
||||
except elasticsearch.NotFoundError:
|
||||
pass
|
||||
|
Loading…
x
Reference in New Issue
Block a user