mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
Fix DataFrame.shape when smaller than its SizedTask
This commit is contained in:
parent
94dbb36081
commit
bfd0ee6f90
@ -77,16 +77,16 @@ class Operations:
|
||||
|
||||
def head(self, index, n):
|
||||
# Add a task that is an ascending sort with size=n
|
||||
task = HeadTask(index.sort_field, n)
|
||||
task = HeadTask(index, n)
|
||||
self._tasks.append(task)
|
||||
|
||||
def tail(self, index, n):
|
||||
# Add a task that is descending sort with size=n
|
||||
task = TailTask(index.sort_field, n)
|
||||
task = TailTask(index, n)
|
||||
self._tasks.append(task)
|
||||
|
||||
def sample(self, index, n, random_state):
|
||||
task = SampleTask(index.sort_field, n, random_state)
|
||||
task = SampleTask(index, n, random_state)
|
||||
self._tasks.append(task)
|
||||
|
||||
def arithmetic_op_fields(self, display_name, arithmetic_series):
|
||||
|
@ -50,6 +50,11 @@ class Task(ABC):
|
||||
|
||||
|
||||
class SizeTask(Task):
|
||||
def __init__(self, task_type: str, index, count: int):
|
||||
super().__init__(task_type)
|
||||
self._sort_field = index.sort_field
|
||||
self._count = min(len(index), count)
|
||||
|
||||
@abstractmethod
|
||||
def size(self) -> int:
|
||||
# must override
|
||||
@ -57,12 +62,8 @@ class SizeTask(Task):
|
||||
|
||||
|
||||
class HeadTask(SizeTask):
|
||||
def __init__(self, sort_field: str, count: int):
|
||||
super().__init__("head")
|
||||
|
||||
# Add a task that is an ascending sort with size=count
|
||||
self._sort_field = sort_field
|
||||
self._count = count
|
||||
def __init__(self, index, count: int):
|
||||
super().__init__("head", index, count)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"('{self._task_type}': ('sort_field': '{self._sort_field}', 'count': {self._count}))"
|
||||
@ -108,12 +109,8 @@ class HeadTask(SizeTask):
|
||||
|
||||
|
||||
class TailTask(SizeTask):
|
||||
def __init__(self, sort_field: str, count: int):
|
||||
super().__init__("tail")
|
||||
|
||||
# Add a task that is descending sort with size=count
|
||||
self._sort_field = sort_field
|
||||
self._count = count
|
||||
def __init__(self, index, count: int):
|
||||
super().__init__("tail", index, count)
|
||||
|
||||
def resolve_task(
|
||||
self,
|
||||
@ -175,11 +172,9 @@ class TailTask(SizeTask):
|
||||
|
||||
|
||||
class SampleTask(SizeTask):
|
||||
def __init__(self, sort_field: str, count: int, random_state: int):
|
||||
super().__init__("sample")
|
||||
self._count = count
|
||||
def __init__(self, index, count: int, random_state: int):
|
||||
super().__init__("sample", index, count)
|
||||
self._random_state = random_state
|
||||
self._sort_field = sort_field
|
||||
|
||||
def resolve_task(
|
||||
self,
|
||||
|
@ -95,3 +95,14 @@ class TestDataFrameHeadTail(TestData):
|
||||
df = df[["timestamp", "OriginAirportID", "DestAirportID", "FlightDelayMin"]]
|
||||
df = df.tail()
|
||||
print(df)
|
||||
|
||||
def test_doc_test_tail_empty(self):
|
||||
df = self.ed_flights()
|
||||
df = df[df.OriginAirportID == "NADA"]
|
||||
df = df.tail()
|
||||
assert df.shape[0] == 0
|
||||
|
||||
def test_doc_test_tail_single(self):
|
||||
df = self.ed_flights_small()
|
||||
df = df[(df.Carrier == "Kibana Airlines") & (df.DestAirportID == "ITM")].tail()
|
||||
assert df.shape[0] == 1
|
||||
|
Loading…
x
Reference in New Issue
Block a user