diff --git a/eland/ml/_model_serializer.py b/eland/ml/_model_serializer.py index 52ae12b..5ffd968 100644 --- a/eland/ml/_model_serializer.py +++ b/eland/ml/_model_serializer.py @@ -19,7 +19,7 @@ import base64 import gzip import json from abc import ABC -from typing import Sequence, Dict, Any, Optional +from typing import Sequence, Dict, Any, Optional, List def add_if_exists(d: Dict[str, Any], k: str, v: Any) -> None: @@ -69,7 +69,7 @@ class TreeNode: right_child: Optional[int] = None, split_feature: Optional[int] = None, threshold: Optional[float] = None, - leaf_value: Optional[float] = None, + leaf_value: Optional[List[float]] = None, ): self._node_idx = node_idx self._decision_type = decision_type diff --git a/eland/ml/imported_ml_model.py b/eland/ml/imported_ml_model.py index 00c50f7..477d079 100644 --- a/eland/ml/imported_ml_model.py +++ b/eland/ml/imported_ml_model.py @@ -60,7 +60,17 @@ class ImportedMLModel(MLModel): - sklearn.ensemble.RandomForestRegressor - sklearn.ensemble.RandomForestClassifier - xgboost.XGBClassifier + - only the following operators are supported: + - "binary:logistic" + - "binary:hinge" + - "multi:softmax" + - "multi:softprob" - xgboost.XGBRegressor + - only the following operators are supportd: + - "reg:squarederror" + - "reg:linear" + - "reg:squaredlogerror" + - "reg:logistic" feature_names: List[str] Names of the features (required) diff --git a/eland/ml/transformers/sklearn.py b/eland/ml/transformers/sklearn.py index 78ead81..3bf842a 100644 --- a/eland/ml/transformers/sklearn.py +++ b/eland/ml/transformers/sklearn.py @@ -79,10 +79,10 @@ class SKLearnTransformer(ModelTransformer): if ( value.shape[1] == 1 ): # classification requires more than one value, so assume regression - leaf_value = float(value[0][0]) + leaf_value = [float(value[0][0])] else: # the classification value, which is the index of the largest value - leaf_value = int(np.argmax(value)) + leaf_value = [float(np.argmax(value))] return TreeNode( node_index, decision_type=self._node_decision_type, diff --git a/eland/ml/transformers/xgboost.py b/eland/ml/transformers/xgboost.py index c036306..ff3c404 100644 --- a/eland/ml/transformers/xgboost.py +++ b/eland/ml/transformers/xgboost.py @@ -49,6 +49,7 @@ class XGBoostForestTransformer(ModelTransformer): self._node_decision_type = "lt" self._base_score = base_score self._objective = objective + self._feature_dict = dict(zip(feature_names, range(len(feature_names)))) def get_feature_id(self, feature_id: str) -> int: if feature_id[0] == "f": @@ -56,6 +57,9 @@ class XGBoostForestTransformer(ModelTransformer): return int(feature_id[1:]) except ValueError: raise RuntimeError(f"Unable to interpret '{feature_id}'") + f_id = self._feature_dict.get(feature_id) + if f_id: + return f_id else: try: return int(feature_id) @@ -81,10 +85,13 @@ class XGBoostForestTransformer(ModelTransformer): f"cannot determine node index or tree from '{node_id}' for tree {curr_tree}" ) + def build_leaf_node(self, row: pd.Series, curr_tree: int) -> TreeNode: + return TreeNode(node_idx=row["Node"], leaf_value=[float(row["Gain"])]) + def build_tree_node(self, row: pd.Series, curr_tree: int) -> TreeNode: node_index = row["Node"] if row["Feature"] == "Leaf": - return TreeNode(node_idx=node_index, leaf_value=float(row["Gain"])) + return self.build_leaf_node(row, curr_tree) else: return TreeNode( node_idx=node_index, @@ -96,12 +103,16 @@ class XGBoostForestTransformer(ModelTransformer): ) def build_tree(self, nodes: List[TreeNode]) -> Tree: - return Tree(feature_names=self._feature_names, tree_structure=nodes) + return Tree( + feature_names=self._feature_names, + tree_structure=nodes, + target_type=self.determine_target_type(), + ) def build_base_score_stump(self) -> Tree: return Tree( feature_names=self._feature_names, - tree_structure=[TreeNode(0, leaf_value=self._base_score)], + tree_structure=[TreeNode(0, leaf_value=[self._base_score])], ) def build_forest(self) -> List[Tree]: @@ -209,12 +220,30 @@ class XGBoostClassifierTransformer(XGBoostForestTransformer): model.objective, classification_labels, ) + if model.classes_ is None: + n_estimators = model.get_params()["n_estimators"] + num_trees = model.get_booster().trees_to_dataframe()["Tree"].max() + 1 + self._num_classes = num_trees // n_estimators + else: + self._num_classes = len(model.classes_) + + def build_leaf_node(self, row: pd.Series, curr_tree: int) -> TreeNode: + if self._num_classes <= 2: + return super().build_leaf_node(row, curr_tree) + leaf_val = [0.0] * self._num_classes + leaf_val[curr_tree % self._num_classes] = float(row["Gain"]) + return TreeNode(node_idx=row["Node"], leaf_value=leaf_val) def determine_target_type(self) -> str: return "classification" def is_objective_supported(self) -> bool: - return self._objective in {"binary:logistic", "binary:hinge"} + return self._objective in { + "binary:logistic", + "binary:hinge", + "multi:softmax", + "multi:softprob", + } def build_aggregator_output(self) -> Dict[str, Any]: return {"logistic_regression": {}} diff --git a/eland/tests/ml/test_imported_ml_model_pytest.py b/eland/tests/ml/test_imported_ml_model_pytest.py index 88fa4c9..451b712 100644 --- a/eland/tests/ml/test_imported_ml_model_pytest.py +++ b/eland/tests/ml/test_imported_ml_model_pytest.py @@ -226,10 +226,19 @@ class TestImportedMLModel: @requires_xgboost @pytest.mark.parametrize("compress_model_definition", [True, False]) - def test_xgb_classifier(self, compress_model_definition): + @pytest.mark.parametrize("multi_class", [True, False]) + def test_xgb_classifier(self, compress_model_definition, multi_class): + # test both multiple and binary classification + if multi_class: + training_data = datasets.make_classification( + n_features=5, n_classes=3, n_informative=3 + ) + classifier = XGBClassifier(booster="gbtree", objective="multi:softmax") + else: + training_data = datasets.make_classification(n_features=5) + classifier = XGBClassifier(booster="gbtree") + # Train model - training_data = datasets.make_classification(n_features=5) - classifier = XGBClassifier(booster="gbtree") classifier.fit(training_data[0], training_data[1]) # Get some test results