uop matmul work (#11388)

* uop matmul work

* works with locals
This commit is contained in:
George Hotz
2025-07-26 21:23:55 -07:00
committed by GitHub
parent 3923e78061
commit dfeee63d30
3 changed files with 46 additions and 10 deletions

View File

@@ -3,7 +3,7 @@ from tinygrad.uop.ops import UOp, Ops, KernelInfo, graph_rewrite, AxisType
from tinygrad.engine.realize import CompiledRunner, ExecItem, get_program
from tinygrad.dtype import AddrSpace
from tinygrad.schedule.kernelize import merge_views
from tinygrad.helpers import getenv
from tinygrad.helpers import getenv, colored
from tinygrad.shape.shapetracker import ShapeTracker
N = 4096
@@ -22,9 +22,9 @@ def hl_spec_kernel3():
# define buffers
# TODO: remove these views once the defines have a shape
a = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=0).view(ShapeTracker.from_shape((N*N,)))
b = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=1).view(ShapeTracker.from_shape((N*N,)))
c = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=2).view(ShapeTracker.from_shape((N*N,)))
a = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=0).view(ShapeTracker.from_shape((N,N)))
b = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=1).view(ShapeTracker.from_shape((N,N))).permute((1,0))
c = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=2).view(ShapeTracker.from_shape((N,N)))
As = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(BK*BM, AddrSpace.LOCAL), arg=0).view(ShapeTracker.from_shape((BK*BM,)))
Bs = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(BK*BN, AddrSpace.LOCAL), arg=1).view(ShapeTracker.from_shape((BK*BN,)))
A_col = UOp(Ops.DEFINE_REG, dtypes.float.ptr(nbIterWaveM * TM, AddrSpace.REG), arg=0).view(ShapeTracker.from_shape((nbIterWaveM * TM,)))
@@ -39,12 +39,19 @@ def hl_spec_kernel3():
Bs = Bs.reshape((1, 1, 1, 1, 1, nbIterWaveN, BN//(nbIterWaveN * TN), TN, 1, BK)).expand(full_shape)
A_col = A_col.reshape((1, nbIterWaveM, 1, TM, 1, 1, 1, 1, 1, 1)).expand(full_shape)
B_row = B_row.reshape((1, 1, 1, 1, 1, nbIterWaveN, 1, TN, 1, 1)).expand(full_shape)
out = (A_col.load(A_col.store(As.load(As.store(a.load())))) * B_row.load(B_row.store(Bs.load(Bs.store(b.load()))))).r(Ops.ADD, (8, 9))
axis_types = [
#out = (a.load() * b.load()).r(Ops.ADD, (8, 9))
out = (As.load(As.store(a.load())) * Bs.load(Bs.store(b.load()))).r(Ops.ADD, (8, 9))
#out = (A_col.load(A_col.store(As.load(As.store(a.load())))) * B_row.load(B_row.store(Bs.load(Bs.store(b.load()))))).r(Ops.ADD, (8, 9))
axis_types = (
AxisType.GLOBAL, AxisType.UPCAST, AxisType.LOCAL, AxisType.UPCAST,
AxisType.GLOBAL, AxisType.UPCAST, AxisType.LOCAL, AxisType.UPCAST,
AxisType.REDUCE, AxisType.UNROLL]
sink = c.store(out).sink(arg=KernelInfo(name="tinygemm", axis_types=tuple(axis_types)))
AxisType.REDUCE, AxisType.UNROLL)
from tinygrad.opt.kernel import axis_colors
shape = '_'.join([colored(str(s), axis_colors[at]) for s,at in zip(full_shape, axis_types)])
sink = c.store(out).sink(arg=KernelInfo(name="tg_"+shape, axis_types=axis_types))
sink = graph_rewrite(sink, merge_views)
return sink

28
test/test_define_reg.py Normal file
View File

@@ -0,0 +1,28 @@
import unittest
from tinygrad import dtypes, Device, Tensor, Context
from tinygrad.dtype import AddrSpace
from tinygrad.uop.ops import UOp, Ops, KernelInfo, AxisType
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.engine.realize import get_program, ExecItem, CompiledRunner
class TestDefineReg(unittest.TestCase):
def test_simple(self):
N = 16
bout = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=0).view(ShapeTracker.from_shape((N,N)))
a = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=1).view(ShapeTracker.from_shape((N,N)))
a_col = UOp(Ops.DEFINE_REG, dtypes.float.ptr(N, AddrSpace.REG), arg=0).view(ShapeTracker.from_shape((1,N)))
out = a_col.load(a_col.store(a.load()))
sink = bout.store(out).sink(arg=KernelInfo(name="regcopy", axis_types=(AxisType.LOOP, AxisType.UPCAST)))
prg = get_program(sink, Device.default.renderer)
with Context(DEBUG=0):
a = Tensor.randn(N, N).realize()
b = Tensor.empty(N, N).realize()
hrunner = CompiledRunner(prg)
ExecItem(hrunner, [b.uop.buffer, a.uop.buffer]).run(wait=True)
with Context(DEBUG=0):
self.assertEqual((b-a).mean().item(), 0.0)
if __name__ == '__main__':
unittest.main()

View File

@@ -56,8 +56,9 @@ def add_gpudims(ctx:Renderer, s:UOp):
if not ki.global_dims and not ki.local_dims: return None
s_topo = list(s.toposort())
if any(x.op is Ops.SPECIAL for x in s_topo): return None
ranges = sorted([x for x in s_topo if x.op is Ops.RANGE and x.arg in (ki.global_dims+ki.local_dims)], key=lambda x: x.arg)
if not len(ranges): return None
all_ranges = {x.arg:x for x in s_topo if x.op is Ops.RANGE}
# NOTE: this supports globals/locals in any position
ranges = [all_ranges[r] for r in ki.global_dims+ki.local_dims]
global_shape = tuple([ssimplify(r.src[0]) for r in ranges if r.arg in ki.global_dims])
local_shape = tuple([ssimplify(r.src[0]) for r in ranges if r.arg in ki.local_dims])
if ki.dont_use_locals: