Expansion support (#740)

This commit is contained in:
Dai Sugimori 2024-11-23 00:20:58 +09:00 committed by GitHub
parent 04102f2a4e
commit 82492fe771
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -527,6 +527,15 @@ class _TraceableFillMaskModel(_TransformerTraceableModel):
) )
class _TraceableTextExpansionModel(_TransformerTraceableModel):
def _prepare_inputs(self) -> transformers.BatchEncoding:
return self._tokenizer(
"This is an example sentence.",
padding="max_length",
return_tensors="pt",
)
class _TraceableNerModel(_TraceableClassificationModel): class _TraceableNerModel(_TraceableClassificationModel):
def _prepare_inputs(self) -> transformers.BatchEncoding: def _prepare_inputs(self) -> transformers.BatchEncoding:
return self._tokenizer( return self._tokenizer(
@ -984,6 +993,13 @@ class TransformerModel:
else: else:
self._task_type = maybe_task_type self._task_type = maybe_task_type
if self._task_type == "text_expansion":
model = transformers.AutoModelForMaskedLM.from_pretrained(
self._model_id, token=self._access_token, torchscript=True
)
model = _DistilBertWrapper.try_wrapping(model)
return _TraceableTextExpansionModel(self._tokenizer, model)
if self._task_type == "fill_mask": if self._task_type == "fill_mask":
model = transformers.AutoModelForMaskedLM.from_pretrained( model = transformers.AutoModelForMaskedLM.from_pretrained(
self._model_id, token=self._access_token, torchscript=True self._model_id, token=self._access_token, torchscript=True
@ -1043,7 +1059,7 @@ class TransformerModel:
else: else:
raise TypeError( raise TypeError(
f"Unknown task type {self._task_type}, must be one of: {SUPPORTED_TASK_TYPES_NAMES}" f"Task {self._task_type} is not supported, must be one of: {SUPPORTED_TASK_TYPES_NAMES}"
) )
def elasticsearch_model_id(self) -> str: def elasticsearch_model_id(self) -> str: