diff --git a/eland/ml/pytorch/transformers.py b/eland/ml/pytorch/transformers.py index 83faaf8..29d4ef8 100644 --- a/eland/ml/pytorch/transformers.py +++ b/eland/ml/pytorch/transformers.py @@ -527,6 +527,15 @@ class _TraceableFillMaskModel(_TransformerTraceableModel): ) +class _TraceableTextExpansionModel(_TransformerTraceableModel): + def _prepare_inputs(self) -> transformers.BatchEncoding: + return self._tokenizer( + "This is an example sentence.", + padding="max_length", + return_tensors="pt", + ) + + class _TraceableNerModel(_TraceableClassificationModel): def _prepare_inputs(self) -> transformers.BatchEncoding: return self._tokenizer( @@ -984,6 +993,13 @@ class TransformerModel: else: self._task_type = maybe_task_type + if self._task_type == "text_expansion": + model = transformers.AutoModelForMaskedLM.from_pretrained( + self._model_id, token=self._access_token, torchscript=True + ) + model = _DistilBertWrapper.try_wrapping(model) + return _TraceableTextExpansionModel(self._tokenizer, model) + if self._task_type == "fill_mask": model = transformers.AutoModelForMaskedLM.from_pretrained( self._model_id, token=self._access_token, torchscript=True @@ -1043,7 +1059,7 @@ class TransformerModel: else: raise TypeError( - f"Unknown task type {self._task_type}, must be one of: {SUPPORTED_TASK_TYPES_NAMES}" + f"Task {self._task_type} is not supported, must be one of: {SUPPORTED_TASK_TYPES_NAMES}" ) def elasticsearch_model_id(self) -> str: