mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
Add initial implementation of PyTorch ML models
This commit is contained in:
parent
995f2432b6
commit
014943d3b8
105
bin/import_hub_model.py
Executable file
105
bin/import_hub_model.py
Executable file
@ -0,0 +1,105 @@
|
||||
# Licensed to Elasticsearch B.V. under one or more contributor
|
||||
# license agreements. See the NOTICE file distributed with
|
||||
# this work for additional information regarding copyright
|
||||
# ownership. Elasticsearch B.V. licenses this file to you under
|
||||
# the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Copies a model from the Hugging Face model hub into an Elasticsearch cluster.
|
||||
This will create local cached copies that will be traced (necessary) before
|
||||
uploading to Elasticsearch. This will also check that the task type is supported
|
||||
as well as the model and tokenizer types. All necessary configuration is
|
||||
uploaded along with the model.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import tempfile
|
||||
|
||||
import elasticsearch
|
||||
|
||||
from eland.ml.pytorch import PyTorchModel
|
||||
from eland.ml.pytorch.transformers import SUPPORTED_TASK_TYPES, TransformerModel
|
||||
|
||||
MODEL_HUB_URL = "https://huggingface.co"
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(prog="upload_hub_model.py")
|
||||
parser.add_argument(
|
||||
"--url",
|
||||
required=True,
|
||||
help="An Elasticsearch connection URL, e.g. http://user:secret@localhost:9200",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hub-model-id",
|
||||
required=True,
|
||||
help="The model ID in the Hugging Face model hub, "
|
||||
"e.g. dbmdz/bert-large-cased-finetuned-conll03-english",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--elasticsearch-model-id",
|
||||
required=False,
|
||||
default=None,
|
||||
help="The model ID to use in Elasticsearch, "
|
||||
"e.g. bert-large-cased-finetuned-conll03-english."
|
||||
"When left unspecified, this will be auto-created from the `hub-id`",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task-type",
|
||||
required=True,
|
||||
choices=SUPPORTED_TASK_TYPES,
|
||||
help="The task type that the model will be used for.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--quantize",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Quantize the model before uploading. Default: False",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--start",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Start the model deployment after uploading. Default: False",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
es = elasticsearch.Elasticsearch(args.url, timeout=300) # 5 minute timeout
|
||||
|
||||
# trace and save model, then upload it from temp file
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
print("Loading HuggingFace transformer tokenizer and model")
|
||||
tm = TransformerModel(args.hub_model_id, args.task_type, args.quantize)
|
||||
model_path, config_path, vocab_path = tm.save(tmp_dir)
|
||||
|
||||
es_model_id = (
|
||||
args.elasticsearch_model_id
|
||||
if args.elasticsearch_model_id
|
||||
else tm.elasticsearch_model_id()
|
||||
)
|
||||
|
||||
ptm = PyTorchModel(es, es_model_id)
|
||||
ptm.stop()
|
||||
ptm.delete()
|
||||
print(f"Importing model: {ptm.model_id}")
|
||||
ptm.import_model(model_path, config_path, vocab_path)
|
||||
|
||||
# start the deployed model
|
||||
if args.start:
|
||||
print("Starting model deployment")
|
||||
ptm.start()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
20
eland/ml/pytorch/__init__.py
Normal file
20
eland/ml/pytorch/__init__.py
Normal file
@ -0,0 +1,20 @@
|
||||
# Licensed to Elasticsearch B.V. under one or more contributor
|
||||
# license agreements. See the NOTICE file distributed with
|
||||
# this work for additional information regarding copyright
|
||||
# ownership. Elasticsearch B.V. licenses this file to you under
|
||||
# the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from eland.ml.pytorch._pytorch_model import PyTorchModel # noqa: F401
|
||||
|
||||
__all__ = ["PyTorchModel"]
|
141
eland/ml/pytorch/_pytorch_model.py
Normal file
141
eland/ml/pytorch/_pytorch_model.py
Normal file
@ -0,0 +1,141 @@
|
||||
# Licensed to Elasticsearch B.V. under one or more contributor
|
||||
# license agreements. See the NOTICE file distributed with
|
||||
# this work for additional information regarding copyright
|
||||
# ownership. Elasticsearch B.V. licenses this file to you under
|
||||
# the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import base64
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Set, Tuple, Union
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from eland.common import ensure_es_client
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
DEFAULT_CHUNK_SIZE = 4 * 1024 * 1024 # 4MB
|
||||
DEFAULT_TIMEOUT = "60s"
|
||||
|
||||
|
||||
class PyTorchModel:
|
||||
"""
|
||||
A PyTorch model managed by Elasticsearch.
|
||||
|
||||
These models must be trained outside of Elasticsearch, conform to the
|
||||
support tokenization and inference interfaces, and exported as their
|
||||
TorchScript representations.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
es_client: Union[str, List[str], Tuple[str, ...], "Elasticsearch"],
|
||||
model_id: str,
|
||||
):
|
||||
self._client = ensure_es_client(es_client)
|
||||
self.model_id = model_id
|
||||
|
||||
def put_config(self, path: str) -> None:
|
||||
with open(path) as f:
|
||||
config = json.load(f)
|
||||
self._client.ml.put_trained_model(model_id=self.model_id, body=config)
|
||||
|
||||
def put_vocab(self, path: str) -> None:
|
||||
with open(path) as f:
|
||||
vocab = json.load(f)
|
||||
self._client.transport.perform_request(
|
||||
"PUT",
|
||||
f"/_ml/trained_models/{self.model_id}/vocabulary",
|
||||
body=vocab,
|
||||
)
|
||||
|
||||
def put_model(self, model_path: str, chunk_size: int = DEFAULT_CHUNK_SIZE) -> None:
|
||||
model_size = os.stat(model_path).st_size
|
||||
total_parts = math.ceil(model_size / chunk_size)
|
||||
|
||||
def model_file_chunk_generator() -> Iterable[str]:
|
||||
with open(model_path, "rb") as f:
|
||||
while True:
|
||||
data = f.read(chunk_size)
|
||||
if not data:
|
||||
break
|
||||
yield base64.b64encode(data).decode()
|
||||
|
||||
for i, data in tqdm(enumerate(model_file_chunk_generator()), total=total_parts):
|
||||
body = {
|
||||
"total_definition_length": model_size,
|
||||
"total_parts": total_parts,
|
||||
"definition": data,
|
||||
}
|
||||
self._client.transport.perform_request(
|
||||
"PUT",
|
||||
f"/_ml/trained_models/{self.model_id}/definition/{i}",
|
||||
body=body,
|
||||
)
|
||||
|
||||
def import_model(
|
||||
self,
|
||||
model_path: str,
|
||||
config_path: str,
|
||||
vocab_path: str,
|
||||
chunk_size: int = DEFAULT_CHUNK_SIZE,
|
||||
) -> None:
|
||||
# TODO: Implement some pre-flight checks on config, vocab, and model
|
||||
self.put_config(config_path)
|
||||
self.put_vocab(vocab_path)
|
||||
self.put_model(model_path, chunk_size)
|
||||
|
||||
def infer(
|
||||
self, body: Dict[str, Any], timeout: str = DEFAULT_TIMEOUT
|
||||
) -> Dict[str, Any]:
|
||||
return self._client.transport.perform_request(
|
||||
"POST",
|
||||
f"/_ml/trained_models/{self.model_id}/deployment/_infer",
|
||||
body=body,
|
||||
params={"timeout": timeout},
|
||||
)
|
||||
|
||||
def start(self, timeout: str = DEFAULT_TIMEOUT) -> None:
|
||||
self._client.transport.perform_request(
|
||||
"POST",
|
||||
f"/_ml/trained_models/{self.model_id}/deployment/_start",
|
||||
params={"timeout": timeout, "wait_for": "started"},
|
||||
)
|
||||
|
||||
def stop(self) -> None:
|
||||
self._client.transport.perform_request(
|
||||
"POST",
|
||||
f"/_ml/trained_models/{self.model_id}/deployment/_stop",
|
||||
params={"ignore": 404},
|
||||
)
|
||||
|
||||
def delete(self) -> None:
|
||||
self._client.ml.delete_trained_model(self.model_id, ignore=(404,))
|
||||
|
||||
@classmethod
|
||||
def list(
|
||||
cls, es_client: Union[str, List[str], Tuple[str, ...], "Elasticsearch"]
|
||||
) -> Set[str]:
|
||||
client = ensure_es_client(es_client)
|
||||
res = client.ml.get_trained_models(model_id="*", allow_no_match=True)
|
||||
return set(
|
||||
[
|
||||
model["model_id"]
|
||||
for model in res["trained_model_configs"]
|
||||
if model["model_type"] == "pytorch"
|
||||
]
|
||||
)
|
465
eland/ml/pytorch/transformers.py
Normal file
465
eland/ml/pytorch/transformers.py
Normal file
@ -0,0 +1,465 @@
|
||||
# Licensed to Elasticsearch B.V. under one or more contributor
|
||||
# license agreements. See the NOTICE file distributed with
|
||||
# this work for additional information regarding copyright
|
||||
# ownership. Elasticsearch B.V. licenses this file to you under
|
||||
# the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Support for and interoperability with HuggingFace transformers and related
|
||||
libraries such as sentence-transformers.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os.path
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from torch import Tensor, nn
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModel,
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizer,
|
||||
PreTrainedTokenizerFast,
|
||||
)
|
||||
|
||||
DEFAULT_OUTPUT_KEY = "sentence_embedding"
|
||||
SUPPORTED_TASK_TYPES = {
|
||||
"fill_mask",
|
||||
"ner",
|
||||
"text_classification",
|
||||
"text_embedding",
|
||||
"zero_shot_classification",
|
||||
}
|
||||
SUPPORTED_TASK_TYPES_NAMES = ", ".join(sorted(SUPPORTED_TASK_TYPES))
|
||||
SUPPORTED_TOKENIZERS = (
|
||||
transformers.BertTokenizer,
|
||||
transformers.DPRContextEncoderTokenizer,
|
||||
transformers.DPRQuestionEncoderTokenizer,
|
||||
transformers.DistilBertTokenizer,
|
||||
transformers.ElectraTokenizer,
|
||||
transformers.MobileBertTokenizer,
|
||||
transformers.RetriBertTokenizer,
|
||||
transformers.SqueezeBertTokenizer,
|
||||
)
|
||||
SUPPORTED_TOKENIZERS_NAMES = ", ".join(sorted([str(x) for x in SUPPORTED_TOKENIZERS]))
|
||||
|
||||
TracedModelTypes = Union[
|
||||
torch.nn.Module,
|
||||
torch.ScriptModule,
|
||||
torch.jit.ScriptModule,
|
||||
torch.jit.TopLevelTracedModule,
|
||||
]
|
||||
|
||||
|
||||
class _DistilBertWrapper(nn.Module):
|
||||
"""
|
||||
A simple wrapper around DistilBERT model which makes the model inputs
|
||||
conform to Elasticsearch's native inference processor interface.
|
||||
"""
|
||||
|
||||
def __init__(self, model: transformers.PreTrainedModel):
|
||||
super().__init__()
|
||||
self._model = model
|
||||
self.config = model.config
|
||||
|
||||
@staticmethod
|
||||
def try_wrapping(model: PreTrainedModel) -> Optional[Any]:
|
||||
if isinstance(model.config, transformers.DistilBertConfig):
|
||||
return _DistilBertWrapper(model)
|
||||
else:
|
||||
return model
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Tensor,
|
||||
attention_mask: Tensor,
|
||||
token_type_ids: Tensor,
|
||||
position_ids: Tensor,
|
||||
) -> Tensor:
|
||||
"""Wrap the input and output to conform to the native process interface."""
|
||||
|
||||
return self._model(input_ids=input_ids, attention_mask=attention_mask)
|
||||
|
||||
|
||||
class _SentenceTransformerWrapper(nn.Module):
|
||||
"""
|
||||
A wrapper around sentence-transformer models to provide pooling,
|
||||
normalization and other graph layers that are not defined in the base
|
||||
HuggingFace transformer model.
|
||||
"""
|
||||
|
||||
def __init__(self, model: PreTrainedModel, output_key: str = DEFAULT_OUTPUT_KEY):
|
||||
super().__init__()
|
||||
self._hf_model = model
|
||||
self._st_model = SentenceTransformer(model.config.name_or_path)
|
||||
self._output_key = output_key
|
||||
|
||||
self._remove_pooling_layer()
|
||||
self._replace_transformer_layer()
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained(
|
||||
model_id: str, output_key: str = DEFAULT_OUTPUT_KEY
|
||||
) -> Optional[Any]:
|
||||
if model_id.startswith("sentence-transformers/"):
|
||||
model = AutoModel.from_pretrained(model_id, torchscript=True)
|
||||
return _SentenceTransformerWrapper(model, output_key)
|
||||
else:
|
||||
return None
|
||||
|
||||
def _remove_pooling_layer(self):
|
||||
"""
|
||||
Removes any last pooling layer which is not used to create embeddings.
|
||||
Leaving this layer in will cause it to return a NoneType which in turn
|
||||
will fail to load in libtorch. Alternatively, we can just use the output
|
||||
of the pooling layer as a dummy but this also affects (if only in a
|
||||
minor way) the performance of inference, so we're better off removing
|
||||
the layer if we can.
|
||||
"""
|
||||
|
||||
if hasattr(self._hf_model, "pooler"):
|
||||
self._hf_model.pooler = None
|
||||
|
||||
def _replace_transformer_layer(self):
|
||||
"""
|
||||
Replaces the HuggingFace Transformer layer in the SentenceTransformer
|
||||
modules so we can set it with one that has pooling layer removed and
|
||||
was loaded ready for TorchScript export.
|
||||
"""
|
||||
|
||||
self._st_model._modules["0"].auto_model = self._hf_model
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Tensor,
|
||||
attention_mask: Tensor,
|
||||
token_type_ids: Tensor,
|
||||
position_ids: Tensor,
|
||||
) -> Tensor:
|
||||
"""Wrap the input and output to conform to the native process interface."""
|
||||
|
||||
inputs = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"token_type_ids": token_type_ids,
|
||||
"position_ids": position_ids,
|
||||
}
|
||||
|
||||
# remove inputs for specific model types
|
||||
if isinstance(self._hf_model.config, transformers.DistilBertConfig):
|
||||
del inputs["token_type_ids"]
|
||||
|
||||
return self._st_model(inputs)[self._output_key]
|
||||
|
||||
|
||||
class _DPREncoderWrapper(nn.Module):
|
||||
"""
|
||||
AutoModel loading does not work for DPRContextEncoders, this only exists as
|
||||
a workaround. This may never be fixed so this is likely permanent.
|
||||
See: https://github.com/huggingface/transformers/issues/13670
|
||||
"""
|
||||
|
||||
_SUPPORTED_MODELS = {
|
||||
transformers.DPRContextEncoder,
|
||||
transformers.DPRQuestionEncoder,
|
||||
}
|
||||
_SUPPORTED_MODELS_NAMES = set([x.__name__ for x in _SUPPORTED_MODELS])
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Union[transformers.DPRContextEncoder, transformers.DPRQuestionEncoder],
|
||||
):
|
||||
super().__init__()
|
||||
self._model = model
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained(model_id: str) -> Optional[Any]:
|
||||
|
||||
config = AutoConfig.from_pretrained(model_id)
|
||||
|
||||
def is_compatible() -> bool:
|
||||
is_dpr_model = config.model_type == "dpr"
|
||||
has_architectures = len(config.architectures) == 1
|
||||
is_supported_architecture = (
|
||||
config.architectures[0] in _DPREncoderWrapper._SUPPORTED_MODELS_NAMES
|
||||
)
|
||||
return is_dpr_model and has_architectures and is_supported_architecture
|
||||
|
||||
if is_compatible():
|
||||
model = getattr(transformers, config.architectures[0]).from_pretrained(
|
||||
model_id, torchscript=True
|
||||
)
|
||||
return _DPREncoderWrapper(model)
|
||||
else:
|
||||
return None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Tensor,
|
||||
attention_mask: Tensor,
|
||||
token_type_ids: Tensor,
|
||||
position_ids: Tensor,
|
||||
) -> Tensor:
|
||||
"""Wrap the input and output to conform to the native process interface."""
|
||||
|
||||
return self._model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
)
|
||||
|
||||
|
||||
class _TraceableModel(ABC):
|
||||
"""A base class representing a HuggingFace transformer model that can be traced."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||
model: Union[
|
||||
PreTrainedModel,
|
||||
_SentenceTransformerWrapper,
|
||||
_DPREncoderWrapper,
|
||||
_DistilBertWrapper,
|
||||
],
|
||||
):
|
||||
self._tokenizer = tokenizer
|
||||
self._model = model
|
||||
|
||||
def classification_labels(self) -> Optional[List[str]]:
|
||||
return None
|
||||
|
||||
def quantize(self):
|
||||
torch.quantization.quantize_dynamic(
|
||||
self._model, {torch.nn.Linear}, dtype=torch.qint8
|
||||
)
|
||||
|
||||
def trace(self) -> TracedModelTypes:
|
||||
# model needs to be in evaluate mode
|
||||
self._model.eval()
|
||||
|
||||
inputs = self._prepare_inputs()
|
||||
position_ids = torch.arange(inputs["input_ids"].size(1), dtype=torch.long)
|
||||
|
||||
# Add params when not provided by the tokenizer (e.g. DistilBERT), to conform to BERT interface
|
||||
if "token_type_ids" not in inputs:
|
||||
inputs["token_type_ids"] = torch.zeros(
|
||||
inputs["input_ids"].size(1), dtype=torch.long
|
||||
)
|
||||
|
||||
return torch.jit.trace(
|
||||
self._model,
|
||||
(
|
||||
inputs["input_ids"],
|
||||
inputs["attention_mask"],
|
||||
inputs["token_type_ids"],
|
||||
position_ids,
|
||||
),
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def _prepare_inputs(self) -> transformers.BatchEncoding:
|
||||
...
|
||||
|
||||
|
||||
class _TraceableClassificationModel(_TraceableModel, ABC):
|
||||
def classification_labels(self) -> Optional[List[str]]:
|
||||
id_label_items = self._model.config.id2label.items()
|
||||
labels = [v for _, v in sorted(id_label_items, key=lambda kv: kv[0])]
|
||||
|
||||
# Make classes like I-PER into I_PER which fits Java enumerations
|
||||
return [label.replace("-", "_") for label in labels]
|
||||
|
||||
|
||||
class _TraceableFillMaskModel(_TraceableModel):
|
||||
def _prepare_inputs(self) -> transformers.BatchEncoding:
|
||||
return self._tokenizer(
|
||||
"Who was Jim Henson?",
|
||||
"[MASK] Henson was a puppeteer",
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
|
||||
class _TraceableNerModel(_TraceableClassificationModel):
|
||||
def _prepare_inputs(self) -> transformers.BatchEncoding:
|
||||
return self._tokenizer(
|
||||
(
|
||||
"Hugging Face Inc. is a company based in New York City. "
|
||||
"Its headquarters are in DUMBO, therefore very close to the Manhattan Bridge."
|
||||
),
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
|
||||
class _TraceableTextClassificationModel(_TraceableClassificationModel):
|
||||
def _prepare_inputs(self) -> transformers.BatchEncoding:
|
||||
return self._tokenizer(
|
||||
"This is an example sentence.",
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
|
||||
class _TraceableTextEmbeddingModel(_TraceableModel):
|
||||
def _prepare_inputs(self) -> transformers.BatchEncoding:
|
||||
return self._tokenizer(
|
||||
"This is an example sentence.",
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
|
||||
class _TraceableZeroShotClassificationModel(_TraceableClassificationModel):
|
||||
def _prepare_inputs(self) -> transformers.BatchEncoding:
|
||||
return self._tokenizer(
|
||||
"This is an example sentence.",
|
||||
"This example is an example.",
|
||||
return_tensors="pt",
|
||||
truncation_strategy="only_first",
|
||||
)
|
||||
|
||||
|
||||
class TransformerModel:
|
||||
def __init__(self, model_id: str, task_type: str, quantize: bool = False):
|
||||
self._model_id = model_id
|
||||
self._task_type = task_type.replace("-", "_")
|
||||
|
||||
# load Hugging Face model and tokenizer
|
||||
# use padding in the tokenizer to ensure max length sequences are used for tracing
|
||||
# - see: https://huggingface.co/transformers/serialization.html#dummy-inputs-and-standard-lengths
|
||||
self._tokenizer = transformers.AutoTokenizer.from_pretrained(
|
||||
self._model_id, padding=True, use_fast=False
|
||||
)
|
||||
|
||||
# check for a supported tokenizer
|
||||
if not isinstance(self._tokenizer, SUPPORTED_TOKENIZERS):
|
||||
raise TypeError(
|
||||
f"Tokenizer type {self._tokenizer} not supported, must be one of: {SUPPORTED_TOKENIZERS_NAMES}"
|
||||
)
|
||||
|
||||
self._traceable_model = self._create_traceable_model()
|
||||
if quantize:
|
||||
self._traceable_model.quantize()
|
||||
self._traced_model = self._traceable_model.trace()
|
||||
self._vocab = self._load_vocab()
|
||||
self._config = self._create_config()
|
||||
|
||||
def _load_vocab(self):
|
||||
vocab_items = self._tokenizer.get_vocab().items()
|
||||
vocabulary = [k for k, _ in sorted(vocab_items, key=lambda kv: kv[1])]
|
||||
return {
|
||||
"vocabulary": vocabulary,
|
||||
}
|
||||
|
||||
def _create_config(self):
|
||||
inference_config = {
|
||||
self._task_type: {
|
||||
"tokenization": {
|
||||
"bert": {
|
||||
"do_lower_case": getattr(
|
||||
self._tokenizer, "do_lower_case", False
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if hasattr(self._tokenizer, "max_model_input_sizes"):
|
||||
max_sequence_length = self._tokenizer.max_model_input_sizes.get(
|
||||
self._model_id
|
||||
)
|
||||
if max_sequence_length:
|
||||
inference_config[self._task_type]["tokenization"]["bert"][
|
||||
"max_sequence_length"
|
||||
] = max_sequence_length
|
||||
|
||||
if self._traceable_model.classification_labels():
|
||||
inference_config[self._task_type][
|
||||
"classification_labels"
|
||||
] = self._traceable_model.classification_labels()
|
||||
|
||||
return {
|
||||
"description": f"Model {self._model_id} for task type '{self._task_type}'",
|
||||
"model_type": "pytorch",
|
||||
"inference_config": inference_config,
|
||||
"input": {
|
||||
"field_names": ["text_field"],
|
||||
},
|
||||
}
|
||||
|
||||
def _create_traceable_model(self) -> _TraceableModel:
|
||||
if self._task_type == "fill_mask":
|
||||
model = transformers.AutoModelForMaskedLM.from_pretrained(
|
||||
self._model_id, torchscript=True
|
||||
)
|
||||
model = _DistilBertWrapper.try_wrapping(model)
|
||||
return _TraceableFillMaskModel(self._tokenizer, model)
|
||||
|
||||
elif self._task_type == "ner":
|
||||
model = transformers.AutoModelForTokenClassification.from_pretrained(
|
||||
self._model_id, torchscript=True
|
||||
)
|
||||
model = _DistilBertWrapper.try_wrapping(model)
|
||||
return _TraceableNerModel(self._tokenizer, model)
|
||||
|
||||
elif self._task_type == "text_classification":
|
||||
model = transformers.AutoModelForSequenceClassification.from_pretrained(
|
||||
self._model_id, torchscript=True
|
||||
)
|
||||
model = _DistilBertWrapper.try_wrapping(model)
|
||||
return _TraceableTextClassificationModel(self._tokenizer, model)
|
||||
|
||||
elif self._task_type == "text_embedding":
|
||||
model = _SentenceTransformerWrapper.from_pretrained(self._model_id)
|
||||
if not model:
|
||||
model = _DPREncoderWrapper.from_pretrained(self._model_id)
|
||||
if not model:
|
||||
model = transformers.AutoModel.from_pretrained(
|
||||
self._model_id, torchscript=True
|
||||
)
|
||||
return _TraceableTextEmbeddingModel(self._tokenizer, model)
|
||||
|
||||
elif self._task_type == "zero_shot_classification":
|
||||
model = transformers.AutoModelForSequenceClassification.from_pretrained(
|
||||
self._model_id, torchscript=True
|
||||
)
|
||||
model = _DistilBertWrapper.try_wrapping(model)
|
||||
return _TraceableZeroShotClassificationModel(self._tokenizer, model)
|
||||
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Unknown task type {self._task_type}, must be one of: {SUPPORTED_TASK_TYPES_NAMES}"
|
||||
)
|
||||
|
||||
def elasticsearch_model_id(self):
|
||||
# Elasticsearch model IDs need to be a specific format: no special chars, all lowercase, max 64 chars
|
||||
return self._model_id.replace("/", "__").lower()[:64]
|
||||
|
||||
def save(self, path: str) -> Tuple[str, str, str]:
|
||||
# save traced model
|
||||
model_path = os.path.join(path, "traced_pytorch_model.pt")
|
||||
torch.jit.save(self._traced_model, model_path)
|
||||
|
||||
# save configuration
|
||||
config_path = os.path.join(path, "config.json")
|
||||
with open(config_path, "w") as outfile:
|
||||
json.dump(self._config, outfile)
|
||||
|
||||
# save vocabulary
|
||||
vocab_path = os.path.join(path, "vocabulary.json")
|
||||
with open(vocab_path, "w") as outfile:
|
||||
json.dump(self._vocab, outfile)
|
||||
|
||||
return model_path, config_path, vocab_path
|
@ -22,7 +22,7 @@ from pathlib import Path
|
||||
import nox
|
||||
|
||||
BASE_DIR = Path(__file__).parent
|
||||
SOURCE_FILES = ("setup.py", "noxfile.py", "eland/", "docs/", "utils/", "tests/")
|
||||
SOURCE_FILES = ("setup.py", "noxfile.py", "eland/", "docs/", "utils/", "tests/", "bin/")
|
||||
|
||||
# Whenever type-hints are completed on a file it should
|
||||
# be added here so that this file will continue to be checked
|
||||
|
@ -11,3 +11,7 @@ nox
|
||||
lightgbm
|
||||
pytest-cov
|
||||
mypy
|
||||
huggingface-hub>=0.0.17
|
||||
sentence-transformers>=2.0.0
|
||||
torch>=1.9.0
|
||||
transformers[torch]>=4.11.0
|
||||
|
Loading…
x
Reference in New Issue
Block a user