Add type hints to base modules

This commit is contained in:
Seth Michael Larson 2020-04-24 12:39:13 -05:00 committed by GitHub
parent fe6589ae6a
commit 33b4976f9a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 333 additions and 205 deletions

View File

@ -13,15 +13,16 @@
# limitations under the License. # limitations under the License.
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import TYPE_CHECKING
# -------------------------------------------------------------------------------------------------------------------- #
# PostProcessingActions #
# -------------------------------------------------------------------------------------------------------------------- #
from eland import SortOrder from eland import SortOrder
if TYPE_CHECKING:
import pandas as pd # type: ignore
class PostProcessingAction(ABC): class PostProcessingAction(ABC):
def __init__(self, action_type): def __init__(self, action_type: str) -> None:
""" """
Abstract class for postprocessing actions Abstract class for postprocessing actions
@ -33,76 +34,74 @@ class PostProcessingAction(ABC):
self._action_type = action_type self._action_type = action_type
@property @property
def type(self): def type(self) -> str:
return self._action_type return self._action_type
@abstractmethod @abstractmethod
def resolve_action(self, df): def resolve_action(self, df: "pd.DataFrame") -> "pd.DataFrame":
pass pass
@abstractmethod @abstractmethod
def __repr__(self): def __repr__(self) -> str:
pass pass
class SortIndexAction(PostProcessingAction): class SortIndexAction(PostProcessingAction):
def __init__(self): def __init__(self) -> None:
super().__init__("sort_index") super().__init__("sort_index")
def resolve_action(self, df): def resolve_action(self, df: "pd.DataFrame") -> "pd.DataFrame":
return df.sort_index() return df.sort_index()
def __repr__(self): def __repr__(self) -> str:
return f"('{self.type}')" return f"('{self.type}')"
class HeadAction(PostProcessingAction): class HeadAction(PostProcessingAction):
def __init__(self, count): def __init__(self, count: int) -> None:
super().__init__("head") super().__init__("head")
self._count = count self._count = count
def resolve_action(self, df): def resolve_action(self, df: "pd.DataFrame") -> "pd.DataFrame":
return df.head(self._count) return df.head(self._count)
def __repr__(self): def __repr__(self) -> str:
return f"('{self.type}': ('count': {self._count}))" return f"('{self.type}': ('count': {self._count}))"
class TailAction(PostProcessingAction): class TailAction(PostProcessingAction):
def __init__(self, count): def __init__(self, count: int) -> None:
super().__init__("tail") super().__init__("tail")
self._count = count self._count = count
def resolve_action(self, df): def resolve_action(self, df: "pd.DataFrame") -> "pd.DataFrame":
return df.tail(self._count) return df.tail(self._count)
def __repr__(self): def __repr__(self) -> str:
return f"('{self.type}': ('count': {self._count}))" return f"('{self.type}': ('count': {self._count}))"
class SortFieldAction(PostProcessingAction): class SortFieldAction(PostProcessingAction):
def __init__(self, sort_params_string): def __init__(self, sort_params_string: str) -> None:
super().__init__("sort_field") super().__init__("sort_field")
if sort_params_string is None: if sort_params_string is None:
raise ValueError("Expected valid string") raise ValueError("Expected valid string")
# Split string # Split string
sort_params = sort_params_string.split(":") sort_field, _, sort_order = sort_params_string.partition(":")
if len(sort_params) != 2: if not sort_field or sort_order not in ("asc", "desc"):
raise ValueError( raise ValueError(
f"Expected ES sort params string (e.g. _doc:desc). Got '{sort_params_string}'" f"Expected ES sort params string (e.g. _doc:desc). Got '{sort_params_string}'"
) )
self._sort_field = sort_params[0] self._sort_field = sort_field
self._sort_order = SortOrder.from_string(sort_params[1]) self._sort_order = SortOrder.from_string(sort_order)
def resolve_action(self, df): def resolve_action(self, df: "pd.DataFrame") -> "pd.DataFrame":
if self._sort_order == SortOrder.ASC: return df.sort_values(self._sort_field, self._sort_order == SortOrder.ASC)
return df.sort_values(self._sort_field, True)
return df.sort_values(self._sort_field, False)
def __repr__(self): def __repr__(self) -> str:
return f"('{self.type}': ('sort_field': '{self._sort_field}', 'sort_order': {self._sort_order}))" return f"('{self.type}': ('sort_field': '{self._sort_field}', 'sort_order': {self._sort_order}))"

View File

