mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
Switch to 2024 black style (#657)
This commit is contained in:
parent
2a6a4b1f06
commit
02190e74e7
@ -956,8 +956,10 @@ class DataFrame(NDFrame):
|
||||
elif verbose is False: # specifically set to False, not nesc None
|
||||
_non_verbose_repr()
|
||||
else:
|
||||
_non_verbose_repr() if exceeds_info_cols else _verbose_repr(
|
||||
number_of_columns
|
||||
(
|
||||
_non_verbose_repr()
|
||||
if exceeds_info_cols
|
||||
else _verbose_repr(number_of_columns)
|
||||
)
|
||||
|
||||
# pandas 0.25.1 uses get_dtype_counts() here. This
|
||||
|
@ -443,9 +443,9 @@ class FieldMappings:
|
||||
try:
|
||||
series = df.loc[df.es_field_name == es_field_name_keyword]
|
||||
if not series.empty and series.is_aggregatable.squeeze():
|
||||
row_as_dict[
|
||||
"aggregatable_es_field_name"
|
||||
] = es_field_name_keyword
|
||||
row_as_dict["aggregatable_es_field_name"] = (
|
||||
es_field_name_keyword
|
||||
)
|
||||
else:
|
||||
row_as_dict["aggregatable_es_field_name"] = None
|
||||
except KeyError:
|
||||
|
@ -169,9 +169,11 @@ class TargetMeanEncoder(FunctionTransformer):
|
||||
def func(column):
|
||||
return np.array(
|
||||
[
|
||||
target_map[str(category)]
|
||||
if category in target_map
|
||||
else fallback_value
|
||||
(
|
||||
target_map[str(category)]
|
||||
if category in target_map
|
||||
else fallback_value
|
||||
)
|
||||
for category in column
|
||||
]
|
||||
).reshape(-1, 1)
|
||||
@ -197,9 +199,11 @@ class FrequencyEncoder(FunctionTransformer):
|
||||
def func(column):
|
||||
return np.array(
|
||||
[
|
||||
frequency_map[str(category)]
|
||||
if category in frequency_map
|
||||
else fallback_value
|
||||
(
|
||||
frequency_map[str(category)]
|
||||
if category in frequency_map
|
||||
else fallback_value
|
||||
)
|
||||
for category in column
|
||||
]
|
||||
).reshape(-1, 1)
|
||||
|
@ -50,12 +50,10 @@ class TraceableModel(ABC):
|
||||
return self._trace()
|
||||
|
||||
@abstractmethod
|
||||
def sample_output(self) -> torch.Tensor:
|
||||
...
|
||||
def sample_output(self) -> torch.Tensor: ...
|
||||
|
||||
@abstractmethod
|
||||
def _trace(self) -> TracedModelTypes:
|
||||
...
|
||||
def _trace(self) -> TracedModelTypes: ...
|
||||
|
||||
def classification_labels(self) -> Optional[List[str]]:
|
||||
return None
|
||||
|
@ -496,8 +496,7 @@ class _TransformerTraceableModel(TraceableModel):
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def _prepare_inputs(self) -> transformers.BatchEncoding:
|
||||
...
|
||||
def _prepare_inputs(self) -> transformers.BatchEncoding: ...
|
||||
|
||||
|
||||
class _TraceableClassificationModel(_TransformerTraceableModel, ABC):
|
||||
|
@ -97,9 +97,11 @@ class LGBMForestTransformer(ModelTransformer):
|
||||
return TreeNode(
|
||||
node_idx=node_id,
|
||||
leaf_value=[float(tree_node_json_obj["leaf_value"])],
|
||||
number_samples=int(tree_node_json_obj["leaf_count"])
|
||||
if "leaf_count" in tree_node_json_obj
|
||||
else None,
|
||||
number_samples=(
|
||||
int(tree_node_json_obj["leaf_count"])
|
||||
if "leaf_count" in tree_node_json_obj
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
def build_tree(self, tree_id: int, tree_json_obj: Dict[str, Any]) -> Tree:
|
||||
@ -235,9 +237,11 @@ class LGBMClassifierTransformer(LGBMForestTransformer):
|
||||
return TreeNode(
|
||||
node_idx=node_id,
|
||||
leaf_value=leaf_val,
|
||||
number_samples=int(tree_node_json_obj["leaf_count"])
|
||||
if "leaf_count" in tree_node_json_obj
|
||||
else None,
|
||||
number_samples=(
|
||||
int(tree_node_json_obj["leaf_count"])
|
||||
if "leaf_count" in tree_node_json_obj
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
def check_model_booster(self) -> None:
|
||||
|
@ -1156,9 +1156,11 @@ class Operations:
|
||||
# piggy-back on that single aggregation.
|
||||
if extended_stats_calls >= 2:
|
||||
es_aggs = [
|
||||
("extended_stats", es_agg)
|
||||
if es_agg in extended_stats_es_aggs
|
||||
else es_agg
|
||||
(
|
||||
("extended_stats", es_agg)
|
||||
if es_agg in extended_stats_es_aggs
|
||||
else es_agg
|
||||
)
|
||||
for es_agg in es_aggs
|
||||
]
|
||||
|
||||
|
@ -75,7 +75,7 @@ def lint(session):
|
||||
session.run("python", "utils/license-headers.py", "check", *SOURCE_FILES)
|
||||
session.run("black", "--check", "--target-version=py38", *SOURCE_FILES)
|
||||
session.run("isort", "--check", "--profile=black", *SOURCE_FILES)
|
||||
session.run("flake8", "--ignore=E501,W503,E402,E712,E203", *SOURCE_FILES)
|
||||
session.run("flake8", "--extend-ignore=E203,E402,E501,E704,E712", *SOURCE_FILES)
|
||||
|
||||
# TODO: When all files are typed we can change this to .run("mypy", "--strict", "eland/")
|
||||
session.log("mypy --show-error-codes --strict eland/")
|
||||
|
@ -54,9 +54,13 @@ class TestDataFrameIterrowsItertuples(TestData):
|
||||
# Shim which uses pytest.approx() for floating point values inside tuples.
|
||||
assert len(left) == len(right)
|
||||
assert all(
|
||||
(lt == rt) # Not floats? Use ==
|
||||
if not isinstance(lt, float) and not isinstance(rt, float)
|
||||
else (lt == pytest.approx(rt)) # If both are floats use pytest.approx()
|
||||
(
|
||||
# Not floats? Use ==
|
||||
(lt == rt)
|
||||
if not isinstance(lt, float) and not isinstance(rt, float)
|
||||
# If both are floats use pytest.approx()
|
||||
else (lt == pytest.approx(rt))
|
||||
)
|
||||
for lt, rt in zip(left, right)
|
||||
)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user