[ML] Improve NLP model import by using nicely defined types (#459)

This adds some more definite types for our NLP tasks and tokenization configurations.

This is the first step in allowing users to more easily import their own transformer models via something other than hugging face.
This commit is contained in:
Benjamin Trent 2022-05-03 15:19:03 -04:00 committed by GitHub
parent 3255f55d71
commit afe08f8107
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 333 additions and 61 deletions

View File

@ -250,11 +250,11 @@ Downloading: 100%|██████████| 249M/249M [00:23<00:00, 11.2MB
# Export the model in a TorchScrpt representation which Elasticsearch uses # Export the model in a TorchScrpt representation which Elasticsearch uses
>>> tmp_path = "models" >>> tmp_path = "models"
>>> Path(tmp_path).mkdir(parents=True, exist_ok=True) >>> Path(tmp_path).mkdir(parents=True, exist_ok=True)
>>> model_path, config_path, vocab_path = tm.save(tmp_path) >>> model_path, config, vocab_path = tm.save(tmp_path)
# Import model into Elasticsearch # Import model into Elasticsearch
>>> es = elasticsearch.Elasticsearch("http://elastic:mlqa_admin@localhost:9200", timeout=300) # 5 minute timeout >>> es = elasticsearch.Elasticsearch("http://elastic:mlqa_admin@localhost:9200", timeout=300) # 5 minute timeout
>>> ptm = PyTorchModel(es, tm.elasticsearch_model_id()) >>> ptm = PyTorchModel(es, tm.elasticsearch_model_id())
>>> ptm.import_model(model_path, config_path, vocab_path) >>> ptm.import_model(model_path=model_path, config_path=None, vocab_path=vocab_path, config=config)
100%|██████████| 63/63 [00:12<00:00, 5.02it/s] 100%|██████████| 63/63 [00:12<00:00, 5.02it/s]
``` ```

View File

@ -188,7 +188,7 @@ if __name__ == "__main__":
logger.info(f"Loading HuggingFace transformer tokenizer and model '{args.hub_model_id}'") logger.info(f"Loading HuggingFace transformer tokenizer and model '{args.hub_model_id}'")
tm = TransformerModel(args.hub_model_id, args.task_type, args.quantize) tm = TransformerModel(args.hub_model_id, args.task_type, args.quantize)
model_path, config_path, vocab_path = tm.save(tmp_dir) model_path, config, vocab_path = tm.save(tmp_dir)
ptm = PyTorchModel(es, args.es_model_id if args.es_model_id else tm.elasticsearch_model_id()) ptm = PyTorchModel(es, args.es_model_id if args.es_model_id else tm.elasticsearch_model_id())
model_exists = es.options(ignore_status=404).ml.get_trained_models(model_id=ptm.model_id).meta.status == 200 model_exists = es.options(ignore_status=404).ml.get_trained_models(model_id=ptm.model_id).meta.status == 200
@ -206,7 +206,7 @@ if __name__ == "__main__":
exit(1) exit(1)
logger.info(f"Creating model with id '{ptm.model_id}'") logger.info(f"Creating model with id '{ptm.model_id}'")
ptm.put_config(config_path) ptm.put_config(config=config)
logger.info(f"Uploading model definition") logger.info(f"Uploading model definition")
ptm.put_model(model_path) ptm.put_model(model_path)

View File

@ -19,11 +19,22 @@ import base64
import json import json
import math import math
import os import os
from typing import TYPE_CHECKING, Any, Iterable, List, Mapping, Set, Tuple, Union from typing import (
TYPE_CHECKING,
Any,
Iterable,
List,
Mapping,
Optional,
Set,
Tuple,
Union,
)
from tqdm.auto import tqdm # type: ignore from tqdm.auto import tqdm # type: ignore
from eland.common import ensure_es_client from eland.common import ensure_es_client
from eland.ml.pytorch.nlp_ml_model import NlpTrainedModelConfig
if TYPE_CHECKING: if TYPE_CHECKING:
from elasticsearch import Elasticsearch from elasticsearch import Elasticsearch
@ -49,10 +60,19 @@ class PyTorchModel:
self._client: Elasticsearch = ensure_es_client(es_client) self._client: Elasticsearch = ensure_es_client(es_client)
self.model_id = model_id self.model_id = model_id
def put_config(self, path: str) -> None: def put_config(
with open(path) as f: self, path: Optional[str] = None, config: Optional[NlpTrainedModelConfig] = None
config = json.load(f) ) -> None:
self._client.ml.put_trained_model(model_id=self.model_id, **config) if path is not None and config is not None:
raise ValueError("Only include path or config. Not both")
if path is not None:
with open(path) as f:
config_map = json.load(f)
elif config is not None:
config_map = config.to_dict()
else:
raise ValueError("Must provide path or config")
self._client.ml.put_trained_model(model_id=self.model_id, **config_map)
def put_vocab(self, path: str) -> None: def put_vocab(self, path: str) -> None:
with open(path) as f: with open(path) as f:
@ -89,13 +109,14 @@ class PyTorchModel:
def import_model( def import_model(
self, self,
*,
model_path: str, model_path: str,
config_path: str, config_path: Optional[str],
vocab_path: str, vocab_path: str,
config: Optional[NlpTrainedModelConfig] = None,
chunk_size: int = DEFAULT_CHUNK_SIZE, chunk_size: int = DEFAULT_CHUNK_SIZE,
) -> None: ) -> None:
# TODO: Implement some pre-flight checks on config, vocab, and model self.put_config(path=config_path, config=config)
self.put_config(config_path)
self.put_model(model_path, chunk_size) self.put_model(model_path, chunk_size)
self.put_vocab(vocab_path) self.put_vocab(vocab_path)

View File

@ -0,0 +1,228 @@
# Licensed to Elasticsearch B.V. under one or more contributor
# license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright
# ownership. Elasticsearch B.V. licenses this file to you under
# the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import typing as t
class NlpTokenizationConfig:
def __init__(self, *, configuration_type: str):
self.name = configuration_type
def to_dict(self):
return {
self.name: {
k: v for k, v in self.__dict__.items() if v is not None and k != "name"
}
}
class NlpRobertaTokenizationConfig(NlpTokenizationConfig):
def __init__(
self,
*,
add_prefix_space: t.Optional[bool] = None,
with_special_tokens: t.Optional[bool] = None,
max_sequence_length: t.Optional[int] = None,
truncate: t.Optional[
t.Union["t.Literal['first', 'none', 'second']", str]
] = None,
span: t.Optional[int] = None,
):
super().__init__(configuration_type="roberta")
self.add_prefix_space = add_prefix_space
self.with_special_tokens = with_special_tokens
self.max_sequence_length = max_sequence_length
self.truncate = truncate
self.span = span
class NlpBertTokenizationConfig(NlpTokenizationConfig):
def __init__(
self,
*,
do_lower_case: t.Optional[bool] = None,
with_special_tokens: t.Optional[bool] = None,
max_sequence_length: t.Optional[int] = None,
truncate: t.Optional[
t.Union["t.Literal['first', 'none', 'second']", str]
] = None,
span: t.Optional[int] = None,
):
super().__init__(configuration_type="bert")
self.do_lower_case = do_lower_case
self.with_special_tokens = with_special_tokens
self.max_sequence_length = max_sequence_length
self.truncate = truncate
self.span = span
class NlpMPNetTokenizationConfig(NlpTokenizationConfig):
def __init__(
self,
*,
do_lower_case: t.Optional[bool] = None,
with_special_tokens: t.Optional[bool] = None,
max_sequence_length: t.Optional[int] = None,
truncate: t.Optional[
t.Union["t.Literal['first', 'none', 'second']", str]
] = None,
span: t.Optional[int] = None,
):
super().__init__(configuration_type="mpnet")
self.do_lower_case = do_lower_case
self.with_special_tokens = with_special_tokens
self.max_sequence_length = max_sequence_length
self.truncate = truncate
self.span = span
class InferenceConfig:
def __init__(self, *, configuration_type: str):
self.name = configuration_type
def to_dict(self) -> t.Dict[str, t.Any]:
return {
self.name: {
k: v.to_dict() if hasattr(v, "to_dict") else v
for k, v in self.__dict__.items()
if v is not None and k != "name"
}
}
class TextClassificationInferenceOptions(InferenceConfig):
def __init__(
self,
*,
classification_labels: t.Union[t.List[str], t.Tuple[str, ...]],
tokenization: NlpTokenizationConfig,
results_field: t.Optional[str] = None,
num_top_classes: t.Optional[int] = None,
):
super().__init__(configuration_type="text_classification")
self.results_field = results_field
self.num_top_classes = num_top_classes
self.tokenization = tokenization
self.classification_labels = classification_labels
class ZeroShotClassificationInferenceOptions(InferenceConfig):
def __init__(
self,
*,
tokenization: NlpTokenizationConfig,
classification_labels: t.Union[t.List[str], t.Tuple[str, ...]],
results_field: t.Optional[str] = None,
multi_label: t.Optional[bool] = None,
labels: t.Optional[t.Union[t.List[str], t.Tuple[str, ...]]] = None,
hypothesis_template: t.Optional[str] = None,
):
super().__init__(configuration_type="zero_shot_classification")
self.tokenization = tokenization
self.hypothesis_template = hypothesis_template
self.classification_labels = classification_labels
self.results_field = results_field
self.multi_label = multi_label
self.labels = labels
class FillMaskInferenceOptions(InferenceConfig):
def __init__(
self,
*,
tokenization: NlpTokenizationConfig,
results_field: t.Optional[str] = None,
num_top_classes: t.Optional[int] = None,
):
super().__init__(configuration_type="fill_mask")
self.num_top_classes = num_top_classes
self.tokenization = tokenization
self.results_field = results_field
class NerInferenceOptions(InferenceConfig):
def __init__(
self,
*,
tokenization: NlpTokenizationConfig,
classification_labels: t.Union[t.List[str], t.Tuple[str, ...]],
results_field: t.Optional[str] = None,
):
super().__init__(configuration_type="ner")
self.tokenization = tokenization
self.classification_labels = classification_labels
self.results_field = results_field
class PassThroughInferenceOptions(InferenceConfig):
def __init__(
self,
*,
tokenization: NlpTokenizationConfig,
results_field: t.Optional[str] = None,
):
super().__init__(configuration_type="pass_through")
self.tokenization = tokenization
self.results_field = results_field
class TextEmbeddingInferenceOptions(InferenceConfig):
def __init__(
self,
*,
tokenization: NlpTokenizationConfig,
results_field: t.Optional[str] = None,
):
super().__init__(configuration_type="text_embedding")
self.tokenization = tokenization
self.results_field = results_field
class TrainedModelInput:
def __init__(self, *, field_names: t.List[str]):
self.field_names = field_names
def to_dict(self) -> t.Dict[str, t.Any]:
return self.__dict__
class NlpTrainedModelConfig:
def __init__(
self,
*,
description: str,
inference_config: InferenceConfig,
input: TrainedModelInput = TrainedModelInput(field_names=["text_field"]),
metadata: t.Optional[dict] = None,
model_type: t.Union["t.Literal['pytorch']", str] = "pytorch",
default_field_map: t.Optional[t.Mapping[str, str]] = None,
tags: t.Optional[t.Union[t.List[str], t.Tuple[str, ...]]] = None,
):
self.tags = tags
self.default_field_map = default_field_map
self.description = description
self.inference_config = inference_config
self.input = input
self.metadata = metadata
self.model_type = model_type
def to_dict(self) -> t.Dict[str, t.Any]:
return {
k: v.to_dict() if hasattr(v, "to_dict") else v
for k, v in self.__dict__.items()
if v is not None
}

View File

@ -37,6 +37,21 @@ from transformers import (
PreTrainedTokenizerFast, PreTrainedTokenizerFast,
) )
from eland.ml.pytorch.nlp_ml_model import (
FillMaskInferenceOptions,
NerInferenceOptions,
NlpBertTokenizationConfig,
NlpMPNetTokenizationConfig,
NlpRobertaTokenizationConfig,
NlpTokenizationConfig,
NlpTrainedModelConfig,
PassThroughInferenceOptions,
TextClassificationInferenceOptions,
TextEmbeddingInferenceOptions,
TrainedModelInput,
ZeroShotClassificationInferenceOptions,
)
DEFAULT_OUTPUT_KEY = "sentence_embedding" DEFAULT_OUTPUT_KEY = "sentence_embedding"
SUPPORTED_TASK_TYPES = { SUPPORTED_TASK_TYPES = {
"fill_mask", "fill_mask",
@ -45,6 +60,14 @@ SUPPORTED_TASK_TYPES = {
"text_embedding", "text_embedding",
"zero_shot_classification", "zero_shot_classification",
} }
TASK_TYPE_TO_INFERENCE_CONFIG = {
"fill_mask": FillMaskInferenceOptions,
"ner": NerInferenceOptions,
"text_classification": TextClassificationInferenceOptions,
"text_embedding": TextEmbeddingInferenceOptions,
"zero_shot_classification": ZeroShotClassificationInferenceOptions,
"pass_through": PassThroughInferenceOptions,
}
SUPPORTED_TASK_TYPES_NAMES = ", ".join(sorted(SUPPORTED_TASK_TYPES)) SUPPORTED_TASK_TYPES_NAMES = ", ".join(sorted(SUPPORTED_TASK_TYPES))
SUPPORTED_TOKENIZERS = ( SUPPORTED_TOKENIZERS = (
transformers.BertTokenizer, transformers.BertTokenizer,
@ -91,8 +114,8 @@ class _DistilBertWrapper(nn.Module): # type: ignore
self, self,
input_ids: Tensor, input_ids: Tensor,
attention_mask: Tensor, attention_mask: Tensor,
token_type_ids: Tensor, _token_type_ids: Tensor,
position_ids: Tensor, _position_ids: Tensor,
) -> Tensor: ) -> Tensor:
"""Wrap the input and output to conform to the native process interface.""" """Wrap the input and output to conform to the native process interface."""
@ -246,7 +269,7 @@ class _DPREncoderWrapper(nn.Module): # type: ignore
input_ids: Tensor, input_ids: Tensor,
attention_mask: Tensor, attention_mask: Tensor,
token_type_ids: Tensor, token_type_ids: Tensor,
position_ids: Tensor, _position_ids: Tensor,
) -> Tensor: ) -> Tensor:
"""Wrap the input and output to conform to the native process interface.""" """Wrap the input and output to conform to the native process interface."""
@ -421,50 +444,53 @@ class TransformerModel:
vocab_obj["merges"] = merges vocab_obj["merges"] = merges
return vocab_obj return vocab_obj
def _create_config(self) -> Dict[str, Any]: def _create_tokenization_config(self) -> NlpTokenizationConfig:
if isinstance(self._tokenizer, transformers.MPNetTokenizer): if isinstance(self._tokenizer, transformers.MPNetTokenizer):
tokenizer_type = "mpnet" return NlpMPNetTokenizationConfig(
tokenizer_obj = { do_lower_case=getattr(self._tokenizer, "do_lower_case", None),
"do_lower_case": getattr(self._tokenizer, "do_lower_case", False) max_sequence_length=getattr(
} self._tokenizer, "max_model_input_sizes", dict()
).get(self._model_id),
)
elif isinstance( elif isinstance(
self._tokenizer, (transformers.RobertaTokenizer, transformers.BartTokenizer) self._tokenizer, (transformers.RobertaTokenizer, transformers.BartTokenizer)
): ):
tokenizer_type = "roberta" return NlpRobertaTokenizationConfig(
tokenizer_obj = { add_prefix_space=getattr(self._tokenizer, "add_prefix_space", None),
"add_prefix_space": getattr(self._tokenizer, "add_prefix_space", False) max_sequence_length=getattr(
} self._tokenizer, "max_model_input_sizes", dict()
else: ).get(self._model_id),
tokenizer_type = "bert" )
tokenizer_obj = { else:
"do_lower_case": getattr(self._tokenizer, "do_lower_case", False) return NlpBertTokenizationConfig(
} do_lower_case=getattr(self._tokenizer, "do_lower_case", None),
inference_config: Dict[str, Dict[str, Any]] = { max_sequence_length=getattr(
self._task_type: {"tokenization": {tokenizer_type: tokenizer_obj}} self._tokenizer, "max_model_input_sizes", dict()
} ).get(self._model_id),
if hasattr(self._tokenizer, "max_model_input_sizes"):
max_sequence_length = self._tokenizer.max_model_input_sizes.get(
self._model_id
) )
if max_sequence_length:
inference_config[self._task_type]["tokenization"][tokenizer_type][
"max_sequence_length"
] = max_sequence_length
if self._traceable_model.classification_labels(): def _create_config(self) -> NlpTrainedModelConfig:
inference_config[self._task_type][ tokenization_config = self._create_tokenization_config()
"classification_labels"
] = self._traceable_model.classification_labels()
return { inference_config = (
"description": f"Model {self._model_id} for task type '{self._task_type}'", TASK_TYPE_TO_INFERENCE_CONFIG[self._task_type](
"model_type": "pytorch", tokenization=tokenization_config,
"inference_config": inference_config, classification_labels=self._traceable_model.classification_labels(),
"input": { )
"field_names": ["text_field"], if self._traceable_model.classification_labels()
}, else TASK_TYPE_TO_INFERENCE_CONFIG[self._task_type](
} tokenization=tokenization_config
)
)
return NlpTrainedModelConfig(
description=f"Model {self._model_id} for task type '{self._task_type}'",
model_type="pytorch",
inference_config=inference_config,
input=TrainedModelInput(
field_names=["text_field"],
),
)
def _create_traceable_model(self) -> _TraceableModel: def _create_traceable_model(self) -> _TraceableModel:
if self._task_type == "fill_mask": if self._task_type == "fill_mask":
@ -514,19 +540,14 @@ class TransformerModel:
# Elasticsearch model IDs need to be a specific format: no special chars, all lowercase, max 64 chars # Elasticsearch model IDs need to be a specific format: no special chars, all lowercase, max 64 chars
return self._model_id.replace("/", "__").lower()[:64] return self._model_id.replace("/", "__").lower()[:64]
def save(self, path: str) -> Tuple[str, str, str]: def save(self, path: str) -> Tuple[str, NlpTrainedModelConfig, str]:
# save traced model # save traced model
model_path = os.path.join(path, "traced_pytorch_model.pt") model_path = os.path.join(path, "traced_pytorch_model.pt")
torch.jit.save(self._traced_model, model_path) torch.jit.save(self._traced_model, model_path)
# save configuration
config_path = os.path.join(path, "config.json")
with open(config_path, "w") as outfile:
json.dump(self._config, outfile)
# save vocabulary # save vocabulary
vocab_path = os.path.join(path, "vocabulary.json") vocab_path = os.path.join(path, "vocabulary.json")
with open(vocab_path, "w") as outfile: with open(vocab_path, "w") as outfile:
json.dump(self._vocab, outfile) json.dump(self._vocab, outfile)
return model_path, config_path, vocab_path return model_path, self._config, vocab_path

