Support Series/collections in Series.isin(), add type hints

This commit is contained in:
Seth Michael Larson 2020-07-14 11:39:52 -05:00 committed by GitHub
parent 6e6ad04c5c
commit 140623283a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 157 additions and 99 deletions

View File

@ -1040,7 +1040,7 @@ script instead of being modified manually.
+---------------------------------------+------------+
| ``ed.Series.__dir__()`` | **Yes** |
+---------------------------------------+------------+
| ``ed.Series.__div__()`` | No |
| ``ed.Series.__div__()`` | **Yes** |
+---------------------------------------+------------+
| ``ed.Series.__divmod__()`` | No |
+---------------------------------------+------------+
@ -1134,7 +1134,7 @@ script instead of being modified manually.
+---------------------------------------+------------+
| ``ed.Series.__rand__()`` | No |
+---------------------------------------+------------+
| ``ed.Series.__rdiv__()`` | No |
| ``ed.Series.__rdiv__()`` | **Yes** |
+---------------------------------------+------------+
| ``ed.Series.__rdivmod__()`` | No |
+---------------------------------------+------------+

View File

@ -138,7 +138,7 @@ class ArithmeticSeries(ArithmeticObject):
for task in self._tasks:
if task.op_name == "__add__":
value = f"({value} + {task.object.resolve()})"
elif task.op_name == "__truediv__":
elif task.op_name in ("__truediv__", "__div__"):
value = f"({value} / {task.object.resolve()})"
elif task.op_name == "__floordiv__":
value = f"Math.floor({value} / {task.object.resolve()})"

View File

