mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
delete rules from sym [pr] (#13339)
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
import math, operator, struct, functools
|
||||
from collections import defaultdict
|
||||
from tinygrad.uop.ops import Ops, PatternMatcher, UPat, UOp, GroupOp, exec_alu
|
||||
from tinygrad.dtype import ConstType, dtypes, PtrDType, AddrSpace, can_safe_cast, Invalid
|
||||
from tinygrad.dtype import ConstType, dtypes, PtrDType, can_safe_cast, Invalid
|
||||
from tinygrad.helpers import partition, all_same, prod, flatten, get_single_element, cdiv, cmod, CORRECT_DIVMOD_FOLDING, unwrap
|
||||
from tinygrad.uop.decompositions import xpow
|
||||
|
||||
@@ -450,7 +450,7 @@ def simplify_valid(valid:UOp) -> UOp|None:
|
||||
if ret[-1] is not stmt: something_changed = True
|
||||
return UOp.prod(*ret) if something_changed else None
|
||||
|
||||
# ******** phase 3 is the complete symbolic, and deals with very complex things like loop rewriting and threefry transform ********
|
||||
# ******** phase 3 is the complete symbolic ********
|
||||
|
||||
def reduce_mul_chain(r:UOp):
|
||||
if r.arg not in {Ops.ADD, Ops.MAX}: return None
|
||||
@@ -479,6 +479,8 @@ def where_on_load(c1, buf, x):
|
||||
# aditionally we can drop the clause on the where if it already exists in the load
|
||||
remaining_clause = UOp.const(dtypes.bool, True).prod(*[c for c in c1.split_uop(Ops.AND) if c not in removed])
|
||||
return remaining_clause.where(buf.index(x.get_idx().valid(functools.reduce(operator.and_, moved_clauses, c2))), 0)
|
||||
|
||||
# where after gated load becomes alt value, TODO: this is sort of duplicated with rules in devectorizer
|
||||
pm_move_where_on_load = PatternMatcher([
|
||||
(UPat.var("c1").where(UPat.var("buf").index(UPat.var("x")), 0), where_on_load),
|
||||
(UPat.var("c1").where(0, UPat.var("buf").index(UPat.var("x"))), lambda c1,buf,x: where_on_load(c1.logical_not(),buf,x)),
|
||||
@@ -494,9 +496,6 @@ pm_simplify_valid = PatternMatcher([
|
||||
# this is symbolic 2.0
|
||||
REMOVE_FROM_SINK_LIKE = {Ops.UNROLL, Ops.NOOP, Ops.VECTORIZE, Ops.SINK}
|
||||
sym = symbolic+pm_simplify_valid+PatternMatcher([
|
||||
# LOAD/STORE -> NOOP
|
||||
(UPat.var('x').store(UPat.var('x').load(), allow_any_len=True), lambda x: None if x.dtype.addrspace != AddrSpace.REG else x.src[0].src[0]),
|
||||
(UPat(Ops.LOAD, src=(UPat.cvar('c'))), lambda c: c),
|
||||
# VECTORIZE/GEP
|
||||
(UPat(Ops.VECTORIZE, src=UPat(Ops.GEP, src=(UPat.var("x"),)), name="vec"), lambda vec,x: x.gep(tuple(y.arg[0] for y in vec.src))),
|
||||
# reorder ALU/VECTORIZE
|
||||
@@ -525,7 +524,6 @@ sym = symbolic+pm_simplify_valid+PatternMatcher([
|
||||
# fold gated LOAD/STORE
|
||||
(UPat((Ops.LOAD, Ops.STORE), src=(UPat().index(UPat.const(dtypes.index, Invalid)).or_casted(),), allow_any_len=True, name="x"),
|
||||
lambda x: UOp(Ops.NOOP) if x.op is Ops.STORE else x.const_like(0)), # invalid store does nothing. invalid load produces 0
|
||||
# # Where after gated load becomes alt value, TODO: this is sort of duplicated with rules in devectorizer
|
||||
((UPat.var("x") * UPat.var("x")).reciprocal(), lambda x: x.reciprocal()*x.reciprocal()), # 1/(x^c) -> (1/x)^c
|
||||
((UPat.var("x") * UPat.var("x") * UPat.var("x")).reciprocal(), lambda x: x.reciprocal()*x.reciprocal()*x.reciprocal()),
|
||||
((UPat.var("x") * UPat.cvar("c")).reciprocal(), lambda x,c: x.reciprocal()*c.reciprocal()), # 1/(x*c) -> (1/c)*(1/x)
|
||||
|
||||
Reference in New Issue
Block a user