mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
Add mode() method to DataFrame and Series
This commit is contained in:
parent
27717eead1
commit
421d84fd20
6
docs/sphinx/reference/api/eland.DataFrame.mode.rst
Normal file
6
docs/sphinx/reference/api/eland.DataFrame.mode.rst
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
eland.DataFrame.mode
|
||||||
|
====================
|
||||||
|
|
||||||
|
.. currentmodule:: eland
|
||||||
|
|
||||||
|
.. automethod:: DataFrame.mode
|
6
docs/sphinx/reference/api/eland.Series.mode.rst
Normal file
6
docs/sphinx/reference/api/eland.Series.mode.rst
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
eland.Series.mode
|
||||||
|
====================
|
||||||
|
|
||||||
|
.. currentmodule:: eland
|
||||||
|
|
||||||
|
.. automethod:: Series.mode
|
@ -89,6 +89,7 @@ Computations / Descriptive Stats
|
|||||||
DataFrame.var
|
DataFrame.var
|
||||||
DataFrame.sum
|
DataFrame.sum
|
||||||
DataFrame.nunique
|
DataFrame.nunique
|
||||||
|
DataFrame.mode
|
||||||
|
|
||||||
Reindexing / Selection / Label Manipulation
|
Reindexing / Selection / Label Manipulation
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
@ -79,6 +79,7 @@ Computations / Descriptive Stats
|
|||||||
Series.var
|
Series.var
|
||||||
Series.nunique
|
Series.nunique
|
||||||
Series.value_counts
|
Series.value_counts
|
||||||
|
Series.mode
|
||||||
|
|
||||||
Reindexing / Selection / Label Manipulation
|
Reindexing / Selection / Label Manipulation
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
@ -1624,6 +1624,68 @@ class DataFrame(NDFrame):
|
|||||||
by=by, query_compiler=self._query_compiler.copy(), dropna=dropna
|
by=by, query_compiler=self._query_compiler.copy(), dropna=dropna
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def mode(
|
||||||
|
self,
|
||||||
|
numeric_only: bool = False,
|
||||||
|
dropna: bool = True,
|
||||||
|
es_size: int = 10,
|
||||||
|
) -> pd.DataFrame:
|
||||||
|
"""
|
||||||
|
Calculate mode of a DataFrame
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
numeric_only: {True, False} Default is False
|
||||||
|
Which datatype to be returned
|
||||||
|
- True: Returns all numeric or timestamp columns
|
||||||
|
- False: Returns all columns
|
||||||
|
dropna: {True, False} Default is True
|
||||||
|
- True: Don’t consider counts of NaN/NaT.
|
||||||
|
- False: Consider counts of NaN/NaT.
|
||||||
|
es_size: default 10
|
||||||
|
number of rows to be returned if mode has multiple values
|
||||||
|
|
||||||
|
See Also
|
||||||
|
--------
|
||||||
|
:pandas_api_docs:`pandas.DataFrame.mode`
|
||||||
|
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
>>> ed_ecommerce = ed.DataFrame('localhost', 'ecommerce')
|
||||||
|
>>> ed_df = ed_ecommerce.filter(["total_quantity", "geoip.city_name", "customer_birth_date", "day_of_week", "taxful_total_price"])
|
||||||
|
>>> ed_df.mode(numeric_only=False)
|
||||||
|
total_quantity geoip.city_name customer_birth_date day_of_week taxful_total_price
|
||||||
|
0 2 New York NaT Thursday 53.98
|
||||||
|
|
||||||
|
>>> ed_df.mode(numeric_only=True)
|
||||||
|
total_quantity taxful_total_price
|
||||||
|
0 2 53.98
|
||||||
|
|
||||||
|
>>> ed_df = ed_ecommerce.filter(["products.tax_amount","order_date"])
|
||||||
|
>>> ed_df.mode()
|
||||||
|
products.tax_amount order_date
|
||||||
|
0 0.0 2016-12-02 20:36:58
|
||||||
|
1 NaN 2016-12-04 23:44:10
|
||||||
|
2 NaN 2016-12-08 06:21:36
|
||||||
|
3 NaN 2016-12-08 09:38:53
|
||||||
|
4 NaN 2016-12-12 11:38:24
|
||||||
|
5 NaN 2016-12-12 19:46:34
|
||||||
|
6 NaN 2016-12-14 18:00:00
|
||||||
|
7 NaN 2016-12-15 11:38:24
|
||||||
|
8 NaN 2016-12-22 19:39:22
|
||||||
|
9 NaN 2016-12-24 06:21:36
|
||||||
|
|
||||||
|
>>> ed_df.mode(es_size = 3)
|
||||||
|
products.tax_amount order_date
|
||||||
|
0 0.0 2016-12-02 20:36:58
|
||||||
|
1 NaN 2016-12-04 23:44:10
|
||||||
|
2 NaN 2016-12-08 06:21:36
|
||||||
|
"""
|
||||||
|
# TODO dropna=False
|
||||||
|
return self._query_compiler.mode(
|
||||||
|
numeric_only=numeric_only, dropna=True, is_dataframe=True, es_size=es_size
|
||||||
|
)
|
||||||
|
|
||||||
def query(self, expr) -> "DataFrame":
|
def query(self, expr) -> "DataFrame":
|
||||||
"""
|
"""
|
||||||
Query the columns of a DataFrame with a boolean expression.
|
Query the columns of a DataFrame with a boolean expression.
|
||||||
|
@ -102,9 +102,13 @@ class Field(NamedTuple):
|
|||||||
# Except "median_absolute_deviation" which doesn't support bool
|
# Except "median_absolute_deviation" which doesn't support bool
|
||||||
if es_agg == "median_absolute_deviation" and self.is_bool:
|
if es_agg == "median_absolute_deviation" and self.is_bool:
|
||||||
return False
|
return False
|
||||||
# Cardinality and Count work for all types
|
# Cardinality, Count and mode work for all types
|
||||||
# Numerics and bools work for all aggs
|
# Numerics and bools work for all aggs
|
||||||
if es_agg in ("cardinality", "value_count") or self.is_numeric or self.is_bool:
|
if (
|
||||||
|
es_agg in {"cardinality", "value_count", "mode"}
|
||||||
|
or self.is_numeric
|
||||||
|
or self.is_bool
|
||||||
|
):
|
||||||
return True
|
return True
|
||||||
# Timestamps also work for 'min', 'max' and 'avg'
|
# Timestamps also work for 'min', 'max' and 'avg'
|
||||||
if es_agg in {"min", "max", "avg", "percentiles"} and self.is_timestamp:
|
if es_agg in {"min", "max", "avg", "percentiles"} and self.is_timestamp:
|
||||||
|
@ -617,3 +617,6 @@ class DataFrameGroupBy(GroupBy):
|
|||||||
numeric_only=False,
|
numeric_only=False,
|
||||||
is_dataframe_agg=False,
|
is_dataframe_agg=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def mode(self) -> None:
|
||||||
|
raise NotImplementedError("Currently mode is not supported for groupby")
|
||||||
|
@ -181,7 +181,7 @@ class Operations:
|
|||||||
dtype = "object"
|
dtype = "object"
|
||||||
return build_pd_series(results, index=results.keys(), dtype=dtype)
|
return build_pd_series(results, index=results.keys(), dtype=dtype)
|
||||||
|
|
||||||
def value_counts(self, query_compiler, es_size):
|
def value_counts(self, query_compiler: "QueryCompiler", es_size: int) -> pd.Series:
|
||||||
return self._terms_aggs(query_compiler, "terms", es_size)
|
return self._terms_aggs(query_compiler, "terms", es_size)
|
||||||
|
|
||||||
def hist(self, query_compiler, bins):
|
def hist(self, query_compiler, bins):
|
||||||
@ -195,12 +195,54 @@ class Operations:
|
|||||||
results, index=pd_aggs, dtype=(np.float64 if numeric_only else None)
|
results, index=pd_aggs, dtype=(np.float64 if numeric_only else None)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def mode(
|
||||||
|
self,
|
||||||
|
query_compiler: "QueryCompiler",
|
||||||
|
pd_aggs: List[str],
|
||||||
|
is_dataframe: bool,
|
||||||
|
es_size: int,
|
||||||
|
numeric_only: bool = False,
|
||||||
|
dropna: bool = True,
|
||||||
|
) -> Union[pd.DataFrame, pd.Series]:
|
||||||
|
|
||||||
|
results = self._metric_aggs(
|
||||||
|
query_compiler,
|
||||||
|
pd_aggs=pd_aggs,
|
||||||
|
numeric_only=numeric_only,
|
||||||
|
dropna=dropna,
|
||||||
|
es_mode_size=es_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
pd_dict: Dict[str, Any] = {}
|
||||||
|
row_diff: Optional[int] = None
|
||||||
|
|
||||||
|
if is_dataframe:
|
||||||
|
# If multiple values of mode is returned for a particular column
|
||||||
|
# find the maximum length and use that to fill dataframe with NaN/NaT
|
||||||
|
rows_len = max([len(value) for value in results.values()])
|
||||||
|
for key, values in results.items():
|
||||||
|
row_diff = rows_len - len(values)
|
||||||
|
# Convert np.ndarray to list
|
||||||
|
values = list(values)
|
||||||
|
if row_diff:
|
||||||
|
if isinstance(values[0], pd.Timestamp):
|
||||||
|
values.extend([pd.NaT] * row_diff)
|
||||||
|
else:
|
||||||
|
values.extend([np.NaN] * row_diff)
|
||||||
|
pd_dict[key] = values
|
||||||
|
|
||||||
|
return pd.DataFrame(pd_dict)
|
||||||
|
else:
|
||||||
|
return pd.DataFrame(results.values()).iloc[0].rename()
|
||||||
|
|
||||||
def _metric_aggs(
|
def _metric_aggs(
|
||||||
self,
|
self,
|
||||||
query_compiler: "QueryCompiler",
|
query_compiler: "QueryCompiler",
|
||||||
pd_aggs: List[str],
|
pd_aggs: List[str],
|
||||||
numeric_only: Optional[bool] = None,
|
numeric_only: Optional[bool] = None,
|
||||||
is_dataframe_agg: bool = False,
|
is_dataframe_agg: bool = False,
|
||||||
|
es_mode_size: Optional[int] = None,
|
||||||
|
dropna: bool = True,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Used to calculate metric aggregations
|
Used to calculate metric aggregations
|
||||||
@ -216,6 +258,10 @@ class Operations:
|
|||||||
return either all numeric values or NaN/NaT
|
return either all numeric values or NaN/NaT
|
||||||
is_dataframe_agg:
|
is_dataframe_agg:
|
||||||
know if this method is called from single-agg or aggreagation method
|
know if this method is called from single-agg or aggreagation method
|
||||||
|
es_mode_size:
|
||||||
|
number of rows to return when multiple mode values are present.
|
||||||
|
dropna:
|
||||||
|
drop NaN/NaT for a dataframe
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
@ -252,6 +298,15 @@ class Operations:
|
|||||||
es_agg[0],
|
es_agg[0],
|
||||||
field.aggregatable_es_field_name,
|
field.aggregatable_es_field_name,
|
||||||
)
|
)
|
||||||
|
elif es_agg == "mode":
|
||||||
|
# TODO for dropna=False, Check If field is timestamp or boolean or numeric,
|
||||||
|
# then use missing parameter for terms aggregation.
|
||||||
|
body.terms_aggs(
|
||||||
|
f"{es_agg}_{field.es_field_name}",
|
||||||
|
"terms",
|
||||||
|
field.aggregatable_es_field_name,
|
||||||
|
es_mode_size,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
body.metric_aggs(
|
body.metric_aggs(
|
||||||
f"{es_agg}_{field.es_field_name}",
|
f"{es_agg}_{field.es_field_name}",
|
||||||
@ -280,7 +335,9 @@ class Operations:
|
|||||||
is_dataframe_agg=is_dataframe_agg,
|
is_dataframe_agg=is_dataframe_agg,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _terms_aggs(self, query_compiler, func, es_size=None):
|
def _terms_aggs(
|
||||||
|
self, query_compiler: "QueryCompiler", func: str, es_size: int
|
||||||
|
) -> pd.Series:
|
||||||
"""
|
"""
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
@ -499,13 +556,43 @@ class Operations:
|
|||||||
agg_value = np.sqrt(
|
agg_value = np.sqrt(
|
||||||
(count / (count - 1.0)) * agg_value * agg_value
|
(count / (count - 1.0)) * agg_value * agg_value
|
||||||
)
|
)
|
||||||
|
elif es_agg == "mode":
|
||||||
|
# For terms aggregation buckets are returned
|
||||||
|
# agg_value will be of type list
|
||||||
|
agg_value = response["aggregations"][
|
||||||
|
f"{es_agg}_{field.es_field_name}"
|
||||||
|
]["buckets"]
|
||||||
else:
|
else:
|
||||||
agg_value = response["aggregations"][
|
agg_value = response["aggregations"][
|
||||||
f"{es_agg}_{field.es_field_name}"
|
f"{es_agg}_{field.es_field_name}"
|
||||||
]["value"]
|
]["value"]
|
||||||
|
|
||||||
|
if isinstance(agg_value, list):
|
||||||
|
# include top-terms in the result.
|
||||||
|
if not agg_value:
|
||||||
|
# If the all the documents for a field are empty
|
||||||
|
agg_value = [field.nan_value]
|
||||||
|
else:
|
||||||
|
max_doc_count = agg_value[0]["doc_count"]
|
||||||
|
# We need only keys which are equal to max_doc_count
|
||||||
|
# lesser values are ignored
|
||||||
|
agg_value = [
|
||||||
|
item["key"]
|
||||||
|
for item in agg_value
|
||||||
|
if item["doc_count"] == max_doc_count
|
||||||
|
]
|
||||||
|
|
||||||
|
# Maintain datatype by default because pandas does the same
|
||||||
|
# text are returned as-is
|
||||||
|
if field.is_bool or field.is_numeric:
|
||||||
|
agg_value = [
|
||||||
|
field.np_dtype.type(value) for value in agg_value
|
||||||
|
]
|
||||||
|
|
||||||
# Null usually means there were no results.
|
# Null usually means there were no results.
|
||||||
if agg_value is None or np.isnan(agg_value):
|
if not isinstance(agg_value, list) and (
|
||||||
|
agg_value is None or np.isnan(agg_value)
|
||||||
|
):
|
||||||
if is_dataframe_agg and not numeric_only:
|
if is_dataframe_agg and not numeric_only:
|
||||||
agg_value = np.NaN
|
agg_value = np.NaN
|
||||||
elif not is_dataframe_agg and numeric_only is False:
|
elif not is_dataframe_agg and numeric_only is False:
|
||||||
@ -517,13 +604,22 @@ class Operations:
|
|||||||
|
|
||||||
# If this is a non-null timestamp field convert to a pd.Timestamp()
|
# If this is a non-null timestamp field convert to a pd.Timestamp()
|
||||||
elif field.is_timestamp:
|
elif field.is_timestamp:
|
||||||
|
if isinstance(agg_value, list):
|
||||||
|
# convert to timestamp results for mode
|
||||||
|
agg_value = [
|
||||||
|
elasticsearch_date_to_pandas_date(
|
||||||
|
value, field.es_date_format
|
||||||
|
)
|
||||||
|
for value in agg_value
|
||||||
|
]
|
||||||
|
else:
|
||||||
agg_value = elasticsearch_date_to_pandas_date(
|
agg_value = elasticsearch_date_to_pandas_date(
|
||||||
agg_value, field.es_date_format
|
agg_value, field.es_date_format
|
||||||
)
|
)
|
||||||
# If numeric_only is False | None then maintain column datatype
|
# If numeric_only is False | None then maintain column datatype
|
||||||
elif not numeric_only:
|
elif not numeric_only:
|
||||||
# we're only converting to bool for lossless aggs like min, max, and median.
|
# we're only converting to bool for lossless aggs like min, max, and median.
|
||||||
if pd_agg in {"max", "min", "median", "sum"}:
|
if pd_agg in {"max", "min", "median", "sum", "mode"}:
|
||||||
# 'sum' isn't representable with bool, use int64
|
# 'sum' isn't representable with bool, use int64
|
||||||
if pd_agg == "sum" and field.is_bool:
|
if pd_agg == "sum" and field.is_bool:
|
||||||
agg_value = np.int64(agg_value)
|
agg_value = np.int64(agg_value)
|
||||||
@ -791,10 +887,15 @@ class Operations:
|
|||||||
elif pd_agg == "median":
|
elif pd_agg == "median":
|
||||||
es_aggs.append(("percentiles", "50.0"))
|
es_aggs.append(("percentiles", "50.0"))
|
||||||
|
|
||||||
# Not implemented
|
|
||||||
elif pd_agg == "mode":
|
elif pd_agg == "mode":
|
||||||
# We could do this via top term
|
if len(pd_aggs) != 1:
|
||||||
raise NotImplementedError(pd_agg, " not currently implemented")
|
raise NotImplementedError(
|
||||||
|
"Currently mode is not supported in df.agg(...). Try df.mode()"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
es_aggs.append("mode")
|
||||||
|
|
||||||
|
# Not implemented
|
||||||
elif pd_agg == "quantile":
|
elif pd_agg == "quantile":
|
||||||
# TODO
|
# TODO
|
||||||
raise NotImplementedError(pd_agg, " not currently implemented")
|
raise NotImplementedError(pd_agg, " not currently implemented")
|
||||||
|
@ -101,7 +101,14 @@ class Query:
|
|||||||
else:
|
else:
|
||||||
self._query = self._query & Rlike(field, value)
|
self._query = self._query & Rlike(field, value)
|
||||||
|
|
||||||
def terms_aggs(self, name: str, func: str, field: str, es_size: int) -> None:
|
def terms_aggs(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
func: str,
|
||||||
|
field: str,
|
||||||
|
es_size: Optional[int] = None,
|
||||||
|
missing: Optional[Any] = None,
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Add terms agg e.g
|
Add terms agg e.g
|
||||||
|
|
||||||
@ -109,12 +116,18 @@ class Query:
|
|||||||
"name": {
|
"name": {
|
||||||
"terms": {
|
"terms": {
|
||||||
"field": "Airline",
|
"field": "Airline",
|
||||||
"size": 10
|
"size": 10,
|
||||||
|
"missing": "null"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
agg = {func: {"field": field, "size": es_size}}
|
agg = {func: {"field": field}}
|
||||||
|
if es_size:
|
||||||
|
agg[func]["size"] = str(es_size)
|
||||||
|
|
||||||
|
if missing:
|
||||||
|
agg[func]["missing"] = missing
|
||||||
self._aggs[name] = agg
|
self._aggs[name] = agg
|
||||||
|
|
||||||
def metric_aggs(self, name: str, func: str, field: str) -> None:
|
def metric_aggs(self, name: str, func: str, field: str) -> None:
|
||||||
|
@ -621,6 +621,22 @@ class QueryCompiler:
|
|||||||
self, ["nunique"], numeric_only=False
|
self, ["nunique"], numeric_only=False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def mode(
|
||||||
|
self,
|
||||||
|
es_size: int,
|
||||||
|
numeric_only: bool = False,
|
||||||
|
dropna: bool = True,
|
||||||
|
is_dataframe: bool = True,
|
||||||
|
) -> Union[pd.DataFrame, pd.Series]:
|
||||||
|
return self._operations.mode(
|
||||||
|
self,
|
||||||
|
pd_aggs=["mode"],
|
||||||
|
numeric_only=numeric_only,
|
||||||
|
dropna=dropna,
|
||||||
|
is_dataframe=is_dataframe,
|
||||||
|
es_size=es_size,
|
||||||
|
)
|
||||||
|
|
||||||
def aggs_groupby(
|
def aggs_groupby(
|
||||||
self,
|
self,
|
||||||
by: List[str],
|
by: List[str],
|
||||||
@ -638,7 +654,7 @@ class QueryCompiler:
|
|||||||
numeric_only=numeric_only,
|
numeric_only=numeric_only,
|
||||||
)
|
)
|
||||||
|
|
||||||
def value_counts(self, es_size):
|
def value_counts(self, es_size: int) -> pd.Series:
|
||||||
return self._operations.value_counts(self, es_size)
|
return self._operations.value_counts(self, es_size)
|
||||||
|
|
||||||
def es_info(self, buf):
|
def es_info(self, buf):
|
||||||
|
@ -637,6 +637,48 @@ class Series(NDFrame):
|
|||||||
)
|
)
|
||||||
return Series(_query_compiler=new_query_compiler)
|
return Series(_query_compiler=new_query_compiler)
|
||||||
|
|
||||||
|
def mode(self, es_size: int = 10) -> pd.Series:
|
||||||
|
"""
|
||||||
|
Calculate mode of a series
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
es_size: default 10
|
||||||
|
number of rows to be returned if mode has multiple values
|
||||||
|
|
||||||
|
See Also
|
||||||
|
--------
|
||||||
|
:pandas_api_docs:`pandas.Series.mode`
|
||||||
|
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
>>> ed_ecommerce = ed.DataFrame('localhost', 'ecommerce')
|
||||||
|
>>> ed_ecommerce["day_of_week"].mode()
|
||||||
|
0 Thursday
|
||||||
|
dtype: object
|
||||||
|
|
||||||
|
>>> ed_ecommerce["order_date"].mode()
|
||||||
|
0 2016-12-02 20:36:58
|
||||||
|
1 2016-12-04 23:44:10
|
||||||
|
2 2016-12-08 06:21:36
|
||||||
|
3 2016-12-08 09:38:53
|
||||||
|
4 2016-12-12 11:38:24
|
||||||
|
5 2016-12-12 19:46:34
|
||||||
|
6 2016-12-14 18:00:00
|
||||||
|
7 2016-12-15 11:38:24
|
||||||
|
8 2016-12-22 19:39:22
|
||||||
|
9 2016-12-24 06:21:36
|
||||||
|
dtype: datetime64[ns]
|
||||||
|
|
||||||
|
>>> ed_ecommerce["order_date"].mode(es_size=3)
|
||||||
|
0 2016-12-02 20:36:58
|
||||||
|
1 2016-12-04 23:44:10
|
||||||
|
2 2016-12-08 06:21:36
|
||||||
|
dtype: datetime64[ns]
|
||||||
|
|
||||||
|
"""
|
||||||
|
return self._query_compiler.mode(is_dataframe=False, es_size=es_size)
|
||||||
|
|
||||||
def es_match(
|
def es_match(
|
||||||
self,
|
self,
|
||||||
text: str,
|
text: str,
|
||||||
|
@ -194,3 +194,9 @@ class TestGroupbyDataFrame(TestData):
|
|||||||
assert_index_equal(pd_min_mad.columns, ed_min_mad.columns)
|
assert_index_equal(pd_min_mad.columns, ed_min_mad.columns)
|
||||||
assert_index_equal(pd_min_mad.index, ed_min_mad.index)
|
assert_index_equal(pd_min_mad.index, ed_min_mad.index)
|
||||||
assert_series_equal(pd_min_mad.dtypes, ed_min_mad.dtypes)
|
assert_series_equal(pd_min_mad.dtypes, ed_min_mad.dtypes)
|
||||||
|
|
||||||
|
def test_groupby_mode(self):
|
||||||
|
ed_flights = self.ed_flights()
|
||||||
|
match = "Currently mode is not supported for groupby"
|
||||||
|
with pytest.raises(NotImplementedError, match=match):
|
||||||
|
ed_flights.groupby("Cancelled").mode()
|
||||||
|
@ -426,3 +426,23 @@ class TestDataFrameMetrics(TestData):
|
|||||||
ed_count = ed_flights.agg(["count"])
|
ed_count = ed_flights.agg(["count"])
|
||||||
|
|
||||||
assert_frame_equal(pd_count, ed_count)
|
assert_frame_equal(pd_count, ed_count)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("numeric_only", [True, False])
|
||||||
|
@pytest.mark.parametrize("es_size", [1, 2, 20, 100, 5000, 3000])
|
||||||
|
def test_aggs_mode(self, es_size, numeric_only):
|
||||||
|
# FlightNum has unique values, so we can test `fill` NaN/NaT for remaining columns
|
||||||
|
pd_flights = self.pd_flights().filter(
|
||||||
|
["Cancelled", "dayOfWeek", "timestamp", "DestCountry", "FlightNum"]
|
||||||
|
)
|
||||||
|
ed_flights = self.ed_flights().filter(
|
||||||
|
["Cancelled", "dayOfWeek", "timestamp", "DestCountry", "FlightNum"]
|
||||||
|
)
|
||||||
|
|
||||||
|
pd_mode = pd_flights.mode(numeric_only=numeric_only)[:es_size]
|
||||||
|
ed_mode = ed_flights.mode(numeric_only=numeric_only, es_size=es_size)
|
||||||
|
|
||||||
|
# Skipping dtype check because eland is giving Cancelled dtype as bool
|
||||||
|
# but pandas is referring it as object
|
||||||
|
assert_frame_equal(
|
||||||
|
pd_mode, ed_mode, check_dtype=(False if es_size == 1 else True)
|
||||||
|
)
|
||||||
|
@ -22,6 +22,7 @@ from datetime import timedelta
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import pytest
|
import pytest
|
||||||
|
from pandas.testing import assert_series_equal
|
||||||
|
|
||||||
from tests.common import TestData, assert_almost_equal
|
from tests.common import TestData, assert_almost_equal
|
||||||
|
|
||||||
@ -114,3 +115,25 @@ class TestSeriesMetrics(TestData):
|
|||||||
<= median
|
<= median
|
||||||
<= pd.to_datetime("2018-01-01 12:00:00.000")
|
<= pd.to_datetime("2018-01-01 12:00:00.000")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"column", ["day_of_week", "geoip.region_name", "taxful_total_price", "user"]
|
||||||
|
)
|
||||||
|
def test_ecommerce_mode(self, column):
|
||||||
|
ed_series = self.ed_ecommerce()
|
||||||
|
pd_series = self.pd_ecommerce()
|
||||||
|
|
||||||
|
ed_mode = ed_series[column].mode()
|
||||||
|
pd_mode = pd_series[column].mode()
|
||||||
|
|
||||||
|
assert_series_equal(ed_mode, pd_mode)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("es_size", [1, 2, 10, 20])
|
||||||
|
def test_ecommerce_mode_es_size(self, es_size):
|
||||||
|
ed_series = self.ed_ecommerce()
|
||||||
|
pd_series = self.pd_ecommerce()
|
||||||
|
|
||||||
|
pd_mode = pd_series["order_date"].mode()[:es_size]
|
||||||
|
ed_mode = ed_series["order_date"].mode(es_size)
|
||||||
|
|
||||||
|
assert_series_equal(pd_mode, ed_mode)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user