experiment with symbolic here

This commit is contained in:
qazal
2024-12-07 13:15:12 +02:00
parent 2a7d258a2b
commit cb87d61f7a
2 changed files with 15 additions and 2 deletions

View File

@@ -177,9 +177,22 @@ def _append_preload(ctx:ScheduleItemContext, x:UOp, b:UOp) -> UOp:
return x.replace(op=Ops.LOAD)
check_preload = PatternMatcher([(UPat(Ops.PRELOAD, src=(UPat.var("b"), UPat()), name="x"), _append_preload),])
def collapse_reduceop(ctx, reduce:UOp, x:UOp) -> Optional[UOp]:
ret = x.arg
prshape = reduce.arg[2]
match reduce.arg[0]:
case Ops.ADD: ret *= prshape
case Ops.MUL: ret **= prshape
case Ops.MAX: pass
case _: return None
return UOp.const(x.dtype, ret)
to_si = symbolic_flat+PatternMatcher([
# if we're loading something that ended up collapsing, just embed the CONST!
(UPat(Ops.VIEW, src=(UPat(Ops.BUFFER), UPat.cvar("x"))), lambda ctx,x:x),
(UPat(Ops.VIEW, src=(UPat.cvar("x"),)), lambda ctx,x:x),
# can fold reduce of CONST
(UPat(Ops.REDUCE_AXIS, src=(UPat.cvar("x"),), name="reduce"), collapse_reduceop),
(UPat(Ops.VIEW, name="x"), _append_st_vars),
(UPat(Ops.SINK, src=(UPat.store(UPat.var("b"), UPat(), UPat(GroupOp.Meta, name="x")),)), lambda ctx,b,x: x.replace(src=(b, *x.src))),
# unmasked VALID is just CONST

View File

@@ -350,7 +350,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
return UOp(Ops.VCONST if isinstance(b, tuple) else Ops.CONST, dtype, arg=dtypes.as_const(b, dtype))
def r(self, op:Ops, axis:Tuple[int, ...]):
axis = tuple(sorted([x for x in axis if resolve(self.shape[x] != 1)]))
return self if len(axis) == 0 else UOp(Ops.REDUCE_AXIS, self.dtype, (self,), (op, axis))
return self if len(axis) == 0 else UOp(Ops.REDUCE_AXIS, self.dtype, (self,), (op, axis, prod(self.shape[i] for i in axis)))
def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self,x), None if self.st is None or self.st.contiguous else self.st)
def contiguous(self, allow_buffer_view=True):
if not unwrap(self.st).contiguous or self.size != self.base.size or self.is_unrealized_const():
@@ -913,7 +913,7 @@ spec = PatternMatcher([
(UPat(Ops.IF, dtype=dtypes.void, src=(UPat(), UPat(Ops.BARRIER))), lambda: True),
(UPat(Ops.ENDIF, dtype=dtypes.void, src=(UPat(Ops.IF),)), lambda: True),
(UPat(Ops.REDUCE_AXIS, name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) == 2 and x.arg[0] in {Ops.ADD, Ops.MUL, Ops.MAX}),
(UPat(Ops.REDUCE_AXIS, name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) in {2, 3} and x.arg[0] in {Ops.ADD, Ops.MUL, Ops.MAX}),
(UPat(Ops.GEP, src=(UPat(name="src"),), name="gep"), lambda gep,src: gep.dtype == src.dtype.scalar()),
(UPat(Ops.VECTORIZE, name="x"), lambda x: len(x.src)>1 and len(x.src) == x.dtype.count and all(x.dtype == y.dtype.vec(len(x.src)) for y in x.src)),
(UPat((Ops.BITCAST, Ops.CAST), src=(UPat(),), name="x"), lambda x: x.arg is None),