From f26fb8a4302eb1a26b56888ecd03bc5ae498ce4c Mon Sep 17 00:00:00 2001 From: Josh Devins Date: Mon, 31 Jul 2023 18:18:46 +0200 Subject: [PATCH] Simplify embedding model support and loading (#569) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit We were attempting to load SentenceTransformers by looking at the model prefix, however SentenceTransformers can also be loaded from other orgs in the model hub, as well as from local disk. This prefix checking failed in those two cases. To simplify the loading logic and deciding which wrapper to use, we’ve removed support for text_embedding tasks to load a plain Transformer. We now only support DPR embedding models and SentenceTransformer embedding models. If you try to load a plain Transformer model, it will be loaded by SentenceTransformers and a mean pooling layer will automatically be added by the SentenceTransformer library. Since we no longer automatically support non-DPR and non-SentenceTransformers, we should include somewhere example code for how to load a custom model without DPR or SentenceTransformers. See: https://github.com/UKPLab/sentence-transformers/blob/v2.2.2/sentence_transformers/SentenceTransformer.py#L801 Resolves #531 --- eland/ml/pytorch/transformers.py | 39 ++++++++++++++++---------------- 1 file changed, 19 insertions(+), 20 deletions(-) diff --git a/eland/ml/pytorch/transformers.py b/eland/ml/pytorch/transformers.py index 7b624d9..4731659 100644 --- a/eland/ml/pytorch/transformers.py +++ b/eland/ml/pytorch/transformers.py @@ -294,22 +294,19 @@ class _SentenceTransformerWrapperModule(nn.Module): # type: ignore def from_pretrained( model_id: str, output_key: str = DEFAULT_OUTPUT_KEY ) -> Optional[Any]: - if model_id.startswith("sentence-transformers/"): - model = AutoModel.from_pretrained(model_id, torchscript=True) - if isinstance( - model.config, - ( - transformers.MPNetConfig, - transformers.XLMRobertaConfig, - transformers.RobertaConfig, - transformers.BartConfig, - ), - ): - return _TwoParameterSentenceTransformerWrapper(model, output_key) - else: - return _SentenceTransformerWrapper(model, output_key) + model = AutoModel.from_pretrained(model_id, torchscript=True) + if isinstance( + model.config, + ( + transformers.MPNetConfig, + transformers.XLMRobertaConfig, + transformers.RobertaConfig, + transformers.BartConfig, + ), + ): + return _TwoParameterSentenceTransformerWrapper(model, output_key) else: - return None + return _SentenceTransformerWrapper(model, output_key) def _remove_pooling_layer(self) -> None: """ @@ -790,12 +787,10 @@ class TransformerModel: return _TraceableTextClassificationModel(self._tokenizer, model) elif self._task_type == "text_embedding": - model = _SentenceTransformerWrapperModule.from_pretrained(self._model_id) + model = _DPREncoderWrapper.from_pretrained(self._model_id) if not model: - model = _DPREncoderWrapper.from_pretrained(self._model_id) - if not model: - model = transformers.AutoModel.from_pretrained( - self._model_id, torchscript=True + model = _SentenceTransformerWrapperModule.from_pretrained( + self._model_id ) return _TraceableTextEmbeddingModel(self._tokenizer, model) @@ -805,20 +800,24 @@ class TransformerModel: ) model = _DistilBertWrapper.try_wrapping(model) return _TraceableZeroShotClassificationModel(self._tokenizer, model) + elif self._task_type == "question_answering": model = _QuestionAnsweringWrapperModule.from_pretrained(self._model_id) return _TraceableQuestionAnsweringModel(self._tokenizer, model) + elif self._task_type == "text_similarity": model = transformers.AutoModelForSequenceClassification.from_pretrained( self._model_id, torchscript=True ) model = _DistilBertWrapper.try_wrapping(model) return _TraceableTextSimilarityModel(self._tokenizer, model) + elif self._task_type == "pass_through": model = transformers.AutoModel.from_pretrained( self._model_id, torchscript=True ) return _TraceablePassThroughModel(self._tokenizer, model) + else: raise TypeError( f"Unknown task type {self._task_type}, must be one of: {SUPPORTED_TASK_TYPES_NAMES}"