Upgrade PyTorch to version 2.3.1 (#718)

Upgrades the PyTorch, transformers and sentence transformer requirements.
Elasticsearch has upgraded to PyTorch to 2.3.1 in 8.16 and 8.15.2. For 
compatibility reasons Eland will refuse to upload to an Elasticsearch cluster 
that has is using an earlier version of PyTorch.
This commit is contained in:
David Kyle 2024-09-30 10:22:02 +01:00 committed by GitHub
parent ec66b5f320
commit 5253501704
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 17 additions and 21 deletions

View File

@ -229,7 +229,6 @@ def check_cluster_version(es_client, logger):
sem_ver = parse_es_version(es_info["version"]["number"])
major_version = sem_ver[0]
minor_version = sem_ver[1]
# NLP models added in 8
if major_version < 8:
@ -238,13 +237,13 @@ def check_cluster_version(es_client, logger):
)
exit(1)
# PyTorch was upgraded to version 2.1.2 in 8.13
# PyTorch was upgraded to version 2.3.1 in 8.15.2
# and is incompatible with earlier versions
if major_version == 8 and minor_version < 13:
if sem_ver < (8, 15, 2):
import torch
logger.error(
f"Eland uses PyTorch version {torch.__version__} which is incompatible with Elasticsearch versions prior to 8.13. Please upgrade Elasticsearch to at least version 8.13"
f"Eland uses PyTorch version {torch.__version__} which is incompatible with Elasticsearch versions prior to 8.15.2. Please upgrade Elasticsearch to at least version 8.15.2"
)
exit(1)

View File

@ -36,6 +36,7 @@ from transformers import (
AutoConfig,
AutoModel,
AutoModelForQuestionAnswering,
BertTokenizer,
PretrainedConfig,
PreTrainedModel,
PreTrainedTokenizer,
@ -757,6 +758,9 @@ class TransformerModel:
if max_len is not None and max_len < REASONABLE_MAX_LENGTH:
return int(max_len)
if isinstance(self._tokenizer, BertTokenizer):
return 512
raise UnknownModelInputSizeError("Cannot determine model max input length")
def _create_config(

View File

@ -121,7 +121,7 @@ def test(session, pandas_version: str):
"--nbval",
)
# PyTorch 2.1.2 doesn't support Python 3.12
# PyTorch 2.3.1 doesn't support Python 3.12
if session.python == "3.12":
pytest_args += ("--ignore=eland/ml/pytorch",)
session.run(

View File

@ -60,10 +60,12 @@ extras = {
"lightgbm": ["lightgbm>=2,<4"],
"pytorch": [
"requests<3",
"torch==2.1.2",
"torch==2.3.1",
"tqdm",
"sentence-transformers>=2.1.0,<=2.3.1",
"transformers[torch]>=4.31.0,<4.36.0",
"sentence-transformers>=2.1.0,<=2.7.0",
# sentencepiece is a required dependency for the slow tokenizers
# https://huggingface.co/transformers/v4.4.2/migration.html#sentencepiece-is-removed-from-the-required-dependencies
"transformers[sentencepiece]>=4.31.0,<4.44.0",
],
}
extras["all"] = list({dep for deps in extras.values() for dep in deps})

View File

@ -58,8 +58,8 @@ from tests import ES_VERSION
pytestmark = [
pytest.mark.skipif(
ES_VERSION < (8, 13, 0),
reason="Eland uses Pytorch 2.1.2, versions of Elasticsearch prior to 8.13.0 are incompatible with PyTorch 2.1.2",
ES_VERSION < (8, 15, 1),
reason="Eland uses Pytorch 2.3.1, versions of Elasticsearch prior to 8.15.1 are incompatible with PyTorch 2.3.1",
),
pytest.mark.skipif(
not HAS_SKLEARN, reason="This test requires 'scikit-learn' package to run"
@ -149,21 +149,12 @@ if HAS_PYTORCH and HAS_SKLEARN and HAS_TRANSFORMERS:
1024,
None,
),
(
"cardiffnlp/twitter-roberta-base-sentiment",
"text_classification",
TextClassificationInferenceOptions,
NlpRobertaTokenizationConfig,
512,
None,
),
]
else:
MODEL_CONFIGURATIONS = []
class TestModelConfguration:
@pytest.mark.skip(reason="https://github.com/elastic/eland/issues/633")
@pytest.mark.parametrize(
"model_id,task_type,config_type,tokenizer_type,max_sequence_len,embedding_size",
MODEL_CONFIGURATIONS,

View File

@ -39,8 +39,8 @@ from tests import ES_TEST_CLIENT, ES_VERSION
pytestmark = [
pytest.mark.skipif(
ES_VERSION < (8, 13, 0),
reason="Eland uses Pytorch 2.1.2, versions of Elasticsearch prior to 8.13.0 are incompatible with PyTorch 2.1.2",
ES_VERSION < (8, 15, 2),
reason="Eland uses Pytorch 2.3.1, versions of Elasticsearch prior to 8.15.2 are incompatible with PyTorch 2.3.1",
),
pytest.mark.skipif(
not HAS_SKLEARN, reason="This test requires 'scikit-learn' package to run"