mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
Replace MLModel(overwrite) with es_if_exists
This commit is contained in:
parent
5bf205a1e0
commit
66b24f9e8a
@ -151,7 +151,8 @@ currently using a minimum version of PyCharm 2019.2.4.
|
||||
Tools\'-\>\'Docstring format\' to `numpy`
|
||||
- Install development requirements. Open terminal in virtual
|
||||
environment and run `pip install -r requirements-dev.txt`
|
||||
- Setup Elasticsearch instance (assumes `localhost:9200`), and run
|
||||
- Setup Elasticsearch instance with docker `ELASTICSEARCH_VERSION=elasticsearch:7.x-SNAPSHOT .ci/run-elasticsearch.sh` and check `http://localhost:9200`
|
||||
- Run
|
||||
`python -m eland.tests.setup_tests` to setup test environment -*note
|
||||
this modifies Elasticsearch indices*
|
||||
- Install local `eland` module (required to execute notebook tests)
|
||||
|
@ -140,7 +140,7 @@ dtype: int64
|
||||
13057 20819.488281
|
||||
13058 18315.431274
|
||||
Length: 13059, dtype: float64
|
||||
>>> print(s.info_es())
|
||||
>>> print(s.es_info())
|
||||
index_pattern: flights
|
||||
Index:
|
||||
index_field: _id
|
||||
|
@ -22,6 +22,7 @@ import numpy as np # type: ignore
|
||||
from .ml_model import MLModel
|
||||
from .transformers import get_model_transformer
|
||||
from ..common import es_version
|
||||
import warnings
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -100,7 +101,13 @@ class ImportedMLModel(MLModel):
|
||||
classification_weights: List[str]
|
||||
Weights of the classification targets
|
||||
|
||||
overwrite: bool
|
||||
es_if_exists: {'fail', 'replace'} default 'fail'
|
||||
How to behave if model already exists
|
||||
|
||||
- fail: Raise a Value Error
|
||||
- replace: Overwrite existing model
|
||||
|
||||
overwrite: **DEPRECATED** - bool
|
||||
Delete and overwrite existing model (if exists)
|
||||
|
||||
es_compress_model_definition: bool
|
||||
@ -127,7 +134,7 @@ class ImportedMLModel(MLModel):
|
||||
>>> # Serialise the model to Elasticsearch
|
||||
>>> feature_names = ["f0", "f1", "f2", "f3", "f4"]
|
||||
>>> model_id = "test_decision_tree_classifier"
|
||||
>>> es_model = ImportedMLModel('localhost', model_id, classifier, feature_names, overwrite=True)
|
||||
>>> es_model = ImportedMLModel('localhost', model_id, classifier, feature_names, es_if_exists='replace')
|
||||
|
||||
>>> # Get some test results from Elasticsearch model
|
||||
>>> es_model.predict(test_data)
|
||||
@ -155,7 +162,8 @@ class ImportedMLModel(MLModel):
|
||||
feature_names: List[str],
|
||||
classification_labels: Optional[List[str]] = None,
|
||||
classification_weights: Optional[List[float]] = None,
|
||||
overwrite: bool = False,
|
||||
es_if_exists: Optional[str] = None,
|
||||
overwrite: Optional[bool] = None,
|
||||
es_compress_model_definition: bool = True,
|
||||
):
|
||||
super().__init__(es_client, model_id)
|
||||
@ -171,7 +179,30 @@ class ImportedMLModel(MLModel):
|
||||
self._model_type = transformer.model_type
|
||||
serializer = transformer.transform()
|
||||
|
||||
if overwrite:
|
||||
# Verify if both parameters are given
|
||||
if overwrite is not None and es_if_exists is not None:
|
||||
raise ValueError(
|
||||
"Using 'overwrite' and 'es_if_exists' together is invalid, use only 'es_if_exists'"
|
||||
)
|
||||
|
||||
if overwrite is not None:
|
||||
warnings.warn(
|
||||
"'overwrite' parameter is deprecated, use 'es_if_exists' instead",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
es_if_exists = "replace" if overwrite else "fail"
|
||||
elif es_if_exists is None:
|
||||
es_if_exists = "fail"
|
||||
|
||||
if es_if_exists not in ("fail", "replace"):
|
||||
raise ValueError("'es_if_exists' must be either 'fail' or 'replace'")
|
||||
elif es_if_exists == "fail":
|
||||
if self.check_existing_model():
|
||||
raise ValueError(
|
||||
f"Trained machine learning model {model_id} already exists"
|
||||
)
|
||||
elif es_if_exists == "replace":
|
||||
self.delete_model()
|
||||
|
||||
body: Dict[str, Any] = {
|
||||
@ -224,7 +255,7 @@ class ImportedMLModel(MLModel):
|
||||
>>> # Serialise the model to Elasticsearch
|
||||
>>> feature_names = ["f0", "f1", "f2", "f3", "f4", "f5"]
|
||||
>>> model_id = "test_xgb_regressor"
|
||||
>>> es_model = ImportedMLModel('localhost', model_id, regressor, feature_names, overwrite=True)
|
||||
>>> es_model = ImportedMLModel('localhost', model_id, regressor, feature_names, es_if_exists='replace')
|
||||
|
||||
>>> # Get some test results from Elasticsearch model
|
||||
>>> es_model.predict(test_data)
|
||||
|
@ -57,3 +57,15 @@ class MLModel:
|
||||
self._client.ml.delete_trained_model(model_id=self._model_id, ignore=(404,))
|
||||
except elasticsearch.NotFoundError:
|
||||
pass
|
||||
|
||||
def check_existing_model(self) -> bool:
|
||||
"""
|
||||
Check If model exists in Elastic
|
||||
"""
|
||||
try:
|
||||
self._client.ml.get_trained_models(
|
||||
model_id=self._model_id, include_model_definition=False
|
||||
)
|
||||
except elasticsearch.NotFoundError:
|
||||
return False
|
||||
return True
|
||||
|
@ -103,7 +103,7 @@ class TestImportedMLModel:
|
||||
model_id,
|
||||
classifier,
|
||||
feature_names,
|
||||
overwrite=True,
|
||||
es_if_exists="replace",
|
||||
es_compress_model_definition=True,
|
||||
)
|
||||
|
||||
@ -147,7 +147,7 @@ class TestImportedMLModel:
|
||||
model_id,
|
||||
classifier,
|
||||
feature_names,
|
||||
overwrite=True,
|
||||
es_if_exists="replace",
|
||||
es_compress_model_definition=compress_model_definition,
|
||||
)
|
||||
|
||||
@ -176,7 +176,7 @@ class TestImportedMLModel:
|
||||
model_id,
|
||||
regressor,
|
||||
feature_names,
|
||||
overwrite=True,
|
||||
es_if_exists="replace",
|
||||
es_compress_model_definition=compress_model_definition,
|
||||
)
|
||||
# Get some test results
|
||||
@ -204,7 +204,7 @@ class TestImportedMLModel:
|
||||
model_id,
|
||||
classifier,
|
||||
feature_names,
|
||||
overwrite=True,
|
||||
es_if_exists="replace",
|
||||
es_compress_model_definition=compress_model_definition,
|
||||
)
|
||||
# Get some test results
|
||||
@ -232,7 +232,7 @@ class TestImportedMLModel:
|
||||
model_id,
|
||||
regressor,
|
||||
feature_names,
|
||||
overwrite=True,
|
||||
es_if_exists="replace",
|
||||
es_compress_model_definition=compress_model_definition,
|
||||
)
|
||||
# Get some test results
|
||||
@ -270,7 +270,7 @@ class TestImportedMLModel:
|
||||
model_id,
|
||||
classifier,
|
||||
feature_names,
|
||||
overwrite=True,
|
||||
es_if_exists="replace",
|
||||
es_compress_model_definition=compress_model_definition,
|
||||
)
|
||||
# Get some test results
|
||||
@ -306,7 +306,7 @@ class TestImportedMLModel:
|
||||
model_id = "test_xgb_classifier"
|
||||
|
||||
es_model = ImportedMLModel(
|
||||
ES_TEST_CLIENT, model_id, classifier, feature_names, overwrite=True
|
||||
ES_TEST_CLIENT, model_id, classifier, feature_names, es_if_exists="replace"
|
||||
)
|
||||
# Get some test results
|
||||
check_prediction_equality(
|
||||
@ -342,7 +342,7 @@ class TestImportedMLModel:
|
||||
model_id,
|
||||
regressor,
|
||||
feature_names,
|
||||
overwrite=True,
|
||||
es_if_exists="replace",
|
||||
es_compress_model_definition=compress_model_definition,
|
||||
)
|
||||
# Get some test results
|
||||
@ -369,7 +369,7 @@ class TestImportedMLModel:
|
||||
model_id = "test_xgb_regressor"
|
||||
|
||||
es_model = ImportedMLModel(
|
||||
ES_TEST_CLIENT, model_id, regressor, feature_names, overwrite=True
|
||||
ES_TEST_CLIENT, model_id, regressor, feature_names, es_if_exists="replace"
|
||||
)
|
||||
|
||||
# Single feature
|
||||
@ -410,7 +410,7 @@ class TestImportedMLModel:
|
||||
model_id,
|
||||
regressor,
|
||||
feature_names,
|
||||
overwrite=True,
|
||||
es_if_exists="replace",
|
||||
es_compress_model_definition=compress_model_definition,
|
||||
)
|
||||
# Get some test results
|
||||
@ -451,7 +451,7 @@ class TestImportedMLModel:
|
||||
model_id,
|
||||
classifier,
|
||||
feature_names,
|
||||
overwrite=True,
|
||||
es_if_exists="replace",
|
||||
es_compress_model_definition=compress_model_definition,
|
||||
)
|
||||
|
||||
@ -461,3 +461,110 @@ class TestImportedMLModel:
|
||||
|
||||
# Clean up
|
||||
es_model.delete_model()
|
||||
|
||||
# If both overwrite and es_if_exists is given.
|
||||
@requires_sklearn
|
||||
@pytest.mark.parametrize("compress_model_definition", [True, False])
|
||||
@pytest.mark.parametrize("es_if_exists", ["fail", "replace"])
|
||||
@pytest.mark.parametrize("overwrite", [True, False])
|
||||
def test_imported_mlmodel_bothparams(
|
||||
self, compress_model_definition, es_if_exists, overwrite
|
||||
):
|
||||
# Train model
|
||||
training_data = datasets.make_regression(n_features=5)
|
||||
regressor = RandomForestRegressor()
|
||||
regressor.fit(training_data[0], training_data[1])
|
||||
|
||||
feature_names = ["f0", "f1", "f2", "f3", "f4"]
|
||||
model_id = "test_random_forest_regressor"
|
||||
|
||||
match = "Using 'overwrite' and 'es_if_exists' together is invalid, use only 'es_if_exists'"
|
||||
with pytest.raises(ValueError, match=match):
|
||||
ImportedMLModel(
|
||||
ES_TEST_CLIENT,
|
||||
model_id,
|
||||
regressor,
|
||||
feature_names,
|
||||
es_if_exists=es_if_exists,
|
||||
overwrite=overwrite,
|
||||
es_compress_model_definition=compress_model_definition,
|
||||
)
|
||||
|
||||
# Deprecation warning for overwrite parameter
|
||||
@requires_sklearn
|
||||
@pytest.mark.parametrize("compress_model_definition", [True, False])
|
||||
@pytest.mark.parametrize("overwrite", [True])
|
||||
def test_imported_mlmodel_overwrite_true(
|
||||
self, compress_model_definition, overwrite
|
||||
):
|
||||
# Train model
|
||||
training_data = datasets.make_regression(n_features=5)
|
||||
regressor = RandomForestRegressor()
|
||||
regressor.fit(training_data[0], training_data[1])
|
||||
|
||||
feature_names = ["f0", "f1", "f2", "f3", "f4"]
|
||||
model_id = "test_random_forest_regressor"
|
||||
|
||||
match = "'overwrite' parameter is deprecated, use 'es_if_exists' instead"
|
||||
with pytest.warns(DeprecationWarning, match=match):
|
||||
ImportedMLModel(
|
||||
ES_TEST_CLIENT,
|
||||
model_id,
|
||||
regressor,
|
||||
feature_names,
|
||||
overwrite=overwrite,
|
||||
es_compress_model_definition=compress_model_definition,
|
||||
)
|
||||
|
||||
@requires_sklearn
|
||||
@pytest.mark.parametrize("compress_model_definition", [True, False])
|
||||
@pytest.mark.parametrize("overwrite", [False])
|
||||
def test_imported_mlmodel_overwrite_false(
|
||||
self, compress_model_definition, overwrite
|
||||
):
|
||||
# Train model
|
||||
training_data = datasets.make_regression(n_features=5)
|
||||
regressor = RandomForestRegressor()
|
||||
regressor.fit(training_data[0], training_data[1])
|
||||
|
||||
feature_names = ["f0", "f1", "f2", "f3", "f4"]
|
||||
model_id = "test_random_forest_regressor"
|
||||
|
||||
match_error = f"Trained machine learning model {model_id} already exists"
|
||||
match_warning = (
|
||||
"'overwrite' parameter is deprecated, use 'es_if_exists' instead"
|
||||
)
|
||||
with pytest.raises(ValueError, match=match_error):
|
||||
with pytest.warns(DeprecationWarning, match=match_warning):
|
||||
ImportedMLModel(
|
||||
ES_TEST_CLIENT,
|
||||
model_id,
|
||||
regressor,
|
||||
feature_names,
|
||||
overwrite=overwrite,
|
||||
es_compress_model_definition=compress_model_definition,
|
||||
)
|
||||
|
||||
# Raise ValueError if Model exists when es_if_exists = 'fail'
|
||||
@requires_sklearn
|
||||
@pytest.mark.parametrize("compress_model_definition", [True, False])
|
||||
def test_es_if_exists_fail(self, compress_model_definition):
|
||||
# Train model
|
||||
training_data = datasets.make_regression(n_features=5)
|
||||
regressor = RandomForestRegressor()
|
||||
regressor.fit(training_data[0], training_data[1])
|
||||
|
||||
feature_names = ["f0", "f1", "f2", "f3", "f4"]
|
||||
model_id = "test_random_forest_regressor"
|
||||
|
||||
# If both overwrite and es_if_exists is given.
|
||||
match = f"Trained machine learning model {model_id} already exists"
|
||||
with pytest.raises(ValueError, match=match):
|
||||
ImportedMLModel(
|
||||
ES_TEST_CLIENT,
|
||||
model_id,
|
||||
regressor,
|
||||
feature_names,
|
||||
es_if_exists="fail",
|
||||
es_compress_model_definition=compress_model_definition,
|
||||
)
|
||||
|
@ -36,6 +36,7 @@ def deprecated_api(
|
||||
warnings.warn(
|
||||
f"{f.__name__} is deprecated, use {replace_with} instead",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return f(*args, **kwargs)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user