diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index ad381a6cd0..77d17cc65b 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -377,7 +377,7 @@ jobs: llvm: 'true' - name: Test openpilot model kernel count and gate usage run: | - ALLOWED_KERNEL_COUNT=190 ALLOWED_READ_IMAGE=2092 ALLOWED_GATED_READ_IMAGE=55 FLOAT16=0 CL=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx + ALLOWED_KERNEL_COUNT=190 ALLOWED_READ_IMAGE=2081 ALLOWED_GATED_READ_IMAGE=28 FLOAT16=0 CL=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx - name: Test openpilot alt model correctness (float32) run: FLOAT16=0 DEBUGCL=1 CL=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/3799fe46b3a629e491d4b8498b8ae83e4c88c304/selfdrive/modeld/models/supercombo.onnx - name: Test openpilot fastvits model correctness (float32) diff --git a/test/test_linearizer.py b/test/test_linearizer.py index 23d42a5349..7af6294c83 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -70,7 +70,8 @@ class TestLinearizer(unittest.TestCase): ranges = [i for i,u in enumerate(uops) if u.op is Ops.RANGE] # RANGE -> ALU -> RANGE -> ALU + LOAD -> STORE assert any(x.op in GroupOp.ALU for x in uops[ranges[0]:ranges[1]]) - assert not any(x.op is Ops.LOAD for x in uops[ranges[0]:ranges[1]]) + # the index of the load doesnt depend on the second range + assert any(x.op is Ops.LOAD for x in uops[ranges[0]:ranges[1]]) assert any(x.op in {*GroupOp.ALU, Ops.LOAD} for x in uops[ranges[1]:]) def test_range_outer_op_before_phi(self): diff --git a/tinygrad/schedule/indexing.py b/tinygrad/schedule/indexing.py index 982babe2b0..4c057a3cf1 100644 --- a/tinygrad/schedule/indexing.py +++ b/tinygrad/schedule/indexing.py @@ -3,7 +3,7 @@ import functools, operator, itertools from dataclasses import dataclass, field from tinygrad.dtype import dtypes, AddrSpace from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, graph_rewrite, sint, AxisType -from tinygrad.uop.symbolic import symbolic, pm_simplify_valid +from tinygrad.uop.symbolic import symbolic, pm_simplify_valid, pm_drop_and_clauses from tinygrad.helpers import argsort, all_same, cpu_profile, TracingKey ALWAYS_CONTIGUOUS: set[Ops] = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW, @@ -112,8 +112,10 @@ def apply_movement_op(op:Ops, in_shape:tuple[sint,...], arg:tuple, rngs:tuple[UO case Ops.EXPAND: rngs = tuple(a if in_sh == out_sh else a.const_like(0) for a,in_sh,out_sh in zip(rngs, in_shape, arg)) case Ops.PAD: # TODO: why is multiple graph_rewrites faster than one here? - rngs = tuple(r if (s == 0 and e == 0) else graph_rewrite(((r >= s) & (r < (sh+s))).where(r-s, UOp.invalid()), - symbolic+pm_simplify_valid, name="pad") for r,sh,(s,e) in zip(rngs, in_shape, arg)) + # TODO: the .where(r-s, i) is not inside the graph_rewrite so that `convert_pad_to_where_to_keep_behavior_local` + # wraps the pad with only the newly added valid + rngs = tuple(r if (s == 0 and e == 0) else graph_rewrite(((r >= s) & (r < (sh+s))), + symbolic+pm_simplify_valid, name="pad").where(r-s, UOp.invalid()) for r,sh,(s,e) in zip(rngs, in_shape, arg)) case Ops.RESHAPE: acc = 1 axes_in:list[UOp] = [] @@ -126,7 +128,8 @@ def apply_movement_op(op:Ops, in_shape:tuple[sint,...], arg:tuple, rngs:tuple[UO axes_out.append(combined_axes % s) combined_axes //= s # this simplify is doing a lot of heavy lifting. this is the replacement for the reshape view merging code - rngs = graph_rewrite(UOp.sink(*axes_out[::-1]), symbolic+pm_simplify_valid, name="reshape").src + rngs = graph_rewrite(graph_rewrite(UOp.sink(*axes_out[::-1]), symbolic+pm_simplify_valid, name="reshape"), + pm_drop_and_clauses, name="reshape drop ands").src case _: raise RuntimeError(f"{op} is not a MovementOp") return rngs diff --git a/tinygrad/uop/symbolic.py b/tinygrad/uop/symbolic.py index 9d0d383a0d..3e43d13161 100644 --- a/tinygrad/uop/symbolic.py +++ b/tinygrad/uop/symbolic.py @@ -470,6 +470,11 @@ def reduce_mul_chain(r:UOp): if len(outside) == 0: return None return r.replace(src=(prod(inside) if len(inside) else r.src[0].const_like(1),)+r.src[1:])*prod(outside) +def drop_and_clauses(cond:UOp, x:UOp, i:UOp) -> UOp|None: + if not (dropped_clauses:=[c for c in cond.split_uop(Ops.AND) if not any(r in x.ranges for r in c.ranges)]): return None + return functools.reduce(operator.and_, [c for c in cond.split_uop(Ops.AND) if c not in dropped_clauses], UOp.const(dtypes.bool, True)).where(x, i) +pm_drop_and_clauses = PatternMatcher([(UPat.var("cond").where(UPat.var("x", dtype=dtypes.index), invalid_pat), drop_and_clauses)]) + pm_simplify_valid = PatternMatcher([ # simplify valid (UPat(Ops.AND, name="valid"), simplify_valid), @@ -514,8 +519,8 @@ sym = symbolic_flat+pm_simplify_valid+PatternMatcher([ (UPat((Ops.LOAD, Ops.STORE), src=(UPat().index(UPat.const(dtypes.index, Invalid)).or_casted(),), allow_any_len=True, name="x"), lambda x: UOp(Ops.NOOP) if x.op is Ops.STORE else x.const_like(0)), # invalid store does nothing. invalid load produces 0 # # Where after gated load becomes alt value, TODO: this is sort of duplicated with rules in devectorizer - (UPat.var("c1").where(UPat(Ops.LOAD, src=(UPat().index(UPat.var("c2").where(UPat(), invalid_pat)).or_casted(),), allow_any_len=True, name="l"), 0), - lambda c1,c2,l,i: l.replace(src=(l.src[0],)+l.src[1:]) if any(c in list(c2.split_uop(Ops.AND)) for c in c1.split_uop(Ops.AND)) else None), + (UPat.var("c1").where(UPat(Ops.LOAD, src=(UPat().index(UPat.var("c2").where(UPat(), invalid_pat)).or_casted(),), name="l"), 0), + lambda c1,c2,l,i: l.replace(src=(l.src[0],)+l.src[1:]) if all(c in list(c2.split_uop(Ops.AND)) for c in c1.split_uop(Ops.AND)) else None), # remove VECTORIZE from SINK/BARRIER. TODO: SINK/BARRIER are really the same thing at GLOBAL/LOCAL levels (UPat(Ops.BARRIER, name="root"), lambda root: UOp(Ops.BARRIER, root.dtype, tuple(flatten(x.src if x.op in REMOVE_FROM_BARRIER else (x,) for x in root.src)), root.arg)