diff --git a/extra/thunder/tiny/tk/group.py b/extra/thunder/tiny/tk/group.py index 66e3187b87..27cc2fefa8 100644 --- a/extra/thunder/tiny/tk/group.py +++ b/extra/thunder/tiny/tk/group.py @@ -312,28 +312,31 @@ class Group: idxs = tuple(idx * st.cols if i == 3 else idx for i, idx in enumerate(idxs)) src_i = ((idxs[0] * src.shape[-3] + idxs[1]) * src.shape[-2] + idxs[2]) * src.shape[-1] + idxs[3] - for height in self.ker.range(dst.shape[-4], track=False): - for width in self.ker.range(dst.shape[-3], track=False): - elements_per_thread = st.base_shape.elements_per_thread - memcpy_per_row = st.base_shape.cols // elements_per_thread - total_calls = st.base_shape.num_elements // (self.group_threads * elements_per_thread) + elements_per_thread = st.base_shape.elements_per_thread + memcpy_per_row = st.cols // elements_per_thread + total_calls = (dst.shape[-4] * dst.shape[-3] * st.base_shape.num_elements) // (self.group_threads * elements_per_thread) - for outer in self.ker.range(total_calls, track=False): - for inner in self.ker.range(elements_per_thread, axis_type=AxisType.UPCAST, track=False): - load_idx = outer * self.group_threads + self.laneid - row = load_idx // memcpy_per_row - col = (load_idx * elements_per_thread) % st.base_shape.cols + inner + for outer in self.ker.range(total_calls, track=False): + for inner in self.ker.range(elements_per_thread, axis_type=AxisType.UPCAST, track=False): + load_idx = outer * self.group_threads + self.laneid + row = load_idx // memcpy_per_row + col = (load_idx * elements_per_thread) % st.cols + inner + height = row // st.base_shape.rows + width = col // st.base_shape.cols - srow, scol = cast(ST, dst).swizzle(row, col) + row = row % st.base_shape.rows + col = col % st.base_shape.cols - src_i += height * st.base_shape.rows * row_stride + width * st.base_shape.cols - src_i += row * row_stride + col + srow, scol = cast(ST, dst).swizzle(row, col) - src_load = srcf[src_i] - if src.dtype.base != dst.dtype.base: - src_load = src_load.cast(dst.dtype.base) - dst_store = dst[*dst_idxs, height, width, srow, scol].store(src_load) - dst_store = dst_store.end(height, width, outer, inner).barrier() + src_i += height * st.base_shape.rows * row_stride + width * st.base_shape.cols + src_i += row * row_stride + col + + src_load = srcf[src_i] + if src.dtype.base != dst.dtype.base: + src_load = src_load.cast(dst.dtype.base) + dst_store = dst[*dst_idxs, height, width, srow, scol].store(src_load) + dst_store = dst_store.end(height, width, outer, inner).barrier() elif dst_dtype.addrspace == AddrSpace.REG and src_dtype.addrspace == AddrSpace.GLOBAL and isinstance(dst, RT): srcf = src.flatten() row_stride = prod(src.shape[axis+1:])