cleanup ctx usage in scheduler upats [pr] (#8205)

This commit is contained in:
qazal
2024-12-13 12:01:13 +02:00
committed by GitHub
parent 55b8c4e8bf
commit 4a617c84e1

View File

@@ -169,10 +169,10 @@ check_preload = PatternMatcher([(UPat(Ops.PRELOAD, src=(UPat.var("b"), UPat()),
to_si = PatternMatcher([
(UPat(Ops.VIEW, name="x"), _append_st_vars),
(UPat(Ops.SINK, src=(UPat.store(UPat.var("b"), UPat(), UPat(GroupOp.Meta, name="x")),)), lambda ctx,b,x: x.replace(src=(b, *x.src))),
(UPat(Ops.SINK, src=(UPat.store(UPat.var("b"), UPat(), UPat(GroupOp.Meta, name="x")),)), lambda b,x: x.replace(src=(b, *x.src))),
# don't need contiguous or assign anymore
(UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda ctx,x: x),
(UPat(Ops.ASSIGN, src=(UPat(), UPat.var("x"),)), lambda ctx,x: x),
(UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda x: x),
(UPat(Ops.ASSIGN, src=(UPat(), UPat.var("x"),)), lambda x: x),
])
# ** fusion
@@ -333,7 +333,7 @@ def _as_const(u:UOp, val:ConstType) -> UOp:
st = (base:=ShapeTracker.from_shape(())).reshape((1,)*len(u.shape)).expand(u.shape)
return UOp(Ops.VIEW, u.dtype, (u.buf_uop, UOp.const(u.dtype, val)), base).view(st)
def simplify_reduceop(ctx, reduce:UOp, x:UOp) -> Optional[UOp]:
def simplify_reduceop(reduce:UOp, x:UOp) -> Optional[UOp]:
# remove reduce on unmasked const
if all_int(x.shape) and x.is_unrealized_unmasked_const():
prshape = prod(unwrap(x.st).shape[i] for i in reduce.arg[1])