mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
Add number_samples to LightGBM MLModel and leaf_count to leaf nodes
* Add number_samples to lightgbm ML Model * Add leaf_count for leaf nodes
This commit is contained in:
parent
dabb327b8b
commit
995f2432b6
@ -70,6 +70,7 @@ class TreeNode:
|
||||
split_feature: Optional[int] = None,
|
||||
threshold: Optional[float] = None,
|
||||
leaf_value: Optional[List[float]] = None,
|
||||
number_samples: Optional[int] = None,
|
||||
):
|
||||
self._node_idx = node_idx
|
||||
self._decision_type = decision_type
|
||||
@ -79,6 +80,7 @@ class TreeNode:
|
||||
self._threshold = threshold
|
||||
self._leaf_value = leaf_value
|
||||
self._default_left = default_left
|
||||
self._number_samples = number_samples
|
||||
|
||||
@property
|
||||
def node_idx(self) -> int:
|
||||
@ -93,6 +95,7 @@ class TreeNode:
|
||||
add_if_exists(d, "right_child", self._right_child)
|
||||
add_if_exists(d, "split_feature", self._split_feature)
|
||||
add_if_exists(d, "threshold", self._threshold)
|
||||
add_if_exists(d, "number_samples", self._number_samples)
|
||||
else:
|
||||
if len(self._leaf_value) == 1:
|
||||
# Support Elasticsearch 7.6 which only
|
||||
|
@ -88,6 +88,7 @@ class LGBMForestTransformer(ModelTransformer):
|
||||
decision_type=transform_decider(tree_node_json_obj["decision_type"]),
|
||||
left_child=left_child,
|
||||
right_child=right_child,
|
||||
number_samples=int(tree_node_json_obj["internal_count"]),
|
||||
)
|
||||
|
||||
def make_leaf_node(
|
||||
@ -96,6 +97,9 @@ class LGBMForestTransformer(ModelTransformer):
|
||||
return TreeNode(
|
||||
node_idx=node_id,
|
||||
leaf_value=[float(tree_node_json_obj["leaf_value"])],
|
||||
number_samples=int(tree_node_json_obj["leaf_count"])
|
||||
if "leaf_count" in tree_node_json_obj
|
||||
else None,
|
||||
)
|
||||
|
||||
def build_tree(self, tree_id: int, tree_json_obj: Dict[str, Any]) -> Tree:
|
||||
@ -228,7 +232,13 @@ class LGBMClassifierTransformer(LGBMForestTransformer):
|
||||
return super().make_leaf_node(tree_id, node_id, tree_node_json_obj)
|
||||
leaf_val = [0.0] * self.n_classes
|
||||
leaf_val[tree_id % self.n_classes] = float(tree_node_json_obj["leaf_value"])
|
||||
return TreeNode(node_idx=node_id, leaf_value=leaf_val)
|
||||
return TreeNode(
|
||||
node_idx=node_id,
|
||||
leaf_value=leaf_val,
|
||||
number_samples=int(tree_node_json_obj["leaf_count"])
|
||||
if "leaf_count" in tree_node_json_obj
|
||||
else None,
|
||||
)
|
||||
|
||||
def check_model_booster(self) -> None:
|
||||
if self._model.params["boosting_type"] not in {"gbdt", "rf", "dart", "goss"}:
|
||||
|
Loading…
x
Reference in New Issue
Block a user