[ML] add ability to upload xlm-roberta tokenized models (#518)

This allows XLMRoberta models to be uploaded to Elasticsearch.

blocked by: elastic/elasticsearch#94089
This commit is contained in:
Benjamin Trent 2023-06-14 07:59:28 -04:00 committed by GitHub
parent 68a22a8001
commit 8b327f60b8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 46 additions and 0 deletions

View File

@ -31,6 +31,7 @@ __all__ = [
"NlpTrainedModelConfig",
"NlpBertTokenizationConfig",
"NlpRobertaTokenizationConfig",
"NlpXLMRobertaTokenizationConfig",
"NlpMPNetTokenizationConfig",
"task_type_from_model_config",
]

View File

@ -66,6 +66,26 @@ class NlpRobertaTokenizationConfig(NlpTokenizationConfig):
self.add_prefix_space = add_prefix_space
class NlpXLMRobertaTokenizationConfig(NlpTokenizationConfig):
def __init__(
self,
*,
with_special_tokens: t.Optional[bool] = None,
max_sequence_length: t.Optional[int] = None,
truncate: t.Optional[
t.Union["t.Literal['first', 'none', 'second']", str]
] = None,
span: t.Optional[int] = None,
):
super().__init__(
configuration_type="xlm_roberta",
with_special_tokens=with_special_tokens,
max_sequence_length=max_sequence_length,
truncate=truncate,
span=span,
)
class NlpBertTokenizationConfig(NlpTokenizationConfig):
def __init__(
self,

View File

@ -48,6 +48,7 @@ from eland.ml.pytorch.nlp_ml_model import (
NlpRobertaTokenizationConfig,
NlpTokenizationConfig,
NlpTrainedModelConfig,
NlpXLMRobertaTokenizationConfig,
PassThroughInferenceOptions,
QuestionAnsweringInferenceOptions,
TextClassificationInferenceOptions,
@ -108,6 +109,7 @@ SUPPORTED_TOKENIZERS = (
transformers.RobertaTokenizer,
transformers.BartTokenizer,
transformers.SqueezeBertTokenizer,
transformers.XLMRobertaTokenizer,
)
SUPPORTED_TOKENIZERS_NAMES = ", ".join(sorted([str(x) for x in SUPPORTED_TOKENIZERS]))
@ -183,6 +185,7 @@ class _QuestionAnsweringWrapperModule(nn.Module): # type: ignore
model.config,
(
transformers.MPNetConfig,
transformers.XLMRobertaConfig,
transformers.RobertaConfig,
transformers.BartConfig,
),
@ -295,6 +298,7 @@ class _SentenceTransformerWrapperModule(nn.Module): # type: ignore
model.config,
(
transformers.MPNetConfig,
transformers.XLMRobertaConfig,
transformers.RobertaConfig,
transformers.BartConfig,
),
@ -463,6 +467,7 @@ class _TransformerTraceableModel(TraceableModel):
self._model.config,
(
transformers.MPNetConfig,
transformers.XLMRobertaConfig,
transformers.RobertaConfig,
transformers.BartConfig,
),
@ -639,6 +644,20 @@ class TransformerModel:
" ".join(m) for m, _ in sorted(ranks.items(), key=lambda kv: kv[1])
]
vocab_obj["merges"] = merges
sp_model = getattr(self._tokenizer, "sp_model", None)
if sp_model:
id_correction = getattr(self._tokenizer, "fairseq_offset", 0)
scores = []
for _ in range(0, id_correction):
scores.append(0.0)
for token_id in range(id_correction, len(vocabulary)):
try:
scores.append(sp_model.get_score(token_id - id_correction))
except IndexError:
scores.append(0.0)
pass
if len(scores) > 0:
vocab_obj["scores"] = scores
return vocab_obj
def _create_tokenization_config(self) -> NlpTokenizationConfig:
@ -658,6 +677,12 @@ class TransformerModel:
self._tokenizer, "max_model_input_sizes", dict()
).get(self._model_id),
)
elif isinstance(self._tokenizer, transformers.XLMRobertaTokenizer):
return NlpXLMRobertaTokenizationConfig(
max_sequence_length=getattr(
self._tokenizer, "max_model_input_sizes", dict()
).get(self._model_id),
)
else:
return NlpBertTokenizationConfig(
do_lower_case=getattr(self._tokenizer, "do_lower_case", None),