add SHAPED_WMMA (#15400)

* add SHAPED_WMMA

* shaped wmma

* less bad
This commit is contained in:
George Hotz
2026-03-21 16:16:03 +08:00
committed by GitHub
parent 41a9b09683
commit c13d9d29ff
6 changed files with 24 additions and 13 deletions

View File

@@ -72,17 +72,8 @@ def block_128x128_gemm(c:UOp, a:UOp, b:UOp) -> UOp:
a_frag = A_local.reshape(WAVES_M, TM // WMMA_ACC, WMMA_M, BLOCK_K // WMMA_K, WMMA_K)[wave_m, tile_m, lane_n, k]
b_frag = B_local.reshape(WAVES_N, TN, WMMA_N, BLOCK_K // WMMA_K, WMMA_K)[wave_n, tile_n, lane_n, k]
# TODO: remove unneeded CONTRACTS
k_upcast_a = UOp.range(WMMA_K, 301, axis_type=AxisType.UPCAST)
k_upcast_b = UOp.range(WMMA_K, 311, axis_type=AxisType.UPCAST)
acc_upcast = UOp.range(WMMA_ACC, 302, axis_type=AxisType.UPCAST)
wmma_arg = ('WMMA_16_16_16_half_float', (16, 16, 16), dtypes.half, dtypes.float, 'AMD', 32,
(((301, 16),), ((311, 16),), ((302, WMMA_ACC),)), ())
out = UOp(Ops.WMMA, dtypes.float.vec(WMMA_ACC), (a_frag[k_upcast_a].contract(k_upcast_a),
b_frag[k_upcast_b].contract(k_upcast_b),
acc_frag.after(k)[acc_upcast].contract(acc_upcast)), arg=wmma_arg)
acc_store = UOp.group(*[acc_frag[e].store(out.gep(e)) for e in range(WMMA_ACC)]).end(tile_m, tile_n)
wmma = UOp(Ops.SHAPED_WMMA, dtypes.float, (a_frag, b_frag, acc_frag.after(k)), arg=((16, 16, 16), 'AMD', 32))
acc_store = acc_frag.store(wmma).end(tile_m, tile_n)
else:
# registers for LOCAL -> REG
a_frag = UOp.placeholder((TM//UNROLL_M, UNROLL_M), dtypes.float, slot=0, addrspace=AddrSpace.REG)

View File

@@ -22,6 +22,17 @@ def add_ranges_to_store(ctx, x):
idxs = [UOp.range(r, next(ctx), AxisType.LOOP) for r in x.src[0].shape]
return UOp.store(x.src[0].index(*idxs), x.src[1].index(*idxs)).end(*idxs)
def lower_shaped_wmma(ctx, x):
dims, device, threads = x.arg
dtype_in, dtype_out = x.src[0].dtype.base, x.dtype
upcasts = [(s, UOp.range(s.shape[-1], next(ctx), axis_type=AxisType.UPCAST)) for s in x.src]
tc_upcast_axes = tuple(((u.arg[0], s.shape[-1]),) for s, u in upcasts)
name = f"WMMA_{'_'.join(map(str, dims))}_{dtype_in.name}_{dtype_out.name}"
wmma_arg = (name, dims, dtype_in, dtype_out, device, threads, tc_upcast_axes, ())
wmma = UOp(Ops.WMMA, dtype_out.vec(x.src[2].shape[-1]), tuple(s[u].contract(u) for s, u in upcasts), arg=wmma_arg)
tmp = UOp.placeholder((x.src[2].shape[-1],), dtype_out, slot=next(ctx), addrspace=AddrSpace.REG)
return tmp.after(UOp.group(*[tmp[e].store(wmma.gep(e)) for e in range(x.src[2].shape[-1])]))
pm_store_ranges = PatternMatcher([
(UPat(Ops.STORE, name="x"), add_ranges_to_store),
])
@@ -63,6 +74,8 @@ pm_mops = PatternMatcher([
lambda r,a: UOp(r.op, r.dtype, (a.replace(src=(r.src[0],)+a.src[1:]),)+r.src[1:], r.arg)
if a.src[0]._shape is not None and not any(s.op is Ops.STORE and s.src[0]._shape is not None for s in a.src[1:]) else None),
(UPat(GroupOp.Movement, name="r").end(name="a", allow_any_len=True), lambda r,a: a.replace(src=(r.src[0],)+a.src[1:])),
# lower SHAPED_WMMA to WMMA with CONTRACT/UNROLL
(UPat(Ops.SHAPED_WMMA, name="x"), lower_shaped_wmma),
])
# *****************

View File

@@ -53,7 +53,7 @@ class Ops(FastEnum):
# ** 4 -- math **
# tensor core math op, not elementwise
WMMA = auto()
WMMA = auto(); SHAPED_WMMA = auto()
# UnaryOps
CAST = auto(); BITCAST = auto(); EXP2 = auto(); LOG2 = auto(); SIN = auto()

View File

@@ -249,6 +249,9 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
if len(self.src) >= 1: return tuple(self.src[0].sgep(i) for i in range(self.src[0].dtype.count))
return None
# SHAPED_WMMA output shape = accumulator shape (src[2])
case Ops.SHAPED_WMMA: return self.src[2]._shape
# passthrough ops
case Ops.REDUCE | Ops.MSTACK | Ops.MSELECT | Ops.DETACH | Ops.CONTIGUOUS | Ops.CONTIGUOUS_BACKWARD | Ops.AFTER | Ops.END:
return self.src[0]._shape

View File

@@ -205,6 +205,10 @@ kernel_spec = PatternMatcher([
(UPat(Ops.CONTRACT, name="x"), lambda x: x.dtype.count == prod(y[1] for y in x.arg)),
(UPat(Ops.UNROLL, name="x"), lambda x: x.src[0].dtype.count == prod(y[1] for y in x.arg)),
# SHAPED_WMMA has <a, b, acc> with shaped inputs, arg=((M,N,K), device, threads), lowered to WMMA+CONTRACT later
(UPat(Ops.SHAPED_WMMA, src=(UPat(), UPat(), UPat()), name="x"),
lambda x: isinstance(x.arg, tuple) and len(x.arg) == 3 and isinstance(x.arg[0], tuple)),
# END can end multiple axes here
(UPat(Ops.END, src=(UPat(), UPat()), allow_any_len=True), lambda: True),

View File

@@ -44,7 +44,7 @@ from tinygrad.device import ProfileDeviceEvent, ProfileGraphEvent, ProfileGraphE
from tinygrad.dtype import dtypes
uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0", Ops.REDUCE: "#FF5B5B",
**{x:"#f2cb91" for x in {Ops.DEFINE_LOCAL, Ops.DEFINE_REG}}, Ops.REDUCE_AXIS: "#FF6B6B",
**{x:"#f2cb91" for x in {Ops.DEFINE_LOCAL, Ops.DEFINE_REG}}, Ops.REDUCE_AXIS: "#FF6B6B", Ops.SHAPED_WMMA: "#FF5B5B",
Ops.RANGE: "#c8a0e0", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff",
Ops.INDEX: "#cef263", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.INS: "#eec4ff",
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80",