mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
Add support for xgboost v1
This commit is contained in:
parent
df2a21ffd4
commit
3d81def5cc
@ -3,6 +3,7 @@
|
||||
ELASTICSEARCH_VERSION:
|
||||
- 8.0.0-SNAPSHOT
|
||||
- 7.x-SNAPSHOT
|
||||
- 7.7-SNAPSHOT
|
||||
- 7.6-SNAPSHOT
|
||||
|
||||
TEST_SUITE:
|
||||
|
@ -1,9 +1,2 @@
|
||||
elasticsearch==7.7.0a2
|
||||
pandas>=1
|
||||
matplotlib
|
||||
pytest>=5.2.1
|
||||
-r ../requirements-dev.txt
|
||||
git+https://github.com/pandas-dev/pydata-sphinx-theme.git@master
|
||||
numpydoc>=0.9.0
|
||||
nbsphinx
|
||||
scikit-learn
|
||||
xgboost==0.90
|
||||
|
@ -6,41 +6,41 @@ import base64
|
||||
import gzip
|
||||
import json
|
||||
from abc import ABC
|
||||
from typing import List
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
|
||||
def add_if_exists(d: dict, k: str, v) -> dict:
|
||||
def add_if_exists(d: Dict[str, Any], k: str, v: Any) -> None:
|
||||
if v is not None:
|
||||
d[k] = v
|
||||
return d
|
||||
|
||||
|
||||
class ModelSerializer(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
feature_names: List[str],
|
||||
target_type: str = None,
|
||||
classification_labels: List[str] = None,
|
||||
target_type: Optional[str] = None,
|
||||
classification_labels: Optional[List[str]] = None,
|
||||
):
|
||||
self._target_type = target_type
|
||||
self._feature_names = feature_names
|
||||
self._classification_labels = classification_labels
|
||||
|
||||
def to_dict(self):
|
||||
d = dict()
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
d: Dict[str, Any] = {}
|
||||
add_if_exists(d, "target_type", self._target_type)
|
||||
add_if_exists(d, "feature_names", self._feature_names)
|
||||
add_if_exists(d, "classification_labels", self._classification_labels)
|
||||
return d
|
||||
|
||||
@property
|
||||
def feature_names(self):
|
||||
def feature_names(self) -> List[str]:
|
||||
return self._feature_names
|
||||
|
||||
def serialize_model(self) -> Dict[str, Any]:
|
||||
return {"trained_model": self.to_dict()}
|
||||
|
||||
def serialize_and_compress_model(self) -> str:
|
||||
json_string = json.dumps(
|
||||
{"trained_model": self.to_dict()}, separators=(",", ":")
|
||||
)
|
||||
json_string = json.dumps(self.serialize_model(), separators=(",", ":"))
|
||||
return base64.b64encode(gzip.compress(json_string.encode("utf-8"))).decode(
|
||||
"ascii"
|
||||
)
|
||||
@ -50,13 +50,13 @@ class TreeNode:
|
||||
def __init__(
|
||||
self,
|
||||
node_idx: int,
|
||||
default_left: bool = None,
|
||||
decision_type: str = None,
|
||||
left_child: int = None,
|
||||
right_child: int = None,
|
||||
split_feature: int = None,
|
||||
threshold: float = None,
|
||||
leaf_value: float = None,
|
||||
default_left: Optional[bool] = None,
|
||||
decision_type: Optional[str] = None,
|
||||
left_child: Optional[int] = None,
|
||||
right_child: Optional[int] = None,
|
||||
split_feature: Optional[int] = None,
|
||||
threshold: Optional[float] = None,
|
||||
leaf_value: Optional[float] = None,
|
||||
):
|
||||
self._node_idx = node_idx
|
||||
self._decision_type = decision_type
|
||||
@ -67,8 +67,8 @@ class TreeNode:
|
||||
self._leaf_value = leaf_value
|
||||
self._default_left = default_left
|
||||
|
||||
def to_dict(self):
|
||||
d = dict()
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
d: Dict[str, Any] = {}
|
||||
add_if_exists(d, "node_index", self._node_idx)
|
||||
add_if_exists(d, "decision_type", self._decision_type)
|
||||
if self._leaf_value is None:
|
||||
@ -85,9 +85,9 @@ class Tree(ModelSerializer):
|
||||
def __init__(
|
||||
self,
|
||||
feature_names: List[str],
|
||||
target_type: str = None,
|
||||
tree_structure: List[TreeNode] = [],
|
||||
classification_labels: List[str] = None,
|
||||
target_type: Optional[str] = None,
|
||||
tree_structure: Optional[List[TreeNode]] = None,
|
||||
classification_labels: Optional[List[str]] = None,
|
||||
):
|
||||
super().__init__(
|
||||
feature_names=feature_names,
|
||||
@ -96,9 +96,9 @@ class Tree(ModelSerializer):
|
||||
)
|
||||
if target_type == "regression" and classification_labels:
|
||||
raise ValueError("regression does not support classification_labels")
|
||||
self._tree_structure = tree_structure
|
||||
self._tree_structure = tree_structure or []
|
||||
|
||||
def to_dict(self):
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
d = super().to_dict()
|
||||
add_if_exists(d, "tree_structure", [t.to_dict() for t in self._tree_structure])
|
||||
return {"tree": d}
|
||||
@ -109,10 +109,10 @@ class Ensemble(ModelSerializer):
|
||||
self,
|
||||
feature_names: List[str],
|
||||
trained_models: List[ModelSerializer],
|
||||
output_aggregator: dict,
|
||||
target_type: str = None,
|
||||
classification_labels: List[str] = None,
|
||||
classification_weights: List[float] = None,
|
||||
output_aggregator: Dict[str, Any],
|
||||
target_type: Optional[str] = None,
|
||||
classification_labels: Optional[List[str]] = None,
|
||||
classification_weights: Optional[List[float]] = None,
|
||||
):
|
||||
super().__init__(
|
||||
feature_names=feature_names,
|
||||
@ -123,7 +123,7 @@ class Ensemble(ModelSerializer):
|
||||
self._classification_weights = classification_weights
|
||||
self._output_aggregator = output_aggregator
|
||||
|
||||
def to_dict(self):
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
d = super().to_dict()
|
||||
trained_models = None
|
||||
if self._trained_models:
|
||||
|
@ -2,7 +2,7 @@
|
||||
# Elasticsearch B.V licenses this file to you under the Apache 2.0 License.
|
||||
# See the LICENSE file in the project root for more information
|
||||
|
||||
from typing import List, Union
|
||||
from typing import List, Union, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -23,8 +23,8 @@ class ModelTransformer:
|
||||
self,
|
||||
model,
|
||||
feature_names: List[str],
|
||||
classification_labels: List[str] = None,
|
||||
classification_weights: List[float] = None,
|
||||
classification_labels: Optional[List[str]] = None,
|
||||
classification_weights: Optional[List[float]] = None,
|
||||
):
|
||||
self._feature_names = feature_names
|
||||
self._model = model
|
||||
@ -56,8 +56,8 @@ class SKLearnTransformer(ModelTransformer):
|
||||
self,
|
||||
model,
|
||||
feature_names: List[str],
|
||||
classification_labels: List[str] = None,
|
||||
classification_weights: List[float] = None,
|
||||
classification_labels: Optional[List[str]] = None,
|
||||
classification_weights: Optional[List[float]] = None,
|
||||
):
|
||||
"""
|
||||
Base class for SKLearn transformations
|
||||
@ -120,7 +120,7 @@ class SKLearnDecisionTreeTransformer(SKLearnTransformer):
|
||||
self,
|
||||
model: Union[DecisionTreeRegressor, DecisionTreeClassifier],
|
||||
feature_names: List[str],
|
||||
classification_labels: List[str] = None,
|
||||
classification_labels: Optional[List[str]] = None,
|
||||
):
|
||||
"""
|
||||
Transforms a Decision Tree model (Regressor|Classifier) into a ES Supported Tree format
|
||||
@ -148,7 +148,7 @@ class SKLearnDecisionTreeTransformer(SKLearnTransformer):
|
||||
check_is_fitted(self._model, ["classes_"])
|
||||
if tree_classes is None:
|
||||
tree_classes = [str(c) for c in self._model.classes_]
|
||||
nodes = list()
|
||||
nodes = []
|
||||
tree_state = self._model.tree_.__getstate__()
|
||||
for i in range(len(tree_state["nodes"])):
|
||||
nodes.append(
|
||||
@ -169,8 +169,8 @@ class SKLearnForestTransformer(SKLearnTransformer):
|
||||
self,
|
||||
model: Union[RandomForestClassifier, RandomForestRegressor],
|
||||
feature_names: List[str],
|
||||
classification_labels: List[str] = None,
|
||||
classification_weights: List[float] = None,
|
||||
classification_labels: Optional[List[str]] = None,
|
||||
classification_weights: Optional[List[float]] = None,
|
||||
):
|
||||
super().__init__(
|
||||
model, feature_names, classification_labels, classification_weights
|
||||
@ -235,7 +235,7 @@ class SKLearnForestClassifierTransformer(SKLearnForestTransformer):
|
||||
self,
|
||||
model: RandomForestClassifier,
|
||||
feature_names: List[str],
|
||||
classification_labels: List[str] = None,
|
||||
classification_labels: Optional[List[str]] = None,
|
||||
):
|
||||
super().__init__(model, feature_names, classification_labels)
|
||||
|
||||
@ -259,8 +259,8 @@ class XGBoostForestTransformer(ModelTransformer):
|
||||
feature_names: List[str],
|
||||
base_score: float = 0.5,
|
||||
objective: str = "reg:squarederror",
|
||||
classification_labels: List[str] = None,
|
||||
classification_weights: List[float] = None,
|
||||
classification_labels: Optional[List[str]] = None,
|
||||
classification_weights: Optional[List[float]] = None,
|
||||
):
|
||||
super().__init__(
|
||||
model, feature_names, classification_labels, classification_weights
|
||||
@ -330,25 +330,24 @@ class XGBoostForestTransformer(ModelTransformer):
|
||||
|
||||
:return: A list of Tree objects
|
||||
"""
|
||||
if self._model.booster not in {"dart", "gbtree"}:
|
||||
raise ValueError("booster must exist and be of type dart or gbtree")
|
||||
self.check_model_booster()
|
||||
|
||||
tree_table = self._model.trees_to_dataframe()
|
||||
transformed_trees = list()
|
||||
transformed_trees = []
|
||||
curr_tree = None
|
||||
tree_nodes = list()
|
||||
tree_nodes = []
|
||||
for _, row in tree_table.iterrows():
|
||||
if row["Tree"] != curr_tree:
|
||||
if len(tree_nodes) > 0:
|
||||
transformed_trees.append(self.build_tree(tree_nodes))
|
||||
curr_tree = row["Tree"]
|
||||
tree_nodes = list()
|
||||
tree_nodes = []
|
||||
tree_nodes.append(self.build_tree_node(row, curr_tree))
|
||||
# add last tree
|
||||
if len(tree_nodes) > 0:
|
||||
transformed_trees.append(self.build_tree(tree_nodes))
|
||||
# We add this stump as XGBoost adds the base_score to the regression outputs
|
||||
if self._objective.startswith("reg"):
|
||||
if self._objective.partition(":")[0] == "reg":
|
||||
transformed_trees.append(self.build_base_score_stump())
|
||||
return transformed_trees
|
||||
|
||||
@ -361,9 +360,16 @@ class XGBoostForestTransformer(ModelTransformer):
|
||||
def is_objective_supported(self) -> bool:
|
||||
return False
|
||||
|
||||
def check_model_booster(self):
|
||||
# xgboost v1 made booster default to 'None' meaning 'gbtree'
|
||||
if self._model.booster not in {"dart", "gbtree", None}:
|
||||
raise ValueError(
|
||||
f"booster must exist and be of type 'dart' or "
|
||||
f"'gbtree', was {self._model.booster!r}"
|
||||
)
|
||||
|
||||
def transform(self) -> Ensemble:
|
||||
if self._model.booster not in {"dart", "gbtree"}:
|
||||
raise ValueError("booster must exist and be of type dart or gbtree")
|
||||
self.check_model_booster()
|
||||
|
||||
if not self.is_objective_supported():
|
||||
raise ValueError(f"Unsupported objective '{self._objective}'")
|
||||
@ -381,8 +387,12 @@ class XGBoostForestTransformer(ModelTransformer):
|
||||
|
||||
class XGBoostRegressorTransformer(XGBoostForestTransformer):
|
||||
def __init__(self, model: XGBRegressor, feature_names: List[str]):
|
||||
# XGBRegressor.base_score defaults to 0.5.
|
||||
base_score = model.base_score
|
||||
if base_score is None:
|
||||
base_score = 0.5
|
||||
super().__init__(
|
||||
model.get_booster(), feature_names, model.base_score, model.objective
|
||||
model.get_booster(), feature_names, base_score, model.objective
|
||||
)
|
||||
|
||||
def determine_target_type(self) -> str:
|
||||
@ -405,7 +415,7 @@ class XGBoostClassifierTransformer(XGBoostForestTransformer):
|
||||
self,
|
||||
model: XGBClassifier,
|
||||
feature_names: List[str],
|
||||
classification_labels: List[str] = None,
|
||||
classification_labels: Optional[List[str]] = None,
|
||||
):
|
||||
super().__init__(
|
||||
model.get_booster(),
|
||||
|
@ -2,9 +2,9 @@
|
||||
# Elasticsearch B.V licenses this file to you under the Apache 2.0 License.
|
||||
# See the LICENSE file in the project root for more information
|
||||
|
||||
from typing import Union, List
|
||||
from typing import Union, List, Optional, Tuple, TYPE_CHECKING, cast
|
||||
|
||||
import numpy as np
|
||||
import numpy as np # type: ignore
|
||||
|
||||
from eland.common import es_version
|
||||
from eland.ml._model_transformers import (
|
||||
@ -14,15 +14,20 @@ from eland.ml._model_transformers import (
|
||||
XGBoostRegressorTransformer,
|
||||
XGBoostClassifierTransformer,
|
||||
)
|
||||
from eland.ml._model_serializer import ModelSerializer
|
||||
from eland.ml._optional import import_optional_dependency
|
||||
from eland.ml.ml_model import MLModel
|
||||
|
||||
sklearn = import_optional_dependency("sklearn")
|
||||
xgboost = import_optional_dependency("xgboost")
|
||||
|
||||
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
|
||||
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
|
||||
from xgboost import XGBRegressor, XGBClassifier
|
||||
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor # type: ignore
|
||||
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor # type: ignore
|
||||
from xgboost import XGBRegressor, XGBClassifier # type: ignore
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from elasticsearch import Elasticsearch # type: ignore # noqa: F401
|
||||
|
||||
|
||||
class ImportedMLModel(MLModel):
|
||||
@ -91,7 +96,7 @@ class ImportedMLModel(MLModel):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
es_client,
|
||||
es_client: Union[str, List[str], Tuple[str, ...], "Elasticsearch"],
|
||||
model_id: str,
|
||||
model: Union[
|
||||
DecisionTreeClassifier,
|
||||
@ -102,15 +107,16 @@ class ImportedMLModel(MLModel):
|
||||
XGBRegressor,
|
||||
],
|
||||
feature_names: List[str],
|
||||
classification_labels: List[str] = None,
|
||||
classification_weights: List[float] = None,
|
||||
overwrite=False,
|
||||
classification_labels: Optional[List[str]] = None,
|
||||
classification_weights: Optional[List[float]] = None,
|
||||
overwrite: bool = False,
|
||||
):
|
||||
super().__init__(es_client, model_id)
|
||||
|
||||
self._feature_names = feature_names
|
||||
self._model_type = None
|
||||
|
||||
serializer: ModelSerializer # type def
|
||||
# Transform model
|
||||
if isinstance(model, DecisionTreeRegressor):
|
||||
serializer = SKLearnDecisionTreeTransformer(
|
||||
@ -161,7 +167,7 @@ class ImportedMLModel(MLModel):
|
||||
model_id=self._model_id, body=body,
|
||||
)
|
||||
|
||||
def predict(self, X):
|
||||
def predict(self, X: Union[List[float], List[List[float]]]) -> np.ndarray:
|
||||
"""
|
||||
Make a prediction using a trained model stored in Elasticsearch.
|
||||
|
||||
@ -190,7 +196,7 @@ class ImportedMLModel(MLModel):
|
||||
|
||||
>>> # Get some test results
|
||||
>>> regressor.predict(np.array(test_data))
|
||||
array([0.23733574, 1.1897984 ], dtype=float32)
|
||||
array([0.06062475, 0.9990102 ], dtype=float32)
|
||||
|
||||
>>> # Serialise the model to Elasticsearch
|
||||
>>> feature_names = ["f0", "f1", "f2", "f3", "f4", "f5"]
|
||||
@ -199,7 +205,7 @@ class ImportedMLModel(MLModel):
|
||||
|
||||
>>> # Get some test results from Elasticsearch model
|
||||
>>> es_model.predict(test_data)
|
||||
array([0.2373357, 1.1897984], dtype=float32)
|
||||
array([0.0606248 , 0.99901026], dtype=float32)
|
||||
|
||||
>>> # Delete model from Elasticsearch
|
||||
>>> es_model.delete_model()
|
||||
@ -207,20 +213,22 @@ class ImportedMLModel(MLModel):
|
||||
"""
|
||||
docs = []
|
||||
if isinstance(X, list):
|
||||
# Is it a list of lists?
|
||||
if all(isinstance(i, list) for i in X):
|
||||
for i in X:
|
||||
doc = dict()
|
||||
doc["_source"] = dict(zip(self._feature_names, i))
|
||||
docs.append(doc)
|
||||
|
||||
else: # single feature vector1
|
||||
doc = dict()
|
||||
doc["_source"] = dict(zip(self._feature_names, i))
|
||||
# Is it a list of floats?
|
||||
if all(isinstance(i, float) for i in X):
|
||||
features = cast(List[List[float]], [X])
|
||||
else:
|
||||
features = cast(List[List[float]], X)
|
||||
for i in features:
|
||||
doc = {"_source": dict(zip(self._feature_names, i))}
|
||||
docs.append(doc)
|
||||
else:
|
||||
raise NotImplementedError(f"Prediction for type {type(X)}, not supported")
|
||||
|
||||
# field_mappings -> field_map in ES 7.7
|
||||
field_map_name = (
|
||||
"field_map" if es_version(self._client) >= (7, 7) else "field_mappings"
|
||||
)
|
||||
|
||||
results = self._client.ingest.simulate(
|
||||
body={
|
||||
"pipeline": {
|
||||
@ -229,7 +237,7 @@ class ImportedMLModel(MLModel):
|
||||
"inference": {
|
||||
"model_id": self._model_id,
|
||||
"inference_config": {self._model_type: {}},
|
||||
"field_mappings": {},
|
||||
field_map_name: {},
|
||||
}
|
||||
}
|
||||
]
|
||||
|
@ -34,7 +34,7 @@ class MLModel:
|
||||
self._client = ensure_es_client(es_client)
|
||||
self._model_id = model_id
|
||||
|
||||
def delete_model(self):
|
||||
def delete_model(self) -> None:
|
||||
"""
|
||||
Delete an inference model saved in Elasticsearch
|
||||
|
||||
|
@ -112,7 +112,7 @@ class TestImportedMLModel:
|
||||
def test_xgb_classifier(self):
|
||||
# Train model
|
||||
training_data = datasets.make_classification(n_features=5)
|
||||
classifier = XGBClassifier()
|
||||
classifier = XGBClassifier(booster="gbtree")
|
||||
classifier.fit(training_data[0], training_data[1])
|
||||
|
||||
# Get some test results
|
||||
@ -150,9 +150,36 @@ class TestImportedMLModel:
|
||||
es_model = ImportedMLModel(
|
||||
ES_TEST_CLIENT, model_id, regressor, feature_names, overwrite=True
|
||||
)
|
||||
|
||||
es_results = es_model.predict(test_data)
|
||||
|
||||
np.testing.assert_almost_equal(test_results, es_results, decimal=2)
|
||||
|
||||
# Clean up
|
||||
es_model.delete_model()
|
||||
|
||||
def test_predict_single_feature_vector(self):
|
||||
# Train model
|
||||
training_data = datasets.make_regression(n_features=1)
|
||||
regressor = XGBRegressor()
|
||||
regressor.fit(training_data[0], training_data[1])
|
||||
|
||||
# Get some test results
|
||||
test_data = [[0.1]]
|
||||
test_results = regressor.predict(np.asarray(test_data))
|
||||
|
||||
# Serialise the models to Elasticsearch
|
||||
feature_names = ["f0"]
|
||||
model_id = "test_xgb_regressor"
|
||||
|
||||
es_model = ImportedMLModel(
|
||||
ES_TEST_CLIENT, model_id, regressor, feature_names, overwrite=True
|
||||
)
|
||||
|
||||
# Single feature
|
||||
es_results = es_model.predict(test_data[0])
|
||||
|
||||
np.testing.assert_almost_equal(test_results, es_results, decimal=2)
|
||||
|
||||
# Clean up
|
||||
es_model.delete_model()
|
||||
|
@ -29,6 +29,8 @@ TYPED_FILES = {
|
||||
"eland/index.py",
|
||||
"eland/query.py",
|
||||
"eland/tasks.py",
|
||||
"eland/ml/_model_serializer.py",
|
||||
"eland/ml/imported_ml_model.py",
|
||||
}
|
||||
|
||||
|
||||
@ -50,6 +52,8 @@ def lint(session):
|
||||
# TODO: When all files are typed we can change this to .run("mypy", "--strict", "eland/")
|
||||
session.log("mypy --strict eland/")
|
||||
for typed_file in TYPED_FILES:
|
||||
if not os.path.isfile(typed_file):
|
||||
session.error(f"The file {typed_file!r} couldn't be found")
|
||||
popen = subprocess.Popen(
|
||||
f"mypy --strict {typed_file}",
|
||||
shell=True,
|
||||
|
@ -5,5 +5,5 @@ pytest>=5.2.1
|
||||
nbval
|
||||
numpydoc>=0.9.0
|
||||
scikit-learn>=0.22.1
|
||||
xgboost==0.90
|
||||
xgboost>=1
|
||||
nox
|
||||
|
4
setup.py
4
setup.py
@ -177,4 +177,8 @@ setup(
|
||||
packages=find_packages(include=["eland", "eland.*"]),
|
||||
install_requires=["elasticsearch==7.7.0a2", "pandas>=1", "matplotlib", "numpy"],
|
||||
python_requires=">=3.6",
|
||||
extras_require={
|
||||
"xgboost": ["xgboost>=0.90,<2"],
|
||||
"scikit-learn": ["scikit-learn>=0.22.1,<1"],
|
||||
},
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user