Implement DataFrameGroupBy.count()

This commit is contained in:
P. Sai Vinay 2020-10-23 19:11:50 +05:30 committed by GitHub
parent bd7956ea72
commit 475e0f41ef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 205 additions and 76 deletions

View File

@ -1474,6 +1474,7 @@ class DataFrame(NDFrame):
True 121.280296 1175.709961 0.0 6.0
<BLANKLINE>
[63 rows x 4 columns]
>>> ed_flights.groupby(["DestCountry", "Cancelled"]).mean(numeric_only=True) # doctest: +NORMALIZE_WHITESPACE
AvgTicketPrice dayOfWeek
DestCountry Cancelled
@ -1490,6 +1491,23 @@ class DataFrame(NDFrame):
True 677.794078 2.928571
<BLANKLINE>
[63 rows x 2 columns]
>>> ed_flights.groupby(["DestCountry", "Cancelled"]).min(numeric_only=False) # doctest: +NORMALIZE_WHITESPACE
AvgTicketPrice dayOfWeek timestamp
DestCountry Cancelled
AE False 110.799911 0 2018-01-01 19:31:30
True 132.443756 0 2018-01-06 13:03:25
AR False 125.589394 0 2018-01-01 01:30:47
True 251.389603 0 2018-01-01 02:13:17
AT False 100.020531 0 2018-01-01 05:24:19
... ... ... ...
TR True 307.915649 0 2018-01-08 04:35:10
US False 100.145966 0 2018-01-01 00:06:27
True 102.153069 0 2018-01-01 09:02:36
ZA False 102.002663 0 2018-01-01 06:44:44
True 121.280296 0 2018-01-04 00:37:01
<BLANKLINE>
[63 rows x 3 columns]
"""
if by is None:
raise ValueError("by parameter should be specified to groupby")

View File

@ -100,12 +100,12 @@ class Field(NamedTuple):
elif es_agg[0] == "percentiles":
es_agg = "percentiles"
# Cardinality works for all types
# Numerics and bools work for all aggs
# Except "median_absolute_deviation" which doesn't support bool
if es_agg == "median_absolute_deviation" and self.is_bool:
return False
if es_agg == "cardinality" or self.is_numeric or self.is_bool:
# Cardinality and Count work for all types
# Numerics and bools work for all aggs
if es_agg in ("cardinality", "value_count") or self.is_numeric or self.is_bool:
return True
# Timestamps also work for 'min', 'max' and 'avg'
if es_agg in {"min", "max", "avg", "percentiles"} and self.is_timestamp:
@ -730,7 +730,6 @@ class FieldMappings:
"""
groupby_fields: Dict[str, Field] = {}
# groupby_fields: Union[List[Field], List[None]] = [None] * len(by)
aggregatable_fields: List[Field] = []
for column, row in self._mappings_capabilities.iterrows():
row = row.to_dict()

View File

