[ML] Add support for MPNet PyTorch models

This commit is contained in:
Benjamin Trent 2022-01-10 12:21:30 -05:00 committed by GitHub
parent 64daa07a65
commit 72856e2c3f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 53 additions and 11 deletions

View File

@ -48,6 +48,7 @@ SUPPORTED_TASK_TYPES = {
SUPPORTED_TASK_TYPES_NAMES = ", ".join(sorted(SUPPORTED_TASK_TYPES))
SUPPORTED_TOKENIZERS = (
transformers.BertTokenizer,
transformers.MPNetTokenizer,
transformers.DPRContextEncoderTokenizer,
transformers.DPRQuestionEncoderTokenizer,
transformers.DistilBertTokenizer,
@ -96,7 +97,7 @@ class _DistilBertWrapper(nn.Module): # type: ignore
return self._model(input_ids=input_ids, attention_mask=attention_mask)
class _SentenceTransformerWrapper(nn.Module): # type: ignore
class _SentenceTransformerWrapperModule(nn.Module): # type: ignore
"""
A wrapper around sentence-transformer models to provide pooling,
normalization and other graph layers that are not defined in the base
@ -108,6 +109,7 @@ class _SentenceTransformerWrapper(nn.Module): # type: ignore
self._hf_model = model
self._st_model = SentenceTransformer(model.config.name_or_path)
self._output_key = output_key
self.config = model.config
self._remove_pooling_layer()
self._replace_transformer_layer()
@ -118,7 +120,10 @@ class _SentenceTransformerWrapper(nn.Module): # type: ignore
) -> Optional[Any]:
if model_id.startswith("sentence-transformers/"):
model = AutoModel.from_pretrained(model_id, torchscript=True)
return _SentenceTransformerWrapper(model, output_key)
if isinstance(model.config, transformers.MPNetConfig):
return _MPNetSentenceTransformerWrapper(model, output_key)
else:
return _SentenceTransformerWrapper(model, output_key)
else:
return None
@ -144,6 +149,11 @@ class _SentenceTransformerWrapper(nn.Module): # type: ignore
self._st_model._modules["0"].auto_model = self._hf_model
class _SentenceTransformerWrapper(_SentenceTransformerWrapperModule):
def __init__(self, model: PreTrainedModel, output_key: str = DEFAULT_OUTPUT_KEY):
super().__init__(model=model, output_key=output_key)
def forward(
self,
input_ids: Tensor,
@ -167,6 +177,19 @@ class _SentenceTransformerWrapper(nn.Module): # type: ignore
return self._st_model(inputs)[self._output_key]
class _MPNetSentenceTransformerWrapper(_SentenceTransformerWrapperModule):
def __init__(self, model: PreTrainedModel, output_key: str = DEFAULT_OUTPUT_KEY):
super().__init__(model=model, output_key=output_key)
def forward(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor:
"""Wrap the input and output to conform to the native process interface."""
inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
}
return self._st_model(inputs)[self._output_key]
class _DPREncoderWrapper(nn.Module): # type: ignore
"""
AutoModel loading does not work for DPRContextEncoders, this only exists as
@ -186,6 +209,7 @@ class _DPREncoderWrapper(nn.Module): # type: ignore
):
super().__init__()
self._model = model
self.config = model.config
@staticmethod
def from_pretrained(model_id: str) -> Optional[Any]:
@ -232,7 +256,7 @@ class _TraceableModel(ABC):
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
model: Union[
PreTrainedModel,
_SentenceTransformerWrapper,
_SentenceTransformerWrapperModule,
_DPREncoderWrapper,
_DistilBertWrapper,
],
@ -250,13 +274,19 @@ class _TraceableModel(ABC):
self._model.eval()
inputs = self._prepare_inputs()
position_ids = torch.arange(inputs["input_ids"].size(1), dtype=torch.long)
# Add params when not provided by the tokenizer (e.g. DistilBERT), to conform to BERT interface
if "token_type_ids" not in inputs:
inputs["token_type_ids"] = torch.zeros(
inputs["input_ids"].size(1), dtype=torch.long
)
if isinstance(self._model.config, transformers.MPNetConfig):
return torch.jit.trace(
self._model,
(inputs["input_ids"], inputs["attention_mask"]),
)
position_ids = torch.arange(inputs["input_ids"].size(1), dtype=torch.long)
return torch.jit.trace(
self._model,
@ -369,10 +399,15 @@ class TransformerModel:
}
def _create_config(self) -> Dict[str, Any]:
tokenizer_type = (
"mpnet"
if isinstance(self._tokenizer, transformers.MPNetTokenizer)
else "bert"
)
inference_config: Dict[str, Dict[str, Any]] = {
self._task_type: {
"tokenization": {
"bert": {
tokenizer_type: {
"do_lower_case": getattr(
self._tokenizer, "do_lower_case", False
),
@ -386,7 +421,7 @@ class TransformerModel:
self._model_id
)
if max_sequence_length:
inference_config[self._task_type]["tokenization"]["bert"][
inference_config[self._task_type]["tokenization"][tokenizer_type][
"max_sequence_length"
] = max_sequence_length
@ -427,7 +462,7 @@ class TransformerModel:
return _TraceableTextClassificationModel(self._tokenizer, model)
elif self._task_type == "text_embedding":
model = _SentenceTransformerWrapper.from_pretrained(self._model_id)
model = _SentenceTransformerWrapperModule.from_pretrained(self._model_id)
if not model:
model = _DPREncoderWrapper.from_pretrained(self._model_id)
if not model:

View File

@ -17,6 +17,7 @@
import tempfile
import pytest
from elasticsearch import NotFoundError
try:
import sklearn # noqa: F401
@ -68,8 +69,11 @@ def setup_and_tear_down():
yield
for model_id, _, _, _ in TEXT_PREDICTION_MODELS:
model = PyTorchModel(ES_TEST_CLIENT, model_id.replace("/", "__").lower()[:64])
model.stop()
model.delete()
try:
model.stop()
model.delete()
except NotFoundError:
pass
def download_model_and_start_deployment(tmp_dir, quantize, model_id, task):
@ -77,8 +81,11 @@ def download_model_and_start_deployment(tmp_dir, quantize, model_id, task):
tm = TransformerModel(model_id, task, quantize)
model_path, config_path, vocab_path = tm.save(tmp_dir)
ptm = PyTorchModel(ES_TEST_CLIENT, tm.elasticsearch_model_id())
ptm.stop()
ptm.delete()
try:
ptm.stop()
ptm.delete()
except NotFoundError:
pass
print(f"Importing model: {ptm.model_id}")
ptm.import_model(model_path, config_path, vocab_path)
ptm.start()