diff --git a/eland/ndframe.py b/eland/ndframe.py index 3d8b646..d1c9372 100644 --- a/eland/ndframe.py +++ b/eland/ndframe.py @@ -251,6 +251,12 @@ class NDFrame(ABC): """ return self._query_compiler.min(numeric_only=numeric_only) + def var(self, numeric_only=True): + return self._query_compiler.var(numeric_only=numeric_only) + + def std(self, numeric_only=True): + return self._query_compiler.std(numeric_only=numeric_only) + def max(self, numeric_only=True): """ Return the maximum value for each numeric column diff --git a/eland/operations.py b/eland/operations.py index 63f4563..a650d46 100644 --- a/eland/operations.py +++ b/eland/operations.py @@ -123,6 +123,18 @@ class Operations: def mean(self, query_compiler, numeric_only=True): return self._metric_aggs(query_compiler, "avg", numeric_only=numeric_only) + def var(self, query_compiler, numeric_only=True): + return self._metric_aggs( + query_compiler, ("extended_stats", "variance"), numeric_only=numeric_only + ) + + def std(self, query_compiler, numeric_only=True): + return self._metric_aggs( + query_compiler, + ("extended_stats", "std_deviation"), + numeric_only=numeric_only, + ) + def sum(self, query_compiler, numeric_only=True): return self._metric_aggs(query_compiler, "sum", numeric_only=numeric_only) @@ -226,7 +238,10 @@ class Operations: ) for field in source_fields: - body.metric_aggs(field, func, field) + if isinstance(func, tuple): + body.metric_aggs(func[0] + "_" + field, func[0], field) + else: + body.metric_aggs(field, func, field) response = query_compiler._client.search( index=query_compiler._index_pattern, size=0, body=body.to_search_body() @@ -250,11 +265,21 @@ class Operations: response["aggregations"][field]["value_as_string"], date_format ) elif keep_original_dtype: - results[field] = pd_dtype.type( - response["aggregations"][field]["value"] - ) + if isinstance(func, tuple): + results = pd_dtype.type( + response["aggregations"][func[0] + "_" + field][func[1]] + ) + else: + results[field] = pd_dtype.type( + response["aggregations"][field]["value"] + ) else: - results[field] = response["aggregations"][field]["value"] + if isinstance(func, tuple): + results[field] = response["aggregations"][ + func[0] + "_" + field + ][func[1]] + else: + results[field] = response["aggregations"][field]["value"] # Return single value if this is a series # if len(numeric_source_fields) == 1: diff --git a/eland/query_compiler.py b/eland/query_compiler.py index b32c4e6..011d5dc 100644 --- a/eland/query_compiler.py +++ b/eland/query_compiler.py @@ -463,6 +463,12 @@ class QueryCompiler: def mean(self, numeric_only=None): return self._operations.mean(self, numeric_only=numeric_only) + def var(self, numeric_only=None): + return self._operations.var(self, numeric_only=numeric_only) + + def std(self, numeric_only=None): + return self._operations.std(self, numeric_only=numeric_only) + def sum(self, numeric_only=None): return self._operations.sum(self, numeric_only=numeric_only) diff --git a/eland/tests/dataframe/test_metrics_pytest.py b/eland/tests/dataframe/test_metrics_pytest.py index 1551b75..bb55cc1 100644 --- a/eland/tests/dataframe/test_metrics_pytest.py +++ b/eland/tests/dataframe/test_metrics_pytest.py @@ -14,13 +14,14 @@ # File called _pytest for PyCharm compatability -from pandas.util.testing import assert_series_equal +from pandas.util.testing import assert_series_equal, assert_almost_equal from eland.tests.common import TestData class TestDataFrameMetrics(TestData): funcs = ["max", "min", "mean", "sum"] + extended_funcs = ["var", "std"] def test_flights_metrics(self): pd_flights = self.pd_flights() @@ -32,6 +33,16 @@ class TestDataFrameMetrics(TestData): assert_series_equal(pd_metric, ed_metric) + def test_flights_extended_metrics(self): + pd_flights = self.pd_flights() + ed_flights = self.ed_flights() + + for func in self.extended_funcs: + pd_metric = getattr(pd_flights, func)(numeric_only=True) + ed_metric = getattr(ed_flights, func)(numeric_only=True) + + assert_almost_equal(pd_metric, ed_metric, check_less_precise=True) + def test_ecommerce_selected_non_numeric_source_fields(self): # None of these are numeric columns = [