mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
Accept LTR inference config when creating model (#645)
* Support for supplying inference_config * Fix linting errors * Add unit test * Add LTR type, throw exception on predict, refine test * Add search step to LTR test * Fix linter errors * Update rescoring assertion in test + type defs * Fix linting error * Remove failing assertion
This commit is contained in:
parent
05c5859b8a
commit
840871f9d9
@ -16,4 +16,5 @@
|
||||
# under the License.
|
||||
|
||||
TYPE_CLASSIFICATION = "classification"
|
||||
TYPE_LEARNING_TO_RANK = "learning_to_rank"
|
||||
TYPE_REGRESSION = "regression"
|
||||
|
@ -23,7 +23,7 @@ import numpy as np
|
||||
from eland.common import ensure_es_client, es_version
|
||||
from eland.utils import deprecated_api
|
||||
|
||||
from .common import TYPE_CLASSIFICATION, TYPE_REGRESSION
|
||||
from .common import TYPE_CLASSIFICATION, TYPE_LEARNING_TO_RANK, TYPE_REGRESSION
|
||||
from .transformers import get_model_transformer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -130,6 +130,11 @@ class MLModel:
|
||||
>>> # Delete model from Elasticsearch
|
||||
>>> es_model.delete_model()
|
||||
"""
|
||||
if self.model_type not in (TYPE_CLASSIFICATION, TYPE_REGRESSION):
|
||||
raise NotImplementedError(
|
||||
f"Prediction for type {self.model_type} is not supported."
|
||||
)
|
||||
|
||||
docs: List[Mapping[str, Any]] = []
|
||||
if isinstance(X, np.ndarray):
|
||||
|
||||
@ -215,6 +220,8 @@ class MLModel:
|
||||
inference_config = self._trained_model_config["inference_config"]
|
||||
if "classification" in inference_config:
|
||||
return TYPE_CLASSIFICATION
|
||||
elif "learning_to_rank" in inference_config:
|
||||
return TYPE_LEARNING_TO_RANK
|
||||
elif "regression" in inference_config:
|
||||
return TYPE_REGRESSION
|
||||
raise ValueError("Unable to determine 'model_type' for MLModel")
|
||||
@ -254,6 +261,7 @@ class MLModel:
|
||||
classification_weights: Optional[List[float]] = None,
|
||||
es_if_exists: Optional[str] = None,
|
||||
es_compress_model_definition: bool = True,
|
||||
inference_config: Optional[Mapping[str, Mapping[str, Any]]] = None,
|
||||
) -> "MLModel":
|
||||
"""
|
||||
Transform and serialize a trained 3rd party model into Elasticsearch.
|
||||
@ -324,6 +332,10 @@ class MLModel:
|
||||
JSON instead of raw JSON to reduce the amount of data sent
|
||||
over the wire in HTTP requests. Defaults to 'True'.
|
||||
|
||||
inference_config: Mapping[str, Mapping[str, Any]]
|
||||
Model inference configuration. Must contain a top-level property whose name is the same as the inference
|
||||
task type.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> from sklearn import datasets
|
||||
@ -367,6 +379,7 @@ class MLModel:
|
||||
)
|
||||
serializer = transformer.transform()
|
||||
model_type = transformer.model_type
|
||||
default_inference_config: Mapping[str, Mapping[str, Any]] = {model_type: {}}
|
||||
|
||||
if es_if_exists is None:
|
||||
es_if_exists = "fail"
|
||||
@ -389,14 +402,14 @@ class MLModel:
|
||||
ml_model._client.ml.put_trained_model(
|
||||
model_id=model_id,
|
||||
input={"field_names": feature_names},
|
||||
inference_config={model_type: {}},
|
||||
inference_config=inference_config or default_inference_config,
|
||||
compressed_definition=serializer.serialize_and_compress_model(),
|
||||
)
|
||||
else:
|
||||
ml_model._client.ml.put_trained_model(
|
||||
model_id=model_id,
|
||||
input={"field_names": feature_names},
|
||||
inference_config={model_type: {}},
|
||||
inference_config=inference_config or default_inference_config,
|
||||
definition=serializer.serialize_model(),
|
||||
)
|
||||
|
||||
|
@ -16,13 +16,19 @@
|
||||
# under the License.
|
||||
|
||||
from operator import itemgetter
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import eland as ed
|
||||
from eland.ml import MLModel
|
||||
from tests import ES_TEST_CLIENT, ES_VERSION, FLIGHTS_SMALL_INDEX_NAME
|
||||
from tests import (
|
||||
ES_TEST_CLIENT,
|
||||
ES_VERSION,
|
||||
FLIGHTS_SMALL_INDEX_NAME,
|
||||
MOVIES_INDEX_NAME,
|
||||
)
|
||||
|
||||
try:
|
||||
from sklearn import datasets
|
||||
@ -70,10 +76,17 @@ requires_no_ml_extras = pytest.mark.skipif(
|
||||
)
|
||||
|
||||
requires_lightgbm = pytest.mark.skipif(
|
||||
not HAS_LIGHTGBM, reason="This test requires 'lightgbm' package to run"
|
||||
not HAS_LIGHTGBM, reason="This test requires 'lightgbm' package to run."
|
||||
)
|
||||
|
||||
|
||||
def requires_elasticsearch_version(minimum_version: Tuple[int, int, int]):
|
||||
return pytest.mark.skipif(
|
||||
ES_VERSION < minimum_version,
|
||||
reason=f"This test requires Elasticsearch version {'.'.join(str(v) for v in minimum_version)} or later.",
|
||||
)
|
||||
|
||||
|
||||
def skip_if_multiclass_classifition():
|
||||
if ES_VERSION < (7, 7):
|
||||
raise pytest.skip(
|
||||
@ -306,6 +319,96 @@ class TestMLModel:
|
||||
# Clean up
|
||||
es_model.delete_model()
|
||||
|
||||
@requires_elasticsearch_version((8, 12))
|
||||
@requires_sklearn
|
||||
@pytest.mark.parametrize("compress_model_definition", [True, False])
|
||||
def test_learning_to_rank(self, compress_model_definition):
|
||||
# Train model
|
||||
training_data = datasets.make_regression(n_features=2)
|
||||
regressor = DecisionTreeRegressor()
|
||||
regressor.fit(training_data[0], training_data[1])
|
||||
|
||||
# Serialise the models to Elasticsearch
|
||||
model_id = "test_learning_to_rank"
|
||||
feature_extractors = [
|
||||
{
|
||||
"query_extractor": {
|
||||
"feature_name": "title_bm25",
|
||||
"query": {"match": {"title": "{{query_string}}"}},
|
||||
}
|
||||
},
|
||||
{
|
||||
"query_extractor": {
|
||||
"feature_name": "imdb_rating",
|
||||
"query": {
|
||||
"script_score": {
|
||||
"query": {"exists": {"field": "imdbRating"}},
|
||||
"script": {"source": 'return doc["imdbRating"].value;'},
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
]
|
||||
feature_names = [
|
||||
extractor["query_extractor"]["feature_name"]
|
||||
for extractor in feature_extractors
|
||||
]
|
||||
inference_config = {
|
||||
"learning_to_rank": {"feature_extractors": feature_extractors}
|
||||
}
|
||||
|
||||
es_model = MLModel.import_model(
|
||||
ES_TEST_CLIENT,
|
||||
model_id,
|
||||
regressor,
|
||||
feature_names,
|
||||
es_if_exists="replace",
|
||||
es_compress_model_definition=compress_model_definition,
|
||||
inference_config=inference_config,
|
||||
)
|
||||
|
||||
# Verify the saved inference config contains the passed LTR config
|
||||
response = ES_TEST_CLIENT.ml.get_trained_models(model_id=model_id)
|
||||
assert response.meta.status == 200
|
||||
assert response.body["count"] == 1
|
||||
saved_inference_config = response.body["trained_model_configs"][0][
|
||||
"inference_config"
|
||||
]
|
||||
assert "learning_to_rank" in saved_inference_config
|
||||
saved_ltr_config = saved_inference_config["learning_to_rank"]
|
||||
assert all(
|
||||
item in saved_ltr_config.items()
|
||||
for item in inference_config["learning_to_rank"].items()
|
||||
)
|
||||
|
||||
# Execute search with rescoring
|
||||
search_result = ES_TEST_CLIENT.search(
|
||||
index=MOVIES_INDEX_NAME,
|
||||
query={"terms": {"_id": ["tt1318514", "tt0071562"]}},
|
||||
rescore={
|
||||
"learning_to_rank": {
|
||||
"model_id": model_id,
|
||||
"params": {"query_string": "planet of the apes"},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# Assert that:
|
||||
# - all documents from the query are present
|
||||
# - all documents have been rescored (score != 1.0)
|
||||
doc_scores = [hit["_score"] for hit in search_result["hits"]["hits"]]
|
||||
assert len(search_result["hits"]["hits"]) == 2
|
||||
assert all(score != float(1) for score in doc_scores)
|
||||
|
||||
# Verify prediction is not supported for LTR
|
||||
try:
|
||||
es_model.predict([0])
|
||||
except NotImplementedError:
|
||||
pass
|
||||
|
||||
# Clean up
|
||||
es_model.delete_model()
|
||||
|
||||
@requires_sklearn
|
||||
@pytest.mark.parametrize("compress_model_definition", [True, False])
|
||||
def test_random_forest_classifier(self, compress_model_definition):
|
||||
|
Loading…
x
Reference in New Issue
Block a user