mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
Replace MLModel(overwrite) with es_if_exists
This commit is contained in:
parent
5bf205a1e0
commit
66b24f9e8a
@ -151,7 +151,8 @@ currently using a minimum version of PyCharm 2019.2.4.
|
|||||||
Tools\'-\>\'Docstring format\' to `numpy`
|
Tools\'-\>\'Docstring format\' to `numpy`
|
||||||
- Install development requirements. Open terminal in virtual
|
- Install development requirements. Open terminal in virtual
|
||||||
environment and run `pip install -r requirements-dev.txt`
|
environment and run `pip install -r requirements-dev.txt`
|
||||||
- Setup Elasticsearch instance (assumes `localhost:9200`), and run
|
- Setup Elasticsearch instance with docker `ELASTICSEARCH_VERSION=elasticsearch:7.x-SNAPSHOT .ci/run-elasticsearch.sh` and check `http://localhost:9200`
|
||||||
|
- Run
|
||||||
`python -m eland.tests.setup_tests` to setup test environment -*note
|
`python -m eland.tests.setup_tests` to setup test environment -*note
|
||||||
this modifies Elasticsearch indices*
|
this modifies Elasticsearch indices*
|
||||||
- Install local `eland` module (required to execute notebook tests)
|
- Install local `eland` module (required to execute notebook tests)
|
||||||
|
@ -140,7 +140,7 @@ dtype: int64
|
|||||||
13057 20819.488281
|
13057 20819.488281
|
||||||
13058 18315.431274
|
13058 18315.431274
|
||||||
Length: 13059, dtype: float64
|
Length: 13059, dtype: float64
|
||||||
>>> print(s.info_es())
|
>>> print(s.es_info())
|
||||||
index_pattern: flights
|
index_pattern: flights
|
||||||
Index:
|
Index:
|
||||||
index_field: _id
|
index_field: _id
|
||||||
|
@ -22,6 +22,7 @@ import numpy as np # type: ignore
|
|||||||
from .ml_model import MLModel
|
from .ml_model import MLModel
|
||||||
from .transformers import get_model_transformer
|
from .transformers import get_model_transformer
|
||||||
from ..common import es_version
|
from ..common import es_version
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -100,7 +101,13 @@ class ImportedMLModel(MLModel):
|
|||||||
classification_weights: List[str]
|
classification_weights: List[str]
|
||||||
Weights of the classification targets
|
Weights of the classification targets
|
||||||
|
|
||||||
overwrite: bool
|
es_if_exists: {'fail', 'replace'} default 'fail'
|
||||||
|
How to behave if model already exists
|
||||||
|
|
||||||
|
- fail: Raise a Value Error
|
||||||
|
- replace: Overwrite existing model
|
||||||
|
|
||||||
|
overwrite: **DEPRECATED** - bool
|
||||||
Delete and overwrite existing model (if exists)
|
Delete and overwrite existing model (if exists)
|
||||||
|
|
||||||
es_compress_model_definition: bool
|
es_compress_model_definition: bool
|
||||||
@ -127,7 +134,7 @@ class ImportedMLModel(MLModel):
|
|||||||
>>> # Serialise the model to Elasticsearch
|
>>> # Serialise the model to Elasticsearch
|
||||||
>>> feature_names = ["f0", "f1", "f2", "f3", "f4"]
|
>>> feature_names = ["f0", "f1", "f2", "f3", "f4"]
|
||||||
>>> model_id = "test_decision_tree_classifier"
|
>>> model_id = "test_decision_tree_classifier"
|
||||||
>>> es_model = ImportedMLModel('localhost', model_id, classifier, feature_names, overwrite=True)
|
>>> es_model = ImportedMLModel('localhost', model_id, classifier, feature_names, es_if_exists='replace')
|
||||||
|
|
||||||
>>> # Get some test results from Elasticsearch model
|
>>> # Get some test results from Elasticsearch model
|
||||||
>>> es_model.predict(test_data)
|
>>> es_model.predict(test_data)
|
||||||
@ -155,7 +162,8 @@ class ImportedMLModel(MLModel):
|
|||||||
feature_names: List[str],
|
feature_names: List[str],
|
||||||
classification_labels: Optional[List[str]] = None,
|
classification_labels: Optional[List[str]] = None,
|
||||||
classification_weights: Optional[List[float]] = None,
|
classification_weights: Optional[List[float]] = None,
|
||||||
overwrite: bool = False,
|
es_if_exists: Optional[str] = None,
|
||||||
|
overwrite: Optional[bool] = None,
|
||||||
es_compress_model_definition: bool = True,
|
es_compress_model_definition: bool = True,
|
||||||
):
|
):
|
||||||
super().__init__(es_client, model_id)
|
super().__init__(es_client, model_id)
|
||||||
@ -171,7 +179,30 @@ class ImportedMLModel(MLModel):
|
|||||||
self._model_type = transformer.model_type
|
self._model_type = transformer.model_type
|
||||||
serializer = transformer.transform()
|
serializer = transformer.transform()
|
||||||
|
|
||||||
if overwrite:
|
# Verify if both parameters are given
|
||||||
|
if overwrite is not None and es_if_exists is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"Using 'overwrite' and 'es_if_exists' together is invalid, use only 'es_if_exists'"
|
||||||
|
)
|
||||||
|
|
||||||
|
if overwrite is not None:
|
||||||
|
warnings.warn(
|
||||||
|
"'overwrite' parameter is deprecated, use 'es_if_exists' instead",
|
||||||
|
DeprecationWarning,
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
es_if_exists = "replace" if overwrite else "fail"
|
||||||
|
elif es_if_exists is None:
|
||||||
|
es_if_exists = "fail"
|
||||||
|
|
||||||
|
if es_if_exists not in ("fail", "replace"):
|
||||||
|
raise ValueError("'es_if_exists' must be either 'fail' or 'replace'")
|
||||||
|
elif es_if_exists == "fail":
|
||||||
|
if self.check_existing_model():
|
||||||
|
raise ValueError(
|
||||||
|
f"Trained machine learning model {model_id} already exists"
|
||||||
|
)
|
||||||
|
elif es_if_exists == "replace":
|
||||||
self.delete_model()
|
self.delete_model()
|
||||||
|
|
||||||
body: Dict[str, Any] = {
|
body: Dict[str, Any] = {
|
||||||
@ -224,7 +255,7 @@ class ImportedMLModel(MLModel):
|
|||||||
>>> # Serialise the model to Elasticsearch
|
>>> # Serialise the model to Elasticsearch
|
||||||
>>> feature_names = ["f0", "f1", "f2", "f3", "f4", "f5"]
|
>>> feature_names = ["f0", "f1", "f2", "f3", "f4", "f5"]
|
||||||
>>> model_id = "test_xgb_regressor"
|
>>> model_id = "test_xgb_regressor"
|
||||||
>>> es_model = ImportedMLModel('localhost', model_id, regressor, feature_names, overwrite=True)
|
>>> es_model = ImportedMLModel('localhost', model_id, regressor, feature_names, es_if_exists='replace')
|
||||||
|
|
||||||
>>> # Get some test results from Elasticsearch model
|
>>> # Get some test results from Elasticsearch model
|
||||||
>>> es_model.predict(test_data)
|
>>> es_model.predict(test_data)
|
||||||
|
@ -57,3 +57,15 @@ class MLModel:
|
|||||||
self._client.ml.delete_trained_model(model_id=self._model_id, ignore=(404,))
|
self._client.ml.delete_trained_model(model_id=self._model_id, ignore=(404,))
|
||||||
except elasticsearch.NotFoundError:
|
except elasticsearch.NotFoundError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def check_existing_model(self) -> bool:
|
||||||
|
"""
|
||||||
|
Check If model exists in Elastic
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self._client.ml.get_trained_models(
|
||||||
|
model_id=self._model_id, include_model_definition=False
|
||||||
|
)
|
||||||
|
except elasticsearch.NotFoundError:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
@ -103,7 +103,7 @@ class TestImportedMLModel:
|
|||||||
model_id,
|
model_id,
|
||||||
classifier,
|
classifier,
|
||||||
feature_names,
|
feature_names,
|
||||||
overwrite=True,
|
es_if_exists="replace",
|
||||||
es_compress_model_definition=True,
|
es_compress_model_definition=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -147,7 +147,7 @@ class TestImportedMLModel:
|
|||||||
model_id,
|
model_id,
|
||||||
classifier,
|
classifier,
|
||||||
feature_names,
|
feature_names,
|
||||||
overwrite=True,
|
es_if_exists="replace",
|
||||||
es_compress_model_definition=compress_model_definition,
|
es_compress_model_definition=compress_model_definition,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -176,7 +176,7 @@ class TestImportedMLModel:
|
|||||||
model_id,
|
model_id,
|
||||||
regressor,
|
regressor,
|
||||||
feature_names,
|
feature_names,
|
||||||
overwrite=True,
|
es_if_exists="replace",
|
||||||
es_compress_model_definition=compress_model_definition,
|
es_compress_model_definition=compress_model_definition,
|
||||||
)
|
)
|
||||||
# Get some test results
|
# Get some test results
|
||||||
@ -204,7 +204,7 @@ class TestImportedMLModel:
|
|||||||
model_id,
|
model_id,
|
||||||
classifier,
|
classifier,
|
||||||
feature_names,
|
feature_names,
|
||||||
overwrite=True,
|
es_if_exists="replace",
|
||||||
es_compress_model_definition=compress_model_definition,
|
es_compress_model_definition=compress_model_definition,
|
||||||
)
|
)
|
||||||
# Get some test results
|
# Get some test results
|
||||||
@ -232,7 +232,7 @@ class TestImportedMLModel:
|
|||||||
model_id,
|
model_id,
|
||||||
regressor,
|
regressor,
|
||||||
feature_names,
|
feature_names,
|
||||||
overwrite=True,
|
es_if_exists="replace",
|
||||||
es_compress_model_definition=compress_model_definition,
|
es_compress_model_definition=compress_model_definition,
|
||||||
)
|
)
|
||||||
# Get some test results
|
# Get some test results
|
||||||
@ -270,7 +270,7 @@ class TestImportedMLModel:
|
|||||||
model_id,
|
model_id,
|
||||||
classifier,
|
classifier,
|
||||||
feature_names,
|
feature_names,
|
||||||
overwrite=True,
|
es_if_exists="replace",
|
||||||
es_compress_model_definition=compress_model_definition,
|
es_compress_model_definition=compress_model_definition,
|
||||||
)
|
)
|
||||||
# Get some test results
|
# Get some test results
|
||||||
@ -306,7 +306,7 @@ class TestImportedMLModel:
|
|||||||
model_id = "test_xgb_classifier"
|
model_id = "test_xgb_classifier"
|
||||||
|
|
||||||
es_model = ImportedMLModel(
|
es_model = ImportedMLModel(
|
||||||
ES_TEST_CLIENT, model_id, classifier, feature_names, overwrite=True
|
ES_TEST_CLIENT, model_id, classifier, feature_names, es_if_exists="replace"
|
||||||
)
|
)
|
||||||
# Get some test results
|
# Get some test results
|
||||||
check_prediction_equality(
|
check_prediction_equality(
|
||||||
@ -342,7 +342,7 @@ class TestImportedMLModel:
|
|||||||
model_id,
|
model_id,
|
||||||
regressor,
|
regressor,
|
||||||
feature_names,
|
feature_names,
|
||||||
overwrite=True,
|
es_if_exists="replace",
|
||||||
es_compress_model_definition=compress_model_definition,
|
es_compress_model_definition=compress_model_definition,
|
||||||
)
|
)
|
||||||
# Get some test results
|
# Get some test results
|
||||||
@ -369,7 +369,7 @@ class TestImportedMLModel:
|
|||||||
model_id = "test_xgb_regressor"
|
model_id = "test_xgb_regressor"
|
||||||
|
|
||||||
es_model = ImportedMLModel(
|
es_model = ImportedMLModel(
|
||||||
ES_TEST_CLIENT, model_id, regressor, feature_names, overwrite=True
|
ES_TEST_CLIENT, model_id, regressor, feature_names, es_if_exists="replace"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Single feature
|
# Single feature
|
||||||
@ -410,7 +410,7 @@ class TestImportedMLModel:
|
|||||||
model_id,
|
model_id,
|
||||||
regressor,
|
regressor,
|
||||||
feature_names,
|
feature_names,
|
||||||
overwrite=True,
|
es_if_exists="replace",
|
||||||
es_compress_model_definition=compress_model_definition,
|
es_compress_model_definition=compress_model_definition,
|
||||||
)
|
)
|
||||||
# Get some test results
|
# Get some test results
|
||||||
@ -451,7 +451,7 @@ class TestImportedMLModel:
|
|||||||
model_id,
|
model_id,
|
||||||
classifier,
|
classifier,
|
||||||
feature_names,
|
feature_names,
|
||||||
overwrite=True,
|
es_if_exists="replace",
|
||||||
es_compress_model_definition=compress_model_definition,
|
es_compress_model_definition=compress_model_definition,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -461,3 +461,110 @@ class TestImportedMLModel:
|
|||||||
|
|
||||||
# Clean up
|
# Clean up
|
||||||
es_model.delete_model()
|
es_model.delete_model()
|
||||||
|
|
||||||
|
# If both overwrite and es_if_exists is given.
|
||||||
|
@requires_sklearn
|
||||||
|
@pytest.mark.parametrize("compress_model_definition", [True, False])
|
||||||
|
@pytest.mark.parametrize("es_if_exists", ["fail", "replace"])
|
||||||
|
@pytest.mark.parametrize("overwrite", [True, False])
|
||||||
|
def test_imported_mlmodel_bothparams(
|
||||||
|
self, compress_model_definition, es_if_exists, overwrite
|
||||||
|
):
|
||||||
|
# Train model
|
||||||
|
training_data = datasets.make_regression(n_features=5)
|
||||||
|
regressor = RandomForestRegressor()
|
||||||
|
regressor.fit(training_data[0], training_data[1])
|
||||||
|
|
||||||
|
feature_names = ["f0", "f1", "f2", "f3", "f4"]
|
||||||
|
model_id = "test_random_forest_regressor"
|
||||||
|
|
||||||
|
match = "Using 'overwrite' and 'es_if_exists' together is invalid, use only 'es_if_exists'"
|
||||||
|
with pytest.raises(ValueError, match=match):
|
||||||
|
ImportedMLModel(
|
||||||
|
ES_TEST_CLIENT,
|
||||||
|
model_id,
|
||||||
|
regressor,
|
||||||
|
feature_names,
|
||||||
|
es_if_exists=es_if_exists,
|
||||||
|
overwrite=overwrite,
|
||||||
|
es_compress_model_definition=compress_model_definition,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Deprecation warning for overwrite parameter
|
||||||
|
@requires_sklearn
|
||||||
|
@pytest.mark.parametrize("compress_model_definition", [True, False])
|
||||||
|
@pytest.mark.parametrize("overwrite", [True])
|
||||||
|
def test_imported_mlmodel_overwrite_true(
|
||||||
|
self, compress_model_definition, overwrite
|
||||||
|
):
|
||||||
|
# Train model
|
||||||
|
training_data = datasets.make_regression(n_features=5)
|
||||||
|
regressor = RandomForestRegressor()
|
||||||
|
regressor.fit(training_data[0], training_data[1])
|
||||||
|
|
||||||
|
feature_names = ["f0", "f1", "f2", "f3", "f4"]
|
||||||
|
model_id = "test_random_forest_regressor"
|
||||||
|
|
||||||
|
match = "'overwrite' parameter is deprecated, use 'es_if_exists' instead"
|
||||||
|
with pytest.warns(DeprecationWarning, match=match):
|
||||||
|
ImportedMLModel(
|
||||||
|
ES_TEST_CLIENT,
|
||||||
|
model_id,
|
||||||
|
regressor,
|
||||||
|
feature_names,
|
||||||
|
overwrite=overwrite,
|
||||||
|
es_compress_model_definition=compress_model_definition,
|
||||||
|
)
|
||||||
|
|
||||||
|
@requires_sklearn
|
||||||
|
@pytest.mark.parametrize("compress_model_definition", [True, False])
|
||||||
|
@pytest.mark.parametrize("overwrite", [False])
|
||||||
|
def test_imported_mlmodel_overwrite_false(
|
||||||
|
self, compress_model_definition, overwrite
|
||||||
|
):
|
||||||
|
# Train model
|
||||||
|
training_data = datasets.make_regression(n_features=5)
|
||||||
|
regressor = RandomForestRegressor()
|
||||||
|
regressor.fit(training_data[0], training_data[1])
|
||||||
|
|
||||||
|
feature_names = ["f0", "f1", "f2", "f3", "f4"]
|
||||||
|
model_id = "test_random_forest_regressor"
|
||||||
|
|
||||||
|
match_error = f"Trained machine learning model {model_id} already exists"
|
||||||
|
match_warning = (
|
||||||
|
"'overwrite' parameter is deprecated, use 'es_if_exists' instead"
|
||||||
|
)
|
||||||
|
with pytest.raises(ValueError, match=match_error):
|
||||||
|
with pytest.warns(DeprecationWarning, match=match_warning):
|
||||||
|
ImportedMLModel(
|
||||||
|
ES_TEST_CLIENT,
|
||||||
|
model_id,
|
||||||
|
regressor,
|
||||||
|
feature_names,
|
||||||
|
overwrite=overwrite,
|
||||||
|
es_compress_model_definition=compress_model_definition,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Raise ValueError if Model exists when es_if_exists = 'fail'
|
||||||
|
@requires_sklearn
|
||||||
|
@pytest.mark.parametrize("compress_model_definition", [True, False])
|
||||||
|
def test_es_if_exists_fail(self, compress_model_definition):
|
||||||
|
# Train model
|
||||||
|
training_data = datasets.make_regression(n_features=5)
|
||||||
|
regressor = RandomForestRegressor()
|
||||||
|
regressor.fit(training_data[0], training_data[1])
|
||||||
|
|
||||||
|
feature_names = ["f0", "f1", "f2", "f3", "f4"]
|
||||||
|
model_id = "test_random_forest_regressor"
|
||||||
|
|
||||||
|
# If both overwrite and es_if_exists is given.
|
||||||
|
match = f"Trained machine learning model {model_id} already exists"
|
||||||
|
with pytest.raises(ValueError, match=match):
|
||||||
|
ImportedMLModel(
|
||||||
|
ES_TEST_CLIENT,
|
||||||
|
model_id,
|
||||||
|
regressor,
|
||||||
|
feature_names,
|
||||||
|
es_if_exists="fail",
|
||||||
|
es_compress_model_definition=compress_model_definition,
|
||||||
|
)
|
||||||
|
@ -36,6 +36,7 @@ def deprecated_api(
|
|||||||
warnings.warn(
|
warnings.warn(
|
||||||
f"{f.__name__} is deprecated, use {replace_with} instead",
|
f"{f.__name__} is deprecated, use {replace_with} instead",
|
||||||
DeprecationWarning,
|
DeprecationWarning,
|
||||||
|
stacklevel=2,
|
||||||
)
|
)
|
||||||
return f(*args, **kwargs)
|
return f(*args, **kwargs)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user