mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
1194 lines
38 KiB
Python
1194 lines
38 KiB
Python
# schedule tests that pass on NULL backend (no copyout needed)
|
|
import gc, unittest, time
|
|
from tinygrad import nn, dtypes, Device, Tensor
|
|
from tinygrad.device import is_dtype_supported
|
|
from tinygrad.uop.ops import UOp, Ops, GroupOp, UPat
|
|
from tinygrad.helpers import DEBUG, GlobalCounters, Context
|
|
from tinygrad.engine.realize import CompiledRunner, run_schedule
|
|
|
|
class KernelCountException(Exception): pass
|
|
def check_schedule(t:Tensor|list[Tensor]|UOp, allowed:int, to_prerealize:list[Tensor]|None=None, filter_sink=True):
|
|
if to_prerealize:
|
|
with Context(DEBUG=0, TRACK_MATCH_STATS=0): Tensor.realize(*to_prerealize)
|
|
if isinstance(t, Tensor): sched = t.schedule()
|
|
elif isinstance(t, list) and isinstance(t[0], Tensor): sched = Tensor.schedule(*t)
|
|
else:
|
|
assert isinstance(t, UOp), f"can't schedule {t}"
|
|
sched = Tensor(t).schedule()
|
|
# test lowering all the ExecItems
|
|
for si in sched: si.lower()
|
|
kernel_cnt = len([si for si in sched if isinstance(si.prg, CompiledRunner) or not filter_sink])
|
|
if kernel_cnt != allowed:
|
|
print(f"SCHEDULE ISSUE, expecting {allowed} got {kernel_cnt}")
|
|
if DEBUG >= 3:
|
|
for i,s in enumerate(sched):
|
|
print("kernel", i+1)
|
|
print(s.ast)
|
|
raise KernelCountException(f"{kernel_cnt} != {allowed}")
|
|
return sched
|
|
|
|
def _realize_weights(m):
|
|
for p in nn.state.get_parameters(m): p.realize()
|
|
|
|
class TestBufferUOp(unittest.TestCase):
|
|
# BUFFER has a ShapeTracker of shape=(n,) and stride=(1,)
|
|
def test_buffer_has_buffer(self):
|
|
buf = Tensor.empty(10)
|
|
self.assertIsNotNone(buf.uop.buffer)
|
|
self.assertEqual(buf.uop.shape, (10,))
|
|
# the device Buffer remains unallocated until it's we run the schedule
|
|
self.assertFalse(buf.uop.buffer.is_allocated())
|
|
add = buf+1
|
|
sched = add.schedule()
|
|
self.assertFalse(buf.uop.buffer.is_allocated())
|
|
run_schedule(sched)
|
|
self.assertTrue(buf.uop.buffer.is_allocated())
|
|
|
|
def test_buffer_has_unique_buffer(self):
|
|
buf = Tensor.empty(10)
|
|
buf1 = buf.uop.buffer
|
|
buf2 = buf.uop.buffer
|
|
self.assertIs(buf1, buf2)
|
|
|
|
# we also allow VIEW(BUFFER) to access the underlying device Buffer, as long as it's contiguous
|
|
def test_buffer_view_allowed(self):
|
|
add = Tensor.empty(1, 1)+Tensor.empty(1, 1)
|
|
add.realize()
|
|
self.assertIsNotNone(add.uop.buffer)
|
|
self.assertEqual(add.uop.shape, (1, 1))
|
|
|
|
def test_buffer_view_not_allowed(self):
|
|
permuted_view = Tensor.empty(1, 2, 3).permute(0, 2, 1)
|
|
with self.assertRaises(RuntimeError):
|
|
permuted_view.uop.buffer # cannot access Buffer of a non contiguous VIEW
|
|
|
|
def test_buffer_only_after_realize(self):
|
|
a = Tensor([1])+Tensor([2])
|
|
# accessing realized will return None
|
|
self.assertIsNone(a.uop.realized)
|
|
# accessing Buffer will assert
|
|
with self.assertRaisesRegex(AssertionError, "must be BUFFER"):
|
|
a.uop.buffer # there is no BUFFER on an unrealized ADD
|
|
# Buffer only exists once we realize it
|
|
a.realize()
|
|
self.assertIsNotNone(a.uop.buffer)
|
|
|
|
def test_const_does_not_realize(self):
|
|
a = Tensor(1)
|
|
run_schedule(check_schedule(a, 0))
|
|
self.assertIsNone(a.uop.base.realized)
|
|
|
|
def test_var_does_not_realize(self):
|
|
a = Tensor(UOp.variable("a", 0, 10).bind(1))
|
|
run_schedule(check_schedule(a, 0))
|
|
self.assertIsNone(a.uop.base.realized)
|
|
|
|
def test_unused_var_not_in_var_vals(self):
|
|
# unused variable should not appear in var_vals even when there's other work
|
|
a = Tensor(UOp.variable("unused", 0, 10).bind(1))
|
|
b = Tensor.empty(3) + 1
|
|
_, var_vals = Tensor.schedule_with_vars(a, b)
|
|
self.assertEqual(var_vals, {})
|
|
self.assertIsNone(a.uop.base.realized)
|
|
|
|
def test_view_does_not_realize(self):
|
|
a = Tensor.randn(1, 4).expand(4, 4)
|
|
a.realize()
|
|
self.assertEqual(a.uop.base.realized.size, 4)
|
|
a2 = a.contiguous().realize()
|
|
self.assertEqual(a2.uop.base.realized.size, 16)
|
|
|
|
class TestContiguous(unittest.TestCase):
|
|
def test_contiguous_buffer(self):
|
|
a = Tensor.empty(4)
|
|
b = a.contiguous()
|
|
check_schedule(b, 0)
|
|
|
|
def test_contiguous_buffer_view(self):
|
|
a = Tensor.empty(4)
|
|
b = a.reshape((2, 2)).contiguous()
|
|
check_schedule(b, 0)
|
|
|
|
def test_non_contiguous_buffer_view(self):
|
|
a = Tensor.empty(4, 1)
|
|
b = a.expand((4, 4)).contiguous()
|
|
check_schedule(b, 1)
|
|
|
|
def test_size_change_buffer_view(self):
|
|
a = Tensor.empty(4)
|
|
b = a.reshape((1, 1, 4)).shrink(((0, 1), (0, 1), (0, 3))).contiguous()
|
|
check_schedule(b, 0) # contiguous shrink of a realized buffer is a zero-copy BUFFER_VIEW
|
|
|
|
def test_double_contiguous_realizes_once(self):
|
|
a = Tensor.empty(4, 1)
|
|
b = a.expand((4, 4)).contiguous().contiguous()
|
|
check_schedule(b, 1)
|
|
|
|
def test_view_does_not_realize(self):
|
|
a = Tensor.empty(4)
|
|
b = a.expand((4, 4))
|
|
check_schedule(b, 0)
|
|
self.assertEqual(b.uop.base.buffer.size, 4)
|
|
|
|
def test_contiguous_view_realizes(self):
|
|
a = Tensor.empty(4)
|
|
b = a.expand((4, 4)).contiguous()
|
|
check_schedule(b, 1)
|
|
self.assertEqual(b.uop.base.buffer.size, 16)
|
|
|
|
class TestSimpleSchedule(unittest.TestCase):
|
|
def test_reduce_doesnt_split(self):
|
|
a = Tensor.empty(16,16).sum(axis=1)
|
|
a1 = a.reshape(4,4)
|
|
a2 = a.reshape(16,1,1)
|
|
self.assertEqual(len(Tensor.schedule(a1, a2)), 1)
|
|
|
|
class TestSchedule(unittest.TestCase):
|
|
@unittest.skipIf(Device.DEFAULT == "CPU", "devices must mismatch")
|
|
def test_error_on_device_mismatch(self):
|
|
a = Tensor.empty(10)
|
|
b = Tensor.empty(10, device="CPU")
|
|
c = a+b
|
|
with self.assertRaisesRegex(RuntimeError, "all buffers must be on the same device"): check_schedule(c, 1)
|
|
|
|
@unittest.skipIf(Device.DEFAULT == "CPU", "devices must mismatch")
|
|
def test_error_on_device_mismatch_alt(self):
|
|
a = Tensor.empty(10)
|
|
b = Tensor.empty((1,), device="CPU").expand(10).contiguous()
|
|
c = a+b
|
|
with self.assertRaisesRegex(RuntimeError, "all buffers must be on the same device"): check_schedule(c, 2)
|
|
|
|
def test_rand(self):
|
|
x = Tensor.rand(32)
|
|
check_schedule(x, 1, [Tensor._device_rng_counters[x.device]])
|
|
|
|
def test_rand_recompute_arange(self):
|
|
x = Tensor.rand(32)
|
|
check_schedule(x, 1, [Tensor._device_rng_counters[x.device]])
|
|
|
|
def test_empty_is_not_realized(self):
|
|
a = Tensor.empty(10)
|
|
child = a+2
|
|
assert not a.uop.is_realized
|
|
child.realize()
|
|
assert a.uop.is_realized
|
|
|
|
def test_realize_view_of_realized_has_empty_schedule(self):
|
|
# views of realized buffers produce an empty schedule
|
|
t = Tensor.zeros((3, 3)).contiguous().realize()
|
|
v = t[1] # view - is_realized but not has_buffer_identity
|
|
assert v.uop.is_realized
|
|
sched, _ = Tensor.schedule_with_vars(v)
|
|
self.assertEqual(len(sched), 0)
|
|
|
|
# NOTE: because empty does not have a lowered ExecItem if realize is called on a childless empty, it never gets allocated.
|
|
def test_childless_empty_never_allocates(self):
|
|
a = Tensor.empty(10)
|
|
a.realize()
|
|
assert not a.uop.is_realized
|
|
|
|
def test_simplify_padded_const(self):
|
|
a, _ = Tensor.empty(1022).cummax(axis=0)
|
|
check_schedule(a, 3)
|
|
|
|
@unittest.skip("should this pass?")
|
|
def test_contiguous_assign(self):
|
|
a = Tensor.ones(10) * 2
|
|
b = Tensor.empty(10)
|
|
c = b.assign(a.contiguous())
|
|
check_schedule(c, 1)
|
|
|
|
def test_basic_binop_fusion(self):
|
|
a = Tensor.empty(10)
|
|
b = Tensor.empty(10)
|
|
c = Tensor.empty(10)
|
|
d = a+b+c
|
|
check_schedule(d, 1)
|
|
|
|
def test_basic_binop_fusion_assign(self):
|
|
a = Tensor.empty(10)
|
|
b = Tensor.empty(10)
|
|
c = Tensor.empty(10)
|
|
d = a+b+c
|
|
e = Tensor.empty(10).assign(d)
|
|
check_schedule(e, 1)
|
|
|
|
def test_basic_binop_fusion_deep(self):
|
|
a = Tensor.empty(10)
|
|
b = Tensor.empty(10)
|
|
c = Tensor.empty(10)
|
|
d = Tensor.empty(10)
|
|
e = a+b+c+d
|
|
check_schedule(e, 1)
|
|
|
|
def test_mulacc_fusion(self):
|
|
a = Tensor.empty(10)
|
|
b = Tensor.empty(10)
|
|
c = (a*b).sum()
|
|
check_schedule(c, 1)
|
|
|
|
def test_mulacc_fusion_assign(self):
|
|
a = Tensor.empty(10)
|
|
b = Tensor.empty(10)
|
|
c = (a*b).sum()
|
|
d = Tensor.empty(1).assign(c)
|
|
check_schedule(d, 1)
|
|
|
|
def test_detach_assign(self):
|
|
a = Tensor.ones(4, 4).contiguous().realize()
|
|
buf1, buf2 = Tensor.empty(4, 4).contiguous(), Tensor.empty(4, 4).contiguous()
|
|
r = buf2.assign(buf1.assign(a + 1.0) * 2.0)
|
|
check_schedule(r.detach().contiguous(), 2)
|
|
|
|
def test_contiguous_backward_assign(self):
|
|
a = Tensor.ones(4, 4).contiguous().realize()
|
|
buf1, buf2 = Tensor.empty(4, 4).contiguous(), Tensor.empty(4, 4).contiguous()
|
|
r = buf2.assign(buf1.assign(a + 1.0) * 2.0)
|
|
check_schedule(r.contiguous_backward().contiguous(), 2)
|
|
|
|
def test_mulacc_relu_fusion(self):
|
|
a = Tensor.empty(10)
|
|
b = Tensor.empty(10)
|
|
c = (a*b).sum().relu()
|
|
check_schedule(c, 1)
|
|
|
|
def test_binop_reshape_fusion(self):
|
|
a = Tensor.empty(10)
|
|
b = Tensor.empty(10)
|
|
c = Tensor.empty(5,2)
|
|
d = (a+b).reshape(5,2)+c
|
|
check_schedule(d, 1)
|
|
|
|
def test_binop_permute_fusion(self):
|
|
a = Tensor.empty(2,5)
|
|
b = Tensor.empty(2,5)
|
|
c = Tensor.empty(5,2)
|
|
d = (a+b).permute(1,0)+c
|
|
check_schedule(d, 1)
|
|
|
|
def test_constants_are_embedded(self):
|
|
a = Tensor.empty(3,3) * 2
|
|
check_schedule(a, 1, filter_sink=False)
|
|
|
|
def tests_constants_are_folded(self):
|
|
a = Tensor(2)
|
|
check_schedule(a, 0)
|
|
|
|
def test_binop_elu_fusion(self):
|
|
a = Tensor.empty(10)
|
|
b = a.elu()
|
|
check_schedule(b, 1)
|
|
|
|
def test_binop_reshape_reduce_fusion(self):
|
|
a = Tensor.empty(100)
|
|
b = Tensor.empty(100)
|
|
c = (a+b).reshape(10, 10).sum(axis=0, keepdim=True)
|
|
check_schedule(c, 1)
|
|
|
|
def test_reduce_reshape_binop_fusion(self):
|
|
a = Tensor.empty(10,10)
|
|
b = Tensor.empty(10)
|
|
c = a.sum(axis=0) + b
|
|
check_schedule(c, 1)
|
|
|
|
def test_reduce_permute_binop_fusion(self):
|
|
a = Tensor.empty(10,10,10)
|
|
b = Tensor.empty(10,10,1)
|
|
c = a.sum(axis=0, keepdim=True).permute(2,1,0) + b
|
|
check_schedule(c, 1)
|
|
|
|
def test_binop_early_reshape_reduce_fusion(self):
|
|
a = Tensor.empty(100)
|
|
b = Tensor.empty(100)
|
|
c = Tensor.empty(10,10)
|
|
d = ((a+b).reshape(10,10) + c).sum(axis=0)
|
|
check_schedule(d, 1)
|
|
|
|
def test_diamond_folded(self):
|
|
a = Tensor.empty(10)
|
|
b = Tensor.empty(10)
|
|
c = Tensor.empty(10)
|
|
d = Tensor.empty(10)
|
|
ab = a+b
|
|
e = (ab+c) + (ab+d)
|
|
check_schedule(e, 1)
|
|
|
|
def test_cache_binaryop(self):
|
|
a = Tensor.empty(10)
|
|
b = Tensor.empty(10)
|
|
c = a+b
|
|
d = a+b
|
|
check_schedule(d, 0, [c])
|
|
|
|
# failing in new lazy
|
|
def test_cache_binaryop_reshaped(self):
|
|
a = Tensor.empty(10)
|
|
b = Tensor.empty(10)
|
|
c = a+b
|
|
d = a.reshape(10,1)+b.reshape(10,1)
|
|
check_schedule(d, 1, [c])
|
|
|
|
# failing in new lazy
|
|
def test_cache_binaryop_transpose(self):
|
|
a = Tensor.empty(10,10)
|
|
b = Tensor.empty(10,10)
|
|
c = (a.T*b.T).T #.contiguous()
|
|
d = a*b
|
|
check_schedule(d, 1, [c])
|
|
|
|
def test_cache_two_reduceops(self):
|
|
a = Tensor.empty(10)
|
|
b = a.sum()
|
|
c = a.sum()
|
|
bc = b+c
|
|
check_schedule(bc, 1)
|
|
|
|
def test_cache_reduce_parent(self):
|
|
x = Tensor.empty(32)
|
|
r0 = x.mean(axis=0, keepdim=True)
|
|
r1 = (x - r0).sum(axis=0).div(2)
|
|
out = r0 + r1
|
|
schedule = check_schedule(out, 2)
|
|
reduceops = [x for si in schedule for x in si.ast.toposort() if x.op in {Ops.REDUCE_AXIS, Ops.REDUCE}]
|
|
assert len(reduceops) == 2
|
|
|
|
def test_cache_reduce_multiple_children(self):
|
|
x = Tensor.empty(32)
|
|
y = Tensor.empty(4, 4)
|
|
r0 = x.mean(axis=0, keepdim=True)
|
|
r1 = (x - r0).sum(axis=0).div(2)
|
|
out0 = r0 + y
|
|
out1 = r1 + y
|
|
schedule = check_schedule([out0, out1], 3)
|
|
reduceops = [x for si in schedule for x in si.ast.toposort() if x.op in {Ops.REDUCE_AXIS, Ops.REDUCE}]
|
|
self.assertEqual(len(reduceops), 2) # why is RANGEIFY different?
|
|
|
|
def test_dedup_assign(self):
|
|
a = Tensor.ones(4).contiguous().realize()
|
|
b = Tensor.full((4,), 2.).contiguous()
|
|
first = a.assign(b)
|
|
second = a.assign(b)
|
|
check_schedule([first, second], 2) # TODO: 1?
|
|
|
|
def test_no_dedup_empty(self):
|
|
a = Tensor.empty((4,))
|
|
b = Tensor.empty((4,))
|
|
# NOTE: empty does not have any schedule
|
|
check_schedule([a, b], 0, filter_sink=False)
|
|
self.assertIsNot(a.uop.buffer, b.uop.buffer)
|
|
|
|
def test_dedup_outputs(self):
|
|
a = Tensor.full((4, 4), 1.).contiguous().realize()
|
|
b = Tensor.full((4, 4), 1.).contiguous().realize()
|
|
check_schedule([a+b, a+b], 1)
|
|
|
|
def test_const_realize(self):
|
|
t = Tensor.ones(2)
|
|
check_schedule(t[0], 0)
|
|
check_schedule(t[1], 0)
|
|
|
|
def test_fold_double_unary(self):
|
|
y = Tensor.empty(2)
|
|
out = y.sum(keepdim=True).sqrt().neg()
|
|
check_schedule(out, 1)
|
|
|
|
#@unittest.skip("may want to reconsider this")
|
|
def test_fold_batchnorm(self):
|
|
with Tensor.train():
|
|
img = Tensor.empty(1,32,4,4)
|
|
bn = nn.BatchNorm2d(32, track_running_stats=False)
|
|
out = bn(img)
|
|
check_schedule(out, 3)
|
|
|
|
def test_fold_conv_batchnorm_notrain(self):
|
|
with Tensor.train(False):
|
|
img = Tensor.empty(1,3,8,8)
|
|
c1 = nn.Conv2d(3,32,3)
|
|
bn = nn.BatchNorm2d(32, track_running_stats=True)
|
|
out = bn(c1(img)).relu()
|
|
check_schedule(out, 1, [c1.weight, c1.bias])
|
|
|
|
def test_fold_conv_batchnorm_notrain_no_running_stats(self):
|
|
with Tensor.train(False):
|
|
img = Tensor.empty(1,3,8,8)
|
|
c1 = nn.Conv2d(3,32,3)
|
|
bn = nn.BatchNorm2d(32, track_running_stats=False)
|
|
out = bn(c1(img)).relu()
|
|
check_schedule(out, 4, [c1.weight, c1.bias])
|
|
|
|
def test_fold_conv_batchnorm(self):
|
|
with Tensor.train():
|
|
img = Tensor.empty(1,3,8,8)
|
|
c1 = nn.Conv2d(3,32,3)
|
|
bn = nn.BatchNorm2d(32, track_running_stats=False)
|
|
out = bn(c1(img)).relu()
|
|
check_schedule(out, 4, [c1.weight, c1.bias])
|
|
|
|
def test_fold_conv_batchnorm_optim(self, adam=False):
|
|
# 2 is too low?
|
|
optim, cnt = (nn.optim.Adam, 16) if adam else (nn.optim.SGD, 2)
|
|
with Tensor.train():
|
|
img = Tensor.ones(1,3,4,4)
|
|
c1 = nn.Conv2d(3,32,3)
|
|
bn = nn.BatchNorm2d(32, track_running_stats=False)
|
|
_realize_weights([c1, bn])
|
|
opt = optim(nn.state.get_parameters([c1, bn]))
|
|
img_bn = bn(c1(img)).elu().sum()
|
|
opt.zero_grad()
|
|
img_bn.backward()
|
|
check_schedule(opt.schedule_step(), cnt)
|
|
def test_fold_conv_batchnorm_optim_adam(self): self.test_fold_conv_batchnorm_optim(True)
|
|
|
|
def test_fold_batchnorm_backward(self):
|
|
with Tensor.train():
|
|
x = Tensor.empty((2, 16, 8, 8)).contiguous()
|
|
bn = nn.BatchNorm2d(16)
|
|
bn.weight.requires_grad = bn.bias.requires_grad = x.requires_grad = True
|
|
fw = bn(x).contiguous_backward().relu().contiguous()
|
|
fw.sum().backward()
|
|
# TODO: this is too many
|
|
check_schedule([x.grad, bn.weight.grad, bn.bias.grad, fw], 9)
|
|
|
|
def test_fold_conv_relu(self):
|
|
c1 = nn.Conv2d(3,16,3)
|
|
# run
|
|
img = Tensor.ones(2,3,64,64)
|
|
out = c1(img).relu()
|
|
check_schedule(out, 1, [c1.weight, c1.bias])
|
|
|
|
def test_fold_conv_relu_alt(self):
|
|
img = Tensor.ones(1,4,8,8)
|
|
c1 = nn.Conv2d(4, 4, kernel_size=3)
|
|
c2 = nn.Conv2d(4, 4, kernel_size=3)
|
|
img_conv = img.sequential([c1, Tensor.relu, c2, Tensor.relu])
|
|
check_schedule(img_conv, 2, [*nn.state.get_parameters(c1), *nn.state.get_parameters(c2), img])
|
|
|
|
def test_fold_conv_relu_nobias(self):
|
|
img = Tensor.ones(1,4,8,8)
|
|
c1 = nn.Conv2d(4, 4, kernel_size=3, bias=False)
|
|
c2 = nn.Conv2d(4, 4, kernel_size=3, bias=False)
|
|
out = img.sequential([c1, Tensor.relu, c2, Tensor.relu])
|
|
check_schedule(out, 2, [c1.weight, c2.weight, img])
|
|
|
|
def test_fold_conv_elu(self):
|
|
c1 = nn.Conv2d(3,16,3)
|
|
# run
|
|
img = Tensor.rand(2,3,64,64)
|
|
out = c1(img).elu()
|
|
check_schedule(out, 1, [c1.weight, c1.bias, img])
|
|
|
|
def test_fold_conv_elu_alt(self):
|
|
img = Tensor.ones(1,4,8,8).contiguous()
|
|
c1 = nn.Conv2d(4, 4, kernel_size=3)
|
|
c2 = nn.Conv2d(4, 4, kernel_size=3)
|
|
img_conv = img.sequential([c1, Tensor.elu, c2, Tensor.elu])
|
|
check_schedule(img_conv, 2, [*nn.state.get_parameters(c1), *nn.state.get_parameters(c2), img])
|
|
|
|
def test_two_sum(self):
|
|
img = Tensor.empty(64,64)
|
|
x = (img.sum(0) + img.sum(1))
|
|
out = x.relu()
|
|
check_schedule(out, 1)
|
|
|
|
def test_push_permute_through_reshape(self):
|
|
a = Tensor.empty(16,16)
|
|
b = Tensor.empty(16,16)
|
|
c = (a+b).reshape(4,4,4,4).permute(2,3,0,1).contiguous()
|
|
check_schedule(c, 1)
|
|
|
|
#@unittest.skip("failing in old lazy")
|
|
def test_push_permute_through_reshape_alt(self):
|
|
a = Tensor.empty(4,4,4,4)
|
|
b = Tensor.empty(4,4,4,4)
|
|
c = (a+b).reshape(16,16).permute(1,0).contiguous()
|
|
check_schedule(c, 1)
|
|
|
|
def test_no_binop_rerun(self):
|
|
a = Tensor.empty(16)
|
|
b = Tensor.empty(16)
|
|
c = a+b
|
|
d = (a+b).reshape(16,1)
|
|
check_schedule(d, 0, [c])
|
|
|
|
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
|
|
def test_multi_permute_should_collapse(self):
|
|
a = Tensor.empty(4,4,4,4)
|
|
b = Tensor.empty(16)
|
|
c = a.sum((0,1)).cast(dtypes.float16).permute(1,0).reshape(4,4,1).permute(1,0,2).reshape(16) + b
|
|
check_schedule(c, 1)
|
|
|
|
def test_fancy_reshape_fusion(self):
|
|
a = Tensor.empty(10)
|
|
b = Tensor.empty(10)
|
|
c = a+b
|
|
d = a.reshape(10,1)+b.reshape(10,1)
|
|
out = c.sum() + d.sum()
|
|
check_schedule(out, 1)
|
|
|
|
def test_children_dont_push(self):
|
|
a = Tensor.empty(10, 10, 1)
|
|
b = Tensor.empty(10, 10, 1)
|
|
d = (a+b).expand(10, 10, 10)
|
|
e = (a+b).permute(2,1,0)
|
|
f = d+e
|
|
check_schedule(f, 1)
|
|
|
|
# failing in new lazy
|
|
@unittest.skip("always fusing elementwise")
|
|
def test_dont_fuse_binops_with_children(self):
|
|
a = Tensor.empty(10)
|
|
b = Tensor.empty(10)
|
|
c = Tensor.empty(10)
|
|
keep_me = a+b
|
|
e = keep_me.sum() # noqa: F841 give keep_me a child (NOTE: BinaryOps won't be a child since it will instant fuse)
|
|
d = keep_me+c
|
|
check_schedule(d, 2)
|
|
check_schedule(keep_me, 0, [d])
|
|
|
|
#@unittest.skip("failing in old lazy")
|
|
def test_permute_breaks_fusion(self):
|
|
a = Tensor.empty(10, 10, 10)
|
|
b = Tensor.empty(10, 10)
|
|
c = (a.sum(axis=2) + b).permute(1,0)
|
|
d = c.permute(1,0)
|
|
check_schedule(d, 1)
|
|
|
|
def test_some_permute_fusion(self):
|
|
a = Tensor.empty(8192, 16)
|
|
b = Tensor.empty(1, 16)
|
|
d = (a.T + b.expand(8192, 16).T)
|
|
c = a + b.expand(8192, 16)
|
|
e = d.T
|
|
check_schedule(c, 1)
|
|
check_schedule(e, 1)
|
|
|
|
def test_shrink_fuse(self):
|
|
a = Tensor.empty(8192, 16)
|
|
b = Tensor.empty(8192, 16)
|
|
c = a * b
|
|
d = Tensor.empty(1, 16)
|
|
e = c[0] * d
|
|
check_schedule(e, 1)
|
|
|
|
def test_expand_fuse(self):
|
|
a = Tensor.empty(1, 16)
|
|
b = Tensor.empty(1, 16)
|
|
c = a * b
|
|
d = Tensor.empty(8192, 16)
|
|
e = c * d
|
|
check_schedule(e, 1)
|
|
|
|
# this is the failing case in openpilot...it's very simple like this
|
|
def test_image_conv_fusion(self):
|
|
with Context(OPENPILOT_HACKS=1):
|
|
w1 = Tensor.empty(16, 16, 1, 1)
|
|
b1 = Tensor.empty(16)
|
|
w2 = Tensor.empty(16, 16, 1, 1)
|
|
b2 = Tensor.empty(16)
|
|
w3 = Tensor.empty(16, 16, 1, 1)
|
|
b3 = Tensor.empty(16)
|
|
|
|
x = Tensor.empty(1, 16, 32, 32)
|
|
x = base = x.image_conv2d(w1, b1)
|
|
x = x.image_conv2d(w2, b2) + base
|
|
x = x.image_conv2d(w3, b3)
|
|
|
|
# NOOP, 3 convs, contiguous
|
|
#check_schedule(x, 5)
|
|
check_schedule(x, 7)
|
|
|
|
def test_image_conv_fusion_minimal(self):
|
|
b1 = Tensor.empty(16)
|
|
b2 = Tensor.empty(16)
|
|
def p(x): return x.permute(1,0).contiguous().reshape(32,16,1).expand(32,16,16).sum(axis=2).permute(1,0)
|
|
|
|
x = Tensor.empty(16, 32)
|
|
x = base = p(x) + b1.reshape(16,1)
|
|
x = p(x)
|
|
x = x + b2.reshape(16,1)
|
|
x = x + base
|
|
del base
|
|
x = p(x)
|
|
check_schedule(x, 4)
|
|
|
|
def test_image_conv_fusion_more_minimal(self):
|
|
b1 = Tensor.empty(16)
|
|
def p(x): return x.permute(1,0).contiguous().reshape(32,16,1).expand(32,16,16).sum(axis=2).permute(1,0)
|
|
|
|
x = Tensor.empty(16, 32)
|
|
x = base = p(x) + b1.reshape(16,1)
|
|
x = p(x)
|
|
del base
|
|
check_schedule(x, 3)
|
|
|
|
def test_contiguous_while_contiguous(self):
|
|
x = Tensor.empty(1, 64, 32, 32)
|
|
out = x.contiguous()
|
|
check_schedule(out, 0, filter_sink=False)
|
|
|
|
def test_contiguous_while_not_contiguous(self):
|
|
x = Tensor.empty(1, 64, 32, 32)
|
|
out = x.permute(0,2,3,1).contiguous()
|
|
check_schedule(out, 1, filter_sink=False)
|
|
|
|
def test_fold_with_contiguous(self):
|
|
a = Tensor.randn(16, 16, 16).realize()
|
|
b = Tensor.randn(16, 16).realize()
|
|
c = (a.sum(2).contiguous() + b).contiguous()
|
|
check_schedule(c, 2)
|
|
|
|
def _alu_from_tensor(self, t:Tensor):
|
|
s = [s for s in t.schedule() if s.ast.op is Ops.SINK]
|
|
self.assertEqual(len(s), 1)
|
|
return [u.op for u in s[0].ast.toposort() if u.op in GroupOp.ALU]
|
|
|
|
def test_2_pow_is_exp2(self):
|
|
t = 2.0 ** Tensor([1.0, 2.0, 3.0])
|
|
self.assertEqual(self._alu_from_tensor(t), [Ops.EXP2])
|
|
|
|
def test_pow_05_is_sqrt(self):
|
|
t = Tensor([1.0, 2.0, 3.0]) ** 0.5
|
|
self.assertEqual(self._alu_from_tensor(t), [Ops.SQRT])
|
|
|
|
def test_pow_neg_05_is_rsqrt(self):
|
|
t = Tensor([1.0, 2.0, 3.0]) ** -0.5
|
|
self.assertEqual(self._alu_from_tensor(t), [Ops.RECIPROCAL, Ops.SQRT])
|
|
|
|
def test_pow_2_has_1_mul(self):
|
|
t = Tensor([1.0, 2.0, 3.0]) ** Tensor(2.0)
|
|
self.assertEqual(self._alu_from_tensor(t), [Ops.MUL])
|
|
|
|
def test_pow_8_has_3_muls(self):
|
|
t = Tensor([1.0, 2.0, 3.0]) ** 8
|
|
self.assertEqual(self._alu_from_tensor(t), [Ops.MUL, Ops.MUL, Ops.MUL])
|
|
|
|
@unittest.skip("const folding is removed")
|
|
def test_pow_const_tensor_to_zero(self):
|
|
x = Tensor([1,2,3,4])
|
|
out = x ** Tensor(0.0)
|
|
# NOTE: this is UOp.const(0) + UOp.const(1)
|
|
check_schedule(out, 0)
|
|
|
|
def test_zero_size(self):
|
|
x = Tensor.empty(2, 3, 0)
|
|
out = x + 1
|
|
check_schedule(out, 0, filter_sink=False)
|
|
|
|
def test_reduce_permute_nofuse(self):
|
|
x = Tensor.empty(32, 32, 32)
|
|
y = Tensor.empty(32, 32)
|
|
out = x.sum(axis=2).T+y
|
|
check_schedule(out, 1)
|
|
|
|
def test_two_elus_sum(self):
|
|
x = Tensor.empty(32, 32)
|
|
y = Tensor.empty(32, 32)
|
|
out = x.sum(1).relu().elu() + y.sum(1).relu().elu()
|
|
check_schedule(out, 1)
|
|
|
|
def test_multistage_reduce(self):
|
|
x = Tensor.empty(32, 32, 32)
|
|
out = x.sum(2).relu().sum(1)
|
|
check_schedule(out, 1)
|
|
|
|
def test_multistage_reduce_fork(self):
|
|
x = Tensor.empty(32, 32, 32)
|
|
x = x.sum(2)
|
|
out2 = x + 1
|
|
out = x.relu().sum(1) + out2[0]
|
|
check_schedule(out, 2)
|
|
|
|
def test_contiguous_add(self):
|
|
x = Tensor.empty(32)
|
|
y = Tensor.empty(32)
|
|
z = Tensor.empty(32)
|
|
out = (x+y).contiguous()+z
|
|
check_schedule(out, 2)
|
|
|
|
def test_double_sum_ref(self):
|
|
x = Tensor.empty(32, 32, 32)
|
|
x = x.sum(2)
|
|
out = x + x[:, 4]
|
|
check_schedule(out, 2)
|
|
|
|
def test_reduce_shrink(self):
|
|
x = Tensor.empty(32, 32)
|
|
y = Tensor.empty(16)
|
|
x = x.sum(1)
|
|
x = x[:16]
|
|
out = x + y
|
|
check_schedule(out, 1)
|
|
|
|
def test_const_no_recompute(self):
|
|
x = Tensor(2) + Tensor(2)
|
|
y = Tensor(2) + Tensor(2)
|
|
out = x.contiguous() + y.contiguous()
|
|
check_schedule(out, 2, filter_sink=False)
|
|
|
|
def test_reduce_shrink_child(self):
|
|
a = Tensor.empty(100, 100)
|
|
b = Tensor.empty(10,)
|
|
c = a.sum() + b[0]
|
|
d = a.sum() + 2
|
|
check_schedule([c, d], 2) # TODO: 1?
|
|
|
|
def test_reduce_multiple_paths_midshrink(self):
|
|
a = Tensor.empty(4, 4)
|
|
r = a.sum(axis=1)
|
|
out0 = r.exp2()
|
|
out1 = out0[0] + out0
|
|
check_schedule([r, out0, out1], 3)
|
|
|
|
def test_reduce_shrink_output(self):
|
|
a = Tensor.empty(4, 4)
|
|
r = a.sum(keepdim=True)
|
|
out0 = r.exp2()
|
|
out1 = out0[0] + Tensor.empty(1, )
|
|
check_schedule([r, out0, out1], 3)
|
|
|
|
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
|
|
def test_softmax_upcast(self):
|
|
# input half, softmax in float
|
|
Tensor.manual_seed(0)
|
|
x = Tensor.randn(4, 12, 64, 64, dtype=dtypes.half).realize()
|
|
out = x.softmax(dtype=dtypes.float)
|
|
sched = out.schedule()
|
|
self.assertEqual(len(sched), 3)
|
|
self.assertEqual(sched[0].bufs[0].dtype, dtypes.float)
|
|
|
|
# input float, softmax in float
|
|
Tensor.manual_seed(0)
|
|
x = Tensor.randn(4, 12, 64, 64, dtype=dtypes.float).realize()
|
|
out = x.softmax(dtype=dtypes.float)
|
|
sched = out.schedule()
|
|
self.assertEqual(len(sched), 3)
|
|
self.assertEqual(sched[0].bufs[0].dtype, dtypes.float)
|
|
|
|
def test_softmax_backward(self):
|
|
Tensor.manual_seed(0)
|
|
x = Tensor.randn(4, 12, 64, 64, requires_grad=True).realize()
|
|
x.softmax().sum().backward()
|
|
run_schedule(check_schedule(x.grad, 4))
|
|
|
|
def test_scaled_dot_product_attention_fusion(self):
|
|
x, y, z, m = (Tensor.empty(32, 8, 16, 16) for _ in range(4))
|
|
out = Tensor.scaled_dot_product_attention(x, y, z, attn_mask=m)
|
|
check_schedule(out, 4)
|
|
|
|
def test_scaled_dot_product_attention_causal_fusion(self):
|
|
x, y, z = (Tensor.empty(32, 8, 16, 16) for _ in range(3))
|
|
out = Tensor.scaled_dot_product_attention(x, y, z, is_causal=True)
|
|
check_schedule(out, 4)
|
|
|
|
def test_adam_step_fusion(self):
|
|
with Tensor.train():
|
|
x = Tensor.empty(4, 64, 32)
|
|
layer = nn.Linear(32, 32*4)
|
|
_realize_weights(layer)
|
|
opt = nn.optim.Adam(nn.state.get_parameters(layer), lr=1e-4)
|
|
layer(x).relu().sum().backward()
|
|
check_schedule(opt.schedule_step(), 13)
|
|
|
|
def test_adam_conv_fuse(self):
|
|
with Tensor.train():
|
|
img = Tensor.empty(2,3,4,4)
|
|
c1 = nn.Conv2d(3,32,3)
|
|
_realize_weights(c1)
|
|
opt = nn.optim.Adam(nn.state.get_parameters(c1), lr=1e-4)
|
|
opt.zero_grad()
|
|
c1(img).relu().sum().backward()
|
|
check_schedule(opt.schedule_step(), 13)
|
|
|
|
def test_adam_2convs_fuse(self):
|
|
with Tensor.train():
|
|
img = Tensor.empty(2,3,4,4)
|
|
c1 = nn.Conv2d(3,16,3,bias=False)
|
|
c2 = nn.Conv2d(16,32,2,bias=False)
|
|
_realize_weights([c1, c2])
|
|
opt = nn.optim.Adam(nn.state.get_parameters([c1, c2]), lr=1e-4)
|
|
opt.zero_grad()
|
|
c2(c1(img).relu()).relu().sum().backward()
|
|
check_schedule(opt.schedule_step(), 15)
|
|
|
|
def test_sgd_conv_fuse(self):
|
|
with Tensor.train():
|
|
img = Tensor.empty(2,3,4,4)
|
|
c1 = nn.Conv2d(3,32,3)
|
|
_realize_weights(c1)
|
|
opt = nn.optim.SGD(nn.state.get_parameters(c1))
|
|
opt.zero_grad()
|
|
c1(img).relu().sum().backward()
|
|
check_schedule(opt.schedule_step(), 5) # TODO: 3?
|
|
|
|
def test_sgd_2convs_fuse(self):
|
|
with Tensor.train():
|
|
img = Tensor.empty(2,3,4,4)
|
|
c1 = nn.Conv2d(3,16,3,bias=False)
|
|
c2 = nn.Conv2d(16,32,2,bias=False)
|
|
_realize_weights([c1, c2])
|
|
opt = nn.optim.SGD(nn.state.get_parameters([c1, c2]))
|
|
opt.zero_grad()
|
|
c2(c1(img).relu()).relu().sum().backward()
|
|
check_schedule(opt.schedule_step(), 7)
|
|
|
|
def test_fold_2convs_sgd_nesterov_momentum_wd(self):
|
|
with Tensor.train():
|
|
img = Tensor.empty(2,3,4,4)
|
|
c1 = nn.Conv2d(3,16,3,bias=False)
|
|
c2 = nn.Conv2d(16,32,2,bias=False)
|
|
_realize_weights([c1, c2])
|
|
opt = nn.optim.SGD(nn.state.get_parameters([c1, c2]), nesterov=True, momentum=0.9, weight_decay=0.1)
|
|
opt.zero_grad()
|
|
c2(c1(img).relu()).relu().sum().backward()
|
|
check_schedule(opt.schedule_step(), 11)
|
|
|
|
def test_sgd_4convs_fuse(self):
|
|
with Tensor.train():
|
|
img = Tensor.empty(2,3,16,16)
|
|
c1 = nn.Conv2d(3,4,3,bias=False)
|
|
c2 = nn.Conv2d(4,8,3,bias=False)
|
|
c3 = nn.Conv2d(8,16,3,bias=False)
|
|
c4 = nn.Conv2d(16,32,3,bias=False)
|
|
_realize_weights([c1, c2, c3, c4])
|
|
opt = nn.optim.SGD(nn.state.get_parameters([c1, c2, c3, c4]))
|
|
opt.zero_grad()
|
|
c4(c3(c2(c1(img).relu()).relu()).relu()).relu().sum().backward()
|
|
check_schedule(opt.schedule_step(), 15)
|
|
|
|
def test_sgd_4convs_fuse_conv_bw(self):
|
|
with Tensor.train():
|
|
img = Tensor.empty(2,3,16,16)
|
|
c1 = nn.Conv2d(3,4,3,bias=False)
|
|
c2 = nn.Conv2d(4,8,3,bias=False)
|
|
c3 = nn.Conv2d(8,16,3,bias=False)
|
|
c4 = nn.Conv2d(16,32,3,bias=False)
|
|
_realize_weights([c1, c2, c3, c4])
|
|
opt = nn.optim.SGD(nn.state.get_parameters([c1, c2, c3, c4]))
|
|
opt.zero_grad()
|
|
c4(c3(c2(c1(img).relu()).relu()).relu()).relu().sum().backward()
|
|
check_schedule(opt.schedule_step(), 15)
|
|
|
|
def test_reduce_simple_chase(self):
|
|
a = Tensor.empty(4, 4, 4)
|
|
r = a.sum(0) + 6
|
|
b = r.sum(0) * 4
|
|
c = r.sum(1) * 2
|
|
check_schedule([b, c], 3)
|
|
|
|
def test_push_permute_chase(self):
|
|
a = Tensor.empty(4, 4, 4)
|
|
b = Tensor.empty(4, 4)
|
|
r = a.sum(2) + b
|
|
d = r.T * 4
|
|
e = r * d
|
|
check_schedule([d, e], 3)
|
|
|
|
def test_push_shrink_chase(self):
|
|
a = Tensor.empty(16, 16)
|
|
b = Tensor.empty(4)
|
|
c = Tensor.empty(16, )
|
|
r = a.sum(1) + c
|
|
d = r[:4] * b
|
|
check_schedule(d, 1)
|
|
|
|
def test_midreduce_nochase(self):
|
|
a = Tensor.empty(16, 16)
|
|
b = (a.sum(0) + a.max(1)) + 2
|
|
check_schedule(b, 1)
|
|
|
|
def test_bitcast_fuses(self):
|
|
x = Tensor.empty(1, dtype=dtypes.float32)
|
|
a = x.exp2().bitcast(dtypes.int32)
|
|
b = x.bitcast(dtypes.int32)
|
|
check_schedule(a+b, 1) # this should fuse when it makes sense
|
|
|
|
def test_reduceop_reshape_dont_push(self):
|
|
Tensor.manual_seed(0)
|
|
x = Tensor.randn(10, 20).realize()
|
|
out = x.argmax(1)
|
|
run_schedule(check_schedule(out, 2))
|
|
|
|
def test_resnet_conv2d(self):
|
|
x = Tensor.empty(1, 8, 32, 32)
|
|
w1 = Tensor.empty(8, 8, 3, 3)
|
|
w2 = Tensor.empty(8, 8, 1, 1)
|
|
out = x.conv2d(w1).conv2d(w2)
|
|
check_schedule(out, 2)
|
|
|
|
def test_schedule_mem_used(self):
|
|
gc.collect()
|
|
base = GlobalCounters.mem_used
|
|
Tensor.ones(256).contiguous().realize()
|
|
Tensor.ones(5, 5).contiguous().schedule()
|
|
gc.collect()
|
|
self.assertEqual(GlobalCounters.mem_used-base, 0)
|
|
|
|
def test_const_schedule(self):
|
|
constv = Tensor.empty(2, 2).uop.const_like(10)
|
|
check_schedule(constv, 0)
|
|
|
|
def test_const_schedule_contig(self):
|
|
constv = Tensor.empty(2, 2).uop.const_like(10).contiguous()
|
|
check_schedule(constv, 1)
|
|
|
|
def test_advanced_simple_indexing_combined(self):
|
|
X = Tensor.arange(16).reshape(4, 4)
|
|
xt = X[1:2, [-1, 2]]
|
|
check_schedule(xt, 1)
|
|
|
|
def test_arange_index_shrink(self):
|
|
Tensor.manual_seed(0)
|
|
with Context(TRACK_MATCH_STATS=0):
|
|
x = Tensor.randn(11).realize()
|
|
a = Tensor.arange(22)
|
|
out = (x + a[:11]).sum()
|
|
check_schedule(out, 1)
|
|
|
|
def test_fuse_arange_avg_pool2d_ceil_mode(self):
|
|
x = Tensor.avg_pool2d(Tensor.empty(1,1,6,6), kernel_size=(3,3), padding=1, stride=3, ceil_mode=True)
|
|
sched = check_schedule(x, 1)
|
|
self.assertEqual(len([x for x in sched[0].ast.backward_slice_with_self if x.op is Ops.REDUCE]), 1)
|
|
|
|
def test_fuse_arange_pad_circular_mode_bw(self):
|
|
x = Tensor.empty(1,1,5,5,5)
|
|
out = x.pad((1,2,3,5,1,2), mode="circular")
|
|
g = out.sum().gradient(x)[0]
|
|
sched = check_schedule(g, 1)
|
|
self.assertEqual(len([x for x in sched[0].ast.backward_slice_with_self if x.op is Ops.REDUCE]), 0)
|
|
|
|
def test_resnet_block(self):
|
|
with Tensor.train(False):
|
|
in_planes, planes = 64, 64
|
|
conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
|
|
bn1 = nn.BatchNorm2d(planes)
|
|
conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, stride=1, bias=False)
|
|
bn2 = nn.BatchNorm2d(planes)
|
|
x = Tensor.empty(1, 64, 32, 32)
|
|
out = bn1(conv1(x)).relu()
|
|
out = bn2(conv2(out))
|
|
out = (out + x).relu()
|
|
run_schedule(check_schedule(out, 2, [conv1.weight, conv2.weight]))
|
|
|
|
class TestSwizzle(unittest.TestCase):
|
|
def test_softmax_one_kernel(self):
|
|
Tensor.manual_seed(0)
|
|
with Context(DEBUG=0, TRACK_MATCH_STATS=0):
|
|
a = Tensor.randn(32, 32).realize()
|
|
t = a.softmax()
|
|
check_schedule(t, 3) # TODO: 1?
|
|
|
|
def test_argmax_one_kernel(self):
|
|
Tensor.manual_seed(0)
|
|
with Context(DEBUG=0, TRACK_MATCH_STATS=0):
|
|
a = Tensor.randn(10, 20).realize()
|
|
t = a.argmax(0)
|
|
check_schedule(t, 2) # TODO: 1?
|
|
|
|
class TestView(unittest.TestCase):
|
|
def test_zero_size_alt(self):
|
|
a = Tensor.empty(135, 0, 9)
|
|
b = a.pad(((0, 0), (0, 0), (18, 0)))
|
|
check_schedule(b, 0)
|
|
|
|
class TestUOpBecome(unittest.TestCase):
|
|
# the simplest case, if we create a new BUFFER for this tensor UOp
|
|
def test_new_buffer(self):
|
|
a = Tensor.empty(4, 4)
|
|
b = Tensor.empty(4, 4)
|
|
add = a+b
|
|
check_schedule(add, 1)
|
|
# NOTE: realized base is always a flat buffer
|
|
assert UPat(Ops.BUFFER).match(add.uop.base, {})
|
|
# the Tensor UOp can optionally stack a VIEW on top of the BUFFER, in this case to preserve the (4, 4) shape of the tensor
|
|
assert add.uop is not add.uop.base
|
|
self.assertEqual(add.uop.size, 16)
|
|
self.assertEqual(add.uop.shape, (4, 4))
|
|
|
|
def test_new_buffer_view(self):
|
|
a = Tensor.empty(4, 4)
|
|
b = Tensor.empty(4, 4)
|
|
add = (a+b).reshape(8, 2)
|
|
check_schedule(add, 1)
|
|
assert UPat(Ops.BUFFER).match(add.uop.base, {})
|
|
# the shape is preserverd in the becomes_map.
|
|
self.assertEqual(add.uop.shape, (8, 2))
|
|
assert add.uop is not add.uop.base
|
|
|
|
def test_new_flat_buffer(self):
|
|
a = Tensor.empty(4,)
|
|
b = Tensor.empty(4,)
|
|
add = a+b
|
|
check_schedule(add, 1)
|
|
# BUFFER already has a shape (4,), this tensor just becomes a contiguous BUFFER
|
|
assert UPat(Ops.BUFFER).match(add.uop.base, {})
|
|
|
|
# sometimes we prefer to perform an op before movement ops, in this case we should stack the mops on top of the new buffer
|
|
|
|
@unittest.skip("no longer supported")
|
|
def test_reorder_expand(self):
|
|
a = Tensor.empty(4, 1)
|
|
b = a.expand(4, 4).reciprocal()
|
|
check_schedule(b, 1)
|
|
self.assertEqual(b.uop.base.buffer.size, 4)
|
|
self.assertEqual(b.uop.shape, (4, 4))
|
|
|
|
def test_reorder_expand_alt(self):
|
|
x = Tensor.empty(4, 1)
|
|
y = Tensor.empty(4, 1)
|
|
img = Tensor.empty(4, 4)
|
|
z = (img*x) / y
|
|
check_schedule(z, 1)
|
|
|
|
# TODO: rangeify doesn't yet cleanup this kind of re-indexing
|
|
@unittest.expectedFailure
|
|
def test_become_existing_buffer(self):
|
|
a = Tensor.empty(4, 4)
|
|
b = a*1
|
|
assert UPat(Ops.MUL).match(b.uop, {}) # before scheduling it's a mul
|
|
check_schedule(b, 0)
|
|
self.assertIs(a.uop.base.buffer, b.uop.base.buffer)
|
|
|
|
def test_become_buf_with_mops(self):
|
|
a = Tensor.empty(2, 4, 2)
|
|
noop = a.shrink(((1, 2), (0, 4), (0, 2))).reshape(4, 2)*1+0
|
|
# before realizing, this tensor is base
|
|
assert noop.uop is noop.uop.base
|
|
noop.realize()
|
|
# it becomes a realized view after realize
|
|
assert noop.uop is not noop.uop.base
|
|
assert noop.uop.base.op is Ops.BUFFER
|
|
late_add = noop+2
|
|
late_add.realize()
|
|
|
|
@unittest.skip("const folding is removed")
|
|
def test_become_const_in_base(self):
|
|
a = Tensor.empty(4)
|
|
b = a*0
|
|
assert UPat(Ops.MUL).match(b.uop, {}) # before scheduling it's a mul
|
|
check_schedule(b, 0)
|
|
assert UPat(Ops.CONST, arg=0).match(b.uop.base, {}) # scheduling replaces the tensor uop with a VIEW(BUFFER)
|
|
|
|
@unittest.skip("const folding is removed")
|
|
def test_become_const_from_const(self):
|
|
const_add = Tensor(1)+Tensor(2)
|
|
assert UPat(Ops.ADD).match(const_add.uop, {})
|
|
check_schedule(const_add, 0)
|
|
assert UPat(Ops.CONST, arg=3).match(const_add.uop.base, {})
|
|
|
|
# tensors can become another realized tensor source
|
|
@unittest.expectedFailure
|
|
def test_become_existing_buf_simple(self):
|
|
a = Tensor.empty(4, 4)
|
|
b = a+0
|
|
check_schedule(b, 0)
|
|
assert b.uop.base.op is Ops.BUFFER
|
|
self.assertIs(a.uop, b.uop)
|
|
|
|
# they can also chain other movement ops on top of the tensor source
|
|
@unittest.expectedFailure
|
|
def test_become_existing_buf_view(self):
|
|
a = Tensor.empty(4, 4)
|
|
b = a.permute((1, 0))+0
|
|
check_schedule(b, 0)
|
|
self.assertEqual(b.uop.st, a.uop.permute((1, 0)).st)
|
|
|
|
@unittest.expectedFailure
|
|
def test_become_existing_buf_view_alt(self):
|
|
a = Tensor.empty(4, 4)
|
|
b = a.permute((1, 0)).reshape((8, 2))+0
|
|
check_schedule(b, 0)
|
|
self.assertEqual(b.uop.st, a.uop.permute((1, 0)).reshape((8, 2)).st)
|
|
|
|
# they can also have other base parents that simplified, in that case we just backtrack to the chained mops
|
|
@unittest.expectedFailure
|
|
def test_become_existing_buf_complex(self):
|
|
a = Tensor.empty(4, 4)
|
|
b = (a.permute((1, 0))+0).reshape((8, 2))+0
|
|
check_schedule(b, 0)
|
|
self.assertEqual(b.uop.st, a.uop.permute((1, 0)).reshape((8, 2)).st)
|
|
assert b.uop.base.op is Ops.BUFFER
|
|
|
|
@unittest.expectedFailure
|
|
def test_become_multiple_choices(self):
|
|
a = Tensor.empty(16)
|
|
b = (a.reshape(1, 1, 4, 1, 4)+0).reshape(1, 1, 4, 4).shrink(((0, 1), (0, 1), (0, 3), (0, 3)))+0
|
|
c = (a.reshape(1, 1, 4, 4)+0).shrink(((0, 1), (0, 1), (0, 3), (0, 3)))+0
|
|
check_schedule([b, c], 0)
|
|
from tinygrad.helpers import all_same
|
|
assert all_same([x.uop.base.realized for x in [a,b,c]])
|
|
|
|
@unittest.skip("not clear if we want this")
|
|
def test_setitem_becomes_subbuffer(self):
|
|
a = Tensor.full((4,), 2.).contiguous().realize()
|
|
b = a.shrink(((0, 2),)).assign(Tensor.full((2,), 1.0))
|
|
b.realize()
|
|
assert a.uop.is_realized
|
|
assert a.uop.buffer._base is None
|
|
assert b.uop.op_in_backward_slice_with_self(Ops.SHRINK)
|
|
assert b.uop.base is a.uop.base
|
|
|
|
class TestFusionOp(unittest.TestCase):
|
|
def test_recursive_add(self):
|
|
st = time.perf_counter()
|
|
a = Tensor([1,2,3,4])
|
|
for _ in range(24): a = a + a
|
|
sched = a.schedule()
|
|
sched[-1].lower()
|
|
self.assertLess(time.perf_counter()-st, 2.0)
|
|
assert len(sched[-1].prg.p.src.splitlines()) < 250
|
|
|
|
def test_recursive_add_cmp(self):
|
|
st = time.perf_counter()
|
|
a = Tensor([1,2,3,4])
|
|
for _ in range(24): a = a + a
|
|
sched1 = a.schedule()
|
|
b = Tensor([1,2,3,4])
|
|
for _ in range(24): b = b + b
|
|
sched2 = b.schedule()
|
|
c = Tensor([1,2,3,4])
|
|
for _ in range(23): c = c + c
|
|
sched3 = c.schedule()
|
|
self.assertEqual(sched1[-1].ast, sched2[-1].ast)
|
|
with self.assertRaises(AssertionError): self.assertEqual(sched1[-1].ast, sched3[-1].ast)
|
|
self.assertLess(time.perf_counter()-st, 2.0)
|
|
|
|
def test_recursive_pad(self):
|
|
st = time.perf_counter()
|
|
val = 1.0
|
|
a = Tensor(val)
|
|
for _ in range(24): a = Tensor.stack(a, a)[0]
|
|
sched = a.schedule()
|
|
self.assertLessEqual(len(sched), 1)
|
|
self.assertLess(time.perf_counter()-st, 2.0)
|
|
|
|
def test_recursive_reshape(self):
|
|
st = time.perf_counter()
|
|
a = Tensor.empty(32, 32).realize()
|
|
b = Tensor.empty(16, 2).realize()
|
|
r = a.sum(1)
|
|
for _ in range(24): r = r.reshape(16, 2) + b
|
|
sched = r.schedule()
|
|
self.assertEqual(len(sched), 1)
|
|
self.assertLess(time.perf_counter()-st, 2.0)
|
|
|
|
# NOTE: the NULL backend supports BUFFER_VIEW
|
|
class TestBufferView(unittest.TestCase):
|
|
def test_shrink_contiguous_is_buffer_view(self):
|
|
# simple 1D shrink of a realized buffer should be BUFFER_VIEW, not a copy kernel
|
|
a = Tensor.arange(100).contiguous().realize()
|
|
b = a.shrink(((10, 50),)).contiguous()
|
|
run_schedule(check_schedule(b, 0))
|
|
|
|
def test_shrink_2d_contiguous_is_buffer_view(self):
|
|
a = Tensor.arange(100).reshape(10,10).contiguous().realize()
|
|
b = a.shrink(((1, 5),None)).contiguous()
|
|
run_schedule(check_schedule(b, 0))
|
|
|
|
def test_chained_shrink_is_buffer_view(self):
|
|
a = Tensor.arange(1000).contiguous().realize()
|
|
b = a.shrink(((200, 800),)).shrink(((0, 300),)).reshape((30, 10)).shrink(((20, 25), (0, 10))).contiguous()
|
|
run_schedule(check_schedule(b, 0))
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main(verbosity=2)
|