mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-22 13:28:06 -05:00
delete KernelContext dataclass [pr] (#10236)
This commit is contained in:
@@ -236,11 +236,7 @@ class Kernel:
|
||||
ast_rep = f"SINK{tuple(s.op for s in self.ast.src)}" if self.ast.op is Ops.SINK else repr(self.ast.op)
|
||||
return f"<Kernel {len(list(self.ast.toposort()))} {ast_rep} {self.metadata}>"
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class KernelContext:
|
||||
realizes: dict[UOp, None]
|
||||
|
||||
def create_kernel(ctx:KernelContext, x:UOp, b:UOp|None=None):
|
||||
def create_kernel(x:UOp, b:UOp|None=None):
|
||||
if b is None: b = UOp.new_buffer(x.device, x.size, x.dtype)
|
||||
kernel = UOp(Ops.KERNEL, src=(b,)+x.src, arg=Kernel(x.sink(), (m,) if (m:=x.metadata) else ()))
|
||||
buffer = b.base if b.size == b.base.size else UOp(Ops.BUFFER_VIEW, b.dtype, (b.base,), (b.size, b.arg.views[0].offset))
|
||||
@@ -258,14 +254,14 @@ insert_kernels = merge_views+PatternMatcher([
|
||||
(UPat(Ops.SINK, name="x"), lambda x: x.replace(src=new_src)
|
||||
if (new_src:=tuple(dedup(s.base for s in x.src if s.op not in {Ops.CONST, Ops.BIND}))) != x.src else None),
|
||||
# otherwise check the context if we're realizing this UOp
|
||||
(UPat(GroupOp.All-DONT_PLACE_IN_KERNEL, name="x"), lambda ctx,x: create_kernel(ctx, x) if x in ctx.realizes else None),
|
||||
(UPat(GroupOp.All-DONT_PLACE_IN_KERNEL, name="x"), lambda ctx,x: create_kernel(x) if x in ctx else None),
|
||||
])
|
||||
|
||||
def append_to_kernel(ctx:KernelContext, x:UOp):
|
||||
def append_to_kernel(ctx:dict[UOp, None], x:UOp):
|
||||
new_srcs: list[UOp] = []
|
||||
metadata = dict.fromkeys(x.arg.metadata)
|
||||
for s in x.src:
|
||||
if s.op in DONT_PLACE_IN_KERNEL or s in ctx.realizes: new_srcs.append(s)
|
||||
if s.op in DONT_PLACE_IN_KERNEL or s in ctx: new_srcs.append(s)
|
||||
else:
|
||||
new_srcs.extend(s.src)
|
||||
if s.base.op not in {Ops.CONST, Ops.DEVICE} and (m:=s.metadata): metadata[m] = None
|
||||
@@ -490,8 +486,7 @@ def get_becomes_map(big_sink:UOp) -> dict[UOp, UOp]:
|
||||
|
||||
# group into kernels
|
||||
realize_map = group_realizes(tensor_map[big_sink])
|
||||
tensor_map = graph_rewrite_map(tensor_map[big_sink], create_kernels, ctx=KernelContext(realize_map),
|
||||
bottom_up=True, input_map=tensor_map, name="create_kernels")
|
||||
tensor_map = graph_rewrite_map(tensor_map[big_sink], create_kernels, ctx=realize_map, bottom_up=True, input_map=tensor_map, name="create_kernels")
|
||||
|
||||
# if a kernel depends on a buffer, and that buffer is later assigned to, make the assign depend on the kernel's assign
|
||||
kernel_assign: dict[UOp, UOp] = {}
|
||||
|
||||
Reference in New Issue
Block a user