update uop_given_valid if a node becomes const (#9604)

* update uop_given_valid if a node becomes const

* cleanup
This commit is contained in:
chenyu
2025-03-27 14:57:46 -04:00
committed by GitHub
parent a187dfd3df
commit 5358b0904b
2 changed files with 22 additions and 5 deletions

View File

@@ -124,6 +124,20 @@ class TestValidIdxSimplification(unittest.TestCase):
"(((ridx0*2)+(ridx3*-1))+1)",
"(ridx2<1)")
def test_valid_becomes_const1(self):
# from DSP mobilenetv2
ridx0 = Range(0, 30)
ridx1 = Range(1, 7)
ridx2 = Range(2, 2)
alu11 = (ridx1+ridx2)
alu15 = ((alu11+1)//7)
idx = (alu15*-31)+(((((alu11+218)//224)+ridx0)%30)*1568)
valid = (ridx2<1)&(ridx1<6)
load = get_gated_load_uop(valid, idx)
self.check(load,
"(ridx0*1568)",
"((ridx2<1)&(ridx1<6))")
class TestImageSimplification(unittest.TestCase):
def check(self, load, svalid, sidx0, sidx1):
load = full_graph_rewrite(load.sink()).src[0]

View File

@@ -275,17 +275,20 @@ def uop_given_valid(valid:UOp, uop:UOp) -> UOp|None:
# simplify uop given that valid is True
for expr,v in bounds.items():
v0, v1 = (expr.vmin if v[0] is None else v[0], expr.vmax if v[1] is None else v[1])
# some expr has lower bound > upper bound -> valid is an empty set and we return None
if v[0] is not None and v[1] is not None and v[0] > v[1]: return None
if v0 > v1: return None
# whole node became a const
if v0 == v1:
uop = uop.substitute({expr:expr.const_like(v0)}).simplify()
continue
# every candidate is a set of contrained UOp based on valid, and if every item in a set simplifies the uop into a same output, we rewrite uop
candidates = []
if expr.op is Ops.ADD and v[0] == 1 and all(u.op in GroupOp.Irreducible for u in split_uop(expr, Ops.ADD)):
if expr.op is Ops.ADD and v0 == 1 and all(u.op in GroupOp.Irreducible for u in split_uop(expr, Ops.ADD)):
# if the constraint is a simplex: X0 + X1 + ... > 0, we can check if all Xi > 0 simplify into the same output
candidates.append([(Xi, UOp.variable("fake", 1, Xi.vmax, Xi.dtype)) for Xi in split_uop(expr, Ops.ADD)])
# try checking the whole clause
if expr in uop.toposort:
candidates.append([(expr, UOp.variable("fake", expr.vmin if v[0] is None else v[0], expr.vmax if v[1] is None else v[1], expr.dtype))])
if expr in uop.toposort: candidates.append([(expr, UOp.variable("fake", v0, v1, expr.dtype))])
for candidate in candidates:
# if every branch in candidate gives the same simplified uop, we can rewrite the uop