Allow using datetime types in filters

This commit is contained in:
Florian Winkler 2022-01-04 21:46:18 +01:00 committed by GitHub
parent c14bc24032
commit 3db93cd789
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 80 additions and 6 deletions

View File

@ -34,6 +34,7 @@ Based on NDFrame which underpins eland.DataFrame
import sys import sys
import warnings import warnings
from collections.abc import Collection from collections.abc import Collection
from datetime import datetime
from io import StringIO from io import StringIO
from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Tuple, Union from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Tuple, Union
@ -461,51 +462,67 @@ class Series(NDFrame):
return self._query_compiler.es_dtypes[0] return self._query_compiler.es_dtypes[0]
def __gt__(self, other: Union[int, float, "Series"]) -> BooleanFilter: def __gt__(self, other: Union[int, float, "Series"]) -> BooleanFilter:
if isinstance(other, np.datetime64):
# convert numpy datetime64 object it has no `strftime` method
other = pd.to_datetime(other)
if isinstance(other, Series): if isinstance(other, Series):
# Need to use scripted query to compare to values # Need to use scripted query to compare to values
painless = f"doc['{self.name}'].value > doc['{other.name}'].value" painless = f"doc['{self.name}'].value > doc['{other.name}'].value"
return ScriptFilter(painless, lang="painless") return ScriptFilter(painless, lang="painless")
elif isinstance(other, (int, float)): elif isinstance(other, (int, float, datetime)):
return Greater(field=self.name, value=other) return Greater(field=self.name, value=other)
else: else:
raise NotImplementedError(other, type(other)) raise NotImplementedError(other, type(other))
def __lt__(self, other: Union[int, float, "Series"]) -> BooleanFilter: def __lt__(self, other: Union[int, float, "Series"]) -> BooleanFilter:
if isinstance(other, np.datetime64):
other = pd.to_datetime(other)
if isinstance(other, Series): if isinstance(other, Series):
# Need to use scripted query to compare to values # Need to use scripted query to compare to values
painless = f"doc['{self.name}'].value < doc['{other.name}'].value" painless = f"doc['{self.name}'].value < doc['{other.name}'].value"
return ScriptFilter(painless, lang="painless") return ScriptFilter(painless, lang="painless")
elif isinstance(other, (int, float)): elif isinstance(other, (int, float, datetime)):
return Less(field=self.name, value=other) return Less(field=self.name, value=other)
else: else:
raise NotImplementedError(other, type(other)) raise NotImplementedError(other, type(other))
def __ge__(self, other: Union[int, float, "Series"]) -> BooleanFilter: def __ge__(self, other: Union[int, float, "Series"]) -> BooleanFilter:
if isinstance(other, np.datetime64):
other = pd.to_datetime(other)
if isinstance(other, Series): if isinstance(other, Series):
# Need to use scripted query to compare to values # Need to use scripted query to compare to values
painless = f"doc['{self.name}'].value >= doc['{other.name}'].value" painless = f"doc['{self.name}'].value >= doc['{other.name}'].value"
return ScriptFilter(painless, lang="painless") return ScriptFilter(painless, lang="painless")
elif isinstance(other, (int, float)): elif isinstance(other, (int, float, datetime)):
return GreaterEqual(field=self.name, value=other) return GreaterEqual(field=self.name, value=other)
else: else:
raise NotImplementedError(other, type(other)) raise NotImplementedError(other, type(other))
def __le__(self, other: Union[int, float, "Series"]) -> BooleanFilter: def __le__(self, other: Union[int, float, "Series"]) -> BooleanFilter:
if isinstance(other, np.datetime64):
other = pd.to_datetime(other)
if isinstance(other, Series): if isinstance(other, Series):
# Need to use scripted query to compare to values # Need to use scripted query to compare to values
painless = f"doc['{self.name}'].value <= doc['{other.name}'].value" painless = f"doc['{self.name}'].value <= doc['{other.name}'].value"
return ScriptFilter(painless, lang="painless") return ScriptFilter(painless, lang="painless")
elif isinstance(other, (int, float)): elif isinstance(other, (int, float, datetime)):
return LessEqual(field=self.name, value=other) return LessEqual(field=self.name, value=other)
else: else:
raise NotImplementedError(other, type(other)) raise NotImplementedError(other, type(other))
def __eq__(self, other: Union[int, float, str, "Series"]) -> BooleanFilter: def __eq__(self, other: Union[int, float, str, "Series"]) -> BooleanFilter:
if isinstance(other, np.datetime64):
other = pd.to_datetime(other)
if isinstance(other, Series): if isinstance(other, Series):
# Need to use scripted query to compare to values # Need to use scripted query to compare to values
painless = f"doc['{self.name}'].value == doc['{other.name}'].value" painless = f"doc['{self.name}'].value == doc['{other.name}'].value"
return ScriptFilter(painless, lang="painless") return ScriptFilter(painless, lang="painless")
elif isinstance(other, (int, float)): elif isinstance(other, (int, float, datetime)):
return Equal(field=self.name, value=other) return Equal(field=self.name, value=other)
elif isinstance(other, str): elif isinstance(other, str):
return Equal(field=self.name, value=other) return Equal(field=self.name, value=other)
@ -513,11 +530,14 @@ class Series(NDFrame):
raise NotImplementedError(other, type(other)) raise NotImplementedError(other, type(other))
def __ne__(self, other: Union[int, float, str, "Series"]) -> BooleanFilter: def __ne__(self, other: Union[int, float, str, "Series"]) -> BooleanFilter:
if isinstance(other, np.datetime64):
other = pd.to_datetime(other)
if isinstance(other, Series): if isinstance(other, Series):
# Need to use scripted query to compare to values # Need to use scripted query to compare to values
painless = f"doc['{self.name}'].value != doc['{other.name}'].value" painless = f"doc['{self.name}'].value != doc['{other.name}'].value"
return ScriptFilter(painless, lang="painless") return ScriptFilter(painless, lang="painless")
elif isinstance(other, (int, float)): elif isinstance(other, (int, float, datetime)):
return NotFilter(Equal(field=self.name, value=other)) return NotFilter(Equal(field=self.name, value=other))
elif isinstance(other, str): elif isinstance(other, str):
return NotFilter(Equal(field=self.name, value=other)) return NotFilter(Equal(field=self.name, value=other))

