From b46229ca517c1728424a7f68db6af113e625393f Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Fri, 31 Oct 2025 10:43:41 +0800 Subject: [PATCH] use shrink in amd_matmul_uop (#13026) * use shrink in amd_matmul_uop * colors --- extra/gemm/amd_uop_matmul.py | 12 ++++++------ tinygrad/codegen/__init__.py | 7 ++++--- tinygrad/uop/ops.py | 2 ++ tinygrad/uop/spec.py | 28 ++++++++++++++-------------- tinygrad/viz/serve.py | 2 +- 5 files changed, 27 insertions(+), 24 deletions(-) diff --git a/extra/gemm/amd_uop_matmul.py b/extra/gemm/amd_uop_matmul.py index 1528ef17f1..f269afdddd 100644 --- a/extra/gemm/amd_uop_matmul.py +++ b/extra/gemm/amd_uop_matmul.py @@ -73,7 +73,7 @@ def hand_spec_kernel3(): c = UOp.placeholder(dtypes.float, (N, N), slot=0) BM_As_stride = (BLOCK_M + 4) if is_kernel5 else BLOCK_M - As = UOp.placeholder(dtypes.float, (BLOCK_K, BM_As_stride), slot=0, addrspace=AddrSpace.LOCAL) + As = UOp.placeholder(dtypes.float, (BLOCK_K, BM_As_stride), slot=0, addrspace=AddrSpace.LOCAL).shrink_to((BLOCK_K, BLOCK_M)) Bs = UOp.placeholder(dtypes.float, (BLOCK_K, BLOCK_N), slot=1, addrspace=AddrSpace.LOCAL) A_col = UOp.placeholder(dtypes.float, (ITERS_PER_WAVE_M, TM), slot=0, addrspace=AddrSpace.REG) @@ -113,15 +113,15 @@ def hand_spec_kernel3(): # --------------------------- # LOCAL -> REG (per-wave tiles) # --------------------------- + Bs_view = Bs.reshape((BLOCK_K, WAVES_IN_BLOCK_X, ITERS_PER_WAVE_N, LANES_PER_WAVE_X, TN)) iterWaveN = UOp.range(ITERS_PER_WAVE_N, 4) i = UOp.range(TN, 5) - index = waveIdx * WAVE_TILE_N + iterWaveN * N_PER_ITER + idxInWave * TN + i - B_row = B_row[iterWaveN, i].set(Bs[k, index], end=(iterWaveN, i)) + B_row = B_row[iterWaveN, i].set(Bs_view[k, waveIdx, iterWaveN, idxInWave, i], end=(iterWaveN, i)) + As_view = As.reshape((BLOCK_K, WAVES_IN_BLOCK_Y, ITERS_PER_WAVE_M, LANES_PER_WAVE_Y, TM)) iterWaveM = UOp.range(ITERS_PER_WAVE_M, 6) i = UOp.range(TM, 7) - index = waveIdy * WAVE_TILE_M + iterWaveM * M_PER_ITER + idyInWave * TM + i - A_col = A_col[iterWaveM, i].set(As[k, index], end=(iterWaveM, i)) + A_col = A_col[iterWaveM, i].set(As_view[k, waveIdy, iterWaveM, idyInWave, i], end=(iterWaveM, i)) # --------------------------- # FMA: c_regs += A_col * B_row @@ -149,7 +149,7 @@ def hand_spec_kernel3(): sink = c_glbl_idx.store(c_regs.after(sink)[iterWaveM, yt, iterWaveN, xt]) sink = sink.end(iterWaveM, iterWaveN, yt, xt) - return sink.sink(arg=KernelInfo(opts_to_apply=())) + return sink.sink(arg=KernelInfo(opts_to_apply=())).simplify() if __name__ == "__main__": diff --git a/tinygrad/codegen/__init__.py b/tinygrad/codegen/__init__.py index c07a98f215..239a04c3db 100644 --- a/tinygrad/codegen/__init__.py +++ b/tinygrad/codegen/__init__.py @@ -1,7 +1,7 @@ from typing import cast import itertools from tinygrad.helpers import DEVECTORIZE, TRANSCENDENTAL, SPEC -from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype, Ops, UPat +from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype, Ops, UPat, GroupOp from tinygrad.uop.spec import type_verify, program_spec, kernel_spec from tinygrad.renderer import Renderer from tinygrad.dtype import dtypes @@ -20,8 +20,9 @@ from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen, p from tinygrad.codegen.late.linearizer import CFGContext, pm_split_ends, pm_add_control_flow, linearize pm_preprocess = PatternMatcher([ - (UPat(Ops.RESHAPE, name="r").after(name="a", allow_any_len=True), lambda r,a: a.replace(src=(r.src[0],)+a.src[1:]).reshape(r.shape)), - (UPat(Ops.RESHAPE, name="r").end(name="a", allow_any_len=True), lambda r,a: a.replace(src=(r.src[0],)+a.src[1:])), + (UPat(GroupOp.Movement, name="r").after(name="a", allow_any_len=True), + lambda r,a: UOp(r.op, r.dtype, (a.replace(src=(r.src[0],)+a.src[1:]),)+r.src[1:], r.arg)), + (UPat(GroupOp.Movement, name="r").end(name="a", allow_any_len=True), lambda r,a: a.replace(src=(r.src[0],)+a.src[1:])), ]) def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -> UOp: diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 11b62135d3..4b1b4bf2de 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -757,6 +757,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass): # *** uop high level syntactic sugar *** + def shrink_to(self, arg:tuple[sint, ...]): return self.shrink(tuple([(0,x) for x in arg])) + @staticmethod def placeholder(dtype:DType, shape:tuple[int, ...], slot:int, addrspace=AddrSpace.GLOBAL): lookup = {AddrSpace.GLOBAL: Ops.DEFINE_GLOBAL, AddrSpace.LOCAL: Ops.DEFINE_LOCAL, AddrSpace.REG: Ops.DEFINE_REG} diff --git a/tinygrad/uop/spec.py b/tinygrad/uop/spec.py index 96312d00f4..d9dbd32c53 100644 --- a/tinygrad/uop/spec.py +++ b/tinygrad/uop/spec.py @@ -44,7 +44,17 @@ shared_spec = PatternMatcher([ # ***** UOp spec in the Tensor graph ***** -tensor_spec = PatternMatcher([ +movement_ops = PatternMatcher([ + (UPat((Ops.RESHAPE, Ops.EXPAND), name="mv", src=(UPat.var("x"), UPat(dtype=dtypes.index))), lambda mv,x: True), + (UPat((Ops.PAD, Ops.SHRINK), name="mv", src=(UPat.var("x"), UPat(dtype=dtypes.index), UPat(dtype=dtypes.index))), lambda mv,x: True), + (UPat((Ops.PERMUTE, Ops.FLIP), name="mv", src=(UPat.var("x"),)), lambda mv,x: isinstance(mv.arg, tuple)), + + # inputs to movement ops + (UPat((Ops.VECTORIZE, Ops.VCONST), dtype=dtypes.index), lambda: True), + (UPat({Ops.ADD, Ops.MUL, Ops.IDIV}, dtype=dtypes.index), lambda: True), +]) + +tensor_spec = movement_ops+PatternMatcher([ # buffer spec (UPat(Ops.UNIQUE, dtypes.void, ()), lambda: True), (UPat(Ops.DEVICE, dtypes.void, (), name="d"), lambda d: @@ -67,14 +77,6 @@ tensor_spec = PatternMatcher([ # MSTACK combines buffers into multi (UPat(Ops.MSTACK, name="x"), lambda x: all(isinstance(x.device, str) for x in x.src)), - (UPat((Ops.RESHAPE, Ops.EXPAND), name="mv", src=(UPat.var("x"), UPat(dtype=dtypes.index))), lambda mv,x: True), - (UPat((Ops.PAD, Ops.SHRINK), name="mv", src=(UPat.var("x"), UPat(dtype=dtypes.index), UPat(dtype=dtypes.index))), lambda mv,x: True), - (UPat((Ops.PERMUTE, Ops.FLIP), name="mv", src=(UPat.var("x"),)), lambda mv,x: isinstance(mv.arg, tuple)), - - # inputs to movement ops - (UPat((Ops.VECTORIZE, Ops.VCONST), dtype=dtypes.index), lambda: True), - (UPat({Ops.ADD, Ops.MUL, Ops.IDIV}, dtype=dtypes.index), lambda: True), - # Tensor variable bindings (UPat(Ops.BIND, (dtypes.int,dtypes.index,), (UPat(Ops.DEFINE_VAR), UPat.cvar(dtype=(dtypes.int,dtypes.index,))), arg=None), lambda: True), @@ -152,11 +154,9 @@ shared_codegen_spec = PatternMatcher([ # ***** UOp spec in kernel graph ***** -kernel_spec = PatternMatcher([ - # RESHAPE (but only RESHAPE) is allowed here - (UPat(Ops.RESHAPE, name="mv", src=(UPat.var("x"), UPat(dtype=dtypes.index))), lambda mv,x: True), - (UPat(Ops.AFTER, src=(UPat(Ops.RESHAPE),), allow_any_len=True), lambda: True), - (UPat(Ops.VCONST, dtype=dtypes.index), lambda: True), +kernel_spec = movement_ops+PatternMatcher([ + # AFTER on Movement Op + (UPat(Ops.AFTER, src=(UPat(GroupOp.Movement),), allow_any_len=True), lambda: True), # index is allowed here (UPat(GroupOp.Elementwise|{Ops.CONST, Ops.RANGE, Ops.DEFINE_VAR}, dtype=dtypes.index), lambda: True), diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index e6bb9fedf0..9c40fa7260 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -14,7 +14,7 @@ from tinygrad.renderer import ProgramSpec from tinygrad.dtype import dtypes uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0", Ops.REDUCE: "#FF5B5B", - **{x:"#f2cb91" for x in GroupOp.Defines}, Ops.REDUCE_AXIS: "#FF6B6B", + Ops.DEFINE_GLOBAL:"#cb9037", **{x:"#f2cb91" for x in {Ops.DEFINE_LOCAL, Ops.DEFINE_REG}}, Ops.REDUCE_AXIS: "#FF6B6B", Ops.RANGE: "#c8a0e0", Ops.ASSIGN: "#909090", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff", Ops.INDEX: "#cef263", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.KERNEL: "#3e7f55", **{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80",