mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
[ML] Add support for LGBMRegressor models
This commit is contained in:
parent
efb9e3b4c4
commit
6ee282e19f
@ -80,6 +80,10 @@ class TreeNode:
|
|||||||
self._leaf_value = leaf_value
|
self._leaf_value = leaf_value
|
||||||
self._default_left = default_left
|
self._default_left = default_left
|
||||||
|
|
||||||
|
@property
|
||||||
|
def node_idx(self) -> int:
|
||||||
|
return self._node_idx
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
d: Dict[str, Any] = {}
|
d: Dict[str, Any] = {}
|
||||||
add_if_exists(d, "node_index", self._node_idx)
|
add_if_exists(d, "node_index", self._node_idx)
|
||||||
|
@ -38,6 +38,10 @@ if TYPE_CHECKING:
|
|||||||
from xgboost import XGBRegressor, XGBClassifier # type: ignore # noqa: F401
|
from xgboost import XGBRegressor, XGBClassifier # type: ignore # noqa: F401
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
try:
|
||||||
|
from lightgbm import LGBMRegressor # type: ignore # noqa: f401
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ImportedMLModel(MLModel):
|
class ImportedMLModel(MLModel):
|
||||||
@ -59,14 +63,23 @@ class ImportedMLModel(MLModel):
|
|||||||
- sklearn.tree.DecisionTreeRegressor
|
- sklearn.tree.DecisionTreeRegressor
|
||||||
- sklearn.ensemble.RandomForestRegressor
|
- sklearn.ensemble.RandomForestRegressor
|
||||||
- sklearn.ensemble.RandomForestClassifier
|
- sklearn.ensemble.RandomForestClassifier
|
||||||
|
- lightgbm.LGBMRegressor
|
||||||
|
- Categorical fields are expected to already be processed
|
||||||
|
- Only the following objectives are supported
|
||||||
|
- "regression"
|
||||||
|
- "regression_l1"
|
||||||
|
- "huber"
|
||||||
|
- "fair"
|
||||||
|
- "quantile"
|
||||||
|
- "mape"
|
||||||
- xgboost.XGBClassifier
|
- xgboost.XGBClassifier
|
||||||
- only the following operators are supported:
|
- only the following objectives are supported:
|
||||||
- "binary:logistic"
|
- "binary:logistic"
|
||||||
- "binary:hinge"
|
- "binary:hinge"
|
||||||
- "multi:softmax"
|
- "multi:softmax"
|
||||||
- "multi:softprob"
|
- "multi:softprob"
|
||||||
- xgboost.XGBRegressor
|
- xgboost.XGBRegressor
|
||||||
- only the following operators are supportd:
|
- only the following objectives are supported:
|
||||||
- "reg:squarederror"
|
- "reg:squarederror"
|
||||||
- "reg:linear"
|
- "reg:linear"
|
||||||
- "reg:squaredlogerror"
|
- "reg:squaredlogerror"
|
||||||
@ -130,6 +143,7 @@ class ImportedMLModel(MLModel):
|
|||||||
"RandomForestClassifier",
|
"RandomForestClassifier",
|
||||||
"XGBClassifier",
|
"XGBClassifier",
|
||||||
"XGBRegressor",
|
"XGBRegressor",
|
||||||
|
"LGBMRegressor",
|
||||||
],
|
],
|
||||||
feature_names: List[str],
|
feature_names: List[str],
|
||||||
classification_labels: Optional[List[str]] = None,
|
classification_labels: Optional[List[str]] = None,
|
||||||
|
@ -82,3 +82,16 @@ try:
|
|||||||
_MODEL_TRANSFORMERS.update(_XGBOOST_MODEL_TRANSFORMERS)
|
_MODEL_TRANSFORMERS.update(_XGBOOST_MODEL_TRANSFORMERS)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
from .lightgbm import (
|
||||||
|
LGBMRegressor,
|
||||||
|
LGBMForestTransformer,
|
||||||
|
LGBMRegressorTransformer,
|
||||||
|
_MODEL_TRANSFORMERS as _LIGHTGBM_MODEL_TRANSFORMERS,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ += ["LGBMRegressor", "LGBMForestTransformer", "LGBMRegressorTransformer"]
|
||||||
|
_MODEL_TRANSFORMERS.update(_LIGHTGBM_MODEL_TRANSFORMERS)
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
188
eland/ml/transformers/lightgbm.py
Normal file
188
eland/ml/transformers/lightgbm.py
Normal file
@ -0,0 +1,188 @@
|
|||||||
|
# Licensed to Elasticsearch B.V. under one or more contributor
|
||||||
|
# license agreements. See the NOTICE file distributed with
|
||||||
|
# this work for additional information regarding copyright
|
||||||
|
# ownership. Elasticsearch B.V. licenses this file to you under
|
||||||
|
# the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
# not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
|
||||||
|
from typing import Optional, List, Dict, Any, Type
|
||||||
|
from .base import ModelTransformer
|
||||||
|
from .._model_serializer import Ensemble, Tree, TreeNode
|
||||||
|
from ..ml_model import MLModel
|
||||||
|
from .._optional import import_optional_dependency
|
||||||
|
|
||||||
|
import_optional_dependency("lightgbm", on_version="warn")
|
||||||
|
|
||||||
|
from lightgbm import Booster, LGBMRegressor # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
def transform_decider(decider: str) -> str:
|
||||||
|
if decider == "<=":
|
||||||
|
return "lte"
|
||||||
|
if decider == "<":
|
||||||
|
return "lt"
|
||||||
|
if decider == ">":
|
||||||
|
return "gt"
|
||||||
|
if decider == ">=":
|
||||||
|
return "gte"
|
||||||
|
raise ValueError(
|
||||||
|
"Unsupported splitting decider: %s. Only <=, <, >=, and > are allowed."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Counter:
|
||||||
|
def __init__(self, start: int = 0):
|
||||||
|
self._value = start
|
||||||
|
|
||||||
|
def inc(self) -> "Counter":
|
||||||
|
self._value += 1
|
||||||
|
return self
|
||||||
|
|
||||||
|
def value(self) -> int:
|
||||||
|
return self._value
|
||||||
|
|
||||||
|
|
||||||
|
class LGBMForestTransformer(ModelTransformer):
|
||||||
|
"""
|
||||||
|
Base class for transforming LightGBM models into ensemble models supported by Elasticsearch
|
||||||
|
|
||||||
|
warning: do not use directly. Use a derived classes instead
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: Booster,
|
||||||
|
feature_names: List[str],
|
||||||
|
classification_labels: Optional[List[str]] = None,
|
||||||
|
classification_weights: Optional[List[float]] = None,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
model, feature_names, classification_labels, classification_weights
|
||||||
|
)
|
||||||
|
self._node_decision_type = "lte"
|
||||||
|
self._objective = model.params["objective"]
|
||||||
|
|
||||||
|
def build_tree(self, tree_json_obj: Dict[str, Any]) -> Tree:
|
||||||
|
tree_nodes = list()
|
||||||
|
next_id = Counter()
|
||||||
|
|
||||||
|
def add_tree_node(tree_node_json_obj: Dict[str, Any], counter: Counter) -> int:
|
||||||
|
curr_id = counter.value()
|
||||||
|
if "leaf_value" in tree_node_json_obj:
|
||||||
|
tree_nodes.append(
|
||||||
|
TreeNode(
|
||||||
|
node_idx=curr_id,
|
||||||
|
leaf_value=[float(tree_node_json_obj["leaf_value"])],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return curr_id
|
||||||
|
left_id = add_tree_node(tree_node_json_obj["left_child"], counter.inc())
|
||||||
|
right_id = add_tree_node(tree_node_json_obj["right_child"], counter.inc())
|
||||||
|
tree_nodes.append(
|
||||||
|
TreeNode(
|
||||||
|
node_idx=curr_id,
|
||||||
|
default_left=tree_node_json_obj["default_left"],
|
||||||
|
split_feature=tree_node_json_obj["split_feature"],
|
||||||
|
threshold=float(tree_node_json_obj["threshold"]),
|
||||||
|
decision_type=transform_decider(
|
||||||
|
tree_node_json_obj["decision_type"]
|
||||||
|
),
|
||||||
|
left_child=left_id,
|
||||||
|
right_child=right_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return curr_id
|
||||||
|
|
||||||
|
add_tree_node(tree_json_obj["tree_structure"], next_id)
|
||||||
|
tree_nodes.sort(key=lambda n: n.node_idx)
|
||||||
|
return Tree(
|
||||||
|
feature_names=self._feature_names,
|
||||||
|
target_type=self.determine_target_type(),
|
||||||
|
tree_structure=tree_nodes,
|
||||||
|
)
|
||||||
|
|
||||||
|
def build_forest(self) -> List[Tree]:
|
||||||
|
"""
|
||||||
|
This builds out the forest of trees as described by LightGBM into a format
|
||||||
|
supported by Elasticsearch
|
||||||
|
|
||||||
|
:return: A list of Tree objects
|
||||||
|
"""
|
||||||
|
self.check_model_booster()
|
||||||
|
json_dump = self._model.dump_model()
|
||||||
|
return [self.build_tree(t) for t in json_dump["tree_info"]]
|
||||||
|
|
||||||
|
def build_aggregator_output(self) -> Dict[str, Any]:
|
||||||
|
raise NotImplementedError("build_aggregator_output must be implemented")
|
||||||
|
|
||||||
|
def determine_target_type(self) -> str:
|
||||||
|
raise NotImplementedError("determine_target_type must be implemented")
|
||||||
|
|
||||||
|
def is_objective_supported(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def check_model_booster(self) -> None:
|
||||||
|
raise NotImplementedError("check_model_booster must be implemented")
|
||||||
|
|
||||||
|
def transform(self) -> Ensemble:
|
||||||
|
self.check_model_booster()
|
||||||
|
|
||||||
|
if not self.is_objective_supported():
|
||||||
|
raise ValueError(f"Unsupported objective '{self._objective}'")
|
||||||
|
|
||||||
|
forest = self.build_forest()
|
||||||
|
return Ensemble(
|
||||||
|
feature_names=self._feature_names,
|
||||||
|
trained_models=forest,
|
||||||
|
output_aggregator=self.build_aggregator_output(),
|
||||||
|
classification_labels=self._classification_labels,
|
||||||
|
classification_weights=self._classification_weights,
|
||||||
|
target_type=self.determine_target_type(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LGBMRegressorTransformer(LGBMForestTransformer):
|
||||||
|
def __init__(self, model: LGBMRegressor, feature_names: List[str]):
|
||||||
|
super().__init__(model.booster_, feature_names)
|
||||||
|
|
||||||
|
def is_objective_supported(self) -> bool:
|
||||||
|
return self._objective in {
|
||||||
|
"regression",
|
||||||
|
"regression_l1",
|
||||||
|
"huber",
|
||||||
|
"fair",
|
||||||
|
"quantile",
|
||||||
|
"mape",
|
||||||
|
}
|
||||||
|
|
||||||
|
def check_model_booster(self) -> None:
|
||||||
|
if self._model.params["boosting_type"] not in {"gbdt", "rf", "dart", "goss"}:
|
||||||
|
raise ValueError(
|
||||||
|
f"boosting type must exist and be of type 'gbdt', 'rf', 'dart', or 'goss'"
|
||||||
|
f", was {self._model.params['boosting_type']!r}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def determine_target_type(self) -> str:
|
||||||
|
return "regression"
|
||||||
|
|
||||||
|
def build_aggregator_output(self) -> Dict[str, Any]:
|
||||||
|
return {"weighted_sum": {}}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_type(self) -> str:
|
||||||
|
return MLModel.TYPE_REGRESSION
|
||||||
|
|
||||||
|
|
||||||
|
_MODEL_TRANSFORMERS: Dict[type, Type[ModelTransformer]] = {
|
||||||
|
LGBMRegressor: LGBMRegressorTransformer
|
||||||
|
}
|
@ -38,6 +38,13 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
HAS_XGBOOST = False
|
HAS_XGBOOST = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
from lightgbm import LGBMRegressor
|
||||||
|
|
||||||
|
HAS_LIGHTGBM = True
|
||||||
|
except ImportError:
|
||||||
|
HAS_LIGHTGBM = False
|
||||||
|
|
||||||
|
|
||||||
requires_sklearn = pytest.mark.skipif(
|
requires_sklearn = pytest.mark.skipif(
|
||||||
not HAS_SKLEARN, reason="This test requires 'scikit-learn' package to run"
|
not HAS_SKLEARN, reason="This test requires 'scikit-learn' package to run"
|
||||||
@ -50,6 +57,10 @@ requires_no_ml_extras = pytest.mark.skipif(
|
|||||||
reason="This test requires 'scikit-learn' and 'xgboost' to not be installed",
|
reason="This test requires 'scikit-learn' and 'xgboost' to not be installed",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
requires_lightgbm = pytest.mark.skipif(
|
||||||
|
not HAS_LIGHTGBM, reason="This test requires 'lightgbm' package to run"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestImportedMLModel:
|
class TestImportedMLModel:
|
||||||
@requires_no_ml_extras
|
@requires_no_ml_extras
|
||||||
@ -322,3 +333,35 @@ class TestImportedMLModel:
|
|||||||
|
|
||||||
# Clean up
|
# Clean up
|
||||||
es_model.delete_model()
|
es_model.delete_model()
|
||||||
|
|
||||||
|
@requires_lightgbm
|
||||||
|
@pytest.mark.parametrize("compress_model_definition", [True, False])
|
||||||
|
def test_lgbm_regressor(self, compress_model_definition):
|
||||||
|
# Train model
|
||||||
|
training_data = datasets.make_regression(n_features=5)
|
||||||
|
regressor = LGBMRegressor()
|
||||||
|
regressor.fit(training_data[0], training_data[1])
|
||||||
|
|
||||||
|
# Get some test results
|
||||||
|
test_data = [[0.1, 0.2, 0.3, -0.5, 1.0], [1.6, 2.1, -10, 50, -1.0]]
|
||||||
|
test_results = regressor.predict(np.asarray(test_data))
|
||||||
|
|
||||||
|
# Serialise the models to Elasticsearch
|
||||||
|
feature_names = ["Column_0", "Column_1", "Column_2", "Column_3", "Column_4"]
|
||||||
|
model_id = "test_lgbm_regressor"
|
||||||
|
|
||||||
|
es_model = ImportedMLModel(
|
||||||
|
ES_TEST_CLIENT,
|
||||||
|
model_id,
|
||||||
|
regressor,
|
||||||
|
feature_names,
|
||||||
|
overwrite=True,
|
||||||
|
es_compress_model_definition=compress_model_definition,
|
||||||
|
)
|
||||||
|
|
||||||
|
es_results = es_model.predict(test_data)
|
||||||
|
|
||||||
|
np.testing.assert_almost_equal(test_results, es_results, decimal=2)
|
||||||
|
|
||||||
|
# Clean up
|
||||||
|
es_model.delete_model()
|
||||||
|
@ -49,6 +49,7 @@ TYPED_FILES = {
|
|||||||
"eland/ml/imported_ml_model.py",
|
"eland/ml/imported_ml_model.py",
|
||||||
"eland/ml/transformers/__init__.py",
|
"eland/ml/transformers/__init__.py",
|
||||||
"eland/ml/transformers/base.py",
|
"eland/ml/transformers/base.py",
|
||||||
|
"eland/ml/transformers/lightgbm.py",
|
||||||
"eland/ml/transformers/sklearn.py",
|
"eland/ml/transformers/sklearn.py",
|
||||||
"eland/ml/transformers/xgboost.py",
|
"eland/ml/transformers/xgboost.py",
|
||||||
}
|
}
|
||||||
@ -109,7 +110,7 @@ def test_ml_deps(session):
|
|||||||
session.install("-r", "requirements-dev.txt")
|
session.install("-r", "requirements-dev.txt")
|
||||||
session.run("python", "-m", "eland.tests.setup_tests")
|
session.run("python", "-m", "eland.tests.setup_tests")
|
||||||
|
|
||||||
session_uninstall("xgboost", "scikit-learn")
|
session_uninstall("xgboost", "scikit-learn", "lightgbm")
|
||||||
session.run("pytest", *(session.posargs or ("eland/tests/ml/",)))
|
session.run("pytest", *(session.posargs or ("eland/tests/ml/",)))
|
||||||
|
|
||||||
session.install(".[scikit-learn]")
|
session.install(".[scikit-learn]")
|
||||||
|
@ -8,3 +8,4 @@ numpydoc>=0.9.0
|
|||||||
scikit-learn>=0.22.1
|
scikit-learn>=0.22.1
|
||||||
xgboost>=1
|
xgboost>=1
|
||||||
nox
|
nox
|
||||||
|
lightgbm>=2.3.0
|
||||||
|
Loading…
x
Reference in New Issue
Block a user