Switch to 2024 black style (#657)

This commit is contained in:
Quentin Pradet 2024-01-31 14:47:19 +04:00 committed by GitHub
parent 2a6a4b1f06
commit 02190e74e7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 43 additions and 30 deletions

View File

@ -956,8 +956,10 @@ class DataFrame(NDFrame):
elif verbose is False: # specifically set to False, not nesc None elif verbose is False: # specifically set to False, not nesc None
_non_verbose_repr() _non_verbose_repr()
else: else:
_non_verbose_repr() if exceeds_info_cols else _verbose_repr( (
number_of_columns _non_verbose_repr()
if exceeds_info_cols
else _verbose_repr(number_of_columns)
) )
# pandas 0.25.1 uses get_dtype_counts() here. This # pandas 0.25.1 uses get_dtype_counts() here. This

View File

@ -443,9 +443,9 @@ class FieldMappings:
try: try:
series = df.loc[df.es_field_name == es_field_name_keyword] series = df.loc[df.es_field_name == es_field_name_keyword]
if not series.empty and series.is_aggregatable.squeeze(): if not series.empty and series.is_aggregatable.squeeze():
row_as_dict[ row_as_dict["aggregatable_es_field_name"] = (
"aggregatable_es_field_name" es_field_name_keyword
] = es_field_name_keyword )
else: else:
row_as_dict["aggregatable_es_field_name"] = None row_as_dict["aggregatable_es_field_name"] = None
except KeyError: except KeyError:

View File

@ -169,9 +169,11 @@ class TargetMeanEncoder(FunctionTransformer):
def func(column): def func(column):
return np.array( return np.array(
[ [
target_map[str(category)] (
if category in target_map target_map[str(category)]
else fallback_value if category in target_map
else fallback_value
)
for category in column for category in column
] ]
).reshape(-1, 1) ).reshape(-1, 1)
@ -197,9 +199,11 @@ class FrequencyEncoder(FunctionTransformer):
def func(column): def func(column):
return np.array( return np.array(
[ [
frequency_map[str(category)] (
if category in frequency_map frequency_map[str(category)]
else fallback_value if category in frequency_map
else fallback_value
)
for category in column for category in column
] ]
).reshape(-1, 1) ).reshape(-1, 1)

View File

@ -50,12 +50,10 @@ class TraceableModel(ABC):
return self._trace() return self._trace()
@abstractmethod @abstractmethod
def sample_output(self) -> torch.Tensor: def sample_output(self) -> torch.Tensor: ...
...
@abstractmethod @abstractmethod
def _trace(self) -> TracedModelTypes: def _trace(self) -> TracedModelTypes: ...
...
def classification_labels(self) -> Optional[List[str]]: def classification_labels(self) -> Optional[List[str]]:
return None return None

View File

@ -496,8 +496,7 @@ class _TransformerTraceableModel(TraceableModel):
) )
@abstractmethod @abstractmethod
def _prepare_inputs(self) -> transformers.BatchEncoding: def _prepare_inputs(self) -> transformers.BatchEncoding: ...
...
class _TraceableClassificationModel(_TransformerTraceableModel, ABC): class _TraceableClassificationModel(_TransformerTraceableModel, ABC):

View File

@ -97,9 +97,11 @@ class LGBMForestTransformer(ModelTransformer):
return TreeNode( return TreeNode(
node_idx=node_id, node_idx=node_id,
leaf_value=[float(tree_node_json_obj["leaf_value"])], leaf_value=[float(tree_node_json_obj["leaf_value"])],
number_samples=int(tree_node_json_obj["leaf_count"]) number_samples=(
if "leaf_count" in tree_node_json_obj int(tree_node_json_obj["leaf_count"])
else None, if "leaf_count" in tree_node_json_obj
else None
),
) )
def build_tree(self, tree_id: int, tree_json_obj: Dict[str, Any]) -> Tree: def build_tree(self, tree_id: int, tree_json_obj: Dict[str, Any]) -> Tree:
@ -235,9 +237,11 @@ class LGBMClassifierTransformer(LGBMForestTransformer):
return TreeNode( return TreeNode(
node_idx=node_id, node_idx=node_id,
leaf_value=leaf_val, leaf_value=leaf_val,
number_samples=int(tree_node_json_obj["leaf_count"]) number_samples=(
if "leaf_count" in tree_node_json_obj int(tree_node_json_obj["leaf_count"])
else None, if "leaf_count" in tree_node_json_obj
else None
),
) )
def check_model_booster(self) -> None: def check_model_booster(self) -> None:

View File

@ -1156,9 +1156,11 @@ class Operations:
# piggy-back on that single aggregation. # piggy-back on that single aggregation.
if extended_stats_calls >= 2: if extended_stats_calls >= 2:
es_aggs = [ es_aggs = [
("extended_stats", es_agg) (
if es_agg in extended_stats_es_aggs ("extended_stats", es_agg)
else es_agg if es_agg in extended_stats_es_aggs
else es_agg
)
for es_agg in es_aggs for es_agg in es_aggs
] ]

View File

@ -75,7 +75,7 @@ def lint(session):
session.run("python", "utils/license-headers.py", "check", *SOURCE_FILES) session.run("python", "utils/license-headers.py", "check", *SOURCE_FILES)
session.run("black", "--check", "--target-version=py38", *SOURCE_FILES) session.run("black", "--check", "--target-version=py38", *SOURCE_FILES)
session.run("isort", "--check", "--profile=black", *SOURCE_FILES) session.run("isort", "--check", "--profile=black", *SOURCE_FILES)
session.run("flake8", "--ignore=E501,W503,E402,E712,E203", *SOURCE_FILES) session.run("flake8", "--extend-ignore=E203,E402,E501,E704,E712", *SOURCE_FILES)
# TODO: When all files are typed we can change this to .run("mypy", "--strict", "eland/") # TODO: When all files are typed we can change this to .run("mypy", "--strict", "eland/")
session.log("mypy --show-error-codes --strict eland/") session.log("mypy --show-error-codes --strict eland/")

View File

@ -54,9 +54,13 @@ class TestDataFrameIterrowsItertuples(TestData):
# Shim which uses pytest.approx() for floating point values inside tuples. # Shim which uses pytest.approx() for floating point values inside tuples.
assert len(left) == len(right) assert len(left) == len(right)
assert all( assert all(
(lt == rt) # Not floats? Use == (
if not isinstance(lt, float) and not isinstance(rt, float) # Not floats? Use ==
else (lt == pytest.approx(rt)) # If both are floats use pytest.approx() (lt == rt)
if not isinstance(lt, float) and not isinstance(rt, float)
# If both are floats use pytest.approx()
else (lt == pytest.approx(rt))
)
for lt, rt in zip(left, right) for lt, rt in zip(left, right)
) )