Add type hints to base modules

This commit is contained in:
Seth Michael Larson 2020-04-24 12:39:13 -05:00 committed by GitHub
parent fe6589ae6a
commit 33b4976f9a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 333 additions and 205 deletions

View File

@ -13,15 +13,16 @@
# limitations under the License.
from abc import ABC, abstractmethod
# -------------------------------------------------------------------------------------------------------------------- #
# PostProcessingActions #
# -------------------------------------------------------------------------------------------------------------------- #
from typing import TYPE_CHECKING
from eland import SortOrder
if TYPE_CHECKING:
import pandas as pd # type: ignore
class PostProcessingAction(ABC):
def __init__(self, action_type):
def __init__(self, action_type: str) -> None:
"""
Abstract class for postprocessing actions
@ -33,76 +34,74 @@ class PostProcessingAction(ABC):
self._action_type = action_type
@property
def type(self):
def type(self) -> str:
return self._action_type
@abstractmethod
def resolve_action(self, df):
def resolve_action(self, df: "pd.DataFrame") -> "pd.DataFrame":
pass
@abstractmethod
def __repr__(self):
def __repr__(self) -> str:
pass
class SortIndexAction(PostProcessingAction):
def __init__(self):
def __init__(self) -> None:
super().__init__("sort_index")
def resolve_action(self, df):
def resolve_action(self, df: "pd.DataFrame") -> "pd.DataFrame":
return df.sort_index()
def __repr__(self):
def __repr__(self) -> str:
return f"('{self.type}')"
class HeadAction(PostProcessingAction):
def __init__(self, count):
def __init__(self, count: int) -> None:
super().__init__("head")
self._count = count
def resolve_action(self, df):
def resolve_action(self, df: "pd.DataFrame") -> "pd.DataFrame":
return df.head(self._count)
def __repr__(self):
def __repr__(self) -> str:
return f"('{self.type}': ('count': {self._count}))"
class TailAction(PostProcessingAction):
def __init__(self, count):
def __init__(self, count: int) -> None:
super().__init__("tail")
self._count = count
def resolve_action(self, df):
def resolve_action(self, df: "pd.DataFrame") -> "pd.DataFrame":
return df.tail(self._count)
def __repr__(self):
def __repr__(self) -> str:
return f"('{self.type}': ('count': {self._count}))"
class SortFieldAction(PostProcessingAction):
def __init__(self, sort_params_string):
def __init__(self, sort_params_string: str) -> None:
super().__init__("sort_field")
if sort_params_string is None:
raise ValueError("Expected valid string")
# Split string
sort_params = sort_params_string.split(":")
if len(sort_params) != 2:
sort_field, _, sort_order = sort_params_string.partition(":")
if not sort_field or sort_order not in ("asc", "desc"):
raise ValueError(
f"Expected ES sort params string (e.g. _doc:desc). Got '{sort_params_string}'"
)
self._sort_field = sort_params[0]
self._sort_order = SortOrder.from_string(sort_params[1])
self._sort_field = sort_field
self._sort_order = SortOrder.from_string(sort_order)
def resolve_action(self, df):
if self._sort_order == SortOrder.ASC:
return df.sort_values(self._sort_field, True)
return df.sort_values(self._sort_field, False)
def resolve_action(self, df: "pd.DataFrame") -> "pd.DataFrame":
return df.sort_values(self._sort_field, self._sort_order == SortOrder.ASC)
def __repr__(self):
def __repr__(self) -> str:
return f"('{self.type}': ('sort_field': '{self._sort_field}', 'sort_order': {self._sort_order}))"

View File

@ -14,72 +14,88 @@
from abc import ABC, abstractmethod
from io import StringIO
from typing import Union, List, TYPE_CHECKING, Any
import numpy as np
import numpy as np # type: ignore
if TYPE_CHECKING:
from .query_compiler import QueryCompiler
class ArithmeticObject(ABC):
@property
@abstractmethod
def value(self):
def value(self) -> str:
pass
@abstractmethod
def dtype(self):
def dtype(self) -> np.dtype:
pass
@abstractmethod
def resolve(self):
def resolve(self) -> str:
pass
@abstractmethod
def __repr__(self):
def __repr__(self) -> str:
pass
class ArithmeticString(ArithmeticObject):
def __init__(self, value):
def __init__(self, value: str):
self._value = value
def resolve(self):
def resolve(self) -> str:
return self.value
@property
def dtype(self):
def dtype(self) -> np.dtype:
return np.dtype(object)
@property
def value(self):
def value(self) -> str:
return f"'{self._value}'"
def __repr__(self):
def __repr__(self) -> str:
return self.value
class ArithmeticNumber(ArithmeticObject):
def __init__(self, value, dtype):
def __init__(self, value: Union[int, float], dtype: np.dtype):
self._value = value
self._dtype = dtype
def resolve(self):
def resolve(self) -> str:
return self.value
@property
def value(self):
def value(self) -> str:
return f"{self._value}"
@property
def dtype(self):
def dtype(self) -> np.dtype:
return self._dtype
def __repr__(self):
def __repr__(self) -> str:
return self.value
class ArithmeticSeries(ArithmeticObject):
def __init__(self, query_compiler, display_name, dtype):
"""Represents each item in a 'Series' by using painless scripts
to evaluate each document in an index as a part of a query.
"""
def __init__(
self, query_compiler: "QueryCompiler", display_name: str, dtype: np.dtype
):
# type defs
self._value: str
self._tasks: List["ArithmeticTask"]
task = query_compiler.get_arithmetic_op_fields()
if task is not None:
assert isinstance(task._arithmetic_series, ArithmeticSeries)
self._value = task._arithmetic_series.value
self._tasks = task._arithmetic_series._tasks.copy()
self._dtype = dtype
@ -98,14 +114,14 @@ class ArithmeticSeries(ArithmeticObject):
self._dtype = dtype
@property
def value(self):
def value(self) -> str:
return self._value
@property
def dtype(self):
def dtype(self) -> np.dtype:
return self._dtype
def __repr__(self):
def __repr__(self) -> str:
buf = StringIO()
buf.write(f"Series: {self.value} ")
buf.write("Tasks: ")
@ -113,7 +129,7 @@ class ArithmeticSeries(ArithmeticObject):
buf.write(f"{task!r} ")
return buf.getvalue()
def resolve(self):
def resolve(self) -> str:
value = self._value
for task in self._tasks:
@ -148,7 +164,7 @@ class ArithmeticSeries(ArithmeticObject):
return value
def arithmetic_operation(self, op_name, right):
def arithmetic_operation(self, op_name: str, right: Any) -> "ArithmeticSeries":
# check if operation is supported (raises on unsupported)
self.check_is_supported(op_name, right)
@ -156,7 +172,7 @@ class ArithmeticSeries(ArithmeticObject):
self._tasks.append(task)
return self
def check_is_supported(self, op_name, right):
def check_is_supported(self, op_name: str, right: Any) -> bool:
# supported set is
# series.number op_name number (all ops)
# series.string op_name string (only add)
@ -165,22 +181,20 @@ class ArithmeticSeries(ArithmeticObject):
# series.int op_name string (none)
# series.float op_name string (none)
# see end of https://pandas.pydata.org/pandas-docs/stable/getting_started/basics.html?highlight=dtype for
# dtype heirarchy
if np.issubdtype(self.dtype, np.number) and np.issubdtype(
right.dtype, np.number
):
# see end of https://pandas.pydata.org/pandas-docs/stable/getting_started/basics.html?highlight=dtype
# for dtype hierarchy
right_is_integer = np.issubdtype(right.dtype, np.number)
if np.issubdtype(self.dtype, np.number) and right_is_integer:
# series.number op_name number (all ops)
return True
elif np.issubdtype(self.dtype, np.object_) and np.issubdtype(
right.dtype, np.object_
):
self_is_object = np.issubdtype(self.dtype, np.object_)
if self_is_object and np.issubdtype(right.dtype, np.object_):
# series.string op_name string (only add)
if op_name == "__add__" or op_name == "__radd__":
return True
elif np.issubdtype(self.dtype, np.object_) and np.issubdtype(
right.dtype, np.integer
):
if self_is_object and right_is_integer:
# series.string op_name int (only mul)
if op_name == "__mul__":
return True
@ -191,22 +205,22 @@ class ArithmeticSeries(ArithmeticObject):
class ArithmeticTask:
def __init__(self, op_name, object):
def __init__(self, op_name: str, object: ArithmeticObject):
self._op_name = op_name
if not isinstance(object, ArithmeticObject):
raise TypeError(f"Task requires ArithmeticObject not {type(object)}")
self._object = object
def __repr__(self):
def __repr__(self) -> str:
buf = StringIO()
buf.write(f"op_name: {self.op_name} object: {self.object!r} ")
return buf.getvalue()
@property
def op_name(self):
def op_name(self) -> str:
return self._op_name
@property
def object(self):
def object(self) -> ArithmeticObject:
return self._object

View File

@ -15,10 +15,10 @@
import re
import warnings
from enum import Enum
from typing import Union, List, Tuple
from typing import Union, List, Tuple, cast, Callable, Any
import pandas as pd
from elasticsearch import Elasticsearch
import pandas as pd # type: ignore
from elasticsearch import Elasticsearch # type: ignore
# Default number of rows displayed (different to pandas where ALL could be displayed)
DEFAULT_NUM_ROWS_DISPLAYED = 60
@ -29,8 +29,8 @@ DEFAULT_PROGRESS_REPORTING_NUM_ROWS = 10000
DEFAULT_ES_MAX_RESULT_WINDOW = 10000 # index.max_result_window
def docstring_parameter(*sub):
def dec(obj):
def docstring_parameter(*sub: Any) -> Callable[[Any], Any]:
def dec(obj: Any) -> Any:
obj.__doc__ = obj.__doc__.format(*sub)
return obj
@ -42,21 +42,21 @@ class SortOrder(Enum):
DESC = 1
@staticmethod
def reverse(order):
def reverse(order: "SortOrder") -> "SortOrder":
if order == SortOrder.ASC:
return SortOrder.DESC
return SortOrder.ASC
@staticmethod
def to_string(order):
def to_string(order: "SortOrder") -> str:
if order == SortOrder.ASC:
return "asc"
return "desc"
@staticmethod
def from_string(order):
def from_string(order: str) -> "SortOrder":
if order == "asc":
return SortOrder.ASC
@ -276,11 +276,13 @@ def es_version(es_client: Elasticsearch) -> Tuple[int, int, int]:
property if one doesn't exist yet for the current Elasticsearch version.
"""
if not hasattr(es_client, "_eland_es_version"):
major, minor, patch = [
int(x)
for x in re.match(
r"^(\d+)\.(\d+)\.(\d+)", es_client.info()["version"]["number"]
).groups()
]
version_info = es_client.info()["version"]["number"]
match = re.match(r"^(\d+)\.(\d+)\.(\d+)", version_info)
if match is None:
raise ValueError(
f"Unable to determine Elasticsearch version. "
f"Received: {version_info}"
)
major, minor, patch = [int(x) for x in match.groups()]
es_client._eland_es_version = (major, minor, patch)
return es_client._eland_es_version
return cast(Tuple[int, int, int], es_client._eland_es_version)

