From a9c36927f6d611b5b926db3d9756f4bc209f7568 Mon Sep 17 00:00:00 2001 From: David Kyle Date: Wed, 23 Apr 2025 09:10:02 +0100 Subject: [PATCH] Fix tokeniser for DeBERTa models (#769) --- eland/ml/pytorch/_pytorch_model.py | 3 + eland/ml/pytorch/nlp_ml_model.py | 2 +- eland/ml/pytorch/transformers.py | 313 ++--------------- eland/ml/pytorch/wrappers.py | 317 ++++++++++++++++++ .../test_pytorch_model_config_pytest.py | 9 + .../test_pytorch_model_upload_pytest.py | 24 ++ 6 files changed, 377 insertions(+), 291 deletions(-) create mode 100644 eland/ml/pytorch/wrappers.py diff --git a/eland/ml/pytorch/_pytorch_model.py b/eland/ml/pytorch/_pytorch_model.py index 35b0554..4db312d 100644 --- a/eland/ml/pytorch/_pytorch_model.py +++ b/eland/ml/pytorch/_pytorch_model.py @@ -126,6 +126,7 @@ class PyTorchModel: def infer( self, docs: List[Mapping[str, str]], + inference_config: Optional[Mapping[str, Any]] = None, timeout: str = DEFAULT_TIMEOUT, ) -> Any: if docs is None: @@ -133,6 +134,8 @@ class PyTorchModel: __body: Dict[str, Any] = {} __body["docs"] = docs + if inference_config is not None: + __body["inference_config"] = inference_config __path = f"/_ml/trained_models/{_quote(self.model_id)}/_infer" __query: Dict[str, Any] = {} diff --git a/eland/ml/pytorch/nlp_ml_model.py b/eland/ml/pytorch/nlp_ml_model.py index 012b080..eddd39b 100644 --- a/eland/ml/pytorch/nlp_ml_model.py +++ b/eland/ml/pytorch/nlp_ml_model.py @@ -86,7 +86,7 @@ class NlpXLMRobertaTokenizationConfig(NlpTokenizationConfig): ) -class DebertaV2Config(NlpTokenizationConfig): +class NlpDebertaV2TokenizationConfig(NlpTokenizationConfig): def __init__( self, *, diff --git a/eland/ml/pytorch/transformers.py b/eland/ml/pytorch/transformers.py index 40e5650..04d4ba8 100644 --- a/eland/ml/pytorch/transformers.py +++ b/eland/ml/pytorch/transformers.py @@ -25,17 +25,13 @@ import os.path import random import re from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional, Set, Tuple, Union +from typing import Dict, List, Optional, Set, Tuple, Union import torch # type: ignore import transformers # type: ignore -from sentence_transformers import SentenceTransformer # type: ignore -from torch import Tensor, nn +from torch import Tensor from torch.profiler import profile # type: ignore from transformers import ( - AutoConfig, - AutoModel, - AutoModelForQuestionAnswering, BertTokenizer, PretrainedConfig, PreTrainedModel, @@ -44,11 +40,11 @@ from transformers import ( ) from eland.ml.pytorch.nlp_ml_model import ( - DebertaV2Config, FillMaskInferenceOptions, NerInferenceOptions, NlpBertJapaneseTokenizationConfig, NlpBertTokenizationConfig, + NlpDebertaV2TokenizationConfig, NlpMPNetTokenizationConfig, NlpRobertaTokenizationConfig, NlpTokenizationConfig, @@ -65,8 +61,13 @@ from eland.ml.pytorch.nlp_ml_model import ( ZeroShotClassificationInferenceOptions, ) from eland.ml.pytorch.traceable_model import TraceableModel +from eland.ml.pytorch.wrappers import ( + _DistilBertWrapper, + _DPREncoderWrapper, + _QuestionAnsweringWrapperModule, + _SentenceTransformerWrapperModule, +) -DEFAULT_OUTPUT_KEY = "sentence_embedding" SUPPORTED_TASK_TYPES = { "fill_mask", "ner", @@ -172,284 +173,6 @@ def task_type_from_model_config(model_config: PretrainedConfig) -> Optional[str] return potential_task_types.pop() -class _QuestionAnsweringWrapperModule(nn.Module): # type: ignore - """ - A wrapper around a question answering model. - Our inference engine only takes the first tuple if the inference response - is a tuple. - - This wrapper transforms the output to be a stacked tensor if its a tuple. - - Otherwise it passes it through - """ - - def __init__(self, model: PreTrainedModel): - super().__init__() - self._hf_model = model - self.config = model.config - - @staticmethod - def from_pretrained(model_id: str, *, token: Optional[str] = None) -> Optional[Any]: - model = AutoModelForQuestionAnswering.from_pretrained( - model_id, token=token, torchscript=True - ) - if isinstance( - model.config, - ( - transformers.MPNetConfig, - transformers.XLMRobertaConfig, - transformers.RobertaConfig, - transformers.BartConfig, - ), - ): - return _TwoParameterQuestionAnsweringWrapper(model) - else: - return _QuestionAnsweringWrapper(model) - - -class _QuestionAnsweringWrapper(_QuestionAnsweringWrapperModule): - def __init__(self, model: PreTrainedModel): - super().__init__(model=model) - - def forward( - self, - input_ids: Tensor, - attention_mask: Tensor, - token_type_ids: Tensor, - position_ids: Tensor, - ) -> Tensor: - """Wrap the input and output to conform to the native process interface.""" - - inputs = { - "input_ids": input_ids, - "attention_mask": attention_mask, - "token_type_ids": token_type_ids, - "position_ids": position_ids, - } - - # remove inputs for specific model types - if isinstance(self._hf_model.config, transformers.DistilBertConfig): - del inputs["token_type_ids"] - del inputs["position_ids"] - response = self._hf_model(**inputs) - if isinstance(response, tuple): - return torch.stack(list(response), dim=0) - return response - - -class _TwoParameterQuestionAnsweringWrapper(_QuestionAnsweringWrapperModule): - def __init__(self, model: PreTrainedModel): - super().__init__(model=model) - - def forward(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor: - """Wrap the input and output to conform to the native process interface.""" - inputs = { - "input_ids": input_ids, - "attention_mask": attention_mask, - } - response = self._hf_model(**inputs) - if isinstance(response, tuple): - return torch.stack(list(response), dim=0) - return response - - -class _DistilBertWrapper(nn.Module): # type: ignore - """ - In Elasticsearch the BERT tokenizer is used for DistilBERT models but - the BERT tokenizer produces 4 inputs where DistilBERT models expect 2. - - Wrap the model's forward function in a method that accepts the 4 - arguments passed to a BERT model then discard the token_type_ids - and the position_ids to match the wrapped DistilBERT model forward - function - """ - - def __init__(self, model: transformers.PreTrainedModel): - super().__init__() - self._model = model - self.config = model.config - - @staticmethod - def try_wrapping(model: PreTrainedModel) -> Optional[Any]: - if isinstance(model.config, transformers.DistilBertConfig): - return _DistilBertWrapper(model) - else: - return model - - def forward( - self, - input_ids: Tensor, - attention_mask: Tensor, - _token_type_ids: Tensor = None, - _position_ids: Tensor = None, - ) -> Tensor: - """Wrap the input and output to conform to the native process interface.""" - - return self._model(input_ids=input_ids, attention_mask=attention_mask) - - -class _SentenceTransformerWrapperModule(nn.Module): # type: ignore - """ - A wrapper around sentence-transformer models to provide pooling, - normalization and other graph layers that are not defined in the base - HuggingFace transformer model. - """ - - def __init__(self, model: PreTrainedModel, output_key: str = DEFAULT_OUTPUT_KEY): - super().__init__() - self._hf_model = model - self._st_model = SentenceTransformer(model.config.name_or_path) - self._output_key = output_key - self.config = model.config - - self._remove_pooling_layer() - self._replace_transformer_layer() - - @staticmethod - def from_pretrained( - model_id: str, - tokenizer: PreTrainedTokenizer, - *, - token: Optional[str] = None, - output_key: str = DEFAULT_OUTPUT_KEY, - ) -> Optional[Any]: - model = AutoModel.from_pretrained(model_id, token=token, torchscript=True) - if isinstance( - tokenizer, - ( - transformers.BartTokenizer, - transformers.MPNetTokenizer, - transformers.RobertaTokenizer, - transformers.XLMRobertaTokenizer, - transformers.DebertaV2Tokenizer, - ), - ): - return _TwoParameterSentenceTransformerWrapper(model, output_key) - else: - return _SentenceTransformerWrapper(model, output_key) - - def _remove_pooling_layer(self) -> None: - """ - Removes any last pooling layer which is not used to create embeddings. - Leaving this layer in will cause it to return a NoneType which in turn - will fail to load in libtorch. Alternatively, we can just use the output - of the pooling layer as a dummy but this also affects (if only in a - minor way) the performance of inference, so we're better off removing - the layer if we can. - """ - - if hasattr(self._hf_model, "pooler"): - self._hf_model.pooler = None - - def _replace_transformer_layer(self) -> None: - """ - Replaces the HuggingFace Transformer layer in the SentenceTransformer - modules so we can set it with one that has pooling layer removed and - was loaded ready for TorchScript export. - """ - - self._st_model._modules["0"].auto_model = self._hf_model - - -class _SentenceTransformerWrapper(_SentenceTransformerWrapperModule): - def __init__(self, model: PreTrainedModel, output_key: str = DEFAULT_OUTPUT_KEY): - super().__init__(model=model, output_key=output_key) - - def forward( - self, - input_ids: Tensor, - attention_mask: Tensor, - token_type_ids: Tensor, - position_ids: Tensor, - ) -> Tensor: - """Wrap the input and output to conform to the native process interface.""" - - inputs = { - "input_ids": input_ids, - "attention_mask": attention_mask, - "token_type_ids": token_type_ids, - "position_ids": position_ids, - } - - # remove inputs for specific model types - if isinstance(self._hf_model.config, transformers.DistilBertConfig): - del inputs["token_type_ids"] - - return self._st_model(inputs)[self._output_key] - - -class _TwoParameterSentenceTransformerWrapper(_SentenceTransformerWrapperModule): - def __init__(self, model: PreTrainedModel, output_key: str = DEFAULT_OUTPUT_KEY): - super().__init__(model=model, output_key=output_key) - - def forward(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor: - """Wrap the input and output to conform to the native process interface.""" - inputs = { - "input_ids": input_ids, - "attention_mask": attention_mask, - } - return self._st_model(inputs)[self._output_key] - - -class _DPREncoderWrapper(nn.Module): # type: ignore - """ - AutoModel loading does not work for DPRContextEncoders, this only exists as - a workaround. This may never be fixed so this is likely permanent. - See: https://github.com/huggingface/transformers/issues/13670 - """ - - _SUPPORTED_MODELS = { - transformers.DPRContextEncoder, - transformers.DPRQuestionEncoder, - } - _SUPPORTED_MODELS_NAMES = set([x.__name__ for x in _SUPPORTED_MODELS]) - - def __init__( - self, - model: Union[transformers.DPRContextEncoder, transformers.DPRQuestionEncoder], - ): - super().__init__() - self._model = model - self.config = model.config - - @staticmethod - def from_pretrained(model_id: str, *, token: Optional[str] = None) -> Optional[Any]: - config = AutoConfig.from_pretrained(model_id, token=token) - - def is_compatible() -> bool: - is_dpr_model = config.model_type == "dpr" - has_architectures = ( - config.architectures is not None and len(config.architectures) == 1 - ) - is_supported_architecture = has_architectures and ( - config.architectures[0] in _DPREncoderWrapper._SUPPORTED_MODELS_NAMES - ) - return is_dpr_model and is_supported_architecture - - if is_compatible(): - model = getattr(transformers, config.architectures[0]).from_pretrained( - model_id, torchscript=True - ) - return _DPREncoderWrapper(model) - else: - return None - - def forward( - self, - input_ids: Tensor, - attention_mask: Tensor, - token_type_ids: Tensor, - _position_ids: Tensor, - ) -> Tensor: - """Wrap the input and output to conform to the native process interface.""" - - return self._model( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - ) - - class _TransformerTraceableModel(TraceableModel): """A base class representing a HuggingFace transformer model that can be traced.""" @@ -489,12 +212,17 @@ class _TransformerTraceableModel(TraceableModel): transformers.MPNetTokenizer, transformers.RobertaTokenizer, transformers.XLMRobertaTokenizer, - transformers.DebertaV2Tokenizer, ), ): - del inputs["token_type_ids"] return (inputs["input_ids"], inputs["attention_mask"]) + if isinstance(self._tokenizer, transformers.DebertaV2Tokenizer): + return ( + inputs["input_ids"], + inputs["attention_mask"], + inputs["token_type_ids"], + ) + position_ids = torch.arange(inputs["input_ids"].size(1), dtype=torch.long) inputs["position_ids"] = position_ids return ( @@ -694,7 +422,12 @@ class TransformerModel: " ".join(m) for m, _ in sorted(ranks.items(), key=lambda kv: kv[1]) ] vocab_obj["merges"] = merges - sp_model = getattr(self._tokenizer, "sp_model", None) + + if isinstance(self._tokenizer, transformers.DebertaV2Tokenizer): + sp_model = self._tokenizer._tokenizer.spm + else: + sp_model = getattr(self._tokenizer, "sp_model", None) + if sp_model: id_correction = getattr(self._tokenizer, "fairseq_offset", 0) scores = [] @@ -733,7 +466,7 @@ class TransformerModel: max_sequence_length=_max_sequence_length ) elif isinstance(self._tokenizer, transformers.DebertaV2Tokenizer): - return DebertaV2Config( + return NlpDebertaV2TokenizationConfig( max_sequence_length=_max_sequence_length, do_lower_case=getattr(self._tokenizer, "do_lower_case", None), ) diff --git a/eland/ml/pytorch/wrappers.py b/eland/ml/pytorch/wrappers.py new file mode 100644 index 0000000..62750a0 --- /dev/null +++ b/eland/ml/pytorch/wrappers.py @@ -0,0 +1,317 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +This module contains the wrapper classes for the Hugging Face models. +Wrapping is necessary to ensure that the forward method of the model +is called with the same arguments the ml-cpp pytorch_inference process +uses. +""" + +from typing import Any, Optional, Union + +import torch # type: ignore +import transformers # type: ignore +from sentence_transformers import SentenceTransformer # type: ignore +from torch import Tensor, nn +from transformers import ( + AutoConfig, + AutoModel, + AutoModelForQuestionAnswering, + PreTrainedModel, + PreTrainedTokenizer, +) + +DEFAULT_OUTPUT_KEY = "sentence_embedding" + + +class _QuestionAnsweringWrapperModule(nn.Module): # type: ignore + """ + A wrapper around a question answering model. + Our inference engine only takes the first tuple if the inference response + is a tuple. + + This wrapper transforms the output to be a stacked tensor if its a tuple. + + Otherwise it passes it through + """ + + def __init__(self, model: PreTrainedModel): + super().__init__() + self._hf_model = model + self.config = model.config + + @staticmethod + def from_pretrained(model_id: str, *, token: Optional[str] = None) -> Optional[Any]: + model = AutoModelForQuestionAnswering.from_pretrained( + model_id, token=token, torchscript=True + ) + if isinstance( + model.config, + ( + transformers.MPNetConfig, + transformers.XLMRobertaConfig, + transformers.RobertaConfig, + transformers.BartConfig, + ), + ): + return _TwoParameterQuestionAnsweringWrapper(model) + else: + return _QuestionAnsweringWrapper(model) + + +class _QuestionAnsweringWrapper(_QuestionAnsweringWrapperModule): + def __init__(self, model: PreTrainedModel): + super().__init__(model=model) + + def forward( + self, + input_ids: Tensor, + attention_mask: Tensor, + token_type_ids: Tensor, + position_ids: Tensor, + ) -> Tensor: + """Wrap the input and output to conform to the native process interface.""" + + inputs = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + "position_ids": position_ids, + } + + # remove inputs for specific model types + if isinstance(self._hf_model.config, transformers.DistilBertConfig): + del inputs["token_type_ids"] + del inputs["position_ids"] + response = self._hf_model(**inputs) + if isinstance(response, tuple): + return torch.stack(list(response), dim=0) + return response + + +class _TwoParameterQuestionAnsweringWrapper(_QuestionAnsweringWrapperModule): + def __init__(self, model: PreTrainedModel): + super().__init__(model=model) + + def forward(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor: + """Wrap the input and output to conform to the native process interface.""" + inputs = { + "input_ids": input_ids, + "attention_mask": attention_mask, + } + response = self._hf_model(**inputs) + if isinstance(response, tuple): + return torch.stack(list(response), dim=0) + return response + + +class _DistilBertWrapper(nn.Module): # type: ignore + """ + In Elasticsearch the BERT tokenizer is used for DistilBERT models but + the BERT tokenizer produces 4 inputs where DistilBERT models expect 2. + + Wrap the model's forward function in a method that accepts the 4 + arguments passed to a BERT model then discard the token_type_ids + and the position_ids to match the wrapped DistilBERT model forward + function + """ + + def __init__(self, model: transformers.PreTrainedModel): + super().__init__() + self._model = model + self.config = model.config + + @staticmethod + def try_wrapping(model: PreTrainedModel) -> Optional[Any]: + if isinstance(model.config, transformers.DistilBertConfig): + return _DistilBertWrapper(model) + else: + return model + + def forward( + self, + input_ids: Tensor, + attention_mask: Tensor, + _token_type_ids: Tensor = None, + _position_ids: Tensor = None, + ) -> Tensor: + """Wrap the input and output to conform to the native process interface.""" + + return self._model(input_ids=input_ids, attention_mask=attention_mask) + + +class _SentenceTransformerWrapperModule(nn.Module): # type: ignore + """ + A wrapper around sentence-transformer models to provide pooling, + normalization and other graph layers that are not defined in the base + HuggingFace transformer model. + """ + + def __init__(self, model: PreTrainedModel, output_key: str = DEFAULT_OUTPUT_KEY): + super().__init__() + self._hf_model = model + self._st_model = SentenceTransformer(model.config.name_or_path) + self._output_key = output_key + self.config = model.config + + self._remove_pooling_layer() + self._replace_transformer_layer() + + @staticmethod + def from_pretrained( + model_id: str, + tokenizer: PreTrainedTokenizer, + *, + token: Optional[str] = None, + output_key: str = DEFAULT_OUTPUT_KEY, + ) -> Optional[Any]: + model = AutoModel.from_pretrained(model_id, token=token, torchscript=True) + if isinstance( + tokenizer, + ( + transformers.BartTokenizer, + transformers.MPNetTokenizer, + transformers.RobertaTokenizer, + transformers.XLMRobertaTokenizer, + transformers.DebertaV2Tokenizer, + ), + ): + return _TwoParameterSentenceTransformerWrapper(model, output_key) + else: + return _SentenceTransformerWrapper(model, output_key) + + def _remove_pooling_layer(self) -> None: + """ + Removes any last pooling layer which is not used to create embeddings. + Leaving this layer in will cause it to return a NoneType which in turn + will fail to load in libtorch. Alternatively, we can just use the output + of the pooling layer as a dummy but this also affects (if only in a + minor way) the performance of inference, so we're better off removing + the layer if we can. + """ + + if hasattr(self._hf_model, "pooler"): + self._hf_model.pooler = None + + def _replace_transformer_layer(self) -> None: + """ + Replaces the HuggingFace Transformer layer in the SentenceTransformer + modules so we can set it with one that has pooling layer removed and + was loaded ready for TorchScript export. + """ + + self._st_model._modules["0"].auto_model = self._hf_model + + +class _SentenceTransformerWrapper(_SentenceTransformerWrapperModule): + def __init__(self, model: PreTrainedModel, output_key: str = DEFAULT_OUTPUT_KEY): + super().__init__(model=model, output_key=output_key) + + def forward( + self, + input_ids: Tensor, + attention_mask: Tensor, + token_type_ids: Tensor, + position_ids: Tensor, + ) -> Tensor: + """Wrap the input and output to conform to the native process interface.""" + + inputs = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + "position_ids": position_ids, + } + + # remove inputs for specific model types + if isinstance(self._hf_model.config, transformers.DistilBertConfig): + del inputs["token_type_ids"] + + return self._st_model(inputs)[self._output_key] + + +class _TwoParameterSentenceTransformerWrapper(_SentenceTransformerWrapperModule): + def __init__(self, model: PreTrainedModel, output_key: str = DEFAULT_OUTPUT_KEY): + super().__init__(model=model, output_key=output_key) + + def forward(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor: + """Wrap the input and output to conform to the native process interface.""" + inputs = { + "input_ids": input_ids, + "attention_mask": attention_mask, + } + return self._st_model(inputs)[self._output_key] + + +class _DPREncoderWrapper(nn.Module): # type: ignore + """ + AutoModel loading does not work for DPRContextEncoders, this only exists as + a workaround. This may never be fixed so this is likely permanent. + See: https://github.com/huggingface/transformers/issues/13670 + """ + + _SUPPORTED_MODELS = { + transformers.DPRContextEncoder, + transformers.DPRQuestionEncoder, + } + _SUPPORTED_MODELS_NAMES = set([x.__name__ for x in _SUPPORTED_MODELS]) + + def __init__( + self, + model: Union[transformers.DPRContextEncoder, transformers.DPRQuestionEncoder], + ): + super().__init__() + self._model = model + self.config = model.config + + @staticmethod + def from_pretrained(model_id: str, *, token: Optional[str] = None) -> Optional[Any]: + config = AutoConfig.from_pretrained(model_id, token=token) + + def is_compatible() -> bool: + is_dpr_model = config.model_type == "dpr" + has_architectures = ( + config.architectures is not None and len(config.architectures) == 1 + ) + is_supported_architecture = has_architectures and ( + config.architectures[0] in _DPREncoderWrapper._SUPPORTED_MODELS_NAMES + ) + return is_dpr_model and is_supported_architecture + + if is_compatible(): + model = getattr(transformers, config.architectures[0]).from_pretrained( + model_id, torchscript=True + ) + return _DPREncoderWrapper(model) + else: + return None + + def forward( + self, + input_ids: Tensor, + attention_mask: Tensor, + token_type_ids: Tensor, + _position_ids: Tensor, + ) -> Tensor: + """Wrap the input and output to conform to the native process interface.""" + + return self._model( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + ) diff --git a/tests/ml/pytorch/test_pytorch_model_config_pytest.py b/tests/ml/pytorch/test_pytorch_model_config_pytest.py index c12be3a..6adc885 100644 --- a/tests/ml/pytorch/test_pytorch_model_config_pytest.py +++ b/tests/ml/pytorch/test_pytorch_model_config_pytest.py @@ -39,6 +39,7 @@ try: from eland.ml.pytorch import ( FillMaskInferenceOptions, NlpBertTokenizationConfig, + NlpDebertaV2TokenizationConfig, NlpMPNetTokenizationConfig, NlpRobertaTokenizationConfig, NlpXLMRobertaTokenizationConfig, @@ -149,6 +150,14 @@ if HAS_PYTORCH and HAS_SKLEARN and HAS_TRANSFORMERS: 1024, None, ), + ( + "microsoft/deberta-v3-xsmall", + "fill_mask", + FillMaskInferenceOptions, + NlpDebertaV2TokenizationConfig, + 512, + None, + ), ] else: MODEL_CONFIGURATIONS = [] diff --git a/tests/ml/pytorch/test_pytorch_model_upload_pytest.py b/tests/ml/pytorch/test_pytorch_model_upload_pytest.py index c84a77e..09fa439 100644 --- a/tests/ml/pytorch/test_pytorch_model_upload_pytest.py +++ b/tests/ml/pytorch/test_pytorch_model_upload_pytest.py @@ -67,6 +67,8 @@ TEXT_EMBEDDING_MODELS = [ ) ] +TEXT_SIMILARITY_MODELS = ["mixedbread-ai/mxbai-rerank-xsmall-v1"] + @pytest.fixture(scope="function", autouse=True) def setup_and_tear_down(): @@ -135,3 +137,25 @@ class TestPytorchModel: ) > 0 ) + + @pytest.mark.skipif( + ES_VERSION < (8, 16, 0), reason="requires 8.16.0 for DeBERTa models" + ) + @pytest.mark.parametrize("model_id", TEXT_SIMILARITY_MODELS) + def test_text_similarity(self, model_id): + with tempfile.TemporaryDirectory() as tmp_dir: + ptm = download_model_and_start_deployment( + tmp_dir, False, model_id, "text_similarity" + ) + result = ptm.infer( + docs=[ + { + "text_field": "The Amazon rainforest covers most of the Amazon basin in South America" + }, + {"text_field": "Paris is the capital of France"}, + ], + inference_config={"text_similarity": {"text": "France"}}, + ) + + assert result.body["inference_results"][0]["predicted_value"] < 0 + assert result.body["inference_results"][1]["predicted_value"] > 0