Add support for DeBERTa-V2 tokenizer (#717)

This commit is contained in:
Max Hniebergall 2024-10-03 14:04:19 -04:00 committed by GitHub
parent a45c7bc357
commit 06b65e211e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 30 additions and 0 deletions

View File

@ -86,6 +86,27 @@ class NlpXLMRobertaTokenizationConfig(NlpTokenizationConfig):
)
class DebertaV2Config(NlpTokenizationConfig):
def __init__(
self,
*,
do_lower_case: t.Optional[bool] = None,
with_special_tokens: t.Optional[bool] = None,
max_sequence_length: t.Optional[int] = None,
truncate: t.Optional[
t.Union["t.Literal['first', 'none', 'second']", str]
] = None,
span: t.Optional[int] = None,
):
super().__init__(
configuration_type="deberta_v2",
with_special_tokens=with_special_tokens,
max_sequence_length=max_sequence_length,
truncate=truncate,
span=span,
)
class NlpBertTokenizationConfig(NlpTokenizationConfig):
def __init__(
self,

View File

@ -44,6 +44,7 @@ from transformers import (
)
from eland.ml.pytorch.nlp_ml_model import (
DebertaV2Config,
FillMaskInferenceOptions,
NerInferenceOptions,
NlpBertJapaneseTokenizationConfig,
@ -116,6 +117,7 @@ SUPPORTED_TOKENIZERS = (
transformers.BartTokenizer,
transformers.SqueezeBertTokenizer,
transformers.XLMRobertaTokenizer,
transformers.DebertaV2Tokenizer,
)
SUPPORTED_TOKENIZERS_NAMES = ", ".join(sorted([str(x) for x in SUPPORTED_TOKENIZERS]))
@ -319,6 +321,7 @@ class _SentenceTransformerWrapperModule(nn.Module): # type: ignore
transformers.MPNetTokenizer,
transformers.RobertaTokenizer,
transformers.XLMRobertaTokenizer,
transformers.DebertaV2Tokenizer,
),
):
return _TwoParameterSentenceTransformerWrapper(model, output_key)
@ -486,6 +489,7 @@ class _TransformerTraceableModel(TraceableModel):
transformers.MPNetTokenizer,
transformers.RobertaTokenizer,
transformers.XLMRobertaTokenizer,
transformers.DebertaV2Tokenizer,
),
):
del inputs["token_type_ids"]
@ -719,6 +723,11 @@ class TransformerModel:
return NlpXLMRobertaTokenizationConfig(
max_sequence_length=_max_sequence_length
)
elif isinstance(self._tokenizer, transformers.DebertaV2Tokenizer):
return DebertaV2Config(
max_sequence_length=_max_sequence_length,
do_lower_case=getattr(self._tokenizer, "do_lower_case", None),
)
else:
japanese_morphological_tokenizers = ["mecab"]
if (