Add idxmax and idxmin methods to DataFrame

This commit is contained in:
P. Sai Vinay 2021-07-28 18:25:26 +05:30 committed by GitHub
parent c74fccbd74
commit 4c1af42c14
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 177 additions and 5 deletions

View File

@ -0,0 +1,6 @@
eland.DataFrame.idxmax
========================
.. currentmodule:: eland
.. automethod:: DataFrame.idxmax

View File

@ -0,0 +1,6 @@
eland.DataFrame.idxmin
========================
.. currentmodule:: eland
.. automethod:: DataFrame.idxmin

View File

@ -101,6 +101,8 @@ Computations / Descriptive Stats
DataFrame.nunique
DataFrame.mode
DataFrame.quantile
DataFrame.idxmax
DataFrame.idxmin
Reindexing / Selection / Label Manipulation
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -1738,6 +1738,70 @@ class DataFrame(NDFrame):
"""
return self._query_compiler.quantile(quantiles=q, numeric_only=numeric_only)
def idxmax(self, axis: int = 0) -> pd.Series:
"""
Return index of first occurrence of maximum over requested axis.
NA/null values are excluded.
Parameters
----------
axis : {0, 1}, default 0
The axis to filter on, expressed as index (int).
Returns
-------
pandas.Series
See Also
--------
:pandas_api_docs:`pandas.DataFrame.idxmax`
Examples
--------
>>> ed_df = ed.DataFrame('localhost', 'flights')
>>> ed_flights = ed_df.filter(["AvgTicketPrice", "FlightDelayMin", "dayOfWeek", "timestamp"])
>>> ed_flights.idxmax()
AvgTicketPrice 1843
FlightDelayMin 109
dayOfWeek 1988
dtype: object
"""
return self._query_compiler.idx(axis=axis, sort_order="desc")
def idxmin(self, axis: int = 0) -> pd.Series:
"""
Return index of first occurrence of minimum over requested axis.
NA/null values are excluded.
Parameters
----------
axis : {0, 1}, default 0
The axis to filter on, expressed as index (int).
Returns
-------
pandas.Series
See Also
--------
:pandas_api_docs:`pandas.DataFrame.idxmin`
Examples
--------
>>> ed_df = ed.DataFrame('localhost', 'flights')
>>> ed_flights = ed_df.filter(["AvgTicketPrice", "FlightDelayMin", "dayOfWeek", "timestamp"])
>>> ed_flights.idxmin()
AvgTicketPrice 5454
FlightDelayMin 0
dayOfWeek 0
dtype: object
"""
return self._query_compiler.idx(axis=axis, sort_order="asc")
def query(self, expr) -> "DataFrame":
"""
Query the columns of a DataFrame with a boolean expression.

View File

@ -187,6 +187,56 @@ class Operations:
def hist(self, query_compiler, bins):
return self._hist_aggs(query_compiler, bins)
def idx(
self, query_compiler: "QueryCompiler", axis: int, sort_order: str
) -> pd.Series:
if axis == 1:
# Fetch idx on Columns
raise NotImplementedError(
"This feature is not implemented yet for 'axis = 1'"
)
# Fetch idx on Index
query_params, post_processing = self._resolve_tasks(query_compiler)
fields = query_compiler._mappings.all_source_fields()
# Consider only Numeric fields
fields = [field for field in fields if (field.is_numeric)]
body = Query(query_params.query)
for field in fields:
body.top_hits_agg(
name=f"top_hits_{field.es_field_name}",
source_columns=[field.es_field_name],
sort_order=sort_order,
size=1,
)
# Fetch Response
response = query_compiler._client.search(
index=query_compiler._index_pattern, size=0, body=body.to_search_body()
)
response = response["aggregations"]
results = {}
for field in fields:
res = response[f"top_hits_{field.es_field_name}"]["hits"]
if not res["total"]["value"] > 0:
raise ValueError("Empty Index with no rows")
if not res["hits"][0]["_source"]:
# This means there are NaN Values, we skip them
# Implement this when skipna is implemented
continue
else:
results[field.es_field_name] = res["hits"][0]["_id"]
return pd.Series(results)
def aggs(self, query_compiler, pd_aggs, numeric_only=None) -> pd.DataFrame:
results = self._metric_aggs(
query_compiler, pd_aggs, numeric_only=numeric_only, is_dataframe_agg=True

View File

@ -163,6 +163,22 @@ class Query:
agg = {"percentiles": {"field": field, "percents": percents}}
self._aggs[name] = agg
def top_hits_agg(
self,
name: str,
source_columns: List[str],
sort_order: str,
size: int = 1,
) -> None:
top_hits: Any = {}
if sort_order:
top_hits["sort"] = [{i: {"order": sort_order}} for i in source_columns]
if source_columns:
top_hits["_source"] = {"includes": source_columns}
top_hits["size"] = size
self._aggs[name] = {"top_hits": top_hits}
def composite_agg_bucket_terms(self, name: str, field: str) -> None:
"""
Add terms agg for composite aggregation

View File

@ -685,6 +685,9 @@ class QueryCompiler:
numeric_only=numeric_only,
)
def idx(self, axis: int, sort_order: str) -> pd.Series:
return self._operations.idx(self, axis=axis, sort_order=sort_order)
def value_counts(self, es_size: int) -> pd.Series:
return self._operations.value_counts(self, es_size)

View File

@ -72,16 +72,17 @@ def lint(session):
for typed_file in TYPED_FILES:
if not os.path.isfile(typed_file):
session.error(f"The file {typed_file!r} couldn't be found")
popen = subprocess.Popen(
f"mypy --strict {typed_file}",
process = subprocess.run(
["mypy", "--strict", typed_file],
env=session.env,
shell=True,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
)
popen.wait()
# Ensure that mypy itself ran successfully
assert process.returncode in (0, 1)
errors = []
for line in popen.stdout.read().decode().split("\n"):
for line in process.stdout.decode().split("\n"):
filepath = line.partition(":")[0]
if filepath in TYPED_FILES:
errors.append(line)

View File

@ -498,3 +498,27 @@ class TestDataFrameMetrics(TestData):
assert_frame_equal(
pd_quantile, ed_quantile, check_exact=False, rtol=4, check_dtype=False
)
def test_flights_idx_on_index(self):
pd_flights = self.pd_flights().filter(
["AvgTicketPrice", "FlightDelayMin", "dayOfWeek"]
)
ed_flights = self.ed_flights().filter(
["AvgTicketPrice", "FlightDelayMin", "dayOfWeek"]
)
pd_idxmax = pd_flights.idxmax()
ed_idxmax = ed_flights.idxmax()
assert_series_equal(pd_idxmax, ed_idxmax)
pd_idxmin = pd_flights.idxmin()
ed_idxmin = ed_flights.idxmin()
assert_series_equal(pd_idxmin, ed_idxmin)
def test_flights_idx_on_columns(self):
match = "This feature is not implemented yet for 'axis = 1'"
with pytest.raises(NotImplementedError, match=match):
ed_flights = self.ed_flights().filter(
["AvgTicketPrice", "FlightDelayMin", "dayOfWeek"]
)
ed_flights.idxmax(axis=1)