mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
Deal with the mad aggregation being removed in Pandas 2 (#602)
This commit is contained in:
parent
5b3a83e7f2
commit
5e5f36bdf8
@ -23,6 +23,17 @@ from pandas.testing import assert_frame_equal, assert_index_equal, assert_series
|
|||||||
|
|
||||||
from tests.common import TestData
|
from tests.common import TestData
|
||||||
|
|
||||||
|
PANDAS_MAJOR_VERSION = int(pd.__version__.split(".")[0])
|
||||||
|
|
||||||
|
|
||||||
|
# The mean absolute difference (mad) aggregation has been removed from
|
||||||
|
# pandas with major version 2:
|
||||||
|
# https://github.com/pandas-dev/pandas/issues/11787
|
||||||
|
# To compare whether eland's version of it works, we need to implement
|
||||||
|
# it here ourselves.
|
||||||
|
def mad(x):
|
||||||
|
return abs(x - x.mean()).mean()
|
||||||
|
|
||||||
|
|
||||||
class TestGroupbyDataFrame(TestData):
|
class TestGroupbyDataFrame(TestData):
|
||||||
funcs = ["max", "min", "mean", "sum"]
|
funcs = ["max", "min", "mean", "sum"]
|
||||||
@ -95,7 +106,14 @@ class TestGroupbyDataFrame(TestData):
|
|||||||
pd_flights = self.pd_flights().filter(self.filter_data)
|
pd_flights = self.pd_flights().filter(self.filter_data)
|
||||||
ed_flights = self.ed_flights().filter(self.filter_data)
|
ed_flights = self.ed_flights().filter(self.filter_data)
|
||||||
|
|
||||||
pd_groupby = getattr(pd_flights.groupby("Cancelled", dropna=dropna), pd_agg)()
|
# The mad aggregation has been removed in Pandas 2, so we need to use
|
||||||
|
# our own implementation if we run the tests with Pandas 2 or higher
|
||||||
|
if PANDAS_MAJOR_VERSION >= 2 and pd_agg == "mad":
|
||||||
|
pd_groupby = pd_flights.groupby("Cancelled", dropna=dropna).aggregate(mad)
|
||||||
|
else:
|
||||||
|
pd_groupby = getattr(
|
||||||
|
pd_flights.groupby("Cancelled", dropna=dropna), pd_agg
|
||||||
|
)()
|
||||||
ed_groupby = getattr(ed_flights.groupby("Cancelled", dropna=dropna), pd_agg)(
|
ed_groupby = getattr(ed_flights.groupby("Cancelled", dropna=dropna), pd_agg)(
|
||||||
numeric_only=True
|
numeric_only=True
|
||||||
)
|
)
|
||||||
@ -211,14 +229,20 @@ class TestGroupbyDataFrame(TestData):
|
|||||||
pd_flights = self.pd_flights().filter(self.filter_data + ["DestCountry"])
|
pd_flights = self.pd_flights().filter(self.filter_data + ["DestCountry"])
|
||||||
ed_flights = self.ed_flights().filter(self.filter_data + ["DestCountry"])
|
ed_flights = self.ed_flights().filter(self.filter_data + ["DestCountry"])
|
||||||
|
|
||||||
|
if PANDAS_MAJOR_VERSION < 2:
|
||||||
pd_mad = pd_flights.groupby("DestCountry").mad()
|
pd_mad = pd_flights.groupby("DestCountry").mad()
|
||||||
|
else:
|
||||||
|
pd_mad = pd_flights.groupby("DestCountry").aggregate(mad)
|
||||||
ed_mad = ed_flights.groupby("DestCountry").mad()
|
ed_mad = ed_flights.groupby("DestCountry").mad()
|
||||||
|
|
||||||
assert_index_equal(pd_mad.columns, ed_mad.columns)
|
assert_index_equal(pd_mad.columns, ed_mad.columns)
|
||||||
assert_index_equal(pd_mad.index, ed_mad.index)
|
assert_index_equal(pd_mad.index, ed_mad.index)
|
||||||
assert_series_equal(pd_mad.dtypes, ed_mad.dtypes)
|
assert_series_equal(pd_mad.dtypes, ed_mad.dtypes)
|
||||||
|
|
||||||
|
if PANDAS_MAJOR_VERSION < 2:
|
||||||
pd_min_mad = pd_flights.groupby("DestCountry").aggregate(["min", "mad"])
|
pd_min_mad = pd_flights.groupby("DestCountry").aggregate(["min", "mad"])
|
||||||
|
else:
|
||||||
|
pd_min_mad = pd_flights.groupby("DestCountry").aggregate(["min", mad])
|
||||||
ed_min_mad = ed_flights.groupby("DestCountry").aggregate(["min", "mad"])
|
ed_min_mad = ed_flights.groupby("DestCountry").aggregate(["min", "mad"])
|
||||||
|
|
||||||
assert_index_equal(pd_min_mad.columns, ed_min_mad.columns)
|
assert_index_equal(pd_min_mad.columns, ed_min_mad.columns)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user