mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
Fix failing unit tests (#558)
I updated the tree serialization format for the new scikit learn versions. I also updated the minimum requirement of scikit learn to 1.3 to ensure compatibility. Fixes #555
This commit is contained in:
parent
5ac8a053f0
commit
f38de0ed05
@ -97,18 +97,23 @@ class Tree:
|
|||||||
impurity[i],
|
impurity[i],
|
||||||
n_node_samples[i],
|
n_node_samples[i],
|
||||||
weighted_n_node_samples[i],
|
weighted_n_node_samples[i],
|
||||||
|
True,
|
||||||
)
|
)
|
||||||
for i in range(node_count)
|
for i in range(node_count)
|
||||||
],
|
],
|
||||||
dtype=[
|
dtype={
|
||||||
("left_child", "<i8"),
|
"names": [
|
||||||
("right_child", "<i8"),
|
"left_child",
|
||||||
("feature", "<i8"),
|
"right_child",
|
||||||
("threshold", "<f8"),
|
"feature",
|
||||||
("impurity", "<f8"),
|
"threshold",
|
||||||
("n_node_samples", "<i8"),
|
"impurity",
|
||||||
("weighted_n_node_samples", "<f8"),
|
"n_node_samples",
|
||||||
],
|
"weighted_n_node_samples",
|
||||||
|
"missing_go_to_left",
|
||||||
|
],
|
||||||
|
"formats": ["<i8", "<i8", "<i8", "<f8", "<f8", "<i8", "<f8", "u1"],
|
||||||
|
},
|
||||||
)
|
)
|
||||||
state = {
|
state = {
|
||||||
"max_depth": self.max_depth,
|
"max_depth": self.max_depth,
|
||||||
|
@ -12,7 +12,7 @@ tqdm<5
|
|||||||
#
|
#
|
||||||
# Extras
|
# Extras
|
||||||
#
|
#
|
||||||
scikit-learn>=0.22.1,<2
|
scikit-learn>=1.3,<2
|
||||||
xgboost>=0.90,<2
|
xgboost>=0.90,<2
|
||||||
lightgbm>=2,<4
|
lightgbm>=2,<4
|
||||||
|
|
||||||
|
2
setup.py
2
setup.py
@ -56,7 +56,7 @@ with open(path.join(here, "README.md"), "r", "utf-8") as f:
|
|||||||
|
|
||||||
extras = {
|
extras = {
|
||||||
"xgboost": ["xgboost>=0.90,<2"],
|
"xgboost": ["xgboost>=0.90,<2"],
|
||||||
"scikit-learn": ["scikit-learn>=0.22.1,<2"],
|
"scikit-learn": ["scikit-learn>=1.3,<2"],
|
||||||
"lightgbm": ["lightgbm>=2,<4"],
|
"lightgbm": ["lightgbm>=2,<4"],
|
||||||
"pytorch": [
|
"pytorch": [
|
||||||
"torch>=1.13.1,<2.0",
|
"torch>=1.13.1,<2.0",
|
||||||
|
@ -103,17 +103,31 @@ def download_model_and_start_deployment(tmp_dir, quantize, model_id, task):
|
|||||||
|
|
||||||
|
|
||||||
class TestPytorchModel:
|
class TestPytorchModel:
|
||||||
|
def __init__(self):
|
||||||
|
# quantization does not work on ARM processors
|
||||||
|
# TODO: It seems that PyTorch 2.0 supports OneDNN for aarch64. We should
|
||||||
|
# revisit this when we upgrade to PyTorch 2.0.
|
||||||
|
import platform
|
||||||
|
|
||||||
|
self.quantize = (
|
||||||
|
True if platform.machine() not in ["arm64", "aarch64"] else False
|
||||||
|
)
|
||||||
|
|
||||||
@pytest.mark.parametrize("model_id,task,text_input,value", TEXT_PREDICTION_MODELS)
|
@pytest.mark.parametrize("model_id,task,text_input,value", TEXT_PREDICTION_MODELS)
|
||||||
def test_text_prediction(self, model_id, task, text_input, value):
|
def test_text_prediction(self, model_id, task, text_input, value):
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
ptm = download_model_and_start_deployment(tmp_dir, True, model_id, task)
|
ptm = download_model_and_start_deployment(
|
||||||
|
tmp_dir, self.quantize, model_id, task
|
||||||
|
)
|
||||||
result = ptm.infer(docs=[{"text_field": text_input}])
|
result = ptm.infer(docs=[{"text_field": text_input}])
|
||||||
assert result["predicted_value"] == value
|
assert result["predicted_value"] == value
|
||||||
|
|
||||||
@pytest.mark.parametrize("model_id,task,text_input", TEXT_EMBEDDING_MODELS)
|
@pytest.mark.parametrize("model_id,task,text_input", TEXT_EMBEDDING_MODELS)
|
||||||
def test_text_embedding(self, model_id, task, text_input):
|
def test_text_embedding(self, model_id, task, text_input):
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
ptm = download_model_and_start_deployment(tmp_dir, True, model_id, task)
|
ptm = download_model_and_start_deployment(
|
||||||
|
tmp_dir, self.quantize, model_id, task
|
||||||
|
)
|
||||||
ptm.infer(docs=[{"text_field": text_input}])
|
ptm.infer(docs=[{"text_field": text_input}])
|
||||||
|
|
||||||
if ES_VERSION >= (8, 8, 0):
|
if ES_VERSION >= (8, 8, 0):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user