mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
[ML] Text expansion model config support (#520)
This commit is contained in:
parent
d5578637cb
commit
7f4687c791
@ -244,6 +244,18 @@ class TextEmbeddingInferenceOptions(InferenceConfig):
|
||||
self.results_field = results_field
|
||||
|
||||
|
||||
class TextExpansionInferenceOptions(InferenceConfig):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
tokenization: NlpTokenizationConfig,
|
||||
results_field: t.Optional[str] = None,
|
||||
):
|
||||
super().__init__(configuration_type="text_expansion")
|
||||
self.tokenization = tokenization
|
||||
self.results_field = results_field
|
||||
|
||||
|
||||
class TrainedModelInput:
|
||||
def __init__(self, *, field_names: t.List[str]):
|
||||
self.field_names = field_names
|
||||
|
@ -51,6 +51,7 @@ from eland.ml.pytorch.nlp_ml_model import (
|
||||
QuestionAnsweringInferenceOptions,
|
||||
TextClassificationInferenceOptions,
|
||||
TextEmbeddingInferenceOptions,
|
||||
TextExpansionInferenceOptions,
|
||||
TextSimilarityInferenceOptions,
|
||||
TrainedModelInput,
|
||||
ZeroShotClassificationInferenceOptions,
|
||||
@ -63,6 +64,7 @@ SUPPORTED_TASK_TYPES = {
|
||||
"ner",
|
||||
"text_classification",
|
||||
"text_embedding",
|
||||
"text_expansion",
|
||||
"zero_shot_classification",
|
||||
"question_answering",
|
||||
"text_similarity",
|
||||
@ -83,6 +85,7 @@ ZERO_SHOT_LABELS = {"contradiction", "neutral", "entailment"}
|
||||
TASK_TYPE_TO_INFERENCE_CONFIG = {
|
||||
"fill_mask": FillMaskInferenceOptions,
|
||||
"ner": NerInferenceOptions,
|
||||
"text_expansion": TextExpansionInferenceOptions,
|
||||
"text_classification": TextClassificationInferenceOptions,
|
||||
"text_embedding": TextEmbeddingInferenceOptions,
|
||||
"zero_shot_classification": ZeroShotClassificationInferenceOptions,
|
||||
|
File diff suppressed because one or more lines are too long
Loading…
x
Reference in New Issue
Block a user