mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
tk: global load/store rv (#13577)
This commit is contained in:
@@ -7,7 +7,7 @@ from tinygrad.dtype import AddrSpace, PtrDType
|
||||
from tinygrad.helpers import getenv, prod
|
||||
|
||||
from extra.thunder.tiny.tk import WARP_THREADS
|
||||
from extra.thunder.tiny.tk.tiles import ALL_TILES, GL, RT_16X16, RT_16X32, ST, RT, RV, TileLayout
|
||||
from extra.thunder.tiny.tk.tiles import ALL_TILES, GL, RT_16X16, RT_16X32, ST, RT, RV, TileLayout, VecLayout
|
||||
|
||||
class Group:
|
||||
def __init__(self, warps:int, ker):
|
||||
@@ -338,7 +338,7 @@ class Group:
|
||||
src_load = src_load.cast(dst.dtype.base)
|
||||
dst_store = dst[*dst_idxs, height, width, srow, scol].store(src_load)
|
||||
dst_store = dst_store.end(height, width, outer, inner).barrier()
|
||||
elif dst_dtype.addrspace == AddrSpace.REG and src_dtype.addrspace ==AddrSpace.GLOBAL:
|
||||
elif dst_dtype.addrspace == AddrSpace.REG and src_dtype.addrspace == AddrSpace.GLOBAL and isinstance(dst, RT):
|
||||
srcf = src.flatten()
|
||||
row_stride = prod(src.shape[axis+1:])
|
||||
|
||||
@@ -371,8 +371,28 @@ class Group:
|
||||
if src.dtype.base != dst.dtype.base:
|
||||
src_load = src_load.cast(dst.dtype.base)
|
||||
dst_store = dst[*dst_idxs, height, width, inner].store(src_load).end(height, width, inner)
|
||||
elif dst_dtype.addrspace == AddrSpace.REG and src_dtype.addrspace == AddrSpace.GLOBAL and isinstance(dst, RV):
|
||||
srcf = src.flatten()
|
||||
row_stride = prod(src.shape[axis+1:])
|
||||
|
||||
laneid = self.ker.laneid
|
||||
rv = cast(RV, dst)
|
||||
reductions = rv.base_shape.rows
|
||||
|
||||
assert rv.layout == VecLayout.ORTHO, "only ortho layout supported"
|
||||
|
||||
idxs = tuple(idx * rv.length if i == 3 else idx for i, idx in enumerate(idxs))
|
||||
src_i = ((idxs[0] * src.shape[-3] + idxs[1]) * src.shape[-2] + idxs[2]) * src.shape[-1] + idxs[3]
|
||||
|
||||
for outer in self.ker.range(dst.shape[-2]):
|
||||
src_i += outer * reductions + (laneid % reductions)
|
||||
|
||||
src_load = srcf[src_i]
|
||||
if src.dtype.base != dst.dtype.base:
|
||||
src_load = src_load.cast(dst.dtype.base)
|
||||
dst_store = dst[outer, 0].store(src_load).end(outer)
|
||||
else:
|
||||
raise NotImplementedError(f"load from {src_dtype.addrspace} to {dst_dtype.addrspace} not implemented")
|
||||
raise NotImplementedError(f"load from {src_dtype.addrspace} to {dst_dtype.addrspace} not implemented for {type(dst)=}")
|
||||
|
||||
self.ker.push_store(dst_store, dst)
|
||||
return dst.after(dst_store).reshape(dst.shape)
|
||||
@@ -381,7 +401,7 @@ class Group:
|
||||
dst, src = cast(UOp, dst), cast(UOp, src)
|
||||
assert isinstance(dst.dtype, PtrDType) and isinstance(src.dtype, PtrDType)
|
||||
dst_dtype, src_dtype = cast(PtrDType, dst.dtype), cast(PtrDType, src.dtype)
|
||||
if src_dtype.addrspace == AddrSpace.REG and dst_dtype.addrspace == AddrSpace.GLOBAL:
|
||||
if src_dtype.addrspace == AddrSpace.REG and dst_dtype.addrspace == AddrSpace.GLOBAL and isinstance(src, RT):
|
||||
dstf = dst.flatten()
|
||||
row_stride = prod(dst.shape[axis+1:])
|
||||
|
||||
@@ -414,8 +434,28 @@ class Group:
|
||||
if src.dtype.base != dst.dtype.base:
|
||||
src_load = src_load.cast(dst.dtype.base)
|
||||
dst_store = dstf[dst_i].store(src_load).end(height, width, inner)
|
||||
elif src_dtype.addrspace == AddrSpace.REG and dst_dtype.addrspace == AddrSpace.GLOBAL and isinstance(src, RV):
|
||||
dstf = dst.flatten()
|
||||
row_stride = prod(dst.shape[axis+1:])
|
||||
|
||||
laneid = self.ker.laneid
|
||||
rv = cast(RV, src)
|
||||
reductions = rv.base_shape.rows
|
||||
|
||||
assert rv.layout == VecLayout.ORTHO, "only ortho layout supported"
|
||||
|
||||
idxs = tuple(idx * rv.length if i == 3 else idx for i, idx in enumerate(idxs))
|
||||
dst_i = ((idxs[0] * dst.shape[-3] + idxs[1]) * dst.shape[-2] + idxs[2]) * dst.shape[-1] + idxs[3]
|
||||
|
||||
for outer in self.ker.range(src.shape[-2]):
|
||||
dst_i += outer * reductions + (laneid % reductions)
|
||||
|
||||
src_load = src[outer, 0]
|
||||
if src.dtype.base != dst.dtype.base:
|
||||
src_load = src_load.cast(dst.dtype.base)
|
||||
dst_store = dstf[dst_i].store(src_load).end(outer)
|
||||
else:
|
||||
raise NotImplementedError(f"store from {src_dtype.addrspace} to {dst_dtype.addrspace} not implemented")
|
||||
raise NotImplementedError(f"store from {src_dtype.addrspace} to {dst_dtype.addrspace} not implemented for {type(src)=}")
|
||||
|
||||
self.ker.push_store(dst_store, dst)
|
||||
return dst.after(dst_store).reshape(dst.shape)
|
||||
|
||||
@@ -250,11 +250,12 @@ class RT(TileMathMixin):
|
||||
|
||||
@autowrap(UOp)
|
||||
class RV(TileMathMixin):
|
||||
def __init__(self, uop:UOp, layout:VecLayout, ker):
|
||||
self._uop, self.layout, self.ker = uop, layout, ker
|
||||
def __init__(self, uop:UOp, length:int, layout:VecLayout, base_shape:RTBaseShape, ker):
|
||||
self._uop, self.ker = uop, ker
|
||||
self.length, self.layout, self.base_shape = length, layout, base_shape
|
||||
|
||||
def ruop(self, uop:UOp):
|
||||
return RV(uop, self.layout, self.ker)
|
||||
return RV(uop, self.length, self.layout, self.base_shape, self.ker)
|
||||
|
||||
@classmethod
|
||||
def create(cls, length, dtype:DType, layout:VecLayout, base_shape:RTBaseShape, ker):
|
||||
@@ -266,6 +267,6 @@ class RV(TileMathMixin):
|
||||
outer_dim = tiles
|
||||
|
||||
uop = ker.alloc((outer_dim, inner_dim), dtype, AddrSpace.REG)
|
||||
return RV(uop, layout, ker)
|
||||
return RV(uop, length, layout, base_shape, ker)
|
||||
|
||||
ALL_TILES = UOp | GL | ST | RT | RV
|
||||
|
||||
Reference in New Issue
Block a user