@ -14,72 +14,88 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from io import StringIO from io import StringIO
from typing import Union, List, TYPE_CHECKING, Any
import numpy as np import numpy as np # type: ignore
if TYPE_CHECKING:
from .query_compiler import QueryCompiler
class ArithmeticObject(ABC): class ArithmeticObject(ABC):
@property @property
@abstractmethod @abstractmethod
def value(self): def value(self) -> str:
pass pass
@abstractmethod @abstractmethod
def dtype(self): def dtype(self) -> np.dtype:
pass pass
@abstractmethod @abstractmethod
def resolve(self): def resolve(self) -> str:
pass pass
@abstractmethod @abstractmethod
def __repr__(self): def __repr__(self) -> str:
pass pass
class ArithmeticString(ArithmeticObject): class ArithmeticString(ArithmeticObject):
def __init__(self, value): def __init__(self, value: str):
self._value = value self._value = value
def resolve(self): def resolve(self) -> str:
return self.value return self.value
@property @property
def dtype(self): def dtype(self) -> np.dtype:
return np.dtype(object) return np.dtype(object)
@property @property
def value(self): def value(self) -> str:
return f"'{self._value}'" return f"'{self._value}'"
def __repr__(self): def __repr__(self) -> str:
return self.value return self.value
class ArithmeticNumber(ArithmeticObject): class ArithmeticNumber(ArithmeticObject):
def __init__(self, value, dtype): def __init__(self, value: Union[int, float], dtype: np.dtype):
self._value = value self._value = value
self._dtype = dtype self._dtype = dtype
def resolve(self): def resolve(self) -> str:
return self.value return self.value
@property @property
def value(self): def value(self) -> str:
return f"{self._value}" return f"{self._value}"
@property @property
def dtype(self): def dtype(self) -> np.dtype:
return self._dtype return self._dtype
def __repr__(self): def __repr__(self) -> str:
return self.value return self.value
class ArithmeticSeries(ArithmeticObject): class ArithmeticSeries(ArithmeticObject):
def __init__(self, query_compiler, display_name, dtype): """Represents each item in a 'Series' by using painless scripts
to evaluate each document in an index as a part of a query.
"""
def __init__(
self, query_compiler: "QueryCompiler", display_name: str, dtype: np.dtype
):
# type defs
self._value: str
self._tasks: List["ArithmeticTask"]
task = query_compiler.get_arithmetic_op_fields() task = query_compiler.get_arithmetic_op_fields()
if task is not None: if task is not None:
assert isinstance(task._arithmetic_series, ArithmeticSeries)
self._value = task._arithmetic_series.value self._value = task._arithmetic_series.value
self._tasks = task._arithmetic_series._tasks.copy() self._tasks = task._arithmetic_series._tasks.copy()
self._dtype = dtype self._dtype = dtype
@ -98,14 +114,14 @@ class ArithmeticSeries(ArithmeticObject):
self._dtype = dtype self._dtype = dtype
@property @property
def value(self): def value(self) -> str:
return self._value return self._value
@property @property
def dtype(self): def dtype(self) -> np.dtype:
return self._dtype return self._dtype
def __repr__(self): def __repr__(self) -> str:
buf = StringIO() buf = StringIO()
buf.write(f"Series: {self.value} ") buf.write(f"Series: {self.value} ")
buf.write("Tasks: ") buf.write("Tasks: ")
@ -113,7 +129,7 @@ class ArithmeticSeries(ArithmeticObject):
buf.write(f"{task!r} ") buf.write(f"{task!r} ")
return buf.getvalue() return buf.getvalue()
def resolve(self): def resolve(self) -> str:
value = self._value value = self._value
for task in self._tasks: for task in self._tasks:
@ -148,7 +164,7 @@ class ArithmeticSeries(ArithmeticObject):
return value return value
def arithmetic_operation(self, op_name, right): def arithmetic_operation(self, op_name: str, right: Any) -> "ArithmeticSeries":
# check if operation is supported (raises on unsupported) # check if operation is supported (raises on unsupported)
self.check_is_supported(op_name, right) self.check_is_supported(op_name, right)
@ -156,7 +172,7 @@ class ArithmeticSeries(ArithmeticObject):
self._tasks.append(task) self._tasks.append(task)
return self return self
def check_is_supported(self, op_name, right): def check_is_supported(self, op_name: str, right: Any) -> bool:
# supported set is # supported set is
# series.number op_name number (all ops) # series.number op_name number (all ops)
# series.string op_name string (only add) # series.string op_name string (only add)
@ -165,22 +181,20 @@ class ArithmeticSeries(ArithmeticObject):
# series.int op_name string (none) # series.int op_name string (none)
# series.float op_name string (none) # series.float op_name string (none)
# see end of https://pandas.pydata.org/pandas-docs/stable/getting_started/basics.html?highlight=dtype for # see end of https://pandas.pydata.org/pandas-docs/stable/getting_started/basics.html?highlight=dtype
# dtype heirarchy # for dtype hierarchy
if np.issubdtype(self.dtype, np.number) and np.issubdtype( right_is_integer = np.issubdtype(right.dtype, np.number)
right.dtype, np.number if np.issubdtype(self.dtype, np.number) and right_is_integer:
):
# series.number op_name number (all ops) # series.number op_name number (all ops)
return True return True
elif np.issubdtype(self.dtype, np.object_) and np.issubdtype(
right.dtype, np.object_ self_is_object = np.issubdtype(self.dtype, np.object_)
): if self_is_object and np.issubdtype(right.dtype, np.object_):
# series.string op_name string (only add) # series.string op_name string (only add)
if op_name == "__add__" or op_name == "__radd__": if op_name == "__add__" or op_name == "__radd__":
return True return True
elif np.issubdtype(self.dtype, np.object_) and np.issubdtype(
right.dtype, np.integer if self_is_object and right_is_integer:
):
# series.string op_name int (only mul) # series.string op_name int (only mul)
if op_name == "__mul__": if op_name == "__mul__":
return True return True
@ -191,22 +205,22 @@ class ArithmeticSeries(ArithmeticObject):
class ArithmeticTask: class ArithmeticTask:
def __init__(self, op_name, object): def __init__(self, op_name: str, object: ArithmeticObject):
self._op_name = op_name self._op_name = op_name
if not isinstance(object, ArithmeticObject): if not isinstance(object, ArithmeticObject):
raise TypeError(f"Task requires ArithmeticObject not {type(object)}") raise TypeError(f"Task requires ArithmeticObject not {type(object)}")
self._object = object self._object = object
def __repr__(self): def __repr__(self) -> str:
buf = StringIO() buf = StringIO()
buf.write(f"op_name: {self.op_name} object: {self.object!r} ") buf.write(f"op_name: {self.op_name} object: {self.object!r} ")
return buf.getvalue() return buf.getvalue()
@property @property
def op_name(self): def op_name(self) -> str:
return self._op_name return self._op_name
@property @property
def object(self): def object(self) -> ArithmeticObject:
return self._object return self._object

View File

@ -15,10 +15,10 @@
import re import re
import warnings import warnings
from enum import Enum from enum import Enum
from typing import Union, List, Tuple from typing import Union, List, Tuple, cast, Callable, Any
import pandas as pd import pandas as pd # type: ignore
from elasticsearch import Elasticsearch from elasticsearch import Elasticsearch # type: ignore
# Default number of rows displayed (different to pandas where ALL could be displayed) # Default number of rows displayed (different to pandas where ALL could be displayed)
DEFAULT_NUM_ROWS_DISPLAYED = 60 DEFAULT_NUM_ROWS_DISPLAYED = 60
@ -29,8 +29,8 @@ DEFAULT_PROGRESS_REPORTING_NUM_ROWS = 10000
DEFAULT_ES_MAX_RESULT_WINDOW = 10000 # index.max_result_window DEFAULT_ES_MAX_RESULT_WINDOW = 10000 # index.max_result_window
def docstring_parameter(*sub): def docstring_parameter(*sub: Any) -> Callable[[Any], Any]:
def dec(obj): def dec(obj: Any) -> Any:
obj.__doc__ = obj.__doc__.format(*sub) obj.__doc__ = obj.__doc__.format(*sub)
return obj return obj
@ -42,21 +42,21 @@ class SortOrder(Enum):
DESC = 1 DESC = 1
@staticmethod @staticmethod
def reverse(order): def reverse(order: "SortOrder") -> "SortOrder":
if order == SortOrder.ASC: if order == SortOrder.ASC:
return SortOrder.DESC return SortOrder.DESC
return SortOrder.ASC return SortOrder.ASC
@staticmethod @staticmethod
def to_string(order): def to_string(order: "SortOrder") -> str:
if order == SortOrder.ASC: if order == SortOrder.ASC:
return "asc" return "asc"
return "desc" return "desc"
@staticmethod @staticmethod
def from_string(order): def from_string(order: str) -> "SortOrder":
if order == "asc": if order == "asc":
return SortOrder.ASC return SortOrder.ASC
@ -276,11 +276,13 @@ def es_version(es_client: Elasticsearch) -> Tuple[int, int, int]:
property if one doesn't exist yet for the current Elasticsearch version. property if one doesn't exist yet for the current Elasticsearch version.
""" """
if not hasattr(es_client, "_eland_es_version"): if not hasattr(es_client, "_eland_es_version"):
major, minor, patch = [ version_info = es_client.info()["version"]["number"]
int(x) match = re.match(r"^(\d+)\.(\d+)\.(\d+)", version_info)
for x in re.match( if match is None:
r"^(\d+)\.(\d+)\.(\d+)", es_client.info()["version"]["number"] raise ValueError(
).groups() f"Unable to determine Elasticsearch version. "
] f"Received: {version_info}"
)
major, minor, patch = [int(x) for x in match.groups()]
es_client._eland_es_version = (major, minor, patch) es_client._eland_es_version = (major, minor, patch)
return es_client._eland_es_version return cast(Tuple[int, int, int], es_client._eland_es_version)

View File

@ -187,7 +187,7 @@ class DataFrame(NDFrame):
""" """
return len(self.columns) == 0 or len(self.index) == 0 return len(self.columns) == 0 or len(self.index) == 0
def head(self, n=5): def head(self, n: int = 5) -> "DataFrame":
""" """
Return the first n rows. Return the first n rows.
@ -222,7 +222,7 @@ class DataFrame(NDFrame):
""" """
return DataFrame(query_compiler=self._query_compiler.head(n)) return DataFrame(query_compiler=self._query_compiler.head(n))
def tail(self, n=5): def tail(self, n: int = 5) -> "DataFrame":
""" """
Return the last n rows. Return the last n rows.

View File

