mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
Add type hints to base modules
This commit is contained in:
parent
fe6589ae6a
commit
33b4976f9a
@ -13,15 +13,16 @@
|
||||
# limitations under the License.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
# -------------------------------------------------------------------------------------------------------------------- #
|
||||
# PostProcessingActions #
|
||||
# -------------------------------------------------------------------------------------------------------------------- #
|
||||
from typing import TYPE_CHECKING
|
||||
from eland import SortOrder
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import pandas as pd # type: ignore
|
||||
|
||||
|
||||
class PostProcessingAction(ABC):
|
||||
def __init__(self, action_type):
|
||||
def __init__(self, action_type: str) -> None:
|
||||
"""
|
||||
Abstract class for postprocessing actions
|
||||
|
||||
@ -33,76 +34,74 @@ class PostProcessingAction(ABC):
|
||||
self._action_type = action_type
|
||||
|
||||
@property
|
||||
def type(self):
|
||||
def type(self) -> str:
|
||||
return self._action_type
|
||||
|
||||
@abstractmethod
|
||||
def resolve_action(self, df):
|
||||
def resolve_action(self, df: "pd.DataFrame") -> "pd.DataFrame":
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
pass
|
||||
|
||||
|
||||
class SortIndexAction(PostProcessingAction):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__("sort_index")
|
||||
|
||||
def resolve_action(self, df):
|
||||
def resolve_action(self, df: "pd.DataFrame") -> "pd.DataFrame":
|
||||
return df.sort_index()
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return f"('{self.type}')"
|
||||
|
||||
|
||||
class HeadAction(PostProcessingAction):
|
||||
def __init__(self, count):
|
||||
def __init__(self, count: int) -> None:
|
||||
super().__init__("head")
|
||||
|
||||
self._count = count
|
||||
|
||||
def resolve_action(self, df):
|
||||
def resolve_action(self, df: "pd.DataFrame") -> "pd.DataFrame":
|
||||
return df.head(self._count)
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return f"('{self.type}': ('count': {self._count}))"
|
||||
|
||||
|
||||
class TailAction(PostProcessingAction):
|
||||
def __init__(self, count):
|
||||
def __init__(self, count: int) -> None:
|
||||
super().__init__("tail")
|
||||
|
||||
self._count = count
|
||||
|
||||
def resolve_action(self, df):
|
||||
def resolve_action(self, df: "pd.DataFrame") -> "pd.DataFrame":
|
||||
return df.tail(self._count)
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return f"('{self.type}': ('count': {self._count}))"
|
||||
|
||||
|
||||
class SortFieldAction(PostProcessingAction):
|
||||
def __init__(self, sort_params_string):
|
||||
def __init__(self, sort_params_string: str) -> None:
|
||||
super().__init__("sort_field")
|
||||
|
||||
if sort_params_string is None:
|
||||
raise ValueError("Expected valid string")
|
||||
|
||||
# Split string
|
||||
sort_params = sort_params_string.split(":")
|
||||
if len(sort_params) != 2:
|
||||
sort_field, _, sort_order = sort_params_string.partition(":")
|
||||
if not sort_field or sort_order not in ("asc", "desc"):
|
||||
raise ValueError(
|
||||
f"Expected ES sort params string (e.g. _doc:desc). Got '{sort_params_string}'"
|
||||
)
|
||||
|
||||
self._sort_field = sort_params[0]
|
||||
self._sort_order = SortOrder.from_string(sort_params[1])
|
||||
self._sort_field = sort_field
|
||||
self._sort_order = SortOrder.from_string(sort_order)
|
||||
|
||||
def resolve_action(self, df):
|
||||
if self._sort_order == SortOrder.ASC:
|
||||
return df.sort_values(self._sort_field, True)
|
||||
return df.sort_values(self._sort_field, False)
|
||||
def resolve_action(self, df: "pd.DataFrame") -> "pd.DataFrame":
|
||||
return df.sort_values(self._sort_field, self._sort_order == SortOrder.ASC)
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return f"('{self.type}': ('sort_field': '{self._sort_field}', 'sort_order': {self._sort_order}))"
|
||||
|
@ -14,72 +14,88 @@
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from io import StringIO
|
||||
from typing import Union, List, TYPE_CHECKING, Any
|
||||
|
||||
import numpy as np
|
||||
import numpy as np # type: ignore
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .query_compiler import QueryCompiler
|
||||
|
||||
|
||||
class ArithmeticObject(ABC):
|
||||
@property
|
||||
@abstractmethod
|
||||
def value(self):
|
||||
def value(self) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def dtype(self):
|
||||
def dtype(self) -> np.dtype:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def resolve(self):
|
||||
def resolve(self) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
pass
|
||||
|
||||
|
||||
class ArithmeticString(ArithmeticObject):
|
||||
def __init__(self, value):
|
||||
def __init__(self, value: str):
|
||||
self._value = value
|
||||
|
||||
def resolve(self):
|
||||
def resolve(self) -> str:
|
||||
return self.value
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
def dtype(self) -> np.dtype:
|
||||
return np.dtype(object)
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
def value(self) -> str:
|
||||
return f"'{self._value}'"
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return self.value
|
||||
|
||||
|
||||
class ArithmeticNumber(ArithmeticObject):
|
||||
def __init__(self, value, dtype):
|
||||
def __init__(self, value: Union[int, float], dtype: np.dtype):
|
||||
self._value = value
|
||||
self._dtype = dtype
|
||||
|
||||
def resolve(self):
|
||||
def resolve(self) -> str:
|
||||
return self.value
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
def value(self) -> str:
|
||||
return f"{self._value}"
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
def dtype(self) -> np.dtype:
|
||||
return self._dtype
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return self.value
|
||||
|
||||
|
||||
class ArithmeticSeries(ArithmeticObject):
|
||||
def __init__(self, query_compiler, display_name, dtype):
|
||||
"""Represents each item in a 'Series' by using painless scripts
|
||||
to evaluate each document in an index as a part of a query.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, query_compiler: "QueryCompiler", display_name: str, dtype: np.dtype
|
||||
):
|
||||
# type defs
|
||||
self._value: str
|
||||
self._tasks: List["ArithmeticTask"]
|
||||
|
||||
task = query_compiler.get_arithmetic_op_fields()
|
||||
|
||||
if task is not None:
|
||||
assert isinstance(task._arithmetic_series, ArithmeticSeries)
|
||||
self._value = task._arithmetic_series.value
|
||||
self._tasks = task._arithmetic_series._tasks.copy()
|
||||
self._dtype = dtype
|
||||
@ -98,14 +114,14 @@ class ArithmeticSeries(ArithmeticObject):
|
||||
self._dtype = dtype
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
def value(self) -> str:
|
||||
return self._value
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
def dtype(self) -> np.dtype:
|
||||
return self._dtype
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
buf = StringIO()
|
||||
buf.write(f"Series: {self.value} ")
|
||||
buf.write("Tasks: ")
|
||||
@ -113,7 +129,7 @@ class ArithmeticSeries(ArithmeticObject):
|
||||
buf.write(f"{task!r} ")
|
||||
return buf.getvalue()
|
||||
|
||||
def resolve(self):
|
||||
def resolve(self) -> str:
|
||||
value = self._value
|
||||
|
||||
for task in self._tasks:
|
||||
@ -148,7 +164,7 @@ class ArithmeticSeries(ArithmeticObject):
|
||||
|
||||
return value
|
||||
|
||||
def arithmetic_operation(self, op_name, right):
|
||||
def arithmetic_operation(self, op_name: str, right: Any) -> "ArithmeticSeries":
|
||||
# check if operation is supported (raises on unsupported)
|
||||
self.check_is_supported(op_name, right)
|
||||
|
||||
@ -156,7 +172,7 @@ class ArithmeticSeries(ArithmeticObject):
|
||||
self._tasks.append(task)
|
||||
return self
|
||||
|
||||
def check_is_supported(self, op_name, right):
|
||||
def check_is_supported(self, op_name: str, right: Any) -> bool:
|
||||
# supported set is
|
||||
# series.number op_name number (all ops)
|
||||
# series.string op_name string (only add)
|
||||
@ -165,22 +181,20 @@ class ArithmeticSeries(ArithmeticObject):
|
||||
# series.int op_name string (none)
|
||||
# series.float op_name string (none)
|
||||
|
||||
# see end of https://pandas.pydata.org/pandas-docs/stable/getting_started/basics.html?highlight=dtype for
|
||||
# dtype heirarchy
|
||||
if np.issubdtype(self.dtype, np.number) and np.issubdtype(
|
||||
right.dtype, np.number
|
||||
):
|
||||
# see end of https://pandas.pydata.org/pandas-docs/stable/getting_started/basics.html?highlight=dtype
|
||||
# for dtype hierarchy
|
||||
right_is_integer = np.issubdtype(right.dtype, np.number)
|
||||
if np.issubdtype(self.dtype, np.number) and right_is_integer:
|
||||
# series.number op_name number (all ops)
|
||||
return True
|
||||
elif np.issubdtype(self.dtype, np.object_) and np.issubdtype(
|
||||
right.dtype, np.object_
|
||||
):
|
||||
|
||||
self_is_object = np.issubdtype(self.dtype, np.object_)
|
||||
if self_is_object and np.issubdtype(right.dtype, np.object_):
|
||||
# series.string op_name string (only add)
|
||||
if op_name == "__add__" or op_name == "__radd__":
|
||||
return True
|
||||
elif np.issubdtype(self.dtype, np.object_) and np.issubdtype(
|
||||
right.dtype, np.integer
|
||||
):
|
||||
|
||||
if self_is_object and right_is_integer:
|
||||
# series.string op_name int (only mul)
|
||||
if op_name == "__mul__":
|
||||
return True
|
||||
@ -191,22 +205,22 @@ class ArithmeticSeries(ArithmeticObject):
|
||||
|
||||
|
||||
class ArithmeticTask:
|
||||
def __init__(self, op_name, object):
|
||||
def __init__(self, op_name: str, object: ArithmeticObject):
|
||||
self._op_name = op_name
|
||||
|
||||
if not isinstance(object, ArithmeticObject):
|
||||
raise TypeError(f"Task requires ArithmeticObject not {type(object)}")
|
||||
self._object = object
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
buf = StringIO()
|
||||
buf.write(f"op_name: {self.op_name} object: {self.object!r} ")
|
||||
return buf.getvalue()
|
||||
|
||||
@property
|
||||
def op_name(self):
|
||||
def op_name(self) -> str:
|
||||
return self._op_name
|
||||
|
||||
@property
|
||||
def object(self):
|
||||
def object(self) -> ArithmeticObject:
|
||||
return self._object
|
||||
|
@ -15,10 +15,10 @@
|
||||
import re
|
||||
import warnings
|
||||
from enum import Enum
|
||||
from typing import Union, List, Tuple
|
||||
from typing import Union, List, Tuple, cast, Callable, Any
|
||||
|
||||
import pandas as pd
|
||||
from elasticsearch import Elasticsearch
|
||||
import pandas as pd # type: ignore
|
||||
from elasticsearch import Elasticsearch # type: ignore
|
||||
|
||||
# Default number of rows displayed (different to pandas where ALL could be displayed)
|
||||
DEFAULT_NUM_ROWS_DISPLAYED = 60
|
||||
@ -29,8 +29,8 @@ DEFAULT_PROGRESS_REPORTING_NUM_ROWS = 10000
|
||||
DEFAULT_ES_MAX_RESULT_WINDOW = 10000 # index.max_result_window
|
||||
|
||||
|
||||
def docstring_parameter(*sub):
|
||||
def dec(obj):
|
||||
def docstring_parameter(*sub: Any) -> Callable[[Any], Any]:
|
||||
def dec(obj: Any) -> Any:
|
||||
obj.__doc__ = obj.__doc__.format(*sub)
|
||||
return obj
|
||||
|
||||
@ -42,21 +42,21 @@ class SortOrder(Enum):
|
||||
DESC = 1
|
||||
|
||||
@staticmethod
|
||||
def reverse(order):
|
||||
def reverse(order: "SortOrder") -> "SortOrder":
|
||||
if order == SortOrder.ASC:
|
||||
return SortOrder.DESC
|
||||
|
||||
return SortOrder.ASC
|
||||
|
||||
@staticmethod
|
||||
def to_string(order):
|
||||
def to_string(order: "SortOrder") -> str:
|
||||
if order == SortOrder.ASC:
|
||||
return "asc"
|
||||
|
||||
return "desc"
|
||||
|
||||
@staticmethod
|
||||
def from_string(order):
|
||||
def from_string(order: str) -> "SortOrder":
|
||||
if order == "asc":
|
||||
return SortOrder.ASC
|
||||
|
||||
@ -276,11 +276,13 @@ def es_version(es_client: Elasticsearch) -> Tuple[int, int, int]:
|
||||
property if one doesn't exist yet for the current Elasticsearch version.
|
||||
"""
|
||||
if not hasattr(es_client, "_eland_es_version"):
|
||||
major, minor, patch = [
|
||||
int(x)
|
||||
for x in re.match(
|
||||
r"^(\d+)\.(\d+)\.(\d+)", es_client.info()["version"]["number"]
|
||||
).groups()
|
||||
]
|
||||
version_info = es_client.info()["version"]["number"]
|
||||
match = re.match(r"^(\d+)\.(\d+)\.(\d+)", version_info)
|
||||
if match is None:
|
||||
raise ValueError(
|
||||
f"Unable to determine Elasticsearch version. "
|
||||
f"Received: {version_info}"
|
||||
)
|
||||
major, minor, patch = [int(x) for x in match.groups()]
|
||||
es_client._eland_es_version = (major, minor, patch)
|
||||
return es_client._eland_es_version
|
||||
return cast(Tuple[int, int, int], es_client._eland_es_version)
|
||||
|
@ -187,7 +187,7 @@ class DataFrame(NDFrame):
|
||||
"""
|
||||
return len(self.columns) == 0 or len(self.index) == 0
|
||||
|
||||
def head(self, n=5):
|
||||
def head(self, n: int = 5) -> "DataFrame":
|
||||
"""
|
||||
Return the first n rows.
|
||||
|
||||
@ -222,7 +222,7 @@ class DataFrame(NDFrame):
|
||||
"""
|
||||
return DataFrame(query_compiler=self._query_compiler.head(n))
|
||||
|
||||
def tail(self, n=5):
|
||||
def tail(self, n: int = 5) -> "DataFrame":
|
||||
"""
|
||||
Return the last n rows.
|
||||
|
||||
|
@ -14,12 +14,14 @@
|
||||
|
||||
# Originally based on code in MIT-licensed pandasticsearch filters
|
||||
|
||||
from typing import Dict, Any, List, Optional, Union, cast
|
||||
|
||||
|
||||
class BooleanFilter:
|
||||
def __init__(self, *args):
|
||||
self._filter = None
|
||||
def __init__(self) -> None:
|
||||
self._filter: Dict[str, Any] = {}
|
||||
|
||||
def __and__(self, x):
|
||||
def __and__(self, x: "BooleanFilter") -> "BooleanFilter":
|
||||
# Combine results
|
||||
if isinstance(self, AndFilter):
|
||||
if "must_not" in x.subtree:
|
||||
@ -37,7 +39,7 @@ class BooleanFilter:
|
||||
return x
|
||||
return AndFilter(self, x)
|
||||
|
||||
def __or__(self, x):
|
||||
def __or__(self, x: "BooleanFilter") -> "BooleanFilter":
|
||||
# Combine results
|
||||
if isinstance(self, OrFilter):
|
||||
if "must_not" in x.subtree:
|
||||
@ -53,85 +55,79 @@ class BooleanFilter:
|
||||
return x
|
||||
return OrFilter(self, x)
|
||||
|
||||
def __invert__(self):
|
||||
def __invert__(self) -> "BooleanFilter":
|
||||
return NotFilter(self)
|
||||
|
||||
def empty(self):
|
||||
if self._filter is None:
|
||||
return True
|
||||
return False
|
||||
def empty(self) -> bool:
|
||||
return not bool(self._filter)
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return str(self._filter)
|
||||
|
||||
@property
|
||||
def subtree(self):
|
||||
def subtree(self) -> Dict[str, Any]:
|
||||
if "bool" in self._filter:
|
||||
return self._filter["bool"]
|
||||
return cast(Dict[str, Any], self._filter["bool"])
|
||||
else:
|
||||
return self._filter
|
||||
|
||||
def build(self):
|
||||
def build(self) -> Dict[str, Any]:
|
||||
return self._filter
|
||||
|
||||
|
||||
# Binary operator
|
||||
class AndFilter(BooleanFilter):
|
||||
def __init__(self, *args):
|
||||
[isinstance(x, BooleanFilter) for x in args]
|
||||
def __init__(self, *args: BooleanFilter) -> None:
|
||||
super().__init__()
|
||||
self._filter = {"bool": {"must": [x.build() for x in args]}}
|
||||
|
||||
|
||||
class OrFilter(BooleanFilter):
|
||||
def __init__(self, *args):
|
||||
[isinstance(x, BooleanFilter) for x in args]
|
||||
def __init__(self, *args: BooleanFilter) -> None:
|
||||
super().__init__()
|
||||
self._filter = {"bool": {"should": [x.build() for x in args]}}
|
||||
|
||||
|
||||
class NotFilter(BooleanFilter):
|
||||
def __init__(self, x):
|
||||
assert isinstance(x, BooleanFilter)
|
||||
def __init__(self, x: BooleanFilter) -> None:
|
||||
super().__init__()
|
||||
self._filter = {"bool": {"must_not": x.build()}}
|
||||
|
||||
|
||||
# LeafBooleanFilter
|
||||
class GreaterEqual(BooleanFilter):
|
||||
def __init__(self, field, value):
|
||||
def __init__(self, field: str, value: Any) -> None:
|
||||
super().__init__()
|
||||
self._filter = {"range": {field: {"gte": value}}}
|
||||
|
||||
|
||||
class Greater(BooleanFilter):
|
||||
def __init__(self, field, value):
|
||||
def __init__(self, field: str, value: Any) -> None:
|
||||
super().__init__()
|
||||
self._filter = {"range": {field: {"gt": value}}}
|
||||
|
||||
|
||||
class LessEqual(BooleanFilter):
|
||||
def __init__(self, field, value):
|
||||
def __init__(self, field: str, value: Any) -> None:
|
||||
super().__init__()
|
||||
self._filter = {"range": {field: {"lte": value}}}
|
||||
|
||||
|
||||
class Less(BooleanFilter):
|
||||
def __init__(self, field, value):
|
||||
def __init__(self, field: str, value: Any) -> None:
|
||||
super().__init__()
|
||||
self._filter = {"range": {field: {"lt": value}}}
|
||||
|
||||
|
||||
class Equal(BooleanFilter):
|
||||
def __init__(self, field, value):
|
||||
def __init__(self, field: str, value: Any) -> None:
|
||||
super().__init__()
|
||||
self._filter = {"term": {field: value}}
|
||||
|
||||
|
||||
class IsIn(BooleanFilter):
|
||||
def __init__(self, field, value):
|
||||
def __init__(self, field: str, value: List[Any]) -> None:
|
||||
super().__init__()
|
||||
assert isinstance(value, list)
|
||||
if field == "ids":
|
||||
self._filter = {"ids": {"values": value}}
|
||||
else:
|
||||
@ -139,39 +135,44 @@ class IsIn(BooleanFilter):
|
||||
|
||||
|
||||
class Like(BooleanFilter):
|
||||
def __init__(self, field, value):
|
||||
def __init__(self, field: str, value: str) -> None:
|
||||
super().__init__()
|
||||
self._filter = {"wildcard": {field: value}}
|
||||
|
||||
|
||||
class Rlike(BooleanFilter):
|
||||
def __init__(self, field, value):
|
||||
def __init__(self, field: str, value: str) -> None:
|
||||
super().__init__()
|
||||
self._filter = {"regexp": {field: value}}
|
||||
|
||||
|
||||
class Startswith(BooleanFilter):
|
||||
def __init__(self, field, value):
|
||||
def __init__(self, field: str, value: str) -> None:
|
||||
super().__init__()
|
||||
self._filter = {"prefix": {field: value}}
|
||||
|
||||
|
||||
class IsNull(BooleanFilter):
|
||||
def __init__(self, field):
|
||||
def __init__(self, field: str) -> None:
|
||||
super().__init__()
|
||||
self._filter = {"missing": {"field": field}}
|
||||
|
||||
|
||||
class NotNull(BooleanFilter):
|
||||
def __init__(self, field):
|
||||
def __init__(self, field: str) -> None:
|
||||
super().__init__()
|
||||
self._filter = {"exists": {"field": field}}
|
||||
|
||||
|
||||
class ScriptFilter(BooleanFilter):
|
||||
def __init__(self, inline, lang=None, params=None):
|
||||
def __init__(
|
||||
self,
|
||||
inline: str,
|
||||
lang: Optional[str] = None,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
script = {"source": inline}
|
||||
script: Dict[str, Union[str, Dict[str, Any]]] = {"source": inline}
|
||||
if lang is not None:
|
||||
script["lang"] = lang
|
||||
if params is not None:
|
||||
@ -180,6 +181,6 @@ class ScriptFilter(BooleanFilter):
|
||||
|
||||
|
||||
class QueryFilter(BooleanFilter):
|
||||
def __init__(self, query):
|
||||
def __init__(self, query: Dict[str, Any]) -> None:
|
||||
super().__init__()
|
||||
self._filter = query
|
||||
|
@ -13,6 +13,12 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Optional, TextIO, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .query_compiler import QueryCompiler
|
||||
|
||||
|
||||
class Index:
|
||||
"""
|
||||
The index for an eland.DataFrame.
|
||||
@ -33,30 +39,36 @@ class Index:
|
||||
ID_INDEX_FIELD = "_id"
|
||||
ID_SORT_FIELD = "_doc" # if index field is _id, sort by _doc
|
||||
|
||||
def __init__(self, query_compiler, index_field=None):
|
||||
def __init__(
|
||||
self, query_compiler: "QueryCompiler", index_field: Optional[str] = None
|
||||
):
|
||||
self._query_compiler = query_compiler
|
||||
|
||||
# _is_source_field is set immediately within
|
||||
# index_field.setter
|
||||
self._is_source_field = False
|
||||
self.index_field = index_field
|
||||
|
||||
# The type:ignore is due to mypy not being smart enough
|
||||
# to recognize the property.setter has a different type
|
||||
# than the property.getter.
|
||||
self.index_field = index_field # type: ignore
|
||||
|
||||
@property
|
||||
def sort_field(self):
|
||||
def sort_field(self) -> str:
|
||||
if self._index_field == self.ID_INDEX_FIELD:
|
||||
return self.ID_SORT_FIELD
|
||||
return self._index_field
|
||||
|
||||
@property
|
||||
def is_source_field(self):
|
||||
def is_source_field(self) -> bool:
|
||||
return self._is_source_field
|
||||
|
||||
@property
|
||||
def index_field(self):
|
||||
def index_field(self) -> str:
|
||||
return self._index_field
|
||||
|
||||
@index_field.setter
|
||||
def index_field(self, index_field):
|
||||
def index_field(self, index_field: Optional[str]) -> None:
|
||||
if index_field is None or index_field == Index.ID_INDEX_FIELD:
|
||||
self._index_field = Index.ID_INDEX_FIELD
|
||||
self._is_source_field = False
|
||||
@ -64,18 +76,18 @@ class Index:
|
||||
self._index_field = index_field
|
||||
self._is_source_field = True
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
return self._query_compiler._index_count()
|
||||
|
||||
# Make iterable
|
||||
def __next__(self):
|
||||
def __next__(self) -> None:
|
||||
# TODO resolve this hack to make this 'iterable'
|
||||
raise StopIteration()
|
||||
|
||||
def __iter__(self):
|
||||
def __iter__(self) -> "Index":
|
||||
return self
|
||||
|
||||
def info_es(self, buf):
|
||||
def info_es(self, buf: TextIO) -> None:
|
||||
buf.write("Index:\n")
|
||||
buf.write(f" index_field: {self.index_field}\n")
|
||||
buf.write(f" is_source_field: {self.is_source_field}\n")
|
||||
|
@ -14,6 +14,7 @@
|
||||
|
||||
import copy
|
||||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -98,7 +99,7 @@ class Operations:
|
||||
else:
|
||||
self._arithmetic_op_fields_task.update(display_name, arithmetic_series)
|
||||
|
||||
def get_arithmetic_op_fields(self):
|
||||
def get_arithmetic_op_fields(self) -> Optional[ArithmeticOpFieldsTask]:
|
||||
# get an ArithmeticOpFieldsTask if it exists
|
||||
return self._arithmetic_op_fields_task
|
||||
|
||||
|
@ -14,6 +14,7 @@
|
||||
|
||||
import warnings
|
||||
from copy import deepcopy
|
||||
from typing import Optional, Dict, List, Any
|
||||
|
||||
from eland.filter import BooleanFilter, NotNull, IsNull, IsIn
|
||||
|
||||
@ -23,7 +24,11 @@ class Query:
|
||||
Simple class to manage building Elasticsearch queries.
|
||||
"""
|
||||
|
||||
def __init__(self, query=None):
|
||||
def __init__(self, query: Optional["Query"] = None):
|
||||
# type defs
|
||||
self._query: BooleanFilter
|
||||
self._aggs: Dict[str, Any]
|
||||
|
||||
if query is None:
|
||||
self._query = BooleanFilter()
|
||||
self._aggs = {}
|
||||
@ -32,7 +37,7 @@ class Query:
|
||||
self._query = deepcopy(query._query)
|
||||
self._aggs = deepcopy(query._aggs)
|
||||
|
||||
def exists(self, field, must=True):
|
||||
def exists(self, field: str, must: bool = True) -> None:
|
||||
"""
|
||||
Add exists query
|
||||
https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-exists-query.html
|
||||
@ -48,7 +53,7 @@ class Query:
|
||||
else:
|
||||
self._query = self._query & IsNull(field)
|
||||
|
||||
def ids(self, items, must=True):
|
||||
def ids(self, items: List[Any], must: bool = True) -> None:
|
||||
"""
|
||||
Add ids query
|
||||
https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-ids-query.html
|
||||
@ -64,7 +69,7 @@ class Query:
|
||||
else:
|
||||
self._query = self._query & ~(IsIn("ids", items))
|
||||
|
||||
def terms(self, field, items, must=True):
|
||||
def terms(self, field: str, items: List[str], must: bool = True) -> None:
|
||||
"""
|
||||
Add ids query
|
||||
https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-terms-query.html
|
||||
@ -80,7 +85,7 @@ class Query:
|
||||
else:
|
||||
self._query = self._query & ~(IsIn(field, items))
|
||||
|
||||
def terms_aggs(self, name, func, field, es_size):
|
||||
def terms_aggs(self, name: str, func: str, field: str, es_size: int) -> None:
|
||||
"""
|
||||
Add terms agg e.g
|
||||
|
||||
@ -96,7 +101,7 @@ class Query:
|
||||
agg = {func: {"field": field, "size": es_size}}
|
||||
self._aggs[name] = agg
|
||||
|
||||
def metric_aggs(self, name, func, field):
|
||||
def metric_aggs(self, name: str, func: str, field: str) -> None:
|
||||
"""
|
||||
Add metric agg e.g
|
||||
|
||||
@ -111,7 +116,14 @@ class Query:
|
||||
agg = {func: {"field": field}}
|
||||
self._aggs[name] = agg
|
||||
|
||||
def hist_aggs(self, name, field, min_aggs, max_aggs, num_bins):
|
||||
def hist_aggs(
|
||||
self,
|
||||
name: str,
|
||||
field: str,
|
||||
min_aggs: Dict[str, Any],
|
||||
max_aggs: Dict[str, Any],
|
||||
num_bins: int,
|
||||
) -> None:
|
||||
"""
|
||||
Add histogram agg e.g.
|
||||
"aggs": {
|
||||
@ -127,14 +139,15 @@ class Query:
|
||||
max = max_aggs[field]
|
||||
|
||||
interval = (max - min) / num_bins
|
||||
offset = min
|
||||
|
||||
agg = {"histogram": {"field": field, "interval": interval, "offset": offset}}
|
||||
|
||||
if interval != 0:
|
||||
offset = min
|
||||
agg = {
|
||||
"histogram": {"field": field, "interval": interval, "offset": offset}
|
||||
}
|
||||
self._aggs[name] = agg
|
||||
|
||||
def to_search_body(self):
|
||||
def to_search_body(self) -> Dict[str, Any]:
|
||||
body = {}
|
||||
if self._aggs:
|
||||
body["aggs"] = self._aggs
|
||||
@ -142,19 +155,19 @@ class Query:
|
||||
body["query"] = self._query.build()
|
||||
return body
|
||||
|
||||
def to_count_body(self):
|
||||
def to_count_body(self) -> Optional[Dict[str, Any]]:
|
||||
if len(self._aggs) > 0:
|
||||
warnings.warn("Requesting count for agg query {}", self)
|
||||
warnings.warn(f"Requesting count for agg query {self}")
|
||||
if self._query.empty():
|
||||
return None
|
||||
else:
|
||||
return {"query": self._query.build()}
|
||||
|
||||
def update_boolean_filter(self, boolean_filter):
|
||||
def update_boolean_filter(self, boolean_filter: BooleanFilter) -> None:
|
||||
if self._query.empty():
|
||||
self._query = boolean_filter
|
||||
else:
|
||||
self._query = self._query & boolean_filter
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return repr(self.to_search_body())
|
||||
|
@ -14,6 +14,7 @@
|
||||
|
||||
import copy
|
||||
from datetime import datetime
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
@ -28,6 +29,9 @@ from eland.common import (
|
||||
elasticsearch_date_to_pandas_date,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .tasks import ArithmeticOpFieldsTask # noqa: F401
|
||||
|
||||
|
||||
class QueryCompiler:
|
||||
"""
|
||||
@ -348,7 +352,7 @@ class QueryCompiler:
|
||||
|
||||
return out
|
||||
|
||||
def _index_count(self):
|
||||
def _index_count(self) -> int:
|
||||
"""
|
||||
Returns
|
||||
-------
|
||||
@ -562,10 +566,10 @@ class QueryCompiler:
|
||||
|
||||
return result
|
||||
|
||||
def get_arithmetic_op_fields(self):
|
||||
def get_arithmetic_op_fields(self) -> Optional["ArithmeticOpFieldsTask"]:
|
||||
return self._operations.get_arithmetic_op_fields()
|
||||
|
||||
def display_name_to_aggregatable_name(self, display_name):
|
||||
def display_name_to_aggregatable_name(self, display_name: str) -> str:
|
||||
aggregatable_field_name = self._mappings.aggregatable_field_name(display_name)
|
||||
|
||||
return aggregatable_field_name
|
||||
|
141
eland/tasks.py
141
eland/tasks.py
@ -13,12 +13,22 @@
|
||||
# limitations under the License.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, List, Dict, Any, Tuple
|
||||
|
||||
from eland import SortOrder
|
||||
from eland.actions import HeadAction, TailAction, SortIndexAction
|
||||
from eland.arithmetics import ArithmeticSeries
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .actions import PostProcessingAction # noqa: F401
|
||||
from .filter import BooleanFilter # noqa: F401
|
||||
from .query_compiler import QueryCompiler # noqa: F401
|
||||
|
||||
QUERY_PARAMS_TYPE = Dict[str, Any]
|
||||
RESOLVED_TASK_TYPE = Tuple[QUERY_PARAMS_TYPE, List["PostProcessingAction"]]
|
||||
|
||||
|
||||
class Task(ABC):
|
||||
"""
|
||||
Abstract class for tasks
|
||||
@ -29,44 +39,51 @@ class Task(ABC):
|
||||
The task type (e.g. head, tail etc.)
|
||||
"""
|
||||
|
||||
def __init__(self, task_type):
|
||||
def __init__(self, task_type: str):
|
||||
self._task_type = task_type
|
||||
|
||||
@property
|
||||
def type(self):
|
||||
def type(self) -> str:
|
||||
return self._task_type
|
||||
|
||||
@abstractmethod
|
||||
def resolve_task(self, query_params, post_processing, query_compiler):
|
||||
def resolve_task(
|
||||
self,
|
||||
query_params: QUERY_PARAMS_TYPE,
|
||||
post_processing: List["PostProcessingAction"],
|
||||
query_compiler: "QueryCompiler",
|
||||
) -> RESOLVED_TASK_TYPE:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
pass
|
||||
|
||||
|
||||
class SizeTask(Task):
|
||||
def __init__(self, task_type):
|
||||
super().__init__(task_type)
|
||||
|
||||
@abstractmethod
|
||||
def size(self):
|
||||
def size(self) -> int:
|
||||
# must override
|
||||
pass
|
||||
|
||||
|
||||
class HeadTask(SizeTask):
|
||||
def __init__(self, sort_field, count):
|
||||
def __init__(self, sort_field: str, count: int):
|
||||
super().__init__("head")
|
||||
|
||||
# Add a task that is an ascending sort with size=count
|
||||
self._sort_field = sort_field
|
||||
self._count = count
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return f"('{self._task_type}': ('sort_field': '{self._sort_field}', 'count': {self._count}))"
|
||||
|
||||
def resolve_task(self, query_params, post_processing, query_compiler):
|
||||
def resolve_task(
|
||||
self,
|
||||
query_params: QUERY_PARAMS_TYPE,
|
||||
post_processing: List["PostProcessingAction"],
|
||||
query_compiler: "QueryCompiler",
|
||||
) -> RESOLVED_TASK_TYPE:
|
||||
# head - sort asc, size n
|
||||
# |12345-------------|
|
||||
query_sort_field = self._sort_field
|
||||
@ -97,22 +114,24 @@ class HeadTask(SizeTask):
|
||||
|
||||
return query_params, post_processing
|
||||
|
||||
def size(self):
|
||||
def size(self) -> int:
|
||||
return self._count
|
||||
|
||||
|
||||
class TailTask(SizeTask):
|
||||
def __init__(self, sort_field, count):
|
||||
def __init__(self, sort_field: str, count: int):
|
||||
super().__init__("tail")
|
||||
|
||||
# Add a task that is descending sort with size=count
|
||||
self._sort_field = sort_field
|
||||
self._count = count
|
||||
|
||||
def __repr__(self):
|
||||
return f"('{self._task_type}': ('sort_field': '{self._sort_field}', 'count': {self._count}))"
|
||||
|
||||
def resolve_task(self, query_params, post_processing, query_compiler):
|
||||
def resolve_task(
|
||||
self,
|
||||
query_params: QUERY_PARAMS_TYPE,
|
||||
post_processing: List["PostProcessingAction"],
|
||||
query_compiler: "QueryCompiler",
|
||||
) -> RESOLVED_TASK_TYPE:
|
||||
# tail - sort desc, size n, post-process sort asc
|
||||
# |-------------12345|
|
||||
query_sort_field = self._sort_field
|
||||
@ -123,7 +142,10 @@ class TailTask(SizeTask):
|
||||
if (
|
||||
query_params["query_size"] is not None
|
||||
and query_params["query_sort_order"] == query_sort_order
|
||||
and post_processing == ["sort_index"]
|
||||
and (
|
||||
len(post_processing) == 1
|
||||
and isinstance(post_processing[0], SortIndexAction)
|
||||
)
|
||||
):
|
||||
if query_size < query_params["query_size"]:
|
||||
query_params["query_size"] = query_size
|
||||
@ -156,12 +178,15 @@ class TailTask(SizeTask):
|
||||
|
||||
return query_params, post_processing
|
||||
|
||||
def size(self):
|
||||
def size(self) -> int:
|
||||
return self._count
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"('{self._task_type}': ('sort_field': '{self._sort_field}', 'count': {self._count}))"
|
||||
|
||||
|
||||
class QueryIdsTask(Task):
|
||||
def __init__(self, must, ids):
|
||||
def __init__(self, must: bool, ids: List[str]):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
@ -176,17 +201,21 @@ class QueryIdsTask(Task):
|
||||
self._must = must
|
||||
self._ids = ids
|
||||
|
||||
def resolve_task(self, query_params, post_processing, query_compiler):
|
||||
def resolve_task(
|
||||
self,
|
||||
query_params: QUERY_PARAMS_TYPE,
|
||||
post_processing: List["PostProcessingAction"],
|
||||
query_compiler: "QueryCompiler",
|
||||
) -> RESOLVED_TASK_TYPE:
|
||||
query_params["query"].ids(self._ids, must=self._must)
|
||||
|
||||
return query_params, post_processing
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return f"('{self._task_type}': ('must': {self._must}, 'ids': {self._ids}))"
|
||||
|
||||
|
||||
class QueryTermsTask(Task):
|
||||
def __init__(self, must, field, terms):
|
||||
def __init__(self, must: bool, field: str, terms: List[Any]):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
@ -205,20 +234,24 @@ class QueryTermsTask(Task):
|
||||
self._field = field
|
||||
self._terms = terms
|
||||
|
||||
def __repr__(self):
|
||||
def resolve_task(
|
||||
self,
|
||||
query_params: QUERY_PARAMS_TYPE,
|
||||
post_processing: List["PostProcessingAction"],
|
||||
query_compiler: "QueryCompiler",
|
||||
) -> RESOLVED_TASK_TYPE:
|
||||
query_params["query"].terms(self._field, self._terms, must=self._must)
|
||||
return query_params, post_processing
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"('{self._task_type}': ('must': {self._must}, "
|
||||
f"'field': '{self._field}', 'terms': {self._terms}))"
|
||||
)
|
||||
|
||||
def resolve_task(self, query_params, post_processing, query_compiler):
|
||||
query_params["query"].terms(self._field, self._terms, must=self._must)
|
||||
|
||||
return query_params, post_processing
|
||||
|
||||
|
||||
class BooleanFilterTask(Task):
|
||||
def __init__(self, boolean_filter):
|
||||
def __init__(self, boolean_filter: "BooleanFilter"):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
@ -229,17 +262,21 @@ class BooleanFilterTask(Task):
|
||||
|
||||
self._boolean_filter = boolean_filter
|
||||
|
||||
def __repr__(self):
|
||||
return f"('{self._task_type}': ('boolean_filter': {self._boolean_filter!r}))"
|
||||
|
||||
def resolve_task(self, query_params, post_processing, query_compiler):
|
||||
def resolve_task(
|
||||
self,
|
||||
query_params: QUERY_PARAMS_TYPE,
|
||||
post_processing: List["PostProcessingAction"],
|
||||
query_compiler: "QueryCompiler",
|
||||
) -> RESOLVED_TASK_TYPE:
|
||||
query_params["query"].update_boolean_filter(self._boolean_filter)
|
||||
|
||||
return query_params, post_processing
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"('{self._task_type}': ('boolean_filter': {self._boolean_filter!r}))"
|
||||
|
||||
|
||||
class ArithmeticOpFieldsTask(Task):
|
||||
def __init__(self, display_name, arithmetic_series):
|
||||
def __init__(self, display_name: str, arithmetic_series: ArithmeticSeries):
|
||||
super().__init__("arithmetic_op_fields")
|
||||
|
||||
self._display_name = display_name
|
||||
@ -248,19 +285,16 @@ class ArithmeticOpFieldsTask(Task):
|
||||
raise TypeError(f"Expecting ArithmeticSeries got {type(arithmetic_series)}")
|
||||
self._arithmetic_series = arithmetic_series
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"('{self._task_type}': ("
|
||||
f"'display_name': {self._display_name}, "
|
||||
f"'arithmetic_object': {self._arithmetic_series}"
|
||||
f"))"
|
||||
)
|
||||
|
||||
def update(self, display_name, arithmetic_series):
|
||||
def update(self, display_name: str, arithmetic_series: ArithmeticSeries) -> None:
|
||||
self._display_name = display_name
|
||||
self._arithmetic_series = arithmetic_series
|
||||
|
||||
def resolve_task(self, query_params, post_processing, query_compiler):
|
||||
def resolve_task(
|
||||
self,
|
||||
query_params: QUERY_PARAMS_TYPE,
|
||||
post_processing: List["PostProcessingAction"],
|
||||
query_compiler: "QueryCompiler",
|
||||
) -> RESOLVED_TASK_TYPE:
|
||||
# https://www.elastic.co/guide/en/elasticsearch/painless/current/painless-api-reference-shared-java-lang.html#painless-api-reference-shared-Math
|
||||
"""
|
||||
"script_fields": {
|
||||
@ -272,7 +306,10 @@ class ArithmeticOpFieldsTask(Task):
|
||||
}
|
||||
"""
|
||||
if query_params["query_script_fields"] is None:
|
||||
query_params["query_script_fields"] = dict()
|
||||
query_params["query_script_fields"] = {}
|
||||
|
||||
# TODO: Remove this once 'query_params' becomes a dataclass.
|
||||
assert isinstance(query_params["query_script_fields"], dict)
|
||||
|
||||
if self._display_name in query_params["query_script_fields"]:
|
||||
raise NotImplementedError(
|
||||
@ -286,3 +323,11 @@ class ArithmeticOpFieldsTask(Task):
|
||||
}
|
||||
|
||||
return query_params, post_processing
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"('{self._task_type}': ("
|
||||
f"'display_name': {self._display_name}, "
|
||||
f"'arithmetic_object': {self._arithmetic_series}"
|
||||
f"))"
|
||||
)
|
||||
|
@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import csv
|
||||
from typing import Union, List, Tuple, Optional, Mapping
|
||||
|
||||
import pandas as pd
|
||||
from pandas.io.parsers import _c_parser_defaults
|
||||
@ -20,10 +21,14 @@ from pandas.io.parsers import _c_parser_defaults
|
||||
from eland import DataFrame
|
||||
from eland.field_mappings import FieldMappings
|
||||
from eland.common import ensure_es_client, DEFAULT_CHUNK_SIZE
|
||||
from elasticsearch import Elasticsearch
|
||||
from elasticsearch.helpers import bulk
|
||||
|
||||
|
||||
def read_es(es_client, es_index_pattern):
|
||||
def read_es(
|
||||
es_client: Union[str, List[str], Tuple[str, ...], Elasticsearch],
|
||||
es_index_pattern: str,
|
||||
) -> DataFrame:
|
||||
"""
|
||||
Utility method to create an eland.Dataframe from an Elasticsearch index_pattern.
|
||||
(Similar to pandas.read_csv, but source data is an Elasticsearch index rather than
|
||||
@ -50,16 +55,16 @@ def read_es(es_client, es_index_pattern):
|
||||
|
||||
|
||||
def pandas_to_eland(
|
||||
pd_df,
|
||||
es_client,
|
||||
es_dest_index,
|
||||
es_if_exists="fail",
|
||||
es_refresh=False,
|
||||
es_dropna=False,
|
||||
es_type_overrides=None,
|
||||
chunksize=None,
|
||||
use_pandas_index_for_es_ids=True,
|
||||
):
|
||||
pd_df: pd.DataFrame,
|
||||
es_client: Union[str, List[str], Tuple[str, ...], Elasticsearch],
|
||||
es_dest_index: str,
|
||||
es_if_exists: str = "fail",
|
||||
es_refresh: bool = False,
|
||||
es_dropna: bool = False,
|
||||
es_type_overrides: Optional[Mapping[str, str]] = None,
|
||||
chunksize: Optional[int] = None,
|
||||
use_pandas_index_for_es_ids: bool = True,
|
||||
) -> DataFrame:
|
||||
"""
|
||||
Append a pandas DataFrame to an Elasticsearch index.
|
||||
Mainly used in testing.
|
||||
@ -217,7 +222,7 @@ def pandas_to_eland(
|
||||
return DataFrame(es_client, es_dest_index)
|
||||
|
||||
|
||||
def eland_to_pandas(ed_df, show_progress=False):
|
||||
def eland_to_pandas(ed_df: DataFrame, show_progress: bool = False) -> pd.DataFrame:
|
||||
"""
|
||||
Convert an eland.Dataframe to a pandas.DataFrame
|
||||
|
||||
|
34
noxfile.py
34
noxfile.py
@ -1,4 +1,5 @@
|
||||
import os
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
import nox
|
||||
import elasticsearch
|
||||
@ -12,6 +13,19 @@ SOURCE_FILES = (
|
||||
"docs/",
|
||||
)
|
||||
|
||||
# Whenever type-hints are completed on a file it should
|
||||
# be added here so that this file will continue to be checked
|
||||
# by mypy. Errors from other files are ignored.
|
||||
TYPED_FILES = {
|
||||
"eland/actions.py",
|
||||
"eland/arithmetics.py",
|
||||
"eland/common.py",
|
||||
"eland/filter.py",
|
||||
"eland/index.py",
|
||||
"eland/query.py",
|
||||
"eland/tasks.py",
|
||||
}
|
||||
|
||||
|
||||
@nox.session(reuse_venv=True)
|
||||
def blacken(session):
|
||||
@ -22,10 +36,28 @@ def blacken(session):
|
||||
|
||||
@nox.session(reuse_venv=True)
|
||||
def lint(session):
|
||||
session.install("black", "flake8")
|
||||
session.install("black", "flake8", "mypy")
|
||||
session.run("black", "--check", "--target-version=py36", *SOURCE_FILES)
|
||||
session.run("flake8", "--ignore=E501,W503,E402,E712", *SOURCE_FILES)
|
||||
|
||||
# TODO: When all files are typed we can change this to .run("mypy", "--strict", "eland/")
|
||||
session.log("mypy --strict eland/")
|
||||
for typed_file in TYPED_FILES:
|
||||
popen = subprocess.Popen(
|
||||
f"mypy --strict {typed_file}",
|
||||
shell=True,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
)
|
||||
popen.wait()
|
||||
errors = []
|
||||
for line in popen.stdout.read().decode().split("\n"):
|
||||
filepath = line.partition(":")[0]
|
||||
if filepath in TYPED_FILES:
|
||||
errors.append(line)
|
||||
if errors:
|
||||
session.error("\n" + "\n".join(sorted(set(errors))))
|
||||
|
||||
|
||||
@nox.session(python=["3.6", "3.7", "3.8"])
|
||||
def test(session):
|
||||
|
Loading…
x
Reference in New Issue
Block a user