@ -152,15 +152,42 @@ class GroupByDataFrame(GroupBy):
- True: returns all values with float64, NaN/NaT are ignored.
- False: returns all values with float64.
- None: returns all values with default datatype.
Returns
-------
A Pandas DataFrame
"""
# Controls whether a MultiIndex is used for the
# columns of the result DataFrame.
is_dataframe_agg = True
if isinstance(func, str):
func = [func]
is_dataframe_agg = False
return self._query_compiler.aggs_groupby(
by=self._by,
pd_aggs=func,
dropna=self._dropna,
numeric_only=numeric_only,
is_dataframe_agg=True,
is_dataframe_agg=is_dataframe_agg,
)
agg = aggregate
def count(self) -> "pd.DataFrame":
"""
Used to groupby and count
Returns
-------
A Pandas DataFrame
"""
return self._query_compiler.aggs_groupby(
by=self._by,
pd_aggs=["count"],
dropna=self._dropna,
numeric_only=False,
is_dataframe_agg=False,
)

View File

@ -512,7 +512,7 @@ class Operations:
agg_value = np.NaN
# Cardinality is always either NaN or integer.
elif pd_agg == "nunique":
elif pd_agg in ("nunique", "count"):
agg_value = int(agg_value)
# If this is a non-null timestamp field convert to a pd.Timestamp()
@ -579,14 +579,63 @@ class Operations:
numeric_only=numeric_only,
)
agg_df = pd.DataFrame(results, columns=results.keys()).set_index(by)
agg_df = pd.DataFrame(results).set_index(by)
if is_dataframe_agg:
# Convert header columns to MultiIndex
agg_df.columns = pd.MultiIndex.from_product([headers, pd_aggs])
else:
# Convert header columns to Index
agg_df.columns = pd.Index(headers)
return agg_df
@staticmethod
def bucket_generator(
query_compiler: "QueryCompiler", body: "Query"
) -> Generator[List[str], None, List[str]]:
"""
This can be used for all groupby operations.
e.g.
"aggregations": {
"groupby_buckets": {
"after_key": {"total_quantity": 8},
"buckets": [
{
"key": {"total_quantity": 1},
"doc_count": 87,
"taxful_total_price_avg": {"value": 48.035978536496216},
}
],
}
}
Returns
-------
A generator which initially yields the bucket
If after_key is found, use it to fetch the next set of buckets.
"""
while True:
res = query_compiler._client.search(
index=query_compiler._index_pattern,
size=0,
body=body.to_search_body(),
)
# Pagination Logic
composite_buckets = res["aggregations"]["groupby_buckets"]
if "after_key" in composite_buckets:
# yield the bucket which contains the result
yield composite_buckets["buckets"]
body.composite_agg_after_key(
name="groupby_buckets",
after_key=composite_buckets["after_key"],
)
else:
return composite_buckets["buckets"]
def _groupby_aggs(
self,
query_compiler: "QueryCompiler",
@ -640,33 +689,45 @@ class Operations:
body = Query(query_params.query)
# To return for creating multi-index on columns
headers = [field.column for field in agg_fields]
# Convert pandas aggs to ES equivalent
es_aggs = self._map_pd_aggs_to_es_aggs(pd_aggs)
# pd_agg 'count' is handled via 'doc_count' from buckets
using_pd_agg_count = "count" in pd_aggs
# Construct Query
for by_field in by_fields:
# groupby fields will be term aggregations
body.composite_agg_bucket_terms(
name=f"groupby_{by_field.column}", field=by_field.es_field_name
name=f"groupby_{by_field.column}",
field=by_field.aggregatable_es_field_name,
)
for field in agg_fields:
for agg_field in agg_fields:
for es_agg in es_aggs:
if not field.is_es_agg_compatible(es_agg):
# Skip if the field isn't compatible or if the agg is
# 'value_count' as this value is pulled from bucket.doc_count.
if (
not agg_field.is_es_agg_compatible(es_agg)
or es_agg == "value_count"
):
continue
# If we have multiple 'extended_stats' etc. here we simply NOOP on 2nd call
if isinstance(es_agg, tuple):
body.metric_aggs(
f"{es_agg[0]}_{field.es_field_name}",
f"{es_agg[0]}_{agg_field.es_field_name}",
es_agg[0],
field.aggregatable_es_field_name,
agg_field.aggregatable_es_field_name,
)
else:
body.metric_aggs(
f"{es_agg}_{field.es_field_name}",
f"{es_agg}_{agg_field.es_field_name}",
es_agg,
field.aggregatable_es_field_name,
agg_field.aggregatable_es_field_name,
)
# Composite aggregation
@ -674,49 +735,7 @@ class Operations:
size=DEFAULT_PAGINATION_SIZE, name="groupby_buckets", dropna=dropna
)
def bucket_generator() -> Generator[List[str], None, List[str]]:
"""
e.g.
"aggregations": {
"groupby_buckets": {
"after_key": {"total_quantity": 8},
"buckets": [
{
"key": {"total_quantity": 1},
"doc_count": 87,
"taxful_total_price_avg": {"value": 48.035978536496216},
}
],
}
}
Returns
-------
A generator which initially yields the bucket
If after_key is found, use it to fetch the next set of buckets.
"""
while True:
res = query_compiler._client.search(
index=query_compiler._index_pattern,
size=0,
body=body.to_search_body(),
)
# Pagination Logic
composite_buckets = res["aggregations"]["groupby_buckets"]
if "after_key" in composite_buckets:
# yield the bucket which contains the result
yield composite_buckets["buckets"]
body.composite_agg_after_key(
name="groupby_buckets",
after_key=composite_buckets["after_key"],
)
else:
return composite_buckets["buckets"]
for buckets in bucket_generator():
for buckets in self.bucket_generator(query_compiler, body):
# We recieve response row-wise
for bucket in buckets:
# groupby columns are added to result same way they are returned
@ -729,6 +748,15 @@ class Operations:
response[by_field.column].append(bucket_key)
# Put 'doc_count' from bucket into each 'agg_field'
# to be extracted from _unpack_metric_aggs()
if using_pd_agg_count:
doc_count = bucket["doc_count"]
for agg_field in agg_fields:
bucket[f"value_count_{agg_field.es_field_name}"] = {
"value": doc_count
}
agg_calculation = self._unpack_metric_aggs(
fields=agg_fields,
es_aggs=es_aggs,
@ -737,15 +765,16 @@ class Operations:
numeric_only=numeric_only,
is_dataframe_agg=is_dataframe_agg,
)
# Process the calculated agg values to response
for key, value in agg_calculation.items():
if isinstance(value, list):
for pd_agg, val in zip(pd_aggs, value):
response[f"{key}_{pd_agg}"].append(val)
else:
if not isinstance(value, list):
response[key].append(value)
continue
for pd_agg, val in zip(pd_aggs, value):
response[f"{key}_{pd_agg}"].append(val)
return [field.column for field in agg_fields], response
return headers, response
@staticmethod
def _map_pd_aggs_to_es_aggs(pd_aggs):
@ -781,8 +810,8 @@ class Operations:
"""
# pd aggs that will be mapped to es aggs
# that can use 'extended_stats'.
extended_stats_pd_aggs = {"mean", "min", "max", "count", "sum", "var", "std"}
extended_stats_es_aggs = {"avg", "min", "max", "count", "sum"}
extended_stats_pd_aggs = {"mean", "min", "max", "sum", "var", "std"}
extended_stats_es_aggs = {"avg", "min", "max", "sum"}
extended_stats_calls = 0
es_aggs = []
@ -792,7 +821,7 @@ class Operations:
# Aggs that are 'extended_stats' compatible
if pd_agg == "count":
es_aggs.append("count")
es_aggs.append("value_count")
elif pd_agg == "max":
es_aggs.append("max")
elif pd_agg == "min":

