From 94dbb3608114d91fc0d487125f6ae88d1eb4d5cf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Mesejo-Le=C3=B3n?= Date: Mon, 4 May 2020 19:07:21 +0200 Subject: [PATCH] Add .sample() method to DataFrame and Series --- .../reference/api/eland.DataFrame.sample.rst | 6 ++ .../reference/api/eland.Series.sample.rst | 6 ++ docs/source/reference/dataframe.rst | 9 +- docs/source/reference/series.rst | 5 +- eland/dataframe.py | 39 ++++++++ eland/filter.py | 16 ++++ eland/ndframe.py | 4 + eland/operations.py | 5 ++ eland/query.py | 5 +- eland/query_compiler.py | 22 ++++- eland/series.py | 3 + eland/tasks.py | 38 +++++++- eland/tests/dataframe/test_sample_pytest.py | 88 +++++++++++++++++++ eland/tests/series/test_sample_pytest.py | 25 ++++++ 14 files changed, 261 insertions(+), 10 deletions(-) create mode 100644 docs/source/reference/api/eland.DataFrame.sample.rst create mode 100644 docs/source/reference/api/eland.Series.sample.rst create mode 100644 eland/tests/dataframe/test_sample_pytest.py create mode 100644 eland/tests/series/test_sample_pytest.py diff --git a/docs/source/reference/api/eland.DataFrame.sample.rst b/docs/source/reference/api/eland.DataFrame.sample.rst new file mode 100644 index 0000000..59e9e82 --- /dev/null +++ b/docs/source/reference/api/eland.DataFrame.sample.rst @@ -0,0 +1,6 @@ +eland.DataFrame.sample +====================== + +.. currentmodule:: eland + +.. automethod:: DataFrame.sample diff --git a/docs/source/reference/api/eland.Series.sample.rst b/docs/source/reference/api/eland.Series.sample.rst new file mode 100644 index 0000000..22a85bc --- /dev/null +++ b/docs/source/reference/api/eland.Series.sample.rst @@ -0,0 +1,6 @@ +eland.Series.sample +=================== + +.. currentmodule:: eland + +.. automethod:: Series.sample diff --git a/docs/source/reference/dataframe.rst b/docs/source/reference/dataframe.rst index 9da5a29..cf94043 100644 --- a/docs/source/reference/dataframe.rst +++ b/docs/source/reference/dataframe.rst @@ -21,10 +21,10 @@ Attributes and underlying data DataFrame.index DataFrame.columns - DataFrame.dtypes - DataFrame.select_dtypes - DataFrame.values - DataFrame.empty + DataFrame.dtypes + DataFrame.select_dtypes + DataFrame.values + DataFrame.empty DataFrame.shape Indexing, iteration @@ -37,6 +37,7 @@ Indexing, iteration DataFrame.tail DataFrame.get DataFrame.query + DataFrame.sample Function application, GroupBy & window ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/reference/series.rst b/docs/source/reference/series.rst index f9d2804..3ab6889 100644 --- a/docs/source/reference/series.rst +++ b/docs/source/reference/series.rst @@ -21,8 +21,8 @@ Attributes and underlying data Series.index Series.shape - Series.name - Series.empty + Series.name + Series.empty Indexing, iteration ~~~~~~~~~~~~~~~~~~~ @@ -31,6 +31,7 @@ Indexing, iteration Series.head Series.tail + Series.sample Binary operator functions ~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/eland/dataframe.py b/eland/dataframe.py index 7b84fc8..2510528 100644 --- a/eland/dataframe.py +++ b/eland/dataframe.py @@ -256,6 +256,45 @@ class DataFrame(NDFrame): """ return DataFrame(query_compiler=self._query_compiler.tail(n)) + def sample( + self, n: int = None, frac: float = None, random_state: int = None + ) -> "DataFrame": + """ + Return n randomly sample rows or the specify fraction of rows + + Parameters + ---------- + n : int, optional + Number of documents from index to return. Cannot be used with `frac`. + Default = 1 if `frac` = None. + frac : float, optional + Fraction of axis items to return. Cannot be used with `n`. + random_state : int, optional + Seed for the random number generator. + + Returns + ------- + eland.DataFrame: + eland DataFrame filtered containing n rows randomly sampled + + See Also + -------- + :pandas_api_docs:`pandas.DataFrame.sample` + """ + + if frac is not None and not (0.0 < frac <= 1.0): + raise ValueError("`frac` must be between 0. and 1.") + elif n is not None and frac is None and n % 1 != 0: + raise ValueError("Only integers accepted as `n` values") + elif (n is not None) == (frac is not None): + raise ValueError("Please enter a value for `frac` OR `n`, not both") + + return DataFrame( + query_compiler=self._query_compiler.sample( + n=n, frac=frac, random_state=random_state + ) + ) + def drop( self, labels=None, diff --git a/eland/filter.py b/eland/filter.py index f78729f..bd2801b 100644 --- a/eland/filter.py +++ b/eland/filter.py @@ -170,3 +170,19 @@ class QueryFilter(BooleanFilter): def __init__(self, query: Dict[str, Any]) -> None: super().__init__() self._filter = query + + +class MatchAllFilter(QueryFilter): + def __init__(self) -> None: + super().__init__({"match_all": {}}) + + +class RandomScoreFilter(QueryFilter): + def __init__(self, query: QueryFilter, random_state: int) -> None: + q = MatchAllFilter() if query.empty() else query + + seed = {} + if random_state is not None: + seed = {"seed": random_state, "field": "_seq_no"} + + super().__init__({"function_score": {"query": q.build(), "random_score": seed}}) diff --git a/eland/ndframe.py b/eland/ndframe.py index 1b79f43..f037605 100644 --- a/eland/ndframe.py +++ b/eland/ndframe.py @@ -488,3 +488,7 @@ class NDFrame(ABC): @abstractmethod def tail(self, n=5): pass + + @abstractmethod + def sample(self, n=None, frac=None, random_state=None): + pass diff --git a/eland/operations.py b/eland/operations.py index ce23cfd..9520371 100644 --- a/eland/operations.py +++ b/eland/operations.py @@ -24,6 +24,7 @@ from eland.actions import SortFieldAction from eland.tasks import ( HeadTask, TailTask, + SampleTask, BooleanFilterTask, ArithmeticOpFieldsTask, QueryTermsTask, @@ -84,6 +85,10 @@ class Operations: task = TailTask(index.sort_field, n) self._tasks.append(task) + def sample(self, index, n, random_state): + task = SampleTask(index.sort_field, n, random_state) + self._tasks.append(task) + def arithmetic_op_fields(self, display_name, arithmetic_series): if self._arithmetic_op_fields_task is None: self._arithmetic_op_fields_task = ArithmeticOpFieldsTask( diff --git a/eland/query.py b/eland/query.py index 4083489..a04b5e6 100644 --- a/eland/query.py +++ b/eland/query.py @@ -6,7 +6,7 @@ import warnings from copy import deepcopy from typing import Optional, Dict, List, Any -from eland.filter import BooleanFilter, NotNull, IsNull, IsIn +from eland.filter import RandomScoreFilter, BooleanFilter, NotNull, IsNull, IsIn class Query: @@ -152,5 +152,8 @@ class Query: else: self._query = self._query & boolean_filter + def random_score(self, random_state: int) -> None: + self._query = RandomScoreFilter(self._query, random_state) + def __repr__(self) -> str: return repr(self.to_search_body()) diff --git a/eland/query_compiler.py b/eland/query_compiler.py index 8283eb1..1386082 100644 --- a/eland/query_compiler.py +++ b/eland/query_compiler.py @@ -9,10 +9,10 @@ from typing import Optional, TYPE_CHECKING import numpy as np import pandas as pd -from eland import Index from eland.field_mappings import FieldMappings -from eland.operations import Operations from eland.filter import QueryFilter +from eland.operations import Operations +from eland import Index from eland.common import ( ensure_es_client, DEFAULT_PROGRESS_REPORTING_NUM_ROWS, @@ -393,6 +393,24 @@ class QueryCompiler: return result + def sample(self, n=None, frac=None, random_state=None): + result = self.copy() + + if n is None and frac is None: + n = 1 + elif n is None and frac is not None: + index_length = self._index_count() + n = int(round(frac * index_length)) + + if n < 0: + raise ValueError( + "A negative number of rows requested. Please provide positive value." + ) + + result._operations.sample(self._index, n, random_state) + + return result + def es_query(self, query): return self._update_query(QueryFilter(query)) diff --git a/eland/series.py b/eland/series.py index a809afb..4343d84 100644 --- a/eland/series.py +++ b/eland/series.py @@ -225,6 +225,9 @@ class Series(NDFrame): def tail(self, n=5): return Series(query_compiler=self._query_compiler.tail(n)) + def sample(self, n: int = None, frac: float = None, random_state: int = None): + return Series(query_compiler=self._query_compiler.sample(n, frac, random_state)) + def value_counts(self, es_size=10): """ Return the value counts for the specified field. diff --git a/eland/tasks.py b/eland/tasks.py index bcfe701..6c8ca1e 100644 --- a/eland/tasks.py +++ b/eland/tasks.py @@ -9,7 +9,6 @@ from eland import SortOrder from eland.actions import HeadAction, TailAction, SortIndexAction from eland.arithmetics import ArithmeticSeries - if TYPE_CHECKING: from .actions import PostProcessingAction # noqa: F401 from .filter import BooleanFilter # noqa: F401 @@ -175,6 +174,43 @@ class TailTask(SizeTask): return f"('{self._task_type}': ('sort_field': '{self._sort_field}', 'count': {self._count}))" +class SampleTask(SizeTask): + def __init__(self, sort_field: str, count: int, random_state: int): + super().__init__("sample") + self._count = count + self._random_state = random_state + self._sort_field = sort_field + + def resolve_task( + self, + query_params: "QueryParams", + post_processing: List["PostProcessingAction"], + query_compiler: "QueryCompiler", + ) -> RESOLVED_TASK_TYPE: + query_params.query.random_score(self._random_state) + + query_sort_field = self._sort_field + query_size = self._count + + if query_params.size is not None: + query_params.size = min(query_size, query_params.size) + else: + query_params.size = query_size + + if query_params.sort_field is None: + query_params.sort_field = query_sort_field + + post_processing.append(SortIndexAction()) + + return query_params, post_processing + + def size(self) -> int: + return self._count + + def __repr__(self) -> str: + return f"('{self._task_type}': ('count': {self._count}))" + + class QueryIdsTask(Task): def __init__(self, must: bool, ids: List[str]): """ diff --git a/eland/tests/dataframe/test_sample_pytest.py b/eland/tests/dataframe/test_sample_pytest.py new file mode 100644 index 0000000..4bf265d --- /dev/null +++ b/eland/tests/dataframe/test_sample_pytest.py @@ -0,0 +1,88 @@ +# Licensed to Elasticsearch B.V under one or more agreements. +# Elasticsearch B.V licenses this file to you under the Apache 2.0 License. +# See the LICENSE file in the project root for more information + +# File called _pytest for PyCharm compatibility +import pytest +from pandas.testing import assert_frame_equal + +from eland.tests.common import TestData +from eland.utils import eland_to_pandas + + +class TestDataFrameSample(TestData): + SEED = 42 + + def build_from_index(self, sample_ed_flights): + sample_pd_flights = self.pd_flights_small().loc[ + sample_ed_flights.index, sample_ed_flights.columns + ] + return sample_pd_flights + + def test_sample(self): + ed_flights_small = self.ed_flights_small() + first_sample = ed_flights_small.sample(n=10, random_state=self.SEED) + second_sample = ed_flights_small.sample(n=10, random_state=self.SEED) + + assert_frame_equal( + eland_to_pandas(first_sample), eland_to_pandas(second_sample) + ) + + def test_sample_raises(self): + ed_flights_small = self.ed_flights_small() + + with pytest.raises(ValueError): + ed_flights_small.sample(n=10, frac=0.1) + + with pytest.raises(ValueError): + ed_flights_small.sample(frac=1.5) + + with pytest.raises(ValueError): + ed_flights_small.sample(n=-1) + + def test_sample_basic(self): + ed_flights_small = self.ed_flights_small() + sample_ed_flights = ed_flights_small.sample(n=10, random_state=self.SEED) + pd_from_eland = eland_to_pandas(sample_ed_flights) + + # build using index + sample_pd_flights = self.build_from_index(pd_from_eland) + + assert_frame_equal(sample_pd_flights, pd_from_eland) + + def test_sample_frac_01(self): + frac = 0.15 + ed_flights = self.ed_flights_small().sample(frac=frac, random_state=self.SEED) + pd_from_eland = eland_to_pandas(ed_flights) + pd_flights = self.build_from_index(pd_from_eland) + + assert_frame_equal(pd_flights, pd_from_eland) + + # assert right size from pd_flights + size = len(self.pd_flights_small()) + assert len(pd_flights) == int(round(frac * size)) + + def test_sample_on_boolean_filter(self): + ed_flights = self.ed_flights_small() + columns = ["timestamp", "OriginAirportID", "DestAirportID", "FlightDelayMin"] + sample_ed_flights = ed_flights[columns].sample(n=5, random_state=self.SEED) + pd_from_eland = eland_to_pandas(sample_ed_flights) + sample_pd_flights = self.build_from_index(pd_from_eland) + + assert_frame_equal(sample_pd_flights, pd_from_eland) + + def test_sample_head(self): + ed_flights = self.ed_flights_small() + sample_ed_flights = ed_flights.sample(n=10, random_state=self.SEED) + sample_pd_flights = self.build_from_index(eland_to_pandas(sample_ed_flights)) + + pd_head_5 = sample_pd_flights.head(5) + ed_head_5 = sample_ed_flights.head(5) + assert_frame_equal(pd_head_5, eland_to_pandas(ed_head_5)) + + def test_sample_shape(self): + ed_flights = self.ed_flights_small() + sample_ed_flights = ed_flights.sample(n=10, random_state=self.SEED) + sample_pd_flights = self.build_from_index(eland_to_pandas(sample_ed_flights)) + + assert sample_pd_flights.shape == sample_ed_flights.shape diff --git a/eland/tests/series/test_sample_pytest.py b/eland/tests/series/test_sample_pytest.py new file mode 100644 index 0000000..856d59e --- /dev/null +++ b/eland/tests/series/test_sample_pytest.py @@ -0,0 +1,25 @@ +# Licensed to Elasticsearch B.V under one or more agreements. +# Elasticsearch B.V licenses this file to you under the Apache 2.0 License. +# See the LICENSE file in the project root for more information + +# File called _pytest for PyCharm compatibility +import eland as ed +from eland.tests import ES_TEST_CLIENT +from eland.tests import FLIGHTS_INDEX_NAME +from eland.tests.common import TestData +from eland.tests.common import assert_pandas_eland_series_equal + + +class TestSeriesSample(TestData): + SEED = 42 + + def build_from_index(self, ed_series): + ed2pd_series = ed_series._to_pandas() + return self.pd_flights()["Carrier"].iloc[ed2pd_series.index] + + def test_sample(self): + ed_s = ed.Series(ES_TEST_CLIENT, FLIGHTS_INDEX_NAME, "Carrier") + pd_s = self.build_from_index(ed_s.sample(n=10, random_state=self.SEED)) + + ed_s_sample = ed_s.sample(n=10, random_state=self.SEED) + assert_pandas_eland_series_equal(pd_s, ed_s_sample)