From 6252831ceb63c16cc3819747d984673e0b167003 Mon Sep 17 00:00:00 2001 From: wozeparrot Date: Sun, 9 Nov 2025 22:54:29 -0800 Subject: [PATCH] feat: initial tk library (#13160) --- extra/thunder/tiny/tk/__init__.py | 1 + extra/thunder/tiny/tk/group.py | 272 +++++++++++++++++++++++ extra/thunder/tiny/tk/kernel.py | 57 +++++ extra/thunder/tiny/tk/tiles.py | 52 +++++ test/external/external_test_tk.py | 345 ++++++++++++++++++++++++++++++ 5 files changed, 727 insertions(+) create mode 100644 extra/thunder/tiny/tk/__init__.py create mode 100644 extra/thunder/tiny/tk/group.py create mode 100644 extra/thunder/tiny/tk/kernel.py create mode 100644 extra/thunder/tiny/tk/tiles.py create mode 100644 test/external/external_test_tk.py diff --git a/extra/thunder/tiny/tk/__init__.py b/extra/thunder/tiny/tk/__init__.py new file mode 100644 index 0000000000..27dfca23e2 --- /dev/null +++ b/extra/thunder/tiny/tk/__init__.py @@ -0,0 +1 @@ +WARP_THREADS = 32 diff --git a/extra/thunder/tiny/tk/group.py b/extra/thunder/tiny/tk/group.py new file mode 100644 index 0000000000..3df0f47ed2 --- /dev/null +++ b/extra/thunder/tiny/tk/group.py @@ -0,0 +1,272 @@ +import math, functools +from typing import cast, Callable +from tinygrad import Tensor, Device, Context, GlobalCounters, dtypes +from tinygrad.uop.ops import AxisType, UOp, KernelInfo, Ops +from tinygrad.engine.realize import ExecItem, get_runner +from tinygrad.dtype import AddrSpace, PtrDType +from tinygrad.helpers import getenv, prod + +from extra.thunder.tiny.tk import WARP_THREADS +from extra.thunder.tiny.tk.tiles import TILE_ROW_DIM, TILE_COL_DIM, RT_BASE_TILE_NEPT, slots + +class Group: + def __init__(self, warps:int, ker): + self.warps = warps + self.group_threads = warps * WARP_THREADS + self.threadIdx_x = ker.threadIdx_x + self.ker = ker + + # helpers + @property + def laneid(self): return self.threadIdx_x % self.group_threads + @property + def warpid(self): return self.laneid // WARP_THREADS + @property + def groupid(self): return self.threadIdx_x // self.group_threads + + # ops that only work on a single warp + + clear_rid = 1000 + def clear(self, reg:UOp, value:float=0): + assert self.warps == 1 + + i = UOp.range(reg.size, Group.clear_rid) + Group.clear_rid += 1 + return reg.reshape((reg.size,))[i].set(value, end=i).after(reg).reshape(reg.shape) + + def zero(self, reg:UOp): return self.clear(reg, 0) + def neg_inf(self, reg:UOp): return self.clear(reg, -math.inf) + + copy_rid = 300 + def copy(self, dst:UOp, src:UOp): + assert self.warps == 1 + + assert dst.shape == src.shape + assert cast(PtrDType, dst.dtype).addrspace == AddrSpace.REG + assert cast(PtrDType, src.dtype).addrspace == AddrSpace.REG + + rngs_for_shape = tuple(UOp.range(dim, Group.copy_rid + i) for i, dim in enumerate(dst.shape)) + Group.copy_rid += len(dst.shape) + + dst_store = dst[*rngs_for_shape].store(src[*rngs_for_shape].cast(dst.dtype.base)).end(*rngs_for_shape) + + self.ker.push_store(dst_store, dst) + return dst.after(dst_store).reshape(dst.shape) + + mma_rid = 600 + def mma_AB(self, c:UOp, a:UOp, b:UOp, after=True): + assert self.warps == 1 + + mma_i_height = UOp.range(c.shape[-3], Group.mma_rid) + mma_i_width = UOp.range(c.shape[-2], Group.mma_rid+1) + mma_i_inner = UOp.range(a.shape[-2], Group.mma_rid+2, AxisType.REDUCE) + Group.mma_rid += 3 + + wmma_arg = ("WMMA_8_16_16_bfloat16_float", (8, 16, 16), dtypes.bfloat16, dtypes.float, "CUDA", 32, (((4, 2), (3, 2), (8, 2)), ((4, 2), (3, 2)), ((4, 2), (3, 2))), ()) + + a_in = UOp.vectorize(*[a[mma_i_height, mma_i_inner, i] for i in range(8)]) + b_in1 = UOp.vectorize(*([b[mma_i_inner, mma_i_width, i] for i in range(2)] + [b[mma_i_inner, mma_i_width, 4+i] for i in range(2)])) + c_out1 = UOp.vectorize(*[c[mma_i_height, mma_i_width, i] for i in range(4)]) + b_in2 = UOp.vectorize(*([b[mma_i_inner, mma_i_width, 2+i] for i in range(2)] + [b[mma_i_inner, mma_i_width, 6+i] for i in range(2)])) + c_out2 = UOp.vectorize(*[c[mma_i_height, mma_i_width, 4+i] for i in range(4)]) + + out1 = UOp(Ops.WMMA, dtypes.float32.vec(4), (a_in, b_in1, c_out1), arg=wmma_arg) + out2 = UOp(Ops.WMMA, dtypes.float32.vec(4), (a_in, b_in2, c_out2), arg=wmma_arg) + c_i = [c[mma_i_height, mma_i_width, i].store(out1.gep(i)) for i in range(4)] + [c[mma_i_height, mma_i_width, 4+i].store(out2.gep(i)) for i in range(4)] + c_store = UOp.group(*c_i).end(mma_i_height, mma_i_width, mma_i_inner) + + self.ker.push_store(c_store, c) + return c.after(c_store).reshape(c.shape) if after else c_store + + def mma_ABt(self, c:UOp, a:UOp, b:UOp, after=True): + assert self.warps == 1 + + mma_i_height = UOp.range(c.shape[-3], Group.mma_rid) + mma_i_width = UOp.range(c.shape[-2], Group.mma_rid+1) + mma_i_inner = UOp.range(a.shape[-2], Group.mma_rid+2, AxisType.REDUCE) + Group.mma_rid += 3 + + wmma_arg = ("WMMA_8_16_16_bfloat16_float", (8, 16, 16), dtypes.bfloat16, dtypes.float, "CUDA", 32, (((4, 2), (3, 2), (8, 2)), ((4, 2), (3, 2)), ((4, 2), (3, 2))), ()) + + a_in = UOp.vectorize(*[a[mma_i_height, mma_i_inner, i] for i in range(8)]) + b_in1 = UOp.vectorize(*([b[mma_i_width, mma_i_inner, i] for i in range(2)] + [b[mma_i_width, mma_i_inner, 4+i] for i in range(2)])) + c_out1 = UOp.vectorize(*[c[mma_i_height, mma_i_width, i] for i in range(4)]) + b_in2 = UOp.vectorize(*([b[mma_i_width, mma_i_inner, 2+i] for i in range(2)] + [b[mma_i_width, mma_i_inner, 6+i] for i in range(2)])) + c_out2 = UOp.vectorize(*[c[mma_i_height, mma_i_width, 4+i] for i in range(4)]) + + out1 = UOp(Ops.WMMA, dtypes.float32.vec(4), (a_in, b_in1, c_out1), arg=wmma_arg) + out2 = UOp(Ops.WMMA, dtypes.float32.vec(4), (a_in, b_in2, c_out2), arg=wmma_arg) + c_i = [c[mma_i_height, mma_i_width, i].store(out1.gep(i)) for i in range(4)] + [c[mma_i_height, mma_i_width, 4+i].store(out2.gep(i)) for i in range(4)] + c_store = UOp.group(*c_i).end(mma_i_height, mma_i_width, mma_i_inner) + + self.ker.push_store(c_store, c) + return c.after(c_store).reshape(c.shape) if after else c_store + + map_rid = 400 + def map(self, a:UOp, op:Callable[[UOp], UOp]|Callable[[UOp, tuple], UOp]): + assert self.warps == 1 + + rngs_for_shape = tuple(UOp.range(dim, Group.map_rid + i) for i, dim in enumerate(a.shape)) + Group.map_rid += len(a.shape) + + if op.__code__.co_argcount == 1: + to_store = op(a[*rngs_for_shape]) + else: + to_store = op(a[*rngs_for_shape], rngs_for_shape) + + a_store = a[*rngs_for_shape].store(to_store).end(*rngs_for_shape) + + self.ker.push_store(a_store, a) + return a.after(a_store).reshape(a.shape) + + def row_reduce(self, vec:UOp, src:UOp, op:Callable[[UOp, UOp], UOp]): + assert self.warps == 1 + + red_local = UOp.placeholder((self.group_threads, 2), src.dtype.base, addrspace=AddrSpace.LOCAL, slot=slots.shared_slot) + slots.shared_slot += 1 + + for height in self.ker.range(src.shape[-3], track=False): + for i_outer in self.ker.range(2, track=False): + for width in self.ker.range(src.shape[-2], AxisType.REDUCE, track=False): + for i_inner in self.ker.range(4, AxisType.REDUCE, track=False): + elem_index = i_inner + 2 * (i_inner // 2) + i_outer * 2 + vec_store = vec[height, 0, i_outer].store(op(vec[height, 0, i_outer], src[height, width, elem_index])).end(width, i_inner, i_outer) + vec = vec.after(vec_store).reshape(vec.shape) + + # store to shared memory + for i_outer in self.ker.range(2, track=False): + red_local_store = red_local[self.laneid, i_outer].store(vec[height, 0, i_outer]).end(i_outer) + red_local = red_local.after(red_local_store).reshape(red_local.shape) + + # reduce from shared memory + for i_outer in self.ker.range(2, track=False): + for i_inner in self.ker.range(3, AxisType.REDUCE, track=False): + offset = (self.laneid // 4) * 4 + ((self.laneid + 1 + i_inner) % 4) + vec_store = vec[height, 0, i_outer].store(op(vec[height, 0, i_outer], red_local[offset, i_outer])).end(i_inner, i_outer) + + self.ker.push_store(vec_store, vec) + return vec.after(vec_store).reshape(vec.shape) + + # ops that can work across multiple warps + + LOAD_INNER = 8 + load_rid = 100 + def load(self, dst:UOp, src:UOp, dst_idxs:tuple[UOp|int,...]=(), idxs:tuple[UOp|int,...]=(), axis:int=0, transpose:bool=False): + assert isinstance(dst.dtype, PtrDType) and isinstance(src.dtype, PtrDType) + dst_dtype, src_dtype = cast(PtrDType, dst.dtype), cast(PtrDType, src.dtype) + if dst_dtype.addrspace == AddrSpace.REG and src_dtype.addrspace == AddrSpace.LOCAL: + srcf = src.flatten(-2) + + load_i_height = UOp.range(dst.shape[-3], Group.load_rid) + load_i_width = UOp.range(dst.shape[-2], Group.load_rid+1) + load_i_inner = UOp.range(RT_BASE_TILE_NEPT, Group.load_rid+2) + Group.load_rid += 3 + + if self.warps % 4 == 0: local_warpid = (self.warpid // 4) + (self.warpid % 4) * (self.warps // 4) + else: local_warpid = self.warpid + warp_laneid = self.threadIdx_x % WARP_THREADS + + if not transpose: + row = (local_warpid * dst.shape[-3] + load_i_height) * TILE_ROW_DIM + (warp_laneid // 4) + col = load_i_width * TILE_COL_DIM + 2 * (warp_laneid % 4) + + row_offset = ((load_i_inner % 4) // 2) * 8 + col_offset = (load_i_inner % 2) + (load_i_inner // 4) * 8 + else: + row = (local_warpid * dst.shape[-3] + load_i_height) * TILE_ROW_DIM + 2 * (warp_laneid % 4) + col = load_i_width * TILE_COL_DIM + (warp_laneid // 4) + + row_offset = (load_i_inner % 2) + (load_i_inner // 4) * 8 + col_offset = ((load_i_inner % 4) // 2) * 8 + + src_i_last = (row + row_offset) * src.shape[-1] + col + col_offset + + dst_store = dst[*dst_idxs, load_i_height, load_i_width, load_i_inner].store(srcf[*idxs[:-2], src_i_last]) + dst_store = dst_store.end(load_i_height, load_i_width, load_i_inner) + elif dst_dtype.addrspace == AddrSpace.LOCAL and src_dtype.addrspace == AddrSpace.GLOBAL: + dstf = dst.flatten(-2) + + srcf = src.flatten() + row_stride = prod(src.shape[axis+1:]) + + idxs = tuple(idx * dst.shape[-2] if i == axis else idx for i, idx in enumerate(idxs)) + idxs = tuple(idx * dst.shape[-1] if i == 3 else idx for i, idx in enumerate(idxs)) + src_i = ((idxs[0] * src.shape[-3] + idxs[1]) * src.shape[-2] + idxs[2]) * src.shape[-1] + idxs[3] + + memcpy_per_row = dst.shape[-1] // Group.LOAD_INNER + total_calls = prod(dst.shape[-2:]) // (self.group_threads * Group.LOAD_INNER) + + load_i_outer = UOp.range(total_calls, Group.load_rid) + load_i_inner = UOp.range(Group.LOAD_INNER, Group.load_rid+1) + Group.load_rid += 2 + + load_idx = load_i_outer * self.group_threads + self.laneid + row = load_idx // memcpy_per_row + col = (load_idx * Group.LOAD_INNER) % dst.shape[-1] + + dst_i = row * dst.shape[-1] + col + load_i_inner + src_i += row * row_stride + col + load_i_inner + + dst_store = dstf[*dst_idxs, dst_i].store(srcf[src_i]).end(load_i_outer, load_i_inner) + else: + raise NotImplementedError(f"load from {src_dtype.addrspace} to {dst_dtype.addrspace} not implemented") + + return dst.after(dst_store.barrier()).reshape(dst.shape) + + STORE_INNER = 8 + store_rid = 200 + def store(self, dst:UOp, src:UOp, idxs:tuple[UOp|int,...]=(), src_idxs:tuple[UOp|int,...]=(), axis=0, after=True): + assert isinstance(dst.dtype, PtrDType) and isinstance(src.dtype, PtrDType) + dst_dtype, src_dtype = cast(PtrDType, dst.dtype), cast(PtrDType, src.dtype) + if src_dtype.addrspace == AddrSpace.REG and dst_dtype.addrspace == AddrSpace.LOCAL: + dstf = dst.flatten(-2) + + store_i_height = UOp.range(src.shape[-3], Group.store_rid) + store_i_width = UOp.range(src.shape[-2], Group.store_rid+1) + store_i_inner = UOp.range(RT_BASE_TILE_NEPT, Group.store_rid+2) + Group.store_rid += 3 + + if self.warps % 4 == 0: local_warpid = (self.warpid // 4) + (self.warpid % 4) * (self.warps // 4) + else: local_warpid = self.warpid + warp_laneid = self.threadIdx_x % WARP_THREADS + + row = (local_warpid * src.shape[-3] + store_i_height) * TILE_ROW_DIM + (warp_laneid // 4) + col = store_i_width * TILE_COL_DIM + 2 * (warp_laneid % 4) + + row_offset = ((store_i_inner % 4) // 2) * 8 + col_offset = (store_i_inner % 2) + (store_i_inner // 4) * 8 + + dst_i_last = (row + row_offset) * dst.shape[-1] + col + col_offset + + dst_store = dstf[*idxs[:-2], dst_i_last].store(src[*src_idxs, store_i_height, store_i_width, store_i_inner]) + dst_store = dst_store.end(store_i_height, store_i_width, store_i_inner) + elif src_dtype.addrspace == AddrSpace.LOCAL and dst_dtype.addrspace == AddrSpace.GLOBAL: + dstf = dst.flatten() + row_stride = prod(dst.shape[axis+1:]) + + idxs = tuple(idx * src.shape[-2] if i == axis else idx for i, idx in enumerate(idxs)) + idxs = tuple(idx * src.shape[-1] if i == 3 else idx for i, idx in enumerate(idxs)) + dst_i = ((idxs[0] * dst.shape[-3] + idxs[1]) * dst.shape[-2] + idxs[2]) * dst.shape[-1] + idxs[3] + + srcf = src.flatten(-2) + + memcpy_per_row = src.shape[-1] // Group.STORE_INNER + total_calls = prod(src.shape[-2:]) // (self.group_threads * Group.STORE_INNER) + + store_i_outer = UOp.range(total_calls, Group.store_rid) + store_i_inner = UOp.range(Group.STORE_INNER, Group.store_rid+1) + Group.store_rid += 2 + + load_idx = store_i_outer * self.group_threads + self.laneid + row = load_idx // memcpy_per_row + col = (load_idx * Group.STORE_INNER) % src.shape[-1] + + src_i = row * src.shape[-1] + col + store_i_inner + dst_i += row * row_stride + col + store_i_inner + + dst_store = dstf[dst_i].store(srcf[*src_idxs, src_i]).end(store_i_outer, store_i_inner) + else: + raise NotImplementedError(f"store from {src_dtype.addrspace} to {dst_dtype.addrspace} not implemented") + + self.ker.push_store(dst_store, dst) + return dst.after(dst_store.barrier()).reshape(dst.shape) if after else dst_store diff --git a/extra/thunder/tiny/tk/kernel.py b/extra/thunder/tiny/tk/kernel.py new file mode 100644 index 0000000000..8fab1ee905 --- /dev/null +++ b/extra/thunder/tiny/tk/kernel.py @@ -0,0 +1,57 @@ +from contextlib import AbstractContextManager +from tinygrad.uop.ops import UOp, KernelInfo, AxisType +from extra.thunder.tiny.tk import WARP_THREADS +from extra.thunder.tiny.tk.group import Group + +class _tk_range: + user_rid = 0 + def __init__(self, end:int, axis_type:AxisType): self.end, self.axis_type, self.done = end, axis_type, False + def __iter__(self): return self + def __next__(self): + if not self.done: + self.done = True + _tk_range.user_rid += 1 + self._rng = UOp.range(self.end, _tk_range.user_rid-1, axis_type=self.axis_type) + return self._rng + raise StopIteration + +class Kernel(AbstractContextManager): + def __init__(self, grid_size:tuple[int, int, int], block_size:int): + self.blockIdx_x = UOp.special(grid_size[0], "gidx0") + self.blockIdx_y = UOp.special(grid_size[1], "gidx1") + self.blockIdx_z = UOp.special(grid_size[2], "gidx2") + self.threadIdx_x = UOp.special(block_size, "lidx0") + + self.range_stack = [] + self.store_stack = [] + + @property + def warpid(self): return self.threadIdx_x // WARP_THREADS + + def __enter__(self): return self + def __exit__(self, exc_type, exc_value, traceback): pass + + def group(self, size:int): return Group(size, self) + @property + def warp(self): return self.group(1) + @property + def warpgroup(self): return self.group(4) + + def range(self, end:int, axis_type:AxisType=AxisType.LOOP, track:bool=True): + rng = _tk_range(end, axis_type) + if track: self.range_stack.append(rng) + return rng + + def push_store(self, store:UOp, uop:UOp): self.store_stack.append((store, uop)) + + def finish(self): + # end all ranges + rngs = [] + while self.range_stack: rngs.append(self.range_stack.pop(0)._rng) + + return self.store_stack.pop()[0].end(*rngs).sink(arg=KernelInfo(opts_to_apply=())).simplify() + + def endrange(self): + last_store = self.store_stack.pop() + last_range = self.range_stack.pop() + return last_store[1].after(last_store[0].barrier().end(last_range._rng)).reshape(last_store[1].shape) diff --git a/extra/thunder/tiny/tk/tiles.py b/extra/thunder/tiny/tk/tiles.py new file mode 100644 index 0000000000..c936dfd199 --- /dev/null +++ b/extra/thunder/tiny/tk/tiles.py @@ -0,0 +1,52 @@ +import math +from typing import cast, Callable +from tinygrad import Tensor, Device, Context, GlobalCounters, dtypes +from tinygrad.uop.ops import AxisType, UOp, KernelInfo, Ops +from tinygrad.engine.realize import ExecItem, get_runner +from tinygrad.dtype import AddrSpace, PtrDType +from tinygrad.helpers import getenv, prod + +from extra.thunder.tiny.tk import WARP_THREADS + +class _Slots: + def __init__(self): + self.global_slot = 0 + self.shared_slot = 0 + self.register_slot = 0 +slots = _Slots() + +def gl(shape, dtype): + slots.global_slot += 1 + return UOp.placeholder(shape, dtype, slot=slots.global_slot-1) + +shared_slot = 0 +def st(shape, dtype): + slots.shared_slot += 1 + return UOp.placeholder(shape, dtype, addrspace=AddrSpace.LOCAL, slot=slots.shared_slot-1) + +TILE_ROW_DIM, TILE_COL_DIM = 16, 16 +RT_BASE_TILE_NE = TILE_ROW_DIM * TILE_COL_DIM +RT_BASE_TILE_NEPT = RT_BASE_TILE_NE // WARP_THREADS +register_slot = 0 +def rt(shape, dtype): + assert len(shape) == 2 + + height = shape[0] // TILE_ROW_DIM + width = shape[1] // TILE_COL_DIM + + slots.register_slot += 1 + return UOp.placeholder((height, width, RT_BASE_TILE_NEPT), dtype, addrspace=AddrSpace.REG, slot=slots.register_slot-1) + +def rv(length, dtype, layout="naive"): + tiles = length // TILE_ROW_DIM + match layout: + case "naive": + inner_dim = 1 + outer_dim = (tiles + 1) // 2 + case "ortho": + inner_dim = 1 + outer_dim = tiles + case _: raise NotImplementedError(f"rv layout {layout} not implemented") + + slots.register_slot += 1 + return UOp.placeholder((outer_dim, inner_dim, 2), dtype, addrspace=AddrSpace.REG, slot=slots.register_slot-1) diff --git a/test/external/external_test_tk.py b/test/external/external_test_tk.py new file mode 100644 index 0000000000..8c7ef65de8 --- /dev/null +++ b/test/external/external_test_tk.py @@ -0,0 +1,345 @@ +import unittest + +from tinygrad import Tensor, Device, dtypes, Context +from tinygrad.engine.realize import ExecItem, get_runner + +from extra.thunder.tiny.tk import WARP_THREADS +from extra.thunder.tiny.tk.kernel import Kernel +from extra.thunder.tiny.tk.tiles import gl, st, rt, rv + +class TestTK(unittest.TestCase): + @unittest.skip("store from float rt is wrong") + def test_simple_matmul(self): + N = 32 + BLOCK_SIZE = 16 + with Kernel((N // BLOCK_SIZE, N // BLOCK_SIZE, 1), WARP_THREADS) as ker: + warp = ker.warp + + c = gl((1, 1, N, N), dtypes.float32) + a = gl((1, 1, N, N), dtypes.bfloat16) + b = gl((1, 1, N, N), dtypes.bfloat16) + + a_smem = st((BLOCK_SIZE, BLOCK_SIZE), dtypes.bfloat16) + b_smem = st((BLOCK_SIZE, BLOCK_SIZE), dtypes.bfloat16) + c_smem = st((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32) + + a_reg = rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.bfloat16) + b_reg = rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.bfloat16) + c_reg = rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32) + + col, row = ker.blockIdx_x, ker.blockIdx_y + + c_reg = warp.zero(c_reg) + for tile in ker.range(N // BLOCK_SIZE): + a_smem = warp.load(a_smem, a, (), (0, 0, row, tile), axis=2) + b_smem = warp.load(b_smem, b, (), (0, 0, tile, col), axis=2) + + a_reg = warp.load(a_reg, a_smem) + b_reg = warp.load(b_reg, b_smem, transpose=True) + + c_reg = warp.mma_AB(c_reg, a_reg, b_reg) + c_reg = ker.endrange() + + c_smem = warp.store(c_smem, c_reg) + c = warp.store(c, c_smem, (0, 0, row, col), (), axis=2) + + sink = ker.finish() + + with Context(DEBUG=0): + a = Tensor.rand(1, 1, N, N, dtype="bfloat16").contiguous() + b = Tensor.rand(1, 1, N, N, dtype="bfloat16").contiguous() + c = Tensor.empty(1, 1, N, N, dtype="float32") + Tensor.realize(a, b, c) + + ei = ExecItem(get_runner(Device.DEFAULT, sink), [t.uop.buffer for t in (c, a, b)]) + for _ in range(5): ei.run(wait=True) + c = c.float() + + ref = a.matmul(b, dtype=dtypes.float32).float() + + assert ref.allclose(c) + + @unittest.skip("store from float rt is wrong") + def test_simple_matmul_transposed(self): + N = 32 + BLOCK_SIZE = 16 + with Kernel((N // BLOCK_SIZE, N // BLOCK_SIZE, 1), WARP_THREADS) as ker: + warp = ker.warp + + c = gl((1, 1, N, N), dtypes.float32) + a = gl((1, 1, N, N), dtypes.bfloat16) + b = gl((1, 1, N, N), dtypes.bfloat16) + + a_smem = st((BLOCK_SIZE, BLOCK_SIZE), dtypes.bfloat16) + b_smem = st((BLOCK_SIZE, BLOCK_SIZE), dtypes.bfloat16) + c_smem = st((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32) + + a_reg = rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.bfloat16) + b_reg = rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.bfloat16) + c_reg = rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32) + + col, row = ker.blockIdx_x, ker.blockIdx_y + + c_reg = warp.zero(c_reg) + for tile in ker.range(N // BLOCK_SIZE): + a_smem = warp.load(a_smem, a, (), (0, 0, row, tile), axis=2) + b_smem = warp.load(b_smem, b, (), (0, 0, col, tile), axis=2) + + a_reg = warp.load(a_reg, a_smem) + b_reg = warp.load(b_reg, b_smem) + + c_reg = warp.mma_ABt(c_reg, a_reg, b_reg) + c_reg = ker.endrange() + + c_smem = warp.store(c_smem, c_reg) + c = warp.store(c, c_smem, (0, 0, row, col), (), axis=2) + + sink = ker.finish() + + with Context(DEBUG=0): + a = Tensor.rand(1, 1, N, N, dtype="bfloat16").contiguous() + b = Tensor.rand(1, 1, N, N, dtype="bfloat16").contiguous() + c = Tensor.empty(1, 1, N, N, dtype="float32") + Tensor.realize(a, b, c) + + ei = ExecItem(get_runner(Device.DEFAULT, sink), [t.uop.buffer for t in (c, a, b)]) + for _ in range(5): ei.run(wait=True) + c = c.float() + + ref = a.matmul(b.transpose(2, 3), dtype=dtypes.float32).float() + + assert ref.allclose(c) + + def test_load_store(self): + N = 32 + BLOCK_SIZE = 16 + with Kernel((N // BLOCK_SIZE, N // BLOCK_SIZE, 1), WARP_THREADS) as ker: + warp = ker.warp + + b = gl((1, 1, N, N), dtypes.float32) + a = gl((1, 1, N, N), dtypes.float32) + + a_smem = st((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32) + b_smem = st((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32) + + a_reg = rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32) + b_reg = rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32) + + col, row = ker.blockIdx_x, ker.blockIdx_y + + a_smem = warp.load(a_smem, a, (), (0, 0, row, col), axis=2) + a_reg = warp.load(a_reg, a_smem) + b_reg = warp.copy(b_reg, a_reg) + b_smem = warp.store(b_smem, b_reg) + b = warp.store(b, b_smem, (0, 0, row, col), (), axis=2) + + sink = ker.finish() + + with Context(DEBUG=0): + a = Tensor.rand(1, 1, N, N, dtype="float32").contiguous() + b = Tensor.empty(1, 1, N, N, dtype="float32") + Tensor.realize(a, b) + + ei = ExecItem(get_runner(Device.DEFAULT, sink), [t.uop.buffer for t in (b, a)]) + for _ in range(5): ei.run(wait=True) + b = b.float() + + ref = a.float() + + assert ref.allclose(b) + + def test_max(self): + N = 16 + BLOCK_SIZE = 16 + with Kernel((1, 1, 1), WARP_THREADS) as ker: + warp = ker.warp + + b = gl((1, 1, N, N), dtypes.float32) + a = gl((1, 1, N, N), dtypes.float32) + + a_smem = st((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32) + b_smem = st((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32) + + a_reg = rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32) + b_reg = rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32) + + max_reg = rv(BLOCK_SIZE, dtypes.float32, "ortho") + + max_reg = warp.neg_inf(max_reg) + + for tile_row in ker.range(N // BLOCK_SIZE): + for tile_col in ker.range(N // BLOCK_SIZE): + a_smem = warp.load(a_smem, a, (), (0, 0, tile_row, tile_col), axis=2) + a_reg = warp.load(a_reg, a_smem) + max_reg = warp.row_reduce(max_reg, a_reg, lambda a, b: a.maximum(b)) + sum_reg = ker.endrange() + + b_reg = warp.zero(b_reg).after(tile_row) + b_reg = warp.map(b_reg, lambda _, idx: sum_reg[idx[0], 0, (idx[2]%4)//2]) + b_smem = warp.store(b_smem, b_reg) + + for tile_col in ker.range(N // BLOCK_SIZE): + b = warp.store(b, b_smem, (0, 0, tile_row, tile_col), (), axis=2) + + sink = ker.finish() + + with Context(DEBUG=0): + a = Tensor.rand(1, 1, N, N, dtype="float32").contiguous() + b = Tensor.empty(1, 1, N, N, dtype="float32") + Tensor.realize(a, b) + + ei = ExecItem(get_runner(Device.DEFAULT, sink), [t.uop.buffer for t in (b, a)]) + for _ in range(5): ei.run(wait=True) + b = b.float() + + ref = a.float().max(axis=3, keepdim=True).expand(a.shape) + + assert ref.allclose(b) + + def test_max_nonsquare(self): + N, M = 16, 64 + BLOCK_N, BLOCK_M = 16, 64 + with Kernel((1, 1, 1), WARP_THREADS) as ker: + warp = ker.warp + + b = gl((1, 1, N, M), dtypes.float32) + a = gl((1, 1, N, M), dtypes.float32) + + a_smem = st((BLOCK_N, BLOCK_M), dtypes.float32) + b_smem = st((BLOCK_N, BLOCK_M), dtypes.float32) + + a_reg = rt((BLOCK_N, BLOCK_M), dtypes.float32) + b_reg = rt((BLOCK_N, BLOCK_M), dtypes.float32) + + max_reg = rv(BLOCK_N, dtypes.float32, "ortho") + + max_reg = warp.zero(max_reg) + + for tile_row in ker.range(N // BLOCK_N): + for tile_col in ker.range(M // BLOCK_M): + a_smem = warp.load(a_smem, a, (), (0, 0, tile_row, tile_col), axis=2) + a_reg = warp.load(a_reg, a_smem) + sum_reg = warp.row_reduce(max_reg, a_reg, lambda a, b: a.maximum(b)) + sum_reg = ker.endrange() + + b_reg = warp.zero(b_reg).after(tile_row) + b_reg = warp.map(b_reg, lambda _, idx: sum_reg[idx[0], 0, (idx[2]%4)//2]) + b_smem = warp.store(b_smem, b_reg) + + for tile_col in ker.range(M // BLOCK_M): + b = warp.store(b, b_smem, (0, 0, tile_row, tile_col), (), axis=2) + + sink = ker.finish() + + with Context(DEBUG=0): + a = Tensor.rand(1, 1, N, M, dtype="float32").contiguous() + b = Tensor.empty(1, 1, N, M, dtype="float32") + Tensor.realize(a, b) + + ei = ExecItem(get_runner(Device.DEFAULT, sink), [t.uop.buffer for t in (b, a)]) + for _ in range(5): ei.run(wait=True) + b = b.float() + + ref = a.float().max(axis=3, keepdim=True).expand(a.shape) + + assert ref.allclose(b) + + def test_sum(self): + N = 16 + BLOCK_SIZE = 16 + with Kernel((1, 1, 1), WARP_THREADS) as ker: + warp = ker.warp + + b = gl((1, 1, N, N), dtypes.float32) + a = gl((1, 1, N, N), dtypes.float32) + + a_smem = st((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32) + b_smem = st((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32) + + a_reg = rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32) + b_reg = rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32) + + sum_reg = rv(BLOCK_SIZE, dtypes.float32, "ortho") + + for tile_row in ker.range(N // BLOCK_SIZE): + sum_reg = warp.zero(sum_reg).after(tile_row) + + for tile_col in ker.range(N // BLOCK_SIZE): + a_smem = warp.load(a_smem, a, (), (0, 0, tile_row, tile_col), axis=2) + a_reg = warp.load(a_reg, a_smem) + sum_reg = warp.row_reduce(sum_reg, a_reg, lambda a, b: a + b) + sum_reg = ker.endrange() + + b_reg = warp.zero(b_reg).after(tile_row) + b_reg = warp.map(b_reg, lambda _, idx: sum_reg[idx[0], 0, (idx[2]%4)//2]) + b_smem = warp.store(b_smem, b_reg) + + for tile_col in ker.range(N // BLOCK_SIZE): + b = warp.store(b, b_smem, (0, 0, tile_row, tile_col), (), axis=2) + + sink = ker.finish() + + with Context(DEBUG=0): + a = Tensor.rand(1, 1, N, N, dtype="float32").contiguous() + a = Tensor.arange(1 * 1 * N * N).reshape(1, 1, N, N).cast(dtypes.float32).contiguous() + b = Tensor.empty(1, 1, N, N, dtype="float32") + Tensor.realize(a, b) + + ei = ExecItem(get_runner(Device.DEFAULT, sink), [t.uop.buffer for t in (b, a)]) + for _ in range(5): ei.run(wait=True) + b = b.float() + + ref = a.float().sum(axis=3, keepdim=True).expand(a.shape) + + assert ref.allclose(b) + + def test_sum_nonsquare(self): + N, M = 16, 64 + BLOCK_N, BLOCK_M = 16, 64 + with Kernel((1, 1, 1), WARP_THREADS) as ker: + warp = ker.warp + + b = gl((1, 1, N, M), dtypes.float32) + a = gl((1, 1, N, M), dtypes.float32) + + a_smem = st((BLOCK_N, BLOCK_M), dtypes.float32) + b_smem = st((BLOCK_N, BLOCK_M), dtypes.float32) + + a_reg = rt((BLOCK_N, BLOCK_M), dtypes.float32) + b_reg = rt((BLOCK_N, BLOCK_M), dtypes.float32) + + sum_reg = rv(BLOCK_N, dtypes.float32, "ortho") + + sum_reg = warp.zero(sum_reg) + + for tile_row in ker.range(N // BLOCK_N): + for tile_col in ker.range(M // BLOCK_M): + a_smem = warp.load(a_smem, a, (), (0, 0, tile_row, tile_col), axis=2) + a_reg = warp.load(a_reg, a_smem) + sum_reg = warp.row_reduce(sum_reg, a_reg, lambda a, b: a + b) + sum_reg = ker.endrange() + + b_reg = warp.zero(b_reg).after(tile_row) + b_reg = warp.map(b_reg, lambda _, idx: sum_reg[idx[0], 0, (idx[2]%4)//2]) + b_smem = warp.store(b_smem, b_reg) + + for tile_col in ker.range(M // BLOCK_M): + b = warp.store(b, b_smem, (0, 0, tile_row, tile_col), (), axis=2) + + sink = ker.finish() + + with Context(DEBUG=0): + a = Tensor.rand(1, 1, N, M, dtype="float32").contiguous() + b = Tensor.empty(1, 1, N, M, dtype="float32") + Tensor.realize(a, b) + + ei = ExecItem(get_runner(Device.DEFAULT, sink), [t.uop.buffer for t in (b, a)]) + for _ in range(5): ei.run(wait=True) + b = b.float() + + ref = a.float().sum(axis=3, keepdim=True).expand(a.shape) + + assert ref.allclose(b) + +if __name__ == "__main__": + unittest.main()