Expansion support (#740)

This commit is contained in:
Dai Sugimori 2024-11-23 00:20:58 +09:00 committed by GitHub
parent 04102f2a4e
commit 82492fe771
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -527,6 +527,15 @@ class _TraceableFillMaskModel(_TransformerTraceableModel):
)
class _TraceableTextExpansionModel(_TransformerTraceableModel):
def _prepare_inputs(self) -> transformers.BatchEncoding:
return self._tokenizer(
"This is an example sentence.",
padding="max_length",
return_tensors="pt",
)
class _TraceableNerModel(_TraceableClassificationModel):
def _prepare_inputs(self) -> transformers.BatchEncoding:
return self._tokenizer(
@ -984,6 +993,13 @@ class TransformerModel:
else:
self._task_type = maybe_task_type
if self._task_type == "text_expansion":
model = transformers.AutoModelForMaskedLM.from_pretrained(
self._model_id, token=self._access_token, torchscript=True
)
model = _DistilBertWrapper.try_wrapping(model)
return _TraceableTextExpansionModel(self._tokenizer, model)
if self._task_type == "fill_mask":
model = transformers.AutoModelForMaskedLM.from_pretrained(
self._model_id, token=self._access_token, torchscript=True
@ -1043,7 +1059,7 @@ class TransformerModel:
else:
raise TypeError(
f"Unknown task type {self._task_type}, must be one of: {SUPPORTED_TASK_TYPES_NAMES}"
f"Task {self._task_type} is not supported, must be one of: {SUPPORTED_TASK_TYPES_NAMES}"
)
def elasticsearch_model_id(self) -> str: