[ML] improve general pytorch model import and add tests (#463)

This improves the user consumed functions and classes for PyTorch NLP model upload to Elasticsearch.

Previously it was difficult to wrap your own module for uploading to Elasticsearch.

This commit splits some classes out, adds new ones, and adds tests showing how to wrap some simple modules.
This commit is contained in:
Benjamin Trent 2022-05-05 10:50:53 -04:00 committed by GitHub
parent 70fadc9986
commit 650e02d16e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 364 additions and 25 deletions

View File

@ -16,5 +16,19 @@
# under the License.
from eland.ml.pytorch._pytorch_model import PyTorchModel # noqa: F401
from eland.ml.pytorch.nlp_ml_model import (
NlpBertTokenizationConfig,
NlpMPNetTokenizationConfig,
NlpRobertaTokenizationConfig,
NlpTrainedModelConfig,
)
from eland.ml.pytorch.traceable_model import TraceableModel # noqa: F401
__all__ = ["PyTorchModel"]
__all__ = [
"PyTorchModel",
"TraceableModel",
"NlpTrainedModelConfig",
"NlpBertTokenizationConfig",
"NlpRobertaTokenizationConfig",
"NlpMPNetTokenizationConfig",
]

View File

@ -0,0 +1,63 @@
# Licensed to Elasticsearch B.V. under one or more contributor
# license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright
# ownership. Elasticsearch B.V. licenses this file to you under
# the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import os.path
from abc import ABC, abstractmethod
from typing import List, Optional, Union
import torch # type: ignore
from torch import nn
TracedModelTypes = Union[
torch.nn.Module,
torch.ScriptModule,
torch.jit.ScriptModule,
torch.jit.TopLevelTracedModule,
]
class TraceableModel(ABC):
"""A base class representing a pytorch model that can be traced."""
def __init__(
self,
model: nn.Module,
):
self._model = model
def quantize(self) -> None:
torch.quantization.quantize_dynamic(
self._model, {torch.nn.Linear}, dtype=torch.qint8
)
def trace(self) -> TracedModelTypes:
# model needs to be in evaluate mode
self._model.eval()
return self._trace()
@abstractmethod
def _trace(self) -> TracedModelTypes:
...
def classification_labels(self) -> Optional[List[str]]:
return None
def save(self, path: str) -> str:
model_path = os.path.join(path, "traced_pytorch_model.pt")
trace_model = self.trace()
torch.jit.save(trace_model, model_path)
return model_path

View File

@ -53,6 +53,7 @@ from eland.ml.pytorch.nlp_ml_model import (
TrainedModelInput,
ZeroShotClassificationInferenceOptions,
)
from eland.ml.pytorch.traceable_model import TraceableModel
DEFAULT_OUTPUT_KEY = "sentence_embedding"
SUPPORTED_TASK_TYPES = {
@ -364,7 +365,7 @@ class _DPREncoderWrapper(nn.Module): # type: ignore
)
class _TraceableModel(ABC):
class _TransformerTraceableModel(TraceableModel):
"""A base class representing a HuggingFace transformer model that can be traced."""
def __init__(
@ -377,18 +378,10 @@ class _TraceableModel(ABC):
_DistilBertWrapper,
],
):
super(_TransformerTraceableModel, self).__init__(model=model)
self._tokenizer = tokenizer
self._model = model
def quantize(self) -> None:
torch.quantization.quantize_dynamic(
self._model, {torch.nn.Linear}, dtype=torch.qint8
)
def trace(self) -> TracedModelTypes:
# model needs to be in evaluate mode
self._model.eval()
def _trace(self) -> TracedModelTypes:
inputs = self._prepare_inputs()
# Add params when not provided by the tokenizer (e.g. DistilBERT), to conform to BERT interface
@ -425,11 +418,8 @@ class _TraceableModel(ABC):
def _prepare_inputs(self) -> transformers.BatchEncoding:
...
def classification_labels(self) -> Optional[List[str]]:
return None
class _TraceableClassificationModel(_TraceableModel, ABC):
class _TraceableClassificationModel(_TransformerTraceableModel, ABC):
def classification_labels(self) -> Optional[List[str]]:
id_label_items = self._model.config.id2label.items()
labels = [v for _, v in sorted(id_label_items, key=lambda kv: kv[0])] # type: ignore
@ -438,7 +428,7 @@ class _TraceableClassificationModel(_TraceableModel, ABC):
return [label.replace("-", "_") for label in labels]
class _TraceableFillMaskModel(_TraceableModel):
class _TraceableFillMaskModel(_TransformerTraceableModel):
def _prepare_inputs(self) -> transformers.BatchEncoding:
return self._tokenizer(
"Who was Jim Henson?",
@ -469,7 +459,7 @@ class _TraceableTextClassificationModel(_TraceableClassificationModel):
)
class _TraceableTextEmbeddingModel(_TraceableModel):
class _TraceableTextEmbeddingModel(_TransformerTraceableModel):
def _prepare_inputs(self) -> transformers.BatchEncoding:
return self._tokenizer(
"This is an example sentence.",
@ -488,7 +478,7 @@ class _TraceableZeroShotClassificationModel(_TraceableClassificationModel):
)
class _TraceableQuestionAnsweringModel(_TraceableModel):
class _TraceableQuestionAnsweringModel(_TransformerTraceableModel):
def _prepare_inputs(self) -> transformers.BatchEncoding:
return self._tokenizer(
"What is the meaning of life?"
@ -520,7 +510,6 @@ class TransformerModel:
self._traceable_model = self._create_traceable_model()
if quantize:
self._traceable_model.quantize()
self._traced_model = self._traceable_model.trace()
self._vocab = self._load_vocab()
self._config = self._create_config()
@ -591,7 +580,7 @@ class TransformerModel:
),
)
def _create_traceable_model(self) -> _TraceableModel:
def _create_traceable_model(self) -> TraceableModel:
if self._task_type == "fill_mask":
model = transformers.AutoModelForMaskedLM.from_pretrained(
self._model_id, torchscript=True
@ -643,8 +632,7 @@ class TransformerModel:
def save(self, path: str) -> Tuple[str, NlpTrainedModelConfig, str]:
# save traced model
model_path = os.path.join(path, "traced_pytorch_model.pt")
torch.jit.save(self._traced_model, model_path)
model_path = self._traceable_model.save(path)
# save vocabulary
vocab_path = os.path.join(path, "vocabulary.json")

View File

@ -27,8 +27,6 @@ except ImportError:
HAS_SKLEARN = False
try:
import transformers # noqa: F401
from eland.ml.pytorch import PyTorchModel
from eland.ml.pytorch.transformers import TransformerModel

View File

@ -0,0 +1,276 @@
# Licensed to Elasticsearch B.V. under one or more contributor
# license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright
# ownership. Elasticsearch B.V. licenses this file to you under
# the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import json
import os
import tempfile
from abc import ABC
from typing import Union
import numpy as np
import pytest
from elasticsearch import NotFoundError
try:
import sklearn # noqa: F401
HAS_SKLEARN = True
except ImportError:
HAS_SKLEARN = False
try:
import torch # noqa: F401
from torch import Tensor, nn # noqa: F401
from eland.ml.pytorch import ( # noqa: F401
NlpBertTokenizationConfig,
NlpTrainedModelConfig,
PyTorchModel,
TraceableModel,
)
from eland.ml.pytorch.nlp_ml_model import (
NerInferenceOptions,
TextClassificationInferenceOptions,
TextEmbeddingInferenceOptions,
)
TracedModelTypes = Union[ # noqa: F401
torch.nn.Module,
torch.ScriptModule,
torch.jit.ScriptModule,
torch.jit.TopLevelTracedModule,
]
HAS_PYTORCH = True
except ImportError:
HAS_PYTORCH = False
from tests import ES_TEST_CLIENT, ES_VERSION
pytestmark = [
pytest.mark.skipif(
ES_VERSION < (8, 0, 0),
reason="This test requires at least Elasticsearch version 8.0.0",
),
pytest.mark.skipif(
not HAS_SKLEARN, reason="This test requires 'scikit-learn' package to run"
),
pytest.mark.skipif(
not HAS_PYTORCH, reason="This test requires 'pytorch' package to run"
),
]
TEST_BERT_VOCAB = [
"Elastic",
"##search",
"is",
"fun",
"my",
"little",
"red",
"car",
"God",
"##zilla",
".",
",",
"[CLS]",
"[SEP]",
"[MASK]",
"[PAD]",
"[UNK]",
"day",
"Pancake",
"with",
"?",
]
NER_LABELS = [
"O",
"B_MISC",
"I_MISC",
"B_PER",
"I_PER",
"B_ORG",
"I_ORG",
"B_LOC",
"I_LOC",
]
TEXT_CLASSIFICATION_LABELS = ["foo", "bar", "baz"]
if not HAS_PYTORCH:
pytest.skip("This test requires 'pytorch' package to run", allow_module_level=True)
class TestTraceableModel(TraceableModel, ABC):
def __init__(self, model: nn.Module):
super().__init__(model)
def _trace(self) -> TracedModelTypes:
input_ids = torch.tensor(np.array(range(0, len(TEST_BERT_VOCAB))))
attention_mask = torch.tensor([1] * len(TEST_BERT_VOCAB))
token_type_ids = torch.tensor([0] * len(TEST_BERT_VOCAB))
position_ids = torch.arange(len(TEST_BERT_VOCAB), dtype=torch.long)
return torch.jit.trace(
self._model,
(
input_ids,
attention_mask,
token_type_ids,
position_ids,
),
)
class NerModule(nn.Module):
def forward(
self,
input_ids: Tensor,
_attention_mask: Tensor,
_token_type_ids: Tensor,
_position_ids: Tensor,
) -> Tensor:
"""Wrap the input and output to conform to the native process interface."""
outside = [0] * len(NER_LABELS)
outside[0] = 1
person = [0] * len(NER_LABELS)
person[3] = 1
person[4] = 1
result = [outside for _t in np.array(input_ids.data)]
result[1] = person
result[2] = person
return torch.tensor([result], dtype=torch.float)
class EmbeddingModule(nn.Module):
def forward(
self,
_input_ids: Tensor,
_attention_mask: Tensor,
_token_type_ids: Tensor,
_position_ids: Tensor,
) -> Tensor:
"""Wrap the input and output to conform to the native process interface."""
result = [0] * 512
return torch.tensor([result], dtype=torch.float)
class TextClassificationModule(nn.Module):
def forward(
self,
_input_ids: Tensor,
_attention_mask: Tensor,
_token_type_ids: Tensor,
_position_ids: Tensor,
) -> Tensor:
# foo, bar, baz are the classification labels
result = [0, 1.0, 0]
return torch.tensor([result], dtype=torch.float)
MODELS_TO_TEST = [
(
"ner",
TestTraceableModel(model=NerModule()),
NlpTrainedModelConfig(
description="test ner model",
inference_config=NerInferenceOptions(
tokenization=NlpBertTokenizationConfig(),
classification_labels=NER_LABELS,
),
),
"Godzilla Pancake Elasticsearch is fun.",
"[Godzilla](PER&Godzilla) Pancake Elasticsearch is fun.",
),
(
"embedding",
TestTraceableModel(model=EmbeddingModule()),
NlpTrainedModelConfig(
description="test text_embedding model",
inference_config=TextEmbeddingInferenceOptions(
tokenization=NlpBertTokenizationConfig()
),
),
"Godzilla Pancake Elasticsearch is fun.",
[0] * 512,
),
(
"text_classification",
TestTraceableModel(model=TextClassificationModule()),
NlpTrainedModelConfig(
description="test text_classification model",
inference_config=TextClassificationInferenceOptions(
tokenization=NlpBertTokenizationConfig(),
classification_labels=TEXT_CLASSIFICATION_LABELS,
),
),
"Godzilla Pancake Elasticsearch is fun.",
"bar",
),
]
@pytest.fixture(scope="function", autouse=True)
def setup_and_tear_down():
ES_TEST_CLIENT.cluster.put_settings(
body={"transient": {"logger.org.elasticsearch.xpack.ml": "DEBUG"}}
)
yield
for (
model_id,
_,
_,
_,
_,
) in MODELS_TO_TEST:
model = PyTorchModel(ES_TEST_CLIENT, model_id.replace("/", "__").lower()[:64])
try:
model.stop()
model.delete()
except NotFoundError:
pass
def upload_model_and_start_deployment(
tmp_dir: str, model: TraceableModel, config: NlpTrainedModelConfig, model_id: str
):
print("Loading HuggingFace transformer tokenizer and model")
model_path = model.save(tmp_dir)
vocab_path = os.path.join(tmp_dir, "vocabulary.json")
with open(vocab_path, "w") as outfile:
json.dump({"vocabulary": TEST_BERT_VOCAB}, outfile)
ptm = PyTorchModel(ES_TEST_CLIENT, model_id)
try:
ptm.stop()
ptm.delete()
except NotFoundError:
pass
print(f"Importing model: {ptm.model_id}")
ptm.import_model(
model_path=model_path, config_path=None, vocab_path=vocab_path, config=config
)
ptm.start()
return ptm
class TestPytorchModelUpload:
@pytest.mark.parametrize("model_id,model,config,input,prediction", MODELS_TO_TEST)
def test_model_upload(self, model_id, model, config, input, prediction):
with tempfile.TemporaryDirectory() as tmp_dir:
ptm = upload_model_and_start_deployment(tmp_dir, model, config, model_id)
result = ptm.infer(docs=[{"text_field": input}])
assert result.get("predicted_value") is not None
assert result["predicted_value"] == prediction