mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
Add NDFrame.var() and .std() aggregations
This commit is contained in:
parent
064d43b9ef
commit
7a1c636e56
@ -251,6 +251,12 @@ class NDFrame(ABC):
|
||||
"""
|
||||
return self._query_compiler.min(numeric_only=numeric_only)
|
||||
|
||||
def var(self, numeric_only=True):
|
||||
return self._query_compiler.var(numeric_only=numeric_only)
|
||||
|
||||
def std(self, numeric_only=True):
|
||||
return self._query_compiler.std(numeric_only=numeric_only)
|
||||
|
||||
def max(self, numeric_only=True):
|
||||
"""
|
||||
Return the maximum value for each numeric column
|
||||
|
@ -123,6 +123,18 @@ class Operations:
|
||||
def mean(self, query_compiler, numeric_only=True):
|
||||
return self._metric_aggs(query_compiler, "avg", numeric_only=numeric_only)
|
||||
|
||||
def var(self, query_compiler, numeric_only=True):
|
||||
return self._metric_aggs(
|
||||
query_compiler, ("extended_stats", "variance"), numeric_only=numeric_only
|
||||
)
|
||||
|
||||
def std(self, query_compiler, numeric_only=True):
|
||||
return self._metric_aggs(
|
||||
query_compiler,
|
||||
("extended_stats", "std_deviation"),
|
||||
numeric_only=numeric_only,
|
||||
)
|
||||
|
||||
def sum(self, query_compiler, numeric_only=True):
|
||||
return self._metric_aggs(query_compiler, "sum", numeric_only=numeric_only)
|
||||
|
||||
@ -226,7 +238,10 @@ class Operations:
|
||||
)
|
||||
|
||||
for field in source_fields:
|
||||
body.metric_aggs(field, func, field)
|
||||
if isinstance(func, tuple):
|
||||
body.metric_aggs(func[0] + "_" + field, func[0], field)
|
||||
else:
|
||||
body.metric_aggs(field, func, field)
|
||||
|
||||
response = query_compiler._client.search(
|
||||
index=query_compiler._index_pattern, size=0, body=body.to_search_body()
|
||||
@ -250,11 +265,21 @@ class Operations:
|
||||
response["aggregations"][field]["value_as_string"], date_format
|
||||
)
|
||||
elif keep_original_dtype:
|
||||
results[field] = pd_dtype.type(
|
||||
response["aggregations"][field]["value"]
|
||||
)
|
||||
if isinstance(func, tuple):
|
||||
results = pd_dtype.type(
|
||||
response["aggregations"][func[0] + "_" + field][func[1]]
|
||||
)
|
||||
else:
|
||||
results[field] = pd_dtype.type(
|
||||
response["aggregations"][field]["value"]
|
||||
)
|
||||
else:
|
||||
results[field] = response["aggregations"][field]["value"]
|
||||
if isinstance(func, tuple):
|
||||
results[field] = response["aggregations"][
|
||||
func[0] + "_" + field
|
||||
][func[1]]
|
||||
else:
|
||||
results[field] = response["aggregations"][field]["value"]
|
||||
|
||||
# Return single value if this is a series
|
||||
# if len(numeric_source_fields) == 1:
|
||||
|
@ -463,6 +463,12 @@ class QueryCompiler:
|
||||
def mean(self, numeric_only=None):
|
||||
return self._operations.mean(self, numeric_only=numeric_only)
|
||||
|
||||
def var(self, numeric_only=None):
|
||||
return self._operations.var(self, numeric_only=numeric_only)
|
||||
|
||||
def std(self, numeric_only=None):
|
||||
return self._operations.std(self, numeric_only=numeric_only)
|
||||
|
||||
def sum(self, numeric_only=None):
|
||||
return self._operations.sum(self, numeric_only=numeric_only)
|
||||
|
||||
|
@ -14,13 +14,14 @@
|
||||
|
||||
# File called _pytest for PyCharm compatability
|
||||
|
||||
from pandas.util.testing import assert_series_equal
|
||||
from pandas.util.testing import assert_series_equal, assert_almost_equal
|
||||
|
||||
from eland.tests.common import TestData
|
||||
|
||||
|
||||
class TestDataFrameMetrics(TestData):
|
||||
funcs = ["max", "min", "mean", "sum"]
|
||||
extended_funcs = ["var", "std"]
|
||||
|
||||
def test_flights_metrics(self):
|
||||
pd_flights = self.pd_flights()
|
||||
@ -32,6 +33,16 @@ class TestDataFrameMetrics(TestData):
|
||||
|
||||
assert_series_equal(pd_metric, ed_metric)
|
||||
|
||||
def test_flights_extended_metrics(self):
|
||||
pd_flights = self.pd_flights()
|
||||
ed_flights = self.ed_flights()
|
||||
|
||||
for func in self.extended_funcs:
|
||||
pd_metric = getattr(pd_flights, func)(numeric_only=True)
|
||||
ed_metric = getattr(ed_flights, func)(numeric_only=True)
|
||||
|
||||
assert_almost_equal(pd_metric, ed_metric, check_less_precise=True)
|
||||
|
||||
def test_ecommerce_selected_non_numeric_source_fields(self):
|
||||
# None of these are numeric
|
||||
columns = [
|
||||
|
Loading…
x
Reference in New Issue
Block a user