mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
Change DataFrame.filter() to preserve the order of items
This commit is contained in:
parent
0dd247b693
commit
b7c6c26606
@ -16,7 +16,7 @@
|
|||||||
# under the License.
|
# under the License.
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import TYPE_CHECKING
|
from typing import List, Optional, TYPE_CHECKING, Union
|
||||||
from eland import SortOrder
|
from eland import SortOrder
|
||||||
|
|
||||||
|
|
||||||
@ -50,10 +50,14 @@ class PostProcessingAction(ABC):
|
|||||||
|
|
||||||
|
|
||||||
class SortIndexAction(PostProcessingAction):
|
class SortIndexAction(PostProcessingAction):
|
||||||
def __init__(self) -> None:
|
def __init__(self, items: Optional[Union[List[int], List[str]]] = None) -> None:
|
||||||
super().__init__("sort_index")
|
super().__init__("sort_index")
|
||||||
|
self._items = items
|
||||||
|
|
||||||
def resolve_action(self, df: "pd.DataFrame") -> "pd.DataFrame":
|
def resolve_action(self, df: "pd.DataFrame") -> "pd.DataFrame":
|
||||||
|
if self._items is not None:
|
||||||
|
return df.reindex(self._items)
|
||||||
|
else:
|
||||||
return df.sort_index()
|
return df.sort_index()
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
|
@ -18,7 +18,7 @@
|
|||||||
import copy
|
import copy
|
||||||
import typing
|
import typing
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Optional, Tuple, List, Dict, Any
|
from typing import Optional, Sequence, Tuple, List, Dict, Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
@ -33,9 +33,10 @@ from eland.common import (
|
|||||||
build_pd_series,
|
build_pd_series,
|
||||||
)
|
)
|
||||||
from eland.query import Query
|
from eland.query import Query
|
||||||
from eland.actions import SortFieldAction
|
from eland.actions import PostProcessingAction, SortFieldAction
|
||||||
from eland.tasks import (
|
from eland.tasks import (
|
||||||
HeadTask,
|
HeadTask,
|
||||||
|
RESOLVED_TASK_TYPE,
|
||||||
TailTask,
|
TailTask,
|
||||||
SampleTask,
|
SampleTask,
|
||||||
BooleanFilterTask,
|
BooleanFilterTask,
|
||||||
@ -559,7 +560,13 @@ class Operations:
|
|||||||
|
|
||||||
return es_aggs
|
return es_aggs
|
||||||
|
|
||||||
def filter(self, query_compiler, items=None, like=None, regex=None):
|
def filter(
|
||||||
|
self,
|
||||||
|
query_compiler: "QueryCompiler",
|
||||||
|
items: Optional[Sequence[str]] = None,
|
||||||
|
like: Optional[str] = None,
|
||||||
|
regex: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
# This function is only called for axis='index',
|
# This function is only called for axis='index',
|
||||||
# DataFrame.filter(..., axis="columns") calls .drop()
|
# DataFrame.filter(..., axis="columns") calls .drop()
|
||||||
if items is not None:
|
if items is not None:
|
||||||
@ -795,7 +802,9 @@ class Operations:
|
|||||||
index=query_compiler._index_pattern, body=body.to_count_body()
|
index=query_compiler._index_pattern, body=body.to_count_body()
|
||||||
)["count"]
|
)["count"]
|
||||||
|
|
||||||
def _validate_index_operation(self, query_compiler, items):
|
def _validate_index_operation(
|
||||||
|
self, query_compiler: "QueryCompiler", items: Sequence[str]
|
||||||
|
) -> RESOLVED_TASK_TYPE:
|
||||||
if not isinstance(items, list):
|
if not isinstance(items, list):
|
||||||
raise TypeError(f"list item required - not {type(items)}")
|
raise TypeError(f"list item required - not {type(items)}")
|
||||||
|
|
||||||
@ -828,7 +837,9 @@ class Operations:
|
|||||||
index=query_compiler._index_pattern, body=body.to_count_body()
|
index=query_compiler._index_pattern, body=body.to_count_body()
|
||||||
)["count"]
|
)["count"]
|
||||||
|
|
||||||
def drop_index_values(self, query_compiler, field, items):
|
def drop_index_values(
|
||||||
|
self, query_compiler: "QueryCompiler", field: str, items: Sequence[str]
|
||||||
|
) -> None:
|
||||||
self._validate_index_operation(query_compiler, items)
|
self._validate_index_operation(query_compiler, items)
|
||||||
|
|
||||||
# Putting boolean queries together
|
# Putting boolean queries together
|
||||||
@ -845,12 +856,14 @@ class Operations:
|
|||||||
task = QueryTermsTask(False, field, items)
|
task = QueryTermsTask(False, field, items)
|
||||||
self._tasks.append(task)
|
self._tasks.append(task)
|
||||||
|
|
||||||
def filter_index_values(self, query_compiler, field, items):
|
def filter_index_values(
|
||||||
|
self, query_compiler: "QueryCompiler", field: str, items: Sequence[str]
|
||||||
|
) -> None:
|
||||||
# Basically .drop_index_values() except with must=True on tasks.
|
# Basically .drop_index_values() except with must=True on tasks.
|
||||||
self._validate_index_operation(query_compiler, items)
|
self._validate_index_operation(query_compiler, items)
|
||||||
|
|
||||||
if field == Index.ID_INDEX_FIELD:
|
if field == Index.ID_INDEX_FIELD:
|
||||||
task = QueryIdsTask(True, items)
|
task = QueryIdsTask(True, items, sort_index_by_ids=True)
|
||||||
else:
|
else:
|
||||||
task = QueryTermsTask(True, field, items)
|
task = QueryTermsTask(True, field, items)
|
||||||
self._tasks.append(task)
|
self._tasks.append(task)
|
||||||
@ -869,7 +882,9 @@ class Operations:
|
|||||||
return size, sort_params
|
return size, sort_params
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _count_post_processing(post_processing):
|
def _count_post_processing(
|
||||||
|
post_processing: List["PostProcessingAction"],
|
||||||
|
) -> Optional[int]:
|
||||||
size = None
|
size = None
|
||||||
for action in post_processing:
|
for action in post_processing:
|
||||||
if isinstance(action, SizeTask):
|
if isinstance(action, SizeTask):
|
||||||
@ -878,13 +893,15 @@ class Operations:
|
|||||||
return size
|
return size
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _apply_df_post_processing(df, post_processing):
|
def _apply_df_post_processing(
|
||||||
|
df: "pd.DataFrame", post_processing: List["PostProcessingAction"]
|
||||||
|
) -> pd.DataFrame:
|
||||||
for action in post_processing:
|
for action in post_processing:
|
||||||
df = action.resolve_action(df)
|
df = action.resolve_action(df)
|
||||||
|
|
||||||
return df
|
return df
|
||||||
|
|
||||||
def _resolve_tasks(self, query_compiler):
|
def _resolve_tasks(self, query_compiler: "QueryCompiler") -> RESOLVED_TASK_TYPE:
|
||||||
# We now try and combine all tasks into an Elasticsearch query
|
# We now try and combine all tasks into an Elasticsearch query
|
||||||
# Some operations can be simply combined into a single query
|
# Some operations can be simply combined into a single query
|
||||||
# other operations require pre-queries and then combinations
|
# other operations require pre-queries and then combinations
|
||||||
@ -907,7 +924,9 @@ class Operations:
|
|||||||
|
|
||||||
return query_params, post_processing
|
return query_params, post_processing
|
||||||
|
|
||||||
def _size(self, query_params, post_processing):
|
def _size(
|
||||||
|
self, query_params: "QueryParams", post_processing: List["PostProcessingAction"]
|
||||||
|
) -> Optional[int]:
|
||||||
# Shrink wrap code around checking if size parameter is set
|
# Shrink wrap code around checking if size parameter is set
|
||||||
size = query_params.size
|
size = query_params.size
|
||||||
|
|
||||||
|
@ -17,7 +17,7 @@
|
|||||||
|
|
||||||
import copy
|
import copy
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional, TYPE_CHECKING, List
|
from typing import Optional, Sequence, TYPE_CHECKING, List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
@ -482,7 +482,12 @@ class QueryCompiler:
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def filter(self, items=None, like=None, regex=None):
|
def filter(
|
||||||
|
self,
|
||||||
|
items: Optional[Sequence[str]] = None,
|
||||||
|
like: Optional[str] = None,
|
||||||
|
regex: Optional[str] = None,
|
||||||
|
) -> "QueryCompiler":
|
||||||
# field will be es_index_field for DataFrames or the column for Series.
|
# field will be es_index_field for DataFrames or the column for Series.
|
||||||
# This function is only called for axis='index',
|
# This function is only called for axis='index',
|
||||||
# DataFrame.filter(..., axis="columns") calls .drop()
|
# DataFrame.filter(..., axis="columns") calls .drop()
|
||||||
|
@ -221,7 +221,7 @@ class SampleTask(SizeTask):
|
|||||||
|
|
||||||
|
|
||||||
class QueryIdsTask(Task):
|
class QueryIdsTask(Task):
|
||||||
def __init__(self, must: bool, ids: List[str]):
|
def __init__(self, must: bool, ids: List[str], sort_index_by_ids: bool = False):
|
||||||
"""
|
"""
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
@ -235,6 +235,7 @@ class QueryIdsTask(Task):
|
|||||||
|
|
||||||
self._must = must
|
self._must = must
|
||||||
self._ids = ids
|
self._ids = ids
|
||||||
|
self._sort_index_by_ids = sort_index_by_ids
|
||||||
|
|
||||||
def resolve_task(
|
def resolve_task(
|
||||||
self,
|
self,
|
||||||
@ -243,6 +244,8 @@ class QueryIdsTask(Task):
|
|||||||
query_compiler: "QueryCompiler",
|
query_compiler: "QueryCompiler",
|
||||||
) -> RESOLVED_TASK_TYPE:
|
) -> RESOLVED_TASK_TYPE:
|
||||||
query_params.query.ids(self._ids, must=self._must)
|
query_params.query.ids(self._ids, must=self._must)
|
||||||
|
if self._sort_index_by_ids:
|
||||||
|
post_processing.append(SortIndexAction(items=self._ids))
|
||||||
return query_params, post_processing
|
return query_params, post_processing
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
|
@ -64,3 +64,13 @@ class TestDataFrameFilter(TestData):
|
|||||||
ed_flights_small.filter(like="2", axis=0)
|
ed_flights_small.filter(like="2", axis=0)
|
||||||
with pytest.raises(NotImplementedError):
|
with pytest.raises(NotImplementedError):
|
||||||
ed_flights_small.filter(regex="^2", axis=0)
|
ed_flights_small.filter(regex="^2", axis=0)
|
||||||
|
|
||||||
|
def test_filter_index_order(self):
|
||||||
|
# Filtering dataframe should retain order of items
|
||||||
|
ed_flights = self.ed_flights()
|
||||||
|
|
||||||
|
items = ["4", "2", "3", "1", "0"]
|
||||||
|
|
||||||
|
assert [
|
||||||
|
i for i in ed_flights.filter(axis="index", items=items).to_pandas().index
|
||||||
|
] == items
|
||||||
|
@ -24,7 +24,6 @@ import pandas as pd # type: ignore
|
|||||||
|
|
||||||
|
|
||||||
RT = TypeVar("RT")
|
RT = TypeVar("RT")
|
||||||
Item = TypeVar("Item")
|
|
||||||
|
|
||||||
|
|
||||||
def deprecated_api(
|
def deprecated_api(
|
||||||
@ -62,7 +61,7 @@ def to_list(x: Union[Collection[Any], pd.Series]) -> List[Any]:
|
|||||||
raise NotImplementedError(f"Could not convert {type(x).__name__} into a list")
|
raise NotImplementedError(f"Could not convert {type(x).__name__} into a list")
|
||||||
|
|
||||||
|
|
||||||
def try_sort(iterable: Iterable[Item]) -> Iterable[Item]:
|
def try_sort(iterable: Iterable[str]) -> Iterable[str]:
|
||||||
# Pulled from pandas.core.common since
|
# Pulled from pandas.core.common since
|
||||||
# it was deprecated and removed in 1.1
|
# it was deprecated and removed in 1.1
|
||||||
listed = list(iterable)
|
listed = list(iterable)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user