use shrink in amd_matmul_uop (#13026)

* use shrink in amd_matmul_uop

* colors
This commit is contained in:
George Hotz
2025-10-31 10:43:41 +08:00
committed by GitHub
parent 78f7650eec
commit b46229ca51
5 changed files with 27 additions and 24 deletions

View File

@@ -73,7 +73,7 @@ def hand_spec_kernel3():
c = UOp.placeholder(dtypes.float, (N, N), slot=0)
BM_As_stride = (BLOCK_M + 4) if is_kernel5 else BLOCK_M
As = UOp.placeholder(dtypes.float, (BLOCK_K, BM_As_stride), slot=0, addrspace=AddrSpace.LOCAL)
As = UOp.placeholder(dtypes.float, (BLOCK_K, BM_As_stride), slot=0, addrspace=AddrSpace.LOCAL).shrink_to((BLOCK_K, BLOCK_M))
Bs = UOp.placeholder(dtypes.float, (BLOCK_K, BLOCK_N), slot=1, addrspace=AddrSpace.LOCAL)
A_col = UOp.placeholder(dtypes.float, (ITERS_PER_WAVE_M, TM), slot=0, addrspace=AddrSpace.REG)
@@ -113,15 +113,15 @@ def hand_spec_kernel3():
# ---------------------------
# LOCAL -> REG (per-wave tiles)
# ---------------------------
Bs_view = Bs.reshape((BLOCK_K, WAVES_IN_BLOCK_X, ITERS_PER_WAVE_N, LANES_PER_WAVE_X, TN))
iterWaveN = UOp.range(ITERS_PER_WAVE_N, 4)
i = UOp.range(TN, 5)
index = waveIdx * WAVE_TILE_N + iterWaveN * N_PER_ITER + idxInWave * TN + i
B_row = B_row[iterWaveN, i].set(Bs[k, index], end=(iterWaveN, i))
B_row = B_row[iterWaveN, i].set(Bs_view[k, waveIdx, iterWaveN, idxInWave, i], end=(iterWaveN, i))
As_view = As.reshape((BLOCK_K, WAVES_IN_BLOCK_Y, ITERS_PER_WAVE_M, LANES_PER_WAVE_Y, TM))
iterWaveM = UOp.range(ITERS_PER_WAVE_M, 6)
i = UOp.range(TM, 7)
index = waveIdy * WAVE_TILE_M + iterWaveM * M_PER_ITER + idyInWave * TM + i
A_col = A_col[iterWaveM, i].set(As[k, index], end=(iterWaveM, i))
A_col = A_col[iterWaveM, i].set(As_view[k, waveIdy, iterWaveM, idyInWave, i], end=(iterWaveM, i))
# ---------------------------
# FMA: c_regs += A_col * B_row
@@ -149,7 +149,7 @@ def hand_spec_kernel3():
sink = c_glbl_idx.store(c_regs.after(sink)[iterWaveM, yt, iterWaveN, xt])
sink = sink.end(iterWaveM, iterWaveN, yt, xt)
return sink.sink(arg=KernelInfo(opts_to_apply=()))
return sink.sink(arg=KernelInfo(opts_to_apply=())).simplify()
if __name__ == "__main__":

View File

@@ -1,7 +1,7 @@
from typing import cast
import itertools
from tinygrad.helpers import DEVECTORIZE, TRANSCENDENTAL, SPEC
from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype, Ops, UPat
from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype, Ops, UPat, GroupOp
from tinygrad.uop.spec import type_verify, program_spec, kernel_spec
from tinygrad.renderer import Renderer
from tinygrad.dtype import dtypes
@@ -20,8 +20,9 @@ from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen, p
from tinygrad.codegen.late.linearizer import CFGContext, pm_split_ends, pm_add_control_flow, linearize
pm_preprocess = PatternMatcher([
(UPat(Ops.RESHAPE, name="r").after(name="a", allow_any_len=True), lambda r,a: a.replace(src=(r.src[0],)+a.src[1:]).reshape(r.shape)),
(UPat(Ops.RESHAPE, name="r").end(name="a", allow_any_len=True), lambda r,a: a.replace(src=(r.src[0],)+a.src[1:])),
(UPat(GroupOp.Movement, name="r").after(name="a", allow_any_len=True),
lambda r,a: UOp(r.op, r.dtype, (a.replace(src=(r.src[0],)+a.src[1:]),)+r.src[1:], r.arg)),
(UPat(GroupOp.Movement, name="r").end(name="a", allow_any_len=True), lambda r,a: a.replace(src=(r.src[0],)+a.src[1:])),
])
def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -> UOp:

View File

@@ -757,6 +757,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
# *** uop high level syntactic sugar ***
def shrink_to(self, arg:tuple[sint, ...]): return self.shrink(tuple([(0,x) for x in arg]))
@staticmethod
def placeholder(dtype:DType, shape:tuple[int, ...], slot:int, addrspace=AddrSpace.GLOBAL):
lookup = {AddrSpace.GLOBAL: Ops.DEFINE_GLOBAL, AddrSpace.LOCAL: Ops.DEFINE_LOCAL, AddrSpace.REG: Ops.DEFINE_REG}

View File

@@ -44,7 +44,17 @@ shared_spec = PatternMatcher([
# ***** UOp spec in the Tensor graph *****
tensor_spec = PatternMatcher([
movement_ops = PatternMatcher([
(UPat((Ops.RESHAPE, Ops.EXPAND), name="mv", src=(UPat.var("x"), UPat(dtype=dtypes.index))), lambda mv,x: True),
(UPat((Ops.PAD, Ops.SHRINK), name="mv", src=(UPat.var("x"), UPat(dtype=dtypes.index), UPat(dtype=dtypes.index))), lambda mv,x: True),
(UPat((Ops.PERMUTE, Ops.FLIP), name="mv", src=(UPat.var("x"),)), lambda mv,x: isinstance(mv.arg, tuple)),
# inputs to movement ops
(UPat((Ops.VECTORIZE, Ops.VCONST), dtype=dtypes.index), lambda: True),
(UPat({Ops.ADD, Ops.MUL, Ops.IDIV}, dtype=dtypes.index), lambda: True),
])
tensor_spec = movement_ops+PatternMatcher([
# buffer spec
(UPat(Ops.UNIQUE, dtypes.void, ()), lambda: True),
(UPat(Ops.DEVICE, dtypes.void, (), name="d"), lambda d:
@@ -67,14 +77,6 @@ tensor_spec = PatternMatcher([
# MSTACK combines buffers into multi
(UPat(Ops.MSTACK, name="x"), lambda x: all(isinstance(x.device, str) for x in x.src)),
(UPat((Ops.RESHAPE, Ops.EXPAND), name="mv", src=(UPat.var("x"), UPat(dtype=dtypes.index))), lambda mv,x: True),
(UPat((Ops.PAD, Ops.SHRINK), name="mv", src=(UPat.var("x"), UPat(dtype=dtypes.index), UPat(dtype=dtypes.index))), lambda mv,x: True),
(UPat((Ops.PERMUTE, Ops.FLIP), name="mv", src=(UPat.var("x"),)), lambda mv,x: isinstance(mv.arg, tuple)),
# inputs to movement ops
(UPat((Ops.VECTORIZE, Ops.VCONST), dtype=dtypes.index), lambda: True),
(UPat({Ops.ADD, Ops.MUL, Ops.IDIV}, dtype=dtypes.index), lambda: True),
# Tensor variable bindings
(UPat(Ops.BIND, (dtypes.int,dtypes.index,), (UPat(Ops.DEFINE_VAR), UPat.cvar(dtype=(dtypes.int,dtypes.index,))), arg=None), lambda: True),
@@ -152,11 +154,9 @@ shared_codegen_spec = PatternMatcher([
# ***** UOp spec in kernel graph *****
kernel_spec = PatternMatcher([
# RESHAPE (but only RESHAPE) is allowed here
(UPat(Ops.RESHAPE, name="mv", src=(UPat.var("x"), UPat(dtype=dtypes.index))), lambda mv,x: True),
(UPat(Ops.AFTER, src=(UPat(Ops.RESHAPE),), allow_any_len=True), lambda: True),
(UPat(Ops.VCONST, dtype=dtypes.index), lambda: True),
kernel_spec = movement_ops+PatternMatcher([
# AFTER on Movement Op
(UPat(Ops.AFTER, src=(UPat(GroupOp.Movement),), allow_any_len=True), lambda: True),
# index is allowed here
(UPat(GroupOp.Elementwise|{Ops.CONST, Ops.RANGE, Ops.DEFINE_VAR}, dtype=dtypes.index), lambda: True),

View File

@@ -14,7 +14,7 @@ from tinygrad.renderer import ProgramSpec
from tinygrad.dtype import dtypes
uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0", Ops.REDUCE: "#FF5B5B",
**{x:"#f2cb91" for x in GroupOp.Defines}, Ops.REDUCE_AXIS: "#FF6B6B",
Ops.DEFINE_GLOBAL:"#cb9037", **{x:"#f2cb91" for x in {Ops.DEFINE_LOCAL, Ops.DEFINE_REG}}, Ops.REDUCE_AXIS: "#FF6B6B",
Ops.RANGE: "#c8a0e0", Ops.ASSIGN: "#909090", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff",
Ops.INDEX: "#cef263", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.KERNEL: "#3e7f55",
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80",