View File

@ -187,7 +187,7 @@ class DataFrame(NDFrame):
"""
return len(self.columns) == 0 or len(self.index) == 0
def head(self, n=5):
def head(self, n: int = 5) -> "DataFrame":
"""
Return the first n rows.
@ -222,7 +222,7 @@ class DataFrame(NDFrame):
"""
return DataFrame(query_compiler=self._query_compiler.head(n))
def tail(self, n=5):
def tail(self, n: int = 5) -> "DataFrame":
"""
Return the last n rows.

View File

@ -14,12 +14,14 @@
# Originally based on code in MIT-licensed pandasticsearch filters
from typing import Dict, Any, List, Optional, Union, cast
class BooleanFilter:
def __init__(self, *args):
self._filter = None
def __init__(self) -> None:
self._filter: Dict[str, Any] = {}
def __and__(self, x):
def __and__(self, x: "BooleanFilter") -> "BooleanFilter":
# Combine results
if isinstance(self, AndFilter):
if "must_not" in x.subtree:
@ -37,7 +39,7 @@ class BooleanFilter:
return x
return AndFilter(self, x)
def __or__(self, x):
def __or__(self, x: "BooleanFilter") -> "BooleanFilter":
# Combine results
if isinstance(self, OrFilter):
if "must_not" in x.subtree:
@ -53,85 +55,79 @@ class BooleanFilter:
return x
return OrFilter(self, x)
def __invert__(self):
def __invert__(self) -> "BooleanFilter":
return NotFilter(self)
def empty(self):
if self._filter is None:
return True
return False
def empty(self) -> bool:
return not bool(self._filter)
def __repr__(self):
def __repr__(self) -> str:
return str(self._filter)
@property
def subtree(self):
def subtree(self) -> Dict[str, Any]:
if "bool" in self._filter:
return self._filter["bool"]
return cast(Dict[str, Any], self._filter["bool"])
else:
return self._filter
def build(self):
def build(self) -> Dict[str, Any]:
return self._filter
# Binary operator
class AndFilter(BooleanFilter):
def __init__(self, *args):
[isinstance(x, BooleanFilter) for x in args]
def __init__(self, *args: BooleanFilter) -> None:
super().__init__()
self._filter = {"bool": {"must": [x.build() for x in args]}}
class OrFilter(BooleanFilter):
def __init__(self, *args):
[isinstance(x, BooleanFilter) for x in args]
def __init__(self, *args: BooleanFilter) -> None:
super().__init__()
self._filter = {"bool": {"should": [x.build() for x in args]}}
class NotFilter(BooleanFilter):
def __init__(self, x):
assert isinstance(x, BooleanFilter)
def __init__(self, x: BooleanFilter) -> None:
super().__init__()
self._filter = {"bool": {"must_not": x.build()}}
# LeafBooleanFilter
class GreaterEqual(BooleanFilter):
def __init__(self, field, value):
def __init__(self, field: str, value: Any) -> None:
super().__init__()
self._filter = {"range": {field: {"gte": value}}}
class Greater(BooleanFilter):
def __init__(self, field, value):
def __init__(self, field: str, value: Any) -> None:
super().__init__()
self._filter = {"range": {field: {"gt": value}}}
class LessEqual(BooleanFilter):
def __init__(self, field, value):
def __init__(self, field: str, value: Any) -> None:
super().__init__()
self._filter = {"range": {field: {"lte": value}}}
class Less(BooleanFilter):
def __init__(self, field, value):
def __init__(self, field: str, value: Any) -> None:
super().__init__()
self._filter = {"range": {field: {"lt": value}}}
class Equal(BooleanFilter):
def __init__(self, field, value):
def __init__(self, field: str, value: Any) -> None:
super().__init__()
self._filter = {"term": {field: value}}
class IsIn(BooleanFilter):
def __init__(self, field, value):
def __init__(self, field: str, value: List[Any]) -> None:
super().__init__()
assert isinstance(value, list)
if field == "ids":
self._filter = {"ids": {"values": value}}
else:
@ -139,39 +135,44 @@ class IsIn(BooleanFilter):
class Like(BooleanFilter):
def __init__(self, field, value):
def __init__(self, field: str, value: str) -> None:
super().__init__()
self._filter = {"wildcard": {field: value}}
class Rlike(BooleanFilter):
def __init__(self, field, value):
def __init__(self, field: str, value: str) -> None:
super().__init__()
self._filter = {"regexp": {field: value}}
class Startswith(BooleanFilter):
def __init__(self, field, value):
def __init__(self, field: str, value: str) -> None:
super().__init__()
self._filter = {"prefix": {field: value}}
class IsNull(BooleanFilter):
def __init__(self, field):
def __init__(self, field: str) -> None:
super().__init__()
self._filter = {"missing": {"field": field}}
class NotNull(BooleanFilter):
def __init__(self, field):
def __init__(self, field: str) -> None:
super().__init__()
self._filter = {"exists": {"field": field}}
class ScriptFilter(BooleanFilter):
def __init__(self, inline, lang=None, params=None):
def __init__(
self,
inline: str,
lang: Optional[str] = None,
params: Optional[Dict[str, Any]] = None,
) -> None:
super().__init__()
script = {"source": inline}
script: Dict[str, Union[str, Dict[str, Any]]] = {"source": inline}
if lang is not None:
script["lang"] = lang
if params is not None:
@ -180,6 +181,6 @@ class ScriptFilter(BooleanFilter):
class QueryFilter(BooleanFilter):
def __init__(self, query):
def __init__(self, query: Dict[str, Any]) -> None:
super().__init__()
self._filter = query

