Make QueryParams a dataclass

This commit is contained in:
Seth Michael Larson 2020-04-27 16:21:26 -05:00 committed by GitHub
parent 15a1977dcf
commit df2a21ffd4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 75 additions and 82 deletions

View File

@ -5,7 +5,7 @@
import copy import copy
import typing import typing
import warnings import warnings
from typing import Optional from typing import Optional, Tuple, List, Dict, Any
import numpy as np import numpy as np
import pandas as pd import pandas as pd
@ -35,6 +35,16 @@ if typing.TYPE_CHECKING:
from eland.query_compiler import QueryCompiler from eland.query_compiler import QueryCompiler
class QueryParams:
def __init__(self):
self.query = Query()
self.sort_field: Optional[str] = None
self.sort_order: Optional[SortOrder] = None
self.size: Optional[int] = None
self.fields: Optional[List[str]] = None
self.script_fields: Optional[Dict[str, Dict[str, Any]]] = None
class Operations: class Operations:
""" """
A collector of the queries and selectors we apply to queries to return the appropriate results. A collector of the queries and selectors we apply to queries to return the appropriate results.
@ -107,7 +117,7 @@ class Operations:
counts = {} counts = {}
for field in fields: for field in fields:
body = Query(query_params["query"]) body = Query(query_params.query)
body.exists(field, must=True) body.exists(field, must=True)
field_exists_count = query_compiler._client.count( field_exists_count = query_compiler._client.count(
@ -175,7 +185,7 @@ class Operations:
if numeric_only: if numeric_only:
fields = [field for field in fields if (field.is_numeric or field.is_bool)] fields = [field for field in fields if (field.is_numeric or field.is_bool)]
body = Query(query_params["query"]) body = Query(query_params.query)
# Convert pandas aggs to ES equivalent # Convert pandas aggs to ES equivalent
es_aggs = self._map_pd_aggs_to_es_aggs(pd_aggs) es_aggs = self._map_pd_aggs_to_es_aggs(pd_aggs)
@ -301,7 +311,7 @@ class Operations:
# Get just aggregatable field_names # Get just aggregatable field_names
aggregatable_field_names = query_compiler._mappings.aggregatable_field_names() aggregatable_field_names = query_compiler._mappings.aggregatable_field_names()
body = Query(query_params["query"]) body = Query(query_params.query)
for field in aggregatable_field_names.keys(): for field in aggregatable_field_names.keys():
body.terms_aggs(field, func, field, es_size=es_size) body.terms_aggs(field, func, field, es_size=es_size)
@ -338,7 +348,7 @@ class Operations:
numeric_source_fields = query_compiler._mappings.numeric_source_fields() numeric_source_fields = query_compiler._mappings.numeric_source_fields()
body = Query(query_params["query"]) body = Query(query_params.query)
results = self._metric_aggs(query_compiler, ["min", "max"], numeric_only=True) results = self._metric_aggs(query_compiler, ["min", "max"], numeric_only=True)
min_aggs = {} min_aggs = {}
@ -348,7 +358,7 @@ class Operations:
max_aggs[field] = max_agg max_aggs[field] = max_agg
for field in numeric_source_fields: for field in numeric_source_fields:
body.hist_aggs(field, field, min_aggs, max_aggs, num_bins) body.hist_aggs(field, field, min_aggs[field], max_aggs[field], num_bins)
response = query_compiler._client.search( response = query_compiler._client.search(
index=query_compiler._index_pattern, size=0, body=body.to_search_body() index=query_compiler._index_pattern, size=0, body=body.to_search_body()
@ -525,7 +535,7 @@ class Operations:
# for each field we compute: # for each field we compute:
# count, mean, std, min, 25%, 50%, 75%, max # count, mean, std, min, 25%, 50%, 75%, max
body = Query(query_params["query"]) body = Query(query_params.query)
for field in numeric_source_fields: for field in numeric_source_fields:
body.metric_aggs("extended_stats_" + field, "extended_stats", field) body.metric_aggs("extended_stats_" + field, "extended_stats", field)
@ -639,8 +649,8 @@ class Operations:
size, sort_params = Operations._query_params_to_size_and_sort(query_params) size, sort_params = Operations._query_params_to_size_and_sort(query_params)
script_fields = query_params["query_script_fields"] script_fields = query_params.script_fields
query = Query(query_params["query"]) query = Query(query_params.query)
body = query.to_search_body() body = query.to_search_body()
if script_fields is not None: if script_fields is not None:
@ -722,7 +732,7 @@ class Operations:
# TODO - this is not necessarily valid as the field may not exist in ALL these docs # TODO - this is not necessarily valid as the field may not exist in ALL these docs
return size return size
body = Query(query_params["query"]) body = Query(query_params.query)
body.exists(field, must=True) body.exists(field, must=True)
return query_compiler._client.count( return query_compiler._client.count(
@ -751,7 +761,7 @@ class Operations:
query_compiler, items query_compiler, items
) )
body = Query(query_params["query"]) body = Query(query_params.query)
if field == Index.ID_INDEX_FIELD: if field == Index.ID_INDEX_FIELD:
body.ids(items, must=True) body.ids(items, must=True)
@ -780,17 +790,16 @@ class Operations:
self._tasks.append(task) self._tasks.append(task)
@staticmethod @staticmethod
def _query_params_to_size_and_sort(query_params): def _query_params_to_size_and_sort(
query_params: QueryParams,
) -> Tuple[Optional[int], Optional[str]]:
sort_params = None sort_params = None
if query_params["query_sort_field"] and query_params["query_sort_order"]: if query_params.sort_field and query_params.sort_order:
sort_params = ( sort_params = (
query_params["query_sort_field"] f"{query_params.sort_field}:"
+ ":" f"{SortOrder.to_string(query_params.sort_order)}"
+ SortOrder.to_string(query_params["query_sort_order"])
) )
size = query_params.size
size = query_params["query_size"]
return size, sort_params return size, sort_params
@staticmethod @staticmethod
@ -800,7 +809,6 @@ class Operations:
if isinstance(action, SizeTask): if isinstance(action, SizeTask):
if size is None or action.size() < size: if size is None or action.size() < size:
size = action.size() size = action.size()
return size return size
@staticmethod @staticmethod
@ -815,15 +823,7 @@ class Operations:
# Some operations can be simply combined into a single query # Some operations can be simply combined into a single query
# other operations require pre-queries and then combinations # other operations require pre-queries and then combinations
# other operations require in-core post-processing of results # other operations require in-core post-processing of results
query_params = { query_params = QueryParams()
"query_sort_field": None,
"query_sort_order": None,
"query_size": None,
"query_fields": None,
"query_script_fields": None,
"query": Query(),
}
post_processing = [] post_processing = []
for task in self._tasks: for task in self._tasks:
@ -843,7 +843,7 @@ class Operations:
def _size(self, query_params, post_processing): def _size(self, query_params, post_processing):
# Shrink wrap code around checking if size parameter is set # Shrink wrap code around checking if size parameter is set
size = query_params["query_size"] # can be None size = query_params.size
pp_size = self._count_post_processing(post_processing) pp_size = self._count_post_processing(post_processing)
if pp_size is not None: if pp_size is not None:
@ -863,8 +863,8 @@ class Operations:
size, sort_params = Operations._query_params_to_size_and_sort(query_params) size, sort_params = Operations._query_params_to_size_and_sort(query_params)
_source = query_compiler._mappings.get_field_names() _source = query_compiler._mappings.get_field_names()
script_fields = query_params["query_script_fields"] script_fields = query_params.script_fields
query = Query(query_params["query"]) query = Query(query_params.query)
body = query.to_search_body() body = query.to_search_body()
if script_fields is not None: if script_fields is not None:
body["script_fields"] = script_fields body["script_fields"] = script_fields

View File

@ -107,12 +107,7 @@ class Query:
self._aggs[name] = agg self._aggs[name] = agg
def hist_aggs( def hist_aggs(
self, self, name: str, field: str, min_value: Any, max_value: Any, num_bins: int,
name: str,
field: str,
min_aggs: Dict[str, Any],
max_aggs: Dict[str, Any],
num_bins: int,
) -> None: ) -> None:
""" """
Add histogram agg e.g. Add histogram agg e.g.
@ -120,20 +115,18 @@ class Query:
"name": { "name": {
"histogram": { "histogram": {
"field": "AvgTicketPrice" "field": "AvgTicketPrice"
"interval": (max_aggs[field] - min_aggs[field])/bins "interval": (max_value - min_value)/bins
"offset": min_value
} }
} }
} }
""" """
min = min_aggs[field]
max = max_aggs[field]
interval = (max - min) / num_bins interval = (max_value - min_value) / num_bins
if interval != 0: if interval != 0:
offset = min
agg = { agg = {
"histogram": {"field": field, "interval": interval, "offset": offset} "histogram": {"field": field, "interval": interval, "offset": min_value}
} }
self._aggs[name] = agg self._aggs[name] = agg

View File

@ -3,7 +3,7 @@
# See the LICENSE file in the project root for more information # See the LICENSE file in the project root for more information
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, List, Dict, Any, Tuple from typing import TYPE_CHECKING, List, Any, Tuple
from eland import SortOrder from eland import SortOrder
from eland.actions import HeadAction, TailAction, SortIndexAction from eland.actions import HeadAction, TailAction, SortIndexAction
@ -14,9 +14,9 @@ if TYPE_CHECKING:
from .actions import PostProcessingAction # noqa: F401 from .actions import PostProcessingAction # noqa: F401
from .filter import BooleanFilter # noqa: F401 from .filter import BooleanFilter # noqa: F401
from .query_compiler import QueryCompiler # noqa: F401 from .query_compiler import QueryCompiler # noqa: F401
from .operations import QueryParams # noqa: F401
QUERY_PARAMS_TYPE = Dict[str, Any] RESOLVED_TASK_TYPE = Tuple["QueryParams", List["PostProcessingAction"]]
RESOLVED_TASK_TYPE = Tuple[QUERY_PARAMS_TYPE, List["PostProcessingAction"]]
class Task(ABC): class Task(ABC):
@ -39,7 +39,7 @@ class Task(ABC):
@abstractmethod @abstractmethod
def resolve_task( def resolve_task(
self, self,
query_params: QUERY_PARAMS_TYPE, query_params: "QueryParams",
post_processing: List["PostProcessingAction"], post_processing: List["PostProcessingAction"],
query_compiler: "QueryCompiler", query_compiler: "QueryCompiler",
) -> RESOLVED_TASK_TYPE: ) -> RESOLVED_TASK_TYPE:
@ -70,7 +70,7 @@ class HeadTask(SizeTask):
def resolve_task( def resolve_task(
self, self,
query_params: QUERY_PARAMS_TYPE, query_params: "QueryParams",
post_processing: List["PostProcessingAction"], post_processing: List["PostProcessingAction"],
query_compiler: "QueryCompiler", query_compiler: "QueryCompiler",
) -> RESOLVED_TASK_TYPE: ) -> RESOLVED_TASK_TYPE:
@ -87,20 +87,20 @@ class HeadTask(SizeTask):
post_processing.append(HeadAction(self._count)) post_processing.append(HeadAction(self._count))
return query_params, post_processing return query_params, post_processing
if query_params["query_sort_field"] is None: if query_params.sort_field is None:
query_params["query_sort_field"] = query_sort_field query_params.sort_field = query_sort_field
# if it is already sorted we use existing field # if it is already sorted we use existing field
if query_params["query_sort_order"] is None: if query_params.sort_order is None:
query_params["query_sort_order"] = query_sort_order query_params.sort_order = query_sort_order
# if it is already sorted we get head of existing order # if it is already sorted we get head of existing order
if query_params["query_size"] is None: if query_params.size is None:
query_params["query_size"] = query_size query_params.size = query_size
else: else:
# truncate if head is smaller # truncate if head is smaller
if query_size < query_params["query_size"]: if query_size < query_params.size:
query_params["query_size"] = query_size query_params.size = query_size
return query_params, post_processing return query_params, post_processing
@ -118,7 +118,7 @@ class TailTask(SizeTask):
def resolve_task( def resolve_task(
self, self,
query_params: QUERY_PARAMS_TYPE, query_params: "QueryParams",
post_processing: List["PostProcessingAction"], post_processing: List["PostProcessingAction"],
query_compiler: "QueryCompiler", query_compiler: "QueryCompiler",
) -> RESOLVED_TASK_TYPE: ) -> RESOLVED_TASK_TYPE:
@ -130,15 +130,15 @@ class TailTask(SizeTask):
# If this is a tail of a tail adjust settings and return # If this is a tail of a tail adjust settings and return
if ( if (
query_params["query_size"] is not None query_params.size is not None
and query_params["query_sort_order"] == query_sort_order and query_params.sort_order == query_sort_order
and ( and (
len(post_processing) == 1 len(post_processing) == 1
and isinstance(post_processing[0], SortIndexAction) and isinstance(post_processing[0], SortIndexAction)
) )
): ):
if query_size < query_params["query_size"]: if query_size < query_params.size:
query_params["query_size"] = query_size query_params.size = query_size
return query_params, post_processing return query_params, post_processing
# If we are already postprocessing the query results, just get 'tail' of these # If we are already postprocessing the query results, just get 'tail' of these
@ -151,18 +151,18 @@ class TailTask(SizeTask):
# If results are already constrained, just get 'tail' of these # If results are already constrained, just get 'tail' of these
# (note, currently we just append another tail, we don't optimise by # (note, currently we just append another tail, we don't optimise by
# overwriting previous tail) # overwriting previous tail)
if query_params["query_size"] is not None: if query_params.size is not None:
post_processing.append(TailAction(self._count)) post_processing.append(TailAction(self._count))
return query_params, post_processing return query_params, post_processing
else: else:
query_params["query_size"] = query_size query_params.size = query_size
if query_params["query_sort_field"] is None: if query_params.sort_field is None:
query_params["query_sort_field"] = query_sort_field query_params.sort_field = query_sort_field
if query_params["query_sort_order"] is None: if query_params.sort_order is None:
query_params["query_sort_order"] = query_sort_order query_params.sort_order = query_sort_order
else: else:
# reverse sort order # reverse sort order
query_params["query_sort_order"] = SortOrder.reverse(query_sort_order) query_params.sort_order = SortOrder.reverse(query_sort_order)
post_processing.append(SortIndexAction()) post_processing.append(SortIndexAction())
@ -193,11 +193,11 @@ class QueryIdsTask(Task):
def resolve_task( def resolve_task(
self, self,
query_params: QUERY_PARAMS_TYPE, query_params: "QueryParams",
post_processing: List["PostProcessingAction"], post_processing: List["PostProcessingAction"],
query_compiler: "QueryCompiler", query_compiler: "QueryCompiler",
) -> RESOLVED_TASK_TYPE: ) -> RESOLVED_TASK_TYPE:
query_params["query"].ids(self._ids, must=self._must) query_params.query.ids(self._ids, must=self._must)
return query_params, post_processing return query_params, post_processing
def __repr__(self) -> str: def __repr__(self) -> str:
@ -226,11 +226,11 @@ class QueryTermsTask(Task):
def resolve_task( def resolve_task(
self, self,
query_params: QUERY_PARAMS_TYPE, query_params: "QueryParams",
post_processing: List["PostProcessingAction"], post_processing: List["PostProcessingAction"],
query_compiler: "QueryCompiler", query_compiler: "QueryCompiler",
) -> RESOLVED_TASK_TYPE: ) -> RESOLVED_TASK_TYPE:
query_params["query"].terms(self._field, self._terms, must=self._must) query_params.query.terms(self._field, self._terms, must=self._must)
return query_params, post_processing return query_params, post_processing
def __repr__(self) -> str: def __repr__(self) -> str:
@ -254,11 +254,11 @@ class BooleanFilterTask(Task):
def resolve_task( def resolve_task(
self, self,
query_params: QUERY_PARAMS_TYPE, query_params: "QueryParams",
post_processing: List["PostProcessingAction"], post_processing: List["PostProcessingAction"],
query_compiler: "QueryCompiler", query_compiler: "QueryCompiler",
) -> RESOLVED_TASK_TYPE: ) -> RESOLVED_TASK_TYPE:
query_params["query"].update_boolean_filter(self._boolean_filter) query_params.query.update_boolean_filter(self._boolean_filter)
return query_params, post_processing return query_params, post_processing
def __repr__(self) -> str: def __repr__(self) -> str:
@ -281,7 +281,7 @@ class ArithmeticOpFieldsTask(Task):
def resolve_task( def resolve_task(
self, self,
query_params: QUERY_PARAMS_TYPE, query_params: "QueryParams",
post_processing: List["PostProcessingAction"], post_processing: List["PostProcessingAction"],
query_compiler: "QueryCompiler", query_compiler: "QueryCompiler",
) -> RESOLVED_TASK_TYPE: ) -> RESOLVED_TASK_TYPE:
@ -295,20 +295,20 @@ class ArithmeticOpFieldsTask(Task):
} }
} }
""" """
if query_params["query_script_fields"] is None: if query_params.script_fields is None:
query_params["query_script_fields"] = {} query_params.script_fields = {}
# TODO: Remove this once 'query_params' becomes a dataclass. # TODO: Remove this once 'query_params' becomes a dataclass.
assert isinstance(query_params["query_script_fields"], dict) assert isinstance(query_params.script_fields, dict)
if self._display_name in query_params["query_script_fields"]: if self._display_name in query_params.script_fields:
raise NotImplementedError( raise NotImplementedError(
f"TODO code path - combine multiple ops " f"TODO code path - combine multiple ops "
f"'{self}'\n{query_params['query_script_fields']}\n" f"'{self}'\n{query_params.script_fields}\n"
f"{self._display_name}\n{self._arithmetic_series.resolve()}" f"{self._display_name}\n{self._arithmetic_series.resolve()}"
) )
query_params["query_script_fields"][self._display_name] = { query_params.script_fields[self._display_name] = {
"script": {"source": self._arithmetic_series.resolve()} "script": {"source": self._arithmetic_series.resolve()}
} }