define reg doesn't have init anymore (#11365)

* define reg doesn't have init anymore

* remove that

* no special logic for dr

* fix amd uop matmul
This commit is contained in:
George Hotz
2025-07-24 19:15:49 -07:00
committed by GitHub
parent 9da3f72495
commit 490a93902c
8 changed files with 18 additions and 33 deletions

View File

@@ -25,9 +25,8 @@ def hl_spec_kernel3():
c = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=2)
As = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(BK*BM, AddrSpace.LOCAL), arg=0)
Bs = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(BK*BN, AddrSpace.LOCAL), arg=1)
junk = UOp.const(dtypes.float, 0)
A_col = UOp(Ops.DEFINE_REG, dtypes.float.ptr(nbIterWaveM * TM, AddrSpace.REG), src=(junk,), arg=0)
B_row = UOp(Ops.DEFINE_REG, dtypes.float.ptr(nbIterWaveN * TN, AddrSpace.REG), src=(junk,), arg=1)
A_col = UOp(Ops.DEFINE_REG, dtypes.float.ptr(nbIterWaveM * TM, AddrSpace.REG), arg=0)
B_row = UOp(Ops.DEFINE_REG, dtypes.float.ptr(nbIterWaveN * TN, AddrSpace.REG), arg=1)
# shape buffers. TODO: permutes
full_shape = (N//BM, nbIterWaveM, BM//(nbIterWaveM * TM), TM, N//BN, nbIterWaveN, BN//(nbIterWaveN * TN), TN, N//BK, BK)
@@ -91,14 +90,16 @@ def hand_spec_kernel3():
b = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=1)
c = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=2)
junk = UOp.const(dtypes.float, 0) # TODO: remove this
A_col = UOp(Ops.DEFINE_REG, dtypes.float.ptr(nbIterWaveM * TM, AddrSpace.REG), src=(junk,), arg=0)
B_row = UOp(Ops.DEFINE_REG, dtypes.float.ptr(nbIterWaveN * TN, AddrSpace.REG), src=(junk,), arg=1)
A_col = UOp(Ops.DEFINE_REG, dtypes.float.ptr(nbIterWaveM * TM, AddrSpace.REG), arg=0)
B_row = UOp(Ops.DEFINE_REG, dtypes.float.ptr(nbIterWaveN * TN, AddrSpace.REG), arg=1)
As = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(BK*BM, AddrSpace.LOCAL), arg=0)
Bs = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(BK*BN, AddrSpace.LOCAL), arg=1)
c_regs = UOp(Ops.DEFINE_REG, dtypes.float.ptr(TM * nbIterWaveM * TN * nbIterWaveN), src=(junk,), arg=2)
c_regs = UOp(Ops.DEFINE_REG, dtypes.float.ptr(TM * nbIterWaveM * TN * nbIterWaveN), arg=2)
i = UOp.range(dtypes.int, c_regs.dtype.size, 16)
init_store = c_regs[i].store(UOp.const(dtypes.float, 0.0), i)
kId_range = UOp.range(dtypes.int, N//BK, 0)
kId = kId_range*BK
@@ -137,7 +138,7 @@ def hand_spec_kernel3():
x = iterWaveN * TN + xt
y = iterWaveM * TM + yt
c_regs_idx = c_regs[y * TN * nbIterWaveN + x]
sink = c_regs_idx.store(c_regs_idx.load() + A_col[y].load(A_col_store) * B_row[x].load(B_row_store),
sink = c_regs_idx.store(c_regs_idx.load(init_store) + A_col[y].load(A_col_store) * B_row[x].load(B_row_store),
iterWaveM, iterWaveN, yt, xt, k, kId_range)
# store c_regs into c
@@ -148,7 +149,8 @@ def hand_spec_kernel3():
xOut = blockIdx_x * BN + waveIdx * WN + iterWaveN * SUBWN + TN * idxInWave
yOut = blockIdx_y * BM + waveIdy * WM + iterWaveM * SUBWM + TM * idyInWave
indexC = N * (yOut + yt) + xOut + xt
sink = c[indexC].store(c_regs[TN * nbIterWaveN * (iterWaveM * TM + yt) + (iterWaveN * TN + xt)].load(sink), iterWaveM, iterWaveN, yt, xt)
sink = c[indexC].store(c_regs[TN * nbIterWaveN * (iterWaveM * TM + yt) + (iterWaveN * TN + xt)].load(sink),
iterWaveM, iterWaveN, yt, xt)
return sink.sink(arg=KernelInfo(name="tinygemm"))