Add NDFrame.var() and .std() aggregations

This commit is contained in:
Daniel Mesejo-León 2020-04-12 22:48:13 +02:00 committed by GitHub
parent 064d43b9ef
commit 7a1c636e56
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 54 additions and 6 deletions

View File

@ -251,6 +251,12 @@ class NDFrame(ABC):
"""
return self._query_compiler.min(numeric_only=numeric_only)
def var(self, numeric_only=True):
return self._query_compiler.var(numeric_only=numeric_only)
def std(self, numeric_only=True):
return self._query_compiler.std(numeric_only=numeric_only)
def max(self, numeric_only=True):
"""
Return the maximum value for each numeric column

View File

@ -123,6 +123,18 @@ class Operations:
def mean(self, query_compiler, numeric_only=True):
return self._metric_aggs(query_compiler, "avg", numeric_only=numeric_only)
def var(self, query_compiler, numeric_only=True):
return self._metric_aggs(
query_compiler, ("extended_stats", "variance"), numeric_only=numeric_only
)
def std(self, query_compiler, numeric_only=True):
return self._metric_aggs(
query_compiler,
("extended_stats", "std_deviation"),
numeric_only=numeric_only,
)
def sum(self, query_compiler, numeric_only=True):
return self._metric_aggs(query_compiler, "sum", numeric_only=numeric_only)
@ -226,7 +238,10 @@ class Operations:
)
for field in source_fields:
body.metric_aggs(field, func, field)
if isinstance(func, tuple):
body.metric_aggs(func[0] + "_" + field, func[0], field)
else:
body.metric_aggs(field, func, field)
response = query_compiler._client.search(
index=query_compiler._index_pattern, size=0, body=body.to_search_body()
@ -250,11 +265,21 @@ class Operations:
response["aggregations"][field]["value_as_string"], date_format
)
elif keep_original_dtype:
results[field] = pd_dtype.type(
response["aggregations"][field]["value"]
)
if isinstance(func, tuple):
results = pd_dtype.type(
response["aggregations"][func[0] + "_" + field][func[1]]
)
else:
results[field] = pd_dtype.type(
response["aggregations"][field]["value"]
)
else:
results[field] = response["aggregations"][field]["value"]
if isinstance(func, tuple):
results[field] = response["aggregations"][
func[0] + "_" + field
][func[1]]
else:
results[field] = response["aggregations"][field]["value"]
# Return single value if this is a series
# if len(numeric_source_fields) == 1:

View File

@ -463,6 +463,12 @@ class QueryCompiler:
def mean(self, numeric_only=None):
return self._operations.mean(self, numeric_only=numeric_only)
def var(self, numeric_only=None):
return self._operations.var(self, numeric_only=numeric_only)
def std(self, numeric_only=None):
return self._operations.std(self, numeric_only=numeric_only)
def sum(self, numeric_only=None):
return self._operations.sum(self, numeric_only=numeric_only)

View File

@ -14,13 +14,14 @@
# File called _pytest for PyCharm compatability
from pandas.util.testing import assert_series_equal
from pandas.util.testing import assert_series_equal, assert_almost_equal
from eland.tests.common import TestData
class TestDataFrameMetrics(TestData):
funcs = ["max", "min", "mean", "sum"]
extended_funcs = ["var", "std"]
def test_flights_metrics(self):
pd_flights = self.pd_flights()
@ -32,6 +33,16 @@ class TestDataFrameMetrics(TestData):
assert_series_equal(pd_metric, ed_metric)
def test_flights_extended_metrics(self):
pd_flights = self.pd_flights()
ed_flights = self.ed_flights()
for func in self.extended_funcs:
pd_metric = getattr(pd_flights, func)(numeric_only=True)
ed_metric = getattr(ed_flights, func)(numeric_only=True)
assert_almost_equal(pd_metric, ed_metric, check_less_precise=True)
def test_ecommerce_selected_non_numeric_source_fields(self):
# None of these are numeric
columns = [