mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
[ML] adds new auto task type that attempts to automatically determine NLP task type from model config (#475)
For many model types, we don't need to require the task requested. We can infer the task type based on the model configuration and architecture.
This commit makes the `task-type` parameter optional for the model up load script and adds logic for auto-detecting the task type based on the 🤗 model.
This commit is contained in:
parent
8448b3ba4e
commit
8892f4fd64
@ -84,9 +84,11 @@ def get_arg_parser():
|
|||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--task-type",
|
"--task-type",
|
||||||
required=True,
|
required=False,
|
||||||
choices=SUPPORTED_TASK_TYPES,
|
choices=SUPPORTED_TASK_TYPES,
|
||||||
help="The task type for the model usage.",
|
help="The task type for the model usage. Will attempt to auto-detect task type for the model if not provided. "
|
||||||
|
"Default: auto",
|
||||||
|
default="auto"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--quantize",
|
"--quantize",
|
||||||
@ -165,7 +167,11 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
from eland.ml.pytorch import PyTorchModel
|
from eland.ml.pytorch import PyTorchModel
|
||||||
from eland.ml.pytorch.transformers import SUPPORTED_TASK_TYPES, TransformerModel
|
from eland.ml.pytorch.transformers import (
|
||||||
|
SUPPORTED_TASK_TYPES,
|
||||||
|
TaskTypeError,
|
||||||
|
TransformerModel,
|
||||||
|
)
|
||||||
except ModuleNotFoundError as e:
|
except ModuleNotFoundError as e:
|
||||||
logger.error(textwrap.dedent(f"""\
|
logger.error(textwrap.dedent(f"""\
|
||||||
\033[31mFailed to run because module '{e.name}' is not available.\033[0m
|
\033[31mFailed to run because module '{e.name}' is not available.\033[0m
|
||||||
@ -187,8 +193,12 @@ if __name__ == "__main__":
|
|||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
logger.info(f"Loading HuggingFace transformer tokenizer and model '{args.hub_model_id}'")
|
logger.info(f"Loading HuggingFace transformer tokenizer and model '{args.hub_model_id}'")
|
||||||
|
|
||||||
tm = TransformerModel(args.hub_model_id, args.task_type, args.quantize)
|
try:
|
||||||
model_path, config, vocab_path = tm.save(tmp_dir)
|
tm = TransformerModel(args.hub_model_id, args.task_type, args.quantize)
|
||||||
|
model_path, config, vocab_path = tm.save(tmp_dir)
|
||||||
|
except TaskTypeError as err:
|
||||||
|
logger.error(f"Failed to get model for task type, please provide valid task type via '--task-type' parameter. Caused by {err}")
|
||||||
|
exit(1)
|
||||||
|
|
||||||
ptm = PyTorchModel(es, args.es_model_id if args.es_model_id else tm.elasticsearch_model_id())
|
ptm = PyTorchModel(es, args.es_model_id if args.es_model_id else tm.elasticsearch_model_id())
|
||||||
model_exists = es.options(ignore_status=404).ml.get_trained_models(model_id=ptm.model_id).meta.status == 200
|
model_exists = es.options(ignore_status=404).ml.get_trained_models(model_id=ptm.model_id).meta.status == 200
|
||||||
|
@ -23,6 +23,7 @@ from eland.ml.pytorch.nlp_ml_model import (
|
|||||||
NlpTrainedModelConfig,
|
NlpTrainedModelConfig,
|
||||||
)
|
)
|
||||||
from eland.ml.pytorch.traceable_model import TraceableModel # noqa: F401
|
from eland.ml.pytorch.traceable_model import TraceableModel # noqa: F401
|
||||||
|
from eland.ml.pytorch.transformers import task_type_from_model_config
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"PyTorchModel",
|
"PyTorchModel",
|
||||||
@ -31,4 +32,5 @@ __all__ = [
|
|||||||
"NlpBertTokenizationConfig",
|
"NlpBertTokenizationConfig",
|
||||||
"NlpRobertaTokenizationConfig",
|
"NlpRobertaTokenizationConfig",
|
||||||
"NlpMPNetTokenizationConfig",
|
"NlpMPNetTokenizationConfig",
|
||||||
|
"task_type_from_model_config",
|
||||||
]
|
]
|
||||||
|
@ -23,7 +23,7 @@ libraries such as sentence-transformers.
|
|||||||
import json
|
import json
|
||||||
import os.path
|
import os.path
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import torch # type: ignore
|
import torch # type: ignore
|
||||||
import transformers # type: ignore
|
import transformers # type: ignore
|
||||||
@ -33,6 +33,7 @@ from transformers import (
|
|||||||
AutoConfig,
|
AutoConfig,
|
||||||
AutoModel,
|
AutoModel,
|
||||||
AutoModelForQuestionAnswering,
|
AutoModelForQuestionAnswering,
|
||||||
|
PretrainedConfig,
|
||||||
PreTrainedModel,
|
PreTrainedModel,
|
||||||
PreTrainedTokenizer,
|
PreTrainedTokenizer,
|
||||||
PreTrainedTokenizerFast,
|
PreTrainedTokenizerFast,
|
||||||
@ -64,6 +65,15 @@ SUPPORTED_TASK_TYPES = {
|
|||||||
"zero_shot_classification",
|
"zero_shot_classification",
|
||||||
"question_answering",
|
"question_answering",
|
||||||
}
|
}
|
||||||
|
ARCHITECTURE_TO_TASK_TYPE = {
|
||||||
|
"MaskedLM": ["fill_mask", "text_embedding"],
|
||||||
|
"TokenClassification": ["ner"],
|
||||||
|
"SequenceClassification": ["text_classification", "zero_shot_classification"],
|
||||||
|
"QuestionAnswering": ["question_answering"],
|
||||||
|
"DPRQuestionEncoder": ["text_embedding"],
|
||||||
|
"DPRContextEncoder": ["text_embedding"],
|
||||||
|
}
|
||||||
|
ZERO_SHOT_LABELS = {"contradiction", "neutral", "entailment"}
|
||||||
TASK_TYPE_TO_INFERENCE_CONFIG = {
|
TASK_TYPE_TO_INFERENCE_CONFIG = {
|
||||||
"fill_mask": FillMaskInferenceOptions,
|
"fill_mask": FillMaskInferenceOptions,
|
||||||
"ner": NerInferenceOptions,
|
"ner": NerInferenceOptions,
|
||||||
@ -97,6 +107,37 @@ TracedModelTypes = Union[
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class TaskTypeError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def task_type_from_model_config(model_config: PretrainedConfig) -> Optional[str]:
|
||||||
|
if model_config.architectures is None:
|
||||||
|
if model_config.name_or_path.startswith("sentence-transformers/"):
|
||||||
|
return "text_embedding"
|
||||||
|
return None
|
||||||
|
potential_task_types: Set[str] = set()
|
||||||
|
for architecture in model_config.architectures:
|
||||||
|
for (substr, task_type) in ARCHITECTURE_TO_TASK_TYPE.items():
|
||||||
|
if substr in architecture:
|
||||||
|
for t in task_type:
|
||||||
|
potential_task_types.add(t)
|
||||||
|
if len(potential_task_types) == 0:
|
||||||
|
return None
|
||||||
|
if len(potential_task_types) > 1:
|
||||||
|
if "zero_shot_classification" in potential_task_types:
|
||||||
|
if model_config.label2id:
|
||||||
|
labels = set([x.lower() for x in model_config.label2id.keys()])
|
||||||
|
if len(labels.difference(ZERO_SHOT_LABELS)) == 0:
|
||||||
|
return "zero_shot_classification"
|
||||||
|
return "text_classification"
|
||||||
|
if "text_embedding" in potential_task_types:
|
||||||
|
if model_config.name_or_path.startswith("sentence-transformers/"):
|
||||||
|
return "text_embedding"
|
||||||
|
return "fill_mask"
|
||||||
|
return potential_task_types.pop()
|
||||||
|
|
||||||
|
|
||||||
class _QuestionAnsweringWrapperModule(nn.Module): # type: ignore
|
class _QuestionAnsweringWrapperModule(nn.Module): # type: ignore
|
||||||
"""
|
"""
|
||||||
A wrapper around a question answering model.
|
A wrapper around a question answering model.
|
||||||
@ -581,6 +622,18 @@ class TransformerModel:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _create_traceable_model(self) -> TraceableModel:
|
def _create_traceable_model(self) -> TraceableModel:
|
||||||
|
if self._task_type == "auto":
|
||||||
|
model = transformers.AutoModel.from_pretrained(
|
||||||
|
self._model_id, torchscript=True
|
||||||
|
)
|
||||||
|
maybe_task_type = task_type_from_model_config(model.config)
|
||||||
|
if maybe_task_type is None:
|
||||||
|
raise TaskTypeError(
|
||||||
|
f"Unable to automatically determine task type for model {self._model_id}, please supply task type: {SUPPORTED_TASK_TYPES_NAMES}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self._task_type = maybe_task_type
|
||||||
|
|
||||||
if self._task_type == "fill_mask":
|
if self._task_type == "fill_mask":
|
||||||
model = transformers.AutoModelForMaskedLM.from_pretrained(
|
model = transformers.AutoModelForMaskedLM.from_pretrained(
|
||||||
self._model_id, torchscript=True
|
self._model_id, torchscript=True
|
||||||
|
@ -34,12 +34,14 @@ except ImportError:
|
|||||||
try:
|
try:
|
||||||
import torch # noqa: F401
|
import torch # noqa: F401
|
||||||
from torch import Tensor, nn # noqa: F401
|
from torch import Tensor, nn # noqa: F401
|
||||||
|
from transformers import PretrainedConfig # noqa: F401
|
||||||
|
|
||||||
from eland.ml.pytorch import ( # noqa: F401
|
from eland.ml.pytorch import ( # noqa: F401
|
||||||
NlpBertTokenizationConfig,
|
NlpBertTokenizationConfig,
|
||||||
NlpTrainedModelConfig,
|
NlpTrainedModelConfig,
|
||||||
PyTorchModel,
|
PyTorchModel,
|
||||||
TraceableModel,
|
TraceableModel,
|
||||||
|
task_type_from_model_config,
|
||||||
)
|
)
|
||||||
from eland.ml.pytorch.nlp_ml_model import (
|
from eland.ml.pytorch.nlp_ml_model import (
|
||||||
NerInferenceOptions,
|
NerInferenceOptions,
|
||||||
@ -222,6 +224,41 @@ MODELS_TO_TEST = [
|
|||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
AUTO_TASK_RESULTS = [
|
||||||
|
("any_bert", "BERTMaskedLM", None, "fill_mask"),
|
||||||
|
("any_roberta", "RoBERTaMaskedLM", None, "fill_mask"),
|
||||||
|
("sentence-transformers/any_bert", "BERTMaskedLM", None, "text_embedding"),
|
||||||
|
("sentence-transformers/any_roberta", "RoBERTaMaskedLM", None, "text_embedding"),
|
||||||
|
("sentence-transformers/mpnet", "MPNetMaskedLM", None, "text_embedding"),
|
||||||
|
("anynermodel", "BERTForTokenClassification", None, "ner"),
|
||||||
|
("anynermodel", "MPNetForTokenClassification", None, "ner"),
|
||||||
|
("anynermodel", "RoBERTaForTokenClassification", None, "ner"),
|
||||||
|
("anynermodel", "BERTForQuestionAnswering", None, "question_answering"),
|
||||||
|
("anynermodel", "MPNetForQuestionAnswering", None, "question_answering"),
|
||||||
|
("anynermodel", "RoBERTaForQuestionAnswering", None, "question_answering"),
|
||||||
|
("aqaModel", "DPRQuestionEncoder", None, "text_embedding"),
|
||||||
|
("aqaModel", "DPRContextEncoder", None, "text_embedding"),
|
||||||
|
(
|
||||||
|
"any_bert",
|
||||||
|
"BERTForSequenceClassification",
|
||||||
|
["foo", "bar", "baz"],
|
||||||
|
"text_classification",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"any_bert",
|
||||||
|
"BERTForSequenceClassification",
|
||||||
|
["contradiction", "neutral", "entailment"],
|
||||||
|
"zero_shot_classification",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"any_bert",
|
||||||
|
"BERTForSequenceClassification",
|
||||||
|
["CONTRADICTION", "NEUTRAL", "ENTAILMENT"],
|
||||||
|
"zero_shot_classification",
|
||||||
|
),
|
||||||
|
("any_bert", "SomeUnknownType", None, None),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function", autouse=True)
|
@pytest.fixture(scope="function", autouse=True)
|
||||||
def setup_and_tear_down():
|
def setup_and_tear_down():
|
||||||
@ -274,3 +311,22 @@ class TestPytorchModelUpload:
|
|||||||
result = ptm.infer(docs=[{"text_field": input}])
|
result = ptm.infer(docs=[{"text_field": input}])
|
||||||
assert result.get("predicted_value") is not None
|
assert result.get("predicted_value") is not None
|
||||||
assert result["predicted_value"] == prediction
|
assert result["predicted_value"] == prediction
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model_id,architecture,labels,expected_task", AUTO_TASK_RESULTS
|
||||||
|
)
|
||||||
|
def test_auto_task_type(self, model_id, architecture, labels, expected_task):
|
||||||
|
config = (
|
||||||
|
PretrainedConfig(
|
||||||
|
name_or_path=model_id,
|
||||||
|
architectures=[architecture],
|
||||||
|
label2id=dict(zip(labels, range(len(labels)))),
|
||||||
|
id2label=dict(zip(range(len(labels)), labels)),
|
||||||
|
)
|
||||||
|
if labels
|
||||||
|
else PretrainedConfig(
|
||||||
|
name_or_path=model_id,
|
||||||
|
architectures=[architecture],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert task_type_from_model_config(model_config=config) == expected_task
|
||||||
|
Loading…
x
Reference in New Issue
Block a user