From 72856e2c3f827a0b71d140323009a7a9a3df6e1d Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Mon, 10 Jan 2022 12:21:30 -0500 Subject: [PATCH] [ML] Add support for MPNet PyTorch models --- eland/ml/pytorch/transformers.py | 49 ++++++++++++++++--- tests/ml/pytorch/test_pytorch_model_pytest.py | 15 ++++-- 2 files changed, 53 insertions(+), 11 deletions(-) diff --git a/eland/ml/pytorch/transformers.py b/eland/ml/pytorch/transformers.py index 4106967..d24bffe 100644 --- a/eland/ml/pytorch/transformers.py +++ b/eland/ml/pytorch/transformers.py @@ -48,6 +48,7 @@ SUPPORTED_TASK_TYPES = { SUPPORTED_TASK_TYPES_NAMES = ", ".join(sorted(SUPPORTED_TASK_TYPES)) SUPPORTED_TOKENIZERS = ( transformers.BertTokenizer, + transformers.MPNetTokenizer, transformers.DPRContextEncoderTokenizer, transformers.DPRQuestionEncoderTokenizer, transformers.DistilBertTokenizer, @@ -96,7 +97,7 @@ class _DistilBertWrapper(nn.Module): # type: ignore return self._model(input_ids=input_ids, attention_mask=attention_mask) -class _SentenceTransformerWrapper(nn.Module): # type: ignore +class _SentenceTransformerWrapperModule(nn.Module): # type: ignore """ A wrapper around sentence-transformer models to provide pooling, normalization and other graph layers that are not defined in the base @@ -108,6 +109,7 @@ class _SentenceTransformerWrapper(nn.Module): # type: ignore self._hf_model = model self._st_model = SentenceTransformer(model.config.name_or_path) self._output_key = output_key + self.config = model.config self._remove_pooling_layer() self._replace_transformer_layer() @@ -118,7 +120,10 @@ class _SentenceTransformerWrapper(nn.Module): # type: ignore ) -> Optional[Any]: if model_id.startswith("sentence-transformers/"): model = AutoModel.from_pretrained(model_id, torchscript=True) - return _SentenceTransformerWrapper(model, output_key) + if isinstance(model.config, transformers.MPNetConfig): + return _MPNetSentenceTransformerWrapper(model, output_key) + else: + return _SentenceTransformerWrapper(model, output_key) else: return None @@ -144,6 +149,11 @@ class _SentenceTransformerWrapper(nn.Module): # type: ignore self._st_model._modules["0"].auto_model = self._hf_model + +class _SentenceTransformerWrapper(_SentenceTransformerWrapperModule): + def __init__(self, model: PreTrainedModel, output_key: str = DEFAULT_OUTPUT_KEY): + super().__init__(model=model, output_key=output_key) + def forward( self, input_ids: Tensor, @@ -167,6 +177,19 @@ class _SentenceTransformerWrapper(nn.Module): # type: ignore return self._st_model(inputs)[self._output_key] +class _MPNetSentenceTransformerWrapper(_SentenceTransformerWrapperModule): + def __init__(self, model: PreTrainedModel, output_key: str = DEFAULT_OUTPUT_KEY): + super().__init__(model=model, output_key=output_key) + + def forward(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor: + """Wrap the input and output to conform to the native process interface.""" + inputs = { + "input_ids": input_ids, + "attention_mask": attention_mask, + } + return self._st_model(inputs)[self._output_key] + + class _DPREncoderWrapper(nn.Module): # type: ignore """ AutoModel loading does not work for DPRContextEncoders, this only exists as @@ -186,6 +209,7 @@ class _DPREncoderWrapper(nn.Module): # type: ignore ): super().__init__() self._model = model + self.config = model.config @staticmethod def from_pretrained(model_id: str) -> Optional[Any]: @@ -232,7 +256,7 @@ class _TraceableModel(ABC): tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], model: Union[ PreTrainedModel, - _SentenceTransformerWrapper, + _SentenceTransformerWrapperModule, _DPREncoderWrapper, _DistilBertWrapper, ], @@ -250,13 +274,19 @@ class _TraceableModel(ABC): self._model.eval() inputs = self._prepare_inputs() - position_ids = torch.arange(inputs["input_ids"].size(1), dtype=torch.long) # Add params when not provided by the tokenizer (e.g. DistilBERT), to conform to BERT interface if "token_type_ids" not in inputs: inputs["token_type_ids"] = torch.zeros( inputs["input_ids"].size(1), dtype=torch.long ) + if isinstance(self._model.config, transformers.MPNetConfig): + return torch.jit.trace( + self._model, + (inputs["input_ids"], inputs["attention_mask"]), + ) + + position_ids = torch.arange(inputs["input_ids"].size(1), dtype=torch.long) return torch.jit.trace( self._model, @@ -369,10 +399,15 @@ class TransformerModel: } def _create_config(self) -> Dict[str, Any]: + tokenizer_type = ( + "mpnet" + if isinstance(self._tokenizer, transformers.MPNetTokenizer) + else "bert" + ) inference_config: Dict[str, Dict[str, Any]] = { self._task_type: { "tokenization": { - "bert": { + tokenizer_type: { "do_lower_case": getattr( self._tokenizer, "do_lower_case", False ), @@ -386,7 +421,7 @@ class TransformerModel: self._model_id ) if max_sequence_length: - inference_config[self._task_type]["tokenization"]["bert"][ + inference_config[self._task_type]["tokenization"][tokenizer_type][ "max_sequence_length" ] = max_sequence_length @@ -427,7 +462,7 @@ class TransformerModel: return _TraceableTextClassificationModel(self._tokenizer, model) elif self._task_type == "text_embedding": - model = _SentenceTransformerWrapper.from_pretrained(self._model_id) + model = _SentenceTransformerWrapperModule.from_pretrained(self._model_id) if not model: model = _DPREncoderWrapper.from_pretrained(self._model_id) if not model: diff --git a/tests/ml/pytorch/test_pytorch_model_pytest.py b/tests/ml/pytorch/test_pytorch_model_pytest.py index 39676e5..5cbd3ad 100644 --- a/tests/ml/pytorch/test_pytorch_model_pytest.py +++ b/tests/ml/pytorch/test_pytorch_model_pytest.py @@ -17,6 +17,7 @@ import tempfile import pytest +from elasticsearch import NotFoundError try: import sklearn # noqa: F401 @@ -68,8 +69,11 @@ def setup_and_tear_down(): yield for model_id, _, _, _ in TEXT_PREDICTION_MODELS: model = PyTorchModel(ES_TEST_CLIENT, model_id.replace("/", "__").lower()[:64]) - model.stop() - model.delete() + try: + model.stop() + model.delete() + except NotFoundError: + pass def download_model_and_start_deployment(tmp_dir, quantize, model_id, task): @@ -77,8 +81,11 @@ def download_model_and_start_deployment(tmp_dir, quantize, model_id, task): tm = TransformerModel(model_id, task, quantize) model_path, config_path, vocab_path = tm.save(tmp_dir) ptm = PyTorchModel(ES_TEST_CLIENT, tm.elasticsearch_model_id()) - ptm.stop() - ptm.delete() + try: + ptm.stop() + ptm.delete() + except NotFoundError: + pass print(f"Importing model: {ptm.model_id}") ptm.import_model(model_path, config_path, vocab_path) ptm.start()