mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
[NLP] Support E5 small multi-lingual (#625)
Although E5 small is a BERT based model it takes 2 parameters to forward not 4. Use the tokenizer type to decide the number of parameters
This commit is contained in:
parent
ab6e44f430
commit
5b3a83e7f2
@ -23,6 +23,7 @@ from eland.ml.pytorch.nlp_ml_model import (
|
|||||||
NlpMPNetTokenizationConfig,
|
NlpMPNetTokenizationConfig,
|
||||||
NlpRobertaTokenizationConfig,
|
NlpRobertaTokenizationConfig,
|
||||||
NlpTrainedModelConfig,
|
NlpTrainedModelConfig,
|
||||||
|
NlpXLMRobertaTokenizationConfig,
|
||||||
QuestionAnsweringInferenceOptions,
|
QuestionAnsweringInferenceOptions,
|
||||||
TextClassificationInferenceOptions,
|
TextClassificationInferenceOptions,
|
||||||
TextEmbeddingInferenceOptions,
|
TextEmbeddingInferenceOptions,
|
||||||
|
@ -245,8 +245,13 @@ class _TwoParameterQuestionAnsweringWrapper(_QuestionAnsweringWrapperModule):
|
|||||||
|
|
||||||
class _DistilBertWrapper(nn.Module): # type: ignore
|
class _DistilBertWrapper(nn.Module): # type: ignore
|
||||||
"""
|
"""
|
||||||
A simple wrapper around DistilBERT model which makes the model inputs
|
In Elasticsearch the BERT tokenizer is used for DistilBERT models but
|
||||||
conform to Elasticsearch's native inference processor interface.
|
the BERT tokenizer produces 4 inputs where DistilBERT models expect 2.
|
||||||
|
|
||||||
|
Wrap the model's forward function in a method that accepts the 4
|
||||||
|
arguments passed to a BERT model then discard the token_type_ids
|
||||||
|
and the position_ids to match the wrapped DistilBERT model forward
|
||||||
|
function
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model: transformers.PreTrainedModel):
|
def __init__(self, model: transformers.PreTrainedModel):
|
||||||
@ -293,18 +298,19 @@ class _SentenceTransformerWrapperModule(nn.Module): # type: ignore
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def from_pretrained(
|
def from_pretrained(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
tokenizer: PreTrainedTokenizer,
|
||||||
*,
|
*,
|
||||||
token: Optional[str] = None,
|
token: Optional[str] = None,
|
||||||
output_key: str = DEFAULT_OUTPUT_KEY,
|
output_key: str = DEFAULT_OUTPUT_KEY,
|
||||||
) -> Optional[Any]:
|
) -> Optional[Any]:
|
||||||
model = AutoModel.from_pretrained(model_id, token=token, torchscript=True)
|
model = AutoModel.from_pretrained(model_id, token=token, torchscript=True)
|
||||||
if isinstance(
|
if isinstance(
|
||||||
model.config,
|
tokenizer,
|
||||||
(
|
(
|
||||||
transformers.MPNetConfig,
|
transformers.BartTokenizer,
|
||||||
transformers.XLMRobertaConfig,
|
transformers.MPNetTokenizer,
|
||||||
transformers.RobertaConfig,
|
transformers.RobertaConfig,
|
||||||
transformers.BartConfig,
|
transformers.XLMRobertaTokenizer,
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
return _TwoParameterSentenceTransformerWrapper(model, output_key)
|
return _TwoParameterSentenceTransformerWrapper(model, output_key)
|
||||||
@ -466,12 +472,12 @@ class _TransformerTraceableModel(TraceableModel):
|
|||||||
inputs["input_ids"].size(1), dtype=torch.long
|
inputs["input_ids"].size(1), dtype=torch.long
|
||||||
)
|
)
|
||||||
if isinstance(
|
if isinstance(
|
||||||
self._model.config,
|
self._tokenizer,
|
||||||
(
|
(
|
||||||
transformers.MPNetConfig,
|
transformers.BartTokenizer,
|
||||||
transformers.XLMRobertaConfig,
|
transformers.MPNetTokenizer,
|
||||||
transformers.RobertaConfig,
|
transformers.RobertaTokenizer,
|
||||||
transformers.BartConfig,
|
transformers.XLMRobertaTokenizer,
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
del inputs["token_type_ids"]
|
del inputs["token_type_ids"]
|
||||||
@ -812,7 +818,7 @@ class TransformerModel:
|
|||||||
)
|
)
|
||||||
if not model:
|
if not model:
|
||||||
model = _SentenceTransformerWrapperModule.from_pretrained(
|
model = _SentenceTransformerWrapperModule.from_pretrained(
|
||||||
self._model_id, token=self._access_token
|
self._model_id, self._tokenizer, token=self._access_token
|
||||||
)
|
)
|
||||||
return _TraceableTextEmbeddingModel(self._tokenizer, model)
|
return _TraceableTextEmbeddingModel(self._tokenizer, model)
|
||||||
|
|
||||||
|
@ -38,10 +38,10 @@ try:
|
|||||||
|
|
||||||
from eland.ml.pytorch import (
|
from eland.ml.pytorch import (
|
||||||
FillMaskInferenceOptions,
|
FillMaskInferenceOptions,
|
||||||
NerInferenceOptions,
|
|
||||||
NlpBertTokenizationConfig,
|
NlpBertTokenizationConfig,
|
||||||
NlpMPNetTokenizationConfig,
|
NlpMPNetTokenizationConfig,
|
||||||
NlpRobertaTokenizationConfig,
|
NlpRobertaTokenizationConfig,
|
||||||
|
NlpXLMRobertaTokenizationConfig,
|
||||||
QuestionAnsweringInferenceOptions,
|
QuestionAnsweringInferenceOptions,
|
||||||
TextClassificationInferenceOptions,
|
TextClassificationInferenceOptions,
|
||||||
TextEmbeddingInferenceOptions,
|
TextEmbeddingInferenceOptions,
|
||||||
@ -78,10 +78,10 @@ pytestmark = [
|
|||||||
if HAS_PYTORCH and HAS_SKLEARN and HAS_TRANSFORMERS:
|
if HAS_PYTORCH and HAS_SKLEARN and HAS_TRANSFORMERS:
|
||||||
MODEL_CONFIGURATIONS = [
|
MODEL_CONFIGURATIONS = [
|
||||||
(
|
(
|
||||||
"intfloat/e5-small-v2",
|
"intfloat/multilingual-e5-small",
|
||||||
"text_embedding",
|
"text_embedding",
|
||||||
TextEmbeddingInferenceOptions,
|
TextEmbeddingInferenceOptions,
|
||||||
NlpBertTokenizationConfig,
|
NlpXLMRobertaTokenizationConfig,
|
||||||
512,
|
512,
|
||||||
384,
|
384,
|
||||||
),
|
),
|
||||||
@ -93,14 +93,6 @@ if HAS_PYTORCH and HAS_SKLEARN and HAS_TRANSFORMERS:
|
|||||||
512,
|
512,
|
||||||
768,
|
768,
|
||||||
),
|
),
|
||||||
(
|
|
||||||
"sentence-transformers/all-MiniLM-L12-v2",
|
|
||||||
"text_embedding",
|
|
||||||
TextEmbeddingInferenceOptions,
|
|
||||||
NlpBertTokenizationConfig,
|
|
||||||
512,
|
|
||||||
384,
|
|
||||||
),
|
|
||||||
(
|
(
|
||||||
"facebook/dpr-ctx_encoder-multiset-base",
|
"facebook/dpr-ctx_encoder-multiset-base",
|
||||||
"text_embedding",
|
"text_embedding",
|
||||||
@ -117,22 +109,6 @@ if HAS_PYTORCH and HAS_SKLEARN and HAS_TRANSFORMERS:
|
|||||||
512,
|
512,
|
||||||
None,
|
None,
|
||||||
),
|
),
|
||||||
(
|
|
||||||
"bert-base-uncased",
|
|
||||||
"fill_mask",
|
|
||||||
FillMaskInferenceOptions,
|
|
||||||
NlpBertTokenizationConfig,
|
|
||||||
512,
|
|
||||||
None,
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"elastic/distilbert-base-uncased-finetuned-conll03-english",
|
|
||||||
"ner",
|
|
||||||
NerInferenceOptions,
|
|
||||||
NlpBertTokenizationConfig,
|
|
||||||
512,
|
|
||||||
None,
|
|
||||||
),
|
|
||||||
(
|
(
|
||||||
"SamLowe/roberta-base-go_emotions",
|
"SamLowe/roberta-base-go_emotions",
|
||||||
"text_classification",
|
"text_classification",
|
||||||
@ -214,3 +190,5 @@ class TestModelConfguration:
|
|||||||
if task_type == "zero_shot_classification":
|
if task_type == "zero_shot_classification":
|
||||||
assert isinstance(config.inference_config.classification_labels, list)
|
assert isinstance(config.inference_config.classification_labels, list)
|
||||||
assert len(config.inference_config.classification_labels) > 0
|
assert len(config.inference_config.classification_labels) > 0
|
||||||
|
|
||||||
|
del tm
|
Loading…
x
Reference in New Issue
Block a user