mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
[ML] add ability to upload xlm-roberta tokenized models (#518)
This allows XLMRoberta models to be uploaded to Elasticsearch. blocked by: elastic/elasticsearch#94089
This commit is contained in:
parent
68a22a8001
commit
8b327f60b8
@ -31,6 +31,7 @@ __all__ = [
|
|||||||
"NlpTrainedModelConfig",
|
"NlpTrainedModelConfig",
|
||||||
"NlpBertTokenizationConfig",
|
"NlpBertTokenizationConfig",
|
||||||
"NlpRobertaTokenizationConfig",
|
"NlpRobertaTokenizationConfig",
|
||||||
|
"NlpXLMRobertaTokenizationConfig",
|
||||||
"NlpMPNetTokenizationConfig",
|
"NlpMPNetTokenizationConfig",
|
||||||
"task_type_from_model_config",
|
"task_type_from_model_config",
|
||||||
]
|
]
|
||||||
|
@ -66,6 +66,26 @@ class NlpRobertaTokenizationConfig(NlpTokenizationConfig):
|
|||||||
self.add_prefix_space = add_prefix_space
|
self.add_prefix_space = add_prefix_space
|
||||||
|
|
||||||
|
|
||||||
|
class NlpXLMRobertaTokenizationConfig(NlpTokenizationConfig):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
with_special_tokens: t.Optional[bool] = None,
|
||||||
|
max_sequence_length: t.Optional[int] = None,
|
||||||
|
truncate: t.Optional[
|
||||||
|
t.Union["t.Literal['first', 'none', 'second']", str]
|
||||||
|
] = None,
|
||||||
|
span: t.Optional[int] = None,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
configuration_type="xlm_roberta",
|
||||||
|
with_special_tokens=with_special_tokens,
|
||||||
|
max_sequence_length=max_sequence_length,
|
||||||
|
truncate=truncate,
|
||||||
|
span=span,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class NlpBertTokenizationConfig(NlpTokenizationConfig):
|
class NlpBertTokenizationConfig(NlpTokenizationConfig):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -48,6 +48,7 @@ from eland.ml.pytorch.nlp_ml_model import (
|
|||||||
NlpRobertaTokenizationConfig,
|
NlpRobertaTokenizationConfig,
|
||||||
NlpTokenizationConfig,
|
NlpTokenizationConfig,
|
||||||
NlpTrainedModelConfig,
|
NlpTrainedModelConfig,
|
||||||
|
NlpXLMRobertaTokenizationConfig,
|
||||||
PassThroughInferenceOptions,
|
PassThroughInferenceOptions,
|
||||||
QuestionAnsweringInferenceOptions,
|
QuestionAnsweringInferenceOptions,
|
||||||
TextClassificationInferenceOptions,
|
TextClassificationInferenceOptions,
|
||||||
@ -108,6 +109,7 @@ SUPPORTED_TOKENIZERS = (
|
|||||||
transformers.RobertaTokenizer,
|
transformers.RobertaTokenizer,
|
||||||
transformers.BartTokenizer,
|
transformers.BartTokenizer,
|
||||||
transformers.SqueezeBertTokenizer,
|
transformers.SqueezeBertTokenizer,
|
||||||
|
transformers.XLMRobertaTokenizer,
|
||||||
)
|
)
|
||||||
SUPPORTED_TOKENIZERS_NAMES = ", ".join(sorted([str(x) for x in SUPPORTED_TOKENIZERS]))
|
SUPPORTED_TOKENIZERS_NAMES = ", ".join(sorted([str(x) for x in SUPPORTED_TOKENIZERS]))
|
||||||
|
|
||||||
@ -183,6 +185,7 @@ class _QuestionAnsweringWrapperModule(nn.Module): # type: ignore
|
|||||||
model.config,
|
model.config,
|
||||||
(
|
(
|
||||||
transformers.MPNetConfig,
|
transformers.MPNetConfig,
|
||||||
|
transformers.XLMRobertaConfig,
|
||||||
transformers.RobertaConfig,
|
transformers.RobertaConfig,
|
||||||
transformers.BartConfig,
|
transformers.BartConfig,
|
||||||
),
|
),
|
||||||
@ -295,6 +298,7 @@ class _SentenceTransformerWrapperModule(nn.Module): # type: ignore
|
|||||||
model.config,
|
model.config,
|
||||||
(
|
(
|
||||||
transformers.MPNetConfig,
|
transformers.MPNetConfig,
|
||||||
|
transformers.XLMRobertaConfig,
|
||||||
transformers.RobertaConfig,
|
transformers.RobertaConfig,
|
||||||
transformers.BartConfig,
|
transformers.BartConfig,
|
||||||
),
|
),
|
||||||
@ -463,6 +467,7 @@ class _TransformerTraceableModel(TraceableModel):
|
|||||||
self._model.config,
|
self._model.config,
|
||||||
(
|
(
|
||||||
transformers.MPNetConfig,
|
transformers.MPNetConfig,
|
||||||
|
transformers.XLMRobertaConfig,
|
||||||
transformers.RobertaConfig,
|
transformers.RobertaConfig,
|
||||||
transformers.BartConfig,
|
transformers.BartConfig,
|
||||||
),
|
),
|
||||||
@ -639,6 +644,20 @@ class TransformerModel:
|
|||||||
" ".join(m) for m, _ in sorted(ranks.items(), key=lambda kv: kv[1])
|
" ".join(m) for m, _ in sorted(ranks.items(), key=lambda kv: kv[1])
|
||||||
]
|
]
|
||||||
vocab_obj["merges"] = merges
|
vocab_obj["merges"] = merges
|
||||||
|
sp_model = getattr(self._tokenizer, "sp_model", None)
|
||||||
|
if sp_model:
|
||||||
|
id_correction = getattr(self._tokenizer, "fairseq_offset", 0)
|
||||||
|
scores = []
|
||||||
|
for _ in range(0, id_correction):
|
||||||
|
scores.append(0.0)
|
||||||
|
for token_id in range(id_correction, len(vocabulary)):
|
||||||
|
try:
|
||||||
|
scores.append(sp_model.get_score(token_id - id_correction))
|
||||||
|
except IndexError:
|
||||||
|
scores.append(0.0)
|
||||||
|
pass
|
||||||
|
if len(scores) > 0:
|
||||||
|
vocab_obj["scores"] = scores
|
||||||
return vocab_obj
|
return vocab_obj
|
||||||
|
|
||||||
def _create_tokenization_config(self) -> NlpTokenizationConfig:
|
def _create_tokenization_config(self) -> NlpTokenizationConfig:
|
||||||
@ -658,6 +677,12 @@ class TransformerModel:
|
|||||||
self._tokenizer, "max_model_input_sizes", dict()
|
self._tokenizer, "max_model_input_sizes", dict()
|
||||||
).get(self._model_id),
|
).get(self._model_id),
|
||||||
)
|
)
|
||||||
|
elif isinstance(self._tokenizer, transformers.XLMRobertaTokenizer):
|
||||||
|
return NlpXLMRobertaTokenizationConfig(
|
||||||
|
max_sequence_length=getattr(
|
||||||
|
self._tokenizer, "max_model_input_sizes", dict()
|
||||||
|
).get(self._model_id),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return NlpBertTokenizationConfig(
|
return NlpBertTokenizationConfig(
|
||||||
do_lower_case=getattr(self._tokenizer, "do_lower_case", None),
|
do_lower_case=getattr(self._tokenizer, "do_lower_case", None),
|
||||||
|
Loading…
x
Reference in New Issue
Block a user