handle some fancier reduces (#10057)

* reduce_unparented

* handle fancier reduces

* fold more

* bugfix
This commit is contained in:
George Hotz
2025-04-26 11:20:15 -04:00
committed by GitHub
parent e08270c1ba
commit c80fe6d5fc

View File

@@ -332,16 +332,34 @@ def reduce_to_acc(ctx:ReduceContext, red:UOp):
ret = functools.reduce(lambda x,y: x.alu(red.arg, y), lst)
return acc.assign(ret) if len(reduce_range) != 0 else ret
def reduce_unparented(red:UOp):
if red.arg not in {Ops.ADD, Ops.MAX}: return None
reduce_parented, reduce_unparented = partition(red.src[1:], lambda x: x in red.src[0].sparents)
if len(reduce_unparented) == 0: return None
ret = red.replace(src=(red.src[0],)+tuple(reduce_parented)) if len(reduce_parented) or red.dtype != red.src[0].dtype else red.src[0]
if red.arg is Ops.ADD:
for r in reduce_unparented: ret = ret * (r.src[1]-r.src[0]).cast(ret.dtype.scalar()).broadcast(ret.dtype.count)
return ret
def no_vectorized_reduce(inp:UOp, red:UOp):
if inp.dtype != red.dtype:
# NOTE: [0 1 2 3 4 5 6 7] -> [0+4, 1+5, 2+6, 3+7]
horizontal_amount = inp.dtype.count//red.dtype.count
lst = [inp.gep(tuple(range(i, inp.dtype.count, horizontal_amount))) for i in range(0, horizontal_amount)]
red = red.replace(src=(functools.reduce(lambda x,y: x.alu(red.arg, y), lst),))
red = red.replace(src=(functools.reduce(lambda x,y: x.alu(red.arg, y), lst),)+red.src[1:])
if red.dtype.vcount == 1: return red
return no_vectorized_alu(red)
# no_vectorize_alu ignoring ranges
if red.dtype.vcount == 1: return None
alus = tuple(UOp(red.op, red.dtype.scalar(), (red.src[0].gep(i),)+red.src[1:], red.arg) for i in range(red.dtype.vcount))
return UOp(Ops.VECTORIZE, red.dtype, alus)
def range_fold(lo:UOp, hi:UOp, st:UOp, cut:UOp, val:UOp) -> UOp:
def range_fold_lo(lo:UOp, hi:UOp, st:UOp, cut:UOp, val:UOp) -> UOp:
# psuedo code: sum(val if i < cut else 0) for i in range(lo, hi, st))
total = (hi-lo+st-1) // st # real count in the range
length = ((cut-lo+st-1) // st).maximum(0).minimum(total)
return length.cast(val.dtype) * val
def range_fold_hi(lo:UOp, hi:UOp, st:UOp, cut:UOp, val:UOp) -> UOp:
# psuedo code: sum(val if i >= cut else 0) for i in range(lo, hi, st))
# TODO: this function is so tricky and still probably wrong. test it
total = (hi-lo+st-1) // st # real count in the range
@@ -360,19 +378,22 @@ pm_reduce_collapse = PatternMatcher([
(UPat.var("x") * UPat(Ops.RANGE, name="r"), lambda x,r: r.replace(src=(r.src[0]*x, r.src[1]*x, r.src[2]*x))),
# add to range
(UPat.var("x") + UPat(Ops.RANGE, name="r"), lambda x,r: r.replace(src=(r.src[0]+x, r.src[1]+x, r.src[2]))),
# 0 is "true" arg in where. fold the range
# fold the range with 0 in either the true or false slot
((UPat(Ops.RANGE, src=(UPat.var("lo"), UPat.var("hi"), UPat.var("st"))) < UPat.cvar("cut")) \
.where(UPat(Ops.CONST, arg=0), UPat.cvar("val")).reduce(arg=Ops.ADD), range_fold),
.where(UPat.cvar("val"), 0).reduce(arg=Ops.ADD, allow_any_len=True), range_fold_lo),
((UPat(Ops.RANGE, src=(UPat.var("lo"), UPat.var("hi"), UPat.var("st"))) < UPat.cvar("cut")) \
.where(UPat(Ops.CONST, arg=0), UPat.cvar("val")).reduce(arg=Ops.ADD, allow_any_len=True), range_fold_hi),
# devectorize REDUCE
(UPat(Ops.VECTORIZE, name="inp").reduce(name="red"), no_vectorized_reduce),
(UPat(Ops.VECTORIZE, name="inp").reduce(name="red", allow_any_len=True), no_vectorized_reduce),
# REDUCE on ADD
((UPat.var("x")+UPat.var("y")).reduce(arg=Ops.ADD), lambda x,y: x.reduce(arg=Ops.ADD) + y.reduce(arg=Ops.ADD)),
((UPat.var("x")+UPat.var("y")).reduce(arg=Ops.ADD, allow_any_len=True), lambda x,y: x.reduce(arg=Ops.ADD) + y.reduce(arg=Ops.ADD)),
# MUL casted bool
((UPat.var("x") * UPat.var("gate", dtype=dtypes.bool).cast()), lambda x,gate: gate.where(x, 0)),
# WHERE on LOAD (works on max too)
(UPat.var("gate").where(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))).load(), 0).reduce(arg=Ops.ADD),
(UPat.var("gate").where(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))).load(), 0).reduce(arg=Ops.ADD, allow_any_len=True),
lambda buf,idx,gate: buf.index(idx, gate).load()),
(UPat.var("gate").where(UPat(Ops.CONST, arg=0), UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))).load()).reduce(arg=Ops.ADD),
(UPat.var("gate").where(UPat(Ops.CONST, arg=0),
UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))).load()).reduce(arg=Ops.ADD, allow_any_len=True),
lambda buf,idx,gate: buf.index(idx, gate.logical_not()).load()),
# INDEX on RANGE / gated RANGE
(UPat.var("buf").index(UPat(Ops.RANGE, name="r"), UPat.var("idx").eq(UPat(Ops.RANGE, name="r2"))), index_fold),
@@ -380,12 +401,14 @@ pm_reduce_collapse = PatternMatcher([
(UPat((Ops.INDEX, Ops.LOAD), name="alu"), no_vectorized_alu),
# cast on RANGE (fix torch indexing)
(UPat(Ops.RANGE, name="r").cast(name="c"), lambda r,c: r.replace(src=tuple([x.cast(c.dtype) for x in r.src]), dtype=c.dtype)),
# remove any ranges from a REDUCE that aren't referenced in the reduce source
(UPat(Ops.REDUCE, name="red"), reduce_unparented),
])+sym
def reduce_collapse(red:UOp):
included, not_included = partition(red.parents, lambda x: any(y in x.sparents for y in red.src[1:]))
if any(x.op in {Ops.STORE, Ops.REDUCE} for x in included): return None
replaces = {red:red.replace(src=red.src[0:1])}
replaces: dict[UOp, UOp] = {}
for u in included:
for s in u.src:
if s in not_included and s not in replaces and s.op not in {Ops.CONST, Ops.VCONST, Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_VAR}:
@@ -395,15 +418,6 @@ def reduce_collapse(red:UOp):
if any(x.op in {Ops.REDUCE, Ops.RANGE} for x in sink.toposort()): return None
return sink.substitute({v:k for k,v in replaces.items()})
def reduce_unparented(red:UOp):
if red.arg not in {Ops.ADD, Ops.MAX}: return None
reduce_parented, reduce_unparented = partition(red.src[1:], lambda x: x in red.src[0].sparents)
if len(reduce_unparented) == 0: return None
ret = red.replace(src=(red.src[0],)+tuple(reduce_parented))
if red.arg is Ops.ADD:
for r in reduce_unparented: ret = ret * (r.src[1]-r.src[0]).cast(ret.dtype.scalar()).broadcast(ret.dtype.count)
return ret
pm_reduce = PatternMatcher([
# remove any ranges from a REDUCE that aren't referenced in the reduce source
(UPat(Ops.REDUCE, name="red"), reduce_unparented),