mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
Fix get_feature_id() for named feature 0
This commit is contained in:
parent
4576951f37
commit
013cab0162
@ -15,6 +15,7 @@
|
|||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
|
|
||||||
|
import re
|
||||||
from typing import Optional, List, Dict, Any, Type
|
from typing import Optional, List, Dict, Any, Type
|
||||||
from .base import ModelTransformer
|
from .base import ModelTransformer
|
||||||
import pandas as pd # type: ignore
|
import pandas as pd # type: ignore
|
||||||
@ -52,13 +53,13 @@ class XGBoostForestTransformer(ModelTransformer):
|
|||||||
self._feature_dict = dict(zip(feature_names, range(len(feature_names))))
|
self._feature_dict = dict(zip(feature_names, range(len(feature_names))))
|
||||||
|
|
||||||
def get_feature_id(self, feature_id: str) -> int:
|
def get_feature_id(self, feature_id: str) -> int:
|
||||||
if feature_id[0] == "f":
|
if re.match(r"^f[0-9]+$", feature_id):
|
||||||
try:
|
try:
|
||||||
return int(feature_id[1:])
|
return int(feature_id[1:])
|
||||||
except ValueError:
|
except ValueError:
|
||||||
raise RuntimeError(f"Unable to interpret '{feature_id}'")
|
raise RuntimeError(f"Unable to interpret '{feature_id}'")
|
||||||
f_id = self._feature_dict.get(feature_id)
|
f_id = self._feature_dict.get(feature_id)
|
||||||
if f_id:
|
if f_id is not None:
|
||||||
return f_id
|
return f_id
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
|
@ -302,7 +302,7 @@ class TestImportedMLModel:
|
|||||||
classifier.fit(training_data[0], training_data[1])
|
classifier.fit(training_data[0], training_data[1])
|
||||||
|
|
||||||
# Serialise the models to Elasticsearch
|
# Serialise the models to Elasticsearch
|
||||||
feature_names = ["f0", "f1", "f2", "f3", "f4"]
|
feature_names = ["feature0", "feature1", "feature2", "feature3", "feature4"]
|
||||||
model_id = "test_xgb_classifier"
|
model_id = "test_xgb_classifier"
|
||||||
|
|
||||||
es_model = ImportedMLModel(
|
es_model = ImportedMLModel(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user