mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
Make QueryParams a dataclass
This commit is contained in:
parent
15a1977dcf
commit
df2a21ffd4
@ -5,7 +5,7 @@
|
||||
import copy
|
||||
import typing
|
||||
import warnings
|
||||
from typing import Optional
|
||||
from typing import Optional, Tuple, List, Dict, Any
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
@ -35,6 +35,16 @@ if typing.TYPE_CHECKING:
|
||||
from eland.query_compiler import QueryCompiler
|
||||
|
||||
|
||||
class QueryParams:
|
||||
def __init__(self):
|
||||
self.query = Query()
|
||||
self.sort_field: Optional[str] = None
|
||||
self.sort_order: Optional[SortOrder] = None
|
||||
self.size: Optional[int] = None
|
||||
self.fields: Optional[List[str]] = None
|
||||
self.script_fields: Optional[Dict[str, Dict[str, Any]]] = None
|
||||
|
||||
|
||||
class Operations:
|
||||
"""
|
||||
A collector of the queries and selectors we apply to queries to return the appropriate results.
|
||||
@ -107,7 +117,7 @@ class Operations:
|
||||
|
||||
counts = {}
|
||||
for field in fields:
|
||||
body = Query(query_params["query"])
|
||||
body = Query(query_params.query)
|
||||
body.exists(field, must=True)
|
||||
|
||||
field_exists_count = query_compiler._client.count(
|
||||
@ -175,7 +185,7 @@ class Operations:
|
||||
if numeric_only:
|
||||
fields = [field for field in fields if (field.is_numeric or field.is_bool)]
|
||||
|
||||
body = Query(query_params["query"])
|
||||
body = Query(query_params.query)
|
||||
|
||||
# Convert pandas aggs to ES equivalent
|
||||
es_aggs = self._map_pd_aggs_to_es_aggs(pd_aggs)
|
||||
@ -301,7 +311,7 @@ class Operations:
|
||||
# Get just aggregatable field_names
|
||||
aggregatable_field_names = query_compiler._mappings.aggregatable_field_names()
|
||||
|
||||
body = Query(query_params["query"])
|
||||
body = Query(query_params.query)
|
||||
|
||||
for field in aggregatable_field_names.keys():
|
||||
body.terms_aggs(field, func, field, es_size=es_size)
|
||||
@ -338,7 +348,7 @@ class Operations:
|
||||
|
||||
numeric_source_fields = query_compiler._mappings.numeric_source_fields()
|
||||
|
||||
body = Query(query_params["query"])
|
||||
body = Query(query_params.query)
|
||||
|
||||
results = self._metric_aggs(query_compiler, ["min", "max"], numeric_only=True)
|
||||
min_aggs = {}
|
||||
@ -348,7 +358,7 @@ class Operations:
|
||||
max_aggs[field] = max_agg
|
||||
|
||||
for field in numeric_source_fields:
|
||||
body.hist_aggs(field, field, min_aggs, max_aggs, num_bins)
|
||||
body.hist_aggs(field, field, min_aggs[field], max_aggs[field], num_bins)
|
||||
|
||||
response = query_compiler._client.search(
|
||||
index=query_compiler._index_pattern, size=0, body=body.to_search_body()
|
||||
@ -525,7 +535,7 @@ class Operations:
|
||||
|
||||
# for each field we compute:
|
||||
# count, mean, std, min, 25%, 50%, 75%, max
|
||||
body = Query(query_params["query"])
|
||||
body = Query(query_params.query)
|
||||
|
||||
for field in numeric_source_fields:
|
||||
body.metric_aggs("extended_stats_" + field, "extended_stats", field)
|
||||
@ -639,8 +649,8 @@ class Operations:
|
||||
|
||||
size, sort_params = Operations._query_params_to_size_and_sort(query_params)
|
||||
|
||||
script_fields = query_params["query_script_fields"]
|
||||
query = Query(query_params["query"])
|
||||
script_fields = query_params.script_fields
|
||||
query = Query(query_params.query)
|
||||
|
||||
body = query.to_search_body()
|
||||
if script_fields is not None:
|
||||
@ -722,7 +732,7 @@ class Operations:
|
||||
# TODO - this is not necessarily valid as the field may not exist in ALL these docs
|
||||
return size
|
||||
|
||||
body = Query(query_params["query"])
|
||||
body = Query(query_params.query)
|
||||
body.exists(field, must=True)
|
||||
|
||||
return query_compiler._client.count(
|
||||
@ -751,7 +761,7 @@ class Operations:
|
||||
query_compiler, items
|
||||
)
|
||||
|
||||
body = Query(query_params["query"])
|
||||
body = Query(query_params.query)
|
||||
|
||||
if field == Index.ID_INDEX_FIELD:
|
||||
body.ids(items, must=True)
|
||||
@ -780,17 +790,16 @@ class Operations:
|
||||
self._tasks.append(task)
|
||||
|
||||
@staticmethod
|
||||
def _query_params_to_size_and_sort(query_params):
|
||||
def _query_params_to_size_and_sort(
|
||||
query_params: QueryParams,
|
||||
) -> Tuple[Optional[int], Optional[str]]:
|
||||
sort_params = None
|
||||
if query_params["query_sort_field"] and query_params["query_sort_order"]:
|
||||
if query_params.sort_field and query_params.sort_order:
|
||||
sort_params = (
|
||||
query_params["query_sort_field"]
|
||||
+ ":"
|
||||
+ SortOrder.to_string(query_params["query_sort_order"])
|
||||
f"{query_params.sort_field}:"
|
||||
f"{SortOrder.to_string(query_params.sort_order)}"
|
||||
)
|
||||
|
||||
size = query_params["query_size"]
|
||||
|
||||
size = query_params.size
|
||||
return size, sort_params
|
||||
|
||||
@staticmethod
|
||||
@ -800,7 +809,6 @@ class Operations:
|
||||
if isinstance(action, SizeTask):
|
||||
if size is None or action.size() < size:
|
||||
size = action.size()
|
||||
|
||||
return size
|
||||
|
||||
@staticmethod
|
||||
@ -815,15 +823,7 @@ class Operations:
|
||||
# Some operations can be simply combined into a single query
|
||||
# other operations require pre-queries and then combinations
|
||||
# other operations require in-core post-processing of results
|
||||
query_params = {
|
||||
"query_sort_field": None,
|
||||
"query_sort_order": None,
|
||||
"query_size": None,
|
||||
"query_fields": None,
|
||||
"query_script_fields": None,
|
||||
"query": Query(),
|
||||
}
|
||||
|
||||
query_params = QueryParams()
|
||||
post_processing = []
|
||||
|
||||
for task in self._tasks:
|
||||
@ -843,7 +843,7 @@ class Operations:
|
||||
|
||||
def _size(self, query_params, post_processing):
|
||||
# Shrink wrap code around checking if size parameter is set
|
||||
size = query_params["query_size"] # can be None
|
||||
size = query_params.size
|
||||
|
||||
pp_size = self._count_post_processing(post_processing)
|
||||
if pp_size is not None:
|
||||
@ -863,8 +863,8 @@ class Operations:
|
||||
size, sort_params = Operations._query_params_to_size_and_sort(query_params)
|
||||
_source = query_compiler._mappings.get_field_names()
|
||||
|
||||
script_fields = query_params["query_script_fields"]
|
||||
query = Query(query_params["query"])
|
||||
script_fields = query_params.script_fields
|
||||
query = Query(query_params.query)
|
||||
body = query.to_search_body()
|
||||
if script_fields is not None:
|
||||
body["script_fields"] = script_fields
|
||||
|
@ -107,12 +107,7 @@ class Query:
|
||||
self._aggs[name] = agg
|
||||
|
||||
def hist_aggs(
|
||||
self,
|
||||
name: str,
|
||||
field: str,
|
||||
min_aggs: Dict[str, Any],
|
||||
max_aggs: Dict[str, Any],
|
||||
num_bins: int,
|
||||
self, name: str, field: str, min_value: Any, max_value: Any, num_bins: int,
|
||||
) -> None:
|
||||
"""
|
||||
Add histogram agg e.g.
|
||||
@ -120,20 +115,18 @@ class Query:
|
||||
"name": {
|
||||
"histogram": {
|
||||
"field": "AvgTicketPrice"
|
||||
"interval": (max_aggs[field] - min_aggs[field])/bins
|
||||
"interval": (max_value - min_value)/bins
|
||||
"offset": min_value
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
min = min_aggs[field]
|
||||
max = max_aggs[field]
|
||||
|
||||
interval = (max - min) / num_bins
|
||||
interval = (max_value - min_value) / num_bins
|
||||
|
||||
if interval != 0:
|
||||
offset = min
|
||||
agg = {
|
||||
"histogram": {"field": field, "interval": interval, "offset": offset}
|
||||
"histogram": {"field": field, "interval": interval, "offset": min_value}
|
||||
}
|
||||
self._aggs[name] = agg
|
||||
|
||||
|
@ -3,7 +3,7 @@
|
||||
# See the LICENSE file in the project root for more information
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, List, Dict, Any, Tuple
|
||||
from typing import TYPE_CHECKING, List, Any, Tuple
|
||||
|
||||
from eland import SortOrder
|
||||
from eland.actions import HeadAction, TailAction, SortIndexAction
|
||||
@ -14,9 +14,9 @@ if TYPE_CHECKING:
|
||||
from .actions import PostProcessingAction # noqa: F401
|
||||
from .filter import BooleanFilter # noqa: F401
|
||||
from .query_compiler import QueryCompiler # noqa: F401
|
||||
from .operations import QueryParams # noqa: F401
|
||||
|
||||
QUERY_PARAMS_TYPE = Dict[str, Any]
|
||||
RESOLVED_TASK_TYPE = Tuple[QUERY_PARAMS_TYPE, List["PostProcessingAction"]]
|
||||
RESOLVED_TASK_TYPE = Tuple["QueryParams", List["PostProcessingAction"]]
|
||||
|
||||
|
||||
class Task(ABC):
|
||||
@ -39,7 +39,7 @@ class Task(ABC):
|
||||
@abstractmethod
|
||||
def resolve_task(
|
||||
self,
|
||||
query_params: QUERY_PARAMS_TYPE,
|
||||
query_params: "QueryParams",
|
||||
post_processing: List["PostProcessingAction"],
|
||||
query_compiler: "QueryCompiler",
|
||||
) -> RESOLVED_TASK_TYPE:
|
||||
@ -70,7 +70,7 @@ class HeadTask(SizeTask):
|
||||
|
||||
def resolve_task(
|
||||
self,
|
||||
query_params: QUERY_PARAMS_TYPE,
|
||||
query_params: "QueryParams",
|
||||
post_processing: List["PostProcessingAction"],
|
||||
query_compiler: "QueryCompiler",
|
||||
) -> RESOLVED_TASK_TYPE:
|
||||
@ -87,20 +87,20 @@ class HeadTask(SizeTask):
|
||||
post_processing.append(HeadAction(self._count))
|
||||
return query_params, post_processing
|
||||
|
||||
if query_params["query_sort_field"] is None:
|
||||
query_params["query_sort_field"] = query_sort_field
|
||||
if query_params.sort_field is None:
|
||||
query_params.sort_field = query_sort_field
|
||||
# if it is already sorted we use existing field
|
||||
|
||||
if query_params["query_sort_order"] is None:
|
||||
query_params["query_sort_order"] = query_sort_order
|
||||
if query_params.sort_order is None:
|
||||
query_params.sort_order = query_sort_order
|
||||
# if it is already sorted we get head of existing order
|
||||
|
||||
if query_params["query_size"] is None:
|
||||
query_params["query_size"] = query_size
|
||||
if query_params.size is None:
|
||||
query_params.size = query_size
|
||||
else:
|
||||
# truncate if head is smaller
|
||||
if query_size < query_params["query_size"]:
|
||||
query_params["query_size"] = query_size
|
||||
if query_size < query_params.size:
|
||||
query_params.size = query_size
|
||||
|
||||
return query_params, post_processing
|
||||
|
||||
@ -118,7 +118,7 @@ class TailTask(SizeTask):
|
||||
|
||||
def resolve_task(
|
||||
self,
|
||||
query_params: QUERY_PARAMS_TYPE,
|
||||
query_params: "QueryParams",
|
||||
post_processing: List["PostProcessingAction"],
|
||||
query_compiler: "QueryCompiler",
|
||||
) -> RESOLVED_TASK_TYPE:
|
||||
@ -130,15 +130,15 @@ class TailTask(SizeTask):
|
||||
|
||||
# If this is a tail of a tail adjust settings and return
|
||||
if (
|
||||
query_params["query_size"] is not None
|
||||
and query_params["query_sort_order"] == query_sort_order
|
||||
query_params.size is not None
|
||||
and query_params.sort_order == query_sort_order
|
||||
and (
|
||||
len(post_processing) == 1
|
||||
and isinstance(post_processing[0], SortIndexAction)
|
||||
)
|
||||
):
|
||||
if query_size < query_params["query_size"]:
|
||||
query_params["query_size"] = query_size
|
||||
if query_size < query_params.size:
|
||||
query_params.size = query_size
|
||||
return query_params, post_processing
|
||||
|
||||
# If we are already postprocessing the query results, just get 'tail' of these
|
||||
@ -151,18 +151,18 @@ class TailTask(SizeTask):
|
||||
# If results are already constrained, just get 'tail' of these
|
||||
# (note, currently we just append another tail, we don't optimise by
|
||||
# overwriting previous tail)
|
||||
if query_params["query_size"] is not None:
|
||||
if query_params.size is not None:
|
||||
post_processing.append(TailAction(self._count))
|
||||
return query_params, post_processing
|
||||
else:
|
||||
query_params["query_size"] = query_size
|
||||
if query_params["query_sort_field"] is None:
|
||||
query_params["query_sort_field"] = query_sort_field
|
||||
if query_params["query_sort_order"] is None:
|
||||
query_params["query_sort_order"] = query_sort_order
|
||||
query_params.size = query_size
|
||||
if query_params.sort_field is None:
|
||||
query_params.sort_field = query_sort_field
|
||||
if query_params.sort_order is None:
|
||||
query_params.sort_order = query_sort_order
|
||||
else:
|
||||
# reverse sort order
|
||||
query_params["query_sort_order"] = SortOrder.reverse(query_sort_order)
|
||||
query_params.sort_order = SortOrder.reverse(query_sort_order)
|
||||
|
||||
post_processing.append(SortIndexAction())
|
||||
|
||||
@ -193,11 +193,11 @@ class QueryIdsTask(Task):
|
||||
|
||||
def resolve_task(
|
||||
self,
|
||||
query_params: QUERY_PARAMS_TYPE,
|
||||
query_params: "QueryParams",
|
||||
post_processing: List["PostProcessingAction"],
|
||||
query_compiler: "QueryCompiler",
|
||||
) -> RESOLVED_TASK_TYPE:
|
||||
query_params["query"].ids(self._ids, must=self._must)
|
||||
query_params.query.ids(self._ids, must=self._must)
|
||||
return query_params, post_processing
|
||||
|
||||
def __repr__(self) -> str:
|
||||
@ -226,11 +226,11 @@ class QueryTermsTask(Task):
|
||||
|
||||
def resolve_task(
|
||||
self,
|
||||
query_params: QUERY_PARAMS_TYPE,
|
||||
query_params: "QueryParams",
|
||||
post_processing: List["PostProcessingAction"],
|
||||
query_compiler: "QueryCompiler",
|
||||
) -> RESOLVED_TASK_TYPE:
|
||||
query_params["query"].terms(self._field, self._terms, must=self._must)
|
||||
query_params.query.terms(self._field, self._terms, must=self._must)
|
||||
return query_params, post_processing
|
||||
|
||||
def __repr__(self) -> str:
|
||||
@ -254,11 +254,11 @@ class BooleanFilterTask(Task):
|
||||
|
||||
def resolve_task(
|
||||
self,
|
||||
query_params: QUERY_PARAMS_TYPE,
|
||||
query_params: "QueryParams",
|
||||
post_processing: List["PostProcessingAction"],
|
||||
query_compiler: "QueryCompiler",
|
||||
) -> RESOLVED_TASK_TYPE:
|
||||
query_params["query"].update_boolean_filter(self._boolean_filter)
|
||||
query_params.query.update_boolean_filter(self._boolean_filter)
|
||||
return query_params, post_processing
|
||||
|
||||
def __repr__(self) -> str:
|
||||
@ -281,7 +281,7 @@ class ArithmeticOpFieldsTask(Task):
|
||||
|
||||
def resolve_task(
|
||||
self,
|
||||
query_params: QUERY_PARAMS_TYPE,
|
||||
query_params: "QueryParams",
|
||||
post_processing: List["PostProcessingAction"],
|
||||
query_compiler: "QueryCompiler",
|
||||
) -> RESOLVED_TASK_TYPE:
|
||||
@ -295,20 +295,20 @@ class ArithmeticOpFieldsTask(Task):
|
||||
}
|
||||
}
|
||||
"""
|
||||
if query_params["query_script_fields"] is None:
|
||||
query_params["query_script_fields"] = {}
|
||||
if query_params.script_fields is None:
|
||||
query_params.script_fields = {}
|
||||
|
||||
# TODO: Remove this once 'query_params' becomes a dataclass.
|
||||
assert isinstance(query_params["query_script_fields"], dict)
|
||||
assert isinstance(query_params.script_fields, dict)
|
||||
|
||||
if self._display_name in query_params["query_script_fields"]:
|
||||
if self._display_name in query_params.script_fields:
|
||||
raise NotImplementedError(
|
||||
f"TODO code path - combine multiple ops "
|
||||
f"'{self}'\n{query_params['query_script_fields']}\n"
|
||||
f"'{self}'\n{query_params.script_fields}\n"
|
||||
f"{self._display_name}\n{self._arithmetic_series.resolve()}"
|
||||
)
|
||||
|
||||
query_params["query_script_fields"][self._display_name] = {
|
||||
query_params.script_fields[self._display_name] = {
|
||||
"script": {"source": self._arithmetic_series.resolve()}
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user