Fix DataFrame.shape when smaller than its SizedTask

This commit is contained in:
Daniel Mesejo-León 2020-05-06 20:59:47 +02:00 committed by GitHub
parent 94dbb36081
commit bfd0ee6f90
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 25 additions and 19 deletions

View File

@ -77,16 +77,16 @@ class Operations:
def head(self, index, n):
# Add a task that is an ascending sort with size=n
task = HeadTask(index.sort_field, n)
task = HeadTask(index, n)
self._tasks.append(task)
def tail(self, index, n):
# Add a task that is descending sort with size=n
task = TailTask(index.sort_field, n)
task = TailTask(index, n)
self._tasks.append(task)
def sample(self, index, n, random_state):
task = SampleTask(index.sort_field, n, random_state)
task = SampleTask(index, n, random_state)
self._tasks.append(task)
def arithmetic_op_fields(self, display_name, arithmetic_series):

View File

@ -50,6 +50,11 @@ class Task(ABC):
class SizeTask(Task):
def __init__(self, task_type: str, index, count: int):
super().__init__(task_type)
self._sort_field = index.sort_field
self._count = min(len(index), count)
@abstractmethod
def size(self) -> int:
# must override
@ -57,12 +62,8 @@ class SizeTask(Task):
class HeadTask(SizeTask):
def __init__(self, sort_field: str, count: int):
super().__init__("head")
# Add a task that is an ascending sort with size=count
self._sort_field = sort_field
self._count = count
def __init__(self, index, count: int):
super().__init__("head", index, count)
def __repr__(self) -> str:
return f"('{self._task_type}': ('sort_field': '{self._sort_field}', 'count': {self._count}))"
@ -108,12 +109,8 @@ class HeadTask(SizeTask):
class TailTask(SizeTask):
def __init__(self, sort_field: str, count: int):
super().__init__("tail")
# Add a task that is descending sort with size=count
self._sort_field = sort_field
self._count = count
def __init__(self, index, count: int):
super().__init__("tail", index, count)
def resolve_task(
self,
@ -175,11 +172,9 @@ class TailTask(SizeTask):
class SampleTask(SizeTask):
def __init__(self, sort_field: str, count: int, random_state: int):
super().__init__("sample")
self._count = count
def __init__(self, index, count: int, random_state: int):
super().__init__("sample", index, count)
self._random_state = random_state
self._sort_field = sort_field
def resolve_task(
self,

View File

@ -95,3 +95,14 @@ class TestDataFrameHeadTail(TestData):
df = df[["timestamp", "OriginAirportID", "DestAirportID", "FlightDelayMin"]]
df = df.tail()
print(df)
def test_doc_test_tail_empty(self):
df = self.ed_flights()
df = df[df.OriginAirportID == "NADA"]
df = df.tail()
assert df.shape[0] == 0
def test_doc_test_tail_single(self):
df = self.ed_flights_small()
df = df[(df.Carrier == "Kibana Airlines") & (df.DestAirportID == "ITM")].tail()
assert df.shape[0] == 1