mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
Switch to 2024 black style (#657)
This commit is contained in:
parent
2a6a4b1f06
commit
02190e74e7
@ -956,8 +956,10 @@ class DataFrame(NDFrame):
|
|||||||
elif verbose is False: # specifically set to False, not nesc None
|
elif verbose is False: # specifically set to False, not nesc None
|
||||||
_non_verbose_repr()
|
_non_verbose_repr()
|
||||||
else:
|
else:
|
||||||
_non_verbose_repr() if exceeds_info_cols else _verbose_repr(
|
(
|
||||||
number_of_columns
|
_non_verbose_repr()
|
||||||
|
if exceeds_info_cols
|
||||||
|
else _verbose_repr(number_of_columns)
|
||||||
)
|
)
|
||||||
|
|
||||||
# pandas 0.25.1 uses get_dtype_counts() here. This
|
# pandas 0.25.1 uses get_dtype_counts() here. This
|
||||||
|
@ -443,9 +443,9 @@ class FieldMappings:
|
|||||||
try:
|
try:
|
||||||
series = df.loc[df.es_field_name == es_field_name_keyword]
|
series = df.loc[df.es_field_name == es_field_name_keyword]
|
||||||
if not series.empty and series.is_aggregatable.squeeze():
|
if not series.empty and series.is_aggregatable.squeeze():
|
||||||
row_as_dict[
|
row_as_dict["aggregatable_es_field_name"] = (
|
||||||
"aggregatable_es_field_name"
|
es_field_name_keyword
|
||||||
] = es_field_name_keyword
|
)
|
||||||
else:
|
else:
|
||||||
row_as_dict["aggregatable_es_field_name"] = None
|
row_as_dict["aggregatable_es_field_name"] = None
|
||||||
except KeyError:
|
except KeyError:
|
||||||
|
@ -169,9 +169,11 @@ class TargetMeanEncoder(FunctionTransformer):
|
|||||||
def func(column):
|
def func(column):
|
||||||
return np.array(
|
return np.array(
|
||||||
[
|
[
|
||||||
target_map[str(category)]
|
(
|
||||||
if category in target_map
|
target_map[str(category)]
|
||||||
else fallback_value
|
if category in target_map
|
||||||
|
else fallback_value
|
||||||
|
)
|
||||||
for category in column
|
for category in column
|
||||||
]
|
]
|
||||||
).reshape(-1, 1)
|
).reshape(-1, 1)
|
||||||
@ -197,9 +199,11 @@ class FrequencyEncoder(FunctionTransformer):
|
|||||||
def func(column):
|
def func(column):
|
||||||
return np.array(
|
return np.array(
|
||||||
[
|
[
|
||||||
frequency_map[str(category)]
|
(
|
||||||
if category in frequency_map
|
frequency_map[str(category)]
|
||||||
else fallback_value
|
if category in frequency_map
|
||||||
|
else fallback_value
|
||||||
|
)
|
||||||
for category in column
|
for category in column
|
||||||
]
|
]
|
||||||
).reshape(-1, 1)
|
).reshape(-1, 1)
|
||||||
|
@ -50,12 +50,10 @@ class TraceableModel(ABC):
|
|||||||
return self._trace()
|
return self._trace()
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def sample_output(self) -> torch.Tensor:
|
def sample_output(self) -> torch.Tensor: ...
|
||||||
...
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _trace(self) -> TracedModelTypes:
|
def _trace(self) -> TracedModelTypes: ...
|
||||||
...
|
|
||||||
|
|
||||||
def classification_labels(self) -> Optional[List[str]]:
|
def classification_labels(self) -> Optional[List[str]]:
|
||||||
return None
|
return None
|
||||||
|
@ -496,8 +496,7 @@ class _TransformerTraceableModel(TraceableModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _prepare_inputs(self) -> transformers.BatchEncoding:
|
def _prepare_inputs(self) -> transformers.BatchEncoding: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
class _TraceableClassificationModel(_TransformerTraceableModel, ABC):
|
class _TraceableClassificationModel(_TransformerTraceableModel, ABC):
|
||||||
|
@ -97,9 +97,11 @@ class LGBMForestTransformer(ModelTransformer):
|
|||||||
return TreeNode(
|
return TreeNode(
|
||||||
node_idx=node_id,
|
node_idx=node_id,
|
||||||
leaf_value=[float(tree_node_json_obj["leaf_value"])],
|
leaf_value=[float(tree_node_json_obj["leaf_value"])],
|
||||||
number_samples=int(tree_node_json_obj["leaf_count"])
|
number_samples=(
|
||||||
if "leaf_count" in tree_node_json_obj
|
int(tree_node_json_obj["leaf_count"])
|
||||||
else None,
|
if "leaf_count" in tree_node_json_obj
|
||||||
|
else None
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def build_tree(self, tree_id: int, tree_json_obj: Dict[str, Any]) -> Tree:
|
def build_tree(self, tree_id: int, tree_json_obj: Dict[str, Any]) -> Tree:
|
||||||
@ -235,9 +237,11 @@ class LGBMClassifierTransformer(LGBMForestTransformer):
|
|||||||
return TreeNode(
|
return TreeNode(
|
||||||
node_idx=node_id,
|
node_idx=node_id,
|
||||||
leaf_value=leaf_val,
|
leaf_value=leaf_val,
|
||||||
number_samples=int(tree_node_json_obj["leaf_count"])
|
number_samples=(
|
||||||
if "leaf_count" in tree_node_json_obj
|
int(tree_node_json_obj["leaf_count"])
|
||||||
else None,
|
if "leaf_count" in tree_node_json_obj
|
||||||
|
else None
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def check_model_booster(self) -> None:
|
def check_model_booster(self) -> None:
|
||||||
|
@ -1156,9 +1156,11 @@ class Operations:
|
|||||||
# piggy-back on that single aggregation.
|
# piggy-back on that single aggregation.
|
||||||
if extended_stats_calls >= 2:
|
if extended_stats_calls >= 2:
|
||||||
es_aggs = [
|
es_aggs = [
|
||||||
("extended_stats", es_agg)
|
(
|
||||||
if es_agg in extended_stats_es_aggs
|
("extended_stats", es_agg)
|
||||||
else es_agg
|
if es_agg in extended_stats_es_aggs
|
||||||
|
else es_agg
|
||||||
|
)
|
||||||
for es_agg in es_aggs
|
for es_agg in es_aggs
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -75,7 +75,7 @@ def lint(session):
|
|||||||
session.run("python", "utils/license-headers.py", "check", *SOURCE_FILES)
|
session.run("python", "utils/license-headers.py", "check", *SOURCE_FILES)
|
||||||
session.run("black", "--check", "--target-version=py38", *SOURCE_FILES)
|
session.run("black", "--check", "--target-version=py38", *SOURCE_FILES)
|
||||||
session.run("isort", "--check", "--profile=black", *SOURCE_FILES)
|
session.run("isort", "--check", "--profile=black", *SOURCE_FILES)
|
||||||
session.run("flake8", "--ignore=E501,W503,E402,E712,E203", *SOURCE_FILES)
|
session.run("flake8", "--extend-ignore=E203,E402,E501,E704,E712", *SOURCE_FILES)
|
||||||
|
|
||||||
# TODO: When all files are typed we can change this to .run("mypy", "--strict", "eland/")
|
# TODO: When all files are typed we can change this to .run("mypy", "--strict", "eland/")
|
||||||
session.log("mypy --show-error-codes --strict eland/")
|
session.log("mypy --show-error-codes --strict eland/")
|
||||||
|
@ -54,9 +54,13 @@ class TestDataFrameIterrowsItertuples(TestData):
|
|||||||
# Shim which uses pytest.approx() for floating point values inside tuples.
|
# Shim which uses pytest.approx() for floating point values inside tuples.
|
||||||
assert len(left) == len(right)
|
assert len(left) == len(right)
|
||||||
assert all(
|
assert all(
|
||||||
(lt == rt) # Not floats? Use ==
|
(
|
||||||
if not isinstance(lt, float) and not isinstance(rt, float)
|
# Not floats? Use ==
|
||||||
else (lt == pytest.approx(rt)) # If both are floats use pytest.approx()
|
(lt == rt)
|
||||||
|
if not isinstance(lt, float) and not isinstance(rt, float)
|
||||||
|
# If both are floats use pytest.approx()
|
||||||
|
else (lt == pytest.approx(rt))
|
||||||
|
)
|
||||||
for lt, rt in zip(left, right)
|
for lt, rt in zip(left, right)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user