From 4a8c514c4644cd11d49e71cc2970f1db0ea11c11 Mon Sep 17 00:00:00 2001 From: Chen-Yu Yang Date: Fri, 20 Feb 2026 21:31:32 -0500 Subject: [PATCH] fuse assign chain --- test/backend/test_schedule.py | 28 +++++++++++++++++++++++++++- tinygrad/schedule/rangeify.py | 22 ++++++++++++++++++++++ 2 files changed, 49 insertions(+), 1 deletion(-) diff --git a/test/backend/test_schedule.py b/test/backend/test_schedule.py index c1f419ed40bf6..25f96c62a8872 100644 --- a/test/backend/test_schedule.py +++ b/test/backend/test_schedule.py @@ -1036,12 +1036,38 @@ def test_no_extra_contiguous_on_setitem_assign_back(self): idx = Tensor([1,2,5,6], dtype=dtypes.int32) flat_base[idx] = Tensor([99,99,99,99]) base.assign(flat_base.reshape(4, 4)) - sched = check_schedule(base, 6) # TODO: this is high + sched = check_schedule(base, 2) run_schedule(sched) expected = list(range(16)) for i, v in zip([1,2,5,6], [99,99,99,99]): expected[i] = v np.testing.assert_equal(base.reshape(16).numpy(), expected) + def test_fuse_assign_chain_two(self): + t = Tensor([1, 2, 3, 4]).realize() + t += 1 + t += 1 + sched = check_schedule(t, 1) + run_schedule(sched) + np.testing.assert_equal(t.numpy(), [3, 4, 5, 6]) + + def test_fuse_assign_chain_three(self): + t = Tensor([1, 2, 3, 4]).realize() + t += 1 + t += 1 + t += 1 + sched = check_schedule(t, 1) + run_schedule(sched) + np.testing.assert_equal(t.numpy(), [4, 5, 6, 7]) + + def test_no_fuse_assign_with_other_consumer(self): + t = Tensor([1, 2, 3, 4]).realize() + t += 1 + u = t * 2 # intermediate is observed + t += 1 + sched = check_schedule([t, u], 4) + run_schedule(sched) + np.testing.assert_equal(t.numpy(), [3, 4, 5, 6]) + def test_sparse_categorical_crossentropy_simple(self): X = Tensor([[0, 2, 3], [1, 2, 3]]).realize() Y = Tensor([1, 2]).realize() diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index fa1ee0249dd8d..327321aeb1c94 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -55,6 +55,18 @@ def fix_assign_hazard(assign:UOp, target:UOp, src:UOp): if any(s is target.base for s in h.toposort(gate=lambda s:s.op not in ALWAYS_CONTIGUOUS-{Ops.PARAM})): return assign.replace(src=(target, src.contiguous())) +def _fuse_assign_chain(ctx:dict[UOp, int], assign:UOp, inner:UOp, root:UOp, inner_src:UOp, src:UOp): + """Fuse ASSIGN(ASSIGN(root, f), g(ASSIGN(root, f))) -> ASSIGN(root, g(f)) when inner ASSIGN has no other consumers.""" + if inner.arg is not None: return None # skip if inner has movement ops + if ctx.get(inner, 0) != 2: return None # inner must only be referenced by this assign (target + source) + if inner not in src.toposort(): return None + return assign.replace(src=(root, src.substitute({inner: inner_src}))) + +pm_fuse_assigns = PatternMatcher([ + (UPat(Ops.ASSIGN, src=(UPat(Ops.ASSIGN, src=(UPat(name="root"), UPat(name="inner_src")), name="inner"), UPat(name="src")), name="assign"), + _fuse_assign_chain), +]) + def normalize_assign_target_chain(assign:UOp, target:UOp, src:UOp): root_target = target while root_target.op is Ops.ASSIGN: root_target = root_target.src[0] @@ -514,6 +526,16 @@ def split_store(x:UOp) -> UOp|None: def get_rangeify(sink:UOp) -> UOp: if VIZ: graph_rewrite(sink, PatternMatcher([]), name="View Input Graph") + + # fuse chained assigns: ASSIGN(ASSIGN(buf, f), g(ASSIGN(buf, f))) -> ASSIGN(buf, g(f)) when intermediate has no other consumers + while True: + ref_counts: dict[UOp, int] = {} + for u in sink.toposort(): + for s in u.src: ref_counts[s] = ref_counts.get(s, 0) + 1 + new_sink = graph_rewrite(sink, pm_fuse_assigns, ctx=ref_counts, bottom_up=True, name="fuse chained assigns") + if new_sink is sink: break + sink = new_sink + tsink = graph_rewrite(sink, pm_syntactic_sugar+pm_mops+earliest_rewrites, bottom_up=True, name="earliest rewrites") # convert movement ops to ranges