This commit is contained in:
David Kyle 2023-03-06 10:08:27 +00:00
parent e9babdcc37
commit 7186a41d74
2 changed files with 14 additions and 14 deletions

View File

@ -124,18 +124,6 @@ class InferenceConfig:
}
class SlimInferenceOptions(InferenceConfig):
def __init__(
self,
*,
tokenization: NlpTokenizationConfig,
results_field: t.Optional[str] = None,
):
super().__init__(configuration_type="slim")
self.tokenization = tokenization
self.results_field = results_field
class TextClassificationInferenceOptions(InferenceConfig):
def __init__(
self,
@ -256,6 +244,18 @@ class TextEmbeddingInferenceOptions(InferenceConfig):
self.results_field = results_field
class TextExpansionInferenceOptions(InferenceConfig):
def __init__(
self,
*,
tokenization: NlpTokenizationConfig,
results_field: t.Optional[str] = None,
):
super().__init__(configuration_type="text_expansion")
self.tokenization = tokenization
self.results_field = results_field
class TrainedModelInput:
def __init__(self, *, field_names: t.List[str]):
self.field_names = field_names

View File

@ -49,9 +49,9 @@ from eland.ml.pytorch.nlp_ml_model import (
NlpTrainedModelConfig,
PassThroughInferenceOptions,
QuestionAnsweringInferenceOptions,
SlimInferenceOptions,
TextClassificationInferenceOptions,
TextEmbeddingInferenceOptions,
TextExpansionInferenceOptions,
TextSimilarityInferenceOptions,
TrainedModelInput,
ZeroShotClassificationInferenceOptions,
@ -85,7 +85,7 @@ ZERO_SHOT_LABELS = {"contradiction", "neutral", "entailment"}
TASK_TYPE_TO_INFERENCE_CONFIG = {
"fill_mask": FillMaskInferenceOptions,
"ner": NerInferenceOptions,
"slim": SlimInferenceOptions,
"text_expansion": TextExpansionInferenceOptions,
"text_classification": TextClassificationInferenceOptions,
"text_embedding": TextEmbeddingInferenceOptions,
"zero_shot_classification": ZeroShotClassificationInferenceOptions,