From 823f01cc6c5e25f82fcdca25f1a70e7067d76db4 Mon Sep 17 00:00:00 2001 From: "P. Sai Vinay" <33659563+V1NAY8@users.noreply.github.com> Date: Mon, 2 Aug 2021 22:20:35 +0530 Subject: [PATCH] Add type hints to 'eland.operations' and 'eland.ndframe' --- eland/arithmetics.py | 18 +- eland/common.py | 22 ++- eland/conftest.py | 8 +- eland/dataframe.py | 95 ++++++---- eland/field_mappings.py | 32 ++-- eland/ml/_optional.py | 9 +- eland/ml/ml_model.py | 11 +- eland/ml/transformers/sklearn.py | 4 +- eland/ml/transformers/xgboost.py | 2 +- eland/ndframe.py | 6 +- eland/operations.py | 280 ++++++++++++++++++----------- eland/plotting/_matplotlib/hist.py | 15 +- eland/query_compiler.py | 125 ++++++++----- eland/series.py | 62 +++---- eland/tasks.py | 4 +- noxfile.py | 8 +- requirements-dev.txt | 1 + setup.cfg | 2 + 18 files changed, 428 insertions(+), 276 deletions(-) diff --git a/eland/arithmetics.py b/eland/arithmetics.py index b9d1b80..6242fe1 100644 --- a/eland/arithmetics.py +++ b/eland/arithmetics.py @@ -19,9 +19,11 @@ from abc import ABC, abstractmethod from io import StringIO from typing import TYPE_CHECKING, Any, List, Union -import numpy as np # type: ignore +import numpy as np if TYPE_CHECKING: + from numpy.typing import DTypeLike + from .query_compiler import QueryCompiler @@ -32,7 +34,7 @@ class ArithmeticObject(ABC): pass @abstractmethod - def dtype(self) -> np.dtype: + def dtype(self) -> "DTypeLike": pass @abstractmethod @@ -52,7 +54,7 @@ class ArithmeticString(ArithmeticObject): return self.value @property - def dtype(self) -> np.dtype: + def dtype(self) -> "DTypeLike": return np.dtype(object) @property @@ -64,7 +66,7 @@ class ArithmeticString(ArithmeticObject): class ArithmeticNumber(ArithmeticObject): - def __init__(self, value: Union[int, float], dtype: np.dtype): + def __init__(self, value: Union[int, float], dtype: "DTypeLike"): self._value = value self._dtype = dtype @@ -76,7 +78,7 @@ class ArithmeticNumber(ArithmeticObject): return f"{self._value}" @property - def dtype(self) -> np.dtype: + def dtype(self) -> "DTypeLike": return self._dtype def __repr__(self) -> str: @@ -89,8 +91,8 @@ class ArithmeticSeries(ArithmeticObject): """ def __init__( - self, query_compiler: "QueryCompiler", display_name: str, dtype: np.dtype - ): + self, query_compiler: "QueryCompiler", display_name: str, dtype: "DTypeLike" + ) -> None: # type defs self._value: str self._tasks: List["ArithmeticTask"] @@ -121,7 +123,7 @@ class ArithmeticSeries(ArithmeticObject): return self._value @property - def dtype(self) -> np.dtype: + def dtype(self) -> "DTypeLike": return self._dtype def __repr__(self) -> str: diff --git a/eland/common.py b/eland/common.py index d5fdca4..3fbaccc 100644 --- a/eland/common.py +++ b/eland/common.py @@ -18,12 +18,24 @@ import re import warnings from enum import Enum -from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Optional, + Tuple, + Union, + cast, +) -import numpy as np # type: ignore import pandas as pd # type: ignore from elasticsearch import Elasticsearch +if TYPE_CHECKING: + from numpy.typing import DTypeLike + # Default number of rows displayed (different to pandas where ALL could be displayed) DEFAULT_NUM_ROWS_DISPLAYED = 60 DEFAULT_CHUNK_SIZE = 10000 @@ -42,7 +54,7 @@ with warnings.catch_warnings(): def build_pd_series( - data: Dict[str, Any], dtype: Optional[np.dtype] = None, **kwargs: Any + data: Dict[str, Any], dtype: Optional["DTypeLike"] = None, **kwargs: Any ) -> pd.Series: """Builds a pd.Series while squelching the warning for unspecified dtype on empty series @@ -88,7 +100,7 @@ class SortOrder(Enum): def elasticsearch_date_to_pandas_date( - value: Union[int, str], date_format: Optional[str] + value: Union[int, str, float], date_format: Optional[str] ) -> pd.Timestamp: """ Given a specific Elasticsearch format for a date datatype, returns the @@ -98,7 +110,7 @@ def elasticsearch_date_to_pandas_date( Parameters ---------- - value: Union[int, str] + value: Union[int, str, float] The date value. date_format: str The Elasticsearch date format (ex. 'epoch_millis', 'epoch_second', etc.) diff --git a/eland/conftest.py b/eland/conftest.py index 27bb070..dd08fc1 100644 --- a/eland/conftest.py +++ b/eland/conftest.py @@ -15,9 +15,11 @@ # specific language governing permissions and limitations # under the License. +from typing import Any, Dict + import numpy as np -import pandas as pd -import pytest +import pandas as pd # type: ignore +import pytest # type: ignore import eland as ed @@ -28,7 +30,7 @@ pd.set_option("display.width", 100) @pytest.fixture(autouse=True) -def add_imports(doctest_namespace): +def add_imports(doctest_namespace: Dict[str, Any]) -> None: doctest_namespace["np"] = np doctest_namespace["pd"] = pd doctest_namespace["ed"] = ed diff --git a/eland/dataframe.py b/eland/dataframe.py index 5b49e02..fca8145 100644 --- a/eland/dataframe.py +++ b/eland/dataframe.py @@ -19,19 +19,19 @@ import re import sys import warnings from io import StringIO -from typing import Any, List, Optional, Sequence, Tuple, Union +from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Tuple, Union import numpy as np -import pandas as pd -from pandas.core.common import apply_if_callable, is_bool_indexer -from pandas.core.computation.eval import eval -from pandas.core.dtypes.common import is_list_like -from pandas.core.indexing import check_bool_indexer -from pandas.io.common import _expand_user, stringify_path -from pandas.io.formats import console +import pandas as pd # type: ignore +from pandas.core.common import apply_if_callable, is_bool_indexer # type: ignore +from pandas.core.computation.eval import eval # type: ignore +from pandas.core.dtypes.common import is_list_like # type: ignore +from pandas.core.indexing import check_bool_indexer # type: ignore +from pandas.io.common import _expand_user, stringify_path # type: ignore +from pandas.io.formats import console # type: ignore from pandas.io.formats import format as fmt -from pandas.io.formats.printing import pprint_thing -from pandas.util._validators import validate_bool_kwarg +from pandas.io.formats.printing import pprint_thing # type: ignore +from pandas.util._validators import validate_bool_kwarg # type: ignore import eland.plotting as gfx from eland.common import DEFAULT_NUM_ROWS_DISPLAYED, docstring_parameter @@ -41,6 +41,11 @@ from eland.ndframe import NDFrame from eland.series import Series from eland.utils import is_valid_attr_name +if TYPE_CHECKING: + from elasticsearch import Elasticsearch + + from .query_compiler import QueryCompiler + class DataFrame(NDFrame): """ @@ -119,11 +124,13 @@ class DataFrame(NDFrame): def __init__( self, - es_client=None, - es_index_pattern=None, - es_index_field=None, - columns=None, - _query_compiler=None, + es_client: Optional[ + Union[str, List[str], Tuple[str, ...], "Elasticsearch"] + ] = None, + es_index_pattern: Optional[str] = None, + columns: Optional[List[str]] = None, + es_index_field: Optional[str] = None, + _query_compiler: Optional["QueryCompiler"] = None, ) -> None: """ There are effectively 2 constructors: @@ -147,7 +154,7 @@ class DataFrame(NDFrame): _query_compiler=_query_compiler, ) - def _get_columns(self): + def _get_columns(self) -> pd.Index: """ The column labels of the DataFrame. @@ -178,7 +185,7 @@ class DataFrame(NDFrame): columns = property(_get_columns) @property - def empty(self): + def empty(self) -> bool: """Determines if the DataFrame is empty. Returns @@ -278,7 +285,10 @@ class DataFrame(NDFrame): return DataFrame(_query_compiler=self._query_compiler.tail(n)) def sample( - self, n: int = None, frac: float = None, random_state: int = None + self, + n: Optional[int] = None, + frac: Optional[float] = None, + random_state: Optional[int] = None, ) -> "DataFrame": """ Return n randomly sample rows or the specify fraction of rows @@ -469,7 +479,7 @@ class DataFrame(NDFrame): if is_valid_attr_name(column_name) ] - def __repr__(self): + def __repr__(self) -> None: """ From pandas """ @@ -501,7 +511,7 @@ class DataFrame(NDFrame): return buf.getvalue() - def _info_repr(self): + def _info_repr(self) -> bool: """ True if the repr should show the info view. """ @@ -510,7 +520,7 @@ class DataFrame(NDFrame): self._repr_fits_horizontal_() and self._repr_fits_vertical_() ) - def _repr_html_(self): + def _repr_html_(self) -> Optional[str]: """ From pandas - this is called by notebooks """ @@ -540,7 +550,7 @@ class DataFrame(NDFrame): else: return None - def count(self): + def count(self) -> pd.Series: """ Count non-NA cells for each column. @@ -855,10 +865,10 @@ class DataFrame(NDFrame): exceeds_info_cols = len(self.columns) > max_cols # From pandas.DataFrame - def _put_str(s, space): + def _put_str(s, space) -> str: return f"{s}"[:space].ljust(space) - def _verbose_repr(): + def _verbose_repr() -> None: lines.append(f"Data columns (total {len(self.columns)} columns):") id_head = " # " @@ -930,10 +940,10 @@ class DataFrame(NDFrame): + _put_str(dtype, space_dtype) ) - def _non_verbose_repr(): + def _non_verbose_repr() -> None: lines.append(self.columns._summary(name="Columns")) - def _sizeof_fmt(num, size_qualifier): + def _sizeof_fmt(num: float, size_qualifier: str) -> str: # returns size in human readable format for x in ["bytes", "KB", "MB", "GB", "TB"]: if num < 1024.0: @@ -1004,7 +1014,7 @@ class DataFrame(NDFrame): border=None, table_id=None, render_links=False, - ): + ) -> Any: """ Render a Elasticsearch data as an HTML table. @@ -1171,7 +1181,7 @@ class DataFrame(NDFrame): result = _buf.getvalue() return result - def __getattr__(self, key): + def __getattr__(self, key: str) -> Any: """After regular attribute access, looks up the name in the columns Parameters @@ -1190,7 +1200,12 @@ class DataFrame(NDFrame): return self[key] raise e - def _getitem(self, key): + def _getitem( + self, + key: Union[ + "DataFrame", "Series", pd.Index, List[str], str, BooleanFilter, np.ndarray + ], + ) -> Union["Series", "DataFrame"]: """Get the column specified by key for this DataFrame. Args: @@ -1215,13 +1230,13 @@ class DataFrame(NDFrame): else: return self._getitem_column(key) - def _getitem_column(self, key): + def _getitem_column(self, key: str) -> "Series": if key not in self.columns: raise KeyError(f"Requested column [{key}] is not in the DataFrame.") s = self._reduce_dimension(self._query_compiler.getitem_column_array([key])) return s - def _getitem_array(self, key): + def _getitem_array(self, key: Union[str, pd.Series]) -> "DataFrame": if isinstance(key, Series): key = key.to_pandas() if is_bool_indexer(key): @@ -1256,7 +1271,9 @@ class DataFrame(NDFrame): _query_compiler=self._query_compiler.getitem_column_array(key) ) - def _create_or_update_from_compiler(self, new_query_compiler, inplace=False): + def _create_or_update_from_compiler( + self, new_query_compiler: "QueryCompiler", inplace: bool = False + ) -> Union["QueryCompiler", "DataFrame"]: """Returns or updates a DataFrame given new query_compiler""" assert ( isinstance(new_query_compiler, type(self._query_compiler)) @@ -1265,10 +1282,10 @@ class DataFrame(NDFrame): if not inplace: return DataFrame(_query_compiler=new_query_compiler) else: - self._query_compiler = new_query_compiler + self._query_compiler: "QueryCompiler" = new_query_compiler @staticmethod - def _reduce_dimension(query_compiler): + def _reduce_dimension(query_compiler: "QueryCompiler") -> "Series": return Series(_query_compiler=query_compiler) def to_csv( @@ -1849,7 +1866,9 @@ class DataFrame(NDFrame): else: raise NotImplementedError(expr, type(expr)) - def get(self, key, default=None): + def get( + self, key: Any, default: Optional[Any] = None + ) -> Union["Series", "DataFrame"]: """ Get item from object for given key (ex: DataFrame column). Returns default value if not found. @@ -1956,7 +1975,7 @@ class DataFrame(NDFrame): elif like is not None: - def matcher(x): + def matcher(x: str) -> bool: return like in x else: @@ -1965,7 +1984,7 @@ class DataFrame(NDFrame): return self[[column for column in self.columns if matcher(column)]] @property - def values(self): + def values(self) -> None: """ Not implemented. @@ -1983,7 +2002,7 @@ class DataFrame(NDFrame): """ return self.to_numpy() - def to_numpy(self): + def to_numpy(self) -> None: """ Not implemented. diff --git a/eland/field_mappings.py b/eland/field_mappings.py index d8449e3..5c48d81 100644 --- a/eland/field_mappings.py +++ b/eland/field_mappings.py @@ -25,13 +25,14 @@ from typing import ( NamedTuple, Optional, Set, + TextIO, Tuple, Union, ) import numpy as np -import pandas as pd -from pandas.core.dtypes.common import ( +import pandas as pd # type: ignore +from pandas.core.dtypes.common import ( # type: ignore is_bool_dtype, is_datetime_or_timedelta_dtype, is_float_dtype, @@ -42,6 +43,7 @@ from pandas.core.dtypes.inference import is_list_like if TYPE_CHECKING: from elasticsearch import Elasticsearch + from numpy.typing import DTypeLike ES_FLOAT_TYPES: Set[str] = {"double", "float", "half_float", "scaled_float"} @@ -559,7 +561,7 @@ class FieldMappings: return {"mappings": {"properties": mapping_props}} - def aggregatable_field_name(self, display_name): + def aggregatable_field_name(self, display_name: str) -> Optional[str]: """ Return a single aggregatable field_name from display_name @@ -598,7 +600,7 @@ class FieldMappings: return self._mappings_capabilities.loc[display_name].aggregatable_es_field_name - def aggregatable_field_names(self): + def aggregatable_field_names(self) -> Dict[str, str]: """ Return a list of aggregatable Elasticsearch field_names for all display names. If field is not aggregatable_field_names, return nothing. @@ -634,7 +636,7 @@ class FieldMappings: )["data"] ) - def date_field_format(self, es_field_name): + def date_field_format(self, es_field_name: str) -> str: """ Parameters ---------- @@ -650,7 +652,7 @@ class FieldMappings: self._mappings_capabilities.es_field_name == es_field_name ].es_date_format.squeeze() - def field_name_pd_dtype(self, es_field_name): + def field_name_pd_dtype(self, es_field_name: str) -> str: """ Parameters ---------- @@ -674,7 +676,9 @@ class FieldMappings: ].pd_dtype.squeeze() return pd_dtype - def add_scripted_field(self, scripted_field_name, display_name, pd_dtype): + def add_scripted_field( + self, scripted_field_name: str, display_name: str, pd_dtype: str + ) -> None: # if this display name is used somewhere else, drop it if display_name in self._mappings_capabilities.index: self._mappings_capabilities = self._mappings_capabilities.drop( @@ -706,8 +710,8 @@ class FieldMappings: capability_matrix_row ) - def numeric_source_fields(self): - pd_dtypes, es_field_names, es_date_formats = self.metric_source_fields() + def numeric_source_fields(self) -> List[str]: + _, es_field_names, _ = self.metric_source_fields() return es_field_names def all_source_fields(self) -> List[Field]: @@ -753,7 +757,9 @@ class FieldMappings: # Maintain groupby order as given input return [groupby_fields[column] for column in by], aggregatable_fields - def metric_source_fields(self, include_bool=False, include_timestamp=False): + def metric_source_fields( + self, include_bool: bool = False, include_timestamp: bool = False + ) -> Tuple[List["DTypeLike"], List[str], Optional[List[str]]]: """ Returns ------- @@ -790,7 +796,7 @@ class FieldMappings: # return in display_name order return pd_dtypes, es_field_names, es_date_formats - def get_field_names(self, include_scripted_fields=True): + def get_field_names(self, include_scripted_fields: bool = True) -> List[str]: if include_scripted_fields: return self._mappings_capabilities.es_field_name.to_list() @@ -801,7 +807,7 @@ class FieldMappings: def _get_display_names(self): return self._mappings_capabilities.index.to_list() - def _set_display_names(self, display_names): + def _set_display_names(self, display_names: List[str]): if not is_list_like(display_names): raise ValueError(f"'{display_names}' is not list like") @@ -842,7 +848,7 @@ class FieldMappings: es_dtypes.name = None return es_dtypes - def es_info(self, buf): + def es_info(self, buf: TextIO) -> None: buf.write("Mappings:\n") buf.write(f" capabilities:\n{self._mappings_capabilities.to_string()}\n") diff --git a/eland/ml/_optional.py b/eland/ml/_optional.py index 25e7091..83e02df 100644 --- a/eland/ml/_optional.py +++ b/eland/ml/_optional.py @@ -17,8 +17,11 @@ import distutils.version import importlib -import types import warnings +from typing import TYPE_CHECKING, Any, Optional + +if TYPE_CHECKING: + from types import ModuleType # ---------------------------------------------------------------------------- # functions largely based / taken from the six module @@ -42,7 +45,7 @@ version_message = ( ) -def _get_version(module: types.ModuleType) -> str: +def _get_version(module: "ModuleType") -> Any: version = getattr(module, "__version__", None) if version is None: # xlrd uses a capitalized attribute name @@ -55,7 +58,7 @@ def _get_version(module: types.ModuleType) -> str: def import_optional_dependency( name: str, extra: str = "", raise_on_missing: bool = True, on_version: str = "raise" -): +) -> Optional["ModuleType"]: """ Import an optional dependency. diff --git a/eland/ml/ml_model.py b/eland/ml/ml_model.py index 29c5b66..efaeea6 100644 --- a/eland/ml/ml_model.py +++ b/eland/ml/ml_model.py @@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast import elasticsearch -import numpy as np # type: ignore +import numpy as np from eland.common import ensure_es_client, es_version from eland.utils import deprecated_api @@ -27,7 +27,8 @@ from .common import TYPE_CLASSIFICATION, TYPE_REGRESSION from .transformers import get_model_transformer if TYPE_CHECKING: - from elasticsearch import Elasticsearch # noqa: F401 + from elasticsearch import Elasticsearch + from numpy.typing import ArrayLike, DTypeLike # Try importing each ML lib separately so mypy users don't have to # have both installed to use type-checking. @@ -83,8 +84,8 @@ class MLModel: self._trained_model_config_cache: Optional[Dict[str, Any]] = None def predict( - self, X: Union[np.ndarray, List[float], List[List[float]]] - ) -> np.ndarray: + self, X: Union["ArrayLike", List[float], List[List[float]]] + ) -> "ArrayLike": """ Make a prediction using a trained model stored in Elasticsearch. @@ -196,7 +197,7 @@ class MLModel: # Return results as np.ndarray of float32 or int (consistent with sklearn/xgboost) if self.model_type == TYPE_CLASSIFICATION: - dt = np.int_ + dt: "DTypeLike" = np.int_ else: dt = np.float32 return np.asarray(y, dtype=dt) diff --git a/eland/ml/transformers/sklearn.py b/eland/ml/transformers/sklearn.py index 303f5ec..2f259c9 100644 --- a/eland/ml/transformers/sklearn.py +++ b/eland/ml/transformers/sklearn.py @@ -17,7 +17,7 @@ from typing import Any, Dict, Optional, Sequence, Tuple, Type, Union -import numpy as np # type: ignore +import numpy as np from .._model_serializer import Ensemble, Tree, TreeNode from .._optional import import_optional_dependency @@ -64,7 +64,7 @@ class SKLearnTransformer(ModelTransformer): self, node_index: int, node_data: Tuple[Union[int, float], ...], - value: np.ndarray, + value: np.ndarray, # type: ignore ) -> TreeNode: """ This builds out a TreeNode class given the sklearn tree node definition. diff --git a/eland/ml/transformers/xgboost.py b/eland/ml/transformers/xgboost.py index 5d4e85e..9eb1d18 100644 --- a/eland/ml/transformers/xgboost.py +++ b/eland/ml/transformers/xgboost.py @@ -229,7 +229,7 @@ class XGBoostClassifierTransformer(XGBoostForestTransformer): if model.classes_ is None: n_estimators = model.get_params()["n_estimators"] num_trees = model.get_booster().trees_to_dataframe()["Tree"].max() + 1 - self._num_classes = num_trees // n_estimators + self._num_classes: int = num_trees // n_estimators else: self._num_classes = len(model.classes_) diff --git a/eland/ndframe.py b/eland/ndframe.py index 320ffa4..6769293 100644 --- a/eland/ndframe.py +++ b/eland/ndframe.py @@ -17,7 +17,7 @@ import sys from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, List, Optional, TextIO, Tuple, Union import pandas as pd # type: ignore @@ -186,7 +186,7 @@ class NDFrame(ABC): """ return len(self.index) - def _es_info(self, buf): + def _es_info(self, buf: TextIO) -> None: self._query_compiler.es_info(buf) def mean(self, numeric_only: Optional[bool] = None) -> pd.Series: @@ -604,7 +604,7 @@ class NDFrame(ABC): """ return self._query_compiler.mad(numeric_only=numeric_only) - def _hist(self, num_bins): + def _hist(self, num_bins: int) -> Tuple[pd.DataFrame, pd.DataFrame]: return self._query_compiler._hist(num_bins) def describe(self) -> pd.DataFrame: diff --git a/eland/operations.py b/eland/operations.py index f345c85..64e6e9d 100644 --- a/eland/operations.py +++ b/eland/operations.py @@ -26,12 +26,13 @@ from typing import ( List, Optional, Sequence, + TextIO, Tuple, Union, ) import numpy as np -import pandas as pd +import pandas as pd # type: ignore from elasticsearch.helpers import scan from eland.actions import PostProcessingAction, SortFieldAction @@ -58,13 +59,18 @@ from eland.tasks import ( ) if TYPE_CHECKING: + from numpy.typing import DTypeLike + + from eland.arithmetics import ArithmeticSeries from eland.field_mappings import Field + from eland.filter import BooleanFilter from eland.query_compiler import QueryCompiler + from eland.tasks import Task class QueryParams: - def __init__(self): - self.query = Query() + def __init__(self) -> None: + self.query: Query = Query() self.sort_field: Optional[str] = None self.sort_order: Optional[SortOrder] = None self.size: Optional[int] = None @@ -85,37 +91,48 @@ class Operations: (see https://docs.dask.org/en/latest/spec.html) """ - def __init__(self, tasks=None, arithmetic_op_fields_task=None): + def __init__( + self, + tasks: Optional[List["Task"]] = None, + arithmetic_op_fields_task: Optional["ArithmeticOpFieldsTask"] = None, + ) -> None: + self._tasks: List["Task"] if tasks is None: self._tasks = [] else: self._tasks = tasks self._arithmetic_op_fields_task = arithmetic_op_fields_task - def __constructor__(self, *args, **kwargs): + def __constructor__( + self, + *args: Any, + **kwargs: Any, + ) -> "Operations": return type(self)(*args, **kwargs) - def copy(self): + def copy(self) -> "Operations": return self.__constructor__( tasks=copy.deepcopy(self._tasks), arithmetic_op_fields_task=copy.deepcopy(self._arithmetic_op_fields_task), ) - def head(self, index, n): + def head(self, index: "Index", n: int) -> None: # Add a task that is an ascending sort with size=n task = HeadTask(index, n) self._tasks.append(task) - def tail(self, index, n): + def tail(self, index: "Index", n: int) -> None: # Add a task that is descending sort with size=n task = TailTask(index, n) self._tasks.append(task) - def sample(self, index, n, random_state): + def sample(self, index: "Index", n: int, random_state: int) -> None: task = SampleTask(index, n, random_state) self._tasks.append(task) - def arithmetic_op_fields(self, display_name, arithmetic_series): + def arithmetic_op_fields( + self, display_name: str, arithmetic_series: "ArithmeticSeries" + ) -> None: if self._arithmetic_op_fields_task is None: self._arithmetic_op_fields_task = ArithmeticOpFieldsTask( display_name, arithmetic_series @@ -127,10 +144,10 @@ class Operations: # get an ArithmeticOpFieldsTask if it exists return self._arithmetic_op_fields_task - def __repr__(self): + def __repr__(self) -> str: return repr(self._tasks) - def count(self, query_compiler): + def count(self, query_compiler: "QueryCompiler") -> pd.Series: query_params, post_processing = self._resolve_tasks(query_compiler) # Elasticsearch _count is very efficient and so used to return results here. This means that @@ -161,7 +178,7 @@ class Operations: def _metric_agg_series( self, query_compiler: "QueryCompiler", - agg: List, + agg: List["str"], numeric_only: Optional[bool] = None, ) -> pd.Series: results = self._metric_aggs(query_compiler, agg, numeric_only=numeric_only) @@ -170,7 +187,7 @@ class Operations: else: # If all results are float convert into float64 if all(isinstance(i, float) for i in results.values()): - dtype = np.float64 + dtype: "DTypeLike" = np.float64 # If all results are int convert into int64 elif all(isinstance(i, int) for i in results.values()): dtype = np.int64 @@ -184,7 +201,9 @@ class Operations: def value_counts(self, query_compiler: "QueryCompiler", es_size: int) -> pd.Series: return self._terms_aggs(query_compiler, "terms", es_size) - def hist(self, query_compiler, bins): + def hist( + self, query_compiler: "QueryCompiler", bins: int + ) -> Tuple[pd.DataFrame, pd.DataFrame]: return self._hist_aggs(query_compiler, bins) def idx( @@ -237,7 +256,12 @@ class Operations: return pd.Series(results) - def aggs(self, query_compiler, pd_aggs, numeric_only=None) -> pd.DataFrame: + def aggs( + self, + query_compiler: "QueryCompiler", + pd_aggs: List[str], + numeric_only: Optional[bool] = None, + ) -> pd.DataFrame: results = self._metric_aggs( query_compiler, pd_aggs, numeric_only=numeric_only, is_dataframe_agg=True ) @@ -441,13 +465,15 @@ class Operations: try: # get first value in dict (key is .keyword) - name = list(aggregatable_field_names.values())[0] + name: Optional[str] = list(aggregatable_field_names.values())[0] except IndexError: name = None return build_pd_series(results, name=name) - def _hist_aggs(self, query_compiler, num_bins): + def _hist_aggs( + self, query_compiler: "QueryCompiler", num_bins: int + ) -> Tuple[pd.DataFrame, pd.DataFrame]: # Get histogram bins and weights for numeric field_names query_params, post_processing = self._resolve_tasks(query_compiler) @@ -488,8 +514,8 @@ class Operations: # }, # ... - bins = {} - weights = {} + bins: Dict[str, List[int]] = {} + weights: Dict[str, List[int]] = {} # There is one more bin that weights # len(bins) = len(weights) + 1 @@ -537,11 +563,11 @@ class Operations: def _unpack_metric_aggs( self, fields: List["Field"], - es_aggs: Union[List[str], List[Tuple[str, str]]], + es_aggs: Union[List[str], List[Tuple[str, List[float]]]], pd_aggs: List[str], response: Dict[str, Any], numeric_only: Optional[bool], - percentiles: Optional[List[float]] = None, + percentiles: Optional[Sequence[float]] = None, is_dataframe_agg: bool = False, is_groupby: bool = False, ) -> Dict[str, List[Any]]: @@ -574,7 +600,7 @@ class Operations: """ results: Dict[str, Any] = {} percentile_values: List[float] = [] - agg_value: Union[int, float] + agg_value: Any for field in fields: values = [] @@ -611,7 +637,10 @@ class Operations: agg_value = agg_value["50.0"] else: # Maintain order of percentiles - percentile_values = [agg_value[str(i)] for i in percentiles] + if percentiles: + percentile_values = [ + agg_value[str(i)] for i in percentiles + ] if not percentile_values and pd_agg not in ("quantile", "median"): agg_value = agg_value[es_agg[1]] @@ -682,7 +711,11 @@ class Operations: # Cardinality is always either NaN or integer. elif pd_agg in ("nunique", "count"): - agg_value = int(agg_value) + agg_value = ( + int(agg_value) + if isinstance(agg_value, (int, float)) + else np.NaN + ) # If this is a non-null timestamp field convert to a pd.Timestamp() elif field.is_timestamp: @@ -702,6 +735,7 @@ class Operations: for value in percentile_values ] else: + assert not isinstance(agg_value, dict) agg_value = elasticsearch_date_to_pandas_date( agg_value, field.es_date_format ) @@ -771,7 +805,7 @@ class Operations: by: List[str], pd_aggs: List[str], dropna: bool = True, - quantiles: Optional[List[float]] = None, + quantiles: Optional[Union[int, float, List[int], List[float]]] = None, is_dataframe_agg: bool = False, numeric_only: Optional[bool] = True, ) -> pd.DataFrame: @@ -811,7 +845,7 @@ class Operations: by_fields, agg_fields = query_compiler._mappings.groupby_source_fields(by=by) # Used defaultdict to avoid initialization of columns with lists - results: Dict[str, List[Any]] = defaultdict(list) + results: Dict[Any, List[Any]] = defaultdict(list) if numeric_only: agg_fields = [ @@ -823,7 +857,8 @@ class Operations: # To return for creating multi-index on columns headers = [agg_field.column for agg_field in agg_fields] - percentiles: Optional[List[str]] = None + percentiles: Optional[List[float]] = None + len_percentiles: int = 0 if quantiles: percentiles = [ quantile_to_percentile(x) @@ -833,6 +868,7 @@ class Operations: else quantiles ) ] + len_percentiles = len(percentiles) # Convert pandas aggs to ES equivalent es_aggs = self._map_pd_aggs_to_es_aggs(pd_aggs=pd_aggs, percentiles=percentiles) @@ -894,8 +930,8 @@ class Operations: if by_field.is_timestamp and isinstance(bucket_key, int): bucket_key = pd.to_datetime(bucket_key, unit="ms") - if pd_aggs == ["quantile"] and len(percentiles) > 1: - bucket_key = [bucket_key] * len(percentiles) + if pd_aggs == ["quantile"] and len_percentiles > 1: + bucket_key = [bucket_key] * len_percentiles results[by_field.column].extend( bucket_key if isinstance(bucket_key, list) else [bucket_key] @@ -915,7 +951,7 @@ class Operations: ) # to construct index with quantiles - if pd_aggs == ["quantile"] and len(percentiles) > 1: + if pd_aggs == ["quantile"] and percentiles and len_percentiles > 1: results[None].extend([i / 100 for i in percentiles]) # Process the calculated agg values to response @@ -929,9 +965,10 @@ class Operations: for pd_agg, val in zip(pd_aggs, value): results[f"{key}_{pd_agg}"].append(val) - # Just to maintain Output same as pandas with empty header. - if pd_aggs == ["quantile"] and len(percentiles) > 1: - by = by + [None] + if pd_aggs == ["quantile"] and len_percentiles > 1: + # by never holds None by default, we make an exception + # here to maintain output same as pandas, also mypy complains + by = by + [None] # type: ignore agg_df = pd.DataFrame(results).set_index(by).sort_index() @@ -947,7 +984,7 @@ class Operations: @staticmethod def bucket_generator( query_compiler: "QueryCompiler", body: "Query" - ) -> Generator[List[str], None, List[str]]: + ) -> Generator[Sequence[Dict[str, Any]], None, Sequence[Dict[str, Any]]]: """ This can be used for all groupby operations. e.g. @@ -977,18 +1014,24 @@ class Operations: ) # Pagination Logic - composite_buckets = res["aggregations"]["groupby_buckets"] - if "after_key" in composite_buckets: + composite_buckets: Dict[str, Any] = res["aggregations"]["groupby_buckets"] + + after_key: Optional[Dict[str, Any]] = composite_buckets.get( + "after_key", None + ) + buckets: Sequence[Dict[str, Any]] = composite_buckets["buckets"] + + if after_key: # yield the bucket which contains the result - yield composite_buckets["buckets"] + yield buckets body.composite_agg_after_key( name="groupby_buckets", - after_key=composite_buckets["after_key"], + after_key=after_key, ) else: - return composite_buckets["buckets"] + return buckets @staticmethod def _map_pd_aggs_to_es_aggs( @@ -1031,7 +1074,7 @@ class Operations: extended_stats_es_aggs = {"avg", "min", "max", "sum"} extended_stats_calls = 0 - es_aggs = [] + es_aggs: List[Any] = [] for pd_agg in pd_aggs: if pd_agg in extended_stats_pd_aggs: extended_stats_calls += 1 @@ -1100,7 +1143,7 @@ class Operations: def filter( self, query_compiler: "QueryCompiler", - items: Optional[Sequence[str]] = None, + items: Optional[List[str]] = None, like: Optional[str] = None, regex: Optional[str] = None, ) -> None: @@ -1122,7 +1165,7 @@ class Operations: f"to substring and regex operations not being available for Elasticsearch document IDs." ) - def describe(self, query_compiler): + def describe(self, query_compiler: "QueryCompiler") -> pd.DataFrame: query_params, post_processing = self._resolve_tasks(query_compiler) size = self._size(query_params, post_processing) @@ -1151,30 +1194,9 @@ class Operations: ["count", "mean", "std", "min", "25%", "50%", "75%", "max"] ) - def to_pandas(self, query_compiler, show_progress=False): - class PandasDataFrameCollector: - def __init__(self, show_progress): - self._df = None - self._show_progress = show_progress - - def collect(self, df): - # This collector does not batch data on output. Therefore, batch_size is fixed to None and this method - # is only called once. - if self._df is not None: - raise RuntimeError( - "Logic error in execution, this method must only be called once for this" - "collector - batch_size == None" - ) - self._df = df - - @staticmethod - def batch_size(): - # Do not change (see notes on collect) - return None - - @property - def show_progress(self): - return self._show_progress + def to_pandas( + self, query_compiler: "QueryCompiler", show_progress: bool = False + ) -> None: collector = PandasDataFrameCollector(show_progress) @@ -1182,35 +1204,12 @@ class Operations: return collector._df - def to_csv(self, query_compiler, show_progress=False, **kwargs): - class PandasToCSVCollector: - def __init__(self, show_progress, **args): - self._args = args - self._show_progress = show_progress - self._ret = None - self._first_time = True - - def collect(self, df): - # If this is the first time we collect results, then write header, otherwise don't write header - # and append results - if self._first_time: - self._first_time = False - df.to_csv(**self._args) - else: - # Don't write header, and change mode to append - self._args["header"] = False - self._args["mode"] = "a" - df.to_csv(**self._args) - - @staticmethod - def batch_size(): - # By default read n docs and then dump to csv - batch_size = DEFAULT_CSV_BATCH_OUTPUT_SIZE - return batch_size - - @property - def show_progress(self): - return self._show_progress + def to_csv( + self, + query_compiler: "QueryCompiler", + show_progress: bool = False, + **kwargs: Union[bool, str], + ) -> None: collector = PandasToCSVCollector(show_progress, **kwargs) @@ -1218,7 +1217,11 @@ class Operations: return collector._ret - def _es_results(self, query_compiler, collector): + def _es_results( + self, + query_compiler: "QueryCompiler", + collector: Union["PandasToCSVCollector", "PandasDataFrameCollector"], + ) -> None: query_params, post_processing = self._resolve_tasks(query_compiler) size, sort_params = Operations._query_params_to_size_and_sort(query_params) @@ -1245,7 +1248,7 @@ class Operations: else: body["_source"] = False - es_results = None + es_results: Any = None # If size=None use scan not search - then post sort results when in df # If size>10000 use scan @@ -1283,7 +1286,7 @@ class Operations: df = self._apply_df_post_processing(df, post_processing) collector.collect(df) - def index_count(self, query_compiler, field): + def index_count(self, query_compiler: "QueryCompiler", field: str) -> int: # field is the index field so count values query_params, post_processing = self._resolve_tasks(query_compiler) @@ -1297,12 +1300,13 @@ class Operations: body = Query(query_params.query) body.exists(field, must=True) - return query_compiler._client.count( + count: int = query_compiler._client.count( index=query_compiler._index_pattern, body=body.to_count_body() )["count"] + return count def _validate_index_operation( - self, query_compiler: "QueryCompiler", items: Sequence[str] + self, query_compiler: "QueryCompiler", items: List[str] ) -> RESOLVED_TASK_TYPE: if not isinstance(items, list): raise TypeError(f"list item required - not {type(items)}") @@ -1320,7 +1324,9 @@ class Operations: return query_params, post_processing - def index_matches_count(self, query_compiler, field, items): + def index_matches_count( + self, query_compiler: "QueryCompiler", field: str, items: List[Any] + ) -> int: query_params, post_processing = self._validate_index_operation( query_compiler, items ) @@ -1332,12 +1338,13 @@ class Operations: else: body.terms(field, items, must=True) - return query_compiler._client.count( + count: int = query_compiler._client.count( index=query_compiler._index_pattern, body=body.to_count_body() )["count"] + return count def drop_index_values( - self, query_compiler: "QueryCompiler", field: str, items: Sequence[str] + self, query_compiler: "QueryCompiler", field: str, items: List[str] ) -> None: self._validate_index_operation(query_compiler, items) @@ -1349,6 +1356,7 @@ class Operations: # a in ['a','b','c'] # b not in ['a','b','c'] # For now use term queries + task: Union["QueryIdsTask", "QueryTermsTask"] if field == Index.ID_INDEX_FIELD: task = QueryIdsTask(False, items) else: @@ -1356,11 +1364,12 @@ class Operations: self._tasks.append(task) def filter_index_values( - self, query_compiler: "QueryCompiler", field: str, items: Sequence[str] + self, query_compiler: "QueryCompiler", field: str, items: List[str] ) -> None: # Basically .drop_index_values() except with must=True on tasks. self._validate_index_operation(query_compiler, items) + task: Union["QueryIdsTask", "QueryTermsTask"] if field == Index.ID_INDEX_FIELD: task = QueryIdsTask(True, items, sort_index_by_ids=True) else: @@ -1406,7 +1415,7 @@ class Operations: # other operations require pre-queries and then combinations # other operations require in-core post-processing of results query_params = QueryParams() - post_processing = [] + post_processing: List["PostProcessingAction"] = [] for task in self._tasks: query_params, post_processing = task.resolve_task( @@ -1439,7 +1448,7 @@ class Operations: # This can return None return size - def es_info(self, query_compiler, buf): + def es_info(self, query_compiler: "QueryCompiler", buf: TextIO) -> None: buf.write("Operations:\n") buf.write(f" tasks: {self._tasks}\n") @@ -1459,7 +1468,7 @@ class Operations: buf.write(f" body: {body}\n") buf.write(f" post_processing: {post_processing}\n") - def update_query(self, boolean_filter): + def update_query(self, boolean_filter: "BooleanFilter") -> None: task = BooleanFilterTask(boolean_filter) self._tasks.append(task) @@ -1477,3 +1486,58 @@ def quantile_to_percentile(quantile: Union[int, float]) -> float: # quantile * 100 = percentile # return float(...) because min(1.0) gives 1 return float(min(100, max(0, quantile * 100))) + + +class PandasToCSVCollector: + def __init__(self, show_progress: bool, **kwargs: Union[bool, str]) -> None: + self._args = kwargs + self._show_progress = show_progress + self._ret = None + self._first_time = True + + def collect(self, df: "pd.DataFrame") -> None: + # If this is the first time we collect results, then write header, otherwise don't write header + # and append results + if self._first_time: + self._first_time = False + df.to_csv(**self._args) + else: + # Don't write header, and change mode to append + self._args["header"] = False + self._args["mode"] = "a" + df.to_csv(**self._args) + + @staticmethod + def batch_size() -> int: + # By default read n docs and then dump to csv + batch_size: int = DEFAULT_CSV_BATCH_OUTPUT_SIZE + return batch_size + + @property + def show_progress(self) -> bool: + return self._show_progress + + +class PandasDataFrameCollector: + def __init__(self, show_progress: bool) -> None: + self._df = None + self._show_progress = show_progress + + def collect(self, df: "pd.DataFrame") -> None: + # This collector does not batch data on output. Therefore, batch_size is fixed to None and this method + # is only called once. + if self._df is not None: + raise RuntimeError( + "Logic error in execution, this method must only be called once for this" + "collector - batch_size == None" + ) + self._df = df + + @staticmethod + def batch_size() -> None: + # Do not change (see notes on collect) + return None + + @property + def show_progress(self) -> bool: + return self._show_progress diff --git a/eland/plotting/_matplotlib/hist.py b/eland/plotting/_matplotlib/hist.py index 5769710..37efda0 100644 --- a/eland/plotting/_matplotlib/hist.py +++ b/eland/plotting/_matplotlib/hist.py @@ -15,8 +15,10 @@ # specific language governing permissions and limitations # under the License. +from typing import TYPE_CHECKING + import numpy as np -from pandas.plotting._matplotlib import converter +from pandas.plotting._matplotlib import converter # type: ignore try: # pandas<1.3.0 @@ -26,13 +28,13 @@ except ImportError: from pandas.core.dtypes.generic import ABCIndex try: # pandas>=1.2.0 - from pandas.plotting._matplotlib.tools import ( + from pandas.plotting._matplotlib.tools import ( # type: ignore create_subplots, flatten_axes, set_ticks_props, ) except ImportError: # pandas<1.2.0 - from pandas.plotting._matplotlib.tools import ( + from pandas.plotting._matplotlib.tools import ( # type: ignore _flatten as flatten_axes, _set_ticks_props as set_ticks_props, _subplots as create_subplots, @@ -40,6 +42,9 @@ except ImportError: # pandas<1.2.0 from eland.utils import try_sort +if TYPE_CHECKING: + from numpy.typing import ArrayLike + def hist_series( self, @@ -53,8 +58,8 @@ def hist_series( figsize=None, bins=10, **kwds, -): - import matplotlib.pyplot as plt +) -> "ArrayLike": + import matplotlib.pyplot as plt # type: ignore if by is None: if kwds.get("layout", None) is not None: diff --git a/eland/query_compiler.py b/eland/query_compiler.py index 3448b35..cd5dfbb 100644 --- a/eland/query_compiler.py +++ b/eland/query_compiler.py @@ -17,9 +17,19 @@ import copy from datetime import datetime -from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Union +from typing import ( + TYPE_CHECKING, + Any, + Dict, + List, + Optional, + Sequence, + TextIO, + Tuple, + Union, +) -import numpy as np # type: ignore +import numpy as np import pandas as pd # type: ignore from eland.common import ( @@ -28,11 +38,15 @@ from eland.common import ( ensure_es_client, ) from eland.field_mappings import FieldMappings -from eland.filter import QueryFilter +from eland.filter import BooleanFilter, QueryFilter from eland.index import Index from eland.operations import Operations if TYPE_CHECKING: + from elasticsearch import Elasticsearch + + from eland.arithmetics import ArithmeticSeries + from .tasks import ArithmeticOpFieldsTask # noqa: F401 @@ -67,8 +81,10 @@ class QueryCompiler: def __init__( self, - client=None, - index_pattern=None, + client: Optional[ + Union[str, List[str], Tuple[str, ...], "Elasticsearch"] + ] = None, + index_pattern: Optional[str] = None, display_names=None, index_field=None, to_copy=None, @@ -77,15 +93,15 @@ class QueryCompiler: if to_copy is not None: self._client = to_copy._client self._index_pattern = to_copy._index_pattern - self._index = Index(self, to_copy._index.es_index_field) - self._operations = copy.deepcopy(to_copy._operations) + self._index: "Index" = Index(self, to_copy._index.es_index_field) + self._operations: "Operations" = copy.deepcopy(to_copy._operations) self._mappings: FieldMappings = copy.deepcopy(to_copy._mappings) else: self._client = ensure_es_client(client) self._index_pattern = index_pattern # Get and persist mappings, this allows us to correctly # map returned types from Elasticsearch to pandas datatypes - self._mappings: FieldMappings = FieldMappings( + self._mappings = FieldMappings( client=self._client, index_pattern=self._index_pattern, display_names=display_names, @@ -103,15 +119,15 @@ class QueryCompiler: return pd.Index(columns) - def _get_display_names(self): + def _get_display_names(self) -> "pd.Index": display_names = self._mappings.display_names return pd.Index(display_names) - def _set_display_names(self, display_names): + def _set_display_names(self, display_names: List[str]) -> None: self._mappings.display_names = display_names - def get_field_names(self, include_scripted_fields): + def get_field_names(self, include_scripted_fields: bool) -> List[str]: return self._mappings.get_field_names(include_scripted_fields) def add_scripted_field(self, scripted_field_name, display_name, pd_dtype): @@ -129,7 +145,12 @@ class QueryCompiler: # END Index, columns, and dtypes objects - def _es_results_to_pandas(self, results, batch_size=None, show_progress=False): + def _es_results_to_pandas( + self, + results: Dict[Any, Any], + batch_size: Optional[int] = None, + show_progress: bool = False, + ) -> "pd.Dataframe": """ Parameters ---------- @@ -300,7 +321,7 @@ class QueryCompiler: return partial_result, df - def _flatten_dict(self, y, field_mapping_cache): + def _flatten_dict(self, y, field_mapping_cache: "FieldMappingCache"): out = {} def flatten(x, name=""): @@ -368,7 +389,7 @@ class QueryCompiler: """ return self._operations.index_count(self, self.index.es_index_field) - def _index_matches_count(self, items): + def _index_matches_count(self, items: List[Any]) -> int: """ Returns ------- @@ -386,10 +407,10 @@ class QueryCompiler: df[c] = pd.Series(dtype=d) return df - def copy(self): + def copy(self) -> "QueryCompiler": return QueryCompiler(to_copy=self) - def rename(self, renames, inplace=False): + def rename(self, renames, inplace: bool = False) -> "QueryCompiler": if inplace: self._mappings.rename(renames) return self @@ -398,21 +419,23 @@ class QueryCompiler: result._mappings.rename(renames) return result - def head(self, n): + def head(self, n: int) -> "QueryCompiler": result = self.copy() result._operations.head(self._index, n) return result - def tail(self, n): + def tail(self, n: int) -> "QueryCompiler": result = self.copy() result._operations.tail(self._index, n) return result - def sample(self, n=None, frac=None, random_state=None): + def sample( + self, n: Optional[int] = None, frac=None, random_state=None + ) -> "QueryCompiler": result = self.copy() if n is None and frac is None: @@ -501,11 +524,11 @@ class QueryCompiler: query = {"multi_match": options} return QueryFilter(query) - def es_query(self, query): + def es_query(self, query: Dict[str, Any]) -> "QueryCompiler": return self._update_query(QueryFilter(query)) # To/From Pandas - def to_pandas(self, show_progress=False): + def to_pandas(self, show_progress: bool = False): """Converts Eland DataFrame to Pandas DataFrame. Returns: @@ -543,7 +566,9 @@ class QueryCompiler: return result - def drop(self, index=None, columns=None): + def drop( + self, index: Optional[str] = None, columns: Optional[List[str]] = None + ) -> "QueryCompiler": result = self.copy() # Drop gets all columns and removes drops @@ -559,7 +584,7 @@ class QueryCompiler: def filter( self, - items: Optional[Sequence[str]] = None, + items: Optional[List[str]] = None, like: Optional[str] = None, regex: Optional[str] = None, ) -> "QueryCompiler": @@ -570,53 +595,55 @@ class QueryCompiler: result._operations.filter(self, items=items, like=like, regex=regex) return result - def aggs(self, func: List[str], numeric_only: Optional[bool] = None): + def aggs( + self, func: List[str], numeric_only: Optional[bool] = None + ) -> pd.DataFrame: return self._operations.aggs(self, func, numeric_only=numeric_only) - def count(self): + def count(self) -> pd.Series: return self._operations.count(self) - def mean(self, numeric_only: Optional[bool] = None): + def mean(self, numeric_only: Optional[bool] = None) -> pd.Series: return self._operations._metric_agg_series( self, ["mean"], numeric_only=numeric_only ) - def var(self, numeric_only: Optional[bool] = None): + def var(self, numeric_only: Optional[bool] = None) -> pd.Series: return self._operations._metric_agg_series( self, ["var"], numeric_only=numeric_only ) - def std(self, numeric_only: Optional[bool] = None): + def std(self, numeric_only: Optional[bool] = None) -> pd.Series: return self._operations._metric_agg_series( self, ["std"], numeric_only=numeric_only ) - def mad(self, numeric_only: Optional[bool] = None): + def mad(self, numeric_only: Optional[bool] = None) -> pd.Series: return self._operations._metric_agg_series( self, ["mad"], numeric_only=numeric_only ) - def median(self, numeric_only: Optional[bool] = None): + def median(self, numeric_only: Optional[bool] = None) -> pd.Series: return self._operations._metric_agg_series( self, ["median"], numeric_only=numeric_only ) - def sum(self, numeric_only: Optional[bool] = None): + def sum(self, numeric_only: Optional[bool] = None) -> pd.Series: return self._operations._metric_agg_series( self, ["sum"], numeric_only=numeric_only ) - def min(self, numeric_only: Optional[bool] = None): + def min(self, numeric_only: Optional[bool] = None) -> pd.Series: return self._operations._metric_agg_series( self, ["min"], numeric_only=numeric_only ) - def max(self, numeric_only: Optional[bool] = None): + def max(self, numeric_only: Optional[bool] = None) -> pd.Series: return self._operations._metric_agg_series( self, ["max"], numeric_only=numeric_only ) - def nunique(self): + def nunique(self) -> pd.Series: return self._operations._metric_agg_series( self, ["nunique"], numeric_only=False ) @@ -673,7 +700,7 @@ class QueryCompiler: dropna: bool = True, is_dataframe_agg: bool = False, numeric_only: Optional[bool] = True, - quantiles: Union[int, float, List[int], List[float], None] = None, + quantiles: Optional[Union[int, float, List[int], List[float]]] = None, ) -> pd.DataFrame: return self._operations.aggs_groupby( self, @@ -691,27 +718,27 @@ class QueryCompiler: def value_counts(self, es_size: int) -> pd.Series: return self._operations.value_counts(self, es_size) - def es_info(self, buf): + def es_info(self, buf: TextIO) -> None: buf.write(f"es_index_pattern: {self._index_pattern}\n") self._index.es_info(buf) self._mappings.es_info(buf) self._operations.es_info(self, buf) - def describe(self): + def describe(self) -> pd.DataFrame: return self._operations.describe(self) - def _hist(self, num_bins): + def _hist(self, num_bins: int) -> Tuple[pd.DataFrame, pd.DataFrame]: return self._operations.hist(self, num_bins) - def _update_query(self, boolean_filter): + def _update_query(self, boolean_filter: "BooleanFilter") -> "QueryCompiler": result = self.copy() result._operations.update_query(boolean_filter) return result - def check_arithmetics(self, right): + def check_arithmetics(self, right: "QueryCompiler") -> None: """ Compare 2 query_compilers to see if arithmetic operations can be performed by the NDFrame object. @@ -750,7 +777,9 @@ class QueryCompiler: f"{self._index_pattern} != {right._index_pattern}" ) - def arithmetic_op_fields(self, display_name, arithmetic_object): + def arithmetic_op_fields( + self, display_name: str, arithmetic_object: "ArithmeticSeries" + ) -> "QueryCompiler": result = self.copy() # create a new field name for this display name @@ -758,7 +787,7 @@ class QueryCompiler: # add scripted field result._mappings.add_scripted_field( - scripted_field_name, display_name, arithmetic_object.dtype.name + scripted_field_name, display_name, arithmetic_object.dtype.name # type: ignore ) result._operations.arithmetic_op_fields(scripted_field_name, arithmetic_object) @@ -768,7 +797,7 @@ class QueryCompiler: def get_arithmetic_op_fields(self) -> Optional["ArithmeticOpFieldsTask"]: return self._operations.get_arithmetic_op_fields() - def display_name_to_aggregatable_name(self, display_name: str) -> str: + def display_name_to_aggregatable_name(self, display_name: str) -> Optional[str]: aggregatable_field_name = self._mappings.aggregatable_field_name(display_name) return aggregatable_field_name @@ -780,13 +809,13 @@ class FieldMappingCache: DataFrame access is slower than dict access. """ - def __init__(self, mappings): + def __init__(self, mappings: "FieldMappings") -> None: self._mappings = mappings - self._field_name_pd_dtype = dict() - self._date_field_format = dict() + self._field_name_pd_dtype: Dict[str, str] = dict() + self._date_field_format: Dict[str, str] = dict() - def field_name_pd_dtype(self, es_field_name): + def field_name_pd_dtype(self, es_field_name: str) -> str: if es_field_name in self._field_name_pd_dtype: return self._field_name_pd_dtype[es_field_name] @@ -797,7 +826,7 @@ class FieldMappingCache: return pd_dtype - def date_field_format(self, es_field_name): + def date_field_format(self, es_field_name: str) -> str: if es_field_name in self._date_field_format: return self._date_field_format[es_field_name] diff --git a/eland/series.py b/eland/series.py index 7abfd72..9a29e30 100644 --- a/eland/series.py +++ b/eland/series.py @@ -38,8 +38,8 @@ from io import StringIO from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Tuple, Union import numpy as np -import pandas as pd -from pandas.io.common import _expand_user, stringify_path +import pandas as pd # type: ignore +from pandas.io.common import _expand_user, stringify_path # type: ignore import eland.plotting from eland.arithmetics import ArithmeticNumber, ArithmeticSeries, ArithmeticString @@ -61,10 +61,10 @@ from eland.filter import ( from eland.ndframe import NDFrame from eland.utils import to_list -if TYPE_CHECKING: # type: ignore - from elasticsearch import Elasticsearch # noqa: F401 +if TYPE_CHECKING: + from elasticsearch import Elasticsearch - from eland.query_compiler import QueryCompiler # noqa: F401 + from eland.query_compiler import QueryCompiler def _get_method_name() -> str: @@ -175,7 +175,7 @@ class Series(NDFrame): return num_rows, num_columns @property - def es_field_name(self) -> str: + def es_field_name(self) -> pd.Index: """ Returns ------- @@ -185,7 +185,7 @@ class Series(NDFrame): return self._query_compiler.get_field_names(include_scripted_fields=True)[0] @property - def name(self) -> str: + def name(self) -> pd.Index: return self._query_compiler.columns[0] @name.setter @@ -793,7 +793,7 @@ class Series(NDFrame): return buf.getvalue() - def __add__(self, right): + def __add__(self, right: "Series") -> "Series": """ Return addition of series and right, element-wise (binary operator add). @@ -853,7 +853,7 @@ class Series(NDFrame): """ return self._numeric_op(right, _get_method_name()) - def __truediv__(self, right): + def __truediv__(self, right: "Series") -> "Series": """ Return floating division of series and right, element-wise (binary operator truediv). @@ -892,7 +892,7 @@ class Series(NDFrame): """ return self._numeric_op(right, _get_method_name()) - def __floordiv__(self, right): + def __floordiv__(self, right: "Series") -> "Series": """ Return integer division of series and right, element-wise (binary operator floordiv //). @@ -931,7 +931,7 @@ class Series(NDFrame): """ return self._numeric_op(right, _get_method_name()) - def __mod__(self, right): + def __mod__(self, right: "Series") -> "Series": """ Return modulo of series and right, element-wise (binary operator mod %). @@ -970,7 +970,7 @@ class Series(NDFrame): """ return self._numeric_op(right, _get_method_name()) - def __mul__(self, right): + def __mul__(self, right: "Series") -> "Series": """ Return multiplication of series and right, element-wise (binary operator mul). @@ -1009,7 +1009,7 @@ class Series(NDFrame): """ return self._numeric_op(right, _get_method_name()) - def __sub__(self, right): + def __sub__(self, right: "Series") -> "Series": """ Return subtraction of series and right, element-wise (binary operator sub). @@ -1048,7 +1048,7 @@ class Series(NDFrame): """ return self._numeric_op(right, _get_method_name()) - def __pow__(self, right): + def __pow__(self, right: "Series") -> "Series": """ Return exponential power of series and right, element-wise (binary operator pow). @@ -1087,7 +1087,7 @@ class Series(NDFrame): """ return self._numeric_op(right, _get_method_name()) - def __radd__(self, left): + def __radd__(self, left: "Series") -> "Series": """ Return addition of series and left, element-wise (binary operator add). @@ -1119,7 +1119,7 @@ class Series(NDFrame): """ return self._numeric_op(left, _get_method_name()) - def __rtruediv__(self, left): + def __rtruediv__(self, left: "Series") -> "Series": """ Return division of series and left, element-wise (binary operator div). @@ -1151,7 +1151,7 @@ class Series(NDFrame): """ return self._numeric_op(left, _get_method_name()) - def __rfloordiv__(self, left): + def __rfloordiv__(self, left: "Series") -> "Series": """ Return integer division of series and left, element-wise (binary operator floordiv //). @@ -1183,7 +1183,7 @@ class Series(NDFrame): """ return self._numeric_op(left, _get_method_name()) - def __rmod__(self, left): + def __rmod__(self, left: "Series") -> "Series": """ Return modulo of series and left, element-wise (binary operator mod %). @@ -1215,7 +1215,7 @@ class Series(NDFrame): """ return self._numeric_op(left, _get_method_name()) - def __rmul__(self, left): + def __rmul__(self, left: "Series") -> "Series": """ Return multiplication of series and left, element-wise (binary operator mul). @@ -1247,7 +1247,7 @@ class Series(NDFrame): """ return self._numeric_op(left, _get_method_name()) - def __rpow__(self, left): + def __rpow__(self, left: "Series") -> "Series": """ Return exponential power of series and left, element-wise (binary operator pow). @@ -1279,7 +1279,7 @@ class Series(NDFrame): """ return self._numeric_op(left, _get_method_name()) - def __rsub__(self, left): + def __rsub__(self, left: "Series") -> "Series": """ Return subtraction of series and left, element-wise (binary operator sub). @@ -1398,7 +1398,7 @@ class Series(NDFrame): return series - def max(self, numeric_only=None): + def max(self, numeric_only: Optional[bool] = None) -> pd.Series: """ Return the maximum of the Series values @@ -1422,7 +1422,7 @@ class Series(NDFrame): results = super().max(numeric_only=numeric_only) return results.squeeze() - def mean(self, numeric_only=None): + def mean(self, numeric_only: Optional[bool] = None) -> pd.Series: """ Return the mean of the Series values @@ -1446,7 +1446,7 @@ class Series(NDFrame): results = super().mean(numeric_only=numeric_only) return results.squeeze() - def median(self, numeric_only=None): + def median(self, numeric_only: Optional[bool] = None) -> pd.Series: """ Return the median of the Series values @@ -1470,7 +1470,7 @@ class Series(NDFrame): results = super().median(numeric_only=numeric_only) return results.squeeze() - def min(self, numeric_only=None): + def min(self, numeric_only: Optional[bool] = None) -> pd.Series: """ Return the minimum of the Series values @@ -1494,7 +1494,7 @@ class Series(NDFrame): results = super().min(numeric_only=numeric_only) return results.squeeze() - def sum(self, numeric_only=None): + def sum(self, numeric_only: Optional[bool] = None) -> pd.Series: """ Return the sum of the Series values @@ -1518,7 +1518,7 @@ class Series(NDFrame): results = super().sum(numeric_only=numeric_only) return results.squeeze() - def nunique(self): + def nunique(self) -> pd.Series: """ Return the number of unique values in a Series @@ -1540,7 +1540,7 @@ class Series(NDFrame): results = super().nunique() return results.squeeze() - def var(self, numeric_only=None): + def var(self, numeric_only: Optional[bool] = None) -> pd.Series: """ Return variance for a Series @@ -1562,7 +1562,7 @@ class Series(NDFrame): results = super().var(numeric_only=numeric_only) return results.squeeze() - def std(self, numeric_only=None): + def std(self, numeric_only: Optional[bool] = None) -> pd.Series: """ Return standard deviation for a Series @@ -1584,7 +1584,7 @@ class Series(NDFrame): results = super().std(numeric_only=numeric_only) return results.squeeze() - def mad(self, numeric_only=None): + def mad(self, numeric_only: Optional[bool] = None) -> pd.Series: """ Return median absolute deviation for a Series @@ -1643,7 +1643,7 @@ class Series(NDFrame): # def values TODO - not implemented as causes current implementation of query to fail - def to_numpy(self): + def to_numpy(self) -> None: """ Not implemented. diff --git a/eland/tasks.py b/eland/tasks.py index fff7ec0..6164577 100644 --- a/eland/tasks.py +++ b/eland/tasks.py @@ -16,7 +16,7 @@ # under the License. from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, List, Tuple +from typing import TYPE_CHECKING, List, Tuple from eland import SortOrder from eland.actions import HeadAction, SortIndexAction, TailAction @@ -253,7 +253,7 @@ class QueryIdsTask(Task): class QueryTermsTask(Task): - def __init__(self, must: bool, field: str, terms: List[Any]): + def __init__(self, must: bool, field: str, terms: List[str]): """ Parameters ---------- diff --git a/noxfile.py b/noxfile.py index f058c71..7230355 100644 --- a/noxfile.py +++ b/noxfile.py @@ -38,7 +38,10 @@ TYPED_FILES = ( "eland/tasks.py", "eland/utils.py", "eland/groupby.py", + "eland/operations.py", + "eland/ndframe.py", "eland/ml/__init__.py", + "eland/ml/_optional.py", "eland/ml/_model_serializer.py", "eland/ml/ml_model.py", "eland/ml/transformers/__init__.py", @@ -46,6 +49,7 @@ TYPED_FILES = ( "eland/ml/transformers/lightgbm.py", "eland/ml/transformers/sklearn.py", "eland/ml/transformers/xgboost.py", + "eland/plotting/_matplotlib/__init__.py", ) @@ -60,7 +64,9 @@ def format(session): @nox.session(reuse_venv=True) def lint(session): - session.install("black", "flake8", "mypy", "isort") + # Install numpy to use its mypy plugin + # https://numpy.org/devdocs/reference/typing.html#mypy-plugin + session.install("black", "flake8", "mypy", "isort", "numpy") session.install("--pre", "elasticsearch") session.run("python", "utils/license-headers.py", "check", *SOURCE_FILES) session.run("black", "--check", "--target-version=py37", *SOURCE_FILES) diff --git a/requirements-dev.txt b/requirements-dev.txt index 07ffb12..9310580 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -10,3 +10,4 @@ xgboost>=1 nox lightgbm pytest-cov +mypy diff --git a/setup.cfg b/setup.cfg index c76db01..cb6a426 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,2 +1,4 @@ [isort] profile = black +[mypy] +plugins = numpy.typing.mypy_plugin