From 9764c6cdee188569a47388ef0c2b0355230a322e Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Thu, 7 Aug 2025 07:57:58 -0700 Subject: [PATCH] fix mismatch reduce, try 2 (#11560) * fix mismatch reduce, try 2 * fix heuristic * delete that test * don't start allowing ones --- test/test_schedule.py | 1 - test/test_softmax_fusion.py | 1 - test/unit/test_uop_spec.py | 11 ----------- tinygrad/opt/heuristic.py | 25 ++++++++++++++----------- tinygrad/opt/kernel.py | 5 ++++- tinygrad/uop/spec.py | 11 +---------- 6 files changed, 19 insertions(+), 35 deletions(-) diff --git a/test/test_schedule.py b/test/test_schedule.py index 4538fe395a..39f70161ef 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -2066,7 +2066,6 @@ class TestSwizzle(unittest.TestCase): np.testing.assert_allclose(t.numpy(), x.numpy().sum(axis=1)+y.numpy().sum(axis=1), atol=1e-6, rtol=1e-3) # kernels can only have 1 or n in each dim - @unittest.expectedFailure def test_dont_parallelize_different_n(self): Tensor.manual_seed(0) x = Tensor.randn(4, 2, 2).realize() diff --git a/test/test_softmax_fusion.py b/test/test_softmax_fusion.py index 729c02a473..9a7efbb651 100644 --- a/test/test_softmax_fusion.py +++ b/test/test_softmax_fusion.py @@ -111,7 +111,6 @@ class TestFuse(unittest.TestCase): with Context(NOOPT=1): self._test_fuse(Tensor.scaled_dot_product_attention, q, k, v, atol=1e-5) - @unittest.expectedFailure def test_mismatch_reduce(self): a = Tensor.ones(16, 10).contiguous().realize() b = Tensor.ones(16, 20).contiguous().realize() diff --git a/test/unit/test_uop_spec.py b/test/unit/test_uop_spec.py index 8cda3f8130..0244ba3531 100644 --- a/test/unit/test_uop_spec.py +++ b/test/unit/test_uop_spec.py @@ -34,17 +34,6 @@ class TestUOpSpec(unittest.TestCase): store = UOp(Ops.STORE, dtypes.void, (buf_0.view(ShapeTracker.from_shape((32, 1))), a+b)) helper_test_verify_ast(store) - def test_exactly_one_full_shape(self): - dtype = dtypes.int - bufs = [UOp(Ops.DEFINE_GLOBAL, dtype.ptr(), (), i) for i in range(6)] - a = UOp(Ops.LOAD, dtype, (bufs[2].view(ShapeTracker.from_shape((32, 1))),)) - b = UOp(Ops.LOAD, dtype, (bufs[3].view(ShapeTracker.from_shape((32, 1))),)) - st0 = UOp.store(bufs[0].view(ShapeTracker.from_shape((32, 1))), a+b) - a = UOp(Ops.LOAD, dtype, (bufs[4].view(ShapeTracker.from_shape((32, 32))),)) - b = UOp(Ops.LOAD, dtype, (bufs[5].view(ShapeTracker.from_shape((32, 32))),)) - st1 = UOp.store(bufs[1].view(ShapeTracker.from_shape((32, 32))), a+b) - with self.assertRaises(InvalidASTException): helper_test_verify_ast(st0, st1) - def test_no_implicit_broadcasting(self): bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), i) for i in range(2)] a = UOp(Ops.LOAD, dtypes.float, (bufs[1].view(ShapeTracker.from_shape((4, 32))),)) diff --git a/tinygrad/opt/heuristic.py b/tinygrad/opt/heuristic.py index 13f259c073..f2979a7b0c 100644 --- a/tinygrad/opt/heuristic.py +++ b/tinygrad/opt/heuristic.py @@ -80,18 +80,21 @@ def hand_coded_optimizations(k:Kernel) -> list[Opt]: else: break # if last reduce dim is small(ish), loop unroll the reduce - upcast_size = prod(k.full_shape[a] for a in k.axes_of(AxisType.UPCAST, AxisType.UNROLL)) - if k.unrollable_dims and (upcast_size <= 4 or not k.axes_of(AxisType.UNROLL)) and (upcast_size < 64): - if (s:=k.full_shape[k.unrollable_dims[-1]]) <= 32: - k.apply_opt(Opt(OptOps.UNROLL, len(k.unrollable_dims)-1, 0)) - # if it's small, upcast a second reduce dimension too - if k.unrollable_dims and s <= 3 and k.full_shape[k.unrollable_dims[-1]] <= 3: + # NOTE: this can fail on multireduce with mismatching dimensions, this is okay + try: + upcast_size = prod(k.full_shape[a] for a in k.axes_of(AxisType.UPCAST, AxisType.UNROLL)) + if k.unrollable_dims and (upcast_size <= 4 or not k.axes_of(AxisType.UNROLL)) and (upcast_size < 64): + if (s:=k.full_shape[k.unrollable_dims[-1]]) <= 32: k.apply_opt(Opt(OptOps.UNROLL, len(k.unrollable_dims)-1, 0)) - else: - for splits in [4]: - if k.full_shape[axis:=k.unrollable_dims[-1]]%splits == 0: - k.apply_opt(Opt(OptOps.UNROLL, len(k.unrollable_dims)-1, splits)) - break + # if it's small, upcast a second reduce dimension too + if k.unrollable_dims and s <= 3 and k.full_shape[k.unrollable_dims[-1]] <= 3: + k.apply_opt(Opt(OptOps.UNROLL, len(k.unrollable_dims)-1, 0)) + else: + for splits in [4]: + if k.full_shape[axis:=k.unrollable_dims[-1]]%splits == 0: + k.apply_opt(Opt(OptOps.UNROLL, len(k.unrollable_dims)-1, splits)) + break + except KernelOptError: pass # if nothing at all is upcasted and it's easy to, do an upcast for splits in [4]: diff --git a/tinygrad/opt/kernel.py b/tinygrad/opt/kernel.py index 0ae4c17ad2..aeb46376d6 100644 --- a/tinygrad/opt/kernel.py +++ b/tinygrad/opt/kernel.py @@ -271,7 +271,10 @@ class Kernel: check(isinstance(opt.arg, int), "arg should be int") amt = arg if (arg:=cast(int, opt.arg)) != 0 else self.full_shape[axis] check(isinstance(amt, int) and amt != 1, f"shift/padto of {amt=}, 1 or symbolic amount is meaningless") - if opt.op is not OptOps.PADTO: check(self.full_shape[axis] % amt == 0, f"no longer valid shift {self.full_shape[axis]=}, {amt=}") + if opt.op is not OptOps.PADTO: + # we check both the full_shape and each shape + check(self.full_shape[axis] % amt == 0, f"no longer valid shift {self.full_shape[axis]=}, {amt=}") + for st in self.sts: check(st.shape[axis] == 1 or st.shape[axis] % amt == 0, f"no longer valid shift {st.shape[axis]=}, {amt=}") else: amt = -1 if self.reduceop is not None and (opt.op in {OptOps.GROUP, OptOps.GROUPTOP} or \ diff --git a/tinygrad/uop/spec.py b/tinygrad/uop/spec.py index fe683e81dd..6c9f9e7daa 100644 --- a/tinygrad/uop/spec.py +++ b/tinygrad/uop/spec.py @@ -1,5 +1,5 @@ from typing import cast, Callable -from tinygrad.uop.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, print_uops, python_alu, graph_rewrite, resolve +from tinygrad.uop.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, print_uops, python_alu, graph_rewrite from tinygrad.dtype import DType, ImageDType, dtypes, PtrDType, AddrSpace from tinygrad.helpers import all_same, prod, DEBUG, ContextVar, Context from tinygrad.shape.shapetracker import ShapeTracker @@ -207,16 +207,7 @@ spec = PatternMatcher([ # *** this is the UOp AST spec *** -def verify_sink_dims(sink:UOp): - if not all_same([s.shape for s in sink.src]): return False - for dims in zip(*[x.shape for x in sink.toposort() if x.op is Ops.VIEW]): - if len(n_dims:={s for s in dims if resolve(s!=1)}) > 1: - print(f"# INVALID KERNEL DIMS: can only have 1 or n in each dimension: {n_dims}") - return False - ast_spec = PatternMatcher([ - # shapes must have either 1 or n in each dimension - (UPat(Ops.SINK, src=UPat(Ops.STORE), name="sink"), verify_sink_dims), # VIEW can only exist in the edges (UPat(Ops.VIEW, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL),))), lambda: True), (UPat(Ops.VIEW, name="view"), lambda view: len(view.src) == 0),