From 02190e74e78cd973a2eea8207d971f2786588b6b Mon Sep 17 00:00:00 2001 From: Quentin Pradet Date: Wed, 31 Jan 2024 14:47:19 +0400 Subject: [PATCH] Switch to 2024 black style (#657) --- eland/dataframe.py | 6 ++++-- eland/field_mappings.py | 6 +++--- eland/ml/exporters/_sklearn_deserializers.py | 16 ++++++++++------ eland/ml/pytorch/traceable_model.py | 6 ++---- eland/ml/pytorch/transformers.py | 3 +-- eland/ml/transformers/lightgbm.py | 16 ++++++++++------ eland/operations.py | 8 +++++--- noxfile.py | 2 +- .../dataframe/test_iterrows_itertuples_pytest.py | 10 +++++++--- 9 files changed, 43 insertions(+), 30 deletions(-) diff --git a/eland/dataframe.py b/eland/dataframe.py index f08e21c..045166b 100644 --- a/eland/dataframe.py +++ b/eland/dataframe.py @@ -956,8 +956,10 @@ class DataFrame(NDFrame): elif verbose is False: # specifically set to False, not nesc None _non_verbose_repr() else: - _non_verbose_repr() if exceeds_info_cols else _verbose_repr( - number_of_columns + ( + _non_verbose_repr() + if exceeds_info_cols + else _verbose_repr(number_of_columns) ) # pandas 0.25.1 uses get_dtype_counts() here. This diff --git a/eland/field_mappings.py b/eland/field_mappings.py index 61e45a0..94cd2e6 100644 --- a/eland/field_mappings.py +++ b/eland/field_mappings.py @@ -443,9 +443,9 @@ class FieldMappings: try: series = df.loc[df.es_field_name == es_field_name_keyword] if not series.empty and series.is_aggregatable.squeeze(): - row_as_dict[ - "aggregatable_es_field_name" - ] = es_field_name_keyword + row_as_dict["aggregatable_es_field_name"] = ( + es_field_name_keyword + ) else: row_as_dict["aggregatable_es_field_name"] = None except KeyError: diff --git a/eland/ml/exporters/_sklearn_deserializers.py b/eland/ml/exporters/_sklearn_deserializers.py index 6df5df7..6de021c 100644 --- a/eland/ml/exporters/_sklearn_deserializers.py +++ b/eland/ml/exporters/_sklearn_deserializers.py @@ -169,9 +169,11 @@ class TargetMeanEncoder(FunctionTransformer): def func(column): return np.array( [ - target_map[str(category)] - if category in target_map - else fallback_value + ( + target_map[str(category)] + if category in target_map + else fallback_value + ) for category in column ] ).reshape(-1, 1) @@ -197,9 +199,11 @@ class FrequencyEncoder(FunctionTransformer): def func(column): return np.array( [ - frequency_map[str(category)] - if category in frequency_map - else fallback_value + ( + frequency_map[str(category)] + if category in frequency_map + else fallback_value + ) for category in column ] ).reshape(-1, 1) diff --git a/eland/ml/pytorch/traceable_model.py b/eland/ml/pytorch/traceable_model.py index a8a335c..7b8e13c 100644 --- a/eland/ml/pytorch/traceable_model.py +++ b/eland/ml/pytorch/traceable_model.py @@ -50,12 +50,10 @@ class TraceableModel(ABC): return self._trace() @abstractmethod - def sample_output(self) -> torch.Tensor: - ... + def sample_output(self) -> torch.Tensor: ... @abstractmethod - def _trace(self) -> TracedModelTypes: - ... + def _trace(self) -> TracedModelTypes: ... def classification_labels(self) -> Optional[List[str]]: return None diff --git a/eland/ml/pytorch/transformers.py b/eland/ml/pytorch/transformers.py index ff41870..002002b 100644 --- a/eland/ml/pytorch/transformers.py +++ b/eland/ml/pytorch/transformers.py @@ -496,8 +496,7 @@ class _TransformerTraceableModel(TraceableModel): ) @abstractmethod - def _prepare_inputs(self) -> transformers.BatchEncoding: - ... + def _prepare_inputs(self) -> transformers.BatchEncoding: ... class _TraceableClassificationModel(_TransformerTraceableModel, ABC): diff --git a/eland/ml/transformers/lightgbm.py b/eland/ml/transformers/lightgbm.py index f6293d6..55d55aa 100644 --- a/eland/ml/transformers/lightgbm.py +++ b/eland/ml/transformers/lightgbm.py @@ -97,9 +97,11 @@ class LGBMForestTransformer(ModelTransformer): return TreeNode( node_idx=node_id, leaf_value=[float(tree_node_json_obj["leaf_value"])], - number_samples=int(tree_node_json_obj["leaf_count"]) - if "leaf_count" in tree_node_json_obj - else None, + number_samples=( + int(tree_node_json_obj["leaf_count"]) + if "leaf_count" in tree_node_json_obj + else None + ), ) def build_tree(self, tree_id: int, tree_json_obj: Dict[str, Any]) -> Tree: @@ -235,9 +237,11 @@ class LGBMClassifierTransformer(LGBMForestTransformer): return TreeNode( node_idx=node_id, leaf_value=leaf_val, - number_samples=int(tree_node_json_obj["leaf_count"]) - if "leaf_count" in tree_node_json_obj - else None, + number_samples=( + int(tree_node_json_obj["leaf_count"]) + if "leaf_count" in tree_node_json_obj + else None + ), ) def check_model_booster(self) -> None: diff --git a/eland/operations.py b/eland/operations.py index cf25411..f9b58f0 100644 --- a/eland/operations.py +++ b/eland/operations.py @@ -1156,9 +1156,11 @@ class Operations: # piggy-back on that single aggregation. if extended_stats_calls >= 2: es_aggs = [ - ("extended_stats", es_agg) - if es_agg in extended_stats_es_aggs - else es_agg + ( + ("extended_stats", es_agg) + if es_agg in extended_stats_es_aggs + else es_agg + ) for es_agg in es_aggs ] diff --git a/noxfile.py b/noxfile.py index 0eeae2a..19698d3 100644 --- a/noxfile.py +++ b/noxfile.py @@ -75,7 +75,7 @@ def lint(session): session.run("python", "utils/license-headers.py", "check", *SOURCE_FILES) session.run("black", "--check", "--target-version=py38", *SOURCE_FILES) session.run("isort", "--check", "--profile=black", *SOURCE_FILES) - session.run("flake8", "--ignore=E501,W503,E402,E712,E203", *SOURCE_FILES) + session.run("flake8", "--extend-ignore=E203,E402,E501,E704,E712", *SOURCE_FILES) # TODO: When all files are typed we can change this to .run("mypy", "--strict", "eland/") session.log("mypy --show-error-codes --strict eland/") diff --git a/tests/dataframe/test_iterrows_itertuples_pytest.py b/tests/dataframe/test_iterrows_itertuples_pytest.py index 9dc495e..65a0e09 100644 --- a/tests/dataframe/test_iterrows_itertuples_pytest.py +++ b/tests/dataframe/test_iterrows_itertuples_pytest.py @@ -54,9 +54,13 @@ class TestDataFrameIterrowsItertuples(TestData): # Shim which uses pytest.approx() for floating point values inside tuples. assert len(left) == len(right) assert all( - (lt == rt) # Not floats? Use == - if not isinstance(lt, float) and not isinstance(rt, float) - else (lt == pytest.approx(rt)) # If both are floats use pytest.approx() + ( + # Not floats? Use == + (lt == rt) + if not isinstance(lt, float) and not isinstance(rt, float) + # If both are floats use pytest.approx() + else (lt == pytest.approx(rt)) + ) for lt, rt in zip(left, right) )