View File

@ -15,14 +15,68 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
from datetime import datetime
# File called _pytest for PyCharm compatability # File called _pytest for PyCharm compatability
import numpy as np import numpy as np
import pytest import pytest
from eland import Series
from tests.common import TestData, assert_pandas_eland_series_equal from tests.common import TestData, assert_pandas_eland_series_equal
class TestSeriesArithmetics(TestData): class TestSeriesArithmetics(TestData):
def test_ecommerce_datetime_comparisons(self):
pd_df = self.pd_ecommerce()
ed_df = self.ed_ecommerce()
ops = ["__le__", "__lt__", "__gt__", "__ge__", "__eq__", "__ne__"]
# this datetime object is timezone naive
datetime_obj = datetime(2016, 12, 18)
# FIXME: the following timezone conversions are just a temporary fix
# to run the datetime comparison tests
#
# The problem:
# - the datetime objects of the pandas DataFrame are timezone aware and
# can't be compared with timezone naive datetime objects
# - the datetime objects of the eland DataFrame are timezone naive (which
# should be fixed)
# - however if the eland DataFrame is converted to a pandas DataFrame
# (using the `to_pandas` function) the datetime objects become timezone aware
#
# This tests converts the datetime objects of the pandas Series to
# timezone naive ones and utilizes a class to make the datetime objects of the
# eland Series timezone naive before the result of `to_pandas` is returned.
# The `to_pandas` function is executed by the `assert_pandas_eland_series_equal`
# function, which compares the eland and pandas Series
# convert to timezone naive datetime object
pd_df["order_date"] = pd_df["order_date"].dt.tz_localize(None)
class ModifiedElandSeries(Series):
def to_pandas(self):
"""remove timezone awareness before returning the pandas dataframe"""
series = super().to_pandas()
series = series.dt.tz_localize(None)
return series
for op in ops:
pd_series = pd_df[getattr(pd_df["order_date"], op)(datetime_obj)][
"order_date"
]
ed_series = ed_df[getattr(ed_df["order_date"], op)(datetime_obj)][
"order_date"
]
# "type cast" to modified class (inherits from ed.Series) that overrides the `to_pandas` function
ed_series.__class__ = ModifiedElandSeries
assert_pandas_eland_series_equal(
pd_series, ed_series, check_less_precise=True
)
def test_ecommerce_series_invalid_div(self): def test_ecommerce_series_invalid_div(self):
pd_df = self.pd_ecommerce() pd_df = self.pd_ecommerce()
ed_df = self.ed_ecommerce() ed_df = self.ed_ecommerce()