Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions src/power_grid_model_ds/_core/model/graphs/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,16 +221,21 @@ def tmp_remove_nodes(self, nodes: list[int]) -> Generator:
considering certain nodes.
"""
edge_list = []
for node in nodes:
edge_list += list(self.in_branches(node))
self.delete_node(node)
node_list = []

yield
try:
for node in nodes:
edge_list += list(self.in_branches(node))

self.delete_node(node)
node_list.append(node)

for node in nodes:
self.add_node(int(node)) # convert to int to avoid type issues when input is e.g. a numpy array
for source, target in edge_list:
self.add_branch(source, target)
yield
finally:
for node in node_list:
self.add_node(int(node)) # convert to int to avoid type issues when input is e.g. a numpy array
for source, target in edge_list:
self.add_branch(source, target)

@contextmanager
def tmp_remove_branches(self, branches: list[tuple[int, int]]) -> Generator:
Expand Down
60 changes: 37 additions & 23 deletions tests/unit/model/graphs/test_graph_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,39 +143,53 @@ def test_graph_in_branches(self, graph: BaseGraphModel):
assert list(graph.in_branches(2)) == [(1, 2), (1, 2), (1, 2)]


def test_tmp_remove_nodes(graph_with_2_routes: BaseGraphModel) -> None:
graph = graph_with_2_routes
class TestTmpRemoveNodes:
def test_tmp_remove_nodes(self, graph_with_2_routes: BaseGraphModel) -> None:
graph = graph_with_2_routes

assert graph.nr_branches == 4

assert graph.nr_branches == 4
# add parallel branches to test whether they are restored correctly
graph.add_branch(1, 5)
graph.add_branch(5, 1)

# add parallel branches to test whether they are restored correctly
graph.add_branch(1, 5)
graph.add_branch(5, 1)
assert graph.nr_nodes == 5
assert graph.nr_branches == 6

assert graph.nr_nodes == 5
assert graph.nr_branches == 6
before_sets = [frozenset(branch) for branch in graph.all_branches]
counter_before = Counter(before_sets)

before_sets = [frozenset(branch) for branch in graph.all_branches]
counter_before = Counter(before_sets)
with graph.tmp_remove_nodes([1, 2]):
assert graph.nr_nodes == 3
assert list(graph.all_branches) == [(5, 4)]

with graph.tmp_remove_nodes([1, 2]):
assert graph.nr_nodes == 3
assert list(graph.all_branches) == [(5, 4)]
assert graph.nr_nodes == 5
assert graph.nr_branches == 6

assert graph.nr_nodes == 5
assert graph.nr_branches == 6
after_sets = [frozenset(branch) for branch in graph.all_branches]
counter_after = Counter(after_sets)
assert counter_before == counter_after

after_sets = [frozenset(branch) for branch in graph.all_branches]
counter_after = Counter(after_sets)
assert counter_before == counter_after
def test_tmp_remove_nodes_array_input(self, graph_with_2_routes: BaseGraphModel) -> None:
with graph_with_2_routes.tmp_remove_nodes(np.array([1, 2])): # type: ignore[arg-type]
pass

# check that the external ids are still all integers instead of e.g. np.int
assert all([isinstance(e_id, int) for e_id in graph_with_2_routes.external_ids])

def test_invalid_tmp_remove_nodes(self, graph_with_2_routes: BaseGraphModel) -> None:
original_graph = deepcopy(graph_with_2_routes)
assert graph_with_2_routes.nr_nodes == 5
assert graph_with_2_routes.nr_branches == 4

def test_tmp_remove_nodes_array_input(graph_with_2_routes: BaseGraphModel) -> None:
with graph_with_2_routes.tmp_remove_nodes(np.array([1, 2])): # type: ignore[arg-type]
pass
# When we remove node 1 and then an non-existing node that crashes the process
with pytest.raises(MissingNodeError), graph_with_2_routes.tmp_remove_nodes([1, 99]):
pass

# check that the external ids are still all integers instead of e.g. np.int
assert all([isinstance(e_id, int) for e_id in graph_with_2_routes.external_ids])
# The remaining graph object should still contain the same nodes and edges.
assert graph_with_2_routes.nr_nodes == 5
assert graph_with_2_routes.nr_branches == 4
assert graph_with_2_routes == original_graph


class TestTmpRemoveBranches:
Expand Down