Deal with the mad aggregation being removed in Pandas 2 (#602)

This commit is contained in:
Bart Broere 2023-11-06 06:12:16 +01:00 committed by GitHub
parent 5b3a83e7f2
commit 5e5f36bdf8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -23,6 +23,17 @@ from pandas.testing import assert_frame_equal, assert_index_equal, assert_series
from tests.common import TestData
PANDAS_MAJOR_VERSION = int(pd.__version__.split(".")[0])
# The mean absolute difference (mad) aggregation has been removed from
# pandas with major version 2:
# https://github.com/pandas-dev/pandas/issues/11787
# To compare whether eland's version of it works, we need to implement
# it here ourselves.
def mad(x):
return abs(x - x.mean()).mean()
class TestGroupbyDataFrame(TestData):
funcs = ["max", "min", "mean", "sum"]
@ -71,7 +82,7 @@ class TestGroupbyDataFrame(TestData):
@pytest.mark.parametrize("dropna", [True, False])
@pytest.mark.parametrize("pd_agg", ["max", "min", "mean", "sum", "median"])
def test_groupby_aggs_numeric_only_true(self, pd_agg, dropna):
# Pandas has numeric_only applicable for the above aggs with groupby only.
# Pandas has numeric_only applicable for the above aggs with groupby only.
pd_flights = self.pd_flights().filter(self.filter_data)
ed_flights = self.ed_flights().filter(self.filter_data)
@ -95,7 +106,14 @@ class TestGroupbyDataFrame(TestData):
pd_flights = self.pd_flights().filter(self.filter_data)
ed_flights = self.ed_flights().filter(self.filter_data)
pd_groupby = getattr(pd_flights.groupby("Cancelled", dropna=dropna), pd_agg)()
# The mad aggregation has been removed in Pandas 2, so we need to use
# our own implementation if we run the tests with Pandas 2 or higher
if PANDAS_MAJOR_VERSION >= 2 and pd_agg == "mad":
pd_groupby = pd_flights.groupby("Cancelled", dropna=dropna).aggregate(mad)
else:
pd_groupby = getattr(
pd_flights.groupby("Cancelled", dropna=dropna), pd_agg
)()
ed_groupby = getattr(ed_flights.groupby("Cancelled", dropna=dropna), pd_agg)(
numeric_only=True
)
@ -211,14 +229,20 @@ class TestGroupbyDataFrame(TestData):
pd_flights = self.pd_flights().filter(self.filter_data + ["DestCountry"])
ed_flights = self.ed_flights().filter(self.filter_data + ["DestCountry"])
pd_mad = pd_flights.groupby("DestCountry").mad()
if PANDAS_MAJOR_VERSION < 2:
pd_mad = pd_flights.groupby("DestCountry").mad()
else:
pd_mad = pd_flights.groupby("DestCountry").aggregate(mad)
ed_mad = ed_flights.groupby("DestCountry").mad()
assert_index_equal(pd_mad.columns, ed_mad.columns)
assert_index_equal(pd_mad.index, ed_mad.index)
assert_series_equal(pd_mad.dtypes, ed_mad.dtypes)
pd_min_mad = pd_flights.groupby("DestCountry").aggregate(["min", "mad"])
if PANDAS_MAJOR_VERSION < 2:
pd_min_mad = pd_flights.groupby("DestCountry").aggregate(["min", "mad"])
else:
pd_min_mad = pd_flights.groupby("DestCountry").aggregate(["min", mad])
ed_min_mad = ed_flights.groupby("DestCountry").aggregate(["min", "mad"])
assert_index_equal(pd_min_mad.columns, ed_min_mad.columns)