feat: mixins on tiles (#13246)

This commit is contained in:
wozeparrot
2025-11-13 16:52:52 -08:00
committed by GitHub
parent ba84d415fe
commit 7eb0d8e744
4 changed files with 181 additions and 37 deletions

View File

@@ -7,7 +7,7 @@ from tinygrad.dtype import AddrSpace, PtrDType
from tinygrad.helpers import getenv, prod
from extra.thunder.tiny.tk import WARP_THREADS
from extra.thunder.tiny.tk.tiles import RT
from extra.thunder.tiny.tk.tiles import ALL_TILES, GL, ST, RT, RV
class Group:
def __init__(self, warps:int, ker):
@@ -27,7 +27,8 @@ class Group:
# ops that only work on a single warp
clear_rid = 1000
def clear(self, reg:UOp, value:float=0):
def clear(self, reg:ALL_TILES, value:float=0):
reg = cast(UOp, reg)
assert self.warps == 1
rngs_for_shape = tuple(UOp.range(dim, Group.clear_rid + i) for i, dim in enumerate(reg.shape))
@@ -38,11 +39,12 @@ class Group:
self.ker.push_store(reg_store, reg)
return reg.after(reg_store).reshape(reg.shape)
def zero(self, reg:UOp): return self.clear(reg, 0)
def neg_inf(self, reg:UOp): return self.clear(reg, -math.inf)
def zero(self, reg:ALL_TILES): return self.clear(reg, 0)
def neg_inf(self, reg:ALL_TILES): return self.clear(reg, -math.inf)
copy_rid = 300
def copy(self, dst:UOp, src:UOp):
def copy(self, dst:ALL_TILES, src:ALL_TILES):
dst, src = cast(UOp, dst), cast(UOp, src)
assert self.warps == 1
assert dst.shape == src.shape
@@ -54,7 +56,8 @@ class Group:
self.ker.push_store(dst_store, dst)
return dst.after(dst_store).reshape(dst.shape)
def mma_AB(self, c:UOp, a:UOp, b:UOp, after=True):
def mma_AB(self, c:UOp|RT, a:UOp|RT, b:UOp|RT, after=True):
c, a, b = cast(UOp, c), cast(UOp, a), cast(UOp, b)
assert self.warps == 1
for height in self.ker.range(c.shape[-3], track=False):
@@ -76,7 +79,8 @@ class Group:
self.ker.push_store(c_store, c)
return c.after(c_store).reshape(c.shape) if after else c_store
def mma_ABt(self, c:UOp, a:UOp, b:UOp, after=True):
def mma_ABt(self, c:UOp|RT, a:UOp|RT, b:UOp|RT, after=True):
c, a, b = cast(UOp, c), cast(UOp, a), cast(UOp, b)
assert self.warps == 1
for height in self.ker.range(c.shape[-3], track=False):
@@ -99,7 +103,8 @@ class Group:
return c.after(c_store).reshape(c.shape) if after else c_store
map_rid = 400
def map(self, a:UOp, op:Callable[[UOp], UOp]|Callable[[UOp, tuple], UOp]):
def map(self, a:ALL_TILES, op:Callable[[UOp], UOp]|Callable[[UOp, tuple], UOp]):
a = cast(UOp, a)
assert self.warps == 1
rngs_for_shape = tuple(UOp.range(dim, Group.map_rid + i) for i, dim in enumerate(a.shape))
@@ -115,7 +120,8 @@ class Group:
self.ker.push_store(a_store, a)
return a.after(a_store).reshape(a.shape)
def row_reduce(self, vec:UOp, src:UOp, op:Callable[[UOp, UOp], UOp]):
def row_reduce(self, vec:UOp|RV, src:UOp|RT, op:Callable[[UOp, UOp], UOp]):
vec, src = cast(UOp, vec), cast(UOp, src)
assert self.warps == 1
red_local = self.ker.alloc((self.group_threads, 2), src.dtype.base, AddrSpace.LOCAL)
@@ -157,7 +163,8 @@ class Group:
# ops that can work across multiple warps
LOAD_INNER = 8
def load(self, dst:UOp, src:UOp, dst_idxs:tuple[UOp|int,...]=(), idxs:tuple[UOp|int,...]=(), axis:int=0, transpose:bool=False):
def load(self, dst:ALL_TILES, src:ALL_TILES, dst_idxs:tuple[UOp|int,...]=(), idxs:tuple[UOp|int,...]=(), axis:int=0, transpose:bool=False):
dst, src = cast(UOp, dst), cast(UOp, src)
assert isinstance(dst.dtype, PtrDType) and isinstance(src.dtype, PtrDType)
dst_dtype, src_dtype = cast(PtrDType, dst.dtype), cast(PtrDType, src.dtype)
if dst_dtype.addrspace == AddrSpace.REG and src_dtype.addrspace == AddrSpace.LOCAL:
@@ -216,7 +223,8 @@ class Group:
return dst.after(dst_store.barrier()).reshape(dst.shape)
STORE_INNER = 8
def store(self, dst:UOp, src:UOp, idxs:tuple[UOp|int,...]=(), src_idxs:tuple[UOp|int,...]=(), axis=0, after=True):
def store(self, dst:ALL_TILES, src:ALL_TILES, idxs:tuple[UOp|int,...]=(), src_idxs:tuple[UOp|int,...]=(), axis=0, after=True):
dst, src = cast(UOp, dst), cast(UOp, src)
assert isinstance(dst.dtype, PtrDType) and isinstance(src.dtype, PtrDType)
dst_dtype, src_dtype = cast(PtrDType, dst.dtype), cast(PtrDType, src.dtype)
if src_dtype.addrspace == AddrSpace.REG and dst_dtype.addrspace == AddrSpace.LOCAL:

View File

@@ -68,10 +68,10 @@ class Kernel(AbstractContextManager):
return uop
def gl(self, shape, dtype): return GL(shape, dtype, self)._uop
def st(self, shape, dtype): return ST(shape, dtype, self)._uop
def rt(self, shape, dtype): return RT(shape, dtype, self)._uop
def rv(self, length, dtype, layout="naive"): return RV(length, dtype, layout, self)._uop
def gl(self, shape, dtype): return GL.create(shape, dtype, self)
def st(self, shape, dtype): return ST.create(shape, dtype, self)
def rt(self, shape, dtype): return RT.create(shape, dtype, self)
def rv(self, length, dtype, layout="naive"): return RV.create(length, dtype, layout, self)
def push_store(self, store:UOp, uop:UOp): self.store_stack.append((store, uop))
@@ -80,7 +80,7 @@ class Kernel(AbstractContextManager):
rngs = []
while self.range_stack: rngs.append(self.range_stack.pop(0)._rng)
return self.store_stack.pop()[0].end(*rngs).sink(arg=KernelInfo(opts_to_apply=())).simplify()
return self.store_stack.pop()[0]._uop.end(*rngs).sink(arg=KernelInfo(opts_to_apply=())).simplify()
def endrange(self):
last_store = self.store_stack.pop()

View File

@@ -1,23 +1,114 @@
import functools
from tinygrad.dtype import AddrSpace
from tinygrad.mixin import MathMixin
from tinygrad.uop.ops import UOp, Ops
from extra.thunder.tiny.tk import WARP_THREADS
def unwrap(x):
if hasattr(x, "_uop"): return x._uop
if isinstance(x, (list, tuple)): return type(x)(unwrap(y) for y in x)
if isinstance(x, dict): return {k: unwrap(v) for k,v in x.items()}
return x
def wrap(x, ker, cls):
if isinstance(x, UOp): return cls(x, ker)
if isinstance(x, (list, tuple)): return type(x)(wrap(y, ker, cls) for y in x)
return x
def autowrap(source_cls, blacklist=None):
if blacklist is None:
blacklist = {
"__init__", "__new__", "__str__", "__del__", "__repr__", "__dict__", "__getattribute__",
"__setattr__", "__delattr__", "__weakref__", "__slots__", "__class__",
"__reduce__", "__reduce_ex__", "__getstate__", "__setstate__", "__hash__"
}
def decorator(cls):
def __getattr__(self, name):
uop = object.__getattribute__(self, "_uop")
val = getattr(uop, name)
if callable(val):
@functools.wraps(val)
def proxy(*args, **kwargs):
return wrap(val(*unwrap(args), **unwrap(kwargs)), self.ker, cls)
return proxy
if name in UOp.__slots__: return val
return wrap(val, self.ker, cls)
cls.__getattr__ = __getattr__
for name in dir(source_cls):
if name in blacklist or not name.startswith("__"): continue
for base in cls.mro():
if base is source_cls: break
if name in base.__dict__: break
else:
original = getattr(source_cls, name)
if callable(original):
def make_proxy(op_name, func):
def proxy(self, *args, **kwargs):
return wrap(func(self._uop, *unwrap(args), **unwrap(kwargs)), self.ker, cls)
return proxy
setattr(cls, name, make_proxy(name, original))
return cls
return decorator
class TileMathMixin(MathMixin):
def alu(self, op, *src, inner_op=lambda x:x):
assert isinstance(self, (RT, RV))
if len(src) == 0:
if self._uop._shape is None: uop = UOp.alu(self._uop, op)
else: uop = self.ker.warp.map(self._uop, lambda x: UOp.alu(x, op))
elif len(src) == 1:
if self._uop._shape is None: uop = UOp.alu(self._uop, op, inner_op(self._uop.ufix(src[0])))
elif isinstance(src[0], (int,float,bool)): uop = self.ker.warp.map(self._uop, lambda x: UOp.alu(x, op, inner_op(x.ufix(src[0]))))
elif src[0]._shape is None: uop = UOp.alu(self._uop, op, inner_op(self._uop.ufix(src[0])))
else:
if isinstance(self, RT) and isinstance(src[0], RV): uop = self.ker.warp.map(self._uop, lambda x, idx: UOp.alu(x, op, inner_op(src[0]._uop[idx[0], 0, (idx[2]%4)//2])))
else: uop = self.ker.warp.map(self._uop, lambda x, idx: UOp.alu(x, op, inner_op(src[0]._uop[*idx])))
else: raise NotImplementedError
return type(self)(uop, self.ker)
def const_like(self, b): return b
# override ops that do compute on the src uop
def sub(self, x, reverse=False):
return self.ufix(x).alu(Ops.ADD, self, inner_op=lambda y: -y) if reverse else self.alu(Ops.ADD, self.ufix(x), inner_op=lambda y: -y)
def div(self, x, reverse=False):
return self.ufix(x).alu(Ops.MUL, self, inner_op=lambda y: 1/y) if reverse else self.alu(Ops.MUL, self.ufix(x), inner_op=lambda y: 1/y)
@autowrap(UOp)
class GL:
def __init__(self, shape, dtype, ker):
self.shape, self.dtype = shape, dtype
self._uop = ker.alloc(shape, dtype, AddrSpace.GLOBAL)
def __init__(self, uop, ker):
self._uop, self.ker = uop, ker
@classmethod
def create(cls, shape, dtype, ker):
uop = ker.alloc(shape, dtype, AddrSpace.GLOBAL)
return cls(uop, ker)
@autowrap(UOp)
class ST:
def __init__(self, shape, dtype, ker):
self.shape, self.dtype = shape, dtype
self._uop = ker.alloc(shape, dtype, AddrSpace.LOCAL)
def __init__(self, uop, ker):
self._uop, self.ker = uop, ker
class RT:
@classmethod
def create(cls, shape, dtype, ker):
uop = ker.alloc(shape, dtype, AddrSpace.LOCAL)
return cls(uop, ker)
@autowrap(UOp)
class RT(TileMathMixin):
TILE_ROW_DIM, TILE_COL_DIM = 16, 16
BASE_TILE_NE = TILE_ROW_DIM * TILE_COL_DIM
BASE_TILE_NEPT = BASE_TILE_NE // WARP_THREADS
def __init__(self, shape, dtype, ker):
def __init__(self, uop, ker):
self._uop, self.ker = uop, ker
@classmethod
def create(cls, shape, dtype, ker):
assert len(shape) == 2
assert shape[0] % RT.TILE_ROW_DIM == 0
assert shape[1] % RT.TILE_COL_DIM == 0
@@ -25,11 +116,16 @@ class RT:
height = shape[0] // RT.TILE_ROW_DIM
width = shape[1] // RT.TILE_COL_DIM
self.shape, self.dtype = (height, width, self.BASE_TILE_NEPT), dtype
self._uop = ker.alloc(self.shape, dtype, AddrSpace.REG)
uop = ker.alloc((height, width, RT.BASE_TILE_NEPT), dtype, AddrSpace.REG)
return cls(uop, ker)
class RV:
def __init__(self, length, dtype, layout, ker):
@autowrap(UOp)
class RV(TileMathMixin):
def __init__(self, uop, ker):
self._uop, self.ker = uop, ker
@classmethod
def create(cls, length, dtype, layout, ker):
tiles = length // RT.TILE_ROW_DIM
match layout:
@@ -41,5 +137,7 @@ class RV:
outer_dim = tiles
case _: raise NotImplementedError(f"rv layout {layout} not implemented")
self.shape, self.dtype = (outer_dim, inner_dim, 2), dtype
self._uop = ker.alloc(self.shape, dtype, AddrSpace.REG)
uop = ker.alloc((outer_dim, inner_dim, 2), dtype, AddrSpace.REG)
return RV(uop, ker)
ALL_TILES = UOp | GL | ST | RT | RV