From 62e2fc5108bf10744cd4d2e6348471c10267c1ff Mon Sep 17 00:00:00 2001 From: wozeparrot Date: Thu, 4 Dec 2025 17:23:48 -0800 Subject: [PATCH] tk: global load/store rv (#13577) --- extra/thunder/tiny/tk/group.py | 50 ++++++++++++++++++++++++++++++---- extra/thunder/tiny/tk/tiles.py | 9 +++--- 2 files changed, 50 insertions(+), 9 deletions(-) diff --git a/extra/thunder/tiny/tk/group.py b/extra/thunder/tiny/tk/group.py index c511d9dc05..d5702eff6b 100644 --- a/extra/thunder/tiny/tk/group.py +++ b/extra/thunder/tiny/tk/group.py @@ -7,7 +7,7 @@ from tinygrad.dtype import AddrSpace, PtrDType from tinygrad.helpers import getenv, prod from extra.thunder.tiny.tk import WARP_THREADS -from extra.thunder.tiny.tk.tiles import ALL_TILES, GL, RT_16X16, RT_16X32, ST, RT, RV, TileLayout +from extra.thunder.tiny.tk.tiles import ALL_TILES, GL, RT_16X16, RT_16X32, ST, RT, RV, TileLayout, VecLayout class Group: def __init__(self, warps:int, ker): @@ -338,7 +338,7 @@ class Group: src_load = src_load.cast(dst.dtype.base) dst_store = dst[*dst_idxs, height, width, srow, scol].store(src_load) dst_store = dst_store.end(height, width, outer, inner).barrier() - elif dst_dtype.addrspace == AddrSpace.REG and src_dtype.addrspace ==AddrSpace.GLOBAL: + elif dst_dtype.addrspace == AddrSpace.REG and src_dtype.addrspace == AddrSpace.GLOBAL and isinstance(dst, RT): srcf = src.flatten() row_stride = prod(src.shape[axis+1:]) @@ -371,8 +371,28 @@ class Group: if src.dtype.base != dst.dtype.base: src_load = src_load.cast(dst.dtype.base) dst_store = dst[*dst_idxs, height, width, inner].store(src_load).end(height, width, inner) + elif dst_dtype.addrspace == AddrSpace.REG and src_dtype.addrspace == AddrSpace.GLOBAL and isinstance(dst, RV): + srcf = src.flatten() + row_stride = prod(src.shape[axis+1:]) + + laneid = self.ker.laneid + rv = cast(RV, dst) + reductions = rv.base_shape.rows + + assert rv.layout == VecLayout.ORTHO, "only ortho layout supported" + + idxs = tuple(idx * rv.length if i == 3 else idx for i, idx in enumerate(idxs)) + src_i = ((idxs[0] * src.shape[-3] + idxs[1]) * src.shape[-2] + idxs[2]) * src.shape[-1] + idxs[3] + + for outer in self.ker.range(dst.shape[-2]): + src_i += outer * reductions + (laneid % reductions) + + src_load = srcf[src_i] + if src.dtype.base != dst.dtype.base: + src_load = src_load.cast(dst.dtype.base) + dst_store = dst[outer, 0].store(src_load).end(outer) else: - raise NotImplementedError(f"load from {src_dtype.addrspace} to {dst_dtype.addrspace} not implemented") + raise NotImplementedError(f"load from {src_dtype.addrspace} to {dst_dtype.addrspace} not implemented for {type(dst)=}") self.ker.push_store(dst_store, dst) return dst.after(dst_store).reshape(dst.shape) @@ -381,7 +401,7 @@ class Group: dst, src = cast(UOp, dst), cast(UOp, src) assert isinstance(dst.dtype, PtrDType) and isinstance(src.dtype, PtrDType) dst_dtype, src_dtype = cast(PtrDType, dst.dtype), cast(PtrDType, src.dtype) - if src_dtype.addrspace == AddrSpace.REG and dst_dtype.addrspace == AddrSpace.GLOBAL: + if src_dtype.addrspace == AddrSpace.REG and dst_dtype.addrspace == AddrSpace.GLOBAL and isinstance(src, RT): dstf = dst.flatten() row_stride = prod(dst.shape[axis+1:]) @@ -414,8 +434,28 @@ class Group: if src.dtype.base != dst.dtype.base: src_load = src_load.cast(dst.dtype.base) dst_store = dstf[dst_i].store(src_load).end(height, width, inner) + elif src_dtype.addrspace == AddrSpace.REG and dst_dtype.addrspace == AddrSpace.GLOBAL and isinstance(src, RV): + dstf = dst.flatten() + row_stride = prod(dst.shape[axis+1:]) + + laneid = self.ker.laneid + rv = cast(RV, src) + reductions = rv.base_shape.rows + + assert rv.layout == VecLayout.ORTHO, "only ortho layout supported" + + idxs = tuple(idx * rv.length if i == 3 else idx for i, idx in enumerate(idxs)) + dst_i = ((idxs[0] * dst.shape[-3] + idxs[1]) * dst.shape[-2] + idxs[2]) * dst.shape[-1] + idxs[3] + + for outer in self.ker.range(src.shape[-2]): + dst_i += outer * reductions + (laneid % reductions) + + src_load = src[outer, 0] + if src.dtype.base != dst.dtype.base: + src_load = src_load.cast(dst.dtype.base) + dst_store = dstf[dst_i].store(src_load).end(outer) else: - raise NotImplementedError(f"store from {src_dtype.addrspace} to {dst_dtype.addrspace} not implemented") + raise NotImplementedError(f"store from {src_dtype.addrspace} to {dst_dtype.addrspace} not implemented for {type(src)=}") self.ker.push_store(dst_store, dst) return dst.after(dst_store).reshape(dst.shape) diff --git a/extra/thunder/tiny/tk/tiles.py b/extra/thunder/tiny/tk/tiles.py index 6cacb644d6..58f0ce1349 100644 --- a/extra/thunder/tiny/tk/tiles.py +++ b/extra/thunder/tiny/tk/tiles.py @@ -250,11 +250,12 @@ class RT(TileMathMixin): @autowrap(UOp) class RV(TileMathMixin): - def __init__(self, uop:UOp, layout:VecLayout, ker): - self._uop, self.layout, self.ker = uop, layout, ker + def __init__(self, uop:UOp, length:int, layout:VecLayout, base_shape:RTBaseShape, ker): + self._uop, self.ker = uop, ker + self.length, self.layout, self.base_shape = length, layout, base_shape def ruop(self, uop:UOp): - return RV(uop, self.layout, self.ker) + return RV(uop, self.length, self.layout, self.base_shape, self.ker) @classmethod def create(cls, length, dtype:DType, layout:VecLayout, base_shape:RTBaseShape, ker): @@ -266,6 +267,6 @@ class RV(TileMathMixin): outer_dim = tiles uop = ker.alloc((outer_dim, inner_dim), dtype, AddrSpace.REG) - return RV(uop, layout, ker) + return RV(uop, length, layout, base_shape, ker) ALL_TILES = UOp | GL | ST | RT | RV