mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
Add .sample() method to DataFrame and Series
This commit is contained in:
parent
def3a46af9
commit
94dbb36081
6
docs/source/reference/api/eland.DataFrame.sample.rst
Normal file
6
docs/source/reference/api/eland.DataFrame.sample.rst
Normal file
@ -0,0 +1,6 @@
|
||||
eland.DataFrame.sample
|
||||
======================
|
||||
|
||||
.. currentmodule:: eland
|
||||
|
||||
.. automethod:: DataFrame.sample
|
6
docs/source/reference/api/eland.Series.sample.rst
Normal file
6
docs/source/reference/api/eland.Series.sample.rst
Normal file
@ -0,0 +1,6 @@
|
||||
eland.Series.sample
|
||||
===================
|
||||
|
||||
.. currentmodule:: eland
|
||||
|
||||
.. automethod:: Series.sample
|
@ -21,10 +21,10 @@ Attributes and underlying data
|
||||
|
||||
DataFrame.index
|
||||
DataFrame.columns
|
||||
DataFrame.dtypes
|
||||
DataFrame.select_dtypes
|
||||
DataFrame.values
|
||||
DataFrame.empty
|
||||
DataFrame.dtypes
|
||||
DataFrame.select_dtypes
|
||||
DataFrame.values
|
||||
DataFrame.empty
|
||||
DataFrame.shape
|
||||
|
||||
Indexing, iteration
|
||||
@ -37,6 +37,7 @@ Indexing, iteration
|
||||
DataFrame.tail
|
||||
DataFrame.get
|
||||
DataFrame.query
|
||||
DataFrame.sample
|
||||
|
||||
Function application, GroupBy & window
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
@ -21,8 +21,8 @@ Attributes and underlying data
|
||||
|
||||
Series.index
|
||||
Series.shape
|
||||
Series.name
|
||||
Series.empty
|
||||
Series.name
|
||||
Series.empty
|
||||
|
||||
Indexing, iteration
|
||||
~~~~~~~~~~~~~~~~~~~
|
||||
@ -31,6 +31,7 @@ Indexing, iteration
|
||||
|
||||
Series.head
|
||||
Series.tail
|
||||
Series.sample
|
||||
|
||||
Binary operator functions
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
@ -256,6 +256,45 @@ class DataFrame(NDFrame):
|
||||
"""
|
||||
return DataFrame(query_compiler=self._query_compiler.tail(n))
|
||||
|
||||
def sample(
|
||||
self, n: int = None, frac: float = None, random_state: int = None
|
||||
) -> "DataFrame":
|
||||
"""
|
||||
Return n randomly sample rows or the specify fraction of rows
|
||||
|
||||
Parameters
|
||||
----------
|
||||
n : int, optional
|
||||
Number of documents from index to return. Cannot be used with `frac`.
|
||||
Default = 1 if `frac` = None.
|
||||
frac : float, optional
|
||||
Fraction of axis items to return. Cannot be used with `n`.
|
||||
random_state : int, optional
|
||||
Seed for the random number generator.
|
||||
|
||||
Returns
|
||||
-------
|
||||
eland.DataFrame:
|
||||
eland DataFrame filtered containing n rows randomly sampled
|
||||
|
||||
See Also
|
||||
--------
|
||||
:pandas_api_docs:`pandas.DataFrame.sample`
|
||||
"""
|
||||
|
||||
if frac is not None and not (0.0 < frac <= 1.0):
|
||||
raise ValueError("`frac` must be between 0. and 1.")
|
||||
elif n is not None and frac is None and n % 1 != 0:
|
||||
raise ValueError("Only integers accepted as `n` values")
|
||||
elif (n is not None) == (frac is not None):
|
||||
raise ValueError("Please enter a value for `frac` OR `n`, not both")
|
||||
|
||||
return DataFrame(
|
||||
query_compiler=self._query_compiler.sample(
|
||||
n=n, frac=frac, random_state=random_state
|
||||
)
|
||||
)
|
||||
|
||||
def drop(
|
||||
self,
|
||||
labels=None,
|
||||
|
@ -170,3 +170,19 @@ class QueryFilter(BooleanFilter):
|
||||
def __init__(self, query: Dict[str, Any]) -> None:
|
||||
super().__init__()
|
||||
self._filter = query
|
||||
|
||||
|
||||
class MatchAllFilter(QueryFilter):
|
||||
def __init__(self) -> None:
|
||||
super().__init__({"match_all": {}})
|
||||
|
||||
|
||||
class RandomScoreFilter(QueryFilter):
|
||||
def __init__(self, query: QueryFilter, random_state: int) -> None:
|
||||
q = MatchAllFilter() if query.empty() else query
|
||||
|
||||
seed = {}
|
||||
if random_state is not None:
|
||||
seed = {"seed": random_state, "field": "_seq_no"}
|
||||
|
||||
super().__init__({"function_score": {"query": q.build(), "random_score": seed}})
|
||||
|
@ -488,3 +488,7 @@ class NDFrame(ABC):
|
||||
@abstractmethod
|
||||
def tail(self, n=5):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def sample(self, n=None, frac=None, random_state=None):
|
||||
pass
|
||||
|
@ -24,6 +24,7 @@ from eland.actions import SortFieldAction
|
||||
from eland.tasks import (
|
||||
HeadTask,
|
||||
TailTask,
|
||||
SampleTask,
|
||||
BooleanFilterTask,
|
||||
ArithmeticOpFieldsTask,
|
||||
QueryTermsTask,
|
||||
@ -84,6 +85,10 @@ class Operations:
|
||||
task = TailTask(index.sort_field, n)
|
||||
self._tasks.append(task)
|
||||
|
||||
def sample(self, index, n, random_state):
|
||||
task = SampleTask(index.sort_field, n, random_state)
|
||||
self._tasks.append(task)
|
||||
|
||||
def arithmetic_op_fields(self, display_name, arithmetic_series):
|
||||
if self._arithmetic_op_fields_task is None:
|
||||
self._arithmetic_op_fields_task = ArithmeticOpFieldsTask(
|
||||
|
@ -6,7 +6,7 @@ import warnings
|
||||
from copy import deepcopy
|
||||
from typing import Optional, Dict, List, Any
|
||||
|
||||
from eland.filter import BooleanFilter, NotNull, IsNull, IsIn
|
||||
from eland.filter import RandomScoreFilter, BooleanFilter, NotNull, IsNull, IsIn
|
||||
|
||||
|
||||
class Query:
|
||||
@ -152,5 +152,8 @@ class Query:
|
||||
else:
|
||||
self._query = self._query & boolean_filter
|
||||
|
||||
def random_score(self, random_state: int) -> None:
|
||||
self._query = RandomScoreFilter(self._query, random_state)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return repr(self.to_search_body())
|
||||
|
@ -9,10 +9,10 @@ from typing import Optional, TYPE_CHECKING
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from eland import Index
|
||||
from eland.field_mappings import FieldMappings
|
||||
from eland.operations import Operations
|
||||
from eland.filter import QueryFilter
|
||||
from eland.operations import Operations
|
||||
from eland import Index
|
||||
from eland.common import (
|
||||
ensure_es_client,
|
||||
DEFAULT_PROGRESS_REPORTING_NUM_ROWS,
|
||||
@ -393,6 +393,24 @@ class QueryCompiler:
|
||||
|
||||
return result
|
||||
|
||||
def sample(self, n=None, frac=None, random_state=None):
|
||||
result = self.copy()
|
||||
|
||||
if n is None and frac is None:
|
||||
n = 1
|
||||
elif n is None and frac is not None:
|
||||
index_length = self._index_count()
|
||||
n = int(round(frac * index_length))
|
||||
|
||||
if n < 0:
|
||||
raise ValueError(
|
||||
"A negative number of rows requested. Please provide positive value."
|
||||
)
|
||||
|
||||
result._operations.sample(self._index, n, random_state)
|
||||
|
||||
return result
|
||||
|
||||
def es_query(self, query):
|
||||
return self._update_query(QueryFilter(query))
|
||||
|
||||
|
@ -225,6 +225,9 @@ class Series(NDFrame):
|
||||
def tail(self, n=5):
|
||||
return Series(query_compiler=self._query_compiler.tail(n))
|
||||
|
||||
def sample(self, n: int = None, frac: float = None, random_state: int = None):
|
||||
return Series(query_compiler=self._query_compiler.sample(n, frac, random_state))
|
||||
|
||||
def value_counts(self, es_size=10):
|
||||
"""
|
||||
Return the value counts for the specified field.
|
||||
|
@ -9,7 +9,6 @@ from eland import SortOrder
|
||||
from eland.actions import HeadAction, TailAction, SortIndexAction
|
||||
from eland.arithmetics import ArithmeticSeries
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .actions import PostProcessingAction # noqa: F401
|
||||
from .filter import BooleanFilter # noqa: F401
|
||||
@ -175,6 +174,43 @@ class TailTask(SizeTask):
|
||||
return f"('{self._task_type}': ('sort_field': '{self._sort_field}', 'count': {self._count}))"
|
||||
|
||||
|
||||
class SampleTask(SizeTask):
|
||||
def __init__(self, sort_field: str, count: int, random_state: int):
|
||||
super().__init__("sample")
|
||||
self._count = count
|
||||
self._random_state = random_state
|
||||
self._sort_field = sort_field
|
||||
|
||||
def resolve_task(
|
||||
self,
|
||||
query_params: "QueryParams",
|
||||
post_processing: List["PostProcessingAction"],
|
||||
query_compiler: "QueryCompiler",
|
||||
) -> RESOLVED_TASK_TYPE:
|
||||
query_params.query.random_score(self._random_state)
|
||||
|
||||
query_sort_field = self._sort_field
|
||||
query_size = self._count
|
||||
|
||||
if query_params.size is not None:
|
||||
query_params.size = min(query_size, query_params.size)
|
||||
else:
|
||||
query_params.size = query_size
|
||||
|
||||
if query_params.sort_field is None:
|
||||
query_params.sort_field = query_sort_field
|
||||
|
||||
post_processing.append(SortIndexAction())
|
||||
|
||||
return query_params, post_processing
|
||||
|
||||
def size(self) -> int:
|
||||
return self._count
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"('{self._task_type}': ('count': {self._count}))"
|
||||
|
||||
|
||||
class QueryIdsTask(Task):
|
||||
def __init__(self, must: bool, ids: List[str]):
|
||||
"""
|
||||
|
88
eland/tests/dataframe/test_sample_pytest.py
Normal file
88
eland/tests/dataframe/test_sample_pytest.py
Normal file
@ -0,0 +1,88 @@
|
||||
# Licensed to Elasticsearch B.V under one or more agreements.
|
||||
# Elasticsearch B.V licenses this file to you under the Apache 2.0 License.
|
||||
# See the LICENSE file in the project root for more information
|
||||
|
||||
# File called _pytest for PyCharm compatibility
|
||||
import pytest
|
||||
from pandas.testing import assert_frame_equal
|
||||
|
||||
from eland.tests.common import TestData
|
||||
from eland.utils import eland_to_pandas
|
||||
|
||||
|
||||
class TestDataFrameSample(TestData):
|
||||
SEED = 42
|
||||
|
||||
def build_from_index(self, sample_ed_flights):
|
||||
sample_pd_flights = self.pd_flights_small().loc[
|
||||
sample_ed_flights.index, sample_ed_flights.columns
|
||||
]
|
||||
return sample_pd_flights
|
||||
|
||||
def test_sample(self):
|
||||
ed_flights_small = self.ed_flights_small()
|
||||
first_sample = ed_flights_small.sample(n=10, random_state=self.SEED)
|
||||
second_sample = ed_flights_small.sample(n=10, random_state=self.SEED)
|
||||
|
||||
assert_frame_equal(
|
||||
eland_to_pandas(first_sample), eland_to_pandas(second_sample)
|
||||
)
|
||||
|
||||
def test_sample_raises(self):
|
||||
ed_flights_small = self.ed_flights_small()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
ed_flights_small.sample(n=10, frac=0.1)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
ed_flights_small.sample(frac=1.5)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
ed_flights_small.sample(n=-1)
|
||||
|
||||
def test_sample_basic(self):
|
||||
ed_flights_small = self.ed_flights_small()
|
||||
sample_ed_flights = ed_flights_small.sample(n=10, random_state=self.SEED)
|
||||
pd_from_eland = eland_to_pandas(sample_ed_flights)
|
||||
|
||||
# build using index
|
||||
sample_pd_flights = self.build_from_index(pd_from_eland)
|
||||
|
||||
assert_frame_equal(sample_pd_flights, pd_from_eland)
|
||||
|
||||
def test_sample_frac_01(self):
|
||||
frac = 0.15
|
||||
ed_flights = self.ed_flights_small().sample(frac=frac, random_state=self.SEED)
|
||||
pd_from_eland = eland_to_pandas(ed_flights)
|
||||
pd_flights = self.build_from_index(pd_from_eland)
|
||||
|
||||
assert_frame_equal(pd_flights, pd_from_eland)
|
||||
|
||||
# assert right size from pd_flights
|
||||
size = len(self.pd_flights_small())
|
||||
assert len(pd_flights) == int(round(frac * size))
|
||||
|
||||
def test_sample_on_boolean_filter(self):
|
||||
ed_flights = self.ed_flights_small()
|
||||
columns = ["timestamp", "OriginAirportID", "DestAirportID", "FlightDelayMin"]
|
||||
sample_ed_flights = ed_flights[columns].sample(n=5, random_state=self.SEED)
|
||||
pd_from_eland = eland_to_pandas(sample_ed_flights)
|
||||
sample_pd_flights = self.build_from_index(pd_from_eland)
|
||||
|
||||
assert_frame_equal(sample_pd_flights, pd_from_eland)
|
||||
|
||||
def test_sample_head(self):
|
||||
ed_flights = self.ed_flights_small()
|
||||
sample_ed_flights = ed_flights.sample(n=10, random_state=self.SEED)
|
||||
sample_pd_flights = self.build_from_index(eland_to_pandas(sample_ed_flights))
|
||||
|
||||
pd_head_5 = sample_pd_flights.head(5)
|
||||
ed_head_5 = sample_ed_flights.head(5)
|
||||
assert_frame_equal(pd_head_5, eland_to_pandas(ed_head_5))
|
||||
|
||||
def test_sample_shape(self):
|
||||
ed_flights = self.ed_flights_small()
|
||||
sample_ed_flights = ed_flights.sample(n=10, random_state=self.SEED)
|
||||
sample_pd_flights = self.build_from_index(eland_to_pandas(sample_ed_flights))
|
||||
|
||||
assert sample_pd_flights.shape == sample_ed_flights.shape
|
25
eland/tests/series/test_sample_pytest.py
Normal file
25
eland/tests/series/test_sample_pytest.py
Normal file
@ -0,0 +1,25 @@
|
||||
# Licensed to Elasticsearch B.V under one or more agreements.
|
||||
# Elasticsearch B.V licenses this file to you under the Apache 2.0 License.
|
||||
# See the LICENSE file in the project root for more information
|
||||
|
||||
# File called _pytest for PyCharm compatibility
|
||||
import eland as ed
|
||||
from eland.tests import ES_TEST_CLIENT
|
||||
from eland.tests import FLIGHTS_INDEX_NAME
|
||||
from eland.tests.common import TestData
|
||||
from eland.tests.common import assert_pandas_eland_series_equal
|
||||
|
||||
|
||||
class TestSeriesSample(TestData):
|
||||
SEED = 42
|
||||
|
||||
def build_from_index(self, ed_series):
|
||||
ed2pd_series = ed_series._to_pandas()
|
||||
return self.pd_flights()["Carrier"].iloc[ed2pd_series.index]
|
||||
|
||||
def test_sample(self):
|
||||
ed_s = ed.Series(ES_TEST_CLIENT, FLIGHTS_INDEX_NAME, "Carrier")
|
||||
pd_s = self.build_from_index(ed_s.sample(n=10, random_state=self.SEED))
|
||||
|
||||
ed_s_sample = ed_s.sample(n=10, random_state=self.SEED)
|
||||
assert_pandas_eland_series_equal(pd_s, ed_s_sample)
|
Loading…
x
Reference in New Issue
Block a user