no early return + allow childless const/bind/var in kernel graph [pr] (#9202)

This commit is contained in:
qazal
2025-02-22 20:28:22 +02:00
committed by GitHub
parent 97bc723538
commit b711c6343a
2 changed files with 2 additions and 3 deletions

View File

@@ -395,7 +395,6 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
elif isinstance(k.metadata, Metadata): ops_metadata[v] = k.metadata
# create kernels
if len(realize_map) == 0: return [], {}, becomes_map
kernel_map = graph_rewrite_map(sink, create_kernels, ctx=KernelContext(realize_map, ops_metadata), bottom_up=True)
sched_sink = kernel_map[sink]
type_verify(list(sched_sink.toposort), kernel_spec)

View File

@@ -126,8 +126,8 @@ kernel_spec = buffer_spec+PatternMatcher([
(UPat(Ops.KERNEL, src=UPat((Ops.BUFFER, Ops.ASSIGN))), lambda: True),
# assign has a buffer view and kernel source, it can optionally depend on other assigns
(UPat(Ops.ASSIGN, src=UPat((Ops.BUFFER, Ops.VIEW, Ops.KERNEL, Ops.ASSIGN))), lambda: True),
# view/sink/const can also exist in the kernel graph
(UPat((Ops.VIEW, Ops.SINK, Ops.CONST)), lambda: True),
# view/sink/const/bind/var can also exist in the kernel graph
(UPat((Ops.VIEW, Ops.SINK, Ops.CONST, Ops.BIND, Ops.DEFINE_VAR)), lambda: True),
(UPat(GroupOp.All), lambda: False),
])