From 3a044a76e3f8debf4697b8f42be5ce5b573cd21a Mon Sep 17 00:00:00 2001 From: David Kyle Date: Wed, 25 Jan 2023 10:09:25 +0000 Subject: [PATCH] Weighted token model config --- eland/ml/pytorch/nlp_ml_model.py | 10 ++++++++++ eland/ml/pytorch/transformers.py | 3 +++ 2 files changed, 13 insertions(+) diff --git a/eland/ml/pytorch/nlp_ml_model.py b/eland/ml/pytorch/nlp_ml_model.py index c485cbd..3de9e9b 100644 --- a/eland/ml/pytorch/nlp_ml_model.py +++ b/eland/ml/pytorch/nlp_ml_model.py @@ -123,6 +123,16 @@ class InferenceConfig: } } +class SlimInferenceOptions(InferenceConfig): + def __init__( + self, + *, + tokenization: NlpTokenizationConfig, + results_field: t.Optional[str] = None, + ): + super().__init__(configuration_type="slim") + self.tokenization = tokenization + self.results_field = results_field class TextClassificationInferenceOptions(InferenceConfig): def __init__( diff --git a/eland/ml/pytorch/transformers.py b/eland/ml/pytorch/transformers.py index b8c4090..99fa738 100644 --- a/eland/ml/pytorch/transformers.py +++ b/eland/ml/pytorch/transformers.py @@ -49,6 +49,7 @@ from eland.ml.pytorch.nlp_ml_model import ( NlpTrainedModelConfig, PassThroughInferenceOptions, QuestionAnsweringInferenceOptions, + SlimInferenceOptions, TextClassificationInferenceOptions, TextEmbeddingInferenceOptions, TextSimilarityInferenceOptions, @@ -61,6 +62,7 @@ DEFAULT_OUTPUT_KEY = "sentence_embedding" SUPPORTED_TASK_TYPES = { "fill_mask", "ner", + "slim", "text_classification", "text_embedding", "zero_shot_classification", @@ -83,6 +85,7 @@ ZERO_SHOT_LABELS = {"contradiction", "neutral", "entailment"} TASK_TYPE_TO_INFERENCE_CONFIG = { "fill_mask": FillMaskInferenceOptions, "ner": NerInferenceOptions, + "slim": SlimInferenceOptions, "text_classification": TextClassificationInferenceOptions, "text_embedding": TextEmbeddingInferenceOptions, "zero_shot_classification": ZeroShotClassificationInferenceOptions,