Dont simplify gate in gate, fix FUSE_ARANGE=1 python test/test_ops.py TestOps.test_scatter_add (#10411)

* substitute out index

* Add test

* change comment
This commit is contained in:
Sieds Lykles
2025-05-19 19:16:21 +02:00
committed by GitHub
parent 116d9e6306
commit db09676250
2 changed files with 17 additions and 0 deletions

View File

@@ -124,6 +124,18 @@ class TestValidIdxSimplification(unittest.TestCase):
"(((ridx0*2)+(ridx3*-1))+1)",
"(ridx2<1)")
def test_load_in_valid(self):
# from FUSE_ARANGE=1 python test/test_ops.py TestOps.test_scatter_add
# can lead to OOB
ridx2 = Range(2, 4)
lidx0 = Special("lidx0", 3)
gidx0 = Special("gidx0", 2)
idx=(((lidx0+(gidx0*3))+(ridx2*5))+40)
valid = (lidx0+(gidx0*3)) < 5
val7 = get_gated_load_uop(valid, idx)
valid2 = valid & val7.cast(dtypes.bool).logical_not()
self.assertIsNone(simplify_valid(valid2))
def test_valid_becomes_const1(self):
# from DSP mobilenetv2
ridx0 = Range(0, 30)

View File

@@ -324,6 +324,9 @@ def uop_given_valid(valid:UOp, uop:UOp) -> UOp|None:
except ValueError: return uop # give up if we cannot parse the valid
bounds[expr][int(is_upper)] = c
# don't simplify any other gates, can lead to OOB, we substitute them back later
uop = uop.substitute((load_subs:={u: UOp(Ops.NOOP, arg=u) for u in uop.toposort() if u.op is Ops.INDEX}))
# simplify uop given that valid is True
for expr,v in bounds.items():
v0, v1 = (expr.vmin if v[0] is None else v[0], expr.vmax if v[1] is None else v[1])
@@ -349,6 +352,8 @@ def uop_given_valid(valid:UOp, uop:UOp) -> UOp|None:
if all_same([uops.src[1] for uops in newuops]): uop = uop.replace(src=(uop.src[0], newuops[0].src[1]))
elif all_same(newuops): uop = newuops[0]
# put the loads back in
uop = uop.substitute({v:k for k,v in load_subs.items()})
return uop
def _valid_priority(v: UOp, valids:list[UOp]):