Error when MLModel.predict fails, add es_compress_model_definition

This commit is contained in:
Seth Michael Larson 2020-07-08 14:31:27 -05:00 committed by GitHub
parent 5d0df757cf
commit de9c836c5e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 122 additions and 25 deletions

View File

@ -2,7 +2,7 @@
# Elasticsearch B.V licenses this file to you under the Apache 2.0 License.
# See the LICENSE file in the project root for more information
from typing import Union, List, Optional, Tuple, TYPE_CHECKING, cast
from typing import Union, List, Optional, Tuple, TYPE_CHECKING, cast, Dict, Any
import numpy as np # type: ignore
@ -61,6 +61,11 @@ class ImportedMLModel(MLModel):
overwrite: bool
Delete and overwrite existing model (if exists)
es_compress_model_definition: bool
If True will use 'compressed_definition' which uses gzipped
JSON instead of raw JSON to reduce the amount of data sent
over the wire in HTTP requests. Defaults to 'True'.
Examples
--------
>>> from sklearn import datasets
@ -107,6 +112,7 @@ class ImportedMLModel(MLModel):
classification_labels: Optional[List[str]] = None,
classification_weights: Optional[List[float]] = None,
overwrite: bool = False,
es_compress_model_definition: bool = True,
):
super().__init__(es_client, model_id)
@ -124,15 +130,18 @@ class ImportedMLModel(MLModel):
if overwrite:
self.delete_model()
serialized_model = serializer.serialize_and_compress_model()
body = {
"compressed_definition": serialized_model,
body: Dict[str, Any] = {
"input": {"field_names": feature_names},
}
# 'inference_config' is required in 7.8+ but isn't available in <=7.7
if es_version(self._client) >= (7, 8):
body["inference_config"] = {self._model_type: {}}
if es_compress_model_definition:
body["compressed_definition"] = serializer.serialize_and_compress_model()
else:
body["definition"] = serializer.serialize_model()
self._client.ml.put_trained_model(
model_id=self._model_id, body=body,
)
@ -216,10 +225,17 @@ class ImportedMLModel(MLModel):
}
)
y = [
doc["doc"]["_source"]["ml"]["inference"]["predicted_value"]
for doc in results["docs"]
]
# Unpack results into an array. Errors can be present
# within the response without a non-2XX HTTP status code.
y = []
for doc in results["docs"]:
if "error" in doc:
raise RuntimeError(
f"Failed to run prediction for model ID {self._model_id!r}",
doc["error"],
)
y.append(doc["doc"]["_source"]["ml"]["inference"]["predicted_value"])
# Return results as np.ndarray of float32 or int (consistent with sklearn/xgboost)
if self._model_type == MLModel.TYPE_CLASSIFICATION:

View File

