[ML] add support for question_answering NLP tasks (#457)

Adds support for `question_answering` NLP models within the pytorch model uploader.

Related: https://github.com/elastic/elasticsearch/pull/85958
This commit is contained in:
Benjamin Trent 2022-05-04 13:15:33 -04:00 committed by GitHub
parent afe08f8107
commit 70fadc9986
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 156 additions and 17 deletions

View File

@ -19,8 +19,22 @@ import typing as t
class NlpTokenizationConfig:
def __init__(self, *, configuration_type: str):
def __init__(
self,
*,
configuration_type: str,
with_special_tokens: t.Optional[bool] = None,
max_sequence_length: t.Optional[int] = None,
truncate: t.Optional[
t.Union["t.Literal['first', 'none', 'second']", str]
] = None,
span: t.Optional[int] = None,
):
self.name = configuration_type
self.with_special_tokens = with_special_tokens
self.max_sequence_length = max_sequence_length
self.truncate = truncate
self.span = span
def to_dict(self):
return {
@ -42,12 +56,14 @@ class NlpRobertaTokenizationConfig(NlpTokenizationConfig):
] = None,
span: t.Optional[int] = None,
):
super().__init__(configuration_type="roberta")
super().__init__(
configuration_type="roberta",
with_special_tokens=with_special_tokens,
max_sequence_length=max_sequence_length,
truncate=truncate,
span=span,
)
self.add_prefix_space = add_prefix_space
self.with_special_tokens = with_special_tokens
self.max_sequence_length = max_sequence_length
self.truncate = truncate
self.span = span
class NlpBertTokenizationConfig(NlpTokenizationConfig):
@ -62,12 +78,14 @@ class NlpBertTokenizationConfig(NlpTokenizationConfig):
] = None,
span: t.Optional[int] = None,
):
super().__init__(configuration_type="bert")
super().__init__(
configuration_type="bert",
with_special_tokens=with_special_tokens,
max_sequence_length=max_sequence_length,
truncate=truncate,
span=span,
)
self.do_lower_case = do_lower_case
self.with_special_tokens = with_special_tokens
self.max_sequence_length = max_sequence_length
self.truncate = truncate
self.span = span
class NlpMPNetTokenizationConfig(NlpTokenizationConfig):
@ -82,12 +100,14 @@ class NlpMPNetTokenizationConfig(NlpTokenizationConfig):
] = None,
span: t.Optional[int] = None,
):
super().__init__(configuration_type="mpnet")
super().__init__(
configuration_type="mpnet",
with_special_tokens=with_special_tokens,
max_sequence_length=max_sequence_length,
truncate=truncate,
span=span,
)
self.do_lower_case = do_lower_case
self.with_special_tokens = with_special_tokens
self.max_sequence_length = max_sequence_length
self.truncate = truncate
self.span = span
class InferenceConfig:
@ -180,6 +200,24 @@ class PassThroughInferenceOptions(InferenceConfig):
self.results_field = results_field
class QuestionAnsweringInferenceOptions(InferenceConfig):
def __init__(
self,
*,
tokenization: NlpTokenizationConfig,
results_field: t.Optional[str] = None,
max_answer_length: t.Optional[int] = None,
question: t.Optional[str] = None,
num_top_classes: t.Optional[int] = None,
):
super().__init__(configuration_type="question_answering")
self.tokenization = tokenization
self.results_field = results_field
self.max_answer_length = max_answer_length
self.question = question
self.num_top_classes = num_top_classes
class TextEmbeddingInferenceOptions(InferenceConfig):
def __init__(
self,

View File

@ -32,6 +32,7 @@ from torch import Tensor, nn
from transformers import (
AutoConfig,
AutoModel,
AutoModelForQuestionAnswering,
PreTrainedModel,
PreTrainedTokenizer,
PreTrainedTokenizerFast,
@ -46,6 +47,7 @@ from eland.ml.pytorch.nlp_ml_model import (
NlpTokenizationConfig,
NlpTrainedModelConfig,
PassThroughInferenceOptions,
QuestionAnsweringInferenceOptions,
TextClassificationInferenceOptions,
TextEmbeddingInferenceOptions,
TrainedModelInput,
@ -59,6 +61,7 @@ SUPPORTED_TASK_TYPES = {
"text_classification",
"text_embedding",
"zero_shot_classification",
"question_answering",
}
TASK_TYPE_TO_INFERENCE_CONFIG = {
"fill_mask": FillMaskInferenceOptions,
@ -67,6 +70,7 @@ TASK_TYPE_TO_INFERENCE_CONFIG = {
"text_embedding": TextEmbeddingInferenceOptions,
"zero_shot_classification": ZeroShotClassificationInferenceOptions,
"pass_through": PassThroughInferenceOptions,
"question_answering": QuestionAnsweringInferenceOptions,
}
SUPPORTED_TASK_TYPES_NAMES = ", ".join(sorted(SUPPORTED_TASK_TYPES))
SUPPORTED_TOKENIZERS = (
@ -92,6 +96,86 @@ TracedModelTypes = Union[
]
class _QuestionAnsweringWrapperModule(nn.Module): # type: ignore
"""
A wrapper around a question answering model.
Our inference engine only takes the first tuple if the inference response
is a tuple.
This wrapper transforms the output to be a stacked tensor if its a tuple.
Otherwise it passes it through
"""
def __init__(self, model: PreTrainedModel):
super().__init__()
self._hf_model = model
self.config = model.config
@staticmethod
def from_pretrained(model_id: str) -> Optional[Any]:
model = AutoModelForQuestionAnswering.from_pretrained(
model_id, torchscript=True
)
if isinstance(
model.config,
(
transformers.MPNetConfig,
transformers.RobertaConfig,
transformers.BartConfig,
),
):
return _TwoParameterQuestionAnsweringWrapper(model)
else:
return _QuestionAnsweringWrapper(model)
class _QuestionAnsweringWrapper(_QuestionAnsweringWrapperModule):
def __init__(self, model: PreTrainedModel):
super().__init__(model=model)
def forward(
self,
input_ids: Tensor,
attention_mask: Tensor,
token_type_ids: Tensor,
position_ids: Tensor,
) -> Tensor:
"""Wrap the input and output to conform to the native process interface."""
inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
"position_ids": position_ids,
}
# remove inputs for specific model types
if isinstance(self._hf_model.config, transformers.DistilBertConfig):
del inputs["token_type_ids"]
del inputs["position_ids"]
response = self._hf_model(**inputs)
if isinstance(response, tuple):
return torch.stack(list(response), dim=0)
return response
class _TwoParameterQuestionAnsweringWrapper(_QuestionAnsweringWrapperModule):
def __init__(self, model: PreTrainedModel):
super().__init__(model=model)
def forward(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor:
"""Wrap the input and output to conform to the native process interface."""
inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
}
response = self._hf_model(**inputs)
if isinstance(response, tuple):
return torch.stack(list(response), dim=0)
return response
class _DistilBertWrapper(nn.Module): # type: ignore
"""
A simple wrapper around DistilBERT model which makes the model inputs
@ -404,6 +488,16 @@ class _TraceableZeroShotClassificationModel(_TraceableClassificationModel):
)
class _TraceableQuestionAnsweringModel(_TraceableModel):
def _prepare_inputs(self) -> transformers.BatchEncoding:
return self._tokenizer(
"What is the meaning of life?"
"The meaning of life, according to the hitchikers guide, is 42.",
padding="max_length",
return_tensors="pt",
)
class TransformerModel:
def __init__(self, model_id: str, task_type: str, quantize: bool = False):
self._model_id = model_id
@ -472,6 +566,11 @@ class TransformerModel:
def _create_config(self) -> NlpTrainedModelConfig:
tokenization_config = self._create_tokenization_config()
# Set squad well known defaults
if self._task_type == "question_answering":
tokenization_config.max_sequence_length = 386
tokenization_config.span = 128
tokenization_config.truncate = "none"
inference_config = (
TASK_TYPE_TO_INFERENCE_CONFIG[self._task_type](
tokenization=tokenization_config,
@ -530,7 +629,9 @@ class TransformerModel:
)
model = _DistilBertWrapper.try_wrapping(model)
return _TraceableZeroShotClassificationModel(self._tokenizer, model)
elif self._task_type == "question_answering":
model = _QuestionAnsweringWrapperModule.from_pretrained(self._model_id)
return _TraceableQuestionAnsweringModel(self._tokenizer, model)
else:
raise TypeError(
f"Unknown task type {self._task_type}, must be one of: {SUPPORTED_TASK_TYPES_NAMES}"