Simplify embedding model support and loading (#569)

We were attempting to load SentenceTransformers by looking at the model
prefix, however SentenceTransformers can also be loaded from other
orgs in the model hub, as well as from local disk. This prefix checking
failed in those two cases. To simplify the loading logic and deciding
which wrapper to use, we’ve removed support for text_embedding tasks to
load a plain Transformer. We now only support DPR embedding models and
SentenceTransformer embedding models. If you try to load a plain
Transformer model, it will be loaded by SentenceTransformers and a mean
pooling layer will automatically be added by the SentenceTransformer
library. Since we no longer automatically support non-DPR and
non-SentenceTransformers, we should include somewhere example code for
how to load a custom model without DPR or SentenceTransformers. 

See: https://github.com/UKPLab/sentence-transformers/blob/v2.2.2/sentence_transformers/SentenceTransformer.py#L801

Resolves #531
This commit is contained in:
Josh Devins 2023-07-31 18:18:46 +02:00 committed by GitHub
parent 7ad1f430e4
commit f26fb8a430
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -294,22 +294,19 @@ class _SentenceTransformerWrapperModule(nn.Module): # type: ignore
def from_pretrained(
model_id: str, output_key: str = DEFAULT_OUTPUT_KEY
) -> Optional[Any]:
if model_id.startswith("sentence-transformers/"):
model = AutoModel.from_pretrained(model_id, torchscript=True)
if isinstance(
model.config,
(
transformers.MPNetConfig,
transformers.XLMRobertaConfig,
transformers.RobertaConfig,
transformers.BartConfig,
),
):
return _TwoParameterSentenceTransformerWrapper(model, output_key)
else:
return _SentenceTransformerWrapper(model, output_key)
model = AutoModel.from_pretrained(model_id, torchscript=True)
if isinstance(
model.config,
(
transformers.MPNetConfig,
transformers.XLMRobertaConfig,
transformers.RobertaConfig,
transformers.BartConfig,
),
):
return _TwoParameterSentenceTransformerWrapper(model, output_key)
else:
return None
return _SentenceTransformerWrapper(model, output_key)
def _remove_pooling_layer(self) -> None:
"""
@ -790,12 +787,10 @@ class TransformerModel:
return _TraceableTextClassificationModel(self._tokenizer, model)
elif self._task_type == "text_embedding":
model = _SentenceTransformerWrapperModule.from_pretrained(self._model_id)
model = _DPREncoderWrapper.from_pretrained(self._model_id)
if not model:
model = _DPREncoderWrapper.from_pretrained(self._model_id)
if not model:
model = transformers.AutoModel.from_pretrained(
self._model_id, torchscript=True
model = _SentenceTransformerWrapperModule.from_pretrained(
self._model_id
)
return _TraceableTextEmbeddingModel(self._tokenizer, model)
@ -805,20 +800,24 @@ class TransformerModel:
)
model = _DistilBertWrapper.try_wrapping(model)
return _TraceableZeroShotClassificationModel(self._tokenizer, model)
elif self._task_type == "question_answering":
model = _QuestionAnsweringWrapperModule.from_pretrained(self._model_id)
return _TraceableQuestionAnsweringModel(self._tokenizer, model)
elif self._task_type == "text_similarity":
model = transformers.AutoModelForSequenceClassification.from_pretrained(
self._model_id, torchscript=True
)
model = _DistilBertWrapper.try_wrapping(model)
return _TraceableTextSimilarityModel(self._tokenizer, model)
elif self._task_type == "pass_through":
model = transformers.AutoModel.from_pretrained(
self._model_id, torchscript=True
)
return _TraceablePassThroughModel(self._tokenizer, model)
else:
raise TypeError(
f"Unknown task type {self._task_type}, must be one of: {SUPPORTED_TASK_TYPES_NAMES}"