mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
[ML] add support for question_answering NLP tasks (#457)
Adds support for `question_answering` NLP models within the pytorch model uploader. Related: https://github.com/elastic/elasticsearch/pull/85958
This commit is contained in:
parent
afe08f8107
commit
70fadc9986
@ -19,8 +19,22 @@ import typing as t
|
|||||||
|
|
||||||
|
|
||||||
class NlpTokenizationConfig:
|
class NlpTokenizationConfig:
|
||||||
def __init__(self, *, configuration_type: str):
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
configuration_type: str,
|
||||||
|
with_special_tokens: t.Optional[bool] = None,
|
||||||
|
max_sequence_length: t.Optional[int] = None,
|
||||||
|
truncate: t.Optional[
|
||||||
|
t.Union["t.Literal['first', 'none', 'second']", str]
|
||||||
|
] = None,
|
||||||
|
span: t.Optional[int] = None,
|
||||||
|
):
|
||||||
self.name = configuration_type
|
self.name = configuration_type
|
||||||
|
self.with_special_tokens = with_special_tokens
|
||||||
|
self.max_sequence_length = max_sequence_length
|
||||||
|
self.truncate = truncate
|
||||||
|
self.span = span
|
||||||
|
|
||||||
def to_dict(self):
|
def to_dict(self):
|
||||||
return {
|
return {
|
||||||
@ -42,12 +56,14 @@ class NlpRobertaTokenizationConfig(NlpTokenizationConfig):
|
|||||||
] = None,
|
] = None,
|
||||||
span: t.Optional[int] = None,
|
span: t.Optional[int] = None,
|
||||||
):
|
):
|
||||||
super().__init__(configuration_type="roberta")
|
super().__init__(
|
||||||
|
configuration_type="roberta",
|
||||||
|
with_special_tokens=with_special_tokens,
|
||||||
|
max_sequence_length=max_sequence_length,
|
||||||
|
truncate=truncate,
|
||||||
|
span=span,
|
||||||
|
)
|
||||||
self.add_prefix_space = add_prefix_space
|
self.add_prefix_space = add_prefix_space
|
||||||
self.with_special_tokens = with_special_tokens
|
|
||||||
self.max_sequence_length = max_sequence_length
|
|
||||||
self.truncate = truncate
|
|
||||||
self.span = span
|
|
||||||
|
|
||||||
|
|
||||||
class NlpBertTokenizationConfig(NlpTokenizationConfig):
|
class NlpBertTokenizationConfig(NlpTokenizationConfig):
|
||||||
@ -62,12 +78,14 @@ class NlpBertTokenizationConfig(NlpTokenizationConfig):
|
|||||||
] = None,
|
] = None,
|
||||||
span: t.Optional[int] = None,
|
span: t.Optional[int] = None,
|
||||||
):
|
):
|
||||||
super().__init__(configuration_type="bert")
|
super().__init__(
|
||||||
|
configuration_type="bert",
|
||||||
|
with_special_tokens=with_special_tokens,
|
||||||
|
max_sequence_length=max_sequence_length,
|
||||||
|
truncate=truncate,
|
||||||
|
span=span,
|
||||||
|
)
|
||||||
self.do_lower_case = do_lower_case
|
self.do_lower_case = do_lower_case
|
||||||
self.with_special_tokens = with_special_tokens
|
|
||||||
self.max_sequence_length = max_sequence_length
|
|
||||||
self.truncate = truncate
|
|
||||||
self.span = span
|
|
||||||
|
|
||||||
|
|
||||||
class NlpMPNetTokenizationConfig(NlpTokenizationConfig):
|
class NlpMPNetTokenizationConfig(NlpTokenizationConfig):
|
||||||
@ -82,12 +100,14 @@ class NlpMPNetTokenizationConfig(NlpTokenizationConfig):
|
|||||||
] = None,
|
] = None,
|
||||||
span: t.Optional[int] = None,
|
span: t.Optional[int] = None,
|
||||||
):
|
):
|
||||||
super().__init__(configuration_type="mpnet")
|
super().__init__(
|
||||||
|
configuration_type="mpnet",
|
||||||
|
with_special_tokens=with_special_tokens,
|
||||||
|
max_sequence_length=max_sequence_length,
|
||||||
|
truncate=truncate,
|
||||||
|
span=span,
|
||||||
|
)
|
||||||
self.do_lower_case = do_lower_case
|
self.do_lower_case = do_lower_case
|
||||||
self.with_special_tokens = with_special_tokens
|
|
||||||
self.max_sequence_length = max_sequence_length
|
|
||||||
self.truncate = truncate
|
|
||||||
self.span = span
|
|
||||||
|
|
||||||
|
|
||||||
class InferenceConfig:
|
class InferenceConfig:
|
||||||
@ -180,6 +200,24 @@ class PassThroughInferenceOptions(InferenceConfig):
|
|||||||
self.results_field = results_field
|
self.results_field = results_field
|
||||||
|
|
||||||
|
|
||||||
|
class QuestionAnsweringInferenceOptions(InferenceConfig):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
tokenization: NlpTokenizationConfig,
|
||||||
|
results_field: t.Optional[str] = None,
|
||||||
|
max_answer_length: t.Optional[int] = None,
|
||||||
|
question: t.Optional[str] = None,
|
||||||
|
num_top_classes: t.Optional[int] = None,
|
||||||
|
):
|
||||||
|
super().__init__(configuration_type="question_answering")
|
||||||
|
self.tokenization = tokenization
|
||||||
|
self.results_field = results_field
|
||||||
|
self.max_answer_length = max_answer_length
|
||||||
|
self.question = question
|
||||||
|
self.num_top_classes = num_top_classes
|
||||||
|
|
||||||
|
|
||||||
class TextEmbeddingInferenceOptions(InferenceConfig):
|
class TextEmbeddingInferenceOptions(InferenceConfig):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -32,6 +32,7 @@ from torch import Tensor, nn
|
|||||||
from transformers import (
|
from transformers import (
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
AutoModel,
|
AutoModel,
|
||||||
|
AutoModelForQuestionAnswering,
|
||||||
PreTrainedModel,
|
PreTrainedModel,
|
||||||
PreTrainedTokenizer,
|
PreTrainedTokenizer,
|
||||||
PreTrainedTokenizerFast,
|
PreTrainedTokenizerFast,
|
||||||
@ -46,6 +47,7 @@ from eland.ml.pytorch.nlp_ml_model import (
|
|||||||
NlpTokenizationConfig,
|
NlpTokenizationConfig,
|
||||||
NlpTrainedModelConfig,
|
NlpTrainedModelConfig,
|
||||||
PassThroughInferenceOptions,
|
PassThroughInferenceOptions,
|
||||||
|
QuestionAnsweringInferenceOptions,
|
||||||
TextClassificationInferenceOptions,
|
TextClassificationInferenceOptions,
|
||||||
TextEmbeddingInferenceOptions,
|
TextEmbeddingInferenceOptions,
|
||||||
TrainedModelInput,
|
TrainedModelInput,
|
||||||
@ -59,6 +61,7 @@ SUPPORTED_TASK_TYPES = {
|
|||||||
"text_classification",
|
"text_classification",
|
||||||
"text_embedding",
|
"text_embedding",
|
||||||
"zero_shot_classification",
|
"zero_shot_classification",
|
||||||
|
"question_answering",
|
||||||
}
|
}
|
||||||
TASK_TYPE_TO_INFERENCE_CONFIG = {
|
TASK_TYPE_TO_INFERENCE_CONFIG = {
|
||||||
"fill_mask": FillMaskInferenceOptions,
|
"fill_mask": FillMaskInferenceOptions,
|
||||||
@ -67,6 +70,7 @@ TASK_TYPE_TO_INFERENCE_CONFIG = {
|
|||||||
"text_embedding": TextEmbeddingInferenceOptions,
|
"text_embedding": TextEmbeddingInferenceOptions,
|
||||||
"zero_shot_classification": ZeroShotClassificationInferenceOptions,
|
"zero_shot_classification": ZeroShotClassificationInferenceOptions,
|
||||||
"pass_through": PassThroughInferenceOptions,
|
"pass_through": PassThroughInferenceOptions,
|
||||||
|
"question_answering": QuestionAnsweringInferenceOptions,
|
||||||
}
|
}
|
||||||
SUPPORTED_TASK_TYPES_NAMES = ", ".join(sorted(SUPPORTED_TASK_TYPES))
|
SUPPORTED_TASK_TYPES_NAMES = ", ".join(sorted(SUPPORTED_TASK_TYPES))
|
||||||
SUPPORTED_TOKENIZERS = (
|
SUPPORTED_TOKENIZERS = (
|
||||||
@ -92,6 +96,86 @@ TracedModelTypes = Union[
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class _QuestionAnsweringWrapperModule(nn.Module): # type: ignore
|
||||||
|
"""
|
||||||
|
A wrapper around a question answering model.
|
||||||
|
Our inference engine only takes the first tuple if the inference response
|
||||||
|
is a tuple.
|
||||||
|
|
||||||
|
This wrapper transforms the output to be a stacked tensor if its a tuple.
|
||||||
|
|
||||||
|
Otherwise it passes it through
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model: PreTrainedModel):
|
||||||
|
super().__init__()
|
||||||
|
self._hf_model = model
|
||||||
|
self.config = model.config
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_pretrained(model_id: str) -> Optional[Any]:
|
||||||
|
model = AutoModelForQuestionAnswering.from_pretrained(
|
||||||
|
model_id, torchscript=True
|
||||||
|
)
|
||||||
|
if isinstance(
|
||||||
|
model.config,
|
||||||
|
(
|
||||||
|
transformers.MPNetConfig,
|
||||||
|
transformers.RobertaConfig,
|
||||||
|
transformers.BartConfig,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
return _TwoParameterQuestionAnsweringWrapper(model)
|
||||||
|
else:
|
||||||
|
return _QuestionAnsweringWrapper(model)
|
||||||
|
|
||||||
|
|
||||||
|
class _QuestionAnsweringWrapper(_QuestionAnsweringWrapperModule):
|
||||||
|
def __init__(self, model: PreTrainedModel):
|
||||||
|
super().__init__(model=model)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: Tensor,
|
||||||
|
attention_mask: Tensor,
|
||||||
|
token_type_ids: Tensor,
|
||||||
|
position_ids: Tensor,
|
||||||
|
) -> Tensor:
|
||||||
|
"""Wrap the input and output to conform to the native process interface."""
|
||||||
|
|
||||||
|
inputs = {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
"token_type_ids": token_type_ids,
|
||||||
|
"position_ids": position_ids,
|
||||||
|
}
|
||||||
|
|
||||||
|
# remove inputs for specific model types
|
||||||
|
if isinstance(self._hf_model.config, transformers.DistilBertConfig):
|
||||||
|
del inputs["token_type_ids"]
|
||||||
|
del inputs["position_ids"]
|
||||||
|
response = self._hf_model(**inputs)
|
||||||
|
if isinstance(response, tuple):
|
||||||
|
return torch.stack(list(response), dim=0)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
class _TwoParameterQuestionAnsweringWrapper(_QuestionAnsweringWrapperModule):
|
||||||
|
def __init__(self, model: PreTrainedModel):
|
||||||
|
super().__init__(model=model)
|
||||||
|
|
||||||
|
def forward(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor:
|
||||||
|
"""Wrap the input and output to conform to the native process interface."""
|
||||||
|
inputs = {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
}
|
||||||
|
response = self._hf_model(**inputs)
|
||||||
|
if isinstance(response, tuple):
|
||||||
|
return torch.stack(list(response), dim=0)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
class _DistilBertWrapper(nn.Module): # type: ignore
|
class _DistilBertWrapper(nn.Module): # type: ignore
|
||||||
"""
|
"""
|
||||||
A simple wrapper around DistilBERT model which makes the model inputs
|
A simple wrapper around DistilBERT model which makes the model inputs
|
||||||
@ -404,6 +488,16 @@ class _TraceableZeroShotClassificationModel(_TraceableClassificationModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _TraceableQuestionAnsweringModel(_TraceableModel):
|
||||||
|
def _prepare_inputs(self) -> transformers.BatchEncoding:
|
||||||
|
return self._tokenizer(
|
||||||
|
"What is the meaning of life?"
|
||||||
|
"The meaning of life, according to the hitchikers guide, is 42.",
|
||||||
|
padding="max_length",
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TransformerModel:
|
class TransformerModel:
|
||||||
def __init__(self, model_id: str, task_type: str, quantize: bool = False):
|
def __init__(self, model_id: str, task_type: str, quantize: bool = False):
|
||||||
self._model_id = model_id
|
self._model_id = model_id
|
||||||
@ -472,6 +566,11 @@ class TransformerModel:
|
|||||||
def _create_config(self) -> NlpTrainedModelConfig:
|
def _create_config(self) -> NlpTrainedModelConfig:
|
||||||
tokenization_config = self._create_tokenization_config()
|
tokenization_config = self._create_tokenization_config()
|
||||||
|
|
||||||
|
# Set squad well known defaults
|
||||||
|
if self._task_type == "question_answering":
|
||||||
|
tokenization_config.max_sequence_length = 386
|
||||||
|
tokenization_config.span = 128
|
||||||
|
tokenization_config.truncate = "none"
|
||||||
inference_config = (
|
inference_config = (
|
||||||
TASK_TYPE_TO_INFERENCE_CONFIG[self._task_type](
|
TASK_TYPE_TO_INFERENCE_CONFIG[self._task_type](
|
||||||
tokenization=tokenization_config,
|
tokenization=tokenization_config,
|
||||||
@ -530,7 +629,9 @@ class TransformerModel:
|
|||||||
)
|
)
|
||||||
model = _DistilBertWrapper.try_wrapping(model)
|
model = _DistilBertWrapper.try_wrapping(model)
|
||||||
return _TraceableZeroShotClassificationModel(self._tokenizer, model)
|
return _TraceableZeroShotClassificationModel(self._tokenizer, model)
|
||||||
|
elif self._task_type == "question_answering":
|
||||||
|
model = _QuestionAnsweringWrapperModule.from_pretrained(self._model_id)
|
||||||
|
return _TraceableQuestionAnsweringModel(self._tokenizer, model)
|
||||||
else:
|
else:
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
f"Unknown task type {self._task_type}, must be one of: {SUPPORTED_TASK_TYPES_NAMES}"
|
f"Unknown task type {self._task_type}, must be one of: {SUPPORTED_TASK_TYPES_NAMES}"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user