View File

@ -13,6 +13,12 @@
# limitations under the License.
from typing import Optional, TextIO, TYPE_CHECKING
if TYPE_CHECKING:
from .query_compiler import QueryCompiler
class Index:
"""
The index for an eland.DataFrame.
@ -33,30 +39,36 @@ class Index:
ID_INDEX_FIELD = "_id"
ID_SORT_FIELD = "_doc" # if index field is _id, sort by _doc
def __init__(self, query_compiler, index_field=None):
def __init__(
self, query_compiler: "QueryCompiler", index_field: Optional[str] = None
):
self._query_compiler = query_compiler
# _is_source_field is set immediately within
# index_field.setter
self._is_source_field = False
self.index_field = index_field
# The type:ignore is due to mypy not being smart enough
# to recognize the property.setter has a different type
# than the property.getter.
self.index_field = index_field # type: ignore
@property
def sort_field(self):
def sort_field(self) -> str:
if self._index_field == self.ID_INDEX_FIELD:
return self.ID_SORT_FIELD
return self._index_field
@property
def is_source_field(self):
def is_source_field(self) -> bool:
return self._is_source_field
@property
def index_field(self):
def index_field(self) -> str:
return self._index_field
@index_field.setter
def index_field(self, index_field):
def index_field(self, index_field: Optional[str]) -> None:
if index_field is None or index_field == Index.ID_INDEX_FIELD:
self._index_field = Index.ID_INDEX_FIELD
self._is_source_field = False
@ -64,18 +76,18 @@ class Index:
self._index_field = index_field
self._is_source_field = True
def __len__(self):
def __len__(self) -> int:
return self._query_compiler._index_count()
# Make iterable
def __next__(self):
def __next__(self) -> None:
# TODO resolve this hack to make this 'iterable'
raise StopIteration()
def __iter__(self):
def __iter__(self) -> "Index":
return self
def info_es(self, buf):
def info_es(self, buf: TextIO) -> None:
buf.write("Index:\n")
buf.write(f" index_field: {self.index_field}\n")
buf.write(f" is_source_field: {self.is_source_field}\n")

View File

@ -14,6 +14,7 @@
import copy
import warnings
from typing import Optional
import numpy as np
@ -98,7 +99,7 @@ class Operations:
else:
self._arithmetic_op_fields_task.update(display_name, arithmetic_series)
def get_arithmetic_op_fields(self):
def get_arithmetic_op_fields(self) -> Optional[ArithmeticOpFieldsTask]:
# get an ArithmeticOpFieldsTask if it exists
return self._arithmetic_op_fields_task

View File

@ -14,6 +14,7 @@
import warnings
from copy import deepcopy
from typing import Optional, Dict, List, Any
from eland.filter import BooleanFilter, NotNull, IsNull, IsIn
@ -23,7 +24,11 @@ class Query:
Simple class to manage building Elasticsearch queries.
"""
def __init__(self, query=None):
def __init__(self, query: Optional["Query"] = None):
# type defs
self._query: BooleanFilter
self._aggs: Dict[str, Any]
if query is None:
self._query = BooleanFilter()
self._aggs = {}
@ -32,7 +37,7 @@ class Query:
self._query = deepcopy(query._query)
self._aggs = deepcopy(query._aggs)
def exists(self, field, must=True):
def exists(self, field: str, must: bool = True) -> None:
"""
Add exists query
https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-exists-query.html
@ -48,7 +53,7 @@ class Query:
else:
self._query = self._query & IsNull(field)
def ids(self, items, must=True):
def ids(self, items: List[Any], must: bool = True) -> None:
"""
Add ids query
https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-ids-query.html
@ -64,7 +69,7 @@ class Query:
else:
self._query = self._query & ~(IsIn("ids", items))
def terms(self, field, items, must=True):
def terms(self, field: str, items: List[str], must: bool = True) -> None:
"""
Add ids query
https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-terms-query.html
@ -80,7 +85,7 @@ class Query:
else:
self._query = self._query & ~(IsIn(field, items))
def terms_aggs(self, name, func, field, es_size):
def terms_aggs(self, name: str, func: str, field: str, es_size: int) -> None:
"""
Add terms agg e.g
@ -96,7 +101,7 @@ class Query:
agg = {func: {"field": field, "size": es_size}}
self._aggs[name] = agg
def metric_aggs(self, name, func, field):
def metric_aggs(self, name: str, func: str, field: str) -> None:
"""
Add metric agg e.g
@ -111,7 +116,14 @@ class Query:
agg = {func: {"field": field}}
self._aggs[name] = agg
def hist_aggs(self, name, field, min_aggs, max_aggs, num_bins):
def hist_aggs(
self,
name: str,
field: str,
min_aggs: Dict[str, Any],
max_aggs: Dict[str, Any],
num_bins: int,
) -> None:
"""
Add histogram agg e.g.
"aggs": {
@ -127,14 +139,15 @@ class Query:
max = max_aggs[field]
interval = (max - min) / num_bins
offset = min
agg = {"histogram": {"field": field, "interval": interval, "offset": offset}}
if interval != 0:
offset = min
agg = {
"histogram": {"field": field, "interval": interval, "offset": offset}
}
self._aggs[name] = agg
def to_search_body(self):
def to_search_body(self) -> Dict[str, Any]:
body = {}
if self._aggs:
body["aggs"] = self._aggs
@ -142,19 +155,19 @@ class Query:
body["query"] = self._query.build()
return body
def to_count_body(self):
def to_count_body(self) -> Optional[Dict[str, Any]]:
if len(self._aggs) > 0:
warnings.warn("Requesting count for agg query {}", self)
warnings.warn(f"Requesting count for agg query {self}")
if self._query.empty():
return None
else:
return {"query": self._query.build()}
def update_boolean_filter(self, boolean_filter):
def update_boolean_filter(self, boolean_filter: BooleanFilter) -> None:
if self._query.empty():
self._query = boolean_filter
else:
self._query = self._query & boolean_filter
def __repr__(self):
def __repr__(self) -> str:
return repr(self.to_search_body())

View File

@ -14,6 +14,7 @@
import copy
from datetime import datetime
from typing import Optional, TYPE_CHECKING
import numpy as np
import pandas as pd
@ -28,6 +29,9 @@ from eland.common import (
elasticsearch_date_to_pandas_date,
)
if TYPE_CHECKING:
from .tasks import ArithmeticOpFieldsTask # noqa: F401
class QueryCompiler:
"""
@ -348,7 +352,7 @@ class QueryCompiler:
return out
def _index_count(self):
def _index_count(self) -> int:
"""
Returns
-------
@ -562,10 +566,10 @@ class QueryCompiler:
return result
def get_arithmetic_op_fields(self):
def get_arithmetic_op_fields(self) -> Optional["ArithmeticOpFieldsTask"]:
return self._operations.get_arithmetic_op_fields()
def display_name_to_aggregatable_name(self, display_name):
def display_name_to_aggregatable_name(self, display_name: str) -> str:
aggregatable_field_name = self._mappings.aggregatable_field_name(display_name)
return aggregatable_field_name

View File

@ -13,12 +13,22 @@
# limitations under the License.
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, List, Dict, Any, Tuple
from eland import SortOrder
from eland.actions import HeadAction, TailAction, SortIndexAction
from eland.arithmetics import ArithmeticSeries
if TYPE_CHECKING:
from .actions import PostProcessingAction # noqa: F401
from .filter import BooleanFilter # noqa: F401
from .query_compiler import QueryCompiler # noqa: F401
QUERY_PARAMS_TYPE = Dict[str, Any]
RESOLVED_TASK_TYPE = Tuple[QUERY_PARAMS_TYPE, List["PostProcessingAction"]]
class Task(ABC):
"""
Abstract class for tasks
@ -29,44 +39,51 @@ class Task(ABC):
The task type (e.g. head, tail etc.)
"""
def __init__(self, task_type):
def __init__(self, task_type: str):
self._task_type = task_type
@property
def type(self):
def type(self) -> str:
return self._task_type
@abstractmethod
def resolve_task(self, query_params, post_processing, query_compiler):
def resolve_task(
self,
query_params: QUERY_PARAMS_TYPE,
post_processing: List["PostProcessingAction"],
query_compiler: "QueryCompiler",
) -> RESOLVED_TASK_TYPE:
pass
@abstractmethod
def __repr__(self):
def __repr__(self) -> str:
pass
class SizeTask(Task):
def __init__(self, task_type):
super().__init__(task_type)
@abstractmethod
def size(self):
def size(self) -> int:
# must override
pass
class HeadTask(SizeTask):
def __init__(self, sort_field, count):
def __init__(self, sort_field: str, count: int):
super().__init__("head")
# Add a task that is an ascending sort with size=count
self._sort_field = sort_field
self._count = count
def __repr__(self):
def __repr__(self) -> str:
return f"('{self._task_type}': ('sort_field': '{self._sort_field}', 'count': {self._count}))"
def resolve_task(self, query_params, post_processing, query_compiler):
def resolve_task(
self,
query_params: QUERY_PARAMS_TYPE,
post_processing: List["PostProcessingAction"],
query_compiler: "QueryCompiler",
) -> RESOLVED_TASK_TYPE:
# head - sort asc, size n
# |12345-------------|
query_sort_field = self._sort_field
@ -97,22 +114,24 @@ class HeadTask(SizeTask):
return query_params, post_processing
def size(self):
def size(self) -> int:
return self._count
class TailTask(SizeTask):
def __init__(self, sort_field, count):
def __init__(self, sort_field: str, count: int):
super().__init__("tail")
# Add a task that is descending sort with size=count
self._sort_field = sort_field
self._count = count
def __repr__(self):
return f"('{self._task_type}': ('sort_field': '{self._sort_field}', 'count': {self._count}))"
def resolve_task(self, query_params, post_processing, query_compiler):
def resolve_task(
self,
query_params: QUERY_PARAMS_TYPE,
post_processing: List["PostProcessingAction"],
query_compiler: "QueryCompiler",
) -> RESOLVED_TASK_TYPE:
# tail - sort desc, size n, post-process sort asc
# |-------------12345|
query_sort_field = self._sort_field
@ -123,7 +142,10 @@ class TailTask(SizeTask):
if (
query_params["query_size"] is not None
and query_params["query_sort_order"] == query_sort_order
and post_processing == ["sort_index"]
and (
len(post_processing) == 1
and isinstance(post_processing[0], SortIndexAction)
)
):
if query_size < query_params["query_size"]:
query_params["query_size"] = query_size
@ -156,12 +178,15 @@ class TailTask(SizeTask):
return query_params, post_processing
def size(self):
def size(self) -> int:
return self._count
def __repr__(self) -> str:
return f"('{self._task_type}': ('sort_field': '{self._sort_field}', 'count': {self._count}))"
class QueryIdsTask(Task):
def __init__(self, must, ids):
def __init__(self, must: bool, ids: List[str]):
"""
Parameters
----------
@ -176,17 +201,21 @@ class QueryIdsTask(Task):
self._must = must
self._ids = ids
def resolve_task(self, query_params, post_processing, query_compiler):
def resolve_task(
self,
query_params: QUERY_PARAMS_TYPE,
post_processing: List["PostProcessingAction"],
query_compiler: "QueryCompiler",
) -> RESOLVED_TASK_TYPE:
query_params["query"].ids(self._ids, must=self._must)
return query_params, post_processing
def __repr__(self):
def __repr__(self) -> str:
return f"('{self._task_type}': ('must': {self._must}, 'ids': {self._ids}))"
class QueryTermsTask(Task):
def __init__(self, must, field, terms):
def __init__(self, must: bool, field: str, terms: List[Any]):
"""
Parameters
----------
@ -205,20 +234,24 @@ class QueryTermsTask(Task):
self._field = field
self._terms = terms
def __repr__(self):
def resolve_task(
self,
query_params: QUERY_PARAMS_TYPE,
post_processing: List["PostProcessingAction"],
query_compiler: "QueryCompiler",
) -> RESOLVED_TASK_TYPE:
query_params["query"].terms(self._field, self._terms, must=self._must)
return query_params, post_processing
def __repr__(self) -> str:
return (
f"('{self._task_type}': ('must': {self._must}, "
f"'field': '{self._field}', 'terms': {self._terms}))"
)
def resolve_task(self, query_params, post_processing, query_compiler):
query_params["query"].terms(self._field, self._terms, must=self._must)
return query_params, post_processing
class BooleanFilterTask(Task):
def __init__(self, boolean_filter):
def __init__(self, boolean_filter: "BooleanFilter"):
"""
Parameters
----------
@ -229,17 +262,21 @@ class BooleanFilterTask(Task):
self._boolean_filter = boolean_filter
def __repr__(self):
return f"('{self._task_type}': ('boolean_filter': {self._boolean_filter!r}))"
def resolve_task(self, query_params, post_processing, query_compiler):
def resolve_task(
self,
query_params: QUERY_PARAMS_TYPE,
post_processing: List["PostProcessingAction"],
query_compiler: "QueryCompiler",
) -> RESOLVED_TASK_TYPE:
query_params["query"].update_boolean_filter(self._boolean_filter)
return query_params, post_processing
def __repr__(self) -> str:
return f"('{self._task_type}': ('boolean_filter': {self._boolean_filter!r}))"
class ArithmeticOpFieldsTask(Task):
def __init__(self, display_name, arithmetic_series):
def __init__(self, display_name: str, arithmetic_series: ArithmeticSeries):
super().__init__("arithmetic_op_fields")
self._display_name = display_name
@ -248,19 +285,16 @@ class ArithmeticOpFieldsTask(Task):
raise TypeError(f"Expecting ArithmeticSeries got {type(arithmetic_series)}")
self._arithmetic_series = arithmetic_series
def __repr__(self):
return (
f"('{self._task_type}': ("
f"'display_name': {self._display_name}, "
f"'arithmetic_object': {self._arithmetic_series}"
f"))"
)
def update(self, display_name, arithmetic_series):
def update(self, display_name: str, arithmetic_series: ArithmeticSeries) -> None:
self._display_name = display_name
self._arithmetic_series = arithmetic_series
def resolve_task(self, query_params, post_processing, query_compiler):
def resolve_task(
self,
query_params: QUERY_PARAMS_TYPE,
post_processing: List["PostProcessingAction"],
query_compiler: "QueryCompiler",
) -> RESOLVED_TASK_TYPE:
# https://www.elastic.co/guide/en/elasticsearch/painless/current/painless-api-reference-shared-java-lang.html#painless-api-reference-shared-Math
"""
"script_fields": {
@ -272,7 +306,10 @@ class ArithmeticOpFieldsTask(Task):
}
"""
if query_params["query_script_fields"] is None:
query_params["query_script_fields"] = dict()
query_params["query_script_fields"] = {}
# TODO: Remove this once 'query_params' becomes a dataclass.
assert isinstance(query_params["query_script_fields"], dict)
if self._display_name in query_params["query_script_fields"]:
raise NotImplementedError(
@ -286,3 +323,11 @@ class ArithmeticOpFieldsTask(Task):
}
return query_params, post_processing
def __repr__(self) -> str:
return (
f"('{self._task_type}': ("
f"'display_name': {self._display_name}, "
f"'arithmetic_object': {self._arithmetic_series}"
f"))"
)

View File

@ -13,6 +13,7 @@
# limitations under the License.
import csv
from typing import Union, List, Tuple, Optional, Mapping
import pandas as pd
from pandas.io.parsers import _c_parser_defaults
@ -20,10 +21,14 @@ from pandas.io.parsers import _c_parser_defaults
from eland import DataFrame
from eland.field_mappings import FieldMappings
from eland.common import ensure_es_client, DEFAULT_CHUNK_SIZE
from elasticsearch import Elasticsearch
from elasticsearch.helpers import bulk
def read_es(es_client, es_index_pattern):
def read_es(
es_client: Union[str, List[str], Tuple[str, ...], Elasticsearch],
es_index_pattern: str,
) -> DataFrame:
"""
Utility method to create an eland.Dataframe from an Elasticsearch index_pattern.
(Similar to pandas.read_csv, but source data is an Elasticsearch index rather than
@ -50,16 +55,16 @@ def read_es(es_client, es_index_pattern):
def pandas_to_eland(
pd_df,
es_client,
es_dest_index,
es_if_exists="fail",
es_refresh=False,
es_dropna=False,
es_type_overrides=None,
chunksize=None,
use_pandas_index_for_es_ids=True,
):
pd_df: pd.DataFrame,
es_client: Union[str, List[str], Tuple[str, ...], Elasticsearch],
es_dest_index: str,
es_if_exists: str = "fail",
es_refresh: bool = False,
es_dropna: bool = False,
es_type_overrides: Optional[Mapping[str, str]] = None,
chunksize: Optional[int] = None,
use_pandas_index_for_es_ids: bool = True,
) -> DataFrame:
"""
Append a pandas DataFrame to an Elasticsearch index.
Mainly used in testing.
@ -217,7 +222,7 @@ def pandas_to_eland(
return DataFrame(es_client, es_dest_index)
def eland_to_pandas(ed_df, show_progress=False):
def eland_to_pandas(ed_df: DataFrame, show_progress: bool = False) -> pd.DataFrame:
"""
Convert an eland.Dataframe to a pandas.DataFrame

View File

@ -1,4 +1,5 @@
import os
import subprocess
from pathlib import Path
import nox
import elasticsearch
@ -12,6 +13,19 @@ SOURCE_FILES = (
"docs/",
)
# Whenever type-hints are completed on a file it should
# be added here so that this file will continue to be checked
# by mypy. Errors from other files are ignored.
TYPED_FILES = {
"eland/actions.py",
"eland/arithmetics.py",
"eland/common.py",
"eland/filter.py",
"eland/index.py",
"eland/query.py",
"eland/tasks.py",
}
@nox.session(reuse_venv=True)
def blacken(session):
@ -22,10 +36,28 @@ def blacken(session):
@nox.session(reuse_venv=True)
def lint(session):
session.install("black", "flake8")
session.install("black", "flake8", "mypy")
session.run("black", "--check", "--target-version=py36", *SOURCE_FILES)
session.run("flake8", "--ignore=E501,W503,E402,E712", *SOURCE_FILES)
# TODO: When all files are typed we can change this to .run("mypy", "--strict", "eland/")
session.log("mypy --strict eland/")
for typed_file in TYPED_FILES:
popen = subprocess.Popen(
f"mypy --strict {typed_file}",
shell=True,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
)
popen.wait()
errors = []
for line in popen.stdout.read().decode().split("\n"):
filepath = line.partition(":")[0]
if filepath in TYPED_FILES:
errors.append(line)
if errors:
session.error("\n" + "\n".join(sorted(set(errors))))
@nox.session(python=["3.6", "3.7", "3.8"])
def test(session):