Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions invokeai/app/services/shared/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,8 @@ mutation helpers. Those helpers reject changes once the affected nodes have alre
- `_PreparedExecRegistry` Owns the relationship between source graph nodes and prepared execution graph nodes, plus
cached metadata such as iteration path and runtime state.
- `_ExecutionMaterializer` Expands source graph nodes into concrete execution graph nodes when the scheduler runs out of
ready work.
ready work. When matching prepared parents for a downstream exec node, skipped prepared exec nodes are ignored and
cannot be selected as live inputs.
- `_ExecutionScheduler` Owns indegree transitions, ready queues, class batching, and downstream release on completion.
- `_ExecutionRuntime` Owns iteration-path lookup and input hydration for prepared exec nodes.
- `_IfBranchScheduler` Applies lazy `If` semantics by deferring branch-local work until the condition is known, then
Expand Down Expand Up @@ -178,7 +179,9 @@ Run `C` -> `D:0` -> enqueue `D`. Run `D` -> done.
- For **CollectInvocation**: gather all incoming `item` values into `collection`, sorting inputs by iteration path so
collected results are stable across expanded iterations. Incoming `collection` values are merged first, then incoming
`item` values are appended.
- For **IfInvocation**: hydrate only `condition` and the selected branch input.
- For **IfInvocation**: hydrate only `condition` and the selected branch input. If the selected branch's upstream exec
node was skipped and therefore produced no runtime output, the branch input is left at its default value (typically
`None`) instead of raising during hydration.
- For all others: deep-copy each incoming edge's value into the destination field. This prevents cross-node mutation
through shared references.

Expand All @@ -191,7 +194,11 @@ Run `C` -> `D:0` -> enqueue `D`. Run `D` -> done.
- Once the prepared `If` node resolves its condition:
- the selected branch is released
- the unselected branch is marked skipped
- unselected input edges on the prepared `If` exec node are pruned from the execution graph so they no longer
participate in downstream indegree accounting
- branch-exclusive ancestors of the unselected branch are never executed
- Skipped branch-local exec nodes may still be treated as executed for scheduling purposes, but they do not create
entries in `results`.
- Shared ancestors still execute if they are required by the selected branch or by any other live path in the graph.

This behavior is implemented in the runtime scheduler, not in the invocation body itself.
Expand Down
35 changes: 28 additions & 7 deletions invokeai/app/services/shared/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,11 +194,11 @@ def _get_selected_branch_fields(self, node: IfInvocation) -> tuple[str, str]:

def _prune_unselected_if_inputs(self, exec_node_id: str, unselected_field: str) -> None:
for edge in self._state.execution_graph._get_input_edges(exec_node_id, unselected_field):
if edge.source.node_id in self._state.executed:
continue
if self._state.indegree[exec_node_id] == 0:
raise RuntimeError(f"indegree underflow for {exec_node_id} when pruning {unselected_field}")
self._state.indegree[exec_node_id] -= 1
if edge.source.node_id not in self._state.executed:
if self._state.indegree[exec_node_id] == 0:
raise RuntimeError(f"indegree underflow for {exec_node_id} when pruning {unselected_field}")
self._state.indegree[exec_node_id] -= 1
self._state.execution_graph.delete_edge(edge)

