diff --git a/extra/thunder/tiny/tk/group.py b/extra/thunder/tiny/tk/group.py index 27cc2fefa8..a2ed4888ba 100644 --- a/extra/thunder/tiny/tk/group.py +++ b/extra/thunder/tiny/tk/group.py @@ -296,9 +296,20 @@ class Group: row = laneid % rt.base_shape.rows col = rt.base_shape.stride * (laneid // rt.base_shape.rows) + inner + sheight = height + swidth = width + if len(idxs) == 2: + row_idx = idxs[0] * dst.shape[-3] * rt.base_shape.rows + col_idx = idxs[1] * dst.shape[-2] * rt.base_shape.cols + + row += row_idx % st.base_shape.rows + col += col_idx % st.base_shape.cols + sheight += row_idx // st.base_shape.rows + swidth += col_idx // st.base_shape.cols + srow, scol = cast(ST, src).swizzle(row, col) - src_load = src[*idxs[:-2], height, width, srow, scol] + src_load = src[*idxs[:-2], sheight, swidth, srow, scol] if src.dtype.base != dst.dtype.base: src_load = src_load.cast(dst.dtype.base) dst_store = dst[*dst_idxs, height, width, inner].store(src_load)