Deprecate ImportedMLModel in favor of MLModel.import_model()

This commit is contained in:
Seth Michael Larson 2020-09-03 09:06:59 -05:00 committed by GitHub
parent 1a8a301cd6
commit c86371733d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 445 additions and 381 deletions

View File

@ -2495,10 +2495,10 @@
}
],
"source": [
"from eland.ml import ImportedMLModel\n",
"from eland.ml import MLModel\n",
"\n",
"# Serialize the scikit-learn model into Elasticsearch\n",
"ed_classifier = ImportedMLModel(\n",
"ed_classifier = MLModel.import_model(\n",
" es_client=es,\n",
" model_id=\"wine-classifier\",\n",
" model=sk_classifier,\n",

View File

@ -1,6 +0,0 @@
eland.ml.ImportedMLModel.predict
================================
.. currentmodule:: eland.ml
.. automethod:: ImportedMLModel.predict

View File

@ -1,6 +0,0 @@
eland.ml.ImportedMLModel
========================
.. currentmodule:: eland.ml
.. autoclass:: ImportedMLModel

View File

@ -0,0 +1,6 @@
eland.ml.MLModel.import_model
=============================
.. currentmodule:: eland.ml
.. automethod:: MLModel.import_model

View File

@ -0,0 +1,6 @@
eland.ml.MLModel.predict
========================
.. currentmodule:: eland.ml
.. automethod:: MLModel.predict

View File

@ -27,19 +27,20 @@ Constructor
.. autosummary::
:toctree: api/
ImportedMLModel
MLModel
Predictions
^^^^^^^^^^^
.. autosummary::
:toctree: api/
ImportedMLModel.predict
MLModel.predict
Manage Models
^^^^^^^^^^^^^
.. autosummary::
:toctree: api/
MLModel.import_model
MLModel.exists_model
MLModel.delete_model

View File

@ -15,8 +15,7 @@
# specific language governing permissions and limitations
# under the License.
from eland.ml.ml_model import MLModel
from eland.ml.imported_ml_model import ImportedMLModel
from eland.ml.ml_model import MLModel, ImportedMLModel
__all__ = [
"MLModel",

19
eland/ml/common.py Normal file
View File

@ -0,0 +1,19 @@
# Licensed to Elasticsearch B.V. under one or more contributor
# license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright
# ownership. Elasticsearch B.V. licenses this file to you under
# the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
TYPE_CLASSIFICATION = "classification"
TYPE_REGRESSION = "regression"

View File

@ -1,321 +0,0 @@
# Licensed to Elasticsearch B.V. under one or more contributor
# license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright
# ownership. Elasticsearch B.V. licenses this file to you under
# the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Union, List, Optional, Tuple, TYPE_CHECKING, cast, Dict, Any
import numpy as np # type: ignore
from .ml_model import MLModel
from .transformers import get_model_transformer
from ..common import es_version
import warnings
if TYPE_CHECKING:
from elasticsearch import Elasticsearch # type: ignore # noqa: F401
# Try importing each ML lib separately so mypy users don't have to
# have both installed to use type-checking.
try:
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor # type: ignore # noqa: F401
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor # type: ignore # noqa: F401
except ImportError:
pass
try:
from xgboost import XGBRegressor, XGBClassifier # type: ignore # noqa: F401
except ImportError:
pass
try:
from lightgbm import LGBMRegressor, LGBMClassifier # type: ignore # noqa: F401
except ImportError:
pass
class ImportedMLModel(MLModel):
"""
Transform and serialize a trained 3rd party model into Elasticsearch.
This model can then be used for inference in the Elastic Stack.
Parameters
----------
es_client: Elasticsearch client argument(s)
- elasticsearch-py parameters or
- elasticsearch-py instance
model_id: str
The unique identifier of the trained inference model in Elasticsearch.
model: An instance of a supported python model. We support the following model types:
- sklearn.tree.DecisionTreeClassifier
- sklearn.tree.DecisionTreeRegressor
- sklearn.ensemble.RandomForestRegressor
- sklearn.ensemble.RandomForestClassifier
- lightgbm.LGBMRegressor
- Categorical fields are expected to already be processed
- Only the following objectives are supported
- "regression"
- "regression_l1"
- "huber"
- "fair"
- "quantile"
- "mape"
- lightgbm.LGBMClassifier
- Categorical fields are expected to already be processed
- Only the following objectives are supported
- "binary"
- "multiclass"
- "multiclassova"
- xgboost.XGBClassifier
- only the following objectives are supported:
- "binary:logistic"
- "multi:softmax"
- "multi:softprob"
- xgboost.XGBRegressor
- only the following objectives are supported:
- "reg:squarederror"
- "reg:linear"
- "reg:squaredlogerror"
- "reg:logistic"
- "reg:pseudohubererror"
feature_names: List[str]
Names of the features (required)
classification_labels: List[str]
Labels of the classification targets
classification_weights: List[str]
Weights of the classification targets
es_if_exists: {'fail', 'replace'} default 'fail'
How to behave if model already exists
- fail: Raise a Value Error
- replace: Overwrite existing model
overwrite: **DEPRECATED** - bool
Delete and overwrite existing model (if exists)
es_compress_model_definition: bool
If True will use 'compressed_definition' which uses gzipped
JSON instead of raw JSON to reduce the amount of data sent
over the wire in HTTP requests. Defaults to 'True'.
Examples
--------
>>> from sklearn import datasets
>>> from sklearn.tree import DecisionTreeClassifier
>>> from eland.ml import ImportedMLModel
>>> # Train model
>>> training_data = datasets.make_classification(n_features=5, random_state=0)
>>> test_data = [[-50.1, 0.2, 0.3, -0.5, 1.0], [1.6, 2.1, -10, 50, -1.0]]
>>> classifier = DecisionTreeClassifier()
>>> classifier = classifier.fit(training_data[0], training_data[1])
>>> # Get some test results
>>> classifier.predict(test_data)
array([0, 1])
>>> # Serialise the model to Elasticsearch
>>> feature_names = ["f0", "f1", "f2", "f3", "f4"]
>>> model_id = "test_decision_tree_classifier"
>>> es_model = ImportedMLModel('localhost', model_id, classifier, feature_names, es_if_exists='replace')
>>> # Get some test results from Elasticsearch model
>>> es_model.predict(test_data)
array([0, 1])
>>> # Delete model from Elasticsearch
>>> es_model.delete_model()
"""
def __init__(
self,
es_client: Union[str, List[str], Tuple[str, ...], "Elasticsearch"],
model_id: str,
model: Union[
"DecisionTreeClassifier",
"DecisionTreeRegressor",
"RandomForestRegressor",
"RandomForestClassifier",
"XGBClassifier",
"XGBRegressor",
"LGBMRegressor",
"LGBMClassifier",
],
feature_names: List[str],
classification_labels: Optional[List[str]] = None,
classification_weights: Optional[List[float]] = None,
es_if_exists: Optional[str] = None,
overwrite: Optional[bool] = None,
es_compress_model_definition: bool = True,
):
super().__init__(es_client, model_id)
self._feature_names = feature_names
transformer = get_model_transformer(
model,
feature_names=feature_names,
classification_labels=classification_labels,
classification_weights=classification_weights,
)
self._model_type = transformer.model_type
serializer = transformer.transform()
# Verify if both parameters are given
if overwrite is not None and es_if_exists is not None:
raise ValueError(
"Using 'overwrite' and 'es_if_exists' together is invalid, use only 'es_if_exists'"
)
if overwrite is not None:
warnings.warn(
"'overwrite' parameter is deprecated, use 'es_if_exists' instead",
DeprecationWarning,
stacklevel=2,
)
es_if_exists = "replace" if overwrite else "fail"
elif es_if_exists is None:
es_if_exists = "fail"
if es_if_exists not in ("fail", "replace"):
raise ValueError("'es_if_exists' must be either 'fail' or 'replace'")
elif es_if_exists == "fail":
if self.exists_model():
raise ValueError(
f"Trained machine learning model {model_id} already exists"
)
elif es_if_exists == "replace":
self.delete_model()
body: Dict[str, Any] = {
"input": {"field_names": feature_names},
}
# 'inference_config' is required in 7.8+ but isn't available in <=7.7
if es_version(self._client) >= (7, 8):
body["inference_config"] = {self._model_type: {}}
if es_compress_model_definition:
body["compressed_definition"] = serializer.serialize_and_compress_model()
else:
body["definition"] = serializer.serialize_model()
self._client.ml.put_trained_model(
model_id=self._model_id,
body=body,
)
def predict(self, X: Union[List[float], List[List[float]]]) -> np.ndarray:
"""
Make a prediction using a trained model stored in Elasticsearch.
Parameters for this method are not yet fully compatible with standard sklearn.predict.
Parameters
----------
X: list or list of lists of type float
Input feature vector - TODO support DataFrame and other formats
Returns
-------
y: np.ndarray of dtype float for regressors or int for classifiers
Examples
--------
>>> from sklearn import datasets
>>> from xgboost import XGBRegressor
>>> from eland.ml import ImportedMLModel
>>> # Train model
>>> training_data = datasets.make_classification(n_features=6, random_state=0)
>>> test_data = [[-1, -2, -3, -4, -5, -6], [10, 20, 30, 40, 50, 60]]
>>> regressor = XGBRegressor(objective='reg:squarederror')
>>> regressor = regressor.fit(training_data[0], training_data[1])
>>> # Get some test results
>>> regressor.predict(np.array(test_data))
array([0.06062475, 0.9990102 ], dtype=float32)
>>> # Serialise the model to Elasticsearch
>>> feature_names = ["f0", "f1", "f2", "f3", "f4", "f5"]
>>> model_id = "test_xgb_regressor"
>>> es_model = ImportedMLModel('localhost', model_id, regressor, feature_names, es_if_exists='replace')
>>> # Get some test results from Elasticsearch model
>>> es_model.predict(test_data)
array([0.0606248 , 0.99901026], dtype=float32)
>>> # Delete model from Elasticsearch
>>> es_model.delete_model()
"""
docs = []
if isinstance(X, list):
# Is it a list of floats?
if all(isinstance(i, float) for i in X):
features = cast(List[List[float]], [X])
else:
features = cast(List[List[float]], X)
for i in features:
doc = {"_source": dict(zip(self._feature_names, i))}
docs.append(doc)
else:
raise NotImplementedError(f"Prediction for type {type(X)}, not supported")
# field_mappings -> field_map in ES 7.7
field_map_name = (
"field_map" if es_version(self._client) >= (7, 7) else "field_mappings"
)
results = self._client.ingest.simulate(
body={
"pipeline": {
"processors": [
{
"inference": {
"model_id": self._model_id,
"inference_config": {self._model_type: {}},
field_map_name: {},
}
}
]
},
"docs": docs,
}
)
# Unpack results into an array. Errors can be present
# within the response without a non-2XX HTTP status code.
y = []
for res in results["docs"]:
if "error" in res:
raise RuntimeError(
f"Failed to run prediction for model ID {self._model_id!r}",
res["error"],
)
y.append(res["doc"]["_source"]["ml"]["inference"]["predicted_value"])
# Return results as np.ndarray of float32 or int (consistent with sklearn/xgboost)
if self._model_type == MLModel.TYPE_CLASSIFICATION:
dt = np.int
else:
dt = np.float32
return np.asarray(y, dtype=dt)

View File

@ -15,8 +15,33 @@
# specific language governing permissions and limitations
# under the License.
import elasticsearch
from eland.common import ensure_es_client
from typing import List, Union, cast, Optional, Dict, TYPE_CHECKING, Any, Tuple
import warnings
import numpy as np # type: ignore
import elasticsearch # type: ignore
from .common import TYPE_REGRESSION, TYPE_CLASSIFICATION
from .transformers import get_model_transformer
from eland.common import ensure_es_client, es_version
from eland.utils import deprecated_api
if TYPE_CHECKING:
from elasticsearch import Elasticsearch # noqa: F401
# Try importing each ML lib separately so mypy users don't have to
# have both installed to use type-checking.
try:
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor # type: ignore # noqa: F401
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor # type: ignore # noqa: F401
except ImportError:
pass
try:
from xgboost import XGBRegressor, XGBClassifier # type: ignore # noqa: F401
except ImportError:
pass
try:
from lightgbm import LGBMRegressor, LGBMClassifier # type: ignore # noqa: F401
except ImportError:
pass
class MLModel:
@ -24,16 +49,17 @@ class MLModel:
A machine learning model managed by Elasticsearch.
(See https://www.elastic.co/guide/en/elasticsearch/reference/master/put-inference.html)
These models can be created by Elastic ML, or transformed from supported python formats such as scikit-learn or
xgboost and imported into Elasticsearch.
These models can be created by Elastic ML, or transformed from supported Python formats
such as scikit-learn or xgboost and imported into Elasticsearch.
The methods for this class attempt to mirror standard python classes.
The methods for this class attempt to mirror standard Python classes.
"""
TYPE_CLASSIFICATION = "classification"
TYPE_REGRESSION = "regression"
def __init__(self, es_client, model_id: str):
def __init__(
self,
es_client: Union[str, List[str], Tuple[str, ...], "Elasticsearch"],
model_id: str,
):
"""
Parameters
----------
@ -46,6 +72,331 @@ class MLModel:
"""
self._client = ensure_es_client(es_client)
self._model_id = model_id
self._trained_model_config_cache: Optional[Dict[str, Any]] = None
def predict(
self, X: Union[np.ndarray, List[float], List[List[float]]]
) -> np.ndarray:
"""
Make a prediction using a trained model stored in Elasticsearch.
Parameters for this method are not yet fully compatible with standard sklearn.predict.
Parameters
----------
X: Input feature vector.
Must be either a numpy ndarray or a list or list of lists
of type float. TODO: support DataFrame and other formats
Returns
-------
y: np.ndarray of dtype float for regressors or int for classifiers
Examples
--------
>>> from sklearn import datasets
>>> from xgboost import XGBRegressor
>>> from eland.ml import ImportedMLModel
>>> # Train model
>>> training_data = datasets.make_classification(n_features=6, random_state=0)
>>> test_data = [[-1, -2, -3, -4, -5, -6], [10, 20, 30, 40, 50, 60]]
>>> regressor = XGBRegressor(objective='reg:squarederror')
>>> regressor = regressor.fit(training_data[0], training_data[1])
>>> # Get some test results
>>> regressor.predict(np.array(test_data))
array([0.06062475, 0.9990102 ], dtype=float32)
>>> # Serialise the model to Elasticsearch
>>> feature_names = ["f0", "f1", "f2", "f3", "f4", "f5"]
>>> model_id = "test_xgb_regressor"
>>> es_model = MLModel.import_model('localhost', model_id, regressor, feature_names, es_if_exists='replace')
>>> # Get some test results from Elasticsearch model
>>> es_model.predict(test_data)
array([0.0606248 , 0.99901026], dtype=float32)
>>> # Delete model from Elasticsearch
>>> es_model.delete_model()
"""
docs = []
if isinstance(X, np.ndarray):
def to_list_or_float(x: Any) -> Union[List[Any], float]:
if isinstance(x, np.ndarray):
return [to_list_or_float(i) for i in x.tolist()]
elif isinstance(x, list):
return [to_list_or_float(i) for i in x]
return float(x)
X = to_list_or_float(X)
# Is it a list of floats?
if isinstance(X, list) and all(isinstance(i, (float, int)) for i in X):
features = cast(List[List[float]], [X])
# If not a list of lists of floats then we error out.
elif isinstance(X, list) and all(
[
isinstance(i, list) and all([isinstance(ix, (float, int)) for ix in i])
for i in X
]
):
features = cast(List[List[float]], X)
else:
raise NotImplementedError(
f"Prediction for type {type(X)}, not supported: {X!r}"
)
for i in features:
doc = {"_source": dict(zip(self.feature_names, i))}
docs.append(doc)
# field_mappings -> field_map in ES 7.7
field_map_name = (
"field_map" if es_version(self._client) >= (7, 7) else "field_mappings"
)
results = self._client.ingest.simulate(
body={
"pipeline": {
"processors": [
{
"inference": {
"model_id": self._model_id,
"inference_config": {self.model_type: {}},
field_map_name: {},
}
}
]
},
"docs": docs,
}
)
# Unpack results into an array. Errors can be present
# within the response without a non-2XX HTTP status code.
y = []
for res in results["docs"]:
if "error" in res:
raise RuntimeError(
f"Failed to run prediction for model ID {self._model_id!r}",
res["error"],
)
y.append(res["doc"]["_source"]["ml"]["inference"][self.results_field])
# Return results as np.ndarray of float32 or int (consistent with sklearn/xgboost)
if self.model_type == TYPE_CLASSIFICATION:
dt = np.int
else:
dt = np.float32
return np.asarray(y, dtype=dt)
@property
def model_type(self) -> str:
inference_config = self._trained_model_config["inference_config"]
if "classification" in inference_config:
return TYPE_CLASSIFICATION
elif "regression" in inference_config:
return TYPE_REGRESSION
raise ValueError("Unable to determine 'model_type' for MLModel")
@property
def feature_names(self) -> List[str]:
return list(self._trained_model_config["input"]["field_names"])
@property
def results_field(self) -> str:
return cast(
str,
self._trained_model_config["inference_config"][self.model_type][
"results_field"
],
)
@classmethod
def import_model(
cls,
es_client: Union[str, List[str], Tuple[str, ...], "Elasticsearch"],
model_id: str,
model: Union[
"DecisionTreeClassifier",
"DecisionTreeRegressor",
"RandomForestRegressor",
"RandomForestClassifier",
"XGBClassifier",
"XGBRegressor",
"LGBMRegressor",
"LGBMClassifier",
],
feature_names: List[str],
classification_labels: Optional[List[str]] = None,
classification_weights: Optional[List[float]] = None,
es_if_exists: Optional[str] = None,
overwrite: Optional[bool] = None,
es_compress_model_definition: bool = True,
) -> "MLModel":
"""
Transform and serialize a trained 3rd party model into Elasticsearch.
This model can then be used for inference in the Elastic Stack.
Parameters
----------
es_client: Elasticsearch client argument(s)
- elasticsearch-py parameters or
- elasticsearch-py instance
model_id: str
The unique identifier of the trained inference model in Elasticsearch.
model: An instance of a supported python model. We support the following model types:
- sklearn.tree.DecisionTreeClassifier
- sklearn.tree.DecisionTreeRegressor
- sklearn.ensemble.RandomForestRegressor
- sklearn.ensemble.RandomForestClassifier
- lightgbm.LGBMRegressor
- Categorical fields are expected to already be processed
- Only the following objectives are supported
- "regression"
- "regression_l1"
- "huber"
- "fair"
- "quantile"
- "mape"
- lightgbm.LGBMClassifier
- Categorical fields are expected to already be processed
- Only the following objectives are supported
- "binary"
- "multiclass"
- "multiclassova"
- xgboost.XGBClassifier
- only the following objectives are supported:
- "binary:logistic"
- "multi:softmax"
- "multi:softprob"
- xgboost.XGBRegressor
- only the following objectives are supported:
- "reg:squarederror"
- "reg:linear"
- "reg:squaredlogerror"
- "reg:logistic"
- "reg:pseudohubererror"
feature_names: List[str]
Names of the features (required)
classification_labels: List[str]
Labels of the classification targets
classification_weights: List[str]
Weights of the classification targets
es_if_exists: {'fail', 'replace'} default 'fail'
How to behave if model already exists
- fail: Raise a Value Error
- replace: Overwrite existing model
overwrite: **DEPRECATED** - bool
Delete and overwrite existing model (if exists)
es_compress_model_definition: bool
If True will use 'compressed_definition' which uses gzipped
JSON instead of raw JSON to reduce the amount of data sent
over the wire in HTTP requests. Defaults to 'True'.
Examples
--------
>>> from sklearn import datasets
>>> from sklearn.tree import DecisionTreeClassifier
>>> from eland.ml import MLModel
>>> # Train model
>>> training_data = datasets.make_classification(n_features=5, random_state=0)
>>> test_data = [[-50.1, 0.2, 0.3, -0.5, 1.0], [1.6, 2.1, -10, 50, -1.0]]
>>> classifier = DecisionTreeClassifier()
>>> classifier = classifier.fit(training_data[0], training_data[1])
>>> # Get some test results
>>> classifier.predict(test_data)
array([0, 1])
>>> # Serialise the model to Elasticsearch
>>> feature_names = ["f0", "f1", "f2", "f3", "f4"]
>>> model_id = "test_decision_tree_classifier"
>>> es_model = MLModel.import_model(
... 'localhost',
... model_id=model_id,
... model=classifier,
... feature_names=feature_names,
... es_if_exists='replace'
... )
>>> # Get some test results from Elasticsearch model
>>> es_model.predict(test_data)
array([0, 1])
>>> # Delete model from Elasticsearch
>>> es_model.delete_model()
"""
es_client = ensure_es_client(es_client)
transformer = get_model_transformer(
model,
feature_names=feature_names,
classification_labels=classification_labels,
classification_weights=classification_weights,
)
serializer = transformer.transform()
model_type = transformer.model_type
# Verify if both parameters are given
if overwrite is not None and es_if_exists is not None:
raise ValueError(
"Using 'overwrite' and 'es_if_exists' together is invalid, use only 'es_if_exists'"
)
if overwrite is not None:
warnings.warn(
"'overwrite' parameter is deprecated, use 'es_if_exists' instead",
DeprecationWarning,
stacklevel=2,
)
es_if_exists = "replace" if overwrite else "fail"
elif es_if_exists is None:
es_if_exists = "fail"
ml_model = MLModel(
es_client=es_client,
model_id=model_id,
)
if es_if_exists not in ("fail", "replace"):
raise ValueError("'es_if_exists' must be either 'fail' or 'replace'")
elif es_if_exists == "fail":
if ml_model.exists_model():
raise ValueError(
f"Trained machine learning model {model_id} already exists"
)
elif es_if_exists == "replace":
ml_model.delete_model()
body: Dict[str, Any] = {
"input": {"field_names": feature_names},
}
# 'inference_config' is required in 7.8+ but isn't available in <=7.7
if es_version(es_client) >= (7, 8):
body["inference_config"] = {model_type: {}}
if es_compress_model_definition:
body["compressed_definition"] = serializer.serialize_and_compress_model()
else:
body["definition"] = serializer.serialize_model()
ml_model._client.ml.put_trained_model(
model_id=model_id,
body=body,
)
return ml_model
def delete_model(self) -> None:
"""
@ -69,3 +420,18 @@ class MLModel:
except elasticsearch.NotFoundError:
return False
return True
@property
def _trained_model_config(self) -> Dict[str, Any]:
"""Lazily loads an ML models 'trained_model_config' information"""
if self._trained_model_config_cache is None:
resp = self._client.ml.get_trained_models(model_id=self._model_id)
if resp["count"] > 1:
raise ValueError(f"Model ID {self._model_id!r} wasn't unambiguous")
elif resp["count"] == 0:
raise ValueError(f"Model with Model ID {self._model_id!r} wasn't found")
self._trained_model_config_cache = resp["trained_model_configs"][0]
return self._trained_model_config_cache
ImportedMLModel = deprecated_api("MLModel.import_model()")(MLModel.import_model)

View File

@ -37,7 +37,7 @@ def get_model_transformer(model: Any, **kwargs: Any) -> ModelTransformer:
return transformer(model, **kwargs)
raise NotImplementedError(
f"ML model of type {type(model)}, not currently implemented"
f"Importing ML models of type {type(model)}, not currently implemented"
)

View File

@ -18,7 +18,7 @@
from typing import Optional, List, Dict, Any, Type
from .base import ModelTransformer
from .._model_serializer import Ensemble, Tree, TreeNode
from ..ml_model import MLModel
from ..common import TYPE_CLASSIFICATION, TYPE_REGRESSION
from .._optional import import_optional_dependency
import_optional_dependency("lightgbm", on_version="warn")
@ -201,7 +201,7 @@ class LGBMRegressorTransformer(LGBMForestTransformer):
@property
def model_type(self) -> str:
return MLModel.TYPE_REGRESSION
return TYPE_REGRESSION
class LGBMClassifierTransformer(LGBMForestTransformer):
@ -244,7 +244,7 @@ class LGBMClassifierTransformer(LGBMForestTransformer):
@property
def model_type(self) -> str:
return MLModel.TYPE_CLASSIFICATION
return TYPE_CLASSIFICATION
def is_objective_supported(self) -> bool:
return self._objective in {

View File

@ -18,7 +18,7 @@
import numpy as np # type: ignore
from typing import Optional, Sequence, Union, Dict, Any, Type, Tuple
from .base import ModelTransformer
from ..ml_model import MLModel
from ..common import TYPE_CLASSIFICATION, TYPE_REGRESSION
from .._optional import import_optional_dependency
from .._model_serializer import Ensemble, Tree, TreeNode
@ -148,9 +148,9 @@ class SKLearnDecisionTreeTransformer(SKLearnTransformer):
@property
def model_type(self) -> str:
return (
MLModel.TYPE_REGRESSION
TYPE_REGRESSION
if isinstance(self._model, DecisionTreeRegressor)
else MLModel.TYPE_CLASSIFICATION
else TYPE_CLASSIFICATION
)
@ -223,7 +223,7 @@ class SKLearnForestRegressorTransformer(SKLearnForestTransformer):
@property
def model_type(self) -> str:
return MLModel.TYPE_REGRESSION
return TYPE_REGRESSION
class SKLearnForestClassifierTransformer(SKLearnForestTransformer):
@ -247,7 +247,7 @@ class SKLearnForestClassifierTransformer(SKLearnForestTransformer):
@property
def model_type(self) -> str:
return MLModel.TYPE_CLASSIFICATION
return TYPE_CLASSIFICATION
_MODEL_TRANSFORMERS: Dict[type, Type[ModelTransformer]] = {

View File

@ -20,7 +20,7 @@ from typing import Optional, List, Dict, Any, Type
from .base import ModelTransformer
import pandas as pd # type: ignore
from .._model_serializer import Ensemble, Tree, TreeNode
from ..ml_model import MLModel
from ..common import TYPE_CLASSIFICATION, TYPE_REGRESSION
from .._optional import import_optional_dependency
import_optional_dependency("xgboost", on_version="warn")
@ -207,7 +207,7 @@ class XGBoostRegressorTransformer(XGBoostForestTransformer):
@property
def model_type(self) -> str:
return MLModel.TYPE_REGRESSION
return TYPE_REGRESSION
class XGBoostClassifierTransformer(XGBoostForestTransformer):
@ -253,7 +253,7 @@ class XGBoostClassifierTransformer(XGBoostForestTransformer):
@property
def model_type(self) -> str:
return MLModel.TYPE_CLASSIFICATION
return TYPE_CLASSIFICATION
_MODEL_TRANSFORMERS: Dict[type, Type[ModelTransformer]] = {

View File

@ -18,7 +18,7 @@
import pytest
import numpy as np
from eland.ml import ImportedMLModel
from eland.ml import MLModel
from eland.tests import ES_TEST_CLIENT, ES_VERSION
@ -71,7 +71,7 @@ def skip_if_multiclass_classifition():
def random_rows(data, size):
return data[np.random.randint(data.shape[0], size=size), :].tolist()
return data[np.random.randint(data.shape[0], size=size), :]
def check_prediction_equality(es_model, py_model, test_data):
@ -84,7 +84,7 @@ def check_prediction_equality(es_model, py_model, test_data):
class TestImportedMLModel:
@requires_no_ml_extras
def test_import_ml_model_when_dependencies_are_not_available(self):
from eland.ml import MLModel, ImportedMLModel # noqa: F401
from eland.ml import MLModel # noqa: F401
@requires_sklearn
def test_unpack_and_raise_errors_in_ingest_simulate(self, mocker):
@ -98,7 +98,7 @@ class TestImportedMLModel:
model_id = "test_decision_tree_classifier"
test_data = [[0.1, 0.2, 0.3, -0.5, 1.0], [1.6, 2.1, -10, 50, -1.0]]
es_model = ImportedMLModel(
es_model = MLModel.import_model(
ES_TEST_CLIENT,
model_id,
classifier,
@ -142,7 +142,7 @@ class TestImportedMLModel:
feature_names = ["f0", "f1", "f2", "f3", "f4"]
model_id = "test_decision_tree_classifier"
es_model = ImportedMLModel(
es_model = MLModel.import_model(
ES_TEST_CLIENT,
model_id,
classifier,
@ -171,7 +171,7 @@ class TestImportedMLModel:
feature_names = ["f0", "f1", "f2", "f3", "f4"]
model_id = "test_decision_tree_regressor"
es_model = ImportedMLModel(
es_model = MLModel.import_model(
ES_TEST_CLIENT,
model_id,
regressor,
@ -199,7 +199,7 @@ class TestImportedMLModel:
feature_names = ["f0", "f1", "f2", "f3", "f4"]
model_id = "test_random_forest_classifier"
es_model = ImportedMLModel(
es_model = MLModel.import_model(
ES_TEST_CLIENT,
model_id,
classifier,
@ -227,7 +227,7 @@ class TestImportedMLModel:
feature_names = ["f0", "f1", "f2", "f3", "f4"]
model_id = "test_random_forest_regressor"
es_model = ImportedMLModel(
es_model = MLModel.import_model(
ES_TEST_CLIENT,
model_id,
regressor,
@ -265,7 +265,7 @@ class TestImportedMLModel:
feature_names = ["f0", "f1", "f2", "f3", "f4"]
model_id = "test_xgb_classifier"
es_model = ImportedMLModel(
es_model = MLModel.import_model(
ES_TEST_CLIENT,
model_id,
classifier,
@ -305,7 +305,7 @@ class TestImportedMLModel:
feature_names = ["feature0", "feature1", "feature2", "feature3", "feature4"]
model_id = "test_xgb_classifier"
es_model = ImportedMLModel(
es_model = MLModel.import_model(
ES_TEST_CLIENT, model_id, classifier, feature_names, es_if_exists="replace"
)
# Get some test results
@ -337,7 +337,7 @@ class TestImportedMLModel:
feature_names = ["f0", "f1", "f2", "f3", "f4"]
model_id = "test_xgb_regressor"
es_model = ImportedMLModel(
es_model = MLModel.import_model(
ES_TEST_CLIENT,
model_id,
regressor,
@ -368,7 +368,7 @@ class TestImportedMLModel:
feature_names = ["f0"]
model_id = "test_xgb_regressor"
es_model = ImportedMLModel(
es_model = MLModel.import_model(
ES_TEST_CLIENT, model_id, regressor, feature_names, es_if_exists="replace"
)
@ -405,7 +405,7 @@ class TestImportedMLModel:
feature_names = ["Column_0", "Column_1", "Column_2", "Column_3", "Column_4"]
model_id = "test_lgbm_regressor"
es_model = ImportedMLModel(
es_model = MLModel.import_model(
ES_TEST_CLIENT,
model_id,
regressor,
@ -446,7 +446,7 @@ class TestImportedMLModel:
feature_names = ["Column_0", "Column_1", "Column_2", "Column_3", "Column_4"]
model_id = "test_lgbm_classifier"
es_model = ImportedMLModel(
es_model = MLModel.import_model(
ES_TEST_CLIENT,
model_id,
classifier,
@ -480,7 +480,7 @@ class TestImportedMLModel:
match = "Using 'overwrite' and 'es_if_exists' together is invalid, use only 'es_if_exists'"
with pytest.raises(ValueError, match=match):
ImportedMLModel(
MLModel.import_model(
ES_TEST_CLIENT,
model_id,
regressor,
@ -507,7 +507,7 @@ class TestImportedMLModel:
match = "'overwrite' parameter is deprecated, use 'es_if_exists' instead"
with pytest.warns(DeprecationWarning, match=match):
ImportedMLModel(
MLModel.import_model(
ES_TEST_CLIENT,
model_id,
regressor,
@ -536,7 +536,7 @@ class TestImportedMLModel:
)
with pytest.raises(ValueError, match=match_error):
with pytest.warns(DeprecationWarning, match=match_warning):
ImportedMLModel(
MLModel.import_model(
ES_TEST_CLIENT,
model_id,
regressor,
@ -560,7 +560,7 @@ class TestImportedMLModel:
# If both overwrite and es_if_exists is given.
match = f"Trained machine learning model {model_id} already exists"
with pytest.raises(ValueError, match=match):
ImportedMLModel(
MLModel.import_model(
ES_TEST_CLIENT,
model_id,
regressor,

View File

@ -34,7 +34,7 @@ SOURCE_FILES = (
# Whenever type-hints are completed on a file it should
# be added here so that this file will continue to be checked
# by mypy. Errors from other files are ignored.
TYPED_FILES = {
TYPED_FILES = (
"eland/actions.py",
"eland/arithmetics.py",
"eland/common.py",
@ -46,13 +46,13 @@ TYPED_FILES = {
"eland/utils.py",
"eland/ml/__init__.py",
"eland/ml/_model_serializer.py",
"eland/ml/imported_ml_model.py",
"eland/ml/ml_model.py",
"eland/ml/transformers/__init__.py",
"eland/ml/transformers/base.py",
"eland/ml/transformers/lightgbm.py",
"eland/ml/transformers/sklearn.py",
"eland/ml/transformers/xgboost.py",
}
)
@nox.session(reuse_venv=True)