mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
[ML] add text_similarity task support (#486)
Adds text_similarity task support. This is a cross-encoder transformer task where both sequences are given to the transformer at once. According to 🤗 (or at least how the cross-encoder models are concerned) this is a sequence classification task with just one classification "label". But really, it isn't labeled at all and is more akin to a regression model. related: elastic/elasticsearch#88439
This commit is contained in:
parent
11ea68a443
commit
a8c8726634
@ -218,6 +218,20 @@ class QuestionAnsweringInferenceOptions(InferenceConfig):
|
||||
self.num_top_classes = num_top_classes
|
||||
|
||||
|
||||
class TextSimilarityInferenceOptions(InferenceConfig):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
tokenization: NlpTokenizationConfig,
|
||||
results_field: t.Optional[str] = None,
|
||||
text: t.Optional[str] = None,
|
||||
):
|
||||
super().__init__(configuration_type="text_similarity")
|
||||
self.tokenization = tokenization
|
||||
self.results_field = results_field
|
||||
self.text = text
|
||||
|
||||
|
||||
class TextEmbeddingInferenceOptions(InferenceConfig):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -51,6 +51,7 @@ from eland.ml.pytorch.nlp_ml_model import (
|
||||
QuestionAnsweringInferenceOptions,
|
||||
TextClassificationInferenceOptions,
|
||||
TextEmbeddingInferenceOptions,
|
||||
TextSimilarityInferenceOptions,
|
||||
TrainedModelInput,
|
||||
ZeroShotClassificationInferenceOptions,
|
||||
)
|
||||
@ -64,11 +65,16 @@ SUPPORTED_TASK_TYPES = {
|
||||
"text_embedding",
|
||||
"zero_shot_classification",
|
||||
"question_answering",
|
||||
"text_similarity",
|
||||
}
|
||||
ARCHITECTURE_TO_TASK_TYPE = {
|
||||
"MaskedLM": ["fill_mask", "text_embedding"],
|
||||
"TokenClassification": ["ner"],
|
||||
"SequenceClassification": ["text_classification", "zero_shot_classification"],
|
||||
"SequenceClassification": [
|
||||
"text_classification",
|
||||
"zero_shot_classification",
|
||||
"text_similarity",
|
||||
],
|
||||
"QuestionAnswering": ["question_answering"],
|
||||
"DPRQuestionEncoder": ["text_embedding"],
|
||||
"DPRContextEncoder": ["text_embedding"],
|
||||
@ -82,6 +88,7 @@ TASK_TYPE_TO_INFERENCE_CONFIG = {
|
||||
"zero_shot_classification": ZeroShotClassificationInferenceOptions,
|
||||
"pass_through": PassThroughInferenceOptions,
|
||||
"question_answering": QuestionAnsweringInferenceOptions,
|
||||
"text_similarity": TextSimilarityInferenceOptions,
|
||||
}
|
||||
SUPPORTED_TASK_TYPES_NAMES = ", ".join(sorted(SUPPORTED_TASK_TYPES))
|
||||
SUPPORTED_TOKENIZERS = (
|
||||
@ -124,6 +131,12 @@ def task_type_from_model_config(model_config: PretrainedConfig) -> Optional[str]
|
||||
potential_task_types.add(t)
|
||||
if len(potential_task_types) == 0:
|
||||
return None
|
||||
if (
|
||||
"text_classification" in potential_task_types
|
||||
and model_config.id2label
|
||||
and len(model_config.id2label) == 1
|
||||
):
|
||||
return "text_similarity"
|
||||
if len(potential_task_types) > 1:
|
||||
if "zero_shot_classification" in potential_task_types:
|
||||
if model_config.label2id:
|
||||
@ -529,6 +542,16 @@ class _TraceableQuestionAnsweringModel(_TransformerTraceableModel):
|
||||
)
|
||||
|
||||
|
||||
class _TraceableTextSimilarityModel(_TransformerTraceableModel):
|
||||
def _prepare_inputs(self) -> transformers.BatchEncoding:
|
||||
return self._tokenizer(
|
||||
"What is the meaning of life?"
|
||||
"The meaning of life, according to the hitchikers guide, is 42.",
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
|
||||
class TransformerModel:
|
||||
def __init__(self, model_id: str, task_type: str, quantize: bool = False):
|
||||
self._model_id = model_id
|
||||
@ -674,6 +697,12 @@ class TransformerModel:
|
||||
elif self._task_type == "question_answering":
|
||||
model = _QuestionAnsweringWrapperModule.from_pretrained(self._model_id)
|
||||
return _TraceableQuestionAnsweringModel(self._tokenizer, model)
|
||||
elif self._task_type == "text_similarity":
|
||||
model = transformers.AutoModelForSequenceClassification.from_pretrained(
|
||||
self._model_id, torchscript=True
|
||||
)
|
||||
model = _DistilBertWrapper.try_wrapping(model)
|
||||
return _TraceableTextSimilarityModel(self._tokenizer, model)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Unknown task type {self._task_type}, must be one of: {SUPPORTED_TASK_TYPES_NAMES}"
|
||||
|
Loading…
x
Reference in New Issue
Block a user