mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
Add support for xgboost v1
This commit is contained in:
parent
df2a21ffd4
commit
3d81def5cc
@ -3,6 +3,7 @@
|
|||||||
ELASTICSEARCH_VERSION:
|
ELASTICSEARCH_VERSION:
|
||||||
- 8.0.0-SNAPSHOT
|
- 8.0.0-SNAPSHOT
|
||||||
- 7.x-SNAPSHOT
|
- 7.x-SNAPSHOT
|
||||||
|
- 7.7-SNAPSHOT
|
||||||
- 7.6-SNAPSHOT
|
- 7.6-SNAPSHOT
|
||||||
|
|
||||||
TEST_SUITE:
|
TEST_SUITE:
|
||||||
|
@ -1,9 +1,2 @@
|
|||||||
elasticsearch==7.7.0a2
|
-r ../requirements-dev.txt
|
||||||
pandas>=1
|
|
||||||
matplotlib
|
|
||||||
pytest>=5.2.1
|
|
||||||
git+https://github.com/pandas-dev/pydata-sphinx-theme.git@master
|
git+https://github.com/pandas-dev/pydata-sphinx-theme.git@master
|
||||||
numpydoc>=0.9.0
|
|
||||||
nbsphinx
|
|
||||||
scikit-learn
|
|
||||||
xgboost==0.90
|
|
||||||
|
@ -6,41 +6,41 @@ import base64
|
|||||||
import gzip
|
import gzip
|
||||||
import json
|
import json
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from typing import List
|
from typing import List, Dict, Any, Optional
|
||||||
|
|
||||||
|
|
||||||
def add_if_exists(d: dict, k: str, v) -> dict:
|
def add_if_exists(d: Dict[str, Any], k: str, v: Any) -> None:
|
||||||
if v is not None:
|
if v is not None:
|
||||||
d[k] = v
|
d[k] = v
|
||||||
return d
|
|
||||||
|
|
||||||
|
|
||||||
class ModelSerializer(ABC):
|
class ModelSerializer(ABC):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
feature_names: List[str],
|
feature_names: List[str],
|
||||||
target_type: str = None,
|
target_type: Optional[str] = None,
|
||||||
classification_labels: List[str] = None,
|
classification_labels: Optional[List[str]] = None,
|
||||||
):
|
):
|
||||||
self._target_type = target_type
|
self._target_type = target_type
|
||||||
self._feature_names = feature_names
|
self._feature_names = feature_names
|
||||||
self._classification_labels = classification_labels
|
self._classification_labels = classification_labels
|
||||||
|
|
||||||
def to_dict(self):
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
d = dict()
|
d: Dict[str, Any] = {}
|
||||||
add_if_exists(d, "target_type", self._target_type)
|
add_if_exists(d, "target_type", self._target_type)
|
||||||
add_if_exists(d, "feature_names", self._feature_names)
|
add_if_exists(d, "feature_names", self._feature_names)
|
||||||
add_if_exists(d, "classification_labels", self._classification_labels)
|
add_if_exists(d, "classification_labels", self._classification_labels)
|
||||||
return d
|
return d
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def feature_names(self):
|
def feature_names(self) -> List[str]:
|
||||||
return self._feature_names
|
return self._feature_names
|
||||||
|
|
||||||
|
def serialize_model(self) -> Dict[str, Any]:
|
||||||
|
return {"trained_model": self.to_dict()}
|
||||||
|
|
||||||
def serialize_and_compress_model(self) -> str:
|
def serialize_and_compress_model(self) -> str:
|
||||||
json_string = json.dumps(
|
json_string = json.dumps(self.serialize_model(), separators=(",", ":"))
|
||||||
{"trained_model": self.to_dict()}, separators=(",", ":")
|
|
||||||
)
|
|
||||||
return base64.b64encode(gzip.compress(json_string.encode("utf-8"))).decode(
|
return base64.b64encode(gzip.compress(json_string.encode("utf-8"))).decode(
|
||||||
"ascii"
|
"ascii"
|
||||||
)
|
)
|
||||||
@ -50,13 +50,13 @@ class TreeNode:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
node_idx: int,
|
node_idx: int,
|
||||||
default_left: bool = None,
|
default_left: Optional[bool] = None,
|
||||||
decision_type: str = None,
|
decision_type: Optional[str] = None,
|
||||||
left_child: int = None,
|
left_child: Optional[int] = None,
|
||||||
right_child: int = None,
|
right_child: Optional[int] = None,
|
||||||
split_feature: int = None,
|
split_feature: Optional[int] = None,
|
||||||
threshold: float = None,
|
threshold: Optional[float] = None,
|
||||||
leaf_value: float = None,
|
leaf_value: Optional[float] = None,
|
||||||
):
|
):
|
||||||
self._node_idx = node_idx
|
self._node_idx = node_idx
|
||||||
self._decision_type = decision_type
|
self._decision_type = decision_type
|
||||||
@ -67,8 +67,8 @@ class TreeNode:
|
|||||||
self._leaf_value = leaf_value
|
self._leaf_value = leaf_value
|
||||||
self._default_left = default_left
|
self._default_left = default_left
|
||||||
|
|
||||||
def to_dict(self):
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
d = dict()
|
d: Dict[str, Any] = {}
|
||||||
add_if_exists(d, "node_index", self._node_idx)
|
add_if_exists(d, "node_index", self._node_idx)
|
||||||
add_if_exists(d, "decision_type", self._decision_type)
|
add_if_exists(d, "decision_type", self._decision_type)
|
||||||
if self._leaf_value is None:
|
if self._leaf_value is None:
|
||||||
@ -85,9 +85,9 @@ class Tree(ModelSerializer):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
feature_names: List[str],
|
feature_names: List[str],
|
||||||
target_type: str = None,
|
target_type: Optional[str] = None,
|
||||||
tree_structure: List[TreeNode] = [],
|
tree_structure: Optional[List[TreeNode]] = None,
|
||||||
classification_labels: List[str] = None,
|
classification_labels: Optional[List[str]] = None,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
feature_names=feature_names,
|
feature_names=feature_names,
|
||||||
@ -96,9 +96,9 @@ class Tree(ModelSerializer):
|
|||||||
)
|
)
|
||||||
if target_type == "regression" and classification_labels:
|
if target_type == "regression" and classification_labels:
|
||||||
raise ValueError("regression does not support classification_labels")
|
raise ValueError("regression does not support classification_labels")
|
||||||
self._tree_structure = tree_structure
|
self._tree_structure = tree_structure or []
|
||||||
|
|
||||||
def to_dict(self):
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
d = super().to_dict()
|
d = super().to_dict()
|
||||||
add_if_exists(d, "tree_structure", [t.to_dict() for t in self._tree_structure])
|
add_if_exists(d, "tree_structure", [t.to_dict() for t in self._tree_structure])
|
||||||
return {"tree": d}
|
return {"tree": d}
|
||||||
@ -109,10 +109,10 @@ class Ensemble(ModelSerializer):
|
|||||||
self,
|
self,
|
||||||
feature_names: List[str],
|
feature_names: List[str],
|
||||||
trained_models: List[ModelSerializer],
|
trained_models: List[ModelSerializer],
|
||||||
output_aggregator: dict,
|
output_aggregator: Dict[str, Any],
|
||||||
target_type: str = None,
|
target_type: Optional[str] = None,
|
||||||
classification_labels: List[str] = None,
|
classification_labels: Optional[List[str]] = None,
|
||||||
classification_weights: List[float] = None,
|
classification_weights: Optional[List[float]] = None,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
feature_names=feature_names,
|
feature_names=feature_names,
|
||||||
@ -123,7 +123,7 @@ class Ensemble(ModelSerializer):
|
|||||||
self._classification_weights = classification_weights
|
self._classification_weights = classification_weights
|
||||||
self._output_aggregator = output_aggregator
|
self._output_aggregator = output_aggregator
|
||||||
|
|
||||||
def to_dict(self):
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
d = super().to_dict()
|
d = super().to_dict()
|
||||||
trained_models = None
|
trained_models = None
|
||||||
if self._trained_models:
|
if self._trained_models:
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
# Elasticsearch B.V licenses this file to you under the Apache 2.0 License.
|
# Elasticsearch B.V licenses this file to you under the Apache 2.0 License.
|
||||||
# See the LICENSE file in the project root for more information
|
# See the LICENSE file in the project root for more information
|
||||||
|
|
||||||
from typing import List, Union
|
from typing import List, Union, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@ -23,8 +23,8 @@ class ModelTransformer:
|
|||||||
self,
|
self,
|
||||||
model,
|
model,
|
||||||
feature_names: List[str],
|
feature_names: List[str],
|
||||||
classification_labels: List[str] = None,
|
classification_labels: Optional[List[str]] = None,
|
||||||
classification_weights: List[float] = None,
|
classification_weights: Optional[List[float]] = None,
|
||||||
):
|
):
|
||||||
self._feature_names = feature_names
|
self._feature_names = feature_names
|
||||||
self._model = model
|
self._model = model
|
||||||
@ -56,8 +56,8 @@ class SKLearnTransformer(ModelTransformer):
|
|||||||
self,
|
self,
|
||||||
model,
|
model,
|
||||||
feature_names: List[str],
|
feature_names: List[str],
|
||||||
classification_labels: List[str] = None,
|
classification_labels: Optional[List[str]] = None,
|
||||||
classification_weights: List[float] = None,
|
classification_weights: Optional[List[float]] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Base class for SKLearn transformations
|
Base class for SKLearn transformations
|
||||||
@ -120,7 +120,7 @@ class SKLearnDecisionTreeTransformer(SKLearnTransformer):
|
|||||||
self,
|
self,
|
||||||
model: Union[DecisionTreeRegressor, DecisionTreeClassifier],
|
model: Union[DecisionTreeRegressor, DecisionTreeClassifier],
|
||||||
feature_names: List[str],
|
feature_names: List[str],
|
||||||
classification_labels: List[str] = None,
|
classification_labels: Optional[List[str]] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Transforms a Decision Tree model (Regressor|Classifier) into a ES Supported Tree format
|
Transforms a Decision Tree model (Regressor|Classifier) into a ES Supported Tree format
|
||||||
@ -148,7 +148,7 @@ class SKLearnDecisionTreeTransformer(SKLearnTransformer):
|
|||||||
check_is_fitted(self._model, ["classes_"])
|
check_is_fitted(self._model, ["classes_"])
|
||||||
if tree_classes is None:
|
if tree_classes is None:
|
||||||
tree_classes = [str(c) for c in self._model.classes_]
|
tree_classes = [str(c) for c in self._model.classes_]
|
||||||
nodes = list()
|
nodes = []
|
||||||
tree_state = self._model.tree_.__getstate__()
|
tree_state = self._model.tree_.__getstate__()
|
||||||
for i in range(len(tree_state["nodes"])):
|
for i in range(len(tree_state["nodes"])):
|
||||||
nodes.append(
|
nodes.append(
|
||||||
@ -169,8 +169,8 @@ class SKLearnForestTransformer(SKLearnTransformer):
|
|||||||
self,
|
self,
|
||||||
model: Union[RandomForestClassifier, RandomForestRegressor],
|
model: Union[RandomForestClassifier, RandomForestRegressor],
|
||||||
feature_names: List[str],
|
feature_names: List[str],
|
||||||
classification_labels: List[str] = None,
|
classification_labels: Optional[List[str]] = None,
|
||||||
classification_weights: List[float] = None,
|
classification_weights: Optional[List[float]] = None,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
model, feature_names, classification_labels, classification_weights
|
model, feature_names, classification_labels, classification_weights
|
||||||
@ -235,7 +235,7 @@ class SKLearnForestClassifierTransformer(SKLearnForestTransformer):
|
|||||||
self,
|
self,
|
||||||
model: RandomForestClassifier,
|
model: RandomForestClassifier,
|
||||||
feature_names: List[str],
|
feature_names: List[str],
|
||||||
classification_labels: List[str] = None,
|
classification_labels: Optional[List[str]] = None,
|
||||||
):
|
):
|
||||||
super().__init__(model, feature_names, classification_labels)
|
super().__init__(model, feature_names, classification_labels)
|
||||||
|
|
||||||
@ -259,8 +259,8 @@ class XGBoostForestTransformer(ModelTransformer):
|
|||||||
feature_names: List[str],
|
feature_names: List[str],
|
||||||
base_score: float = 0.5,
|
base_score: float = 0.5,
|
||||||
objective: str = "reg:squarederror",
|
objective: str = "reg:squarederror",
|
||||||
classification_labels: List[str] = None,
|
classification_labels: Optional[List[str]] = None,
|
||||||
classification_weights: List[float] = None,
|
classification_weights: Optional[List[float]] = None,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
model, feature_names, classification_labels, classification_weights
|
model, feature_names, classification_labels, classification_weights
|
||||||
@ -330,25 +330,24 @@ class XGBoostForestTransformer(ModelTransformer):
|
|||||||
|
|
||||||
:return: A list of Tree objects
|
:return: A list of Tree objects
|
||||||
"""
|
"""
|
||||||
if self._model.booster not in {"dart", "gbtree"}:
|
self.check_model_booster()
|
||||||
raise ValueError("booster must exist and be of type dart or gbtree")
|
|
||||||
|
|
||||||
tree_table = self._model.trees_to_dataframe()
|
tree_table = self._model.trees_to_dataframe()
|
||||||
transformed_trees = list()
|
transformed_trees = []
|
||||||
curr_tree = None
|
curr_tree = None
|
||||||
tree_nodes = list()
|
tree_nodes = []
|
||||||
for _, row in tree_table.iterrows():
|
for _, row in tree_table.iterrows():
|
||||||
if row["Tree"] != curr_tree:
|
if row["Tree"] != curr_tree:
|
||||||
if len(tree_nodes) > 0:
|
if len(tree_nodes) > 0:
|
||||||
transformed_trees.append(self.build_tree(tree_nodes))
|
transformed_trees.append(self.build_tree(tree_nodes))
|
||||||
curr_tree = row["Tree"]
|
curr_tree = row["Tree"]
|
||||||
tree_nodes = list()
|
tree_nodes = []
|
||||||
tree_nodes.append(self.build_tree_node(row, curr_tree))
|
tree_nodes.append(self.build_tree_node(row, curr_tree))
|
||||||
# add last tree
|
# add last tree
|
||||||
if len(tree_nodes) > 0:
|
if len(tree_nodes) > 0:
|
||||||
transformed_trees.append(self.build_tree(tree_nodes))
|
transformed_trees.append(self.build_tree(tree_nodes))
|
||||||
# We add this stump as XGBoost adds the base_score to the regression outputs
|
# We add this stump as XGBoost adds the base_score to the regression outputs
|
||||||
if self._objective.startswith("reg"):
|
if self._objective.partition(":")[0] == "reg":
|
||||||
transformed_trees.append(self.build_base_score_stump())
|
transformed_trees.append(self.build_base_score_stump())
|
||||||
return transformed_trees
|
return transformed_trees
|
||||||
|
|
||||||
@ -361,9 +360,16 @@ class XGBoostForestTransformer(ModelTransformer):
|
|||||||
def is_objective_supported(self) -> bool:
|
def is_objective_supported(self) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def check_model_booster(self):
|
||||||
|
# xgboost v1 made booster default to 'None' meaning 'gbtree'
|
||||||
|
if self._model.booster not in {"dart", "gbtree", None}:
|
||||||
|
raise ValueError(
|
||||||
|
f"booster must exist and be of type 'dart' or "
|
||||||
|
f"'gbtree', was {self._model.booster!r}"
|
||||||
|
)
|
||||||
|
|
||||||
def transform(self) -> Ensemble:
|
def transform(self) -> Ensemble:
|
||||||
if self._model.booster not in {"dart", "gbtree"}:
|
self.check_model_booster()
|
||||||
raise ValueError("booster must exist and be of type dart or gbtree")
|
|
||||||
|
|
||||||
if not self.is_objective_supported():
|
if not self.is_objective_supported():
|
||||||
raise ValueError(f"Unsupported objective '{self._objective}'")
|
raise ValueError(f"Unsupported objective '{self._objective}'")
|
||||||
@ -381,8 +387,12 @@ class XGBoostForestTransformer(ModelTransformer):
|
|||||||
|
|
||||||
class XGBoostRegressorTransformer(XGBoostForestTransformer):
|
class XGBoostRegressorTransformer(XGBoostForestTransformer):
|
||||||
def __init__(self, model: XGBRegressor, feature_names: List[str]):
|
def __init__(self, model: XGBRegressor, feature_names: List[str]):
|
||||||
|
# XGBRegressor.base_score defaults to 0.5.
|
||||||
|
base_score = model.base_score
|
||||||
|
if base_score is None:
|
||||||
|
base_score = 0.5
|
||||||
super().__init__(
|
super().__init__(
|
||||||
model.get_booster(), feature_names, model.base_score, model.objective
|
model.get_booster(), feature_names, base_score, model.objective
|
||||||
)
|
)
|
||||||
|
|
||||||
def determine_target_type(self) -> str:
|
def determine_target_type(self) -> str:
|
||||||
@ -405,7 +415,7 @@ class XGBoostClassifierTransformer(XGBoostForestTransformer):
|
|||||||
self,
|
self,
|
||||||
model: XGBClassifier,
|
model: XGBClassifier,
|
||||||
feature_names: List[str],
|
feature_names: List[str],
|
||||||
classification_labels: List[str] = None,
|
classification_labels: Optional[List[str]] = None,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
model.get_booster(),
|
model.get_booster(),
|
||||||
|
@ -2,9 +2,9 @@
|
|||||||
# Elasticsearch B.V licenses this file to you under the Apache 2.0 License.
|
# Elasticsearch B.V licenses this file to you under the Apache 2.0 License.
|
||||||
# See the LICENSE file in the project root for more information
|
# See the LICENSE file in the project root for more information
|
||||||
|
|
||||||
from typing import Union, List
|
from typing import Union, List, Optional, Tuple, TYPE_CHECKING, cast
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np # type: ignore
|
||||||
|
|
||||||
from eland.common import es_version
|
from eland.common import es_version
|
||||||
from eland.ml._model_transformers import (
|
from eland.ml._model_transformers import (
|
||||||
@ -14,15 +14,20 @@ from eland.ml._model_transformers import (
|
|||||||
XGBoostRegressorTransformer,
|
XGBoostRegressorTransformer,
|
||||||
XGBoostClassifierTransformer,
|
XGBoostClassifierTransformer,
|
||||||
)
|
)
|
||||||
|
from eland.ml._model_serializer import ModelSerializer
|
||||||
from eland.ml._optional import import_optional_dependency
|
from eland.ml._optional import import_optional_dependency
|
||||||
from eland.ml.ml_model import MLModel
|
from eland.ml.ml_model import MLModel
|
||||||
|
|
||||||
sklearn = import_optional_dependency("sklearn")
|
sklearn = import_optional_dependency("sklearn")
|
||||||
xgboost = import_optional_dependency("xgboost")
|
xgboost = import_optional_dependency("xgboost")
|
||||||
|
|
||||||
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
|
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor # type: ignore
|
||||||
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
|
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor # type: ignore
|
||||||
from xgboost import XGBRegressor, XGBClassifier
|
from xgboost import XGBRegressor, XGBClassifier # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from elasticsearch import Elasticsearch # type: ignore # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
class ImportedMLModel(MLModel):
|
class ImportedMLModel(MLModel):
|
||||||
@ -91,7 +96,7 @@ class ImportedMLModel(MLModel):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
es_client,
|
es_client: Union[str, List[str], Tuple[str, ...], "Elasticsearch"],
|
||||||
model_id: str,
|
model_id: str,
|
||||||
model: Union[
|
model: Union[
|
||||||
DecisionTreeClassifier,
|
DecisionTreeClassifier,
|
||||||
@ -102,15 +107,16 @@ class ImportedMLModel(MLModel):
|
|||||||
XGBRegressor,
|
XGBRegressor,
|
||||||
],
|
],
|
||||||
feature_names: List[str],
|
feature_names: List[str],
|
||||||
classification_labels: List[str] = None,
|
classification_labels: Optional[List[str]] = None,
|
||||||
classification_weights: List[float] = None,
|
classification_weights: Optional[List[float]] = None,
|
||||||
overwrite=False,
|
overwrite: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__(es_client, model_id)
|
super().__init__(es_client, model_id)
|
||||||
|
|
||||||
self._feature_names = feature_names
|
self._feature_names = feature_names
|
||||||
self._model_type = None
|
self._model_type = None
|
||||||
|
|
||||||
|
serializer: ModelSerializer # type def
|
||||||
# Transform model
|
# Transform model
|
||||||
if isinstance(model, DecisionTreeRegressor):
|
if isinstance(model, DecisionTreeRegressor):
|
||||||
serializer = SKLearnDecisionTreeTransformer(
|
serializer = SKLearnDecisionTreeTransformer(
|
||||||
@ -161,7 +167,7 @@ class ImportedMLModel(MLModel):
|
|||||||
model_id=self._model_id, body=body,
|
model_id=self._model_id, body=body,
|
||||||
)
|
)
|
||||||
|
|
||||||
def predict(self, X):
|
def predict(self, X: Union[List[float], List[List[float]]]) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Make a prediction using a trained model stored in Elasticsearch.
|
Make a prediction using a trained model stored in Elasticsearch.
|
||||||
|
|
||||||
@ -190,7 +196,7 @@ class ImportedMLModel(MLModel):
|
|||||||
|
|
||||||
>>> # Get some test results
|
>>> # Get some test results
|
||||||
>>> regressor.predict(np.array(test_data))
|
>>> regressor.predict(np.array(test_data))
|
||||||
array([0.23733574, 1.1897984 ], dtype=float32)
|
array([0.06062475, 0.9990102 ], dtype=float32)
|
||||||
|
|
||||||
>>> # Serialise the model to Elasticsearch
|
>>> # Serialise the model to Elasticsearch
|
||||||
>>> feature_names = ["f0", "f1", "f2", "f3", "f4", "f5"]
|
>>> feature_names = ["f0", "f1", "f2", "f3", "f4", "f5"]
|
||||||
@ -199,7 +205,7 @@ class ImportedMLModel(MLModel):
|
|||||||
|
|
||||||
>>> # Get some test results from Elasticsearch model
|
>>> # Get some test results from Elasticsearch model
|
||||||
>>> es_model.predict(test_data)
|
>>> es_model.predict(test_data)
|
||||||
array([0.2373357, 1.1897984], dtype=float32)
|
array([0.0606248 , 0.99901026], dtype=float32)
|
||||||
|
|
||||||
>>> # Delete model from Elasticsearch
|
>>> # Delete model from Elasticsearch
|
||||||
>>> es_model.delete_model()
|
>>> es_model.delete_model()
|
||||||
@ -207,20 +213,22 @@ class ImportedMLModel(MLModel):
|
|||||||
"""
|
"""
|
||||||
docs = []
|
docs = []
|
||||||
if isinstance(X, list):
|
if isinstance(X, list):
|
||||||
# Is it a list of lists?
|
# Is it a list of floats?
|
||||||
if all(isinstance(i, list) for i in X):
|
if all(isinstance(i, float) for i in X):
|
||||||
for i in X:
|
features = cast(List[List[float]], [X])
|
||||||
doc = dict()
|
else:
|
||||||
doc["_source"] = dict(zip(self._feature_names, i))
|
features = cast(List[List[float]], X)
|
||||||
docs.append(doc)
|
for i in features:
|
||||||
|
doc = {"_source": dict(zip(self._feature_names, i))}
|
||||||
else: # single feature vector1
|
|
||||||
doc = dict()
|
|
||||||
doc["_source"] = dict(zip(self._feature_names, i))
|
|
||||||
docs.append(doc)
|
docs.append(doc)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Prediction for type {type(X)}, not supported")
|
raise NotImplementedError(f"Prediction for type {type(X)}, not supported")
|
||||||
|
|
||||||
|
# field_mappings -> field_map in ES 7.7
|
||||||
|
field_map_name = (
|
||||||
|
"field_map" if es_version(self._client) >= (7, 7) else "field_mappings"
|
||||||
|
)
|
||||||
|
|
||||||
results = self._client.ingest.simulate(
|
results = self._client.ingest.simulate(
|
||||||
body={
|
body={
|
||||||
"pipeline": {
|
"pipeline": {
|
||||||
@ -229,7 +237,7 @@ class ImportedMLModel(MLModel):
|
|||||||
"inference": {
|
"inference": {
|
||||||
"model_id": self._model_id,
|
"model_id": self._model_id,
|
||||||
"inference_config": {self._model_type: {}},
|
"inference_config": {self._model_type: {}},
|
||||||
"field_mappings": {},
|
field_map_name: {},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
@ -34,7 +34,7 @@ class MLModel:
|
|||||||
self._client = ensure_es_client(es_client)
|
self._client = ensure_es_client(es_client)
|
||||||
self._model_id = model_id
|
self._model_id = model_id
|
||||||
|
|
||||||
def delete_model(self):
|
def delete_model(self) -> None:
|
||||||
"""
|
"""
|
||||||
Delete an inference model saved in Elasticsearch
|
Delete an inference model saved in Elasticsearch
|
||||||
|
|
||||||
|
@ -112,7 +112,7 @@ class TestImportedMLModel:
|
|||||||
def test_xgb_classifier(self):
|
def test_xgb_classifier(self):
|
||||||
# Train model
|
# Train model
|
||||||
training_data = datasets.make_classification(n_features=5)
|
training_data = datasets.make_classification(n_features=5)
|
||||||
classifier = XGBClassifier()
|
classifier = XGBClassifier(booster="gbtree")
|
||||||
classifier.fit(training_data[0], training_data[1])
|
classifier.fit(training_data[0], training_data[1])
|
||||||
|
|
||||||
# Get some test results
|
# Get some test results
|
||||||
@ -150,9 +150,36 @@ class TestImportedMLModel:
|
|||||||
es_model = ImportedMLModel(
|
es_model = ImportedMLModel(
|
||||||
ES_TEST_CLIENT, model_id, regressor, feature_names, overwrite=True
|
ES_TEST_CLIENT, model_id, regressor, feature_names, overwrite=True
|
||||||
)
|
)
|
||||||
|
|
||||||
es_results = es_model.predict(test_data)
|
es_results = es_model.predict(test_data)
|
||||||
|
|
||||||
np.testing.assert_almost_equal(test_results, es_results, decimal=2)
|
np.testing.assert_almost_equal(test_results, es_results, decimal=2)
|
||||||
|
|
||||||
# Clean up
|
# Clean up
|
||||||
es_model.delete_model()
|
es_model.delete_model()
|
||||||
|
|
||||||
|
def test_predict_single_feature_vector(self):
|
||||||
|
# Train model
|
||||||
|
training_data = datasets.make_regression(n_features=1)
|
||||||
|
regressor = XGBRegressor()
|
||||||
|
regressor.fit(training_data[0], training_data[1])
|
||||||
|
|
||||||
|
# Get some test results
|
||||||
|
test_data = [[0.1]]
|
||||||
|
test_results = regressor.predict(np.asarray(test_data))
|
||||||
|
|
||||||
|
# Serialise the models to Elasticsearch
|
||||||
|
feature_names = ["f0"]
|
||||||
|
model_id = "test_xgb_regressor"
|
||||||
|
|
||||||
|
es_model = ImportedMLModel(
|
||||||
|
ES_TEST_CLIENT, model_id, regressor, feature_names, overwrite=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Single feature
|
||||||
|
es_results = es_model.predict(test_data[0])
|
||||||
|
|
||||||
|
np.testing.assert_almost_equal(test_results, es_results, decimal=2)
|
||||||
|
|
||||||
|
# Clean up
|
||||||
|
es_model.delete_model()
|
||||||
|
@ -29,6 +29,8 @@ TYPED_FILES = {
|
|||||||
"eland/index.py",
|
"eland/index.py",
|
||||||
"eland/query.py",
|
"eland/query.py",
|
||||||
"eland/tasks.py",
|
"eland/tasks.py",
|
||||||
|
"eland/ml/_model_serializer.py",
|
||||||
|
"eland/ml/imported_ml_model.py",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -50,6 +52,8 @@ def lint(session):
|
|||||||
# TODO: When all files are typed we can change this to .run("mypy", "--strict", "eland/")
|
# TODO: When all files are typed we can change this to .run("mypy", "--strict", "eland/")
|
||||||
session.log("mypy --strict eland/")
|
session.log("mypy --strict eland/")
|
||||||
for typed_file in TYPED_FILES:
|
for typed_file in TYPED_FILES:
|
||||||
|
if not os.path.isfile(typed_file):
|
||||||
|
session.error(f"The file {typed_file!r} couldn't be found")
|
||||||
popen = subprocess.Popen(
|
popen = subprocess.Popen(
|
||||||
f"mypy --strict {typed_file}",
|
f"mypy --strict {typed_file}",
|
||||||
shell=True,
|
shell=True,
|
||||||
|
@ -5,5 +5,5 @@ pytest>=5.2.1
|
|||||||
nbval
|
nbval
|
||||||
numpydoc>=0.9.0
|
numpydoc>=0.9.0
|
||||||
scikit-learn>=0.22.1
|
scikit-learn>=0.22.1
|
||||||
xgboost==0.90
|
xgboost>=1
|
||||||
nox
|
nox
|
||||||
|
4
setup.py
4
setup.py
@ -177,4 +177,8 @@ setup(
|
|||||||
packages=find_packages(include=["eland", "eland.*"]),
|
packages=find_packages(include=["eland", "eland.*"]),
|
||||||
install_requires=["elasticsearch==7.7.0a2", "pandas>=1", "matplotlib", "numpy"],
|
install_requires=["elasticsearch==7.7.0a2", "pandas>=1", "matplotlib", "numpy"],
|
||||||
python_requires=">=3.6",
|
python_requires=">=3.6",
|
||||||
|
extras_require={
|
||||||
|
"xgboost": ["xgboost>=0.90,<2"],
|
||||||
|
"scikit-learn": ["scikit-learn>=0.22.1,<1"],
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user