mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
[ML] Improve NLP model import by using nicely defined types (#459)
This adds some more definite types for our NLP tasks and tokenization configurations. This is the first step in allowing users to more easily import their own transformer models via something other than hugging face.
This commit is contained in:
parent
3255f55d71
commit
afe08f8107
@ -250,11 +250,11 @@ Downloading: 100%|██████████| 249M/249M [00:23<00:00, 11.2MB
|
|||||||
# Export the model in a TorchScrpt representation which Elasticsearch uses
|
# Export the model in a TorchScrpt representation which Elasticsearch uses
|
||||||
>>> tmp_path = "models"
|
>>> tmp_path = "models"
|
||||||
>>> Path(tmp_path).mkdir(parents=True, exist_ok=True)
|
>>> Path(tmp_path).mkdir(parents=True, exist_ok=True)
|
||||||
>>> model_path, config_path, vocab_path = tm.save(tmp_path)
|
>>> model_path, config, vocab_path = tm.save(tmp_path)
|
||||||
|
|
||||||
# Import model into Elasticsearch
|
# Import model into Elasticsearch
|
||||||
>>> es = elasticsearch.Elasticsearch("http://elastic:mlqa_admin@localhost:9200", timeout=300) # 5 minute timeout
|
>>> es = elasticsearch.Elasticsearch("http://elastic:mlqa_admin@localhost:9200", timeout=300) # 5 minute timeout
|
||||||
>>> ptm = PyTorchModel(es, tm.elasticsearch_model_id())
|
>>> ptm = PyTorchModel(es, tm.elasticsearch_model_id())
|
||||||
>>> ptm.import_model(model_path, config_path, vocab_path)
|
>>> ptm.import_model(model_path=model_path, config_path=None, vocab_path=vocab_path, config=config)
|
||||||
100%|██████████| 63/63 [00:12<00:00, 5.02it/s]
|
100%|██████████| 63/63 [00:12<00:00, 5.02it/s]
|
||||||
```
|
```
|
||||||
|
@ -188,7 +188,7 @@ if __name__ == "__main__":
|
|||||||
logger.info(f"Loading HuggingFace transformer tokenizer and model '{args.hub_model_id}'")
|
logger.info(f"Loading HuggingFace transformer tokenizer and model '{args.hub_model_id}'")
|
||||||
|
|
||||||
tm = TransformerModel(args.hub_model_id, args.task_type, args.quantize)
|
tm = TransformerModel(args.hub_model_id, args.task_type, args.quantize)
|
||||||
model_path, config_path, vocab_path = tm.save(tmp_dir)
|
model_path, config, vocab_path = tm.save(tmp_dir)
|
||||||
|
|
||||||
ptm = PyTorchModel(es, args.es_model_id if args.es_model_id else tm.elasticsearch_model_id())
|
ptm = PyTorchModel(es, args.es_model_id if args.es_model_id else tm.elasticsearch_model_id())
|
||||||
model_exists = es.options(ignore_status=404).ml.get_trained_models(model_id=ptm.model_id).meta.status == 200
|
model_exists = es.options(ignore_status=404).ml.get_trained_models(model_id=ptm.model_id).meta.status == 200
|
||||||
@ -206,7 +206,7 @@ if __name__ == "__main__":
|
|||||||
exit(1)
|
exit(1)
|
||||||
|
|
||||||
logger.info(f"Creating model with id '{ptm.model_id}'")
|
logger.info(f"Creating model with id '{ptm.model_id}'")
|
||||||
ptm.put_config(config_path)
|
ptm.put_config(config=config)
|
||||||
|
|
||||||
logger.info(f"Uploading model definition")
|
logger.info(f"Uploading model definition")
|
||||||
ptm.put_model(model_path)
|
ptm.put_model(model_path)
|
||||||
|
@ -19,11 +19,22 @@ import base64
|
|||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
from typing import TYPE_CHECKING, Any, Iterable, List, Mapping, Set, Tuple, Union
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
Iterable,
|
||||||
|
List,
|
||||||
|
Mapping,
|
||||||
|
Optional,
|
||||||
|
Set,
|
||||||
|
Tuple,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
from tqdm.auto import tqdm # type: ignore
|
from tqdm.auto import tqdm # type: ignore
|
||||||
|
|
||||||
from eland.common import ensure_es_client
|
from eland.common import ensure_es_client
|
||||||
|
from eland.ml.pytorch.nlp_ml_model import NlpTrainedModelConfig
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from elasticsearch import Elasticsearch
|
from elasticsearch import Elasticsearch
|
||||||
@ -49,10 +60,19 @@ class PyTorchModel:
|
|||||||
self._client: Elasticsearch = ensure_es_client(es_client)
|
self._client: Elasticsearch = ensure_es_client(es_client)
|
||||||
self.model_id = model_id
|
self.model_id = model_id
|
||||||
|
|
||||||
def put_config(self, path: str) -> None:
|
def put_config(
|
||||||
with open(path) as f:
|
self, path: Optional[str] = None, config: Optional[NlpTrainedModelConfig] = None
|
||||||
config = json.load(f)
|
) -> None:
|
||||||
self._client.ml.put_trained_model(model_id=self.model_id, **config)
|
if path is not None and config is not None:
|
||||||
|
raise ValueError("Only include path or config. Not both")
|
||||||
|
if path is not None:
|
||||||
|
with open(path) as f:
|
||||||
|
config_map = json.load(f)
|
||||||
|
elif config is not None:
|
||||||
|
config_map = config.to_dict()
|
||||||
|
else:
|
||||||
|
raise ValueError("Must provide path or config")
|
||||||
|
self._client.ml.put_trained_model(model_id=self.model_id, **config_map)
|
||||||
|
|
||||||
def put_vocab(self, path: str) -> None:
|
def put_vocab(self, path: str) -> None:
|
||||||
with open(path) as f:
|
with open(path) as f:
|
||||||
@ -89,13 +109,14 @@ class PyTorchModel:
|
|||||||
|
|
||||||
def import_model(
|
def import_model(
|
||||||
self,
|
self,
|
||||||
|
*,
|
||||||
model_path: str,
|
model_path: str,
|
||||||
config_path: str,
|
config_path: Optional[str],
|
||||||
vocab_path: str,
|
vocab_path: str,
|
||||||
|
config: Optional[NlpTrainedModelConfig] = None,
|
||||||
chunk_size: int = DEFAULT_CHUNK_SIZE,
|
chunk_size: int = DEFAULT_CHUNK_SIZE,
|
||||||
) -> None:
|
) -> None:
|
||||||
# TODO: Implement some pre-flight checks on config, vocab, and model
|
self.put_config(path=config_path, config=config)
|
||||||
self.put_config(config_path)
|
|
||||||
self.put_model(model_path, chunk_size)
|
self.put_model(model_path, chunk_size)
|
||||||
self.put_vocab(vocab_path)
|
self.put_vocab(vocab_path)
|
||||||
|
|
||||||
|
228
eland/ml/pytorch/nlp_ml_model.py
Normal file
228
eland/ml/pytorch/nlp_ml_model.py
Normal file
@ -0,0 +1,228 @@
|
|||||||
|
# Licensed to Elasticsearch B.V. under one or more contributor
|
||||||
|
# license agreements. See the NOTICE file distributed with
|
||||||
|
# this work for additional information regarding copyright
|
||||||
|
# ownership. Elasticsearch B.V. licenses this file to you under
|
||||||
|
# the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
# not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
|
||||||
|
import typing as t
|
||||||
|
|
||||||
|
|
||||||
|
class NlpTokenizationConfig:
|
||||||
|
def __init__(self, *, configuration_type: str):
|
||||||
|
self.name = configuration_type
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
return {
|
||||||
|
self.name: {
|
||||||
|
k: v for k, v in self.__dict__.items() if v is not None and k != "name"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class NlpRobertaTokenizationConfig(NlpTokenizationConfig):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
add_prefix_space: t.Optional[bool] = None,
|
||||||
|
with_special_tokens: t.Optional[bool] = None,
|
||||||
|
max_sequence_length: t.Optional[int] = None,
|
||||||
|
truncate: t.Optional[
|
||||||
|
t.Union["t.Literal['first', 'none', 'second']", str]
|
||||||
|
] = None,
|
||||||
|
span: t.Optional[int] = None,
|
||||||
|
):
|
||||||
|
super().__init__(configuration_type="roberta")
|
||||||
|
self.add_prefix_space = add_prefix_space
|
||||||
|
self.with_special_tokens = with_special_tokens
|
||||||
|
self.max_sequence_length = max_sequence_length
|
||||||
|
self.truncate = truncate
|
||||||
|
self.span = span
|
||||||
|
|
||||||
|
|
||||||
|
class NlpBertTokenizationConfig(NlpTokenizationConfig):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
do_lower_case: t.Optional[bool] = None,
|
||||||
|
with_special_tokens: t.Optional[bool] = None,
|
||||||
|
max_sequence_length: t.Optional[int] = None,
|
||||||
|
truncate: t.Optional[
|
||||||
|
t.Union["t.Literal['first', 'none', 'second']", str]
|
||||||
|
] = None,
|
||||||
|
span: t.Optional[int] = None,
|
||||||
|
):
|
||||||
|
super().__init__(configuration_type="bert")
|
||||||
|
self.do_lower_case = do_lower_case
|
||||||
|
self.with_special_tokens = with_special_tokens
|
||||||
|
self.max_sequence_length = max_sequence_length
|
||||||
|
self.truncate = truncate
|
||||||
|
self.span = span
|
||||||
|
|
||||||
|
|
||||||
|
class NlpMPNetTokenizationConfig(NlpTokenizationConfig):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
do_lower_case: t.Optional[bool] = None,
|
||||||
|
with_special_tokens: t.Optional[bool] = None,
|
||||||
|
max_sequence_length: t.Optional[int] = None,
|
||||||
|
truncate: t.Optional[
|
||||||
|
t.Union["t.Literal['first', 'none', 'second']", str]
|
||||||
|
] = None,
|
||||||
|
span: t.Optional[int] = None,
|
||||||
|
):
|
||||||
|
super().__init__(configuration_type="mpnet")
|
||||||
|
self.do_lower_case = do_lower_case
|
||||||
|
self.with_special_tokens = with_special_tokens
|
||||||
|
self.max_sequence_length = max_sequence_length
|
||||||
|
self.truncate = truncate
|
||||||
|
self.span = span
|
||||||
|
|
||||||
|
|
||||||
|
class InferenceConfig:
|
||||||
|
def __init__(self, *, configuration_type: str):
|
||||||
|
self.name = configuration_type
|
||||||
|
|
||||||
|
def to_dict(self) -> t.Dict[str, t.Any]:
|
||||||
|
return {
|
||||||
|
self.name: {
|
||||||
|
k: v.to_dict() if hasattr(v, "to_dict") else v
|
||||||
|
for k, v in self.__dict__.items()
|
||||||
|
if v is not None and k != "name"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TextClassificationInferenceOptions(InferenceConfig):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
classification_labels: t.Union[t.List[str], t.Tuple[str, ...]],
|
||||||
|
tokenization: NlpTokenizationConfig,
|
||||||
|
results_field: t.Optional[str] = None,
|
||||||
|
num_top_classes: t.Optional[int] = None,
|
||||||
|
):
|
||||||
|
super().__init__(configuration_type="text_classification")
|
||||||
|
self.results_field = results_field
|
||||||
|
self.num_top_classes = num_top_classes
|
||||||
|
self.tokenization = tokenization
|
||||||
|
self.classification_labels = classification_labels
|
||||||
|
|
||||||
|
|
||||||
|
class ZeroShotClassificationInferenceOptions(InferenceConfig):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
tokenization: NlpTokenizationConfig,
|
||||||
|
classification_labels: t.Union[t.List[str], t.Tuple[str, ...]],
|
||||||
|
results_field: t.Optional[str] = None,
|
||||||
|
multi_label: t.Optional[bool] = None,
|
||||||
|
labels: t.Optional[t.Union[t.List[str], t.Tuple[str, ...]]] = None,
|
||||||
|
hypothesis_template: t.Optional[str] = None,
|
||||||
|
):
|
||||||
|
super().__init__(configuration_type="zero_shot_classification")
|
||||||
|
self.tokenization = tokenization
|
||||||
|
self.hypothesis_template = hypothesis_template
|
||||||
|
self.classification_labels = classification_labels
|
||||||
|
self.results_field = results_field
|
||||||
|
self.multi_label = multi_label
|
||||||
|
self.labels = labels
|
||||||
|
|
||||||
|
|
||||||
|
class FillMaskInferenceOptions(InferenceConfig):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
tokenization: NlpTokenizationConfig,
|
||||||
|
results_field: t.Optional[str] = None,
|
||||||
|
num_top_classes: t.Optional[int] = None,
|
||||||
|
):
|
||||||
|
super().__init__(configuration_type="fill_mask")
|
||||||
|
self.num_top_classes = num_top_classes
|
||||||
|
self.tokenization = tokenization
|
||||||
|
self.results_field = results_field
|
||||||
|
|
||||||
|
|
||||||
|
class NerInferenceOptions(InferenceConfig):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
tokenization: NlpTokenizationConfig,
|
||||||
|
classification_labels: t.Union[t.List[str], t.Tuple[str, ...]],
|
||||||
|
results_field: t.Optional[str] = None,
|
||||||
|
):
|
||||||
|
super().__init__(configuration_type="ner")
|
||||||
|
self.tokenization = tokenization
|
||||||
|
self.classification_labels = classification_labels
|
||||||
|
self.results_field = results_field
|
||||||
|
|
||||||
|
|
||||||
|
class PassThroughInferenceOptions(InferenceConfig):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
tokenization: NlpTokenizationConfig,
|
||||||
|
results_field: t.Optional[str] = None,
|
||||||
|
):
|
||||||
|
super().__init__(configuration_type="pass_through")
|
||||||
|
self.tokenization = tokenization
|
||||||
|
self.results_field = results_field
|
||||||
|
|
||||||
|
|
||||||
|
class TextEmbeddingInferenceOptions(InferenceConfig):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
tokenization: NlpTokenizationConfig,
|
||||||
|
results_field: t.Optional[str] = None,
|
||||||
|
):
|
||||||
|
super().__init__(configuration_type="text_embedding")
|
||||||
|
self.tokenization = tokenization
|
||||||
|
self.results_field = results_field
|
||||||
|
|
||||||
|
|
||||||
|
class TrainedModelInput:
|
||||||
|
def __init__(self, *, field_names: t.List[str]):
|
||||||
|
self.field_names = field_names
|
||||||
|
|
||||||
|
def to_dict(self) -> t.Dict[str, t.Any]:
|
||||||
|
return self.__dict__
|
||||||
|
|
||||||
|
|
||||||
|
class NlpTrainedModelConfig:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
description: str,
|
||||||
|
inference_config: InferenceConfig,
|
||||||
|
input: TrainedModelInput = TrainedModelInput(field_names=["text_field"]),
|
||||||
|
metadata: t.Optional[dict] = None,
|
||||||
|
model_type: t.Union["t.Literal['pytorch']", str] = "pytorch",
|
||||||
|
default_field_map: t.Optional[t.Mapping[str, str]] = None,
|
||||||
|
tags: t.Optional[t.Union[t.List[str], t.Tuple[str, ...]]] = None,
|
||||||
|
):
|
||||||
|
self.tags = tags
|
||||||
|
self.default_field_map = default_field_map
|
||||||
|
self.description = description
|
||||||
|
self.inference_config = inference_config
|
||||||
|
self.input = input
|
||||||
|
self.metadata = metadata
|
||||||
|
self.model_type = model_type
|
||||||
|
|
||||||
|
def to_dict(self) -> t.Dict[str, t.Any]:
|
||||||
|
return {
|
||||||
|
k: v.to_dict() if hasattr(v, "to_dict") else v
|
||||||
|
for k, v in self.__dict__.items()
|
||||||
|
if v is not None
|
||||||
|
}
|
@ -37,6 +37,21 @@ from transformers import (
|
|||||||
PreTrainedTokenizerFast,
|
PreTrainedTokenizerFast,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from eland.ml.pytorch.nlp_ml_model import (
|
||||||
|
FillMaskInferenceOptions,
|
||||||
|
NerInferenceOptions,
|
||||||
|
NlpBertTokenizationConfig,
|
||||||
|
NlpMPNetTokenizationConfig,
|
||||||
|
NlpRobertaTokenizationConfig,
|
||||||
|
NlpTokenizationConfig,
|
||||||
|
NlpTrainedModelConfig,
|
||||||
|
PassThroughInferenceOptions,
|
||||||
|
TextClassificationInferenceOptions,
|
||||||
|
TextEmbeddingInferenceOptions,
|
||||||
|
TrainedModelInput,
|
||||||
|
ZeroShotClassificationInferenceOptions,
|
||||||
|
)
|
||||||
|
|
||||||
DEFAULT_OUTPUT_KEY = "sentence_embedding"
|
DEFAULT_OUTPUT_KEY = "sentence_embedding"
|
||||||
SUPPORTED_TASK_TYPES = {
|
SUPPORTED_TASK_TYPES = {
|
||||||
"fill_mask",
|
"fill_mask",
|
||||||
@ -45,6 +60,14 @@ SUPPORTED_TASK_TYPES = {
|
|||||||
"text_embedding",
|
"text_embedding",
|
||||||
"zero_shot_classification",
|
"zero_shot_classification",
|
||||||
}
|
}
|
||||||
|
TASK_TYPE_TO_INFERENCE_CONFIG = {
|
||||||
|
"fill_mask": FillMaskInferenceOptions,
|
||||||
|
"ner": NerInferenceOptions,
|
||||||
|
"text_classification": TextClassificationInferenceOptions,
|
||||||
|
"text_embedding": TextEmbeddingInferenceOptions,
|
||||||
|
"zero_shot_classification": ZeroShotClassificationInferenceOptions,
|
||||||
|
"pass_through": PassThroughInferenceOptions,
|
||||||
|
}
|
||||||
SUPPORTED_TASK_TYPES_NAMES = ", ".join(sorted(SUPPORTED_TASK_TYPES))
|
SUPPORTED_TASK_TYPES_NAMES = ", ".join(sorted(SUPPORTED_TASK_TYPES))
|
||||||
SUPPORTED_TOKENIZERS = (
|
SUPPORTED_TOKENIZERS = (
|
||||||
transformers.BertTokenizer,
|
transformers.BertTokenizer,
|
||||||
@ -91,8 +114,8 @@ class _DistilBertWrapper(nn.Module): # type: ignore
|
|||||||
self,
|
self,
|
||||||
input_ids: Tensor,
|
input_ids: Tensor,
|
||||||
attention_mask: Tensor,
|
attention_mask: Tensor,
|
||||||
token_type_ids: Tensor,
|
_token_type_ids: Tensor,
|
||||||
position_ids: Tensor,
|
_position_ids: Tensor,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
"""Wrap the input and output to conform to the native process interface."""
|
"""Wrap the input and output to conform to the native process interface."""
|
||||||
|
|
||||||
@ -246,7 +269,7 @@ class _DPREncoderWrapper(nn.Module): # type: ignore
|
|||||||
input_ids: Tensor,
|
input_ids: Tensor,
|
||||||
attention_mask: Tensor,
|
attention_mask: Tensor,
|
||||||
token_type_ids: Tensor,
|
token_type_ids: Tensor,
|
||||||
position_ids: Tensor,
|
_position_ids: Tensor,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
"""Wrap the input and output to conform to the native process interface."""
|
"""Wrap the input and output to conform to the native process interface."""
|
||||||
|
|
||||||
@ -421,50 +444,53 @@ class TransformerModel:
|
|||||||
vocab_obj["merges"] = merges
|
vocab_obj["merges"] = merges
|
||||||
return vocab_obj
|
return vocab_obj
|
||||||
|
|
||||||
def _create_config(self) -> Dict[str, Any]:
|
def _create_tokenization_config(self) -> NlpTokenizationConfig:
|
||||||
if isinstance(self._tokenizer, transformers.MPNetTokenizer):
|
if isinstance(self._tokenizer, transformers.MPNetTokenizer):
|
||||||
tokenizer_type = "mpnet"
|
return NlpMPNetTokenizationConfig(
|
||||||
tokenizer_obj = {
|
do_lower_case=getattr(self._tokenizer, "do_lower_case", None),
|
||||||
"do_lower_case": getattr(self._tokenizer, "do_lower_case", False)
|
max_sequence_length=getattr(
|
||||||
}
|
self._tokenizer, "max_model_input_sizes", dict()
|
||||||
|
).get(self._model_id),
|
||||||
|
)
|
||||||
elif isinstance(
|
elif isinstance(
|
||||||
self._tokenizer, (transformers.RobertaTokenizer, transformers.BartTokenizer)
|
self._tokenizer, (transformers.RobertaTokenizer, transformers.BartTokenizer)
|
||||||
):
|
):
|
||||||
tokenizer_type = "roberta"
|
return NlpRobertaTokenizationConfig(
|
||||||
tokenizer_obj = {
|
add_prefix_space=getattr(self._tokenizer, "add_prefix_space", None),
|
||||||
"add_prefix_space": getattr(self._tokenizer, "add_prefix_space", False)
|
max_sequence_length=getattr(
|
||||||
}
|
self._tokenizer, "max_model_input_sizes", dict()
|
||||||
else:
|
).get(self._model_id),
|
||||||
tokenizer_type = "bert"
|
)
|
||||||
tokenizer_obj = {
|
else:
|
||||||
"do_lower_case": getattr(self._tokenizer, "do_lower_case", False)
|
return NlpBertTokenizationConfig(
|
||||||
}
|
do_lower_case=getattr(self._tokenizer, "do_lower_case", None),
|
||||||
inference_config: Dict[str, Dict[str, Any]] = {
|
max_sequence_length=getattr(
|
||||||
self._task_type: {"tokenization": {tokenizer_type: tokenizer_obj}}
|
self._tokenizer, "max_model_input_sizes", dict()
|
||||||
}
|
).get(self._model_id),
|
||||||
|
|
||||||
if hasattr(self._tokenizer, "max_model_input_sizes"):
|
|
||||||
max_sequence_length = self._tokenizer.max_model_input_sizes.get(
|
|
||||||
self._model_id
|
|
||||||
)
|
)
|
||||||
if max_sequence_length:
|
|
||||||
inference_config[self._task_type]["tokenization"][tokenizer_type][
|
|
||||||
"max_sequence_length"
|
|
||||||
] = max_sequence_length
|
|
||||||
|
|
||||||
if self._traceable_model.classification_labels():
|
def _create_config(self) -> NlpTrainedModelConfig:
|
||||||
inference_config[self._task_type][
|
tokenization_config = self._create_tokenization_config()
|
||||||
"classification_labels"
|
|
||||||
] = self._traceable_model.classification_labels()
|
|
||||||
|
|
||||||
return {
|
inference_config = (
|
||||||
"description": f"Model {self._model_id} for task type '{self._task_type}'",
|
TASK_TYPE_TO_INFERENCE_CONFIG[self._task_type](
|
||||||
"model_type": "pytorch",
|
tokenization=tokenization_config,
|
||||||
"inference_config": inference_config,
|
classification_labels=self._traceable_model.classification_labels(),
|
||||||
"input": {
|
)
|
||||||
"field_names": ["text_field"],
|
if self._traceable_model.classification_labels()
|
||||||
},
|
else TASK_TYPE_TO_INFERENCE_CONFIG[self._task_type](
|
||||||
}
|
tokenization=tokenization_config
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return NlpTrainedModelConfig(
|
||||||
|
description=f"Model {self._model_id} for task type '{self._task_type}'",
|
||||||
|
model_type="pytorch",
|
||||||
|
inference_config=inference_config,
|
||||||
|
input=TrainedModelInput(
|
||||||
|
field_names=["text_field"],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
def _create_traceable_model(self) -> _TraceableModel:
|
def _create_traceable_model(self) -> _TraceableModel:
|
||||||
if self._task_type == "fill_mask":
|
if self._task_type == "fill_mask":
|
||||||
@ -514,19 +540,14 @@ class TransformerModel:
|
|||||||
# Elasticsearch model IDs need to be a specific format: no special chars, all lowercase, max 64 chars
|
# Elasticsearch model IDs need to be a specific format: no special chars, all lowercase, max 64 chars
|
||||||
return self._model_id.replace("/", "__").lower()[:64]
|
return self._model_id.replace("/", "__").lower()[:64]
|
||||||
|
|
||||||
def save(self, path: str) -> Tuple[str, str, str]:
|
def save(self, path: str) -> Tuple[str, NlpTrainedModelConfig, str]:
|
||||||
# save traced model
|
# save traced model
|
||||||
model_path = os.path.join(path, "traced_pytorch_model.pt")
|
model_path = os.path.join(path, "traced_pytorch_model.pt")
|
||||||
torch.jit.save(self._traced_model, model_path)
|
torch.jit.save(self._traced_model, model_path)
|
||||||
|
|
||||||
# save configuration
|
|
||||||
config_path = os.path.join(path, "config.json")
|
|
||||||
with open(config_path, "w") as outfile:
|
|
||||||
json.dump(self._config, outfile)
|
|
||||||
|
|
||||||
# save vocabulary
|
# save vocabulary
|
||||||
vocab_path = os.path.join(path, "vocabulary.json")
|
vocab_path = os.path.join(path, "vocabulary.json")
|
||||||
with open(vocab_path, "w") as outfile:
|
with open(vocab_path, "w") as outfile:
|
||||||
json.dump(self._vocab, outfile)
|
json.dump(self._vocab, outfile)
|
||||||
|
|
||||||
return model_path, config_path, vocab_path
|
return model_path, self._config, vocab_path
|
||||||
|
@ -79,7 +79,7 @@ def setup_and_tear_down():
|
|||||||
def download_model_and_start_deployment(tmp_dir, quantize, model_id, task):
|
def download_model_and_start_deployment(tmp_dir, quantize, model_id, task):
|
||||||
print("Loading HuggingFace transformer tokenizer and model")
|
print("Loading HuggingFace transformer tokenizer and model")
|
||||||
tm = TransformerModel(model_id, task, quantize)
|
tm = TransformerModel(model_id, task, quantize)
|
||||||
model_path, config_path, vocab_path = tm.save(tmp_dir)
|
model_path, config, vocab_path = tm.save(tmp_dir)
|
||||||
ptm = PyTorchModel(ES_TEST_CLIENT, tm.elasticsearch_model_id())
|
ptm = PyTorchModel(ES_TEST_CLIENT, tm.elasticsearch_model_id())
|
||||||
try:
|
try:
|
||||||
ptm.stop()
|
ptm.stop()
|
||||||
@ -87,7 +87,9 @@ def download_model_and_start_deployment(tmp_dir, quantize, model_id, task):
|
|||||||
except NotFoundError:
|
except NotFoundError:
|
||||||
pass
|
pass
|
||||||
print(f"Importing model: {ptm.model_id}")
|
print(f"Importing model: {ptm.model_id}")
|
||||||
ptm.import_model(model_path, config_path, vocab_path)
|
ptm.import_model(
|
||||||
|
model_path=model_path, config_path=None, vocab_path=vocab_path, config=config
|
||||||
|
)
|
||||||
ptm.start()
|
ptm.start()
|
||||||
return ptm
|
return ptm
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user