no const/view in schedule sink after sym [pr]

This commit is contained in:
qazal
2025-03-11 10:58:23 +01:00
parent 68f062c8be
commit fa69fd3afc

View File

@@ -106,6 +106,9 @@ sym = symbolic_simple+PatternMatcher([
# put CAST after expanding BUFFER
(UPat(Ops.VIEW, src=(UPat(Ops.CAST, src=(UPat.var("x"),)),), name="v"), lambda x,v: x.view(x.st+v.st).cast(v.dtype) if getenv("CAST_AFTER_EXPAND")
and x.base.op is Ops.BUFFER and resolve(prod(v.shape) > prod(x.shape)) else None),
# remove CONST/BIND/VIEW from SINK
(UPat(Ops.SINK, name="x"), lambda x: x.replace(src=new_src)
if (new_src:=tuple(dedup(s.base for s in x.src if s.op not in {Ops.CONST,Ops.BIND}))) != x.src else None),
])
# **** UOp realization
@@ -259,9 +262,8 @@ create_kernels = merge_views+PatternMatcher([
lambda ctx,x: create_kernel(ctx, x, UOp.new_buffer(x.device, x.size, x.dtype)) if x in ctx.realizes else None),
# walk back the local graph until we reach a buffer/assign parent
(UPat(Ops.KERNEL, name="x"), append_to_kernel),
# remove CONST/BIND from SINK
(UPat(Ops.SINK, name="x"), lambda x: x.replace(src=new_src)
if (new_src:=tuple(dedup(s.base for s in x.src if s.op not in {Ops.CONST,Ops.BIND}))) != x.src else None),
# remove downstream reshapes from SINK
(UPat(Ops.SINK, name="x"), lambda x:x.replace(src=tuple(s.base for s in x.src)) if any(s.op is Ops.VIEW for s in x.src) else None),
])
# **** fix kernel AST