Add number_samples to LightGBM MLModel and leaf_count to leaf nodes

* Add number_samples to lightgbm ML Model

* Add leaf_count for leaf nodes
This commit is contained in:
P. Sai Vinay 2021-09-30 18:43:44 +05:30 committed by GitHub
parent dabb327b8b
commit 995f2432b6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 14 additions and 1 deletions

View File

@ -70,6 +70,7 @@ class TreeNode:
split_feature: Optional[int] = None,
threshold: Optional[float] = None,
leaf_value: Optional[List[float]] = None,
number_samples: Optional[int] = None,
):
self._node_idx = node_idx
self._decision_type = decision_type
@ -79,6 +80,7 @@ class TreeNode:
self._threshold = threshold
self._leaf_value = leaf_value
self._default_left = default_left
self._number_samples = number_samples
@property
def node_idx(self) -> int:
@ -93,6 +95,7 @@ class TreeNode:
add_if_exists(d, "right_child", self._right_child)
add_if_exists(d, "split_feature", self._split_feature)
add_if_exists(d, "threshold", self._threshold)
add_if_exists(d, "number_samples", self._number_samples)
else:
if len(self._leaf_value) == 1:
# Support Elasticsearch 7.6 which only

View File

@ -88,6 +88,7 @@ class LGBMForestTransformer(ModelTransformer):
decision_type=transform_decider(tree_node_json_obj["decision_type"]),
left_child=left_child,
right_child=right_child,
number_samples=int(tree_node_json_obj["internal_count"]),
)
def make_leaf_node(
@ -96,6 +97,9 @@ class LGBMForestTransformer(ModelTransformer):
return TreeNode(
node_idx=node_id,
leaf_value=[float(tree_node_json_obj["leaf_value"])],
number_samples=int(tree_node_json_obj["leaf_count"])
if "leaf_count" in tree_node_json_obj
else None,
)
def build_tree(self, tree_id: int, tree_json_obj: Dict[str, Any]) -> Tree:
@ -228,7 +232,13 @@ class LGBMClassifierTransformer(LGBMForestTransformer):
return super().make_leaf_node(tree_id, node_id, tree_node_json_obj)
leaf_val = [0.0] * self.n_classes
leaf_val[tree_id % self.n_classes] = float(tree_node_json_obj["leaf_value"])
return TreeNode(node_idx=node_id, leaf_value=leaf_val)
return TreeNode(
node_idx=node_id,
leaf_value=leaf_val,
number_samples=int(tree_node_json_obj["leaf_count"])
if "leaf_count" in tree_node_json_obj
else None,
)
def check_model_booster(self) -> None:
if self._model.params["boosting_type"] not in {"gbdt", "rf", "dart", "goss"}: