mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
226 lines
7.2 KiB
Python
226 lines
7.2 KiB
Python
# Licensed to Elasticsearch B.V. under one or more contributor
|
|
# license agreements. See the NOTICE file distributed with
|
|
# this work for additional information regarding copyright
|
|
# ownership. Elasticsearch B.V. licenses this file to you under
|
|
# the Apache License, Version 2.0 (the "License"); you may
|
|
# not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing,
|
|
# software distributed under the License is distributed on an
|
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
# KIND, either express or implied. See the License for the
|
|
# specific language governing permissions and limitations
|
|
# under the License.
|
|
|
|
from abc import ABC, abstractmethod
|
|
from io import StringIO
|
|
from typing import TYPE_CHECKING, Any, List, Union
|
|
|
|
import numpy as np
|
|
|
|
if TYPE_CHECKING:
|
|
from numpy.typing import DTypeLike
|
|
|
|
from .query_compiler import QueryCompiler
|
|
|
|
|
|
class ArithmeticObject(ABC):
|
|
@property
|
|
@abstractmethod
|
|
def value(self) -> str:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def dtype(self) -> "DTypeLike":
|
|
pass
|
|
|
|
@abstractmethod
|
|
def resolve(self) -> str:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def __repr__(self) -> str:
|
|
pass
|
|
|
|
|
|
class ArithmeticString(ArithmeticObject):
|
|
def __init__(self, value: str):
|
|
self._value = value
|
|
|
|
def resolve(self) -> str:
|
|
return self.value
|
|
|
|
@property
|
|
def dtype(self) -> "DTypeLike":
|
|
return np.dtype(object)
|
|
|
|
@property
|
|
def value(self) -> str:
|
|
return f"'{self._value}'"
|
|
|
|
def __repr__(self) -> str:
|
|
return self.value
|
|
|
|
|
|
class ArithmeticNumber(ArithmeticObject):
|
|
def __init__(self, value: Union[int, float], dtype: "DTypeLike"):
|
|
self._value = value
|
|
self._dtype = dtype
|
|
|
|
def resolve(self) -> str:
|
|
return self.value
|
|
|
|
@property
|
|
def value(self) -> str:
|
|
return f"{self._value}"
|
|
|
|
@property
|
|
def dtype(self) -> "DTypeLike":
|
|
return self._dtype
|
|
|
|
def __repr__(self) -> str:
|
|
return self.value
|
|
|
|
|
|
class ArithmeticSeries(ArithmeticObject):
|
|
"""Represents each item in a 'Series' by using painless scripts
|
|
to evaluate each document in an index as a part of a query.
|
|
"""
|
|
|
|
def __init__(
|
|
self, query_compiler: "QueryCompiler", display_name: str, dtype: "DTypeLike"
|
|
) -> None:
|
|
# type defs
|
|
self._value: str
|
|
self._tasks: List["ArithmeticTask"]
|
|
|
|
task = query_compiler.get_arithmetic_op_fields()
|
|
|
|
if task is not None:
|
|
assert isinstance(task._arithmetic_series, ArithmeticSeries)
|
|
self._value = task._arithmetic_series.value
|
|
self._tasks = task._arithmetic_series._tasks.copy()
|
|
self._dtype = dtype
|
|
else:
|
|
aggregatable_field_name = query_compiler.display_name_to_aggregatable_name(
|
|
display_name
|
|
)
|
|
self._value = f"doc['{aggregatable_field_name}'].value"
|
|
self._tasks = []
|
|
self._dtype = dtype
|
|
|
|
@property
|
|
def value(self) -> str:
|
|
return self._value
|
|
|
|
@property
|
|
def dtype(self) -> "DTypeLike":
|
|
return self._dtype
|
|
|
|
def __repr__(self) -> str:
|
|
buf = StringIO()
|
|
buf.write(f"Series: {self.value} ")
|
|
buf.write("Tasks: ")
|
|
for task in self._tasks:
|
|
buf.write(f"{task!r} ")
|
|
return buf.getvalue()
|
|
|
|
def resolve(self) -> str:
|
|
value = self._value
|
|
|
|
for task in self._tasks:
|
|
if task.op_name == "__add__":
|
|
value = f"({value} + {task.object.resolve()})"
|
|
elif task.op_name in {"__truediv__", "__div__"}:
|
|
value = f"({value} / {task.object.resolve()})"
|
|
elif task.op_name == "__floordiv__":
|
|
value = f"Math.floor({value} / {task.object.resolve()})"
|
|
elif task.op_name == "__mod__":
|
|
value = f"({value} % {task.object.resolve()})"
|
|
elif task.op_name == "__mul__":
|
|
value = f"({value} * {task.object.resolve()})"
|
|
elif task.op_name == "__pow__":
|
|
value = f"Math.pow({value}, {task.object.resolve()})"
|
|
elif task.op_name == "__sub__":
|
|
value = f"({value} - {task.object.resolve()})"
|
|
elif task.op_name == "__radd__":
|
|
value = f"({task.object.resolve()} + {value})"
|
|
elif task.op_name in {"__rtruediv__", "__rdiv__"}:
|
|
value = f"({task.object.resolve()} / {value})"
|
|
elif task.op_name == "__rfloordiv__":
|
|
value = f"Math.floor({task.object.resolve()} / {value})"
|
|
elif task.op_name == "__rmod__":
|
|
value = f"({task.object.resolve()} % {value})"
|
|
elif task.op_name == "__rmul__":
|
|
value = f"({task.object.resolve()} * {value})"
|
|
elif task.op_name == "__rpow__":
|
|
value = f"Math.pow({task.object.resolve()}, {value})"
|
|
elif task.op_name == "__rsub__":
|
|
value = f"({task.object.resolve()} - {value})"
|
|
|
|
return value
|
|
|
|
def arithmetic_operation(self, op_name: str, right: Any) -> "ArithmeticSeries":
|
|
# check if operation is supported (raises on unsupported)
|
|
self.check_is_supported(op_name, right)
|
|
|
|
task = ArithmeticTask(op_name, right)
|
|
self._tasks.append(task)
|
|
return self
|
|
|
|
def check_is_supported(self, op_name: str, right: Any) -> bool:
|
|
# supported set is
|
|
# series.number op_name number (all ops)
|
|
# series.string op_name string (only add)
|
|
# series.string op_name int (only mul)
|
|
# series.string op_name float (none)
|
|
# series.int op_name string (none)
|
|
# series.float op_name string (none)
|
|
|
|
# see end of https://pandas.pydata.org/pandas-docs/stable/getting_started/basics.html?highlight=dtype
|
|
# for dtype hierarchy
|
|
right_is_integer = np.issubdtype(right.dtype, np.number)
|
|
if np.issubdtype(self.dtype, np.number) and right_is_integer:
|
|
# series.number op_name number (all ops)
|
|
return True
|
|
|
|
self_is_object = np.issubdtype(self.dtype, np.object_)
|
|
if self_is_object and np.issubdtype(right.dtype, np.object_):
|
|
# series.string op_name string (only add)
|
|
if op_name == "__add__" or op_name == "__radd__":
|
|
return True
|
|
|
|
if self_is_object and right_is_integer:
|
|
# series.string op_name int (only mul)
|
|
if op_name == "__mul__":
|
|
return True
|
|
|
|
raise TypeError(
|
|
f"Arithmetic operation on incompatible types {self.dtype} {op_name} {right.dtype}"
|
|
)
|
|
|
|
|
|
class ArithmeticTask:
|
|
def __init__(self, op_name: str, object: ArithmeticObject):
|
|
self._op_name = op_name
|
|
|
|
if not isinstance(object, ArithmeticObject):
|
|
raise TypeError(f"Task requires ArithmeticObject not {type(object)}")
|
|
self._object = object
|
|
|
|
def __repr__(self) -> str:
|
|
buf = StringIO()
|
|
buf.write(f"op_name: {self.op_name} object: {self.object!r} ")
|
|
return buf.getvalue()
|
|
|
|
@property
|
|
def op_name(self) -> str:
|
|
return self._op_name
|
|
|
|
@property
|
|
def object(self) -> ArithmeticObject:
|
|
return self._object
|