diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index e6cdb30..7f47bf3 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -32,6 +32,6 @@ steps: - '3.9' - '3.8' stack: - - '8.13.0-SNAPSHOT' - - '8.12.2' + - '8.15.0-SNAPSHOT' + - '8.14.1' command: ./.buildkite/run-tests diff --git a/eland/cli/eland_import_hub_model.py b/eland/cli/eland_import_hub_model.py index 5f9a3ac..7980a3c 100755 --- a/eland/cli/eland_import_hub_model.py +++ b/eland/cli/eland_import_hub_model.py @@ -33,7 +33,7 @@ from elastic_transport.client_utils import DEFAULT from elasticsearch import AuthenticationException, Elasticsearch from eland._version import __version__ -from eland.common import parse_es_version +from eland.common import is_serverless_es, parse_es_version MODEL_HUB_URL = "https://huggingface.co" @@ -197,10 +197,7 @@ def get_es_client(cli_args, logger): def check_cluster_version(es_client, logger): es_info = es_client.info() - if ( - "build_flavor" in es_info["version"] - and es_info["version"]["build_flavor"] == "serverless" - ): + if is_serverless_es(es_client): logger.info(f"Connected to serverless cluster '{es_info['cluster_name']}'") # Serverless is compatible # Return the latest known semantic version, i.e. this version diff --git a/eland/common.py b/eland/common.py index d582fba..219ec8b 100644 --- a/eland/common.py +++ b/eland/common.py @@ -344,6 +344,17 @@ def es_version(es_client: Elasticsearch) -> Tuple[int, int, int]: return eland_es_version +def is_serverless_es(es_client: Elasticsearch) -> bool: + """ + Returns true if the client is connected to a serverless instance of Elasticsearch. + """ + es_info = es_client.info() + return ( + "build_flavor" in es_info["version"] + and es_info["version"]["build_flavor"] == "serverless" + ) + + def parse_es_version(version: str) -> Tuple[int, int, int]: """ Parse the semantic version from a string e.g. '8.8.0' diff --git a/eland/ml/ml_model.py b/eland/ml/ml_model.py index d13c2a6..d7f7c53 100644 --- a/eland/ml/ml_model.py +++ b/eland/ml/ml_model.py @@ -20,7 +20,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Uni import elasticsearch import numpy as np -from eland.common import ensure_es_client, es_version +from eland.common import ensure_es_client, es_version, is_serverless_es from eland.utils import deprecated_api from .common import TYPE_CLASSIFICATION, TYPE_LEARNING_TO_RANK, TYPE_REGRESSION @@ -504,7 +504,9 @@ class MLModel: ) serializer = transformer.transform() model_type = transformer.model_type - default_inference_config: Mapping[str, Mapping[str, Any]] = {model_type: {}} + + if inference_config is None: + inference_config = {model_type: {}} if es_if_exists is None: es_if_exists = "fail" @@ -523,18 +525,25 @@ class MLModel: elif es_if_exists == "replace": ml_model.delete_model() + trained_model_input = None + is_ltr = next(iter(inference_config)) is TYPE_LEARNING_TO_RANK + if not is_ltr or ( + es_version(es_client) < (8, 15) and not is_serverless_es(es_client) + ): + trained_model_input = {"field_names": feature_names} + if es_compress_model_definition: ml_model._client.ml.put_trained_model( model_id=model_id, - input={"field_names": feature_names}, - inference_config=inference_config or default_inference_config, + inference_config=inference_config, + input=trained_model_input, compressed_definition=serializer.serialize_and_compress_model(), ) else: ml_model._client.ml.put_trained_model( model_id=model_id, - input={"field_names": feature_names}, - inference_config=inference_config or default_inference_config, + inference_config=inference_config, + input=trained_model_input, definition=serializer.serialize_model(), ) diff --git a/tests/__init__.py b/tests/__init__.py index cea05c4..0efa5fa 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -20,7 +20,7 @@ import os import pandas as pd from elasticsearch import Elasticsearch -from eland.common import es_version +from eland.common import es_version, is_serverless_es ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) @@ -33,6 +33,7 @@ ELASTICSEARCH_HOST = os.environ.get( ES_TEST_CLIENT = Elasticsearch(ELASTICSEARCH_HOST) ES_VERSION = es_version(ES_TEST_CLIENT) +ES_IS_SERVERLESS = is_serverless_es(ES_TEST_CLIENT) FLIGHTS_INDEX_NAME = "flights" FLIGHTS_MAPPING = { diff --git a/tests/ml/test_ml_model_pytest.py b/tests/ml/test_ml_model_pytest.py index 228ba79..1094503 100644 --- a/tests/ml/test_ml_model_pytest.py +++ b/tests/ml/test_ml_model_pytest.py @@ -25,6 +25,7 @@ import eland as ed from eland.ml import MLModel from eland.ml.ltr import FeatureLogger, LTRModelConfig, QueryFeatureExtractor from tests import ( + ES_IS_SERVERLESS, ES_TEST_CLIENT, ES_VERSION, FLIGHTS_SMALL_INDEX_NAME, @@ -379,7 +380,6 @@ class TestMLModel: model_id, ranker, ltr_model_config, - es_if_exists="replace", es_compress_model_definition=compress_model_definition, ) @@ -387,9 +387,19 @@ class TestMLModel: response = ES_TEST_CLIENT.ml.get_trained_models(model_id=model_id) assert response.meta.status == 200 assert response.body["count"] == 1 - saved_inference_config = response.body["trained_model_configs"][0][ - "inference_config" - ] + + saved_trained_model_config = response.body["trained_model_configs"][0] + + assert "input" in saved_trained_model_config + assert "field_names" in saved_trained_model_config["input"] + + if not ES_IS_SERVERLESS and ES_VERSION < (8, 15): + assert len(saved_trained_model_config["input"]["field_names"]) == 3 + else: + assert not len(saved_trained_model_config["input"]["field_names"]) + + saved_inference_config = saved_trained_model_config["inference_config"] + assert "learning_to_rank" in saved_inference_config assert "feature_extractors" in saved_inference_config["learning_to_rank"] saved_feature_extractors = saved_inference_config["learning_to_rank"][ @@ -438,6 +448,9 @@ class TestMLModel: pass # Clean up + ES_TEST_CLIENT.cluster.health( + index=".ml-*", wait_for_active_shards="all" + ) # Added to prevent flakiness in the test es_model.delete_model() @requires_sklearn @@ -466,6 +479,7 @@ class TestMLModel: ) # Clean up + es_model.delete_model() @requires_sklearn