remove views from the kernel graph minimum diff (#9385)

* remove views from the kernel graph

* notes
This commit is contained in:
qazal
2025-03-08 11:14:42 +02:00
committed by GitHub
parent 0d2762c010
commit ecfccdea8e
3 changed files with 21 additions and 16 deletions

View File

@@ -230,6 +230,11 @@ class KernelContext:
realizes: dict[UOp, None]
ops_metadata: dict[UOp, Metadata]
def create_kernel(x:UOp, b:UOp):
kernel = UOp(Ops.KERNEL, src=(b,)+x.src, arg=Kernel(x))
buffer = b.base if b.size == b.base.size else UOp(Ops.BUFFER_VIEW, b.dtype, (b.base,), (b.size, b.arg.views[0].offset))
return UOp(Ops.ASSIGN, x.dtype, (buffer, kernel)).reshape(x.shape)
DONT_PLACE_IN_KERNEL = {Ops.KERNEL, Ops.ASSIGN, Ops.BUFFER}
def append_to_kernel(ctx:KernelContext, x:UOp):
new_srcs: list[UOp] = []
@@ -245,10 +250,10 @@ def append_to_kernel(ctx:KernelContext, x:UOp):
create_kernels = merge_views+PatternMatcher([
# always give assign a kernel
(UPat.assign(UPat.var("b"), UPat(GroupOp.All-{Ops.KERNEL}), name="x"), lambda x,b: b.assign(UOp(Ops.KERNEL, src=(b,)+x.src, arg=Kernel(x)))),
(UPat.assign(UPat.var("b"), UPat(GroupOp.All-{Ops.KERNEL}), name="x"), create_kernel),
# otherwise check if need to assign this UOp to a new buffer
(UPat(GroupOp.All-DONT_PLACE_IN_KERNEL, name="x"), lambda ctx,x: UOp(Ops.ASSIGN, x.dtype, (b:=UOp.new_buffer(x.device, x.size, x.dtype).view(x.st),\
UOp(Ops.KERNEL, src=(b,)+x.src, arg=Kernel(x)))) if x in ctx.realizes else None),
(UPat(GroupOp.All-DONT_PLACE_IN_KERNEL, name="x"),
lambda ctx,x: create_kernel(x, UOp.new_buffer(x.device, x.size, x.dtype)) if x in ctx.realizes else None),
# walk back the local graph until we reach a buffer/assign parent
(UPat(Ops.KERNEL, name="x"), append_to_kernel),
# remove CONST/BIND from SINK
@@ -364,8 +369,8 @@ fix_kernel_ops = PatternMatcher([
def fix_kernel_ast(k:UOp, var_vals:dict[Variable, int]) -> UOp:
assert k.op is Ops.KERNEL, f"kernel isn't kernel, it's {k}"
# substitute kernel sources for the target buffer
ast = k.arg.ast.substitute({s.src[1].arg.ast:s.src[0] for s in k.src if s.op is Ops.ASSIGN}).sink()
# substitute kernel sources for the target buffer + apply reshapes
ast = k.arg.ast.substitute({(ast:=s.src[1].arg.ast):s.src[0].view(unwrap(ast.st)) for s in k.src if s.op is Ops.ASSIGN}).sink()
# add buffer ops
ast = graph_rewrite(ast, add_buffer_ops, bufs:=tuple(s.buf_uop for s in k.src), bottom_up=True)
if ast.op is Ops.SINK and not all_same(dev:=[x.device for x in bufs]): raise RuntimeError(f"all buffers must be on the same device: {dev}")
@@ -413,8 +418,8 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
becomes_map: dict[UOp, UOp] = {}
for k,v in tensor_map.items():
# if we created a KERNEL for this tensor, map it to the assigned buffer
if (a:=kernel_map.get(v.base)) is not None and a.op is Ops.ASSIGN:
becomes_map[k] = a.src[0] if v is v.base else a.src[0].view(unwrap(v.st))
if (a:=kernel_map.get(v.base)) is not None and (a:=a.base).op is Ops.ASSIGN:
becomes_map[k] = a.src[0] if a.src[0].st == v.st else a.src[0].view(unwrap(v.st))
# tensors can also simplify to an existing buffer/const
else:
if k is v: continue
@@ -458,8 +463,8 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
# TODO: move this to create_kernels
k = fix_kernel_ast(u.src[1], var_vals)
schedule.append(ScheduleItem(k.arg.ast, tuple(s.buf_uop.buffer for s in k.src), k.arg.metadata))
# increment the refcount of the target buf (this is required by the JIT and memory planner)
u.buf_uop.buffer.ref(1)
# increment the refcount of the target buf (this is required by the JIT and memory planner) TODO: this does not belong here
k.src[0].buffer.ref(1)
for x in children.get(u, []):
in_degree[x] -= 1
if in_degree[x] == 0: queue.append(x)

View File

@@ -292,7 +292,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
return ShapeTracker.from_shape(
tuple(sum(y.shape[a] for y in self.real_lbs) if a == self.axis else s for a,s in enumerate(self.real_lbs[0].shape)))
if self.op in {Ops.BUFFER, Ops.BUFFER_VIEW}: return ShapeTracker.from_shape((self.size,))
if self.op is Ops.KERNEL: return ShapeTracker.from_shape(self.arg.ast.shape)
if self.op is Ops.KERNEL: return ShapeTracker.from_shape((self.arg.ast.size,))
# these ops define a ShapeTracker from the arg
if self.op is Ops.VIEW: return self.arg
if self.op in GroupOp.Movement: return unwrap(self.src[0].st).mop(self.op, self.arg)

View File

@@ -1,7 +1,7 @@
from typing import cast
from tinygrad.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, print_uops
from tinygrad.dtype import DType, ImageDType, dtypes, PtrDType
from tinygrad.helpers import all_same, all_int, dedup, prod
from tinygrad.helpers import all_same, dedup, prod
buffer_spec = PatternMatcher([
(UPat(Ops.UNIQUE, dtypes.void, ()), lambda: True),
@@ -9,7 +9,7 @@ buffer_spec = PatternMatcher([
(UPat(Ops.BUFFER, src=(UPat(Ops.DEVICE), UPat(Ops.UNIQUE)), name="buf"),
lambda buf: isinstance(buf.arg, int) and isinstance(buf.dtype, (DType, ImageDType))),
(UPat(Ops.BUFFER_VIEW, src=(UPat(Ops.BUFFER),), name="buf_view"),
lambda buf_view: isinstance(buf_view.arg, tuple) and len(buf_view.arg) == 2 and all_int(buf_view.arg)),
lambda buf_view: isinstance(buf_view.arg, tuple) and len(buf_view.arg) == 2 and all(isinstance(arg, (int, UOp)) for arg in buf_view.arg)),
])
# *** this is the spec of a Tensor in UOp ***
@@ -126,10 +126,10 @@ spec = PatternMatcher([
# *** this is the spec of a Kernel in UOp ***
kernel_spec = buffer_spec+PatternMatcher([
(UPat(Ops.KERNEL, src=UPat((Ops.BUFFER, Ops.ASSIGN))), lambda: True),
# assign has a buffer view and kernel source, it can optionally depend on other assigns
(UPat(Ops.ASSIGN, src=UPat((Ops.BUFFER, Ops.VIEW, Ops.KERNEL, Ops.ASSIGN))), lambda: True),
(UPat(GroupOp.All-{Ops.SINK, Ops.VIEW}), lambda: False),
(UPat(Ops.KERNEL, src=UPat((Ops.BUFFER, Ops.BUFFER_VIEW, Ops.ASSIGN))), lambda: True),
# assign has a buffer and kernel source, it can optionally depend on other assigns
(UPat(Ops.ASSIGN, src=UPat((Ops.BUFFER, Ops.BUFFER_VIEW, Ops.KERNEL, Ops.ASSIGN))), lambda: True),
(UPat(GroupOp.All-{Ops.SINK}), lambda: False),
])
# *** this is the UOp shape spec ***