mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
Deal with the mad aggregation being removed in Pandas 2 (#602)
This commit is contained in:
parent
5b3a83e7f2
commit
5e5f36bdf8
@ -23,6 +23,17 @@ from pandas.testing import assert_frame_equal, assert_index_equal, assert_series
|
||||
|
||||
from tests.common import TestData
|
||||
|
||||
PANDAS_MAJOR_VERSION = int(pd.__version__.split(".")[0])
|
||||
|
||||
|
||||
# The mean absolute difference (mad) aggregation has been removed from
|
||||
# pandas with major version 2:
|
||||
# https://github.com/pandas-dev/pandas/issues/11787
|
||||
# To compare whether eland's version of it works, we need to implement
|
||||
# it here ourselves.
|
||||
def mad(x):
|
||||
return abs(x - x.mean()).mean()
|
||||
|
||||
|
||||
class TestGroupbyDataFrame(TestData):
|
||||
funcs = ["max", "min", "mean", "sum"]
|
||||
@ -71,7 +82,7 @@ class TestGroupbyDataFrame(TestData):
|
||||
@pytest.mark.parametrize("dropna", [True, False])
|
||||
@pytest.mark.parametrize("pd_agg", ["max", "min", "mean", "sum", "median"])
|
||||
def test_groupby_aggs_numeric_only_true(self, pd_agg, dropna):
|
||||
# Pandas has numeric_only applicable for the above aggs with groupby only.
|
||||
# Pandas has numeric_only applicable for the above aggs with groupby only.
|
||||
|
||||
pd_flights = self.pd_flights().filter(self.filter_data)
|
||||
ed_flights = self.ed_flights().filter(self.filter_data)
|
||||
@ -95,7 +106,14 @@ class TestGroupbyDataFrame(TestData):
|
||||
pd_flights = self.pd_flights().filter(self.filter_data)
|
||||
ed_flights = self.ed_flights().filter(self.filter_data)
|
||||
|
||||
pd_groupby = getattr(pd_flights.groupby("Cancelled", dropna=dropna), pd_agg)()
|
||||
# The mad aggregation has been removed in Pandas 2, so we need to use
|
||||
# our own implementation if we run the tests with Pandas 2 or higher
|
||||
if PANDAS_MAJOR_VERSION >= 2 and pd_agg == "mad":
|
||||
pd_groupby = pd_flights.groupby("Cancelled", dropna=dropna).aggregate(mad)
|
||||
else:
|
||||
pd_groupby = getattr(
|
||||
pd_flights.groupby("Cancelled", dropna=dropna), pd_agg
|
||||
)()
|
||||
ed_groupby = getattr(ed_flights.groupby("Cancelled", dropna=dropna), pd_agg)(
|
||||
numeric_only=True
|
||||
)
|
||||
@ -211,14 +229,20 @@ class TestGroupbyDataFrame(TestData):
|
||||
pd_flights = self.pd_flights().filter(self.filter_data + ["DestCountry"])
|
||||
ed_flights = self.ed_flights().filter(self.filter_data + ["DestCountry"])
|
||||
|
||||
pd_mad = pd_flights.groupby("DestCountry").mad()
|
||||
if PANDAS_MAJOR_VERSION < 2:
|
||||
pd_mad = pd_flights.groupby("DestCountry").mad()
|
||||
else:
|
||||
pd_mad = pd_flights.groupby("DestCountry").aggregate(mad)
|
||||
ed_mad = ed_flights.groupby("DestCountry").mad()
|
||||
|
||||
assert_index_equal(pd_mad.columns, ed_mad.columns)
|
||||
assert_index_equal(pd_mad.index, ed_mad.index)
|
||||
assert_series_equal(pd_mad.dtypes, ed_mad.dtypes)
|
||||
|
||||
pd_min_mad = pd_flights.groupby("DestCountry").aggregate(["min", "mad"])
|
||||
if PANDAS_MAJOR_VERSION < 2:
|
||||
pd_min_mad = pd_flights.groupby("DestCountry").aggregate(["min", "mad"])
|
||||
else:
|
||||
pd_min_mad = pd_flights.groupby("DestCountry").aggregate(["min", mad])
|
||||
ed_min_mad = ed_flights.groupby("DestCountry").aggregate(["min", "mad"])
|
||||
|
||||
assert_index_equal(pd_min_mad.columns, ed_min_mad.columns)
|
||||
|
Loading…
x
Reference in New Issue
Block a user