diff --git a/eland/actions.py b/eland/actions.py index 92cf4e0..ee1ceac 100644 --- a/eland/actions.py +++ b/eland/actions.py @@ -13,15 +13,16 @@ # limitations under the License. from abc import ABC, abstractmethod - -# -------------------------------------------------------------------------------------------------------------------- # -# PostProcessingActions # -# -------------------------------------------------------------------------------------------------------------------- # +from typing import TYPE_CHECKING from eland import SortOrder +if TYPE_CHECKING: + import pandas as pd # type: ignore + + class PostProcessingAction(ABC): - def __init__(self, action_type): + def __init__(self, action_type: str) -> None: """ Abstract class for postprocessing actions @@ -33,76 +34,74 @@ class PostProcessingAction(ABC): self._action_type = action_type @property - def type(self): + def type(self) -> str: return self._action_type @abstractmethod - def resolve_action(self, df): + def resolve_action(self, df: "pd.DataFrame") -> "pd.DataFrame": pass @abstractmethod - def __repr__(self): + def __repr__(self) -> str: pass class SortIndexAction(PostProcessingAction): - def __init__(self): + def __init__(self) -> None: super().__init__("sort_index") - def resolve_action(self, df): + def resolve_action(self, df: "pd.DataFrame") -> "pd.DataFrame": return df.sort_index() - def __repr__(self): + def __repr__(self) -> str: return f"('{self.type}')" class HeadAction(PostProcessingAction): - def __init__(self, count): + def __init__(self, count: int) -> None: super().__init__("head") self._count = count - def resolve_action(self, df): + def resolve_action(self, df: "pd.DataFrame") -> "pd.DataFrame": return df.head(self._count) - def __repr__(self): + def __repr__(self) -> str: return f"('{self.type}': ('count': {self._count}))" class TailAction(PostProcessingAction): - def __init__(self, count): + def __init__(self, count: int) -> None: super().__init__("tail") self._count = count - def resolve_action(self, df): + def resolve_action(self, df: "pd.DataFrame") -> "pd.DataFrame": return df.tail(self._count) - def __repr__(self): + def __repr__(self) -> str: return f"('{self.type}': ('count': {self._count}))" class SortFieldAction(PostProcessingAction): - def __init__(self, sort_params_string): + def __init__(self, sort_params_string: str) -> None: super().__init__("sort_field") if sort_params_string is None: raise ValueError("Expected valid string") # Split string - sort_params = sort_params_string.split(":") - if len(sort_params) != 2: + sort_field, _, sort_order = sort_params_string.partition(":") + if not sort_field or sort_order not in ("asc", "desc"): raise ValueError( f"Expected ES sort params string (e.g. _doc:desc). Got '{sort_params_string}'" ) - self._sort_field = sort_params[0] - self._sort_order = SortOrder.from_string(sort_params[1]) + self._sort_field = sort_field + self._sort_order = SortOrder.from_string(sort_order) - def resolve_action(self, df): - if self._sort_order == SortOrder.ASC: - return df.sort_values(self._sort_field, True) - return df.sort_values(self._sort_field, False) + def resolve_action(self, df: "pd.DataFrame") -> "pd.DataFrame": + return df.sort_values(self._sort_field, self._sort_order == SortOrder.ASC) - def __repr__(self): + def __repr__(self) -> str: return f"('{self.type}': ('sort_field': '{self._sort_field}', 'sort_order': {self._sort_order}))" diff --git a/eland/arithmetics.py b/eland/arithmetics.py index 2be82ce..a821107 100644 --- a/eland/arithmetics.py +++ b/eland/arithmetics.py @@ -14,72 +14,88 @@ from abc import ABC, abstractmethod from io import StringIO +from typing import Union, List, TYPE_CHECKING, Any -import numpy as np +import numpy as np # type: ignore + +if TYPE_CHECKING: + from .query_compiler import QueryCompiler class ArithmeticObject(ABC): @property @abstractmethod - def value(self): + def value(self) -> str: pass @abstractmethod - def dtype(self): + def dtype(self) -> np.dtype: pass @abstractmethod - def resolve(self): + def resolve(self) -> str: pass @abstractmethod - def __repr__(self): + def __repr__(self) -> str: pass class ArithmeticString(ArithmeticObject): - def __init__(self, value): + def __init__(self, value: str): self._value = value - def resolve(self): + def resolve(self) -> str: return self.value @property - def dtype(self): + def dtype(self) -> np.dtype: return np.dtype(object) @property - def value(self): + def value(self) -> str: return f"'{self._value}'" - def __repr__(self): + def __repr__(self) -> str: return self.value class ArithmeticNumber(ArithmeticObject): - def __init__(self, value, dtype): + def __init__(self, value: Union[int, float], dtype: np.dtype): self._value = value self._dtype = dtype - def resolve(self): + def resolve(self) -> str: return self.value @property - def value(self): + def value(self) -> str: return f"{self._value}" @property - def dtype(self): + def dtype(self) -> np.dtype: return self._dtype - def __repr__(self): + def __repr__(self) -> str: return self.value class ArithmeticSeries(ArithmeticObject): - def __init__(self, query_compiler, display_name, dtype): + """Represents each item in a 'Series' by using painless scripts + to evaluate each document in an index as a part of a query. + """ + + def __init__( + self, query_compiler: "QueryCompiler", display_name: str, dtype: np.dtype + ): + # type defs + self._value: str + self._tasks: List["ArithmeticTask"] + task = query_compiler.get_arithmetic_op_fields() + if task is not None: + assert isinstance(task._arithmetic_series, ArithmeticSeries) self._value = task._arithmetic_series.value self._tasks = task._arithmetic_series._tasks.copy() self._dtype = dtype @@ -98,14 +114,14 @@ class ArithmeticSeries(ArithmeticObject): self._dtype = dtype @property - def value(self): + def value(self) -> str: return self._value @property - def dtype(self): + def dtype(self) -> np.dtype: return self._dtype - def __repr__(self): + def __repr__(self) -> str: buf = StringIO() buf.write(f"Series: {self.value} ") buf.write("Tasks: ") @@ -113,7 +129,7 @@ class ArithmeticSeries(ArithmeticObject): buf.write(f"{task!r} ") return buf.getvalue() - def resolve(self): + def resolve(self) -> str: value = self._value for task in self._tasks: @@ -148,7 +164,7 @@ class ArithmeticSeries(ArithmeticObject): return value - def arithmetic_operation(self, op_name, right): + def arithmetic_operation(self, op_name: str, right: Any) -> "ArithmeticSeries": # check if operation is supported (raises on unsupported) self.check_is_supported(op_name, right) @@ -156,7 +172,7 @@ class ArithmeticSeries(ArithmeticObject): self._tasks.append(task) return self - def check_is_supported(self, op_name, right): + def check_is_supported(self, op_name: str, right: Any) -> bool: # supported set is # series.number op_name number (all ops) # series.string op_name string (only add) @@ -165,22 +181,20 @@ class ArithmeticSeries(ArithmeticObject): # series.int op_name string (none) # series.float op_name string (none) - # see end of https://pandas.pydata.org/pandas-docs/stable/getting_started/basics.html?highlight=dtype for - # dtype heirarchy - if np.issubdtype(self.dtype, np.number) and np.issubdtype( - right.dtype, np.number - ): + # see end of https://pandas.pydata.org/pandas-docs/stable/getting_started/basics.html?highlight=dtype + # for dtype hierarchy + right_is_integer = np.issubdtype(right.dtype, np.number) + if np.issubdtype(self.dtype, np.number) and right_is_integer: # series.number op_name number (all ops) return True - elif np.issubdtype(self.dtype, np.object_) and np.issubdtype( - right.dtype, np.object_ - ): + + self_is_object = np.issubdtype(self.dtype, np.object_) + if self_is_object and np.issubdtype(right.dtype, np.object_): # series.string op_name string (only add) if op_name == "__add__" or op_name == "__radd__": return True - elif np.issubdtype(self.dtype, np.object_) and np.issubdtype( - right.dtype, np.integer - ): + + if self_is_object and right_is_integer: # series.string op_name int (only mul) if op_name == "__mul__": return True @@ -191,22 +205,22 @@ class ArithmeticSeries(ArithmeticObject): class ArithmeticTask: - def __init__(self, op_name, object): + def __init__(self, op_name: str, object: ArithmeticObject): self._op_name = op_name if not isinstance(object, ArithmeticObject): raise TypeError(f"Task requires ArithmeticObject not {type(object)}") self._object = object - def __repr__(self): + def __repr__(self) -> str: buf = StringIO() buf.write(f"op_name: {self.op_name} object: {self.object!r} ") return buf.getvalue() @property - def op_name(self): + def op_name(self) -> str: return self._op_name @property - def object(self): + def object(self) -> ArithmeticObject: return self._object diff --git a/eland/common.py b/eland/common.py index fe19c09..30d1856 100644 --- a/eland/common.py +++ b/eland/common.py @@ -15,10 +15,10 @@ import re import warnings from enum import Enum -from typing import Union, List, Tuple +from typing import Union, List, Tuple, cast, Callable, Any -import pandas as pd -from elasticsearch import Elasticsearch +import pandas as pd # type: ignore +from elasticsearch import Elasticsearch # type: ignore # Default number of rows displayed (different to pandas where ALL could be displayed) DEFAULT_NUM_ROWS_DISPLAYED = 60 @@ -29,8 +29,8 @@ DEFAULT_PROGRESS_REPORTING_NUM_ROWS = 10000 DEFAULT_ES_MAX_RESULT_WINDOW = 10000 # index.max_result_window -def docstring_parameter(*sub): - def dec(obj): +def docstring_parameter(*sub: Any) -> Callable[[Any], Any]: + def dec(obj: Any) -> Any: obj.__doc__ = obj.__doc__.format(*sub) return obj @@ -42,21 +42,21 @@ class SortOrder(Enum): DESC = 1 @staticmethod - def reverse(order): + def reverse(order: "SortOrder") -> "SortOrder": if order == SortOrder.ASC: return SortOrder.DESC return SortOrder.ASC @staticmethod - def to_string(order): + def to_string(order: "SortOrder") -> str: if order == SortOrder.ASC: return "asc" return "desc" @staticmethod - def from_string(order): + def from_string(order: str) -> "SortOrder": if order == "asc": return SortOrder.ASC @@ -276,11 +276,13 @@ def es_version(es_client: Elasticsearch) -> Tuple[int, int, int]: property if one doesn't exist yet for the current Elasticsearch version. """ if not hasattr(es_client, "_eland_es_version"): - major, minor, patch = [ - int(x) - for x in re.match( - r"^(\d+)\.(\d+)\.(\d+)", es_client.info()["version"]["number"] - ).groups() - ] + version_info = es_client.info()["version"]["number"] + match = re.match(r"^(\d+)\.(\d+)\.(\d+)", version_info) + if match is None: + raise ValueError( + f"Unable to determine Elasticsearch version. " + f"Received: {version_info}" + ) + major, minor, patch = [int(x) for x in match.groups()] es_client._eland_es_version = (major, minor, patch) - return es_client._eland_es_version + return cast(Tuple[int, int, int], es_client._eland_es_version) diff --git a/eland/dataframe.py b/eland/dataframe.py index 719d1de..346619e 100644 --- a/eland/dataframe.py +++ b/eland/dataframe.py @@ -187,7 +187,7 @@ class DataFrame(NDFrame): """ return len(self.columns) == 0 or len(self.index) == 0 - def head(self, n=5): + def head(self, n: int = 5) -> "DataFrame": """ Return the first n rows. @@ -222,7 +222,7 @@ class DataFrame(NDFrame): """ return DataFrame(query_compiler=self._query_compiler.head(n)) - def tail(self, n=5): + def tail(self, n: int = 5) -> "DataFrame": """ Return the last n rows. diff --git a/eland/filter.py b/eland/filter.py index cf0387b..fe1de3d 100644 --- a/eland/filter.py +++ b/eland/filter.py @@ -14,12 +14,14 @@ # Originally based on code in MIT-licensed pandasticsearch filters +from typing import Dict, Any, List, Optional, Union, cast + class BooleanFilter: - def __init__(self, *args): - self._filter = None + def __init__(self) -> None: + self._filter: Dict[str, Any] = {} - def __and__(self, x): + def __and__(self, x: "BooleanFilter") -> "BooleanFilter": # Combine results if isinstance(self, AndFilter): if "must_not" in x.subtree: @@ -37,7 +39,7 @@ class BooleanFilter: return x return AndFilter(self, x) - def __or__(self, x): + def __or__(self, x: "BooleanFilter") -> "BooleanFilter": # Combine results if isinstance(self, OrFilter): if "must_not" in x.subtree: @@ -53,85 +55,79 @@ class BooleanFilter: return x return OrFilter(self, x) - def __invert__(self): + def __invert__(self) -> "BooleanFilter": return NotFilter(self) - def empty(self): - if self._filter is None: - return True - return False + def empty(self) -> bool: + return not bool(self._filter) - def __repr__(self): + def __repr__(self) -> str: return str(self._filter) @property - def subtree(self): + def subtree(self) -> Dict[str, Any]: if "bool" in self._filter: - return self._filter["bool"] + return cast(Dict[str, Any], self._filter["bool"]) else: return self._filter - def build(self): + def build(self) -> Dict[str, Any]: return self._filter # Binary operator class AndFilter(BooleanFilter): - def __init__(self, *args): - [isinstance(x, BooleanFilter) for x in args] + def __init__(self, *args: BooleanFilter) -> None: super().__init__() self._filter = {"bool": {"must": [x.build() for x in args]}} class OrFilter(BooleanFilter): - def __init__(self, *args): - [isinstance(x, BooleanFilter) for x in args] + def __init__(self, *args: BooleanFilter) -> None: super().__init__() self._filter = {"bool": {"should": [x.build() for x in args]}} class NotFilter(BooleanFilter): - def __init__(self, x): - assert isinstance(x, BooleanFilter) + def __init__(self, x: BooleanFilter) -> None: super().__init__() self._filter = {"bool": {"must_not": x.build()}} # LeafBooleanFilter class GreaterEqual(BooleanFilter): - def __init__(self, field, value): + def __init__(self, field: str, value: Any) -> None: super().__init__() self._filter = {"range": {field: {"gte": value}}} class Greater(BooleanFilter): - def __init__(self, field, value): + def __init__(self, field: str, value: Any) -> None: super().__init__() self._filter = {"range": {field: {"gt": value}}} class LessEqual(BooleanFilter): - def __init__(self, field, value): + def __init__(self, field: str, value: Any) -> None: super().__init__() self._filter = {"range": {field: {"lte": value}}} class Less(BooleanFilter): - def __init__(self, field, value): + def __init__(self, field: str, value: Any) -> None: super().__init__() self._filter = {"range": {field: {"lt": value}}} class Equal(BooleanFilter): - def __init__(self, field, value): + def __init__(self, field: str, value: Any) -> None: super().__init__() self._filter = {"term": {field: value}} class IsIn(BooleanFilter): - def __init__(self, field, value): + def __init__(self, field: str, value: List[Any]) -> None: super().__init__() - assert isinstance(value, list) if field == "ids": self._filter = {"ids": {"values": value}} else: @@ -139,39 +135,44 @@ class IsIn(BooleanFilter): class Like(BooleanFilter): - def __init__(self, field, value): + def __init__(self, field: str, value: str) -> None: super().__init__() self._filter = {"wildcard": {field: value}} class Rlike(BooleanFilter): - def __init__(self, field, value): + def __init__(self, field: str, value: str) -> None: super().__init__() self._filter = {"regexp": {field: value}} class Startswith(BooleanFilter): - def __init__(self, field, value): + def __init__(self, field: str, value: str) -> None: super().__init__() self._filter = {"prefix": {field: value}} class IsNull(BooleanFilter): - def __init__(self, field): + def __init__(self, field: str) -> None: super().__init__() self._filter = {"missing": {"field": field}} class NotNull(BooleanFilter): - def __init__(self, field): + def __init__(self, field: str) -> None: super().__init__() self._filter = {"exists": {"field": field}} class ScriptFilter(BooleanFilter): - def __init__(self, inline, lang=None, params=None): + def __init__( + self, + inline: str, + lang: Optional[str] = None, + params: Optional[Dict[str, Any]] = None, + ) -> None: super().__init__() - script = {"source": inline} + script: Dict[str, Union[str, Dict[str, Any]]] = {"source": inline} if lang is not None: script["lang"] = lang if params is not None: @@ -180,6 +181,6 @@ class ScriptFilter(BooleanFilter): class QueryFilter(BooleanFilter): - def __init__(self, query): + def __init__(self, query: Dict[str, Any]) -> None: super().__init__() self._filter = query diff --git a/eland/index.py b/eland/index.py index 846c2e8..042370e 100644 --- a/eland/index.py +++ b/eland/index.py @@ -13,6 +13,12 @@ # limitations under the License. +from typing import Optional, TextIO, TYPE_CHECKING + +if TYPE_CHECKING: + from .query_compiler import QueryCompiler + + class Index: """ The index for an eland.DataFrame. @@ -33,30 +39,36 @@ class Index: ID_INDEX_FIELD = "_id" ID_SORT_FIELD = "_doc" # if index field is _id, sort by _doc - def __init__(self, query_compiler, index_field=None): + def __init__( + self, query_compiler: "QueryCompiler", index_field: Optional[str] = None + ): self._query_compiler = query_compiler # _is_source_field is set immediately within # index_field.setter self._is_source_field = False - self.index_field = index_field + + # The type:ignore is due to mypy not being smart enough + # to recognize the property.setter has a different type + # than the property.getter. + self.index_field = index_field # type: ignore @property - def sort_field(self): + def sort_field(self) -> str: if self._index_field == self.ID_INDEX_FIELD: return self.ID_SORT_FIELD return self._index_field @property - def is_source_field(self): + def is_source_field(self) -> bool: return self._is_source_field @property - def index_field(self): + def index_field(self) -> str: return self._index_field @index_field.setter - def index_field(self, index_field): + def index_field(self, index_field: Optional[str]) -> None: if index_field is None or index_field == Index.ID_INDEX_FIELD: self._index_field = Index.ID_INDEX_FIELD self._is_source_field = False @@ -64,18 +76,18 @@ class Index: self._index_field = index_field self._is_source_field = True - def __len__(self): + def __len__(self) -> int: return self._query_compiler._index_count() # Make iterable - def __next__(self): + def __next__(self) -> None: # TODO resolve this hack to make this 'iterable' raise StopIteration() - def __iter__(self): + def __iter__(self) -> "Index": return self - def info_es(self, buf): + def info_es(self, buf: TextIO) -> None: buf.write("Index:\n") buf.write(f" index_field: {self.index_field}\n") buf.write(f" is_source_field: {self.is_source_field}\n") diff --git a/eland/operations.py b/eland/operations.py index aa344ff..5c85d82 100644 --- a/eland/operations.py +++ b/eland/operations.py @@ -14,6 +14,7 @@ import copy import warnings +from typing import Optional import numpy as np @@ -98,7 +99,7 @@ class Operations: else: self._arithmetic_op_fields_task.update(display_name, arithmetic_series) - def get_arithmetic_op_fields(self): + def get_arithmetic_op_fields(self) -> Optional[ArithmeticOpFieldsTask]: # get an ArithmeticOpFieldsTask if it exists return self._arithmetic_op_fields_task diff --git a/eland/query.py b/eland/query.py index f695deb..ce712b0 100644 --- a/eland/query.py +++ b/eland/query.py @@ -14,6 +14,7 @@ import warnings from copy import deepcopy +from typing import Optional, Dict, List, Any from eland.filter import BooleanFilter, NotNull, IsNull, IsIn @@ -23,7 +24,11 @@ class Query: Simple class to manage building Elasticsearch queries. """ - def __init__(self, query=None): + def __init__(self, query: Optional["Query"] = None): + # type defs + self._query: BooleanFilter + self._aggs: Dict[str, Any] + if query is None: self._query = BooleanFilter() self._aggs = {} @@ -32,7 +37,7 @@ class Query: self._query = deepcopy(query._query) self._aggs = deepcopy(query._aggs) - def exists(self, field, must=True): + def exists(self, field: str, must: bool = True) -> None: """ Add exists query https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-exists-query.html @@ -48,7 +53,7 @@ class Query: else: self._query = self._query & IsNull(field) - def ids(self, items, must=True): + def ids(self, items: List[Any], must: bool = True) -> None: """ Add ids query https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-ids-query.html @@ -64,7 +69,7 @@ class Query: else: self._query = self._query & ~(IsIn("ids", items)) - def terms(self, field, items, must=True): + def terms(self, field: str, items: List[str], must: bool = True) -> None: """ Add ids query https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-terms-query.html @@ -80,7 +85,7 @@ class Query: else: self._query = self._query & ~(IsIn(field, items)) - def terms_aggs(self, name, func, field, es_size): + def terms_aggs(self, name: str, func: str, field: str, es_size: int) -> None: """ Add terms agg e.g @@ -96,7 +101,7 @@ class Query: agg = {func: {"field": field, "size": es_size}} self._aggs[name] = agg - def metric_aggs(self, name, func, field): + def metric_aggs(self, name: str, func: str, field: str) -> None: """ Add metric agg e.g @@ -111,7 +116,14 @@ class Query: agg = {func: {"field": field}} self._aggs[name] = agg - def hist_aggs(self, name, field, min_aggs, max_aggs, num_bins): + def hist_aggs( + self, + name: str, + field: str, + min_aggs: Dict[str, Any], + max_aggs: Dict[str, Any], + num_bins: int, + ) -> None: """ Add histogram agg e.g. "aggs": { @@ -127,14 +139,15 @@ class Query: max = max_aggs[field] interval = (max - min) / num_bins - offset = min - - agg = {"histogram": {"field": field, "interval": interval, "offset": offset}} if interval != 0: + offset = min + agg = { + "histogram": {"field": field, "interval": interval, "offset": offset} + } self._aggs[name] = agg - def to_search_body(self): + def to_search_body(self) -> Dict[str, Any]: body = {} if self._aggs: body["aggs"] = self._aggs @@ -142,19 +155,19 @@ class Query: body["query"] = self._query.build() return body - def to_count_body(self): + def to_count_body(self) -> Optional[Dict[str, Any]]: if len(self._aggs) > 0: - warnings.warn("Requesting count for agg query {}", self) + warnings.warn(f"Requesting count for agg query {self}") if self._query.empty(): return None else: return {"query": self._query.build()} - def update_boolean_filter(self, boolean_filter): + def update_boolean_filter(self, boolean_filter: BooleanFilter) -> None: if self._query.empty(): self._query = boolean_filter else: self._query = self._query & boolean_filter - def __repr__(self): + def __repr__(self) -> str: return repr(self.to_search_body()) diff --git a/eland/query_compiler.py b/eland/query_compiler.py index 2a4be4d..5d55076 100644 --- a/eland/query_compiler.py +++ b/eland/query_compiler.py @@ -14,6 +14,7 @@ import copy from datetime import datetime +from typing import Optional, TYPE_CHECKING import numpy as np import pandas as pd @@ -28,6 +29,9 @@ from eland.common import ( elasticsearch_date_to_pandas_date, ) +if TYPE_CHECKING: + from .tasks import ArithmeticOpFieldsTask # noqa: F401 + class QueryCompiler: """ @@ -348,7 +352,7 @@ class QueryCompiler: return out - def _index_count(self): + def _index_count(self) -> int: """ Returns ------- @@ -562,10 +566,10 @@ class QueryCompiler: return result - def get_arithmetic_op_fields(self): + def get_arithmetic_op_fields(self) -> Optional["ArithmeticOpFieldsTask"]: return self._operations.get_arithmetic_op_fields() - def display_name_to_aggregatable_name(self, display_name): + def display_name_to_aggregatable_name(self, display_name: str) -> str: aggregatable_field_name = self._mappings.aggregatable_field_name(display_name) return aggregatable_field_name diff --git a/eland/tasks.py b/eland/tasks.py index aa7b158..4086942 100644 --- a/eland/tasks.py +++ b/eland/tasks.py @@ -13,12 +13,22 @@ # limitations under the License. from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, List, Dict, Any, Tuple from eland import SortOrder from eland.actions import HeadAction, TailAction, SortIndexAction from eland.arithmetics import ArithmeticSeries +if TYPE_CHECKING: + from .actions import PostProcessingAction # noqa: F401 + from .filter import BooleanFilter # noqa: F401 + from .query_compiler import QueryCompiler # noqa: F401 + +QUERY_PARAMS_TYPE = Dict[str, Any] +RESOLVED_TASK_TYPE = Tuple[QUERY_PARAMS_TYPE, List["PostProcessingAction"]] + + class Task(ABC): """ Abstract class for tasks @@ -29,44 +39,51 @@ class Task(ABC): The task type (e.g. head, tail etc.) """ - def __init__(self, task_type): + def __init__(self, task_type: str): self._task_type = task_type @property - def type(self): + def type(self) -> str: return self._task_type @abstractmethod - def resolve_task(self, query_params, post_processing, query_compiler): + def resolve_task( + self, + query_params: QUERY_PARAMS_TYPE, + post_processing: List["PostProcessingAction"], + query_compiler: "QueryCompiler", + ) -> RESOLVED_TASK_TYPE: pass @abstractmethod - def __repr__(self): + def __repr__(self) -> str: pass class SizeTask(Task): - def __init__(self, task_type): - super().__init__(task_type) - @abstractmethod - def size(self): + def size(self) -> int: # must override pass class HeadTask(SizeTask): - def __init__(self, sort_field, count): + def __init__(self, sort_field: str, count: int): super().__init__("head") # Add a task that is an ascending sort with size=count self._sort_field = sort_field self._count = count - def __repr__(self): + def __repr__(self) -> str: return f"('{self._task_type}': ('sort_field': '{self._sort_field}', 'count': {self._count}))" - def resolve_task(self, query_params, post_processing, query_compiler): + def resolve_task( + self, + query_params: QUERY_PARAMS_TYPE, + post_processing: List["PostProcessingAction"], + query_compiler: "QueryCompiler", + ) -> RESOLVED_TASK_TYPE: # head - sort asc, size n # |12345-------------| query_sort_field = self._sort_field @@ -97,22 +114,24 @@ class HeadTask(SizeTask): return query_params, post_processing - def size(self): + def size(self) -> int: return self._count class TailTask(SizeTask): - def __init__(self, sort_field, count): + def __init__(self, sort_field: str, count: int): super().__init__("tail") # Add a task that is descending sort with size=count self._sort_field = sort_field self._count = count - def __repr__(self): - return f"('{self._task_type}': ('sort_field': '{self._sort_field}', 'count': {self._count}))" - - def resolve_task(self, query_params, post_processing, query_compiler): + def resolve_task( + self, + query_params: QUERY_PARAMS_TYPE, + post_processing: List["PostProcessingAction"], + query_compiler: "QueryCompiler", + ) -> RESOLVED_TASK_TYPE: # tail - sort desc, size n, post-process sort asc # |-------------12345| query_sort_field = self._sort_field @@ -123,7 +142,10 @@ class TailTask(SizeTask): if ( query_params["query_size"] is not None and query_params["query_sort_order"] == query_sort_order - and post_processing == ["sort_index"] + and ( + len(post_processing) == 1 + and isinstance(post_processing[0], SortIndexAction) + ) ): if query_size < query_params["query_size"]: query_params["query_size"] = query_size @@ -156,12 +178,15 @@ class TailTask(SizeTask): return query_params, post_processing - def size(self): + def size(self) -> int: return self._count + def __repr__(self) -> str: + return f"('{self._task_type}': ('sort_field': '{self._sort_field}', 'count': {self._count}))" + class QueryIdsTask(Task): - def __init__(self, must, ids): + def __init__(self, must: bool, ids: List[str]): """ Parameters ---------- @@ -176,17 +201,21 @@ class QueryIdsTask(Task): self._must = must self._ids = ids - def resolve_task(self, query_params, post_processing, query_compiler): + def resolve_task( + self, + query_params: QUERY_PARAMS_TYPE, + post_processing: List["PostProcessingAction"], + query_compiler: "QueryCompiler", + ) -> RESOLVED_TASK_TYPE: query_params["query"].ids(self._ids, must=self._must) - return query_params, post_processing - def __repr__(self): + def __repr__(self) -> str: return f"('{self._task_type}': ('must': {self._must}, 'ids': {self._ids}))" class QueryTermsTask(Task): - def __init__(self, must, field, terms): + def __init__(self, must: bool, field: str, terms: List[Any]): """ Parameters ---------- @@ -205,20 +234,24 @@ class QueryTermsTask(Task): self._field = field self._terms = terms - def __repr__(self): + def resolve_task( + self, + query_params: QUERY_PARAMS_TYPE, + post_processing: List["PostProcessingAction"], + query_compiler: "QueryCompiler", + ) -> RESOLVED_TASK_TYPE: + query_params["query"].terms(self._field, self._terms, must=self._must) + return query_params, post_processing + + def __repr__(self) -> str: return ( f"('{self._task_type}': ('must': {self._must}, " f"'field': '{self._field}', 'terms': {self._terms}))" ) - def resolve_task(self, query_params, post_processing, query_compiler): - query_params["query"].terms(self._field, self._terms, must=self._must) - - return query_params, post_processing - class BooleanFilterTask(Task): - def __init__(self, boolean_filter): + def __init__(self, boolean_filter: "BooleanFilter"): """ Parameters ---------- @@ -229,17 +262,21 @@ class BooleanFilterTask(Task): self._boolean_filter = boolean_filter - def __repr__(self): - return f"('{self._task_type}': ('boolean_filter': {self._boolean_filter!r}))" - - def resolve_task(self, query_params, post_processing, query_compiler): + def resolve_task( + self, + query_params: QUERY_PARAMS_TYPE, + post_processing: List["PostProcessingAction"], + query_compiler: "QueryCompiler", + ) -> RESOLVED_TASK_TYPE: query_params["query"].update_boolean_filter(self._boolean_filter) - return query_params, post_processing + def __repr__(self) -> str: + return f"('{self._task_type}': ('boolean_filter': {self._boolean_filter!r}))" + class ArithmeticOpFieldsTask(Task): - def __init__(self, display_name, arithmetic_series): + def __init__(self, display_name: str, arithmetic_series: ArithmeticSeries): super().__init__("arithmetic_op_fields") self._display_name = display_name @@ -248,19 +285,16 @@ class ArithmeticOpFieldsTask(Task): raise TypeError(f"Expecting ArithmeticSeries got {type(arithmetic_series)}") self._arithmetic_series = arithmetic_series - def __repr__(self): - return ( - f"('{self._task_type}': (" - f"'display_name': {self._display_name}, " - f"'arithmetic_object': {self._arithmetic_series}" - f"))" - ) - - def update(self, display_name, arithmetic_series): + def update(self, display_name: str, arithmetic_series: ArithmeticSeries) -> None: self._display_name = display_name self._arithmetic_series = arithmetic_series - def resolve_task(self, query_params, post_processing, query_compiler): + def resolve_task( + self, + query_params: QUERY_PARAMS_TYPE, + post_processing: List["PostProcessingAction"], + query_compiler: "QueryCompiler", + ) -> RESOLVED_TASK_TYPE: # https://www.elastic.co/guide/en/elasticsearch/painless/current/painless-api-reference-shared-java-lang.html#painless-api-reference-shared-Math """ "script_fields": { @@ -272,7 +306,10 @@ class ArithmeticOpFieldsTask(Task): } """ if query_params["query_script_fields"] is None: - query_params["query_script_fields"] = dict() + query_params["query_script_fields"] = {} + + # TODO: Remove this once 'query_params' becomes a dataclass. + assert isinstance(query_params["query_script_fields"], dict) if self._display_name in query_params["query_script_fields"]: raise NotImplementedError( @@ -286,3 +323,11 @@ class ArithmeticOpFieldsTask(Task): } return query_params, post_processing + + def __repr__(self) -> str: + return ( + f"('{self._task_type}': (" + f"'display_name': {self._display_name}, " + f"'arithmetic_object': {self._arithmetic_series}" + f"))" + ) diff --git a/eland/utils.py b/eland/utils.py index e3e9a8f..37327ae 100644 --- a/eland/utils.py +++ b/eland/utils.py @@ -13,6 +13,7 @@ # limitations under the License. import csv +from typing import Union, List, Tuple, Optional, Mapping import pandas as pd from pandas.io.parsers import _c_parser_defaults @@ -20,10 +21,14 @@ from pandas.io.parsers import _c_parser_defaults from eland import DataFrame from eland.field_mappings import FieldMappings from eland.common import ensure_es_client, DEFAULT_CHUNK_SIZE +from elasticsearch import Elasticsearch from elasticsearch.helpers import bulk -def read_es(es_client, es_index_pattern): +def read_es( + es_client: Union[str, List[str], Tuple[str, ...], Elasticsearch], + es_index_pattern: str, +) -> DataFrame: """ Utility method to create an eland.Dataframe from an Elasticsearch index_pattern. (Similar to pandas.read_csv, but source data is an Elasticsearch index rather than @@ -50,16 +55,16 @@ def read_es(es_client, es_index_pattern): def pandas_to_eland( - pd_df, - es_client, - es_dest_index, - es_if_exists="fail", - es_refresh=False, - es_dropna=False, - es_type_overrides=None, - chunksize=None, - use_pandas_index_for_es_ids=True, -): + pd_df: pd.DataFrame, + es_client: Union[str, List[str], Tuple[str, ...], Elasticsearch], + es_dest_index: str, + es_if_exists: str = "fail", + es_refresh: bool = False, + es_dropna: bool = False, + es_type_overrides: Optional[Mapping[str, str]] = None, + chunksize: Optional[int] = None, + use_pandas_index_for_es_ids: bool = True, +) -> DataFrame: """ Append a pandas DataFrame to an Elasticsearch index. Mainly used in testing. @@ -217,7 +222,7 @@ def pandas_to_eland( return DataFrame(es_client, es_dest_index) -def eland_to_pandas(ed_df, show_progress=False): +def eland_to_pandas(ed_df: DataFrame, show_progress: bool = False) -> pd.DataFrame: """ Convert an eland.Dataframe to a pandas.DataFrame diff --git a/noxfile.py b/noxfile.py index ff80ad8..2d31e48 100644 --- a/noxfile.py +++ b/noxfile.py @@ -1,4 +1,5 @@ import os +import subprocess from pathlib import Path import nox import elasticsearch @@ -12,6 +13,19 @@ SOURCE_FILES = ( "docs/", ) +# Whenever type-hints are completed on a file it should +# be added here so that this file will continue to be checked +# by mypy. Errors from other files are ignored. +TYPED_FILES = { + "eland/actions.py", + "eland/arithmetics.py", + "eland/common.py", + "eland/filter.py", + "eland/index.py", + "eland/query.py", + "eland/tasks.py", +} + @nox.session(reuse_venv=True) def blacken(session): @@ -22,10 +36,28 @@ def blacken(session): @nox.session(reuse_venv=True) def lint(session): - session.install("black", "flake8") + session.install("black", "flake8", "mypy") session.run("black", "--check", "--target-version=py36", *SOURCE_FILES) session.run("flake8", "--ignore=E501,W503,E402,E712", *SOURCE_FILES) + # TODO: When all files are typed we can change this to .run("mypy", "--strict", "eland/") + session.log("mypy --strict eland/") + for typed_file in TYPED_FILES: + popen = subprocess.Popen( + f"mypy --strict {typed_file}", + shell=True, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + ) + popen.wait() + errors = [] + for line in popen.stdout.read().decode().split("\n"): + filepath = line.partition(":")[0] + if filepath in TYPED_FILES: + errors.append(line) + if errors: + session.error("\n" + "\n".join(sorted(set(errors)))) + @nox.session(python=["3.6", "3.7", "3.8"]) def test(session):