Change DataFrame.filter() to preserve the order of items

This commit is contained in:
P. Sai Vinay 2020-10-13 21:28:09 +05:30 committed by GitHub
parent 0dd247b693
commit b7c6c26606
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 59 additions and 19 deletions

View File

@ -16,7 +16,7 @@
# under the License.
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING
from typing import List, Optional, TYPE_CHECKING, Union
from eland import SortOrder
@ -50,10 +50,14 @@ class PostProcessingAction(ABC):
class SortIndexAction(PostProcessingAction):
def __init__(self) -> None:
def __init__(self, items: Optional[Union[List[int], List[str]]] = None) -> None:
super().__init__("sort_index")
self._items = items
def resolve_action(self, df: "pd.DataFrame") -> "pd.DataFrame":
if self._items is not None:
return df.reindex(self._items)
else:
return df.sort_index()
def __repr__(self) -> str:

View File

@ -18,7 +18,7 @@
import copy
import typing
import warnings
from typing import Optional, Tuple, List, Dict, Any
from typing import Optional, Sequence, Tuple, List, Dict, Any
import numpy as np
import pandas as pd
@ -33,9 +33,10 @@ from eland.common import (
build_pd_series,
)
from eland.query import Query
from eland.actions import SortFieldAction
from eland.actions import PostProcessingAction, SortFieldAction
from eland.tasks import (
HeadTask,
RESOLVED_TASK_TYPE,
TailTask,
SampleTask,
BooleanFilterTask,
@ -559,7 +560,13 @@ class Operations:
return es_aggs
def filter(self, query_compiler, items=None, like=None, regex=None):
def filter(
self,
query_compiler: "QueryCompiler",
items: Optional[Sequence[str]] = None,
like: Optional[str] = None,
regex: Optional[str] = None,
) -> None:
# This function is only called for axis='index',
# DataFrame.filter(..., axis="columns") calls .drop()
if items is not None:
@ -795,7 +802,9 @@ class Operations:
index=query_compiler._index_pattern, body=body.to_count_body()
)["count"]
def _validate_index_operation(self, query_compiler, items):
def _validate_index_operation(
self, query_compiler: "QueryCompiler", items: Sequence[str]
) -> RESOLVED_TASK_TYPE:
if not isinstance(items, list):
raise TypeError(f"list item required - not {type(items)}")
@ -828,7 +837,9 @@ class Operations:
index=query_compiler._index_pattern, body=body.to_count_body()
)["count"]
def drop_index_values(self, query_compiler, field, items):
def drop_index_values(
self, query_compiler: "QueryCompiler", field: str, items: Sequence[str]
) -> None:
self._validate_index_operation(query_compiler, items)
# Putting boolean queries together
@ -845,12 +856,14 @@ class Operations:
task = QueryTermsTask(False, field, items)
self._tasks.append(task)
def filter_index_values(self, query_compiler, field, items):
def filter_index_values(
self, query_compiler: "QueryCompiler", field: str, items: Sequence[str]
) -> None:
# Basically .drop_index_values() except with must=True on tasks.
self._validate_index_operation(query_compiler, items)
if field == Index.ID_INDEX_FIELD:
task = QueryIdsTask(True, items)
task = QueryIdsTask(True, items, sort_index_by_ids=True)
else:
task = QueryTermsTask(True, field, items)
self._tasks.append(task)
@ -869,7 +882,9 @@ class Operations:
return size, sort_params
@staticmethod
def _count_post_processing(post_processing):
def _count_post_processing(
post_processing: List["PostProcessingAction"],
) -> Optional[int]:
size = None
for action in post_processing:
if isinstance(action, SizeTask):
@ -878,13 +893,15 @@ class Operations:
return size
@staticmethod
def _apply_df_post_processing(df, post_processing):
def _apply_df_post_processing(
df: "pd.DataFrame", post_processing: List["PostProcessingAction"]
) -> pd.DataFrame:
for action in post_processing:
df = action.resolve_action(df)
return df
def _resolve_tasks(self, query_compiler):
def _resolve_tasks(self, query_compiler: "QueryCompiler") -> RESOLVED_TASK_TYPE:
# We now try and combine all tasks into an Elasticsearch query
# Some operations can be simply combined into a single query
# other operations require pre-queries and then combinations
@ -907,7 +924,9 @@ class Operations:
return query_params, post_processing
def _size(self, query_params, post_processing):
def _size(
self, query_params: "QueryParams", post_processing: List["PostProcessingAction"]
) -> Optional[int]:
# Shrink wrap code around checking if size parameter is set
size = query_params.size

View File

@ -17,7 +17,7 @@
import copy
from datetime import datetime
from typing import Optional, TYPE_CHECKING, List
from typing import Optional, Sequence, TYPE_CHECKING, List
import numpy as np
import pandas as pd
@ -482,7 +482,12 @@ class QueryCompiler:
return result
def filter(self, items=None, like=None, regex=None):
def filter(
self,
items: Optional[Sequence[str]] = None,
like: Optional[str] = None,
regex: Optional[str] = None,
) -> "QueryCompiler":
# field will be es_index_field for DataFrames or the column for Series.
# This function is only called for axis='index',
# DataFrame.filter(..., axis="columns") calls .drop()

View File

@ -221,7 +221,7 @@ class SampleTask(SizeTask):
class QueryIdsTask(Task):
def __init__(self, must: bool, ids: List[str]):
def __init__(self, must: bool, ids: List[str], sort_index_by_ids: bool = False):
"""
Parameters
----------
@ -235,6 +235,7 @@ class QueryIdsTask(Task):
self._must = must
self._ids = ids
self._sort_index_by_ids = sort_index_by_ids
def resolve_task(
self,
@ -243,6 +244,8 @@ class QueryIdsTask(Task):
query_compiler: "QueryCompiler",
) -> RESOLVED_TASK_TYPE:
query_params.query.ids(self._ids, must=self._must)
if self._sort_index_by_ids:
post_processing.append(SortIndexAction(items=self._ids))
return query_params, post_processing
def __repr__(self) -> str:

View File

@ -64,3 +64,13 @@ class TestDataFrameFilter(TestData):
ed_flights_small.filter(like="2", axis=0)
with pytest.raises(NotImplementedError):
ed_flights_small.filter(regex="^2", axis=0)
def test_filter_index_order(self):
# Filtering dataframe should retain order of items
ed_flights = self.ed_flights()
items = ["4", "2", "3", "1", "0"]
assert [
i for i in ed_flights.filter(axis="index", items=items).to_pandas().index
] == items

View File

@ -24,7 +24,6 @@ import pandas as pd # type: ignore
RT = TypeVar("RT")
Item = TypeVar("Item")
def deprecated_api(
@ -62,7 +61,7 @@ def to_list(x: Union[Collection[Any], pd.Series]) -> List[Any]:
raise NotImplementedError(f"Could not convert {type(x).__name__} into a list")
def try_sort(iterable: Iterable[Item]) -> Iterable[Item]:
def try_sort(iterable: Iterable[str]) -> Iterable[str]:
# Pulled from pandas.core.common since
# it was deprecated and removed in 1.1
listed = list(iterable)