Fix DataFrame.agg() with string argument to return Series

This commit is contained in:
P. Sai Vinay 2020-08-25 23:09:34 +05:30 committed by GitHub
parent d73e8a241c
commit 1d6311164e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 34 additions and 5 deletions

1
.gitignore vendored
View File

@ -49,3 +49,4 @@ venv/
ENV/ ENV/
env.bak/ env.bak/
venv.bak/ venv.bak/
.mypy_cache

View File

@ -1386,9 +1386,8 @@ class DataFrame(NDFrame):
# ['count', 'mad', 'max', 'mean', 'median', 'min', 'mode', 'quantile', # ['count', 'mad', 'max', 'mean', 'median', 'min', 'mode', 'quantile',
# 'rank', 'sem', 'skew', 'sum', 'std', 'var', 'nunique'] # 'rank', 'sem', 'skew', 'sum', 'std', 'var', 'nunique']
if isinstance(func, str): if isinstance(func, str):
# wrap in list # Wrap in list
func = [func] return self._query_compiler.aggs([func]).squeeze().rename(None)
return self._query_compiler.aggs(func)
elif is_list_like(func): elif is_list_like(func):
# we have a list! # we have a list!
return self._query_compiler.aggs(func) return self._query_compiler.aggs(func)

View File

@ -18,8 +18,8 @@
# File called _pytest for PyCharm compatability # File called _pytest for PyCharm compatability
import numpy as np import numpy as np
from pandas.testing import assert_frame_equal from pandas.testing import assert_frame_equal, assert_series_equal
import pytest
from eland.tests.common import TestData from eland.tests.common import TestData
@ -94,3 +94,32 @@ class TestDataFrameAggs(TestData):
# TODO - investigate this more # TODO - investigate this more
pd_aggs = pd_aggs.astype("float64") pd_aggs = pd_aggs.astype("float64")
assert_frame_equal(pd_aggs, ed_aggs, check_exact=False, check_less_precise=2) assert_frame_equal(pd_aggs, ed_aggs, check_exact=False, check_less_precise=2)
# If Aggregate is given a string then series is returned.
@pytest.mark.parametrize("agg", ["mean", "min", "max"])
def test_terms_aggs_series(self, agg):
pd_flights = self.pd_flights()
ed_flights = self.ed_flights()
pd_sum_min_std = pd_flights.select_dtypes(include=[np.number]).agg(agg)
ed_sum_min_std = ed_flights.select_dtypes(include=[np.number]).agg(agg)
assert_series_equal(pd_sum_min_std, ed_sum_min_std)
def test_terms_aggs_series_with_single_list_agg(self):
# aggs list with single agg should return dataframe.
pd_flights = self.pd_flights()
ed_flights = self.ed_flights()
pd_sum_min = pd_flights.select_dtypes(include=[np.number]).agg(["mean"])
ed_sum_min = ed_flights.select_dtypes(include=[np.number]).agg(["mean"])
assert_frame_equal(pd_sum_min, ed_sum_min)
# If Wrong Aggregate value is given.
def test_terms_wrongaggs(self):
ed_flights = self.ed_flights()[["FlightDelayMin"]]
match = "('abc', ' not currently implemented')"
with pytest.raises(NotImplementedError, match=match):
ed_flights.select_dtypes(include=[np.number]).agg("abc")