Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions test/null/test_tensor_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
from tinygrad.tensor import _METADATA
from tinygrad.engine.realize import capturing
from tinygrad.helpers import Context
from tinygrad.uop.ops import all_metadata

@unittest.skip("tensor metadata is no longer supported")
class TestTensorMetadata(unittest.TestCase):
def setUp(self) -> None:
_METADATA.set(None)
all_metadata.clear()
self._ctx = Context(SCACHE=0)
self._ctx.__enter__()
def tearDown(self) -> None:
Expand Down Expand Up @@ -83,6 +84,11 @@ def test_complex_backward(self):
#self.assertEqual(len(bw), 1)
#self.assertEqual(bw[0].name, "sigmoid")

def test_backward(self):
x = Tensor.rand(3, requires_grad=True).realize()
x.relu().sum().backward()
self.assertIn(("relu", True), {(m.name, m.backward) for m in x.grad.uop.metadata})

def test_tracemeta_0(self):
with Context(TRACEMETA=0):
x = Tensor.rand(3, requires_grad=True)
Expand All @@ -107,7 +113,6 @@ def test_metadata_survives_realize_pending_assign(self):
c[:4].assign(shared)
self.assertTrue(self._has_metadata(c[:4].relu(), "relu"))

@unittest.expectedFailure
def test_metadata_lost_realize_pending_assign(self):
shared = Tensor.rand(4)
c = Tensor.zeros(8).contiguous().realize()
Expand Down
3 changes: 2 additions & 1 deletion tinygrad/schedule/rangeify.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,8 @@ def split_store(x:UOp) -> UOp|None:
if stored.op in {Ops.COPY, Ops.BUFFER_VIEW}: ret = stored.replace(src=stored.src + ret.ended_ranges)
else: ret = ret.sink(arg=KernelInfo(opts_to_apply=lctx.opts))

kernel = ret.call(*lctx.map.values(), *lctx.vars.keys())
metadata = tuple(dedup(m for u in x.toposort(gate=lambda u: u.op is not Ops.AFTER) if u.metadata is not None for m in u.metadata))
kernel = ret.call(*lctx.map.values(), *lctx.vars.keys(), metadata=metadata)
if ret.op is Ops.SINK and not all_same([x.device for x in kernel.src[1:] if x.op is not Ops.BIND]):
raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop for b in kernel.src[1:])}")
return kernel
Expand Down
6 changes: 3 additions & 3 deletions tinygrad/uop/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def replace(self, **kwargs) -> UOp:
kwargs.pop("arg", self.arg), kwargs.pop("tag", self.tag))
assert len(kwargs) == 0, f"unused kwargs in replace {list(kwargs)}"
if (self.op, self.dtype, self.src, self.arg, self.tag) == new_args: return self
return UOp(*new_args)
return UOp(*new_args, metadata=self.metadata) # type: ignore[call-arg]
def rtag(self, tag=True): return self.replace(tag=tag)
@recursive_property
def key(self) -> bytes:
Expand Down Expand Up @@ -1417,7 +1417,7 @@ def walk_rewrite(self, root:UOp) -> UOp:
else:
# rebuild node with rewritten srcs
new_src = tuple(self.replace.get(x, x) for x in n.src)
new_n = UOp(n.op, n.dtype, new_src, n.arg, n.tag) if new_src != n.src else n
new_n = UOp(n.op, n.dtype, new_src, n.arg, n.tag, metadata=n.metadata) if new_src != n.src else n # type: ignore[call-arg]
# top-down: try pm on rebuilt node, use result as-is (no re-traversal)
if self.pm is not None and (rewritten:=self.pm_rewrite(new_n)) is not None: new_n = rewritten
self.replace[n] = new_n
Expand Down Expand Up @@ -1474,7 +1474,7 @@ def unified_rewrite(self, root:UOp) -> UOp:
continue
else:
# if srcs changed from rewrites, construct a new UOp with the new srcs
new_src_n = UOp(new_n.op, new_n.dtype, new_src, new_n.arg, new_n.tag)
new_src_n = UOp(new_n.op, new_n.dtype, new_src, new_n.arg, new_n.tag, metadata=new_n.metadata) # type: ignore[call-arg]
# trigger a rewrite of new_src_n, then after that rewrite is done, link it back to n
stack.append((n, 2, new_src_n))
stack.append((new_src_n, 0, new_src_n))
Expand Down
Loading