Add type hints to 'eland.operations' and 'eland.ndframe'

This commit is contained in:
P. Sai Vinay 2021-08-02 22:20:35 +05:30 committed by GitHub
parent c0e861dc77
commit 823f01cc6c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 428 additions and 276 deletions

View File

@ -19,9 +19,11 @@ from abc import ABC, abstractmethod
from io import StringIO
from typing import TYPE_CHECKING, Any, List, Union
import numpy as np # type: ignore
import numpy as np
if TYPE_CHECKING:
from numpy.typing import DTypeLike
from .query_compiler import QueryCompiler
@ -32,7 +34,7 @@ class ArithmeticObject(ABC):
pass
@abstractmethod
def dtype(self) -> np.dtype:
def dtype(self) -> "DTypeLike":
pass
@abstractmethod
@ -52,7 +54,7 @@ class ArithmeticString(ArithmeticObject):
return self.value
@property
def dtype(self) -> np.dtype:
def dtype(self) -> "DTypeLike":
return np.dtype(object)
@property
@ -64,7 +66,7 @@ class ArithmeticString(ArithmeticObject):
class ArithmeticNumber(ArithmeticObject):
def __init__(self, value: Union[int, float], dtype: np.dtype):
def __init__(self, value: Union[int, float], dtype: "DTypeLike"):
self._value = value
self._dtype = dtype
@ -76,7 +78,7 @@ class ArithmeticNumber(ArithmeticObject):
return f"{self._value}"
@property
def dtype(self) -> np.dtype:
def dtype(self) -> "DTypeLike":
return self._dtype
def __repr__(self) -> str:
@ -89,8 +91,8 @@ class ArithmeticSeries(ArithmeticObject):
"""
def __init__(
self, query_compiler: "QueryCompiler", display_name: str, dtype: np.dtype
):
self, query_compiler: "QueryCompiler", display_name: str, dtype: "DTypeLike"
) -> None:
# type defs
self._value: str
self._tasks: List["ArithmeticTask"]
@ -121,7 +123,7 @@ class ArithmeticSeries(ArithmeticObject):
return self._value
@property
def dtype(self) -> np.dtype:
def dtype(self) -> "DTypeLike":
return self._dtype
def __repr__(self) -> str:

View File

@ -18,12 +18,24 @@
import re
import warnings
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Optional,
Tuple,
Union,
cast,
)
import numpy as np # type: ignore
import pandas as pd # type: ignore
from elasticsearch import Elasticsearch
if TYPE_CHECKING:
from numpy.typing import DTypeLike
# Default number of rows displayed (different to pandas where ALL could be displayed)
DEFAULT_NUM_ROWS_DISPLAYED = 60
DEFAULT_CHUNK_SIZE = 10000
@ -42,7 +54,7 @@ with warnings.catch_warnings():
def build_pd_series(
data: Dict[str, Any], dtype: Optional[np.dtype] = None, **kwargs: Any
data: Dict[str, Any], dtype: Optional["DTypeLike"] = None, **kwargs: Any
) -> pd.Series:
"""Builds a pd.Series while squelching the warning
for unspecified dtype on empty series
@ -88,7 +100,7 @@ class SortOrder(Enum):
def elasticsearch_date_to_pandas_date(
value: Union[int, str], date_format: Optional[str]
value: Union[int, str, float], date_format: Optional[str]
) -> pd.Timestamp:
"""
Given a specific Elasticsearch format for a date datatype, returns the
@ -98,7 +110,7 @@ def elasticsearch_date_to_pandas_date(
Parameters
----------
value: Union[int, str]
value: Union[int, str, float]
The date value.
date_format: str
The Elasticsearch date format (ex. 'epoch_millis', 'epoch_second', etc.)

View File

@ -15,9 +15,11 @@
# specific language governing permissions and limitations
# under the License.
from typing import Any, Dict
import numpy as np
import pandas as pd
import pytest
import pandas as pd # type: ignore
import pytest # type: ignore
import eland as ed
@ -28,7 +30,7 @@ pd.set_option("display.width", 100)
@pytest.fixture(autouse=True)
def add_imports(doctest_namespace):
def add_imports(doctest_namespace: Dict[str, Any]) -> None:
doctest_namespace["np"] = np
doctest_namespace["pd"] = pd
doctest_namespace["ed"] = ed

View File

@ -19,19 +19,19 @@ import re
import sys
import warnings
from io import StringIO
from typing import Any, List, Optional, Sequence, Tuple, Union
from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Tuple, Union
import numpy as np
import pandas as pd
from pandas.core.common import apply_if_callable, is_bool_indexer
from pandas.core.computation.eval import eval
from pandas.core.dtypes.common import is_list_like
from pandas.core.indexing import check_bool_indexer
from pandas.io.common import _expand_user, stringify_path
from pandas.io.formats import console
import pandas as pd # type: ignore
from pandas.core.common import apply_if_callable, is_bool_indexer # type: ignore
from pandas.core.computation.eval import eval # type: ignore
from pandas.core.dtypes.common import is_list_like # type: ignore
from pandas.core.indexing import check_bool_indexer # type: ignore
from pandas.io.common import _expand_user, stringify_path # type: ignore
from pandas.io.formats import console # type: ignore
from pandas.io.formats import format as fmt
from pandas.io.formats.printing import pprint_thing
from pandas.util._validators import validate_bool_kwarg
from pandas.io.formats.printing import pprint_thing # type: ignore
from pandas.util._validators import validate_bool_kwarg # type: ignore
import eland.plotting as gfx
from eland.common import DEFAULT_NUM_ROWS_DISPLAYED, docstring_parameter
@ -41,6 +41,11 @@ from eland.ndframe import NDFrame
from eland.series import Series
from eland.utils import is_valid_attr_name
if TYPE_CHECKING:
from elasticsearch import Elasticsearch
from .query_compiler import QueryCompiler
class DataFrame(NDFrame):
"""
@ -119,11 +124,13 @@ class DataFrame(NDFrame):
def __init__(
self,
es_client=None,
es_index_pattern=None,
es_index_field=None,
columns=None,
_query_compiler=None,
es_client: Optional[
Union[str, List[str], Tuple[str, ...], "Elasticsearch"]
] = None,
es_index_pattern: Optional[str] = None,
columns: Optional[List[str]] = None,
es_index_field: Optional[str] = None,
_query_compiler: Optional["QueryCompiler"] = None,
) -> None:
"""
There are effectively 2 constructors:
@ -147,7 +154,7 @@ class DataFrame(NDFrame):
_query_compiler=_query_compiler,
)
def _get_columns(self):
def _get_columns(self) -> pd.Index:
"""
The column labels of the DataFrame.
@ -178,7 +185,7 @@ class DataFrame(NDFrame):
columns = property(_get_columns)
@property
def empty(self):
def empty(self) -> bool:
"""Determines if the DataFrame is empty.
Returns
@ -278,7 +285,10 @@ class DataFrame(NDFrame):
return DataFrame(_query_compiler=self._query_compiler.tail(n))
def sample(
self, n: int = None, frac: float = None, random_state: int = None
self,
n: Optional[int] = None,
frac: Optional[float] = None,
random_state: Optional[int] = None,
) -> "DataFrame":
"""
Return n randomly sample rows or the specify fraction of rows
@ -469,7 +479,7 @@ class DataFrame(NDFrame):
if is_valid_attr_name(column_name)
]
def __repr__(self):
def __repr__(self) -> None:
"""
From pandas
"""
@ -501,7 +511,7 @@ class DataFrame(NDFrame):
return buf.getvalue()
def _info_repr(self):
def _info_repr(self) -> bool:
"""
True if the repr should show the info view.
"""
@ -510,7 +520,7 @@ class DataFrame(NDFrame):
self._repr_fits_horizontal_() and self._repr_fits_vertical_()
)
def _repr_html_(self):
def _repr_html_(self) -> Optional[str]:
"""
From pandas - this is called by notebooks
"""
@ -540,7 +550,7 @@ class DataFrame(NDFrame):
else:
return None
def count(self):
def count(self) -> pd.Series:
"""
Count non-NA cells for each column.
@ -855,10 +865,10 @@ class DataFrame(NDFrame):
exceeds_info_cols = len(self.columns) > max_cols
# From pandas.DataFrame
def _put_str(s, space):
def _put_str(s, space) -> str:
return f"{s}"[:space].ljust(space)
def _verbose_repr():
def _verbose_repr() -> None:
lines.append(f"Data columns (total {len(self.columns)} columns):")
id_head = " # "
@ -930,10 +940,10 @@ class DataFrame(NDFrame):
+ _put_str(dtype, space_dtype)
)
def _non_verbose_repr():
def _non_verbose_repr() -> None:
lines.append(self.columns._summary(name="Columns"))
def _sizeof_fmt(num, size_qualifier):
def _sizeof_fmt(num: float, size_qualifier: str) -> str:
# returns size in human readable format
for x in ["bytes", "KB", "MB", "GB", "TB"]:
if num < 1024.0:
@ -1004,7 +1014,7 @@ class DataFrame(NDFrame):
border=None,
table_id=None,
render_links=False,
):
) -> Any:
"""
Render a Elasticsearch data as an HTML table.
@ -1171,7 +1181,7 @@ class DataFrame(NDFrame):
result = _buf.getvalue()
return result
def __getattr__(self, key):
def __getattr__(self, key: str) -> Any:
"""After regular attribute access, looks up the name in the columns
Parameters
@ -1190,7 +1200,12 @@ class DataFrame(NDFrame):
return self[key]
raise e
def _getitem(self, key):
def _getitem(
self,
key: Union[
"DataFrame", "Series", pd.Index, List[str], str, BooleanFilter, np.ndarray
],
) -> Union["Series", "DataFrame"]:
"""Get the column specified by key for this DataFrame.
Args:
@ -1215,13 +1230,13 @@ class DataFrame(NDFrame):
else:
return self._getitem_column(key)
def _getitem_column(self, key):
def _getitem_column(self, key: str) -> "Series":
if key not in self.columns:
raise KeyError(f"Requested column [{key}] is not in the DataFrame.")
s = self._reduce_dimension(self._query_compiler.getitem_column_array([key]))
return s
def _getitem_array(self, key):
def _getitem_array(self, key: Union[str, pd.Series]) -> "DataFrame":
if isinstance(key, Series):
key = key.to_pandas()
if is_bool_indexer(key):
@ -1256,7 +1271,9 @@ class DataFrame(NDFrame):
_query_compiler=self._query_compiler.getitem_column_array(key)
)
def _create_or_update_from_compiler(self, new_query_compiler, inplace=False):
def _create_or_update_from_compiler(
self, new_query_compiler: "QueryCompiler", inplace: bool = False
) -> Union["QueryCompiler", "DataFrame"]:
"""Returns or updates a DataFrame given new query_compiler"""
assert (
isinstance(new_query_compiler, type(self._query_compiler))
@ -1265,10 +1282,10 @@ class DataFrame(NDFrame):
if not inplace:
return DataFrame(_query_compiler=new_query_compiler)
else:
self._query_compiler = new_query_compiler
self._query_compiler: "QueryCompiler" = new_query_compiler
@staticmethod
def _reduce_dimension(query_compiler):
def _reduce_dimension(query_compiler: "QueryCompiler") -> "Series":
return Series(_query_compiler=query_compiler)
def to_csv(
@ -1849,7 +1866,9 @@ class DataFrame(NDFrame):
else:
raise NotImplementedError(expr, type(expr))
def get(self, key, default=None):
def get(
self, key: Any, default: Optional[Any] = None
) -> Union["Series", "DataFrame"]:
"""
Get item from object for given key (ex: DataFrame column).
Returns default value if not found.
@ -1956,7 +1975,7 @@ class DataFrame(NDFrame):
elif like is not None:
def matcher(x):
def matcher(x: str) -> bool:
return like in x
else:
@ -1965,7 +1984,7 @@ class DataFrame(NDFrame):
return self[[column for column in self.columns if matcher(column)]]
@property
def values(self):
def values(self) -> None:
"""
Not implemented.
@ -1983,7 +2002,7 @@ class DataFrame(NDFrame):
"""
return self.to_numpy()
def to_numpy(self):
def to_numpy(self) -> None:
"""
Not implemented.

View File

@ -25,13 +25,14 @@ from typing import (
NamedTuple,
Optional,
Set,
TextIO,
Tuple,
Union,
)
import numpy as np
import pandas as pd
from pandas.core.dtypes.common import (
import pandas as pd # type: ignore
from pandas.core.dtypes.common import ( # type: ignore
is_bool_dtype,
is_datetime_or_timedelta_dtype,
is_float_dtype,
@ -42,6 +43,7 @@ from pandas.core.dtypes.inference import is_list_like
if TYPE_CHECKING:
from elasticsearch import Elasticsearch
from numpy.typing import DTypeLike
ES_FLOAT_TYPES: Set[str] = {"double", "float", "half_float", "scaled_float"}
@ -559,7 +561,7 @@ class FieldMappings:
return {"mappings": {"properties": mapping_props}}
def aggregatable_field_name(self, display_name):
def aggregatable_field_name(self, display_name: str) -> Optional[str]:
"""
Return a single aggregatable field_name from display_name
@ -598,7 +600,7 @@ class FieldMappings:
return self._mappings_capabilities.loc[display_name].aggregatable_es_field_name
def aggregatable_field_names(self):
def aggregatable_field_names(self) -> Dict[str, str]:
"""
Return a list of aggregatable Elasticsearch field_names for all display names.
If field is not aggregatable_field_names, return nothing.
@ -634,7 +636,7 @@ class FieldMappings:
)["data"]
)
def date_field_format(self, es_field_name):
def date_field_format(self, es_field_name: str) -> str:
"""
Parameters
----------
@ -650,7 +652,7 @@ class FieldMappings:
self._mappings_capabilities.es_field_name == es_field_name
].es_date_format.squeeze()
def field_name_pd_dtype(self, es_field_name):
def field_name_pd_dtype(self, es_field_name: str) -> str:
"""
Parameters
----------
@ -674,7 +676,9 @@ class FieldMappings:
].pd_dtype.squeeze()
return pd_dtype
def add_scripted_field(self, scripted_field_name, display_name, pd_dtype):
def add_scripted_field(
self, scripted_field_name: str, display_name: str, pd_dtype: str
) -> None:
# if this display name is used somewhere else, drop it
if display_name in self._mappings_capabilities.index:
self._mappings_capabilities = self._mappings_capabilities.drop(
@ -706,8 +710,8 @@ class FieldMappings:
capability_matrix_row
)
def numeric_source_fields(self):
pd_dtypes, es_field_names, es_date_formats = self.metric_source_fields()
def numeric_source_fields(self) -> List[str]:
_, es_field_names, _ = self.metric_source_fields()
return es_field_names
def all_source_fields(self) -> List[Field]:
@ -753,7 +757,9 @@ class FieldMappings:
# Maintain groupby order as given input
return [groupby_fields[column] for column in by], aggregatable_fields
def metric_source_fields(self, include_bool=False, include_timestamp=False):
def metric_source_fields(
self, include_bool: bool = False, include_timestamp: bool = False
) -> Tuple[List["DTypeLike"], List[str], Optional[List[str]]]:
"""
Returns
-------
@ -790,7 +796,7 @@ class FieldMappings:
# return in display_name order
return pd_dtypes, es_field_names, es_date_formats
def get_field_names(self, include_scripted_fields=True):
def get_field_names(self, include_scripted_fields: bool = True) -> List[str]:
if include_scripted_fields:
return self._mappings_capabilities.es_field_name.to_list()
@ -801,7 +807,7 @@ class FieldMappings:
def _get_display_names(self):
return self._mappings_capabilities.index.to_list()
def _set_display_names(self, display_names):
def _set_display_names(self, display_names: List[str]):
if not is_list_like(display_names):
raise ValueError(f"'{display_names}' is not list like")
@ -842,7 +848,7 @@ class FieldMappings:
es_dtypes.name = None
return es_dtypes
def es_info(self, buf):
def es_info(self, buf: TextIO) -> None:
buf.write("Mappings:\n")
buf.write(f" capabilities:\n{self._mappings_capabilities.to_string()}\n")

View File

@ -17,8 +17,11 @@
import distutils.version
import importlib
import types
import warnings
from typing import TYPE_CHECKING, Any, Optional
if TYPE_CHECKING:
from types import ModuleType
# ----------------------------------------------------------------------------
# functions largely based / taken from the six module
@ -42,7 +45,7 @@ version_message = (
)
def _get_version(module: types.ModuleType) -> str:
def _get_version(module: "ModuleType") -> Any:
version = getattr(module, "__version__", None)
if version is None:
# xlrd uses a capitalized attribute name
@ -55,7 +58,7 @@ def _get_version(module: types.ModuleType) -> str:
def import_optional_dependency(
name: str, extra: str = "", raise_on_missing: bool = True, on_version: str = "raise"
):
) -> Optional["ModuleType"]:
"""
Import an optional dependency.

View File

@ -18,7 +18,7 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
import elasticsearch
import numpy as np # type: ignore
import numpy as np
from eland.common import ensure_es_client, es_version
from eland.utils import deprecated_api
@ -27,7 +27,8 @@ from .common import TYPE_CLASSIFICATION, TYPE_REGRESSION
from .transformers import get_model_transformer
if TYPE_CHECKING:
from elasticsearch import Elasticsearch # noqa: F401
from elasticsearch import Elasticsearch
from numpy.typing import ArrayLike, DTypeLike
# Try importing each ML lib separately so mypy users don't have to
# have both installed to use type-checking.
@ -83,8 +84,8 @@ class MLModel:
self._trained_model_config_cache: Optional[Dict[str, Any]] = None
def predict(
self, X: Union[np.ndarray, List[float], List[List[float]]]
) -> np.ndarray:
self, X: Union["ArrayLike", List[float], List[List[float]]]
) -> "ArrayLike":
"""
Make a prediction using a trained model stored in Elasticsearch.
@ -196,7 +197,7 @@ class MLModel:
# Return results as np.ndarray of float32 or int (consistent with sklearn/xgboost)
if self.model_type == TYPE_CLASSIFICATION:
dt = np.int_
dt: "DTypeLike" = np.int_
else:
dt = np.float32
return np.asarray(y, dtype=dt)

View File

@ -17,7 +17,7 @@
from typing import Any, Dict, Optional, Sequence, Tuple, Type, Union
import numpy as np # type: ignore
import numpy as np
from .._model_serializer import Ensemble, Tree, TreeNode
from .._optional import import_optional_dependency
@ -64,7 +64,7 @@ class SKLearnTransformer(ModelTransformer):
self,
node_index: int,
node_data: Tuple[Union[int, float], ...],
value: np.ndarray,
value: np.ndarray, # type: ignore
) -> TreeNode:
"""
This builds out a TreeNode class given the sklearn tree node definition.

