From b7c6c26606f0262d980000ed8defd7b369726830 Mon Sep 17 00:00:00 2001 From: "P. Sai Vinay" <33659563+V1NAY8@users.noreply.github.com> Date: Tue, 13 Oct 2020 21:28:09 +0530 Subject: [PATCH] Change DataFrame.filter() to preserve the order of items --- eland/actions.py | 10 +++-- eland/operations.py | 41 +++++++++++++++------ eland/query_compiler.py | 9 ++++- eland/tasks.py | 5 ++- eland/tests/dataframe/test_filter_pytest.py | 10 +++++ eland/utils.py | 3 +- 6 files changed, 59 insertions(+), 19 deletions(-) diff --git a/eland/actions.py b/eland/actions.py index 32c52fc..6da1751 100644 --- a/eland/actions.py +++ b/eland/actions.py @@ -16,7 +16,7 @@ # under the License. from abc import ABC, abstractmethod -from typing import TYPE_CHECKING +from typing import List, Optional, TYPE_CHECKING, Union from eland import SortOrder @@ -50,11 +50,15 @@ class PostProcessingAction(ABC): class SortIndexAction(PostProcessingAction): - def __init__(self) -> None: + def __init__(self, items: Optional[Union[List[int], List[str]]] = None) -> None: super().__init__("sort_index") + self._items = items def resolve_action(self, df: "pd.DataFrame") -> "pd.DataFrame": - return df.sort_index() + if self._items is not None: + return df.reindex(self._items) + else: + return df.sort_index() def __repr__(self) -> str: return f"('{self.type}')" diff --git a/eland/operations.py b/eland/operations.py index cd48ef5..d6c9543 100644 --- a/eland/operations.py +++ b/eland/operations.py @@ -18,7 +18,7 @@ import copy import typing import warnings -from typing import Optional, Tuple, List, Dict, Any +from typing import Optional, Sequence, Tuple, List, Dict, Any import numpy as np import pandas as pd @@ -33,9 +33,10 @@ from eland.common import ( build_pd_series, ) from eland.query import Query -from eland.actions import SortFieldAction +from eland.actions import PostProcessingAction, SortFieldAction from eland.tasks import ( HeadTask, + RESOLVED_TASK_TYPE, TailTask, SampleTask, BooleanFilterTask, @@ -559,7 +560,13 @@ class Operations: return es_aggs - def filter(self, query_compiler, items=None, like=None, regex=None): + def filter( + self, + query_compiler: "QueryCompiler", + items: Optional[Sequence[str]] = None, + like: Optional[str] = None, + regex: Optional[str] = None, + ) -> None: # This function is only called for axis='index', # DataFrame.filter(..., axis="columns") calls .drop() if items is not None: @@ -795,7 +802,9 @@ class Operations: index=query_compiler._index_pattern, body=body.to_count_body() )["count"] - def _validate_index_operation(self, query_compiler, items): + def _validate_index_operation( + self, query_compiler: "QueryCompiler", items: Sequence[str] + ) -> RESOLVED_TASK_TYPE: if not isinstance(items, list): raise TypeError(f"list item required - not {type(items)}") @@ -828,7 +837,9 @@ class Operations: index=query_compiler._index_pattern, body=body.to_count_body() )["count"] - def drop_index_values(self, query_compiler, field, items): + def drop_index_values( + self, query_compiler: "QueryCompiler", field: str, items: Sequence[str] + ) -> None: self._validate_index_operation(query_compiler, items) # Putting boolean queries together @@ -845,12 +856,14 @@ class Operations: task = QueryTermsTask(False, field, items) self._tasks.append(task) - def filter_index_values(self, query_compiler, field, items): + def filter_index_values( + self, query_compiler: "QueryCompiler", field: str, items: Sequence[str] + ) -> None: # Basically .drop_index_values() except with must=True on tasks. self._validate_index_operation(query_compiler, items) if field == Index.ID_INDEX_FIELD: - task = QueryIdsTask(True, items) + task = QueryIdsTask(True, items, sort_index_by_ids=True) else: task = QueryTermsTask(True, field, items) self._tasks.append(task) @@ -869,7 +882,9 @@ class Operations: return size, sort_params @staticmethod - def _count_post_processing(post_processing): + def _count_post_processing( + post_processing: List["PostProcessingAction"], + ) -> Optional[int]: size = None for action in post_processing: if isinstance(action, SizeTask): @@ -878,13 +893,15 @@ class Operations: return size @staticmethod - def _apply_df_post_processing(df, post_processing): + def _apply_df_post_processing( + df: "pd.DataFrame", post_processing: List["PostProcessingAction"] + ) -> pd.DataFrame: for action in post_processing: df = action.resolve_action(df) return df - def _resolve_tasks(self, query_compiler): + def _resolve_tasks(self, query_compiler: "QueryCompiler") -> RESOLVED_TASK_TYPE: # We now try and combine all tasks into an Elasticsearch query # Some operations can be simply combined into a single query # other operations require pre-queries and then combinations @@ -907,7 +924,9 @@ class Operations: return query_params, post_processing - def _size(self, query_params, post_processing): + def _size( + self, query_params: "QueryParams", post_processing: List["PostProcessingAction"] + ) -> Optional[int]: # Shrink wrap code around checking if size parameter is set size = query_params.size diff --git a/eland/query_compiler.py b/eland/query_compiler.py index 60309f7..0d8a394 100644 --- a/eland/query_compiler.py +++ b/eland/query_compiler.py @@ -17,7 +17,7 @@ import copy from datetime import datetime -from typing import Optional, TYPE_CHECKING, List +from typing import Optional, Sequence, TYPE_CHECKING, List import numpy as np import pandas as pd @@ -482,7 +482,12 @@ class QueryCompiler: return result - def filter(self, items=None, like=None, regex=None): + def filter( + self, + items: Optional[Sequence[str]] = None, + like: Optional[str] = None, + regex: Optional[str] = None, + ) -> "QueryCompiler": # field will be es_index_field for DataFrames or the column for Series. # This function is only called for axis='index', # DataFrame.filter(..., axis="columns") calls .drop() diff --git a/eland/tasks.py b/eland/tasks.py index 2fa835d..f2b5b39 100644 --- a/eland/tasks.py +++ b/eland/tasks.py @@ -221,7 +221,7 @@ class SampleTask(SizeTask): class QueryIdsTask(Task): - def __init__(self, must: bool, ids: List[str]): + def __init__(self, must: bool, ids: List[str], sort_index_by_ids: bool = False): """ Parameters ---------- @@ -235,6 +235,7 @@ class QueryIdsTask(Task): self._must = must self._ids = ids + self._sort_index_by_ids = sort_index_by_ids def resolve_task( self, @@ -243,6 +244,8 @@ class QueryIdsTask(Task): query_compiler: "QueryCompiler", ) -> RESOLVED_TASK_TYPE: query_params.query.ids(self._ids, must=self._must) + if self._sort_index_by_ids: + post_processing.append(SortIndexAction(items=self._ids)) return query_params, post_processing def __repr__(self) -> str: diff --git a/eland/tests/dataframe/test_filter_pytest.py b/eland/tests/dataframe/test_filter_pytest.py index e6ec8c8..9524e17 100644 --- a/eland/tests/dataframe/test_filter_pytest.py +++ b/eland/tests/dataframe/test_filter_pytest.py @@ -64,3 +64,13 @@ class TestDataFrameFilter(TestData): ed_flights_small.filter(like="2", axis=0) with pytest.raises(NotImplementedError): ed_flights_small.filter(regex="^2", axis=0) + + def test_filter_index_order(self): + # Filtering dataframe should retain order of items + ed_flights = self.ed_flights() + + items = ["4", "2", "3", "1", "0"] + + assert [ + i for i in ed_flights.filter(axis="index", items=items).to_pandas().index + ] == items diff --git a/eland/utils.py b/eland/utils.py index 330db9a..66bf586 100644 --- a/eland/utils.py +++ b/eland/utils.py @@ -24,7 +24,6 @@ import pandas as pd # type: ignore RT = TypeVar("RT") -Item = TypeVar("Item") def deprecated_api( @@ -62,7 +61,7 @@ def to_list(x: Union[Collection[Any], pd.Series]) -> List[Any]: raise NotImplementedError(f"Could not convert {type(x).__name__} into a list") -def try_sort(iterable: Iterable[Item]) -> Iterable[Item]: +def try_sort(iterable: Iterable[str]) -> Iterable[str]: # Pulled from pandas.core.common since # it was deprecated and removed in 1.1 listed = list(iterable)