mirror of
https://github.com/elastic/eland.git
synced 2025-07-24 00:00:39 +08:00
Accept LTR inference config when creating model (#645)
* Support for supplying inference_config * Fix linting errors * Add unit test * Add LTR type, throw exception on predict, refine test * Add search step to LTR test * Fix linter errors * Update rescoring assertion in test + type defs * Fix linting error * Remove failing assertion
This commit is contained in:
parent
05c5859b8a
commit
840871f9d9
@ -16,4 +16,5 @@
|
|||||||
# under the License.
|
# under the License.
|
||||||
|
|
||||||
TYPE_CLASSIFICATION = "classification"
|
TYPE_CLASSIFICATION = "classification"
|
||||||
|
TYPE_LEARNING_TO_RANK = "learning_to_rank"
|
||||||
TYPE_REGRESSION = "regression"
|
TYPE_REGRESSION = "regression"
|
||||||
|
@ -23,7 +23,7 @@ import numpy as np
|
|||||||
from eland.common import ensure_es_client, es_version
|
from eland.common import ensure_es_client, es_version
|
||||||
from eland.utils import deprecated_api
|
from eland.utils import deprecated_api
|
||||||
|
|
||||||
from .common import TYPE_CLASSIFICATION, TYPE_REGRESSION
|
from .common import TYPE_CLASSIFICATION, TYPE_LEARNING_TO_RANK, TYPE_REGRESSION
|
||||||
from .transformers import get_model_transformer
|
from .transformers import get_model_transformer
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -130,6 +130,11 @@ class MLModel:
|
|||||||
>>> # Delete model from Elasticsearch
|
>>> # Delete model from Elasticsearch
|
||||||
>>> es_model.delete_model()
|
>>> es_model.delete_model()
|
||||||
"""
|
"""
|
||||||
|
if self.model_type not in (TYPE_CLASSIFICATION, TYPE_REGRESSION):
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"Prediction for type {self.model_type} is not supported."
|
||||||
|
)
|
||||||
|
|
||||||
docs: List[Mapping[str, Any]] = []
|
docs: List[Mapping[str, Any]] = []
|
||||||
if isinstance(X, np.ndarray):
|
if isinstance(X, np.ndarray):
|
||||||
|
|
||||||
@ -215,6 +220,8 @@ class MLModel:
|
|||||||
inference_config = self._trained_model_config["inference_config"]
|
inference_config = self._trained_model_config["inference_config"]
|
||||||
if "classification" in inference_config:
|
if "classification" in inference_config:
|
||||||
return TYPE_CLASSIFICATION
|
return TYPE_CLASSIFICATION
|
||||||
|
elif "learning_to_rank" in inference_config:
|
||||||
|
return TYPE_LEARNING_TO_RANK
|
||||||
elif "regression" in inference_config:
|
elif "regression" in inference_config:
|
||||||
return TYPE_REGRESSION
|
return TYPE_REGRESSION
|
||||||
raise ValueError("Unable to determine 'model_type' for MLModel")
|
raise ValueError("Unable to determine 'model_type' for MLModel")
|
||||||
@ -254,6 +261,7 @@ class MLModel:
|
|||||||
classification_weights: Optional[List[float]] = None,
|
classification_weights: Optional[List[float]] = None,
|
||||||
es_if_exists: Optional[str] = None,
|
es_if_exists: Optional[str] = None,
|
||||||
es_compress_model_definition: bool = True,
|
es_compress_model_definition: bool = True,
|
||||||
|
inference_config: Optional[Mapping[str, Mapping[str, Any]]] = None,
|
||||||
) -> "MLModel":
|
) -> "MLModel":
|
||||||
"""
|
"""
|
||||||
Transform and serialize a trained 3rd party model into Elasticsearch.
|
Transform and serialize a trained 3rd party model into Elasticsearch.
|
||||||
@ -324,6 +332,10 @@ class MLModel:
|
|||||||
JSON instead of raw JSON to reduce the amount of data sent
|
JSON instead of raw JSON to reduce the amount of data sent
|
||||||
over the wire in HTTP requests. Defaults to 'True'.
|
over the wire in HTTP requests. Defaults to 'True'.
|
||||||
|
|
||||||
|
inference_config: Mapping[str, Mapping[str, Any]]
|
||||||
|
Model inference configuration. Must contain a top-level property whose name is the same as the inference
|
||||||
|
task type.
|
||||||
|
|
||||||
Examples
|
Examples
|
||||||
--------
|
--------
|
||||||
>>> from sklearn import datasets
|
>>> from sklearn import datasets
|
||||||
@ -367,6 +379,7 @@ class MLModel:
|
|||||||
)
|
)
|
||||||
serializer = transformer.transform()
|
serializer = transformer.transform()
|
||||||
model_type = transformer.model_type
|
model_type = transformer.model_type
|
||||||
|
default_inference_config: Mapping[str, Mapping[str, Any]] = {model_type: {}}
|
||||||
|
|
||||||
if es_if_exists is None:
|
if es_if_exists is None:
|
||||||
es_if_exists = "fail"
|
es_if_exists = "fail"
|
||||||
@ -389,14 +402,14 @@ class MLModel:
|
|||||||
ml_model._client.ml.put_trained_model(
|
ml_model._client.ml.put_trained_model(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
input={"field_names": feature_names},
|
input={"field_names": feature_names},
|
||||||
inference_config={model_type: {}},
|
inference_config=inference_config or default_inference_config,
|
||||||
compressed_definition=serializer.serialize_and_compress_model(),
|
compressed_definition=serializer.serialize_and_compress_model(),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
ml_model._client.ml.put_trained_model(
|
ml_model._client.ml.put_trained_model(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
input={"field_names": feature_names},
|
input={"field_names": feature_names},
|
||||||
inference_config={model_type: {}},
|
inference_config=inference_config or default_inference_config,
|
||||||
definition=serializer.serialize_model(),
|
definition=serializer.serialize_model(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -16,13 +16,19 @@
|
|||||||
# under the License.
|
# under the License.
|
||||||
|
|
||||||
from operator import itemgetter
|
from operator import itemgetter
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import eland as ed
|
import eland as ed
|
||||||
from eland.ml import MLModel
|
from eland.ml import MLModel
|
||||||
from tests import ES_TEST_CLIENT, ES_VERSION, FLIGHTS_SMALL_INDEX_NAME
|
from tests import (
|
||||||
|
ES_TEST_CLIENT,
|
||||||
|
ES_VERSION,
|
||||||
|
FLIGHTS_SMALL_INDEX_NAME,
|
||||||
|
MOVIES_INDEX_NAME,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from sklearn import datasets
|
from sklearn import datasets
|
||||||
@ -70,10 +76,17 @@ requires_no_ml_extras = pytest.mark.skipif(
|
|||||||
)
|
)
|
||||||
|
|
||||||
requires_lightgbm = pytest.mark.skipif(
|
requires_lightgbm = pytest.mark.skipif(
|
||||||
not HAS_LIGHTGBM, reason="This test requires 'lightgbm' package to run"
|
not HAS_LIGHTGBM, reason="This test requires 'lightgbm' package to run."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def requires_elasticsearch_version(minimum_version: Tuple[int, int, int]):
|
||||||
|
return pytest.mark.skipif(
|
||||||
|
ES_VERSION < minimum_version,
|
||||||
|
reason=f"This test requires Elasticsearch version {'.'.join(str(v) for v in minimum_version)} or later.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def skip_if_multiclass_classifition():
|
def skip_if_multiclass_classifition():
|
||||||
if ES_VERSION < (7, 7):
|
if ES_VERSION < (7, 7):
|
||||||
raise pytest.skip(
|
raise pytest.skip(
|
||||||
@ -306,6 +319,96 @@ class TestMLModel:
|
|||||||
# Clean up
|
# Clean up
|
||||||
es_model.delete_model()
|
es_model.delete_model()
|
||||||
|
|
||||||
|
@requires_elasticsearch_version((8, 12))
|
||||||
|
@requires_sklearn
|
||||||
|
@pytest.mark.parametrize("compress_model_definition", [True, False])
|
||||||
|
def test_learning_to_rank(self, compress_model_definition):
|
||||||
|
# Train model
|
||||||
|
training_data = datasets.make_regression(n_features=2)
|
||||||
|
regressor = DecisionTreeRegressor()
|
||||||
|
regressor.fit(training_data[0], training_data[1])
|
||||||
|
|
||||||
|
# Serialise the models to Elasticsearch
|
||||||
|
model_id = "test_learning_to_rank"
|
||||||
|
feature_extractors = [
|
||||||
|
{
|
||||||
|
"query_extractor": {
|
||||||
|
"feature_name": "title_bm25",
|
||||||
|
"query": {"match": {"title": "{{query_string}}"}},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"query_extractor": {
|
||||||
|
"feature_name": "imdb_rating",
|
||||||
|
"query": {
|
||||||
|
"script_score": {
|
||||||
|
"query": {"exists": {"field": "imdbRating"}},
|
||||||
|
"script": {"source": 'return doc["imdbRating"].value;'},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
]
|
||||||
|
feature_names = [
|
||||||
|
extractor["query_extractor"]["feature_name"]
|
||||||
|
for extractor in feature_extractors
|
||||||
|
]
|
||||||
|
inference_config = {
|
||||||
|
"learning_to_rank": {"feature_extractors": feature_extractors}
|
||||||
|
}
|
||||||
|
|
||||||
|
es_model = MLModel.import_model(
|
||||||
|
ES_TEST_CLIENT,
|
||||||
|
model_id,
|
||||||
|
regressor,
|
||||||
|
feature_names,
|
||||||
|
es_if_exists="replace",
|
||||||
|
es_compress_model_definition=compress_model_definition,
|
||||||
|
inference_config=inference_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the saved inference config contains the passed LTR config
|
||||||
|
response = ES_TEST_CLIENT.ml.get_trained_models(model_id=model_id)
|
||||||
|
assert response.meta.status == 200
|
||||||
|
assert response.body["count"] == 1
|
||||||
|
saved_inference_config = response.body["trained_model_configs"][0][
|
||||||
|
"inference_config"
|
||||||
|
]
|
||||||
|
assert "learning_to_rank" in saved_inference_config
|
||||||
|
saved_ltr_config = saved_inference_config["learning_to_rank"]
|
||||||
|
assert all(
|
||||||
|
item in saved_ltr_config.items()
|
||||||
|
for item in inference_config["learning_to_rank"].items()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Execute search with rescoring
|
||||||
|
search_result = ES_TEST_CLIENT.search(
|
||||||
|
index=MOVIES_INDEX_NAME,
|
||||||
|
query={"terms": {"_id": ["tt1318514", "tt0071562"]}},
|
||||||
|
rescore={
|
||||||
|
"learning_to_rank": {
|
||||||
|
"model_id": model_id,
|
||||||
|
"params": {"query_string": "planet of the apes"},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert that:
|
||||||
|
# - all documents from the query are present
|
||||||
|
# - all documents have been rescored (score != 1.0)
|
||||||
|
doc_scores = [hit["_score"] for hit in search_result["hits"]["hits"]]
|
||||||
|
assert len(search_result["hits"]["hits"]) == 2
|
||||||
|
assert all(score != float(1) for score in doc_scores)
|
||||||
|
|
||||||
|
# Verify prediction is not supported for LTR
|
||||||
|
try:
|
||||||
|
es_model.predict([0])
|
||||||
|
except NotImplementedError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Clean up
|
||||||
|
es_model.delete_model()
|
||||||
|
|
||||||
@requires_sklearn
|
@requires_sklearn
|
||||||
@pytest.mark.parametrize("compress_model_definition", [True, False])
|
@pytest.mark.parametrize("compress_model_definition", [True, False])
|
||||||
def test_random_forest_classifier(self, compress_model_definition):
|
def test_random_forest_classifier(self, compress_model_definition):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user