diff --git a/eland/ml/_model_serializer.py b/eland/ml/_model_serializer.py index da45a00..49d461e 100644 --- a/eland/ml/_model_serializer.py +++ b/eland/ml/_model_serializer.py @@ -70,6 +70,7 @@ class TreeNode: split_feature: Optional[int] = None, threshold: Optional[float] = None, leaf_value: Optional[List[float]] = None, + number_samples: Optional[int] = None, ): self._node_idx = node_idx self._decision_type = decision_type @@ -79,6 +80,7 @@ class TreeNode: self._threshold = threshold self._leaf_value = leaf_value self._default_left = default_left + self._number_samples = number_samples @property def node_idx(self) -> int: @@ -93,6 +95,7 @@ class TreeNode: add_if_exists(d, "right_child", self._right_child) add_if_exists(d, "split_feature", self._split_feature) add_if_exists(d, "threshold", self._threshold) + add_if_exists(d, "number_samples", self._number_samples) else: if len(self._leaf_value) == 1: # Support Elasticsearch 7.6 which only diff --git a/eland/ml/transformers/lightgbm.py b/eland/ml/transformers/lightgbm.py index 8e96957..f6293d6 100644 --- a/eland/ml/transformers/lightgbm.py +++ b/eland/ml/transformers/lightgbm.py @@ -88,6 +88,7 @@ class LGBMForestTransformer(ModelTransformer): decision_type=transform_decider(tree_node_json_obj["decision_type"]), left_child=left_child, right_child=right_child, + number_samples=int(tree_node_json_obj["internal_count"]), ) def make_leaf_node( @@ -96,6 +97,9 @@ class LGBMForestTransformer(ModelTransformer): return TreeNode( node_idx=node_id, leaf_value=[float(tree_node_json_obj["leaf_value"])], + number_samples=int(tree_node_json_obj["leaf_count"]) + if "leaf_count" in tree_node_json_obj + else None, ) def build_tree(self, tree_id: int, tree_json_obj: Dict[str, Any]) -> Tree: @@ -228,7 +232,13 @@ class LGBMClassifierTransformer(LGBMForestTransformer): return super().make_leaf_node(tree_id, node_id, tree_node_json_obj) leaf_val = [0.0] * self.n_classes leaf_val[tree_id % self.n_classes] = float(tree_node_json_obj["leaf_value"]) - return TreeNode(node_idx=node_id, leaf_value=leaf_val) + return TreeNode( + node_idx=node_id, + leaf_value=leaf_val, + number_samples=int(tree_node_json_obj["leaf_count"]) + if "leaf_count" in tree_node_json_obj + else None, + ) def check_model_booster(self) -> None: if self._model.params["boosting_type"] not in {"gbdt", "rf", "dart", "goss"}: