buffer in create_kernel is optional [pr] (#10218)

* buffer in create_kernel is optional [pr]

* pylint
This commit is contained in:
qazal
2025-05-08 22:35:55 +03:00
committed by GitHub
parent 40560e77c2
commit ff2aa6d0b2

View File

@@ -241,7 +241,8 @@ class KernelContext:
realizes: dict[UOp, None]
metadata: dict[UOp, Metadata|None]
def create_kernel(ctx:KernelContext, x:UOp, b:UOp):
def create_kernel(ctx:KernelContext, x:UOp, b:UOp|None=None):
if b is None: b = UOp.new_buffer(x.device, x.size, x.dtype)
kernel = UOp(Ops.KERNEL, src=(b,)+x.src, arg=Kernel(x.sink(), (m,) if (m:=ctx.metadata.get(x)) else ()))
buffer = b.base if b.size == b.base.size else UOp(Ops.BUFFER_VIEW, b.dtype, (b.base,), (b.size, b.arg.views[0].offset))
return UOp(Ops.ASSIGN, x.dtype, (buffer, kernel)).reshape(x.shape)
@@ -251,15 +252,14 @@ DONT_PLACE_IN_KERNEL = {Ops.KERNEL, Ops.ASSIGN, Ops.BUFFER}
insert_kernels = merge_views+PatternMatcher([
# always give assign/contiguous a kernel
(UPat.assign(UPat.var("b"), UPat(GroupOp.All-{Ops.KERNEL}), name="x"), create_kernel),
(UPat(Ops.CONTIGUOUS, name="x"), lambda ctx,x: create_kernel(ctx, x, UOp.new_buffer(x.device, x.size, x.dtype))),
(UPat(Ops.CONTIGUOUS, name="x"), create_kernel),
# create a buffer for COPY on the new device
(UPat(Ops.COPY, src=(UPat(), UPat(Ops.DEVICE, name="d")), name="x"), lambda ctx,d,x: create_kernel(ctx, x, UOp.new_buffer(d.arg, x.size, x.dtype))),
(UPat(Ops.COPY, src=(UPat(), UPat(Ops.DEVICE)), name="x"), create_kernel),
# remove CONST/BIND/VIEW from SINK
(UPat(Ops.SINK, name="x"), lambda x: x.replace(src=new_src)
if (new_src:=tuple(dedup(s.base for s in x.src if s.op not in {Ops.CONST, Ops.BIND}))) != x.src else None),
# otherwise check the context if we're realizing this UOp
(UPat(GroupOp.All-DONT_PLACE_IN_KERNEL, name="x"),
lambda ctx,x: create_kernel(ctx, x, UOp.new_buffer(x.device, x.size, x.dtype)) if x in ctx.realizes else None),
(UPat(GroupOp.All-DONT_PLACE_IN_KERNEL, name="x"), lambda ctx,x: create_kernel(ctx, x) if x in ctx.realizes else None),
])
def append_to_kernel(ctx:KernelContext, x:UOp):