diff --git a/extra/thunder/tiny/tk/group.py b/extra/thunder/tiny/tk/group.py index 52b7e995d8..d1a853a1dd 100644 --- a/extra/thunder/tiny/tk/group.py +++ b/extra/thunder/tiny/tk/group.py @@ -404,7 +404,29 @@ class Group: dst, src = cast(UOp, dst), cast(UOp, src) assert isinstance(dst.dtype, PtrDType) and isinstance(src.dtype, PtrDType) dst_dtype, src_dtype = dst.dtype, src.dtype - if src_dtype.addrspace == AddrSpace.REG and dst_dtype.addrspace == AddrSpace.GLOBAL and isinstance(src, RT): + if src_dtype.addrspace == AddrSpace.REG and dst_dtype.addrspace == AddrSpace.LOCAL: + laneid = self.ker.laneid + st, rt = cast(ST, dst), cast(RT, src) + elements_per_thread = rt.base_shape.elements_per_thread + + for height in self.ker.range(src.shape[-3], track=False): + for width in self.ker.range(src.shape[-2], track=False): + for inner in self.ker.range(elements_per_thread, track=False): + if rt.layout != st.layout: + row = rt.base_shape.stride * (laneid // rt.base_shape.cols) + inner + col = laneid % rt.base_shape.cols + else: + row = laneid % rt.base_shape.rows + col = rt.base_shape.stride * (laneid // rt.base_shape.rows) + inner + + srow, scol = cast(ST, dst).swizzle(row, col) + + src_load = src[*src_idxs, height, width, inner] + if src.dtype.base != dst.dtype.base: + src_load = src_load.cast(dst.dtype.base) + dst_store = dst[*idxs[:-2], height, width, srow, scol].store(src_load) + dst_store = dst_store.end(height, width, inner) + elif src_dtype.addrspace == AddrSpace.REG and dst_dtype.addrspace == AddrSpace.GLOBAL and isinstance(src, RT): dstf = dst.flatten() row_stride = prod(dst.shape[axis+1:]) diff --git a/test/testextra/test_tk.py b/test/testextra/test_tk.py index 6cf9ef46f8..bbc9466d3b 100644 --- a/test/testextra/test_tk.py +++ b/test/testextra/test_tk.py @@ -31,14 +31,16 @@ class TestTK(unittest.TestCase): a_smem = ker.st((BLOCK_SIZE, BLOCK_SIZE), dtypes.bfloat16) b_smem = ker.st((BLOCK_SIZE, BLOCK_SIZE), dtypes.bfloat16) + c_smem = ker.st((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32) a_reg = ker.rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.bfloat16) b_reg = ker.rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.bfloat16, TileLayout.COL) - c_reg = ker.rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32, TileLayout.COL) + c_reg_col = ker.rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32, TileLayout.COL) + c_reg = ker.rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32) col, row = ker.blockIdx_x, ker.blockIdx_y - c_reg = warp.zero(c_reg) + c_reg_col = warp.zero(c_reg_col) for tile in ker.range(N // BLOCK_SIZE): a_smem = warp.load(a_smem, a, (), (0, 0, row, tile), axis=2) b_smem = warp.load(b_smem, b, (), (0, 0, tile, col), axis=2) @@ -46,8 +48,11 @@ class TestTK(unittest.TestCase): a_reg = warp.load(a_reg, a_smem) b_reg = warp.load(b_reg, b_smem) - c_reg = warp.mma_AB(c_reg, a_reg, b_reg) - c_reg = ker.endrange() + c_reg_col = warp.mma_AB(c_reg_col, a_reg, b_reg) + c_reg_col = ker.endrange() + + c_smem = warp.store(c_smem, c_reg_col) + c_reg = warp.load(c_reg, c_smem) c = warp.store(c, c_reg, (0, 0, row, col), (), axis=2) @@ -152,6 +157,45 @@ class TestTK(unittest.TestCase): np.testing.assert_allclose(b.numpy(), ref.numpy()) + def test_load_store_local_hop(self): + N = 64 + BLOCK_SIZE = 32 + with Kernel("load_store_local_hop", (N // BLOCK_SIZE, N // BLOCK_SIZE, 1), WARP_THREADS) as ker: + warp = ker.warp + + b = ker.gl((1, 1, N, N), dtypes.float32) + a = ker.gl((1, 1, N, N), dtypes.float32) + + a_smem = ker.st((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32) + b_smem = ker.st((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32) + + a_reg = ker.rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32) + b_reg = ker.rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32) + + col, row = ker.blockIdx_x, ker.blockIdx_y + + a_smem = warp.load(a_smem, a, (), (0, 0, row, col), axis=2) + a_reg = warp.load(a_reg, a_smem) + b_reg = warp.copy(b_reg, a_reg) + b_smem = warp.store(b_smem, b_reg) + b_reg = warp.load(b_reg, b_smem) + b = warp.store(b, b_reg, (0, 0, row, col), (), axis=2) + + sink = ker.finish() + + with Context(DEBUG=0): + a = Tensor.rand(1, 1, N, N, dtype="float32").contiguous() + b = Tensor.empty(1, 1, N, N, dtype="float32") + Tensor.realize(a, b) + + ei = ExecItem(get_runner(Device.DEFAULT, sink), [t.uop.buffer for t in (b, a)]) + for _ in range(5): ei.run(wait=True) + b = b.float() + + ref = a.float() + + np.testing.assert_allclose(b.numpy(), ref.numpy()) + @unittest.skip("TODO") def test_load_store_group(self): N = 256