[ML] Add support for LGBMRegressor models

This commit is contained in:
Benjamin Trent 2020-08-11 08:42:59 -04:00 committed by GitHub
parent efb9e3b4c4
commit 6ee282e19f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 268 additions and 3 deletions

View File

@ -80,6 +80,10 @@ class TreeNode:
self._leaf_value = leaf_value
self._default_left = default_left
@property
def node_idx(self) -> int:
return self._node_idx
def to_dict(self) -> Dict[str, Any]:
d: Dict[str, Any] = {}
add_if_exists(d, "node_index", self._node_idx)

View File

@ -38,6 +38,10 @@ if TYPE_CHECKING:
from xgboost import XGBRegressor, XGBClassifier # type: ignore # noqa: F401
except ImportError:
pass
try:
from lightgbm import LGBMRegressor # type: ignore # noqa: f401
except ImportError:
pass
class ImportedMLModel(MLModel):
@ -59,14 +63,23 @@ class ImportedMLModel(MLModel):
- sklearn.tree.DecisionTreeRegressor
- sklearn.ensemble.RandomForestRegressor
- sklearn.ensemble.RandomForestClassifier
- lightgbm.LGBMRegressor
- Categorical fields are expected to already be processed
- Only the following objectives are supported
- "regression"
- "regression_l1"
- "huber"
- "fair"
- "quantile"
- "mape"
- xgboost.XGBClassifier
- only the following operators are supported:
- only the following objectives are supported:
- "binary:logistic"
- "binary:hinge"
- "multi:softmax"
- "multi:softprob"
- xgboost.XGBRegressor
- only the following operators are supportd:
- only the following objectives are supported:
- "reg:squarederror"
- "reg:linear"
- "reg:squaredlogerror"
@ -130,6 +143,7 @@ class ImportedMLModel(MLModel):
"RandomForestClassifier",
"XGBClassifier",
"XGBRegressor",
"LGBMRegressor",
],
feature_names: List[str],
classification_labels: Optional[List[str]] = None,

View File

@ -82,3 +82,16 @@ try:
_MODEL_TRANSFORMERS.update(_XGBOOST_MODEL_TRANSFORMERS)
except ImportError:
pass
try:
from .lightgbm import (
LGBMRegressor,
LGBMForestTransformer,
LGBMRegressorTransformer,
_MODEL_TRANSFORMERS as _LIGHTGBM_MODEL_TRANSFORMERS,
)
__all__ += ["LGBMRegressor", "LGBMForestTransformer", "LGBMRegressorTransformer"]
_MODEL_TRANSFORMERS.update(_LIGHTGBM_MODEL_TRANSFORMERS)
except ImportError:
pass

View File

@ -0,0 +1,188 @@
# Licensed to Elasticsearch B.V. under one or more contributor
# license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright
# ownership. Elasticsearch B.V. licenses this file to you under
# the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Optional, List, Dict, Any, Type
from .base import ModelTransformer
from .._model_serializer import Ensemble, Tree, TreeNode
from ..ml_model import MLModel
from .._optional import import_optional_dependency
import_optional_dependency("lightgbm", on_version="warn")
from lightgbm import Booster, LGBMRegressor # type: ignore
def transform_decider(decider: str) -> str:
if decider == "<=":
return "lte"
if decider == "<":
return "lt"
if decider == ">":
return "gt"
if decider == ">=":
return "gte"
raise ValueError(
"Unsupported splitting decider: %s. Only <=, <, >=, and > are allowed."
)
class Counter:
def __init__(self, start: int = 0):
self._value = start
def inc(self) -> "Counter":
self._value += 1
return self
def value(self) -> int:
return self._value
class LGBMForestTransformer(ModelTransformer):
"""
Base class for transforming LightGBM models into ensemble models supported by Elasticsearch
warning: do not use directly. Use a derived classes instead
"""
def __init__(
self,
model: Booster,
feature_names: List[str],
classification_labels: Optional[List[str]] = None,
classification_weights: Optional[List[float]] = None,
):
super().__init__(
model, feature_names, classification_labels, classification_weights
)
self._node_decision_type = "lte"
self._objective = model.params["objective"]
def build_tree(self, tree_json_obj: Dict[str, Any]) -> Tree:
tree_nodes = list()
next_id = Counter()
def add_tree_node(tree_node_json_obj: Dict[str, Any], counter: Counter) -> int:
curr_id = counter.value()
if "leaf_value" in tree_node_json_obj:
tree_nodes.append(
TreeNode(
node_idx=curr_id,
leaf_value=[float(tree_node_json_obj["leaf_value"])],
)
)
return curr_id
left_id = add_tree_node(tree_node_json_obj["left_child"], counter.inc())
right_id = add_tree_node(tree_node_json_obj["right_child"], counter.inc())
tree_nodes.append(
TreeNode(
node_idx=curr_id,
default_left=tree_node_json_obj["default_left"],
split_feature=tree_node_json_obj["split_feature"],
threshold=float(tree_node_json_obj["threshold"]),
decision_type=transform_decider(
tree_node_json_obj["decision_type"]
),
left_child=left_id,
right_child=right_id,
)
)
return curr_id
add_tree_node(tree_json_obj["tree_structure"], next_id)
tree_nodes.sort(key=lambda n: n.node_idx)
return Tree(
feature_names=self._feature_names,
target_type=self.determine_target_type(),
tree_structure=tree_nodes,
)
def build_forest(self) -> List[Tree]:
"""
This builds out the forest of trees as described by LightGBM into a format
supported by Elasticsearch
:return: A list of Tree objects
"""
self.check_model_booster()
json_dump = self._model.dump_model()
return [self.build_tree(t) for t in json_dump["tree_info"]]
def build_aggregator_output(self) -> Dict[str, Any]:
raise NotImplementedError("build_aggregator_output must be implemented")
def determine_target_type(self) -> str:
raise NotImplementedError("determine_target_type must be implemented")
def is_objective_supported(self) -> bool:
return False
def check_model_booster(self) -> None:
raise NotImplementedError("check_model_booster must be implemented")
def transform(self) -> Ensemble:
self.check_model_booster()
if not self.is_objective_supported():
raise ValueError(f"Unsupported objective '{self._objective}'")
forest = self.build_forest()
return Ensemble(
feature_names=self._feature_names,
trained_models=forest,
output_aggregator=self.build_aggregator_output(),
classification_labels=self._classification_labels,
classification_weights=self._classification_weights,
target_type=self.determine_target_type(),
)
class LGBMRegressorTransformer(LGBMForestTransformer):
def __init__(self, model: LGBMRegressor, feature_names: List[str]):
super().__init__(model.booster_, feature_names)
def is_objective_supported(self) -> bool:
return self._objective in {
"regression",
"regression_l1",
"huber",
"fair",
"quantile",
"mape",
}
def check_model_booster(self) -> None:
if self._model.params["boosting_type"] not in {"gbdt", "rf", "dart", "goss"}:
raise ValueError(
f"boosting type must exist and be of type 'gbdt', 'rf', 'dart', or 'goss'"
f", was {self._model.params['boosting_type']!r}"
)
def determine_target_type(self) -> str:
return "regression"
def build_aggregator_output(self) -> Dict[str, Any]:
return {"weighted_sum": {}}
@property
def model_type(self) -> str:
return MLModel.TYPE_REGRESSION
_MODEL_TRANSFORMERS: Dict[type, Type[ModelTransformer]] = {
LGBMRegressor: LGBMRegressorTransformer
}

