Replace MLModel(overwrite) with es_if_exists

This commit is contained in:
P. Sai Vinay 2020-08-17 22:40:27 +05:30 committed by GitHub
parent 5bf205a1e0
commit 66b24f9e8a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 170 additions and 18 deletions

View File

@ -151,7 +151,8 @@ currently using a minimum version of PyCharm 2019.2.4.
Tools\'-\>\'Docstring format\' to `numpy`
- Install development requirements. Open terminal in virtual
environment and run `pip install -r requirements-dev.txt`
- Setup Elasticsearch instance (assumes `localhost:9200`), and run
- Setup Elasticsearch instance with docker `ELASTICSEARCH_VERSION=elasticsearch:7.x-SNAPSHOT .ci/run-elasticsearch.sh` and check `http://localhost:9200`
- Run
`python -m eland.tests.setup_tests` to setup test environment -*note
this modifies Elasticsearch indices*
- Install local `eland` module (required to execute notebook tests)

View File

@ -140,7 +140,7 @@ dtype: int64
13057 20819.488281
13058 18315.431274
Length: 13059, dtype: float64
>>> print(s.info_es())
>>> print(s.es_info())
index_pattern: flights
Index:
index_field: _id

View File

@ -22,6 +22,7 @@ import numpy as np # type: ignore
from .ml_model import MLModel
from .transformers import get_model_transformer
from ..common import es_version
import warnings
if TYPE_CHECKING:
@ -100,7 +101,13 @@ class ImportedMLModel(MLModel):
classification_weights: List[str]
Weights of the classification targets
overwrite: bool
es_if_exists: {'fail', 'replace'} default 'fail'
How to behave if model already exists
- fail: Raise a Value Error
- replace: Overwrite existing model
overwrite: **DEPRECATED** - bool
Delete and overwrite existing model (if exists)
es_compress_model_definition: bool
@ -127,7 +134,7 @@ class ImportedMLModel(MLModel):
>>> # Serialise the model to Elasticsearch
>>> feature_names = ["f0", "f1", "f2", "f3", "f4"]
>>> model_id = "test_decision_tree_classifier"
>>> es_model = ImportedMLModel('localhost', model_id, classifier, feature_names, overwrite=True)
>>> es_model = ImportedMLModel('localhost', model_id, classifier, feature_names, es_if_exists='replace')
>>> # Get some test results from Elasticsearch model
>>> es_model.predict(test_data)
@ -155,7 +162,8 @@ class ImportedMLModel(MLModel):
feature_names: List[str],
classification_labels: Optional[List[str]] = None,
classification_weights: Optional[List[float]] = None,
overwrite: bool = False,
es_if_exists: Optional[str] = None,
overwrite: Optional[bool] = None,
es_compress_model_definition: bool = True,
):
super().__init__(es_client, model_id)
@ -171,7 +179,30 @@ class ImportedMLModel(MLModel):
self._model_type = transformer.model_type
serializer = transformer.transform()
if overwrite:
# Verify if both parameters are given
if overwrite is not None and es_if_exists is not None:
raise ValueError(
"Using 'overwrite' and 'es_if_exists' together is invalid, use only 'es_if_exists'"
)
if overwrite is not None:
warnings.warn(
"'overwrite' parameter is deprecated, use 'es_if_exists' instead",
DeprecationWarning,
stacklevel=2,
)
es_if_exists = "replace" if overwrite else "fail"
elif es_if_exists is None:
es_if_exists = "fail"
if es_if_exists not in ("fail", "replace"):
raise ValueError("'es_if_exists' must be either 'fail' or 'replace'")
elif es_if_exists == "fail":
if self.check_existing_model():
raise ValueError(
f"Trained machine learning model {model_id} already exists"
)
elif es_if_exists == "replace":
self.delete_model()
body: Dict[str, Any] = {
@ -224,7 +255,7 @@ class ImportedMLModel(MLModel):
>>> # Serialise the model to Elasticsearch
>>> feature_names = ["f0", "f1", "f2", "f3", "f4", "f5"]
>>> model_id = "test_xgb_regressor"
>>> es_model = ImportedMLModel('localhost', model_id, regressor, feature_names, overwrite=True)
>>> es_model = ImportedMLModel('localhost', model_id, regressor, feature_names, es_if_exists='replace')
>>> # Get some test results from Elasticsearch model
>>> es_model.predict(test_data)

View File

@ -57,3 +57,15 @@ class MLModel:
self._client.ml.delete_trained_model(model_id=self._model_id, ignore=(404,))
except elasticsearch.NotFoundError:
pass
def check_existing_model(self) -> bool:
"""
Check If model exists in Elastic
"""
try:
self._client.ml.get_trained_models(
model_id=self._model_id, include_model_definition=False
)
except elasticsearch.NotFoundError:
return False
return True

View File

@ -103,7 +103,7 @@ class TestImportedMLModel:
model_id,
classifier,
feature_names,
overwrite=True,
es_if_exists="replace",
es_compress_model_definition=True,
)
@ -147,7 +147,7 @@ class TestImportedMLModel:
model_id,
classifier,
feature_names,
overwrite=True,
es_if_exists="replace",
es_compress_model_definition=compress_model_definition,
)
@ -176,7 +176,7 @@ class TestImportedMLModel:
model_id,
regressor,
feature_names,
overwrite=True,
es_if_exists="replace",
es_compress_model_definition=compress_model_definition,
)
# Get some test results
@ -204,7 +204,7 @@ class TestImportedMLModel:
model_id,
classifier,
feature_names,
overwrite=True,
es_if_exists="replace",
es_compress_model_definition=compress_model_definition,
)
# Get some test results
@ -232,7 +232,7 @@ class TestImportedMLModel:
model_id,
regressor,
feature_names,
overwrite=True,
es_if_exists="replace",
es_compress_model_definition=compress_model_definition,
)
# Get some test results
@ -270,7 +270,7 @@ class TestImportedMLModel:
model_id,
classifier,
feature_names,
overwrite=True,
es_if_exists="replace",
es_compress_model_definition=compress_model_definition,
)
# Get some test results
@ -306,7 +306,7 @@ class TestImportedMLModel:
model_id = "test_xgb_classifier"
es_model = ImportedMLModel(
ES_TEST_CLIENT, model_id, classifier, feature_names, overwrite=True
ES_TEST_CLIENT, model_id, classifier, feature_names, es_if_exists="replace"
)
# Get some test results
check_prediction_equality(
@ -342,7 +342,7 @@ class TestImportedMLModel:
model_id,
regressor,
feature_names,
overwrite=True,
es_if_exists="replace",
es_compress_model_definition=compress_model_definition,
)
# Get some test results
@ -369,7 +369,7 @@ class TestImportedMLModel:
model_id = "test_xgb_regressor"
es_model = ImportedMLModel(
ES_TEST_CLIENT, model_id, regressor, feature_names, overwrite=True
ES_TEST_CLIENT, model_id, regressor, feature_names, es_if_exists="replace"
)
# Single feature
@ -410,7 +410,7 @@ class TestImportedMLModel:
model_id,
regressor,
feature_names,
overwrite=True,
es_if_exists="replace",
es_compress_model_definition=compress_model_definition,
)
# Get some test results
@ -451,7 +451,7 @@ class TestImportedMLModel:
model_id,
classifier,
feature_names,
overwrite=True,
es_if_exists="replace",
es_compress_model_definition=compress_model_definition,
)
@ -461,3 +461,110 @@ class TestImportedMLModel:
# Clean up
es_model.delete_model()
# If both overwrite and es_if_exists is given.
@requires_sklearn
@pytest.mark.parametrize("compress_model_definition", [True, False])
@pytest.mark.parametrize("es_if_exists", ["fail", "replace"])
@pytest.mark.parametrize("overwrite", [True, False])
def test_imported_mlmodel_bothparams(
self, compress_model_definition, es_if_exists, overwrite
):
# Train model
training_data = datasets.make_regression(n_features=5)
regressor = RandomForestRegressor()
regressor.fit(training_data[0], training_data[1])
feature_names = ["f0", "f1", "f2", "f3", "f4"]
model_id = "test_random_forest_regressor"
match = "Using 'overwrite' and 'es_if_exists' together is invalid, use only 'es_if_exists'"
with pytest.raises(ValueError, match=match):
ImportedMLModel(
ES_TEST_CLIENT,
model_id,
regressor,
feature_names,
es_if_exists=es_if_exists,
overwrite=overwrite,
es_compress_model_definition=compress_model_definition,
)
# Deprecation warning for overwrite parameter
@requires_sklearn
@pytest.mark.parametrize("compress_model_definition", [True, False])
@pytest.mark.parametrize("overwrite", [True])
def test_imported_mlmodel_overwrite_true(
self, compress_model_definition, overwrite
):
# Train model
training_data = datasets.make_regression(n_features=5)
regressor = RandomForestRegressor()
regressor.fit(training_data[0], training_data[1])
feature_names = ["f0", "f1", "f2", "f3", "f4"]
model_id = "test_random_forest_regressor"
match = "'overwrite' parameter is deprecated, use 'es_if_exists' instead"
with pytest.warns(DeprecationWarning, match=match):
ImportedMLModel(
ES_TEST_CLIENT,
model_id,
regressor,
feature_names,
overwrite=overwrite,
es_compress_model_definition=compress_model_definition,
)
@requires_sklearn
@pytest.mark.parametrize("compress_model_definition", [True, False])
@pytest.mark.parametrize("overwrite", [False])
def test_imported_mlmodel_overwrite_false(
self, compress_model_definition, overwrite
):
# Train model
training_data = datasets.make_regression(n_features=5)
regressor = RandomForestRegressor()
regressor.fit(training_data[0], training_data[1])
feature_names = ["f0", "f1", "f2", "f3", "f4"]
model_id = "test_random_forest_regressor"
match_error = f"Trained machine learning model {model_id} already exists"
match_warning = (
"'overwrite' parameter is deprecated, use 'es_if_exists' instead"
)
with pytest.raises(ValueError, match=match_error):
with pytest.warns(DeprecationWarning, match=match_warning):
ImportedMLModel(
ES_TEST_CLIENT,
model_id,
regressor,
feature_names,
overwrite=overwrite,
es_compress_model_definition=compress_model_definition,
)
# Raise ValueError if Model exists when es_if_exists = 'fail'
@requires_sklearn
@pytest.mark.parametrize("compress_model_definition", [True, False])
def test_es_if_exists_fail(self, compress_model_definition):
# Train model
training_data = datasets.make_regression(n_features=5)
regressor = RandomForestRegressor()
regressor.fit(training_data[0], training_data[1])
feature_names = ["f0", "f1", "f2", "f3", "f4"]
model_id = "test_random_forest_regressor"
# If both overwrite and es_if_exists is given.
match = f"Trained machine learning model {model_id} already exists"
with pytest.raises(ValueError, match=match):
ImportedMLModel(
ES_TEST_CLIENT,
model_id,
regressor,
feature_names,
es_if_exists="fail",
es_compress_model_definition=compress_model_definition,
)

View File

@ -36,6 +36,7 @@ def deprecated_api(
warnings.warn(
f"{f.__name__} is deprecated, use {replace_with} instead",
DeprecationWarning,
stacklevel=2,
)
return f(*args, **kwargs)