mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
[ML] adds new auto task type that attempts to automatically determine NLP task type from model config (#475)
For many model types, we don't need to require the task requested. We can infer the task type based on the model configuration and architecture.
This commit makes the `task-type` parameter optional for the model up load script and adds logic for auto-detecting the task type based on the 🤗 model.
This commit is contained in:
parent
8448b3ba4e
commit
8892f4fd64
@ -84,9 +84,11 @@ def get_arg_parser():
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task-type",
|
||||
required=True,
|
||||
required=False,
|
||||
choices=SUPPORTED_TASK_TYPES,
|
||||
help="The task type for the model usage.",
|
||||
help="The task type for the model usage. Will attempt to auto-detect task type for the model if not provided. "
|
||||
"Default: auto",
|
||||
default="auto"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--quantize",
|
||||
@ -165,7 +167,11 @@ if __name__ == "__main__":
|
||||
|
||||
try:
|
||||
from eland.ml.pytorch import PyTorchModel
|
||||
from eland.ml.pytorch.transformers import SUPPORTED_TASK_TYPES, TransformerModel
|
||||
from eland.ml.pytorch.transformers import (
|
||||
SUPPORTED_TASK_TYPES,
|
||||
TaskTypeError,
|
||||
TransformerModel,
|
||||
)
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(textwrap.dedent(f"""\
|
||||
\033[31mFailed to run because module '{e.name}' is not available.\033[0m
|
||||
@ -187,8 +193,12 @@ if __name__ == "__main__":
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
logger.info(f"Loading HuggingFace transformer tokenizer and model '{args.hub_model_id}'")
|
||||
|
||||
tm = TransformerModel(args.hub_model_id, args.task_type, args.quantize)
|
||||
model_path, config, vocab_path = tm.save(tmp_dir)
|
||||
try:
|
||||
tm = TransformerModel(args.hub_model_id, args.task_type, args.quantize)
|
||||
model_path, config, vocab_path = tm.save(tmp_dir)
|
||||
except TaskTypeError as err:
|
||||
logger.error(f"Failed to get model for task type, please provide valid task type via '--task-type' parameter. Caused by {err}")
|
||||
exit(1)
|
||||
|
||||
ptm = PyTorchModel(es, args.es_model_id if args.es_model_id else tm.elasticsearch_model_id())
|
||||
model_exists = es.options(ignore_status=404).ml.get_trained_models(model_id=ptm.model_id).meta.status == 200
|
||||
|
@ -23,6 +23,7 @@ from eland.ml.pytorch.nlp_ml_model import (
|
||||
NlpTrainedModelConfig,
|
||||
)
|
||||
from eland.ml.pytorch.traceable_model import TraceableModel # noqa: F401
|
||||
from eland.ml.pytorch.transformers import task_type_from_model_config
|
||||
|
||||
__all__ = [
|
||||
"PyTorchModel",
|
||||
@ -31,4 +32,5 @@ __all__ = [
|
||||
"NlpBertTokenizationConfig",
|
||||
"NlpRobertaTokenizationConfig",
|
||||
"NlpMPNetTokenizationConfig",
|
||||
"task_type_from_model_config",
|
||||
]
|
||||
|
@ -23,7 +23,7 @@ libraries such as sentence-transformers.
|
||||
import json
|
||||
import os.path
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch # type: ignore
|
||||
import transformers # type: ignore
|
||||
@ -33,6 +33,7 @@ from transformers import (
|
||||
AutoConfig,
|
||||
AutoModel,
|
||||
AutoModelForQuestionAnswering,
|
||||
PretrainedConfig,
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizer,
|
||||
PreTrainedTokenizerFast,
|
||||
@ -64,6 +65,15 @@ SUPPORTED_TASK_TYPES = {
|
||||
"zero_shot_classification",
|
||||
"question_answering",
|
||||
}
|
||||
ARCHITECTURE_TO_TASK_TYPE = {
|
||||
"MaskedLM": ["fill_mask", "text_embedding"],
|
||||
"TokenClassification": ["ner"],
|
||||
"SequenceClassification": ["text_classification", "zero_shot_classification"],
|
||||
"QuestionAnswering": ["question_answering"],
|
||||
"DPRQuestionEncoder": ["text_embedding"],
|
||||
"DPRContextEncoder": ["text_embedding"],
|
||||
}
|
||||
ZERO_SHOT_LABELS = {"contradiction", "neutral", "entailment"}
|
||||
TASK_TYPE_TO_INFERENCE_CONFIG = {
|
||||
"fill_mask": FillMaskInferenceOptions,
|
||||
"ner": NerInferenceOptions,
|
||||
@ -97,6 +107,37 @@ TracedModelTypes = Union[
|
||||
]
|
||||
|
||||
|
||||
class TaskTypeError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def task_type_from_model_config(model_config: PretrainedConfig) -> Optional[str]:
|
||||
if model_config.architectures is None:
|
||||
if model_config.name_or_path.startswith("sentence-transformers/"):
|
||||
return "text_embedding"
|
||||
return None
|
||||
potential_task_types: Set[str] = set()
|
||||
for architecture in model_config.architectures:
|
||||
for (substr, task_type) in ARCHITECTURE_TO_TASK_TYPE.items():
|
||||
if substr in architecture:
|
||||
for t in task_type:
|
||||
potential_task_types.add(t)
|
||||
if len(potential_task_types) == 0:
|
||||
return None
|
||||
if len(potential_task_types) > 1:
|
||||
if "zero_shot_classification" in potential_task_types:
|
||||
if model_config.label2id:
|
||||
labels = set([x.lower() for x in model_config.label2id.keys()])
|
||||
if len(labels.difference(ZERO_SHOT_LABELS)) == 0:
|
||||
return "zero_shot_classification"
|
||||
return "text_classification"
|
||||
if "text_embedding" in potential_task_types:
|
||||
if model_config.name_or_path.startswith("sentence-transformers/"):
|
||||
return "text_embedding"
|
||||
return "fill_mask"
|
||||
return potential_task_types.pop()
|
||||
|
||||
|
||||
class _QuestionAnsweringWrapperModule(nn.Module): # type: ignore
|
||||
"""
|
||||
A wrapper around a question answering model.
|
||||
@ -581,6 +622,18 @@ class TransformerModel:
|
||||
)
|
||||
|
||||
def _create_traceable_model(self) -> TraceableModel:
|
||||
if self._task_type == "auto":
|
||||
model = transformers.AutoModel.from_pretrained(
|
||||
self._model_id, torchscript=True
|
||||
)
|
||||
maybe_task_type = task_type_from_model_config(model.config)
|
||||
if maybe_task_type is None:
|
||||
raise TaskTypeError(
|
||||
f"Unable to automatically determine task type for model {self._model_id}, please supply task type: {SUPPORTED_TASK_TYPES_NAMES}"
|
||||
)
|
||||
else:
|
||||
self._task_type = maybe_task_type
|
||||
|
||||
if self._task_type == "fill_mask":
|
||||
model = transformers.AutoModelForMaskedLM.from_pretrained(
|
||||
self._model_id, torchscript=True
|
||||
|
@ -34,12 +34,14 @@ except ImportError:
|
||||
try:
|
||||
import torch # noqa: F401
|
||||
from torch import Tensor, nn # noqa: F401
|
||||
from transformers import PretrainedConfig # noqa: F401
|
||||
|
||||
from eland.ml.pytorch import ( # noqa: F401
|
||||
NlpBertTokenizationConfig,
|
||||
NlpTrainedModelConfig,
|
||||
PyTorchModel,
|
||||
TraceableModel,
|
||||
task_type_from_model_config,
|
||||
)
|
||||
from eland.ml.pytorch.nlp_ml_model import (
|
||||
NerInferenceOptions,
|
||||
@ -222,6 +224,41 @@ MODELS_TO_TEST = [
|
||||
),
|
||||
]
|
||||
|
||||
AUTO_TASK_RESULTS = [
|
||||
("any_bert", "BERTMaskedLM", None, "fill_mask"),
|
||||
("any_roberta", "RoBERTaMaskedLM", None, "fill_mask"),
|
||||
("sentence-transformers/any_bert", "BERTMaskedLM", None, "text_embedding"),
|
||||
("sentence-transformers/any_roberta", "RoBERTaMaskedLM", None, "text_embedding"),
|
||||
("sentence-transformers/mpnet", "MPNetMaskedLM", None, "text_embedding"),
|
||||
("anynermodel", "BERTForTokenClassification", None, "ner"),
|
||||
("anynermodel", "MPNetForTokenClassification", None, "ner"),
|
||||
("anynermodel", "RoBERTaForTokenClassification", None, "ner"),
|
||||
("anynermodel", "BERTForQuestionAnswering", None, "question_answering"),
|
||||
("anynermodel", "MPNetForQuestionAnswering", None, "question_answering"),
|
||||
("anynermodel", "RoBERTaForQuestionAnswering", None, "question_answering"),
|
||||
("aqaModel", "DPRQuestionEncoder", None, "text_embedding"),
|
||||
("aqaModel", "DPRContextEncoder", None, "text_embedding"),
|
||||
(
|
||||
"any_bert",
|
||||
"BERTForSequenceClassification",
|
||||
["foo", "bar", "baz"],
|
||||
"text_classification",
|
||||
),
|
||||
(
|
||||
"any_bert",
|
||||
"BERTForSequenceClassification",
|
||||
["contradiction", "neutral", "entailment"],
|
||||
"zero_shot_classification",
|
||||
),
|
||||
(
|
||||
"any_bert",
|
||||
"BERTForSequenceClassification",
|
||||
["CONTRADICTION", "NEUTRAL", "ENTAILMENT"],
|
||||
"zero_shot_classification",
|
||||
),
|
||||
("any_bert", "SomeUnknownType", None, None),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def setup_and_tear_down():
|
||||
@ -274,3 +311,22 @@ class TestPytorchModelUpload:
|
||||
result = ptm.infer(docs=[{"text_field": input}])
|
||||
assert result.get("predicted_value") is not None
|
||||
assert result["predicted_value"] == prediction
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_id,architecture,labels,expected_task", AUTO_TASK_RESULTS
|
||||
)
|
||||
def test_auto_task_type(self, model_id, architecture, labels, expected_task):
|
||||
config = (
|
||||
PretrainedConfig(
|
||||
name_or_path=model_id,
|
||||
architectures=[architecture],
|
||||
label2id=dict(zip(labels, range(len(labels)))),
|
||||
id2label=dict(zip(range(len(labels)), labels)),
|
||||
)
|
||||
if labels
|
||||
else PretrainedConfig(
|
||||
name_or_path=model_id,
|
||||
architectures=[architecture],
|
||||
)
|
||||
)
|
||||
assert task_type_from_model_config(model_config=config) == expected_task
|
||||
|
Loading…
x
Reference in New Issue
Block a user