mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
Improve LTR (#651)
* Ensure the feature logger is using NaN for non matching query feature extractors (consistent with ES). * Default score is None instead of 0. * LTR model import API improvements. * Fix feature logger tests. * Fix export in eland.ml.ltr * Apply suggestions from code review Co-authored-by: Adam Demjen <demjened@gmail.com> * Fix supported models for LTR --------- Co-authored-by: Adam Demjen <demjened@gmail.com>
This commit is contained in:
parent
d2291889f8
commit
5169cc926a
@ -15,11 +15,12 @@
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from eland.ml.ltr.feature_logger import FeatureLogger
|
||||
from eland.ml.ltr.ltr_model_config import (
|
||||
FeatureExtractor,
|
||||
LTRModelConfig,
|
||||
QueryFeatureExtractor,
|
||||
)
|
||||
from .feature_logger import FeatureLogger
|
||||
from .ltr_model_config import FeatureExtractor, LTRModelConfig, QueryFeatureExtractor
|
||||
|
||||
__all__ = [LTRModelConfig, QueryFeatureExtractor, FeatureExtractor, FeatureLogger]
|
||||
__all__ = [
|
||||
"LTRModelConfig",
|
||||
"QueryFeatureExtractor",
|
||||
"FeatureExtractor",
|
||||
"FeatureLogger",
|
||||
]
|
||||
|
@ -95,7 +95,7 @@ class FeatureLogger:
|
||||
"""
|
||||
|
||||
doc_features = {
|
||||
doc_id: [float(0)] * len(self._model_config.feature_extractors)
|
||||
doc_id: [float("nan")] * len(self._model_config.feature_extractors)
|
||||
for doc_id in doc_ids
|
||||
}
|
||||
|
||||
@ -141,28 +141,6 @@ class FeatureLogger:
|
||||
|
||||
def _extract_query_features(
|
||||
self, query_params: Mapping[str, Any], doc_ids: List[str]
|
||||
):
|
||||
default_query_scores = dict(
|
||||
(extractor.feature_name, extractor.default_score)
|
||||
for extractor in self._model_config.query_feature_extractors
|
||||
)
|
||||
|
||||
matched_queries = self._execute_search_template_request(
|
||||
script_source=self._script_source,
|
||||
template_params={
|
||||
**query_params,
|
||||
"__doc_ids": doc_ids,
|
||||
"__size": len(doc_ids),
|
||||
},
|
||||
)
|
||||
|
||||
return {
|
||||
hit_id: {**default_query_scores, **matched_queries_scores}
|
||||
for hit_id, matched_queries_scores in matched_queries.items()
|
||||
}
|
||||
|
||||
def _execute_search_template_request(
|
||||
self, script_source: str, template_params: Mapping[str, any]
|
||||
):
|
||||
# When support for include_named_queries_score will be added,
|
||||
# this will be replaced by the call to the client search_template method.
|
||||
@ -171,7 +149,10 @@ class FeatureLogger:
|
||||
__path = f"/{_quote(self._index_name)}/_search/template"
|
||||
__query = {"include_named_queries_score": True}
|
||||
__headers = {"accept": "application/json", "content-type": "application/json"}
|
||||
__body = {"source": script_source, "params": template_params}
|
||||
__body = {
|
||||
"source": self._script_source,
|
||||
"params": {**query_params, "__doc_ids": doc_ids, "__size": len(doc_ids)},
|
||||
}
|
||||
|
||||
return {
|
||||
hit["_id"]: hit["matched_queries"] if "matched_queries" in hit else {}
|
||||
|
@ -59,7 +59,7 @@ class QueryFeatureExtractor(FeatureExtractor):
|
||||
self,
|
||||
feature_name: str,
|
||||
query: Mapping[str, Any],
|
||||
default_score: Optional[float] = float(0),
|
||||
default_score: Optional[float] = None,
|
||||
):
|
||||
"""
|
||||
Parameters
|
||||
|
@ -24,6 +24,7 @@ from eland.common import ensure_es_client, es_version
|
||||
from eland.utils import deprecated_api
|
||||
|
||||
from .common import TYPE_CLASSIFICATION, TYPE_LEARNING_TO_RANK, TYPE_REGRESSION
|
||||
from .ltr import LTRModelConfig
|
||||
from .transformers import get_model_transformer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -266,7 +267,6 @@ class MLModel:
|
||||
classification_weights: Optional[List[float]] = None,
|
||||
es_if_exists: Optional[str] = None,
|
||||
es_compress_model_definition: bool = True,
|
||||
inference_config: Optional[Mapping[str, Mapping[str, Any]]] = None,
|
||||
) -> "MLModel":
|
||||
"""
|
||||
Transform and serialize a trained 3rd party model into Elasticsearch.
|
||||
@ -342,10 +342,6 @@ class MLModel:
|
||||
JSON instead of raw JSON to reduce the amount of data sent
|
||||
over the wire in HTTP requests. Defaults to 'True'.
|
||||
|
||||
inference_config: Mapping[str, Mapping[str, Any]]
|
||||
Model inference configuration. Must contain a top-level property whose name is the same as the inference
|
||||
task type.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> from sklearn import datasets
|
||||
@ -380,6 +376,125 @@ class MLModel:
|
||||
>>> # Delete model from Elasticsearch
|
||||
>>> es_model.delete_model()
|
||||
"""
|
||||
|
||||
return cls._import_model(
|
||||
es_client=es_client,
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
feature_names=feature_names,
|
||||
classification_labels=classification_labels,
|
||||
classification_weights=classification_weights,
|
||||
es_if_exists=es_if_exists,
|
||||
es_compress_model_definition=es_compress_model_definition,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def import_ltr_model(
|
||||
cls,
|
||||
es_client: Union[str, List[str], Tuple[str, ...], "Elasticsearch"],
|
||||
model_id: str,
|
||||
model: Union[
|
||||
"DecisionTreeRegressor",
|
||||
"RandomForestRegressor",
|
||||
"XGBRanker",
|
||||
"XGBRegressor",
|
||||
"LGBMRegressor",
|
||||
],
|
||||
ltr_model_config: LTRModelConfig,
|
||||
es_if_exists: Optional[str] = None,
|
||||
es_compress_model_definition: bool = True,
|
||||
) -> "MLModel":
|
||||
"""
|
||||
Transform and serialize a trained 3rd party model into Elasticsearch.
|
||||
This model can then be used as a learning_to_rank rescorer in the Elastic Stack.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
es_client: Elasticsearch client argument(s)
|
||||
- elasticsearch-py parameters or
|
||||
- elasticsearch-py instance
|
||||
|
||||
model_id: str
|
||||
The unique identifier of the trained inference model in Elasticsearch.
|
||||
|
||||
model: An instance of a supported python model. We support the following model types for LTR prediction:
|
||||
- sklearn.tree.DecisionTreeRegressor
|
||||
- sklearn.ensemble.RandomForestRegressor
|
||||
- xgboost.XGBRanker
|
||||
- only the following objectives are supported:
|
||||
- "rank:map"
|
||||
- "rank:ndcg"
|
||||
- "rank:pairwise"
|
||||
- xgboost.XGBRegressor
|
||||
- only the following objectives are supported:
|
||||
- "reg:squarederror"
|
||||
- "reg:linear"
|
||||
- "reg:squaredlogerror"
|
||||
- "reg:logistic"
|
||||
- "reg:pseudohubererror"
|
||||
- lightgbm.LGBMRegressor
|
||||
- Categorical fields are expected to already be processed
|
||||
- Only the following objectives are supported
|
||||
- "regression"
|
||||
- "regression_l1"
|
||||
- "huber"
|
||||
- "fair"
|
||||
- "quantile"
|
||||
- "mape"
|
||||
|
||||
ltr_model_config: LTRModelConfig
|
||||
The LTR model configuration is used to configure feature extractors for the LTR model.
|
||||
Feature names are automatically inferred from the feature extractors.
|
||||
|
||||
es_if_exists: {'fail', 'replace'} default 'fail'
|
||||
How to behave if model already exists
|
||||
|
||||
- fail: Raise a Value Error
|
||||
- replace: Overwrite existing model
|
||||
|
||||
es_compress_model_definition: bool
|
||||
If True will use 'compressed_definition' which uses gzipped
|
||||
JSON instead of raw JSON to reduce the amount of data sent
|
||||
over the wire in HTTP requests. Defaults to 'True'.
|
||||
"""
|
||||
|
||||
return cls._import_model(
|
||||
es_client=es_client,
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
feature_names=ltr_model_config.feature_names,
|
||||
inference_config=ltr_model_config.to_dict(),
|
||||
es_if_exists=es_if_exists,
|
||||
es_compress_model_definition=es_compress_model_definition,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _import_model(
|
||||
cls,
|
||||
es_client: Union[str, List[str], Tuple[str, ...], "Elasticsearch"],
|
||||
model_id: str,
|
||||
model: Union[
|
||||
"DecisionTreeClassifier",
|
||||
"DecisionTreeRegressor",
|
||||
"RandomForestRegressor",
|
||||
"RandomForestClassifier",
|
||||
"XGBClassifier",
|
||||
"XGBRanker",
|
||||
"XGBRegressor",
|
||||
"LGBMRegressor",
|
||||
"LGBMClassifier",
|
||||
],
|
||||
feature_names: List[str],
|
||||
classification_labels: Optional[List[str]] = None,
|
||||
classification_weights: Optional[List[float]] = None,
|
||||
es_if_exists: Optional[str] = None,
|
||||
es_compress_model_definition: bool = True,
|
||||
inference_config: Optional[Mapping[str, Mapping[str, Any]]] = None,
|
||||
) -> "MLModel":
|
||||
"""
|
||||
Actual implementation of model import used by public API methods.
|
||||
"""
|
||||
|
||||
es_client = ensure_es_client(es_client)
|
||||
transformer = get_model_transformer(
|
||||
model,
|
||||
|
@ -15,6 +15,8 @@
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import math
|
||||
|
||||
from eland.ml.ltr import FeatureLogger, LTRModelConfig, QueryFeatureExtractor
|
||||
from tests import ES_TEST_CLIENT, NATIONAL_PARKS_INDEX_NAME
|
||||
|
||||
@ -53,12 +55,12 @@ class TestFeatureLogger:
|
||||
|
||||
# "park_hawaii-volcanoes" document does not matches for title but is a world heritage site
|
||||
assert (
|
||||
doc_features["park_hawaii-volcanoes"][0] == 0
|
||||
math.isnan(doc_features["park_hawaii-volcanoes"][0])
|
||||
and doc_features["park_hawaii-volcanoes"][1] > 1
|
||||
)
|
||||
|
||||
# "park_hawaii-volcanoes" document does not matches for title and is not a world heritage site
|
||||
assert doc_features["park_death-valley"] == [0, 0]
|
||||
assert all(math.isnan(feature) for feature in doc_features["park_death-valley"])
|
||||
|
||||
def _ltr_model_config(self):
|
||||
# Returns an LTR config with 2 query feature extractors:
|
||||
|
@ -23,6 +23,7 @@ import pytest
|
||||
|
||||
import eland as ed
|
||||
from eland.ml import MLModel
|
||||
from eland.ml.ltr import LTRModelConfig, QueryFeatureExtractor
|
||||
from tests import (
|
||||
ES_TEST_CLIENT,
|
||||
ES_VERSION,
|
||||
@ -330,41 +331,35 @@ class TestMLModel:
|
||||
|
||||
# Serialise the models to Elasticsearch
|
||||
model_id = "test_learning_to_rank"
|
||||
feature_extractors = [
|
||||
{
|
||||
"query_extractor": {
|
||||
"feature_name": "title_bm25",
|
||||
"query": {"match": {"title": "{{query_string}}"}},
|
||||
}
|
||||
},
|
||||
{
|
||||
"query_extractor": {
|
||||
"feature_name": "visitors",
|
||||
"query": {
|
||||
ltr_model_config = LTRModelConfig(
|
||||
feature_extractors=[
|
||||
QueryFeatureExtractor(
|
||||
feature_name="title_bm25",
|
||||
query={"match": {"title": "{{query_string}}"}},
|
||||
),
|
||||
QueryFeatureExtractor(
|
||||
feature_name="description_bm25",
|
||||
query={"match": {"description_bm25": "{{query_string}}"}},
|
||||
),
|
||||
QueryFeatureExtractor(
|
||||
feature_name="visitors",
|
||||
query={
|
||||
"script_score": {
|
||||
"query": {"exists": {"field": "visitors"}},
|
||||
"script": {"source": 'return doc["visitors"].value;'},
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
]
|
||||
feature_names = [
|
||||
extractor["query_extractor"]["feature_name"]
|
||||
for extractor in feature_extractors
|
||||
]
|
||||
inference_config = {
|
||||
"learning_to_rank": {"feature_extractors": feature_extractors}
|
||||
}
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
es_model = MLModel.import_model(
|
||||
es_model = MLModel.import_ltr_model(
|
||||
ES_TEST_CLIENT,
|
||||
model_id,
|
||||
regressor,
|
||||
feature_names,
|
||||
ltr_model_config,
|
||||
es_if_exists="replace",
|
||||
es_compress_model_definition=compress_model_definition,
|
||||
inference_config=inference_config,
|
||||
)
|
||||
|
||||
# Verify the saved inference config contains the passed LTR config
|
||||
@ -375,10 +370,14 @@ class TestMLModel:
|
||||
"inference_config"
|
||||
]
|
||||
assert "learning_to_rank" in saved_inference_config
|
||||
saved_ltr_config = saved_inference_config["learning_to_rank"]
|
||||
assert "feature_extractors" in saved_inference_config["learning_to_rank"]
|
||||
saved_feature_extractors = saved_inference_config["learning_to_rank"][
|
||||
"feature_extractors"
|
||||
]
|
||||
|
||||
assert all(
|
||||
item in saved_ltr_config.items()
|
||||
for item in inference_config["learning_to_rank"].items()
|
||||
feature_extractor.to_dict() in saved_feature_extractors
|
||||
for feature_extractor in ltr_model_config.feature_extractors
|
||||
)
|
||||
|
||||
# Execute search with rescoring
|
||||
|
Loading…
x
Reference in New Issue
Block a user