From 525350170484fc8d236950d3171537f987c5d378 Mon Sep 17 00:00:00 2001 From: David Kyle Date: Mon, 30 Sep 2024 10:22:02 +0100 Subject: [PATCH] Upgrade PyTorch to version 2.3.1 (#718) Upgrades the PyTorch, transformers and sentence transformer requirements. Elasticsearch has upgraded to PyTorch to 2.3.1 in 8.16 and 8.15.2. For compatibility reasons Eland will refuse to upload to an Elasticsearch cluster that has is using an earlier version of PyTorch. --- eland/cli/eland_import_hub_model.py | 7 +++---- eland/ml/pytorch/transformers.py | 4 ++++ noxfile.py | 2 +- setup.py | 8 +++++--- .../ml/pytorch/test_pytorch_model_config_pytest.py | 13 ++----------- .../ml/pytorch/test_pytorch_model_upload_pytest.py | 4 ++-- 6 files changed, 17 insertions(+), 21 deletions(-) diff --git a/eland/cli/eland_import_hub_model.py b/eland/cli/eland_import_hub_model.py index 4ca8544..9496e91 100755 --- a/eland/cli/eland_import_hub_model.py +++ b/eland/cli/eland_import_hub_model.py @@ -229,7 +229,6 @@ def check_cluster_version(es_client, logger): sem_ver = parse_es_version(es_info["version"]["number"]) major_version = sem_ver[0] - minor_version = sem_ver[1] # NLP models added in 8 if major_version < 8: @@ -238,13 +237,13 @@ def check_cluster_version(es_client, logger): ) exit(1) - # PyTorch was upgraded to version 2.1.2 in 8.13 + # PyTorch was upgraded to version 2.3.1 in 8.15.2 # and is incompatible with earlier versions - if major_version == 8 and minor_version < 13: + if sem_ver < (8, 15, 2): import torch logger.error( - f"Eland uses PyTorch version {torch.__version__} which is incompatible with Elasticsearch versions prior to 8.13. Please upgrade Elasticsearch to at least version 8.13" + f"Eland uses PyTorch version {torch.__version__} which is incompatible with Elasticsearch versions prior to 8.15.2. Please upgrade Elasticsearch to at least version 8.15.2" ) exit(1) diff --git a/eland/ml/pytorch/transformers.py b/eland/ml/pytorch/transformers.py index ab89e55..271a243 100644 --- a/eland/ml/pytorch/transformers.py +++ b/eland/ml/pytorch/transformers.py @@ -36,6 +36,7 @@ from transformers import ( AutoConfig, AutoModel, AutoModelForQuestionAnswering, + BertTokenizer, PretrainedConfig, PreTrainedModel, PreTrainedTokenizer, @@ -757,6 +758,9 @@ class TransformerModel: if max_len is not None and max_len < REASONABLE_MAX_LENGTH: return int(max_len) + if isinstance(self._tokenizer, BertTokenizer): + return 512 + raise UnknownModelInputSizeError("Cannot determine model max input length") def _create_config( diff --git a/noxfile.py b/noxfile.py index a60950e..e8a5719 100644 --- a/noxfile.py +++ b/noxfile.py @@ -121,7 +121,7 @@ def test(session, pandas_version: str): "--nbval", ) - # PyTorch 2.1.2 doesn't support Python 3.12 + # PyTorch 2.3.1 doesn't support Python 3.12 if session.python == "3.12": pytest_args += ("--ignore=eland/ml/pytorch",) session.run( diff --git a/setup.py b/setup.py index 1767ea3..1befe7d 100644 --- a/setup.py +++ b/setup.py @@ -60,10 +60,12 @@ extras = { "lightgbm": ["lightgbm>=2,<4"], "pytorch": [ "requests<3", - "torch==2.1.2", + "torch==2.3.1", "tqdm", - "sentence-transformers>=2.1.0,<=2.3.1", - "transformers[torch]>=4.31.0,<4.36.0", + "sentence-transformers>=2.1.0,<=2.7.0", + # sentencepiece is a required dependency for the slow tokenizers + # https://huggingface.co/transformers/v4.4.2/migration.html#sentencepiece-is-removed-from-the-required-dependencies + "transformers[sentencepiece]>=4.31.0,<4.44.0", ], } extras["all"] = list({dep for deps in extras.values() for dep in deps}) diff --git a/tests/ml/pytorch/test_pytorch_model_config_pytest.py b/tests/ml/pytorch/test_pytorch_model_config_pytest.py index 50ea4aa..c12be3a 100644 --- a/tests/ml/pytorch/test_pytorch_model_config_pytest.py +++ b/tests/ml/pytorch/test_pytorch_model_config_pytest.py @@ -58,8 +58,8 @@ from tests import ES_VERSION pytestmark = [ pytest.mark.skipif( - ES_VERSION < (8, 13, 0), - reason="Eland uses Pytorch 2.1.2, versions of Elasticsearch prior to 8.13.0 are incompatible with PyTorch 2.1.2", + ES_VERSION < (8, 15, 1), + reason="Eland uses Pytorch 2.3.1, versions of Elasticsearch prior to 8.15.1 are incompatible with PyTorch 2.3.1", ), pytest.mark.skipif( not HAS_SKLEARN, reason="This test requires 'scikit-learn' package to run" @@ -149,21 +149,12 @@ if HAS_PYTORCH and HAS_SKLEARN and HAS_TRANSFORMERS: 1024, None, ), - ( - "cardiffnlp/twitter-roberta-base-sentiment", - "text_classification", - TextClassificationInferenceOptions, - NlpRobertaTokenizationConfig, - 512, - None, - ), ] else: MODEL_CONFIGURATIONS = [] class TestModelConfguration: - @pytest.mark.skip(reason="https://github.com/elastic/eland/issues/633") @pytest.mark.parametrize( "model_id,task_type,config_type,tokenizer_type,max_sequence_len,embedding_size", MODEL_CONFIGURATIONS, diff --git a/tests/ml/pytorch/test_pytorch_model_upload_pytest.py b/tests/ml/pytorch/test_pytorch_model_upload_pytest.py index 7eac6a8..c84a77e 100644 --- a/tests/ml/pytorch/test_pytorch_model_upload_pytest.py +++ b/tests/ml/pytorch/test_pytorch_model_upload_pytest.py @@ -39,8 +39,8 @@ from tests import ES_TEST_CLIENT, ES_VERSION pytestmark = [ pytest.mark.skipif( - ES_VERSION < (8, 13, 0), - reason="Eland uses Pytorch 2.1.2, versions of Elasticsearch prior to 8.13.0 are incompatible with PyTorch 2.1.2", + ES_VERSION < (8, 15, 2), + reason="Eland uses Pytorch 2.3.1, versions of Elasticsearch prior to 8.15.2 are incompatible with PyTorch 2.3.1", ), pytest.mark.skipif( not HAS_SKLEARN, reason="This test requires 'scikit-learn' package to run"