mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
Fix DataFrame.agg() with string argument to return Series
This commit is contained in:
parent
d73e8a241c
commit
1d6311164e
1
.gitignore
vendored
1
.gitignore
vendored
@ -49,3 +49,4 @@ venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
.mypy_cache
|
||||
|
@ -1386,9 +1386,8 @@ class DataFrame(NDFrame):
|
||||
# ['count', 'mad', 'max', 'mean', 'median', 'min', 'mode', 'quantile',
|
||||
# 'rank', 'sem', 'skew', 'sum', 'std', 'var', 'nunique']
|
||||
if isinstance(func, str):
|
||||
# wrap in list
|
||||
func = [func]
|
||||
return self._query_compiler.aggs(func)
|
||||
# Wrap in list
|
||||
return self._query_compiler.aggs([func]).squeeze().rename(None)
|
||||
elif is_list_like(func):
|
||||
# we have a list!
|
||||
return self._query_compiler.aggs(func)
|
||||
|
@ -18,8 +18,8 @@
|
||||
# File called _pytest for PyCharm compatability
|
||||
|
||||
import numpy as np
|
||||
from pandas.testing import assert_frame_equal
|
||||
|
||||
from pandas.testing import assert_frame_equal, assert_series_equal
|
||||
import pytest
|
||||
from eland.tests.common import TestData
|
||||
|
||||
|
||||
@ -94,3 +94,32 @@ class TestDataFrameAggs(TestData):
|
||||
# TODO - investigate this more
|
||||
pd_aggs = pd_aggs.astype("float64")
|
||||
assert_frame_equal(pd_aggs, ed_aggs, check_exact=False, check_less_precise=2)
|
||||
|
||||
# If Aggregate is given a string then series is returned.
|
||||
@pytest.mark.parametrize("agg", ["mean", "min", "max"])
|
||||
def test_terms_aggs_series(self, agg):
|
||||
pd_flights = self.pd_flights()
|
||||
ed_flights = self.ed_flights()
|
||||
|
||||
pd_sum_min_std = pd_flights.select_dtypes(include=[np.number]).agg(agg)
|
||||
ed_sum_min_std = ed_flights.select_dtypes(include=[np.number]).agg(agg)
|
||||
|
||||
assert_series_equal(pd_sum_min_std, ed_sum_min_std)
|
||||
|
||||
def test_terms_aggs_series_with_single_list_agg(self):
|
||||
# aggs list with single agg should return dataframe.
|
||||
pd_flights = self.pd_flights()
|
||||
ed_flights = self.ed_flights()
|
||||
|
||||
pd_sum_min = pd_flights.select_dtypes(include=[np.number]).agg(["mean"])
|
||||
ed_sum_min = ed_flights.select_dtypes(include=[np.number]).agg(["mean"])
|
||||
|
||||
assert_frame_equal(pd_sum_min, ed_sum_min)
|
||||
|
||||
# If Wrong Aggregate value is given.
|
||||
def test_terms_wrongaggs(self):
|
||||
ed_flights = self.ed_flights()[["FlightDelayMin"]]
|
||||
|
||||
match = "('abc', ' not currently implemented')"
|
||||
with pytest.raises(NotImplementedError, match=match):
|
||||
ed_flights.select_dtypes(include=[np.number]).agg("abc")
|
||||
|
Loading…
x
Reference in New Issue
Block a user