Add BertJapaneseTokenizer support with bert_ja tokenization configuration (#534)

See elasticsearch#95546
This commit is contained in:
Dai Sugimori 2023-06-23 16:14:27 +09:00 committed by GitHub
parent 5fd1221815
commit bf3b092ed4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 43 additions and 6 deletions

View File

@ -108,6 +108,28 @@ class NlpBertTokenizationConfig(NlpTokenizationConfig):
self.do_lower_case = do_lower_case
class NlpBertJapaneseTokenizationConfig(NlpTokenizationConfig):
def __init__(
self,
*,
do_lower_case: t.Optional[bool] = None,
with_special_tokens: t.Optional[bool] = None,
max_sequence_length: t.Optional[int] = None,
truncate: t.Optional[
t.Union["t.Literal['first', 'none', 'second']", str]
] = None,
span: t.Optional[int] = None,
):
super().__init__(
configuration_type="bert_ja",
with_special_tokens=with_special_tokens,
max_sequence_length=max_sequence_length,
truncate=truncate,
span=span,
)
self.do_lower_case = do_lower_case
class NlpMPNetTokenizationConfig(NlpTokenizationConfig):
def __init__(
self,

View File

@ -43,6 +43,7 @@ from transformers import (
from eland.ml.pytorch.nlp_ml_model import (
FillMaskInferenceOptions,
NerInferenceOptions,
NlpBertJapaneseTokenizationConfig,
NlpBertTokenizationConfig,
NlpMPNetTokenizationConfig,
NlpRobertaTokenizationConfig,
@ -99,6 +100,7 @@ TASK_TYPE_TO_INFERENCE_CONFIG = {
SUPPORTED_TASK_TYPES_NAMES = ", ".join(sorted(SUPPORTED_TASK_TYPES))
SUPPORTED_TOKENIZERS = (
transformers.BertTokenizer,
transformers.BertJapaneseTokenizer,
transformers.MPNetTokenizer,
transformers.DPRContextEncoderTokenizer,
transformers.DPRQuestionEncoderTokenizer,
@ -684,12 +686,25 @@ class TransformerModel:
).get(self._model_id),
)
else:
return NlpBertTokenizationConfig(
do_lower_case=getattr(self._tokenizer, "do_lower_case", None),
max_sequence_length=getattr(
self._tokenizer, "max_model_input_sizes", dict()
).get(self._model_id),
)
japanese_morphological_tokenizers = ["mecab"]
if (
hasattr(self._tokenizer, "word_tokenizer_type")
and self._tokenizer.word_tokenizer_type
in japanese_morphological_tokenizers
):
return NlpBertJapaneseTokenizationConfig(
do_lower_case=getattr(self._tokenizer, "do_lower_case", None),
max_sequence_length=getattr(
self._tokenizer, "max_model_input_sizes", dict()
).get(self._model_id),
)
else:
return NlpBertTokenizationConfig(
do_lower_case=getattr(self._tokenizer, "do_lower_case", None),
max_sequence_length=getattr(
self._tokenizer, "max_model_input_sizes", dict()
).get(self._model_id),
)
def _create_config(
self, es_version: Optional[Tuple[int, int, int]]