From bfd0ee6f906de191d7b2bd5ae6118ce7c661a161 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Mesejo-Le=C3=B3n?= Date: Wed, 6 May 2020 20:59:47 +0200 Subject: [PATCH] Fix DataFrame.shape when smaller than its SizedTask --- eland/operations.py | 6 ++--- eland/tasks.py | 27 ++++++++----------- .../tests/dataframe/test_head_tail_pytest.py | 11 ++++++++ 3 files changed, 25 insertions(+), 19 deletions(-) diff --git a/eland/operations.py b/eland/operations.py index 9520371..a5c2856 100644 --- a/eland/operations.py +++ b/eland/operations.py @@ -77,16 +77,16 @@ class Operations: def head(self, index, n): # Add a task that is an ascending sort with size=n - task = HeadTask(index.sort_field, n) + task = HeadTask(index, n) self._tasks.append(task) def tail(self, index, n): # Add a task that is descending sort with size=n - task = TailTask(index.sort_field, n) + task = TailTask(index, n) self._tasks.append(task) def sample(self, index, n, random_state): - task = SampleTask(index.sort_field, n, random_state) + task = SampleTask(index, n, random_state) self._tasks.append(task) def arithmetic_op_fields(self, display_name, arithmetic_series): diff --git a/eland/tasks.py b/eland/tasks.py index 6c8ca1e..a1542fe 100644 --- a/eland/tasks.py +++ b/eland/tasks.py @@ -50,6 +50,11 @@ class Task(ABC): class SizeTask(Task): + def __init__(self, task_type: str, index, count: int): + super().__init__(task_type) + self._sort_field = index.sort_field + self._count = min(len(index), count) + @abstractmethod def size(self) -> int: # must override @@ -57,12 +62,8 @@ class SizeTask(Task): class HeadTask(SizeTask): - def __init__(self, sort_field: str, count: int): - super().__init__("head") - - # Add a task that is an ascending sort with size=count - self._sort_field = sort_field - self._count = count + def __init__(self, index, count: int): + super().__init__("head", index, count) def __repr__(self) -> str: return f"('{self._task_type}': ('sort_field': '{self._sort_field}', 'count': {self._count}))" @@ -108,12 +109,8 @@ class HeadTask(SizeTask): class TailTask(SizeTask): - def __init__(self, sort_field: str, count: int): - super().__init__("tail") - - # Add a task that is descending sort with size=count - self._sort_field = sort_field - self._count = count + def __init__(self, index, count: int): + super().__init__("tail", index, count) def resolve_task( self, @@ -175,11 +172,9 @@ class TailTask(SizeTask): class SampleTask(SizeTask): - def __init__(self, sort_field: str, count: int, random_state: int): - super().__init__("sample") - self._count = count + def __init__(self, index, count: int, random_state: int): + super().__init__("sample", index, count) self._random_state = random_state - self._sort_field = sort_field def resolve_task( self, diff --git a/eland/tests/dataframe/test_head_tail_pytest.py b/eland/tests/dataframe/test_head_tail_pytest.py index 6055f01..0e20f31 100644 --- a/eland/tests/dataframe/test_head_tail_pytest.py +++ b/eland/tests/dataframe/test_head_tail_pytest.py @@ -95,3 +95,14 @@ class TestDataFrameHeadTail(TestData): df = df[["timestamp", "OriginAirportID", "DestAirportID", "FlightDelayMin"]] df = df.tail() print(df) + + def test_doc_test_tail_empty(self): + df = self.ed_flights() + df = df[df.OriginAirportID == "NADA"] + df = df.tail() + assert df.shape[0] == 0 + + def test_doc_test_tail_single(self): + df = self.ed_flights_small() + df = df[(df.Carrier == "Kibana Airlines") & (df.DestAirportID == "ITM")].tail() + assert df.shape[0] == 1