mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
update uop_given_valid if a node becomes const (#9604)
* update uop_given_valid if a node becomes const * cleanup
This commit is contained in:
@@ -124,6 +124,20 @@ class TestValidIdxSimplification(unittest.TestCase):
|
||||
"(((ridx0*2)+(ridx3*-1))+1)",
|
||||
"(ridx2<1)")
|
||||
|
||||
def test_valid_becomes_const1(self):
|
||||
# from DSP mobilenetv2
|
||||
ridx0 = Range(0, 30)
|
||||
ridx1 = Range(1, 7)
|
||||
ridx2 = Range(2, 2)
|
||||
alu11 = (ridx1+ridx2)
|
||||
alu15 = ((alu11+1)//7)
|
||||
idx = (alu15*-31)+(((((alu11+218)//224)+ridx0)%30)*1568)
|
||||
valid = (ridx2<1)&(ridx1<6)
|
||||
load = get_gated_load_uop(valid, idx)
|
||||
self.check(load,
|
||||
"(ridx0*1568)",
|
||||
"((ridx2<1)&(ridx1<6))")
|
||||
|
||||
class TestImageSimplification(unittest.TestCase):
|
||||
def check(self, load, svalid, sidx0, sidx1):
|
||||
load = full_graph_rewrite(load.sink()).src[0]
|
||||
|
||||
@@ -275,17 +275,20 @@ def uop_given_valid(valid:UOp, uop:UOp) -> UOp|None:
|
||||
|
||||
# simplify uop given that valid is True
|
||||
for expr,v in bounds.items():
|
||||
v0, v1 = (expr.vmin if v[0] is None else v[0], expr.vmax if v[1] is None else v[1])
|
||||
# some expr has lower bound > upper bound -> valid is an empty set and we return None
|
||||
if v[0] is not None and v[1] is not None and v[0] > v[1]: return None
|
||||
|
||||
if v0 > v1: return None
|
||||
# whole node became a const
|
||||
if v0 == v1:
|
||||
uop = uop.substitute({expr:expr.const_like(v0)}).simplify()
|
||||
continue
|
||||
# every candidate is a set of contrained UOp based on valid, and if every item in a set simplifies the uop into a same output, we rewrite uop
|
||||
candidates = []
|
||||
if expr.op is Ops.ADD and v[0] == 1 and all(u.op in GroupOp.Irreducible for u in split_uop(expr, Ops.ADD)):
|
||||
if expr.op is Ops.ADD and v0 == 1 and all(u.op in GroupOp.Irreducible for u in split_uop(expr, Ops.ADD)):
|
||||
# if the constraint is a simplex: X0 + X1 + ... > 0, we can check if all Xi > 0 simplify into the same output
|
||||
candidates.append([(Xi, UOp.variable("fake", 1, Xi.vmax, Xi.dtype)) for Xi in split_uop(expr, Ops.ADD)])
|
||||
# try checking the whole clause
|
||||
if expr in uop.toposort:
|
||||
candidates.append([(expr, UOp.variable("fake", expr.vmin if v[0] is None else v[0], expr.vmax if v[1] is None else v[1], expr.dtype))])
|
||||
if expr in uop.toposort: candidates.append([(expr, UOp.variable("fake", v0, v1, expr.dtype))])
|
||||
|
||||
for candidate in candidates:
|
||||
# if every branch in candidate gives the same simplified uop, we can rewrite the uop
|
||||
|
||||
Reference in New Issue
Block a user