mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
Improve LTR (#651)
* Ensure the feature logger is using NaN for non matching query feature extractors (consistent with ES). * Default score is None instead of 0. * LTR model import API improvements. * Fix feature logger tests. * Fix export in eland.ml.ltr * Apply suggestions from code review Co-authored-by: Adam Demjen <demjened@gmail.com> * Fix supported models for LTR --------- Co-authored-by: Adam Demjen <demjened@gmail.com>
This commit is contained in:
parent
d2291889f8
commit
5169cc926a
@ -15,11 +15,12 @@
|
|||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
|
|
||||||
from eland.ml.ltr.feature_logger import FeatureLogger
|
from .feature_logger import FeatureLogger
|
||||||
from eland.ml.ltr.ltr_model_config import (
|
from .ltr_model_config import FeatureExtractor, LTRModelConfig, QueryFeatureExtractor
|
||||||
FeatureExtractor,
|
|
||||||
LTRModelConfig,
|
|
||||||
QueryFeatureExtractor,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [LTRModelConfig, QueryFeatureExtractor, FeatureExtractor, FeatureLogger]
|
__all__ = [
|
||||||
|
"LTRModelConfig",
|
||||||
|
"QueryFeatureExtractor",
|
||||||
|
"FeatureExtractor",
|
||||||
|
"FeatureLogger",
|
||||||
|
]
|
||||||
|
@ -95,7 +95,7 @@ class FeatureLogger:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
doc_features = {
|
doc_features = {
|
||||||
doc_id: [float(0)] * len(self._model_config.feature_extractors)
|
doc_id: [float("nan")] * len(self._model_config.feature_extractors)
|
||||||
for doc_id in doc_ids
|
for doc_id in doc_ids
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -141,28 +141,6 @@ class FeatureLogger:
|
|||||||
|
|
||||||
def _extract_query_features(
|
def _extract_query_features(
|
||||||
self, query_params: Mapping[str, Any], doc_ids: List[str]
|
self, query_params: Mapping[str, Any], doc_ids: List[str]
|
||||||
):
|
|
||||||
default_query_scores = dict(
|
|
||||||
(extractor.feature_name, extractor.default_score)
|
|
||||||
for extractor in self._model_config.query_feature_extractors
|
|
||||||
)
|
|
||||||
|
|
||||||
matched_queries = self._execute_search_template_request(
|
|
||||||
script_source=self._script_source,
|
|
||||||
template_params={
|
|
||||||
**query_params,
|
|
||||||
"__doc_ids": doc_ids,
|
|
||||||
"__size": len(doc_ids),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
hit_id: {**default_query_scores, **matched_queries_scores}
|
|
||||||
for hit_id, matched_queries_scores in matched_queries.items()
|
|
||||||
}
|
|
||||||
|
|
||||||
def _execute_search_template_request(
|
|
||||||
self, script_source: str, template_params: Mapping[str, any]
|
|
||||||
):
|
):
|
||||||
# When support for include_named_queries_score will be added,
|
# When support for include_named_queries_score will be added,
|
||||||
# this will be replaced by the call to the client search_template method.
|
# this will be replaced by the call to the client search_template method.
|
||||||
@ -171,7 +149,10 @@ class FeatureLogger:
|
|||||||
__path = f"/{_quote(self._index_name)}/_search/template"
|
__path = f"/{_quote(self._index_name)}/_search/template"
|
||||||
__query = {"include_named_queries_score": True}
|
__query = {"include_named_queries_score": True}
|
||||||
__headers = {"accept": "application/json", "content-type": "application/json"}
|
__headers = {"accept": "application/json", "content-type": "application/json"}
|
||||||
__body = {"source": script_source, "params": template_params}
|
__body = {
|
||||||
|
"source": self._script_source,
|
||||||
|
"params": {**query_params, "__doc_ids": doc_ids, "__size": len(doc_ids)},
|
||||||
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
hit["_id"]: hit["matched_queries"] if "matched_queries" in hit else {}
|
hit["_id"]: hit["matched_queries"] if "matched_queries" in hit else {}
|
||||||
|
@ -59,7 +59,7 @@ class QueryFeatureExtractor(FeatureExtractor):
|
|||||||
self,
|
self,
|
||||||
feature_name: str,
|
feature_name: str,
|
||||||
query: Mapping[str, Any],
|
query: Mapping[str, Any],
|
||||||
default_score: Optional[float] = float(0),
|
default_score: Optional[float] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Parameters
|
Parameters
|
||||||
|
@ -24,6 +24,7 @@ from eland.common import ensure_es_client, es_version
|
|||||||
from eland.utils import deprecated_api
|
from eland.utils import deprecated_api
|
||||||
|
|
||||||
from .common import TYPE_CLASSIFICATION, TYPE_LEARNING_TO_RANK, TYPE_REGRESSION
|
from .common import TYPE_CLASSIFICATION, TYPE_LEARNING_TO_RANK, TYPE_REGRESSION
|
||||||
|
from .ltr import LTRModelConfig
|
||||||
from .transformers import get_model_transformer
|
from .transformers import get_model_transformer
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -266,7 +267,6 @@ class MLModel:
|
|||||||
classification_weights: Optional[List[float]] = None,
|
classification_weights: Optional[List[float]] = None,
|
||||||
es_if_exists: Optional[str] = None,
|
es_if_exists: Optional[str] = None,
|
||||||
es_compress_model_definition: bool = True,
|
es_compress_model_definition: bool = True,
|
||||||
inference_config: Optional[Mapping[str, Mapping[str, Any]]] = None,
|
|
||||||
) -> "MLModel":
|
) -> "MLModel":
|
||||||
"""
|
"""
|
||||||
Transform and serialize a trained 3rd party model into Elasticsearch.
|
Transform and serialize a trained 3rd party model into Elasticsearch.
|
||||||
@ -342,10 +342,6 @@ class MLModel:
|
|||||||
JSON instead of raw JSON to reduce the amount of data sent
|
JSON instead of raw JSON to reduce the amount of data sent
|
||||||
over the wire in HTTP requests. Defaults to 'True'.
|
over the wire in HTTP requests. Defaults to 'True'.
|
||||||
|
|
||||||
inference_config: Mapping[str, Mapping[str, Any]]
|
|
||||||
Model inference configuration. Must contain a top-level property whose name is the same as the inference
|
|
||||||
task type.
|
|
||||||
|
|
||||||
Examples
|
Examples
|
||||||
--------
|
--------
|
||||||
>>> from sklearn import datasets
|
>>> from sklearn import datasets
|
||||||
@ -380,6 +376,125 @@ class MLModel:
|
|||||||
>>> # Delete model from Elasticsearch
|
>>> # Delete model from Elasticsearch
|
||||||
>>> es_model.delete_model()
|
>>> es_model.delete_model()
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
return cls._import_model(
|
||||||
|
es_client=es_client,
|
||||||
|
model_id=model_id,
|
||||||
|
model=model,
|
||||||
|
feature_names=feature_names,
|
||||||
|
classification_labels=classification_labels,
|
||||||
|
classification_weights=classification_weights,
|
||||||
|
es_if_exists=es_if_exists,
|
||||||
|
es_compress_model_definition=es_compress_model_definition,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def import_ltr_model(
|
||||||
|
cls,
|
||||||
|
es_client: Union[str, List[str], Tuple[str, ...], "Elasticsearch"],
|
||||||
|
model_id: str,
|
||||||
|
model: Union[
|
||||||
|
"DecisionTreeRegressor",
|
||||||
|
"RandomForestRegressor",
|
||||||
|
"XGBRanker",
|
||||||
|
"XGBRegressor",
|
||||||
|
"LGBMRegressor",
|
||||||
|
],
|
||||||
|
ltr_model_config: LTRModelConfig,
|
||||||
|
es_if_exists: Optional[str] = None,
|
||||||
|
es_compress_model_definition: bool = True,
|
||||||
|
) -> "MLModel":
|
||||||
|
"""
|
||||||
|
Transform and serialize a trained 3rd party model into Elasticsearch.
|
||||||
|
This model can then be used as a learning_to_rank rescorer in the Elastic Stack.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
es_client: Elasticsearch client argument(s)
|
||||||
|
- elasticsearch-py parameters or
|
||||||
|
- elasticsearch-py instance
|
||||||
|
|
||||||
|
model_id: str
|
||||||
|
The unique identifier of the trained inference model in Elasticsearch.
|
||||||
|
|
||||||
|
model: An instance of a supported python model. We support the following model types for LTR prediction:
|
||||||
|
- sklearn.tree.DecisionTreeRegressor
|
||||||
|
- sklearn.ensemble.RandomForestRegressor
|
||||||
|
- xgboost.XGBRanker
|
||||||
|
- only the following objectives are supported:
|
||||||
|
- "rank:map"
|
||||||
|
- "rank:ndcg"
|
||||||
|
- "rank:pairwise"
|
||||||
|
- xgboost.XGBRegressor
|
||||||
|
- only the following objectives are supported:
|
||||||
|
- "reg:squarederror"
|
||||||
|
- "reg:linear"
|
||||||
|
- "reg:squaredlogerror"
|
||||||
|
- "reg:logistic"
|
||||||
|
- "reg:pseudohubererror"
|
||||||
|
- lightgbm.LGBMRegressor
|
||||||
|
- Categorical fields are expected to already be processed
|
||||||
|
- Only the following objectives are supported
|
||||||
|
- "regression"
|
||||||
|
- "regression_l1"
|
||||||
|
- "huber"
|
||||||
|
- "fair"
|
||||||
|
- "quantile"
|
||||||
|
- "mape"
|
||||||
|
|
||||||
|
ltr_model_config: LTRModelConfig
|
||||||
|
The LTR model configuration is used to configure feature extractors for the LTR model.
|
||||||
|
Feature names are automatically inferred from the feature extractors.
|
||||||
|
|
||||||
|
es_if_exists: {'fail', 'replace'} default 'fail'
|
||||||
|
How to behave if model already exists
|
||||||
|
|
||||||
|
- fail: Raise a Value Error
|
||||||
|
- replace: Overwrite existing model
|
||||||
|
|
||||||
|
es_compress_model_definition: bool
|
||||||
|
If True will use 'compressed_definition' which uses gzipped
|
||||||
|
JSON instead of raw JSON to reduce the amount of data sent
|
||||||
|
over the wire in HTTP requests. Defaults to 'True'.
|
||||||
|
"""
|
||||||
|
|
||||||
|
return cls._import_model(
|
||||||
|
es_client=es_client,
|
||||||
|
model_id=model_id,
|
||||||
|
model=model,
|
||||||
|
feature_names=ltr_model_config.feature_names,
|
||||||
|
inference_config=ltr_model_config.to_dict(),
|
||||||
|
es_if_exists=es_if_exists,
|
||||||
|
es_compress_model_definition=es_compress_model_definition,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _import_model(
|
||||||
|
cls,
|
||||||
|
es_client: Union[str, List[str], Tuple[str, ...], "Elasticsearch"],
|
||||||
|
model_id: str,
|
||||||
|
model: Union[
|
||||||
|
"DecisionTreeClassifier",
|
||||||
|
"DecisionTreeRegressor",
|
||||||
|
"RandomForestRegressor",
|
||||||
|
"RandomForestClassifier",
|
||||||
|
"XGBClassifier",
|
||||||
|
"XGBRanker",
|
||||||
|
"XGBRegressor",
|
||||||
|
"LGBMRegressor",
|
||||||
|
"LGBMClassifier",
|
||||||
|
],
|
||||||
|
feature_names: List[str],
|
||||||
|
classification_labels: Optional[List[str]] = None,
|
||||||
|
classification_weights: Optional[List[float]] = None,
|
||||||
|
es_if_exists: Optional[str] = None,
|
||||||
|
es_compress_model_definition: bool = True,
|
||||||
|
inference_config: Optional[Mapping[str, Mapping[str, Any]]] = None,
|
||||||
|
) -> "MLModel":
|
||||||
|
"""
|
||||||
|
Actual implementation of model import used by public API methods.
|
||||||
|
"""
|
||||||
|
|
||||||
es_client = ensure_es_client(es_client)
|
es_client = ensure_es_client(es_client)
|
||||||
transformer = get_model_transformer(
|
transformer = get_model_transformer(
|
||||||
model,
|
model,
|
||||||
|
@ -15,6 +15,8 @@
|
|||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
from eland.ml.ltr import FeatureLogger, LTRModelConfig, QueryFeatureExtractor
|
from eland.ml.ltr import FeatureLogger, LTRModelConfig, QueryFeatureExtractor
|
||||||
from tests import ES_TEST_CLIENT, NATIONAL_PARKS_INDEX_NAME
|
from tests import ES_TEST_CLIENT, NATIONAL_PARKS_INDEX_NAME
|
||||||
|
|
||||||
@ -53,12 +55,12 @@ class TestFeatureLogger:
|
|||||||
|
|
||||||
# "park_hawaii-volcanoes" document does not matches for title but is a world heritage site
|
# "park_hawaii-volcanoes" document does not matches for title but is a world heritage site
|
||||||
assert (
|
assert (
|
||||||
doc_features["park_hawaii-volcanoes"][0] == 0
|
math.isnan(doc_features["park_hawaii-volcanoes"][0])
|
||||||
and doc_features["park_hawaii-volcanoes"][1] > 1
|
and doc_features["park_hawaii-volcanoes"][1] > 1
|
||||||
)
|
)
|
||||||
|
|
||||||
# "park_hawaii-volcanoes" document does not matches for title and is not a world heritage site
|
# "park_hawaii-volcanoes" document does not matches for title and is not a world heritage site
|
||||||
assert doc_features["park_death-valley"] == [0, 0]
|
assert all(math.isnan(feature) for feature in doc_features["park_death-valley"])
|
||||||
|
|
||||||
def _ltr_model_config(self):
|
def _ltr_model_config(self):
|
||||||
# Returns an LTR config with 2 query feature extractors:
|
# Returns an LTR config with 2 query feature extractors:
|
||||||
|
@ -23,6 +23,7 @@ import pytest
|
|||||||
|
|
||||||
import eland as ed
|
import eland as ed
|
||||||
from eland.ml import MLModel
|
from eland.ml import MLModel
|
||||||
|
from eland.ml.ltr import LTRModelConfig, QueryFeatureExtractor
|
||||||
from tests import (
|
from tests import (
|
||||||
ES_TEST_CLIENT,
|
ES_TEST_CLIENT,
|
||||||
ES_VERSION,
|
ES_VERSION,
|
||||||
@ -330,41 +331,35 @@ class TestMLModel:
|
|||||||
|
|
||||||
# Serialise the models to Elasticsearch
|
# Serialise the models to Elasticsearch
|
||||||
model_id = "test_learning_to_rank"
|
model_id = "test_learning_to_rank"
|
||||||
|
ltr_model_config = LTRModelConfig(
|
||||||
feature_extractors=[
|
feature_extractors=[
|
||||||
{
|
QueryFeatureExtractor(
|
||||||
"query_extractor": {
|
feature_name="title_bm25",
|
||||||
"feature_name": "title_bm25",
|
query={"match": {"title": "{{query_string}}"}},
|
||||||
"query": {"match": {"title": "{{query_string}}"}},
|
),
|
||||||
}
|
QueryFeatureExtractor(
|
||||||
},
|
feature_name="description_bm25",
|
||||||
{
|
query={"match": {"description_bm25": "{{query_string}}"}},
|
||||||
"query_extractor": {
|
),
|
||||||
"feature_name": "visitors",
|
QueryFeatureExtractor(
|
||||||
"query": {
|
feature_name="visitors",
|
||||||
|
query={
|
||||||
"script_score": {
|
"script_score": {
|
||||||
"query": {"exists": {"field": "visitors"}},
|
"query": {"exists": {"field": "visitors"}},
|
||||||
"script": {"source": 'return doc["visitors"].value;'},
|
"script": {"source": 'return doc["visitors"].value;'},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
),
|
||||||
},
|
|
||||||
]
|
]
|
||||||
feature_names = [
|
)
|
||||||
extractor["query_extractor"]["feature_name"]
|
|
||||||
for extractor in feature_extractors
|
|
||||||
]
|
|
||||||
inference_config = {
|
|
||||||
"learning_to_rank": {"feature_extractors": feature_extractors}
|
|
||||||
}
|
|
||||||
|
|
||||||
es_model = MLModel.import_model(
|
es_model = MLModel.import_ltr_model(
|
||||||
ES_TEST_CLIENT,
|
ES_TEST_CLIENT,
|
||||||
model_id,
|
model_id,
|
||||||
regressor,
|
regressor,
|
||||||
feature_names,
|
ltr_model_config,
|
||||||
es_if_exists="replace",
|
es_if_exists="replace",
|
||||||
es_compress_model_definition=compress_model_definition,
|
es_compress_model_definition=compress_model_definition,
|
||||||
inference_config=inference_config,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify the saved inference config contains the passed LTR config
|
# Verify the saved inference config contains the passed LTR config
|
||||||
@ -375,10 +370,14 @@ class TestMLModel:
|
|||||||
"inference_config"
|
"inference_config"
|
||||||
]
|
]
|
||||||
assert "learning_to_rank" in saved_inference_config
|
assert "learning_to_rank" in saved_inference_config
|
||||||
saved_ltr_config = saved_inference_config["learning_to_rank"]
|
assert "feature_extractors" in saved_inference_config["learning_to_rank"]
|
||||||
|
saved_feature_extractors = saved_inference_config["learning_to_rank"][
|
||||||
|
"feature_extractors"
|
||||||
|
]
|
||||||
|
|
||||||
assert all(
|
assert all(
|
||||||
item in saved_ltr_config.items()
|
feature_extractor.to_dict() in saved_feature_extractors
|
||||||
for item in inference_config["learning_to_rank"].items()
|
for feature_extractor in ltr_model_config.feature_extractors
|
||||||
)
|
)
|
||||||
|
|
||||||
# Execute search with rescoring
|
# Execute search with rescoring
|
||||||
|
Loading…
x
Reference in New Issue
Block a user