Remove input fields from exported LTR models (#708)

This commit is contained in:
Aurélien FOUCRET 2024-07-05 14:31:22 +02:00 committed by GitHub
parent f18aa35e8e
commit bee6d0e1f7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 50 additions and 18 deletions

View File

@ -32,6 +32,6 @@ steps:
- '3.9'
- '3.8'
stack:
- '8.13.0-SNAPSHOT'
- '8.12.2'
- '8.15.0-SNAPSHOT'
- '8.14.1'
command: ./.buildkite/run-tests

View File

@ -33,7 +33,7 @@ from elastic_transport.client_utils import DEFAULT
from elasticsearch import AuthenticationException, Elasticsearch
from eland._version import __version__
from eland.common import parse_es_version
from eland.common import is_serverless_es, parse_es_version
MODEL_HUB_URL = "https://huggingface.co"
@ -197,10 +197,7 @@ def get_es_client(cli_args, logger):
def check_cluster_version(es_client, logger):
es_info = es_client.info()
if (
"build_flavor" in es_info["version"]
and es_info["version"]["build_flavor"] == "serverless"
):
if is_serverless_es(es_client):
logger.info(f"Connected to serverless cluster '{es_info['cluster_name']}'")
# Serverless is compatible
# Return the latest known semantic version, i.e. this version

View File

@ -344,6 +344,17 @@ def es_version(es_client: Elasticsearch) -> Tuple[int, int, int]:
return eland_es_version
def is_serverless_es(es_client: Elasticsearch) -> bool:
"""
Returns true if the client is connected to a serverless instance of Elasticsearch.
"""
es_info = es_client.info()
return (
"build_flavor" in es_info["version"]
and es_info["version"]["build_flavor"] == "serverless"
)
def parse_es_version(version: str) -> Tuple[int, int, int]:
"""
Parse the semantic version from a string e.g. '8.8.0'

View File

@ -20,7 +20,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Uni
import elasticsearch
import numpy as np
from eland.common import ensure_es_client, es_version
from eland.common import ensure_es_client, es_version, is_serverless_es
from eland.utils import deprecated_api
from .common import TYPE_CLASSIFICATION, TYPE_LEARNING_TO_RANK, TYPE_REGRESSION
@ -504,7 +504,9 @@ class MLModel:
)
serializer = transformer.transform()
model_type = transformer.model_type
default_inference_config: Mapping[str, Mapping[str, Any]] = {model_type: {}}
if inference_config is None:
inference_config = {model_type: {}}
if es_if_exists is None:
es_if_exists = "fail"
@ -523,18 +525,25 @@ class MLModel:
elif es_if_exists == "replace":
ml_model.delete_model()
trained_model_input = None
is_ltr = next(iter(inference_config)) is TYPE_LEARNING_TO_RANK
if not is_ltr or (
es_version(es_client) < (8, 15) and not is_serverless_es(es_client)
):
trained_model_input = {"field_names": feature_names}
if es_compress_model_definition:
ml_model._client.ml.put_trained_model(
model_id=model_id,
input={"field_names": feature_names},
inference_config=inference_config or default_inference_config,
inference_config=inference_config,
input=trained_model_input,
compressed_definition=serializer.serialize_and_compress_model(),
)
else:
ml_model._client.ml.put_trained_model(
model_id=model_id,
input={"field_names": feature_names},
inference_config=inference_config or default_inference_config,
inference_config=inference_config,
input=trained_model_input,
definition=serializer.serialize_model(),
)

View File

@ -20,7 +20,7 @@ import os
import pandas as pd
from elasticsearch import Elasticsearch
from eland.common import es_version
from eland.common import es_version, is_serverless_es
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
@ -33,6 +33,7 @@ ELASTICSEARCH_HOST = os.environ.get(
ES_TEST_CLIENT = Elasticsearch(ELASTICSEARCH_HOST)
ES_VERSION = es_version(ES_TEST_CLIENT)
ES_IS_SERVERLESS = is_serverless_es(ES_TEST_CLIENT)
FLIGHTS_INDEX_NAME = "flights"
FLIGHTS_MAPPING = {

View File

@ -25,6 +25,7 @@ import eland as ed
from eland.ml import MLModel
from eland.ml.ltr import FeatureLogger, LTRModelConfig, QueryFeatureExtractor
from tests import (
ES_IS_SERVERLESS,
ES_TEST_CLIENT,
ES_VERSION,
FLIGHTS_SMALL_INDEX_NAME,
@ -379,7 +380,6 @@ class TestMLModel:
model_id,
ranker,
ltr_model_config,
es_if_exists="replace",
es_compress_model_definition=compress_model_definition,
)
@ -387,9 +387,19 @@ class TestMLModel:
response = ES_TEST_CLIENT.ml.get_trained_models(model_id=model_id)
assert response.meta.status == 200
assert response.body["count"] == 1
saved_inference_config = response.body["trained_model_configs"][0][
"inference_config"
]
saved_trained_model_config = response.body["trained_model_configs"][0]
assert "input" in saved_trained_model_config
assert "field_names" in saved_trained_model_config["input"]
if not ES_IS_SERVERLESS and ES_VERSION < (8, 15):
assert len(saved_trained_model_config["input"]["field_names"]) == 3
else:
assert not len(saved_trained_model_config["input"]["field_names"])
saved_inference_config = saved_trained_model_config["inference_config"]
assert "learning_to_rank" in saved_inference_config
assert "feature_extractors" in saved_inference_config["learning_to_rank"]
saved_feature_extractors = saved_inference_config["learning_to_rank"][
@ -438,6 +448,9 @@ class TestMLModel:
pass
# Clean up
ES_TEST_CLIENT.cluster.health(
index=".ml-*", wait_for_active_shards="all"
) # Added to prevent flakiness in the test
es_model.delete_model()
@requires_sklearn
@ -466,6 +479,7 @@ class TestMLModel:
)
# Clean up
es_model.delete_model()
@requires_sklearn