mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
Tolerate different model output formats when measuring embedding size (#535)
Only add the embedding_size config option if the target Elasticsearch cluster version supports it
This commit is contained in:
parent
7ca8376f68
commit
32ab988eb6
@ -251,11 +251,15 @@ hardware.
|
||||
```python
|
||||
>>> import elasticsearch
|
||||
>>> from pathlib import Path
|
||||
>>> from eland.common import es_version
|
||||
>>> from eland.ml.pytorch import PyTorchModel
|
||||
>>> from eland.ml.pytorch.transformers import TransformerModel
|
||||
|
||||
>>> es = elasticsearch.Elasticsearch("http://elastic:mlqa_admin@localhost:9200")
|
||||
>>> es_cluster_version = es_version(es)
|
||||
|
||||
# Load a Hugging Face transformers model directly from the model hub
|
||||
>>> tm = TransformerModel("elastic/distilbert-base-cased-finetuned-conll03-english", "ner")
|
||||
>>> tm = TransformerModel(model_id="elastic/distilbert-base-cased-finetuned-conll03-english", task_type="ner", es_version=es_cluster_version)
|
||||
Downloading: 100%|██████████| 257/257 [00:00<00:00, 108kB/s]
|
||||
Downloading: 100%|██████████| 954/954 [00:00<00:00, 372kB/s]
|
||||
Downloading: 100%|██████████| 208k/208k [00:00<00:00, 668kB/s]
|
||||
@ -268,7 +272,6 @@ Downloading: 100%|██████████| 249M/249M [00:23<00:00, 11.2MB
|
||||
>>> model_path, config, vocab_path = tm.save(tmp_path)
|
||||
|
||||
# Import model into Elasticsearch
|
||||
>>> es = elasticsearch.Elasticsearch("http://elastic:mlqa_admin@localhost:9200", timeout=300) # 5 minute timeout
|
||||
>>> ptm = PyTorchModel(es, tm.elasticsearch_model_id())
|
||||
>>> ptm.import_model(model_path=model_path, config_path=None, vocab_path=vocab_path, config=config)
|
||||
100%|██████████| 63/63 [00:12<00:00, 5.02it/s]
|
||||
|
@ -35,6 +35,8 @@ import torch
|
||||
from elastic_transport.client_utils import DEFAULT
|
||||
from elasticsearch import AuthenticationException, Elasticsearch
|
||||
|
||||
from eland.common import parse_es_version
|
||||
|
||||
MODEL_HUB_URL = "https://huggingface.co"
|
||||
|
||||
|
||||
@ -156,12 +158,14 @@ def get_es_client(cli_args):
|
||||
logger.error(e)
|
||||
exit(1)
|
||||
|
||||
def check_cluster_version(es_client):
|
||||
|
||||
def check_cluster_version(es_client):
|
||||
es_info = es_client.info()
|
||||
logger.info(f"Connected to cluster named '{es_info['cluster_name']}' (version: {es_info['version']['number']})")
|
||||
|
||||
major_version = int(es_info['version']['number'].split(".")[0])
|
||||
minor_version = int(es_info['version']['number'].split(".")[1])
|
||||
sem_ver = parse_es_version(es_info['version']['number'])
|
||||
major_version = sem_ver[0]
|
||||
minor_version = sem_ver[1]
|
||||
|
||||
# NLP models added in 8
|
||||
if major_version < 8:
|
||||
@ -174,6 +178,7 @@ def check_cluster_version(es_client):
|
||||
logger.error(f"Eland uses PyTorch version {torch.__version__} which is incompatible with Elasticsearch versions prior to 8.7. Please upgrade Elasticsearch to at least version 8.7")
|
||||
exit(1)
|
||||
|
||||
return sem_ver
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Configure logging
|
||||
@ -204,14 +209,14 @@ if __name__ == "__main__":
|
||||
# Connect to ES
|
||||
logger.info("Establishing connection to Elasticsearch")
|
||||
es = get_es_client(args)
|
||||
check_cluster_version(es)
|
||||
cluster_version = check_cluster_version(es)
|
||||
|
||||
# Trace and save model, then upload it from temp file
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
logger.info(f"Loading HuggingFace transformer tokenizer and model '{args.hub_model_id}'")
|
||||
|
||||
try:
|
||||
tm = TransformerModel(args.hub_model_id, args.task_type, args.quantize)
|
||||
tm = TransformerModel(model_id=args.hub_model_id, task_type=args.task_type, es_version=cluster_version, quantize=args.quantize)
|
||||
model_path, config, vocab_path = tm.save(tmp_dir)
|
||||
except TaskTypeError as err:
|
||||
logger.error(f"Failed to get model for task type, please provide valid task type via '--task-type' parameter. Caused by {err}")
|
||||
|
@ -322,15 +322,7 @@ def es_version(es_client: Elasticsearch) -> Tuple[int, int, int]:
|
||||
eland_es_version: Tuple[int, int, int]
|
||||
if not hasattr(es_client, "_eland_es_version"):
|
||||
version_info = es_client.info()["version"]["number"]
|
||||
match = re.match(r"^(\d+)\.(\d+)\.(\d+)", version_info)
|
||||
if match is None:
|
||||
raise ValueError(
|
||||
f"Unable to determine Elasticsearch version. "
|
||||
f"Received: {version_info}"
|
||||
)
|
||||
eland_es_version = cast(
|
||||
Tuple[int, int, int], tuple(int(x) for x in match.groups())
|
||||
)
|
||||
eland_es_version = parse_es_version(version_info)
|
||||
es_client._eland_es_version = eland_es_version # type: ignore
|
||||
|
||||
# Raise a warning if the major version of the library doesn't match the
|
||||
@ -347,3 +339,16 @@ def es_version(es_client: Elasticsearch) -> Tuple[int, int, int]:
|
||||
else:
|
||||
eland_es_version = es_client._eland_es_version
|
||||
return eland_es_version
|
||||
|
||||
|
||||
def parse_es_version(version: str) -> Tuple[int, int, int]:
|
||||
"""
|
||||
Parse the semantic version from a string e.g. '8.8.0'
|
||||
Extensions such as '-SNAPSHOT' are ignored
|
||||
"""
|
||||
match = re.match(r"^(\d+)\.(\d+)\.(\d+)", version)
|
||||
if match is None:
|
||||
raise ValueError(
|
||||
f"Unable to determine Elasticsearch version. " f"Received: {version}"
|
||||
)
|
||||
return cast(Tuple[int, int, int], tuple(int(x) for x in match.groups()))
|
||||
|
@ -573,7 +573,37 @@ class _TraceableTextSimilarityModel(_TransformerTraceableModel):
|
||||
|
||||
|
||||
class TransformerModel:
|
||||
def __init__(self, model_id: str, task_type: str, quantize: bool = False):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
task_type: str,
|
||||
*,
|
||||
es_version: Optional[Tuple[int, int, int]],
|
||||
quantize: bool = False,
|
||||
):
|
||||
"""
|
||||
Loads a model from the Hugging Face repository or local file and creates
|
||||
the configuration for upload to Elasticsearch.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model_id: str
|
||||
A Hugging Face model Id or a file path to the directory containing
|
||||
the model files.
|
||||
|
||||
task_type: str
|
||||
One of the supported task types.
|
||||
|
||||
es_version: Optional[Tuple[int, int, int]]
|
||||
The Elasticsearch cluster version.
|
||||
Certain features are created only if the target cluster is
|
||||
a high enough version to support them. If not set only
|
||||
universally supported features are added.
|
||||
|
||||
quantize: bool, default False
|
||||
Quantize the model.
|
||||
"""
|
||||
|
||||
self._model_id = model_id
|
||||
self._task_type = task_type.replace("-", "_")
|
||||
|
||||
@ -595,7 +625,7 @@ class TransformerModel:
|
||||
if quantize:
|
||||
self._traceable_model.quantize()
|
||||
self._vocab = self._load_vocab()
|
||||
self._config = self._create_config()
|
||||
self._config = self._create_config(es_version)
|
||||
|
||||
def _load_vocab(self) -> Dict[str, List[str]]:
|
||||
vocab_items = self._tokenizer.get_vocab().items()
|
||||
@ -636,7 +666,9 @@ class TransformerModel:
|
||||
).get(self._model_id),
|
||||
)
|
||||
|
||||
def _create_config(self) -> NlpTrainedModelConfig:
|
||||
def _create_config(
|
||||
self, es_version: Optional[Tuple[int, int, int]]
|
||||
) -> NlpTrainedModelConfig:
|
||||
tokenization_config = self._create_tokenization_config()
|
||||
|
||||
# Set squad well known defaults
|
||||
@ -651,12 +683,24 @@ class TransformerModel:
|
||||
classification_labels=self._traceable_model.classification_labels(),
|
||||
)
|
||||
elif self._task_type == "text_embedding":
|
||||
sample_embedding, _ = self._traceable_model.sample_output()
|
||||
embedding_size = sample_embedding.size(-1)
|
||||
inference_config = TASK_TYPE_TO_INFERENCE_CONFIG[self._task_type](
|
||||
tokenization=tokenization_config,
|
||||
embedding_size=embedding_size,
|
||||
)
|
||||
# The embedding_size paramater was added in Elasticsearch 8.8
|
||||
# If the version is not known use the basic config
|
||||
if es_version is None or (es_version[0] <= 8 and es_version[1] < 8):
|
||||
inference_config = TASK_TYPE_TO_INFERENCE_CONFIG[self._task_type](
|
||||
tokenization=tokenization_config
|
||||
)
|
||||
else:
|
||||
sample_embedding = self._traceable_model.sample_output()
|
||||
if type(sample_embedding) is tuple:
|
||||
text_embedding, _ = sample_embedding
|
||||
else:
|
||||
text_embedding = sample_embedding
|
||||
|
||||
embedding_size = text_embedding.size(-1)
|
||||
inference_config = TASK_TYPE_TO_INFERENCE_CONFIG[self._task_type](
|
||||
tokenization=tokenization_config,
|
||||
embedding_size=embedding_size,
|
||||
)
|
||||
else:
|
||||
inference_config = TASK_TYPE_TO_INFERENCE_CONFIG[self._task_type](
|
||||
tokenization=tokenization_config
|
||||
|
@ -58,6 +58,14 @@ TEXT_PREDICTION_MODELS = [
|
||||
)
|
||||
]
|
||||
|
||||
TEXT_EMBEDDING_MODELS = [
|
||||
(
|
||||
"sentence-transformers/all-MiniLM-L6-v2",
|
||||
"text_embedding",
|
||||
"Paris is the capital of France.",
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def setup_and_tear_down():
|
||||
@ -76,7 +84,9 @@ def setup_and_tear_down():
|
||||
|
||||
def download_model_and_start_deployment(tmp_dir, quantize, model_id, task):
|
||||
print("Loading HuggingFace transformer tokenizer and model")
|
||||
tm = TransformerModel(model_id, task, quantize)
|
||||
tm = TransformerModel(
|
||||
model_id=model_id, task_type=task, es_version=ES_VERSION, quantize=quantize
|
||||
)
|
||||
model_path, config, vocab_path = tm.save(tmp_dir)
|
||||
ptm = PyTorchModel(ES_TEST_CLIENT, tm.elasticsearch_model_id())
|
||||
try:
|
||||
@ -94,8 +104,25 @@ def download_model_and_start_deployment(tmp_dir, quantize, model_id, task):
|
||||
|
||||
class TestPytorchModel:
|
||||
@pytest.mark.parametrize("model_id,task,text_input,value", TEXT_PREDICTION_MODELS)
|
||||
def test_text_classification(self, model_id, task, text_input, value):
|
||||
def test_text_prediction(self, model_id, task, text_input, value):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
ptm = download_model_and_start_deployment(tmp_dir, True, model_id, task)
|
||||
result = ptm.infer(docs=[{"text_field": text_input}])
|
||||
assert result["predicted_value"] == value
|
||||
|
||||
@pytest.mark.parametrize("model_id,task,text_input", TEXT_EMBEDDING_MODELS)
|
||||
def test_text_embedding(self, model_id, task, text_input):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
ptm = download_model_and_start_deployment(tmp_dir, True, model_id, task)
|
||||
ptm.infer(docs=[{"text_field": text_input}])
|
||||
|
||||
if ES_VERSION >= (8, 8, 0):
|
||||
configs = ES_TEST_CLIENT.ml.get_trained_models(model_id=ptm.model_id)
|
||||
assert (
|
||||
int(
|
||||
configs["trained_model_configs"][0]["inference_config"][
|
||||
"text_embedding"
|
||||
]["embedding_size"]
|
||||
)
|
||||
> 0
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user