mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
Simplify valid in symbolic (#12104)
* cleanup cast_folding * from sym to symbolic * no more sym in dtype lowering * move around simplify_valid * update test
This commit is contained in:
@@ -154,7 +154,7 @@ class TestRealStrides(unittest.TestCase):
|
||||
View.create((1, 3, 22, 21), (0, 192, 16, 1), 0, ((0, 1), (0, 3), (0, 12), (0, 16))),
|
||||
View.create((3, 11, 7, 2, 3), (462, 21, 1, 231, 7), 0, None),
|
||||
))
|
||||
self.assertEqual(st.real_strides(), (132, None, None, None, None))
|
||||
self.assertEqual(st.real_strides(), (132, 12, None, None, None))
|
||||
|
||||
class TestRealSimplifies(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
|
||||
@@ -271,6 +271,91 @@ gep_pushing = PatternMatcher([
|
||||
(UPat(Ops.WMMA, name="wmma").f(Ops.GEP, name="gep"), gep_through_wmma),
|
||||
])
|
||||
|
||||
# ******** we take a small aside to "simplify_valid" to rewrite "and" clauses (valids) ********
|
||||
|
||||
def parse_valid(valid:UOp) -> tuple[UOp, bool, int]:
|
||||
# if it's X <= c, returns X, True, c
|
||||
# if it's X >= c, returns X, False, c
|
||||
|
||||
# (X < c).ne(True) -> X >= c
|
||||
if valid.op is Ops.CMPNE and valid.src[1].op is Ops.CONST and valid.src[1].arg == 1 and \
|
||||
(s0:=valid.src[0]).op is Ops.CMPLT and dtypes.is_int(s0.src[0].dtype): return s0.src[0], False, int(s0.src[1].vmin)
|
||||
# X < c -> X <= c-1
|
||||
if valid.op is Ops.CMPLT and dtypes.is_int(valid.src[0].dtype): return valid.src[0], True, int((valid.src[1]).vmax)-1
|
||||
raise ValueError(f"not able to parse {valid=}")
|
||||
|
||||
def uop_given_valid(valid:UOp, uop:UOp) -> UOp|None:
|
||||
# return None if valid is always False, otherwise the simplified uop (might be the same as input)
|
||||
|
||||
# first, parse valid into {expr: (lower_bound, upper_bound)}
|
||||
bounds:defaultdict[UOp, list[ConstType|None]] = defaultdict(lambda: [None, None])
|
||||
for stmt in valid.split_uop(Ops.AND):
|
||||
try: expr, is_upper, c = parse_valid(stmt)
|
||||
except ValueError: return uop # give up if we cannot parse the valid
|
||||
bounds[expr][int(is_upper)] = c
|
||||
|
||||
# don't simplify any other gates, can lead to OOB, we substitute them back later
|
||||
uop = uop.substitute((load_subs:={u: UOp(Ops.NOOP, arg=u) for u in uop.toposort() if u.op is Ops.INDEX}))
|
||||
|
||||
# simplify uop given that valid is True
|
||||
for expr,v in bounds.items():
|
||||
v0, v1 = (expr.vmin if v[0] is None else v[0], expr.vmax if v[1] is None else v[1])
|
||||
expr = expr.substitute(load_subs) # make sure expr appears in same form in the uop
|
||||
# some expr has lower bound > upper bound -> valid is an empty set and we return None
|
||||
if v0 > v1: return None
|
||||
# whole node became a const
|
||||
if v0 == v1:
|
||||
uop = uop.substitute({expr:expr.const_like(v0)}).simplify()
|
||||
continue
|
||||
# every candidate is a set of constrained UOp based on valid, and if every item in a set simplifies the uop into a same output, we rewrite uop
|
||||
candidates = []
|
||||
if expr.op is Ops.ADD and v0 == 1 and all(u.op in GroupOp.Irreducible for u in expr.split_uop(Ops.ADD)):
|
||||
# if the constraint is a simplex: X0 + X1 + ... > 0, we can check if all Xi > 0 simplify into the same output
|
||||
candidates.append([(Xi, UOp.variable("fake", 1, Xi.vmax, Xi.dtype)) for Xi in expr.split_uop(Ops.ADD)])
|
||||
# try checking the whole clause
|
||||
if expr in uop.toposort(): candidates.append([(expr, UOp.variable("fake", v0, v1, expr.dtype))])
|
||||
|
||||
for candidate in candidates:
|
||||
# if every branch in candidate gives the same simplified uop, we can rewrite the uop
|
||||
newuops = [uop.substitute({X:newX}).simplify().substitute({newX:X}).simplify() for X,newX in candidate]
|
||||
if uop.op is Ops.VECTORIZE and len(uop.src) == 2:
|
||||
if all_same([uops.src[0] for uops in newuops]): uop = uop.replace(src=(newuops[0].src[0], uop.src[1]))
|
||||
if all_same([uops.src[1] for uops in newuops]): uop = uop.replace(src=(uop.src[0], newuops[0].src[1]))
|
||||
elif all_same(newuops): uop = newuops[0]
|
||||
|
||||
# put the loads back in
|
||||
uop = uop.substitute({v:k for k,v in load_subs.items()})
|
||||
return uop
|
||||
|
||||
def _valid_priority(v: UOp, valids:list[UOp]):
|
||||
# we want valid that's in other valids' parents to be first, so it's more likely the other valids get simplified
|
||||
try: return sum(-1 if parse_valid(v)[0] in other.toposort() else 0 for other in valids)
|
||||
except ValueError: return 0
|
||||
|
||||
def simplify_valid(valid:UOp) -> UOp|None:
|
||||
ret:list[UOp] = []
|
||||
something_changed = False
|
||||
valids = list(valid.split_uop(Ops.AND))
|
||||
for stmt in sorted(valids, key=lambda v: _valid_priority(v, valids)):
|
||||
# TODO: root cause this and test_simplify_valid_from_div
|
||||
if stmt.op is Ops.CAST: return None
|
||||
ret.append(newstmt if ret and (newstmt:=uop_given_valid(functools.reduce(operator.and_, ret), stmt)) is not None else stmt)
|
||||
if ret[-1] is not stmt: something_changed = True
|
||||
return functools.reduce(operator.and_, ret) if something_changed else None
|
||||
|
||||
# ******** phase 3 is the complete symbolic, and deals with very complex things like loop rewriting and threefry transform ********
|
||||
|
||||
def reduce_mul_chain(r:UOp):
|
||||
if r.arg not in {Ops.ADD, Ops.MAX}: return None
|
||||
if r.dtype != r.src[0].dtype: return None
|
||||
inside, outside = [], []
|
||||
for m in r.src[0].split_uop(Ops.MUL):
|
||||
m_parents = m.toposort()
|
||||
if all(r not in m_parents for r in r.src[1:]) and (r.arg != Ops.MAX or m.vmin >= 0): outside.append(m)
|
||||
else: inside.append(m)
|
||||
if len(outside) == 0: return None
|
||||
return r.replace(src=(prod(inside) if len(inside) else r.src[0].const_like(1),)+r.src[1:])*prod(outside)
|
||||
|
||||
commutative = PatternMatcher([
|
||||
# ** COMMUTATIVE flipping (only for index) **
|
||||
# NOTE: this can break merging vector math by only flipping some of them
|
||||
@@ -282,6 +367,8 @@ symbolic = symbolic_simple+commutative+PatternMatcher([
|
||||
(UPat.var("x") | (UPat.var("x") & UPat.var()), lambda x: x), # x|(x&y) -> x
|
||||
# TODO: make a more general or folder like simplify_valid
|
||||
(UPat.var("x", dtype=dtypes.bool) | UPat.var("x").logical_not(), lambda x: x.const_like(True)), # x|!x -> True
|
||||
# simplify valid
|
||||
(UPat(Ops.AND, name="valid"), simplify_valid),
|
||||
# ** combine terms **
|
||||
(UPat.var("x") * UPat.cvar("c0") + UPat.var("x") * UPat.cvar("c1"), lambda x,c0,c1: x*(c0+c1)), # (x*c0)+(x*c1) -> x*(c0+c1)
|
||||
((UPat.var("y") + UPat.var("x") * UPat.cvar("c0")) + UPat.var("x") * UPat.cvar("c1"), lambda x,y,c0,c1: y+x*(c0+c1)),
|
||||
@@ -375,97 +462,10 @@ symbolic_flat = symbolic+PatternMatcher([
|
||||
((UPat.var("x", dtypes.index) + UPat.var("y")) * UPat.cvar("c"), lambda x,y,c: x*c+y*c),
|
||||
])
|
||||
|
||||
# ******** we take a small aside to "simplify_valid" to rewrite valids ********
|
||||
|
||||
def parse_valid(valid:UOp) -> tuple[UOp, bool, int]:
|
||||
# if it's X <= c, returns X, True, c
|
||||
# if it's X >= c, returns X, False, c
|
||||
|
||||
# (X < c).ne(True) -> X >= c
|
||||
if valid.op is Ops.CMPNE and valid.src[1].op is Ops.CONST and valid.src[1].arg == 1 and \
|
||||
(s0:=valid.src[0]).op is Ops.CMPLT and dtypes.is_int(s0.src[0].dtype): return s0.src[0], False, int(s0.src[1].vmin)
|
||||
# X < c -> X <= c-1
|
||||
if valid.op is Ops.CMPLT and dtypes.is_int(valid.src[0].dtype): return valid.src[0], True, int((valid.src[1]).vmax)-1
|
||||
raise ValueError(f"not able to parse {valid=}")
|
||||
|
||||
def uop_given_valid(valid:UOp, uop:UOp) -> UOp|None:
|
||||
# return None if valid is always False, otherwise the simplified uop (might be the same as input)
|
||||
|
||||
# first, parse valid into {expr: (lower_bound, upper_bound)}
|
||||
bounds:defaultdict[UOp, list[ConstType|None]] = defaultdict(lambda: [None, None])
|
||||
for stmt in valid.split_uop(Ops.AND):
|
||||
try: expr, is_upper, c = parse_valid(stmt)
|
||||
except ValueError: return uop # give up if we cannot parse the valid
|
||||
bounds[expr][int(is_upper)] = c
|
||||
|
||||
# don't simplify any other gates, can lead to OOB, we substitute them back later
|
||||
uop = uop.substitute((load_subs:={u: UOp(Ops.NOOP, arg=u) for u in uop.toposort() if u.op is Ops.INDEX}))
|
||||
|
||||
# simplify uop given that valid is True
|
||||
for expr,v in bounds.items():
|
||||
v0, v1 = (expr.vmin if v[0] is None else v[0], expr.vmax if v[1] is None else v[1])
|
||||
expr = expr.substitute(load_subs) # make sure expr appears in same form in the uop
|
||||
# some expr has lower bound > upper bound -> valid is an empty set and we return None
|
||||
if v0 > v1: return None
|
||||
# whole node became a const
|
||||
if v0 == v1:
|
||||
uop = uop.substitute({expr:expr.const_like(v0)}).simplify()
|
||||
continue
|
||||
# every candidate is a set of constrained UOp based on valid, and if every item in a set simplifies the uop into a same output, we rewrite uop
|
||||
candidates = []
|
||||
if expr.op is Ops.ADD and v0 == 1 and all(u.op in GroupOp.Irreducible for u in expr.split_uop(Ops.ADD)):
|
||||
# if the constraint is a simplex: X0 + X1 + ... > 0, we can check if all Xi > 0 simplify into the same output
|
||||
candidates.append([(Xi, UOp.variable("fake", 1, Xi.vmax, Xi.dtype)) for Xi in expr.split_uop(Ops.ADD)])
|
||||
# try checking the whole clause
|
||||
if expr in uop.toposort(): candidates.append([(expr, UOp.variable("fake", v0, v1, expr.dtype))])
|
||||
|
||||
for candidate in candidates:
|
||||
# if every branch in candidate gives the same simplified uop, we can rewrite the uop
|
||||
newuops = [uop.substitute({X:newX}).simplify().substitute({newX:X}).simplify() for X,newX in candidate]
|
||||
if uop.op is Ops.VECTORIZE and len(uop.src) == 2:
|
||||
if all_same([uops.src[0] for uops in newuops]): uop = uop.replace(src=(newuops[0].src[0], uop.src[1]))
|
||||
if all_same([uops.src[1] for uops in newuops]): uop = uop.replace(src=(uop.src[0], newuops[0].src[1]))
|
||||
elif all_same(newuops): uop = newuops[0]
|
||||
|
||||
# put the loads back in
|
||||
uop = uop.substitute({v:k for k,v in load_subs.items()})
|
||||
return uop
|
||||
|
||||
def _valid_priority(v: UOp, valids:list[UOp]):
|
||||
# we want valid that's in other valids' parents to be first, so it's more likely the other valids get simplified
|
||||
try: return sum(-1 if parse_valid(v)[0] in other.toposort() else 0 for other in valids)
|
||||
except ValueError: return 0
|
||||
|
||||
def simplify_valid(valid:UOp) -> UOp|None:
|
||||
ret:list[UOp] = []
|
||||
something_changed = False
|
||||
valids = list(valid.split_uop(Ops.AND))
|
||||
for stmt in sorted(valids, key=lambda v: _valid_priority(v, valids)):
|
||||
# TODO: root cause this and test_simplify_valid_from_div
|
||||
if stmt.op is Ops.CAST: return None
|
||||
ret.append(newstmt if ret and (newstmt:=uop_given_valid(functools.reduce(operator.and_, ret), stmt)) is not None else stmt)
|
||||
if ret[-1] is not stmt: something_changed = True
|
||||
return functools.reduce(operator.and_, ret) if something_changed else None
|
||||
|
||||
# ******** phase 3 is the complete symbolic, and deals with very complex things like loop rewriting and threefry transform ********
|
||||
|
||||
def reduce_mul_chain(r:UOp):
|
||||
if r.arg not in {Ops.ADD, Ops.MAX}: return None
|
||||
if r.dtype != r.src[0].dtype: return None
|
||||
inside, outside = [], []
|
||||
for m in r.src[0].split_uop(Ops.MUL):
|
||||
m_parents = m.toposort()
|
||||
if all(r not in m_parents for r in r.src[1:]) and (r.arg != Ops.MAX or m.vmin >= 0): outside.append(m)
|
||||
else: inside.append(m)
|
||||
if len(outside) == 0: return None
|
||||
return r.replace(src=(prod(inside) if len(inside) else r.src[0].const_like(1),)+r.src[1:])*prod(outside)
|
||||
|
||||
# this is symbolic 2.0
|
||||
REMOVE_FROM_SINK = {Ops.SINK, Ops.UNROLL, Ops.PTRCAT, Ops.CAT, Ops.NOOP}
|
||||
REMOVE_FROM_BARRIER = {Ops.VECTORIZE, Ops.SINK, Ops.CAT, Ops.PTRCAT, Ops.NOOP}
|
||||
sym = symbolic_flat+PatternMatcher([
|
||||
# simplify valid
|
||||
(UPat(Ops.AND, name="valid"), simplify_valid),
|
||||
# LOAD/STORE -> NOOP
|
||||
(UPat.var('x').store(UPat.var('x').load(), allow_any_len=True), lambda x: None if x.dtype.addrspace != AddrSpace.REG else x.src[0].src[0]),
|
||||
(UPat(Ops.LOAD, src=(UPat.cvar('c'))), lambda c: c),
|
||||
|
||||
Reference in New Issue
Block a user