From 5b3a83e7f21317997433c1da9f46e91bc3304f23 Mon Sep 17 00:00:00 2001 From: David Kyle Date: Tue, 31 Oct 2023 17:49:43 +0000 Subject: [PATCH] [NLP] Support E5 small multi-lingual (#625) Although E5 small is a BERT based model it takes 2 parameters to forward not 4. Use the tokenizer type to decide the number of parameters --- eland/ml/pytorch/__init__.py | 1 + eland/ml/pytorch/transformers.py | 30 ++++++++++------- ...py => test_pytorch_model_config_pytest.py} | 32 +++---------------- 3 files changed, 24 insertions(+), 39 deletions(-) rename tests/ml/pytorch/{test_pytorch_model_config.py => test_pytorch_model_config_pytest.py} (88%) diff --git a/eland/ml/pytorch/__init__.py b/eland/ml/pytorch/__init__.py index 465e80f..4a7ce3d 100644 --- a/eland/ml/pytorch/__init__.py +++ b/eland/ml/pytorch/__init__.py @@ -23,6 +23,7 @@ from eland.ml.pytorch.nlp_ml_model import ( NlpMPNetTokenizationConfig, NlpRobertaTokenizationConfig, NlpTrainedModelConfig, + NlpXLMRobertaTokenizationConfig, QuestionAnsweringInferenceOptions, TextClassificationInferenceOptions, TextEmbeddingInferenceOptions, diff --git a/eland/ml/pytorch/transformers.py b/eland/ml/pytorch/transformers.py index fb15bec..5f90350 100644 --- a/eland/ml/pytorch/transformers.py +++ b/eland/ml/pytorch/transformers.py @@ -245,8 +245,13 @@ class _TwoParameterQuestionAnsweringWrapper(_QuestionAnsweringWrapperModule): class _DistilBertWrapper(nn.Module): # type: ignore """ - A simple wrapper around DistilBERT model which makes the model inputs - conform to Elasticsearch's native inference processor interface. + In Elasticsearch the BERT tokenizer is used for DistilBERT models but + the BERT tokenizer produces 4 inputs where DistilBERT models expect 2. + + Wrap the model's forward function in a method that accepts the 4 + arguments passed to a BERT model then discard the token_type_ids + and the position_ids to match the wrapped DistilBERT model forward + function """ def __init__(self, model: transformers.PreTrainedModel): @@ -293,18 +298,19 @@ class _SentenceTransformerWrapperModule(nn.Module): # type: ignore @staticmethod def from_pretrained( model_id: str, + tokenizer: PreTrainedTokenizer, *, token: Optional[str] = None, output_key: str = DEFAULT_OUTPUT_KEY, ) -> Optional[Any]: model = AutoModel.from_pretrained(model_id, token=token, torchscript=True) if isinstance( - model.config, + tokenizer, ( - transformers.MPNetConfig, - transformers.XLMRobertaConfig, + transformers.BartTokenizer, + transformers.MPNetTokenizer, transformers.RobertaConfig, - transformers.BartConfig, + transformers.XLMRobertaTokenizer, ), ): return _TwoParameterSentenceTransformerWrapper(model, output_key) @@ -466,12 +472,12 @@ class _TransformerTraceableModel(TraceableModel): inputs["input_ids"].size(1), dtype=torch.long ) if isinstance( - self._model.config, + self._tokenizer, ( - transformers.MPNetConfig, - transformers.XLMRobertaConfig, - transformers.RobertaConfig, - transformers.BartConfig, + transformers.BartTokenizer, + transformers.MPNetTokenizer, + transformers.RobertaTokenizer, + transformers.XLMRobertaTokenizer, ), ): del inputs["token_type_ids"] @@ -812,7 +818,7 @@ class TransformerModel: ) if not model: model = _SentenceTransformerWrapperModule.from_pretrained( - self._model_id, token=self._access_token + self._model_id, self._tokenizer, token=self._access_token ) return _TraceableTextEmbeddingModel(self._tokenizer, model) diff --git a/tests/ml/pytorch/test_pytorch_model_config.py b/tests/ml/pytorch/test_pytorch_model_config_pytest.py similarity index 88% rename from tests/ml/pytorch/test_pytorch_model_config.py rename to tests/ml/pytorch/test_pytorch_model_config_pytest.py index beac170..b826baa 100644 --- a/tests/ml/pytorch/test_pytorch_model_config.py +++ b/tests/ml/pytorch/test_pytorch_model_config_pytest.py @@ -38,10 +38,10 @@ try: from eland.ml.pytorch import ( FillMaskInferenceOptions, - NerInferenceOptions, NlpBertTokenizationConfig, NlpMPNetTokenizationConfig, NlpRobertaTokenizationConfig, + NlpXLMRobertaTokenizationConfig, QuestionAnsweringInferenceOptions, TextClassificationInferenceOptions, TextEmbeddingInferenceOptions, @@ -78,10 +78,10 @@ pytestmark = [ if HAS_PYTORCH and HAS_SKLEARN and HAS_TRANSFORMERS: MODEL_CONFIGURATIONS = [ ( - "intfloat/e5-small-v2", + "intfloat/multilingual-e5-small", "text_embedding", TextEmbeddingInferenceOptions, - NlpBertTokenizationConfig, + NlpXLMRobertaTokenizationConfig, 512, 384, ), @@ -93,14 +93,6 @@ if HAS_PYTORCH and HAS_SKLEARN and HAS_TRANSFORMERS: 512, 768, ), - ( - "sentence-transformers/all-MiniLM-L12-v2", - "text_embedding", - TextEmbeddingInferenceOptions, - NlpBertTokenizationConfig, - 512, - 384, - ), ( "facebook/dpr-ctx_encoder-multiset-base", "text_embedding", @@ -117,22 +109,6 @@ if HAS_PYTORCH and HAS_SKLEARN and HAS_TRANSFORMERS: 512, None, ), - ( - "bert-base-uncased", - "fill_mask", - FillMaskInferenceOptions, - NlpBertTokenizationConfig, - 512, - None, - ), - ( - "elastic/distilbert-base-uncased-finetuned-conll03-english", - "ner", - NerInferenceOptions, - NlpBertTokenizationConfig, - 512, - None, - ), ( "SamLowe/roberta-base-go_emotions", "text_classification", @@ -214,3 +190,5 @@ class TestModelConfguration: if task_type == "zero_shot_classification": assert isinstance(config.inference_config.classification_labels, list) assert len(config.inference_config.classification_labels) > 0 + + del tm