Scan: remove unused outputs#2067
Conversation
|
re: OFG, we shouldn't have this problem if the OFG gets inlined right? |
e2d85b1 to
731c68d
Compare
Right, only applies to non-inlined OFGs |
|
should we just be more aggressive about inline? We already are iirc. |
Not necessarily, but if there's late pipeline optimizations we can do on non-inlined OFG we should that's all. Like Scan it's an inner graph Op, so many of the questions posed are similar, unlike Scan it's expected to have multiple instances of the same Op by nature (it's one of its applications, so we have to be more careful when we go about doing it). Another case is inplace, can make it much faster, but again only if it doesn't lead to de-duplicated OFGs in the final graph |
Yeah, ordering is hard. This would be another place egglog could help (deciding the right moment to inplace) |
5f7de8a to
1a190d4
Compare
|
one test failing but ready for review |
ffe7563 to
9508b01
Compare
|
|
||
|
|
||
| @node_rewriter([Scan]) | ||
| def scan_inline_invariant_constants(fgraph, node): |
There was a problem hiding this comment.
This can be extended for Alloc that broadcast the front dim, or constant that only have duplicates along 0th axis (a bit more expensive to check). We did this in a client project and yielded some useful simplifications. It also need not inline to provide speedup moving from sequences -> non_sequences is already nice, and then if it's something like an Alloc that's usually subsumed by other internal Elemwise/Alloc it can still be inlined, and the non-sequence is just the value/shape. Something to do next. The fact we isolated the scope of this rewrite makes it nicer to extend next
| """ | ||
| op = node.op | ||
|
|
||
| def _duplicates(inner_list, outer_list): |
There was a problem hiding this comment.
This rewrite is pretty expensive. I tshould rely on merge optimizer and just check x is y. Don't want to regress on this PR though
jessegrabowski
left a comment
There was a problem hiding this comment.
I did as best i could with this one while only reading the code. Approving on trust + things seem fine. I could check out the branch and really try to understand it if you want, or we could do a walkthrough call, but I'm also comfortable with it going in.
| op, | ||
| node, | ||
| *, | ||
| drop_seqs=frozenset(), |
There was a problem hiding this comment.
considering typehinting these as frozenset[int] for clairty
|
|
||
| inner_outputs = op.inner_outputs | ||
| if inner_substitutions: | ||
| inner_outputs = clone_replace(inner_outputs, replace=inner_substitutions) |
There was a problem hiding this comment.
why clone here, in case an input is another Op that needs to be rebuild (e.g. scan into scan?)
There was a problem hiding this comment.
as opposed to graph replace you mean? I think just for safety to avoid sharing variables between distinct scans (since we still mutate the fgraph)
| else: | ||
| extra_dims = [_y.shape[i] for i in range(1, _y.ndim)] | ||
| zero_buf = pt.zeros((nit_sot_size, *extra_dims), dtype=_y.dtype) | ||
| y = set_subtensor(zero_buf[:n_steps], _y) |
There was a problem hiding this comment.
is it always correct to pad left?
There was a problem hiding this comment.
Yeah it may just be useless because the elemwise already had the total shape. Or did you mean something else?
There was a problem hiding this comment.
I was specifically thinking this was about the case where the buffer got trimmed by scan optimizations. I don't know if that trimming always happens on the left side (the first time steps) or if it can also be on the right.
There was a problem hiding this comment.
trimming is always on the left (last states are kept never discarded in preference of earlier ones), but tbh I wouldn't be surprised if this rewrite completely broke with a scan save mem.
Even the bot text assumption is wrong then, n_steps can be larget than nit_sot size in that case.
I'm not interested in this rewrite tbh, I think it's myopic/ worse in expectation. I just patched the bug in the logic assuming non scan save mem. May be worth checking if it handles save mem cases.
I wouldn't be surprised "push out" rewrites don't handle this, just like pullback itself doesn't
There was a problem hiding this comment.
Yeah this and other paths in this rewrite don't handle trimmed buffers.
In general we should distinguish scan with trimmed buffers at the Op level, because otherwise it complicates the logic and breaks many things.
This PR fixes the nit-sot case and makes the rewrite self-consistent with full scan. The pre-existing trimmed scan limitations remain.
Scan save mem is registered separately later and continues to be unsafe if called out of order. I'd postpone the cleanup it clearly needs.
| # candidate state inputs. Un-confirming a candidate folds its | ||
| # outputs' reached inputs into the live set immediately, so later | ||
| # candidates in the same pass see the update. |
There was a problem hiding this comment.
Is there any ordering issues that could arise from this procedure? Like if a certain candidate depends indirectly on the direct inputs of a certain other candidate, so if they trigger in the wrong order the 1st won't see the the inputs in the live set yet and be wrongly dropped.
I can't think of a concrete example, but wondering out loud.
There was a problem hiding this comment.
What's an indirect dependency?
For a -> b -> c Everything that has b as an ancestor must also have a as an ancestor since a in an ancestor of b.
There was a problem hiding this comment.
If you mean sit-sot1 depends on sit-sot2 which depends on sit-sot3, then 1 can be dropped unconditionally, 2 only if 1 is also dropped (or else the corresponding input is visited by the survivor 1) and 3 only if 1 and 2 are dropped.
Think about the cases that invalidate dropping sot-sot 3:
Direct: sit-sot2 is not a drop candidate, therefore sit-sot3 inner input is an ancestor of it, and is "reachable from survivors"
Two stage: sit-sot2 is a drop candidate as well (but not sit-sot1). We already showed we can't drop sit-sot2 because sit-sot1 depends on it. Therefore the loop will remove sit-sot2 from candidates (and add its root inputs as needed - which includes sit-sot3 inner input), and we are back to direct case invalidatiom.
There was a problem hiding this comment.
ok ok, thanks for thinking hard about it even when I didn't
There was a problem hiding this comment.
Added a test like that, only had the cyclical dependency before (a depends on b, and b on a)
a7c8417 to
7eb34c9
Compare
|
mypy :) |
When folding a stateless nit_sot scan into an Elemwise, the pushed-out result has length == n_steps. If the nit_sot's declared outer buffer size is larger (e.g., a Scan built via Scan.pullback with truncate_gradient, where nit_sot_size tracks the forward's step count but n_steps is the truncated grad_steps), the direct fold silently drops the trailing slots that the scan's allocator would have left zero-initialized, breaking any downstream consumer that reads the full buffer. Pad the Elemwise result via set_subtensor(zeros(nit_sot_size, ...), _y) when the two sizes aren't the same Variable. The direct-fold shortcut still fires for the common case where pytensor.scan reuses the same n_steps Variable for both.
Pure reachability analysis that drops, in a single pass, unused outputs
and inputs from a Scan node:
* State slots (mit_mot / mit_sot / sit_sot / nit_sot / untraced_sit_sot)
whose outer output has no clients, provided none of their inner inputs
is reached from any surviving inner output. Cross-dependent unused
states are resolved together.
* Sequences and non-sequences that the rebuilt inner graph no longer
references.
Rebuild plumbing is factored into ``_rebuild_scan_with_new_signature``,
a helper other rewrites can reuse to produce a Scan with a trimmed
signature (drop categories individually; optionally apply inner-graph
substitutions).
Registered in scan_eqopt2 and wired into Scan.L_op so the pullback graph
is cleaned eagerly when disconnected cotangents are present, avoiding
unused gradient computation in the user-facing graph.
``scan_save_mem`` is responsible for buffer-size trimming; unused-output removal is owned by ``scan_remove_unused``. Strip the orphane-output detection, the ``scan_can_remove_outs`` reachability call, and the ``compress_outs`` rebuild path. The new scan is built directly by reusing the existing op with only the resized outer inputs -- no state drops happen here. The one piece of the old orphan path that was load-bearing for buffer sizing was the "required orphan" case (state with no external clients but needed by the inner recurrence). It's replaced by a small post-loop that trims such mit_sot / sit_sot buffers to their minimum (``taps + 1`` under prealloc, ``taps`` otherwise). ``scan_can_remove_outs`` and ``compress_outs`` had no other callers and are deleted from ``pytensor.scan.utils``. Also register ``scan_remove_unused`` at the top-level optdb so it is discoverable via its own tag -- needed for tests that explicitly include ``scan_save_mem`` under FAST_COMPILE, where buffer trimming without unused-output removal leaves orphan nit_sots in the final Scan.
The old rewrite bundled three responsibilities. They're now separated:
* ``scan_inline_invariant_constants`` -- inlines compile-time-constant,
iteration-invariant inputs (non-sequence Constants, and sequences
whose outer input is a ``TensorConstant`` with a uniform value) into
the inner graph. Enables inner constant-folding.
* ``scan_merge_duplicate_inputs`` -- deduplicates outer seqs / non_seqs
that are ``equal_computations``.
* ``scan_remove_unused`` (commit 1) -- drops unused outputs and inputs,
which also cleans up the stale inputs left behind by the other two.
The three are registered together in a single ``dfs_rewriter`` at the
four positions that used to host ``remove_constants_and_unused_inputs_scan``,
with ``scan_remove_unused`` first (most powerful, always reduces), then
inline, then merge.
Tests for the old bundled rewrite now exercise the split combination.
7eb34c9 to
0ba3ca8
Compare
We were
missing this basic optimization to simplify scan when only some outputs were used.half assing this optimization. it was partially bundled with scan unused inputs and scan save memory (both of which do to much and not enough).Specifically we never dropped mit mot or untraced sit sot (because they didn't play in scan save memory main goal) and so a pullback of a scan with unused outputs always had a useless mitmot and kept a reference to the equally useless output of the forward pass. Double waste.
This PR adds a rewrite whose single purpose is triming away useless computation and inputs. It removes the respective behavior from the two other rewrites. It is also called eagerly in the pullback. There's a lot of bookeeping already in the pullback method that I feel is better to keep as is and just patch at the end reusing the rewrite.
Also fixed a bug in scan_push_out_seq which violated the contract where nit_sot can have larger buffer than n_steps (shows up in truncated/while scan gradients). This bug revealed itself with the new pullback cleanup, but is otherwise orthogonal.
PS: We need something like this for OpFromGraph, but there we need to be careful when to do it: OFG that encapsulate an Op with a specific signature, (AllocDiag, Einsum, RVs) shouldn't be mutated because other rewrites may expected the standard signature when working with them as closed box Ops. Similarly if the same OFG is reused across multiple nodes, we only want to do it, when outputs aren't used across all uses, since one of the goals of OFG is to reduce compilation/rewrite work for repeated subgraphs.