mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
tk reg local store (#13689)
This commit is contained in:
@@ -404,7 +404,29 @@ class Group:
|
||||
dst, src = cast(UOp, dst), cast(UOp, src)
|
||||
assert isinstance(dst.dtype, PtrDType) and isinstance(src.dtype, PtrDType)
|
||||
dst_dtype, src_dtype = dst.dtype, src.dtype
|
||||
if src_dtype.addrspace == AddrSpace.REG and dst_dtype.addrspace == AddrSpace.GLOBAL and isinstance(src, RT):
|
||||
if src_dtype.addrspace == AddrSpace.REG and dst_dtype.addrspace == AddrSpace.LOCAL:
|
||||
laneid = self.ker.laneid
|
||||
st, rt = cast(ST, dst), cast(RT, src)
|
||||
elements_per_thread = rt.base_shape.elements_per_thread
|
||||
|
||||
for height in self.ker.range(src.shape[-3], track=False):
|
||||
for width in self.ker.range(src.shape[-2], track=False):
|
||||
for inner in self.ker.range(elements_per_thread, track=False):
|
||||
if rt.layout != st.layout:
|
||||
row = rt.base_shape.stride * (laneid // rt.base_shape.cols) + inner
|
||||
col = laneid % rt.base_shape.cols
|
||||
else:
|
||||
row = laneid % rt.base_shape.rows
|
||||
col = rt.base_shape.stride * (laneid // rt.base_shape.rows) + inner
|
||||
|
||||
srow, scol = cast(ST, dst).swizzle(row, col)
|
||||
|
||||
src_load = src[*src_idxs, height, width, inner]
|
||||
if src.dtype.base != dst.dtype.base:
|
||||
src_load = src_load.cast(dst.dtype.base)
|
||||
dst_store = dst[*idxs[:-2], height, width, srow, scol].store(src_load)
|
||||
dst_store = dst_store.end(height, width, inner)
|
||||
elif src_dtype.addrspace == AddrSpace.REG and dst_dtype.addrspace == AddrSpace.GLOBAL and isinstance(src, RT):
|
||||
dstf = dst.flatten()
|
||||
row_stride = prod(dst.shape[axis+1:])
|
||||
|
||||
|
||||
@@ -31,14 +31,16 @@ class TestTK(unittest.TestCase):
|
||||
|
||||
a_smem = ker.st((BLOCK_SIZE, BLOCK_SIZE), dtypes.bfloat16)
|
||||
b_smem = ker.st((BLOCK_SIZE, BLOCK_SIZE), dtypes.bfloat16)
|
||||
c_smem = ker.st((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32)
|
||||
|
||||
a_reg = ker.rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.bfloat16)
|
||||
b_reg = ker.rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.bfloat16, TileLayout.COL)
|
||||
c_reg = ker.rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32, TileLayout.COL)
|
||||
c_reg_col = ker.rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32, TileLayout.COL)
|
||||
c_reg = ker.rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32)
|
||||
|
||||
col, row = ker.blockIdx_x, ker.blockIdx_y
|
||||
|
||||
c_reg = warp.zero(c_reg)
|
||||
c_reg_col = warp.zero(c_reg_col)
|
||||
for tile in ker.range(N // BLOCK_SIZE):
|
||||
a_smem = warp.load(a_smem, a, (), (0, 0, row, tile), axis=2)
|
||||
b_smem = warp.load(b_smem, b, (), (0, 0, tile, col), axis=2)
|
||||
@@ -46,8 +48,11 @@ class TestTK(unittest.TestCase):
|
||||
a_reg = warp.load(a_reg, a_smem)
|
||||
b_reg = warp.load(b_reg, b_smem)
|
||||
|
||||
c_reg = warp.mma_AB(c_reg, a_reg, b_reg)
|
||||
c_reg = ker.endrange()
|
||||
c_reg_col = warp.mma_AB(c_reg_col, a_reg, b_reg)
|
||||
c_reg_col = ker.endrange()
|
||||
|
||||
c_smem = warp.store(c_smem, c_reg_col)
|
||||
c_reg = warp.load(c_reg, c_smem)
|
||||
|
||||
c = warp.store(c, c_reg, (0, 0, row, col), (), axis=2)
|
||||
|
||||
@@ -152,6 +157,45 @@ class TestTK(unittest.TestCase):
|
||||
|
||||
np.testing.assert_allclose(b.numpy(), ref.numpy())
|
||||
|
||||
def test_load_store_local_hop(self):
|
||||
N = 64
|
||||
BLOCK_SIZE = 32
|
||||
with Kernel("load_store_local_hop", (N // BLOCK_SIZE, N // BLOCK_SIZE, 1), WARP_THREADS) as ker:
|
||||
warp = ker.warp
|
||||
|
||||
b = ker.gl((1, 1, N, N), dtypes.float32)
|
||||
a = ker.gl((1, 1, N, N), dtypes.float32)
|
||||
|
||||
a_smem = ker.st((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32)
|
||||
b_smem = ker.st((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32)
|
||||
|
||||
a_reg = ker.rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32)
|
||||
b_reg = ker.rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32)
|
||||
|
||||
col, row = ker.blockIdx_x, ker.blockIdx_y
|
||||
|
||||
a_smem = warp.load(a_smem, a, (), (0, 0, row, col), axis=2)
|
||||
a_reg = warp.load(a_reg, a_smem)
|
||||
b_reg = warp.copy(b_reg, a_reg)
|
||||
b_smem = warp.store(b_smem, b_reg)
|
||||
b_reg = warp.load(b_reg, b_smem)
|
||||
b = warp.store(b, b_reg, (0, 0, row, col), (), axis=2)
|
||||
|
||||
sink = ker.finish()
|
||||
|
||||
with Context(DEBUG=0):
|
||||
a = Tensor.rand(1, 1, N, N, dtype="float32").contiguous()
|
||||
b = Tensor.empty(1, 1, N, N, dtype="float32")
|
||||
Tensor.realize(a, b)
|
||||
|
||||
ei = ExecItem(get_runner(Device.DEFAULT, sink), [t.uop.buffer for t in (b, a)])
|
||||
for _ in range(5): ei.run(wait=True)
|
||||
b = b.float()
|
||||
|
||||
ref = a.float()
|
||||
|
||||
np.testing.assert_allclose(b.numpy(), ref.numpy())
|
||||
|
||||
@unittest.skip("TODO")
|
||||
def test_load_store_group(self):
|
||||
N = 256
|
||||
|
||||
Reference in New Issue
Block a user