diff --git a/eland/ml/exporters/_sklearn_deserializers.py b/eland/ml/exporters/_sklearn_deserializers.py index 086038b..6df5df7 100644 --- a/eland/ml/exporters/_sklearn_deserializers.py +++ b/eland/ml/exporters/_sklearn_deserializers.py @@ -97,18 +97,23 @@ class Tree: impurity[i], n_node_samples[i], weighted_n_node_samples[i], + True, ) for i in range(node_count) ], - dtype=[ - ("left_child", "=0.22.1,<2 +scikit-learn>=1.3,<2 xgboost>=0.90,<2 lightgbm>=2,<4 diff --git a/setup.py b/setup.py index 63ab3c7..ffb96c2 100644 --- a/setup.py +++ b/setup.py @@ -56,7 +56,7 @@ with open(path.join(here, "README.md"), "r", "utf-8") as f: extras = { "xgboost": ["xgboost>=0.90,<2"], - "scikit-learn": ["scikit-learn>=0.22.1,<2"], + "scikit-learn": ["scikit-learn>=1.3,<2"], "lightgbm": ["lightgbm>=2,<4"], "pytorch": [ "torch>=1.13.1,<2.0", diff --git a/tests/ml/pytorch/test_pytorch_model_upload_pytest.py b/tests/ml/pytorch/test_pytorch_model_upload_pytest.py index 5722858..36c4086 100644 --- a/tests/ml/pytorch/test_pytorch_model_upload_pytest.py +++ b/tests/ml/pytorch/test_pytorch_model_upload_pytest.py @@ -103,17 +103,31 @@ def download_model_and_start_deployment(tmp_dir, quantize, model_id, task): class TestPytorchModel: + def __init__(self): + # quantization does not work on ARM processors + # TODO: It seems that PyTorch 2.0 supports OneDNN for aarch64. We should + # revisit this when we upgrade to PyTorch 2.0. + import platform + + self.quantize = ( + True if platform.machine() not in ["arm64", "aarch64"] else False + ) + @pytest.mark.parametrize("model_id,task,text_input,value", TEXT_PREDICTION_MODELS) def test_text_prediction(self, model_id, task, text_input, value): with tempfile.TemporaryDirectory() as tmp_dir: - ptm = download_model_and_start_deployment(tmp_dir, True, model_id, task) + ptm = download_model_and_start_deployment( + tmp_dir, self.quantize, model_id, task + ) result = ptm.infer(docs=[{"text_field": text_input}]) assert result["predicted_value"] == value @pytest.mark.parametrize("model_id,task,text_input", TEXT_EMBEDDING_MODELS) def test_text_embedding(self, model_id, task, text_input): with tempfile.TemporaryDirectory() as tmp_dir: - ptm = download_model_and_start_deployment(tmp_dir, True, model_id, task) + ptm = download_model_and_start_deployment( + tmp_dir, self.quantize, model_id, task + ) ptm.infer(docs=[{"text_field": text_input}]) if ES_VERSION >= (8, 8, 0):