mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
[ML] Fix XGBoost model import for xgboost>=1.6
This commit is contained in:
parent
cb839a9ac9
commit
8294224e34
@ -27,7 +27,7 @@ from .base import ModelTransformer
|
|||||||
|
|
||||||
import_optional_dependency("xgboost", on_version="warn")
|
import_optional_dependency("xgboost", on_version="warn")
|
||||||
|
|
||||||
from xgboost import Booster, XGBClassifier, XGBRegressor # type: ignore
|
from xgboost import Booster, XGBClassifier, XGBModel, XGBRegressor # type: ignore
|
||||||
|
|
||||||
|
|
||||||
class XGBoostForestTransformer(ModelTransformer):
|
class XGBoostForestTransformer(ModelTransformer):
|
||||||
@ -125,8 +125,6 @@ class XGBoostForestTransformer(ModelTransformer):
|
|||||||
|
|
||||||
:return: A list of Tree objects
|
:return: A list of Tree objects
|
||||||
"""
|
"""
|
||||||
self.check_model_booster()
|
|
||||||
|
|
||||||
tree_table: pd.DataFrame = self._model.trees_to_dataframe()
|
tree_table: pd.DataFrame = self._model.trees_to_dataframe()
|
||||||
transformed_trees = []
|
transformed_trees = []
|
||||||
curr_tree: Optional[Any] = None
|
curr_tree: Optional[Any] = None
|
||||||
@ -155,17 +153,21 @@ class XGBoostForestTransformer(ModelTransformer):
|
|||||||
def is_objective_supported(self) -> bool:
|
def is_objective_supported(self) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def check_model_booster(self) -> None:
|
@staticmethod
|
||||||
|
def check_model_booster(model: XGBModel) -> None:
|
||||||
# xgboost v1 made booster default to 'None' meaning 'gbtree'
|
# xgboost v1 made booster default to 'None' meaning 'gbtree'
|
||||||
if self._model.booster not in {"dart", "gbtree", None}:
|
booster = (
|
||||||
|
model.get_booster().booster
|
||||||
|
if hasattr(model.get_booster(), "booster")
|
||||||
|
else model.booster
|
||||||
|
)
|
||||||
|
if booster not in {"dart", "gbtree", None}:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"booster must exist and be of type 'dart' or "
|
f"booster must exist and be of type 'dart' or "
|
||||||
f"'gbtree', was {self._model.booster!r}"
|
f"'gbtree', was {booster!r}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def transform(self) -> Ensemble:
|
def transform(self) -> Ensemble:
|
||||||
self.check_model_booster()
|
|
||||||
|
|
||||||
if not self.is_objective_supported():
|
if not self.is_objective_supported():
|
||||||
raise ValueError(f"Unsupported objective '{self._objective}'")
|
raise ValueError(f"Unsupported objective '{self._objective}'")
|
||||||
|
|
||||||
@ -189,6 +191,7 @@ class XGBoostRegressorTransformer(XGBoostForestTransformer):
|
|||||||
super().__init__(
|
super().__init__(
|
||||||
model.get_booster(), feature_names, base_score, model.objective
|
model.get_booster(), feature_names, base_score, model.objective
|
||||||
)
|
)
|
||||||
|
XGBoostForestTransformer.check_model_booster(model)
|
||||||
|
|
||||||
def determine_target_type(self) -> str:
|
def determine_target_type(self) -> str:
|
||||||
return "regression"
|
return "regression"
|
||||||
@ -226,6 +229,7 @@ class XGBoostClassifierTransformer(XGBoostForestTransformer):
|
|||||||
model.objective,
|
model.objective,
|
||||||
classification_labels,
|
classification_labels,
|
||||||
)
|
)
|
||||||
|
XGBoostForestTransformer.check_model_booster(model)
|
||||||
if model.classes_ is None:
|
if model.classes_ is None:
|
||||||
n_estimators = model.get_params()["n_estimators"]
|
n_estimators = model.get_params()["n_estimators"]
|
||||||
num_trees = model.get_booster().trees_to_dataframe()["Tree"].max() + 1
|
num_trees = model.get_booster().trees_to_dataframe()["Tree"].max() + 1
|
||||||
|
Loading…
x
Reference in New Issue
Block a user