diff --git a/eland/tests/conftest.py b/eland/tests/conftest.py new file mode 100644 index 0000000..3142270 --- /dev/null +++ b/eland/tests/conftest.py @@ -0,0 +1,148 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import inspect +import pytest +import pandas as pd +from .common import ( + assert_pandas_eland_frame_equal, + assert_pandas_eland_series_equal, + assert_frame_equal, + assert_series_equal, + _ed_flights, + _pd_flights, + _ed_ecommerce, + _pd_ecommerce, + _ed_flights_small, + _pd_flights_small, +) +import eland as ed + + +class SymmetricAPIChecker: + def __init__(self, ed_obj, pd_obj): + self.ed = ed_obj + self.pd = pd_obj + + def load_dataset(self, dataset): + if dataset == "flights": + self.ed = _ed_flights + self.pd = _pd_flights.copy() + elif dataset == "flights_small": + self.ed = _ed_flights_small + self.pd = _pd_flights_small.copy() + elif dataset == "ecommerce": + self.ed = _ed_ecommerce + self.pd = _pd_ecommerce.copy() + else: + raise ValueError(f"Unknown dataset {dataset!r}") + + def return_value_checker(self, func_name): + """Returns a function which wraps the requested function + and checks the return value when that function is inevitably + called. + """ + + def f(*args, **kwargs): + ed_exc = None + try: + ed_obj = getattr(self.ed, func_name)(*args, **kwargs) + except Exception as e: + ed_exc = e + pd_exc = None + try: + if func_name == "to_pandas": + pd_obj = self.pd + else: + pd_obj = getattr(self.pd, func_name)(*args, **kwargs) + except Exception as e: + pd_exc = e + + self.check_exception(ed_exc, pd_exc) + self.check_values(ed_obj, pd_obj) + + if isinstance(ed_obj, (ed.DataFrame, ed.Series)): + return SymmetricAPIChecker(ed_obj, pd_obj) + return pd_obj + + return f + + def check_values(self, ed_obj, pd_obj): + """Checks that any two values coming from eland and pandas are equal""" + if isinstance(ed_obj, ed.DataFrame): + assert_pandas_eland_frame_equal(pd_obj, ed_obj) + elif isinstance(ed_obj, ed.Series): + assert_pandas_eland_series_equal(pd_obj, ed_obj) + elif isinstance(ed_obj, pd.DataFrame): + assert_frame_equal(ed_obj, pd_obj) + elif isinstance(ed_obj, pd.Series): + assert_series_equal(ed_obj, pd_obj) + elif isinstance(ed_obj, pd.Index): + assert ed_obj.equals(pd_obj) + else: + assert ed_obj == pd_obj + + def check_exception(self, ed_exc, pd_exc): + """Checks that either an exception was raised or not from both eland and pandas""" + assert (ed_exc is None) == (pd_exc is None) and type(ed_exc) == type(pd_exc) + if pd_exc is not None: + raise pd_exc + + def __getitem__(self, item): + if isinstance(item, SymmetricAPIChecker): + pd_item = item.pd + ed_item = item.ed + else: + pd_item = ed_item = item + + ed_exc = None + pd_exc = None + try: + pd_obj = self.pd[pd_item] + except Exception as e: + pd_exc = e + try: + ed_obj = self.ed[ed_item] + except Exception as e: + ed_exc = e + + self.check_exception(ed_exc, pd_exc) + if isinstance(ed_obj, (ed.DataFrame, ed.Series)): + return SymmetricAPIChecker(ed_obj, pd_obj) + return pd_obj + + def __getattr__(self, item): + if item == "to_pandas": + return self.return_value_checker("to_pandas") + + pd_obj = getattr(self.pd, item) + if inspect.isfunction(pd_obj) or inspect.ismethod(pd_obj): + return self.return_value_checker(item) + ed_obj = getattr(self.ed, item) + + self.check_values(ed_obj, pd_obj) + + if isinstance(ed_obj, (ed.DataFrame, ed.Series)): + return SymmetricAPIChecker(ed_obj, pd_obj) + return pd_obj + + +@pytest.fixture(scope="function") +def df(): + return SymmetricAPIChecker( + ed_obj=_ed_flights_small, pd_obj=_pd_flights_small.copy() + ) diff --git a/eland/tests/dataframe/test_count_pytest.py b/eland/tests/dataframe/test_count_pytest.py index 9d42171..4936009 100644 --- a/eland/tests/dataframe/test_count_pytest.py +++ b/eland/tests/dataframe/test_count_pytest.py @@ -17,20 +17,8 @@ # File called _pytest for PyCharm compatability -from pandas.testing import assert_series_equal -from eland.tests.common import TestData - - -class TestDataFrameCount(TestData): - def test_ecommerce_count(self): - pd_ecommerce = self.pd_ecommerce() - ed_ecommerce = self.ed_ecommerce() - - pd_count = pd_ecommerce.count() - ed_count = ed_ecommerce.count() - - print(pd_count) - print(ed_count) - - assert_series_equal(pd_count, ed_count) +class TestDataFrameCount: + def test_count(self, df): + df.load_dataset("ecommerce") + df.count() diff --git a/eland/tests/dataframe/test_drop_pytest.py b/eland/tests/dataframe/test_drop_pytest.py index b7c0ab6..6800fb4 100644 --- a/eland/tests/dataframe/test_drop_pytest.py +++ b/eland/tests/dataframe/test_drop_pytest.py @@ -17,51 +17,36 @@ # File called _pytest for PyCharm compatability -from eland.tests.common import TestData -from eland.tests.common import assert_pandas_eland_frame_equal +class TestDataFrameDrop: + def test_drop(self, df): + df.drop(["Carrier", "DestCityName"], axis=1) + df.drop(columns=["Carrier", "DestCityName"]) -class TestDataFrameDrop(TestData): - def test_flights_small_drop(self): - ed_flights_small = self.ed_flights_small() - pd_flights_small = self.pd_flights_small() + df.drop(["1", "2"]) + df.drop(["1", "2"], axis=0) + df.drop(index=["1", "2"]) - # ['AvgTicketPrice', 'Cancelled', 'Carrier', 'Dest', 'DestAirportID', - # 'DestCityName', 'DestCountry', 'DestLocation', 'DestRegion', - # 'DestWeather', 'DistanceKilometers', 'DistanceMiles', 'FlightDelay', - # 'FlightDelayMin', 'FlightDelayType', 'FlightNum', 'FlightTimeHour', - # 'FlightTimeMin', 'Origin', 'OriginAirportID', 'OriginCityName', - # 'OriginCountry', 'OriginLocation', 'OriginRegion', 'OriginWeather', - # 'dayOfWeek', 'timestamp'] - pd_col0 = pd_flights_small.drop(["Carrier", "DestCityName"], axis=1) - pd_col1 = pd_flights_small.drop(columns=["Carrier", "DestCityName"]) + def test_drop_all_columns(self, df): + all_columns = list(df.columns) + rows = df.shape[0] - ed_col0 = ed_flights_small.drop(["Carrier", "DestCityName"], axis=1) - ed_col1 = ed_flights_small.drop(columns=["Carrier", "DestCityName"]) + for dropped in ( + df.drop(labels=all_columns, axis=1), + df.drop(columns=all_columns), + df.drop(all_columns, axis=1), + ): + assert dropped.shape == (rows, 0) + assert list(dropped.columns) == [] - assert_pandas_eland_frame_equal(pd_col0, ed_col0) - assert_pandas_eland_frame_equal(pd_col1, ed_col1) + def test_drop_all_index(self, df): + all_index = list(df.pd.index) + cols = df.shape[1] - # Drop rows by index - pd_idx0 = pd_flights_small.drop(["1", "2"]) - ed_idx0 = ed_flights_small.drop(["1", "2"]) - - assert_pandas_eland_frame_equal(pd_idx0, ed_idx0) - - def test_flights_drop_all_columns(self): - ed_flights_small = self.ed_flights_small() - pd_flights_small = self.pd_flights_small() - - all_columns = ed_flights_small.columns - - pd_col0 = pd_flights_small.drop(labels=all_columns, axis=1) - pd_col1 = pd_flights_small.drop(columns=all_columns) - - ed_col0 = ed_flights_small.drop(labels=all_columns, axis=1) - ed_col1 = ed_flights_small.drop(columns=all_columns) - - assert_pandas_eland_frame_equal(pd_col0, ed_col0) - assert_pandas_eland_frame_equal(pd_col1, ed_col1) - - assert ed_col0.columns.equals(pd_col0.columns) - assert ed_col1.columns.equals(pd_col1.columns) + for dropped in ( + df.drop(all_index), + df.drop(all_index, axis=0), + df.drop(index=all_index), + ): + assert dropped.shape == (0, cols) + assert list(dropped.to_pandas().index) == [] diff --git a/eland/tests/dataframe/test_dtypes_pytest.py b/eland/tests/dataframe/test_dtypes_pytest.py index 9d56578..6e63495 100644 --- a/eland/tests/dataframe/test_dtypes_pytest.py +++ b/eland/tests/dataframe/test_dtypes_pytest.py @@ -18,30 +18,17 @@ # File called _pytest for PyCharm compatability import numpy as np -from pandas.testing import assert_series_equal - -from eland.tests.common import TestData -from eland.tests.common import assert_pandas_eland_frame_equal -class TestDataFrameDtypes(TestData): - def test_flights_dtypes(self): - pd_flights = self.pd_flights() - ed_flights = self.ed_flights() +class TestDataFrameDtypes: + def test_dtypes(self, df): + print(df.dtypes) - print(pd_flights.dtypes) - print(ed_flights.dtypes) + for i in range(0, len(df.dtypes) - 1): + assert isinstance(df.dtypes[i], type(df.dtypes[i])) - assert_series_equal(pd_flights.dtypes, ed_flights.dtypes) - - for i in range(0, len(pd_flights.dtypes) - 1): - assert isinstance(pd_flights.dtypes[i], type(ed_flights.dtypes[i])) - - def test_flights_select_dtypes(self): - pd_flights = self.pd_flights_small() - ed_flights = self.ed_flights_small() - - assert_pandas_eland_frame_equal( - pd_flights.select_dtypes(include=np.number), - ed_flights.select_dtypes(include=np.number), - ) + def test_select_dtypes(self, df): + df.select_dtypes(include=np.number) + df.select_dtypes(exclude=np.number) + df.select_dtypes(include=np.float64) + df.select_dtypes(exclude=np.float64) diff --git a/eland/tests/dataframe/test_filter_pytest.py b/eland/tests/dataframe/test_filter_pytest.py index dded02d..e6ec8c8 100644 --- a/eland/tests/dataframe/test_filter_pytest.py +++ b/eland/tests/dataframe/test_filter_pytest.py @@ -19,23 +19,20 @@ import pytest from eland.tests.common import TestData -from eland.tests.common import assert_pandas_eland_frame_equal class TestDataFrameFilter(TestData): - def test_filter_arguments_mutually_exclusive(self): - ed_flights_small = self.ed_flights_small() - + def test_filter_arguments_mutually_exclusive(self, df): with pytest.raises(TypeError): - ed_flights_small.filter(items=[], like="!", regex="!") + df.filter(items=[], like="!", regex="!") with pytest.raises(TypeError): - ed_flights_small.filter(items=[], regex="!") + df.filter(items=[], regex="!") with pytest.raises(TypeError): - ed_flights_small.filter(items=[], like="!") + df.filter(items=[], like="!") with pytest.raises(TypeError): - ed_flights_small.filter(like="!", regex="!") + df.filter(like="!", regex="!") with pytest.raises(TypeError): - ed_flights_small.filter() + df.filter() @pytest.mark.parametrize( "items", @@ -45,46 +42,22 @@ class TestDataFrameFilter(TestData): ["notfound", "AvgTicketPrice"], ], ) - def test_flights_filter_columns_items(self, items): - ed_flights_small = self.ed_flights_small() - pd_flights_small = self.pd_flights_small() - - ed_df = ed_flights_small.filter(items=items) - pd_df = pd_flights_small.filter(items=items) - - assert_pandas_eland_frame_equal(pd_df, ed_df) + def test_filter_columns_items(self, df, items): + df.filter(items=items) @pytest.mark.parametrize("like", ["Flight", "Nope"]) - def test_flights_filter_columns_like(self, like): - ed_flights_small = self.ed_flights_small() - pd_flights_small = self.pd_flights_small() - - ed_df = ed_flights_small.filter(like=like) - pd_df = pd_flights_small.filter(like=like) - - assert_pandas_eland_frame_equal(pd_df, ed_df) + def test_filter_columns_like(self, df, like): + df.filter(like=like) @pytest.mark.parametrize("regex", ["^Flig", "^Flight.*r$", ".*", "^[^C]"]) - def test_flights_filter_columns_regex(self, regex): - ed_flights_small = self.ed_flights_small() - pd_flights_small = self.pd_flights_small() - - ed_df = ed_flights_small.filter(regex=regex) - pd_df = pd_flights_small.filter(regex=regex) - - assert_pandas_eland_frame_equal(pd_df, ed_df) + def test_filter_columns_regex(self, df, regex): + df.filter(regex=regex) @pytest.mark.parametrize("items", [[], ["20"], [str(x) for x in range(30)]]) - def test_flights_filter_index_items(self, items): - ed_flights_small = self.ed_flights_small() - pd_flights_small = self.pd_flights_small() + def test_filter_index_items(self, df, items): + df.filter(items=items, axis=0) - ed_df = ed_flights_small.filter(items=items, axis=0) - pd_df = pd_flights_small.filter(items=items, axis=0) - - assert_pandas_eland_frame_equal(pd_df, ed_df) - - def test_flights_filter_index_like_and_regex(self): + def test_filter_index_like_and_regex(self): ed_flights_small = self.ed_flights_small() with pytest.raises(NotImplementedError): diff --git a/eland/tests/dataframe/test_get_pytest.py b/eland/tests/dataframe/test_get_pytest.py index f12509a..75ea0fe 100644 --- a/eland/tests/dataframe/test_get_pytest.py +++ b/eland/tests/dataframe/test_get_pytest.py @@ -17,16 +17,7 @@ # File called _pytest for PyCharm compatability -from eland.tests.common import TestData - -class TestDataFrameGet(TestData): - def test_get_one_attribute(self): - ed_flights = self.ed_flights() - pd_flights = self.pd_flights() - - ed_get0 = ed_flights.get("Carrier") - pd_get0 = pd_flights.get("Carrier") - - print(ed_get0, type(ed_get0)) - print(pd_get0, type(pd_get0)) +class TestDataFrameGet: + def test_get_one_attribute(self, df): + df.get("Carrier") diff --git a/eland/tests/dataframe/test_getitem_pytest.py b/eland/tests/dataframe/test_getitem_pytest.py index fdbf3a1..c9a052d 100644 --- a/eland/tests/dataframe/test_getitem_pytest.py +++ b/eland/tests/dataframe/test_getitem_pytest.py @@ -17,58 +17,23 @@ # File called _pytest for PyCharm compatability -from eland.tests.common import TestData -from eland.tests.common import ( - assert_pandas_eland_frame_equal, - assert_pandas_eland_series_equal, -) +import pytest -class TestDataFrameGetItem(TestData): - def test_getitem_one_attribute(self): - ed_flights = self.ed_flights().head(103) - pd_flights = self.pd_flights().head(103) +class TestDataFrameGetItem: + def test_getitem_one_attribute(self, df): + df.load_dataset("flights") + print(df.head(103)["OriginAirportID"]) - ed_flights_OriginAirportID = ed_flights["OriginAirportID"] - pd_flights_OriginAirportID = pd_flights["OriginAirportID"] + def test_getitem_attribute_list(self, df): + print(df[["OriginAirportID", "AvgTicketPrice", "Carrier"]]) - assert_pandas_eland_series_equal( - pd_flights_OriginAirportID, ed_flights_OriginAirportID - ) + def test_getitem_one_argument(self, df): + print(df.OriginAirportID) - def test_getitem_attribute_list(self): - ed_flights = self.ed_flights().head(42) - pd_flights = self.pd_flights().head(42) + def test_getitem_multiple_calls(self, df): + df = df[["DestCityName", "DestCountry", "DestLocation", "DestRegion"]] + with pytest.raises(KeyError): + df["Carrier"] - ed_flights_slice = ed_flights[["OriginAirportID", "AvgTicketPrice", "Carrier"]] - pd_flights_slice = pd_flights[["OriginAirportID", "AvgTicketPrice", "Carrier"]] - - assert_pandas_eland_frame_equal(pd_flights_slice, ed_flights_slice) - - def test_getitem_one_argument(self): - ed_flights = self.ed_flights().head(89) - pd_flights = self.pd_flights().head(89) - - ed_flights_OriginAirportID = ed_flights.OriginAirportID - pd_flights_OriginAirportID = pd_flights.OriginAirportID - - assert_pandas_eland_series_equal( - pd_flights_OriginAirportID, ed_flights_OriginAirportID - ) - - def test_getitem_multiple_calls(self): - ed_flights = self.ed_flights().head(89) - pd_flights = self.pd_flights().head(89) - - ed_col0 = ed_flights[ - ["DestCityName", "DestCountry", "DestLocation", "DestRegion"] - ] - try: - ed_col1 = ed_col0["Carrier"] - except KeyError: - pass - - pd_col1 = pd_flights["DestCountry"] - ed_col1 = ed_col0["DestCountry"] - - assert_pandas_eland_series_equal(pd_col1, ed_col1) + df["DestCountry"]