From dfeee63d30550a90cb1bd54802529da535ed95bd Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Sat, 26 Jul 2025 21:23:55 -0700 Subject: [PATCH] uop matmul work (#11388) * uop matmul work * works with locals --- extra/gemm/amd_uop_matmul.py | 23 +++++++++++++++-------- test/test_define_reg.py | 28 ++++++++++++++++++++++++++++ tinygrad/codegen/gpudims.py | 5 +++-- 3 files changed, 46 insertions(+), 10 deletions(-) create mode 100644 test/test_define_reg.py diff --git a/extra/gemm/amd_uop_matmul.py b/extra/gemm/amd_uop_matmul.py index ac5c71ca38..1c4b6f217a 100644 --- a/extra/gemm/amd_uop_matmul.py +++ b/extra/gemm/amd_uop_matmul.py @@ -3,7 +3,7 @@ from tinygrad.uop.ops import UOp, Ops, KernelInfo, graph_rewrite, AxisType from tinygrad.engine.realize import CompiledRunner, ExecItem, get_program from tinygrad.dtype import AddrSpace from tinygrad.schedule.kernelize import merge_views -from tinygrad.helpers import getenv +from tinygrad.helpers import getenv, colored from tinygrad.shape.shapetracker import ShapeTracker N = 4096 @@ -22,9 +22,9 @@ def hl_spec_kernel3(): # define buffers # TODO: remove these views once the defines have a shape - a = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=0).view(ShapeTracker.from_shape((N*N,))) - b = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=1).view(ShapeTracker.from_shape((N*N,))) - c = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=2).view(ShapeTracker.from_shape((N*N,))) + a = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=0).view(ShapeTracker.from_shape((N,N))) + b = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=1).view(ShapeTracker.from_shape((N,N))).permute((1,0)) + c = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=2).view(ShapeTracker.from_shape((N,N))) As = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(BK*BM, AddrSpace.LOCAL), arg=0).view(ShapeTracker.from_shape((BK*BM,))) Bs = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(BK*BN, AddrSpace.LOCAL), arg=1).view(ShapeTracker.from_shape((BK*BN,))) A_col = UOp(Ops.DEFINE_REG, dtypes.float.ptr(nbIterWaveM * TM, AddrSpace.REG), arg=0).view(ShapeTracker.from_shape((nbIterWaveM * TM,))) @@ -39,12 +39,19 @@ def hl_spec_kernel3(): Bs = Bs.reshape((1, 1, 1, 1, 1, nbIterWaveN, BN//(nbIterWaveN * TN), TN, 1, BK)).expand(full_shape) A_col = A_col.reshape((1, nbIterWaveM, 1, TM, 1, 1, 1, 1, 1, 1)).expand(full_shape) B_row = B_row.reshape((1, 1, 1, 1, 1, nbIterWaveN, 1, TN, 1, 1)).expand(full_shape) - out = (A_col.load(A_col.store(As.load(As.store(a.load())))) * B_row.load(B_row.store(Bs.load(Bs.store(b.load()))))).r(Ops.ADD, (8, 9)) - axis_types = [ + + #out = (a.load() * b.load()).r(Ops.ADD, (8, 9)) + out = (As.load(As.store(a.load())) * Bs.load(Bs.store(b.load()))).r(Ops.ADD, (8, 9)) + #out = (A_col.load(A_col.store(As.load(As.store(a.load())))) * B_row.load(B_row.store(Bs.load(Bs.store(b.load()))))).r(Ops.ADD, (8, 9)) + + axis_types = ( AxisType.GLOBAL, AxisType.UPCAST, AxisType.LOCAL, AxisType.UPCAST, AxisType.GLOBAL, AxisType.UPCAST, AxisType.LOCAL, AxisType.UPCAST, - AxisType.REDUCE, AxisType.UNROLL] - sink = c.store(out).sink(arg=KernelInfo(name="tinygemm", axis_types=tuple(axis_types))) + AxisType.REDUCE, AxisType.UNROLL) + + from tinygrad.opt.kernel import axis_colors + shape = '_'.join([colored(str(s), axis_colors[at]) for s,at in zip(full_shape, axis_types)]) + sink = c.store(out).sink(arg=KernelInfo(name="tg_"+shape, axis_types=axis_types)) sink = graph_rewrite(sink, merge_views) return sink diff --git a/test/test_define_reg.py b/test/test_define_reg.py new file mode 100644 index 0000000000..10ab277dd5 --- /dev/null +++ b/test/test_define_reg.py @@ -0,0 +1,28 @@ +import unittest +from tinygrad import dtypes, Device, Tensor, Context +from tinygrad.dtype import AddrSpace +from tinygrad.uop.ops import UOp, Ops, KernelInfo, AxisType +from tinygrad.shape.shapetracker import ShapeTracker +from tinygrad.engine.realize import get_program, ExecItem, CompiledRunner + +class TestDefineReg(unittest.TestCase): + def test_simple(self): + N = 16 + bout = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=0).view(ShapeTracker.from_shape((N,N))) + a = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=1).view(ShapeTracker.from_shape((N,N))) + a_col = UOp(Ops.DEFINE_REG, dtypes.float.ptr(N, AddrSpace.REG), arg=0).view(ShapeTracker.from_shape((1,N))) + + out = a_col.load(a_col.store(a.load())) + sink = bout.store(out).sink(arg=KernelInfo(name="regcopy", axis_types=(AxisType.LOOP, AxisType.UPCAST))) + prg = get_program(sink, Device.default.renderer) + + with Context(DEBUG=0): + a = Tensor.randn(N, N).realize() + b = Tensor.empty(N, N).realize() + hrunner = CompiledRunner(prg) + ExecItem(hrunner, [b.uop.buffer, a.uop.buffer]).run(wait=True) + with Context(DEBUG=0): + self.assertEqual((b-a).mean().item(), 0.0) + +if __name__ == '__main__': + unittest.main() diff --git a/tinygrad/codegen/gpudims.py b/tinygrad/codegen/gpudims.py index 9e4d98dedf..09b67e8489 100644 --- a/tinygrad/codegen/gpudims.py +++ b/tinygrad/codegen/gpudims.py @@ -56,8 +56,9 @@ def add_gpudims(ctx:Renderer, s:UOp): if not ki.global_dims and not ki.local_dims: return None s_topo = list(s.toposort()) if any(x.op is Ops.SPECIAL for x in s_topo): return None - ranges = sorted([x for x in s_topo if x.op is Ops.RANGE and x.arg in (ki.global_dims+ki.local_dims)], key=lambda x: x.arg) - if not len(ranges): return None + all_ranges = {x.arg:x for x in s_topo if x.op is Ops.RANGE} + # NOTE: this supports globals/locals in any position + ranges = [all_ranges[r] for r in ki.global_dims+ki.local_dims] global_shape = tuple([ssimplify(r.src[0]) for r in ranges if r.arg in ki.global_dims]) local_shape = tuple([ssimplify(r.src[0]) for r in ranges if r.arg in ki.local_dims]) if ki.dont_use_locals: