Add type hints to 'eland.operations' and 'eland.ndframe'

This commit is contained in:
P. Sai Vinay 2021-08-02 22:20:35 +05:30 committed by GitHub
parent c0e861dc77
commit 823f01cc6c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 428 additions and 276 deletions

View File

@ -19,9 +19,11 @@ from abc import ABC, abstractmethod
from io import StringIO from io import StringIO
from typing import TYPE_CHECKING, Any, List, Union from typing import TYPE_CHECKING, Any, List, Union
import numpy as np # type: ignore import numpy as np
if TYPE_CHECKING: if TYPE_CHECKING:
from numpy.typing import DTypeLike
from .query_compiler import QueryCompiler from .query_compiler import QueryCompiler
@ -32,7 +34,7 @@ class ArithmeticObject(ABC):
pass pass
@abstractmethod @abstractmethod
def dtype(self) -> np.dtype: def dtype(self) -> "DTypeLike":
pass pass
@abstractmethod @abstractmethod
@ -52,7 +54,7 @@ class ArithmeticString(ArithmeticObject):
return self.value return self.value
@property @property
def dtype(self) -> np.dtype: def dtype(self) -> "DTypeLike":
return np.dtype(object) return np.dtype(object)
@property @property
@ -64,7 +66,7 @@ class ArithmeticString(ArithmeticObject):
class ArithmeticNumber(ArithmeticObject): class ArithmeticNumber(ArithmeticObject):
def __init__(self, value: Union[int, float], dtype: np.dtype): def __init__(self, value: Union[int, float], dtype: "DTypeLike"):
self._value = value self._value = value
self._dtype = dtype self._dtype = dtype
@ -76,7 +78,7 @@ class ArithmeticNumber(ArithmeticObject):
return f"{self._value}" return f"{self._value}"
@property @property
def dtype(self) -> np.dtype: def dtype(self) -> "DTypeLike":
return self._dtype return self._dtype
def __repr__(self) -> str: def __repr__(self) -> str:
@ -89,8 +91,8 @@ class ArithmeticSeries(ArithmeticObject):
""" """
def __init__( def __init__(
self, query_compiler: "QueryCompiler", display_name: str, dtype: np.dtype self, query_compiler: "QueryCompiler", display_name: str, dtype: "DTypeLike"
): ) -> None:
# type defs # type defs
self._value: str self._value: str
self._tasks: List["ArithmeticTask"] self._tasks: List["ArithmeticTask"]
@ -121,7 +123,7 @@ class ArithmeticSeries(ArithmeticObject):
return self._value return self._value
@property @property
def dtype(self) -> np.dtype: def dtype(self) -> "DTypeLike":
return self._dtype return self._dtype
def __repr__(self) -> str: def __repr__(self) -> str:

View File

