tk reg local store (#13689)

This commit is contained in:
wozeparrot
2025-12-14 23:07:30 -08:00
committed by GitHub
parent 572ca80046
commit 7ef7ce2856
2 changed files with 71 additions and 5 deletions

View File

@@ -404,7 +404,29 @@ class Group:
dst, src = cast(UOp, dst), cast(UOp, src)
assert isinstance(dst.dtype, PtrDType) and isinstance(src.dtype, PtrDType)
dst_dtype, src_dtype = dst.dtype, src.dtype
if src_dtype.addrspace == AddrSpace.REG and dst_dtype.addrspace == AddrSpace.GLOBAL and isinstance(src, RT):
if src_dtype.addrspace == AddrSpace.REG and dst_dtype.addrspace == AddrSpace.LOCAL:
laneid = self.ker.laneid
st, rt = cast(ST, dst), cast(RT, src)
elements_per_thread = rt.base_shape.elements_per_thread
for height in self.ker.range(src.shape[-3], track=False):
for width in self.ker.range(src.shape[-2], track=False):
for inner in self.ker.range(elements_per_thread, track=False):
if rt.layout != st.layout:
row = rt.base_shape.stride * (laneid // rt.base_shape.cols) + inner
col = laneid % rt.base_shape.cols
else:
row = laneid % rt.base_shape.rows
col = rt.base_shape.stride * (laneid // rt.base_shape.rows) + inner
srow, scol = cast(ST, dst).swizzle(row, col)
src_load = src[*src_idxs, height, width, inner]
if src.dtype.base != dst.dtype.base:
src_load = src_load.cast(dst.dtype.base)
dst_store = dst[*idxs[:-2], height, width, srow, scol].store(src_load)
dst_store = dst_store.end(height, width, inner)
elif src_dtype.addrspace == AddrSpace.REG and dst_dtype.addrspace == AddrSpace.GLOBAL and isinstance(src, RT):
dstf = dst.flatten()
row_stride = prod(dst.shape[axis+1:])

View File

@@ -31,14 +31,16 @@ class TestTK(unittest.TestCase):
a_smem = ker.st((BLOCK_SIZE, BLOCK_SIZE), dtypes.bfloat16)
b_smem = ker.st((BLOCK_SIZE, BLOCK_SIZE), dtypes.bfloat16)
c_smem = ker.st((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32)
a_reg = ker.rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.bfloat16)
b_reg = ker.rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.bfloat16, TileLayout.COL)
c_reg = ker.rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32, TileLayout.COL)
c_reg_col = ker.rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32, TileLayout.COL)
c_reg = ker.rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32)
col, row = ker.blockIdx_x, ker.blockIdx_y
c_reg = warp.zero(c_reg)
c_reg_col = warp.zero(c_reg_col)
for tile in ker.range(N // BLOCK_SIZE):
a_smem = warp.load(a_smem, a, (), (0, 0, row, tile), axis=2)
b_smem = warp.load(b_smem, b, (), (0, 0, tile, col), axis=2)
@@ -46,8 +48,11 @@ class TestTK(unittest.TestCase):
a_reg = warp.load(a_reg, a_smem)
b_reg = warp.load(b_reg, b_smem)
c_reg = warp.mma_AB(c_reg, a_reg, b_reg)
c_reg = ker.endrange()
c_reg_col = warp.mma_AB(c_reg_col, a_reg, b_reg)
c_reg_col = ker.endrange()
c_smem = warp.store(c_smem, c_reg_col)
c_reg = warp.load(c_reg, c_smem)
c = warp.store(c, c_reg, (0, 0, row, col), (), axis=2)
@@ -152,6 +157,45 @@ class TestTK(unittest.TestCase):
np.testing.assert_allclose(b.numpy(), ref.numpy())
def test_load_store_local_hop(self):
N = 64
BLOCK_SIZE = 32
with Kernel("load_store_local_hop", (N // BLOCK_SIZE, N // BLOCK_SIZE, 1), WARP_THREADS) as ker:
warp = ker.warp
b = ker.gl((1, 1, N, N), dtypes.float32)
a = ker.gl((1, 1, N, N), dtypes.float32)
a_smem = ker.st((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32)
b_smem = ker.st((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32)
a_reg = ker.rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32)
b_reg = ker.rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32)
col, row = ker.blockIdx_x, ker.blockIdx_y
a_smem = warp.load(a_smem, a, (), (0, 0, row, col), axis=2)
a_reg = warp.load(a_reg, a_smem)
b_reg = warp.copy(b_reg, a_reg)
b_smem = warp.store(b_smem, b_reg)
b_reg = warp.load(b_reg, b_smem)
b = warp.store(b, b_reg, (0, 0, row, col), (), axis=2)
sink = ker.finish()
with Context(DEBUG=0):
a = Tensor.rand(1, 1, N, N, dtype="float32").contiguous()
b = Tensor.empty(1, 1, N, N, dtype="float32")
Tensor.realize(a, b)
ei = ExecItem(get_runner(Device.DEFAULT, sink), [t.uop.buffer for t in (b, a)])
for _ in range(5): ei.run(wait=True)
b = b.float()
ref = a.float()
np.testing.assert_allclose(b.numpy(), ref.numpy())
@unittest.skip("TODO")
def test_load_store_group(self):
N = 256