mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
remove dtype from range, it will be dtypes.index soon [pr] (#11914)
* remove dtype from range, it will be dtypes.index soon [pr] * a few more
This commit is contained in:
@@ -65,7 +65,7 @@ def top_spec_kernel3():
|
||||
c = a@b
|
||||
sink = c.schedule()[-1].ast
|
||||
L = 16
|
||||
sink = sink.reshape((N//L, L, N//L, L)) #.lift({0:UOp.range(dtypes.int, N//BM, 0), 2:UOp.range(dtypes.int, N//BN, 1)})
|
||||
sink = sink.reshape((N//L, L, N//L, L)) #.lift({0:UOp.range(N//BM, 0), 2:UOp.range(N//BN, 1)})
|
||||
sink = graph_rewrite(sink, view_left+pm)
|
||||
axis_types = (AxisType.GLOBAL, AxisType.LOCAL, AxisType.GLOBAL, AxisType.LOCAL, AxisType.REDUCE)
|
||||
return sink.replace(arg=KernelInfo(name="top_"+to_colored(sink.full_shape, axis_types), axis_types=axis_types))
|
||||
@@ -186,7 +186,7 @@ def hand_spec_kernel3(kernel4=getenv("K4", 0), kernel5=getenv("K5", 0)):
|
||||
|
||||
c_regs = UOp(Ops.DEFINE_REG, dtypes.float.ptr(TM * nbIterWaveM * TN * nbIterWaveN), arg=2)
|
||||
|
||||
i = UOp.range(dtypes.int, c_regs.dtype.size, 16)
|
||||
i = UOp.range(c_regs.dtype.size, 16)
|
||||
init_store = c_regs[i].store(UOp.const(dtypes.float, 0.0), i)
|
||||
|
||||
if kernel4:
|
||||
@@ -197,53 +197,53 @@ def hand_spec_kernel3(kernel4=getenv("K4", 0), kernel5=getenv("K5", 0)):
|
||||
kId = 0
|
||||
|
||||
# load from globals into locals
|
||||
i = UOp.range(dtypes.int, nbReadsB, 0)
|
||||
i = UOp.range(nbReadsB, 0)
|
||||
index_x = BN * blockIdx_x + rBIdx
|
||||
index_y = rBIdy + i * strideReadB + kId
|
||||
Bs_store = Bs[(index_y % BK) * BN + index_x % BN].store(b[N * index_y + index_x].load(), i)
|
||||
|
||||
i = UOp.range(dtypes.int, nbReadsA, 1)
|
||||
i = UOp.range(nbReadsA, 1)
|
||||
index_x = rAIdx + kId
|
||||
index_y = BM * blockIdx_y + rAIdy + i * strideReadA
|
||||
As_store = As[(index_x % BK) * BM_As_stride + index_y % BM].store(a[N * index_y + index_x].load(), i)
|
||||
|
||||
# iterate over the middle chunk
|
||||
kId_range = UOp.range(dtypes.int, N//BK-1, 2)
|
||||
kId_range = UOp.range(N//BK-1, 2)
|
||||
kId = kId_range*BK
|
||||
|
||||
barrier = UOp.barrier(As_store, Bs_store)
|
||||
|
||||
# load from globals into registers (next round)
|
||||
i = UOp.range(dtypes.int, nbReadsB, 3)
|
||||
i = UOp.range(nbReadsB, 3)
|
||||
index_x = BN * blockIdx_x + rBIdx
|
||||
index_y = rBIdy + i * strideReadB + kId + BK
|
||||
regB_store = regB[i].store(b[N * index_y + index_x].load(), i)
|
||||
|
||||
i = UOp.range(dtypes.int, nbReadsA, 4)
|
||||
i = UOp.range(nbReadsA, 4)
|
||||
index_x = rAIdx + kId + BK
|
||||
index_y = BM * blockIdx_y + rAIdy + i * strideReadA
|
||||
regA_store = regA[i].store(a[N * index_y + index_x].load(), i)
|
||||
|
||||
def inner_loop(first_range, inp_dep=()):
|
||||
# inner unroll
|
||||
k = UOp.range(dtypes.int, BK, first_range+0)
|
||||
k = UOp.range(BK, first_range+0)
|
||||
|
||||
# load from locals into registers
|
||||
iterWave = UOp.range(dtypes.int, nbIterWaveN, first_range+1)
|
||||
i = UOp.range(dtypes.int, TN, first_range+2)
|
||||
iterWave = UOp.range(nbIterWaveN, first_range+1)
|
||||
i = UOp.range(TN, first_range+2)
|
||||
index = waveIdx * WN + iterWave * SUBWN + TN * idxInWave + i
|
||||
B_row_store = B_row[iterWave*TN + i].store(Bs[k*BN + index].load(*inp_dep), iterWave, i)
|
||||
|
||||
iterWave = UOp.range(dtypes.int, nbIterWaveM, first_range+3)
|
||||
i = UOp.range(dtypes.int, TM, first_range+4)
|
||||
iterWave = UOp.range(nbIterWaveM, first_range+3)
|
||||
i = UOp.range(TM, first_range+4)
|
||||
index = waveIdy * WM + iterWave * SUBWM + TM * idyInWave + i
|
||||
A_col_store = A_col[iterWave*TM + i].store(As[k*BM_As_stride + index].load(*inp_dep), iterWave, i)
|
||||
|
||||
# do the GEMM math
|
||||
iterWaveM = UOp.range(dtypes.int, nbIterWaveM, first_range+5)
|
||||
yt = UOp.range(dtypes.int, TM, first_range+6)
|
||||
iterWaveN = UOp.range(dtypes.int, nbIterWaveN, first_range+7)
|
||||
xt = UOp.range(dtypes.int, TN, first_range+8)
|
||||
iterWaveM = UOp.range(nbIterWaveM, first_range+5)
|
||||
yt = UOp.range(TM, first_range+6)
|
||||
iterWaveN = UOp.range(nbIterWaveN, first_range+7)
|
||||
xt = UOp.range(TN, first_range+8)
|
||||
x = iterWaveN * TN + xt
|
||||
y = iterWaveM * TM + yt
|
||||
c_regs_idx = c_regs[y * TN * nbIterWaveN + x]
|
||||
@@ -256,12 +256,12 @@ def hand_spec_kernel3(kernel4=getenv("K4", 0), kernel5=getenv("K5", 0)):
|
||||
sink = inner_loop(5, (barrier, regB_store, regA_store)).barrier()
|
||||
|
||||
# load from registers into locals
|
||||
i = UOp.range(dtypes.int, nbReadsB, 14)
|
||||
i = UOp.range(nbReadsB, 14)
|
||||
index_x = BN * blockIdx_x + rBIdx
|
||||
index_y = rBIdy + i * strideReadB + kId + BK
|
||||
Bs_store = Bs[(index_y % BK) * BN + index_x % BN].store(regB[i].load(sink), i, kId_range)
|
||||
|
||||
i = UOp.range(dtypes.int, nbReadsA, 15)
|
||||
i = UOp.range(nbReadsA, 15)
|
||||
index_x = rAIdx + kId + BK
|
||||
index_y = BM * blockIdx_y + rAIdy + i * strideReadA
|
||||
As_store = As[(index_x % BK) * BM_As_stride + index_y % BM].store(regA[i].load(sink), i, kId_range)
|
||||
@@ -269,40 +269,40 @@ def hand_spec_kernel3(kernel4=getenv("K4", 0), kernel5=getenv("K5", 0)):
|
||||
# final iteration without the copy
|
||||
sink = inner_loop(16, (UOp.barrier(Bs_store, As_store),))
|
||||
else:
|
||||
kId_range = UOp.range(dtypes.int, N//BK, 0)
|
||||
kId_range = UOp.range(N//BK, 0)
|
||||
kId = kId_range*BK
|
||||
|
||||
# load from globals into locals
|
||||
i = UOp.range(dtypes.int, nbReadsB, 1)
|
||||
i = UOp.range(nbReadsB, 1)
|
||||
index_x = BN * blockIdx_x + rBIdx
|
||||
index_y = rBIdy + i * strideReadB + kId
|
||||
Bs_store = Bs[(index_y % BK) * BN + index_x % BN].store(b[N * index_y + index_x].load(), i)
|
||||
|
||||
i = UOp.range(dtypes.int, nbReadsA, 2)
|
||||
i = UOp.range(nbReadsA, 2)
|
||||
index_x = rAIdx + kId
|
||||
index_y = BM * blockIdx_y + rAIdy + i * strideReadA
|
||||
As_store = As[(index_x % BK) * BM_As_stride + index_y % BM].store(a[N * index_y + index_x].load(), i)
|
||||
|
||||
barrier = UOp.barrier(As_store, Bs_store)
|
||||
|
||||
k = UOp.range(dtypes.int, BK, 3)
|
||||
k = UOp.range(BK, 3)
|
||||
|
||||
# load from locals into registers
|
||||
iterWave = UOp.range(dtypes.int, nbIterWaveN, 4)
|
||||
i = UOp.range(dtypes.int, TN, 5)
|
||||
iterWave = UOp.range(nbIterWaveN, 4)
|
||||
i = UOp.range(TN, 5)
|
||||
index = waveIdx * WN + iterWave * SUBWN + TN * idxInWave + i
|
||||
B_row_store = B_row[iterWave*TN + i].store(Bs[k*BN + index].load(barrier), iterWave, i)
|
||||
|
||||
iterWave = UOp.range(dtypes.int, nbIterWaveM, 6)
|
||||
i = UOp.range(dtypes.int, TM, 7)
|
||||
iterWave = UOp.range(nbIterWaveM, 6)
|
||||
i = UOp.range(TM, 7)
|
||||
index = waveIdy * WM + iterWave * SUBWM + TM * idyInWave + i
|
||||
A_col_store = A_col[iterWave*TM + i].store(As[k*BM_As_stride + index].load(barrier), iterWave, i)
|
||||
|
||||
# do the GEMM math
|
||||
iterWaveM = UOp.range(dtypes.int, nbIterWaveM, 8)
|
||||
yt = UOp.range(dtypes.int, TM, 9)
|
||||
iterWaveN = UOp.range(dtypes.int, nbIterWaveN, 10)
|
||||
xt = UOp.range(dtypes.int, TN, 12)
|
||||
iterWaveM = UOp.range(nbIterWaveM, 8)
|
||||
yt = UOp.range(TM, 9)
|
||||
iterWaveN = UOp.range(nbIterWaveN, 10)
|
||||
xt = UOp.range(TN, 12)
|
||||
x = iterWaveN * TN + xt
|
||||
y = iterWaveM * TM + yt
|
||||
c_regs_idx = c_regs[y * TN * nbIterWaveN + x]
|
||||
@@ -310,10 +310,10 @@ def hand_spec_kernel3(kernel4=getenv("K4", 0), kernel5=getenv("K5", 0)):
|
||||
iterWaveM, iterWaveN, yt, xt, k, kId_range)
|
||||
|
||||
# store c_regs into c
|
||||
iterWaveM = UOp.range(dtypes.int, nbIterWaveM, 1000)
|
||||
yt = UOp.range(dtypes.int, TM, 1001)
|
||||
iterWaveN = UOp.range(dtypes.int, nbIterWaveN, 1002)
|
||||
xt = UOp.range(dtypes.int, TN, 1003)
|
||||
iterWaveM = UOp.range(nbIterWaveM, 1000)
|
||||
yt = UOp.range(TM, 1001)
|
||||
iterWaveN = UOp.range(nbIterWaveN, 1002)
|
||||
xt = UOp.range(TN, 1003)
|
||||
xOut = blockIdx_x * BN + waveIdx * WN + iterWaveN * SUBWN + TN * idxInWave
|
||||
yOut = blockIdx_y * BM + waveIdy * WM + iterWaveM * SUBWM + TM * idyInWave
|
||||
indexC = N * (yOut + yt) + xOut + xt
|
||||
|
||||
Reference in New Issue
Block a user