@ -14,12 +14,14 @@
# Originally based on code in MIT-licensed pandasticsearch filters # Originally based on code in MIT-licensed pandasticsearch filters
from typing import Dict, Any, List, Optional, Union, cast
class BooleanFilter: class BooleanFilter:
def __init__(self, *args): def __init__(self) -> None:
self._filter = None self._filter: Dict[str, Any] = {}
def __and__(self, x): def __and__(self, x: "BooleanFilter") -> "BooleanFilter":
# Combine results # Combine results
if isinstance(self, AndFilter): if isinstance(self, AndFilter):
if "must_not" in x.subtree: if "must_not" in x.subtree:
@ -37,7 +39,7 @@ class BooleanFilter:
return x return x
return AndFilter(self, x) return AndFilter(self, x)
def __or__(self, x): def __or__(self, x: "BooleanFilter") -> "BooleanFilter":
# Combine results # Combine results
if isinstance(self, OrFilter): if isinstance(self, OrFilter):
if "must_not" in x.subtree: if "must_not" in x.subtree:
@ -53,85 +55,79 @@ class BooleanFilter:
return x return x
return OrFilter(self, x) return OrFilter(self, x)
def __invert__(self): def __invert__(self) -> "BooleanFilter":
return NotFilter(self) return NotFilter(self)
def empty(self): def empty(self) -> bool:
if self._filter is None: return not bool(self._filter)
return True
return False
def __repr__(self): def __repr__(self) -> str:
return str(self._filter) return str(self._filter)
@property @property
def subtree(self): def subtree(self) -> Dict[str, Any]:
if "bool" in self._filter: if "bool" in self._filter:
return self._filter["bool"] return cast(Dict[str, Any], self._filter["bool"])
else: else:
return self._filter return self._filter
def build(self): def build(self) -> Dict[str, Any]:
return self._filter return self._filter
# Binary operator # Binary operator
class AndFilter(BooleanFilter): class AndFilter(BooleanFilter):
def __init__(self, *args): def __init__(self, *args: BooleanFilter) -> None:
[isinstance(x, BooleanFilter) for x in args]
super().__init__() super().__init__()
self._filter = {"bool": {"must": [x.build() for x in args]}} self._filter = {"bool": {"must": [x.build() for x in args]}}
class OrFilter(BooleanFilter): class OrFilter(BooleanFilter):
def __init__(self, *args): def __init__(self, *args: BooleanFilter) -> None:
[isinstance(x, BooleanFilter) for x in args]
super().__init__() super().__init__()
self._filter = {"bool": {"should": [x.build() for x in args]}} self._filter = {"bool": {"should": [x.build() for x in args]}}
class NotFilter(BooleanFilter): class NotFilter(BooleanFilter):
def __init__(self, x): def __init__(self, x: BooleanFilter) -> None:
assert isinstance(x, BooleanFilter)
super().__init__() super().__init__()
self._filter = {"bool": {"must_not": x.build()}} self._filter = {"bool": {"must_not": x.build()}}
# LeafBooleanFilter # LeafBooleanFilter
class GreaterEqual(BooleanFilter): class GreaterEqual(BooleanFilter):
def __init__(self, field, value): def __init__(self, field: str, value: Any) -> None:
super().__init__() super().__init__()
self._filter = {"range": {field: {"gte": value}}} self._filter = {"range": {field: {"gte": value}}}
class Greater(BooleanFilter): class Greater(BooleanFilter):
def __init__(self, field, value): def __init__(self, field: str, value: Any) -> None:
super().__init__() super().__init__()
self._filter = {"range": {field: {"gt": value}}} self._filter = {"range": {field: {"gt": value}}}
class LessEqual(BooleanFilter): class LessEqual(BooleanFilter):
def __init__(self, field, value): def __init__(self, field: str, value: Any) -> None:
super().__init__() super().__init__()
self._filter = {"range": {field: {"lte": value}}} self._filter = {"range": {field: {"lte": value}}}
class Less(BooleanFilter): class Less(BooleanFilter):
def __init__(self, field, value): def __init__(self, field: str, value: Any) -> None:
super().__init__() super().__init__()
self._filter = {"range": {field: {"lt": value}}} self._filter = {"range": {field: {"lt": value}}}
class Equal(BooleanFilter): class Equal(BooleanFilter):
def __init__(self, field, value): def __init__(self, field: str, value: Any) -> None:
super().__init__() super().__init__()
self._filter = {"term": {field: value}} self._filter = {"term": {field: value}}
class IsIn(BooleanFilter): class IsIn(BooleanFilter):
def __init__(self, field, value): def __init__(self, field: str, value: List[Any]) -> None:
super().__init__() super().__init__()
assert isinstance(value, list)
if field == "ids": if field == "ids":
self._filter = {"ids": {"values": value}} self._filter = {"ids": {"values": value}}
else: else:
@ -139,39 +135,44 @@ class IsIn(BooleanFilter):
class Like(BooleanFilter): class Like(BooleanFilter):
def __init__(self, field, value): def __init__(self, field: str, value: str) -> None:
super().__init__() super().__init__()
self._filter = {"wildcard": {field: value}} self._filter = {"wildcard": {field: value}}
class Rlike(BooleanFilter): class Rlike(BooleanFilter):
def __init__(self, field, value): def __init__(self, field: str, value: str) -> None:
super().__init__() super().__init__()
self._filter = {"regexp": {field: value}} self._filter = {"regexp": {field: value}}
class Startswith(BooleanFilter): class Startswith(BooleanFilter):
def __init__(self, field, value): def __init__(self, field: str, value: str) -> None:
super().__init__() super().__init__()
self._filter = {"prefix": {field: value}} self._filter = {"prefix": {field: value}}
class IsNull(BooleanFilter): class IsNull(BooleanFilter):
def __init__(self, field): def __init__(self, field: str) -> None:
super().__init__() super().__init__()
self._filter = {"missing": {"field": field}} self._filter = {"missing": {"field": field}}
class NotNull(BooleanFilter): class NotNull(BooleanFilter):
def __init__(self, field): def __init__(self, field: str) -> None:
super().__init__() super().__init__()
self._filter = {"exists": {"field": field}} self._filter = {"exists": {"field": field}}
class ScriptFilter(BooleanFilter): class ScriptFilter(BooleanFilter):
def __init__(self, inline, lang=None, params=None): def __init__(
self,
inline: str,
lang: Optional[str] = None,
params: Optional[Dict[str, Any]] = None,
) -> None:
super().__init__() super().__init__()
script = {"source": inline} script: Dict[str, Union[str, Dict[str, Any]]] = {"source": inline}
if lang is not None: if lang is not None:
script["lang"] = lang script["lang"] = lang
if params is not None: if params is not None:
@ -180,6 +181,6 @@ class ScriptFilter(BooleanFilter):
class QueryFilter(BooleanFilter): class QueryFilter(BooleanFilter):
def __init__(self, query): def __init__(self, query: Dict[str, Any]) -> None:
super().__init__() super().__init__()
self._filter = query self._filter = query

View File

@ -13,6 +13,12 @@
# limitations under the License. # limitations under the License.
from typing import Optional, TextIO, TYPE_CHECKING
if TYPE_CHECKING:
from .query_compiler import QueryCompiler
class Index: class Index:
""" """
The index for an eland.DataFrame. The index for an eland.DataFrame.
@ -33,30 +39,36 @@ class Index:
ID_INDEX_FIELD = "_id" ID_INDEX_FIELD = "_id"
ID_SORT_FIELD = "_doc" # if index field is _id, sort by _doc ID_SORT_FIELD = "_doc" # if index field is _id, sort by _doc
def __init__(self, query_compiler, index_field=None): def __init__(
self, query_compiler: "QueryCompiler", index_field: Optional[str] = None
):
self._query_compiler = query_compiler self._query_compiler = query_compiler
# _is_source_field is set immediately within # _is_source_field is set immediately within
# index_field.setter # index_field.setter
self._is_source_field = False self._is_source_field = False
self.index_field = index_field
# The type:ignore is due to mypy not being smart enough
# to recognize the property.setter has a different type
# than the property.getter.
self.index_field = index_field # type: ignore
@property @property
def sort_field(self): def sort_field(self) -> str:
if self._index_field == self.ID_INDEX_FIELD: if self._index_field == self.ID_INDEX_FIELD:
return self.ID_SORT_FIELD return self.ID_SORT_FIELD
return self._index_field return self._index_field
@property @property
def is_source_field(self): def is_source_field(self) -> bool:
return self._is_source_field return self._is_source_field
@property @property
def index_field(self): def index_field(self) -> str:
return self._index_field return self._index_field
@index_field.setter @index_field.setter
def index_field(self, index_field): def index_field(self, index_field: Optional[str]) -> None:
if index_field is None or index_field == Index.ID_INDEX_FIELD: if index_field is None or index_field == Index.ID_INDEX_FIELD:
self._index_field = Index.ID_INDEX_FIELD self._index_field = Index.ID_INDEX_FIELD
self._is_source_field = False self._is_source_field = False
@ -64,18 +76,18 @@ class Index:
self._index_field = index_field self._index_field = index_field
self._is_source_field = True self._is_source_field = True
def __len__(self): def __len__(self) -> int:
return self._query_compiler._index_count() return self._query_compiler._index_count()
# Make iterable # Make iterable
def __next__(self): def __next__(self) -> None:
# TODO resolve this hack to make this 'iterable' # TODO resolve this hack to make this 'iterable'
raise StopIteration() raise StopIteration()
def __iter__(self): def __iter__(self) -> "Index":
return self return self
def info_es(self, buf): def info_es(self, buf: TextIO) -> None:
buf.write("Index:\n") buf.write("Index:\n")
buf.write(f" index_field: {self.index_field}\n") buf.write(f" index_field: {self.index_field}\n")
buf.write(f" is_source_field: {self.is_source_field}\n") buf.write(f" is_source_field: {self.is_source_field}\n")

View File

@ -14,6 +14,7 @@
import copy import copy
import warnings import warnings
from typing import Optional
import numpy as np import numpy as np
@ -98,7 +99,7 @@ class Operations:
else: else:
self._arithmetic_op_fields_task.update(display_name, arithmetic_series) self._arithmetic_op_fields_task.update(display_name, arithmetic_series)
def get_arithmetic_op_fields(self): def get_arithmetic_op_fields(self) -> Optional[ArithmeticOpFieldsTask]:
# get an ArithmeticOpFieldsTask if it exists # get an ArithmeticOpFieldsTask if it exists
return self._arithmetic_op_fields_task return self._arithmetic_op_fields_task

View File

