mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
[ML] add roberta bart transformer upload support (#443)
Related to: https://github.com/elastic/elasticsearch/pull/84777 This allows BART and RoBERTa models to be uploaded to Elasticsearch for our currently defined NLP tasks.
This commit is contained in:
parent
5678525b15
commit
15a3007288
@ -57,8 +57,11 @@ class PyTorchModel:
|
||||
def put_vocab(self, path: str) -> None:
|
||||
with open(path) as f:
|
||||
vocab = json.load(f)
|
||||
self._client.ml.put_trained_model_vocabulary(
|
||||
model_id=self.model_id, vocabulary=vocab["vocabulary"]
|
||||
self._client.perform_request(
|
||||
method="PUT",
|
||||
path=f"/_ml/trained_models/{self.model_id}/vocabulary",
|
||||
headers={"accept": "application/json", "content-type": "application/json"},
|
||||
body=vocab,
|
||||
)
|
||||
|
||||
def put_model(self, model_path: str, chunk_size: int = DEFAULT_CHUNK_SIZE) -> None:
|
||||
|
@ -55,6 +55,8 @@ SUPPORTED_TOKENIZERS = (
|
||||
transformers.ElectraTokenizer,
|
||||
transformers.MobileBertTokenizer,
|
||||
transformers.RetriBertTokenizer,
|
||||
transformers.RobertaTokenizer,
|
||||
transformers.BartTokenizer,
|
||||
transformers.SqueezeBertTokenizer,
|
||||
)
|
||||
SUPPORTED_TOKENIZERS_NAMES = ", ".join(sorted([str(x) for x in SUPPORTED_TOKENIZERS]))
|
||||
@ -120,8 +122,15 @@ class _SentenceTransformerWrapperModule(nn.Module): # type: ignore
|
||||
) -> Optional[Any]:
|
||||
if model_id.startswith("sentence-transformers/"):
|
||||
model = AutoModel.from_pretrained(model_id, torchscript=True)
|
||||
if isinstance(model.config, transformers.MPNetConfig):
|
||||
return _MPNetSentenceTransformerWrapper(model, output_key)
|
||||
if isinstance(
|
||||
model.config,
|
||||
(
|
||||
transformers.MPNetConfig,
|
||||
transformers.RobertaConfig,
|
||||
transformers.BartConfig,
|
||||
),
|
||||
):
|
||||
return _TwoParameterSentenceTransformerWrapper(model, output_key)
|
||||
else:
|
||||
return _SentenceTransformerWrapper(model, output_key)
|
||||
else:
|
||||
@ -177,7 +186,7 @@ class _SentenceTransformerWrapper(_SentenceTransformerWrapperModule):
|
||||
return self._st_model(inputs)[self._output_key]
|
||||
|
||||
|
||||
class _MPNetSentenceTransformerWrapper(_SentenceTransformerWrapperModule):
|
||||
class _TwoParameterSentenceTransformerWrapper(_SentenceTransformerWrapperModule):
|
||||
def __init__(self, model: PreTrainedModel, output_key: str = DEFAULT_OUTPUT_KEY):
|
||||
super().__init__(model=model, output_key=output_key)
|
||||
|
||||
@ -280,7 +289,14 @@ class _TraceableModel(ABC):
|
||||
inputs["token_type_ids"] = torch.zeros(
|
||||
inputs["input_ids"].size(1), dtype=torch.long
|
||||
)
|
||||
if isinstance(self._model.config, transformers.MPNetConfig):
|
||||
if isinstance(
|
||||
self._model.config,
|
||||
(
|
||||
transformers.MPNetConfig,
|
||||
transformers.RobertaConfig,
|
||||
transformers.BartConfig,
|
||||
),
|
||||
):
|
||||
return torch.jit.trace(
|
||||
self._model,
|
||||
(inputs["input_ids"], inputs["attention_mask"]),
|
||||
@ -394,26 +410,37 @@ class TransformerModel:
|
||||
def _load_vocab(self) -> Dict[str, List[str]]:
|
||||
vocab_items = self._tokenizer.get_vocab().items()
|
||||
vocabulary = [k for k, _ in sorted(vocab_items, key=lambda kv: kv[1])] # type: ignore
|
||||
return {
|
||||
vocab_obj = {
|
||||
"vocabulary": vocabulary,
|
||||
}
|
||||
ranks = getattr(self._tokenizer, "bpe_ranks", {})
|
||||
if len(ranks) > 0:
|
||||
merges = [
|
||||
" ".join(m) for m, _ in sorted(ranks.items(), key=lambda kv: kv[1])
|
||||
]
|
||||
vocab_obj["merges"] = merges
|
||||
return vocab_obj
|
||||
|
||||
def _create_config(self) -> Dict[str, Any]:
|
||||
tokenizer_type = (
|
||||
"mpnet"
|
||||
if isinstance(self._tokenizer, transformers.MPNetTokenizer)
|
||||
else "bert"
|
||||
)
|
||||
inference_config: Dict[str, Dict[str, Any]] = {
|
||||
self._task_type: {
|
||||
"tokenization": {
|
||||
tokenizer_type: {
|
||||
"do_lower_case": getattr(
|
||||
self._tokenizer, "do_lower_case", False
|
||||
),
|
||||
}
|
||||
}
|
||||
if isinstance(self._tokenizer, transformers.MPNetTokenizer):
|
||||
tokenizer_type = "mpnet"
|
||||
tokenizer_obj = {
|
||||
"do_lower_case": getattr(self._tokenizer, "do_lower_case", False)
|
||||
}
|
||||
elif isinstance(
|
||||
self._tokenizer, (transformers.RobertaTokenizer, transformers.BartTokenizer)
|
||||
):
|
||||
tokenizer_type = "roberta"
|
||||
tokenizer_obj = {
|
||||
"add_prefix_space": getattr(self._tokenizer, "add_prefix_space", False)
|
||||
}
|
||||
else:
|
||||
tokenizer_type = "bert"
|
||||
tokenizer_obj = {
|
||||
"do_lower_case": getattr(self._tokenizer, "do_lower_case", False)
|
||||
}
|
||||
inference_config: Dict[str, Dict[str, Any]] = {
|
||||
self._task_type: {"tokenization": {tokenizer_type: tokenizer_obj}}
|
||||
}
|
||||
|
||||
if hasattr(self._tokenizer, "max_model_input_sizes"):
|
||||
|
Loading…
x
Reference in New Issue
Block a user