use apply uop for assign to fix assign metadata (#12732)

* use apply uop for assign

* fix metadata for assign

* fix backward metadata

* those aren't real tests
This commit is contained in:
George Hotz
2025-10-16 20:34:12 +08:00
committed by GitHub
parent 3aa2277b8f
commit 8be7844b2e
5 changed files with 37 additions and 23 deletions

View File

@@ -1,6 +1,6 @@
import unittest
from tinygrad import Tensor, nn
from tinygrad.helpers import Context, GlobalCounters
from tinygrad.helpers import Context, GlobalCounters, CI
from tinygrad.uop.ops import graph_rewrite, PatternMatcher, UPat, Ops
class TestRangeifyAssign(unittest.TestCase):
@@ -17,8 +17,22 @@ class TestRangeifyAssign(unittest.TestCase):
self.assertListEqual(lst, lst3)
self.assertListEqual(lst2, B.permute(1, 0).tolist())
class TestRangeifyEdgeCase(unittest.TestCase):
def test_matmul_relu_cat(self):
a = Tensor.ones(100, 512).contiguous().realize()
c = Tensor.ones(1, 512).contiguous().realize()
cm = Tensor.ones(512, 512)
c = c @ cm
c = c.relu()
res = Tensor.cat(a, c, dim=0)
self.assertEqual(res.numpy()[-1, :16].tolist(), [512] * 16)
# *** non CI rangeify tests below this line ***
N = 256
@unittest.skipIf(CI, "useless in CI, doesn't test anything")
class TestRangeifyOpt(unittest.TestCase):
def test_randperm(self):
Tensor.randperm(10000).realize()
@@ -54,6 +68,7 @@ class TestRangeifyOpt(unittest.TestCase):
A = Tensor.empty(8,8,8,8).permute(1,0,3,2).flatten()
A.sum().realize()
@unittest.skipIf(CI, "useless in CI, doesn't test anything")
class TestRangeify(unittest.TestCase):
def test_groupnorm(self):
# ranges 1 and 3 are merging
@@ -280,16 +295,5 @@ class TestRangeifyPM(unittest.TestCase):
b = self.base.pad(((0,1),(0,0))).pad(((0,0),(0,1)))
self.assert_same(a, b)
class TestRangeifyEdgeCase(unittest.TestCase):
def test_matmul_relu_cat(self):
a = Tensor.ones(100, 512).contiguous().realize()
c = Tensor.ones(1, 512).contiguous().realize()
cm = Tensor.ones(512, 512)
c = c @ cm
c = c.relu()
res = Tensor.cat(a, c, dim=0)
self.assertEqual(res.numpy()[-1, :16].tolist(), [512] * 16)
if __name__ == '__main__':
unittest.main()

View File

@@ -810,6 +810,13 @@ class TestTensorMetadata(unittest.TestCase):
self.assertEqual(len(si.metadata), 1)
self.assertEqual(si.metadata[0].name, "relu")
def test_assign(self):
x = Tensor.empty(10, 10).realize()
x.assign(Tensor.ones(10, 10).contiguous())
si = x.schedule()[-1]
self.assertEqual(len(si.metadata), 1)
self.assertEqual(si.metadata[0].name, "assign")
def test_complex(self):
x = Tensor.rand(3, requires_grad=True)
y = Tensor.rand(3, requires_grad=True)

View File

@@ -4,13 +4,13 @@ from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, all_metadata
from tinygrad.helpers import argsort
def reduce_gradient(ctx:UOp, ret:UOp):
def to_inp_shape(x): return x.reshape(x.shape+(1,)*(len(ret.src[0].shape)-len(x.shape))).expand(ret.src[0].shape)
if ret.arg[0] == Ops.ADD: return (to_inp_shape(ctx),)
def broadcast_to_input(x): return x.reshape(x.shape+(1,)*(len(ret.src[0].shape)-len(x.shape))).expand(ret.src[0].shape)
if ret.arg[0] == Ops.ADD: return (broadcast_to_input(ctx),)
if ret.arg[0] == Ops.MAX:
max_is_1s = ret.src[0].eq(to_inp_shape(ret)).cast(ctx.dtype)
div = to_inp_shape(max_is_1s.r(Ops.ADD, ret.arg[1]))
return ((max_is_1s/div) * to_inp_shape(ctx),)
if ret.arg[0] == Ops.MUL: return (to_inp_shape(ctx * ret) / ret.src[0],)
mask = ret.src[0].eq(broadcast_to_input(ret)).cast(ctx.dtype)
count = mask.r(Ops.ADD, ret.arg[1])
return ((mask/broadcast_to_input(count)) * broadcast_to_input(ctx),)
if ret.arg[0] == Ops.MUL: return (broadcast_to_input(ctx * ret) / ret.src[0],)
# ctx is grad_output
pm_gradient = PatternMatcher([
@@ -60,5 +60,9 @@ def compute_gradient(root:UOp, root_grad:UOp, targets:set[UOp]) -> dict[UOp, UOp
if v is None: continue
if k in grads: grads[k] = grads[k] + v
else: grads[k] = v
if len(forward_metadata:=all_metadata.get(t0, ())): all_metadata[v] = tuple(dataclasses.replace(x, backward=True) for x in forward_metadata)
if len(forward_metadata:=all_metadata.get(t0, ())):
backward_metadata = tuple(dataclasses.replace(x, backward=True) for x in forward_metadata)
# we add the backward metadata to everything new in the graph
for bw_uop in v.toposort(lambda x: x not in (t0, *t0.src, grads[t0])):
all_metadata[bw_uop] = all_metadata.get(bw_uop, ())+backward_metadata
return grads

View File

@@ -272,7 +272,7 @@ def bufferize_to_store(x:UOp):
assert assign_target.op is Ops.INDEX, f"{assign_target.op} is not index"
# in assign, this is the buffer size, not the bufferize size
# TODO: assign_mops here
ret = assign_target.replace(dtype=sdtype).store(assign_src, *rngs, dtype=x.dtype)
ret = assign_target.replace(dtype=sdtype).store(assign_src, *rngs, dtype=x.dtype).replace(tag=x.tag)
mops = []
walk = assign_mops
while walk is not assign_mops.base:
@@ -284,7 +284,7 @@ def bufferize_to_store(x:UOp):
# NOTE: the DEFINE_LOCAL needs to be disambiguated here
if sdtype.addrspace == AddrSpace.GLOBAL:
buf = UOp.new_buffer(x.arg.device, size, x.dtype)
ret = buf.reshape(shape).index(*rngs, dtype=sdtype).store(x.src[0], *rngs, dtype=x.dtype)
ret = buf.reshape(shape).index(*rngs, dtype=sdtype).store(x.src[0], *rngs, dtype=x.dtype).replace(tag=x.tag)
ret = ret.forced_reshape(shape)
# TODO: is this right? what if it's offset
if any(r.op is Ops.RANGE and r.src[0].op is not Ops.CONST for r in rngs):

View File

@@ -294,8 +294,7 @@ class Tensor(MathTrait):
assert self.shape == x.shape, f"assign shape mismatch {self.shape} != {x.shape}"
assert self.device == x.device, f"assign device mismatch {self.device} != {x.device}"
assert self.dtype == x.dtype, f"assign dtype mismatch {self.dtype} != {x.dtype}"
self.uop = self.uop.assign(x.uop)
return self
return self.replace(self._apply_uop(UOp.assign, x))
def detach(self) -> Tensor:
"""