mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
move reshape to MathTraits (#13054)
* move reshape to MathTraits * confirm it works in amd_uop_matmul
This commit is contained in:
@@ -88,15 +88,15 @@ def hand_spec_kernel3():
|
||||
# ---------------------------
|
||||
# GLOBAL -> LOCAL (As, Bs)
|
||||
# ---------------------------
|
||||
b = b.reshape((N // BLOCK_K, BLOCK_K,
|
||||
N // BLOCK_N, BLOCK_N))
|
||||
b = b.reshape(N // BLOCK_K, BLOCK_K,
|
||||
N // BLOCK_N, BLOCK_N)
|
||||
i = UOp.range(BLOCK_N * BLOCK_K // THREADS_PER_BLOCK, 1)
|
||||
index_x = tid % BLOCK_N
|
||||
index_y = (tid // BLOCK_N) + (THREADS_PER_BLOCK // BLOCK_N) * i
|
||||
Bs_store = Bs[index_y, index_x].store(b[k_tile_range, index_y, blockIdx_x, index_x]).end(i)
|
||||
|
||||
a = a.reshape((N // BLOCK_M, BLOCK_M,
|
||||
N // BLOCK_K, BLOCK_K))
|
||||
a = a.reshape(N // BLOCK_M, BLOCK_M,
|
||||
N // BLOCK_K, BLOCK_K)
|
||||
i = UOp.range(BLOCK_M * BLOCK_K // THREADS_PER_BLOCK, 2)
|
||||
index_x = tid % BLOCK_K
|
||||
index_y = (tid // BLOCK_K) + (THREADS_PER_BLOCK // BLOCK_K) * i
|
||||
@@ -113,12 +113,12 @@ def hand_spec_kernel3():
|
||||
# ---------------------------
|
||||
# LOCAL -> REG (per-wave tiles)
|
||||
# ---------------------------
|
||||
Bs_view = Bs.reshape((BLOCK_K, WAVES_IN_BLOCK_X, ITERS_PER_WAVE_N, LANES_PER_WAVE_X, TN))
|
||||
Bs_view = Bs.reshape(BLOCK_K, WAVES_IN_BLOCK_X, ITERS_PER_WAVE_N, LANES_PER_WAVE_X, TN)
|
||||
iterWaveN = UOp.range(ITERS_PER_WAVE_N, 4)
|
||||
i = UOp.range(TN, 5)
|
||||
B_row = B_row[iterWaveN, i].set(Bs_view[k, waveIdx, iterWaveN, idxInWave, i], end=(iterWaveN, i))
|
||||
|
||||
As_view = As.reshape((BLOCK_K, WAVES_IN_BLOCK_Y, ITERS_PER_WAVE_M, LANES_PER_WAVE_Y, TM))
|
||||
As_view = As.reshape(BLOCK_K, WAVES_IN_BLOCK_Y, ITERS_PER_WAVE_M, LANES_PER_WAVE_Y, TM)
|
||||
iterWaveM = UOp.range(ITERS_PER_WAVE_M, 6)
|
||||
i = UOp.range(TM, 7)
|
||||
A_col = A_col[iterWaveM, i].set(As_view[k, waveIdy, iterWaveM, idyInWave, i], end=(iterWaveM, i))
|
||||
@@ -139,8 +139,8 @@ def hand_spec_kernel3():
|
||||
# ---------------------------
|
||||
# REG -> GLOBAL (epilogue)
|
||||
# ---------------------------
|
||||
c = c.reshape((N//BLOCK_M, WAVES_IN_BLOCK_Y, ITERS_PER_WAVE_M, LANES_PER_WAVE_Y, TM,
|
||||
N//BLOCK_N, WAVES_IN_BLOCK_X, ITERS_PER_WAVE_N, LANES_PER_WAVE_X, TN))
|
||||
c = c.reshape(N//BLOCK_M, WAVES_IN_BLOCK_Y, ITERS_PER_WAVE_M, LANES_PER_WAVE_Y, TM,
|
||||
N//BLOCK_N, WAVES_IN_BLOCK_X, ITERS_PER_WAVE_N, LANES_PER_WAVE_X, TN)
|
||||
iterWaveM = UOp.range(ITERS_PER_WAVE_M, 1000)
|
||||
yt = UOp.range(TM, 1001)
|
||||
iterWaveN = UOp.range(ITERS_PER_WAVE_N, 1002)
|
||||
|
||||
Reference in New Issue
Block a user