mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
Upgrade torch to 2.1.2 (#671)
Compatible with Elasticsearch 8.13 where the same upgrade has been made
This commit is contained in:
parent
aaec995b1b
commit
ae0bba34c6
@ -31,6 +31,6 @@ steps:
|
||||
- '3.9'
|
||||
- '3.8'
|
||||
stack:
|
||||
- '8.11-SNAPSHOT'
|
||||
- '8.12-SNAPSHOT'
|
||||
- '8.13.0-SNAPSHOT'
|
||||
- '8.12.2'
|
||||
command: ./.buildkite/run-tests
|
||||
|
@ -18,7 +18,7 @@ RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
if [ "$TARGETPLATFORM" = "linux/amd64" ]; then \
|
||||
python3 -m pip install \
|
||||
--no-cache-dir --disable-pip-version-check --extra-index-url https://download.pytorch.org/whl/cpu \
|
||||
torch==1.13.1+cpu .[all]; \
|
||||
torch==2.1.2+cpu .[all]; \
|
||||
else \
|
||||
python3 -m pip install \
|
||||
--no-cache-dir --disable-pip-version-check \
|
||||
|
@ -54,7 +54,7 @@ $ conda install -c conda-forge eland
|
||||
### Compatibility
|
||||
|
||||
- Supports Python 3.8, 3.9, 3.10 and Pandas 1.5
|
||||
- Supports Elasticsearch clusters that are 7.11+, recommended 8.3 or later for all features to work.
|
||||
- Supports Elasticsearch clusters that are 7.11+, recommended 8.13 or later for all features to work.
|
||||
If you are using the NLP with PyTorch feature make sure your Eland minor version matches the minor
|
||||
version of your Elasticsearch cluster. For all other features it is sufficient for the major versions
|
||||
to match.
|
||||
|
@ -209,13 +209,13 @@ def check_cluster_version(es_client, logger):
|
||||
)
|
||||
exit(1)
|
||||
|
||||
# PyTorch was upgraded to version 1.13.1 in 8.7.
|
||||
# PyTorch was upgraded to version 2.1.2 in 8.13
|
||||
# and is incompatible with earlier versions
|
||||
if major_version == 8 and minor_version < 7:
|
||||
if major_version == 8 and minor_version < 13:
|
||||
import torch
|
||||
|
||||
logger.error(
|
||||
f"Eland uses PyTorch version {torch.__version__} which is incompatible with Elasticsearch versions prior to 8.7. Please upgrade Elasticsearch to at least version 8.7"
|
||||
f"Eland uses PyTorch version {torch.__version__} which is incompatible with Elasticsearch versions prior to 8.13. Please upgrade Elasticsearch to at least version 8.13"
|
||||
)
|
||||
exit(1)
|
||||
|
||||
|
@ -464,7 +464,7 @@ class _TransformerTraceableModel(TraceableModel):
|
||||
|
||||
def _trace(self) -> TracedModelTypes:
|
||||
inputs = self._compatible_inputs()
|
||||
return torch.jit.trace(self._model, inputs)
|
||||
return torch.jit.trace(self._model, example_inputs=inputs)
|
||||
|
||||
def sample_output(self) -> Tensor:
|
||||
inputs = self._compatible_inputs()
|
||||
|
@ -14,14 +14,12 @@ scikit-learn>=1.3,<1.4
|
||||
xgboost>=0.90,<2
|
||||
lightgbm>=2,<4
|
||||
|
||||
# PyTorch doesn't support Python 3.11 yet (pytorch/pytorch#86566)
|
||||
|
||||
# Elasticsearch uses v1.13.1 of PyTorch
|
||||
torch>=1.13.1,<2.0; python_version<'3.11'
|
||||
# Versions known to be compatible with PyTorch 1.13.1
|
||||
sentence-transformers>=2.1.0,<=2.2.2; python_version<'3.11'
|
||||
transformers[torch]>=4.31.0,<=4.33.2; python_version<'3.11'
|
||||
|
||||
# Elasticsearch uses v2.1.2 of PyTorch
|
||||
torch==2.1.2
|
||||
# Versions known to be compatible with PyTorch 2.1.2
|
||||
sentence-transformers>=2.1.0,<=2.3.1
|
||||
transformers[torch]>=4.31.0,<4.36.0
|
||||
#
|
||||
# Testing
|
||||
#
|
||||
|
6
setup.py
6
setup.py
@ -58,9 +58,9 @@ extras = {
|
||||
"scikit-learn": ["scikit-learn>=1.3,<1.4"],
|
||||
"lightgbm": ["lightgbm>=2,<4"],
|
||||
"pytorch": [
|
||||
"torch>=1.13.1,<2.0",
|
||||
"sentence-transformers>=2.1.0,<=2.2.2",
|
||||
"transformers[torch]>=4.31.0,<=4.33.2",
|
||||
"torch==2.1.2",
|
||||
"sentence-transformers>=2.1.0,<=2.3.1",
|
||||
"transformers[torch]>=4.31.0,<4.36.0",
|
||||
],
|
||||
}
|
||||
extras["all"] = list({dep for deps in extras.values() for dep in deps})
|
||||
|
@ -27,7 +27,6 @@ from tests.common import ROOT_DIR, TestData
|
||||
|
||||
|
||||
class TestDataFrameToJSON(TestData):
|
||||
|
||||
def test_to_json_default_arguments(self):
|
||||
ed_flights = self.ed_flights()
|
||||
pd_flights = self.pd_flights()
|
||||
@ -96,7 +95,6 @@ class TestDataFrameToJSON(TestData):
|
||||
)
|
||||
|
||||
def test_to_json_with_other_buffer(self):
|
||||
|
||||
ed_flights = self.ed_flights()
|
||||
pd_flights = self.pd_flights()
|
||||
output_buffer = StringIO()
|
||||
|
@ -58,8 +58,8 @@ from tests import ES_VERSION
|
||||
|
||||
pytestmark = [
|
||||
pytest.mark.skipif(
|
||||
ES_VERSION < (8, 7, 0),
|
||||
reason="Eland uses Pytorch 1.13.1, versions of Elasticsearch prior to 8.7.0 are incompatible with PyTorch 1.13.1",
|
||||
ES_VERSION < (8, 13, 0),
|
||||
reason="Eland uses Pytorch 2.1.2, versions of Elasticsearch prior to 8.13.0 are incompatible with PyTorch 2.1.2",
|
||||
),
|
||||
pytest.mark.skipif(
|
||||
not HAS_SKLEARN, reason="This test requires 'scikit-learn' package to run"
|
||||
|
@ -39,8 +39,8 @@ from tests import ES_TEST_CLIENT, ES_VERSION
|
||||
|
||||
pytestmark = [
|
||||
pytest.mark.skipif(
|
||||
ES_VERSION < (8, 7, 0),
|
||||
reason="Eland uses Pytorch 1.13.1, versions of Elasticsearch prior to 8.7.0 are incompatible with PyTorch 1.13.1",
|
||||
ES_VERSION < (8, 13, 0),
|
||||
reason="Eland uses Pytorch 2.1.2, versions of Elasticsearch prior to 8.13.0 are incompatible with PyTorch 2.1.2",
|
||||
),
|
||||
pytest.mark.skipif(
|
||||
not HAS_SKLEARN, reason="This test requires 'scikit-learn' package to run"
|
||||
|
Loading…
x
Reference in New Issue
Block a user