Upgrade torch to 2.1.2 (#671)

Compatible with Elasticsearch 8.13 where the same upgrade has been made
This commit is contained in:
David Kyle 2024-03-26 10:06:50 +00:00 committed by GitHub
parent aaec995b1b
commit ae0bba34c6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 20 additions and 24 deletions

View File

@ -31,6 +31,6 @@ steps:
- '3.9'
- '3.8'
stack:
- '8.11-SNAPSHOT'
- '8.12-SNAPSHOT'
- '8.13.0-SNAPSHOT'
- '8.12.2'
command: ./.buildkite/run-tests

View File

@ -18,7 +18,7 @@ RUN --mount=type=cache,target=/root/.cache/pip \
if [ "$TARGETPLATFORM" = "linux/amd64" ]; then \
python3 -m pip install \
--no-cache-dir --disable-pip-version-check --extra-index-url https://download.pytorch.org/whl/cpu \
torch==1.13.1+cpu .[all]; \
torch==2.1.2+cpu .[all]; \
else \
python3 -m pip install \
--no-cache-dir --disable-pip-version-check \

View File

@ -54,7 +54,7 @@ $ conda install -c conda-forge eland
### Compatibility
- Supports Python 3.8, 3.9, 3.10 and Pandas 1.5
- Supports Elasticsearch clusters that are 7.11+, recommended 8.3 or later for all features to work.
- Supports Elasticsearch clusters that are 7.11+, recommended 8.13 or later for all features to work.
If you are using the NLP with PyTorch feature make sure your Eland minor version matches the minor
version of your Elasticsearch cluster. For all other features it is sufficient for the major versions
to match.

View File

@ -209,13 +209,13 @@ def check_cluster_version(es_client, logger):
)
exit(1)
# PyTorch was upgraded to version 1.13.1 in 8.7.
# PyTorch was upgraded to version 2.1.2 in 8.13
# and is incompatible with earlier versions
if major_version == 8 and minor_version < 7:
if major_version == 8 and minor_version < 13:
import torch
logger.error(
f"Eland uses PyTorch version {torch.__version__} which is incompatible with Elasticsearch versions prior to 8.7. Please upgrade Elasticsearch to at least version 8.7"
f"Eland uses PyTorch version {torch.__version__} which is incompatible with Elasticsearch versions prior to 8.13. Please upgrade Elasticsearch to at least version 8.13"
)
exit(1)

View File

@ -464,7 +464,7 @@ class _TransformerTraceableModel(TraceableModel):
def _trace(self) -> TracedModelTypes:
inputs = self._compatible_inputs()
return torch.jit.trace(self._model, inputs)
return torch.jit.trace(self._model, example_inputs=inputs)
def sample_output(self) -> Tensor:
inputs = self._compatible_inputs()

View File

@ -14,14 +14,12 @@ scikit-learn>=1.3,<1.4
xgboost>=0.90,<2
lightgbm>=2,<4
# PyTorch doesn't support Python 3.11 yet (pytorch/pytorch#86566)
# Elasticsearch uses v1.13.1 of PyTorch
torch>=1.13.1,<2.0; python_version<'3.11'
# Versions known to be compatible with PyTorch 1.13.1
sentence-transformers>=2.1.0,<=2.2.2; python_version<'3.11'
transformers[torch]>=4.31.0,<=4.33.2; python_version<'3.11'
# Elasticsearch uses v2.1.2 of PyTorch
torch==2.1.2
# Versions known to be compatible with PyTorch 2.1.2
sentence-transformers>=2.1.0,<=2.3.1
transformers[torch]>=4.31.0,<4.36.0
#
# Testing
#

View File

@ -58,9 +58,9 @@ extras = {
"scikit-learn": ["scikit-learn>=1.3,<1.4"],
"lightgbm": ["lightgbm>=2,<4"],
"pytorch": [
"torch>=1.13.1,<2.0",
"sentence-transformers>=2.1.0,<=2.2.2",
"transformers[torch]>=4.31.0,<=4.33.2",
"torch==2.1.2",
"sentence-transformers>=2.1.0,<=2.3.1",
"transformers[torch]>=4.31.0,<4.36.0",
],
}
extras["all"] = list({dep for deps in extras.values() for dep in deps})

View File

@ -27,7 +27,6 @@ from tests.common import ROOT_DIR, TestData
class TestDataFrameToJSON(TestData):
def test_to_json_default_arguments(self):
ed_flights = self.ed_flights()
pd_flights = self.pd_flights()
@ -96,7 +95,6 @@ class TestDataFrameToJSON(TestData):
)
def test_to_json_with_other_buffer(self):
ed_flights = self.ed_flights()
pd_flights = self.pd_flights()
output_buffer = StringIO()

View File

@ -58,8 +58,8 @@ from tests import ES_VERSION
pytestmark = [
pytest.mark.skipif(
ES_VERSION < (8, 7, 0),
reason="Eland uses Pytorch 1.13.1, versions of Elasticsearch prior to 8.7.0 are incompatible with PyTorch 1.13.1",
ES_VERSION < (8, 13, 0),
reason="Eland uses Pytorch 2.1.2, versions of Elasticsearch prior to 8.13.0 are incompatible with PyTorch 2.1.2",
),
pytest.mark.skipif(
not HAS_SKLEARN, reason="This test requires 'scikit-learn' package to run"

View File

@ -39,8 +39,8 @@ from tests import ES_TEST_CLIENT, ES_VERSION
pytestmark = [
pytest.mark.skipif(
ES_VERSION < (8, 7, 0),
reason="Eland uses Pytorch 1.13.1, versions of Elasticsearch prior to 8.7.0 are incompatible with PyTorch 1.13.1",
ES_VERSION < (8, 13, 0),
reason="Eland uses Pytorch 2.1.2, versions of Elasticsearch prior to 8.13.0 are incompatible with PyTorch 2.1.2",
),
pytest.mark.skipif(
not HAS_SKLEARN, reason="This test requires 'scikit-learn' package to run"