mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
define reg doesn't have init anymore (#11365)
* define reg doesn't have init anymore * remove that * no special logic for dr * fix amd uop matmul
This commit is contained in:
@@ -25,9 +25,8 @@ def hl_spec_kernel3():
|
||||
c = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=2)
|
||||
As = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(BK*BM, AddrSpace.LOCAL), arg=0)
|
||||
Bs = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(BK*BN, AddrSpace.LOCAL), arg=1)
|
||||
junk = UOp.const(dtypes.float, 0)
|
||||
A_col = UOp(Ops.DEFINE_REG, dtypes.float.ptr(nbIterWaveM * TM, AddrSpace.REG), src=(junk,), arg=0)
|
||||
B_row = UOp(Ops.DEFINE_REG, dtypes.float.ptr(nbIterWaveN * TN, AddrSpace.REG), src=(junk,), arg=1)
|
||||
A_col = UOp(Ops.DEFINE_REG, dtypes.float.ptr(nbIterWaveM * TM, AddrSpace.REG), arg=0)
|
||||
B_row = UOp(Ops.DEFINE_REG, dtypes.float.ptr(nbIterWaveN * TN, AddrSpace.REG), arg=1)
|
||||
|
||||
# shape buffers. TODO: permutes
|
||||
full_shape = (N//BM, nbIterWaveM, BM//(nbIterWaveM * TM), TM, N//BN, nbIterWaveN, BN//(nbIterWaveN * TN), TN, N//BK, BK)
|
||||
@@ -91,14 +90,16 @@ def hand_spec_kernel3():
|
||||
b = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=1)
|
||||
c = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=2)
|
||||
|
||||
junk = UOp.const(dtypes.float, 0) # TODO: remove this
|
||||
A_col = UOp(Ops.DEFINE_REG, dtypes.float.ptr(nbIterWaveM * TM, AddrSpace.REG), src=(junk,), arg=0)
|
||||
B_row = UOp(Ops.DEFINE_REG, dtypes.float.ptr(nbIterWaveN * TN, AddrSpace.REG), src=(junk,), arg=1)
|
||||
A_col = UOp(Ops.DEFINE_REG, dtypes.float.ptr(nbIterWaveM * TM, AddrSpace.REG), arg=0)
|
||||
B_row = UOp(Ops.DEFINE_REG, dtypes.float.ptr(nbIterWaveN * TN, AddrSpace.REG), arg=1)
|
||||
|
||||
As = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(BK*BM, AddrSpace.LOCAL), arg=0)
|
||||
Bs = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(BK*BN, AddrSpace.LOCAL), arg=1)
|
||||
|
||||
c_regs = UOp(Ops.DEFINE_REG, dtypes.float.ptr(TM * nbIterWaveM * TN * nbIterWaveN), src=(junk,), arg=2)
|
||||
c_regs = UOp(Ops.DEFINE_REG, dtypes.float.ptr(TM * nbIterWaveM * TN * nbIterWaveN), arg=2)
|
||||
|
||||
i = UOp.range(dtypes.int, c_regs.dtype.size, 16)
|
||||
init_store = c_regs[i].store(UOp.const(dtypes.float, 0.0), i)
|
||||
|
||||
kId_range = UOp.range(dtypes.int, N//BK, 0)
|
||||
kId = kId_range*BK
|
||||
@@ -137,7 +138,7 @@ def hand_spec_kernel3():
|
||||
x = iterWaveN * TN + xt
|
||||
y = iterWaveM * TM + yt
|
||||
c_regs_idx = c_regs[y * TN * nbIterWaveN + x]
|
||||
sink = c_regs_idx.store(c_regs_idx.load() + A_col[y].load(A_col_store) * B_row[x].load(B_row_store),
|
||||
sink = c_regs_idx.store(c_regs_idx.load(init_store) + A_col[y].load(A_col_store) * B_row[x].load(B_row_store),
|
||||
iterWaveM, iterWaveN, yt, xt, k, kId_range)
|
||||
|
||||
# store c_regs into c
|
||||
@@ -148,7 +149,8 @@ def hand_spec_kernel3():
|
||||
xOut = blockIdx_x * BN + waveIdx * WN + iterWaveN * SUBWN + TN * idxInWave
|
||||
yOut = blockIdx_y * BM + waveIdy * WM + iterWaveM * SUBWM + TM * idyInWave
|
||||
indexC = N * (yOut + yt) + xOut + xt
|
||||
sink = c[indexC].store(c_regs[TN * nbIterWaveN * (iterWaveM * TM + yt) + (iterWaveN * TN + xt)].load(sink), iterWaveM, iterWaveN, yt, xt)
|
||||
sink = c[indexC].store(c_regs[TN * nbIterWaveN * (iterWaveM * TM + yt) + (iterWaveN * TN + xt)].load(sink),
|
||||
iterWaveM, iterWaveN, yt, xt)
|
||||
|
||||
return sink.sink(arg=KernelInfo(name="tinygemm"))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user