[ML] Add support for MPNet PyTorch models

This commit is contained in:
Benjamin Trent 2022-01-10 12:21:30 -05:00 committed by GitHub
parent 64daa07a65
commit 72856e2c3f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 53 additions and 11 deletions

View File

@ -48,6 +48,7 @@ SUPPORTED_TASK_TYPES = {
SUPPORTED_TASK_TYPES_NAMES = ", ".join(sorted(SUPPORTED_TASK_TYPES)) SUPPORTED_TASK_TYPES_NAMES = ", ".join(sorted(SUPPORTED_TASK_TYPES))
SUPPORTED_TOKENIZERS = ( SUPPORTED_TOKENIZERS = (
transformers.BertTokenizer, transformers.BertTokenizer,
transformers.MPNetTokenizer,
transformers.DPRContextEncoderTokenizer, transformers.DPRContextEncoderTokenizer,
transformers.DPRQuestionEncoderTokenizer, transformers.DPRQuestionEncoderTokenizer,
transformers.DistilBertTokenizer, transformers.DistilBertTokenizer,
@ -96,7 +97,7 @@ class _DistilBertWrapper(nn.Module): # type: ignore
return self._model(input_ids=input_ids, attention_mask=attention_mask) return self._model(input_ids=input_ids, attention_mask=attention_mask)
class _SentenceTransformerWrapper(nn.Module): # type: ignore class _SentenceTransformerWrapperModule(nn.Module): # type: ignore
""" """
A wrapper around sentence-transformer models to provide pooling, A wrapper around sentence-transformer models to provide pooling,
normalization and other graph layers that are not defined in the base normalization and other graph layers that are not defined in the base
@ -108,6 +109,7 @@ class _SentenceTransformerWrapper(nn.Module): # type: ignore
self._hf_model = model self._hf_model = model
self._st_model = SentenceTransformer(model.config.name_or_path) self._st_model = SentenceTransformer(model.config.name_or_path)
self._output_key = output_key self._output_key = output_key
self.config = model.config
self._remove_pooling_layer() self._remove_pooling_layer()
self._replace_transformer_layer() self._replace_transformer_layer()
@ -118,6 +120,9 @@ class _SentenceTransformerWrapper(nn.Module): # type: ignore
) -> Optional[Any]: ) -> Optional[Any]:
if model_id.startswith("sentence-transformers/"): if model_id.startswith("sentence-transformers/"):
model = AutoModel.from_pretrained(model_id, torchscript=True) model = AutoModel.from_pretrained(model_id, torchscript=True)
if isinstance(model.config, transformers.MPNetConfig):
return _MPNetSentenceTransformerWrapper(model, output_key)
else:
return _SentenceTransformerWrapper(model, output_key) return _SentenceTransformerWrapper(model, output_key)
else: else:
return None return None
@ -144,6 +149,11 @@ class _SentenceTransformerWrapper(nn.Module): # type: ignore
self._st_model._modules["0"].auto_model = self._hf_model self._st_model._modules["0"].auto_model = self._hf_model
class _SentenceTransformerWrapper(_SentenceTransformerWrapperModule):
def __init__(self, model: PreTrainedModel, output_key: str = DEFAULT_OUTPUT_KEY):
super().__init__(model=model, output_key=output_key)
def forward( def forward(
self, self,
input_ids: Tensor, input_ids: Tensor,
@ -167,6 +177,19 @@ class _SentenceTransformerWrapper(nn.Module): # type: ignore
return self._st_model(inputs)[self._output_key] return self._st_model(inputs)[self._output_key]
class _MPNetSentenceTransformerWrapper(_SentenceTransformerWrapperModule):
def __init__(self, model: PreTrainedModel, output_key: str = DEFAULT_OUTPUT_KEY):
super().__init__(model=model, output_key=output_key)
def forward(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor:
"""Wrap the input and output to conform to the native process interface."""
inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
}
return self._st_model(inputs)[self._output_key]
class _DPREncoderWrapper(nn.Module): # type: ignore class _DPREncoderWrapper(nn.Module): # type: ignore
""" """
AutoModel loading does not work for DPRContextEncoders, this only exists as AutoModel loading does not work for DPRContextEncoders, this only exists as
@ -186,6 +209,7 @@ class _DPREncoderWrapper(nn.Module): # type: ignore
): ):
super().__init__() super().__init__()
self._model = model self._model = model
self.config = model.config
@staticmethod @staticmethod
def from_pretrained(model_id: str) -> Optional[Any]: def from_pretrained(model_id: str) -> Optional[Any]:
@ -232,7 +256,7 @@ class _TraceableModel(ABC):
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
model: Union[ model: Union[
PreTrainedModel, PreTrainedModel,
_SentenceTransformerWrapper, _SentenceTransformerWrapperModule,
_DPREncoderWrapper, _DPREncoderWrapper,
_DistilBertWrapper, _DistilBertWrapper,
], ],
@ -250,13 +274,19 @@ class _TraceableModel(ABC):
self._model.eval() self._model.eval()
inputs = self._prepare_inputs() inputs = self._prepare_inputs()
position_ids = torch.arange(inputs["input_ids"].size(1), dtype=torch.long)
# Add params when not provided by the tokenizer (e.g. DistilBERT), to conform to BERT interface # Add params when not provided by the tokenizer (e.g. DistilBERT), to conform to BERT interface
if "token_type_ids" not in inputs: if "token_type_ids" not in inputs:
inputs["token_type_ids"] = torch.zeros( inputs["token_type_ids"] = torch.zeros(
inputs["input_ids"].size(1), dtype=torch.long inputs["input_ids"].size(1), dtype=torch.long
) )
if isinstance(self._model.config, transformers.MPNetConfig):
return torch.jit.trace(
self._model,
(inputs["input_ids"], inputs["attention_mask"]),
)
position_ids = torch.arange(inputs["input_ids"].size(1), dtype=torch.long)
return torch.jit.trace( return torch.jit.trace(
self._model, self._model,
@ -369,10 +399,15 @@ class TransformerModel:
} }
def _create_config(self) -> Dict[str, Any]: def _create_config(self) -> Dict[str, Any]:
tokenizer_type = (
"mpnet"
if isinstance(self._tokenizer, transformers.MPNetTokenizer)
else "bert"
)
inference_config: Dict[str, Dict[str, Any]] = { inference_config: Dict[str, Dict[str, Any]] = {
self._task_type: { self._task_type: {
"tokenization": { "tokenization": {
"bert": { tokenizer_type: {
"do_lower_case": getattr( "do_lower_case": getattr(
self._tokenizer, "do_lower_case", False self._tokenizer, "do_lower_case", False
), ),
@ -386,7 +421,7 @@ class TransformerModel:
self._model_id self._model_id
) )
if max_sequence_length: if max_sequence_length:
inference_config[self._task_type]["tokenization"]["bert"][ inference_config[self._task_type]["tokenization"][tokenizer_type][
"max_sequence_length" "max_sequence_length"
] = max_sequence_length ] = max_sequence_length
@ -427,7 +462,7 @@ class TransformerModel:
return _TraceableTextClassificationModel(self._tokenizer, model) return _TraceableTextClassificationModel(self._tokenizer, model)
elif self._task_type == "text_embedding": elif self._task_type == "text_embedding":
model = _SentenceTransformerWrapper.from_pretrained(self._model_id) model = _SentenceTransformerWrapperModule.from_pretrained(self._model_id)
if not model: if not model:
model = _DPREncoderWrapper.from_pretrained(self._model_id) model = _DPREncoderWrapper.from_pretrained(self._model_id)
if not model: if not model:

View File

@ -17,6 +17,7 @@
import tempfile import tempfile
import pytest import pytest
from elasticsearch import NotFoundError
try: try:
import sklearn # noqa: F401 import sklearn # noqa: F401
@ -68,8 +69,11 @@ def setup_and_tear_down():
yield yield
for model_id, _, _, _ in TEXT_PREDICTION_MODELS: for model_id, _, _, _ in TEXT_PREDICTION_MODELS:
model = PyTorchModel(ES_TEST_CLIENT, model_id.replace("/", "__").lower()[:64]) model = PyTorchModel(ES_TEST_CLIENT, model_id.replace("/", "__").lower()[:64])
try:
model.stop() model.stop()
model.delete() model.delete()
except NotFoundError:
pass
def download_model_and_start_deployment(tmp_dir, quantize, model_id, task): def download_model_and_start_deployment(tmp_dir, quantize, model_id, task):
@ -77,8 +81,11 @@ def download_model_and_start_deployment(tmp_dir, quantize, model_id, task):
tm = TransformerModel(model_id, task, quantize) tm = TransformerModel(model_id, task, quantize)
model_path, config_path, vocab_path = tm.save(tmp_dir) model_path, config_path, vocab_path = tm.save(tmp_dir)
ptm = PyTorchModel(ES_TEST_CLIENT, tm.elasticsearch_model_id()) ptm = PyTorchModel(ES_TEST_CLIENT, tm.elasticsearch_model_id())
try:
ptm.stop() ptm.stop()
ptm.delete() ptm.delete()
except NotFoundError:
pass
print(f"Importing model: {ptm.model_id}") print(f"Importing model: {ptm.model_id}")
ptm.import_model(model_path, config_path, vocab_path) ptm.import_model(model_path, config_path, vocab_path)
ptm.start() ptm.start()