Weighted token model config

This commit is contained in:
David Kyle 2023-01-25 10:09:25 +00:00
parent d5578637cb
commit 3a044a76e3
2 changed files with 13 additions and 0 deletions

View File

@ -123,6 +123,16 @@ class InferenceConfig:
} }
} }
class SlimInferenceOptions(InferenceConfig):
def __init__(
self,
*,
tokenization: NlpTokenizationConfig,
results_field: t.Optional[str] = None,
):
super().__init__(configuration_type="slim")
self.tokenization = tokenization
self.results_field = results_field
class TextClassificationInferenceOptions(InferenceConfig): class TextClassificationInferenceOptions(InferenceConfig):
def __init__( def __init__(

View File

@ -49,6 +49,7 @@ from eland.ml.pytorch.nlp_ml_model import (
NlpTrainedModelConfig, NlpTrainedModelConfig,
PassThroughInferenceOptions, PassThroughInferenceOptions,
QuestionAnsweringInferenceOptions, QuestionAnsweringInferenceOptions,
SlimInferenceOptions,
TextClassificationInferenceOptions, TextClassificationInferenceOptions,
TextEmbeddingInferenceOptions, TextEmbeddingInferenceOptions,
TextSimilarityInferenceOptions, TextSimilarityInferenceOptions,
@ -61,6 +62,7 @@ DEFAULT_OUTPUT_KEY = "sentence_embedding"
SUPPORTED_TASK_TYPES = { SUPPORTED_TASK_TYPES = {
"fill_mask", "fill_mask",
"ner", "ner",
"slim",
"text_classification", "text_classification",
"text_embedding", "text_embedding",
"zero_shot_classification", "zero_shot_classification",
@ -83,6 +85,7 @@ ZERO_SHOT_LABELS = {"contradiction", "neutral", "entailment"}
TASK_TYPE_TO_INFERENCE_CONFIG = { TASK_TYPE_TO_INFERENCE_CONFIG = {
"fill_mask": FillMaskInferenceOptions, "fill_mask": FillMaskInferenceOptions,
"ner": NerInferenceOptions, "ner": NerInferenceOptions,
"slim": SlimInferenceOptions,
"text_classification": TextClassificationInferenceOptions, "text_classification": TextClassificationInferenceOptions,
"text_embedding": TextEmbeddingInferenceOptions, "text_embedding": TextEmbeddingInferenceOptions,
"zero_shot_classification": ZeroShotClassificationInferenceOptions, "zero_shot_classification": ZeroShotClassificationInferenceOptions,