mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
Add XGBRanker and transformer (#649)
* Add XGBRanker and transformer * Map XGBoostRegressorTransformer to XGBRanker * Add unit tests * Remove unused import * Revert addition of type * Update function comment * Distinguish objective based on model class
This commit is contained in:
parent
840871f9d9
commit
926f0b9b5c
@ -45,7 +45,11 @@ if TYPE_CHECKING:
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
from xgboost import XGBClassifier, XGBRegressor # type: ignore # noqa: F401
|
||||
from xgboost import ( # type: ignore # noqa: F401
|
||||
XGBClassifier,
|
||||
XGBRanker,
|
||||
XGBRegressor,
|
||||
)
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
@ -252,6 +256,7 @@ class MLModel:
|
||||
"RandomForestRegressor",
|
||||
"RandomForestClassifier",
|
||||
"XGBClassifier",
|
||||
"XGBRanker",
|
||||
"XGBRegressor",
|
||||
"LGBMRegressor",
|
||||
"LGBMClassifier",
|
||||
@ -304,6 +309,11 @@ class MLModel:
|
||||
- "binary:logistic"
|
||||
- "multi:softmax"
|
||||
- "multi:softprob"
|
||||
- xgboost.XGBRanker
|
||||
- only the following objectives are supported:
|
||||
- "rank:map"
|
||||
- "rank:ndcg"
|
||||
- "rank:pairwise"
|
||||
- xgboost.XGBRegressor
|
||||
- only the following objectives are supported:
|
||||
- "reg:squarederror"
|
||||
|
@ -27,7 +27,13 @@ from .base import ModelTransformer
|
||||
|
||||
import_optional_dependency("xgboost", on_version="warn")
|
||||
|
||||
from xgboost import Booster, XGBClassifier, XGBModel, XGBRegressor # type: ignore
|
||||
from xgboost import ( # type: ignore
|
||||
Booster,
|
||||
XGBClassifier,
|
||||
XGBModel,
|
||||
XGBRanker,
|
||||
XGBRegressor,
|
||||
)
|
||||
|
||||
|
||||
class XGBoostForestTransformer(ModelTransformer):
|
||||
@ -140,7 +146,7 @@ class XGBoostForestTransformer(ModelTransformer):
|
||||
if len(tree_nodes) > 0:
|
||||
transformed_trees.append(self.build_tree(tree_nodes))
|
||||
# We add this stump as XGBoost adds the base_score to the regression outputs
|
||||
if self._objective.partition(":")[0] == "reg":
|
||||
if self._objective.partition(":")[0] in ["reg", "rank"]:
|
||||
transformed_trees.append(self.build_base_score_stump())
|
||||
return transformed_trees
|
||||
|
||||
@ -184,6 +190,7 @@ class XGBoostForestTransformer(ModelTransformer):
|
||||
|
||||
class XGBoostRegressorTransformer(XGBoostForestTransformer):
|
||||
def __init__(self, model: XGBRegressor, feature_names: List[str]):
|
||||
self._regressor_model = model
|
||||
# XGBRegressor.base_score defaults to 0.5.
|
||||
base_score = model.base_score
|
||||
if base_score is None:
|
||||
@ -197,6 +204,13 @@ class XGBoostRegressorTransformer(XGBoostForestTransformer):
|
||||
return "regression"
|
||||
|
||||
def is_objective_supported(self) -> bool:
|
||||
if isinstance(self._regressor_model, XGBRanker):
|
||||
return self._objective in {
|
||||
"rank:pairwise",
|
||||
"rank:ndcg",
|
||||
"rank:map",
|
||||
}
|
||||
|
||||
return self._objective in {
|
||||
"reg:squarederror",
|
||||
"reg:squaredlogerror",
|
||||
@ -264,5 +278,6 @@ class XGBoostClassifierTransformer(XGBoostForestTransformer):
|
||||
|
||||
_MODEL_TRANSFORMERS: Dict[type, Type[ModelTransformer]] = {
|
||||
XGBRegressor: XGBoostRegressorTransformer,
|
||||
XGBRanker: XGBoostRegressorTransformer,
|
||||
XGBClassifier: XGBoostClassifierTransformer,
|
||||
}
|
||||
|
@ -40,7 +40,7 @@ except ImportError:
|
||||
HAS_SKLEARN = False
|
||||
|
||||
try:
|
||||
from xgboost import XGBClassifier, XGBRegressor
|
||||
from xgboost import XGBClassifier, XGBRanker, XGBRegressor
|
||||
|
||||
HAS_XGBOOST = True
|
||||
except ImportError:
|
||||
@ -555,6 +555,45 @@ class TestMLModel:
|
||||
# Clean up
|
||||
es_model.delete_model()
|
||||
|
||||
@requires_xgboost
|
||||
@pytest.mark.parametrize("compress_model_definition", [True, False])
|
||||
@pytest.mark.parametrize(
|
||||
"objective",
|
||||
["rank:ndcg", "rank:map", "rank:pairwise"],
|
||||
)
|
||||
def test_xgb_ranker(self, compress_model_definition, objective):
|
||||
X, y = datasets.make_classification(n_features=5)
|
||||
rng = np.random.default_rng()
|
||||
qid = rng.integers(0, 3, size=X.shape[0])
|
||||
|
||||
# Sort the inputs based on query index
|
||||
sorted_idx = np.argsort(qid)
|
||||
X = X[sorted_idx, :]
|
||||
y = y[sorted_idx]
|
||||
qid = qid[sorted_idx]
|
||||
|
||||
ranker = XGBRanker(objective=objective)
|
||||
ranker.fit(X, y, qid=qid)
|
||||
|
||||
# Serialise the models to Elasticsearch
|
||||
feature_names = ["f0", "f1", "f2", "f3", "f4"]
|
||||
model_id = "test_xgb_ranker"
|
||||
|
||||
es_model = MLModel.import_model(
|
||||
ES_TEST_CLIENT,
|
||||
model_id,
|
||||
ranker,
|
||||
feature_names,
|
||||
es_if_exists="replace",
|
||||
es_compress_model_definition=compress_model_definition,
|
||||
)
|
||||
|
||||
# Get some test results
|
||||
check_prediction_equality(es_model, ranker, random_rows(X, 20))
|
||||
|
||||
# Clean up
|
||||
es_model.delete_model()
|
||||
|
||||
@requires_xgboost
|
||||
@pytest.mark.parametrize("compress_model_definition", [True, False])
|
||||
@pytest.mark.parametrize(
|
||||
|
Loading…
x
Reference in New Issue
Block a user