mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
Add prefix_string config option to the import model hub script (#642)
This commit is contained in:
parent
0a6e3db157
commit
64216d44fb
@ -128,6 +128,19 @@ def get_arg_parser():
|
|||||||
"--ca-certs", required=False, default=DEFAULT, help="Path to CA bundle"
|
"--ca-certs", required=False, default=DEFAULT, help="Path to CA bundle"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--ingest-prefix",
|
||||||
|
required=False,
|
||||||
|
default=None,
|
||||||
|
help="String to prepend to model input at ingest",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--search-prefix",
|
||||||
|
required=False,
|
||||||
|
default=None,
|
||||||
|
help="String to prepend to model input at search",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -244,6 +257,8 @@ def main():
|
|||||||
task_type=args.task_type,
|
task_type=args.task_type,
|
||||||
es_version=cluster_version,
|
es_version=cluster_version,
|
||||||
quantize=args.quantize,
|
quantize=args.quantize,
|
||||||
|
ingest_prefix=args.ingest_prefix,
|
||||||
|
search_prefix=args.search_prefix,
|
||||||
)
|
)
|
||||||
model_path, config, vocab_path = tm.save(tmp_dir)
|
model_path, config, vocab_path = tm.save(tmp_dir)
|
||||||
except TaskTypeError as err:
|
except TaskTypeError as err:
|
||||||
|
@ -308,6 +308,23 @@ class TrainedModelInput:
|
|||||||
return self.__dict__
|
return self.__dict__
|
||||||
|
|
||||||
|
|
||||||
|
class PrefixStrings:
|
||||||
|
def __init__(
|
||||||
|
self, *, ingest_prefix: t.Optional[str], search_prefix: t.Optional[str]
|
||||||
|
):
|
||||||
|
self.ingest_prefix = ingest_prefix
|
||||||
|
self.search_prefix = search_prefix
|
||||||
|
|
||||||
|
def to_dict(self) -> t.Dict[str, t.Any]:
|
||||||
|
config = {}
|
||||||
|
if self.ingest_prefix is not None:
|
||||||
|
config["ingest"] = self.ingest_prefix
|
||||||
|
if self.search_prefix is not None:
|
||||||
|
config["search"] = self.search_prefix
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
class NlpTrainedModelConfig:
|
class NlpTrainedModelConfig:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -318,6 +335,7 @@ class NlpTrainedModelConfig:
|
|||||||
metadata: t.Optional[dict] = None,
|
metadata: t.Optional[dict] = None,
|
||||||
model_type: t.Union["t.Literal['pytorch']", str] = "pytorch",
|
model_type: t.Union["t.Literal['pytorch']", str] = "pytorch",
|
||||||
tags: t.Optional[t.Union[t.List[str], t.Tuple[str, ...]]] = None,
|
tags: t.Optional[t.Union[t.List[str], t.Tuple[str, ...]]] = None,
|
||||||
|
prefix_strings: t.Optional[PrefixStrings],
|
||||||
):
|
):
|
||||||
self.tags = tags
|
self.tags = tags
|
||||||
self.description = description
|
self.description = description
|
||||||
@ -325,6 +343,7 @@ class NlpTrainedModelConfig:
|
|||||||
self.input = input
|
self.input = input
|
||||||
self.metadata = metadata
|
self.metadata = metadata
|
||||||
self.model_type = model_type
|
self.model_type = model_type
|
||||||
|
self.prefix_strings = prefix_strings
|
||||||
|
|
||||||
def to_dict(self) -> t.Dict[str, t.Any]:
|
def to_dict(self) -> t.Dict[str, t.Any]:
|
||||||
return {
|
return {
|
||||||
|
@ -53,6 +53,7 @@ from eland.ml.pytorch.nlp_ml_model import (
|
|||||||
NlpTrainedModelConfig,
|
NlpTrainedModelConfig,
|
||||||
NlpXLMRobertaTokenizationConfig,
|
NlpXLMRobertaTokenizationConfig,
|
||||||
PassThroughInferenceOptions,
|
PassThroughInferenceOptions,
|
||||||
|
PrefixStrings,
|
||||||
QuestionAnsweringInferenceOptions,
|
QuestionAnsweringInferenceOptions,
|
||||||
TextClassificationInferenceOptions,
|
TextClassificationInferenceOptions,
|
||||||
TextEmbeddingInferenceOptions,
|
TextEmbeddingInferenceOptions,
|
||||||
@ -596,6 +597,8 @@ class TransformerModel:
|
|||||||
es_version: Optional[Tuple[int, int, int]] = None,
|
es_version: Optional[Tuple[int, int, int]] = None,
|
||||||
quantize: bool = False,
|
quantize: bool = False,
|
||||||
access_token: Optional[str] = None,
|
access_token: Optional[str] = None,
|
||||||
|
ingest_prefix: Optional[str] = None,
|
||||||
|
search_prefix: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Loads a model from the Hugging Face repository or local file and creates
|
Loads a model from the Hugging Face repository or local file and creates
|
||||||
@ -618,11 +621,22 @@ class TransformerModel:
|
|||||||
|
|
||||||
quantize: bool, default False
|
quantize: bool, default False
|
||||||
Quantize the model.
|
Quantize the model.
|
||||||
|
|
||||||
|
access_token: Optional[str]
|
||||||
|
For the HuggingFace Hub private model access
|
||||||
|
|
||||||
|
ingest_prefix: Optional[str]
|
||||||
|
Prefix string to prepend to input at ingest
|
||||||
|
|
||||||
|
search_prefix: Optional[str]
|
||||||
|
Prefix string to prepend to input at search
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self._model_id = model_id
|
self._model_id = model_id
|
||||||
self._access_token = access_token
|
self._access_token = access_token
|
||||||
self._task_type = task_type.replace("-", "_")
|
self._task_type = task_type.replace("-", "_")
|
||||||
|
self._ingest_prefix = ingest_prefix
|
||||||
|
self._search_prefix = search_prefix
|
||||||
|
|
||||||
# load Hugging Face model and tokenizer
|
# load Hugging Face model and tokenizer
|
||||||
# use padding in the tokenizer to ensure max length sequences are used for tracing (at call time)
|
# use padding in the tokenizer to ensure max length sequences are used for tracing (at call time)
|
||||||
@ -783,6 +797,19 @@ class TransformerModel:
|
|||||||
"per_allocation_memory_bytes": per_allocation_memory_bytes,
|
"per_allocation_memory_bytes": per_allocation_memory_bytes,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
prefix_strings = (
|
||||||
|
PrefixStrings(
|
||||||
|
ingest_prefix=self._ingest_prefix, search_prefix=self._search_prefix
|
||||||
|
)
|
||||||
|
if self._ingest_prefix or self._search_prefix
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
prefix_strings_supported = es_version is None or es_version >= (8, 12, 0)
|
||||||
|
if not prefix_strings_supported and prefix_strings:
|
||||||
|
raise Exception(
|
||||||
|
f"The Elasticsearch cluster version {es_version} does not support prefix strings. Support was added in version 8.12.0"
|
||||||
|
)
|
||||||
|
|
||||||
return NlpTrainedModelConfig(
|
return NlpTrainedModelConfig(
|
||||||
description=f"Model {self._model_id} for task type '{self._task_type}'",
|
description=f"Model {self._model_id} for task type '{self._task_type}'",
|
||||||
model_type="pytorch",
|
model_type="pytorch",
|
||||||
@ -791,6 +818,7 @@ class TransformerModel:
|
|||||||
field_names=["text_field"],
|
field_names=["text_field"],
|
||||||
),
|
),
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
|
prefix_strings=prefix_strings,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_per_deployment_memory(self) -> float:
|
def _get_per_deployment_memory(self) -> float:
|
||||||
|
@ -154,13 +154,13 @@ else:
|
|||||||
MODEL_CONFIGURATIONS = []
|
MODEL_CONFIGURATIONS = []
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="https://github.com/elastic/eland/issues/633")
|
|
||||||
class TestModelConfguration:
|
class TestModelConfguration:
|
||||||
|
@pytest.mark.skip(reason="https://github.com/elastic/eland/issues/633")
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"model_id,task_type,config_type,tokenizer_type,max_sequence_len,embedding_size",
|
"model_id,task_type,config_type,tokenizer_type,max_sequence_len,embedding_size",
|
||||||
MODEL_CONFIGURATIONS,
|
MODEL_CONFIGURATIONS,
|
||||||
)
|
)
|
||||||
def test_text_prediction(
|
def test_model_config(
|
||||||
self,
|
self,
|
||||||
model_id,
|
model_id,
|
||||||
task_type,
|
task_type,
|
||||||
@ -170,7 +170,6 @@ class TestModelConfguration:
|
|||||||
embedding_size,
|
embedding_size,
|
||||||
):
|
):
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
print("loading model " + model_id)
|
|
||||||
tm = TransformerModel(
|
tm = TransformerModel(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
task_type=task_type,
|
task_type=task_type,
|
||||||
@ -183,6 +182,7 @@ class TestModelConfguration:
|
|||||||
assert isinstance(config.inference_config, config_type)
|
assert isinstance(config.inference_config, config_type)
|
||||||
tokenization = config.inference_config.tokenization
|
tokenization = config.inference_config.tokenization
|
||||||
assert isinstance(config.metadata, dict)
|
assert isinstance(config.metadata, dict)
|
||||||
|
assert config.prefix_strings is None
|
||||||
assert (
|
assert (
|
||||||
"per_deployment_memory_bytes" in config.metadata
|
"per_deployment_memory_bytes" in config.metadata
|
||||||
and config.metadata["per_deployment_memory_bytes"] > 0
|
and config.metadata["per_deployment_memory_bytes"] > 0
|
||||||
@ -210,3 +210,28 @@ class TestModelConfguration:
|
|||||||
assert len(config.inference_config.classification_labels) > 0
|
assert len(config.inference_config.classification_labels) > 0
|
||||||
|
|
||||||
del tm
|
del tm
|
||||||
|
|
||||||
|
def test_model_config_with_prefix_string(self):
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
tm = TransformerModel(
|
||||||
|
model_id="sentence-transformers/all-distilroberta-v1",
|
||||||
|
task_type="text_embedding",
|
||||||
|
es_version=(8, 12, 0),
|
||||||
|
quantize=False,
|
||||||
|
ingest_prefix="INGEST:",
|
||||||
|
search_prefix="SEARCH:",
|
||||||
|
)
|
||||||
|
_, config, _ = tm.save(tmp_dir)
|
||||||
|
assert config.prefix_strings.to_dict()["ingest"] == "INGEST:"
|
||||||
|
assert config.prefix_strings.to_dict()["search"] == "SEARCH:"
|
||||||
|
|
||||||
|
def test_model_config_with_prefix_string_not_supported(self):
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
TransformerModel(
|
||||||
|
model_id="sentence-transformers/all-distilroberta-v1",
|
||||||
|
task_type="text_embedding",
|
||||||
|
es_version=(8, 11, 0),
|
||||||
|
quantize=False,
|
||||||
|
ingest_prefix="INGEST:",
|
||||||
|
search_prefix="SEARCH:",
|
||||||
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user