Add BertJapaneseTokenizer support with bert_ja tokenization configuration (#534)

See elasticsearch#95546
This commit is contained in:
Dai Sugimori 2023-06-23 16:14:27 +09:00 committed by GitHub
parent 5fd1221815
commit bf3b092ed4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 43 additions and 6 deletions

View File

@ -108,6 +108,28 @@ class NlpBertTokenizationConfig(NlpTokenizationConfig):
self.do_lower_case = do_lower_case self.do_lower_case = do_lower_case
class NlpBertJapaneseTokenizationConfig(NlpTokenizationConfig):
def __init__(
self,
*,
do_lower_case: t.Optional[bool] = None,
with_special_tokens: t.Optional[bool] = None,
max_sequence_length: t.Optional[int] = None,
truncate: t.Optional[
t.Union["t.Literal['first', 'none', 'second']", str]
] = None,
span: t.Optional[int] = None,
):
super().__init__(
configuration_type="bert_ja",
with_special_tokens=with_special_tokens,
max_sequence_length=max_sequence_length,
truncate=truncate,
span=span,
)
self.do_lower_case = do_lower_case
class NlpMPNetTokenizationConfig(NlpTokenizationConfig): class NlpMPNetTokenizationConfig(NlpTokenizationConfig):
def __init__( def __init__(
self, self,

View File

@ -43,6 +43,7 @@ from transformers import (
from eland.ml.pytorch.nlp_ml_model import ( from eland.ml.pytorch.nlp_ml_model import (
FillMaskInferenceOptions, FillMaskInferenceOptions,
NerInferenceOptions, NerInferenceOptions,
NlpBertJapaneseTokenizationConfig,
NlpBertTokenizationConfig, NlpBertTokenizationConfig,
NlpMPNetTokenizationConfig, NlpMPNetTokenizationConfig,
NlpRobertaTokenizationConfig, NlpRobertaTokenizationConfig,
@ -99,6 +100,7 @@ TASK_TYPE_TO_INFERENCE_CONFIG = {
SUPPORTED_TASK_TYPES_NAMES = ", ".join(sorted(SUPPORTED_TASK_TYPES)) SUPPORTED_TASK_TYPES_NAMES = ", ".join(sorted(SUPPORTED_TASK_TYPES))
SUPPORTED_TOKENIZERS = ( SUPPORTED_TOKENIZERS = (
transformers.BertTokenizer, transformers.BertTokenizer,
transformers.BertJapaneseTokenizer,
transformers.MPNetTokenizer, transformers.MPNetTokenizer,
transformers.DPRContextEncoderTokenizer, transformers.DPRContextEncoderTokenizer,
transformers.DPRQuestionEncoderTokenizer, transformers.DPRQuestionEncoderTokenizer,
@ -684,12 +686,25 @@ class TransformerModel:
).get(self._model_id), ).get(self._model_id),
) )
else: else:
return NlpBertTokenizationConfig( japanese_morphological_tokenizers = ["mecab"]
do_lower_case=getattr(self._tokenizer, "do_lower_case", None), if (
max_sequence_length=getattr( hasattr(self._tokenizer, "word_tokenizer_type")
self._tokenizer, "max_model_input_sizes", dict() and self._tokenizer.word_tokenizer_type
).get(self._model_id), in japanese_morphological_tokenizers
) ):
return NlpBertJapaneseTokenizationConfig(
do_lower_case=getattr(self._tokenizer, "do_lower_case", None),
max_sequence_length=getattr(
self._tokenizer, "max_model_input_sizes", dict()
).get(self._model_id),
)
else:
return NlpBertTokenizationConfig(
do_lower_case=getattr(self._tokenizer, "do_lower_case", None),
max_sequence_length=getattr(
self._tokenizer, "max_model_input_sizes", dict()
).get(self._model_id),
)
def _create_config( def _create_config(
self, es_version: Optional[Tuple[int, int, int]] self, es_version: Optional[Tuple[int, int, int]]