diff --git a/eland/ml/pytorch/nlp_ml_model.py b/eland/ml/pytorch/nlp_ml_model.py index c485cbd..3de9e9b 100644 --- a/eland/ml/pytorch/nlp_ml_model.py +++ b/eland/ml/pytorch/nlp_ml_model.py @@ -123,6 +123,16 @@ class InferenceConfig: } } +class SlimInferenceOptions(InferenceConfig): + def __init__( + self, + *, + tokenization: NlpTokenizationConfig, + results_field: t.Optional[str] = None, + ): + super().__init__(configuration_type="slim") + self.tokenization = tokenization + self.results_field = results_field class TextClassificationInferenceOptions(InferenceConfig): def __init__( diff --git a/eland/ml/pytorch/transformers.py b/eland/ml/pytorch/transformers.py index b8c4090..99fa738 100644 --- a/eland/ml/pytorch/transformers.py +++ b/eland/ml/pytorch/transformers.py @@ -49,6 +49,7 @@ from eland.ml.pytorch.nlp_ml_model import ( NlpTrainedModelConfig, PassThroughInferenceOptions, QuestionAnsweringInferenceOptions, + SlimInferenceOptions, TextClassificationInferenceOptions, TextEmbeddingInferenceOptions, TextSimilarityInferenceOptions, @@ -61,6 +62,7 @@ DEFAULT_OUTPUT_KEY = "sentence_embedding" SUPPORTED_TASK_TYPES = { "fill_mask", "ner", + "slim", "text_classification", "text_embedding", "zero_shot_classification", @@ -83,6 +85,7 @@ ZERO_SHOT_LABELS = {"contradiction", "neutral", "entailment"} TASK_TYPE_TO_INFERENCE_CONFIG = { "fill_mask": FillMaskInferenceOptions, "ner": NerInferenceOptions, + "slim": SlimInferenceOptions, "text_classification": TextClassificationInferenceOptions, "text_embedding": TextEmbeddingInferenceOptions, "zero_shot_classification": ZeroShotClassificationInferenceOptions,