mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
handle some fancier reduces (#10057)
* reduce_unparented * handle fancier reduces * fold more * bugfix
This commit is contained in:
@@ -332,16 +332,34 @@ def reduce_to_acc(ctx:ReduceContext, red:UOp):
|
||||
ret = functools.reduce(lambda x,y: x.alu(red.arg, y), lst)
|
||||
return acc.assign(ret) if len(reduce_range) != 0 else ret
|
||||
|
||||
def reduce_unparented(red:UOp):
|
||||
if red.arg not in {Ops.ADD, Ops.MAX}: return None
|
||||
reduce_parented, reduce_unparented = partition(red.src[1:], lambda x: x in red.src[0].sparents)
|
||||
if len(reduce_unparented) == 0: return None
|
||||
ret = red.replace(src=(red.src[0],)+tuple(reduce_parented)) if len(reduce_parented) or red.dtype != red.src[0].dtype else red.src[0]
|
||||
if red.arg is Ops.ADD:
|
||||
for r in reduce_unparented: ret = ret * (r.src[1]-r.src[0]).cast(ret.dtype.scalar()).broadcast(ret.dtype.count)
|
||||
return ret
|
||||
|
||||
def no_vectorized_reduce(inp:UOp, red:UOp):
|
||||
if inp.dtype != red.dtype:
|
||||
# NOTE: [0 1 2 3 4 5 6 7] -> [0+4, 1+5, 2+6, 3+7]
|
||||
horizontal_amount = inp.dtype.count//red.dtype.count
|
||||
lst = [inp.gep(tuple(range(i, inp.dtype.count, horizontal_amount))) for i in range(0, horizontal_amount)]
|
||||
red = red.replace(src=(functools.reduce(lambda x,y: x.alu(red.arg, y), lst),))
|
||||
red = red.replace(src=(functools.reduce(lambda x,y: x.alu(red.arg, y), lst),)+red.src[1:])
|
||||
if red.dtype.vcount == 1: return red
|
||||
return no_vectorized_alu(red)
|
||||
# no_vectorize_alu ignoring ranges
|
||||
if red.dtype.vcount == 1: return None
|
||||
alus = tuple(UOp(red.op, red.dtype.scalar(), (red.src[0].gep(i),)+red.src[1:], red.arg) for i in range(red.dtype.vcount))
|
||||
return UOp(Ops.VECTORIZE, red.dtype, alus)
|
||||
|
||||
def range_fold(lo:UOp, hi:UOp, st:UOp, cut:UOp, val:UOp) -> UOp:
|
||||
def range_fold_lo(lo:UOp, hi:UOp, st:UOp, cut:UOp, val:UOp) -> UOp:
|
||||
# psuedo code: sum(val if i < cut else 0) for i in range(lo, hi, st))
|
||||
total = (hi-lo+st-1) // st # real count in the range
|
||||
length = ((cut-lo+st-1) // st).maximum(0).minimum(total)
|
||||
return length.cast(val.dtype) * val
|
||||
|
||||
def range_fold_hi(lo:UOp, hi:UOp, st:UOp, cut:UOp, val:UOp) -> UOp:
|
||||
# psuedo code: sum(val if i >= cut else 0) for i in range(lo, hi, st))
|
||||
# TODO: this function is so tricky and still probably wrong. test it
|
||||
total = (hi-lo+st-1) // st # real count in the range
|
||||
@@ -360,19 +378,22 @@ pm_reduce_collapse = PatternMatcher([
|
||||
(UPat.var("x") * UPat(Ops.RANGE, name="r"), lambda x,r: r.replace(src=(r.src[0]*x, r.src[1]*x, r.src[2]*x))),
|
||||
# add to range
|
||||
(UPat.var("x") + UPat(Ops.RANGE, name="r"), lambda x,r: r.replace(src=(r.src[0]+x, r.src[1]+x, r.src[2]))),
|
||||
# 0 is "true" arg in where. fold the range
|
||||
# fold the range with 0 in either the true or false slot
|
||||
((UPat(Ops.RANGE, src=(UPat.var("lo"), UPat.var("hi"), UPat.var("st"))) < UPat.cvar("cut")) \
|
||||
.where(UPat(Ops.CONST, arg=0), UPat.cvar("val")).reduce(arg=Ops.ADD), range_fold),
|
||||
.where(UPat.cvar("val"), 0).reduce(arg=Ops.ADD, allow_any_len=True), range_fold_lo),
|
||||
((UPat(Ops.RANGE, src=(UPat.var("lo"), UPat.var("hi"), UPat.var("st"))) < UPat.cvar("cut")) \
|
||||
.where(UPat(Ops.CONST, arg=0), UPat.cvar("val")).reduce(arg=Ops.ADD, allow_any_len=True), range_fold_hi),
|
||||
# devectorize REDUCE
|
||||
(UPat(Ops.VECTORIZE, name="inp").reduce(name="red"), no_vectorized_reduce),
|
||||
(UPat(Ops.VECTORIZE, name="inp").reduce(name="red", allow_any_len=True), no_vectorized_reduce),
|
||||
# REDUCE on ADD
|
||||
((UPat.var("x")+UPat.var("y")).reduce(arg=Ops.ADD), lambda x,y: x.reduce(arg=Ops.ADD) + y.reduce(arg=Ops.ADD)),
|
||||
((UPat.var("x")+UPat.var("y")).reduce(arg=Ops.ADD, allow_any_len=True), lambda x,y: x.reduce(arg=Ops.ADD) + y.reduce(arg=Ops.ADD)),
|
||||
# MUL casted bool
|
||||
((UPat.var("x") * UPat.var("gate", dtype=dtypes.bool).cast()), lambda x,gate: gate.where(x, 0)),
|
||||
# WHERE on LOAD (works on max too)
|
||||
(UPat.var("gate").where(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))).load(), 0).reduce(arg=Ops.ADD),
|
||||
(UPat.var("gate").where(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))).load(), 0).reduce(arg=Ops.ADD, allow_any_len=True),
|
||||
lambda buf,idx,gate: buf.index(idx, gate).load()),
|
||||
(UPat.var("gate").where(UPat(Ops.CONST, arg=0), UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))).load()).reduce(arg=Ops.ADD),
|
||||
(UPat.var("gate").where(UPat(Ops.CONST, arg=0),
|
||||
UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))).load()).reduce(arg=Ops.ADD, allow_any_len=True),
|
||||
lambda buf,idx,gate: buf.index(idx, gate.logical_not()).load()),
|
||||
# INDEX on RANGE / gated RANGE
|
||||
(UPat.var("buf").index(UPat(Ops.RANGE, name="r"), UPat.var("idx").eq(UPat(Ops.RANGE, name="r2"))), index_fold),
|
||||
@@ -380,12 +401,14 @@ pm_reduce_collapse = PatternMatcher([
|
||||
(UPat((Ops.INDEX, Ops.LOAD), name="alu"), no_vectorized_alu),
|
||||
# cast on RANGE (fix torch indexing)
|
||||
(UPat(Ops.RANGE, name="r").cast(name="c"), lambda r,c: r.replace(src=tuple([x.cast(c.dtype) for x in r.src]), dtype=c.dtype)),
|
||||
# remove any ranges from a REDUCE that aren't referenced in the reduce source
|
||||
(UPat(Ops.REDUCE, name="red"), reduce_unparented),
|
||||
])+sym
|
||||
|
||||
def reduce_collapse(red:UOp):
|
||||
included, not_included = partition(red.parents, lambda x: any(y in x.sparents for y in red.src[1:]))
|
||||
if any(x.op in {Ops.STORE, Ops.REDUCE} for x in included): return None
|
||||
replaces = {red:red.replace(src=red.src[0:1])}
|
||||
replaces: dict[UOp, UOp] = {}
|
||||
for u in included:
|
||||
for s in u.src:
|
||||
if s in not_included and s not in replaces and s.op not in {Ops.CONST, Ops.VCONST, Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_VAR}:
|
||||
@@ -395,15 +418,6 @@ def reduce_collapse(red:UOp):
|
||||
if any(x.op in {Ops.REDUCE, Ops.RANGE} for x in sink.toposort()): return None
|
||||
return sink.substitute({v:k for k,v in replaces.items()})
|
||||
|
||||
def reduce_unparented(red:UOp):
|
||||
if red.arg not in {Ops.ADD, Ops.MAX}: return None
|
||||
reduce_parented, reduce_unparented = partition(red.src[1:], lambda x: x in red.src[0].sparents)
|
||||
if len(reduce_unparented) == 0: return None
|
||||
ret = red.replace(src=(red.src[0],)+tuple(reduce_parented))
|
||||
if red.arg is Ops.ADD:
|
||||
for r in reduce_unparented: ret = ret * (r.src[1]-r.src[0]).cast(ret.dtype.scalar()).broadcast(ret.dtype.count)
|
||||
return ret
|
||||
|
||||
pm_reduce = PatternMatcher([
|
||||
# remove any ranges from a REDUCE that aren't referenced in the reduce source
|
||||
(UPat(Ops.REDUCE, name="red"), reduce_unparented),
|
||||
|
||||
Reference in New Issue
Block a user