mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
Deprecate ImportedMLModel in favor of MLModel.import_model()
This commit is contained in:
parent
1a8a301cd6
commit
c86371733d
@ -2495,10 +2495,10 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"from eland.ml import ImportedMLModel\n",
|
"from eland.ml import MLModel\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Serialize the scikit-learn model into Elasticsearch\n",
|
"# Serialize the scikit-learn model into Elasticsearch\n",
|
||||||
"ed_classifier = ImportedMLModel(\n",
|
"ed_classifier = MLModel.import_model(\n",
|
||||||
" es_client=es,\n",
|
" es_client=es,\n",
|
||||||
" model_id=\"wine-classifier\",\n",
|
" model_id=\"wine-classifier\",\n",
|
||||||
" model=sk_classifier,\n",
|
" model=sk_classifier,\n",
|
||||||
|
@ -1,6 +0,0 @@
|
|||||||
eland.ml.ImportedMLModel.predict
|
|
||||||
================================
|
|
||||||
|
|
||||||
.. currentmodule:: eland.ml
|
|
||||||
|
|
||||||
.. automethod:: ImportedMLModel.predict
|
|
@ -1,6 +0,0 @@
|
|||||||
eland.ml.ImportedMLModel
|
|
||||||
========================
|
|
||||||
|
|
||||||
.. currentmodule:: eland.ml
|
|
||||||
|
|
||||||
.. autoclass:: ImportedMLModel
|
|
@ -0,0 +1,6 @@
|
|||||||
|
eland.ml.MLModel.import_model
|
||||||
|
=============================
|
||||||
|
|
||||||
|
.. currentmodule:: eland.ml
|
||||||
|
|
||||||
|
.. automethod:: MLModel.import_model
|
6
docs/source/reference/api/eland.ml.MLModel.predict.rst
Normal file
6
docs/source/reference/api/eland.ml.MLModel.predict.rst
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
eland.ml.MLModel.predict
|
||||||
|
========================
|
||||||
|
|
||||||
|
.. currentmodule:: eland.ml
|
||||||
|
|
||||||
|
.. automethod:: MLModel.predict
|
@ -27,19 +27,20 @@ Constructor
|
|||||||
.. autosummary::
|
.. autosummary::
|
||||||
:toctree: api/
|
:toctree: api/
|
||||||
|
|
||||||
ImportedMLModel
|
MLModel
|
||||||
|
|
||||||
Predictions
|
Predictions
|
||||||
^^^^^^^^^^^
|
^^^^^^^^^^^
|
||||||
.. autosummary::
|
.. autosummary::
|
||||||
:toctree: api/
|
:toctree: api/
|
||||||
|
|
||||||
ImportedMLModel.predict
|
MLModel.predict
|
||||||
|
|
||||||
Manage Models
|
Manage Models
|
||||||
^^^^^^^^^^^^^
|
^^^^^^^^^^^^^
|
||||||
.. autosummary::
|
.. autosummary::
|
||||||
:toctree: api/
|
:toctree: api/
|
||||||
|
|
||||||
|
MLModel.import_model
|
||||||
MLModel.exists_model
|
MLModel.exists_model
|
||||||
MLModel.delete_model
|
MLModel.delete_model
|
||||||
|
@ -15,8 +15,7 @@
|
|||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
|
|
||||||
from eland.ml.ml_model import MLModel
|
from eland.ml.ml_model import MLModel, ImportedMLModel
|
||||||
from eland.ml.imported_ml_model import ImportedMLModel
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"MLModel",
|
"MLModel",
|
||||||
|
19
eland/ml/common.py
Normal file
19
eland/ml/common.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
# Licensed to Elasticsearch B.V. under one or more contributor
|
||||||
|
# license agreements. See the NOTICE file distributed with
|
||||||
|
# this work for additional information regarding copyright
|
||||||
|
# ownership. Elasticsearch B.V. licenses this file to you under
|
||||||
|
# the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
# not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
|
||||||
|
TYPE_CLASSIFICATION = "classification"
|
||||||
|
TYPE_REGRESSION = "regression"
|
@ -1,321 +0,0 @@
|
|||||||
# Licensed to Elasticsearch B.V. under one or more contributor
|
|
||||||
# license agreements. See the NOTICE file distributed with
|
|
||||||
# this work for additional information regarding copyright
|
|
||||||
# ownership. Elasticsearch B.V. licenses this file to you under
|
|
||||||
# the Apache License, Version 2.0 (the "License"); you may
|
|
||||||
# not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing,
|
|
||||||
# software distributed under the License is distributed on an
|
|
||||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
||||||
# KIND, either express or implied. See the License for the
|
|
||||||
# specific language governing permissions and limitations
|
|
||||||
# under the License.
|
|
||||||
|
|
||||||
from typing import Union, List, Optional, Tuple, TYPE_CHECKING, cast, Dict, Any
|
|
||||||
|
|
||||||
import numpy as np # type: ignore
|
|
||||||
|
|
||||||
from .ml_model import MLModel
|
|
||||||
from .transformers import get_model_transformer
|
|
||||||
from ..common import es_version
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from elasticsearch import Elasticsearch # type: ignore # noqa: F401
|
|
||||||
|
|
||||||
# Try importing each ML lib separately so mypy users don't have to
|
|
||||||
# have both installed to use type-checking.
|
|
||||||
try:
|
|
||||||
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor # type: ignore # noqa: F401
|
|
||||||
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor # type: ignore # noqa: F401
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
try:
|
|
||||||
from xgboost import XGBRegressor, XGBClassifier # type: ignore # noqa: F401
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
try:
|
|
||||||
from lightgbm import LGBMRegressor, LGBMClassifier # type: ignore # noqa: F401
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ImportedMLModel(MLModel):
|
|
||||||
"""
|
|
||||||
Transform and serialize a trained 3rd party model into Elasticsearch.
|
|
||||||
This model can then be used for inference in the Elastic Stack.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
es_client: Elasticsearch client argument(s)
|
|
||||||
- elasticsearch-py parameters or
|
|
||||||
- elasticsearch-py instance
|
|
||||||
|
|
||||||
model_id: str
|
|
||||||
The unique identifier of the trained inference model in Elasticsearch.
|
|
||||||
|
|
||||||
model: An instance of a supported python model. We support the following model types:
|
|
||||||
- sklearn.tree.DecisionTreeClassifier
|
|
||||||
- sklearn.tree.DecisionTreeRegressor
|
|
||||||
- sklearn.ensemble.RandomForestRegressor
|
|
||||||
- sklearn.ensemble.RandomForestClassifier
|
|
||||||
- lightgbm.LGBMRegressor
|
|
||||||
- Categorical fields are expected to already be processed
|
|
||||||
- Only the following objectives are supported
|
|
||||||
- "regression"
|
|
||||||
- "regression_l1"
|
|
||||||
- "huber"
|
|
||||||
- "fair"
|
|
||||||
- "quantile"
|
|
||||||
- "mape"
|
|
||||||
- lightgbm.LGBMClassifier
|
|
||||||
- Categorical fields are expected to already be processed
|
|
||||||
- Only the following objectives are supported
|
|
||||||
- "binary"
|
|
||||||
- "multiclass"
|
|
||||||
- "multiclassova"
|
|
||||||
- xgboost.XGBClassifier
|
|
||||||
- only the following objectives are supported:
|
|
||||||
- "binary:logistic"
|
|
||||||
- "multi:softmax"
|
|
||||||
- "multi:softprob"
|
|
||||||
- xgboost.XGBRegressor
|
|
||||||
- only the following objectives are supported:
|
|
||||||
- "reg:squarederror"
|
|
||||||
- "reg:linear"
|
|
||||||
- "reg:squaredlogerror"
|
|
||||||
- "reg:logistic"
|
|
||||||
- "reg:pseudohubererror"
|
|
||||||
|
|
||||||
feature_names: List[str]
|
|
||||||
Names of the features (required)
|
|
||||||
|
|
||||||
classification_labels: List[str]
|
|
||||||
Labels of the classification targets
|
|
||||||
|
|
||||||
classification_weights: List[str]
|
|
||||||
Weights of the classification targets
|
|
||||||
|
|
||||||
es_if_exists: {'fail', 'replace'} default 'fail'
|
|
||||||
How to behave if model already exists
|
|
||||||
|
|
||||||
- fail: Raise a Value Error
|
|
||||||
- replace: Overwrite existing model
|
|
||||||
|
|
||||||
overwrite: **DEPRECATED** - bool
|
|
||||||
Delete and overwrite existing model (if exists)
|
|
||||||
|
|
||||||
es_compress_model_definition: bool
|
|
||||||
If True will use 'compressed_definition' which uses gzipped
|
|
||||||
JSON instead of raw JSON to reduce the amount of data sent
|
|
||||||
over the wire in HTTP requests. Defaults to 'True'.
|
|
||||||
|
|
||||||
Examples
|
|
||||||
--------
|
|
||||||
>>> from sklearn import datasets
|
|
||||||
>>> from sklearn.tree import DecisionTreeClassifier
|
|
||||||
>>> from eland.ml import ImportedMLModel
|
|
||||||
|
|
||||||
>>> # Train model
|
|
||||||
>>> training_data = datasets.make_classification(n_features=5, random_state=0)
|
|
||||||
>>> test_data = [[-50.1, 0.2, 0.3, -0.5, 1.0], [1.6, 2.1, -10, 50, -1.0]]
|
|
||||||
>>> classifier = DecisionTreeClassifier()
|
|
||||||
>>> classifier = classifier.fit(training_data[0], training_data[1])
|
|
||||||
|
|
||||||
>>> # Get some test results
|
|
||||||
>>> classifier.predict(test_data)
|
|
||||||
array([0, 1])
|
|
||||||
|
|
||||||
>>> # Serialise the model to Elasticsearch
|
|
||||||
>>> feature_names = ["f0", "f1", "f2", "f3", "f4"]
|
|
||||||
>>> model_id = "test_decision_tree_classifier"
|
|
||||||
>>> es_model = ImportedMLModel('localhost', model_id, classifier, feature_names, es_if_exists='replace')
|
|
||||||
|
|
||||||
>>> # Get some test results from Elasticsearch model
|
|
||||||
>>> es_model.predict(test_data)
|
|
||||||
array([0, 1])
|
|
||||||
|
|
||||||
>>> # Delete model from Elasticsearch
|
|
||||||
>>> es_model.delete_model()
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
es_client: Union[str, List[str], Tuple[str, ...], "Elasticsearch"],
|
|
||||||
model_id: str,
|
|
||||||
model: Union[
|
|
||||||
"DecisionTreeClassifier",
|
|
||||||
"DecisionTreeRegressor",
|
|
||||||
"RandomForestRegressor",
|
|
||||||
"RandomForestClassifier",
|
|
||||||
"XGBClassifier",
|
|
||||||
"XGBRegressor",
|
|
||||||
"LGBMRegressor",
|
|
||||||
"LGBMClassifier",
|
|
||||||
],
|
|
||||||
feature_names: List[str],
|
|
||||||
classification_labels: Optional[List[str]] = None,
|
|
||||||
classification_weights: Optional[List[float]] = None,
|
|
||||||
es_if_exists: Optional[str] = None,
|
|
||||||
overwrite: Optional[bool] = None,
|
|
||||||
es_compress_model_definition: bool = True,
|
|
||||||
):
|
|
||||||
super().__init__(es_client, model_id)
|
|
||||||
|
|
||||||
self._feature_names = feature_names
|
|
||||||
|
|
||||||
transformer = get_model_transformer(
|
|
||||||
model,
|
|
||||||
feature_names=feature_names,
|
|
||||||
classification_labels=classification_labels,
|
|
||||||
classification_weights=classification_weights,
|
|
||||||
)
|
|
||||||
self._model_type = transformer.model_type
|
|
||||||
serializer = transformer.transform()
|
|
||||||
|
|
||||||
# Verify if both parameters are given
|
|
||||||
if overwrite is not None and es_if_exists is not None:
|
|
||||||
raise ValueError(
|
|
||||||
"Using 'overwrite' and 'es_if_exists' together is invalid, use only 'es_if_exists'"
|
|
||||||
)
|
|
||||||
|
|
||||||
if overwrite is not None:
|
|
||||||
warnings.warn(
|
|
||||||
"'overwrite' parameter is deprecated, use 'es_if_exists' instead",
|
|
||||||
DeprecationWarning,
|
|
||||||
stacklevel=2,
|
|
||||||
)
|
|
||||||
es_if_exists = "replace" if overwrite else "fail"
|
|
||||||
elif es_if_exists is None:
|
|
||||||
es_if_exists = "fail"
|
|
||||||
|
|
||||||
if es_if_exists not in ("fail", "replace"):
|
|
||||||
raise ValueError("'es_if_exists' must be either 'fail' or 'replace'")
|
|
||||||
elif es_if_exists == "fail":
|
|
||||||
if self.exists_model():
|
|
||||||
raise ValueError(
|
|
||||||
f"Trained machine learning model {model_id} already exists"
|
|
||||||
)
|
|
||||||
elif es_if_exists == "replace":
|
|
||||||
self.delete_model()
|
|
||||||
|
|
||||||
body: Dict[str, Any] = {
|
|
||||||
"input": {"field_names": feature_names},
|
|
||||||
}
|
|
||||||
# 'inference_config' is required in 7.8+ but isn't available in <=7.7
|
|
||||||
if es_version(self._client) >= (7, 8):
|
|
||||||
body["inference_config"] = {self._model_type: {}}
|
|
||||||
|
|
||||||
if es_compress_model_definition:
|
|
||||||
body["compressed_definition"] = serializer.serialize_and_compress_model()
|
|
||||||
else:
|
|
||||||
body["definition"] = serializer.serialize_model()
|
|
||||||
|
|
||||||
self._client.ml.put_trained_model(
|
|
||||||
model_id=self._model_id,
|
|
||||||
body=body,
|
|
||||||
)
|
|
||||||
|
|
||||||
def predict(self, X: Union[List[float], List[List[float]]]) -> np.ndarray:
|
|
||||||
"""
|
|
||||||
Make a prediction using a trained model stored in Elasticsearch.
|
|
||||||
|
|
||||||
Parameters for this method are not yet fully compatible with standard sklearn.predict.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
X: list or list of lists of type float
|
|
||||||
Input feature vector - TODO support DataFrame and other formats
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
y: np.ndarray of dtype float for regressors or int for classifiers
|
|
||||||
|
|
||||||
Examples
|
|
||||||
--------
|
|
||||||
>>> from sklearn import datasets
|
|
||||||
>>> from xgboost import XGBRegressor
|
|
||||||
>>> from eland.ml import ImportedMLModel
|
|
||||||
|
|
||||||
>>> # Train model
|
|
||||||
>>> training_data = datasets.make_classification(n_features=6, random_state=0)
|
|
||||||
>>> test_data = [[-1, -2, -3, -4, -5, -6], [10, 20, 30, 40, 50, 60]]
|
|
||||||
>>> regressor = XGBRegressor(objective='reg:squarederror')
|
|
||||||
>>> regressor = regressor.fit(training_data[0], training_data[1])
|
|
||||||
|
|
||||||
>>> # Get some test results
|
|
||||||
>>> regressor.predict(np.array(test_data))
|
|
||||||
array([0.06062475, 0.9990102 ], dtype=float32)
|
|
||||||
|
|
||||||
>>> # Serialise the model to Elasticsearch
|
|
||||||
>>> feature_names = ["f0", "f1", "f2", "f3", "f4", "f5"]
|
|
||||||
>>> model_id = "test_xgb_regressor"
|
|
||||||
>>> es_model = ImportedMLModel('localhost', model_id, regressor, feature_names, es_if_exists='replace')
|
|
||||||
|
|
||||||
>>> # Get some test results from Elasticsearch model
|
|
||||||
>>> es_model.predict(test_data)
|
|
||||||
array([0.0606248 , 0.99901026], dtype=float32)
|
|
||||||
|
|
||||||
>>> # Delete model from Elasticsearch
|
|
||||||
>>> es_model.delete_model()
|
|
||||||
|
|
||||||
"""
|
|
||||||
docs = []
|
|
||||||
if isinstance(X, list):
|
|
||||||
# Is it a list of floats?
|
|
||||||
if all(isinstance(i, float) for i in X):
|
|
||||||
features = cast(List[List[float]], [X])
|
|
||||||
else:
|
|
||||||
features = cast(List[List[float]], X)
|
|
||||||
for i in features:
|
|
||||||
doc = {"_source": dict(zip(self._feature_names, i))}
|
|
||||||
docs.append(doc)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(f"Prediction for type {type(X)}, not supported")
|
|
||||||
|
|
||||||
# field_mappings -> field_map in ES 7.7
|
|
||||||
field_map_name = (
|
|
||||||
"field_map" if es_version(self._client) >= (7, 7) else "field_mappings"
|
|
||||||
)
|
|
||||||
|
|
||||||
results = self._client.ingest.simulate(
|
|
||||||
body={
|
|
||||||
"pipeline": {
|
|
||||||
"processors": [
|
|
||||||
{
|
|
||||||
"inference": {
|
|
||||||
"model_id": self._model_id,
|
|
||||||
"inference_config": {self._model_type: {}},
|
|
||||||
field_map_name: {},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"docs": docs,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Unpack results into an array. Errors can be present
|
|
||||||
# within the response without a non-2XX HTTP status code.
|
|
||||||
y = []
|
|
||||||
for res in results["docs"]:
|
|
||||||
if "error" in res:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Failed to run prediction for model ID {self._model_id!r}",
|
|
||||||
res["error"],
|
|
||||||
)
|
|
||||||
|
|
||||||
y.append(res["doc"]["_source"]["ml"]["inference"]["predicted_value"])
|
|
||||||
|
|
||||||
# Return results as np.ndarray of float32 or int (consistent with sklearn/xgboost)
|
|
||||||
if self._model_type == MLModel.TYPE_CLASSIFICATION:
|
|
||||||
dt = np.int
|
|
||||||
else:
|
|
||||||
dt = np.float32
|
|
||||||
return np.asarray(y, dtype=dt)
|
|
@ -15,8 +15,33 @@
|
|||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
|
|
||||||
import elasticsearch
|
from typing import List, Union, cast, Optional, Dict, TYPE_CHECKING, Any, Tuple
|
||||||
from eland.common import ensure_es_client
|
import warnings
|
||||||
|
import numpy as np # type: ignore
|
||||||
|
import elasticsearch # type: ignore
|
||||||
|
from .common import TYPE_REGRESSION, TYPE_CLASSIFICATION
|
||||||
|
from .transformers import get_model_transformer
|
||||||
|
from eland.common import ensure_es_client, es_version
|
||||||
|
from eland.utils import deprecated_api
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from elasticsearch import Elasticsearch # noqa: F401
|
||||||
|
|
||||||
|
# Try importing each ML lib separately so mypy users don't have to
|
||||||
|
# have both installed to use type-checking.
|
||||||
|
try:
|
||||||
|
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor # type: ignore # noqa: F401
|
||||||
|
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor # type: ignore # noqa: F401
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
from xgboost import XGBRegressor, XGBClassifier # type: ignore # noqa: F401
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
from lightgbm import LGBMRegressor, LGBMClassifier # type: ignore # noqa: F401
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class MLModel:
|
class MLModel:
|
||||||
@ -24,16 +49,17 @@ class MLModel:
|
|||||||
A machine learning model managed by Elasticsearch.
|
A machine learning model managed by Elasticsearch.
|
||||||
(See https://www.elastic.co/guide/en/elasticsearch/reference/master/put-inference.html)
|
(See https://www.elastic.co/guide/en/elasticsearch/reference/master/put-inference.html)
|
||||||
|
|
||||||
These models can be created by Elastic ML, or transformed from supported python formats such as scikit-learn or
|
These models can be created by Elastic ML, or transformed from supported Python formats
|
||||||
xgboost and imported into Elasticsearch.
|
such as scikit-learn or xgboost and imported into Elasticsearch.
|
||||||
|
|
||||||
The methods for this class attempt to mirror standard python classes.
|
The methods for this class attempt to mirror standard Python classes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
TYPE_CLASSIFICATION = "classification"
|
def __init__(
|
||||||
TYPE_REGRESSION = "regression"
|
self,
|
||||||
|
es_client: Union[str, List[str], Tuple[str, ...], "Elasticsearch"],
|
||||||
def __init__(self, es_client, model_id: str):
|
model_id: str,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
@ -46,6 +72,331 @@ class MLModel:
|
|||||||
"""
|
"""
|
||||||
self._client = ensure_es_client(es_client)
|
self._client = ensure_es_client(es_client)
|
||||||
self._model_id = model_id
|
self._model_id = model_id
|
||||||
|
self._trained_model_config_cache: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
def predict(
|
||||||
|
self, X: Union[np.ndarray, List[float], List[List[float]]]
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Make a prediction using a trained model stored in Elasticsearch.
|
||||||
|
|
||||||
|
Parameters for this method are not yet fully compatible with standard sklearn.predict.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
X: Input feature vector.
|
||||||
|
Must be either a numpy ndarray or a list or list of lists
|
||||||
|
of type float. TODO: support DataFrame and other formats
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
y: np.ndarray of dtype float for regressors or int for classifiers
|
||||||
|
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
>>> from sklearn import datasets
|
||||||
|
>>> from xgboost import XGBRegressor
|
||||||
|
>>> from eland.ml import ImportedMLModel
|
||||||
|
|
||||||
|
>>> # Train model
|
||||||
|
>>> training_data = datasets.make_classification(n_features=6, random_state=0)
|
||||||
|
>>> test_data = [[-1, -2, -3, -4, -5, -6], [10, 20, 30, 40, 50, 60]]
|
||||||
|
>>> regressor = XGBRegressor(objective='reg:squarederror')
|
||||||
|
>>> regressor = regressor.fit(training_data[0], training_data[1])
|
||||||
|
|
||||||
|
>>> # Get some test results
|
||||||
|
>>> regressor.predict(np.array(test_data))
|
||||||
|
array([0.06062475, 0.9990102 ], dtype=float32)
|
||||||
|
|
||||||
|
>>> # Serialise the model to Elasticsearch
|
||||||
|
>>> feature_names = ["f0", "f1", "f2", "f3", "f4", "f5"]
|
||||||
|
>>> model_id = "test_xgb_regressor"
|
||||||
|
>>> es_model = MLModel.import_model('localhost', model_id, regressor, feature_names, es_if_exists='replace')
|
||||||
|
|
||||||
|
>>> # Get some test results from Elasticsearch model
|
||||||
|
>>> es_model.predict(test_data)
|
||||||
|
array([0.0606248 , 0.99901026], dtype=float32)
|
||||||
|
|
||||||
|
>>> # Delete model from Elasticsearch
|
||||||
|
>>> es_model.delete_model()
|
||||||
|
"""
|
||||||
|
docs = []
|
||||||
|
if isinstance(X, np.ndarray):
|
||||||
|
|
||||||
|
def to_list_or_float(x: Any) -> Union[List[Any], float]:
|
||||||
|
if isinstance(x, np.ndarray):
|
||||||
|
return [to_list_or_float(i) for i in x.tolist()]
|
||||||
|
elif isinstance(x, list):
|
||||||
|
return [to_list_or_float(i) for i in x]
|
||||||
|
return float(x)
|
||||||
|
|
||||||
|
X = to_list_or_float(X)
|
||||||
|
|
||||||
|
# Is it a list of floats?
|
||||||
|
if isinstance(X, list) and all(isinstance(i, (float, int)) for i in X):
|
||||||
|
features = cast(List[List[float]], [X])
|
||||||
|
# If not a list of lists of floats then we error out.
|
||||||
|
elif isinstance(X, list) and all(
|
||||||
|
[
|
||||||
|
isinstance(i, list) and all([isinstance(ix, (float, int)) for ix in i])
|
||||||
|
for i in X
|
||||||
|
]
|
||||||
|
):
|
||||||
|
features = cast(List[List[float]], X)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"Prediction for type {type(X)}, not supported: {X!r}"
|
||||||
|
)
|
||||||
|
|
||||||
|
for i in features:
|
||||||
|
doc = {"_source": dict(zip(self.feature_names, i))}
|
||||||
|
docs.append(doc)
|
||||||
|
|
||||||
|
# field_mappings -> field_map in ES 7.7
|
||||||
|
field_map_name = (
|
||||||
|
"field_map" if es_version(self._client) >= (7, 7) else "field_mappings"
|
||||||
|
)
|
||||||
|
|
||||||
|
results = self._client.ingest.simulate(
|
||||||
|
body={
|
||||||
|
"pipeline": {
|
||||||
|
"processors": [
|
||||||
|
{
|
||||||
|
"inference": {
|
||||||
|
"model_id": self._model_id,
|
||||||
|
"inference_config": {self.model_type: {}},
|
||||||
|
field_map_name: {},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"docs": docs,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Unpack results into an array. Errors can be present
|
||||||
|
# within the response without a non-2XX HTTP status code.
|
||||||
|
y = []
|
||||||
|
for res in results["docs"]:
|
||||||
|
if "error" in res:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Failed to run prediction for model ID {self._model_id!r}",
|
||||||
|
res["error"],
|
||||||
|
)
|
||||||
|
|
||||||
|
y.append(res["doc"]["_source"]["ml"]["inference"][self.results_field])
|
||||||
|
|
||||||
|
# Return results as np.ndarray of float32 or int (consistent with sklearn/xgboost)
|
||||||
|
if self.model_type == TYPE_CLASSIFICATION:
|
||||||
|
dt = np.int
|
||||||
|
else:
|
||||||
|
dt = np.float32
|
||||||
|
return np.asarray(y, dtype=dt)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_type(self) -> str:
|
||||||
|
inference_config = self._trained_model_config["inference_config"]
|
||||||
|
if "classification" in inference_config:
|
||||||
|
return TYPE_CLASSIFICATION
|
||||||
|
elif "regression" in inference_config:
|
||||||
|
return TYPE_REGRESSION
|
||||||
|
raise ValueError("Unable to determine 'model_type' for MLModel")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def feature_names(self) -> List[str]:
|
||||||
|
return list(self._trained_model_config["input"]["field_names"])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def results_field(self) -> str:
|
||||||
|
return cast(
|
||||||
|
str,
|
||||||
|
self._trained_model_config["inference_config"][self.model_type][
|
||||||
|
"results_field"
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def import_model(
|
||||||
|
cls,
|
||||||
|
es_client: Union[str, List[str], Tuple[str, ...], "Elasticsearch"],
|
||||||
|
model_id: str,
|
||||||
|
model: Union[
|
||||||
|
"DecisionTreeClassifier",
|
||||||
|
"DecisionTreeRegressor",
|
||||||
|
"RandomForestRegressor",
|
||||||
|
"RandomForestClassifier",
|
||||||
|
"XGBClassifier",
|
||||||
|
"XGBRegressor",
|
||||||
|
"LGBMRegressor",
|
||||||
|
"LGBMClassifier",
|
||||||
|
],
|
||||||
|
feature_names: List[str],
|
||||||
|
classification_labels: Optional[List[str]] = None,
|
||||||
|
classification_weights: Optional[List[float]] = None,
|
||||||
|
es_if_exists: Optional[str] = None,
|
||||||
|
overwrite: Optional[bool] = None,
|
||||||
|
es_compress_model_definition: bool = True,
|
||||||
|
) -> "MLModel":
|
||||||
|
"""
|
||||||
|
Transform and serialize a trained 3rd party model into Elasticsearch.
|
||||||
|
This model can then be used for inference in the Elastic Stack.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
es_client: Elasticsearch client argument(s)
|
||||||
|
- elasticsearch-py parameters or
|
||||||
|
- elasticsearch-py instance
|
||||||
|
|
||||||
|
model_id: str
|
||||||
|
The unique identifier of the trained inference model in Elasticsearch.
|
||||||
|
|
||||||
|
model: An instance of a supported python model. We support the following model types:
|
||||||
|
- sklearn.tree.DecisionTreeClassifier
|
||||||
|
- sklearn.tree.DecisionTreeRegressor
|
||||||
|
- sklearn.ensemble.RandomForestRegressor
|
||||||
|
- sklearn.ensemble.RandomForestClassifier
|
||||||
|
- lightgbm.LGBMRegressor
|
||||||
|
- Categorical fields are expected to already be processed
|
||||||
|
- Only the following objectives are supported
|
||||||
|
- "regression"
|
||||||
|
- "regression_l1"
|
||||||
|
- "huber"
|
||||||
|
- "fair"
|
||||||
|
- "quantile"
|
||||||
|
- "mape"
|
||||||
|
- lightgbm.LGBMClassifier
|
||||||
|
- Categorical fields are expected to already be processed
|
||||||
|
- Only the following objectives are supported
|
||||||
|
- "binary"
|
||||||
|
- "multiclass"
|
||||||
|
- "multiclassova"
|
||||||
|
- xgboost.XGBClassifier
|
||||||
|
- only the following objectives are supported:
|
||||||
|
- "binary:logistic"
|
||||||
|
- "multi:softmax"
|
||||||
|
- "multi:softprob"
|
||||||
|
- xgboost.XGBRegressor
|
||||||
|
- only the following objectives are supported:
|
||||||
|
- "reg:squarederror"
|
||||||
|
- "reg:linear"
|
||||||
|
- "reg:squaredlogerror"
|
||||||
|
- "reg:logistic"
|
||||||
|
- "reg:pseudohubererror"
|
||||||
|
|
||||||
|
feature_names: List[str]
|
||||||
|
Names of the features (required)
|
||||||
|
|
||||||
|
classification_labels: List[str]
|
||||||
|
Labels of the classification targets
|
||||||
|
|
||||||
|
classification_weights: List[str]
|
||||||
|
Weights of the classification targets
|
||||||
|
|
||||||
|
es_if_exists: {'fail', 'replace'} default 'fail'
|
||||||
|
How to behave if model already exists
|
||||||
|
|
||||||
|
- fail: Raise a Value Error
|
||||||
|
- replace: Overwrite existing model
|
||||||
|
|
||||||
|
overwrite: **DEPRECATED** - bool
|
||||||
|
Delete and overwrite existing model (if exists)
|
||||||
|
|
||||||
|
es_compress_model_definition: bool
|
||||||
|
If True will use 'compressed_definition' which uses gzipped
|
||||||
|
JSON instead of raw JSON to reduce the amount of data sent
|
||||||
|
over the wire in HTTP requests. Defaults to 'True'.
|
||||||
|
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
>>> from sklearn import datasets
|
||||||
|
>>> from sklearn.tree import DecisionTreeClassifier
|
||||||
|
>>> from eland.ml import MLModel
|
||||||
|
|
||||||
|
>>> # Train model
|
||||||
|
>>> training_data = datasets.make_classification(n_features=5, random_state=0)
|
||||||
|
>>> test_data = [[-50.1, 0.2, 0.3, -0.5, 1.0], [1.6, 2.1, -10, 50, -1.0]]
|
||||||
|
>>> classifier = DecisionTreeClassifier()
|
||||||
|
>>> classifier = classifier.fit(training_data[0], training_data[1])
|
||||||
|
|
||||||
|
>>> # Get some test results
|
||||||
|
>>> classifier.predict(test_data)
|
||||||
|
array([0, 1])
|
||||||
|
|
||||||
|
>>> # Serialise the model to Elasticsearch
|
||||||
|
>>> feature_names = ["f0", "f1", "f2", "f3", "f4"]
|
||||||
|
>>> model_id = "test_decision_tree_classifier"
|
||||||
|
>>> es_model = MLModel.import_model(
|
||||||
|
... 'localhost',
|
||||||
|
... model_id=model_id,
|
||||||
|
... model=classifier,
|
||||||
|
... feature_names=feature_names,
|
||||||
|
... es_if_exists='replace'
|
||||||
|
... )
|
||||||
|
|
||||||
|
>>> # Get some test results from Elasticsearch model
|
||||||
|
>>> es_model.predict(test_data)
|
||||||
|
array([0, 1])
|
||||||
|
|
||||||
|
>>> # Delete model from Elasticsearch
|
||||||
|
>>> es_model.delete_model()
|
||||||
|
"""
|
||||||
|
es_client = ensure_es_client(es_client)
|
||||||
|
transformer = get_model_transformer(
|
||||||
|
model,
|
||||||
|
feature_names=feature_names,
|
||||||
|
classification_labels=classification_labels,
|
||||||
|
classification_weights=classification_weights,
|
||||||
|
)
|
||||||
|
serializer = transformer.transform()
|
||||||
|
model_type = transformer.model_type
|
||||||
|
|
||||||
|
# Verify if both parameters are given
|
||||||
|
if overwrite is not None and es_if_exists is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"Using 'overwrite' and 'es_if_exists' together is invalid, use only 'es_if_exists'"
|
||||||
|
)
|
||||||
|
|
||||||
|
if overwrite is not None:
|
||||||
|
warnings.warn(
|
||||||
|
"'overwrite' parameter is deprecated, use 'es_if_exists' instead",
|
||||||
|
DeprecationWarning,
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
es_if_exists = "replace" if overwrite else "fail"
|
||||||
|
elif es_if_exists is None:
|
||||||
|
es_if_exists = "fail"
|
||||||
|
|
||||||
|
ml_model = MLModel(
|
||||||
|
es_client=es_client,
|
||||||
|
model_id=model_id,
|
||||||
|
)
|
||||||
|
if es_if_exists not in ("fail", "replace"):
|
||||||
|
raise ValueError("'es_if_exists' must be either 'fail' or 'replace'")
|
||||||
|
elif es_if_exists == "fail":
|
||||||
|
if ml_model.exists_model():
|
||||||
|
raise ValueError(
|
||||||
|
f"Trained machine learning model {model_id} already exists"
|
||||||
|
)
|
||||||
|
elif es_if_exists == "replace":
|
||||||
|
ml_model.delete_model()
|
||||||
|
|
||||||
|
body: Dict[str, Any] = {
|
||||||
|
"input": {"field_names": feature_names},
|
||||||
|
}
|
||||||
|
# 'inference_config' is required in 7.8+ but isn't available in <=7.7
|
||||||
|
if es_version(es_client) >= (7, 8):
|
||||||
|
body["inference_config"] = {model_type: {}}
|
||||||
|
|
||||||
|
if es_compress_model_definition:
|
||||||
|
body["compressed_definition"] = serializer.serialize_and_compress_model()
|
||||||
|
else:
|
||||||
|
body["definition"] = serializer.serialize_model()
|
||||||
|
|
||||||
|
ml_model._client.ml.put_trained_model(
|
||||||
|
model_id=model_id,
|
||||||
|
body=body,
|
||||||
|
)
|
||||||
|
return ml_model
|
||||||
|
|
||||||
def delete_model(self) -> None:
|
def delete_model(self) -> None:
|
||||||
"""
|
"""
|
||||||
@ -69,3 +420,18 @@ class MLModel:
|
|||||||
except elasticsearch.NotFoundError:
|
except elasticsearch.NotFoundError:
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _trained_model_config(self) -> Dict[str, Any]:
|
||||||
|
"""Lazily loads an ML models 'trained_model_config' information"""
|
||||||
|
if self._trained_model_config_cache is None:
|
||||||
|
resp = self._client.ml.get_trained_models(model_id=self._model_id)
|
||||||
|
if resp["count"] > 1:
|
||||||
|
raise ValueError(f"Model ID {self._model_id!r} wasn't unambiguous")
|
||||||
|
elif resp["count"] == 0:
|
||||||
|
raise ValueError(f"Model with Model ID {self._model_id!r} wasn't found")
|
||||||
|
self._trained_model_config_cache = resp["trained_model_configs"][0]
|
||||||
|
return self._trained_model_config_cache
|
||||||
|
|
||||||
|
|
||||||
|
ImportedMLModel = deprecated_api("MLModel.import_model()")(MLModel.import_model)
|
||||||
|
@ -37,7 +37,7 @@ def get_model_transformer(model: Any, **kwargs: Any) -> ModelTransformer:
|
|||||||
return transformer(model, **kwargs)
|
return transformer(model, **kwargs)
|
||||||
|
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"ML model of type {type(model)}, not currently implemented"
|
f"Importing ML models of type {type(model)}, not currently implemented"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -18,7 +18,7 @@
|
|||||||
from typing import Optional, List, Dict, Any, Type
|
from typing import Optional, List, Dict, Any, Type
|
||||||
from .base import ModelTransformer
|
from .base import ModelTransformer
|
||||||
from .._model_serializer import Ensemble, Tree, TreeNode
|
from .._model_serializer import Ensemble, Tree, TreeNode
|
||||||
from ..ml_model import MLModel
|
from ..common import TYPE_CLASSIFICATION, TYPE_REGRESSION
|
||||||
from .._optional import import_optional_dependency
|
from .._optional import import_optional_dependency
|
||||||
|
|
||||||
import_optional_dependency("lightgbm", on_version="warn")
|
import_optional_dependency("lightgbm", on_version="warn")
|
||||||
@ -201,7 +201,7 @@ class LGBMRegressorTransformer(LGBMForestTransformer):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def model_type(self) -> str:
|
def model_type(self) -> str:
|
||||||
return MLModel.TYPE_REGRESSION
|
return TYPE_REGRESSION
|
||||||
|
|
||||||
|
|
||||||
class LGBMClassifierTransformer(LGBMForestTransformer):
|
class LGBMClassifierTransformer(LGBMForestTransformer):
|
||||||
@ -244,7 +244,7 @@ class LGBMClassifierTransformer(LGBMForestTransformer):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def model_type(self) -> str:
|
def model_type(self) -> str:
|
||||||
return MLModel.TYPE_CLASSIFICATION
|
return TYPE_CLASSIFICATION
|
||||||
|
|
||||||
def is_objective_supported(self) -> bool:
|
def is_objective_supported(self) -> bool:
|
||||||
return self._objective in {
|
return self._objective in {
|
||||||
|
@ -18,7 +18,7 @@
|
|||||||
import numpy as np # type: ignore
|
import numpy as np # type: ignore
|
||||||
from typing import Optional, Sequence, Union, Dict, Any, Type, Tuple
|
from typing import Optional, Sequence, Union, Dict, Any, Type, Tuple
|
||||||
from .base import ModelTransformer
|
from .base import ModelTransformer
|
||||||
from ..ml_model import MLModel
|
from ..common import TYPE_CLASSIFICATION, TYPE_REGRESSION
|
||||||
from .._optional import import_optional_dependency
|
from .._optional import import_optional_dependency
|
||||||
from .._model_serializer import Ensemble, Tree, TreeNode
|
from .._model_serializer import Ensemble, Tree, TreeNode
|
||||||
|
|
||||||
@ -148,9 +148,9 @@ class SKLearnDecisionTreeTransformer(SKLearnTransformer):
|
|||||||
@property
|
@property
|
||||||
def model_type(self) -> str:
|
def model_type(self) -> str:
|
||||||
return (
|
return (
|
||||||
MLModel.TYPE_REGRESSION
|
TYPE_REGRESSION
|
||||||
if isinstance(self._model, DecisionTreeRegressor)
|
if isinstance(self._model, DecisionTreeRegressor)
|
||||||
else MLModel.TYPE_CLASSIFICATION
|
else TYPE_CLASSIFICATION
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -223,7 +223,7 @@ class SKLearnForestRegressorTransformer(SKLearnForestTransformer):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def model_type(self) -> str:
|
def model_type(self) -> str:
|
||||||
return MLModel.TYPE_REGRESSION
|
return TYPE_REGRESSION
|
||||||
|
|
||||||
|
|
||||||
class SKLearnForestClassifierTransformer(SKLearnForestTransformer):
|
class SKLearnForestClassifierTransformer(SKLearnForestTransformer):
|
||||||
@ -247,7 +247,7 @@ class SKLearnForestClassifierTransformer(SKLearnForestTransformer):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def model_type(self) -> str:
|
def model_type(self) -> str:
|
||||||
return MLModel.TYPE_CLASSIFICATION
|
return TYPE_CLASSIFICATION
|
||||||
|
|
||||||
|
|
||||||
_MODEL_TRANSFORMERS: Dict[type, Type[ModelTransformer]] = {
|
_MODEL_TRANSFORMERS: Dict[type, Type[ModelTransformer]] = {
|
||||||
|
@ -20,7 +20,7 @@ from typing import Optional, List, Dict, Any, Type
|
|||||||
from .base import ModelTransformer
|
from .base import ModelTransformer
|
||||||
import pandas as pd # type: ignore
|
import pandas as pd # type: ignore
|
||||||
from .._model_serializer import Ensemble, Tree, TreeNode
|
from .._model_serializer import Ensemble, Tree, TreeNode
|
||||||
from ..ml_model import MLModel
|
from ..common import TYPE_CLASSIFICATION, TYPE_REGRESSION
|
||||||
from .._optional import import_optional_dependency
|
from .._optional import import_optional_dependency
|
||||||
|
|
||||||
import_optional_dependency("xgboost", on_version="warn")
|
import_optional_dependency("xgboost", on_version="warn")
|
||||||
@ -207,7 +207,7 @@ class XGBoostRegressorTransformer(XGBoostForestTransformer):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def model_type(self) -> str:
|
def model_type(self) -> str:
|
||||||
return MLModel.TYPE_REGRESSION
|
return TYPE_REGRESSION
|
||||||
|
|
||||||
|
|
||||||
class XGBoostClassifierTransformer(XGBoostForestTransformer):
|
class XGBoostClassifierTransformer(XGBoostForestTransformer):
|
||||||
@ -253,7 +253,7 @@ class XGBoostClassifierTransformer(XGBoostForestTransformer):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def model_type(self) -> str:
|
def model_type(self) -> str:
|
||||||
return MLModel.TYPE_CLASSIFICATION
|
return TYPE_CLASSIFICATION
|
||||||
|
|
||||||
|
|
||||||
_MODEL_TRANSFORMERS: Dict[type, Type[ModelTransformer]] = {
|
_MODEL_TRANSFORMERS: Dict[type, Type[ModelTransformer]] = {
|
||||||
|
@ -18,7 +18,7 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from eland.ml import ImportedMLModel
|
from eland.ml import MLModel
|
||||||
from eland.tests import ES_TEST_CLIENT, ES_VERSION
|
from eland.tests import ES_TEST_CLIENT, ES_VERSION
|
||||||
|
|
||||||
|
|
||||||
@ -71,7 +71,7 @@ def skip_if_multiclass_classifition():
|
|||||||
|
|
||||||
|
|
||||||
def random_rows(data, size):
|
def random_rows(data, size):
|
||||||
return data[np.random.randint(data.shape[0], size=size), :].tolist()
|
return data[np.random.randint(data.shape[0], size=size), :]
|
||||||
|
|
||||||
|
|
||||||
def check_prediction_equality(es_model, py_model, test_data):
|
def check_prediction_equality(es_model, py_model, test_data):
|
||||||
@ -84,7 +84,7 @@ def check_prediction_equality(es_model, py_model, test_data):
|
|||||||
class TestImportedMLModel:
|
class TestImportedMLModel:
|
||||||
@requires_no_ml_extras
|
@requires_no_ml_extras
|
||||||
def test_import_ml_model_when_dependencies_are_not_available(self):
|
def test_import_ml_model_when_dependencies_are_not_available(self):
|
||||||
from eland.ml import MLModel, ImportedMLModel # noqa: F401
|
from eland.ml import MLModel # noqa: F401
|
||||||
|
|
||||||
@requires_sklearn
|
@requires_sklearn
|
||||||
def test_unpack_and_raise_errors_in_ingest_simulate(self, mocker):
|
def test_unpack_and_raise_errors_in_ingest_simulate(self, mocker):
|
||||||
@ -98,7 +98,7 @@ class TestImportedMLModel:
|
|||||||
model_id = "test_decision_tree_classifier"
|
model_id = "test_decision_tree_classifier"
|
||||||
test_data = [[0.1, 0.2, 0.3, -0.5, 1.0], [1.6, 2.1, -10, 50, -1.0]]
|
test_data = [[0.1, 0.2, 0.3, -0.5, 1.0], [1.6, 2.1, -10, 50, -1.0]]
|
||||||
|
|
||||||
es_model = ImportedMLModel(
|
es_model = MLModel.import_model(
|
||||||
ES_TEST_CLIENT,
|
ES_TEST_CLIENT,
|
||||||
model_id,
|
model_id,
|
||||||
classifier,
|
classifier,
|
||||||
@ -142,7 +142,7 @@ class TestImportedMLModel:
|
|||||||
feature_names = ["f0", "f1", "f2", "f3", "f4"]
|
feature_names = ["f0", "f1", "f2", "f3", "f4"]
|
||||||
model_id = "test_decision_tree_classifier"
|
model_id = "test_decision_tree_classifier"
|
||||||
|
|
||||||
es_model = ImportedMLModel(
|
es_model = MLModel.import_model(
|
||||||
ES_TEST_CLIENT,
|
ES_TEST_CLIENT,
|
||||||
model_id,
|
model_id,
|
||||||
classifier,
|
classifier,
|
||||||
@ -171,7 +171,7 @@ class TestImportedMLModel:
|
|||||||
feature_names = ["f0", "f1", "f2", "f3", "f4"]
|
feature_names = ["f0", "f1", "f2", "f3", "f4"]
|
||||||
model_id = "test_decision_tree_regressor"
|
model_id = "test_decision_tree_regressor"
|
||||||
|
|
||||||
es_model = ImportedMLModel(
|
es_model = MLModel.import_model(
|
||||||
ES_TEST_CLIENT,
|
ES_TEST_CLIENT,
|
||||||
model_id,
|
model_id,
|
||||||
regressor,
|
regressor,
|
||||||
@ -199,7 +199,7 @@ class TestImportedMLModel:
|
|||||||
feature_names = ["f0", "f1", "f2", "f3", "f4"]
|
feature_names = ["f0", "f1", "f2", "f3", "f4"]
|
||||||
model_id = "test_random_forest_classifier"
|
model_id = "test_random_forest_classifier"
|
||||||
|
|
||||||
es_model = ImportedMLModel(
|
es_model = MLModel.import_model(
|
||||||
ES_TEST_CLIENT,
|
ES_TEST_CLIENT,
|
||||||
model_id,
|
model_id,
|
||||||
classifier,
|
classifier,
|
||||||
@ -227,7 +227,7 @@ class TestImportedMLModel:
|
|||||||
feature_names = ["f0", "f1", "f2", "f3", "f4"]
|
feature_names = ["f0", "f1", "f2", "f3", "f4"]
|
||||||
model_id = "test_random_forest_regressor"
|
model_id = "test_random_forest_regressor"
|
||||||
|
|
||||||
es_model = ImportedMLModel(
|
es_model = MLModel.import_model(
|
||||||
ES_TEST_CLIENT,
|
ES_TEST_CLIENT,
|
||||||
model_id,
|
model_id,
|
||||||
regressor,
|
regressor,
|
||||||
@ -265,7 +265,7 @@ class TestImportedMLModel:
|
|||||||
feature_names = ["f0", "f1", "f2", "f3", "f4"]
|
feature_names = ["f0", "f1", "f2", "f3", "f4"]
|
||||||
model_id = "test_xgb_classifier"
|
model_id = "test_xgb_classifier"
|
||||||
|
|
||||||
es_model = ImportedMLModel(
|
es_model = MLModel.import_model(
|
||||||
ES_TEST_CLIENT,
|
ES_TEST_CLIENT,
|
||||||
model_id,
|
model_id,
|
||||||
classifier,
|
classifier,
|
||||||
@ -305,7 +305,7 @@ class TestImportedMLModel:
|
|||||||
feature_names = ["feature0", "feature1", "feature2", "feature3", "feature4"]
|
feature_names = ["feature0", "feature1", "feature2", "feature3", "feature4"]
|
||||||
model_id = "test_xgb_classifier"
|
model_id = "test_xgb_classifier"
|
||||||
|
|
||||||
es_model = ImportedMLModel(
|
es_model = MLModel.import_model(
|
||||||
ES_TEST_CLIENT, model_id, classifier, feature_names, es_if_exists="replace"
|
ES_TEST_CLIENT, model_id, classifier, feature_names, es_if_exists="replace"
|
||||||
)
|
)
|
||||||
# Get some test results
|
# Get some test results
|
||||||
@ -337,7 +337,7 @@ class TestImportedMLModel:
|
|||||||
feature_names = ["f0", "f1", "f2", "f3", "f4"]
|
feature_names = ["f0", "f1", "f2", "f3", "f4"]
|
||||||
model_id = "test_xgb_regressor"
|
model_id = "test_xgb_regressor"
|
||||||
|
|
||||||
es_model = ImportedMLModel(
|
es_model = MLModel.import_model(
|
||||||
ES_TEST_CLIENT,
|
ES_TEST_CLIENT,
|
||||||
model_id,
|
model_id,
|
||||||
regressor,
|
regressor,
|
||||||
@ -368,7 +368,7 @@ class TestImportedMLModel:
|
|||||||
feature_names = ["f0"]
|
feature_names = ["f0"]
|
||||||
model_id = "test_xgb_regressor"
|
model_id = "test_xgb_regressor"
|
||||||
|
|
||||||
es_model = ImportedMLModel(
|
es_model = MLModel.import_model(
|
||||||
ES_TEST_CLIENT, model_id, regressor, feature_names, es_if_exists="replace"
|
ES_TEST_CLIENT, model_id, regressor, feature_names, es_if_exists="replace"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -405,7 +405,7 @@ class TestImportedMLModel:
|
|||||||
feature_names = ["Column_0", "Column_1", "Column_2", "Column_3", "Column_4"]
|
feature_names = ["Column_0", "Column_1", "Column_2", "Column_3", "Column_4"]
|
||||||
model_id = "test_lgbm_regressor"
|
model_id = "test_lgbm_regressor"
|
||||||
|
|
||||||
es_model = ImportedMLModel(
|
es_model = MLModel.import_model(
|
||||||
ES_TEST_CLIENT,
|
ES_TEST_CLIENT,
|
||||||
model_id,
|
model_id,
|
||||||
regressor,
|
regressor,
|
||||||
@ -446,7 +446,7 @@ class TestImportedMLModel:
|
|||||||
feature_names = ["Column_0", "Column_1", "Column_2", "Column_3", "Column_4"]
|
feature_names = ["Column_0", "Column_1", "Column_2", "Column_3", "Column_4"]
|
||||||
model_id = "test_lgbm_classifier"
|
model_id = "test_lgbm_classifier"
|
||||||
|
|
||||||
es_model = ImportedMLModel(
|
es_model = MLModel.import_model(
|
||||||
ES_TEST_CLIENT,
|
ES_TEST_CLIENT,
|
||||||
model_id,
|
model_id,
|
||||||
classifier,
|
classifier,
|
||||||
@ -480,7 +480,7 @@ class TestImportedMLModel:
|
|||||||
|
|
||||||
match = "Using 'overwrite' and 'es_if_exists' together is invalid, use only 'es_if_exists'"
|
match = "Using 'overwrite' and 'es_if_exists' together is invalid, use only 'es_if_exists'"
|
||||||
with pytest.raises(ValueError, match=match):
|
with pytest.raises(ValueError, match=match):
|
||||||
ImportedMLModel(
|
MLModel.import_model(
|
||||||
ES_TEST_CLIENT,
|
ES_TEST_CLIENT,
|
||||||
model_id,
|
model_id,
|
||||||
regressor,
|
regressor,
|
||||||
@ -507,7 +507,7 @@ class TestImportedMLModel:
|
|||||||
|
|
||||||
match = "'overwrite' parameter is deprecated, use 'es_if_exists' instead"
|
match = "'overwrite' parameter is deprecated, use 'es_if_exists' instead"
|
||||||
with pytest.warns(DeprecationWarning, match=match):
|
with pytest.warns(DeprecationWarning, match=match):
|
||||||
ImportedMLModel(
|
MLModel.import_model(
|
||||||
ES_TEST_CLIENT,
|
ES_TEST_CLIENT,
|
||||||
model_id,
|
model_id,
|
||||||
regressor,
|
regressor,
|
||||||
@ -536,7 +536,7 @@ class TestImportedMLModel:
|
|||||||
)
|
)
|
||||||
with pytest.raises(ValueError, match=match_error):
|
with pytest.raises(ValueError, match=match_error):
|
||||||
with pytest.warns(DeprecationWarning, match=match_warning):
|
with pytest.warns(DeprecationWarning, match=match_warning):
|
||||||
ImportedMLModel(
|
MLModel.import_model(
|
||||||
ES_TEST_CLIENT,
|
ES_TEST_CLIENT,
|
||||||
model_id,
|
model_id,
|
||||||
regressor,
|
regressor,
|
||||||
@ -560,7 +560,7 @@ class TestImportedMLModel:
|
|||||||
# If both overwrite and es_if_exists is given.
|
# If both overwrite and es_if_exists is given.
|
||||||
match = f"Trained machine learning model {model_id} already exists"
|
match = f"Trained machine learning model {model_id} already exists"
|
||||||
with pytest.raises(ValueError, match=match):
|
with pytest.raises(ValueError, match=match):
|
||||||
ImportedMLModel(
|
MLModel.import_model(
|
||||||
ES_TEST_CLIENT,
|
ES_TEST_CLIENT,
|
||||||
model_id,
|
model_id,
|
||||||
regressor,
|
regressor,
|
||||||
|
@ -34,7 +34,7 @@ SOURCE_FILES = (
|
|||||||
# Whenever type-hints are completed on a file it should
|
# Whenever type-hints are completed on a file it should
|
||||||
# be added here so that this file will continue to be checked
|
# be added here so that this file will continue to be checked
|
||||||
# by mypy. Errors from other files are ignored.
|
# by mypy. Errors from other files are ignored.
|
||||||
TYPED_FILES = {
|
TYPED_FILES = (
|
||||||
"eland/actions.py",
|
"eland/actions.py",
|
||||||
"eland/arithmetics.py",
|
"eland/arithmetics.py",
|
||||||
"eland/common.py",
|
"eland/common.py",
|
||||||
@ -46,13 +46,13 @@ TYPED_FILES = {
|
|||||||
"eland/utils.py",
|
"eland/utils.py",
|
||||||
"eland/ml/__init__.py",
|
"eland/ml/__init__.py",
|
||||||
"eland/ml/_model_serializer.py",
|
"eland/ml/_model_serializer.py",
|
||||||
"eland/ml/imported_ml_model.py",
|
"eland/ml/ml_model.py",
|
||||||
"eland/ml/transformers/__init__.py",
|
"eland/ml/transformers/__init__.py",
|
||||||
"eland/ml/transformers/base.py",
|
"eland/ml/transformers/base.py",
|
||||||
"eland/ml/transformers/lightgbm.py",
|
"eland/ml/transformers/lightgbm.py",
|
||||||
"eland/ml/transformers/sklearn.py",
|
"eland/ml/transformers/sklearn.py",
|
||||||
"eland/ml/transformers/xgboost.py",
|
"eland/ml/transformers/xgboost.py",
|
||||||
}
|
)
|
||||||
|
|
||||||
|
|
||||||
@nox.session(reuse_venv=True)
|
@nox.session(reuse_venv=True)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user