Add support for DeBERTa-V2 tokenizer (#717)

This commit is contained in:
Max Hniebergall 2024-10-03 14:04:19 -04:00 committed by GitHub
parent a45c7bc357
commit 06b65e211e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 30 additions and 0 deletions

View File

@ -86,6 +86,27 @@ class NlpXLMRobertaTokenizationConfig(NlpTokenizationConfig):
) )
class DebertaV2Config(NlpTokenizationConfig):
def __init__(
self,
*,
do_lower_case: t.Optional[bool] = None,
with_special_tokens: t.Optional[bool] = None,
max_sequence_length: t.Optional[int] = None,
truncate: t.Optional[
t.Union["t.Literal['first', 'none', 'second']", str]
] = None,
span: t.Optional[int] = None,
):
super().__init__(
configuration_type="deberta_v2",
with_special_tokens=with_special_tokens,
max_sequence_length=max_sequence_length,
truncate=truncate,
span=span,
)
class NlpBertTokenizationConfig(NlpTokenizationConfig): class NlpBertTokenizationConfig(NlpTokenizationConfig):
def __init__( def __init__(
self, self,

View File

@ -44,6 +44,7 @@ from transformers import (
) )
from eland.ml.pytorch.nlp_ml_model import ( from eland.ml.pytorch.nlp_ml_model import (
DebertaV2Config,
FillMaskInferenceOptions, FillMaskInferenceOptions,
NerInferenceOptions, NerInferenceOptions,
NlpBertJapaneseTokenizationConfig, NlpBertJapaneseTokenizationConfig,
@ -116,6 +117,7 @@ SUPPORTED_TOKENIZERS = (
transformers.BartTokenizer, transformers.BartTokenizer,
transformers.SqueezeBertTokenizer, transformers.SqueezeBertTokenizer,
transformers.XLMRobertaTokenizer, transformers.XLMRobertaTokenizer,
transformers.DebertaV2Tokenizer,
) )
SUPPORTED_TOKENIZERS_NAMES = ", ".join(sorted([str(x) for x in SUPPORTED_TOKENIZERS])) SUPPORTED_TOKENIZERS_NAMES = ", ".join(sorted([str(x) for x in SUPPORTED_TOKENIZERS]))
@ -319,6 +321,7 @@ class _SentenceTransformerWrapperModule(nn.Module): # type: ignore
transformers.MPNetTokenizer, transformers.MPNetTokenizer,
transformers.RobertaTokenizer, transformers.RobertaTokenizer,
transformers.XLMRobertaTokenizer, transformers.XLMRobertaTokenizer,
transformers.DebertaV2Tokenizer,
), ),
): ):
return _TwoParameterSentenceTransformerWrapper(model, output_key) return _TwoParameterSentenceTransformerWrapper(model, output_key)
@ -486,6 +489,7 @@ class _TransformerTraceableModel(TraceableModel):
transformers.MPNetTokenizer, transformers.MPNetTokenizer,
transformers.RobertaTokenizer, transformers.RobertaTokenizer,
transformers.XLMRobertaTokenizer, transformers.XLMRobertaTokenizer,
transformers.DebertaV2Tokenizer,
), ),
): ):
del inputs["token_type_ids"] del inputs["token_type_ids"]
@ -719,6 +723,11 @@ class TransformerModel:
return NlpXLMRobertaTokenizationConfig( return NlpXLMRobertaTokenizationConfig(
max_sequence_length=_max_sequence_length max_sequence_length=_max_sequence_length
) )
elif isinstance(self._tokenizer, transformers.DebertaV2Tokenizer):
return DebertaV2Config(
max_sequence_length=_max_sequence_length,
do_lower_case=getattr(self._tokenizer, "do_lower_case", None),
)
else: else:
japanese_morphological_tokenizers = ["mecab"] japanese_morphological_tokenizers = ["mecab"]
if ( if (