diff --git a/eland/ml/pytorch/transformers.py b/eland/ml/pytorch/transformers.py index 69ef7db..f48f7be 100644 --- a/eland/ml/pytorch/transformers.py +++ b/eland/ml/pytorch/transformers.py @@ -62,6 +62,7 @@ DEFAULT_OUTPUT_KEY = "sentence_embedding" SUPPORTED_TASK_TYPES = { "fill_mask", "ner", + "pass_through", "text_classification", "text_embedding", "text_expansion", @@ -510,6 +511,15 @@ class _TraceableNerModel(_TraceableClassificationModel): ) +class _TraceablePassThroughModel(_TransformerTraceableModel): + def _prepare_inputs(self) -> transformers.BatchEncoding: + return self._tokenizer( + "This is an example sentence.", + padding="max_length", + return_tensors="pt", + ) + + class _TraceableTextClassificationModel(_TraceableClassificationModel): def _prepare_inputs(self) -> transformers.BatchEncoding: return self._tokenizer( @@ -709,6 +719,11 @@ class TransformerModel: ) model = _DistilBertWrapper.try_wrapping(model) return _TraceableTextSimilarityModel(self._tokenizer, model) + elif self._task_type == "pass_through": + model = transformers.AutoModel.from_pretrained( + self._model_id, torchscript=True + ) + return _TraceablePassThroughModel(self._tokenizer, model) else: raise TypeError( f"Unknown task type {self._task_type}, must be one of: {SUPPORTED_TASK_TYPES_NAMES}"