eland/eland/tests/dataframe/test_groupby_pytest.py

128 lines
5.5 KiB
Python

# Licensed to Elasticsearch B.V. under one or more contributor
# license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright
# ownership. Elasticsearch B.V. licenses this file to you under
# the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# File called _pytest for PyCharm compatability
import pytest
from pandas.testing import assert_frame_equal, assert_series_equal
from eland.tests.common import TestData
import pandas as pd
class TestGroupbyDataFrame(TestData):
funcs = ["max", "min", "mean", "sum"]
extended_funcs = ["median", "mad", "var", "std"]
filter_data = [
"AvgTicketPrice",
"Cancelled",
"dayOfWeek",
"timestamp",
"DestCountry",
]
@pytest.mark.parametrize("numeric_only", [True])
def test_groupby_aggregate(self, numeric_only):
# TODO Add tests for numeric_only=False for aggs
# when we support aggregations on text fields
pd_flights = self.pd_flights().filter(self.filter_data)
ed_flights = self.ed_flights().filter(self.filter_data)
pd_groupby = pd_flights.groupby("Cancelled").agg(self.funcs, numeric_only)
ed_groupby = ed_flights.groupby("Cancelled").agg(self.funcs, numeric_only)
# checking only values because dtypes are checked in aggs tests
assert_frame_equal(pd_groupby, ed_groupby, check_exact=False, check_dtype=False)
@pytest.mark.parametrize("pd_agg", ["max", "min", "mean", "sum", "median"])
def test_groupby_aggs_true(self, pd_agg):
# Pandas has numeric_only applicable for the above aggs with groupby only.
pd_flights = self.pd_flights().filter(self.filter_data)
ed_flights = self.ed_flights().filter(self.filter_data)
pd_groupby = getattr(pd_flights.groupby("Cancelled"), pd_agg)(numeric_only=True)
ed_groupby = getattr(ed_flights.groupby("Cancelled"), pd_agg)(numeric_only=True)
# checking only values because dtypes are checked in aggs tests
assert_frame_equal(
pd_groupby, ed_groupby, check_exact=False, check_dtype=False, rtol=4
)
@pytest.mark.parametrize("pd_agg", ["mad", "var", "std"])
def test_groupby_aggs_mad_var_std(self, pd_agg):
# For these aggs pandas doesn't support numeric_only
pd_flights = self.pd_flights().filter(self.filter_data)
ed_flights = self.ed_flights().filter(self.filter_data)
pd_groupby = getattr(pd_flights.groupby("Cancelled"), pd_agg)()
ed_groupby = getattr(ed_flights.groupby("Cancelled"), pd_agg)(numeric_only=True)
# checking only values because dtypes are checked in aggs tests
assert_frame_equal(
pd_groupby, ed_groupby, check_exact=False, check_dtype=False, rtol=4
)
@pytest.mark.parametrize("pd_agg", ["nunique"])
def test_groupby_aggs_nunique(self, pd_agg):
pd_flights = self.pd_flights().filter(self.filter_data)
ed_flights = self.ed_flights().filter(self.filter_data)
pd_groupby = getattr(pd_flights.groupby("Cancelled"), pd_agg)()
ed_groupby = getattr(ed_flights.groupby("Cancelled"), pd_agg)()
# checking only values because dtypes are checked in aggs tests
assert_frame_equal(
pd_groupby, ed_groupby, check_exact=False, check_dtype=False, rtol=4
)
@pytest.mark.parametrize("pd_agg", ["max", "min", "mean", "median"])
def test_groupby_aggs_false(self, pd_agg):
pd_flights = self.pd_flights().filter(self.filter_data)
ed_flights = self.ed_flights().filter(self.filter_data)
# pandas numeric_only=False, matches with Eland numeric_only=None
pd_groupby = getattr(pd_flights.groupby("Cancelled"), pd_agg)(
numeric_only=False
)
ed_groupby = getattr(ed_flights.groupby("Cancelled"), pd_agg)(numeric_only=None)
# sum usually returns NaT for Eland, Nothing is returned from pandas
# we only check timestamp field here, because remaining cols are similar to numeric_only=True tests
# assert_frame_equal doesn't work well for timestamp fields (It converts into int)
# so we convert it into float
pd_timestamp = pd.to_numeric(pd_groupby["timestamp"], downcast="float")
ed_timestamp = pd.to_numeric(ed_groupby["timestamp"], downcast="float")
assert_series_equal(pd_timestamp, ed_timestamp, check_exact=False, rtol=4)
def test_groupby_columns(self):
# Check errors
ed_flights = self.ed_flights().filter(self.filter_data)
match = "by parameter should be specified to groupby"
with pytest.raises(TypeError, match=match):
ed_flights.groupby(None).mean()
by = ["ABC", "Cancelled"]
match = "Requested columns {'ABC'} not in the DataFrame."
with pytest.raises(KeyError, match=match):
ed_flights.groupby(by).mean()
def test_groupby_dropna(self):
# TODO Add tests once dropna is implemeted
pass