mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
buffer in create_kernel is optional [pr] (#10218)
* buffer in create_kernel is optional [pr] * pylint
This commit is contained in:
@@ -241,7 +241,8 @@ class KernelContext:
|
||||
realizes: dict[UOp, None]
|
||||
metadata: dict[UOp, Metadata|None]
|
||||
|
||||
def create_kernel(ctx:KernelContext, x:UOp, b:UOp):
|
||||
def create_kernel(ctx:KernelContext, x:UOp, b:UOp|None=None):
|
||||
if b is None: b = UOp.new_buffer(x.device, x.size, x.dtype)
|
||||
kernel = UOp(Ops.KERNEL, src=(b,)+x.src, arg=Kernel(x.sink(), (m,) if (m:=ctx.metadata.get(x)) else ()))
|
||||
buffer = b.base if b.size == b.base.size else UOp(Ops.BUFFER_VIEW, b.dtype, (b.base,), (b.size, b.arg.views[0].offset))
|
||||
return UOp(Ops.ASSIGN, x.dtype, (buffer, kernel)).reshape(x.shape)
|
||||
@@ -251,15 +252,14 @@ DONT_PLACE_IN_KERNEL = {Ops.KERNEL, Ops.ASSIGN, Ops.BUFFER}
|
||||
insert_kernels = merge_views+PatternMatcher([
|
||||
# always give assign/contiguous a kernel
|
||||
(UPat.assign(UPat.var("b"), UPat(GroupOp.All-{Ops.KERNEL}), name="x"), create_kernel),
|
||||
(UPat(Ops.CONTIGUOUS, name="x"), lambda ctx,x: create_kernel(ctx, x, UOp.new_buffer(x.device, x.size, x.dtype))),
|
||||
(UPat(Ops.CONTIGUOUS, name="x"), create_kernel),
|
||||
# create a buffer for COPY on the new device
|
||||
(UPat(Ops.COPY, src=(UPat(), UPat(Ops.DEVICE, name="d")), name="x"), lambda ctx,d,x: create_kernel(ctx, x, UOp.new_buffer(d.arg, x.size, x.dtype))),
|
||||
(UPat(Ops.COPY, src=(UPat(), UPat(Ops.DEVICE)), name="x"), create_kernel),
|
||||
# remove CONST/BIND/VIEW from SINK
|
||||
(UPat(Ops.SINK, name="x"), lambda x: x.replace(src=new_src)
|
||||
if (new_src:=tuple(dedup(s.base for s in x.src if s.op not in {Ops.CONST, Ops.BIND}))) != x.src else None),
|
||||
# otherwise check the context if we're realizing this UOp
|
||||
(UPat(GroupOp.All-DONT_PLACE_IN_KERNEL, name="x"),
|
||||
lambda ctx,x: create_kernel(ctx, x, UOp.new_buffer(x.device, x.size, x.dtype)) if x in ctx.realizes else None),
|
||||
(UPat(GroupOp.All-DONT_PLACE_IN_KERNEL, name="x"), lambda ctx,x: create_kernel(ctx, x) if x in ctx.realizes else None),
|
||||
])
|
||||
|
||||
def append_to_kernel(ctx:KernelContext, x:UOp):
|
||||
|
||||
Reference in New Issue
Block a user