mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
no load substitute in uop_given_valid [pr] (#13333)
This commit is contained in:
@@ -407,14 +407,10 @@ def uop_given_valid(valid:UOp, uop:UOp, try_simplex=True) -> UOp:
|
||||
expr, is_upper, c = res
|
||||
bounds[expr][int(is_upper)] = c
|
||||
|
||||
# don't simplify any other gates, can lead to OOB, we substitute them back later
|
||||
uop = uop.substitute((load_subs:={u: UOp(Ops.NOOP, dtype=u.dtype, arg=u) for u in uop.toposort() if u.op is Ops.INDEX}))
|
||||
|
||||
# simplify uop given that valid is True
|
||||
all_candidates = []
|
||||
for i,(expr,v) in enumerate(bounds.items()):
|
||||
v0, v1 = (expr.vmin if v[0] is None else v[0], expr.vmax if v[1] is None else v[1])
|
||||
expr = expr.substitute(load_subs) # make sure expr appears in same form in the uop
|
||||
# try checking the whole clause
|
||||
all_candidates.append((expr, UOp.variable(f"fake{i}", v0, v1, expr.dtype)))
|
||||
|
||||
@@ -438,8 +434,6 @@ def uop_given_valid(valid:UOp, uop:UOp, try_simplex=True) -> UOp:
|
||||
# try all the valids together (but only the whole expressions)
|
||||
if (s_uop:=uop.substitute(sub_dict:=dict(all_candidates))) is not uop:
|
||||
uop = s_uop.simplify().substitute({newX:X for X,newX in sub_dict.items()}).simplify(full_symbolic=False)
|
||||
# put the loads back in
|
||||
uop = uop.substitute({v:k for k,v in load_subs.items()})
|
||||
return uop
|
||||
|
||||
def _valid_priority(v: UOp, valids:list[UOp]):
|
||||
|
||||
Reference in New Issue
Block a user