diff --git a/invokeai/app/services/shared/README.md b/invokeai/app/services/shared/README.md index 113b7a41e54..a11b04661d8 100644 --- a/invokeai/app/services/shared/README.md +++ b/invokeai/app/services/shared/README.md @@ -123,7 +123,8 @@ mutation helpers. Those helpers reject changes once the affected nodes have alre - `_PreparedExecRegistry` Owns the relationship between source graph nodes and prepared execution graph nodes, plus cached metadata such as iteration path and runtime state. - `_ExecutionMaterializer` Expands source graph nodes into concrete execution graph nodes when the scheduler runs out of - ready work. + ready work. When matching prepared parents for a downstream exec node, skipped prepared exec nodes are ignored and + cannot be selected as live inputs. - `_ExecutionScheduler` Owns indegree transitions, ready queues, class batching, and downstream release on completion. - `_ExecutionRuntime` Owns iteration-path lookup and input hydration for prepared exec nodes. - `_IfBranchScheduler` Applies lazy `If` semantics by deferring branch-local work until the condition is known, then @@ -178,7 +179,9 @@ Run `C` -> `D:0` -> enqueue `D`. Run `D` -> done. - For **CollectInvocation**: gather all incoming `item` values into `collection`, sorting inputs by iteration path so collected results are stable across expanded iterations. Incoming `collection` values are merged first, then incoming `item` values are appended. -- For **IfInvocation**: hydrate only `condition` and the selected branch input. +- For **IfInvocation**: hydrate only `condition` and the selected branch input. If the selected branch's upstream exec + node was skipped and therefore produced no runtime output, the branch input is left at its default value (typically + `None`) instead of raising during hydration. - For all others: deep-copy each incoming edge's value into the destination field. This prevents cross-node mutation through shared references. @@ -191,7 +194,11 @@ Run `C` -> `D:0` -> enqueue `D`. Run `D` -> done. - Once the prepared `If` node resolves its condition: - the selected branch is released - the unselected branch is marked skipped + - unselected input edges on the prepared `If` exec node are pruned from the execution graph so they no longer + participate in downstream indegree accounting - branch-exclusive ancestors of the unselected branch are never executed +- Skipped branch-local exec nodes may still be treated as executed for scheduling purposes, but they do not create + entries in `results`. - Shared ancestors still execute if they are required by the selected branch or by any other live path in the graph. This behavior is implemented in the runtime scheduler, not in the invocation body itself. diff --git a/invokeai/app/services/shared/graph.py b/invokeai/app/services/shared/graph.py index 24c1dd1fe4f..e7f5c4bcd85 100644 --- a/invokeai/app/services/shared/graph.py +++ b/invokeai/app/services/shared/graph.py @@ -194,11 +194,11 @@ def _get_selected_branch_fields(self, node: IfInvocation) -> tuple[str, str]: def _prune_unselected_if_inputs(self, exec_node_id: str, unselected_field: str) -> None: for edge in self._state.execution_graph._get_input_edges(exec_node_id, unselected_field): - if edge.source.node_id in self._state.executed: - continue - if self._state.indegree[exec_node_id] == 0: - raise RuntimeError(f"indegree underflow for {exec_node_id} when pruning {unselected_field}") - self._state.indegree[exec_node_id] -= 1 + if edge.source.node_id not in self._state.executed: + if self._state.indegree[exec_node_id] == 0: + raise RuntimeError(f"indegree underflow for {exec_node_id} when pruning {unselected_field}") + self._state.indegree[exec_node_id] -= 1 + self._state.execution_graph.delete_edge(edge) def _apply_branch_resolution( self, @@ -424,7 +424,11 @@ def get_node_iterators(self, node_id: str, it_graph: Optional[nx.DiGraph] = None return [n for n in nx.ancestors(g, node_id) if isinstance(self._state.graph.get_node(n), IterateInvocation)] def _get_prepared_nodes_for_source(self, source_node_id: str) -> set[str]: - return self._state.source_prepared_mapping[source_node_id] + return { + exec_node_id + for exec_node_id in self._state.source_prepared_mapping[source_node_id] + if self._state._get_prepared_exec_metadata(exec_node_id).state != "skipped" + } def _get_parent_iterator_exec_nodes( self, source_node_id: str, graph: nx.DiGraph, prepared_iterator_nodes: list[str] @@ -743,6 +747,12 @@ def _sort_collect_input_edges(self, input_edges: list[Edge], field_name: str) -> def _get_copied_result_value(self, edge: Edge) -> Any: return copydeep(getattr(self._state.results[edge.source.node_id], edge.source.field)) + def _try_get_copied_result_value(self, edge: Edge) -> tuple[bool, Any]: + source_output = self._state.results.get(edge.source.node_id) + if source_output is None: + return False, None + return True, copydeep(getattr(source_output, edge.source.field)) + def _build_collect_collection(self, input_edges: list[Edge]) -> list[Any]: item_edges = self._sort_collect_input_edges(input_edges, ITEM_FIELD) collection_edges = self._sort_collect_input_edges(input_edges, COLLECTION_FIELD) @@ -771,7 +781,18 @@ def _prepare_collect_inputs(self, node: "CollectInvocation", input_edges: list[E def _prepare_if_inputs(self, node: IfInvocation, input_edges: list[Edge]) -> None: selected_field = self._state._resolved_if_exec_branches.get(node.id) allowed_fields = {"condition", selected_field} if selected_field is not None else {"condition"} - self._set_node_inputs(node, input_edges, allowed_fields) + + for edge in input_edges: + if edge.destination.field not in allowed_fields: + continue + + found_value, copied_value = self._try_get_copied_result_value(edge) + if not found_value: + # A skipped branch-local exec node is considered executed for scheduling purposes, but it does not + # produce an output payload. Leave the optional branch input at its default None instead of crashing. + continue + + setattr(node, edge.destination.field, copied_value) def _prepare_default_inputs(self, node: BaseInvocation, input_edges: list[Edge]) -> None: self._set_node_inputs(node, input_edges) diff --git a/tests/test_graph_execution_state.py b/tests/test_graph_execution_state.py index ffd0ca1559d..c2ea198bc29 100644 --- a/tests/test_graph_execution_state.py +++ b/tests/test_graph_execution_state.py @@ -7,7 +7,7 @@ from invokeai.app.invocations.collections import RangeInvocation from invokeai.app.invocations.logic import IfInvocation, IfInvocationOutput from invokeai.app.invocations.math import AddInvocation, MultiplyInvocation -from invokeai.app.invocations.primitives import BooleanCollectionInvocation, BooleanInvocation +from invokeai.app.invocations.primitives import BooleanCollectionInvocation, BooleanInvocation, BooleanOutput from invokeai.app.services.shared.graph import ( CollectInvocation, Graph, @@ -750,6 +750,146 @@ def test_if_graph_optimized_behavior_keeps_shared_live_consumers_per_iteration() assert executed_source_ids.count("false_branch") == 2 +def test_if_graph_optimized_behavior_handles_selected_true_branch_with_shared_false_input_ancestor(): + graph = Graph() + graph.add_node(BooleanInvocation(id="condition", value=True)) + graph.add_node(AnyTypeTestInvocation(id="shared_item", value="shared")) + graph.add_node(AnyTypeTestInvocation(id="true_item", value="true")) + graph.add_node(CollectInvocation(id="shared_collect")) + graph.add_node(CollectInvocation(id="true_collect")) + graph.add_node(IfInvocation(id="if")) + graph.add_node(AnyTypeTestInvocation(id="selected_output")) + + graph.add_edge(create_edge("condition", "value", "if", "condition")) + graph.add_edge(create_edge("shared_item", "value", "shared_collect", "item")) + graph.add_edge(create_edge("shared_collect", "collection", "true_collect", "collection")) + graph.add_edge(create_edge("true_item", "value", "true_collect", "item")) + graph.add_edge(create_edge("shared_collect", "collection", "if", "false_input")) + graph.add_edge(create_edge("true_collect", "collection", "if", "true_input")) + graph.add_edge(create_edge("if", "value", "selected_output", "value")) + + g = GraphExecutionState(graph=graph) + executed_source_ids = execute_all_nodes(g) + + prepared_selected_output_id = next(iter(g.source_prepared_mapping["selected_output"])) + assert g.results[prepared_selected_output_id].value == ["shared", "true"] + assert set(executed_source_ids) == { + "condition", + "shared_item", + "true_item", + "shared_collect", + "true_collect", + "if", + "selected_output", + } + + +def test_if_graph_optimized_behavior_handles_selected_false_branch_with_shared_true_input_ancestor(): + graph = Graph() + graph.add_node(BooleanInvocation(id="condition", value=False)) + graph.add_node(AnyTypeTestInvocation(id="shared_item", value="shared")) + graph.add_node(AnyTypeTestInvocation(id="true_item", value="true")) + graph.add_node(CollectInvocation(id="shared_collect")) + graph.add_node(CollectInvocation(id="true_collect")) + graph.add_node(IfInvocation(id="if")) + graph.add_node(AnyTypeTestInvocation(id="selected_output")) + + graph.add_edge(create_edge("condition", "value", "if", "condition")) + graph.add_edge(create_edge("shared_item", "value", "shared_collect", "item")) + graph.add_edge(create_edge("shared_collect", "collection", "true_collect", "collection")) + graph.add_edge(create_edge("true_item", "value", "true_collect", "item")) + graph.add_edge(create_edge("shared_collect", "collection", "if", "false_input")) + graph.add_edge(create_edge("true_collect", "collection", "if", "true_input")) + graph.add_edge(create_edge("if", "value", "selected_output", "value")) + + g = GraphExecutionState(graph=graph) + executed_source_ids = execute_all_nodes(g) + + prepared_selected_output_id = next(iter(g.source_prepared_mapping["selected_output"])) + assert g.results[prepared_selected_output_id].value == ["shared"] + assert set(executed_source_ids) == { + "condition", + "shared_item", + "shared_collect", + "if", + "selected_output", + } + assert "true_item" not in executed_source_ids + assert "true_collect" not in executed_source_ids + + +def test_prepare_if_inputs_ignores_selected_branch_sources_without_results(): + graph = Graph() + graph.add_node(BooleanInvocation(id="condition", value=True)) + graph.add_node(PromptTestInvocation(id="true_value", prompt="true branch")) + graph.add_node(IfInvocation(id="if")) + + graph.add_edge(create_edge("condition", "value", "if", "condition")) + graph.add_edge(create_edge("true_value", "prompt", "if", "true_input")) + + g = GraphExecutionState(graph=graph) + + condition_exec_id = g._create_execution_node("condition", [])[0] + true_value_exec_id = g._create_execution_node("true_value", [])[0] + if_exec_id = g._create_execution_node( + "if", + [("condition", condition_exec_id), ("true_value", true_value_exec_id)], + )[0] + + g.executed.add(condition_exec_id) + g.results[condition_exec_id] = BooleanOutput(value=True) + g.executed.add(true_value_exec_id) + g._resolved_if_exec_branches[if_exec_id] = "true_input" + + if_node = g.execution_graph.get_node(if_exec_id) + g._prepare_inputs(if_node) + + assert if_node.condition is True + assert if_node.true_input is None + + +def test_get_iteration_node_ignores_skipped_prepared_exec_nodes(): + graph = Graph() + graph.add_node(PromptTestInvocation(id="value", prompt="branch value")) + + g = GraphExecutionState(graph=graph) + + skipped_exec_id = g._create_execution_node("value", [])[0] + active_exec_id = g._create_execution_node("value", [])[0] + g._set_prepared_exec_state(skipped_exec_id, "skipped") + + selected_exec_id = g._get_iteration_node("value", graph.nx_graph_flat(), g.execution_graph.nx_graph_flat(), []) + + assert selected_exec_id == active_exec_id + + +def test_get_iteration_node_returns_single_active_prepared_exec_node(): + graph = Graph() + graph.add_node(PromptTestInvocation(id="value", prompt="branch value")) + + g = GraphExecutionState(graph=graph) + + active_exec_id = g._create_execution_node("value", [])[0] + + selected_exec_id = g._get_iteration_node("value", graph.nx_graph_flat(), g.execution_graph.nx_graph_flat(), []) + + assert selected_exec_id == active_exec_id + + +def test_get_iteration_node_returns_none_when_only_skipped_prepared_exec_nodes_exist(): + graph = Graph() + graph.add_node(PromptTestInvocation(id="value", prompt="branch value")) + + g = GraphExecutionState(graph=graph) + + skipped_exec_id = g._create_execution_node("value", [])[0] + g._set_prepared_exec_state(skipped_exec_id, "skipped") + + selected_exec_id = g._get_iteration_node("value", graph.nx_graph_flat(), g.execution_graph.nx_graph_flat(), []) + + assert selected_exec_id is None + + def test_are_connection_types_compatible_accepts_subclass_to_base(): """A subclass output should be connectable to a base-class input.