mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
Add number_samples to LightGBM MLModel and leaf_count to leaf nodes
* Add number_samples to lightgbm ML Model * Add leaf_count for leaf nodes
This commit is contained in:
parent
dabb327b8b
commit
995f2432b6
@ -70,6 +70,7 @@ class TreeNode:
|
|||||||
split_feature: Optional[int] = None,
|
split_feature: Optional[int] = None,
|
||||||
threshold: Optional[float] = None,
|
threshold: Optional[float] = None,
|
||||||
leaf_value: Optional[List[float]] = None,
|
leaf_value: Optional[List[float]] = None,
|
||||||
|
number_samples: Optional[int] = None,
|
||||||
):
|
):
|
||||||
self._node_idx = node_idx
|
self._node_idx = node_idx
|
||||||
self._decision_type = decision_type
|
self._decision_type = decision_type
|
||||||
@ -79,6 +80,7 @@ class TreeNode:
|
|||||||
self._threshold = threshold
|
self._threshold = threshold
|
||||||
self._leaf_value = leaf_value
|
self._leaf_value = leaf_value
|
||||||
self._default_left = default_left
|
self._default_left = default_left
|
||||||
|
self._number_samples = number_samples
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def node_idx(self) -> int:
|
def node_idx(self) -> int:
|
||||||
@ -93,6 +95,7 @@ class TreeNode:
|
|||||||
add_if_exists(d, "right_child", self._right_child)
|
add_if_exists(d, "right_child", self._right_child)
|
||||||
add_if_exists(d, "split_feature", self._split_feature)
|
add_if_exists(d, "split_feature", self._split_feature)
|
||||||
add_if_exists(d, "threshold", self._threshold)
|
add_if_exists(d, "threshold", self._threshold)
|
||||||
|
add_if_exists(d, "number_samples", self._number_samples)
|
||||||
else:
|
else:
|
||||||
if len(self._leaf_value) == 1:
|
if len(self._leaf_value) == 1:
|
||||||
# Support Elasticsearch 7.6 which only
|
# Support Elasticsearch 7.6 which only
|
||||||
|
@ -88,6 +88,7 @@ class LGBMForestTransformer(ModelTransformer):
|
|||||||
decision_type=transform_decider(tree_node_json_obj["decision_type"]),
|
decision_type=transform_decider(tree_node_json_obj["decision_type"]),
|
||||||
left_child=left_child,
|
left_child=left_child,
|
||||||
right_child=right_child,
|
right_child=right_child,
|
||||||
|
number_samples=int(tree_node_json_obj["internal_count"]),
|
||||||
)
|
)
|
||||||
|
|
||||||
def make_leaf_node(
|
def make_leaf_node(
|
||||||
@ -96,6 +97,9 @@ class LGBMForestTransformer(ModelTransformer):
|
|||||||
return TreeNode(
|
return TreeNode(
|
||||||
node_idx=node_id,
|
node_idx=node_id,
|
||||||
leaf_value=[float(tree_node_json_obj["leaf_value"])],
|
leaf_value=[float(tree_node_json_obj["leaf_value"])],
|
||||||
|
number_samples=int(tree_node_json_obj["leaf_count"])
|
||||||
|
if "leaf_count" in tree_node_json_obj
|
||||||
|
else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
def build_tree(self, tree_id: int, tree_json_obj: Dict[str, Any]) -> Tree:
|
def build_tree(self, tree_id: int, tree_json_obj: Dict[str, Any]) -> Tree:
|
||||||
@ -228,7 +232,13 @@ class LGBMClassifierTransformer(LGBMForestTransformer):
|
|||||||
return super().make_leaf_node(tree_id, node_id, tree_node_json_obj)
|
return super().make_leaf_node(tree_id, node_id, tree_node_json_obj)
|
||||||
leaf_val = [0.0] * self.n_classes
|
leaf_val = [0.0] * self.n_classes
|
||||||
leaf_val[tree_id % self.n_classes] = float(tree_node_json_obj["leaf_value"])
|
leaf_val[tree_id % self.n_classes] = float(tree_node_json_obj["leaf_value"])
|
||||||
return TreeNode(node_idx=node_id, leaf_value=leaf_val)
|
return TreeNode(
|
||||||
|
node_idx=node_id,
|
||||||
|
leaf_value=leaf_val,
|
||||||
|
number_samples=int(tree_node_json_obj["leaf_count"])
|
||||||
|
if "leaf_count" in tree_node_json_obj
|
||||||
|
else None,
|
||||||
|
)
|
||||||
|
|
||||||
def check_model_booster(self) -> None:
|
def check_model_booster(self) -> None:
|
||||||
if self._model.params["boosting_type"] not in {"gbdt", "rf", "dart", "goss"}:
|
if self._model.params["boosting_type"] not in {"gbdt", "rf", "dart", "goss"}:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user