mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
Upgrade PyTorch to version 2.3.1 (#718)
Upgrades the PyTorch, transformers and sentence transformer requirements. Elasticsearch has upgraded to PyTorch to 2.3.1 in 8.16 and 8.15.2. For compatibility reasons Eland will refuse to upload to an Elasticsearch cluster that has is using an earlier version of PyTorch.
This commit is contained in:
parent
ec66b5f320
commit
5253501704
@ -229,7 +229,6 @@ def check_cluster_version(es_client, logger):
|
||||
|
||||
sem_ver = parse_es_version(es_info["version"]["number"])
|
||||
major_version = sem_ver[0]
|
||||
minor_version = sem_ver[1]
|
||||
|
||||
# NLP models added in 8
|
||||
if major_version < 8:
|
||||
@ -238,13 +237,13 @@ def check_cluster_version(es_client, logger):
|
||||
)
|
||||
exit(1)
|
||||
|
||||
# PyTorch was upgraded to version 2.1.2 in 8.13
|
||||
# PyTorch was upgraded to version 2.3.1 in 8.15.2
|
||||
# and is incompatible with earlier versions
|
||||
if major_version == 8 and minor_version < 13:
|
||||
if sem_ver < (8, 15, 2):
|
||||
import torch
|
||||
|
||||
logger.error(
|
||||
f"Eland uses PyTorch version {torch.__version__} which is incompatible with Elasticsearch versions prior to 8.13. Please upgrade Elasticsearch to at least version 8.13"
|
||||
f"Eland uses PyTorch version {torch.__version__} which is incompatible with Elasticsearch versions prior to 8.15.2. Please upgrade Elasticsearch to at least version 8.15.2"
|
||||
)
|
||||
exit(1)
|
||||
|
||||
|
@ -36,6 +36,7 @@ from transformers import (
|
||||
AutoConfig,
|
||||
AutoModel,
|
||||
AutoModelForQuestionAnswering,
|
||||
BertTokenizer,
|
||||
PretrainedConfig,
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizer,
|
||||
@ -757,6 +758,9 @@ class TransformerModel:
|
||||
if max_len is not None and max_len < REASONABLE_MAX_LENGTH:
|
||||
return int(max_len)
|
||||
|
||||
if isinstance(self._tokenizer, BertTokenizer):
|
||||
return 512
|
||||
|
||||
raise UnknownModelInputSizeError("Cannot determine model max input length")
|
||||
|
||||
def _create_config(
|
||||
|
@ -121,7 +121,7 @@ def test(session, pandas_version: str):
|
||||
"--nbval",
|
||||
)
|
||||
|
||||
# PyTorch 2.1.2 doesn't support Python 3.12
|
||||
# PyTorch 2.3.1 doesn't support Python 3.12
|
||||
if session.python == "3.12":
|
||||
pytest_args += ("--ignore=eland/ml/pytorch",)
|
||||
session.run(
|
||||
|
8
setup.py
8
setup.py
@ -60,10 +60,12 @@ extras = {
|
||||
"lightgbm": ["lightgbm>=2,<4"],
|
||||
"pytorch": [
|
||||
"requests<3",
|
||||
"torch==2.1.2",
|
||||
"torch==2.3.1",
|
||||
"tqdm",
|
||||
"sentence-transformers>=2.1.0,<=2.3.1",
|
||||
"transformers[torch]>=4.31.0,<4.36.0",
|
||||
"sentence-transformers>=2.1.0,<=2.7.0",
|
||||
# sentencepiece is a required dependency for the slow tokenizers
|
||||
# https://huggingface.co/transformers/v4.4.2/migration.html#sentencepiece-is-removed-from-the-required-dependencies
|
||||
"transformers[sentencepiece]>=4.31.0,<4.44.0",
|
||||
],
|
||||
}
|
||||
extras["all"] = list({dep for deps in extras.values() for dep in deps})
|
||||
|
@ -58,8 +58,8 @@ from tests import ES_VERSION
|
||||
|
||||
pytestmark = [
|
||||
pytest.mark.skipif(
|
||||
ES_VERSION < (8, 13, 0),
|
||||
reason="Eland uses Pytorch 2.1.2, versions of Elasticsearch prior to 8.13.0 are incompatible with PyTorch 2.1.2",
|
||||
ES_VERSION < (8, 15, 1),
|
||||
reason="Eland uses Pytorch 2.3.1, versions of Elasticsearch prior to 8.15.1 are incompatible with PyTorch 2.3.1",
|
||||
),
|
||||
pytest.mark.skipif(
|
||||
not HAS_SKLEARN, reason="This test requires 'scikit-learn' package to run"
|
||||
@ -149,21 +149,12 @@ if HAS_PYTORCH and HAS_SKLEARN and HAS_TRANSFORMERS:
|
||||
1024,
|
||||
None,
|
||||
),
|
||||
(
|
||||
"cardiffnlp/twitter-roberta-base-sentiment",
|
||||
"text_classification",
|
||||
TextClassificationInferenceOptions,
|
||||
NlpRobertaTokenizationConfig,
|
||||
512,
|
||||
None,
|
||||
),
|
||||
]
|
||||
else:
|
||||
MODEL_CONFIGURATIONS = []
|
||||
|
||||
|
||||
class TestModelConfguration:
|
||||
@pytest.mark.skip(reason="https://github.com/elastic/eland/issues/633")
|
||||
@pytest.mark.parametrize(
|
||||
"model_id,task_type,config_type,tokenizer_type,max_sequence_len,embedding_size",
|
||||
MODEL_CONFIGURATIONS,
|
||||
|
@ -39,8 +39,8 @@ from tests import ES_TEST_CLIENT, ES_VERSION
|
||||
|
||||
pytestmark = [
|
||||
pytest.mark.skipif(
|
||||
ES_VERSION < (8, 13, 0),
|
||||
reason="Eland uses Pytorch 2.1.2, versions of Elasticsearch prior to 8.13.0 are incompatible with PyTorch 2.1.2",
|
||||
ES_VERSION < (8, 15, 2),
|
||||
reason="Eland uses Pytorch 2.3.1, versions of Elasticsearch prior to 8.15.2 are incompatible with PyTorch 2.3.1",
|
||||
),
|
||||
pytest.mark.skipif(
|
||||
not HAS_SKLEARN, reason="This test requires 'scikit-learn' package to run"
|
||||
|
Loading…
x
Reference in New Issue
Block a user