diff --git a/eland/ml/pytorch/nlp_ml_model.py b/eland/ml/pytorch/nlp_ml_model.py index 6d84d43..13d1e6d 100644 --- a/eland/ml/pytorch/nlp_ml_model.py +++ b/eland/ml/pytorch/nlp_ml_model.py @@ -124,18 +124,6 @@ class InferenceConfig: } -class SlimInferenceOptions(InferenceConfig): - def __init__( - self, - *, - tokenization: NlpTokenizationConfig, - results_field: t.Optional[str] = None, - ): - super().__init__(configuration_type="slim") - self.tokenization = tokenization - self.results_field = results_field - - class TextClassificationInferenceOptions(InferenceConfig): def __init__( self, @@ -256,6 +244,18 @@ class TextEmbeddingInferenceOptions(InferenceConfig): self.results_field = results_field +class TextExpansionInferenceOptions(InferenceConfig): + def __init__( + self, + *, + tokenization: NlpTokenizationConfig, + results_field: t.Optional[str] = None, + ): + super().__init__(configuration_type="text_expansion") + self.tokenization = tokenization + self.results_field = results_field + + class TrainedModelInput: def __init__(self, *, field_names: t.List[str]): self.field_names = field_names diff --git a/eland/ml/pytorch/transformers.py b/eland/ml/pytorch/transformers.py index 99fa738..03772de 100644 --- a/eland/ml/pytorch/transformers.py +++ b/eland/ml/pytorch/transformers.py @@ -49,9 +49,9 @@ from eland.ml.pytorch.nlp_ml_model import ( NlpTrainedModelConfig, PassThroughInferenceOptions, QuestionAnsweringInferenceOptions, - SlimInferenceOptions, TextClassificationInferenceOptions, TextEmbeddingInferenceOptions, + TextExpansionInferenceOptions, TextSimilarityInferenceOptions, TrainedModelInput, ZeroShotClassificationInferenceOptions, @@ -85,7 +85,7 @@ ZERO_SHOT_LABELS = {"contradiction", "neutral", "entailment"} TASK_TYPE_TO_INFERENCE_CONFIG = { "fill_mask": FillMaskInferenceOptions, "ner": NerInferenceOptions, - "slim": SlimInferenceOptions, + "text_expansion": TextExpansionInferenceOptions, "text_classification": TextClassificationInferenceOptions, "text_embedding": TextEmbeddingInferenceOptions, "zero_shot_classification": ZeroShotClassificationInferenceOptions,