mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
[ML] Add support for LGBMClassifier models
This commit is contained in:
parent
701a8008ad
commit
f58634dc6e
@ -39,7 +39,7 @@ if TYPE_CHECKING:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
try:
|
try:
|
||||||
from lightgbm import LGBMRegressor # type: ignore # noqa: f401
|
from lightgbm import LGBMRegressor, LGBMClassifier # type: ignore # noqa: f401
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -72,6 +72,12 @@ class ImportedMLModel(MLModel):
|
|||||||
- "fair"
|
- "fair"
|
||||||
- "quantile"
|
- "quantile"
|
||||||
- "mape"
|
- "mape"
|
||||||
|
- lightgbm.LGBMClassifier
|
||||||
|
- Categorical fields are expected to already be processed
|
||||||
|
- Only the following objectives are supported
|
||||||
|
- "binary"
|
||||||
|
- "multiclass"
|
||||||
|
- "multiclassova"
|
||||||
- xgboost.XGBClassifier
|
- xgboost.XGBClassifier
|
||||||
- only the following objectives are supported:
|
- only the following objectives are supported:
|
||||||
- "binary:logistic"
|
- "binary:logistic"
|
||||||
@ -144,6 +150,7 @@ class ImportedMLModel(MLModel):
|
|||||||
"XGBClassifier",
|
"XGBClassifier",
|
||||||
"XGBRegressor",
|
"XGBRegressor",
|
||||||
"LGBMRegressor",
|
"LGBMRegressor",
|
||||||
|
"LGBMClassifier",
|
||||||
],
|
],
|
||||||
feature_names: List[str],
|
feature_names: List[str],
|
||||||
classification_labels: Optional[List[str]] = None,
|
classification_labels: Optional[List[str]] = None,
|
||||||
|
@ -86,12 +86,20 @@ except ImportError:
|
|||||||
try:
|
try:
|
||||||
from .lightgbm import (
|
from .lightgbm import (
|
||||||
LGBMRegressor,
|
LGBMRegressor,
|
||||||
|
LGBMClassifier,
|
||||||
LGBMForestTransformer,
|
LGBMForestTransformer,
|
||||||
LGBMRegressorTransformer,
|
LGBMRegressorTransformer,
|
||||||
|
LGBMClassifierTransformer,
|
||||||
_MODEL_TRANSFORMERS as _LIGHTGBM_MODEL_TRANSFORMERS,
|
_MODEL_TRANSFORMERS as _LIGHTGBM_MODEL_TRANSFORMERS,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ += ["LGBMRegressor", "LGBMForestTransformer", "LGBMRegressorTransformer"]
|
__all__ += [
|
||||||
|
"LGBMRegressor",
|
||||||
|
"LGBMClassifier",
|
||||||
|
"LGBMForestTransformer",
|
||||||
|
"LGBMRegressorTransformer",
|
||||||
|
"LGBMClassifierTransformer",
|
||||||
|
]
|
||||||
_MODEL_TRANSFORMERS.update(_LIGHTGBM_MODEL_TRANSFORMERS)
|
_MODEL_TRANSFORMERS.update(_LIGHTGBM_MODEL_TRANSFORMERS)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
@ -23,7 +23,7 @@ from .._optional import import_optional_dependency
|
|||||||
|
|
||||||
import_optional_dependency("lightgbm", on_version="warn")
|
import_optional_dependency("lightgbm", on_version="warn")
|
||||||
|
|
||||||
from lightgbm import Booster, LGBMRegressor # type: ignore
|
from lightgbm import Booster, LGBMRegressor, LGBMClassifier # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def transform_decider(decider: str) -> str:
|
def transform_decider(decider: str) -> str:
|
||||||
@ -69,10 +69,34 @@ class LGBMForestTransformer(ModelTransformer):
|
|||||||
super().__init__(
|
super().__init__(
|
||||||
model, feature_names, classification_labels, classification_weights
|
model, feature_names, classification_labels, classification_weights
|
||||||
)
|
)
|
||||||
self._node_decision_type = "lte"
|
|
||||||
self._objective = model.params["objective"]
|
self._objective = model.params["objective"]
|
||||||
|
|
||||||
def build_tree(self, tree_json_obj: Dict[str, Any]) -> Tree:
|
def make_inner_node(
|
||||||
|
self,
|
||||||
|
tree_id: int,
|
||||||
|
node_id: int,
|
||||||
|
tree_node_json_obj: Dict[str, Any],
|
||||||
|
left_child: int,
|
||||||
|
right_child: int,
|
||||||
|
) -> TreeNode:
|
||||||
|
return TreeNode(
|
||||||
|
node_idx=node_id,
|
||||||
|
default_left=tree_node_json_obj["default_left"],
|
||||||
|
split_feature=int(tree_node_json_obj["split_feature"]),
|
||||||
|
threshold=float(tree_node_json_obj["threshold"]),
|
||||||
|
decision_type=transform_decider(tree_node_json_obj["decision_type"]),
|
||||||
|
left_child=left_child,
|
||||||
|
right_child=right_child,
|
||||||
|
)
|
||||||
|
|
||||||
|
def make_leaf_node(
|
||||||
|
self, tree_id: int, node_id: int, tree_node_json_obj: Dict[str, Any]
|
||||||
|
) -> TreeNode:
|
||||||
|
return TreeNode(
|
||||||
|
node_idx=node_id, leaf_value=[float(tree_node_json_obj["leaf_value"])],
|
||||||
|
)
|
||||||
|
|
||||||
|
def build_tree(self, tree_id: int, tree_json_obj: Dict[str, Any]) -> Tree:
|
||||||
tree_nodes = list()
|
tree_nodes = list()
|
||||||
next_id = Counter()
|
next_id = Counter()
|
||||||
|
|
||||||
@ -80,25 +104,14 @@ class LGBMForestTransformer(ModelTransformer):
|
|||||||
curr_id = counter.value()
|
curr_id = counter.value()
|
||||||
if "leaf_value" in tree_node_json_obj:
|
if "leaf_value" in tree_node_json_obj:
|
||||||
tree_nodes.append(
|
tree_nodes.append(
|
||||||
TreeNode(
|
self.make_leaf_node(tree_id, curr_id, tree_node_json_obj)
|
||||||
node_idx=curr_id,
|
|
||||||
leaf_value=[float(tree_node_json_obj["leaf_value"])],
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
return curr_id
|
return curr_id
|
||||||
left_id = add_tree_node(tree_node_json_obj["left_child"], counter.inc())
|
left_id = add_tree_node(tree_node_json_obj["left_child"], counter.inc())
|
||||||
right_id = add_tree_node(tree_node_json_obj["right_child"], counter.inc())
|
right_id = add_tree_node(tree_node_json_obj["right_child"], counter.inc())
|
||||||
tree_nodes.append(
|
tree_nodes.append(
|
||||||
TreeNode(
|
self.make_inner_node(
|
||||||
node_idx=curr_id,
|
tree_id, curr_id, tree_node_json_obj, left_id, right_id
|
||||||
default_left=tree_node_json_obj["default_left"],
|
|
||||||
split_feature=tree_node_json_obj["split_feature"],
|
|
||||||
threshold=float(tree_node_json_obj["threshold"]),
|
|
||||||
decision_type=transform_decider(
|
|
||||||
tree_node_json_obj["decision_type"]
|
|
||||||
),
|
|
||||||
left_child=left_id,
|
|
||||||
right_child=right_id,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return curr_id
|
return curr_id
|
||||||
@ -120,7 +133,7 @@ class LGBMForestTransformer(ModelTransformer):
|
|||||||
"""
|
"""
|
||||||
self.check_model_booster()
|
self.check_model_booster()
|
||||||
json_dump = self._model.dump_model()
|
json_dump = self._model.dump_model()
|
||||||
return [self.build_tree(t) for t in json_dump["tree_info"]]
|
return [self.build_tree(i, t) for i, t in enumerate(json_dump["tree_info"])]
|
||||||
|
|
||||||
def build_aggregator_output(self) -> Dict[str, Any]:
|
def build_aggregator_output(self) -> Dict[str, Any]:
|
||||||
raise NotImplementedError("build_aggregator_output must be implemented")
|
raise NotImplementedError("build_aggregator_output must be implemented")
|
||||||
@ -190,6 +203,57 @@ class LGBMRegressorTransformer(LGBMForestTransformer):
|
|||||||
return MLModel.TYPE_REGRESSION
|
return MLModel.TYPE_REGRESSION
|
||||||
|
|
||||||
|
|
||||||
|
class LGBMClassifierTransformer(LGBMForestTransformer):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: LGBMClassifier,
|
||||||
|
feature_names: List[str],
|
||||||
|
classification_labels: List[str],
|
||||||
|
classification_weights: List[float],
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
model.booster_, feature_names, classification_labels, classification_weights
|
||||||
|
)
|
||||||
|
self.n_estimators = int(model.n_estimators)
|
||||||
|
self.n_classes = int(model.n_classes_)
|
||||||
|
if not classification_labels:
|
||||||
|
self._classification_labels = [str(x) for x in model.classes_]
|
||||||
|
|
||||||
|
def make_leaf_node(
|
||||||
|
self, tree_id: int, node_id: int, tree_node_json_obj: Dict[str, Any]
|
||||||
|
) -> TreeNode:
|
||||||
|
if self._objective == "binary":
|
||||||
|
return super().make_leaf_node(tree_id, node_id, tree_node_json_obj)
|
||||||
|
leaf_val = [0.0] * self.n_classes
|
||||||
|
leaf_val[tree_id % self.n_classes] = float(tree_node_json_obj["leaf_value"])
|
||||||
|
return TreeNode(node_idx=node_id, leaf_value=leaf_val)
|
||||||
|
|
||||||
|
def check_model_booster(self) -> None:
|
||||||
|
if self._model.params["boosting_type"] not in {"gbdt", "rf", "dart", "goss"}:
|
||||||
|
raise ValueError(
|
||||||
|
f"boosting type must exist and be of type 'gbdt', 'rf', 'dart', or 'goss'"
|
||||||
|
f", was {self._model.params['boosting_type']!r}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def determine_target_type(self) -> str:
|
||||||
|
return "classification"
|
||||||
|
|
||||||
|
def build_aggregator_output(self) -> Dict[str, Any]:
|
||||||
|
return {"logistic_regression": {}}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_type(self) -> str:
|
||||||
|
return MLModel.TYPE_CLASSIFICATION
|
||||||
|
|
||||||
|
def is_objective_supported(self) -> bool:
|
||||||
|
return self._objective in {
|
||||||
|
"binary",
|
||||||
|
"multiclass",
|
||||||
|
"multiclassova",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
_MODEL_TRANSFORMERS: Dict[type, Type[ModelTransformer]] = {
|
_MODEL_TRANSFORMERS: Dict[type, Type[ModelTransformer]] = {
|
||||||
LGBMRegressor: LGBMRegressorTransformer
|
LGBMRegressor: LGBMRegressorTransformer,
|
||||||
|
LGBMClassifier: LGBMClassifierTransformer,
|
||||||
}
|
}
|
||||||
|
@ -39,7 +39,7 @@ except ImportError:
|
|||||||
HAS_XGBOOST = False
|
HAS_XGBOOST = False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from lightgbm import LGBMRegressor
|
from lightgbm import LGBMRegressor, LGBMClassifier
|
||||||
|
|
||||||
HAS_LIGHTGBM = True
|
HAS_LIGHTGBM = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@ -62,6 +62,10 @@ requires_lightgbm = pytest.mark.skipif(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def random_rows(data, size):
|
||||||
|
return data[np.random.randint(data.shape[0], size=size), :].tolist()
|
||||||
|
|
||||||
|
|
||||||
def check_prediction_equality(es_model, py_model, test_data):
|
def check_prediction_equality(es_model, py_model, test_data):
|
||||||
# Get some test results
|
# Get some test results
|
||||||
test_results = py_model.predict(np.asarray(test_data))
|
test_results = py_model.predict(np.asarray(test_data))
|
||||||
@ -140,8 +144,9 @@ class TestImportedMLModel:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Get some test results
|
# Get some test results
|
||||||
test_data = [[0.1, 0.2, 0.3, -0.5, 1.0], [1.6, 2.1, -10, 50, -1.0]]
|
check_prediction_equality(
|
||||||
check_prediction_equality(es_model, classifier, test_data)
|
es_model, classifier, random_rows(training_data[0], 20)
|
||||||
|
)
|
||||||
|
|
||||||
# Clean up
|
# Clean up
|
||||||
es_model.delete_model()
|
es_model.delete_model()
|
||||||
@ -167,8 +172,9 @@ class TestImportedMLModel:
|
|||||||
es_compress_model_definition=compress_model_definition,
|
es_compress_model_definition=compress_model_definition,
|
||||||
)
|
)
|
||||||
# Get some test results
|
# Get some test results
|
||||||
test_data = [[0.1, 0.2, 0.3, -0.5, 1.0], [1.6, 2.1, -10, 50, -1.0]]
|
check_prediction_equality(
|
||||||
check_prediction_equality(es_model, regressor, test_data)
|
es_model, regressor, random_rows(training_data[0], 20)
|
||||||
|
)
|
||||||
|
|
||||||
# Clean up
|
# Clean up
|
||||||
es_model.delete_model()
|
es_model.delete_model()
|
||||||
@ -194,8 +200,9 @@ class TestImportedMLModel:
|
|||||||
es_compress_model_definition=compress_model_definition,
|
es_compress_model_definition=compress_model_definition,
|
||||||
)
|
)
|
||||||
# Get some test results
|
# Get some test results
|
||||||
test_data = [[0.1, 0.2, 0.3, -0.5, 1.0], [1.6, 2.1, -10, 50, -1.0]]
|
check_prediction_equality(
|
||||||
check_prediction_equality(es_model, classifier, test_data)
|
es_model, classifier, random_rows(training_data[0], 20)
|
||||||
|
)
|
||||||
|
|
||||||
# Clean up
|
# Clean up
|
||||||
es_model.delete_model()
|
es_model.delete_model()
|
||||||
@ -221,8 +228,9 @@ class TestImportedMLModel:
|
|||||||
es_compress_model_definition=compress_model_definition,
|
es_compress_model_definition=compress_model_definition,
|
||||||
)
|
)
|
||||||
# Get some test results
|
# Get some test results
|
||||||
test_data = [[0.1, 0.2, 0.3, -0.5, 1.0], [1.6, 2.1, -10, 50, -1.0]]
|
check_prediction_equality(
|
||||||
check_prediction_equality(es_model, regressor, test_data)
|
es_model, regressor, random_rows(training_data[0], 20)
|
||||||
|
)
|
||||||
|
|
||||||
# Clean up
|
# Clean up
|
||||||
es_model.delete_model()
|
es_model.delete_model()
|
||||||
@ -257,8 +265,9 @@ class TestImportedMLModel:
|
|||||||
es_compress_model_definition=compress_model_definition,
|
es_compress_model_definition=compress_model_definition,
|
||||||
)
|
)
|
||||||
# Get some test results
|
# Get some test results
|
||||||
test_data = [[0.1, 0.2, 0.3, -0.5, 1.0], [1.6, 2.1, -10, 50, -1.0]]
|
check_prediction_equality(
|
||||||
check_prediction_equality(es_model, classifier, test_data)
|
es_model, classifier, random_rows(training_data[0], 20)
|
||||||
|
)
|
||||||
|
|
||||||
# Clean up
|
# Clean up
|
||||||
es_model.delete_model()
|
es_model.delete_model()
|
||||||
@ -290,8 +299,9 @@ class TestImportedMLModel:
|
|||||||
ES_TEST_CLIENT, model_id, classifier, feature_names, overwrite=True
|
ES_TEST_CLIENT, model_id, classifier, feature_names, overwrite=True
|
||||||
)
|
)
|
||||||
# Get some test results
|
# Get some test results
|
||||||
test_data = [[0.1, 0.2, 0.3, -0.5, 1.0], [1.6, 2.1, -10, 50, -1.0]]
|
check_prediction_equality(
|
||||||
check_prediction_equality(es_model, classifier, test_data)
|
es_model, classifier, random_rows(training_data[0], 20)
|
||||||
|
)
|
||||||
|
|
||||||
# Clean up
|
# Clean up
|
||||||
es_model.delete_model()
|
es_model.delete_model()
|
||||||
@ -326,8 +336,9 @@ class TestImportedMLModel:
|
|||||||
es_compress_model_definition=compress_model_definition,
|
es_compress_model_definition=compress_model_definition,
|
||||||
)
|
)
|
||||||
# Get some test results
|
# Get some test results
|
||||||
test_data = [[0.1, 0.2, 0.3, -0.5, 1.0], [1.6, 2.1, -10, 50, -1.0]]
|
check_prediction_equality(
|
||||||
check_prediction_equality(es_model, regressor, test_data)
|
es_model, regressor, random_rows(training_data[0], 20)
|
||||||
|
)
|
||||||
|
|
||||||
# Clean up
|
# Clean up
|
||||||
es_model.delete_model()
|
es_model.delete_model()
|
||||||
@ -393,8 +404,49 @@ class TestImportedMLModel:
|
|||||||
es_compress_model_definition=compress_model_definition,
|
es_compress_model_definition=compress_model_definition,
|
||||||
)
|
)
|
||||||
# Get some test results
|
# Get some test results
|
||||||
test_data = [[0.1, 0.2, 0.3, -0.5, 1.0], [1.6, 2.1, -10, 50, -1.0]]
|
check_prediction_equality(
|
||||||
check_prediction_equality(es_model, regressor, test_data)
|
es_model, regressor, random_rows(training_data[0], 20)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Clean up
|
||||||
|
es_model.delete_model()
|
||||||
|
|
||||||
|
@requires_lightgbm
|
||||||
|
@pytest.mark.parametrize("compress_model_definition", [True, False])
|
||||||
|
@pytest.mark.parametrize("objective", ["binary", "multiclass", "multiclassova"])
|
||||||
|
@pytest.mark.parametrize("booster", ["gbdt", "dart", "goss"])
|
||||||
|
def test_lgbm_classifier_objectives_and_booster(
|
||||||
|
self, compress_model_definition, objective, booster
|
||||||
|
):
|
||||||
|
# test both multiple and binary classification
|
||||||
|
if objective.startswith("multi"):
|
||||||
|
training_data = datasets.make_classification(
|
||||||
|
n_features=5, n_classes=3, n_informative=3
|
||||||
|
)
|
||||||
|
classifier = LGBMClassifier(boosting_type=booster, objective=objective)
|
||||||
|
else:
|
||||||
|
training_data = datasets.make_classification(n_features=5)
|
||||||
|
classifier = LGBMClassifier(boosting_type=booster, objective=objective)
|
||||||
|
|
||||||
|
# Train model
|
||||||
|
classifier.fit(training_data[0], training_data[1])
|
||||||
|
|
||||||
|
# Serialise the models to Elasticsearch
|
||||||
|
feature_names = ["Column_0", "Column_1", "Column_2", "Column_3", "Column_4"]
|
||||||
|
model_id = "test_lgbm_classifier"
|
||||||
|
|
||||||
|
es_model = ImportedMLModel(
|
||||||
|
ES_TEST_CLIENT,
|
||||||
|
model_id,
|
||||||
|
classifier,
|
||||||
|
feature_names,
|
||||||
|
overwrite=True,
|
||||||
|
es_compress_model_definition=compress_model_definition,
|
||||||
|
)
|
||||||
|
|
||||||
|
check_prediction_equality(
|
||||||
|
es_model, classifier, random_rows(training_data[0], 20)
|
||||||
|
)
|
||||||
|
|
||||||
# Clean up
|
# Clean up
|
||||||
es_model.delete_model()
|
es_model.delete_model()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user