diff --git a/eland/ml/pytorch/__init__.py b/eland/ml/pytorch/__init__.py index 5cd49b5..4862028 100644 --- a/eland/ml/pytorch/__init__.py +++ b/eland/ml/pytorch/__init__.py @@ -31,6 +31,7 @@ __all__ = [ "NlpTrainedModelConfig", "NlpBertTokenizationConfig", "NlpRobertaTokenizationConfig", + "NlpXLMRobertaTokenizationConfig", "NlpMPNetTokenizationConfig", "task_type_from_model_config", ] diff --git a/eland/ml/pytorch/nlp_ml_model.py b/eland/ml/pytorch/nlp_ml_model.py index d21b9f4..739ca45 100644 --- a/eland/ml/pytorch/nlp_ml_model.py +++ b/eland/ml/pytorch/nlp_ml_model.py @@ -66,6 +66,26 @@ class NlpRobertaTokenizationConfig(NlpTokenizationConfig): self.add_prefix_space = add_prefix_space +class NlpXLMRobertaTokenizationConfig(NlpTokenizationConfig): + def __init__( + self, + *, + with_special_tokens: t.Optional[bool] = None, + max_sequence_length: t.Optional[int] = None, + truncate: t.Optional[ + t.Union["t.Literal['first', 'none', 'second']", str] + ] = None, + span: t.Optional[int] = None, + ): + super().__init__( + configuration_type="xlm_roberta", + with_special_tokens=with_special_tokens, + max_sequence_length=max_sequence_length, + truncate=truncate, + span=span, + ) + + class NlpBertTokenizationConfig(NlpTokenizationConfig): def __init__( self, diff --git a/eland/ml/pytorch/transformers.py b/eland/ml/pytorch/transformers.py index 9b2b78e..742e42c 100644 --- a/eland/ml/pytorch/transformers.py +++ b/eland/ml/pytorch/transformers.py @@ -48,6 +48,7 @@ from eland.ml.pytorch.nlp_ml_model import ( NlpRobertaTokenizationConfig, NlpTokenizationConfig, NlpTrainedModelConfig, + NlpXLMRobertaTokenizationConfig, PassThroughInferenceOptions, QuestionAnsweringInferenceOptions, TextClassificationInferenceOptions, @@ -108,6 +109,7 @@ SUPPORTED_TOKENIZERS = ( transformers.RobertaTokenizer, transformers.BartTokenizer, transformers.SqueezeBertTokenizer, + transformers.XLMRobertaTokenizer, ) SUPPORTED_TOKENIZERS_NAMES = ", ".join(sorted([str(x) for x in SUPPORTED_TOKENIZERS])) @@ -183,6 +185,7 @@ class _QuestionAnsweringWrapperModule(nn.Module): # type: ignore model.config, ( transformers.MPNetConfig, + transformers.XLMRobertaConfig, transformers.RobertaConfig, transformers.BartConfig, ), @@ -295,6 +298,7 @@ class _SentenceTransformerWrapperModule(nn.Module): # type: ignore model.config, ( transformers.MPNetConfig, + transformers.XLMRobertaConfig, transformers.RobertaConfig, transformers.BartConfig, ), @@ -463,6 +467,7 @@ class _TransformerTraceableModel(TraceableModel): self._model.config, ( transformers.MPNetConfig, + transformers.XLMRobertaConfig, transformers.RobertaConfig, transformers.BartConfig, ), @@ -639,6 +644,20 @@ class TransformerModel: " ".join(m) for m, _ in sorted(ranks.items(), key=lambda kv: kv[1]) ] vocab_obj["merges"] = merges + sp_model = getattr(self._tokenizer, "sp_model", None) + if sp_model: + id_correction = getattr(self._tokenizer, "fairseq_offset", 0) + scores = [] + for _ in range(0, id_correction): + scores.append(0.0) + for token_id in range(id_correction, len(vocabulary)): + try: + scores.append(sp_model.get_score(token_id - id_correction)) + except IndexError: + scores.append(0.0) + pass + if len(scores) > 0: + vocab_obj["scores"] = scores return vocab_obj def _create_tokenization_config(self) -> NlpTokenizationConfig: @@ -658,6 +677,12 @@ class TransformerModel: self._tokenizer, "max_model_input_sizes", dict() ).get(self._model_id), ) + elif isinstance(self._tokenizer, transformers.XLMRobertaTokenizer): + return NlpXLMRobertaTokenizationConfig( + max_sequence_length=getattr( + self._tokenizer, "max_model_input_sizes", dict() + ).get(self._model_id), + ) else: return NlpBertTokenizationConfig( do_lower_case=getattr(self._tokenizer, "do_lower_case", None),