Fix get_feature_id() for named feature 0

This commit is contained in:
Seth Michael Larson 2020-08-18 10:58:36 -05:00 committed by GitHub
parent 4576951f37
commit 013cab0162
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 4 additions and 3 deletions

View File

@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import re
from typing import Optional, List, Dict, Any, Type
from .base import ModelTransformer
import pandas as pd # type: ignore
@ -52,13 +53,13 @@ class XGBoostForestTransformer(ModelTransformer):
self._feature_dict = dict(zip(feature_names, range(len(feature_names))))
def get_feature_id(self, feature_id: str) -> int:
if feature_id[0] == "f":
if re.match(r"^f[0-9]+$", feature_id):
try:
return int(feature_id[1:])
except ValueError:
raise RuntimeError(f"Unable to interpret '{feature_id}'")
f_id = self._feature_dict.get(feature_id)
if f_id:
if f_id is not None:
return f_id
else:
try:

View File

@ -302,7 +302,7 @@ class TestImportedMLModel:
classifier.fit(training_data[0], training_data[1])
# Serialise the models to Elasticsearch
feature_names = ["f0", "f1", "f2", "f3", "f4"]
feature_names = ["feature0", "feature1", "feature2", "feature3", "feature4"]
model_id = "test_xgb_classifier"
es_model = ImportedMLModel(