From 7eb0d8e744b4c597f7f2b818169358e1365b2c7f Mon Sep 17 00:00:00 2001 From: wozeparrot Date: Thu, 13 Nov 2025 16:52:52 -0800 Subject: [PATCH] feat: mixins on tiles (#13246) --- extra/thunder/tiny/tk/group.py | 30 +++++--- extra/thunder/tiny/tk/kernel.py | 10 +-- extra/thunder/tiny/tk/tiles.py | 126 ++++++++++++++++++++++++++++---- test/testextra/test_tk.py | 52 +++++++++++-- 4 files changed, 181 insertions(+), 37 deletions(-) diff --git a/extra/thunder/tiny/tk/group.py b/extra/thunder/tiny/tk/group.py index 0ee4ec407d..5b8234cd15 100644 --- a/extra/thunder/tiny/tk/group.py +++ b/extra/thunder/tiny/tk/group.py @@ -7,7 +7,7 @@ from tinygrad.dtype import AddrSpace, PtrDType from tinygrad.helpers import getenv, prod from extra.thunder.tiny.tk import WARP_THREADS -from extra.thunder.tiny.tk.tiles import RT +from extra.thunder.tiny.tk.tiles import ALL_TILES, GL, ST, RT, RV class Group: def __init__(self, warps:int, ker): @@ -27,7 +27,8 @@ class Group: # ops that only work on a single warp clear_rid = 1000 - def clear(self, reg:UOp, value:float=0): + def clear(self, reg:ALL_TILES, value:float=0): + reg = cast(UOp, reg) assert self.warps == 1 rngs_for_shape = tuple(UOp.range(dim, Group.clear_rid + i) for i, dim in enumerate(reg.shape)) @@ -38,11 +39,12 @@ class Group: self.ker.push_store(reg_store, reg) return reg.after(reg_store).reshape(reg.shape) - def zero(self, reg:UOp): return self.clear(reg, 0) - def neg_inf(self, reg:UOp): return self.clear(reg, -math.inf) + def zero(self, reg:ALL_TILES): return self.clear(reg, 0) + def neg_inf(self, reg:ALL_TILES): return self.clear(reg, -math.inf) copy_rid = 300 - def copy(self, dst:UOp, src:UOp): + def copy(self, dst:ALL_TILES, src:ALL_TILES): + dst, src = cast(UOp, dst), cast(UOp, src) assert self.warps == 1 assert dst.shape == src.shape @@ -54,7 +56,8 @@ class Group: self.ker.push_store(dst_store, dst) return dst.after(dst_store).reshape(dst.shape) - def mma_AB(self, c:UOp, a:UOp, b:UOp, after=True): + def mma_AB(self, c:UOp|RT, a:UOp|RT, b:UOp|RT, after=True): + c, a, b = cast(UOp, c), cast(UOp, a), cast(UOp, b) assert self.warps == 1 for height in self.ker.range(c.shape[-3], track=False): @@ -76,7 +79,8 @@ class Group: self.ker.push_store(c_store, c) return c.after(c_store).reshape(c.shape) if after else c_store - def mma_ABt(self, c:UOp, a:UOp, b:UOp, after=True): + def mma_ABt(self, c:UOp|RT, a:UOp|RT, b:UOp|RT, after=True): + c, a, b = cast(UOp, c), cast(UOp, a), cast(UOp, b) assert self.warps == 1 for height in self.ker.range(c.shape[-3], track=False): @@ -99,7 +103,8 @@ class Group: return c.after(c_store).reshape(c.shape) if after else c_store map_rid = 400 - def map(self, a:UOp, op:Callable[[UOp], UOp]|Callable[[UOp, tuple], UOp]): + def map(self, a:ALL_TILES, op:Callable[[UOp], UOp]|Callable[[UOp, tuple], UOp]): + a = cast(UOp, a) assert self.warps == 1 rngs_for_shape = tuple(UOp.range(dim, Group.map_rid + i) for i, dim in enumerate(a.shape)) @@ -115,7 +120,8 @@ class Group: self.ker.push_store(a_store, a) return a.after(a_store).reshape(a.shape) - def row_reduce(self, vec:UOp, src:UOp, op:Callable[[UOp, UOp], UOp]): + def row_reduce(self, vec:UOp|RV, src:UOp|RT, op:Callable[[UOp, UOp], UOp]): + vec, src = cast(UOp, vec), cast(UOp, src) assert self.warps == 1 red_local = self.ker.alloc((self.group_threads, 2), src.dtype.base, AddrSpace.LOCAL) @@ -157,7 +163,8 @@ class Group: # ops that can work across multiple warps LOAD_INNER = 8 - def load(self, dst:UOp, src:UOp, dst_idxs:tuple[UOp|int,...]=(), idxs:tuple[UOp|int,...]=(), axis:int=0, transpose:bool=False): + def load(self, dst:ALL_TILES, src:ALL_TILES, dst_idxs:tuple[UOp|int,...]=(), idxs:tuple[UOp|int,...]=(), axis:int=0, transpose:bool=False): + dst, src = cast(UOp, dst), cast(UOp, src) assert isinstance(dst.dtype, PtrDType) and isinstance(src.dtype, PtrDType) dst_dtype, src_dtype = cast(PtrDType, dst.dtype), cast(PtrDType, src.dtype) if dst_dtype.addrspace == AddrSpace.REG and src_dtype.addrspace == AddrSpace.LOCAL: @@ -216,7 +223,8 @@ class Group: return dst.after(dst_store.barrier()).reshape(dst.shape) STORE_INNER = 8 - def store(self, dst:UOp, src:UOp, idxs:tuple[UOp|int,...]=(), src_idxs:tuple[UOp|int,...]=(), axis=0, after=True): + def store(self, dst:ALL_TILES, src:ALL_TILES, idxs:tuple[UOp|int,...]=(), src_idxs:tuple[UOp|int,...]=(), axis=0, after=True): + dst, src = cast(UOp, dst), cast(UOp, src) assert isinstance(dst.dtype, PtrDType) and isinstance(src.dtype, PtrDType) dst_dtype, src_dtype = cast(PtrDType, dst.dtype), cast(PtrDType, src.dtype) if src_dtype.addrspace == AddrSpace.REG and dst_dtype.addrspace == AddrSpace.LOCAL: diff --git a/extra/thunder/tiny/tk/kernel.py b/extra/thunder/tiny/tk/kernel.py index 68d89ea3b2..29df120e12 100644 --- a/extra/thunder/tiny/tk/kernel.py +++ b/extra/thunder/tiny/tk/kernel.py @@ -68,10 +68,10 @@ class Kernel(AbstractContextManager): return uop - def gl(self, shape, dtype): return GL(shape, dtype, self)._uop - def st(self, shape, dtype): return ST(shape, dtype, self)._uop - def rt(self, shape, dtype): return RT(shape, dtype, self)._uop - def rv(self, length, dtype, layout="naive"): return RV(length, dtype, layout, self)._uop + def gl(self, shape, dtype): return GL.create(shape, dtype, self) + def st(self, shape, dtype): return ST.create(shape, dtype, self) + def rt(self, shape, dtype): return RT.create(shape, dtype, self) + def rv(self, length, dtype, layout="naive"): return RV.create(length, dtype, layout, self) def push_store(self, store:UOp, uop:UOp): self.store_stack.append((store, uop)) @@ -80,7 +80,7 @@ class Kernel(AbstractContextManager): rngs = [] while self.range_stack: rngs.append(self.range_stack.pop(0)._rng) - return self.store_stack.pop()[0].end(*rngs).sink(arg=KernelInfo(opts_to_apply=())).simplify() + return self.store_stack.pop()[0]._uop.end(*rngs).sink(arg=KernelInfo(opts_to_apply=())).simplify() def endrange(self): last_store = self.store_stack.pop() diff --git a/extra/thunder/tiny/tk/tiles.py b/extra/thunder/tiny/tk/tiles.py index 0a8ecc987f..1613e8dabe 100644 --- a/extra/thunder/tiny/tk/tiles.py +++ b/extra/thunder/tiny/tk/tiles.py @@ -1,23 +1,114 @@ +import functools from tinygrad.dtype import AddrSpace +from tinygrad.mixin import MathMixin +from tinygrad.uop.ops import UOp, Ops from extra.thunder.tiny.tk import WARP_THREADS +def unwrap(x): + if hasattr(x, "_uop"): return x._uop + if isinstance(x, (list, tuple)): return type(x)(unwrap(y) for y in x) + if isinstance(x, dict): return {k: unwrap(v) for k,v in x.items()} + return x + +def wrap(x, ker, cls): + if isinstance(x, UOp): return cls(x, ker) + if isinstance(x, (list, tuple)): return type(x)(wrap(y, ker, cls) for y in x) + return x + +def autowrap(source_cls, blacklist=None): + if blacklist is None: + blacklist = { + "__init__", "__new__", "__str__", "__del__", "__repr__", "__dict__", "__getattribute__", + "__setattr__", "__delattr__", "__weakref__", "__slots__", "__class__", + "__reduce__", "__reduce_ex__", "__getstate__", "__setstate__", "__hash__" + } + + def decorator(cls): + def __getattr__(self, name): + uop = object.__getattribute__(self, "_uop") + val = getattr(uop, name) + if callable(val): + @functools.wraps(val) + def proxy(*args, **kwargs): + return wrap(val(*unwrap(args), **unwrap(kwargs)), self.ker, cls) + return proxy + if name in UOp.__slots__: return val + return wrap(val, self.ker, cls) + cls.__getattr__ = __getattr__ + + for name in dir(source_cls): + if name in blacklist or not name.startswith("__"): continue + + for base in cls.mro(): + if base is source_cls: break + if name in base.__dict__: break + else: + original = getattr(source_cls, name) + if callable(original): + def make_proxy(op_name, func): + def proxy(self, *args, **kwargs): + return wrap(func(self._uop, *unwrap(args), **unwrap(kwargs)), self.ker, cls) + return proxy + setattr(cls, name, make_proxy(name, original)) + + return cls + return decorator + +class TileMathMixin(MathMixin): + def alu(self, op, *src, inner_op=lambda x:x): + assert isinstance(self, (RT, RV)) + if len(src) == 0: + if self._uop._shape is None: uop = UOp.alu(self._uop, op) + else: uop = self.ker.warp.map(self._uop, lambda x: UOp.alu(x, op)) + elif len(src) == 1: + if self._uop._shape is None: uop = UOp.alu(self._uop, op, inner_op(self._uop.ufix(src[0]))) + elif isinstance(src[0], (int,float,bool)): uop = self.ker.warp.map(self._uop, lambda x: UOp.alu(x, op, inner_op(x.ufix(src[0])))) + elif src[0]._shape is None: uop = UOp.alu(self._uop, op, inner_op(self._uop.ufix(src[0]))) + else: + if isinstance(self, RT) and isinstance(src[0], RV): uop = self.ker.warp.map(self._uop, lambda x, idx: UOp.alu(x, op, inner_op(src[0]._uop[idx[0], 0, (idx[2]%4)//2]))) + else: uop = self.ker.warp.map(self._uop, lambda x, idx: UOp.alu(x, op, inner_op(src[0]._uop[*idx]))) + else: raise NotImplementedError + return type(self)(uop, self.ker) + def const_like(self, b): return b + + # override ops that do compute on the src uop + def sub(self, x, reverse=False): + return self.ufix(x).alu(Ops.ADD, self, inner_op=lambda y: -y) if reverse else self.alu(Ops.ADD, self.ufix(x), inner_op=lambda y: -y) + def div(self, x, reverse=False): + return self.ufix(x).alu(Ops.MUL, self, inner_op=lambda y: 1/y) if reverse else self.alu(Ops.MUL, self.ufix(x), inner_op=lambda y: 1/y) + +@autowrap(UOp) class GL: - def __init__(self, shape, dtype, ker): - self.shape, self.dtype = shape, dtype - self._uop = ker.alloc(shape, dtype, AddrSpace.GLOBAL) + def __init__(self, uop, ker): + self._uop, self.ker = uop, ker + @classmethod + def create(cls, shape, dtype, ker): + uop = ker.alloc(shape, dtype, AddrSpace.GLOBAL) + return cls(uop, ker) + +@autowrap(UOp) class ST: - def __init__(self, shape, dtype, ker): - self.shape, self.dtype = shape, dtype - self._uop = ker.alloc(shape, dtype, AddrSpace.LOCAL) + def __init__(self, uop, ker): + self._uop, self.ker = uop, ker -class RT: + @classmethod + def create(cls, shape, dtype, ker): + uop = ker.alloc(shape, dtype, AddrSpace.LOCAL) + return cls(uop, ker) + +@autowrap(UOp) +class RT(TileMathMixin): TILE_ROW_DIM, TILE_COL_DIM = 16, 16 BASE_TILE_NE = TILE_ROW_DIM * TILE_COL_DIM BASE_TILE_NEPT = BASE_TILE_NE // WARP_THREADS - def __init__(self, shape, dtype, ker): + def __init__(self, uop, ker): + self._uop, self.ker = uop, ker + + @classmethod + def create(cls, shape, dtype, ker): assert len(shape) == 2 assert shape[0] % RT.TILE_ROW_DIM == 0 assert shape[1] % RT.TILE_COL_DIM == 0 @@ -25,11 +116,16 @@ class RT: height = shape[0] // RT.TILE_ROW_DIM width = shape[1] // RT.TILE_COL_DIM - self.shape, self.dtype = (height, width, self.BASE_TILE_NEPT), dtype - self._uop = ker.alloc(self.shape, dtype, AddrSpace.REG) + uop = ker.alloc((height, width, RT.BASE_TILE_NEPT), dtype, AddrSpace.REG) + return cls(uop, ker) -class RV: - def __init__(self, length, dtype, layout, ker): +@autowrap(UOp) +class RV(TileMathMixin): + def __init__(self, uop, ker): + self._uop, self.ker = uop, ker + + @classmethod + def create(cls, length, dtype, layout, ker): tiles = length // RT.TILE_ROW_DIM match layout: @@ -41,5 +137,7 @@ class RV: outer_dim = tiles case _: raise NotImplementedError(f"rv layout {layout} not implemented") - self.shape, self.dtype = (outer_dim, inner_dim, 2), dtype - self._uop = ker.alloc(self.shape, dtype, AddrSpace.REG) + uop = ker.alloc((outer_dim, inner_dim, 2), dtype, AddrSpace.REG) + return RV(uop, ker) + +ALL_TILES = UOp | GL | ST | RT | RV diff --git a/test/testextra/test_tk.py b/test/testextra/test_tk.py index ec2294bf0e..4888fa782b 100644 --- a/test/testextra/test_tk.py +++ b/test/testextra/test_tk.py @@ -152,6 +152,44 @@ class TestTK(unittest.TestCase): np.testing.assert_allclose(b.numpy(), ref.numpy()) + def test_add(self): + N = 32 + BLOCK_SIZE = 16 + with Kernel((1, 1, 1), WARP_THREADS) as ker: + warp = ker.warp + + b = ker.gl((1, 1, N, N), dtypes.float32) + a = ker.gl((1, 1, N, N), dtypes.float32) + + a_smem = ker.st((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32) + + a_reg = ker.rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32) + + for tile_row in ker.range(N // BLOCK_SIZE): + for tile_col in ker.range(N // BLOCK_SIZE): + a_smem = warp.load(a_smem, a, (), (0, 0, tile_row, tile_col), axis=2) + a_reg = warp.load(a_reg, a_smem) + + a_reg += 1 + + a_smem = warp.store(a_smem, a_reg) + b = warp.store(b, a_smem, (0, 0, tile_row, tile_col), (), axis=2) + + sink = ker.finish() + + with Context(DEBUG=0): + a = Tensor.rand(1, 1, N, N, dtype="float32").contiguous() + b = Tensor.empty(1, 1, N, N, dtype="float32") + Tensor.realize(a, b) + + ei = ExecItem(get_runner(Device.DEFAULT, sink), [t.uop.buffer for t in (b, a)]) + for _ in range(5): ei.run(wait=True) + b = b.float() + + ref = a.float() + 1 + + np.testing.assert_allclose(b.numpy(), ref.numpy()) + def test_max(self): N = 16 BLOCK_SIZE = 16 @@ -365,13 +403,13 @@ class TestTK(unittest.TestCase): a_smem = warp.load(a_smem, a, (), (0, 0, 0, tile_col), axis=2) a_reg = warp.load(a_reg, a_smem) - a_reg = warp.map(a_reg, lambda x: x * (1.0 / math.log(2))) + a_reg *= 1.0 / math.log(2) max_vec_last = warp.copy(max_vec_last.after(tile_col), max_vec) max_vec = warp.row_reduce(max_vec, a_reg, lambda a, b: a.maximum(b)) - a_reg = warp.map(a_reg, lambda x, idx: (x - max_vec[idx[0], 0, (idx[2]%4)//2]).exp2()) - max_vec_last = warp.map(max_vec_last, lambda x, idx: (x - max_vec[*idx]).exp2()) - norm_vec = warp.map(norm_vec, lambda x, idx: x * max_vec_last[*idx]) + a_reg = (a_reg - max_vec).exp2() + max_vec_last = (max_vec_last - max_vec).exp2() + norm_vec *= max_vec_last norm_vec = warp.row_reduce(norm_vec, a_reg, lambda a, b: a + b) norm_vec = ker.endrange() @@ -379,9 +417,9 @@ class TestTK(unittest.TestCase): a_smem = warp.load(a_smem, a, (), (0, 0, 0, tile_col), axis=2) a_reg = warp.load(a_reg, a_smem) - a_reg = warp.map(a_reg, lambda x: x * (1.0 / math.log(2))) - a_reg = warp.map(a_reg, lambda x, idx: (x - max_vec[idx[0], 0, (idx[2]%4)//2]).exp2()) - a_reg = warp.map(a_reg, lambda x, idx: x / norm_vec[idx[0], 0, (idx[2]%4)//2]) + a_reg *= 1.0 / math.log(2) + a_reg = (a_reg - max_vec).exp2() + a_reg /= norm_vec a_smem = warp.store(a_smem, a_reg) b = warp.store(b, a_smem, (0, 0, 0, tile_col), (), axis=2)