diff --git a/eland/ml/pytorch/nlp_ml_model.py b/eland/ml/pytorch/nlp_ml_model.py index 26222f3..012b080 100644 --- a/eland/ml/pytorch/nlp_ml_model.py +++ b/eland/ml/pytorch/nlp_ml_model.py @@ -86,6 +86,27 @@ class NlpXLMRobertaTokenizationConfig(NlpTokenizationConfig): ) +class DebertaV2Config(NlpTokenizationConfig): + def __init__( + self, + *, + do_lower_case: t.Optional[bool] = None, + with_special_tokens: t.Optional[bool] = None, + max_sequence_length: t.Optional[int] = None, + truncate: t.Optional[ + t.Union["t.Literal['first', 'none', 'second']", str] + ] = None, + span: t.Optional[int] = None, + ): + super().__init__( + configuration_type="deberta_v2", + with_special_tokens=with_special_tokens, + max_sequence_length=max_sequence_length, + truncate=truncate, + span=span, + ) + + class NlpBertTokenizationConfig(NlpTokenizationConfig): def __init__( self, diff --git a/eland/ml/pytorch/transformers.py b/eland/ml/pytorch/transformers.py index 271a243..83faaf8 100644 --- a/eland/ml/pytorch/transformers.py +++ b/eland/ml/pytorch/transformers.py @@ -44,6 +44,7 @@ from transformers import ( ) from eland.ml.pytorch.nlp_ml_model import ( + DebertaV2Config, FillMaskInferenceOptions, NerInferenceOptions, NlpBertJapaneseTokenizationConfig, @@ -116,6 +117,7 @@ SUPPORTED_TOKENIZERS = ( transformers.BartTokenizer, transformers.SqueezeBertTokenizer, transformers.XLMRobertaTokenizer, + transformers.DebertaV2Tokenizer, ) SUPPORTED_TOKENIZERS_NAMES = ", ".join(sorted([str(x) for x in SUPPORTED_TOKENIZERS])) @@ -319,6 +321,7 @@ class _SentenceTransformerWrapperModule(nn.Module): # type: ignore transformers.MPNetTokenizer, transformers.RobertaTokenizer, transformers.XLMRobertaTokenizer, + transformers.DebertaV2Tokenizer, ), ): return _TwoParameterSentenceTransformerWrapper(model, output_key) @@ -486,6 +489,7 @@ class _TransformerTraceableModel(TraceableModel): transformers.MPNetTokenizer, transformers.RobertaTokenizer, transformers.XLMRobertaTokenizer, + transformers.DebertaV2Tokenizer, ), ): del inputs["token_type_ids"] @@ -719,6 +723,11 @@ class TransformerModel: return NlpXLMRobertaTokenizationConfig( max_sequence_length=_max_sequence_length ) + elif isinstance(self._tokenizer, transformers.DebertaV2Tokenizer): + return DebertaV2Config( + max_sequence_length=_max_sequence_length, + do_lower_case=getattr(self._tokenizer, "do_lower_case", None), + ) else: japanese_morphological_tokenizers = ["mecab"] if (