Generate valid NLP model id from file path (#541)

The eland_import_hub_model script supports uploading a local file where
the --hub-model-id argument is a file path. If the --es-model-id option is
not used the model Id is generated from the hub model id and when that 
is a file path the path must be converted to a valid elasticsearch model id.
This commit is contained in:
David Kyle 2023-05-22 15:37:36 +01:00 committed by GitHub
parent 7820a31256
commit 1e6f48f8f4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 39 additions and 3 deletions

View File

@ -22,6 +22,7 @@ libraries such as sentence-transformers.
import json import json
import os.path import os.path
import re
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Set, Tuple, Union from typing import Any, Dict, List, Optional, Set, Tuple, Union
@ -740,8 +741,7 @@ class TransformerModel:
) )
def elasticsearch_model_id(self) -> str: def elasticsearch_model_id(self) -> str:
# Elasticsearch model IDs need to be a specific format: no special chars, all lowercase, max 64 chars return elasticsearch_model_id(self._model_id)
return self._model_id.replace("/", "__").lower()[:64]
def save(self, path: str) -> Tuple[str, NlpTrainedModelConfig, str]: def save(self, path: str) -> Tuple[str, NlpTrainedModelConfig, str]:
# save traced model # save traced model
@ -753,3 +753,21 @@ class TransformerModel:
json.dump(self._vocab, outfile) json.dump(self._vocab, outfile)
return model_path, self._config, vocab_path return model_path, self._config, vocab_path
def elasticsearch_model_id(model_id: str) -> str:
"""
Elasticsearch model IDs need to be a specific format:
no special chars, all lowercase, max 64 chars. If the
Id is longer than 64 charaters take the last 64- in the
case where the id is long file path this captures the
model name.
Ids starting with __ are not valid elasticsearch Ids,
# this might be the case if model_id is a file path
"""
id = re.sub(r"[\s\\/]", "__", model_id).lower()[-64:]
if id.startswith("__"):
id = id.removeprefix("__")
return id

View File

@ -34,7 +34,7 @@ except ImportError:
try: try:
import torch # noqa: F401 import torch # noqa: F401
from torch import Tensor, nn # noqa: F401 from torch import Tensor, nn # noqa: F401
from transformers import PretrainedConfig # noqa: F401 from transformers import PretrainedConfig, elasticsearch_model_id # noqa: F401
from eland.ml.pytorch import ( # noqa: F401 from eland.ml.pytorch import ( # noqa: F401
NlpBertTokenizationConfig, NlpBertTokenizationConfig,
@ -338,3 +338,21 @@ class TestPytorchModelUpload:
) )
) )
assert task_type_from_model_config(model_config=config) == expected_task assert task_type_from_model_config(model_config=config) == expected_task
def test_elasticsearch_model_id(self):
model_id_from_path = elasticsearch_model_id(r"/foo/bar")
assert model_id_from_path == "foo_bar"
model_id_from_path = elasticsearch_model_id(r"\BAZ\with space")
assert model_id_from_path == "baz__with__space"
model_id_from_path = elasticsearch_model_id(r"/foo/with tab\tback\slash")
assert model_id_from_path == "foo__with__space__back__slash"
long_id_left_truncated = elasticsearch_model_id(
"/foo/bar/64charactersoftext64charactersoftext64charactersoftext64charofte"
)
assert (
long_id_left_truncated
== "64charactersoftext64charactersoftext64charactersoftext64charofte"
)