diff --git a/bin/import_hub_model.py b/bin/import_hub_model.py new file mode 100755 index 0000000..11cfc8e --- /dev/null +++ b/bin/import_hub_model.py @@ -0,0 +1,105 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Copies a model from the Hugging Face model hub into an Elasticsearch cluster. +This will create local cached copies that will be traced (necessary) before +uploading to Elasticsearch. This will also check that the task type is supported +as well as the model and tokenizer types. All necessary configuration is +uploaded along with the model. +""" + +import argparse +import tempfile + +import elasticsearch + +from eland.ml.pytorch import PyTorchModel +from eland.ml.pytorch.transformers import SUPPORTED_TASK_TYPES, TransformerModel + +MODEL_HUB_URL = "https://huggingface.co" + + +def main(): + parser = argparse.ArgumentParser(prog="upload_hub_model.py") + parser.add_argument( + "--url", + required=True, + help="An Elasticsearch connection URL, e.g. http://user:secret@localhost:9200", + ) + parser.add_argument( + "--hub-model-id", + required=True, + help="The model ID in the Hugging Face model hub, " + "e.g. dbmdz/bert-large-cased-finetuned-conll03-english", + ) + parser.add_argument( + "--elasticsearch-model-id", + required=False, + default=None, + help="The model ID to use in Elasticsearch, " + "e.g. bert-large-cased-finetuned-conll03-english." + "When left unspecified, this will be auto-created from the `hub-id`", + ) + parser.add_argument( + "--task-type", + required=True, + choices=SUPPORTED_TASK_TYPES, + help="The task type that the model will be used for.", + ) + parser.add_argument( + "--quantize", + action="store_true", + default=False, + help="Quantize the model before uploading. Default: False", + ) + parser.add_argument( + "--start", + action="store_true", + default=False, + help="Start the model deployment after uploading. Default: False", + ) + args = parser.parse_args() + + es = elasticsearch.Elasticsearch(args.url, timeout=300) # 5 minute timeout + + # trace and save model, then upload it from temp file + with tempfile.TemporaryDirectory() as tmp_dir: + print("Loading HuggingFace transformer tokenizer and model") + tm = TransformerModel(args.hub_model_id, args.task_type, args.quantize) + model_path, config_path, vocab_path = tm.save(tmp_dir) + + es_model_id = ( + args.elasticsearch_model_id + if args.elasticsearch_model_id + else tm.elasticsearch_model_id() + ) + + ptm = PyTorchModel(es, es_model_id) + ptm.stop() + ptm.delete() + print(f"Importing model: {ptm.model_id}") + ptm.import_model(model_path, config_path, vocab_path) + + # start the deployed model + if args.start: + print("Starting model deployment") + ptm.start() + + +if __name__ == "__main__": + main() diff --git a/eland/ml/pytorch/__init__.py b/eland/ml/pytorch/__init__.py new file mode 100644 index 0000000..be83855 --- /dev/null +++ b/eland/ml/pytorch/__init__.py @@ -0,0 +1,20 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from eland.ml.pytorch._pytorch_model import PyTorchModel # noqa: F401 + +__all__ = ["PyTorchModel"] diff --git a/eland/ml/pytorch/_pytorch_model.py b/eland/ml/pytorch/_pytorch_model.py new file mode 100644 index 0000000..a397974 --- /dev/null +++ b/eland/ml/pytorch/_pytorch_model.py @@ -0,0 +1,141 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import base64 +import json +import math +import os +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Set, Tuple, Union + +from tqdm.auto import tqdm + +from eland.common import ensure_es_client + +if TYPE_CHECKING: + from elasticsearch import Elasticsearch + +DEFAULT_CHUNK_SIZE = 4 * 1024 * 1024 # 4MB +DEFAULT_TIMEOUT = "60s" + + +class PyTorchModel: + """ + A PyTorch model managed by Elasticsearch. + + These models must be trained outside of Elasticsearch, conform to the + support tokenization and inference interfaces, and exported as their + TorchScript representations. + """ + + def __init__( + self, + es_client: Union[str, List[str], Tuple[str, ...], "Elasticsearch"], + model_id: str, + ): + self._client = ensure_es_client(es_client) + self.model_id = model_id + + def put_config(self, path: str) -> None: + with open(path) as f: + config = json.load(f) + self._client.ml.put_trained_model(model_id=self.model_id, body=config) + + def put_vocab(self, path: str) -> None: + with open(path) as f: + vocab = json.load(f) + self._client.transport.perform_request( + "PUT", + f"/_ml/trained_models/{self.model_id}/vocabulary", + body=vocab, + ) + + def put_model(self, model_path: str, chunk_size: int = DEFAULT_CHUNK_SIZE) -> None: + model_size = os.stat(model_path).st_size + total_parts = math.ceil(model_size / chunk_size) + + def model_file_chunk_generator() -> Iterable[str]: + with open(model_path, "rb") as f: + while True: + data = f.read(chunk_size) + if not data: + break + yield base64.b64encode(data).decode() + + for i, data in tqdm(enumerate(model_file_chunk_generator()), total=total_parts): + body = { + "total_definition_length": model_size, + "total_parts": total_parts, + "definition": data, + } + self._client.transport.perform_request( + "PUT", + f"/_ml/trained_models/{self.model_id}/definition/{i}", + body=body, + ) + + def import_model( + self, + model_path: str, + config_path: str, + vocab_path: str, + chunk_size: int = DEFAULT_CHUNK_SIZE, + ) -> None: + # TODO: Implement some pre-flight checks on config, vocab, and model + self.put_config(config_path) + self.put_vocab(vocab_path) + self.put_model(model_path, chunk_size) + + def infer( + self, body: Dict[str, Any], timeout: str = DEFAULT_TIMEOUT + ) -> Dict[str, Any]: + return self._client.transport.perform_request( + "POST", + f"/_ml/trained_models/{self.model_id}/deployment/_infer", + body=body, + params={"timeout": timeout}, + ) + + def start(self, timeout: str = DEFAULT_TIMEOUT) -> None: + self._client.transport.perform_request( + "POST", + f"/_ml/trained_models/{self.model_id}/deployment/_start", + params={"timeout": timeout, "wait_for": "started"}, + ) + + def stop(self) -> None: + self._client.transport.perform_request( + "POST", + f"/_ml/trained_models/{self.model_id}/deployment/_stop", + params={"ignore": 404}, + ) + + def delete(self) -> None: + self._client.ml.delete_trained_model(self.model_id, ignore=(404,)) + + @classmethod + def list( + cls, es_client: Union[str, List[str], Tuple[str, ...], "Elasticsearch"] + ) -> Set[str]: + client = ensure_es_client(es_client) + res = client.ml.get_trained_models(model_id="*", allow_no_match=True) + return set( + [ + model["model_id"] + for model in res["trained_model_configs"] + if model["model_type"] == "pytorch" + ] + ) diff --git a/eland/ml/pytorch/transformers.py b/eland/ml/pytorch/transformers.py new file mode 100644 index 0000000..9143579 --- /dev/null +++ b/eland/ml/pytorch/transformers.py @@ -0,0 +1,465 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Support for and interoperability with HuggingFace transformers and related +libraries such as sentence-transformers. +""" + +import json +import os.path +from abc import ABC, abstractmethod +from typing import Any, List, Optional, Tuple, Union + +import torch +import transformers +from sentence_transformers import SentenceTransformer +from torch import Tensor, nn +from transformers import ( + AutoConfig, + AutoModel, + PreTrainedModel, + PreTrainedTokenizer, + PreTrainedTokenizerFast, +) + +DEFAULT_OUTPUT_KEY = "sentence_embedding" +SUPPORTED_TASK_TYPES = { + "fill_mask", + "ner", + "text_classification", + "text_embedding", + "zero_shot_classification", +} +SUPPORTED_TASK_TYPES_NAMES = ", ".join(sorted(SUPPORTED_TASK_TYPES)) +SUPPORTED_TOKENIZERS = ( + transformers.BertTokenizer, + transformers.DPRContextEncoderTokenizer, + transformers.DPRQuestionEncoderTokenizer, + transformers.DistilBertTokenizer, + transformers.ElectraTokenizer, + transformers.MobileBertTokenizer, + transformers.RetriBertTokenizer, + transformers.SqueezeBertTokenizer, +) +SUPPORTED_TOKENIZERS_NAMES = ", ".join(sorted([str(x) for x in SUPPORTED_TOKENIZERS])) + +TracedModelTypes = Union[ + torch.nn.Module, + torch.ScriptModule, + torch.jit.ScriptModule, + torch.jit.TopLevelTracedModule, +] + + +class _DistilBertWrapper(nn.Module): + """ + A simple wrapper around DistilBERT model which makes the model inputs + conform to Elasticsearch's native inference processor interface. + """ + + def __init__(self, model: transformers.PreTrainedModel): + super().__init__() + self._model = model + self.config = model.config + + @staticmethod + def try_wrapping(model: PreTrainedModel) -> Optional[Any]: + if isinstance(model.config, transformers.DistilBertConfig): + return _DistilBertWrapper(model) + else: + return model + + def forward( + self, + input_ids: Tensor, + attention_mask: Tensor, + token_type_ids: Tensor, + position_ids: Tensor, + ) -> Tensor: + """Wrap the input and output to conform to the native process interface.""" + + return self._model(input_ids=input_ids, attention_mask=attention_mask) + + +class _SentenceTransformerWrapper(nn.Module): + """ + A wrapper around sentence-transformer models to provide pooling, + normalization and other graph layers that are not defined in the base + HuggingFace transformer model. + """ + + def __init__(self, model: PreTrainedModel, output_key: str = DEFAULT_OUTPUT_KEY): + super().__init__() + self._hf_model = model + self._st_model = SentenceTransformer(model.config.name_or_path) + self._output_key = output_key + + self._remove_pooling_layer() + self._replace_transformer_layer() + + @staticmethod + def from_pretrained( + model_id: str, output_key: str = DEFAULT_OUTPUT_KEY + ) -> Optional[Any]: + if model_id.startswith("sentence-transformers/"): + model = AutoModel.from_pretrained(model_id, torchscript=True) + return _SentenceTransformerWrapper(model, output_key) + else: + return None + + def _remove_pooling_layer(self): + """ + Removes any last pooling layer which is not used to create embeddings. + Leaving this layer in will cause it to return a NoneType which in turn + will fail to load in libtorch. Alternatively, we can just use the output + of the pooling layer as a dummy but this also affects (if only in a + minor way) the performance of inference, so we're better off removing + the layer if we can. + """ + + if hasattr(self._hf_model, "pooler"): + self._hf_model.pooler = None + + def _replace_transformer_layer(self): + """ + Replaces the HuggingFace Transformer layer in the SentenceTransformer + modules so we can set it with one that has pooling layer removed and + was loaded ready for TorchScript export. + """ + + self._st_model._modules["0"].auto_model = self._hf_model + + def forward( + self, + input_ids: Tensor, + attention_mask: Tensor, + token_type_ids: Tensor, + position_ids: Tensor, + ) -> Tensor: + """Wrap the input and output to conform to the native process interface.""" + + inputs = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + "position_ids": position_ids, + } + + # remove inputs for specific model types + if isinstance(self._hf_model.config, transformers.DistilBertConfig): + del inputs["token_type_ids"] + + return self._st_model(inputs)[self._output_key] + + +class _DPREncoderWrapper(nn.Module): + """ + AutoModel loading does not work for DPRContextEncoders, this only exists as + a workaround. This may never be fixed so this is likely permanent. + See: https://github.com/huggingface/transformers/issues/13670 + """ + + _SUPPORTED_MODELS = { + transformers.DPRContextEncoder, + transformers.DPRQuestionEncoder, + } + _SUPPORTED_MODELS_NAMES = set([x.__name__ for x in _SUPPORTED_MODELS]) + + def __init__( + self, + model: Union[transformers.DPRContextEncoder, transformers.DPRQuestionEncoder], + ): + super().__init__() + self._model = model + + @staticmethod + def from_pretrained(model_id: str) -> Optional[Any]: + + config = AutoConfig.from_pretrained(model_id) + + def is_compatible() -> bool: + is_dpr_model = config.model_type == "dpr" + has_architectures = len(config.architectures) == 1 + is_supported_architecture = ( + config.architectures[0] in _DPREncoderWrapper._SUPPORTED_MODELS_NAMES + ) + return is_dpr_model and has_architectures and is_supported_architecture + + if is_compatible(): + model = getattr(transformers, config.architectures[0]).from_pretrained( + model_id, torchscript=True + ) + return _DPREncoderWrapper(model) + else: + return None + + def forward( + self, + input_ids: Tensor, + attention_mask: Tensor, + token_type_ids: Tensor, + position_ids: Tensor, + ) -> Tensor: + """Wrap the input and output to conform to the native process interface.""" + + return self._model( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + ) + + +class _TraceableModel(ABC): + """A base class representing a HuggingFace transformer model that can be traced.""" + + def __init__( + self, + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + model: Union[ + PreTrainedModel, + _SentenceTransformerWrapper, + _DPREncoderWrapper, + _DistilBertWrapper, + ], + ): + self._tokenizer = tokenizer + self._model = model + + def classification_labels(self) -> Optional[List[str]]: + return None + + def quantize(self): + torch.quantization.quantize_dynamic( + self._model, {torch.nn.Linear}, dtype=torch.qint8 + ) + + def trace(self) -> TracedModelTypes: + # model needs to be in evaluate mode + self._model.eval() + + inputs = self._prepare_inputs() + position_ids = torch.arange(inputs["input_ids"].size(1), dtype=torch.long) + + # Add params when not provided by the tokenizer (e.g. DistilBERT), to conform to BERT interface + if "token_type_ids" not in inputs: + inputs["token_type_ids"] = torch.zeros( + inputs["input_ids"].size(1), dtype=torch.long + ) + + return torch.jit.trace( + self._model, + ( + inputs["input_ids"], + inputs["attention_mask"], + inputs["token_type_ids"], + position_ids, + ), + ) + + @abstractmethod + def _prepare_inputs(self) -> transformers.BatchEncoding: + ... + + +class _TraceableClassificationModel(_TraceableModel, ABC): + def classification_labels(self) -> Optional[List[str]]: + id_label_items = self._model.config.id2label.items() + labels = [v for _, v in sorted(id_label_items, key=lambda kv: kv[0])] + + # Make classes like I-PER into I_PER which fits Java enumerations + return [label.replace("-", "_") for label in labels] + + +class _TraceableFillMaskModel(_TraceableModel): + def _prepare_inputs(self) -> transformers.BatchEncoding: + return self._tokenizer( + "Who was Jim Henson?", + "[MASK] Henson was a puppeteer", + return_tensors="pt", + ) + + +class _TraceableNerModel(_TraceableClassificationModel): + def _prepare_inputs(self) -> transformers.BatchEncoding: + return self._tokenizer( + ( + "Hugging Face Inc. is a company based in New York City. " + "Its headquarters are in DUMBO, therefore very close to the Manhattan Bridge." + ), + return_tensors="pt", + ) + + +class _TraceableTextClassificationModel(_TraceableClassificationModel): + def _prepare_inputs(self) -> transformers.BatchEncoding: + return self._tokenizer( + "This is an example sentence.", + return_tensors="pt", + ) + + +class _TraceableTextEmbeddingModel(_TraceableModel): + def _prepare_inputs(self) -> transformers.BatchEncoding: + return self._tokenizer( + "This is an example sentence.", + return_tensors="pt", + ) + + +class _TraceableZeroShotClassificationModel(_TraceableClassificationModel): + def _prepare_inputs(self) -> transformers.BatchEncoding: + return self._tokenizer( + "This is an example sentence.", + "This example is an example.", + return_tensors="pt", + truncation_strategy="only_first", + ) + + +class TransformerModel: + def __init__(self, model_id: str, task_type: str, quantize: bool = False): + self._model_id = model_id + self._task_type = task_type.replace("-", "_") + + # load Hugging Face model and tokenizer + # use padding in the tokenizer to ensure max length sequences are used for tracing + # - see: https://huggingface.co/transformers/serialization.html#dummy-inputs-and-standard-lengths + self._tokenizer = transformers.AutoTokenizer.from_pretrained( + self._model_id, padding=True, use_fast=False + ) + + # check for a supported tokenizer + if not isinstance(self._tokenizer, SUPPORTED_TOKENIZERS): + raise TypeError( + f"Tokenizer type {self._tokenizer} not supported, must be one of: {SUPPORTED_TOKENIZERS_NAMES}" + ) + + self._traceable_model = self._create_traceable_model() + if quantize: + self._traceable_model.quantize() + self._traced_model = self._traceable_model.trace() + self._vocab = self._load_vocab() + self._config = self._create_config() + + def _load_vocab(self): + vocab_items = self._tokenizer.get_vocab().items() + vocabulary = [k for k, _ in sorted(vocab_items, key=lambda kv: kv[1])] + return { + "vocabulary": vocabulary, + } + + def _create_config(self): + inference_config = { + self._task_type: { + "tokenization": { + "bert": { + "do_lower_case": getattr( + self._tokenizer, "do_lower_case", False + ), + } + } + } + } + + if hasattr(self._tokenizer, "max_model_input_sizes"): + max_sequence_length = self._tokenizer.max_model_input_sizes.get( + self._model_id + ) + if max_sequence_length: + inference_config[self._task_type]["tokenization"]["bert"][ + "max_sequence_length" + ] = max_sequence_length + + if self._traceable_model.classification_labels(): + inference_config[self._task_type][ + "classification_labels" + ] = self._traceable_model.classification_labels() + + return { + "description": f"Model {self._model_id} for task type '{self._task_type}'", + "model_type": "pytorch", + "inference_config": inference_config, + "input": { + "field_names": ["text_field"], + }, + } + + def _create_traceable_model(self) -> _TraceableModel: + if self._task_type == "fill_mask": + model = transformers.AutoModelForMaskedLM.from_pretrained( + self._model_id, torchscript=True + ) + model = _DistilBertWrapper.try_wrapping(model) + return _TraceableFillMaskModel(self._tokenizer, model) + + elif self._task_type == "ner": + model = transformers.AutoModelForTokenClassification.from_pretrained( + self._model_id, torchscript=True + ) + model = _DistilBertWrapper.try_wrapping(model) + return _TraceableNerModel(self._tokenizer, model) + + elif self._task_type == "text_classification": + model = transformers.AutoModelForSequenceClassification.from_pretrained( + self._model_id, torchscript=True + ) + model = _DistilBertWrapper.try_wrapping(model) + return _TraceableTextClassificationModel(self._tokenizer, model) + + elif self._task_type == "text_embedding": + model = _SentenceTransformerWrapper.from_pretrained(self._model_id) + if not model: + model = _DPREncoderWrapper.from_pretrained(self._model_id) + if not model: + model = transformers.AutoModel.from_pretrained( + self._model_id, torchscript=True + ) + return _TraceableTextEmbeddingModel(self._tokenizer, model) + + elif self._task_type == "zero_shot_classification": + model = transformers.AutoModelForSequenceClassification.from_pretrained( + self._model_id, torchscript=True + ) + model = _DistilBertWrapper.try_wrapping(model) + return _TraceableZeroShotClassificationModel(self._tokenizer, model) + + else: + raise TypeError( + f"Unknown task type {self._task_type}, must be one of: {SUPPORTED_TASK_TYPES_NAMES}" + ) + + def elasticsearch_model_id(self): + # Elasticsearch model IDs need to be a specific format: no special chars, all lowercase, max 64 chars + return self._model_id.replace("/", "__").lower()[:64] + + def save(self, path: str) -> Tuple[str, str, str]: + # save traced model + model_path = os.path.join(path, "traced_pytorch_model.pt") + torch.jit.save(self._traced_model, model_path) + + # save configuration + config_path = os.path.join(path, "config.json") + with open(config_path, "w") as outfile: + json.dump(self._config, outfile) + + # save vocabulary + vocab_path = os.path.join(path, "vocabulary.json") + with open(vocab_path, "w") as outfile: + json.dump(self._vocab, outfile) + + return model_path, config_path, vocab_path diff --git a/noxfile.py b/noxfile.py index 083c858..e904314 100644 --- a/noxfile.py +++ b/noxfile.py @@ -22,7 +22,7 @@ from pathlib import Path import nox BASE_DIR = Path(__file__).parent -SOURCE_FILES = ("setup.py", "noxfile.py", "eland/", "docs/", "utils/", "tests/") +SOURCE_FILES = ("setup.py", "noxfile.py", "eland/", "docs/", "utils/", "tests/", "bin/") # Whenever type-hints are completed on a file it should # be added here so that this file will continue to be checked diff --git a/requirements-dev.txt b/requirements-dev.txt index 9310580..a074f42 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -11,3 +11,7 @@ nox lightgbm pytest-cov mypy +huggingface-hub>=0.0.17 +sentence-transformers>=2.0.0 +torch>=1.9.0 +transformers[torch]>=4.11.0 diff --git a/setup.py b/setup.py index 279b973..908a87e 100644 --- a/setup.py +++ b/setup.py @@ -84,5 +84,11 @@ setup( "xgboost": ["xgboost>=0.90,<2"], "scikit-learn": ["scikit-learn>=0.22.1,<1"], "lightgbm": ["lightgbm>=2,<4"], + "pytorch": [ + "huggingface-hub>=0.0.17,<1", + "sentence-transformers>=2.0.0,<3", + "torch>=1.9.0,<2", + "transformers[torch]>=4.11.0<5", + ], }, )