Switch to 2024 black style (#657)

This commit is contained in:
Quentin Pradet 2024-01-31 14:47:19 +04:00 committed by GitHub
parent 2a6a4b1f06
commit 02190e74e7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 43 additions and 30 deletions

View File

@ -956,8 +956,10 @@ class DataFrame(NDFrame):
elif verbose is False: # specifically set to False, not nesc None
_non_verbose_repr()
else:
_non_verbose_repr() if exceeds_info_cols else _verbose_repr(
number_of_columns
(
_non_verbose_repr()
if exceeds_info_cols
else _verbose_repr(number_of_columns)
)
# pandas 0.25.1 uses get_dtype_counts() here. This

View File

@ -443,9 +443,9 @@ class FieldMappings:
try:
series = df.loc[df.es_field_name == es_field_name_keyword]
if not series.empty and series.is_aggregatable.squeeze():
row_as_dict[
"aggregatable_es_field_name"
] = es_field_name_keyword
row_as_dict["aggregatable_es_field_name"] = (
es_field_name_keyword
)
else:
row_as_dict["aggregatable_es_field_name"] = None
except KeyError:

View File

@ -169,9 +169,11 @@ class TargetMeanEncoder(FunctionTransformer):
def func(column):
return np.array(
[
target_map[str(category)]
if category in target_map
else fallback_value
(
target_map[str(category)]
if category in target_map
else fallback_value
)
for category in column
]
).reshape(-1, 1)
@ -197,9 +199,11 @@ class FrequencyEncoder(FunctionTransformer):
def func(column):
return np.array(
[
frequency_map[str(category)]
if category in frequency_map
else fallback_value
(
frequency_map[str(category)]
if category in frequency_map
else fallback_value
)
for category in column
]
).reshape(-1, 1)

View File

@ -50,12 +50,10 @@ class TraceableModel(ABC):
return self._trace()
@abstractmethod
def sample_output(self) -> torch.Tensor:
...
def sample_output(self) -> torch.Tensor: ...
@abstractmethod
def _trace(self) -> TracedModelTypes:
...
def _trace(self) -> TracedModelTypes: ...
def classification_labels(self) -> Optional[List[str]]:
return None

View File

@ -496,8 +496,7 @@ class _TransformerTraceableModel(TraceableModel):
)
@abstractmethod
def _prepare_inputs(self) -> transformers.BatchEncoding:
...
def _prepare_inputs(self) -> transformers.BatchEncoding: ...
class _TraceableClassificationModel(_TransformerTraceableModel, ABC):

View File

@ -97,9 +97,11 @@ class LGBMForestTransformer(ModelTransformer):
return TreeNode(
node_idx=node_id,
leaf_value=[float(tree_node_json_obj["leaf_value"])],
number_samples=int(tree_node_json_obj["leaf_count"])
if "leaf_count" in tree_node_json_obj
else None,
number_samples=(
int(tree_node_json_obj["leaf_count"])
if "leaf_count" in tree_node_json_obj
else None
),
)
def build_tree(self, tree_id: int, tree_json_obj: Dict[str, Any]) -> Tree:
@ -235,9 +237,11 @@ class LGBMClassifierTransformer(LGBMForestTransformer):
return TreeNode(
node_idx=node_id,
leaf_value=leaf_val,
number_samples=int(tree_node_json_obj["leaf_count"])
if "leaf_count" in tree_node_json_obj
else None,
number_samples=(
int(tree_node_json_obj["leaf_count"])
if "leaf_count" in tree_node_json_obj
else None
),
)
def check_model_booster(self) -> None:

View File

@ -1156,9 +1156,11 @@ class Operations:
# piggy-back on that single aggregation.
if extended_stats_calls >= 2:
es_aggs = [
("extended_stats", es_agg)
if es_agg in extended_stats_es_aggs
else es_agg
(
("extended_stats", es_agg)
if es_agg in extended_stats_es_aggs
else es_agg
)
for es_agg in es_aggs
]

View File

@ -75,7 +75,7 @@ def lint(session):
session.run("python", "utils/license-headers.py", "check", *SOURCE_FILES)
session.run("black", "--check", "--target-version=py38", *SOURCE_FILES)
session.run("isort", "--check", "--profile=black", *SOURCE_FILES)
session.run("flake8", "--ignore=E501,W503,E402,E712,E203", *SOURCE_FILES)
session.run("flake8", "--extend-ignore=E203,E402,E501,E704,E712", *SOURCE_FILES)
# TODO: When all files are typed we can change this to .run("mypy", "--strict", "eland/")
session.log("mypy --show-error-codes --strict eland/")

View File

@ -54,9 +54,13 @@ class TestDataFrameIterrowsItertuples(TestData):
# Shim which uses pytest.approx() for floating point values inside tuples.
assert len(left) == len(right)
assert all(
(lt == rt) # Not floats? Use ==
if not isinstance(lt, float) and not isinstance(rt, float)
else (lt == pytest.approx(rt)) # If both are floats use pytest.approx()
(
# Not floats? Use ==
(lt == rt)
if not isinstance(lt, float) and not isinstance(rt, float)
# If both are floats use pytest.approx()
else (lt == pytest.approx(rt))
)
for lt, rt in zip(left, right)
)