diff --git a/eland/ml/_model_serializer.py b/eland/ml/_model_serializer.py index e9dc26f..976a85c 100644 --- a/eland/ml/_model_serializer.py +++ b/eland/ml/_model_serializer.py @@ -94,7 +94,12 @@ class TreeNode: add_if_exists(d, "split_feature", self._split_feature) add_if_exists(d, "threshold", self._threshold) else: - add_if_exists(d, "leaf_value", self._leaf_value) + if len(self._leaf_value) == 1: + # Support Elasticsearch 7.6 which only + # singular leaf_values not in arrays + add_if_exists(d, "leaf_value", self._leaf_value[0]) + else: + add_if_exists(d, "leaf_value", self._leaf_value) return d diff --git a/eland/ml/imported_ml_model.py b/eland/ml/imported_ml_model.py index deed962..23418fe 100644 --- a/eland/ml/imported_ml_model.py +++ b/eland/ml/imported_ml_model.py @@ -39,7 +39,7 @@ if TYPE_CHECKING: except ImportError: pass try: - from lightgbm import LGBMRegressor, LGBMClassifier # type: ignore # noqa: f401 + from lightgbm import LGBMRegressor, LGBMClassifier # type: ignore # noqa: F401 except ImportError: pass diff --git a/eland/tests/__init__.py b/eland/tests/__init__.py index f416bf2..44c94c8 100644 --- a/eland/tests/__init__.py +++ b/eland/tests/__init__.py @@ -19,6 +19,7 @@ import os import pandas as pd from elasticsearch import Elasticsearch +from eland.common import es_version ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) @@ -34,6 +35,8 @@ if TEST_SUITE == "xpack": else: ES_TEST_CLIENT = Elasticsearch(ELASTICSEARCH_HOST) +ES_VERSION = es_version(ES_TEST_CLIENT) + FLIGHTS_INDEX_NAME = "flights" FLIGHTS_MAPPING = { "mappings": { diff --git a/eland/tests/ml/test_imported_ml_model_pytest.py b/eland/tests/ml/test_imported_ml_model_pytest.py index 01018d2..c3a9543 100644 --- a/eland/tests/ml/test_imported_ml_model_pytest.py +++ b/eland/tests/ml/test_imported_ml_model_pytest.py @@ -19,7 +19,7 @@ import pytest import numpy as np from eland.ml import ImportedMLModel -from eland.tests import ES_TEST_CLIENT +from eland.tests import ES_TEST_CLIENT, ES_VERSION try: @@ -62,6 +62,14 @@ requires_lightgbm = pytest.mark.skipif( ) +def skip_if_multiclass_classifition(): + if ES_VERSION < (7, 7): + raise pytest.skip( + "Skipped because multiclass classification " + "isn't supported on Elasticsearch 7.6" + ) + + def random_rows(data, size): return data[np.random.randint(data.shape[0], size=size), :].tolist() @@ -241,6 +249,7 @@ class TestImportedMLModel: def test_xgb_classifier(self, compress_model_definition, multi_class): # test both multiple and binary classification if multi_class: + skip_if_multiclass_classifition() training_data = datasets.make_classification( n_features=5, n_classes=3, n_informative=3 ) @@ -280,6 +289,7 @@ class TestImportedMLModel: def test_xgb_classifier_objectives_and_booster(self, objective, booster): # test both multiple and binary classification if objective.startswith("multi"): + skip_if_multiclass_classifition() training_data = datasets.make_classification( n_features=5, n_classes=3, n_informative=3 ) @@ -420,6 +430,7 @@ class TestImportedMLModel: ): # test both multiple and binary classification if objective.startswith("multi"): + skip_if_multiclass_classifition() training_data = datasets.make_classification( n_features=5, n_classes=3, n_informative=3 ) diff --git a/noxfile.py b/noxfile.py index cf00731..194a5a1 100644 --- a/noxfile.py +++ b/noxfile.py @@ -98,25 +98,19 @@ def test(session): session.run("python", "-m", "eland.tests.setup_tests") session.run("pytest", "--doctest-modules", *(session.posargs or ("eland/",))) - session.run("python", "-m", "pip", "uninstall", "--yes", "scikit-learn", "xgboost") + session.run( + "python", + "-m", + "pip", + "uninstall", + "--yes", + "scikit-learn", + "xgboost", + "lightgbm", + ) session.run("pytest", "eland/tests/ml/") -@nox.session(python=["3.6", "3.7", "3.8"], name="test-ml-deps") -def test_ml_deps(session): - def session_uninstall(*deps): - session.run("python", "-m", "pip", "uninstall", "--yes", *deps) - - session.install("-r", "requirements-dev.txt") - session.run("python", "-m", "eland.tests.setup_tests") - - session_uninstall("xgboost", "scikit-learn", "lightgbm") - session.run("pytest", *(session.posargs or ("eland/tests/ml/",))) - - session.install(".[scikit-learn]") - session.run("pytest", *(session.posargs or ("eland/tests/ml/",))) - - @nox.session(reuse_venv=True) def docs(session): # Run this so users get an error if they don't have Pandoc installed.