Files
tinygrad/test/null/test_schedule.py
2026-03-04 05:26:04 -05:00

1194 lines
38 KiB
Python

# schedule tests that pass on NULL backend (no copyout needed)
import gc, unittest, time
from tinygrad import nn, dtypes, Device, Tensor
from tinygrad.device import is_dtype_supported
from tinygrad.uop.ops import UOp, Ops, GroupOp, UPat
from tinygrad.helpers import DEBUG, GlobalCounters, Context
from tinygrad.engine.realize import CompiledRunner, run_schedule
class KernelCountException(Exception): pass
def check_schedule(t:Tensor|list[Tensor]|UOp, allowed:int, to_prerealize:list[Tensor]|None=None, filter_sink=True):
if to_prerealize:
with Context(DEBUG=0, TRACK_MATCH_STATS=0): Tensor.realize(*to_prerealize)
if isinstance(t, Tensor): sched = t.schedule()
elif isinstance(t, list) and isinstance(t[0], Tensor): sched = Tensor.schedule(*t)
else:
assert isinstance(t, UOp), f"can't schedule {t}"
sched = Tensor(t).schedule()
# test lowering all the ExecItems
for si in sched: si.lower()
kernel_cnt = len([si for si in sched if isinstance(si.prg, CompiledRunner) or not filter_sink])
if kernel_cnt != allowed:
print(f"SCHEDULE ISSUE, expecting {allowed} got {kernel_cnt}")
if DEBUG >= 3:
for i,s in enumerate(sched):
print("kernel", i+1)
print(s.ast)
raise KernelCountException(f"{kernel_cnt} != {allowed}")
return sched
def _realize_weights(m):
for p in nn.state.get_parameters(m): p.realize()
class TestBufferUOp(unittest.TestCase):
# BUFFER has a ShapeTracker of shape=(n,) and stride=(1,)
def test_buffer_has_buffer(self):
buf = Tensor.empty(10)
self.assertIsNotNone(buf.uop.buffer)
self.assertEqual(buf.uop.shape, (10,))
# the device Buffer remains unallocated until it's we run the schedule
self.assertFalse(buf.uop.buffer.is_allocated())
add = buf+1
sched = add.schedule()
self.assertFalse(buf.uop.buffer.is_allocated())
run_schedule(sched)
self.assertTrue(buf.uop.buffer.is_allocated())
def test_buffer_has_unique_buffer(self):
buf = Tensor.empty(10)
buf1 = buf.uop.buffer
buf2 = buf.uop.buffer
self.assertIs(buf1, buf2)
# we also allow VIEW(BUFFER) to access the underlying device Buffer, as long as it's contiguous
def test_buffer_view_allowed(self):
add = Tensor.empty(1, 1)+Tensor.empty(1, 1)
add.realize()
self.assertIsNotNone(add.uop.buffer)
self.assertEqual(add.uop.shape, (1, 1))
def test_buffer_view_not_allowed(self):
permuted_view = Tensor.empty(1, 2, 3).permute(0, 2, 1)
with self.assertRaises(RuntimeError):
permuted_view.uop.buffer # cannot access Buffer of a non contiguous VIEW
def test_buffer_only_after_realize(self):
a = Tensor([1])+Tensor([2])
# accessing realized will return None
self.assertIsNone(a.uop.realized)
# accessing Buffer will assert
with self.assertRaisesRegex(AssertionError, "must be BUFFER"):
a.uop.buffer # there is no BUFFER on an unrealized ADD
# Buffer only exists once we realize it
a.realize()
self.assertIsNotNone(a.uop.buffer)
def test_const_does_not_realize(self):
a = Tensor(1)
run_schedule(check_schedule(a, 0))
self.assertIsNone(a.uop.base.realized)
def test_var_does_not_realize(self):
a = Tensor(UOp.variable("a", 0, 10).bind(1))
run_schedule(check_schedule(a, 0))
self.assertIsNone(a.uop.base.realized)
def test_unused_var_not_in_var_vals(self):
# unused variable should not appear in var_vals even when there's other work
a = Tensor(UOp.variable("unused", 0, 10).bind(1))
b = Tensor.empty(3) + 1
_, var_vals = Tensor.schedule_with_vars(a, b)
self.assertEqual(var_vals, {})
self.assertIsNone(a.uop.base.realized)
def test_view_does_not_realize(self):
a = Tensor.randn(1, 4).expand(4, 4)
a.realize()
self.assertEqual(a.uop.base.realized.size, 4)
a2 = a.contiguous().realize()
self.assertEqual(a2.uop.base.realized.size, 16)
class TestContiguous(unittest.TestCase):
def test_contiguous_buffer(self):
a = Tensor.empty(4)
b = a.contiguous()
check_schedule(b, 0)
def test_contiguous_buffer_view(self):
a = Tensor.empty(4)
b = a.reshape((2, 2)).contiguous()
check_schedule(b, 0)
def test_non_contiguous_buffer_view(self):
a = Tensor.empty(4, 1)
b = a.expand((4, 4)).contiguous()
check_schedule(b, 1)
def test_size_change_buffer_view(self):
a = Tensor.empty(4)
b = a.reshape((1, 1, 4)).shrink(((0, 1), (0, 1), (0, 3))).contiguous()
check_schedule(b, 0) # contiguous shrink of a realized buffer is a zero-copy BUFFER_VIEW
def test_double_contiguous_realizes_once(self):
a = Tensor.empty(4, 1)
b = a.expand((4, 4)).contiguous().contiguous()
check_schedule(b, 1)
def test_view_does_not_realize(self):
a = Tensor.empty(4)
b = a.expand((4, 4))
check_schedule(b, 0)
self.assertEqual(b.uop.base.buffer.size, 4)
def test_contiguous_view_realizes(self):
a = Tensor.empty(4)
b = a.expand((4, 4)).contiguous()
check_schedule(b, 1)
self.assertEqual(b.uop.base.buffer.size, 16)
class TestSimpleSchedule(unittest.TestCase):
def test_reduce_doesnt_split(self):
a = Tensor.empty(16,16).sum(axis=1)
a1 = a.reshape(4,4)
a2 = a.reshape(16,1,1)
self.assertEqual(len(Tensor.schedule(a1, a2)), 1)
class TestSchedule(unittest.TestCase):
@unittest.skipIf(Device.DEFAULT == "CPU", "devices must mismatch")
def test_error_on_device_mismatch(self):
a = Tensor.empty(10)
b = Tensor.empty(10, device="CPU")
c = a+b
with self.assertRaisesRegex(RuntimeError, "all buffers must be on the same device"): check_schedule(c, 1)
@unittest.skipIf(Device.DEFAULT == "CPU", "devices must mismatch")
def test_error_on_device_mismatch_alt(self):
a = Tensor.empty(10)
b = Tensor.empty((1,), device="CPU").expand(10).contiguous()
c = a+b
with self.assertRaisesRegex(RuntimeError, "all buffers must be on the same device"): check_schedule(c, 2)
def test_rand(self):
x = Tensor.rand(32)
check_schedule(x, 1, [Tensor._device_rng_counters[x.device]])
def test_rand_recompute_arange(self):
x = Tensor.rand(32)
check_schedule(x, 1, [Tensor._device_rng_counters[x.device]])
def test_empty_is_not_realized(self):
a = Tensor.empty(10)
child = a+2
assert not a.uop.is_realized
child.realize()
assert a.uop.is_realized
def test_realize_view_of_realized_has_empty_schedule(self):
# views of realized buffers produce an empty schedule
t = Tensor.zeros((3, 3)).contiguous().realize()
v = t[1] # view - is_realized but not has_buffer_identity
assert v.uop.is_realized
sched, _ = Tensor.schedule_with_vars(v)
self.assertEqual(len(sched), 0)
# NOTE: because empty does not have a lowered ExecItem if realize is called on a childless empty, it never gets allocated.
def test_childless_empty_never_allocates(self):
a = Tensor.empty(10)
a.realize()
assert not a.uop.is_realized
def test_simplify_padded_const(self):
a, _ = Tensor.empty(1022).cummax(axis=0)
check_schedule(a, 3)
@unittest.skip("should this pass?")
def test_contiguous_assign(self):
a = Tensor.ones(10) * 2
b = Tensor.empty(10)
c = b.assign(a.contiguous())
check_schedule(c, 1)
def test_basic_binop_fusion(self):
a = Tensor.empty(10)
b = Tensor.empty(10)
c = Tensor.empty(10)
d = a+b+c
check_schedule(d, 1)
def test_basic_binop_fusion_assign(self):
a = Tensor.empty(10)
b = Tensor.empty(10)
c = Tensor.empty(10)
d = a+b+c
e = Tensor.empty(10).assign(d)
check_schedule(e, 1)
def test_basic_binop_fusion_deep(self):
a = Tensor.empty(10)
b = Tensor.empty(10)
c = Tensor.empty(10)
d = Tensor.empty(10)
e = a+b+c+d
check_schedule(e, 1)
def test_mulacc_fusion(self):
a = Tensor.empty(10)
b = Tensor.empty(10)
c = (a*b).sum()
check_schedule(c, 1)
def test_mulacc_fusion_assign(self):
a = Tensor.empty(10)
b = Tensor.empty(10)
c = (a*b).sum()
d = Tensor.empty(1).assign(c)
check_schedule(d, 1)
def test_detach_assign(self):
a = Tensor.ones(4, 4).contiguous().realize()
buf1, buf2 = Tensor.empty(4, 4).contiguous(), Tensor.empty(4, 4).contiguous()
r = buf2.assign(buf1.assign(a + 1.0) * 2.0)
check_schedule(r.detach().contiguous(), 2)
def test_contiguous_backward_assign(self):
a = Tensor.ones(4, 4).contiguous().realize()
buf1, buf2 = Tensor.empty(4, 4).contiguous(), Tensor.empty(4, 4).contiguous()
r = buf2.assign(buf1.assign(a + 1.0) * 2.0)
check_schedule(r.contiguous_backward().contiguous(), 2)
def test_mulacc_relu_fusion(self):
a = Tensor.empty(10)
b = Tensor.empty(10)
c = (a*b).sum().relu()
check_schedule(c, 1)
def test_binop_reshape_fusion(self):
a = Tensor.empty(10)
b = Tensor.empty(10)
c = Tensor.empty(5,2)
d = (a+b).reshape(5,2)+c
check_schedule(d, 1)
def test_binop_permute_fusion(self):
a = Tensor.empty(2,5)
b = Tensor.empty(2,5)
c = Tensor.empty(5,2)
d = (a+b).permute(1,0)+c
check_schedule(d, 1)
def test_constants_are_embedded(self):
a = Tensor.empty(3,3) * 2
check_schedule(a, 1, filter_sink=False)
def tests_constants_are_folded(self):
a = Tensor(2)
check_schedule(a, 0)
def test_binop_elu_fusion(self):
a = Tensor.empty(10)
b = a.elu()
check_schedule(b, 1)
def test_binop_reshape_reduce_fusion(self):
a = Tensor.empty(100)
b = Tensor.empty(100)
c = (a+b).reshape(10, 10).sum(axis=0, keepdim=True)
check_schedule(c, 1)
def test_reduce_reshape_binop_fusion(self):
a = Tensor.empty(10,10)
b = Tensor.empty(10)
c = a.sum(axis=0) + b
check_schedule(c, 1)
def test_reduce_permute_binop_fusion(self):
a = Tensor.empty(10,10,10)
b = Tensor.empty(10,10,1)
c = a.sum(axis=0, keepdim=True).permute(2,1,0) + b
check_schedule(c, 1)
def test_binop_early_reshape_reduce_fusion(self):
a = Tensor.empty(100)
b = Tensor.empty(100)
c = Tensor.empty(10,10)
d = ((a+b).reshape(10,10) + c).sum(axis=0)
check_schedule(d, 1)
def test_diamond_folded(self):
a = Tensor.empty(10)
b = Tensor.empty(10)
c = Tensor.empty(10)
d = Tensor.empty(10)
ab = a+b
e = (ab+c) + (ab+d)
check_schedule(e, 1)
def test_cache_binaryop(self):
a = Tensor.empty(10)
b = Tensor.empty(10)
c = a+b
d = a+b
check_schedule(d, 0, [c])
# failing in new lazy
def test_cache_binaryop_reshaped(self):
a = Tensor.empty(10)
b = Tensor.empty(10)
c = a+b
d = a.reshape(10,1)+b.reshape(10,1)
check_schedule(d, 1, [c])
# failing in new lazy
def test_cache_binaryop_transpose(self):
a = Tensor.empty(10,10)
b = Tensor.empty(10,10)
c = (a.T*b.T).T #.contiguous()
d = a*b
check_schedule(d, 1, [c])
def test_cache_two_reduceops(self):
a = Tensor.empty(10)
b = a.sum()
c = a.sum()
bc = b+c
check_schedule(bc, 1)
def test_cache_reduce_parent(self):
x = Tensor.empty(32)
r0 = x.mean(axis=0, keepdim=True)
r1 = (x - r0).sum(axis=0).div(2)
out = r0 + r1
schedule = check_schedule(out, 2)
reduceops = [x for si in schedule for x in si.ast.toposort() if x.op in {Ops.REDUCE_AXIS, Ops.REDUCE}]
assert len(reduceops) == 2
def test_cache_reduce_multiple_children(self):
x = Tensor.empty(32)
y = Tensor.empty(4, 4)
r0 = x.mean(axis=0, keepdim=True)
r1 = (x - r0).sum(axis=0).div(2)
out0 = r0 + y
out1 = r1 + y
schedule = check_schedule([out0, out1], 3)
reduceops = [x for si in schedule for x in si.ast.toposort() if x.op in {Ops.REDUCE_AXIS, Ops.REDUCE}]
self.assertEqual(len(reduceops), 2) # why is RANGEIFY different?
def test_dedup_assign(self):
a = Tensor.ones(4).contiguous().realize()
b = Tensor.full((4,), 2.).contiguous()
first = a.assign(b)
second = a.assign(b)
check_schedule([first, second], 2) # TODO: 1?
def test_no_dedup_empty(self):
a = Tensor.empty((4,))
b = Tensor.empty((4,))
# NOTE: empty does not have any schedule
check_schedule([a, b], 0, filter_sink=False)
self.assertIsNot(a.uop.buffer, b.uop.buffer)
def test_dedup_outputs(self):
a = Tensor.full((4, 4), 1.).contiguous().realize()
b = Tensor.full((4, 4), 1.).contiguous().realize()
check_schedule([a+b, a+b], 1)
def test_const_realize(self):
t = Tensor.ones(2)
check_schedule(t[0], 0)
check_schedule(t[1], 0)
def test_fold_double_unary(self):
y = Tensor.empty(2)
out = y.sum(keepdim=True).sqrt().neg()
check_schedule(out, 1)
#@unittest.skip("may want to reconsider this")
def test_fold_batchnorm(self):
with Tensor.train():
img = Tensor.empty(1,32,4,4)
bn = nn.BatchNorm2d(32, track_running_stats=False)
out = bn(img)
check_schedule(out, 3)
def test_fold_conv_batchnorm_notrain(self):
with Tensor.train(False):
img = Tensor.empty(1,3,8,8)
c1 = nn.Conv2d(3,32,3)
bn = nn.BatchNorm2d(32, track_running_stats=True)
out = bn(c1(img)).relu()
check_schedule(out, 1, [c1.weight, c1.bias])
def test_fold_conv_batchnorm_notrain_no_running_stats(self):
with Tensor.train(False):
img = Tensor.empty(1,3,8,8)
c1 = nn.Conv2d(3,32,3)
bn = nn.BatchNorm2d(32, track_running_stats=False)
out = bn(c1(img)).relu()
check_schedule(out, 4, [c1.weight, c1.bias])
def test_fold_conv_batchnorm(self):
with Tensor.train():
img = Tensor.empty(1,3,8,8)
c1 = nn.Conv2d(3,32,3)
bn = nn.BatchNorm2d(32, track_running_stats=False)
out = bn(c1(img)).relu()
check_schedule(out, 4, [c1.weight, c1.bias])
def test_fold_conv_batchnorm_optim(self, adam=False):
# 2 is too low?
optim, cnt = (nn.optim.Adam, 16) if adam else (nn.optim.SGD, 2)
with Tensor.train():
img = Tensor.ones(1,3,4,4)
c1 = nn.Conv2d(3,32,3)
bn = nn.BatchNorm2d(32, track_running_stats=False)
_realize_weights([c1, bn])
opt = optim(nn.state.get_parameters([c1, bn]))
img_bn = bn(c1(img)).elu().sum()
opt.zero_grad()
img_bn.backward()
check_schedule(opt.schedule_step(), cnt)
def test_fold_conv_batchnorm_optim_adam(self): self.test_fold_conv_batchnorm_optim(True)
def test_fold_batchnorm_backward(self):
with Tensor.train():
x = Tensor.empty((2, 16, 8, 8)).contiguous()
bn = nn.BatchNorm2d(16)
bn.weight.requires_grad = bn.bias.requires_grad = x.requires_grad = True
fw = bn(x).contiguous_backward().relu().contiguous()
fw.sum().backward()
# TODO: this is too many
check_schedule([x.grad, bn.weight.grad, bn.bias.grad, fw], 9)
def test_fold_conv_relu(self):
c1 = nn.Conv2d(3,16,3)
# run
img = Tensor.ones(2,3,64,64)
out = c1(img).relu()
check_schedule(out, 1, [c1.weight, c1.bias])
def test_fold_conv_relu_alt(self):
img = Tensor.ones(1,4,8,8)
c1 = nn.Conv2d(4, 4, kernel_size=3)
c2 = nn.Conv2d(4, 4, kernel_size=3)
img_conv = img.sequential([c1, Tensor.relu, c2, Tensor.relu])
check_schedule(img_conv, 2, [*nn.state.get_parameters(c1), *nn.state.get_parameters(c2), img])
def test_fold_conv_relu_nobias(self):
img = Tensor.ones(1,4,8,8)
c1 = nn.Conv2d(4, 4, kernel_size=3, bias=False)
c2 = nn.Conv2d(4, 4, kernel_size=3, bias=False)
out = img.sequential([c1, Tensor.relu, c2, Tensor.relu])
check_schedule(out, 2, [c1.weight, c2.weight, img])
def test_fold_conv_elu(self):
c1 = nn.Conv2d(3,16,3)
# run
img = Tensor.rand(2,3,64,64)
out = c1(img).elu()
check_schedule(out, 1, [c1.weight, c1.bias, img])
def test_fold_conv_elu_alt(self):
img = Tensor.ones(1,4,8,8).contiguous()
c1 = nn.Conv2d(4, 4, kernel_size=3)
c2 = nn.Conv2d(4, 4, kernel_size=3)
img_conv = img.sequential([c1, Tensor.elu, c2, Tensor.elu])
check_schedule(img_conv, 2, [*nn.state.get_parameters(c1), *nn.state.get_parameters(c2), img])
def test_two_sum(self):
img = Tensor.empty(64,64)
x = (img.sum(0) + img.sum(1))
out = x.relu()
check_schedule(out, 1)
def test_push_permute_through_reshape(self):
a = Tensor.empty(16,16)
b = Tensor.empty(16,16)
c = (a+b).reshape(4,4,4,4).permute(2,3,0,1).contiguous()
check_schedule(c, 1)
#@unittest.skip("failing in old lazy")
def test_push_permute_through_reshape_alt(self):
a = Tensor.empty(4,4,4,4)
b = Tensor.empty(4,4,4,4)
c = (a+b).reshape(16,16).permute(1,0).contiguous()
check_schedule(c, 1)
def test_no_binop_rerun(self):
a = Tensor.empty(16)
b = Tensor.empty(16)
c = a+b
d = (a+b).reshape(16,1)
check_schedule(d, 0, [c])
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
def test_multi_permute_should_collapse(self):
a = Tensor.empty(4,4,4,4)
b = Tensor.empty(16)
c = a.sum((0,1)).cast(dtypes.float16).permute(1,0).reshape(4,4,1).permute(1,0,2).reshape(16) + b
check_schedule(c, 1)
def test_fancy_reshape_fusion(self):
a = Tensor.empty(10)
b = Tensor.empty(10)
c = a+b
d = a.reshape(10,1)+b.reshape(10,1)
out = c.sum() + d.sum()
check_schedule(out, 1)
def test_children_dont_push(self):
a = Tensor.empty(10, 10, 1)
b = Tensor.empty(10, 10, 1)
d = (a+b).expand(10, 10, 10)
e = (a+b).permute(2,1,0)
f = d+e
check_schedule(f, 1)
# failing in new lazy
@unittest.skip("always fusing elementwise")
def test_dont_fuse_binops_with_children(self):
a = Tensor.empty(10)
b = Tensor.empty(10)
c = Tensor.empty(10)
keep_me = a+b
e = keep_me.sum() # noqa: F841 give keep_me a child (NOTE: BinaryOps won't be a child since it will instant fuse)
d = keep_me+c
check_schedule(d, 2)
check_schedule(keep_me, 0, [d])
#@unittest.skip("failing in old lazy")
def test_permute_breaks_fusion(self):
a = Tensor.empty(10, 10, 10)
b = Tensor.empty(10, 10)
c = (a.sum(axis=2) + b).permute(1,0)
d = c.permute(1,0)
check_schedule(d, 1)
def test_some_permute_fusion(self):
a = Tensor.empty(8192, 16)
b = Tensor.empty(1, 16)
d = (a.T + b.expand(8192, 16).T)
c = a + b.expand(8192, 16)
e = d.T
check_schedule(c, 1)
check_schedule(e, 1)
def test_shrink_fuse(self):
a = Tensor.empty(8192, 16)
b = Tensor.empty(8192, 16)
c = a * b
d = Tensor.empty(1, 16)
e = c[0] * d
check_schedule(e, 1)
def test_expand_fuse(self):
a = Tensor.empty(1, 16)
b = Tensor.empty(1, 16)
c = a * b
d = Tensor.empty(8192, 16)
e = c * d
check_schedule(e, 1)
# this is the failing case in openpilot...it's very simple like this
def test_image_conv_fusion(self):
with Context(OPENPILOT_HACKS=1):
w1 = Tensor.empty(16, 16, 1, 1)
b1 = Tensor.empty(16)
w2 = Tensor.empty(16, 16, 1, 1)
b2 = Tensor.empty(16)
w3 = Tensor.empty(16, 16, 1, 1)
b3 = Tensor.empty(16)
x = Tensor.empty(1, 16, 32, 32)
x = base = x.image_conv2d(w1, b1)
x = x.image_conv2d(w2, b2) + base
x = x.image_conv2d(w3, b3)
# NOOP, 3 convs, contiguous
#check_schedule(x, 5)
check_schedule(x, 7)
def test_image_conv_fusion_minimal(self):
b1 = Tensor.empty(16)
b2 = Tensor.empty(16)
def p(x): return x.permute(1,0).contiguous().reshape(32,16,1).expand(32,16,16).sum(axis=2).permute(1,0)
x = Tensor.empty(16, 32)
x = base = p(x) + b1.reshape(16,1)
x = p(x)
x = x + b2.reshape(16,1)
x = x + base
del base
x = p(x)
check_schedule(x, 4)
def test_image_conv_fusion_more_minimal(self):
b1 = Tensor.empty(16)
def p(x): return x.permute(1,0).contiguous().reshape(32,16,1).expand(32,16,16).sum(axis=2).permute(1,0)
x = Tensor.empty(16, 32)
x = base = p(x) + b1.reshape(16,1)
x = p(x)
del base
check_schedule(x, 3)
def test_contiguous_while_contiguous(self):
x = Tensor.empty(1, 64, 32, 32)
out = x.contiguous()
check_schedule(out, 0, filter_sink=False)
def test_contiguous_while_not_contiguous(self):
x = Tensor.empty(1, 64, 32, 32)
out = x.permute(0,2,3,1).contiguous()
check_schedule(out, 1, filter_sink=False)
def test_fold_with_contiguous(self):
a = Tensor.randn(16, 16, 16).realize()
b = Tensor.randn(16, 16).realize()
c = (a.sum(2).contiguous() + b).contiguous()
check_schedule(c, 2)
def _alu_from_tensor(self, t:Tensor):
s = [s for s in t.schedule() if s.ast.op is Ops.SINK]
self.assertEqual(len(s), 1)
return [u.op for u in s[0].ast.toposort() if u.op in GroupOp.ALU]
def test_2_pow_is_exp2(self):
t = 2.0 ** Tensor([1.0, 2.0, 3.0])
self.assertEqual(self._alu_from_tensor(t), [Ops.EXP2])
def test_pow_05_is_sqrt(self):
t = Tensor([1.0, 2.0, 3.0]) ** 0.5
self.assertEqual(self._alu_from_tensor(t), [Ops.SQRT])
def test_pow_neg_05_is_rsqrt(self):
t = Tensor([1.0, 2.0, 3.0]) ** -0.5
self.assertEqual(self._alu_from_tensor(t), [Ops.RECIPROCAL, Ops.SQRT])
def test_pow_2_has_1_mul(self):
t = Tensor([1.0, 2.0, 3.0]) ** Tensor(2.0)
self.assertEqual(self._alu_from_tensor(t), [Ops.MUL])
def test_pow_8_has_3_muls(self):
t = Tensor([1.0, 2.0, 3.0]) ** 8
self.assertEqual(self._alu_from_tensor(t), [Ops.MUL, Ops.MUL, Ops.MUL])
@unittest.skip("const folding is removed")
def test_pow_const_tensor_to_zero(self):
x = Tensor([1,2,3,4])
out = x ** Tensor(0.0)
# NOTE: this is UOp.const(0) + UOp.const(1)
check_schedule(out, 0)
def test_zero_size(self):
x = Tensor.empty(2, 3, 0)
out = x + 1
check_schedule(out, 0, filter_sink=False)
def test_reduce_permute_nofuse(self):
x = Tensor.empty(32, 32, 32)
y = Tensor.empty(32, 32)
out = x.sum(axis=2).T+y
check_schedule(out, 1)
def test_two_elus_sum(self):
x = Tensor.empty(32, 32)
y = Tensor.empty(32, 32)
out = x.sum(1).relu().elu() + y.sum(1).relu().elu()
check_schedule(out, 1)
def test_multistage_reduce(self):
x = Tensor.empty(32, 32, 32)
out = x.sum(2).relu().sum(1)
check_schedule(out, 1)
def test_multistage_reduce_fork(self):
x = Tensor.empty(32, 32, 32)
x = x.sum(2)
out2 = x + 1
out = x.relu().sum(1) + out2[0]
check_schedule(out, 2)
def test_contiguous_add(self):
x = Tensor.empty(32)
y = Tensor.empty(32)
z = Tensor.empty(32)
out = (x+y).contiguous()+z
check_schedule(out, 2)
def test_double_sum_ref(self):
x = Tensor.empty(32, 32, 32)
x = x.sum(2)
out = x + x[:, 4]
check_schedule(out, 2)
def test_reduce_shrink(self):
x = Tensor.empty(32, 32)
y = Tensor.empty(16)
x = x.sum(1)
x = x[:16]
out = x + y
check_schedule(out, 1)
def test_const_no_recompute(self):
x = Tensor(2) + Tensor(2)
y = Tensor(2) + Tensor(2)
out = x.contiguous() + y.contiguous()
check_schedule(out, 2, filter_sink=False)
def test_reduce_shrink_child(self):
a = Tensor.empty(100, 100)
b = Tensor.empty(10,)
c = a.sum() + b[0]
d = a.sum() + 2
check_schedule([c, d], 2) # TODO: 1?
def test_reduce_multiple_paths_midshrink(self):
a = Tensor.empty(4, 4)
r = a.sum(axis=1)
out0 = r.exp2()
out1 = out0[0] + out0
check_schedule([r, out0, out1], 3)
def test_reduce_shrink_output(self):
a = Tensor.empty(4, 4)
r = a.sum(keepdim=True)
out0 = r.exp2()
out1 = out0[0] + Tensor.empty(1, )
check_schedule([r, out0, out1], 3)
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
def test_softmax_upcast(self):
# input half, softmax in float
Tensor.manual_seed(0)
x = Tensor.randn(4, 12, 64, 64, dtype=dtypes.half).realize()
out = x.softmax(dtype=dtypes.float)
sched = out.schedule()
self.assertEqual(len(sched), 3)
self.assertEqual(sched[0].bufs[0].dtype, dtypes.float)
# input float, softmax in float
Tensor.manual_seed(0)
x = Tensor.randn(4, 12, 64, 64, dtype=dtypes.float).realize()
out = x.softmax(dtype=dtypes.float)
sched = out.schedule()
self.assertEqual(len(sched), 3)
self.assertEqual(sched[0].bufs[0].dtype, dtypes.float)
def test_softmax_backward(self):
Tensor.manual_seed(0)
x = Tensor.randn(4, 12, 64, 64, requires_grad=True).realize()
x.softmax().sum().backward()
run_schedule(check_schedule(x.grad, 4))
def test_scaled_dot_product_attention_fusion(self):
x, y, z, m = (Tensor.empty(32, 8, 16, 16) for _ in range(4))
out = Tensor.scaled_dot_product_attention(x, y, z, attn_mask=m)
check_schedule(out, 4)
def test_scaled_dot_product_attention_causal_fusion(self):
x, y, z = (Tensor.empty(32, 8, 16, 16) for _ in range(3))
out = Tensor.scaled_dot_product_attention(x, y, z, is_causal=True)
check_schedule(out, 4)
def test_adam_step_fusion(self):
with Tensor.train():
x = Tensor.empty(4, 64, 32)
layer = nn.Linear(32, 32*4)
_realize_weights(layer)
opt = nn.optim.Adam(nn.state.get_parameters(layer), lr=1e-4)
layer(x).relu().sum().backward()
check_schedule(opt.schedule_step(), 13)
def test_adam_conv_fuse(self):
with Tensor.train():
img = Tensor.empty(2,3,4,4)
c1 = nn.Conv2d(3,32,3)
_realize_weights(c1)
opt = nn.optim.Adam(nn.state.get_parameters(c1), lr=1e-4)
opt.zero_grad()
c1(img).relu().sum().backward()
check_schedule(opt.schedule_step(), 13)
def test_adam_2convs_fuse(self):
with Tensor.train():
img = Tensor.empty(2,3,4,4)
c1 = nn.Conv2d(3,16,3,bias=False)
c2 = nn.Conv2d(16,32,2,bias=False)
_realize_weights([c1, c2])
opt = nn.optim.Adam(nn.state.get_parameters([c1, c2]), lr=1e-4)
opt.zero_grad()
c2(c1(img).relu()).relu().sum().backward()
check_schedule(opt.schedule_step(), 15)
def test_sgd_conv_fuse(self):
with Tensor.train():
img = Tensor.empty(2,3,4,4)
c1 = nn.Conv2d(3,32,3)
_realize_weights(c1)
opt = nn.optim.SGD(nn.state.get_parameters(c1))
opt.zero_grad()
c1(img).relu().sum().backward()
check_schedule(opt.schedule_step(), 5) # TODO: 3?
def test_sgd_2convs_fuse(self):
with Tensor.train():
img = Tensor.empty(2,3,4,4)
c1 = nn.Conv2d(3,16,3,bias=False)
c2 = nn.Conv2d(16,32,2,bias=False)
_realize_weights([c1, c2])
opt = nn.optim.SGD(nn.state.get_parameters([c1, c2]))
opt.zero_grad()
c2(c1(img).relu()).relu().sum().backward()
check_schedule(opt.schedule_step(), 7)
def test_fold_2convs_sgd_nesterov_momentum_wd(self):
with Tensor.train():
img = Tensor.empty(2,3,4,4)
c1 = nn.Conv2d(3,16,3,bias=False)
c2 = nn.Conv2d(16,32,2,bias=False)
_realize_weights([c1, c2])
opt = nn.optim.SGD(nn.state.get_parameters([c1, c2]), nesterov=True, momentum=0.9, weight_decay=0.1)
opt.zero_grad()
c2(c1(img).relu()).relu().sum().backward()
check_schedule(opt.schedule_step(), 11)
def test_sgd_4convs_fuse(self):
with Tensor.train():
img = Tensor.empty(2,3,16,16)
c1 = nn.Conv2d(3,4,3,bias=False)
c2 = nn.Conv2d(4,8,3,bias=False)
c3 = nn.Conv2d(8,16,3,bias=False)
c4 = nn.Conv2d(16,32,3,bias=False)
_realize_weights([c1, c2, c3, c4])
opt = nn.optim.SGD(nn.state.get_parameters([c1, c2, c3, c4]))
opt.zero_grad()
c4(c3(c2(c1(img).relu()).relu()).relu()).relu().sum().backward()
check_schedule(opt.schedule_step(), 15)
def test_sgd_4convs_fuse_conv_bw(self):
with Tensor.train():
img = Tensor.empty(2,3,16,16)
c1 = nn.Conv2d(3,4,3,bias=False)
c2 = nn.Conv2d(4,8,3,bias=False)
c3 = nn.Conv2d(8,16,3,bias=False)
c4 = nn.Conv2d(16,32,3,bias=False)
_realize_weights([c1, c2, c3, c4])
opt = nn.optim.SGD(nn.state.get_parameters([c1, c2, c3, c4]))
opt.zero_grad()
c4(c3(c2(c1(img).relu()).relu()).relu()).relu().sum().backward()
check_schedule(opt.schedule_step(), 15)
def test_reduce_simple_chase(self):
a = Tensor.empty(4, 4, 4)
r = a.sum(0) + 6
b = r.sum(0) * 4
c = r.sum(1) * 2
check_schedule([b, c], 3)
def test_push_permute_chase(self):
a = Tensor.empty(4, 4, 4)
b = Tensor.empty(4, 4)
r = a.sum(2) + b
d = r.T * 4
e = r * d
check_schedule([d, e], 3)
def test_push_shrink_chase(self):
a = Tensor.empty(16, 16)
b = Tensor.empty(4)
c = Tensor.empty(16, )
r = a.sum(1) + c
d = r[:4] * b
check_schedule(d, 1)
def test_midreduce_nochase(self):
a = Tensor.empty(16, 16)
b = (a.sum(0) + a.max(1)) + 2
check_schedule(b, 1)
def test_bitcast_fuses(self):
x = Tensor.empty(1, dtype=dtypes.float32)
a = x.exp2().bitcast(dtypes.int32)
b = x.bitcast(dtypes.int32)
check_schedule(a+b, 1) # this should fuse when it makes sense
def test_reduceop_reshape_dont_push(self):
Tensor.manual_seed(0)
x = Tensor.randn(10, 20).realize()
out = x.argmax(1)
run_schedule(check_schedule(out, 2))
def test_resnet_conv2d(self):
x = Tensor.empty(1, 8, 32, 32)
w1 = Tensor.empty(8, 8, 3, 3)
w2 = Tensor.empty(8, 8, 1, 1)
out = x.conv2d(w1).conv2d(w2)
check_schedule(out, 2)
def test_schedule_mem_used(self):
gc.collect()
base = GlobalCounters.mem_used
Tensor.ones(256).contiguous().realize()
Tensor.ones(5, 5).contiguous().schedule()
gc.collect()
self.assertEqual(GlobalCounters.mem_used-base, 0)
def test_const_schedule(self):
constv = Tensor.empty(2, 2).uop.const_like(10)
check_schedule(constv, 0)
def test_const_schedule_contig(self):
constv = Tensor.empty(2, 2).uop.const_like(10).contiguous()
check_schedule(constv, 1)
def test_advanced_simple_indexing_combined(self):
X = Tensor.arange(16).reshape(4, 4)
xt = X[1:2, [-1, 2]]
check_schedule(xt, 1)
def test_arange_index_shrink(self):
Tensor.manual_seed(0)
with Context(TRACK_MATCH_STATS=0):
x = Tensor.randn(11).realize()
a = Tensor.arange(22)
out = (x + a[:11]).sum()
check_schedule(out, 1)
def test_fuse_arange_avg_pool2d_ceil_mode(self):
x = Tensor.avg_pool2d(Tensor.empty(1,1,6,6), kernel_size=(3,3), padding=1, stride=3, ceil_mode=True)
sched = check_schedule(x, 1)
self.assertEqual(len([x for x in sched[0].ast.backward_slice_with_self if x.op is Ops.REDUCE]), 1)
def test_fuse_arange_pad_circular_mode_bw(self):
x = Tensor.empty(1,1,5,5,5)
out = x.pad((1,2,3,5,1,2), mode="circular")
g = out.sum().gradient(x)[0]
sched = check_schedule(g, 1)
self.assertEqual(len([x for x in sched[0].ast.backward_slice_with_self if x.op is Ops.REDUCE]), 0)
def test_resnet_block(self):
with Tensor.train(False):
in_planes, planes = 64, 64
conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
bn1 = nn.BatchNorm2d(planes)
conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, stride=1, bias=False)
bn2 = nn.BatchNorm2d(planes)
x = Tensor.empty(1, 64, 32, 32)
out = bn1(conv1(x)).relu()
out = bn2(conv2(out))
out = (out + x).relu()
run_schedule(check_schedule(out, 2, [conv1.weight, conv2.weight]))
class TestSwizzle(unittest.TestCase):
def test_softmax_one_kernel(self):
Tensor.manual_seed(0)
with Context(DEBUG=0, TRACK_MATCH_STATS=0):
a = Tensor.randn(32, 32).realize()
t = a.softmax()
check_schedule(t, 3) # TODO: 1?
def test_argmax_one_kernel(self):
Tensor.manual_seed(0)
with Context(DEBUG=0, TRACK_MATCH_STATS=0):
a = Tensor.randn(10, 20).realize()
t = a.argmax(0)
check_schedule(t, 2) # TODO: 1?
class TestView(unittest.TestCase):
def test_zero_size_alt(self):
a = Tensor.empty(135, 0, 9)
b = a.pad(((0, 0), (0, 0), (18, 0)))
check_schedule(b, 0)
class TestUOpBecome(unittest.TestCase):
# the simplest case, if we create a new BUFFER for this tensor UOp
def test_new_buffer(self):
a = Tensor.empty(4, 4)
b = Tensor.empty(4, 4)
add = a+b
check_schedule(add, 1)
# NOTE: realized base is always a flat buffer
assert UPat(Ops.BUFFER).match(add.uop.base, {})
# the Tensor UOp can optionally stack a VIEW on top of the BUFFER, in this case to preserve the (4, 4) shape of the tensor
assert add.uop is not add.uop.base
self.assertEqual(add.uop.size, 16)
self.assertEqual(add.uop.shape, (4, 4))
def test_new_buffer_view(self):
a = Tensor.empty(4, 4)
b = Tensor.empty(4, 4)
add = (a+b).reshape(8, 2)
check_schedule(add, 1)
assert UPat(Ops.BUFFER).match(add.uop.base, {})
# the shape is preserverd in the becomes_map.
self.assertEqual(add.uop.shape, (8, 2))
assert add.uop is not add.uop.base
def test_new_flat_buffer(self):
a = Tensor.empty(4,)
b = Tensor.empty(4,)
add = a+b
check_schedule(add, 1)
# BUFFER already has a shape (4,), this tensor just becomes a contiguous BUFFER
assert UPat(Ops.BUFFER).match(add.uop.base, {})
# sometimes we prefer to perform an op before movement ops, in this case we should stack the mops on top of the new buffer
@unittest.skip("no longer supported")
def test_reorder_expand(self):
a = Tensor.empty(4, 1)
b = a.expand(4, 4).reciprocal()
check_schedule(b, 1)
self.assertEqual(b.uop.base.buffer.size, 4)
self.assertEqual(b.uop.shape, (4, 4))
def test_reorder_expand_alt(self):
x = Tensor.empty(4, 1)
y = Tensor.empty(4, 1)
img = Tensor.empty(4, 4)
z = (img*x) / y
check_schedule(z, 1)
# TODO: rangeify doesn't yet cleanup this kind of re-indexing
@unittest.expectedFailure
def test_become_existing_buffer(self):
a = Tensor.empty(4, 4)
b = a*1
assert UPat(Ops.MUL).match(b.uop, {}) # before scheduling it's a mul
check_schedule(b, 0)
self.assertIs(a.uop.base.buffer, b.uop.base.buffer)
def test_become_buf_with_mops(self):
a = Tensor.empty(2, 4, 2)
noop = a.shrink(((1, 2), (0, 4), (0, 2))).reshape(4, 2)*1+0
# before realizing, this tensor is base
assert noop.uop is noop.uop.base
noop.realize()
# it becomes a realized view after realize
assert noop.uop is not noop.uop.base
assert noop.uop.base.op is Ops.BUFFER
late_add = noop+2
late_add.realize()
@unittest.skip("const folding is removed")
def test_become_const_in_base(self):
a = Tensor.empty(4)
b = a*0
assert UPat(Ops.MUL).match(b.uop, {}) # before scheduling it's a mul
check_schedule(b, 0)
assert UPat(Ops.CONST, arg=0).match(b.uop.base, {}) # scheduling replaces the tensor uop with a VIEW(BUFFER)
@unittest.skip("const folding is removed")
def test_become_const_from_const(self):
const_add = Tensor(1)+Tensor(2)
assert UPat(Ops.ADD).match(const_add.uop, {})
check_schedule(const_add, 0)
assert UPat(Ops.CONST, arg=3).match(const_add.uop.base, {})
# tensors can become another realized tensor source
@unittest.expectedFailure
def test_become_existing_buf_simple(self):
a = Tensor.empty(4, 4)
b = a+0
check_schedule(b, 0)
assert b.uop.base.op is Ops.BUFFER
self.assertIs(a.uop, b.uop)
# they can also chain other movement ops on top of the tensor source
@unittest.expectedFailure
def test_become_existing_buf_view(self):
a = Tensor.empty(4, 4)
b = a.permute((1, 0))+0
check_schedule(b, 0)
self.assertEqual(b.uop.st, a.uop.permute((1, 0)).st)
@unittest.expectedFailure
def test_become_existing_buf_view_alt(self):
a = Tensor.empty(4, 4)
b = a.permute((1, 0)).reshape((8, 2))+0
check_schedule(b, 0)
self.assertEqual(b.uop.st, a.uop.permute((1, 0)).reshape((8, 2)).st)
# they can also have other base parents that simplified, in that case we just backtrack to the chained mops
@unittest.expectedFailure
def test_become_existing_buf_complex(self):
a = Tensor.empty(4, 4)
b = (a.permute((1, 0))+0).reshape((8, 2))+0
check_schedule(b, 0)
self.assertEqual(b.uop.st, a.uop.permute((1, 0)).reshape((8, 2)).st)
assert b.uop.base.op is Ops.BUFFER
@unittest.expectedFailure
def test_become_multiple_choices(self):
a = Tensor.empty(16)
b = (a.reshape(1, 1, 4, 1, 4)+0).reshape(1, 1, 4, 4).shrink(((0, 1), (0, 1), (0, 3), (0, 3)))+0
c = (a.reshape(1, 1, 4, 4)+0).shrink(((0, 1), (0, 1), (0, 3), (0, 3)))+0
check_schedule([b, c], 0)
from tinygrad.helpers import all_same
assert all_same([x.uop.base.realized for x in [a,b,c]])
@unittest.skip("not clear if we want this")
def test_setitem_becomes_subbuffer(self):
a = Tensor.full((4,), 2.).contiguous().realize()
b = a.shrink(((0, 2),)).assign(Tensor.full((2,), 1.0))
b.realize()
assert a.uop.is_realized
assert a.uop.buffer._base is None
assert b.uop.op_in_backward_slice_with_self(Ops.SHRINK)
assert b.uop.base is a.uop.base
class TestFusionOp(unittest.TestCase):
def test_recursive_add(self):
st = time.perf_counter()
a = Tensor([1,2,3,4])
for _ in range(24): a = a + a
sched = a.schedule()
sched[-1].lower()
self.assertLess(time.perf_counter()-st, 2.0)
assert len(sched[-1].prg.p.src.splitlines()) < 250
def test_recursive_add_cmp(self):
st = time.perf_counter()
a = Tensor([1,2,3,4])
for _ in range(24): a = a + a
sched1 = a.schedule()
b = Tensor([1,2,3,4])
for _ in range(24): b = b + b
sched2 = b.schedule()
c = Tensor([1,2,3,4])
for _ in range(23): c = c + c
sched3 = c.schedule()
self.assertEqual(sched1[-1].ast, sched2[-1].ast)
with self.assertRaises(AssertionError): self.assertEqual(sched1[-1].ast, sched3[-1].ast)
self.assertLess(time.perf_counter()-st, 2.0)
def test_recursive_pad(self):
st = time.perf_counter()
val = 1.0
a = Tensor(val)
for _ in range(24): a = Tensor.stack(a, a)[0]
sched = a.schedule()
self.assertLessEqual(len(sched), 1)
self.assertLess(time.perf_counter()-st, 2.0)
def test_recursive_reshape(self):
st = time.perf_counter()
a = Tensor.empty(32, 32).realize()
b = Tensor.empty(16, 2).realize()
r = a.sum(1)
for _ in range(24): r = r.reshape(16, 2) + b
sched = r.schedule()
self.assertEqual(len(sched), 1)
self.assertLess(time.perf_counter()-st, 2.0)
# NOTE: the NULL backend supports BUFFER_VIEW
class TestBufferView(unittest.TestCase):
def test_shrink_contiguous_is_buffer_view(self):
# simple 1D shrink of a realized buffer should be BUFFER_VIEW, not a copy kernel
a = Tensor.arange(100).contiguous().realize()
b = a.shrink(((10, 50),)).contiguous()
run_schedule(check_schedule(b, 0))
def test_shrink_2d_contiguous_is_buffer_view(self):
a = Tensor.arange(100).reshape(10,10).contiguous().realize()
b = a.shrink(((1, 5),None)).contiguous()
run_schedule(check_schedule(b, 0))
def test_chained_shrink_is_buffer_view(self):
a = Tensor.arange(1000).contiguous().realize()
b = a.shrink(((200, 800),)).shrink(((0, 300),)).reshape((30, 10)).shrink(((20, 25), (0, 10))).contiguous()
run_schedule(check_schedule(b, 0))
if __name__ == '__main__':
unittest.main(verbosity=2)