split_uop is a method (#11984)

This commit is contained in:
George Hotz
2025-09-03 10:46:17 -07:00
committed by GitHub
parent 1877eddde4
commit 55e4bdd353
4 changed files with 28 additions and 26 deletions

View File

@@ -4,7 +4,7 @@ from collections import defaultdict
from dataclasses import dataclass
from tinygrad.dtype import dtypes, ImageDType, DType, AddrSpace
from tinygrad.uop.ops import UOp, Ops, UPat, PatternMatcher, graph_rewrite, GroupOp, identity_element
from tinygrad.uop.symbolic import split_uop, uop_given_valid, parse_valid, simplify_valid, sym, symbolic_flat
from tinygrad.uop.symbolic import uop_given_valid, parse_valid, simplify_valid, sym, symbolic_flat
from tinygrad.helpers import getenv, flatten, AMX, prod, partition
from tinygrad.renderer import Renderer
@@ -19,13 +19,13 @@ def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> UOp|None:
# can drop valid if idx is out of bound when valid is False
drop_stmt = []
for stmt in split_uop(valid, Ops.AND):
for stmt in valid.split_uop(Ops.AND):
try: X, is_upper_bound, c = parse_valid(stmt)
except ValueError: return None
# for X0 + X1 + ... >= 1, check if it's out of bound when Xi = 0 for all i
if not is_upper_bound and c == 1 and all(u.op in GroupOp.Irreducible and u.vmin == 0 for u in split_uop(X, Ops.ADD)):
testidx = functools.reduce(lambda nowidx,u: nowidx.substitute({u:u.const_like(0)}), split_uop(X, Ops.ADD), idx)
if not is_upper_bound and c == 1 and all(u.op in GroupOp.Irreducible and u.vmin == 0 for u in X.split_uop(Ops.ADD)):
testidx = functools.reduce(lambda nowidx,u: nowidx.substitute({u:u.const_like(0)}), X.split_uop(Ops.ADD), idx)
testidx = testidx.simplify()
if testidx.gep(0).vmax < 0 or testidx.gep(1).vmax < 0:
drop_stmt.append(stmt)
@@ -42,7 +42,7 @@ def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> UOp|None:
break
if not drop_stmt and idx is start_idx: return None
new_valid = functools.reduce(operator.and_, ss) if (ss:=[s for s in split_uop(valid, Ops.AND) if s not in drop_stmt]) else None
new_valid = functools.reduce(operator.and_, ss) if (ss:=[s for s in valid.split_uop(Ops.AND) if s not in drop_stmt]) else None
return buf.index(idx, new_valid)
def delete_redundant_gates(store:UOp, buf:UOp, idx:UOp, val:UOp, store_gate:UOp, cast:UOp|None=None) -> UOp|None:

View File

@@ -7,7 +7,7 @@ from tinygrad.helpers import merge_dicts, getenv
from tinygrad.shape.view import View, unravel
from tinygrad.dtype import dtypes
from tinygrad.uop.ops import UOp, Ops, graph_rewrite, Variable, sint, sint_to_uop, Context, PatternMatcher, UPat, GroupOp
from tinygrad.uop.symbolic import split_uop, symbolic_flat, uop_given_valid, simplify_valid
from tinygrad.uop.symbolic import symbolic_flat, uop_given_valid, simplify_valid
# If a node overflow, its srcs need to be checked to see if this overflow is the result of an ALU operation,
# or that the node simply inherits the dtype from srcs. Upcast is either `Ops.CAST`+`replace` or just `replace`.
@@ -43,7 +43,7 @@ def views_to_real_strides(views: tuple[View, ...], ignore_valid=False) -> tuple[
if len(views) == 1 and views[-1].mask is None: return views[-1].strides
ret: list[sint|None] = [None] * len(views[-1].shape)
idx, valid = views_to_indexed_uops(views)
for c in split_uop(idx, Ops.ADD):
for c in idx.split_uop(Ops.ADD):
if c.op is Ops.RANGE: ret[c.arg[0]] = 1
if c.op is Ops.MUL and c.src[0].op is Ops.RANGE and c.src[1].op is Ops.CONST: ret[c.src[0].arg[0]] = c.src[1].arg
if c.op is Ops.MUL and c.src[1].op is Ops.RANGE and c.src[0].op is Ops.CONST: ret[c.src[1].arg[0]] = c.src[0].arg

View File

@@ -329,6 +329,13 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
return UOp(Ops.ALLREDUCE, self.dtype, (self, UOp(Ops.DEVICE, arg=device) if not isinstance(device, UOp) else device), op)
def overflows(self, dtype:DType) -> bool: return self.vmin < dtype.min or dtype.max < self.vmax
# *** ShapeTracker helpers ***
def split_uop(self:UOp, sep:Ops):
if self.op is sep:
for s in self.src: yield from s.split_uop(sep)
else: yield self
# *** from MultiLazyBuffer ***
def multi(self, axis:int|None):

View File

@@ -93,16 +93,11 @@ symbolic_simple = PatternMatcher([
# ******** phase 2 builds on phase 1, it includes the old "symbolic", rules that match deeper ********
def split_uop(x:UOp, sep:Ops):
if x.op is sep:
for s in x.src: yield from split_uop(s, sep)
else: yield x
def fold_unrolled_divs(divs:UOp, denominator: int, fac=1) -> UOp|None:
# div pattern in unrolled arange
# example: (x//4+(x+1)//4+(x+2)//4+(x+3)//4 -> x
seen_const, ans = [], None
for u in split_uop(divs, Ops.ADD):
for u in divs.split_uop(Ops.ADD):
if fac!=1:
if u.op is not Ops.MUL or u.src[1].op is not Ops.CONST or u.src[1].arg != fac: return None
u = u.src[0]
@@ -125,7 +120,7 @@ def fold_unrolled_divs(divs:UOp, denominator: int, fac=1) -> UOp|None:
return None
def lt_folding(x:UOp, c:int) -> UOp|None:
p, np = partition(split_uop(x, Ops.ADD), lambda u: u.const_factor() == 1)
p, np = partition(x.split_uop(Ops.ADD), lambda u: u.const_factor() == 1)
if np and (d:=math.gcd(*[u.const_factor() for u in np], c)) > 1 and 0 <= sum(u.vmin for u in p) and sum(u.vmax for u in p) < d:
return cast(UOp, functools.reduce(operator.add, np).divides(d))<(c//d)
return None
@@ -134,7 +129,7 @@ def canonicalize_simplex(X:UOp) -> UOp|None:
# (X := a0*x0 + a1*x1 + ...) > 0 is equivalent to x0 + x1 + ... > 0 if xi >= 0 and ai > 0 for ints.
# returns x0 + x1 + ... in such case, or None if not
changed, ret = False, []
for u in split_uop(X, Ops.ADD):
for u in X.split_uop(Ops.ADD):
# assumed the const is the last src of MUL
if u.op is Ops.MUL and u.src[1].op is Ops.CONST and u.src[1].arg > 0:
changed = True
@@ -158,7 +153,7 @@ def remove_nested_mod(m: UOp, x: UOp, y: UOp) -> UOp|None:
if ((c := y.arg) < 0) or x.vmin<0: return None
new_xs = []
something_changed = False
for u in split_uop(x, Ops.ADD):
for u in x.split_uop(Ops.ADD):
if u.op is Ops.MOD:
if u.src[1].divides(c) is not None:
something_changed = True
@@ -172,7 +167,7 @@ def fold_binary_numerator(d: UOp, x: UOp, y: UOp) -> UOp|None:
# we can fold if the expression has only one non-constant term and this term can only take on two values
if ((c := y.arg) < 0) or (x.dtype.count > 1): return None
x,const = x.pop_const()
terms, factors = zip(*[(u.divides(f:=u.const_factor()),f) for u in split_uop(x, Ops.ADD)])
terms, factors = zip(*[(u.divides(f:=u.const_factor()),f) for u in x.split_uop(Ops.ADD)])
if len(terms)==1 and (v:=terms[0]).vmax-v.vmin == 1:
y1 = cmod(factors[0]*v.vmin+const, c) if d.op is Ops.MOD else cdiv(factors[0]*v.vmin+const, c) # type: ignore
y2 = cmod(factors[0]*v.vmax+const, c) if d.op is Ops.MOD else cdiv(factors[0]*v.vmax+const, c) # type: ignore
@@ -183,7 +178,7 @@ def fold_divmod_congruence(d: UOp, x: UOp, y: UOp) -> UOp|None:
# within a mod we can freely subtract multiples of c, we use this to see if a is congruent to an expression whose vmin/vmax are between 0 and c
if (x.vmin<0 and CORRECT_DIVMOD_FOLDING) or ((c := y.arg) < 0) or (x.dtype.count > 1): return None
x,const = x.pop_const()
terms, factors = zip(*[(u.divides(f:=u.const_factor()),f) for u in split_uop(x, Ops.ADD)])
terms, factors = zip(*[(u.divides(f:=u.const_factor()),f) for u in x.split_uop(Ops.ADD)])
# a//c = (a-a%c)/c, if we can fold a%c, we can fold a//c
rems = [min((r:=f%c), r-c, key=abs) for f in factors]
if (rem:=sum(r*v for r,v in zip(rems,terms))+const%c).vmin//c!=rem.vmax//c: return None
@@ -192,7 +187,7 @@ def fold_divmod_congruence(d: UOp, x: UOp, y: UOp) -> UOp|None:
def divide_by_gcd(d: UOp, x: UOp, y: UOp) -> UOp|None:
# x//y -> (x//gcd)//(y//gcd) or x%y -> gcd*(x//gcd)%(y//gcd)
terms, factors = zip(*[(u.divides(f:=u.const_factor()),f) for u in split_uop(x, Ops.ADD)])
terms, factors = zip(*[(u.divides(f:=u.const_factor()),f) for u in x.split_uop(Ops.ADD)])
if (gcd := math.gcd(y.arg, *factors)) == 1: return None
ret = sum(f//gcd * v for f,v in zip(factors, terms)).alu(d.op, y.const_like(y.arg//gcd))
return ret*gcd if d.op is Ops.MOD else ret
@@ -200,7 +195,7 @@ def divide_by_gcd(d: UOp, x: UOp, y: UOp) -> UOp|None:
def nest_div_by_smallest_factor(d: UOp, x: UOp, y: UOp) -> UOp|None:
# we try and nest the div and see if it allows the numerator to be simplified
if ((c := y.arg) < 0) or (x.dtype.count > 1): return None
factors = [u.const_factor() for u in split_uop(x.pop_const()[0], Ops.ADD)]
factors = [u.const_factor() for u in x.pop_const()[0].split_uop(Ops.ADD)]
# div is the smallest factor of the denominator (greater than 1) out of all "factors"
# TODO: there are better ways to pick `div`, this sometimes adds extra divisions
# TODO: add same optimization for mod
@@ -212,7 +207,7 @@ def simplify_remainder(d: UOp, x: UOp, y: UOp) -> UOp|None:
# we try and take out the quotient and see if it allows the numerator to be simplified
if ((c := y.arg) < 0) or (x.dtype.count > 1): return None
x_no_const,const = x.pop_const()
terms, factors = zip(*[(u.divides(f:=u.const_factor()),f) for u in split_uop(x_no_const, Ops.ADD)])
terms, factors = zip(*[(u.divides(f:=u.const_factor()),f) for u in x_no_const.split_uop(Ops.ADD)])
quotients, remainders = zip(*[divmod(f, c) for f in factors])
gcd = math.gcd(c, *remainders) # gcd without const!
if const%c==const and gcd==1 and not any(r==0 or (r!=f and d.op is Ops.MOD) for r,f in zip(remainders, factors)): return None
@@ -385,7 +380,7 @@ def uop_given_valid(valid:UOp, uop:UOp) -> UOp|None:
# first, parse valid into {expr: (lower_bound, upper_bound)}
bounds:defaultdict[UOp, list[ConstType|None]] = defaultdict(lambda: [None, None])
for stmt in split_uop(valid, Ops.AND):
for stmt in valid.split_uop(Ops.AND):
try: expr, is_upper, c = parse_valid(stmt)
except ValueError: return uop # give up if we cannot parse the valid
bounds[expr][int(is_upper)] = c
@@ -404,9 +399,9 @@ def uop_given_valid(valid:UOp, uop:UOp) -> UOp|None:
continue
# every candidate is a set of constrained UOp based on valid, and if every item in a set simplifies the uop into a same output, we rewrite uop
candidates = []
if expr.op is Ops.ADD and v0 == 1 and all(u.op in GroupOp.Irreducible for u in split_uop(expr, Ops.ADD)):
if expr.op is Ops.ADD and v0 == 1 and all(u.op in GroupOp.Irreducible for u in expr.split_uop(Ops.ADD)):
# if the constraint is a simplex: X0 + X1 + ... > 0, we can check if all Xi > 0 simplify into the same output
candidates.append([(Xi, UOp.variable("fake", 1, Xi.vmax, Xi.dtype)) for Xi in split_uop(expr, Ops.ADD)])
candidates.append([(Xi, UOp.variable("fake", 1, Xi.vmax, Xi.dtype)) for Xi in expr.split_uop(Ops.ADD)])
# try checking the whole clause
if expr in uop.toposort(): candidates.append([(expr, UOp.variable("fake", v0, v1, expr.dtype))])
@@ -430,7 +425,7 @@ def _valid_priority(v: UOp, valids:list[UOp]):
def simplify_valid(valid:UOp) -> UOp|None:
ret:list[UOp] = []
something_changed = False
valids = list(split_uop(valid, Ops.AND))
valids = list(valid.split_uop(Ops.AND))
for stmt in sorted(valids, key=lambda v: _valid_priority(v, valids)):
# TODO: root cause this and test_simplify_valid_from_div
if stmt.op is Ops.CAST: return None
@@ -444,7 +439,7 @@ def reduce_mul_chain(r:UOp):
if r.arg not in {Ops.ADD, Ops.MAX}: return None
if r.dtype != r.src[0].dtype: return None
inside, outside = [], []
for m in split_uop(r.src[0], Ops.MUL):
for m in r.src[0].split_uop(Ops.MUL):
m_parents = m.toposort()
if all(r not in m_parents for r in r.src[1:]) and (r.arg != Ops.MAX or m.vmin >= 0): outside.append(m)
else: inside.append(m)