diff --git a/eland/arithmetics.py b/eland/arithmetics.py index 6242fe1..5fe0d7d 100644 --- a/eland/arithmetics.py +++ b/eland/arithmetics.py @@ -108,12 +108,6 @@ class ArithmeticSeries(ArithmeticObject): aggregatable_field_name = query_compiler.display_name_to_aggregatable_name( display_name ) - if aggregatable_field_name is None: - raise ValueError( - f"Can not perform arithmetic operations on non aggregatable fields" - f"{display_name} is not aggregatable." - ) - self._value = f"doc['{aggregatable_field_name}'].value" self._tasks = [] self._dtype = dtype diff --git a/eland/field_mappings.py b/eland/field_mappings.py index 5c48d81..cb0936a 100644 --- a/eland/field_mappings.py +++ b/eland/field_mappings.py @@ -584,21 +584,19 @@ class FieldMappings: raise KeyError if the field_name doesn't exist in the mapping, or isn't aggregatable """ - if display_name not in self._mappings_capabilities.index: + mapping: Optional[pd.Series] = None + + try: + mapping = self._mappings_capabilities.loc[display_name] + except KeyError: raise KeyError( f"Can not get aggregatable field name for invalid display name {display_name}" - ) + ) from None - if ( - self._mappings_capabilities.loc[display_name].aggregatable_es_field_name - is None - ): - warnings.warn( - f"Aggregations not supported for '{display_name}' " - f"'{self._mappings_capabilities.loc[display_name].es_field_name}'" - ) + if mapping is not None and mapping.aggregatable_es_field_name is None: + warnings.warn(f"Aggregations not supported for '{display_name}'") - return self._mappings_capabilities.loc[display_name].aggregatable_es_field_name + return mapping.aggregatable_es_field_name def aggregatable_field_names(self) -> Dict[str, str]: """ diff --git a/eland/query_compiler.py b/eland/query_compiler.py index cd5dfbb..cefdef1 100644 --- a/eland/query_compiler.py +++ b/eland/query_compiler.py @@ -797,9 +797,13 @@ class QueryCompiler: def get_arithmetic_op_fields(self) -> Optional["ArithmeticOpFieldsTask"]: return self._operations.get_arithmetic_op_fields() - def display_name_to_aggregatable_name(self, display_name: str) -> Optional[str]: + def display_name_to_aggregatable_name(self, display_name: str) -> str: aggregatable_field_name = self._mappings.aggregatable_field_name(display_name) - + if aggregatable_field_name is None: + raise ValueError( + f"Can not perform arithmetic operations on non aggregatable fields" + f"{display_name} is not aggregatable." + ) return aggregatable_field_name