mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
Expansion support (#740)
This commit is contained in:
parent
04102f2a4e
commit
82492fe771
@ -527,6 +527,15 @@ class _TraceableFillMaskModel(_TransformerTraceableModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _TraceableTextExpansionModel(_TransformerTraceableModel):
|
||||||
|
def _prepare_inputs(self) -> transformers.BatchEncoding:
|
||||||
|
return self._tokenizer(
|
||||||
|
"This is an example sentence.",
|
||||||
|
padding="max_length",
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class _TraceableNerModel(_TraceableClassificationModel):
|
class _TraceableNerModel(_TraceableClassificationModel):
|
||||||
def _prepare_inputs(self) -> transformers.BatchEncoding:
|
def _prepare_inputs(self) -> transformers.BatchEncoding:
|
||||||
return self._tokenizer(
|
return self._tokenizer(
|
||||||
@ -984,6 +993,13 @@ class TransformerModel:
|
|||||||
else:
|
else:
|
||||||
self._task_type = maybe_task_type
|
self._task_type = maybe_task_type
|
||||||
|
|
||||||
|
if self._task_type == "text_expansion":
|
||||||
|
model = transformers.AutoModelForMaskedLM.from_pretrained(
|
||||||
|
self._model_id, token=self._access_token, torchscript=True
|
||||||
|
)
|
||||||
|
model = _DistilBertWrapper.try_wrapping(model)
|
||||||
|
return _TraceableTextExpansionModel(self._tokenizer, model)
|
||||||
|
|
||||||
if self._task_type == "fill_mask":
|
if self._task_type == "fill_mask":
|
||||||
model = transformers.AutoModelForMaskedLM.from_pretrained(
|
model = transformers.AutoModelForMaskedLM.from_pretrained(
|
||||||
self._model_id, token=self._access_token, torchscript=True
|
self._model_id, token=self._access_token, torchscript=True
|
||||||
@ -1043,7 +1059,7 @@ class TransformerModel:
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
f"Unknown task type {self._task_type}, must be one of: {SUPPORTED_TASK_TYPES_NAMES}"
|
f"Task {self._task_type} is not supported, must be one of: {SUPPORTED_TASK_TYPES_NAMES}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def elasticsearch_model_id(self) -> str:
|
def elasticsearch_model_id(self) -> str:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user