diff --git a/extra/gemm/amd_asm_matmul.py b/extra/gemm/amd_asm_matmul.py index 4261396763..a97a7a0a84 100644 --- a/extra/gemm/amd_asm_matmul.py +++ b/extra/gemm/amd_asm_matmul.py @@ -291,7 +291,7 @@ def build_kernel(N, arch='gfx1100'): # MAIN GEMM LOOP # =========================================================================== - NO_DS, NO_GLOBAL = getenv("NO_DS", 0), getenv("NO_GLOBAL", 0) + NO_ALU, NO_DS, NO_GLOBAL = getenv("NO_ALU", 0), getenv("NO_DS", 0), getenv("NO_GLOBAL", 0) k.label('LOOP_INC') k.emit(s_add_i32(s[S_LOOP_CTR], s[S_LOOP_CTR], 8)) @@ -350,10 +350,11 @@ def build_kernel(N, arch='gfx1100'): # 64 dual FMACs k.waitcnt(lgkm=0) - k.emit(s_clause(simm16=len(FMAC_PATTERN)-1)) - for i, (vdst_x, vdst_y, ax, bx, ay, by) in enumerate(FMAC_PATTERN): - k.emit(VOPD(VOPDOp.V_DUAL_FMAC_F32, VOPDOp.V_DUAL_FMAC_F32, - vdstx=v[vdst_x], vdsty=v[vdst_y], srcx0=v[ax], vsrcx1=v[bx], srcy0=v[ay], vsrcy1=v[by])) + if not NO_ALU: + k.emit(s_clause(simm16=len(FMAC_PATTERN)-1)) + for i, (vdst_x, vdst_y, ax, bx, ay, by) in enumerate(FMAC_PATTERN): + k.emit(VOPD(VOPDOp.V_DUAL_FMAC_F32, VOPDOp.V_DUAL_FMAC_F32, + vdstx=v[vdst_x], vdsty=v[vdst_y], srcx0=v[ax], vsrcx1=v[bx], srcy0=v[ay], vsrcy1=v[by])) # wait for all global loads to finish # then sync the warp so it's safe to store local diff --git a/extra/gemm/amd_copy_matmul.py b/extra/gemm/amd_copy_matmul.py index f1d6b39b84..85e6479447 100644 --- a/extra/gemm/amd_copy_matmul.py +++ b/extra/gemm/amd_copy_matmul.py @@ -12,26 +12,29 @@ BLOCK_K = getenv("BK", 16) assert N % BLOCK_N == 0 and M % BLOCK_M == 0 and K % BLOCK_K == 0 use_wmma = getenv("WMMA") - if use_wmma: - WMMA_M, WMMA_N, WMMA_K = 16, 16, 16 WAVES_M, WAVES_N = 2, 2 LANES_PER_WAVE_M, LANES_PER_WAVE_N = 2, 16 - TM = BLOCK_M // (WAVES_M * WMMA_M) # 4 - TN = BLOCK_N // (WAVES_N * WMMA_N) # 4 + UNROLL_M, UNROLL_N = 1, 1 + + # wmma params + WMMA_M, WMMA_N, WMMA_K = 16, 16, 16 + WMMA_ACC = WMMA_M // LANES_PER_WAVE_M else: - UNROLL_M, UNROLL_N = 4, 4 WAVES_M, WAVES_N = 4, 1 LANES_PER_WAVE_M, LANES_PER_WAVE_N = 4, 8 - TM = BLOCK_M // (WAVES_M * LANES_PER_WAVE_M) # 8 - TN = BLOCK_N // (WAVES_N * LANES_PER_WAVE_N) # 16 + UNROLL_M, UNROLL_N = 4, 4 # WARP_SIZE * total waves THREADS_PER_BLOCK = WARP_SIZE * WAVES_M * WAVES_N +# accumulator size +TM = BLOCK_M // (WAVES_M * LANES_PER_WAVE_M) +TN = BLOCK_N // (WAVES_N * LANES_PER_WAVE_N) + def block_128x128_gemm(c:UOp, a:UOp, b:UOp) -> UOp: - wave_m = UOp.range(WAVES_M, 100, AxisType.LOCAL) - wave_n = UOp.range(WAVES_N, 101, AxisType.LOCAL) + wave_m = UOp.range(WAVES_M, 2, AxisType.LOCAL) + wave_n = UOp.range(WAVES_N, 3, AxisType.LOCAL) lane = UOp.range(WARP_SIZE, -1, AxisType.WARP) tid = (wave_m * WAVES_N + wave_n) * WARP_SIZE + lane @@ -43,7 +46,7 @@ def block_128x128_gemm(c:UOp, a:UOp, b:UOp) -> UOp: a = a.reshape(K // BLOCK_K, BLOCK_K, BLOCK_M) b = b.reshape(K // BLOCK_K, BLOCK_K, BLOCK_N) - k_tile = UOp.range(K // BLOCK_K, 3, AxisType.REDUCE) + k_tile = UOp.range(K // BLOCK_K, 100, AxisType.REDUCE) # copy with transpose for wmma (input is k×spatial, LDS is spatial×k) A_copy = A_local.permute((1,0)) if use_wmma else A_local @@ -56,60 +59,52 @@ def block_128x128_gemm(c:UOp, a:UOp, b:UOp) -> UOp: # -- COMPUTE -- lane_m, lane_n = lane // LANES_PER_WAVE_N, lane % LANES_PER_WAVE_N + # accumulator (unified: both paths use (TM, TN) with scalar dtypes.float) + acc = UOp.placeholder((TM, TN), dtypes.float, slot=2, addrspace=AddrSpace.REG) + acc = acc.after(acc.store(UOp.const(dtypes.float, 0).reshape((1,)*len(acc.shape)).expand(acc.shape))) + if use_wmma: - # accumulator - acc = UOp.placeholder((TM, TN), dtypes.float.vec(8), slot=2, addrspace=AddrSpace.REG) - zi = UOp.range(TM, 200); zj = UOp.range(TN, 201) - acc = acc[zi, zj].set(UOp.const(dtypes.float.vec(8), 0.0), end=(zi, zj)) + k = UOp.range(BLOCK_K // WMMA_K, 101, AxisType.REDUCE) + tile_m = UOp.range(TM // WMMA_ACC, 200, AxisType.LOOP) + tile_n = UOp.range(TN, 201, AxisType.LOOP) - A_tiles = A_local.reshape(WAVES_M, TM, WMMA_M, BLOCK_K // WMMA_K, WMMA_K) - B_tiles = B_local.reshape(WAVES_N, TN, WMMA_N, BLOCK_K // WMMA_K, WMMA_K) - - k = UOp.range(BLOCK_K // WMMA_K, 4, AxisType.REDUCE) - tile_m = UOp.range(TM, 5, AxisType.LOOP) - tile_n = UOp.range(TN, 6, AxisType.LOOP) + acc_frag = acc.reshape(TM // WMMA_ACC, WMMA_ACC, TN).permute(0,2,1)[tile_m, tile_n] + a_frag = A_local.reshape(WAVES_M, TM // WMMA_ACC, WMMA_M, BLOCK_K // WMMA_K, WMMA_K)[wave_m, tile_m, lane_n, k] + b_frag = B_local.reshape(WAVES_N, TN, WMMA_N, BLOCK_K // WMMA_K, WMMA_K)[wave_n, tile_n, lane_n, k] + # TODO: remove unneeded CONTRACTS k_upcast_a = UOp.range(WMMA_K, 301, axis_type=AxisType.UPCAST) - a_frag = A_tiles[wave_m, tile_m, lane_n, k, k_upcast_a].contract(k_upcast_a) k_upcast_b = UOp.range(WMMA_K, 311, axis_type=AxisType.UPCAST) - b_frag = B_tiles[wave_n, tile_n, lane_n, k, k_upcast_b].contract(k_upcast_b) - - acc_load = acc.after(k_tile, k, tile_m, tile_n)[tile_m, tile_n] + acc_upcast = UOp.range(WMMA_ACC, 302, axis_type=AxisType.UPCAST) wmma_arg = ('WMMA_16_16_16_half_float', (16, 16, 16), dtypes.half, dtypes.float, 'AMD', 32, - (((301, 16),), ((311, 16),), ()), ()) - out = UOp(Ops.WMMA, dtypes.float.vec(8), (a_frag, b_frag, acc_load), arg=wmma_arg) - acc = acc.after(acc[tile_m, tile_n].store(out).end(tile_m, tile_n).end(k).barrier().end(k_tile)) + (((301, 16),), ((311, 16),), ((302, WMMA_ACC),)), ()) + out = UOp(Ops.WMMA, dtypes.float.vec(WMMA_ACC), (a_frag[k_upcast_a].contract(k_upcast_a), + b_frag[k_upcast_b].contract(k_upcast_b), + acc_frag.after(k)[acc_upcast].contract(acc_upcast)), arg=wmma_arg) - # store accumulator to output - c = c.reshape(WAVES_M, TM, WMMA_M, - WAVES_N, TN, WMMA_N) - st_m = UOp.range(TM, 9, AxisType.LOOP) - st_n = UOp.range(TN, 10, AxisType.LOOP) - stores = [c[wave_m, st_m, e*2 + lane_m, wave_n, st_n, lane_n].store(acc[st_m, st_n].gep(e)) for e in range(8)] - return UOp.group(*stores).end(st_m, st_n, wave_m, wave_n, lane) + acc_store = UOp.group(*[acc_frag[e].store(out.gep(e)) for e in range(WMMA_ACC)]).end(tile_m, tile_n) else: - # accumulator - acc = UOp.placeholder((TM, TN), dtypes.float, slot=2, addrspace=AddrSpace.REG) - acc = acc.after(acc.store(UOp.const(dtypes.float, 0).reshape((1,)*len(acc.shape)).expand(acc.shape))) - # registers for LOCAL -> REG a_frag = UOp.placeholder((TM//UNROLL_M, UNROLL_M), dtypes.float, slot=0, addrspace=AddrSpace.REG) b_frag = UOp.placeholder((TN//UNROLL_N, UNROLL_N), dtypes.float, slot=1, addrspace=AddrSpace.REG) - k = UOp.range(BLOCK_K, 4, AxisType.REDUCE) + k = UOp.range(BLOCK_K, 101, AxisType.REDUCE) a_frag = a_frag.after(a_frag.store(A_local[k].reshape(WAVES_M, TM//UNROLL_M, LANES_PER_WAVE_M, UNROLL_M)[wave_m, :, lane_m, :])) b_frag = b_frag.after(b_frag.store(B_local[k].reshape(WAVES_N, TN//UNROLL_N, LANES_PER_WAVE_N, UNROLL_N)[wave_n, :, lane_n, :])) # FMA a_frag = a_frag.reshape(TM, 1).expand(TM, TN) b_frag = b_frag.reshape(1, TN).expand(TM, TN) - acc = acc.after(acc.store(acc.after(k) + (a_frag * b_frag)).end(k).barrier().end(k_tile)) + acc_store = acc.store(acc.after(k) + (a_frag * b_frag)) - # store accumulator to output - c = c.reshape(WAVES_M, TM//UNROLL_M, LANES_PER_WAVE_M, UNROLL_M, - WAVES_N, TN//UNROLL_N, LANES_PER_WAVE_N, UNROLL_N) - c = c.permute((0,4,2,6, 1,3,5,7)).reshape(THREADS_PER_BLOCK, TM, TN) - return c[tid].store(acc).end(wave_m, wave_n, lane) + # store accumulator and loop + acc = acc.after(acc_store.end(k).barrier().end(k_tile)) + + # store accumulator to output (unified) + c = c.reshape(WAVES_M, TM//UNROLL_M, LANES_PER_WAVE_M, UNROLL_M, + WAVES_N, TN//UNROLL_N, LANES_PER_WAVE_N, UNROLL_N) + c = c.permute((0,4,2,6, 1,3,5,7)).reshape(THREADS_PER_BLOCK, TM, TN) + return c[tid].store(acc).end(wave_m, wave_n, lane) def amd_copy_matmul(c:UOp, a:UOp, b:UOp) -> UOp: block_id_m = UOp.range(M // BLOCK_M, 0, AxisType.GLOBAL) diff --git a/tinygrad/codegen/late/expander.py b/tinygrad/codegen/late/expander.py index 9ccd700723..4ae8c675b5 100644 --- a/tinygrad/codegen/late/expander.py +++ b/tinygrad/codegen/late/expander.py @@ -56,6 +56,16 @@ def do_expand(root:UOp): # repeat the arg new_srcs.append(src.broadcast(expand_sz)) + # for non-PtrDType INDEX on REG buffers, expand into individual scalar INDEXes instead of one vectorized INDEX + # this avoids creating a VECTORIZE of REG pointers which the devectorizer can't resolve + if root.op is Ops.INDEX and not isinstance(root.dtype, PtrDType) and \ + isinstance(root.src[0].dtype, PtrDType) and root.src[0].dtype.addrspace == AddrSpace.REG: + idxs = [] + for j in range(expand_sz): + idx_srcs = tuple(s.gep(j) if isinstance(s.dtype, PtrDType) or s.dtype.count > 1 else s for s in new_srcs) + idxs.append(UOp(Ops.INDEX, root.dtype, idx_srcs, root.arg)) + return UOp(Ops.UNROLL, root.dtype, (UOp(Ops.VECTORIZE, root.dtype.vec(expand_sz), tuple(idxs)),), expand_args) + new_arg = root.arg if root.op is Ops.GEP: assert root.dtype.count == 1 diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index 69719445ff..d3ce0fcdcf 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -58,8 +58,8 @@ pm_mops = PatternMatcher([ (UPat(GroupOp.Movement, name="r").f(Ops.INDEX, allow_any_len=True, name="idx"), lambda r,idx: r.src[0].index(*apply_movement_op(r.op, r.src[0].shape, r.marg, idx.src[1:]), dtype=idx.dtype, arg=idx.arg) if r.src[0]._shape is not None and len(idx.src[1:]) == len(r.shape) else None), - # move movement ops after AFTER (but not when AFTER has a raw STORE with shaped children — from replace_contig_with_store_after) - (UPat(GroupOp.Movement, name="r").after(name="a", allow_any_len=True), + # move movement ops and INDEX after AFTER (but not when AFTER has a raw STORE with shaped children — from replace_contig_with_store_after) + (UPat(GroupOp.Movement|{Ops.INDEX}, name="r").after(name="a", allow_any_len=True), lambda r,a: UOp(r.op, r.dtype, (a.replace(src=(r.src[0],)+a.src[1:]),)+r.src[1:], r.arg) if a.src[0]._shape is not None and not any(s.op is Ops.STORE and s.src[0]._shape is not None for s in a.src[1:]) else None), (UPat(GroupOp.Movement, name="r").end(name="a", allow_any_len=True), lambda r,a: a.replace(src=(r.src[0],)+a.src[1:])), diff --git a/tinygrad/uop/spec.py b/tinygrad/uop/spec.py index f67808b192..bde80ab2f7 100644 --- a/tinygrad/uop/spec.py +++ b/tinygrad/uop/spec.py @@ -77,9 +77,9 @@ movement_ops = PatternMatcher([ (UPat((Ops.VECTORIZE, Ops.VCONST), dtype=dtypes.weakint), lambda: True), (UPat({Ops.ADD, Ops.MUL, Ops.IDIV}, dtype=dtypes.weakint), lambda: True), - # AFTER on Movement Op, BUFFER, COPY, or BITCAST - (UPat(Ops.AFTER, src=(UPat(GroupOp.Movement.union({Ops.MULTI, Ops.CONTIGUOUS, Ops.BUFFER, Ops.BITCAST, Ops.COPY})),), allow_any_len=True), - lambda: True), + # AFTER on Movement Op, INDEX, BUFFER, COPY, or BITCAST + (UPat(Ops.AFTER, src=(UPat(GroupOp.Movement.union({Ops.INDEX, Ops.MULTI, Ops.CONTIGUOUS, Ops.BUFFER, Ops.BITCAST, Ops.COPY})),), + allow_any_len=True), lambda: True), ]) _tensor_spec = PatternMatcher([