add asserts for KERNEL op ast [pr] (#9868)

This commit is contained in:
qazal
2025-04-13 16:50:18 +08:00
committed by GitHub
parent 5ee9c343e6
commit 7191f88551
2 changed files with 8 additions and 2 deletions

View File

@@ -226,7 +226,8 @@ class Kernel:
ast: UOp
metadata: tuple[Metadata, ...] = ()
def __repr__(self):
return f"<Kernel {len(list(self.ast.toposort))} {[s.op for s in self.ast.src] if self.ast.op is Ops.SINK else self.ast.op} {self.metadata}>"
ast_rep = f"SINK{tuple(s.op for s in self.ast.src)}" if self.ast.op is Ops.SINK else repr(self.ast.op)
return f"<Kernel {len(list(self.ast.toposort))} {ast_rep} {self.metadata}>"
@dataclass(frozen=True)
class KernelContext:

View File

@@ -141,8 +141,13 @@ spec = PatternMatcher([
# *** this is the spec of a Kernel in UOp ***
def validate_kernel(k:UOp):
assert k.arg.ast.op in {Ops.COPY, Ops.BUFFER_VIEW, Ops.SINK}, f"must end with SINK/COPY/BUFFER_VIEW {k.arg}"
if k.arg.ast.op is Ops.SINK: assert all(s.op is Ops.STORE for s in k.arg.ast.src), f"SINK must end with STORE {k.arg.ast}"
return True
kernel_spec = buffer_spec+PatternMatcher([
(UPat(Ops.KERNEL, src=UPat((Ops.BUFFER, Ops.BUFFER_VIEW, Ops.ASSIGN))), lambda: True),
(UPat(Ops.KERNEL, src=UPat((Ops.BUFFER, Ops.BUFFER_VIEW, Ops.ASSIGN)), name="k"), validate_kernel),
# assign has a buffer and kernel source, it can optionally depend on other assigns
(UPat(Ops.ASSIGN, src=UPat((Ops.BUFFER, Ops.BUFFER_VIEW, Ops.KERNEL, Ops.ASSIGN))), lambda: True),
(UPat(GroupOp.All-{Ops.SINK}), lambda: False),