mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
Remove input fields from exported LTR models (#708)
This commit is contained in:
parent
f18aa35e8e
commit
bee6d0e1f7
@ -32,6 +32,6 @@ steps:
|
|||||||
- '3.9'
|
- '3.9'
|
||||||
- '3.8'
|
- '3.8'
|
||||||
stack:
|
stack:
|
||||||
- '8.13.0-SNAPSHOT'
|
- '8.15.0-SNAPSHOT'
|
||||||
- '8.12.2'
|
- '8.14.1'
|
||||||
command: ./.buildkite/run-tests
|
command: ./.buildkite/run-tests
|
||||||
|
@ -33,7 +33,7 @@ from elastic_transport.client_utils import DEFAULT
|
|||||||
from elasticsearch import AuthenticationException, Elasticsearch
|
from elasticsearch import AuthenticationException, Elasticsearch
|
||||||
|
|
||||||
from eland._version import __version__
|
from eland._version import __version__
|
||||||
from eland.common import parse_es_version
|
from eland.common import is_serverless_es, parse_es_version
|
||||||
|
|
||||||
MODEL_HUB_URL = "https://huggingface.co"
|
MODEL_HUB_URL = "https://huggingface.co"
|
||||||
|
|
||||||
@ -197,10 +197,7 @@ def get_es_client(cli_args, logger):
|
|||||||
def check_cluster_version(es_client, logger):
|
def check_cluster_version(es_client, logger):
|
||||||
es_info = es_client.info()
|
es_info = es_client.info()
|
||||||
|
|
||||||
if (
|
if is_serverless_es(es_client):
|
||||||
"build_flavor" in es_info["version"]
|
|
||||||
and es_info["version"]["build_flavor"] == "serverless"
|
|
||||||
):
|
|
||||||
logger.info(f"Connected to serverless cluster '{es_info['cluster_name']}'")
|
logger.info(f"Connected to serverless cluster '{es_info['cluster_name']}'")
|
||||||
# Serverless is compatible
|
# Serverless is compatible
|
||||||
# Return the latest known semantic version, i.e. this version
|
# Return the latest known semantic version, i.e. this version
|
||||||
|
@ -344,6 +344,17 @@ def es_version(es_client: Elasticsearch) -> Tuple[int, int, int]:
|
|||||||
return eland_es_version
|
return eland_es_version
|
||||||
|
|
||||||
|
|
||||||
|
def is_serverless_es(es_client: Elasticsearch) -> bool:
|
||||||
|
"""
|
||||||
|
Returns true if the client is connected to a serverless instance of Elasticsearch.
|
||||||
|
"""
|
||||||
|
es_info = es_client.info()
|
||||||
|
return (
|
||||||
|
"build_flavor" in es_info["version"]
|
||||||
|
and es_info["version"]["build_flavor"] == "serverless"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def parse_es_version(version: str) -> Tuple[int, int, int]:
|
def parse_es_version(version: str) -> Tuple[int, int, int]:
|
||||||
"""
|
"""
|
||||||
Parse the semantic version from a string e.g. '8.8.0'
|
Parse the semantic version from a string e.g. '8.8.0'
|
||||||
|
@ -20,7 +20,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Uni
|
|||||||
import elasticsearch
|
import elasticsearch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from eland.common import ensure_es_client, es_version
|
from eland.common import ensure_es_client, es_version, is_serverless_es
|
||||||
from eland.utils import deprecated_api
|
from eland.utils import deprecated_api
|
||||||
|
|
||||||
from .common import TYPE_CLASSIFICATION, TYPE_LEARNING_TO_RANK, TYPE_REGRESSION
|
from .common import TYPE_CLASSIFICATION, TYPE_LEARNING_TO_RANK, TYPE_REGRESSION
|
||||||
@ -504,7 +504,9 @@ class MLModel:
|
|||||||
)
|
)
|
||||||
serializer = transformer.transform()
|
serializer = transformer.transform()
|
||||||
model_type = transformer.model_type
|
model_type = transformer.model_type
|
||||||
default_inference_config: Mapping[str, Mapping[str, Any]] = {model_type: {}}
|
|
||||||
|
if inference_config is None:
|
||||||
|
inference_config = {model_type: {}}
|
||||||
|
|
||||||
if es_if_exists is None:
|
if es_if_exists is None:
|
||||||
es_if_exists = "fail"
|
es_if_exists = "fail"
|
||||||
@ -523,18 +525,25 @@ class MLModel:
|
|||||||
elif es_if_exists == "replace":
|
elif es_if_exists == "replace":
|
||||||
ml_model.delete_model()
|
ml_model.delete_model()
|
||||||
|
|
||||||
|
trained_model_input = None
|
||||||
|
is_ltr = next(iter(inference_config)) is TYPE_LEARNING_TO_RANK
|
||||||
|
if not is_ltr or (
|
||||||
|
es_version(es_client) < (8, 15) and not is_serverless_es(es_client)
|
||||||
|
):
|
||||||
|
trained_model_input = {"field_names": feature_names}
|
||||||
|
|
||||||
if es_compress_model_definition:
|
if es_compress_model_definition:
|
||||||
ml_model._client.ml.put_trained_model(
|
ml_model._client.ml.put_trained_model(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
input={"field_names": feature_names},
|
inference_config=inference_config,
|
||||||
inference_config=inference_config or default_inference_config,
|
input=trained_model_input,
|
||||||
compressed_definition=serializer.serialize_and_compress_model(),
|
compressed_definition=serializer.serialize_and_compress_model(),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
ml_model._client.ml.put_trained_model(
|
ml_model._client.ml.put_trained_model(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
input={"field_names": feature_names},
|
inference_config=inference_config,
|
||||||
inference_config=inference_config or default_inference_config,
|
input=trained_model_input,
|
||||||
definition=serializer.serialize_model(),
|
definition=serializer.serialize_model(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -20,7 +20,7 @@ import os
|
|||||||
import pandas as pd
|
import pandas as pd
|
||||||
from elasticsearch import Elasticsearch
|
from elasticsearch import Elasticsearch
|
||||||
|
|
||||||
from eland.common import es_version
|
from eland.common import es_version, is_serverless_es
|
||||||
|
|
||||||
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
@ -33,6 +33,7 @@ ELASTICSEARCH_HOST = os.environ.get(
|
|||||||
ES_TEST_CLIENT = Elasticsearch(ELASTICSEARCH_HOST)
|
ES_TEST_CLIENT = Elasticsearch(ELASTICSEARCH_HOST)
|
||||||
|
|
||||||
ES_VERSION = es_version(ES_TEST_CLIENT)
|
ES_VERSION = es_version(ES_TEST_CLIENT)
|
||||||
|
ES_IS_SERVERLESS = is_serverless_es(ES_TEST_CLIENT)
|
||||||
|
|
||||||
FLIGHTS_INDEX_NAME = "flights"
|
FLIGHTS_INDEX_NAME = "flights"
|
||||||
FLIGHTS_MAPPING = {
|
FLIGHTS_MAPPING = {
|
||||||
|
@ -25,6 +25,7 @@ import eland as ed
|
|||||||
from eland.ml import MLModel
|
from eland.ml import MLModel
|
||||||
from eland.ml.ltr import FeatureLogger, LTRModelConfig, QueryFeatureExtractor
|
from eland.ml.ltr import FeatureLogger, LTRModelConfig, QueryFeatureExtractor
|
||||||
from tests import (
|
from tests import (
|
||||||
|
ES_IS_SERVERLESS,
|
||||||
ES_TEST_CLIENT,
|
ES_TEST_CLIENT,
|
||||||
ES_VERSION,
|
ES_VERSION,
|
||||||
FLIGHTS_SMALL_INDEX_NAME,
|
FLIGHTS_SMALL_INDEX_NAME,
|
||||||
@ -379,7 +380,6 @@ class TestMLModel:
|
|||||||
model_id,
|
model_id,
|
||||||
ranker,
|
ranker,
|
||||||
ltr_model_config,
|
ltr_model_config,
|
||||||
es_if_exists="replace",
|
|
||||||
es_compress_model_definition=compress_model_definition,
|
es_compress_model_definition=compress_model_definition,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -387,9 +387,19 @@ class TestMLModel:
|
|||||||
response = ES_TEST_CLIENT.ml.get_trained_models(model_id=model_id)
|
response = ES_TEST_CLIENT.ml.get_trained_models(model_id=model_id)
|
||||||
assert response.meta.status == 200
|
assert response.meta.status == 200
|
||||||
assert response.body["count"] == 1
|
assert response.body["count"] == 1
|
||||||
saved_inference_config = response.body["trained_model_configs"][0][
|
|
||||||
"inference_config"
|
saved_trained_model_config = response.body["trained_model_configs"][0]
|
||||||
]
|
|
||||||
|
assert "input" in saved_trained_model_config
|
||||||
|
assert "field_names" in saved_trained_model_config["input"]
|
||||||
|
|
||||||
|
if not ES_IS_SERVERLESS and ES_VERSION < (8, 15):
|
||||||
|
assert len(saved_trained_model_config["input"]["field_names"]) == 3
|
||||||
|
else:
|
||||||
|
assert not len(saved_trained_model_config["input"]["field_names"])
|
||||||
|
|
||||||
|
saved_inference_config = saved_trained_model_config["inference_config"]
|
||||||
|
|
||||||
assert "learning_to_rank" in saved_inference_config
|
assert "learning_to_rank" in saved_inference_config
|
||||||
assert "feature_extractors" in saved_inference_config["learning_to_rank"]
|
assert "feature_extractors" in saved_inference_config["learning_to_rank"]
|
||||||
saved_feature_extractors = saved_inference_config["learning_to_rank"][
|
saved_feature_extractors = saved_inference_config["learning_to_rank"][
|
||||||
@ -438,6 +448,9 @@ class TestMLModel:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
# Clean up
|
# Clean up
|
||||||
|
ES_TEST_CLIENT.cluster.health(
|
||||||
|
index=".ml-*", wait_for_active_shards="all"
|
||||||
|
) # Added to prevent flakiness in the test
|
||||||
es_model.delete_model()
|
es_model.delete_model()
|
||||||
|
|
||||||
@requires_sklearn
|
@requires_sklearn
|
||||||
@ -466,6 +479,7 @@ class TestMLModel:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Clean up
|
# Clean up
|
||||||
|
|
||||||
es_model.delete_model()
|
es_model.delete_model()
|
||||||
|
|
||||||
@requires_sklearn
|
@requires_sklearn
|
||||||
|
Loading…
x
Reference in New Issue
Block a user