Accept LTR inference config when creating model (#645)

* Support for supplying inference_config

* Fix linting errors

* Add unit test

* Add LTR type, throw exception on predict, refine test

* Add search step to LTR test

* Fix linter errors

* Update rescoring assertion in test + type defs

* Fix linting error

* Remove failing assertion
This commit is contained in:
Adam Demjen 2024-01-08 09:19:03 -05:00 committed by GitHub
parent 05c5859b8a
commit 840871f9d9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 122 additions and 5 deletions

View File

@ -16,4 +16,5 @@
# under the License.
TYPE_CLASSIFICATION = "classification"
TYPE_LEARNING_TO_RANK = "learning_to_rank"
TYPE_REGRESSION = "regression"

View File

@ -23,7 +23,7 @@ import numpy as np
from eland.common import ensure_es_client, es_version
from eland.utils import deprecated_api
from .common import TYPE_CLASSIFICATION, TYPE_REGRESSION
from .common import TYPE_CLASSIFICATION, TYPE_LEARNING_TO_RANK, TYPE_REGRESSION
from .transformers import get_model_transformer
if TYPE_CHECKING:
@ -130,6 +130,11 @@ class MLModel:
>>> # Delete model from Elasticsearch
>>> es_model.delete_model()
"""
if self.model_type not in (TYPE_CLASSIFICATION, TYPE_REGRESSION):
raise NotImplementedError(
f"Prediction for type {self.model_type} is not supported."
)
docs: List[Mapping[str, Any]] = []
if isinstance(X, np.ndarray):
@ -215,6 +220,8 @@ class MLModel:
inference_config = self._trained_model_config["inference_config"]
if "classification" in inference_config:
return TYPE_CLASSIFICATION
elif "learning_to_rank" in inference_config:
return TYPE_LEARNING_TO_RANK
elif "regression" in inference_config:
return TYPE_REGRESSION
raise ValueError("Unable to determine 'model_type' for MLModel")
@ -254,6 +261,7 @@ class MLModel:
classification_weights: Optional[List[float]] = None,
es_if_exists: Optional[str] = None,
es_compress_model_definition: bool = True,
inference_config: Optional[Mapping[str, Mapping[str, Any]]] = None,
) -> "MLModel":
"""
Transform and serialize a trained 3rd party model into Elasticsearch.
@ -324,6 +332,10 @@ class MLModel:
JSON instead of raw JSON to reduce the amount of data sent
over the wire in HTTP requests. Defaults to 'True'.
inference_config: Mapping[str, Mapping[str, Any]]
Model inference configuration. Must contain a top-level property whose name is the same as the inference
task type.
Examples
--------
>>> from sklearn import datasets
@ -367,6 +379,7 @@ class MLModel:
)
serializer = transformer.transform()
model_type = transformer.model_type
default_inference_config: Mapping[str, Mapping[str, Any]] = {model_type: {}}
if es_if_exists is None:
es_if_exists = "fail"
@ -389,14 +402,14 @@ class MLModel:
ml_model._client.ml.put_trained_model(
model_id=model_id,
input={"field_names": feature_names},
inference_config={model_type: {}},
inference_config=inference_config or default_inference_config,
compressed_definition=serializer.serialize_and_compress_model(),
)
else:
ml_model._client.ml.put_trained_model(
model_id=model_id,
input={"field_names": feature_names},
inference_config={model_type: {}},
inference_config=inference_config or default_inference_config,
definition=serializer.serialize_model(),
)

View File

@ -16,13 +16,19 @@
# under the License.
from operator import itemgetter
from typing import Tuple
import numpy as np
import pytest
import eland as ed
from eland.ml import MLModel
from tests import ES_TEST_CLIENT, ES_VERSION, FLIGHTS_SMALL_INDEX_NAME
from tests import (
ES_TEST_CLIENT,
ES_VERSION,
FLIGHTS_SMALL_INDEX_NAME,
MOVIES_INDEX_NAME,
)
try:
from sklearn import datasets
@ -70,10 +76,17 @@ requires_no_ml_extras = pytest.mark.skipif(
)
requires_lightgbm = pytest.mark.skipif(
not HAS_LIGHTGBM, reason="This test requires 'lightgbm' package to run"
not HAS_LIGHTGBM, reason="This test requires 'lightgbm' package to run."
)
def requires_elasticsearch_version(minimum_version: Tuple[int, int, int]):
return pytest.mark.skipif(
ES_VERSION < minimum_version,
reason=f"This test requires Elasticsearch version {'.'.join(str(v) for v in minimum_version)} or later.",
)
def skip_if_multiclass_classifition():
if ES_VERSION < (7, 7):
raise pytest.skip(
@ -306,6 +319,96 @@ class TestMLModel:
# Clean up
es_model.delete_model()
@requires_elasticsearch_version((8, 12))
@requires_sklearn
@pytest.mark.parametrize("compress_model_definition", [True, False])
def test_learning_to_rank(self, compress_model_definition):
# Train model
training_data = datasets.make_regression(n_features=2)
regressor = DecisionTreeRegressor()
regressor.fit(training_data[0], training_data[1])
# Serialise the models to Elasticsearch
model_id = "test_learning_to_rank"
feature_extractors = [
{
"query_extractor": {
"feature_name": "title_bm25",
"query": {"match": {"title": "{{query_string}}"}},
}
},
{
"query_extractor": {
"feature_name": "imdb_rating",
"query": {
"script_score": {
"query": {"exists": {"field": "imdbRating"}},
"script": {"source": 'return doc["imdbRating"].value;'},
}
},
}
},
]
feature_names = [
extractor["query_extractor"]["feature_name"]
for extractor in feature_extractors
]
inference_config = {
"learning_to_rank": {"feature_extractors": feature_extractors}
}
es_model = MLModel.import_model(
ES_TEST_CLIENT,
model_id,
regressor,
feature_names,
es_if_exists="replace",
es_compress_model_definition=compress_model_definition,
inference_config=inference_config,
)
# Verify the saved inference config contains the passed LTR config
response = ES_TEST_CLIENT.ml.get_trained_models(model_id=model_id)
assert response.meta.status == 200
assert response.body["count"] == 1
saved_inference_config = response.body["trained_model_configs"][0][
"inference_config"
]
assert "learning_to_rank" in saved_inference_config
saved_ltr_config = saved_inference_config["learning_to_rank"]
assert all(
item in saved_ltr_config.items()
for item in inference_config["learning_to_rank"].items()
)
# Execute search with rescoring
search_result = ES_TEST_CLIENT.search(
index=MOVIES_INDEX_NAME,
query={"terms": {"_id": ["tt1318514", "tt0071562"]}},
rescore={
"learning_to_rank": {
"model_id": model_id,
"params": {"query_string": "planet of the apes"},
}
},
)
# Assert that:
# - all documents from the query are present
# - all documents have been rescored (score != 1.0)
doc_scores = [hit["_score"] for hit in search_result["hits"]["hits"]]
assert len(search_result["hits"]["hits"]) == 2
assert all(score != float(1) for score in doc_scores)
# Verify prediction is not supported for LTR
try:
es_model.predict([0])
except NotImplementedError:
pass
# Clean up
es_model.delete_model()
@requires_sklearn
@pytest.mark.parametrize("compress_model_definition", [True, False])
def test_random_forest_classifier(self, compress_model_definition):