mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
Fix DataFrame.agg() with string argument to return Series
This commit is contained in:
parent
d73e8a241c
commit
1d6311164e
1
.gitignore
vendored
1
.gitignore
vendored
@ -49,3 +49,4 @@ venv/
|
|||||||
ENV/
|
ENV/
|
||||||
env.bak/
|
env.bak/
|
||||||
venv.bak/
|
venv.bak/
|
||||||
|
.mypy_cache
|
||||||
|
@ -1386,9 +1386,8 @@ class DataFrame(NDFrame):
|
|||||||
# ['count', 'mad', 'max', 'mean', 'median', 'min', 'mode', 'quantile',
|
# ['count', 'mad', 'max', 'mean', 'median', 'min', 'mode', 'quantile',
|
||||||
# 'rank', 'sem', 'skew', 'sum', 'std', 'var', 'nunique']
|
# 'rank', 'sem', 'skew', 'sum', 'std', 'var', 'nunique']
|
||||||
if isinstance(func, str):
|
if isinstance(func, str):
|
||||||
# wrap in list
|
# Wrap in list
|
||||||
func = [func]
|
return self._query_compiler.aggs([func]).squeeze().rename(None)
|
||||||
return self._query_compiler.aggs(func)
|
|
||||||
elif is_list_like(func):
|
elif is_list_like(func):
|
||||||
# we have a list!
|
# we have a list!
|
||||||
return self._query_compiler.aggs(func)
|
return self._query_compiler.aggs(func)
|
||||||
|
@ -18,8 +18,8 @@
|
|||||||
# File called _pytest for PyCharm compatability
|
# File called _pytest for PyCharm compatability
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pandas.testing import assert_frame_equal
|
from pandas.testing import assert_frame_equal, assert_series_equal
|
||||||
|
import pytest
|
||||||
from eland.tests.common import TestData
|
from eland.tests.common import TestData
|
||||||
|
|
||||||
|
|
||||||
@ -94,3 +94,32 @@ class TestDataFrameAggs(TestData):
|
|||||||
# TODO - investigate this more
|
# TODO - investigate this more
|
||||||
pd_aggs = pd_aggs.astype("float64")
|
pd_aggs = pd_aggs.astype("float64")
|
||||||
assert_frame_equal(pd_aggs, ed_aggs, check_exact=False, check_less_precise=2)
|
assert_frame_equal(pd_aggs, ed_aggs, check_exact=False, check_less_precise=2)
|
||||||
|
|
||||||
|
# If Aggregate is given a string then series is returned.
|
||||||
|
@pytest.mark.parametrize("agg", ["mean", "min", "max"])
|
||||||
|
def test_terms_aggs_series(self, agg):
|
||||||
|
pd_flights = self.pd_flights()
|
||||||
|
ed_flights = self.ed_flights()
|
||||||
|
|
||||||
|
pd_sum_min_std = pd_flights.select_dtypes(include=[np.number]).agg(agg)
|
||||||
|
ed_sum_min_std = ed_flights.select_dtypes(include=[np.number]).agg(agg)
|
||||||
|
|
||||||
|
assert_series_equal(pd_sum_min_std, ed_sum_min_std)
|
||||||
|
|
||||||
|
def test_terms_aggs_series_with_single_list_agg(self):
|
||||||
|
# aggs list with single agg should return dataframe.
|
||||||
|
pd_flights = self.pd_flights()
|
||||||
|
ed_flights = self.ed_flights()
|
||||||
|
|
||||||
|
pd_sum_min = pd_flights.select_dtypes(include=[np.number]).agg(["mean"])
|
||||||
|
ed_sum_min = ed_flights.select_dtypes(include=[np.number]).agg(["mean"])
|
||||||
|
|
||||||
|
assert_frame_equal(pd_sum_min, ed_sum_min)
|
||||||
|
|
||||||
|
# If Wrong Aggregate value is given.
|
||||||
|
def test_terms_wrongaggs(self):
|
||||||
|
ed_flights = self.ed_flights()[["FlightDelayMin"]]
|
||||||
|
|
||||||
|
match = "('abc', ' not currently implemented')"
|
||||||
|
with pytest.raises(NotImplementedError, match=match):
|
||||||
|
ed_flights.select_dtypes(include=[np.number]).agg("abc")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user