Support Series/collections in Series.isin(), add type hints

This commit is contained in:
Seth Michael Larson 2020-07-14 11:39:52 -05:00 committed by GitHub
parent 6e6ad04c5c
commit 140623283a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 157 additions and 99 deletions

View File

@ -1040,7 +1040,7 @@ script instead of being modified manually.
+---------------------------------------+------------+ +---------------------------------------+------------+
| ``ed.Series.__dir__()`` | **Yes** | | ``ed.Series.__dir__()`` | **Yes** |
+---------------------------------------+------------+ +---------------------------------------+------------+
| ``ed.Series.__div__()`` | No | | ``ed.Series.__div__()`` | **Yes** |
+---------------------------------------+------------+ +---------------------------------------+------------+
| ``ed.Series.__divmod__()`` | No | | ``ed.Series.__divmod__()`` | No |
+---------------------------------------+------------+ +---------------------------------------+------------+
@ -1134,7 +1134,7 @@ script instead of being modified manually.
+---------------------------------------+------------+ +---------------------------------------+------------+
| ``ed.Series.__rand__()`` | No | | ``ed.Series.__rand__()`` | No |
+---------------------------------------+------------+ +---------------------------------------+------------+
| ``ed.Series.__rdiv__()`` | No | | ``ed.Series.__rdiv__()`` | **Yes** |
+---------------------------------------+------------+ +---------------------------------------+------------+
| ``ed.Series.__rdivmod__()`` | No | | ``ed.Series.__rdivmod__()`` | No |
+---------------------------------------+------------+ +---------------------------------------+------------+

View File

@ -138,7 +138,7 @@ class ArithmeticSeries(ArithmeticObject):
for task in self._tasks: for task in self._tasks:
if task.op_name == "__add__": if task.op_name == "__add__":
value = f"({value} + {task.object.resolve()})" value = f"({value} + {task.object.resolve()})"
elif task.op_name == "__truediv__": elif task.op_name in ("__truediv__", "__div__"):
value = f"({value} / {task.object.resolve()})" value = f"({value} / {task.object.resolve()})"
elif task.op_name == "__floordiv__": elif task.op_name == "__floordiv__":
value = f"Math.floor({value} / {task.object.resolve()})" value = f"Math.floor({value} / {task.object.resolve()})"

View File

