mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
[ML] Add support for MPNet PyTorch models
This commit is contained in:
parent
64daa07a65
commit
72856e2c3f
@ -48,6 +48,7 @@ SUPPORTED_TASK_TYPES = {
|
||||
SUPPORTED_TASK_TYPES_NAMES = ", ".join(sorted(SUPPORTED_TASK_TYPES))
|
||||
SUPPORTED_TOKENIZERS = (
|
||||
transformers.BertTokenizer,
|
||||
transformers.MPNetTokenizer,
|
||||
transformers.DPRContextEncoderTokenizer,
|
||||
transformers.DPRQuestionEncoderTokenizer,
|
||||
transformers.DistilBertTokenizer,
|
||||
@ -96,7 +97,7 @@ class _DistilBertWrapper(nn.Module): # type: ignore
|
||||
return self._model(input_ids=input_ids, attention_mask=attention_mask)
|
||||
|
||||
|
||||
class _SentenceTransformerWrapper(nn.Module): # type: ignore
|
||||
class _SentenceTransformerWrapperModule(nn.Module): # type: ignore
|
||||
"""
|
||||
A wrapper around sentence-transformer models to provide pooling,
|
||||
normalization and other graph layers that are not defined in the base
|
||||
@ -108,6 +109,7 @@ class _SentenceTransformerWrapper(nn.Module): # type: ignore
|
||||
self._hf_model = model
|
||||
self._st_model = SentenceTransformer(model.config.name_or_path)
|
||||
self._output_key = output_key
|
||||
self.config = model.config
|
||||
|
||||
self._remove_pooling_layer()
|
||||
self._replace_transformer_layer()
|
||||
@ -118,7 +120,10 @@ class _SentenceTransformerWrapper(nn.Module): # type: ignore
|
||||
) -> Optional[Any]:
|
||||
if model_id.startswith("sentence-transformers/"):
|
||||
model = AutoModel.from_pretrained(model_id, torchscript=True)
|
||||
return _SentenceTransformerWrapper(model, output_key)
|
||||
if isinstance(model.config, transformers.MPNetConfig):
|
||||
return _MPNetSentenceTransformerWrapper(model, output_key)
|
||||
else:
|
||||
return _SentenceTransformerWrapper(model, output_key)
|
||||
else:
|
||||
return None
|
||||
|
||||
@ -144,6 +149,11 @@ class _SentenceTransformerWrapper(nn.Module): # type: ignore
|
||||
|
||||
self._st_model._modules["0"].auto_model = self._hf_model
|
||||
|
||||
|
||||
class _SentenceTransformerWrapper(_SentenceTransformerWrapperModule):
|
||||
def __init__(self, model: PreTrainedModel, output_key: str = DEFAULT_OUTPUT_KEY):
|
||||
super().__init__(model=model, output_key=output_key)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Tensor,
|
||||
@ -167,6 +177,19 @@ class _SentenceTransformerWrapper(nn.Module): # type: ignore
|
||||
return self._st_model(inputs)[self._output_key]
|
||||
|
||||
|
||||
class _MPNetSentenceTransformerWrapper(_SentenceTransformerWrapperModule):
|
||||
def __init__(self, model: PreTrainedModel, output_key: str = DEFAULT_OUTPUT_KEY):
|
||||
super().__init__(model=model, output_key=output_key)
|
||||
|
||||
def forward(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor:
|
||||
"""Wrap the input and output to conform to the native process interface."""
|
||||
inputs = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
return self._st_model(inputs)[self._output_key]
|
||||
|
||||
|
||||
class _DPREncoderWrapper(nn.Module): # type: ignore
|
||||
"""
|
||||
AutoModel loading does not work for DPRContextEncoders, this only exists as
|
||||
@ -186,6 +209,7 @@ class _DPREncoderWrapper(nn.Module): # type: ignore
|
||||
):
|
||||
super().__init__()
|
||||
self._model = model
|
||||
self.config = model.config
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained(model_id: str) -> Optional[Any]:
|
||||
@ -232,7 +256,7 @@ class _TraceableModel(ABC):
|
||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||
model: Union[
|
||||
PreTrainedModel,
|
||||
_SentenceTransformerWrapper,
|
||||
_SentenceTransformerWrapperModule,
|
||||
_DPREncoderWrapper,
|
||||
_DistilBertWrapper,
|
||||
],
|
||||
@ -250,13 +274,19 @@ class _TraceableModel(ABC):
|
||||
self._model.eval()
|
||||
|
||||
inputs = self._prepare_inputs()
|
||||
position_ids = torch.arange(inputs["input_ids"].size(1), dtype=torch.long)
|
||||
|
||||
# Add params when not provided by the tokenizer (e.g. DistilBERT), to conform to BERT interface
|
||||
if "token_type_ids" not in inputs:
|
||||
inputs["token_type_ids"] = torch.zeros(
|
||||
inputs["input_ids"].size(1), dtype=torch.long
|
||||
)
|
||||
if isinstance(self._model.config, transformers.MPNetConfig):
|
||||
return torch.jit.trace(
|
||||
self._model,
|
||||
(inputs["input_ids"], inputs["attention_mask"]),
|
||||
)
|
||||
|
||||
position_ids = torch.arange(inputs["input_ids"].size(1), dtype=torch.long)
|
||||
|
||||
return torch.jit.trace(
|
||||
self._model,
|
||||
@ -369,10 +399,15 @@ class TransformerModel:
|
||||
}
|
||||
|
||||
def _create_config(self) -> Dict[str, Any]:
|
||||
tokenizer_type = (
|
||||
"mpnet"
|
||||
if isinstance(self._tokenizer, transformers.MPNetTokenizer)
|
||||
else "bert"
|
||||
)
|
||||
inference_config: Dict[str, Dict[str, Any]] = {
|
||||
self._task_type: {
|
||||
"tokenization": {
|
||||
"bert": {
|
||||
tokenizer_type: {
|
||||
"do_lower_case": getattr(
|
||||
self._tokenizer, "do_lower_case", False
|
||||
),
|
||||
@ -386,7 +421,7 @@ class TransformerModel:
|
||||
self._model_id
|
||||
)
|
||||
if max_sequence_length:
|
||||
inference_config[self._task_type]["tokenization"]["bert"][
|
||||
inference_config[self._task_type]["tokenization"][tokenizer_type][
|
||||
"max_sequence_length"
|
||||
] = max_sequence_length
|
||||
|
||||
@ -427,7 +462,7 @@ class TransformerModel:
|
||||
return _TraceableTextClassificationModel(self._tokenizer, model)
|
||||
|
||||
elif self._task_type == "text_embedding":
|
||||
model = _SentenceTransformerWrapper.from_pretrained(self._model_id)
|
||||
model = _SentenceTransformerWrapperModule.from_pretrained(self._model_id)
|
||||
if not model:
|
||||
model = _DPREncoderWrapper.from_pretrained(self._model_id)
|
||||
if not model:
|
||||
|
@ -17,6 +17,7 @@
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
from elasticsearch import NotFoundError
|
||||
|
||||
try:
|
||||
import sklearn # noqa: F401
|
||||
@ -68,8 +69,11 @@ def setup_and_tear_down():
|
||||
yield
|
||||
for model_id, _, _, _ in TEXT_PREDICTION_MODELS:
|
||||
model = PyTorchModel(ES_TEST_CLIENT, model_id.replace("/", "__").lower()[:64])
|
||||
model.stop()
|
||||
model.delete()
|
||||
try:
|
||||
model.stop()
|
||||
model.delete()
|
||||
except NotFoundError:
|
||||
pass
|
||||
|
||||
|
||||
def download_model_and_start_deployment(tmp_dir, quantize, model_id, task):
|
||||
@ -77,8 +81,11 @@ def download_model_and_start_deployment(tmp_dir, quantize, model_id, task):
|
||||
tm = TransformerModel(model_id, task, quantize)
|
||||
model_path, config_path, vocab_path = tm.save(tmp_dir)
|
||||
ptm = PyTorchModel(ES_TEST_CLIENT, tm.elasticsearch_model_id())
|
||||
ptm.stop()
|
||||
ptm.delete()
|
||||
try:
|
||||
ptm.stop()
|
||||
ptm.delete()
|
||||
except NotFoundError:
|
||||
pass
|
||||
print(f"Importing model: {ptm.model_id}")
|
||||
ptm.import_model(model_path, config_path, vocab_path)
|
||||
ptm.start()
|
||||
|
Loading…
x
Reference in New Issue
Block a user