Fix get_feature_id() for named feature 0

This commit is contained in:
Seth Michael Larson 2020-08-18 10:58:36 -05:00 committed by GitHub
parent 4576951f37
commit 013cab0162
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 4 additions and 3 deletions

View File

@ -15,6 +15,7 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import re
from typing import Optional, List, Dict, Any, Type from typing import Optional, List, Dict, Any, Type
from .base import ModelTransformer from .base import ModelTransformer
import pandas as pd # type: ignore import pandas as pd # type: ignore
@ -52,13 +53,13 @@ class XGBoostForestTransformer(ModelTransformer):
self._feature_dict = dict(zip(feature_names, range(len(feature_names)))) self._feature_dict = dict(zip(feature_names, range(len(feature_names))))
def get_feature_id(self, feature_id: str) -> int: def get_feature_id(self, feature_id: str) -> int:
if feature_id[0] == "f": if re.match(r"^f[0-9]+$", feature_id):
try: try:
return int(feature_id[1:]) return int(feature_id[1:])
except ValueError: except ValueError:
raise RuntimeError(f"Unable to interpret '{feature_id}'") raise RuntimeError(f"Unable to interpret '{feature_id}'")
f_id = self._feature_dict.get(feature_id) f_id = self._feature_dict.get(feature_id)
if f_id: if f_id is not None:
return f_id return f_id
else: else:
try: try:

View File

@ -302,7 +302,7 @@ class TestImportedMLModel:
classifier.fit(training_data[0], training_data[1]) classifier.fit(training_data[0], training_data[1])
# Serialise the models to Elasticsearch # Serialise the models to Elasticsearch
feature_names = ["f0", "f1", "f2", "f3", "f4"] feature_names = ["feature0", "feature1", "feature2", "feature3", "feature4"]
model_id = "test_xgb_classifier" model_id = "test_xgb_classifier"
es_model = ImportedMLModel( es_model = ImportedMLModel(