def _apply_branch_resolution(
self,
Expand Down Expand Up @@ -424,7 +424,11 @@ def get_node_iterators(self, node_id: str, it_graph: Optional[nx.DiGraph] = None
return [n for n in nx.ancestors(g, node_id) if isinstance(self._state.graph.get_node(n), IterateInvocation)]

def _get_prepared_nodes_for_source(self, source_node_id: str) -> set[str]:
return self._state.source_prepared_mapping[source_node_id]
return {
exec_node_id
for exec_node_id in self._state.source_prepared_mapping[source_node_id]
if self._state._get_prepared_exec_metadata(exec_node_id).state != "skipped"
}

def _get_parent_iterator_exec_nodes(
self, source_node_id: str, graph: nx.DiGraph, prepared_iterator_nodes: list[str]
Expand Down Expand Up @@ -743,6 +747,12 @@ def _sort_collect_input_edges(self, input_edges: list[Edge], field_name: str) ->
def _get_copied_result_value(self, edge: Edge) -> Any:
return copydeep(getattr(self._state.results[edge.source.node_id], edge.source.field))

def _try_get_copied_result_value(self, edge: Edge) -> tuple[bool, Any]:
source_output = self._state.results.get(edge.source.node_id)
if source_output is None:
return False, None
return True, copydeep(getattr(source_output, edge.source.field))

def _build_collect_collection(self, input_edges: list[Edge]) -> list[Any]:
item_edges = self._sort_collect_input_edges(input_edges, ITEM_FIELD)
collection_edges = self._sort_collect_input_edges(input_edges, COLLECTION_FIELD)
Expand Down Expand Up @@ -771,7 +781,18 @@ def _prepare_collect_inputs(self, node: "CollectInvocation", input_edges: list[E
def _prepare_if_inputs(self, node: IfInvocation, input_edges: list[Edge]) -> None:
selected_field = self._state._resolved_if_exec_branches.get(node.id)
allowed_fields = {"condition", selected_field} if selected_field is not None else {"condition"}
self._set_node_inputs(node, input_edges, allowed_fields)

for edge in input_edges:
if edge.destination.field not in allowed_fields:
continue

found_value, copied_value = self._try_get_copied_result_value(edge)
if not found_value:
# A skipped branch-local exec node is considered executed for scheduling purposes, but it does not
# produce an output payload. Leave the optional branch input at its default None instead of crashing.
continue

setattr(node, edge.destination.field, copied_value)

def _prepare_default_inputs(self, node: BaseInvocation, input_edges: list[Edge]) -> None:
self._set_node_inputs(node, input_edges)
Expand Down
142 changes: 141 additions & 1 deletion tests/test_graph_execution_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from invokeai.app.invocations.collections import RangeInvocation
from invokeai.app.invocations.logic import IfInvocation, IfInvocationOutput
from invokeai.app.invocations.math import AddInvocation, MultiplyInvocation
from invokeai.app.invocations.primitives import BooleanCollectionInvocation, BooleanInvocation
from invokeai.app.invocations.primitives import BooleanCollectionInvocation, BooleanInvocation, BooleanOutput
from invokeai.app.services.shared.graph import (
CollectInvocation,
Graph,
Expand Down Expand Up @@ -750,6 +750,146 @@ def test_if_graph_optimized_behavior_keeps_shared_live_consumers_per_iteration()
assert executed_source_ids.count("false_branch") == 2


def test_if_graph_optimized_behavior_handles_selected_true_branch_with_shared_false_input_ancestor():
graph = Graph()
graph.add_node(BooleanInvocation(id="condition", value=True))
graph.add_node(AnyTypeTestInvocation(id="shared_item", value="shared"))
graph.add_node(AnyTypeTestInvocation(id="true_item", value="true"))
graph.add_node(CollectInvocation(id="shared_collect"))
graph.add_node(CollectInvocation(id="true_collect"))
graph.add_node(IfInvocation(id="if"))
graph.add_node(AnyTypeTestInvocation(id="selected_output"))

graph.add_edge(create_edge("condition", "value", "if", "condition"))
graph.add_edge(create_edge("shared_item", "value", "shared_collect", "item"))
graph.add_edge(create_edge("shared_collect", "collection", "true_collect", "collection"))
graph.add_edge(create_edge("true_item", "value", "true_collect", "item"))
graph.add_edge(create_edge("shared_collect", "collection", "if", "false_input"))
graph.add_edge(create_edge("true_collect", "collection", "if", "true_input"))
graph.add_edge(create_edge("if", "value", "selected_output", "value"))

g = GraphExecutionState(graph=graph)
executed_source_ids = execute_all_nodes(g)

prepared_selected_output_id = next(iter(g.source_prepared_mapping["selected_output"]))
assert g.results[prepared_selected_output_id].value == ["shared", "true"]
assert set(executed_source_ids) == {
"condition",
"shared_item",
"true_item",
"shared_collect",
"true_collect",
"if",
"selected_output",
}


def test_if_graph_optimized_behavior_handles_selected_false_branch_with_shared_true_input_ancestor():
graph = Graph()
graph.add_node(BooleanInvocation(id="condition", value=False))
graph.add_node(AnyTypeTestInvocation(id="shared_item", value="shared"))
graph.add_node(AnyTypeTestInvocation(id="true_item", value="true"))
graph.add_node(CollectInvocation(id="shared_collect"))
graph.add_node(CollectInvocation(id="true_collect"))
graph.add_node(IfInvocation(id="if"))
graph.add_node(AnyTypeTestInvocation(id="selected_output"))

graph.add_edge(create_edge("condition", "value", "if", "condition"))
graph.add_edge(create_edge("shared_item", "value", "shared_collect", "item"))
graph.add_edge(create_edge("shared_collect", "collection", "true_collect", "collection"))
graph.add_edge(create_edge("true_item", "value", "true_collect", "item"))
graph.add_edge(create_edge("shared_collect", "collection", "if", "false_input"))
graph.add_edge(create_edge("true_collect", "collection", "if", "true_input"))
graph.add_edge(create_edge("if", "value", "selected_output", "value"))

g = GraphExecutionState(graph=graph)
executed_source_ids = execute_all_nodes(g)

prepared_selected_output_id = next(iter(g.source_prepared_mapping["selected_output"]))
assert g.results[prepared_selected_output_id].value == ["shared"]
assert set(executed_source_ids) == {
"condition",
"shared_item",
"shared_collect",
"if",
"selected_output",
}
assert "true_item" not in executed_source_ids
assert "true_collect" not in executed_source_ids


def test_prepare_if_inputs_ignores_selected_branch_sources_without_results():
graph = Graph()
graph.add_node(BooleanInvocation(id="condition", value=True))
graph.add_node(PromptTestInvocation(id="true_value", prompt="true branch"))
graph.add_node(IfInvocation(id="if"))

graph.add_edge(create_edge("condition", "value", "if", "condition"))
graph.add_edge(create_edge("true_value", "prompt", "if", "true_input"))

g = GraphExecutionState(graph=graph)

condition_exec_id = g._create_execution_node("condition", [])[0]
true_value_exec_id = g._create_execution_node("true_value", [])[0]
if_exec_id = g._create_execution_node(
"if",
[("condition", condition_exec_id), ("true_value", true_value_exec_id)],
)[0]

g.executed.add(condition_exec_id)
g.results[condition_exec_id] = BooleanOutput(value=True)
g.executed.add(true_value_exec_id)
g._resolved_if_exec_branches[if_exec_id] = "true_input"

if_node = g.execution_graph.get_node(if_exec_id)
g._prepare_inputs(if_node)

assert if_node.condition is True
assert if_node.true_input is None


def test_get_iteration_node_ignores_skipped_prepared_exec_nodes():
graph = Graph()
graph.add_node(PromptTestInvocation(id="value", prompt="branch value"))

g = GraphExecutionState(graph=graph)

skipped_exec_id = g._create_execution_node("value", [])[0]
active_exec_id = g._create_execution_node("value", [])[0]
g._set_prepared_exec_state(skipped_exec_id, "skipped")

selected_exec_id = g._get_iteration_node("value", graph.nx_graph_flat(), g.execution_graph.nx_graph_flat(), [])

assert selected_exec_id == active_exec_id


def test_get_iteration_node_returns_single_active_prepared_exec_node():
graph = Graph()
graph.add_node(PromptTestInvocation(id="value", prompt="branch value"))

g = GraphExecutionState(graph=graph)

active_exec_id = g._create_execution_node("value", [])[0]

selected_exec_id = g._get_iteration_node("value", graph.nx_graph_flat(), g.execution_graph.nx_graph_flat(), [])

assert selected_exec_id == active_exec_id


def test_get_iteration_node_returns_none_when_only_skipped_prepared_exec_nodes_exist():
graph = Graph()
graph.add_node(PromptTestInvocation(id="value", prompt="branch value"))

g = GraphExecutionState(graph=graph)

skipped_exec_id = g._create_execution_node("value", [])[0]
g._set_prepared_exec_state(skipped_exec_id, "skipped")

selected_exec_id = g._get_iteration_node("value", graph.nx_graph_flat(), g.execution_graph.nx_graph_flat(), [])

assert selected_exec_id is None


def test_are_connection_types_compatible_accepts_subclass_to_base():
"""A subclass output should be connectable to a base-class input.

Expand Down
Loading