Improve LTR (#651)

* Ensure the feature logger is using NaN for non matching query feature extractors (consistent with ES).

* Default score is None instead of 0.

* LTR model import API improvements.

* Fix feature logger tests.

* Fix export in eland.ml.ltr

* Apply suggestions from code review

Co-authored-by: Adam Demjen <demjened@gmail.com>

* Fix supported models for LTR

---------

Co-authored-by: Adam Demjen <demjened@gmail.com>
This commit is contained in:
Aurélien FOUCRET 2024-01-17 10:01:47 +01:00 committed by GitHub
parent d2291889f8
commit 5169cc926a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 164 additions and 66 deletions

View File

@ -15,11 +15,12 @@
# specific language governing permissions and limitations
# under the License.
from eland.ml.ltr.feature_logger import FeatureLogger
from eland.ml.ltr.ltr_model_config import (
FeatureExtractor,
LTRModelConfig,
QueryFeatureExtractor,
)
from .feature_logger import FeatureLogger
from .ltr_model_config import FeatureExtractor, LTRModelConfig, QueryFeatureExtractor
__all__ = [LTRModelConfig, QueryFeatureExtractor, FeatureExtractor, FeatureLogger]
__all__ = [
"LTRModelConfig",
"QueryFeatureExtractor",
"FeatureExtractor",
"FeatureLogger",
]

View File

@ -95,7 +95,7 @@ class FeatureLogger:
"""
doc_features = {
doc_id: [float(0)] * len(self._model_config.feature_extractors)
doc_id: [float("nan")] * len(self._model_config.feature_extractors)
for doc_id in doc_ids
}
@ -141,28 +141,6 @@ class FeatureLogger:
def _extract_query_features(
self, query_params: Mapping[str, Any], doc_ids: List[str]
):
default_query_scores = dict(
(extractor.feature_name, extractor.default_score)
for extractor in self._model_config.query_feature_extractors
)
matched_queries = self._execute_search_template_request(
script_source=self._script_source,
template_params={
**query_params,
"__doc_ids": doc_ids,
"__size": len(doc_ids),
},
)
return {
hit_id: {**default_query_scores, **matched_queries_scores}
for hit_id, matched_queries_scores in matched_queries.items()
}
def _execute_search_template_request(
self, script_source: str, template_params: Mapping[str, any]
):
# When support for include_named_queries_score will be added,
# this will be replaced by the call to the client search_template method.
@ -171,7 +149,10 @@ class FeatureLogger:
__path = f"/{_quote(self._index_name)}/_search/template"
__query = {"include_named_queries_score": True}
__headers = {"accept": "application/json", "content-type": "application/json"}
__body = {"source": script_source, "params": template_params}
__body = {
"source": self._script_source,
"params": {**query_params, "__doc_ids": doc_ids, "__size": len(doc_ids)},
}
return {
hit["_id"]: hit["matched_queries"] if "matched_queries" in hit else {}

View File

@ -59,7 +59,7 @@ class QueryFeatureExtractor(FeatureExtractor):
self,
feature_name: str,
query: Mapping[str, Any],
default_score: Optional[float] = float(0),
default_score: Optional[float] = None,
):
"""
Parameters

View File

@ -24,6 +24,7 @@ from eland.common import ensure_es_client, es_version
from eland.utils import deprecated_api
from .common import TYPE_CLASSIFICATION, TYPE_LEARNING_TO_RANK, TYPE_REGRESSION
from .ltr import LTRModelConfig
from .transformers import get_model_transformer
if TYPE_CHECKING:
@ -266,7 +267,6 @@ class MLModel:
classification_weights: Optional[List[float]] = None,
es_if_exists: Optional[str] = None,
es_compress_model_definition: bool = True,
inference_config: Optional[Mapping[str, Mapping[str, Any]]] = None,
) -> "MLModel":
"""
Transform and serialize a trained 3rd party model into Elasticsearch.
@ -342,10 +342,6 @@ class MLModel:
JSON instead of raw JSON to reduce the amount of data sent
over the wire in HTTP requests. Defaults to 'True'.
inference_config: Mapping[str, Mapping[str, Any]]
Model inference configuration. Must contain a top-level property whose name is the same as the inference
task type.
Examples
--------
>>> from sklearn import datasets
@ -380,6 +376,125 @@ class MLModel:
>>> # Delete model from Elasticsearch
>>> es_model.delete_model()
"""
return cls._import_model(
es_client=es_client,
model_id=model_id,
model=model,
feature_names=feature_names,
classification_labels=classification_labels,
classification_weights=classification_weights,
es_if_exists=es_if_exists,
es_compress_model_definition=es_compress_model_definition,
)
@classmethod
def import_ltr_model(
cls,
es_client: Union[str, List[str], Tuple[str, ...], "Elasticsearch"],
model_id: str,
model: Union[
"DecisionTreeRegressor",
"RandomForestRegressor",
"XGBRanker",
"XGBRegressor",
"LGBMRegressor",
],
ltr_model_config: LTRModelConfig,
es_if_exists: Optional[str] = None,
es_compress_model_definition: bool = True,
) -> "MLModel":
"""
Transform and serialize a trained 3rd party model into Elasticsearch.
This model can then be used as a learning_to_rank rescorer in the Elastic Stack.
Parameters
----------
es_client: Elasticsearch client argument(s)
- elasticsearch-py parameters or
- elasticsearch-py instance
model_id: str
The unique identifier of the trained inference model in Elasticsearch.
model: An instance of a supported python model. We support the following model types for LTR prediction:
- sklearn.tree.DecisionTreeRegressor
- sklearn.ensemble.RandomForestRegressor
- xgboost.XGBRanker
- only the following objectives are supported:
- "rank:map"
- "rank:ndcg"
- "rank:pairwise"
- xgboost.XGBRegressor
- only the following objectives are supported:
- "reg:squarederror"
- "reg:linear"
- "reg:squaredlogerror"
- "reg:logistic"
- "reg:pseudohubererror"
- lightgbm.LGBMRegressor
- Categorical fields are expected to already be processed
- Only the following objectives are supported
- "regression"
- "regression_l1"
- "huber"
- "fair"
- "quantile"
- "mape"
ltr_model_config: LTRModelConfig
The LTR model configuration is used to configure feature extractors for the LTR model.
Feature names are automatically inferred from the feature extractors.
es_if_exists: {'fail', 'replace'} default 'fail'
How to behave if model already exists
- fail: Raise a Value Error
- replace: Overwrite existing model
es_compress_model_definition: bool
If True will use 'compressed_definition' which uses gzipped
JSON instead of raw JSON to reduce the amount of data sent
over the wire in HTTP requests. Defaults to 'True'.
"""
return cls._import_model(
es_client=es_client,
model_id=model_id,
model=model,
feature_names=ltr_model_config.feature_names,
inference_config=ltr_model_config.to_dict(),
es_if_exists=es_if_exists,
es_compress_model_definition=es_compress_model_definition,
)
@classmethod
def _import_model(
cls,
es_client: Union[str, List[str], Tuple[str, ...], "Elasticsearch"],
model_id: str,
model: Union[
"DecisionTreeClassifier",
"DecisionTreeRegressor",
"RandomForestRegressor",
"RandomForestClassifier",
"XGBClassifier",
"XGBRanker",
"XGBRegressor",
"LGBMRegressor",
"LGBMClassifier",
],
feature_names: List[str],
classification_labels: Optional[List[str]] = None,
classification_weights: Optional[List[float]] = None,
es_if_exists: Optional[str] = None,
es_compress_model_definition: bool = True,
inference_config: Optional[Mapping[str, Mapping[str, Any]]] = None,
) -> "MLModel":
"""
Actual implementation of model import used by public API methods.
"""
es_client = ensure_es_client(es_client)
transformer = get_model_transformer(
model,

View File

@ -15,6 +15,8 @@
# specific language governing permissions and limitations
# under the License.
import math
from eland.ml.ltr import FeatureLogger, LTRModelConfig, QueryFeatureExtractor
from tests import ES_TEST_CLIENT, NATIONAL_PARKS_INDEX_NAME
@ -53,12 +55,12 @@ class TestFeatureLogger:
# "park_hawaii-volcanoes" document does not matches for title but is a world heritage site
assert (
doc_features["park_hawaii-volcanoes"][0] == 0
math.isnan(doc_features["park_hawaii-volcanoes"][0])
and doc_features["park_hawaii-volcanoes"][1] > 1
)
# "park_hawaii-volcanoes" document does not matches for title and is not a world heritage site
assert doc_features["park_death-valley"] == [0, 0]
assert all(math.isnan(feature) for feature in doc_features["park_death-valley"])
def _ltr_model_config(self):
# Returns an LTR config with 2 query feature extractors:

View File

@ -23,6 +23,7 @@ import pytest
import eland as ed
from eland.ml import MLModel
from eland.ml.ltr import LTRModelConfig, QueryFeatureExtractor
from tests import (
ES_TEST_CLIENT,
ES_VERSION,
@ -330,41 +331,35 @@ class TestMLModel:
# Serialise the models to Elasticsearch
model_id = "test_learning_to_rank"
feature_extractors = [
{
"query_extractor": {
"feature_name": "title_bm25",
"query": {"match": {"title": "{{query_string}}"}},
}
},
{
"query_extractor": {
"feature_name": "visitors",
"query": {
ltr_model_config = LTRModelConfig(
feature_extractors=[
QueryFeatureExtractor(
feature_name="title_bm25",
query={"match": {"title": "{{query_string}}"}},
),
QueryFeatureExtractor(
feature_name="description_bm25",
query={"match": {"description_bm25": "{{query_string}}"}},
),
QueryFeatureExtractor(
feature_name="visitors",
query={
"script_score": {
"query": {"exists": {"field": "visitors"}},
"script": {"source": 'return doc["visitors"].value;'},
}
},
}
},
]
feature_names = [
extractor["query_extractor"]["feature_name"]
for extractor in feature_extractors
]
inference_config = {
"learning_to_rank": {"feature_extractors": feature_extractors}
}
),
]
)
es_model = MLModel.import_model(
es_model = MLModel.import_ltr_model(
ES_TEST_CLIENT,
model_id,
regressor,
feature_names,
ltr_model_config,
es_if_exists="replace",
es_compress_model_definition=compress_model_definition,
inference_config=inference_config,
)
# Verify the saved inference config contains the passed LTR config
@ -375,10 +370,14 @@ class TestMLModel:
"inference_config"
]
assert "learning_to_rank" in saved_inference_config
saved_ltr_config = saved_inference_config["learning_to_rank"]
assert "feature_extractors" in saved_inference_config["learning_to_rank"]
saved_feature_extractors = saved_inference_config["learning_to_rank"][
"feature_extractors"
]
assert all(
item in saved_ltr_config.items()
for item in inference_config["learning_to_rank"].items()
feature_extractor.to_dict() in saved_feature_extractors
for feature_extractor in ltr_model_config.feature_extractors
)
# Execute search with rescoring