@ -18,12 +18,24 @@
import re import re
import warnings import warnings
from enum import Enum from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Optional,
Tuple,
Union,
cast,
)
import numpy as np # type: ignore
import pandas as pd # type: ignore import pandas as pd # type: ignore
from elasticsearch import Elasticsearch from elasticsearch import Elasticsearch
if TYPE_CHECKING:
from numpy.typing import DTypeLike
# Default number of rows displayed (different to pandas where ALL could be displayed) # Default number of rows displayed (different to pandas where ALL could be displayed)
DEFAULT_NUM_ROWS_DISPLAYED = 60 DEFAULT_NUM_ROWS_DISPLAYED = 60
DEFAULT_CHUNK_SIZE = 10000 DEFAULT_CHUNK_SIZE = 10000
@ -42,7 +54,7 @@ with warnings.catch_warnings():
def build_pd_series( def build_pd_series(
data: Dict[str, Any], dtype: Optional[np.dtype] = None, **kwargs: Any data: Dict[str, Any], dtype: Optional["DTypeLike"] = None, **kwargs: Any
) -> pd.Series: ) -> pd.Series:
"""Builds a pd.Series while squelching the warning """Builds a pd.Series while squelching the warning
for unspecified dtype on empty series for unspecified dtype on empty series
@ -88,7 +100,7 @@ class SortOrder(Enum):
def elasticsearch_date_to_pandas_date( def elasticsearch_date_to_pandas_date(
value: Union[int, str], date_format: Optional[str] value: Union[int, str, float], date_format: Optional[str]
) -> pd.Timestamp: ) -> pd.Timestamp:
""" """
Given a specific Elasticsearch format for a date datatype, returns the Given a specific Elasticsearch format for a date datatype, returns the
@ -98,7 +110,7 @@ def elasticsearch_date_to_pandas_date(
Parameters Parameters
---------- ----------
value: Union[int, str] value: Union[int, str, float]
The date value. The date value.
date_format: str date_format: str
The Elasticsearch date format (ex. 'epoch_millis', 'epoch_second', etc.) The Elasticsearch date format (ex. 'epoch_millis', 'epoch_second', etc.)

View File

@ -15,9 +15,11 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
from typing import Any, Dict
import numpy as np import numpy as np
import pandas as pd import pandas as pd # type: ignore
import pytest import pytest # type: ignore
import eland as ed import eland as ed
@ -28,7 +30,7 @@ pd.set_option("display.width", 100)
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def add_imports(doctest_namespace): def add_imports(doctest_namespace: Dict[str, Any]) -> None:
doctest_namespace["np"] = np doctest_namespace["np"] = np
doctest_namespace["pd"] = pd doctest_namespace["pd"] = pd
doctest_namespace["ed"] = ed doctest_namespace["ed"] = ed

View File

@ -19,19 +19,19 @@ import re
import sys import sys
import warnings import warnings
from io import StringIO from io import StringIO
from typing import Any, List, Optional, Sequence, Tuple, Union from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Tuple, Union
import numpy as np import numpy as np
import pandas as pd import pandas as pd # type: ignore
from pandas.core.common import apply_if_callable, is_bool_indexer from pandas.core.common import apply_if_callable, is_bool_indexer # type: ignore
from pandas.core.computation.eval import eval from pandas.core.computation.eval import eval # type: ignore
from pandas.core.dtypes.common import is_list_like from pandas.core.dtypes.common import is_list_like # type: ignore
from pandas.core.indexing import check_bool_indexer from pandas.core.indexing import check_bool_indexer # type: ignore
from pandas.io.common import _expand_user, stringify_path from pandas.io.common import _expand_user, stringify_path # type: ignore
from pandas.io.formats import console from pandas.io.formats import console # type: ignore
from pandas.io.formats import format as fmt from pandas.io.formats import format as fmt
from pandas.io.formats.printing import pprint_thing from pandas.io.formats.printing import pprint_thing # type: ignore
from pandas.util._validators import validate_bool_kwarg from pandas.util._validators import validate_bool_kwarg # type: ignore
import eland.plotting as gfx import eland.plotting as gfx
from eland.common import DEFAULT_NUM_ROWS_DISPLAYED, docstring_parameter from eland.common import DEFAULT_NUM_ROWS_DISPLAYED, docstring_parameter
@ -41,6 +41,11 @@ from eland.ndframe import NDFrame
from eland.series import Series from eland.series import Series
from eland.utils import is_valid_attr_name from eland.utils import is_valid_attr_name
if TYPE_CHECKING:
from elasticsearch import Elasticsearch
from .query_compiler import QueryCompiler
class DataFrame(NDFrame): class DataFrame(NDFrame):
""" """
@ -119,11 +124,13 @@ class DataFrame(NDFrame):
def __init__( def __init__(
self, self,
es_client=None, es_client: Optional[
es_index_pattern=None, Union[str, List[str], Tuple[str, ...], "Elasticsearch"]
es_index_field=None, ] = None,
columns=None, es_index_pattern: Optional[str] = None,
_query_compiler=None, columns: Optional[List[str]] = None,
es_index_field: Optional[str] = None,
_query_compiler: Optional["QueryCompiler"] = None,
) -> None: ) -> None:
""" """
There are effectively 2 constructors: There are effectively 2 constructors:
@ -147,7 +154,7 @@ class DataFrame(NDFrame):
_query_compiler=_query_compiler, _query_compiler=_query_compiler,
) )
def _get_columns(self): def _get_columns(self) -> pd.Index:
""" """
The column labels of the DataFrame. The column labels of the DataFrame.
@ -178,7 +185,7 @@ class DataFrame(NDFrame):
columns = property(_get_columns) columns = property(_get_columns)
@property @property
def empty(self): def empty(self) -> bool:
"""Determines if the DataFrame is empty. """Determines if the DataFrame is empty.
Returns Returns
@ -278,7 +285,10 @@ class DataFrame(NDFrame):
return DataFrame(_query_compiler=self._query_compiler.tail(n)) return DataFrame(_query_compiler=self._query_compiler.tail(n))
def sample( def sample(
self, n: int = None, frac: float = None, random_state: int = None self,
n: Optional[int] = None,
frac: Optional[float] = None,
random_state: Optional[int] = None,
) -> "DataFrame": ) -> "DataFrame":
""" """
Return n randomly sample rows or the specify fraction of rows Return n randomly sample rows or the specify fraction of rows
@ -469,7 +479,7 @@ class DataFrame(NDFrame):
if is_valid_attr_name(column_name) if is_valid_attr_name(column_name)
] ]
def __repr__(self): def __repr__(self) -> None:
""" """
From pandas From pandas
""" """
@ -501,7 +511,7 @@ class DataFrame(NDFrame):
return buf.getvalue() return buf.getvalue()
def _info_repr(self): def _info_repr(self) -> bool:
""" """
True if the repr should show the info view. True if the repr should show the info view.
""" """
@ -510,7 +520,7 @@ class DataFrame(NDFrame):
self._repr_fits_horizontal_() and self._repr_fits_vertical_() self._repr_fits_horizontal_() and self._repr_fits_vertical_()
) )
def _repr_html_(self): def _repr_html_(self) -> Optional[str]:
""" """
From pandas - this is called by notebooks From pandas - this is called by notebooks
""" """
@ -540,7 +550,7 @@ class DataFrame(NDFrame):
else: else:
return None return None
def count(self): def count(self) -> pd.Series:
""" """
Count non-NA cells for each column. Count non-NA cells for each column.
@ -855,10 +865,10 @@ class DataFrame(NDFrame):
exceeds_info_cols = len(self.columns) > max_cols exceeds_info_cols = len(self.columns) > max_cols
# From pandas.DataFrame # From pandas.DataFrame
def _put_str(s, space): def _put_str(s, space) -> str:
return f"{s}"[:space].ljust(space) return f"{s}"[:space].ljust(space)
def _verbose_repr(): def _verbose_repr() -> None:
lines.append(f"Data columns (total {len(self.columns)} columns):") lines.append(f"Data columns (total {len(self.columns)} columns):")
id_head = " # " id_head = " # "
@ -930,10 +940,10 @@ class DataFrame(NDFrame):
+ _put_str(dtype, space_dtype) + _put_str(dtype, space_dtype)
) )
def _non_verbose_repr(): def _non_verbose_repr() -> None:
lines.append(self.columns._summary(name="Columns")) lines.append(self.columns._summary(name="Columns"))
def _sizeof_fmt(num, size_qualifier): def _sizeof_fmt(num: float, size_qualifier: str) -> str:
# returns size in human readable format # returns size in human readable format
for x in ["bytes", "KB", "MB", "GB", "TB"]: for x in ["bytes", "KB", "MB", "GB", "TB"]:
if num < 1024.0: if num < 1024.0:
@ -1004,7 +1014,7 @@ class DataFrame(NDFrame):
border=None, border=None,
table_id=None, table_id=None,
render_links=False, render_links=False,
): ) -> Any:
""" """
Render a Elasticsearch data as an HTML table. Render a Elasticsearch data as an HTML table.
@ -1171,7 +1181,7 @@ class DataFrame(NDFrame):
result = _buf.getvalue() result = _buf.getvalue()
return result return result
def __getattr__(self, key): def __getattr__(self, key: str) -> Any:
"""After regular attribute access, looks up the name in the columns """After regular attribute access, looks up the name in the columns
Parameters Parameters
@ -1190,7 +1200,12 @@ class DataFrame(NDFrame):
return self[key] return self[key]
raise e raise e
def _getitem(self, key): def _getitem(
self,
key: Union[
"DataFrame", "Series", pd.Index, List[str], str, BooleanFilter, np.ndarray
],
) -> Union["Series", "DataFrame"]:
"""Get the column specified by key for this DataFrame. """Get the column specified by key for this DataFrame.
Args: Args:
@ -1215,13 +1230,13 @@ class DataFrame(NDFrame):
else: else:
return self._getitem_column(key) return self._getitem_column(key)
def _getitem_column(self, key): def _getitem_column(self, key: str) -> "Series":
if key not in self.columns: if key not in self.columns:
raise KeyError(f"Requested column [{key}] is not in the DataFrame.") raise KeyError(f"Requested column [{key}] is not in the DataFrame.")
s = self._reduce_dimension(self._query_compiler.getitem_column_array([key])) s = self._reduce_dimension(self._query_compiler.getitem_column_array([key]))
return s return s
def _getitem_array(self, key): def _getitem_array(self, key: Union[str, pd.Series]) -> "DataFrame":
if isinstance(key, Series): if isinstance(key, Series):
key = key.to_pandas() key = key.to_pandas()
if is_bool_indexer(key): if is_bool_indexer(key):
@ -1256,7 +1271,9 @@ class DataFrame(NDFrame):
_query_compiler=self._query_compiler.getitem_column_array(key) _query_compiler=self._query_compiler.getitem_column_array(key)
) )
def _create_or_update_from_compiler(self, new_query_compiler, inplace=False): def _create_or_update_from_compiler(
self, new_query_compiler: "QueryCompiler", inplace: bool = False
) -> Union["QueryCompiler", "DataFrame"]:
"""Returns or updates a DataFrame given new query_compiler""" """Returns or updates a DataFrame given new query_compiler"""
assert ( assert (
isinstance(new_query_compiler, type(self._query_compiler)) isinstance(new_query_compiler, type(self._query_compiler))
@ -1265,10 +1282,10 @@ class DataFrame(NDFrame):
if not inplace: if not inplace:
return DataFrame(_query_compiler=new_query_compiler) return DataFrame(_query_compiler=new_query_compiler)
else: else:
self._query_compiler = new_query_compiler self._query_compiler: "QueryCompiler" = new_query_compiler
@staticmethod @staticmethod
def _reduce_dimension(query_compiler): def _reduce_dimension(query_compiler: "QueryCompiler") -> "Series":
return Series(_query_compiler=query_compiler) return Series(_query_compiler=query_compiler)
def to_csv( def to_csv(
@ -1849,7 +1866,9 @@ class DataFrame(NDFrame):
else: else:
raise NotImplementedError(expr, type(expr)) raise NotImplementedError(expr, type(expr))
def get(self, key, default=None): def get(
self, key: Any, default: Optional[Any] = None
) -> Union["Series", "DataFrame"]:
""" """
Get item from object for given key (ex: DataFrame column). Get item from object for given key (ex: DataFrame column).
Returns default value if not found. Returns default value if not found.
@ -1956,7 +1975,7 @@ class DataFrame(NDFrame):
elif like is not None: elif like is not None:
def matcher(x): def matcher(x: str) -> bool:
return like in x return like in x
else: else:
@ -1965,7 +1984,7 @@ class DataFrame(NDFrame):
return self[[column for column in self.columns if matcher(column)]] return self[[column for column in self.columns if matcher(column)]]
@property @property
def values(self): def values(self) -> None:
""" """
Not implemented. Not implemented.
@ -1983,7 +2002,7 @@ class DataFrame(NDFrame):
""" """
return self.to_numpy() return self.to_numpy()
def to_numpy(self): def to_numpy(self) -> None:
""" """
Not implemented. Not implemented.

View File

@ -25,13 +25,14 @@ from typing import (
NamedTuple, NamedTuple,
Optional, Optional,
Set, Set,
TextIO,
Tuple, Tuple,
Union, Union,
) )
import numpy as np import numpy as np
import pandas as pd import pandas as pd # type: ignore
from pandas.core.dtypes.common import ( from pandas.core.dtypes.common import ( # type: ignore
is_bool_dtype, is_bool_dtype,
is_datetime_or_timedelta_dtype, is_datetime_or_timedelta_dtype,
is_float_dtype, is_float_dtype,
@ -42,6 +43,7 @@ from pandas.core.dtypes.inference import is_list_like
if TYPE_CHECKING: if TYPE_CHECKING:
from elasticsearch import Elasticsearch from elasticsearch import Elasticsearch
from numpy.typing import DTypeLike
ES_FLOAT_TYPES: Set[str] = {"double", "float", "half_float", "scaled_float"} ES_FLOAT_TYPES: Set[str] = {"double", "float", "half_float", "scaled_float"}
@ -559,7 +561,7 @@ class FieldMappings:
return {"mappings": {"properties": mapping_props}} return {"mappings": {"properties": mapping_props}}
def aggregatable_field_name(self, display_name): def aggregatable_field_name(self, display_name: str) -> Optional[str]:
""" """
Return a single aggregatable field_name from display_name Return a single aggregatable field_name from display_name
@ -598,7 +600,7 @@ class FieldMappings:
return self._mappings_capabilities.loc[display_name].aggregatable_es_field_name return self._mappings_capabilities.loc[display_name].aggregatable_es_field_name
def aggregatable_field_names(self): def aggregatable_field_names(self) -> Dict[str, str]:
""" """
Return a list of aggregatable Elasticsearch field_names for all display names. Return a list of aggregatable Elasticsearch field_names for all display names.
If field is not aggregatable_field_names, return nothing. If field is not aggregatable_field_names, return nothing.
@ -634,7 +636,7 @@ class FieldMappings:
)["data"] )["data"]
) )
def date_field_format(self, es_field_name): def date_field_format(self, es_field_name: str) -> str:
""" """
Parameters Parameters
---------- ----------
@ -650,7 +652,7 @@ class FieldMappings:
self._mappings_capabilities.es_field_name == es_field_name self._mappings_capabilities.es_field_name == es_field_name
].es_date_format.squeeze() ].es_date_format.squeeze()
def field_name_pd_dtype(self, es_field_name): def field_name_pd_dtype(self, es_field_name: str) -> str:
""" """
Parameters Parameters
---------- ----------
@ -674,7 +676,9 @@ class FieldMappings:
].pd_dtype.squeeze() ].pd_dtype.squeeze()
return pd_dtype return pd_dtype
def add_scripted_field(self, scripted_field_name, display_name, pd_dtype): def add_scripted_field(
self, scripted_field_name: str, display_name: str, pd_dtype: str
) -> None:
# if this display name is used somewhere else, drop it # if this display name is used somewhere else, drop it
if display_name in self._mappings_capabilities.index: if display_name in self._mappings_capabilities.index:
self._mappings_capabilities = self._mappings_capabilities.drop( self._mappings_capabilities = self._mappings_capabilities.drop(
@ -706,8 +710,8 @@ class FieldMappings:
capability_matrix_row capability_matrix_row
) )
def numeric_source_fields(self): def numeric_source_fields(self) -> List[str]:
pd_dtypes, es_field_names, es_date_formats = self.metric_source_fields() _, es_field_names, _ = self.metric_source_fields()
return es_field_names return es_field_names
def all_source_fields(self) -> List[Field]: def all_source_fields(self) -> List[Field]:
@ -753,7 +757,9 @@ class FieldMappings:
# Maintain groupby order as given input # Maintain groupby order as given input
return [groupby_fields[column] for column in by], aggregatable_fields return [groupby_fields[column] for column in by], aggregatable_fields
def metric_source_fields(self, include_bool=False, include_timestamp=False): def metric_source_fields(
self, include_bool: bool = False, include_timestamp: bool = False
) -> Tuple[List["DTypeLike"], List[str], Optional[List[str]]]:
""" """
Returns Returns
------- -------
@ -790,7 +796,7 @@ class FieldMappings:
# return in display_name order # return in display_name order
return pd_dtypes, es_field_names, es_date_formats return pd_dtypes, es_field_names, es_date_formats
def get_field_names(self, include_scripted_fields=True): def get_field_names(self, include_scripted_fields: bool = True) -> List[str]:
if include_scripted_fields: if include_scripted_fields:
return self._mappings_capabilities.es_field_name.to_list() return self._mappings_capabilities.es_field_name.to_list()
@ -801,7 +807,7 @@ class FieldMappings:
def _get_display_names(self): def _get_display_names(self):
return self._mappings_capabilities.index.to_list() return self._mappings_capabilities.index.to_list()
def _set_display_names(self, display_names): def _set_display_names(self, display_names: List[str]):
if not is_list_like(display_names): if not is_list_like(display_names):
raise ValueError(f"'{display_names}' is not list like") raise ValueError(f"'{display_names}' is not list like")
@ -842,7 +848,7 @@ class FieldMappings:
es_dtypes.name = None es_dtypes.name = None
return es_dtypes return es_dtypes
def es_info(self, buf): def es_info(self, buf: TextIO) -> None:
buf.write("Mappings:\n") buf.write("Mappings:\n")
buf.write(f" capabilities:\n{self._mappings_capabilities.to_string()}\n") buf.write(f" capabilities:\n{self._mappings_capabilities.to_string()}\n")

