mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-25 14:58:46 -05:00
update indexing with UPat.any [run_process_replay] (#6605)
This commit is contained in:
@@ -261,7 +261,7 @@ def loop_collapse(compval, idx, mval, multconst, rng:UOp, reduce, idx2=None, idx
|
||||
if extra is not None: ret = ret + UOp(UOps.REDUCE, reduce.dtype, (extra,) + reduce.src[1:], reduce.arg)
|
||||
return ret
|
||||
|
||||
def index_collapse(idx,rng,buf,add,mul,ld,reduce):
|
||||
def index_collapse(idx,rng,buf,ld,reduce,add=UOp.const(dtypes.int, 0),mul=UOp.const(dtypes.int, 1)):
|
||||
if rng not in reduce.src: return None
|
||||
return UOp(reduce.op, reduce.dtype, (UOp(ld.op, ld.dtype, (buf, add+mul*idx, ld.const_like(0), idx.ge(rng.src[0]) & idx.lt(rng.src[1]))),)+
|
||||
tuple(x for x in reduce.src[1:] if x is not rng), reduce.arg)
|
||||
@@ -337,17 +337,13 @@ constant_folder = PatternMatcher([
|
||||
arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse),
|
||||
# unrolled arange div folding
|
||||
(UPat.var("divs") + UPat.cvar("c"), fold_unrolled_divs),
|
||||
# indexing (with a multiply offset)!
|
||||
# indexing, with cast or where
|
||||
(UPat(UOps.REDUCE, src=(UPat.var("idx").eq(UPat(UOps.RANGE, name="rng")).cast()*
|
||||
UPat(UOps.LOAD, src=(UPat.var("buf"), UPat.var("add")+UPat.var("mul")*UPat(UOps.RANGE, name="rng")), name="ld"),),
|
||||
arg=BinaryOps.ADD, name="reduce", allow_any_len=True), index_collapse),
|
||||
(UPat(UOps.REDUCE, src=(UPat.var("idx").eq(UPat(UOps.RANGE, name="rng")).cast()*
|
||||
UPat(UOps.LOAD, src=(UPat.var("buf"), UPat(UOps.RANGE, name="rng")), name="ld"),),
|
||||
arg=BinaryOps.ADD, name="reduce", allow_any_len=True),
|
||||
lambda **kwargs: index_collapse(add=UOp.const(dtypes.int, 0), mul=UOp.const(dtypes.int, 1), **kwargs)),
|
||||
UPat(UOps.LOAD, src=(UPat.var("buf"), UPat.any(UPat.var("add")+UPat.var("mul")*UPat(UOps.RANGE, name="rng"), UPat(UOps.RANGE, name="rng"))),
|
||||
name="ld"),), arg=BinaryOps.ADD, name="reduce", allow_any_len=True), index_collapse),
|
||||
(UPat(UOps.REDUCE, src=(UPat.var("idx").eq(UPat(UOps.RANGE, name="rng")).where(
|
||||
UPat(UOps.LOAD, src=(UPat.var("buf"), UPat.var("add")+UPat.var("mul")*UPat(UOps.RANGE, name="rng")), name="ld"), UPat.const(None, 0.0)),),
|
||||
arg=BinaryOps.ADD, name="reduce", allow_any_len=True), index_collapse),
|
||||
UPat(UOps.LOAD, src=(UPat.var("buf"), UPat.any(UPat.var("add")+UPat.var("mul")*UPat(UOps.RANGE, name="rng"), UPat(UOps.RANGE, name="rng"))),
|
||||
name="ld"), UPat.const(None, 0.0)),), arg=BinaryOps.ADD, name="reduce", allow_any_len=True), index_collapse),
|
||||
# max folding
|
||||
(UPat.max(UPat.var("x"), UPat.var("y")), lambda x,y: x if x.vmin >= y.vmax else y if x.vmax <= y.vmin else None),
|
||||
# GEP/CAST const rules
|
||||
|
||||
Reference in New Issue
Block a user