mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-26 07:18:40 -05:00
remove views from the kernel graph minimum diff (#9385)
* remove views from the kernel graph * notes
This commit is contained in:
@@ -230,6 +230,11 @@ class KernelContext:
|
||||
realizes: dict[UOp, None]
|
||||
ops_metadata: dict[UOp, Metadata]
|
||||
|
||||
def create_kernel(x:UOp, b:UOp):
|
||||
kernel = UOp(Ops.KERNEL, src=(b,)+x.src, arg=Kernel(x))
|
||||
buffer = b.base if b.size == b.base.size else UOp(Ops.BUFFER_VIEW, b.dtype, (b.base,), (b.size, b.arg.views[0].offset))
|
||||
return UOp(Ops.ASSIGN, x.dtype, (buffer, kernel)).reshape(x.shape)
|
||||
|
||||
DONT_PLACE_IN_KERNEL = {Ops.KERNEL, Ops.ASSIGN, Ops.BUFFER}
|
||||
def append_to_kernel(ctx:KernelContext, x:UOp):
|
||||
new_srcs: list[UOp] = []
|
||||
@@ -245,10 +250,10 @@ def append_to_kernel(ctx:KernelContext, x:UOp):
|
||||
|
||||
create_kernels = merge_views+PatternMatcher([
|
||||
# always give assign a kernel
|
||||
(UPat.assign(UPat.var("b"), UPat(GroupOp.All-{Ops.KERNEL}), name="x"), lambda x,b: b.assign(UOp(Ops.KERNEL, src=(b,)+x.src, arg=Kernel(x)))),
|
||||
(UPat.assign(UPat.var("b"), UPat(GroupOp.All-{Ops.KERNEL}), name="x"), create_kernel),
|
||||
# otherwise check if need to assign this UOp to a new buffer
|
||||
(UPat(GroupOp.All-DONT_PLACE_IN_KERNEL, name="x"), lambda ctx,x: UOp(Ops.ASSIGN, x.dtype, (b:=UOp.new_buffer(x.device, x.size, x.dtype).view(x.st),\
|
||||
UOp(Ops.KERNEL, src=(b,)+x.src, arg=Kernel(x)))) if x in ctx.realizes else None),
|
||||
(UPat(GroupOp.All-DONT_PLACE_IN_KERNEL, name="x"),
|
||||
lambda ctx,x: create_kernel(x, UOp.new_buffer(x.device, x.size, x.dtype)) if x in ctx.realizes else None),
|
||||
# walk back the local graph until we reach a buffer/assign parent
|
||||
(UPat(Ops.KERNEL, name="x"), append_to_kernel),
|
||||
# remove CONST/BIND from SINK
|
||||
@@ -364,8 +369,8 @@ fix_kernel_ops = PatternMatcher([
|
||||
|
||||
def fix_kernel_ast(k:UOp, var_vals:dict[Variable, int]) -> UOp:
|
||||
assert k.op is Ops.KERNEL, f"kernel isn't kernel, it's {k}"
|
||||
# substitute kernel sources for the target buffer
|
||||
ast = k.arg.ast.substitute({s.src[1].arg.ast:s.src[0] for s in k.src if s.op is Ops.ASSIGN}).sink()
|
||||
# substitute kernel sources for the target buffer + apply reshapes
|
||||
ast = k.arg.ast.substitute({(ast:=s.src[1].arg.ast):s.src[0].view(unwrap(ast.st)) for s in k.src if s.op is Ops.ASSIGN}).sink()
|
||||
# add buffer ops
|
||||
ast = graph_rewrite(ast, add_buffer_ops, bufs:=tuple(s.buf_uop for s in k.src), bottom_up=True)
|
||||
if ast.op is Ops.SINK and not all_same(dev:=[x.device for x in bufs]): raise RuntimeError(f"all buffers must be on the same device: {dev}")
|
||||
@@ -413,8 +418,8 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
|
||||
becomes_map: dict[UOp, UOp] = {}
|
||||
for k,v in tensor_map.items():
|
||||
# if we created a KERNEL for this tensor, map it to the assigned buffer
|
||||
if (a:=kernel_map.get(v.base)) is not None and a.op is Ops.ASSIGN:
|
||||
becomes_map[k] = a.src[0] if v is v.base else a.src[0].view(unwrap(v.st))
|
||||
if (a:=kernel_map.get(v.base)) is not None and (a:=a.base).op is Ops.ASSIGN:
|
||||
becomes_map[k] = a.src[0] if a.src[0].st == v.st else a.src[0].view(unwrap(v.st))
|
||||
# tensors can also simplify to an existing buffer/const
|
||||
else:
|
||||
if k is v: continue
|
||||
@@ -458,8 +463,8 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
|
||||
# TODO: move this to create_kernels
|
||||
k = fix_kernel_ast(u.src[1], var_vals)
|
||||
schedule.append(ScheduleItem(k.arg.ast, tuple(s.buf_uop.buffer for s in k.src), k.arg.metadata))
|
||||
# increment the refcount of the target buf (this is required by the JIT and memory planner)
|
||||
u.buf_uop.buffer.ref(1)
|
||||
# increment the refcount of the target buf (this is required by the JIT and memory planner) TODO: this does not belong here
|
||||
k.src[0].buffer.ref(1)
|
||||
for x in children.get(u, []):
|
||||
in_degree[x] -= 1
|
||||
if in_degree[x] == 0: queue.append(x)
|
||||
|
||||
@@ -292,7 +292,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
return ShapeTracker.from_shape(
|
||||
tuple(sum(y.shape[a] for y in self.real_lbs) if a == self.axis else s for a,s in enumerate(self.real_lbs[0].shape)))
|
||||
if self.op in {Ops.BUFFER, Ops.BUFFER_VIEW}: return ShapeTracker.from_shape((self.size,))
|
||||
if self.op is Ops.KERNEL: return ShapeTracker.from_shape(self.arg.ast.shape)
|
||||
if self.op is Ops.KERNEL: return ShapeTracker.from_shape((self.arg.ast.size,))
|
||||
# these ops define a ShapeTracker from the arg
|
||||
if self.op is Ops.VIEW: return self.arg
|
||||
if self.op in GroupOp.Movement: return unwrap(self.src[0].st).mop(self.op, self.arg)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import cast
|
||||
from tinygrad.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, print_uops
|
||||
from tinygrad.dtype import DType, ImageDType, dtypes, PtrDType
|
||||
from tinygrad.helpers import all_same, all_int, dedup, prod
|
||||
from tinygrad.helpers import all_same, dedup, prod
|
||||
|
||||
buffer_spec = PatternMatcher([
|
||||
(UPat(Ops.UNIQUE, dtypes.void, ()), lambda: True),
|
||||
@@ -9,7 +9,7 @@ buffer_spec = PatternMatcher([
|
||||
(UPat(Ops.BUFFER, src=(UPat(Ops.DEVICE), UPat(Ops.UNIQUE)), name="buf"),
|
||||
lambda buf: isinstance(buf.arg, int) and isinstance(buf.dtype, (DType, ImageDType))),
|
||||
(UPat(Ops.BUFFER_VIEW, src=(UPat(Ops.BUFFER),), name="buf_view"),
|
||||
lambda buf_view: isinstance(buf_view.arg, tuple) and len(buf_view.arg) == 2 and all_int(buf_view.arg)),
|
||||
lambda buf_view: isinstance(buf_view.arg, tuple) and len(buf_view.arg) == 2 and all(isinstance(arg, (int, UOp)) for arg in buf_view.arg)),
|
||||
])
|
||||
|
||||
# *** this is the spec of a Tensor in UOp ***
|
||||
@@ -126,10 +126,10 @@ spec = PatternMatcher([
|
||||
# *** this is the spec of a Kernel in UOp ***
|
||||
|
||||
kernel_spec = buffer_spec+PatternMatcher([
|
||||
(UPat(Ops.KERNEL, src=UPat((Ops.BUFFER, Ops.ASSIGN))), lambda: True),
|
||||
# assign has a buffer view and kernel source, it can optionally depend on other assigns
|
||||
(UPat(Ops.ASSIGN, src=UPat((Ops.BUFFER, Ops.VIEW, Ops.KERNEL, Ops.ASSIGN))), lambda: True),
|
||||
(UPat(GroupOp.All-{Ops.SINK, Ops.VIEW}), lambda: False),
|
||||
(UPat(Ops.KERNEL, src=UPat((Ops.BUFFER, Ops.BUFFER_VIEW, Ops.ASSIGN))), lambda: True),
|
||||
# assign has a buffer and kernel source, it can optionally depend on other assigns
|
||||
(UPat(Ops.ASSIGN, src=UPat((Ops.BUFFER, Ops.BUFFER_VIEW, Ops.KERNEL, Ops.ASSIGN))), lambda: True),
|
||||
(UPat(GroupOp.All-{Ops.SINK}), lambda: False),
|
||||
])
|
||||
|
||||
# *** this is the UOp shape spec ***
|
||||
|
||||
Reference in New Issue
Block a user