Fix failing unit tests (#558)

I updated the tree serialization format for the new scikit learn versions. I also updated the minimum requirement of scikit learn to 1.3 to ensure compatibility.

Fixes #555
This commit is contained in:
Valeriy Khakhutskyy 2023-07-10 15:15:58 +02:00 committed by GitHub
parent 5ac8a053f0
commit f38de0ed05
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 32 additions and 13 deletions

View File

@ -97,18 +97,23 @@ class Tree:
impurity[i],
n_node_samples[i],
weighted_n_node_samples[i],
True,
)
for i in range(node_count)
],
dtype=[
("left_child", "<i8"),
("right_child", "<i8"),
("feature", "<i8"),
("threshold", "<f8"),
("impurity", "<f8"),
("n_node_samples", "<i8"),
("weighted_n_node_samples", "<f8"),
],
dtype={
"names": [
"left_child",
"right_child",
"feature",
"threshold",
"impurity",
"n_node_samples",
"weighted_n_node_samples",
"missing_go_to_left",
],
"formats": ["<i8", "<i8", "<i8", "<f8", "<f8", "<i8", "<f8", "u1"],
},
)
state = {
"max_depth": self.max_depth,

View File

@ -12,7 +12,7 @@ tqdm<5
#
# Extras
#
scikit-learn>=0.22.1,<2
scikit-learn>=1.3,<2
xgboost>=0.90,<2
lightgbm>=2,<4

View File

@ -56,7 +56,7 @@ with open(path.join(here, "README.md"), "r", "utf-8") as f:
extras = {
"xgboost": ["xgboost>=0.90,<2"],
"scikit-learn": ["scikit-learn>=0.22.1,<2"],
"scikit-learn": ["scikit-learn>=1.3,<2"],
"lightgbm": ["lightgbm>=2,<4"],
"pytorch": [
"torch>=1.13.1,<2.0",

View File

@ -103,17 +103,31 @@ def download_model_and_start_deployment(tmp_dir, quantize, model_id, task):
class TestPytorchModel:
def __init__(self):
# quantization does not work on ARM processors
# TODO: It seems that PyTorch 2.0 supports OneDNN for aarch64. We should
# revisit this when we upgrade to PyTorch 2.0.
import platform
self.quantize = (
True if platform.machine() not in ["arm64", "aarch64"] else False
)
@pytest.mark.parametrize("model_id,task,text_input,value", TEXT_PREDICTION_MODELS)
def test_text_prediction(self, model_id, task, text_input, value):
with tempfile.TemporaryDirectory() as tmp_dir:
ptm = download_model_and_start_deployment(tmp_dir, True, model_id, task)
ptm = download_model_and_start_deployment(
tmp_dir, self.quantize, model_id, task
)
result = ptm.infer(docs=[{"text_field": text_input}])
assert result["predicted_value"] == value
@pytest.mark.parametrize("model_id,task,text_input", TEXT_EMBEDDING_MODELS)
def test_text_embedding(self, model_id, task, text_input):
with tempfile.TemporaryDirectory() as tmp_dir:
ptm = download_model_and_start_deployment(tmp_dir, True, model_id, task)
ptm = download_model_and_start_deployment(
tmp_dir, self.quantize, model_id, task
)
ptm.infer(docs=[{"text_field": text_input}])
if ES_VERSION >= (8, 8, 0):