mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
Allow using datetime types in filters
This commit is contained in:
parent
c14bc24032
commit
3db93cd789
@ -34,6 +34,7 @@ Based on NDFrame which underpins eland.DataFrame
|
||||
import sys
|
||||
import warnings
|
||||
from collections.abc import Collection
|
||||
from datetime import datetime
|
||||
from io import StringIO
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
@ -461,51 +462,67 @@ class Series(NDFrame):
|
||||
return self._query_compiler.es_dtypes[0]
|
||||
|
||||
def __gt__(self, other: Union[int, float, "Series"]) -> BooleanFilter:
|
||||
if isinstance(other, np.datetime64):
|
||||
# convert numpy datetime64 object it has no `strftime` method
|
||||
other = pd.to_datetime(other)
|
||||
|
||||
if isinstance(other, Series):
|
||||
# Need to use scripted query to compare to values
|
||||
painless = f"doc['{self.name}'].value > doc['{other.name}'].value"
|
||||
return ScriptFilter(painless, lang="painless")
|
||||
elif isinstance(other, (int, float)):
|
||||
elif isinstance(other, (int, float, datetime)):
|
||||
return Greater(field=self.name, value=other)
|
||||
else:
|
||||
raise NotImplementedError(other, type(other))
|
||||
|
||||
def __lt__(self, other: Union[int, float, "Series"]) -> BooleanFilter:
|
||||
if isinstance(other, np.datetime64):
|
||||
other = pd.to_datetime(other)
|
||||
|
||||
if isinstance(other, Series):
|
||||
# Need to use scripted query to compare to values
|
||||
painless = f"doc['{self.name}'].value < doc['{other.name}'].value"
|
||||
return ScriptFilter(painless, lang="painless")
|
||||
elif isinstance(other, (int, float)):
|
||||
elif isinstance(other, (int, float, datetime)):
|
||||
return Less(field=self.name, value=other)
|
||||
else:
|
||||
raise NotImplementedError(other, type(other))
|
||||
|
||||
def __ge__(self, other: Union[int, float, "Series"]) -> BooleanFilter:
|
||||
if isinstance(other, np.datetime64):
|
||||
other = pd.to_datetime(other)
|
||||
|
||||
if isinstance(other, Series):
|
||||
# Need to use scripted query to compare to values
|
||||
painless = f"doc['{self.name}'].value >= doc['{other.name}'].value"
|
||||
return ScriptFilter(painless, lang="painless")
|
||||
elif isinstance(other, (int, float)):
|
||||
elif isinstance(other, (int, float, datetime)):
|
||||
return GreaterEqual(field=self.name, value=other)
|
||||
else:
|
||||
raise NotImplementedError(other, type(other))
|
||||
|
||||
def __le__(self, other: Union[int, float, "Series"]) -> BooleanFilter:
|
||||
if isinstance(other, np.datetime64):
|
||||
other = pd.to_datetime(other)
|
||||
|
||||
if isinstance(other, Series):
|
||||
# Need to use scripted query to compare to values
|
||||
painless = f"doc['{self.name}'].value <= doc['{other.name}'].value"
|
||||
return ScriptFilter(painless, lang="painless")
|
||||
elif isinstance(other, (int, float)):
|
||||
elif isinstance(other, (int, float, datetime)):
|
||||
return LessEqual(field=self.name, value=other)
|
||||
else:
|
||||
raise NotImplementedError(other, type(other))
|
||||
|
||||
def __eq__(self, other: Union[int, float, str, "Series"]) -> BooleanFilter:
|
||||
if isinstance(other, np.datetime64):
|
||||
other = pd.to_datetime(other)
|
||||
|
||||
if isinstance(other, Series):
|
||||
# Need to use scripted query to compare to values
|
||||
painless = f"doc['{self.name}'].value == doc['{other.name}'].value"
|
||||
return ScriptFilter(painless, lang="painless")
|
||||
elif isinstance(other, (int, float)):
|
||||
elif isinstance(other, (int, float, datetime)):
|
||||
return Equal(field=self.name, value=other)
|
||||
elif isinstance(other, str):
|
||||
return Equal(field=self.name, value=other)
|
||||
@ -513,11 +530,14 @@ class Series(NDFrame):
|
||||
raise NotImplementedError(other, type(other))
|
||||
|
||||
def __ne__(self, other: Union[int, float, str, "Series"]) -> BooleanFilter:
|
||||
if isinstance(other, np.datetime64):
|
||||
other = pd.to_datetime(other)
|
||||
|
||||
if isinstance(other, Series):
|
||||
# Need to use scripted query to compare to values
|
||||
painless = f"doc['{self.name}'].value != doc['{other.name}'].value"
|
||||
return ScriptFilter(painless, lang="painless")
|
||||
elif isinstance(other, (int, float)):
|
||||
elif isinstance(other, (int, float, datetime)):
|
||||
return NotFilter(Equal(field=self.name, value=other))
|
||||
elif isinstance(other, str):
|
||||
return NotFilter(Equal(field=self.name, value=other))
|
||||
|
@ -15,14 +15,68 @@
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
# File called _pytest for PyCharm compatability
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from eland import Series
|
||||
from tests.common import TestData, assert_pandas_eland_series_equal
|
||||
|
||||
|
||||
class TestSeriesArithmetics(TestData):
|
||||
def test_ecommerce_datetime_comparisons(self):
|
||||
pd_df = self.pd_ecommerce()
|
||||
ed_df = self.ed_ecommerce()
|
||||
|
||||
ops = ["__le__", "__lt__", "__gt__", "__ge__", "__eq__", "__ne__"]
|
||||
|
||||
# this datetime object is timezone naive
|
||||
datetime_obj = datetime(2016, 12, 18)
|
||||
|
||||
# FIXME: the following timezone conversions are just a temporary fix
|
||||
# to run the datetime comparison tests
|
||||
#
|
||||
# The problem:
|
||||
# - the datetime objects of the pandas DataFrame are timezone aware and
|
||||
# can't be compared with timezone naive datetime objects
|
||||
# - the datetime objects of the eland DataFrame are timezone naive (which
|
||||
# should be fixed)
|
||||
# - however if the eland DataFrame is converted to a pandas DataFrame
|
||||
# (using the `to_pandas` function) the datetime objects become timezone aware
|
||||
#
|
||||
# This tests converts the datetime objects of the pandas Series to
|
||||
# timezone naive ones and utilizes a class to make the datetime objects of the
|
||||
# eland Series timezone naive before the result of `to_pandas` is returned.
|
||||
# The `to_pandas` function is executed by the `assert_pandas_eland_series_equal`
|
||||
# function, which compares the eland and pandas Series
|
||||
|
||||
# convert to timezone naive datetime object
|
||||
pd_df["order_date"] = pd_df["order_date"].dt.tz_localize(None)
|
||||
|
||||
class ModifiedElandSeries(Series):
|
||||
def to_pandas(self):
|
||||
"""remove timezone awareness before returning the pandas dataframe"""
|
||||
series = super().to_pandas()
|
||||
series = series.dt.tz_localize(None)
|
||||
return series
|
||||
|
||||
for op in ops:
|
||||
pd_series = pd_df[getattr(pd_df["order_date"], op)(datetime_obj)][
|
||||
"order_date"
|
||||
]
|
||||
ed_series = ed_df[getattr(ed_df["order_date"], op)(datetime_obj)][
|
||||
"order_date"
|
||||
]
|
||||
|
||||
# "type cast" to modified class (inherits from ed.Series) that overrides the `to_pandas` function
|
||||
ed_series.__class__ = ModifiedElandSeries
|
||||
|
||||
assert_pandas_eland_series_equal(
|
||||
pd_series, ed_series, check_less_precise=True
|
||||
)
|
||||
|
||||
def test_ecommerce_series_invalid_div(self):
|
||||
pd_df = self.pd_ecommerce()
|
||||
ed_df = self.ed_ecommerce()
|
||||
|
Loading…
x
Reference in New Issue
Block a user