tk reg local store (#13689)

This commit is contained in:
wozeparrot
2025-12-14 23:07:30 -08:00
committed by GitHub
parent 572ca80046
commit 7ef7ce2856
2 changed files with 71 additions and 5 deletions

View File

@@ -404,7 +404,29 @@ class Group:
dst, src = cast(UOp, dst), cast(UOp, src)
assert isinstance(dst.dtype, PtrDType) and isinstance(src.dtype, PtrDType)
dst_dtype, src_dtype = dst.dtype, src.dtype
if src_dtype.addrspace == AddrSpace.REG and dst_dtype.addrspace == AddrSpace.GLOBAL and isinstance(src, RT):
if src_dtype.addrspace == AddrSpace.REG and dst_dtype.addrspace == AddrSpace.LOCAL:
laneid = self.ker.laneid
st, rt = cast(ST, dst), cast(RT, src)
elements_per_thread = rt.base_shape.elements_per_thread
for height in self.ker.range(src.shape[-3], track=False):
for width in self.ker.range(src.shape[-2], track=False):
for inner in self.ker.range(elements_per_thread, track=False):
if rt.layout != st.layout:
row = rt.base_shape.stride * (laneid // rt.base_shape.cols) + inner
col = laneid % rt.base_shape.cols
else:
row = laneid % rt.base_shape.rows
col = rt.base_shape.stride * (laneid // rt.base_shape.rows) + inner
srow, scol = cast(ST, dst).swizzle(row, col)
src_load = src[*src_idxs, height, width, inner]
if src.dtype.base != dst.dtype.base:
src_load = src_load.cast(dst.dtype.base)
dst_store = dst[*idxs[:-2], height, width, srow, scol].store(src_load)
dst_store = dst_store.end(height, width, inner)
elif src_dtype.addrspace == AddrSpace.REG and dst_dtype.addrspace == AddrSpace.GLOBAL and isinstance(src, RT):
dstf = dst.flatten()
row_stride = prod(dst.shape[axis+1:])