mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
Fix failing unit tests (#558)
I updated the tree serialization format for the new scikit learn versions. I also updated the minimum requirement of scikit learn to 1.3 to ensure compatibility. Fixes #555
This commit is contained in:
parent
5ac8a053f0
commit
f38de0ed05
@ -97,18 +97,23 @@ class Tree:
|
||||
impurity[i],
|
||||
n_node_samples[i],
|
||||
weighted_n_node_samples[i],
|
||||
True,
|
||||
)
|
||||
for i in range(node_count)
|
||||
],
|
||||
dtype=[
|
||||
("left_child", "<i8"),
|
||||
("right_child", "<i8"),
|
||||
("feature", "<i8"),
|
||||
("threshold", "<f8"),
|
||||
("impurity", "<f8"),
|
||||
("n_node_samples", "<i8"),
|
||||
("weighted_n_node_samples", "<f8"),
|
||||
],
|
||||
dtype={
|
||||
"names": [
|
||||
"left_child",
|
||||
"right_child",
|
||||
"feature",
|
||||
"threshold",
|
||||
"impurity",
|
||||
"n_node_samples",
|
||||
"weighted_n_node_samples",
|
||||
"missing_go_to_left",
|
||||
],
|
||||
"formats": ["<i8", "<i8", "<i8", "<f8", "<f8", "<i8", "<f8", "u1"],
|
||||
},
|
||||
)
|
||||
state = {
|
||||
"max_depth": self.max_depth,
|
||||
|
@ -12,7 +12,7 @@ tqdm<5
|
||||
#
|
||||
# Extras
|
||||
#
|
||||
scikit-learn>=0.22.1,<2
|
||||
scikit-learn>=1.3,<2
|
||||
xgboost>=0.90,<2
|
||||
lightgbm>=2,<4
|
||||
|
||||
|
2
setup.py
2
setup.py
@ -56,7 +56,7 @@ with open(path.join(here, "README.md"), "r", "utf-8") as f:
|
||||
|
||||
extras = {
|
||||
"xgboost": ["xgboost>=0.90,<2"],
|
||||
"scikit-learn": ["scikit-learn>=0.22.1,<2"],
|
||||
"scikit-learn": ["scikit-learn>=1.3,<2"],
|
||||
"lightgbm": ["lightgbm>=2,<4"],
|
||||
"pytorch": [
|
||||
"torch>=1.13.1,<2.0",
|
||||
|
@ -103,17 +103,31 @@ def download_model_and_start_deployment(tmp_dir, quantize, model_id, task):
|
||||
|
||||
|
||||
class TestPytorchModel:
|
||||
def __init__(self):
|
||||
# quantization does not work on ARM processors
|
||||
# TODO: It seems that PyTorch 2.0 supports OneDNN for aarch64. We should
|
||||
# revisit this when we upgrade to PyTorch 2.0.
|
||||
import platform
|
||||
|
||||
self.quantize = (
|
||||
True if platform.machine() not in ["arm64", "aarch64"] else False
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("model_id,task,text_input,value", TEXT_PREDICTION_MODELS)
|
||||
def test_text_prediction(self, model_id, task, text_input, value):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
ptm = download_model_and_start_deployment(tmp_dir, True, model_id, task)
|
||||
ptm = download_model_and_start_deployment(
|
||||
tmp_dir, self.quantize, model_id, task
|
||||
)
|
||||
result = ptm.infer(docs=[{"text_field": text_input}])
|
||||
assert result["predicted_value"] == value
|
||||
|
||||
@pytest.mark.parametrize("model_id,task,text_input", TEXT_EMBEDDING_MODELS)
|
||||
def test_text_embedding(self, model_id, task, text_input):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
ptm = download_model_and_start_deployment(tmp_dir, True, model_id, task)
|
||||
ptm = download_model_and_start_deployment(
|
||||
tmp_dir, self.quantize, model_id, task
|
||||
)
|
||||
ptm.infer(docs=[{"text_field": text_input}])
|
||||
|
||||
if ES_VERSION >= (8, 8, 0):
|
||||
|
Loading…
x
Reference in New Issue
Block a user