mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
Add support for DeBERTa-V2 tokenizer (#717)
This commit is contained in:
parent
a45c7bc357
commit
06b65e211e
@ -86,6 +86,27 @@ class NlpXLMRobertaTokenizationConfig(NlpTokenizationConfig):
|
||||
)
|
||||
|
||||
|
||||
class DebertaV2Config(NlpTokenizationConfig):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
do_lower_case: t.Optional[bool] = None,
|
||||
with_special_tokens: t.Optional[bool] = None,
|
||||
max_sequence_length: t.Optional[int] = None,
|
||||
truncate: t.Optional[
|
||||
t.Union["t.Literal['first', 'none', 'second']", str]
|
||||
] = None,
|
||||
span: t.Optional[int] = None,
|
||||
):
|
||||
super().__init__(
|
||||
configuration_type="deberta_v2",
|
||||
with_special_tokens=with_special_tokens,
|
||||
max_sequence_length=max_sequence_length,
|
||||
truncate=truncate,
|
||||
span=span,
|
||||
)
|
||||
|
||||
|
||||
class NlpBertTokenizationConfig(NlpTokenizationConfig):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -44,6 +44,7 @@ from transformers import (
|
||||
)
|
||||
|
||||
from eland.ml.pytorch.nlp_ml_model import (
|
||||
DebertaV2Config,
|
||||
FillMaskInferenceOptions,
|
||||
NerInferenceOptions,
|
||||
NlpBertJapaneseTokenizationConfig,
|
||||
@ -116,6 +117,7 @@ SUPPORTED_TOKENIZERS = (
|
||||
transformers.BartTokenizer,
|
||||
transformers.SqueezeBertTokenizer,
|
||||
transformers.XLMRobertaTokenizer,
|
||||
transformers.DebertaV2Tokenizer,
|
||||
)
|
||||
SUPPORTED_TOKENIZERS_NAMES = ", ".join(sorted([str(x) for x in SUPPORTED_TOKENIZERS]))
|
||||
|
||||
@ -319,6 +321,7 @@ class _SentenceTransformerWrapperModule(nn.Module): # type: ignore
|
||||
transformers.MPNetTokenizer,
|
||||
transformers.RobertaTokenizer,
|
||||
transformers.XLMRobertaTokenizer,
|
||||
transformers.DebertaV2Tokenizer,
|
||||
),
|
||||
):
|
||||
return _TwoParameterSentenceTransformerWrapper(model, output_key)
|
||||
@ -486,6 +489,7 @@ class _TransformerTraceableModel(TraceableModel):
|
||||
transformers.MPNetTokenizer,
|
||||
transformers.RobertaTokenizer,
|
||||
transformers.XLMRobertaTokenizer,
|
||||
transformers.DebertaV2Tokenizer,
|
||||
),
|
||||
):
|
||||
del inputs["token_type_ids"]
|
||||
@ -719,6 +723,11 @@ class TransformerModel:
|
||||
return NlpXLMRobertaTokenizationConfig(
|
||||
max_sequence_length=_max_sequence_length
|
||||
)
|
||||
elif isinstance(self._tokenizer, transformers.DebertaV2Tokenizer):
|
||||
return DebertaV2Config(
|
||||
max_sequence_length=_max_sequence_length,
|
||||
do_lower_case=getattr(self._tokenizer, "do_lower_case", None),
|
||||
)
|
||||
else:
|
||||
japanese_morphological_tokenizers = ["mecab"]
|
||||
if (
|
||||
|
Loading…
x
Reference in New Issue
Block a user