Make QueryParams a dataclass

This commit is contained in:
Seth Michael Larson 2020-04-27 16:21:26 -05:00 committed by GitHub
parent 15a1977dcf
commit df2a21ffd4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 75 additions and 82 deletions

View File

@ -5,7 +5,7 @@
import copy
import typing
import warnings
from typing import Optional
from typing import Optional, Tuple, List, Dict, Any
import numpy as np
import pandas as pd
@ -35,6 +35,16 @@ if typing.TYPE_CHECKING:
from eland.query_compiler import QueryCompiler
class QueryParams:
def __init__(self):
self.query = Query()
self.sort_field: Optional[str] = None
self.sort_order: Optional[SortOrder] = None
self.size: Optional[int] = None
self.fields: Optional[List[str]] = None
self.script_fields: Optional[Dict[str, Dict[str, Any]]] = None
class Operations:
"""
A collector of the queries and selectors we apply to queries to return the appropriate results.
@ -107,7 +117,7 @@ class Operations:
counts = {}
for field in fields:
body = Query(query_params["query"])
body = Query(query_params.query)
body.exists(field, must=True)
field_exists_count = query_compiler._client.count(
@ -175,7 +185,7 @@ class Operations:
if numeric_only:
fields = [field for field in fields if (field.is_numeric or field.is_bool)]
body = Query(query_params["query"])
body = Query(query_params.query)
# Convert pandas aggs to ES equivalent
es_aggs = self._map_pd_aggs_to_es_aggs(pd_aggs)
@ -301,7 +311,7 @@ class Operations:
# Get just aggregatable field_names
aggregatable_field_names = query_compiler._mappings.aggregatable_field_names()
body = Query(query_params["query"])
body = Query(query_params.query)
for field in aggregatable_field_names.keys():
body.terms_aggs(field, func, field, es_size=es_size)
@ -338,7 +348,7 @@ class Operations:
numeric_source_fields = query_compiler._mappings.numeric_source_fields()
body = Query(query_params["query"])
body = Query(query_params.query)
results = self._metric_aggs(query_compiler, ["min", "max"], numeric_only=True)
min_aggs = {}
@ -348,7 +358,7 @@ class Operations:
max_aggs[field] = max_agg
for field in numeric_source_fields:
body.hist_aggs(field, field, min_aggs, max_aggs, num_bins)
body.hist_aggs(field, field, min_aggs[field], max_aggs[field], num_bins)
response = query_compiler._client.search(
index=query_compiler._index_pattern, size=0, body=body.to_search_body()
@ -525,7 +535,7 @@ class Operations:
# for each field we compute:
# count, mean, std, min, 25%, 50%, 75%, max
body = Query(query_params["query"])
body = Query(query_params.query)
for field in numeric_source_fields:
body.metric_aggs("extended_stats_" + field, "extended_stats", field)
@ -639,8 +649,8 @@ class Operations:
size, sort_params = Operations._query_params_to_size_and_sort(query_params)
script_fields = query_params["query_script_fields"]
query = Query(query_params["query"])
script_fields = query_params.script_fields
query = Query(query_params.query)
body = query.to_search_body()
if script_fields is not None:
@ -722,7 +732,7 @@ class Operations:
# TODO - this is not necessarily valid as the field may not exist in ALL these docs
return size
body = Query(query_params["query"])
body = Query(query_params.query)
body.exists(field, must=True)
return query_compiler._client.count(
@ -751,7 +761,7 @@ class Operations:
query_compiler, items
)
body = Query(query_params["query"])
body = Query(query_params.query)
if field == Index.ID_INDEX_FIELD:
body.ids(items, must=True)
@ -780,17 +790,16 @@ class Operations:
self._tasks.append(task)
@staticmethod
def _query_params_to_size_and_sort(query_params):
def _query_params_to_size_and_sort(
query_params: QueryParams,
) -> Tuple[Optional[int], Optional[str]]:
sort_params = None
if query_params["query_sort_field"] and query_params["query_sort_order"]:
if query_params.sort_field and query_params.sort_order:
sort_params = (
query_params["query_sort_field"]
+ ":"
+ SortOrder.to_string(query_params["query_sort_order"])
f"{query_params.sort_field}:"
f"{SortOrder.to_string(query_params.sort_order)}"
)
size = query_params["query_size"]
size = query_params.size
return size, sort_params
@staticmethod
@ -800,7 +809,6 @@ class Operations:
if isinstance(action, SizeTask):
if size is None or action.size() < size:
size = action.size()
return size
@staticmethod
@ -815,15 +823,7 @@ class Operations:
# Some operations can be simply combined into a single query
# other operations require pre-queries and then combinations
# other operations require in-core post-processing of results
query_params = {
"query_sort_field": None,
"query_sort_order": None,
"query_size": None,
"query_fields": None,
"query_script_fields": None,
"query": Query(),
}
query_params = QueryParams()
post_processing = []
for task in self._tasks:
@ -843,7 +843,7 @@ class Operations:
def _size(self, query_params, post_processing):
# Shrink wrap code around checking if size parameter is set
size = query_params["query_size"] # can be None
size = query_params.size
pp_size = self._count_post_processing(post_processing)
if pp_size is not None:
@ -863,8 +863,8 @@ class Operations:
size, sort_params = Operations._query_params_to_size_and_sort(query_params)
_source = query_compiler._mappings.get_field_names()
script_fields = query_params["query_script_fields"]
query = Query(query_params["query"])
script_fields = query_params.script_fields
query = Query(query_params.query)
body = query.to_search_body()
if script_fields is not None:
body["script_fields"] = script_fields

View File

@ -107,12 +107,7 @@ class Query:
self._aggs[name] = agg
def hist_aggs(
self,
name: str,
field: str,
min_aggs: Dict[str, Any],
max_aggs: Dict[str, Any],
num_bins: int,
self, name: str, field: str, min_value: Any, max_value: Any, num_bins: int,
) -> None:
"""
Add histogram agg e.g.
@ -120,20 +115,18 @@ class Query:
"name": {
"histogram": {
"field": "AvgTicketPrice"
"interval": (max_aggs[field] - min_aggs[field])/bins
"interval": (max_value - min_value)/bins
"offset": min_value
}
}
}
"""
min = min_aggs[field]
max = max_aggs[field]
interval = (max - min) / num_bins
interval = (max_value - min_value) / num_bins
if interval != 0:
offset = min
agg = {
"histogram": {"field": field, "interval": interval, "offset": offset}
"histogram": {"field": field, "interval": interval, "offset": min_value}
}
self._aggs[name] = agg

View File

@ -3,7 +3,7 @@
# See the LICENSE file in the project root for more information
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, List, Dict, Any, Tuple
from typing import TYPE_CHECKING, List, Any, Tuple
from eland import SortOrder
from eland.actions import HeadAction, TailAction, SortIndexAction
@ -14,9 +14,9 @@ if TYPE_CHECKING:
from .actions import PostProcessingAction # noqa: F401
from .filter import BooleanFilter # noqa: F401
from .query_compiler import QueryCompiler # noqa: F401
from .operations import QueryParams # noqa: F401
QUERY_PARAMS_TYPE = Dict[str, Any]
RESOLVED_TASK_TYPE = Tuple[QUERY_PARAMS_TYPE, List["PostProcessingAction"]]
RESOLVED_TASK_TYPE = Tuple["QueryParams", List["PostProcessingAction"]]
class Task(ABC):
@ -39,7 +39,7 @@ class Task(ABC):
@abstractmethod
def resolve_task(
self,
query_params: QUERY_PARAMS_TYPE,
query_params: "QueryParams",
post_processing: List["PostProcessingAction"],
query_compiler: "QueryCompiler",
) -> RESOLVED_TASK_TYPE:
@ -70,7 +70,7 @@ class HeadTask(SizeTask):
def resolve_task(
self,
query_params: QUERY_PARAMS_TYPE,
query_params: "QueryParams",
post_processing: List["PostProcessingAction"],
query_compiler: "QueryCompiler",
) -> RESOLVED_TASK_TYPE:
@ -87,20 +87,20 @@ class HeadTask(SizeTask):
post_processing.append(HeadAction(self._count))
return query_params, post_processing
if query_params["query_sort_field"] is None:
query_params["query_sort_field"] = query_sort_field
if query_params.sort_field is None:
query_params.sort_field = query_sort_field
# if it is already sorted we use existing field
if query_params["query_sort_order"] is None:
query_params["query_sort_order"] = query_sort_order
if query_params.sort_order is None:
query_params.sort_order = query_sort_order
# if it is already sorted we get head of existing order
if query_params["query_size"] is None:
query_params["query_size"] = query_size
if query_params.size is None:
query_params.size = query_size
else:
# truncate if head is smaller
if query_size < query_params["query_size"]:
query_params["query_size"] = query_size
if query_size < query_params.size:
query_params.size = query_size
return query_params, post_processing
@ -118,7 +118,7 @@ class TailTask(SizeTask):
def resolve_task(
self,
query_params: QUERY_PARAMS_TYPE,
query_params: "QueryParams",
post_processing: List["PostProcessingAction"],
query_compiler: "QueryCompiler",
) -> RESOLVED_TASK_TYPE:
@ -130,15 +130,15 @@ class TailTask(SizeTask):
# If this is a tail of a tail adjust settings and return
if (
query_params["query_size"] is not None
and query_params["query_sort_order"] == query_sort_order
query_params.size is not None
and query_params.sort_order == query_sort_order
and (
len(post_processing) == 1
and isinstance(post_processing[0], SortIndexAction)
)
):
if query_size < query_params["query_size"]:
query_params["query_size"] = query_size
if query_size < query_params.size:
query_params.size = query_size
return query_params, post_processing
# If we are already postprocessing the query results, just get 'tail' of these
@ -151,18 +151,18 @@ class TailTask(SizeTask):
# If results are already constrained, just get 'tail' of these
# (note, currently we just append another tail, we don't optimise by
# overwriting previous tail)
if query_params["query_size"] is not None:
if query_params.size is not None:
post_processing.append(TailAction(self._count))
return query_params, post_processing
else:
query_params["query_size"] = query_size
if query_params["query_sort_field"] is None:
query_params["query_sort_field"] = query_sort_field
if query_params["query_sort_order"] is None:
query_params["query_sort_order"] = query_sort_order
query_params.size = query_size
if query_params.sort_field is None:
query_params.sort_field = query_sort_field
if query_params.sort_order is None:
query_params.sort_order = query_sort_order
else:
# reverse sort order
query_params["query_sort_order"] = SortOrder.reverse(query_sort_order)
query_params.sort_order = SortOrder.reverse(query_sort_order)
post_processing.append(SortIndexAction())
@ -193,11 +193,11 @@ class QueryIdsTask(Task):
def resolve_task(
self,
query_params: QUERY_PARAMS_TYPE,
query_params: "QueryParams",
post_processing: List["PostProcessingAction"],
query_compiler: "QueryCompiler",
) -> RESOLVED_TASK_TYPE:
query_params["query"].ids(self._ids, must=self._must)
query_params.query.ids(self._ids, must=self._must)
return query_params, post_processing
def __repr__(self) -> str:
@ -226,11 +226,11 @@ class QueryTermsTask(Task):
def resolve_task(
self,
query_params: QUERY_PARAMS_TYPE,
query_params: "QueryParams",
post_processing: List["PostProcessingAction"],
query_compiler: "QueryCompiler",
) -> RESOLVED_TASK_TYPE:
query_params["query"].terms(self._field, self._terms, must=self._must)
query_params.query.terms(self._field, self._terms, must=self._must)
return query_params, post_processing
def __repr__(self) -> str:
@ -254,11 +254,11 @@ class BooleanFilterTask(Task):
def resolve_task(
self,
query_params: QUERY_PARAMS_TYPE,
query_params: "QueryParams",
post_processing: List["PostProcessingAction"],
query_compiler: "QueryCompiler",
) -> RESOLVED_TASK_TYPE:
query_params["query"].update_boolean_filter(self._boolean_filter)
query_params.query.update_boolean_filter(self._boolean_filter)
return query_params, post_processing
def __repr__(self) -> str:
@ -281,7 +281,7 @@ class ArithmeticOpFieldsTask(Task):
def resolve_task(
self,
query_params: QUERY_PARAMS_TYPE,
query_params: "QueryParams",
post_processing: List["PostProcessingAction"],
query_compiler: "QueryCompiler",
) -> RESOLVED_TASK_TYPE:
@ -295,20 +295,20 @@ class ArithmeticOpFieldsTask(Task):
}
}
"""
if query_params["query_script_fields"] is None:
query_params["query_script_fields"] = {}
if query_params.script_fields is None:
query_params.script_fields = {}
# TODO: Remove this once 'query_params' becomes a dataclass.
assert isinstance(query_params["query_script_fields"], dict)
assert isinstance(query_params.script_fields, dict)
if self._display_name in query_params["query_script_fields"]:
if self._display_name in query_params.script_fields:
raise NotImplementedError(
f"TODO code path - combine multiple ops "
f"'{self}'\n{query_params['query_script_fields']}\n"
f"'{self}'\n{query_params.script_fields}\n"
f"{self._display_name}\n{self._arithmetic_series.resolve()}"
)
query_params["query_script_fields"][self._display_name] = {
query_params.script_fields[self._display_name] = {
"script": {"source": self._arithmetic_series.resolve()}
}