View File

@ -195,8 +195,8 @@ class Query:
Parameters
----------
size: int
Pagination size.
size: int or None
Use composite aggregation with pagination if size is not None
name: str
Name of the buckets
dropna: bool
@ -215,10 +215,13 @@ class Query:
sources.append({bucket_agg_name: bucket_agg})
self._composite_aggs.clear()
aggs = {
"composite": {"size": size, "sources": sources},
"aggregations": self._aggs.copy(),
aggs: Dict[str, Dict[str, Any]] = {
"composite": {"size": size, "sources": sources}
}
if self._aggs:
aggs["aggregations"] = self._aggs.copy()
self._aggs.clear()
self._aggs[name] = aggs

View File

@ -16,9 +16,30 @@
# under the License.
# File called _pytest for PyCharm compatability
from pandas.testing import assert_series_equal
from eland.tests.common import TestData
class TestDataFrameCount:
class TestDataFrameCount(TestData):
filter_data = [
"AvgTicketPrice",
"Cancelled",
"dayOfWeek",
"timestamp",
"DestCountry",
]
def test_count(self, df):
df.load_dataset("ecommerce")
df.count()
def test_count_flights(self):
pd_flights = self.pd_flights().filter(self.filter_data)
ed_flights = self.ed_flights().filter(self.filter_data)
pd_count = pd_flights.count()
ed_count = ed_flights.count()
assert_series_equal(pd_count, ed_count)

View File

@ -19,7 +19,7 @@
import pandas as pd
import pytest
from pandas.testing import assert_frame_equal, assert_series_equal
from pandas.testing import assert_frame_equal, assert_index_equal, assert_series_equal
from eland.tests.common import TestData
@ -154,3 +154,25 @@ class TestGroupbyDataFrame(TestData):
def test_groupby_dropna(self):
# TODO Add tests once dropna is implemeted
pass
@pytest.mark.parametrize("groupby", ["dayOfWeek", ["dayOfWeek", "Cancelled"]])
@pytest.mark.parametrize(
["func", "func_args"],
[
("count", ()),
("agg", ("count",)),
("agg", (["count"],)),
("agg", (["max", "count", "min"],)),
],
)
def test_groupby_dataframe_count(self, groupby, func, func_args):
pd_flights = self.pd_flights().filter(self.filter_data)
ed_flights = self.ed_flights().filter(self.filter_data)
pd_count = getattr(pd_flights.groupby(groupby), func)(*func_args)
ed_count = getattr(ed_flights.groupby(groupby), func)(*func_args)
assert_index_equal(pd_count.columns, ed_count.columns)
assert_index_equal(pd_count.index, ed_count.index)
assert_frame_equal(pd_count, ed_count)
assert_series_equal(pd_count.dtypes, ed_count.dtypes)

View File

@ -20,7 +20,7 @@ import pandas as pd
# File called _pytest for PyCharm compatibility
import pytest
from pandas.testing import assert_series_equal
from pandas.testing import assert_frame_equal, assert_series_equal
from eland.tests.common import TestData
@ -414,3 +414,13 @@ class TestDataFrameMetrics(TestData):
assert isinstance(calculated_values["AvgTicketPrice"], float)
assert isinstance(calculated_values["dayOfWeek"], float)
assert calculated_values.shape == (2,)
def test_aggs_count(self):
pd_flights = self.pd_flights().filter(self.filter_data)
ed_flights = self.ed_flights().filter(self.filter_data)
pd_count = pd_flights.agg(["count"])
ed_count = ed_flights.agg(["count"])
assert_frame_equal(pd_count, ed_count)

View File

@ -30,7 +30,7 @@ def test_all_aggs():
("extended_stats", "std_deviation"),
("extended_stats", "variance"),
"median_absolute_deviation",
("extended_stats", "count"),
"value_count",
"cardinality",
("percentiles", "50.0"),
]
@ -40,7 +40,7 @@ def test_extended_stats_optimization():
# Tests that when '<agg>' and an 'extended_stats' agg are used together
# that ('extended_stats', '<agg>') is used instead of '<agg>'.
es_aggs = Operations._map_pd_aggs_to_es_aggs(["count", "nunique"])
assert es_aggs == ["count", "cardinality"]
assert es_aggs == ["value_count", "cardinality"]
for pd_agg in ["var", "std"]:
extended_es_agg = Operations._map_pd_aggs_to_es_aggs([pd_agg])[0]
@ -49,4 +49,4 @@ def test_extended_stats_optimization():
assert es_aggs == [extended_es_agg, "cardinality"]
es_aggs = Operations._map_pd_aggs_to_es_aggs(["count", pd_agg, "nunique"])
assert es_aggs == [("extended_stats", "count"), extended_es_agg, "cardinality"]
assert es_aggs == ["value_count", extended_es_agg, "cardinality"]