@ -19,7 +19,7 @@ import sys
import warnings import warnings
from io import StringIO from io import StringIO
import re import re
from typing import Optional, Sequence, Union from typing import Optional, Sequence, Union, Tuple
import numpy as np import numpy as np
import pandas as pd import pandas as pd
@ -631,7 +631,7 @@ class DataFrame(NDFrame):
def info_es(self): def info_es(self):
return self.es_info() return self.es_info()
def es_query(self, query): def es_query(self, query) -> "DataFrame":
"""Applies an Elasticsearch DSL query to the current DataFrame. """Applies an Elasticsearch DSL query to the current DataFrame.
Parameters Parameters
@ -705,7 +705,7 @@ class DataFrame(NDFrame):
def info( def info(
self, verbose=None, buf=None, max_cols=None, memory_usage=None, null_counts=None self, verbose=None, buf=None, max_cols=None, memory_usage=None, null_counts=None
): ) -> None:
""" """
Print a concise summary of a DataFrame. Print a concise summary of a DataFrame.
@ -822,7 +822,7 @@ class DataFrame(NDFrame):
dtype = dtypes.iloc[i] dtype = dtypes.iloc[i]
col = pprint_thing(col) col = pprint_thing(col)
line_no = _put_str(" {num}".format(num=i), space_num) line_no = _put_str(f" {i}", space_num)
count = "" count = ""
if show_counts: if show_counts:
@ -1223,7 +1223,7 @@ class DataFrame(NDFrame):
} }
return self._query_compiler.to_csv(**kwargs) return self._query_compiler.to_csv(**kwargs)
def to_pandas(self, show_progress: bool = False) -> "DataFrame": def to_pandas(self, show_progress: bool = False) -> pd.DataFrame:
""" """
Utility method to convert eland.Dataframe to pandas.Dataframe Utility method to convert eland.Dataframe to pandas.Dataframe
@ -1233,10 +1233,10 @@ class DataFrame(NDFrame):
""" """
return self._query_compiler.to_pandas(show_progress=show_progress) return self._query_compiler.to_pandas(show_progress=show_progress)
def _empty_pd_df(self): def _empty_pd_df(self) -> pd.DataFrame:
return self._query_compiler._empty_pd_ef() return self._query_compiler._empty_pd_ef()
def select_dtypes(self, include=None, exclude=None): def select_dtypes(self, include=None, exclude=None) -> "DataFrame":
""" """
Return a subset of the DataFrame's columns based on the column dtypes. Return a subset of the DataFrame's columns based on the column dtypes.
@ -1272,7 +1272,7 @@ class DataFrame(NDFrame):
return self._getitem_array(empty_df.columns) return self._getitem_array(empty_df.columns)
@property @property
def shape(self): def shape(self) -> Tuple[int, int]:
""" """
Return a tuple representing the dimensionality of the DataFrame. Return a tuple representing the dimensionality of the DataFrame.
@ -1299,7 +1299,23 @@ class DataFrame(NDFrame):
return num_rows, num_columns return num_rows, num_columns
def keys(self): @property
def ndim(self) -> int:
"""
Returns 2 by definition of a DataFrame
Returns
-------
int
By definition 2
See Also
--------
:pandas_api_docs:`pandas.DataFrame.ndim`
"""
return 2
def keys(self) -> pd.Index:
""" """
Return columns Return columns
@ -1381,7 +1397,7 @@ class DataFrame(NDFrame):
hist = gfx.ed_hist_frame hist = gfx.ed_hist_frame
def query(self, expr): def query(self, expr) -> "DataFrame":
""" """
Query the columns of a DataFrame with a boolean expression. Query the columns of a DataFrame with a boolean expression.
@ -1474,7 +1490,7 @@ class DataFrame(NDFrame):
like: Optional[str] = None, like: Optional[str] = None,
regex: Optional[str] = None, regex: Optional[str] = None,
axis: Optional[Union[int, str]] = None, axis: Optional[Union[int, str]] = None,
): ) -> "DataFrame":
""" """
Subset the dataframe rows or columns according to the specified index labels. Subset the dataframe rows or columns according to the specified index labels.
Note that this routine does not filter a dataframe on its Note that this routine does not filter a dataframe on its

View File

@ -27,15 +27,16 @@ from pandas.core.dtypes.common import (
is_string_dtype, is_string_dtype,
) )
from pandas.core.dtypes.inference import is_list_like from pandas.core.dtypes.inference import is_list_like
from typing import NamedTuple, Optional, Mapping, Dict, Any, TYPE_CHECKING from typing import NamedTuple, Optional, Mapping, Dict, Any, TYPE_CHECKING, List, Set
if TYPE_CHECKING: if TYPE_CHECKING:
from elasticsearch import Elasticsearch
from eland import DataFrame from eland import DataFrame
ES_FLOAT_TYPES = {"double", "float", "half_float", "scaled_float"} ES_FLOAT_TYPES: Set[str] = {"double", "float", "half_float", "scaled_float"}
ES_INTEGER_TYPES = {"long", "integer", "short", "byte"} ES_INTEGER_TYPES: Set[str] = {"long", "integer", "short", "byte"}
ES_COMPATIBLE_TYPES = { ES_COMPATIBLE_TYPES: Dict[str, Set[str]] = {
"double": ES_FLOAT_TYPES, "double": ES_FLOAT_TYPES,
"scaled_float": ES_FLOAT_TYPES, "scaled_float": ES_FLOAT_TYPES,
"float": ES_FLOAT_TYPES, "float": ES_FLOAT_TYPES,
@ -80,7 +81,7 @@ class Field(NamedTuple):
def np_dtype(self): def np_dtype(self):
return np.dtype(self.pd_dtype) return np.dtype(self.pd_dtype)
def is_es_agg_compatible(self, es_agg): def is_es_agg_compatible(self, es_agg) -> bool:
# Cardinality works for all types # Cardinality works for all types
# Numerics and bools work for all aggs # Numerics and bools work for all aggs
if es_agg == "cardinality" or self.is_numeric or self.is_bool: if es_agg == "cardinality" or self.is_numeric or self.is_bool:
@ -115,7 +116,7 @@ class FieldMappings:
or es_field_name.keyword (if exists) or None or es_field_name.keyword (if exists) or None
""" """
ES_DTYPE_TO_PD_DTYPE = { ES_DTYPE_TO_PD_DTYPE: Dict[str, str] = {
"text": "object", "text": "object",
"keyword": "object", "keyword": "object",
"long": "int64", "long": "int64",
@ -133,7 +134,7 @@ class FieldMappings:
} }
# the labels for each column (display_name is index) # the labels for each column (display_name is index)
column_labels = [ column_labels: List[str] = [
"es_field_name", "es_field_name",
"is_source", "is_source",
"es_dtype", "es_dtype",
@ -145,7 +146,12 @@ class FieldMappings:
"aggregatable_es_field_name", "aggregatable_es_field_name",
] ]
def __init__(self, client=None, index_pattern=None, display_names=None): def __init__(
self,
client: "Elasticsearch",
index_pattern: str,
display_names: Optional[List[str]] = None,
):
""" """
Parameters Parameters
---------- ----------
@ -184,7 +190,9 @@ class FieldMappings:
self.display_names = display_names self.display_names = display_names
@staticmethod @staticmethod
def _extract_fields_from_mapping(mappings, source_only=False, date_format=None): def _extract_fields_from_mapping(
mappings: Dict[str, Any], source_only: bool = False
) -> Dict[str, str]:
""" """
Extract all field names and types from a mapping. Extract all field names and types from a mapping.
``` ```
@ -256,10 +264,10 @@ class FieldMappings:
# Recurse until we get a 'type: xxx' # Recurse until we get a 'type: xxx'
def flatten(x, name=""): def flatten(x, name=""):
if type(x) is dict: if isinstance(x, dict):
for a in x: for a in x:
if ( if a == "type" and isinstance(
a == "type" and type(x[a]) is str x[a], str
): # 'type' can be a name of a field ): # 'type' can be a name of a field
field_name = name[:-1] field_name = name[:-1]
field_type = x[a] field_type = x[a]

View File

@ -17,10 +17,14 @@
import sys import sys
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Tuple from typing import TYPE_CHECKING, Tuple
import pandas as pd
from eland.query_compiler import QueryCompiler from eland.query_compiler import QueryCompiler
if TYPE_CHECKING:
from eland.index import Index
""" """
NDFrame NDFrame
--------- ---------
@ -73,7 +77,8 @@ class NDFrame(ABC):
) )
self._query_compiler = _query_compiler self._query_compiler = _query_compiler
def _get_index(self): @property
def index(self) -> "Index":
""" """
Return eland index referencing Elasticsearch field to index a DataFrame/Series Return eland index referencing Elasticsearch field to index a DataFrame/Series
@ -100,10 +105,8 @@ class NDFrame(ABC):
""" """
return self._query_compiler.index return self._query_compiler.index
index = property(_get_index)
@property @property
def dtypes(self): def dtypes(self) -> pd.Series:
""" """
Return the pandas dtypes in the DataFrame. Elasticsearch types are mapped Return the pandas dtypes in the DataFrame. Elasticsearch types are mapped
to pandas dtypes via Mappings._es_dtype_to_pd_dtype.__doc__ to pandas dtypes via Mappings._es_dtype_to_pd_dtype.__doc__
@ -129,7 +132,7 @@ class NDFrame(ABC):
""" """
return self._query_compiler.dtypes return self._query_compiler.dtypes
def _build_repr(self, num_rows): def _build_repr(self, num_rows) -> pd.DataFrame:
# self could be Series or DataFrame # self could be Series or DataFrame
if len(self.index) <= num_rows: if len(self.index) <= num_rows:
return self.to_pandas() return self.to_pandas()
@ -144,11 +147,11 @@ class NDFrame(ABC):
return head.append(tail) return head.append(tail)
def __sizeof__(self): def __sizeof__(self) -> int:
# Don't default to pandas, just return approximation TODO - make this more accurate # Don't default to pandas, just return approximation TODO - make this more accurate
return sys.getsizeof(self._query_compiler) return sys.getsizeof(self._query_compiler)
def __len__(self): def __len__(self) -> int:
"""Gets the length of the DataFrame. """Gets the length of the DataFrame.
Returns: Returns:
@ -159,7 +162,7 @@ class NDFrame(ABC):
def _es_info(self, buf): def _es_info(self, buf):
self._query_compiler.es_info(buf) self._query_compiler.es_info(buf)
def mean(self, numeric_only=True): def mean(self, numeric_only: bool = True) -> pd.Series:
""" """
Return mean value for each numeric column Return mean value for each numeric column
@ -191,7 +194,7 @@ class NDFrame(ABC):
""" """
return self._query_compiler.mean(numeric_only=numeric_only) return self._query_compiler.mean(numeric_only=numeric_only)
def sum(self, numeric_only=True): def sum(self, numeric_only: bool = True) -> pd.Series:
""" """
Return sum for each numeric column Return sum for each numeric column
@ -223,7 +226,7 @@ class NDFrame(ABC):
""" """
return self._query_compiler.sum(numeric_only=numeric_only) return self._query_compiler.sum(numeric_only=numeric_only)
def min(self, numeric_only=True): def min(self, numeric_only: bool = True) -> pd.Series:
""" """
Return the minimum value for each numeric column Return the minimum value for each numeric column
@ -255,7 +258,7 @@ class NDFrame(ABC):
""" """
return self._query_compiler.min(numeric_only=numeric_only) return self._query_compiler.min(numeric_only=numeric_only)
def var(self, numeric_only=True): def var(self, numeric_only: bool = True) -> pd.Series:
""" """
Return variance for each numeric column Return variance for each numeric column
@ -285,7 +288,7 @@ class NDFrame(ABC):
""" """
return self._query_compiler.var(numeric_only=numeric_only) return self._query_compiler.var(numeric_only=numeric_only)
def std(self, numeric_only=True): def std(self, numeric_only: bool = True) -> pd.Series:
""" """
Return standard deviation for each numeric column Return standard deviation for each numeric column
@ -315,7 +318,7 @@ class NDFrame(ABC):
""" """
return self._query_compiler.std(numeric_only=numeric_only) return self._query_compiler.std(numeric_only=numeric_only)
def median(self, numeric_only=True): def median(self, numeric_only: bool = True) -> pd.Series:
""" """
Return the median value for each numeric column Return the median value for each numeric column
@ -345,7 +348,7 @@ class NDFrame(ABC):
""" """
return self._query_compiler.median(numeric_only=numeric_only) return self._query_compiler.median(numeric_only=numeric_only)
def max(self, numeric_only=True): def max(self, numeric_only: bool = True) -> pd.Series:
""" """
Return the maximum value for each numeric column Return the maximum value for each numeric column
@ -377,7 +380,7 @@ class NDFrame(ABC):
""" """
return self._query_compiler.max(numeric_only=numeric_only) return self._query_compiler.max(numeric_only=numeric_only)
def nunique(self): def nunique(self) -> pd.Series:
""" """
Return cardinality of each field. Return cardinality of each field.
@ -423,7 +426,7 @@ class NDFrame(ABC):
""" """
return self._query_compiler.nunique() return self._query_compiler.nunique()
def mad(self, numeric_only=True): def mad(self, numeric_only: bool = True) -> pd.Series:
""" """
Return standard deviation for each numeric column Return standard deviation for each numeric column
@ -456,7 +459,7 @@ class NDFrame(ABC):
def _hist(self, num_bins): def _hist(self, num_bins):
return self._query_compiler._hist(num_bins) return self._query_compiler._hist(num_bins)
def describe(self): def describe(self) -> pd.DataFrame:
""" """
Generate descriptive statistics that summarize the central tendency, dispersion and shape of a Generate descriptive statistics that summarize the central tendency, dispersion and shape of a
datasets distribution, excluding NaN values. datasets distribution, excluding NaN values.

View File

@ -33,8 +33,9 @@ Based on NDFrame which underpins eland.DataFrame
import sys import sys
import warnings import warnings
from collections.abc import Collection
from io import StringIO from io import StringIO
from typing import Optional, Union, Sequence from typing import Optional, Union, Sequence, Any, Tuple, TYPE_CHECKING
import numpy as np import numpy as np
import pandas as pd import pandas as pd
@ -45,6 +46,7 @@ from eland import NDFrame
from eland.arithmetics import ArithmeticSeries, ArithmeticString, ArithmeticNumber from eland.arithmetics import ArithmeticSeries, ArithmeticString, ArithmeticNumber
from eland.common import DEFAULT_NUM_ROWS_DISPLAYED, docstring_parameter from eland.common import DEFAULT_NUM_ROWS_DISPLAYED, docstring_parameter
from eland.filter import ( from eland.filter import (
BooleanFilter,
NotFilter, NotFilter,
Equal, Equal,
Greater, Greater,
@ -56,10 +58,14 @@ from eland.filter import (
IsNull, IsNull,
NotNull, NotNull,
) )
from eland.utils import deprecated_api from eland.utils import deprecated_api, to_list
if TYPE_CHECKING: # type: ignore
from elasticsearch import Elasticsearch # noqa: F401
from eland.query_compiler import QueryCompiler # noqa: F401
def _get_method_name(): def _get_method_name() -> str:
return sys._getframe(1).f_code.co_name return sys._getframe(1).f_code.co_name
@ -106,12 +112,12 @@ class Series(NDFrame):
def __init__( def __init__(
self, self,
es_client=None, es_client: Optional["Elasticsearch"] = None,
es_index_pattern=None, es_index_pattern: Optional[str] = None,
name=None, name: Optional[str] = None,
es_index_field=None, es_index_field: Optional[str] = None,
_query_compiler=None, _query_compiler: Optional["QueryCompiler"] = None,
): ) -> None:
# Series has 1 column # Series has 1 column
if name is None: if name is None:
columns = None columns = None
@ -129,7 +135,7 @@ class Series(NDFrame):
hist = eland.plotting.ed_hist_series hist = eland.plotting.ed_hist_series
@property @property
def empty(self): def empty(self) -> bool:
"""Determines if the Series is empty. """Determines if the Series is empty.
Returns: Returns:
@ -139,7 +145,7 @@ class Series(NDFrame):
return len(self.index) == 0 return len(self.index) == 0
@property @property
def shape(self): def shape(self) -> Tuple[int, int]:
""" """
Return a tuple representing the dimensionality of the Series. Return a tuple representing the dimensionality of the Series.
@ -167,7 +173,7 @@ class Series(NDFrame):
return num_rows, num_columns return num_rows, num_columns
@property @property
def es_field_name(self): def es_field_name(self) -> str:
""" """
Returns Returns
------- -------
@ -176,15 +182,15 @@ class Series(NDFrame):
""" """
return self._query_compiler.get_field_names(include_scripted_fields=True)[0] return self._query_compiler.get_field_names(include_scripted_fields=True)[0]
def _get_name(self): @property
def name(self) -> str:
return self._query_compiler.columns[0] return self._query_compiler.columns[0]
def _set_name(self, name): @name.setter
def name(self, name: str) -> None:
self._query_compiler.rename({self.name: name}, inplace=True) self._query_compiler.rename({self.name: name}, inplace=True)
name = property(_get_name, _set_name) def rename(self, new_name: str) -> "Series":
def rename(self, new_name):
""" """
Rename name of series. Only column rename is supported. This does not change the underlying Rename name of series. Only column rename is supported. This does not change the underlying
Elasticsearch index, but adds a symbolic link from the new name (column) to the Elasticsearch field name. Elasticsearch index, but adds a symbolic link from the new name (column) to the Elasticsearch field name.
@ -238,18 +244,23 @@ class Series(NDFrame):
_query_compiler=self._query_compiler.rename({self.name: new_name}) _query_compiler=self._query_compiler.rename({self.name: new_name})
) )
def head(self, n=5): def head(self, n: int = 5) -> "Series":
return Series(_query_compiler=self._query_compiler.head(n)) return Series(_query_compiler=self._query_compiler.head(n))
def tail(self, n=5): def tail(self, n: int = 5) -> "Series":
return Series(_query_compiler=self._query_compiler.tail(n)) return Series(_query_compiler=self._query_compiler.tail(n))
def sample(self, n: int = None, frac: float = None, random_state: int = None): def sample(
self,
n: Optional[int] = None,
frac: Optional[float] = None,
random_state: Optional[int] = None,
) -> "Series":
return Series( return Series(
_query_compiler=self._query_compiler.sample(n, frac, random_state) _query_compiler=self._query_compiler.sample(n, frac, random_state)
) )
def value_counts(self, es_size=10): def value_counts(self, es_size: int = 10) -> pd.Series:
""" """
Return the value counts for the specified field. Return the value counts for the specified field.
@ -287,9 +298,8 @@ class Series(NDFrame):
""" """
if not isinstance(es_size, int): if not isinstance(es_size, int):
raise TypeError("es_size must be a positive integer.") raise TypeError("es_size must be a positive integer.")
if not es_size > 0: elif es_size <= 0:
raise ValueError("es_size must be a positive integer.") raise ValueError("es_size must be a positive integer.")
return self._query_compiler.value_counts(es_size) return self._query_compiler.value_counts(es_size)
# dtype not implemented for Series as causes query to fail # dtype not implemented for Series as causes query to fail
@ -297,7 +307,7 @@ class Series(NDFrame):
# ---------------------------------------------------------------------- # ----------------------------------------------------------------------
# Rendering Methods # Rendering Methods
def __repr__(self): def __repr__(self) -> str:
""" """
Return a string representation for a particular Series. Return a string representation for a particular Series.
""" """
@ -339,7 +349,7 @@ class Series(NDFrame):
name=False, name=False,
max_rows=None, max_rows=None,
min_rows=None, min_rows=None,
): ) -> Optional[str]:
""" """
Render a string representation of the Series. Render a string representation of the Series.
@ -411,15 +421,15 @@ class Series(NDFrame):
result = _buf.getvalue() result = _buf.getvalue()
return result return result
def to_pandas(self, show_progress=False): def to_pandas(self, show_progress: bool = False) -> pd.Series:
return self._query_compiler.to_pandas(show_progress=show_progress)[self.name] return self._query_compiler.to_pandas(show_progress=show_progress)[self.name]
@property @property
def _dtype(self): def _dtype(self) -> np.dtype:
# DO NOT MAKE PUBLIC (i.e. def dtype) as this breaks query eval implementation # DO NOT MAKE PUBLIC (i.e. def dtype) as this breaks query eval implementation
return self._query_compiler.dtypes[0] return self._query_compiler.dtypes[0]
def __gt__(self, other): def __gt__(self, other: Union[int, float, "Series"]) -> BooleanFilter:
if isinstance(other, Series): if isinstance(other, Series):
# Need to use scripted query to compare to values # Need to use scripted query to compare to values
painless = f"doc['{self.name}'].value > doc['{other.name}'].value" painless = f"doc['{self.name}'].value > doc['{other.name}'].value"
@ -429,7 +439,7 @@ class Series(NDFrame):
else: else:
raise NotImplementedError(other, type(other)) raise NotImplementedError(other, type(other))
def __lt__(self, other): def __lt__(self, other: Union[int, float, "Series"]) -> BooleanFilter:
if isinstance(other, Series): if isinstance(other, Series):
# Need to use scripted query to compare to values # Need to use scripted query to compare to values
painless = f"doc['{self.name}'].value < doc['{other.name}'].value" painless = f"doc['{self.name}'].value < doc['{other.name}'].value"
@ -439,7 +449,7 @@ class Series(NDFrame):
else: else:
raise NotImplementedError(other, type(other)) raise NotImplementedError(other, type(other))
def __ge__(self, other): def __ge__(self, other: Union[int, float, "Series"]) -> BooleanFilter:
if isinstance(other, Series): if isinstance(other, Series):
# Need to use scripted query to compare to values # Need to use scripted query to compare to values
painless = f"doc['{self.name}'].value >= doc['{other.name}'].value" painless = f"doc['{self.name}'].value >= doc['{other.name}'].value"
@ -449,7 +459,7 @@ class Series(NDFrame):
else: else:
raise NotImplementedError(other, type(other)) raise NotImplementedError(other, type(other))
def __le__(self, other): def __le__(self, other: Union[int, float, "Series"]) -> BooleanFilter:
if isinstance(other, Series): if isinstance(other, Series):
# Need to use scripted query to compare to values # Need to use scripted query to compare to values
painless = f"doc['{self.name}'].value <= doc['{other.name}'].value" painless = f"doc['{self.name}'].value <= doc['{other.name}'].value"
@ -459,7 +469,7 @@ class Series(NDFrame):
else: else:
raise NotImplementedError(other, type(other)) raise NotImplementedError(other, type(other))
def __eq__(self, other): def __eq__(self, other: Union[int, float, str, "Series"]) -> BooleanFilter:
if isinstance(other, Series): if isinstance(other, Series):
# Need to use scripted query to compare to values # Need to use scripted query to compare to values
painless = f"doc['{self.name}'].value == doc['{other.name}'].value" painless = f"doc['{self.name}'].value == doc['{other.name}'].value"
@ -471,7 +481,7 @@ class Series(NDFrame):
else: else:
raise NotImplementedError(other, type(other)) raise NotImplementedError(other, type(other))
def __ne__(self, other): def __ne__(self, other: Union[int, float, str, "Series"]) -> BooleanFilter:
if isinstance(other, Series): if isinstance(other, Series):
# Need to use scripted query to compare to values # Need to use scripted query to compare to values
painless = f"doc['{self.name}'].value != doc['{other.name}'].value" painless = f"doc['{self.name}'].value != doc['{other.name}'].value"
@ -483,13 +493,13 @@ class Series(NDFrame):
else: else:
raise NotImplementedError(other, type(other)) raise NotImplementedError(other, type(other))
def isin(self, other): def isin(self, other: Union[Collection, pd.Series]) -> BooleanFilter:
if isinstance(other, list): if isinstance(other, (Collection, pd.Series)):
return IsIn(field=self.name, value=other) return IsIn(field=self.name, value=to_list(other))
else: else:
raise NotImplementedError(other, type(other)) raise NotImplementedError(other, type(other))
def isna(self): def isna(self) -> BooleanFilter:
""" """
Detect missing values. Detect missing values.
@ -506,7 +516,7 @@ class Series(NDFrame):
isnull = isna isnull = isna
def notna(self): def notna(self) -> BooleanFilter:
""" """
Detect existing (non-missing) values. Detect existing (non-missing) values.
@ -525,7 +535,7 @@ class Series(NDFrame):
notnull = notna notnull = notna
@property @property
def ndim(self): def ndim(self) -> int:
""" """
Returns 1 by definition of a Series Returns 1 by definition of a Series
@ -596,7 +606,7 @@ class Series(NDFrame):
) )
return Series(_query_compiler=new_query_compiler) return Series(_query_compiler=new_query_compiler)
def es_info(self): def es_info(self) -> str:
buf = StringIO() buf = StringIO()
super()._es_info(buf) super()._es_info(buf)
@ -604,7 +614,7 @@ class Series(NDFrame):
return buf.getvalue() return buf.getvalue()
@deprecated_api("eland.Series.es_info()") @deprecated_api("eland.Series.es_info()")
def info_es(self): def info_es(self) -> str:
return self.es_info() return self.es_info()
def __add__(self, right): def __add__(self, right):
@ -1149,7 +1159,12 @@ class Series(NDFrame):
rsubtract = __rsub__ rsubtract = __rsub__
rtruediv = __rtruediv__ rtruediv = __rtruediv__
def _numeric_op(self, right, method_name): # __div__ is technically Python 2.x only
# but pandas has it so we do too.
__div__ = __truediv__
__rdiv__ = __rtruediv__
def _numeric_op(self, right: Any, method_name: str) -> "Series":
""" """
return a op b return a op b

