mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
Upgrade torch to 2.1.2 (#671)
Compatible with Elasticsearch 8.13 where the same upgrade has been made
This commit is contained in:
parent
aaec995b1b
commit
ae0bba34c6
@ -31,6 +31,6 @@ steps:
|
|||||||
- '3.9'
|
- '3.9'
|
||||||
- '3.8'
|
- '3.8'
|
||||||
stack:
|
stack:
|
||||||
- '8.11-SNAPSHOT'
|
- '8.13.0-SNAPSHOT'
|
||||||
- '8.12-SNAPSHOT'
|
- '8.12.2'
|
||||||
command: ./.buildkite/run-tests
|
command: ./.buildkite/run-tests
|
||||||
|
@ -18,7 +18,7 @@ RUN --mount=type=cache,target=/root/.cache/pip \
|
|||||||
if [ "$TARGETPLATFORM" = "linux/amd64" ]; then \
|
if [ "$TARGETPLATFORM" = "linux/amd64" ]; then \
|
||||||
python3 -m pip install \
|
python3 -m pip install \
|
||||||
--no-cache-dir --disable-pip-version-check --extra-index-url https://download.pytorch.org/whl/cpu \
|
--no-cache-dir --disable-pip-version-check --extra-index-url https://download.pytorch.org/whl/cpu \
|
||||||
torch==1.13.1+cpu .[all]; \
|
torch==2.1.2+cpu .[all]; \
|
||||||
else \
|
else \
|
||||||
python3 -m pip install \
|
python3 -m pip install \
|
||||||
--no-cache-dir --disable-pip-version-check \
|
--no-cache-dir --disable-pip-version-check \
|
||||||
|
@ -54,7 +54,7 @@ $ conda install -c conda-forge eland
|
|||||||
### Compatibility
|
### Compatibility
|
||||||
|
|
||||||
- Supports Python 3.8, 3.9, 3.10 and Pandas 1.5
|
- Supports Python 3.8, 3.9, 3.10 and Pandas 1.5
|
||||||
- Supports Elasticsearch clusters that are 7.11+, recommended 8.3 or later for all features to work.
|
- Supports Elasticsearch clusters that are 7.11+, recommended 8.13 or later for all features to work.
|
||||||
If you are using the NLP with PyTorch feature make sure your Eland minor version matches the minor
|
If you are using the NLP with PyTorch feature make sure your Eland minor version matches the minor
|
||||||
version of your Elasticsearch cluster. For all other features it is sufficient for the major versions
|
version of your Elasticsearch cluster. For all other features it is sufficient for the major versions
|
||||||
to match.
|
to match.
|
||||||
|
@ -209,13 +209,13 @@ def check_cluster_version(es_client, logger):
|
|||||||
)
|
)
|
||||||
exit(1)
|
exit(1)
|
||||||
|
|
||||||
# PyTorch was upgraded to version 1.13.1 in 8.7.
|
# PyTorch was upgraded to version 2.1.2 in 8.13
|
||||||
# and is incompatible with earlier versions
|
# and is incompatible with earlier versions
|
||||||
if major_version == 8 and minor_version < 7:
|
if major_version == 8 and minor_version < 13:
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Eland uses PyTorch version {torch.__version__} which is incompatible with Elasticsearch versions prior to 8.7. Please upgrade Elasticsearch to at least version 8.7"
|
f"Eland uses PyTorch version {torch.__version__} which is incompatible with Elasticsearch versions prior to 8.13. Please upgrade Elasticsearch to at least version 8.13"
|
||||||
)
|
)
|
||||||
exit(1)
|
exit(1)
|
||||||
|
|
||||||
|
@ -464,7 +464,7 @@ class _TransformerTraceableModel(TraceableModel):
|
|||||||
|
|
||||||
def _trace(self) -> TracedModelTypes:
|
def _trace(self) -> TracedModelTypes:
|
||||||
inputs = self._compatible_inputs()
|
inputs = self._compatible_inputs()
|
||||||
return torch.jit.trace(self._model, inputs)
|
return torch.jit.trace(self._model, example_inputs=inputs)
|
||||||
|
|
||||||
def sample_output(self) -> Tensor:
|
def sample_output(self) -> Tensor:
|
||||||
inputs = self._compatible_inputs()
|
inputs = self._compatible_inputs()
|
||||||
|
@ -14,14 +14,12 @@ scikit-learn>=1.3,<1.4
|
|||||||
xgboost>=0.90,<2
|
xgboost>=0.90,<2
|
||||||
lightgbm>=2,<4
|
lightgbm>=2,<4
|
||||||
|
|
||||||
# PyTorch doesn't support Python 3.11 yet (pytorch/pytorch#86566)
|
|
||||||
|
|
||||||
# Elasticsearch uses v1.13.1 of PyTorch
|
|
||||||
torch>=1.13.1,<2.0; python_version<'3.11'
|
|
||||||
# Versions known to be compatible with PyTorch 1.13.1
|
|
||||||
sentence-transformers>=2.1.0,<=2.2.2; python_version<'3.11'
|
|
||||||
transformers[torch]>=4.31.0,<=4.33.2; python_version<'3.11'
|
|
||||||
|
|
||||||
|
# Elasticsearch uses v2.1.2 of PyTorch
|
||||||
|
torch==2.1.2
|
||||||
|
# Versions known to be compatible with PyTorch 2.1.2
|
||||||
|
sentence-transformers>=2.1.0,<=2.3.1
|
||||||
|
transformers[torch]>=4.31.0,<4.36.0
|
||||||
#
|
#
|
||||||
# Testing
|
# Testing
|
||||||
#
|
#
|
||||||
|
6
setup.py
6
setup.py
@ -58,9 +58,9 @@ extras = {
|
|||||||
"scikit-learn": ["scikit-learn>=1.3,<1.4"],
|
"scikit-learn": ["scikit-learn>=1.3,<1.4"],
|
||||||
"lightgbm": ["lightgbm>=2,<4"],
|
"lightgbm": ["lightgbm>=2,<4"],
|
||||||
"pytorch": [
|
"pytorch": [
|
||||||
"torch>=1.13.1,<2.0",
|
"torch==2.1.2",
|
||||||
"sentence-transformers>=2.1.0,<=2.2.2",
|
"sentence-transformers>=2.1.0,<=2.3.1",
|
||||||
"transformers[torch]>=4.31.0,<=4.33.2",
|
"transformers[torch]>=4.31.0,<4.36.0",
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
extras["all"] = list({dep for deps in extras.values() for dep in deps})
|
extras["all"] = list({dep for deps in extras.values() for dep in deps})
|
||||||
|
@ -27,7 +27,6 @@ from tests.common import ROOT_DIR, TestData
|
|||||||
|
|
||||||
|
|
||||||
class TestDataFrameToJSON(TestData):
|
class TestDataFrameToJSON(TestData):
|
||||||
|
|
||||||
def test_to_json_default_arguments(self):
|
def test_to_json_default_arguments(self):
|
||||||
ed_flights = self.ed_flights()
|
ed_flights = self.ed_flights()
|
||||||
pd_flights = self.pd_flights()
|
pd_flights = self.pd_flights()
|
||||||
@ -96,7 +95,6 @@ class TestDataFrameToJSON(TestData):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def test_to_json_with_other_buffer(self):
|
def test_to_json_with_other_buffer(self):
|
||||||
|
|
||||||
ed_flights = self.ed_flights()
|
ed_flights = self.ed_flights()
|
||||||
pd_flights = self.pd_flights()
|
pd_flights = self.pd_flights()
|
||||||
output_buffer = StringIO()
|
output_buffer = StringIO()
|
||||||
|
@ -58,8 +58,8 @@ from tests import ES_VERSION
|
|||||||
|
|
||||||
pytestmark = [
|
pytestmark = [
|
||||||
pytest.mark.skipif(
|
pytest.mark.skipif(
|
||||||
ES_VERSION < (8, 7, 0),
|
ES_VERSION < (8, 13, 0),
|
||||||
reason="Eland uses Pytorch 1.13.1, versions of Elasticsearch prior to 8.7.0 are incompatible with PyTorch 1.13.1",
|
reason="Eland uses Pytorch 2.1.2, versions of Elasticsearch prior to 8.13.0 are incompatible with PyTorch 2.1.2",
|
||||||
),
|
),
|
||||||
pytest.mark.skipif(
|
pytest.mark.skipif(
|
||||||
not HAS_SKLEARN, reason="This test requires 'scikit-learn' package to run"
|
not HAS_SKLEARN, reason="This test requires 'scikit-learn' package to run"
|
||||||
|
@ -39,8 +39,8 @@ from tests import ES_TEST_CLIENT, ES_VERSION
|
|||||||
|
|
||||||
pytestmark = [
|
pytestmark = [
|
||||||
pytest.mark.skipif(
|
pytest.mark.skipif(
|
||||||
ES_VERSION < (8, 7, 0),
|
ES_VERSION < (8, 13, 0),
|
||||||
reason="Eland uses Pytorch 1.13.1, versions of Elasticsearch prior to 8.7.0 are incompatible with PyTorch 1.13.1",
|
reason="Eland uses Pytorch 2.1.2, versions of Elasticsearch prior to 8.13.0 are incompatible with PyTorch 2.1.2",
|
||||||
),
|
),
|
||||||
pytest.mark.skipif(
|
pytest.mark.skipif(
|
||||||
not HAS_SKLEARN, reason="This test requires 'scikit-learn' package to run"
|
not HAS_SKLEARN, reason="This test requires 'scikit-learn' package to run"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user