mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
[ML] Text expansion model config support (#520)
This commit is contained in:
parent
d5578637cb
commit
7f4687c791
@ -244,6 +244,18 @@ class TextEmbeddingInferenceOptions(InferenceConfig):
|
|||||||
self.results_field = results_field
|
self.results_field = results_field
|
||||||
|
|
||||||
|
|
||||||
|
class TextExpansionInferenceOptions(InferenceConfig):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
tokenization: NlpTokenizationConfig,
|
||||||
|
results_field: t.Optional[str] = None,
|
||||||
|
):
|
||||||
|
super().__init__(configuration_type="text_expansion")
|
||||||
|
self.tokenization = tokenization
|
||||||
|
self.results_field = results_field
|
||||||
|
|
||||||
|
|
||||||
class TrainedModelInput:
|
class TrainedModelInput:
|
||||||
def __init__(self, *, field_names: t.List[str]):
|
def __init__(self, *, field_names: t.List[str]):
|
||||||
self.field_names = field_names
|
self.field_names = field_names
|
||||||
|
@ -51,6 +51,7 @@ from eland.ml.pytorch.nlp_ml_model import (
|
|||||||
QuestionAnsweringInferenceOptions,
|
QuestionAnsweringInferenceOptions,
|
||||||
TextClassificationInferenceOptions,
|
TextClassificationInferenceOptions,
|
||||||
TextEmbeddingInferenceOptions,
|
TextEmbeddingInferenceOptions,
|
||||||
|
TextExpansionInferenceOptions,
|
||||||
TextSimilarityInferenceOptions,
|
TextSimilarityInferenceOptions,
|
||||||
TrainedModelInput,
|
TrainedModelInput,
|
||||||
ZeroShotClassificationInferenceOptions,
|
ZeroShotClassificationInferenceOptions,
|
||||||
@ -63,6 +64,7 @@ SUPPORTED_TASK_TYPES = {
|
|||||||
"ner",
|
"ner",
|
||||||
"text_classification",
|
"text_classification",
|
||||||
"text_embedding",
|
"text_embedding",
|
||||||
|
"text_expansion",
|
||||||
"zero_shot_classification",
|
"zero_shot_classification",
|
||||||
"question_answering",
|
"question_answering",
|
||||||
"text_similarity",
|
"text_similarity",
|
||||||
@ -83,6 +85,7 @@ ZERO_SHOT_LABELS = {"contradiction", "neutral", "entailment"}
|
|||||||
TASK_TYPE_TO_INFERENCE_CONFIG = {
|
TASK_TYPE_TO_INFERENCE_CONFIG = {
|
||||||
"fill_mask": FillMaskInferenceOptions,
|
"fill_mask": FillMaskInferenceOptions,
|
||||||
"ner": NerInferenceOptions,
|
"ner": NerInferenceOptions,
|
||||||
|
"text_expansion": TextExpansionInferenceOptions,
|
||||||
"text_classification": TextClassificationInferenceOptions,
|
"text_classification": TextClassificationInferenceOptions,
|
||||||
"text_embedding": TextEmbeddingInferenceOptions,
|
"text_embedding": TextEmbeddingInferenceOptions,
|
||||||
"zero_shot_classification": ZeroShotClassificationInferenceOptions,
|
"zero_shot_classification": ZeroShotClassificationInferenceOptions,
|
||||||
|
File diff suppressed because one or more lines are too long
Loading…
x
Reference in New Issue
Block a user