fix cifar while keeping openpilot fused (#13528)

* this works

* test now passes
This commit is contained in:
Roelof van Dijk
2025-12-02 21:05:56 +01:00
committed by GitHub
parent 0874ba8cc8
commit e329baffa7
2 changed files with 7 additions and 4 deletions

View File

@@ -96,7 +96,6 @@ class TestLinearizer(unittest.TestCase):
ranges = [i for i,u in enumerate(uops) if u.op is Ops.RANGE]
assert len(ranges) == 1 # NOTE: it collapses now
@unittest.expectedFailure # TODO: investigate
def test_two_nested_range_alt_indexing(self):
a = Tensor([2, 2]).realize()
out = a.reshape(2, 1).pad(((1, 1), (1, 1)), value=2).sum()

View File

@@ -3,7 +3,7 @@ import math, operator, struct, functools
from collections import defaultdict
from tinygrad.uop.ops import Ops, PatternMatcher, UPat, UOp, GroupOp, exec_alu
from tinygrad.dtype import ConstType, dtypes, PtrDType, can_safe_cast, Invalid
from tinygrad.helpers import partition, all_same, prod, flatten, get_single_element, unwrap, dedup
from tinygrad.helpers import partition, all_same, prod, flatten, get_single_element, unwrap, IMAGE, dedup
from tinygrad.uop.decompositions import xpow
from tinygrad.uop.divandmod import div_and_mod_symbolic
@@ -363,11 +363,15 @@ pm_move_where_on_load = PatternMatcher([
(UPat.var("c1").where(0, UPat.var("buf").index(UPat.var("x"))), lambda c1,buf,x: where_on_load(c1.logical_not(),buf,x)),
])
def gated_given_valid(cond:UOp, x:UOp, i:UOp) -> UOp|None:
# Skip if x contains DIV/MOD AND IMAGE mode is enabled -> image index e.g. openpilot
if IMAGE.value > 0 and x.op_in_backward_slice_with_self(Ops.IDIV, Ops.MOD): return None
return cond.where(uop_given_valid(cond, x, try_simplex=False), i)
pm_simplify_valid = PatternMatcher([
# simplify valid
(UPat(Ops.AND, name="valid"), simplify_valid),
# TODO: this regressed openpilot, not having this regressed cifar
# (invalid_gate, lambda cond,x,i: cond.where(uop_given_valid(cond, x, try_simplex=False), i)),
(invalid_gate, gated_given_valid),
])
# this is symbolic 2.0