Implement DataFrameGroupBy.count()

This commit is contained in:
P. Sai Vinay 2020-10-23 19:11:50 +05:30 committed by GitHub
parent bd7956ea72
commit 475e0f41ef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 205 additions and 76 deletions

View File

@ -1474,6 +1474,7 @@ class DataFrame(NDFrame):
True 121.280296 1175.709961 0.0 6.0 True 121.280296 1175.709961 0.0 6.0
<BLANKLINE> <BLANKLINE>
[63 rows x 4 columns] [63 rows x 4 columns]
>>> ed_flights.groupby(["DestCountry", "Cancelled"]).mean(numeric_only=True) # doctest: +NORMALIZE_WHITESPACE >>> ed_flights.groupby(["DestCountry", "Cancelled"]).mean(numeric_only=True) # doctest: +NORMALIZE_WHITESPACE
AvgTicketPrice dayOfWeek AvgTicketPrice dayOfWeek
DestCountry Cancelled DestCountry Cancelled
@ -1490,6 +1491,23 @@ class DataFrame(NDFrame):
True 677.794078 2.928571 True 677.794078 2.928571
<BLANKLINE> <BLANKLINE>
[63 rows x 2 columns] [63 rows x 2 columns]
>>> ed_flights.groupby(["DestCountry", "Cancelled"]).min(numeric_only=False) # doctest: +NORMALIZE_WHITESPACE
AvgTicketPrice dayOfWeek timestamp
DestCountry Cancelled
AE False 110.799911 0 2018-01-01 19:31:30
True 132.443756 0 2018-01-06 13:03:25
AR False 125.589394 0 2018-01-01 01:30:47
True 251.389603 0 2018-01-01 02:13:17
AT False 100.020531 0 2018-01-01 05:24:19
... ... ... ...
TR True 307.915649 0 2018-01-08 04:35:10
US False 100.145966 0 2018-01-01 00:06:27
True 102.153069 0 2018-01-01 09:02:36
ZA False 102.002663 0 2018-01-01 06:44:44
True 121.280296 0 2018-01-04 00:37:01
<BLANKLINE>
[63 rows x 3 columns]
""" """
if by is None: if by is None:
raise ValueError("by parameter should be specified to groupby") raise ValueError("by parameter should be specified to groupby")

View File

@ -100,12 +100,12 @@ class Field(NamedTuple):
elif es_agg[0] == "percentiles": elif es_agg[0] == "percentiles":
es_agg = "percentiles" es_agg = "percentiles"
# Cardinality works for all types
# Numerics and bools work for all aggs
# Except "median_absolute_deviation" which doesn't support bool # Except "median_absolute_deviation" which doesn't support bool
if es_agg == "median_absolute_deviation" and self.is_bool: if es_agg == "median_absolute_deviation" and self.is_bool:
return False return False
if es_agg == "cardinality" or self.is_numeric or self.is_bool: # Cardinality and Count work for all types
# Numerics and bools work for all aggs
if es_agg in ("cardinality", "value_count") or self.is_numeric or self.is_bool:
return True return True
# Timestamps also work for 'min', 'max' and 'avg' # Timestamps also work for 'min', 'max' and 'avg'
if es_agg in {"min", "max", "avg", "percentiles"} and self.is_timestamp: if es_agg in {"min", "max", "avg", "percentiles"} and self.is_timestamp:
@ -730,7 +730,6 @@ class FieldMappings:
""" """
groupby_fields: Dict[str, Field] = {} groupby_fields: Dict[str, Field] = {}
# groupby_fields: Union[List[Field], List[None]] = [None] * len(by)
aggregatable_fields: List[Field] = [] aggregatable_fields: List[Field] = []
for column, row in self._mappings_capabilities.iterrows(): for column, row in self._mappings_capabilities.iterrows():
row = row.to_dict() row = row.to_dict()

View File

