[ML] Text expansion model config support (#520)

This commit is contained in:
David Kyle 2023-03-08 15:40:14 +00:00 committed by GitHub
parent d5578637cb
commit 7f4687c791
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 30 additions and 15 deletions

View File

@ -244,6 +244,18 @@ class TextEmbeddingInferenceOptions(InferenceConfig):
self.results_field = results_field
class TextExpansionInferenceOptions(InferenceConfig):
def __init__(
self,
*,
tokenization: NlpTokenizationConfig,
results_field: t.Optional[str] = None,
):
super().__init__(configuration_type="text_expansion")
self.tokenization = tokenization
self.results_field = results_field
class TrainedModelInput:
def __init__(self, *, field_names: t.List[str]):
self.field_names = field_names

View File

@ -51,6 +51,7 @@ from eland.ml.pytorch.nlp_ml_model import (
QuestionAnsweringInferenceOptions,
TextClassificationInferenceOptions,
TextEmbeddingInferenceOptions,
TextExpansionInferenceOptions,
TextSimilarityInferenceOptions,
TrainedModelInput,
ZeroShotClassificationInferenceOptions,
@ -63,6 +64,7 @@ SUPPORTED_TASK_TYPES = {
"ner",
"text_classification",
"text_embedding",
"text_expansion",
"zero_shot_classification",
"question_answering",
"text_similarity",
@ -83,6 +85,7 @@ ZERO_SHOT_LABELS = {"contradiction", "neutral", "entailment"}
TASK_TYPE_TO_INFERENCE_CONFIG = {
"fill_mask": FillMaskInferenceOptions,
"ner": NerInferenceOptions,
"text_expansion": TextExpansionInferenceOptions,
"text_classification": TextClassificationInferenceOptions,
"text_embedding": TextEmbeddingInferenceOptions,
"zero_shot_classification": ZeroShotClassificationInferenceOptions,

File diff suppressed because one or more lines are too long