Add .sample() method to DataFrame and Series

This commit is contained in:
Daniel Mesejo-León 2020-05-04 19:07:21 +02:00 committed by GitHub
parent def3a46af9
commit 94dbb36081
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 261 additions and 10 deletions

View File

@ -0,0 +1,6 @@
eland.DataFrame.sample
======================
.. currentmodule:: eland
.. automethod:: DataFrame.sample

View File

@ -0,0 +1,6 @@
eland.Series.sample
===================
.. currentmodule:: eland
.. automethod:: Series.sample

View File

@ -37,6 +37,7 @@ Indexing, iteration
DataFrame.tail DataFrame.tail
DataFrame.get DataFrame.get
DataFrame.query DataFrame.query
DataFrame.sample
Function application, GroupBy & window Function application, GroupBy & window
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -31,6 +31,7 @@ Indexing, iteration
Series.head Series.head
Series.tail Series.tail
Series.sample
Binary operator functions Binary operator functions
~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -256,6 +256,45 @@ class DataFrame(NDFrame):
""" """
return DataFrame(query_compiler=self._query_compiler.tail(n)) return DataFrame(query_compiler=self._query_compiler.tail(n))
def sample(
self, n: int = None, frac: float = None, random_state: int = None
) -> "DataFrame":
"""
Return n randomly sample rows or the specify fraction of rows
Parameters
----------
n : int, optional
Number of documents from index to return. Cannot be used with `frac`.
Default = 1 if `frac` = None.
frac : float, optional
Fraction of axis items to return. Cannot be used with `n`.
random_state : int, optional
Seed for the random number generator.
Returns
-------
eland.DataFrame:
eland DataFrame filtered containing n rows randomly sampled
See Also
--------
:pandas_api_docs:`pandas.DataFrame.sample`
"""
if frac is not None and not (0.0 < frac <= 1.0):
raise ValueError("`frac` must be between 0. and 1.")
elif n is not None and frac is None and n % 1 != 0:
raise ValueError("Only integers accepted as `n` values")
elif (n is not None) == (frac is not None):
raise ValueError("Please enter a value for `frac` OR `n`, not both")
return DataFrame(
query_compiler=self._query_compiler.sample(
n=n, frac=frac, random_state=random_state
)
)
def drop( def drop(
self, self,
labels=None, labels=None,

View File

@ -170,3 +170,19 @@ class QueryFilter(BooleanFilter):
def __init__(self, query: Dict[str, Any]) -> None: def __init__(self, query: Dict[str, Any]) -> None:
super().__init__() super().__init__()
self._filter = query self._filter = query
class MatchAllFilter(QueryFilter):
def __init__(self) -> None:
super().__init__({"match_all": {}})
class RandomScoreFilter(QueryFilter):
def __init__(self, query: QueryFilter, random_state: int) -> None:
q = MatchAllFilter() if query.empty() else query
seed = {}
if random_state is not None:
seed = {"seed": random_state, "field": "_seq_no"}
super().__init__({"function_score": {"query": q.build(), "random_score": seed}})

View File

@ -488,3 +488,7 @@ class NDFrame(ABC):
@abstractmethod @abstractmethod
def tail(self, n=5): def tail(self, n=5):
pass pass
@abstractmethod
def sample(self, n=None, frac=None, random_state=None):
pass

View File

@ -24,6 +24,7 @@ from eland.actions import SortFieldAction
from eland.tasks import ( from eland.tasks import (
HeadTask, HeadTask,
TailTask, TailTask,
SampleTask,
BooleanFilterTask, BooleanFilterTask,
ArithmeticOpFieldsTask, ArithmeticOpFieldsTask,
QueryTermsTask, QueryTermsTask,
@ -84,6 +85,10 @@ class Operations:
task = TailTask(index.sort_field, n) task = TailTask(index.sort_field, n)
self._tasks.append(task) self._tasks.append(task)
def sample(self, index, n, random_state):
task = SampleTask(index.sort_field, n, random_state)
self._tasks.append(task)
def arithmetic_op_fields(self, display_name, arithmetic_series): def arithmetic_op_fields(self, display_name, arithmetic_series):
if self._arithmetic_op_fields_task is None: if self._arithmetic_op_fields_task is None:
self._arithmetic_op_fields_task = ArithmeticOpFieldsTask( self._arithmetic_op_fields_task = ArithmeticOpFieldsTask(

View File

@ -6,7 +6,7 @@ import warnings
from copy import deepcopy from copy import deepcopy
from typing import Optional, Dict, List, Any from typing import Optional, Dict, List, Any
from eland.filter import BooleanFilter, NotNull, IsNull, IsIn from eland.filter import RandomScoreFilter, BooleanFilter, NotNull, IsNull, IsIn
class Query: class Query:
@ -152,5 +152,8 @@ class Query:
else: else:
self._query = self._query & boolean_filter self._query = self._query & boolean_filter
def random_score(self, random_state: int) -> None:
self._query = RandomScoreFilter(self._query, random_state)
def __repr__(self) -> str: def __repr__(self) -> str:
return repr(self.to_search_body()) return repr(self.to_search_body())

View File

@ -9,10 +9,10 @@ from typing import Optional, TYPE_CHECKING
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from eland import Index
from eland.field_mappings import FieldMappings from eland.field_mappings import FieldMappings
from eland.operations import Operations
from eland.filter import QueryFilter from eland.filter import QueryFilter
from eland.operations import Operations
from eland import Index
from eland.common import ( from eland.common import (
ensure_es_client, ensure_es_client,
DEFAULT_PROGRESS_REPORTING_NUM_ROWS, DEFAULT_PROGRESS_REPORTING_NUM_ROWS,
@ -393,6 +393,24 @@ class QueryCompiler:
return result return result
def sample(self, n=None, frac=None, random_state=None):
result = self.copy()
if n is None and frac is None:
n = 1
elif n is None and frac is not None:
index_length = self._index_count()
n = int(round(frac * index_length))
if n < 0:
raise ValueError(
"A negative number of rows requested. Please provide positive value."
)
result._operations.sample(self._index, n, random_state)
return result
def es_query(self, query): def es_query(self, query):
return self._update_query(QueryFilter(query)) return self._update_query(QueryFilter(query))

View File

@ -225,6 +225,9 @@ class Series(NDFrame):
def tail(self, n=5): def tail(self, n=5):
return Series(query_compiler=self._query_compiler.tail(n)) return Series(query_compiler=self._query_compiler.tail(n))
def sample(self, n: int = None, frac: float = None, random_state: int = None):
return Series(query_compiler=self._query_compiler.sample(n, frac, random_state))
def value_counts(self, es_size=10): def value_counts(self, es_size=10):
""" """
Return the value counts for the specified field. Return the value counts for the specified field.

View File

@ -9,7 +9,6 @@ from eland import SortOrder
from eland.actions import HeadAction, TailAction, SortIndexAction from eland.actions import HeadAction, TailAction, SortIndexAction
from eland.arithmetics import ArithmeticSeries from eland.arithmetics import ArithmeticSeries
if TYPE_CHECKING: if TYPE_CHECKING:
from .actions import PostProcessingAction # noqa: F401 from .actions import PostProcessingAction # noqa: F401
from .filter import BooleanFilter # noqa: F401 from .filter import BooleanFilter # noqa: F401
@ -175,6 +174,43 @@ class TailTask(SizeTask):
return f"('{self._task_type}': ('sort_field': '{self._sort_field}', 'count': {self._count}))" return f"('{self._task_type}': ('sort_field': '{self._sort_field}', 'count': {self._count}))"
class SampleTask(SizeTask):
def __init__(self, sort_field: str, count: int, random_state: int):
super().__init__("sample")
self._count = count
self._random_state = random_state
self._sort_field = sort_field
def resolve_task(
self,
query_params: "QueryParams",
post_processing: List["PostProcessingAction"],
query_compiler: "QueryCompiler",
) -> RESOLVED_TASK_TYPE:
query_params.query.random_score(self._random_state)
query_sort_field = self._sort_field
query_size = self._count
if query_params.size is not None:
query_params.size = min(query_size, query_params.size)
else:
query_params.size = query_size
if query_params.sort_field is None:
query_params.sort_field = query_sort_field
post_processing.append(SortIndexAction())
return query_params, post_processing
def size(self) -> int:
return self._count
def __repr__(self) -> str:
return f"('{self._task_type}': ('count': {self._count}))"
class QueryIdsTask(Task): class QueryIdsTask(Task):
def __init__(self, must: bool, ids: List[str]): def __init__(self, must: bool, ids: List[str]):
""" """

View File

@ -0,0 +1,88 @@
# Licensed to Elasticsearch B.V under one or more agreements.
# Elasticsearch B.V licenses this file to you under the Apache 2.0 License.
# See the LICENSE file in the project root for more information
# File called _pytest for PyCharm compatibility
import pytest
from pandas.testing import assert_frame_equal
from eland.tests.common import TestData
from eland.utils import eland_to_pandas
class TestDataFrameSample(TestData):
SEED = 42
def build_from_index(self, sample_ed_flights):
sample_pd_flights = self.pd_flights_small().loc[
sample_ed_flights.index, sample_ed_flights.columns
]
return sample_pd_flights
def test_sample(self):
ed_flights_small = self.ed_flights_small()
first_sample = ed_flights_small.sample(n=10, random_state=self.SEED)
second_sample = ed_flights_small.sample(n=10, random_state=self.SEED)
assert_frame_equal(
eland_to_pandas(first_sample), eland_to_pandas(second_sample)
)
def test_sample_raises(self):
ed_flights_small = self.ed_flights_small()
with pytest.raises(ValueError):
ed_flights_small.sample(n=10, frac=0.1)
with pytest.raises(ValueError):
ed_flights_small.sample(frac=1.5)
with pytest.raises(ValueError):
ed_flights_small.sample(n=-1)
def test_sample_basic(self):
ed_flights_small = self.ed_flights_small()
sample_ed_flights = ed_flights_small.sample(n=10, random_state=self.SEED)
pd_from_eland = eland_to_pandas(sample_ed_flights)
# build using index
sample_pd_flights = self.build_from_index(pd_from_eland)
assert_frame_equal(sample_pd_flights, pd_from_eland)
def test_sample_frac_01(self):
frac = 0.15
ed_flights = self.ed_flights_small().sample(frac=frac, random_state=self.SEED)
pd_from_eland = eland_to_pandas(ed_flights)
pd_flights = self.build_from_index(pd_from_eland)
assert_frame_equal(pd_flights, pd_from_eland)
# assert right size from pd_flights
size = len(self.pd_flights_small())
assert len(pd_flights) == int(round(frac * size))
def test_sample_on_boolean_filter(self):
ed_flights = self.ed_flights_small()
columns = ["timestamp", "OriginAirportID", "DestAirportID", "FlightDelayMin"]
sample_ed_flights = ed_flights[columns].sample(n=5, random_state=self.SEED)
pd_from_eland = eland_to_pandas(sample_ed_flights)
sample_pd_flights = self.build_from_index(pd_from_eland)
assert_frame_equal(sample_pd_flights, pd_from_eland)
def test_sample_head(self):
ed_flights = self.ed_flights_small()
sample_ed_flights = ed_flights.sample(n=10, random_state=self.SEED)
sample_pd_flights = self.build_from_index(eland_to_pandas(sample_ed_flights))
pd_head_5 = sample_pd_flights.head(5)
ed_head_5 = sample_ed_flights.head(5)
assert_frame_equal(pd_head_5, eland_to_pandas(ed_head_5))
def test_sample_shape(self):
ed_flights = self.ed_flights_small()
sample_ed_flights = ed_flights.sample(n=10, random_state=self.SEED)
sample_pd_flights = self.build_from_index(eland_to_pandas(sample_ed_flights))
assert sample_pd_flights.shape == sample_ed_flights.shape

View File

@ -0,0 +1,25 @@
# Licensed to Elasticsearch B.V under one or more agreements.
# Elasticsearch B.V licenses this file to you under the Apache 2.0 License.
# See the LICENSE file in the project root for more information
# File called _pytest for PyCharm compatibility
import eland as ed
from eland.tests import ES_TEST_CLIENT
from eland.tests import FLIGHTS_INDEX_NAME
from eland.tests.common import TestData
from eland.tests.common import assert_pandas_eland_series_equal
class TestSeriesSample(TestData):
SEED = 42
def build_from_index(self, ed_series):
ed2pd_series = ed_series._to_pandas()
return self.pd_flights()["Carrier"].iloc[ed2pd_series.index]
def test_sample(self):
ed_s = ed.Series(ES_TEST_CLIENT, FLIGHTS_INDEX_NAME, "Carrier")
pd_s = self.build_from_index(ed_s.sample(n=10, random_state=self.SEED))
ed_s_sample = ed_s.sample(n=10, random_state=self.SEED)
assert_pandas_eland_series_equal(pd_s, ed_s_sample)