From 4a741e836486bf8ecfed9fad96e6702a0b937005 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Thu, 30 Oct 2025 17:02:38 +0800 Subject: [PATCH] modernize amd uop matmul (#13011) * modernize amd uop matmul * progress * comment * more comments * revert that * mac cleanups * fix estimates * format --- extra/gemm/amd_uop_matmul.py | 399 ++++++++-------------------- tinygrad/codegen/late/linearizer.py | 2 +- tinygrad/renderer/__init__.py | 2 +- tinygrad/uop/ops.py | 7 +- tinygrad/uop/spec.py | 5 +- 5 files changed, 124 insertions(+), 291 deletions(-) diff --git a/extra/gemm/amd_uop_matmul.py b/extra/gemm/amd_uop_matmul.py index 0b1f534789..febc6b2098 100644 --- a/extra/gemm/amd_uop_matmul.py +++ b/extra/gemm/amd_uop_matmul.py @@ -1,145 +1,51 @@ from tinygrad import Tensor, Device, Context, GlobalCounters, dtypes -from tinygrad.uop.ops import UOp, Ops, KernelInfo, graph_rewrite, AxisType, PatternMatcher, UPat -from tinygrad.engine.realize import CompiledRunner, ExecItem, get_program +from tinygrad.uop.ops import UOp, KernelInfo +from tinygrad.engine.realize import ExecItem, get_runner from tinygrad.dtype import AddrSpace -from tinygrad.helpers import getenv, colored, prod, unwrap -from tinygrad.shape.shapetracker import ShapeTracker, View -from tinygrad.shape.view import strides_for_shape -from tinygrad.codegen.opt.kernel import axis_colors, Opt, OptOps -from tinygrad.codegen.opt.swizzler import merge_views, view_left - -def to_colored(full_shape, axis_types): return '_'.join([colored(str(s), axis_colors[at]) for s,at in zip(full_shape, axis_types)]) +from tinygrad.helpers import getenv N = 4096 run_count = 5 +# block for locals BN = 128 BM = 128 BK = 8 +# t for registers TN = 4 TM = 4 -# NOTE: this is from testgrad -# change reduceop axes and input ShapeTrackers, view gets replaced with a reshape. -# src->r->view --> src->view->r -def swizzle_reduceop(src:UOp, r:UOp, view:UOp): - if r.tag is not None: return None - # confirm the input is in order - # TODO: replace this with a UOp that allows for nothing else then remove this - permute = tuple(i for i in range(len(src.shape)) if i not in r.axis_arg)+r.axis_arg - assert permute == tuple(range(len(permute))), f"reduce axis must already be in order, {permute} isn't" - # append the reduce shape to each of the views - prshape = prod(rshape:=src.shape[-len(r.axis_arg):]) - rstrides = strides_for_shape(rshape) - nv = [View.create(v.shape+rshape, tuple(x*prshape for x in v.strides)+rstrides, v.offset*prshape, - v.mask+tuple((0,s) for s in rshape) if v.mask is not None else None) for v in unwrap(view.st).views] +def hand_spec_kernel3(kernel5=getenv("K5", 0)): + # --------------------------- + # launch/config constants + # --------------------------- - # no reshape required with shrinking REDUCE_AXIS - return UOp(Ops.REDUCE_AXIS, r.dtype, (src.view(ShapeTracker(tuple(nv))),), - (r.arg[0], tuple(range(len(view.shape), len(view.shape) + len(r.axis_arg))))) - -pm = PatternMatcher([ - (UPat(Ops.VIEW, src=(UPat(Ops.REDUCE_AXIS, src=(UPat.var("src"),), name="r"),), name="view"), swizzle_reduceop), -]) - -def rangeify_kernel3(): - a = Tensor.empty(N,N) - b = Tensor.empty(N,N) - c = a@b - #c = c.reshape((32,2,16,4,32,2,16,4)).contiguous() - sink = c.schedule()[-1].ast - #print(sink) - - opts = [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.LOCAL, 0, 16), Opt(OptOps.UPCAST, 0, 2)] - opts += [Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.LOCAL, 1, 16), Opt(OptOps.UPCAST, 1, 2)] - opts += [Opt(OptOps.UNROLL, 0, 8)] - - return sink.replace(arg=KernelInfo(opts_to_apply=tuple(opts))) - -def top_spec_kernel3(): - a = Tensor.empty(N,N) - b = Tensor.empty(N,N) - c = a@b - sink = c.schedule()[-1].ast - L = 16 - sink = sink.reshape((N//L, L, N//L, L)) #.lift({0:UOp.range(N//BM, 0), 2:UOp.range(N//BN, 1)}) - sink = graph_rewrite(sink, view_left+pm) - axis_types = (AxisType.GLOBAL, AxisType.LOCAL, AxisType.GLOBAL, AxisType.LOCAL, AxisType.REDUCE) - return sink.replace(arg=KernelInfo(name="top_"+to_colored(sink.full_shape, axis_types), axis_types=axis_types)) - -def hl_spec_kernel3(): - nbIterWaveM = 2 - nbIterWaveN = 2 - - # define buffers - # TODO: remove these views once the defines have a shape - a = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=1).view(ShapeTracker.from_shape((N,N))) - b = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=2).view(ShapeTracker.from_shape((N,N))).permute((1,0)) - c = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=0).view(ShapeTracker.from_shape((N,N))) - As = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(BK*BM, AddrSpace.LOCAL), arg=0).view(ShapeTracker.from_shape((BK, BM))).permute((1,0)) - Bs = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(BK*BN, AddrSpace.LOCAL), arg=1).view(ShapeTracker.from_shape((BK, BN))).permute((1,0)) - A_col = UOp(Ops.DEFINE_REG, dtypes.float.ptr(nbIterWaveM * TM, AddrSpace.REG), arg=0).view(ShapeTracker.from_shape((nbIterWaveM * TM,))) - B_row = UOp(Ops.DEFINE_REG, dtypes.float.ptr(nbIterWaveN * TN, AddrSpace.REG), arg=1).view(ShapeTracker.from_shape((nbIterWaveN * TN,))) - - # shape buffers. TODO: permutes - full_shape = (N//BM, nbIterWaveM, BM//(nbIterWaveM * TM), TM, N//BN, nbIterWaveN, BN//(nbIterWaveN * TN), TN, N//BK, BK) - a = a.reshape((N//BM, nbIterWaveM, BM//(nbIterWaveM * TM), TM, 1, 1, 1, 1, N//BK, BK)).expand(full_shape) - b = b.reshape((1, 1, 1, 1, N//BN, nbIterWaveN, BN//(nbIterWaveN * TN), TN, N//BK, BK)).expand(full_shape) - c = c.reshape((N//BM, nbIterWaveM, BM//(nbIterWaveM * TM), TM, N//BN, nbIterWaveN, BN//(nbIterWaveN * TN), TN, 1, 1)) - As = As.reshape((1, nbIterWaveM, BM//(nbIterWaveM * TM), TM, 1, 1, 1, 1, 1, BK)).expand(full_shape) - Bs = Bs.reshape((1, 1, 1, 1, 1, nbIterWaveN, BN//(nbIterWaveN * TN), TN, 1, BK)).expand(full_shape) - A_col = A_col.reshape((1, nbIterWaveM, 1, TM, 1, 1, 1, 1, 1, 1)).expand(full_shape) - B_row = B_row.reshape((1, 1, 1, 1, 1, nbIterWaveN, 1, TN, 1, 1)).expand(full_shape) - - # U1 L2 L3 L4 L5 U6 U7 U9 L10 L11 L12 L13 U14 U15 U17 U18 U19 - expanded_shape = (32, 2, 2, 2, 2, 2, 2, 2, 32, 2, 2, 2, 2, 2, 2, 2, 512, 2, 2, 2) - assert len(expanded_shape) == 20 - permute_a = list(range(len(expanded_shape))) - permute_b = permute_a[:] - - # this makes all the global loads match - # this can also be more simply done by rebinding the RANGEs - # but sadly, rebinding the RANGEs doesn't work to change the order of the local axes - permute_a[17:20] = [11,12,13] - permute_a[11:14] = [17,18,19] - permute_a[7], permute_a[10] = permute_a[10], permute_a[7] - permute_a[2:7] = [3,4,5,6,2] - - permute_b[2:16] = [19,9,10,11,17,18,8,2,12,13,14,15,3,4] - permute_b[17:20] = [5,6,7] - - a_permute = a.reshape(expanded_shape).permute(tuple(permute_a)).reshape(full_shape) - As_permute = As.reshape(expanded_shape).permute(tuple(permute_a)).reshape(full_shape) - - b_permute = b.reshape(expanded_shape).permute(tuple(permute_b)).reshape(full_shape) - Bs_permute = Bs.reshape(expanded_shape).permute(tuple(permute_b)).reshape(full_shape) - - #out = (a.load() * b.load()).r(Ops.ADD, (8, 9)) - out = (As.load(As_permute.store(a_permute.load())) * Bs.load(Bs_permute.store(b_permute.load()))).r(Ops.ADD, (8, 9)) - #out = (A_col.load(A_col.store(As.load(As.store(a.load())))) * B_row.load(B_row.store(Bs.load(Bs.store(b.load()))))).r(Ops.ADD, (8, 9)) - - axis_types = ( - AxisType.GLOBAL, AxisType.UPCAST, AxisType.LOCAL, AxisType.UPCAST, - AxisType.GLOBAL, AxisType.UPCAST, AxisType.LOCAL, AxisType.UPCAST, - AxisType.REDUCE, AxisType.REDUCE) - - sink = c.store(out).sink(arg=KernelInfo(name="tg_"+to_colored(full_shape, axis_types), axis_types=axis_types)) - sink = graph_rewrite(sink, merge_views) - return sink - -def hand_spec_kernel3(kernel4=getenv("K4", 0), kernel5=getenv("K5", 0)): BLOCK_SIZE = 128 if kernel5 else 256 nbWaves = BLOCK_SIZE // 32 WN = 128 if kernel5 else 64 WM = BN * BM // nbWaves // WN + # Sanity checks (fail fast if shapes/tiles misalign) + assert BN % WN == 0, "BN must be a multiple of WN" + assert BM % WM == 0, "BM must be a multiple of WM" nbWaveX = BN // WN nbWaveY = BM // WM - threadIdx_x = UOp(Ops.SPECIAL, dtypes.int, arg=("lidx0", BLOCK_SIZE)) + assert BLOCK_SIZE % BN == 0, "BLOCK_SIZE must be divisible by BN" + assert BLOCK_SIZE % BK == 0, "BLOCK_SIZE must be divisible by BK" + + assert (BN * BK) % BLOCK_SIZE == 0 + assert (BM * BK) % BLOCK_SIZE == 0 + + # --------------------------- + # per-thread read mapping + # --------------------------- + # A: read BK x BN tiles; B: read BN x BK tiles + + threadIdx_x = UOp.special(BLOCK_SIZE, "lidx0") waveIndex = threadIdx_x // 32 waveIdx = waveIndex % nbWaveX waveIdy = waveIndex // nbWaveX @@ -157,197 +63,122 @@ def hand_spec_kernel3(kernel4=getenv("K4", 0), kernel5=getenv("K5", 0)): SUBWN = WN // nbIterWaveN SUBWM = WM // nbIterWaveM - # Thread mapping to read BKxBN block from A - rAIdx = threadIdx_x % BK - rAIdy = threadIdx_x // BK - # Thread mapping to read BNxBK block from B - rBIdx = threadIdx_x % BN - rBIdy = threadIdx_x // BN + # --------------------------- + # block indices & placeholders + # --------------------------- + blockIdx_x = UOp.special(N // BN, "gidx0") + blockIdx_y = UOp.special(N // BM, "gidx1") - strideReadB = BLOCK_SIZE // BN - strideReadA = BLOCK_SIZE // BK - nbReadsB = BN * BK // BLOCK_SIZE - nbReadsA = BM * BK // BLOCK_SIZE + a = UOp.placeholder(dtypes.float, (N, N), slot=1) + b = UOp.placeholder(dtypes.float, (N, N), slot=2) + c = UOp.placeholder(dtypes.float, (N, N), slot=0) - blockIdx_x = UOp(Ops.SPECIAL, dtypes.int, arg=("gidx0", N//BN)) - blockIdx_y = UOp(Ops.SPECIAL, dtypes.int, arg=("gidx1", N//BM)) + BM_As_stride = (BM + 4) if kernel5 else BM + As = UOp.placeholder(dtypes.float, (BK, BM_As_stride), slot=0, addrspace=AddrSpace.LOCAL) + Bs = UOp.placeholder(dtypes.float, (BK, BN), slot=1, addrspace=AddrSpace.LOCAL) - a = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=1) - b = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=2) - c = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=0) - - A_col = UOp(Ops.DEFINE_REG, dtypes.float.ptr(nbIterWaveM * TM, AddrSpace.REG), arg=0) - B_row = UOp(Ops.DEFINE_REG, dtypes.float.ptr(nbIterWaveN * TN, AddrSpace.REG), arg=1) - - BM_As_stride = (BM+4) if kernel5 else BM - As = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(BK*BM_As_stride, AddrSpace.LOCAL), arg=0) - Bs = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(BK*BN, AddrSpace.LOCAL), arg=1) - - c_regs = UOp(Ops.DEFINE_REG, dtypes.float.ptr(TM * nbIterWaveM * TN * nbIterWaveN), arg=2) + A_col = UOp.placeholder(dtypes.float, (nbIterWaveM, TM), slot=0, addrspace=AddrSpace.REG) + B_row = UOp.placeholder(dtypes.float, (nbIterWaveN, TN), slot=1, addrspace=AddrSpace.REG) + c_regs = UOp.placeholder(dtypes.float, (nbIterWaveM, TM, nbIterWaveN, TN), slot=2, addrspace=AddrSpace.REG) i = UOp.range(c_regs.dtype.size, 16) - init_store = c_regs[i].store(UOp.const(dtypes.float, 0.0), i) + c_regs = c_regs[i].set(0.0, end=i) - if kernel4: - regA = UOp(Ops.DEFINE_REG, dtypes.float.ptr(nbReadsA, AddrSpace.REG), arg=3) - regB = UOp(Ops.DEFINE_REG, dtypes.float.ptr(nbReadsB, AddrSpace.REG), arg=4) + kId_range = UOp.range(N // BK, 0) + kId = kId_range * BK - # initial load from globals into locals (0) - kId = 0 + # --------------------------- + # GLOBAL -> LOCAL (As, Bs) + # --------------------------- + nbReadsB = BN * BK // BLOCK_SIZE + i = UOp.range(nbReadsB, 1) + rBIdx = threadIdx_x % BN + rBIdy = threadIdx_x // BN + strideReadB = BLOCK_SIZE // BN + index_x = BN * blockIdx_x + rBIdx + index_y = rBIdy + i * strideReadB + kId + Bs_store = Bs[index_y % BK, index_x % BN].store(b[index_y, index_x]).end(i) - # load from globals into locals - i = UOp.range(nbReadsB, 0) - index_x = BN * blockIdx_x + rBIdx - index_y = rBIdy + i * strideReadB + kId - Bs_store = Bs[(index_y % BK) * BN + index_x % BN].store(b[N * index_y + index_x].load(), i) + nbReadsA = BM * BK // BLOCK_SIZE + i = UOp.range(nbReadsA, 2) + rAIdx = threadIdx_x % BK + rAIdy = threadIdx_x // BK + strideReadA = BLOCK_SIZE // BK + index_x = rAIdx + kId + index_y = BM * blockIdx_y + rAIdy + i * strideReadA + As_store = As[index_x % BK, index_y % BM].store(a[index_y, index_x]).end(i) - i = UOp.range(nbReadsA, 1) - index_x = rAIdx + kId - index_y = BM * blockIdx_y + rAIdy + i * strideReadA - As_store = As[(index_x % BK) * BM_As_stride + index_y % BM].store(a[N * index_y + index_x].load(), i) + # TODO: can we automate barrier? + barrier = UOp.barrier(As_store, Bs_store) + Bs = Bs.after(barrier) + As = As.after(barrier) - # iterate over the middle chunk - kId_range = UOp.range(N//BK-1, 2) - kId = kId_range*BK + # open inner k range + k = UOp.range(BK, 3) - barrier = UOp.barrier(As_store, Bs_store) + # --------------------------- + # LOCAL -> REG (per-wave tiles) + # --------------------------- + iterWave = UOp.range(nbIterWaveN, 4) + i = UOp.range(TN, 5) + index = waveIdx * WN + iterWave * SUBWN + TN * idxInWave + i + B_row = B_row[iterWave, i].set(Bs[k, index], end=(iterWave, i)) - # load from globals into registers (next round) - i = UOp.range(nbReadsB, 3) - index_x = BN * blockIdx_x + rBIdx - index_y = rBIdy + i * strideReadB + kId + BK - regB_store = regB[i].store(b[N * index_y + index_x].load(), i) + iterWave = UOp.range(nbIterWaveM, 6) + i = UOp.range(TM, 7) + index = waveIdy * WM + iterWave * SUBWM + TM * idyInWave + i + A_col = A_col[iterWave, i].set(As[k, index], end=(iterWave, i)) - i = UOp.range(nbReadsA, 4) - index_x = rAIdx + kId + BK - index_y = BM * blockIdx_y + rAIdy + i * strideReadA - regA_store = regA[i].store(a[N * index_y + index_x].load(), i) + # --------------------------- + # FMA: c_regs += A_col * B_row + # --------------------------- + iterWaveM = UOp.range(nbIterWaveM, 8) + yt = UOp.range(TM, 9) + iterWaveN = UOp.range(nbIterWaveN, 10) + xt = UOp.range(TN, 12) + c_idx = c_regs.after(k, kId_range)[iterWaveM, yt, iterWaveN, xt] + sink = c_idx.store(c_idx + A_col[iterWaveM, yt] * B_row[iterWaveN, xt]).end(iterWaveM, iterWaveN, yt, xt) - def inner_loop(first_range, inp_dep=()): - # inner unroll - k = UOp.range(BK, first_range+0) + # Close k, sync, and close K tiles + sink = sink.end(k).barrier().end(kId_range) - # load from locals into registers - iterWave = UOp.range(nbIterWaveN, first_range+1) - i = UOp.range(TN, first_range+2) - index = waveIdx * WN + iterWave * SUBWN + TN * idxInWave + i - B_row_store = B_row[iterWave*TN + i].store(Bs[k*BN + index].load(*inp_dep), iterWave, i) - - iterWave = UOp.range(nbIterWaveM, first_range+3) - i = UOp.range(TM, first_range+4) - index = waveIdy * WM + iterWave * SUBWM + TM * idyInWave + i - A_col_store = A_col[iterWave*TM + i].store(As[k*BM_As_stride + index].load(*inp_dep), iterWave, i) - - # do the GEMM math - iterWaveM = UOp.range(nbIterWaveM, first_range+5) - yt = UOp.range(TM, first_range+6) - iterWaveN = UOp.range(nbIterWaveN, first_range+7) - xt = UOp.range(TN, first_range+8) - x = iterWaveN * TN + xt - y = iterWaveM * TM + yt - c_regs_idx = c_regs[y * TN * nbIterWaveN + x] - # sketchy, this should end the kId_range but it doesn't - sink = c_regs_idx.store(c_regs_idx.load(init_store) + A_col[y].load(A_col_store) * B_row[x].load(B_row_store), - iterWaveM, iterWaveN, yt, xt, k) - return sink - - # TODO: kId_range should endrange after a barrier - sink = inner_loop(5, (barrier, regB_store, regA_store)).barrier() - - # load from registers into locals - i = UOp.range(nbReadsB, 14) - index_x = BN * blockIdx_x + rBIdx - index_y = rBIdy + i * strideReadB + kId + BK - Bs_store = Bs[(index_y % BK) * BN + index_x % BN].store(regB[i].load(sink), i, kId_range) - - i = UOp.range(nbReadsA, 15) - index_x = rAIdx + kId + BK - index_y = BM * blockIdx_y + rAIdy + i * strideReadA - As_store = As[(index_x % BK) * BM_As_stride + index_y % BM].store(regA[i].load(sink), i, kId_range) - - # final iteration without the copy - sink = inner_loop(16, (UOp.barrier(Bs_store, As_store),)) - else: - kId_range = UOp.range(N//BK, 0) - kId = kId_range*BK - - # load from globals into locals - i = UOp.range(nbReadsB, 1) - index_x = BN * blockIdx_x + rBIdx - index_y = rBIdy + i * strideReadB + kId - Bs_store = Bs[(index_y % BK) * BN + index_x % BN].store(b[N * index_y + index_x].load(), i) - - i = UOp.range(nbReadsA, 2) - index_x = rAIdx + kId - index_y = BM * blockIdx_y + rAIdy + i * strideReadA - As_store = As[(index_x % BK) * BM_As_stride + index_y % BM].store(a[N * index_y + index_x].load(), i) - - barrier = UOp.barrier(As_store, Bs_store) - - k = UOp.range(BK, 3) - - # load from locals into registers - iterWave = UOp.range(nbIterWaveN, 4) - i = UOp.range(TN, 5) - index = waveIdx * WN + iterWave * SUBWN + TN * idxInWave + i - B_row_store = B_row[iterWave*TN + i].store(Bs[k*BN + index].load(barrier), iterWave, i) - - iterWave = UOp.range(nbIterWaveM, 6) - i = UOp.range(TM, 7) - index = waveIdy * WM + iterWave * SUBWM + TM * idyInWave + i - A_col_store = A_col[iterWave*TM + i].store(As[k*BM_As_stride + index].load(barrier), iterWave, i) - - # do the GEMM math - iterWaveM = UOp.range(nbIterWaveM, 8) - yt = UOp.range(TM, 9) - iterWaveN = UOp.range(nbIterWaveN, 10) - xt = UOp.range(TN, 12) - x = iterWaveN * TN + xt - y = iterWaveM * TM + yt - c_regs_idx = c_regs[y * TN * nbIterWaveN + x] - sink = c_regs_idx.store(c_regs_idx.load(init_store) + A_col[y].load(A_col_store) * B_row[x].load(B_row_store), - iterWaveM, iterWaveN, yt, xt, k, kId_range) - - # store c_regs into c + # --------------------------- + # REG -> GLOBAL (epilogue) + # --------------------------- iterWaveM = UOp.range(nbIterWaveM, 1000) yt = UOp.range(TM, 1001) iterWaveN = UOp.range(nbIterWaveN, 1002) xt = UOp.range(TN, 1003) xOut = blockIdx_x * BN + waveIdx * WN + iterWaveN * SUBWN + TN * idxInWave yOut = blockIdx_y * BM + waveIdy * WM + iterWaveM * SUBWM + TM * idyInWave - indexC = N * (yOut + yt) + xOut + xt - sink = c[indexC].store(c_regs[TN * nbIterWaveN * (iterWaveM * TM + yt) + (iterWaveN * TN + xt)].load(sink), - iterWaveM, iterWaveN, yt, xt) + sink = c[yOut + yt, xOut + xt].store(c_regs.after(sink)[iterWaveM, yt, iterWaveN, xt]) + sink = sink.end(iterWaveM, iterWaveN, yt, xt) + + return sink.sink(arg=KernelInfo(opts_to_apply=())) - return sink.sink(arg=KernelInfo(name="tinygemm")) if __name__ == "__main__": - HL = getenv("HL") - if HL == 3: hprg = rangeify_kernel3() - elif HL == 2: hprg = top_spec_kernel3() - elif HL == 1: hprg = hl_spec_kernel3() - else: hprg = hand_spec_kernel3() - if HL == 3: - prg = get_program(hprg, Device.default.renderer) - else: - prg = get_program(hprg, Device.default.renderer) - print(prg.src) - if getenv("SRC"): exit(0) - hrunner = CompiledRunner(prg) + with Context(DEBUG=0): + a = Tensor.randn(N, N) + b = Tensor.randn(N, N) + hc = Tensor.empty(N, N) + Tensor.realize(a, b, hc) - a = Tensor.randn(N, N).realize() - b = Tensor.randn(N, N).realize() - hc = Tensor.zeros(N, N).contiguous().realize() + sink = hand_spec_kernel3() + ei = ExecItem(get_runner(Device.DEFAULT, sink), [t.uop.buffer for t in [hc, a, b]]) + + GlobalCounters.reset() + ets = [] + with Context(DEBUG=2): + for _ in range(run_count): + ets.append(ei.run(wait=True)) + print(f"REAL TFLOPS {N * N * N * 2 / min(ets) * 1e-12:.2f}") GlobalCounters.reset() with Context(DEBUG=2): - for _ in range(run_count): tc = (a@b).realize() - - GlobalCounters.reset() - buffers = [hc.uop.buffer, a.uop.buffer, b.uop.buffer] - ei = ExecItem(hrunner, buffers) - with Context(DEBUG=2): - for _ in range(run_count): ei.run(wait=True) - err = (hc-tc).square().mean().item() - print(f"hrunner {err}") - if err > 1e-06: raise RuntimeError("matmul is wrong!") + tc = (a @ b).realize() + with Context(DEBUG=0): + err = (hc - tc).square().mean().item() + print(f"mean squared error {err}") + if err > 1e-06: + raise RuntimeError("matmul is wrong!") diff --git a/tinygrad/codegen/late/linearizer.py b/tinygrad/codegen/late/linearizer.py index a12fd0b744..45b53204fb 100644 --- a/tinygrad/codegen/late/linearizer.py +++ b/tinygrad/codegen/late/linearizer.py @@ -75,7 +75,7 @@ pm_add_control_flow = PatternMatcher([ def do_split_ends(e:UOp): ret = e.src[0] - for r in list(UOp.sink(*e.src[1:]).ranges)[::-1]: ret = ret.end(r) + for r in sorted(UOp.sink(*e.src[1:]).ranges, key=lambda x: x.arg, reverse=True): ret = ret.end(r) return ret pm_split_ends = PatternMatcher([ diff --git a/tinygrad/renderer/__init__.py b/tinygrad/renderer/__init__.py index 439615f6a7..71d86cae84 100644 --- a/tinygrad/renderer/__init__.py +++ b/tinygrad/renderer/__init__.py @@ -30,7 +30,7 @@ class Estimates: if ignore_indexing: def range_gate(x): return x.op is not Ops.RANGE for u in uops: - if u.op in {Ops.LOAD, Ops.STORE} and (not isinstance(u.src[0].dtype, PtrDType) or u.src[0].dtype.addrspace != AddrSpace.REG): + if u.op in {Ops.LOAD, Ops.STORE}: # if u.src[0] is INDEX, we have to include the buffer since it might be an AFTER dont_count = dont_count.union((UOp.sink(*u.src[0].src[1:]) if u.src[0].op is Ops.INDEX else u.src[0]).toposort(range_gate)) # TODO: is this correct? this all needs to be cleaned up diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 8ac476bcc5..ae47d756a9 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -5,7 +5,7 @@ from dataclasses import dataclass from enum import Enum, auto from tinygrad.uop import Ops, GroupOp from tinygrad.uop.mathtraits import MathTrait -from tinygrad.dtype import ConstType, ImageDType, dtypes, DType, truncate, PtrDType, least_upper_dtype, Invalid, InvalidType +from tinygrad.dtype import ConstType, ImageDType, dtypes, DType, truncate, PtrDType, least_upper_dtype, Invalid, InvalidType, AddrSpace from tinygrad.helpers import ContextVar, all_int, prod, getenv, all_same, Context, partition, temp, unwrap, T, argfix, Metadata, flatten, TRACEMETA from tinygrad.helpers import PICKLE_BUFFERS, PROFILE, dedup, cdiv, cmod, diskcache_put, to_function_name, cpu_profile, TracingKey, VIZ, SPEC, CI from tinygrad.helpers import strip_parens, colored @@ -756,8 +756,9 @@ class UOp(MathTrait, metaclass=UOpMetaClass): # *** uop high level syntactic sugar *** @staticmethod - def placeholder(dtype:DType, shape:tuple[int, ...], slot:int): - ret = UOp(Ops.DEFINE_GLOBAL, dtype.ptr(prod(shape)), arg=slot) + def placeholder(dtype:DType, shape:tuple[int, ...], slot:int, addrspace=AddrSpace.GLOBAL): + lookup = {AddrSpace.GLOBAL: Ops.DEFINE_GLOBAL, AddrSpace.LOCAL: Ops.DEFINE_LOCAL, AddrSpace.REG: Ops.DEFINE_REG} + ret = UOp(lookup[addrspace], dtype.ptr(prod(shape), addrspace), arg=slot) if len(shape) > 1: ret = ret.reshape(shape) return ret def placeholder_like(self, slot:int): diff --git a/tinygrad/uop/spec.py b/tinygrad/uop/spec.py index de9226121f..96312d00f4 100644 --- a/tinygrad/uop/spec.py +++ b/tinygrad/uop/spec.py @@ -146,8 +146,8 @@ shared_codegen_spec = PatternMatcher([ # SPECIAL (UPat(Ops.SPECIAL, src=(UPat.var("x", (dtypes.index, dtypes.int32)),), name="s"), lambda s,x: s.dtype == x.dtype and isinstance(s.arg, str)), - # BARRIER - (UPat(Ops.BARRIER, dtypes.void, src=(UPat(),)), lambda: True), + # BARRIER (on any length) + (UPat(Ops.BARRIER, dtypes.void), lambda: True), ]) # ***** UOp spec in kernel graph ***** @@ -156,6 +156,7 @@ kernel_spec = PatternMatcher([ # RESHAPE (but only RESHAPE) is allowed here (UPat(Ops.RESHAPE, name="mv", src=(UPat.var("x"), UPat(dtype=dtypes.index))), lambda mv,x: True), (UPat(Ops.AFTER, src=(UPat(Ops.RESHAPE),), allow_any_len=True), lambda: True), + (UPat(Ops.VCONST, dtype=dtypes.index), lambda: True), # index is allowed here (UPat(GroupOp.Elementwise|{Ops.CONST, Ops.RANGE, Ops.DEFINE_VAR}, dtype=dtypes.index), lambda: True),