mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
Weighted token model config
This commit is contained in:
parent
d5578637cb
commit
3a044a76e3
@ -123,6 +123,16 @@ class InferenceConfig:
|
||||
}
|
||||
}
|
||||
|
||||
class SlimInferenceOptions(InferenceConfig):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
tokenization: NlpTokenizationConfig,
|
||||
results_field: t.Optional[str] = None,
|
||||
):
|
||||
super().__init__(configuration_type="slim")
|
||||
self.tokenization = tokenization
|
||||
self.results_field = results_field
|
||||
|
||||
class TextClassificationInferenceOptions(InferenceConfig):
|
||||
def __init__(
|
||||
|
@ -49,6 +49,7 @@ from eland.ml.pytorch.nlp_ml_model import (
|
||||
NlpTrainedModelConfig,
|
||||
PassThroughInferenceOptions,
|
||||
QuestionAnsweringInferenceOptions,
|
||||
SlimInferenceOptions,
|
||||
TextClassificationInferenceOptions,
|
||||
TextEmbeddingInferenceOptions,
|
||||
TextSimilarityInferenceOptions,
|
||||
@ -61,6 +62,7 @@ DEFAULT_OUTPUT_KEY = "sentence_embedding"
|
||||
SUPPORTED_TASK_TYPES = {
|
||||
"fill_mask",
|
||||
"ner",
|
||||
"slim",
|
||||
"text_classification",
|
||||
"text_embedding",
|
||||
"zero_shot_classification",
|
||||
@ -83,6 +85,7 @@ ZERO_SHOT_LABELS = {"contradiction", "neutral", "entailment"}
|
||||
TASK_TYPE_TO_INFERENCE_CONFIG = {
|
||||
"fill_mask": FillMaskInferenceOptions,
|
||||
"ner": NerInferenceOptions,
|
||||
"slim": SlimInferenceOptions,
|
||||
"text_classification": TextClassificationInferenceOptions,
|
||||
"text_embedding": TextEmbeddingInferenceOptions,
|
||||
"zero_shot_classification": ZeroShotClassificationInferenceOptions,
|
||||
|
Loading…
x
Reference in New Issue
Block a user