@ -152,15 +152,42 @@ class GroupByDataFrame(GroupBy):
- True: returns all values with float64, NaN/NaT are ignored. - True: returns all values with float64, NaN/NaT are ignored.
- False: returns all values with float64. - False: returns all values with float64.
- None: returns all values with default datatype. - None: returns all values with default datatype.
Returns
-------
A Pandas DataFrame
""" """
# Controls whether a MultiIndex is used for the
# columns of the result DataFrame.
is_dataframe_agg = True
if isinstance(func, str): if isinstance(func, str):
func = [func] func = [func]
is_dataframe_agg = False
return self._query_compiler.aggs_groupby( return self._query_compiler.aggs_groupby(
by=self._by, by=self._by,
pd_aggs=func, pd_aggs=func,
dropna=self._dropna, dropna=self._dropna,
numeric_only=numeric_only, numeric_only=numeric_only,
is_dataframe_agg=True, is_dataframe_agg=is_dataframe_agg,
) )
agg = aggregate agg = aggregate
def count(self) -> "pd.DataFrame":
"""
Used to groupby and count
Returns
-------
A Pandas DataFrame
"""
return self._query_compiler.aggs_groupby(
by=self._by,
pd_aggs=["count"],
dropna=self._dropna,
numeric_only=False,
is_dataframe_agg=False,
)

View File

