[ML] add text_similarity task support (#486)

Adds text_similarity task support. This is a cross-encoder transformer task where both sequences are given to the transformer at once.

According to 🤗 (or at least how the cross-encoder models are concerned) this is a sequence classification task with just one classification "label". But really, it isn't labeled at all and is more akin to a regression model.

related: elastic/elasticsearch#88439
This commit is contained in:
Benjamin Trent 2022-08-01 09:04:34 -04:00 committed by GitHub
parent 11ea68a443
commit a8c8726634
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 44 additions and 1 deletions

View File

@ -218,6 +218,20 @@ class QuestionAnsweringInferenceOptions(InferenceConfig):
self.num_top_classes = num_top_classes self.num_top_classes = num_top_classes
class TextSimilarityInferenceOptions(InferenceConfig):
def __init__(
self,
*,
tokenization: NlpTokenizationConfig,
results_field: t.Optional[str] = None,
text: t.Optional[str] = None,
):
super().__init__(configuration_type="text_similarity")
self.tokenization = tokenization
self.results_field = results_field
self.text = text
class TextEmbeddingInferenceOptions(InferenceConfig): class TextEmbeddingInferenceOptions(InferenceConfig):
def __init__( def __init__(
self, self,

View File

@ -51,6 +51,7 @@ from eland.ml.pytorch.nlp_ml_model import (
QuestionAnsweringInferenceOptions, QuestionAnsweringInferenceOptions,
TextClassificationInferenceOptions, TextClassificationInferenceOptions,
TextEmbeddingInferenceOptions, TextEmbeddingInferenceOptions,
TextSimilarityInferenceOptions,
TrainedModelInput, TrainedModelInput,
ZeroShotClassificationInferenceOptions, ZeroShotClassificationInferenceOptions,
) )
@ -64,11 +65,16 @@ SUPPORTED_TASK_TYPES = {
"text_embedding", "text_embedding",
"zero_shot_classification", "zero_shot_classification",
"question_answering", "question_answering",
"text_similarity",
} }
ARCHITECTURE_TO_TASK_TYPE = { ARCHITECTURE_TO_TASK_TYPE = {
"MaskedLM": ["fill_mask", "text_embedding"], "MaskedLM": ["fill_mask", "text_embedding"],
"TokenClassification": ["ner"], "TokenClassification": ["ner"],
"SequenceClassification": ["text_classification", "zero_shot_classification"], "SequenceClassification": [
"text_classification",
"zero_shot_classification",
"text_similarity",
],
"QuestionAnswering": ["question_answering"], "QuestionAnswering": ["question_answering"],
"DPRQuestionEncoder": ["text_embedding"], "DPRQuestionEncoder": ["text_embedding"],
"DPRContextEncoder": ["text_embedding"], "DPRContextEncoder": ["text_embedding"],
@ -82,6 +88,7 @@ TASK_TYPE_TO_INFERENCE_CONFIG = {
"zero_shot_classification": ZeroShotClassificationInferenceOptions, "zero_shot_classification": ZeroShotClassificationInferenceOptions,
"pass_through": PassThroughInferenceOptions, "pass_through": PassThroughInferenceOptions,
"question_answering": QuestionAnsweringInferenceOptions, "question_answering": QuestionAnsweringInferenceOptions,
"text_similarity": TextSimilarityInferenceOptions,
} }
SUPPORTED_TASK_TYPES_NAMES = ", ".join(sorted(SUPPORTED_TASK_TYPES)) SUPPORTED_TASK_TYPES_NAMES = ", ".join(sorted(SUPPORTED_TASK_TYPES))
SUPPORTED_TOKENIZERS = ( SUPPORTED_TOKENIZERS = (
@ -124,6 +131,12 @@ def task_type_from_model_config(model_config: PretrainedConfig) -> Optional[str]
potential_task_types.add(t) potential_task_types.add(t)
if len(potential_task_types) == 0: if len(potential_task_types) == 0:
return None return None
if (
"text_classification" in potential_task_types
and model_config.id2label
and len(model_config.id2label) == 1
):
return "text_similarity"
if len(potential_task_types) > 1: if len(potential_task_types) > 1:
if "zero_shot_classification" in potential_task_types: if "zero_shot_classification" in potential_task_types:
if model_config.label2id: if model_config.label2id:
@ -529,6 +542,16 @@ class _TraceableQuestionAnsweringModel(_TransformerTraceableModel):
) )
class _TraceableTextSimilarityModel(_TransformerTraceableModel):
def _prepare_inputs(self) -> transformers.BatchEncoding:
return self._tokenizer(
"What is the meaning of life?"
"The meaning of life, according to the hitchikers guide, is 42.",
padding="max_length",
return_tensors="pt",
)
class TransformerModel: class TransformerModel:
def __init__(self, model_id: str, task_type: str, quantize: bool = False): def __init__(self, model_id: str, task_type: str, quantize: bool = False):
self._model_id = model_id self._model_id = model_id
@ -674,6 +697,12 @@ class TransformerModel:
elif self._task_type == "question_answering": elif self._task_type == "question_answering":
model = _QuestionAnsweringWrapperModule.from_pretrained(self._model_id) model = _QuestionAnsweringWrapperModule.from_pretrained(self._model_id)
return _TraceableQuestionAnsweringModel(self._tokenizer, model) return _TraceableQuestionAnsweringModel(self._tokenizer, model)
elif self._task_type == "text_similarity":
model = transformers.AutoModelForSequenceClassification.from_pretrained(
self._model_id, torchscript=True
)
model = _DistilBertWrapper.try_wrapping(model)
return _TraceableTextSimilarityModel(self._tokenizer, model)
else: else:
raise TypeError( raise TypeError(
f"Unknown task type {self._task_type}, must be one of: {SUPPORTED_TASK_TYPES_NAMES}" f"Unknown task type {self._task_type}, must be one of: {SUPPORTED_TASK_TYPES_NAMES}"