[NLP] Support E5 small multi-lingual (#625)

Although E5 small is a BERT based model it takes 2 parameters to forward
not 4. Use the tokenizer type to decide the number of parameters
This commit is contained in:
David Kyle 2023-10-31 17:49:43 +00:00 committed by GitHub
parent ab6e44f430
commit 5b3a83e7f2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 24 additions and 39 deletions

View File

@ -23,6 +23,7 @@ from eland.ml.pytorch.nlp_ml_model import (
NlpMPNetTokenizationConfig, NlpMPNetTokenizationConfig,
NlpRobertaTokenizationConfig, NlpRobertaTokenizationConfig,
NlpTrainedModelConfig, NlpTrainedModelConfig,
NlpXLMRobertaTokenizationConfig,
QuestionAnsweringInferenceOptions, QuestionAnsweringInferenceOptions,
TextClassificationInferenceOptions, TextClassificationInferenceOptions,
TextEmbeddingInferenceOptions, TextEmbeddingInferenceOptions,

View File

@ -245,8 +245,13 @@ class _TwoParameterQuestionAnsweringWrapper(_QuestionAnsweringWrapperModule):
class _DistilBertWrapper(nn.Module): # type: ignore class _DistilBertWrapper(nn.Module): # type: ignore
""" """
A simple wrapper around DistilBERT model which makes the model inputs In Elasticsearch the BERT tokenizer is used for DistilBERT models but
conform to Elasticsearch's native inference processor interface. the BERT tokenizer produces 4 inputs where DistilBERT models expect 2.
Wrap the model's forward function in a method that accepts the 4
arguments passed to a BERT model then discard the token_type_ids
and the position_ids to match the wrapped DistilBERT model forward
function
""" """
def __init__(self, model: transformers.PreTrainedModel): def __init__(self, model: transformers.PreTrainedModel):
@ -293,18 +298,19 @@ class _SentenceTransformerWrapperModule(nn.Module): # type: ignore
@staticmethod @staticmethod
def from_pretrained( def from_pretrained(
model_id: str, model_id: str,
tokenizer: PreTrainedTokenizer,
*, *,
token: Optional[str] = None, token: Optional[str] = None,
output_key: str = DEFAULT_OUTPUT_KEY, output_key: str = DEFAULT_OUTPUT_KEY,
) -> Optional[Any]: ) -> Optional[Any]:
model = AutoModel.from_pretrained(model_id, token=token, torchscript=True) model = AutoModel.from_pretrained(model_id, token=token, torchscript=True)
if isinstance( if isinstance(
model.config, tokenizer,
( (
transformers.MPNetConfig, transformers.BartTokenizer,
transformers.XLMRobertaConfig, transformers.MPNetTokenizer,
transformers.RobertaConfig, transformers.RobertaConfig,
transformers.BartConfig, transformers.XLMRobertaTokenizer,
), ),
): ):
return _TwoParameterSentenceTransformerWrapper(model, output_key) return _TwoParameterSentenceTransformerWrapper(model, output_key)
@ -466,12 +472,12 @@ class _TransformerTraceableModel(TraceableModel):
inputs["input_ids"].size(1), dtype=torch.long inputs["input_ids"].size(1), dtype=torch.long
) )
if isinstance( if isinstance(
self._model.config, self._tokenizer,
( (
transformers.MPNetConfig, transformers.BartTokenizer,
transformers.XLMRobertaConfig, transformers.MPNetTokenizer,
transformers.RobertaConfig, transformers.RobertaTokenizer,
transformers.BartConfig, transformers.XLMRobertaTokenizer,
), ),
): ):
del inputs["token_type_ids"] del inputs["token_type_ids"]
@ -812,7 +818,7 @@ class TransformerModel:
) )
if not model: if not model:
model = _SentenceTransformerWrapperModule.from_pretrained( model = _SentenceTransformerWrapperModule.from_pretrained(
self._model_id, token=self._access_token self._model_id, self._tokenizer, token=self._access_token
) )
return _TraceableTextEmbeddingModel(self._tokenizer, model) return _TraceableTextEmbeddingModel(self._tokenizer, model)

View File

@ -38,10 +38,10 @@ try:
from eland.ml.pytorch import ( from eland.ml.pytorch import (
FillMaskInferenceOptions, FillMaskInferenceOptions,
NerInferenceOptions,
NlpBertTokenizationConfig, NlpBertTokenizationConfig,
NlpMPNetTokenizationConfig, NlpMPNetTokenizationConfig,
NlpRobertaTokenizationConfig, NlpRobertaTokenizationConfig,
NlpXLMRobertaTokenizationConfig,
QuestionAnsweringInferenceOptions, QuestionAnsweringInferenceOptions,
TextClassificationInferenceOptions, TextClassificationInferenceOptions,
TextEmbeddingInferenceOptions, TextEmbeddingInferenceOptions,
@ -78,10 +78,10 @@ pytestmark = [
if HAS_PYTORCH and HAS_SKLEARN and HAS_TRANSFORMERS: if HAS_PYTORCH and HAS_SKLEARN and HAS_TRANSFORMERS:
MODEL_CONFIGURATIONS = [ MODEL_CONFIGURATIONS = [
( (
"intfloat/e5-small-v2", "intfloat/multilingual-e5-small",
"text_embedding", "text_embedding",
TextEmbeddingInferenceOptions, TextEmbeddingInferenceOptions,
NlpBertTokenizationConfig, NlpXLMRobertaTokenizationConfig,
512, 512,
384, 384,
), ),
@ -93,14 +93,6 @@ if HAS_PYTORCH and HAS_SKLEARN and HAS_TRANSFORMERS:
512, 512,
768, 768,
), ),
(
"sentence-transformers/all-MiniLM-L12-v2",
"text_embedding",
TextEmbeddingInferenceOptions,
NlpBertTokenizationConfig,
512,
384,
),
( (
"facebook/dpr-ctx_encoder-multiset-base", "facebook/dpr-ctx_encoder-multiset-base",
"text_embedding", "text_embedding",
@ -117,22 +109,6 @@ if HAS_PYTORCH and HAS_SKLEARN and HAS_TRANSFORMERS:
512, 512,
None, None,
), ),
(
"bert-base-uncased",
"fill_mask",
FillMaskInferenceOptions,
NlpBertTokenizationConfig,
512,
None,
),
(
"elastic/distilbert-base-uncased-finetuned-conll03-english",
"ner",
NerInferenceOptions,
NlpBertTokenizationConfig,
512,
None,
),
( (
"SamLowe/roberta-base-go_emotions", "SamLowe/roberta-base-go_emotions",
"text_classification", "text_classification",
@ -214,3 +190,5 @@ class TestModelConfguration:
if task_type == "zero_shot_classification": if task_type == "zero_shot_classification":
assert isinstance(config.inference_config.classification_labels, list) assert isinstance(config.inference_config.classification_labels, list)
assert len(config.inference_config.classification_labels) > 0 assert len(config.inference_config.classification_labels) > 0
del tm