mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
simpler simplify_valid [pr] (#13514)
dedup instead of getting a True clause which is removed later
This commit is contained in:
@@ -3,7 +3,7 @@ import math, operator, struct, functools
|
||||
from collections import defaultdict
|
||||
from tinygrad.uop.ops import Ops, PatternMatcher, UPat, UOp, GroupOp, exec_alu
|
||||
from tinygrad.dtype import ConstType, dtypes, PtrDType, can_safe_cast, Invalid
|
||||
from tinygrad.helpers import partition, all_same, prod, flatten, get_single_element, unwrap
|
||||
from tinygrad.helpers import partition, all_same, prod, flatten, get_single_element, unwrap, dedup
|
||||
from tinygrad.uop.decompositions import xpow
|
||||
from tinygrad.uop.divandmod import div_and_mod_symbolic
|
||||
|
||||
@@ -320,12 +320,12 @@ def _valid_priority(v: UOp, valids:list[UOp]):
|
||||
def simplify_valid(valid:UOp) -> UOp|None:
|
||||
if valid.op_in_backward_slice_with_self(Ops.INDEX): return None # this should only be for indexing, skip if there's a INDEX
|
||||
ret:list[UOp] = []
|
||||
something_changed = False
|
||||
valids = list(valid.split_uop(Ops.AND))
|
||||
for stmt in sorted(valids, key=lambda v: _valid_priority(v, valids)):
|
||||
ret.append(uop_given_valid(UOp.prod(*ret), stmt) if ret else stmt)
|
||||
if ret[-1] is not stmt: something_changed = True
|
||||
return UOp.prod(*ret) if something_changed else None
|
||||
valids = sorted(valids, key=lambda v: _valid_priority(v, valids))
|
||||
for stmt in dedup(valids):
|
||||
if ret: stmt = uop_given_valid(UOp.prod(*ret), stmt)
|
||||
ret.append(stmt)
|
||||
return UOp.prod(*ret) if ret != valids else None
|
||||
|
||||
# ******** phase 3 is the complete symbolic ********
|
||||
|
||||
|
||||
Reference in New Issue
Block a user