Generate valid NLP model id from file path (#541)

The eland_import_hub_model script supports uploading a local file where
the --hub-model-id argument is a file path. If the --es-model-id option is
not used the model Id is generated from the hub model id and when that 
is a file path the path must be converted to a valid elasticsearch model id.
This commit is contained in:
David Kyle 2023-05-22 15:37:36 +01:00 committed by GitHub
parent 7820a31256
commit 1e6f48f8f4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 39 additions and 3 deletions

View File

@ -22,6 +22,7 @@ libraries such as sentence-transformers.
import json
import os.path
import re
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Set, Tuple, Union
@ -740,8 +741,7 @@ class TransformerModel:
)
def elasticsearch_model_id(self) -> str:
# Elasticsearch model IDs need to be a specific format: no special chars, all lowercase, max 64 chars
return self._model_id.replace("/", "__").lower()[:64]
return elasticsearch_model_id(self._model_id)
def save(self, path: str) -> Tuple[str, NlpTrainedModelConfig, str]:
# save traced model
@ -753,3 +753,21 @@ class TransformerModel:
json.dump(self._vocab, outfile)
return model_path, self._config, vocab_path
def elasticsearch_model_id(model_id: str) -> str:
"""
Elasticsearch model IDs need to be a specific format:
no special chars, all lowercase, max 64 chars. If the
Id is longer than 64 charaters take the last 64- in the
case where the id is long file path this captures the
model name.
Ids starting with __ are not valid elasticsearch Ids,
# this might be the case if model_id is a file path
"""
id = re.sub(r"[\s\\/]", "__", model_id).lower()[-64:]
if id.startswith("__"):
id = id.removeprefix("__")
return id

View File

@ -34,7 +34,7 @@ except ImportError:
try:
import torch # noqa: F401
from torch import Tensor, nn # noqa: F401
from transformers import PretrainedConfig # noqa: F401
from transformers import PretrainedConfig, elasticsearch_model_id # noqa: F401
from eland.ml.pytorch import ( # noqa: F401
NlpBertTokenizationConfig,
@ -338,3 +338,21 @@ class TestPytorchModelUpload:
)
)
assert task_type_from_model_config(model_config=config) == expected_task
def test_elasticsearch_model_id(self):
model_id_from_path = elasticsearch_model_id(r"/foo/bar")
assert model_id_from_path == "foo_bar"
model_id_from_path = elasticsearch_model_id(r"\BAZ\with space")
assert model_id_from_path == "baz__with__space"
model_id_from_path = elasticsearch_model_id(r"/foo/with tab\tback\slash")
assert model_id_from_path == "foo__with__space__back__slash"
long_id_left_truncated = elasticsearch_model_id(
"/foo/bar/64charactersoftext64charactersoftext64charactersoftext64charofte"
)
assert (
long_id_left_truncated
== "64charactersoftext64charactersoftext64charactersoftext64charofte"
)