From fd8886da6a7bc99c954a78c43f27b5b6f4ae8157 Mon Sep 17 00:00:00 2001 From: David Kyle Date: Mon, 5 Aug 2024 11:47:15 +0100 Subject: [PATCH] Default truncation to `second` for text similarity the task type(#713) In reranking the first input (the query) is generally shorter. In this case it makes more sense to truncate the second input (the document text) --- eland/ml/pytorch/transformers.py | 3 +++ tests/ml/pytorch/test_pytorch_model_config_pytest.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/eland/ml/pytorch/transformers.py b/eland/ml/pytorch/transformers.py index 770e4f7..ab89e55 100644 --- a/eland/ml/pytorch/transformers.py +++ b/eland/ml/pytorch/transformers.py @@ -770,6 +770,9 @@ class TransformerModel: tokenization_config.span = 128 tokenization_config.truncate = "none" + if self._task_type == "text_similarity": + tokenization_config.truncate = "second" + if self._traceable_model.classification_labels(): inference_config = TASK_TYPE_TO_INFERENCE_CONFIG[self._task_type]( tokenization=tokenization_config, diff --git a/tests/ml/pytorch/test_pytorch_model_config_pytest.py b/tests/ml/pytorch/test_pytorch_model_config_pytest.py index 664e2d5..50ea4aa 100644 --- a/tests/ml/pytorch/test_pytorch_model_config_pytest.py +++ b/tests/ml/pytorch/test_pytorch_model_config_pytest.py @@ -217,6 +217,9 @@ class TestModelConfguration: assert isinstance(config.inference_config.classification_labels, list) assert len(config.inference_config.classification_labels) > 0 + if task_type == "text_similarity": + assert tokenization.truncate == "second" + del tm def test_model_config_with_prefix_string(self):