View File

@ -17,8 +17,11 @@
import distutils.version import distutils.version
import importlib import importlib
import types
import warnings import warnings
from typing import TYPE_CHECKING, Any, Optional
if TYPE_CHECKING:
from types import ModuleType
# ---------------------------------------------------------------------------- # ----------------------------------------------------------------------------
# functions largely based / taken from the six module # functions largely based / taken from the six module
@ -42,7 +45,7 @@ version_message = (
) )
def _get_version(module: types.ModuleType) -> str: def _get_version(module: "ModuleType") -> Any:
version = getattr(module, "__version__", None) version = getattr(module, "__version__", None)
if version is None: if version is None:
# xlrd uses a capitalized attribute name # xlrd uses a capitalized attribute name
@ -55,7 +58,7 @@ def _get_version(module: types.ModuleType) -> str:
def import_optional_dependency( def import_optional_dependency(
name: str, extra: str = "", raise_on_missing: bool = True, on_version: str = "raise" name: str, extra: str = "", raise_on_missing: bool = True, on_version: str = "raise"
): ) -> Optional["ModuleType"]:
""" """
Import an optional dependency. Import an optional dependency.

View File

@ -18,7 +18,7 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
import elasticsearch import elasticsearch
import numpy as np # type: ignore import numpy as np
from eland.common import ensure_es_client, es_version from eland.common import ensure_es_client, es_version
from eland.utils import deprecated_api from eland.utils import deprecated_api
@ -27,7 +27,8 @@ from .common import TYPE_CLASSIFICATION, TYPE_REGRESSION
from .transformers import get_model_transformer from .transformers import get_model_transformer
if TYPE_CHECKING: if TYPE_CHECKING:
from elasticsearch import Elasticsearch # noqa: F401 from elasticsearch import Elasticsearch
from numpy.typing import ArrayLike, DTypeLike
# Try importing each ML lib separately so mypy users don't have to # Try importing each ML lib separately so mypy users don't have to
# have both installed to use type-checking. # have both installed to use type-checking.
@ -83,8 +84,8 @@ class MLModel:
self._trained_model_config_cache: Optional[Dict[str, Any]] = None self._trained_model_config_cache: Optional[Dict[str, Any]] = None
def predict( def predict(
self, X: Union[np.ndarray, List[float], List[List[float]]] self, X: Union["ArrayLike", List[float], List[List[float]]]
) -> np.ndarray: ) -> "ArrayLike":
""" """
Make a prediction using a trained model stored in Elasticsearch. Make a prediction using a trained model stored in Elasticsearch.
@ -196,7 +197,7 @@ class MLModel:
# Return results as np.ndarray of float32 or int (consistent with sklearn/xgboost) # Return results as np.ndarray of float32 or int (consistent with sklearn/xgboost)
if self.model_type == TYPE_CLASSIFICATION: if self.model_type == TYPE_CLASSIFICATION:
dt = np.int_ dt: "DTypeLike" = np.int_
else: else:
dt = np.float32 dt = np.float32
return np.asarray(y, dtype=dt) return np.asarray(y, dtype=dt)

View File

@ -17,7 +17,7 @@
from typing import Any, Dict, Optional, Sequence, Tuple, Type, Union from typing import Any, Dict, Optional, Sequence, Tuple, Type, Union
import numpy as np # type: ignore import numpy as np
from .._model_serializer import Ensemble, Tree, TreeNode from .._model_serializer import Ensemble, Tree, TreeNode
from .._optional import import_optional_dependency from .._optional import import_optional_dependency
@ -64,7 +64,7 @@ class SKLearnTransformer(ModelTransformer):
self, self,
node_index: int, node_index: int,
node_data: Tuple[Union[int, float], ...], node_data: Tuple[Union[int, float], ...],
value: np.ndarray, value: np.ndarray, # type: ignore
) -> TreeNode: ) -> TreeNode:
""" """
This builds out a TreeNode class given the sklearn tree node definition. This builds out a TreeNode class given the sklearn tree node definition.

View File

@ -229,7 +229,7 @@ class XGBoostClassifierTransformer(XGBoostForestTransformer):
if model.classes_ is None: if model.classes_ is None:
n_estimators = model.get_params()["n_estimators"] n_estimators = model.get_params()["n_estimators"]
num_trees = model.get_booster().trees_to_dataframe()["Tree"].max() + 1 num_trees = model.get_booster().trees_to_dataframe()["Tree"].max() + 1
self._num_classes = num_trees // n_estimators self._num_classes: int = num_trees // n_estimators
else: else:
self._num_classes = len(model.classes_) self._num_classes = len(model.classes_)

View File

@ -17,7 +17,7 @@
import sys import sys
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, List, Optional, Tuple, Union from typing import TYPE_CHECKING, List, Optional, TextIO, Tuple, Union
import pandas as pd # type: ignore import pandas as pd # type: ignore
@ -186,7 +186,7 @@ class NDFrame(ABC):
""" """
return len(self.index) return len(self.index)
def _es_info(self, buf): def _es_info(self, buf: TextIO) -> None:
self._query_compiler.es_info(buf) self._query_compiler.es_info(buf)
def mean(self, numeric_only: Optional[bool] = None) -> pd.Series: def mean(self, numeric_only: Optional[bool] = None) -> pd.Series:
@ -604,7 +604,7 @@ class NDFrame(ABC):
""" """
return self._query_compiler.mad(numeric_only=numeric_only) return self._query_compiler.mad(numeric_only=numeric_only)
def _hist(self, num_bins): def _hist(self, num_bins: int) -> Tuple[pd.DataFrame, pd.DataFrame]:
return self._query_compiler._hist(num_bins) return self._query_compiler._hist(num_bins)
def describe(self) -> pd.DataFrame: def describe(self) -> pd.DataFrame:

View File

