mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
[NLP] Add support for the pass_through task #526
This commit is contained in:
parent
8e0d897171
commit
940f2a9bad
@ -62,6 +62,7 @@ DEFAULT_OUTPUT_KEY = "sentence_embedding"
|
|||||||
SUPPORTED_TASK_TYPES = {
|
SUPPORTED_TASK_TYPES = {
|
||||||
"fill_mask",
|
"fill_mask",
|
||||||
"ner",
|
"ner",
|
||||||
|
"pass_through",
|
||||||
"text_classification",
|
"text_classification",
|
||||||
"text_embedding",
|
"text_embedding",
|
||||||
"text_expansion",
|
"text_expansion",
|
||||||
@ -510,6 +511,15 @@ class _TraceableNerModel(_TraceableClassificationModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _TraceablePassThroughModel(_TransformerTraceableModel):
|
||||||
|
def _prepare_inputs(self) -> transformers.BatchEncoding:
|
||||||
|
return self._tokenizer(
|
||||||
|
"This is an example sentence.",
|
||||||
|
padding="max_length",
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class _TraceableTextClassificationModel(_TraceableClassificationModel):
|
class _TraceableTextClassificationModel(_TraceableClassificationModel):
|
||||||
def _prepare_inputs(self) -> transformers.BatchEncoding:
|
def _prepare_inputs(self) -> transformers.BatchEncoding:
|
||||||
return self._tokenizer(
|
return self._tokenizer(
|
||||||
@ -709,6 +719,11 @@ class TransformerModel:
|
|||||||
)
|
)
|
||||||
model = _DistilBertWrapper.try_wrapping(model)
|
model = _DistilBertWrapper.try_wrapping(model)
|
||||||
return _TraceableTextSimilarityModel(self._tokenizer, model)
|
return _TraceableTextSimilarityModel(self._tokenizer, model)
|
||||||
|
elif self._task_type == "pass_through":
|
||||||
|
model = transformers.AutoModel.from_pretrained(
|
||||||
|
self._model_id, torchscript=True
|
||||||
|
)
|
||||||
|
return _TraceablePassThroughModel(self._tokenizer, model)
|
||||||
else:
|
else:
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
f"Unknown task type {self._task_type}, must be one of: {SUPPORTED_TASK_TYPES_NAMES}"
|
f"Unknown task type {self._task_type}, must be one of: {SUPPORTED_TASK_TYPES_NAMES}"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user