mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
Fix get_feature_id() for named feature 0
This commit is contained in:
parent
4576951f37
commit
013cab0162
@ -15,6 +15,7 @@
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import re
|
||||
from typing import Optional, List, Dict, Any, Type
|
||||
from .base import ModelTransformer
|
||||
import pandas as pd # type: ignore
|
||||
@ -52,13 +53,13 @@ class XGBoostForestTransformer(ModelTransformer):
|
||||
self._feature_dict = dict(zip(feature_names, range(len(feature_names))))
|
||||
|
||||
def get_feature_id(self, feature_id: str) -> int:
|
||||
if feature_id[0] == "f":
|
||||
if re.match(r"^f[0-9]+$", feature_id):
|
||||
try:
|
||||
return int(feature_id[1:])
|
||||
except ValueError:
|
||||
raise RuntimeError(f"Unable to interpret '{feature_id}'")
|
||||
f_id = self._feature_dict.get(feature_id)
|
||||
if f_id:
|
||||
if f_id is not None:
|
||||
return f_id
|
||||
else:
|
||||
try:
|
||||
|
@ -302,7 +302,7 @@ class TestImportedMLModel:
|
||||
classifier.fit(training_data[0], training_data[1])
|
||||
|
||||
# Serialise the models to Elasticsearch
|
||||
feature_names = ["f0", "f1", "f2", "f3", "f4"]
|
||||
feature_names = ["feature0", "feature1", "feature2", "feature3", "feature4"]
|
||||
model_id = "test_xgb_classifier"
|
||||
|
||||
es_model = ImportedMLModel(
|
||||
|
Loading…
x
Reference in New Issue
Block a user