@ -19,7 +19,7 @@ import sys
import warnings
from io import StringIO
import re
from typing import Optional, Sequence, Union
from typing import Optional, Sequence, Union, Tuple
import numpy as np
import pandas as pd
@ -631,7 +631,7 @@ class DataFrame(NDFrame):
def info_es(self):
return self.es_info()
def es_query(self, query):
def es_query(self, query) -> "DataFrame":
"""Applies an Elasticsearch DSL query to the current DataFrame.
Parameters
@ -705,7 +705,7 @@ class DataFrame(NDFrame):
def info(
self, verbose=None, buf=None, max_cols=None, memory_usage=None, null_counts=None
):
) -> None:
"""
Print a concise summary of a DataFrame.
@ -822,7 +822,7 @@ class DataFrame(NDFrame):
dtype = dtypes.iloc[i]
col = pprint_thing(col)
line_no = _put_str(" {num}".format(num=i), space_num)
line_no = _put_str(f" {i}", space_num)
count = ""
if show_counts:
@ -1223,7 +1223,7 @@ class DataFrame(NDFrame):
}
return self._query_compiler.to_csv(**kwargs)
def to_pandas(self, show_progress: bool = False) -> "DataFrame":
def to_pandas(self, show_progress: bool = False) -> pd.DataFrame:
"""
Utility method to convert eland.Dataframe to pandas.Dataframe
@ -1233,10 +1233,10 @@ class DataFrame(NDFrame):
"""
return self._query_compiler.to_pandas(show_progress=show_progress)
def _empty_pd_df(self):
def _empty_pd_df(self) -> pd.DataFrame:
return self._query_compiler._empty_pd_ef()
def select_dtypes(self, include=None, exclude=None):
def select_dtypes(self, include=None, exclude=None) -> "DataFrame":
"""
Return a subset of the DataFrame's columns based on the column dtypes.
@ -1272,7 +1272,7 @@ class DataFrame(NDFrame):
return self._getitem_array(empty_df.columns)
@property
def shape(self):
def shape(self) -> Tuple[int, int]:
"""
Return a tuple representing the dimensionality of the DataFrame.
@ -1299,7 +1299,23 @@ class DataFrame(NDFrame):
return num_rows, num_columns
def keys(self):
@property
def ndim(self) -> int:
"""
Returns 2 by definition of a DataFrame
Returns
-------
int
By definition 2
See Also
--------
:pandas_api_docs:`pandas.DataFrame.ndim`
"""
return 2
def keys(self) -> pd.Index:
"""
Return columns
@ -1381,7 +1397,7 @@ class DataFrame(NDFrame):
hist = gfx.ed_hist_frame
def query(self, expr):
def query(self, expr) -> "DataFrame":
"""
Query the columns of a DataFrame with a boolean expression.
@ -1474,7 +1490,7 @@ class DataFrame(NDFrame):
like: Optional[str] = None,
regex: Optional[str] = None,
axis: Optional[Union[int, str]] = None,
):
) -> "DataFrame":
"""
Subset the dataframe rows or columns according to the specified index labels.
Note that this routine does not filter a dataframe on its

View File

@ -27,15 +27,16 @@ from pandas.core.dtypes.common import (
is_string_dtype,
)
from pandas.core.dtypes.inference import is_list_like
from typing import NamedTuple, Optional, Mapping, Dict, Any, TYPE_CHECKING
from typing import NamedTuple, Optional, Mapping, Dict, Any, TYPE_CHECKING, List, Set
if TYPE_CHECKING:
from elasticsearch import Elasticsearch
from eland import DataFrame
ES_FLOAT_TYPES = {"double", "float", "half_float", "scaled_float"}
ES_INTEGER_TYPES = {"long", "integer", "short", "byte"}
ES_COMPATIBLE_TYPES = {
ES_FLOAT_TYPES: Set[str] = {"double", "float", "half_float", "scaled_float"}
ES_INTEGER_TYPES: Set[str] = {"long", "integer", "short", "byte"}
ES_COMPATIBLE_TYPES: Dict[str, Set[str]] = {
"double": ES_FLOAT_TYPES,
"scaled_float": ES_FLOAT_TYPES,
"float": ES_FLOAT_TYPES,
@ -80,7 +81,7 @@ class Field(NamedTuple):
def np_dtype(self):
return np.dtype(self.pd_dtype)
def is_es_agg_compatible(self, es_agg):
def is_es_agg_compatible(self, es_agg) -> bool:
# Cardinality works for all types
# Numerics and bools work for all aggs
if es_agg == "cardinality" or self.is_numeric or self.is_bool:
@ -115,7 +116,7 @@ class FieldMappings:
or es_field_name.keyword (if exists) or None
"""
ES_DTYPE_TO_PD_DTYPE = {
ES_DTYPE_TO_PD_DTYPE: Dict[str, str] = {
"text": "object",
"keyword": "object",
"long": "int64",
@ -133,7 +134,7 @@ class FieldMappings:
}
# the labels for each column (display_name is index)
column_labels = [
column_labels: List[str] = [
"es_field_name",
"is_source",
"es_dtype",
@ -145,7 +146,12 @@ class FieldMappings:
"aggregatable_es_field_name",
]
def __init__(self, client=None, index_pattern=None, display_names=None):
def __init__(
self,
client: "Elasticsearch",
index_pattern: str,
display_names: Optional[List[str]] = None,
):
"""
Parameters
----------
@ -184,7 +190,9 @@ class FieldMappings:
self.display_names = display_names
@staticmethod
def _extract_fields_from_mapping(mappings, source_only=False, date_format=None):
def _extract_fields_from_mapping(
mappings: Dict[str, Any], source_only: bool = False
) -> Dict[str, str]:
"""
Extract all field names and types from a mapping.
```
@ -256,10 +264,10 @@ class FieldMappings:
# Recurse until we get a 'type: xxx'
def flatten(x, name=""):
if type(x) is dict:
if isinstance(x, dict):
for a in x:
if (
a == "type" and type(x[a]) is str
if a == "type" and isinstance(
x[a], str
): # 'type' can be a name of a field
field_name = name[:-1]
field_type = x[a]

View File

@ -17,10 +17,14 @@
import sys
from abc import ABC, abstractmethod
from typing import Tuple
from typing import TYPE_CHECKING, Tuple
import pandas as pd
from eland.query_compiler import QueryCompiler
if TYPE_CHECKING:
from eland.index import Index
"""
NDFrame
---------
@ -73,7 +77,8 @@ class NDFrame(ABC):
)
self._query_compiler = _query_compiler
def _get_index(self):
@property
def index(self) -> "Index":
"""
Return eland index referencing Elasticsearch field to index a DataFrame/Series
@ -100,10 +105,8 @@ class NDFrame(ABC):
"""
return self._query_compiler.index
index = property(_get_index)
@property
def dtypes(self):
def dtypes(self) -> pd.Series:
"""
Return the pandas dtypes in the DataFrame. Elasticsearch types are mapped
to pandas dtypes via Mappings._es_dtype_to_pd_dtype.__doc__
@ -129,7 +132,7 @@ class NDFrame(ABC):
"""
return self._query_compiler.dtypes
def _build_repr(self, num_rows):
def _build_repr(self, num_rows) -> pd.DataFrame:
# self could be Series or DataFrame
if len(self.index) <= num_rows:
return self.to_pandas()
@ -144,11 +147,11 @@ class NDFrame(ABC):
return head.append(tail)
def __sizeof__(self):
def __sizeof__(self) -> int:
# Don't default to pandas, just return approximation TODO - make this more accurate
return sys.getsizeof(self._query_compiler)
def __len__(self):
def __len__(self) -> int:
"""Gets the length of the DataFrame.
Returns:
@ -159,7 +162,7 @@ class NDFrame(ABC):
def _es_info(self, buf):
self._query_compiler.es_info(buf)
def mean(self, numeric_only=True):
def mean(self, numeric_only: bool = True) -> pd.Series:
"""
Return mean value for each numeric column
@ -191,7 +194,7 @@ class NDFrame(ABC):
"""
return self._query_compiler.mean(numeric_only=numeric_only)
def sum(self, numeric_only=True):
def sum(self, numeric_only: bool = True) -> pd.Series:
"""
Return sum for each numeric column
@ -223,7 +226,7 @@ class NDFrame(ABC):
"""
return self._query_compiler.sum(numeric_only=numeric_only)
def min(self, numeric_only=True):
def min(self, numeric_only: bool = True) -> pd.Series:
"""
Return the minimum value for each numeric column
@ -255,7 +258,7 @@ class NDFrame(ABC):
"""
return self._query_compiler.min(numeric_only=numeric_only)
def var(self, numeric_only=True):
def var(self, numeric_only: bool = True) -> pd.Series:
"""
Return variance for each numeric column
@ -285,7 +288,7 @@ class NDFrame(ABC):
"""
return self._query_compiler.var(numeric_only=numeric_only)
def std(self, numeric_only=True):
def std(self, numeric_only: bool = True) -> pd.Series:
"""
Return standard deviation for each numeric column
@ -315,7 +318,7 @@ class NDFrame(ABC):
"""
return self._query_compiler.std(numeric_only=numeric_only)
def median(self, numeric_only=True):
def median(self, numeric_only: bool = True) -> pd.Series:
"""
Return the median value for each numeric column
@ -345,7 +348,7 @@ class NDFrame(ABC):
"""
return self._query_compiler.median(numeric_only=numeric_only)
def max(self, numeric_only=True):
def max(self, numeric_only: bool = True) -> pd.Series:
"""
Return the maximum value for each numeric column
@ -377,7 +380,7 @@ class NDFrame(ABC):
"""
return self._query_compiler.max(numeric_only=numeric_only)
def nunique(self):
def nunique(self) -> pd.Series:
"""
Return cardinality of each field.
@ -423,7 +426,7 @@ class NDFrame(ABC):
"""
return self._query_compiler.nunique()
def mad(self, numeric_only=True):
def mad(self, numeric_only: bool = True) -> pd.Series:
"""
Return standard deviation for each numeric column
@ -456,7 +459,7 @@ class NDFrame(ABC):
def _hist(self, num_bins):
return self._query_compiler._hist(num_bins)
def describe(self):
def describe(self) -> pd.DataFrame:
"""
Generate descriptive statistics that summarize the central tendency, dispersion and shape of a
datasets distribution, excluding NaN values.

View File

@ -33,8 +33,9 @@ Based on NDFrame which underpins eland.DataFrame
import sys
import warnings
from collections.abc import Collection
from io import StringIO
from typing import Optional, Union, Sequence
from typing import Optional, Union, Sequence, Any, Tuple, TYPE_CHECKING
import numpy as np
import pandas as pd
@ -45,6 +46,7 @@ from eland import NDFrame
from eland.arithmetics import ArithmeticSeries, ArithmeticString, ArithmeticNumber
from eland.common import DEFAULT_NUM_ROWS_DISPLAYED, docstring_parameter
from eland.filter import (
BooleanFilter,
NotFilter,
Equal,
Greater,
@ -56,10 +58,14 @@ from eland.filter import (
IsNull,
NotNull,
)
from eland.utils import deprecated_api
from eland.utils import deprecated_api, to_list
if TYPE_CHECKING: # type: ignore
from elasticsearch import Elasticsearch # noqa: F401
from eland.query_compiler import QueryCompiler # noqa: F401
def _get_method_name():
def _get_method_name() -> str:
return sys._getframe(1).f_code.co_name
@ -106,12 +112,12 @@ class Series(NDFrame):
def __init__(
self,
es_client=None,
es_index_pattern=None,
name=None,
es_index_field=None,
_query_compiler=None,
):
es_client: Optional["Elasticsearch"] = None,
es_index_pattern: Optional[str] = None,
name: Optional[str] = None,
es_index_field: Optional[str] = None,
_query_compiler: Optional["QueryCompiler"] = None,
) -> None:
# Series has 1 column
if name is None:
columns = None
@ -129,7 +135,7 @@ class Series(NDFrame):
hist = eland.plotting.ed_hist_series
@property
def empty(self):
def empty(self) -> bool:
"""Determines if the Series is empty.
Returns:
@ -139,7 +145,7 @@ class Series(NDFrame):
return len(self.index) == 0
@property
def shape(self):
def shape(self) -> Tuple[int, int]:
"""
Return a tuple representing the dimensionality of the Series.
@ -167,7 +173,7 @@ class Series(NDFrame):
return num_rows, num_columns
@property
def es_field_name(self):
def es_field_name(self) -> str:
"""
Returns
-------
@ -176,15 +182,15 @@ class Series(NDFrame):
"""
return self._query_compiler.get_field_names(include_scripted_fields=True)[0]
def _get_name(self):
@property
def name(self) -> str:
return self._query_compiler.columns[0]
def _set_name(self, name):
@name.setter
def name(self, name: str) -> None:
self._query_compiler.rename({self.name: name}, inplace=True)
name = property(_get_name, _set_name)
def rename(self, new_name):
def rename(self, new_name: str) -> "Series":
"""
Rename name of series. Only column rename is supported. This does not change the underlying
Elasticsearch index, but adds a symbolic link from the new name (column) to the Elasticsearch field name.
@ -238,18 +244,23 @@ class Series(NDFrame):
_query_compiler=self._query_compiler.rename({self.name: new_name})
)
def head(self, n=5):
def head(self, n: int = 5) -> "Series":
return Series(_query_compiler=self._query_compiler.head(n))
def tail(self, n=5):
def tail(self, n: int = 5) -> "Series":
return Series(_query_compiler=self._query_compiler.tail(n))
def sample(self, n: int = None, frac: float = None, random_state: int = None):
def sample(
self,
n: Optional[int] = None,
frac: Optional[float] = None,
random_state: Optional[int] = None,
) -> "Series":
return Series(
_query_compiler=self._query_compiler.sample(n, frac, random_state)
)
def value_counts(self, es_size=10):
def value_counts(self, es_size: int = 10) -> pd.Series:
"""
Return the value counts for the specified field.
@ -287,9 +298,8 @@ class Series(NDFrame):
"""
if not isinstance(es_size, int):
raise TypeError("es_size must be a positive integer.")
if not es_size > 0:
elif es_size <= 0:
raise ValueError("es_size must be a positive integer.")
return self._query_compiler.value_counts(es_size)
# dtype not implemented for Series as causes query to fail
@ -297,7 +307,7 @@ class Series(NDFrame):
# ----------------------------------------------------------------------
# Rendering Methods
def __repr__(self):
def __repr__(self) -> str:
"""
Return a string representation for a particular Series.
"""
@ -339,7 +349,7 @@ class Series(NDFrame):
name=False,
max_rows=None,
min_rows=None,
):
) -> Optional[str]:
"""
Render a string representation of the Series.
@ -411,15 +421,15 @@ class Series(NDFrame):
result = _buf.getvalue()
return result
def to_pandas(self, show_progress=False):
def to_pandas(self, show_progress: bool = False) -> pd.Series:
return self._query_compiler.to_pandas(show_progress=show_progress)[self.name]
@property
def _dtype(self):
def _dtype(self) -> np.dtype:
# DO NOT MAKE PUBLIC (i.e. def dtype) as this breaks query eval implementation
return self._query_compiler.dtypes[0]
def __gt__(self, other):
def __gt__(self, other: Union[int, float, "Series"]) -> BooleanFilter:
if isinstance(other, Series):
# Need to use scripted query to compare to values
painless = f"doc['{self.name}'].value > doc['{other.name}'].value"
@ -429,7 +439,7 @@ class Series(NDFrame):
else:
raise NotImplementedError(other, type(other))
def __lt__(self, other):
def __lt__(self, other: Union[int, float, "Series"]) -> BooleanFilter:
if isinstance(other, Series):
# Need to use scripted query to compare to values
painless = f"doc['{self.name}'].value < doc['{other.name}'].value"
@ -439,7 +449,7 @@ class Series(NDFrame):
else:
raise NotImplementedError(other, type(other))
def __ge__(self, other):
def __ge__(self, other: Union[int, float, "Series"]) -> BooleanFilter:
if isinstance(other, Series):
# Need to use scripted query to compare to values
painless = f"doc['{self.name}'].value >= doc['{other.name}'].value"
@ -449,7 +459,7 @@ class Series(NDFrame):
else:
raise NotImplementedError(other, type(other))
def __le__(self, other):
def __le__(self, other: Union[int, float, "Series"]) -> BooleanFilter:
if isinstance(other, Series):
# Need to use scripted query to compare to values
painless = f"doc['{self.name}'].value <= doc['{other.name}'].value"
@ -459,7 +469,7 @@ class Series(NDFrame):
else:
raise NotImplementedError(other, type(other))
def __eq__(self, other):
def __eq__(self, other: Union[int, float, str, "Series"]) -> BooleanFilter:
if isinstance(other, Series):
# Need to use scripted query to compare to values
painless = f"doc['{self.name}'].value == doc['{other.name}'].value"
@ -471,7 +481,7 @@ class Series(NDFrame):
else:
raise NotImplementedError(other, type(other))
def __ne__(self, other):
def __ne__(self, other: Union[int, float, str, "Series"]) -> BooleanFilter:
if isinstance(other, Series):
# Need to use scripted query to compare to values
painless = f"doc['{self.name}'].value != doc['{other.name}'].value"
@ -483,13 +493,13 @@ class Series(NDFrame):
else:
raise NotImplementedError(other, type(other))
def isin(self, other):
if isinstance(other, list):
return IsIn(field=self.name, value=other)
def isin(self, other: Union[Collection, pd.Series]) -> BooleanFilter:
if isinstance(other, (Collection, pd.Series)):
return IsIn(field=self.name, value=to_list(other))
else:
raise NotImplementedError(other, type(other))
def isna(self):
def isna(self) -> BooleanFilter:
"""
Detect missing values.
@ -506,7 +516,7 @@ class Series(NDFrame):
isnull = isna
def notna(self):
def notna(self) -> BooleanFilter:
"""
Detect existing (non-missing) values.
@ -525,7 +535,7 @@ class Series(NDFrame):
notnull = notna
@property
def ndim(self):
def ndim(self) -> int:
"""
Returns 1 by definition of a Series
@ -596,7 +606,7 @@ class Series(NDFrame):
)
return Series(_query_compiler=new_query_compiler)
def es_info(self):
def es_info(self) -> str:
buf = StringIO()
super()._es_info(buf)
@ -604,7 +614,7 @@ class Series(NDFrame):
return buf.getvalue()
@deprecated_api("eland.Series.es_info()")
def info_es(self):
def info_es(self) -> str:
return self.es_info()
def __add__(self, right):
@ -1149,7 +1159,12 @@ class Series(NDFrame):
rsubtract = __rsub__
rtruediv = __rtruediv__
def _numeric_op(self, right, method_name):
# __div__ is technically Python 2.x only
# but pandas has it so we do too.
__div__ = __truediv__
__rdiv__ = __rtruediv__
def _numeric_op(self, right: Any, method_name: str) -> "Series":
"""
return a op b

View File

@ -86,10 +86,11 @@ class TestDataFrameQuery(TestData):
ed_flights = self.ed_flights()
pd_flights = self.pd_flights()
assert (
pd_flights[pd_flights.OriginAirportID.isin(["LHR", "SYD"])].shape
== ed_flights[ed_flights.OriginAirportID.isin(["LHR", "SYD"])].shape
)
for obj in (["LHR", "SYD"], ("LHR", "SYD"), pd.Series(data=["LHR", "SYD"])):
assert (
pd_flights[pd_flights.OriginAirportID.isin(obj)].shape
== ed_flights[ed_flights.OriginAirportID.isin(obj)].shape
)
def test_multiitem_query(self):
# Examples from:

View File

@ -18,16 +18,21 @@
import re
import functools
import warnings
from typing import Callable, TypeVar
from typing import Callable, TypeVar, Any, Union, List, cast, Collection
from collections.abc import Collection as ABCCollection
import pandas as pd # type: ignore
F = TypeVar("F")
RT = TypeVar("RT")
Item = TypeVar("Item")
def deprecated_api(replace_with: str) -> Callable[[F], F]:
def wrapper(f: F) -> F:
def deprecated_api(
replace_with: str,
) -> Callable[[Callable[..., RT]], Callable[..., RT]]:
def wrapper(f: Callable[..., RT]) -> Callable[..., RT]:
@functools.wraps(f)
def wrapped(*args, **kwargs):
def wrapped(*args: Any, **kwargs: Any) -> RT:
warnings.warn(
f"{f.__name__} is deprecated, use {replace_with} instead",
DeprecationWarning,
@ -39,10 +44,18 @@ def deprecated_api(replace_with: str) -> Callable[[F], F]:
return wrapper
def is_valid_attr_name(s):
def is_valid_attr_name(s: str) -> bool:
"""
Ensure the given string can be used as attribute on an object instance.
"""
return isinstance(s, str) and re.search(
string=s, pattern=r"^[a-zA-Z_][a-zA-Z0-9_]*$"
return bool(
isinstance(s, str) and re.search(string=s, pattern=r"^[a-zA-Z_][a-zA-Z0-9_]*$")
)
def to_list(x: Union[Collection[Any], pd.Series]) -> List[Any]:
if isinstance(x, ABCCollection):
return list(x)
elif isinstance(x, pd.Series):
return cast(List[Any], x.to_list())
raise NotImplementedError(f"Could not convert {type(x).__name__} into a list")

View File

@ -43,6 +43,7 @@ TYPED_FILES = {
"eland/index.py",
"eland/query.py",
"eland/tasks.py",
"eland/utils.py",
"eland/ml/__init__.py",
"eland/ml/_model_serializer.py",
"eland/ml/imported_ml_model.py",
@ -75,6 +76,7 @@ def lint(session):
session.error(f"The file {typed_file!r} couldn't be found")
popen = subprocess.Popen(
f"mypy --strict {typed_file}",
env=session.env,
shell=True,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,