From 3db93cd789dcc1f47e4a66574217d7fff9d2f688 Mon Sep 17 00:00:00 2001 From: Florian Winkler Date: Tue, 4 Jan 2022 21:46:18 +0100 Subject: [PATCH] Allow using datetime types in filters --- eland/series.py | 32 ++++++++++++--- tests/series/test_arithmetics_pytest.py | 54 +++++++++++++++++++++++++ 2 files changed, 80 insertions(+), 6 deletions(-) diff --git a/eland/series.py b/eland/series.py index 4a02b79..2d7f6b4 100644 --- a/eland/series.py +++ b/eland/series.py @@ -34,6 +34,7 @@ Based on NDFrame which underpins eland.DataFrame import sys import warnings from collections.abc import Collection +from datetime import datetime from io import StringIO from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Tuple, Union @@ -461,51 +462,67 @@ class Series(NDFrame): return self._query_compiler.es_dtypes[0] def __gt__(self, other: Union[int, float, "Series"]) -> BooleanFilter: + if isinstance(other, np.datetime64): + # convert numpy datetime64 object it has no `strftime` method + other = pd.to_datetime(other) + if isinstance(other, Series): # Need to use scripted query to compare to values painless = f"doc['{self.name}'].value > doc['{other.name}'].value" return ScriptFilter(painless, lang="painless") - elif isinstance(other, (int, float)): + elif isinstance(other, (int, float, datetime)): return Greater(field=self.name, value=other) else: raise NotImplementedError(other, type(other)) def __lt__(self, other: Union[int, float, "Series"]) -> BooleanFilter: + if isinstance(other, np.datetime64): + other = pd.to_datetime(other) + if isinstance(other, Series): # Need to use scripted query to compare to values painless = f"doc['{self.name}'].value < doc['{other.name}'].value" return ScriptFilter(painless, lang="painless") - elif isinstance(other, (int, float)): + elif isinstance(other, (int, float, datetime)): return Less(field=self.name, value=other) else: raise NotImplementedError(other, type(other)) def __ge__(self, other: Union[int, float, "Series"]) -> BooleanFilter: + if isinstance(other, np.datetime64): + other = pd.to_datetime(other) + if isinstance(other, Series): # Need to use scripted query to compare to values painless = f"doc['{self.name}'].value >= doc['{other.name}'].value" return ScriptFilter(painless, lang="painless") - elif isinstance(other, (int, float)): + elif isinstance(other, (int, float, datetime)): return GreaterEqual(field=self.name, value=other) else: raise NotImplementedError(other, type(other)) def __le__(self, other: Union[int, float, "Series"]) -> BooleanFilter: + if isinstance(other, np.datetime64): + other = pd.to_datetime(other) + if isinstance(other, Series): # Need to use scripted query to compare to values painless = f"doc['{self.name}'].value <= doc['{other.name}'].value" return ScriptFilter(painless, lang="painless") - elif isinstance(other, (int, float)): + elif isinstance(other, (int, float, datetime)): return LessEqual(field=self.name, value=other) else: raise NotImplementedError(other, type(other)) def __eq__(self, other: Union[int, float, str, "Series"]) -> BooleanFilter: + if isinstance(other, np.datetime64): + other = pd.to_datetime(other) + if isinstance(other, Series): # Need to use scripted query to compare to values painless = f"doc['{self.name}'].value == doc['{other.name}'].value" return ScriptFilter(painless, lang="painless") - elif isinstance(other, (int, float)): + elif isinstance(other, (int, float, datetime)): return Equal(field=self.name, value=other) elif isinstance(other, str): return Equal(field=self.name, value=other) @@ -513,11 +530,14 @@ class Series(NDFrame): raise NotImplementedError(other, type(other)) def __ne__(self, other: Union[int, float, str, "Series"]) -> BooleanFilter: + if isinstance(other, np.datetime64): + other = pd.to_datetime(other) + if isinstance(other, Series): # Need to use scripted query to compare to values painless = f"doc['{self.name}'].value != doc['{other.name}'].value" return ScriptFilter(painless, lang="painless") - elif isinstance(other, (int, float)): + elif isinstance(other, (int, float, datetime)): return NotFilter(Equal(field=self.name, value=other)) elif isinstance(other, str): return NotFilter(Equal(field=self.name, value=other)) diff --git a/tests/series/test_arithmetics_pytest.py b/tests/series/test_arithmetics_pytest.py index 9f118e8..4a6251a 100644 --- a/tests/series/test_arithmetics_pytest.py +++ b/tests/series/test_arithmetics_pytest.py @@ -15,14 +15,68 @@ # specific language governing permissions and limitations # under the License. +from datetime import datetime + # File called _pytest for PyCharm compatability import numpy as np import pytest +from eland import Series from tests.common import TestData, assert_pandas_eland_series_equal class TestSeriesArithmetics(TestData): + def test_ecommerce_datetime_comparisons(self): + pd_df = self.pd_ecommerce() + ed_df = self.ed_ecommerce() + + ops = ["__le__", "__lt__", "__gt__", "__ge__", "__eq__", "__ne__"] + + # this datetime object is timezone naive + datetime_obj = datetime(2016, 12, 18) + + # FIXME: the following timezone conversions are just a temporary fix + # to run the datetime comparison tests + # + # The problem: + # - the datetime objects of the pandas DataFrame are timezone aware and + # can't be compared with timezone naive datetime objects + # - the datetime objects of the eland DataFrame are timezone naive (which + # should be fixed) + # - however if the eland DataFrame is converted to a pandas DataFrame + # (using the `to_pandas` function) the datetime objects become timezone aware + # + # This tests converts the datetime objects of the pandas Series to + # timezone naive ones and utilizes a class to make the datetime objects of the + # eland Series timezone naive before the result of `to_pandas` is returned. + # The `to_pandas` function is executed by the `assert_pandas_eland_series_equal` + # function, which compares the eland and pandas Series + + # convert to timezone naive datetime object + pd_df["order_date"] = pd_df["order_date"].dt.tz_localize(None) + + class ModifiedElandSeries(Series): + def to_pandas(self): + """remove timezone awareness before returning the pandas dataframe""" + series = super().to_pandas() + series = series.dt.tz_localize(None) + return series + + for op in ops: + pd_series = pd_df[getattr(pd_df["order_date"], op)(datetime_obj)][ + "order_date" + ] + ed_series = ed_df[getattr(ed_df["order_date"], op)(datetime_obj)][ + "order_date" + ] + + # "type cast" to modified class (inherits from ed.Series) that overrides the `to_pandas` function + ed_series.__class__ = ModifiedElandSeries + + assert_pandas_eland_series_equal( + pd_series, ed_series, check_less_precise=True + ) + def test_ecommerce_series_invalid_div(self): pd_df = self.pd_ecommerce() ed_df = self.ed_ecommerce()