delete rules from sym [pr] (#13339)

This commit is contained in:
chenyu
2025-11-18 14:57:35 -05:00
committed by GitHub
parent 9c59b3d19e
commit 46cb65e692

View File

@@ -2,7 +2,7 @@
import math, operator, struct, functools
from collections import defaultdict
from tinygrad.uop.ops import Ops, PatternMatcher, UPat, UOp, GroupOp, exec_alu
from tinygrad.dtype import ConstType, dtypes, PtrDType, AddrSpace, can_safe_cast, Invalid
from tinygrad.dtype import ConstType, dtypes, PtrDType, can_safe_cast, Invalid
from tinygrad.helpers import partition, all_same, prod, flatten, get_single_element, cdiv, cmod, CORRECT_DIVMOD_FOLDING, unwrap
from tinygrad.uop.decompositions import xpow
@@ -450,7 +450,7 @@ def simplify_valid(valid:UOp) -> UOp|None:
if ret[-1] is not stmt: something_changed = True
return UOp.prod(*ret) if something_changed else None
# ******** phase 3 is the complete symbolic, and deals with very complex things like loop rewriting and threefry transform ********
# ******** phase 3 is the complete symbolic ********
def reduce_mul_chain(r:UOp):
if r.arg not in {Ops.ADD, Ops.MAX}: return None
@@ -479,6 +479,8 @@ def where_on_load(c1, buf, x):
# aditionally we can drop the clause on the where if it already exists in the load
remaining_clause = UOp.const(dtypes.bool, True).prod(*[c for c in c1.split_uop(Ops.AND) if c not in removed])
return remaining_clause.where(buf.index(x.get_idx().valid(functools.reduce(operator.and_, moved_clauses, c2))), 0)
# where after gated load becomes alt value, TODO: this is sort of duplicated with rules in devectorizer
pm_move_where_on_load = PatternMatcher([
(UPat.var("c1").where(UPat.var("buf").index(UPat.var("x")), 0), where_on_load),
(UPat.var("c1").where(0, UPat.var("buf").index(UPat.var("x"))), lambda c1,buf,x: where_on_load(c1.logical_not(),buf,x)),
@@ -494,9 +496,6 @@ pm_simplify_valid = PatternMatcher([
# this is symbolic 2.0
REMOVE_FROM_SINK_LIKE = {Ops.UNROLL, Ops.NOOP, Ops.VECTORIZE, Ops.SINK}
sym = symbolic+pm_simplify_valid+PatternMatcher([
# LOAD/STORE -> NOOP
(UPat.var('x').store(UPat.var('x').load(), allow_any_len=True), lambda x: None if x.dtype.addrspace != AddrSpace.REG else x.src[0].src[0]),
(UPat(Ops.LOAD, src=(UPat.cvar('c'))), lambda c: c),
# VECTORIZE/GEP
(UPat(Ops.VECTORIZE, src=UPat(Ops.GEP, src=(UPat.var("x"),)), name="vec"), lambda vec,x: x.gep(tuple(y.arg[0] for y in vec.src))),
# reorder ALU/VECTORIZE
@@ -525,7 +524,6 @@ sym = symbolic+pm_simplify_valid+PatternMatcher([
# fold gated LOAD/STORE
(UPat((Ops.LOAD, Ops.STORE), src=(UPat().index(UPat.const(dtypes.index, Invalid)).or_casted(),), allow_any_len=True, name="x"),
lambda x: UOp(Ops.NOOP) if x.op is Ops.STORE else x.const_like(0)), # invalid store does nothing. invalid load produces 0
# # Where after gated load becomes alt value, TODO: this is sort of duplicated with rules in devectorizer
((UPat.var("x") * UPat.var("x")).reciprocal(), lambda x: x.reciprocal()*x.reciprocal()), # 1/(x^c) -> (1/x)^c
((UPat.var("x") * UPat.var("x") * UPat.var("x")).reciprocal(), lambda x: x.reciprocal()*x.reciprocal()*x.reciprocal()),
((UPat.var("x") * UPat.cvar("c")).reciprocal(), lambda x,c: x.reciprocal()*c.reciprocal()), # 1/(x*c) -> (1/c)*(1/x)