From 081250cdec4e54a2660a680b8b1b683ae03f86ae Mon Sep 17 00:00:00 2001 From: David Kyle Date: Tue, 21 Nov 2023 12:53:43 +0000 Subject: [PATCH] Fix failed import of ST RoBERTa models (#637) Fixes an error uploading the sentence-transformers/all-distilroberta-v1 model which failed with "missing 2 required positional arguments: 'token_type_ids' and 'position_ids'". The cause was that the tokenizer type was not recognised due to a typo --- eland/ml/pytorch/transformers.py | 2 +- tests/ml/pytorch/test_pytorch_model_config_pytest.py | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/eland/ml/pytorch/transformers.py b/eland/ml/pytorch/transformers.py index c25b00a..ed047a7 100644 --- a/eland/ml/pytorch/transformers.py +++ b/eland/ml/pytorch/transformers.py @@ -311,7 +311,7 @@ class _SentenceTransformerWrapperModule(nn.Module): # type: ignore ( transformers.BartTokenizer, transformers.MPNetTokenizer, - transformers.RobertaConfig, + transformers.RobertaTokenizer, transformers.XLMRobertaTokenizer, ), ): diff --git a/tests/ml/pytorch/test_pytorch_model_config_pytest.py b/tests/ml/pytorch/test_pytorch_model_config_pytest.py index 09e632b..9c10f4b 100644 --- a/tests/ml/pytorch/test_pytorch_model_config_pytest.py +++ b/tests/ml/pytorch/test_pytorch_model_config_pytest.py @@ -77,6 +77,14 @@ pytestmark = [ # have been imported if HAS_PYTORCH and HAS_SKLEARN and HAS_TRANSFORMERS: MODEL_CONFIGURATIONS = [ + ( + "sentence-transformers/all-distilroberta-v1", + "text_embedding", + TextEmbeddingInferenceOptions, + NlpRobertaTokenizationConfig, + 512, + 768, + ), ( "intfloat/multilingual-e5-small", "text_embedding",