diff --git a/eland/dataframe.py b/eland/dataframe.py index 89e110c..87a60dc 100644 --- a/eland/dataframe.py +++ b/eland/dataframe.py @@ -1474,6 +1474,7 @@ class DataFrame(NDFrame): True 121.280296 1175.709961 0.0 6.0 [63 rows x 4 columns] + >>> ed_flights.groupby(["DestCountry", "Cancelled"]).mean(numeric_only=True) # doctest: +NORMALIZE_WHITESPACE AvgTicketPrice dayOfWeek DestCountry Cancelled @@ -1490,6 +1491,23 @@ class DataFrame(NDFrame): True 677.794078 2.928571 [63 rows x 2 columns] + + >>> ed_flights.groupby(["DestCountry", "Cancelled"]).min(numeric_only=False) # doctest: +NORMALIZE_WHITESPACE + AvgTicketPrice dayOfWeek timestamp + DestCountry Cancelled + AE False 110.799911 0 2018-01-01 19:31:30 + True 132.443756 0 2018-01-06 13:03:25 + AR False 125.589394 0 2018-01-01 01:30:47 + True 251.389603 0 2018-01-01 02:13:17 + AT False 100.020531 0 2018-01-01 05:24:19 + ... ... ... ... + TR True 307.915649 0 2018-01-08 04:35:10 + US False 100.145966 0 2018-01-01 00:06:27 + True 102.153069 0 2018-01-01 09:02:36 + ZA False 102.002663 0 2018-01-01 06:44:44 + True 121.280296 0 2018-01-04 00:37:01 + + [63 rows x 3 columns] """ if by is None: raise ValueError("by parameter should be specified to groupby") diff --git a/eland/field_mappings.py b/eland/field_mappings.py index 7754d29..7226d41 100644 --- a/eland/field_mappings.py +++ b/eland/field_mappings.py @@ -100,12 +100,12 @@ class Field(NamedTuple): elif es_agg[0] == "percentiles": es_agg = "percentiles" - # Cardinality works for all types - # Numerics and bools work for all aggs # Except "median_absolute_deviation" which doesn't support bool if es_agg == "median_absolute_deviation" and self.is_bool: return False - if es_agg == "cardinality" or self.is_numeric or self.is_bool: + # Cardinality and Count work for all types + # Numerics and bools work for all aggs + if es_agg in ("cardinality", "value_count") or self.is_numeric or self.is_bool: return True # Timestamps also work for 'min', 'max' and 'avg' if es_agg in {"min", "max", "avg", "percentiles"} and self.is_timestamp: @@ -730,7 +730,6 @@ class FieldMappings: """ groupby_fields: Dict[str, Field] = {} - # groupby_fields: Union[List[Field], List[None]] = [None] * len(by) aggregatable_fields: List[Field] = [] for column, row in self._mappings_capabilities.iterrows(): row = row.to_dict() diff --git a/eland/groupby.py b/eland/groupby.py index 3679a8c..97f9661 100644 --- a/eland/groupby.py +++ b/eland/groupby.py @@ -152,15 +152,42 @@ class GroupByDataFrame(GroupBy): - True: returns all values with float64, NaN/NaT are ignored. - False: returns all values with float64. - None: returns all values with default datatype. + + Returns + ------- + A Pandas DataFrame + """ + # Controls whether a MultiIndex is used for the + # columns of the result DataFrame. + is_dataframe_agg = True if isinstance(func, str): func = [func] + is_dataframe_agg = False + return self._query_compiler.aggs_groupby( by=self._by, pd_aggs=func, dropna=self._dropna, numeric_only=numeric_only, - is_dataframe_agg=True, + is_dataframe_agg=is_dataframe_agg, ) agg = aggregate + + def count(self) -> "pd.DataFrame": + """ + Used to groupby and count + + Returns + ------- + A Pandas DataFrame + + """ + return self._query_compiler.aggs_groupby( + by=self._by, + pd_aggs=["count"], + dropna=self._dropna, + numeric_only=False, + is_dataframe_agg=False, + ) diff --git a/eland/operations.py b/eland/operations.py index 7025156..acb93c9 100644 --- a/eland/operations.py +++ b/eland/operations.py @@ -512,7 +512,7 @@ class Operations: agg_value = np.NaN # Cardinality is always either NaN or integer. - elif pd_agg == "nunique": + elif pd_agg in ("nunique", "count"): agg_value = int(agg_value) # If this is a non-null timestamp field convert to a pd.Timestamp() @@ -579,14 +579,63 @@ class Operations: numeric_only=numeric_only, ) - agg_df = pd.DataFrame(results, columns=results.keys()).set_index(by) + agg_df = pd.DataFrame(results).set_index(by) if is_dataframe_agg: # Convert header columns to MultiIndex agg_df.columns = pd.MultiIndex.from_product([headers, pd_aggs]) + else: + # Convert header columns to Index + agg_df.columns = pd.Index(headers) return agg_df + @staticmethod + def bucket_generator( + query_compiler: "QueryCompiler", body: "Query" + ) -> Generator[List[str], None, List[str]]: + """ + This can be used for all groupby operations. + e.g. + "aggregations": { + "groupby_buckets": { + "after_key": {"total_quantity": 8}, + "buckets": [ + { + "key": {"total_quantity": 1}, + "doc_count": 87, + "taxful_total_price_avg": {"value": 48.035978536496216}, + } + ], + } + } + Returns + ------- + A generator which initially yields the bucket + If after_key is found, use it to fetch the next set of buckets. + + """ + while True: + res = query_compiler._client.search( + index=query_compiler._index_pattern, + size=0, + body=body.to_search_body(), + ) + + # Pagination Logic + composite_buckets = res["aggregations"]["groupby_buckets"] + if "after_key" in composite_buckets: + + # yield the bucket which contains the result + yield composite_buckets["buckets"] + + body.composite_agg_after_key( + name="groupby_buckets", + after_key=composite_buckets["after_key"], + ) + else: + return composite_buckets["buckets"] + def _groupby_aggs( self, query_compiler: "QueryCompiler", @@ -640,33 +689,45 @@ class Operations: body = Query(query_params.query) + # To return for creating multi-index on columns + headers = [field.column for field in agg_fields] + # Convert pandas aggs to ES equivalent es_aggs = self._map_pd_aggs_to_es_aggs(pd_aggs) + # pd_agg 'count' is handled via 'doc_count' from buckets + using_pd_agg_count = "count" in pd_aggs + # Construct Query for by_field in by_fields: # groupby fields will be term aggregations body.composite_agg_bucket_terms( - name=f"groupby_{by_field.column}", field=by_field.es_field_name + name=f"groupby_{by_field.column}", + field=by_field.aggregatable_es_field_name, ) - for field in agg_fields: + for agg_field in agg_fields: for es_agg in es_aggs: - if not field.is_es_agg_compatible(es_agg): + # Skip if the field isn't compatible or if the agg is + # 'value_count' as this value is pulled from bucket.doc_count. + if ( + not agg_field.is_es_agg_compatible(es_agg) + or es_agg == "value_count" + ): continue # If we have multiple 'extended_stats' etc. here we simply NOOP on 2nd call if isinstance(es_agg, tuple): body.metric_aggs( - f"{es_agg[0]}_{field.es_field_name}", + f"{es_agg[0]}_{agg_field.es_field_name}", es_agg[0], - field.aggregatable_es_field_name, + agg_field.aggregatable_es_field_name, ) else: body.metric_aggs( - f"{es_agg}_{field.es_field_name}", + f"{es_agg}_{agg_field.es_field_name}", es_agg, - field.aggregatable_es_field_name, + agg_field.aggregatable_es_field_name, ) # Composite aggregation @@ -674,49 +735,7 @@ class Operations: size=DEFAULT_PAGINATION_SIZE, name="groupby_buckets", dropna=dropna ) - def bucket_generator() -> Generator[List[str], None, List[str]]: - """ - e.g. - "aggregations": { - "groupby_buckets": { - "after_key": {"total_quantity": 8}, - "buckets": [ - { - "key": {"total_quantity": 1}, - "doc_count": 87, - "taxful_total_price_avg": {"value": 48.035978536496216}, - } - ], - } - } - Returns - ------- - A generator which initially yields the bucket - If after_key is found, use it to fetch the next set of buckets. - - """ - while True: - res = query_compiler._client.search( - index=query_compiler._index_pattern, - size=0, - body=body.to_search_body(), - ) - - # Pagination Logic - composite_buckets = res["aggregations"]["groupby_buckets"] - if "after_key" in composite_buckets: - - # yield the bucket which contains the result - yield composite_buckets["buckets"] - - body.composite_agg_after_key( - name="groupby_buckets", - after_key=composite_buckets["after_key"], - ) - else: - return composite_buckets["buckets"] - - for buckets in bucket_generator(): + for buckets in self.bucket_generator(query_compiler, body): # We recieve response row-wise for bucket in buckets: # groupby columns are added to result same way they are returned @@ -729,6 +748,15 @@ class Operations: response[by_field.column].append(bucket_key) + # Put 'doc_count' from bucket into each 'agg_field' + # to be extracted from _unpack_metric_aggs() + if using_pd_agg_count: + doc_count = bucket["doc_count"] + for agg_field in agg_fields: + bucket[f"value_count_{agg_field.es_field_name}"] = { + "value": doc_count + } + agg_calculation = self._unpack_metric_aggs( fields=agg_fields, es_aggs=es_aggs, @@ -737,15 +765,16 @@ class Operations: numeric_only=numeric_only, is_dataframe_agg=is_dataframe_agg, ) + # Process the calculated agg values to response for key, value in agg_calculation.items(): - if isinstance(value, list): - for pd_agg, val in zip(pd_aggs, value): - response[f"{key}_{pd_agg}"].append(val) - else: + if not isinstance(value, list): response[key].append(value) + continue + for pd_agg, val in zip(pd_aggs, value): + response[f"{key}_{pd_agg}"].append(val) - return [field.column for field in agg_fields], response + return headers, response @staticmethod def _map_pd_aggs_to_es_aggs(pd_aggs): @@ -781,8 +810,8 @@ class Operations: """ # pd aggs that will be mapped to es aggs # that can use 'extended_stats'. - extended_stats_pd_aggs = {"mean", "min", "max", "count", "sum", "var", "std"} - extended_stats_es_aggs = {"avg", "min", "max", "count", "sum"} + extended_stats_pd_aggs = {"mean", "min", "max", "sum", "var", "std"} + extended_stats_es_aggs = {"avg", "min", "max", "sum"} extended_stats_calls = 0 es_aggs = [] @@ -792,7 +821,7 @@ class Operations: # Aggs that are 'extended_stats' compatible if pd_agg == "count": - es_aggs.append("count") + es_aggs.append("value_count") elif pd_agg == "max": es_aggs.append("max") elif pd_agg == "min": diff --git a/eland/query.py b/eland/query.py index 8d55fa2..5f7fe2e 100644 --- a/eland/query.py +++ b/eland/query.py @@ -195,8 +195,8 @@ class Query: Parameters ---------- - size: int - Pagination size. + size: int or None + Use composite aggregation with pagination if size is not None name: str Name of the buckets dropna: bool @@ -215,10 +215,13 @@ class Query: sources.append({bucket_agg_name: bucket_agg}) self._composite_aggs.clear() - aggs = { - "composite": {"size": size, "sources": sources}, - "aggregations": self._aggs.copy(), + aggs: Dict[str, Dict[str, Any]] = { + "composite": {"size": size, "sources": sources} } + + if self._aggs: + aggs["aggregations"] = self._aggs.copy() + self._aggs.clear() self._aggs[name] = aggs diff --git a/eland/tests/dataframe/test_count_pytest.py b/eland/tests/dataframe/test_count_pytest.py index 4936009..eff36fc 100644 --- a/eland/tests/dataframe/test_count_pytest.py +++ b/eland/tests/dataframe/test_count_pytest.py @@ -16,9 +16,30 @@ # under the License. # File called _pytest for PyCharm compatability +from pandas.testing import assert_series_equal + +from eland.tests.common import TestData -class TestDataFrameCount: +class TestDataFrameCount(TestData): + filter_data = [ + "AvgTicketPrice", + "Cancelled", + "dayOfWeek", + "timestamp", + "DestCountry", + ] + def test_count(self, df): df.load_dataset("ecommerce") df.count() + + def test_count_flights(self): + + pd_flights = self.pd_flights().filter(self.filter_data) + ed_flights = self.ed_flights().filter(self.filter_data) + + pd_count = pd_flights.count() + ed_count = ed_flights.count() + + assert_series_equal(pd_count, ed_count) diff --git a/eland/tests/dataframe/test_groupby_pytest.py b/eland/tests/dataframe/test_groupby_pytest.py index 3ae95a0..8027caf 100644 --- a/eland/tests/dataframe/test_groupby_pytest.py +++ b/eland/tests/dataframe/test_groupby_pytest.py @@ -19,7 +19,7 @@ import pandas as pd import pytest -from pandas.testing import assert_frame_equal, assert_series_equal +from pandas.testing import assert_frame_equal, assert_index_equal, assert_series_equal from eland.tests.common import TestData @@ -154,3 +154,25 @@ class TestGroupbyDataFrame(TestData): def test_groupby_dropna(self): # TODO Add tests once dropna is implemeted pass + + @pytest.mark.parametrize("groupby", ["dayOfWeek", ["dayOfWeek", "Cancelled"]]) + @pytest.mark.parametrize( + ["func", "func_args"], + [ + ("count", ()), + ("agg", ("count",)), + ("agg", (["count"],)), + ("agg", (["max", "count", "min"],)), + ], + ) + def test_groupby_dataframe_count(self, groupby, func, func_args): + pd_flights = self.pd_flights().filter(self.filter_data) + ed_flights = self.ed_flights().filter(self.filter_data) + + pd_count = getattr(pd_flights.groupby(groupby), func)(*func_args) + ed_count = getattr(ed_flights.groupby(groupby), func)(*func_args) + + assert_index_equal(pd_count.columns, ed_count.columns) + assert_index_equal(pd_count.index, ed_count.index) + assert_frame_equal(pd_count, ed_count) + assert_series_equal(pd_count.dtypes, ed_count.dtypes) diff --git a/eland/tests/dataframe/test_metrics_pytest.py b/eland/tests/dataframe/test_metrics_pytest.py index d3d5785..6eba69d 100644 --- a/eland/tests/dataframe/test_metrics_pytest.py +++ b/eland/tests/dataframe/test_metrics_pytest.py @@ -20,7 +20,7 @@ import pandas as pd # File called _pytest for PyCharm compatibility import pytest -from pandas.testing import assert_series_equal +from pandas.testing import assert_frame_equal, assert_series_equal from eland.tests.common import TestData @@ -414,3 +414,13 @@ class TestDataFrameMetrics(TestData): assert isinstance(calculated_values["AvgTicketPrice"], float) assert isinstance(calculated_values["dayOfWeek"], float) assert calculated_values.shape == (2,) + + def test_aggs_count(self): + + pd_flights = self.pd_flights().filter(self.filter_data) + ed_flights = self.ed_flights().filter(self.filter_data) + + pd_count = pd_flights.agg(["count"]) + ed_count = ed_flights.agg(["count"]) + + assert_frame_equal(pd_count, ed_count) diff --git a/eland/tests/operations/test_map_pd_aggs_to_es_aggs_pytest.py b/eland/tests/operations/test_map_pd_aggs_to_es_aggs_pytest.py index 36dde8f..2ec7882 100644 --- a/eland/tests/operations/test_map_pd_aggs_to_es_aggs_pytest.py +++ b/eland/tests/operations/test_map_pd_aggs_to_es_aggs_pytest.py @@ -30,7 +30,7 @@ def test_all_aggs(): ("extended_stats", "std_deviation"), ("extended_stats", "variance"), "median_absolute_deviation", - ("extended_stats", "count"), + "value_count", "cardinality", ("percentiles", "50.0"), ] @@ -40,7 +40,7 @@ def test_extended_stats_optimization(): # Tests that when '' and an 'extended_stats' agg are used together # that ('extended_stats', '') is used instead of ''. es_aggs = Operations._map_pd_aggs_to_es_aggs(["count", "nunique"]) - assert es_aggs == ["count", "cardinality"] + assert es_aggs == ["value_count", "cardinality"] for pd_agg in ["var", "std"]: extended_es_agg = Operations._map_pd_aggs_to_es_aggs([pd_agg])[0] @@ -49,4 +49,4 @@ def test_extended_stats_optimization(): assert es_aggs == [extended_es_agg, "cardinality"] es_aggs = Operations._map_pd_aggs_to_es_aggs(["count", pd_agg, "nunique"]) - assert es_aggs == [("extended_stats", "count"), extended_es_agg, "cardinality"] + assert es_aggs == ["value_count", extended_es_agg, "cardinality"]