From df2a21ffd43428585d29ed46bc70b45309076ea1 Mon Sep 17 00:00:00 2001 From: Seth Michael Larson Date: Mon, 27 Apr 2020 16:21:26 -0500 Subject: [PATCH] Make QueryParams a dataclass --- eland/operations.py | 64 +++++++++++++++++++------------------- eland/query.py | 17 +++------- eland/tasks.py | 76 ++++++++++++++++++++++----------------------- 3 files changed, 75 insertions(+), 82 deletions(-) diff --git a/eland/operations.py b/eland/operations.py index d2f4b23..ce23cfd 100644 --- a/eland/operations.py +++ b/eland/operations.py @@ -5,7 +5,7 @@ import copy import typing import warnings -from typing import Optional +from typing import Optional, Tuple, List, Dict, Any import numpy as np import pandas as pd @@ -35,6 +35,16 @@ if typing.TYPE_CHECKING: from eland.query_compiler import QueryCompiler +class QueryParams: + def __init__(self): + self.query = Query() + self.sort_field: Optional[str] = None + self.sort_order: Optional[SortOrder] = None + self.size: Optional[int] = None + self.fields: Optional[List[str]] = None + self.script_fields: Optional[Dict[str, Dict[str, Any]]] = None + + class Operations: """ A collector of the queries and selectors we apply to queries to return the appropriate results. @@ -107,7 +117,7 @@ class Operations: counts = {} for field in fields: - body = Query(query_params["query"]) + body = Query(query_params.query) body.exists(field, must=True) field_exists_count = query_compiler._client.count( @@ -175,7 +185,7 @@ class Operations: if numeric_only: fields = [field for field in fields if (field.is_numeric or field.is_bool)] - body = Query(query_params["query"]) + body = Query(query_params.query) # Convert pandas aggs to ES equivalent es_aggs = self._map_pd_aggs_to_es_aggs(pd_aggs) @@ -301,7 +311,7 @@ class Operations: # Get just aggregatable field_names aggregatable_field_names = query_compiler._mappings.aggregatable_field_names() - body = Query(query_params["query"]) + body = Query(query_params.query) for field in aggregatable_field_names.keys(): body.terms_aggs(field, func, field, es_size=es_size) @@ -338,7 +348,7 @@ class Operations: numeric_source_fields = query_compiler._mappings.numeric_source_fields() - body = Query(query_params["query"]) + body = Query(query_params.query) results = self._metric_aggs(query_compiler, ["min", "max"], numeric_only=True) min_aggs = {} @@ -348,7 +358,7 @@ class Operations: max_aggs[field] = max_agg for field in numeric_source_fields: - body.hist_aggs(field, field, min_aggs, max_aggs, num_bins) + body.hist_aggs(field, field, min_aggs[field], max_aggs[field], num_bins) response = query_compiler._client.search( index=query_compiler._index_pattern, size=0, body=body.to_search_body() @@ -525,7 +535,7 @@ class Operations: # for each field we compute: # count, mean, std, min, 25%, 50%, 75%, max - body = Query(query_params["query"]) + body = Query(query_params.query) for field in numeric_source_fields: body.metric_aggs("extended_stats_" + field, "extended_stats", field) @@ -639,8 +649,8 @@ class Operations: size, sort_params = Operations._query_params_to_size_and_sort(query_params) - script_fields = query_params["query_script_fields"] - query = Query(query_params["query"]) + script_fields = query_params.script_fields + query = Query(query_params.query) body = query.to_search_body() if script_fields is not None: @@ -722,7 +732,7 @@ class Operations: # TODO - this is not necessarily valid as the field may not exist in ALL these docs return size - body = Query(query_params["query"]) + body = Query(query_params.query) body.exists(field, must=True) return query_compiler._client.count( @@ -751,7 +761,7 @@ class Operations: query_compiler, items ) - body = Query(query_params["query"]) + body = Query(query_params.query) if field == Index.ID_INDEX_FIELD: body.ids(items, must=True) @@ -780,17 +790,16 @@ class Operations: self._tasks.append(task) @staticmethod - def _query_params_to_size_and_sort(query_params): + def _query_params_to_size_and_sort( + query_params: QueryParams, + ) -> Tuple[Optional[int], Optional[str]]: sort_params = None - if query_params["query_sort_field"] and query_params["query_sort_order"]: + if query_params.sort_field and query_params.sort_order: sort_params = ( - query_params["query_sort_field"] - + ":" - + SortOrder.to_string(query_params["query_sort_order"]) + f"{query_params.sort_field}:" + f"{SortOrder.to_string(query_params.sort_order)}" ) - - size = query_params["query_size"] - + size = query_params.size return size, sort_params @staticmethod @@ -800,7 +809,6 @@ class Operations: if isinstance(action, SizeTask): if size is None or action.size() < size: size = action.size() - return size @staticmethod @@ -815,15 +823,7 @@ class Operations: # Some operations can be simply combined into a single query # other operations require pre-queries and then combinations # other operations require in-core post-processing of results - query_params = { - "query_sort_field": None, - "query_sort_order": None, - "query_size": None, - "query_fields": None, - "query_script_fields": None, - "query": Query(), - } - + query_params = QueryParams() post_processing = [] for task in self._tasks: @@ -843,7 +843,7 @@ class Operations: def _size(self, query_params, post_processing): # Shrink wrap code around checking if size parameter is set - size = query_params["query_size"] # can be None + size = query_params.size pp_size = self._count_post_processing(post_processing) if pp_size is not None: @@ -863,8 +863,8 @@ class Operations: size, sort_params = Operations._query_params_to_size_and_sort(query_params) _source = query_compiler._mappings.get_field_names() - script_fields = query_params["query_script_fields"] - query = Query(query_params["query"]) + script_fields = query_params.script_fields + query = Query(query_params.query) body = query.to_search_body() if script_fields is not None: body["script_fields"] = script_fields diff --git a/eland/query.py b/eland/query.py index 11e1fec..4083489 100644 --- a/eland/query.py +++ b/eland/query.py @@ -107,12 +107,7 @@ class Query: self._aggs[name] = agg def hist_aggs( - self, - name: str, - field: str, - min_aggs: Dict[str, Any], - max_aggs: Dict[str, Any], - num_bins: int, + self, name: str, field: str, min_value: Any, max_value: Any, num_bins: int, ) -> None: """ Add histogram agg e.g. @@ -120,20 +115,18 @@ class Query: "name": { "histogram": { "field": "AvgTicketPrice" - "interval": (max_aggs[field] - min_aggs[field])/bins + "interval": (max_value - min_value)/bins + "offset": min_value } } } """ - min = min_aggs[field] - max = max_aggs[field] - interval = (max - min) / num_bins + interval = (max_value - min_value) / num_bins if interval != 0: - offset = min agg = { - "histogram": {"field": field, "interval": interval, "offset": offset} + "histogram": {"field": field, "interval": interval, "offset": min_value} } self._aggs[name] = agg diff --git a/eland/tasks.py b/eland/tasks.py index 8e7da9b..bcfe701 100644 --- a/eland/tasks.py +++ b/eland/tasks.py @@ -3,7 +3,7 @@ # See the LICENSE file in the project root for more information from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, List, Dict, Any, Tuple +from typing import TYPE_CHECKING, List, Any, Tuple from eland import SortOrder from eland.actions import HeadAction, TailAction, SortIndexAction @@ -14,9 +14,9 @@ if TYPE_CHECKING: from .actions import PostProcessingAction # noqa: F401 from .filter import BooleanFilter # noqa: F401 from .query_compiler import QueryCompiler # noqa: F401 + from .operations import QueryParams # noqa: F401 -QUERY_PARAMS_TYPE = Dict[str, Any] -RESOLVED_TASK_TYPE = Tuple[QUERY_PARAMS_TYPE, List["PostProcessingAction"]] +RESOLVED_TASK_TYPE = Tuple["QueryParams", List["PostProcessingAction"]] class Task(ABC): @@ -39,7 +39,7 @@ class Task(ABC): @abstractmethod def resolve_task( self, - query_params: QUERY_PARAMS_TYPE, + query_params: "QueryParams", post_processing: List["PostProcessingAction"], query_compiler: "QueryCompiler", ) -> RESOLVED_TASK_TYPE: @@ -70,7 +70,7 @@ class HeadTask(SizeTask): def resolve_task( self, - query_params: QUERY_PARAMS_TYPE, + query_params: "QueryParams", post_processing: List["PostProcessingAction"], query_compiler: "QueryCompiler", ) -> RESOLVED_TASK_TYPE: @@ -87,20 +87,20 @@ class HeadTask(SizeTask): post_processing.append(HeadAction(self._count)) return query_params, post_processing - if query_params["query_sort_field"] is None: - query_params["query_sort_field"] = query_sort_field + if query_params.sort_field is None: + query_params.sort_field = query_sort_field # if it is already sorted we use existing field - if query_params["query_sort_order"] is None: - query_params["query_sort_order"] = query_sort_order + if query_params.sort_order is None: + query_params.sort_order = query_sort_order # if it is already sorted we get head of existing order - if query_params["query_size"] is None: - query_params["query_size"] = query_size + if query_params.size is None: + query_params.size = query_size else: # truncate if head is smaller - if query_size < query_params["query_size"]: - query_params["query_size"] = query_size + if query_size < query_params.size: + query_params.size = query_size return query_params, post_processing @@ -118,7 +118,7 @@ class TailTask(SizeTask): def resolve_task( self, - query_params: QUERY_PARAMS_TYPE, + query_params: "QueryParams", post_processing: List["PostProcessingAction"], query_compiler: "QueryCompiler", ) -> RESOLVED_TASK_TYPE: @@ -130,15 +130,15 @@ class TailTask(SizeTask): # If this is a tail of a tail adjust settings and return if ( - query_params["query_size"] is not None - and query_params["query_sort_order"] == query_sort_order + query_params.size is not None + and query_params.sort_order == query_sort_order and ( len(post_processing) == 1 and isinstance(post_processing[0], SortIndexAction) ) ): - if query_size < query_params["query_size"]: - query_params["query_size"] = query_size + if query_size < query_params.size: + query_params.size = query_size return query_params, post_processing # If we are already postprocessing the query results, just get 'tail' of these @@ -151,18 +151,18 @@ class TailTask(SizeTask): # If results are already constrained, just get 'tail' of these # (note, currently we just append another tail, we don't optimise by # overwriting previous tail) - if query_params["query_size"] is not None: + if query_params.size is not None: post_processing.append(TailAction(self._count)) return query_params, post_processing else: - query_params["query_size"] = query_size - if query_params["query_sort_field"] is None: - query_params["query_sort_field"] = query_sort_field - if query_params["query_sort_order"] is None: - query_params["query_sort_order"] = query_sort_order + query_params.size = query_size + if query_params.sort_field is None: + query_params.sort_field = query_sort_field + if query_params.sort_order is None: + query_params.sort_order = query_sort_order else: # reverse sort order - query_params["query_sort_order"] = SortOrder.reverse(query_sort_order) + query_params.sort_order = SortOrder.reverse(query_sort_order) post_processing.append(SortIndexAction()) @@ -193,11 +193,11 @@ class QueryIdsTask(Task): def resolve_task( self, - query_params: QUERY_PARAMS_TYPE, + query_params: "QueryParams", post_processing: List["PostProcessingAction"], query_compiler: "QueryCompiler", ) -> RESOLVED_TASK_TYPE: - query_params["query"].ids(self._ids, must=self._must) + query_params.query.ids(self._ids, must=self._must) return query_params, post_processing def __repr__(self) -> str: @@ -226,11 +226,11 @@ class QueryTermsTask(Task): def resolve_task( self, - query_params: QUERY_PARAMS_TYPE, + query_params: "QueryParams", post_processing: List["PostProcessingAction"], query_compiler: "QueryCompiler", ) -> RESOLVED_TASK_TYPE: - query_params["query"].terms(self._field, self._terms, must=self._must) + query_params.query.terms(self._field, self._terms, must=self._must) return query_params, post_processing def __repr__(self) -> str: @@ -254,11 +254,11 @@ class BooleanFilterTask(Task): def resolve_task( self, - query_params: QUERY_PARAMS_TYPE, + query_params: "QueryParams", post_processing: List["PostProcessingAction"], query_compiler: "QueryCompiler", ) -> RESOLVED_TASK_TYPE: - query_params["query"].update_boolean_filter(self._boolean_filter) + query_params.query.update_boolean_filter(self._boolean_filter) return query_params, post_processing def __repr__(self) -> str: @@ -281,7 +281,7 @@ class ArithmeticOpFieldsTask(Task): def resolve_task( self, - query_params: QUERY_PARAMS_TYPE, + query_params: "QueryParams", post_processing: List["PostProcessingAction"], query_compiler: "QueryCompiler", ) -> RESOLVED_TASK_TYPE: @@ -295,20 +295,20 @@ class ArithmeticOpFieldsTask(Task): } } """ - if query_params["query_script_fields"] is None: - query_params["query_script_fields"] = {} + if query_params.script_fields is None: + query_params.script_fields = {} # TODO: Remove this once 'query_params' becomes a dataclass. - assert isinstance(query_params["query_script_fields"], dict) + assert isinstance(query_params.script_fields, dict) - if self._display_name in query_params["query_script_fields"]: + if self._display_name in query_params.script_fields: raise NotImplementedError( f"TODO code path - combine multiple ops " - f"'{self}'\n{query_params['query_script_fields']}\n" + f"'{self}'\n{query_params.script_fields}\n" f"{self._display_name}\n{self._arithmetic_series.resolve()}" ) - query_params["query_script_fields"][self._display_name] = { + query_params.script_fields[self._display_name] = { "script": {"source": self._arithmetic_series.resolve()} }