diff --git a/eland/ml/pytorch/transformers.py b/eland/ml/pytorch/transformers.py index 4731659..2b16b1f 100644 --- a/eland/ml/pytorch/transformers.py +++ b/eland/ml/pytorch/transformers.py @@ -729,7 +729,7 @@ class TransformerModel: else: sample_embedding = self._traceable_model.sample_output() if type(sample_embedding) is tuple: - text_embedding, _ = sample_embedding + text_embedding = sample_embedding[0] else: text_embedding = sample_embedding