View File

@ -229,7 +229,7 @@ class XGBoostClassifierTransformer(XGBoostForestTransformer):
if model.classes_ is None:
n_estimators = model.get_params()["n_estimators"]
num_trees = model.get_booster().trees_to_dataframe()["Tree"].max() + 1
self._num_classes = num_trees // n_estimators
self._num_classes: int = num_trees // n_estimators
else:
self._num_classes = len(model.classes_)

View File

@ -17,7 +17,7 @@
import sys
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, List, Optional, TextIO, Tuple, Union
import pandas as pd # type: ignore
@ -186,7 +186,7 @@ class NDFrame(ABC):
"""
return len(self.index)
def _es_info(self, buf):
def _es_info(self, buf: TextIO) -> None:
self._query_compiler.es_info(buf)
def mean(self, numeric_only: Optional[bool] = None) -> pd.Series:
@ -604,7 +604,7 @@ class NDFrame(ABC):
"""
return self._query_compiler.mad(numeric_only=numeric_only)
def _hist(self, num_bins):
def _hist(self, num_bins: int) -> Tuple[pd.DataFrame, pd.DataFrame]:
return self._query_compiler._hist(num_bins)
def describe(self) -> pd.DataFrame:

View File

@ -26,12 +26,13 @@ from typing import (
List,
Optional,
Sequence,
TextIO,
Tuple,
Union,
)
import numpy as np
import pandas as pd
import pandas as pd # type: ignore
from elasticsearch.helpers import scan
from eland.actions import PostProcessingAction, SortFieldAction
@ -58,13 +59,18 @@ from eland.tasks import (
)
if TYPE_CHECKING:
from numpy.typing import DTypeLike
from eland.arithmetics import ArithmeticSeries
from eland.field_mappings import Field
from eland.filter import BooleanFilter
from eland.query_compiler import QueryCompiler
from eland.tasks import Task
class QueryParams:
def __init__(self):
self.query = Query()
def __init__(self) -> None:
self.query: Query = Query()
self.sort_field: Optional[str] = None
self.sort_order: Optional[SortOrder] = None
self.size: Optional[int] = None
@ -85,37 +91,48 @@ class Operations:
(see https://docs.dask.org/en/latest/spec.html)
"""
def __init__(self, tasks=None, arithmetic_op_fields_task=None):
def __init__(
self,
tasks: Optional[List["Task"]] = None,
arithmetic_op_fields_task: Optional["ArithmeticOpFieldsTask"] = None,
) -> None:
self._tasks: List["Task"]
if tasks is None:
self._tasks = []
else:
self._tasks = tasks
self._arithmetic_op_fields_task = arithmetic_op_fields_task
def __constructor__(self, *args, **kwargs):
def __constructor__(
self,
*args: Any,
**kwargs: Any,
) -> "Operations":
return type(self)(*args, **kwargs)
def copy(self):
def copy(self) -> "Operations":
return self.__constructor__(
tasks=copy.deepcopy(self._tasks),
arithmetic_op_fields_task=copy.deepcopy(self._arithmetic_op_fields_task),
)
def head(self, index, n):
def head(self, index: "Index", n: int) -> None:
# Add a task that is an ascending sort with size=n
task = HeadTask(index, n)
self._tasks.append(task)
def tail(self, index, n):
def tail(self, index: "Index", n: int) -> None:
# Add a task that is descending sort with size=n
task = TailTask(index, n)
self._tasks.append(task)
def sample(self, index, n, random_state):
def sample(self, index: "Index", n: int, random_state: int) -> None:
task = SampleTask(index, n, random_state)
self._tasks.append(task)
def arithmetic_op_fields(self, display_name, arithmetic_series):
def arithmetic_op_fields(
self, display_name: str, arithmetic_series: "ArithmeticSeries"
) -> None:
if self._arithmetic_op_fields_task is None:
self._arithmetic_op_fields_task = ArithmeticOpFieldsTask(
display_name, arithmetic_series
@ -127,10 +144,10 @@ class Operations:
# get an ArithmeticOpFieldsTask if it exists
return self._arithmetic_op_fields_task
def __repr__(self):
def __repr__(self) -> str:
return repr(self._tasks)
def count(self, query_compiler):
def count(self, query_compiler: "QueryCompiler") -> pd.Series:
query_params, post_processing = self._resolve_tasks(query_compiler)
# Elasticsearch _count is very efficient and so used to return results here. This means that
@ -161,7 +178,7 @@ class Operations:
def _metric_agg_series(
self,
query_compiler: "QueryCompiler",
agg: List,
agg: List["str"],
numeric_only: Optional[bool] = None,
) -> pd.Series:
results = self._metric_aggs(query_compiler, agg, numeric_only=numeric_only)
@ -170,7 +187,7 @@ class Operations:
else:
# If all results are float convert into float64
if all(isinstance(i, float) for i in results.values()):
dtype = np.float64
dtype: "DTypeLike" = np.float64
# If all results are int convert into int64
elif all(isinstance(i, int) for i in results.values()):
dtype = np.int64
@ -184,7 +201,9 @@ class Operations:
def value_counts(self, query_compiler: "QueryCompiler", es_size: int) -> pd.Series:
return self._terms_aggs(query_compiler, "terms", es_size)
def hist(self, query_compiler, bins):
def hist(
self, query_compiler: "QueryCompiler", bins: int
) -> Tuple[pd.DataFrame, pd.DataFrame]:
return self._hist_aggs(query_compiler, bins)
def idx(
@ -237,7 +256,12 @@ class Operations:
return pd.Series(results)
def aggs(self, query_compiler, pd_aggs, numeric_only=None) -> pd.DataFrame:
def aggs(
self,
query_compiler: "QueryCompiler",
pd_aggs: List[str],
numeric_only: Optional[bool] = None,
) -> pd.DataFrame:
results = self._metric_aggs(
query_compiler, pd_aggs, numeric_only=numeric_only, is_dataframe_agg=True
)
@ -441,13 +465,15 @@ class Operations:
try:
# get first value in dict (key is .keyword)
name = list(aggregatable_field_names.values())[0]
name: Optional[str] = list(aggregatable_field_names.values())[0]
except IndexError:
name = None
return build_pd_series(results, name=name)
def _hist_aggs(self, query_compiler, num_bins):
def _hist_aggs(
self, query_compiler: "QueryCompiler", num_bins: int
) -> Tuple[pd.DataFrame, pd.DataFrame]:
# Get histogram bins and weights for numeric field_names
query_params, post_processing = self._resolve_tasks(query_compiler)
@ -488,8 +514,8 @@ class Operations:
# },
# ...
bins = {}
weights = {}
bins: Dict[str, List[int]] = {}
weights: Dict[str, List[int]] = {}
# There is one more bin that weights
# len(bins) = len(weights) + 1
@ -537,11 +563,11 @@ class Operations:
def _unpack_metric_aggs(
self,
fields: List["Field"],
es_aggs: Union[List[str], List[Tuple[str, str]]],
es_aggs: Union[List[str], List[Tuple[str, List[float]]]],
pd_aggs: List[str],
response: Dict[str, Any],
numeric_only: Optional[bool],
percentiles: Optional[List[float]] = None,
percentiles: Optional[Sequence[float]] = None,
is_dataframe_agg: bool = False,
is_groupby: bool = False,
) -> Dict[str, List[Any]]:
@ -574,7 +600,7 @@ class Operations:
"""
results: Dict[str, Any] = {}
percentile_values: List[float] = []
agg_value: Union[int, float]
agg_value: Any
for field in fields:
values = []
@ -611,7 +637,10 @@ class Operations:
agg_value = agg_value["50.0"]
else:
# Maintain order of percentiles
percentile_values = [agg_value[str(i)] for i in percentiles]
if percentiles:
percentile_values = [
agg_value[str(i)] for i in percentiles
]
if not percentile_values and pd_agg not in ("quantile", "median"):
agg_value = agg_value[es_agg[1]]
@ -682,7 +711,11 @@ class Operations:
# Cardinality is always either NaN or integer.
elif pd_agg in ("nunique", "count"):
agg_value = int(agg_value)
agg_value = (
int(agg_value)
if isinstance(agg_value, (int, float))
else np.NaN
)
# If this is a non-null timestamp field convert to a pd.Timestamp()
elif field.is_timestamp:
@ -702,6 +735,7 @@ class Operations:
for value in percentile_values
]
else:
assert not isinstance(agg_value, dict)
agg_value = elasticsearch_date_to_pandas_date(
agg_value, field.es_date_format
)
@ -771,7 +805,7 @@ class Operations:
by: List[str],
pd_aggs: List[str],
dropna: bool = True,
quantiles: Optional[List[float]] = None,
quantiles: Optional[Union[int, float, List[int], List[float]]] = None,
is_dataframe_agg: bool = False,
numeric_only: Optional[bool] = True,
) -> pd.DataFrame:
@ -811,7 +845,7 @@ class Operations:
by_fields, agg_fields = query_compiler._mappings.groupby_source_fields(by=by)
# Used defaultdict to avoid initialization of columns with lists
results: Dict[str, List[Any]] = defaultdict(list)
results: Dict[Any, List[Any]] = defaultdict(list)
if numeric_only:
agg_fields = [
@ -823,7 +857,8 @@ class Operations:
# To return for creating multi-index on columns
headers = [agg_field.column for agg_field in agg_fields]
percentiles: Optional[List[str]] = None
percentiles: Optional[List[float]] = None
len_percentiles: int = 0
if quantiles:
percentiles = [
quantile_to_percentile(x)
@ -833,6 +868,7 @@ class Operations:
else quantiles
)
]
len_percentiles = len(percentiles)
# Convert pandas aggs to ES equivalent
es_aggs = self._map_pd_aggs_to_es_aggs(pd_aggs=pd_aggs, percentiles=percentiles)
@ -894,8 +930,8 @@ class Operations:
if by_field.is_timestamp and isinstance(bucket_key, int):
bucket_key = pd.to_datetime(bucket_key, unit="ms")
if pd_aggs == ["quantile"] and len(percentiles) > 1:
bucket_key = [bucket_key] * len(percentiles)
if pd_aggs == ["quantile"] and len_percentiles > 1:
bucket_key = [bucket_key] * len_percentiles
results[by_field.column].extend(
bucket_key if isinstance(bucket_key, list) else [bucket_key]
@ -915,7 +951,7 @@ class Operations:
)
# to construct index with quantiles
if pd_aggs == ["quantile"] and len(percentiles) > 1:
if pd_aggs == ["quantile"] and percentiles and len_percentiles > 1:
results[None].extend([i / 100 for i in percentiles])
# Process the calculated agg values to response
@ -929,9 +965,10 @@ class Operations:
for pd_agg, val in zip(pd_aggs, value):
results[f"{key}_{pd_agg}"].append(val)
# Just to maintain Output same as pandas with empty header.
if pd_aggs == ["quantile"] and len(percentiles) > 1:
by = by + [None]
if pd_aggs == ["quantile"] and len_percentiles > 1:
# by never holds None by default, we make an exception
# here to maintain output same as pandas, also mypy complains
by = by + [None] # type: ignore
agg_df = pd.DataFrame(results).set_index(by).sort_index()
@ -947,7 +984,7 @@ class Operations:
@staticmethod
def bucket_generator(
query_compiler: "QueryCompiler", body: "Query"
) -> Generator[List[str], None, List[str]]:
) -> Generator[Sequence[Dict[str, Any]], None, Sequence[Dict[str, Any]]]:
"""
This can be used for all groupby operations.
e.g.
@ -977,18 +1014,24 @@ class Operations:
)
# Pagination Logic
composite_buckets = res["aggregations"]["groupby_buckets"]
if "after_key" in composite_buckets:
composite_buckets: Dict[str, Any] = res["aggregations"]["groupby_buckets"]
after_key: Optional[Dict[str, Any]] = composite_buckets.get(
"after_key", None
)
buckets: Sequence[Dict[str, Any]] = composite_buckets["buckets"]
if after_key:
# yield the bucket which contains the result
yield composite_buckets["buckets"]
yield buckets
body.composite_agg_after_key(
name="groupby_buckets",
after_key=composite_buckets["after_key"],
after_key=after_key,
)
else:
return composite_buckets["buckets"]
return buckets
@staticmethod
def _map_pd_aggs_to_es_aggs(
@ -1031,7 +1074,7 @@ class Operations:
extended_stats_es_aggs = {"avg", "min", "max", "sum"}
extended_stats_calls = 0
es_aggs = []
es_aggs: List[Any] = []
for pd_agg in pd_aggs:
if pd_agg in extended_stats_pd_aggs:
extended_stats_calls += 1
@ -1100,7 +1143,7 @@ class Operations:
def filter(
self,
query_compiler: "QueryCompiler",
items: Optional[Sequence[str]] = None,
items: Optional[List[str]] = None,
like: Optional[str] = None,
regex: Optional[str] = None,
) -> None:
@ -1122,7 +1165,7 @@ class Operations:
f"to substring and regex operations not being available for Elasticsearch document IDs."
)
def describe(self, query_compiler):
def describe(self, query_compiler: "QueryCompiler") -> pd.DataFrame:
query_params, post_processing = self._resolve_tasks(query_compiler)
size = self._size(query_params, post_processing)
@ -1151,30 +1194,9 @@ class Operations:
["count", "mean", "std", "min", "25%", "50%", "75%", "max"]
)
def to_pandas(self, query_compiler, show_progress=False):
class PandasDataFrameCollector:
def __init__(self, show_progress):
self._df = None
self._show_progress = show_progress
def collect(self, df):
# This collector does not batch data on output. Therefore, batch_size is fixed to None and this method
# is only called once.
if self._df is not None:
raise RuntimeError(
"Logic error in execution, this method must only be called once for this"
"collector - batch_size == None"
)
self._df = df
@staticmethod
def batch_size():
# Do not change (see notes on collect)
return None
@property
def show_progress(self):
return self._show_progress
def to_pandas(
self, query_compiler: "QueryCompiler", show_progress: bool = False
) -> None:
collector = PandasDataFrameCollector(show_progress)
@ -1182,35 +1204,12 @@ class Operations:
return collector._df
def to_csv(self, query_compiler, show_progress=False, **kwargs):
class PandasToCSVCollector:
def __init__(self, show_progress, **args):
self._args = args
self._show_progress = show_progress
self._ret = None
self._first_time = True
def collect(self, df):
# If this is the first time we collect results, then write header, otherwise don't write header
# and append results
if self._first_time:
self._first_time = False
df.to_csv(**self._args)
else:
# Don't write header, and change mode to append
self._args["header"] = False
self._args["mode"] = "a"
df.to_csv(**self._args)
@staticmethod
def batch_size():
# By default read n docs and then dump to csv
batch_size = DEFAULT_CSV_BATCH_OUTPUT_SIZE
return batch_size
@property
def show_progress(self):
return self._show_progress
def to_csv(
self,
query_compiler: "QueryCompiler",
show_progress: bool = False,
**kwargs: Union[bool, str],
) -> None:
collector = PandasToCSVCollector(show_progress, **kwargs)
@ -1218,7 +1217,11 @@ class Operations:
return collector._ret
def _es_results(self, query_compiler, collector):
def _es_results(
self,
query_compiler: "QueryCompiler",
collector: Union["PandasToCSVCollector", "PandasDataFrameCollector"],
) -> None:
query_params, post_processing = self._resolve_tasks(query_compiler)
size, sort_params = Operations._query_params_to_size_and_sort(query_params)
@ -1245,7 +1248,7 @@ class Operations:
else:
body["_source"] = False
es_results = None
es_results: Any = None
# If size=None use scan not search - then post sort results when in df
# If size>10000 use scan
@ -1283,7 +1286,7 @@ class Operations:
df = self._apply_df_post_processing(df, post_processing)
collector.collect(df)
def index_count(self, query_compiler, field):
def index_count(self, query_compiler: "QueryCompiler", field: str) -> int:
# field is the index field so count values
query_params, post_processing = self._resolve_tasks(query_compiler)
@ -1297,12 +1300,13 @@ class Operations:
body = Query(query_params.query)
body.exists(field, must=True)
return query_compiler._client.count(
count: int = query_compiler._client.count(
index=query_compiler._index_pattern, body=body.to_count_body()
)["count"]
return count
def _validate_index_operation(
self, query_compiler: "QueryCompiler", items: Sequence[str]
self, query_compiler: "QueryCompiler", items: List[str]
) -> RESOLVED_TASK_TYPE:
if not isinstance(items, list):
raise TypeError(f"list item required - not {type(items)}")
@ -1320,7 +1324,9 @@ class Operations:
return query_params, post_processing
def index_matches_count(self, query_compiler, field, items):
def index_matches_count(
self, query_compiler: "QueryCompiler", field: str, items: List[Any]
) -> int:
query_params, post_processing = self._validate_index_operation(
query_compiler, items
)
@ -1332,12 +1338,13 @@ class Operations:
else:
body.terms(field, items, must=True)
return query_compiler._client.count(
count: int = query_compiler._client.count(
index=query_compiler._index_pattern, body=body.to_count_body()
)["count"]
return count
def drop_index_values(
self, query_compiler: "QueryCompiler", field: str, items: Sequence[str]
self, query_compiler: "QueryCompiler", field: str, items: List[str]
) -> None:
self._validate_index_operation(query_compiler, items)
@ -1349,6 +1356,7 @@ class Operations:
# a in ['a','b','c']
# b not in ['a','b','c']
# For now use term queries
task: Union["QueryIdsTask", "QueryTermsTask"]
if field == Index.ID_INDEX_FIELD:
task = QueryIdsTask(False, items)
else:
@ -1356,11 +1364,12 @@ class Operations:
self._tasks.append(task)
def filter_index_values(
self, query_compiler: "QueryCompiler", field: str, items: Sequence[str]
self, query_compiler: "QueryCompiler", field: str, items: List[str]
) -> None:
# Basically .drop_index_values() except with must=True on tasks.
self._validate_index_operation(query_compiler, items)
task: Union["QueryIdsTask", "QueryTermsTask"]
if field == Index.ID_INDEX_FIELD:
task = QueryIdsTask(True, items, sort_index_by_ids=True)
else:
@ -1406,7 +1415,7 @@ class Operations:
# other operations require pre-queries and then combinations
# other operations require in-core post-processing of results
query_params = QueryParams()
post_processing = []
post_processing: List["PostProcessingAction"] = []
for task in self._tasks:
query_params, post_processing = task.resolve_task(
@ -1439,7 +1448,7 @@ class Operations:
# This can return None
return size
def es_info(self, query_compiler, buf):
def es_info(self, query_compiler: "QueryCompiler", buf: TextIO) -> None:
buf.write("Operations:\n")
buf.write(f" tasks: {self._tasks}\n")
@ -1459,7 +1468,7 @@ class Operations:
buf.write(f" body: {body}\n")
buf.write(f" post_processing: {post_processing}\n")
def update_query(self, boolean_filter):
def update_query(self, boolean_filter: "BooleanFilter") -> None:
task = BooleanFilterTask(boolean_filter)
self._tasks.append(task)
@ -1477,3 +1486,58 @@ def quantile_to_percentile(quantile: Union[int, float]) -> float:
# quantile * 100 = percentile
# return float(...) because min(1.0) gives 1
return float(min(100, max(0, quantile * 100)))
class PandasToCSVCollector:
def __init__(self, show_progress: bool, **kwargs: Union[bool, str]) -> None:
self._args = kwargs
self._show_progress = show_progress
self._ret = None
self._first_time = True
def collect(self, df: "pd.DataFrame") -> None:
# If this is the first time we collect results, then write header, otherwise don't write header
# and append results
if self._first_time:
self._first_time = False
df.to_csv(**self._args)
else:
# Don't write header, and change mode to append
self._args["header"] = False
self._args["mode"] = "a"
df.to_csv(**self._args)
@staticmethod
def batch_size() -> int:
# By default read n docs and then dump to csv
batch_size: int = DEFAULT_CSV_BATCH_OUTPUT_SIZE
return batch_size
@property
def show_progress(self) -> bool:
return self._show_progress
class PandasDataFrameCollector:
def __init__(self, show_progress: bool) -> None:
self._df = None
self._show_progress = show_progress
def collect(self, df: "pd.DataFrame") -> None:
# This collector does not batch data on output. Therefore, batch_size is fixed to None and this method
# is only called once.
if self._df is not None:
raise RuntimeError(
"Logic error in execution, this method must only be called once for this"
"collector - batch_size == None"
)
self._df = df
@staticmethod
def batch_size() -> None:
# Do not change (see notes on collect)
return None
@property
def show_progress(self) -> bool:
return self._show_progress

View File

@ -15,8 +15,10 @@
# specific language governing permissions and limitations
# under the License.
from typing import TYPE_CHECKING
import numpy as np
from pandas.plotting._matplotlib import converter
from pandas.plotting._matplotlib import converter # type: ignore
try:
# pandas<1.3.0
@ -26,13 +28,13 @@ except ImportError:
from pandas.core.dtypes.generic import ABCIndex
try: # pandas>=1.2.0
from pandas.plotting._matplotlib.tools import (
from pandas.plotting._matplotlib.tools import ( # type: ignore
create_subplots,
flatten_axes,
set_ticks_props,
)
except ImportError: # pandas<1.2.0
from pandas.plotting._matplotlib.tools import (
from pandas.plotting._matplotlib.tools import ( # type: ignore
_flatten as flatten_axes,
_set_ticks_props as set_ticks_props,
_subplots as create_subplots,
@ -40,6 +42,9 @@ except ImportError: # pandas<1.2.0
from eland.utils import try_sort
if TYPE_CHECKING:
from numpy.typing import ArrayLike
def hist_series(
self,
@ -53,8 +58,8 @@ def hist_series(
figsize=None,
bins=10,
**kwds,
):
import matplotlib.pyplot as plt
) -> "ArrayLike":
import matplotlib.pyplot as plt # type: ignore
if by is None:
if kwds.get("layout", None) is not None:

View File

@ -17,9 +17,19 @@
import copy
from datetime import datetime
from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Union
from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
Optional,
Sequence,
TextIO,
Tuple,
Union,
)
import numpy as np # type: ignore
import numpy as np
import pandas as pd # type: ignore
from eland.common import (
@ -28,11 +38,15 @@ from eland.common import (
ensure_es_client,
)
from eland.field_mappings import FieldMappings
from eland.filter import QueryFilter
from eland.filter import BooleanFilter, QueryFilter
from eland.index import Index
from eland.operations import Operations
if TYPE_CHECKING:
from elasticsearch import Elasticsearch
from eland.arithmetics import ArithmeticSeries
from .tasks import ArithmeticOpFieldsTask # noqa: F401
@ -67,8 +81,10 @@ class QueryCompiler:
def __init__(
self,
client=None,
index_pattern=None,
client: Optional[
Union[str, List[str], Tuple[str, ...], "Elasticsearch"]
] = None,
index_pattern: Optional[str] = None,
display_names=None,
index_field=None,
to_copy=None,
@ -77,15 +93,15 @@ class QueryCompiler:
if to_copy is not None:
self._client = to_copy._client
self._index_pattern = to_copy._index_pattern
self._index = Index(self, to_copy._index.es_index_field)
self._operations = copy.deepcopy(to_copy._operations)
self._index: "Index" = Index(self, to_copy._index.es_index_field)
self._operations: "Operations" = copy.deepcopy(to_copy._operations)
self._mappings: FieldMappings = copy.deepcopy(to_copy._mappings)
else:
self._client = ensure_es_client(client)
self._index_pattern = index_pattern
# Get and persist mappings, this allows us to correctly
# map returned types from Elasticsearch to pandas datatypes
self._mappings: FieldMappings = FieldMappings(
self._mappings = FieldMappings(
client=self._client,
index_pattern=self._index_pattern,
display_names=display_names,
@ -103,15 +119,15 @@ class QueryCompiler:
return pd.Index(columns)
def _get_display_names(self):
def _get_display_names(self) -> "pd.Index":
display_names = self._mappings.display_names
return pd.Index(display_names)
def _set_display_names(self, display_names):
def _set_display_names(self, display_names: List[str]) -> None:
self._mappings.display_names = display_names
def get_field_names(self, include_scripted_fields):
def get_field_names(self, include_scripted_fields: bool) -> List[str]:
return self._mappings.get_field_names(include_scripted_fields)
def add_scripted_field(self, scripted_field_name, display_name, pd_dtype):
@ -129,7 +145,12 @@ class QueryCompiler:
# END Index, columns, and dtypes objects
def _es_results_to_pandas(self, results, batch_size=None, show_progress=False):
def _es_results_to_pandas(
self,
results: Dict[Any, Any],
batch_size: Optional[int] = None,
show_progress: bool = False,
) -> "pd.Dataframe":
"""
Parameters
----------
@ -300,7 +321,7 @@ class QueryCompiler:
return partial_result, df
def _flatten_dict(self, y, field_mapping_cache):
def _flatten_dict(self, y, field_mapping_cache: "FieldMappingCache"):
out = {}
def flatten(x, name=""):
@ -368,7 +389,7 @@ class QueryCompiler:
"""
return self._operations.index_count(self, self.index.es_index_field)
def _index_matches_count(self, items):
def _index_matches_count(self, items: List[Any]) -> int:
"""
Returns
-------
@ -386,10 +407,10 @@ class QueryCompiler:
df[c] = pd.Series(dtype=d)
return df
def copy(self):
def copy(self) -> "QueryCompiler":
return QueryCompiler(to_copy=self)
def rename(self, renames, inplace=False):
def rename(self, renames, inplace: bool = False) -> "QueryCompiler":
if inplace:
self._mappings.rename(renames)
return self
@ -398,21 +419,23 @@ class QueryCompiler:
result._mappings.rename(renames)
return result
def head(self, n):
def head(self, n: int) -> "QueryCompiler":
result = self.copy()
result._operations.head(self._index, n)
return result
def tail(self, n):
def tail(self, n: int) -> "QueryCompiler":
result = self.copy()
result._operations.tail(self._index, n)
return result
def sample(self, n=None, frac=None, random_state=None):
def sample(
self, n: Optional[int] = None, frac=None, random_state=None
) -> "QueryCompiler":
result = self.copy()
if n is None and frac is None:
@ -501,11 +524,11 @@ class QueryCompiler:
query = {"multi_match": options}
return QueryFilter(query)
def es_query(self, query):
def es_query(self, query: Dict[str, Any]) -> "QueryCompiler":
return self._update_query(QueryFilter(query))
# To/From Pandas
def to_pandas(self, show_progress=False):
def to_pandas(self, show_progress: bool = False):
"""Converts Eland DataFrame to Pandas DataFrame.
Returns:
@ -543,7 +566,9 @@ class QueryCompiler:
return result
def drop(self, index=None, columns=None):
def drop(
self, index: Optional[str] = None, columns: Optional[List[str]] = None
) -> "QueryCompiler":
result = self.copy()
# Drop gets all columns and removes drops
@ -559,7 +584,7 @@ class QueryCompiler:
def filter(
self,
items: Optional[Sequence[str]] = None,
items: Optional[List[str]] = None,
like: Optional[str] = None,
regex: Optional[str] = None,
) -> "QueryCompiler":
@ -570,53 +595,55 @@ class QueryCompiler:
result._operations.filter(self, items=items, like=like, regex=regex)
return result
def aggs(self, func: List[str], numeric_only: Optional[bool] = None):
def aggs(
self, func: List[str], numeric_only: Optional[bool] = None
) -> pd.DataFrame:
return self._operations.aggs(self, func, numeric_only=numeric_only)
def count(self):
def count(self) -> pd.Series:
return self._operations.count(self)
def mean(self, numeric_only: Optional[bool] = None):
def mean(self, numeric_only: Optional[bool] = None) -> pd.Series:
return self._operations._metric_agg_series(
self, ["mean"], numeric_only=numeric_only
)
def var(self, numeric_only: Optional[bool] = None):
def var(self, numeric_only: Optional[bool] = None) -> pd.Series:
return self._operations._metric_agg_series(
self, ["var"], numeric_only=numeric_only
)
def std(self, numeric_only: Optional[bool] = None):
def std(self, numeric_only: Optional[bool] = None) -> pd.Series:
return self._operations._metric_agg_series(
self, ["std"], numeric_only=numeric_only
)
def mad(self, numeric_only: Optional[bool] = None):
def mad(self, numeric_only: Optional[bool] = None) -> pd.Series:
return self._operations._metric_agg_series(
self, ["mad"], numeric_only=numeric_only
)
def median(self, numeric_only: Optional[bool] = None):
def median(self, numeric_only: Optional[bool] = None) -> pd.Series:
return self._operations._metric_agg_series(
self, ["median"], numeric_only=numeric_only
)
def sum(self, numeric_only: Optional[bool] = None):
def sum(self, numeric_only: Optional[bool] = None) -> pd.Series:
return self._operations._metric_agg_series(
self, ["sum"], numeric_only=numeric_only
)
def min(self, numeric_only: Optional[bool] = None):
def min(self, numeric_only: Optional[bool] = None) -> pd.Series:
return self._operations._metric_agg_series(
self, ["min"], numeric_only=numeric_only
)
def max(self, numeric_only: Optional[bool] = None):
def max(self, numeric_only: Optional[bool] = None) -> pd.Series:
return self._operations._metric_agg_series(
self, ["max"], numeric_only=numeric_only
)
def nunique(self):
def nunique(self) -> pd.Series:
return self._operations._metric_agg_series(
self, ["nunique"], numeric_only=False
)
@ -673,7 +700,7 @@ class QueryCompiler:
dropna: bool = True,
is_dataframe_agg: bool = False,
numeric_only: Optional[bool] = True,
quantiles: Union[int, float, List[int], List[float], None] = None,
quantiles: Optional[Union[int, float, List[int], List[float]]] = None,
) -> pd.DataFrame:
return self._operations.aggs_groupby(
self,
@ -691,27 +718,27 @@ class QueryCompiler:
def value_counts(self, es_size: int) -> pd.Series:
return self._operations.value_counts(self, es_size)
def es_info(self, buf):
def es_info(self, buf: TextIO) -> None:
buf.write(f"es_index_pattern: {self._index_pattern}\n")
self._index.es_info(buf)
self._mappings.es_info(buf)
self._operations.es_info(self, buf)
def describe(self):
def describe(self) -> pd.DataFrame:
return self._operations.describe(self)
def _hist(self, num_bins):
def _hist(self, num_bins: int) -> Tuple[pd.DataFrame, pd.DataFrame]:
return self._operations.hist(self, num_bins)
def _update_query(self, boolean_filter):
def _update_query(self, boolean_filter: "BooleanFilter") -> "QueryCompiler":
result = self.copy()
result._operations.update_query(boolean_filter)
return result
def check_arithmetics(self, right):
def check_arithmetics(self, right: "QueryCompiler") -> None:
"""
Compare 2 query_compilers to see if arithmetic operations can be performed by the NDFrame object.
@ -750,7 +777,9 @@ class QueryCompiler:
f"{self._index_pattern} != {right._index_pattern}"
)
def arithmetic_op_fields(self, display_name, arithmetic_object):
def arithmetic_op_fields(
self, display_name: str, arithmetic_object: "ArithmeticSeries"
) -> "QueryCompiler":
result = self.copy()
# create a new field name for this display name
@ -758,7 +787,7 @@ class QueryCompiler:
# add scripted field
result._mappings.add_scripted_field(
scripted_field_name, display_name, arithmetic_object.dtype.name
scripted_field_name, display_name, arithmetic_object.dtype.name # type: ignore
)
result._operations.arithmetic_op_fields(scripted_field_name, arithmetic_object)
@ -768,7 +797,7 @@ class QueryCompiler:
def get_arithmetic_op_fields(self) -> Optional["ArithmeticOpFieldsTask"]:
return self._operations.get_arithmetic_op_fields()
def display_name_to_aggregatable_name(self, display_name: str) -> str:
def display_name_to_aggregatable_name(self, display_name: str) -> Optional[str]:
aggregatable_field_name = self._mappings.aggregatable_field_name(display_name)
return aggregatable_field_name
@ -780,13 +809,13 @@ class FieldMappingCache:
DataFrame access is slower than dict access.
"""
def __init__(self, mappings):
def __init__(self, mappings: "FieldMappings") -> None:
self._mappings = mappings
self._field_name_pd_dtype = dict()
self._date_field_format = dict()
self._field_name_pd_dtype: Dict[str, str] = dict()
self._date_field_format: Dict[str, str] = dict()
def field_name_pd_dtype(self, es_field_name):
def field_name_pd_dtype(self, es_field_name: str) -> str:
if es_field_name in self._field_name_pd_dtype:
return self._field_name_pd_dtype[es_field_name]
@ -797,7 +826,7 @@ class FieldMappingCache:
return pd_dtype
def date_field_format(self, es_field_name):
def date_field_format(self, es_field_name: str) -> str:
if es_field_name in self._date_field_format:
return self._date_field_format[es_field_name]

View File

@ -38,8 +38,8 @@ from io import StringIO
from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Tuple, Union
import numpy as np
import pandas as pd
from pandas.io.common import _expand_user, stringify_path
import pandas as pd # type: ignore
from pandas.io.common import _expand_user, stringify_path # type: ignore
import eland.plotting
from eland.arithmetics import ArithmeticNumber, ArithmeticSeries, ArithmeticString
@ -61,10 +61,10 @@ from eland.filter import (
from eland.ndframe import NDFrame
from eland.utils import to_list
if TYPE_CHECKING: # type: ignore
from elasticsearch import Elasticsearch # noqa: F401
if TYPE_CHECKING:
from elasticsearch import Elasticsearch
from eland.query_compiler import QueryCompiler # noqa: F401
from eland.query_compiler import QueryCompiler
def _get_method_name() -> str:
@ -175,7 +175,7 @@ class Series(NDFrame):
return num_rows, num_columns
@property
def es_field_name(self) -> str:
def es_field_name(self) -> pd.Index:
"""
Returns
-------
@ -185,7 +185,7 @@ class Series(NDFrame):
return self._query_compiler.get_field_names(include_scripted_fields=True)[0]
@property
def name(self) -> str:
def name(self) -> pd.Index:
return self._query_compiler.columns[0]
@name.setter
@ -793,7 +793,7 @@ class Series(NDFrame):
return buf.getvalue()
def __add__(self, right):
def __add__(self, right: "Series") -> "Series":
"""
Return addition of series and right, element-wise (binary operator add).
@ -853,7 +853,7 @@ class Series(NDFrame):
"""
return self._numeric_op(right, _get_method_name())
def __truediv__(self, right):
def __truediv__(self, right: "Series") -> "Series":
"""
Return floating division of series and right, element-wise (binary operator truediv).
@ -892,7 +892,7 @@ class Series(NDFrame):
"""
return self._numeric_op(right, _get_method_name())
def __floordiv__(self, right):
def __floordiv__(self, right: "Series") -> "Series":
"""
Return integer division of series and right, element-wise (binary operator floordiv //).
@ -931,7 +931,7 @@ class Series(NDFrame):
"""
return self._numeric_op(right, _get_method_name())
def __mod__(self, right):
def __mod__(self, right: "Series") -> "Series":
"""
Return modulo of series and right, element-wise (binary operator mod %).
@ -970,7 +970,7 @@ class Series(NDFrame):
"""
return self._numeric_op(right, _get_method_name())
def __mul__(self, right):
def __mul__(self, right: "Series") -> "Series":
"""
Return multiplication of series and right, element-wise (binary operator mul).
@ -1009,7 +1009,7 @@ class Series(NDFrame):
"""
return self._numeric_op(right, _get_method_name())
def __sub__(self, right):
def __sub__(self, right: "Series") -> "Series":
"""
Return subtraction of series and right, element-wise (binary operator sub).
@ -1048,7 +1048,7 @@ class Series(NDFrame):
"""
return self._numeric_op(right, _get_method_name())
def __pow__(self, right):
def __pow__(self, right: "Series") -> "Series":
"""
Return exponential power of series and right, element-wise (binary operator pow).
@ -1087,7 +1087,7 @@ class Series(NDFrame):
"""
return self._numeric_op(right, _get_method_name())
def __radd__(self, left):
def __radd__(self, left: "Series") -> "Series":
"""
Return addition of series and left, element-wise (binary operator add).
@ -1119,7 +1119,7 @@ class Series(NDFrame):
"""
return self._numeric_op(left, _get_method_name())
def __rtruediv__(self, left):
def __rtruediv__(self, left: "Series") -> "Series":
"""
Return division of series and left, element-wise (binary operator div).
@ -1151,7 +1151,7 @@ class Series(NDFrame):
"""
return self._numeric_op(left, _get_method_name())
def __rfloordiv__(self, left):
def __rfloordiv__(self, left: "Series") -> "Series":
"""
Return integer division of series and left, element-wise (binary operator floordiv //).
@ -1183,7 +1183,7 @@ class Series(NDFrame):
"""
return self._numeric_op(left, _get_method_name())
def __rmod__(self, left):
def __rmod__(self, left: "Series") -> "Series":
"""
Return modulo of series and left, element-wise (binary operator mod %).
@ -1215,7 +1215,7 @@ class Series(NDFrame):
"""
return self._numeric_op(left, _get_method_name())
def __rmul__(self, left):
def __rmul__(self, left: "Series") -> "Series":
"""
Return multiplication of series and left, element-wise (binary operator mul).
@ -1247,7 +1247,7 @@ class Series(NDFrame):
"""
return self._numeric_op(left, _get_method_name())
def __rpow__(self, left):
def __rpow__(self, left: "Series") -> "Series":
"""
Return exponential power of series and left, element-wise (binary operator pow).
@ -1279,7 +1279,7 @@ class Series(NDFrame):
"""
return self._numeric_op(left, _get_method_name())
def __rsub__(self, left):
def __rsub__(self, left: "Series") -> "Series":
"""
Return subtraction of series and left, element-wise (binary operator sub).
@ -1398,7 +1398,7 @@ class Series(NDFrame):
return series
def max(self, numeric_only=None):
def max(self, numeric_only: Optional[bool] = None) -> pd.Series:
"""
Return the maximum of the Series values
@ -1422,7 +1422,7 @@ class Series(NDFrame):
results = super().max(numeric_only=numeric_only)
return results.squeeze()
def mean(self, numeric_only=None):
def mean(self, numeric_only: Optional[bool] = None) -> pd.Series:
"""
Return the mean of the Series values
@ -1446,7 +1446,7 @@ class Series(NDFrame):
results = super().mean(numeric_only=numeric_only)
return results.squeeze()
def median(self, numeric_only=None):
def median(self, numeric_only: Optional[bool] = None) -> pd.Series:
"""
Return the median of the Series values
@ -1470,7 +1470,7 @@ class Series(NDFrame):
results = super().median(numeric_only=numeric_only)
return results.squeeze()
def min(self, numeric_only=None):
def min(self, numeric_only: Optional[bool] = None) -> pd.Series:
"""
Return the minimum of the Series values
@ -1494,7 +1494,7 @@ class Series(NDFrame):
results = super().min(numeric_only=numeric_only)
return results.squeeze()
def sum(self, numeric_only=None):
def sum(self, numeric_only: Optional[bool] = None) -> pd.Series:
"""
Return the sum of the Series values
@ -1518,7 +1518,7 @@ class Series(NDFrame):
results = super().sum(numeric_only=numeric_only)
return results.squeeze()
def nunique(self):
def nunique(self) -> pd.Series:
"""
Return the number of unique values in a Series
@ -1540,7 +1540,7 @@ class Series(NDFrame):
results = super().nunique()
return results.squeeze()
def var(self, numeric_only=None):
def var(self, numeric_only: Optional[bool] = None) -> pd.Series:
"""
Return variance for a Series
@ -1562,7 +1562,7 @@ class Series(NDFrame):
results = super().var(numeric_only=numeric_only)
return results.squeeze()
def std(self, numeric_only=None):
def std(self, numeric_only: Optional[bool] = None) -> pd.Series:
"""
Return standard deviation for a Series
@ -1584,7 +1584,7 @@ class Series(NDFrame):
results = super().std(numeric_only=numeric_only)
return results.squeeze()
def mad(self, numeric_only=None):
def mad(self, numeric_only: Optional[bool] = None) -> pd.Series:
"""
Return median absolute deviation for a Series
@ -1643,7 +1643,7 @@ class Series(NDFrame):
# def values TODO - not implemented as causes current implementation of query to fail
def to_numpy(self):
def to_numpy(self) -> None:
"""
Not implemented.

View File

@ -16,7 +16,7 @@
# under the License.
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, List, Tuple
from typing import TYPE_CHECKING, List, Tuple
from eland import SortOrder
from eland.actions import HeadAction, SortIndexAction, TailAction
@ -253,7 +253,7 @@ class QueryIdsTask(Task):
class QueryTermsTask(Task):
def __init__(self, must: bool, field: str, terms: List[Any]):
def __init__(self, must: bool, field: str, terms: List[str]):
"""
Parameters
----------

