Add .sample() method to DataFrame and Series

This commit is contained in:
Daniel Mesejo-León 2020-05-04 19:07:21 +02:00 committed by GitHub
parent def3a46af9
commit 94dbb36081
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 261 additions and 10 deletions

View File

@ -0,0 +1,6 @@
eland.DataFrame.sample
======================
.. currentmodule:: eland
.. automethod:: DataFrame.sample

View File

@ -0,0 +1,6 @@
eland.Series.sample
===================
.. currentmodule:: eland
.. automethod:: Series.sample

View File

@ -21,10 +21,10 @@ Attributes and underlying data
DataFrame.index
DataFrame.columns
DataFrame.dtypes
DataFrame.select_dtypes
DataFrame.values
DataFrame.empty
DataFrame.dtypes
DataFrame.select_dtypes
DataFrame.values
DataFrame.empty
DataFrame.shape
Indexing, iteration
@ -37,6 +37,7 @@ Indexing, iteration
DataFrame.tail
DataFrame.get
DataFrame.query
DataFrame.sample
Function application, GroupBy & window
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -21,8 +21,8 @@ Attributes and underlying data
Series.index
Series.shape
Series.name
Series.empty
Series.name
Series.empty
Indexing, iteration
~~~~~~~~~~~~~~~~~~~
@ -31,6 +31,7 @@ Indexing, iteration
Series.head
Series.tail
Series.sample
Binary operator functions
~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -256,6 +256,45 @@ class DataFrame(NDFrame):
"""
return DataFrame(query_compiler=self._query_compiler.tail(n))
def sample(
self, n: int = None, frac: float = None, random_state: int = None
) -> "DataFrame":
"""
Return n randomly sample rows or the specify fraction of rows
Parameters
----------
n : int, optional
Number of documents from index to return. Cannot be used with `frac`.
Default = 1 if `frac` = None.
frac : float, optional
Fraction of axis items to return. Cannot be used with `n`.
random_state : int, optional
Seed for the random number generator.
Returns
-------
eland.DataFrame:
eland DataFrame filtered containing n rows randomly sampled
See Also
--------
:pandas_api_docs:`pandas.DataFrame.sample`
"""
if frac is not None and not (0.0 < frac <= 1.0):
raise ValueError("`frac` must be between 0. and 1.")
elif n is not None and frac is None and n % 1 != 0:
raise ValueError("Only integers accepted as `n` values")
elif (n is not None) == (frac is not None):
raise ValueError("Please enter a value for `frac` OR `n`, not both")
return DataFrame(
query_compiler=self._query_compiler.sample(
n=n, frac=frac, random_state=random_state
)
)
def drop(
self,
labels=None,

View File

@ -170,3 +170,19 @@ class QueryFilter(BooleanFilter):
def __init__(self, query: Dict[str, Any]) -> None:
super().__init__()
self._filter = query
class MatchAllFilter(QueryFilter):
def __init__(self) -> None:
super().__init__({"match_all": {}})
class RandomScoreFilter(QueryFilter):
def __init__(self, query: QueryFilter, random_state: int) -> None:
q = MatchAllFilter() if query.empty() else query
seed = {}
if random_state is not None:
seed = {"seed": random_state, "field": "_seq_no"}
super().__init__({"function_score": {"query": q.build(), "random_score": seed}})

View File

@ -488,3 +488,7 @@ class NDFrame(ABC):
@abstractmethod
def tail(self, n=5):
pass
@abstractmethod
def sample(self, n=None, frac=None, random_state=None):
pass

View File

@ -24,6 +24,7 @@ from eland.actions import SortFieldAction
from eland.tasks import (
HeadTask,
TailTask,
SampleTask,
BooleanFilterTask,
ArithmeticOpFieldsTask,
QueryTermsTask,
@ -84,6 +85,10 @@ class Operations:
task = TailTask(index.sort_field, n)
self._tasks.append(task)
def sample(self, index, n, random_state):
task = SampleTask(index.sort_field, n, random_state)
self._tasks.append(task)
def arithmetic_op_fields(self, display_name, arithmetic_series):
if self._arithmetic_op_fields_task is None:
self._arithmetic_op_fields_task = ArithmeticOpFieldsTask(

View File

@ -6,7 +6,7 @@ import warnings
from copy import deepcopy
from typing import Optional, Dict, List, Any
from eland.filter import BooleanFilter, NotNull, IsNull, IsIn
from eland.filter import RandomScoreFilter, BooleanFilter, NotNull, IsNull, IsIn
class Query:
@ -152,5 +152,8 @@ class Query:
else:
self._query = self._query & boolean_filter
def random_score(self, random_state: int) -> None:
self._query = RandomScoreFilter(self._query, random_state)
def __repr__(self) -> str:
return repr(self.to_search_body())

View File

@ -9,10 +9,10 @@ from typing import Optional, TYPE_CHECKING
import numpy as np
import pandas as pd
from eland import Index
from eland.field_mappings import FieldMappings
from eland.operations import Operations
from eland.filter import QueryFilter
from eland.operations import Operations
from eland import Index
from eland.common import (
ensure_es_client,
DEFAULT_PROGRESS_REPORTING_NUM_ROWS,
@ -393,6 +393,24 @@ class QueryCompiler:
return result
def sample(self, n=None, frac=None, random_state=None):
result = self.copy()
if n is None and frac is None:
n = 1
elif n is None and frac is not None:
index_length = self._index_count()
n = int(round(frac * index_length))
if n < 0:
raise ValueError(
"A negative number of rows requested. Please provide positive value."
)
result._operations.sample(self._index, n, random_state)
return result
def es_query(self, query):
return self._update_query(QueryFilter(query))

View File

@ -225,6 +225,9 @@ class Series(NDFrame):
def tail(self, n=5):
return Series(query_compiler=self._query_compiler.tail(n))
def sample(self, n: int = None, frac: float = None, random_state: int = None):
return Series(query_compiler=self._query_compiler.sample(n, frac, random_state))
def value_counts(self, es_size=10):
"""
Return the value counts for the specified field.

