From 15a300728876022b206161d71055c67b500a0192 Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Mon, 14 Mar 2022 12:26:12 -0400 Subject: [PATCH] [ML] add roberta bart transformer upload support (#443) Related to: https://github.com/elastic/elasticsearch/pull/84777 This allows BART and RoBERTa models to be uploaded to Elasticsearch for our currently defined NLP tasks. --- eland/ml/pytorch/_pytorch_model.py | 7 +++- eland/ml/pytorch/transformers.py | 65 +++++++++++++++++++++--------- 2 files changed, 51 insertions(+), 21 deletions(-) diff --git a/eland/ml/pytorch/_pytorch_model.py b/eland/ml/pytorch/_pytorch_model.py index 4cf2a42..f762394 100644 --- a/eland/ml/pytorch/_pytorch_model.py +++ b/eland/ml/pytorch/_pytorch_model.py @@ -57,8 +57,11 @@ class PyTorchModel: def put_vocab(self, path: str) -> None: with open(path) as f: vocab = json.load(f) - self._client.ml.put_trained_model_vocabulary( - model_id=self.model_id, vocabulary=vocab["vocabulary"] + self._client.perform_request( + method="PUT", + path=f"/_ml/trained_models/{self.model_id}/vocabulary", + headers={"accept": "application/json", "content-type": "application/json"}, + body=vocab, ) def put_model(self, model_path: str, chunk_size: int = DEFAULT_CHUNK_SIZE) -> None: diff --git a/eland/ml/pytorch/transformers.py b/eland/ml/pytorch/transformers.py index d24bffe..ed6f002 100644 --- a/eland/ml/pytorch/transformers.py +++ b/eland/ml/pytorch/transformers.py @@ -55,6 +55,8 @@ SUPPORTED_TOKENIZERS = ( transformers.ElectraTokenizer, transformers.MobileBertTokenizer, transformers.RetriBertTokenizer, + transformers.RobertaTokenizer, + transformers.BartTokenizer, transformers.SqueezeBertTokenizer, ) SUPPORTED_TOKENIZERS_NAMES = ", ".join(sorted([str(x) for x in SUPPORTED_TOKENIZERS])) @@ -120,8 +122,15 @@ class _SentenceTransformerWrapperModule(nn.Module): # type: ignore ) -> Optional[Any]: if model_id.startswith("sentence-transformers/"): model = AutoModel.from_pretrained(model_id, torchscript=True) - if isinstance(model.config, transformers.MPNetConfig): - return _MPNetSentenceTransformerWrapper(model, output_key) + if isinstance( + model.config, + ( + transformers.MPNetConfig, + transformers.RobertaConfig, + transformers.BartConfig, + ), + ): + return _TwoParameterSentenceTransformerWrapper(model, output_key) else: return _SentenceTransformerWrapper(model, output_key) else: @@ -177,7 +186,7 @@ class _SentenceTransformerWrapper(_SentenceTransformerWrapperModule): return self._st_model(inputs)[self._output_key] -class _MPNetSentenceTransformerWrapper(_SentenceTransformerWrapperModule): +class _TwoParameterSentenceTransformerWrapper(_SentenceTransformerWrapperModule): def __init__(self, model: PreTrainedModel, output_key: str = DEFAULT_OUTPUT_KEY): super().__init__(model=model, output_key=output_key) @@ -280,7 +289,14 @@ class _TraceableModel(ABC): inputs["token_type_ids"] = torch.zeros( inputs["input_ids"].size(1), dtype=torch.long ) - if isinstance(self._model.config, transformers.MPNetConfig): + if isinstance( + self._model.config, + ( + transformers.MPNetConfig, + transformers.RobertaConfig, + transformers.BartConfig, + ), + ): return torch.jit.trace( self._model, (inputs["input_ids"], inputs["attention_mask"]), @@ -394,26 +410,37 @@ class TransformerModel: def _load_vocab(self) -> Dict[str, List[str]]: vocab_items = self._tokenizer.get_vocab().items() vocabulary = [k for k, _ in sorted(vocab_items, key=lambda kv: kv[1])] # type: ignore - return { + vocab_obj = { "vocabulary": vocabulary, } + ranks = getattr(self._tokenizer, "bpe_ranks", {}) + if len(ranks) > 0: + merges = [ + " ".join(m) for m, _ in sorted(ranks.items(), key=lambda kv: kv[1]) + ] + vocab_obj["merges"] = merges + return vocab_obj def _create_config(self) -> Dict[str, Any]: - tokenizer_type = ( - "mpnet" - if isinstance(self._tokenizer, transformers.MPNetTokenizer) - else "bert" - ) - inference_config: Dict[str, Dict[str, Any]] = { - self._task_type: { - "tokenization": { - tokenizer_type: { - "do_lower_case": getattr( - self._tokenizer, "do_lower_case", False - ), - } - } + if isinstance(self._tokenizer, transformers.MPNetTokenizer): + tokenizer_type = "mpnet" + tokenizer_obj = { + "do_lower_case": getattr(self._tokenizer, "do_lower_case", False) } + elif isinstance( + self._tokenizer, (transformers.RobertaTokenizer, transformers.BartTokenizer) + ): + tokenizer_type = "roberta" + tokenizer_obj = { + "add_prefix_space": getattr(self._tokenizer, "add_prefix_space", False) + } + else: + tokenizer_type = "bert" + tokenizer_obj = { + "do_lower_case": getattr(self._tokenizer, "do_lower_case", False) + } + inference_config: Dict[str, Dict[str, Any]] = { + self._task_type: {"tokenization": {tokenizer_type: tokenizer_obj}} } if hasattr(self._tokenizer, "max_model_input_sizes"):