mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
[ML] add text_similarity task support (#486)
Adds text_similarity task support. This is a cross-encoder transformer task where both sequences are given to the transformer at once. According to 🤗 (or at least how the cross-encoder models are concerned) this is a sequence classification task with just one classification "label". But really, it isn't labeled at all and is more akin to a regression model. related: elastic/elasticsearch#88439
This commit is contained in:
parent
11ea68a443
commit
a8c8726634
@ -218,6 +218,20 @@ class QuestionAnsweringInferenceOptions(InferenceConfig):
|
|||||||
self.num_top_classes = num_top_classes
|
self.num_top_classes = num_top_classes
|
||||||
|
|
||||||
|
|
||||||
|
class TextSimilarityInferenceOptions(InferenceConfig):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
tokenization: NlpTokenizationConfig,
|
||||||
|
results_field: t.Optional[str] = None,
|
||||||
|
text: t.Optional[str] = None,
|
||||||
|
):
|
||||||
|
super().__init__(configuration_type="text_similarity")
|
||||||
|
self.tokenization = tokenization
|
||||||
|
self.results_field = results_field
|
||||||
|
self.text = text
|
||||||
|
|
||||||
|
|
||||||
class TextEmbeddingInferenceOptions(InferenceConfig):
|
class TextEmbeddingInferenceOptions(InferenceConfig):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -51,6 +51,7 @@ from eland.ml.pytorch.nlp_ml_model import (
|
|||||||
QuestionAnsweringInferenceOptions,
|
QuestionAnsweringInferenceOptions,
|
||||||
TextClassificationInferenceOptions,
|
TextClassificationInferenceOptions,
|
||||||
TextEmbeddingInferenceOptions,
|
TextEmbeddingInferenceOptions,
|
||||||
|
TextSimilarityInferenceOptions,
|
||||||
TrainedModelInput,
|
TrainedModelInput,
|
||||||
ZeroShotClassificationInferenceOptions,
|
ZeroShotClassificationInferenceOptions,
|
||||||
)
|
)
|
||||||
@ -64,11 +65,16 @@ SUPPORTED_TASK_TYPES = {
|
|||||||
"text_embedding",
|
"text_embedding",
|
||||||
"zero_shot_classification",
|
"zero_shot_classification",
|
||||||
"question_answering",
|
"question_answering",
|
||||||
|
"text_similarity",
|
||||||
}
|
}
|
||||||
ARCHITECTURE_TO_TASK_TYPE = {
|
ARCHITECTURE_TO_TASK_TYPE = {
|
||||||
"MaskedLM": ["fill_mask", "text_embedding"],
|
"MaskedLM": ["fill_mask", "text_embedding"],
|
||||||
"TokenClassification": ["ner"],
|
"TokenClassification": ["ner"],
|
||||||
"SequenceClassification": ["text_classification", "zero_shot_classification"],
|
"SequenceClassification": [
|
||||||
|
"text_classification",
|
||||||
|
"zero_shot_classification",
|
||||||
|
"text_similarity",
|
||||||
|
],
|
||||||
"QuestionAnswering": ["question_answering"],
|
"QuestionAnswering": ["question_answering"],
|
||||||
"DPRQuestionEncoder": ["text_embedding"],
|
"DPRQuestionEncoder": ["text_embedding"],
|
||||||
"DPRContextEncoder": ["text_embedding"],
|
"DPRContextEncoder": ["text_embedding"],
|
||||||
@ -82,6 +88,7 @@ TASK_TYPE_TO_INFERENCE_CONFIG = {
|
|||||||
"zero_shot_classification": ZeroShotClassificationInferenceOptions,
|
"zero_shot_classification": ZeroShotClassificationInferenceOptions,
|
||||||
"pass_through": PassThroughInferenceOptions,
|
"pass_through": PassThroughInferenceOptions,
|
||||||
"question_answering": QuestionAnsweringInferenceOptions,
|
"question_answering": QuestionAnsweringInferenceOptions,
|
||||||
|
"text_similarity": TextSimilarityInferenceOptions,
|
||||||
}
|
}
|
||||||
SUPPORTED_TASK_TYPES_NAMES = ", ".join(sorted(SUPPORTED_TASK_TYPES))
|
SUPPORTED_TASK_TYPES_NAMES = ", ".join(sorted(SUPPORTED_TASK_TYPES))
|
||||||
SUPPORTED_TOKENIZERS = (
|
SUPPORTED_TOKENIZERS = (
|
||||||
@ -124,6 +131,12 @@ def task_type_from_model_config(model_config: PretrainedConfig) -> Optional[str]
|
|||||||
potential_task_types.add(t)
|
potential_task_types.add(t)
|
||||||
if len(potential_task_types) == 0:
|
if len(potential_task_types) == 0:
|
||||||
return None
|
return None
|
||||||
|
if (
|
||||||
|
"text_classification" in potential_task_types
|
||||||
|
and model_config.id2label
|
||||||
|
and len(model_config.id2label) == 1
|
||||||
|
):
|
||||||
|
return "text_similarity"
|
||||||
if len(potential_task_types) > 1:
|
if len(potential_task_types) > 1:
|
||||||
if "zero_shot_classification" in potential_task_types:
|
if "zero_shot_classification" in potential_task_types:
|
||||||
if model_config.label2id:
|
if model_config.label2id:
|
||||||
@ -529,6 +542,16 @@ class _TraceableQuestionAnsweringModel(_TransformerTraceableModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _TraceableTextSimilarityModel(_TransformerTraceableModel):
|
||||||
|
def _prepare_inputs(self) -> transformers.BatchEncoding:
|
||||||
|
return self._tokenizer(
|
||||||
|
"What is the meaning of life?"
|
||||||
|
"The meaning of life, according to the hitchikers guide, is 42.",
|
||||||
|
padding="max_length",
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TransformerModel:
|
class TransformerModel:
|
||||||
def __init__(self, model_id: str, task_type: str, quantize: bool = False):
|
def __init__(self, model_id: str, task_type: str, quantize: bool = False):
|
||||||
self._model_id = model_id
|
self._model_id = model_id
|
||||||
@ -674,6 +697,12 @@ class TransformerModel:
|
|||||||
elif self._task_type == "question_answering":
|
elif self._task_type == "question_answering":
|
||||||
model = _QuestionAnsweringWrapperModule.from_pretrained(self._model_id)
|
model = _QuestionAnsweringWrapperModule.from_pretrained(self._model_id)
|
||||||
return _TraceableQuestionAnsweringModel(self._tokenizer, model)
|
return _TraceableQuestionAnsweringModel(self._tokenizer, model)
|
||||||
|
elif self._task_type == "text_similarity":
|
||||||
|
model = transformers.AutoModelForSequenceClassification.from_pretrained(
|
||||||
|
self._model_id, torchscript=True
|
||||||
|
)
|
||||||
|
model = _DistilBertWrapper.try_wrapping(model)
|
||||||
|
return _TraceableTextSimilarityModel(self._tokenizer, model)
|
||||||
else:
|
else:
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
f"Unknown task type {self._task_type}, must be one of: {SUPPORTED_TASK_TYPES_NAMES}"
|
f"Unknown task type {self._task_type}, must be one of: {SUPPORTED_TASK_TYPES_NAMES}"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user