Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions src/power_grid_model_ds/_core/model/graphs/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,28 @@ def tmp_remove_nodes(self, nodes: list[int]) -> Generator:
for source, target in edge_list:
self.add_branch(source, target)

@contextmanager
def tmp_remove_branches(self, branches: list[tuple[int, int]]) -> Generator:
"""Context manager that temporarily removes branches from the graph.

Example:
>>> with graph.tmp_remove_branches([(1, 2), (2, 3)]):
>>> assert not graph.has_branch(1, 2)
>>> assert not graph.has_branch(2, 3)
>>> assert graph.has_branch(1, 2)
>>> assert graph.has_branch(2, 3)
"""
removed_branches = []
try:
for from_node, to_node in branches:
self.delete_branch(from_node, to_node)
removed_branches.append((from_node, to_node))

yield
finally:
for from_node, to_node in removed_branches:
self.add_branch(from_node, to_node)

def get_shortest_path(self, ext_start_node_id: int, ext_end_node_id: int) -> tuple[list[int], int]:
"""Calculate the shortest path between two nodes

Expand Down
29 changes: 29 additions & 0 deletions tests/unit/model/graphs/test_graph_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,35 @@ def test_tmp_remove_nodes_array_input(graph_with_2_routes: BaseGraphModel) -> No
assert all([isinstance(e_id, int) for e_id in graph_with_2_routes.external_ids])


class TestTmpRemoveBranches:
def test_tmp_remove_branches(self, graph_with_2_routes: BaseGraphModel):
graph = deepcopy(graph_with_2_routes)

assert graph.has_branch(1, 2)
assert graph.has_branch(2, 3)

with graph.tmp_remove_branches([(1, 2), (2, 3)]):
assert not graph.has_branch(1, 2)
assert not graph.has_branch(2, 3)

assert graph == graph_with_2_routes
assert graph.has_branch(1, 2)
assert graph.has_branch(2, 3)

def test_tmp_remove_branches_non_existent_branch_keeps_graph_as_is(self, graph_with_2_routes: BaseGraphModel):
graph = deepcopy(graph_with_2_routes)

# If we remove a branch and then a non-existing branch, we should raise an error.
with (
pytest.raises(MissingBranchError, match="Branch between nodes 1 and 4 does NOT exist"),
graph.tmp_remove_branches([(1, 2), (1, 4)]),
):
pass

# And the graph should still be the same as the original afterwards.
assert graph == graph_with_2_routes


def test_get_components(graph_with_2_routes: BaseGraphModel):
graph = graph_with_2_routes
graph.add_node(99)
Expand Down