mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
Add 'inference_config' on ES >=7.8
This commit is contained in:
parent
448770df78
commit
e1cacead44
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
ELASTICSEARCH_VERSION:
|
ELASTICSEARCH_VERSION:
|
||||||
- 8.0.0-SNAPSHOT
|
- 8.0.0-SNAPSHOT
|
||||||
|
- 7.x-SNAPSHOT
|
||||||
- 7.6-SNAPSHOT
|
- 7.6-SNAPSHOT
|
||||||
|
|
||||||
TEST_SUITE:
|
TEST_SUITE:
|
||||||
|
@ -12,6 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import re
|
||||||
import warnings
|
import warnings
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Union, List, Tuple
|
from typing import Union, List, Tuple
|
||||||
@ -266,5 +267,20 @@ def ensure_es_client(
|
|||||||
es_client: Union[str, List[str], Tuple[str, ...], Elasticsearch]
|
es_client: Union[str, List[str], Tuple[str, ...], Elasticsearch]
|
||||||
) -> Elasticsearch:
|
) -> Elasticsearch:
|
||||||
if not isinstance(es_client, Elasticsearch):
|
if not isinstance(es_client, Elasticsearch):
|
||||||
return Elasticsearch(es_client)
|
es_client = Elasticsearch(es_client)
|
||||||
return es_client
|
return es_client
|
||||||
|
|
||||||
|
|
||||||
|
def es_version(es_client: Elasticsearch) -> Tuple[int, int, int]:
|
||||||
|
"""Tags the current ES client with a cached '_eland_es_version'
|
||||||
|
property if one doesn't exist yet for the current Elasticsearch version.
|
||||||
|
"""
|
||||||
|
if not hasattr(es_client, "_eland_es_version"):
|
||||||
|
major, minor, patch = [
|
||||||
|
int(x)
|
||||||
|
for x in re.match(
|
||||||
|
r"^(\d+)\.(\d+)\.(\d+)", es_client.info()["version"]["number"]
|
||||||
|
).groups()
|
||||||
|
]
|
||||||
|
es_client._eland_es_version = (major, minor, patch)
|
||||||
|
return es_client._eland_es_version
|
||||||
|
@ -51,7 +51,9 @@ class ModelSerializer(ABC):
|
|||||||
json_string = json.dumps(
|
json_string = json.dumps(
|
||||||
{"trained_model": self.to_dict()}, separators=(",", ":")
|
{"trained_model": self.to_dict()}, separators=(",", ":")
|
||||||
)
|
)
|
||||||
return base64.b64encode(gzip.compress(bytes(json_string, "utf-8")))
|
return base64.b64encode(gzip.compress(json_string.encode("utf-8"))).decode(
|
||||||
|
"ascii"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TreeNode:
|
class TreeNode:
|
||||||
|
@ -16,6 +16,7 @@ from typing import Union, List
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from eland.common import es_version
|
||||||
from eland.ml._model_transformers import (
|
from eland.ml._model_transformers import (
|
||||||
SKLearnDecisionTreeTransformer,
|
SKLearnDecisionTreeTransformer,
|
||||||
SKLearnForestRegressorTransformer,
|
SKLearnForestRegressorTransformer,
|
||||||
@ -157,15 +158,17 @@ class ImportedMLModel(MLModel):
|
|||||||
if overwrite:
|
if overwrite:
|
||||||
self.delete_model()
|
self.delete_model()
|
||||||
|
|
||||||
serialized_model = str(serializer.serialize_and_compress_model())[
|
serialized_model = serializer.serialize_and_compress_model()
|
||||||
2:-1
|
body = {
|
||||||
] # remove `b` and str quotes
|
"compressed_definition": serialized_model,
|
||||||
|
"input": {"field_names": feature_names},
|
||||||
|
}
|
||||||
|
# 'inference_config' is required in 7.8+ but isn't available in <=7.7
|
||||||
|
if es_version(self._client) >= (7, 8):
|
||||||
|
body["inference_config"] = {self._model_type: {}}
|
||||||
|
|
||||||
self._client.ml.put_trained_model(
|
self._client.ml.put_trained_model(
|
||||||
model_id=self._model_id,
|
model_id=self._model_id, body=body,
|
||||||
body={
|
|
||||||
"input": {"field_names": feature_names},
|
|
||||||
"compressed_definition": serialized_model,
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def predict(self, X):
|
def predict(self, X):
|
||||||
|
@ -51,6 +51,6 @@ class MLModel:
|
|||||||
If model doesn't exist, ignore failure.
|
If model doesn't exist, ignore failure.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
self._client.ml.delete_trained_model(model_id=self._model_id)
|
self._client.ml.delete_trained_model(model_id=self._model_id, ignore=(404,))
|
||||||
except elasticsearch.NotFoundError:
|
except elasticsearch.NotFoundError:
|
||||||
pass
|
pass
|
||||||
|
Loading…
x
Reference in New Issue
Block a user