mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
Remove input fields from exported LTR models (#708)
This commit is contained in:
parent
f18aa35e8e
commit
bee6d0e1f7
@ -32,6 +32,6 @@ steps:
|
||||
- '3.9'
|
||||
- '3.8'
|
||||
stack:
|
||||
- '8.13.0-SNAPSHOT'
|
||||
- '8.12.2'
|
||||
- '8.15.0-SNAPSHOT'
|
||||
- '8.14.1'
|
||||
command: ./.buildkite/run-tests
|
||||
|
@ -33,7 +33,7 @@ from elastic_transport.client_utils import DEFAULT
|
||||
from elasticsearch import AuthenticationException, Elasticsearch
|
||||
|
||||
from eland._version import __version__
|
||||
from eland.common import parse_es_version
|
||||
from eland.common import is_serverless_es, parse_es_version
|
||||
|
||||
MODEL_HUB_URL = "https://huggingface.co"
|
||||
|
||||
@ -197,10 +197,7 @@ def get_es_client(cli_args, logger):
|
||||
def check_cluster_version(es_client, logger):
|
||||
es_info = es_client.info()
|
||||
|
||||
if (
|
||||
"build_flavor" in es_info["version"]
|
||||
and es_info["version"]["build_flavor"] == "serverless"
|
||||
):
|
||||
if is_serverless_es(es_client):
|
||||
logger.info(f"Connected to serverless cluster '{es_info['cluster_name']}'")
|
||||
# Serverless is compatible
|
||||
# Return the latest known semantic version, i.e. this version
|
||||
|
@ -344,6 +344,17 @@ def es_version(es_client: Elasticsearch) -> Tuple[int, int, int]:
|
||||
return eland_es_version
|
||||
|
||||
|
||||
def is_serverless_es(es_client: Elasticsearch) -> bool:
|
||||
"""
|
||||
Returns true if the client is connected to a serverless instance of Elasticsearch.
|
||||
"""
|
||||
es_info = es_client.info()
|
||||
return (
|
||||
"build_flavor" in es_info["version"]
|
||||
and es_info["version"]["build_flavor"] == "serverless"
|
||||
)
|
||||
|
||||
|
||||
def parse_es_version(version: str) -> Tuple[int, int, int]:
|
||||
"""
|
||||
Parse the semantic version from a string e.g. '8.8.0'
|
||||
|
@ -20,7 +20,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Uni
|
||||
import elasticsearch
|
||||
import numpy as np
|
||||
|
||||
from eland.common import ensure_es_client, es_version
|
||||
from eland.common import ensure_es_client, es_version, is_serverless_es
|
||||
from eland.utils import deprecated_api
|
||||
|
||||
from .common import TYPE_CLASSIFICATION, TYPE_LEARNING_TO_RANK, TYPE_REGRESSION
|
||||
@ -504,7 +504,9 @@ class MLModel:
|
||||
)
|
||||
serializer = transformer.transform()
|
||||
model_type = transformer.model_type
|
||||
default_inference_config: Mapping[str, Mapping[str, Any]] = {model_type: {}}
|
||||
|
||||
if inference_config is None:
|
||||
inference_config = {model_type: {}}
|
||||
|
||||
if es_if_exists is None:
|
||||
es_if_exists = "fail"
|
||||
@ -523,18 +525,25 @@ class MLModel:
|
||||
elif es_if_exists == "replace":
|
||||
ml_model.delete_model()
|
||||
|
||||
trained_model_input = None
|
||||
is_ltr = next(iter(inference_config)) is TYPE_LEARNING_TO_RANK
|
||||
if not is_ltr or (
|
||||
es_version(es_client) < (8, 15) and not is_serverless_es(es_client)
|
||||
):
|
||||
trained_model_input = {"field_names": feature_names}
|
||||
|
||||
if es_compress_model_definition:
|
||||
ml_model._client.ml.put_trained_model(
|
||||
model_id=model_id,
|
||||
input={"field_names": feature_names},
|
||||
inference_config=inference_config or default_inference_config,
|
||||
inference_config=inference_config,
|
||||
input=trained_model_input,
|
||||
compressed_definition=serializer.serialize_and_compress_model(),
|
||||
)
|
||||
else:
|
||||
ml_model._client.ml.put_trained_model(
|
||||
model_id=model_id,
|
||||
input={"field_names": feature_names},
|
||||
inference_config=inference_config or default_inference_config,
|
||||
inference_config=inference_config,
|
||||
input=trained_model_input,
|
||||
definition=serializer.serialize_model(),
|
||||
)
|
||||
|
||||
|
@ -20,7 +20,7 @@ import os
|
||||
import pandas as pd
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
from eland.common import es_version
|
||||
from eland.common import es_version, is_serverless_es
|
||||
|
||||
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
@ -33,6 +33,7 @@ ELASTICSEARCH_HOST = os.environ.get(
|
||||
ES_TEST_CLIENT = Elasticsearch(ELASTICSEARCH_HOST)
|
||||
|
||||
ES_VERSION = es_version(ES_TEST_CLIENT)
|
||||
ES_IS_SERVERLESS = is_serverless_es(ES_TEST_CLIENT)
|
||||
|
||||
FLIGHTS_INDEX_NAME = "flights"
|
||||
FLIGHTS_MAPPING = {
|
||||
|
@ -25,6 +25,7 @@ import eland as ed
|
||||
from eland.ml import MLModel
|
||||
from eland.ml.ltr import FeatureLogger, LTRModelConfig, QueryFeatureExtractor
|
||||
from tests import (
|
||||
ES_IS_SERVERLESS,
|
||||
ES_TEST_CLIENT,
|
||||
ES_VERSION,
|
||||
FLIGHTS_SMALL_INDEX_NAME,
|
||||
@ -379,7 +380,6 @@ class TestMLModel:
|
||||
model_id,
|
||||
ranker,
|
||||
ltr_model_config,
|
||||
es_if_exists="replace",
|
||||
es_compress_model_definition=compress_model_definition,
|
||||
)
|
||||
|
||||
@ -387,9 +387,19 @@ class TestMLModel:
|
||||
response = ES_TEST_CLIENT.ml.get_trained_models(model_id=model_id)
|
||||
assert response.meta.status == 200
|
||||
assert response.body["count"] == 1
|
||||
saved_inference_config = response.body["trained_model_configs"][0][
|
||||
"inference_config"
|
||||
]
|
||||
|
||||
saved_trained_model_config = response.body["trained_model_configs"][0]
|
||||
|
||||
assert "input" in saved_trained_model_config
|
||||
assert "field_names" in saved_trained_model_config["input"]
|
||||
|
||||
if not ES_IS_SERVERLESS and ES_VERSION < (8, 15):
|
||||
assert len(saved_trained_model_config["input"]["field_names"]) == 3
|
||||
else:
|
||||
assert not len(saved_trained_model_config["input"]["field_names"])
|
||||
|
||||
saved_inference_config = saved_trained_model_config["inference_config"]
|
||||
|
||||
assert "learning_to_rank" in saved_inference_config
|
||||
assert "feature_extractors" in saved_inference_config["learning_to_rank"]
|
||||
saved_feature_extractors = saved_inference_config["learning_to_rank"][
|
||||
@ -438,6 +448,9 @@ class TestMLModel:
|
||||
pass
|
||||
|
||||
# Clean up
|
||||
ES_TEST_CLIENT.cluster.health(
|
||||
index=".ml-*", wait_for_active_shards="all"
|
||||
) # Added to prevent flakiness in the test
|
||||
es_model.delete_model()
|
||||
|
||||
@requires_sklearn
|
||||
@ -466,6 +479,7 @@ class TestMLModel:
|
||||
)
|
||||
|
||||
# Clean up
|
||||
|
||||
es_model.delete_model()
|
||||
|
||||
@requires_sklearn
|
||||
|
Loading…
x
Reference in New Issue
Block a user