HL=2 top matmul (#11406)

* HL=2 top matmul

* top colored
This commit is contained in:
George Hotz
2025-07-28 12:32:38 -07:00
committed by GitHub
parent c7b4ab86e4
commit fddc645668
3 changed files with 59 additions and 17 deletions

View File

@@ -1,10 +1,13 @@
from tinygrad import Tensor, Device, Context, GlobalCounters, dtypes
from tinygrad.uop.ops import UOp, Ops, KernelInfo, graph_rewrite, AxisType
from tinygrad.uop.ops import UOp, Ops, KernelInfo, graph_rewrite, AxisType, PatternMatcher, UPat
from tinygrad.engine.realize import CompiledRunner, ExecItem, get_program
from tinygrad.dtype import AddrSpace
from tinygrad.schedule.kernelize import merge_views
from tinygrad.helpers import getenv, colored
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.schedule.kernelize import merge_views, view_left
from tinygrad.helpers import getenv, colored, prod, unwrap
from tinygrad.shape.shapetracker import ShapeTracker, View, strides_for_shape
from tinygrad.opt.kernel import axis_colors
def to_colored(full_shape, axis_types): return '_'.join([colored(str(s), axis_colors[at]) for s,at in zip(full_shape, axis_types)])
N = 4096
run_count = 5
@@ -16,15 +19,50 @@ BK = 8
TN = 4
TM = 4
# NOTE: this is from testgrad
# change reduceop axes and input ShapeTrackers, view gets replaced with a reshape.
# src->r->view --> src->view->r
def swizzle_reduceop(src:UOp, r:UOp, view:UOp):
if r.tag is not None: return None
# confirm the input is in order
# TODO: replace this with a UOp that allows for nothing else then remove this
permute = tuple(i for i in range(len(src.shape)) if i not in r.axis_arg)+r.axis_arg
assert permute == tuple(range(len(permute))), f"reduce axis must already be in order, {permute} isn't"
# append the reduce shape to each of the views
prshape = prod(rshape:=src.shape[-len(r.axis_arg):])
rstrides = strides_for_shape(rshape)
nv = [View.create(v.shape+rshape, tuple(x*prshape for x in v.strides)+rstrides, v.offset*prshape,
v.mask+tuple((0,s) for s in rshape) if v.mask is not None else None) for v in unwrap(view.st).views]
# no reshape required with shrinking REDUCE_AXIS
return UOp(Ops.REDUCE_AXIS, r.dtype, (src.view(ShapeTracker(tuple(nv))),),
(r.arg[0], tuple(range(len(view.shape), len(view.shape) + len(r.axis_arg)))))
pm = PatternMatcher([
(UPat(Ops.VIEW, src=(UPat(Ops.REDUCE_AXIS, src=(UPat.var("src"),), name="r"),), name="view"), swizzle_reduceop),
])
def top_spec_kernel3():
a = Tensor.empty(N,N)
b = Tensor.empty(N,N)
c = a@b
sink = c.schedule()[-1].ast
L = 16
sink = sink.reshape((N//L, L, N//L, L)) #.lift({0:UOp.range(dtypes.int, N//BM, 0), 2:UOp.range(dtypes.int, N//BN, 1)})
sink = graph_rewrite(sink, view_left+pm)
axis_types = (AxisType.GLOBAL, AxisType.LOCAL, AxisType.GLOBAL, AxisType.LOCAL, AxisType.REDUCE)
return sink.replace(arg=KernelInfo(name="top_"+to_colored(sink.full_shape, axis_types), axis_types=axis_types))
def hl_spec_kernel3():
nbIterWaveM = 2
nbIterWaveN = 2
# define buffers
# TODO: remove these views once the defines have a shape
a = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=0).view(ShapeTracker.from_shape((N,N)))
b = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=1).view(ShapeTracker.from_shape((N,N))).permute((1,0))
c = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=2).view(ShapeTracker.from_shape((N,N)))
a = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=1).view(ShapeTracker.from_shape((N,N)))
b = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=2).view(ShapeTracker.from_shape((N,N))).permute((1,0))
c = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=0).view(ShapeTracker.from_shape((N,N)))
As = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(BK*BM, AddrSpace.LOCAL), arg=0).view(ShapeTracker.from_shape((BK*BM,)))
Bs = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(BK*BN, AddrSpace.LOCAL), arg=1).view(ShapeTracker.from_shape((BK*BN,)))
A_col = UOp(Ops.DEFINE_REG, dtypes.float.ptr(nbIterWaveM * TM, AddrSpace.REG), arg=0).view(ShapeTracker.from_shape((nbIterWaveM * TM,)))
@@ -49,9 +87,7 @@ def hl_spec_kernel3():
AxisType.GLOBAL, AxisType.UPCAST, AxisType.LOCAL, AxisType.UPCAST,
AxisType.REDUCE, AxisType.UNROLL)
from tinygrad.opt.kernel import axis_colors
shape = '_'.join([colored(str(s), axis_colors[at]) for s,at in zip(full_shape, axis_types)])
sink = c.store(out).sink(arg=KernelInfo(name="tg_"+shape, axis_types=axis_types))
sink = c.store(out).sink(arg=KernelInfo(name="tg_"+to_colored(full_shape, axis_types), axis_types=axis_types))
sink = graph_rewrite(sink, merge_views)
return sink
@@ -98,9 +134,9 @@ def hand_spec_kernel3():
blockIdx_x = UOp(Ops.SPECIAL, dtypes.int, arg=("gidx0", N//BN))
blockIdx_y = UOp(Ops.SPECIAL, dtypes.int, arg=("gidx1", N//BM))
a = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=0)
b = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=1)
c = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=2)
a = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=1)
b = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=2)
c = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=0)
A_col = UOp(Ops.DEFINE_REG, dtypes.float.ptr(nbIterWaveM * TM, AddrSpace.REG), arg=0)
B_row = UOp(Ops.DEFINE_REG, dtypes.float.ptr(nbIterWaveN * TN, AddrSpace.REG), arg=1)
@@ -167,9 +203,13 @@ def hand_spec_kernel3():
return sink.sink(arg=KernelInfo(name="tinygemm"))
if __name__ == "__main__":
hprg = hl_spec_kernel3() if getenv("HL") else hand_spec_kernel3()
HL = getenv("HL")
if HL == 2: hprg = top_spec_kernel3()
elif HL == 1: hprg = hl_spec_kernel3()
else: hprg = hand_spec_kernel3()
prg = get_program(hprg, Device.default.renderer)
print(prg.src)
if getenv("SRC"): exit(0)
hrunner = CompiledRunner(prg)
a = Tensor.randn(N, N).realize()
@@ -181,7 +221,8 @@ if __name__ == "__main__":
for _ in range(run_count): tc = (a@b).realize()
GlobalCounters.reset()
ei = ExecItem(hrunner, [a.uop.buffer, b.uop.buffer, hc.uop.buffer])
buffers = [hc.uop.buffer, a.uop.buffer, b.uop.buffer]
ei = ExecItem(hrunner, buffers)
with Context(DEBUG=2):
for _ in range(run_count): ei.run(wait=True)
err = (hc-tc).square().mean().item()