@ -512,7 +512,7 @@ class Operations:
agg_value = np.NaN agg_value = np.NaN
# Cardinality is always either NaN or integer. # Cardinality is always either NaN or integer.
elif pd_agg == "nunique": elif pd_agg in ("nunique", "count"):
agg_value = int(agg_value) agg_value = int(agg_value)
# If this is a non-null timestamp field convert to a pd.Timestamp() # If this is a non-null timestamp field convert to a pd.Timestamp()
@ -579,14 +579,63 @@ class Operations:
numeric_only=numeric_only, numeric_only=numeric_only,
) )
agg_df = pd.DataFrame(results, columns=results.keys()).set_index(by) agg_df = pd.DataFrame(results).set_index(by)
if is_dataframe_agg: if is_dataframe_agg:
# Convert header columns to MultiIndex # Convert header columns to MultiIndex
agg_df.columns = pd.MultiIndex.from_product([headers, pd_aggs]) agg_df.columns = pd.MultiIndex.from_product([headers, pd_aggs])
else:
# Convert header columns to Index
agg_df.columns = pd.Index(headers)
return agg_df return agg_df
@staticmethod
def bucket_generator(
query_compiler: "QueryCompiler", body: "Query"
) -> Generator[List[str], None, List[str]]:
"""
This can be used for all groupby operations.
e.g.
"aggregations": {
"groupby_buckets": {
"after_key": {"total_quantity": 8},
"buckets": [
{
"key": {"total_quantity": 1},
"doc_count": 87,
"taxful_total_price_avg": {"value": 48.035978536496216},
}
],
}
}
Returns
-------
A generator which initially yields the bucket
If after_key is found, use it to fetch the next set of buckets.
"""
while True:
res = query_compiler._client.search(
index=query_compiler._index_pattern,
size=0,
body=body.to_search_body(),
)
# Pagination Logic
composite_buckets = res["aggregations"]["groupby_buckets"]
if "after_key" in composite_buckets:
# yield the bucket which contains the result
yield composite_buckets["buckets"]
body.composite_agg_after_key(
name="groupby_buckets",
after_key=composite_buckets["after_key"],
)
else:
return composite_buckets["buckets"]
def _groupby_aggs( def _groupby_aggs(
self, self,
query_compiler: "QueryCompiler", query_compiler: "QueryCompiler",
@ -640,33 +689,45 @@ class Operations:
body = Query(query_params.query) body = Query(query_params.query)
# To return for creating multi-index on columns
headers = [field.column for field in agg_fields]
# Convert pandas aggs to ES equivalent # Convert pandas aggs to ES equivalent
es_aggs = self._map_pd_aggs_to_es_aggs(pd_aggs) es_aggs = self._map_pd_aggs_to_es_aggs(pd_aggs)
# pd_agg 'count' is handled via 'doc_count' from buckets
using_pd_agg_count = "count" in pd_aggs
# Construct Query # Construct Query
for by_field in by_fields: for by_field in by_fields:
# groupby fields will be term aggregations # groupby fields will be term aggregations
body.composite_agg_bucket_terms( body.composite_agg_bucket_terms(
name=f"groupby_{by_field.column}", field=by_field.es_field_name name=f"groupby_{by_field.column}",
field=by_field.aggregatable_es_field_name,
) )
for field in agg_fields: for agg_field in agg_fields:
for es_agg in es_aggs: for es_agg in es_aggs:
if not field.is_es_agg_compatible(es_agg): # Skip if the field isn't compatible or if the agg is
# 'value_count' as this value is pulled from bucket.doc_count.
if (
not agg_field.is_es_agg_compatible(es_agg)
or es_agg == "value_count"
):
continue continue
# If we have multiple 'extended_stats' etc. here we simply NOOP on 2nd call # If we have multiple 'extended_stats' etc. here we simply NOOP on 2nd call
if isinstance(es_agg, tuple): if isinstance(es_agg, tuple):
body.metric_aggs( body.metric_aggs(
f"{es_agg[0]}_{field.es_field_name}", f"{es_agg[0]}_{agg_field.es_field_name}",
es_agg[0], es_agg[0],
field.aggregatable_es_field_name, agg_field.aggregatable_es_field_name,
) )
else: else:
body.metric_aggs( body.metric_aggs(
f"{es_agg}_{field.es_field_name}", f"{es_agg}_{agg_field.es_field_name}",
es_agg, es_agg,
field.aggregatable_es_field_name, agg_field.aggregatable_es_field_name,
) )
# Composite aggregation # Composite aggregation
@ -674,49 +735,7 @@ class Operations:
size=DEFAULT_PAGINATION_SIZE, name="groupby_buckets", dropna=dropna size=DEFAULT_PAGINATION_SIZE, name="groupby_buckets", dropna=dropna
) )
def bucket_generator() -> Generator[List[str], None, List[str]]: for buckets in self.bucket_generator(query_compiler, body):
"""
e.g.
"aggregations": {
"groupby_buckets": {
"after_key": {"total_quantity": 8},
"buckets": [
{
"key": {"total_quantity": 1},
"doc_count": 87,
"taxful_total_price_avg": {"value": 48.035978536496216},
}
],
}
}
Returns
-------
A generator which initially yields the bucket
If after_key is found, use it to fetch the next set of buckets.
"""
while True:
res = query_compiler._client.search(
index=query_compiler._index_pattern,
size=0,
body=body.to_search_body(),
)
# Pagination Logic
composite_buckets = res["aggregations"]["groupby_buckets"]
if "after_key" in composite_buckets:
# yield the bucket which contains the result
yield composite_buckets["buckets"]
body.composite_agg_after_key(
name="groupby_buckets",
after_key=composite_buckets["after_key"],
)
else:
return composite_buckets["buckets"]
for buckets in bucket_generator():
# We recieve response row-wise # We recieve response row-wise
for bucket in buckets: for bucket in buckets:
# groupby columns are added to result same way they are returned # groupby columns are added to result same way they are returned
@ -729,6 +748,15 @@ class Operations:
response[by_field.column].append(bucket_key) response[by_field.column].append(bucket_key)
# Put 'doc_count' from bucket into each 'agg_field'
# to be extracted from _unpack_metric_aggs()
if using_pd_agg_count:
doc_count = bucket["doc_count"]
for agg_field in agg_fields:
bucket[f"value_count_{agg_field.es_field_name}"] = {
"value": doc_count
}
agg_calculation = self._unpack_metric_aggs( agg_calculation = self._unpack_metric_aggs(
fields=agg_fields, fields=agg_fields,
es_aggs=es_aggs, es_aggs=es_aggs,
@ -737,15 +765,16 @@ class Operations:
numeric_only=numeric_only, numeric_only=numeric_only,
is_dataframe_agg=is_dataframe_agg, is_dataframe_agg=is_dataframe_agg,
) )
# Process the calculated agg values to response # Process the calculated agg values to response
for key, value in agg_calculation.items(): for key, value in agg_calculation.items():
if isinstance(value, list): if not isinstance(value, list):
for pd_agg, val in zip(pd_aggs, value):
response[f"{key}_{pd_agg}"].append(val)
else:
response[key].append(value) response[key].append(value)
continue
for pd_agg, val in zip(pd_aggs, value):
response[f"{key}_{pd_agg}"].append(val)
return [field.column for field in agg_fields], response return headers, response
@staticmethod @staticmethod
def _map_pd_aggs_to_es_aggs(pd_aggs): def _map_pd_aggs_to_es_aggs(pd_aggs):
@ -781,8 +810,8 @@ class Operations:
""" """
# pd aggs that will be mapped to es aggs # pd aggs that will be mapped to es aggs
# that can use 'extended_stats'. # that can use 'extended_stats'.
extended_stats_pd_aggs = {"mean", "min", "max", "count", "sum", "var", "std"} extended_stats_pd_aggs = {"mean", "min", "max", "sum", "var", "std"}
extended_stats_es_aggs = {"avg", "min", "max", "count", "sum"} extended_stats_es_aggs = {"avg", "min", "max", "sum"}
extended_stats_calls = 0 extended_stats_calls = 0
es_aggs = [] es_aggs = []
@ -792,7 +821,7 @@ class Operations:
# Aggs that are 'extended_stats' compatible # Aggs that are 'extended_stats' compatible
if pd_agg == "count": if pd_agg == "count":
es_aggs.append("count") es_aggs.append("value_count")
elif pd_agg == "max": elif pd_agg == "max":
es_aggs.append("max") es_aggs.append("max")
elif pd_agg == "min": elif pd_agg == "min":

