write out kernel 3 in uops (#11352)

* write out kernel 3 in uops

* matmul is correct

* gemm passes spec

* bugfix to match speed

* cleanups
This commit is contained in:
George Hotz
2025-07-23 17:32:38 -07:00
committed by GitHub
parent 5b570196e4
commit b0dc97d1f7
6 changed files with 103 additions and 127 deletions

View File

@@ -29,7 +29,7 @@ if __name__ == "__main__":
c = Tensor.zeros(N, N).contiguous().realize()
GlobalCounters.reset()
with Context(DEBUG=2, BEAM=4):
with Context(DEBUG=2):
for _ in range(run_count): tc = (a@b).realize()
GlobalCounters.reset()

View File

@@ -80,6 +80,8 @@ extern "C" __attribute__((global)) void kernel3_registers(float *a, float *b, fl
// Iteration over BK blocks.
for (int kId = 0; kId < N; kId += BK) {
__syncthreads();
// We populate the Shared Memory with Ks row and columns
for (int i = 0; i < nbReadsB; i++) {
int index_x = BN * blockIdx.x + rBIdx;
@@ -123,7 +125,6 @@ extern "C" __attribute__((global)) void kernel3_registers(float *a, float *b, fl
}
}
}
__syncthreads();
}
for (int iterWaveM = 0; iterWaveM < nbIterWaveM; iterWaveM++) {

View File

@@ -1,168 +1,145 @@
from tinygrad import Tensor, Device, Context, GlobalCounters, dtypes
from tinygrad.helpers import prod, unwrap
from tinygrad.uop.ops import UOp, Ops, KernelInfo
from tinygrad.opt.kernel import AxisType
from tinygrad.engine.realize import CompiledRunner, ExecItem, get_program
from tinygrad.uop.ops import graph_rewrite, PatternMatcher, UPat, Ops, UOp, GroupOp
from tinygrad.shape.shapetracker import ShapeTracker, strides_for_shape
from tinygrad.schedule.kernelize import merge_views
from tinygrad.shape.view import View
from tinygrad.uop.ops import Ops, UOp
from tinygrad.dtype import AddrSpace
N = 4096
run_count = 5
# change reduceop axes and input ShapeTrackers, view gets replaced with a reshape.
# src->r->view --> src->view->r
def swizzle_reduceop(src:UOp, r:UOp, view:UOp):
if r.tag is not None: return None
# confirm the input is in order
# TODO: replace this with a UOp that allows for nothing else then remove this
permute = tuple(i for i in range(len(src.shape)) if i not in r.axis_arg)+r.axis_arg
assert permute == tuple(range(len(permute))), f"reduce axis must already be in order, {permute} isn't"
def hand_spec_kernel3():
BLOCK_SIZE = 256
# append the reduce shape to each of the views
reduce_count = len(r.axis_arg)
prshape = prod(rshape:=src.shape[-reduce_count:])
rstrides = strides_for_shape(rshape)
nv = [View.create(v.shape[:-reduce_count]+rshape, tuple(x*prshape for x in v.strides[:-reduce_count])+rstrides, v.offset*prshape,
v.mask[:-reduce_count]+tuple((0,s) for s in rshape) if v.mask is not None else None) for v in unwrap(view.st).views]
BN = 128
BM = 128
BK = 8
# no reshape required with shrinking REDUCE_AXIS
return UOp(Ops.REDUCE_AXIS, r.dtype, (src.view(ShapeTracker(tuple(nv))),),
(r.arg[0], tuple(range(len(view.shape)-reduce_count, len(view.shape)))))
early_view_left = merge_views+PatternMatcher([
# view before elementwise and buffer ops
(UPat(Ops.VIEW, src=(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.BIND, Ops.VALID, Ops.STORE, Ops.LOAD}, name="e"),), name="view"),
lambda e,view: e.replace(src=tuple(s.view(view.st) for s in e.src)) if e.tag is None else None),
# push a non contiguous ShapeTracker through reduceop
(UPat(Ops.VIEW, src=(UPat(Ops.REDUCE_AXIS, src=(UPat.var("src"),), name="r"),), name="view"), swizzle_reduceop),
])
def hand_spec():
# Block Tile size . 128x128
# Thread Tile size . 4x4
# Wave Tile size . 128x32
# A wave is . 8x4
# ────── problem size and tiling params (mirror the C kernel) ───────────────────
BK = 8 # depth of K-tile
BN = BM = 128 # block-tile (output) sizes
# the real thread is 16x8 = 128 regs
TM = 4
nbIterWaveM = 2
TN = 4
nbIterWaveN = 4
TM = 4
# ────── shared-memory tile sizes (unchanged) ───────────────────────────────────
LDS_A_SZ = BK * BM # 1024 floats
LDS_B_SZ = BK * BN # 1024 floats
nbWaves = BLOCK_SIZE // 32
WN = 64
WM = BN * BM // nbWaves // WN
bC = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=0) # output C
bA = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=1) # input A
bB = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=2) # input B
nbWaveX = BN // WN
nbWaveY = BM // WM
# TODO: this should not be a string, just a number
lAs = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(LDS_A_SZ, addrspace=AddrSpace.LOCAL), arg="As")
lBs = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(LDS_B_SZ, addrspace=AddrSpace.LOCAL), arg="Bs")
threadIdx_x = UOp(Ops.SPECIAL, dtypes.int, arg=("lidx0", BLOCK_SIZE))
waveIndex = threadIdx_x // 32
waveIdx = waveIndex % nbWaveX
waveIdy = waveIndex // nbWaveX
indexInWave = threadIdx_x % 32
s0 = ShapeTracker.from_shape((N, N, N), (N, 0, 1))
s1 = ShapeTracker.from_shape((N, N, N), (0, 1, N))
s2 = ShapeTracker.from_shape((N, N, 1), (N, 1, 0))
nbThreadXPerWave = 8
nbThreadYPerWave = 4
ls0 = ShapeTracker.from_shape((BM, BK))
ls1 = ShapeTracker.from_shape((BN, BK))
idxInWave = indexInWave % nbThreadXPerWave
idyInWave = indexInWave // nbThreadXPerWave
buf_at = [AxisType.GLOBAL, AxisType.UPCAST, AxisType.LOCAL, AxisType.LOCAL, AxisType.LOCAL, AxisType.LOCAL, AxisType.UPCAST, AxisType.UPCAST]
buf_bt = [AxisType.GLOBAL, AxisType.UPCAST, AxisType.LOCAL, AxisType.LOCAL, AxisType.LOCAL, AxisType.LOCAL, AxisType.UPCAST, AxisType.UPCAST]
axis_types = buf_at + buf_bt + [AxisType.REDUCE, AxisType.UNROLL, AxisType.UNROLL, AxisType.UNROLL]
nbIterWaveN = WN // (nbThreadXPerWave * TN)
nbIterWaveM = WM // (nbThreadYPerWave * TM)
# 128 x 128 x 8
full_shape = (N//BM, 2, 2, 2, 2, 2, 2, 2, N//BN, 2, 2, 2, 2, 2, 2, 2, N//BK, 2, 2, 2)
SUBWN = WN // nbIterWaveN
SUBWM = WM // nbIterWaveM
s0 = s0.reshape(full_shape)
s1 = s1.reshape(full_shape)
s2 = s2.reshape(full_shape[:-4] + (1,)*4)
# Thread mapping to read BKxBN block from A
rAIdx = threadIdx_x % BK
rAIdy = threadIdx_x // BK
# Thread mapping to read BNxBK block from B
rBIdx = threadIdx_x % BN
rBIdy = threadIdx_x // BN
ls0 = ls0.reshape((1, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2)).expand(s0.shape)
ls1 = ls1.reshape((1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2)).expand(s1.shape)
assert ls0.real_size() == LDS_A_SZ
assert ls1.real_size() == LDS_B_SZ
strideReadB = BLOCK_SIZE // BN
strideReadA = BLOCK_SIZE // BK
nbReadsB = BN * BK // BLOCK_SIZE
nbReadsA = BM * BK // BLOCK_SIZE
# BK is a loop of 8
# each loop reads 8 in A, 16 in B
blockIdx_x = UOp(Ops.SPECIAL, dtypes.int, arg=("gidx0", N//BN))
blockIdx_y = UOp(Ops.SPECIAL, dtypes.int, arg=("gidx1", N//BM))
print(ls0)
print(ls1)
a = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=0)
b = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=1)
c = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=2)
permaxis = []
for axis_order in [AxisType.GLOBAL, AxisType.LOCAL, AxisType.LOOP, AxisType.UPCAST, AxisType.GROUP_REDUCE, AxisType.REDUCE, AxisType.UNROLL]:
permaxis += [i for i,a in enumerate(axis_types) if a == axis_order]
axis_types = [axis_types[x] for x in permaxis]
s0, s1, s2, ls0, ls1 = [x.permute(tuple(permaxis)) for x in [s0, s1, s2, ls0, ls1]]
print(axis_types)
junk = UOp.const(dtypes.float, 0)
A_col = UOp(Ops.DEFINE_REG, dtypes.float.ptr(nbIterWaveM * TM, AddrSpace.REG), src=(junk,), arg=0)
B_row = UOp(Ops.DEFINE_REG, dtypes.float.ptr(nbIterWaveN * TN, AddrSpace.REG), src=(junk,), arg=1)
lw0, lr0 = ls0, ls0
lw1, lr1 = ls1, ls1
As = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(BK*BM, AddrSpace.LOCAL), arg=0)
Bs = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(BK*BN, AddrSpace.LOCAL), arg=1)
# first round of permutes
c_regs = UOp(Ops.DEFINE_REG, dtypes.float.ptr(TM * nbIterWaveM * TN * nbIterWaveN), src=(junk,), arg=2)
permaxis = (0, 1, 19, 18, 17, 12, 11, 10, 5, 4, 3, 2, 6, 7, 8, 9, 16, 13, 14, 15)
s0 = s0.permute(permaxis)
lw0 = lw0.permute(permaxis)
kId_range = UOp.range(dtypes.int, N//BK, 0)
kId = kId_range*BK
permaxis = (0, 1, 15, 14, 9, 8, 7, 6, 13, 19, 18, 17, 5, 4, 3, 2, 16, 12, 11, 10)
s1 = s1.permute(permaxis)
lw1 = lw1.permute(permaxis)
# load from globals into locals
i = UOp.range(dtypes.int, nbReadsB, 1)
index_x = BN * blockIdx_x + rBIdx
index_y = rBIdy + i * strideReadB + kId
Bs_store = Bs[(index_y % BK) * BN + index_x % BN].store(b[N * index_y + index_x].load(), i)
# second round of permutes
#permaxis = (0, 1, 12, 11, 5, 4, 3, 2, 10, 6, 7, 8, 9, 13, 14, 15, 16, 17, 18, 19)
#lw0 = lw0.permute(permaxis)
#lr0 = lr0.permute(permaxis)
i = UOp.range(dtypes.int, nbReadsA, 2)
index_x = rAIdx + kId
index_y = BM * blockIdx_y + rAIdy + i * strideReadA
As_store = As[(index_x % BK) * BM + index_y % BM].store(a[N * index_y + index_x].load(), i)
from tinygrad.opt.kernel import axis_colors, colored
print('_'.join([colored(f"{s}({st})", axis_colors[x]) for s,st,x in zip(s0.shape, s0.views[0].strides, axis_types)]))
print('_'.join([colored(f"{s}({st})", axis_colors[x]) for s,st,x in zip(s1.shape, s1.views[0].strides, axis_types)]))
print('_'.join([colored(f"{s}({st})", axis_colors[x]) for s,st,x in zip(s2.shape, s2.views[0].strides, axis_types)]))
print("lw")
print('_'.join([colored(f"{s}({st})", axis_colors[x]) for s,st,x in zip(lw0.shape, lw0.views[0].strides, axis_types)]))
print('_'.join([colored(f"{s}({st})", axis_colors[x]) for s,st,x in zip(lw1.shape, lw1.views[0].strides, axis_types)]))
print("lr")
print('_'.join([colored(f"{s}({st})", axis_colors[x]) for s,st,x in zip(lr0.shape, lr0.views[0].strides, axis_types)]))
print('_'.join([colored(f"{s}({st})", axis_colors[x]) for s,st,x in zip(lr1.shape, lr1.views[0].strides, axis_types)]))
barrier = UOp(Ops.BARRIER, src=(As_store, Bs_store))
# loads and stores
bs0 = bA.view(s0).load()
bs1 = bB.view(s1).load()
bs0 = lAs.view(lr0).load(lAs.view(lw0).store(bs0))
bs1 = lBs.view(lr1).load(lBs.view(lw1).store(bs1))
k = UOp.range(dtypes.int, BK, 3)
mat = (bs0 * bs1).r(Ops.ADD, tuple([i for i,a in enumerate(axis_types) if a in (AxisType.REDUCE, AxisType.UNROLL)]), permute=False)
st = bC.view(s2).store(mat)
# load from locals into registers
iterWave = UOp.range(dtypes.int, nbIterWaveN, 4)
i = UOp.range(dtypes.int, TN, 5)
index = waveIdx * WN + iterWave * SUBWN + TN * idxInWave + i
B_row_store = B_row[iterWave*TN + i].store(Bs[k*BN + index].load(barrier), iterWave, i)
ast = st.sink(arg=KernelInfo(axis_types=tuple(axis_types), name="tinygemm"))
ast = graph_rewrite(ast, merge_views)
prg = get_program(ast, Device.default.renderer)
print(prg.src)
return prg
iterWave = UOp.range(dtypes.int, nbIterWaveM, 6)
i = UOp.range(dtypes.int, TM, 7)
index = waveIdy * WM + iterWave * SUBWM + TM * idyInWave + i
A_col_store = A_col[iterWave*TM + i].store(As[k*BM + index].load(barrier), iterWave, i)
# do the GEMM math
iterWaveM = UOp.range(dtypes.int, nbIterWaveM, 8)
iterWaveN = UOp.range(dtypes.int, nbIterWaveN, 9)
yt = UOp.range(dtypes.int, TM, 10)
xt = UOp.range(dtypes.int, TN, 11)
x = iterWaveN * TN + xt
y = iterWaveM * TM + yt
c_regs_idx = c_regs[y * TN * nbIterWaveN + x]
sink = c_regs_idx.store(c_regs_idx.load() + A_col[y].load(A_col_store) * B_row[x].load(B_row_store),
iterWaveM, iterWaveN, yt, xt, k, kId_range)
# store c_regs into c
iterWaveM = UOp.range(dtypes.int, nbIterWaveM, 12)
iterWaveN = UOp.range(dtypes.int, nbIterWaveN, 13)
xOut = blockIdx_x * BN + waveIdx * WN + iterWaveN * SUBWN + TN * idxInWave
yOut = blockIdx_y * BM + waveIdy * WM + iterWaveM * SUBWM + TM * idyInWave
yt = UOp.range(dtypes.int, TM, 14)
xt = UOp.range(dtypes.int, TN, 15)
indexC = N * (yOut + yt) + xOut + xt
sink = c[indexC].store(c_regs[TN * nbIterWaveN * (iterWaveM * TM + yt) + (iterWaveN * TN + xt)].load(sink), iterWaveM, iterWaveN, yt, xt)
return sink.sink(arg=KernelInfo(name="tinygemm"))
if __name__ == "__main__":
hprg = hand_spec()
hrunner = CompiledRunner(hprg)
hprg = hand_spec_kernel3()
prg = get_program(hprg, Device.default.renderer)
print(prg.src)
hrunner = CompiledRunner(prg)
a = Tensor.randn(N, N).realize()
b = Tensor.randn(N, N).realize()
hc = Tensor.zeros(N, N).contiguous().realize()
GlobalCounters.reset()
with Context(DEBUG=2, BEAM=4):
with Context(DEBUG=2):
for _ in range(run_count): tc = (a@b).realize()
GlobalCounters.reset()
ei = ExecItem(hrunner, [hc.uop.buffer, a.uop.buffer, b.uop.buffer])
ei = ExecItem(hrunner, [a.uop.buffer, b.uop.buffer, hc.uop.buffer])
with Context(DEBUG=2):
for _ in range(run_count): ei.run(wait=True)
err = (hc-tc).square().mean().item()
print(f"hrunner {err}")
assert err < 1e-06
if err > 1e-06: raise RuntimeError("matmul is wrong!")

View File

@@ -211,6 +211,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
def sink(self, *srcs:UOp|None, **kwargs): return UOp(Ops.SINK, dtypes.void, (self,)+tuple([x for x in srcs if x is not None]), **kwargs)
def detach(self): return UOp(Ops.DETACH, self.dtype, (self,))
def index(self, idx:UOp, valid:UOp|None=None): return UOp(Ops.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx))
def __getitem__(self, idx): return self.index(idx)
def const_like(self, b:ConstLike):
# constants can optionally have a DEVICE source
return UOp.const(self.dtype, b, device=self._device, shape=self.shape if self.st is not None else None)

