diff --git a/eland/ml/pytorch/transformers.py b/eland/ml/pytorch/transformers.py index 770e4f7..ab89e55 100644 --- a/eland/ml/pytorch/transformers.py +++ b/eland/ml/pytorch/transformers.py @@ -770,6 +770,9 @@ class TransformerModel: tokenization_config.span = 128 tokenization_config.truncate = "none" + if self._task_type == "text_similarity": + tokenization_config.truncate = "second" + if self._traceable_model.classification_labels(): inference_config = TASK_TYPE_TO_INFERENCE_CONFIG[self._task_type]( tokenization=tokenization_config, diff --git a/tests/ml/pytorch/test_pytorch_model_config_pytest.py b/tests/ml/pytorch/test_pytorch_model_config_pytest.py index 664e2d5..50ea4aa 100644 --- a/tests/ml/pytorch/test_pytorch_model_config_pytest.py +++ b/tests/ml/pytorch/test_pytorch_model_config_pytest.py @@ -217,6 +217,9 @@ class TestModelConfguration: assert isinstance(config.inference_config.classification_labels, list) assert len(config.inference_config.classification_labels) > 0 + if task_type == "text_similarity": + assert tokenization.truncate == "second" + del tm def test_model_config_with_prefix_string(self):