diff --git a/tinygrad/codegen/devectorizer.py b/tinygrad/codegen/devectorizer.py index 586b37e6a7..e98f2d4b68 100644 --- a/tinygrad/codegen/devectorizer.py +++ b/tinygrad/codegen/devectorizer.py @@ -3,7 +3,7 @@ import functools, operator, itertools from collections import defaultdict from dataclasses import dataclass from tinygrad.device import is_dtype_supported -from tinygrad.dtype import dtypes, ImageDType, PtrDType, promo_lattice +from tinygrad.dtype import dtypes, ImageDType, PtrDType, promo_lattice, DType from tinygrad.ops import UOp, Ops, UPat, PatternMatcher, resolve, graph_rewrite, GroupOp, identity_element from tinygrad.codegen.symbolic import symbolic_simple, split_uop, uop_given_valid, parse_valid, simplify_valid, sym, symbolic_flat, gep_pushing from tinygrad.helpers import getenv, flatten, TRANSCENDENTAL, AMX, prod, DEVECTORIZE, partition @@ -315,15 +315,17 @@ pm_render = PatternMatcher([ class ReduceContext: acc_num: int = 0 +def horizontal_reduce(inp:UOp, out_dtype:DType) -> list[UOp]: + # if this has a horizontal reduction component, do that first + if inp.dtype != out_dtype: + # NOTE: [0 1 2 3 4 5 6 7] -> [0+4, 1+5, 2+6, 3+7] + horizontal_amount = inp.dtype.count//out_dtype.count + return [inp.gep(tuple(range(i, inp.dtype.count, horizontal_amount))) for i in range(0, horizontal_amount)] + return [inp] + def reduce_to_acc(ctx:ReduceContext, red:UOp): inp, reduce_range = red.src[0], red.src[1:] - # if this has a horizontal reduction component, do that first - if inp.dtype != red.dtype: - # NOTE: [0 1 2 3 4 5 6 7] -> [0+4, 1+5, 2+6, 3+7] - horizontal_amount = inp.dtype.count//red.dtype.count - lst = [inp.gep(tuple(range(i, inp.dtype.count, horizontal_amount))) for i in range(0, horizontal_amount)] - else: - lst = [inp] + lst = horizontal_reduce(inp, red.dtype) assert all(x.dtype == red.dtype for x in lst), f"horizontal reduction mismatch {lst[0].dtype} != {red.dtype}" # if we have a range if len(reduce_range) != 0: @@ -335,10 +337,7 @@ def reduce_to_acc(ctx:ReduceContext, red:UOp): def no_vectorized_reduce(inp:UOp, red:UOp): if inp.dtype != red.dtype: - # NOTE: [0 1 2 3 4 5 6 7] -> [0+4, 1+5, 2+6, 3+7] - horizontal_amount = inp.dtype.count//red.dtype.count - lst = [inp.gep(tuple(range(i, inp.dtype.count, horizontal_amount))) for i in range(0, horizontal_amount)] - red = red.replace(src=(functools.reduce(lambda x,y: x.alu(red.arg, y), lst),)+red.src[1:]) + red = red.replace(src=(functools.reduce(lambda x,y: x.alu(red.arg, y), horizontal_reduce(inp, red.dtype)),)+red.src[1:]) if red.dtype.vcount == 1: return red # no_vectorize_alu ignoring ranges if red.dtype.vcount == 1: return None @@ -367,34 +366,34 @@ pm_reduce_collapse = PatternMatcher([ # lift x+y out of reduce on ne ((UPat.var("x")+UPat.var("y")) != UPat.var("c"), lambda x,y,c: (x != (c-y)) if no_range(y) and no_range(c) else None), # fold the range - ((UPat(Ops.RANGE, name="r") < UPat.var("cut")).where(UPat(Ops.CONST, arg=0), UPat.cvar("val")).reduce(arg=Ops.ADD, allow_any_len=True), + ((UPat(Ops.RANGE, name="r") < UPat.var("cut")).where(0, UPat.cvar("val")).reduce(arg=Ops.ADD, allow_any_len=True), lambda r,cut,val: (r.src[0]-cut).maximum(0).minimum(r.src[0]).cast(val.dtype) * val), ((UPat(Ops.RANGE, name="r") < UPat.var("cut")).where(UPat.cvar("val"), 0).reduce(arg=Ops.ADD, allow_any_len=True), lambda r,cut,val: cut.maximum(0).minimum(r.src[0]).cast(val.dtype) * val), - # devectorize REDUCE - (UPat(Ops.VECTORIZE, name="inp").reduce(name="red", allow_any_len=True), no_vectorized_reduce), # REDUCE on ADD ((UPat.var("x")+UPat.var("y")).reduce(arg=Ops.ADD, allow_any_len=True, name="r"), lambda x,y,r: x.reduce(*r.src[1:], arg=Ops.ADD) + y.reduce(*r.src[1:],arg=Ops.ADD)), # MUL casted bool - ((UPat.var("x") * UPat.var("gate", dtype=dtypes.bool).cast()), lambda x,gate: gate.where(x, 0)), + ((UPat.var("x") * UPat.var("gate", dtype=dtypes.bool).cast().or_broadcasted(name="b")), + lambda x,gate,b=None: gate.broadcast(x.dtype.count).where(x, 0) if b is not None else gate.where(x, 0)), # WHERE on LOAD (works on max too) (UPat.var("gate").where(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))).load(), 0).reduce(arg=Ops.ADD, allow_any_len=True), lambda buf,idx,gate: buf.index(idx, gate).load()), - (UPat.var("gate").where(UPat(Ops.CONST, arg=0), - UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))).load()).reduce(arg=Ops.ADD, allow_any_len=True), + (UPat.var("gate").where(0, UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))).load()).reduce(arg=Ops.ADD, allow_any_len=True), lambda buf,idx,gate: buf.index(idx, gate.logical_not()).load()), # INDEX on RANGE / gated RANGE (UPat.var("buf").index(UPat.var("expr"), UPat.var("idx").eq(UPat(Ops.RANGE, name="r").or_casted())), lambda buf,r,idx,expr: buf.index(expr.substitute({r:idx.cast(r.dtype)}), (idx.cast(r.dtype) >= 0) & (idx.cast(r.dtype) < r.src[0]))), - # index/load. TODO: this is more aggressive than needed - (UPat((Ops.INDEX, Ops.LOAD), name="alu"), no_vectorized_alu), # AND on WHERE ((UPat.any(UPat(Ops.DEFINE_VAR, name="x"), UPat(Ops.DEFINE_VAR).gep(name="x")) & UPat.var("y")) \ .where(UPat.cvar("c"), 0).reduce(arg=Ops.ADD, allow_any_len=True, name="r"), lambda x,y,c,r: y.where(c, 0).reduce(*r.src[1:], arg=Ops.ADD)*x.cast(c.dtype)), # remove REDUCEs that no longer have a RANGE in the src (UPat(Ops.REDUCE, name="red"), reduce_rangeless), + # devectorize REDUCE + (UPat(Ops.VECTORIZE, name="inp").reduce(name="red", allow_any_len=True), no_vectorized_reduce), + # index/load/where. TODO: this is more aggressive than needed + (UPat((Ops.INDEX, Ops.LOAD, Ops.WHERE), name="alu"), no_vectorized_alu), ])+sym def reduce_collapse(red:UOp): @@ -406,7 +405,7 @@ def reduce_collapse(red:UOp): if s in not_included and s not in replaces and s.op not in {Ops.CONST, Ops.VCONST, Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_VAR}: replaces[s] = UOp(Ops.DEFINE_VAR, dtype=s.dtype, arg=(f'in{len(replaces)}', s.vmin, s.vmax)) collapse_fxn = red.substitute(replaces) - sink = graph_rewrite(collapse_fxn, pm_reduce_collapse+devectorize, name="reduce_collapse") + sink = graph_rewrite(collapse_fxn, pm_reduce_collapse, name="reduce_collapse") # TODO: why is REDUCE needed here and just RANGE isn't enough? if any(x.op in {Ops.REDUCE, Ops.RANGE} for x in sink.toposort()): return None return sink.substitute({v:k for k,v in replaces.items()}) diff --git a/tinygrad/codegen/symbolic.py b/tinygrad/codegen/symbolic.py index 41bb3e6703..653210bcdc 100644 --- a/tinygrad/codegen/symbolic.py +++ b/tinygrad/codegen/symbolic.py @@ -283,7 +283,7 @@ symbolic = symbolic_simple+commutative+PatternMatcher([ ((UPat.var("x")//UPat.cvar("c") + UPat.cvar("a"))//UPat.cvar("d"), lambda x,c,a,d: (x+a*c)//(c*d) if x.vmin>=0 or x.vmax<=0 else None), # (x//c+a)//d -> (x+a*c)//(c*d) (UPat.var("x", dtypes.sints) // UPat.var("y"), lambda x,y: div_and_mod_folding(x,y,Ops.IDIV)), - ((UPat.var("x", dtypes.sints)+UPat.cvar("c")).named("n")//UPat.cvar("d"), + ((UPat.var("x", dtypes.sints)+UPat.cvar("c", vec=False)).named("n")//UPat.cvar("d", vec=False), lambda x,c,n,d: (-(-(c.arg%d.arg + x - (d.arg-1))//d) + c.arg//d.arg) if x.vmax<=0 and n.vmin>=0 and d.arg>0 else None), # ** mod ** # mod folding diff --git a/tinygrad/ops.py b/tinygrad/ops.py index aa2970d716..b047cdbcc0 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -81,7 +81,10 @@ class MathTrait(SimpleMathTrait): def maximum(self, x): return self.alu(Ops.MAX, self.ufix(x)) def minimum(self, x): return -(-self).maximum(-x) - def where(self, x, y): return self.alu(Ops.WHERE, x, x.ufix(y)) + def where(self, x, y): + if type(self) is type(x): return self.alu(Ops.WHERE, x, x.ufix(y)) + if type(self) is type(y): return self.alu(Ops.WHERE, y.ufix(x), y) + raise RuntimeError("where needs at least one UOp arg") def threefry(self, seed): return self.alu(Ops.THREEFRY, seed) def reciprocal(self): return self.alu(Ops.RECIP) def sqrt(self): return self.alu(Ops.SQRT) @@ -424,7 +427,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): axis = tuple(sorted([x for x in axis if resolve(self.shape[x] != 1)])) return self if len(axis) == 0 else UOp(Ops.REDUCE_AXIS, self.dtype, (self,), (op, axis)) def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self,x)) - def reduce(self, *src:UOp, **kwargs): return UOp(Ops.REDUCE, self.dtype, src=(self,)+src, **kwargs) + def reduce(self, *src:UOp, **kwargs): return UOp(Ops.REDUCE, kwargs.pop('dtype', self.dtype), src=(self,)+src, **kwargs) def contiguous(self): return self.alu(Ops.CONTIGUOUS) def contiguous_backward(self): return self.alu(Ops.CONTIGUOUS_BACKWARD) def fuse(self): return self.alu(Ops.FUSE) @@ -780,6 +783,7 @@ class UPat(MathTrait): def assign(self, x:UPat, **kwargs): return UPat(Ops.ASSIGN, self.dtype, (self,x), **kwargs) def reduce(self, *src:UPat, **kwargs): return UPat(Ops.REDUCE, self.dtype, src=(self,)+src, **kwargs) def fuse(self): return self.alu(Ops.FUSE) + def or_broadcasted(self, **kwargs): return UPat.any(self, UPat(Ops.VECTORIZE, self.dtype, src=self, **kwargs)) def const_like(self, b:ConstLike): return UPat.const(self.dtype, cast(ConstType, b)) def alu(self, op:Ops, *src:UPat):