View File

@@ -52,7 +52,8 @@ def lower_load(ctx: IndexContext, x: UOp, buf: UOp):
return UOp(Ops.LOAD, x.dtype, (buf.index(idx, valid),) + barrier)
def lower_store(ctx: IndexContext, x: UOp, buf: UOp):
assert x.src[1].shape == x.src[0].shape, f"shape mismatch on store {x.src[1].shape} != {x.src[0].shape}"
# TODO: reenable after REDUCE_AXIS is fixed
#assert x.src[1].shape == x.src[0].shape, f"shape mismatch on store {x.src[1].shape} != {x.src[0].shape}"
idx, valid = x.st_arg.to_indexed_uops(ctx.idxs)
if cast(PtrDType, buf.dtype).addrspace == AddrSpace.GLOBAL:
# NOTE: only store the local reduceop in the threads that are actually doing the reduce

View File

@@ -188,7 +188,7 @@ def reduce_push_add_ones(src:UOp, r:UOp, view:UOp):
view_left = merge_views+PatternMatcher([
# view before elementwise and buffer ops
(UPat(Ops.VIEW, src=(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.BIND, Ops.LOAD, Ops.STORE, Ops.VALID}, name="e"),), name="view"),
(UPat(Ops.VIEW, src=(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.BIND, Ops.LOAD, Ops.STORE, Ops.VALID, Ops.SINK}, name="e"),), name="view"),
lambda e,view: e.replace(src=tuple(s.view(view.st) for s in e.src))),
# if there's ones added after reduce, put this before the reduce
(UPat(Ops.VIEW, src=(UPat(Ops.REDUCE_AXIS, src=(UPat.var("src"),), name="r"),), name="view"), reduce_push_add_ones),