from typing import List, Tuple, Union, cast import numpy as np import unittest from dataclasses import replace from tinygrad.codegen.kernel import Opt, OptOps, KernelOptError, Kernel from tinygrad.codegen.lowerer import get_grouped_dims from tinygrad.ops import UOp, UOps from tinygrad.device import Device, Buffer from extra.ops import BinaryOps, BufferOps, MemBuffer, ConstBuffer, LazyOp, MetaOps, TernaryOps, ReduceOps, UnaryOps, to_uop from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import View # from tinygrad.shape.symbolic import Variable from tinygrad.tensor import Tensor, _to_np_dtype from tinygrad.engine.schedule import create_schedule from tinygrad.engine.realize import run_schedule, lower_schedule, CompiledRunner from tinygrad.helpers import prod, Context, getenv, CI, flatten, dedup from tinygrad.dtype import DType, dtypes def helper_realized_ast(r:Union[Tensor, List[Tensor]]) -> Tuple[UOp, List[Buffer]]: if isinstance(r, Tensor): r = [r] s = create_schedule([x.lazydata for x in r]) run_schedule(s[:-1]) # run all kernels except the last one # now all input LazyBuffers buffers in s[-1] should be realized # allocate an output buffer output_buffers = [Buffer((out).device, out.size, out.dtype).allocate() for out in s[-1].outputs] return s[-1].ast, output_buffers+list(s[-1].inputs) def helper_tc_allclose(n:int, m:int, k:int, dtype_in:DType, dtype_out:DType, axis:int=0, tc_opt:int=0): a, b = Tensor.rand(m, k, dtype=dtype_in), Tensor.rand(k, n, dtype=dtype_in) np_a, np_b = a.numpy(), b.numpy() r = a.matmul(b, acc_dtype=dtype_out) sched = create_schedule([r.lazydata]) realized_ast = sched[-1].ast run_schedule(sched) out = r.numpy() k = Kernel(realized_ast) k.apply_tensor_cores(1, axis=axis, tc_opt=tc_opt) k.linearize() assert len([uop for uop in k.uops if uop.op is UOps.WMMA]) > 0, "tensor core not triggered" assert len([x for x in k.applied_opts if x.op is OptOps.TC]) == 1, "tensor core opt not included" np_c = np_a @ np_b if dtype_out == dtypes.half: tc_atol, tc_rtol = 1e-2, 1e-3 elif dtype_in == dtypes.bfloat16: tc_atol, tc_rtol = 1e-2, 3e-3 else: tc_atol, tc_rtol = 5e-3, 1e-4 np.testing.assert_allclose(np_c, out, atol=tc_atol, rtol=tc_rtol) def helper_tc_ensure_uops_and_opts_count(n: int, m:int, k:int, dtype_in:DType, dtype_out:DType, axis:int=0, tc_opt:int=0, ensure_triggered:bool=True): a, b = Tensor.rand(m, k, dtype=dtype_in), Tensor.rand(k, n, dtype=dtype_in) r = a.matmul(b, acc_dtype=dtype_out) sched = create_schedule([r.lazydata]) realized_ast = sched[-1].ast k = Kernel(realized_ast) k.apply_tensor_cores(1, axis=axis, tc_opt=tc_opt) k.linearize() wmmas = len([uop for uop in k.uops if uop.op is UOps.WMMA]) tcs = len([x for x in k.applied_opts if x.op is OptOps.TC]) if ensure_triggered: assert wmmas > 0, "tensor core not triggered" assert tcs == 1, "tensor core opt not included" else: assert wmmas == 0, "tensor core is incorrectly triggered" assert tcs == 0, "tensor core opt is incorrectly included" class TestLinearizer(unittest.TestCase): def test_arg_dedup(self): a, b = Tensor.randn(4), Tensor.randn(4) np_a, np_b = a.numpy(), b.numpy() c = ((a.shrink(((0, 2),)) - a.shrink(((2, 4),))) - (b.shrink(((0, 2),)) - b.shrink(((2, 4),)))) lowered = list(lower_schedule(create_schedule([c.lazydata]))) for ei in lowered: ei.run() rawbufs = lowered[-1].bufs assert len(rawbufs) == 3 and set(rawbufs[1:]) == {a.lazydata.base.realized, b.lazydata.base.realized} np_c = (np_a[:2] - np_a[2:]) - (np_b[:2] - np_b[2:]) np.testing.assert_allclose(np_c, c.numpy(), atol=1e-4, rtol=1e-4) def test_load_removed(self): a = Tensor.rand(1).realize() b = Tensor.rand(1).realize() ta = Tensor.where(Tensor(True), a, b).numpy() tb = Tensor.where(Tensor(False), a, b).numpy() np.testing.assert_equal(a.numpy(), ta) np.testing.assert_equal(b.numpy(), tb) def test_multioutput(self): dtype, st = dtypes.int, ShapeTracker.from_shape((8,)) a = LazyOp(BufferOps.LOAD, arg=MemBuffer(idx=2, dtype=dtype, st=st)) b = LazyOp(BufferOps.LOAD, arg=MemBuffer(idx=3, dtype=dtype, st=st)) out0 = LazyOp(BufferOps.STORE, (LazyOp(op=BinaryOps.ADD, src=(a,b)),), MemBuffer(idx=0, dtype=dtype, st=st)) out1 = LazyOp(BufferOps.STORE, (LazyOp(op=BinaryOps.MUL, src=(a,b)),), MemBuffer(idx=1, dtype=dtype, st=st)) a_t = Tensor.full(st.shape, 2).contiguous().realize() b_t = Tensor.full(st.shape, 3).contiguous().realize() lin = helper_linearizer_ast((out0, out1), [a_t, b_t], wanna_output=[a_t.numpy()+b_t.numpy(), a_t.numpy()*b_t.numpy()])[0] stores = [u for u in lin.uops if u.op is UOps.STORE] mutable_bufs = dedup(flatten([[x for x in u.src[0].sparents if x.op is UOps.DEFINE_GLOBAL] for u in stores])) assert len(mutable_bufs) == len(stores) == 2 assert [u.arg for u in mutable_bufs] == [0, 1] @unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet") @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared") def test_multireduce(self): Tensor.manual_seed(0) x = Tensor.randn(32, dtype=dtypes.float).realize() first_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((1, 32)).expand((32, 32)))) first_reduce = LazyOp(ReduceOps.SUM, (first_x,), (1,)) second_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((32, 1)))) diff = (second_x-first_reduce) second_reduce = LazyOp(ReduceOps.SUM, (diff,), (0,)) store = LazyOp(BufferOps.STORE, (second_reduce,), MemBuffer(0, dtypes.float, ShapeTracker.from_shape((1, 1)))) opts = [ # [Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.GROUPTOP, 1, 2)], # grouping # [Opt(OptOps.GROUPTOP, 0, 8), Opt(OptOps.GROUPTOP, 1, 8)], # [Opt(OptOps.GROUPTOP, 0, 16), Opt(OptOps.GROUPTOP, 1, 16)], # [Opt(OptOps.GROUPTOP, 0, 32), Opt(OptOps.GROUPTOP, 0, 32)], [Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UNROLL, 1, 2)], # unroll reduce [Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UNROLL, 1, 4)], [Opt(OptOps.UNROLL, 0, 8), Opt(OptOps.UNROLL, 1, 8)] if Device.DEFAULT not in {"NV", "METAL"} else [], # can't do float8, # [Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.GROUPTOP, 1, 2), Opt(OptOps.UNROLL, 2, 2), Opt(OptOps.UNROLL, 3, 2)], # grouping + unrolling # [Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UNROLL, 1, 2), Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.GROUPTOP, 1, 2)], # [Opt(OptOps.GROUPTOP, 0, 4), Opt(OptOps.GROUPTOP, 1, 4), Opt(OptOps.UNROLL, 2, 8), Opt(OptOps.UNROLL, 2, 8)], # [Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UNROLL, 1, 4), Opt(OptOps.GROUPTOP, 0, 8), Opt(OptOps.GROUPTOP, 0, 8)], ] wanna_output = (x.numpy()-x.numpy().sum(-1, keepdims=True)).sum(-1).reshape(1,1) lins = helper_linearizer_ast((store, ), [x], wanna_output=[wanna_output], opts=opts) for l in lins: ranges = [u.op for u in l.uops if (u.op is UOps.RANGE and u.arg[1]) or (u.op is UOps.ENDRANGE and u.src[0].arg[1])] for i,u in enumerate(ranges): if i == 0: continue assert ranges[i-1] != u, f"multireduce nested the ranges! {ranges[i-1], {u}}" @unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet") @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared") def test_mid_dim_multireduce(self): Tensor.manual_seed(0) x = Tensor.randn(27, 32, 5, dtype=dtypes.float).realize() first_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((27, 1, 32, 5)).expand((27, 32, 32, 5)))) first_reduce = LazyOp(ReduceOps.SUM, (first_x,), (2,)) second_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((27, 32, 1, 5)))) diff = (second_x-first_reduce) second_reduce = LazyOp(ReduceOps.SUM, (diff,), (1,)) store = LazyOp(BufferOps.STORE, (second_reduce,), MemBuffer(0, dtypes.float, ShapeTracker.from_shape((27, 1, 1, 5)))) opts = [ # locals [Opt(OptOps.LOCAL, 0, 3)], [Opt(OptOps.LOCAL, 0, 9)], [Opt(OptOps.LOCAL, 0, 27)], # # grouping # [Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.GROUPTOP, 1, 2)], # [Opt(OptOps.GROUPTOP, 0, 8), Opt(OptOps.GROUPTOP, 1, 8)], # [Opt(OptOps.GROUPTOP, 0, 16), Opt(OptOps.GROUPTOP, 1, 16)], # [Opt(OptOps.GROUPTOP, 0, 32), Opt(OptOps.GROUPTOP, 0, 32)], # # unroll [Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UNROLL, 1, 2)], [Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UNROLL, 1, 4)], [Opt(OptOps.UNROLL, 0, 8), Opt(OptOps.UNROLL, 1, 8)] if Device.DEFAULT not in {"NV", "METAL"} else [], # # upcasting [Opt(OptOps.UPCAST, 0, 3)], [Opt(OptOps.UPCAST, 0, 9)], # # locals with grouping # [Opt(OptOps.LOCAL, 0, 3), Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.GROUPTOP, 1, 2)], # # locals with unroll # [Opt(OptOps.LOCAL, 0, 3), Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UNROLL, 1, 2)], # # locals with upcasting [Opt(OptOps.LOCAL, 0, 3), Opt(OptOps.UPCAST, 0, 9)], # # grouping with unrolling # [Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.GROUPTOP, 1, 2), Opt(OptOps.UNROLL, 2, 2), Opt(OptOps.UNROLL, 3, 2)], # [Opt(OptOps.GROUPTOP, 0, 4), Opt(OptOps.GROUPTOP, 1, 4), Opt(OptOps.UNROLL, 2, 8), Opt(OptOps.UNROLL, 2, 8)], # # grouping with upcasting # [Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.GROUPTOP, 1, 2), Opt(OptOps.UPCAST, 0, 3)], # # locals with grouping with unroll # [Opt(OptOps.LOCAL, 0, 3), Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.GROUPTOP, 1, 2), Opt(OptOps.UNROLL, 2, 2), Opt(OptOps.UNROLL, 3, 2)], # [Opt(OptOps.LOCAL, 0, 3), Opt(OptOps.GROUPTOP, 0, 4), Opt(OptOps.GROUPTOP, 1, 4), Opt(OptOps.UNROLL, 2, 8), Opt(OptOps.UNROLL, 2, 8)], # # locals with grouping with upcasting # [Opt(OptOps.LOCAL, 0, 3), Opt(OptOps.UPCAST, 0, 3), Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.GROUPTOP, 1, 2)], # [Opt(OptOps.LOCAL, 0, 9), Opt(OptOps.UPCAST, 0, 3), Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.GROUPTOP, 1, 2)], # # grouping with unrolling and upcasting # [Opt(OptOps.UPCAST, 0, 3), Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.GROUPTOP, 1, 2), Opt(OptOps.UNROLL, 2, 2), Opt(OptOps.UNROLL, 3, 2)], # [Opt(OptOps.UPCAST, 0, 3), Opt(OptOps.GROUPTOP, 0, 4), Opt(OptOps.GROUPTOP, 1, 4), Opt(OptOps.UNROLL, 2, 8), Opt(OptOps.UNROLL, 2, 8)], # # locals + grouping + unrolling + upcasting # [Opt(OptOps.LOCAL, 0, 3), Opt(OptOps.UPCAST, 0, 3), Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.GROUPTOP, 1, 2), # Opt(OptOps.UNROLL, 2, 2), Opt(OptOps.UNROLL, 3, 2)], ] wanna_output = (x.numpy()-x.numpy().sum(axis=1, keepdims=True)).sum(axis=1).reshape(27,1,1,5) lins = helper_linearizer_ast((store, ), [x], wanna_output=[wanna_output], opts=opts) for l in lins: ranges = [u.op for u in l.uops if (u.op is UOps.RANGE and u.arg[1]) or (u.op is UOps.ENDRANGE and u.src[0].arg[1])] for i,u in enumerate(ranges): if i == 0: continue assert ranges[i-1] != u, f"multireduce nested the ranges! {ranges[i-1], {u}}" def test_triple_multireduce(self): Tensor.manual_seed(0) x0 = Tensor.randn(27, 32, 5, dtype=dtypes.float).realize() x1 = Tensor.randn(27, 32, 5, dtype=dtypes.float).realize() x2 = Tensor.randn(27, 32, 5, dtype=dtypes.float).realize() first_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x0.lazydata.st.reshape((27, 1, 1, 32, 5)).expand((27, 32, 32, 32, 5)))) first_reduce = LazyOp(ReduceOps.SUM, (first_x,), (3,)) second_x = LazyOp(BufferOps.LOAD, (), MemBuffer(2, dtypes.float, x1.lazydata.st.reshape((27, 1, 32, 1, 5)).expand((27, 32, 32, 1, 5)))) diff = (second_x-first_reduce) second_reduce = LazyOp(ReduceOps.SUM, (diff,), (2,)) third_x = LazyOp(BufferOps.LOAD, (), MemBuffer(3, dtypes.float, x2.lazydata.st.reshape((27, 32, 1, 1, 5)))) mul = (third_x*second_reduce) third_reduce = second_reduce = LazyOp(ReduceOps.SUM, (mul,), (1,)) store = LazyOp(BufferOps.STORE, (third_reduce,), MemBuffer(0, dtypes.float, ShapeTracker.from_shape((27, 1, 1, 1, 5)))) wanna_output = (x2.numpy()*(x1.numpy()-x0.numpy().sum(axis=1, keepdims=True)).sum(axis=1, keepdims=True)).sum(axis=1).reshape(27,1,1,1,5) lins = helper_linearizer_ast((store, ), [x0,x1,x2], wanna_output=[wanna_output]) for l in lins: ranges = [u.op for u in l.uops if (u.op is UOps.RANGE and u.arg[1]) or (u.op is UOps.ENDRANGE and u.src[0].arg[1])] for i,u in enumerate(ranges): if i == 0: continue assert ranges[i-1] != u, f"multireduce nested the ranges! {ranges[i-1], {u}}" @unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet") @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared") def test_double_reduce_multireduce(self): Tensor.manual_seed(0) x = Tensor.randn(8, 32, 8, 16, dtype=dtypes.float).realize() first_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((8, 1, 32, 8, 1, 16)).expand((8, 32, 32, 8, 16, 16)))) first_reduce = LazyOp(ReduceOps.SUM, (first_x,), (2,5)) second_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((8, 32, 1, 8, 16, 1)))) squares = (second_x-first_reduce) squares_sum = LazyOp(ReduceOps.SUM, (squares,), (1,4)) store = LazyOp(BufferOps.STORE, (squares_sum,), MemBuffer(0, dtypes.float, ShapeTracker.from_shape((8, 1, 1, 8, 1, 1)))) wanna_output = (x.numpy()-x.numpy().sum(axis=(1,3), keepdims=True)).sum(axis=(1,3)).reshape((8,1,1,8,1,1)) opts = [ # openCL / GPU=1 is 256 max threads # grouping # [Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.GROUPTOP, 1, 2)], # first dim of both reduces # [Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.GROUPTOP, 3, 2)], # both dims of the second reduce # [Opt(OptOps.GROUPTOP, 2, 2), Opt(OptOps.GROUPTOP, 3, 2)], # second dim of both reduces # [Opt(OptOps.GROUPTOP, 1, 2), Opt(OptOps.GROUPTOP, 3, 2)], # both dims of the first reduce # # group all reduce dims # [Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.GROUPTOP, 1, 2), Opt(OptOps.GROUPTOP, 2, 2), Opt(OptOps.GROUPTOP, 3, 2)], # # checking how it works with 2 grouped reduces + unrolling # [Opt(OptOps.GROUPTOP, 0, 4), Opt(OptOps.GROUPTOP, 1, 4), Opt(OptOps.GROUPTOP, 2, 4), Opt(OptOps.GROUPTOP, 3, 4), # Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UNROLL, 1, 2), Opt(OptOps.UNROLL, 2, 2), Opt(OptOps.UNROLL, 3, 2)], # # Checking how it works with 2 grouped reduces + locals. # [Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.LOCAL, 0, 4), # Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.GROUPTOP, 1, 2), Opt(OptOps.GROUPTOP, 2, 2), Opt(OptOps.GROUPTOP, 3, 2)], # # Checking how it works with 2 grouped reduces + locals + unroll. # [Opt(OptOps.LOCAL, 0, 2), # Opt(OptOps.GROUPTOP, 0, 4), Opt(OptOps.GROUPTOP, 1, 4), Opt(OptOps.GROUPTOP, 2, 4), Opt(OptOps.GROUPTOP, 3, 4), # Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UNROLL, 1, 2), Opt(OptOps.UNROLL, 2, 2), Opt(OptOps.UNROLL, 3, 2)], # # Checking how it works with 2 grouped reduces + locals + upcast. # [Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.UPCAST, 0, 2), # Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.GROUPTOP, 1, 2), Opt(OptOps.GROUPTOP, 2, 2), Opt(OptOps.GROUPTOP, 3, 2)], # # Checking how it works with 2 grouped reduces + locals + upcast + unroll. # [Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.UPCAST, 0, 2), # Opt(OptOps.GROUPTOP, 0, 4), Opt(OptOps.GROUPTOP, 1, 4), Opt(OptOps.GROUPTOP, 2, 4), Opt(OptOps.GROUPTOP, 3, 4), # Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UNROLL, 1, 2), Opt(OptOps.UNROLL, 2, 2), Opt(OptOps.UNROLL, 3, 2)], ] lins = helper_linearizer_ast((store, ), [x], wanna_output=[wanna_output], opts=opts) for l in lins: ranges = [u.op for u in l.uops if (u.op is UOps.RANGE and u.arg[1]) or (u.op is UOps.ENDRANGE and u.src[0].arg[1])] for i,u in enumerate(ranges): if i < 2: continue assert ranges[i-2] != u or ranges[i-1] != u, f"multireduce nested the ranges! {ranges[i-2], ranges[i-1], {u}}" @unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet") @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared") def test_partial_opt_multireduce(self): # check how it works with one reduce optimized and one unoptimized Tensor.manual_seed(0) x = Tensor.randn(27, 15, 5, dtype=dtypes.float).softmax(1).realize() first_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((27, 1, 15, 5)).expand((27, 15, 15, 5)))) first_reduce = LazyOp(ReduceOps.SUM, (first_x,), (2,)) second_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((27, 15, 1, 5)))) diff = (second_x-first_reduce) second_reduce = LazyOp(ReduceOps.SUM, (diff,), (1,)) store = LazyOp(BufferOps.STORE, (second_reduce,), MemBuffer(0, dtypes.float, ShapeTracker.from_shape((27, 1, 1, 5)))) opts = [ # [Opt(OptOps.GROUPTOP, 0, 3)], # grouping # [Opt(OptOps.GROUPTOP, 1, 3)], # [Opt(OptOps.GROUPTOP, 0, 15)], # [Opt(OptOps.GROUPTOP, 1, 15)], [Opt(OptOps.UNROLL, 0, 3)], [Opt(OptOps.UNROLL, 1, 3)], ] wanna_output = (x.numpy()-x.numpy().sum(axis=1, keepdims=True)).sum(axis=1).reshape(27,1,1,5) lins = helper_linearizer_ast((store, ), [x], wanna_output=[wanna_output], opts=opts) for l in lins: ranges = [u.op for u in l.uops if (u.op is UOps.RANGE and u.arg[1]) or (u.op is UOps.ENDRANGE and u.src[0].arg[1])] for i,u in enumerate(ranges): if i == 0: continue assert ranges[i-1] != u, f"multireduce nested the ranges! {ranges[i-1], {u}}" @unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet") @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared") def test_multireduce_with_parallel(self): Tensor.manual_seed(0) x = Tensor.randn(4, 32, dtype=dtypes.float).realize() x_p = Tensor.randn(4, 32, dtype=dtypes.float).realize() first_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((4, 1, 32)).expand((4, 32, 32)))) first_x_p = LazyOp(BufferOps.LOAD, (), MemBuffer(2, dtypes.float, x_p.lazydata.st.reshape((4, 1, 32)).expand((4, 32, 32)))) first_reduce = LazyOp(ReduceOps.SUM, (first_x,), (2,)) first_reduce_p = LazyOp(ReduceOps.SUM, (LazyOp(UnaryOps.EXP2, (first_x_p,)),), (2,)) second_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((4, 32, 1)))) diff = (second_x-LazyOp(BinaryOps.ADD, (first_reduce,first_reduce_p))) second_reduce = LazyOp(ReduceOps.SUM, (diff,), (1,)) store = LazyOp(BufferOps.STORE, (second_reduce,), MemBuffer(0, dtypes.float, ShapeTracker.from_shape((4, 1, 1)))) opts = [ # [Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.GROUPTOP, 1, 2)], # grouping # [Opt(OptOps.GROUPTOP, 0, 8), Opt(OptOps.GROUPTOP, 1, 8)], # [Opt(OptOps.GROUPTOP, 0, 16), Opt(OptOps.GROUPTOP, 1, 16)], # [Opt(OptOps.GROUPTOP, 0, 32), Opt(OptOps.GROUPTOP, 0, 32)], [Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UNROLL, 1, 2)], # unroll reduce [Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UNROLL, 1, 4)], [Opt(OptOps.UNROLL, 0, 8), Opt(OptOps.UNROLL, 1, 8)] if Device.DEFAULT not in {"NV", "METAL"} else [], # can't do float8, # [Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.GROUPTOP, 1, 2), Opt(OptOps.UNROLL, 2, 2), Opt(OptOps.UNROLL, 3, 2)], # grouping + unrolling # [Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UNROLL, 1, 2), Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.GROUPTOP, 1, 2)], # [Opt(OptOps.GROUPTOP, 0, 4), Opt(OptOps.GROUPTOP, 1, 4), Opt(OptOps.UNROLL, 2, 8), Opt(OptOps.UNROLL, 2, 8)], # [Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UNROLL, 1, 4), Opt(OptOps.GROUPTOP, 0, 8), Opt(OptOps.GROUPTOP, 0, 8)], ] wanna_output = (x.numpy()-(x.numpy().sum(-1, keepdims=True)+np.exp2(x_p.numpy()).sum(-1, keepdims=True))).sum(-1).reshape(4, 1,1) lins = helper_linearizer_ast((store, ), [x,x_p], wanna_output=[wanna_output], opts=opts) for l in lins: ranges = [u.op for u in l.uops if (u.op is UOps.RANGE and u.arg[1]) or (u.op is UOps.ENDRANGE and u.src[0].arg[1])] for i,u in enumerate(ranges): if i == 0: continue assert ranges[i-1] != u, f"multireduce nested the ranges! {ranges[i-1], {u}}" @unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet") def test_multiout_multireduce(self): # check how multireduce works with multioutput Tensor.manual_seed(0) x = Tensor.randn(27, 15, 5, dtype=dtypes.float).realize() first_x = LazyOp(BufferOps.LOAD, (), MemBuffer(2, dtypes.float, x.lazydata.st.reshape((27, 1, 15, 5)).expand((27, 15, 15, 5)))) first_reduce = LazyOp(ReduceOps.SUM, (first_x,), (2,)) second_x = LazyOp(BufferOps.LOAD, (), MemBuffer(2, dtypes.float, x.lazydata.st.reshape((27, 15, 1, 5)))) diff = (second_x-first_reduce) second_reduce = LazyOp(ReduceOps.SUM, (diff,), (1,)) store0 = LazyOp(BufferOps.STORE, (second_reduce,), MemBuffer(0, dtypes.float, ShapeTracker.from_shape((27, 1, 1, 5)))) second_out = LazyOp(BinaryOps.MUL, (second_reduce, LazyOp(BufferOps.CONST, (), ConstBuffer(1/15, dtypes.float, ShapeTracker.from_shape((27,1,1,5)))))) store1 = LazyOp(BufferOps.STORE, (second_out,), MemBuffer(1, dtypes.float, ShapeTracker.from_shape((27, 1, 1, 5)))) wanna_output = (x.numpy()-x.numpy().sum(axis=1, keepdims=True)).sum(axis=1).reshape(27,1,1,5) helper_linearizer_ast((store0, store1, ), [x], wanna_output=[wanna_output, wanna_output/15]) @unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet") @unittest.expectedFailure def test_multiout_intermediate_multireduce(self): # check how it outputing at different stages of the multireduce works # TODO: Fails because the stores shapes do not match: store1.shape = (27,15,1,5) != store0.shape = (27,1,1,5) # so the output shapes are different (FAIL!), # if we change the shape of store1 to be contiguous, it will match store0 but not the value it's storing (FAIL!) Tensor.manual_seed(0) x = Tensor.randn(27, 15, 5, dtype=dtypes.float).realize() first_x = LazyOp(BufferOps.LOAD, (), MemBuffer(2, dtypes.float, x.lazydata.st.reshape((27, 1, 15, 5)).expand((27, 15, 15, 5)))) first_reduce = LazyOp(ReduceOps.SUM, (first_x,), (2,)) second_x = LazyOp(BufferOps.LOAD, (), MemBuffer(2, dtypes.float, x.lazydata.st.reshape((27, 15, 1, 5)))) diff = (second_x-first_reduce) second_reduce = LazyOp(ReduceOps.SUM, (diff,), (1,)) store0 = LazyOp(BufferOps.STORE, (second_reduce,), MemBuffer(0, dtypes.float, ShapeTracker.from_shape((27, 1, 1, 5)))) store1 = LazyOp(BufferOps.STORE, (first_reduce,), MemBuffer(1, dtypes.float, ShapeTracker(views=(View(shape=(27,15,1,5), strides=(5,0,1,1), offset=0, mask=None, contiguous=False),)))) # noqa: E501 wanna_output0 = (x.numpy()-x.numpy().sum(axis=1, keepdims=True)).sum(axis=1).reshape(27,1,1,5) wanna_output1 = x.numpy().sum(axis=1).reshape(27,1,1,5) ast = LazyOp(MetaOps.KERNEL, (store0,store1)) k = Kernel(ast) prg = CompiledRunner(replace(k.to_program(), dname=Device.DEFAULT)) inbufs = [x.lazydata.base.buffer] outbufs = [Buffer(inbufs[-1].device if inbufs else Device.DEFAULT, out.arg.st.size, out.arg.dtype).allocate() for out in ast.src] prg.exec(outbufs+inbufs) np.testing.assert_allclose(np.frombuffer(outbufs[0].as_buffer(), _to_np_dtype(outbufs[0].dtype)).reshape(27,1,1,5), wanna_output0) np.testing.assert_allclose(np.frombuffer(outbufs[1].as_buffer(), _to_np_dtype(outbufs[1].dtype))[:135].reshape(27,1,1,5), wanna_output1) @unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet") def test_complete_unroll_multireduce(self): Tensor.manual_seed(0) x = Tensor.randn(27, 3, 5, dtype=dtypes.float).realize() first_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((27, 1, 3, 5)).expand((27, 3, 3, 5)))) first_reduce = LazyOp(ReduceOps.SUM, (first_x,), (2,)) second_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((27, 3, 1, 5)))) diff = (second_x-first_reduce) second_reduce = LazyOp(ReduceOps.SUM, (diff,), (1,)) store = LazyOp(BufferOps.STORE, (second_reduce,), MemBuffer(0, dtypes.float, ShapeTracker.from_shape((27, 1, 1, 5)))) opts = [[Opt(OptOps.UNROLL, 0, 3), Opt(OptOps.UNROLL, 0, 3)]] wanna_output = (x.numpy()-x.numpy().sum(axis=1, keepdims=True)).sum(axis=1).reshape(27,1,1,5) helper_linearizer_ast((store, ), [x], wanna_output=[wanna_output], opts=opts) @unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet") def test_upcast_multireduce(self): Tensor.manual_seed(0) x = Tensor.randn(27, 3, 5, dtype=dtypes.float).realize() first_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((27, 1, 3, 5)).expand((27, 3, 3, 5)))) first_reduce = LazyOp(ReduceOps.SUM, (first_x,), (2,)) second_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((27, 3, 1, 5)))) diff = (second_x-first_reduce) second_reduce = LazyOp(ReduceOps.SUM, (diff,), (1,)) store = LazyOp(BufferOps.STORE, (second_reduce,), MemBuffer(0, dtypes.float, ShapeTracker.from_shape((27, 1, 1, 5)))) opts = [[Opt(OptOps.UPCAST, 0, 3)]] wanna_output = (x.numpy()-x.numpy().sum(axis=1, keepdims=True)).sum(axis=1).reshape(27,1,1,5) helper_linearizer_ast((store, ), [x], wanna_output=[wanna_output], opts=opts) @unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet") @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") @unittest.skip("can't group with multiple reduces yet") def test_early_endif(self): # make sure the if block of a grouped reduce can be closed early and the result loaded back in Tensor.manual_seed(0) x = Tensor.randn(27, 12, 5, dtype=dtypes.float).realize() first_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((27, 1, 12, 5)).expand((27, 12, 12, 5)))) first_reduce = LazyOp(ReduceOps.SUM, (first_x,), (2,)) second_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((27, 12, 1, 5)))) diff = (second_x-first_reduce) second_reduce = LazyOp(ReduceOps.SUM, (diff,), (1,)) store = LazyOp(BufferOps.STORE, (second_reduce,), MemBuffer(0, dtypes.float, ShapeTracker.from_shape((27, 1, 1, 5)))) opts = [[Opt(OptOps.GROUPTOP, 0, 3), Opt(OptOps.GROUPTOP, 1, 3)]] wanna_output = (x.numpy()-x.numpy().sum(axis=1, keepdims=True)).sum(axis=1).reshape(27,1,1,5) helper_linearizer_ast((store, ), [x], wanna_output=[wanna_output], opts=opts) @unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet") def test_mean_std_multireduce(self): Tensor.manual_seed(0) x = Tensor.randn(15, 25, 35, dtype=dtypes.float).realize() first_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((15, 25, 1, 35)).expand((15, 25, 35, 35)))) first_reduce = LazyOp(ReduceOps.SUM, (first_x,), (3,)) mean = first_reduce * LazyOp(BufferOps.CONST, (), ConstBuffer(1/35, dtypes.float, ShapeTracker.from_shape(()).reshape((1, 1, 1, 1)).expand((15, 25, 35, 1)))) # noqa: E501 second_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((15, 25, 35, 1)))) squares = (second_x-mean)*(second_x-mean) squares_sum = LazyOp(ReduceOps.SUM, (squares,), (2,)) variance = squares_sum * LazyOp(BufferOps.CONST, (), ConstBuffer(1/35, dtypes.float, ShapeTracker.from_shape(()).reshape((1, 1, 1, 1)).expand((15, 25, 1, 1)))) # noqa: E501 std = LazyOp(UnaryOps.SQRT, (variance,), None) store = LazyOp(BufferOps.STORE, (std,), MemBuffer(0, dtypes.float, ShapeTracker.from_shape((15, 25, 1, 1)))) wanna_output = x.numpy().std(axis=2, ddof=0).reshape((15,25,1,1)) helper_linearizer_ast((store,), [x], wanna_output=[wanna_output]) @unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet") def test_mean_std_multireduce_mid_dim(self): Tensor.manual_seed(0) x = Tensor.randn(15, 25, 35, dtype=dtypes.float).realize() first_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((15, 1, 25, 35)).expand((15, 25, 25, 35)))) first_reduce = LazyOp(ReduceOps.SUM, (first_x,), (2,)) mean = first_reduce * LazyOp(BufferOps.CONST, (), ConstBuffer(0.04, dtypes.float, ShapeTracker.from_shape(()).reshape((1, 1, 1, 1)).expand((15, 25, 1, 35)))) # noqa: E501 second_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((15, 25, 1, 35)))) squares = (second_x-mean)*(second_x-mean) squares_sum = LazyOp(ReduceOps.SUM, (squares,), (1,)) variance = squares_sum * LazyOp(BufferOps.CONST, (), ConstBuffer(0.04, dtypes.float, ShapeTracker.from_shape(()).reshape((1, 1, 1, 1)).expand((15, 1, 1, 35)))) # noqa: E501 std = LazyOp(UnaryOps.SQRT, (variance,), None) store = LazyOp(BufferOps.STORE, (std,), MemBuffer(0, dtypes.float, ShapeTracker.from_shape((15, 1, 1, 35)))) wanna_output = x.numpy().std(axis=1, ddof=0).reshape((15,1,1,35)) helper_linearizer_ast((store,), [x], wanna_output=[wanna_output]) @unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet") @unittest.expectedFailure def test_mean_std_multireduce_multiout(self): # TODO: Same error as in test_multiout_intermediate_multireduce Tensor.manual_seed(0) x = Tensor.randn(15, 25, 35, dtype=dtypes.float).realize() first_x = LazyOp(BufferOps.LOAD, (), MemBuffer(2, dtypes.float, x.lazydata.st.reshape((15, 25, 1, 35)).expand((15, 25, 35, 35)))) first_reduce = LazyOp(ReduceOps.SUM, (first_x,), (3,)) mean = first_reduce * LazyOp(BufferOps.CONST, (), ConstBuffer(1/35, dtypes.float, ShapeTracker.from_shape(()).reshape((1, 1, 1, 1)).expand((15, 25, 35, 1)))) # noqa: E501 second_x = LazyOp(BufferOps.LOAD, (), MemBuffer(2, dtypes.float, x.lazydata.st.reshape((15, 25, 35, 1)))) squares = (second_x-mean)*(second_x-mean) squares_sum = LazyOp(ReduceOps.SUM, (squares,), (2,)) variance = squares_sum * LazyOp(BufferOps.CONST, (), ConstBuffer(1/35, dtypes.float, ShapeTracker.from_shape(()).reshape((1, 1, 1, 1)).expand((15, 25, 1, 1)))) # noqa: E501 std = LazyOp(UnaryOps.SQRT, (variance,), None) store_mean = LazyOp(BufferOps.STORE, (mean,), MemBuffer(1, dtypes.float, ShapeTracker.from_shape((15,25,1,1)))) store_std = LazyOp(BufferOps.STORE, (std,), MemBuffer(0, dtypes.float, ShapeTracker.from_shape((15, 25, 1, 1)))) wanna_output = [x.numpy().std(axis=2, ddof=0).reshape(15,25,1,1), x.numpy().mean(axis=2).reshape(15,25,1,1)] lins = helper_linearizer_ast((store_std,store_mean), [x], wanna_output=wanna_output) for k in lins: assert len([u for u in k.uops if u.op is UOps.DEFINE_ACC]) == 2, "got more than two accs (implies the kernel didn't reuse the mean reduce)" @unittest.skipIf(CI and Device.DEFAULT in {"PTX", "AMD", "NV"}, "ocelot/remu doesn't have multiple wave syncs yet") def test_var_multireduce(self): Tensor.manual_seed(0) x = Tensor.randn(3, 27, 32, dtype=dtypes.float).realize() # push reduce (3, 27, 32) -> (3, 27, 1) -> (3, 27, 32) expand to LOAD first_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((3, 27, 1, 32)).expand((3, 27, 32, 32)))) first_reduce = LazyOp(ReduceOps.SUM, (first_x,), (3,)) mean = first_reduce * LazyOp(BufferOps.CONST, (), ConstBuffer(0.03125, dtypes.float, ShapeTracker.from_shape(()).reshape((1, 1, 1, 1)).expand((3, 27, 32, 1)))) # noqa: E501 # store = LazyOp(BufferOps.STORE, (mean,), MemBuffer(0, dtypes.float, ShapeTracker.from_shape((3, 27, 32, 1)))) # verify_lazyop(store) second_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((3, 27, 32, 1)))) squares = (second_x-mean)*(second_x-mean) squares_sum = LazyOp(ReduceOps.SUM, (squares,), (2,)) variance = squares_sum * LazyOp(BufferOps.CONST, (), ConstBuffer(0.03125, dtypes.float, ShapeTracker.from_shape(()).reshape((1, 1, 1, 1)).expand((3, 27, 1, 1)))) # noqa: E501 store = LazyOp(BufferOps.STORE, (variance,), MemBuffer(0, dtypes.float, ShapeTracker.from_shape((3, 27, 1, 1)))) wanna_output = x.numpy().var(axis=2, ddof=0).reshape((3,27,1,1)) helper_linearizer_ast((store, ), [x], wanna_output=[wanna_output]) # tinygrad ref y_tiny = x.var(axis=2, correction=0).reshape(3,27,1,1) np.testing.assert_allclose(y_tiny.numpy(), wanna_output, atol=1e-4, rtol=1e-4) @unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet") def test_softmax_multireduce(self): x = Tensor.rand(4, 32).realize() first_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((4, 1, 32,)).expand((4, 32, 32)))) max_x = LazyOp(op=ReduceOps.MAX, src=(first_x,), arg=(2,)) second_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((4, 32, 1,)))) centered_x = LazyOp(op=BinaryOps.ADD, src=(second_x, LazyOp(op=UnaryOps.NEG, src=(max_x,), arg=None))) exp_x = LazyOp(op=UnaryOps.EXP2, src=(centered_x,)) sum_exp_x = LazyOp(op=ReduceOps.SUM, src=(exp_x,), arg=(1,)) # y = LazyOp(op=BinaryOps.MUL, src=(exp_x, LazyOp(op=UnaryOps.RECIP, src=(sum_exp_x,)))) # kernels cannot do a return to full shape recip_sum_exp_x = LazyOp(op=UnaryOps.RECIP, src=(sum_exp_x,)) store = LazyOp(op=BufferOps.STORE, src=(recip_sum_exp_x,), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker.from_shape((4,1,1)))) expected = 1/np.exp2(x.numpy() - x.numpy().max(axis=-1, keepdims=True)).sum(axis=-1, keepdims=True).reshape(4,1,1) helper_linearizer_ast((store,), [x], wanna_output=[expected]) # *** buildup to fused indexing @unittest.skipIf(CI, "very slow because of recomputing") def test_arange_expanded(self): # Tensor.arange(16384) expanded such that output shape is (4, 16384, 256, 1) # basically it's pushing the expand through this reduce: tiny = Tensor.arange(16384).reshape(16384, 1).expand(4, 16384, 256).reshape(4, 16384, 256, 1) real_arange = np.broadcast_to(np.arange(16384).reshape(16384, 1), (4, 16384, 256)).reshape(4, 16384, 256, 1) # NOTE: this is stupidly recomputing because it's not fused, but it proves a point. arange_input_st = ShapeTracker(views=(View(shape=(16385, 32767), strides=(0, 0), offset=0, mask=((0, 16385), (16383, 32767)), contiguous=False), \ View(shape=(16384, 16384), strides=(1, 32768), offset=0, mask=None, contiguous=False))) arange_input_st = arange_input_st.reshape((1, 16384, 1, 16384)).expand((4, 16384, 256, 16384)) arange_axis = (3,) arange = LazyOp(ReduceOps.SUM, (LazyOp(BufferOps.CONST, (), ConstBuffer(1, dtypes.int, arange_input_st)), ), arange_axis) output_shape = tuple(1 if i in arange_axis else s for i,s in enumerate(arange_input_st.shape)) out = arange-LazyOp.const(1, dtypes.int, output_shape) store = LazyOp(BufferOps.STORE, (out, ), MemBuffer(0, dtypes.int, st=ShapeTracker.from_shape(output_shape))) helper_linearizer_ast((store, ), [], wanna_output=[real_arange]) with Context(DEBUG=0, NOOPT=0): np.testing.assert_equal(tiny.numpy(), real_arange) @unittest.skipIf(CI and Device.DEFAULT in {"PTX", "AMD", "NV"}, "very slow") def test_indexing_multireduce(self): arange_input_st = ShapeTracker(views=(View(shape=(16385, 32767), strides=(0, 0), offset=0, mask=((0, 16385), (16383, 32767)), contiguous=False), \ View(shape=(16384, 16384), strides=(1, 32768), offset=0, mask=None, contiguous=False))) # TODO: do this arange broadcast in the scheduler arange_input_st = arange_input_st.reshape((1, 16384, 1, 16384)).expand((4, 16384, 256, 16384)) arange_axis = (3,) arange = LazyOp(ReduceOps.SUM, (LazyOp(BufferOps.CONST, (), ConstBuffer(1, dtypes.int, arange_input_st)), ), arange_axis) arange_out_shape = tuple(1 if i in arange_axis else s for i,s in enumerate(arange_input_st.shape)) arange = arange-LazyOp.const(1, dtypes.int, arange_out_shape) # p2: the indexing dataset = Tensor.rand(16384, 256).realize() data1 = MemBuffer(1, dataset.dtype, ShapeTracker.from_shape(dataset.shape).reshape((1, 16384, 256, 1)).expand(arange_out_shape)) idxs = Tensor([0,3,5,6]).realize() data2 = MemBuffer(2, dtypes.int, ShapeTracker.from_shape((4,)+(1,)*(len(arange_out_shape)-1)).expand(arange_out_shape)) reduce_input = LazyOp(BufferOps.LOAD, (), data1)*LazyOp(UnaryOps.CAST, (arange.eq(LazyOp(BufferOps.LOAD, (), data2)),), dataset.dtype) out = LazyOp(ReduceOps.SUM, (reduce_input, ), (1,)) output_shape = tuple(1 if i in out.arg else s for i,s in enumerate(arange_out_shape)) store = LazyOp(BufferOps.STORE, (out, ), MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker.from_shape(output_shape))) real_index = dataset.numpy()[idxs.numpy()].reshape(4, 1, 256, 1) helper_linearizer_ast((store, ), [dataset, idxs], wanna_output=[real_index]) # AssertionError: repeated stores in uops def test_argmax_multireduce_axis0(self): t = Tensor.randn(10, 20).realize() t_max = t.max((0,)).realize() real_argmax = np.argmax(t.numpy(), axis=0, keepdims=False).reshape(1, 20, 1) ast = LazyOp(MetaOps.KERNEL, arg=None, src=( LazyOp(BufferOps.STORE, arg=MemBuffer(idx=0, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(1, 20, 1), strides=(0, 1, 0), offset=0, mask=None, contiguous=True),))), src=( # noqa E501 LazyOp(BinaryOps.ADD, arg=None, src=( LazyOp(BinaryOps.ADD, arg=None, src=( LazyOp(BufferOps.CONST, arg=ConstBuffer(val=10, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(1, 20, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))), src=()), # noqa E501 LazyOp(UnaryOps.NEG, arg=None, src=( LazyOp(ReduceOps.MAX, arg=(0,), src=( LazyOp(BinaryOps.MUL, arg=None, src=( LazyOp(UnaryOps.CAST, arg=dtypes.int, src=( LazyOp(BinaryOps.CMPNE, arg=None, src=( LazyOp(BinaryOps.CMPNE, arg=None, src=( LazyOp(BufferOps.LOAD, arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(10, 20, 1), strides=(20, 1, 0), offset=0, mask=None, contiguous=True),))), src=()), # noqa E501 LazyOp(BufferOps.LOAD, arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(10, 20, 1), strides=(0, 1, 0), offset=0, mask=None, contiguous=False),))), src=()),)), # noqa E501 LazyOp(BufferOps.CONST, arg=ConstBuffer(val=True, dtype=dtypes.bool, st=ShapeTracker(views=(View(shape=(10, 20, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))), src=()),)),)), # noqa E501 LazyOp(BinaryOps.ADD, arg=None, src=( LazyOp(ReduceOps.SUM, arg=(2,), src=( LazyOp(BufferOps.CONST, arg=ConstBuffer(val=-1, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(11, 19), strides=(0, 0), offset=0, mask=((0, 11), (9, 19)), contiguous=False), View(shape=(10, 20, 10), strides=(1, 0, 20), offset=0, mask=None, contiguous=False)))), src=()),)), # noqa E501 LazyOp(BufferOps.CONST, arg=ConstBuffer(val=10, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(10, 20, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))), src=()),)),)),)),)),)), # noqa E501 LazyOp(BufferOps.CONST, arg=ConstBuffer(val=-1, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(1, 20, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))), src=()),)),)),)) # noqa E501 helper_linearizer_ast(ast, [t, t_max], wanna_output=[real_argmax]) def test_argmax_multireduce_flat(self): t = Tensor.randn(10, 20).realize() t_max = t.max().realize() real_argmax = np.argmax(t.numpy()) ast = LazyOp(MetaOps.KERNEL, arg=None, src=( LazyOp(BufferOps.STORE, arg=MemBuffer(idx=0, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),))), src=( # noqa E501 LazyOp(BinaryOps.ADD, arg=None, src=( LazyOp(BinaryOps.ADD, arg=None, src=( LazyOp(BufferOps.CONST, arg=ConstBuffer(val=200, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),))), src=()), # noqa E501 LazyOp(UnaryOps.NEG, arg=None, src=( LazyOp(ReduceOps.MAX, arg=(0,), src=( LazyOp(BinaryOps.MUL, arg=None, src=( LazyOp(UnaryOps.CAST, arg=dtypes.int, src=( LazyOp(BinaryOps.CMPNE, arg=None, src=( LazyOp(BinaryOps.CMPNE, arg=None, src=( LazyOp(BufferOps.LOAD, arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(200, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),))), src=()), # noqa E501 LazyOp(BufferOps.LOAD, arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(200, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),))), src=()),)), # noqa E501 LazyOp(BufferOps.CONST, arg=ConstBuffer(val=True, dtype=dtypes.bool, st=ShapeTracker(views=(View(shape=(200, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),))), src=()),)),)), # noqa E501 LazyOp(BinaryOps.ADD, arg=None, src=( LazyOp(ReduceOps.SUM, arg=(1,), src=( LazyOp(BufferOps.CONST, arg=ConstBuffer(val=-1, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(201, 399), strides=(0, 0), offset=0, mask=((0, 201), (199, 399)), contiguous=False), View(shape=(200, 200), strides=(1, 400), offset=0, mask=None, contiguous=False)))), src=()),)), # noqa E501 LazyOp(BufferOps.CONST, arg=ConstBuffer(val=200, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(200, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),))), src=()),)),)),)),)),)), # noqa E501 LazyOp(BufferOps.CONST, arg=ConstBuffer(val=-1, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),))), src=()),)),)),)) # noqa E501 helper_linearizer_ast(ast, [t, t_max], wanna_output=[real_argmax]) @unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet") def test_padto_sum_multireduce(self): Tensor.manual_seed(0) N = 17 x = Tensor.rand(N, N).realize() opts = [ [Opt(OptOps.PADTO, 0, 32)], [Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),], # TODO: multireduce pads # causes an issue because the acc won't be masked in the second reduce # [Opt(OptOps.PADTO, 1, 32), Opt(OptOps.PADTO, 2, 32)] ] x_ld0 = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((1, N, N)).expand((N,N,N)))) x_ld1 = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((N, 1, N)))) r0 = LazyOp(ReduceOps.SUM, (x_ld0,), (1,)) r1 = LazyOp(ReduceOps.SUM, (LazyOp(BinaryOps.ADD, (x_ld1, LazyOp(op=UnaryOps.NEG, src=(r0,), arg=None)),),), (0,)) store = LazyOp(BufferOps.STORE, (r1, ), MemBuffer(0, dtypes.float, ShapeTracker.from_shape((1,1,N)))) helper_linearizer_ast((store,), [x], wanna_output=[(x.numpy()-x.numpy().sum(axis=0, keepdims=True)).sum(axis=0).reshape(1,1,N)], opts=opts) x_ld0 = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((N, 1, N)).expand((N,N,N)))) x_ld1 = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((N, N, 1)))) r0 = LazyOp(ReduceOps.SUM, (x_ld0,), (2,)) r1 = LazyOp(ReduceOps.SUM, (LazyOp(BinaryOps.ADD, (x_ld1, LazyOp(op=UnaryOps.NEG, src=(r0,), arg=None)),),), (1,)) store = LazyOp(BufferOps.STORE, (r1, ), MemBuffer(0, dtypes.float, ShapeTracker.from_shape((N,1,1)))) helper_linearizer_ast((store,), [x], wanna_output=[(x.numpy()-x.numpy().sum(axis=1, keepdims=True)).sum(axis=1).reshape(N,1,1)], opts=opts) @unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet") def test_padto_max_multireduce(self): Tensor.manual_seed(0) N = 17 x = Tensor.rand(N, N).realize() opts = [ [Opt(OptOps.PADTO, 0, 32)], [Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),] ] x_ld0 = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((1, N, N)).expand((N,N,N)))) x_ld1 = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((N, 1, N)))) r0 = LazyOp(ReduceOps.MAX, (x_ld0,), (1,)) r1 = LazyOp(ReduceOps.MAX, (LazyOp(BinaryOps.ADD, (x_ld1, LazyOp(op=UnaryOps.NEG, src=(r0,), arg=None)),),), (0,)) store = LazyOp(BufferOps.STORE, (r1, ), MemBuffer(0, dtypes.float, ShapeTracker.from_shape((1,1,N)))) helper_linearizer_ast((store,), [x], wanna_output=[(x.numpy()-x.numpy().max(axis=0, keepdims=True)).max(axis=0).reshape(1,1,N)], opts=opts) x_ld0 = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((N, 1, N)).expand((N,N,N)))) x_ld1 = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((N, N, 1)))) r0 = LazyOp(ReduceOps.MAX, (x_ld0,), (2,)) r1 = LazyOp(ReduceOps.MAX, (LazyOp(BinaryOps.ADD, (x_ld1, LazyOp(op=UnaryOps.NEG, src=(r0,), arg=None)),),), (1,)) store = LazyOp(BufferOps.STORE, (r1, ), MemBuffer(0, dtypes.float, ShapeTracker.from_shape((N,1,1)))) helper_linearizer_ast((store,), [x], wanna_output=[(x.numpy()-x.numpy().max(axis=1, keepdims=True)).max(axis=1).reshape(N,1,1)], opts=opts) @unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet") def test_padto_where_multireduce(self): # ternary operators try to use both ridxs # we need to make sure the ternary operators nest properly N = 17 x = Tensor.rand(N, N).realize() a = Tensor.rand(1, 1).realize() b = Tensor.rand(1, 1).realize() opts = [[Opt(OptOps.PADTO, 0, 32)],[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),],] # TODO: these large ASTs are suboptimal but we need this until the scheduler can fuse these wanna_output = np.where(0.5*17 < (x.numpy()+np.where(0.75*17 < x.numpy().sum(axis=1,keepdims=True), a.numpy(), b.numpy())).sum(axis=1),0.0,1.0).reshape((N,1,1)) # noqa: E501 ld0 = x.lazydata.st.reshape((N, 1, N)).expand((N,N,N)) ld1 = x.lazydata.st.reshape((N, N, 1)) ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=TernaryOps.WHERE, src=(LazyOp(op=BinaryOps.CMPLT, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.5*17, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1,1)).expand((N,1,1)))),LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ld1)),LazyOp(op=TernaryOps.WHERE, src=(LazyOp(op=BinaryOps.CMPLT, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.75*17, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1,1)).expand((N,N,1)))),LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ld0)),), arg=(2,)))),LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1,1)).expand((N,N,1)))),LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=3, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1,1)).expand((N,N,1)))),)),)),), arg=(1,)),)),LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.0, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1,1)).expand((N,1,1)))),LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=1.0, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1,1)).expand((N,1,1)))),)),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker.from_shape((N,1,1)))) # noqa: E501 helper_linearizer_ast((ast,), [x,a,b], opts=opts, wanna_output=[wanna_output]) ld0 = x.lazydata.st.reshape((1, N, N)).expand((N,N,N)) ld1 = x.lazydata.st.reshape((N, 1, N)) wanna_output = np.where(0.5*17 < (x.numpy()+np.where(0.75*17 < x.numpy().sum(axis=0,keepdims=True), a.numpy(), b.numpy())).sum(axis=0),0.0,1.0).reshape(1,1,N) # noqa: E501 ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=TernaryOps.WHERE, src=(LazyOp(op=BinaryOps.CMPLT, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.5*17, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1,1)).expand((1,1,N)))),LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ld1)),LazyOp(op=TernaryOps.WHERE, src=(LazyOp(op=BinaryOps.CMPLT, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.75*17, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1,1)).expand((N,1,N)))),LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ld0)),), arg=(1,)))),LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1,1)).expand((N,1,N)))),LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=3, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1,1)).expand((N,1,N)))),)),)),), arg=(0,)),)),LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.0, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1,1)).expand((1,1,N)))),LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=1.0, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1,1)).expand((1,1,N)))),)),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1,N)))) # noqa: E501 helper_linearizer_ast((ast,), [x,a,b], opts=opts, wanna_output=[wanna_output]) # # pad reduce axis # helper_linearizer_ast((ast,), [x,a,b], opts=[[Opt(OptOps.PADTO, 1, 32)],], wanna_output=[wanna_output]) # ld0 = x.lazydata.st.reshape((1,1,N,N)).expand((N,N,N,N)) # ld1 = x.lazydata.st.reshape((N,N,1,1)) # wanna_output = np.where(0.5*17 < (x.numpy()+np.where(0.75*17 < x.numpy().sum(keepdims=True), a.numpy(), b.numpy())).sum(keepdims=True),0.0,1.0).reshape((1,1,1,1))# noqa: E501 # ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=TernaryOps.WHERE, src=(LazyOp(op=BinaryOps.CMPLT, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.5*17, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1,1,1)))),LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ld1)),LazyOp(op=TernaryOps.WHERE, src=(LazyOp(op=BinaryOps.CMPLT, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.75*17, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1,1,1)).expand((N,N,1,1)))),LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ld0)),), arg=(2,3,)))),LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1,1,1)).expand((N,N,1,1)))),LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=3, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1,1,1)).expand((N,N,1,1)))),)),)),), arg=(0,1,)),)),LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.0, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1,1,1)))),LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=1.0, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1,1,1)))),)),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1,1,1)))) # noqa: E501 # helper_linearizer_ast((ast,), [x,a,b], opts=[[Opt(OptOps.PADTO, 0, 32)],], wanna_output=[wanna_output]) @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared") def test_end_local(self): load = MemBuffer(idx=1, dtype=dtypes.int, st=ShapeTracker.from_shape((32,))) store = MemBuffer(idx=0, dtype=dtypes.int, st=ShapeTracker.from_shape((1,))) ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, arg=load),), arg=(0,)),), arg=store), load_t = Tensor.full(load.st.shape, 1).contiguous().realize() k = helper_linearizer_ast(ast, [load_t], wanna_output=[load_t.numpy().sum()])[1] self.assertEqual(k.uops[-1].op, UOps.ENDIF) self.assertLess(k.uops.index([x for x in k.uops if x.op is UOps.STORE][-1]), k.uops.index(k.uops[-1])) def test_two_nested_range(self): a = Tensor.randn(2, ).realize() out = a.reshape(2, 1).expand(2, 3).sum() lin = helper_linearizer_opt(out, wanna_output=[np.broadcast_to(a.numpy().reshape(2, 1), (2, 3)).sum()])[0] ranges = [i for i,u in enumerate(lin.uops) if u.op is UOps.RANGE] assert len(ranges) == 1 # NOTE: it collapses now # RANGE -> LOAD -> RANGE -> PHI #assert any(x.op is UOps.LOAD for x in lin.uops[ranges[0]:ranges[1]]) def test_three_nested_range(self): a = Tensor.randn(2, ).realize() out = a.reshape(2, 1).expand(2, 3).expand(2, 2, 3).sum() lin = helper_linearizer_opt(out, wanna_output=[np.broadcast_to(np.broadcast_to(a.numpy().reshape(2, 1), (2, 3)), (2, 2, 3)).sum()])[0] ranges = [i for i,u in enumerate(lin.uops) if u.op is UOps.RANGE] assert len(ranges) == 1 # NOTE: it collapses now # RANGE -> RANGE -> LOAD -> RANGE -> PHI # NOTE: nothing should toposort between the first two ranges #assert ranges[0]+1 == ranges[1] #assert any(x.op is UOps.LOAD for x in lin.uops[ranges[1]:ranges[2]]) def test_two_nested_range_alt_indexing(self): a = Tensor([2, 2]).realize() out = a.reshape(2, 1).pad(((1, 1), (1, 1)), 2).sum() lin = helper_linearizer_opt(out, wanna_output=[24])[0] ranges = [i for i,u in enumerate(lin.uops) if u.op is UOps.RANGE] # RANGE -> ALU -> RANGE -> ALU + LOAD -> PHI assert any(x.op is UOps.ALU for x in lin.uops[ranges[0]:ranges[1]]) assert not any(x.op is UOps.LOAD for x in lin.uops[ranges[0]:ranges[1]]) assert any(x.op in {UOps.ALU, UOps.LOAD} for x in lin.uops[ranges[1]:]) def test_range_outer_op_before_phi(self): a = Tensor.randn(4, 1).realize() b = Tensor.randn(1, 1).realize() out = (a + b[0]).sum() + b[0] lin = helper_linearizer_opt(out, wanna_output=[(a.numpy()+b.numpy()[0]).sum()+b.numpy()])[0] ranges = [i for i,u in enumerate(lin.uops) if u.op is UOps.RANGE] # LOAD -> RANGE -> LOAD -> PHI assert lin.uops[ranges[0]-2].op is UOps.LOAD def test_range_outer_op_before_phi_nested_range(self): a = Tensor.randn(2, ).realize() b = Tensor.randn(1, 1).realize() out = (a.reshape(2, 1).expand(2, 3) + b[0]).sum() + b[0] lin = helper_linearizer_opt(out, wanna_output=[(np.broadcast_to(a.numpy().reshape(2, 1), (2, 3)) + b.numpy()[0]).sum() + b.numpy()])[0] ranges = [i for i,u in enumerate(lin.uops) if u.op is UOps.RANGE] assert len(ranges) == 1 # NOTE: it collapses now #if getenv("PTX"): # LOAD -> RANGE -> CAST -> ALU -> ALU -> LOAD -> ALU -> RANGE -> ALU -> PHI # assert lin.uops[ranges[0]-2].op is UOps.LOAD # assert ranges[1] == ranges[0]+6 # assert [x.op for x in lin.uops[ranges[1]-2:ranges[1]]] == [UOps.LOAD, UOps.ALU] # LOAD -> RANGE -> LOAD -> ALU -> RANGE -> PHI #else: # assert lin.uops[ranges[0]-2].op is UOps.LOAD # assert ranges[1] == ranges[0]+3 # assert [x.op for x in lin.uops[ranges[1]-2:ranges[1]]] == [UOps.LOAD, UOps.ALU] def test_range_outer_op_after_phi(self): a = Tensor.randn(4, 1).realize() out = a.sum() * a.sum() lin = helper_linearizer_opt(out, wanna_output=[a.numpy().sum()*a.numpy().sum()])[0] # RANGE -> LOAD -> PHI -> ALU end = max(i for i,u in enumerate(lin.uops) if u.op is UOps.ENDRANGE) assert lin.uops[end+1].op is UOps.ALU def test_range_outer_op_after_phi_nested_range(self): a = Tensor.randn(2, ).realize() out = a.reshape(2, 1).expand(2, 3).sum() + a.reshape(2, 1).expand(2, 3).sum() lin = helper_linearizer_opt(out, wanna_output=[(np.broadcast_to(a.numpy().reshape(2, 1), (2, 3))).sum()*2])[0] # RANGE -> LOAD -> PHI -> ALU end = max(i for i,u in enumerate(lin.uops) if u.op is UOps.ENDRANGE) assert lin.uops[end+1].op is UOps.ALU def test_load_dedup(self): # for different leaves in the AST, the same loads may occur. a = Tensor.randn(4).realize() # these are of size 3 to avoid float4 coalesce r = a[:-1] + a[1:] k = Kernel(create_schedule([r.lazydata])[-1].ast) k.upcast() k.linearize() num_loads = len([uop for uop in k.uops if uop.op is UOps.LOAD]) assert num_loads <= 4, "more load uops than needed" assert num_loads >= 4, "unexpected number of uops, maybe this test needs updating?" @unittest.skipIf(getenv("PTX"), "broken on ptx for some reason") def test_load_cache_const_bufs(self): # make sure const buffers are differentiated from local and mem buffers ST, DT = ShapeTracker(views=(View(shape=((1,)), strides=(0, 0), offset=0, mask=None, contiguous=False),)), dtypes.int VAL = LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=2, dtype=DT, st=ST)) # data1[0] + VAL a = LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=DT, st=ST)), VAL)) # (literal const 1) + VAL b = LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=1, dtype=DT, st=ST)), VAL)) ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=BinaryOps.ADD, src=(a,b)),), arg=MemBuffer(idx=0, dtype=DT, st=ST)) lin = Kernel(ast) lin.linearize() assert len(lin.uops) <= 7, "too many uops" a_bufs = [u.op for u in lin.uops[-1].src[2].src] assert a_bufs == [UOps.LOAD, UOps.CONST] def test_upcast_cse(self): # when upcasting, within a subtree, there may be common expressions. a, b = Tensor.randn(1).realize(), Tensor.randn(1).realize() r = a.expand([2]) + b.expand([2]) k = Kernel(create_schedule([r.lazydata])[-1].ast) k.upcast() k.linearize() num_ops = len([uop for uop in k.uops if uop.op is UOps.ALU]) assert num_ops <= 1, "more alu uops than needed" @unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4") def test_reduce_upcast(self): x, w = Tensor.randn((1,1,3)).realize(), Tensor.randn((1,1,2)).realize() r = Tensor.conv2d(x,w,padding=1).relu() k = Kernel(create_schedule([r.lazydata])[-1].ast) k.upcast() k.upcast() k.linearize() accs = [u for u in k.uops if u.op is UOps.DEFINE_ACC] stores = [u for u in k.uops if u.op is UOps.STORE] assert len(accs) == 0 # it's removed now assert len(stores) == 1 assert stores[0].src[-1].dtype == dtypes.float.vec(4) @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared") @unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4") @unittest.skipIf(getenv("PTX"), "broken on ptx for some reason") def test_upcast_with_locals(self): x, y = Tensor.rand(1,128), Tensor.rand(128, 128) r = (x@y).relu() k = Kernel(create_schedule([r.lazydata])[-1].ast) k.hand_coded_optimizations() k.linearize() accs = [u for u in k.uops if u.op is UOps.DEFINE_ACC] stores = [u for u in k.uops if u.op is UOps.STORE] # the first store is to lds and can be upcasted assert accs[0].dtype == stores[0].src[-1].dtype == dtypes.float.vec(4) assert stores[0].src[0].op is UOps.DEFINE_LOCAL # the second store is to gds with no upcasts assert stores[1].src[2].dtype == dtypes.float assert stores[1].src[0].op is UOps.DEFINE_GLOBAL def test_zero_fold(self): a, b = Tensor.randn(1).realize(), Tensor.randn(1).realize() r = Tensor.stack(a, b) k = Kernel(create_schedule([r.lazydata])[-1].ast) k.upcast() k.linearize() num_ops = len([uop for uop in k.uops if uop.op is UOps.ALU]) assert num_ops == 0, "more alu uops than needed" def test_sum_acc_dtype(self): for tensor_dtype, acc_dtype in ( (dtypes.bool, dtypes.int), (dtypes.int16, dtypes.int), (dtypes.float16, dtypes.float), (dtypes.bfloat16, dtypes.float)): a = Tensor([1, 2, 3], dtype=tensor_dtype).sum() k = Kernel(create_schedule([a.lazydata])[-1].ast) k.linearize() local = [uop for uop in k.uops if uop.op is UOps.DEFINE_ACC] assert local[0].dtype == acc_dtype def test_arg_acc_dtype(self): def helper_arg_acc_dtype(c: Tensor, expected_dtype:DType): k = Kernel(create_schedule([c.lazydata])[-1].ast) k.linearize() local = [uop for uop in k.uops if uop.op is UOps.DEFINE_ACC] assert local[0].dtype == expected_dtype tests = ( (dtypes.float16, None, dtypes.float), (dtypes.bfloat16, None, dtypes.float), (dtypes.float, None, dtypes.float), (dtypes.float16, dtypes.float16, dtypes.float16), (dtypes.bfloat16, dtypes.bfloat16, dtypes.bfloat16), (dtypes.float, dtypes.float16, dtypes.float16), ) for tensor_dtype, acc_dtype, expected_dtype in tests: a, b = Tensor.rand(8, 8, dtype=tensor_dtype), Tensor.rand(8, 8, dtype=tensor_dtype) helper_arg_acc_dtype(a.sum(acc_dtype=acc_dtype), expected_dtype) helper_arg_acc_dtype(a.matmul(b, acc_dtype=acc_dtype), expected_dtype) helper_arg_acc_dtype(Tensor.einsum("ki,ij->kj", a, b, acc_dtype=acc_dtype), expected_dtype) d, w = Tensor.rand(4, 8, 8, 8, dtype=tensor_dtype), Tensor.rand(8, 8, 2, 2, dtype=tensor_dtype) helper_arg_acc_dtype(d.conv2d(w, acc_dtype=acc_dtype), expected_dtype) @unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores") def test_tensor_cores(self): for tc in Device[Device.DEFAULT].renderer.tensor_cores: if (getenv("EMULATE_CUDA") or getenv("EMULATE_INTEL")) and (tc.dtype_in == dtypes.bfloat16 or tc.dtype_out == dtypes.bfloat16): continue helper_tc_allclose(tc.dims[0], tc.dims[1], tc.dims[2], tc.dtype_in, tc.dtype_out, axis=0, tc_opt=0) @unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores") def test_tensor_cores_padded(self): for tc in Device[Device.DEFAULT].renderer.tensor_cores: if getenv("EMULATE_CUDA") and (tc.dtype_in == dtypes.bfloat16 or tc.dtype_out == dtypes.bfloat16): continue pad = 1 # check that TC is triggered for TC_OPT=2 helper_tc_ensure_uops_and_opts_count(tc.dims[0]+pad, tc.dims[1]+pad, tc.dims[2]+pad, tc.dtype_in, tc.dtype_out, tc_opt=2, ensure_triggered=True) # check that TC is not triggered for TC_OPT<2 helper_tc_ensure_uops_and_opts_count(tc.dims[0]+pad, tc.dims[1]+pad, tc.dims[2]+pad, tc.dtype_in, tc.dtype_out, tc_opt=1, ensure_triggered=False) helper_tc_ensure_uops_and_opts_count(tc.dims[0]+pad, tc.dims[1]+pad, tc.dims[2]+pad, tc.dtype_in, tc.dtype_out, tc_opt=0, ensure_triggered=False) # check excessive padding doesn't trigger padded TC in TC_OPT=2 helper_tc_ensure_uops_and_opts_count(tc.dims[0]//4, tc.dims[1], tc.dims[2], tc.dtype_in, tc.dtype_out, tc_opt=2, ensure_triggered=False) helper_tc_ensure_uops_and_opts_count(tc.dims[0], tc.dims[1]//4, tc.dims[2], tc.dtype_in, tc.dtype_out, tc_opt=2, ensure_triggered=False) helper_tc_ensure_uops_and_opts_count(tc.dims[0], tc.dims[1], tc.dims[2]//4, tc.dtype_in, tc.dtype_out, tc_opt=2, ensure_triggered=False) # check correctness helper_tc_allclose(tc.dims[0]+pad, tc.dims[1]+pad, tc.dims[2]+pad, tc.dtype_in, tc.dtype_out, tc_opt=2) @unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI is really slow here") @unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores") def test_tensor_cores_multi_reduce(self): for tc in Device[Device.DEFAULT].renderer.tensor_cores: if tc.dtype_in == dtypes.bfloat16 or tc.dtype_out == dtypes.bfloat16: continue # this will be a M=G16, N=G32, M=G16, M=G16, K=R16, K=R16, K=R16 with 9 choices of TC MNK axes golden_result = None for axis in range(9): a = Tensor.rand(16, 16, 29, 29, dtype=tc.dtype_in).realize() b = Tensor.rand(32, 16, 16, 16, dtype=tc.dtype_in).realize() c = a.conv2d(b, padding=1, acc_dtype=tc.dtype_out) realized_ast, real_bufs = helper_realized_ast(c) k = Kernel(realized_ast) k.apply_tensor_cores(1, axis=axis, tc_opt=2) k.linearize() assert len([uop for uop in k.uops if uop.op is UOps.WMMA]) > 0, "tensor core not triggered" assert len([x for x in k.applied_opts if x.op is OptOps.TC]) == 1, "tensor core opt not included" prg = CompiledRunner(k.to_program()) real_bufs[0].copyin(np.zeros((real_bufs[0].size, ), dtype=_to_np_dtype(real_bufs[0].dtype)).data) # Zero to check that all values are filled prg.exec(real_bufs) result = np.frombuffer(real_bufs[0].as_buffer(), _to_np_dtype(real_bufs[0].dtype)) # ensure the results for each choice of axis matches if golden_result is None: golden_result = np.frombuffer(real_bufs[0].as_buffer(), _to_np_dtype(real_bufs[0].dtype)) np.testing.assert_allclose(result, golden_result, atol=0.1, rtol=0.15) # check that get_kernel_actions produces all 9 options from tinygrad.engine.search import get_kernel_actions tc_actions = [k for i, k in get_kernel_actions(Kernel(realized_ast), False).items() if k.applied_opts[0].op == OptOps.TC] assert len(tc_actions) == 9, f"get_kernel_actions should contain 9 possible TC actions, only got {len(tc_actions)}" @unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores") def test_tensor_cores_unroll_phi(self): tc = Device[Device.DEFAULT].renderer.tensor_cores[0] x, y = Tensor.rand(128, 128, dtype=tc.dtype_in), Tensor.rand(128, 128, dtype=tc.dtype_in) r = x.matmul(y, acc_dtype=tc.dtype_out) k = helper_linearizer_opt(r, [[Opt(OptOps.UNROLL, 0, 4)]], apply_tc=True, atol=3e-2, rtol=1e-3)[-1] for u in k.uops: if u.op is UOps.WMMA: assert u.src[-1].src[0].op != UOps.PHI @unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores") def test_tensor_cores_unroll_casted_phi(self): tc = [tc for tc in Device[Device.DEFAULT].renderer.tensor_cores if tc.dtype_in != tc.dtype_out][0] x, y = Tensor.rand(128, 128, dtype=tc.dtype_in), Tensor.rand(128, 128, dtype=tc.dtype_in) r = x.matmul(y, acc_dtype=tc.dtype_out) k = helper_linearizer_opt(r, [[Opt(OptOps.UNROLL, 0, 4)]], apply_tc=True, atol=3e-2, rtol=1e-3)[-1] for u in k.uops: if u.op is UOps.WMMA: #assert u.src[-1].dtype == dtypes.float.vec(prod(tc.thread_local_sizes[2])) assert u.src[-1].src[0].op != UOps.PHI @unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores") def test_tensor_cores_unroll_casted_phi_with_children(self): # all PHI children are outside the loop tc = [tc for tc in Device[Device.DEFAULT].renderer.tensor_cores if tc.dtype_in != tc.dtype_out][0] x, y = Tensor.rand(128, 128, dtype=tc.dtype_in), Tensor.rand(128, 128, dtype=tc.dtype_in) r = x.matmul(y, acc_dtype=tc.dtype_out).relu() k = helper_linearizer_opt(r, [[Opt(OptOps.UNROLL, 0, 4)]], apply_tc=True, atol=3e-2, rtol=1e-3)[-1] for u in k.uops: if u.op is UOps.WMMA: #assert u.src[-1].dtype == dtypes.float.vec(prod(tc.thread_local_sizes[2])) assert u.src[-1].src[0].op != UOps.PHI @unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4") def test_simple_unroll_no_between_phi_dependencies(self): x, y = Tensor.rand(128, 128), Tensor.rand(128, 128) r = (x@y).relu() k = helper_linearizer_opt(r, [[Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UPCAST, 0, 4)]])[-1] # the uops graph is RANGE -> DEFINE_ACC -> 4x ALU -> 4x PHI -> ENDRANGE for u in k.uops: if u.op is UOps.PHI: assert u.src[1].op is UOps.ALU # children of PHI are placed after ENDRANGE if any(x.op is UOps.PHI for x in u.src): end_range = [i for i, x in enumerate(k.uops) if x.op is UOps.ENDRANGE][0] assert end_range < k.uops.index(u) def test_grouped_dims(self): def _assert_grouped_dims(prefix, dims, max_sizes, reverse_dims, expected_sizes): idxs = get_grouped_dims(prefix, dims, max_sizes, reverse_dims) loop_idxs = dedup(flatten([[y for y in sorted(list(x.sparents)) if y.op is UOps.SPECIAL] for x in idxs])) loop_idxs = sorted(loop_idxs, key=lambda uop: uop.arg[0]) sizes = [x.arg[1] for x in loop_idxs] assert len(idxs) == len(dims), f"expected idxs to have same length as dims {len(dims)}, got {len(idxs)}" assert len(loop_idxs) == min(len(sizes), len(dims)), f"expected idxs to have length {min(len(sizes), len(dims))}, got {len(loop_idxs)}" assert sizes == expected_sizes, f"expected sizes={expected_sizes}, got {sizes=}" # TODO: add these back after uop symbolic # for i in range(len(dims)): # assert idxs[i].max+1 == dims[i], f"idxs[{i}] should have max {dims[i]-1}" # for i in range(len(loop_idxs)): # assert loop_idxs[i].expr.startswith(prefix), f"loop_idxs[{i}] must start with {prefix}" # assert loop_idxs[i].max+1 == sizes[i], f"loop_idxs[{i}] should have max {sizes[i]-1}" # no-op _assert_grouped_dims("gidx", (2,), (16,16,16), False, [2]) _assert_grouped_dims("gidx", (2,3), (16,16,16), False, [2,3]) # check reverse dims _assert_grouped_dims("gidx", (2,3), (16,16,16), True, [3,2]) _assert_grouped_dims("gidx", (2,3,4), (16,16,16), False, [2,3,4]) # test splitting globals # _assert_grouped_dims("gidx", (64,3,4), (16,16,16), False, [16,12,4]) # _assert_grouped_dims("gidx", (64,3,4), (16,4,16), False, [16,4,12]) # _assert_grouped_dims("gidx", (64,3,4), (16,16,16), True, [12,16,4]) # _assert_grouped_dims("gidx", (128,3,4), (16,4,256), False, [16,4,24]) # collapse on onto the left most axis _assert_grouped_dims("gidx", (2,3,4,5), (16,16,16), False, [6,4,5]) _assert_grouped_dims("gidx", (2,3,4,5), (32,16,16), True, [20,3,2]) # _assert_grouped_dims("gidx", (Variable("start_pos",1,2),3,4,5), (32,16,16), True, [20,3,Variable("start_pos",1,2)]) # collapse on left-most available axis (the left most is too small) _assert_grouped_dims("gidx", (2,3,4,5), (4,16,16), False, [2,12,5]) _assert_grouped_dims("gidx", (2,3,4,5), (16,16,16), True, [5,12,2]) # _assert_grouped_dims("gidx", (Variable("start_pos",1,2),3,4,5), (16,16,16), False, [Variable("start_pos",1,2)*3,4,5]) # # dim too large and not factorable # with self.assertRaises(AssertionError): # get_grouped_dims("gidx", (23,), (16,16,16), False,) # with self.assertRaises(AssertionError): # get_grouped_dims("gidx", (128,3,4), (16,4,23), False,) # too large for sizes with self.assertRaises(RuntimeError): get_grouped_dims("gidx", (2,3,4,5,6), (16,16,16)) # # variable too large # with self.assertRaises(AssertionError): # get_grouped_dims("gidx", (Variable("start_pos",0,16),3,4), (16,16,16), False,) @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") def test_default_global_reversed(self): # shrink so that the dims do not collapse t = Tensor.ones(5, 6, 7).contiguous().realize().shrink(((0, 4), (0, 5), (0, 6))) k = helper_linearizer_opt(t+1)[0] idxs = dedup([uop for uop in k.uops if uop.op is UOps.SPECIAL]) idxs = sorted(idxs, key=lambda uop: uop.arg[0]) assert idxs[0].arg == ('gidx0', 6), idxs[0].arg assert idxs[1].arg == ('gidx1', 5), idxs[1].arg assert idxs[2].arg == ('gidx2', 4), idxs[2].arg def test_div_collapse(self): def helper(t, msg, max_ops=0): sched = [si for si in create_schedule([t.lazydata]) if si.ast.op is UOps.SINK] assert len(sched) == 1 lin = Kernel(sched[0].ast) assert sum(u.arg is UnaryOps.RECIP for u in lin.linearize().uops) == max_ops, msg a = Tensor.empty((4,4)) b = Tensor.empty((4,4)) d = Tensor.empty((4,4)) c = (a*b)/b helper(c, "found UnaryOps.RECIP in (a*b)/b operation") c = a/a helper(c, "found UnaryOps.RECIP in (a/a) operation") c = (a/b)/d helper(c, "found multiple UnaryOps.RECIP in (a/b)/d operation", 1) def test_sum_collapse(self): t = Tensor([2]).reshape(1, 1).expand(256, 256).sum() sched = [si for si in create_schedule([t.lazydata]) if si.ast.op is UOps.SINK] assert len(sched) == 1 lin = Kernel(sched[0].ast) assert not any(u.op is UOps.RANGE for u in lin.linearize().uops), "found loop in sum collapse" def test_assign_fold(self): a = Tensor.ones(4, 4).contiguous().realize() m = Tensor.ones(4, 4).shrink(((1, 2), None)).pad(((1, 2), None)) a.assign(a+m) a.realize() np.testing.assert_equal(a.flatten().numpy(), [1.,1.,1.,1.,2.,2.,2.,2.,1.,1.,1.,1.,1.,1.,1.,1.]) def test_where_fold(self): a = Tensor.ones(4, 4).contiguous().realize() b = a.shrink(((1, 2), None)).pad(((1, 2), None)) a.assign(b.where(2, a)) sched = create_schedule([a.lazydata]) assert len(sched) == 1 sched_copy = sched[:] run_schedule(sched) np.testing.assert_equal(a.flatten().numpy(), [1.,1.,1.,1.,2.,2.,2.,2.,1.,1.,1.,1.,1.,1.,1.,1.]) lin = Kernel(sched_copy[-1].ast) lin.hand_coded_optimizations() lin.linearize() assert not any(u.arg == TernaryOps.WHERE for u in lin.uops), "found where where where should be folded" def test_phi_simplification(self): def helper(t, max_ops=0): k = helper_linearizer_opt(t)[-1] uops = list(k.linearize().uops) # ignore kernel optimized IF statements for now if if_op:=next((u for u in uops if u.op is UOps.IF), None): uops = uops[:uops.index(if_op)] assert len(set([u.op for u in uops if u.op in {UOps.RANGE, UOps.SPECIAL}])) == 1, "has either specials or ranges, not both" assert len([u for u in uops if u.op is UOps.PHI]) == 0, "PHI should have been simplified" # TODO: once uops track min/max this will be fixed #assert len([u for u in uops if u.arg is BinaryOps.MAX]) <= max_ops, "no unnecessary MAX ops" helper(Tensor.arange(5.5, (3.5*300), 3.5), max_ops=2) helper(Tensor.arange(-1, -100, -5), max_ops=2) # NOTE: both of these split the reduce (this just wasn't tracked before) #helper(Tensor.arange(-3.2, 6.7, 0.64), max_ops=2) #helper(Tensor.arange(256), max_ops=2) helper(Tensor.arange(255), max_ops=2) @unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4") def test_grouped_store_phis(self): """ float4 acc0 = float4(0.0,0.0,0.0,0.0); { acc0 = // ... } *((device float4*)(data0+alu2)) = float4(acc0.x,acc0.y,acc0.z,acc0.w); simplifies to: *((device float4*)(data0+alu2)) = acc0; """ x, y = Tensor.randn(64,64), Tensor.randn(64,64) out = x.matmul(y) k = helper_linearizer_opt(out)[-1] # check that the float4 cast collapses store_vals = [u.src[-1] for u in k.uops if u.op is UOps.STORE] for val in store_vals: assert val.dtype == dtypes.float.vec(4) and val.op is not UOps.VECTORIZE @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared") @unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4") def test_arange_opts(self): a = Tensor.arange(128) helper_linearizer_opt(a, [ [Opt(OptOps.GROUP, 0, 32)], [Opt(OptOps.GROUPTOP, 0, 32)], [Opt(op=OptOps.LOCAL, axis=0, amt=8)], [Opt(op=OptOps.LOCAL, axis=0, amt=8), Opt(op=OptOps.UPCAST, axis=0, amt=0)], [Opt(op=OptOps.LOCAL, axis=0, amt=8), Opt(op=OptOps.UPCAST, axis=0, amt=0), Opt(op=OptOps.GROUP, axis=0, amt=8)], [Opt(op=OptOps.LOCAL, axis=0, amt=8), Opt(op=OptOps.UPCAST, axis=0, amt=0), Opt(op=OptOps.GROUP, axis=0, amt=8), Opt(op=OptOps.UNROLL, axis=1, amt=4)], # noqa: E501 ]) @unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4") def test_grouped_store_values(self): x = Tensor.randn((4,3,6,6)).realize() out = x.flip((0,1)).contiguous() k = helper_linearizer_opt(out)[-1] store_val = [u.src[-1] for u in k.uops if u.op is UOps.STORE][0] assert store_val.dtype == dtypes.float.vec(4) and store_val.op is not UOps.VECTORIZE @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared") @unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4") def test_grouped_store_locals_and_globals(self): x, y = Tensor.rand(128, 128), Tensor.rand(128, 128) out = x@y opt = [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.GROUPTOP, 0, 8), Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 2)] # upcast accs in both reduces k = helper_linearizer_opt(out, opts=[opt])[-1] def get_recursive(uop): return set.union(set(uop.src), [uop], *[get_recursive(v) for v in uop.src]) local_stores = [u for u in k.uops if u.op is UOps.STORE and any(x.op is UOps.DEFINE_LOCAL for x in get_recursive(u.src[0]))] global_stores = [u for u in k.uops if u.op is UOps.STORE and any(x.op is UOps.DEFINE_GLOBAL for x in get_recursive(u.src[0]))] barrier = [u for u in k.uops if u.op is UOps.BARRIER][0] # check that the float4 cast collapses for all stores for store in local_stores+global_stores: assert store.src[2].dtype.count > 1 and store.src[2].op is not UOps.VECTORIZE # # check the children's vins # TODO: src ALU are not the same, should it? # assert barrier.src == tuple(local_stores) assert len([u for u in k.uops if u.op is UOps.IF and u.src[-1] == barrier]) == 1 @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared") @unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4") def test_grouped_store_local_only(self): x, y = Tensor.rand(1,128), Tensor.rand(128, 128) r = (x@y).relu() k = helper_linearizer_opt(r)[-1] stores = [u for u in k.uops if u.op is UOps.STORE] # the float4 value stores directly in lds and we skip upcast assert stores[0].src[-1].dtype == dtypes.float.vec(4) assert stores[0].src[-1].op is not UOps.VECTORIZE # the global store doesn't change assert stores[1].src[2].dtype == dtypes.float @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") @unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4") def test_skip_unmatching_upcasts(self): Tensor.manual_seed(0) ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(240, 40, 1, 1), strides=(1, 240, 0, 0), offset=0, mask=None, contiguous=False),)))),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(240, 40, 1, 1), strides=(40, 1, 0, 0), offset=0, mask=None, contiguous=True),)))), # noqa: E501 opt = [ Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=16), Opt(op=OptOps.LOCAL, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=3, amt=2) ] k = helper_linearizer_ast(ast, [Tensor.randn(240*40).realize()], opts=[opt])[-1] out = [u for u in k.uops if u.op is UOps.STORE][0] assert out.src[-1].op is UOps.VECTORIZE and out.src[-1].dtype == dtypes.float.vec(4) @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") @unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4") def test_skip_unmatching_upcasts_with_gep(self): Tensor.manual_seed(0) ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(8, 32, 1, 1), strides=(1, 8, 0, 0), offset=0, mask=None, contiguous=False),)))),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(8, 32, 1, 1), strides=(32, 1, 0, 0), offset=0, mask=None, contiguous=True),)))), # noqa: E501 opt = [Opt(op=OptOps.LOCAL, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=2, amt=2), Opt(op=OptOps.LOCAL, axis=1, amt=8), Opt(op=OptOps.UPCAST, axis=1, amt=0), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=8), Opt(op=OptOps.UPCAST, axis=1, amt=0), Opt(op=OptOps.UPCAST, axis=0, amt=2)] k = helper_linearizer_ast(ast, [Tensor.randn(8*32).realize()], opts=[opt])[-1] out = [u for u in k.uops if u.op is UOps.STORE][0] assert out.src[-1].op is UOps.VECTORIZE and out.src[-1].dtype.count != 1 @unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "need backends that support float4") class TestFloat4(unittest.TestCase): @staticmethod def count_float4(k): return (len([uop for uop in k.uops if uop.op is UOps.LOAD and uop.dtype == dtypes.float.vec(4)]), len([uop for uop in k.uops if uop.op is UOps.STORE and len(uop.src) == 3 and uop.src[2].dtype == dtypes.float.vec(4)])) @staticmethod def count_half4(k): return (len([uop for uop in k.uops if uop.op is UOps.LOAD and uop.dtype == dtypes.half.vec(4)]), len([uop for uop in k.uops if uop.op is UOps.STORE and len(uop.src) == 3 and uop.src[2].dtype == dtypes.half.vec(4)])) # TODO: express opts below as auto opts def test_float4_basic(self): a = Tensor.rand(2, 8).realize() b = Tensor.rand(2, 8).realize() c = a + b s = create_schedule([c.lazydata])[0] k = Kernel(s.ast) k.hand_coded_optimizations() k.linearize() assert TestFloat4.count_float4(k) == (2, 1) def test_float4_multidim(self): a = Tensor.rand(2, 8).realize() b = Tensor.rand(2, 8).realize() c = a + b s = create_schedule([c.lazydata])[0] k = Kernel(s.ast) k.shift_to(0, 4) # float4 dimension k.shift_to(0, 2, insert_before=k.shape_len-1) k.upcast() k.upcast() k.local_dims += 1 k.linearize() assert TestFloat4.count_float4(k) == (4, 2) def test_float4_unaligned_load(self): a = Tensor.rand(9).realize().shrink(((1, 9),)) b = Tensor.rand(9).realize().shrink(((1, 9),)) c = a + b s = create_schedule([c.lazydata])[0] k = Kernel(s.ast) k.hand_coded_optimizations() # implicit trigger float4 dim k.linearize() assert TestFloat4.count_float4(k) == (0, 1) def test_float4_multidim_unaligned_load(self): a = Tensor.rand(2, 9).realize().shrink(((0, 2), (1, 9),)) b = Tensor.rand(2, 9).realize().shrink(((0, 2), (1, 9),)) c = a + b s = create_schedule([c.lazydata])[0] k = Kernel(s.ast) k.shift_to(len(k.full_unupcasted_shape)-1, 4) # manual trigger float4 dim k.upcast() k.shift_to(len(k.full_unupcasted_shape)-1, 2, insert_before=k.shape_len-1) k.upcast() k.local_dims += 1 k.linearize() assert TestFloat4.count_float4(k) == (0, 2) def test_float4_sometimes_unaligned(self): a = Tensor.rand(1, 1, 8).realize() b = Tensor.rand(1, 1, 5).realize().shrink(((0, 1), (0, 1), (1, 5))) c = a.conv2d(b) # only the first and last conv dot products are aligned in a, and b is never aligned, so no # float4 should be emitted (the reduce axis of size 4 is the float4 axis here) s = create_schedule([c.lazydata])[0] k = Kernel(s.ast) k.upcast() k.linearize() assert TestFloat4.count_float4(k) == (0, 0) def test_float4_multidim_sometimes_unaligned(self): a = Tensor.rand(1, 1, 7).realize() b = Tensor.rand(1, 1, 5).realize().shrink(((0, 1), (0, 1), (1, 5))) c = a.conv2d(b) # the first conv dot product is aligned in a. If we upcast the output and reduce # dimension, then we could do float4 for only that one set of loads, but we currently # don't. # UPDATE: now we do this fusion s = create_schedule([c.lazydata])[0] k = Kernel(s.ast) k.upcast() k.upcast() k.linearize() assert TestFloat4.count_float4(k) in {(0,1), (1,1)} def test_float4_noncontiguous(self): a = Tensor.rand(4, 2).realize() b = Tensor.rand(4, 2).realize() c = a + b # we will upcast the top axis of sz 4. they should not be coalesced into float4, # since the top axis is not contiguous. s = create_schedule([c.lazydata])[0] k = Kernel(s.ast) k.shift_to(0, 4, top=True) # top axes are float4 axes k.upcast() k.linearize() assert TestFloat4.count_float4(k) == (0, 0) def test_float4_expand(self): a = Tensor.rand(9).realize().shrink(((1, 9),)) b = Tensor.rand(2).realize().reshape((2, 1)).expand((2,4)).reshape((8,)) c = a + b # we will upcast the top axis of sz 4. they should not be coalesced into float4, # since the top axis is not contiguous. s = create_schedule([c.lazydata])[0] k = Kernel(s.ast) k.shift_to(0, 4) # float4 axis k.upcast() k.linearize() assert TestFloat4.count_float4(k) == (0, 1) def test_float4_heterogeneous(self): a = Tensor.rand(8).realize() b = Tensor.rand(9).realize().shrink(((1, 9),)) c = a + b # should float4 b but not a s = create_schedule([c.lazydata])[0] k = Kernel(s.ast) k.shift_to(0, 4) # float4 axis k.upcast() k.linearize() assert TestFloat4.count_float4(k) == (1, 1) def test_half4_load_unrolled(self): # from llama 7B shard 4 gpus ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 3, 32000, 1024), strides=(0, 4096, 0, 1), offset=0, mask=None, contiguous=False),)))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 3, 32000, 1024), strides=(0, 0, 1024, 1), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=dtypes.float),), arg=(3,)),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 3, 32000, 1), strides=(0, 32000, 1, 0), offset=0, mask=None, contiguous=True),)))) # noqa: E501 # TODO: fix this, expected might change but should be positive for expected, opts in [ ((7, 0), [Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=3), Opt(op=OptOps.UNROLL, axis=0, amt=4)]), ((5, 0), [Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UNROLL, axis=0, amt=4)]), ((2, 0), [Opt(op=OptOps.UNROLL, axis=0, amt=4)]), ]: k = Kernel(ast) for opt in opts: k.apply_opt(opt) k.linearize() count = TestFloat4.count_half4(k) assert count == expected, f"{count=}, {expected=}" def test_float4_acc(self): # from float32 stable diffusion red tinybox ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 1, 256, 4, 514, 4, 514), strides=(0, 0, 0, 262144, 0, 512, 0, 1), offset=-513, mask=((0, 1), (0, 1), (0, 1), (0, 256), (0, 4), (1, 513), (0, 4), (1, 513)), contiguous=False), View(shape=(1, 1, 128, 512, 512, 256, 3, 3), strides=(0, 0, 0, 2056, 1, 4227136, 1058840, 515), offset=0, mask=None, contiguous=False))))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 128, 512, 512, 256, 3, 3), strides=(0, 0, 2304, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=(5, 6, 7)), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=3, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 128, 512, 512, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 128, 512, 512, 1, 1, 1), strides=(0, 0, 262144, 512, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)))) # noqa: E501 for expected, opts in [ (1, [Opt(op=OptOps.UPCAST, axis=2, amt=4)]), (4, [Opt(op=OptOps.UPCAST, axis=2, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=4)]), ]: k = Kernel(ast) for opt in opts: k.apply_opt(opt) k.linearize() count = len([uop for uop in k.uops if uop.op is UOps.DEFINE_ACC and uop.dtype == dtypes.float.vec(4)]) assert count == expected, f"{count=}, {expected=}" @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") def test_float2_acc(self): # from resnet ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(256, 64, 3, 56, 2, 3, 56, 2), strides=(1806336, 28224, 3, 504, 0, 1, 9, 0), offset=0, mask=((0, 256), (0, 64), (0, 3), (0, 56), (0, 1), (0, 3), (0, 56), (0, 1)), contiguous=False), View(shape=(256, 64, 3, 115, 3, 115), strides=(7225344, 112896, 37632, 336, 112, 1), offset=0, mask=((0, 256), (0, 64), (0, 3), (0, 112), (0, 3), (0, 112)), contiguous=False), View(shape=(256, 64, 456, 456), strides=(7617600, 119025, 345, 1), offset=0, mask=((0, 256), (0, 64), (0, 345), (0, 345)), contiguous=False), View(shape=(1, 256, 1, 64, 4, 114, 4, 114), strides=(0, 13307904, 0, 207936, 51984, 456, 114, 1), offset=0, mask=None, contiguous=True))))),), arg=dtypes.float),), arg=(4, 6)),), arg=dtypes.half),), arg=MemBuffer(idx=0, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 256, 1, 64, 1, 114, 1, 114), strides=(0, 831744, 0, 12996, 0, 114, 0, 1), offset=0, mask=None, contiguous=True),)))) # noqa: E501 for expected, opts in [ (16, [Opt(op=OptOps.LOCAL, axis=1, amt=16), Opt(op=OptOps.UPCAST, axis=1, amt=0), Opt(op=OptOps.UPCAST, axis=2, amt=2), Opt(op=OptOps.LOCAL, axis=2, amt=3), Opt(op=OptOps.UPCAST, axis=3, amt=4)]), # noqa: E501 (4, [Opt(op=OptOps.LOCAL, axis=1, amt=16), Opt(op=OptOps.UPCAST, axis=1, amt=0), Opt(op=OptOps.UPCAST, axis=2, amt=2)]), ]: k = Kernel(ast) for opt in opts: k.apply_opt(opt) k.linearize() count = len([uop for uop in k.uops if uop.op is UOps.DEFINE_ACC and uop.dtype == dtypes.float.vec(2)]) assert count == expected, f"{count=}, {expected=}" class TestHandCodedOpts(unittest.TestCase): def test_masked_upcast(self): layer_1 = Tensor.cat(*[Tensor.rand(5) for _ in range(4)]) layer_2 = Tensor.cat(layer_1.unsqueeze(0), Tensor.rand(6, 20)) s = create_schedule([layer_2.lazydata])[-1] k = Kernel(s.ast) k.hand_coded_optimizations() assert len(k.bufs) == 6 # make sure all ops are done in one kernel # masked upcast should upcast masked axis of size 7 # masked upcast should not upcast large (20) last axis # float4/other hcopt shouldn't upcast last axis, since we already have 7 upcast, and the last axis is not very contiguous assert k.upcasted == 1 and k.full_shape[-1] == 7 def test_masked_upcast_wino(self): monster = Tensor.stack(*[Tensor.stack(*[Tensor.rand(16) for _ in range(6)]) for _ in range(6)]) s = create_schedule([monster.lazydata])[-1] k = Kernel(s.ast) k.hand_coded_optimizations() assert len(k.bufs) == 37 # make sure all ops are done in one kernel # should upcast the two Tensor.stacks assert k.upcasted >= 2 and k.full_shape[k.shape_len-k.upcasted:k.shape_len].count(6) == 2 def test_masked_upcast_wino_full(self): with Context(WINO=1): x,w = Tensor.rand(1,4,8,8, requires_grad=True).realize(), Tensor.rand(4,4,3,3, requires_grad=True).realize() out = Tensor.conv2d(x,w, padding=1) upcasts = [] wino_schedule = create_schedule([out.lazydata]) # collect upcasts of tile transform kernels for i, si in enumerate(wino_schedule): k = Kernel(si.ast) k.hand_coded_optimizations() if k.reduceop is not None: continue # not a tile transform kernel (there is a gemm reduce kernel) if len(k.bufs) < 36: continue # not a tile transform kernel (there's a permute kernel at the end) upcasts.append(tuple(k.full_shape[k.shape_len - k.upcasted:k.shape_len])) assert len(upcasts) == 3 # 3 transformation matrices assert len(wino_schedule) <= 4 # 4 kernels # this test case's inputs are too small, so one of the 4-stacks became a local, which is fine i guess assert upcasts.count((6, 6)) == 2 #and upcasts.count((4, 4)) == 1 out.mean().backward() backward_schedule = create_schedule([x.grad.lazydata, w.grad.lazydata]) for si in backward_schedule: k = Kernel(si.ast) k.hand_coded_optimizations() k.linearize() if len(k.bufs) < 20: continue # not a tile transform kernel # heuristic number to make sure that at least some upcasts but not too many upcasts are being done assert 6 <= prod(k.full_shape[k.shape_len - k.upcasted:k.shape_len]) <= 216 assert len(backward_schedule) <= 13 # just the current number, but it could be better def test_masked_upcast_many(self): layer_1 = Tensor.cat(Tensor.rand(3, 4), Tensor.rand(4, 4)) layer_2 = Tensor.cat(layer_1.unsqueeze(0), Tensor.rand(6, 7, 4)) layer_3 = Tensor.cat(layer_2.unsqueeze(0), Tensor.rand(6, 7, 7, 4)) k = helper_linearizer_opt(layer_3)[-1] assert len(k.bufs) == 5 # make sure all ops are done in one kernel # check that we don't do too many upcasts assert prod(k.full_shape[k.shape_len-k.upcasted:k.shape_len]) <= 49 @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") def test_matvec(self): N = 128 a = Tensor.rand(1, N).realize() b = Tensor.rand(N, N).realize() c = a @ b k = helper_linearizer_opt(c)[-1] assert k.group_for_reduces == 1 assert k.local_dims == 1 assert k.upcasted == 1 def helper_linearizer_ast(ast:Union[Tuple[LazyOp, ...], LazyOp, UOp], inputs:List[Tensor], *args, **kwargs): if not isinstance(ast, LazyOp) and not isinstance(ast, UOp): ast = LazyOp(MetaOps.KERNEL, ast) inbufs = [x.lazydata.base.buffer for x in inputs] ast = to_uop(ast) if isinstance(ast, LazyOp) else ast outbufs = [Buffer(inbufs[-1].device if inbufs else Device.DEFAULT, out.st_arg.size, cast(DType,out.src[2].dtype)).allocate() \ for out in ast.src] return _helper_linearizer_opt_ast(ast, outbufs+inbufs, *args, **kwargs) def helper_linearizer_opt(r:Union[Tensor, List[Tensor]], *args, **kwargs): realized_ast, real_bufs = helper_realized_ast(r) return _helper_linearizer_opt_ast(realized_ast, real_bufs, *args, **kwargs) def _helper_linearizer_opt_ast(realized_ast:UOp, real_bufs:List[Buffer], opts=[], apply_tc=False, atol=1e-4, rtol=1e-4, color_sizes=[], wanna_output=[]) -> List[Kernel]: lins: List[Kernel] = [] outbufs = [(real_bufs[i], lop.st_arg.shape) for i,lop in enumerate(realized_ast.src)] def get_prg(k:Kernel): return CompiledRunner(replace(k.to_program(), dname=Device.DEFAULT)) def check_opt(opts, create_k, expected_color_size): k = create_k() lins.append(k) if apply_tc: assert k.apply_tensor_cores(1, extra_opts=opts), "no tensor core triggered" else: for opt in opts: k.apply_opt(opt) if expected_color_size is not None: assert (cs:=list(zip(k.colors(), k.full_shape))) == expected_color_size, f"expected={expected_color_size} got={cs}" prg = get_prg(k) for buf,_ in outbufs: buf.copyin(np.zeros((buf.size, ), dtype=_to_np_dtype(buf.dtype)).data) # Zero to check that all values are filled prg.exec(real_bufs) for i, (buf,shape) in enumerate(outbufs): np.testing.assert_allclose(np.frombuffer(buf.as_buffer(), _to_np_dtype(buf.dtype)).reshape(shape), wanna_output[i], atol=atol, rtol=rtol) # Get baseline if it is not provided, which is not optimized at all. k = Kernel(realized_ast) lins.append(k) prg = get_prg(k) prg.exec(real_bufs) if len(wanna_output) == 0: wanna_output = [np.frombuffer(buf.as_buffer(), _to_np_dtype(buf.dtype)).reshape(shape).copy() for buf,shape in outbufs] else: for i, (buf,shape) in enumerate(outbufs): np.testing.assert_allclose(np.frombuffer(buf.as_buffer(), _to_np_dtype(buf.dtype)).reshape(shape), wanna_output[i], atol=atol, rtol=rtol) # Check correctness of handcoded optimiztions. k = Kernel(realized_ast) lins.append(k) k.hand_coded_optimizations() prg = get_prg(k) for buf,_ in outbufs: buf.copyin(np.zeros((buf.size, ), dtype=_to_np_dtype(buf.dtype)).data) # Zero to check that all values are filled prg.exec(real_bufs) for i, (buf,shape) in enumerate(outbufs): np.testing.assert_allclose(np.frombuffer(buf.as_buffer(), _to_np_dtype(buf.dtype)).reshape(shape), wanna_output[i], atol=atol, rtol=rtol) for i, x in enumerate(opts): # Check custom transformations if any. check_opt(x, lambda: Kernel(realized_ast), color_sizes[i] if i < len(color_sizes) else None) return lins class TestKernelOpts(unittest.TestCase): @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared") def test_local_and_grouped_reduce(self): N = 128 Tensor.manual_seed(1882) a = Tensor.rand(4, 4, N, N) b = Tensor.rand(4, 4, N) r = (b.sqrt() + ((a+1).sum(axis=3).exp())) helper_linearizer_opt(r, [ [Opt(OptOps.LOCAL, 0, 2)], [Opt(OptOps.LOCAL, 0, 8)], [Opt(OptOps.LOCAL, 0, 16)], # Checking how it works with locals [Opt(OptOps.GROUPTOP, 0, 2)], [Opt(OptOps.GROUPTOP, 0, 32)], [Opt(OptOps.GROUPTOP, 0, 64)], # Checking how it works with grouped reduce [Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.GROUPTOP, 0, 2)], [Opt(OptOps.LOCAL, 0, 16), Opt(OptOps.GROUPTOP, 0, 16)], [Opt(OptOps.LOCAL, 0, 32), Opt(OptOps.GROUPTOP, 0, 2)], # Checking how it works with locals + grouped reduce [Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.GROUPTOP, 0, 64)], # Checking how it works with locals + grouped reduce + upcasts [Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.UPCAST, 0, 8), Opt(OptOps.UNROLL, 1, 4)], # many local + many group [Opt(OptOps.GROUP, 0, 2)] * 4, [Opt(OptOps.LOCAL, 0, 2)] * 4, [Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.GROUP, 0, 2)] * 4, ]) def test_upcasts(self): N = 16 Tensor.manual_seed(1772) a = Tensor.rand(N, N) b = Tensor.rand(N, N) r = (a+b).sqrt() * ((a+1).exp()) helper_linearizer_opt(r, [ [Opt(OptOps.UPCAST, 0, 2)], [Opt(OptOps.UPCAST, 0, 4)], [Opt(OptOps.UPCAST, 0, 8)], # Checking how it works with upcasts ]) def test_full_upcast(self): Tensor.manual_seed(1772) a = Tensor.rand(4) b = Tensor.rand(4) r = (a+b).sqrt() * ((a+1).exp()) helper_linearizer_opt(r, [ [Opt(OptOps.UPCAST, 0, 4)], # Checking how it works with upcasts ]) @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared") def test_matmul(self): N = 128 Tensor.manual_seed(1552) a = Tensor.rand(N, N) b = Tensor.rand(N, N) r = a@b helper_linearizer_opt(r, [ [Opt(OptOps.UPCAST, 0, 2)], [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4)], # Checking how it works with upcasts [Opt(OptOps.LOCAL, 0, 2)], [Opt(OptOps.LOCAL, 1, 32)], [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 1, 4)], [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 1, 32)], [Opt(OptOps.LOCAL, 0, 16), Opt(OptOps.LOCAL, 1, 8)], # Checking how it works with locals [Opt(OptOps.GROUPTOP, 0, 2)], [Opt(OptOps.GROUPTOP, 0, 32)], [Opt(OptOps.GROUPTOP, 0, 32), Opt(OptOps.UNROLL, 0, 4)], # Checking how it works with grouped_reduce [Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.LOCAL, 1, 2), Opt(OptOps.GROUPTOP, 0, 32)], [Opt(OptOps.LOCAL, 0, 8), Opt(OptOps.GROUPTOP, 0, 32)], [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 0, 8), Opt(OptOps.GROUPTOP, 0, 4)], # Checking how it works with local+grouped_reduce # Checking all together [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.GROUPTOP, 0, 8), Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 2)], # Full global upcast + local [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.GROUPTOP, 0, 8), Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UPCAST, 0, 8)], ]) @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared") def test_double_reduce(self): N = 128 Tensor.manual_seed(1552) a = Tensor.rand(8, N, 8, N) r = a.sum(axis=(1,3)) helper_linearizer_opt(r, [ # openCL / GPU=1 is 256 max threads [Opt(OptOps.GROUPTOP, 0, 2)], [Opt(OptOps.GROUPTOP, 0, 32)], [Opt(OptOps.GROUPTOP, 1, 2)], [Opt(OptOps.GROUPTOP, 1, 32)], # Checking how it works with 1 grouped_reduce. [Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.GROUPTOP, 1, 2)], [Opt(OptOps.GROUPTOP, 0, 16), Opt(OptOps.GROUPTOP, 1, 2)], [Opt(OptOps.GROUPTOP, 0, 4), Opt(OptOps.GROUPTOP, 1, 64)], # Checking how it works with 2 grouped_reduces. [Opt(OptOps.GROUPTOP, 0, 16), Opt(OptOps.GROUPTOP, 1, 2), Opt(OptOps.UNROLL, 0, 4)], [Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.GROUPTOP, 1, 32), Opt(OptOps.UNROLL, 2, 4)], # Checking how it works with 2 grouped_reduces + upcasts. [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 1, 4), Opt(OptOps.GROUPTOP, 0, 4), Opt(OptOps.GROUPTOP, 1, 4)], # Checking how it works with 2 grouped_reduces + upcasts + locals. [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 1, 4), Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.GROUPTOP, 1, 32), Opt(OptOps.UNROLL, 1, 4)], [Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.LOCAL, 1, 2), Opt(OptOps.GROUPTOP, 0, 8), Opt(OptOps.GROUPTOP, 1, 4), Opt(OptOps.UPCAST, 0, 2)], [Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.LOCAL, 1, 2), Opt(OptOps.GROUPTOP, 0, 8), Opt(OptOps.GROUPTOP, 1, 4), Opt(OptOps.UPCAST, 0, 2), Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UNROLL, 1, 4)], # Checking how it works with 2 grouped_reduces + upcasts + locals. [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 1, 4), Opt(OptOps.GROUPTOP, 0, 4), Opt(OptOps.GROUPTOP, 1, 4), Opt(OptOps.UPCAST, 0, 2), Opt(OptOps.UPCAST, 0, 2)], # No globals ]) @unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores") def test_invalid_tensor_core_extra_opts(self): N = 128 Tensor.manual_seed(1552) a = Tensor.rand(N, N) b = Tensor.rand(N, N) realized_ast, _ = helper_realized_ast(a@b) invalid_opts = [ [Opt(OptOps.LOCAL, 2, 2)], [Opt(OptOps.UPCAST, 2, 2)], [Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.LOCAL, 2, 2)], ] for x in invalid_opts: k = Kernel(realized_ast) with self.assertRaises(AssertionError): assert k.apply_tensor_cores(use_tensor_cores=1, extra_opts=x), "no valid tensor core" # for METAL in runners @unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores") def test_buf_index_not_found_tensor_core(self): ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BinaryOps.CMPNE, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(1243, 256), strides=(0, 1), offset=0, mask=None, contiguous=False),)))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(1243, 256), strides=(1, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=dtypes.float), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=3, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1243, 256), strides=(1, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=(0,)),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 256), strides=(0, 1), offset=0, mask=None, contiguous=True),)))) # noqa: E501 k = Kernel(ast, opts=Device[Device.DEFAULT].renderer) with self.assertRaises(KernelOptError): k.apply_opt(Opt(OptOps.TC, 0, 1)) @unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores") def test_tensor_core_opts(self): N = 128 Tensor.manual_seed(1552) for tc in Device[Device.DEFAULT].renderer.tensor_cores: # bf16 buffer returns float32 numpy outputs so test would fail. testing opt with half suffices. if tc.dtype_in == dtypes.bfloat16: continue a, b = Tensor.rand(N, N, dtype=tc.dtype_in), Tensor.rand(N, N, dtype=tc.dtype_in) r = a.matmul(b, acc_dtype=tc.dtype_out) (atol, rtol) = ((0.25, 0.01) if tc.dtype_out == dtypes.half else (3e-2, 1e-3)) if tc.dtype_in == dtypes.half else (1e-4, 1e-4) helper_linearizer_opt(r, [ [], [Opt(OptOps.UPCAST, 0, 4)], [Opt(OptOps.UPCAST, 1, 4)], [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4)], # check upcasts [Opt(OptOps.UNROLL, 0, 2)], # check unroll [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 2)], # check combo of unroll and local [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 2)], [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 4)], [Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UPCAST, 0, 4)], # check permutations [Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UPCAST, 0, 4)], [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UPCAST, 1, 4)], [Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 4)], # [Opt(OptOps.GROUP, 0, 2)] # doesn't work because group_for_reduce dims become early locals (conflicting with TC) ], apply_tc=True, atol=atol, rtol=rtol) @unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores") @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") def test_tensor_core_opts_locals(self): N = 128 Tensor.manual_seed(1552) for tc in Device[Device.DEFAULT].renderer.tensor_cores: # bf16 buffer returns float32 numpy outputs so test would fail. testing opt with half suffices. if tc.dtype_in == dtypes.bfloat16: continue a, b = Tensor.rand(N, N, dtype=tc.dtype_in), Tensor.rand(N, N, dtype=tc.dtype_in) r = a.matmul(b, acc_dtype=tc.dtype_out) (atol, rtol) = ((0.25, 0.01) if tc.dtype_out == dtypes.half else (3e-2, 1e-3)) if tc.dtype_in == dtypes.half else (1e-4, 1e-4) helper_linearizer_opt(r, [ [Opt(OptOps.UNROLL, 0, 0)], # check full unroll of reduce with locals [Opt(OptOps.LOCAL, 0, 4)], # check local [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.LOCAL, 0, 2)], [Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UPCAST, 0, 4)], ], apply_tc=True, atol=atol, rtol=rtol) def test_padto_matmul(self): if CI and Device.DEFAULT in ["AMD", "NV", "CUDA"]: self.skipTest("super slow on CUDA and AMD because of the big grid dims") N = 17 * 17 Tensor.manual_seed(289) a = Tensor.rand(N, N) b = Tensor.rand(N, N) helper_linearizer_opt(a@b, [ [Opt(OptOps.PADTO, 0, 32)], [Opt(OptOps.PADTO, 1, 32)], [Opt(OptOps.PADTO, 2, 32)], [Opt(OptOps.PADTO, 0, 32), Opt(OptOps.PADTO, 1, 32)], [Opt(OptOps.PADTO, 0, 32), Opt(OptOps.PADTO, 1, 32), Opt(OptOps.PADTO, 2, 32)], # can optimize further post PADTO [Opt(OptOps.PADTO, 0, 32), Opt(OptOps.PADTO, 1, 32), Opt(OptOps.UPCAST, 0, 2), Opt(OptOps.UPCAST, 1, 2),], ]) def test_padto_upcasted_not_ok(self): N = 4 a = Tensor.rand(N, N) b = Tensor.rand(N, N) helper_linearizer_opt(a@b, [ [Opt(OptOps.UPCAST, 0, 0)], [Opt(OptOps.UPCAST, 1, 0)], [Opt(OptOps.UNROLL, 0, 0)], [Opt(OptOps.PADTO, 0, 8)], [Opt(OptOps.PADTO, 1, 8)], [Opt(OptOps.PADTO, 2, 8)], ]) with self.assertRaises(KernelOptError): helper_linearizer_opt(a@b, [[Opt(OptOps.UPCAST, 0, 0), Opt(OptOps.PADTO, 2, 8)]]) with self.assertRaises(KernelOptError): helper_linearizer_opt(a@b, [[Opt(OptOps.UPCAST, 1, 0), Opt(OptOps.PADTO, 2, 8)]]) with self.assertRaises(KernelOptError): helper_linearizer_opt(a@b, [[Opt(OptOps.UNROLL, 0, 0), Opt(OptOps.PADTO, 2, 8)]]) def test_padto_sum_ok(self): N = 18 * 18 # NOTE: this setup prevents 17 * 17 contiguous merged into one dimension a = Tensor.rand(N, N).shrink(((0, 17), (0, 17))) * 100 b = (Tensor.rand(N, N) < 0.5).realize().shrink(((0, 17), (0, 17))) helper_linearizer_opt(a.sum(0), [ [Opt(OptOps.PADTO, 0, 32)], [Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),], ]) helper_linearizer_opt(a.sum(1), [ [Opt(OptOps.PADTO, 0, 32)], [Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),], ]) # can pad sum reduce axis if there's no unsafe ops prior to sum for axis in (0, 1): helper_linearizer_opt(a.sum(), [[Opt(OptOps.PADTO, axis, 32)],]) helper_linearizer_opt(a.sum(0), [[Opt(OptOps.PADTO, axis, 32)],]) helper_linearizer_opt(b.sum(), [[Opt(OptOps.PADTO, axis, 32)],]) helper_linearizer_opt(b.sum(0), [[Opt(OptOps.PADTO, axis, 32)],]) helper_linearizer_opt(b.sum(acc_dtype=dtypes.bool), [[Opt(OptOps.PADTO, axis, 32)],]) helper_linearizer_opt(b.sum(0, acc_dtype=dtypes.bool), [[Opt(OptOps.PADTO, axis, 32)],]) helper_linearizer_opt(b.sum(1, acc_dtype=dtypes.bool), [[Opt(OptOps.PADTO, axis, 32)],]) # having unsafe ops after sum is fine helper_linearizer_opt(a.sum().exp(), [[Opt(OptOps.PADTO, 0, 32)],]) helper_linearizer_opt(a.sum(0).exp(), [[Opt(OptOps.PADTO, 1, 32)],]) def test_padto_sum_not_ok(self): N = 18 * 18 # NOTE: this setup prevents 17 * 17 contiguous merged into one dimension a = Tensor.rand(N, N).shrink(((0, 17), (0, 17))).exp() # exp is not safe to pad with self.assertRaises(KernelOptError): helper_linearizer_opt(a.exp().sum(), [[Opt(OptOps.PADTO, 0, 32)],]) with self.assertRaises(KernelOptError): helper_linearizer_opt(a.exp().sum(0), [[Opt(OptOps.PADTO, 1, 32)],]) b = a < -1 # lt is not safe to pad with self.assertRaises(KernelOptError): helper_linearizer_opt(b.sum(), [[Opt(OptOps.PADTO, 0, 32)],]) with self.assertRaises(KernelOptError): helper_linearizer_opt(b.sum(0), [[Opt(OptOps.PADTO, 1, 32)],]) def test_padto_max(self): N = 18 * 18 # NOTE: this setup prevents 17 * 17 contiguous merged into one axis a = -Tensor.rand(N, N).shrink(((0, 17), (0, 17))) * 100 helper_linearizer_opt(a.max(0), [ [Opt(OptOps.PADTO, 0, 32)], [Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),], ]) helper_linearizer_opt(a.max(1), [ [Opt(OptOps.PADTO, 0, 32)], [Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),], ]) # cannot pad max kernel on reduce with self.assertRaises(KernelOptError): helper_linearizer_opt(a.max(), [[Opt(OptOps.PADTO, 0, 32)],]) with self.assertRaises(KernelOptError): helper_linearizer_opt(a.max(0), [[Opt(OptOps.PADTO, 1, 32)],]) def test_padto_where(self): Tensor.manual_seed(0) N = 17 * 17 a = (Tensor.randn(N, N).realize().max(axis=0, keepdim=True) > 1).where(1, 0) helper_linearizer_opt(a.max(0), [ [Opt(OptOps.PADTO, 0, 32)], [Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),], ]) def test_padto_where_multioutput(self): Tensor.manual_seed(0) N = 17 * 17 r = Tensor.randn(N, N).realize().max(axis=0, keepdim=True) > 1 a0 = r.where(1, 0) a1 = r.where(2, 0) helper_linearizer_opt([a0.max(0), a1.max(0)], [ [Opt(OptOps.PADTO, 0, 32)], [Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),], ]) @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared") def test_padto_group(self): Tensor.manual_seed(0) ld0 = LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=False),)))) # noqa: E501 ld1 = LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)))) # noqa: E501 ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(ld0, ld1)),), arg=(0, 2, 4, 6)),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 1, 1, 1, 4, 1, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=True),)))) # noqa: E501 data1 = Tensor.randn(2, 1, 4, 1, 3, 4, 2, 6, 1, 3).realize() data2 = Tensor.randn(2, 1, 4, 1, 3, 4, 2, 6, 1, 3).realize() helper_linearizer_ast((ast, ), [data1, data2], opts=[ [Opt(OptOps.PADTO, 0, 32), Opt(OptOps.GROUP, 0, 4)], [Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8)], [Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8), Opt(OptOps.GROUP, 0, 4)] ]) @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared") def test_color_shapes_with_local(self): N = 32 Tensor.manual_seed(1552) a = Tensor.rand(N, N) b = Tensor.rand(N, N) r = a@b opts_shapes = [ ([Opt(OptOps.LOCAL, 0, 2)], [("blue",16),("blue",32),("cyan",2),("red",32)]), ([Opt(OptOps.LOCAL, 0, 2),Opt(OptOps.GROUP, 0, 2)], [("blue",16),("blue",32),("cyan",2),("green",2),("red",16)]), # check to ensure local_dims are stable for full UNROLL of first_reduce ([Opt(OptOps.LOCAL, 0, 2),Opt(OptOps.UNROLL, 0, 0)], [("blue",16),("blue",32),("cyan",2),("magenta",32)]), ([Opt(OptOps.UNROLL, 0, 0),Opt(OptOps.LOCAL, 0, 2)], [("blue",16),("blue",32),("cyan",2),("magenta",32)]), # check behavior for full UNROLL on an existing GROUP ([Opt(OptOps.LOCAL, 0, 2),Opt(OptOps.GROUP, 0, 0),Opt(OptOps.UNROLL, 0, 2)], [("blue",16),("blue",32),("cyan",2),("green",16),("magenta",2)]), ([Opt(OptOps.LOCAL, 0, 2),Opt(OptOps.GROUP, 0, 0),Opt(OptOps.UNROLL, 0, 0)], [("blue",16),("blue",32),("cyan",2),("magenta",32)]), ([Opt(OptOps.GROUP, 0, 0),Opt(OptOps.LOCAL, 0, 2),Opt(OptOps.UNROLL, 0, 0)], [("blue",16),("blue",32),("cyan",2),("magenta",32)]), ([Opt(OptOps.GROUP, 0, 2),Opt(OptOps.UNROLL, 0, 0)], [("blue",32),("blue",32),("red",16),("magenta",2)]), ] helper_linearizer_opt(r, [x[0] for x in opts_shapes], color_sizes=[x[1] for x in opts_shapes]) if __name__ == '__main__': unittest.main()