diff --git a/eland/ml/pytorch/transformers.py b/eland/ml/pytorch/transformers.py index 7b624d9..4731659 100644 --- a/eland/ml/pytorch/transformers.py +++ b/eland/ml/pytorch/transformers.py @@ -294,22 +294,19 @@ class _SentenceTransformerWrapperModule(nn.Module): # type: ignore def from_pretrained( model_id: str, output_key: str = DEFAULT_OUTPUT_KEY ) -> Optional[Any]: - if model_id.startswith("sentence-transformers/"): - model = AutoModel.from_pretrained(model_id, torchscript=True) - if isinstance( - model.config, - ( - transformers.MPNetConfig, - transformers.XLMRobertaConfig, - transformers.RobertaConfig, - transformers.BartConfig, - ), - ): - return _TwoParameterSentenceTransformerWrapper(model, output_key) - else: - return _SentenceTransformerWrapper(model, output_key) + model = AutoModel.from_pretrained(model_id, torchscript=True) + if isinstance( + model.config, + ( + transformers.MPNetConfig, + transformers.XLMRobertaConfig, + transformers.RobertaConfig, + transformers.BartConfig, + ), + ): + return _TwoParameterSentenceTransformerWrapper(model, output_key) else: - return None + return _SentenceTransformerWrapper(model, output_key) def _remove_pooling_layer(self) -> None: """ @@ -790,12 +787,10 @@ class TransformerModel: return _TraceableTextClassificationModel(self._tokenizer, model) elif self._task_type == "text_embedding": - model = _SentenceTransformerWrapperModule.from_pretrained(self._model_id) + model = _DPREncoderWrapper.from_pretrained(self._model_id) if not model: - model = _DPREncoderWrapper.from_pretrained(self._model_id) - if not model: - model = transformers.AutoModel.from_pretrained( - self._model_id, torchscript=True + model = _SentenceTransformerWrapperModule.from_pretrained( + self._model_id ) return _TraceableTextEmbeddingModel(self._tokenizer, model) @@ -805,20 +800,24 @@ class TransformerModel: ) model = _DistilBertWrapper.try_wrapping(model) return _TraceableZeroShotClassificationModel(self._tokenizer, model) + elif self._task_type == "question_answering": model = _QuestionAnsweringWrapperModule.from_pretrained(self._model_id) return _TraceableQuestionAnsweringModel(self._tokenizer, model) + elif self._task_type == "text_similarity": model = transformers.AutoModelForSequenceClassification.from_pretrained( self._model_id, torchscript=True ) model = _DistilBertWrapper.try_wrapping(model) return _TraceableTextSimilarityModel(self._tokenizer, model) + elif self._task_type == "pass_through": model = transformers.AutoModel.from_pretrained( self._model_id, torchscript=True ) return _TraceablePassThroughModel(self._tokenizer, model) + else: raise TypeError( f"Unknown task type {self._task_type}, must be one of: {SUPPORTED_TASK_TYPES_NAMES}"