simpler simplify_valid [pr] (#13514)

dedup instead of getting a True clause which is removed later
This commit is contained in:
chenyu
2025-12-01 17:36:33 -05:00
committed by GitHub
parent a5ec3b24be
commit 0b92fd30f5

View File

@@ -3,7 +3,7 @@ import math, operator, struct, functools
from collections import defaultdict
from tinygrad.uop.ops import Ops, PatternMatcher, UPat, UOp, GroupOp, exec_alu
from tinygrad.dtype import ConstType, dtypes, PtrDType, can_safe_cast, Invalid
from tinygrad.helpers import partition, all_same, prod, flatten, get_single_element, unwrap
from tinygrad.helpers import partition, all_same, prod, flatten, get_single_element, unwrap, dedup
from tinygrad.uop.decompositions import xpow
from tinygrad.uop.divandmod import div_and_mod_symbolic
@@ -320,12 +320,12 @@ def _valid_priority(v: UOp, valids:list[UOp]):
def simplify_valid(valid:UOp) -> UOp|None:
if valid.op_in_backward_slice_with_self(Ops.INDEX): return None # this should only be for indexing, skip if there's a INDEX
ret:list[UOp] = []
something_changed = False
valids = list(valid.split_uop(Ops.AND))
for stmt in sorted(valids, key=lambda v: _valid_priority(v, valids)):
ret.append(uop_given_valid(UOp.prod(*ret), stmt) if ret else stmt)
if ret[-1] is not stmt: something_changed = True
return UOp.prod(*ret) if something_changed else None
valids = sorted(valids, key=lambda v: _valid_priority(v, valids))
for stmt in dedup(valids):
if ret: stmt = uop_given_valid(UOp.prod(*ret), stmt)
ret.append(stmt)
return UOp.prod(*ret) if ret != valids else None
# ******** phase 3 is the complete symbolic ********