mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
fix cifar while keeping openpilot fused (#13528)
* this works * test now passes
This commit is contained in:
@@ -96,7 +96,6 @@ class TestLinearizer(unittest.TestCase):
|
||||
ranges = [i for i,u in enumerate(uops) if u.op is Ops.RANGE]
|
||||
assert len(ranges) == 1 # NOTE: it collapses now
|
||||
|
||||
@unittest.expectedFailure # TODO: investigate
|
||||
def test_two_nested_range_alt_indexing(self):
|
||||
a = Tensor([2, 2]).realize()
|
||||
out = a.reshape(2, 1).pad(((1, 1), (1, 1)), value=2).sum()
|
||||
|
||||
@@ -3,7 +3,7 @@ import math, operator, struct, functools
|
||||
from collections import defaultdict
|
||||
from tinygrad.uop.ops import Ops, PatternMatcher, UPat, UOp, GroupOp, exec_alu
|
||||
from tinygrad.dtype import ConstType, dtypes, PtrDType, can_safe_cast, Invalid
|
||||
from tinygrad.helpers import partition, all_same, prod, flatten, get_single_element, unwrap, dedup
|
||||
from tinygrad.helpers import partition, all_same, prod, flatten, get_single_element, unwrap, IMAGE, dedup
|
||||
from tinygrad.uop.decompositions import xpow
|
||||
from tinygrad.uop.divandmod import div_and_mod_symbolic
|
||||
|
||||
@@ -363,11 +363,15 @@ pm_move_where_on_load = PatternMatcher([
|
||||
(UPat.var("c1").where(0, UPat.var("buf").index(UPat.var("x"))), lambda c1,buf,x: where_on_load(c1.logical_not(),buf,x)),
|
||||
])
|
||||
|
||||
def gated_given_valid(cond:UOp, x:UOp, i:UOp) -> UOp|None:
|
||||
# Skip if x contains DIV/MOD AND IMAGE mode is enabled -> image index e.g. openpilot
|
||||
if IMAGE.value > 0 and x.op_in_backward_slice_with_self(Ops.IDIV, Ops.MOD): return None
|
||||
return cond.where(uop_given_valid(cond, x, try_simplex=False), i)
|
||||
|
||||
pm_simplify_valid = PatternMatcher([
|
||||
# simplify valid
|
||||
(UPat(Ops.AND, name="valid"), simplify_valid),
|
||||
# TODO: this regressed openpilot, not having this regressed cifar
|
||||
# (invalid_gate, lambda cond,x,i: cond.where(uop_given_valid(cond, x, try_simplex=False), i)),
|
||||
(invalid_gate, gated_given_valid),
|
||||
])
|
||||
|
||||
# this is symbolic 2.0
|
||||
|
||||
Reference in New Issue
Block a user