From 772a8dfe31f8d2f4fcf1dcbbf4d4172bb07850f1 Mon Sep 17 00:00:00 2001 From: Sieds Lykles <93992551+S-Lykles@users.noreply.github.com> Date: Sat, 11 Oct 2025 17:02:54 +0200 Subject: [PATCH] reshape uses valid when simplifying (#12597) * reshape uses valid when simplifying * try with IGNORE_OOB=0 * is it this test? * skipif gpuocelot --- .github/workflows/test.yml | 2 +- test/test_linearizer.py | 6 ++++-- test/unit/test_winograd.py | 2 +- tinygrad/schedule/indexing.py | 8 ++++---- tinygrad/uop/symbolic.py | 13 ++++++++----- 5 files changed, 18 insertions(+), 13 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 578e76cc4b..f80262e65b 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -377,7 +377,7 @@ jobs: llvm: 'true' - name: Test openpilot model kernel count and gate usage run: | - ALLOWED_KERNEL_COUNT=190 ALLOWED_READ_IMAGE=2041 ALLOWED_GATED_READ_IMAGE=41 FLOAT16=0 CL=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx + ALLOWED_KERNEL_COUNT=190 ALLOWED_READ_IMAGE=2092 ALLOWED_GATED_READ_IMAGE=55 FLOAT16=0 CL=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx - name: Test openpilot alt model correctness (float32) run: FLOAT16=0 DEBUGCL=1 CL=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/3799fe46b3a629e491d4b8498b8ae83e4c88c304/selfdrive/modeld/models/supercombo.onnx - name: Test openpilot fastvits model correctness (float32) diff --git a/test/test_linearizer.py b/test/test_linearizer.py index a0d6d67f67..23d42a5349 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -8,9 +8,11 @@ from tinygrad.uop.ops import UOp, Ops, GroupOp from tinygrad.device import Device, Buffer, is_dtype_supported from tinygrad.tensor import Tensor, _to_np_dtype from tinygrad.engine.realize import run_schedule, lower_schedule, CompiledRunner, get_program -from tinygrad.helpers import Context, flatten, dedup, TC_SELECT, TC_OPT +from tinygrad.helpers import Context, flatten, dedup, TC_SELECT, TC_OPT, getenv from tinygrad.dtype import DType, dtypes, PtrDType, AddrSpace from tinygrad.renderer.ptx import PTXRenderer +from tinygrad.renderer.cstyle import CUDARenderer +MOCKGPU = getenv("MOCKGPU") class TestLinearizer(unittest.TestCase): def test_arg_dedup(self): @@ -314,7 +316,7 @@ class TestLinearizer(unittest.TestCase): a.realize() np.testing.assert_equal(a.flatten().numpy(), [1.,1.,1.,1.,2.,2.,2.,2.,1.,1.,1.,1.,1.,1.,1.,1.]) - @unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "PTX indexes differently. might be ok?") + @unittest.skipIf(MOCKGPU and isinstance(Device[Device.DEFAULT].renderer, (PTXRenderer, CUDARenderer)), "PTX indexes differently. might be ok?") def test_where_fold(self): a = Tensor.ones(4, 4).contiguous().realize() b = a.shrink(((1, 2), None)).pad(((1, 2), None)) diff --git a/test/unit/test_winograd.py b/test/unit/test_winograd.py index 7f419b838c..d8909f7620 100644 --- a/test/unit/test_winograd.py +++ b/test/unit/test_winograd.py @@ -42,7 +42,7 @@ class TestWinograd(unittest.TestCase): out = Tensor.conv2d(x,w, padding=1) out.mean().backward() backward_schedule = Tensor.schedule(x.grad, w.grad) - self.assertEqual(len(backward_schedule), 4) + self.assertEqual(len(backward_schedule), 5) def test_counters(self): IC, OC, X, Y = 4,4,9,9 diff --git a/tinygrad/schedule/indexing.py b/tinygrad/schedule/indexing.py index 88d9394857..982babe2b0 100644 --- a/tinygrad/schedule/indexing.py +++ b/tinygrad/schedule/indexing.py @@ -3,7 +3,7 @@ import functools, operator, itertools from dataclasses import dataclass, field from tinygrad.dtype import dtypes, AddrSpace from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, graph_rewrite, sint, AxisType -from tinygrad.uop.symbolic import sym, symbolic +from tinygrad.uop.symbolic import symbolic, pm_simplify_valid from tinygrad.helpers import argsort, all_same, cpu_profile, TracingKey ALWAYS_CONTIGUOUS: set[Ops] = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW, @@ -112,8 +112,8 @@ def apply_movement_op(op:Ops, in_shape:tuple[sint,...], arg:tuple, rngs:tuple[UO case Ops.EXPAND: rngs = tuple(a if in_sh == out_sh else a.const_like(0) for a,in_sh,out_sh in zip(rngs, in_shape, arg)) case Ops.PAD: # TODO: why is multiple graph_rewrites faster than one here? - rngs = tuple(r if (s == 0 and e == 0) else graph_rewrite(((r >= s) & (r < (sh+s))).where(r-s, UOp.invalid()), sym, name="pad") - for r,sh,(s,e) in zip(rngs, in_shape, arg)) + rngs = tuple(r if (s == 0 and e == 0) else graph_rewrite(((r >= s) & (r < (sh+s))).where(r-s, UOp.invalid()), + symbolic+pm_simplify_valid, name="pad") for r,sh,(s,e) in zip(rngs, in_shape, arg)) case Ops.RESHAPE: acc = 1 axes_in:list[UOp] = [] @@ -126,7 +126,7 @@ def apply_movement_op(op:Ops, in_shape:tuple[sint,...], arg:tuple, rngs:tuple[UO axes_out.append(combined_axes % s) combined_axes //= s # this simplify is doing a lot of heavy lifting. this is the replacement for the reshape view merging code - rngs = graph_rewrite(UOp.sink(*axes_out[::-1]), symbolic, name="reshape").src + rngs = graph_rewrite(UOp.sink(*axes_out[::-1]), symbolic+pm_simplify_valid, name="reshape").src case _: raise RuntimeError(f"{op} is not a MovementOp") return rngs diff --git a/tinygrad/uop/symbolic.py b/tinygrad/uop/symbolic.py index d2f0618bff..1f7f9b18e7 100644 --- a/tinygrad/uop/symbolic.py +++ b/tinygrad/uop/symbolic.py @@ -437,7 +437,7 @@ def uop_given_valid(valid:UOp, uop:UOp, try_simplex=True) -> UOp: # try all the valids together (but only the whole expressions) if (s_uop:=uop.substitute(sub_dict:=dict(all_candidates))) is not uop: - uop = s_uop.simplify(tracked=True).substitute({newX:X for X,newX in sub_dict.items()}).simplify(full_symbolic=False) + uop = s_uop.simplify().substitute({newX:X for X,newX in sub_dict.items()}).simplify(full_symbolic=False) # put the loads back in uop = uop.substitute({v:k for k,v in load_subs.items()}) return uop @@ -470,13 +470,16 @@ def reduce_mul_chain(r:UOp): if len(outside) == 0: return None return r.replace(src=(prod(inside) if len(inside) else r.src[0].const_like(1),)+r.src[1:])*prod(outside) -# this is symbolic 2.0 -REMOVE_FROM_SINK = {Ops.SINK, Ops.UNROLL, Ops.PTRCAT, Ops.CAT, Ops.NOOP} -REMOVE_FROM_BARRIER = {Ops.VECTORIZE, Ops.SINK, Ops.CAT, Ops.PTRCAT, Ops.NOOP} -sym = symbolic_flat+PatternMatcher([ +pm_simplify_valid = PatternMatcher([ # simplify valid (UPat(Ops.AND, name="valid"), simplify_valid), (UPat.var("c").where(UPat.var("x", dtype=dtypes.index), invalid_pat), lambda c,x,i: c.where(uop_given_valid(c, x, try_simplex=False), i)), +]) + +# this is symbolic 2.0 +REMOVE_FROM_SINK = {Ops.SINK, Ops.UNROLL, Ops.PTRCAT, Ops.CAT, Ops.NOOP} +REMOVE_FROM_BARRIER = {Ops.VECTORIZE, Ops.SINK, Ops.CAT, Ops.PTRCAT, Ops.NOOP} +sym = symbolic_flat+pm_simplify_valid+PatternMatcher([ # LOAD/STORE -> NOOP (UPat.var('x').store(UPat.var('x').load(), allow_any_len=True), lambda x: None if x.dtype.addrspace != AddrSpace.REG else x.src[0].src[0]), (UPat(Ops.LOAD, src=(UPat.cvar('c'))), lambda c: c),