View File

@ -38,7 +38,10 @@ TYPED_FILES = (
"eland/tasks.py",
"eland/utils.py",
"eland/groupby.py",
"eland/operations.py",
"eland/ndframe.py",
"eland/ml/__init__.py",
"eland/ml/_optional.py",
"eland/ml/_model_serializer.py",
"eland/ml/ml_model.py",
"eland/ml/transformers/__init__.py",
@ -46,6 +49,7 @@ TYPED_FILES = (
"eland/ml/transformers/lightgbm.py",
"eland/ml/transformers/sklearn.py",
"eland/ml/transformers/xgboost.py",
"eland/plotting/_matplotlib/__init__.py",
)
@ -60,7 +64,9 @@ def format(session):
@nox.session(reuse_venv=True)
def lint(session):
session.install("black", "flake8", "mypy", "isort")
# Install numpy to use its mypy plugin
# https://numpy.org/devdocs/reference/typing.html#mypy-plugin
session.install("black", "flake8", "mypy", "isort", "numpy")
session.install("--pre", "elasticsearch")
session.run("python", "utils/license-headers.py", "check", *SOURCE_FILES)
session.run("black", "--check", "--target-version=py37", *SOURCE_FILES)

View File

@ -10,3 +10,4 @@ xgboost>=1
nox
lightgbm
pytest-cov
mypy

View File

@ -1,2 +1,4 @@
[isort]
profile = black
[mypy]
plugins = numpy.typing.mypy_plugin