ML add externral models (#125)

* Partially implemented implementation of ml.ExternalModel

* Adding eland.ml.ExternalMLModel

More testing to be added + more support for MLModels
This commit is contained in:
stevedodson 2020-02-15 15:54:29 +01:00 committed by GitHub
parent 4ac67a73ea
commit 7c1c2945a7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 1238 additions and 39 deletions

View File

@ -38,7 +38,7 @@ sys.path.extend(
# -- Project information -----------------------------------------------------
project = 'eland'
copyright = '2019, Elasticsearch B.V.'
copyright = '2020, Elasticsearch B.V.'
# The full version, including alpha/beta/rc tags
import eland

View File

@ -753,7 +753,7 @@
{
"data": {
"text/plain": [
"<eland.index.Index at 0x11a122310>"
"<eland.index.Index at 0x11631ffd0>"
]
},
"execution_count": 17,
@ -2704,10 +2704,10 @@
" </tr>\n",
" <tr>\n",
" <th>25%</th>\n",
" <td>410.008918</td>\n",
" <td>410.011039</td>\n",
" <td>2470.545974</td>\n",
" <td>...</td>\n",
" <td>251.682199</td>\n",
" <td>251.773003</td>\n",
" <td>1.000000</td>\n",
" </tr>\n",
" <tr>\n",
@ -2720,11 +2720,11 @@
" </tr>\n",
" <tr>\n",
" <th>75%</th>\n",
" <td>842.233478</td>\n",
" <td>9735.660463</td>\n",
" <td>842.213490</td>\n",
" <td>9734.960478</td>\n",
" <td>...</td>\n",
" <td>720.534532</td>\n",
" <td>4.288079</td>\n",
" <td>720.505705</td>\n",
" <td>4.172535</td>\n",
" </tr>\n",
" <tr>\n",
" <th>max</th>\n",
@ -2745,9 +2745,9 @@
"mean 628.253689 7092.142457 ... 511.127842 2.835975\n",
"std 266.386661 4578.263193 ... 334.741135 1.939365\n",
"min 100.020531 0.000000 ... 0.000000 0.000000\n",
"25% 410.008918 2470.545974 ... 251.682199 1.000000\n",
"25% 410.011039 2470.545974 ... 251.773003 1.000000\n",
"50% 640.387285 7612.072403 ... 503.148975 3.000000\n",
"75% 842.233478 9735.660463 ... 720.534532 4.288079\n",
"75% 842.213490 9734.960478 ... 720.505705 4.172535\n",
"max 1199.729004 19881.482422 ... 1902.901978 6.000000\n",
"\n",
"[8 rows x 7 columns]"
@ -3676,11 +3676,11 @@
" is_source_field: False\n",
"Mappings:\n",
" capabilities:\n",
" es_field_name is_source es_dtype es_date_format pd_dtype is_searchable is_aggregatable is_scripted aggregatable_es_field_name\n",
"timestamp timestamp True date None datetime64[ns] True True False timestamp\n",
"OriginAirportID OriginAirportID True keyword None object True True False OriginAirportID\n",
"DestAirportID DestAirportID True keyword None object True True False DestAirportID\n",
"FlightDelayMin FlightDelayMin True integer None int64 True True False FlightDelayMin\n",
" es_field_name is_source es_dtype es_date_format pd_dtype is_searchable is_aggregatable is_scripted aggregatable_es_field_name\n",
"timestamp timestamp True date strict_date_hour_minute_second datetime64[ns] True True False timestamp\n",
"OriginAirportID OriginAirportID True keyword None object True True False OriginAirportID\n",
"DestAirportID DestAirportID True keyword None object True True False DestAirportID\n",
"FlightDelayMin FlightDelayMin True integer None int64 True True False FlightDelayMin\n",
"Operations:\n",
" tasks: [('boolean_filter': ('boolean_filter': {'bool': {'must': [{'term': {'OriginAirportID': 'AMS'}}, {'range': {'FlightDelayMin': {'gt': 60}}}]}})), ('tail': ('sort_field': '_doc', 'count': 5))]\n",
" size: 5\n",

View File

@ -1023,21 +1023,21 @@
" </tr>\n",
" <tr>\n",
" <th>25%</th>\n",
" <td>14221.960201</td>\n",
" <td>14217.474239</td>\n",
" <td>1.000000</td>\n",
" <td>1.250000</td>\n",
" <td>1.250068</td>\n",
" </tr>\n",
" <tr>\n",
" <th>50%</th>\n",
" <td>15671.712170</td>\n",
" <td>15662.024630</td>\n",
" <td>2.000000</td>\n",
" <td>2.510000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>75%</th>\n",
" <td>17214.376367</td>\n",
" <td>6.615042</td>\n",
" <td>4.210533</td>\n",
" <td>17212.723881</td>\n",
" <td>6.671951</td>\n",
" <td>4.210000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>max</th>\n",
@ -1055,9 +1055,9 @@
"mean 15590.776680 7.464000 4.103233\n",
"std 1764.025160 85.924387 20.104873\n",
"min 12347.000000 -9360.000000 0.000000\n",
"25% 14221.960201 1.000000 1.250000\n",
"50% 15671.712170 2.000000 2.510000\n",
"75% 17214.376367 6.615042 4.210533\n",
"25% 14217.474239 1.000000 1.250068\n",
"50% 15662.024630 2.000000 2.510000\n",
"75% 17212.723881 6.671951 4.210000\n",
"max 18239.000000 2880.000000 950.990000"
]
},

View File

@ -35,6 +35,7 @@ In general, the data resides in elasticsearch and not in memory, which allows el
* :doc:`reference/dataframe`
* :doc:`reference/series`
* :doc:`reference/indexing`
* :doc:`reference/ml`
* :doc:`implementation/index`

View File

@ -0,0 +1,6 @@
eland.ml.ExternalMLModel.predict
================================
.. currentmodule:: eland.ml
.. automethod:: ExternalMLModel.predict

View File

@ -0,0 +1,6 @@
eland.ml.ExternalMLModel
========================
.. currentmodule:: eland.ml
.. autoclass:: ExternalMLModel

View File

@ -15,3 +15,4 @@ methods. All classes and functions exposed in ``eland.*`` namespace are public.
dataframe
series
indexing
ml

View File

@ -0,0 +1,25 @@
.. _api.ml:
================
Machine Learning
================
.. currentmodule:: eland.ml
ExternalMLModel
~~~~~~~~~~~~~~~
.. currentmodule:: eland.ml
Constructor
^^^^^^^^^^^
.. autosummary::
:toctree: api/
ExternalMLModel
Learning API
^^^^^^^^^^^^
.. autosummary::
:toctree: api/
ExternalMLModel.predict

View File

@ -27,3 +27,4 @@ from eland.ndframe import *
from eland.series import *
from eland.dataframe import *
from eland.utils import *

View File

@ -56,3 +56,6 @@ class Client:
def count(self, **kwargs):
count_json = self._es.count(**kwargs)
return count_json['count']
def perform_request(self, method, url, headers=None, params=None, body=None):
return self._es.transport.perform_request(method, url, headers, params, body)

16
eland/ml/__init__.py Normal file
View File

@ -0,0 +1,16 @@
# Copyright 2019 Elasticsearch BV
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from eland.ml.ml_model import *
from eland.ml.external_ml_model import *

View File

@ -0,0 +1,116 @@
import base64
import gzip
import json
from abc import ABC
from typing import List
def add_if_exists(d: dict, k: str, v) -> dict:
if v is not None:
d[k] = v
return d
class ModelSerializer(ABC):
def __init__(self,
feature_names: List[str],
target_type: str = None,
classification_labels: List[str] = None):
self._target_type = target_type
self._feature_names = feature_names
self._classification_labels = classification_labels
def to_dict(self):
d = dict()
add_if_exists(d, "target_type", self._target_type)
add_if_exists(d, "feature_names", self._feature_names)
add_if_exists(d, "classification_labels", self._classification_labels)
return d
@property
def feature_names(self):
return self._feature_names
def serialize_and_compress_model(self) -> str:
json_string = json.dumps({'trained_model': self.to_dict()})
return base64.b64encode(gzip.compress(bytes(json_string, 'utf-8')))
class TreeNode:
def __init__(self,
node_idx: int,
default_left: bool = None,
decision_type: str = None,
left_child: int = None,
right_child: int = None,
split_feature: int = None,
threshold: float = None,
leaf_value: float = None):
self._node_idx = node_idx
self._decision_type = decision_type
self._left_child = left_child
self._right_child = right_child
self._split_feature = split_feature
self._threshold = threshold
self._leaf_value = leaf_value
self._default_left = default_left
def to_dict(self):
d = dict()
add_if_exists(d, 'node_index', self._node_idx)
add_if_exists(d, 'decision_type', self._decision_type)
if self._leaf_value is None:
add_if_exists(d, 'left_child', self._left_child)
add_if_exists(d, 'right_child', self._right_child)
add_if_exists(d, 'split_feature', self._split_feature)
add_if_exists(d, 'threshold', self._threshold)
else:
add_if_exists(d, 'leaf_value', self._leaf_value)
return d
class Tree(ModelSerializer):
def __init__(self,
feature_names: List[str],
target_type: str = None,
tree_structure: List[TreeNode] = [],
classification_labels: List[str] = None):
super().__init__(
feature_names=feature_names,
target_type=target_type,
classification_labels=classification_labels
)
if target_type == 'regression' and classification_labels:
raise ValueError("regression does not support classification_labels")
self._tree_structure = tree_structure
def to_dict(self):
d = super().to_dict()
add_if_exists(d, 'tree_structure', [t.to_dict() for t in self._tree_structure])
return {'tree': d}
class Ensemble(ModelSerializer):
def __init__(self,
feature_names: List[str],
trained_models: List[ModelSerializer],
output_aggregator: dict,
target_type: str = None,
classification_labels: List[str] = None,
classification_weights: List[float] = None):
super().__init__(feature_names=feature_names,
target_type=target_type,
classification_labels=classification_labels)
self._trained_models = trained_models
self._classification_weights = classification_weights
self._output_aggregator = output_aggregator
def to_dict(self):
d = super().to_dict()
trained_models = None
if self._trained_models:
trained_models = [t.to_dict() for t in self._trained_models]
add_if_exists(d, 'trained_models', trained_models)
add_if_exists(d, 'classification_weights', self._classification_weights)
add_if_exists(d, 'aggregate_output', self._output_aggregator)
return {'ensemble': d}

View File

@ -0,0 +1,389 @@
# Copyright 2019 Elasticsearch BV
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Union
import numpy as np
from eland.ml._optional import import_optional_dependency
from eland.ml._model_serializer import Tree, TreeNode, Ensemble
sklearn = import_optional_dependency("sklearn")
xgboost = import_optional_dependency("xgboost")
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
from sklearn.utils.validation import check_is_fitted
from xgboost import Booster, XGBRegressor, XGBClassifier
class ModelTransformer:
def __init__(self,
model,
feature_names: List[str],
classification_labels: List[str] = None,
classification_weights: List[float] = None
):
self._feature_names = feature_names
self._model = model
self._classification_labels = classification_labels
self._classification_weights = classification_weights
def is_supported(self):
return isinstance(self._model, (DecisionTreeClassifier,
DecisionTreeRegressor,
RandomForestRegressor,
RandomForestClassifier,
XGBClassifier,
XGBRegressor,
Booster))
class SKLearnTransformer(ModelTransformer):
"""
Base class for SKLearn transformers.
warning: Should not use this class directly. Use derived classes instead
"""
def __init__(self,
model,
feature_names: List[str],
classification_labels: List[str] = None,
classification_weights: List[float] = None
):
"""
Base class for SKLearn transformations
:param model: sklearn trained model
:param feature_names: The feature names for the model
:param classification_labels: Optional classification labels (if not encoded in the model)
:param classification_weights: Optional classification weights
"""
super().__init__(model, feature_names, classification_labels, classification_weights)
self._node_decision_type = "lte"
def build_tree_node(self, node_index: int, node_data: dict, value) -> TreeNode:
"""
This builds out a TreeNode class given the sklearn tree node definition.
Node decision types are defaulted to "lte" to match the behavior of SKLearn
:param node_index: The node index
:param node_data: Opaque node data contained in the sklearn tree state
:param value: Opaque node value (i.e. leaf/node values) from tree state
:return: TreeNode object
"""
if value.shape[0] != 1:
raise ValueError("unexpected multiple values returned from leaf node '{0}'".format(node_index))
if node_data[0] == -1: # is leaf node
if value.shape[1] == 1: # classification requires more than one value, so assume regression
leaf_value = float(value[0][0])
else:
# the classification value, which is the index of the largest value
leaf_value = int(np.argmax(value))
return TreeNode(node_index, decision_type=self._node_decision_type, leaf_value=leaf_value)
else:
return TreeNode(node_index,
decision_type=self._node_decision_type,
left_child=int(node_data[0]),
right_child=int(node_data[1]),
split_feature=int(node_data[2]),
threshold=float(node_data[3]))
class SKLearnDecisionTreeTransformer(SKLearnTransformer):
"""
class for transforming SKLearn decision tree models into Tree model formats supported by Elasticsearch.
"""
def __init__(self,
model: Union[DecisionTreeRegressor, DecisionTreeClassifier],
feature_names: List[str],
classification_labels: List[str] = None):
"""
Transforms a Decision Tree model (Regressor|Classifier) into a ES Supported Tree format
:param model: fitted decision tree model
:param feature_names: model feature names
:param classification_labels: Optional classification labels
"""
super().__init__(model, feature_names, classification_labels)
def transform(self) -> Tree:
"""
Transform the provided model into an ES supported Tree object
:return: Tree object for ES storage and use
"""
target_type = "regression" if isinstance(self._model, DecisionTreeRegressor) else "classification"
check_is_fitted(self._model, ["tree_"])
tree_classes = None
if self._classification_labels:
tree_classes = self._classification_labels
if isinstance(self._model, DecisionTreeClassifier):
check_is_fitted(self._model, ["classes_"])
if tree_classes is None:
tree_classes = [str(c) for c in self._model.classes_]
nodes = list()
tree_state = self._model.tree_.__getstate__()
for i in range(len(tree_state["nodes"])):
nodes.append(self.build_tree_node(i, tree_state["nodes"][i], tree_state["values"][i]))
return Tree(self._feature_names,
target_type,
nodes,
tree_classes)
class SKLearnForestTransformer(SKLearnTransformer):
"""
Base class for transforming SKLearn forest models into Ensemble model formats supported by Elasticsearch.
warning: do not use this class directly. Use a derived class instead
"""
def __init__(self,
model: Union[RandomForestClassifier,
RandomForestRegressor],
feature_names: List[str],
classification_labels: List[str] = None,
classification_weights: List[float] = None
):
super().__init__(model, feature_names, classification_labels, classification_weights)
def build_aggregator_output(self) -> dict:
raise NotImplementedError("build_aggregator_output must be implemented")
def determine_target_type(self) -> str:
raise NotImplementedError("determine_target_type must be implemented")
def transform(self) -> Ensemble:
check_is_fitted(self._model, ["estimators_"])
estimators = self._model.estimators_
ensemble_classes = None
if self._classification_labels:
ensemble_classes = self._classification_labels
if isinstance(self._model, RandomForestClassifier):
check_is_fitted(self._model, ["classes_"])
if ensemble_classes is None:
ensemble_classes = [str(c) for c in self._model.classes_]
ensemble_models = [SKLearnDecisionTreeTransformer(m,
self._feature_names).transform() for m in estimators]
return Ensemble(self._feature_names,
ensemble_models,
self.build_aggregator_output(),
target_type=self.determine_target_type(),
classification_labels=ensemble_classes,
classification_weights=self._classification_weights)
class SKLearnForestRegressorTransformer(SKLearnForestTransformer):
"""
Class for transforming RandomForestRegressor models into an ensemble model supported by Elasticsearch
"""
def __init__(self,
model: RandomForestRegressor,
feature_names: List[str]
):
super().__init__(model, feature_names)
def build_aggregator_output(self) -> dict:
return {
"weighted_sum": {"weights": [1.0 / len(self._model.estimators_)] * len(self._model.estimators_), }
}
def determine_target_type(self) -> str:
return "regression"
class SKLearnForestClassifierTransformer(SKLearnForestTransformer):
"""
Class for transforming RandomForestClassifier models into an ensemble model supported by Elasticsearch
"""
def __init__(self,
model: RandomForestClassifier,
feature_names: List[str],
classification_labels: List[str] = None,
):
super().__init__(model, feature_names, classification_labels)
def build_aggregator_output(self) -> dict:
return {"weighted_mode": {"num_classes": len(self._model.classes_)}}
def determine_target_type(self) -> str:
return "classification"
class XGBoostForestTransformer(ModelTransformer):
"""
Base class for transforming XGBoost models into ensemble models supported by Elasticsearch
warning: do not use directly. Use a derived classes instead
"""
def __init__(self,
model: Booster,
feature_names: List[str],
base_score: float = 0.5,
objective: str = "reg:squarederror",
classification_labels: List[str] = None,
classification_weights: List[float] = None
):
super().__init__(model, feature_names, classification_labels, classification_weights)
self._node_decision_type = "lt"
self._base_score = base_score
self._objective = objective
def get_feature_id(self, feature_id: str) -> int:
if feature_id[0] == "f":
try:
return int(feature_id[1:])
except ValueError:
raise RuntimeError("Unable to interpret '{0}'".format(feature_id))
else:
try:
return int(feature_id)
except ValueError:
raise RuntimeError("Unable to interpret '{0}'".format(feature_id))
def extract_node_id(self, node_id: str, curr_tree: int) -> int:
t_id, n_id = node_id.split("-")
if t_id is None or n_id is None:
raise RuntimeError(
"cannot determine node index or tree from '{0}' for tree {1}".format(node_id, curr_tree))
try:
t_id = int(t_id)
n_id = int(n_id)
if t_id != curr_tree:
raise RuntimeError("extracted tree id {0} does not match current tree {1}".format(t_id, curr_tree))
return n_id
except ValueError:
raise RuntimeError(
"cannot determine node index or tree from '{0}' for tree {1}".format(node_id, curr_tree))
def build_tree_node(self, row, curr_tree: int) -> TreeNode:
node_index = row["Node"]
if row["Feature"] == "Leaf":
return TreeNode(node_idx=node_index, leaf_value=float(row["Gain"]))
else:
return TreeNode(node_idx=node_index,
decision_type=self._node_decision_type,
left_child=self.extract_node_id(row["Yes"], curr_tree),
right_child=self.extract_node_id(row["No"], curr_tree),
threshold=float(row["Split"]),
split_feature=self.get_feature_id(row["Feature"]))
def build_tree(self, nodes: List[TreeNode]) -> Tree:
return Tree(feature_names=self._feature_names,
tree_structure=nodes)
def build_base_score_stump(self) -> Tree:
return Tree(feature_names=self._feature_names,
tree_structure=[TreeNode(0, leaf_value=self._base_score)])
def build_forest(self) -> List[Tree]:
"""
This builds out the forest of trees as described by XGBoost into a format
supported by Elasticsearch
:return: A list of Tree objects
"""
if self._model.booster not in {'dart', 'gbtree'}:
raise ValueError("booster must exist and be of type dart or gbtree")
tree_table = self._model.trees_to_dataframe()
transformed_trees = list()
curr_tree = None
tree_nodes = list()
for _, row in tree_table.iterrows():
if row["Tree"] != curr_tree:
if len(tree_nodes) > 0:
transformed_trees.append(self.build_tree(tree_nodes))
curr_tree = row["Tree"]
tree_nodes = list()
tree_nodes.append(self.build_tree_node(row, curr_tree))
# add last tree
if len(tree_nodes) > 0:
transformed_trees.append(self.build_tree(tree_nodes))
# We add this stump as XGBoost adds the base_score to the regression outputs
if self._objective.startswith("reg"):
transformed_trees.append(self.build_base_score_stump())
return transformed_trees
def build_aggregator_output(self) -> dict:
raise NotImplementedError("build_aggregator_output must be implemented")
def determine_target_type(self) -> str:
raise NotImplementedError("determine_target_type must be implemented")
def is_objective_supported(self) -> bool:
return False
def transform(self) -> Ensemble:
if self._model.booster not in {'dart', 'gbtree'}:
raise ValueError("booster must exist and be of type dart or gbtree")
if not self.is_objective_supported():
raise ValueError("Unsupported objective '{0}'".format(self._objective))
forest = self.build_forest()
return Ensemble(feature_names=self._feature_names,
trained_models=forest,
output_aggregator=self.build_aggregator_output(),
classification_labels=self._classification_labels,
classification_weights=self._classification_weights,
target_type=self.determine_target_type())
class XGBoostRegressorTransformer(XGBoostForestTransformer):
def __init__(self,
model: XGBRegressor,
feature_names: List[str]):
super().__init__(model.get_booster(),
feature_names,
model.base_score,
model.objective)
def determine_target_type(self) -> str:
return "regression"
def is_objective_supported(self) -> bool:
return self._objective in {'reg:squarederror',
'reg:linear',
'reg:squaredlogerror',
'reg:logistic'}
def build_aggregator_output(self) -> dict:
return {"weighted_sum": {}}
class XGBoostClassifierTransformer(XGBoostForestTransformer):
def __init__(self,
model: XGBClassifier,
feature_names: List[str],
classification_labels: List[str] = None):
super().__init__(model.get_booster(),
feature_names,
model.base_score,
model.objective,
classification_labels)
def determine_target_type(self) -> str:
return "classification"
def is_objective_supported(self) -> bool:
return self._objective in {'binary:logistic', 'binary:hinge'}
def build_aggregator_output(self) -> dict:
return {"logistic_regression": {}}

115
eland/ml/_optional.py Normal file
View File

@ -0,0 +1,115 @@
# Copyright 2019 Elasticsearch BV
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import distutils.version
import importlib
import types
import warnings
# ----------------------------------------------------------------------------
# functions largely based / taken from the six module
# Much of the code in this module comes from pandas.
# The license for this library can be found NOTICE.txt and the code can be
# https://raw.githubusercontent.com/pandas-dev/pandas/v1.0.1/pandas/compat/_optional.py
VERSIONS = {
"xgboost": "0.90",
"sklearn": "0.22.1"
}
# Update install.rst when updating versions!
message = (
"Missing optional dependency '{name}'. {extra} "
"Use pip or conda to install {name}."
)
version_message = (
"Eland requires version '{minimum_version}' or newer of '{name}' "
"(version '{actual_version}' currently installed). "
"Use pip or conda to update {name}."
)
def _get_version(module: types.ModuleType) -> str:
version = getattr(module, "__version__", None)
if version is None:
# xlrd uses a capitalized attribute name
version = getattr(module, "__VERSION__", None)
if version is None:
raise ImportError("Can't determine version for {}".format(module.__name__))
return version
def import_optional_dependency(
name: str, extra: str = "", raise_on_missing: bool = True, on_version: str = "raise"
):
"""
Import an optional dependency.
By default, if a dependency is missing an ImportError with a nice
message will be raised. If a dependency is present, but too old,
we raise.
Parameters
----------
name : str
The module name. This should be top-level only, so that the
version may be checked.
extra : str
Additional text to include in the ImportError message.
raise_on_missing : bool, default True
Whether to raise if the optional dependency is not found.
When False and the module is not present, None is returned.
on_version : str {'raise', 'warn'}
What to do when a dependency's version is too old.
* raise : Raise an ImportError
* warn : Warn that the version is too old. Returns None
* ignore: Return the module, even if the version is too old.
It's expected that users validate the version locally when
using ``on_version="ignore"`` (see. ``io/html.py``)
Returns
-------
maybe_module : Optional[ModuleType]
The imported module, when found and the version is correct.
None is returned when the package is not found and `raise_on_missing`
is False, or when the package's version is too old and `on_version`
is ``'warn'``.
"""
try:
module = importlib.import_module(name)
except ImportError:
if raise_on_missing:
raise ImportError(message.format(name=name, extra=extra)) from None
else:
return None
minimum_version = VERSIONS.get(name)
if minimum_version:
version = _get_version(module)
if distutils.version.LooseVersion(version) < minimum_version:
assert on_version in {"warn", "raise", "ignore"}
msg = version_message.format(
minimum_version=minimum_version, name=name, actual_version=version
)
if on_version == "warn":
warnings.warn(msg, UserWarning)
return None
elif on_version == "raise":
raise ImportError(msg)
return module

View File

@ -0,0 +1,238 @@
# Copyright 2020 Elasticsearch BV
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union, List
import numpy as np
from eland.ml._model_transformers import SKLearnDecisionTreeTransformer, SKLearnForestRegressorTransformer, \
SKLearnForestClassifierTransformer, XGBoostRegressorTransformer, XGBoostClassifierTransformer
from eland.ml._optional import import_optional_dependency
from eland.ml.ml_model import MLModel
sklearn = import_optional_dependency("sklearn")
xgboost = import_optional_dependency("xgboost")
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
from xgboost import XGBRegressor, XGBClassifier
class ExternalMLModel(MLModel):
"""
Put a trained inference model in Elasticsearch based on an external model.
An external model that is transformed and added to Elasticsearch.
Parameters
----------
es_client: Elasticsearch client argument(s)
- elasticsearch-py parameters or
- elasticsearch-py instance or
- eland.Client instance
model_id: str
The unique identifier of the trained inference model in Elasticsearch.
model: An instance of a supported python model. We support the following model types:
- sklearn.tree.DecisionTreeClassifier
- sklearn.tree.DecisionTreeRegressor
- sklearn.ensemble.RandomForestRegressor
- sklearn.ensemble.RandomForestClassifier
- xgboost.XGBClassifier
- xgboost.XGBRegressor
feature_names: List[str]
Names of the features (required)
classification_labels: List[str]
Labels of the classification targets
classification_weights: List[str]
Weights of the classification targets
overwrite: bool
Delete and overwrite existing model (if exists)
Examples
--------
>>> from sklearn import datasets
>>> from sklearn.tree import DecisionTreeClassifier
>>> from eland.ml import ExternalMLModel
>>> # Train model
>>> training_data = datasets.make_classification(n_features=5, random_state=0)
>>> test_data = [[-50.1, 0.2, 0.3, -0.5, 1.0], [1.6, 2.1, -10, 50, -1.0]]
>>> classifier = DecisionTreeClassifier()
>>> classifier = classifier.fit(training_data[0], training_data[1])
>>> # Get some test results
>>> classifier.predict(test_data)
array([0, 1])
>>> # Serialise the model to Elasticsearch
>>> feature_names = ["f0", "f1", "f2", "f3", "f4"]
>>> model_id = "test_decision_tree_classifier"
>>> es_model = ExternalMLModel('localhost', model_id, classifier, feature_names, overwrite=True)
>>> # Get some test results from Elasticsearch model
>>> es_model.predict(test_data)
array([0, 1])
>>> # Delete model from Elasticsearch
>>> es_model.delete_model()
"""
def __init__(self,
es_client,
model_id: str,
model: Union[DecisionTreeClassifier,
DecisionTreeRegressor,
RandomForestRegressor,
RandomForestClassifier,
XGBClassifier,
XGBRegressor],
feature_names: List[str],
classification_labels: List[str] = None,
classification_weights: List[float] = None,
overwrite=False):
super().__init__(
es_client,
model_id
)
self._feature_names = feature_names
self._model_type = None
# Transform model
if isinstance(model, DecisionTreeRegressor):
serializer = SKLearnDecisionTreeTransformer(model, feature_names).transform()
self._model_type = MLModel.TYPE_REGRESSION
elif isinstance(model, DecisionTreeClassifier):
serializer = SKLearnDecisionTreeTransformer(model, feature_names, classification_labels).transform()
self._model_type = MLModel.TYPE_CLASSIFICATION
elif isinstance(model, RandomForestRegressor):
serializer = SKLearnForestRegressorTransformer(model, feature_names).transform()
self._model_type = MLModel.TYPE_REGRESSION
elif isinstance(model, RandomForestClassifier):
serializer = SKLearnForestClassifierTransformer(model, feature_names, classification_labels).transform()
self._model_type = MLModel.TYPE_CLASSIFICATION
elif isinstance(model, XGBRegressor):
serializer = XGBoostRegressorTransformer(model, feature_names).transform()
self._model_type = MLModel.TYPE_REGRESSION
elif isinstance(model, XGBClassifier):
serializer = XGBoostClassifierTransformer(model, feature_names, classification_labels).transform()
self._model_type = MLModel.TYPE_CLASSIFICATION
else:
raise NotImplementedError("ML model of type {}, not currently implemented".format(type(model)))
if overwrite:
self.delete_model()
serialized_model = str(serializer.serialize_and_compress_model())[2:-1] # remove `b` and str quotes
self._client.perform_request(
"PUT", "/_ml/inference/" + self._model_id,
body={
"input": {
"field_names": feature_names
},
"compressed_definition": serialized_model
}
)
def predict(self, X):
"""
Make a prediction using a trained inference model in Elasticsearch.
Parameters for this method are not fully compatible with standard sklearn.predict.
Parameters
----------
X: list or list of lists of type float
Input feature vector - TODO support DataFrame and other formats
Returns
-------
y: np.ndarray of dtype float for regressors or int for classifiers
Examples
--------
>>> from sklearn import datasets
>>> from xgboost import XGBRegressor
>>> from eland.ml import ExternalMLModel
>>> # Train model
>>> training_data = datasets.make_classification(n_features=6, random_state=0)
>>> test_data = [[-1, -2, -3, -4, -5, -6], [10, 20, 30, 40, 50, 60]]
>>> regressor = XGBRegressor(objective='reg:squarederror')
>>> regressor = regressor.fit(training_data[0], training_data[1])
>>> # Get some test results
>>> regressor.predict(np.array(test_data))
array([0.23733574, 1.1897984 ], dtype=float32)
>>> # Serialise the model to Elasticsearch
>>> feature_names = ["f0", "f1", "f2", "f3", "f4", "f5"]
>>> model_id = "test_xgb_regressor"
>>> es_model = ExternalMLModel('localhost', model_id, regressor, feature_names, overwrite=True)
>>> # Get some test results from Elasticsearch model
>>> es_model.predict(test_data)
array([0.2373357, 1.1897984], dtype=float32)
>>> # Delete model from Elasticsearch
>>> es_model.delete_model()
"""
docs = []
if isinstance(X, list):
# Is it a list of lists?
if all(isinstance(i, list) for i in X):
for i in X:
doc = dict()
doc['_source'] = dict(zip(self._feature_names, i))
docs.append(doc)
else: # single feature vector1
doc = dict()
doc['_source'] = dict(zip(self._feature_names, i))
docs.append(doc)
else:
raise NotImplementedError("Prediction for type {}, not supported".format(type(X)))
results = self._client.perform_request(
"POST",
"/_ingest/pipeline/_simulate",
body={
"pipeline": {
"processors": [
{"inference": {
"model_id": self._model_id,
"inference_config": {self._model_type: {}},
"field_mappings": {}
}}
]
},
"docs": docs
})
y = [
doc['doc']['_source']['ml']['inference']['predicted_value'] for doc in results['docs']
]
# Return results as np.ndarray of float32 or int (consistent with sklearn/xgboost)
if self._model_type == MLModel.TYPE_CLASSIFICATION:
dt = np.int
else:
dt = np.float32
return np.asarray(y, dtype=dt)

58
eland/ml/ml_model.py Normal file
View File

@ -0,0 +1,58 @@
# Copyright 2019 Elasticsearch BV
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import elasticsearch
from eland import Client
class MLModel:
"""
A machine learning model managed by Elasticsearch.
(See https://www.elastic.co/guide/en/elasticsearch/reference/master/put-inference.html)
These models can be created by Elastic ML, or transformed from supported python formats such as scikit-learn or
xgboost and imported into Elasticsearch.
The methods for this class attempt to mirror standard python classes.
"""
TYPE_CLASSIFICATION = "classification"
TYPE_REGRESSION = "regression"
def __init__(self,
es_client,
model_id: str):
"""
Parameters
----------
es_client: Elasticsearch client argument(s)
- elasticsearch-py parameters or
- elasticsearch-py instance or
- eland.Client instance
model_id: str
The unique identifier of the trained inference model in Elasticsearch.
"""
self._client = Client(es_client)
self._model_id = model_id
def delete_model(self):
"""
Delete an inference model saved in Elasticsearch
If model doesn't exist, ignore failure.
"""
try:
self._client.perform_request("DELETE", "/_ml/inference/" + self._model_id)
except elasticsearch.exceptions.NotFoundError:
pass

View File

@ -13,28 +13,29 @@
# limitations under the License.
# File called _pytest for PyCharm compatability
from elasticsearch import Elasticsearch
import elasticsearch
import pytest
import eland as ed
from eland.tests import ES_TEST_CLIENT
from eland.tests.common import TestData
class TestClientEq(TestData):
def test_self_eq(self):
es = Elasticsearch('localhost')
def test_perform_request(self):
client = ed.Client(ES_TEST_CLIENT)
client = ed.Client(es)
response = client.perform_request("GET", "/_cat/indices/flights")
assert client != es
# yellow open flights TNUv0iysQSi7F-N5ykWfWQ 1 1 13059 0 5.7mb 5.7mb
tokens = response.split(' ')
assert client == client
assert tokens[2] == 'flights'
assert tokens[6] == '13059'
def test_non_self_ne(self):
es1 = Elasticsearch('localhost')
es2 = Elasticsearch('localhost')
def test_bad_perform_request(self):
client = ed.Client(ES_TEST_CLIENT)
client1 = ed.Client(es1)
client2 = ed.Client(es2)
assert client1 != client2
with pytest.raises(elasticsearch.exceptions.NotFoundError):
response = client.perform_request("GET", "/_cat/indices/non_existant_index")

View File

@ -0,0 +1,157 @@
# Copyright 2020 Elasticsearch BV
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import numpy as np
from sklearn import datasets
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
from xgboost import XGBRegressor, XGBClassifier
from eland.ml import ExternalMLModel
from eland.tests import ES_TEST_CLIENT
class TestExternalMLModel:
def test_decision_tree_classifier(self):
# Train model
training_data = datasets.make_classification(n_features=5)
classifier = DecisionTreeClassifier()
classifier.fit(training_data[0], training_data[1])
# Get some test results
test_data = [[0.1, 0.2, 0.3, -0.5, 1.0], [1.6, 2.1, -10, 50, -1.0]]
test_results = classifier.predict(test_data)
# Serialise the models to Elasticsearch
feature_names = ["f0", "f1", "f2", "f3", "f4"]
model_id = "test_decision_tree_classifier"
es_model = ExternalMLModel(ES_TEST_CLIENT, model_id, classifier, feature_names, overwrite=True)
es_results = es_model.predict(test_data)
np.testing.assert_almost_equal(test_results, es_results, decimal=4)
# Clean up
es_model.delete_model()
def test_decision_tree_regressor(self):
# Train model
training_data = datasets.make_regression(n_features=5)
regressor = DecisionTreeRegressor()
regressor.fit(training_data[0], training_data[1])
# Get some test results
test_data = [[0.1, 0.2, 0.3, -0.5, 1.0], [1.6, 2.1, -10, 50, -1.0]]
test_results = regressor.predict(test_data)
# Serialise the models to Elasticsearch
feature_names = ["f0", "f1", "f2", "f3", "f4"]
model_id = "test_decision_tree_regressor"
es_model = ExternalMLModel(ES_TEST_CLIENT, model_id, regressor, feature_names, overwrite=True)
es_results = es_model.predict(test_data)
np.testing.assert_almost_equal(test_results, es_results, decimal=4)
# Clean up
es_model.delete_model()
def test_random_forest_classifier(self):
# Train model
training_data = datasets.make_classification(n_features=5)
classifier = RandomForestClassifier()
classifier.fit(training_data[0], training_data[1])
# Get some test results
test_data = [[0.1, 0.2, 0.3, -0.5, 1.0], [1.6, 2.1, -10, 50, -1.0]]
test_results = classifier.predict(test_data)
# Serialise the models to Elasticsearch
feature_names = ["f0", "f1", "f2", "f3", "f4"]
model_id = "test_random_forest_classifier"
es_model = ExternalMLModel(ES_TEST_CLIENT, model_id, classifier, feature_names, overwrite=True)
es_results = es_model.predict(test_data)
np.testing.assert_almost_equal(test_results, es_results, decimal=4)
# Clean up
es_model.delete_model()
def test_random_forest_regressor(self):
# Train model
training_data = datasets.make_regression(n_features=5)
regressor = RandomForestRegressor()
regressor.fit(training_data[0], training_data[1])
# Get some test results
test_data = [[0.1, 0.2, 0.3, -0.5, 1.0], [1.6, 2.1, -10, 50, -1.0]]
test_results = regressor.predict(test_data)
# Serialise the models to Elasticsearch
feature_names = ["f0", "f1", "f2", "f3", "f4"]
model_id = "test_random_forest_regressor"
es_model = ExternalMLModel(ES_TEST_CLIENT, model_id, regressor, feature_names, overwrite=True)
es_results = es_model.predict(test_data)
np.testing.assert_almost_equal(test_results, es_results, decimal=4)
# Clean up
es_model.delete_model()
def test_xgb_classifier(self):
# Train model
training_data = datasets.make_classification(n_features=5)
classifier = XGBClassifier()
classifier.fit(training_data[0], training_data[1])
# Get some test results
test_data = [[0.1, 0.2, 0.3, -0.5, 1.0], [1.6, 2.1, -10, 50, -1.0]]
test_results = classifier.predict(test_data)
# Serialise the models to Elasticsearch
feature_names = ["f0", "f1", "f2", "f3", "f4"]
model_id = "test_xgb_classifier"
es_model = ExternalMLModel(ES_TEST_CLIENT, model_id, classifier, feature_names, overwrite=True)
es_results = es_model.predict(test_data)
np.testing.assert_almost_equal(test_results, es_results, decimal=4)
# Clean up
es_model.delete_model()
def test_xgb_regressor(self):
# Train model
training_data = datasets.make_regression(n_features=5)
regressor = XGBRegressor()
regressor.fit(training_data[0], training_data[1])
# Get some test results
test_data = [[0.1, 0.2, 0.3, -0.5, 1.0], [1.6, 2.1, -10, 50, -1.0]]
test_results = regressor.predict(test_data)
# Serialise the models to Elasticsearch
feature_names = ["f0", "f1", "f2", "f3", "f4"]
model_id = "test_xgb_regressor"
es_model = ExternalMLModel(ES_TEST_CLIENT, model_id, regressor, feature_names, overwrite=True)
es_results = es_model.predict(test_data)
np.testing.assert_almost_equal(test_results, es_results, decimal=4)
# Clean up
es_model.delete_model()

View File

@ -0,0 +1,64 @@
# Copyright 2019 Elasticsearch BV
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import types
import pytest
from eland.ml._optional import VERSIONS, import_optional_dependency
def test_import_optional():
match = "Missing .*notapackage.* pip .* conda .* notapackage"
with pytest.raises(ImportError, match=match):
import_optional_dependency("notapackage")
result = import_optional_dependency("notapackage", raise_on_missing=False)
assert result is None
def test_xlrd_version_fallback():
pytest.importorskip("xlrd")
import_optional_dependency("xlrd")
def test_bad_version():
name = "fakemodule"
module = types.ModuleType(name)
module.__version__ = "0.9.0"
sys.modules[name] = module
VERSIONS[name] = "1.0.0"
match = "Eland requires .*1.0.0.* of .fakemodule.*'0.9.0'"
with pytest.raises(ImportError, match=match):
import_optional_dependency("fakemodule")
with pytest.warns(UserWarning):
result = import_optional_dependency("fakemodule", on_version="warn")
assert result is None
module.__version__ = "1.0.0" # exact match is OK
result = import_optional_dependency("fakemodule")
assert result is module
def test_no_version_raises():
name = "fakemodule"
module = types.ModuleType(name)
sys.modules[name] = module
VERSIONS[name] = "1.0.0"
with pytest.raises(ImportError, match="Can't determine .* fakemodule"):
import_optional_dependency(name)

View File

@ -2,8 +2,8 @@
python setup.py install
jupyter nbconvert --to notebook --inplace --execute docs/source/examples/demo_notebook.ipynb
jupyter nbconvert --to notebook --inplace --execute docs/source/examples/online_retail_analysis.ipynb
#jupyter nbconvert --to notebook --inplace --execute docs/source/examples/demo_notebook.ipynb
#jupyter nbconvert --to notebook --inplace --execute docs/source/examples/online_retail_analysis.ipynb
cd docs

View File

@ -4,3 +4,5 @@ matplotlib
pytest>=5.2.1
nbval
numpydoc>=0.9.0
scikit-learn>=0.22.1
xgboost>=0.90