[ML] add roberta bart transformer upload support (#443)

Related to: https://github.com/elastic/elasticsearch/pull/84777

This allows BART and RoBERTa models to be uploaded to Elasticsearch for our currently defined NLP tasks.
This commit is contained in:
Benjamin Trent 2022-03-14 12:26:12 -04:00 committed by GitHub
parent 5678525b15
commit 15a3007288
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 51 additions and 21 deletions

View File

@ -57,8 +57,11 @@ class PyTorchModel:
def put_vocab(self, path: str) -> None:
with open(path) as f:
vocab = json.load(f)
self._client.ml.put_trained_model_vocabulary(
model_id=self.model_id, vocabulary=vocab["vocabulary"]
self._client.perform_request(
method="PUT",
path=f"/_ml/trained_models/{self.model_id}/vocabulary",
headers={"accept": "application/json", "content-type": "application/json"},
body=vocab,
)
def put_model(self, model_path: str, chunk_size: int = DEFAULT_CHUNK_SIZE) -> None:

View File

@ -55,6 +55,8 @@ SUPPORTED_TOKENIZERS = (
transformers.ElectraTokenizer,
transformers.MobileBertTokenizer,
transformers.RetriBertTokenizer,
transformers.RobertaTokenizer,
transformers.BartTokenizer,
transformers.SqueezeBertTokenizer,
)
SUPPORTED_TOKENIZERS_NAMES = ", ".join(sorted([str(x) for x in SUPPORTED_TOKENIZERS]))
@ -120,8 +122,15 @@ class _SentenceTransformerWrapperModule(nn.Module): # type: ignore
) -> Optional[Any]:
if model_id.startswith("sentence-transformers/"):
model = AutoModel.from_pretrained(model_id, torchscript=True)
if isinstance(model.config, transformers.MPNetConfig):
return _MPNetSentenceTransformerWrapper(model, output_key)
if isinstance(
model.config,
(
transformers.MPNetConfig,
transformers.RobertaConfig,
transformers.BartConfig,
),
):
return _TwoParameterSentenceTransformerWrapper(model, output_key)
else:
return _SentenceTransformerWrapper(model, output_key)
else:
@ -177,7 +186,7 @@ class _SentenceTransformerWrapper(_SentenceTransformerWrapperModule):
return self._st_model(inputs)[self._output_key]
class _MPNetSentenceTransformerWrapper(_SentenceTransformerWrapperModule):
class _TwoParameterSentenceTransformerWrapper(_SentenceTransformerWrapperModule):
def __init__(self, model: PreTrainedModel, output_key: str = DEFAULT_OUTPUT_KEY):
super().__init__(model=model, output_key=output_key)
@ -280,7 +289,14 @@ class _TraceableModel(ABC):
inputs["token_type_ids"] = torch.zeros(
inputs["input_ids"].size(1), dtype=torch.long
)
if isinstance(self._model.config, transformers.MPNetConfig):
if isinstance(
self._model.config,
(
transformers.MPNetConfig,
transformers.RobertaConfig,
transformers.BartConfig,
),
):
return torch.jit.trace(
self._model,
(inputs["input_ids"], inputs["attention_mask"]),
@ -394,26 +410,37 @@ class TransformerModel:
def _load_vocab(self) -> Dict[str, List[str]]:
vocab_items = self._tokenizer.get_vocab().items()
vocabulary = [k for k, _ in sorted(vocab_items, key=lambda kv: kv[1])] # type: ignore
return {
vocab_obj = {
"vocabulary": vocabulary,
}
ranks = getattr(self._tokenizer, "bpe_ranks", {})
if len(ranks) > 0:
merges = [
" ".join(m) for m, _ in sorted(ranks.items(), key=lambda kv: kv[1])
]
vocab_obj["merges"] = merges
return vocab_obj
def _create_config(self) -> Dict[str, Any]:
tokenizer_type = (
"mpnet"
if isinstance(self._tokenizer, transformers.MPNetTokenizer)
else "bert"
)
inference_config: Dict[str, Dict[str, Any]] = {
self._task_type: {
"tokenization": {
tokenizer_type: {
"do_lower_case": getattr(
self._tokenizer, "do_lower_case", False
),
}
}
if isinstance(self._tokenizer, transformers.MPNetTokenizer):
tokenizer_type = "mpnet"
tokenizer_obj = {
"do_lower_case": getattr(self._tokenizer, "do_lower_case", False)
}
elif isinstance(
self._tokenizer, (transformers.RobertaTokenizer, transformers.BartTokenizer)
):
tokenizer_type = "roberta"
tokenizer_obj = {
"add_prefix_space": getattr(self._tokenizer, "add_prefix_space", False)
}
else:
tokenizer_type = "bert"
tokenizer_obj = {
"do_lower_case": getattr(self._tokenizer, "do_lower_case", False)
}
inference_config: Dict[str, Dict[str, Any]] = {
self._task_type: {"tokenization": {tokenizer_type: tokenizer_obj}}
}
if hasattr(self._tokenizer, "max_model_input_sizes"):