mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
tk: flat global -> local load (#14033)
This commit is contained in:
@@ -312,28 +312,31 @@ class Group:
|
||||
idxs = tuple(idx * st.cols if i == 3 else idx for i, idx in enumerate(idxs))
|
||||
src_i = ((idxs[0] * src.shape[-3] + idxs[1]) * src.shape[-2] + idxs[2]) * src.shape[-1] + idxs[3]
|
||||
|
||||
for height in self.ker.range(dst.shape[-4], track=False):
|
||||
for width in self.ker.range(dst.shape[-3], track=False):
|
||||
elements_per_thread = st.base_shape.elements_per_thread
|
||||
memcpy_per_row = st.base_shape.cols // elements_per_thread
|
||||
total_calls = st.base_shape.num_elements // (self.group_threads * elements_per_thread)
|
||||
elements_per_thread = st.base_shape.elements_per_thread
|
||||
memcpy_per_row = st.cols // elements_per_thread
|
||||
total_calls = (dst.shape[-4] * dst.shape[-3] * st.base_shape.num_elements) // (self.group_threads * elements_per_thread)
|
||||
|
||||
for outer in self.ker.range(total_calls, track=False):
|
||||
for inner in self.ker.range(elements_per_thread, axis_type=AxisType.UPCAST, track=False):
|
||||
load_idx = outer * self.group_threads + self.laneid
|
||||
row = load_idx // memcpy_per_row
|
||||
col = (load_idx * elements_per_thread) % st.base_shape.cols + inner
|
||||
for outer in self.ker.range(total_calls, track=False):
|
||||
for inner in self.ker.range(elements_per_thread, axis_type=AxisType.UPCAST, track=False):
|
||||
load_idx = outer * self.group_threads + self.laneid
|
||||
row = load_idx // memcpy_per_row
|
||||
col = (load_idx * elements_per_thread) % st.cols + inner
|
||||
height = row // st.base_shape.rows
|
||||
width = col // st.base_shape.cols
|
||||
|
||||
srow, scol = cast(ST, dst).swizzle(row, col)
|
||||
row = row % st.base_shape.rows
|
||||
col = col % st.base_shape.cols
|
||||
|
||||
src_i += height * st.base_shape.rows * row_stride + width * st.base_shape.cols
|
||||
src_i += row * row_stride + col
|
||||
srow, scol = cast(ST, dst).swizzle(row, col)
|
||||
|
||||
src_load = srcf[src_i]
|
||||
if src.dtype.base != dst.dtype.base:
|
||||
src_load = src_load.cast(dst.dtype.base)
|
||||
dst_store = dst[*dst_idxs, height, width, srow, scol].store(src_load)
|
||||
dst_store = dst_store.end(height, width, outer, inner).barrier()
|
||||
src_i += height * st.base_shape.rows * row_stride + width * st.base_shape.cols
|
||||
src_i += row * row_stride + col
|
||||
|
||||
src_load = srcf[src_i]
|
||||
if src.dtype.base != dst.dtype.base:
|
||||
src_load = src_load.cast(dst.dtype.base)
|
||||
dst_store = dst[*dst_idxs, height, width, srow, scol].store(src_load)
|
||||
dst_store = dst_store.end(height, width, outer, inner).barrier()
|
||||
elif dst_dtype.addrspace == AddrSpace.REG and src_dtype.addrspace == AddrSpace.GLOBAL and isinstance(dst, RT):
|
||||
srcf = src.flatten()
|
||||
row_stride = prod(src.shape[axis+1:])
|
||||
|
||||
Reference in New Issue
Block a user