diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index 1ea1c295a8..c9e7e9132a 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -261,7 +261,7 @@ def loop_collapse(compval, idx, mval, multconst, rng:UOp, reduce, idx2=None, idx if extra is not None: ret = ret + UOp(UOps.REDUCE, reduce.dtype, (extra,) + reduce.src[1:], reduce.arg) return ret -def index_collapse(idx,rng,buf,add,mul,ld,reduce): +def index_collapse(idx,rng,buf,ld,reduce,add=UOp.const(dtypes.int, 0),mul=UOp.const(dtypes.int, 1)): if rng not in reduce.src: return None return UOp(reduce.op, reduce.dtype, (UOp(ld.op, ld.dtype, (buf, add+mul*idx, ld.const_like(0), idx.ge(rng.src[0]) & idx.lt(rng.src[1]))),)+ tuple(x for x in reduce.src[1:] if x is not rng), reduce.arg) @@ -337,17 +337,13 @@ constant_folder = PatternMatcher([ arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse), # unrolled arange div folding (UPat.var("divs") + UPat.cvar("c"), fold_unrolled_divs), - # indexing (with a multiply offset)! + # indexing, with cast or where (UPat(UOps.REDUCE, src=(UPat.var("idx").eq(UPat(UOps.RANGE, name="rng")).cast()* - UPat(UOps.LOAD, src=(UPat.var("buf"), UPat.var("add")+UPat.var("mul")*UPat(UOps.RANGE, name="rng")), name="ld"),), - arg=BinaryOps.ADD, name="reduce", allow_any_len=True), index_collapse), - (UPat(UOps.REDUCE, src=(UPat.var("idx").eq(UPat(UOps.RANGE, name="rng")).cast()* - UPat(UOps.LOAD, src=(UPat.var("buf"), UPat(UOps.RANGE, name="rng")), name="ld"),), - arg=BinaryOps.ADD, name="reduce", allow_any_len=True), - lambda **kwargs: index_collapse(add=UOp.const(dtypes.int, 0), mul=UOp.const(dtypes.int, 1), **kwargs)), + UPat(UOps.LOAD, src=(UPat.var("buf"), UPat.any(UPat.var("add")+UPat.var("mul")*UPat(UOps.RANGE, name="rng"), UPat(UOps.RANGE, name="rng"))), + name="ld"),), arg=BinaryOps.ADD, name="reduce", allow_any_len=True), index_collapse), (UPat(UOps.REDUCE, src=(UPat.var("idx").eq(UPat(UOps.RANGE, name="rng")).where( - UPat(UOps.LOAD, src=(UPat.var("buf"), UPat.var("add")+UPat.var("mul")*UPat(UOps.RANGE, name="rng")), name="ld"), UPat.const(None, 0.0)),), - arg=BinaryOps.ADD, name="reduce", allow_any_len=True), index_collapse), + UPat(UOps.LOAD, src=(UPat.var("buf"), UPat.any(UPat.var("add")+UPat.var("mul")*UPat(UOps.RANGE, name="rng"), UPat(UOps.RANGE, name="rng"))), + name="ld"), UPat.const(None, 0.0)),), arg=BinaryOps.ADD, name="reduce", allow_any_len=True), index_collapse), # max folding (UPat.max(UPat.var("x"), UPat.var("y")), lambda x,y: x if x.vmin >= y.vmax else y if x.vmax <= y.vmin else None), # GEP/CAST const rules