diff --git a/tinygrad/codegen/symbolic.py b/tinygrad/codegen/symbolic.py index 0b2db0178c..abd005fb72 100644 --- a/tinygrad/codegen/symbolic.py +++ b/tinygrad/codegen/symbolic.py @@ -417,6 +417,17 @@ index_load = UPat.var("buf").index(rng_aug).load(name="ld") arange_augrng = UPat.any(rng_aug, rng_aug+UPat.var("idx2"), rng_aug+UPat.var("idx2")+UPat.var("idx3"), UPat(Ops.VECTORIZE, name="vec", src=rng_aug)) arange_m = (arange_augrng= 0): outside.append(m) + else: inside.append(m) + if len(outside) == 0: return None + return r.replace(src=(prod(inside) if len(inside) else r.src[0].const_like(1),)+r.src[1:])*prod(outside) + # this is symbolic 2.0 sym = symbolic_flat+PatternMatcher([ # self ASSIGN is just self @@ -487,7 +498,9 @@ sym = symbolic_flat+PatternMatcher([ (UPat.var("x") * ((1+UPat.var("x")).reciprocal().named("d")), lambda x,d: 1-d), # x*/(1+x) -> 1-1/(1+x) (UPat.var("x") * ((1+UPat.var("x")).reciprocal().named("d")*UPat.var("y")), lambda x,y,d: y*(1-d)), (UPat.var("x") * ((1+UPat.var("x")).reciprocal().named("d")+UPat.var("y")), lambda x,y,d: (1-d)+x*y), - # move const multiply after REDUCE + # move const multiply after REDUCE (NOTE: the mul chain can do this, but only if it's a same dtype reduce) (UPat(Ops.REDUCE, src=(UPat.var("x")*UPat.cvar("c", vec=False),), arg=Ops.ADD, name="r", allow_any_len=True), lambda x,c,r: r.replace(src=(x,)+r.src[1:])*c.arg), + # reduce mul chain, move muls after the reduce + (UPat(Ops.REDUCE, src=(UPat(Ops.MUL),), name="r", allow_any_len=True), reduce_mul_chain), ])