mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
Add support for DeBERTa-V2 tokenizer (#717)
This commit is contained in:
parent
a45c7bc357
commit
06b65e211e
@ -86,6 +86,27 @@ class NlpXLMRobertaTokenizationConfig(NlpTokenizationConfig):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DebertaV2Config(NlpTokenizationConfig):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
do_lower_case: t.Optional[bool] = None,
|
||||||
|
with_special_tokens: t.Optional[bool] = None,
|
||||||
|
max_sequence_length: t.Optional[int] = None,
|
||||||
|
truncate: t.Optional[
|
||||||
|
t.Union["t.Literal['first', 'none', 'second']", str]
|
||||||
|
] = None,
|
||||||
|
span: t.Optional[int] = None,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
configuration_type="deberta_v2",
|
||||||
|
with_special_tokens=with_special_tokens,
|
||||||
|
max_sequence_length=max_sequence_length,
|
||||||
|
truncate=truncate,
|
||||||
|
span=span,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class NlpBertTokenizationConfig(NlpTokenizationConfig):
|
class NlpBertTokenizationConfig(NlpTokenizationConfig):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -44,6 +44,7 @@ from transformers import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from eland.ml.pytorch.nlp_ml_model import (
|
from eland.ml.pytorch.nlp_ml_model import (
|
||||||
|
DebertaV2Config,
|
||||||
FillMaskInferenceOptions,
|
FillMaskInferenceOptions,
|
||||||
NerInferenceOptions,
|
NerInferenceOptions,
|
||||||
NlpBertJapaneseTokenizationConfig,
|
NlpBertJapaneseTokenizationConfig,
|
||||||
@ -116,6 +117,7 @@ SUPPORTED_TOKENIZERS = (
|
|||||||
transformers.BartTokenizer,
|
transformers.BartTokenizer,
|
||||||
transformers.SqueezeBertTokenizer,
|
transformers.SqueezeBertTokenizer,
|
||||||
transformers.XLMRobertaTokenizer,
|
transformers.XLMRobertaTokenizer,
|
||||||
|
transformers.DebertaV2Tokenizer,
|
||||||
)
|
)
|
||||||
SUPPORTED_TOKENIZERS_NAMES = ", ".join(sorted([str(x) for x in SUPPORTED_TOKENIZERS]))
|
SUPPORTED_TOKENIZERS_NAMES = ", ".join(sorted([str(x) for x in SUPPORTED_TOKENIZERS]))
|
||||||
|
|
||||||
@ -319,6 +321,7 @@ class _SentenceTransformerWrapperModule(nn.Module): # type: ignore
|
|||||||
transformers.MPNetTokenizer,
|
transformers.MPNetTokenizer,
|
||||||
transformers.RobertaTokenizer,
|
transformers.RobertaTokenizer,
|
||||||
transformers.XLMRobertaTokenizer,
|
transformers.XLMRobertaTokenizer,
|
||||||
|
transformers.DebertaV2Tokenizer,
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
return _TwoParameterSentenceTransformerWrapper(model, output_key)
|
return _TwoParameterSentenceTransformerWrapper(model, output_key)
|
||||||
@ -486,6 +489,7 @@ class _TransformerTraceableModel(TraceableModel):
|
|||||||
transformers.MPNetTokenizer,
|
transformers.MPNetTokenizer,
|
||||||
transformers.RobertaTokenizer,
|
transformers.RobertaTokenizer,
|
||||||
transformers.XLMRobertaTokenizer,
|
transformers.XLMRobertaTokenizer,
|
||||||
|
transformers.DebertaV2Tokenizer,
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
del inputs["token_type_ids"]
|
del inputs["token_type_ids"]
|
||||||
@ -719,6 +723,11 @@ class TransformerModel:
|
|||||||
return NlpXLMRobertaTokenizationConfig(
|
return NlpXLMRobertaTokenizationConfig(
|
||||||
max_sequence_length=_max_sequence_length
|
max_sequence_length=_max_sequence_length
|
||||||
)
|
)
|
||||||
|
elif isinstance(self._tokenizer, transformers.DebertaV2Tokenizer):
|
||||||
|
return DebertaV2Config(
|
||||||
|
max_sequence_length=_max_sequence_length,
|
||||||
|
do_lower_case=getattr(self._tokenizer, "do_lower_case", None),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
japanese_morphological_tokenizers = ["mecab"]
|
japanese_morphological_tokenizers = ["mecab"]
|
||||||
if (
|
if (
|
||||||
|
Loading…
x
Reference in New Issue
Block a user