mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
amd_uop_matmul more cleanups (#15240)
This commit is contained in:
@@ -29,11 +29,10 @@ assert WAVES_PER_BLOCK_M*REG_TILES_PER_WAVE_M*LANES_PER_WAVE_M*TM == BLOCK_M, "M
|
||||
assert WAVES_PER_BLOCK_N*REG_TILES_PER_WAVE_N*LANES_PER_WAVE_N*TN == BLOCK_N, "N reshape is wrong"
|
||||
|
||||
def rngs_for_shape(shape:tuple[sint, ...], rng:int, axis_type=AxisType.LOOP): return [UOp.range(s, rng+i, axis_type) for i,s in enumerate(shape)]
|
||||
def copy(dest:UOp, src:UOp, rng:int, set=False, upcast=False):
|
||||
def copy(dest:UOp, src:UOp, rng:int, upcast=False):
|
||||
assert dest.shape == src.shape
|
||||
rngs = rngs_for_shape(src.shape, rng, AxisType.UPCAST if upcast else AxisType.LOOP)
|
||||
copy = dest[*rngs].store(src[*rngs]).end(*rngs)
|
||||
return dest.after(copy) if set else copy
|
||||
return dest[*rngs].store(src[*rngs]).end(*rngs)
|
||||
|
||||
def hand_spec_kernel3(c:UOp, a:UOp, b:UOp) -> UOp:
|
||||
# ---------------------------
|
||||
@@ -77,21 +76,18 @@ def hand_spec_kernel3(c:UOp, a:UOp, b:UOp) -> UOp:
|
||||
# ---------------------------
|
||||
# LOCAL -> REG (per-wave tiles)
|
||||
# ---------------------------
|
||||
waveIdx = (tid // WARP_SIZE) % WAVES_PER_BLOCK_N
|
||||
waveIdy = (tid // WARP_SIZE) // WAVES_PER_BLOCK_N
|
||||
assert waveIdy.vmax+1 == WAVES_PER_BLOCK_M
|
||||
|
||||
laneIdx = (tid % WARP_SIZE) % LANES_PER_WAVE_N
|
||||
laneIdy = (tid % WARP_SIZE) // LANES_PER_WAVE_N
|
||||
assert laneIdy.vmax+1 == LANES_PER_WAVE_M
|
||||
warp, lane = tid // WARP_SIZE, tid % WARP_SIZE
|
||||
waveIdx, waveIdy = warp % WAVES_PER_BLOCK_N, warp // WAVES_PER_BLOCK_N
|
||||
laneIdx, laneIdy = lane % LANES_PER_WAVE_N, lane // LANES_PER_WAVE_N
|
||||
assert waveIdy.vmax+1 == WAVES_PER_BLOCK_M and laneIdy.vmax+1 == LANES_PER_WAVE_M
|
||||
|
||||
A_col = UOp.placeholder((REG_TILES_PER_WAVE_M, TM), dtypes.float, slot=0, addrspace=AddrSpace.REG)
|
||||
A_local_slice = A_local[k, :].reshape(WAVES_PER_BLOCK_M, REG_TILES_PER_WAVE_M, LANES_PER_WAVE_M, TM)[waveIdy, :, laneIdy, :]
|
||||
A_col = copy(A_col, A_local_slice, 300, set=True, upcast=True)
|
||||
A_col = A_col.after(copy(A_col, A_local_slice, 300, upcast=True))
|
||||
|
||||
B_row = UOp.placeholder((REG_TILES_PER_WAVE_N, TN), dtypes.float, slot=1, addrspace=AddrSpace.REG)
|
||||
B_local_slice = B_local[k, :].reshape(WAVES_PER_BLOCK_N, REG_TILES_PER_WAVE_N, LANES_PER_WAVE_N, TN)[waveIdx, :, laneIdx, :]
|
||||
B_row = copy(B_row, B_local_slice, 400, set=True, upcast=True)
|
||||
B_row = B_row.after(copy(B_row, B_local_slice, 400, upcast=True))
|
||||
|
||||
# ---------------------------
|
||||
# FMA: c_regs += A_col * B_row
|
||||
|
||||
Reference in New Issue
Block a user