mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
Add BertJapaneseTokenizer support with bert_ja tokenization configuration (#534)
See elasticsearch#95546
This commit is contained in:
parent
5fd1221815
commit
bf3b092ed4
@ -108,6 +108,28 @@ class NlpBertTokenizationConfig(NlpTokenizationConfig):
|
||||
self.do_lower_case = do_lower_case
|
||||
|
||||
|
||||
class NlpBertJapaneseTokenizationConfig(NlpTokenizationConfig):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
do_lower_case: t.Optional[bool] = None,
|
||||
with_special_tokens: t.Optional[bool] = None,
|
||||
max_sequence_length: t.Optional[int] = None,
|
||||
truncate: t.Optional[
|
||||
t.Union["t.Literal['first', 'none', 'second']", str]
|
||||
] = None,
|
||||
span: t.Optional[int] = None,
|
||||
):
|
||||
super().__init__(
|
||||
configuration_type="bert_ja",
|
||||
with_special_tokens=with_special_tokens,
|
||||
max_sequence_length=max_sequence_length,
|
||||
truncate=truncate,
|
||||
span=span,
|
||||
)
|
||||
self.do_lower_case = do_lower_case
|
||||
|
||||
|
||||
class NlpMPNetTokenizationConfig(NlpTokenizationConfig):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -43,6 +43,7 @@ from transformers import (
|
||||
from eland.ml.pytorch.nlp_ml_model import (
|
||||
FillMaskInferenceOptions,
|
||||
NerInferenceOptions,
|
||||
NlpBertJapaneseTokenizationConfig,
|
||||
NlpBertTokenizationConfig,
|
||||
NlpMPNetTokenizationConfig,
|
||||
NlpRobertaTokenizationConfig,
|
||||
@ -99,6 +100,7 @@ TASK_TYPE_TO_INFERENCE_CONFIG = {
|
||||
SUPPORTED_TASK_TYPES_NAMES = ", ".join(sorted(SUPPORTED_TASK_TYPES))
|
||||
SUPPORTED_TOKENIZERS = (
|
||||
transformers.BertTokenizer,
|
||||
transformers.BertJapaneseTokenizer,
|
||||
transformers.MPNetTokenizer,
|
||||
transformers.DPRContextEncoderTokenizer,
|
||||
transformers.DPRQuestionEncoderTokenizer,
|
||||
@ -684,12 +686,25 @@ class TransformerModel:
|
||||
).get(self._model_id),
|
||||
)
|
||||
else:
|
||||
return NlpBertTokenizationConfig(
|
||||
do_lower_case=getattr(self._tokenizer, "do_lower_case", None),
|
||||
max_sequence_length=getattr(
|
||||
self._tokenizer, "max_model_input_sizes", dict()
|
||||
).get(self._model_id),
|
||||
)
|
||||
japanese_morphological_tokenizers = ["mecab"]
|
||||
if (
|
||||
hasattr(self._tokenizer, "word_tokenizer_type")
|
||||
and self._tokenizer.word_tokenizer_type
|
||||
in japanese_morphological_tokenizers
|
||||
):
|
||||
return NlpBertJapaneseTokenizationConfig(
|
||||
do_lower_case=getattr(self._tokenizer, "do_lower_case", None),
|
||||
max_sequence_length=getattr(
|
||||
self._tokenizer, "max_model_input_sizes", dict()
|
||||
).get(self._model_id),
|
||||
)
|
||||
else:
|
||||
return NlpBertTokenizationConfig(
|
||||
do_lower_case=getattr(self._tokenizer, "do_lower_case", None),
|
||||
max_sequence_length=getattr(
|
||||
self._tokenizer, "max_model_input_sizes", dict()
|
||||
).get(self._model_id),
|
||||
)
|
||||
|
||||
def _create_config(
|
||||
self, es_version: Optional[Tuple[int, int, int]]
|
||||
|
Loading…
x
Reference in New Issue
Block a user