mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
Simplify embedding model support and loading (#569)
We were attempting to load SentenceTransformers by looking at the model prefix, however SentenceTransformers can also be loaded from other orgs in the model hub, as well as from local disk. This prefix checking failed in those two cases. To simplify the loading logic and deciding which wrapper to use, we’ve removed support for text_embedding tasks to load a plain Transformer. We now only support DPR embedding models and SentenceTransformer embedding models. If you try to load a plain Transformer model, it will be loaded by SentenceTransformers and a mean pooling layer will automatically be added by the SentenceTransformer library. Since we no longer automatically support non-DPR and non-SentenceTransformers, we should include somewhere example code for how to load a custom model without DPR or SentenceTransformers. See: https://github.com/UKPLab/sentence-transformers/blob/v2.2.2/sentence_transformers/SentenceTransformer.py#L801 Resolves #531
This commit is contained in:
parent
7ad1f430e4
commit
f26fb8a430
@ -294,22 +294,19 @@ class _SentenceTransformerWrapperModule(nn.Module): # type: ignore
|
||||
def from_pretrained(
|
||||
model_id: str, output_key: str = DEFAULT_OUTPUT_KEY
|
||||
) -> Optional[Any]:
|
||||
if model_id.startswith("sentence-transformers/"):
|
||||
model = AutoModel.from_pretrained(model_id, torchscript=True)
|
||||
if isinstance(
|
||||
model.config,
|
||||
(
|
||||
transformers.MPNetConfig,
|
||||
transformers.XLMRobertaConfig,
|
||||
transformers.RobertaConfig,
|
||||
transformers.BartConfig,
|
||||
),
|
||||
):
|
||||
return _TwoParameterSentenceTransformerWrapper(model, output_key)
|
||||
else:
|
||||
return _SentenceTransformerWrapper(model, output_key)
|
||||
model = AutoModel.from_pretrained(model_id, torchscript=True)
|
||||
if isinstance(
|
||||
model.config,
|
||||
(
|
||||
transformers.MPNetConfig,
|
||||
transformers.XLMRobertaConfig,
|
||||
transformers.RobertaConfig,
|
||||
transformers.BartConfig,
|
||||
),
|
||||
):
|
||||
return _TwoParameterSentenceTransformerWrapper(model, output_key)
|
||||
else:
|
||||
return None
|
||||
return _SentenceTransformerWrapper(model, output_key)
|
||||
|
||||
def _remove_pooling_layer(self) -> None:
|
||||
"""
|
||||
@ -790,12 +787,10 @@ class TransformerModel:
|
||||
return _TraceableTextClassificationModel(self._tokenizer, model)
|
||||
|
||||
elif self._task_type == "text_embedding":
|
||||
model = _SentenceTransformerWrapperModule.from_pretrained(self._model_id)
|
||||
model = _DPREncoderWrapper.from_pretrained(self._model_id)
|
||||
if not model:
|
||||
model = _DPREncoderWrapper.from_pretrained(self._model_id)
|
||||
if not model:
|
||||
model = transformers.AutoModel.from_pretrained(
|
||||
self._model_id, torchscript=True
|
||||
model = _SentenceTransformerWrapperModule.from_pretrained(
|
||||
self._model_id
|
||||
)
|
||||
return _TraceableTextEmbeddingModel(self._tokenizer, model)
|
||||
|
||||
@ -805,20 +800,24 @@ class TransformerModel:
|
||||
)
|
||||
model = _DistilBertWrapper.try_wrapping(model)
|
||||
return _TraceableZeroShotClassificationModel(self._tokenizer, model)
|
||||
|
||||
elif self._task_type == "question_answering":
|
||||
model = _QuestionAnsweringWrapperModule.from_pretrained(self._model_id)
|
||||
return _TraceableQuestionAnsweringModel(self._tokenizer, model)
|
||||
|
||||
elif self._task_type == "text_similarity":
|
||||
model = transformers.AutoModelForSequenceClassification.from_pretrained(
|
||||
self._model_id, torchscript=True
|
||||
)
|
||||
model = _DistilBertWrapper.try_wrapping(model)
|
||||
return _TraceableTextSimilarityModel(self._tokenizer, model)
|
||||
|
||||
elif self._task_type == "pass_through":
|
||||
model = transformers.AutoModel.from_pretrained(
|
||||
self._model_id, torchscript=True
|
||||
)
|
||||
return _TraceablePassThroughModel(self._tokenizer, model)
|
||||
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Unknown task type {self._task_type}, must be one of: {SUPPORTED_TASK_TYPES_NAMES}"
|
||||
|
Loading…
x
Reference in New Issue
Block a user