Files
tinygrad/test/backend/test_schedule.py
George Hotz d1cce7a476 put the ranges on store instead of after (#15759)
* put the ranges on store instead of after

* better assert

* fix stuff

* comment out slow rules i don't understand

* simpler rule

* closer

* return false for store

* fix loop

* only a few schedule failures remain

* remove stores to self

* all tests pass locally

* remove junk

* regression test and fix

* better test, bump broken torch count

* bugfix with regression test

* new fusion is better
2026-04-16 19:06:40 +08:00

1450 lines
61 KiB
Python

# this will be the new test_ops for the next level
# schedule confirms the right things are capable of fusing
# NOTE: this has overlap with external_test_opt.py
import gc, unittest, functools
import numpy as np
from typing import cast
from hypothesis import assume, given, settings, strategies as strat
from tinygrad import nn, dtypes, Device, Tensor, Variable
from tinygrad.device import is_dtype_supported
from tinygrad.dtype import DType
from tinygrad.uop.ops import UOp, Ops, UPat
from tinygrad.helpers import CI, DEBUG, OSX, GlobalCounters, Context, getenv, all_same, temp
from tinygrad.engine.realize import CompiledRunner, run_schedule
class KernelCountException(Exception): pass
def check_schedule(t:Tensor|list[Tensor]|UOp, allowed:int, to_prerealize:list[Tensor]|None=None, filter_sink=True):
if to_prerealize:
with Context(DEBUG=0, TRACK_MATCH_STATS=0): Tensor.realize(*to_prerealize)
if isinstance(t, Tensor): sched = t.schedule()
elif isinstance(t, list) and isinstance(t[0], Tensor): sched = Tensor.schedule(*t)
else:
assert isinstance(t, UOp), f"can't schedule {t}"
sched = Tensor(t).schedule()
# test lowering all the ExecItems
for si in sched: si.lower()
kernel_cnt = len([si for si in sched if isinstance(si.prg, CompiledRunner) or not filter_sink])
if kernel_cnt != allowed:
print(f"SCHEDULE ISSUE, expecting {allowed} got {kernel_cnt}")
if DEBUG >= 3:
for i,s in enumerate(sched):
print("kernel", i+1)
print(s.ast)
raise KernelCountException(f"{kernel_cnt} != {allowed}")
return sched
def _realize_weights(m):
for p in nn.state.get_parameters(m): p.realize()
def _test_conv2d(allowed:int, dtype:DType=dtypes.float):
old_default_float, dtypes.default_float = dtypes.default_float, dtype
dtypes.default_float = dtype
Tensor.manual_seed(0)
BS, CIN = 2, 3
img = Tensor.randn(BS, CIN, 64, 64, requires_grad=True).realize()
w = Tensor.uniform(16, CIN, 3, 3, requires_grad=True).realize()
ret = Tensor.conv2d(img, w).relu().mean().backward()
dtypes.default_float = old_default_float
s = Tensor.schedule(ret, img.grad, w.grad)
run_schedule(s.copy())
cnt = len([si for si in s if si.ast.op is Ops.SINK])
assert cnt == allowed, f"expected {allowed} kernels, got {cnt}"
if getenv("CHECK", 1):
import torch
ref_img = torch.tensor(img.numpy(), requires_grad=True)
ref_w = torch.tensor(w.numpy(), requires_grad=True)
torch.nn.functional.conv2d(ref_img, ref_w).relu().mean().backward()
assert ref_img.grad is not None and ref_w.grad is not None and img.grad is not None and w.grad is not None
np.testing.assert_allclose(img.grad.numpy(), ref_img.grad.detach().numpy(), atol=1e-6 if dtype == dtypes.float else 1e-2)
np.testing.assert_allclose(w.grad.numpy(), ref_w.grad.detach().numpy(), atol=1e-6 if dtype == dtypes.float else 1e-2)
class TestSchedule(unittest.TestCase):
def setUp(self):
self.ctx = Context(SPLIT_REDUCEOP=0)
self.ctx.__enter__()
def tearDown(self):
self.ctx.__exit__(None, None, None)
def test_arange_avgpool2d(self, kcount=1):
x = Tensor.arange(25).reshape(1,1,5,5).cast(dtypes.float32)
t = x.avg_pool2d(padding=1)
sched = t.schedule()
self.assertEqual(len(sched), kcount)
run_schedule(sched)
import torch
torch_out = torch.nn.functional.avg_pool2d(torch.arange(25).reshape(1,1,5,5).float(), kernel_size=(2,2), padding=1).numpy()
np.testing.assert_allclose(t.numpy(), torch_out)
def test_arange_avgpool2d_fused_noopt(self):
with Context(NOOPT=1): self.test_arange_avgpool2d(kcount=1)
# linearizer error
@unittest.skip("recursion error no longer raised")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "needs supports_float4 to fail")
def test_arange_avgpool2d_fused(self):
with self.assertRaises(RecursionError):
with Context(NOOPT=0): self.test_arange_avgpool2d(kcount=1)
# when we're fusing a reduce, all ReduceOps must have the same N in the dimensions
# all permutes, reshapes, expands and shrinks push through the reduce
def test_arange_sum(self):
a = Tensor.arange(6).reshape(3, 2).sum(axis=1)
run_schedule(check_schedule(a, 1))
self.assertListEqual(a.tolist(), [1, 5, 9])
def test_arange_sum_alt(self):
a = (Tensor.arange(5).reshape(1,5).expand(6,5)*Tensor(2)).reshape(1,6,5).sum(axis=2)
run_schedule(check_schedule(a, 1))
np.testing.assert_equal(a.numpy(), 20)
def test_permute_arange(self):
a = Tensor.arange(6).reshape(6, 1, 1).permute(2, 0, 1).sum(axis=1)
run_schedule(check_schedule(a, 1))
self.assertListEqual(a.tolist(), [[15]])
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
@unittest.skipIf(Device.DEFAULT == "WEBGPU" and OSX, "WEBGPU Metal backend is not accurate enough")
def test_expand_buffer_before_cast(self):
a = Tensor.randn(4, 2, 1).realize().permute((1, 0, 2))
b = a.cast(dtypes.half).expand((2, 4, 4))+2
run_schedule(check_schedule(b, 1))
np.testing.assert_allclose(b.numpy(), np.broadcast_to(a.numpy().astype(np.float16), (2, 4, 4))+2, rtol=1e-3)
def test_indexing_scalars_simple(self):
X = Tensor.randn(2, 2).realize()
xt = X[Tensor(1)][Tensor(0)]
run_schedule(check_schedule(xt, 1))
np.testing.assert_equal(xt.numpy(), X.numpy()[1][0])
@unittest.skipIf(CI and Device.DEFAULT == "NV", "crashes on NV CI")
def test_add_chain_buffers(self):
N = 31
with Context(TRACK_MATCH_STATS=0, DEBUG=0):
bufs = [Tensor(i).reshape((1,)).contiguous().realize() for i in range(N)]
for X in range(1,N):
root = bufs[0]
for i in range(1,N,X):
root = root + functools.reduce(lambda a,b:a+b, bufs[i:i+X])
self.assertEqual(root.item(), sum(range(N)))
@given(strat.sampled_from(range(2,4)), strat.sampled_from(range(2,4)), strat.sampled_from(range(0,4)), strat.sampled_from(range(0,4)))
@settings(deadline=None)
def test_indexing_scalars(self, x, y, a, b):
assume(a<x and b<y)
X = Tensor.randn(x, y).realize()
xt = X[Tensor(a)][Tensor(b)]
run_schedule(check_schedule(xt, 1))
np.testing.assert_equal(xt.numpy(), X.numpy()[a][b])
def test_push_pads_elementwise(self):
x = Tensor.full((4,4), 2.).contiguous().realize()
y = Tensor.full((4,4), 4.).contiguous().realize()
z = (x.reciprocal()*y).pad((None, (0,1),)).sum()
run_schedule(check_schedule(z, 1))
self.assertEqual(z.item(), 32)
def test_push_pads_contiguous(self):
x = Tensor.full((4,1), 2.).contiguous()
y = Tensor.full((4,4), 4.).contiguous()
z = (x.reciprocal().expand(4,4)*y).pad((None, (0,1),)).sum()
run_schedule(check_schedule(z, 1, [x,y]))
self.assertEqual(z.item(), 32)
def test_constants_can_store(self):
a = Tensor(2).contiguous()
run_schedule(check_schedule(a, 1))
np.testing.assert_equal(a.numpy(), 2)
def test_allow_push_permutes(self):
a = Tensor.randn(10,10,10).realize()
b = Tensor.randn(10,10,1).realize()
c = a.sum(axis=0, keepdim=True).permute(2,1,0) + b
run_schedule(check_schedule(c, 1))
np.testing.assert_allclose(c.numpy(), np.sum(a.numpy(), axis=0, keepdims=True).transpose(2,1,0)+b.numpy())
def test_div_collapse_buffer(self):
a = Tensor.full((4,), 4.0).contiguous().realize()
b = Tensor.full((4,), 2.0).contiguous().realize()
expr = (a*b)/b
run_schedule(check_schedule(expr, 1))
np.testing.assert_allclose(expr.numpy(), np.full((4,), 4.0))
def test_div_collapse_const(self):
a = Tensor.full((4,), 4.0).contiguous().realize()
expr = a/a
run_schedule(check_schedule(expr, 1))
np.testing.assert_allclose(expr.numpy(), np.full((4,), 1.0))
def test_div_collapse(self):
a = Tensor.full((4,), 1.0).contiguous().realize()
b = Tensor.full((4,), 2.0).contiguous().realize()
c = Tensor.full((4,), 3.0).contiguous().realize()
GlobalCounters.reset()
expr = (a/b)/c
expr.realize()
self.assertEqual(GlobalCounters.kernel_count, 1)
self.assertLessEqual(GlobalCounters.global_ops, 4*3)
np.testing.assert_allclose(expr.numpy(), (a.numpy()/b.numpy())/c.numpy())
# NOTE: this is causing "LAZYCACHE=1 incorrectly reuses contiguous const" #4562
# should contiguous dedup?
@unittest.skip("we do the exact opposite now")
def test_dedup_contiguous(self):
a = Tensor.ones(4).contiguous()
b = Tensor.ones(4).contiguous()
sched = check_schedule([a, b], 1)
run_schedule(sched)
# a and b share the same underlying device memory
self.assertIs(a.uop.realized, b.uop.realized)
def test_clone_doesnt_dedup(self):
src = Tensor.ones(4).contiguous().realize()
a = src.clone()
b = src.clone()
sched = check_schedule([a, b], 2, filter_sink=False)
run_schedule(sched)
# a and b are assigned to the same device Buffer
self.assertIsNot(a.uop.base.realized, b.uop.base.realized)
@unittest.skip("no longer supported")
def test_double_from(self):
x = Tensor([1,2,3,4])
out = x.to('python')
check_schedule(out, 0, filter_sink=False)
def test_zero_size_assign(self):
f = Tensor.full((2,), 0.).contiguous().realize()
a = f.shrink_to((0,))
a.assign(Tensor.ones_like(a))
check_schedule(a, 0)
self.assertEqual(a.tolist(), [])
def test_zero_size_children(self):
r = Tensor.ones(1,2).contiguous().realize().sum(axis=(1,), keepdim=True)
ax = r.reshape(1)*2
ay = r.reshape(1).shrink(((1,1),))*2
out = ax+ay.pad(((1, 0),))
run_schedule(check_schedule(out, 1))
self.assertEqual(out.item(), 4.)
def test_preserve_multistage_reduce(self):
big_enough = getenv("REDUCEOP_SPLIT_THRESHOLD", 32768)
x = Tensor.randn(big_enough).realize()
with Context(SPLIT_REDUCEOP=1):
out = (x - x.max(keepdim=True)).max()
run_schedule(check_schedule(out, 4))
np.testing.assert_allclose(out.numpy(), (x.numpy() - x.numpy().max(keepdims=True)).max())
@unittest.skip("these two Tensors are the same")
def test_example_matmul(self):
x = Tensor.eye(64, requires_grad=True)
y = Tensor.eye(64, requires_grad=True)
z = y.matmul(x).sum()
z.backward()
out = x.grad.contiguous()
run_schedule(check_schedule(out, 1))
np.testing.assert_allclose(out.numpy(), np.ones((64,64)))
def test_example_matmul_contig(self):
x = Tensor.eye(64, requires_grad=True).contiguous().realize()
y = Tensor.eye(64, requires_grad=True).contiguous().realize()
z = y.matmul(x).sum()
z.backward()
out = x.grad.contiguous()
run_schedule(check_schedule(out, 1))
np.testing.assert_allclose(out.numpy(), np.ones((64,64)))
def test_example_matmul_same(self):
x = Tensor.eye(64, requires_grad=True)
z = x.matmul(x).sum()
z.backward()
out = x.grad.contiguous()
run_schedule(check_schedule(out, 1))
# NOTE: the gradient flows twice
np.testing.assert_allclose(out.numpy(), 2*np.ones((64,64)))
def test_multireduce_shrink(self):
Tensor.manual_seed(0)
a = Tensor.randn(32, 32).realize()
b = Tensor.randn(32, 32).realize()
c = Tensor.randn(16).realize()
a_out = a.sum(1)
a_out = a_out[:16]
b_out = b.sum(1)
b_out = b_out[:16]
out = a_out + b_out + c
run_schedule(check_schedule(out, 1))
np.testing.assert_allclose(out.numpy(), a.numpy().sum(axis=1)[:16] + b.numpy().sum(axis=1)[:16] + c.numpy(), atol=1e-4, rtol=1e-4)
def test_reduce_same_size(self):
Tensor.manual_seed(0)
a = Tensor.randn(4, 4).realize()
out0 = a.sum() + 2
out1 = a.sum() + 4
out2 = out0 * out1
run_schedule(check_schedule([out0, out1, out2], 3)) # TODO: 1?
np.testing.assert_allclose(out0.numpy(), out0_np:=a.numpy().sum()+2, atol=1e-4, rtol=1e-6)
np.testing.assert_allclose(out1.numpy(), out1_np:=a.numpy().sum()+4, atol=1e-4, rtol=1e-6)
np.testing.assert_allclose(out2.numpy(), out0_np*out1_np, atol=1e-4, rtol=1e-6)
def test_reduce_multiple_paths(self):
Tensor.manual_seed(0)
a = Tensor.randn(4, 4).realize()
out0 = a.sum().exp2()
# out1 has two paths to a.sum()
out1 = a.sum() + out0
run_schedule(check_schedule([out0, out1], 2)) # TODO: 1?
np.testing.assert_allclose(out0.numpy(), out0_np:=np.exp2(a.numpy().sum()), atol=1e-4, rtol=1e-4)
np.testing.assert_allclose(out1.numpy(), a.numpy().sum()+out0_np, atol=1e-4, rtol=1e-6)
def test_multireduce_reduce_multiple_paths(self):
Tensor.manual_seed(0)
a = Tensor.randn(4, 4).realize()
out0 = a.sum().exp2()
out1 = a.sum() + out0
b = (a + out0 + out1)
out2 = b.sum().exp2()
out3 = b.sum() + out2
# run_schedule(check_schedule([out0, out1, out2, out3], 1))
run_schedule(check_schedule([out0, out1, out2, out3], 4))
np.testing.assert_allclose(out0.numpy(), np_out0:=np.exp2(a.numpy().sum()), atol=1e-4, rtol=1e-4)
np.testing.assert_allclose(out1.numpy(), np_out1:=a.numpy().sum()+np_out0, atol=1e-4, rtol=1e-4)
np_b = (a.numpy() + np_out0 + np_out1)
with np.errstate(over='ignore'):
np.testing.assert_allclose(out2.numpy(), np_out2:=np.exp2(np_b.sum()), atol=1e-4, rtol=1e-4)
np.testing.assert_allclose(out3.numpy(), np_b.sum()+np_out2, atol=1e-4, rtol=1e-4)
def test_reduce_ext_reduce_child(self):
Tensor.manual_seed(0)
a = Tensor.randn(4, 4).realize()
b = Tensor.randn(4, 4).realize()
# b.sum() is not a descendant of the fused nodes
out0 = a.sum() + b.sum() + 2
out1 = a.sum() + b.sum() + 4
# run_schedule(check_schedule([out0, out1], 1))
run_schedule(check_schedule([out0, out1], 2))
np.testing.assert_allclose(out0.numpy(), a.numpy().sum()+b.numpy().sum()+2, atol=1e-4, rtol=1e-4)
np.testing.assert_allclose(out1.numpy(), a.numpy().sum()+b.numpy().sum()+4, atol=1e-4, rtol=1e-4)
def test_reduce_multiple_paths_midreduce(self):
Tensor.manual_seed(0)
a = Tensor.randn(4, 4).realize()
r = a.sum()
out0 = r.exp2()
# reduce node in the indirect path from r to out2
out1 = (a - out0).max()
out2 = r + out1
# run_schedule(check_schedule([r, out0, out1, out2], 1))
run_schedule(check_schedule([r, out0, out1, out2], 4))
np.testing.assert_allclose(r.numpy(), r_np:=a.numpy().sum(), atol=1e-4, rtol=1e-4)
np.testing.assert_allclose(out0.numpy(), out0_np:=np.exp2(r_np), atol=1e-4, rtol=1e-4)
np.testing.assert_allclose(out1.numpy(), out1_np:=(a.numpy() - out0_np).max(), atol=1e-4, rtol=1e-4)
np.testing.assert_allclose(out2.numpy(), r_np + out1_np, atol=1e-4, rtol=1e-4)
def test_reduce_multiple_paths_midreduce_fused(self):
Tensor.manual_seed(0)
a = Tensor.randn(4, 4).realize()
b = Tensor.randn(4, 4).realize()
out0 = a.sum() + 4
out1 = b.max() + out0*2
out2 = a.sum() + out1
# run_schedule(check_schedule([out0, out1, out2], 1))
run_schedule(check_schedule([out0, out1, out2], 3))
np.testing.assert_allclose(out0.numpy(), out0_np:=a.numpy().sum()+4, atol=1e-4, rtol=1e-6)
np.testing.assert_allclose(out1.numpy(), out1_np:=b.numpy().max() + out0_np*2, atol=1e-4, rtol=1e-6)
np.testing.assert_allclose(out2.numpy(), a.numpy().sum() + out1_np, atol=1e-4, rtol=1e-6)
def test_reduce_multiple_paths_midexpand(self):
Tensor.manual_seed(0)
a = Tensor.randn(4, 4).realize()
b = Tensor.randn(4, 4, 4).realize()
r = a.sum()
out0 = r.exp2()
# e1 is in the indirect path from a.sum() to out1
e = b + out0
out1 = r + e[0][0][0]
# run_schedule(check_schedule([r, out0, out1, e], 3)) # 1 or 2 or 3? should be 1 (one reduce) but the different outputs might make it 3
run_schedule(check_schedule([r, out0, out1, e], 4))
np.testing.assert_allclose(r.numpy(), r_np:=a.numpy().sum(), atol=1e-4, rtol=1e-4)
np.testing.assert_allclose(out0.numpy(), out0_np:=np.exp2(r_np), atol=1e-4, rtol=1e-4)
np.testing.assert_allclose(e.numpy(), e_np:=b.numpy() + out0_np, atol=1e-4, rtol=1e-4)
np.testing.assert_allclose(out1.numpy(), r_np + e_np[0][0][0], atol=1e-4, rtol=1e-4)
def test_reduce_expand_child(self):
Tensor.manual_seed(0)
a = Tensor.randn((32, 32, 32)).realize()
b = Tensor.randn((1, 16)).realize()
out0 = a.sum() + 2
out1 = a.sum() + b
# run_schedule(check_schedule([out0, out1], 2))
run_schedule(check_schedule([out0, out1], 3))
np.testing.assert_allclose(out0.numpy(), a.numpy().sum()+2, atol=1e-4, rtol=1e-4)
np.testing.assert_allclose(out1.numpy(), a.numpy().sum()+b.numpy(), atol=1e-4, rtol=1e-4)
def test_std_multireduce_fusion(self):
Tensor.manual_seed(0)
x = Tensor.randn(4, 32).realize()
out = x.std(-1)
run_schedule(check_schedule(out, 2))
np.testing.assert_allclose(out.numpy(), x.numpy().std(axis=-1, ddof=1), atol=1e-4, rtol=1e-4)
def test_scaled_dot_product_attention_multireduce_fusion(self):
Tensor.manual_seed(0)
q = Tensor.randn(32,8,16,8).realize()
k = Tensor.randn(32,8,16,8).realize()
v = Tensor.randn(32,8,16,8).realize()
out = Tensor.scaled_dot_product_attention(q,k,v)
run_schedule(check_schedule(out, 4))
if getenv("CHECK", 1):
import torch
compare = torch.nn.functional.scaled_dot_product_attention(torch.tensor(q.numpy()),torch.tensor(k.numpy()),torch.tensor(v.numpy()))
np.testing.assert_allclose(out.numpy(), compare.numpy(), atol=1e-6, rtol=1e-3)
out = Tensor.scaled_dot_product_attention(q,k,v)
run_schedule(check_schedule(out, 4)) # TODO: should be 1?
if getenv("CHECK", 1):
import torch
compare = torch.nn.functional.scaled_dot_product_attention(torch.tensor(q.numpy()),torch.tensor(k.numpy()),torch.tensor(v.numpy()))
np.testing.assert_allclose(out.numpy(), compare.numpy(), atol=1e-6, rtol=1e-3)
def test_ugly_reduceop_pairing(self):
Tensor.manual_seed(0)
a = Tensor.randn(4, 32).realize()
b = Tensor.randn(4, 32).realize()
c = Tensor.randn(4, 32).realize()
out = (c * a.sum(-1, keepdim=True)).sum(-1) + (b * a.sum(-1, keepdim=True)).sum(-1) # a.sum has >1 children but should still fuse
# run_schedule(check_schedule(out, 1))
run_schedule(check_schedule(out, 2))
np.testing.assert_allclose(out.numpy(), \
(c.numpy()*a.numpy().sum(axis=-1,keepdims=True)).sum(-1) + (b.numpy()*a.numpy().sum(axis=-1,keepdims=True)).sum(-1), atol=1e-4, rtol=1e-4)
def test_reduce_expand_reduce_fusion(self):
Tensor.manual_seed(0)
a = Tensor.randn(4, 32).realize()
out = (a+a.sum(-1, keepdim=True)).sum(-1)
# run_schedule(check_schedule(out, 1))
run_schedule(check_schedule(out, 2))
np.testing.assert_allclose(out.numpy(), (a.numpy()+a.numpy().sum(axis=-1,keepdims=True)).sum(axis=-1), atol=1e-4, rtol=1e-4)
def test_reduce_expand_reduce_expand_fusion(self):
Tensor.manual_seed(0)
a = Tensor.randn(4, 32).realize()
out = a+(a+a.sum(-1,keepdim=True)).sum(-1, keepdim=True)
# run_schedule(check_schedule(out, 2))
run_schedule(check_schedule(out, 3))
np.testing.assert_allclose(out.numpy(), \
a.numpy()+(a.numpy()+a.numpy().sum(axis=-1,keepdims=True)).sum(axis=-1,keepdims=True), atol=1e-4, rtol=1e-4)
def test_branching_reduces_and_expands_fusion(self):
Tensor.manual_seed(0)
a = Tensor.randn(4, 32).realize()
out0 = a+a.sum(-1, keepdim=True)
out1 = out0.sum(-1)
# run_schedule(check_schedule(out, 2))
run_schedule(check_schedule([out0, out1], 3))
np.testing.assert_allclose(out0.numpy(), a.numpy()+a.numpy().sum(axis=-1,keepdims=True), atol=1e-4, rtol=1e-4)
np.testing.assert_allclose(out1.numpy(), (a.numpy()+a.numpy().sum(axis=-1,keepdims=True)).sum(axis=-1), atol=1e-4, rtol=1e-4)
def test_multireduce_fusion_simple_sequential(self):
Tensor.manual_seed(0)
x = Tensor.randn(4, 32).realize()
y = Tensor.randn(4, 32).realize()
out = (y + x.sum(axis=-1, keepdim=True)).sum(axis=-1)
# run_schedule(check_schedule(out, 1))
run_schedule(check_schedule(out, 2))
np.testing.assert_allclose(out.numpy(), (y.numpy() + x.numpy().sum(axis=-1, keepdims=True)).sum(axis=-1), atol=1e-4, rtol=1e-4)
def test_multireduce_fusion_simple_parallel(self):
Tensor.manual_seed(0)
x = Tensor.randn(4, 32).realize()
y = Tensor.randn(4, 32).realize()
out = y.sum(axis=-1) + x.sum(axis=-1)
run_schedule(check_schedule(out, 1))
np.testing.assert_allclose(out.numpy(), y.numpy().sum(axis=-1) + x.numpy().sum(axis=-1), atol=1e-4, rtol=1e-4)
def test_multireduce_fusion_sequential(self):
Tensor.manual_seed(0)
x = Tensor.randn(4, 32).realize()
out = x.std(-1)
# run_schedule(check_schedule(out, 1))
run_schedule(check_schedule(out, 2))
np.testing.assert_allclose(out.numpy(), x.numpy().std(axis=-1, ddof=1), atol=1e-4, rtol=1e-4)
def test_multireduce_fusion_parallel(self):
Tensor.manual_seed(0)
x = Tensor.randn(4, 32).realize()
y = Tensor.randn(4, 32).realize()
out = x.std(-1) + y.std(-1)
# run_schedule(check_schedule(out, 1))
run_schedule(check_schedule(out, 3))
np.testing.assert_allclose(out.numpy(), x.numpy().std(axis=-1, ddof=1) + y.numpy().std(axis=-1, ddof=1), atol=1e-4, rtol=1e-4)
def test_multireduce_diffops_sequential(self):
Tensor.manual_seed(0)
x = Tensor.randn(4, 32).realize()
out = (x - x.max(-1, keepdim=True)).sum(-1)
# run_schedule(check_schedule(out, 1))
run_schedule(check_schedule(out, 2))
np.testing.assert_allclose(out.numpy(), (x.numpy() - x.numpy().max(axis=-1, keepdims=True)).sum(axis=-1), atol=1e-4, rtol=1e-4)
def test_multireduce_fusion_diffops_parallel(self):
Tensor.manual_seed(0)
x = Tensor.randn(4, 32).realize()
y = Tensor.randn(4, 32).realize()
out = x.sum(-1) + y.max(-1)
run_schedule(check_schedule(out, 1))
np.testing.assert_allclose(out.numpy(), x.numpy().sum(axis=-1) + y.numpy().max(axis=-1), atol=1e-4, rtol=1e-4)
def test_multireduce_fusion_sequential_and_parallel(self):
Tensor.manual_seed(0)
x = Tensor.randn(4, 32).realize()
y = Tensor.randn(4, 32).realize()
mu = (x - x.max(axis=-1, keepdim=True)).mean(axis=-1, keepdim=True) + (y - y.max(axis=-1, keepdim=True)).mean(axis=-1, keepdim=True)
out = [((x - mu).square().sum(-1)/x.shape[-1]).sqrt(), ((y - mu).square().sum(-1)/y.shape[-1]).sqrt()]
np_mu = (x.numpy() - x.numpy().max(axis=-1, keepdims=True)).mean(axis=-1, keepdims=True) + \
(y.numpy() - y.numpy().max(axis=-1, keepdims=True)).mean(axis=-1, keepdims=True)
# run_schedule(check_schedule(out, 1))
run_schedule(check_schedule(out, 5))
np.testing.assert_allclose(out[0].numpy(), np.sqrt(np.square(x.numpy() - np_mu).sum(-1)/x.shape[-1]), atol=1e-4, rtol=1e-4)
np.testing.assert_allclose(out[1].numpy(), np.sqrt(np.square(y.numpy() - np_mu).sum(-1)/y.shape[-1]), atol=1e-4, rtol=1e-4)
def test_cumsum_parallel_reduce_fused(self):
# two-stage cumsum + ops triggers parallel REDUCEs in one kernel that must share an END (same nesting context = should merge)
step, num_steps = 513, 10
t = Tensor.arange(step).float().realize()
phase = t.cumsum()
tiled = phase.repeat((num_steps,)).reshape(num_steps, step)
pattern = Tensor([1,0,0,1,0,0,0,0,1,0]).reshape(num_steps, 1)
out = (tiled * pattern).flatten()
expected = np.tile(np.arange(step).astype(np.float32).cumsum(), num_steps).reshape(num_steps, step)
expected = (expected * np.array([1,0,0,1,0,0,0,0,1,0]).reshape(num_steps, 1)).flatten()
np.testing.assert_allclose(out.numpy(), expected, atol=1e-4, rtol=1e-4)
@unittest.skipIf(Device.DEFAULT == "CL", "TODO: fails on CI CL")
def test_reduce_different_nesting_depth(self):
# two REDUCEs sharing the same RANGE at different nesting depths must NOT merge
x = Tensor.arange(768).reshape(3, 256).float()
np.testing.assert_allclose((x.sum(axis=1) + x.sum(axis=1).sum()).numpy(), x.numpy().sum(axis=1) + x.numpy().sum(axis=1).sum())
def test_multimatmul_fusion(self):
Tensor.manual_seed(0)
a,b = Tensor.randn(4, 64).realize(), Tensor.rand(64,8).realize()
c,d = Tensor.randn(4, 64).realize(), Tensor.rand(64,8).realize()
out = a@b + c@d
run_schedule(check_schedule(out, 1))
np.testing.assert_allclose(out.numpy(), a.numpy()@b.numpy() + c.numpy()@d.numpy(), atol=1e-4, rtol=1e-4)
def test_softmax_fusion(self):
Tensor.manual_seed(0)
x = Tensor.randn(4, 12, 64, 64).realize()
out = x.softmax()
run_schedule(check_schedule(out, 3))
expected = (x_exp:=np.exp(x.numpy()-x.numpy().max(-1, keepdims=True)))/x_exp.sum(-1, keepdims=True)
np.testing.assert_allclose(out.numpy(), expected, atol=1e-4, rtol=1e-4)
def test_layernorm_onelayer_fusion(self):
Tensor.manual_seed(0)
layer = nn.LayerNorm([10, 10])
layer.weight = Tensor.randn(10,10).realize()
layer.bias = Tensor.randn(10,10).realize()
x = Tensor.randn(20, 5, 10, 10).realize()
out = layer(x)
# run_schedule(check_schedule(out, 2))
run_schedule(check_schedule(out, 3))
y = (x.numpy() - x.numpy().mean(layer.axis, keepdims=True))
expected = y / np.sqrt((y*y).mean(layer.axis, keepdims=True) + layer.eps)
np.testing.assert_allclose(out.numpy(), expected * layer.weight.numpy() + layer.bias.numpy(), atol=1e-4, rtol=1e-4)
def test_multireduce_simple_chase(self):
Tensor.manual_seed(0)
a = Tensor.randn(4, 4, 4).realize()
r = (a + (a.sum(0, keepdim=True) + 6)).sum(0) * 2
b = r.sum(0) + 8
c = r.sum(1) + 12
np_r = (a.numpy() + (a.numpy().sum(0) + 6)).sum(0) * 2
# schedule = check_schedule([b,c], 3)
# self.assertIs(schedule[0].ast[0].src[0].arg, Ops.MUL)
schedule = check_schedule([b,c], 4)
run_schedule(schedule)
np.testing.assert_allclose(b.numpy(), np_r.sum(0) + 8, atol=1e-4, rtol=1e-4)
np.testing.assert_allclose(c.numpy(), np_r.sum(1) + 12, atol=1e-4, rtol=1e-4)
def test_multireduce_push_permute_chase(self):
Tensor.manual_seed(0)
a = Tensor.randn(4, 4, 4).realize()
b = Tensor.randn(4, 4).realize()
r = a.sum(2) + b
d = r.T * 4
e = r * (d + a).sum(2)
schedule = check_schedule([d, e], 3) # make sure it doesn't fuse
run_schedule(schedule)
np.testing.assert_allclose(d.numpy(), (a.numpy().sum(2) + b.numpy()).T * 4, atol=1e-4, rtol=1e-4)
np.testing.assert_allclose(e.numpy(), (a.numpy().sum(2) + b.numpy()) * (d.numpy() + a.numpy()).sum(2), atol=1e-4, rtol=1e-4)
def test_multireduce_push_shrink_chase(self):
Tensor.manual_seed(0)
a = Tensor.randn(16, 16).realize()
b = Tensor.randn(4).realize()
c = Tensor.randn(16, ).realize()
d = Tensor.randn(16, 16).realize()
r = a.sum(1) + c
out = r[:4] * b + d.sum(1)[:4]
schedule = check_schedule(out, 1)
run_schedule(schedule)
np.testing.assert_allclose(out.numpy(), (a.numpy().sum(1) + c.numpy())[:4] * b.numpy() + d.numpy().sum(1)[:4], atol=1e-4, rtol=1e-4)
def test_multireduce_midreduce_nochase(self):
Tensor.manual_seed(0)
a = Tensor.randn(16, 16).realize()
b = (a.sum(0)+a.max(0) + a.max(1)+a.sum(1)) + 2
schedule = check_schedule(b, 1)
run_schedule(schedule)
np.testing.assert_allclose(b.numpy(), a.numpy().sum(0)+a.numpy().max(0) + a.numpy().max(1)+a.numpy().sum(1)+2, atol=1e-4, rtol=1e-4)
# pattern in test_transformer
def test_partial_fuse1(self):
Tensor.manual_seed(0)
a = Tensor.randn(16, 16).realize()
b = Tensor.randn(16, 16).realize()
c = a.sum() + 2
d = (a.sum() - b.sum()) * 4
# run_schedule(check_schedule([c, d], 1))
run_schedule(check_schedule([c, d], 2))
np.testing.assert_allclose(c.numpy(), a.numpy().sum()+2, atol=1e-4, rtol=1e-4)
np.testing.assert_allclose(d.numpy(), (a.numpy().sum() - b.numpy().sum()) * 4, atol=1e-4, rtol=1e-4)
# pattern in conv
def test_partial_fuse2(self):
Tensor.manual_seed(0)
a = Tensor.randn(16, 16).realize()
b = Tensor.randn(16, 16).realize()
c = a.sum() + 2
d = b.sum() - c
# run_schedule(check_schedule([c, d], 1))
run_schedule(check_schedule([c, d], 2))
np.testing.assert_allclose(c.numpy(), a.numpy().sum()+2, atol=1e-4, rtol=1e-4)
np.testing.assert_allclose(d.numpy(), b.numpy().sum()-(a.numpy().sum()+2), atol=1e-4, rtol=1e-4)
# pattern in adam
def test_partial_fuse3(self):
Tensor.manual_seed(0)
a = Tensor.randn(16, 16).realize()
b = Tensor.randn(16, 16).realize()
c = a.sum() + 2
d = a.sum() * 2
e = c * d
f = b.sum() - e
# run_schedule(check_schedule([c, d, e, f], 1))
run_schedule(check_schedule([c, d, e, f], 4))
np.testing.assert_allclose(c.numpy(), c_np:=a.numpy().sum()+2, atol=1e-4, rtol=1e-4)
np.testing.assert_allclose(d.numpy(), d_np:=a.numpy().sum()*2, atol=1e-4, rtol=1e-4)
np.testing.assert_allclose(e.numpy(), e_np:=c_np*d_np, atol=1e-4, rtol=1e-4)
np.testing.assert_allclose(f.numpy(), b.numpy().sum() - e_np, atol=1e-4, rtol=1e-4)
def test_partial_fuse4(self):
Tensor.manual_seed(0)
a = Tensor.randn(16, 16).realize()
b = Tensor.randn(16, 16).realize()
c = a.sum() + 2
d = a.sum() * 2
e = c * d
f = (b - d).sum() - e
# run_schedule(check_schedule([c, d, e, f], 1))
run_schedule(check_schedule([c, d, e, f], 4))
np.testing.assert_allclose(c.numpy(), c_np:=a.numpy().sum()+2, atol=1e-4, rtol=1e-4)
np.testing.assert_allclose(d.numpy(), d_np:=a.numpy().sum()*2, atol=1e-4, rtol=1e-4)
np.testing.assert_allclose(e.numpy(), e_np:=c_np*d_np, atol=1e-4, rtol=1e-4)
np.testing.assert_allclose(f.numpy(), (b.numpy()-d_np).sum()-e_np, atol=1e-4, rtol=1e-4)
def test_pad_reduce_safe(self):
Tensor.manual_seed(0)
a = Tensor.rand(3, 4, 5).realize()
b = Tensor.rand(3, 4, 5).realize()
out = (a + b).pad(((0, 1), (0, 1), (0, 1)), value=1.0).sum().contiguous()
run_schedule(check_schedule(out, 1))
np.testing.assert_allclose(out.numpy(), np.pad(a.numpy()+b.numpy(), ((0, 1), (0, 1), (0, 1)), constant_values=1.0).sum(), atol=1e-5, rtol=1e-6)
def test_multireduce_pad_reduce_safe(self):
Tensor.manual_seed(0)
a = Tensor.randn(3, 4, 5).realize()
b = Tensor.randn(3, 4, 5).realize()
out = (a.pad(((0, 1), (0, 1), (0, 1)), value=1.0).sum(keepdim=True)+b.pad(((0, 1), (0, 1), (0, 1)), value=1.0).sum()).contiguous()
run_schedule(check_schedule(out, 1))
np.testing.assert_allclose(out.numpy(), np.pad(a.numpy(), ((0, 1), (0, 1), (0, 1)), constant_values=1.0).sum(keepdims=True) + \
np.pad(b.numpy(), ((0, 1), (0, 1), (0, 1)), constant_values=1.0).sum(), atol=1e-4, rtol=1e-4)
def test_pad_reduce_unsafe(self):
Tensor.manual_seed(0)
a = Tensor.rand(3, 4, 5).realize()
out = a.log2().pad(((0, 1), (0, 1), (0, 1)), value=1.0).sum().contiguous()
run_schedule(check_schedule(out, 1))
np.testing.assert_allclose(out.numpy(), np.pad(np.log2(a.numpy()), ((0, 1), (0, 1), (0, 1)), constant_values=1.0).sum(), atol=1e-5, rtol=1e-6)
def test_multireduce_pad_reduce_unsafe(self):
Tensor.manual_seed(0)
a = Tensor.randn(3, 4, 5).abs().realize()
b = Tensor.randn(3, 4, 5).abs().realize()
out = (a.log2().pad(((0, 1), (0, 1), (0, 1)), value=1.0).sum()+b).abs().log2().pad(((0, 1), (0, 1), (0, 1)), value=1.0).sum().contiguous()
# run_schedule(check_schedule(out, 1))
run_schedule(check_schedule(out, 2))
np.testing.assert_allclose(out.numpy(), np.pad(np.log2(np.abs(np.pad(np.log2(a.numpy()), ((0, 1), (0, 1), (0, 1)), constant_values=1.0).sum() + \
b.numpy())), ((0, 1), (0, 1), (0, 1)), constant_values=1.0).sum(), atol=3e-4, rtol=1e-5)
def test_shrink_pad_safe(self):
a = Tensor.ones((3, )).contiguous().realize()
b = Tensor.ones((3, )).contiguous().realize()
out = (a + b).shrink(((0, 1),)).pad(((0, 1),)).contiguous()
run_schedule(check_schedule(out, 1))
np.testing.assert_equal(out.numpy(), [2, 0])
def test_shrink_pad_unsafe(self):
a = Tensor.ones((3, )).contiguous().realize()
out = a.exp2().shrink(((0, 1),)).pad(((0, 1),)).contiguous()
run_schedule(check_schedule(out, 1))
np.testing.assert_equal(out.numpy(), [2, 0])
def test_base_change_shrink_pad(self):
a = Tensor.ones(3, 3).contiguous().realize()
b = a.exp2()
c = b[:-1, :-1]
d = c.pad(((0, 1), (0, 1))) * 2
run_schedule(check_schedule(d, 1))
np.testing.assert_equal(d.numpy(), np.pad(np.exp2(a.numpy())[:-1, :-1], ((0, 1), (0, 1)))*2)
def test_base_change_expand_pad(self):
a = Tensor.ones(3, 3).contiguous().realize()
b = a.exp2()
c = b[:, None, :]
d = c.pad(((0, 0), (1, 1), (0, 0))) * 2
run_schedule(check_schedule(d, 1))
np.testing.assert_equal(d.numpy(), np.pad(np.exp2(a.numpy())[:, None, :], ((0, 0), (1, 1), (0, 0)))*2)
def test_fuse_arange_pad_replicate_mode(self):
x = Tensor.empty(3,3,3,3, requires_grad=True)
y = x.pad((-1,2,2,-1), mode="replicate")
dx = y.sum().gradient(x)[0]
sched = check_schedule(dx, 1)
run_schedule(sched)
np.testing.assert_allclose(dx.numpy(), [[[[0.,3.,9.],[0,1.,3.],[0.,0.,0.]]]*3]*3)
# TODO like openpilot with imagef
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
def test_base_change_expand_expand(self):
a = Tensor.ones(4, 4).contiguous().realize()
b = a.cast(dtypes.half).expand(2, 4, 4)
c = b.cast(dtypes.int).expand(2, 2, 4, 4)
run_schedule(check_schedule(c, 1))
np.testing.assert_equal(c.numpy(), np.ones(((2, 2, 4, 4)), dtype=np.int32))
def test_base_change_pad_expand(self):
a = Tensor.full((4, 4), 1.).contiguous().realize()
b = Tensor.full((4, 4), 2.).contiguous().realize()
c = (a + b).pad(((1, 1), (1, 1)))
d = c.cast(dtypes.int).expand((2, 6, 6)) * 4
run_schedule(check_schedule(d, 1))
c_np = np.pad((np.full((4, 4), 2., dtype=np.float32) + np.full((4, 4), 1., dtype=np.float32)), ((1, 1), (1, 1)), constant_values=0.0)
np.testing.assert_equal(d.numpy(), np.broadcast_to(c_np.astype(np.half), (2, *c_np.shape)) * 4)
def test_pad_reduce_unsafe_multiview_st(self):
P = Tensor.ones(3, 3).contiguous()
sums = P.sum(axis=1, keepdim=True)
P /= sums
p = P[0]
p = p.pad(((1, 0), ))
p = p.repeat([2])
run_schedule(check_schedule(p, 3))
tiny_ret = p.numpy()
P = np.ones((3, 3), dtype=np.float32)
sums = P.sum(axis=1, keepdims=True)
P /= sums
p = P[0]
p = np.pad(p, (1, 0), 'constant')
p = np.tile(p, 2)
np.testing.assert_allclose(tiny_ret, p)
@unittest.skip("disabling subbuffer manually isn't supported anymore")
def test_bitcast_disable_subbufer(self):
x = cast(UOp, Tensor.empty(1, dtype=dtypes.float32).realize().uop)
a = x.alu(Ops.EXP2).cast(dtypes.int32, True, allow_buffer_view=False)
b = x.cast(dtypes.int32, True, allow_buffer_view=False)
b = a.alu(Ops.ADD, b)
check_schedule(b, 1)
def test_conv2d(self): _test_conv2d(4)
def test_conv2d_fused(self): _test_conv2d(4)
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
def test_conv2d_half(self): _test_conv2d(4, dtype=dtypes.half)
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "Causes other tests to fail")
def test_conv2d_fused_half(self): _test_conv2d(4, dtype=dtypes.half)
def test_schedule_mem_used_with_inputs(self):
gc.collect()
base = GlobalCounters.mem_used
x = Tensor.ones(256).contiguous().realize()
(x+Tensor.ones(256).contiguous()).schedule()
gc.collect()
self.assertEqual(GlobalCounters.mem_used-base, 1024)
@unittest.skipIf(Device.DEFAULT != "CL", "image only supported on CL")
def test_image_dot_f16_fusion(self):
with Context(FLOAT16=1, OPENPILOT_HACKS=1):
def cnt():
x, y, z = Tensor.empty((64, 64), dtype='float'), Tensor.empty((64, 64), dtype='float'), Tensor.empty((64, 64), dtype='float')
a = (x @ y).relu()
sched = ((a @ z).relu() + a).schedule()
for si in sched: si.lower()
return len([si for si in sched if isinstance(si.prg, CompiledRunner)])
with Context(IMAGE=1):
self.assertEqual(cnt(), 5)
@unittest.skipIf(Device.DEFAULT != "CL", "image only supported on CL")
def test_image_f16_residual_fusion(self):
with Context(FLOAT16=1, OPENPILOT_HACKS=1):
def cnt():
inp = Tensor.empty((512,), dtype='float')
b1, b2 = Tensor.empty((512, 1024), dtype='float'), Tensor.empty((1024, 512), dtype='float')
c1, c2 = Tensor.empty((1024,), dtype='float'), Tensor.empty((512,), dtype='float')
rb = (((((inp @ b1) + c1).relu() @ b2) + c2).relu() + inp).relu()
b16, c16 = Tensor.empty((512, 16), dtype='float'), Tensor.empty((16,), dtype='float')
b32, c32 = Tensor.empty((512, 32), dtype='float'), Tensor.empty((32,), dtype='float')
sched = Tensor.schedule((rb @ b16 + c16).relu(), (rb @ b32 + c32).relu())
for si in sched: si.lower()
return len([si for si in sched if isinstance(si.prg, CompiledRunner)])
with Context(IMAGE=1):
self.assertEqual(cnt(), 9)
@unittest.skipIf(Device.DEFAULT != "CL", "image only supported on CL")
def test_image_conv_fusion(self):
with Context(OPENPILOT_HACKS=1):
def cnt():
x, y, z = Tensor.empty((1, 4, 3, 3)), Tensor.empty((4, 1, 3, 3)), Tensor.empty((4, 1, 7, 7))
a = x.conv2d(y, Tensor.empty(4), groups=4, padding=1)
b = a.conv2d(z, groups=4, padding=3)
sched = (a + b).schedule()
for si in sched: si.lower()
return len([si for si in sched if isinstance(si.prg, CompiledRunner)])
with Context(IMAGE=1):
self.assertEqual(cnt(), 5)
def _test_fusion(self, shapes, f, cnt):
with Context(DEBUG=0, TRACK_MATCH_STATS=0): args = [Tensor.randn(s).realize() for s in shapes]
run_schedule(check_schedule(compare:=f(*args), cnt))
if getenv("COMPARE", 1):
import torch
good = f(*[torch.tensor(x.numpy()) for x in args])
np.testing.assert_allclose(compare.numpy(), good.numpy(), atol=1e-4, rtol=1e-4)
def test_late_fusion_simple(self):
self._test_fusion([(4, 4), (4, 1)], lambda a,b:a.sum(1, keepdim=True)+b, 1)
def test_late_fusion_post_reshape(self):
self._test_fusion([(4, 4), (1, 4)], lambda a,b:a.sum(1).reshape(b.shape)+b, 1)
def test_late_fusion_post_permute(self):
self._test_fusion([(4, 6, 4), (4, 4, 1)], lambda a,b:a.sum(1, keepdim=True).permute((2, 0, 1))+b, 1)
def test_late_fusion_double_transpose(self):
self._test_fusion([(32, 16, 1)],
lambda a:(a.expand(32, 16, 16).sum((2,), keepdim=True).permute((1, 0, 2))+2).permute((1, 0, 2)).contiguous(), 1)
def test_late_fusion_post_expand(self):
self._test_fusion([(32, 32)], lambda a:a-a.sum(1), 2)
def test_cast_padded_view(self):
a = Tensor.arange(4).reshape(1, 4)
casted_view = a.pad(((0, 1), (0, 0))).cast(dtypes.float)
casted_view.realize()
self.assertEqual(casted_view.uop.base.realized.size, 8)
contig = casted_view.contiguous().realize()
self.assertEqual(contig.uop.base.realized.size, 8)
self.assertListEqual(contig.tolist(), [[0.0, 1.0, 2.0, 3.0], [0.0, 0.0, 0.0, 0.0]])
# NOTE: we only reorder CAST if it's an EXPAND
def test_cast_after_shrink(self):
a = Tensor.arange(4).reshape(1, 4)
casted_view = a.shrink(((0, 1), (0, 2))).cast(dtypes.float)
casted_view.realize()
self.assertEqual(casted_view.uop.base.realized.size, 2)
realized_view = casted_view.contiguous().realize()
self.assertEqual(realized_view.uop.base.realized.size, 2)
self.assertListEqual(realized_view.tolist(), [[0, 1]])
def test_cast_const_view(self):
a = Tensor.ones((4, 4), dtype=dtypes.float32)
casted_view = a.cast(dtypes.int32)
run_schedule(check_schedule(casted_view, 1))
realized_const_view = casted_view.contiguous()
run_schedule(check_schedule(realized_const_view, 0))
self.assertListEqual(realized_const_view.tolist(), [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]])
@given(strat.sampled_from(dtypes.all), strat.sampled_from(dtypes.all))
@unittest.skip("kernel count depends on input")
def test_cast_padded_const(self, dt1, dt2):
assume(is_dtype_supported(dt1) and is_dtype_supported(dt2))
a = Tensor(1, dtype=dt1).reshape(1, 1).pad(((1, 1), None))
casted_view = a.cast(dt2)
run_schedule(check_schedule(casted_view, 0))
realized_const_view = casted_view.contiguous()
run_schedule(check_schedule(realized_const_view, 1))
np.testing.assert_equal(realized_const_view.numpy(), [[0], [1], [0]])
def test_simple_indexing(self):
X = Tensor.randn(10, 10).realize()
idxs = Tensor([0, 2]).realize()
xt = X[idxs]
run_schedule(check_schedule(xt, 1))
np.testing.assert_equal(xt.numpy(), X.numpy()[idxs.numpy()])
def test_simple_indexing_alt(self):
X = Tensor.arange(16).reshape(4, 4)
xt = X[[1, 2], [-1, 2]]
run_schedule(check_schedule(xt, 1))
np.testing.assert_equal(xt.numpy(), (np.arange(16).reshape(4, 4))[[1, 2], [-1, 2]])
def test_advanced_indexing(self):
X = Tensor.arange(10)+1
xt = X[[0, -1]]
run_schedule(check_schedule(xt, 1))
np.testing.assert_equal(xt.numpy(), (np.arange(10)+1)[[0, -1]])
def test_advanced_indexing_alt(self):
X = Tensor.arange(6).reshape(3, 2)+1
xt = X[[Tensor([2]), Tensor([1])]]
run_schedule(check_schedule(xt, 1))
np.testing.assert_equal(xt.numpy(), 6)
def test_push_through_reshape(self):
Tensor.manual_seed(0)
x = Tensor.randn(10, 20).realize()
out = x.argmax(1)
run_schedule(check_schedule(out, 2))
np.testing.assert_allclose(out.numpy(), np.argmax(x.numpy(), 1))
def test_arange_push_through_expand(self):
Tensor.manual_seed(0)
a = Tensor.arange(4,)
b = Tensor.randn(4, 4).realize()
out = (a+b).sum()
run_schedule(check_schedule(out, 1))
np.testing.assert_allclose(out.numpy(), (np.arange(4)+b.numpy()).sum(), atol=1e-5)
def test_argmin(self):
Tensor.manual_seed(0)
x = Tensor.randn(4, 32).realize()
out = x.argmin(-1)
run_schedule(check_schedule(out, 2))
np.testing.assert_equal(out.numpy(), x.numpy().argmin(axis=-1))
def test_argmax(self):
Tensor.manual_seed(0)
x = Tensor.randn(4, 32).realize()
out = x.argmax(-1)
run_schedule(check_schedule(out, 2))
np.testing.assert_equal(out.numpy(), x.numpy().argmax(axis=-1))
def test_arange_transposed(self):
Tensor.manual_seed(0)
x = Tensor.randint(4, 1).realize()
a = ((Tensor.arange(4,)*x).T).sum()
run_schedule(check_schedule(a, 1))
np.testing.assert_equal(a.numpy(), (np.arange(4)*x.numpy()).T.sum())
def test_div_padded_arange(self):
x = Tensor.full((2,2), 16)
y = x.idiv(Tensor.linspace(2, 8, steps=4, dtype=dtypes.int).reshape(2,2)).pad(((1,1), (1,1)))
out = y.sum(axis=1)
run_schedule(check_schedule(out, 1))
self.assertListEqual(out.tolist(), [0, 12, 4, 0])
def test_arange_transposed_descendants(self):
Tensor.manual_seed(0)
x = Tensor.randint(4, 1).realize()
a = (Tensor.arange(4,)*x).T
b = Tensor.randint(4, 4).realize()
out = (a+b).sum()
run_schedule(check_schedule(out, 1))
np.testing.assert_equal(out.numpy(), ((np.arange(4)*x.numpy()).T+b.numpy()).sum())
def test_arange_index(self):
Tensor.manual_seed(0)
x = Tensor.randn(5, 2).realize()
a = Tensor.arange(10)
out = (x + a[2]).sum()
run_schedule(check_schedule(out, 1))
np.testing.assert_allclose(out.numpy(), (x.numpy()+np.arange(10)[2]).sum(), atol=1e-5, rtol=1e-6)
def test_arange_index_contiguous(self):
Tensor.manual_seed(0)
x = Tensor.randn(5, 2).realize()
a = Tensor.arange(10).contiguous()
out = (x + a[2]).sum()
run_schedule(check_schedule(out, 2))
np.testing.assert_allclose(out.numpy(), (x.numpy()+np.arange(10)[2]).sum(), atol=1e-5, rtol=1e-6)
def test_arange_index_child(self):
Tensor.manual_seed(0)
x = Tensor.randn(5, 2).realize()
a = Tensor.arange(10)+1
out = (x + a[2]).sum()
run_schedule(check_schedule(out, 1))
np.testing.assert_allclose(out.numpy(), (x.numpy()+(np.arange(10)+1)[2]).sum(), atol=1e-5, rtol=1e-6)
def test_user_contiguous(self):
Tensor.manual_seed(0)
x = Tensor.randn(5, 2).realize()
a = (Tensor.arange(10)+1).contiguous()
out = (x + a[2]).sum()
run_schedule(check_schedule(out, 2))
np.testing.assert_allclose(out.numpy(), (x.numpy()+(np.arange(10)+1)[2]).sum(), atol=1e-5, rtol=1e-6)
@unittest.skip("BUFFER_VIEW no longer supported on non-disk devices")
def test_arange_view_op(self):
a = Tensor.arange(12).reshape(4, 3).shrink(((1, 2), (1, 3))).contiguous()
sched = run_schedule(check_schedule(a, 1))
self.assertIs(sched[1].ast.op, Ops.BUFFER_VIEW)
np.testing.assert_equal(a.numpy(), [[4, 5]])
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
def test_precompute_freqs_cis(self):
from extra.models.llama import precompute_freqs_cis
args = {"dim":32, "end":2048, "theta":10000}
fused = precompute_freqs_cis(**args)
run_schedule(check_schedule(fused, 1))
if getenv("CHECK", 1):
ref = precompute_freqs_cis(**args)
run_schedule(check_schedule(ref, 1))
np.testing.assert_equal(fused.numpy(), ref.numpy())
def test_fuse_assign_contiguous(self):
x = Tensor.zeros(4, 4, dtype=dtypes.int).contiguous().realize()
a = Tensor.arange(8).reshape(4, 2)
run_schedule(check_schedule(x.shrink((None, (0, 2))).assign(a.contiguous()), 2))
np.testing.assert_equal(x.numpy(), [[0, 1, 0, 0], [2, 3, 0, 0], [4, 5, 0, 0], [6, 7, 0, 0]])
def test_assign_non_contiguous_alt(self): self.test_assign_non_contiguous(alt=True)
def test_assign_non_contiguous(self, alt=False):
x = (Tensor.arange(16)-100).reshape(4,4).contiguous().realize()
xref = x.numpy()
if alt:
y = Tensor.randint(2, 4).contiguous().realize()
a = Tensor.arange(8).reshape(2, 4)+y
tst = x.shrink(((0, 2), None)).assign(a).realize()
xref[:2, :] = np.arange(8).reshape(2, 4)+y.numpy()
else:
y = Tensor.randint(4, 2).contiguous().realize()
a = Tensor.arange(8).reshape(4, 2)+y
tst = x.shrink((None, (0, 2))).assign(a).realize()
xref[:, :2] = np.arange(8).reshape(4, 2)+y.numpy()
np.testing.assert_equal(x.numpy(), xref)
np.testing.assert_equal(tst.numpy(), a.numpy())
def test_setitem_sched(self, mop=lambda x:x, expected_kcount=1):
a = Tensor.arange(16, device="CPU").reshape(4, 4).contiguous().realize()
a2 = mop(a)
expected = (a+a2).tolist()
a.assign(a+a2)
kcount = len(sched:=a.schedule())
run_schedule(sched)
self.assertListEqual(a.tolist(), expected)
self.assertEqual(kcount, expected_kcount)
def test_setitem_permuted_sched(self): self.test_setitem_sched(lambda x: x.T, 2)
def test_setitem_paddded_sched(self): self.test_setitem_sched(lambda x: x.shrink_to(4, 1).pad_to(4, 4), 1)
def test_setitem_const_fused(self):
# https://github.com/tinygrad/tinygrad/issues/10690
a = Tensor.arange(16).contiguous().realize()
GlobalCounters.reset()
a[4] = 3
self.assertEqual(GlobalCounters.kernel_count, 0)
a.realize()
self.assertEqual(GlobalCounters.kernel_count, 1)
self.assertListEqual(a.tolist(), [0, 1, 2, 3, 3, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])
def test_no_extra_contiguous_on_setitem_assign_back(self):
# pattern: contiguous copy, advanced setitem, assign back (e.g. torch backend _view_write)
base = Tensor.arange(16).reshape(4, 4).contiguous()
flat_base = base.reshape(16).contiguous()
idx = Tensor([1,2,5,6], dtype=dtypes.int32)
flat_base[idx] = Tensor([99,99,99,99])
base.assign(flat_base.reshape(4, 4))
sched = check_schedule(base, 4)
run_schedule(sched)
expected = list(range(16))
for i, v in zip([1,2,5,6], [99,99,99,99]): expected[i] = v
np.testing.assert_equal(base.reshape(16).numpy(), expected)
def test_sparse_categorical_crossentropy_simple(self):
X = Tensor([[0, 2, 3], [1, 2, 3]]).realize()
Y = Tensor([1, 2]).realize()
loss = X.sparse_categorical_crossentropy(Y)
run_schedule(check_schedule(loss, 3))
np.testing.assert_allclose(loss.item(), 0.878309, atol=1e-5, rtol=1e-6)
def test_const_folding_alt(self):
t = Tensor.full((2,), 1.)
lt = (t < 0.)
a = Tensor.empty(2).assign(t*lt.where(-1., 0.))
b = Tensor.empty(2, dtype=dtypes.bool).assign(lt)
Tensor.realize(a, b)
self.assertEqual(a.tolist(), [0., 0.])
self.assertEqual(b.tolist(), [False, False])
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "Validation error on WebGPU")
def test_mnist_val(self):
from tinygrad.nn.datasets import mnist
import torch
_, Y_train, _, _ = mnist()
samples = Tensor.randint(BS:=getenv("BS", 512), high=cast(int,Y_train.shape[-1])).realize()
yt = Tensor.randn(BS, 10).realize()
loss = yt.sparse_categorical_crossentropy(Y_train[samples])
run_schedule(check_schedule(loss, 4))
loss_fused = loss.numpy()
loss_ref = torch.nn.CrossEntropyLoss()(torch.tensor(yt.numpy()), torch.tensor(Y_train.numpy())[torch.tensor(samples.numpy())])
np.testing.assert_allclose(loss_fused, loss_ref.numpy(), atol=1e-6, rtol=1e-6)
def test_arange_fuse_grouped_children(self):
X = Tensor.randn(4, 4).realize()
r = (X+Tensor.arange(16).reshape(4, 4)).sum()
out0 = r+2
out1 = r+3
run_schedule(check_schedule([out0, out1], 2)) # TODO: 1?
r_ref = (X.numpy()+np.arange(16).reshape(4, 4)).sum()
np.testing.assert_allclose(out0.numpy(), r_ref+2, rtol=2e-7)
np.testing.assert_allclose(out1.numpy(), r_ref+3, rtol=2e-7)
def test_recursive_swizzle(self):
a = Tensor([1,2,3,4]).realize()
for _ in range(24): a = a + a
new_uop = a.reshape(4,1).realize().uop
assert new_uop.base.op is Ops.BUFFER
def test_self_assign_no_empty_kernel(self):
for shape in [(3, 3), (4, 4)]:
a = Tensor.ones(*shape).contiguous().realize()
a.assign(a / 1)
run_schedule(check_schedule(a, 0, filter_sink=False))
self.assertListEqual(a.tolist(), [[1.]*shape[1]]*shape[0])
class TestLimitBufs(unittest.TestCase):
@unittest.skipIf(CI and Device.DEFAULT == "NV", "crashes on NV CI")
def test_limit_bufs_with_var(self):
N = 31
with Context(TRACK_MATCH_STATS=0, DEBUG=0):
bufs = [Tensor([1]*10).contiguous().realize() for i in range(N)]
vi = Variable("i", 0, 9).bind(1)
vj = Variable("j", 0, 9).bind(2)
root = bufs[0][vi] + bufs[0][vj]
for X in range(1,N): root = root + bufs[X][vi] + bufs[X][vj]
self.assertEqual(root.item(), N * 2)
def test_limit_bufs_arange_condition(self):
# WHERE with arange-based condition (pure index math, no device) and many buffer loads should not crash limit_bufs
with Context(MAX_KERNEL_BUFFERS=8):
N = 8
idx = Tensor.arange(N)
base = Tensor.zeros(N)
for i in range(4):
a, b = Tensor.rand(N).realize(), Tensor.rand(N).realize()
base = (idx >= i).where(a + b, base)
assert all(x > 0 for x in base.tolist())
class TestSwizzle(unittest.TestCase):
def test_swizzle_simple(self):
Tensor.manual_seed(0)
with Context(DEBUG=0, TRACK_MATCH_STATS=0):
a = Tensor.randint(32, 32).realize()
r = (a+a).sum(1).sum(0)
# double reduce collapses to a single reduce
run_schedule(check_schedule(r, 1))
self.assertEqual(r.numpy(), (a.numpy()+a.numpy()).sum(1).sum(0))
def test_single_swizzle(self):
Tensor.manual_seed(0)
with Context(DEBUG=0, TRACK_MATCH_STATS=0):
a = Tensor.randint(4, 1).realize()
b = Tensor.ones((1, 1), dtype=a.dtype).contiguous().realize()
# ADD(REDUCE(RESHAPE(LOAD)), LOAD) to ADD(REDUCE(RESHAPE(LOAD))), RESHAPE(LOAD)
r = a.sum(0)+b
run_schedule(check_schedule(r, 1))
self.assertEqual(r.numpy(), a.numpy().sum(0)+1)
def test_double_swizzle_possible(self):
Tensor.manual_seed(0)
with Context(DEBUG=0, TRACK_MATCH_STATS=0):
a = Tensor.randint(4,).realize()
b = Tensor.randint(4,).realize()
# parallel reduce!
add = a.sum(0)+b.sum(0)
run_schedule(check_schedule(add, 1))
self.assertEqual(add.numpy(), a.numpy().sum(0)+b.numpy().sum(0))
def test_swizzle_reduceop(self):
Tensor.manual_seed(0)
x = Tensor.randn(4,4).realize()
y = Tensor.randn(4,4,4).realize()
out = x.reshape(4,4,1).expand(4,4,4).sum(axis=(1,))+y
run_schedule(check_schedule(out, 2)) # TODO: 1?
np.testing.assert_allclose(out.numpy(), np.tile(x.numpy().reshape(4,4,1), (1,1,4)).sum(axis=1)+y.numpy())
def test_permute_rewrite(self):
x = Tensor.randn(4, 4, 16).realize()
y = Tensor.randn(4, 1, 16).realize()
z = Tensor.randn(4, 4, 1).realize()
t = (x*y).sum(axis=(0, 2)).reshape(1, 4, 1).permute(0, 2, 1)+z
run_schedule(check_schedule(t, 2)) # TODO: 1?
t_np = (x.numpy()*y.numpy()).sum(axis=(0, 2)).reshape(1, 4, 1).transpose(0, 2, 1)+z.numpy()
np.testing.assert_allclose(t.numpy(), t_np, atol=1e-6, rtol=1e-3)
@unittest.skip("TODO: this swizzle isn't resolvable when there's a mask")
def test_swizzle_failure_permute(self):
a = Tensor.empty(45,65).T.reshape(65,1,45).pad((None,None,(0,45))).expand(65,45,90)
b = Tensor.empty(45,65)
a_reduce = a.sum(axis=(2,), keepdim=True).sum(axis=(1,))
b_reduce = b.sum(axis=(0,))
t = a_reduce+b_reduce
run_schedule(check_schedule(t, 1))
def test_parallel_reduce_possible(self):
Tensor.manual_seed(0)
x = Tensor.randn(4, 2, 2).realize()
y = Tensor.randn(4, 2, 2).realize()
t = x.sum(axis=1)+y.sum(axis=1)
run_schedule(check_schedule(t, 1))
np.testing.assert_allclose(t.numpy(), x.numpy().sum(axis=1)+y.numpy().sum(axis=1), atol=1e-6, rtol=1e-3)
# kernels can only have 1 or n in each dim
def test_dont_parallelize_different_n(self):
Tensor.manual_seed(0)
x = Tensor.randn(4, 2, 2).realize()
y = Tensor.randn(4, 3, 2).realize()
t = x.sum(axis=1)+y.sum(axis=1)
run_schedule(check_schedule(t, 1))
np.testing.assert_allclose(t.numpy(), x.numpy().sum(axis=1)+y.numpy().sum(axis=1), atol=1e-6, rtol=1e-3)
def test_unsafe_pad(self):
x = Tensor.full((2,2), 1.0).contiguous()
y = x*x.sum((1,)).reciprocal()
t = y.pad(((0,1),None))
run_schedule(check_schedule(t, 3))
np.testing.assert_equal(t.numpy(), [[0.5, 0.5], [0.5, 0.5], [0., 0.]])
zero_pm = UPat(Ops.CONST, arg=0)
class TestView(unittest.TestCase):
def test_all_masked_out(self):
# start with non CONST Ops
a = Tensor.rand(10, 10).realize()
# all masked out, degrades to const 0
b = a.pad(((0, 10), None))[10:]
sched = check_schedule(b.contiguous(), 1)
run_schedule(sched)
np.testing.assert_equal(b.numpy(), 0)
def test_mask_dim_1(self):
# mask out dim = 1 works too
a = Tensor.rand(10, 10).realize()
b = a.pad((None, (0, 10)))[:, 10:]
assert b.shape == (10, 10)
sched = check_schedule(b.contiguous(), 1)
run_schedule(sched)
np.testing.assert_equal(b.numpy(), 0)
def test_partial_mask(self):
# partial masked out does not degrade into CONST
a = Tensor.rand(10, 10).realize()
b = a.pad(((0, 5), None))[5:]
assert b.shape == (10, 10)
sched = check_schedule(b.contiguous(), 1)
run_schedule(sched)
np.testing.assert_allclose(b.numpy(), np.pad(a.numpy(), ((0, 5), (0, 0)))[5:])
# a*VIEW(x), where VIEW(x) = 0
# x collapses along with its children
def test_parent_view_collapses(self):
a = Tensor([1, 2])
b = Tensor.arange(3).contiguous()
bv = b.pad(((0, 2),))[-2:]
# this becomes a late a*0
late_mul = a*bv
run_schedule(check_schedule(late_mul, 2))
# the arange doesn't realize
#self.assertIsNone(b.uop.base.realized)
# mul doesn't realize
#self.assertIsNone(late_mul.uop.base.realized)
self.assertEqual(late_mul.tolist(), [0, 0])
# SINK has two branches:
# a*VIEW(x), where VIEW(x) = 0
# x+2
# as long as one child realizes, x does not collapse
def test_parent_multiple_children_no_collapse(self):
a = Tensor([1, 2])
b = Tensor.arange(3).contiguous()
bv = b.pad(((0, 2),))[-2:]
late_mul = a*bv
other_child = b+2
s = check_schedule([late_mul, other_child], 3)
# the arange becomes a BUFFER
self.assertIs(b.uop.base.op, Ops.BUFFER)
# NOTE: no longer checked
# mul still collapses
#self.assertIs(late_mul.uop.base.op, Ops.CONST)
run_schedule(s)
self.assertEqual(other_child.tolist(), [2, 3, 4])
@unittest.skipIf(Device.DEFAULT == "CPU", "tests copy from another device to cpu")
class TestCopyFolding(unittest.TestCase):
def test_const_copy_is_free(self):
b = Tensor(1).to("CPU") * 4
run_schedule(check_schedule(b, 1, filter_sink=False))
assert b.item() == 4
def test_one_hot_with_copy(self):
y = Tensor([1, 2, 3]).to("CPU")
x = y.one_hot(10)
check_schedule(x, 3, filter_sink=False)
def test_const_copy_multi(self):
x = Tensor.ones(1, device="CPU").to_(["CPU", "CPU:1"]) * 2
run_schedule(check_schedule(x, 2, filter_sink=False))
self.assertEqual(x.item(), 2.0)
def test_late_const_copy_folding(self):
a = Tensor.arange(3).realize()
zeros = Tensor.zeros(3).realize()
b = (a*zeros).to("CPU") + 1
run_schedule(check_schedule(b, 1, filter_sink=False))
self.assertListEqual(b.tolist(), [1, 1, 1])
self.assertEqual(b.device, "CPU")
def test_alu_after_copy(self):
a = Tensor.ones((4,)).to("CPU")
b = Tensor.empty(4, device="CPU")
add = a+b
assert all_same([x.device for x in add.uop.src]), f"ALU has different devices! {[x.device for x in add.src]}"
add.schedule()
def test_alu_before_copy(self):
buf = Tensor.ones(1).contiguous().realize()
a = buf+1
b = a.to("CPU")
self.assertListEqual(b.tolist(), [2.])
def test_copy_to_same_device(self):
a = Tensor.empty(4).uop
b = a.copy_to_device(a.device)
check_schedule(b, 1, filter_sink=False) # TODO: 0?
def test_copy_to_same_device_alt(self):
a = Tensor.empty(4, 4).uop
b = a.copy_to_device(a.device)
check_schedule(b, 1, filter_sink=False) # TODO: 0?
def test_copy_to_same_device_sched(self):
a = Tensor.ones(4).contiguous().realize().uop.buf_uop
t = Tensor(a.copy_to_device(a.device))
sched = t.schedule()
assert len([s for s in sched if s.ast.op is Ops.COPY]) == 0
run_schedule(sched)
assert t.uop.is_realized, f"didn't realize Tensor {t}"
self.assertListEqual(t.tolist(), [1.,1.,1.,1.])
def test_self_assign_same_device_copy(self):
a = Tensor.ones(4, 4).contiguous().realize()
# use copy_to_device to bypass Tensor.to() shortcircuit and force a real same-device COPY in the graph
a.assign(Tensor(a.uop.copy_to_device(a.device), a.device))
run_schedule(check_schedule(a, 2, filter_sink=False))
self.assertListEqual(a.tolist(), [[1.]*4]*4)
def test_clone(self):
a = Tensor.empty(4)
check_schedule(a.clone(), 1, filter_sink=False)
def test_shrink_copy(self):
a = Tensor.arange(4)
view = a.shrink(((0, 2),))
b = view.clone()
run_schedule(check_schedule(b, 1, filter_sink=False))
self.assertEqual(b.uop.base.buffer.size, 2)
self.assertEqual(b.uop.numel(), 2)
self.assertListEqual(b.tolist(), [0, 1])
def test_expanded_copy(self):
a = Tensor.arange(2)
view = a.reshape(2, 1).expand(2, 2)
b = view.clone()
run_schedule(check_schedule(b, 1, filter_sink=False))
self.assertEqual(b.uop.base.buffer.size, 4)
self.assertEqual(b.uop.numel(), 4)
self.assertListEqual(b.tolist(), [[0, 0], [1, 1]])
def test_permuted_copy(self):
a = Tensor.arange(4)
b = a.reshape(2, 2).permute(1, 0)
b.realize()
self.assertListEqual(b.tolist(), [[0, 2], [1, 3]])
def test_permute_on_disk(self):
with open(temp('dt_arange_4_permute'), "wb") as f: f.write(Tensor.arange(4).realize().uop.base.buffer.as_memoryview())
a = Tensor.empty(4, dtype=dtypes.int32, device=f"disk:{temp('dt_arange_4_permute')}")
b = a.reshape(2, 2).permute(1, 0).to("CPU")
b.realize()
self.assertListEqual(b.tolist(), [[0, 2], [1, 3]])
def test_permute_on_disk_contiguous(self):
with open(temp('dt_arange_4_permute_contig'), "wb") as f: f.write(Tensor.arange(4).realize().uop.base.buffer.as_memoryview())
a = Tensor.empty(4, dtype=dtypes.int32, device=f"disk:{temp('dt_arange_4_permute_contig')}")
b = a.reshape(2, 2).permute(1, 0).contiguous().to("CPU")
b.realize()
self.assertListEqual(b.tolist(), [[0, 2], [1, 3]])
def test_permute_after_shrink(self):
a = Tensor.arange(5)
b = a.shrink(((0, 4),)).reshape(2, 2).permute(1, 0).to("CPU")
b.realize()
self.assertListEqual(b.tolist(), [[0, 2], [1, 3]])
# NOTE: disk permute must come after COPY
def test_permute_after_shrink_on_disk(self):
with open(temp('dt_arange_5_permute'), "wb") as f: f.write(Tensor.arange(5).realize().uop.base.buffer.as_memoryview())
a = Tensor.empty(5, dtype=dtypes.int32, device=f"disk:{temp('dt_arange_5_permute')}")
b = a.shrink(((0, 4),)).reshape(2, 2).permute(1, 0).to("CPU")
b.realize()
self.assertListEqual(b.tolist(), [[0, 2], [1, 3]])
class TestUOpBecome(unittest.TestCase):
def test_setitem_offset(self):
a = Tensor.full((16,), 0.).contiguous().realize()
b = Tensor.full((16,), 1.).contiguous().realize()
a_view = a[4:].reshape(3, 4).shrink(((0,2),(0,2))).reshape((4,))
b.shrink(((0,4),)).assign(a_view).realize()
self.assertListEqual(b.tolist(), [0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])
class TestFusionOp(unittest.TestCase):
def test_contiguous_add(self):
def test(contig=False):
bt = Tensor(np.arange(16), dtype=dtypes.float32).reshape(4,4)
x = bt.permute(1,0)
if contig: x = x.contiguous()
return (x.permute(1,0) + bt).data()
assert test() == test(True)
def test_expand_fuse(self):
bt = Tensor(np.ones((10, 1)), dtype=dtypes.float32)
out = (bt*2).expand(10,10).sum(1)
sched = out.schedule()
run_schedule(sched)
outd = out.tolist()
assert all(x == 20.0 for x in outd)
if __name__ == '__main__':
unittest.main(verbosity=2)