mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
Fix missing value support for XGBRanker. (#654)
* Fix missing value support for XGBRanker. * lint * Sort expected scores * lint
This commit is contained in:
parent
1190364abb
commit
2a6a4b1f06
@ -96,6 +96,7 @@ class TreeNode:
|
||||
add_if_exists(d, "split_feature", self._split_feature)
|
||||
add_if_exists(d, "threshold", self._threshold)
|
||||
add_if_exists(d, "number_samples", self._number_samples)
|
||||
add_if_exists(d, "default_left", self._default_left)
|
||||
else:
|
||||
if len(self._leaf_value) == 1:
|
||||
# Support Elasticsearch 7.6 which only
|
||||
|
@ -107,6 +107,7 @@ class XGBoostForestTransformer(ModelTransformer):
|
||||
decision_type=self._node_decision_type,
|
||||
left_child=self.extract_node_id(row["Yes"], curr_tree),
|
||||
right_child=self.extract_node_id(row["No"], curr_tree),
|
||||
default_left=row["Yes"] == row["Missing"],
|
||||
threshold=float(row["Split"]),
|
||||
split_feature=self.get_feature_id(row["Feature"]),
|
||||
)
|
||||
|
@ -23,7 +23,7 @@ import pytest
|
||||
|
||||
import eland as ed
|
||||
from eland.ml import MLModel
|
||||
from eland.ml.ltr import LTRModelConfig, QueryFeatureExtractor
|
||||
from eland.ml.ltr import FeatureLogger, LTRModelConfig, QueryFeatureExtractor
|
||||
from tests import (
|
||||
ES_TEST_CLIENT,
|
||||
ES_VERSION,
|
||||
@ -321,13 +321,27 @@ class TestMLModel:
|
||||
es_model.delete_model()
|
||||
|
||||
@requires_elasticsearch_version((8, 12))
|
||||
@requires_sklearn
|
||||
@requires_xgboost
|
||||
@pytest.mark.parametrize("compress_model_definition", [True, False])
|
||||
def test_learning_to_rank(self, compress_model_definition):
|
||||
# Train model
|
||||
training_data = datasets.make_regression(n_features=2)
|
||||
regressor = DecisionTreeRegressor()
|
||||
regressor.fit(training_data[0], training_data[1])
|
||||
@pytest.mark.parametrize(
|
||||
"objective",
|
||||
["rank:ndcg", "rank:map", "rank:pairwise"],
|
||||
)
|
||||
def test_learning_to_rank(self, objective, compress_model_definition):
|
||||
X, y = datasets.make_classification(
|
||||
n_features=3, n_informative=2, n_redundant=1
|
||||
)
|
||||
rng = np.random.default_rng()
|
||||
qid = rng.integers(0, 3, size=X.shape[0])
|
||||
|
||||
# Sort the inputs based on query index
|
||||
sorted_idx = np.argsort(qid)
|
||||
X = X[sorted_idx, :]
|
||||
y = y[sorted_idx]
|
||||
qid = qid[sorted_idx]
|
||||
|
||||
ranker = XGBRanker(objective=objective)
|
||||
ranker.fit(X, y, qid=qid)
|
||||
|
||||
# Serialise the models to Elasticsearch
|
||||
model_id = "test_learning_to_rank"
|
||||
@ -356,7 +370,7 @@ class TestMLModel:
|
||||
es_model = MLModel.import_ltr_model(
|
||||
ES_TEST_CLIENT,
|
||||
model_id,
|
||||
regressor,
|
||||
ranker,
|
||||
ltr_model_config,
|
||||
es_if_exists="replace",
|
||||
es_compress_model_definition=compress_model_definition,
|
||||
@ -388,16 +402,27 @@ class TestMLModel:
|
||||
"learning_to_rank": {
|
||||
"model_id": model_id,
|
||||
"params": {"query_string": "yosemite"},
|
||||
}
|
||||
},
|
||||
"window_size": 2,
|
||||
},
|
||||
)
|
||||
|
||||
# Assert that:
|
||||
# - all documents from the query are present
|
||||
# - all documents have been rescored (score != 1.0)
|
||||
# Assert that rescored search result match predition.
|
||||
doc_scores = [hit["_score"] for hit in search_result["hits"]["hits"]]
|
||||
assert len(search_result["hits"]["hits"]) == 2
|
||||
assert all(score != float(1) for score in doc_scores)
|
||||
|
||||
feature_logger = FeatureLogger(
|
||||
ES_TEST_CLIENT, NATIONAL_PARKS_INDEX_NAME, ltr_model_config
|
||||
)
|
||||
expected_scores = sorted(
|
||||
[
|
||||
ranker.predict(np.asarray([doc_features]))[0]
|
||||
for _, doc_features in feature_logger.extract_features(
|
||||
{"query_string": "yosemite"}, ["park_yosemite", "park_everglades"]
|
||||
).items()
|
||||
],
|
||||
reverse=True,
|
||||
)
|
||||
np.testing.assert_almost_equal(expected_scores, doc_scores, decimal=2)
|
||||
|
||||
# Verify prediction is not supported for LTR
|
||||
try:
|
||||
|
Loading…
x
Reference in New Issue
Block a user