diff --git a/extra/thunder/tiny/tk/group.py b/extra/thunder/tiny/tk/group.py index 25acfe94a1..5ab1e4a01b 100644 --- a/extra/thunder/tiny/tk/group.py +++ b/extra/thunder/tiny/tk/group.py @@ -7,7 +7,7 @@ from tinygrad.dtype import AddrSpace, PtrDType from tinygrad.helpers import getenv, prod from extra.thunder.tiny.tk import WARP_THREADS -from extra.thunder.tiny.tk.tiles import TILE_ROW_DIM, TILE_COL_DIM, RT_BASE_TILE_NEPT, slots +from extra.thunder.tiny.tk.tiles import RT class Group: def __init__(self, warps:int, ker): @@ -126,11 +126,8 @@ class Group: def row_reduce(self, vec:UOp, src:UOp, op:Callable[[UOp, UOp], UOp]): assert self.warps == 1 - red_local = UOp.placeholder((self.group_threads, 2), src.dtype.base, addrspace=AddrSpace.LOCAL, slot=slots.shared_slot) - slots.shared_slot += 1 - - red_reg = UOp.placeholder((2,), src.dtype.base, addrspace=AddrSpace.REG, slot=slots.register_slot) - slots.register_slot += 1 + red_local = self.ker.alloc((self.group_threads, 2), src.dtype.base, AddrSpace.LOCAL) + red_reg = self.ker.alloc((2,), src.dtype.base, AddrSpace.REG) for height in self.ker.range(src.shape[-3], track=False): i = UOp.range(red_reg.size, Group.clear_rid) @@ -177,7 +174,7 @@ class Group: load_i_height = UOp.range(dst.shape[-3], Group.load_rid) load_i_width = UOp.range(dst.shape[-2], Group.load_rid+1) - load_i_inner = UOp.range(RT_BASE_TILE_NEPT, Group.load_rid+2) + load_i_inner = UOp.range(RT.BASE_TILE_NEPT, Group.load_rid+2) Group.load_rid += 3 if self.warps % 4 == 0: local_warpid = (self.warpid // 4) + (self.warpid % 4) * (self.warps // 4) @@ -185,14 +182,14 @@ class Group: warp_laneid = self.threadIdx_x % WARP_THREADS if not transpose: - row = (local_warpid * dst.shape[-3] + load_i_height) * TILE_ROW_DIM + (warp_laneid // 4) - col = load_i_width * TILE_COL_DIM + 2 * (warp_laneid % 4) + row = (local_warpid * dst.shape[-3] + load_i_height) * RT.TILE_ROW_DIM + (warp_laneid // 4) + col = load_i_width * RT.TILE_COL_DIM + 2 * (warp_laneid % 4) row_offset = ((load_i_inner % 4) // 2) * 8 col_offset = (load_i_inner % 2) + (load_i_inner // 4) * 8 else: - row = (local_warpid * dst.shape[-3] + load_i_height) * TILE_ROW_DIM + 2 * (warp_laneid % 4) - col = load_i_width * TILE_COL_DIM + (warp_laneid // 4) + row = (local_warpid * dst.shape[-3] + load_i_height) * RT.TILE_ROW_DIM + 2 * (warp_laneid % 4) + col = load_i_width * RT.TILE_COL_DIM + (warp_laneid // 4) row_offset = (load_i_inner % 2) + (load_i_inner // 4) * 8 col_offset = ((load_i_inner % 4) // 2) * 8 @@ -241,15 +238,15 @@ class Group: store_i_height = UOp.range(src.shape[-3], Group.store_rid) store_i_width = UOp.range(src.shape[-2], Group.store_rid+1) - store_i_inner = UOp.range(RT_BASE_TILE_NEPT, Group.store_rid+2) + store_i_inner = UOp.range(RT.BASE_TILE_NEPT, Group.store_rid+2) Group.store_rid += 3 if self.warps % 4 == 0: local_warpid = (self.warpid // 4) + (self.warpid % 4) * (self.warps // 4) else: local_warpid = self.warpid warp_laneid = self.threadIdx_x % WARP_THREADS - row = (local_warpid * src.shape[-3] + store_i_height) * TILE_ROW_DIM + (warp_laneid // 4) - col = store_i_width * TILE_COL_DIM + 2 * (warp_laneid % 4) + row = (local_warpid * src.shape[-3] + store_i_height) * RT.TILE_ROW_DIM + (warp_laneid // 4) + col = store_i_width * RT.TILE_COL_DIM + 2 * (warp_laneid % 4) row_offset = ((store_i_inner % 4) // 2) * 8 col_offset = (store_i_inner % 2) + (store_i_inner // 4) * 8 diff --git a/extra/thunder/tiny/tk/kernel.py b/extra/thunder/tiny/tk/kernel.py index 8fab1ee905..68d89ea3b2 100644 --- a/extra/thunder/tiny/tk/kernel.py +++ b/extra/thunder/tiny/tk/kernel.py @@ -1,7 +1,8 @@ from contextlib import AbstractContextManager -from tinygrad.uop.ops import UOp, KernelInfo, AxisType +from tinygrad.uop.ops import UOp, KernelInfo, AxisType, AddrSpace from extra.thunder.tiny.tk import WARP_THREADS from extra.thunder.tiny.tk.group import Group +from extra.thunder.tiny.tk.tiles import GL, ST, RT, RV class _tk_range: user_rid = 0 @@ -25,6 +26,11 @@ class Kernel(AbstractContextManager): self.range_stack = [] self.store_stack = [] + self.global_slot = 0 + self.shared_slot = 0 + self.register_slot = 0 + self.allocs = {} + @property def warpid(self): return self.threadIdx_x // WARP_THREADS @@ -42,6 +48,31 @@ class Kernel(AbstractContextManager): if track: self.range_stack.append(rng) return rng + def alloc(self, shape, dtype, addrspace:AddrSpace, name:str|None=None): + match addrspace: + case AddrSpace.GLOBAL: + slot = self.global_slot + self.global_slot += 1 + case AddrSpace.LOCAL: + slot = self.shared_slot + self.shared_slot += 1 + case AddrSpace.REG: + slot = self.register_slot + self.register_slot += 1 + + uop = UOp.placeholder(shape, dtype, slot=slot, addrspace=addrspace) + + if name: + if (name, shape) in self.allocs: return self.allocs[(name, shape)] + self.allocs[(name, shape)] = uop + + return uop + + def gl(self, shape, dtype): return GL(shape, dtype, self)._uop + def st(self, shape, dtype): return ST(shape, dtype, self)._uop + def rt(self, shape, dtype): return RT(shape, dtype, self)._uop + def rv(self, length, dtype, layout="naive"): return RV(length, dtype, layout, self)._uop + def push_store(self, store:UOp, uop:UOp): self.store_stack.append((store, uop)) def finish(self): diff --git a/extra/thunder/tiny/tk/tiles.py b/extra/thunder/tiny/tk/tiles.py index c936dfd199..0a8ecc987f 100644 --- a/extra/thunder/tiny/tk/tiles.py +++ b/extra/thunder/tiny/tk/tiles.py @@ -1,52 +1,45 @@ -import math -from typing import cast, Callable -from tinygrad import Tensor, Device, Context, GlobalCounters, dtypes -from tinygrad.uop.ops import AxisType, UOp, KernelInfo, Ops -from tinygrad.engine.realize import ExecItem, get_runner -from tinygrad.dtype import AddrSpace, PtrDType -from tinygrad.helpers import getenv, prod +from tinygrad.dtype import AddrSpace from extra.thunder.tiny.tk import WARP_THREADS -class _Slots: - def __init__(self): - self.global_slot = 0 - self.shared_slot = 0 - self.register_slot = 0 -slots = _Slots() +class GL: + def __init__(self, shape, dtype, ker): + self.shape, self.dtype = shape, dtype + self._uop = ker.alloc(shape, dtype, AddrSpace.GLOBAL) -def gl(shape, dtype): - slots.global_slot += 1 - return UOp.placeholder(shape, dtype, slot=slots.global_slot-1) +class ST: + def __init__(self, shape, dtype, ker): + self.shape, self.dtype = shape, dtype + self._uop = ker.alloc(shape, dtype, AddrSpace.LOCAL) -shared_slot = 0 -def st(shape, dtype): - slots.shared_slot += 1 - return UOp.placeholder(shape, dtype, addrspace=AddrSpace.LOCAL, slot=slots.shared_slot-1) +class RT: + TILE_ROW_DIM, TILE_COL_DIM = 16, 16 + BASE_TILE_NE = TILE_ROW_DIM * TILE_COL_DIM + BASE_TILE_NEPT = BASE_TILE_NE // WARP_THREADS -TILE_ROW_DIM, TILE_COL_DIM = 16, 16 -RT_BASE_TILE_NE = TILE_ROW_DIM * TILE_COL_DIM -RT_BASE_TILE_NEPT = RT_BASE_TILE_NE // WARP_THREADS -register_slot = 0 -def rt(shape, dtype): - assert len(shape) == 2 + def __init__(self, shape, dtype, ker): + assert len(shape) == 2 + assert shape[0] % RT.TILE_ROW_DIM == 0 + assert shape[1] % RT.TILE_COL_DIM == 0 - height = shape[0] // TILE_ROW_DIM - width = shape[1] // TILE_COL_DIM + height = shape[0] // RT.TILE_ROW_DIM + width = shape[1] // RT.TILE_COL_DIM - slots.register_slot += 1 - return UOp.placeholder((height, width, RT_BASE_TILE_NEPT), dtype, addrspace=AddrSpace.REG, slot=slots.register_slot-1) + self.shape, self.dtype = (height, width, self.BASE_TILE_NEPT), dtype + self._uop = ker.alloc(self.shape, dtype, AddrSpace.REG) -def rv(length, dtype, layout="naive"): - tiles = length // TILE_ROW_DIM - match layout: - case "naive": - inner_dim = 1 - outer_dim = (tiles + 1) // 2 - case "ortho": - inner_dim = 1 - outer_dim = tiles - case _: raise NotImplementedError(f"rv layout {layout} not implemented") +class RV: + def __init__(self, length, dtype, layout, ker): + tiles = length // RT.TILE_ROW_DIM - slots.register_slot += 1 - return UOp.placeholder((outer_dim, inner_dim, 2), dtype, addrspace=AddrSpace.REG, slot=slots.register_slot-1) + match layout: + case "naive": + inner_dim = 1 + outer_dim = (tiles + 1) // 2 + case "ortho": + inner_dim = 1 + outer_dim = tiles + case _: raise NotImplementedError(f"rv layout {layout} not implemented") + + self.shape, self.dtype = (outer_dim, inner_dim, 2), dtype + self._uop = ker.alloc(self.shape, dtype, AddrSpace.REG) diff --git a/test/external/external_test_tk.py b/test/external/external_test_tk.py index 6215394c8d..a0d099b2d6 100644 --- a/test/external/external_test_tk.py +++ b/test/external/external_test_tk.py @@ -6,7 +6,6 @@ import numpy as np from extra.thunder.tiny.tk import WARP_THREADS from extra.thunder.tiny.tk.kernel import Kernel -from extra.thunder.tiny.tk.tiles import gl, st, rt, rv class TestTK(unittest.TestCase): def test_simple_matmul(self): @@ -15,17 +14,17 @@ class TestTK(unittest.TestCase): with Kernel((N // BLOCK_SIZE, N // BLOCK_SIZE, 1), WARP_THREADS) as ker: warp = ker.warp - c = gl((1, 1, N, N), dtypes.float32) - a = gl((1, 1, N, N), dtypes.bfloat16) - b = gl((1, 1, N, N), dtypes.bfloat16) + c = ker.gl((1, 1, N, N), dtypes.float32) + a = ker.gl((1, 1, N, N), dtypes.bfloat16) + b = ker.gl((1, 1, N, N), dtypes.bfloat16) - a_smem = st((BLOCK_SIZE, BLOCK_SIZE), dtypes.bfloat16) - b_smem = st((BLOCK_SIZE, BLOCK_SIZE), dtypes.bfloat16) - c_smem = st((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32) + a_smem = ker.st((BLOCK_SIZE, BLOCK_SIZE), dtypes.bfloat16) + b_smem = ker.st((BLOCK_SIZE, BLOCK_SIZE), dtypes.bfloat16) + c_smem = ker.st((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32) - a_reg = rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.bfloat16) - b_reg = rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.bfloat16) - c_reg = rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32) + a_reg = ker.rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.bfloat16) + b_reg = ker.rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.bfloat16) + c_reg = ker.rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32) col, row = ker.blockIdx_x, ker.blockIdx_y @@ -65,17 +64,17 @@ class TestTK(unittest.TestCase): with Kernel((N // BLOCK_SIZE, N // BLOCK_SIZE, 1), WARP_THREADS) as ker: warp = ker.warp - c = gl((1, 1, N, N), dtypes.float32) - a = gl((1, 1, N, N), dtypes.bfloat16) - b = gl((1, 1, N, N), dtypes.bfloat16) + c = ker.gl((1, 1, N, N), dtypes.float32) + a = ker.gl((1, 1, N, N), dtypes.bfloat16) + b = ker.gl((1, 1, N, N), dtypes.bfloat16) - a_smem = st((BLOCK_SIZE, BLOCK_SIZE), dtypes.bfloat16) - b_smem = st((BLOCK_SIZE, BLOCK_SIZE), dtypes.bfloat16) - c_smem = st((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32) + a_smem = ker.st((BLOCK_SIZE, BLOCK_SIZE), dtypes.bfloat16) + b_smem = ker.st((BLOCK_SIZE, BLOCK_SIZE), dtypes.bfloat16) + c_smem = ker.st((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32) - a_reg = rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.bfloat16) - b_reg = rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.bfloat16) - c_reg = rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32) + a_reg = ker.rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.bfloat16) + b_reg = ker.rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.bfloat16) + c_reg = ker.rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32) col, row = ker.blockIdx_x, ker.blockIdx_y @@ -115,14 +114,14 @@ class TestTK(unittest.TestCase): with Kernel((N // BLOCK_SIZE, N // BLOCK_SIZE, 1), WARP_THREADS) as ker: warp = ker.warp - b = gl((1, 1, N, N), dtypes.float32) - a = gl((1, 1, N, N), dtypes.float32) + b = ker.gl((1, 1, N, N), dtypes.float32) + a = ker.gl((1, 1, N, N), dtypes.float32) - a_smem = st((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32) - b_smem = st((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32) + a_smem = ker.st((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32) + b_smem = ker.st((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32) - a_reg = rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32) - b_reg = rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32) + a_reg = ker.rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32) + b_reg = ker.rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32) col, row = ker.blockIdx_x, ker.blockIdx_y @@ -153,16 +152,16 @@ class TestTK(unittest.TestCase): with Kernel((1, 1, 1), WARP_THREADS) as ker: warp = ker.warp - b = gl((1, 1, N, N), dtypes.float32) - a = gl((1, 1, N, N), dtypes.float32) + b = ker.gl((1, 1, N, N), dtypes.float32) + a = ker.gl((1, 1, N, N), dtypes.float32) - a_smem = st((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32) - b_smem = st((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32) + a_smem = ker.st((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32) + b_smem = ker.st((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32) - a_reg = rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32) - b_reg = rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32) + a_reg = ker.rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32) + b_reg = ker.rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32) - max_reg = rv(BLOCK_SIZE, dtypes.float32, "ortho") + max_reg = ker.rv(BLOCK_SIZE, dtypes.float32, "ortho") for tile_row in ker.range(N // BLOCK_SIZE): max_reg = warp.neg_inf(max_reg.after(tile_row)) @@ -200,16 +199,16 @@ class TestTK(unittest.TestCase): with Kernel((1, 1, 1), WARP_THREADS) as ker: warp = ker.warp - b = gl((1, 1, N, M), dtypes.float32) - a = gl((1, 1, N, M), dtypes.float32) + b = ker.gl((1, 1, N, M), dtypes.float32) + a = ker.gl((1, 1, N, M), dtypes.float32) - a_smem = st((BLOCK_N, BLOCK_M), dtypes.float32) - b_smem = st((BLOCK_N, BLOCK_M), dtypes.float32) + a_smem = ker.st((BLOCK_N, BLOCK_M), dtypes.float32) + b_smem = ker.st((BLOCK_N, BLOCK_M), dtypes.float32) - a_reg = rt((BLOCK_N, BLOCK_M), dtypes.float32) - b_reg = rt((BLOCK_N, BLOCK_M), dtypes.float32) + a_reg = ker.rt((BLOCK_N, BLOCK_M), dtypes.float32) + b_reg = ker.rt((BLOCK_N, BLOCK_M), dtypes.float32) - max_reg = rv(BLOCK_N, dtypes.float32, "ortho") + max_reg = ker.rv(BLOCK_N, dtypes.float32, "ortho") for tile_row in ker.range(N // BLOCK_N): max_reg = warp.neg_inf(max_reg.after(tile_row)) @@ -247,16 +246,16 @@ class TestTK(unittest.TestCase): with Kernel((1, 1, 1), WARP_THREADS) as ker: warp = ker.warp - b = gl((1, 1, N, N), dtypes.float32) - a = gl((1, 1, N, N), dtypes.float32) + b = ker.gl((1, 1, N, N), dtypes.float32) + a = ker.gl((1, 1, N, N), dtypes.float32) - a_smem = st((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32) - b_smem = st((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32) + a_smem = ker.st((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32) + b_smem = ker.st((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32) - a_reg = rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32) - b_reg = rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32) + a_reg = ker.rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32) + b_reg = ker.rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32) - sum_reg = rv(BLOCK_SIZE, dtypes.float32, "ortho") + sum_reg = ker.rv(BLOCK_SIZE, dtypes.float32, "ortho") for tile_row in ker.range(N // BLOCK_SIZE): sum_reg = warp.zero(sum_reg.after(tile_row)) @@ -294,16 +293,16 @@ class TestTK(unittest.TestCase): with Kernel((1, 1, 1), WARP_THREADS) as ker: warp = ker.warp - b = gl((1, 1, N, M), dtypes.float32) - a = gl((1, 1, N, M), dtypes.float32) + b = ker.gl((1, 1, N, M), dtypes.float32) + a = ker.gl((1, 1, N, M), dtypes.float32) - a_smem = st((BLOCK_N, BLOCK_M), dtypes.float32) - b_smem = st((BLOCK_N, BLOCK_M), dtypes.float32) + a_smem = ker.st((BLOCK_N, BLOCK_M), dtypes.float32) + b_smem = ker.st((BLOCK_N, BLOCK_M), dtypes.float32) - a_reg = rt((BLOCK_N, BLOCK_M), dtypes.float32) - b_reg = rt((BLOCK_N, BLOCK_M), dtypes.float32) + a_reg = ker.rt((BLOCK_N, BLOCK_M), dtypes.float32) + b_reg = ker.rt((BLOCK_N, BLOCK_M), dtypes.float32) - sum_reg = rv(BLOCK_N, dtypes.float32, "ortho") + sum_reg = ker.rv(BLOCK_N, dtypes.float32, "ortho") for tile_row in ker.range(N // BLOCK_N): sum_reg = warp.zero(sum_reg.after(tile_row)) @@ -341,16 +340,16 @@ class TestTK(unittest.TestCase): with Kernel((1, 1, 1), WARP_THREADS) as ker: warp = ker.warp - b = gl((1, 1, BLOCK_SIZE, N), dtypes.float32) - a = gl((1, 1, BLOCK_SIZE, N), dtypes.float32) + b = ker.gl((1, 1, BLOCK_SIZE, N), dtypes.float32) + a = ker.gl((1, 1, BLOCK_SIZE, N), dtypes.float32) - a_smem = st((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32) + a_smem = ker.st((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32) - a_reg = rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32) + a_reg = ker.rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32) - max_vec_last = rv(BLOCK_SIZE, dtypes.float32, "ortho") - max_vec = rv(BLOCK_SIZE, dtypes.float32, "ortho") - norm_vec = rv(BLOCK_SIZE, dtypes.float32, "ortho") + max_vec_last = ker.rv(BLOCK_SIZE, dtypes.float32, "ortho") + max_vec = ker.rv(BLOCK_SIZE, dtypes.float32, "ortho") + norm_vec = ker.rv(BLOCK_SIZE, dtypes.float32, "ortho") max_vec = warp.neg_inf(max_vec) norm_vec = warp.zero(norm_vec)