mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
Change DataFrame.filter() to preserve the order of items
This commit is contained in:
parent
0dd247b693
commit
b7c6c26606
@ -16,7 +16,7 @@
|
||||
# under the License.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import List, Optional, TYPE_CHECKING, Union
|
||||
from eland import SortOrder
|
||||
|
||||
|
||||
@ -50,11 +50,15 @@ class PostProcessingAction(ABC):
|
||||
|
||||
|
||||
class SortIndexAction(PostProcessingAction):
|
||||
def __init__(self) -> None:
|
||||
def __init__(self, items: Optional[Union[List[int], List[str]]] = None) -> None:
|
||||
super().__init__("sort_index")
|
||||
self._items = items
|
||||
|
||||
def resolve_action(self, df: "pd.DataFrame") -> "pd.DataFrame":
|
||||
return df.sort_index()
|
||||
if self._items is not None:
|
||||
return df.reindex(self._items)
|
||||
else:
|
||||
return df.sort_index()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"('{self.type}')"
|
||||
|
@ -18,7 +18,7 @@
|
||||
import copy
|
||||
import typing
|
||||
import warnings
|
||||
from typing import Optional, Tuple, List, Dict, Any
|
||||
from typing import Optional, Sequence, Tuple, List, Dict, Any
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
@ -33,9 +33,10 @@ from eland.common import (
|
||||
build_pd_series,
|
||||
)
|
||||
from eland.query import Query
|
||||
from eland.actions import SortFieldAction
|
||||
from eland.actions import PostProcessingAction, SortFieldAction
|
||||
from eland.tasks import (
|
||||
HeadTask,
|
||||
RESOLVED_TASK_TYPE,
|
||||
TailTask,
|
||||
SampleTask,
|
||||
BooleanFilterTask,
|
||||
@ -559,7 +560,13 @@ class Operations:
|
||||
|
||||
return es_aggs
|
||||
|
||||
def filter(self, query_compiler, items=None, like=None, regex=None):
|
||||
def filter(
|
||||
self,
|
||||
query_compiler: "QueryCompiler",
|
||||
items: Optional[Sequence[str]] = None,
|
||||
like: Optional[str] = None,
|
||||
regex: Optional[str] = None,
|
||||
) -> None:
|
||||
# This function is only called for axis='index',
|
||||
# DataFrame.filter(..., axis="columns") calls .drop()
|
||||
if items is not None:
|
||||
@ -795,7 +802,9 @@ class Operations:
|
||||
index=query_compiler._index_pattern, body=body.to_count_body()
|
||||
)["count"]
|
||||
|
||||
def _validate_index_operation(self, query_compiler, items):
|
||||
def _validate_index_operation(
|
||||
self, query_compiler: "QueryCompiler", items: Sequence[str]
|
||||
) -> RESOLVED_TASK_TYPE:
|
||||
if not isinstance(items, list):
|
||||
raise TypeError(f"list item required - not {type(items)}")
|
||||
|
||||
@ -828,7 +837,9 @@ class Operations:
|
||||
index=query_compiler._index_pattern, body=body.to_count_body()
|
||||
)["count"]
|
||||
|
||||
def drop_index_values(self, query_compiler, field, items):
|
||||
def drop_index_values(
|
||||
self, query_compiler: "QueryCompiler", field: str, items: Sequence[str]
|
||||
) -> None:
|
||||
self._validate_index_operation(query_compiler, items)
|
||||
|
||||
# Putting boolean queries together
|
||||
@ -845,12 +856,14 @@ class Operations:
|
||||
task = QueryTermsTask(False, field, items)
|
||||
self._tasks.append(task)
|
||||
|
||||
def filter_index_values(self, query_compiler, field, items):
|
||||
def filter_index_values(
|
||||
self, query_compiler: "QueryCompiler", field: str, items: Sequence[str]
|
||||
) -> None:
|
||||
# Basically .drop_index_values() except with must=True on tasks.
|
||||
self._validate_index_operation(query_compiler, items)
|
||||
|
||||
if field == Index.ID_INDEX_FIELD:
|
||||
task = QueryIdsTask(True, items)
|
||||
task = QueryIdsTask(True, items, sort_index_by_ids=True)
|
||||
else:
|
||||
task = QueryTermsTask(True, field, items)
|
||||
self._tasks.append(task)
|
||||
@ -869,7 +882,9 @@ class Operations:
|
||||
return size, sort_params
|
||||
|
||||
@staticmethod
|
||||
def _count_post_processing(post_processing):
|
||||
def _count_post_processing(
|
||||
post_processing: List["PostProcessingAction"],
|
||||
) -> Optional[int]:
|
||||
size = None
|
||||
for action in post_processing:
|
||||
if isinstance(action, SizeTask):
|
||||
@ -878,13 +893,15 @@ class Operations:
|
||||
return size
|
||||
|
||||
@staticmethod
|
||||
def _apply_df_post_processing(df, post_processing):
|
||||
def _apply_df_post_processing(
|
||||
df: "pd.DataFrame", post_processing: List["PostProcessingAction"]
|
||||
) -> pd.DataFrame:
|
||||
for action in post_processing:
|
||||
df = action.resolve_action(df)
|
||||
|
||||
return df
|
||||
|
||||
def _resolve_tasks(self, query_compiler):
|
||||
def _resolve_tasks(self, query_compiler: "QueryCompiler") -> RESOLVED_TASK_TYPE:
|
||||
# We now try and combine all tasks into an Elasticsearch query
|
||||
# Some operations can be simply combined into a single query
|
||||
# other operations require pre-queries and then combinations
|
||||
@ -907,7 +924,9 @@ class Operations:
|
||||
|
||||
return query_params, post_processing
|
||||
|
||||
def _size(self, query_params, post_processing):
|
||||
def _size(
|
||||
self, query_params: "QueryParams", post_processing: List["PostProcessingAction"]
|
||||
) -> Optional[int]:
|
||||
# Shrink wrap code around checking if size parameter is set
|
||||
size = query_params.size
|
||||
|
||||
|
@ -17,7 +17,7 @@
|
||||
|
||||
import copy
|
||||
from datetime import datetime
|
||||
from typing import Optional, TYPE_CHECKING, List
|
||||
from typing import Optional, Sequence, TYPE_CHECKING, List
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
@ -482,7 +482,12 @@ class QueryCompiler:
|
||||
|
||||
return result
|
||||
|
||||
def filter(self, items=None, like=None, regex=None):
|
||||
def filter(
|
||||
self,
|
||||
items: Optional[Sequence[str]] = None,
|
||||
like: Optional[str] = None,
|
||||
regex: Optional[str] = None,
|
||||
) -> "QueryCompiler":
|
||||
# field will be es_index_field for DataFrames or the column for Series.
|
||||
# This function is only called for axis='index',
|
||||
# DataFrame.filter(..., axis="columns") calls .drop()
|
||||
|
@ -221,7 +221,7 @@ class SampleTask(SizeTask):
|
||||
|
||||
|
||||
class QueryIdsTask(Task):
|
||||
def __init__(self, must: bool, ids: List[str]):
|
||||
def __init__(self, must: bool, ids: List[str], sort_index_by_ids: bool = False):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
@ -235,6 +235,7 @@ class QueryIdsTask(Task):
|
||||
|
||||
self._must = must
|
||||
self._ids = ids
|
||||
self._sort_index_by_ids = sort_index_by_ids
|
||||
|
||||
def resolve_task(
|
||||
self,
|
||||
@ -243,6 +244,8 @@ class QueryIdsTask(Task):
|
||||
query_compiler: "QueryCompiler",
|
||||
) -> RESOLVED_TASK_TYPE:
|
||||
query_params.query.ids(self._ids, must=self._must)
|
||||
if self._sort_index_by_ids:
|
||||
post_processing.append(SortIndexAction(items=self._ids))
|
||||
return query_params, post_processing
|
||||
|
||||
def __repr__(self) -> str:
|
||||
|
@ -64,3 +64,13 @@ class TestDataFrameFilter(TestData):
|
||||
ed_flights_small.filter(like="2", axis=0)
|
||||
with pytest.raises(NotImplementedError):
|
||||
ed_flights_small.filter(regex="^2", axis=0)
|
||||
|
||||
def test_filter_index_order(self):
|
||||
# Filtering dataframe should retain order of items
|
||||
ed_flights = self.ed_flights()
|
||||
|
||||
items = ["4", "2", "3", "1", "0"]
|
||||
|
||||
assert [
|
||||
i for i in ed_flights.filter(axis="index", items=items).to_pandas().index
|
||||
] == items
|
||||
|
@ -24,7 +24,6 @@ import pandas as pd # type: ignore
|
||||
|
||||
|
||||
RT = TypeVar("RT")
|
||||
Item = TypeVar("Item")
|
||||
|
||||
|
||||
def deprecated_api(
|
||||
@ -62,7 +61,7 @@ def to_list(x: Union[Collection[Any], pd.Series]) -> List[Any]:
|
||||
raise NotImplementedError(f"Could not convert {type(x).__name__} into a list")
|
||||
|
||||
|
||||
def try_sort(iterable: Iterable[Item]) -> Iterable[Item]:
|
||||
def try_sort(iterable: Iterable[str]) -> Iterable[str]:
|
||||
# Pulled from pandas.core.common since
|
||||
# it was deprecated and removed in 1.1
|
||||
listed = list(iterable)
|
||||
|
Loading…
x
Reference in New Issue
Block a user