mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
Elasticsearch 7.6 only supports scalar leaf_values
This commit is contained in:
parent
92170c22d9
commit
d238bc5d42
@ -94,7 +94,12 @@ class TreeNode:
|
|||||||
add_if_exists(d, "split_feature", self._split_feature)
|
add_if_exists(d, "split_feature", self._split_feature)
|
||||||
add_if_exists(d, "threshold", self._threshold)
|
add_if_exists(d, "threshold", self._threshold)
|
||||||
else:
|
else:
|
||||||
add_if_exists(d, "leaf_value", self._leaf_value)
|
if len(self._leaf_value) == 1:
|
||||||
|
# Support Elasticsearch 7.6 which only
|
||||||
|
# singular leaf_values not in arrays
|
||||||
|
add_if_exists(d, "leaf_value", self._leaf_value[0])
|
||||||
|
else:
|
||||||
|
add_if_exists(d, "leaf_value", self._leaf_value)
|
||||||
return d
|
return d
|
||||||
|
|
||||||
|
|
||||||
|
@ -39,7 +39,7 @@ if TYPE_CHECKING:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
try:
|
try:
|
||||||
from lightgbm import LGBMRegressor, LGBMClassifier # type: ignore # noqa: f401
|
from lightgbm import LGBMRegressor, LGBMClassifier # type: ignore # noqa: F401
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -19,6 +19,7 @@ import os
|
|||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from elasticsearch import Elasticsearch
|
from elasticsearch import Elasticsearch
|
||||||
|
from eland.common import es_version
|
||||||
|
|
||||||
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
@ -34,6 +35,8 @@ if TEST_SUITE == "xpack":
|
|||||||
else:
|
else:
|
||||||
ES_TEST_CLIENT = Elasticsearch(ELASTICSEARCH_HOST)
|
ES_TEST_CLIENT = Elasticsearch(ELASTICSEARCH_HOST)
|
||||||
|
|
||||||
|
ES_VERSION = es_version(ES_TEST_CLIENT)
|
||||||
|
|
||||||
FLIGHTS_INDEX_NAME = "flights"
|
FLIGHTS_INDEX_NAME = "flights"
|
||||||
FLIGHTS_MAPPING = {
|
FLIGHTS_MAPPING = {
|
||||||
"mappings": {
|
"mappings": {
|
||||||
|
@ -19,7 +19,7 @@ import pytest
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from eland.ml import ImportedMLModel
|
from eland.ml import ImportedMLModel
|
||||||
from eland.tests import ES_TEST_CLIENT
|
from eland.tests import ES_TEST_CLIENT, ES_VERSION
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -62,6 +62,14 @@ requires_lightgbm = pytest.mark.skipif(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def skip_if_multiclass_classifition():
|
||||||
|
if ES_VERSION < (7, 7):
|
||||||
|
raise pytest.skip(
|
||||||
|
"Skipped because multiclass classification "
|
||||||
|
"isn't supported on Elasticsearch 7.6"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def random_rows(data, size):
|
def random_rows(data, size):
|
||||||
return data[np.random.randint(data.shape[0], size=size), :].tolist()
|
return data[np.random.randint(data.shape[0], size=size), :].tolist()
|
||||||
|
|
||||||
@ -241,6 +249,7 @@ class TestImportedMLModel:
|
|||||||
def test_xgb_classifier(self, compress_model_definition, multi_class):
|
def test_xgb_classifier(self, compress_model_definition, multi_class):
|
||||||
# test both multiple and binary classification
|
# test both multiple and binary classification
|
||||||
if multi_class:
|
if multi_class:
|
||||||
|
skip_if_multiclass_classifition()
|
||||||
training_data = datasets.make_classification(
|
training_data = datasets.make_classification(
|
||||||
n_features=5, n_classes=3, n_informative=3
|
n_features=5, n_classes=3, n_informative=3
|
||||||
)
|
)
|
||||||
@ -280,6 +289,7 @@ class TestImportedMLModel:
|
|||||||
def test_xgb_classifier_objectives_and_booster(self, objective, booster):
|
def test_xgb_classifier_objectives_and_booster(self, objective, booster):
|
||||||
# test both multiple and binary classification
|
# test both multiple and binary classification
|
||||||
if objective.startswith("multi"):
|
if objective.startswith("multi"):
|
||||||
|
skip_if_multiclass_classifition()
|
||||||
training_data = datasets.make_classification(
|
training_data = datasets.make_classification(
|
||||||
n_features=5, n_classes=3, n_informative=3
|
n_features=5, n_classes=3, n_informative=3
|
||||||
)
|
)
|
||||||
@ -420,6 +430,7 @@ class TestImportedMLModel:
|
|||||||
):
|
):
|
||||||
# test both multiple and binary classification
|
# test both multiple and binary classification
|
||||||
if objective.startswith("multi"):
|
if objective.startswith("multi"):
|
||||||
|
skip_if_multiclass_classifition()
|
||||||
training_data = datasets.make_classification(
|
training_data = datasets.make_classification(
|
||||||
n_features=5, n_classes=3, n_informative=3
|
n_features=5, n_classes=3, n_informative=3
|
||||||
)
|
)
|
||||||
|
26
noxfile.py
26
noxfile.py
@ -98,25 +98,19 @@ def test(session):
|
|||||||
session.run("python", "-m", "eland.tests.setup_tests")
|
session.run("python", "-m", "eland.tests.setup_tests")
|
||||||
session.run("pytest", "--doctest-modules", *(session.posargs or ("eland/",)))
|
session.run("pytest", "--doctest-modules", *(session.posargs or ("eland/",)))
|
||||||
|
|
||||||
session.run("python", "-m", "pip", "uninstall", "--yes", "scikit-learn", "xgboost")
|
session.run(
|
||||||
|
"python",
|
||||||
|
"-m",
|
||||||
|
"pip",
|
||||||
|
"uninstall",
|
||||||
|
"--yes",
|
||||||
|
"scikit-learn",
|
||||||
|
"xgboost",
|
||||||
|
"lightgbm",
|
||||||
|
)
|
||||||
session.run("pytest", "eland/tests/ml/")
|
session.run("pytest", "eland/tests/ml/")
|
||||||
|
|
||||||
|
|
||||||
@nox.session(python=["3.6", "3.7", "3.8"], name="test-ml-deps")
|
|
||||||
def test_ml_deps(session):
|
|
||||||
def session_uninstall(*deps):
|
|
||||||
session.run("python", "-m", "pip", "uninstall", "--yes", *deps)
|
|
||||||
|
|
||||||
session.install("-r", "requirements-dev.txt")
|
|
||||||
session.run("python", "-m", "eland.tests.setup_tests")
|
|
||||||
|
|
||||||
session_uninstall("xgboost", "scikit-learn", "lightgbm")
|
|
||||||
session.run("pytest", *(session.posargs or ("eland/tests/ml/",)))
|
|
||||||
|
|
||||||
session.install(".[scikit-learn]")
|
|
||||||
session.run("pytest", *(session.posargs or ("eland/tests/ml/",)))
|
|
||||||
|
|
||||||
|
|
||||||
@nox.session(reuse_venv=True)
|
@nox.session(reuse_venv=True)
|
||||||
def docs(session):
|
def docs(session):
|
||||||
# Run this so users get an error if they don't have Pandoc installed.
|
# Run this so users get an error if they don't have Pandoc installed.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user