mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
LDS noop and spec (#9669)
* init lds noop and lds_0 spec * refactor lds helper test * fix typo * test all lds at the same time * change comment * comment * start test_lds_full * test_lds_tc * add tc spec
This commit is contained in:
@@ -2218,5 +2218,228 @@ class TestKernelOpts(unittest.TestCase):
|
||||
]
|
||||
helper_linearizer_opt(r, [x[0] for x in opts_shapes], color_sizes=[x[1] for x in opts_shapes])
|
||||
|
||||
if __name__ == '__main__':
|
||||
def helper_lds_allclose(opts:list[Opt], expected_bufs, N=16, M=16, K=16, dtype_in=dtypes.float, acc_dtype=dtypes.float):
|
||||
with Context(DEBUG=0): a, b = Tensor.rand(M, K, dtype=dtype_in).realize(), Tensor.rand(K, N, dtype=dtype_in).realize()
|
||||
realized_ast, bufs = helper_realized_ast(a.matmul(b, dtype=acc_dtype))
|
||||
k = Kernel(realized_ast)
|
||||
for opt in opts:
|
||||
k.apply_opt(opt)
|
||||
prg = k.to_program()
|
||||
CompiledRunner(replace(prg, device=Device.DEFAULT)).exec(bufs)
|
||||
|
||||
atol, rtol = 1e-4, 1e-4
|
||||
if dtype_in == dtypes.half: atol, rtol = 1e-2, 1e-2
|
||||
np.testing.assert_allclose(bufs[0].numpy().reshape((M,N)), a.numpy() @ b.numpy(), atol=atol, rtol=rtol)
|
||||
|
||||
local_buffers = [uop for uop in k.uops if uop.op is Ops.DEFINE_LOCAL]
|
||||
assert len(local_buffers) == len(expected_bufs), f"Expected exactly {len(expected_bufs)} local buffers, got {len(local_buffers)}"
|
||||
for i,(buf, sz) in enumerate(expected_bufs):
|
||||
assert local_buffers[i].arg == buf, f"Expected buffer argument index {buf}, got {local_buffers[i].arg}"
|
||||
expected_dtype = (acc_dtype if buf == 0 else dtype_in).ptr(sz, local=True)
|
||||
assert local_buffers[i].dtype == expected_dtype, f"Expected buffer dtype {expected_dtype}, got {local_buffers[i].dtype} for {opts=}"
|
||||
# TODO: check all access to the global buffer are proxied through the local buffer
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")
|
||||
class TestLDS(unittest.TestCase):
|
||||
# lds tile size for inputs are the same size as the memory accessed by each thread inside the reduce loop
|
||||
# test no reshape opt after lds? true for lds_swap
|
||||
# test TC3?
|
||||
|
||||
def test_lds_args(self):
|
||||
realized_ast, _ = helper_realized_ast(Tensor.rand(4, 4) @ Tensor.rand(4, 4))
|
||||
k = Kernel(realized_ast)
|
||||
valid_opts = [Opt(OptOps.LDS, 0, None),
|
||||
Opt(OptOps.LDS, 1, None),
|
||||
Opt(OptOps.LDS, 2, None)]
|
||||
for opt in valid_opts:
|
||||
k = Kernel(realized_ast)
|
||||
k.apply_opt(opt)
|
||||
|
||||
invalid_opts = [Opt(OptOps.LDS, -1, None),
|
||||
Opt(OptOps.LDS, 3, None)]
|
||||
for opt in invalid_opts:
|
||||
k = Kernel(realized_ast)
|
||||
with self.assertRaises(KernelOptError):
|
||||
k.apply_opt(opt)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_lds_output_basic(self):
|
||||
helper_lds_allclose(opts=[Opt(OptOps.LDS, 0, None)], expected_bufs=[(0,1)])
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_lds_input_basic(self):
|
||||
helper_lds_allclose(opts=[Opt(OptOps.LDS, 1, None)], expected_bufs=[(1,1)])
|
||||
helper_lds_allclose(opts=[Opt(OptOps.LDS, 2, None)], expected_bufs=[(2,1)])
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_lds_multi_basic(self):
|
||||
helper_lds_allclose(opts=[Opt(OptOps.LDS, 0, None), Opt(OptOps.LDS, 1, None)], expected_bufs=[(0,1),(1,1)])
|
||||
helper_lds_allclose(opts=[Opt(OptOps.LDS, 0, None), Opt(OptOps.LDS, 1, None), Opt(OptOps.LDS, 2, None)], expected_bufs=[(0,1),(1,1),(2,1)])
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_lds_unroll(self):
|
||||
# unroll doesn't change local output buffer size
|
||||
for sz in [2,4,8]:
|
||||
helper_lds_allclose(opts=[Opt(OptOps.UNROLL, 0, sz), Opt(OptOps.LDS, 0, None)], expected_bufs=[(0,1)])
|
||||
helper_lds_allclose(opts=[Opt(OptOps.UNROLL, 0, sz), Opt(OptOps.LDS, 1, None)], expected_bufs=[(1,sz)])
|
||||
helper_lds_allclose(opts=[Opt(OptOps.UNROLL, 0, sz), Opt(OptOps.LDS, 2, None)], expected_bufs=[(2,sz)])
|
||||
|
||||
@unittest.expectedFailure
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
|
||||
def test_lds_local(self):
|
||||
# if only locals are applied, local buffer size for output should be prod(locals)
|
||||
|
||||
basic_local_opts = [Opt(OptOps.LOCAL, 0, 2),
|
||||
Opt(OptOps.LDS, 0, None),
|
||||
Opt(OptOps.LDS, 1, None),
|
||||
Opt(OptOps.LDS, 2, None)]
|
||||
helper_lds_allclose(opts=basic_local_opts, expected_bufs=[(0,2),(1,2),(2,1)])
|
||||
|
||||
multi_local_opts = [Opt(OptOps.LOCAL, 0, 2),
|
||||
Opt(OptOps.LOCAL, 0, 8),
|
||||
Opt(OptOps.LDS, 0, None),
|
||||
Opt(OptOps.LDS, 1, None),
|
||||
Opt(OptOps.LDS, 2, None)]
|
||||
helper_lds_allclose(opts=multi_local_opts, expected_bufs=[(0,16),(1,16),(2,1)])
|
||||
|
||||
multi_axis_local_opts = [Opt(OptOps.LOCAL, 1, 4),
|
||||
Opt(OptOps.LOCAL, 0, 2),
|
||||
Opt(OptOps.LDS, 0, None),
|
||||
Opt(OptOps.LDS, 1, None),
|
||||
Opt(OptOps.LDS, 2, None)]
|
||||
helper_lds_allclose(opts=multi_axis_local_opts, expected_bufs=[(0,8),(1,2),(2,4)])
|
||||
|
||||
full_local_opts = [Opt(OptOps.LOCAL, 0, 16),
|
||||
Opt(OptOps.LOCAL, 0, 16),
|
||||
Opt(OptOps.LDS, 0, None),
|
||||
Opt(OptOps.LDS, 1, None),
|
||||
Opt(OptOps.LDS, 2, None)]
|
||||
helper_lds_allclose(opts=full_local_opts, expected_bufs=[(0,256),(1,16),(2,16)])
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_lds_upcast(self):
|
||||
# if only upcasts are applied, local buffer size for output should be prod(upcast)
|
||||
|
||||
basic_upcast_opts = [Opt(OptOps.UPCAST, 0, 2),
|
||||
Opt(OptOps.LDS, 0, None),
|
||||
Opt(OptOps.LDS, 1, None),
|
||||
Opt(OptOps.LDS, 2, None)]
|
||||
helper_lds_allclose(opts=basic_upcast_opts, expected_bufs=[(0,2),(1,2),(2,1)])
|
||||
|
||||
multi_upcast_opts = [Opt(OptOps.UPCAST, 0, 2),
|
||||
Opt(OptOps.UPCAST, 0, 8),
|
||||
Opt(OptOps.LDS, 0, None),
|
||||
Opt(OptOps.LDS, 1, None),
|
||||
Opt(OptOps.LDS, 2, None)]
|
||||
helper_lds_allclose(opts=multi_upcast_opts, expected_bufs=[(0,16),(1,16),(2,1)])
|
||||
|
||||
multi_axis_upcast_opts = [Opt(OptOps.UPCAST, 1, 4),
|
||||
Opt(OptOps.UPCAST, 0, 2),
|
||||
Opt(OptOps.LDS, 0, None),
|
||||
Opt(OptOps.LDS, 1, None),
|
||||
Opt(OptOps.LDS, 2, None)]
|
||||
helper_lds_allclose(opts=multi_axis_upcast_opts, expected_bufs=[(0,8),(1,2),(2,4)])
|
||||
|
||||
full_upcast_opts = [Opt(OptOps.UPCAST, 0, 16),
|
||||
Opt(OptOps.UPCAST, 0, 16),
|
||||
Opt(OptOps.LDS, 0, None),
|
||||
Opt(OptOps.LDS, 1, None),
|
||||
Opt(OptOps.LDS, 2, None)]
|
||||
helper_lds_allclose(opts=full_upcast_opts, expected_bufs=[(0,256),(1,16),(2,16)])
|
||||
|
||||
@unittest.expectedFailure
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
|
||||
def test_lds_tc(self):
|
||||
for tc in Device[Device.DEFAULT].renderer.tensor_cores:
|
||||
if tc.dtype_in == dtypes.bfloat16 or tc.dtype_out == dtypes.bfloat16: continue
|
||||
(N, M, K) = tc.dims
|
||||
opts = [Opt(OptOps.TC, 0, (-1, 0)),
|
||||
Opt(OptOps.LDS, 0, None),
|
||||
Opt(OptOps.LDS, 1, None),
|
||||
Opt(OptOps.LDS, 2, None)]
|
||||
helper_lds_allclose(opts=opts, expected_bufs=[(0,N*M),(1,M*K),(2,K*N)], N=N, M=M, K=K, dtype_in=tc.dtype_in, acc_dtype=tc.dtype_out)
|
||||
|
||||
opts = [Opt(OptOps.TC, 0, (-1, 0)),
|
||||
Opt(OptOps.LOCAL, 0, 2),
|
||||
Opt(OptOps.UPCAST, 1, 2),
|
||||
Opt(OptOps.LDS, 0, None),
|
||||
Opt(OptOps.LDS, 1, None),
|
||||
Opt(OptOps.LDS, 2, None)]
|
||||
helper_lds_allclose(opts=opts, expected_bufs=[(0,N*M*4),(1,M*K*2),(2,K*N*2)], N=N*4, M=M*4, K=K*4, dtype_in=tc.dtype_in, acc_dtype=tc.dtype_out)
|
||||
|
||||
opts = [Opt(OptOps.TC, 0, (-1, 0)),
|
||||
Opt(OptOps.UNROLL, 0, 2),
|
||||
Opt(OptOps.LDS, 0, None),
|
||||
Opt(OptOps.LDS, 1, None),
|
||||
Opt(OptOps.LDS, 2, None)]
|
||||
helper_lds_allclose(opts=opts, expected_bufs=[(0,N*M),(1,M*K*2),(2,K*N*2)], N=N*4, M=M*4, K=K*4, dtype_in=tc.dtype_in, acc_dtype=tc.dtype_out)
|
||||
|
||||
opts = [Opt(OptOps.TC, 0, (-1, 0)),
|
||||
Opt(OptOps.UNROLL, 0, 2),
|
||||
Opt(OptOps.UPCAST, 1, 2),
|
||||
Opt(OptOps.LDS, 0, None),
|
||||
Opt(OptOps.LDS, 1, None),
|
||||
Opt(OptOps.LDS, 2, None)]
|
||||
helper_lds_allclose(opts=opts, expected_bufs=[(0,N*M*2),(1,M*K*2),(2,K*N*4)], N=N*4, M=M*4, K=K*4, dtype_in=tc.dtype_in, acc_dtype=tc.dtype_out)
|
||||
|
||||
@unittest.expectedFailure
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
|
||||
def test_lds_tc_padded(self):
|
||||
for tc in Device[Device.DEFAULT].renderer.tensor_cores:
|
||||
if tc.dtype_in == dtypes.bfloat16 or tc.dtype_out == dtypes.bfloat16: continue
|
||||
(N, M, K) = tc.dims
|
||||
opts = [Opt(OptOps.TC, 0, (-1, 2)),
|
||||
Opt(OptOps.LDS, 0, None),
|
||||
Opt(OptOps.LDS, 1, None),
|
||||
Opt(OptOps.LDS, 2, None)]
|
||||
helper_lds_allclose(opts=opts, expected_bufs=[(0,N*M),(1,M*K),(2,K*N)], N=N+3, M=M+3, K=K+3, dtype_in=tc.dtype_in, acc_dtype=tc.dtype_out)
|
||||
|
||||
@unittest.expectedFailure
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
|
||||
def test_lds_full(self):
|
||||
opts = [Opt(OptOps.LOCAL, 0, 2),
|
||||
Opt(OptOps.UPCAST, 1, 2),
|
||||
Opt(OptOps.LDS, 0, None),
|
||||
Opt(OptOps.LDS, 1, None),
|
||||
Opt(OptOps.LDS, 2, None)]
|
||||
helper_lds_allclose(opts=opts, expected_bufs=[(0,4),(1,2),(2,2)])
|
||||
|
||||
opts = [Opt(OptOps.LOCAL, 0, 2),
|
||||
Opt(OptOps.UPCAST, 0, 4),
|
||||
Opt(OptOps.LOCAL, 1, 8),
|
||||
Opt(OptOps.LDS, 0, None),
|
||||
Opt(OptOps.LDS, 1, None),
|
||||
Opt(OptOps.LDS, 2, None)]
|
||||
helper_lds_allclose(opts=opts, expected_bufs=[(0,64),(1,8),(2,8)])
|
||||
|
||||
opts = [Opt(OptOps.LOCAL, 0, 16),
|
||||
Opt(OptOps.UPCAST, 1, 2),
|
||||
Opt(OptOps.LDS, 0, None),
|
||||
Opt(OptOps.LDS, 1, None),
|
||||
Opt(OptOps.LDS, 2, None)]
|
||||
helper_lds_allclose(opts=opts, expected_bufs=[(0,16),(1,16),(2,1)])
|
||||
|
||||
opts = [Opt(OptOps.LOCAL, 0, 16),
|
||||
Opt(OptOps.UPCAST, 0, 16),
|
||||
Opt(OptOps.LDS, 0, None),
|
||||
Opt(OptOps.LDS, 1, None),
|
||||
Opt(OptOps.LDS, 2, None)]
|
||||
helper_lds_allclose(opts=opts, expected_bufs=[(0,256),(1,16),(2,16)])
|
||||
|
||||
opts = [Opt(OptOps.LOCAL, 1, 16),
|
||||
Opt(OptOps.UPCAST, 1, 16),
|
||||
Opt(OptOps.LDS, 0, None),
|
||||
Opt(OptOps.LDS, 1, None),
|
||||
Opt(OptOps.LDS, 2, None)]
|
||||
helper_lds_allclose(opts=opts, expected_bufs=[(0,16),(1,1),(2,16)])
|
||||
|
||||
opts = [Opt(OptOps.LOCAL, 1, 4),
|
||||
Opt(OptOps.UNROLL, 0, 2),
|
||||
Opt(OptOps.UPCAST, 0, 2),
|
||||
Opt(OptOps.LDS, 0, None),
|
||||
Opt(OptOps.LDS, 1, None),
|
||||
Opt(OptOps.LDS, 2, None)]
|
||||
helper_lds_allclose(opts=opts, expected_bufs=[(0,8),(1,4),(2,8)])
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user