mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
[ML] improve general pytorch model import and add tests (#463)
This improves the user consumed functions and classes for PyTorch NLP model upload to Elasticsearch. Previously it was difficult to wrap your own module for uploading to Elasticsearch. This commit splits some classes out, adds new ones, and adds tests showing how to wrap some simple modules.
This commit is contained in:
parent
70fadc9986
commit
650e02d16e
@ -16,5 +16,19 @@
|
||||
# under the License.
|
||||
|
||||
from eland.ml.pytorch._pytorch_model import PyTorchModel # noqa: F401
|
||||
from eland.ml.pytorch.nlp_ml_model import (
|
||||
NlpBertTokenizationConfig,
|
||||
NlpMPNetTokenizationConfig,
|
||||
NlpRobertaTokenizationConfig,
|
||||
NlpTrainedModelConfig,
|
||||
)
|
||||
from eland.ml.pytorch.traceable_model import TraceableModel # noqa: F401
|
||||
|
||||
__all__ = ["PyTorchModel"]
|
||||
__all__ = [
|
||||
"PyTorchModel",
|
||||
"TraceableModel",
|
||||
"NlpTrainedModelConfig",
|
||||
"NlpBertTokenizationConfig",
|
||||
"NlpRobertaTokenizationConfig",
|
||||
"NlpMPNetTokenizationConfig",
|
||||
]
|
||||
|
63
eland/ml/pytorch/traceable_model.py
Normal file
63
eland/ml/pytorch/traceable_model.py
Normal file
@ -0,0 +1,63 @@
|
||||
# Licensed to Elasticsearch B.V. under one or more contributor
|
||||
# license agreements. See the NOTICE file distributed with
|
||||
# this work for additional information regarding copyright
|
||||
# ownership. Elasticsearch B.V. licenses this file to you under
|
||||
# the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import os.path
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch # type: ignore
|
||||
from torch import nn
|
||||
|
||||
TracedModelTypes = Union[
|
||||
torch.nn.Module,
|
||||
torch.ScriptModule,
|
||||
torch.jit.ScriptModule,
|
||||
torch.jit.TopLevelTracedModule,
|
||||
]
|
||||
|
||||
|
||||
class TraceableModel(ABC):
|
||||
"""A base class representing a pytorch model that can be traced."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
):
|
||||
self._model = model
|
||||
|
||||
def quantize(self) -> None:
|
||||
torch.quantization.quantize_dynamic(
|
||||
self._model, {torch.nn.Linear}, dtype=torch.qint8
|
||||
)
|
||||
|
||||
def trace(self) -> TracedModelTypes:
|
||||
# model needs to be in evaluate mode
|
||||
self._model.eval()
|
||||
return self._trace()
|
||||
|
||||
@abstractmethod
|
||||
def _trace(self) -> TracedModelTypes:
|
||||
...
|
||||
|
||||
def classification_labels(self) -> Optional[List[str]]:
|
||||
return None
|
||||
|
||||
def save(self, path: str) -> str:
|
||||
model_path = os.path.join(path, "traced_pytorch_model.pt")
|
||||
trace_model = self.trace()
|
||||
torch.jit.save(trace_model, model_path)
|
||||
return model_path
|
@ -53,6 +53,7 @@ from eland.ml.pytorch.nlp_ml_model import (
|
||||
TrainedModelInput,
|
||||
ZeroShotClassificationInferenceOptions,
|
||||
)
|
||||
from eland.ml.pytorch.traceable_model import TraceableModel
|
||||
|
||||
DEFAULT_OUTPUT_KEY = "sentence_embedding"
|
||||
SUPPORTED_TASK_TYPES = {
|
||||
@ -364,7 +365,7 @@ class _DPREncoderWrapper(nn.Module): # type: ignore
|
||||
)
|
||||
|
||||
|
||||
class _TraceableModel(ABC):
|
||||
class _TransformerTraceableModel(TraceableModel):
|
||||
"""A base class representing a HuggingFace transformer model that can be traced."""
|
||||
|
||||
def __init__(
|
||||
@ -377,18 +378,10 @@ class _TraceableModel(ABC):
|
||||
_DistilBertWrapper,
|
||||
],
|
||||
):
|
||||
super(_TransformerTraceableModel, self).__init__(model=model)
|
||||
self._tokenizer = tokenizer
|
||||
self._model = model
|
||||
|
||||
def quantize(self) -> None:
|
||||
torch.quantization.quantize_dynamic(
|
||||
self._model, {torch.nn.Linear}, dtype=torch.qint8
|
||||
)
|
||||
|
||||
def trace(self) -> TracedModelTypes:
|
||||
# model needs to be in evaluate mode
|
||||
self._model.eval()
|
||||
|
||||
def _trace(self) -> TracedModelTypes:
|
||||
inputs = self._prepare_inputs()
|
||||
|
||||
# Add params when not provided by the tokenizer (e.g. DistilBERT), to conform to BERT interface
|
||||
@ -425,11 +418,8 @@ class _TraceableModel(ABC):
|
||||
def _prepare_inputs(self) -> transformers.BatchEncoding:
|
||||
...
|
||||
|
||||
def classification_labels(self) -> Optional[List[str]]:
|
||||
return None
|
||||
|
||||
|
||||
class _TraceableClassificationModel(_TraceableModel, ABC):
|
||||
class _TraceableClassificationModel(_TransformerTraceableModel, ABC):
|
||||
def classification_labels(self) -> Optional[List[str]]:
|
||||
id_label_items = self._model.config.id2label.items()
|
||||
labels = [v for _, v in sorted(id_label_items, key=lambda kv: kv[0])] # type: ignore
|
||||
@ -438,7 +428,7 @@ class _TraceableClassificationModel(_TraceableModel, ABC):
|
||||
return [label.replace("-", "_") for label in labels]
|
||||
|
||||
|
||||
class _TraceableFillMaskModel(_TraceableModel):
|
||||
class _TraceableFillMaskModel(_TransformerTraceableModel):
|
||||
def _prepare_inputs(self) -> transformers.BatchEncoding:
|
||||
return self._tokenizer(
|
||||
"Who was Jim Henson?",
|
||||
@ -469,7 +459,7 @@ class _TraceableTextClassificationModel(_TraceableClassificationModel):
|
||||
)
|
||||
|
||||
|
||||
class _TraceableTextEmbeddingModel(_TraceableModel):
|
||||
class _TraceableTextEmbeddingModel(_TransformerTraceableModel):
|
||||
def _prepare_inputs(self) -> transformers.BatchEncoding:
|
||||
return self._tokenizer(
|
||||
"This is an example sentence.",
|
||||
@ -488,7 +478,7 @@ class _TraceableZeroShotClassificationModel(_TraceableClassificationModel):
|
||||
)
|
||||
|
||||
|
||||
class _TraceableQuestionAnsweringModel(_TraceableModel):
|
||||
class _TraceableQuestionAnsweringModel(_TransformerTraceableModel):
|
||||
def _prepare_inputs(self) -> transformers.BatchEncoding:
|
||||
return self._tokenizer(
|
||||
"What is the meaning of life?"
|
||||
@ -520,7 +510,6 @@ class TransformerModel:
|
||||
self._traceable_model = self._create_traceable_model()
|
||||
if quantize:
|
||||
self._traceable_model.quantize()
|
||||
self._traced_model = self._traceable_model.trace()
|
||||
self._vocab = self._load_vocab()
|
||||
self._config = self._create_config()
|
||||
|
||||
@ -591,7 +580,7 @@ class TransformerModel:
|
||||
),
|
||||
)
|
||||
|
||||
def _create_traceable_model(self) -> _TraceableModel:
|
||||
def _create_traceable_model(self) -> TraceableModel:
|
||||
if self._task_type == "fill_mask":
|
||||
model = transformers.AutoModelForMaskedLM.from_pretrained(
|
||||
self._model_id, torchscript=True
|
||||
@ -643,8 +632,7 @@ class TransformerModel:
|
||||
|
||||
def save(self, path: str) -> Tuple[str, NlpTrainedModelConfig, str]:
|
||||
# save traced model
|
||||
model_path = os.path.join(path, "traced_pytorch_model.pt")
|
||||
torch.jit.save(self._traced_model, model_path)
|
||||
model_path = self._traceable_model.save(path)
|
||||
|
||||
# save vocabulary
|
||||
vocab_path = os.path.join(path, "vocabulary.json")
|
||||
|
@ -27,8 +27,6 @@ except ImportError:
|
||||
HAS_SKLEARN = False
|
||||
|
||||
try:
|
||||
import transformers # noqa: F401
|
||||
|
||||
from eland.ml.pytorch import PyTorchModel
|
||||
from eland.ml.pytorch.transformers import TransformerModel
|
||||
|
276
tests/ml/pytorch/test_transformer_pytorch_model_pytest.py
Normal file
276
tests/ml/pytorch/test_transformer_pytorch_model_pytest.py
Normal file
@ -0,0 +1,276 @@
|
||||
# Licensed to Elasticsearch B.V. under one or more contributor
|
||||
# license agreements. See the NOTICE file distributed with
|
||||
# this work for additional information regarding copyright
|
||||
# ownership. Elasticsearch B.V. licenses this file to you under
|
||||
# the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
from abc import ABC
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from elasticsearch import NotFoundError
|
||||
|
||||
try:
|
||||
import sklearn # noqa: F401
|
||||
|
||||
HAS_SKLEARN = True
|
||||
except ImportError:
|
||||
HAS_SKLEARN = False
|
||||
|
||||
try:
|
||||
import torch # noqa: F401
|
||||
from torch import Tensor, nn # noqa: F401
|
||||
|
||||
from eland.ml.pytorch import ( # noqa: F401
|
||||
NlpBertTokenizationConfig,
|
||||
NlpTrainedModelConfig,
|
||||
PyTorchModel,
|
||||
TraceableModel,
|
||||
)
|
||||
from eland.ml.pytorch.nlp_ml_model import (
|
||||
NerInferenceOptions,
|
||||
TextClassificationInferenceOptions,
|
||||
TextEmbeddingInferenceOptions,
|
||||
)
|
||||
|
||||
TracedModelTypes = Union[ # noqa: F401
|
||||
torch.nn.Module,
|
||||
torch.ScriptModule,
|
||||
torch.jit.ScriptModule,
|
||||
torch.jit.TopLevelTracedModule,
|
||||
]
|
||||
|
||||
HAS_PYTORCH = True
|
||||
except ImportError:
|
||||
HAS_PYTORCH = False
|
||||
|
||||
from tests import ES_TEST_CLIENT, ES_VERSION
|
||||
|
||||
pytestmark = [
|
||||
pytest.mark.skipif(
|
||||
ES_VERSION < (8, 0, 0),
|
||||
reason="This test requires at least Elasticsearch version 8.0.0",
|
||||
),
|
||||
pytest.mark.skipif(
|
||||
not HAS_SKLEARN, reason="This test requires 'scikit-learn' package to run"
|
||||
),
|
||||
pytest.mark.skipif(
|
||||
not HAS_PYTORCH, reason="This test requires 'pytorch' package to run"
|
||||
),
|
||||
]
|
||||
|
||||
TEST_BERT_VOCAB = [
|
||||
"Elastic",
|
||||
"##search",
|
||||
"is",
|
||||
"fun",
|
||||
"my",
|
||||
"little",
|
||||
"red",
|
||||
"car",
|
||||
"God",
|
||||
"##zilla",
|
||||
".",
|
||||
",",
|
||||
"[CLS]",
|
||||
"[SEP]",
|
||||
"[MASK]",
|
||||
"[PAD]",
|
||||
"[UNK]",
|
||||
"day",
|
||||
"Pancake",
|
||||
"with",
|
||||
"?",
|
||||
]
|
||||
|
||||
NER_LABELS = [
|
||||
"O",
|
||||
"B_MISC",
|
||||
"I_MISC",
|
||||
"B_PER",
|
||||
"I_PER",
|
||||
"B_ORG",
|
||||
"I_ORG",
|
||||
"B_LOC",
|
||||
"I_LOC",
|
||||
]
|
||||
|
||||
TEXT_CLASSIFICATION_LABELS = ["foo", "bar", "baz"]
|
||||
|
||||
if not HAS_PYTORCH:
|
||||
pytest.skip("This test requires 'pytorch' package to run", allow_module_level=True)
|
||||
|
||||
|
||||
class TestTraceableModel(TraceableModel, ABC):
|
||||
def __init__(self, model: nn.Module):
|
||||
super().__init__(model)
|
||||
|
||||
def _trace(self) -> TracedModelTypes:
|
||||
input_ids = torch.tensor(np.array(range(0, len(TEST_BERT_VOCAB))))
|
||||
attention_mask = torch.tensor([1] * len(TEST_BERT_VOCAB))
|
||||
token_type_ids = torch.tensor([0] * len(TEST_BERT_VOCAB))
|
||||
position_ids = torch.arange(len(TEST_BERT_VOCAB), dtype=torch.long)
|
||||
return torch.jit.trace(
|
||||
self._model,
|
||||
(
|
||||
input_ids,
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class NerModule(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Tensor,
|
||||
_attention_mask: Tensor,
|
||||
_token_type_ids: Tensor,
|
||||
_position_ids: Tensor,
|
||||
) -> Tensor:
|
||||
"""Wrap the input and output to conform to the native process interface."""
|
||||
outside = [0] * len(NER_LABELS)
|
||||
outside[0] = 1
|
||||
person = [0] * len(NER_LABELS)
|
||||
person[3] = 1
|
||||
person[4] = 1
|
||||
result = [outside for _t in np.array(input_ids.data)]
|
||||
result[1] = person
|
||||
result[2] = person
|
||||
return torch.tensor([result], dtype=torch.float)
|
||||
|
||||
|
||||
class EmbeddingModule(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
_input_ids: Tensor,
|
||||
_attention_mask: Tensor,
|
||||
_token_type_ids: Tensor,
|
||||
_position_ids: Tensor,
|
||||
) -> Tensor:
|
||||
"""Wrap the input and output to conform to the native process interface."""
|
||||
result = [0] * 512
|
||||
return torch.tensor([result], dtype=torch.float)
|
||||
|
||||
|
||||
class TextClassificationModule(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
_input_ids: Tensor,
|
||||
_attention_mask: Tensor,
|
||||
_token_type_ids: Tensor,
|
||||
_position_ids: Tensor,
|
||||
) -> Tensor:
|
||||
# foo, bar, baz are the classification labels
|
||||
result = [0, 1.0, 0]
|
||||
return torch.tensor([result], dtype=torch.float)
|
||||
|
||||
|
||||
MODELS_TO_TEST = [
|
||||
(
|
||||
"ner",
|
||||
TestTraceableModel(model=NerModule()),
|
||||
NlpTrainedModelConfig(
|
||||
description="test ner model",
|
||||
inference_config=NerInferenceOptions(
|
||||
tokenization=NlpBertTokenizationConfig(),
|
||||
classification_labels=NER_LABELS,
|
||||
),
|
||||
),
|
||||
"Godzilla Pancake Elasticsearch is fun.",
|
||||
"[Godzilla](PER&Godzilla) Pancake Elasticsearch is fun.",
|
||||
),
|
||||
(
|
||||
"embedding",
|
||||
TestTraceableModel(model=EmbeddingModule()),
|
||||
NlpTrainedModelConfig(
|
||||
description="test text_embedding model",
|
||||
inference_config=TextEmbeddingInferenceOptions(
|
||||
tokenization=NlpBertTokenizationConfig()
|
||||
),
|
||||
),
|
||||
"Godzilla Pancake Elasticsearch is fun.",
|
||||
[0] * 512,
|
||||
),
|
||||
(
|
||||
"text_classification",
|
||||
TestTraceableModel(model=TextClassificationModule()),
|
||||
NlpTrainedModelConfig(
|
||||
description="test text_classification model",
|
||||
inference_config=TextClassificationInferenceOptions(
|
||||
tokenization=NlpBertTokenizationConfig(),
|
||||
classification_labels=TEXT_CLASSIFICATION_LABELS,
|
||||
),
|
||||
),
|
||||
"Godzilla Pancake Elasticsearch is fun.",
|
||||
"bar",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def setup_and_tear_down():
|
||||
ES_TEST_CLIENT.cluster.put_settings(
|
||||
body={"transient": {"logger.org.elasticsearch.xpack.ml": "DEBUG"}}
|
||||
)
|
||||
yield
|
||||
for (
|
||||
model_id,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
) in MODELS_TO_TEST:
|
||||
model = PyTorchModel(ES_TEST_CLIENT, model_id.replace("/", "__").lower()[:64])
|
||||
try:
|
||||
model.stop()
|
||||
model.delete()
|
||||
except NotFoundError:
|
||||
pass
|
||||
|
||||
|
||||
def upload_model_and_start_deployment(
|
||||
tmp_dir: str, model: TraceableModel, config: NlpTrainedModelConfig, model_id: str
|
||||
):
|
||||
print("Loading HuggingFace transformer tokenizer and model")
|
||||
model_path = model.save(tmp_dir)
|
||||
vocab_path = os.path.join(tmp_dir, "vocabulary.json")
|
||||
with open(vocab_path, "w") as outfile:
|
||||
json.dump({"vocabulary": TEST_BERT_VOCAB}, outfile)
|
||||
ptm = PyTorchModel(ES_TEST_CLIENT, model_id)
|
||||
try:
|
||||
ptm.stop()
|
||||
ptm.delete()
|
||||
except NotFoundError:
|
||||
pass
|
||||
print(f"Importing model: {ptm.model_id}")
|
||||
ptm.import_model(
|
||||
model_path=model_path, config_path=None, vocab_path=vocab_path, config=config
|
||||
)
|
||||
ptm.start()
|
||||
return ptm
|
||||
|
||||
|
||||
class TestPytorchModelUpload:
|
||||
@pytest.mark.parametrize("model_id,model,config,input,prediction", MODELS_TO_TEST)
|
||||
def test_model_upload(self, model_id, model, config, input, prediction):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
ptm = upload_model_and_start_deployment(tmp_dir, model, config, model_id)
|
||||
result = ptm.infer(docs=[{"text_field": input}])
|
||||
assert result.get("predicted_value") is not None
|
||||
assert result["predicted_value"] == prediction
|
Loading…
x
Reference in New Issue
Block a user