diff --git a/eland/ml/pytorch/transformers.py b/eland/ml/pytorch/transformers.py index e8f61ac..6d9f46b 100644 --- a/eland/ml/pytorch/transformers.py +++ b/eland/ml/pytorch/transformers.py @@ -22,6 +22,7 @@ libraries such as sentence-transformers. import json import os.path +import re from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional, Set, Tuple, Union @@ -740,8 +741,7 @@ class TransformerModel: ) def elasticsearch_model_id(self) -> str: - # Elasticsearch model IDs need to be a specific format: no special chars, all lowercase, max 64 chars - return self._model_id.replace("/", "__").lower()[:64] + return elasticsearch_model_id(self._model_id) def save(self, path: str) -> Tuple[str, NlpTrainedModelConfig, str]: # save traced model @@ -753,3 +753,21 @@ class TransformerModel: json.dump(self._vocab, outfile) return model_path, self._config, vocab_path + + +def elasticsearch_model_id(model_id: str) -> str: + """ + Elasticsearch model IDs need to be a specific format: + no special chars, all lowercase, max 64 chars. If the + Id is longer than 64 charaters take the last 64- in the + case where the id is long file path this captures the + model name. + + Ids starting with __ are not valid elasticsearch Ids, + # this might be the case if model_id is a file path + """ + + id = re.sub(r"[\s\\/]", "__", model_id).lower()[-64:] + if id.startswith("__"): + id = id.removeprefix("__") + return id diff --git a/tests/ml/pytorch/test_transformer_pytorch_model_pytest.py b/tests/ml/pytorch/test_transformer_pytorch_model_pytest.py index 4a81be9..c394ae7 100644 --- a/tests/ml/pytorch/test_transformer_pytorch_model_pytest.py +++ b/tests/ml/pytorch/test_transformer_pytorch_model_pytest.py @@ -34,7 +34,7 @@ except ImportError: try: import torch # noqa: F401 from torch import Tensor, nn # noqa: F401 - from transformers import PretrainedConfig # noqa: F401 + from transformers import PretrainedConfig, elasticsearch_model_id # noqa: F401 from eland.ml.pytorch import ( # noqa: F401 NlpBertTokenizationConfig, @@ -338,3 +338,21 @@ class TestPytorchModelUpload: ) ) assert task_type_from_model_config(model_config=config) == expected_task + + def test_elasticsearch_model_id(self): + model_id_from_path = elasticsearch_model_id(r"/foo/bar") + assert model_id_from_path == "foo_bar" + + model_id_from_path = elasticsearch_model_id(r"\BAZ\with space") + assert model_id_from_path == "baz__with__space" + + model_id_from_path = elasticsearch_model_id(r"/foo/with tab\tback\slash") + assert model_id_from_path == "foo__with__space__back__slash" + + long_id_left_truncated = elasticsearch_model_id( + "/foo/bar/64charactersoftext64charactersoftext64charactersoftext64charofte" + ) + assert ( + long_id_left_truncated + == "64charactersoftext64charactersoftext64charactersoftext64charofte" + )