mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
[ML] Add support for LGBMClassifier models
This commit is contained in:
parent
701a8008ad
commit
f58634dc6e
@ -39,7 +39,7 @@ if TYPE_CHECKING:
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
from lightgbm import LGBMRegressor # type: ignore # noqa: f401
|
||||
from lightgbm import LGBMRegressor, LGBMClassifier # type: ignore # noqa: f401
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
@ -72,6 +72,12 @@ class ImportedMLModel(MLModel):
|
||||
- "fair"
|
||||
- "quantile"
|
||||
- "mape"
|
||||
- lightgbm.LGBMClassifier
|
||||
- Categorical fields are expected to already be processed
|
||||
- Only the following objectives are supported
|
||||
- "binary"
|
||||
- "multiclass"
|
||||
- "multiclassova"
|
||||
- xgboost.XGBClassifier
|
||||
- only the following objectives are supported:
|
||||
- "binary:logistic"
|
||||
@ -144,6 +150,7 @@ class ImportedMLModel(MLModel):
|
||||
"XGBClassifier",
|
||||
"XGBRegressor",
|
||||
"LGBMRegressor",
|
||||
"LGBMClassifier",
|
||||
],
|
||||
feature_names: List[str],
|
||||
classification_labels: Optional[List[str]] = None,
|
||||
|
@ -86,12 +86,20 @@ except ImportError:
|
||||
try:
|
||||
from .lightgbm import (
|
||||
LGBMRegressor,
|
||||
LGBMClassifier,
|
||||
LGBMForestTransformer,
|
||||
LGBMRegressorTransformer,
|
||||
LGBMClassifierTransformer,
|
||||
_MODEL_TRANSFORMERS as _LIGHTGBM_MODEL_TRANSFORMERS,
|
||||
)
|
||||
|
||||
__all__ += ["LGBMRegressor", "LGBMForestTransformer", "LGBMRegressorTransformer"]
|
||||
__all__ += [
|
||||
"LGBMRegressor",
|
||||
"LGBMClassifier",
|
||||
"LGBMForestTransformer",
|
||||
"LGBMRegressorTransformer",
|
||||
"LGBMClassifierTransformer",
|
||||
]
|
||||
_MODEL_TRANSFORMERS.update(_LIGHTGBM_MODEL_TRANSFORMERS)
|
||||
except ImportError:
|
||||
pass
|
||||
|
@ -23,7 +23,7 @@ from .._optional import import_optional_dependency
|
||||
|
||||
import_optional_dependency("lightgbm", on_version="warn")
|
||||
|
||||
from lightgbm import Booster, LGBMRegressor # type: ignore
|
||||
from lightgbm import Booster, LGBMRegressor, LGBMClassifier # type: ignore
|
||||
|
||||
|
||||
def transform_decider(decider: str) -> str:
|
||||
@ -69,10 +69,34 @@ class LGBMForestTransformer(ModelTransformer):
|
||||
super().__init__(
|
||||
model, feature_names, classification_labels, classification_weights
|
||||
)
|
||||
self._node_decision_type = "lte"
|
||||
self._objective = model.params["objective"]
|
||||
|
||||
def build_tree(self, tree_json_obj: Dict[str, Any]) -> Tree:
|
||||
def make_inner_node(
|
||||
self,
|
||||
tree_id: int,
|
||||
node_id: int,
|
||||
tree_node_json_obj: Dict[str, Any],
|
||||
left_child: int,
|
||||
right_child: int,
|
||||
) -> TreeNode:
|
||||
return TreeNode(
|
||||
node_idx=node_id,
|
||||
default_left=tree_node_json_obj["default_left"],
|
||||
split_feature=int(tree_node_json_obj["split_feature"]),
|
||||
threshold=float(tree_node_json_obj["threshold"]),
|
||||
decision_type=transform_decider(tree_node_json_obj["decision_type"]),
|
||||
left_child=left_child,
|
||||
right_child=right_child,
|
||||
)
|
||||
|
||||
def make_leaf_node(
|
||||
self, tree_id: int, node_id: int, tree_node_json_obj: Dict[str, Any]
|
||||
) -> TreeNode:
|
||||
return TreeNode(
|
||||
node_idx=node_id, leaf_value=[float(tree_node_json_obj["leaf_value"])],
|
||||
)
|
||||
|
||||
def build_tree(self, tree_id: int, tree_json_obj: Dict[str, Any]) -> Tree:
|
||||
tree_nodes = list()
|
||||
next_id = Counter()
|
||||
|
||||
@ -80,25 +104,14 @@ class LGBMForestTransformer(ModelTransformer):
|
||||
curr_id = counter.value()
|
||||
if "leaf_value" in tree_node_json_obj:
|
||||
tree_nodes.append(
|
||||
TreeNode(
|
||||
node_idx=curr_id,
|
||||
leaf_value=[float(tree_node_json_obj["leaf_value"])],
|
||||
)
|
||||
self.make_leaf_node(tree_id, curr_id, tree_node_json_obj)
|
||||
)
|
||||
return curr_id
|
||||
left_id = add_tree_node(tree_node_json_obj["left_child"], counter.inc())
|
||||
right_id = add_tree_node(tree_node_json_obj["right_child"], counter.inc())
|
||||
tree_nodes.append(
|
||||
TreeNode(
|
||||
node_idx=curr_id,
|
||||
default_left=tree_node_json_obj["default_left"],
|
||||
split_feature=tree_node_json_obj["split_feature"],
|
||||
threshold=float(tree_node_json_obj["threshold"]),
|
||||
decision_type=transform_decider(
|
||||
tree_node_json_obj["decision_type"]
|
||||
),
|
||||
left_child=left_id,
|
||||
right_child=right_id,
|
||||
self.make_inner_node(
|
||||
tree_id, curr_id, tree_node_json_obj, left_id, right_id
|
||||
)
|
||||
)
|
||||
return curr_id
|
||||
@ -120,7 +133,7 @@ class LGBMForestTransformer(ModelTransformer):
|
||||
"""
|
||||
self.check_model_booster()
|
||||
json_dump = self._model.dump_model()
|
||||
return [self.build_tree(t) for t in json_dump["tree_info"]]
|
||||
return [self.build_tree(i, t) for i, t in enumerate(json_dump["tree_info"])]
|
||||
|
||||
def build_aggregator_output(self) -> Dict[str, Any]:
|
||||
raise NotImplementedError("build_aggregator_output must be implemented")
|
||||
@ -190,6 +203,57 @@ class LGBMRegressorTransformer(LGBMForestTransformer):
|
||||
return MLModel.TYPE_REGRESSION
|
||||
|
||||
|
||||
class LGBMClassifierTransformer(LGBMForestTransformer):
|
||||
def __init__(
|
||||
self,
|
||||
model: LGBMClassifier,
|
||||
feature_names: List[str],
|
||||
classification_labels: List[str],
|
||||
classification_weights: List[float],
|
||||
):
|
||||
super().__init__(
|
||||
model.booster_, feature_names, classification_labels, classification_weights
|
||||
)
|
||||
self.n_estimators = int(model.n_estimators)
|
||||
self.n_classes = int(model.n_classes_)
|
||||
if not classification_labels:
|
||||
self._classification_labels = [str(x) for x in model.classes_]
|
||||
|
||||
def make_leaf_node(
|
||||
self, tree_id: int, node_id: int, tree_node_json_obj: Dict[str, Any]
|
||||
) -> TreeNode:
|
||||
if self._objective == "binary":
|
||||
return super().make_leaf_node(tree_id, node_id, tree_node_json_obj)
|
||||
leaf_val = [0.0] * self.n_classes
|
||||
leaf_val[tree_id % self.n_classes] = float(tree_node_json_obj["leaf_value"])
|
||||
return TreeNode(node_idx=node_id, leaf_value=leaf_val)
|
||||
|
||||
def check_model_booster(self) -> None:
|
||||
if self._model.params["boosting_type"] not in {"gbdt", "rf", "dart", "goss"}:
|
||||
raise ValueError(
|
||||
f"boosting type must exist and be of type 'gbdt', 'rf', 'dart', or 'goss'"
|
||||
f", was {self._model.params['boosting_type']!r}"
|
||||
)
|
||||
|
||||
def determine_target_type(self) -> str:
|
||||
return "classification"
|
||||
|
||||
def build_aggregator_output(self) -> Dict[str, Any]:
|
||||
return {"logistic_regression": {}}
|
||||
|
||||
@property
|
||||
def model_type(self) -> str:
|
||||
return MLModel.TYPE_CLASSIFICATION
|
||||
|
||||
def is_objective_supported(self) -> bool:
|
||||
return self._objective in {
|
||||
"binary",
|
||||
"multiclass",
|
||||
"multiclassova",
|
||||
}
|
||||
|
||||
|
||||
_MODEL_TRANSFORMERS: Dict[type, Type[ModelTransformer]] = {
|
||||
LGBMRegressor: LGBMRegressorTransformer
|
||||
LGBMRegressor: LGBMRegressorTransformer,
|
||||
LGBMClassifier: LGBMClassifierTransformer,
|
||||
}
|
||||
|
@ -39,7 +39,7 @@ except ImportError:
|
||||
HAS_XGBOOST = False
|
||||
|
||||
try:
|
||||
from lightgbm import LGBMRegressor
|
||||
from lightgbm import LGBMRegressor, LGBMClassifier
|
||||
|
||||
HAS_LIGHTGBM = True
|
||||
except ImportError:
|
||||
@ -62,6 +62,10 @@ requires_lightgbm = pytest.mark.skipif(
|
||||
)
|
||||
|
||||
|
||||
def random_rows(data, size):
|
||||
return data[np.random.randint(data.shape[0], size=size), :].tolist()
|
||||
|
||||
|
||||
def check_prediction_equality(es_model, py_model, test_data):
|
||||
# Get some test results
|
||||
test_results = py_model.predict(np.asarray(test_data))
|
||||
@ -140,8 +144,9 @@ class TestImportedMLModel:
|
||||
)
|
||||
|
||||
# Get some test results
|
||||
test_data = [[0.1, 0.2, 0.3, -0.5, 1.0], [1.6, 2.1, -10, 50, -1.0]]
|
||||
check_prediction_equality(es_model, classifier, test_data)
|
||||
check_prediction_equality(
|
||||
es_model, classifier, random_rows(training_data[0], 20)
|
||||
)
|
||||
|
||||
# Clean up
|
||||
es_model.delete_model()
|
||||
@ -167,8 +172,9 @@ class TestImportedMLModel:
|
||||
es_compress_model_definition=compress_model_definition,
|
||||
)
|
||||
# Get some test results
|
||||
test_data = [[0.1, 0.2, 0.3, -0.5, 1.0], [1.6, 2.1, -10, 50, -1.0]]
|
||||
check_prediction_equality(es_model, regressor, test_data)
|
||||
check_prediction_equality(
|
||||
es_model, regressor, random_rows(training_data[0], 20)
|
||||
)
|
||||
|
||||
# Clean up
|
||||
es_model.delete_model()
|
||||
@ -194,8 +200,9 @@ class TestImportedMLModel:
|
||||
es_compress_model_definition=compress_model_definition,
|
||||
)
|
||||
# Get some test results
|
||||
test_data = [[0.1, 0.2, 0.3, -0.5, 1.0], [1.6, 2.1, -10, 50, -1.0]]
|
||||
check_prediction_equality(es_model, classifier, test_data)
|
||||
check_prediction_equality(
|
||||
es_model, classifier, random_rows(training_data[0], 20)
|
||||
)
|
||||
|
||||
# Clean up
|
||||
es_model.delete_model()
|
||||
@ -221,8 +228,9 @@ class TestImportedMLModel:
|
||||
es_compress_model_definition=compress_model_definition,
|
||||
)
|
||||
# Get some test results
|
||||
test_data = [[0.1, 0.2, 0.3, -0.5, 1.0], [1.6, 2.1, -10, 50, -1.0]]
|
||||
check_prediction_equality(es_model, regressor, test_data)
|
||||
check_prediction_equality(
|
||||
es_model, regressor, random_rows(training_data[0], 20)
|
||||
)
|
||||
|
||||
# Clean up
|
||||
es_model.delete_model()
|
||||
@ -257,8 +265,9 @@ class TestImportedMLModel:
|
||||
es_compress_model_definition=compress_model_definition,
|
||||
)
|
||||
# Get some test results
|
||||
test_data = [[0.1, 0.2, 0.3, -0.5, 1.0], [1.6, 2.1, -10, 50, -1.0]]
|
||||
check_prediction_equality(es_model, classifier, test_data)
|
||||
check_prediction_equality(
|
||||
es_model, classifier, random_rows(training_data[0], 20)
|
||||
)
|
||||
|
||||
# Clean up
|
||||
es_model.delete_model()
|
||||
@ -290,8 +299,9 @@ class TestImportedMLModel:
|
||||
ES_TEST_CLIENT, model_id, classifier, feature_names, overwrite=True
|
||||
)
|
||||
# Get some test results
|
||||
test_data = [[0.1, 0.2, 0.3, -0.5, 1.0], [1.6, 2.1, -10, 50, -1.0]]
|
||||
check_prediction_equality(es_model, classifier, test_data)
|
||||
check_prediction_equality(
|
||||
es_model, classifier, random_rows(training_data[0], 20)
|
||||
)
|
||||
|
||||
# Clean up
|
||||
es_model.delete_model()
|
||||
@ -326,8 +336,9 @@ class TestImportedMLModel:
|
||||
es_compress_model_definition=compress_model_definition,
|
||||
)
|
||||
# Get some test results
|
||||
test_data = [[0.1, 0.2, 0.3, -0.5, 1.0], [1.6, 2.1, -10, 50, -1.0]]
|
||||
check_prediction_equality(es_model, regressor, test_data)
|
||||
check_prediction_equality(
|
||||
es_model, regressor, random_rows(training_data[0], 20)
|
||||
)
|
||||
|
||||
# Clean up
|
||||
es_model.delete_model()
|
||||
@ -393,8 +404,49 @@ class TestImportedMLModel:
|
||||
es_compress_model_definition=compress_model_definition,
|
||||
)
|
||||
# Get some test results
|
||||
test_data = [[0.1, 0.2, 0.3, -0.5, 1.0], [1.6, 2.1, -10, 50, -1.0]]
|
||||
check_prediction_equality(es_model, regressor, test_data)
|
||||
check_prediction_equality(
|
||||
es_model, regressor, random_rows(training_data[0], 20)
|
||||
)
|
||||
|
||||
# Clean up
|
||||
es_model.delete_model()
|
||||
|
||||
@requires_lightgbm
|
||||
@pytest.mark.parametrize("compress_model_definition", [True, False])
|
||||
@pytest.mark.parametrize("objective", ["binary", "multiclass", "multiclassova"])
|
||||
@pytest.mark.parametrize("booster", ["gbdt", "dart", "goss"])
|
||||
def test_lgbm_classifier_objectives_and_booster(
|
||||
self, compress_model_definition, objective, booster
|
||||
):
|
||||
# test both multiple and binary classification
|
||||
if objective.startswith("multi"):
|
||||
training_data = datasets.make_classification(
|
||||
n_features=5, n_classes=3, n_informative=3
|
||||
)
|
||||
classifier = LGBMClassifier(boosting_type=booster, objective=objective)
|
||||
else:
|
||||
training_data = datasets.make_classification(n_features=5)
|
||||
classifier = LGBMClassifier(boosting_type=booster, objective=objective)
|
||||
|
||||
# Train model
|
||||
classifier.fit(training_data[0], training_data[1])
|
||||
|
||||
# Serialise the models to Elasticsearch
|
||||
feature_names = ["Column_0", "Column_1", "Column_2", "Column_3", "Column_4"]
|
||||
model_id = "test_lgbm_classifier"
|
||||
|
||||
es_model = ImportedMLModel(
|
||||
ES_TEST_CLIENT,
|
||||
model_id,
|
||||
classifier,
|
||||
feature_names,
|
||||
overwrite=True,
|
||||
es_compress_model_definition=compress_model_definition,
|
||||
)
|
||||
|
||||
check_prediction_equality(
|
||||
es_model, classifier, random_rows(training_data[0], 20)
|
||||
)
|
||||
|
||||
# Clean up
|
||||
es_model.delete_model()
|
||||
|
Loading…
x
Reference in New Issue
Block a user