update indexing with UPat.any [run_process_replay] (#6605)

This commit is contained in:
George Hotz
2024-09-19 17:40:17 +08:00
committed by GitHub
parent d148a62f8d
commit 224151a958

View File

@@ -261,7 +261,7 @@ def loop_collapse(compval, idx, mval, multconst, rng:UOp, reduce, idx2=None, idx
if extra is not None: ret = ret + UOp(UOps.REDUCE, reduce.dtype, (extra,) + reduce.src[1:], reduce.arg)
return ret
def index_collapse(idx,rng,buf,add,mul,ld,reduce):
def index_collapse(idx,rng,buf,ld,reduce,add=UOp.const(dtypes.int, 0),mul=UOp.const(dtypes.int, 1)):
if rng not in reduce.src: return None
return UOp(reduce.op, reduce.dtype, (UOp(ld.op, ld.dtype, (buf, add+mul*idx, ld.const_like(0), idx.ge(rng.src[0]) & idx.lt(rng.src[1]))),)+
tuple(x for x in reduce.src[1:] if x is not rng), reduce.arg)
@@ -337,17 +337,13 @@ constant_folder = PatternMatcher([
arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse),
# unrolled arange div folding
(UPat.var("divs") + UPat.cvar("c"), fold_unrolled_divs),
# indexing (with a multiply offset)!
# indexing, with cast or where
(UPat(UOps.REDUCE, src=(UPat.var("idx").eq(UPat(UOps.RANGE, name="rng")).cast()*
UPat(UOps.LOAD, src=(UPat.var("buf"), UPat.var("add")+UPat.var("mul")*UPat(UOps.RANGE, name="rng")), name="ld"),),
arg=BinaryOps.ADD, name="reduce", allow_any_len=True), index_collapse),
(UPat(UOps.REDUCE, src=(UPat.var("idx").eq(UPat(UOps.RANGE, name="rng")).cast()*
UPat(UOps.LOAD, src=(UPat.var("buf"), UPat(UOps.RANGE, name="rng")), name="ld"),),
arg=BinaryOps.ADD, name="reduce", allow_any_len=True),
lambda **kwargs: index_collapse(add=UOp.const(dtypes.int, 0), mul=UOp.const(dtypes.int, 1), **kwargs)),
UPat(UOps.LOAD, src=(UPat.var("buf"), UPat.any(UPat.var("add")+UPat.var("mul")*UPat(UOps.RANGE, name="rng"), UPat(UOps.RANGE, name="rng"))),
name="ld"),), arg=BinaryOps.ADD, name="reduce", allow_any_len=True), index_collapse),
(UPat(UOps.REDUCE, src=(UPat.var("idx").eq(UPat(UOps.RANGE, name="rng")).where(
UPat(UOps.LOAD, src=(UPat.var("buf"), UPat.var("add")+UPat.var("mul")*UPat(UOps.RANGE, name="rng")), name="ld"), UPat.const(None, 0.0)),),
arg=BinaryOps.ADD, name="reduce", allow_any_len=True), index_collapse),
UPat(UOps.LOAD, src=(UPat.var("buf"), UPat.any(UPat.var("add")+UPat.var("mul")*UPat(UOps.RANGE, name="rng"), UPat(UOps.RANGE, name="rng"))),
name="ld"), UPat.const(None, 0.0)),), arg=BinaryOps.ADD, name="reduce", allow_any_len=True), index_collapse),
# max folding
(UPat.max(UPat.var("x"), UPat.var("y")), lambda x,y: x if x.vmin >= y.vmax else y if x.vmax <= y.vmin else None),
# GEP/CAST const rules