From 06b65e211e9a29ae490a0e92d447d575909c8423 Mon Sep 17 00:00:00 2001 From: Max Hniebergall <137079448+maxhniebergall@users.noreply.github.com> Date: Thu, 3 Oct 2024 14:04:19 -0400 Subject: [PATCH] Add support for DeBERTa-V2 tokenizer (#717) --- eland/ml/pytorch/nlp_ml_model.py | 21 +++++++++++++++++++++ eland/ml/pytorch/transformers.py | 9 +++++++++ 2 files changed, 30 insertions(+) diff --git a/eland/ml/pytorch/nlp_ml_model.py b/eland/ml/pytorch/nlp_ml_model.py index 26222f3..012b080 100644 --- a/eland/ml/pytorch/nlp_ml_model.py +++ b/eland/ml/pytorch/nlp_ml_model.py @@ -86,6 +86,27 @@ class NlpXLMRobertaTokenizationConfig(NlpTokenizationConfig): ) +class DebertaV2Config(NlpTokenizationConfig): + def __init__( + self, + *, + do_lower_case: t.Optional[bool] = None, + with_special_tokens: t.Optional[bool] = None, + max_sequence_length: t.Optional[int] = None, + truncate: t.Optional[ + t.Union["t.Literal['first', 'none', 'second']", str] + ] = None, + span: t.Optional[int] = None, + ): + super().__init__( + configuration_type="deberta_v2", + with_special_tokens=with_special_tokens, + max_sequence_length=max_sequence_length, + truncate=truncate, + span=span, + ) + + class NlpBertTokenizationConfig(NlpTokenizationConfig): def __init__( self, diff --git a/eland/ml/pytorch/transformers.py b/eland/ml/pytorch/transformers.py index 271a243..83faaf8 100644 --- a/eland/ml/pytorch/transformers.py +++ b/eland/ml/pytorch/transformers.py @@ -44,6 +44,7 @@ from transformers import ( ) from eland.ml.pytorch.nlp_ml_model import ( + DebertaV2Config, FillMaskInferenceOptions, NerInferenceOptions, NlpBertJapaneseTokenizationConfig, @@ -116,6 +117,7 @@ SUPPORTED_TOKENIZERS = ( transformers.BartTokenizer, transformers.SqueezeBertTokenizer, transformers.XLMRobertaTokenizer, + transformers.DebertaV2Tokenizer, ) SUPPORTED_TOKENIZERS_NAMES = ", ".join(sorted([str(x) for x in SUPPORTED_TOKENIZERS])) @@ -319,6 +321,7 @@ class _SentenceTransformerWrapperModule(nn.Module): # type: ignore transformers.MPNetTokenizer, transformers.RobertaTokenizer, transformers.XLMRobertaTokenizer, + transformers.DebertaV2Tokenizer, ), ): return _TwoParameterSentenceTransformerWrapper(model, output_key) @@ -486,6 +489,7 @@ class _TransformerTraceableModel(TraceableModel): transformers.MPNetTokenizer, transformers.RobertaTokenizer, transformers.XLMRobertaTokenizer, + transformers.DebertaV2Tokenizer, ), ): del inputs["token_type_ids"] @@ -719,6 +723,11 @@ class TransformerModel: return NlpXLMRobertaTokenizationConfig( max_sequence_length=_max_sequence_length ) + elif isinstance(self._tokenizer, transformers.DebertaV2Tokenizer): + return DebertaV2Config( + max_sequence_length=_max_sequence_length, + do_lower_case=getattr(self._tokenizer, "do_lower_case", None), + ) else: japanese_morphological_tokenizers = ["mecab"] if (