diff --git a/eland/cli/eland_import_hub_model.py b/eland/cli/eland_import_hub_model.py index b8d40b4..8cf950f 100755 --- a/eland/cli/eland_import_hub_model.py +++ b/eland/cli/eland_import_hub_model.py @@ -128,6 +128,19 @@ def get_arg_parser(): "--ca-certs", required=False, default=DEFAULT, help="Path to CA bundle" ) + parser.add_argument( + "--ingest-prefix", + required=False, + default=None, + help="String to prepend to model input at ingest", + ) + parser.add_argument( + "--search-prefix", + required=False, + default=None, + help="String to prepend to model input at search", + ) + return parser @@ -244,6 +257,8 @@ def main(): task_type=args.task_type, es_version=cluster_version, quantize=args.quantize, + ingest_prefix=args.ingest_prefix, + search_prefix=args.search_prefix, ) model_path, config, vocab_path = tm.save(tmp_dir) except TaskTypeError as err: diff --git a/eland/ml/pytorch/nlp_ml_model.py b/eland/ml/pytorch/nlp_ml_model.py index 4a7284d..26222f3 100644 --- a/eland/ml/pytorch/nlp_ml_model.py +++ b/eland/ml/pytorch/nlp_ml_model.py @@ -308,6 +308,23 @@ class TrainedModelInput: return self.__dict__ +class PrefixStrings: + def __init__( + self, *, ingest_prefix: t.Optional[str], search_prefix: t.Optional[str] + ): + self.ingest_prefix = ingest_prefix + self.search_prefix = search_prefix + + def to_dict(self) -> t.Dict[str, t.Any]: + config = {} + if self.ingest_prefix is not None: + config["ingest"] = self.ingest_prefix + if self.search_prefix is not None: + config["search"] = self.search_prefix + + return config + + class NlpTrainedModelConfig: def __init__( self, @@ -318,6 +335,7 @@ class NlpTrainedModelConfig: metadata: t.Optional[dict] = None, model_type: t.Union["t.Literal['pytorch']", str] = "pytorch", tags: t.Optional[t.Union[t.List[str], t.Tuple[str, ...]]] = None, + prefix_strings: t.Optional[PrefixStrings], ): self.tags = tags self.description = description @@ -325,6 +343,7 @@ class NlpTrainedModelConfig: self.input = input self.metadata = metadata self.model_type = model_type + self.prefix_strings = prefix_strings def to_dict(self) -> t.Dict[str, t.Any]: return { diff --git a/eland/ml/pytorch/transformers.py b/eland/ml/pytorch/transformers.py index ed047a7..ff41870 100644 --- a/eland/ml/pytorch/transformers.py +++ b/eland/ml/pytorch/transformers.py @@ -53,6 +53,7 @@ from eland.ml.pytorch.nlp_ml_model import ( NlpTrainedModelConfig, NlpXLMRobertaTokenizationConfig, PassThroughInferenceOptions, + PrefixStrings, QuestionAnsweringInferenceOptions, TextClassificationInferenceOptions, TextEmbeddingInferenceOptions, @@ -596,6 +597,8 @@ class TransformerModel: es_version: Optional[Tuple[int, int, int]] = None, quantize: bool = False, access_token: Optional[str] = None, + ingest_prefix: Optional[str] = None, + search_prefix: Optional[str] = None, ): """ Loads a model from the Hugging Face repository or local file and creates @@ -618,11 +621,22 @@ class TransformerModel: quantize: bool, default False Quantize the model. + + access_token: Optional[str] + For the HuggingFace Hub private model access + + ingest_prefix: Optional[str] + Prefix string to prepend to input at ingest + + search_prefix: Optional[str] + Prefix string to prepend to input at search """ self._model_id = model_id self._access_token = access_token self._task_type = task_type.replace("-", "_") + self._ingest_prefix = ingest_prefix + self._search_prefix = search_prefix # load Hugging Face model and tokenizer # use padding in the tokenizer to ensure max length sequences are used for tracing (at call time) @@ -783,6 +797,19 @@ class TransformerModel: "per_allocation_memory_bytes": per_allocation_memory_bytes, } + prefix_strings = ( + PrefixStrings( + ingest_prefix=self._ingest_prefix, search_prefix=self._search_prefix + ) + if self._ingest_prefix or self._search_prefix + else None + ) + prefix_strings_supported = es_version is None or es_version >= (8, 12, 0) + if not prefix_strings_supported and prefix_strings: + raise Exception( + f"The Elasticsearch cluster version {es_version} does not support prefix strings. Support was added in version 8.12.0" + ) + return NlpTrainedModelConfig( description=f"Model {self._model_id} for task type '{self._task_type}'", model_type="pytorch", @@ -791,6 +818,7 @@ class TransformerModel: field_names=["text_field"], ), metadata=metadata, + prefix_strings=prefix_strings, ) def _get_per_deployment_memory(self) -> float: diff --git a/tests/ml/pytorch/test_pytorch_model_config_pytest.py b/tests/ml/pytorch/test_pytorch_model_config_pytest.py index 9c10f4b..8f25574 100644 --- a/tests/ml/pytorch/test_pytorch_model_config_pytest.py +++ b/tests/ml/pytorch/test_pytorch_model_config_pytest.py @@ -154,13 +154,13 @@ else: MODEL_CONFIGURATIONS = [] -@pytest.mark.skip(reason="https://github.com/elastic/eland/issues/633") class TestModelConfguration: + @pytest.mark.skip(reason="https://github.com/elastic/eland/issues/633") @pytest.mark.parametrize( "model_id,task_type,config_type,tokenizer_type,max_sequence_len,embedding_size", MODEL_CONFIGURATIONS, ) - def test_text_prediction( + def test_model_config( self, model_id, task_type, @@ -170,7 +170,6 @@ class TestModelConfguration: embedding_size, ): with tempfile.TemporaryDirectory() as tmp_dir: - print("loading model " + model_id) tm = TransformerModel( model_id=model_id, task_type=task_type, @@ -183,6 +182,7 @@ class TestModelConfguration: assert isinstance(config.inference_config, config_type) tokenization = config.inference_config.tokenization assert isinstance(config.metadata, dict) + assert config.prefix_strings is None assert ( "per_deployment_memory_bytes" in config.metadata and config.metadata["per_deployment_memory_bytes"] > 0 @@ -210,3 +210,28 @@ class TestModelConfguration: assert len(config.inference_config.classification_labels) > 0 del tm + + def test_model_config_with_prefix_string(self): + with tempfile.TemporaryDirectory() as tmp_dir: + tm = TransformerModel( + model_id="sentence-transformers/all-distilroberta-v1", + task_type="text_embedding", + es_version=(8, 12, 0), + quantize=False, + ingest_prefix="INGEST:", + search_prefix="SEARCH:", + ) + _, config, _ = tm.save(tmp_dir) + assert config.prefix_strings.to_dict()["ingest"] == "INGEST:" + assert config.prefix_strings.to_dict()["search"] == "SEARCH:" + + def test_model_config_with_prefix_string_not_supported(self): + with pytest.raises(Exception): + TransformerModel( + model_id="sentence-transformers/all-distilroberta-v1", + task_type="text_embedding", + es_version=(8, 11, 0), + quantize=False, + ingest_prefix="INGEST:", + search_prefix="SEARCH:", + )