From 8c47cf43237c93fbd717efd676ca77b48e865341 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Wed, 29 Oct 2025 13:06:43 +0800 Subject: [PATCH] pcontig double matmul works (#12899) * pcontig double matmul works * tests * contract * closer * works-ish * add that broadcast * 2 more work * something * disable broken ones * llvm * align 16 --- test/test_rangeify.py | 43 ++++++++++++++++++++++++++++++- tinygrad/codegen/__init__.py | 3 --- tinygrad/codegen/late/expander.py | 3 +++ tinygrad/codegen/opt/postrange.py | 2 +- tinygrad/renderer/llvmir.py | 4 ++- tinygrad/schedule/rangeify.py | 8 +++++- tinygrad/uop/ops.py | 2 +- tinygrad/uop/spec.py | 3 +++ 8 files changed, 60 insertions(+), 8 deletions(-) diff --git a/test/test_rangeify.py b/test/test_rangeify.py index 72317f0984..e8de1d6513 100644 --- a/test/test_rangeify.py +++ b/test/test_rangeify.py @@ -1,10 +1,51 @@ import unittest from tinygrad import Tensor, nn, Device -from tinygrad.helpers import Context, GlobalCounters, CI, getenv, PCONTIG +from tinygrad.helpers import Context, GlobalCounters, CI, getenv, PCONTIG, DEBUG from tinygrad.uop.ops import graph_rewrite, PatternMatcher, UPat, Ops +from tinygrad.codegen.opt import OptOps, Opt from tinygrad.renderer.ptx import PTXRenderer from tinygrad.renderer.nir import NIRRenderer +@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, NIRRenderer), "broken in LVP") +class TestDoubleMatmul(unittest.TestCase): + def setUp(self): + with Context(DEBUG=0): + self.a, self.b, self.c = [Tensor.randn(16, 16).contiguous().realize() for _ in range(3)] + self.cmp = (self.a @ self.b @ self.c).realize() + + def _test(self, opts): + with Context(PCONTIG=2, DEBUG=max(2, DEBUG.value)): + out = (self.a @ self.b @ self.c).contiguous(arg=opts).realize() + + with Context(DEBUG=0): + err = (out-self.cmp).square() + self.assertLess(err.max().item(), 1e-4) + self.assertLess(err.mean().item(), 1e-6) + + def test_baseline(self): self._test(()) + def test_upcast_0(self): self._test((Opt(OptOps.UPCAST, 0, 4),)) + def test_upcast_1(self): self._test((Opt(OptOps.UPCAST, 1, 4),)) + def test_upcast_2(self): self._test((Opt(OptOps.UPCAST, 2, 4),)) + @unittest.skip("doesn't work") + def test_upcast_01(self): self._test((Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4))) + def test_upcast_02(self): self._test((Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 2, 4))) + def test_upcast_12(self): self._test((Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UPCAST, 2, 4))) + + def test_unroll_0(self): self._test((Opt(OptOps.UNROLL, 0, 4),)) + def test_unroll_1(self): self._test((Opt(OptOps.UNROLL, 1, 4),)) + def test_unroll_01(self): self._test((Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UNROLL, 1, 4))) + + def test_upcast_0_unroll_0(self): self._test((Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 4))) + def test_upcast_1_unroll_0(self): self._test((Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 4))) + def test_upcast_2_unroll_0(self): self._test((Opt(OptOps.UPCAST, 2, 4), Opt(OptOps.UNROLL, 0, 4))) + + @unittest.skip("doesn't work") + def test_upcast_01_unroll_01(self): + self._test((Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UNROLL, 1, 4))) + @unittest.skip("doesn't work") + def test_upcast_12_unroll_01(self): + self._test((Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UPCAST, 2, 4), Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UNROLL, 1, 4))) + class TestRangeifyAssign(unittest.TestCase): def test_assign_permuted(self): A = Tensor.empty(4, 4, dtype='int') diff --git a/tinygrad/codegen/__init__.py b/tinygrad/codegen/__init__.py index 61cb13cc21..d595444943 100644 --- a/tinygrad/codegen/__init__.py +++ b/tinygrad/codegen/__init__.py @@ -25,9 +25,6 @@ def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) - # first we optimize if optimize: - # TODO: fix expander and remove this - sink = graph_rewrite(sink, pm_add_buffers_local, name="add locals early") - # collapse loads reduce (indexing by a tensor) sink = graph_rewrite(sink, pm_load_collapse, name="load collapse") diff --git a/tinygrad/codegen/late/expander.py b/tinygrad/codegen/late/expander.py index 1f270394e6..ce028f1492 100644 --- a/tinygrad/codegen/late/expander.py +++ b/tinygrad/codegen/late/expander.py @@ -82,6 +82,9 @@ def do_contract(con:UOp): return UOp(Ops.UNROLL, con.dtype, (ex.src[0].gep(tuple(idxs)),), new_ex_args) expander = PatternMatcher([ + # BUFFERIZE puts UNROLLs for ranges as contract + (UPat(Ops.BUFFERIZE, src=(UPat(Ops.UNROLL), UPat(Ops.UNROLL)), name="x"), + lambda x: x.replace(src=tuple(UOp(Ops.CONTRACT, dtype=s.dtype.vec(x.src[1].src[0].dtype.count), src=(s,), arg=x.src[1].arg) for s in x.src))), # double expand (UPat(Ops.UNROLL, name="outer", src=(UPat(Ops.UNROLL, name="inner"),)), lambda outer, inner: UOp(Ops.UNROLL, outer.dtype, (inner.src[0],), inner.arg+outer.arg)), diff --git a/tinygrad/codegen/opt/postrange.py b/tinygrad/codegen/opt/postrange.py index f0e00004ea..c2a82d5c85 100644 --- a/tinygrad/codegen/opt/postrange.py +++ b/tinygrad/codegen/opt/postrange.py @@ -334,6 +334,6 @@ def apply_opts(ast:UOp, ren:Renderer) -> UOp: elif not NOOPT and (ast.arg is None or ast.arg.applied_opts == ()): from tinygrad.codegen.opt.heuristic import hand_coded_optimizations # NOTE: hand_coded_optimizations doesn't support multiblock opts yet - if not any(u.op is Ops.AFTER and u.src[0].op is Ops.DEFINE_LOCAL for u in ast.backward_slice): + if not any(u.op is Ops.BUFFERIZE for u in ast.backward_slice): k = hand_coded_optimizations(k) return k.get_optimized_ast(name_override=ast.arg.name if ast.arg is not None and ast.arg.name != "test" else None) diff --git a/tinygrad/renderer/llvmir.py b/tinygrad/renderer/llvmir.py index e83e44364f..684b12d654 100644 --- a/tinygrad/renderer/llvmir.py +++ b/tinygrad/renderer/llvmir.py @@ -187,8 +187,10 @@ class LLVMRenderer(Renderer): elif u.op in (Ops.DEFINE_LOCAL, Ops.DEFINE_REG): r[u] = f"%{'local' if u.op is Ops.DEFINE_LOCAL else 'reg'}_{str(u.arg).replace('(', '').replace(')', '').replace(',', '_').replace(' ', '')}" assert isinstance(u.dtype, PtrDType) - if self.device == "CPU" or u.op is Ops.DEFINE_REG: + if u.op is Ops.DEFINE_REG: kernel.append(f" {r[u]} = alloca [{u.dtype.size} x {ldt(u.dtype.base)}]") + elif self.device == "CPU" and u.op is Ops.DEFINE_LOCAL: + kernel.append(f" {r[u]} = alloca [{u.dtype.size} x {ldt(u.dtype.base)}], align 16") else: local_args.append(f"@{r[u][1:]} = internal unnamed_addr addrspace(3) global [{u.dtype.size} x {ldt(u.dtype)}] undef, align 16") kernel.append(f" {r[u]} = addrspacecast [{u.dtype.size} x {ldt(u.dtype)}] addrspace(3)* @{r[u][1:]} to [{u.dtype.size} x {ldt(u.dtype)}]*") diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index 64fea35b26..83218d3989 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -332,7 +332,7 @@ def bufferize_to_store(x:UOp, idx:UOp, allow_locals=True): tag = x.arg.device if tag is None: tag = UOp.unique().arg # TODO: hack buf = UOp(Ops.DEFINE_LOCAL, sdtype, arg=tag) - do_store = buf.index(idx, dtype=sdtype).store(x.src[0]).end(*rngs) + do_store = buf.broadcast(x.src[1].dtype.count).index(idx, dtype=sdtype).store(x.src[0]).end(*rngs) return buf.after(do_store.barrier()) # collapse any BUFFERIZE to single input BUFFERIZE. move the tag to a reshape @@ -438,6 +438,12 @@ rangeify_codegen = PatternMatcher([ (UPat.any(UPat(Ops.DEFINE_GLOBAL, name="dg"), UPat(Ops.DEFINE_LOCAL).f(Ops.AFTER, allow_any_len=True, name="dg")) .f(Ops.INDEX, name="idx", allow_any_len=True), lambda dg,idx: None if isinstance(idx.dtype, (PtrDType, ImageDType)) else idx.replace(dtype=dg.dtype, arg=None).load()), + + # fix broadcast dtype + (UPat(Ops.AFTER, name="a").broadcast(name="b"), lambda a,b: a.broadcast(len(b.src))), + (UPat(Ops.DEFINE_LOCAL).f(Ops.AFTER, allow_any_len=True).broadcast(name="dg").f(Ops.INDEX, name="idx", allow_any_len=True), + lambda dg,idx: None if isinstance(idx.dtype, (PtrDType, ImageDType)) else + idx.replace(dtype=dg.dtype, arg=None).load(dtype=dg.dtype.base.scalar().vec(dg.dtype.vcount))), ]) def remove_metadata_tags(ctx:LocalAddBufferContext, x:UOp): diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 0e257fe1af..263269a222 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -346,7 +346,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): # constants can optionally have a DEVICE source return UOp.const(self.dtype, b, device=self._device, shape=self._shape) def broadcast(self, count:int): - assert self.dtype.count == 1 + assert self.dtype.vcount == 1 if count == 1: return self return UOp(Ops.VECTORIZE, self.dtype.vec(count), (self,)*count) def cast(self, dtype:DType): diff --git a/tinygrad/uop/spec.py b/tinygrad/uop/spec.py index 815f48aa19..e97db78ef5 100644 --- a/tinygrad/uop/spec.py +++ b/tinygrad/uop/spec.py @@ -223,6 +223,9 @@ full_spec = PatternMatcher([ # in progress MSTACK may lose device (UPat((Ops.MSELECT, Ops.MSTACK), name="x"), lambda x: True), + # temp VECTORIZEs during rewrite have the wrong dtype + (UPat(Ops.VECTORIZE), lambda: True), + # all loads/stores (UPat((Ops.LOAD, Ops.STORE)), lambda: True), # DEFINE_VAR to deal with the floats used in reduce collapse