mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
Elasticsearch 7.6 only supports scalar leaf_values
This commit is contained in:
parent
92170c22d9
commit
d238bc5d42
@ -93,6 +93,11 @@ class TreeNode:
|
||||
add_if_exists(d, "right_child", self._right_child)
|
||||
add_if_exists(d, "split_feature", self._split_feature)
|
||||
add_if_exists(d, "threshold", self._threshold)
|
||||
else:
|
||||
if len(self._leaf_value) == 1:
|
||||
# Support Elasticsearch 7.6 which only
|
||||
# singular leaf_values not in arrays
|
||||
add_if_exists(d, "leaf_value", self._leaf_value[0])
|
||||
else:
|
||||
add_if_exists(d, "leaf_value", self._leaf_value)
|
||||
return d
|
||||
|
@ -39,7 +39,7 @@ if TYPE_CHECKING:
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
from lightgbm import LGBMRegressor, LGBMClassifier # type: ignore # noqa: f401
|
||||
from lightgbm import LGBMRegressor, LGBMClassifier # type: ignore # noqa: F401
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
@ -19,6 +19,7 @@ import os
|
||||
|
||||
import pandas as pd
|
||||
from elasticsearch import Elasticsearch
|
||||
from eland.common import es_version
|
||||
|
||||
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
@ -34,6 +35,8 @@ if TEST_SUITE == "xpack":
|
||||
else:
|
||||
ES_TEST_CLIENT = Elasticsearch(ELASTICSEARCH_HOST)
|
||||
|
||||
ES_VERSION = es_version(ES_TEST_CLIENT)
|
||||
|
||||
FLIGHTS_INDEX_NAME = "flights"
|
||||
FLIGHTS_MAPPING = {
|
||||
"mappings": {
|
||||
|
@ -19,7 +19,7 @@ import pytest
|
||||
import numpy as np
|
||||
|
||||
from eland.ml import ImportedMLModel
|
||||
from eland.tests import ES_TEST_CLIENT
|
||||
from eland.tests import ES_TEST_CLIENT, ES_VERSION
|
||||
|
||||
|
||||
try:
|
||||
@ -62,6 +62,14 @@ requires_lightgbm = pytest.mark.skipif(
|
||||
)
|
||||
|
||||
|
||||
def skip_if_multiclass_classifition():
|
||||
if ES_VERSION < (7, 7):
|
||||
raise pytest.skip(
|
||||
"Skipped because multiclass classification "
|
||||
"isn't supported on Elasticsearch 7.6"
|
||||
)
|
||||
|
||||
|
||||
def random_rows(data, size):
|
||||
return data[np.random.randint(data.shape[0], size=size), :].tolist()
|
||||
|
||||
@ -241,6 +249,7 @@ class TestImportedMLModel:
|
||||
def test_xgb_classifier(self, compress_model_definition, multi_class):
|
||||
# test both multiple and binary classification
|
||||
if multi_class:
|
||||
skip_if_multiclass_classifition()
|
||||
training_data = datasets.make_classification(
|
||||
n_features=5, n_classes=3, n_informative=3
|
||||
)
|
||||
@ -280,6 +289,7 @@ class TestImportedMLModel:
|
||||
def test_xgb_classifier_objectives_and_booster(self, objective, booster):
|
||||
# test both multiple and binary classification
|
||||
if objective.startswith("multi"):
|
||||
skip_if_multiclass_classifition()
|
||||
training_data = datasets.make_classification(
|
||||
n_features=5, n_classes=3, n_informative=3
|
||||
)
|
||||
@ -420,6 +430,7 @@ class TestImportedMLModel:
|
||||
):
|
||||
# test both multiple and binary classification
|
||||
if objective.startswith("multi"):
|
||||
skip_if_multiclass_classifition()
|
||||
training_data = datasets.make_classification(
|
||||
n_features=5, n_classes=3, n_informative=3
|
||||
)
|
||||
|
26
noxfile.py
26
noxfile.py
@ -98,25 +98,19 @@ def test(session):
|
||||
session.run("python", "-m", "eland.tests.setup_tests")
|
||||
session.run("pytest", "--doctest-modules", *(session.posargs or ("eland/",)))
|
||||
|
||||
session.run("python", "-m", "pip", "uninstall", "--yes", "scikit-learn", "xgboost")
|
||||
session.run(
|
||||
"python",
|
||||
"-m",
|
||||
"pip",
|
||||
"uninstall",
|
||||
"--yes",
|
||||
"scikit-learn",
|
||||
"xgboost",
|
||||
"lightgbm",
|
||||
)
|
||||
session.run("pytest", "eland/tests/ml/")
|
||||
|
||||
|
||||
@nox.session(python=["3.6", "3.7", "3.8"], name="test-ml-deps")
|
||||
def test_ml_deps(session):
|
||||
def session_uninstall(*deps):
|
||||
session.run("python", "-m", "pip", "uninstall", "--yes", *deps)
|
||||
|
||||
session.install("-r", "requirements-dev.txt")
|
||||
session.run("python", "-m", "eland.tests.setup_tests")
|
||||
|
||||
session_uninstall("xgboost", "scikit-learn", "lightgbm")
|
||||
session.run("pytest", *(session.posargs or ("eland/tests/ml/",)))
|
||||
|
||||
session.install(".[scikit-learn]")
|
||||
session.run("pytest", *(session.posargs or ("eland/tests/ml/",)))
|
||||
|
||||
|
||||
@nox.session(reuse_venv=True)
|
||||
def docs(session):
|
||||
# Run this so users get an error if they don't have Pandoc installed.
|
||||
|
Loading…
x
Reference in New Issue
Block a user