mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
Document DataFrame.groupby() and rename Field.index -> .column
This commit is contained in:
parent
abc5ca927b
commit
18fb4af731
6
docs/source/reference/api/eland.DataFrame.groupby.rst
Normal file
6
docs/source/reference/api/eland.DataFrame.groupby.rst
Normal file
@ -0,0 +1,6 @@
|
||||
eland.DataFrame.groupby
|
||||
=======================
|
||||
|
||||
.. currentmodule:: eland
|
||||
|
||||
.. automethod:: DataFrame.groupby
|
@ -46,6 +46,7 @@ Function Application, GroupBy & Window
|
||||
|
||||
DataFrame.agg
|
||||
DataFrame.aggregate
|
||||
DataFrame.groupby
|
||||
|
||||
.. _api.dataframe.stats:
|
||||
|
||||
|
@ -1442,13 +1442,10 @@ class DataFrame(NDFrame):
|
||||
by:
|
||||
column or list of columns used to groupby
|
||||
Currently accepts column or list of columns
|
||||
TODO Implement other combinations of by similar to pandas
|
||||
|
||||
dropna: default True
|
||||
If True, and if group keys contain NA values, NA values together with row/column will be dropped.
|
||||
TODO Implement False
|
||||
|
||||
TODO Implement remainder of pandas arguments
|
||||
Returns
|
||||
-------
|
||||
GroupByDataFrame
|
||||
@ -1495,18 +1492,18 @@ class DataFrame(NDFrame):
|
||||
[63 rows x 2 columns]
|
||||
"""
|
||||
if by is None:
|
||||
raise TypeError("by parameter should be specified to groupby")
|
||||
raise ValueError("by parameter should be specified to groupby")
|
||||
if isinstance(by, str):
|
||||
by = [by]
|
||||
if isinstance(by, (list, tuple)):
|
||||
remaining_columns = set(by) - set(self._query_compiler.columns)
|
||||
remaining_columns = sorted(set(by) - set(self._query_compiler.columns))
|
||||
if remaining_columns:
|
||||
raise KeyError(
|
||||
f"Requested columns {remaining_columns} not in the DataFrame."
|
||||
f"Requested columns {repr(remaining_columns)[1:-1]} not in the DataFrame"
|
||||
)
|
||||
|
||||
return GroupByDataFrame(
|
||||
by=by, query_compiler=self._query_compiler, dropna=dropna
|
||||
by=by, query_compiler=self._query_compiler.copy(), dropna=dropna
|
||||
)
|
||||
|
||||
def query(self, expr) -> "DataFrame":
|
||||
|
@ -64,7 +64,7 @@ ES_COMPATIBLE_TYPES: Dict[str, Set[str]] = {
|
||||
class Field(NamedTuple):
|
||||
"""Holds all information on a particular field in the mapping"""
|
||||
|
||||
index: str
|
||||
column: str
|
||||
es_field_name: str
|
||||
is_source: bool
|
||||
es_dtype: str
|
||||
@ -129,7 +129,7 @@ class FieldMappings:
|
||||
_mappings_capabilities: pandas.DataFrame
|
||||
A data frame summarising the capabilities of the index mapping
|
||||
|
||||
index - the eland display name
|
||||
column (index) - the eland display name
|
||||
|
||||
es_field_name - the Elasticsearch field name
|
||||
is_source - is top level field (i.e. not a multi-field sub-field)
|
||||
@ -537,13 +537,13 @@ class FieldMappings:
|
||||
"""
|
||||
|
||||
mapping_props = {}
|
||||
for field_name_name, dtype in dataframe.dtypes.iteritems():
|
||||
if es_type_overrides is not None and field_name_name in es_type_overrides:
|
||||
es_dtype = es_type_overrides[field_name_name]
|
||||
for column, dtype in dataframe.dtypes.iteritems():
|
||||
if es_type_overrides is not None and column in es_type_overrides:
|
||||
es_dtype = es_type_overrides[column]
|
||||
else:
|
||||
es_dtype = FieldMappings._pd_dtype_to_es_dtype(dtype)
|
||||
|
||||
mapping_props[field_name_name] = {"type": es_dtype}
|
||||
mapping_props[column] = {"type": es_dtype}
|
||||
|
||||
return {"mappings": {"properties": mapping_props}}
|
||||
|
||||
@ -708,9 +708,9 @@ class FieldMappings:
|
||||
|
||||
"""
|
||||
source_fields: List[Field] = []
|
||||
for index, row in self._mappings_capabilities.iterrows():
|
||||
for column, row in self._mappings_capabilities.iterrows():
|
||||
row = row.to_dict()
|
||||
row["index"] = index
|
||||
row["column"] = column
|
||||
source_fields.append(Field(**row))
|
||||
return source_fields
|
||||
|
||||
@ -731,13 +731,13 @@ class FieldMappings:
|
||||
groupby_fields: Dict[str, Field] = {}
|
||||
# groupby_fields: Union[List[Field], List[None]] = [None] * len(by)
|
||||
aggregatable_fields: List[Field] = []
|
||||
for index_name, row in self._mappings_capabilities.iterrows():
|
||||
for column, row in self._mappings_capabilities.iterrows():
|
||||
row = row.to_dict()
|
||||
row["index"] = index_name
|
||||
if index_name not in by:
|
||||
row["column"] = column
|
||||
if column not in by:
|
||||
aggregatable_fields.append(Field(**row))
|
||||
else:
|
||||
groupby_fields[index_name] = Field(**row)
|
||||
groupby_fields[column] = Field(**row)
|
||||
|
||||
# Maintain groupby order as given input
|
||||
return [groupby_fields[column] for column in by], aggregatable_fields
|
||||
|
@ -24,7 +24,7 @@ if TYPE_CHECKING:
|
||||
|
||||
class GroupBy:
|
||||
"""
|
||||
This holds all the groupby base methods
|
||||
Base class for calls to X.groupby([...])
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@ -34,7 +34,6 @@ class GroupBy:
|
||||
Query compiler object
|
||||
dropna:
|
||||
default is true, drop None/NaT/NaN values while grouping
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -47,9 +46,8 @@ class GroupBy:
|
||||
self._dropna: bool = dropna
|
||||
self._by: List[str] = by
|
||||
|
||||
# numeric_only=True by default for all aggs because pandas does the same
|
||||
def mean(self, numeric_only: bool = True) -> "pd.DataFrame":
|
||||
return self._query_compiler.groupby(
|
||||
return self._query_compiler.aggs_groupby(
|
||||
by=self._by,
|
||||
pd_aggs=["mean"],
|
||||
dropna=self._dropna,
|
||||
@ -57,7 +55,7 @@ class GroupBy:
|
||||
)
|
||||
|
||||
def var(self, numeric_only: bool = True) -> "pd.DataFrame":
|
||||
return self._query_compiler.groupby(
|
||||
return self._query_compiler.aggs_groupby(
|
||||
by=self._by,
|
||||
pd_aggs=["var"],
|
||||
dropna=self._dropna,
|
||||
@ -65,7 +63,7 @@ class GroupBy:
|
||||
)
|
||||
|
||||
def std(self, numeric_only: bool = True) -> "pd.DataFrame":
|
||||
return self._query_compiler.groupby(
|
||||
return self._query_compiler.aggs_groupby(
|
||||
by=self._by,
|
||||
pd_aggs=["std"],
|
||||
dropna=self._dropna,
|
||||
@ -73,7 +71,7 @@ class GroupBy:
|
||||
)
|
||||
|
||||
def mad(self, numeric_only: bool = True) -> "pd.DataFrame":
|
||||
return self._query_compiler.groupby(
|
||||
return self._query_compiler.aggs_groupby(
|
||||
by=self._by,
|
||||
pd_aggs=["mad"],
|
||||
dropna=self._dropna,
|
||||
@ -81,7 +79,7 @@ class GroupBy:
|
||||
)
|
||||
|
||||
def median(self, numeric_only: bool = True) -> "pd.DataFrame":
|
||||
return self._query_compiler.groupby(
|
||||
return self._query_compiler.aggs_groupby(
|
||||
by=self._by,
|
||||
pd_aggs=["median"],
|
||||
dropna=self._dropna,
|
||||
@ -89,7 +87,7 @@ class GroupBy:
|
||||
)
|
||||
|
||||
def sum(self, numeric_only: bool = True) -> "pd.DataFrame":
|
||||
return self._query_compiler.groupby(
|
||||
return self._query_compiler.aggs_groupby(
|
||||
by=self._by,
|
||||
pd_aggs=["sum"],
|
||||
dropna=self._dropna,
|
||||
@ -97,7 +95,7 @@ class GroupBy:
|
||||
)
|
||||
|
||||
def min(self, numeric_only: bool = True) -> "pd.DataFrame":
|
||||
return self._query_compiler.groupby(
|
||||
return self._query_compiler.aggs_groupby(
|
||||
by=self._by,
|
||||
pd_aggs=["min"],
|
||||
dropna=self._dropna,
|
||||
@ -105,7 +103,7 @@ class GroupBy:
|
||||
)
|
||||
|
||||
def max(self, numeric_only: bool = True) -> "pd.DataFrame":
|
||||
return self._query_compiler.groupby(
|
||||
return self._query_compiler.aggs_groupby(
|
||||
by=self._by,
|
||||
pd_aggs=["max"],
|
||||
dropna=self._dropna,
|
||||
@ -113,7 +111,7 @@ class GroupBy:
|
||||
)
|
||||
|
||||
def nunique(self) -> "pd.DataFrame":
|
||||
return self._query_compiler.groupby(
|
||||
return self._query_compiler.aggs_groupby(
|
||||
by=self._by,
|
||||
pd_aggs=["nunique"],
|
||||
dropna=self._dropna,
|
||||
@ -133,7 +131,6 @@ class GroupByDataFrame(GroupBy):
|
||||
Query compiler object
|
||||
dropna:
|
||||
default is true, drop None/NaT/NaN values while grouping
|
||||
|
||||
"""
|
||||
|
||||
def aggregate(self, func: List[str], numeric_only: bool = False) -> "pd.DataFrame":
|
||||
@ -157,13 +154,12 @@ class GroupByDataFrame(GroupBy):
|
||||
"""
|
||||
if isinstance(func, str):
|
||||
func = [func]
|
||||
# numeric_only is by default False because pandas does the same
|
||||
return self._query_compiler.groupby(
|
||||
return self._query_compiler.aggs_groupby(
|
||||
by=self._by,
|
||||
pd_aggs=func,
|
||||
dropna=self._dropna,
|
||||
numeric_only=numeric_only,
|
||||
is_agg=True,
|
||||
is_dataframe_agg=True,
|
||||
)
|
||||
|
||||
agg = aggregate
|
||||
|
@ -271,7 +271,7 @@ class Operations:
|
||||
min 1.000205e+02 0.000000e+00 0.000000e+00 0
|
||||
"""
|
||||
|
||||
return self._calculate_single_agg(
|
||||
return self._unpack_metric_aggs(
|
||||
fields=fields,
|
||||
es_aggs=es_aggs,
|
||||
pd_aggs=pd_aggs,
|
||||
@ -415,7 +415,7 @@ class Operations:
|
||||
df_weights = pd.DataFrame(data=weights)
|
||||
return df_bins, df_weights
|
||||
|
||||
def _calculate_single_agg(
|
||||
def _unpack_metric_aggs(
|
||||
self,
|
||||
fields: List["Field"],
|
||||
es_aggs: Union[List[str], List[Tuple[str, str]]],
|
||||
@ -425,8 +425,9 @@ class Operations:
|
||||
is_dataframe_agg: bool = False,
|
||||
):
|
||||
"""
|
||||
This method is used to calculate single agg calculations.
|
||||
Common for both metric aggs and groupby aggs
|
||||
This method unpacks metric aggregations JSON response.
|
||||
This can be called either directly on an aggs query
|
||||
or on an individual bucket within a composite aggregation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@ -533,21 +534,21 @@ class Operations:
|
||||
|
||||
# If numeric_only is True and We only have a NaN type field then we check for empty.
|
||||
if values:
|
||||
results[field.index] = values if len(values) > 1 else values[0]
|
||||
results[field.column] = values if len(values) > 1 else values[0]
|
||||
|
||||
return results
|
||||
|
||||
def groupby(
|
||||
def aggs_groupby(
|
||||
self,
|
||||
query_compiler: "QueryCompiler",
|
||||
by: List[str],
|
||||
pd_aggs: List[str],
|
||||
dropna: bool = True,
|
||||
is_agg: bool = False,
|
||||
is_dataframe_agg: bool = False,
|
||||
numeric_only: bool = True,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
This method is used to construct groupby dataframe
|
||||
This method is used to construct groupby aggregation dataframe
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@ -560,7 +561,7 @@ class Operations:
|
||||
dropna:
|
||||
Drop None values if True.
|
||||
TODO Not yet implemented
|
||||
is_agg:
|
||||
is_dataframe_agg:
|
||||
Know if groupby with aggregation or single agg is called.
|
||||
numeric_only:
|
||||
return either numeric values or NaN/NaT
|
||||
@ -574,13 +575,13 @@ class Operations:
|
||||
by=by,
|
||||
pd_aggs=pd_aggs,
|
||||
dropna=dropna,
|
||||
is_agg=is_agg,
|
||||
is_dataframe_agg=is_dataframe_agg,
|
||||
numeric_only=numeric_only,
|
||||
)
|
||||
|
||||
agg_df = pd.DataFrame(results, columns=results.keys()).set_index(by)
|
||||
|
||||
if is_agg:
|
||||
if is_dataframe_agg:
|
||||
# Convert header columns to MultiIndex
|
||||
agg_df.columns = pd.MultiIndex.from_product([headers, pd_aggs])
|
||||
|
||||
@ -592,7 +593,7 @@ class Operations:
|
||||
by: List[str],
|
||||
pd_aggs: List[str],
|
||||
dropna: bool = True,
|
||||
is_agg: bool = False,
|
||||
is_dataframe_agg: bool = False,
|
||||
numeric_only: bool = True,
|
||||
) -> Tuple[List[str], Dict[str, Any]]:
|
||||
"""
|
||||
@ -609,8 +610,8 @@ class Operations:
|
||||
dropna:
|
||||
Drop None values if True.
|
||||
TODO Not yet implemented
|
||||
is_agg:
|
||||
Know if groupby aggregation or single agg is called.
|
||||
is_dataframe_agg:
|
||||
Know if multi aggregation or single agg is called.
|
||||
numeric_only:
|
||||
return either numeric values or NaN/NaT
|
||||
|
||||
@ -627,13 +628,15 @@ class Operations:
|
||||
f"Can not count field matches if size is set {size}"
|
||||
)
|
||||
|
||||
by, fields = query_compiler._mappings.groupby_source_fields(by=by)
|
||||
by_fields, agg_fields = query_compiler._mappings.groupby_source_fields(by=by)
|
||||
|
||||
# Used defaultdict to avoid initialization of columns with lists
|
||||
response: Dict[str, List[Any]] = defaultdict(list)
|
||||
|
||||
if numeric_only:
|
||||
fields = [field for field in fields if (field.is_numeric or field.is_bool)]
|
||||
agg_fields = [
|
||||
field for field in agg_fields if (field.is_numeric or field.is_bool)
|
||||
]
|
||||
|
||||
body = Query(query_params.query)
|
||||
|
||||
@ -641,11 +644,13 @@ class Operations:
|
||||
es_aggs = self._map_pd_aggs_to_es_aggs(pd_aggs)
|
||||
|
||||
# Construct Query
|
||||
for b in by:
|
||||
for by_field in by_fields:
|
||||
# groupby fields will be term aggregations
|
||||
body.term_aggs(f"groupby_{b.index}", b.index)
|
||||
body.composite_agg_bucket_terms(
|
||||
name=f"groupby_{by_field.column}", field=by_field.es_field_name
|
||||
)
|
||||
|
||||
for field in fields:
|
||||
for field in agg_fields:
|
||||
for es_agg in es_aggs:
|
||||
if not field.is_es_agg_compatible(es_agg):
|
||||
continue
|
||||
@ -665,11 +670,11 @@ class Operations:
|
||||
)
|
||||
|
||||
# Composite aggregation
|
||||
body.composite_agg(
|
||||
body.composite_agg_start(
|
||||
size=DEFAULT_PAGINATION_SIZE, name="groupby_buckets", dropna=dropna
|
||||
)
|
||||
|
||||
def response_generator() -> Generator[List[str], None, List[str]]:
|
||||
def bucket_generator() -> Generator[List[str], None, List[str]]:
|
||||
"""
|
||||
e.g.
|
||||
"aggregations": {
|
||||
@ -696,43 +701,51 @@ class Operations:
|
||||
size=0,
|
||||
body=body.to_search_body(),
|
||||
)
|
||||
|
||||
# Pagination Logic
|
||||
if "after_key" in res["aggregations"]["groupby_buckets"]:
|
||||
composite_buckets = res["aggregations"]["groupby_buckets"]
|
||||
if "after_key" in composite_buckets:
|
||||
|
||||
# yield the bucket which contains the result
|
||||
yield res["aggregations"]["groupby_buckets"]["buckets"]
|
||||
yield composite_buckets["buckets"]
|
||||
|
||||
body.composite_agg_after_key(
|
||||
name="groupby_buckets",
|
||||
after_key=res["aggregations"]["groupby_buckets"]["after_key"],
|
||||
after_key=composite_buckets["after_key"],
|
||||
)
|
||||
else:
|
||||
return res["aggregations"]["groupby_buckets"]["buckets"]
|
||||
return composite_buckets["buckets"]
|
||||
|
||||
for buckets in response_generator():
|
||||
for buckets in bucket_generator():
|
||||
# We recieve response row-wise
|
||||
for bucket in buckets:
|
||||
# groupby columns are added to result same way they are returned
|
||||
for b in by:
|
||||
response[b.index].append(bucket["key"][f"groupby_{b.index}"])
|
||||
for by_field in by_fields:
|
||||
bucket_key = bucket["key"][f"groupby_{by_field.column}"]
|
||||
|
||||
agg_calculation = self._calculate_single_agg(
|
||||
fields=fields,
|
||||
# Datetimes always come back as integers, convert to pd.Timestamp()
|
||||
if by_field.is_timestamp and isinstance(bucket_key, int):
|
||||
bucket_key = pd.to_datetime(bucket_key, unit="ms")
|
||||
|
||||
response[by_field.column].append(bucket_key)
|
||||
|
||||
agg_calculation = self._unpack_metric_aggs(
|
||||
fields=agg_fields,
|
||||
es_aggs=es_aggs,
|
||||
pd_aggs=pd_aggs,
|
||||
response={"aggregations": bucket},
|
||||
numeric_only=numeric_only,
|
||||
is_dataframe_agg=is_agg,
|
||||
is_dataframe_agg=is_dataframe_agg,
|
||||
)
|
||||
# Process the calculated agg values to response
|
||||
for key, value in agg_calculation.items():
|
||||
if not is_agg:
|
||||
response[key].append(value)
|
||||
if isinstance(value, list):
|
||||
for pd_agg, val in zip(pd_aggs, value):
|
||||
response[f"{key}_{pd_agg}"].append(val)
|
||||
else:
|
||||
for i in range(0, len(pd_aggs)):
|
||||
response[f"{key}_{pd_aggs[i]}"].append(value[i])
|
||||
response[key].append(value)
|
||||
|
||||
return [field.index for field in fields], response
|
||||
return [field.column for field in agg_fields], response
|
||||
|
||||
@staticmethod
|
||||
def _map_pd_aggs_to_es_aggs(pd_aggs):
|
||||
|
@ -38,14 +38,17 @@ class Query:
|
||||
# type defs
|
||||
self._query: BooleanFilter
|
||||
self._aggs: Dict[str, Any]
|
||||
self._composite_aggs: Dict[str, Any]
|
||||
|
||||
if query is None:
|
||||
self._query = BooleanFilter()
|
||||
self._aggs = {}
|
||||
self._composite_aggs = {}
|
||||
else:
|
||||
# Deep copy the incoming query so we can change it
|
||||
self._query = deepcopy(query._query)
|
||||
self._aggs = deepcopy(query._aggs)
|
||||
self._composite_aggs = deepcopy(query._composite_aggs)
|
||||
|
||||
def exists(self, field: str, must: bool = True) -> None:
|
||||
"""
|
||||
@ -136,9 +139,9 @@ class Query:
|
||||
agg = {func: {"field": field}}
|
||||
self._aggs[name] = agg
|
||||
|
||||
def term_aggs(self, name: str, field: str) -> None:
|
||||
def composite_agg_bucket_terms(self, name: str, field: str) -> None:
|
||||
"""
|
||||
Add term agg e.g.
|
||||
Add terms agg for composite aggregation
|
||||
|
||||
"aggs": {
|
||||
"name": {
|
||||
@ -148,17 +151,36 @@ class Query:
|
||||
}
|
||||
}
|
||||
"""
|
||||
agg = {"terms": {"field": field}}
|
||||
self._aggs[name] = agg
|
||||
self._composite_aggs[name] = {"terms": {"field": field}}
|
||||
|
||||
def composite_agg(
|
||||
def composite_agg_bucket_date_histogram(
|
||||
self,
|
||||
name: str,
|
||||
field: str,
|
||||
calendar_interval: Optional[str] = None,
|
||||
fixed_interval: Optional[str] = None,
|
||||
) -> None:
|
||||
if (calendar_interval is None) == (fixed_interval is None):
|
||||
raise ValueError(
|
||||
"calendar_interval and fixed_interval parmaeters are mutually exclusive"
|
||||
)
|
||||
agg = {"field": field}
|
||||
if calendar_interval is not None:
|
||||
agg["calendar_interval"] = calendar_interval
|
||||
elif fixed_interval is not None:
|
||||
agg["fixed_interval"] = fixed_interval
|
||||
self._composite_aggs[name] = {"date_histogram": agg}
|
||||
|
||||
def composite_agg_start(
|
||||
self,
|
||||
name: str,
|
||||
size: int,
|
||||
dropna: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Add composite aggregation e.g.
|
||||
Start a composite aggregation. This should be called
|
||||
after calls to composite_agg_bucket_*(), etc.
|
||||
|
||||
https://www.elastic.co/guide/en/elasticsearch/reference/current/search-aggregations-bucket-composite-aggregation.html
|
||||
|
||||
"aggs": {
|
||||
@ -190,22 +212,22 @@ class Query:
|
||||
|
||||
"""
|
||||
sources: List[Dict[str, Dict[str, str]]] = []
|
||||
aggregations: Dict[str, Dict[str, str]] = {}
|
||||
|
||||
for _name, agg in self._aggs.items():
|
||||
if agg.get("terms"):
|
||||
if not dropna:
|
||||
agg["terms"]["missing_bucket"] = "true"
|
||||
sources.append({_name: agg})
|
||||
else:
|
||||
aggregations[_name] = agg
|
||||
# Go through all composite source aggregations
|
||||
# and apply dropna if needed.
|
||||
for bucket_agg_name, bucket_agg in self._composite_aggs.items():
|
||||
if bucket_agg.get("terms") and not dropna:
|
||||
bucket_agg = bucket_agg.copy()
|
||||
bucket_agg["terms"]["missing_bucket"] = "true"
|
||||
sources.append({bucket_agg_name: bucket_agg})
|
||||
self._composite_aggs.clear()
|
||||
|
||||
agg = {
|
||||
aggs = {
|
||||
"composite": {"size": size, "sources": sources},
|
||||
"aggregations": aggregations,
|
||||
"aggregations": self._aggs.copy(),
|
||||
}
|
||||
self._aggs.clear()
|
||||
self._aggs[name] = agg
|
||||
self._aggs[name] = aggs
|
||||
|
||||
def composite_agg_after_key(self, name: str, after_key: Dict[str, Any]) -> None:
|
||||
"""
|
||||
|
@ -550,15 +550,22 @@ class QueryCompiler:
|
||||
self, ["nunique"], numeric_only=False
|
||||
)
|
||||
|
||||
def groupby(
|
||||
def aggs_groupby(
|
||||
self,
|
||||
by: List[str],
|
||||
pd_aggs: List[str],
|
||||
dropna: bool = True,
|
||||
is_agg: bool = False,
|
||||
is_dataframe_agg: bool = False,
|
||||
numeric_only: bool = True,
|
||||
) -> pd.DataFrame:
|
||||
return self._operations.groupby(self, by, pd_aggs, dropna, is_agg, numeric_only)
|
||||
return self._operations.aggs_groupby(
|
||||
self,
|
||||
by=by,
|
||||
pd_aggs=pd_aggs,
|
||||
dropna=dropna,
|
||||
is_dataframe_agg=is_dataframe_agg,
|
||||
numeric_only=numeric_only,
|
||||
)
|
||||
|
||||
def value_counts(self, es_size):
|
||||
return self._operations.value_counts(self, es_size)
|
||||
|
@ -25,13 +25,10 @@ import pandas as pd
|
||||
|
||||
class TestGroupbyDataFrame(TestData):
|
||||
funcs = ["max", "min", "mean", "sum"]
|
||||
extended_funcs = ["median", "mad", "var", "std"]
|
||||
filter_data = [
|
||||
"AvgTicketPrice",
|
||||
"Cancelled",
|
||||
"dayOfWeek",
|
||||
"timestamp",
|
||||
"DestCountry",
|
||||
]
|
||||
|
||||
@pytest.mark.parametrize("numeric_only", [True])
|
||||
@ -41,14 +38,29 @@ class TestGroupbyDataFrame(TestData):
|
||||
pd_flights = self.pd_flights().filter(self.filter_data)
|
||||
ed_flights = self.ed_flights().filter(self.filter_data)
|
||||
|
||||
pd_groupby = pd_flights.groupby("Cancelled").agg(self.funcs, numeric_only)
|
||||
ed_groupby = ed_flights.groupby("Cancelled").agg(self.funcs, numeric_only)
|
||||
pd_groupby = pd_flights.groupby("Cancelled").agg(
|
||||
self.funcs, numeric_only=numeric_only
|
||||
)
|
||||
ed_groupby = ed_flights.groupby("Cancelled").agg(
|
||||
self.funcs, numeric_only=numeric_only
|
||||
)
|
||||
|
||||
# checking only values because dtypes are checked in aggs tests
|
||||
assert_frame_equal(pd_groupby, ed_groupby, check_exact=False, check_dtype=False)
|
||||
|
||||
@pytest.mark.parametrize("pd_agg", funcs)
|
||||
def test_groupby_aggregate_single_aggs(self, pd_agg):
|
||||
pd_flights = self.pd_flights().filter(self.filter_data)
|
||||
ed_flights = self.ed_flights().filter(self.filter_data)
|
||||
|
||||
pd_groupby = pd_flights.groupby("Cancelled").agg([pd_agg], numeric_only=True)
|
||||
ed_groupby = ed_flights.groupby("Cancelled").agg([pd_agg], numeric_only=True)
|
||||
|
||||
# checking only values because dtypes are checked in aggs tests
|
||||
assert_frame_equal(pd_groupby, ed_groupby, check_exact=False, check_dtype=False)
|
||||
|
||||
@pytest.mark.parametrize("pd_agg", ["max", "min", "mean", "sum", "median"])
|
||||
def test_groupby_aggs_true(self, pd_agg):
|
||||
def test_groupby_aggs_numeric_only_true(self, pd_agg):
|
||||
# Pandas has numeric_only applicable for the above aggs with groupby only.
|
||||
|
||||
pd_flights = self.pd_flights().filter(self.filter_data)
|
||||
@ -59,7 +71,7 @@ class TestGroupbyDataFrame(TestData):
|
||||
|
||||
# checking only values because dtypes are checked in aggs tests
|
||||
assert_frame_equal(
|
||||
pd_groupby, ed_groupby, check_exact=False, check_dtype=False, rtol=4
|
||||
pd_groupby, ed_groupby, check_exact=False, check_dtype=False, rtol=2
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("pd_agg", ["mad", "var", "std"])
|
||||
@ -90,9 +102,9 @@ class TestGroupbyDataFrame(TestData):
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("pd_agg", ["max", "min", "mean", "median"])
|
||||
def test_groupby_aggs_false(self, pd_agg):
|
||||
pd_flights = self.pd_flights().filter(self.filter_data)
|
||||
ed_flights = self.ed_flights().filter(self.filter_data)
|
||||
def test_groupby_aggs_numeric_only_false(self, pd_agg):
|
||||
pd_flights = self.pd_flights().filter(self.filter_data + ["timestamp"])
|
||||
ed_flights = self.ed_flights().filter(self.filter_data + ["timestamp"])
|
||||
|
||||
# pandas numeric_only=False, matches with Eland numeric_only=None
|
||||
pd_groupby = getattr(pd_flights.groupby("Cancelled"), pd_agg)(
|
||||
@ -114,14 +126,30 @@ class TestGroupbyDataFrame(TestData):
|
||||
ed_flights = self.ed_flights().filter(self.filter_data)
|
||||
|
||||
match = "by parameter should be specified to groupby"
|
||||
with pytest.raises(TypeError, match=match):
|
||||
with pytest.raises(ValueError, match=match):
|
||||
ed_flights.groupby(None).mean()
|
||||
|
||||
by = ["ABC", "Cancelled"]
|
||||
match = "Requested columns {'ABC'} not in the DataFrame."
|
||||
match = "Requested columns 'ABC' not in the DataFrame"
|
||||
with pytest.raises(KeyError, match=match):
|
||||
ed_flights.groupby(by).mean()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"by",
|
||||
["timestamp", "dayOfWeek", "Carrier", "Cancelled", ["dayOfWeek", "Carrier"]],
|
||||
)
|
||||
def test_groupby_different_dtypes(self, by):
|
||||
columns = ["dayOfWeek", "Carrier", "timestamp", "Cancelled"]
|
||||
pd_flights = self.pd_flights_small().filter(columns)
|
||||
ed_flights = self.ed_flights_small().filter(columns)
|
||||
|
||||
pd_groupby = pd_flights.groupby(by).nunique()
|
||||
ed_groupby = ed_flights.groupby(by).nunique()
|
||||
|
||||
assert list(pd_groupby.index) == list(ed_groupby.index)
|
||||
assert pd_groupby.index.dtype == ed_groupby.index.dtype
|
||||
assert list(pd_groupby.columns) == list(ed_groupby.columns)
|
||||
|
||||
def test_groupby_dropna(self):
|
||||
# TODO Add tests once dropna is implemeted
|
||||
pass
|
||||
|
Loading…
x
Reference in New Issue
Block a user