@ -4,6 +4,7 @@
import pytest
import numpy as np
from elasticsearch import ElasticsearchException
from eland.ml import ImportedMLModel
from eland.tests import ES_TEST_CLIENT
@ -38,14 +39,58 @@ requires_no_ml_extras = pytest.mark.skipif(
)
@requires_no_ml_extras
def test_import_ml_model_when_dependencies_are_not_available():
from eland.ml import MLModel, ImportedMLModel # noqa: F401
class TestImportedMLModel:
@requires_no_ml_extras
def test_import_ml_model_when_dependencies_are_not_available(self):
from eland.ml import MLModel, ImportedMLModel # noqa: F401
@requires_sklearn
def test_decision_tree_classifier(self):
def test_unpack_and_raise_errors_in_ingest_simulate(self, mocker):
# Train model
training_data = datasets.make_classification(n_features=5)
classifier = DecisionTreeClassifier()
classifier.fit(training_data[0], training_data[1])
# Serialise the models to Elasticsearch
feature_names = ["f0", "f1", "f2", "f3", "f4"]
model_id = "test_decision_tree_classifier"
test_data = [[0.1, 0.2, 0.3, -0.5, 1.0], [1.6, 2.1, -10, 50, -1.0]]
es_model = ImportedMLModel(
ES_TEST_CLIENT,
model_id,
classifier,
feature_names,
overwrite=True,
es_compress_model_definition=True,
)
# Mock the ingest.simulate API to return an error within {'docs': [...]}
mock = mocker.patch.object(ES_TEST_CLIENT.ingest, "simulate")
mock.return_value = {
"docs": [
{
"error": {
"type": "x_content_parse_exception",
"reason": "[1:1052] [inference_model_definition] failed to parse field [trained_model]",
}
}
]
}
with pytest.raises(RuntimeError) as err:
es_model.predict(test_data)
assert repr(err.value) == (
'RuntimeError("Failed to run prediction for model ID '
"'test_decision_tree_classifier'\", {'type': 'x_content_parse_exception', "
"'reason': '[1:1052] [inference_model_definition] failed to parse "
"field [trained_model]'})"
)
@requires_sklearn
@pytest.mark.parametrize("compress_model_definition", [True, False])
def test_decision_tree_classifier(self, compress_model_definition):
# Train model
training_data = datasets.make_classification(n_features=5)
classifier = DecisionTreeClassifier()
@ -60,7 +105,12 @@ class TestImportedMLModel:
model_id = "test_decision_tree_classifier"
es_model = ImportedMLModel(
ES_TEST_CLIENT, model_id, classifier, feature_names, overwrite=True
ES_TEST_CLIENT,
model_id,
classifier,
feature_names,
overwrite=True,
es_compress_model_definition=compress_model_definition,
)
es_results = es_model.predict(test_data)
@ -70,7 +120,8 @@ class TestImportedMLModel:
es_model.delete_model()
@requires_sklearn
def test_decision_tree_regressor(self):
@pytest.mark.parametrize("compress_model_definition", [True, False])
def test_decision_tree_regressor(self, compress_model_definition):
# Train model
training_data = datasets.make_regression(n_features=5)
regressor = DecisionTreeRegressor()
@ -85,7 +136,12 @@ class TestImportedMLModel:
model_id = "test_decision_tree_regressor"
es_model = ImportedMLModel(
ES_TEST_CLIENT, model_id, regressor, feature_names, overwrite=True
ES_TEST_CLIENT,
model_id,
regressor,
feature_names,
overwrite=True,
es_compress_model_definition=compress_model_definition,
)
es_results = es_model.predict(test_data)
@ -95,7 +151,8 @@ class TestImportedMLModel:
es_model.delete_model()
@requires_sklearn
def test_random_forest_classifier(self):
@pytest.mark.parametrize("compress_model_definition", [True, False])
def test_random_forest_classifier(self, compress_model_definition):
# Train model
training_data = datasets.make_classification(n_features=5)
classifier = RandomForestClassifier()
@ -110,7 +167,12 @@ class TestImportedMLModel:
model_id = "test_random_forest_classifier"
es_model = ImportedMLModel(
ES_TEST_CLIENT, model_id, classifier, feature_names, overwrite=True
ES_TEST_CLIENT,
model_id,
classifier,
feature_names,
overwrite=True,
es_compress_model_definition=compress_model_definition,
)
es_results = es_model.predict(test_data)
@ -120,7 +182,8 @@ class TestImportedMLModel:
es_model.delete_model()
@requires_sklearn
def test_random_forest_regressor(self):
@pytest.mark.parametrize("compress_model_definition", [True, False])
def test_random_forest_regressor(self, compress_model_definition):
# Train model
training_data = datasets.make_regression(n_features=5)
regressor = RandomForestRegressor()
@ -135,7 +198,12 @@ class TestImportedMLModel:
model_id = "test_random_forest_regressor"
es_model = ImportedMLModel(
ES_TEST_CLIENT, model_id, regressor, feature_names, overwrite=True
ES_TEST_CLIENT,
model_id,
regressor,
feature_names,
overwrite=True,
es_compress_model_definition=compress_model_definition,
)
es_results = es_model.predict(test_data)
@ -145,7 +213,8 @@ class TestImportedMLModel:
es_model.delete_model()
@requires_xgboost
def test_xgb_classifier(self):
@pytest.mark.parametrize("compress_model_definition", [True, False])
def test_xgb_classifier(self, compress_model_definition):
# Train model
training_data = datasets.make_classification(n_features=5)
classifier = XGBClassifier(booster="gbtree")
@ -160,7 +229,12 @@ class TestImportedMLModel:
model_id = "test_xgb_classifier"
es_model = ImportedMLModel(
ES_TEST_CLIENT, model_id, classifier, feature_names, overwrite=True
ES_TEST_CLIENT,
model_id,
classifier,
feature_names,
overwrite=True,
es_compress_model_definition=compress_model_definition,
)
es_results = es_model.predict(test_data)
@ -170,7 +244,8 @@ class TestImportedMLModel:
es_model.delete_model()
@requires_xgboost
def test_xgb_regressor(self):
@pytest.mark.parametrize("compress_model_definition", [True, False])
def test_xgb_regressor(self, compress_model_definition):
# Train model
training_data = datasets.make_regression(n_features=5)
regressor = XGBRegressor()
@ -185,7 +260,12 @@ class TestImportedMLModel:
model_id = "test_xgb_regressor"
es_model = ImportedMLModel(
ES_TEST_CLIENT, model_id, regressor, feature_names, overwrite=True
ES_TEST_CLIENT,
model_id,
regressor,
feature_names,
overwrite=True,
es_compress_model_definition=compress_model_definition,
)
es_results = es_model.predict(test_data)

View File

@ -2,6 +2,7 @@ elasticsearch>=7.7
pandas>=1
matplotlib
pytest>=5.2.1
pytest-mock
nbval
numpydoc>=0.9.0
scikit-learn>=0.22.1