From afe08f810710f452202abae60c43a9e73cc0145f Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Tue, 3 May 2022 15:19:03 -0400 Subject: [PATCH] [ML] Improve NLP model import by using nicely defined types (#459) This adds some more definite types for our NLP tasks and tokenization configurations. This is the first step in allowing users to more easily import their own transformer models via something other than hugging face. --- README.md | 4 +- bin/eland_import_hub_model | 4 +- eland/ml/pytorch/_pytorch_model.py | 37 ++- eland/ml/pytorch/nlp_ml_model.py | 228 ++++++++++++++++++ eland/ml/pytorch/transformers.py | 115 +++++---- tests/ml/pytorch/test_pytorch_model_pytest.py | 6 +- 6 files changed, 333 insertions(+), 61 deletions(-) create mode 100644 eland/ml/pytorch/nlp_ml_model.py diff --git a/README.md b/README.md index b28d0b6..fdcd10c 100644 --- a/README.md +++ b/README.md @@ -250,11 +250,11 @@ Downloading: 100%|██████████| 249M/249M [00:23<00:00, 11.2MB # Export the model in a TorchScrpt representation which Elasticsearch uses >>> tmp_path = "models" >>> Path(tmp_path).mkdir(parents=True, exist_ok=True) ->>> model_path, config_path, vocab_path = tm.save(tmp_path) +>>> model_path, config, vocab_path = tm.save(tmp_path) # Import model into Elasticsearch >>> es = elasticsearch.Elasticsearch("http://elastic:mlqa_admin@localhost:9200", timeout=300) # 5 minute timeout >>> ptm = PyTorchModel(es, tm.elasticsearch_model_id()) ->>> ptm.import_model(model_path, config_path, vocab_path) +>>> ptm.import_model(model_path=model_path, config_path=None, vocab_path=vocab_path, config=config) 100%|██████████| 63/63 [00:12<00:00, 5.02it/s] ``` diff --git a/bin/eland_import_hub_model b/bin/eland_import_hub_model index e8e6ee9..332c982 100755 --- a/bin/eland_import_hub_model +++ b/bin/eland_import_hub_model @@ -188,7 +188,7 @@ if __name__ == "__main__": logger.info(f"Loading HuggingFace transformer tokenizer and model '{args.hub_model_id}'") tm = TransformerModel(args.hub_model_id, args.task_type, args.quantize) - model_path, config_path, vocab_path = tm.save(tmp_dir) + model_path, config, vocab_path = tm.save(tmp_dir) ptm = PyTorchModel(es, args.es_model_id if args.es_model_id else tm.elasticsearch_model_id()) model_exists = es.options(ignore_status=404).ml.get_trained_models(model_id=ptm.model_id).meta.status == 200 @@ -206,7 +206,7 @@ if __name__ == "__main__": exit(1) logger.info(f"Creating model with id '{ptm.model_id}'") - ptm.put_config(config_path) + ptm.put_config(config=config) logger.info(f"Uploading model definition") ptm.put_model(model_path) diff --git a/eland/ml/pytorch/_pytorch_model.py b/eland/ml/pytorch/_pytorch_model.py index e823e53..10b4d01 100644 --- a/eland/ml/pytorch/_pytorch_model.py +++ b/eland/ml/pytorch/_pytorch_model.py @@ -19,11 +19,22 @@ import base64 import json import math import os -from typing import TYPE_CHECKING, Any, Iterable, List, Mapping, Set, Tuple, Union +from typing import ( + TYPE_CHECKING, + Any, + Iterable, + List, + Mapping, + Optional, + Set, + Tuple, + Union, +) from tqdm.auto import tqdm # type: ignore from eland.common import ensure_es_client +from eland.ml.pytorch.nlp_ml_model import NlpTrainedModelConfig if TYPE_CHECKING: from elasticsearch import Elasticsearch @@ -49,10 +60,19 @@ class PyTorchModel: self._client: Elasticsearch = ensure_es_client(es_client) self.model_id = model_id - def put_config(self, path: str) -> None: - with open(path) as f: - config = json.load(f) - self._client.ml.put_trained_model(model_id=self.model_id, **config) + def put_config( + self, path: Optional[str] = None, config: Optional[NlpTrainedModelConfig] = None + ) -> None: + if path is not None and config is not None: + raise ValueError("Only include path or config. Not both") + if path is not None: + with open(path) as f: + config_map = json.load(f) + elif config is not None: + config_map = config.to_dict() + else: + raise ValueError("Must provide path or config") + self._client.ml.put_trained_model(model_id=self.model_id, **config_map) def put_vocab(self, path: str) -> None: with open(path) as f: @@ -89,13 +109,14 @@ class PyTorchModel: def import_model( self, + *, model_path: str, - config_path: str, + config_path: Optional[str], vocab_path: str, + config: Optional[NlpTrainedModelConfig] = None, chunk_size: int = DEFAULT_CHUNK_SIZE, ) -> None: - # TODO: Implement some pre-flight checks on config, vocab, and model - self.put_config(config_path) + self.put_config(path=config_path, config=config) self.put_model(model_path, chunk_size) self.put_vocab(vocab_path) diff --git a/eland/ml/pytorch/nlp_ml_model.py b/eland/ml/pytorch/nlp_ml_model.py new file mode 100644 index 0000000..b9441e0 --- /dev/null +++ b/eland/ml/pytorch/nlp_ml_model.py @@ -0,0 +1,228 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import typing as t + + +class NlpTokenizationConfig: + def __init__(self, *, configuration_type: str): + self.name = configuration_type + + def to_dict(self): + return { + self.name: { + k: v for k, v in self.__dict__.items() if v is not None and k != "name" + } + } + + +class NlpRobertaTokenizationConfig(NlpTokenizationConfig): + def __init__( + self, + *, + add_prefix_space: t.Optional[bool] = None, + with_special_tokens: t.Optional[bool] = None, + max_sequence_length: t.Optional[int] = None, + truncate: t.Optional[ + t.Union["t.Literal['first', 'none', 'second']", str] + ] = None, + span: t.Optional[int] = None, + ): + super().__init__(configuration_type="roberta") + self.add_prefix_space = add_prefix_space + self.with_special_tokens = with_special_tokens + self.max_sequence_length = max_sequence_length + self.truncate = truncate + self.span = span + + +class NlpBertTokenizationConfig(NlpTokenizationConfig): + def __init__( + self, + *, + do_lower_case: t.Optional[bool] = None, + with_special_tokens: t.Optional[bool] = None, + max_sequence_length: t.Optional[int] = None, + truncate: t.Optional[ + t.Union["t.Literal['first', 'none', 'second']", str] + ] = None, + span: t.Optional[int] = None, + ): + super().__init__(configuration_type="bert") + self.do_lower_case = do_lower_case + self.with_special_tokens = with_special_tokens + self.max_sequence_length = max_sequence_length + self.truncate = truncate + self.span = span + + +class NlpMPNetTokenizationConfig(NlpTokenizationConfig): + def __init__( + self, + *, + do_lower_case: t.Optional[bool] = None, + with_special_tokens: t.Optional[bool] = None, + max_sequence_length: t.Optional[int] = None, + truncate: t.Optional[ + t.Union["t.Literal['first', 'none', 'second']", str] + ] = None, + span: t.Optional[int] = None, + ): + super().__init__(configuration_type="mpnet") + self.do_lower_case = do_lower_case + self.with_special_tokens = with_special_tokens + self.max_sequence_length = max_sequence_length + self.truncate = truncate + self.span = span + + +class InferenceConfig: + def __init__(self, *, configuration_type: str): + self.name = configuration_type + + def to_dict(self) -> t.Dict[str, t.Any]: + return { + self.name: { + k: v.to_dict() if hasattr(v, "to_dict") else v + for k, v in self.__dict__.items() + if v is not None and k != "name" + } + } + + +class TextClassificationInferenceOptions(InferenceConfig): + def __init__( + self, + *, + classification_labels: t.Union[t.List[str], t.Tuple[str, ...]], + tokenization: NlpTokenizationConfig, + results_field: t.Optional[str] = None, + num_top_classes: t.Optional[int] = None, + ): + super().__init__(configuration_type="text_classification") + self.results_field = results_field + self.num_top_classes = num_top_classes + self.tokenization = tokenization + self.classification_labels = classification_labels + + +class ZeroShotClassificationInferenceOptions(InferenceConfig): + def __init__( + self, + *, + tokenization: NlpTokenizationConfig, + classification_labels: t.Union[t.List[str], t.Tuple[str, ...]], + results_field: t.Optional[str] = None, + multi_label: t.Optional[bool] = None, + labels: t.Optional[t.Union[t.List[str], t.Tuple[str, ...]]] = None, + hypothesis_template: t.Optional[str] = None, + ): + super().__init__(configuration_type="zero_shot_classification") + self.tokenization = tokenization + self.hypothesis_template = hypothesis_template + self.classification_labels = classification_labels + self.results_field = results_field + self.multi_label = multi_label + self.labels = labels + + +class FillMaskInferenceOptions(InferenceConfig): + def __init__( + self, + *, + tokenization: NlpTokenizationConfig, + results_field: t.Optional[str] = None, + num_top_classes: t.Optional[int] = None, + ): + super().__init__(configuration_type="fill_mask") + self.num_top_classes = num_top_classes + self.tokenization = tokenization + self.results_field = results_field + + +class NerInferenceOptions(InferenceConfig): + def __init__( + self, + *, + tokenization: NlpTokenizationConfig, + classification_labels: t.Union[t.List[str], t.Tuple[str, ...]], + results_field: t.Optional[str] = None, + ): + super().__init__(configuration_type="ner") + self.tokenization = tokenization + self.classification_labels = classification_labels + self.results_field = results_field + + +class PassThroughInferenceOptions(InferenceConfig): + def __init__( + self, + *, + tokenization: NlpTokenizationConfig, + results_field: t.Optional[str] = None, + ): + super().__init__(configuration_type="pass_through") + self.tokenization = tokenization + self.results_field = results_field + + +class TextEmbeddingInferenceOptions(InferenceConfig): + def __init__( + self, + *, + tokenization: NlpTokenizationConfig, + results_field: t.Optional[str] = None, + ): + super().__init__(configuration_type="text_embedding") + self.tokenization = tokenization + self.results_field = results_field + + +class TrainedModelInput: + def __init__(self, *, field_names: t.List[str]): + self.field_names = field_names + + def to_dict(self) -> t.Dict[str, t.Any]: + return self.__dict__ + + +class NlpTrainedModelConfig: + def __init__( + self, + *, + description: str, + inference_config: InferenceConfig, + input: TrainedModelInput = TrainedModelInput(field_names=["text_field"]), + metadata: t.Optional[dict] = None, + model_type: t.Union["t.Literal['pytorch']", str] = "pytorch", + default_field_map: t.Optional[t.Mapping[str, str]] = None, + tags: t.Optional[t.Union[t.List[str], t.Tuple[str, ...]]] = None, + ): + self.tags = tags + self.default_field_map = default_field_map + self.description = description + self.inference_config = inference_config + self.input = input + self.metadata = metadata + self.model_type = model_type + + def to_dict(self) -> t.Dict[str, t.Any]: + return { + k: v.to_dict() if hasattr(v, "to_dict") else v + for k, v in self.__dict__.items() + if v is not None + } diff --git a/eland/ml/pytorch/transformers.py b/eland/ml/pytorch/transformers.py index ed6f002..c360e27 100644 --- a/eland/ml/pytorch/transformers.py +++ b/eland/ml/pytorch/transformers.py @@ -37,6 +37,21 @@ from transformers import ( PreTrainedTokenizerFast, ) +from eland.ml.pytorch.nlp_ml_model import ( + FillMaskInferenceOptions, + NerInferenceOptions, + NlpBertTokenizationConfig, + NlpMPNetTokenizationConfig, + NlpRobertaTokenizationConfig, + NlpTokenizationConfig, + NlpTrainedModelConfig, + PassThroughInferenceOptions, + TextClassificationInferenceOptions, + TextEmbeddingInferenceOptions, + TrainedModelInput, + ZeroShotClassificationInferenceOptions, +) + DEFAULT_OUTPUT_KEY = "sentence_embedding" SUPPORTED_TASK_TYPES = { "fill_mask", @@ -45,6 +60,14 @@ SUPPORTED_TASK_TYPES = { "text_embedding", "zero_shot_classification", } +TASK_TYPE_TO_INFERENCE_CONFIG = { + "fill_mask": FillMaskInferenceOptions, + "ner": NerInferenceOptions, + "text_classification": TextClassificationInferenceOptions, + "text_embedding": TextEmbeddingInferenceOptions, + "zero_shot_classification": ZeroShotClassificationInferenceOptions, + "pass_through": PassThroughInferenceOptions, +} SUPPORTED_TASK_TYPES_NAMES = ", ".join(sorted(SUPPORTED_TASK_TYPES)) SUPPORTED_TOKENIZERS = ( transformers.BertTokenizer, @@ -91,8 +114,8 @@ class _DistilBertWrapper(nn.Module): # type: ignore self, input_ids: Tensor, attention_mask: Tensor, - token_type_ids: Tensor, - position_ids: Tensor, + _token_type_ids: Tensor, + _position_ids: Tensor, ) -> Tensor: """Wrap the input and output to conform to the native process interface.""" @@ -246,7 +269,7 @@ class _DPREncoderWrapper(nn.Module): # type: ignore input_ids: Tensor, attention_mask: Tensor, token_type_ids: Tensor, - position_ids: Tensor, + _position_ids: Tensor, ) -> Tensor: """Wrap the input and output to conform to the native process interface.""" @@ -421,50 +444,53 @@ class TransformerModel: vocab_obj["merges"] = merges return vocab_obj - def _create_config(self) -> Dict[str, Any]: + def _create_tokenization_config(self) -> NlpTokenizationConfig: if isinstance(self._tokenizer, transformers.MPNetTokenizer): - tokenizer_type = "mpnet" - tokenizer_obj = { - "do_lower_case": getattr(self._tokenizer, "do_lower_case", False) - } + return NlpMPNetTokenizationConfig( + do_lower_case=getattr(self._tokenizer, "do_lower_case", None), + max_sequence_length=getattr( + self._tokenizer, "max_model_input_sizes", dict() + ).get(self._model_id), + ) elif isinstance( self._tokenizer, (transformers.RobertaTokenizer, transformers.BartTokenizer) ): - tokenizer_type = "roberta" - tokenizer_obj = { - "add_prefix_space": getattr(self._tokenizer, "add_prefix_space", False) - } - else: - tokenizer_type = "bert" - tokenizer_obj = { - "do_lower_case": getattr(self._tokenizer, "do_lower_case", False) - } - inference_config: Dict[str, Dict[str, Any]] = { - self._task_type: {"tokenization": {tokenizer_type: tokenizer_obj}} - } - - if hasattr(self._tokenizer, "max_model_input_sizes"): - max_sequence_length = self._tokenizer.max_model_input_sizes.get( - self._model_id + return NlpRobertaTokenizationConfig( + add_prefix_space=getattr(self._tokenizer, "add_prefix_space", None), + max_sequence_length=getattr( + self._tokenizer, "max_model_input_sizes", dict() + ).get(self._model_id), + ) + else: + return NlpBertTokenizationConfig( + do_lower_case=getattr(self._tokenizer, "do_lower_case", None), + max_sequence_length=getattr( + self._tokenizer, "max_model_input_sizes", dict() + ).get(self._model_id), ) - if max_sequence_length: - inference_config[self._task_type]["tokenization"][tokenizer_type][ - "max_sequence_length" - ] = max_sequence_length - if self._traceable_model.classification_labels(): - inference_config[self._task_type][ - "classification_labels" - ] = self._traceable_model.classification_labels() + def _create_config(self) -> NlpTrainedModelConfig: + tokenization_config = self._create_tokenization_config() - return { - "description": f"Model {self._model_id} for task type '{self._task_type}'", - "model_type": "pytorch", - "inference_config": inference_config, - "input": { - "field_names": ["text_field"], - }, - } + inference_config = ( + TASK_TYPE_TO_INFERENCE_CONFIG[self._task_type]( + tokenization=tokenization_config, + classification_labels=self._traceable_model.classification_labels(), + ) + if self._traceable_model.classification_labels() + else TASK_TYPE_TO_INFERENCE_CONFIG[self._task_type]( + tokenization=tokenization_config + ) + ) + + return NlpTrainedModelConfig( + description=f"Model {self._model_id} for task type '{self._task_type}'", + model_type="pytorch", + inference_config=inference_config, + input=TrainedModelInput( + field_names=["text_field"], + ), + ) def _create_traceable_model(self) -> _TraceableModel: if self._task_type == "fill_mask": @@ -514,19 +540,14 @@ class TransformerModel: # Elasticsearch model IDs need to be a specific format: no special chars, all lowercase, max 64 chars return self._model_id.replace("/", "__").lower()[:64] - def save(self, path: str) -> Tuple[str, str, str]: + def save(self, path: str) -> Tuple[str, NlpTrainedModelConfig, str]: # save traced model model_path = os.path.join(path, "traced_pytorch_model.pt") torch.jit.save(self._traced_model, model_path) - # save configuration - config_path = os.path.join(path, "config.json") - with open(config_path, "w") as outfile: - json.dump(self._config, outfile) - # save vocabulary vocab_path = os.path.join(path, "vocabulary.json") with open(vocab_path, "w") as outfile: json.dump(self._vocab, outfile) - return model_path, config_path, vocab_path + return model_path, self._config, vocab_path diff --git a/tests/ml/pytorch/test_pytorch_model_pytest.py b/tests/ml/pytorch/test_pytorch_model_pytest.py index 5cbd3ad..353467d 100644 --- a/tests/ml/pytorch/test_pytorch_model_pytest.py +++ b/tests/ml/pytorch/test_pytorch_model_pytest.py @@ -79,7 +79,7 @@ def setup_and_tear_down(): def download_model_and_start_deployment(tmp_dir, quantize, model_id, task): print("Loading HuggingFace transformer tokenizer and model") tm = TransformerModel(model_id, task, quantize) - model_path, config_path, vocab_path = tm.save(tmp_dir) + model_path, config, vocab_path = tm.save(tmp_dir) ptm = PyTorchModel(ES_TEST_CLIENT, tm.elasticsearch_model_id()) try: ptm.stop() @@ -87,7 +87,9 @@ def download_model_and_start_deployment(tmp_dir, quantize, model_id, task): except NotFoundError: pass print(f"Importing model: {ptm.model_id}") - ptm.import_model(model_path, config_path, vocab_path) + ptm.import_model( + model_path=model_path, config_path=None, vocab_path=vocab_path, config=config + ) ptm.start() return ptm