Accept LTR inference config when creating model (#645)

* Support for supplying inference_config

* Fix linting errors

* Add unit test

* Add LTR type, throw exception on predict, refine test

* Add search step to LTR test

* Fix linter errors

* Update rescoring assertion in test + type defs

* Fix linting error

* Remove failing assertion
This commit is contained in:
Adam Demjen 2024-01-08 09:19:03 -05:00 committed by GitHub
parent 05c5859b8a
commit 840871f9d9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 122 additions and 5 deletions

View File

@ -16,4 +16,5 @@
# under the License. # under the License.
TYPE_CLASSIFICATION = "classification" TYPE_CLASSIFICATION = "classification"
TYPE_LEARNING_TO_RANK = "learning_to_rank"
TYPE_REGRESSION = "regression" TYPE_REGRESSION = "regression"

View File

@ -23,7 +23,7 @@ import numpy as np
from eland.common import ensure_es_client, es_version from eland.common import ensure_es_client, es_version
from eland.utils import deprecated_api from eland.utils import deprecated_api
from .common import TYPE_CLASSIFICATION, TYPE_REGRESSION from .common import TYPE_CLASSIFICATION, TYPE_LEARNING_TO_RANK, TYPE_REGRESSION
from .transformers import get_model_transformer from .transformers import get_model_transformer
if TYPE_CHECKING: if TYPE_CHECKING:
@ -130,6 +130,11 @@ class MLModel:
>>> # Delete model from Elasticsearch >>> # Delete model from Elasticsearch
>>> es_model.delete_model() >>> es_model.delete_model()
""" """
if self.model_type not in (TYPE_CLASSIFICATION, TYPE_REGRESSION):
raise NotImplementedError(
f"Prediction for type {self.model_type} is not supported."
)
docs: List[Mapping[str, Any]] = [] docs: List[Mapping[str, Any]] = []
if isinstance(X, np.ndarray): if isinstance(X, np.ndarray):
@ -215,6 +220,8 @@ class MLModel:
inference_config = self._trained_model_config["inference_config"] inference_config = self._trained_model_config["inference_config"]
if "classification" in inference_config: if "classification" in inference_config:
return TYPE_CLASSIFICATION return TYPE_CLASSIFICATION
elif "learning_to_rank" in inference_config:
return TYPE_LEARNING_TO_RANK
elif "regression" in inference_config: elif "regression" in inference_config:
return TYPE_REGRESSION return TYPE_REGRESSION
raise ValueError("Unable to determine 'model_type' for MLModel") raise ValueError("Unable to determine 'model_type' for MLModel")
@ -254,6 +261,7 @@ class MLModel:
classification_weights: Optional[List[float]] = None, classification_weights: Optional[List[float]] = None,
es_if_exists: Optional[str] = None, es_if_exists: Optional[str] = None,
es_compress_model_definition: bool = True, es_compress_model_definition: bool = True,
inference_config: Optional[Mapping[str, Mapping[str, Any]]] = None,
) -> "MLModel": ) -> "MLModel":
""" """
Transform and serialize a trained 3rd party model into Elasticsearch. Transform and serialize a trained 3rd party model into Elasticsearch.
@ -324,6 +332,10 @@ class MLModel:
JSON instead of raw JSON to reduce the amount of data sent JSON instead of raw JSON to reduce the amount of data sent
over the wire in HTTP requests. Defaults to 'True'. over the wire in HTTP requests. Defaults to 'True'.
inference_config: Mapping[str, Mapping[str, Any]]
Model inference configuration. Must contain a top-level property whose name is the same as the inference
task type.
Examples Examples
-------- --------
>>> from sklearn import datasets >>> from sklearn import datasets
@ -367,6 +379,7 @@ class MLModel:
) )
serializer = transformer.transform() serializer = transformer.transform()
model_type = transformer.model_type model_type = transformer.model_type
default_inference_config: Mapping[str, Mapping[str, Any]] = {model_type: {}}
if es_if_exists is None: if es_if_exists is None:
es_if_exists = "fail" es_if_exists = "fail"
@ -389,14 +402,14 @@ class MLModel:
ml_model._client.ml.put_trained_model( ml_model._client.ml.put_trained_model(
model_id=model_id, model_id=model_id,
input={"field_names": feature_names}, input={"field_names": feature_names},
inference_config={model_type: {}}, inference_config=inference_config or default_inference_config,
compressed_definition=serializer.serialize_and_compress_model(), compressed_definition=serializer.serialize_and_compress_model(),
) )
else: else:
ml_model._client.ml.put_trained_model( ml_model._client.ml.put_trained_model(
model_id=model_id, model_id=model_id,
input={"field_names": feature_names}, input={"field_names": feature_names},
inference_config={model_type: {}}, inference_config=inference_config or default_inference_config,
definition=serializer.serialize_model(), definition=serializer.serialize_model(),
) )