@ -26,12 +26,13 @@ from typing import (
List, List,
Optional, Optional,
Sequence, Sequence,
TextIO,
Tuple, Tuple,
Union, Union,
) )
import numpy as np import numpy as np
import pandas as pd import pandas as pd # type: ignore
from elasticsearch.helpers import scan from elasticsearch.helpers import scan
from eland.actions import PostProcessingAction, SortFieldAction from eland.actions import PostProcessingAction, SortFieldAction
@ -58,13 +59,18 @@ from eland.tasks import (
) )
if TYPE_CHECKING: if TYPE_CHECKING:
from numpy.typing import DTypeLike
from eland.arithmetics import ArithmeticSeries
from eland.field_mappings import Field from eland.field_mappings import Field
from eland.filter import BooleanFilter
from eland.query_compiler import QueryCompiler from eland.query_compiler import QueryCompiler
from eland.tasks import Task
class QueryParams: class QueryParams:
def __init__(self): def __init__(self) -> None:
self.query = Query() self.query: Query = Query()
self.sort_field: Optional[str] = None self.sort_field: Optional[str] = None
self.sort_order: Optional[SortOrder] = None self.sort_order: Optional[SortOrder] = None
self.size: Optional[int] = None self.size: Optional[int] = None
@ -85,37 +91,48 @@ class Operations:
(see https://docs.dask.org/en/latest/spec.html) (see https://docs.dask.org/en/latest/spec.html)
""" """
def __init__(self, tasks=None, arithmetic_op_fields_task=None): def __init__(
self,
tasks: Optional[List["Task"]] = None,
arithmetic_op_fields_task: Optional["ArithmeticOpFieldsTask"] = None,
) -> None:
self._tasks: List["Task"]
if tasks is None: if tasks is None:
self._tasks = [] self._tasks = []
else: else:
self._tasks = tasks self._tasks = tasks
self._arithmetic_op_fields_task = arithmetic_op_fields_task self._arithmetic_op_fields_task = arithmetic_op_fields_task
def __constructor__(self, *args, **kwargs): def __constructor__(
self,
*args: Any,
**kwargs: Any,
) -> "Operations":
return type(self)(*args, **kwargs) return type(self)(*args, **kwargs)
def copy(self): def copy(self) -> "Operations":
return self.__constructor__( return self.__constructor__(
tasks=copy.deepcopy(self._tasks), tasks=copy.deepcopy(self._tasks),
arithmetic_op_fields_task=copy.deepcopy(self._arithmetic_op_fields_task), arithmetic_op_fields_task=copy.deepcopy(self._arithmetic_op_fields_task),
) )
def head(self, index, n): def head(self, index: "Index", n: int) -> None:
# Add a task that is an ascending sort with size=n # Add a task that is an ascending sort with size=n
task = HeadTask(index, n) task = HeadTask(index, n)
self._tasks.append(task) self._tasks.append(task)
def tail(self, index, n): def tail(self, index: "Index", n: int) -> None:
# Add a task that is descending sort with size=n # Add a task that is descending sort with size=n
task = TailTask(index, n) task = TailTask(index, n)
self._tasks.append(task) self._tasks.append(task)
def sample(self, index, n, random_state): def sample(self, index: "Index", n: int, random_state: int) -> None:
task = SampleTask(index, n, random_state) task = SampleTask(index, n, random_state)
self._tasks.append(task) self._tasks.append(task)
def arithmetic_op_fields(self, display_name, arithmetic_series): def arithmetic_op_fields(
self, display_name: str, arithmetic_series: "ArithmeticSeries"
) -> None:
if self._arithmetic_op_fields_task is None: if self._arithmetic_op_fields_task is None:
self._arithmetic_op_fields_task = ArithmeticOpFieldsTask( self._arithmetic_op_fields_task = ArithmeticOpFieldsTask(
display_name, arithmetic_series display_name, arithmetic_series
@ -127,10 +144,10 @@ class Operations:
# get an ArithmeticOpFieldsTask if it exists # get an ArithmeticOpFieldsTask if it exists
return self._arithmetic_op_fields_task return self._arithmetic_op_fields_task
def __repr__(self): def __repr__(self) -> str:
return repr(self._tasks) return repr(self._tasks)
def count(self, query_compiler): def count(self, query_compiler: "QueryCompiler") -> pd.Series:
query_params, post_processing = self._resolve_tasks(query_compiler) query_params, post_processing = self._resolve_tasks(query_compiler)
# Elasticsearch _count is very efficient and so used to return results here. This means that # Elasticsearch _count is very efficient and so used to return results here. This means that
@ -161,7 +178,7 @@ class Operations:
def _metric_agg_series( def _metric_agg_series(
self, self,
query_compiler: "QueryCompiler", query_compiler: "QueryCompiler",
agg: List, agg: List["str"],
numeric_only: Optional[bool] = None, numeric_only: Optional[bool] = None,
) -> pd.Series: ) -> pd.Series:
results = self._metric_aggs(query_compiler, agg, numeric_only=numeric_only) results = self._metric_aggs(query_compiler, agg, numeric_only=numeric_only)
@ -170,7 +187,7 @@ class Operations:
else: else:
# If all results are float convert into float64 # If all results are float convert into float64
if all(isinstance(i, float) for i in results.values()): if all(isinstance(i, float) for i in results.values()):
dtype = np.float64 dtype: "DTypeLike" = np.float64
# If all results are int convert into int64 # If all results are int convert into int64
elif all(isinstance(i, int) for i in results.values()): elif all(isinstance(i, int) for i in results.values()):
dtype = np.int64 dtype = np.int64
@ -184,7 +201,9 @@ class Operations:
def value_counts(self, query_compiler: "QueryCompiler", es_size: int) -> pd.Series: def value_counts(self, query_compiler: "QueryCompiler", es_size: int) -> pd.Series:
return self._terms_aggs(query_compiler, "terms", es_size) return self._terms_aggs(query_compiler, "terms", es_size)
def hist(self, query_compiler, bins): def hist(
self, query_compiler: "QueryCompiler", bins: int
) -> Tuple[pd.DataFrame, pd.DataFrame]:
return self._hist_aggs(query_compiler, bins) return self._hist_aggs(query_compiler, bins)
def idx( def idx(
@ -237,7 +256,12 @@ class Operations:
return pd.Series(results) return pd.Series(results)
def aggs(self, query_compiler, pd_aggs, numeric_only=None) -> pd.DataFrame: def aggs(
self,
query_compiler: "QueryCompiler",
pd_aggs: List[str],
numeric_only: Optional[bool] = None,
) -> pd.DataFrame:
results = self._metric_aggs( results = self._metric_aggs(
query_compiler, pd_aggs, numeric_only=numeric_only, is_dataframe_agg=True query_compiler, pd_aggs, numeric_only=numeric_only, is_dataframe_agg=True
) )
@ -441,13 +465,15 @@ class Operations:
try: try:
# get first value in dict (key is .keyword) # get first value in dict (key is .keyword)
name = list(aggregatable_field_names.values())[0] name: Optional[str] = list(aggregatable_field_names.values())[0]
except IndexError: except IndexError:
name = None name = None
return build_pd_series(results, name=name) return build_pd_series(results, name=name)
def _hist_aggs(self, query_compiler, num_bins): def _hist_aggs(
self, query_compiler: "QueryCompiler", num_bins: int
) -> Tuple[pd.DataFrame, pd.DataFrame]:
# Get histogram bins and weights for numeric field_names # Get histogram bins and weights for numeric field_names
query_params, post_processing = self._resolve_tasks(query_compiler) query_params, post_processing = self._resolve_tasks(query_compiler)
@ -488,8 +514,8 @@ class Operations:
# }, # },
# ... # ...
bins = {} bins: Dict[str, List[int]] = {}
weights = {} weights: Dict[str, List[int]] = {}
# There is one more bin that weights # There is one more bin that weights
# len(bins) = len(weights) + 1 # len(bins) = len(weights) + 1
@ -537,11 +563,11 @@ class Operations:
def _unpack_metric_aggs( def _unpack_metric_aggs(
self, self,
fields: List["Field"], fields: List["Field"],
es_aggs: Union[List[str], List[Tuple[str, str]]], es_aggs: Union[List[str], List[Tuple[str, List[float]]]],
pd_aggs: List[str], pd_aggs: List[str],
response: Dict[str, Any], response: Dict[str, Any],
numeric_only: Optional[bool], numeric_only: Optional[bool],
percentiles: Optional[List[float]] = None, percentiles: Optional[Sequence[float]] = None,
is_dataframe_agg: bool = False, is_dataframe_agg: bool = False,
is_groupby: bool = False, is_groupby: bool = False,
) -> Dict[str, List[Any]]: ) -> Dict[str, List[Any]]:
@ -574,7 +600,7 @@ class Operations:
""" """
results: Dict[str, Any] = {} results: Dict[str, Any] = {}
percentile_values: List[float] = [] percentile_values: List[float] = []
agg_value: Union[int, float] agg_value: Any
for field in fields: for field in fields:
values = [] values = []
@ -611,7 +637,10 @@ class Operations:
agg_value = agg_value["50.0"] agg_value = agg_value["50.0"]
else: else:
# Maintain order of percentiles # Maintain order of percentiles
percentile_values = [agg_value[str(i)] for i in percentiles] if percentiles:
percentile_values = [
agg_value[str(i)] for i in percentiles
]
if not percentile_values and pd_agg not in ("quantile", "median"): if not percentile_values and pd_agg not in ("quantile", "median"):
agg_value = agg_value[es_agg[1]] agg_value = agg_value[es_agg[1]]
@ -682,7 +711,11 @@ class Operations:
# Cardinality is always either NaN or integer. # Cardinality is always either NaN or integer.
elif pd_agg in ("nunique", "count"): elif pd_agg in ("nunique", "count"):
agg_value = int(agg_value) agg_value = (
int(agg_value)
if isinstance(agg_value, (int, float))
else np.NaN
)
# If this is a non-null timestamp field convert to a pd.Timestamp() # If this is a non-null timestamp field convert to a pd.Timestamp()
elif field.is_timestamp: elif field.is_timestamp:
@ -702,6 +735,7 @@ class Operations:
for value in percentile_values for value in percentile_values
] ]
else: else:
assert not isinstance(agg_value, dict)
agg_value = elasticsearch_date_to_pandas_date( agg_value = elasticsearch_date_to_pandas_date(
agg_value, field.es_date_format agg_value, field.es_date_format
) )
@ -771,7 +805,7 @@ class Operations:
by: List[str], by: List[str],
pd_aggs: List[str], pd_aggs: List[str],
dropna: bool = True, dropna: bool = True,
quantiles: Optional[List[float]] = None, quantiles: Optional[Union[int, float, List[int], List[float]]] = None,
is_dataframe_agg: bool = False, is_dataframe_agg: bool = False,
numeric_only: Optional[bool] = True, numeric_only: Optional[bool] = True,
) -> pd.DataFrame: ) -> pd.DataFrame:
@ -811,7 +845,7 @@ class Operations:
by_fields, agg_fields = query_compiler._mappings.groupby_source_fields(by=by) by_fields, agg_fields = query_compiler._mappings.groupby_source_fields(by=by)
# Used defaultdict to avoid initialization of columns with lists # Used defaultdict to avoid initialization of columns with lists
results: Dict[str, List[Any]] = defaultdict(list) results: Dict[Any, List[Any]] = defaultdict(list)
if numeric_only: if numeric_only:
agg_fields = [ agg_fields = [
@ -823,7 +857,8 @@ class Operations:
# To return for creating multi-index on columns # To return for creating multi-index on columns
headers = [agg_field.column for agg_field in agg_fields] headers = [agg_field.column for agg_field in agg_fields]
percentiles: Optional[List[str]] = None percentiles: Optional[List[float]] = None
len_percentiles: int = 0
if quantiles: if quantiles:
percentiles = [ percentiles = [
quantile_to_percentile(x) quantile_to_percentile(x)
@ -833,6 +868,7 @@ class Operations:
else quantiles else quantiles
) )
] ]
len_percentiles = len(percentiles)
# Convert pandas aggs to ES equivalent # Convert pandas aggs to ES equivalent
es_aggs = self._map_pd_aggs_to_es_aggs(pd_aggs=pd_aggs, percentiles=percentiles) es_aggs = self._map_pd_aggs_to_es_aggs(pd_aggs=pd_aggs, percentiles=percentiles)
@ -894,8 +930,8 @@ class Operations:
if by_field.is_timestamp and isinstance(bucket_key, int): if by_field.is_timestamp and isinstance(bucket_key, int):
bucket_key = pd.to_datetime(bucket_key, unit="ms") bucket_key = pd.to_datetime(bucket_key, unit="ms")
if pd_aggs == ["quantile"] and len(percentiles) > 1: if pd_aggs == ["quantile"] and len_percentiles > 1:
bucket_key = [bucket_key] * len(percentiles) bucket_key = [bucket_key] * len_percentiles
results[by_field.column].extend( results[by_field.column].extend(
bucket_key if isinstance(bucket_key, list) else [bucket_key] bucket_key if isinstance(bucket_key, list) else [bucket_key]
@ -915,7 +951,7 @@ class Operations:
) )
# to construct index with quantiles # to construct index with quantiles
if pd_aggs == ["quantile"] and len(percentiles) > 1: if pd_aggs == ["quantile"] and percentiles and len_percentiles > 1:
results[None].extend([i / 100 for i in percentiles]) results[None].extend([i / 100 for i in percentiles])
# Process the calculated agg values to response # Process the calculated agg values to response
@ -929,9 +965,10 @@ class Operations:
for pd_agg, val in zip(pd_aggs, value): for pd_agg, val in zip(pd_aggs, value):
results[f"{key}_{pd_agg}"].append(val) results[f"{key}_{pd_agg}"].append(val)
# Just to maintain Output same as pandas with empty header. if pd_aggs == ["quantile"] and len_percentiles > 1:
if pd_aggs == ["quantile"] and len(percentiles) > 1: # by never holds None by default, we make an exception
by = by + [None] # here to maintain output same as pandas, also mypy complains
by = by + [None] # type: ignore
agg_df = pd.DataFrame(results).set_index(by).sort_index() agg_df = pd.DataFrame(results).set_index(by).sort_index()
@ -947,7 +984,7 @@ class Operations:
@staticmethod @staticmethod
def bucket_generator( def bucket_generator(
query_compiler: "QueryCompiler", body: "Query" query_compiler: "QueryCompiler", body: "Query"
) -> Generator[List[str], None, List[str]]: ) -> Generator[Sequence[Dict[str, Any]], None, Sequence[Dict[str, Any]]]:
""" """
This can be used for all groupby operations. This can be used for all groupby operations.
e.g. e.g.
@ -977,18 +1014,24 @@ class Operations:
) )
# Pagination Logic # Pagination Logic
composite_buckets = res["aggregations"]["groupby_buckets"] composite_buckets: Dict[str, Any] = res["aggregations"]["groupby_buckets"]
if "after_key" in composite_buckets:
after_key: Optional[Dict[str, Any]] = composite_buckets.get(
"after_key", None
)
buckets: Sequence[Dict[str, Any]] = composite_buckets["buckets"]
if after_key:
# yield the bucket which contains the result # yield the bucket which contains the result
yield composite_buckets["buckets"] yield buckets
body.composite_agg_after_key( body.composite_agg_after_key(
name="groupby_buckets", name="groupby_buckets",
after_key=composite_buckets["after_key"], after_key=after_key,
) )
else: else:
return composite_buckets["buckets"] return buckets
@staticmethod @staticmethod
def _map_pd_aggs_to_es_aggs( def _map_pd_aggs_to_es_aggs(
@ -1031,7 +1074,7 @@ class Operations:
extended_stats_es_aggs = {"avg", "min", "max", "sum"} extended_stats_es_aggs = {"avg", "min", "max", "sum"}
extended_stats_calls = 0 extended_stats_calls = 0
es_aggs = [] es_aggs: List[Any] = []
for pd_agg in pd_aggs: for pd_agg in pd_aggs:
if pd_agg in extended_stats_pd_aggs: if pd_agg in extended_stats_pd_aggs:
extended_stats_calls += 1 extended_stats_calls += 1
@ -1100,7 +1143,7 @@ class Operations:
def filter( def filter(
self, self,
query_compiler: "QueryCompiler", query_compiler: "QueryCompiler",
items: Optional[Sequence[str]] = None, items: Optional[List[str]] = None,
like: Optional[str] = None, like: Optional[str] = None,
regex: Optional[str] = None, regex: Optional[str] = None,
) -> None: ) -> None:
@ -1122,7 +1165,7 @@ class Operations:
f"to substring and regex operations not being available for Elasticsearch document IDs." f"to substring and regex operations not being available for Elasticsearch document IDs."
) )
def describe(self, query_compiler): def describe(self, query_compiler: "QueryCompiler") -> pd.DataFrame:
query_params, post_processing = self._resolve_tasks(query_compiler) query_params, post_processing = self._resolve_tasks(query_compiler)
size = self._size(query_params, post_processing) size = self._size(query_params, post_processing)
@ -1151,30 +1194,9 @@ class Operations:
["count", "mean", "std", "min", "25%", "50%", "75%", "max"] ["count", "mean", "std", "min", "25%", "50%", "75%", "max"]
) )
def to_pandas(self, query_compiler, show_progress=False): def to_pandas(
class PandasDataFrameCollector: self, query_compiler: "QueryCompiler", show_progress: bool = False
def __init__(self, show_progress): ) -> None:
self._df = None
self._show_progress = show_progress
def collect(self, df):
# This collector does not batch data on output. Therefore, batch_size is fixed to None and this method
# is only called once.
if self._df is not None:
raise RuntimeError(
"Logic error in execution, this method must only be called once for this"
"collector - batch_size == None"
)
self._df = df
@staticmethod
def batch_size():
# Do not change (see notes on collect)
return None
@property
def show_progress(self):
return self._show_progress
collector = PandasDataFrameCollector(show_progress) collector = PandasDataFrameCollector(show_progress)
@ -1182,35 +1204,12 @@ class Operations:
return collector._df return collector._df
def to_csv(self, query_compiler, show_progress=False, **kwargs): def to_csv(
class PandasToCSVCollector: self,
def __init__(self, show_progress, **args): query_compiler: "QueryCompiler",
self._args = args show_progress: bool = False,
self._show_progress = show_progress **kwargs: Union[bool, str],
self._ret = None ) -> None:
self._first_time = True
def collect(self, df):
# If this is the first time we collect results, then write header, otherwise don't write header
# and append results
if self._first_time:
self._first_time = False
df.to_csv(**self._args)
else:
# Don't write header, and change mode to append
self._args["header"] = False
self._args["mode"] = "a"
df.to_csv(**self._args)
@staticmethod
def batch_size():
# By default read n docs and then dump to csv
batch_size = DEFAULT_CSV_BATCH_OUTPUT_SIZE
return batch_size
@property
def show_progress(self):
return self._show_progress
collector = PandasToCSVCollector(show_progress, **kwargs) collector = PandasToCSVCollector(show_progress, **kwargs)
@ -1218,7 +1217,11 @@ class Operations:
return collector._ret return collector._ret
def _es_results(self, query_compiler, collector): def _es_results(
self,
query_compiler: "QueryCompiler",
collector: Union["PandasToCSVCollector", "PandasDataFrameCollector"],
) -> None:
query_params, post_processing = self._resolve_tasks(query_compiler) query_params, post_processing = self._resolve_tasks(query_compiler)
size, sort_params = Operations._query_params_to_size_and_sort(query_params) size, sort_params = Operations._query_params_to_size_and_sort(query_params)
@ -1245,7 +1248,7 @@ class Operations:
else: else:
body["_source"] = False body["_source"] = False
es_results = None es_results: Any = None
# If size=None use scan not search - then post sort results when in df # If size=None use scan not search - then post sort results when in df
# If size>10000 use scan # If size>10000 use scan
@ -1283,7 +1286,7 @@ class Operations:
df = self._apply_df_post_processing(df, post_processing) df = self._apply_df_post_processing(df, post_processing)
collector.collect(df) collector.collect(df)
def index_count(self, query_compiler, field): def index_count(self, query_compiler: "QueryCompiler", field: str) -> int:
# field is the index field so count values # field is the index field so count values
query_params, post_processing = self._resolve_tasks(query_compiler) query_params, post_processing = self._resolve_tasks(query_compiler)
@ -1297,12 +1300,13 @@ class Operations:
body = Query(query_params.query) body = Query(query_params.query)
body.exists(field, must=True) body.exists(field, must=True)
return query_compiler._client.count( count: int = query_compiler._client.count(
index=query_compiler._index_pattern, body=body.to_count_body() index=query_compiler._index_pattern, body=body.to_count_body()
)["count"] )["count"]
return count
def _validate_index_operation( def _validate_index_operation(
self, query_compiler: "QueryCompiler", items: Sequence[str] self, query_compiler: "QueryCompiler", items: List[str]
) -> RESOLVED_TASK_TYPE: ) -> RESOLVED_TASK_TYPE:
if not isinstance(items, list): if not isinstance(items, list):
raise TypeError(f"list item required - not {type(items)}") raise TypeError(f"list item required - not {type(items)}")
@ -1320,7 +1324,9 @@ class Operations:
return query_params, post_processing return query_params, post_processing
def index_matches_count(self, query_compiler, field, items): def index_matches_count(
self, query_compiler: "QueryCompiler", field: str, items: List[Any]
) -> int:
query_params, post_processing = self._validate_index_operation( query_params, post_processing = self._validate_index_operation(
query_compiler, items query_compiler, items
) )
@ -1332,12 +1338,13 @@ class Operations:
else: else:
body.terms(field, items, must=True) body.terms(field, items, must=True)
return query_compiler._client.count( count: int = query_compiler._client.count(
index=query_compiler._index_pattern, body=body.to_count_body() index=query_compiler._index_pattern, body=body.to_count_body()
)["count"] )["count"]
return count
def drop_index_values( def drop_index_values(
self, query_compiler: "QueryCompiler", field: str, items: Sequence[str] self, query_compiler: "QueryCompiler", field: str, items: List[str]
) -> None: ) -> None:
self._validate_index_operation(query_compiler, items) self._validate_index_operation(query_compiler, items)
@ -1349,6 +1356,7 @@ class Operations:
# a in ['a','b','c'] # a in ['a','b','c']
# b not in ['a','b','c'] # b not in ['a','b','c']
# For now use term queries # For now use term queries
task: Union["QueryIdsTask", "QueryTermsTask"]
if field == Index.ID_INDEX_FIELD: if field == Index.ID_INDEX_FIELD:
task = QueryIdsTask(False, items) task = QueryIdsTask(False, items)
else: else:
@ -1356,11 +1364,12 @@ class Operations:
self._tasks.append(task) self._tasks.append(task)
def filter_index_values( def filter_index_values(
self, query_compiler: "QueryCompiler", field: str, items: Sequence[str] self, query_compiler: "QueryCompiler", field: str, items: List[str]
) -> None: ) -> None:
# Basically .drop_index_values() except with must=True on tasks. # Basically .drop_index_values() except with must=True on tasks.
self._validate_index_operation(query_compiler, items) self._validate_index_operation(query_compiler, items)
task: Union["QueryIdsTask", "QueryTermsTask"]
if field == Index.ID_INDEX_FIELD: if field == Index.ID_INDEX_FIELD:
task = QueryIdsTask(True, items, sort_index_by_ids=True) task = QueryIdsTask(True, items, sort_index_by_ids=True)
else: else:
@ -1406,7 +1415,7 @@ class Operations:
# other operations require pre-queries and then combinations # other operations require pre-queries and then combinations
# other operations require in-core post-processing of results # other operations require in-core post-processing of results
query_params = QueryParams() query_params = QueryParams()
post_processing = [] post_processing: List["PostProcessingAction"] = []
for task in self._tasks: for task in self._tasks:
query_params, post_processing = task.resolve_task( query_params, post_processing = task.resolve_task(
@ -1439,7 +1448,7 @@ class Operations:
# This can return None # This can return None
return size return size
def es_info(self, query_compiler, buf): def es_info(self, query_compiler: "QueryCompiler", buf: TextIO) -> None:
buf.write("Operations:\n") buf.write("Operations:\n")
buf.write(f" tasks: {self._tasks}\n") buf.write(f" tasks: {self._tasks}\n")
@ -1459,7 +1468,7 @@ class Operations:
buf.write(f" body: {body}\n") buf.write(f" body: {body}\n")
buf.write(f" post_processing: {post_processing}\n") buf.write(f" post_processing: {post_processing}\n")
def update_query(self, boolean_filter): def update_query(self, boolean_filter: "BooleanFilter") -> None:
task = BooleanFilterTask(boolean_filter) task = BooleanFilterTask(boolean_filter)
self._tasks.append(task) self._tasks.append(task)
@ -1477,3 +1486,58 @@ def quantile_to_percentile(quantile: Union[int, float]) -> float:
# quantile * 100 = percentile # quantile * 100 = percentile
# return float(...) because min(1.0) gives 1 # return float(...) because min(1.0) gives 1
return float(min(100, max(0, quantile * 100))) return float(min(100, max(0, quantile * 100)))
class PandasToCSVCollector:
def __init__(self, show_progress: bool, **kwargs: Union[bool, str]) -> None:
self._args = kwargs
self._show_progress = show_progress
self._ret = None
self._first_time = True
def collect(self, df: "pd.DataFrame") -> None:
# If this is the first time we collect results, then write header, otherwise don't write header
# and append results
if self._first_time:
self._first_time = False
df.to_csv(**self._args)
else:
# Don't write header, and change mode to append
self._args["header"] = False
self._args["mode"] = "a"
df.to_csv(**self._args)
@staticmethod
def batch_size() -> int:
# By default read n docs and then dump to csv
batch_size: int = DEFAULT_CSV_BATCH_OUTPUT_SIZE
return batch_size
@property
def show_progress(self) -> bool:
return self._show_progress
class PandasDataFrameCollector:
def __init__(self, show_progress: bool) -> None:
self._df = None
self._show_progress = show_progress
def collect(self, df: "pd.DataFrame") -> None:
# This collector does not batch data on output. Therefore, batch_size is fixed to None and this method
# is only called once.
if self._df is not None:
raise RuntimeError(
"Logic error in execution, this method must only be called once for this"
"collector - batch_size == None"
)
self._df = df
@staticmethod
def batch_size() -> None:
# Do not change (see notes on collect)
return None
@property
def show_progress(self) -> bool:
return self._show_progress

View File

@ -15,8 +15,10 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
from typing import TYPE_CHECKING
import numpy as np import numpy as np
from pandas.plotting._matplotlib import converter from pandas.plotting._matplotlib import converter # type: ignore
try: try:
# pandas<1.3.0 # pandas<1.3.0
@ -26,13 +28,13 @@ except ImportError:
from pandas.core.dtypes.generic import ABCIndex from pandas.core.dtypes.generic import ABCIndex
try: # pandas>=1.2.0 try: # pandas>=1.2.0
from pandas.plotting._matplotlib.tools import ( from pandas.plotting._matplotlib.tools import ( # type: ignore
create_subplots, create_subplots,
flatten_axes, flatten_axes,
set_ticks_props, set_ticks_props,
) )
except ImportError: # pandas<1.2.0 except ImportError: # pandas<1.2.0
from pandas.plotting._matplotlib.tools import ( from pandas.plotting._matplotlib.tools import ( # type: ignore
_flatten as flatten_axes, _flatten as flatten_axes,
_set_ticks_props as set_ticks_props, _set_ticks_props as set_ticks_props,
_subplots as create_subplots, _subplots as create_subplots,
@ -40,6 +42,9 @@ except ImportError: # pandas<1.2.0
from eland.utils import try_sort from eland.utils import try_sort
if TYPE_CHECKING:
from numpy.typing import ArrayLike
def hist_series( def hist_series(
self, self,
@ -53,8 +58,8 @@ def hist_series(
figsize=None, figsize=None,
bins=10, bins=10,
**kwds, **kwds,
): ) -> "ArrayLike":
import matplotlib.pyplot as plt import matplotlib.pyplot as plt # type: ignore
if by is None: if by is None:
if kwds.get("layout", None) is not None: if kwds.get("layout", None) is not None:

View File

@ -17,9 +17,19 @@
import copy import copy
from datetime import datetime from datetime import datetime
from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Union from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
Optional,
Sequence,
TextIO,
Tuple,
Union,
)
import numpy as np # type: ignore import numpy as np
import pandas as pd # type: ignore import pandas as pd # type: ignore
from eland.common import ( from eland.common import (
@ -28,11 +38,15 @@ from eland.common import (
ensure_es_client, ensure_es_client,
) )
from eland.field_mappings import FieldMappings from eland.field_mappings import FieldMappings
from eland.filter import QueryFilter from eland.filter import BooleanFilter, QueryFilter
from eland.index import Index from eland.index import Index
from eland.operations import Operations from eland.operations import Operations
if TYPE_CHECKING: if TYPE_CHECKING:
from elasticsearch import Elasticsearch
from eland.arithmetics import ArithmeticSeries
from .tasks import ArithmeticOpFieldsTask # noqa: F401 from .tasks import ArithmeticOpFieldsTask # noqa: F401
@ -67,8 +81,10 @@ class QueryCompiler:
def __init__( def __init__(
self, self,
client=None, client: Optional[
index_pattern=None, Union[str, List[str], Tuple[str, ...], "Elasticsearch"]
] = None,
index_pattern: Optional[str] = None,
display_names=None, display_names=None,
index_field=None, index_field=None,
to_copy=None, to_copy=None,
@ -77,15 +93,15 @@ class QueryCompiler:
if to_copy is not None: if to_copy is not None:
self._client = to_copy._client self._client = to_copy._client
self._index_pattern = to_copy._index_pattern self._index_pattern = to_copy._index_pattern
self._index = Index(self, to_copy._index.es_index_field) self._index: "Index" = Index(self, to_copy._index.es_index_field)
self._operations = copy.deepcopy(to_copy._operations) self._operations: "Operations" = copy.deepcopy(to_copy._operations)
self._mappings: FieldMappings = copy.deepcopy(to_copy._mappings) self._mappings: FieldMappings = copy.deepcopy(to_copy._mappings)
else: else:
self._client = ensure_es_client(client) self._client = ensure_es_client(client)
self._index_pattern = index_pattern self._index_pattern = index_pattern
# Get and persist mappings, this allows us to correctly # Get and persist mappings, this allows us to correctly
# map returned types from Elasticsearch to pandas datatypes # map returned types from Elasticsearch to pandas datatypes
self._mappings: FieldMappings = FieldMappings( self._mappings = FieldMappings(
client=self._client, client=self._client,
index_pattern=self._index_pattern, index_pattern=self._index_pattern,
display_names=display_names, display_names=display_names,
@ -103,15 +119,15 @@ class QueryCompiler:
return pd.Index(columns) return pd.Index(columns)
def _get_display_names(self): def _get_display_names(self) -> "pd.Index":
display_names = self._mappings.display_names display_names = self._mappings.display_names
return pd.Index(display_names) return pd.Index(display_names)
def _set_display_names(self, display_names): def _set_display_names(self, display_names: List[str]) -> None:
self._mappings.display_names = display_names self._mappings.display_names = display_names
def get_field_names(self, include_scripted_fields): def get_field_names(self, include_scripted_fields: bool) -> List[str]:
return self._mappings.get_field_names(include_scripted_fields) return self._mappings.get_field_names(include_scripted_fields)
def add_scripted_field(self, scripted_field_name, display_name, pd_dtype): def add_scripted_field(self, scripted_field_name, display_name, pd_dtype):
@ -129,7 +145,12 @@ class QueryCompiler:
# END Index, columns, and dtypes objects # END Index, columns, and dtypes objects
def _es_results_to_pandas(self, results, batch_size=None, show_progress=False): def _es_results_to_pandas(
self,
results: Dict[Any, Any],
batch_size: Optional[int] = None,
show_progress: bool = False,
) -> "pd.Dataframe":
""" """
Parameters Parameters
---------- ----------
@ -300,7 +321,7 @@ class QueryCompiler:
return partial_result, df return partial_result, df
def _flatten_dict(self, y, field_mapping_cache): def _flatten_dict(self, y, field_mapping_cache: "FieldMappingCache"):
out = {} out = {}
def flatten(x, name=""): def flatten(x, name=""):
@ -368,7 +389,7 @@ class QueryCompiler:
""" """
return self._operations.index_count(self, self.index.es_index_field) return self._operations.index_count(self, self.index.es_index_field)
def _index_matches_count(self, items): def _index_matches_count(self, items: List[Any]) -> int:
""" """
Returns Returns
------- -------
@ -386,10 +407,10 @@ class QueryCompiler:
df[c] = pd.Series(dtype=d) df[c] = pd.Series(dtype=d)
return df return df
def copy(self): def copy(self) -> "QueryCompiler":
return QueryCompiler(to_copy=self) return QueryCompiler(to_copy=self)
def rename(self, renames, inplace=False): def rename(self, renames, inplace: bool = False) -> "QueryCompiler":
if inplace: if inplace:
self._mappings.rename(renames) self._mappings.rename(renames)
return self return self
@ -398,21 +419,23 @@ class QueryCompiler:
result._mappings.rename(renames) result._mappings.rename(renames)
return result return result
def head(self, n): def head(self, n: int) -> "QueryCompiler":
result = self.copy() result = self.copy()
result._operations.head(self._index, n) result._operations.head(self._index, n)
return result return result
def tail(self, n): def tail(self, n: int) -> "QueryCompiler":
result = self.copy() result = self.copy()
result._operations.tail(self._index, n) result._operations.tail(self._index, n)
return result return result
def sample(self, n=None, frac=None, random_state=None): def sample(
self, n: Optional[int] = None, frac=None, random_state=None
) -> "QueryCompiler":
result = self.copy() result = self.copy()
if n is None and frac is None: if n is None and frac is None:
@ -501,11 +524,11 @@ class QueryCompiler:
query = {"multi_match": options} query = {"multi_match": options}
return QueryFilter(query) return QueryFilter(query)
def es_query(self, query): def es_query(self, query: Dict[str, Any]) -> "QueryCompiler":
return self._update_query(QueryFilter(query)) return self._update_query(QueryFilter(query))
# To/From Pandas # To/From Pandas
def to_pandas(self, show_progress=False): def to_pandas(self, show_progress: bool = False):
"""Converts Eland DataFrame to Pandas DataFrame. """Converts Eland DataFrame to Pandas DataFrame.
Returns: Returns:
@ -543,7 +566,9 @@ class QueryCompiler:
return result return result
def drop(self, index=None, columns=None): def drop(
self, index: Optional[str] = None, columns: Optional[List[str]] = None
) -> "QueryCompiler":
result = self.copy() result = self.copy()
# Drop gets all columns and removes drops # Drop gets all columns and removes drops
@ -559,7 +584,7 @@ class QueryCompiler:
def filter( def filter(
self, self,
items: Optional[Sequence[str]] = None, items: Optional[List[str]] = None,
like: Optional[str] = None, like: Optional[str] = None,
regex: Optional[str] = None, regex: Optional[str] = None,
) -> "QueryCompiler": ) -> "QueryCompiler":
@ -570,53 +595,55 @@ class QueryCompiler:
result._operations.filter(self, items=items, like=like, regex=regex) result._operations.filter(self, items=items, like=like, regex=regex)
return result return result
def aggs(self, func: List[str], numeric_only: Optional[bool] = None): def aggs(
self, func: List[str], numeric_only: Optional[bool] = None
) -> pd.DataFrame:
return self._operations.aggs(self, func, numeric_only=numeric_only) return self._operations.aggs(self, func, numeric_only=numeric_only)
def count(self): def count(self) -> pd.Series:
return self._operations.count(self) return self._operations.count(self)
def mean(self, numeric_only: Optional[bool] = None): def mean(self, numeric_only: Optional[bool] = None) -> pd.Series:
return self._operations._metric_agg_series( return self._operations._metric_agg_series(
self, ["mean"], numeric_only=numeric_only self, ["mean"], numeric_only=numeric_only
) )
def var(self, numeric_only: Optional[bool] = None): def var(self, numeric_only: Optional[bool] = None) -> pd.Series:
return self._operations._metric_agg_series( return self._operations._metric_agg_series(
self, ["var"], numeric_only=numeric_only self, ["var"], numeric_only=numeric_only
) )
def std(self, numeric_only: Optional[bool] = None): def std(self, numeric_only: Optional[bool] = None) -> pd.Series:
return self._operations._metric_agg_series( return self._operations._metric_agg_series(
self, ["std"], numeric_only=numeric_only self, ["std"], numeric_only=numeric_only
) )
def mad(self, numeric_only: Optional[bool] = None): def mad(self, numeric_only: Optional[bool] = None) -> pd.Series:
return self._operations._metric_agg_series( return self._operations._metric_agg_series(
self, ["mad"], numeric_only=numeric_only self, ["mad"], numeric_only=numeric_only
) )
def median(self, numeric_only: Optional[bool] = None): def median(self, numeric_only: Optional[bool] = None) -> pd.Series:
return self._operations._metric_agg_series( return self._operations._metric_agg_series(
self, ["median"], numeric_only=numeric_only self, ["median"], numeric_only=numeric_only
) )
def sum(self, numeric_only: Optional[bool] = None): def sum(self, numeric_only: Optional[bool] = None) -> pd.Series:
return self._operations._metric_agg_series( return self._operations._metric_agg_series(
self, ["sum"], numeric_only=numeric_only self, ["sum"], numeric_only=numeric_only
) )
def min(self, numeric_only: Optional[bool] = None): def min(self, numeric_only: Optional[bool] = None) -> pd.Series:
return self._operations._metric_agg_series( return self._operations._metric_agg_series(
self, ["min"], numeric_only=numeric_only self, ["min"], numeric_only=numeric_only
) )
def max(self, numeric_only: Optional[bool] = None): def max(self, numeric_only: Optional[bool] = None) -> pd.Series:
return self._operations._metric_agg_series( return self._operations._metric_agg_series(
self, ["max"], numeric_only=numeric_only self, ["max"], numeric_only=numeric_only
) )
def nunique(self): def nunique(self) -> pd.Series:
return self._operations._metric_agg_series( return self._operations._metric_agg_series(
self, ["nunique"], numeric_only=False self, ["nunique"], numeric_only=False
) )
@ -673,7 +700,7 @@ class QueryCompiler:
dropna: bool = True, dropna: bool = True,
is_dataframe_agg: bool = False, is_dataframe_agg: bool = False,
numeric_only: Optional[bool] = True, numeric_only: Optional[bool] = True,
quantiles: Union[int, float, List[int], List[float], None] = None, quantiles: Optional[Union[int, float, List[int], List[float]]] = None,
) -> pd.DataFrame: ) -> pd.DataFrame:
return self._operations.aggs_groupby( return self._operations.aggs_groupby(
self, self,
@ -691,27 +718,27 @@ class QueryCompiler:
def value_counts(self, es_size: int) -> pd.Series: def value_counts(self, es_size: int) -> pd.Series:
return self._operations.value_counts(self, es_size) return self._operations.value_counts(self, es_size)
def es_info(self, buf): def es_info(self, buf: TextIO) -> None:
buf.write(f"es_index_pattern: {self._index_pattern}\n") buf.write(f"es_index_pattern: {self._index_pattern}\n")
self._index.es_info(buf) self._index.es_info(buf)
self._mappings.es_info(buf) self._mappings.es_info(buf)
self._operations.es_info(self, buf) self._operations.es_info(self, buf)
def describe(self): def describe(self) -> pd.DataFrame:
return self._operations.describe(self) return self._operations.describe(self)
def _hist(self, num_bins): def _hist(self, num_bins: int) -> Tuple[pd.DataFrame, pd.DataFrame]:
return self._operations.hist(self, num_bins) return self._operations.hist(self, num_bins)
def _update_query(self, boolean_filter): def _update_query(self, boolean_filter: "BooleanFilter") -> "QueryCompiler":
result = self.copy() result = self.copy()
result._operations.update_query(boolean_filter) result._operations.update_query(boolean_filter)
return result return result
def check_arithmetics(self, right): def check_arithmetics(self, right: "QueryCompiler") -> None:
""" """
Compare 2 query_compilers to see if arithmetic operations can be performed by the NDFrame object. Compare 2 query_compilers to see if arithmetic operations can be performed by the NDFrame object.
@ -750,7 +777,9 @@ class QueryCompiler:
f"{self._index_pattern} != {right._index_pattern}" f"{self._index_pattern} != {right._index_pattern}"
) )
def arithmetic_op_fields(self, display_name, arithmetic_object): def arithmetic_op_fields(
self, display_name: str, arithmetic_object: "ArithmeticSeries"
) -> "QueryCompiler":
result = self.copy() result = self.copy()
# create a new field name for this display name # create a new field name for this display name
@ -758,7 +787,7 @@ class QueryCompiler:
# add scripted field # add scripted field
result._mappings.add_scripted_field( result._mappings.add_scripted_field(
scripted_field_name, display_name, arithmetic_object.dtype.name scripted_field_name, display_name, arithmetic_object.dtype.name # type: ignore
) )
result._operations.arithmetic_op_fields(scripted_field_name, arithmetic_object) result._operations.arithmetic_op_fields(scripted_field_name, arithmetic_object)
@ -768,7 +797,7 @@ class QueryCompiler:
def get_arithmetic_op_fields(self) -> Optional["ArithmeticOpFieldsTask"]: def get_arithmetic_op_fields(self) -> Optional["ArithmeticOpFieldsTask"]:
return self._operations.get_arithmetic_op_fields() return self._operations.get_arithmetic_op_fields()
def display_name_to_aggregatable_name(self, display_name: str) -> str: def display_name_to_aggregatable_name(self, display_name: str) -> Optional[str]:
aggregatable_field_name = self._mappings.aggregatable_field_name(display_name) aggregatable_field_name = self._mappings.aggregatable_field_name(display_name)
return aggregatable_field_name return aggregatable_field_name
@ -780,13 +809,13 @@ class FieldMappingCache:
DataFrame access is slower than dict access. DataFrame access is slower than dict access.
""" """
def __init__(self, mappings): def __init__(self, mappings: "FieldMappings") -> None:
self._mappings = mappings self._mappings = mappings
self._field_name_pd_dtype = dict() self._field_name_pd_dtype: Dict[str, str] = dict()
self._date_field_format = dict() self._date_field_format: Dict[str, str] = dict()
def field_name_pd_dtype(self, es_field_name): def field_name_pd_dtype(self, es_field_name: str) -> str:
if es_field_name in self._field_name_pd_dtype: if es_field_name in self._field_name_pd_dtype:
return self._field_name_pd_dtype[es_field_name] return self._field_name_pd_dtype[es_field_name]
@ -797,7 +826,7 @@ class FieldMappingCache:
return pd_dtype return pd_dtype
def date_field_format(self, es_field_name): def date_field_format(self, es_field_name: str) -> str:
if es_field_name in self._date_field_format: if es_field_name in self._date_field_format:
return self._date_field_format[es_field_name] return self._date_field_format[es_field_name]

View File

@ -38,8 +38,8 @@ from io import StringIO
from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Tuple, Union from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Tuple, Union
import numpy as np import numpy as np
import pandas as pd import pandas as pd # type: ignore
from pandas.io.common import _expand_user, stringify_path from pandas.io.common import _expand_user, stringify_path # type: ignore
import eland.plotting import eland.plotting
from eland.arithmetics import ArithmeticNumber, ArithmeticSeries, ArithmeticString from eland.arithmetics import ArithmeticNumber, ArithmeticSeries, ArithmeticString
@ -61,10 +61,10 @@ from eland.filter import (
from eland.ndframe import NDFrame from eland.ndframe import NDFrame
from eland.utils import to_list from eland.utils import to_list
if TYPE_CHECKING: # type: ignore if TYPE_CHECKING:
from elasticsearch import Elasticsearch # noqa: F401 from elasticsearch import Elasticsearch
from eland.query_compiler import QueryCompiler # noqa: F401 from eland.query_compiler import QueryCompiler
def _get_method_name() -> str: def _get_method_name() -> str:
@ -175,7 +175,7 @@ class Series(NDFrame):
return num_rows, num_columns return num_rows, num_columns
@property @property
def es_field_name(self) -> str: def es_field_name(self) -> pd.Index:
""" """
Returns Returns
------- -------
@ -185,7 +185,7 @@ class Series(NDFrame):
return self._query_compiler.get_field_names(include_scripted_fields=True)[0] return self._query_compiler.get_field_names(include_scripted_fields=True)[0]
@property @property
def name(self) -> str: def name(self) -> pd.Index:
return self._query_compiler.columns[0] return self._query_compiler.columns[0]
@name.setter @name.setter
@ -793,7 +793,7 @@ class Series(NDFrame):
return buf.getvalue() return buf.getvalue()
def __add__(self, right): def __add__(self, right: "Series") -> "Series":
""" """
Return addition of series and right, element-wise (binary operator add). Return addition of series and right, element-wise (binary operator add).
@ -853,7 +853,7 @@ class Series(NDFrame):
""" """
return self._numeric_op(right, _get_method_name()) return self._numeric_op(right, _get_method_name())
def __truediv__(self, right): def __truediv__(self, right: "Series") -> "Series":
""" """
Return floating division of series and right, element-wise (binary operator truediv). Return floating division of series and right, element-wise (binary operator truediv).
@ -892,7 +892,7 @@ class Series(NDFrame):
""" """
return self._numeric_op(right, _get_method_name()) return self._numeric_op(right, _get_method_name())
def __floordiv__(self, right): def __floordiv__(self, right: "Series") -> "Series":
""" """
Return integer division of series and right, element-wise (binary operator floordiv //). Return integer division of series and right, element-wise (binary operator floordiv //).
@ -931,7 +931,7 @@ class Series(NDFrame):
""" """
return self._numeric_op(right, _get_method_name()) return self._numeric_op(right, _get_method_name())
def __mod__(self, right): def __mod__(self, right: "Series") -> "Series":
""" """
Return modulo of series and right, element-wise (binary operator mod %). Return modulo of series and right, element-wise (binary operator mod %).
@ -970,7 +970,7 @@ class Series(NDFrame):
""" """
return self._numeric_op(right, _get_method_name()) return self._numeric_op(right, _get_method_name())
def __mul__(self, right): def __mul__(self, right: "Series") -> "Series":
""" """
Return multiplication of series and right, element-wise (binary operator mul). Return multiplication of series and right, element-wise (binary operator mul).
@ -1009,7 +1009,7 @@ class Series(NDFrame):
""" """
return self._numeric_op(right, _get_method_name()) return self._numeric_op(right, _get_method_name())
def __sub__(self, right): def __sub__(self, right: "Series") -> "Series":
""" """
Return subtraction of series and right, element-wise (binary operator sub). Return subtraction of series and right, element-wise (binary operator sub).
@ -1048,7 +1048,7 @@ class Series(NDFrame):
""" """
return self._numeric_op(right, _get_method_name()) return self._numeric_op(right, _get_method_name())
def __pow__(self, right): def __pow__(self, right: "Series") -> "Series":
""" """
Return exponential power of series and right, element-wise (binary operator pow). Return exponential power of series and right, element-wise (binary operator pow).
@ -1087,7 +1087,7 @@ class Series(NDFrame):
""" """
return self._numeric_op(right, _get_method_name()) return self._numeric_op(right, _get_method_name())
def __radd__(self, left): def __radd__(self, left: "Series") -> "Series":
""" """
Return addition of series and left, element-wise (binary operator add). Return addition of series and left, element-wise (binary operator add).
@ -1119,7 +1119,7 @@ class Series(NDFrame):
""" """
return self._numeric_op(left, _get_method_name()) return self._numeric_op(left, _get_method_name())
def __rtruediv__(self, left): def __rtruediv__(self, left: "Series") -> "Series":
""" """
Return division of series and left, element-wise (binary operator div). Return division of series and left, element-wise (binary operator div).
@ -1151,7 +1151,7 @@ class Series(NDFrame):
""" """
return self._numeric_op(left, _get_method_name()) return self._numeric_op(left, _get_method_name())
def __rfloordiv__(self, left): def __rfloordiv__(self, left: "Series") -> "Series":
""" """
Return integer division of series and left, element-wise (binary operator floordiv //). Return integer division of series and left, element-wise (binary operator floordiv //).
@ -1183,7 +1183,7 @@ class Series(NDFrame):
""" """
return self._numeric_op(left, _get_method_name()) return self._numeric_op(left, _get_method_name())
def __rmod__(self, left): def __rmod__(self, left: "Series") -> "Series":
""" """
Return modulo of series and left, element-wise (binary operator mod %). Return modulo of series and left, element-wise (binary operator mod %).
@ -1215,7 +1215,7 @@ class Series(NDFrame):
""" """
return self._numeric_op(left, _get_method_name()) return self._numeric_op(left, _get_method_name())
def __rmul__(self, left): def __rmul__(self, left: "Series") -> "Series":
""" """
Return multiplication of series and left, element-wise (binary operator mul). Return multiplication of series and left, element-wise (binary operator mul).
@ -1247,7 +1247,7 @@ class Series(NDFrame):
""" """
return self._numeric_op(left, _get_method_name()) return self._numeric_op(left, _get_method_name())
def __rpow__(self, left): def __rpow__(self, left: "Series") -> "Series":
""" """
Return exponential power of series and left, element-wise (binary operator pow). Return exponential power of series and left, element-wise (binary operator pow).
@ -1279,7 +1279,7 @@ class Series(NDFrame):
""" """
return self._numeric_op(left, _get_method_name()) return self._numeric_op(left, _get_method_name())
def __rsub__(self, left): def __rsub__(self, left: "Series") -> "Series":
""" """
Return subtraction of series and left, element-wise (binary operator sub). Return subtraction of series and left, element-wise (binary operator sub).
@ -1398,7 +1398,7 @@ class Series(NDFrame):
return series return series
def max(self, numeric_only=None): def max(self, numeric_only: Optional[bool] = None) -> pd.Series:
""" """
Return the maximum of the Series values Return the maximum of the Series values
@ -1422,7 +1422,7 @@ class Series(NDFrame):
results = super().max(numeric_only=numeric_only) results = super().max(numeric_only=numeric_only)
return results.squeeze() return results.squeeze()
def mean(self, numeric_only=None): def mean(self, numeric_only: Optional[bool] = None) -> pd.Series:
""" """
Return the mean of the Series values Return the mean of the Series values
@ -1446,7 +1446,7 @@ class Series(NDFrame):
results = super().mean(numeric_only=numeric_only) results = super().mean(numeric_only=numeric_only)
return results.squeeze() return results.squeeze()
def median(self, numeric_only=None): def median(self, numeric_only: Optional[bool] = None) -> pd.Series:
""" """
Return the median of the Series values Return the median of the Series values
@ -1470,7 +1470,7 @@ class Series(NDFrame):
results = super().median(numeric_only=numeric_only) results = super().median(numeric_only=numeric_only)
return results.squeeze() return results.squeeze()
def min(self, numeric_only=None): def min(self, numeric_only: Optional[bool] = None) -> pd.Series:
""" """
Return the minimum of the Series values Return the minimum of the Series values
@ -1494,7 +1494,7 @@ class Series(NDFrame):
results = super().min(numeric_only=numeric_only) results = super().min(numeric_only=numeric_only)
return results.squeeze() return results.squeeze()
def sum(self, numeric_only=None): def sum(self, numeric_only: Optional[bool] = None) -> pd.Series:
""" """
Return the sum of the Series values Return the sum of the Series values
@ -1518,7 +1518,7 @@ class Series(NDFrame):
results = super().sum(numeric_only=numeric_only) results = super().sum(numeric_only=numeric_only)
return results.squeeze() return results.squeeze()
def nunique(self): def nunique(self) -> pd.Series:
""" """
Return the number of unique values in a Series Return the number of unique values in a Series
@ -1540,7 +1540,7 @@ class Series(NDFrame):
results = super().nunique() results = super().nunique()
return results.squeeze() return results.squeeze()
def var(self, numeric_only=None): def var(self, numeric_only: Optional[bool] = None) -> pd.Series:
""" """
Return variance for a Series Return variance for a Series
@ -1562,7 +1562,7 @@ class Series(NDFrame):
results = super().var(numeric_only=numeric_only) results = super().var(numeric_only=numeric_only)
return results.squeeze() return results.squeeze()
def std(self, numeric_only=None): def std(self, numeric_only: Optional[bool] = None) -> pd.Series:
""" """
Return standard deviation for a Series Return standard deviation for a Series
@ -1584,7 +1584,7 @@ class Series(NDFrame):
results = super().std(numeric_only=numeric_only) results = super().std(numeric_only=numeric_only)
return results.squeeze() return results.squeeze()
def mad(self, numeric_only=None): def mad(self, numeric_only: Optional[bool] = None) -> pd.Series:
""" """
Return median absolute deviation for a Series Return median absolute deviation for a Series
@ -1643,7 +1643,7 @@ class Series(NDFrame):
# def values TODO - not implemented as causes current implementation of query to fail # def values TODO - not implemented as causes current implementation of query to fail
def to_numpy(self): def to_numpy(self) -> None:
""" """
Not implemented. Not implemented.

View File

@ -16,7 +16,7 @@
# under the License. # under the License.
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, List, Tuple from typing import TYPE_CHECKING, List, Tuple
from eland import SortOrder from eland import SortOrder
from eland.actions import HeadAction, SortIndexAction, TailAction from eland.actions import HeadAction, SortIndexAction, TailAction
@ -253,7 +253,7 @@ class QueryIdsTask(Task):
class QueryTermsTask(Task): class QueryTermsTask(Task):
def __init__(self, must: bool, field: str, terms: List[Any]): def __init__(self, must: bool, field: str, terms: List[str]):
""" """
Parameters Parameters
---------- ----------

View File

@ -38,7 +38,10 @@ TYPED_FILES = (
"eland/tasks.py", "eland/tasks.py",
"eland/utils.py", "eland/utils.py",
"eland/groupby.py", "eland/groupby.py",
"eland/operations.py",
"eland/ndframe.py",
"eland/ml/__init__.py", "eland/ml/__init__.py",
"eland/ml/_optional.py",
"eland/ml/_model_serializer.py", "eland/ml/_model_serializer.py",
"eland/ml/ml_model.py", "eland/ml/ml_model.py",
"eland/ml/transformers/__init__.py", "eland/ml/transformers/__init__.py",
@ -46,6 +49,7 @@ TYPED_FILES = (
"eland/ml/transformers/lightgbm.py", "eland/ml/transformers/lightgbm.py",
"eland/ml/transformers/sklearn.py", "eland/ml/transformers/sklearn.py",
"eland/ml/transformers/xgboost.py", "eland/ml/transformers/xgboost.py",
"eland/plotting/_matplotlib/__init__.py",
) )
@ -60,7 +64,9 @@ def format(session):
@nox.session(reuse_venv=True) @nox.session(reuse_venv=True)
def lint(session): def lint(session):
session.install("black", "flake8", "mypy", "isort") # Install numpy to use its mypy plugin
# https://numpy.org/devdocs/reference/typing.html#mypy-plugin
session.install("black", "flake8", "mypy", "isort", "numpy")
session.install("--pre", "elasticsearch") session.install("--pre", "elasticsearch")
session.run("python", "utils/license-headers.py", "check", *SOURCE_FILES) session.run("python", "utils/license-headers.py", "check", *SOURCE_FILES)
session.run("black", "--check", "--target-version=py37", *SOURCE_FILES) session.run("black", "--check", "--target-version=py37", *SOURCE_FILES)

View File

@ -10,3 +10,4 @@ xgboost>=1
nox nox
lightgbm lightgbm
pytest-cov pytest-cov
mypy

View File

@ -1,2 +1,4 @@
[isort] [isort]
profile = black profile = black
[mypy]
plugins = numpy.typing.mypy_plugin