mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
add asserts for KERNEL op ast [pr] (#9868)
This commit is contained in:
@@ -226,7 +226,8 @@ class Kernel:
|
||||
ast: UOp
|
||||
metadata: tuple[Metadata, ...] = ()
|
||||
def __repr__(self):
|
||||
return f"<Kernel {len(list(self.ast.toposort))} {[s.op for s in self.ast.src] if self.ast.op is Ops.SINK else self.ast.op} {self.metadata}>"
|
||||
ast_rep = f"SINK{tuple(s.op for s in self.ast.src)}" if self.ast.op is Ops.SINK else repr(self.ast.op)
|
||||
return f"<Kernel {len(list(self.ast.toposort))} {ast_rep} {self.metadata}>"
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class KernelContext:
|
||||
|
||||
@@ -141,8 +141,13 @@ spec = PatternMatcher([
|
||||
|
||||
# *** this is the spec of a Kernel in UOp ***
|
||||
|
||||
def validate_kernel(k:UOp):
|
||||
assert k.arg.ast.op in {Ops.COPY, Ops.BUFFER_VIEW, Ops.SINK}, f"must end with SINK/COPY/BUFFER_VIEW {k.arg}"
|
||||
if k.arg.ast.op is Ops.SINK: assert all(s.op is Ops.STORE for s in k.arg.ast.src), f"SINK must end with STORE {k.arg.ast}"
|
||||
return True
|
||||
|
||||
kernel_spec = buffer_spec+PatternMatcher([
|
||||
(UPat(Ops.KERNEL, src=UPat((Ops.BUFFER, Ops.BUFFER_VIEW, Ops.ASSIGN))), lambda: True),
|
||||
(UPat(Ops.KERNEL, src=UPat((Ops.BUFFER, Ops.BUFFER_VIEW, Ops.ASSIGN)), name="k"), validate_kernel),
|
||||
# assign has a buffer and kernel source, it can optionally depend on other assigns
|
||||
(UPat(Ops.ASSIGN, src=UPat((Ops.BUFFER, Ops.BUFFER_VIEW, Ops.KERNEL, Ops.ASSIGN))), lambda: True),
|
||||
(UPat(GroupOp.All-{Ops.SINK}), lambda: False),
|
||||
|
||||
Reference in New Issue
Block a user