mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
Add BertJapaneseTokenizer support with bert_ja tokenization configuration (#534)
See elasticsearch#95546
This commit is contained in:
parent
5fd1221815
commit
bf3b092ed4
@ -108,6 +108,28 @@ class NlpBertTokenizationConfig(NlpTokenizationConfig):
|
|||||||
self.do_lower_case = do_lower_case
|
self.do_lower_case = do_lower_case
|
||||||
|
|
||||||
|
|
||||||
|
class NlpBertJapaneseTokenizationConfig(NlpTokenizationConfig):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
do_lower_case: t.Optional[bool] = None,
|
||||||
|
with_special_tokens: t.Optional[bool] = None,
|
||||||
|
max_sequence_length: t.Optional[int] = None,
|
||||||
|
truncate: t.Optional[
|
||||||
|
t.Union["t.Literal['first', 'none', 'second']", str]
|
||||||
|
] = None,
|
||||||
|
span: t.Optional[int] = None,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
configuration_type="bert_ja",
|
||||||
|
with_special_tokens=with_special_tokens,
|
||||||
|
max_sequence_length=max_sequence_length,
|
||||||
|
truncate=truncate,
|
||||||
|
span=span,
|
||||||
|
)
|
||||||
|
self.do_lower_case = do_lower_case
|
||||||
|
|
||||||
|
|
||||||
class NlpMPNetTokenizationConfig(NlpTokenizationConfig):
|
class NlpMPNetTokenizationConfig(NlpTokenizationConfig):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -43,6 +43,7 @@ from transformers import (
|
|||||||
from eland.ml.pytorch.nlp_ml_model import (
|
from eland.ml.pytorch.nlp_ml_model import (
|
||||||
FillMaskInferenceOptions,
|
FillMaskInferenceOptions,
|
||||||
NerInferenceOptions,
|
NerInferenceOptions,
|
||||||
|
NlpBertJapaneseTokenizationConfig,
|
||||||
NlpBertTokenizationConfig,
|
NlpBertTokenizationConfig,
|
||||||
NlpMPNetTokenizationConfig,
|
NlpMPNetTokenizationConfig,
|
||||||
NlpRobertaTokenizationConfig,
|
NlpRobertaTokenizationConfig,
|
||||||
@ -99,6 +100,7 @@ TASK_TYPE_TO_INFERENCE_CONFIG = {
|
|||||||
SUPPORTED_TASK_TYPES_NAMES = ", ".join(sorted(SUPPORTED_TASK_TYPES))
|
SUPPORTED_TASK_TYPES_NAMES = ", ".join(sorted(SUPPORTED_TASK_TYPES))
|
||||||
SUPPORTED_TOKENIZERS = (
|
SUPPORTED_TOKENIZERS = (
|
||||||
transformers.BertTokenizer,
|
transformers.BertTokenizer,
|
||||||
|
transformers.BertJapaneseTokenizer,
|
||||||
transformers.MPNetTokenizer,
|
transformers.MPNetTokenizer,
|
||||||
transformers.DPRContextEncoderTokenizer,
|
transformers.DPRContextEncoderTokenizer,
|
||||||
transformers.DPRQuestionEncoderTokenizer,
|
transformers.DPRQuestionEncoderTokenizer,
|
||||||
@ -684,12 +686,25 @@ class TransformerModel:
|
|||||||
).get(self._model_id),
|
).get(self._model_id),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return NlpBertTokenizationConfig(
|
japanese_morphological_tokenizers = ["mecab"]
|
||||||
do_lower_case=getattr(self._tokenizer, "do_lower_case", None),
|
if (
|
||||||
max_sequence_length=getattr(
|
hasattr(self._tokenizer, "word_tokenizer_type")
|
||||||
self._tokenizer, "max_model_input_sizes", dict()
|
and self._tokenizer.word_tokenizer_type
|
||||||
).get(self._model_id),
|
in japanese_morphological_tokenizers
|
||||||
)
|
):
|
||||||
|
return NlpBertJapaneseTokenizationConfig(
|
||||||
|
do_lower_case=getattr(self._tokenizer, "do_lower_case", None),
|
||||||
|
max_sequence_length=getattr(
|
||||||
|
self._tokenizer, "max_model_input_sizes", dict()
|
||||||
|
).get(self._model_id),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return NlpBertTokenizationConfig(
|
||||||
|
do_lower_case=getattr(self._tokenizer, "do_lower_case", None),
|
||||||
|
max_sequence_length=getattr(
|
||||||
|
self._tokenizer, "max_model_input_sizes", dict()
|
||||||
|
).get(self._model_id),
|
||||||
|
)
|
||||||
|
|
||||||
def _create_config(
|
def _create_config(
|
||||||
self, es_version: Optional[Tuple[int, int, int]]
|
self, es_version: Optional[Tuple[int, int, int]]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user