mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
Add prefix_string config option to the import model hub script (#642)
This commit is contained in:
parent
0a6e3db157
commit
64216d44fb
@ -128,6 +128,19 @@ def get_arg_parser():
|
||||
"--ca-certs", required=False, default=DEFAULT, help="Path to CA bundle"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ingest-prefix",
|
||||
required=False,
|
||||
default=None,
|
||||
help="String to prepend to model input at ingest",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--search-prefix",
|
||||
required=False,
|
||||
default=None,
|
||||
help="String to prepend to model input at search",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -244,6 +257,8 @@ def main():
|
||||
task_type=args.task_type,
|
||||
es_version=cluster_version,
|
||||
quantize=args.quantize,
|
||||
ingest_prefix=args.ingest_prefix,
|
||||
search_prefix=args.search_prefix,
|
||||
)
|
||||
model_path, config, vocab_path = tm.save(tmp_dir)
|
||||
except TaskTypeError as err:
|
||||
|
@ -308,6 +308,23 @@ class TrainedModelInput:
|
||||
return self.__dict__
|
||||
|
||||
|
||||
class PrefixStrings:
|
||||
def __init__(
|
||||
self, *, ingest_prefix: t.Optional[str], search_prefix: t.Optional[str]
|
||||
):
|
||||
self.ingest_prefix = ingest_prefix
|
||||
self.search_prefix = search_prefix
|
||||
|
||||
def to_dict(self) -> t.Dict[str, t.Any]:
|
||||
config = {}
|
||||
if self.ingest_prefix is not None:
|
||||
config["ingest"] = self.ingest_prefix
|
||||
if self.search_prefix is not None:
|
||||
config["search"] = self.search_prefix
|
||||
|
||||
return config
|
||||
|
||||
|
||||
class NlpTrainedModelConfig:
|
||||
def __init__(
|
||||
self,
|
||||
@ -318,6 +335,7 @@ class NlpTrainedModelConfig:
|
||||
metadata: t.Optional[dict] = None,
|
||||
model_type: t.Union["t.Literal['pytorch']", str] = "pytorch",
|
||||
tags: t.Optional[t.Union[t.List[str], t.Tuple[str, ...]]] = None,
|
||||
prefix_strings: t.Optional[PrefixStrings],
|
||||
):
|
||||
self.tags = tags
|
||||
self.description = description
|
||||
@ -325,6 +343,7 @@ class NlpTrainedModelConfig:
|
||||
self.input = input
|
||||
self.metadata = metadata
|
||||
self.model_type = model_type
|
||||
self.prefix_strings = prefix_strings
|
||||
|
||||
def to_dict(self) -> t.Dict[str, t.Any]:
|
||||
return {
|
||||
|
@ -53,6 +53,7 @@ from eland.ml.pytorch.nlp_ml_model import (
|
||||
NlpTrainedModelConfig,
|
||||
NlpXLMRobertaTokenizationConfig,
|
||||
PassThroughInferenceOptions,
|
||||
PrefixStrings,
|
||||
QuestionAnsweringInferenceOptions,
|
||||
TextClassificationInferenceOptions,
|
||||
TextEmbeddingInferenceOptions,
|
||||
@ -596,6 +597,8 @@ class TransformerModel:
|
||||
es_version: Optional[Tuple[int, int, int]] = None,
|
||||
quantize: bool = False,
|
||||
access_token: Optional[str] = None,
|
||||
ingest_prefix: Optional[str] = None,
|
||||
search_prefix: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Loads a model from the Hugging Face repository or local file and creates
|
||||
@ -618,11 +621,22 @@ class TransformerModel:
|
||||
|
||||
quantize: bool, default False
|
||||
Quantize the model.
|
||||
|
||||
access_token: Optional[str]
|
||||
For the HuggingFace Hub private model access
|
||||
|
||||
ingest_prefix: Optional[str]
|
||||
Prefix string to prepend to input at ingest
|
||||
|
||||
search_prefix: Optional[str]
|
||||
Prefix string to prepend to input at search
|
||||
"""
|
||||
|
||||
self._model_id = model_id
|
||||
self._access_token = access_token
|
||||
self._task_type = task_type.replace("-", "_")
|
||||
self._ingest_prefix = ingest_prefix
|
||||
self._search_prefix = search_prefix
|
||||
|
||||
# load Hugging Face model and tokenizer
|
||||
# use padding in the tokenizer to ensure max length sequences are used for tracing (at call time)
|
||||
@ -783,6 +797,19 @@ class TransformerModel:
|
||||
"per_allocation_memory_bytes": per_allocation_memory_bytes,
|
||||
}
|
||||
|
||||
prefix_strings = (
|
||||
PrefixStrings(
|
||||
ingest_prefix=self._ingest_prefix, search_prefix=self._search_prefix
|
||||
)
|
||||
if self._ingest_prefix or self._search_prefix
|
||||
else None
|
||||
)
|
||||
prefix_strings_supported = es_version is None or es_version >= (8, 12, 0)
|
||||
if not prefix_strings_supported and prefix_strings:
|
||||
raise Exception(
|
||||
f"The Elasticsearch cluster version {es_version} does not support prefix strings. Support was added in version 8.12.0"
|
||||
)
|
||||
|
||||
return NlpTrainedModelConfig(
|
||||
description=f"Model {self._model_id} for task type '{self._task_type}'",
|
||||
model_type="pytorch",
|
||||
@ -791,6 +818,7 @@ class TransformerModel:
|
||||
field_names=["text_field"],
|
||||
),
|
||||
metadata=metadata,
|
||||
prefix_strings=prefix_strings,
|
||||
)
|
||||
|
||||
def _get_per_deployment_memory(self) -> float:
|
||||
|
@ -154,13 +154,13 @@ else:
|
||||
MODEL_CONFIGURATIONS = []
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="https://github.com/elastic/eland/issues/633")
|
||||
class TestModelConfguration:
|
||||
@pytest.mark.skip(reason="https://github.com/elastic/eland/issues/633")
|
||||
@pytest.mark.parametrize(
|
||||
"model_id,task_type,config_type,tokenizer_type,max_sequence_len,embedding_size",
|
||||
MODEL_CONFIGURATIONS,
|
||||
)
|
||||
def test_text_prediction(
|
||||
def test_model_config(
|
||||
self,
|
||||
model_id,
|
||||
task_type,
|
||||
@ -170,7 +170,6 @@ class TestModelConfguration:
|
||||
embedding_size,
|
||||
):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
print("loading model " + model_id)
|
||||
tm = TransformerModel(
|
||||
model_id=model_id,
|
||||
task_type=task_type,
|
||||
@ -183,6 +182,7 @@ class TestModelConfguration:
|
||||
assert isinstance(config.inference_config, config_type)
|
||||
tokenization = config.inference_config.tokenization
|
||||
assert isinstance(config.metadata, dict)
|
||||
assert config.prefix_strings is None
|
||||
assert (
|
||||
"per_deployment_memory_bytes" in config.metadata
|
||||
and config.metadata["per_deployment_memory_bytes"] > 0
|
||||
@ -210,3 +210,28 @@ class TestModelConfguration:
|
||||
assert len(config.inference_config.classification_labels) > 0
|
||||
|
||||
del tm
|
||||
|
||||
def test_model_config_with_prefix_string(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tm = TransformerModel(
|
||||
model_id="sentence-transformers/all-distilroberta-v1",
|
||||
task_type="text_embedding",
|
||||
es_version=(8, 12, 0),
|
||||
quantize=False,
|
||||
ingest_prefix="INGEST:",
|
||||
search_prefix="SEARCH:",
|
||||
)
|
||||
_, config, _ = tm.save(tmp_dir)
|
||||
assert config.prefix_strings.to_dict()["ingest"] == "INGEST:"
|
||||
assert config.prefix_strings.to_dict()["search"] == "SEARCH:"
|
||||
|
||||
def test_model_config_with_prefix_string_not_supported(self):
|
||||
with pytest.raises(Exception):
|
||||
TransformerModel(
|
||||
model_id="sentence-transformers/all-distilroberta-v1",
|
||||
task_type="text_embedding",
|
||||
es_version=(8, 11, 0),
|
||||
quantize=False,
|
||||
ingest_prefix="INGEST:",
|
||||
search_prefix="SEARCH:",
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user