This commit is contained in:
David Kyle 2023-03-06 10:08:27 +00:00
parent e9babdcc37
commit 7186a41d74
2 changed files with 14 additions and 14 deletions

View File

@ -124,18 +124,6 @@ class InferenceConfig:
} }
class SlimInferenceOptions(InferenceConfig):
def __init__(
self,
*,
tokenization: NlpTokenizationConfig,
results_field: t.Optional[str] = None,
):
super().__init__(configuration_type="slim")
self.tokenization = tokenization
self.results_field = results_field
class TextClassificationInferenceOptions(InferenceConfig): class TextClassificationInferenceOptions(InferenceConfig):
def __init__( def __init__(
self, self,
@ -256,6 +244,18 @@ class TextEmbeddingInferenceOptions(InferenceConfig):
self.results_field = results_field self.results_field = results_field
class TextExpansionInferenceOptions(InferenceConfig):
def __init__(
self,
*,
tokenization: NlpTokenizationConfig,
results_field: t.Optional[str] = None,
):
super().__init__(configuration_type="text_expansion")
self.tokenization = tokenization
self.results_field = results_field
class TrainedModelInput: class TrainedModelInput:
def __init__(self, *, field_names: t.List[str]): def __init__(self, *, field_names: t.List[str]):
self.field_names = field_names self.field_names = field_names

View File

@ -49,9 +49,9 @@ from eland.ml.pytorch.nlp_ml_model import (
NlpTrainedModelConfig, NlpTrainedModelConfig,
PassThroughInferenceOptions, PassThroughInferenceOptions,
QuestionAnsweringInferenceOptions, QuestionAnsweringInferenceOptions,
SlimInferenceOptions,
TextClassificationInferenceOptions, TextClassificationInferenceOptions,
TextEmbeddingInferenceOptions, TextEmbeddingInferenceOptions,
TextExpansionInferenceOptions,
TextSimilarityInferenceOptions, TextSimilarityInferenceOptions,
TrainedModelInput, TrainedModelInput,
ZeroShotClassificationInferenceOptions, ZeroShotClassificationInferenceOptions,
@ -85,7 +85,7 @@ ZERO_SHOT_LABELS = {"contradiction", "neutral", "entailment"}
TASK_TYPE_TO_INFERENCE_CONFIG = { TASK_TYPE_TO_INFERENCE_CONFIG = {
"fill_mask": FillMaskInferenceOptions, "fill_mask": FillMaskInferenceOptions,
"ner": NerInferenceOptions, "ner": NerInferenceOptions,
"slim": SlimInferenceOptions, "text_expansion": TextExpansionInferenceOptions,
"text_classification": TextClassificationInferenceOptions, "text_classification": TextClassificationInferenceOptions,
"text_embedding": TextEmbeddingInferenceOptions, "text_embedding": TextEmbeddingInferenceOptions,
"zero_shot_classification": ZeroShotClassificationInferenceOptions, "zero_shot_classification": ZeroShotClassificationInferenceOptions,