View File

@ -79,7 +79,7 @@ def setup_and_tear_down():
def download_model_and_start_deployment(tmp_dir, quantize, model_id, task): def download_model_and_start_deployment(tmp_dir, quantize, model_id, task):
print("Loading HuggingFace transformer tokenizer and model") print("Loading HuggingFace transformer tokenizer and model")
tm = TransformerModel(model_id, task, quantize) tm = TransformerModel(model_id, task, quantize)
model_path, config_path, vocab_path = tm.save(tmp_dir) model_path, config, vocab_path = tm.save(tmp_dir)
ptm = PyTorchModel(ES_TEST_CLIENT, tm.elasticsearch_model_id()) ptm = PyTorchModel(ES_TEST_CLIENT, tm.elasticsearch_model_id())
try: try:
ptm.stop() ptm.stop()
@ -87,7 +87,9 @@ def download_model_and_start_deployment(tmp_dir, quantize, model_id, task):
except NotFoundError: except NotFoundError:
pass pass
print(f"Importing model: {ptm.model_id}") print(f"Importing model: {ptm.model_id}")
ptm.import_model(model_path, config_path, vocab_path) ptm.import_model(
model_path=model_path, config_path=None, vocab_path=vocab_path, config=config
)
ptm.start() ptm.start()
return ptm return ptm