mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
Implement DataFrameGroupBy.count()
This commit is contained in:
parent
bd7956ea72
commit
475e0f41ef
@ -1474,6 +1474,7 @@ class DataFrame(NDFrame):
|
||||
True 121.280296 1175.709961 0.0 6.0
|
||||
<BLANKLINE>
|
||||
[63 rows x 4 columns]
|
||||
|
||||
>>> ed_flights.groupby(["DestCountry", "Cancelled"]).mean(numeric_only=True) # doctest: +NORMALIZE_WHITESPACE
|
||||
AvgTicketPrice dayOfWeek
|
||||
DestCountry Cancelled
|
||||
@ -1490,6 +1491,23 @@ class DataFrame(NDFrame):
|
||||
True 677.794078 2.928571
|
||||
<BLANKLINE>
|
||||
[63 rows x 2 columns]
|
||||
|
||||
>>> ed_flights.groupby(["DestCountry", "Cancelled"]).min(numeric_only=False) # doctest: +NORMALIZE_WHITESPACE
|
||||
AvgTicketPrice dayOfWeek timestamp
|
||||
DestCountry Cancelled
|
||||
AE False 110.799911 0 2018-01-01 19:31:30
|
||||
True 132.443756 0 2018-01-06 13:03:25
|
||||
AR False 125.589394 0 2018-01-01 01:30:47
|
||||
True 251.389603 0 2018-01-01 02:13:17
|
||||
AT False 100.020531 0 2018-01-01 05:24:19
|
||||
... ... ... ...
|
||||
TR True 307.915649 0 2018-01-08 04:35:10
|
||||
US False 100.145966 0 2018-01-01 00:06:27
|
||||
True 102.153069 0 2018-01-01 09:02:36
|
||||
ZA False 102.002663 0 2018-01-01 06:44:44
|
||||
True 121.280296 0 2018-01-04 00:37:01
|
||||
<BLANKLINE>
|
||||
[63 rows x 3 columns]
|
||||
"""
|
||||
if by is None:
|
||||
raise ValueError("by parameter should be specified to groupby")
|
||||
|
@ -100,12 +100,12 @@ class Field(NamedTuple):
|
||||
elif es_agg[0] == "percentiles":
|
||||
es_agg = "percentiles"
|
||||
|
||||
# Cardinality works for all types
|
||||
# Numerics and bools work for all aggs
|
||||
# Except "median_absolute_deviation" which doesn't support bool
|
||||
if es_agg == "median_absolute_deviation" and self.is_bool:
|
||||
return False
|
||||
if es_agg == "cardinality" or self.is_numeric or self.is_bool:
|
||||
# Cardinality and Count work for all types
|
||||
# Numerics and bools work for all aggs
|
||||
if es_agg in ("cardinality", "value_count") or self.is_numeric or self.is_bool:
|
||||
return True
|
||||
# Timestamps also work for 'min', 'max' and 'avg'
|
||||
if es_agg in {"min", "max", "avg", "percentiles"} and self.is_timestamp:
|
||||
@ -730,7 +730,6 @@ class FieldMappings:
|
||||
|
||||
"""
|
||||
groupby_fields: Dict[str, Field] = {}
|
||||
# groupby_fields: Union[List[Field], List[None]] = [None] * len(by)
|
||||
aggregatable_fields: List[Field] = []
|
||||
for column, row in self._mappings_capabilities.iterrows():
|
||||
row = row.to_dict()
|
||||
|
@ -152,15 +152,42 @@ class GroupByDataFrame(GroupBy):
|
||||
- True: returns all values with float64, NaN/NaT are ignored.
|
||||
- False: returns all values with float64.
|
||||
- None: returns all values with default datatype.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A Pandas DataFrame
|
||||
|
||||
"""
|
||||
# Controls whether a MultiIndex is used for the
|
||||
# columns of the result DataFrame.
|
||||
is_dataframe_agg = True
|
||||
if isinstance(func, str):
|
||||
func = [func]
|
||||
is_dataframe_agg = False
|
||||
|
||||
return self._query_compiler.aggs_groupby(
|
||||
by=self._by,
|
||||
pd_aggs=func,
|
||||
dropna=self._dropna,
|
||||
numeric_only=numeric_only,
|
||||
is_dataframe_agg=True,
|
||||
is_dataframe_agg=is_dataframe_agg,
|
||||
)
|
||||
|
||||
agg = aggregate
|
||||
|
||||
def count(self) -> "pd.DataFrame":
|
||||
"""
|
||||
Used to groupby and count
|
||||
|
||||
Returns
|
||||
-------
|
||||
A Pandas DataFrame
|
||||
|
||||
"""
|
||||
return self._query_compiler.aggs_groupby(
|
||||
by=self._by,
|
||||
pd_aggs=["count"],
|
||||
dropna=self._dropna,
|
||||
numeric_only=False,
|
||||
is_dataframe_agg=False,
|
||||
)
|
||||
|
@ -512,7 +512,7 @@ class Operations:
|
||||
agg_value = np.NaN
|
||||
|
||||
# Cardinality is always either NaN or integer.
|
||||
elif pd_agg == "nunique":
|
||||
elif pd_agg in ("nunique", "count"):
|
||||
agg_value = int(agg_value)
|
||||
|
||||
# If this is a non-null timestamp field convert to a pd.Timestamp()
|
||||
@ -579,14 +579,63 @@ class Operations:
|
||||
numeric_only=numeric_only,
|
||||
)
|
||||
|
||||
agg_df = pd.DataFrame(results, columns=results.keys()).set_index(by)
|
||||
agg_df = pd.DataFrame(results).set_index(by)
|
||||
|
||||
if is_dataframe_agg:
|
||||
# Convert header columns to MultiIndex
|
||||
agg_df.columns = pd.MultiIndex.from_product([headers, pd_aggs])
|
||||
else:
|
||||
# Convert header columns to Index
|
||||
agg_df.columns = pd.Index(headers)
|
||||
|
||||
return agg_df
|
||||
|
||||
@staticmethod
|
||||
def bucket_generator(
|
||||
query_compiler: "QueryCompiler", body: "Query"
|
||||
) -> Generator[List[str], None, List[str]]:
|
||||
"""
|
||||
This can be used for all groupby operations.
|
||||
e.g.
|
||||
"aggregations": {
|
||||
"groupby_buckets": {
|
||||
"after_key": {"total_quantity": 8},
|
||||
"buckets": [
|
||||
{
|
||||
"key": {"total_quantity": 1},
|
||||
"doc_count": 87,
|
||||
"taxful_total_price_avg": {"value": 48.035978536496216},
|
||||
}
|
||||
],
|
||||
}
|
||||
}
|
||||
Returns
|
||||
-------
|
||||
A generator which initially yields the bucket
|
||||
If after_key is found, use it to fetch the next set of buckets.
|
||||
|
||||
"""
|
||||
while True:
|
||||
res = query_compiler._client.search(
|
||||
index=query_compiler._index_pattern,
|
||||
size=0,
|
||||
body=body.to_search_body(),
|
||||
)
|
||||
|
||||
# Pagination Logic
|
||||
composite_buckets = res["aggregations"]["groupby_buckets"]
|
||||
if "after_key" in composite_buckets:
|
||||
|
||||
# yield the bucket which contains the result
|
||||
yield composite_buckets["buckets"]
|
||||
|
||||
body.composite_agg_after_key(
|
||||
name="groupby_buckets",
|
||||
after_key=composite_buckets["after_key"],
|
||||
)
|
||||
else:
|
||||
return composite_buckets["buckets"]
|
||||
|
||||
def _groupby_aggs(
|
||||
self,
|
||||
query_compiler: "QueryCompiler",
|
||||
@ -640,33 +689,45 @@ class Operations:
|
||||
|
||||
body = Query(query_params.query)
|
||||
|
||||
# To return for creating multi-index on columns
|
||||
headers = [field.column for field in agg_fields]
|
||||
|
||||
# Convert pandas aggs to ES equivalent
|
||||
es_aggs = self._map_pd_aggs_to_es_aggs(pd_aggs)
|
||||
|
||||
# pd_agg 'count' is handled via 'doc_count' from buckets
|
||||
using_pd_agg_count = "count" in pd_aggs
|
||||
|
||||
# Construct Query
|
||||
for by_field in by_fields:
|
||||
# groupby fields will be term aggregations
|
||||
body.composite_agg_bucket_terms(
|
||||
name=f"groupby_{by_field.column}", field=by_field.es_field_name
|
||||
name=f"groupby_{by_field.column}",
|
||||
field=by_field.aggregatable_es_field_name,
|
||||
)
|
||||
|
||||
for field in agg_fields:
|
||||
for agg_field in agg_fields:
|
||||
for es_agg in es_aggs:
|
||||
if not field.is_es_agg_compatible(es_agg):
|
||||
# Skip if the field isn't compatible or if the agg is
|
||||
# 'value_count' as this value is pulled from bucket.doc_count.
|
||||
if (
|
||||
not agg_field.is_es_agg_compatible(es_agg)
|
||||
or es_agg == "value_count"
|
||||
):
|
||||
continue
|
||||
|
||||
# If we have multiple 'extended_stats' etc. here we simply NOOP on 2nd call
|
||||
if isinstance(es_agg, tuple):
|
||||
body.metric_aggs(
|
||||
f"{es_agg[0]}_{field.es_field_name}",
|
||||
f"{es_agg[0]}_{agg_field.es_field_name}",
|
||||
es_agg[0],
|
||||
field.aggregatable_es_field_name,
|
||||
agg_field.aggregatable_es_field_name,
|
||||
)
|
||||
else:
|
||||
body.metric_aggs(
|
||||
f"{es_agg}_{field.es_field_name}",
|
||||
f"{es_agg}_{agg_field.es_field_name}",
|
||||
es_agg,
|
||||
field.aggregatable_es_field_name,
|
||||
agg_field.aggregatable_es_field_name,
|
||||
)
|
||||
|
||||
# Composite aggregation
|
||||
@ -674,49 +735,7 @@ class Operations:
|
||||
size=DEFAULT_PAGINATION_SIZE, name="groupby_buckets", dropna=dropna
|
||||
)
|
||||
|
||||
def bucket_generator() -> Generator[List[str], None, List[str]]:
|
||||
"""
|
||||
e.g.
|
||||
"aggregations": {
|
||||
"groupby_buckets": {
|
||||
"after_key": {"total_quantity": 8},
|
||||
"buckets": [
|
||||
{
|
||||
"key": {"total_quantity": 1},
|
||||
"doc_count": 87,
|
||||
"taxful_total_price_avg": {"value": 48.035978536496216},
|
||||
}
|
||||
],
|
||||
}
|
||||
}
|
||||
Returns
|
||||
-------
|
||||
A generator which initially yields the bucket
|
||||
If after_key is found, use it to fetch the next set of buckets.
|
||||
|
||||
"""
|
||||
while True:
|
||||
res = query_compiler._client.search(
|
||||
index=query_compiler._index_pattern,
|
||||
size=0,
|
||||
body=body.to_search_body(),
|
||||
)
|
||||
|
||||
# Pagination Logic
|
||||
composite_buckets = res["aggregations"]["groupby_buckets"]
|
||||
if "after_key" in composite_buckets:
|
||||
|
||||
# yield the bucket which contains the result
|
||||
yield composite_buckets["buckets"]
|
||||
|
||||
body.composite_agg_after_key(
|
||||
name="groupby_buckets",
|
||||
after_key=composite_buckets["after_key"],
|
||||
)
|
||||
else:
|
||||
return composite_buckets["buckets"]
|
||||
|
||||
for buckets in bucket_generator():
|
||||
for buckets in self.bucket_generator(query_compiler, body):
|
||||
# We recieve response row-wise
|
||||
for bucket in buckets:
|
||||
# groupby columns are added to result same way they are returned
|
||||
@ -729,6 +748,15 @@ class Operations:
|
||||
|
||||
response[by_field.column].append(bucket_key)
|
||||
|
||||
# Put 'doc_count' from bucket into each 'agg_field'
|
||||
# to be extracted from _unpack_metric_aggs()
|
||||
if using_pd_agg_count:
|
||||
doc_count = bucket["doc_count"]
|
||||
for agg_field in agg_fields:
|
||||
bucket[f"value_count_{agg_field.es_field_name}"] = {
|
||||
"value": doc_count
|
||||
}
|
||||
|
||||
agg_calculation = self._unpack_metric_aggs(
|
||||
fields=agg_fields,
|
||||
es_aggs=es_aggs,
|
||||
@ -737,15 +765,16 @@ class Operations:
|
||||
numeric_only=numeric_only,
|
||||
is_dataframe_agg=is_dataframe_agg,
|
||||
)
|
||||
|
||||
# Process the calculated agg values to response
|
||||
for key, value in agg_calculation.items():
|
||||
if isinstance(value, list):
|
||||
for pd_agg, val in zip(pd_aggs, value):
|
||||
response[f"{key}_{pd_agg}"].append(val)
|
||||
else:
|
||||
if not isinstance(value, list):
|
||||
response[key].append(value)
|
||||
continue
|
||||
for pd_agg, val in zip(pd_aggs, value):
|
||||
response[f"{key}_{pd_agg}"].append(val)
|
||||
|
||||
return [field.column for field in agg_fields], response
|
||||
return headers, response
|
||||
|
||||
@staticmethod
|
||||
def _map_pd_aggs_to_es_aggs(pd_aggs):
|
||||
@ -781,8 +810,8 @@ class Operations:
|
||||
"""
|
||||
# pd aggs that will be mapped to es aggs
|
||||
# that can use 'extended_stats'.
|
||||
extended_stats_pd_aggs = {"mean", "min", "max", "count", "sum", "var", "std"}
|
||||
extended_stats_es_aggs = {"avg", "min", "max", "count", "sum"}
|
||||
extended_stats_pd_aggs = {"mean", "min", "max", "sum", "var", "std"}
|
||||
extended_stats_es_aggs = {"avg", "min", "max", "sum"}
|
||||
extended_stats_calls = 0
|
||||
|
||||
es_aggs = []
|
||||
@ -792,7 +821,7 @@ class Operations:
|
||||
|
||||
# Aggs that are 'extended_stats' compatible
|
||||
if pd_agg == "count":
|
||||
es_aggs.append("count")
|
||||
es_aggs.append("value_count")
|
||||
elif pd_agg == "max":
|
||||
es_aggs.append("max")
|
||||
elif pd_agg == "min":
|
||||
|
@ -195,8 +195,8 @@ class Query:
|
||||
|
||||
Parameters
|
||||
----------
|
||||
size: int
|
||||
Pagination size.
|
||||
size: int or None
|
||||
Use composite aggregation with pagination if size is not None
|
||||
name: str
|
||||
Name of the buckets
|
||||
dropna: bool
|
||||
@ -215,10 +215,13 @@ class Query:
|
||||
sources.append({bucket_agg_name: bucket_agg})
|
||||
self._composite_aggs.clear()
|
||||
|
||||
aggs = {
|
||||
"composite": {"size": size, "sources": sources},
|
||||
"aggregations": self._aggs.copy(),
|
||||
aggs: Dict[str, Dict[str, Any]] = {
|
||||
"composite": {"size": size, "sources": sources}
|
||||
}
|
||||
|
||||
if self._aggs:
|
||||
aggs["aggregations"] = self._aggs.copy()
|
||||
|
||||
self._aggs.clear()
|
||||
self._aggs[name] = aggs
|
||||
|
||||
|
@ -16,9 +16,30 @@
|
||||
# under the License.
|
||||
|
||||
# File called _pytest for PyCharm compatability
|
||||
from pandas.testing import assert_series_equal
|
||||
|
||||
from eland.tests.common import TestData
|
||||
|
||||
|
||||
class TestDataFrameCount:
|
||||
class TestDataFrameCount(TestData):
|
||||
filter_data = [
|
||||
"AvgTicketPrice",
|
||||
"Cancelled",
|
||||
"dayOfWeek",
|
||||
"timestamp",
|
||||
"DestCountry",
|
||||
]
|
||||
|
||||
def test_count(self, df):
|
||||
df.load_dataset("ecommerce")
|
||||
df.count()
|
||||
|
||||
def test_count_flights(self):
|
||||
|
||||
pd_flights = self.pd_flights().filter(self.filter_data)
|
||||
ed_flights = self.ed_flights().filter(self.filter_data)
|
||||
|
||||
pd_count = pd_flights.count()
|
||||
ed_count = ed_flights.count()
|
||||
|
||||
assert_series_equal(pd_count, ed_count)
|
||||
|
@ -19,7 +19,7 @@
|
||||
|
||||
import pandas as pd
|
||||
import pytest
|
||||
from pandas.testing import assert_frame_equal, assert_series_equal
|
||||
from pandas.testing import assert_frame_equal, assert_index_equal, assert_series_equal
|
||||
|
||||
from eland.tests.common import TestData
|
||||
|
||||
@ -154,3 +154,25 @@ class TestGroupbyDataFrame(TestData):
|
||||
def test_groupby_dropna(self):
|
||||
# TODO Add tests once dropna is implemeted
|
||||
pass
|
||||
|
||||
@pytest.mark.parametrize("groupby", ["dayOfWeek", ["dayOfWeek", "Cancelled"]])
|
||||
@pytest.mark.parametrize(
|
||||
["func", "func_args"],
|
||||
[
|
||||
("count", ()),
|
||||
("agg", ("count",)),
|
||||
("agg", (["count"],)),
|
||||
("agg", (["max", "count", "min"],)),
|
||||
],
|
||||
)
|
||||
def test_groupby_dataframe_count(self, groupby, func, func_args):
|
||||
pd_flights = self.pd_flights().filter(self.filter_data)
|
||||
ed_flights = self.ed_flights().filter(self.filter_data)
|
||||
|
||||
pd_count = getattr(pd_flights.groupby(groupby), func)(*func_args)
|
||||
ed_count = getattr(ed_flights.groupby(groupby), func)(*func_args)
|
||||
|
||||
assert_index_equal(pd_count.columns, ed_count.columns)
|
||||
assert_index_equal(pd_count.index, ed_count.index)
|
||||
assert_frame_equal(pd_count, ed_count)
|
||||
assert_series_equal(pd_count.dtypes, ed_count.dtypes)
|
||||
|
@ -20,7 +20,7 @@ import pandas as pd
|
||||
|
||||
# File called _pytest for PyCharm compatibility
|
||||
import pytest
|
||||
from pandas.testing import assert_series_equal
|
||||
from pandas.testing import assert_frame_equal, assert_series_equal
|
||||
|
||||
from eland.tests.common import TestData
|
||||
|
||||
@ -414,3 +414,13 @@ class TestDataFrameMetrics(TestData):
|
||||
assert isinstance(calculated_values["AvgTicketPrice"], float)
|
||||
assert isinstance(calculated_values["dayOfWeek"], float)
|
||||
assert calculated_values.shape == (2,)
|
||||
|
||||
def test_aggs_count(self):
|
||||
|
||||
pd_flights = self.pd_flights().filter(self.filter_data)
|
||||
ed_flights = self.ed_flights().filter(self.filter_data)
|
||||
|
||||
pd_count = pd_flights.agg(["count"])
|
||||
ed_count = ed_flights.agg(["count"])
|
||||
|
||||
assert_frame_equal(pd_count, ed_count)
|
||||
|
@ -30,7 +30,7 @@ def test_all_aggs():
|
||||
("extended_stats", "std_deviation"),
|
||||
("extended_stats", "variance"),
|
||||
"median_absolute_deviation",
|
||||
("extended_stats", "count"),
|
||||
"value_count",
|
||||
"cardinality",
|
||||
("percentiles", "50.0"),
|
||||
]
|
||||
@ -40,7 +40,7 @@ def test_extended_stats_optimization():
|
||||
# Tests that when '<agg>' and an 'extended_stats' agg are used together
|
||||
# that ('extended_stats', '<agg>') is used instead of '<agg>'.
|
||||
es_aggs = Operations._map_pd_aggs_to_es_aggs(["count", "nunique"])
|
||||
assert es_aggs == ["count", "cardinality"]
|
||||
assert es_aggs == ["value_count", "cardinality"]
|
||||
|
||||
for pd_agg in ["var", "std"]:
|
||||
extended_es_agg = Operations._map_pd_aggs_to_es_aggs([pd_agg])[0]
|
||||
@ -49,4 +49,4 @@ def test_extended_stats_optimization():
|
||||
assert es_aggs == [extended_es_agg, "cardinality"]
|
||||
|
||||
es_aggs = Operations._map_pd_aggs_to_es_aggs(["count", pd_agg, "nunique"])
|
||||
assert es_aggs == [("extended_stats", "count"), extended_es_agg, "cardinality"]
|
||||
assert es_aggs == ["value_count", extended_es_agg, "cardinality"]
|
||||
|
Loading…
x
Reference in New Issue
Block a user