Weighted token model config

This commit is contained in:
David Kyle 2023-01-25 10:09:25 +00:00
parent d5578637cb
commit 3a044a76e3
2 changed files with 13 additions and 0 deletions

View File

@ -123,6 +123,16 @@ class InferenceConfig:
}
}
class SlimInferenceOptions(InferenceConfig):
def __init__(
self,
*,
tokenization: NlpTokenizationConfig,
results_field: t.Optional[str] = None,
):
super().__init__(configuration_type="slim")
self.tokenization = tokenization
self.results_field = results_field
class TextClassificationInferenceOptions(InferenceConfig):
def __init__(

View File

@ -49,6 +49,7 @@ from eland.ml.pytorch.nlp_ml_model import (
NlpTrainedModelConfig,
PassThroughInferenceOptions,
QuestionAnsweringInferenceOptions,
SlimInferenceOptions,
TextClassificationInferenceOptions,
TextEmbeddingInferenceOptions,
TextSimilarityInferenceOptions,
@ -61,6 +62,7 @@ DEFAULT_OUTPUT_KEY = "sentence_embedding"
SUPPORTED_TASK_TYPES = {
"fill_mask",
"ner",
"slim",
"text_classification",
"text_embedding",
"zero_shot_classification",
@ -83,6 +85,7 @@ ZERO_SHOT_LABELS = {"contradiction", "neutral", "entailment"}
TASK_TYPE_TO_INFERENCE_CONFIG = {
"fill_mask": FillMaskInferenceOptions,
"ner": NerInferenceOptions,
"slim": SlimInferenceOptions,
"text_classification": TextClassificationInferenceOptions,
"text_embedding": TextEmbeddingInferenceOptions,
"zero_shot_classification": ZeroShotClassificationInferenceOptions,