LTR feature logger (#648)

This commit is contained in:
Aurélien FOUCRET 2024-01-12 13:52:04 +01:00 committed by GitHub
parent 926f0b9b5c
commit d3ed669a5e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 466 additions and 52 deletions

View File

@ -50,3 +50,6 @@ Permission is hereby granted, free of charge, to any person obtaining a copy of
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
--
This product contains a adapted version of the "us-national-parks" dataset, https://data.world/kevinnayar/us-national-parks, by Kevin Nayar, https://data.world/kevinnayar, is licensed under CC BY, https://creativecommons.org/licenses/by/4.0/legalcode

25
eland/ml/ltr/__init__.py Normal file
View File

@ -0,0 +1,25 @@
# Licensed to Elasticsearch B.V. under one or more contributor
# license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright
# ownership. Elasticsearch B.V. licenses this file to you under
# the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from eland.ml.ltr.feature_logger import FeatureLogger
from eland.ml.ltr.ltr_model_config import (
FeatureExtractor,
LTRModelConfig,
QueryFeatureExtractor,
)
__all__ = [LTRModelConfig, QueryFeatureExtractor, FeatureExtractor, FeatureLogger]

View File

@ -0,0 +1,181 @@
# Licensed to Elasticsearch B.V. under one or more contributor
# license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright
# ownership. Elasticsearch B.V. licenses this file to you under
# the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import json
from functools import cached_property
from typing import TYPE_CHECKING, Any, List, Mapping, Tuple, Union
from eland.common import ensure_es_client
from eland.ml.ltr.ltr_model_config import LTRModelConfig
if TYPE_CHECKING:
from elasticsearch import Elasticsearch
class FeatureLogger:
"""
A class that is used during model training to extract features from the judgment list.
"""
def __init__(
self,
es_client: Union[str, List[str], Tuple[str, ...], "Elasticsearch"],
es_index: str,
ltr_model_config: LTRModelConfig,
):
"""
Parameters
----------
es_client: Elasticsearch client argument(s)
- elasticsearch-py parameters or
- elasticsearch-py instance
es_index: str
Name of the Elastcsearch index used for features extractions.
ltr_model_config: LTRModelConfig
LTR model config used to extract feature.
"""
self._model_config = ltr_model_config
self._client: Elasticsearch = ensure_es_client(es_client)
self._index_name = es_index
def extract_features(
self, query_params: Mapping[str, Any], doc_ids: List[str]
) -> Mapping[str, List[float]]:
"""
Extract document features.
Parameters
----------
query_params: Mapping[str, Any]
List of templates params used during features extraction.
doc_ids: List[str]
List of doc ids.
Example
-------
>>> from eland.ml.ltr import FeatureLogger, LTRModelConfig, QueryFeatureExtractor
>>> ltr_model_config=LTRModelConfig(
... feature_extractors=[
... QueryFeatureExtractor(
... feature_name='title_bm25',
... query={ "match": { "title": "{{query}}" } }
... ),
... QueryFeatureExtractor(
... feature_name='descritption_bm25',
... query={ "match": { "description": "{{query}}" } }
... )
... ]
... )
>>> feature_logger = FeatureLogger(
... es_client='http://localhost:9200',
... es_index='national_parks',
... ltr_model_config=ltr_model_config
... )
>>> doc_features = feature_logger.extract_features(query_params={"query": "yosemite"}, doc_ids=["park-yosemite", "park-everglade"])
"""
doc_features = {
doc_id: [float(0)] * len(self._model_config.feature_extractors)
for doc_id in doc_ids
}
for doc_id, query_features in self._extract_query_features(
query_params, doc_ids
).items():
for feature_name, feature_value in query_features.items():
doc_features[doc_id][
self._model_config.feature_index(feature_name)
] = feature_value
return doc_features
def _to_named_query(
self, query: Mapping[str, Mapping[str, any]], query_name: str
) -> Mapping[str, Mapping[str, any]]:
return {"bool": {"must": query, "_name": query_name}}
@cached_property
def _script_source(self) -> str:
query_extractors = self._model_config.query_feature_extractors
queries = [
self._to_named_query(extractor.query, extractor.feature_name)
for extractor in query_extractors
]
return (
json.dumps(
{
"query": {
"bool": {
"should": queries,
"filter": {"ids": {"values": "##DOC_IDS_JSON##"}},
}
},
"size": "##DOC_IDS_SIZE##",
"_source": False,
}
)
.replace('"##DOC_IDS_JSON##"', "{{#toJson}}__doc_ids{{/toJson}}")
.replace('"##DOC_IDS_SIZE##"', "{{__size}}")
)
def _extract_query_features(
self, query_params: Mapping[str, Any], doc_ids: List[str]
):
default_query_scores = dict(
(extractor.feature_name, extractor.default_score)
for extractor in self._model_config.query_feature_extractors
)
matched_queries = self._execute_search_template_request(
script_source=self._script_source,
template_params={
**query_params,
"__doc_ids": doc_ids,
"__size": len(doc_ids),
},
)
return {
hit_id: {**default_query_scores, **matched_queries_scores}
for hit_id, matched_queries_scores in matched_queries.items()
}
def _execute_search_template_request(
self, script_source: str, template_params: Mapping[str, any]
):
# When support for include_named_queries_score will be added,
# this will be replaced by the call to the client search_template method.
from elasticsearch._sync.client import _quote
__path = f"/{_quote(self._index_name)}/_search/template"
__query = {"include_named_queries_score": True}
__headers = {"accept": "application/json", "content-type": "application/json"}
__body = {"source": script_source, "params": template_params}
return {
hit["_id"]: hit["matched_queries"] if "matched_queries" in hit else {}
for hit in self._client.perform_request(
"GET", __path, params=__query, headers=__headers, body=__body
)["hits"]["hits"]
}

View File

@ -0,0 +1,156 @@
# Licensed to Elasticsearch B.V. under one or more contributor
# license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright
# ownership. Elasticsearch B.V. licenses this file to you under
# the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from functools import cached_property
from typing import Any, Dict, List, Mapping, Optional
from eland.ml.common import TYPE_LEARNING_TO_RANK
class FeatureExtractor:
"""
A base class representing a generic feature extractor.
"""
def __init__(self, type: str, feature_name: str):
"""
Parameters
----------
type: str
Type of the feature extractor.
feature_name: str
Name of the extracted features.
"""
self.feature_name = feature_name
self.type = type
def to_dict(self) -> Dict[str, Any]:
"""Convert the feature extractor into a dict that can be send to ES as part of the inference config."""
return {
self.type: {
k: v.to_dict() if hasattr(v, "to_dict") else v
for k, v in self.__dict__.items()
if v is not None and k != "type"
}
}
class QueryFeatureExtractor(FeatureExtractor):
"""
A class that allows to define a query feature extractor.
"""
def __init__(
self,
feature_name: str,
query: Mapping[str, Any],
default_score: Optional[float] = float(0),
):
"""
Parameters
----------
feature_name: str
Name of the extracted features.
query: Mapping[str, Any]
Templated query used to extract the feature.
default_score: str
Scored used by default when the doc is not matching the query.
Examples
--------
>>> from eland.ml.ltr import QueryFeatureExtractor
>>> query_feature_extractor = QueryFeatureExtractor(
... feature_name='title_bm25',
... query={ "match": { "title": "{{query}}" } }
... )
"""
super().__init__(feature_name=feature_name, type="query_extractor")
self.query = query
self.default_score = default_score
class LTRModelConfig:
"""
A class representing LTR model configuration.
"""
def __init__(self, feature_extractors: List[FeatureExtractor]):
"""
Parameters
----------
feature_extractors: List[FeatureExtractor]
List of the feature extractors for the LTR model.
Examples
--------
>>> from eland.ml.ltr import LTRModelConfig, QueryFeatureExtractor
>>> ltr_model_config = LTRModelConfig(
... feature_extractors=[
... QueryFeatureExtractor(
... feature_name='title_bm25',
... query={ "match": { "title": "{{query}}" } }
... ),
... QueryFeatureExtractor(
... feature_name='descritption_bm25',
... query={ "match": { "description": "{{query}}" } }
... )
... ]
... )
"""
self.feature_extractors = feature_extractors
def to_dict(self) -> Mapping[str, Any]:
"""
Convert the into a dict that can be send to ES as an inference config.
"""
return {
TYPE_LEARNING_TO_RANK: {
"feature_extractors": [
feature_extractor.to_dict()
for feature_extractor in self.feature_extractors
]
}
}
@cached_property
def feature_names(self) -> List[str]:
"""
List of the feature names for the model.
"""
return [extractor.feature_name for extractor in self.feature_extractors]
@cached_property
def query_feature_extractors(self) -> List[QueryFeatureExtractor]:
"""
List of query feature extractors for the model.
"""
return [
extractor
for extractor in self.feature_extractors
if isinstance(extractor, QueryFeatureExtractor)
]
def feature_index(self, feature_name: str) -> int:
"Returns the index of the feature in the feature lists."
return self.feature_names.index(feature_name)

View File

@ -163,55 +163,27 @@ ECOMMERCE_MAPPING = {
ECOMMERCE_FILE_NAME = ROOT_DIR + "/ecommerce.json.gz" ECOMMERCE_FILE_NAME = ROOT_DIR + "/ecommerce.json.gz"
ECOMMERCE_DF_FILE_NAME = ROOT_DIR + "/ecommerce_df.json.gz" ECOMMERCE_DF_FILE_NAME = ROOT_DIR + "/ecommerce_df.json.gz"
MOVIES_INDEX_NAME = "movies" NATIONAL_PARKS_INDEX_NAME = "national_parks"
MOVIES_FILE_NAME = ROOT_DIR + "/movies.json.gz" NATIONAL_PARKS_FILE_NAME = ROOT_DIR + "/national-parks.json.gz"
MOVIES_MAPPING = { NATIONAL_PARKS_MAPPING = {
"mappings": { "mappings": {
"properties": { "properties": {
"type": {"type": "keyword"},
"title": {"type": "text"},
"year": {"type": "integer"},
"rated": {"type": "keyword"},
"released": {"type": "date"},
"plot": {"type": "text"},
"awards": {"type": "text"},
"poster": {"type": "keyword"},
"id": {"type": "keyword"}, "id": {"type": "keyword"},
"metascore": {"type": "float"}, "title": {"type": "text"},
"imdbRating": {"type": "float"}, "description": {"type": "text"},
"imdbVotes": {"type": "integer"}, "nps_link": {"type": "text", "index": False},
"language": {"type": "keyword"}, "date_established": {"type": "date"},
"runtime": {"type": "integer"}, "location": {"type": "geo_point"},
"genres": { "states": {
"type": "text",
"fields": {
"keyword": {"type": "keyword"},
},
},
"directors": {
"type": "text",
"fields": {
"keyword": {"type": "keyword"},
},
},
"writers": {
"type": "text",
"fields": {
"keyword": {"type": "keyword"},
},
},
"actors": {
"type": "text",
"fields": {
"keyword": {"type": "keyword"},
},
},
"country": {
"type": "text", "type": "text",
"fields": { "fields": {
"keyword": {"type": "keyword"}, "keyword": {"type": "keyword"},
}, },
}, },
"visitors": {"type": "integer"},
"world_heritage_site": {"type": "boolean"},
"acres": {"type": "float"},
"square_km": {"type": "float"},
} }
} }
} }

View File

@ -0,0 +1,77 @@
# Licensed to Elasticsearch B.V. under one or more contributor
# license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright
# ownership. Elasticsearch B.V. licenses this file to you under
# the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from eland.ml.ltr import FeatureLogger, LTRModelConfig, QueryFeatureExtractor
from tests import ES_TEST_CLIENT, NATIONAL_PARKS_INDEX_NAME
class TestFeatureLogger:
def test_extract_feature(self):
# Create the feature logger and some document extract features for a query.
ltr_model_config = self._ltr_model_config()
feature_logger = FeatureLogger(
ES_TEST_CLIENT, NATIONAL_PARKS_INDEX_NAME, ltr_model_config
)
doc_ids = ["park_yosemite", "park_hawaii-volcanoes", "park_death-valley"]
doc_features = feature_logger.extract_features(
query_params={"query": "yosemite"}, doc_ids=doc_ids
)
# Assert all docs are presents.
assert len(doc_features) == len(doc_ids) and all(
doc_id in doc_ids for doc_id in doc_features.keys()
)
# Check all features are extracted for all docs
assert all(
len(features) == len(ltr_model_config.feature_extractors)
for features in doc_features.values()
)
print(doc_features)
# "park_yosemite" document matches for title and is a world heritage site
assert (
doc_features["park_yosemite"][0] > 0
and doc_features["park_yosemite"][1] > 1
)
# "park_hawaii-volcanoes" document does not matches for title but is a world heritage site
assert (
doc_features["park_hawaii-volcanoes"][0] == 0
and doc_features["park_hawaii-volcanoes"][1] > 1
)
# "park_hawaii-volcanoes" document does not matches for title and is not a world heritage site
assert doc_features["park_death-valley"] == [0, 0]
def _ltr_model_config(self):
# Returns an LTR config with 2 query feature extractors:
# - title_bm25: BM25 score of the match query on the title field
# - popularity: Value of the popularity field
return LTRModelConfig(
[
QueryFeatureExtractor(
feature_name="title_bm25", query={"match": {"title": "{{query}}"}}
),
QueryFeatureExtractor(
feature_name="world_heritage_site",
query={"term": {"world_heritage_site": True}},
),
]
)

View File

@ -27,7 +27,7 @@ from tests import (
ES_TEST_CLIENT, ES_TEST_CLIENT,
ES_VERSION, ES_VERSION,
FLIGHTS_SMALL_INDEX_NAME, FLIGHTS_SMALL_INDEX_NAME,
MOVIES_INDEX_NAME, NATIONAL_PARKS_INDEX_NAME,
) )
try: try:
@ -339,11 +339,11 @@ class TestMLModel:
}, },
{ {
"query_extractor": { "query_extractor": {
"feature_name": "imdb_rating", "feature_name": "visitors",
"query": { "query": {
"script_score": { "script_score": {
"query": {"exists": {"field": "imdbRating"}}, "query": {"exists": {"field": "visitors"}},
"script": {"source": 'return doc["imdbRating"].value;'}, "script": {"source": 'return doc["visitors"].value;'},
} }
}, },
} }
@ -383,12 +383,12 @@ class TestMLModel:
# Execute search with rescoring # Execute search with rescoring
search_result = ES_TEST_CLIENT.search( search_result = ES_TEST_CLIENT.search(
index=MOVIES_INDEX_NAME, index=NATIONAL_PARKS_INDEX_NAME,
query={"terms": {"_id": ["tt1318514", "tt0071562"]}}, query={"terms": {"_id": ["park_yosemite", "park_everglades"]}},
rescore={ rescore={
"learning_to_rank": { "learning_to_rank": {
"model_id": model_id, "model_id": model_id,
"params": {"query_string": "planet of the apes"}, "params": {"query_string": "yosemite"},
} }
}, },
) )

Binary file not shown.

Binary file not shown.

View File

@ -30,9 +30,9 @@ from tests import (
FLIGHTS_MAPPING, FLIGHTS_MAPPING,
FLIGHTS_SMALL_FILE_NAME, FLIGHTS_SMALL_FILE_NAME,
FLIGHTS_SMALL_INDEX_NAME, FLIGHTS_SMALL_INDEX_NAME,
MOVIES_FILE_NAME, NATIONAL_PARKS_FILE_NAME,
MOVIES_INDEX_NAME, NATIONAL_PARKS_INDEX_NAME,
MOVIES_MAPPING, NATIONAL_PARKS_MAPPING,
TEST_MAPPING1, TEST_MAPPING1,
TEST_MAPPING1_INDEX_NAME, TEST_MAPPING1_INDEX_NAME,
TEST_NESTED_USER_GROUP_DOCS, TEST_NESTED_USER_GROUP_DOCS,
@ -44,7 +44,7 @@ DATA_LIST = [
(FLIGHTS_FILE_NAME, FLIGHTS_INDEX_NAME, FLIGHTS_MAPPING), (FLIGHTS_FILE_NAME, FLIGHTS_INDEX_NAME, FLIGHTS_MAPPING),
(FLIGHTS_SMALL_FILE_NAME, FLIGHTS_SMALL_INDEX_NAME, FLIGHTS_MAPPING), (FLIGHTS_SMALL_FILE_NAME, FLIGHTS_SMALL_INDEX_NAME, FLIGHTS_MAPPING),
(ECOMMERCE_FILE_NAME, ECOMMERCE_INDEX_NAME, ECOMMERCE_MAPPING), (ECOMMERCE_FILE_NAME, ECOMMERCE_INDEX_NAME, ECOMMERCE_MAPPING),
(MOVIES_FILE_NAME, MOVIES_INDEX_NAME, MOVIES_MAPPING), (NATIONAL_PARKS_FILE_NAME, NATIONAL_PARKS_INDEX_NAME, NATIONAL_PARKS_MAPPING),
] ]