mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
[ML] add ability to upload xlm-roberta tokenized models (#518)
This allows XLMRoberta models to be uploaded to Elasticsearch. blocked by: elastic/elasticsearch#94089
This commit is contained in:
parent
68a22a8001
commit
8b327f60b8
@ -31,6 +31,7 @@ __all__ = [
|
||||
"NlpTrainedModelConfig",
|
||||
"NlpBertTokenizationConfig",
|
||||
"NlpRobertaTokenizationConfig",
|
||||
"NlpXLMRobertaTokenizationConfig",
|
||||
"NlpMPNetTokenizationConfig",
|
||||
"task_type_from_model_config",
|
||||
]
|
||||
|
@ -66,6 +66,26 @@ class NlpRobertaTokenizationConfig(NlpTokenizationConfig):
|
||||
self.add_prefix_space = add_prefix_space
|
||||
|
||||
|
||||
class NlpXLMRobertaTokenizationConfig(NlpTokenizationConfig):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
with_special_tokens: t.Optional[bool] = None,
|
||||
max_sequence_length: t.Optional[int] = None,
|
||||
truncate: t.Optional[
|
||||
t.Union["t.Literal['first', 'none', 'second']", str]
|
||||
] = None,
|
||||
span: t.Optional[int] = None,
|
||||
):
|
||||
super().__init__(
|
||||
configuration_type="xlm_roberta",
|
||||
with_special_tokens=with_special_tokens,
|
||||
max_sequence_length=max_sequence_length,
|
||||
truncate=truncate,
|
||||
span=span,
|
||||
)
|
||||
|
||||
|
||||
class NlpBertTokenizationConfig(NlpTokenizationConfig):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -48,6 +48,7 @@ from eland.ml.pytorch.nlp_ml_model import (
|
||||
NlpRobertaTokenizationConfig,
|
||||
NlpTokenizationConfig,
|
||||
NlpTrainedModelConfig,
|
||||
NlpXLMRobertaTokenizationConfig,
|
||||
PassThroughInferenceOptions,
|
||||
QuestionAnsweringInferenceOptions,
|
||||
TextClassificationInferenceOptions,
|
||||
@ -108,6 +109,7 @@ SUPPORTED_TOKENIZERS = (
|
||||
transformers.RobertaTokenizer,
|
||||
transformers.BartTokenizer,
|
||||
transformers.SqueezeBertTokenizer,
|
||||
transformers.XLMRobertaTokenizer,
|
||||
)
|
||||
SUPPORTED_TOKENIZERS_NAMES = ", ".join(sorted([str(x) for x in SUPPORTED_TOKENIZERS]))
|
||||
|
||||
@ -183,6 +185,7 @@ class _QuestionAnsweringWrapperModule(nn.Module): # type: ignore
|
||||
model.config,
|
||||
(
|
||||
transformers.MPNetConfig,
|
||||
transformers.XLMRobertaConfig,
|
||||
transformers.RobertaConfig,
|
||||
transformers.BartConfig,
|
||||
),
|
||||
@ -295,6 +298,7 @@ class _SentenceTransformerWrapperModule(nn.Module): # type: ignore
|
||||
model.config,
|
||||
(
|
||||
transformers.MPNetConfig,
|
||||
transformers.XLMRobertaConfig,
|
||||
transformers.RobertaConfig,
|
||||
transformers.BartConfig,
|
||||
),
|
||||
@ -463,6 +467,7 @@ class _TransformerTraceableModel(TraceableModel):
|
||||
self._model.config,
|
||||
(
|
||||
transformers.MPNetConfig,
|
||||
transformers.XLMRobertaConfig,
|
||||
transformers.RobertaConfig,
|
||||
transformers.BartConfig,
|
||||
),
|
||||
@ -639,6 +644,20 @@ class TransformerModel:
|
||||
" ".join(m) for m, _ in sorted(ranks.items(), key=lambda kv: kv[1])
|
||||
]
|
||||
vocab_obj["merges"] = merges
|
||||
sp_model = getattr(self._tokenizer, "sp_model", None)
|
||||
if sp_model:
|
||||
id_correction = getattr(self._tokenizer, "fairseq_offset", 0)
|
||||
scores = []
|
||||
for _ in range(0, id_correction):
|
||||
scores.append(0.0)
|
||||
for token_id in range(id_correction, len(vocabulary)):
|
||||
try:
|
||||
scores.append(sp_model.get_score(token_id - id_correction))
|
||||
except IndexError:
|
||||
scores.append(0.0)
|
||||
pass
|
||||
if len(scores) > 0:
|
||||
vocab_obj["scores"] = scores
|
||||
return vocab_obj
|
||||
|
||||
def _create_tokenization_config(self) -> NlpTokenizationConfig:
|
||||
@ -658,6 +677,12 @@ class TransformerModel:
|
||||
self._tokenizer, "max_model_input_sizes", dict()
|
||||
).get(self._model_id),
|
||||
)
|
||||
elif isinstance(self._tokenizer, transformers.XLMRobertaTokenizer):
|
||||
return NlpXLMRobertaTokenizationConfig(
|
||||
max_sequence_length=getattr(
|
||||
self._tokenizer, "max_model_input_sizes", dict()
|
||||
).get(self._model_id),
|
||||
)
|
||||
else:
|
||||
return NlpBertTokenizationConfig(
|
||||
do_lower_case=getattr(self._tokenizer, "do_lower_case", None),
|
||||
|
Loading…
x
Reference in New Issue
Block a user