[ML] add text_similarity task support (#486)

Adds text_similarity task support. This is a cross-encoder transformer task where both sequences are given to the transformer at once.

According to 🤗 (or at least how the cross-encoder models are concerned) this is a sequence classification task with just one classification "label". But really, it isn't labeled at all and is more akin to a regression model.

related: elastic/elasticsearch#88439
This commit is contained in:
Benjamin Trent 2022-08-01 09:04:34 -04:00 committed by GitHub
parent 11ea68a443
commit a8c8726634
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 44 additions and 1 deletions

View File

@ -218,6 +218,20 @@ class QuestionAnsweringInferenceOptions(InferenceConfig):
self.num_top_classes = num_top_classes
class TextSimilarityInferenceOptions(InferenceConfig):
def __init__(
self,
*,
tokenization: NlpTokenizationConfig,
results_field: t.Optional[str] = None,
text: t.Optional[str] = None,
):
super().__init__(configuration_type="text_similarity")
self.tokenization = tokenization
self.results_field = results_field
self.text = text
class TextEmbeddingInferenceOptions(InferenceConfig):
def __init__(
self,

View File

@ -51,6 +51,7 @@ from eland.ml.pytorch.nlp_ml_model import (
QuestionAnsweringInferenceOptions,
TextClassificationInferenceOptions,
TextEmbeddingInferenceOptions,
TextSimilarityInferenceOptions,
TrainedModelInput,
ZeroShotClassificationInferenceOptions,
)
@ -64,11 +65,16 @@ SUPPORTED_TASK_TYPES = {
"text_embedding",
"zero_shot_classification",
"question_answering",
"text_similarity",
}
ARCHITECTURE_TO_TASK_TYPE = {
"MaskedLM": ["fill_mask", "text_embedding"],
"TokenClassification": ["ner"],
"SequenceClassification": ["text_classification", "zero_shot_classification"],
"SequenceClassification": [
"text_classification",
"zero_shot_classification",
"text_similarity",
],
"QuestionAnswering": ["question_answering"],
"DPRQuestionEncoder": ["text_embedding"],
"DPRContextEncoder": ["text_embedding"],
@ -82,6 +88,7 @@ TASK_TYPE_TO_INFERENCE_CONFIG = {
"zero_shot_classification": ZeroShotClassificationInferenceOptions,
"pass_through": PassThroughInferenceOptions,
"question_answering": QuestionAnsweringInferenceOptions,
"text_similarity": TextSimilarityInferenceOptions,
}
SUPPORTED_TASK_TYPES_NAMES = ", ".join(sorted(SUPPORTED_TASK_TYPES))
SUPPORTED_TOKENIZERS = (
@ -124,6 +131,12 @@ def task_type_from_model_config(model_config: PretrainedConfig) -> Optional[str]
potential_task_types.add(t)
if len(potential_task_types) == 0:
return None
if (
"text_classification" in potential_task_types
and model_config.id2label
and len(model_config.id2label) == 1
):
return "text_similarity"
if len(potential_task_types) > 1:
if "zero_shot_classification" in potential_task_types:
if model_config.label2id:
@ -529,6 +542,16 @@ class _TraceableQuestionAnsweringModel(_TransformerTraceableModel):
)
class _TraceableTextSimilarityModel(_TransformerTraceableModel):
def _prepare_inputs(self) -> transformers.BatchEncoding:
return self._tokenizer(
"What is the meaning of life?"
"The meaning of life, according to the hitchikers guide, is 42.",
padding="max_length",
return_tensors="pt",
)
class TransformerModel:
def __init__(self, model_id: str, task_type: str, quantize: bool = False):
self._model_id = model_id
@ -674,6 +697,12 @@ class TransformerModel:
elif self._task_type == "question_answering":
model = _QuestionAnsweringWrapperModule.from_pretrained(self._model_id)
return _TraceableQuestionAnsweringModel(self._tokenizer, model)
elif self._task_type == "text_similarity":
model = transformers.AutoModelForSequenceClassification.from_pretrained(
self._model_id, torchscript=True
)
model = _DistilBertWrapper.try_wrapping(model)
return _TraceableTextSimilarityModel(self._tokenizer, model)
else:
raise TypeError(
f"Unknown task type {self._task_type}, must be one of: {SUPPORTED_TASK_TYPES_NAMES}"