Allow using datetime types in filters

This commit is contained in:
Florian Winkler 2022-01-04 21:46:18 +01:00 committed by GitHub
parent c14bc24032
commit 3db93cd789
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 80 additions and 6 deletions

View File

@ -34,6 +34,7 @@ Based on NDFrame which underpins eland.DataFrame
import sys
import warnings
from collections.abc import Collection
from datetime import datetime
from io import StringIO
from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Tuple, Union
@ -461,51 +462,67 @@ class Series(NDFrame):
return self._query_compiler.es_dtypes[0]
def __gt__(self, other: Union[int, float, "Series"]) -> BooleanFilter:
if isinstance(other, np.datetime64):
# convert numpy datetime64 object it has no `strftime` method
other = pd.to_datetime(other)
if isinstance(other, Series):
# Need to use scripted query to compare to values
painless = f"doc['{self.name}'].value > doc['{other.name}'].value"
return ScriptFilter(painless, lang="painless")
elif isinstance(other, (int, float)):
elif isinstance(other, (int, float, datetime)):
return Greater(field=self.name, value=other)
else:
raise NotImplementedError(other, type(other))
def __lt__(self, other: Union[int, float, "Series"]) -> BooleanFilter:
if isinstance(other, np.datetime64):
other = pd.to_datetime(other)
if isinstance(other, Series):
# Need to use scripted query to compare to values
painless = f"doc['{self.name}'].value < doc['{other.name}'].value"
return ScriptFilter(painless, lang="painless")
elif isinstance(other, (int, float)):
elif isinstance(other, (int, float, datetime)):
return Less(field=self.name, value=other)
else:
raise NotImplementedError(other, type(other))
def __ge__(self, other: Union[int, float, "Series"]) -> BooleanFilter:
if isinstance(other, np.datetime64):
other = pd.to_datetime(other)
if isinstance(other, Series):
# Need to use scripted query to compare to values
painless = f"doc['{self.name}'].value >= doc['{other.name}'].value"
return ScriptFilter(painless, lang="painless")
elif isinstance(other, (int, float)):
elif isinstance(other, (int, float, datetime)):
return GreaterEqual(field=self.name, value=other)
else:
raise NotImplementedError(other, type(other))
def __le__(self, other: Union[int, float, "Series"]) -> BooleanFilter:
if isinstance(other, np.datetime64):
other = pd.to_datetime(other)
if isinstance(other, Series):
# Need to use scripted query to compare to values
painless = f"doc['{self.name}'].value <= doc['{other.name}'].value"
return ScriptFilter(painless, lang="painless")
elif isinstance(other, (int, float)):
elif isinstance(other, (int, float, datetime)):
return LessEqual(field=self.name, value=other)
else:
raise NotImplementedError(other, type(other))
def __eq__(self, other: Union[int, float, str, "Series"]) -> BooleanFilter:
if isinstance(other, np.datetime64):
other = pd.to_datetime(other)
if isinstance(other, Series):
# Need to use scripted query to compare to values
painless = f"doc['{self.name}'].value == doc['{other.name}'].value"
return ScriptFilter(painless, lang="painless")
elif isinstance(other, (int, float)):
elif isinstance(other, (int, float, datetime)):
return Equal(field=self.name, value=other)
elif isinstance(other, str):
return Equal(field=self.name, value=other)
@ -513,11 +530,14 @@ class Series(NDFrame):
raise NotImplementedError(other, type(other))
def __ne__(self, other: Union[int, float, str, "Series"]) -> BooleanFilter:
if isinstance(other, np.datetime64):
other = pd.to_datetime(other)
if isinstance(other, Series):
# Need to use scripted query to compare to values
painless = f"doc['{self.name}'].value != doc['{other.name}'].value"
return ScriptFilter(painless, lang="painless")
elif isinstance(other, (int, float)):
elif isinstance(other, (int, float, datetime)):
return NotFilter(Equal(field=self.name, value=other))
elif isinstance(other, str):
return NotFilter(Equal(field=self.name, value=other))

View File

@ -15,14 +15,68 @@
# specific language governing permissions and limitations
# under the License.
from datetime import datetime
# File called _pytest for PyCharm compatability
import numpy as np
import pytest
from eland import Series
from tests.common import TestData, assert_pandas_eland_series_equal
class TestSeriesArithmetics(TestData):
def test_ecommerce_datetime_comparisons(self):
pd_df = self.pd_ecommerce()
ed_df = self.ed_ecommerce()
ops = ["__le__", "__lt__", "__gt__", "__ge__", "__eq__", "__ne__"]
# this datetime object is timezone naive
datetime_obj = datetime(2016, 12, 18)
# FIXME: the following timezone conversions are just a temporary fix
# to run the datetime comparison tests
#
# The problem:
# - the datetime objects of the pandas DataFrame are timezone aware and
# can't be compared with timezone naive datetime objects
# - the datetime objects of the eland DataFrame are timezone naive (which
# should be fixed)
# - however if the eland DataFrame is converted to a pandas DataFrame
# (using the `to_pandas` function) the datetime objects become timezone aware
#
# This tests converts the datetime objects of the pandas Series to
# timezone naive ones and utilizes a class to make the datetime objects of the
# eland Series timezone naive before the result of `to_pandas` is returned.
# The `to_pandas` function is executed by the `assert_pandas_eland_series_equal`
# function, which compares the eland and pandas Series
# convert to timezone naive datetime object
pd_df["order_date"] = pd_df["order_date"].dt.tz_localize(None)
class ModifiedElandSeries(Series):
def to_pandas(self):
"""remove timezone awareness before returning the pandas dataframe"""
series = super().to_pandas()
series = series.dt.tz_localize(None)
return series
for op in ops:
pd_series = pd_df[getattr(pd_df["order_date"], op)(datetime_obj)][
"order_date"
]
ed_series = ed_df[getattr(ed_df["order_date"], op)(datetime_obj)][
"order_date"
]
# "type cast" to modified class (inherits from ed.Series) that overrides the `to_pandas` function
ed_series.__class__ = ModifiedElandSeries
assert_pandas_eland_series_equal(
pd_series, ed_series, check_less_precise=True
)
def test_ecommerce_series_invalid_div(self):
pd_df = self.pd_ecommerce()
ed_df = self.ed_ecommerce()