[ML] Export ML model as sklearn Pipeline (#509)

Closes #503

Note: I also had to fix the Sphinx version to 5.3.0 since, starting from 6.0, Sphinx suffers from a TypeError bug, which causes a CI failure.
This commit is contained in:
Valeriy Khakhutskyy 2023-02-01 16:17:06 +01:00 committed by GitHub
parent 2ea96322b3
commit 0576114a1d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 1117 additions and 22 deletions

0
.ci/test-matrix.yml Executable file → Normal file
View File

View File

@ -191,7 +191,7 @@ currently using a minimum version of PyCharm 2019.2.4.
``` bash
> import eland as ed
> ed_df = ed.DataFrame('localhost', 'flights')
> ed_df = ed.DataFrame('http://localhost:9200', 'flights')
```
* To run the automatic formatter and check for lint issues run

View File

@ -0,0 +1,16 @@
# Licensed to Elasticsearch B.V. under one or more contributor
# license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright
# ownership. Elasticsearch B.V. licenses this file to you under
# the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@ -0,0 +1,217 @@
# Licensed to Elasticsearch B.V. under one or more contributor
# license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright
# ownership. Elasticsearch B.V. licenses this file to you under
# the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any, Dict
import numpy as np
from .._optional import import_optional_dependency
import_optional_dependency("sklearn", on_version="warn")
import sklearn
from sklearn.preprocessing import FunctionTransformer
class Tree:
"""Wrapper to create sklearn Tree objects from Elastic ML tree
description in JSON format.
"""
def __init__(
self,
json_tree: Dict[str, Any],
feature_names_map: Dict[str, int],
):
tree_leaf = -1
node_count = len(json_tree["tree_structure"])
children_left = np.ones((node_count,), dtype=int) * tree_leaf
children_right = np.ones((node_count,), dtype=int) * tree_leaf
feature = np.ones((node_count,), dtype=int) * -2
threshold = np.ones((node_count,), dtype=float) * -2
impurity = np.zeros((node_count,), dtype=float)
# value works only for regression and binary classification
value = np.zeros((node_count, 1, 1), dtype="<f8")
n_node_samples = np.zeros((node_count,), dtype=int)
# parse values from the JSON tree
feature_names = json_tree["feature_names"]
for json_node in json_tree["tree_structure"]:
node_id = json_node["node_index"]
if "number_samples" in json_node:
n_node_samples[node_id] = json_node["number_samples"]
else:
n_node_samples[node_id] = 0
if "leaf_value" not in json_node:
children_left[node_id] = json_node["left_child"]
children_right[node_id] = json_node["right_child"]
feature[node_id] = feature_names_map[
feature_names[json_node["split_feature"]]
]
threshold[node_id] = json_node["threshold"]
if "split_gain" in json_node:
impurity[node_id] = json_node["split_gain"]
else:
impurity[node_id] = -1
else:
value[node_id, 0, 0] = json_node["leaf_value"]
# iterate through tree to get max depth and expected values
weighted_n_node_samples = n_node_samples.copy()
self.max_depth = Tree._compute_expectations(
children_left=children_left,
children_right=children_right,
node_sample_weight=weighted_n_node_samples,
values=value,
node_index=0,
)
self.n_outputs = value.shape[-1]
# initialize the sklearn tree
self.tree = sklearn.tree._tree.Tree(
len(feature_names), np.array([1], dtype=int), 1
)
node_state = np.array(
[
(
children_left[i],
children_right[i],
feature[i],
threshold[i],
impurity[i],
n_node_samples[i],
weighted_n_node_samples[i],
)
for i in range(node_count)
],
dtype=[
("left_child", "<i8"),
("right_child", "<i8"),
("feature", "<i8"),
("threshold", "<f8"),
("impurity", "<f8"),
("n_node_samples", "<i8"),
("weighted_n_node_samples", "<f8"),
],
)
state = {
"max_depth": self.max_depth,
"node_count": node_count,
"nodes": node_state,
"values": value,
}
self.tree.__setstate__(state)
@staticmethod
def _compute_expectations(
children_left, children_right, node_sample_weight, values, node_index
) -> int:
if children_right[node_index] == -1:
return 0
left_index = children_left[node_index]
right_index = children_right[node_index]
depth_left = Tree._compute_expectations(
children_left, children_right, node_sample_weight, values, left_index
)
depth_right = Tree._compute_expectations(
children_left, children_right, node_sample_weight, values, right_index
)
left_weight = node_sample_weight[left_index]
right_weight = node_sample_weight[right_index]
v = (
(
left_weight * values[left_index, :]
+ right_weight * values[right_index, :]
)
/ (left_weight + right_weight)
if left_weight + right_weight > 0
else 0
)
values[node_index, :] = v
return max(depth_left, depth_right) + 1
class TargetMeanEncoder(FunctionTransformer):
"""FunctionTransformer implementation of the target mean encoder, which is
deserialized from the Elastic ML preprocessor description in JSON formats.
"""
def __init__(self, preprocessor: Dict[str, Any]):
self.preprocessor = preprocessor
target_map = self.preprocessor["target_mean_encoding"]["target_map"]
feature_name_out = self.preprocessor["target_mean_encoding"]["feature_name"]
self.field_name_in = self.preprocessor["target_mean_encoding"]["field"]
fallback_value = self.preprocessor["target_mean_encoding"]["default_value"]
def func(column):
return np.array(
[
target_map[str(category)]
if category in target_map
else fallback_value
for category in column
]
).reshape(-1, 1)
def feature_names_out(ft, carr):
return [feature_name_out if c == self.field_name_in else c for c in carr]
super().__init__(func=func, feature_names_out=feature_names_out)
class FrequencyEncoder(FunctionTransformer):
"""FunctionTransformer implementation of the frequency encoder, which is
deserialized from the Elastic ML preprocessor description in JSON format.
"""
def __init__(self, preprocessor: Dict[str, Any]):
self.preprocessor = preprocessor
frequency_map = self.preprocessor["frequency_encoding"]["frequency_map"]
feature_name_out = self.preprocessor["frequency_encoding"]["feature_name"]
self.field_name_in = self.preprocessor["frequency_encoding"]["field"]
fallback_value = 0.0
def func(column):
return np.array(
[
frequency_map[str(category)]
if category in frequency_map
else fallback_value
for category in column
]
).reshape(-1, 1)
def feature_names_out(ft, carr):
return [feature_name_out if c == self.field_name_in else c for c in carr]
super().__init__(func=func, feature_names_out=feature_names_out)
class OneHotEncoder(sklearn.preprocessing.OneHotEncoder):
"""Wrapper for sklearn one-hot encoder, which is deserialized from the
Elastic ML preprocessor description in JSON format.
"""
def __init__(self, preprocessor: Dict[str, Any]):
self.preprocessor = preprocessor
self.field_name_in = self.preprocessor["one_hot_encoding"]["field"]
self.cats = [list(self.preprocessor["one_hot_encoding"]["hot_map"].keys())]
super().__init__(categories=self.cats, handle_unknown="ignore")

View File

@ -0,0 +1,46 @@
# Licensed to Elasticsearch B.V. under one or more contributor
# license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright
# ownership. Elasticsearch B.V. licenses this file to you under
# the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import eland
class ModelDefinitionKeyError(Exception):
"""
This exception is raised when a key is not found in the model definition.
Attributes:
missed_key (str): The key that was not found in the model definition.
available_keys (List[str]): The list of keys that are available in the model definition.
Examples:
model_definition = {"key1": "value1", "key2": "value2"}
try:
model_definition["key3"]
except KeyError as ex:
raise ModelDefinitionKeyError(ex) from ex
"""
def __init__(self, ex: KeyError):
self.missed_key = ex.args[0]
def __str__(self):
return (
f'Key "{self.missed_key}" is not available. '
+ "The model definition may have changed. "
+ "Make sure you are using an Elasticsearch version compatible "
+ f"with Eland {eland.__version__}."
)

View File

@ -0,0 +1,472 @@
# Licensed to Elasticsearch B.V. under one or more contributor
# license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright
# ownership. Elasticsearch B.V. licenses this file to you under
# the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from abc import ABC
from typing import Any, List, Literal, Mapping, Optional, Set, Tuple, Union
import numpy as np
from elasticsearch import Elasticsearch
from numpy.typing import ArrayLike
from .._optional import import_optional_dependency
import_optional_dependency("sklearn", on_version="warn")
from sklearn.dummy import DummyClassifier, DummyRegressor
from sklearn.ensemble import GradientBoostingClassifier, GradientBoostingRegressor
from sklearn.ensemble._gb_losses import (
BinomialDeviance,
HuberLossFunction,
LeastSquaresError,
)
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
from sklearn.utils.validation import check_array
from eland.common import ensure_es_client
from eland.ml.common import TYPE_CLASSIFICATION, TYPE_REGRESSION
from ._sklearn_deserializers import Tree
from .common import ModelDefinitionKeyError
class ESGradientBoostingModel(ABC):
"""
Abstract class for converting Elastic ML model into sklearn Pipeline.
"""
def __init__(
self,
es_client: Union[str, List[str], Tuple[str, ...], "Elasticsearch"],
model_id: str,
) -> None:
"""
Parameters
----------
es_client : Elasticsearch client argument(s)
- elasticsearch-py parameters or
- elasticsearch-py instance
model_id : str
The unique identifier of the trained inference model in Elasticsearch.
Raises
------
RuntimeError
On failure to retrieve trained model information to the specified model ID.
ValueError
The model is expected to be trained in Elastic Stack. Models initially imported
from xgboost, lgbm, or sklearn are not supported.
"""
self.es_client: Elasticsearch = ensure_es_client(es_client)
self.model_id = model_id
self._trained_model_result = self.es_client.ml.get_trained_models(
model_id=self.model_id,
decompress_definition=True,
include=["hyperparameters", "definition"],
)
if (
"trained_model_configs" not in self._trained_model_result
or len(self._trained_model_result["trained_model_configs"]) == 0
):
raise RuntimeError(
f"Failed to retrieve the trained model for model ID {self.model_id!r}"
)
if "metadata" not in self._trained_model_result["trained_model_configs"][0]:
raise ValueError(
"Error initializing sklearn classifier. Incorrect prior class probability. "
+ "Note: only export of models trained in the Elastic Stack is supported."
)
preprocessors = []
if "preprocessors" in self._definition:
preprocessors = self._definition["preprocessors"]
(
self.feature_names_in_,
self.input_field_names,
) = ESGradientBoostingModel._get_feature_names_in_(
preprocessors,
self._definition["trained_model"]["ensemble"]["feature_names"],
self._trained_model_result["trained_model_configs"][0]["input"][
"field_names"
],
)
feature_names_map = {name: i for i, name in enumerate(self.feature_names_in_)}
trained_models = self._definition["trained_model"]["ensemble"]["trained_models"]
self._trees = []
for trained_model in trained_models:
self._trees.append(Tree(trained_model["tree"], feature_names_map))
# 0's tree is the constant estimator
self.n_estimators = len(trained_models) - 1
def _initialize_estimators(self, decision_tree_type) -> None:
self.estimators_ = np.ndarray(
(len(self._trees) - 1, 1), dtype=decision_tree_type
)
self.n_estimators_ = self.estimators_.shape[0]
for i in range(self.n_estimators_):
estimator = decision_tree_type()
estimator.tree_ = self._trees[i + 1].tree
estimator.n_features_in_ = self.n_features_in_
estimator.max_depth = self._max_depth
estimator.max_features_ = self.max_features_
self.estimators_[i, 0] = estimator
def _extract_common_parameters(self) -> None:
self.n_features_in_ = len(self.feature_names_in_)
self.max_features_ = self.n_features_in_
@property
def _max_depth(self) -> int:
return max(map(lambda x: x.max_depth, self._trees))
@property
def _n_outputs(self) -> int:
return self._trees[0].n_outputs
@property
def _definition(self) -> Mapping[Union[str, int], Any]:
return self._trained_model_result["trained_model_configs"][0]["definition"]
@staticmethod
def _get_feature_names_in_(
preprocessors, feature_names, field_names
) -> Tuple[List[str], Set[str]]:
input_field_names = set()
def add_input_field_name(preprocessor_type: str, feature_name: str) -> None:
if feature_name in feature_names:
input_field_names.add(preprocessor[preprocessor_type]["field"])
for preprocessor in preprocessors:
if "target_mean_encoding" in preprocessor:
add_input_field_name(
"target_mean_encoding",
preprocessor["target_mean_encoding"]["feature_name"],
)
elif "frequency_encoding" in preprocessor:
add_input_field_name(
"frequency_encoding",
preprocessor["frequency_encoding"]["feature_name"],
)
elif "one_hot_encoding" in preprocessor:
for feature_name in preprocessor["one_hot_encoding"][
"hot_map"
].values():
add_input_field_name("one_hot_encoding", feature_name)
for field_name in field_names:
if field_name in feature_names and field_name not in input_field_names:
input_field_names.add(field_name)
return feature_names, input_field_names
@property
def preprocessors(self) -> List[Any]:
"""
Returns the list of preprocessor JSON definitions.
Returns
-------
List[Any]
List of preprocessors definitions or [].
"""
if "preprocessors" in self._definition:
return self._definition["preprocessors"]
return []
def fit(self, X, y, sample_weight=None, monitor=None) -> None:
"""
Override of the sklearn fit() method. It does nothing since Elastic ML models are
trained in the Elastic Stack or imported.
"""
# Do nothing, model if fitted using Elasticsearch API
pass
class ESGradientBoostingClassifier(ESGradientBoostingModel, GradientBoostingClassifier):
"""
Elastic ML model wrapper compatible with sklearn GradientBoostingClassifier.
"""
def __init__(
self,
es_client: Union[str, List[str], Tuple[str, ...], "Elasticsearch"],
model_id: str,
) -> None:
"""
Parameters
----------
es_client : Elasticsearch client argument(s)
- elasticsearch-py parameters or
- elasticsearch-py instance
model_id : str
The unique identifier of the trained inference model in Elasticsearch.
Raises
------
NotImplementedError
Multi-class classification is not supported at the moment.
ValueError
The classifier should be defined for at least 2 classes.
ModelDefinitionKeyError
If required data cannot be extracted from the model definition due to a schema change.
"""
try:
ESGradientBoostingModel.__init__(self, es_client, model_id)
self._extract_common_parameters()
GradientBoostingClassifier.__init__(
self,
learning_rate=1.0,
n_estimators=self.n_estimators,
max_depth=self._max_depth,
)
if "classification_labels" in self._definition["trained_model"]["ensemble"]:
self.classes_ = np.array(
self._definition["trained_model"]["ensemble"][
"classification_labels"
]
)
else:
self.classes_ = None
self.n_outputs = self._n_outputs
if self.classes_ is not None:
self.n_classes_ = len(self.classes_)
elif self.n_outputs <= 2:
self.n_classes_ = 2
else:
self.n_classes_ = self.n_outputs
if self.n_classes_ == 2:
self._loss = BinomialDeviance(self.n_classes_)
# self.n_outputs = 1
elif self.n_classes_ > 2:
raise NotImplementedError("Only binary classification is implemented.")
else:
raise ValueError(f"At least 2 classes required. got {self.n_classes_}.")
self.init_ = self._initialize_init_()
self._initialize_estimators(DecisionTreeClassifier)
except KeyError as ex:
raise ModelDefinitionKeyError(ex) from ex
@property
def analysis_type(self) -> Literal["classification"]:
return TYPE_CLASSIFICATION
def _initialize_init_(self) -> DummyClassifier:
estimator = DummyClassifier(strategy="prior")
estimator.n_classes_ = self.n_classes_
estimator.n_outputs_ = self.n_outputs
estimator.classes_ = np.arange(self.n_classes_)
estimator._strategy = estimator.strategy
if self.n_classes_ == 2:
log_odds = self._trees[0].tree.value.flatten()[0]
if np.isnan(log_odds):
raise ValueError(
"Error initializing sklearn classifier. Incorrect prior class probability. "
+ "Note: only export of models trained in the Elastic Stack is supported."
)
class_prior = 1 / (1 + np.exp(-log_odds))
estimator.class_prior_ = np.array([1 - class_prior, class_prior])
else:
raise NotImplementedError("Only binary classification is implemented.")
return estimator
def predict_proba(
self, X, feature_names_in: Optional[Union["ArrayLike", List[str]]] = None
) -> "ArrayLike":
"""Predict class probabilities for X.
Parameters
----------
X : array-like of shape (n_samples, n_features)
The input samples.
feature_names_in : {array of string, list of string} of length n_features.
Feature names of the corresponding columns in X. Important, since the column list
can be extended by ColumnTransformer through the pipeline. By default None.
Returns
-------
ArrayLike of shape (n_samples, n_classes)
The class probabilities of the input samples. The order of the
classes corresponds to that in the attribute :term:`classes_`.
"""
if feature_names_in is not None:
if X.shape[1] != len(feature_names_in):
raise ValueError(
f"Dimension mismatch: X with {X.shape[1]} columns has to be the same size as feature_names_in with {len(feature_names_in)}."
)
if isinstance(feature_names_in, np.ndarray):
feature_names_in = feature_names_in.tolist()
# select columns used by the model in the correct order
X = X[:, [feature_names_in.index(fn) for fn in self.feature_names_in_]]
X = check_array(X)
return GradientBoostingClassifier.predict_proba(self, X)
def predict(
self,
X: "ArrayLike",
feature_names_in: Optional[Union["ArrayLike", List[str]]] = None,
) -> "ArrayLike":
"""Predict class for X.
Parameters
----------
X : array-like of shape (n_samples, n_features)
The input samples.
feature_names_in : {array of string, list of string} of length n_features.
Feature names of the corresponding columns in X. Important, since the column list
can be extended by ColumnTransformer through the pipeline. By default None.
Returns
-------
ArrayLike of shape (n_samples,)
The predicted values.
"""
if feature_names_in is not None:
if X.shape[1] != len(feature_names_in):
raise ValueError(
f"Dimension mismatch: X with {X.shape[1]} columns has to be the same size as feature_names_in with {len(feature_names_in)}."
)
if isinstance(feature_names_in, np.ndarray):
feature_names_in = feature_names_in.tolist()
# select columns used by the model in the correct order
X = X[:, [feature_names_in.index(fn) for fn in self.feature_names_in_]]
X = check_array(X)
return GradientBoostingClassifier.predict(self, X)
class ESGradientBoostingRegressor(ESGradientBoostingModel, GradientBoostingRegressor):
"""
Elastic ML model wrapper compatible with sklearn GradientBoostingRegressor.
"""
def __init__(
self,
es_client: Union[str, List[str], Tuple[str, ...], "Elasticsearch"],
model_id: str,
) -> None:
"""
Parameters
----------
es_client : Elasticsearch client argument(s)
- elasticsearch-py parameters or
- elasticsearch-py instance
model_id : str
The unique identifier of the trained inference model in Elasticsearch.
Raises
------
NotImplementedError
Only MSE, MSLE, and Huber loss functions are supported.
ModelDefinitionKeyError
If required data cannot be extracted from the model definition due to a schema change.
"""
try:
ESGradientBoostingModel.__init__(self, es_client, model_id)
self._extract_common_parameters()
GradientBoostingRegressor.__init__(
self,
learning_rate=1.0,
n_estimators=self.n_estimators,
max_depth=self._max_depth,
)
self.n_outputs = 1
loss_function = self._trained_model_result["trained_model_configs"][0][
"metadata"
]["analytics_config"]["analysis"][self.analysis_type]["loss_function"]
if loss_function == "mse" or loss_function == "msle":
self.criterion = "squared_error"
self._loss = LeastSquaresError()
elif loss_function == "huber":
loss_parameter = loss_function = self._trained_model_result[
"trained_model_configs"
][0]["metadata"]["analytics_config"]["analysis"][self.analysis_type][
"loss_function_parameter"
]
self.criterion = "huber"
self._loss = HuberLossFunction(loss_parameter)
else:
raise NotImplementedError(
"Only MSE, MSLE and Huber loss functions are supported."
)
self.init_ = self._initialize_init_()
self._initialize_estimators(DecisionTreeRegressor)
except KeyError as ex:
raise ModelDefinitionKeyError(ex) from ex
@property
def analysis_type(self) -> Literal["regression"]:
return TYPE_REGRESSION
def _initialize_init_(self) -> DummyRegressor:
constant = self._trees[0].tree.value[0]
estimator = DummyRegressor(
strategy="constant",
constant=constant,
)
estimator.constant_ = np.array([constant])
estimator.n_outputs_ = 1
return estimator
def predict(
self,
X: "ArrayLike",
feature_names_in: Optional[Union["ArrayLike", List[str]]] = None,
) -> "ArrayLike":
"""Predict targets for X.
Parameters
----------
X : array-like of shape (n_samples, n_features)
The input samples.
feature_names_in : {array of string, list of string} of length n_features.
Feature names of the corresponding columns in X. Important, since the column list
can be extended by ColumnTransformer through the pipeline. By default None.
Returns
-------
ArrayLike of shape (n_samples,)
The predicted values.
"""
if feature_names_in is not None:
if X.shape[1] != len(feature_names_in):
raise ValueError(
f"Dimension mismatch: X with {X.shape[1]} columns has to be the same size as feature_names_in with {len(feature_names_in)}."
)
if isinstance(X, np.ndarray):
feature_names_in = feature_names_in.tolist()
# select columns used by the model in the correct order
X = X[:, [feature_names_in.index(fn) for fn in self.feature_names_in_]]
X = check_array(X)
return GradientBoostingRegressor.predict(self, X)

View File

@ -37,6 +37,7 @@ if TYPE_CHECKING:
RandomForestClassifier,
RandomForestRegressor,
)
from sklearn.pipeline import Pipeline # type: ignore # noqa: F401
from sklearn.tree import ( # type: ignore # noqa: F401
DecisionTreeClassifier,
DecisionTreeRegressor,
@ -424,6 +425,83 @@ class MLModel:
return False
return True
def export_model(self) -> "Pipeline":
"""Export Elastic ML model as sklearn Pipeline.
Returns
-------
sklearn.pipeline.Pipeline
_description_
Raises
------
AssertionError
If preprocessors JSON definition has unexpected schema.
ValueError
The model is expected to be trained in Elastic Stack. Models initially imported
from xgboost, lgbm, or sklearn are not supported.
ValueError
If unexpected categorical encoding is found in the list of preprocessors.
NotImplementedError
Only regression and binary classification models are supported currently.
"""
from sklearn.compose import ColumnTransformer # type: ignore # noqa: F401
from sklearn.pipeline import Pipeline
from .exporters._sklearn_deserializers import (
FrequencyEncoder,
OneHotEncoder,
TargetMeanEncoder,
)
from .exporters.es_gb_models import (
ESGradientBoostingClassifier,
ESGradientBoostingRegressor,
)
if self.model_type == TYPE_CLASSIFICATION:
model = ESGradientBoostingClassifier(
es_client=self._client, model_id=self._model_id
)
elif self.model_type == TYPE_REGRESSION:
model = ESGradientBoostingRegressor(
es_client=self._client, model_id=self._model_id
)
else:
raise NotImplementedError(
"Only regression and binary classification models are supported currently."
)
transformers = []
for p in model.preprocessors:
assert (
len(p) == 1
), f"Unexpected preprocessor data structure: {p}. One-key mapping expected."
encoding_type = list(p.keys())[0]
field = p[encoding_type]["field"]
if encoding_type == "frequency_encoding":
transform = FrequencyEncoder(p)
transformers.append((f"{field}_{encoding_type}", transform, field))
elif encoding_type == "target_mean_encoding":
transform = TargetMeanEncoder(p)
transformers.append((f"{field}_{encoding_type}", transform, field))
elif encoding_type == "one_hot_encoding":
transform = OneHotEncoder(p)
transformers.append((f"{field}_{encoding_type}", transform, [field]))
else:
raise ValueError(
f"Unexpected categorical encoding type {encoding_type} found. "
+ "Expected encodings: frequency_encoding, target_mean_encoding, one_hot_encoding."
)
preprocessor = ColumnTransformer(
transformers=transformers,
remainder="passthrough",
verbose_feature_names_out=False,
)
pipeline = Pipeline(steps=[("preprocessor", preprocessor), ("es_model", model)])
return pipeline
@property
def _trained_model_config(self) -> Dict[str, Any]:
"""Lazily loads an ML models 'trained_model_config' information"""

View File

@ -125,7 +125,7 @@ def task_type_from_model_config(model_config: PretrainedConfig) -> Optional[str]
return None
potential_task_types: Set[str] = set()
for architecture in model_config.architectures:
for (substr, task_type) in ARCHITECTURE_TO_TASK_TYPE.items():
for substr, task_type in ARCHITECTURE_TO_TASK_TYPE.items():
if substr in architecture:
for t in task_type:
potential_task_types.add(t)
@ -384,7 +384,6 @@ class _DPREncoderWrapper(nn.Module): # type: ignore
@staticmethod
def from_pretrained(model_id: str) -> Optional[Any]:
config = AutoConfig.from_pretrained(model_id)
def is_compatible() -> bool:

View File

@ -210,7 +210,6 @@ class Operations:
def idx(
self, query_compiler: "QueryCompiler", axis: int, sort_order: str
) -> pd.Series:
if axis == 1:
# Fetch idx on Columns
raise NotImplementedError(
@ -279,7 +278,6 @@ class Operations:
numeric_only: bool = False,
dropna: bool = True,
) -> Union[pd.DataFrame, pd.Series]:
results = self._metric_aggs(
query_compiler,
pd_aggs=pd_aggs,
@ -530,7 +528,6 @@ class Operations:
# weights = [10066., 263., 386., 264., 273., 390., 324., 438., 261., 252., 142.]
# So sum last 2 buckets
for field in numeric_source_fields:
# in case of series let plotting.ed_hist_series thrown an exception
if not response.get("aggregations"):
continue
@ -771,7 +768,6 @@ class Operations:
is_dataframe: bool = True,
numeric_only: Optional[bool] = True,
) -> Union[pd.DataFrame, pd.Series]:
percentiles = [
quantile_to_percentile(x)
for x in (
@ -801,7 +797,6 @@ class Operations:
return df if is_dataframe else df.transpose().iloc[0]
def unique(self, query_compiler: "QueryCompiler") -> pd.Series:
query_params, _ = self._resolve_tasks(query_compiler)
body = Query(query_params.query)
@ -1052,7 +1047,6 @@ class Operations:
buckets: Sequence[Dict[str, Any]] = composite_buckets["buckets"]
if after_key:
# yield the bucket which contains the result
yield buckets
@ -1227,7 +1221,6 @@ class Operations:
def to_pandas(
self, query_compiler: "QueryCompiler", show_progress: bool = False
) -> pd.DataFrame:
df_list: List[pd.DataFrame] = []
i = 0
for df in self.search_yield_pandas_dataframes(query_compiler=query_compiler):

View File

@ -170,7 +170,6 @@ class Query:
sort_order: str,
size: int = 1,
) -> None:
top_hits: Any = {}
if sort_order:
top_hits["sort"] = [{i: {"order": sort_order}} for i in source_columns]

View File

@ -246,7 +246,6 @@ class QueryCompiler:
i = 0
for i, hit in enumerate(results, 1):
if "_source" in hit:
row = hit["_source"]
else:

View File

@ -61,7 +61,7 @@ def format(session):
session.install("black", "isort", "flynt")
session.run("python", "utils/license-headers.py", "fix", *SOURCE_FILES)
session.run("flynt", *SOURCE_FILES)
session.run("black", "--target-version=py37", *SOURCE_FILES)
session.run("black", "--target-version=py38", *SOURCE_FILES)
session.run("isort", "--profile=black", *SOURCE_FILES)
lint(session)
@ -73,7 +73,7 @@ def lint(session):
session.install("black", "flake8", "mypy", "isort", "numpy")
session.install("--pre", "elasticsearch>=8.3,<9")
session.run("python", "utils/license-headers.py", "check", *SOURCE_FILES)
session.run("black", "--check", "--target-version=py37", *SOURCE_FILES)
session.run("black", "--check", "--target-version=py38", *SOURCE_FILES)
session.run("isort", "--check", "--profile=black", *SOURCE_FILES)
session.run("flake8", "--ignore=E501,W503,E402,E712,E203", *SOURCE_FILES)
@ -138,6 +138,7 @@ def test(session, pandas_version: str):
"scikit-learn",
"xgboost",
"lightgbm",
"shap",
)
session.run("pytest", "tests/ml/")

View File

@ -29,6 +29,7 @@ pytest>=5.2.1
pytest-mock
pytest-cov
nbval
shap==0.41.0
#
# Docs

View File

@ -35,7 +35,6 @@ class TestDataFrameCount(TestData):
df.count()
def test_count_flights(self):
pd_flights = self.pd_flights().filter(self.filter_data)
ed_flights = self.ed_flights().filter(self.filter_data)

View File

@ -419,7 +419,6 @@ class TestDataFrameMetrics(TestData):
assert calculated_values.shape == (2,)
def test_aggs_count(self):
pd_flights = self.pd_flights().filter(self.filter_data)
ed_flights = self.ed_flights().filter(self.filter_data)

View File

@ -102,7 +102,6 @@ class TestDataFrameToCSV(TestData):
ES_TEST_CLIENT.indices.delete(index=test_index)
def test_pd_to_csv_without_filepath(self):
ed_flights = self.ed_flights()
pd_flights = self.pd_flights()

View File

@ -147,7 +147,6 @@ class TestDataFrameUtils(TestData):
# assert_pandas_eland_frame_equal(pd_df, self.ed_flights())
def test_es_type_override_error(self):
df = self.pd_flights().filter(
["AvgTicketPrice", "Cancelled", "dayOfWeek", "timestamp", "DestCountry"]
)

View File

@ -15,11 +15,14 @@
# specific language governing permissions and limitations
# under the License.
from operator import itemgetter
import numpy as np
import pytest
import eland as ed
from eland.ml import MLModel
from tests import ES_TEST_CLIENT, ES_VERSION
from tests import ES_TEST_CLIENT, ES_VERSION, FLIGHTS_SMALL_INDEX_NAME
try:
from sklearn import datasets
@ -44,16 +47,26 @@ try:
except ImportError:
HAS_LIGHTGBM = False
try:
import shap
HAS_SHAP = True
except ImportError:
HAS_SHAP = False
requires_sklearn = pytest.mark.skipif(
not HAS_SKLEARN, reason="This test requires 'scikit-learn' package to run"
not HAS_SKLEARN, reason="This test requires 'scikit-learn' package to run."
)
requires_xgboost = pytest.mark.skipif(
not HAS_XGBOOST, reason="This test requires 'xgboost' package to run"
not HAS_XGBOOST, reason="This test requires 'xgboost' package to run."
)
requires_shap = pytest.mark.skipif(
not HAS_SHAP, reason="This tests requries 'shap' package to run."
)
requires_no_ml_extras = pytest.mark.skipif(
HAS_SKLEARN or HAS_XGBOOST,
reason="This test requires 'scikit-learn' and 'xgboost' to not be installed",
reason="This test requires 'scikit-learn' and 'xgboost' to not be installed.",
)
requires_lightgbm = pytest.mark.skipif(
@ -80,6 +93,102 @@ def check_prediction_equality(es_model: MLModel, py_model, test_data):
np.testing.assert_almost_equal(test_results, es_results, decimal=2)
def yield_model_id(analysis, analyzed_fields):
import random
import string
import time
suffix = "".join(random.choices(string.ascii_lowercase, k=4))
job_id = "test-flights-regression-" + suffix
dest = job_id + "-dest"
response = ES_TEST_CLIENT.ml.put_data_frame_analytics(
id=job_id,
analysis=analysis,
dest={"index": dest},
source={"index": [FLIGHTS_SMALL_INDEX_NAME]},
analyzed_fields=analyzed_fields,
)
assert response.meta.status == 200
response = ES_TEST_CLIENT.ml.start_data_frame_analytics(id=job_id)
assert response.meta.status == 200
time.sleep(2)
response = ES_TEST_CLIENT.ml.get_trained_models(model_id=job_id + "*")
assert response.meta.status == 200
assert response.body["count"] == 1
model_id = response.body["trained_model_configs"][0]["model_id"]
yield model_id
ES_TEST_CLIENT.ml.delete_data_frame_analytics(id=job_id)
ES_TEST_CLIENT.indices.delete(index=dest)
ES_TEST_CLIENT.ml.delete_trained_model(model_id=model_id)
@pytest.fixture(params=[[0, 4], [0, 1], range(5)])
def regression_model_id(request):
analysis = {
"regression": {
"dependent_variable": "FlightDelayMin",
"max_trees": 3,
"num_top_feature_importance_values": 0,
"max_optimization_rounds_per_hyperparameter": 1,
"prediction_field_name": "FlightDelayMin_prediction",
"training_percent": 30,
"randomize_seed": 1000,
"loss_function": "mse",
"early_stopping_enabled": True,
}
}
all_includes = [
"FlightDelayMin",
"FlightDelayType",
"FlightTimeMin",
"DistanceMiles",
"OriginAirportID",
]
includes = [all_includes[i] for i in request.param]
analyzed_fields = {
"includes": includes,
"excludes": [],
}
yield from yield_model_id(analysis=analysis, analyzed_fields=analyzed_fields)
@pytest.fixture(params=[[0, 6], [5, 6], range(7)])
def classification_model_id(request):
analysis = {
"classification": {
"dependent_variable": "Cancelled",
"max_trees": 5,
"num_top_feature_importance_values": 0,
"max_optimization_rounds_per_hyperparameter": 1,
"prediction_field_name": "Cancelled_prediction",
"training_percent": 50,
"randomize_seed": 1000,
"num_top_classes": -1,
"class_assignment_objective": "maximize_accuracy",
"early_stopping_enabled": True,
}
}
all_includes = [
"OriginWeather",
"OriginAirportID",
"DestCityName",
"DestWeather",
"DestRegion",
"AvgTicketPrice",
"Cancelled",
]
includes = [all_includes[i] for i in request.param]
analyzed_fields = {
"includes": includes,
"excludes": [],
}
yield from yield_model_id(analysis=analysis, analyzed_fields=analyzed_fields)
class TestMLModel:
@requires_no_ml_extras
def test_import_ml_model_when_dependencies_are_not_available(self):
@ -494,3 +603,172 @@ class TestMLModel:
# Clean up
es_model.delete_model()
@requires_sklearn
@requires_shap
def test_export_regressor(self, regression_model_id):
ed_flights = ed.DataFrame(ES_TEST_CLIENT, FLIGHTS_SMALL_INDEX_NAME).head(10)
types = dict(ed_flights.dtypes)
X = ed_flights.to_pandas().astype(types)
model = MLModel(es_client=ES_TEST_CLIENT, model_id=regression_model_id)
pipeline = model.export_model()
pipeline.fit(X)
predictions_sklearn = pipeline.predict(
X, feature_names_in=pipeline["preprocessor"].get_feature_names_out()
)
response = ES_TEST_CLIENT.ml.infer_trained_model(
model_id=regression_model_id,
docs=X[pipeline["es_model"].input_field_names].to_dict("records"),
)
predictions_es = np.array(
list(
map(
itemgetter("FlightDelayMin_prediction"),
response.body["inference_results"],
)
)
)
np.testing.assert_array_almost_equal(predictions_sklearn, predictions_es)
import pandas as pd
X_transformed = pipeline["preprocessor"].transform(X=X)
X_transformed = pd.DataFrame(
X_transformed, columns=pipeline["preprocessor"].get_feature_names_out()
)
explainer = shap.TreeExplainer(pipeline["es_model"])
shap_values = explainer.shap_values(
X_transformed[pipeline["es_model"].feature_names_in_]
)
np.testing.assert_array_almost_equal(
predictions_sklearn, shap_values.sum(axis=1) + explainer.expected_value
)
@requires_sklearn
def test_export_classification(self, classification_model_id):
ed_flights = ed.DataFrame(ES_TEST_CLIENT, FLIGHTS_SMALL_INDEX_NAME).head(10)
X = ed.eland_to_pandas(ed_flights)
model = MLModel(es_client=ES_TEST_CLIENT, model_id=classification_model_id)
pipeline = model.export_model()
pipeline.fit(X)
predictions_sklearn = pipeline.predict(
X, feature_names_in=pipeline["preprocessor"].get_feature_names_out()
)
prediction_proba_sklearn = pipeline.predict_proba(
X, feature_names_in=pipeline["preprocessor"].get_feature_names_out()
).max(axis=1)
response = ES_TEST_CLIENT.ml.infer_trained_model(
model_id=classification_model_id,
docs=X[pipeline["es_model"].input_field_names].to_dict("records"),
)
predictions_es = np.array(
list(
map(
lambda x: str(int(x["Cancelled_prediction"])),
response.body["inference_results"],
)
)
)
prediction_proba_es = np.array(
list(
map(
itemgetter("prediction_probability"),
response.body["inference_results"],
)
)
)
np.testing.assert_array_almost_equal(
prediction_proba_sklearn, prediction_proba_es
)
np.testing.assert_array_equal(predictions_sklearn, predictions_es)
import pandas as pd
X_transformed = pipeline["preprocessor"].transform(X=X)
X_transformed = pd.DataFrame(
X_transformed, columns=pipeline["preprocessor"].get_feature_names_out()
)
explainer = shap.TreeExplainer(pipeline["es_model"])
shap_values = explainer.shap_values(
X_transformed[pipeline["es_model"].feature_names_in_]
)
log_odds = shap_values.sum(axis=1) + explainer.expected_value
prediction_proba_shap = 1 / (1 + np.exp(-log_odds))
# use probability of the predicted class
prediction_proba_shap[prediction_proba_shap < 0.5] = (
1 - prediction_proba_shap[prediction_proba_shap < 0.5]
)
np.testing.assert_array_almost_equal(
prediction_proba_sklearn, prediction_proba_shap
)
@requires_xgboost
@requires_sklearn
@pytest.mark.parametrize("objective", ["binary:logistic", "reg:squarederror"])
def test_xgb_import_export(self, objective):
booster = "gbtree"
if objective.startswith("binary:"):
training_data = datasets.make_classification(n_features=5)
xgb_model = XGBClassifier(
booster=booster, objective=objective, use_label_encoder=False
)
else:
training_data = datasets.make_regression(n_features=5)
xgb_model = XGBRegressor(
booster=booster, objective=objective, use_label_encoder=False
)
# Train model
xgb_model.fit(training_data[0], training_data[1])
# Serialise the models to Elasticsearch
feature_names = ["feature0", "feature1", "feature2", "feature3", "feature4"]
model_id = "test_xgb_model"
es_model = MLModel.import_model(
ES_TEST_CLIENT, model_id, xgb_model, feature_names, es_if_exists="replace"
)
# Export suppose to fail
with pytest.raises(ValueError) as ex:
es_model.export_model()
assert ex.match("Error initializing sklearn classifier.")
# Clean up
es_model.delete_model()
@requires_lightgbm
@pytest.mark.parametrize("objective", ["regression", "binary"])
def test_lgbm_import_export(self, objective):
booster = "gbdt"
if objective == "binary":
training_data = datasets.make_classification(n_features=5)
lgbm_model = LGBMClassifier(boosting_type=booster, objective=objective)
else:
training_data = datasets.make_regression(n_features=5)
lgbm_model = LGBMRegressor(boosting_type=booster, objective=objective)
# Train model
lgbm_model.fit(training_data[0], training_data[1])
# Serialise the models to Elasticsearch
feature_names = ["feature0", "feature1", "feature2", "feature3", "feature4"]
model_id = "test_lgbm_model"
es_model = MLModel.import_model(
ES_TEST_CLIENT, model_id, lgbm_model, feature_names, es_if_exists="replace"
)
# Export suppose to fail
with pytest.raises(ValueError) as ex:
es_model.export_model()
assert ex.match("Error initializing sklearn classifier.")
# Clean up
es_model.delete_model()