mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
tk reg local store (#13689)
This commit is contained in:
@@ -404,7 +404,29 @@ class Group:
|
||||
dst, src = cast(UOp, dst), cast(UOp, src)
|
||||
assert isinstance(dst.dtype, PtrDType) and isinstance(src.dtype, PtrDType)
|
||||
dst_dtype, src_dtype = dst.dtype, src.dtype
|
||||
if src_dtype.addrspace == AddrSpace.REG and dst_dtype.addrspace == AddrSpace.GLOBAL and isinstance(src, RT):
|
||||
if src_dtype.addrspace == AddrSpace.REG and dst_dtype.addrspace == AddrSpace.LOCAL:
|
||||
laneid = self.ker.laneid
|
||||
st, rt = cast(ST, dst), cast(RT, src)
|
||||
elements_per_thread = rt.base_shape.elements_per_thread
|
||||
|
||||
for height in self.ker.range(src.shape[-3], track=False):
|
||||
for width in self.ker.range(src.shape[-2], track=False):
|
||||
for inner in self.ker.range(elements_per_thread, track=False):
|
||||
if rt.layout != st.layout:
|
||||
row = rt.base_shape.stride * (laneid // rt.base_shape.cols) + inner
|
||||
col = laneid % rt.base_shape.cols
|
||||
else:
|
||||
row = laneid % rt.base_shape.rows
|
||||
col = rt.base_shape.stride * (laneid // rt.base_shape.rows) + inner
|
||||
|
||||
srow, scol = cast(ST, dst).swizzle(row, col)
|
||||
|
||||
src_load = src[*src_idxs, height, width, inner]
|
||||
if src.dtype.base != dst.dtype.base:
|
||||
src_load = src_load.cast(dst.dtype.base)
|
||||
dst_store = dst[*idxs[:-2], height, width, srow, scol].store(src_load)
|
||||
dst_store = dst_store.end(height, width, inner)
|
||||
elif src_dtype.addrspace == AddrSpace.REG and dst_dtype.addrspace == AddrSpace.GLOBAL and isinstance(src, RT):
|
||||
dstf = dst.flatten()
|
||||
row_stride = prod(dst.shape[axis+1:])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user