From 013cab0162e901ad1ca8868a349efd7041c31a64 Mon Sep 17 00:00:00 2001 From: Seth Michael Larson Date: Tue, 18 Aug 2020 10:58:36 -0500 Subject: [PATCH] Fix get_feature_id() for named feature 0 --- eland/ml/transformers/xgboost.py | 5 +++-- eland/tests/ml/test_imported_ml_model_pytest.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/eland/ml/transformers/xgboost.py b/eland/ml/transformers/xgboost.py index c5712e2..1f092c1 100644 --- a/eland/ml/transformers/xgboost.py +++ b/eland/ml/transformers/xgboost.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. +import re from typing import Optional, List, Dict, Any, Type from .base import ModelTransformer import pandas as pd # type: ignore @@ -52,13 +53,13 @@ class XGBoostForestTransformer(ModelTransformer): self._feature_dict = dict(zip(feature_names, range(len(feature_names)))) def get_feature_id(self, feature_id: str) -> int: - if feature_id[0] == "f": + if re.match(r"^f[0-9]+$", feature_id): try: return int(feature_id[1:]) except ValueError: raise RuntimeError(f"Unable to interpret '{feature_id}'") f_id = self._feature_dict.get(feature_id) - if f_id: + if f_id is not None: return f_id else: try: diff --git a/eland/tests/ml/test_imported_ml_model_pytest.py b/eland/tests/ml/test_imported_ml_model_pytest.py index 99d87ed..493e3e2 100644 --- a/eland/tests/ml/test_imported_ml_model_pytest.py +++ b/eland/tests/ml/test_imported_ml_model_pytest.py @@ -302,7 +302,7 @@ class TestImportedMLModel: classifier.fit(training_data[0], training_data[1]) # Serialise the models to Elasticsearch - feature_names = ["f0", "f1", "f2", "f3", "f4"] + feature_names = ["feature0", "feature1", "feature2", "feature3", "feature4"] model_id = "test_xgb_classifier" es_model = ImportedMLModel(