View File

@ -38,6 +38,13 @@ try:
except ImportError:
HAS_XGBOOST = False
try:
from lightgbm import LGBMRegressor
HAS_LIGHTGBM = True
except ImportError:
HAS_LIGHTGBM = False
requires_sklearn = pytest.mark.skipif(
not HAS_SKLEARN, reason="This test requires 'scikit-learn' package to run"
@ -50,6 +57,10 @@ requires_no_ml_extras = pytest.mark.skipif(
reason="This test requires 'scikit-learn' and 'xgboost' to not be installed",
)
requires_lightgbm = pytest.mark.skipif(
not HAS_LIGHTGBM, reason="This test requires 'lightgbm' package to run"
)
class TestImportedMLModel:
@requires_no_ml_extras
@ -322,3 +333,35 @@ class TestImportedMLModel:
# Clean up
es_model.delete_model()
@requires_lightgbm
@pytest.mark.parametrize("compress_model_definition", [True, False])
def test_lgbm_regressor(self, compress_model_definition):
# Train model
training_data = datasets.make_regression(n_features=5)
regressor = LGBMRegressor()
regressor.fit(training_data[0], training_data[1])
# Get some test results
test_data = [[0.1, 0.2, 0.3, -0.5, 1.0], [1.6, 2.1, -10, 50, -1.0]]
test_results = regressor.predict(np.asarray(test_data))
# Serialise the models to Elasticsearch
feature_names = ["Column_0", "Column_1", "Column_2", "Column_3", "Column_4"]
model_id = "test_lgbm_regressor"
es_model = ImportedMLModel(
ES_TEST_CLIENT,
model_id,
regressor,
feature_names,
overwrite=True,
es_compress_model_definition=compress_model_definition,
)
es_results = es_model.predict(test_data)
np.testing.assert_almost_equal(test_results, es_results, decimal=2)
# Clean up
es_model.delete_model()

View File

@ -49,6 +49,7 @@ TYPED_FILES = {
"eland/ml/imported_ml_model.py",
"eland/ml/transformers/__init__.py",
"eland/ml/transformers/base.py",
"eland/ml/transformers/lightgbm.py",
"eland/ml/transformers/sklearn.py",
"eland/ml/transformers/xgboost.py",
}
@ -109,7 +110,7 @@ def test_ml_deps(session):
session.install("-r", "requirements-dev.txt")
session.run("python", "-m", "eland.tests.setup_tests")
session_uninstall("xgboost", "scikit-learn")
session_uninstall("xgboost", "scikit-learn", "lightgbm")
session.run("pytest", *(session.posargs or ("eland/tests/ml/",)))
session.install(".[scikit-learn]")

View File

@ -8,3 +8,4 @@ numpydoc>=0.9.0
scikit-learn>=0.22.1
xgboost>=1
nox
lightgbm>=2.3.0

View File

@ -193,5 +193,6 @@ setup(
extras_require={
"xgboost": ["xgboost>=0.90,<2"],
"scikit-learn": ["scikit-learn>=0.22.1,<1"],
"lightgbm": ["lightgbm>=2,<2.4"],
},
)