mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
[ML] add support for question_answering NLP tasks (#457)
Adds support for `question_answering` NLP models within the pytorch model uploader. Related: https://github.com/elastic/elasticsearch/pull/85958
This commit is contained in:
parent
afe08f8107
commit
70fadc9986
@ -19,8 +19,22 @@ import typing as t
|
||||
|
||||
|
||||
class NlpTokenizationConfig:
|
||||
def __init__(self, *, configuration_type: str):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
configuration_type: str,
|
||||
with_special_tokens: t.Optional[bool] = None,
|
||||
max_sequence_length: t.Optional[int] = None,
|
||||
truncate: t.Optional[
|
||||
t.Union["t.Literal['first', 'none', 'second']", str]
|
||||
] = None,
|
||||
span: t.Optional[int] = None,
|
||||
):
|
||||
self.name = configuration_type
|
||||
self.with_special_tokens = with_special_tokens
|
||||
self.max_sequence_length = max_sequence_length
|
||||
self.truncate = truncate
|
||||
self.span = span
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
@ -42,12 +56,14 @@ class NlpRobertaTokenizationConfig(NlpTokenizationConfig):
|
||||
] = None,
|
||||
span: t.Optional[int] = None,
|
||||
):
|
||||
super().__init__(configuration_type="roberta")
|
||||
super().__init__(
|
||||
configuration_type="roberta",
|
||||
with_special_tokens=with_special_tokens,
|
||||
max_sequence_length=max_sequence_length,
|
||||
truncate=truncate,
|
||||
span=span,
|
||||
)
|
||||
self.add_prefix_space = add_prefix_space
|
||||
self.with_special_tokens = with_special_tokens
|
||||
self.max_sequence_length = max_sequence_length
|
||||
self.truncate = truncate
|
||||
self.span = span
|
||||
|
||||
|
||||
class NlpBertTokenizationConfig(NlpTokenizationConfig):
|
||||
@ -62,12 +78,14 @@ class NlpBertTokenizationConfig(NlpTokenizationConfig):
|
||||
] = None,
|
||||
span: t.Optional[int] = None,
|
||||
):
|
||||
super().__init__(configuration_type="bert")
|
||||
super().__init__(
|
||||
configuration_type="bert",
|
||||
with_special_tokens=with_special_tokens,
|
||||
max_sequence_length=max_sequence_length,
|
||||
truncate=truncate,
|
||||
span=span,
|
||||
)
|
||||
self.do_lower_case = do_lower_case
|
||||
self.with_special_tokens = with_special_tokens
|
||||
self.max_sequence_length = max_sequence_length
|
||||
self.truncate = truncate
|
||||
self.span = span
|
||||
|
||||
|
||||
class NlpMPNetTokenizationConfig(NlpTokenizationConfig):
|
||||
@ -82,12 +100,14 @@ class NlpMPNetTokenizationConfig(NlpTokenizationConfig):
|
||||
] = None,
|
||||
span: t.Optional[int] = None,
|
||||
):
|
||||
super().__init__(configuration_type="mpnet")
|
||||
super().__init__(
|
||||
configuration_type="mpnet",
|
||||
with_special_tokens=with_special_tokens,
|
||||
max_sequence_length=max_sequence_length,
|
||||
truncate=truncate,
|
||||
span=span,
|
||||
)
|
||||
self.do_lower_case = do_lower_case
|
||||
self.with_special_tokens = with_special_tokens
|
||||
self.max_sequence_length = max_sequence_length
|
||||
self.truncate = truncate
|
||||
self.span = span
|
||||
|
||||
|
||||
class InferenceConfig:
|
||||
@ -180,6 +200,24 @@ class PassThroughInferenceOptions(InferenceConfig):
|
||||
self.results_field = results_field
|
||||
|
||||
|
||||
class QuestionAnsweringInferenceOptions(InferenceConfig):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
tokenization: NlpTokenizationConfig,
|
||||
results_field: t.Optional[str] = None,
|
||||
max_answer_length: t.Optional[int] = None,
|
||||
question: t.Optional[str] = None,
|
||||
num_top_classes: t.Optional[int] = None,
|
||||
):
|
||||
super().__init__(configuration_type="question_answering")
|
||||
self.tokenization = tokenization
|
||||
self.results_field = results_field
|
||||
self.max_answer_length = max_answer_length
|
||||
self.question = question
|
||||
self.num_top_classes = num_top_classes
|
||||
|
||||
|
||||
class TextEmbeddingInferenceOptions(InferenceConfig):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -32,6 +32,7 @@ from torch import Tensor, nn
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModel,
|
||||
AutoModelForQuestionAnswering,
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizer,
|
||||
PreTrainedTokenizerFast,
|
||||
@ -46,6 +47,7 @@ from eland.ml.pytorch.nlp_ml_model import (
|
||||
NlpTokenizationConfig,
|
||||
NlpTrainedModelConfig,
|
||||
PassThroughInferenceOptions,
|
||||
QuestionAnsweringInferenceOptions,
|
||||
TextClassificationInferenceOptions,
|
||||
TextEmbeddingInferenceOptions,
|
||||
TrainedModelInput,
|
||||
@ -59,6 +61,7 @@ SUPPORTED_TASK_TYPES = {
|
||||
"text_classification",
|
||||
"text_embedding",
|
||||
"zero_shot_classification",
|
||||
"question_answering",
|
||||
}
|
||||
TASK_TYPE_TO_INFERENCE_CONFIG = {
|
||||
"fill_mask": FillMaskInferenceOptions,
|
||||
@ -67,6 +70,7 @@ TASK_TYPE_TO_INFERENCE_CONFIG = {
|
||||
"text_embedding": TextEmbeddingInferenceOptions,
|
||||
"zero_shot_classification": ZeroShotClassificationInferenceOptions,
|
||||
"pass_through": PassThroughInferenceOptions,
|
||||
"question_answering": QuestionAnsweringInferenceOptions,
|
||||
}
|
||||
SUPPORTED_TASK_TYPES_NAMES = ", ".join(sorted(SUPPORTED_TASK_TYPES))
|
||||
SUPPORTED_TOKENIZERS = (
|
||||
@ -92,6 +96,86 @@ TracedModelTypes = Union[
|
||||
]
|
||||
|
||||
|
||||
class _QuestionAnsweringWrapperModule(nn.Module): # type: ignore
|
||||
"""
|
||||
A wrapper around a question answering model.
|
||||
Our inference engine only takes the first tuple if the inference response
|
||||
is a tuple.
|
||||
|
||||
This wrapper transforms the output to be a stacked tensor if its a tuple.
|
||||
|
||||
Otherwise it passes it through
|
||||
"""
|
||||
|
||||
def __init__(self, model: PreTrainedModel):
|
||||
super().__init__()
|
||||
self._hf_model = model
|
||||
self.config = model.config
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained(model_id: str) -> Optional[Any]:
|
||||
model = AutoModelForQuestionAnswering.from_pretrained(
|
||||
model_id, torchscript=True
|
||||
)
|
||||
if isinstance(
|
||||
model.config,
|
||||
(
|
||||
transformers.MPNetConfig,
|
||||
transformers.RobertaConfig,
|
||||
transformers.BartConfig,
|
||||
),
|
||||
):
|
||||
return _TwoParameterQuestionAnsweringWrapper(model)
|
||||
else:
|
||||
return _QuestionAnsweringWrapper(model)
|
||||
|
||||
|
||||
class _QuestionAnsweringWrapper(_QuestionAnsweringWrapperModule):
|
||||
def __init__(self, model: PreTrainedModel):
|
||||
super().__init__(model=model)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Tensor,
|
||||
attention_mask: Tensor,
|
||||
token_type_ids: Tensor,
|
||||
position_ids: Tensor,
|
||||
) -> Tensor:
|
||||
"""Wrap the input and output to conform to the native process interface."""
|
||||
|
||||
inputs = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"token_type_ids": token_type_ids,
|
||||
"position_ids": position_ids,
|
||||
}
|
||||
|
||||
# remove inputs for specific model types
|
||||
if isinstance(self._hf_model.config, transformers.DistilBertConfig):
|
||||
del inputs["token_type_ids"]
|
||||
del inputs["position_ids"]
|
||||
response = self._hf_model(**inputs)
|
||||
if isinstance(response, tuple):
|
||||
return torch.stack(list(response), dim=0)
|
||||
return response
|
||||
|
||||
|
||||
class _TwoParameterQuestionAnsweringWrapper(_QuestionAnsweringWrapperModule):
|
||||
def __init__(self, model: PreTrainedModel):
|
||||
super().__init__(model=model)
|
||||
|
||||
def forward(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor:
|
||||
"""Wrap the input and output to conform to the native process interface."""
|
||||
inputs = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
response = self._hf_model(**inputs)
|
||||
if isinstance(response, tuple):
|
||||
return torch.stack(list(response), dim=0)
|
||||
return response
|
||||
|
||||
|
||||
class _DistilBertWrapper(nn.Module): # type: ignore
|
||||
"""
|
||||
A simple wrapper around DistilBERT model which makes the model inputs
|
||||
@ -404,6 +488,16 @@ class _TraceableZeroShotClassificationModel(_TraceableClassificationModel):
|
||||
)
|
||||
|
||||
|
||||
class _TraceableQuestionAnsweringModel(_TraceableModel):
|
||||
def _prepare_inputs(self) -> transformers.BatchEncoding:
|
||||
return self._tokenizer(
|
||||
"What is the meaning of life?"
|
||||
"The meaning of life, according to the hitchikers guide, is 42.",
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
|
||||
class TransformerModel:
|
||||
def __init__(self, model_id: str, task_type: str, quantize: bool = False):
|
||||
self._model_id = model_id
|
||||
@ -472,6 +566,11 @@ class TransformerModel:
|
||||
def _create_config(self) -> NlpTrainedModelConfig:
|
||||
tokenization_config = self._create_tokenization_config()
|
||||
|
||||
# Set squad well known defaults
|
||||
if self._task_type == "question_answering":
|
||||
tokenization_config.max_sequence_length = 386
|
||||
tokenization_config.span = 128
|
||||
tokenization_config.truncate = "none"
|
||||
inference_config = (
|
||||
TASK_TYPE_TO_INFERENCE_CONFIG[self._task_type](
|
||||
tokenization=tokenization_config,
|
||||
@ -530,7 +629,9 @@ class TransformerModel:
|
||||
)
|
||||
model = _DistilBertWrapper.try_wrapping(model)
|
||||
return _TraceableZeroShotClassificationModel(self._tokenizer, model)
|
||||
|
||||
elif self._task_type == "question_answering":
|
||||
model = _QuestionAnsweringWrapperModule.from_pretrained(self._model_id)
|
||||
return _TraceableQuestionAnsweringModel(self._tokenizer, model)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Unknown task type {self._task_type}, must be one of: {SUPPORTED_TASK_TYPES_NAMES}"
|
||||
|
Loading…
x
Reference in New Issue
Block a user