diff --git a/.ci/test-matrix.yml b/.ci/test-matrix.yml old mode 100755 new mode 100644 diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index fba4f70..f61d8d0 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -191,7 +191,7 @@ currently using a minimum version of PyCharm 2019.2.4. ``` bash > import eland as ed - > ed_df = ed.DataFrame('localhost', 'flights') + > ed_df = ed.DataFrame('http://localhost:9200', 'flights') ``` * To run the automatic formatter and check for lint issues run diff --git a/eland/ml/exporters/__init__.py b/eland/ml/exporters/__init__.py new file mode 100644 index 0000000..2a87d18 --- /dev/null +++ b/eland/ml/exporters/__init__.py @@ -0,0 +1,16 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/eland/ml/exporters/_sklearn_deserializers.py b/eland/ml/exporters/_sklearn_deserializers.py new file mode 100644 index 0000000..086038b --- /dev/null +++ b/eland/ml/exporters/_sklearn_deserializers.py @@ -0,0 +1,217 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Any, Dict + +import numpy as np + +from .._optional import import_optional_dependency + +import_optional_dependency("sklearn", on_version="warn") + +import sklearn +from sklearn.preprocessing import FunctionTransformer + + +class Tree: + """Wrapper to create sklearn Tree objects from Elastic ML tree + description in JSON format. + """ + + def __init__( + self, + json_tree: Dict[str, Any], + feature_names_map: Dict[str, int], + ): + tree_leaf = -1 + + node_count = len(json_tree["tree_structure"]) + children_left = np.ones((node_count,), dtype=int) * tree_leaf + children_right = np.ones((node_count,), dtype=int) * tree_leaf + feature = np.ones((node_count,), dtype=int) * -2 + threshold = np.ones((node_count,), dtype=float) * -2 + impurity = np.zeros((node_count,), dtype=float) + # value works only for regression and binary classification + value = np.zeros((node_count, 1, 1), dtype=" int: + if children_right[node_index] == -1: + return 0 + + left_index = children_left[node_index] + right_index = children_right[node_index] + depth_left = Tree._compute_expectations( + children_left, children_right, node_sample_weight, values, left_index + ) + depth_right = Tree._compute_expectations( + children_left, children_right, node_sample_weight, values, right_index + ) + left_weight = node_sample_weight[left_index] + right_weight = node_sample_weight[right_index] + + v = ( + ( + left_weight * values[left_index, :] + + right_weight * values[right_index, :] + ) + / (left_weight + right_weight) + if left_weight + right_weight > 0 + else 0 + ) + values[node_index, :] = v + return max(depth_left, depth_right) + 1 + + +class TargetMeanEncoder(FunctionTransformer): + """FunctionTransformer implementation of the target mean encoder, which is + deserialized from the Elastic ML preprocessor description in JSON formats. + """ + + def __init__(self, preprocessor: Dict[str, Any]): + self.preprocessor = preprocessor + target_map = self.preprocessor["target_mean_encoding"]["target_map"] + feature_name_out = self.preprocessor["target_mean_encoding"]["feature_name"] + self.field_name_in = self.preprocessor["target_mean_encoding"]["field"] + fallback_value = self.preprocessor["target_mean_encoding"]["default_value"] + + def func(column): + return np.array( + [ + target_map[str(category)] + if category in target_map + else fallback_value + for category in column + ] + ).reshape(-1, 1) + + def feature_names_out(ft, carr): + return [feature_name_out if c == self.field_name_in else c for c in carr] + + super().__init__(func=func, feature_names_out=feature_names_out) + + +class FrequencyEncoder(FunctionTransformer): + """FunctionTransformer implementation of the frequency encoder, which is + deserialized from the Elastic ML preprocessor description in JSON format. + """ + + def __init__(self, preprocessor: Dict[str, Any]): + self.preprocessor = preprocessor + frequency_map = self.preprocessor["frequency_encoding"]["frequency_map"] + feature_name_out = self.preprocessor["frequency_encoding"]["feature_name"] + self.field_name_in = self.preprocessor["frequency_encoding"]["field"] + fallback_value = 0.0 + + def func(column): + return np.array( + [ + frequency_map[str(category)] + if category in frequency_map + else fallback_value + for category in column + ] + ).reshape(-1, 1) + + def feature_names_out(ft, carr): + return [feature_name_out if c == self.field_name_in else c for c in carr] + + super().__init__(func=func, feature_names_out=feature_names_out) + + +class OneHotEncoder(sklearn.preprocessing.OneHotEncoder): + """Wrapper for sklearn one-hot encoder, which is deserialized from the + Elastic ML preprocessor description in JSON format. + """ + + def __init__(self, preprocessor: Dict[str, Any]): + self.preprocessor = preprocessor + self.field_name_in = self.preprocessor["one_hot_encoding"]["field"] + self.cats = [list(self.preprocessor["one_hot_encoding"]["hot_map"].keys())] + super().__init__(categories=self.cats, handle_unknown="ignore") diff --git a/eland/ml/exporters/common.py b/eland/ml/exporters/common.py new file mode 100644 index 0000000..02c72ec --- /dev/null +++ b/eland/ml/exporters/common.py @@ -0,0 +1,46 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import eland + + +class ModelDefinitionKeyError(Exception): + """ + This exception is raised when a key is not found in the model definition. + + Attributes: + missed_key (str): The key that was not found in the model definition. + available_keys (List[str]): The list of keys that are available in the model definition. + + Examples: + model_definition = {"key1": "value1", "key2": "value2"} + try: + model_definition["key3"] + except KeyError as ex: + raise ModelDefinitionKeyError(ex) from ex + """ + + def __init__(self, ex: KeyError): + self.missed_key = ex.args[0] + + def __str__(self): + return ( + f'Key "{self.missed_key}" is not available. ' + + "The model definition may have changed. " + + "Make sure you are using an Elasticsearch version compatible " + + f"with Eland {eland.__version__}." + ) diff --git a/eland/ml/exporters/es_gb_models.py b/eland/ml/exporters/es_gb_models.py new file mode 100644 index 0000000..5802b05 --- /dev/null +++ b/eland/ml/exporters/es_gb_models.py @@ -0,0 +1,472 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from abc import ABC +from typing import Any, List, Literal, Mapping, Optional, Set, Tuple, Union + +import numpy as np +from elasticsearch import Elasticsearch +from numpy.typing import ArrayLike + +from .._optional import import_optional_dependency + +import_optional_dependency("sklearn", on_version="warn") + +from sklearn.dummy import DummyClassifier, DummyRegressor +from sklearn.ensemble import GradientBoostingClassifier, GradientBoostingRegressor +from sklearn.ensemble._gb_losses import ( + BinomialDeviance, + HuberLossFunction, + LeastSquaresError, +) +from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor +from sklearn.utils.validation import check_array + +from eland.common import ensure_es_client +from eland.ml.common import TYPE_CLASSIFICATION, TYPE_REGRESSION + +from ._sklearn_deserializers import Tree +from .common import ModelDefinitionKeyError + + +class ESGradientBoostingModel(ABC): + """ + Abstract class for converting Elastic ML model into sklearn Pipeline. + """ + + def __init__( + self, + es_client: Union[str, List[str], Tuple[str, ...], "Elasticsearch"], + model_id: str, + ) -> None: + """ + Parameters + ---------- + es_client : Elasticsearch client argument(s) + - elasticsearch-py parameters or + - elasticsearch-py instance + model_id : str + The unique identifier of the trained inference model in Elasticsearch. + + Raises + ------ + RuntimeError + On failure to retrieve trained model information to the specified model ID. + ValueError + The model is expected to be trained in Elastic Stack. Models initially imported + from xgboost, lgbm, or sklearn are not supported. + """ + self.es_client: Elasticsearch = ensure_es_client(es_client) + self.model_id = model_id + + self._trained_model_result = self.es_client.ml.get_trained_models( + model_id=self.model_id, + decompress_definition=True, + include=["hyperparameters", "definition"], + ) + + if ( + "trained_model_configs" not in self._trained_model_result + or len(self._trained_model_result["trained_model_configs"]) == 0 + ): + raise RuntimeError( + f"Failed to retrieve the trained model for model ID {self.model_id!r}" + ) + + if "metadata" not in self._trained_model_result["trained_model_configs"][0]: + raise ValueError( + "Error initializing sklearn classifier. Incorrect prior class probability. " + + "Note: only export of models trained in the Elastic Stack is supported." + ) + preprocessors = [] + if "preprocessors" in self._definition: + preprocessors = self._definition["preprocessors"] + ( + self.feature_names_in_, + self.input_field_names, + ) = ESGradientBoostingModel._get_feature_names_in_( + preprocessors, + self._definition["trained_model"]["ensemble"]["feature_names"], + self._trained_model_result["trained_model_configs"][0]["input"][ + "field_names" + ], + ) + + feature_names_map = {name: i for i, name in enumerate(self.feature_names_in_)} + + trained_models = self._definition["trained_model"]["ensemble"]["trained_models"] + self._trees = [] + for trained_model in trained_models: + self._trees.append(Tree(trained_model["tree"], feature_names_map)) + + # 0's tree is the constant estimator + self.n_estimators = len(trained_models) - 1 + + def _initialize_estimators(self, decision_tree_type) -> None: + self.estimators_ = np.ndarray( + (len(self._trees) - 1, 1), dtype=decision_tree_type + ) + self.n_estimators_ = self.estimators_.shape[0] + + for i in range(self.n_estimators_): + estimator = decision_tree_type() + estimator.tree_ = self._trees[i + 1].tree + estimator.n_features_in_ = self.n_features_in_ + estimator.max_depth = self._max_depth + estimator.max_features_ = self.max_features_ + self.estimators_[i, 0] = estimator + + def _extract_common_parameters(self) -> None: + self.n_features_in_ = len(self.feature_names_in_) + self.max_features_ = self.n_features_in_ + + @property + def _max_depth(self) -> int: + return max(map(lambda x: x.max_depth, self._trees)) + + @property + def _n_outputs(self) -> int: + return self._trees[0].n_outputs + + @property + def _definition(self) -> Mapping[Union[str, int], Any]: + return self._trained_model_result["trained_model_configs"][0]["definition"] + + @staticmethod + def _get_feature_names_in_( + preprocessors, feature_names, field_names + ) -> Tuple[List[str], Set[str]]: + input_field_names = set() + + def add_input_field_name(preprocessor_type: str, feature_name: str) -> None: + if feature_name in feature_names: + input_field_names.add(preprocessor[preprocessor_type]["field"]) + + for preprocessor in preprocessors: + if "target_mean_encoding" in preprocessor: + add_input_field_name( + "target_mean_encoding", + preprocessor["target_mean_encoding"]["feature_name"], + ) + elif "frequency_encoding" in preprocessor: + add_input_field_name( + "frequency_encoding", + preprocessor["frequency_encoding"]["feature_name"], + ) + elif "one_hot_encoding" in preprocessor: + for feature_name in preprocessor["one_hot_encoding"][ + "hot_map" + ].values(): + add_input_field_name("one_hot_encoding", feature_name) + + for field_name in field_names: + if field_name in feature_names and field_name not in input_field_names: + input_field_names.add(field_name) + + return feature_names, input_field_names + + @property + def preprocessors(self) -> List[Any]: + """ + Returns the list of preprocessor JSON definitions. + + Returns + ------- + List[Any] + List of preprocessors definitions or []. + """ + if "preprocessors" in self._definition: + return self._definition["preprocessors"] + return [] + + def fit(self, X, y, sample_weight=None, monitor=None) -> None: + """ + Override of the sklearn fit() method. It does nothing since Elastic ML models are + trained in the Elastic Stack or imported. + """ + # Do nothing, model if fitted using Elasticsearch API + pass + + +class ESGradientBoostingClassifier(ESGradientBoostingModel, GradientBoostingClassifier): + """ + Elastic ML model wrapper compatible with sklearn GradientBoostingClassifier. + """ + + def __init__( + self, + es_client: Union[str, List[str], Tuple[str, ...], "Elasticsearch"], + model_id: str, + ) -> None: + """ + Parameters + ---------- + es_client : Elasticsearch client argument(s) + - elasticsearch-py parameters or + - elasticsearch-py instance + model_id : str + The unique identifier of the trained inference model in Elasticsearch. + + Raises + ------ + NotImplementedError + Multi-class classification is not supported at the moment. + ValueError + The classifier should be defined for at least 2 classes. + ModelDefinitionKeyError + If required data cannot be extracted from the model definition due to a schema change. + """ + + try: + ESGradientBoostingModel.__init__(self, es_client, model_id) + self._extract_common_parameters() + GradientBoostingClassifier.__init__( + self, + learning_rate=1.0, + n_estimators=self.n_estimators, + max_depth=self._max_depth, + ) + + if "classification_labels" in self._definition["trained_model"]["ensemble"]: + self.classes_ = np.array( + self._definition["trained_model"]["ensemble"][ + "classification_labels" + ] + ) + else: + self.classes_ = None + + self.n_outputs = self._n_outputs + if self.classes_ is not None: + self.n_classes_ = len(self.classes_) + elif self.n_outputs <= 2: + self.n_classes_ = 2 + else: + self.n_classes_ = self.n_outputs + + if self.n_classes_ == 2: + self._loss = BinomialDeviance(self.n_classes_) + # self.n_outputs = 1 + elif self.n_classes_ > 2: + raise NotImplementedError("Only binary classification is implemented.") + else: + raise ValueError(f"At least 2 classes required. got {self.n_classes_}.") + + self.init_ = self._initialize_init_() + self._initialize_estimators(DecisionTreeClassifier) + except KeyError as ex: + raise ModelDefinitionKeyError(ex) from ex + + @property + def analysis_type(self) -> Literal["classification"]: + return TYPE_CLASSIFICATION + + def _initialize_init_(self) -> DummyClassifier: + estimator = DummyClassifier(strategy="prior") + + estimator.n_classes_ = self.n_classes_ + estimator.n_outputs_ = self.n_outputs + estimator.classes_ = np.arange(self.n_classes_) + estimator._strategy = estimator.strategy + + if self.n_classes_ == 2: + log_odds = self._trees[0].tree.value.flatten()[0] + if np.isnan(log_odds): + raise ValueError( + "Error initializing sklearn classifier. Incorrect prior class probability. " + + "Note: only export of models trained in the Elastic Stack is supported." + ) + class_prior = 1 / (1 + np.exp(-log_odds)) + estimator.class_prior_ = np.array([1 - class_prior, class_prior]) + else: + raise NotImplementedError("Only binary classification is implemented.") + + return estimator + + def predict_proba( + self, X, feature_names_in: Optional[Union["ArrayLike", List[str]]] = None + ) -> "ArrayLike": + """Predict class probabilities for X. + + Parameters + ---------- + X : array-like of shape (n_samples, n_features) + The input samples. + feature_names_in : {array of string, list of string} of length n_features. + Feature names of the corresponding columns in X. Important, since the column list + can be extended by ColumnTransformer through the pipeline. By default None. + + Returns + ------- + ArrayLike of shape (n_samples, n_classes) + The class probabilities of the input samples. The order of the + classes corresponds to that in the attribute :term:`classes_`. + """ + if feature_names_in is not None: + if X.shape[1] != len(feature_names_in): + raise ValueError( + f"Dimension mismatch: X with {X.shape[1]} columns has to be the same size as feature_names_in with {len(feature_names_in)}." + ) + if isinstance(feature_names_in, np.ndarray): + feature_names_in = feature_names_in.tolist() + # select columns used by the model in the correct order + X = X[:, [feature_names_in.index(fn) for fn in self.feature_names_in_]] + + X = check_array(X) + return GradientBoostingClassifier.predict_proba(self, X) + + def predict( + self, + X: "ArrayLike", + feature_names_in: Optional[Union["ArrayLike", List[str]]] = None, + ) -> "ArrayLike": + """Predict class for X. + + Parameters + ---------- + X : array-like of shape (n_samples, n_features) + The input samples. + feature_names_in : {array of string, list of string} of length n_features. + Feature names of the corresponding columns in X. Important, since the column list + can be extended by ColumnTransformer through the pipeline. By default None. + + Returns + ------- + ArrayLike of shape (n_samples,) + The predicted values. + """ + if feature_names_in is not None: + if X.shape[1] != len(feature_names_in): + raise ValueError( + f"Dimension mismatch: X with {X.shape[1]} columns has to be the same size as feature_names_in with {len(feature_names_in)}." + ) + if isinstance(feature_names_in, np.ndarray): + feature_names_in = feature_names_in.tolist() + # select columns used by the model in the correct order + X = X[:, [feature_names_in.index(fn) for fn in self.feature_names_in_]] + + X = check_array(X) + return GradientBoostingClassifier.predict(self, X) + + +class ESGradientBoostingRegressor(ESGradientBoostingModel, GradientBoostingRegressor): + """ + Elastic ML model wrapper compatible with sklearn GradientBoostingRegressor. + """ + + def __init__( + self, + es_client: Union[str, List[str], Tuple[str, ...], "Elasticsearch"], + model_id: str, + ) -> None: + """ + Parameters + ---------- + es_client : Elasticsearch client argument(s) + - elasticsearch-py parameters or + - elasticsearch-py instance + model_id : str + The unique identifier of the trained inference model in Elasticsearch. + + Raises + ------ + NotImplementedError + Only MSE, MSLE, and Huber loss functions are supported. + ModelDefinitionKeyError + If required data cannot be extracted from the model definition due to a schema change. + """ + try: + ESGradientBoostingModel.__init__(self, es_client, model_id) + self._extract_common_parameters() + GradientBoostingRegressor.__init__( + self, + learning_rate=1.0, + n_estimators=self.n_estimators, + max_depth=self._max_depth, + ) + + self.n_outputs = 1 + loss_function = self._trained_model_result["trained_model_configs"][0][ + "metadata" + ]["analytics_config"]["analysis"][self.analysis_type]["loss_function"] + if loss_function == "mse" or loss_function == "msle": + self.criterion = "squared_error" + self._loss = LeastSquaresError() + elif loss_function == "huber": + loss_parameter = loss_function = self._trained_model_result[ + "trained_model_configs" + ][0]["metadata"]["analytics_config"]["analysis"][self.analysis_type][ + "loss_function_parameter" + ] + self.criterion = "huber" + self._loss = HuberLossFunction(loss_parameter) + else: + raise NotImplementedError( + "Only MSE, MSLE and Huber loss functions are supported." + ) + + self.init_ = self._initialize_init_() + self._initialize_estimators(DecisionTreeRegressor) + except KeyError as ex: + raise ModelDefinitionKeyError(ex) from ex + + @property + def analysis_type(self) -> Literal["regression"]: + return TYPE_REGRESSION + + def _initialize_init_(self) -> DummyRegressor: + constant = self._trees[0].tree.value[0] + estimator = DummyRegressor( + strategy="constant", + constant=constant, + ) + estimator.constant_ = np.array([constant]) + estimator.n_outputs_ = 1 + return estimator + + def predict( + self, + X: "ArrayLike", + feature_names_in: Optional[Union["ArrayLike", List[str]]] = None, + ) -> "ArrayLike": + """Predict targets for X. + + Parameters + ---------- + X : array-like of shape (n_samples, n_features) + The input samples. + feature_names_in : {array of string, list of string} of length n_features. + Feature names of the corresponding columns in X. Important, since the column list + can be extended by ColumnTransformer through the pipeline. By default None. + + Returns + ------- + ArrayLike of shape (n_samples,) + The predicted values. + """ + if feature_names_in is not None: + if X.shape[1] != len(feature_names_in): + raise ValueError( + f"Dimension mismatch: X with {X.shape[1]} columns has to be the same size as feature_names_in with {len(feature_names_in)}." + ) + if isinstance(X, np.ndarray): + feature_names_in = feature_names_in.tolist() + # select columns used by the model in the correct order + X = X[:, [feature_names_in.index(fn) for fn in self.feature_names_in_]] + + X = check_array(X) + return GradientBoostingRegressor.predict(self, X) diff --git a/eland/ml/ml_model.py b/eland/ml/ml_model.py index c1333dc..fb26b9d 100644 --- a/eland/ml/ml_model.py +++ b/eland/ml/ml_model.py @@ -37,6 +37,7 @@ if TYPE_CHECKING: RandomForestClassifier, RandomForestRegressor, ) + from sklearn.pipeline import Pipeline # type: ignore # noqa: F401 from sklearn.tree import ( # type: ignore # noqa: F401 DecisionTreeClassifier, DecisionTreeRegressor, @@ -424,6 +425,83 @@ class MLModel: return False return True + def export_model(self) -> "Pipeline": + """Export Elastic ML model as sklearn Pipeline. + + Returns + ------- + sklearn.pipeline.Pipeline + _description_ + + Raises + ------ + AssertionError + If preprocessors JSON definition has unexpected schema. + ValueError + The model is expected to be trained in Elastic Stack. Models initially imported + from xgboost, lgbm, or sklearn are not supported. + ValueError + If unexpected categorical encoding is found in the list of preprocessors. + NotImplementedError + Only regression and binary classification models are supported currently. + """ + from sklearn.compose import ColumnTransformer # type: ignore # noqa: F401 + from sklearn.pipeline import Pipeline + + from .exporters._sklearn_deserializers import ( + FrequencyEncoder, + OneHotEncoder, + TargetMeanEncoder, + ) + from .exporters.es_gb_models import ( + ESGradientBoostingClassifier, + ESGradientBoostingRegressor, + ) + + if self.model_type == TYPE_CLASSIFICATION: + model = ESGradientBoostingClassifier( + es_client=self._client, model_id=self._model_id + ) + elif self.model_type == TYPE_REGRESSION: + model = ESGradientBoostingRegressor( + es_client=self._client, model_id=self._model_id + ) + else: + raise NotImplementedError( + "Only regression and binary classification models are supported currently." + ) + + transformers = [] + for p in model.preprocessors: + assert ( + len(p) == 1 + ), f"Unexpected preprocessor data structure: {p}. One-key mapping expected." + encoding_type = list(p.keys())[0] + field = p[encoding_type]["field"] + if encoding_type == "frequency_encoding": + transform = FrequencyEncoder(p) + transformers.append((f"{field}_{encoding_type}", transform, field)) + elif encoding_type == "target_mean_encoding": + transform = TargetMeanEncoder(p) + transformers.append((f"{field}_{encoding_type}", transform, field)) + elif encoding_type == "one_hot_encoding": + transform = OneHotEncoder(p) + transformers.append((f"{field}_{encoding_type}", transform, [field])) + else: + raise ValueError( + f"Unexpected categorical encoding type {encoding_type} found. " + + "Expected encodings: frequency_encoding, target_mean_encoding, one_hot_encoding." + ) + preprocessor = ColumnTransformer( + transformers=transformers, + remainder="passthrough", + verbose_feature_names_out=False, + ) + + pipeline = Pipeline(steps=[("preprocessor", preprocessor), ("es_model", model)]) + + return pipeline + @property def _trained_model_config(self) -> Dict[str, Any]: """Lazily loads an ML models 'trained_model_config' information""" diff --git a/eland/ml/pytorch/transformers.py b/eland/ml/pytorch/transformers.py index 4b662c0..1037e8f 100644 --- a/eland/ml/pytorch/transformers.py +++ b/eland/ml/pytorch/transformers.py @@ -125,7 +125,7 @@ def task_type_from_model_config(model_config: PretrainedConfig) -> Optional[str] return None potential_task_types: Set[str] = set() for architecture in model_config.architectures: - for (substr, task_type) in ARCHITECTURE_TO_TASK_TYPE.items(): + for substr, task_type in ARCHITECTURE_TO_TASK_TYPE.items(): if substr in architecture: for t in task_type: potential_task_types.add(t) @@ -384,7 +384,6 @@ class _DPREncoderWrapper(nn.Module): # type: ignore @staticmethod def from_pretrained(model_id: str) -> Optional[Any]: - config = AutoConfig.from_pretrained(model_id) def is_compatible() -> bool: diff --git a/eland/operations.py b/eland/operations.py index f5dd064..a6b20f2 100644 --- a/eland/operations.py +++ b/eland/operations.py @@ -210,7 +210,6 @@ class Operations: def idx( self, query_compiler: "QueryCompiler", axis: int, sort_order: str ) -> pd.Series: - if axis == 1: # Fetch idx on Columns raise NotImplementedError( @@ -279,7 +278,6 @@ class Operations: numeric_only: bool = False, dropna: bool = True, ) -> Union[pd.DataFrame, pd.Series]: - results = self._metric_aggs( query_compiler, pd_aggs=pd_aggs, @@ -530,7 +528,6 @@ class Operations: # weights = [10066., 263., 386., 264., 273., 390., 324., 438., 261., 252., 142.] # So sum last 2 buckets for field in numeric_source_fields: - # in case of series let plotting.ed_hist_series thrown an exception if not response.get("aggregations"): continue @@ -771,7 +768,6 @@ class Operations: is_dataframe: bool = True, numeric_only: Optional[bool] = True, ) -> Union[pd.DataFrame, pd.Series]: - percentiles = [ quantile_to_percentile(x) for x in ( @@ -801,7 +797,6 @@ class Operations: return df if is_dataframe else df.transpose().iloc[0] def unique(self, query_compiler: "QueryCompiler") -> pd.Series: - query_params, _ = self._resolve_tasks(query_compiler) body = Query(query_params.query) @@ -1052,7 +1047,6 @@ class Operations: buckets: Sequence[Dict[str, Any]] = composite_buckets["buckets"] if after_key: - # yield the bucket which contains the result yield buckets @@ -1227,7 +1221,6 @@ class Operations: def to_pandas( self, query_compiler: "QueryCompiler", show_progress: bool = False ) -> pd.DataFrame: - df_list: List[pd.DataFrame] = [] i = 0 for df in self.search_yield_pandas_dataframes(query_compiler=query_compiler): diff --git a/eland/query.py b/eland/query.py index be20d1f..21f7cb4 100644 --- a/eland/query.py +++ b/eland/query.py @@ -170,7 +170,6 @@ class Query: sort_order: str, size: int = 1, ) -> None: - top_hits: Any = {} if sort_order: top_hits["sort"] = [{i: {"order": sort_order}} for i in source_columns] diff --git a/eland/query_compiler.py b/eland/query_compiler.py index 6dc032d..6af60e5 100644 --- a/eland/query_compiler.py +++ b/eland/query_compiler.py @@ -246,7 +246,6 @@ class QueryCompiler: i = 0 for i, hit in enumerate(results, 1): - if "_source" in hit: row = hit["_source"] else: diff --git a/noxfile.py b/noxfile.py index a1fbdec..06bc74f 100644 --- a/noxfile.py +++ b/noxfile.py @@ -61,7 +61,7 @@ def format(session): session.install("black", "isort", "flynt") session.run("python", "utils/license-headers.py", "fix", *SOURCE_FILES) session.run("flynt", *SOURCE_FILES) - session.run("black", "--target-version=py37", *SOURCE_FILES) + session.run("black", "--target-version=py38", *SOURCE_FILES) session.run("isort", "--profile=black", *SOURCE_FILES) lint(session) @@ -73,7 +73,7 @@ def lint(session): session.install("black", "flake8", "mypy", "isort", "numpy") session.install("--pre", "elasticsearch>=8.3,<9") session.run("python", "utils/license-headers.py", "check", *SOURCE_FILES) - session.run("black", "--check", "--target-version=py37", *SOURCE_FILES) + session.run("black", "--check", "--target-version=py38", *SOURCE_FILES) session.run("isort", "--check", "--profile=black", *SOURCE_FILES) session.run("flake8", "--ignore=E501,W503,E402,E712,E203", *SOURCE_FILES) @@ -138,6 +138,7 @@ def test(session, pandas_version: str): "scikit-learn", "xgboost", "lightgbm", + "shap", ) session.run("pytest", "tests/ml/") diff --git a/requirements-dev.txt b/requirements-dev.txt index 10efae6..0d21e02 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -29,6 +29,7 @@ pytest>=5.2.1 pytest-mock pytest-cov nbval +shap==0.41.0 # # Docs diff --git a/tests/dataframe/test_count_pytest.py b/tests/dataframe/test_count_pytest.py index e57db99..7cadee0 100644 --- a/tests/dataframe/test_count_pytest.py +++ b/tests/dataframe/test_count_pytest.py @@ -35,7 +35,6 @@ class TestDataFrameCount(TestData): df.count() def test_count_flights(self): - pd_flights = self.pd_flights().filter(self.filter_data) ed_flights = self.ed_flights().filter(self.filter_data) diff --git a/tests/dataframe/test_metrics_pytest.py b/tests/dataframe/test_metrics_pytest.py index 79b958e..9477629 100644 --- a/tests/dataframe/test_metrics_pytest.py +++ b/tests/dataframe/test_metrics_pytest.py @@ -419,7 +419,6 @@ class TestDataFrameMetrics(TestData): assert calculated_values.shape == (2,) def test_aggs_count(self): - pd_flights = self.pd_flights().filter(self.filter_data) ed_flights = self.ed_flights().filter(self.filter_data) diff --git a/tests/dataframe/test_to_csv_pytest.py b/tests/dataframe/test_to_csv_pytest.py index 49c05c7..1971d77 100644 --- a/tests/dataframe/test_to_csv_pytest.py +++ b/tests/dataframe/test_to_csv_pytest.py @@ -102,7 +102,6 @@ class TestDataFrameToCSV(TestData): ES_TEST_CLIENT.indices.delete(index=test_index) def test_pd_to_csv_without_filepath(self): - ed_flights = self.ed_flights() pd_flights = self.pd_flights() diff --git a/tests/dataframe/test_utils_pytest.py b/tests/dataframe/test_utils_pytest.py index d14c332..5cdb591 100644 --- a/tests/dataframe/test_utils_pytest.py +++ b/tests/dataframe/test_utils_pytest.py @@ -147,7 +147,6 @@ class TestDataFrameUtils(TestData): # assert_pandas_eland_frame_equal(pd_df, self.ed_flights()) def test_es_type_override_error(self): - df = self.pd_flights().filter( ["AvgTicketPrice", "Cancelled", "dayOfWeek", "timestamp", "DestCountry"] ) diff --git a/tests/ml/test_ml_model_pytest.py b/tests/ml/test_ml_model_pytest.py index 7d28018..07250c6 100644 --- a/tests/ml/test_ml_model_pytest.py +++ b/tests/ml/test_ml_model_pytest.py @@ -15,11 +15,14 @@ # specific language governing permissions and limitations # under the License. +from operator import itemgetter + import numpy as np import pytest +import eland as ed from eland.ml import MLModel -from tests import ES_TEST_CLIENT, ES_VERSION +from tests import ES_TEST_CLIENT, ES_VERSION, FLIGHTS_SMALL_INDEX_NAME try: from sklearn import datasets @@ -44,16 +47,26 @@ try: except ImportError: HAS_LIGHTGBM = False +try: + import shap + + HAS_SHAP = True +except ImportError: + HAS_SHAP = False + requires_sklearn = pytest.mark.skipif( - not HAS_SKLEARN, reason="This test requires 'scikit-learn' package to run" + not HAS_SKLEARN, reason="This test requires 'scikit-learn' package to run." ) requires_xgboost = pytest.mark.skipif( - not HAS_XGBOOST, reason="This test requires 'xgboost' package to run" + not HAS_XGBOOST, reason="This test requires 'xgboost' package to run." +) +requires_shap = pytest.mark.skipif( + not HAS_SHAP, reason="This tests requries 'shap' package to run." ) requires_no_ml_extras = pytest.mark.skipif( HAS_SKLEARN or HAS_XGBOOST, - reason="This test requires 'scikit-learn' and 'xgboost' to not be installed", + reason="This test requires 'scikit-learn' and 'xgboost' to not be installed.", ) requires_lightgbm = pytest.mark.skipif( @@ -80,6 +93,102 @@ def check_prediction_equality(es_model: MLModel, py_model, test_data): np.testing.assert_almost_equal(test_results, es_results, decimal=2) +def yield_model_id(analysis, analyzed_fields): + import random + import string + import time + + suffix = "".join(random.choices(string.ascii_lowercase, k=4)) + job_id = "test-flights-regression-" + suffix + dest = job_id + "-dest" + + response = ES_TEST_CLIENT.ml.put_data_frame_analytics( + id=job_id, + analysis=analysis, + dest={"index": dest}, + source={"index": [FLIGHTS_SMALL_INDEX_NAME]}, + analyzed_fields=analyzed_fields, + ) + assert response.meta.status == 200 + response = ES_TEST_CLIENT.ml.start_data_frame_analytics(id=job_id) + assert response.meta.status == 200 + + time.sleep(2) + response = ES_TEST_CLIENT.ml.get_trained_models(model_id=job_id + "*") + assert response.meta.status == 200 + assert response.body["count"] == 1 + model_id = response.body["trained_model_configs"][0]["model_id"] + + yield model_id + + ES_TEST_CLIENT.ml.delete_data_frame_analytics(id=job_id) + ES_TEST_CLIENT.indices.delete(index=dest) + ES_TEST_CLIENT.ml.delete_trained_model(model_id=model_id) + + +@pytest.fixture(params=[[0, 4], [0, 1], range(5)]) +def regression_model_id(request): + analysis = { + "regression": { + "dependent_variable": "FlightDelayMin", + "max_trees": 3, + "num_top_feature_importance_values": 0, + "max_optimization_rounds_per_hyperparameter": 1, + "prediction_field_name": "FlightDelayMin_prediction", + "training_percent": 30, + "randomize_seed": 1000, + "loss_function": "mse", + "early_stopping_enabled": True, + } + } + all_includes = [ + "FlightDelayMin", + "FlightDelayType", + "FlightTimeMin", + "DistanceMiles", + "OriginAirportID", + ] + includes = [all_includes[i] for i in request.param] + analyzed_fields = { + "includes": includes, + "excludes": [], + } + yield from yield_model_id(analysis=analysis, analyzed_fields=analyzed_fields) + + +@pytest.fixture(params=[[0, 6], [5, 6], range(7)]) +def classification_model_id(request): + analysis = { + "classification": { + "dependent_variable": "Cancelled", + "max_trees": 5, + "num_top_feature_importance_values": 0, + "max_optimization_rounds_per_hyperparameter": 1, + "prediction_field_name": "Cancelled_prediction", + "training_percent": 50, + "randomize_seed": 1000, + "num_top_classes": -1, + "class_assignment_objective": "maximize_accuracy", + "early_stopping_enabled": True, + } + } + all_includes = [ + "OriginWeather", + "OriginAirportID", + "DestCityName", + "DestWeather", + "DestRegion", + "AvgTicketPrice", + "Cancelled", + ] + includes = [all_includes[i] for i in request.param] + analyzed_fields = { + "includes": includes, + "excludes": [], + } + yield from yield_model_id(analysis=analysis, analyzed_fields=analyzed_fields) + + class TestMLModel: @requires_no_ml_extras def test_import_ml_model_when_dependencies_are_not_available(self): @@ -494,3 +603,172 @@ class TestMLModel: # Clean up es_model.delete_model() + + @requires_sklearn + @requires_shap + def test_export_regressor(self, regression_model_id): + ed_flights = ed.DataFrame(ES_TEST_CLIENT, FLIGHTS_SMALL_INDEX_NAME).head(10) + types = dict(ed_flights.dtypes) + X = ed_flights.to_pandas().astype(types) + + model = MLModel(es_client=ES_TEST_CLIENT, model_id=regression_model_id) + pipeline = model.export_model() + pipeline.fit(X) + + predictions_sklearn = pipeline.predict( + X, feature_names_in=pipeline["preprocessor"].get_feature_names_out() + ) + response = ES_TEST_CLIENT.ml.infer_trained_model( + model_id=regression_model_id, + docs=X[pipeline["es_model"].input_field_names].to_dict("records"), + ) + predictions_es = np.array( + list( + map( + itemgetter("FlightDelayMin_prediction"), + response.body["inference_results"], + ) + ) + ) + np.testing.assert_array_almost_equal(predictions_sklearn, predictions_es) + + import pandas as pd + + X_transformed = pipeline["preprocessor"].transform(X=X) + X_transformed = pd.DataFrame( + X_transformed, columns=pipeline["preprocessor"].get_feature_names_out() + ) + explainer = shap.TreeExplainer(pipeline["es_model"]) + shap_values = explainer.shap_values( + X_transformed[pipeline["es_model"].feature_names_in_] + ) + np.testing.assert_array_almost_equal( + predictions_sklearn, shap_values.sum(axis=1) + explainer.expected_value + ) + + @requires_sklearn + def test_export_classification(self, classification_model_id): + ed_flights = ed.DataFrame(ES_TEST_CLIENT, FLIGHTS_SMALL_INDEX_NAME).head(10) + X = ed.eland_to_pandas(ed_flights) + + model = MLModel(es_client=ES_TEST_CLIENT, model_id=classification_model_id) + pipeline = model.export_model() + pipeline.fit(X) + + predictions_sklearn = pipeline.predict( + X, feature_names_in=pipeline["preprocessor"].get_feature_names_out() + ) + prediction_proba_sklearn = pipeline.predict_proba( + X, feature_names_in=pipeline["preprocessor"].get_feature_names_out() + ).max(axis=1) + + response = ES_TEST_CLIENT.ml.infer_trained_model( + model_id=classification_model_id, + docs=X[pipeline["es_model"].input_field_names].to_dict("records"), + ) + predictions_es = np.array( + list( + map( + lambda x: str(int(x["Cancelled_prediction"])), + response.body["inference_results"], + ) + ) + ) + prediction_proba_es = np.array( + list( + map( + itemgetter("prediction_probability"), + response.body["inference_results"], + ) + ) + ) + np.testing.assert_array_almost_equal( + prediction_proba_sklearn, prediction_proba_es + ) + np.testing.assert_array_equal(predictions_sklearn, predictions_es) + + import pandas as pd + + X_transformed = pipeline["preprocessor"].transform(X=X) + X_transformed = pd.DataFrame( + X_transformed, columns=pipeline["preprocessor"].get_feature_names_out() + ) + explainer = shap.TreeExplainer(pipeline["es_model"]) + shap_values = explainer.shap_values( + X_transformed[pipeline["es_model"].feature_names_in_] + ) + log_odds = shap_values.sum(axis=1) + explainer.expected_value + prediction_proba_shap = 1 / (1 + np.exp(-log_odds)) + # use probability of the predicted class + prediction_proba_shap[prediction_proba_shap < 0.5] = ( + 1 - prediction_proba_shap[prediction_proba_shap < 0.5] + ) + np.testing.assert_array_almost_equal( + prediction_proba_sklearn, prediction_proba_shap + ) + + @requires_xgboost + @requires_sklearn + @pytest.mark.parametrize("objective", ["binary:logistic", "reg:squarederror"]) + def test_xgb_import_export(self, objective): + booster = "gbtree" + + if objective.startswith("binary:"): + training_data = datasets.make_classification(n_features=5) + xgb_model = XGBClassifier( + booster=booster, objective=objective, use_label_encoder=False + ) + else: + training_data = datasets.make_regression(n_features=5) + xgb_model = XGBRegressor( + booster=booster, objective=objective, use_label_encoder=False + ) + + # Train model + xgb_model.fit(training_data[0], training_data[1]) + + # Serialise the models to Elasticsearch + feature_names = ["feature0", "feature1", "feature2", "feature3", "feature4"] + model_id = "test_xgb_model" + + es_model = MLModel.import_model( + ES_TEST_CLIENT, model_id, xgb_model, feature_names, es_if_exists="replace" + ) + + # Export suppose to fail + with pytest.raises(ValueError) as ex: + es_model.export_model() + assert ex.match("Error initializing sklearn classifier.") + + # Clean up + es_model.delete_model() + + @requires_lightgbm + @pytest.mark.parametrize("objective", ["regression", "binary"]) + def test_lgbm_import_export(self, objective): + booster = "gbdt" + if objective == "binary": + training_data = datasets.make_classification(n_features=5) + lgbm_model = LGBMClassifier(boosting_type=booster, objective=objective) + else: + training_data = datasets.make_regression(n_features=5) + lgbm_model = LGBMRegressor(boosting_type=booster, objective=objective) + + # Train model + lgbm_model.fit(training_data[0], training_data[1]) + + # Serialise the models to Elasticsearch + feature_names = ["feature0", "feature1", "feature2", "feature3", "feature4"] + model_id = "test_lgbm_model" + + es_model = MLModel.import_model( + ES_TEST_CLIENT, model_id, lgbm_model, feature_names, es_if_exists="replace" + ) + + # Export suppose to fail + with pytest.raises(ValueError) as ex: + es_model.export_model() + assert ex.match("Error initializing sklearn classifier.") + + # Clean up + es_model.delete_model()