mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
split_uop is a method (#11984)
This commit is contained in:
@@ -4,7 +4,7 @@ from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from tinygrad.dtype import dtypes, ImageDType, DType, AddrSpace
|
||||
from tinygrad.uop.ops import UOp, Ops, UPat, PatternMatcher, graph_rewrite, GroupOp, identity_element
|
||||
from tinygrad.uop.symbolic import split_uop, uop_given_valid, parse_valid, simplify_valid, sym, symbolic_flat
|
||||
from tinygrad.uop.symbolic import uop_given_valid, parse_valid, simplify_valid, sym, symbolic_flat
|
||||
from tinygrad.helpers import getenv, flatten, AMX, prod, partition
|
||||
from tinygrad.renderer import Renderer
|
||||
|
||||
@@ -19,13 +19,13 @@ def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> UOp|None:
|
||||
|
||||
# can drop valid if idx is out of bound when valid is False
|
||||
drop_stmt = []
|
||||
for stmt in split_uop(valid, Ops.AND):
|
||||
for stmt in valid.split_uop(Ops.AND):
|
||||
try: X, is_upper_bound, c = parse_valid(stmt)
|
||||
except ValueError: return None
|
||||
|
||||
# for X0 + X1 + ... >= 1, check if it's out of bound when Xi = 0 for all i
|
||||
if not is_upper_bound and c == 1 and all(u.op in GroupOp.Irreducible and u.vmin == 0 for u in split_uop(X, Ops.ADD)):
|
||||
testidx = functools.reduce(lambda nowidx,u: nowidx.substitute({u:u.const_like(0)}), split_uop(X, Ops.ADD), idx)
|
||||
if not is_upper_bound and c == 1 and all(u.op in GroupOp.Irreducible and u.vmin == 0 for u in X.split_uop(Ops.ADD)):
|
||||
testidx = functools.reduce(lambda nowidx,u: nowidx.substitute({u:u.const_like(0)}), X.split_uop(Ops.ADD), idx)
|
||||
testidx = testidx.simplify()
|
||||
if testidx.gep(0).vmax < 0 or testidx.gep(1).vmax < 0:
|
||||
drop_stmt.append(stmt)
|
||||
@@ -42,7 +42,7 @@ def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> UOp|None:
|
||||
break
|
||||
|
||||
if not drop_stmt and idx is start_idx: return None
|
||||
new_valid = functools.reduce(operator.and_, ss) if (ss:=[s for s in split_uop(valid, Ops.AND) if s not in drop_stmt]) else None
|
||||
new_valid = functools.reduce(operator.and_, ss) if (ss:=[s for s in valid.split_uop(Ops.AND) if s not in drop_stmt]) else None
|
||||
return buf.index(idx, new_valid)
|
||||
|
||||
def delete_redundant_gates(store:UOp, buf:UOp, idx:UOp, val:UOp, store_gate:UOp, cast:UOp|None=None) -> UOp|None:
|
||||
|
||||
@@ -7,7 +7,7 @@ from tinygrad.helpers import merge_dicts, getenv
|
||||
from tinygrad.shape.view import View, unravel
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.uop.ops import UOp, Ops, graph_rewrite, Variable, sint, sint_to_uop, Context, PatternMatcher, UPat, GroupOp
|
||||
from tinygrad.uop.symbolic import split_uop, symbolic_flat, uop_given_valid, simplify_valid
|
||||
from tinygrad.uop.symbolic import symbolic_flat, uop_given_valid, simplify_valid
|
||||
|
||||
# If a node overflow, its srcs need to be checked to see if this overflow is the result of an ALU operation,
|
||||
# or that the node simply inherits the dtype from srcs. Upcast is either `Ops.CAST`+`replace` or just `replace`.
|
||||
@@ -43,7 +43,7 @@ def views_to_real_strides(views: tuple[View, ...], ignore_valid=False) -> tuple[
|
||||
if len(views) == 1 and views[-1].mask is None: return views[-1].strides
|
||||
ret: list[sint|None] = [None] * len(views[-1].shape)
|
||||
idx, valid = views_to_indexed_uops(views)
|
||||
for c in split_uop(idx, Ops.ADD):
|
||||
for c in idx.split_uop(Ops.ADD):
|
||||
if c.op is Ops.RANGE: ret[c.arg[0]] = 1
|
||||
if c.op is Ops.MUL and c.src[0].op is Ops.RANGE and c.src[1].op is Ops.CONST: ret[c.src[0].arg[0]] = c.src[1].arg
|
||||
if c.op is Ops.MUL and c.src[1].op is Ops.RANGE and c.src[0].op is Ops.CONST: ret[c.src[1].arg[0]] = c.src[0].arg
|
||||
|
||||
@@ -329,6 +329,13 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
return UOp(Ops.ALLREDUCE, self.dtype, (self, UOp(Ops.DEVICE, arg=device) if not isinstance(device, UOp) else device), op)
|
||||
def overflows(self, dtype:DType) -> bool: return self.vmin < dtype.min or dtype.max < self.vmax
|
||||
|
||||
# *** ShapeTracker helpers ***
|
||||
|
||||
def split_uop(self:UOp, sep:Ops):
|
||||
if self.op is sep:
|
||||
for s in self.src: yield from s.split_uop(sep)
|
||||
else: yield self
|
||||
|
||||
# *** from MultiLazyBuffer ***
|
||||
|
||||
def multi(self, axis:int|None):
|
||||
|
||||
@@ -93,16 +93,11 @@ symbolic_simple = PatternMatcher([
|
||||
|
||||
# ******** phase 2 builds on phase 1, it includes the old "symbolic", rules that match deeper ********
|
||||
|
||||
def split_uop(x:UOp, sep:Ops):
|
||||
if x.op is sep:
|
||||
for s in x.src: yield from split_uop(s, sep)
|
||||
else: yield x
|
||||
|
||||
def fold_unrolled_divs(divs:UOp, denominator: int, fac=1) -> UOp|None:
|
||||
# div pattern in unrolled arange
|
||||
# example: (x//4+(x+1)//4+(x+2)//4+(x+3)//4 -> x
|
||||
seen_const, ans = [], None
|
||||
for u in split_uop(divs, Ops.ADD):
|
||||
for u in divs.split_uop(Ops.ADD):
|
||||
if fac!=1:
|
||||
if u.op is not Ops.MUL or u.src[1].op is not Ops.CONST or u.src[1].arg != fac: return None
|
||||
u = u.src[0]
|
||||
@@ -125,7 +120,7 @@ def fold_unrolled_divs(divs:UOp, denominator: int, fac=1) -> UOp|None:
|
||||
return None
|
||||
|
||||
def lt_folding(x:UOp, c:int) -> UOp|None:
|
||||
p, np = partition(split_uop(x, Ops.ADD), lambda u: u.const_factor() == 1)
|
||||
p, np = partition(x.split_uop(Ops.ADD), lambda u: u.const_factor() == 1)
|
||||
if np and (d:=math.gcd(*[u.const_factor() for u in np], c)) > 1 and 0 <= sum(u.vmin for u in p) and sum(u.vmax for u in p) < d:
|
||||
return cast(UOp, functools.reduce(operator.add, np).divides(d))<(c//d)
|
||||
return None
|
||||
@@ -134,7 +129,7 @@ def canonicalize_simplex(X:UOp) -> UOp|None:
|
||||
# (X := a0*x0 + a1*x1 + ...) > 0 is equivalent to x0 + x1 + ... > 0 if xi >= 0 and ai > 0 for ints.
|
||||
# returns x0 + x1 + ... in such case, or None if not
|
||||
changed, ret = False, []
|
||||
for u in split_uop(X, Ops.ADD):
|
||||
for u in X.split_uop(Ops.ADD):
|
||||
# assumed the const is the last src of MUL
|
||||
if u.op is Ops.MUL and u.src[1].op is Ops.CONST and u.src[1].arg > 0:
|
||||
changed = True
|
||||
@@ -158,7 +153,7 @@ def remove_nested_mod(m: UOp, x: UOp, y: UOp) -> UOp|None:
|
||||
if ((c := y.arg) < 0) or x.vmin<0: return None
|
||||
new_xs = []
|
||||
something_changed = False
|
||||
for u in split_uop(x, Ops.ADD):
|
||||
for u in x.split_uop(Ops.ADD):
|
||||
if u.op is Ops.MOD:
|
||||
if u.src[1].divides(c) is not None:
|
||||
something_changed = True
|
||||
@@ -172,7 +167,7 @@ def fold_binary_numerator(d: UOp, x: UOp, y: UOp) -> UOp|None:
|
||||
# we can fold if the expression has only one non-constant term and this term can only take on two values
|
||||
if ((c := y.arg) < 0) or (x.dtype.count > 1): return None
|
||||
x,const = x.pop_const()
|
||||
terms, factors = zip(*[(u.divides(f:=u.const_factor()),f) for u in split_uop(x, Ops.ADD)])
|
||||
terms, factors = zip(*[(u.divides(f:=u.const_factor()),f) for u in x.split_uop(Ops.ADD)])
|
||||
if len(terms)==1 and (v:=terms[0]).vmax-v.vmin == 1:
|
||||
y1 = cmod(factors[0]*v.vmin+const, c) if d.op is Ops.MOD else cdiv(factors[0]*v.vmin+const, c) # type: ignore
|
||||
y2 = cmod(factors[0]*v.vmax+const, c) if d.op is Ops.MOD else cdiv(factors[0]*v.vmax+const, c) # type: ignore
|
||||
@@ -183,7 +178,7 @@ def fold_divmod_congruence(d: UOp, x: UOp, y: UOp) -> UOp|None:
|
||||
# within a mod we can freely subtract multiples of c, we use this to see if a is congruent to an expression whose vmin/vmax are between 0 and c
|
||||
if (x.vmin<0 and CORRECT_DIVMOD_FOLDING) or ((c := y.arg) < 0) or (x.dtype.count > 1): return None
|
||||
x,const = x.pop_const()
|
||||
terms, factors = zip(*[(u.divides(f:=u.const_factor()),f) for u in split_uop(x, Ops.ADD)])
|
||||
terms, factors = zip(*[(u.divides(f:=u.const_factor()),f) for u in x.split_uop(Ops.ADD)])
|
||||
# a//c = (a-a%c)/c, if we can fold a%c, we can fold a//c
|
||||
rems = [min((r:=f%c), r-c, key=abs) for f in factors]
|
||||
if (rem:=sum(r*v for r,v in zip(rems,terms))+const%c).vmin//c!=rem.vmax//c: return None
|
||||
@@ -192,7 +187,7 @@ def fold_divmod_congruence(d: UOp, x: UOp, y: UOp) -> UOp|None:
|
||||
|
||||
def divide_by_gcd(d: UOp, x: UOp, y: UOp) -> UOp|None:
|
||||
# x//y -> (x//gcd)//(y//gcd) or x%y -> gcd*(x//gcd)%(y//gcd)
|
||||
terms, factors = zip(*[(u.divides(f:=u.const_factor()),f) for u in split_uop(x, Ops.ADD)])
|
||||
terms, factors = zip(*[(u.divides(f:=u.const_factor()),f) for u in x.split_uop(Ops.ADD)])
|
||||
if (gcd := math.gcd(y.arg, *factors)) == 1: return None
|
||||
ret = sum(f//gcd * v for f,v in zip(factors, terms)).alu(d.op, y.const_like(y.arg//gcd))
|
||||
return ret*gcd if d.op is Ops.MOD else ret
|
||||
@@ -200,7 +195,7 @@ def divide_by_gcd(d: UOp, x: UOp, y: UOp) -> UOp|None:
|
||||
def nest_div_by_smallest_factor(d: UOp, x: UOp, y: UOp) -> UOp|None:
|
||||
# we try and nest the div and see if it allows the numerator to be simplified
|
||||
if ((c := y.arg) < 0) or (x.dtype.count > 1): return None
|
||||
factors = [u.const_factor() for u in split_uop(x.pop_const()[0], Ops.ADD)]
|
||||
factors = [u.const_factor() for u in x.pop_const()[0].split_uop(Ops.ADD)]
|
||||
# div is the smallest factor of the denominator (greater than 1) out of all "factors"
|
||||
# TODO: there are better ways to pick `div`, this sometimes adds extra divisions
|
||||
# TODO: add same optimization for mod
|
||||
@@ -212,7 +207,7 @@ def simplify_remainder(d: UOp, x: UOp, y: UOp) -> UOp|None:
|
||||
# we try and take out the quotient and see if it allows the numerator to be simplified
|
||||
if ((c := y.arg) < 0) or (x.dtype.count > 1): return None
|
||||
x_no_const,const = x.pop_const()
|
||||
terms, factors = zip(*[(u.divides(f:=u.const_factor()),f) for u in split_uop(x_no_const, Ops.ADD)])
|
||||
terms, factors = zip(*[(u.divides(f:=u.const_factor()),f) for u in x_no_const.split_uop(Ops.ADD)])
|
||||
quotients, remainders = zip(*[divmod(f, c) for f in factors])
|
||||
gcd = math.gcd(c, *remainders) # gcd without const!
|
||||
if const%c==const and gcd==1 and not any(r==0 or (r!=f and d.op is Ops.MOD) for r,f in zip(remainders, factors)): return None
|
||||
@@ -385,7 +380,7 @@ def uop_given_valid(valid:UOp, uop:UOp) -> UOp|None:
|
||||
|
||||
# first, parse valid into {expr: (lower_bound, upper_bound)}
|
||||
bounds:defaultdict[UOp, list[ConstType|None]] = defaultdict(lambda: [None, None])
|
||||
for stmt in split_uop(valid, Ops.AND):
|
||||
for stmt in valid.split_uop(Ops.AND):
|
||||
try: expr, is_upper, c = parse_valid(stmt)
|
||||
except ValueError: return uop # give up if we cannot parse the valid
|
||||
bounds[expr][int(is_upper)] = c
|
||||
@@ -404,9 +399,9 @@ def uop_given_valid(valid:UOp, uop:UOp) -> UOp|None:
|
||||
continue
|
||||
# every candidate is a set of constrained UOp based on valid, and if every item in a set simplifies the uop into a same output, we rewrite uop
|
||||
candidates = []
|
||||
if expr.op is Ops.ADD and v0 == 1 and all(u.op in GroupOp.Irreducible for u in split_uop(expr, Ops.ADD)):
|
||||
if expr.op is Ops.ADD and v0 == 1 and all(u.op in GroupOp.Irreducible for u in expr.split_uop(Ops.ADD)):
|
||||
# if the constraint is a simplex: X0 + X1 + ... > 0, we can check if all Xi > 0 simplify into the same output
|
||||
candidates.append([(Xi, UOp.variable("fake", 1, Xi.vmax, Xi.dtype)) for Xi in split_uop(expr, Ops.ADD)])
|
||||
candidates.append([(Xi, UOp.variable("fake", 1, Xi.vmax, Xi.dtype)) for Xi in expr.split_uop(Ops.ADD)])
|
||||
# try checking the whole clause
|
||||
if expr in uop.toposort(): candidates.append([(expr, UOp.variable("fake", v0, v1, expr.dtype))])
|
||||
|
||||
@@ -430,7 +425,7 @@ def _valid_priority(v: UOp, valids:list[UOp]):
|
||||
def simplify_valid(valid:UOp) -> UOp|None:
|
||||
ret:list[UOp] = []
|
||||
something_changed = False
|
||||
valids = list(split_uop(valid, Ops.AND))
|
||||
valids = list(valid.split_uop(Ops.AND))
|
||||
for stmt in sorted(valids, key=lambda v: _valid_priority(v, valids)):
|
||||
# TODO: root cause this and test_simplify_valid_from_div
|
||||
if stmt.op is Ops.CAST: return None
|
||||
@@ -444,7 +439,7 @@ def reduce_mul_chain(r:UOp):
|
||||
if r.arg not in {Ops.ADD, Ops.MAX}: return None
|
||||
if r.dtype != r.src[0].dtype: return None
|
||||
inside, outside = [], []
|
||||
for m in split_uop(r.src[0], Ops.MUL):
|
||||
for m in r.src[0].split_uop(Ops.MUL):
|
||||
m_parents = m.toposort()
|
||||
if all(r not in m_parents for r in r.src[1:]) and (r.arg != Ops.MAX or m.vmin >= 0): outside.append(m)
|
||||
else: inside.append(m)
|
||||
|
||||
Reference in New Issue
Block a user