mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
Cleanup const buffers (#12829)
* split pm_cleanups * update test_schedule * shrink when we remove bufferize * dont do shrink if shape is empty * update tests * remove *1 from metadata * deal with the noop bufferize * only noop on cvar * cleanup * fix if * rename
This commit is contained in:
@@ -67,12 +67,9 @@ class TestBinaryOpsConstFolding(unittest.TestCase):
|
|||||||
def test_tensor_one_mul(self):
|
def test_tensor_one_mul(self):
|
||||||
_check_ast_count(0, Tensor.ones(4) * Tensor([1.0, 2, 3, 4]))
|
_check_ast_count(0, Tensor.ones(4) * Tensor([1.0, 2, 3, 4]))
|
||||||
|
|
||||||
# TODO: these will be fixed with better folding
|
|
||||||
@unittest.expectedFailure
|
|
||||||
def test_bool_tensor_mul_bool(self):
|
def test_bool_tensor_mul_bool(self):
|
||||||
_check_ast_count(0, Tensor([True, False]) * True)
|
_check_ast_count(0, Tensor([True, False]) * True)
|
||||||
_check_ast_count(0, Tensor([True, False]) * False)
|
_check_ast_count(0, Tensor([True, False]) * False)
|
||||||
@unittest.expectedFailure
|
|
||||||
def test_bool_mul_bool_tensor(self):
|
def test_bool_mul_bool_tensor(self):
|
||||||
_check_ast_count(0, True * Tensor([True, False]))
|
_check_ast_count(0, True * Tensor([True, False]))
|
||||||
_check_ast_count(0, False * Tensor([True, False]))
|
_check_ast_count(0, False * Tensor([True, False]))
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ class TestFusionOp(unittest.TestCase):
|
|||||||
a = Tensor(val)
|
a = Tensor(val)
|
||||||
for _ in range(24): a = Tensor.stack(a, a)[0]
|
for _ in range(24): a = Tensor.stack(a, a)[0]
|
||||||
sched = a.schedule()
|
sched = a.schedule()
|
||||||
self.assertEqual(len(sched), 1)
|
self.assertEqual(len(sched), 0)
|
||||||
self.assertLess(time.perf_counter()-st, 2.0)
|
self.assertLess(time.perf_counter()-st, 2.0)
|
||||||
|
|
||||||
def test_recursive_reshape(self):
|
def test_recursive_reshape(self):
|
||||||
|
|||||||
@@ -52,7 +52,6 @@ class TestImageDType(unittest.TestCase):
|
|||||||
assert isinstance(it.uop.base.realized.dtype, ImageDType)
|
assert isinstance(it.uop.base.realized.dtype, ImageDType)
|
||||||
np.testing.assert_equal(tst, it.numpy())
|
np.testing.assert_equal(tst, it.numpy())
|
||||||
|
|
||||||
@unittest.expectedFailure # this isn't supported anymore, CAST to ImageDType stays ImageDType
|
|
||||||
def test_image_cast_and_back_collapses(self):
|
def test_image_cast_and_back_collapses(self):
|
||||||
data = Tensor.randn(9*27*4).realize()
|
data = Tensor.randn(9*27*4).realize()
|
||||||
tst = data.numpy()
|
tst = data.numpy()
|
||||||
|
|||||||
@@ -446,7 +446,7 @@ class TestSchedule(unittest.TestCase):
|
|||||||
@unittest.skipUnless(is_dtype_supported(dtypes.ulong), "Needs ulong")
|
@unittest.skipUnless(is_dtype_supported(dtypes.ulong), "Needs ulong")
|
||||||
def test_fold_conv_batchnorm_optim(self):
|
def test_fold_conv_batchnorm_optim(self):
|
||||||
# this is too high
|
# this is too high
|
||||||
for optim, cnt in [(nn.optim.Adam, 30), (nn.optim.SGD, 13)]:
|
for optim, cnt in [(nn.optim.Adam, 21), (nn.optim.SGD, 8)]:
|
||||||
with self.subTest(optim=optim.__name__):
|
with self.subTest(optim=optim.__name__):
|
||||||
with Tensor.train():
|
with Tensor.train():
|
||||||
img = Tensor.ones(1,3,4,4)
|
img = Tensor.ones(1,3,4,4)
|
||||||
@@ -1863,7 +1863,7 @@ class TestSchedule(unittest.TestCase):
|
|||||||
yt = Tensor.randn(BS, 10).realize()
|
yt = Tensor.randn(BS, 10).realize()
|
||||||
with Context(SPLIT_REDUCEOP=0):
|
with Context(SPLIT_REDUCEOP=0):
|
||||||
loss = yt.sparse_categorical_crossentropy(Y_train[samples])
|
loss = yt.sparse_categorical_crossentropy(Y_train[samples])
|
||||||
run_schedule(check_schedule(loss, 5))
|
run_schedule(check_schedule(loss, 4))
|
||||||
loss_fused = loss.numpy()
|
loss_fused = loss.numpy()
|
||||||
loss_ref = torch.nn.CrossEntropyLoss()(torch.tensor(yt.numpy()), torch.tensor(Y_train.numpy())[torch.tensor(samples.numpy())])
|
loss_ref = torch.nn.CrossEntropyLoss()(torch.tensor(yt.numpy()), torch.tensor(Y_train.numpy())[torch.tensor(samples.numpy())])
|
||||||
np.testing.assert_allclose(loss_fused, loss_ref.numpy(), atol=1e-6, rtol=1e-6)
|
np.testing.assert_allclose(loss_fused, loss_ref.numpy(), atol=1e-6, rtol=1e-6)
|
||||||
@@ -2076,6 +2076,11 @@ class TestCopyFolding(unittest.TestCase):
|
|||||||
check_schedule(b, 0, filter_sink=False)
|
check_schedule(b, 0, filter_sink=False)
|
||||||
assert b.item() == 1
|
assert b.item() == 1
|
||||||
|
|
||||||
|
def test_one_hot_with_copy(self):
|
||||||
|
y = Tensor([1, 2, 3]).to("CPU")
|
||||||
|
x = y.one_hot(10)
|
||||||
|
check_schedule(x, 3, filter_sink=False)
|
||||||
|
|
||||||
def test_const_copy_multi(self):
|
def test_const_copy_multi(self):
|
||||||
x = Tensor.ones(1, device="CPU").to_(["CPU", "CPU:1"])
|
x = Tensor.ones(1, device="CPU").to_(["CPU", "CPU:1"])
|
||||||
check_schedule(x, 0, filter_sink=False)
|
check_schedule(x, 0, filter_sink=False)
|
||||||
@@ -2085,7 +2090,7 @@ class TestCopyFolding(unittest.TestCase):
|
|||||||
a = Tensor.arange(3).realize()
|
a = Tensor.arange(3).realize()
|
||||||
zeros = Tensor.zeros(3).realize()
|
zeros = Tensor.zeros(3).realize()
|
||||||
b = (a*zeros).to("CPU")
|
b = (a*zeros).to("CPU")
|
||||||
run_schedule(check_schedule(b, 2, filter_sink=False)) # TODO: 0?
|
run_schedule(check_schedule(b, 0, filter_sink=False))
|
||||||
self.assertListEqual(b.tolist(), [0, 0, 0])
|
self.assertListEqual(b.tolist(), [0, 0, 0])
|
||||||
self.assertEqual(b.device, "CPU")
|
self.assertEqual(b.device, "CPU")
|
||||||
|
|
||||||
|
|||||||
@@ -839,12 +839,11 @@ class TestTensorMetadata(unittest.TestCase):
|
|||||||
self.assertEqual(y.grad.uop.metadata[0].name, "sigmoid")
|
self.assertEqual(y.grad.uop.metadata[0].name, "sigmoid")
|
||||||
self.assertTrue(y.grad.uop.metadata[0].backward)
|
self.assertTrue(y.grad.uop.metadata[0].backward)
|
||||||
si = Tensor.schedule(out, x.grad, y.grad)[-1]
|
si = Tensor.schedule(out, x.grad, y.grad)[-1]
|
||||||
self.assertEqual(len(si.metadata), 4, f"failed with {si.metadata}")
|
self.assertEqual(len(si.metadata), 3, f"failed with {si.metadata}")
|
||||||
self.assertSetEqual(set(m.name for m in si.metadata), {"__mul__", "sigmoid", "relu"})
|
self.assertSetEqual(set(m.name for m in si.metadata), {"sigmoid", "relu"})
|
||||||
bw = [m for m in si.metadata if m.backward]
|
bw = [m for m in si.metadata if m.backward]
|
||||||
self.assertEqual(len(bw), 2)
|
self.assertEqual(len(bw), 1)
|
||||||
self.assertEqual(bw[0].name, "__mul__")
|
self.assertEqual(bw[0].name, "sigmoid")
|
||||||
self.assertEqual(bw[1].name, "sigmoid")
|
|
||||||
|
|
||||||
class TestIdxUpcast(unittest.TestCase):
|
class TestIdxUpcast(unittest.TestCase):
|
||||||
def _find_op(self, ast: UOp, op: Ops):
|
def _find_op(self, ast: UOp, op: Ops):
|
||||||
|
|||||||
@@ -192,33 +192,42 @@ def remove_bufferize(src:UOp, buf:UOp, idx:UOp):
|
|||||||
# NOTE: if buf src is a const, we don't replace it
|
# NOTE: if buf src is a const, we don't replace it
|
||||||
return src.substitute({k:v for k,v in zip(buf.src[1:], idx.src[1:]) if k.op is not Ops.CONST}, extra_pm=pm_gate_substitute)
|
return src.substitute({k:v for k,v in zip(buf.src[1:], idx.src[1:]) if k.op is not Ops.CONST}, extra_pm=pm_gate_substitute)
|
||||||
|
|
||||||
def pre_bufferize(b:UOp, x:UOp, copy:UOp):
|
def remove_noop_bufferize(idx,b2):
|
||||||
nb = b.replace(src=(b.src[0].contiguous(),)+b.src[1:])
|
if idx.src[1:] != b2.src[1:] or idx.src[0].op is Ops.BUFFER_VIEW: return None
|
||||||
return copy.replace(src=(x.replace(src=(nb,)+x.src[1:]), copy.src[1]))
|
new_tag = (idx.src[0].tag or ()) + (b2.tag or ()) or None
|
||||||
|
return idx.src[0].rtag(new_tag).shrink(tuple((0, s) for s in b2.shape)) if b2.shape else idx.src[0].rtag(new_tag)
|
||||||
|
|
||||||
pm_cleanups = pm_mops+PatternMatcher([
|
pm_const_buffer_folding = pm_mops+PatternMatcher([
|
||||||
(UPat(Ops.BUFFERIZE, name="b"), cleanup_dead_axes),
|
(UPat(Ops.BUFFERIZE, name="b"), cleanup_dead_axes),
|
||||||
(UPat(GroupOp.All-{Ops.BUFFERIZE, Ops.BUFFER}, name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None),
|
(UPat(GroupOp.All-{Ops.BUFFERIZE, Ops.BUFFER}, name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None),
|
||||||
(UPat((Ops.BUFFERIZE), name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType)
|
(UPat((Ops.BUFFERIZE), name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType)
|
||||||
and (resolve(prod(x.dtype.shape)!=prod(x.shape)) or x.shape[-1]%4!=0) else None),
|
and (resolve(prod(x.dtype.shape)!=prod(x.shape)) or x.shape[-1]%4!=0) else None),
|
||||||
# remove noop buffers. if we look at the next index we can remove even more of these
|
# remove noop buffers. if we look at the next index we can remove even more of these
|
||||||
# NOTE: this is mostly the same case as below, but if there's no INDEX this gets more
|
(UPat(Ops.INDEX, name="idx").f(Ops.BUFFERIZE, allow_any_len=True, name="b2"), remove_noop_bufferize),
|
||||||
(UPat(Ops.INDEX, name="idx").f(Ops.BUFFERIZE, allow_any_len=True, name="b2"),
|
|
||||||
lambda idx,b2: idx.src[0].replace(tag=nt if len(nt:=(idx.src[0].tag or ()) + (b2.tag or ())) else None) if idx.src[1:] == b2.src[1:] \
|
|
||||||
and idx.src[0].op is not Ops.BUFFER_VIEW else None),
|
|
||||||
# remove reindexing with cost function
|
|
||||||
(UPat.var("src").f(Ops.BUFFERIZE, allow_any_len=True, name="buf").f(Ops.INDEX, allow_any_len=True, name="idx"), remove_bufferize),
|
|
||||||
# no buffers for const
|
# no buffers for const
|
||||||
(UPat(Ops.CONST, name='c').f(Ops.BUFFERIZE, allow_any_len=True, name="b"), lambda c,b: b.const_like(c.arg).rtag(b.tag)),
|
(UPat(Ops.CONST, name='c').f(Ops.BUFFERIZE, allow_any_len=True, name="b"), lambda c,b: b.const_like(c.arg).rtag(b.tag)),
|
||||||
|
# indexing a const is a const
|
||||||
|
(UPat(Ops.INDEX, src=(UPat(Ops.CONST, name="c"),),), lambda c: c),
|
||||||
# copy on CONST is CONST
|
# copy on CONST is CONST
|
||||||
(UPat(Ops.COPY, src=(UPat.cvar("x"), UPat()), name="copy"), lambda copy,x: copy.const_like(x.arg)),
|
(UPat(Ops.COPY, src=(UPat.cvar("x"), UPat()), name="copy"), lambda copy,x: copy.const_like(x.arg)),
|
||||||
(UPat(Ops.COPY, src=(UPat(GroupOp.All-{Ops.CONTIGUOUS, Ops.COPY}).f(Ops.BUFFERIZE, allow_any_len=True, name="b")
|
# hack if a noop turned to a const
|
||||||
.f(Ops.INDEX, allow_any_len=True, name="x"), UPat()), name="copy"), pre_bufferize),
|
(UPat.cvar("c").f(Ops.NOOP).f(Ops.BUFFERIZE, allow_any_len=True, name="buf"), lambda c,buf: buf.replace(src=(c,)+buf.src[1:])),
|
||||||
# mstack on CONST is CONST
|
# mstack on CONST is CONST
|
||||||
(UPat(Ops.MSTACK, src=(UPat.var("s"),), allow_any_len=True).f(Ops.INDEX, allow_any_len=True),
|
(UPat(Ops.MSTACK, src=(UPat.var("s"),), allow_any_len=True).f(Ops.INDEX, allow_any_len=True),
|
||||||
lambda s: UOp.const(c.dtype, c.arg) if (c:=s.base).op is Ops.CONST else None),
|
lambda s: UOp.const(c.dtype, c.arg) if (c:=s.base).op is Ops.CONST else None),
|
||||||
])
|
])
|
||||||
|
|
||||||
|
def pre_bufferize(b:UOp, x:UOp, copy:UOp):
|
||||||
|
nb = b.replace(src=(b.src[0].contiguous(),)+b.src[1:])
|
||||||
|
return copy.replace(src=(x.replace(src=(nb,)+x.src[1:]), copy.src[1]))
|
||||||
|
pm_remove_bufferize = PatternMatcher([
|
||||||
|
# hack so remove_bufferize doesnt remove the buffer before a copy
|
||||||
|
(UPat(Ops.COPY, src=(UPat(GroupOp.All-{Ops.CONTIGUOUS, Ops.COPY}).f(Ops.BUFFERIZE, allow_any_len=True, name="b")
|
||||||
|
.f(Ops.INDEX, allow_any_len=True, name="x"), UPat()), name="copy"), pre_bufferize),
|
||||||
|
# remove reindexing with cost function
|
||||||
|
(UPat.var("src").f(Ops.BUFFERIZE, allow_any_len=True, name="buf").f(Ops.INDEX, allow_any_len=True, name="idx"), remove_bufferize),
|
||||||
|
])
|
||||||
|
|
||||||
def late_buffer_view(t:UOp, b:UOp):
|
def late_buffer_view(t:UOp, b:UOp):
|
||||||
if isinstance(b.device, str) and (b.device.startswith("DISK") or b.device.startswith("TINYFS")):
|
if isinstance(b.device, str) and (b.device.startswith("DISK") or b.device.startswith("TINYFS")):
|
||||||
rngs = b.src[1:]
|
rngs = b.src[1:]
|
||||||
@@ -488,9 +497,8 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
|
|||||||
# convert movement ops to ranges
|
# convert movement ops to ranges
|
||||||
tsink, rctx = run_rangeify(tsink, DEBUG_RANGEIFY)
|
tsink, rctx = run_rangeify(tsink, DEBUG_RANGEIFY)
|
||||||
|
|
||||||
# NOTE: sym (vs symbolic_simple) breaks things here because ranges with len 1 aren't handled right
|
tsink = graph_rewrite(tsink, symbolic_flat+pm_reduce_unparented+pm_const_buffer_folding, name="symbolic") # this supports const folding
|
||||||
tsink = graph_rewrite(tsink, symbolic_flat+pm_reduce_unparented, name="symbolic") # this supports const folding
|
tsink = graph_rewrite(tsink, pm_remove_bufferize, bottom_up=True, name="remove bufferize with cost function")
|
||||||
tsink = graph_rewrite(tsink, pm_cleanups, bottom_up=True, name="remove costly buffers")
|
|
||||||
tsink = graph_rewrite(tsink, pm_limit_bufs, ctx=rctx, name="limit buffers")
|
tsink = graph_rewrite(tsink, pm_limit_bufs, ctx=rctx, name="limit buffers")
|
||||||
|
|
||||||
# rebuild the sink with all the BUFFERIZEs with tags, this is what's ending up in the tensor graph
|
# rebuild the sink with all the BUFFERIZEs with tags, this is what's ending up in the tensor graph
|
||||||
|
|||||||
Reference in New Issue
Block a user