View File

@ -195,8 +195,8 @@ class Query:
Parameters Parameters
---------- ----------
size: int size: int or None
Pagination size. Use composite aggregation with pagination if size is not None
name: str name: str
Name of the buckets Name of the buckets
dropna: bool dropna: bool
@ -215,10 +215,13 @@ class Query:
sources.append({bucket_agg_name: bucket_agg}) sources.append({bucket_agg_name: bucket_agg})
self._composite_aggs.clear() self._composite_aggs.clear()
aggs = { aggs: Dict[str, Dict[str, Any]] = {
"composite": {"size": size, "sources": sources}, "composite": {"size": size, "sources": sources}
"aggregations": self._aggs.copy(),
} }
if self._aggs:
aggs["aggregations"] = self._aggs.copy()
self._aggs.clear() self._aggs.clear()
self._aggs[name] = aggs self._aggs[name] = aggs

View File

@ -16,9 +16,30 @@
# under the License. # under the License.
# File called _pytest for PyCharm compatability # File called _pytest for PyCharm compatability
from pandas.testing import assert_series_equal
from eland.tests.common import TestData
class TestDataFrameCount: class TestDataFrameCount(TestData):
filter_data = [
"AvgTicketPrice",
"Cancelled",
"dayOfWeek",
"timestamp",
"DestCountry",
]
def test_count(self, df): def test_count(self, df):
df.load_dataset("ecommerce") df.load_dataset("ecommerce")
df.count() df.count()
def test_count_flights(self):
pd_flights = self.pd_flights().filter(self.filter_data)
ed_flights = self.ed_flights().filter(self.filter_data)
pd_count = pd_flights.count()
ed_count = ed_flights.count()
assert_series_equal(pd_count, ed_count)

View File

