mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
[ML] Add support for LGBMRegressor models
This commit is contained in:
parent
efb9e3b4c4
commit
6ee282e19f
@ -80,6 +80,10 @@ class TreeNode:
|
||||
self._leaf_value = leaf_value
|
||||
self._default_left = default_left
|
||||
|
||||
@property
|
||||
def node_idx(self) -> int:
|
||||
return self._node_idx
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
d: Dict[str, Any] = {}
|
||||
add_if_exists(d, "node_index", self._node_idx)
|
||||
|
@ -38,6 +38,10 @@ if TYPE_CHECKING:
|
||||
from xgboost import XGBRegressor, XGBClassifier # type: ignore # noqa: F401
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
from lightgbm import LGBMRegressor # type: ignore # noqa: f401
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
class ImportedMLModel(MLModel):
|
||||
@ -59,14 +63,23 @@ class ImportedMLModel(MLModel):
|
||||
- sklearn.tree.DecisionTreeRegressor
|
||||
- sklearn.ensemble.RandomForestRegressor
|
||||
- sklearn.ensemble.RandomForestClassifier
|
||||
- lightgbm.LGBMRegressor
|
||||
- Categorical fields are expected to already be processed
|
||||
- Only the following objectives are supported
|
||||
- "regression"
|
||||
- "regression_l1"
|
||||
- "huber"
|
||||
- "fair"
|
||||
- "quantile"
|
||||
- "mape"
|
||||
- xgboost.XGBClassifier
|
||||
- only the following operators are supported:
|
||||
- only the following objectives are supported:
|
||||
- "binary:logistic"
|
||||
- "binary:hinge"
|
||||
- "multi:softmax"
|
||||
- "multi:softprob"
|
||||
- xgboost.XGBRegressor
|
||||
- only the following operators are supportd:
|
||||
- only the following objectives are supported:
|
||||
- "reg:squarederror"
|
||||
- "reg:linear"
|
||||
- "reg:squaredlogerror"
|
||||
@ -130,6 +143,7 @@ class ImportedMLModel(MLModel):
|
||||
"RandomForestClassifier",
|
||||
"XGBClassifier",
|
||||
"XGBRegressor",
|
||||
"LGBMRegressor",
|
||||
],
|
||||
feature_names: List[str],
|
||||
classification_labels: Optional[List[str]] = None,
|
||||
|
@ -82,3 +82,16 @@ try:
|
||||
_MODEL_TRANSFORMERS.update(_XGBOOST_MODEL_TRANSFORMERS)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from .lightgbm import (
|
||||
LGBMRegressor,
|
||||
LGBMForestTransformer,
|
||||
LGBMRegressorTransformer,
|
||||
_MODEL_TRANSFORMERS as _LIGHTGBM_MODEL_TRANSFORMERS,
|
||||
)
|
||||
|
||||
__all__ += ["LGBMRegressor", "LGBMForestTransformer", "LGBMRegressorTransformer"]
|
||||
_MODEL_TRANSFORMERS.update(_LIGHTGBM_MODEL_TRANSFORMERS)
|
||||
except ImportError:
|
||||
pass
|
||||
|
188
eland/ml/transformers/lightgbm.py
Normal file
188
eland/ml/transformers/lightgbm.py
Normal file
@ -0,0 +1,188 @@
|
||||
# Licensed to Elasticsearch B.V. under one or more contributor
|
||||
# license agreements. See the NOTICE file distributed with
|
||||
# this work for additional information regarding copyright
|
||||
# ownership. Elasticsearch B.V. licenses this file to you under
|
||||
# the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from typing import Optional, List, Dict, Any, Type
|
||||
from .base import ModelTransformer
|
||||
from .._model_serializer import Ensemble, Tree, TreeNode
|
||||
from ..ml_model import MLModel
|
||||
from .._optional import import_optional_dependency
|
||||
|
||||
import_optional_dependency("lightgbm", on_version="warn")
|
||||
|
||||
from lightgbm import Booster, LGBMRegressor # type: ignore
|
||||
|
||||
|
||||
def transform_decider(decider: str) -> str:
|
||||
if decider == "<=":
|
||||
return "lte"
|
||||
if decider == "<":
|
||||
return "lt"
|
||||
if decider == ">":
|
||||
return "gt"
|
||||
if decider == ">=":
|
||||
return "gte"
|
||||
raise ValueError(
|
||||
"Unsupported splitting decider: %s. Only <=, <, >=, and > are allowed."
|
||||
)
|
||||
|
||||
|
||||
class Counter:
|
||||
def __init__(self, start: int = 0):
|
||||
self._value = start
|
||||
|
||||
def inc(self) -> "Counter":
|
||||
self._value += 1
|
||||
return self
|
||||
|
||||
def value(self) -> int:
|
||||
return self._value
|
||||
|
||||
|
||||
class LGBMForestTransformer(ModelTransformer):
|
||||
"""
|
||||
Base class for transforming LightGBM models into ensemble models supported by Elasticsearch
|
||||
|
||||
warning: do not use directly. Use a derived classes instead
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Booster,
|
||||
feature_names: List[str],
|
||||
classification_labels: Optional[List[str]] = None,
|
||||
classification_weights: Optional[List[float]] = None,
|
||||
):
|
||||
super().__init__(
|
||||
model, feature_names, classification_labels, classification_weights
|
||||
)
|
||||
self._node_decision_type = "lte"
|
||||
self._objective = model.params["objective"]
|
||||
|
||||
def build_tree(self, tree_json_obj: Dict[str, Any]) -> Tree:
|
||||
tree_nodes = list()
|
||||
next_id = Counter()
|
||||
|
||||
def add_tree_node(tree_node_json_obj: Dict[str, Any], counter: Counter) -> int:
|
||||
curr_id = counter.value()
|
||||
if "leaf_value" in tree_node_json_obj:
|
||||
tree_nodes.append(
|
||||
TreeNode(
|
||||
node_idx=curr_id,
|
||||
leaf_value=[float(tree_node_json_obj["leaf_value"])],
|
||||
)
|
||||
)
|
||||
return curr_id
|
||||
left_id = add_tree_node(tree_node_json_obj["left_child"], counter.inc())
|
||||
right_id = add_tree_node(tree_node_json_obj["right_child"], counter.inc())
|
||||
tree_nodes.append(
|
||||
TreeNode(
|
||||
node_idx=curr_id,
|
||||
default_left=tree_node_json_obj["default_left"],
|
||||
split_feature=tree_node_json_obj["split_feature"],
|
||||
threshold=float(tree_node_json_obj["threshold"]),
|
||||
decision_type=transform_decider(
|
||||
tree_node_json_obj["decision_type"]
|
||||
),
|
||||
left_child=left_id,
|
||||
right_child=right_id,
|
||||
)
|
||||
)
|
||||
return curr_id
|
||||
|
||||
add_tree_node(tree_json_obj["tree_structure"], next_id)
|
||||
tree_nodes.sort(key=lambda n: n.node_idx)
|
||||
return Tree(
|
||||
feature_names=self._feature_names,
|
||||
target_type=self.determine_target_type(),
|
||||
tree_structure=tree_nodes,
|
||||
)
|
||||
|
||||
def build_forest(self) -> List[Tree]:
|
||||
"""
|
||||
This builds out the forest of trees as described by LightGBM into a format
|
||||
supported by Elasticsearch
|
||||
|
||||
:return: A list of Tree objects
|
||||
"""
|
||||
self.check_model_booster()
|
||||
json_dump = self._model.dump_model()
|
||||
return [self.build_tree(t) for t in json_dump["tree_info"]]
|
||||
|
||||
def build_aggregator_output(self) -> Dict[str, Any]:
|
||||
raise NotImplementedError("build_aggregator_output must be implemented")
|
||||
|
||||
def determine_target_type(self) -> str:
|
||||
raise NotImplementedError("determine_target_type must be implemented")
|
||||
|
||||
def is_objective_supported(self) -> bool:
|
||||
return False
|
||||
|
||||
def check_model_booster(self) -> None:
|
||||
raise NotImplementedError("check_model_booster must be implemented")
|
||||
|
||||
def transform(self) -> Ensemble:
|
||||
self.check_model_booster()
|
||||
|
||||
if not self.is_objective_supported():
|
||||
raise ValueError(f"Unsupported objective '{self._objective}'")
|
||||
|
||||
forest = self.build_forest()
|
||||
return Ensemble(
|
||||
feature_names=self._feature_names,
|
||||
trained_models=forest,
|
||||
output_aggregator=self.build_aggregator_output(),
|
||||
classification_labels=self._classification_labels,
|
||||
classification_weights=self._classification_weights,
|
||||
target_type=self.determine_target_type(),
|
||||
)
|
||||
|
||||
|
||||
class LGBMRegressorTransformer(LGBMForestTransformer):
|
||||
def __init__(self, model: LGBMRegressor, feature_names: List[str]):
|
||||
super().__init__(model.booster_, feature_names)
|
||||
|
||||
def is_objective_supported(self) -> bool:
|
||||
return self._objective in {
|
||||
"regression",
|
||||
"regression_l1",
|
||||
"huber",
|
||||
"fair",
|
||||
"quantile",
|
||||
"mape",
|
||||
}
|
||||
|
||||
def check_model_booster(self) -> None:
|
||||
if self._model.params["boosting_type"] not in {"gbdt", "rf", "dart", "goss"}:
|
||||
raise ValueError(
|
||||
f"boosting type must exist and be of type 'gbdt', 'rf', 'dart', or 'goss'"
|
||||
f", was {self._model.params['boosting_type']!r}"
|
||||
)
|
||||
|
||||
def determine_target_type(self) -> str:
|
||||
return "regression"
|
||||
|
||||
def build_aggregator_output(self) -> Dict[str, Any]:
|
||||
return {"weighted_sum": {}}
|
||||
|
||||
@property
|
||||
def model_type(self) -> str:
|
||||
return MLModel.TYPE_REGRESSION
|
||||
|
||||
|
||||
_MODEL_TRANSFORMERS: Dict[type, Type[ModelTransformer]] = {
|
||||
LGBMRegressor: LGBMRegressorTransformer
|
||||
}
|
@ -38,6 +38,13 @@ try:
|
||||
except ImportError:
|
||||
HAS_XGBOOST = False
|
||||
|
||||
try:
|
||||
from lightgbm import LGBMRegressor
|
||||
|
||||
HAS_LIGHTGBM = True
|
||||
except ImportError:
|
||||
HAS_LIGHTGBM = False
|
||||
|
||||
|
||||
requires_sklearn = pytest.mark.skipif(
|
||||
not HAS_SKLEARN, reason="This test requires 'scikit-learn' package to run"
|
||||
@ -50,6 +57,10 @@ requires_no_ml_extras = pytest.mark.skipif(
|
||||
reason="This test requires 'scikit-learn' and 'xgboost' to not be installed",
|
||||
)
|
||||
|
||||
requires_lightgbm = pytest.mark.skipif(
|
||||
not HAS_LIGHTGBM, reason="This test requires 'lightgbm' package to run"
|
||||
)
|
||||
|
||||
|
||||
class TestImportedMLModel:
|
||||
@requires_no_ml_extras
|
||||
@ -322,3 +333,35 @@ class TestImportedMLModel:
|
||||
|
||||
# Clean up
|
||||
es_model.delete_model()
|
||||
|
||||
@requires_lightgbm
|
||||
@pytest.mark.parametrize("compress_model_definition", [True, False])
|
||||
def test_lgbm_regressor(self, compress_model_definition):
|
||||
# Train model
|
||||
training_data = datasets.make_regression(n_features=5)
|
||||
regressor = LGBMRegressor()
|
||||
regressor.fit(training_data[0], training_data[1])
|
||||
|
||||
# Get some test results
|
||||
test_data = [[0.1, 0.2, 0.3, -0.5, 1.0], [1.6, 2.1, -10, 50, -1.0]]
|
||||
test_results = regressor.predict(np.asarray(test_data))
|
||||
|
||||
# Serialise the models to Elasticsearch
|
||||
feature_names = ["Column_0", "Column_1", "Column_2", "Column_3", "Column_4"]
|
||||
model_id = "test_lgbm_regressor"
|
||||
|
||||
es_model = ImportedMLModel(
|
||||
ES_TEST_CLIENT,
|
||||
model_id,
|
||||
regressor,
|
||||
feature_names,
|
||||
overwrite=True,
|
||||
es_compress_model_definition=compress_model_definition,
|
||||
)
|
||||
|
||||
es_results = es_model.predict(test_data)
|
||||
|
||||
np.testing.assert_almost_equal(test_results, es_results, decimal=2)
|
||||
|
||||
# Clean up
|
||||
es_model.delete_model()
|
||||
|
@ -49,6 +49,7 @@ TYPED_FILES = {
|
||||
"eland/ml/imported_ml_model.py",
|
||||
"eland/ml/transformers/__init__.py",
|
||||
"eland/ml/transformers/base.py",
|
||||
"eland/ml/transformers/lightgbm.py",
|
||||
"eland/ml/transformers/sklearn.py",
|
||||
"eland/ml/transformers/xgboost.py",
|
||||
}
|
||||
@ -109,7 +110,7 @@ def test_ml_deps(session):
|
||||
session.install("-r", "requirements-dev.txt")
|
||||
session.run("python", "-m", "eland.tests.setup_tests")
|
||||
|
||||
session_uninstall("xgboost", "scikit-learn")
|
||||
session_uninstall("xgboost", "scikit-learn", "lightgbm")
|
||||
session.run("pytest", *(session.posargs or ("eland/tests/ml/",)))
|
||||
|
||||
session.install(".[scikit-learn]")
|
||||
|
@ -8,3 +8,4 @@ numpydoc>=0.9.0
|
||||
scikit-learn>=0.22.1
|
||||
xgboost>=1
|
||||
nox
|
||||
lightgbm>=2.3.0
|
||||
|
Loading…
x
Reference in New Issue
Block a user