mirror of
https://github.com/elastic/eland.git
synced 2025-07-24 00:00:39 +08:00
Add .sample() method to DataFrame and Series
This commit is contained in:
parent
def3a46af9
commit
94dbb36081
6
docs/source/reference/api/eland.DataFrame.sample.rst
Normal file
6
docs/source/reference/api/eland.DataFrame.sample.rst
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
eland.DataFrame.sample
|
||||||
|
======================
|
||||||
|
|
||||||
|
.. currentmodule:: eland
|
||||||
|
|
||||||
|
.. automethod:: DataFrame.sample
|
6
docs/source/reference/api/eland.Series.sample.rst
Normal file
6
docs/source/reference/api/eland.Series.sample.rst
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
eland.Series.sample
|
||||||
|
===================
|
||||||
|
|
||||||
|
.. currentmodule:: eland
|
||||||
|
|
||||||
|
.. automethod:: Series.sample
|
@ -37,6 +37,7 @@ Indexing, iteration
|
|||||||
DataFrame.tail
|
DataFrame.tail
|
||||||
DataFrame.get
|
DataFrame.get
|
||||||
DataFrame.query
|
DataFrame.query
|
||||||
|
DataFrame.sample
|
||||||
|
|
||||||
Function application, GroupBy & window
|
Function application, GroupBy & window
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
@ -31,6 +31,7 @@ Indexing, iteration
|
|||||||
|
|
||||||
Series.head
|
Series.head
|
||||||
Series.tail
|
Series.tail
|
||||||
|
Series.sample
|
||||||
|
|
||||||
Binary operator functions
|
Binary operator functions
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
@ -256,6 +256,45 @@ class DataFrame(NDFrame):
|
|||||||
"""
|
"""
|
||||||
return DataFrame(query_compiler=self._query_compiler.tail(n))
|
return DataFrame(query_compiler=self._query_compiler.tail(n))
|
||||||
|
|
||||||
|
def sample(
|
||||||
|
self, n: int = None, frac: float = None, random_state: int = None
|
||||||
|
) -> "DataFrame":
|
||||||
|
"""
|
||||||
|
Return n randomly sample rows or the specify fraction of rows
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
n : int, optional
|
||||||
|
Number of documents from index to return. Cannot be used with `frac`.
|
||||||
|
Default = 1 if `frac` = None.
|
||||||
|
frac : float, optional
|
||||||
|
Fraction of axis items to return. Cannot be used with `n`.
|
||||||
|
random_state : int, optional
|
||||||
|
Seed for the random number generator.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
eland.DataFrame:
|
||||||
|
eland DataFrame filtered containing n rows randomly sampled
|
||||||
|
|
||||||
|
See Also
|
||||||
|
--------
|
||||||
|
:pandas_api_docs:`pandas.DataFrame.sample`
|
||||||
|
"""
|
||||||
|
|
||||||
|
if frac is not None and not (0.0 < frac <= 1.0):
|
||||||
|
raise ValueError("`frac` must be between 0. and 1.")
|
||||||
|
elif n is not None and frac is None and n % 1 != 0:
|
||||||
|
raise ValueError("Only integers accepted as `n` values")
|
||||||
|
elif (n is not None) == (frac is not None):
|
||||||
|
raise ValueError("Please enter a value for `frac` OR `n`, not both")
|
||||||
|
|
||||||
|
return DataFrame(
|
||||||
|
query_compiler=self._query_compiler.sample(
|
||||||
|
n=n, frac=frac, random_state=random_state
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def drop(
|
def drop(
|
||||||
self,
|
self,
|
||||||
labels=None,
|
labels=None,
|
||||||
|
@ -170,3 +170,19 @@ class QueryFilter(BooleanFilter):
|
|||||||
def __init__(self, query: Dict[str, Any]) -> None:
|
def __init__(self, query: Dict[str, Any]) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._filter = query
|
self._filter = query
|
||||||
|
|
||||||
|
|
||||||
|
class MatchAllFilter(QueryFilter):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__({"match_all": {}})
|
||||||
|
|
||||||
|
|
||||||
|
class RandomScoreFilter(QueryFilter):
|
||||||
|
def __init__(self, query: QueryFilter, random_state: int) -> None:
|
||||||
|
q = MatchAllFilter() if query.empty() else query
|
||||||
|
|
||||||
|
seed = {}
|
||||||
|
if random_state is not None:
|
||||||
|
seed = {"seed": random_state, "field": "_seq_no"}
|
||||||
|
|
||||||
|
super().__init__({"function_score": {"query": q.build(), "random_score": seed}})
|
||||||
|
@ -488,3 +488,7 @@ class NDFrame(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def tail(self, n=5):
|
def tail(self, n=5):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def sample(self, n=None, frac=None, random_state=None):
|
||||||
|
pass
|
||||||
|
@ -24,6 +24,7 @@ from eland.actions import SortFieldAction
|
|||||||
from eland.tasks import (
|
from eland.tasks import (
|
||||||
HeadTask,
|
HeadTask,
|
||||||
TailTask,
|
TailTask,
|
||||||
|
SampleTask,
|
||||||
BooleanFilterTask,
|
BooleanFilterTask,
|
||||||
ArithmeticOpFieldsTask,
|
ArithmeticOpFieldsTask,
|
||||||
QueryTermsTask,
|
QueryTermsTask,
|
||||||
@ -84,6 +85,10 @@ class Operations:
|
|||||||
task = TailTask(index.sort_field, n)
|
task = TailTask(index.sort_field, n)
|
||||||
self._tasks.append(task)
|
self._tasks.append(task)
|
||||||
|
|
||||||
|
def sample(self, index, n, random_state):
|
||||||
|
task = SampleTask(index.sort_field, n, random_state)
|
||||||
|
self._tasks.append(task)
|
||||||
|
|
||||||
def arithmetic_op_fields(self, display_name, arithmetic_series):
|
def arithmetic_op_fields(self, display_name, arithmetic_series):
|
||||||
if self._arithmetic_op_fields_task is None:
|
if self._arithmetic_op_fields_task is None:
|
||||||
self._arithmetic_op_fields_task = ArithmeticOpFieldsTask(
|
self._arithmetic_op_fields_task = ArithmeticOpFieldsTask(
|
||||||
|
@ -6,7 +6,7 @@ import warnings
|
|||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Optional, Dict, List, Any
|
from typing import Optional, Dict, List, Any
|
||||||
|
|
||||||
from eland.filter import BooleanFilter, NotNull, IsNull, IsIn
|
from eland.filter import RandomScoreFilter, BooleanFilter, NotNull, IsNull, IsIn
|
||||||
|
|
||||||
|
|
||||||
class Query:
|
class Query:
|
||||||
@ -152,5 +152,8 @@ class Query:
|
|||||||
else:
|
else:
|
||||||
self._query = self._query & boolean_filter
|
self._query = self._query & boolean_filter
|
||||||
|
|
||||||
|
def random_score(self, random_state: int) -> None:
|
||||||
|
self._query = RandomScoreFilter(self._query, random_state)
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return repr(self.to_search_body())
|
return repr(self.to_search_body())
|
||||||
|
@ -9,10 +9,10 @@ from typing import Optional, TYPE_CHECKING
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
from eland import Index
|
|
||||||
from eland.field_mappings import FieldMappings
|
from eland.field_mappings import FieldMappings
|
||||||
from eland.operations import Operations
|
|
||||||
from eland.filter import QueryFilter
|
from eland.filter import QueryFilter
|
||||||
|
from eland.operations import Operations
|
||||||
|
from eland import Index
|
||||||
from eland.common import (
|
from eland.common import (
|
||||||
ensure_es_client,
|
ensure_es_client,
|
||||||
DEFAULT_PROGRESS_REPORTING_NUM_ROWS,
|
DEFAULT_PROGRESS_REPORTING_NUM_ROWS,
|
||||||
@ -393,6 +393,24 @@ class QueryCompiler:
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def sample(self, n=None, frac=None, random_state=None):
|
||||||
|
result = self.copy()
|
||||||
|
|
||||||
|
if n is None and frac is None:
|
||||||
|
n = 1
|
||||||
|
elif n is None and frac is not None:
|
||||||
|
index_length = self._index_count()
|
||||||
|
n = int(round(frac * index_length))
|
||||||
|
|
||||||
|
if n < 0:
|
||||||
|
raise ValueError(
|
||||||
|
"A negative number of rows requested. Please provide positive value."
|
||||||
|
)
|
||||||
|
|
||||||
|
result._operations.sample(self._index, n, random_state)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
def es_query(self, query):
|
def es_query(self, query):
|
||||||
return self._update_query(QueryFilter(query))
|
return self._update_query(QueryFilter(query))
|
||||||
|
|
||||||
|
@ -225,6 +225,9 @@ class Series(NDFrame):
|
|||||||
def tail(self, n=5):
|
def tail(self, n=5):
|
||||||
return Series(query_compiler=self._query_compiler.tail(n))
|
return Series(query_compiler=self._query_compiler.tail(n))
|
||||||
|
|
||||||
|
def sample(self, n: int = None, frac: float = None, random_state: int = None):
|
||||||
|
return Series(query_compiler=self._query_compiler.sample(n, frac, random_state))
|
||||||
|
|
||||||
def value_counts(self, es_size=10):
|
def value_counts(self, es_size=10):
|
||||||
"""
|
"""
|
||||||
Return the value counts for the specified field.
|
Return the value counts for the specified field.
|
||||||
|
@ -9,7 +9,6 @@ from eland import SortOrder
|
|||||||
from eland.actions import HeadAction, TailAction, SortIndexAction
|
from eland.actions import HeadAction, TailAction, SortIndexAction
|
||||||
from eland.arithmetics import ArithmeticSeries
|
from eland.arithmetics import ArithmeticSeries
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .actions import PostProcessingAction # noqa: F401
|
from .actions import PostProcessingAction # noqa: F401
|
||||||
from .filter import BooleanFilter # noqa: F401
|
from .filter import BooleanFilter # noqa: F401
|
||||||
@ -175,6 +174,43 @@ class TailTask(SizeTask):
|
|||||||
return f"('{self._task_type}': ('sort_field': '{self._sort_field}', 'count': {self._count}))"
|
return f"('{self._task_type}': ('sort_field': '{self._sort_field}', 'count': {self._count}))"
|
||||||
|
|
||||||
|
|
||||||
|
class SampleTask(SizeTask):
|
||||||
|
def __init__(self, sort_field: str, count: int, random_state: int):
|
||||||
|
super().__init__("sample")
|
||||||
|
self._count = count
|
||||||
|
self._random_state = random_state
|
||||||
|
self._sort_field = sort_field
|
||||||
|
|
||||||
|
def resolve_task(
|
||||||
|
self,
|
||||||
|
query_params: "QueryParams",
|
||||||
|
post_processing: List["PostProcessingAction"],
|
||||||
|
query_compiler: "QueryCompiler",
|
||||||
|
) -> RESOLVED_TASK_TYPE:
|
||||||
|
query_params.query.random_score(self._random_state)
|
||||||
|
|
||||||
|
query_sort_field = self._sort_field
|
||||||
|
query_size = self._count
|
||||||
|
|
||||||
|
if query_params.size is not None:
|
||||||
|
query_params.size = min(query_size, query_params.size)
|
||||||
|
else:
|
||||||
|
query_params.size = query_size
|
||||||
|
|
||||||
|
if query_params.sort_field is None:
|
||||||
|
query_params.sort_field = query_sort_field
|
||||||
|
|
||||||
|
post_processing.append(SortIndexAction())
|
||||||
|
|
||||||
|
return query_params, post_processing
|
||||||
|
|
||||||
|
def size(self) -> int:
|
||||||
|
return self._count
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"('{self._task_type}': ('count': {self._count}))"
|
||||||
|
|
||||||
|
|
||||||
class QueryIdsTask(Task):
|
class QueryIdsTask(Task):
|
||||||
def __init__(self, must: bool, ids: List[str]):
|
def __init__(self, must: bool, ids: List[str]):
|
||||||
"""
|
"""
|
||||||
|
88
eland/tests/dataframe/test_sample_pytest.py
Normal file
88
eland/tests/dataframe/test_sample_pytest.py
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
# Licensed to Elasticsearch B.V under one or more agreements.
|
||||||
|
# Elasticsearch B.V licenses this file to you under the Apache 2.0 License.
|
||||||
|
# See the LICENSE file in the project root for more information
|
||||||
|
|
||||||
|
# File called _pytest for PyCharm compatibility
|
||||||
|
import pytest
|
||||||
|
from pandas.testing import assert_frame_equal
|
||||||
|
|
||||||
|
from eland.tests.common import TestData
|
||||||
|
from eland.utils import eland_to_pandas
|
||||||
|
|
||||||
|
|
||||||
|
class TestDataFrameSample(TestData):
|
||||||
|
SEED = 42
|
||||||
|
|
||||||
|
def build_from_index(self, sample_ed_flights):
|
||||||
|
sample_pd_flights = self.pd_flights_small().loc[
|
||||||
|
sample_ed_flights.index, sample_ed_flights.columns
|
||||||
|
]
|
||||||
|
return sample_pd_flights
|
||||||
|
|
||||||
|
def test_sample(self):
|
||||||
|
ed_flights_small = self.ed_flights_small()
|
||||||
|
first_sample = ed_flights_small.sample(n=10, random_state=self.SEED)
|
||||||
|
second_sample = ed_flights_small.sample(n=10, random_state=self.SEED)
|
||||||
|
|
||||||
|
assert_frame_equal(
|
||||||
|
eland_to_pandas(first_sample), eland_to_pandas(second_sample)
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_sample_raises(self):
|
||||||
|
ed_flights_small = self.ed_flights_small()
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
ed_flights_small.sample(n=10, frac=0.1)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
ed_flights_small.sample(frac=1.5)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
ed_flights_small.sample(n=-1)
|
||||||
|
|
||||||
|
def test_sample_basic(self):
|
||||||
|
ed_flights_small = self.ed_flights_small()
|
||||||
|
sample_ed_flights = ed_flights_small.sample(n=10, random_state=self.SEED)
|
||||||
|
pd_from_eland = eland_to_pandas(sample_ed_flights)
|
||||||
|
|
||||||
|
# build using index
|
||||||
|
sample_pd_flights = self.build_from_index(pd_from_eland)
|
||||||
|
|
||||||
|
assert_frame_equal(sample_pd_flights, pd_from_eland)
|
||||||
|
|
||||||
|
def test_sample_frac_01(self):
|
||||||
|
frac = 0.15
|
||||||
|
ed_flights = self.ed_flights_small().sample(frac=frac, random_state=self.SEED)
|
||||||
|
pd_from_eland = eland_to_pandas(ed_flights)
|
||||||
|
pd_flights = self.build_from_index(pd_from_eland)
|
||||||
|
|
||||||
|
assert_frame_equal(pd_flights, pd_from_eland)
|
||||||
|
|
||||||
|
# assert right size from pd_flights
|
||||||
|
size = len(self.pd_flights_small())
|
||||||
|
assert len(pd_flights) == int(round(frac * size))
|
||||||
|
|
||||||
|
def test_sample_on_boolean_filter(self):
|
||||||
|
ed_flights = self.ed_flights_small()
|
||||||
|
columns = ["timestamp", "OriginAirportID", "DestAirportID", "FlightDelayMin"]
|
||||||
|
sample_ed_flights = ed_flights[columns].sample(n=5, random_state=self.SEED)
|
||||||
|
pd_from_eland = eland_to_pandas(sample_ed_flights)
|
||||||
|
sample_pd_flights = self.build_from_index(pd_from_eland)
|
||||||
|
|
||||||
|
assert_frame_equal(sample_pd_flights, pd_from_eland)
|
||||||
|
|
||||||
|
def test_sample_head(self):
|
||||||
|
ed_flights = self.ed_flights_small()
|
||||||
|
sample_ed_flights = ed_flights.sample(n=10, random_state=self.SEED)
|
||||||
|
sample_pd_flights = self.build_from_index(eland_to_pandas(sample_ed_flights))
|
||||||
|
|
||||||
|
pd_head_5 = sample_pd_flights.head(5)
|
||||||
|
ed_head_5 = sample_ed_flights.head(5)
|
||||||
|
assert_frame_equal(pd_head_5, eland_to_pandas(ed_head_5))
|
||||||
|
|
||||||
|
def test_sample_shape(self):
|
||||||
|
ed_flights = self.ed_flights_small()
|
||||||
|
sample_ed_flights = ed_flights.sample(n=10, random_state=self.SEED)
|
||||||
|
sample_pd_flights = self.build_from_index(eland_to_pandas(sample_ed_flights))
|
||||||
|
|
||||||
|
assert sample_pd_flights.shape == sample_ed_flights.shape
|
25
eland/tests/series/test_sample_pytest.py
Normal file
25
eland/tests/series/test_sample_pytest.py
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
# Licensed to Elasticsearch B.V under one or more agreements.
|
||||||
|
# Elasticsearch B.V licenses this file to you under the Apache 2.0 License.
|
||||||
|
# See the LICENSE file in the project root for more information
|
||||||
|
|
||||||
|
# File called _pytest for PyCharm compatibility
|
||||||
|
import eland as ed
|
||||||
|
from eland.tests import ES_TEST_CLIENT
|
||||||
|
from eland.tests import FLIGHTS_INDEX_NAME
|
||||||
|
from eland.tests.common import TestData
|
||||||
|
from eland.tests.common import assert_pandas_eland_series_equal
|
||||||
|
|
||||||
|
|
||||||
|
class TestSeriesSample(TestData):
|
||||||
|
SEED = 42
|
||||||
|
|
||||||
|
def build_from_index(self, ed_series):
|
||||||
|
ed2pd_series = ed_series._to_pandas()
|
||||||
|
return self.pd_flights()["Carrier"].iloc[ed2pd_series.index]
|
||||||
|
|
||||||
|
def test_sample(self):
|
||||||
|
ed_s = ed.Series(ES_TEST_CLIENT, FLIGHTS_INDEX_NAME, "Carrier")
|
||||||
|
pd_s = self.build_from_index(ed_s.sample(n=10, random_state=self.SEED))
|
||||||
|
|
||||||
|
ed_s_sample = ed_s.sample(n=10, random_state=self.SEED)
|
||||||
|
assert_pandas_eland_series_equal(pd_s, ed_s_sample)
|
Loading…
x
Reference in New Issue
Block a user