diff --git a/tinygrad/codegen/devectorizer.py b/tinygrad/codegen/devectorizer.py index d7b0ef30c3..4d8bccea9d 100644 --- a/tinygrad/codegen/devectorizer.py +++ b/tinygrad/codegen/devectorizer.py @@ -332,16 +332,34 @@ def reduce_to_acc(ctx:ReduceContext, red:UOp): ret = functools.reduce(lambda x,y: x.alu(red.arg, y), lst) return acc.assign(ret) if len(reduce_range) != 0 else ret +def reduce_unparented(red:UOp): + if red.arg not in {Ops.ADD, Ops.MAX}: return None + reduce_parented, reduce_unparented = partition(red.src[1:], lambda x: x in red.src[0].sparents) + if len(reduce_unparented) == 0: return None + ret = red.replace(src=(red.src[0],)+tuple(reduce_parented)) if len(reduce_parented) or red.dtype != red.src[0].dtype else red.src[0] + if red.arg is Ops.ADD: + for r in reduce_unparented: ret = ret * (r.src[1]-r.src[0]).cast(ret.dtype.scalar()).broadcast(ret.dtype.count) + return ret + def no_vectorized_reduce(inp:UOp, red:UOp): if inp.dtype != red.dtype: # NOTE: [0 1 2 3 4 5 6 7] -> [0+4, 1+5, 2+6, 3+7] horizontal_amount = inp.dtype.count//red.dtype.count lst = [inp.gep(tuple(range(i, inp.dtype.count, horizontal_amount))) for i in range(0, horizontal_amount)] - red = red.replace(src=(functools.reduce(lambda x,y: x.alu(red.arg, y), lst),)) + red = red.replace(src=(functools.reduce(lambda x,y: x.alu(red.arg, y), lst),)+red.src[1:]) if red.dtype.vcount == 1: return red - return no_vectorized_alu(red) + # no_vectorize_alu ignoring ranges + if red.dtype.vcount == 1: return None + alus = tuple(UOp(red.op, red.dtype.scalar(), (red.src[0].gep(i),)+red.src[1:], red.arg) for i in range(red.dtype.vcount)) + return UOp(Ops.VECTORIZE, red.dtype, alus) -def range_fold(lo:UOp, hi:UOp, st:UOp, cut:UOp, val:UOp) -> UOp: +def range_fold_lo(lo:UOp, hi:UOp, st:UOp, cut:UOp, val:UOp) -> UOp: + # psuedo code: sum(val if i < cut else 0) for i in range(lo, hi, st)) + total = (hi-lo+st-1) // st # real count in the range + length = ((cut-lo+st-1) // st).maximum(0).minimum(total) + return length.cast(val.dtype) * val + +def range_fold_hi(lo:UOp, hi:UOp, st:UOp, cut:UOp, val:UOp) -> UOp: # psuedo code: sum(val if i >= cut else 0) for i in range(lo, hi, st)) # TODO: this function is so tricky and still probably wrong. test it total = (hi-lo+st-1) // st # real count in the range @@ -360,19 +378,22 @@ pm_reduce_collapse = PatternMatcher([ (UPat.var("x") * UPat(Ops.RANGE, name="r"), lambda x,r: r.replace(src=(r.src[0]*x, r.src[1]*x, r.src[2]*x))), # add to range (UPat.var("x") + UPat(Ops.RANGE, name="r"), lambda x,r: r.replace(src=(r.src[0]+x, r.src[1]+x, r.src[2]))), - # 0 is "true" arg in where. fold the range + # fold the range with 0 in either the true or false slot ((UPat(Ops.RANGE, src=(UPat.var("lo"), UPat.var("hi"), UPat.var("st"))) < UPat.cvar("cut")) \ - .where(UPat(Ops.CONST, arg=0), UPat.cvar("val")).reduce(arg=Ops.ADD), range_fold), + .where(UPat.cvar("val"), 0).reduce(arg=Ops.ADD, allow_any_len=True), range_fold_lo), + ((UPat(Ops.RANGE, src=(UPat.var("lo"), UPat.var("hi"), UPat.var("st"))) < UPat.cvar("cut")) \ + .where(UPat(Ops.CONST, arg=0), UPat.cvar("val")).reduce(arg=Ops.ADD, allow_any_len=True), range_fold_hi), # devectorize REDUCE - (UPat(Ops.VECTORIZE, name="inp").reduce(name="red"), no_vectorized_reduce), + (UPat(Ops.VECTORIZE, name="inp").reduce(name="red", allow_any_len=True), no_vectorized_reduce), # REDUCE on ADD - ((UPat.var("x")+UPat.var("y")).reduce(arg=Ops.ADD), lambda x,y: x.reduce(arg=Ops.ADD) + y.reduce(arg=Ops.ADD)), + ((UPat.var("x")+UPat.var("y")).reduce(arg=Ops.ADD, allow_any_len=True), lambda x,y: x.reduce(arg=Ops.ADD) + y.reduce(arg=Ops.ADD)), # MUL casted bool ((UPat.var("x") * UPat.var("gate", dtype=dtypes.bool).cast()), lambda x,gate: gate.where(x, 0)), # WHERE on LOAD (works on max too) - (UPat.var("gate").where(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))).load(), 0).reduce(arg=Ops.ADD), + (UPat.var("gate").where(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))).load(), 0).reduce(arg=Ops.ADD, allow_any_len=True), lambda buf,idx,gate: buf.index(idx, gate).load()), - (UPat.var("gate").where(UPat(Ops.CONST, arg=0), UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))).load()).reduce(arg=Ops.ADD), + (UPat.var("gate").where(UPat(Ops.CONST, arg=0), + UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))).load()).reduce(arg=Ops.ADD, allow_any_len=True), lambda buf,idx,gate: buf.index(idx, gate.logical_not()).load()), # INDEX on RANGE / gated RANGE (UPat.var("buf").index(UPat(Ops.RANGE, name="r"), UPat.var("idx").eq(UPat(Ops.RANGE, name="r2"))), index_fold), @@ -380,12 +401,14 @@ pm_reduce_collapse = PatternMatcher([ (UPat((Ops.INDEX, Ops.LOAD), name="alu"), no_vectorized_alu), # cast on RANGE (fix torch indexing) (UPat(Ops.RANGE, name="r").cast(name="c"), lambda r,c: r.replace(src=tuple([x.cast(c.dtype) for x in r.src]), dtype=c.dtype)), + # remove any ranges from a REDUCE that aren't referenced in the reduce source + (UPat(Ops.REDUCE, name="red"), reduce_unparented), ])+sym def reduce_collapse(red:UOp): included, not_included = partition(red.parents, lambda x: any(y in x.sparents for y in red.src[1:])) if any(x.op in {Ops.STORE, Ops.REDUCE} for x in included): return None - replaces = {red:red.replace(src=red.src[0:1])} + replaces: dict[UOp, UOp] = {} for u in included: for s in u.src: if s in not_included and s not in replaces and s.op not in {Ops.CONST, Ops.VCONST, Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_VAR}: @@ -395,15 +418,6 @@ def reduce_collapse(red:UOp): if any(x.op in {Ops.REDUCE, Ops.RANGE} for x in sink.toposort()): return None return sink.substitute({v:k for k,v in replaces.items()}) -def reduce_unparented(red:UOp): - if red.arg not in {Ops.ADD, Ops.MAX}: return None - reduce_parented, reduce_unparented = partition(red.src[1:], lambda x: x in red.src[0].sparents) - if len(reduce_unparented) == 0: return None - ret = red.replace(src=(red.src[0],)+tuple(reduce_parented)) - if red.arg is Ops.ADD: - for r in reduce_unparented: ret = ret * (r.src[1]-r.src[0]).cast(ret.dtype.scalar()).broadcast(ret.dtype.count) - return ret - pm_reduce = PatternMatcher([ # remove any ranges from a REDUCE that aren't referenced in the reduce source (UPat(Ops.REDUCE, name="red"), reduce_unparented),