View File

@@ -162,10 +162,8 @@ spec = PatternMatcher([
(UPat(Ops.LOAD, src=(UPat(Ops.STORE),)), lambda: True),
# LOAD takes a <bufidx, alt?, barrier?>
(UPat(Ops.LOAD, src=(index_pat,)), validate_index),
(UPat(Ops.LOAD, src=(index_pat, UPat(Ops.BARRIER))), validate_index),
(UPat(Ops.LOAD, src=(index_pat, UPat(Ops.IF, name="cond"))), lambda idx,cond: validate_index(idx,cond.src[0])),
(UPat(Ops.LOAD, src=(index_pat, UPat.var("alt")), name="ld"), lambda ld,alt,idx: ld.dtype == alt.dtype and validate_index(idx)),
(UPat(Ops.LOAD, src=(index_pat, UPat(Ops.IF, name="cond")), allow_any_len=True), lambda idx,cond: validate_index(idx,cond.src[0])),
(UPat(Ops.LOAD, src=(index_pat,), allow_any_len=True), validate_index),
# STORE takes a <bufidx, val, gate?>
(UPat(Ops.STORE, src=(index_pat, UPat(name="val"), UPat(Ops.IF, name="gate")), allow_any_len=True), validate_store),

View File

@@ -437,7 +437,6 @@ sym = symbolic_flat+PatternMatcher([
((UPat.var('x', dtypes.uint64)&(UPat.var('y').where(UPat.const(dtypes.uint64, 0xFFFFFFFF), UPat.const(dtypes.uint64, 0)))).cast(dtypes.uint32),
lambda x,y: y.where(x.cast(dtypes.uint32), UOp.const(dtypes.uint32, 0))),
# ** self folding **
(UPat(Ops.DEFINE_REG, src=(UPat.var("x"),)), lambda x: x), # a DEFINE_ACC without ranges is a CONST
# x!=0 -> (bool)x
(UPat.var("x")!=0, lambda x: x.cast(dtypes.bool.vec(x.dtype.count))),
# ** where **