Add prefix_string config option to the import model hub script (#642)

This commit is contained in:
David Kyle 2024-01-19 08:06:57 +00:00 committed by GitHub
parent 0a6e3db157
commit 64216d44fb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 90 additions and 3 deletions

View File

@ -128,6 +128,19 @@ def get_arg_parser():
"--ca-certs", required=False, default=DEFAULT, help="Path to CA bundle" "--ca-certs", required=False, default=DEFAULT, help="Path to CA bundle"
) )
parser.add_argument(
"--ingest-prefix",
required=False,
default=None,
help="String to prepend to model input at ingest",
)
parser.add_argument(
"--search-prefix",
required=False,
default=None,
help="String to prepend to model input at search",
)
return parser return parser
@ -244,6 +257,8 @@ def main():
task_type=args.task_type, task_type=args.task_type,
es_version=cluster_version, es_version=cluster_version,
quantize=args.quantize, quantize=args.quantize,
ingest_prefix=args.ingest_prefix,
search_prefix=args.search_prefix,
) )
model_path, config, vocab_path = tm.save(tmp_dir) model_path, config, vocab_path = tm.save(tmp_dir)
except TaskTypeError as err: except TaskTypeError as err:

View File

@ -308,6 +308,23 @@ class TrainedModelInput:
return self.__dict__ return self.__dict__
class PrefixStrings:
def __init__(
self, *, ingest_prefix: t.Optional[str], search_prefix: t.Optional[str]
):
self.ingest_prefix = ingest_prefix
self.search_prefix = search_prefix
def to_dict(self) -> t.Dict[str, t.Any]:
config = {}
if self.ingest_prefix is not None:
config["ingest"] = self.ingest_prefix
if self.search_prefix is not None:
config["search"] = self.search_prefix
return config
class NlpTrainedModelConfig: class NlpTrainedModelConfig:
def __init__( def __init__(
self, self,
@ -318,6 +335,7 @@ class NlpTrainedModelConfig:
metadata: t.Optional[dict] = None, metadata: t.Optional[dict] = None,
model_type: t.Union["t.Literal['pytorch']", str] = "pytorch", model_type: t.Union["t.Literal['pytorch']", str] = "pytorch",
tags: t.Optional[t.Union[t.List[str], t.Tuple[str, ...]]] = None, tags: t.Optional[t.Union[t.List[str], t.Tuple[str, ...]]] = None,
prefix_strings: t.Optional[PrefixStrings],
): ):
self.tags = tags self.tags = tags
self.description = description self.description = description
@ -325,6 +343,7 @@ class NlpTrainedModelConfig:
self.input = input self.input = input
self.metadata = metadata self.metadata = metadata
self.model_type = model_type self.model_type = model_type
self.prefix_strings = prefix_strings
def to_dict(self) -> t.Dict[str, t.Any]: def to_dict(self) -> t.Dict[str, t.Any]:
return { return {

View File

@ -53,6 +53,7 @@ from eland.ml.pytorch.nlp_ml_model import (
NlpTrainedModelConfig, NlpTrainedModelConfig,
NlpXLMRobertaTokenizationConfig, NlpXLMRobertaTokenizationConfig,
PassThroughInferenceOptions, PassThroughInferenceOptions,
PrefixStrings,
QuestionAnsweringInferenceOptions, QuestionAnsweringInferenceOptions,
TextClassificationInferenceOptions, TextClassificationInferenceOptions,
TextEmbeddingInferenceOptions, TextEmbeddingInferenceOptions,
@ -596,6 +597,8 @@ class TransformerModel:
es_version: Optional[Tuple[int, int, int]] = None, es_version: Optional[Tuple[int, int, int]] = None,
quantize: bool = False, quantize: bool = False,
access_token: Optional[str] = None, access_token: Optional[str] = None,
ingest_prefix: Optional[str] = None,
search_prefix: Optional[str] = None,
): ):
""" """
Loads a model from the Hugging Face repository or local file and creates Loads a model from the Hugging Face repository or local file and creates
@ -618,11 +621,22 @@ class TransformerModel:
quantize: bool, default False quantize: bool, default False
Quantize the model. Quantize the model.
access_token: Optional[str]
For the HuggingFace Hub private model access
ingest_prefix: Optional[str]
Prefix string to prepend to input at ingest
search_prefix: Optional[str]
Prefix string to prepend to input at search
""" """
self._model_id = model_id self._model_id = model_id
self._access_token = access_token self._access_token = access_token
self._task_type = task_type.replace("-", "_") self._task_type = task_type.replace("-", "_")
self._ingest_prefix = ingest_prefix
self._search_prefix = search_prefix
# load Hugging Face model and tokenizer # load Hugging Face model and tokenizer
# use padding in the tokenizer to ensure max length sequences are used for tracing (at call time) # use padding in the tokenizer to ensure max length sequences are used for tracing (at call time)
@ -783,6 +797,19 @@ class TransformerModel:
"per_allocation_memory_bytes": per_allocation_memory_bytes, "per_allocation_memory_bytes": per_allocation_memory_bytes,
} }
prefix_strings = (
PrefixStrings(
ingest_prefix=self._ingest_prefix, search_prefix=self._search_prefix
)
if self._ingest_prefix or self._search_prefix
else None
)
prefix_strings_supported = es_version is None or es_version >= (8, 12, 0)
if not prefix_strings_supported and prefix_strings:
raise Exception(
f"The Elasticsearch cluster version {es_version} does not support prefix strings. Support was added in version 8.12.0"
)
return NlpTrainedModelConfig( return NlpTrainedModelConfig(
description=f"Model {self._model_id} for task type '{self._task_type}'", description=f"Model {self._model_id} for task type '{self._task_type}'",
model_type="pytorch", model_type="pytorch",
@ -791,6 +818,7 @@ class TransformerModel:
field_names=["text_field"], field_names=["text_field"],
), ),
metadata=metadata, metadata=metadata,
prefix_strings=prefix_strings,
) )
def _get_per_deployment_memory(self) -> float: def _get_per_deployment_memory(self) -> float:

