mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 06:58:11 -05:00
tk: support sliced local -> reg load (#14034)
This commit is contained in:
@@ -296,9 +296,20 @@ class Group:
|
||||
row = laneid % rt.base_shape.rows
|
||||
col = rt.base_shape.stride * (laneid // rt.base_shape.rows) + inner
|
||||
|
||||
sheight = height
|
||||
swidth = width
|
||||
if len(idxs) == 2:
|
||||
row_idx = idxs[0] * dst.shape[-3] * rt.base_shape.rows
|
||||
col_idx = idxs[1] * dst.shape[-2] * rt.base_shape.cols
|
||||
|
||||
row += row_idx % st.base_shape.rows
|
||||
col += col_idx % st.base_shape.cols
|
||||
sheight += row_idx // st.base_shape.rows
|
||||
swidth += col_idx // st.base_shape.cols
|
||||
|
||||
srow, scol = cast(ST, src).swizzle(row, col)
|
||||
|
||||
src_load = src[*idxs[:-2], height, width, srow, scol]
|
||||
src_load = src[*idxs[:-2], sheight, swidth, srow, scol]
|
||||
if src.dtype.base != dst.dtype.base:
|
||||
src_load = src_load.cast(dst.dtype.base)
|
||||
dst_store = dst[*dst_idxs, height, width, inner].store(src_load)
|
||||
|
||||
Reference in New Issue
Block a user