Fix tokeniser for DeBERTa models (#769)

This commit is contained in:
David Kyle 2025-04-23 09:10:02 +01:00 committed by GitHub
parent 87380ef716
commit a9c36927f6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 377 additions and 291 deletions

View File

@ -126,6 +126,7 @@ class PyTorchModel:
def infer(
self,
docs: List[Mapping[str, str]],
inference_config: Optional[Mapping[str, Any]] = None,
timeout: str = DEFAULT_TIMEOUT,
) -> Any:
if docs is None:
@ -133,6 +134,8 @@ class PyTorchModel:
__body: Dict[str, Any] = {}
__body["docs"] = docs
if inference_config is not None:
__body["inference_config"] = inference_config
__path = f"/_ml/trained_models/{_quote(self.model_id)}/_infer"
__query: Dict[str, Any] = {}

View File

@ -86,7 +86,7 @@ class NlpXLMRobertaTokenizationConfig(NlpTokenizationConfig):
)
class DebertaV2Config(NlpTokenizationConfig):
class NlpDebertaV2TokenizationConfig(NlpTokenizationConfig):
def __init__(
self,
*,

View File

@ -25,17 +25,13 @@ import os.path
import random
import re
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Set, Tuple, Union
from typing import Dict, List, Optional, Set, Tuple, Union
import torch # type: ignore
import transformers # type: ignore
from sentence_transformers import SentenceTransformer # type: ignore
from torch import Tensor, nn
from torch import Tensor
from torch.profiler import profile # type: ignore
from transformers import (
AutoConfig,
AutoModel,
AutoModelForQuestionAnswering,
BertTokenizer,
PretrainedConfig,
PreTrainedModel,
@ -44,11 +40,11 @@ from transformers import (
)
from eland.ml.pytorch.nlp_ml_model import (
DebertaV2Config,
FillMaskInferenceOptions,
NerInferenceOptions,
NlpBertJapaneseTokenizationConfig,
NlpBertTokenizationConfig,
NlpDebertaV2TokenizationConfig,
NlpMPNetTokenizationConfig,
NlpRobertaTokenizationConfig,
NlpTokenizationConfig,
@ -65,8 +61,13 @@ from eland.ml.pytorch.nlp_ml_model import (
ZeroShotClassificationInferenceOptions,
)
from eland.ml.pytorch.traceable_model import TraceableModel
from eland.ml.pytorch.wrappers import (
_DistilBertWrapper,
_DPREncoderWrapper,
_QuestionAnsweringWrapperModule,
_SentenceTransformerWrapperModule,
)
DEFAULT_OUTPUT_KEY = "sentence_embedding"
SUPPORTED_TASK_TYPES = {
"fill_mask",
"ner",
@ -172,284 +173,6 @@ def task_type_from_model_config(model_config: PretrainedConfig) -> Optional[str]
return potential_task_types.pop()
class _QuestionAnsweringWrapperModule(nn.Module): # type: ignore
"""
A wrapper around a question answering model.
Our inference engine only takes the first tuple if the inference response
is a tuple.
This wrapper transforms the output to be a stacked tensor if its a tuple.
Otherwise it passes it through
"""
def __init__(self, model: PreTrainedModel):
super().__init__()
self._hf_model = model
self.config = model.config
@staticmethod
def from_pretrained(model_id: str, *, token: Optional[str] = None) -> Optional[Any]:
model = AutoModelForQuestionAnswering.from_pretrained(
model_id, token=token, torchscript=True
)
if isinstance(
model.config,
(
transformers.MPNetConfig,
transformers.XLMRobertaConfig,
transformers.RobertaConfig,
transformers.BartConfig,
),
):
return _TwoParameterQuestionAnsweringWrapper(model)
else:
return _QuestionAnsweringWrapper(model)
class _QuestionAnsweringWrapper(_QuestionAnsweringWrapperModule):
def __init__(self, model: PreTrainedModel):
super().__init__(model=model)
def forward(
self,
input_ids: Tensor,
attention_mask: Tensor,
token_type_ids: Tensor,
position_ids: Tensor,
) -> Tensor:
"""Wrap the input and output to conform to the native process interface."""
inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
"position_ids": position_ids,
}
# remove inputs for specific model types
if isinstance(self._hf_model.config, transformers.DistilBertConfig):
del inputs["token_type_ids"]
del inputs["position_ids"]
response = self._hf_model(**inputs)
if isinstance(response, tuple):
return torch.stack(list(response), dim=0)
return response
class _TwoParameterQuestionAnsweringWrapper(_QuestionAnsweringWrapperModule):
def __init__(self, model: PreTrainedModel):
super().__init__(model=model)
def forward(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor:
"""Wrap the input and output to conform to the native process interface."""
inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
}
response = self._hf_model(**inputs)
if isinstance(response, tuple):
return torch.stack(list(response), dim=0)
return response
class _DistilBertWrapper(nn.Module): # type: ignore
"""
In Elasticsearch the BERT tokenizer is used for DistilBERT models but
the BERT tokenizer produces 4 inputs where DistilBERT models expect 2.
Wrap the model's forward function in a method that accepts the 4
arguments passed to a BERT model then discard the token_type_ids
and the position_ids to match the wrapped DistilBERT model forward
function
"""
def __init__(self, model: transformers.PreTrainedModel):
super().__init__()
self._model = model
self.config = model.config
@staticmethod
def try_wrapping(model: PreTrainedModel) -> Optional[Any]:
if isinstance(model.config, transformers.DistilBertConfig):
return _DistilBertWrapper(model)
else:
return model
def forward(
self,
input_ids: Tensor,
attention_mask: Tensor,
_token_type_ids: Tensor = None,
_position_ids: Tensor = None,
) -> Tensor:
"""Wrap the input and output to conform to the native process interface."""
return self._model(input_ids=input_ids, attention_mask=attention_mask)
class _SentenceTransformerWrapperModule(nn.Module): # type: ignore
"""
A wrapper around sentence-transformer models to provide pooling,
normalization and other graph layers that are not defined in the base
HuggingFace transformer model.
"""
def __init__(self, model: PreTrainedModel, output_key: str = DEFAULT_OUTPUT_KEY):
super().__init__()
self._hf_model = model
self._st_model = SentenceTransformer(model.config.name_or_path)
self._output_key = output_key
self.config = model.config
self._remove_pooling_layer()
self._replace_transformer_layer()
@staticmethod
def from_pretrained(
model_id: str,
tokenizer: PreTrainedTokenizer,
*,
token: Optional[str] = None,
output_key: str = DEFAULT_OUTPUT_KEY,
) -> Optional[Any]:
model = AutoModel.from_pretrained(model_id, token=token, torchscript=True)
if isinstance(
tokenizer,
(
transformers.BartTokenizer,
transformers.MPNetTokenizer,
transformers.RobertaTokenizer,
transformers.XLMRobertaTokenizer,
transformers.DebertaV2Tokenizer,
),
):
return _TwoParameterSentenceTransformerWrapper(model, output_key)
else:
return _SentenceTransformerWrapper(model, output_key)
def _remove_pooling_layer(self) -> None:
"""
Removes any last pooling layer which is not used to create embeddings.
Leaving this layer in will cause it to return a NoneType which in turn
will fail to load in libtorch. Alternatively, we can just use the output
of the pooling layer as a dummy but this also affects (if only in a
minor way) the performance of inference, so we're better off removing
the layer if we can.
"""
if hasattr(self._hf_model, "pooler"):
self._hf_model.pooler = None
def _replace_transformer_layer(self) -> None:
"""
Replaces the HuggingFace Transformer layer in the SentenceTransformer
modules so we can set it with one that has pooling layer removed and
was loaded ready for TorchScript export.
"""
self._st_model._modules["0"].auto_model = self._hf_model
class _SentenceTransformerWrapper(_SentenceTransformerWrapperModule):
def __init__(self, model: PreTrainedModel, output_key: str = DEFAULT_OUTPUT_KEY):
super().__init__(model=model, output_key=output_key)
def forward(
self,
input_ids: Tensor,
attention_mask: Tensor,
token_type_ids: Tensor,
position_ids: Tensor,
) -> Tensor:
"""Wrap the input and output to conform to the native process interface."""
inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
"position_ids": position_ids,
}
# remove inputs for specific model types
if isinstance(self._hf_model.config, transformers.DistilBertConfig):
del inputs["token_type_ids"]
return self._st_model(inputs)[self._output_key]
class _TwoParameterSentenceTransformerWrapper(_SentenceTransformerWrapperModule):
def __init__(self, model: PreTrainedModel, output_key: str = DEFAULT_OUTPUT_KEY):
super().__init__(model=model, output_key=output_key)
def forward(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor:
"""Wrap the input and output to conform to the native process interface."""
inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
}
return self._st_model(inputs)[self._output_key]
class _DPREncoderWrapper(nn.Module): # type: ignore
"""
AutoModel loading does not work for DPRContextEncoders, this only exists as
a workaround. This may never be fixed so this is likely permanent.
See: https://github.com/huggingface/transformers/issues/13670
"""
_SUPPORTED_MODELS = {
transformers.DPRContextEncoder,
transformers.DPRQuestionEncoder,
}
_SUPPORTED_MODELS_NAMES = set([x.__name__ for x in _SUPPORTED_MODELS])
def __init__(
self,
model: Union[transformers.DPRContextEncoder, transformers.DPRQuestionEncoder],
):
super().__init__()
self._model = model
self.config = model.config
@staticmethod
def from_pretrained(model_id: str, *, token: Optional[str] = None) -> Optional[Any]:
config = AutoConfig.from_pretrained(model_id, token=token)
def is_compatible() -> bool:
is_dpr_model = config.model_type == "dpr"
has_architectures = (
config.architectures is not None and len(config.architectures) == 1
)
is_supported_architecture = has_architectures and (
config.architectures[0] in _DPREncoderWrapper._SUPPORTED_MODELS_NAMES
)
return is_dpr_model and is_supported_architecture
if is_compatible():
model = getattr(transformers, config.architectures[0]).from_pretrained(
model_id, torchscript=True
)
return _DPREncoderWrapper(model)
else:
return None
def forward(
self,
input_ids: Tensor,
attention_mask: Tensor,
token_type_ids: Tensor,
_position_ids: Tensor,
) -> Tensor:
"""Wrap the input and output to conform to the native process interface."""
return self._model(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
)
class _TransformerTraceableModel(TraceableModel):
"""A base class representing a HuggingFace transformer model that can be traced."""
@ -489,12 +212,17 @@ class _TransformerTraceableModel(TraceableModel):
transformers.MPNetTokenizer,
transformers.RobertaTokenizer,
transformers.XLMRobertaTokenizer,
transformers.DebertaV2Tokenizer,
),
):
del inputs["token_type_ids"]
return (inputs["input_ids"], inputs["attention_mask"])
if isinstance(self._tokenizer, transformers.DebertaV2Tokenizer):
return (
inputs["input_ids"],
inputs["attention_mask"],
inputs["token_type_ids"],
)
position_ids = torch.arange(inputs["input_ids"].size(1), dtype=torch.long)
inputs["position_ids"] = position_ids
return (
@ -694,7 +422,12 @@ class TransformerModel:
" ".join(m) for m, _ in sorted(ranks.items(), key=lambda kv: kv[1])
]
vocab_obj["merges"] = merges
if isinstance(self._tokenizer, transformers.DebertaV2Tokenizer):
sp_model = self._tokenizer._tokenizer.spm
else:
sp_model = getattr(self._tokenizer, "sp_model", None)
if sp_model:
id_correction = getattr(self._tokenizer, "fairseq_offset", 0)
scores = []
@ -733,7 +466,7 @@ class TransformerModel:
max_sequence_length=_max_sequence_length
)
elif isinstance(self._tokenizer, transformers.DebertaV2Tokenizer):
return DebertaV2Config(
return NlpDebertaV2TokenizationConfig(
max_sequence_length=_max_sequence_length,
do_lower_case=getattr(self._tokenizer, "do_lower_case", None),
)

View File

@ -0,0 +1,317 @@
# Licensed to Elasticsearch B.V. under one or more contributor
# license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright
# ownership. Elasticsearch B.V. licenses this file to you under
# the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
This module contains the wrapper classes for the Hugging Face models.
Wrapping is necessary to ensure that the forward method of the model
is called with the same arguments the ml-cpp pytorch_inference process
uses.
"""
from typing import Any, Optional, Union
import torch # type: ignore
import transformers # type: ignore
from sentence_transformers import SentenceTransformer # type: ignore
from torch import Tensor, nn
from transformers import (
AutoConfig,
AutoModel,
AutoModelForQuestionAnswering,
PreTrainedModel,
PreTrainedTokenizer,
)
DEFAULT_OUTPUT_KEY = "sentence_embedding"
class _QuestionAnsweringWrapperModule(nn.Module): # type: ignore
"""
A wrapper around a question answering model.
Our inference engine only takes the first tuple if the inference response
is a tuple.
This wrapper transforms the output to be a stacked tensor if its a tuple.
Otherwise it passes it through
"""
def __init__(self, model: PreTrainedModel):
super().__init__()
self._hf_model = model
self.config = model.config
@staticmethod
def from_pretrained(model_id: str, *, token: Optional[str] = None) -> Optional[Any]:
model = AutoModelForQuestionAnswering.from_pretrained(
model_id, token=token, torchscript=True
)
if isinstance(
model.config,
(
transformers.MPNetConfig,
transformers.XLMRobertaConfig,
transformers.RobertaConfig,
transformers.BartConfig,
),
):
return _TwoParameterQuestionAnsweringWrapper(model)
else:
return _QuestionAnsweringWrapper(model)
class _QuestionAnsweringWrapper(_QuestionAnsweringWrapperModule):
def __init__(self, model: PreTrainedModel):
super().__init__(model=model)
def forward(
self,
input_ids: Tensor,
attention_mask: Tensor,
token_type_ids: Tensor,
position_ids: Tensor,
) -> Tensor:
"""Wrap the input and output to conform to the native process interface."""
inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
"position_ids": position_ids,
}
# remove inputs for specific model types
if isinstance(self._hf_model.config, transformers.DistilBertConfig):
del inputs["token_type_ids"]
del inputs["position_ids"]
response = self._hf_model(**inputs)
if isinstance(response, tuple):
return torch.stack(list(response), dim=0)
return response
class _TwoParameterQuestionAnsweringWrapper(_QuestionAnsweringWrapperModule):
def __init__(self, model: PreTrainedModel):
super().__init__(model=model)
def forward(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor:
"""Wrap the input and output to conform to the native process interface."""
inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
}
response = self._hf_model(**inputs)
if isinstance(response, tuple):
return torch.stack(list(response), dim=0)
return response
class _DistilBertWrapper(nn.Module): # type: ignore
"""
In Elasticsearch the BERT tokenizer is used for DistilBERT models but
the BERT tokenizer produces 4 inputs where DistilBERT models expect 2.
Wrap the model's forward function in a method that accepts the 4
arguments passed to a BERT model then discard the token_type_ids
and the position_ids to match the wrapped DistilBERT model forward
function
"""
def __init__(self, model: transformers.PreTrainedModel):
super().__init__()
self._model = model
self.config = model.config
@staticmethod
def try_wrapping(model: PreTrainedModel) -> Optional[Any]:
if isinstance(model.config, transformers.DistilBertConfig):
return _DistilBertWrapper(model)
else:
return model
def forward(
self,
input_ids: Tensor,
attention_mask: Tensor,
_token_type_ids: Tensor = None,
_position_ids: Tensor = None,
) -> Tensor:
"""Wrap the input and output to conform to the native process interface."""
return self._model(input_ids=input_ids, attention_mask=attention_mask)
class _SentenceTransformerWrapperModule(nn.Module): # type: ignore
"""
A wrapper around sentence-transformer models to provide pooling,
normalization and other graph layers that are not defined in the base
HuggingFace transformer model.
"""
def __init__(self, model: PreTrainedModel, output_key: str = DEFAULT_OUTPUT_KEY):
super().__init__()
self._hf_model = model
self._st_model = SentenceTransformer(model.config.name_or_path)
self._output_key = output_key
self.config = model.config
self._remove_pooling_layer()
self._replace_transformer_layer()
@staticmethod
def from_pretrained(
model_id: str,
tokenizer: PreTrainedTokenizer,
*,
token: Optional[str] = None,
output_key: str = DEFAULT_OUTPUT_KEY,
) -> Optional[Any]:
model = AutoModel.from_pretrained(model_id, token=token, torchscript=True)
if isinstance(
tokenizer,
(
transformers.BartTokenizer,
transformers.MPNetTokenizer,
transformers.RobertaTokenizer,
transformers.XLMRobertaTokenizer,
transformers.DebertaV2Tokenizer,
),
):
return _TwoParameterSentenceTransformerWrapper(model, output_key)
else:
return _SentenceTransformerWrapper(model, output_key)
def _remove_pooling_layer(self) -> None:
"""
Removes any last pooling layer which is not used to create embeddings.
Leaving this layer in will cause it to return a NoneType which in turn
will fail to load in libtorch. Alternatively, we can just use the output
of the pooling layer as a dummy but this also affects (if only in a
minor way) the performance of inference, so we're better off removing
the layer if we can.
"""
if hasattr(self._hf_model, "pooler"):
self._hf_model.pooler = None
def _replace_transformer_layer(self) -> None:
"""
Replaces the HuggingFace Transformer layer in the SentenceTransformer
modules so we can set it with one that has pooling layer removed and
was loaded ready for TorchScript export.
"""
self._st_model._modules["0"].auto_model = self._hf_model
class _SentenceTransformerWrapper(_SentenceTransformerWrapperModule):
def __init__(self, model: PreTrainedModel, output_key: str = DEFAULT_OUTPUT_KEY):
super().__init__(model=model, output_key=output_key)
def forward(
self,
input_ids: Tensor,
attention_mask: Tensor,
token_type_ids: Tensor,
position_ids: Tensor,
) -> Tensor:
"""Wrap the input and output to conform to the native process interface."""
inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
"position_ids": position_ids,
}
# remove inputs for specific model types
if isinstance(self._hf_model.config, transformers.DistilBertConfig):
del inputs["token_type_ids"]
return self._st_model(inputs)[self._output_key]
class _TwoParameterSentenceTransformerWrapper(_SentenceTransformerWrapperModule):
def __init__(self, model: PreTrainedModel, output_key: str = DEFAULT_OUTPUT_KEY):
super().__init__(model=model, output_key=output_key)
def forward(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor:
"""Wrap the input and output to conform to the native process interface."""
inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
}
return self._st_model(inputs)[self._output_key]
class _DPREncoderWrapper(nn.Module): # type: ignore
"""
AutoModel loading does not work for DPRContextEncoders, this only exists as
a workaround. This may never be fixed so this is likely permanent.
See: https://github.com/huggingface/transformers/issues/13670
"""
_SUPPORTED_MODELS = {
transformers.DPRContextEncoder,
transformers.DPRQuestionEncoder,
}
_SUPPORTED_MODELS_NAMES = set([x.__name__ for x in _SUPPORTED_MODELS])
def __init__(
self,
model: Union[transformers.DPRContextEncoder, transformers.DPRQuestionEncoder],
):
super().__init__()
self._model = model
self.config = model.config
@staticmethod
def from_pretrained(model_id: str, *, token: Optional[str] = None) -> Optional[Any]:
config = AutoConfig.from_pretrained(model_id, token=token)
def is_compatible() -> bool:
is_dpr_model = config.model_type == "dpr"
has_architectures = (
config.architectures is not None and len(config.architectures) == 1
)
is_supported_architecture = has_architectures and (
config.architectures[0] in _DPREncoderWrapper._SUPPORTED_MODELS_NAMES
)
return is_dpr_model and is_supported_architecture
if is_compatible():
model = getattr(transformers, config.architectures[0]).from_pretrained(
model_id, torchscript=True
)
return _DPREncoderWrapper(model)
else:
return None
def forward(
self,
input_ids: Tensor,
attention_mask: Tensor,
token_type_ids: Tensor,
_position_ids: Tensor,
) -> Tensor:
"""Wrap the input and output to conform to the native process interface."""
return self._model(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
)

View File

@ -39,6 +39,7 @@ try:
from eland.ml.pytorch import (
FillMaskInferenceOptions,
NlpBertTokenizationConfig,
NlpDebertaV2TokenizationConfig,
NlpMPNetTokenizationConfig,
NlpRobertaTokenizationConfig,
NlpXLMRobertaTokenizationConfig,
@ -149,6 +150,14 @@ if HAS_PYTORCH and HAS_SKLEARN and HAS_TRANSFORMERS:
1024,
None,
),
(
"microsoft/deberta-v3-xsmall",
"fill_mask",
FillMaskInferenceOptions,
NlpDebertaV2TokenizationConfig,
512,
None,
),
]
else:
MODEL_CONFIGURATIONS = []

View File

@ -67,6 +67,8 @@ TEXT_EMBEDDING_MODELS = [
)
]
TEXT_SIMILARITY_MODELS = ["mixedbread-ai/mxbai-rerank-xsmall-v1"]
@pytest.fixture(scope="function", autouse=True)
def setup_and_tear_down():
@ -135,3 +137,25 @@ class TestPytorchModel:
)
> 0
)
@pytest.mark.skipif(
ES_VERSION < (8, 16, 0), reason="requires 8.16.0 for DeBERTa models"
)
@pytest.mark.parametrize("model_id", TEXT_SIMILARITY_MODELS)
def test_text_similarity(self, model_id):
with tempfile.TemporaryDirectory() as tmp_dir:
ptm = download_model_and_start_deployment(
tmp_dir, False, model_id, "text_similarity"
)
result = ptm.infer(
docs=[
{
"text_field": "The Amazon rainforest covers most of the Amazon basin in South America"
},
{"text_field": "Paris is the capital of France"},
],
inference_config={"text_similarity": {"text": "France"}},
)
assert result.body["inference_results"][0]["predicted_value"] < 0
assert result.body["inference_results"][1]["predicted_value"] > 0