[NLP] Add support for the pass_through task #526

This commit is contained in:
David Kyle 2023-04-06 15:43:00 +01:00 committed by GitHub
parent 8e0d897171
commit 940f2a9bad
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -62,6 +62,7 @@ DEFAULT_OUTPUT_KEY = "sentence_embedding"
SUPPORTED_TASK_TYPES = {
"fill_mask",
"ner",
"pass_through",
"text_classification",
"text_embedding",
"text_expansion",
@ -510,6 +511,15 @@ class _TraceableNerModel(_TraceableClassificationModel):
)
class _TraceablePassThroughModel(_TransformerTraceableModel):
def _prepare_inputs(self) -> transformers.BatchEncoding:
return self._tokenizer(
"This is an example sentence.",
padding="max_length",
return_tensors="pt",
)
class _TraceableTextClassificationModel(_TraceableClassificationModel):
def _prepare_inputs(self) -> transformers.BatchEncoding:
return self._tokenizer(
@ -709,6 +719,11 @@ class TransformerModel:
)
model = _DistilBertWrapper.try_wrapping(model)
return _TraceableTextSimilarityModel(self._tokenizer, model)
elif self._task_type == "pass_through":
model = transformers.AutoModel.from_pretrained(
self._model_id, torchscript=True
)
return _TraceablePassThroughModel(self._tokenizer, model)
else:
raise TypeError(
f"Unknown task type {self._task_type}, must be one of: {SUPPORTED_TASK_TYPES_NAMES}"