@ -14,6 +14,7 @@
import warnings import warnings
from copy import deepcopy from copy import deepcopy
from typing import Optional, Dict, List, Any
from eland.filter import BooleanFilter, NotNull, IsNull, IsIn from eland.filter import BooleanFilter, NotNull, IsNull, IsIn
@ -23,7 +24,11 @@ class Query:
Simple class to manage building Elasticsearch queries. Simple class to manage building Elasticsearch queries.
""" """
def __init__(self, query=None): def __init__(self, query: Optional["Query"] = None):
# type defs
self._query: BooleanFilter
self._aggs: Dict[str, Any]
if query is None: if query is None:
self._query = BooleanFilter() self._query = BooleanFilter()
self._aggs = {} self._aggs = {}
@ -32,7 +37,7 @@ class Query:
self._query = deepcopy(query._query) self._query = deepcopy(query._query)
self._aggs = deepcopy(query._aggs) self._aggs = deepcopy(query._aggs)
def exists(self, field, must=True): def exists(self, field: str, must: bool = True) -> None:
""" """
Add exists query Add exists query
https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-exists-query.html https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-exists-query.html
@ -48,7 +53,7 @@ class Query:
else: else:
self._query = self._query & IsNull(field) self._query = self._query & IsNull(field)
def ids(self, items, must=True): def ids(self, items: List[Any], must: bool = True) -> None:
""" """
Add ids query Add ids query
https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-ids-query.html https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-ids-query.html
@ -64,7 +69,7 @@ class Query:
else: else:
self._query = self._query & ~(IsIn("ids", items)) self._query = self._query & ~(IsIn("ids", items))
def terms(self, field, items, must=True): def terms(self, field: str, items: List[str], must: bool = True) -> None:
""" """
Add ids query Add ids query
https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-terms-query.html https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-terms-query.html
@ -80,7 +85,7 @@ class Query:
else: else:
self._query = self._query & ~(IsIn(field, items)) self._query = self._query & ~(IsIn(field, items))
def terms_aggs(self, name, func, field, es_size): def terms_aggs(self, name: str, func: str, field: str, es_size: int) -> None:
""" """
Add terms agg e.g Add terms agg e.g
@ -96,7 +101,7 @@ class Query:
agg = {func: {"field": field, "size": es_size}} agg = {func: {"field": field, "size": es_size}}
self._aggs[name] = agg self._aggs[name] = agg
def metric_aggs(self, name, func, field): def metric_aggs(self, name: str, func: str, field: str) -> None:
""" """
Add metric agg e.g Add metric agg e.g
@ -111,7 +116,14 @@ class Query:
agg = {func: {"field": field}} agg = {func: {"field": field}}
self._aggs[name] = agg self._aggs[name] = agg
def hist_aggs(self, name, field, min_aggs, max_aggs, num_bins): def hist_aggs(
self,
name: str,
field: str,
min_aggs: Dict[str, Any],
max_aggs: Dict[str, Any],
num_bins: int,
) -> None:
""" """
Add histogram agg e.g. Add histogram agg e.g.
"aggs": { "aggs": {
@ -127,14 +139,15 @@ class Query:
max = max_aggs[field] max = max_aggs[field]
interval = (max - min) / num_bins interval = (max - min) / num_bins
offset = min
agg = {"histogram": {"field": field, "interval": interval, "offset": offset}}
if interval != 0: if interval != 0:
offset = min
agg = {
"histogram": {"field": field, "interval": interval, "offset": offset}
}
self._aggs[name] = agg self._aggs[name] = agg
def to_search_body(self): def to_search_body(self) -> Dict[str, Any]:
body = {} body = {}
if self._aggs: if self._aggs:
body["aggs"] = self._aggs body["aggs"] = self._aggs
@ -142,19 +155,19 @@ class Query:
body["query"] = self._query.build() body["query"] = self._query.build()
return body return body
def to_count_body(self): def to_count_body(self) -> Optional[Dict[str, Any]]:
if len(self._aggs) > 0: if len(self._aggs) > 0:
warnings.warn("Requesting count for agg query {}", self) warnings.warn(f"Requesting count for agg query {self}")
if self._query.empty(): if self._query.empty():
return None return None
else: else:
return {"query": self._query.build()} return {"query": self._query.build()}
def update_boolean_filter(self, boolean_filter): def update_boolean_filter(self, boolean_filter: BooleanFilter) -> None:
if self._query.empty(): if self._query.empty():
self._query = boolean_filter self._query = boolean_filter
else: else:
self._query = self._query & boolean_filter self._query = self._query & boolean_filter
def __repr__(self): def __repr__(self) -> str:
return repr(self.to_search_body()) return repr(self.to_search_body())

View File

@ -14,6 +14,7 @@
import copy import copy
from datetime import datetime from datetime import datetime
from typing import Optional, TYPE_CHECKING
import numpy as np import numpy as np
import pandas as pd import pandas as pd
@ -28,6 +29,9 @@ from eland.common import (
elasticsearch_date_to_pandas_date, elasticsearch_date_to_pandas_date,
) )
if TYPE_CHECKING:
from .tasks import ArithmeticOpFieldsTask # noqa: F401
class QueryCompiler: class QueryCompiler:
""" """
@ -348,7 +352,7 @@ class QueryCompiler:
return out return out
def _index_count(self): def _index_count(self) -> int:
""" """
Returns Returns
------- -------
@ -562,10 +566,10 @@ class QueryCompiler:
return result return result
def get_arithmetic_op_fields(self): def get_arithmetic_op_fields(self) -> Optional["ArithmeticOpFieldsTask"]:
return self._operations.get_arithmetic_op_fields() return self._operations.get_arithmetic_op_fields()
def display_name_to_aggregatable_name(self, display_name): def display_name_to_aggregatable_name(self, display_name: str) -> str:
aggregatable_field_name = self._mappings.aggregatable_field_name(display_name) aggregatable_field_name = self._mappings.aggregatable_field_name(display_name)
return aggregatable_field_name return aggregatable_field_name

View File

@ -13,12 +13,22 @@
# limitations under the License. # limitations under the License.
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, List, Dict, Any, Tuple
from eland import SortOrder from eland import SortOrder
from eland.actions import HeadAction, TailAction, SortIndexAction from eland.actions import HeadAction, TailAction, SortIndexAction
from eland.arithmetics import ArithmeticSeries from eland.arithmetics import ArithmeticSeries
if TYPE_CHECKING:
from .actions import PostProcessingAction # noqa: F401
from .filter import BooleanFilter # noqa: F401
from .query_compiler import QueryCompiler # noqa: F401
QUERY_PARAMS_TYPE = Dict[str, Any]
RESOLVED_TASK_TYPE = Tuple[QUERY_PARAMS_TYPE, List["PostProcessingAction"]]
class Task(ABC): class Task(ABC):
""" """
Abstract class for tasks Abstract class for tasks
@ -29,44 +39,51 @@ class Task(ABC):
The task type (e.g. head, tail etc.) The task type (e.g. head, tail etc.)
""" """
def __init__(self, task_type): def __init__(self, task_type: str):
self._task_type = task_type self._task_type = task_type
@property @property
def type(self): def type(self) -> str:
return self._task_type return self._task_type
@abstractmethod @abstractmethod
def resolve_task(self, query_params, post_processing, query_compiler): def resolve_task(
self,
query_params: QUERY_PARAMS_TYPE,
post_processing: List["PostProcessingAction"],
query_compiler: "QueryCompiler",
) -> RESOLVED_TASK_TYPE:
pass pass
@abstractmethod @abstractmethod
def __repr__(self): def __repr__(self) -> str:
pass pass
class SizeTask(Task): class SizeTask(Task):
def __init__(self, task_type):
super().__init__(task_type)
@abstractmethod @abstractmethod
def size(self): def size(self) -> int:
# must override # must override
pass pass
class HeadTask(SizeTask): class HeadTask(SizeTask):
def __init__(self, sort_field, count): def __init__(self, sort_field: str, count: int):
super().__init__("head") super().__init__("head")
# Add a task that is an ascending sort with size=count # Add a task that is an ascending sort with size=count
self._sort_field = sort_field self._sort_field = sort_field
self._count = count self._count = count
def __repr__(self): def __repr__(self) -> str:
return f"('{self._task_type}': ('sort_field': '{self._sort_field}', 'count': {self._count}))" return f"('{self._task_type}': ('sort_field': '{self._sort_field}', 'count': {self._count}))"
def resolve_task(self, query_params, post_processing, query_compiler): def resolve_task(
self,
query_params: QUERY_PARAMS_TYPE,
post_processing: List["PostProcessingAction"],
query_compiler: "QueryCompiler",
) -> RESOLVED_TASK_TYPE:
# head - sort asc, size n # head - sort asc, size n
# |12345-------------| # |12345-------------|
query_sort_field = self._sort_field query_sort_field = self._sort_field
@ -97,22 +114,24 @@ class HeadTask(SizeTask):
return query_params, post_processing return query_params, post_processing
def size(self): def size(self) -> int:
return self._count return self._count
class TailTask(SizeTask): class TailTask(SizeTask):
def __init__(self, sort_field, count): def __init__(self, sort_field: str, count: int):
super().__init__("tail") super().__init__("tail")
# Add a task that is descending sort with size=count # Add a task that is descending sort with size=count
self._sort_field = sort_field self._sort_field = sort_field
self._count = count self._count = count
def __repr__(self): def resolve_task(
return f"('{self._task_type}': ('sort_field': '{self._sort_field}', 'count': {self._count}))" self,
query_params: QUERY_PARAMS_TYPE,
def resolve_task(self, query_params, post_processing, query_compiler): post_processing: List["PostProcessingAction"],
query_compiler: "QueryCompiler",
) -> RESOLVED_TASK_TYPE:
# tail - sort desc, size n, post-process sort asc # tail - sort desc, size n, post-process sort asc
# |-------------12345| # |-------------12345|
query_sort_field = self._sort_field query_sort_field = self._sort_field
@ -123,7 +142,10 @@ class TailTask(SizeTask):
if ( if (
query_params["query_size"] is not None query_params["query_size"] is not None
and query_params["query_sort_order"] == query_sort_order and query_params["query_sort_order"] == query_sort_order
and post_processing == ["sort_index"] and (
len(post_processing) == 1
and isinstance(post_processing[0], SortIndexAction)
)
): ):
if query_size < query_params["query_size"]: if query_size < query_params["query_size"]:
query_params["query_size"] = query_size query_params["query_size"] = query_size
@ -156,12 +178,15 @@ class TailTask(SizeTask):
return query_params, post_processing return query_params, post_processing
def size(self): def size(self) -> int:
return self._count return self._count
def __repr__(self) -> str:
return f"('{self._task_type}': ('sort_field': '{self._sort_field}', 'count': {self._count}))"
class QueryIdsTask(Task): class QueryIdsTask(Task):
def __init__(self, must, ids): def __init__(self, must: bool, ids: List[str]):
""" """
Parameters Parameters
---------- ----------
@ -176,17 +201,21 @@ class QueryIdsTask(Task):
self._must = must self._must = must
self._ids = ids self._ids = ids
def resolve_task(self, query_params, post_processing, query_compiler): def resolve_task(
self,
query_params: QUERY_PARAMS_TYPE,
post_processing: List["PostProcessingAction"],
query_compiler: "QueryCompiler",
) -> RESOLVED_TASK_TYPE:
query_params["query"].ids(self._ids, must=self._must) query_params["query"].ids(self._ids, must=self._must)
return query_params, post_processing return query_params, post_processing
def __repr__(self): def __repr__(self) -> str:
return f"('{self._task_type}': ('must': {self._must}, 'ids': {self._ids}))" return f"('{self._task_type}': ('must': {self._must}, 'ids': {self._ids}))"
class QueryTermsTask(Task): class QueryTermsTask(Task):
def __init__(self, must, field, terms): def __init__(self, must: bool, field: str, terms: List[Any]):
""" """
Parameters Parameters
---------- ----------
@ -205,20 +234,24 @@ class QueryTermsTask(Task):
self._field = field self._field = field
self._terms = terms self._terms = terms
def __repr__(self): def resolve_task(
self,
query_params: QUERY_PARAMS_TYPE,
post_processing: List["PostProcessingAction"],
query_compiler: "QueryCompiler",
) -> RESOLVED_TASK_TYPE:
query_params["query"].terms(self._field, self._terms, must=self._must)
return query_params, post_processing
def __repr__(self) -> str:
return ( return (
f"('{self._task_type}': ('must': {self._must}, " f"('{self._task_type}': ('must': {self._must}, "
f"'field': '{self._field}', 'terms': {self._terms}))" f"'field': '{self._field}', 'terms': {self._terms}))"
) )
def resolve_task(self, query_params, post_processing, query_compiler):
query_params["query"].terms(self._field, self._terms, must=self._must)
return query_params, post_processing
class BooleanFilterTask(Task): class BooleanFilterTask(Task):
def __init__(self, boolean_filter): def __init__(self, boolean_filter: "BooleanFilter"):
""" """
Parameters Parameters
---------- ----------
@ -229,17 +262,21 @@ class BooleanFilterTask(Task):
self._boolean_filter = boolean_filter self._boolean_filter = boolean_filter
def __repr__(self): def resolve_task(
return f"('{self._task_type}': ('boolean_filter': {self._boolean_filter!r}))" self,
query_params: QUERY_PARAMS_TYPE,
def resolve_task(self, query_params, post_processing, query_compiler): post_processing: List["PostProcessingAction"],
query_compiler: "QueryCompiler",
) -> RESOLVED_TASK_TYPE:
query_params["query"].update_boolean_filter(self._boolean_filter) query_params["query"].update_boolean_filter(self._boolean_filter)
return query_params, post_processing return query_params, post_processing
def __repr__(self) -> str:
return f"('{self._task_type}': ('boolean_filter': {self._boolean_filter!r}))"
class ArithmeticOpFieldsTask(Task): class ArithmeticOpFieldsTask(Task):
def __init__(self, display_name, arithmetic_series): def __init__(self, display_name: str, arithmetic_series: ArithmeticSeries):
super().__init__("arithmetic_op_fields") super().__init__("arithmetic_op_fields")
self._display_name = display_name self._display_name = display_name
@ -248,19 +285,16 @@ class ArithmeticOpFieldsTask(Task):
raise TypeError(f"Expecting ArithmeticSeries got {type(arithmetic_series)}") raise TypeError(f"Expecting ArithmeticSeries got {type(arithmetic_series)}")
self._arithmetic_series = arithmetic_series self._arithmetic_series = arithmetic_series
def __repr__(self): def update(self, display_name: str, arithmetic_series: ArithmeticSeries) -> None:
return (
f"('{self._task_type}': ("
f"'display_name': {self._display_name}, "
f"'arithmetic_object': {self._arithmetic_series}"
f"))"
)
def update(self, display_name, arithmetic_series):
self._display_name = display_name self._display_name = display_name
self._arithmetic_series = arithmetic_series self._arithmetic_series = arithmetic_series
def resolve_task(self, query_params, post_processing, query_compiler): def resolve_task(
self,
query_params: QUERY_PARAMS_TYPE,
post_processing: List["PostProcessingAction"],
query_compiler: "QueryCompiler",
) -> RESOLVED_TASK_TYPE:
# https://www.elastic.co/guide/en/elasticsearch/painless/current/painless-api-reference-shared-java-lang.html#painless-api-reference-shared-Math # https://www.elastic.co/guide/en/elasticsearch/painless/current/painless-api-reference-shared-java-lang.html#painless-api-reference-shared-Math
""" """
"script_fields": { "script_fields": {
@ -272,7 +306,10 @@ class ArithmeticOpFieldsTask(Task):
} }
""" """
if query_params["query_script_fields"] is None: if query_params["query_script_fields"] is None:
query_params["query_script_fields"] = dict() query_params["query_script_fields"] = {}
# TODO: Remove this once 'query_params' becomes a dataclass.
assert isinstance(query_params["query_script_fields"], dict)
if self._display_name in query_params["query_script_fields"]: if self._display_name in query_params["query_script_fields"]:
raise NotImplementedError( raise NotImplementedError(
@ -286,3 +323,11 @@ class ArithmeticOpFieldsTask(Task):
} }
return query_params, post_processing return query_params, post_processing
def __repr__(self) -> str:
return (
f"('{self._task_type}': ("
f"'display_name': {self._display_name}, "
f"'arithmetic_object': {self._arithmetic_series}"
f"))"
)

View File

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import csv import csv
from typing import Union, List, Tuple, Optional, Mapping
import pandas as pd import pandas as pd
from pandas.io.parsers import _c_parser_defaults from pandas.io.parsers import _c_parser_defaults
@ -20,10 +21,14 @@ from pandas.io.parsers import _c_parser_defaults
from eland import DataFrame from eland import DataFrame
from eland.field_mappings import FieldMappings from eland.field_mappings import FieldMappings
from eland.common import ensure_es_client, DEFAULT_CHUNK_SIZE from eland.common import ensure_es_client, DEFAULT_CHUNK_SIZE
from elasticsearch import Elasticsearch
from elasticsearch.helpers import bulk from elasticsearch.helpers import bulk
def read_es(es_client, es_index_pattern): def read_es(
es_client: Union[str, List[str], Tuple[str, ...], Elasticsearch],
es_index_pattern: str,
) -> DataFrame:
""" """
Utility method to create an eland.Dataframe from an Elasticsearch index_pattern. Utility method to create an eland.Dataframe from an Elasticsearch index_pattern.
(Similar to pandas.read_csv, but source data is an Elasticsearch index rather than (Similar to pandas.read_csv, but source data is an Elasticsearch index rather than
@ -50,16 +55,16 @@ def read_es(es_client, es_index_pattern):
def pandas_to_eland( def pandas_to_eland(
pd_df, pd_df: pd.DataFrame,
es_client, es_client: Union[str, List[str], Tuple[str, ...], Elasticsearch],
es_dest_index, es_dest_index: str,
es_if_exists="fail", es_if_exists: str = "fail",
es_refresh=False, es_refresh: bool = False,
es_dropna=False, es_dropna: bool = False,
es_type_overrides=None, es_type_overrides: Optional[Mapping[str, str]] = None,
chunksize=None, chunksize: Optional[int] = None,
use_pandas_index_for_es_ids=True, use_pandas_index_for_es_ids: bool = True,
): ) -> DataFrame:
""" """
Append a pandas DataFrame to an Elasticsearch index. Append a pandas DataFrame to an Elasticsearch index.
Mainly used in testing. Mainly used in testing.
@ -217,7 +222,7 @@ def pandas_to_eland(
return DataFrame(es_client, es_dest_index) return DataFrame(es_client, es_dest_index)
def eland_to_pandas(ed_df, show_progress=False): def eland_to_pandas(ed_df: DataFrame, show_progress: bool = False) -> pd.DataFrame:
""" """
Convert an eland.Dataframe to a pandas.DataFrame Convert an eland.Dataframe to a pandas.DataFrame

View File

@ -1,4 +1,5 @@
import os import os
import subprocess
from pathlib import Path from pathlib import Path
import nox import nox
import elasticsearch import elasticsearch
@ -12,6 +13,19 @@ SOURCE_FILES = (
"docs/", "docs/",
) )
# Whenever type-hints are completed on a file it should
# be added here so that this file will continue to be checked
# by mypy. Errors from other files are ignored.
TYPED_FILES = {
"eland/actions.py",
"eland/arithmetics.py",
"eland/common.py",
"eland/filter.py",
"eland/index.py",
"eland/query.py",
"eland/tasks.py",
}
@nox.session(reuse_venv=True) @nox.session(reuse_venv=True)
def blacken(session): def blacken(session):
@ -22,10 +36,28 @@ def blacken(session):
@nox.session(reuse_venv=True) @nox.session(reuse_venv=True)
def lint(session): def lint(session):
session.install("black", "flake8") session.install("black", "flake8", "mypy")
session.run("black", "--check", "--target-version=py36", *SOURCE_FILES) session.run("black", "--check", "--target-version=py36", *SOURCE_FILES)
session.run("flake8", "--ignore=E501,W503,E402,E712", *SOURCE_FILES) session.run("flake8", "--ignore=E501,W503,E402,E712", *SOURCE_FILES)
# TODO: When all files are typed we can change this to .run("mypy", "--strict", "eland/")
session.log("mypy --strict eland/")
for typed_file in TYPED_FILES:
popen = subprocess.Popen(
f"mypy --strict {typed_file}",
shell=True,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
)
popen.wait()
errors = []
for line in popen.stdout.read().decode().split("\n"):
filepath = line.partition(":")[0]
if filepath in TYPED_FILES:
errors.append(line)
if errors:
session.error("\n" + "\n".join(sorted(set(errors))))
@nox.session(python=["3.6", "3.7", "3.8"]) @nox.session(python=["3.6", "3.7", "3.8"])
def test(session): def test(session):