View File

@ -16,13 +16,19 @@
# under the License. # under the License.
from operator import itemgetter from operator import itemgetter
from typing import Tuple
import numpy as np import numpy as np
import pytest import pytest
import eland as ed import eland as ed
from eland.ml import MLModel from eland.ml import MLModel
from tests import ES_TEST_CLIENT, ES_VERSION, FLIGHTS_SMALL_INDEX_NAME from tests import (
ES_TEST_CLIENT,
ES_VERSION,
FLIGHTS_SMALL_INDEX_NAME,
MOVIES_INDEX_NAME,
)
try: try:
from sklearn import datasets from sklearn import datasets
@ -70,10 +76,17 @@ requires_no_ml_extras = pytest.mark.skipif(
) )
requires_lightgbm = pytest.mark.skipif( requires_lightgbm = pytest.mark.skipif(
not HAS_LIGHTGBM, reason="This test requires 'lightgbm' package to run" not HAS_LIGHTGBM, reason="This test requires 'lightgbm' package to run."
) )
def requires_elasticsearch_version(minimum_version: Tuple[int, int, int]):
return pytest.mark.skipif(
ES_VERSION < minimum_version,
reason=f"This test requires Elasticsearch version {'.'.join(str(v) for v in minimum_version)} or later.",
)
def skip_if_multiclass_classifition(): def skip_if_multiclass_classifition():
if ES_VERSION < (7, 7): if ES_VERSION < (7, 7):
raise pytest.skip( raise pytest.skip(
@ -306,6 +319,96 @@ class TestMLModel:
# Clean up # Clean up
es_model.delete_model() es_model.delete_model()
@requires_elasticsearch_version((8, 12))
@requires_sklearn
@pytest.mark.parametrize("compress_model_definition", [True, False])
def test_learning_to_rank(self, compress_model_definition):
# Train model
training_data = datasets.make_regression(n_features=2)
regressor = DecisionTreeRegressor()
regressor.fit(training_data[0], training_data[1])
# Serialise the models to Elasticsearch
model_id = "test_learning_to_rank"
feature_extractors = [
{
"query_extractor": {
"feature_name": "title_bm25",
"query": {"match": {"title": "{{query_string}}"}},
}
},
{
"query_extractor": {
"feature_name": "imdb_rating",
"query": {
"script_score": {
"query": {"exists": {"field": "imdbRating"}},
"script": {"source": 'return doc["imdbRating"].value;'},
}
},
}
},
]
feature_names = [
extractor["query_extractor"]["feature_name"]
for extractor in feature_extractors
]
inference_config = {
"learning_to_rank": {"feature_extractors": feature_extractors}
}
es_model = MLModel.import_model(
ES_TEST_CLIENT,
model_id,
regressor,
feature_names,
es_if_exists="replace",
es_compress_model_definition=compress_model_definition,
inference_config=inference_config,
)
# Verify the saved inference config contains the passed LTR config
response = ES_TEST_CLIENT.ml.get_trained_models(model_id=model_id)
assert response.meta.status == 200
assert response.body["count"] == 1
saved_inference_config = response.body["trained_model_configs"][0][
"inference_config"
]
assert "learning_to_rank" in saved_inference_config
saved_ltr_config = saved_inference_config["learning_to_rank"]
assert all(
item in saved_ltr_config.items()
for item in inference_config["learning_to_rank"].items()
)
# Execute search with rescoring
search_result = ES_TEST_CLIENT.search(
index=MOVIES_INDEX_NAME,
query={"terms": {"_id": ["tt1318514", "tt0071562"]}},
rescore={
"learning_to_rank": {
"model_id": model_id,
"params": {"query_string": "planet of the apes"},
}
},
)
# Assert that:
# - all documents from the query are present
# - all documents have been rescored (score != 1.0)
doc_scores = [hit["_score"] for hit in search_result["hits"]["hits"]]
assert len(search_result["hits"]["hits"]) == 2
assert all(score != float(1) for score in doc_scores)
# Verify prediction is not supported for LTR
try:
es_model.predict([0])
except NotImplementedError:
pass
# Clean up
es_model.delete_model()
@requires_sklearn @requires_sklearn
@pytest.mark.parametrize("compress_model_definition", [True, False]) @pytest.mark.parametrize("compress_model_definition", [True, False])
def test_random_forest_classifier(self, compress_model_definition): def test_random_forest_classifier(self, compress_model_definition):