diff --git a/eland/ml/pytorch/nlp_ml_model.py b/eland/ml/pytorch/nlp_ml_model.py index 739ca45..dc5e055 100644 --- a/eland/ml/pytorch/nlp_ml_model.py +++ b/eland/ml/pytorch/nlp_ml_model.py @@ -108,6 +108,28 @@ class NlpBertTokenizationConfig(NlpTokenizationConfig): self.do_lower_case = do_lower_case +class NlpBertJapaneseTokenizationConfig(NlpTokenizationConfig): + def __init__( + self, + *, + do_lower_case: t.Optional[bool] = None, + with_special_tokens: t.Optional[bool] = None, + max_sequence_length: t.Optional[int] = None, + truncate: t.Optional[ + t.Union["t.Literal['first', 'none', 'second']", str] + ] = None, + span: t.Optional[int] = None, + ): + super().__init__( + configuration_type="bert_ja", + with_special_tokens=with_special_tokens, + max_sequence_length=max_sequence_length, + truncate=truncate, + span=span, + ) + self.do_lower_case = do_lower_case + + class NlpMPNetTokenizationConfig(NlpTokenizationConfig): def __init__( self, diff --git a/eland/ml/pytorch/transformers.py b/eland/ml/pytorch/transformers.py index 742e42c..4f1e312 100644 --- a/eland/ml/pytorch/transformers.py +++ b/eland/ml/pytorch/transformers.py @@ -43,6 +43,7 @@ from transformers import ( from eland.ml.pytorch.nlp_ml_model import ( FillMaskInferenceOptions, NerInferenceOptions, + NlpBertJapaneseTokenizationConfig, NlpBertTokenizationConfig, NlpMPNetTokenizationConfig, NlpRobertaTokenizationConfig, @@ -99,6 +100,7 @@ TASK_TYPE_TO_INFERENCE_CONFIG = { SUPPORTED_TASK_TYPES_NAMES = ", ".join(sorted(SUPPORTED_TASK_TYPES)) SUPPORTED_TOKENIZERS = ( transformers.BertTokenizer, + transformers.BertJapaneseTokenizer, transformers.MPNetTokenizer, transformers.DPRContextEncoderTokenizer, transformers.DPRQuestionEncoderTokenizer, @@ -684,12 +686,25 @@ class TransformerModel: ).get(self._model_id), ) else: - return NlpBertTokenizationConfig( - do_lower_case=getattr(self._tokenizer, "do_lower_case", None), - max_sequence_length=getattr( - self._tokenizer, "max_model_input_sizes", dict() - ).get(self._model_id), - ) + japanese_morphological_tokenizers = ["mecab"] + if ( + hasattr(self._tokenizer, "word_tokenizer_type") + and self._tokenizer.word_tokenizer_type + in japanese_morphological_tokenizers + ): + return NlpBertJapaneseTokenizationConfig( + do_lower_case=getattr(self._tokenizer, "do_lower_case", None), + max_sequence_length=getattr( + self._tokenizer, "max_model_input_sizes", dict() + ).get(self._model_id), + ) + else: + return NlpBertTokenizationConfig( + do_lower_case=getattr(self._tokenizer, "do_lower_case", None), + max_sequence_length=getattr( + self._tokenizer, "max_model_input_sizes", dict() + ).get(self._model_id), + ) def _create_config( self, es_version: Optional[Tuple[int, int, int]]