mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
Add XGBRanker and transformer (#649)
* Add XGBRanker and transformer * Map XGBoostRegressorTransformer to XGBRanker * Add unit tests * Remove unused import * Revert addition of type * Update function comment * Distinguish objective based on model class
This commit is contained in:
parent
840871f9d9
commit
926f0b9b5c
@ -45,7 +45,11 @@ if TYPE_CHECKING:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
try:
|
try:
|
||||||
from xgboost import XGBClassifier, XGBRegressor # type: ignore # noqa: F401
|
from xgboost import ( # type: ignore # noqa: F401
|
||||||
|
XGBClassifier,
|
||||||
|
XGBRanker,
|
||||||
|
XGBRegressor,
|
||||||
|
)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
try:
|
try:
|
||||||
@ -252,6 +256,7 @@ class MLModel:
|
|||||||
"RandomForestRegressor",
|
"RandomForestRegressor",
|
||||||
"RandomForestClassifier",
|
"RandomForestClassifier",
|
||||||
"XGBClassifier",
|
"XGBClassifier",
|
||||||
|
"XGBRanker",
|
||||||
"XGBRegressor",
|
"XGBRegressor",
|
||||||
"LGBMRegressor",
|
"LGBMRegressor",
|
||||||
"LGBMClassifier",
|
"LGBMClassifier",
|
||||||
@ -304,6 +309,11 @@ class MLModel:
|
|||||||
- "binary:logistic"
|
- "binary:logistic"
|
||||||
- "multi:softmax"
|
- "multi:softmax"
|
||||||
- "multi:softprob"
|
- "multi:softprob"
|
||||||
|
- xgboost.XGBRanker
|
||||||
|
- only the following objectives are supported:
|
||||||
|
- "rank:map"
|
||||||
|
- "rank:ndcg"
|
||||||
|
- "rank:pairwise"
|
||||||
- xgboost.XGBRegressor
|
- xgboost.XGBRegressor
|
||||||
- only the following objectives are supported:
|
- only the following objectives are supported:
|
||||||
- "reg:squarederror"
|
- "reg:squarederror"
|
||||||
|
@ -27,7 +27,13 @@ from .base import ModelTransformer
|
|||||||
|
|
||||||
import_optional_dependency("xgboost", on_version="warn")
|
import_optional_dependency("xgboost", on_version="warn")
|
||||||
|
|
||||||
from xgboost import Booster, XGBClassifier, XGBModel, XGBRegressor # type: ignore
|
from xgboost import ( # type: ignore
|
||||||
|
Booster,
|
||||||
|
XGBClassifier,
|
||||||
|
XGBModel,
|
||||||
|
XGBRanker,
|
||||||
|
XGBRegressor,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class XGBoostForestTransformer(ModelTransformer):
|
class XGBoostForestTransformer(ModelTransformer):
|
||||||
@ -140,7 +146,7 @@ class XGBoostForestTransformer(ModelTransformer):
|
|||||||
if len(tree_nodes) > 0:
|
if len(tree_nodes) > 0:
|
||||||
transformed_trees.append(self.build_tree(tree_nodes))
|
transformed_trees.append(self.build_tree(tree_nodes))
|
||||||
# We add this stump as XGBoost adds the base_score to the regression outputs
|
# We add this stump as XGBoost adds the base_score to the regression outputs
|
||||||
if self._objective.partition(":")[0] == "reg":
|
if self._objective.partition(":")[0] in ["reg", "rank"]:
|
||||||
transformed_trees.append(self.build_base_score_stump())
|
transformed_trees.append(self.build_base_score_stump())
|
||||||
return transformed_trees
|
return transformed_trees
|
||||||
|
|
||||||
@ -184,6 +190,7 @@ class XGBoostForestTransformer(ModelTransformer):
|
|||||||
|
|
||||||
class XGBoostRegressorTransformer(XGBoostForestTransformer):
|
class XGBoostRegressorTransformer(XGBoostForestTransformer):
|
||||||
def __init__(self, model: XGBRegressor, feature_names: List[str]):
|
def __init__(self, model: XGBRegressor, feature_names: List[str]):
|
||||||
|
self._regressor_model = model
|
||||||
# XGBRegressor.base_score defaults to 0.5.
|
# XGBRegressor.base_score defaults to 0.5.
|
||||||
base_score = model.base_score
|
base_score = model.base_score
|
||||||
if base_score is None:
|
if base_score is None:
|
||||||
@ -197,6 +204,13 @@ class XGBoostRegressorTransformer(XGBoostForestTransformer):
|
|||||||
return "regression"
|
return "regression"
|
||||||
|
|
||||||
def is_objective_supported(self) -> bool:
|
def is_objective_supported(self) -> bool:
|
||||||
|
if isinstance(self._regressor_model, XGBRanker):
|
||||||
|
return self._objective in {
|
||||||
|
"rank:pairwise",
|
||||||
|
"rank:ndcg",
|
||||||
|
"rank:map",
|
||||||
|
}
|
||||||
|
|
||||||
return self._objective in {
|
return self._objective in {
|
||||||
"reg:squarederror",
|
"reg:squarederror",
|
||||||
"reg:squaredlogerror",
|
"reg:squaredlogerror",
|
||||||
@ -264,5 +278,6 @@ class XGBoostClassifierTransformer(XGBoostForestTransformer):
|
|||||||
|
|
||||||
_MODEL_TRANSFORMERS: Dict[type, Type[ModelTransformer]] = {
|
_MODEL_TRANSFORMERS: Dict[type, Type[ModelTransformer]] = {
|
||||||
XGBRegressor: XGBoostRegressorTransformer,
|
XGBRegressor: XGBoostRegressorTransformer,
|
||||||
|
XGBRanker: XGBoostRegressorTransformer,
|
||||||
XGBClassifier: XGBoostClassifierTransformer,
|
XGBClassifier: XGBoostClassifierTransformer,
|
||||||
}
|
}
|
||||||
|
@ -40,7 +40,7 @@ except ImportError:
|
|||||||
HAS_SKLEARN = False
|
HAS_SKLEARN = False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from xgboost import XGBClassifier, XGBRegressor
|
from xgboost import XGBClassifier, XGBRanker, XGBRegressor
|
||||||
|
|
||||||
HAS_XGBOOST = True
|
HAS_XGBOOST = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@ -555,6 +555,45 @@ class TestMLModel:
|
|||||||
# Clean up
|
# Clean up
|
||||||
es_model.delete_model()
|
es_model.delete_model()
|
||||||
|
|
||||||
|
@requires_xgboost
|
||||||
|
@pytest.mark.parametrize("compress_model_definition", [True, False])
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"objective",
|
||||||
|
["rank:ndcg", "rank:map", "rank:pairwise"],
|
||||||
|
)
|
||||||
|
def test_xgb_ranker(self, compress_model_definition, objective):
|
||||||
|
X, y = datasets.make_classification(n_features=5)
|
||||||
|
rng = np.random.default_rng()
|
||||||
|
qid = rng.integers(0, 3, size=X.shape[0])
|
||||||
|
|
||||||
|
# Sort the inputs based on query index
|
||||||
|
sorted_idx = np.argsort(qid)
|
||||||
|
X = X[sorted_idx, :]
|
||||||
|
y = y[sorted_idx]
|
||||||
|
qid = qid[sorted_idx]
|
||||||
|
|
||||||
|
ranker = XGBRanker(objective=objective)
|
||||||
|
ranker.fit(X, y, qid=qid)
|
||||||
|
|
||||||
|
# Serialise the models to Elasticsearch
|
||||||
|
feature_names = ["f0", "f1", "f2", "f3", "f4"]
|
||||||
|
model_id = "test_xgb_ranker"
|
||||||
|
|
||||||
|
es_model = MLModel.import_model(
|
||||||
|
ES_TEST_CLIENT,
|
||||||
|
model_id,
|
||||||
|
ranker,
|
||||||
|
feature_names,
|
||||||
|
es_if_exists="replace",
|
||||||
|
es_compress_model_definition=compress_model_definition,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get some test results
|
||||||
|
check_prediction_equality(es_model, ranker, random_rows(X, 20))
|
||||||
|
|
||||||
|
# Clean up
|
||||||
|
es_model.delete_model()
|
||||||
|
|
||||||
@requires_xgboost
|
@requires_xgboost
|
||||||
@pytest.mark.parametrize("compress_model_definition", [True, False])
|
@pytest.mark.parametrize("compress_model_definition", [True, False])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user