From 514d2a07746081cc866cd747defc99b49ba43d27 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Tue, 7 Oct 2025 13:57:58 +0800 Subject: [PATCH] merge tagless reshapes (#12474) * merge tagless reshapes * cleanup --- extra/thunder/gemm.py | 11 +++++++---- tinygrad/schedule/rangeify.py | 3 +++ 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/extra/thunder/gemm.py b/extra/thunder/gemm.py index 005627e7f8..61d3dae787 100644 --- a/extra/thunder/gemm.py +++ b/extra/thunder/gemm.py @@ -1,4 +1,5 @@ # include directory copied from https://github.com/HazyResearch/ThunderMittens +# https://hazyresearch.stanford.edu/blog/2024-11-28-tk-mlx gemm = """ #include @@ -41,10 +42,9 @@ kernel void matmul_naive(GEMM_PARAMS_DEF(T)) { instantiate_matmul_custom(float32, float); """ -from tinygrad import Device, Tensor +from tinygrad import Device, Tensor, Context if __name__ == "__main__": - # TODO: why isn't this type inferred? device = Device["METAL"] lib = device.compiler.compile(gemm) prg = device.runtime("matmul_custom_float32", lib) @@ -65,7 +65,10 @@ if __name__ == "__main__": global_size=gsz, local_size=(32,1,1), vals=(N, N, N), wait=True) print(f"{N*N*N*2/(et*1e9):2f} GFLOPS") - val = ((a@b).contiguous()-c).mean() - print(val.item()) + for _ in range(5): + with Context(DEBUG=2): + ref = (a@b).realize() + + print((ref-c).mean().item()) diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index 2e47656d9a..1d46836dac 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -46,6 +46,9 @@ earliest_rewrites = PatternMatcher([ # just removing it works... (UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD, Ops.FUSE), name="x"), lambda x: x.src[0]), + # merge adjacent RESHAPES, safe because they are not tagged + (UPat(Ops.RESHAPE, name="x2").f(Ops.RESHAPE, name="x"), lambda x,x2: x.replace(src=(x2.src[0],)) if x.tag is None and x2.tag is None else None), + # remove CONTIGUOUS if the BUFFER is already contiguous (UPat(Ops.BUFFER).f(Ops.RESHAPE, name="r").f(Ops.CONTIGUOUS, name="c"), lambda r,c: r.replace(tag=c.tag)),