mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
[NLP] Support E5 small multi-lingual (#625)
Although E5 small is a BERT based model it takes 2 parameters to forward not 4. Use the tokenizer type to decide the number of parameters
This commit is contained in:
parent
ab6e44f430
commit
5b3a83e7f2
@ -23,6 +23,7 @@ from eland.ml.pytorch.nlp_ml_model import (
|
||||
NlpMPNetTokenizationConfig,
|
||||
NlpRobertaTokenizationConfig,
|
||||
NlpTrainedModelConfig,
|
||||
NlpXLMRobertaTokenizationConfig,
|
||||
QuestionAnsweringInferenceOptions,
|
||||
TextClassificationInferenceOptions,
|
||||
TextEmbeddingInferenceOptions,
|
||||
|
@ -245,8 +245,13 @@ class _TwoParameterQuestionAnsweringWrapper(_QuestionAnsweringWrapperModule):
|
||||
|
||||
class _DistilBertWrapper(nn.Module): # type: ignore
|
||||
"""
|
||||
A simple wrapper around DistilBERT model which makes the model inputs
|
||||
conform to Elasticsearch's native inference processor interface.
|
||||
In Elasticsearch the BERT tokenizer is used for DistilBERT models but
|
||||
the BERT tokenizer produces 4 inputs where DistilBERT models expect 2.
|
||||
|
||||
Wrap the model's forward function in a method that accepts the 4
|
||||
arguments passed to a BERT model then discard the token_type_ids
|
||||
and the position_ids to match the wrapped DistilBERT model forward
|
||||
function
|
||||
"""
|
||||
|
||||
def __init__(self, model: transformers.PreTrainedModel):
|
||||
@ -293,18 +298,19 @@ class _SentenceTransformerWrapperModule(nn.Module): # type: ignore
|
||||
@staticmethod
|
||||
def from_pretrained(
|
||||
model_id: str,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
*,
|
||||
token: Optional[str] = None,
|
||||
output_key: str = DEFAULT_OUTPUT_KEY,
|
||||
) -> Optional[Any]:
|
||||
model = AutoModel.from_pretrained(model_id, token=token, torchscript=True)
|
||||
if isinstance(
|
||||
model.config,
|
||||
tokenizer,
|
||||
(
|
||||
transformers.MPNetConfig,
|
||||
transformers.XLMRobertaConfig,
|
||||
transformers.BartTokenizer,
|
||||
transformers.MPNetTokenizer,
|
||||
transformers.RobertaConfig,
|
||||
transformers.BartConfig,
|
||||
transformers.XLMRobertaTokenizer,
|
||||
),
|
||||
):
|
||||
return _TwoParameterSentenceTransformerWrapper(model, output_key)
|
||||
@ -466,12 +472,12 @@ class _TransformerTraceableModel(TraceableModel):
|
||||
inputs["input_ids"].size(1), dtype=torch.long
|
||||
)
|
||||
if isinstance(
|
||||
self._model.config,
|
||||
self._tokenizer,
|
||||
(
|
||||
transformers.MPNetConfig,
|
||||
transformers.XLMRobertaConfig,
|
||||
transformers.RobertaConfig,
|
||||
transformers.BartConfig,
|
||||
transformers.BartTokenizer,
|
||||
transformers.MPNetTokenizer,
|
||||
transformers.RobertaTokenizer,
|
||||
transformers.XLMRobertaTokenizer,
|
||||
),
|
||||
):
|
||||
del inputs["token_type_ids"]
|
||||
@ -812,7 +818,7 @@ class TransformerModel:
|
||||
)
|
||||
if not model:
|
||||
model = _SentenceTransformerWrapperModule.from_pretrained(
|
||||
self._model_id, token=self._access_token
|
||||
self._model_id, self._tokenizer, token=self._access_token
|
||||
)
|
||||
return _TraceableTextEmbeddingModel(self._tokenizer, model)
|
||||
|
||||
|
@ -38,10 +38,10 @@ try:
|
||||
|
||||
from eland.ml.pytorch import (
|
||||
FillMaskInferenceOptions,
|
||||
NerInferenceOptions,
|
||||
NlpBertTokenizationConfig,
|
||||
NlpMPNetTokenizationConfig,
|
||||
NlpRobertaTokenizationConfig,
|
||||
NlpXLMRobertaTokenizationConfig,
|
||||
QuestionAnsweringInferenceOptions,
|
||||
TextClassificationInferenceOptions,
|
||||
TextEmbeddingInferenceOptions,
|
||||
@ -78,10 +78,10 @@ pytestmark = [
|
||||
if HAS_PYTORCH and HAS_SKLEARN and HAS_TRANSFORMERS:
|
||||
MODEL_CONFIGURATIONS = [
|
||||
(
|
||||
"intfloat/e5-small-v2",
|
||||
"intfloat/multilingual-e5-small",
|
||||
"text_embedding",
|
||||
TextEmbeddingInferenceOptions,
|
||||
NlpBertTokenizationConfig,
|
||||
NlpXLMRobertaTokenizationConfig,
|
||||
512,
|
||||
384,
|
||||
),
|
||||
@ -93,14 +93,6 @@ if HAS_PYTORCH and HAS_SKLEARN and HAS_TRANSFORMERS:
|
||||
512,
|
||||
768,
|
||||
),
|
||||
(
|
||||
"sentence-transformers/all-MiniLM-L12-v2",
|
||||
"text_embedding",
|
||||
TextEmbeddingInferenceOptions,
|
||||
NlpBertTokenizationConfig,
|
||||
512,
|
||||
384,
|
||||
),
|
||||
(
|
||||
"facebook/dpr-ctx_encoder-multiset-base",
|
||||
"text_embedding",
|
||||
@ -117,22 +109,6 @@ if HAS_PYTORCH and HAS_SKLEARN and HAS_TRANSFORMERS:
|
||||
512,
|
||||
None,
|
||||
),
|
||||
(
|
||||
"bert-base-uncased",
|
||||
"fill_mask",
|
||||
FillMaskInferenceOptions,
|
||||
NlpBertTokenizationConfig,
|
||||
512,
|
||||
None,
|
||||
),
|
||||
(
|
||||
"elastic/distilbert-base-uncased-finetuned-conll03-english",
|
||||
"ner",
|
||||
NerInferenceOptions,
|
||||
NlpBertTokenizationConfig,
|
||||
512,
|
||||
None,
|
||||
),
|
||||
(
|
||||
"SamLowe/roberta-base-go_emotions",
|
||||
"text_classification",
|
||||
@ -214,3 +190,5 @@ class TestModelConfguration:
|
||||
if task_type == "zero_shot_classification":
|
||||
assert isinstance(config.inference_config.classification_labels, list)
|
||||
assert len(config.inference_config.classification_labels) > 0
|
||||
|
||||
del tm
|
Loading…
x
Reference in New Issue
Block a user