View File

@ -9,7 +9,6 @@ from eland import SortOrder
from eland.actions import HeadAction, TailAction, SortIndexAction
from eland.arithmetics import ArithmeticSeries
if TYPE_CHECKING:
from .actions import PostProcessingAction # noqa: F401
from .filter import BooleanFilter # noqa: F401
@ -175,6 +174,43 @@ class TailTask(SizeTask):
return f"('{self._task_type}': ('sort_field': '{self._sort_field}', 'count': {self._count}))"
class SampleTask(SizeTask):
def __init__(self, sort_field: str, count: int, random_state: int):
super().__init__("sample")
self._count = count
self._random_state = random_state
self._sort_field = sort_field
def resolve_task(
self,
query_params: "QueryParams",
post_processing: List["PostProcessingAction"],
query_compiler: "QueryCompiler",
) -> RESOLVED_TASK_TYPE:
query_params.query.random_score(self._random_state)
query_sort_field = self._sort_field
query_size = self._count
if query_params.size is not None:
query_params.size = min(query_size, query_params.size)
else:
query_params.size = query_size
if query_params.sort_field is None:
query_params.sort_field = query_sort_field
post_processing.append(SortIndexAction())
return query_params, post_processing
def size(self) -> int:
return self._count
def __repr__(self) -> str:
return f"('{self._task_type}': ('count': {self._count}))"
class QueryIdsTask(Task):
def __init__(self, must: bool, ids: List[str]):
"""

View File

@ -0,0 +1,88 @@
# Licensed to Elasticsearch B.V under one or more agreements.
# Elasticsearch B.V licenses this file to you under the Apache 2.0 License.
# See the LICENSE file in the project root for more information
# File called _pytest for PyCharm compatibility
import pytest
from pandas.testing import assert_frame_equal
from eland.tests.common import TestData
from eland.utils import eland_to_pandas
class TestDataFrameSample(TestData):
SEED = 42
def build_from_index(self, sample_ed_flights):
sample_pd_flights = self.pd_flights_small().loc[
sample_ed_flights.index, sample_ed_flights.columns
]
return sample_pd_flights
def test_sample(self):
ed_flights_small = self.ed_flights_small()
first_sample = ed_flights_small.sample(n=10, random_state=self.SEED)
second_sample = ed_flights_small.sample(n=10, random_state=self.SEED)
assert_frame_equal(
eland_to_pandas(first_sample), eland_to_pandas(second_sample)
)
def test_sample_raises(self):
ed_flights_small = self.ed_flights_small()
with pytest.raises(ValueError):
ed_flights_small.sample(n=10, frac=0.1)
with pytest.raises(ValueError):
ed_flights_small.sample(frac=1.5)
with pytest.raises(ValueError):
ed_flights_small.sample(n=-1)
def test_sample_basic(self):
ed_flights_small = self.ed_flights_small()
sample_ed_flights = ed_flights_small.sample(n=10, random_state=self.SEED)
pd_from_eland = eland_to_pandas(sample_ed_flights)
# build using index
sample_pd_flights = self.build_from_index(pd_from_eland)
assert_frame_equal(sample_pd_flights, pd_from_eland)
def test_sample_frac_01(self):
frac = 0.15
ed_flights = self.ed_flights_small().sample(frac=frac, random_state=self.SEED)
pd_from_eland = eland_to_pandas(ed_flights)
pd_flights = self.build_from_index(pd_from_eland)
assert_frame_equal(pd_flights, pd_from_eland)
# assert right size from pd_flights
size = len(self.pd_flights_small())
assert len(pd_flights) == int(round(frac * size))
def test_sample_on_boolean_filter(self):
ed_flights = self.ed_flights_small()
columns = ["timestamp", "OriginAirportID", "DestAirportID", "FlightDelayMin"]
sample_ed_flights = ed_flights[columns].sample(n=5, random_state=self.SEED)
pd_from_eland = eland_to_pandas(sample_ed_flights)
sample_pd_flights = self.build_from_index(pd_from_eland)
assert_frame_equal(sample_pd_flights, pd_from_eland)
def test_sample_head(self):
ed_flights = self.ed_flights_small()
sample_ed_flights = ed_flights.sample(n=10, random_state=self.SEED)
sample_pd_flights = self.build_from_index(eland_to_pandas(sample_ed_flights))
pd_head_5 = sample_pd_flights.head(5)
ed_head_5 = sample_ed_flights.head(5)
assert_frame_equal(pd_head_5, eland_to_pandas(ed_head_5))
def test_sample_shape(self):
ed_flights = self.ed_flights_small()
sample_ed_flights = ed_flights.sample(n=10, random_state=self.SEED)
sample_pd_flights = self.build_from_index(eland_to_pandas(sample_ed_flights))
assert sample_pd_flights.shape == sample_ed_flights.shape

View File

@ -0,0 +1,25 @@
# Licensed to Elasticsearch B.V under one or more agreements.
# Elasticsearch B.V licenses this file to you under the Apache 2.0 License.
# See the LICENSE file in the project root for more information
# File called _pytest for PyCharm compatibility
import eland as ed
from eland.tests import ES_TEST_CLIENT
from eland.tests import FLIGHTS_INDEX_NAME
from eland.tests.common import TestData
from eland.tests.common import assert_pandas_eland_series_equal
class TestSeriesSample(TestData):
SEED = 42
def build_from_index(self, ed_series):
ed2pd_series = ed_series._to_pandas()
return self.pd_flights()["Carrier"].iloc[ed2pd_series.index]
def test_sample(self):
ed_s = ed.Series(ES_TEST_CLIENT, FLIGHTS_INDEX_NAME, "Carrier")
pd_s = self.build_from_index(ed_s.sample(n=10, random_state=self.SEED))
ed_s_sample = ed_s.sample(n=10, random_state=self.SEED)
assert_pandas_eland_series_equal(pd_s, ed_s_sample)