tk: flat global -> local load (#14033)

This commit is contained in:
wozeparrot
2026-01-06 02:35:53 -05:00
committed by GitHub
parent 3170365a5b
commit 21d0f6bb76

View File

@@ -312,28 +312,31 @@ class Group:
idxs = tuple(idx * st.cols if i == 3 else idx for i, idx in enumerate(idxs))
src_i = ((idxs[0] * src.shape[-3] + idxs[1]) * src.shape[-2] + idxs[2]) * src.shape[-1] + idxs[3]
for height in self.ker.range(dst.shape[-4], track=False):
for width in self.ker.range(dst.shape[-3], track=False):
elements_per_thread = st.base_shape.elements_per_thread
memcpy_per_row = st.base_shape.cols // elements_per_thread
total_calls = st.base_shape.num_elements // (self.group_threads * elements_per_thread)
elements_per_thread = st.base_shape.elements_per_thread
memcpy_per_row = st.cols // elements_per_thread
total_calls = (dst.shape[-4] * dst.shape[-3] * st.base_shape.num_elements) // (self.group_threads * elements_per_thread)
for outer in self.ker.range(total_calls, track=False):
for inner in self.ker.range(elements_per_thread, axis_type=AxisType.UPCAST, track=False):
load_idx = outer * self.group_threads + self.laneid
row = load_idx // memcpy_per_row
col = (load_idx * elements_per_thread) % st.base_shape.cols + inner
for outer in self.ker.range(total_calls, track=False):
for inner in self.ker.range(elements_per_thread, axis_type=AxisType.UPCAST, track=False):
load_idx = outer * self.group_threads + self.laneid
row = load_idx // memcpy_per_row
col = (load_idx * elements_per_thread) % st.cols + inner
height = row // st.base_shape.rows
width = col // st.base_shape.cols
srow, scol = cast(ST, dst).swizzle(row, col)
row = row % st.base_shape.rows
col = col % st.base_shape.cols
src_i += height * st.base_shape.rows * row_stride + width * st.base_shape.cols
src_i += row * row_stride + col
srow, scol = cast(ST, dst).swizzle(row, col)
src_load = srcf[src_i]
if src.dtype.base != dst.dtype.base:
src_load = src_load.cast(dst.dtype.base)
dst_store = dst[*dst_idxs, height, width, srow, scol].store(src_load)
dst_store = dst_store.end(height, width, outer, inner).barrier()
src_i += height * st.base_shape.rows * row_stride + width * st.base_shape.cols
src_i += row * row_stride + col
src_load = srcf[src_i]
if src.dtype.base != dst.dtype.base:
src_load = src_load.cast(dst.dtype.base)
dst_store = dst[*dst_idxs, height, width, srow, scol].store(src_load)
dst_store = dst_store.end(height, width, outer, inner).barrier()
elif dst_dtype.addrspace == AddrSpace.REG and src_dtype.addrspace == AddrSpace.GLOBAL and isinstance(dst, RT):
srcf = src.flatten()
row_stride = prod(src.shape[axis+1:])