mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
use shrink in amd_matmul_uop (#13026)
* use shrink in amd_matmul_uop * colors
This commit is contained in:
@@ -73,7 +73,7 @@ def hand_spec_kernel3():
|
||||
c = UOp.placeholder(dtypes.float, (N, N), slot=0)
|
||||
|
||||
BM_As_stride = (BLOCK_M + 4) if is_kernel5 else BLOCK_M
|
||||
As = UOp.placeholder(dtypes.float, (BLOCK_K, BM_As_stride), slot=0, addrspace=AddrSpace.LOCAL)
|
||||
As = UOp.placeholder(dtypes.float, (BLOCK_K, BM_As_stride), slot=0, addrspace=AddrSpace.LOCAL).shrink_to((BLOCK_K, BLOCK_M))
|
||||
Bs = UOp.placeholder(dtypes.float, (BLOCK_K, BLOCK_N), slot=1, addrspace=AddrSpace.LOCAL)
|
||||
|
||||
A_col = UOp.placeholder(dtypes.float, (ITERS_PER_WAVE_M, TM), slot=0, addrspace=AddrSpace.REG)
|
||||
@@ -113,15 +113,15 @@ def hand_spec_kernel3():
|
||||
# ---------------------------
|
||||
# LOCAL -> REG (per-wave tiles)
|
||||
# ---------------------------
|
||||
Bs_view = Bs.reshape((BLOCK_K, WAVES_IN_BLOCK_X, ITERS_PER_WAVE_N, LANES_PER_WAVE_X, TN))
|
||||
iterWaveN = UOp.range(ITERS_PER_WAVE_N, 4)
|
||||
i = UOp.range(TN, 5)
|
||||
index = waveIdx * WAVE_TILE_N + iterWaveN * N_PER_ITER + idxInWave * TN + i
|
||||
B_row = B_row[iterWaveN, i].set(Bs[k, index], end=(iterWaveN, i))
|
||||
B_row = B_row[iterWaveN, i].set(Bs_view[k, waveIdx, iterWaveN, idxInWave, i], end=(iterWaveN, i))
|
||||
|
||||
As_view = As.reshape((BLOCK_K, WAVES_IN_BLOCK_Y, ITERS_PER_WAVE_M, LANES_PER_WAVE_Y, TM))
|
||||
iterWaveM = UOp.range(ITERS_PER_WAVE_M, 6)
|
||||
i = UOp.range(TM, 7)
|
||||
index = waveIdy * WAVE_TILE_M + iterWaveM * M_PER_ITER + idyInWave * TM + i
|
||||
A_col = A_col[iterWaveM, i].set(As[k, index], end=(iterWaveM, i))
|
||||
A_col = A_col[iterWaveM, i].set(As_view[k, waveIdy, iterWaveM, idyInWave, i], end=(iterWaveM, i))
|
||||
|
||||
# ---------------------------
|
||||
# FMA: c_regs += A_col * B_row
|
||||
@@ -149,7 +149,7 @@ def hand_spec_kernel3():
|
||||
sink = c_glbl_idx.store(c_regs.after(sink)[iterWaveM, yt, iterWaveN, xt])
|
||||
sink = sink.end(iterWaveM, iterWaveN, yt, xt)
|
||||
|
||||
return sink.sink(arg=KernelInfo(opts_to_apply=()))
|
||||
return sink.sink(arg=KernelInfo(opts_to_apply=())).simplify()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import cast
|
||||
import itertools
|
||||
from tinygrad.helpers import DEVECTORIZE, TRANSCENDENTAL, SPEC
|
||||
from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype, Ops, UPat
|
||||
from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype, Ops, UPat, GroupOp
|
||||
from tinygrad.uop.spec import type_verify, program_spec, kernel_spec
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.dtype import dtypes
|
||||
@@ -20,8 +20,9 @@ from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen, p
|
||||
from tinygrad.codegen.late.linearizer import CFGContext, pm_split_ends, pm_add_control_flow, linearize
|
||||
|
||||
pm_preprocess = PatternMatcher([
|
||||
(UPat(Ops.RESHAPE, name="r").after(name="a", allow_any_len=True), lambda r,a: a.replace(src=(r.src[0],)+a.src[1:]).reshape(r.shape)),
|
||||
(UPat(Ops.RESHAPE, name="r").end(name="a", allow_any_len=True), lambda r,a: a.replace(src=(r.src[0],)+a.src[1:])),
|
||||
(UPat(GroupOp.Movement, name="r").after(name="a", allow_any_len=True),
|
||||
lambda r,a: UOp(r.op, r.dtype, (a.replace(src=(r.src[0],)+a.src[1:]),)+r.src[1:], r.arg)),
|
||||
(UPat(GroupOp.Movement, name="r").end(name="a", allow_any_len=True), lambda r,a: a.replace(src=(r.src[0],)+a.src[1:])),
|
||||
])
|
||||
|
||||
def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -> UOp:
|
||||
|
||||
@@ -757,6 +757,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
|
||||
# *** uop high level syntactic sugar ***
|
||||
|
||||
def shrink_to(self, arg:tuple[sint, ...]): return self.shrink(tuple([(0,x) for x in arg]))
|
||||
|
||||
@staticmethod
|
||||
def placeholder(dtype:DType, shape:tuple[int, ...], slot:int, addrspace=AddrSpace.GLOBAL):
|
||||
lookup = {AddrSpace.GLOBAL: Ops.DEFINE_GLOBAL, AddrSpace.LOCAL: Ops.DEFINE_LOCAL, AddrSpace.REG: Ops.DEFINE_REG}
|
||||
|
||||
@@ -44,7 +44,17 @@ shared_spec = PatternMatcher([
|
||||
|
||||
# ***** UOp spec in the Tensor graph *****
|
||||
|
||||
tensor_spec = PatternMatcher([
|
||||
movement_ops = PatternMatcher([
|
||||
(UPat((Ops.RESHAPE, Ops.EXPAND), name="mv", src=(UPat.var("x"), UPat(dtype=dtypes.index))), lambda mv,x: True),
|
||||
(UPat((Ops.PAD, Ops.SHRINK), name="mv", src=(UPat.var("x"), UPat(dtype=dtypes.index), UPat(dtype=dtypes.index))), lambda mv,x: True),
|
||||
(UPat((Ops.PERMUTE, Ops.FLIP), name="mv", src=(UPat.var("x"),)), lambda mv,x: isinstance(mv.arg, tuple)),
|
||||
|
||||
# inputs to movement ops
|
||||
(UPat((Ops.VECTORIZE, Ops.VCONST), dtype=dtypes.index), lambda: True),
|
||||
(UPat({Ops.ADD, Ops.MUL, Ops.IDIV}, dtype=dtypes.index), lambda: True),
|
||||
])
|
||||
|
||||
tensor_spec = movement_ops+PatternMatcher([
|
||||
# buffer spec
|
||||
(UPat(Ops.UNIQUE, dtypes.void, ()), lambda: True),
|
||||
(UPat(Ops.DEVICE, dtypes.void, (), name="d"), lambda d:
|
||||
@@ -67,14 +77,6 @@ tensor_spec = PatternMatcher([
|
||||
# MSTACK combines buffers into multi
|
||||
(UPat(Ops.MSTACK, name="x"), lambda x: all(isinstance(x.device, str) for x in x.src)),
|
||||
|
||||
(UPat((Ops.RESHAPE, Ops.EXPAND), name="mv", src=(UPat.var("x"), UPat(dtype=dtypes.index))), lambda mv,x: True),
|
||||
(UPat((Ops.PAD, Ops.SHRINK), name="mv", src=(UPat.var("x"), UPat(dtype=dtypes.index), UPat(dtype=dtypes.index))), lambda mv,x: True),
|
||||
(UPat((Ops.PERMUTE, Ops.FLIP), name="mv", src=(UPat.var("x"),)), lambda mv,x: isinstance(mv.arg, tuple)),
|
||||
|
||||
# inputs to movement ops
|
||||
(UPat((Ops.VECTORIZE, Ops.VCONST), dtype=dtypes.index), lambda: True),
|
||||
(UPat({Ops.ADD, Ops.MUL, Ops.IDIV}, dtype=dtypes.index), lambda: True),
|
||||
|
||||
# Tensor variable bindings
|
||||
(UPat(Ops.BIND, (dtypes.int,dtypes.index,), (UPat(Ops.DEFINE_VAR), UPat.cvar(dtype=(dtypes.int,dtypes.index,))), arg=None), lambda: True),
|
||||
|
||||
@@ -152,11 +154,9 @@ shared_codegen_spec = PatternMatcher([
|
||||
|
||||
# ***** UOp spec in kernel graph *****
|
||||
|
||||
kernel_spec = PatternMatcher([
|
||||
# RESHAPE (but only RESHAPE) is allowed here
|
||||
(UPat(Ops.RESHAPE, name="mv", src=(UPat.var("x"), UPat(dtype=dtypes.index))), lambda mv,x: True),
|
||||
(UPat(Ops.AFTER, src=(UPat(Ops.RESHAPE),), allow_any_len=True), lambda: True),
|
||||
(UPat(Ops.VCONST, dtype=dtypes.index), lambda: True),
|
||||
kernel_spec = movement_ops+PatternMatcher([
|
||||
# AFTER on Movement Op
|
||||
(UPat(Ops.AFTER, src=(UPat(GroupOp.Movement),), allow_any_len=True), lambda: True),
|
||||
|
||||
# index is allowed here
|
||||
(UPat(GroupOp.Elementwise|{Ops.CONST, Ops.RANGE, Ops.DEFINE_VAR}, dtype=dtypes.index), lambda: True),
|
||||
|
||||
@@ -14,7 +14,7 @@ from tinygrad.renderer import ProgramSpec
|
||||
from tinygrad.dtype import dtypes
|
||||
|
||||
uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0", Ops.REDUCE: "#FF5B5B",
|
||||
**{x:"#f2cb91" for x in GroupOp.Defines}, Ops.REDUCE_AXIS: "#FF6B6B",
|
||||
Ops.DEFINE_GLOBAL:"#cb9037", **{x:"#f2cb91" for x in {Ops.DEFINE_LOCAL, Ops.DEFINE_REG}}, Ops.REDUCE_AXIS: "#FF6B6B",
|
||||
Ops.RANGE: "#c8a0e0", Ops.ASSIGN: "#909090", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff",
|
||||
Ops.INDEX: "#cef263", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.KERNEL: "#3e7f55",
|
||||
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80",
|
||||
|
||||
Reference in New Issue
Block a user