diff --git a/eland/ml/pytorch/transformers.py b/eland/ml/pytorch/transformers.py index c25b00a..ed047a7 100644 --- a/eland/ml/pytorch/transformers.py +++ b/eland/ml/pytorch/transformers.py @@ -311,7 +311,7 @@ class _SentenceTransformerWrapperModule(nn.Module): # type: ignore ( transformers.BartTokenizer, transformers.MPNetTokenizer, - transformers.RobertaConfig, + transformers.RobertaTokenizer, transformers.XLMRobertaTokenizer, ), ): diff --git a/tests/ml/pytorch/test_pytorch_model_config_pytest.py b/tests/ml/pytorch/test_pytorch_model_config_pytest.py index 09e632b..9c10f4b 100644 --- a/tests/ml/pytorch/test_pytorch_model_config_pytest.py +++ b/tests/ml/pytorch/test_pytorch_model_config_pytest.py @@ -77,6 +77,14 @@ pytestmark = [ # have been imported if HAS_PYTORCH and HAS_SKLEARN and HAS_TRANSFORMERS: MODEL_CONFIGURATIONS = [ + ( + "sentence-transformers/all-distilroberta-v1", + "text_embedding", + TextEmbeddingInferenceOptions, + NlpRobertaTokenizationConfig, + 512, + 768, + ), ( "intfloat/multilingual-e5-small", "text_embedding",