[ML] Fix XGBoost model import for xgboost>=1.6

This commit is contained in:
Benjamin Trent 2022-04-20 10:20:50 -04:00 committed by GitHub
parent cb839a9ac9
commit 8294224e34
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -27,7 +27,7 @@ from .base import ModelTransformer
import_optional_dependency("xgboost", on_version="warn") import_optional_dependency("xgboost", on_version="warn")
from xgboost import Booster, XGBClassifier, XGBRegressor # type: ignore from xgboost import Booster, XGBClassifier, XGBModel, XGBRegressor # type: ignore
class XGBoostForestTransformer(ModelTransformer): class XGBoostForestTransformer(ModelTransformer):
@ -125,8 +125,6 @@ class XGBoostForestTransformer(ModelTransformer):
:return: A list of Tree objects :return: A list of Tree objects
""" """
self.check_model_booster()
tree_table: pd.DataFrame = self._model.trees_to_dataframe() tree_table: pd.DataFrame = self._model.trees_to_dataframe()
transformed_trees = [] transformed_trees = []
curr_tree: Optional[Any] = None curr_tree: Optional[Any] = None
@ -155,17 +153,21 @@ class XGBoostForestTransformer(ModelTransformer):
def is_objective_supported(self) -> bool: def is_objective_supported(self) -> bool:
return False return False
def check_model_booster(self) -> None: @staticmethod
def check_model_booster(model: XGBModel) -> None:
# xgboost v1 made booster default to 'None' meaning 'gbtree' # xgboost v1 made booster default to 'None' meaning 'gbtree'
if self._model.booster not in {"dart", "gbtree", None}: booster = (
model.get_booster().booster
if hasattr(model.get_booster(), "booster")
else model.booster
)
if booster not in {"dart", "gbtree", None}:
raise ValueError( raise ValueError(
f"booster must exist and be of type 'dart' or " f"booster must exist and be of type 'dart' or "
f"'gbtree', was {self._model.booster!r}" f"'gbtree', was {booster!r}"
) )
def transform(self) -> Ensemble: def transform(self) -> Ensemble:
self.check_model_booster()
if not self.is_objective_supported(): if not self.is_objective_supported():
raise ValueError(f"Unsupported objective '{self._objective}'") raise ValueError(f"Unsupported objective '{self._objective}'")
@ -189,6 +191,7 @@ class XGBoostRegressorTransformer(XGBoostForestTransformer):
super().__init__( super().__init__(
model.get_booster(), feature_names, base_score, model.objective model.get_booster(), feature_names, base_score, model.objective
) )
XGBoostForestTransformer.check_model_booster(model)
def determine_target_type(self) -> str: def determine_target_type(self) -> str:
return "regression" return "regression"
@ -226,6 +229,7 @@ class XGBoostClassifierTransformer(XGBoostForestTransformer):
model.objective, model.objective,
classification_labels, classification_labels,
) )
XGBoostForestTransformer.check_model_booster(model)
if model.classes_ is None: if model.classes_ is None:
n_estimators = model.get_params()["n_estimators"] n_estimators = model.get_params()["n_estimators"]
num_trees = model.get_booster().trees_to_dataframe()["Tree"].max() + 1 num_trees = model.get_booster().trees_to_dataframe()["Tree"].max() + 1