mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
Generate valid NLP model id from file path (#541)
The eland_import_hub_model script supports uploading a local file where the --hub-model-id argument is a file path. If the --es-model-id option is not used the model Id is generated from the hub model id and when that is a file path the path must be converted to a valid elasticsearch model id.
This commit is contained in:
parent
7820a31256
commit
1e6f48f8f4
@ -22,6 +22,7 @@ libraries such as sentence-transformers.
|
||||
|
||||
import json
|
||||
import os.path
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
@ -740,8 +741,7 @@ class TransformerModel:
|
||||
)
|
||||
|
||||
def elasticsearch_model_id(self) -> str:
|
||||
# Elasticsearch model IDs need to be a specific format: no special chars, all lowercase, max 64 chars
|
||||
return self._model_id.replace("/", "__").lower()[:64]
|
||||
return elasticsearch_model_id(self._model_id)
|
||||
|
||||
def save(self, path: str) -> Tuple[str, NlpTrainedModelConfig, str]:
|
||||
# save traced model
|
||||
@ -753,3 +753,21 @@ class TransformerModel:
|
||||
json.dump(self._vocab, outfile)
|
||||
|
||||
return model_path, self._config, vocab_path
|
||||
|
||||
|
||||
def elasticsearch_model_id(model_id: str) -> str:
|
||||
"""
|
||||
Elasticsearch model IDs need to be a specific format:
|
||||
no special chars, all lowercase, max 64 chars. If the
|
||||
Id is longer than 64 charaters take the last 64- in the
|
||||
case where the id is long file path this captures the
|
||||
model name.
|
||||
|
||||
Ids starting with __ are not valid elasticsearch Ids,
|
||||
# this might be the case if model_id is a file path
|
||||
"""
|
||||
|
||||
id = re.sub(r"[\s\\/]", "__", model_id).lower()[-64:]
|
||||
if id.startswith("__"):
|
||||
id = id.removeprefix("__")
|
||||
return id
|
||||
|
@ -34,7 +34,7 @@ except ImportError:
|
||||
try:
|
||||
import torch # noqa: F401
|
||||
from torch import Tensor, nn # noqa: F401
|
||||
from transformers import PretrainedConfig # noqa: F401
|
||||
from transformers import PretrainedConfig, elasticsearch_model_id # noqa: F401
|
||||
|
||||
from eland.ml.pytorch import ( # noqa: F401
|
||||
NlpBertTokenizationConfig,
|
||||
@ -338,3 +338,21 @@ class TestPytorchModelUpload:
|
||||
)
|
||||
)
|
||||
assert task_type_from_model_config(model_config=config) == expected_task
|
||||
|
||||
def test_elasticsearch_model_id(self):
|
||||
model_id_from_path = elasticsearch_model_id(r"/foo/bar")
|
||||
assert model_id_from_path == "foo_bar"
|
||||
|
||||
model_id_from_path = elasticsearch_model_id(r"\BAZ\with space")
|
||||
assert model_id_from_path == "baz__with__space"
|
||||
|
||||
model_id_from_path = elasticsearch_model_id(r"/foo/with tab\tback\slash")
|
||||
assert model_id_from_path == "foo__with__space__back__slash"
|
||||
|
||||
long_id_left_truncated = elasticsearch_model_id(
|
||||
"/foo/bar/64charactersoftext64charactersoftext64charactersoftext64charofte"
|
||||
)
|
||||
assert (
|
||||
long_id_left_truncated
|
||||
== "64charactersoftext64charactersoftext64charactersoftext64charofte"
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user