diff --git a/src/power_grid_model_ds/_core/model/graphs/models/base.py b/src/power_grid_model_ds/_core/model/graphs/models/base.py index 77da3c9a..0dd1503c 100644 --- a/src/power_grid_model_ds/_core/model/graphs/models/base.py +++ b/src/power_grid_model_ds/_core/model/graphs/models/base.py @@ -223,6 +223,28 @@ def tmp_remove_nodes(self, nodes: list[int]) -> Generator: for source, target in edge_list: self.add_branch(source, target) + @contextmanager + def tmp_remove_branches(self, branches: list[tuple[int, int]]) -> Generator: + """Context manager that temporarily removes branches from the graph. + + Example: + >>> with graph.tmp_remove_branches([(1, 2), (2, 3)]): + >>> assert not graph.has_branch(1, 2) + >>> assert not graph.has_branch(2, 3) + >>> assert graph.has_branch(1, 2) + >>> assert graph.has_branch(2, 3) + """ + removed_branches = [] + try: + for from_node, to_node in branches: + self.delete_branch(from_node, to_node) + removed_branches.append((from_node, to_node)) + + yield + finally: + for from_node, to_node in removed_branches: + self.add_branch(from_node, to_node) + def get_shortest_path(self, ext_start_node_id: int, ext_end_node_id: int) -> tuple[list[int], int]: """Calculate the shortest path between two nodes diff --git a/tests/unit/model/graphs/test_graph_model.py b/tests/unit/model/graphs/test_graph_model.py index e5a4fd4b..5ce35314 100644 --- a/tests/unit/model/graphs/test_graph_model.py +++ b/tests/unit/model/graphs/test_graph_model.py @@ -178,6 +178,35 @@ def test_tmp_remove_nodes_array_input(graph_with_2_routes: BaseGraphModel) -> No assert all([isinstance(e_id, int) for e_id in graph_with_2_routes.external_ids]) +class TestTmpRemoveBranches: + def test_tmp_remove_branches(self, graph_with_2_routes: BaseGraphModel): + graph = deepcopy(graph_with_2_routes) + + assert graph.has_branch(1, 2) + assert graph.has_branch(2, 3) + + with graph.tmp_remove_branches([(1, 2), (2, 3)]): + assert not graph.has_branch(1, 2) + assert not graph.has_branch(2, 3) + + assert graph == graph_with_2_routes + assert graph.has_branch(1, 2) + assert graph.has_branch(2, 3) + + def test_tmp_remove_branches_non_existent_branch_keeps_graph_as_is(self, graph_with_2_routes: BaseGraphModel): + graph = deepcopy(graph_with_2_routes) + + # If we remove a branch and then a non-existing branch, we should raise an error. + with ( + pytest.raises(MissingBranchError, match="Branch between nodes 1 and 4 does NOT exist"), + graph.tmp_remove_branches([(1, 2), (1, 4)]), + ): + pass + + # And the graph should still be the same as the original afterwards. + assert graph == graph_with_2_routes + + def test_get_components(graph_with_2_routes: BaseGraphModel): graph = graph_with_2_routes graph.add_node(99)