mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
Add support for DataFrame.groupby() with aggregations
This commit is contained in:
parent
adafeed667
commit
abc5ca927b
@ -31,6 +31,7 @@ DEFAULT_CHUNK_SIZE = 10000
|
|||||||
DEFAULT_CSV_BATCH_OUTPUT_SIZE = 10000
|
DEFAULT_CSV_BATCH_OUTPUT_SIZE = 10000
|
||||||
DEFAULT_PROGRESS_REPORTING_NUM_ROWS = 10000
|
DEFAULT_PROGRESS_REPORTING_NUM_ROWS = 10000
|
||||||
DEFAULT_ES_MAX_RESULT_WINDOW = 10000 # index.max_result_window
|
DEFAULT_ES_MAX_RESULT_WINDOW = 10000 # index.max_result_window
|
||||||
|
DEFAULT_PAGINATION_SIZE = 5000 # for composite aggregations
|
||||||
|
|
||||||
|
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
|
@ -19,7 +19,7 @@ import sys
|
|||||||
import warnings
|
import warnings
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
import re
|
import re
|
||||||
from typing import Optional, Sequence, Union, Tuple, List
|
from typing import List, Optional, Sequence, Union, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
@ -39,6 +39,7 @@ from eland.series import Series
|
|||||||
from eland.common import DEFAULT_NUM_ROWS_DISPLAYED, docstring_parameter
|
from eland.common import DEFAULT_NUM_ROWS_DISPLAYED, docstring_parameter
|
||||||
from eland.filter import BooleanFilter
|
from eland.filter import BooleanFilter
|
||||||
from eland.utils import deprecated_api, is_valid_attr_name
|
from eland.utils import deprecated_api, is_valid_attr_name
|
||||||
|
from eland.groupby import GroupByDataFrame
|
||||||
|
|
||||||
|
|
||||||
class DataFrame(NDFrame):
|
class DataFrame(NDFrame):
|
||||||
@ -1430,6 +1431,84 @@ class DataFrame(NDFrame):
|
|||||||
|
|
||||||
hist = gfx.ed_hist_frame
|
hist = gfx.ed_hist_frame
|
||||||
|
|
||||||
|
def groupby(
|
||||||
|
self, by: Optional[Union[str, List[str]]] = None, dropna: bool = True
|
||||||
|
) -> "GroupByDataFrame":
|
||||||
|
"""
|
||||||
|
Used to perform groupby operations
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
by:
|
||||||
|
column or list of columns used to groupby
|
||||||
|
Currently accepts column or list of columns
|
||||||
|
TODO Implement other combinations of by similar to pandas
|
||||||
|
|
||||||
|
dropna: default True
|
||||||
|
If True, and if group keys contain NA values, NA values together with row/column will be dropped.
|
||||||
|
TODO Implement False
|
||||||
|
|
||||||
|
TODO Implement remainder of pandas arguments
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
GroupByDataFrame
|
||||||
|
|
||||||
|
See Also
|
||||||
|
--------
|
||||||
|
:pandas_api_docs:`pandas.DataFrame.groupby`
|
||||||
|
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
>>> ed_flights = ed.DataFrame('localhost', 'flights', columns=["AvgTicketPrice", "Cancelled", "dayOfWeek", "timestamp", "DestCountry"])
|
||||||
|
>>> ed_flights.groupby(["DestCountry", "Cancelled"]).agg(["min", "max"], numeric_only=True) # doctest: +NORMALIZE_WHITESPACE
|
||||||
|
AvgTicketPrice dayOfWeek
|
||||||
|
min max min max
|
||||||
|
DestCountry Cancelled
|
||||||
|
AE False 110.799911 1126.148682 0.0 6.0
|
||||||
|
True 132.443756 817.931030 0.0 6.0
|
||||||
|
AR False 125.589394 1199.642822 0.0 6.0
|
||||||
|
True 251.389603 1172.382568 0.0 6.0
|
||||||
|
AT False 100.020531 1181.835815 0.0 6.0
|
||||||
|
... ... ... ... ...
|
||||||
|
TR True 307.915649 307.915649 0.0 0.0
|
||||||
|
US False 100.145966 1199.729004 0.0 6.0
|
||||||
|
True 102.153069 1192.429932 0.0 6.0
|
||||||
|
ZA False 102.002663 1196.186157 0.0 6.0
|
||||||
|
True 121.280296 1175.709961 0.0 6.0
|
||||||
|
<BLANKLINE>
|
||||||
|
[63 rows x 4 columns]
|
||||||
|
>>> ed_flights.groupby(["DestCountry", "Cancelled"]).mean(numeric_only=True) # doctest: +NORMALIZE_WHITESPACE
|
||||||
|
AvgTicketPrice dayOfWeek
|
||||||
|
DestCountry Cancelled
|
||||||
|
AE False 643.956793 2.717949
|
||||||
|
True 388.828809 2.571429
|
||||||
|
AR False 673.551677 2.746154
|
||||||
|
True 682.197241 2.733333
|
||||||
|
AT False 647.158290 2.819936
|
||||||
|
... ... ...
|
||||||
|
TR True 307.915649 0.000000
|
||||||
|
US False 598.063146 2.752014
|
||||||
|
True 579.799066 2.767068
|
||||||
|
ZA False 636.998605 2.738589
|
||||||
|
True 677.794078 2.928571
|
||||||
|
<BLANKLINE>
|
||||||
|
[63 rows x 2 columns]
|
||||||
|
"""
|
||||||
|
if by is None:
|
||||||
|
raise TypeError("by parameter should be specified to groupby")
|
||||||
|
if isinstance(by, str):
|
||||||
|
by = [by]
|
||||||
|
if isinstance(by, (list, tuple)):
|
||||||
|
remaining_columns = set(by) - set(self._query_compiler.columns)
|
||||||
|
if remaining_columns:
|
||||||
|
raise KeyError(
|
||||||
|
f"Requested columns {remaining_columns} not in the DataFrame."
|
||||||
|
)
|
||||||
|
|
||||||
|
return GroupByDataFrame(
|
||||||
|
by=by, query_compiler=self._query_compiler, dropna=dropna
|
||||||
|
)
|
||||||
|
|
||||||
def query(self, expr) -> "DataFrame":
|
def query(self, expr) -> "DataFrame":
|
||||||
"""
|
"""
|
||||||
Query the columns of a DataFrame with a boolean expression.
|
Query the columns of a DataFrame with a boolean expression.
|
||||||
|
@ -33,6 +33,7 @@ from typing import (
|
|||||||
Mapping,
|
Mapping,
|
||||||
Dict,
|
Dict,
|
||||||
Any,
|
Any,
|
||||||
|
Tuple,
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
List,
|
List,
|
||||||
Set,
|
Set,
|
||||||
@ -697,14 +698,50 @@ class FieldMappings:
|
|||||||
pd_dtypes, es_field_names, es_date_formats = self.metric_source_fields()
|
pd_dtypes, es_field_names, es_date_formats = self.metric_source_fields()
|
||||||
return es_field_names
|
return es_field_names
|
||||||
|
|
||||||
def all_source_fields(self):
|
def all_source_fields(self) -> List[Field]:
|
||||||
source_fields = []
|
"""
|
||||||
|
This method is used to return all Field Mappings for fields
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
A list of Field Mappings
|
||||||
|
|
||||||
|
"""
|
||||||
|
source_fields: List[Field] = []
|
||||||
for index, row in self._mappings_capabilities.iterrows():
|
for index, row in self._mappings_capabilities.iterrows():
|
||||||
row = row.to_dict()
|
row = row.to_dict()
|
||||||
row["index"] = index
|
row["index"] = index
|
||||||
source_fields.append(Field(**row))
|
source_fields.append(Field(**row))
|
||||||
return source_fields
|
return source_fields
|
||||||
|
|
||||||
|
def groupby_source_fields(self, by: List[str]) -> Tuple[List[Field], List[Field]]:
|
||||||
|
"""
|
||||||
|
This method returns all Field Mappings for groupby and non-groupby fields
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
by:
|
||||||
|
A list of groupby fields
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
A Tuple consisting of a list of field mappings for groupby and non-groupby fields
|
||||||
|
|
||||||
|
"""
|
||||||
|
groupby_fields: Dict[str, Field] = {}
|
||||||
|
# groupby_fields: Union[List[Field], List[None]] = [None] * len(by)
|
||||||
|
aggregatable_fields: List[Field] = []
|
||||||
|
for index_name, row in self._mappings_capabilities.iterrows():
|
||||||
|
row = row.to_dict()
|
||||||
|
row["index"] = index_name
|
||||||
|
if index_name not in by:
|
||||||
|
aggregatable_fields.append(Field(**row))
|
||||||
|
else:
|
||||||
|
groupby_fields[index_name] = Field(**row)
|
||||||
|
|
||||||
|
# Maintain groupby order as given input
|
||||||
|
return [groupby_fields[column] for column in by], aggregatable_fields
|
||||||
|
|
||||||
def metric_source_fields(self, include_bool=False, include_timestamp=False):
|
def metric_source_fields(self, include_bool=False, include_timestamp=False):
|
||||||
"""
|
"""
|
||||||
Returns
|
Returns
|
||||||
|
169
eland/groupby.py
Normal file
169
eland/groupby.py
Normal file
@ -0,0 +1,169 @@
|
|||||||
|
# Licensed to Elasticsearch B.V. under one or more contributor
|
||||||
|
# license agreements. See the NOTICE file distributed with
|
||||||
|
# this work for additional information regarding copyright
|
||||||
|
# ownership. Elasticsearch B.V. licenses this file to you under
|
||||||
|
# the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
# not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
|
||||||
|
from typing import List, TYPE_CHECKING
|
||||||
|
from eland.query_compiler import QueryCompiler
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
import pandas as pd # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
class GroupBy:
|
||||||
|
"""
|
||||||
|
This holds all the groupby base methods
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
by:
|
||||||
|
List of columns to groupby
|
||||||
|
query_compiler:
|
||||||
|
Query compiler object
|
||||||
|
dropna:
|
||||||
|
default is true, drop None/NaT/NaN values while grouping
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
by: List[str],
|
||||||
|
query_compiler: "QueryCompiler",
|
||||||
|
dropna: bool = True,
|
||||||
|
) -> None:
|
||||||
|
self._query_compiler: "QueryCompiler" = QueryCompiler(to_copy=query_compiler)
|
||||||
|
self._dropna: bool = dropna
|
||||||
|
self._by: List[str] = by
|
||||||
|
|
||||||
|
# numeric_only=True by default for all aggs because pandas does the same
|
||||||
|
def mean(self, numeric_only: bool = True) -> "pd.DataFrame":
|
||||||
|
return self._query_compiler.groupby(
|
||||||
|
by=self._by,
|
||||||
|
pd_aggs=["mean"],
|
||||||
|
dropna=self._dropna,
|
||||||
|
numeric_only=numeric_only,
|
||||||
|
)
|
||||||
|
|
||||||
|
def var(self, numeric_only: bool = True) -> "pd.DataFrame":
|
||||||
|
return self._query_compiler.groupby(
|
||||||
|
by=self._by,
|
||||||
|
pd_aggs=["var"],
|
||||||
|
dropna=self._dropna,
|
||||||
|
numeric_only=numeric_only,
|
||||||
|
)
|
||||||
|
|
||||||
|
def std(self, numeric_only: bool = True) -> "pd.DataFrame":
|
||||||
|
return self._query_compiler.groupby(
|
||||||
|
by=self._by,
|
||||||
|
pd_aggs=["std"],
|
||||||
|
dropna=self._dropna,
|
||||||
|
numeric_only=numeric_only,
|
||||||
|
)
|
||||||
|
|
||||||
|
def mad(self, numeric_only: bool = True) -> "pd.DataFrame":
|
||||||
|
return self._query_compiler.groupby(
|
||||||
|
by=self._by,
|
||||||
|
pd_aggs=["mad"],
|
||||||
|
dropna=self._dropna,
|
||||||
|
numeric_only=numeric_only,
|
||||||
|
)
|
||||||
|
|
||||||
|
def median(self, numeric_only: bool = True) -> "pd.DataFrame":
|
||||||
|
return self._query_compiler.groupby(
|
||||||
|
by=self._by,
|
||||||
|
pd_aggs=["median"],
|
||||||
|
dropna=self._dropna,
|
||||||
|
numeric_only=numeric_only,
|
||||||
|
)
|
||||||
|
|
||||||
|
def sum(self, numeric_only: bool = True) -> "pd.DataFrame":
|
||||||
|
return self._query_compiler.groupby(
|
||||||
|
by=self._by,
|
||||||
|
pd_aggs=["sum"],
|
||||||
|
dropna=self._dropna,
|
||||||
|
numeric_only=numeric_only,
|
||||||
|
)
|
||||||
|
|
||||||
|
def min(self, numeric_only: bool = True) -> "pd.DataFrame":
|
||||||
|
return self._query_compiler.groupby(
|
||||||
|
by=self._by,
|
||||||
|
pd_aggs=["min"],
|
||||||
|
dropna=self._dropna,
|
||||||
|
numeric_only=numeric_only,
|
||||||
|
)
|
||||||
|
|
||||||
|
def max(self, numeric_only: bool = True) -> "pd.DataFrame":
|
||||||
|
return self._query_compiler.groupby(
|
||||||
|
by=self._by,
|
||||||
|
pd_aggs=["max"],
|
||||||
|
dropna=self._dropna,
|
||||||
|
numeric_only=numeric_only,
|
||||||
|
)
|
||||||
|
|
||||||
|
def nunique(self) -> "pd.DataFrame":
|
||||||
|
return self._query_compiler.groupby(
|
||||||
|
by=self._by,
|
||||||
|
pd_aggs=["nunique"],
|
||||||
|
dropna=self._dropna,
|
||||||
|
numeric_only=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class GroupByDataFrame(GroupBy):
|
||||||
|
"""
|
||||||
|
This holds all the groupby methods for DataFrame
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
by:
|
||||||
|
List of columns to groupby
|
||||||
|
query_compiler:
|
||||||
|
Query compiler object
|
||||||
|
dropna:
|
||||||
|
default is true, drop None/NaT/NaN values while grouping
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def aggregate(self, func: List[str], numeric_only: bool = False) -> "pd.DataFrame":
|
||||||
|
"""
|
||||||
|
Used to groupby and aggregate
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
func:
|
||||||
|
Functions to use for aggregating the data.
|
||||||
|
|
||||||
|
Accepted combinations are:
|
||||||
|
- function
|
||||||
|
- list of functions
|
||||||
|
|
||||||
|
numeric_only: {True, False, None} Default is None
|
||||||
|
Which datatype to be returned
|
||||||
|
- True: returns all values with float64, NaN/NaT are ignored.
|
||||||
|
- False: returns all values with float64.
|
||||||
|
- None: returns all values with default datatype.
|
||||||
|
"""
|
||||||
|
if isinstance(func, str):
|
||||||
|
func = [func]
|
||||||
|
# numeric_only is by default False because pandas does the same
|
||||||
|
return self._query_compiler.groupby(
|
||||||
|
by=self._by,
|
||||||
|
pd_aggs=func,
|
||||||
|
dropna=self._dropna,
|
||||||
|
numeric_only=numeric_only,
|
||||||
|
is_agg=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
agg = aggregate
|
@ -16,12 +16,22 @@
|
|||||||
# under the License.
|
# under the License.
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
import typing
|
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Optional, Sequence, Tuple, List, Dict, Any
|
from typing import (
|
||||||
|
Generator,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Tuple,
|
||||||
|
List,
|
||||||
|
Dict,
|
||||||
|
Any,
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
from collections import defaultdict
|
||||||
from elasticsearch.helpers import scan
|
from elasticsearch.helpers import scan
|
||||||
|
|
||||||
from eland.index import Index
|
from eland.index import Index
|
||||||
@ -31,6 +41,7 @@ from eland.common import (
|
|||||||
DEFAULT_ES_MAX_RESULT_WINDOW,
|
DEFAULT_ES_MAX_RESULT_WINDOW,
|
||||||
elasticsearch_date_to_pandas_date,
|
elasticsearch_date_to_pandas_date,
|
||||||
build_pd_series,
|
build_pd_series,
|
||||||
|
DEFAULT_PAGINATION_SIZE,
|
||||||
)
|
)
|
||||||
from eland.query import Query
|
from eland.query import Query
|
||||||
from eland.actions import PostProcessingAction, SortFieldAction
|
from eland.actions import PostProcessingAction, SortFieldAction
|
||||||
@ -46,8 +57,9 @@ from eland.tasks import (
|
|||||||
SizeTask,
|
SizeTask,
|
||||||
)
|
)
|
||||||
|
|
||||||
if typing.TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from eland.query_compiler import QueryCompiler
|
from eland.query_compiler import QueryCompiler
|
||||||
|
from eland.field_mappings import Field
|
||||||
|
|
||||||
|
|
||||||
class QueryParams:
|
class QueryParams:
|
||||||
@ -186,10 +198,29 @@ class Operations:
|
|||||||
def _metric_aggs(
|
def _metric_aggs(
|
||||||
self,
|
self,
|
||||||
query_compiler: "QueryCompiler",
|
query_compiler: "QueryCompiler",
|
||||||
pd_aggs,
|
pd_aggs: List[str],
|
||||||
numeric_only: Optional[bool] = None,
|
numeric_only: Optional[bool] = None,
|
||||||
is_dataframe_agg: bool = False,
|
is_dataframe_agg: bool = False,
|
||||||
) -> Dict:
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Used to calculate metric aggregations
|
||||||
|
https://www.elastic.co/guide/en/elasticsearch/reference/current/search-aggregations-metrics.html
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
query_compiler:
|
||||||
|
Query Compiler object
|
||||||
|
pd_aggs:
|
||||||
|
aggregations that are to be performed on dataframe or series
|
||||||
|
numeric_only:
|
||||||
|
return either all numeric values or NaN/NaT
|
||||||
|
is_dataframe_agg:
|
||||||
|
know if this method is called from single-agg or aggreagation method
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
A dictionary which contains all aggregations calculated.
|
||||||
|
"""
|
||||||
query_params, post_processing = self._resolve_tasks(query_compiler)
|
query_params, post_processing = self._resolve_tasks(query_compiler)
|
||||||
|
|
||||||
size = self._size(query_params, post_processing)
|
size = self._size(query_params, post_processing)
|
||||||
@ -198,7 +229,6 @@ class Operations:
|
|||||||
f"Can not count field matches if size is set {size}"
|
f"Can not count field matches if size is set {size}"
|
||||||
)
|
)
|
||||||
|
|
||||||
results = {}
|
|
||||||
fields = query_compiler._mappings.all_source_fields()
|
fields = query_compiler._mappings.all_source_fields()
|
||||||
if numeric_only:
|
if numeric_only:
|
||||||
# Consider if field is Int/Float/Bool
|
# Consider if field is Int/Float/Bool
|
||||||
@ -240,95 +270,15 @@ class Operations:
|
|||||||
sum 8.204365e+06 9.261629e+07 5.754909e+07 618150
|
sum 8.204365e+06 9.261629e+07 5.754909e+07 618150
|
||||||
min 1.000205e+02 0.000000e+00 0.000000e+00 0
|
min 1.000205e+02 0.000000e+00 0.000000e+00 0
|
||||||
"""
|
"""
|
||||||
for field in fields:
|
|
||||||
values = []
|
|
||||||
for es_agg, pd_agg in zip(es_aggs, pd_aggs):
|
|
||||||
# is_dataframe_agg is used to differentiate agg() and an aggregation called through .mean()
|
|
||||||
# If the field and agg aren't compatible we add a NaN/NaT for agg
|
|
||||||
# If the field and agg aren't compatible we don't add NaN/NaT for an aggregation called through .mean()
|
|
||||||
if not field.is_es_agg_compatible(es_agg):
|
|
||||||
if is_dataframe_agg and not numeric_only:
|
|
||||||
values.append(field.nan_value)
|
|
||||||
elif not is_dataframe_agg and numeric_only is False:
|
|
||||||
values.append(field.nan_value)
|
|
||||||
# Explicit condition for mad to add NaN because it doesn't support bool
|
|
||||||
elif is_dataframe_agg and numeric_only:
|
|
||||||
if pd_agg == "mad":
|
|
||||||
values.append(field.nan_value)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if isinstance(es_agg, tuple):
|
return self._calculate_single_agg(
|
||||||
agg_value = response["aggregations"][
|
fields=fields,
|
||||||
f"{es_agg[0]}_{field.es_field_name}"
|
es_aggs=es_aggs,
|
||||||
]
|
pd_aggs=pd_aggs,
|
||||||
|
response=response,
|
||||||
# Pull multiple values from 'percentiles' result.
|
numeric_only=numeric_only,
|
||||||
if es_agg[0] == "percentiles":
|
is_dataframe_agg=is_dataframe_agg,
|
||||||
agg_value = agg_value["values"]
|
)
|
||||||
|
|
||||||
agg_value = agg_value[es_agg[1]]
|
|
||||||
|
|
||||||
# Need to convert 'Population' stddev and variance
|
|
||||||
# from Elasticsearch into 'Sample' stddev and variance
|
|
||||||
# which is what pandas uses.
|
|
||||||
if es_agg[1] in ("std_deviation", "variance"):
|
|
||||||
# Neither transformation works with count <=1
|
|
||||||
count = response["aggregations"][
|
|
||||||
f"{es_agg[0]}_{field.es_field_name}"
|
|
||||||
]["count"]
|
|
||||||
|
|
||||||
# All of the below calculations result in NaN if count<=1
|
|
||||||
if count <= 1:
|
|
||||||
agg_value = np.NaN
|
|
||||||
|
|
||||||
elif es_agg[1] == "std_deviation":
|
|
||||||
agg_value *= count / (count - 1.0)
|
|
||||||
|
|
||||||
else: # es_agg[1] == "variance"
|
|
||||||
# sample_std=\sqrt{\frac{1}{N-1}\sum_{i=1}^N(x_i-\bar{x})^2}
|
|
||||||
# population_std=\sqrt{\frac{1}{N}\sum_{i=1}^N(x_i-\bar{x})^2}
|
|
||||||
# sample_std=\sqrt{\frac{N}{N-1}population_std}
|
|
||||||
agg_value = np.sqrt(
|
|
||||||
(count / (count - 1.0)) * agg_value * agg_value
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
agg_value = response["aggregations"][
|
|
||||||
f"{es_agg}_{field.es_field_name}"
|
|
||||||
]["value"]
|
|
||||||
|
|
||||||
# Null usually means there were no results.
|
|
||||||
if agg_value is None or np.isnan(agg_value):
|
|
||||||
if is_dataframe_agg and not numeric_only:
|
|
||||||
agg_value = np.NaN
|
|
||||||
elif not is_dataframe_agg and numeric_only is False:
|
|
||||||
agg_value = np.NaN
|
|
||||||
|
|
||||||
# Cardinality is always either NaN or integer.
|
|
||||||
elif pd_agg == "nunique":
|
|
||||||
agg_value = int(agg_value)
|
|
||||||
|
|
||||||
# If this is a non-null timestamp field convert to a pd.Timestamp()
|
|
||||||
elif field.is_timestamp:
|
|
||||||
agg_value = elasticsearch_date_to_pandas_date(
|
|
||||||
agg_value, field.es_date_format
|
|
||||||
)
|
|
||||||
# If numeric_only is False | None then maintain column datatype
|
|
||||||
elif not numeric_only:
|
|
||||||
# we're only converting to bool for lossless aggs like min, max, and median.
|
|
||||||
if pd_agg in {"max", "min", "median", "sum"}:
|
|
||||||
# 'sum' isn't representable with bool, use int64
|
|
||||||
if pd_agg == "sum" and field.is_bool:
|
|
||||||
agg_value = np.int64(agg_value)
|
|
||||||
else:
|
|
||||||
agg_value = field.np_dtype.type(agg_value)
|
|
||||||
|
|
||||||
values.append(agg_value)
|
|
||||||
|
|
||||||
# If numeric_only is True and We only have a NaN type field then we check for empty.
|
|
||||||
if values:
|
|
||||||
results[field.index] = values if len(values) > 1 else values[0]
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
def _terms_aggs(self, query_compiler, func, es_size=None):
|
def _terms_aggs(self, query_compiler, func, es_size=None):
|
||||||
"""
|
"""
|
||||||
@ -465,6 +415,325 @@ class Operations:
|
|||||||
df_weights = pd.DataFrame(data=weights)
|
df_weights = pd.DataFrame(data=weights)
|
||||||
return df_bins, df_weights
|
return df_bins, df_weights
|
||||||
|
|
||||||
|
def _calculate_single_agg(
|
||||||
|
self,
|
||||||
|
fields: List["Field"],
|
||||||
|
es_aggs: Union[List[str], List[Tuple[str, str]]],
|
||||||
|
pd_aggs: List[str],
|
||||||
|
response: Dict[str, Any],
|
||||||
|
numeric_only: Optional[bool],
|
||||||
|
is_dataframe_agg: bool = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
This method is used to calculate single agg calculations.
|
||||||
|
Common for both metric aggs and groupby aggs
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
fields:
|
||||||
|
a list of Field Mappings
|
||||||
|
es_aggs:
|
||||||
|
Eland Equivalent of aggs
|
||||||
|
pd_aggs:
|
||||||
|
a list of aggs
|
||||||
|
response:
|
||||||
|
a dict containing response from Elastic Search
|
||||||
|
numeric_only:
|
||||||
|
return either numeric values or NaN/NaT
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
a dictionary on which agg caluculations are done.
|
||||||
|
"""
|
||||||
|
results: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
for field in fields:
|
||||||
|
values = []
|
||||||
|
for es_agg, pd_agg in zip(es_aggs, pd_aggs):
|
||||||
|
# is_dataframe_agg is used to differentiate agg() and an aggregation called through .mean()
|
||||||
|
# If the field and agg aren't compatible we add a NaN/NaT for agg
|
||||||
|
# If the field and agg aren't compatible we don't add NaN/NaT for an aggregation called through .mean()
|
||||||
|
if not field.is_es_agg_compatible(es_agg):
|
||||||
|
if is_dataframe_agg and not numeric_only:
|
||||||
|
values.append(field.nan_value)
|
||||||
|
elif not is_dataframe_agg and numeric_only is False:
|
||||||
|
values.append(field.nan_value)
|
||||||
|
# Explicit condition for mad to add NaN because it doesn't support bool
|
||||||
|
elif is_dataframe_agg and numeric_only:
|
||||||
|
if pd_agg == "mad":
|
||||||
|
values.append(field.nan_value)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if isinstance(es_agg, tuple):
|
||||||
|
agg_value = response["aggregations"][
|
||||||
|
f"{es_agg[0]}_{field.es_field_name}"
|
||||||
|
]
|
||||||
|
|
||||||
|
# Pull multiple values from 'percentiles' result.
|
||||||
|
if es_agg[0] == "percentiles":
|
||||||
|
agg_value = agg_value["values"]
|
||||||
|
|
||||||
|
agg_value = agg_value[es_agg[1]]
|
||||||
|
|
||||||
|
# Need to convert 'Population' stddev and variance
|
||||||
|
# from Elasticsearch into 'Sample' stddev and variance
|
||||||
|
# which is what pandas uses.
|
||||||
|
if es_agg[1] in ("std_deviation", "variance"):
|
||||||
|
# Neither transformation works with count <=1
|
||||||
|
count = response["aggregations"][
|
||||||
|
f"{es_agg[0]}_{field.es_field_name}"
|
||||||
|
]["count"]
|
||||||
|
|
||||||
|
# All of the below calculations result in NaN if count<=1
|
||||||
|
if count <= 1:
|
||||||
|
agg_value = np.NaN
|
||||||
|
|
||||||
|
elif es_agg[1] == "std_deviation":
|
||||||
|
agg_value *= count / (count - 1.0)
|
||||||
|
|
||||||
|
else: # es_agg[1] == "variance"
|
||||||
|
# sample_std=\sqrt{\frac{1}{N-1}\sum_{i=1}^N(x_i-\bar{x})^2}
|
||||||
|
# population_std=\sqrt{\frac{1}{N}\sum_{i=1}^N(x_i-\bar{x})^2}
|
||||||
|
# sample_std=\sqrt{\frac{N}{N-1}population_std}
|
||||||
|
agg_value = np.sqrt(
|
||||||
|
(count / (count - 1.0)) * agg_value * agg_value
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
agg_value = response["aggregations"][
|
||||||
|
f"{es_agg}_{field.es_field_name}"
|
||||||
|
]["value"]
|
||||||
|
|
||||||
|
# Null usually means there were no results.
|
||||||
|
if agg_value is None or np.isnan(agg_value):
|
||||||
|
if is_dataframe_agg and not numeric_only:
|
||||||
|
agg_value = np.NaN
|
||||||
|
elif not is_dataframe_agg and numeric_only is False:
|
||||||
|
agg_value = np.NaN
|
||||||
|
|
||||||
|
# Cardinality is always either NaN or integer.
|
||||||
|
elif pd_agg == "nunique":
|
||||||
|
agg_value = int(agg_value)
|
||||||
|
|
||||||
|
# If this is a non-null timestamp field convert to a pd.Timestamp()
|
||||||
|
elif field.is_timestamp:
|
||||||
|
agg_value = elasticsearch_date_to_pandas_date(
|
||||||
|
agg_value, field.es_date_format
|
||||||
|
)
|
||||||
|
# If numeric_only is False | None then maintain column datatype
|
||||||
|
elif not numeric_only:
|
||||||
|
# we're only converting to bool for lossless aggs like min, max, and median.
|
||||||
|
if pd_agg in {"max", "min", "median", "sum"}:
|
||||||
|
# 'sum' isn't representable with bool, use int64
|
||||||
|
if pd_agg == "sum" and field.is_bool:
|
||||||
|
agg_value = np.int64(agg_value)
|
||||||
|
else:
|
||||||
|
agg_value = field.np_dtype.type(agg_value)
|
||||||
|
|
||||||
|
values.append(agg_value)
|
||||||
|
|
||||||
|
# If numeric_only is True and We only have a NaN type field then we check for empty.
|
||||||
|
if values:
|
||||||
|
results[field.index] = values if len(values) > 1 else values[0]
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def groupby(
|
||||||
|
self,
|
||||||
|
query_compiler: "QueryCompiler",
|
||||||
|
by: List[str],
|
||||||
|
pd_aggs: List[str],
|
||||||
|
dropna: bool = True,
|
||||||
|
is_agg: bool = False,
|
||||||
|
numeric_only: bool = True,
|
||||||
|
) -> pd.DataFrame:
|
||||||
|
"""
|
||||||
|
This method is used to construct groupby dataframe
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
query_compiler:
|
||||||
|
A Query compiler
|
||||||
|
by:
|
||||||
|
a list of columns on which groupby operations have to be performed
|
||||||
|
pd_aggs:
|
||||||
|
a list of aggregations to be performed
|
||||||
|
dropna:
|
||||||
|
Drop None values if True.
|
||||||
|
TODO Not yet implemented
|
||||||
|
is_agg:
|
||||||
|
Know if groupby with aggregation or single agg is called.
|
||||||
|
numeric_only:
|
||||||
|
return either numeric values or NaN/NaT
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
A dataframe which consists groupby data
|
||||||
|
"""
|
||||||
|
headers, results = self._groupby_aggs(
|
||||||
|
query_compiler,
|
||||||
|
by=by,
|
||||||
|
pd_aggs=pd_aggs,
|
||||||
|
dropna=dropna,
|
||||||
|
is_agg=is_agg,
|
||||||
|
numeric_only=numeric_only,
|
||||||
|
)
|
||||||
|
|
||||||
|
agg_df = pd.DataFrame(results, columns=results.keys()).set_index(by)
|
||||||
|
|
||||||
|
if is_agg:
|
||||||
|
# Convert header columns to MultiIndex
|
||||||
|
agg_df.columns = pd.MultiIndex.from_product([headers, pd_aggs])
|
||||||
|
|
||||||
|
return agg_df
|
||||||
|
|
||||||
|
def _groupby_aggs(
|
||||||
|
self,
|
||||||
|
query_compiler: "QueryCompiler",
|
||||||
|
by: List[str],
|
||||||
|
pd_aggs: List[str],
|
||||||
|
dropna: bool = True,
|
||||||
|
is_agg: bool = False,
|
||||||
|
numeric_only: bool = True,
|
||||||
|
) -> Tuple[List[str], Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
This method is used to calculate groupby aggregations
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
query_compiler:
|
||||||
|
A Query compiler
|
||||||
|
by:
|
||||||
|
a list of columns on which groupby operations have to be performed
|
||||||
|
pd_aggs:
|
||||||
|
a list of aggregations to be performed
|
||||||
|
dropna:
|
||||||
|
Drop None values if True.
|
||||||
|
TODO Not yet implemented
|
||||||
|
is_agg:
|
||||||
|
Know if groupby aggregation or single agg is called.
|
||||||
|
numeric_only:
|
||||||
|
return either numeric values or NaN/NaT
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
headers: columns on which MultiIndex has to be applied
|
||||||
|
response: dictionary of groupby aggregated values
|
||||||
|
"""
|
||||||
|
query_params, post_processing = self._resolve_tasks(query_compiler)
|
||||||
|
|
||||||
|
size = self._size(query_params, post_processing)
|
||||||
|
if size is not None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"Can not count field matches if size is set {size}"
|
||||||
|
)
|
||||||
|
|
||||||
|
by, fields = query_compiler._mappings.groupby_source_fields(by=by)
|
||||||
|
|
||||||
|
# Used defaultdict to avoid initialization of columns with lists
|
||||||
|
response: Dict[str, List[Any]] = defaultdict(list)
|
||||||
|
|
||||||
|
if numeric_only:
|
||||||
|
fields = [field for field in fields if (field.is_numeric or field.is_bool)]
|
||||||
|
|
||||||
|
body = Query(query_params.query)
|
||||||
|
|
||||||
|
# Convert pandas aggs to ES equivalent
|
||||||
|
es_aggs = self._map_pd_aggs_to_es_aggs(pd_aggs)
|
||||||
|
|
||||||
|
# Construct Query
|
||||||
|
for b in by:
|
||||||
|
# groupby fields will be term aggregations
|
||||||
|
body.term_aggs(f"groupby_{b.index}", b.index)
|
||||||
|
|
||||||
|
for field in fields:
|
||||||
|
for es_agg in es_aggs:
|
||||||
|
if not field.is_es_agg_compatible(es_agg):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# If we have multiple 'extended_stats' etc. here we simply NOOP on 2nd call
|
||||||
|
if isinstance(es_agg, tuple):
|
||||||
|
body.metric_aggs(
|
||||||
|
f"{es_agg[0]}_{field.es_field_name}",
|
||||||
|
es_agg[0],
|
||||||
|
field.aggregatable_es_field_name,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
body.metric_aggs(
|
||||||
|
f"{es_agg}_{field.es_field_name}",
|
||||||
|
es_agg,
|
||||||
|
field.aggregatable_es_field_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Composite aggregation
|
||||||
|
body.composite_agg(
|
||||||
|
size=DEFAULT_PAGINATION_SIZE, name="groupby_buckets", dropna=dropna
|
||||||
|
)
|
||||||
|
|
||||||
|
def response_generator() -> Generator[List[str], None, List[str]]:
|
||||||
|
"""
|
||||||
|
e.g.
|
||||||
|
"aggregations": {
|
||||||
|
"groupby_buckets": {
|
||||||
|
"after_key": {"total_quantity": 8},
|
||||||
|
"buckets": [
|
||||||
|
{
|
||||||
|
"key": {"total_quantity": 1},
|
||||||
|
"doc_count": 87,
|
||||||
|
"taxful_total_price_avg": {"value": 48.035978536496216},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
A generator which initially yields the bucket
|
||||||
|
If after_key is found, use it to fetch the next set of buckets.
|
||||||
|
|
||||||
|
"""
|
||||||
|
while True:
|
||||||
|
res = query_compiler._client.search(
|
||||||
|
index=query_compiler._index_pattern,
|
||||||
|
size=0,
|
||||||
|
body=body.to_search_body(),
|
||||||
|
)
|
||||||
|
# Pagination Logic
|
||||||
|
if "after_key" in res["aggregations"]["groupby_buckets"]:
|
||||||
|
|
||||||
|
# yield the bucket which contains the result
|
||||||
|
yield res["aggregations"]["groupby_buckets"]["buckets"]
|
||||||
|
|
||||||
|
body.composite_agg_after_key(
|
||||||
|
name="groupby_buckets",
|
||||||
|
after_key=res["aggregations"]["groupby_buckets"]["after_key"],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return res["aggregations"]["groupby_buckets"]["buckets"]
|
||||||
|
|
||||||
|
for buckets in response_generator():
|
||||||
|
# We recieve response row-wise
|
||||||
|
for bucket in buckets:
|
||||||
|
# groupby columns are added to result same way they are returned
|
||||||
|
for b in by:
|
||||||
|
response[b.index].append(bucket["key"][f"groupby_{b.index}"])
|
||||||
|
|
||||||
|
agg_calculation = self._calculate_single_agg(
|
||||||
|
fields=fields,
|
||||||
|
es_aggs=es_aggs,
|
||||||
|
pd_aggs=pd_aggs,
|
||||||
|
response={"aggregations": bucket},
|
||||||
|
numeric_only=numeric_only,
|
||||||
|
is_dataframe_agg=is_agg,
|
||||||
|
)
|
||||||
|
# Process the calculated agg values to response
|
||||||
|
for key, value in agg_calculation.items():
|
||||||
|
if not is_agg:
|
||||||
|
response[key].append(value)
|
||||||
|
else:
|
||||||
|
for i in range(0, len(pd_aggs)):
|
||||||
|
response[f"{key}_{pd_aggs[i]}"].append(value[i])
|
||||||
|
|
||||||
|
return [field.index for field in fields], response
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _map_pd_aggs_to_es_aggs(pd_aggs):
|
def _map_pd_aggs_to_es_aggs(pd_aggs):
|
||||||
"""
|
"""
|
||||||
|
@ -136,6 +136,90 @@ class Query:
|
|||||||
agg = {func: {"field": field}}
|
agg = {func: {"field": field}}
|
||||||
self._aggs[name] = agg
|
self._aggs[name] = agg
|
||||||
|
|
||||||
|
def term_aggs(self, name: str, field: str) -> None:
|
||||||
|
"""
|
||||||
|
Add term agg e.g.
|
||||||
|
|
||||||
|
"aggs": {
|
||||||
|
"name": {
|
||||||
|
"terms": {
|
||||||
|
"field": "AvgTicketPrice"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
agg = {"terms": {"field": field}}
|
||||||
|
self._aggs[name] = agg
|
||||||
|
|
||||||
|
def composite_agg(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
size: int,
|
||||||
|
dropna: bool = True,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Add composite aggregation e.g.
|
||||||
|
https://www.elastic.co/guide/en/elasticsearch/reference/current/search-aggregations-bucket-composite-aggregation.html
|
||||||
|
|
||||||
|
"aggs": {
|
||||||
|
"groupby_buckets": {
|
||||||
|
"composite": {
|
||||||
|
"size": 10,
|
||||||
|
"sources": [
|
||||||
|
{"total_quantity": {"terms": {"field": "total_quantity"}}}
|
||||||
|
],
|
||||||
|
"after": {"total_quantity": 8},
|
||||||
|
},
|
||||||
|
"aggregations": {
|
||||||
|
"taxful_total_price_avg": {
|
||||||
|
"avg": {"field": "taxful_total_price"}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
size: int
|
||||||
|
Pagination size.
|
||||||
|
name: str
|
||||||
|
Name of the buckets
|
||||||
|
dropna: bool
|
||||||
|
Drop None values if True.
|
||||||
|
TODO Not yet implemented
|
||||||
|
|
||||||
|
"""
|
||||||
|
sources: List[Dict[str, Dict[str, str]]] = []
|
||||||
|
aggregations: Dict[str, Dict[str, str]] = {}
|
||||||
|
|
||||||
|
for _name, agg in self._aggs.items():
|
||||||
|
if agg.get("terms"):
|
||||||
|
if not dropna:
|
||||||
|
agg["terms"]["missing_bucket"] = "true"
|
||||||
|
sources.append({_name: agg})
|
||||||
|
else:
|
||||||
|
aggregations[_name] = agg
|
||||||
|
|
||||||
|
agg = {
|
||||||
|
"composite": {"size": size, "sources": sources},
|
||||||
|
"aggregations": aggregations,
|
||||||
|
}
|
||||||
|
self._aggs.clear()
|
||||||
|
self._aggs[name] = agg
|
||||||
|
|
||||||
|
def composite_agg_after_key(self, name: str, after_key: Dict[str, Any]) -> None:
|
||||||
|
"""
|
||||||
|
Add's after_key to existing query to fetch next bunch of results
|
||||||
|
|
||||||
|
PARAMETERS
|
||||||
|
----------
|
||||||
|
name: str
|
||||||
|
Name of the buckets
|
||||||
|
after_key: Dict[str, Any]
|
||||||
|
Dictionary returned from previous query results
|
||||||
|
"""
|
||||||
|
self._aggs[name]["composite"]["after"] = after_key
|
||||||
|
|
||||||
def hist_aggs(
|
def hist_aggs(
|
||||||
self,
|
self,
|
||||||
name: str,
|
name: str,
|
||||||
|
@ -19,8 +19,8 @@ import copy
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional, Sequence, TYPE_CHECKING, List
|
from typing import Optional, Sequence, TYPE_CHECKING, List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np # type: ignore
|
||||||
import pandas as pd
|
import pandas as pd # type: ignore
|
||||||
|
|
||||||
from eland.field_mappings import FieldMappings
|
from eland.field_mappings import FieldMappings
|
||||||
from eland.filter import QueryFilter
|
from eland.filter import QueryFilter
|
||||||
@ -72,7 +72,7 @@ class QueryCompiler:
|
|||||||
display_names=None,
|
display_names=None,
|
||||||
index_field=None,
|
index_field=None,
|
||||||
to_copy=None,
|
to_copy=None,
|
||||||
):
|
) -> None:
|
||||||
# Implement copy as we don't deep copy the client
|
# Implement copy as we don't deep copy the client
|
||||||
if to_copy is not None:
|
if to_copy is not None:
|
||||||
self._client = to_copy._client
|
self._client = to_copy._client
|
||||||
@ -550,6 +550,16 @@ class QueryCompiler:
|
|||||||
self, ["nunique"], numeric_only=False
|
self, ["nunique"], numeric_only=False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def groupby(
|
||||||
|
self,
|
||||||
|
by: List[str],
|
||||||
|
pd_aggs: List[str],
|
||||||
|
dropna: bool = True,
|
||||||
|
is_agg: bool = False,
|
||||||
|
numeric_only: bool = True,
|
||||||
|
) -> pd.DataFrame:
|
||||||
|
return self._operations.groupby(self, by, pd_aggs, dropna, is_agg, numeric_only)
|
||||||
|
|
||||||
def value_counts(self, es_size):
|
def value_counts(self, es_size):
|
||||||
return self._operations.value_counts(self, es_size)
|
return self._operations.value_counts(self, es_size)
|
||||||
|
|
||||||
|
127
eland/tests/dataframe/test_groupby_pytest.py
Normal file
127
eland/tests/dataframe/test_groupby_pytest.py
Normal file
@ -0,0 +1,127 @@
|
|||||||
|
# Licensed to Elasticsearch B.V. under one or more contributor
|
||||||
|
# license agreements. See the NOTICE file distributed with
|
||||||
|
# this work for additional information regarding copyright
|
||||||
|
# ownership. Elasticsearch B.V. licenses this file to you under
|
||||||
|
# the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
# not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
|
||||||
|
# File called _pytest for PyCharm compatability
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pandas.testing import assert_frame_equal, assert_series_equal
|
||||||
|
from eland.tests.common import TestData
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
|
||||||
|
class TestGroupbyDataFrame(TestData):
|
||||||
|
funcs = ["max", "min", "mean", "sum"]
|
||||||
|
extended_funcs = ["median", "mad", "var", "std"]
|
||||||
|
filter_data = [
|
||||||
|
"AvgTicketPrice",
|
||||||
|
"Cancelled",
|
||||||
|
"dayOfWeek",
|
||||||
|
"timestamp",
|
||||||
|
"DestCountry",
|
||||||
|
]
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("numeric_only", [True])
|
||||||
|
def test_groupby_aggregate(self, numeric_only):
|
||||||
|
# TODO Add tests for numeric_only=False for aggs
|
||||||
|
# when we support aggregations on text fields
|
||||||
|
pd_flights = self.pd_flights().filter(self.filter_data)
|
||||||
|
ed_flights = self.ed_flights().filter(self.filter_data)
|
||||||
|
|
||||||
|
pd_groupby = pd_flights.groupby("Cancelled").agg(self.funcs, numeric_only)
|
||||||
|
ed_groupby = ed_flights.groupby("Cancelled").agg(self.funcs, numeric_only)
|
||||||
|
|
||||||
|
# checking only values because dtypes are checked in aggs tests
|
||||||
|
assert_frame_equal(pd_groupby, ed_groupby, check_exact=False, check_dtype=False)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("pd_agg", ["max", "min", "mean", "sum", "median"])
|
||||||
|
def test_groupby_aggs_true(self, pd_agg):
|
||||||
|
# Pandas has numeric_only applicable for the above aggs with groupby only.
|
||||||
|
|
||||||
|
pd_flights = self.pd_flights().filter(self.filter_data)
|
||||||
|
ed_flights = self.ed_flights().filter(self.filter_data)
|
||||||
|
|
||||||
|
pd_groupby = getattr(pd_flights.groupby("Cancelled"), pd_agg)(numeric_only=True)
|
||||||
|
ed_groupby = getattr(ed_flights.groupby("Cancelled"), pd_agg)(numeric_only=True)
|
||||||
|
|
||||||
|
# checking only values because dtypes are checked in aggs tests
|
||||||
|
assert_frame_equal(
|
||||||
|
pd_groupby, ed_groupby, check_exact=False, check_dtype=False, rtol=4
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("pd_agg", ["mad", "var", "std"])
|
||||||
|
def test_groupby_aggs_mad_var_std(self, pd_agg):
|
||||||
|
# For these aggs pandas doesn't support numeric_only
|
||||||
|
pd_flights = self.pd_flights().filter(self.filter_data)
|
||||||
|
ed_flights = self.ed_flights().filter(self.filter_data)
|
||||||
|
|
||||||
|
pd_groupby = getattr(pd_flights.groupby("Cancelled"), pd_agg)()
|
||||||
|
ed_groupby = getattr(ed_flights.groupby("Cancelled"), pd_agg)(numeric_only=True)
|
||||||
|
|
||||||
|
# checking only values because dtypes are checked in aggs tests
|
||||||
|
assert_frame_equal(
|
||||||
|
pd_groupby, ed_groupby, check_exact=False, check_dtype=False, rtol=4
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("pd_agg", ["nunique"])
|
||||||
|
def test_groupby_aggs_nunique(self, pd_agg):
|
||||||
|
pd_flights = self.pd_flights().filter(self.filter_data)
|
||||||
|
ed_flights = self.ed_flights().filter(self.filter_data)
|
||||||
|
|
||||||
|
pd_groupby = getattr(pd_flights.groupby("Cancelled"), pd_agg)()
|
||||||
|
ed_groupby = getattr(ed_flights.groupby("Cancelled"), pd_agg)()
|
||||||
|
|
||||||
|
# checking only values because dtypes are checked in aggs tests
|
||||||
|
assert_frame_equal(
|
||||||
|
pd_groupby, ed_groupby, check_exact=False, check_dtype=False, rtol=4
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("pd_agg", ["max", "min", "mean", "median"])
|
||||||
|
def test_groupby_aggs_false(self, pd_agg):
|
||||||
|
pd_flights = self.pd_flights().filter(self.filter_data)
|
||||||
|
ed_flights = self.ed_flights().filter(self.filter_data)
|
||||||
|
|
||||||
|
# pandas numeric_only=False, matches with Eland numeric_only=None
|
||||||
|
pd_groupby = getattr(pd_flights.groupby("Cancelled"), pd_agg)(
|
||||||
|
numeric_only=False
|
||||||
|
)
|
||||||
|
ed_groupby = getattr(ed_flights.groupby("Cancelled"), pd_agg)(numeric_only=None)
|
||||||
|
|
||||||
|
# sum usually returns NaT for Eland, Nothing is returned from pandas
|
||||||
|
# we only check timestamp field here, because remaining cols are similar to numeric_only=True tests
|
||||||
|
# assert_frame_equal doesn't work well for timestamp fields (It converts into int)
|
||||||
|
# so we convert it into float
|
||||||
|
pd_timestamp = pd.to_numeric(pd_groupby["timestamp"], downcast="float")
|
||||||
|
ed_timestamp = pd.to_numeric(ed_groupby["timestamp"], downcast="float")
|
||||||
|
|
||||||
|
assert_series_equal(pd_timestamp, ed_timestamp, check_exact=False, rtol=4)
|
||||||
|
|
||||||
|
def test_groupby_columns(self):
|
||||||
|
# Check errors
|
||||||
|
ed_flights = self.ed_flights().filter(self.filter_data)
|
||||||
|
|
||||||
|
match = "by parameter should be specified to groupby"
|
||||||
|
with pytest.raises(TypeError, match=match):
|
||||||
|
ed_flights.groupby(None).mean()
|
||||||
|
|
||||||
|
by = ["ABC", "Cancelled"]
|
||||||
|
match = "Requested columns {'ABC'} not in the DataFrame."
|
||||||
|
with pytest.raises(KeyError, match=match):
|
||||||
|
ed_flights.groupby(by).mean()
|
||||||
|
|
||||||
|
def test_groupby_dropna(self):
|
||||||
|
# TODO Add tests once dropna is implemeted
|
||||||
|
pass
|
@ -44,6 +44,7 @@ TYPED_FILES = (
|
|||||||
"eland/query.py",
|
"eland/query.py",
|
||||||
"eland/tasks.py",
|
"eland/tasks.py",
|
||||||
"eland/utils.py",
|
"eland/utils.py",
|
||||||
|
"eland/groupby.py",
|
||||||
"eland/ml/__init__.py",
|
"eland/ml/__init__.py",
|
||||||
"eland/ml/_model_serializer.py",
|
"eland/ml/_model_serializer.py",
|
||||||
"eland/ml/ml_model.py",
|
"eland/ml/ml_model.py",
|
||||||
|
Loading…
x
Reference in New Issue
Block a user