Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion test/backend/test_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -1036,12 +1036,38 @@ def test_no_extra_contiguous_on_setitem_assign_back(self):
idx = Tensor([1,2,5,6], dtype=dtypes.int32)
flat_base[idx] = Tensor([99,99,99,99])
base.assign(flat_base.reshape(4, 4))
sched = check_schedule(base, 6) # TODO: this is high
sched = check_schedule(base, 2)
run_schedule(sched)
expected = list(range(16))
for i, v in zip([1,2,5,6], [99,99,99,99]): expected[i] = v
np.testing.assert_equal(base.reshape(16).numpy(), expected)

def test_fuse_assign_chain_two(self):
t = Tensor([1, 2, 3, 4]).realize()
t += 1
t += 1
sched = check_schedule(t, 1)
run_schedule(sched)
np.testing.assert_equal(t.numpy(), [3, 4, 5, 6])

def test_fuse_assign_chain_three(self):
t = Tensor([1, 2, 3, 4]).realize()
t += 1
t += 1
t += 1
sched = check_schedule(t, 1)
run_schedule(sched)
np.testing.assert_equal(t.numpy(), [4, 5, 6, 7])

def test_no_fuse_assign_with_other_consumer(self):
t = Tensor([1, 2, 3, 4]).realize()
t += 1
u = t * 2 # intermediate is observed
t += 1
sched = check_schedule([t, u], 4)
run_schedule(sched)
np.testing.assert_equal(t.numpy(), [3, 4, 5, 6])

def test_sparse_categorical_crossentropy_simple(self):
X = Tensor([[0, 2, 3], [1, 2, 3]]).realize()
Y = Tensor([1, 2]).realize()
Expand Down
22 changes: 22 additions & 0 deletions tinygrad/schedule/rangeify.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,18 @@ def fix_assign_hazard(assign:UOp, target:UOp, src:UOp):
if any(s is target.base for s in h.toposort(gate=lambda s:s.op not in ALWAYS_CONTIGUOUS-{Ops.PARAM})):
return assign.replace(src=(target, src.contiguous()))

def _fuse_assign_chain(ctx:dict[UOp, int], assign:UOp, inner:UOp, root:UOp, inner_src:UOp, src:UOp):
"""Fuse ASSIGN(ASSIGN(root, f), g(ASSIGN(root, f))) -> ASSIGN(root, g(f)) when inner ASSIGN has no other consumers."""
if inner.arg is not None: return None # skip if inner has movement ops
if ctx.get(inner, 0) != 2: return None # inner must only be referenced by this assign (target + source)
if inner not in src.toposort(): return None
return assign.replace(src=(root, src.substitute({inner: inner_src})))

pm_fuse_assigns = PatternMatcher([
(UPat(Ops.ASSIGN, src=(UPat(Ops.ASSIGN, src=(UPat(name="root"), UPat(name="inner_src")), name="inner"), UPat(name="src")), name="assign"),
_fuse_assign_chain),
])

def normalize_assign_target_chain(assign:UOp, target:UOp, src:UOp):
root_target = target
while root_target.op is Ops.ASSIGN: root_target = root_target.src[0]
Expand Down Expand Up @@ -514,6 +526,16 @@ def split_store(x:UOp) -> UOp|None:

def get_rangeify(sink:UOp) -> UOp:
if VIZ: graph_rewrite(sink, PatternMatcher([]), name="View Input Graph")

# fuse chained assigns: ASSIGN(ASSIGN(buf, f), g(ASSIGN(buf, f))) -> ASSIGN(buf, g(f)) when intermediate has no other consumers
while True:
ref_counts: dict[UOp, int] = {}
for u in sink.toposort():
for s in u.src: ref_counts[s] = ref_counts.get(s, 0) + 1
new_sink = graph_rewrite(sink, pm_fuse_assigns, ctx=ref_counts, bottom_up=True, name="fuse chained assigns")
if new_sink is sink: break
sink = new_sink

tsink = graph_rewrite(sink, pm_syntactic_sugar+pm_mops+earliest_rewrites, bottom_up=True, name="earliest rewrites")

# convert movement ops to ranges
Expand Down
Loading