mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
use apply uop for assign to fix assign metadata (#12732)
* use apply uop for assign * fix metadata for assign * fix backward metadata * those aren't real tests
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import unittest
|
||||
from tinygrad import Tensor, nn
|
||||
from tinygrad.helpers import Context, GlobalCounters
|
||||
from tinygrad.helpers import Context, GlobalCounters, CI
|
||||
from tinygrad.uop.ops import graph_rewrite, PatternMatcher, UPat, Ops
|
||||
|
||||
class TestRangeifyAssign(unittest.TestCase):
|
||||
@@ -17,8 +17,22 @@ class TestRangeifyAssign(unittest.TestCase):
|
||||
self.assertListEqual(lst, lst3)
|
||||
self.assertListEqual(lst2, B.permute(1, 0).tolist())
|
||||
|
||||
class TestRangeifyEdgeCase(unittest.TestCase):
|
||||
def test_matmul_relu_cat(self):
|
||||
a = Tensor.ones(100, 512).contiguous().realize()
|
||||
c = Tensor.ones(1, 512).contiguous().realize()
|
||||
cm = Tensor.ones(512, 512)
|
||||
c = c @ cm
|
||||
c = c.relu()
|
||||
|
||||
res = Tensor.cat(a, c, dim=0)
|
||||
self.assertEqual(res.numpy()[-1, :16].tolist(), [512] * 16)
|
||||
|
||||
# *** non CI rangeify tests below this line ***
|
||||
|
||||
N = 256
|
||||
|
||||
@unittest.skipIf(CI, "useless in CI, doesn't test anything")
|
||||
class TestRangeifyOpt(unittest.TestCase):
|
||||
def test_randperm(self):
|
||||
Tensor.randperm(10000).realize()
|
||||
@@ -54,6 +68,7 @@ class TestRangeifyOpt(unittest.TestCase):
|
||||
A = Tensor.empty(8,8,8,8).permute(1,0,3,2).flatten()
|
||||
A.sum().realize()
|
||||
|
||||
@unittest.skipIf(CI, "useless in CI, doesn't test anything")
|
||||
class TestRangeify(unittest.TestCase):
|
||||
def test_groupnorm(self):
|
||||
# ranges 1 and 3 are merging
|
||||
@@ -280,16 +295,5 @@ class TestRangeifyPM(unittest.TestCase):
|
||||
b = self.base.pad(((0,1),(0,0))).pad(((0,0),(0,1)))
|
||||
self.assert_same(a, b)
|
||||
|
||||
class TestRangeifyEdgeCase(unittest.TestCase):
|
||||
def test_matmul_relu_cat(self):
|
||||
a = Tensor.ones(100, 512).contiguous().realize()
|
||||
c = Tensor.ones(1, 512).contiguous().realize()
|
||||
cm = Tensor.ones(512, 512)
|
||||
c = c @ cm
|
||||
c = c.relu()
|
||||
|
||||
res = Tensor.cat(a, c, dim=0)
|
||||
self.assertEqual(res.numpy()[-1, :16].tolist(), [512] * 16)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -810,6 +810,13 @@ class TestTensorMetadata(unittest.TestCase):
|
||||
self.assertEqual(len(si.metadata), 1)
|
||||
self.assertEqual(si.metadata[0].name, "relu")
|
||||
|
||||
def test_assign(self):
|
||||
x = Tensor.empty(10, 10).realize()
|
||||
x.assign(Tensor.ones(10, 10).contiguous())
|
||||
si = x.schedule()[-1]
|
||||
self.assertEqual(len(si.metadata), 1)
|
||||
self.assertEqual(si.metadata[0].name, "assign")
|
||||
|
||||
def test_complex(self):
|
||||
x = Tensor.rand(3, requires_grad=True)
|
||||
y = Tensor.rand(3, requires_grad=True)
|
||||
|
||||
@@ -4,13 +4,13 @@ from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, all_metadata
|
||||
from tinygrad.helpers import argsort
|
||||
|
||||
def reduce_gradient(ctx:UOp, ret:UOp):
|
||||
def to_inp_shape(x): return x.reshape(x.shape+(1,)*(len(ret.src[0].shape)-len(x.shape))).expand(ret.src[0].shape)
|
||||
if ret.arg[0] == Ops.ADD: return (to_inp_shape(ctx),)
|
||||
def broadcast_to_input(x): return x.reshape(x.shape+(1,)*(len(ret.src[0].shape)-len(x.shape))).expand(ret.src[0].shape)
|
||||
if ret.arg[0] == Ops.ADD: return (broadcast_to_input(ctx),)
|
||||
if ret.arg[0] == Ops.MAX:
|
||||
max_is_1s = ret.src[0].eq(to_inp_shape(ret)).cast(ctx.dtype)
|
||||
div = to_inp_shape(max_is_1s.r(Ops.ADD, ret.arg[1]))
|
||||
return ((max_is_1s/div) * to_inp_shape(ctx),)
|
||||
if ret.arg[0] == Ops.MUL: return (to_inp_shape(ctx * ret) / ret.src[0],)
|
||||
mask = ret.src[0].eq(broadcast_to_input(ret)).cast(ctx.dtype)
|
||||
count = mask.r(Ops.ADD, ret.arg[1])
|
||||
return ((mask/broadcast_to_input(count)) * broadcast_to_input(ctx),)
|
||||
if ret.arg[0] == Ops.MUL: return (broadcast_to_input(ctx * ret) / ret.src[0],)
|
||||
|
||||
# ctx is grad_output
|
||||
pm_gradient = PatternMatcher([
|
||||
@@ -60,5 +60,9 @@ def compute_gradient(root:UOp, root_grad:UOp, targets:set[UOp]) -> dict[UOp, UOp
|
||||
if v is None: continue
|
||||
if k in grads: grads[k] = grads[k] + v
|
||||
else: grads[k] = v
|
||||
if len(forward_metadata:=all_metadata.get(t0, ())): all_metadata[v] = tuple(dataclasses.replace(x, backward=True) for x in forward_metadata)
|
||||
if len(forward_metadata:=all_metadata.get(t0, ())):
|
||||
backward_metadata = tuple(dataclasses.replace(x, backward=True) for x in forward_metadata)
|
||||
# we add the backward metadata to everything new in the graph
|
||||
for bw_uop in v.toposort(lambda x: x not in (t0, *t0.src, grads[t0])):
|
||||
all_metadata[bw_uop] = all_metadata.get(bw_uop, ())+backward_metadata
|
||||
return grads
|
||||
|
||||
@@ -272,7 +272,7 @@ def bufferize_to_store(x:UOp):
|
||||
assert assign_target.op is Ops.INDEX, f"{assign_target.op} is not index"
|
||||
# in assign, this is the buffer size, not the bufferize size
|
||||
# TODO: assign_mops here
|
||||
ret = assign_target.replace(dtype=sdtype).store(assign_src, *rngs, dtype=x.dtype)
|
||||
ret = assign_target.replace(dtype=sdtype).store(assign_src, *rngs, dtype=x.dtype).replace(tag=x.tag)
|
||||
mops = []
|
||||
walk = assign_mops
|
||||
while walk is not assign_mops.base:
|
||||
@@ -284,7 +284,7 @@ def bufferize_to_store(x:UOp):
|
||||
# NOTE: the DEFINE_LOCAL needs to be disambiguated here
|
||||
if sdtype.addrspace == AddrSpace.GLOBAL:
|
||||
buf = UOp.new_buffer(x.arg.device, size, x.dtype)
|
||||
ret = buf.reshape(shape).index(*rngs, dtype=sdtype).store(x.src[0], *rngs, dtype=x.dtype)
|
||||
ret = buf.reshape(shape).index(*rngs, dtype=sdtype).store(x.src[0], *rngs, dtype=x.dtype).replace(tag=x.tag)
|
||||
ret = ret.forced_reshape(shape)
|
||||
# TODO: is this right? what if it's offset
|
||||
if any(r.op is Ops.RANGE and r.src[0].op is not Ops.CONST for r in rngs):
|
||||
|
||||
@@ -294,8 +294,7 @@ class Tensor(MathTrait):
|
||||
assert self.shape == x.shape, f"assign shape mismatch {self.shape} != {x.shape}"
|
||||
assert self.device == x.device, f"assign device mismatch {self.device} != {x.device}"
|
||||
assert self.dtype == x.dtype, f"assign dtype mismatch {self.dtype} != {x.dtype}"
|
||||
self.uop = self.uop.assign(x.uop)
|
||||
return self
|
||||
return self.replace(self._apply_uop(UOp.assign, x))
|
||||
|
||||
def detach(self) -> Tensor:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user