mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
minimal vec in amd_copy_matmul (#15398)
* minimal vec in amd_copy_matmul * unified * unify * reshape/permute * cleanups * simpler * move index * cleanups * more shared
This commit is contained in:
@@ -291,7 +291,7 @@ def build_kernel(N, arch='gfx1100'):
|
||||
# MAIN GEMM LOOP
|
||||
# ===========================================================================
|
||||
|
||||
NO_DS, NO_GLOBAL = getenv("NO_DS", 0), getenv("NO_GLOBAL", 0)
|
||||
NO_ALU, NO_DS, NO_GLOBAL = getenv("NO_ALU", 0), getenv("NO_DS", 0), getenv("NO_GLOBAL", 0)
|
||||
|
||||
k.label('LOOP_INC')
|
||||
k.emit(s_add_i32(s[S_LOOP_CTR], s[S_LOOP_CTR], 8))
|
||||
@@ -350,10 +350,11 @@ def build_kernel(N, arch='gfx1100'):
|
||||
|
||||
# 64 dual FMACs
|
||||
k.waitcnt(lgkm=0)
|
||||
k.emit(s_clause(simm16=len(FMAC_PATTERN)-1))
|
||||
for i, (vdst_x, vdst_y, ax, bx, ay, by) in enumerate(FMAC_PATTERN):
|
||||
k.emit(VOPD(VOPDOp.V_DUAL_FMAC_F32, VOPDOp.V_DUAL_FMAC_F32,
|
||||
vdstx=v[vdst_x], vdsty=v[vdst_y], srcx0=v[ax], vsrcx1=v[bx], srcy0=v[ay], vsrcy1=v[by]))
|
||||
if not NO_ALU:
|
||||
k.emit(s_clause(simm16=len(FMAC_PATTERN)-1))
|
||||
for i, (vdst_x, vdst_y, ax, bx, ay, by) in enumerate(FMAC_PATTERN):
|
||||
k.emit(VOPD(VOPDOp.V_DUAL_FMAC_F32, VOPDOp.V_DUAL_FMAC_F32,
|
||||
vdstx=v[vdst_x], vdsty=v[vdst_y], srcx0=v[ax], vsrcx1=v[bx], srcy0=v[ay], vsrcy1=v[by]))
|
||||
|
||||
# wait for all global loads to finish
|
||||
# then sync the warp so it's safe to store local
|
||||
|
||||
@@ -12,26 +12,29 @@ BLOCK_K = getenv("BK", 16)
|
||||
assert N % BLOCK_N == 0 and M % BLOCK_M == 0 and K % BLOCK_K == 0
|
||||
|
||||
use_wmma = getenv("WMMA")
|
||||
|
||||
if use_wmma:
|
||||
WMMA_M, WMMA_N, WMMA_K = 16, 16, 16
|
||||
WAVES_M, WAVES_N = 2, 2
|
||||
LANES_PER_WAVE_M, LANES_PER_WAVE_N = 2, 16
|
||||
TM = BLOCK_M // (WAVES_M * WMMA_M) # 4
|
||||
TN = BLOCK_N // (WAVES_N * WMMA_N) # 4
|
||||
UNROLL_M, UNROLL_N = 1, 1
|
||||
|
||||
# wmma params
|
||||
WMMA_M, WMMA_N, WMMA_K = 16, 16, 16
|
||||
WMMA_ACC = WMMA_M // LANES_PER_WAVE_M
|
||||
else:
|
||||
UNROLL_M, UNROLL_N = 4, 4
|
||||
WAVES_M, WAVES_N = 4, 1
|
||||
LANES_PER_WAVE_M, LANES_PER_WAVE_N = 4, 8
|
||||
TM = BLOCK_M // (WAVES_M * LANES_PER_WAVE_M) # 8
|
||||
TN = BLOCK_N // (WAVES_N * LANES_PER_WAVE_N) # 16
|
||||
UNROLL_M, UNROLL_N = 4, 4
|
||||
|
||||
# WARP_SIZE * total waves
|
||||
THREADS_PER_BLOCK = WARP_SIZE * WAVES_M * WAVES_N
|
||||
|
||||
# accumulator size
|
||||
TM = BLOCK_M // (WAVES_M * LANES_PER_WAVE_M)
|
||||
TN = BLOCK_N // (WAVES_N * LANES_PER_WAVE_N)
|
||||
|
||||
def block_128x128_gemm(c:UOp, a:UOp, b:UOp) -> UOp:
|
||||
wave_m = UOp.range(WAVES_M, 100, AxisType.LOCAL)
|
||||
wave_n = UOp.range(WAVES_N, 101, AxisType.LOCAL)
|
||||
wave_m = UOp.range(WAVES_M, 2, AxisType.LOCAL)
|
||||
wave_n = UOp.range(WAVES_N, 3, AxisType.LOCAL)
|
||||
lane = UOp.range(WARP_SIZE, -1, AxisType.WARP)
|
||||
tid = (wave_m * WAVES_N + wave_n) * WARP_SIZE + lane
|
||||
|
||||
@@ -43,7 +46,7 @@ def block_128x128_gemm(c:UOp, a:UOp, b:UOp) -> UOp:
|
||||
|
||||
a = a.reshape(K // BLOCK_K, BLOCK_K, BLOCK_M)
|
||||
b = b.reshape(K // BLOCK_K, BLOCK_K, BLOCK_N)
|
||||
k_tile = UOp.range(K // BLOCK_K, 3, AxisType.REDUCE)
|
||||
k_tile = UOp.range(K // BLOCK_K, 100, AxisType.REDUCE)
|
||||
|
||||
# copy with transpose for wmma (input is k×spatial, LDS is spatial×k)
|
||||
A_copy = A_local.permute((1,0)) if use_wmma else A_local
|
||||
@@ -56,60 +59,52 @@ def block_128x128_gemm(c:UOp, a:UOp, b:UOp) -> UOp:
|
||||
# -- COMPUTE --
|
||||
lane_m, lane_n = lane // LANES_PER_WAVE_N, lane % LANES_PER_WAVE_N
|
||||
|
||||
# accumulator (unified: both paths use (TM, TN) with scalar dtypes.float)
|
||||
acc = UOp.placeholder((TM, TN), dtypes.float, slot=2, addrspace=AddrSpace.REG)
|
||||
acc = acc.after(acc.store(UOp.const(dtypes.float, 0).reshape((1,)*len(acc.shape)).expand(acc.shape)))
|
||||
|
||||
if use_wmma:
|
||||
# accumulator
|
||||
acc = UOp.placeholder((TM, TN), dtypes.float.vec(8), slot=2, addrspace=AddrSpace.REG)
|
||||
zi = UOp.range(TM, 200); zj = UOp.range(TN, 201)
|
||||
acc = acc[zi, zj].set(UOp.const(dtypes.float.vec(8), 0.0), end=(zi, zj))
|
||||
k = UOp.range(BLOCK_K // WMMA_K, 101, AxisType.REDUCE)
|
||||
tile_m = UOp.range(TM // WMMA_ACC, 200, AxisType.LOOP)
|
||||
tile_n = UOp.range(TN, 201, AxisType.LOOP)
|
||||
|
||||
A_tiles = A_local.reshape(WAVES_M, TM, WMMA_M, BLOCK_K // WMMA_K, WMMA_K)
|
||||
B_tiles = B_local.reshape(WAVES_N, TN, WMMA_N, BLOCK_K // WMMA_K, WMMA_K)
|
||||
|
||||
k = UOp.range(BLOCK_K // WMMA_K, 4, AxisType.REDUCE)
|
||||
tile_m = UOp.range(TM, 5, AxisType.LOOP)
|
||||
tile_n = UOp.range(TN, 6, AxisType.LOOP)
|
||||
acc_frag = acc.reshape(TM // WMMA_ACC, WMMA_ACC, TN).permute(0,2,1)[tile_m, tile_n]
|
||||
a_frag = A_local.reshape(WAVES_M, TM // WMMA_ACC, WMMA_M, BLOCK_K // WMMA_K, WMMA_K)[wave_m, tile_m, lane_n, k]
|
||||
b_frag = B_local.reshape(WAVES_N, TN, WMMA_N, BLOCK_K // WMMA_K, WMMA_K)[wave_n, tile_n, lane_n, k]
|
||||
|
||||
# TODO: remove unneeded CONTRACTS
|
||||
k_upcast_a = UOp.range(WMMA_K, 301, axis_type=AxisType.UPCAST)
|
||||
a_frag = A_tiles[wave_m, tile_m, lane_n, k, k_upcast_a].contract(k_upcast_a)
|
||||
k_upcast_b = UOp.range(WMMA_K, 311, axis_type=AxisType.UPCAST)
|
||||
b_frag = B_tiles[wave_n, tile_n, lane_n, k, k_upcast_b].contract(k_upcast_b)
|
||||
|
||||
acc_load = acc.after(k_tile, k, tile_m, tile_n)[tile_m, tile_n]
|
||||
acc_upcast = UOp.range(WMMA_ACC, 302, axis_type=AxisType.UPCAST)
|
||||
wmma_arg = ('WMMA_16_16_16_half_float', (16, 16, 16), dtypes.half, dtypes.float, 'AMD', 32,
|
||||
(((301, 16),), ((311, 16),), ()), ())
|
||||
out = UOp(Ops.WMMA, dtypes.float.vec(8), (a_frag, b_frag, acc_load), arg=wmma_arg)
|
||||
acc = acc.after(acc[tile_m, tile_n].store(out).end(tile_m, tile_n).end(k).barrier().end(k_tile))
|
||||
(((301, 16),), ((311, 16),), ((302, WMMA_ACC),)), ())
|
||||
out = UOp(Ops.WMMA, dtypes.float.vec(WMMA_ACC), (a_frag[k_upcast_a].contract(k_upcast_a),
|
||||
b_frag[k_upcast_b].contract(k_upcast_b),
|
||||
acc_frag.after(k)[acc_upcast].contract(acc_upcast)), arg=wmma_arg)
|
||||
|
||||
# store accumulator to output
|
||||
c = c.reshape(WAVES_M, TM, WMMA_M,
|
||||
WAVES_N, TN, WMMA_N)
|
||||
st_m = UOp.range(TM, 9, AxisType.LOOP)
|
||||
st_n = UOp.range(TN, 10, AxisType.LOOP)
|
||||
stores = [c[wave_m, st_m, e*2 + lane_m, wave_n, st_n, lane_n].store(acc[st_m, st_n].gep(e)) for e in range(8)]
|
||||
return UOp.group(*stores).end(st_m, st_n, wave_m, wave_n, lane)
|
||||
acc_store = UOp.group(*[acc_frag[e].store(out.gep(e)) for e in range(WMMA_ACC)]).end(tile_m, tile_n)
|
||||
else:
|
||||
# accumulator
|
||||
acc = UOp.placeholder((TM, TN), dtypes.float, slot=2, addrspace=AddrSpace.REG)
|
||||
acc = acc.after(acc.store(UOp.const(dtypes.float, 0).reshape((1,)*len(acc.shape)).expand(acc.shape)))
|
||||
|
||||
# registers for LOCAL -> REG
|
||||
a_frag = UOp.placeholder((TM//UNROLL_M, UNROLL_M), dtypes.float, slot=0, addrspace=AddrSpace.REG)
|
||||
b_frag = UOp.placeholder((TN//UNROLL_N, UNROLL_N), dtypes.float, slot=1, addrspace=AddrSpace.REG)
|
||||
|
||||
k = UOp.range(BLOCK_K, 4, AxisType.REDUCE)
|
||||
k = UOp.range(BLOCK_K, 101, AxisType.REDUCE)
|
||||
a_frag = a_frag.after(a_frag.store(A_local[k].reshape(WAVES_M, TM//UNROLL_M, LANES_PER_WAVE_M, UNROLL_M)[wave_m, :, lane_m, :]))
|
||||
b_frag = b_frag.after(b_frag.store(B_local[k].reshape(WAVES_N, TN//UNROLL_N, LANES_PER_WAVE_N, UNROLL_N)[wave_n, :, lane_n, :]))
|
||||
|
||||
# FMA
|
||||
a_frag = a_frag.reshape(TM, 1).expand(TM, TN)
|
||||
b_frag = b_frag.reshape(1, TN).expand(TM, TN)
|
||||
acc = acc.after(acc.store(acc.after(k) + (a_frag * b_frag)).end(k).barrier().end(k_tile))
|
||||
acc_store = acc.store(acc.after(k) + (a_frag * b_frag))
|
||||
|
||||
# store accumulator to output
|
||||
c = c.reshape(WAVES_M, TM//UNROLL_M, LANES_PER_WAVE_M, UNROLL_M,
|
||||
WAVES_N, TN//UNROLL_N, LANES_PER_WAVE_N, UNROLL_N)
|
||||
c = c.permute((0,4,2,6, 1,3,5,7)).reshape(THREADS_PER_BLOCK, TM, TN)
|
||||
return c[tid].store(acc).end(wave_m, wave_n, lane)
|
||||
# store accumulator and loop
|
||||
acc = acc.after(acc_store.end(k).barrier().end(k_tile))
|
||||
|
||||
# store accumulator to output (unified)
|
||||
c = c.reshape(WAVES_M, TM//UNROLL_M, LANES_PER_WAVE_M, UNROLL_M,
|
||||
WAVES_N, TN//UNROLL_N, LANES_PER_WAVE_N, UNROLL_N)
|
||||
c = c.permute((0,4,2,6, 1,3,5,7)).reshape(THREADS_PER_BLOCK, TM, TN)
|
||||
return c[tid].store(acc).end(wave_m, wave_n, lane)
|
||||
|
||||
def amd_copy_matmul(c:UOp, a:UOp, b:UOp) -> UOp:
|
||||
block_id_m = UOp.range(M // BLOCK_M, 0, AxisType.GLOBAL)
|
||||
|
||||
@@ -56,6 +56,16 @@ def do_expand(root:UOp):
|
||||
# repeat the arg
|
||||
new_srcs.append(src.broadcast(expand_sz))
|
||||
|
||||
# for non-PtrDType INDEX on REG buffers, expand into individual scalar INDEXes instead of one vectorized INDEX
|
||||
# this avoids creating a VECTORIZE of REG pointers which the devectorizer can't resolve
|
||||
if root.op is Ops.INDEX and not isinstance(root.dtype, PtrDType) and \
|
||||
isinstance(root.src[0].dtype, PtrDType) and root.src[0].dtype.addrspace == AddrSpace.REG:
|
||||
idxs = []
|
||||
for j in range(expand_sz):
|
||||
idx_srcs = tuple(s.gep(j) if isinstance(s.dtype, PtrDType) or s.dtype.count > 1 else s for s in new_srcs)
|
||||
idxs.append(UOp(Ops.INDEX, root.dtype, idx_srcs, root.arg))
|
||||
return UOp(Ops.UNROLL, root.dtype, (UOp(Ops.VECTORIZE, root.dtype.vec(expand_sz), tuple(idxs)),), expand_args)
|
||||
|
||||
new_arg = root.arg
|
||||
if root.op is Ops.GEP:
|
||||
assert root.dtype.count == 1
|
||||
|
||||
@@ -58,8 +58,8 @@ pm_mops = PatternMatcher([
|
||||
(UPat(GroupOp.Movement, name="r").f(Ops.INDEX, allow_any_len=True, name="idx"),
|
||||
lambda r,idx: r.src[0].index(*apply_movement_op(r.op, r.src[0].shape, r.marg, idx.src[1:]), dtype=idx.dtype, arg=idx.arg)
|
||||
if r.src[0]._shape is not None and len(idx.src[1:]) == len(r.shape) else None),
|
||||
# move movement ops after AFTER (but not when AFTER has a raw STORE with shaped children — from replace_contig_with_store_after)
|
||||
(UPat(GroupOp.Movement, name="r").after(name="a", allow_any_len=True),
|
||||
# move movement ops and INDEX after AFTER (but not when AFTER has a raw STORE with shaped children — from replace_contig_with_store_after)
|
||||
(UPat(GroupOp.Movement|{Ops.INDEX}, name="r").after(name="a", allow_any_len=True),
|
||||
lambda r,a: UOp(r.op, r.dtype, (a.replace(src=(r.src[0],)+a.src[1:]),)+r.src[1:], r.arg)
|
||||
if a.src[0]._shape is not None and not any(s.op is Ops.STORE and s.src[0]._shape is not None for s in a.src[1:]) else None),
|
||||
(UPat(GroupOp.Movement, name="r").end(name="a", allow_any_len=True), lambda r,a: a.replace(src=(r.src[0],)+a.src[1:])),
|
||||
|
||||
@@ -77,9 +77,9 @@ movement_ops = PatternMatcher([
|
||||
(UPat((Ops.VECTORIZE, Ops.VCONST), dtype=dtypes.weakint), lambda: True),
|
||||
(UPat({Ops.ADD, Ops.MUL, Ops.IDIV}, dtype=dtypes.weakint), lambda: True),
|
||||
|
||||
# AFTER on Movement Op, BUFFER, COPY, or BITCAST
|
||||
(UPat(Ops.AFTER, src=(UPat(GroupOp.Movement.union({Ops.MULTI, Ops.CONTIGUOUS, Ops.BUFFER, Ops.BITCAST, Ops.COPY})),), allow_any_len=True),
|
||||
lambda: True),
|
||||
# AFTER on Movement Op, INDEX, BUFFER, COPY, or BITCAST
|
||||
(UPat(Ops.AFTER, src=(UPat(GroupOp.Movement.union({Ops.INDEX, Ops.MULTI, Ops.CONTIGUOUS, Ops.BUFFER, Ops.BITCAST, Ops.COPY})),),
|
||||
allow_any_len=True), lambda: True),
|
||||
])
|
||||
|
||||
_tensor_spec = PatternMatcher([
|
||||
|
||||
Reference in New Issue
Block a user