@ -19,7 +19,7 @@
import pandas as pd import pandas as pd
import pytest import pytest
from pandas.testing import assert_frame_equal, assert_series_equal from pandas.testing import assert_frame_equal, assert_index_equal, assert_series_equal
from eland.tests.common import TestData from eland.tests.common import TestData
@ -154,3 +154,25 @@ class TestGroupbyDataFrame(TestData):
def test_groupby_dropna(self): def test_groupby_dropna(self):
# TODO Add tests once dropna is implemeted # TODO Add tests once dropna is implemeted
pass pass
@pytest.mark.parametrize("groupby", ["dayOfWeek", ["dayOfWeek", "Cancelled"]])
@pytest.mark.parametrize(
["func", "func_args"],
[
("count", ()),
("agg", ("count",)),
("agg", (["count"],)),
("agg", (["max", "count", "min"],)),
],
)
def test_groupby_dataframe_count(self, groupby, func, func_args):
pd_flights = self.pd_flights().filter(self.filter_data)
ed_flights = self.ed_flights().filter(self.filter_data)
pd_count = getattr(pd_flights.groupby(groupby), func)(*func_args)
ed_count = getattr(ed_flights.groupby(groupby), func)(*func_args)
assert_index_equal(pd_count.columns, ed_count.columns)
assert_index_equal(pd_count.index, ed_count.index)
assert_frame_equal(pd_count, ed_count)
assert_series_equal(pd_count.dtypes, ed_count.dtypes)

View File

@ -20,7 +20,7 @@ import pandas as pd
# File called _pytest for PyCharm compatibility # File called _pytest for PyCharm compatibility
import pytest import pytest
from pandas.testing import assert_series_equal from pandas.testing import assert_frame_equal, assert_series_equal
from eland.tests.common import TestData from eland.tests.common import TestData
@ -414,3 +414,13 @@ class TestDataFrameMetrics(TestData):
assert isinstance(calculated_values["AvgTicketPrice"], float) assert isinstance(calculated_values["AvgTicketPrice"], float)
assert isinstance(calculated_values["dayOfWeek"], float) assert isinstance(calculated_values["dayOfWeek"], float)
assert calculated_values.shape == (2,) assert calculated_values.shape == (2,)
def test_aggs_count(self):
pd_flights = self.pd_flights().filter(self.filter_data)
ed_flights = self.ed_flights().filter(self.filter_data)
pd_count = pd_flights.agg(["count"])
ed_count = ed_flights.agg(["count"])
assert_frame_equal(pd_count, ed_count)

View File

@ -30,7 +30,7 @@ def test_all_aggs():
("extended_stats", "std_deviation"), ("extended_stats", "std_deviation"),
("extended_stats", "variance"), ("extended_stats", "variance"),
"median_absolute_deviation", "median_absolute_deviation",
("extended_stats", "count"), "value_count",
"cardinality", "cardinality",
("percentiles", "50.0"), ("percentiles", "50.0"),
] ]
@ -40,7 +40,7 @@ def test_extended_stats_optimization():
# Tests that when '<agg>' and an 'extended_stats' agg are used together # Tests that when '<agg>' and an 'extended_stats' agg are used together
# that ('extended_stats', '<agg>') is used instead of '<agg>'. # that ('extended_stats', '<agg>') is used instead of '<agg>'.
es_aggs = Operations._map_pd_aggs_to_es_aggs(["count", "nunique"]) es_aggs = Operations._map_pd_aggs_to_es_aggs(["count", "nunique"])
assert es_aggs == ["count", "cardinality"] assert es_aggs == ["value_count", "cardinality"]
for pd_agg in ["var", "std"]: for pd_agg in ["var", "std"]:
extended_es_agg = Operations._map_pd_aggs_to_es_aggs([pd_agg])[0] extended_es_agg = Operations._map_pd_aggs_to_es_aggs([pd_agg])[0]
@ -49,4 +49,4 @@ def test_extended_stats_optimization():
assert es_aggs == [extended_es_agg, "cardinality"] assert es_aggs == [extended_es_agg, "cardinality"]
es_aggs = Operations._map_pd_aggs_to_es_aggs(["count", pd_agg, "nunique"]) es_aggs = Operations._map_pd_aggs_to_es_aggs(["count", pd_agg, "nunique"])
assert es_aggs == [("extended_stats", "count"), extended_es_agg, "cardinality"] assert es_aggs == ["value_count", extended_es_agg, "cardinality"]