mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
Cleanup const buffers (#12829)
* split pm_cleanups * update test_schedule * shrink when we remove bufferize * dont do shrink if shape is empty * update tests * remove *1 from metadata * deal with the noop bufferize * only noop on cvar * cleanup * fix if * rename
This commit is contained in:
@@ -67,12 +67,9 @@ class TestBinaryOpsConstFolding(unittest.TestCase):
|
||||
def test_tensor_one_mul(self):
|
||||
_check_ast_count(0, Tensor.ones(4) * Tensor([1.0, 2, 3, 4]))
|
||||
|
||||
# TODO: these will be fixed with better folding
|
||||
@unittest.expectedFailure
|
||||
def test_bool_tensor_mul_bool(self):
|
||||
_check_ast_count(0, Tensor([True, False]) * True)
|
||||
_check_ast_count(0, Tensor([True, False]) * False)
|
||||
@unittest.expectedFailure
|
||||
def test_bool_mul_bool_tensor(self):
|
||||
_check_ast_count(0, True * Tensor([True, False]))
|
||||
_check_ast_count(0, False * Tensor([True, False]))
|
||||
|
||||
@@ -51,7 +51,7 @@ class TestFusionOp(unittest.TestCase):
|
||||
a = Tensor(val)
|
||||
for _ in range(24): a = Tensor.stack(a, a)[0]
|
||||
sched = a.schedule()
|
||||
self.assertEqual(len(sched), 1)
|
||||
self.assertEqual(len(sched), 0)
|
||||
self.assertLess(time.perf_counter()-st, 2.0)
|
||||
|
||||
def test_recursive_reshape(self):
|
||||
|
||||
@@ -52,7 +52,6 @@ class TestImageDType(unittest.TestCase):
|
||||
assert isinstance(it.uop.base.realized.dtype, ImageDType)
|
||||
np.testing.assert_equal(tst, it.numpy())
|
||||
|
||||
@unittest.expectedFailure # this isn't supported anymore, CAST to ImageDType stays ImageDType
|
||||
def test_image_cast_and_back_collapses(self):
|
||||
data = Tensor.randn(9*27*4).realize()
|
||||
tst = data.numpy()
|
||||
|
||||
@@ -446,7 +446,7 @@ class TestSchedule(unittest.TestCase):
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.ulong), "Needs ulong")
|
||||
def test_fold_conv_batchnorm_optim(self):
|
||||
# this is too high
|
||||
for optim, cnt in [(nn.optim.Adam, 30), (nn.optim.SGD, 13)]:
|
||||
for optim, cnt in [(nn.optim.Adam, 21), (nn.optim.SGD, 8)]:
|
||||
with self.subTest(optim=optim.__name__):
|
||||
with Tensor.train():
|
||||
img = Tensor.ones(1,3,4,4)
|
||||
@@ -1863,7 +1863,7 @@ class TestSchedule(unittest.TestCase):
|
||||
yt = Tensor.randn(BS, 10).realize()
|
||||
with Context(SPLIT_REDUCEOP=0):
|
||||
loss = yt.sparse_categorical_crossentropy(Y_train[samples])
|
||||
run_schedule(check_schedule(loss, 5))
|
||||
run_schedule(check_schedule(loss, 4))
|
||||
loss_fused = loss.numpy()
|
||||
loss_ref = torch.nn.CrossEntropyLoss()(torch.tensor(yt.numpy()), torch.tensor(Y_train.numpy())[torch.tensor(samples.numpy())])
|
||||
np.testing.assert_allclose(loss_fused, loss_ref.numpy(), atol=1e-6, rtol=1e-6)
|
||||
@@ -2076,6 +2076,11 @@ class TestCopyFolding(unittest.TestCase):
|
||||
check_schedule(b, 0, filter_sink=False)
|
||||
assert b.item() == 1
|
||||
|
||||
def test_one_hot_with_copy(self):
|
||||
y = Tensor([1, 2, 3]).to("CPU")
|
||||
x = y.one_hot(10)
|
||||
check_schedule(x, 3, filter_sink=False)
|
||||
|
||||
def test_const_copy_multi(self):
|
||||
x = Tensor.ones(1, device="CPU").to_(["CPU", "CPU:1"])
|
||||
check_schedule(x, 0, filter_sink=False)
|
||||
@@ -2085,7 +2090,7 @@ class TestCopyFolding(unittest.TestCase):
|
||||
a = Tensor.arange(3).realize()
|
||||
zeros = Tensor.zeros(3).realize()
|
||||
b = (a*zeros).to("CPU")
|
||||
run_schedule(check_schedule(b, 2, filter_sink=False)) # TODO: 0?
|
||||
run_schedule(check_schedule(b, 0, filter_sink=False))
|
||||
self.assertListEqual(b.tolist(), [0, 0, 0])
|
||||
self.assertEqual(b.device, "CPU")
|
||||
|
||||
|
||||
@@ -839,12 +839,11 @@ class TestTensorMetadata(unittest.TestCase):
|
||||
self.assertEqual(y.grad.uop.metadata[0].name, "sigmoid")
|
||||
self.assertTrue(y.grad.uop.metadata[0].backward)
|
||||
si = Tensor.schedule(out, x.grad, y.grad)[-1]
|
||||
self.assertEqual(len(si.metadata), 4, f"failed with {si.metadata}")
|
||||
self.assertSetEqual(set(m.name for m in si.metadata), {"__mul__", "sigmoid", "relu"})
|
||||
self.assertEqual(len(si.metadata), 3, f"failed with {si.metadata}")
|
||||
self.assertSetEqual(set(m.name for m in si.metadata), {"sigmoid", "relu"})
|
||||
bw = [m for m in si.metadata if m.backward]
|
||||
self.assertEqual(len(bw), 2)
|
||||
self.assertEqual(bw[0].name, "__mul__")
|
||||
self.assertEqual(bw[1].name, "sigmoid")
|
||||
self.assertEqual(len(bw), 1)
|
||||
self.assertEqual(bw[0].name, "sigmoid")
|
||||
|
||||
class TestIdxUpcast(unittest.TestCase):
|
||||
def _find_op(self, ast: UOp, op: Ops):
|
||||
|
||||
@@ -192,33 +192,42 @@ def remove_bufferize(src:UOp, buf:UOp, idx:UOp):
|
||||
# NOTE: if buf src is a const, we don't replace it
|
||||
return src.substitute({k:v for k,v in zip(buf.src[1:], idx.src[1:]) if k.op is not Ops.CONST}, extra_pm=pm_gate_substitute)
|
||||
|
||||
def pre_bufferize(b:UOp, x:UOp, copy:UOp):
|
||||
nb = b.replace(src=(b.src[0].contiguous(),)+b.src[1:])
|
||||
return copy.replace(src=(x.replace(src=(nb,)+x.src[1:]), copy.src[1]))
|
||||
def remove_noop_bufferize(idx,b2):
|
||||
if idx.src[1:] != b2.src[1:] or idx.src[0].op is Ops.BUFFER_VIEW: return None
|
||||
new_tag = (idx.src[0].tag or ()) + (b2.tag or ()) or None
|
||||
return idx.src[0].rtag(new_tag).shrink(tuple((0, s) for s in b2.shape)) if b2.shape else idx.src[0].rtag(new_tag)
|
||||
|
||||
pm_cleanups = pm_mops+PatternMatcher([
|
||||
pm_const_buffer_folding = pm_mops+PatternMatcher([
|
||||
(UPat(Ops.BUFFERIZE, name="b"), cleanup_dead_axes),
|
||||
(UPat(GroupOp.All-{Ops.BUFFERIZE, Ops.BUFFER}, name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None),
|
||||
(UPat((Ops.BUFFERIZE), name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType)
|
||||
and (resolve(prod(x.dtype.shape)!=prod(x.shape)) or x.shape[-1]%4!=0) else None),
|
||||
# remove noop buffers. if we look at the next index we can remove even more of these
|
||||
# NOTE: this is mostly the same case as below, but if there's no INDEX this gets more
|
||||
(UPat(Ops.INDEX, name="idx").f(Ops.BUFFERIZE, allow_any_len=True, name="b2"),
|
||||
lambda idx,b2: idx.src[0].replace(tag=nt if len(nt:=(idx.src[0].tag or ()) + (b2.tag or ())) else None) if idx.src[1:] == b2.src[1:] \
|
||||
and idx.src[0].op is not Ops.BUFFER_VIEW else None),
|
||||
# remove reindexing with cost function
|
||||
(UPat.var("src").f(Ops.BUFFERIZE, allow_any_len=True, name="buf").f(Ops.INDEX, allow_any_len=True, name="idx"), remove_bufferize),
|
||||
(UPat(Ops.INDEX, name="idx").f(Ops.BUFFERIZE, allow_any_len=True, name="b2"), remove_noop_bufferize),
|
||||
# no buffers for const
|
||||
(UPat(Ops.CONST, name='c').f(Ops.BUFFERIZE, allow_any_len=True, name="b"), lambda c,b: b.const_like(c.arg).rtag(b.tag)),
|
||||
# indexing a const is a const
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.CONST, name="c"),),), lambda c: c),
|
||||
# copy on CONST is CONST
|
||||
(UPat(Ops.COPY, src=(UPat.cvar("x"), UPat()), name="copy"), lambda copy,x: copy.const_like(x.arg)),
|
||||
(UPat(Ops.COPY, src=(UPat(GroupOp.All-{Ops.CONTIGUOUS, Ops.COPY}).f(Ops.BUFFERIZE, allow_any_len=True, name="b")
|
||||
.f(Ops.INDEX, allow_any_len=True, name="x"), UPat()), name="copy"), pre_bufferize),
|
||||
# hack if a noop turned to a const
|
||||
(UPat.cvar("c").f(Ops.NOOP).f(Ops.BUFFERIZE, allow_any_len=True, name="buf"), lambda c,buf: buf.replace(src=(c,)+buf.src[1:])),
|
||||
# mstack on CONST is CONST
|
||||
(UPat(Ops.MSTACK, src=(UPat.var("s"),), allow_any_len=True).f(Ops.INDEX, allow_any_len=True),
|
||||
lambda s: UOp.const(c.dtype, c.arg) if (c:=s.base).op is Ops.CONST else None),
|
||||
])
|
||||
|
||||
def pre_bufferize(b:UOp, x:UOp, copy:UOp):
|
||||
nb = b.replace(src=(b.src[0].contiguous(),)+b.src[1:])
|
||||
return copy.replace(src=(x.replace(src=(nb,)+x.src[1:]), copy.src[1]))
|
||||
pm_remove_bufferize = PatternMatcher([
|
||||
# hack so remove_bufferize doesnt remove the buffer before a copy
|
||||
(UPat(Ops.COPY, src=(UPat(GroupOp.All-{Ops.CONTIGUOUS, Ops.COPY}).f(Ops.BUFFERIZE, allow_any_len=True, name="b")
|
||||
.f(Ops.INDEX, allow_any_len=True, name="x"), UPat()), name="copy"), pre_bufferize),
|
||||
# remove reindexing with cost function
|
||||
(UPat.var("src").f(Ops.BUFFERIZE, allow_any_len=True, name="buf").f(Ops.INDEX, allow_any_len=True, name="idx"), remove_bufferize),
|
||||
])
|
||||
|
||||
def late_buffer_view(t:UOp, b:UOp):
|
||||
if isinstance(b.device, str) and (b.device.startswith("DISK") or b.device.startswith("TINYFS")):
|
||||
rngs = b.src[1:]
|
||||
@@ -488,9 +497,8 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
|
||||
# convert movement ops to ranges
|
||||
tsink, rctx = run_rangeify(tsink, DEBUG_RANGEIFY)
|
||||
|
||||
# NOTE: sym (vs symbolic_simple) breaks things here because ranges with len 1 aren't handled right
|
||||
tsink = graph_rewrite(tsink, symbolic_flat+pm_reduce_unparented, name="symbolic") # this supports const folding
|
||||
tsink = graph_rewrite(tsink, pm_cleanups, bottom_up=True, name="remove costly buffers")
|
||||
tsink = graph_rewrite(tsink, symbolic_flat+pm_reduce_unparented+pm_const_buffer_folding, name="symbolic") # this supports const folding
|
||||
tsink = graph_rewrite(tsink, pm_remove_bufferize, bottom_up=True, name="remove bufferize with cost function")
|
||||
tsink = graph_rewrite(tsink, pm_limit_bufs, ctx=rctx, name="limit buffers")
|
||||
|
||||
# rebuild the sink with all the BUFFERIZEs with tags, this is what's ending up in the tensor graph
|
||||
|
||||
Reference in New Issue
Block a user