diff --git a/README.md b/README.md index b28d0b6..fdcd10c 100644 --- a/README.md +++ b/README.md @@ -250,11 +250,11 @@ Downloading: 100%|██████████| 249M/249M [00:23<00:00, 11.2MB # Export the model in a TorchScrpt representation which Elasticsearch uses >>> tmp_path = "models" >>> Path(tmp_path).mkdir(parents=True, exist_ok=True) ->>> model_path, config_path, vocab_path = tm.save(tmp_path) +>>> model_path, config, vocab_path = tm.save(tmp_path) # Import model into Elasticsearch >>> es = elasticsearch.Elasticsearch("http://elastic:mlqa_admin@localhost:9200", timeout=300) # 5 minute timeout >>> ptm = PyTorchModel(es, tm.elasticsearch_model_id()) ->>> ptm.import_model(model_path, config_path, vocab_path) +>>> ptm.import_model(model_path=model_path, config_path=None, vocab_path=vocab_path, config=config) 100%|██████████| 63/63 [00:12<00:00, 5.02it/s] ``` diff --git a/bin/eland_import_hub_model b/bin/eland_import_hub_model index e8e6ee9..332c982 100755 --- a/bin/eland_import_hub_model +++ b/bin/eland_import_hub_model @@ -188,7 +188,7 @@ if __name__ == "__main__": logger.info(f"Loading HuggingFace transformer tokenizer and model '{args.hub_model_id}'") tm = TransformerModel(args.hub_model_id, args.task_type, args.quantize) - model_path, config_path, vocab_path = tm.save(tmp_dir) + model_path, config, vocab_path = tm.save(tmp_dir) ptm = PyTorchModel(es, args.es_model_id if args.es_model_id else tm.elasticsearch_model_id()) model_exists = es.options(ignore_status=404).ml.get_trained_models(model_id=ptm.model_id).meta.status == 200 @@ -206,7 +206,7 @@ if __name__ == "__main__": exit(1) logger.info(f"Creating model with id '{ptm.model_id}'") - ptm.put_config(config_path) + ptm.put_config(config=config) logger.info(f"Uploading model definition") ptm.put_model(model_path) diff --git a/eland/ml/pytorch/_pytorch_model.py b/eland/ml/pytorch/_pytorch_model.py index e823e53..10b4d01 100644 --- a/eland/ml/pytorch/_pytorch_model.py +++ b/eland/ml/pytorch/_pytorch_model.py @@ -19,11 +19,22 @@ import base64 import json import math import os -from typing import TYPE_CHECKING, Any, Iterable, List, Mapping, Set, Tuple, Union +from typing import ( + TYPE_CHECKING, + Any, + Iterable, + List, + Mapping, + Optional, + Set, + Tuple, + Union, +) from tqdm.auto import tqdm # type: ignore from eland.common import ensure_es_client +from eland.ml.pytorch.nlp_ml_model import NlpTrainedModelConfig if TYPE_CHECKING: from elasticsearch import Elasticsearch @@ -49,10 +60,19 @@ class PyTorchModel: self._client: Elasticsearch = ensure_es_client(es_client) self.model_id = model_id - def put_config(self, path: str) -> None: - with open(path) as f: - config = json.load(f) - self._client.ml.put_trained_model(model_id=self.model_id, **config) + def put_config( + self, path: Optional[str] = None, config: Optional[NlpTrainedModelConfig] = None + ) -> None: + if path is not None and config is not None: + raise ValueError("Only include path or config. Not both") + if path is not None: + with open(path) as f: + config_map = json.load(f) + elif config is not None: + config_map = config.to_dict() + else: + raise ValueError("Must provide path or config") + self._client.ml.put_trained_model(model_id=self.model_id, **config_map) def put_vocab(self, path: str) -> None: with open(path) as f: @@ -89,13 +109,14 @@ class PyTorchModel: def import_model( self, + *, model_path: str, - config_path: str, + config_path: Optional[str], vocab_path: str, + config: Optional[NlpTrainedModelConfig] = None, chunk_size: int = DEFAULT_CHUNK_SIZE, ) -> None: - # TODO: Implement some pre-flight checks on config, vocab, and model - self.put_config(config_path) + self.put_config(path=config_path, config=config) self.put_model(model_path, chunk_size) self.put_vocab(vocab_path) diff --git a/eland/ml/pytorch/nlp_ml_model.py b/eland/ml/pytorch/nlp_ml_model.py new file mode 100644 index 0000000..b9441e0 --- /dev/null +++ b/eland/ml/pytorch/nlp_ml_model.py @@ -0,0 +1,228 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import typing as t + + +class NlpTokenizationConfig: + def __init__(self, *, configuration_type: str): + self.name = configuration_type + + def to_dict(self): + return { + self.name: { + k: v for k, v in self.__dict__.items() if v is not None and k != "name" + } + } + + +class NlpRobertaTokenizationConfig(NlpTokenizationConfig): + def __init__( + self, + *, + add_prefix_space: t.Optional[bool] = None, + with_special_tokens: t.Optional[bool] = None, + max_sequence_length: t.Optional[int] = None, + truncate: t.Optional[ + t.Union["t.Literal['first', 'none', 'second']", str] + ] = None, + span: t.Optional[int] = None, + ): + super().__init__(configuration_type="roberta") + self.add_prefix_space = add_prefix_space + self.with_special_tokens = with_special_tokens + self.max_sequence_length = max_sequence_length + self.truncate = truncate + self.span = span + + +class NlpBertTokenizationConfig(NlpTokenizationConfig): + def __init__( + self, + *, + do_lower_case: t.Optional[bool] = None, + with_special_tokens: t.Optional[bool] = None, + max_sequence_length: t.Optional[int] = None, + truncate: t.Optional[ + t.Union["t.Literal['first', 'none', 'second']", str] + ] = None, + span: t.Optional[int] = None, + ): + super().__init__(configuration_type="bert") + self.do_lower_case = do_lower_case + self.with_special_tokens = with_special_tokens + self.max_sequence_length = max_sequence_length + self.truncate = truncate + self.span = span + + +class NlpMPNetTokenizationConfig(NlpTokenizationConfig): + def __init__( + self, + *, + do_lower_case: t.Optional[bool] = None, + with_special_tokens: t.Optional[bool] = None, + max_sequence_length: t.Optional[int] = None, + truncate: t.Optional[ + t.Union["t.Literal['first', 'none', 'second']", str] + ] = None, + span: t.Optional[int] = None, + ): + super().__init__(configuration_type="mpnet") + self.do_lower_case = do_lower_case + self.with_special_tokens = with_special_tokens + self.max_sequence_length = max_sequence_length + self.truncate = truncate + self.span = span + + +class InferenceConfig: + def __init__(self, *, configuration_type: str): + self.name = configuration_type + + def to_dict(self) -> t.Dict[str, t.Any]: + return { + self.name: { + k: v.to_dict() if hasattr(v, "to_dict") else v + for k, v in self.__dict__.items() + if v is not None and k != "name" + } + } + + +class TextClassificationInferenceOptions(InferenceConfig): + def __init__( + self, + *, + classification_labels: t.Union[t.List[str], t.Tuple[str, ...]], + tokenization: NlpTokenizationConfig, + results_field: t.Optional[str] = None, + num_top_classes: t.Optional[int] = None, + ): + super().__init__(configuration_type="text_classification") + self.results_field = results_field + self.num_top_classes = num_top_classes + self.tokenization = tokenization + self.classification_labels = classification_labels + + +class ZeroShotClassificationInferenceOptions(InferenceConfig): + def __init__( + self, + *, + tokenization: NlpTokenizationConfig, + classification_labels: t.Union[t.List[str], t.Tuple[str, ...]], + results_field: t.Optional[str] = None, + multi_label: t.Optional[bool] = None, + labels: t.Optional[t.Union[t.List[str], t.Tuple[str, ...]]] = None, + hypothesis_template: t.Optional[str] = None, + ): + super().__init__(configuration_type="zero_shot_classification") + self.tokenization = tokenization + self.hypothesis_template = hypothesis_template + self.classification_labels = classification_labels + self.results_field = results_field + self.multi_label = multi_label + self.labels = labels + + +class FillMaskInferenceOptions(InferenceConfig): + def __init__( + self, + *, + tokenization: NlpTokenizationConfig, + results_field: t.Optional[str] = None, + num_top_classes: t.Optional[int] = None, + ): + super().__init__(configuration_type="fill_mask") + self.num_top_classes = num_top_classes + self.tokenization = tokenization + self.results_field = results_field + + +class NerInferenceOptions(InferenceConfig): + def __init__( + self, + *, + tokenization: NlpTokenizationConfig, + classification_labels: t.Union[t.List[str], t.Tuple[str, ...]], + results_field: t.Optional[str] = None, + ): + super().__init__(configuration_type="ner") + self.tokenization = tokenization + self.classification_labels = classification_labels + self.results_field = results_field + + +class PassThroughInferenceOptions(InferenceConfig): + def __init__( + self, + *, + tokenization: NlpTokenizationConfig, + results_field: t.Optional[str] = None, + ): + super().__init__(configuration_type="pass_through") + self.tokenization = tokenization + self.results_field = results_field + + +class TextEmbeddingInferenceOptions(InferenceConfig): + def __init__( + self, + *, + tokenization: NlpTokenizationConfig, + results_field: t.Optional[str] = None, + ): + super().__init__(configuration_type="text_embedding") + self.tokenization = tokenization + self.results_field = results_field + + +class TrainedModelInput: + def __init__(self, *, field_names: t.List[str]): + self.field_names = field_names + + def to_dict(self) -> t.Dict[str, t.Any]: + return self.__dict__ + + +class NlpTrainedModelConfig: + def __init__( + self, + *, + description: str, + inference_config: InferenceConfig, + input: TrainedModelInput = TrainedModelInput(field_names=["text_field"]), + metadata: t.Optional[dict] = None, + model_type: t.Union["t.Literal['pytorch']", str] = "pytorch", + default_field_map: t.Optional[t.Mapping[str, str]] = None, + tags: t.Optional[t.Union[t.List[str], t.Tuple[str, ...]]] = None, + ): + self.tags = tags + self.default_field_map = default_field_map + self.description = description + self.inference_config = inference_config + self.input = input + self.metadata = metadata + self.model_type = model_type + + def to_dict(self) -> t.Dict[str, t.Any]: + return { + k: v.to_dict() if hasattr(v, "to_dict") else v + for k, v in self.__dict__.items() + if v is not None + } diff --git a/eland/ml/pytorch/transformers.py b/eland/ml/pytorch/transformers.py index ed6f002..c360e27 100644 --- a/eland/ml/pytorch/transformers.py +++ b/eland/ml/pytorch/transformers.py @@ -37,6 +37,21 @@ from transformers import ( PreTrainedTokenizerFast, ) +from eland.ml.pytorch.nlp_ml_model import ( + FillMaskInferenceOptions, + NerInferenceOptions, + NlpBertTokenizationConfig, + NlpMPNetTokenizationConfig, + NlpRobertaTokenizationConfig, + NlpTokenizationConfig, + NlpTrainedModelConfig, + PassThroughInferenceOptions, + TextClassificationInferenceOptions, + TextEmbeddingInferenceOptions, + TrainedModelInput, + ZeroShotClassificationInferenceOptions, +) + DEFAULT_OUTPUT_KEY = "sentence_embedding" SUPPORTED_TASK_TYPES = { "fill_mask", @@ -45,6 +60,14 @@ SUPPORTED_TASK_TYPES = { "text_embedding", "zero_shot_classification", } +TASK_TYPE_TO_INFERENCE_CONFIG = { + "fill_mask": FillMaskInferenceOptions, + "ner": NerInferenceOptions, + "text_classification": TextClassificationInferenceOptions, + "text_embedding": TextEmbeddingInferenceOptions, + "zero_shot_classification": ZeroShotClassificationInferenceOptions, + "pass_through": PassThroughInferenceOptions, +} SUPPORTED_TASK_TYPES_NAMES = ", ".join(sorted(SUPPORTED_TASK_TYPES)) SUPPORTED_TOKENIZERS = ( transformers.BertTokenizer, @@ -91,8 +114,8 @@ class _DistilBertWrapper(nn.Module): # type: ignore self, input_ids: Tensor, attention_mask: Tensor, - token_type_ids: Tensor, - position_ids: Tensor, + _token_type_ids: Tensor, + _position_ids: Tensor, ) -> Tensor: """Wrap the input and output to conform to the native process interface.""" @@ -246,7 +269,7 @@ class _DPREncoderWrapper(nn.Module): # type: ignore input_ids: Tensor, attention_mask: Tensor, token_type_ids: Tensor, - position_ids: Tensor, + _position_ids: Tensor, ) -> Tensor: """Wrap the input and output to conform to the native process interface.""" @@ -421,50 +444,53 @@ class TransformerModel: vocab_obj["merges"] = merges return vocab_obj - def _create_config(self) -> Dict[str, Any]: + def _create_tokenization_config(self) -> NlpTokenizationConfig: if isinstance(self._tokenizer, transformers.MPNetTokenizer): - tokenizer_type = "mpnet" - tokenizer_obj = { - "do_lower_case": getattr(self._tokenizer, "do_lower_case", False) - } + return NlpMPNetTokenizationConfig( + do_lower_case=getattr(self._tokenizer, "do_lower_case", None), + max_sequence_length=getattr( + self._tokenizer, "max_model_input_sizes", dict() + ).get(self._model_id), + ) elif isinstance( self._tokenizer, (transformers.RobertaTokenizer, transformers.BartTokenizer) ): - tokenizer_type = "roberta" - tokenizer_obj = { - "add_prefix_space": getattr(self._tokenizer, "add_prefix_space", False) - } - else: - tokenizer_type = "bert" - tokenizer_obj = { - "do_lower_case": getattr(self._tokenizer, "do_lower_case", False) - } - inference_config: Dict[str, Dict[str, Any]] = { - self._task_type: {"tokenization": {tokenizer_type: tokenizer_obj}} - } - - if hasattr(self._tokenizer, "max_model_input_sizes"): - max_sequence_length = self._tokenizer.max_model_input_sizes.get( - self._model_id + return NlpRobertaTokenizationConfig( + add_prefix_space=getattr(self._tokenizer, "add_prefix_space", None), + max_sequence_length=getattr( + self._tokenizer, "max_model_input_sizes", dict() + ).get(self._model_id), + ) + else: + return NlpBertTokenizationConfig( + do_lower_case=getattr(self._tokenizer, "do_lower_case", None), + max_sequence_length=getattr( + self._tokenizer, "max_model_input_sizes", dict() + ).get(self._model_id), ) - if max_sequence_length: - inference_config[self._task_type]["tokenization"][tokenizer_type][ - "max_sequence_length" - ] = max_sequence_length - if self._traceable_model.classification_labels(): - inference_config[self._task_type][ - "classification_labels" - ] = self._traceable_model.classification_labels() + def _create_config(self) -> NlpTrainedModelConfig: + tokenization_config = self._create_tokenization_config() - return { - "description": f"Model {self._model_id} for task type '{self._task_type}'", - "model_type": "pytorch", - "inference_config": inference_config, - "input": { - "field_names": ["text_field"], - }, - } + inference_config = ( + TASK_TYPE_TO_INFERENCE_CONFIG[self._task_type]( + tokenization=tokenization_config, + classification_labels=self._traceable_model.classification_labels(), + ) + if self._traceable_model.classification_labels() + else TASK_TYPE_TO_INFERENCE_CONFIG[self._task_type]( + tokenization=tokenization_config + ) + ) + + return NlpTrainedModelConfig( + description=f"Model {self._model_id} for task type '{self._task_type}'", + model_type="pytorch", + inference_config=inference_config, + input=TrainedModelInput( + field_names=["text_field"], + ), + ) def _create_traceable_model(self) -> _TraceableModel: if self._task_type == "fill_mask": @@ -514,19 +540,14 @@ class TransformerModel: # Elasticsearch model IDs need to be a specific format: no special chars, all lowercase, max 64 chars return self._model_id.replace("/", "__").lower()[:64] - def save(self, path: str) -> Tuple[str, str, str]: + def save(self, path: str) -> Tuple[str, NlpTrainedModelConfig, str]: # save traced model model_path = os.path.join(path, "traced_pytorch_model.pt") torch.jit.save(self._traced_model, model_path) - # save configuration - config_path = os.path.join(path, "config.json") - with open(config_path, "w") as outfile: - json.dump(self._config, outfile) - # save vocabulary vocab_path = os.path.join(path, "vocabulary.json") with open(vocab_path, "w") as outfile: json.dump(self._vocab, outfile) - return model_path, config_path, vocab_path + return model_path, self._config, vocab_path diff --git a/tests/ml/pytorch/test_pytorch_model_pytest.py b/tests/ml/pytorch/test_pytorch_model_pytest.py index 5cbd3ad..353467d 100644 --- a/tests/ml/pytorch/test_pytorch_model_pytest.py +++ b/tests/ml/pytorch/test_pytorch_model_pytest.py @@ -79,7 +79,7 @@ def setup_and_tear_down(): def download_model_and_start_deployment(tmp_dir, quantize, model_id, task): print("Loading HuggingFace transformer tokenizer and model") tm = TransformerModel(model_id, task, quantize) - model_path, config_path, vocab_path = tm.save(tmp_dir) + model_path, config, vocab_path = tm.save(tmp_dir) ptm = PyTorchModel(ES_TEST_CLIENT, tm.elasticsearch_model_id()) try: ptm.stop() @@ -87,7 +87,9 @@ def download_model_and_start_deployment(tmp_dir, quantize, model_id, task): except NotFoundError: pass print(f"Importing model: {ptm.model_id}") - ptm.import_model(model_path, config_path, vocab_path) + ptm.import_model( + model_path=model_path, config_path=None, vocab_path=vocab_path, config=config + ) ptm.start() return ptm