View File

@ -86,9 +86,10 @@ class TestDataFrameQuery(TestData):
ed_flights = self.ed_flights() ed_flights = self.ed_flights()
pd_flights = self.pd_flights() pd_flights = self.pd_flights()
for obj in (["LHR", "SYD"], ("LHR", "SYD"), pd.Series(data=["LHR", "SYD"])):
assert ( assert (
pd_flights[pd_flights.OriginAirportID.isin(["LHR", "SYD"])].shape pd_flights[pd_flights.OriginAirportID.isin(obj)].shape
== ed_flights[ed_flights.OriginAirportID.isin(["LHR", "SYD"])].shape == ed_flights[ed_flights.OriginAirportID.isin(obj)].shape
) )
def test_multiitem_query(self): def test_multiitem_query(self):

View File

@ -18,16 +18,21 @@
import re import re
import functools import functools
import warnings import warnings
from typing import Callable, TypeVar from typing import Callable, TypeVar, Any, Union, List, cast, Collection
from collections.abc import Collection as ABCCollection
import pandas as pd # type: ignore
F = TypeVar("F") RT = TypeVar("RT")
Item = TypeVar("Item")
def deprecated_api(replace_with: str) -> Callable[[F], F]: def deprecated_api(
def wrapper(f: F) -> F: replace_with: str,
) -> Callable[[Callable[..., RT]], Callable[..., RT]]:
def wrapper(f: Callable[..., RT]) -> Callable[..., RT]:
@functools.wraps(f) @functools.wraps(f)
def wrapped(*args, **kwargs): def wrapped(*args: Any, **kwargs: Any) -> RT:
warnings.warn( warnings.warn(
f"{f.__name__} is deprecated, use {replace_with} instead", f"{f.__name__} is deprecated, use {replace_with} instead",
DeprecationWarning, DeprecationWarning,
@ -39,10 +44,18 @@ def deprecated_api(replace_with: str) -> Callable[[F], F]:
return wrapper return wrapper
def is_valid_attr_name(s): def is_valid_attr_name(s: str) -> bool:
""" """
Ensure the given string can be used as attribute on an object instance. Ensure the given string can be used as attribute on an object instance.
""" """
return isinstance(s, str) and re.search( return bool(
string=s, pattern=r"^[a-zA-Z_][a-zA-Z0-9_]*$" isinstance(s, str) and re.search(string=s, pattern=r"^[a-zA-Z_][a-zA-Z0-9_]*$")
) )
def to_list(x: Union[Collection[Any], pd.Series]) -> List[Any]:
if isinstance(x, ABCCollection):
return list(x)
elif isinstance(x, pd.Series):
return cast(List[Any], x.to_list())
raise NotImplementedError(f"Could not convert {type(x).__name__} into a list")

View File

@ -43,6 +43,7 @@ TYPED_FILES = {
"eland/index.py", "eland/index.py",
"eland/query.py", "eland/query.py",
"eland/tasks.py", "eland/tasks.py",
"eland/utils.py",
"eland/ml/__init__.py", "eland/ml/__init__.py",
"eland/ml/_model_serializer.py", "eland/ml/_model_serializer.py",
"eland/ml/imported_ml_model.py", "eland/ml/imported_ml_model.py",
@ -75,6 +76,7 @@ def lint(session):
session.error(f"The file {typed_file!r} couldn't be found") session.error(f"The file {typed_file!r} couldn't be found")
popen = subprocess.Popen( popen = subprocess.Popen(
f"mypy --strict {typed_file}", f"mypy --strict {typed_file}",
env=session.env,
shell=True, shell=True,
stdout=subprocess.PIPE, stdout=subprocess.PIPE,
stderr=subprocess.STDOUT, stderr=subprocess.STDOUT,