mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
Add type hints to 'eland.operations' and 'eland.ndframe'
This commit is contained in:
parent
c0e861dc77
commit
823f01cc6c
@ -19,9 +19,11 @@ from abc import ABC, abstractmethod
|
||||
from io import StringIO
|
||||
from typing import TYPE_CHECKING, Any, List, Union
|
||||
|
||||
import numpy as np # type: ignore
|
||||
import numpy as np
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from numpy.typing import DTypeLike
|
||||
|
||||
from .query_compiler import QueryCompiler
|
||||
|
||||
|
||||
@ -32,7 +34,7 @@ class ArithmeticObject(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def dtype(self) -> np.dtype:
|
||||
def dtype(self) -> "DTypeLike":
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@ -52,7 +54,7 @@ class ArithmeticString(ArithmeticObject):
|
||||
return self.value
|
||||
|
||||
@property
|
||||
def dtype(self) -> np.dtype:
|
||||
def dtype(self) -> "DTypeLike":
|
||||
return np.dtype(object)
|
||||
|
||||
@property
|
||||
@ -64,7 +66,7 @@ class ArithmeticString(ArithmeticObject):
|
||||
|
||||
|
||||
class ArithmeticNumber(ArithmeticObject):
|
||||
def __init__(self, value: Union[int, float], dtype: np.dtype):
|
||||
def __init__(self, value: Union[int, float], dtype: "DTypeLike"):
|
||||
self._value = value
|
||||
self._dtype = dtype
|
||||
|
||||
@ -76,7 +78,7 @@ class ArithmeticNumber(ArithmeticObject):
|
||||
return f"{self._value}"
|
||||
|
||||
@property
|
||||
def dtype(self) -> np.dtype:
|
||||
def dtype(self) -> "DTypeLike":
|
||||
return self._dtype
|
||||
|
||||
def __repr__(self) -> str:
|
||||
@ -89,8 +91,8 @@ class ArithmeticSeries(ArithmeticObject):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, query_compiler: "QueryCompiler", display_name: str, dtype: np.dtype
|
||||
):
|
||||
self, query_compiler: "QueryCompiler", display_name: str, dtype: "DTypeLike"
|
||||
) -> None:
|
||||
# type defs
|
||||
self._value: str
|
||||
self._tasks: List["ArithmeticTask"]
|
||||
@ -121,7 +123,7 @@ class ArithmeticSeries(ArithmeticObject):
|
||||
return self._value
|
||||
|
||||
@property
|
||||
def dtype(self) -> np.dtype:
|
||||
def dtype(self) -> "DTypeLike":
|
||||
return self._dtype
|
||||
|
||||
def __repr__(self) -> str:
|
||||
|
@ -18,12 +18,24 @@
|
||||
import re
|
||||
import warnings
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
import numpy as np # type: ignore
|
||||
import pandas as pd # type: ignore
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from numpy.typing import DTypeLike
|
||||
|
||||
# Default number of rows displayed (different to pandas where ALL could be displayed)
|
||||
DEFAULT_NUM_ROWS_DISPLAYED = 60
|
||||
DEFAULT_CHUNK_SIZE = 10000
|
||||
@ -42,7 +54,7 @@ with warnings.catch_warnings():
|
||||
|
||||
|
||||
def build_pd_series(
|
||||
data: Dict[str, Any], dtype: Optional[np.dtype] = None, **kwargs: Any
|
||||
data: Dict[str, Any], dtype: Optional["DTypeLike"] = None, **kwargs: Any
|
||||
) -> pd.Series:
|
||||
"""Builds a pd.Series while squelching the warning
|
||||
for unspecified dtype on empty series
|
||||
@ -88,7 +100,7 @@ class SortOrder(Enum):
|
||||
|
||||
|
||||
def elasticsearch_date_to_pandas_date(
|
||||
value: Union[int, str], date_format: Optional[str]
|
||||
value: Union[int, str, float], date_format: Optional[str]
|
||||
) -> pd.Timestamp:
|
||||
"""
|
||||
Given a specific Elasticsearch format for a date datatype, returns the
|
||||
@ -98,7 +110,7 @@ def elasticsearch_date_to_pandas_date(
|
||||
|
||||
Parameters
|
||||
----------
|
||||
value: Union[int, str]
|
||||
value: Union[int, str, float]
|
||||
The date value.
|
||||
date_format: str
|
||||
The Elasticsearch date format (ex. 'epoch_millis', 'epoch_second', etc.)
|
||||
|
@ -15,9 +15,11 @@
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pytest
|
||||
import pandas as pd # type: ignore
|
||||
import pytest # type: ignore
|
||||
|
||||
import eland as ed
|
||||
|
||||
@ -28,7 +30,7 @@ pd.set_option("display.width", 100)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def add_imports(doctest_namespace):
|
||||
def add_imports(doctest_namespace: Dict[str, Any]) -> None:
|
||||
doctest_namespace["np"] = np
|
||||
doctest_namespace["pd"] = pd
|
||||
doctest_namespace["ed"] = ed
|
||||
|
@ -19,19 +19,19 @@ import re
|
||||
import sys
|
||||
import warnings
|
||||
from io import StringIO
|
||||
from typing import Any, List, Optional, Sequence, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from pandas.core.common import apply_if_callable, is_bool_indexer
|
||||
from pandas.core.computation.eval import eval
|
||||
from pandas.core.dtypes.common import is_list_like
|
||||
from pandas.core.indexing import check_bool_indexer
|
||||
from pandas.io.common import _expand_user, stringify_path
|
||||
from pandas.io.formats import console
|
||||
import pandas as pd # type: ignore
|
||||
from pandas.core.common import apply_if_callable, is_bool_indexer # type: ignore
|
||||
from pandas.core.computation.eval import eval # type: ignore
|
||||
from pandas.core.dtypes.common import is_list_like # type: ignore
|
||||
from pandas.core.indexing import check_bool_indexer # type: ignore
|
||||
from pandas.io.common import _expand_user, stringify_path # type: ignore
|
||||
from pandas.io.formats import console # type: ignore
|
||||
from pandas.io.formats import format as fmt
|
||||
from pandas.io.formats.printing import pprint_thing
|
||||
from pandas.util._validators import validate_bool_kwarg
|
||||
from pandas.io.formats.printing import pprint_thing # type: ignore
|
||||
from pandas.util._validators import validate_bool_kwarg # type: ignore
|
||||
|
||||
import eland.plotting as gfx
|
||||
from eland.common import DEFAULT_NUM_ROWS_DISPLAYED, docstring_parameter
|
||||
@ -41,6 +41,11 @@ from eland.ndframe import NDFrame
|
||||
from eland.series import Series
|
||||
from eland.utils import is_valid_attr_name
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
from .query_compiler import QueryCompiler
|
||||
|
||||
|
||||
class DataFrame(NDFrame):
|
||||
"""
|
||||
@ -119,11 +124,13 @@ class DataFrame(NDFrame):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
es_client=None,
|
||||
es_index_pattern=None,
|
||||
es_index_field=None,
|
||||
columns=None,
|
||||
_query_compiler=None,
|
||||
es_client: Optional[
|
||||
Union[str, List[str], Tuple[str, ...], "Elasticsearch"]
|
||||
] = None,
|
||||
es_index_pattern: Optional[str] = None,
|
||||
columns: Optional[List[str]] = None,
|
||||
es_index_field: Optional[str] = None,
|
||||
_query_compiler: Optional["QueryCompiler"] = None,
|
||||
) -> None:
|
||||
"""
|
||||
There are effectively 2 constructors:
|
||||
@ -147,7 +154,7 @@ class DataFrame(NDFrame):
|
||||
_query_compiler=_query_compiler,
|
||||
)
|
||||
|
||||
def _get_columns(self):
|
||||
def _get_columns(self) -> pd.Index:
|
||||
"""
|
||||
The column labels of the DataFrame.
|
||||
|
||||
@ -178,7 +185,7 @@ class DataFrame(NDFrame):
|
||||
columns = property(_get_columns)
|
||||
|
||||
@property
|
||||
def empty(self):
|
||||
def empty(self) -> bool:
|
||||
"""Determines if the DataFrame is empty.
|
||||
|
||||
Returns
|
||||
@ -278,7 +285,10 @@ class DataFrame(NDFrame):
|
||||
return DataFrame(_query_compiler=self._query_compiler.tail(n))
|
||||
|
||||
def sample(
|
||||
self, n: int = None, frac: float = None, random_state: int = None
|
||||
self,
|
||||
n: Optional[int] = None,
|
||||
frac: Optional[float] = None,
|
||||
random_state: Optional[int] = None,
|
||||
) -> "DataFrame":
|
||||
"""
|
||||
Return n randomly sample rows or the specify fraction of rows
|
||||
@ -469,7 +479,7 @@ class DataFrame(NDFrame):
|
||||
if is_valid_attr_name(column_name)
|
||||
]
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> None:
|
||||
"""
|
||||
From pandas
|
||||
"""
|
||||
@ -501,7 +511,7 @@ class DataFrame(NDFrame):
|
||||
|
||||
return buf.getvalue()
|
||||
|
||||
def _info_repr(self):
|
||||
def _info_repr(self) -> bool:
|
||||
"""
|
||||
True if the repr should show the info view.
|
||||
"""
|
||||
@ -510,7 +520,7 @@ class DataFrame(NDFrame):
|
||||
self._repr_fits_horizontal_() and self._repr_fits_vertical_()
|
||||
)
|
||||
|
||||
def _repr_html_(self):
|
||||
def _repr_html_(self) -> Optional[str]:
|
||||
"""
|
||||
From pandas - this is called by notebooks
|
||||
"""
|
||||
@ -540,7 +550,7 @@ class DataFrame(NDFrame):
|
||||
else:
|
||||
return None
|
||||
|
||||
def count(self):
|
||||
def count(self) -> pd.Series:
|
||||
"""
|
||||
Count non-NA cells for each column.
|
||||
|
||||
@ -855,10 +865,10 @@ class DataFrame(NDFrame):
|
||||
exceeds_info_cols = len(self.columns) > max_cols
|
||||
|
||||
# From pandas.DataFrame
|
||||
def _put_str(s, space):
|
||||
def _put_str(s, space) -> str:
|
||||
return f"{s}"[:space].ljust(space)
|
||||
|
||||
def _verbose_repr():
|
||||
def _verbose_repr() -> None:
|
||||
lines.append(f"Data columns (total {len(self.columns)} columns):")
|
||||
|
||||
id_head = " # "
|
||||
@ -930,10 +940,10 @@ class DataFrame(NDFrame):
|
||||
+ _put_str(dtype, space_dtype)
|
||||
)
|
||||
|
||||
def _non_verbose_repr():
|
||||
def _non_verbose_repr() -> None:
|
||||
lines.append(self.columns._summary(name="Columns"))
|
||||
|
||||
def _sizeof_fmt(num, size_qualifier):
|
||||
def _sizeof_fmt(num: float, size_qualifier: str) -> str:
|
||||
# returns size in human readable format
|
||||
for x in ["bytes", "KB", "MB", "GB", "TB"]:
|
||||
if num < 1024.0:
|
||||
@ -1004,7 +1014,7 @@ class DataFrame(NDFrame):
|
||||
border=None,
|
||||
table_id=None,
|
||||
render_links=False,
|
||||
):
|
||||
) -> Any:
|
||||
"""
|
||||
Render a Elasticsearch data as an HTML table.
|
||||
|
||||
@ -1171,7 +1181,7 @@ class DataFrame(NDFrame):
|
||||
result = _buf.getvalue()
|
||||
return result
|
||||
|
||||
def __getattr__(self, key):
|
||||
def __getattr__(self, key: str) -> Any:
|
||||
"""After regular attribute access, looks up the name in the columns
|
||||
|
||||
Parameters
|
||||
@ -1190,7 +1200,12 @@ class DataFrame(NDFrame):
|
||||
return self[key]
|
||||
raise e
|
||||
|
||||
def _getitem(self, key):
|
||||
def _getitem(
|
||||
self,
|
||||
key: Union[
|
||||
"DataFrame", "Series", pd.Index, List[str], str, BooleanFilter, np.ndarray
|
||||
],
|
||||
) -> Union["Series", "DataFrame"]:
|
||||
"""Get the column specified by key for this DataFrame.
|
||||
|
||||
Args:
|
||||
@ -1215,13 +1230,13 @@ class DataFrame(NDFrame):
|
||||
else:
|
||||
return self._getitem_column(key)
|
||||
|
||||
def _getitem_column(self, key):
|
||||
def _getitem_column(self, key: str) -> "Series":
|
||||
if key not in self.columns:
|
||||
raise KeyError(f"Requested column [{key}] is not in the DataFrame.")
|
||||
s = self._reduce_dimension(self._query_compiler.getitem_column_array([key]))
|
||||
return s
|
||||
|
||||
def _getitem_array(self, key):
|
||||
def _getitem_array(self, key: Union[str, pd.Series]) -> "DataFrame":
|
||||
if isinstance(key, Series):
|
||||
key = key.to_pandas()
|
||||
if is_bool_indexer(key):
|
||||
@ -1256,7 +1271,9 @@ class DataFrame(NDFrame):
|
||||
_query_compiler=self._query_compiler.getitem_column_array(key)
|
||||
)
|
||||
|
||||
def _create_or_update_from_compiler(self, new_query_compiler, inplace=False):
|
||||
def _create_or_update_from_compiler(
|
||||
self, new_query_compiler: "QueryCompiler", inplace: bool = False
|
||||
) -> Union["QueryCompiler", "DataFrame"]:
|
||||
"""Returns or updates a DataFrame given new query_compiler"""
|
||||
assert (
|
||||
isinstance(new_query_compiler, type(self._query_compiler))
|
||||
@ -1265,10 +1282,10 @@ class DataFrame(NDFrame):
|
||||
if not inplace:
|
||||
return DataFrame(_query_compiler=new_query_compiler)
|
||||
else:
|
||||
self._query_compiler = new_query_compiler
|
||||
self._query_compiler: "QueryCompiler" = new_query_compiler
|
||||
|
||||
@staticmethod
|
||||
def _reduce_dimension(query_compiler):
|
||||
def _reduce_dimension(query_compiler: "QueryCompiler") -> "Series":
|
||||
return Series(_query_compiler=query_compiler)
|
||||
|
||||
def to_csv(
|
||||
@ -1849,7 +1866,9 @@ class DataFrame(NDFrame):
|
||||
else:
|
||||
raise NotImplementedError(expr, type(expr))
|
||||
|
||||
def get(self, key, default=None):
|
||||
def get(
|
||||
self, key: Any, default: Optional[Any] = None
|
||||
) -> Union["Series", "DataFrame"]:
|
||||
"""
|
||||
Get item from object for given key (ex: DataFrame column).
|
||||
Returns default value if not found.
|
||||
@ -1956,7 +1975,7 @@ class DataFrame(NDFrame):
|
||||
|
||||
elif like is not None:
|
||||
|
||||
def matcher(x):
|
||||
def matcher(x: str) -> bool:
|
||||
return like in x
|
||||
|
||||
else:
|
||||
@ -1965,7 +1984,7 @@ class DataFrame(NDFrame):
|
||||
return self[[column for column in self.columns if matcher(column)]]
|
||||
|
||||
@property
|
||||
def values(self):
|
||||
def values(self) -> None:
|
||||
"""
|
||||
Not implemented.
|
||||
|
||||
@ -1983,7 +2002,7 @@ class DataFrame(NDFrame):
|
||||
"""
|
||||
return self.to_numpy()
|
||||
|
||||
def to_numpy(self):
|
||||
def to_numpy(self) -> None:
|
||||
"""
|
||||
Not implemented.
|
||||
|
||||
|
@ -25,13 +25,14 @@ from typing import (
|
||||
NamedTuple,
|
||||
Optional,
|
||||
Set,
|
||||
TextIO,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from pandas.core.dtypes.common import (
|
||||
import pandas as pd # type: ignore
|
||||
from pandas.core.dtypes.common import ( # type: ignore
|
||||
is_bool_dtype,
|
||||
is_datetime_or_timedelta_dtype,
|
||||
is_float_dtype,
|
||||
@ -42,6 +43,7 @@ from pandas.core.dtypes.inference import is_list_like
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from elasticsearch import Elasticsearch
|
||||
from numpy.typing import DTypeLike
|
||||
|
||||
|
||||
ES_FLOAT_TYPES: Set[str] = {"double", "float", "half_float", "scaled_float"}
|
||||
@ -559,7 +561,7 @@ class FieldMappings:
|
||||
|
||||
return {"mappings": {"properties": mapping_props}}
|
||||
|
||||
def aggregatable_field_name(self, display_name):
|
||||
def aggregatable_field_name(self, display_name: str) -> Optional[str]:
|
||||
"""
|
||||
Return a single aggregatable field_name from display_name
|
||||
|
||||
@ -598,7 +600,7 @@ class FieldMappings:
|
||||
|
||||
return self._mappings_capabilities.loc[display_name].aggregatable_es_field_name
|
||||
|
||||
def aggregatable_field_names(self):
|
||||
def aggregatable_field_names(self) -> Dict[str, str]:
|
||||
"""
|
||||
Return a list of aggregatable Elasticsearch field_names for all display names.
|
||||
If field is not aggregatable_field_names, return nothing.
|
||||
@ -634,7 +636,7 @@ class FieldMappings:
|
||||
)["data"]
|
||||
)
|
||||
|
||||
def date_field_format(self, es_field_name):
|
||||
def date_field_format(self, es_field_name: str) -> str:
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
@ -650,7 +652,7 @@ class FieldMappings:
|
||||
self._mappings_capabilities.es_field_name == es_field_name
|
||||
].es_date_format.squeeze()
|
||||
|
||||
def field_name_pd_dtype(self, es_field_name):
|
||||
def field_name_pd_dtype(self, es_field_name: str) -> str:
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
@ -674,7 +676,9 @@ class FieldMappings:
|
||||
].pd_dtype.squeeze()
|
||||
return pd_dtype
|
||||
|
||||
def add_scripted_field(self, scripted_field_name, display_name, pd_dtype):
|
||||
def add_scripted_field(
|
||||
self, scripted_field_name: str, display_name: str, pd_dtype: str
|
||||
) -> None:
|
||||
# if this display name is used somewhere else, drop it
|
||||
if display_name in self._mappings_capabilities.index:
|
||||
self._mappings_capabilities = self._mappings_capabilities.drop(
|
||||
@ -706,8 +710,8 @@ class FieldMappings:
|
||||
capability_matrix_row
|
||||
)
|
||||
|
||||
def numeric_source_fields(self):
|
||||
pd_dtypes, es_field_names, es_date_formats = self.metric_source_fields()
|
||||
def numeric_source_fields(self) -> List[str]:
|
||||
_, es_field_names, _ = self.metric_source_fields()
|
||||
return es_field_names
|
||||
|
||||
def all_source_fields(self) -> List[Field]:
|
||||
@ -753,7 +757,9 @@ class FieldMappings:
|
||||
# Maintain groupby order as given input
|
||||
return [groupby_fields[column] for column in by], aggregatable_fields
|
||||
|
||||
def metric_source_fields(self, include_bool=False, include_timestamp=False):
|
||||
def metric_source_fields(
|
||||
self, include_bool: bool = False, include_timestamp: bool = False
|
||||
) -> Tuple[List["DTypeLike"], List[str], Optional[List[str]]]:
|
||||
"""
|
||||
Returns
|
||||
-------
|
||||
@ -790,7 +796,7 @@ class FieldMappings:
|
||||
# return in display_name order
|
||||
return pd_dtypes, es_field_names, es_date_formats
|
||||
|
||||
def get_field_names(self, include_scripted_fields=True):
|
||||
def get_field_names(self, include_scripted_fields: bool = True) -> List[str]:
|
||||
if include_scripted_fields:
|
||||
return self._mappings_capabilities.es_field_name.to_list()
|
||||
|
||||
@ -801,7 +807,7 @@ class FieldMappings:
|
||||
def _get_display_names(self):
|
||||
return self._mappings_capabilities.index.to_list()
|
||||
|
||||
def _set_display_names(self, display_names):
|
||||
def _set_display_names(self, display_names: List[str]):
|
||||
if not is_list_like(display_names):
|
||||
raise ValueError(f"'{display_names}' is not list like")
|
||||
|
||||
@ -842,7 +848,7 @@ class FieldMappings:
|
||||
es_dtypes.name = None
|
||||
return es_dtypes
|
||||
|
||||
def es_info(self, buf):
|
||||
def es_info(self, buf: TextIO) -> None:
|
||||
buf.write("Mappings:\n")
|
||||
buf.write(f" capabilities:\n{self._mappings_capabilities.to_string()}\n")
|
||||
|
||||
|
@ -17,8 +17,11 @@
|
||||
|
||||
import distutils.version
|
||||
import importlib
|
||||
import types
|
||||
import warnings
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from types import ModuleType
|
||||
|
||||
# ----------------------------------------------------------------------------
|
||||
# functions largely based / taken from the six module
|
||||
@ -42,7 +45,7 @@ version_message = (
|
||||
)
|
||||
|
||||
|
||||
def _get_version(module: types.ModuleType) -> str:
|
||||
def _get_version(module: "ModuleType") -> Any:
|
||||
version = getattr(module, "__version__", None)
|
||||
if version is None:
|
||||
# xlrd uses a capitalized attribute name
|
||||
@ -55,7 +58,7 @@ def _get_version(module: types.ModuleType) -> str:
|
||||
|
||||
def import_optional_dependency(
|
||||
name: str, extra: str = "", raise_on_missing: bool = True, on_version: str = "raise"
|
||||
):
|
||||
) -> Optional["ModuleType"]:
|
||||
"""
|
||||
Import an optional dependency.
|
||||
|
||||
|
@ -18,7 +18,7 @@
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
|
||||
|
||||
import elasticsearch
|
||||
import numpy as np # type: ignore
|
||||
import numpy as np
|
||||
|
||||
from eland.common import ensure_es_client, es_version
|
||||
from eland.utils import deprecated_api
|
||||
@ -27,7 +27,8 @@ from .common import TYPE_CLASSIFICATION, TYPE_REGRESSION
|
||||
from .transformers import get_model_transformer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from elasticsearch import Elasticsearch # noqa: F401
|
||||
from elasticsearch import Elasticsearch
|
||||
from numpy.typing import ArrayLike, DTypeLike
|
||||
|
||||
# Try importing each ML lib separately so mypy users don't have to
|
||||
# have both installed to use type-checking.
|
||||
@ -83,8 +84,8 @@ class MLModel:
|
||||
self._trained_model_config_cache: Optional[Dict[str, Any]] = None
|
||||
|
||||
def predict(
|
||||
self, X: Union[np.ndarray, List[float], List[List[float]]]
|
||||
) -> np.ndarray:
|
||||
self, X: Union["ArrayLike", List[float], List[List[float]]]
|
||||
) -> "ArrayLike":
|
||||
"""
|
||||
Make a prediction using a trained model stored in Elasticsearch.
|
||||
|
||||
@ -196,7 +197,7 @@ class MLModel:
|
||||
|
||||
# Return results as np.ndarray of float32 or int (consistent with sklearn/xgboost)
|
||||
if self.model_type == TYPE_CLASSIFICATION:
|
||||
dt = np.int_
|
||||
dt: "DTypeLike" = np.int_
|
||||
else:
|
||||
dt = np.float32
|
||||
return np.asarray(y, dtype=dt)
|
||||
|
@ -17,7 +17,7 @@
|
||||
|
||||
from typing import Any, Dict, Optional, Sequence, Tuple, Type, Union
|
||||
|
||||
import numpy as np # type: ignore
|
||||
import numpy as np
|
||||
|
||||
from .._model_serializer import Ensemble, Tree, TreeNode
|
||||
from .._optional import import_optional_dependency
|
||||
@ -64,7 +64,7 @@ class SKLearnTransformer(ModelTransformer):
|
||||
self,
|
||||
node_index: int,
|
||||
node_data: Tuple[Union[int, float], ...],
|
||||
value: np.ndarray,
|
||||
value: np.ndarray, # type: ignore
|
||||
) -> TreeNode:
|
||||
"""
|
||||
This builds out a TreeNode class given the sklearn tree node definition.
|
||||
|
@ -229,7 +229,7 @@ class XGBoostClassifierTransformer(XGBoostForestTransformer):
|
||||
if model.classes_ is None:
|
||||
n_estimators = model.get_params()["n_estimators"]
|
||||
num_trees = model.get_booster().trees_to_dataframe()["Tree"].max() + 1
|
||||
self._num_classes = num_trees // n_estimators
|
||||
self._num_classes: int = num_trees // n_estimators
|
||||
else:
|
||||
self._num_classes = len(model.classes_)
|
||||
|
||||
|
@ -17,7 +17,7 @@
|
||||
|
||||
import sys
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, List, Optional, TextIO, Tuple, Union
|
||||
|
||||
import pandas as pd # type: ignore
|
||||
|
||||
@ -186,7 +186,7 @@ class NDFrame(ABC):
|
||||
"""
|
||||
return len(self.index)
|
||||
|
||||
def _es_info(self, buf):
|
||||
def _es_info(self, buf: TextIO) -> None:
|
||||
self._query_compiler.es_info(buf)
|
||||
|
||||
def mean(self, numeric_only: Optional[bool] = None) -> pd.Series:
|
||||
@ -604,7 +604,7 @@ class NDFrame(ABC):
|
||||
"""
|
||||
return self._query_compiler.mad(numeric_only=numeric_only)
|
||||
|
||||
def _hist(self, num_bins):
|
||||
def _hist(self, num_bins: int) -> Tuple[pd.DataFrame, pd.DataFrame]:
|
||||
return self._query_compiler._hist(num_bins)
|
||||
|
||||
def describe(self) -> pd.DataFrame:
|
||||
|
@ -26,12 +26,13 @@ from typing import (
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
TextIO,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pandas as pd # type: ignore
|
||||
from elasticsearch.helpers import scan
|
||||
|
||||
from eland.actions import PostProcessingAction, SortFieldAction
|
||||
@ -58,13 +59,18 @@ from eland.tasks import (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from numpy.typing import DTypeLike
|
||||
|
||||
from eland.arithmetics import ArithmeticSeries
|
||||
from eland.field_mappings import Field
|
||||
from eland.filter import BooleanFilter
|
||||
from eland.query_compiler import QueryCompiler
|
||||
from eland.tasks import Task
|
||||
|
||||
|
||||
class QueryParams:
|
||||
def __init__(self):
|
||||
self.query = Query()
|
||||
def __init__(self) -> None:
|
||||
self.query: Query = Query()
|
||||
self.sort_field: Optional[str] = None
|
||||
self.sort_order: Optional[SortOrder] = None
|
||||
self.size: Optional[int] = None
|
||||
@ -85,37 +91,48 @@ class Operations:
|
||||
(see https://docs.dask.org/en/latest/spec.html)
|
||||
"""
|
||||
|
||||
def __init__(self, tasks=None, arithmetic_op_fields_task=None):
|
||||
def __init__(
|
||||
self,
|
||||
tasks: Optional[List["Task"]] = None,
|
||||
arithmetic_op_fields_task: Optional["ArithmeticOpFieldsTask"] = None,
|
||||
) -> None:
|
||||
self._tasks: List["Task"]
|
||||
if tasks is None:
|
||||
self._tasks = []
|
||||
else:
|
||||
self._tasks = tasks
|
||||
self._arithmetic_op_fields_task = arithmetic_op_fields_task
|
||||
|
||||
def __constructor__(self, *args, **kwargs):
|
||||
def __constructor__(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> "Operations":
|
||||
return type(self)(*args, **kwargs)
|
||||
|
||||
def copy(self):
|
||||
def copy(self) -> "Operations":
|
||||
return self.__constructor__(
|
||||
tasks=copy.deepcopy(self._tasks),
|
||||
arithmetic_op_fields_task=copy.deepcopy(self._arithmetic_op_fields_task),
|
||||
)
|
||||
|
||||
def head(self, index, n):
|
||||
def head(self, index: "Index", n: int) -> None:
|
||||
# Add a task that is an ascending sort with size=n
|
||||
task = HeadTask(index, n)
|
||||
self._tasks.append(task)
|
||||
|
||||
def tail(self, index, n):
|
||||
def tail(self, index: "Index", n: int) -> None:
|
||||
# Add a task that is descending sort with size=n
|
||||
task = TailTask(index, n)
|
||||
self._tasks.append(task)
|
||||
|
||||
def sample(self, index, n, random_state):
|
||||
def sample(self, index: "Index", n: int, random_state: int) -> None:
|
||||
task = SampleTask(index, n, random_state)
|
||||
self._tasks.append(task)
|
||||
|
||||
def arithmetic_op_fields(self, display_name, arithmetic_series):
|
||||
def arithmetic_op_fields(
|
||||
self, display_name: str, arithmetic_series: "ArithmeticSeries"
|
||||
) -> None:
|
||||
if self._arithmetic_op_fields_task is None:
|
||||
self._arithmetic_op_fields_task = ArithmeticOpFieldsTask(
|
||||
display_name, arithmetic_series
|
||||
@ -127,10 +144,10 @@ class Operations:
|
||||
# get an ArithmeticOpFieldsTask if it exists
|
||||
return self._arithmetic_op_fields_task
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return repr(self._tasks)
|
||||
|
||||
def count(self, query_compiler):
|
||||
def count(self, query_compiler: "QueryCompiler") -> pd.Series:
|
||||
query_params, post_processing = self._resolve_tasks(query_compiler)
|
||||
|
||||
# Elasticsearch _count is very efficient and so used to return results here. This means that
|
||||
@ -161,7 +178,7 @@ class Operations:
|
||||
def _metric_agg_series(
|
||||
self,
|
||||
query_compiler: "QueryCompiler",
|
||||
agg: List,
|
||||
agg: List["str"],
|
||||
numeric_only: Optional[bool] = None,
|
||||
) -> pd.Series:
|
||||
results = self._metric_aggs(query_compiler, agg, numeric_only=numeric_only)
|
||||
@ -170,7 +187,7 @@ class Operations:
|
||||
else:
|
||||
# If all results are float convert into float64
|
||||
if all(isinstance(i, float) for i in results.values()):
|
||||
dtype = np.float64
|
||||
dtype: "DTypeLike" = np.float64
|
||||
# If all results are int convert into int64
|
||||
elif all(isinstance(i, int) for i in results.values()):
|
||||
dtype = np.int64
|
||||
@ -184,7 +201,9 @@ class Operations:
|
||||
def value_counts(self, query_compiler: "QueryCompiler", es_size: int) -> pd.Series:
|
||||
return self._terms_aggs(query_compiler, "terms", es_size)
|
||||
|
||||
def hist(self, query_compiler, bins):
|
||||
def hist(
|
||||
self, query_compiler: "QueryCompiler", bins: int
|
||||
) -> Tuple[pd.DataFrame, pd.DataFrame]:
|
||||
return self._hist_aggs(query_compiler, bins)
|
||||
|
||||
def idx(
|
||||
@ -237,7 +256,12 @@ class Operations:
|
||||
|
||||
return pd.Series(results)
|
||||
|
||||
def aggs(self, query_compiler, pd_aggs, numeric_only=None) -> pd.DataFrame:
|
||||
def aggs(
|
||||
self,
|
||||
query_compiler: "QueryCompiler",
|
||||
pd_aggs: List[str],
|
||||
numeric_only: Optional[bool] = None,
|
||||
) -> pd.DataFrame:
|
||||
results = self._metric_aggs(
|
||||
query_compiler, pd_aggs, numeric_only=numeric_only, is_dataframe_agg=True
|
||||
)
|
||||
@ -441,13 +465,15 @@ class Operations:
|
||||
|
||||
try:
|
||||
# get first value in dict (key is .keyword)
|
||||
name = list(aggregatable_field_names.values())[0]
|
||||
name: Optional[str] = list(aggregatable_field_names.values())[0]
|
||||
except IndexError:
|
||||
name = None
|
||||
|
||||
return build_pd_series(results, name=name)
|
||||
|
||||
def _hist_aggs(self, query_compiler, num_bins):
|
||||
def _hist_aggs(
|
||||
self, query_compiler: "QueryCompiler", num_bins: int
|
||||
) -> Tuple[pd.DataFrame, pd.DataFrame]:
|
||||
# Get histogram bins and weights for numeric field_names
|
||||
query_params, post_processing = self._resolve_tasks(query_compiler)
|
||||
|
||||
@ -488,8 +514,8 @@ class Operations:
|
||||
# },
|
||||
# ...
|
||||
|
||||
bins = {}
|
||||
weights = {}
|
||||
bins: Dict[str, List[int]] = {}
|
||||
weights: Dict[str, List[int]] = {}
|
||||
|
||||
# There is one more bin that weights
|
||||
# len(bins) = len(weights) + 1
|
||||
@ -537,11 +563,11 @@ class Operations:
|
||||
def _unpack_metric_aggs(
|
||||
self,
|
||||
fields: List["Field"],
|
||||
es_aggs: Union[List[str], List[Tuple[str, str]]],
|
||||
es_aggs: Union[List[str], List[Tuple[str, List[float]]]],
|
||||
pd_aggs: List[str],
|
||||
response: Dict[str, Any],
|
||||
numeric_only: Optional[bool],
|
||||
percentiles: Optional[List[float]] = None,
|
||||
percentiles: Optional[Sequence[float]] = None,
|
||||
is_dataframe_agg: bool = False,
|
||||
is_groupby: bool = False,
|
||||
) -> Dict[str, List[Any]]:
|
||||
@ -574,7 +600,7 @@ class Operations:
|
||||
"""
|
||||
results: Dict[str, Any] = {}
|
||||
percentile_values: List[float] = []
|
||||
agg_value: Union[int, float]
|
||||
agg_value: Any
|
||||
|
||||
for field in fields:
|
||||
values = []
|
||||
@ -611,7 +637,10 @@ class Operations:
|
||||
agg_value = agg_value["50.0"]
|
||||
else:
|
||||
# Maintain order of percentiles
|
||||
percentile_values = [agg_value[str(i)] for i in percentiles]
|
||||
if percentiles:
|
||||
percentile_values = [
|
||||
agg_value[str(i)] for i in percentiles
|
||||
]
|
||||
|
||||
if not percentile_values and pd_agg not in ("quantile", "median"):
|
||||
agg_value = agg_value[es_agg[1]]
|
||||
@ -682,7 +711,11 @@ class Operations:
|
||||
|
||||
# Cardinality is always either NaN or integer.
|
||||
elif pd_agg in ("nunique", "count"):
|
||||
agg_value = int(agg_value)
|
||||
agg_value = (
|
||||
int(agg_value)
|
||||
if isinstance(agg_value, (int, float))
|
||||
else np.NaN
|
||||
)
|
||||
|
||||
# If this is a non-null timestamp field convert to a pd.Timestamp()
|
||||
elif field.is_timestamp:
|
||||
@ -702,6 +735,7 @@ class Operations:
|
||||
for value in percentile_values
|
||||
]
|
||||
else:
|
||||
assert not isinstance(agg_value, dict)
|
||||
agg_value = elasticsearch_date_to_pandas_date(
|
||||
agg_value, field.es_date_format
|
||||
)
|
||||
@ -771,7 +805,7 @@ class Operations:
|
||||
by: List[str],
|
||||
pd_aggs: List[str],
|
||||
dropna: bool = True,
|
||||
quantiles: Optional[List[float]] = None,
|
||||
quantiles: Optional[Union[int, float, List[int], List[float]]] = None,
|
||||
is_dataframe_agg: bool = False,
|
||||
numeric_only: Optional[bool] = True,
|
||||
) -> pd.DataFrame:
|
||||
@ -811,7 +845,7 @@ class Operations:
|
||||
by_fields, agg_fields = query_compiler._mappings.groupby_source_fields(by=by)
|
||||
|
||||
# Used defaultdict to avoid initialization of columns with lists
|
||||
results: Dict[str, List[Any]] = defaultdict(list)
|
||||
results: Dict[Any, List[Any]] = defaultdict(list)
|
||||
|
||||
if numeric_only:
|
||||
agg_fields = [
|
||||
@ -823,7 +857,8 @@ class Operations:
|
||||
# To return for creating multi-index on columns
|
||||
headers = [agg_field.column for agg_field in agg_fields]
|
||||
|
||||
percentiles: Optional[List[str]] = None
|
||||
percentiles: Optional[List[float]] = None
|
||||
len_percentiles: int = 0
|
||||
if quantiles:
|
||||
percentiles = [
|
||||
quantile_to_percentile(x)
|
||||
@ -833,6 +868,7 @@ class Operations:
|
||||
else quantiles
|
||||
)
|
||||
]
|
||||
len_percentiles = len(percentiles)
|
||||
|
||||
# Convert pandas aggs to ES equivalent
|
||||
es_aggs = self._map_pd_aggs_to_es_aggs(pd_aggs=pd_aggs, percentiles=percentiles)
|
||||
@ -894,8 +930,8 @@ class Operations:
|
||||
if by_field.is_timestamp and isinstance(bucket_key, int):
|
||||
bucket_key = pd.to_datetime(bucket_key, unit="ms")
|
||||
|
||||
if pd_aggs == ["quantile"] and len(percentiles) > 1:
|
||||
bucket_key = [bucket_key] * len(percentiles)
|
||||
if pd_aggs == ["quantile"] and len_percentiles > 1:
|
||||
bucket_key = [bucket_key] * len_percentiles
|
||||
|
||||
results[by_field.column].extend(
|
||||
bucket_key if isinstance(bucket_key, list) else [bucket_key]
|
||||
@ -915,7 +951,7 @@ class Operations:
|
||||
)
|
||||
|
||||
# to construct index with quantiles
|
||||
if pd_aggs == ["quantile"] and len(percentiles) > 1:
|
||||
if pd_aggs == ["quantile"] and percentiles and len_percentiles > 1:
|
||||
results[None].extend([i / 100 for i in percentiles])
|
||||
|
||||
# Process the calculated agg values to response
|
||||
@ -929,9 +965,10 @@ class Operations:
|
||||
for pd_agg, val in zip(pd_aggs, value):
|
||||
results[f"{key}_{pd_agg}"].append(val)
|
||||
|
||||
# Just to maintain Output same as pandas with empty header.
|
||||
if pd_aggs == ["quantile"] and len(percentiles) > 1:
|
||||
by = by + [None]
|
||||
if pd_aggs == ["quantile"] and len_percentiles > 1:
|
||||
# by never holds None by default, we make an exception
|
||||
# here to maintain output same as pandas, also mypy complains
|
||||
by = by + [None] # type: ignore
|
||||
|
||||
agg_df = pd.DataFrame(results).set_index(by).sort_index()
|
||||
|
||||
@ -947,7 +984,7 @@ class Operations:
|
||||
@staticmethod
|
||||
def bucket_generator(
|
||||
query_compiler: "QueryCompiler", body: "Query"
|
||||
) -> Generator[List[str], None, List[str]]:
|
||||
) -> Generator[Sequence[Dict[str, Any]], None, Sequence[Dict[str, Any]]]:
|
||||
"""
|
||||
This can be used for all groupby operations.
|
||||
e.g.
|
||||
@ -977,18 +1014,24 @@ class Operations:
|
||||
)
|
||||
|
||||
# Pagination Logic
|
||||
composite_buckets = res["aggregations"]["groupby_buckets"]
|
||||
if "after_key" in composite_buckets:
|
||||
composite_buckets: Dict[str, Any] = res["aggregations"]["groupby_buckets"]
|
||||
|
||||
after_key: Optional[Dict[str, Any]] = composite_buckets.get(
|
||||
"after_key", None
|
||||
)
|
||||
buckets: Sequence[Dict[str, Any]] = composite_buckets["buckets"]
|
||||
|
||||
if after_key:
|
||||
|
||||
# yield the bucket which contains the result
|
||||
yield composite_buckets["buckets"]
|
||||
yield buckets
|
||||
|
||||
body.composite_agg_after_key(
|
||||
name="groupby_buckets",
|
||||
after_key=composite_buckets["after_key"],
|
||||
after_key=after_key,
|
||||
)
|
||||
else:
|
||||
return composite_buckets["buckets"]
|
||||
return buckets
|
||||
|
||||
@staticmethod
|
||||
def _map_pd_aggs_to_es_aggs(
|
||||
@ -1031,7 +1074,7 @@ class Operations:
|
||||
extended_stats_es_aggs = {"avg", "min", "max", "sum"}
|
||||
extended_stats_calls = 0
|
||||
|
||||
es_aggs = []
|
||||
es_aggs: List[Any] = []
|
||||
for pd_agg in pd_aggs:
|
||||
if pd_agg in extended_stats_pd_aggs:
|
||||
extended_stats_calls += 1
|
||||
@ -1100,7 +1143,7 @@ class Operations:
|
||||
def filter(
|
||||
self,
|
||||
query_compiler: "QueryCompiler",
|
||||
items: Optional[Sequence[str]] = None,
|
||||
items: Optional[List[str]] = None,
|
||||
like: Optional[str] = None,
|
||||
regex: Optional[str] = None,
|
||||
) -> None:
|
||||
@ -1122,7 +1165,7 @@ class Operations:
|
||||
f"to substring and regex operations not being available for Elasticsearch document IDs."
|
||||
)
|
||||
|
||||
def describe(self, query_compiler):
|
||||
def describe(self, query_compiler: "QueryCompiler") -> pd.DataFrame:
|
||||
query_params, post_processing = self._resolve_tasks(query_compiler)
|
||||
|
||||
size = self._size(query_params, post_processing)
|
||||
@ -1151,30 +1194,9 @@ class Operations:
|
||||
["count", "mean", "std", "min", "25%", "50%", "75%", "max"]
|
||||
)
|
||||
|
||||
def to_pandas(self, query_compiler, show_progress=False):
|
||||
class PandasDataFrameCollector:
|
||||
def __init__(self, show_progress):
|
||||
self._df = None
|
||||
self._show_progress = show_progress
|
||||
|
||||
def collect(self, df):
|
||||
# This collector does not batch data on output. Therefore, batch_size is fixed to None and this method
|
||||
# is only called once.
|
||||
if self._df is not None:
|
||||
raise RuntimeError(
|
||||
"Logic error in execution, this method must only be called once for this"
|
||||
"collector - batch_size == None"
|
||||
)
|
||||
self._df = df
|
||||
|
||||
@staticmethod
|
||||
def batch_size():
|
||||
# Do not change (see notes on collect)
|
||||
return None
|
||||
|
||||
@property
|
||||
def show_progress(self):
|
||||
return self._show_progress
|
||||
def to_pandas(
|
||||
self, query_compiler: "QueryCompiler", show_progress: bool = False
|
||||
) -> None:
|
||||
|
||||
collector = PandasDataFrameCollector(show_progress)
|
||||
|
||||
@ -1182,35 +1204,12 @@ class Operations:
|
||||
|
||||
return collector._df
|
||||
|
||||
def to_csv(self, query_compiler, show_progress=False, **kwargs):
|
||||
class PandasToCSVCollector:
|
||||
def __init__(self, show_progress, **args):
|
||||
self._args = args
|
||||
self._show_progress = show_progress
|
||||
self._ret = None
|
||||
self._first_time = True
|
||||
|
||||
def collect(self, df):
|
||||
# If this is the first time we collect results, then write header, otherwise don't write header
|
||||
# and append results
|
||||
if self._first_time:
|
||||
self._first_time = False
|
||||
df.to_csv(**self._args)
|
||||
else:
|
||||
# Don't write header, and change mode to append
|
||||
self._args["header"] = False
|
||||
self._args["mode"] = "a"
|
||||
df.to_csv(**self._args)
|
||||
|
||||
@staticmethod
|
||||
def batch_size():
|
||||
# By default read n docs and then dump to csv
|
||||
batch_size = DEFAULT_CSV_BATCH_OUTPUT_SIZE
|
||||
return batch_size
|
||||
|
||||
@property
|
||||
def show_progress(self):
|
||||
return self._show_progress
|
||||
def to_csv(
|
||||
self,
|
||||
query_compiler: "QueryCompiler",
|
||||
show_progress: bool = False,
|
||||
**kwargs: Union[bool, str],
|
||||
) -> None:
|
||||
|
||||
collector = PandasToCSVCollector(show_progress, **kwargs)
|
||||
|
||||
@ -1218,7 +1217,11 @@ class Operations:
|
||||
|
||||
return collector._ret
|
||||
|
||||
def _es_results(self, query_compiler, collector):
|
||||
def _es_results(
|
||||
self,
|
||||
query_compiler: "QueryCompiler",
|
||||
collector: Union["PandasToCSVCollector", "PandasDataFrameCollector"],
|
||||
) -> None:
|
||||
query_params, post_processing = self._resolve_tasks(query_compiler)
|
||||
|
||||
size, sort_params = Operations._query_params_to_size_and_sort(query_params)
|
||||
@ -1245,7 +1248,7 @@ class Operations:
|
||||
else:
|
||||
body["_source"] = False
|
||||
|
||||
es_results = None
|
||||
es_results: Any = None
|
||||
|
||||
# If size=None use scan not search - then post sort results when in df
|
||||
# If size>10000 use scan
|
||||
@ -1283,7 +1286,7 @@ class Operations:
|
||||
df = self._apply_df_post_processing(df, post_processing)
|
||||
collector.collect(df)
|
||||
|
||||
def index_count(self, query_compiler, field):
|
||||
def index_count(self, query_compiler: "QueryCompiler", field: str) -> int:
|
||||
# field is the index field so count values
|
||||
query_params, post_processing = self._resolve_tasks(query_compiler)
|
||||
|
||||
@ -1297,12 +1300,13 @@ class Operations:
|
||||
body = Query(query_params.query)
|
||||
body.exists(field, must=True)
|
||||
|
||||
return query_compiler._client.count(
|
||||
count: int = query_compiler._client.count(
|
||||
index=query_compiler._index_pattern, body=body.to_count_body()
|
||||
)["count"]
|
||||
return count
|
||||
|
||||
def _validate_index_operation(
|
||||
self, query_compiler: "QueryCompiler", items: Sequence[str]
|
||||
self, query_compiler: "QueryCompiler", items: List[str]
|
||||
) -> RESOLVED_TASK_TYPE:
|
||||
if not isinstance(items, list):
|
||||
raise TypeError(f"list item required - not {type(items)}")
|
||||
@ -1320,7 +1324,9 @@ class Operations:
|
||||
|
||||
return query_params, post_processing
|
||||
|
||||
def index_matches_count(self, query_compiler, field, items):
|
||||
def index_matches_count(
|
||||
self, query_compiler: "QueryCompiler", field: str, items: List[Any]
|
||||
) -> int:
|
||||
query_params, post_processing = self._validate_index_operation(
|
||||
query_compiler, items
|
||||
)
|
||||
@ -1332,12 +1338,13 @@ class Operations:
|
||||
else:
|
||||
body.terms(field, items, must=True)
|
||||
|
||||
return query_compiler._client.count(
|
||||
count: int = query_compiler._client.count(
|
||||
index=query_compiler._index_pattern, body=body.to_count_body()
|
||||
)["count"]
|
||||
return count
|
||||
|
||||
def drop_index_values(
|
||||
self, query_compiler: "QueryCompiler", field: str, items: Sequence[str]
|
||||
self, query_compiler: "QueryCompiler", field: str, items: List[str]
|
||||
) -> None:
|
||||
self._validate_index_operation(query_compiler, items)
|
||||
|
||||
@ -1349,6 +1356,7 @@ class Operations:
|
||||
# a in ['a','b','c']
|
||||
# b not in ['a','b','c']
|
||||
# For now use term queries
|
||||
task: Union["QueryIdsTask", "QueryTermsTask"]
|
||||
if field == Index.ID_INDEX_FIELD:
|
||||
task = QueryIdsTask(False, items)
|
||||
else:
|
||||
@ -1356,11 +1364,12 @@ class Operations:
|
||||
self._tasks.append(task)
|
||||
|
||||
def filter_index_values(
|
||||
self, query_compiler: "QueryCompiler", field: str, items: Sequence[str]
|
||||
self, query_compiler: "QueryCompiler", field: str, items: List[str]
|
||||
) -> None:
|
||||
# Basically .drop_index_values() except with must=True on tasks.
|
||||
self._validate_index_operation(query_compiler, items)
|
||||
|
||||
task: Union["QueryIdsTask", "QueryTermsTask"]
|
||||
if field == Index.ID_INDEX_FIELD:
|
||||
task = QueryIdsTask(True, items, sort_index_by_ids=True)
|
||||
else:
|
||||
@ -1406,7 +1415,7 @@ class Operations:
|
||||
# other operations require pre-queries and then combinations
|
||||
# other operations require in-core post-processing of results
|
||||
query_params = QueryParams()
|
||||
post_processing = []
|
||||
post_processing: List["PostProcessingAction"] = []
|
||||
|
||||
for task in self._tasks:
|
||||
query_params, post_processing = task.resolve_task(
|
||||
@ -1439,7 +1448,7 @@ class Operations:
|
||||
# This can return None
|
||||
return size
|
||||
|
||||
def es_info(self, query_compiler, buf):
|
||||
def es_info(self, query_compiler: "QueryCompiler", buf: TextIO) -> None:
|
||||
buf.write("Operations:\n")
|
||||
buf.write(f" tasks: {self._tasks}\n")
|
||||
|
||||
@ -1459,7 +1468,7 @@ class Operations:
|
||||
buf.write(f" body: {body}\n")
|
||||
buf.write(f" post_processing: {post_processing}\n")
|
||||
|
||||
def update_query(self, boolean_filter):
|
||||
def update_query(self, boolean_filter: "BooleanFilter") -> None:
|
||||
task = BooleanFilterTask(boolean_filter)
|
||||
self._tasks.append(task)
|
||||
|
||||
@ -1477,3 +1486,58 @@ def quantile_to_percentile(quantile: Union[int, float]) -> float:
|
||||
# quantile * 100 = percentile
|
||||
# return float(...) because min(1.0) gives 1
|
||||
return float(min(100, max(0, quantile * 100)))
|
||||
|
||||
|
||||
class PandasToCSVCollector:
|
||||
def __init__(self, show_progress: bool, **kwargs: Union[bool, str]) -> None:
|
||||
self._args = kwargs
|
||||
self._show_progress = show_progress
|
||||
self._ret = None
|
||||
self._first_time = True
|
||||
|
||||
def collect(self, df: "pd.DataFrame") -> None:
|
||||
# If this is the first time we collect results, then write header, otherwise don't write header
|
||||
# and append results
|
||||
if self._first_time:
|
||||
self._first_time = False
|
||||
df.to_csv(**self._args)
|
||||
else:
|
||||
# Don't write header, and change mode to append
|
||||
self._args["header"] = False
|
||||
self._args["mode"] = "a"
|
||||
df.to_csv(**self._args)
|
||||
|
||||
@staticmethod
|
||||
def batch_size() -> int:
|
||||
# By default read n docs and then dump to csv
|
||||
batch_size: int = DEFAULT_CSV_BATCH_OUTPUT_SIZE
|
||||
return batch_size
|
||||
|
||||
@property
|
||||
def show_progress(self) -> bool:
|
||||
return self._show_progress
|
||||
|
||||
|
||||
class PandasDataFrameCollector:
|
||||
def __init__(self, show_progress: bool) -> None:
|
||||
self._df = None
|
||||
self._show_progress = show_progress
|
||||
|
||||
def collect(self, df: "pd.DataFrame") -> None:
|
||||
# This collector does not batch data on output. Therefore, batch_size is fixed to None and this method
|
||||
# is only called once.
|
||||
if self._df is not None:
|
||||
raise RuntimeError(
|
||||
"Logic error in execution, this method must only be called once for this"
|
||||
"collector - batch_size == None"
|
||||
)
|
||||
self._df = df
|
||||
|
||||
@staticmethod
|
||||
def batch_size() -> None:
|
||||
# Do not change (see notes on collect)
|
||||
return None
|
||||
|
||||
@property
|
||||
def show_progress(self) -> bool:
|
||||
return self._show_progress
|
||||
|
@ -15,8 +15,10 @@
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
from pandas.plotting._matplotlib import converter
|
||||
from pandas.plotting._matplotlib import converter # type: ignore
|
||||
|
||||
try:
|
||||
# pandas<1.3.0
|
||||
@ -26,13 +28,13 @@ except ImportError:
|
||||
from pandas.core.dtypes.generic import ABCIndex
|
||||
|
||||
try: # pandas>=1.2.0
|
||||
from pandas.plotting._matplotlib.tools import (
|
||||
from pandas.plotting._matplotlib.tools import ( # type: ignore
|
||||
create_subplots,
|
||||
flatten_axes,
|
||||
set_ticks_props,
|
||||
)
|
||||
except ImportError: # pandas<1.2.0
|
||||
from pandas.plotting._matplotlib.tools import (
|
||||
from pandas.plotting._matplotlib.tools import ( # type: ignore
|
||||
_flatten as flatten_axes,
|
||||
_set_ticks_props as set_ticks_props,
|
||||
_subplots as create_subplots,
|
||||
@ -40,6 +42,9 @@ except ImportError: # pandas<1.2.0
|
||||
|
||||
from eland.utils import try_sort
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from numpy.typing import ArrayLike
|
||||
|
||||
|
||||
def hist_series(
|
||||
self,
|
||||
@ -53,8 +58,8 @@ def hist_series(
|
||||
figsize=None,
|
||||
bins=10,
|
||||
**kwds,
|
||||
):
|
||||
import matplotlib.pyplot as plt
|
||||
) -> "ArrayLike":
|
||||
import matplotlib.pyplot as plt # type: ignore
|
||||
|
||||
if by is None:
|
||||
if kwds.get("layout", None) is not None:
|
||||
|
@ -17,9 +17,19 @@
|
||||
|
||||
import copy
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Union
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
TextIO,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
import numpy as np # type: ignore
|
||||
import numpy as np
|
||||
import pandas as pd # type: ignore
|
||||
|
||||
from eland.common import (
|
||||
@ -28,11 +38,15 @@ from eland.common import (
|
||||
ensure_es_client,
|
||||
)
|
||||
from eland.field_mappings import FieldMappings
|
||||
from eland.filter import QueryFilter
|
||||
from eland.filter import BooleanFilter, QueryFilter
|
||||
from eland.index import Index
|
||||
from eland.operations import Operations
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
from eland.arithmetics import ArithmeticSeries
|
||||
|
||||
from .tasks import ArithmeticOpFieldsTask # noqa: F401
|
||||
|
||||
|
||||
@ -67,8 +81,10 @@ class QueryCompiler:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client=None,
|
||||
index_pattern=None,
|
||||
client: Optional[
|
||||
Union[str, List[str], Tuple[str, ...], "Elasticsearch"]
|
||||
] = None,
|
||||
index_pattern: Optional[str] = None,
|
||||
display_names=None,
|
||||
index_field=None,
|
||||
to_copy=None,
|
||||
@ -77,15 +93,15 @@ class QueryCompiler:
|
||||
if to_copy is not None:
|
||||
self._client = to_copy._client
|
||||
self._index_pattern = to_copy._index_pattern
|
||||
self._index = Index(self, to_copy._index.es_index_field)
|
||||
self._operations = copy.deepcopy(to_copy._operations)
|
||||
self._index: "Index" = Index(self, to_copy._index.es_index_field)
|
||||
self._operations: "Operations" = copy.deepcopy(to_copy._operations)
|
||||
self._mappings: FieldMappings = copy.deepcopy(to_copy._mappings)
|
||||
else:
|
||||
self._client = ensure_es_client(client)
|
||||
self._index_pattern = index_pattern
|
||||
# Get and persist mappings, this allows us to correctly
|
||||
# map returned types from Elasticsearch to pandas datatypes
|
||||
self._mappings: FieldMappings = FieldMappings(
|
||||
self._mappings = FieldMappings(
|
||||
client=self._client,
|
||||
index_pattern=self._index_pattern,
|
||||
display_names=display_names,
|
||||
@ -103,15 +119,15 @@ class QueryCompiler:
|
||||
|
||||
return pd.Index(columns)
|
||||
|
||||
def _get_display_names(self):
|
||||
def _get_display_names(self) -> "pd.Index":
|
||||
display_names = self._mappings.display_names
|
||||
|
||||
return pd.Index(display_names)
|
||||
|
||||
def _set_display_names(self, display_names):
|
||||
def _set_display_names(self, display_names: List[str]) -> None:
|
||||
self._mappings.display_names = display_names
|
||||
|
||||
def get_field_names(self, include_scripted_fields):
|
||||
def get_field_names(self, include_scripted_fields: bool) -> List[str]:
|
||||
return self._mappings.get_field_names(include_scripted_fields)
|
||||
|
||||
def add_scripted_field(self, scripted_field_name, display_name, pd_dtype):
|
||||
@ -129,7 +145,12 @@ class QueryCompiler:
|
||||
|
||||
# END Index, columns, and dtypes objects
|
||||
|
||||
def _es_results_to_pandas(self, results, batch_size=None, show_progress=False):
|
||||
def _es_results_to_pandas(
|
||||
self,
|
||||
results: Dict[Any, Any],
|
||||
batch_size: Optional[int] = None,
|
||||
show_progress: bool = False,
|
||||
) -> "pd.Dataframe":
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
@ -300,7 +321,7 @@ class QueryCompiler:
|
||||
|
||||
return partial_result, df
|
||||
|
||||
def _flatten_dict(self, y, field_mapping_cache):
|
||||
def _flatten_dict(self, y, field_mapping_cache: "FieldMappingCache"):
|
||||
out = {}
|
||||
|
||||
def flatten(x, name=""):
|
||||
@ -368,7 +389,7 @@ class QueryCompiler:
|
||||
"""
|
||||
return self._operations.index_count(self, self.index.es_index_field)
|
||||
|
||||
def _index_matches_count(self, items):
|
||||
def _index_matches_count(self, items: List[Any]) -> int:
|
||||
"""
|
||||
Returns
|
||||
-------
|
||||
@ -386,10 +407,10 @@ class QueryCompiler:
|
||||
df[c] = pd.Series(dtype=d)
|
||||
return df
|
||||
|
||||
def copy(self):
|
||||
def copy(self) -> "QueryCompiler":
|
||||
return QueryCompiler(to_copy=self)
|
||||
|
||||
def rename(self, renames, inplace=False):
|
||||
def rename(self, renames, inplace: bool = False) -> "QueryCompiler":
|
||||
if inplace:
|
||||
self._mappings.rename(renames)
|
||||
return self
|
||||
@ -398,21 +419,23 @@ class QueryCompiler:
|
||||
result._mappings.rename(renames)
|
||||
return result
|
||||
|
||||
def head(self, n):
|
||||
def head(self, n: int) -> "QueryCompiler":
|
||||
result = self.copy()
|
||||
|
||||
result._operations.head(self._index, n)
|
||||
|
||||
return result
|
||||
|
||||
def tail(self, n):
|
||||
def tail(self, n: int) -> "QueryCompiler":
|
||||
result = self.copy()
|
||||
|
||||
result._operations.tail(self._index, n)
|
||||
|
||||
return result
|
||||
|
||||
def sample(self, n=None, frac=None, random_state=None):
|
||||
def sample(
|
||||
self, n: Optional[int] = None, frac=None, random_state=None
|
||||
) -> "QueryCompiler":
|
||||
result = self.copy()
|
||||
|
||||
if n is None and frac is None:
|
||||
@ -501,11 +524,11 @@ class QueryCompiler:
|
||||
query = {"multi_match": options}
|
||||
return QueryFilter(query)
|
||||
|
||||
def es_query(self, query):
|
||||
def es_query(self, query: Dict[str, Any]) -> "QueryCompiler":
|
||||
return self._update_query(QueryFilter(query))
|
||||
|
||||
# To/From Pandas
|
||||
def to_pandas(self, show_progress=False):
|
||||
def to_pandas(self, show_progress: bool = False):
|
||||
"""Converts Eland DataFrame to Pandas DataFrame.
|
||||
|
||||
Returns:
|
||||
@ -543,7 +566,9 @@ class QueryCompiler:
|
||||
|
||||
return result
|
||||
|
||||
def drop(self, index=None, columns=None):
|
||||
def drop(
|
||||
self, index: Optional[str] = None, columns: Optional[List[str]] = None
|
||||
) -> "QueryCompiler":
|
||||
result = self.copy()
|
||||
|
||||
# Drop gets all columns and removes drops
|
||||
@ -559,7 +584,7 @@ class QueryCompiler:
|
||||
|
||||
def filter(
|
||||
self,
|
||||
items: Optional[Sequence[str]] = None,
|
||||
items: Optional[List[str]] = None,
|
||||
like: Optional[str] = None,
|
||||
regex: Optional[str] = None,
|
||||
) -> "QueryCompiler":
|
||||
@ -570,53 +595,55 @@ class QueryCompiler:
|
||||
result._operations.filter(self, items=items, like=like, regex=regex)
|
||||
return result
|
||||
|
||||
def aggs(self, func: List[str], numeric_only: Optional[bool] = None):
|
||||
def aggs(
|
||||
self, func: List[str], numeric_only: Optional[bool] = None
|
||||
) -> pd.DataFrame:
|
||||
return self._operations.aggs(self, func, numeric_only=numeric_only)
|
||||
|
||||
def count(self):
|
||||
def count(self) -> pd.Series:
|
||||
return self._operations.count(self)
|
||||
|
||||
def mean(self, numeric_only: Optional[bool] = None):
|
||||
def mean(self, numeric_only: Optional[bool] = None) -> pd.Series:
|
||||
return self._operations._metric_agg_series(
|
||||
self, ["mean"], numeric_only=numeric_only
|
||||
)
|
||||
|
||||
def var(self, numeric_only: Optional[bool] = None):
|
||||
def var(self, numeric_only: Optional[bool] = None) -> pd.Series:
|
||||
return self._operations._metric_agg_series(
|
||||
self, ["var"], numeric_only=numeric_only
|
||||
)
|
||||
|
||||
def std(self, numeric_only: Optional[bool] = None):
|
||||
def std(self, numeric_only: Optional[bool] = None) -> pd.Series:
|
||||
return self._operations._metric_agg_series(
|
||||
self, ["std"], numeric_only=numeric_only
|
||||
)
|
||||
|
||||
def mad(self, numeric_only: Optional[bool] = None):
|
||||
def mad(self, numeric_only: Optional[bool] = None) -> pd.Series:
|
||||
return self._operations._metric_agg_series(
|
||||
self, ["mad"], numeric_only=numeric_only
|
||||
)
|
||||
|
||||
def median(self, numeric_only: Optional[bool] = None):
|
||||
def median(self, numeric_only: Optional[bool] = None) -> pd.Series:
|
||||
return self._operations._metric_agg_series(
|
||||
self, ["median"], numeric_only=numeric_only
|
||||
)
|
||||
|
||||
def sum(self, numeric_only: Optional[bool] = None):
|
||||
def sum(self, numeric_only: Optional[bool] = None) -> pd.Series:
|
||||
return self._operations._metric_agg_series(
|
||||
self, ["sum"], numeric_only=numeric_only
|
||||
)
|
||||
|
||||
def min(self, numeric_only: Optional[bool] = None):
|
||||
def min(self, numeric_only: Optional[bool] = None) -> pd.Series:
|
||||
return self._operations._metric_agg_series(
|
||||
self, ["min"], numeric_only=numeric_only
|
||||
)
|
||||
|
||||
def max(self, numeric_only: Optional[bool] = None):
|
||||
def max(self, numeric_only: Optional[bool] = None) -> pd.Series:
|
||||
return self._operations._metric_agg_series(
|
||||
self, ["max"], numeric_only=numeric_only
|
||||
)
|
||||
|
||||
def nunique(self):
|
||||
def nunique(self) -> pd.Series:
|
||||
return self._operations._metric_agg_series(
|
||||
self, ["nunique"], numeric_only=False
|
||||
)
|
||||
@ -673,7 +700,7 @@ class QueryCompiler:
|
||||
dropna: bool = True,
|
||||
is_dataframe_agg: bool = False,
|
||||
numeric_only: Optional[bool] = True,
|
||||
quantiles: Union[int, float, List[int], List[float], None] = None,
|
||||
quantiles: Optional[Union[int, float, List[int], List[float]]] = None,
|
||||
) -> pd.DataFrame:
|
||||
return self._operations.aggs_groupby(
|
||||
self,
|
||||
@ -691,27 +718,27 @@ class QueryCompiler:
|
||||
def value_counts(self, es_size: int) -> pd.Series:
|
||||
return self._operations.value_counts(self, es_size)
|
||||
|
||||
def es_info(self, buf):
|
||||
def es_info(self, buf: TextIO) -> None:
|
||||
buf.write(f"es_index_pattern: {self._index_pattern}\n")
|
||||
|
||||
self._index.es_info(buf)
|
||||
self._mappings.es_info(buf)
|
||||
self._operations.es_info(self, buf)
|
||||
|
||||
def describe(self):
|
||||
def describe(self) -> pd.DataFrame:
|
||||
return self._operations.describe(self)
|
||||
|
||||
def _hist(self, num_bins):
|
||||
def _hist(self, num_bins: int) -> Tuple[pd.DataFrame, pd.DataFrame]:
|
||||
return self._operations.hist(self, num_bins)
|
||||
|
||||
def _update_query(self, boolean_filter):
|
||||
def _update_query(self, boolean_filter: "BooleanFilter") -> "QueryCompiler":
|
||||
result = self.copy()
|
||||
|
||||
result._operations.update_query(boolean_filter)
|
||||
|
||||
return result
|
||||
|
||||
def check_arithmetics(self, right):
|
||||
def check_arithmetics(self, right: "QueryCompiler") -> None:
|
||||
"""
|
||||
Compare 2 query_compilers to see if arithmetic operations can be performed by the NDFrame object.
|
||||
|
||||
@ -750,7 +777,9 @@ class QueryCompiler:
|
||||
f"{self._index_pattern} != {right._index_pattern}"
|
||||
)
|
||||
|
||||
def arithmetic_op_fields(self, display_name, arithmetic_object):
|
||||
def arithmetic_op_fields(
|
||||
self, display_name: str, arithmetic_object: "ArithmeticSeries"
|
||||
) -> "QueryCompiler":
|
||||
result = self.copy()
|
||||
|
||||
# create a new field name for this display name
|
||||
@ -758,7 +787,7 @@ class QueryCompiler:
|
||||
|
||||
# add scripted field
|
||||
result._mappings.add_scripted_field(
|
||||
scripted_field_name, display_name, arithmetic_object.dtype.name
|
||||
scripted_field_name, display_name, arithmetic_object.dtype.name # type: ignore
|
||||
)
|
||||
|
||||
result._operations.arithmetic_op_fields(scripted_field_name, arithmetic_object)
|
||||
@ -768,7 +797,7 @@ class QueryCompiler:
|
||||
def get_arithmetic_op_fields(self) -> Optional["ArithmeticOpFieldsTask"]:
|
||||
return self._operations.get_arithmetic_op_fields()
|
||||
|
||||
def display_name_to_aggregatable_name(self, display_name: str) -> str:
|
||||
def display_name_to_aggregatable_name(self, display_name: str) -> Optional[str]:
|
||||
aggregatable_field_name = self._mappings.aggregatable_field_name(display_name)
|
||||
|
||||
return aggregatable_field_name
|
||||
@ -780,13 +809,13 @@ class FieldMappingCache:
|
||||
DataFrame access is slower than dict access.
|
||||
"""
|
||||
|
||||
def __init__(self, mappings):
|
||||
def __init__(self, mappings: "FieldMappings") -> None:
|
||||
self._mappings = mappings
|
||||
|
||||
self._field_name_pd_dtype = dict()
|
||||
self._date_field_format = dict()
|
||||
self._field_name_pd_dtype: Dict[str, str] = dict()
|
||||
self._date_field_format: Dict[str, str] = dict()
|
||||
|
||||
def field_name_pd_dtype(self, es_field_name):
|
||||
def field_name_pd_dtype(self, es_field_name: str) -> str:
|
||||
if es_field_name in self._field_name_pd_dtype:
|
||||
return self._field_name_pd_dtype[es_field_name]
|
||||
|
||||
@ -797,7 +826,7 @@ class FieldMappingCache:
|
||||
|
||||
return pd_dtype
|
||||
|
||||
def date_field_format(self, es_field_name):
|
||||
def date_field_format(self, es_field_name: str) -> str:
|
||||
if es_field_name in self._date_field_format:
|
||||
return self._date_field_format[es_field_name]
|
||||
|
||||
|
@ -38,8 +38,8 @@ from io import StringIO
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from pandas.io.common import _expand_user, stringify_path
|
||||
import pandas as pd # type: ignore
|
||||
from pandas.io.common import _expand_user, stringify_path # type: ignore
|
||||
|
||||
import eland.plotting
|
||||
from eland.arithmetics import ArithmeticNumber, ArithmeticSeries, ArithmeticString
|
||||
@ -61,10 +61,10 @@ from eland.filter import (
|
||||
from eland.ndframe import NDFrame
|
||||
from eland.utils import to_list
|
||||
|
||||
if TYPE_CHECKING: # type: ignore
|
||||
from elasticsearch import Elasticsearch # noqa: F401
|
||||
if TYPE_CHECKING:
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
from eland.query_compiler import QueryCompiler # noqa: F401
|
||||
from eland.query_compiler import QueryCompiler
|
||||
|
||||
|
||||
def _get_method_name() -> str:
|
||||
@ -175,7 +175,7 @@ class Series(NDFrame):
|
||||
return num_rows, num_columns
|
||||
|
||||
@property
|
||||
def es_field_name(self) -> str:
|
||||
def es_field_name(self) -> pd.Index:
|
||||
"""
|
||||
Returns
|
||||
-------
|
||||
@ -185,7 +185,7 @@ class Series(NDFrame):
|
||||
return self._query_compiler.get_field_names(include_scripted_fields=True)[0]
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
def name(self) -> pd.Index:
|
||||
return self._query_compiler.columns[0]
|
||||
|
||||
@name.setter
|
||||
@ -793,7 +793,7 @@ class Series(NDFrame):
|
||||
|
||||
return buf.getvalue()
|
||||
|
||||
def __add__(self, right):
|
||||
def __add__(self, right: "Series") -> "Series":
|
||||
"""
|
||||
Return addition of series and right, element-wise (binary operator add).
|
||||
|
||||
@ -853,7 +853,7 @@ class Series(NDFrame):
|
||||
"""
|
||||
return self._numeric_op(right, _get_method_name())
|
||||
|
||||
def __truediv__(self, right):
|
||||
def __truediv__(self, right: "Series") -> "Series":
|
||||
"""
|
||||
Return floating division of series and right, element-wise (binary operator truediv).
|
||||
|
||||
@ -892,7 +892,7 @@ class Series(NDFrame):
|
||||
"""
|
||||
return self._numeric_op(right, _get_method_name())
|
||||
|
||||
def __floordiv__(self, right):
|
||||
def __floordiv__(self, right: "Series") -> "Series":
|
||||
"""
|
||||
Return integer division of series and right, element-wise (binary operator floordiv //).
|
||||
|
||||
@ -931,7 +931,7 @@ class Series(NDFrame):
|
||||
"""
|
||||
return self._numeric_op(right, _get_method_name())
|
||||
|
||||
def __mod__(self, right):
|
||||
def __mod__(self, right: "Series") -> "Series":
|
||||
"""
|
||||
Return modulo of series and right, element-wise (binary operator mod %).
|
||||
|
||||
@ -970,7 +970,7 @@ class Series(NDFrame):
|
||||
"""
|
||||
return self._numeric_op(right, _get_method_name())
|
||||
|
||||
def __mul__(self, right):
|
||||
def __mul__(self, right: "Series") -> "Series":
|
||||
"""
|
||||
Return multiplication of series and right, element-wise (binary operator mul).
|
||||
|
||||
@ -1009,7 +1009,7 @@ class Series(NDFrame):
|
||||
"""
|
||||
return self._numeric_op(right, _get_method_name())
|
||||
|
||||
def __sub__(self, right):
|
||||
def __sub__(self, right: "Series") -> "Series":
|
||||
"""
|
||||
Return subtraction of series and right, element-wise (binary operator sub).
|
||||
|
||||
@ -1048,7 +1048,7 @@ class Series(NDFrame):
|
||||
"""
|
||||
return self._numeric_op(right, _get_method_name())
|
||||
|
||||
def __pow__(self, right):
|
||||
def __pow__(self, right: "Series") -> "Series":
|
||||
"""
|
||||
Return exponential power of series and right, element-wise (binary operator pow).
|
||||
|
||||
@ -1087,7 +1087,7 @@ class Series(NDFrame):
|
||||
"""
|
||||
return self._numeric_op(right, _get_method_name())
|
||||
|
||||
def __radd__(self, left):
|
||||
def __radd__(self, left: "Series") -> "Series":
|
||||
"""
|
||||
Return addition of series and left, element-wise (binary operator add).
|
||||
|
||||
@ -1119,7 +1119,7 @@ class Series(NDFrame):
|
||||
"""
|
||||
return self._numeric_op(left, _get_method_name())
|
||||
|
||||
def __rtruediv__(self, left):
|
||||
def __rtruediv__(self, left: "Series") -> "Series":
|
||||
"""
|
||||
Return division of series and left, element-wise (binary operator div).
|
||||
|
||||
@ -1151,7 +1151,7 @@ class Series(NDFrame):
|
||||
"""
|
||||
return self._numeric_op(left, _get_method_name())
|
||||
|
||||
def __rfloordiv__(self, left):
|
||||
def __rfloordiv__(self, left: "Series") -> "Series":
|
||||
"""
|
||||
Return integer division of series and left, element-wise (binary operator floordiv //).
|
||||
|
||||
@ -1183,7 +1183,7 @@ class Series(NDFrame):
|
||||
"""
|
||||
return self._numeric_op(left, _get_method_name())
|
||||
|
||||
def __rmod__(self, left):
|
||||
def __rmod__(self, left: "Series") -> "Series":
|
||||
"""
|
||||
Return modulo of series and left, element-wise (binary operator mod %).
|
||||
|
||||
@ -1215,7 +1215,7 @@ class Series(NDFrame):
|
||||
"""
|
||||
return self._numeric_op(left, _get_method_name())
|
||||
|
||||
def __rmul__(self, left):
|
||||
def __rmul__(self, left: "Series") -> "Series":
|
||||
"""
|
||||
Return multiplication of series and left, element-wise (binary operator mul).
|
||||
|
||||
@ -1247,7 +1247,7 @@ class Series(NDFrame):
|
||||
"""
|
||||
return self._numeric_op(left, _get_method_name())
|
||||
|
||||
def __rpow__(self, left):
|
||||
def __rpow__(self, left: "Series") -> "Series":
|
||||
"""
|
||||
Return exponential power of series and left, element-wise (binary operator pow).
|
||||
|
||||
@ -1279,7 +1279,7 @@ class Series(NDFrame):
|
||||
"""
|
||||
return self._numeric_op(left, _get_method_name())
|
||||
|
||||
def __rsub__(self, left):
|
||||
def __rsub__(self, left: "Series") -> "Series":
|
||||
"""
|
||||
Return subtraction of series and left, element-wise (binary operator sub).
|
||||
|
||||
@ -1398,7 +1398,7 @@ class Series(NDFrame):
|
||||
|
||||
return series
|
||||
|
||||
def max(self, numeric_only=None):
|
||||
def max(self, numeric_only: Optional[bool] = None) -> pd.Series:
|
||||
"""
|
||||
Return the maximum of the Series values
|
||||
|
||||
@ -1422,7 +1422,7 @@ class Series(NDFrame):
|
||||
results = super().max(numeric_only=numeric_only)
|
||||
return results.squeeze()
|
||||
|
||||
def mean(self, numeric_only=None):
|
||||
def mean(self, numeric_only: Optional[bool] = None) -> pd.Series:
|
||||
"""
|
||||
Return the mean of the Series values
|
||||
|
||||
@ -1446,7 +1446,7 @@ class Series(NDFrame):
|
||||
results = super().mean(numeric_only=numeric_only)
|
||||
return results.squeeze()
|
||||
|
||||
def median(self, numeric_only=None):
|
||||
def median(self, numeric_only: Optional[bool] = None) -> pd.Series:
|
||||
"""
|
||||
Return the median of the Series values
|
||||
|
||||
@ -1470,7 +1470,7 @@ class Series(NDFrame):
|
||||
results = super().median(numeric_only=numeric_only)
|
||||
return results.squeeze()
|
||||
|
||||
def min(self, numeric_only=None):
|
||||
def min(self, numeric_only: Optional[bool] = None) -> pd.Series:
|
||||
"""
|
||||
Return the minimum of the Series values
|
||||
|
||||
@ -1494,7 +1494,7 @@ class Series(NDFrame):
|
||||
results = super().min(numeric_only=numeric_only)
|
||||
return results.squeeze()
|
||||
|
||||
def sum(self, numeric_only=None):
|
||||
def sum(self, numeric_only: Optional[bool] = None) -> pd.Series:
|
||||
"""
|
||||
Return the sum of the Series values
|
||||
|
||||
@ -1518,7 +1518,7 @@ class Series(NDFrame):
|
||||
results = super().sum(numeric_only=numeric_only)
|
||||
return results.squeeze()
|
||||
|
||||
def nunique(self):
|
||||
def nunique(self) -> pd.Series:
|
||||
"""
|
||||
Return the number of unique values in a Series
|
||||
|
||||
@ -1540,7 +1540,7 @@ class Series(NDFrame):
|
||||
results = super().nunique()
|
||||
return results.squeeze()
|
||||
|
||||
def var(self, numeric_only=None):
|
||||
def var(self, numeric_only: Optional[bool] = None) -> pd.Series:
|
||||
"""
|
||||
Return variance for a Series
|
||||
|
||||
@ -1562,7 +1562,7 @@ class Series(NDFrame):
|
||||
results = super().var(numeric_only=numeric_only)
|
||||
return results.squeeze()
|
||||
|
||||
def std(self, numeric_only=None):
|
||||
def std(self, numeric_only: Optional[bool] = None) -> pd.Series:
|
||||
"""
|
||||
Return standard deviation for a Series
|
||||
|
||||
@ -1584,7 +1584,7 @@ class Series(NDFrame):
|
||||
results = super().std(numeric_only=numeric_only)
|
||||
return results.squeeze()
|
||||
|
||||
def mad(self, numeric_only=None):
|
||||
def mad(self, numeric_only: Optional[bool] = None) -> pd.Series:
|
||||
"""
|
||||
Return median absolute deviation for a Series
|
||||
|
||||
@ -1643,7 +1643,7 @@ class Series(NDFrame):
|
||||
|
||||
# def values TODO - not implemented as causes current implementation of query to fail
|
||||
|
||||
def to_numpy(self):
|
||||
def to_numpy(self) -> None:
|
||||
"""
|
||||
Not implemented.
|
||||
|
||||
|
@ -16,7 +16,7 @@
|
||||
# under the License.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, List, Tuple
|
||||
from typing import TYPE_CHECKING, List, Tuple
|
||||
|
||||
from eland import SortOrder
|
||||
from eland.actions import HeadAction, SortIndexAction, TailAction
|
||||
@ -253,7 +253,7 @@ class QueryIdsTask(Task):
|
||||
|
||||
|
||||
class QueryTermsTask(Task):
|
||||
def __init__(self, must: bool, field: str, terms: List[Any]):
|
||||
def __init__(self, must: bool, field: str, terms: List[str]):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
|
@ -38,7 +38,10 @@ TYPED_FILES = (
|
||||
"eland/tasks.py",
|
||||
"eland/utils.py",
|
||||
"eland/groupby.py",
|
||||
"eland/operations.py",
|
||||
"eland/ndframe.py",
|
||||
"eland/ml/__init__.py",
|
||||
"eland/ml/_optional.py",
|
||||
"eland/ml/_model_serializer.py",
|
||||
"eland/ml/ml_model.py",
|
||||
"eland/ml/transformers/__init__.py",
|
||||
@ -46,6 +49,7 @@ TYPED_FILES = (
|
||||
"eland/ml/transformers/lightgbm.py",
|
||||
"eland/ml/transformers/sklearn.py",
|
||||
"eland/ml/transformers/xgboost.py",
|
||||
"eland/plotting/_matplotlib/__init__.py",
|
||||
)
|
||||
|
||||
|
||||
@ -60,7 +64,9 @@ def format(session):
|
||||
|
||||
@nox.session(reuse_venv=True)
|
||||
def lint(session):
|
||||
session.install("black", "flake8", "mypy", "isort")
|
||||
# Install numpy to use its mypy plugin
|
||||
# https://numpy.org/devdocs/reference/typing.html#mypy-plugin
|
||||
session.install("black", "flake8", "mypy", "isort", "numpy")
|
||||
session.install("--pre", "elasticsearch")
|
||||
session.run("python", "utils/license-headers.py", "check", *SOURCE_FILES)
|
||||
session.run("black", "--check", "--target-version=py37", *SOURCE_FILES)
|
||||
|
@ -10,3 +10,4 @@ xgboost>=1
|
||||
nox
|
||||
lightgbm
|
||||
pytest-cov
|
||||
mypy
|
||||
|
Loading…
x
Reference in New Issue
Block a user