mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
[ML] Add support for multi:softmax|softprob XGBClassifier
This commit is contained in:
parent
5c901e8f1b
commit
efb9e3b4c4
@ -19,7 +19,7 @@ import base64
|
||||
import gzip
|
||||
import json
|
||||
from abc import ABC
|
||||
from typing import Sequence, Dict, Any, Optional
|
||||
from typing import Sequence, Dict, Any, Optional, List
|
||||
|
||||
|
||||
def add_if_exists(d: Dict[str, Any], k: str, v: Any) -> None:
|
||||
@ -69,7 +69,7 @@ class TreeNode:
|
||||
right_child: Optional[int] = None,
|
||||
split_feature: Optional[int] = None,
|
||||
threshold: Optional[float] = None,
|
||||
leaf_value: Optional[float] = None,
|
||||
leaf_value: Optional[List[float]] = None,
|
||||
):
|
||||
self._node_idx = node_idx
|
||||
self._decision_type = decision_type
|
||||
|
@ -60,7 +60,17 @@ class ImportedMLModel(MLModel):
|
||||
- sklearn.ensemble.RandomForestRegressor
|
||||
- sklearn.ensemble.RandomForestClassifier
|
||||
- xgboost.XGBClassifier
|
||||
- only the following operators are supported:
|
||||
- "binary:logistic"
|
||||
- "binary:hinge"
|
||||
- "multi:softmax"
|
||||
- "multi:softprob"
|
||||
- xgboost.XGBRegressor
|
||||
- only the following operators are supportd:
|
||||
- "reg:squarederror"
|
||||
- "reg:linear"
|
||||
- "reg:squaredlogerror"
|
||||
- "reg:logistic"
|
||||
|
||||
feature_names: List[str]
|
||||
Names of the features (required)
|
||||
|
@ -79,10 +79,10 @@ class SKLearnTransformer(ModelTransformer):
|
||||
if (
|
||||
value.shape[1] == 1
|
||||
): # classification requires more than one value, so assume regression
|
||||
leaf_value = float(value[0][0])
|
||||
leaf_value = [float(value[0][0])]
|
||||
else:
|
||||
# the classification value, which is the index of the largest value
|
||||
leaf_value = int(np.argmax(value))
|
||||
leaf_value = [float(np.argmax(value))]
|
||||
return TreeNode(
|
||||
node_index,
|
||||
decision_type=self._node_decision_type,
|
||||
|
@ -49,6 +49,7 @@ class XGBoostForestTransformer(ModelTransformer):
|
||||
self._node_decision_type = "lt"
|
||||
self._base_score = base_score
|
||||
self._objective = objective
|
||||
self._feature_dict = dict(zip(feature_names, range(len(feature_names))))
|
||||
|
||||
def get_feature_id(self, feature_id: str) -> int:
|
||||
if feature_id[0] == "f":
|
||||
@ -56,6 +57,9 @@ class XGBoostForestTransformer(ModelTransformer):
|
||||
return int(feature_id[1:])
|
||||
except ValueError:
|
||||
raise RuntimeError(f"Unable to interpret '{feature_id}'")
|
||||
f_id = self._feature_dict.get(feature_id)
|
||||
if f_id:
|
||||
return f_id
|
||||
else:
|
||||
try:
|
||||
return int(feature_id)
|
||||
@ -81,10 +85,13 @@ class XGBoostForestTransformer(ModelTransformer):
|
||||
f"cannot determine node index or tree from '{node_id}' for tree {curr_tree}"
|
||||
)
|
||||
|
||||
def build_leaf_node(self, row: pd.Series, curr_tree: int) -> TreeNode:
|
||||
return TreeNode(node_idx=row["Node"], leaf_value=[float(row["Gain"])])
|
||||
|
||||
def build_tree_node(self, row: pd.Series, curr_tree: int) -> TreeNode:
|
||||
node_index = row["Node"]
|
||||
if row["Feature"] == "Leaf":
|
||||
return TreeNode(node_idx=node_index, leaf_value=float(row["Gain"]))
|
||||
return self.build_leaf_node(row, curr_tree)
|
||||
else:
|
||||
return TreeNode(
|
||||
node_idx=node_index,
|
||||
@ -96,12 +103,16 @@ class XGBoostForestTransformer(ModelTransformer):
|
||||
)
|
||||
|
||||
def build_tree(self, nodes: List[TreeNode]) -> Tree:
|
||||
return Tree(feature_names=self._feature_names, tree_structure=nodes)
|
||||
return Tree(
|
||||
feature_names=self._feature_names,
|
||||
tree_structure=nodes,
|
||||
target_type=self.determine_target_type(),
|
||||
)
|
||||
|
||||
def build_base_score_stump(self) -> Tree:
|
||||
return Tree(
|
||||
feature_names=self._feature_names,
|
||||
tree_structure=[TreeNode(0, leaf_value=self._base_score)],
|
||||
tree_structure=[TreeNode(0, leaf_value=[self._base_score])],
|
||||
)
|
||||
|
||||
def build_forest(self) -> List[Tree]:
|
||||
@ -209,12 +220,30 @@ class XGBoostClassifierTransformer(XGBoostForestTransformer):
|
||||
model.objective,
|
||||
classification_labels,
|
||||
)
|
||||
if model.classes_ is None:
|
||||
n_estimators = model.get_params()["n_estimators"]
|
||||
num_trees = model.get_booster().trees_to_dataframe()["Tree"].max() + 1
|
||||
self._num_classes = num_trees // n_estimators
|
||||
else:
|
||||
self._num_classes = len(model.classes_)
|
||||
|
||||
def build_leaf_node(self, row: pd.Series, curr_tree: int) -> TreeNode:
|
||||
if self._num_classes <= 2:
|
||||
return super().build_leaf_node(row, curr_tree)
|
||||
leaf_val = [0.0] * self._num_classes
|
||||
leaf_val[curr_tree % self._num_classes] = float(row["Gain"])
|
||||
return TreeNode(node_idx=row["Node"], leaf_value=leaf_val)
|
||||
|
||||
def determine_target_type(self) -> str:
|
||||
return "classification"
|
||||
|
||||
def is_objective_supported(self) -> bool:
|
||||
return self._objective in {"binary:logistic", "binary:hinge"}
|
||||
return self._objective in {
|
||||
"binary:logistic",
|
||||
"binary:hinge",
|
||||
"multi:softmax",
|
||||
"multi:softprob",
|
||||
}
|
||||
|
||||
def build_aggregator_output(self) -> Dict[str, Any]:
|
||||
return {"logistic_regression": {}}
|
||||
|
@ -226,10 +226,19 @@ class TestImportedMLModel:
|
||||
|
||||
@requires_xgboost
|
||||
@pytest.mark.parametrize("compress_model_definition", [True, False])
|
||||
def test_xgb_classifier(self, compress_model_definition):
|
||||
@pytest.mark.parametrize("multi_class", [True, False])
|
||||
def test_xgb_classifier(self, compress_model_definition, multi_class):
|
||||
# test both multiple and binary classification
|
||||
if multi_class:
|
||||
training_data = datasets.make_classification(
|
||||
n_features=5, n_classes=3, n_informative=3
|
||||
)
|
||||
classifier = XGBClassifier(booster="gbtree", objective="multi:softmax")
|
||||
else:
|
||||
training_data = datasets.make_classification(n_features=5)
|
||||
classifier = XGBClassifier(booster="gbtree")
|
||||
|
||||
# Train model
|
||||
training_data = datasets.make_classification(n_features=5)
|
||||
classifier = XGBClassifier(booster="gbtree")
|
||||
classifier.fit(training_data[0], training_data[1])
|
||||
|
||||
# Get some test results
|
||||
|
Loading…
x
Reference in New Issue
Block a user