diff --git a/extra/gemm/amd_uop_matmul.py b/extra/gemm/amd_uop_matmul.py index 7959c4d909..de88871f67 100644 --- a/extra/gemm/amd_uop_matmul.py +++ b/extra/gemm/amd_uop_matmul.py @@ -29,11 +29,10 @@ assert WAVES_PER_BLOCK_M*REG_TILES_PER_WAVE_M*LANES_PER_WAVE_M*TM == BLOCK_M, "M assert WAVES_PER_BLOCK_N*REG_TILES_PER_WAVE_N*LANES_PER_WAVE_N*TN == BLOCK_N, "N reshape is wrong" def rngs_for_shape(shape:tuple[sint, ...], rng:int, axis_type=AxisType.LOOP): return [UOp.range(s, rng+i, axis_type) for i,s in enumerate(shape)] -def copy(dest:UOp, src:UOp, rng:int, set=False, upcast=False): +def copy(dest:UOp, src:UOp, rng:int, upcast=False): assert dest.shape == src.shape rngs = rngs_for_shape(src.shape, rng, AxisType.UPCAST if upcast else AxisType.LOOP) - copy = dest[*rngs].store(src[*rngs]).end(*rngs) - return dest.after(copy) if set else copy + return dest[*rngs].store(src[*rngs]).end(*rngs) def hand_spec_kernel3(c:UOp, a:UOp, b:UOp) -> UOp: # --------------------------- @@ -77,21 +76,18 @@ def hand_spec_kernel3(c:UOp, a:UOp, b:UOp) -> UOp: # --------------------------- # LOCAL -> REG (per-wave tiles) # --------------------------- - waveIdx = (tid // WARP_SIZE) % WAVES_PER_BLOCK_N - waveIdy = (tid // WARP_SIZE) // WAVES_PER_BLOCK_N - assert waveIdy.vmax+1 == WAVES_PER_BLOCK_M - - laneIdx = (tid % WARP_SIZE) % LANES_PER_WAVE_N - laneIdy = (tid % WARP_SIZE) // LANES_PER_WAVE_N - assert laneIdy.vmax+1 == LANES_PER_WAVE_M + warp, lane = tid // WARP_SIZE, tid % WARP_SIZE + waveIdx, waveIdy = warp % WAVES_PER_BLOCK_N, warp // WAVES_PER_BLOCK_N + laneIdx, laneIdy = lane % LANES_PER_WAVE_N, lane // LANES_PER_WAVE_N + assert waveIdy.vmax+1 == WAVES_PER_BLOCK_M and laneIdy.vmax+1 == LANES_PER_WAVE_M A_col = UOp.placeholder((REG_TILES_PER_WAVE_M, TM), dtypes.float, slot=0, addrspace=AddrSpace.REG) A_local_slice = A_local[k, :].reshape(WAVES_PER_BLOCK_M, REG_TILES_PER_WAVE_M, LANES_PER_WAVE_M, TM)[waveIdy, :, laneIdy, :] - A_col = copy(A_col, A_local_slice, 300, set=True, upcast=True) + A_col = A_col.after(copy(A_col, A_local_slice, 300, upcast=True)) B_row = UOp.placeholder((REG_TILES_PER_WAVE_N, TN), dtypes.float, slot=1, addrspace=AddrSpace.REG) B_local_slice = B_local[k, :].reshape(WAVES_PER_BLOCK_N, REG_TILES_PER_WAVE_N, LANES_PER_WAVE_N, TN)[waveIdx, :, laneIdx, :] - B_row = copy(B_row, B_local_slice, 400, set=True, upcast=True) + B_row = B_row.after(copy(B_row, B_local_slice, 400, upcast=True)) # --------------------------- # FMA: c_regs += A_col * B_row