View File

@ -154,13 +154,13 @@ else:
MODEL_CONFIGURATIONS = [] MODEL_CONFIGURATIONS = []
@pytest.mark.skip(reason="https://github.com/elastic/eland/issues/633")
class TestModelConfguration: class TestModelConfguration:
@pytest.mark.skip(reason="https://github.com/elastic/eland/issues/633")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_id,task_type,config_type,tokenizer_type,max_sequence_len,embedding_size", "model_id,task_type,config_type,tokenizer_type,max_sequence_len,embedding_size",
MODEL_CONFIGURATIONS, MODEL_CONFIGURATIONS,
) )
def test_text_prediction( def test_model_config(
self, self,
model_id, model_id,
task_type, task_type,
@ -170,7 +170,6 @@ class TestModelConfguration:
embedding_size, embedding_size,
): ):
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
print("loading model " + model_id)
tm = TransformerModel( tm = TransformerModel(
model_id=model_id, model_id=model_id,
task_type=task_type, task_type=task_type,
@ -183,6 +182,7 @@ class TestModelConfguration:
assert isinstance(config.inference_config, config_type) assert isinstance(config.inference_config, config_type)
tokenization = config.inference_config.tokenization tokenization = config.inference_config.tokenization
assert isinstance(config.metadata, dict) assert isinstance(config.metadata, dict)
assert config.prefix_strings is None
assert ( assert (
"per_deployment_memory_bytes" in config.metadata "per_deployment_memory_bytes" in config.metadata
and config.metadata["per_deployment_memory_bytes"] > 0 and config.metadata["per_deployment_memory_bytes"] > 0
@ -210,3 +210,28 @@ class TestModelConfguration:
assert len(config.inference_config.classification_labels) > 0 assert len(config.inference_config.classification_labels) > 0
del tm del tm
def test_model_config_with_prefix_string(self):
with tempfile.TemporaryDirectory() as tmp_dir:
tm = TransformerModel(
model_id="sentence-transformers/all-distilroberta-v1",
task_type="text_embedding",
es_version=(8, 12, 0),
quantize=False,
ingest_prefix="INGEST:",
search_prefix="SEARCH:",
)
_, config, _ = tm.save(tmp_dir)
assert config.prefix_strings.to_dict()["ingest"] == "INGEST:"
assert config.prefix_strings.to_dict()["search"] == "SEARCH:"
def test_model_config_with_prefix_string_not_supported(self):
with pytest.raises(Exception):
TransformerModel(
model_id="sentence-transformers/all-distilroberta-v1",
task_type="text_embedding",
es_version=(8, 11, 0),
quantize=False,
ingest_prefix="INGEST:",
search_prefix="SEARCH:",
)