diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 6837c3387b..792eca4d71 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -230,6 +230,11 @@ class KernelContext: realizes: dict[UOp, None] ops_metadata: dict[UOp, Metadata] +def create_kernel(x:UOp, b:UOp): + kernel = UOp(Ops.KERNEL, src=(b,)+x.src, arg=Kernel(x)) + buffer = b.base if b.size == b.base.size else UOp(Ops.BUFFER_VIEW, b.dtype, (b.base,), (b.size, b.arg.views[0].offset)) + return UOp(Ops.ASSIGN, x.dtype, (buffer, kernel)).reshape(x.shape) + DONT_PLACE_IN_KERNEL = {Ops.KERNEL, Ops.ASSIGN, Ops.BUFFER} def append_to_kernel(ctx:KernelContext, x:UOp): new_srcs: list[UOp] = [] @@ -245,10 +250,10 @@ def append_to_kernel(ctx:KernelContext, x:UOp): create_kernels = merge_views+PatternMatcher([ # always give assign a kernel - (UPat.assign(UPat.var("b"), UPat(GroupOp.All-{Ops.KERNEL}), name="x"), lambda x,b: b.assign(UOp(Ops.KERNEL, src=(b,)+x.src, arg=Kernel(x)))), + (UPat.assign(UPat.var("b"), UPat(GroupOp.All-{Ops.KERNEL}), name="x"), create_kernel), # otherwise check if need to assign this UOp to a new buffer - (UPat(GroupOp.All-DONT_PLACE_IN_KERNEL, name="x"), lambda ctx,x: UOp(Ops.ASSIGN, x.dtype, (b:=UOp.new_buffer(x.device, x.size, x.dtype).view(x.st),\ - UOp(Ops.KERNEL, src=(b,)+x.src, arg=Kernel(x)))) if x in ctx.realizes else None), + (UPat(GroupOp.All-DONT_PLACE_IN_KERNEL, name="x"), + lambda ctx,x: create_kernel(x, UOp.new_buffer(x.device, x.size, x.dtype)) if x in ctx.realizes else None), # walk back the local graph until we reach a buffer/assign parent (UPat(Ops.KERNEL, name="x"), append_to_kernel), # remove CONST/BIND from SINK @@ -364,8 +369,8 @@ fix_kernel_ops = PatternMatcher([ def fix_kernel_ast(k:UOp, var_vals:dict[Variable, int]) -> UOp: assert k.op is Ops.KERNEL, f"kernel isn't kernel, it's {k}" - # substitute kernel sources for the target buffer - ast = k.arg.ast.substitute({s.src[1].arg.ast:s.src[0] for s in k.src if s.op is Ops.ASSIGN}).sink() + # substitute kernel sources for the target buffer + apply reshapes + ast = k.arg.ast.substitute({(ast:=s.src[1].arg.ast):s.src[0].view(unwrap(ast.st)) for s in k.src if s.op is Ops.ASSIGN}).sink() # add buffer ops ast = graph_rewrite(ast, add_buffer_ops, bufs:=tuple(s.buf_uop for s in k.src), bottom_up=True) if ast.op is Ops.SINK and not all_same(dev:=[x.device for x in bufs]): raise RuntimeError(f"all buffers must be on the same device: {dev}") @@ -413,8 +418,8 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va becomes_map: dict[UOp, UOp] = {} for k,v in tensor_map.items(): # if we created a KERNEL for this tensor, map it to the assigned buffer - if (a:=kernel_map.get(v.base)) is not None and a.op is Ops.ASSIGN: - becomes_map[k] = a.src[0] if v is v.base else a.src[0].view(unwrap(v.st)) + if (a:=kernel_map.get(v.base)) is not None and (a:=a.base).op is Ops.ASSIGN: + becomes_map[k] = a.src[0] if a.src[0].st == v.st else a.src[0].view(unwrap(v.st)) # tensors can also simplify to an existing buffer/const else: if k is v: continue @@ -458,8 +463,8 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va # TODO: move this to create_kernels k = fix_kernel_ast(u.src[1], var_vals) schedule.append(ScheduleItem(k.arg.ast, tuple(s.buf_uop.buffer for s in k.src), k.arg.metadata)) - # increment the refcount of the target buf (this is required by the JIT and memory planner) - u.buf_uop.buffer.ref(1) + # increment the refcount of the target buf (this is required by the JIT and memory planner) TODO: this does not belong here + k.src[0].buffer.ref(1) for x in children.get(u, []): in_degree[x] -= 1 if in_degree[x] == 0: queue.append(x) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index f2576f2598..2df87e2281 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -292,7 +292,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): return ShapeTracker.from_shape( tuple(sum(y.shape[a] for y in self.real_lbs) if a == self.axis else s for a,s in enumerate(self.real_lbs[0].shape))) if self.op in {Ops.BUFFER, Ops.BUFFER_VIEW}: return ShapeTracker.from_shape((self.size,)) - if self.op is Ops.KERNEL: return ShapeTracker.from_shape(self.arg.ast.shape) + if self.op is Ops.KERNEL: return ShapeTracker.from_shape((self.arg.ast.size,)) # these ops define a ShapeTracker from the arg if self.op is Ops.VIEW: return self.arg if self.op in GroupOp.Movement: return unwrap(self.src[0].st).mop(self.op, self.arg) diff --git a/tinygrad/spec.py b/tinygrad/spec.py index d790728c9f..c94677abcc 100644 --- a/tinygrad/spec.py +++ b/tinygrad/spec.py @@ -1,7 +1,7 @@ from typing import cast from tinygrad.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, print_uops from tinygrad.dtype import DType, ImageDType, dtypes, PtrDType -from tinygrad.helpers import all_same, all_int, dedup, prod +from tinygrad.helpers import all_same, dedup, prod buffer_spec = PatternMatcher([ (UPat(Ops.UNIQUE, dtypes.void, ()), lambda: True), @@ -9,7 +9,7 @@ buffer_spec = PatternMatcher([ (UPat(Ops.BUFFER, src=(UPat(Ops.DEVICE), UPat(Ops.UNIQUE)), name="buf"), lambda buf: isinstance(buf.arg, int) and isinstance(buf.dtype, (DType, ImageDType))), (UPat(Ops.BUFFER_VIEW, src=(UPat(Ops.BUFFER),), name="buf_view"), - lambda buf_view: isinstance(buf_view.arg, tuple) and len(buf_view.arg) == 2 and all_int(buf_view.arg)), + lambda buf_view: isinstance(buf_view.arg, tuple) and len(buf_view.arg) == 2 and all(isinstance(arg, (int, UOp)) for arg in buf_view.arg)), ]) # *** this is the spec of a Tensor in UOp *** @@ -126,10 +126,10 @@ spec = PatternMatcher([ # *** this is the spec of a Kernel in UOp *** kernel_spec = buffer_spec+PatternMatcher([ - (UPat(Ops.KERNEL, src=UPat((Ops.BUFFER, Ops.ASSIGN))), lambda: True), - # assign has a buffer view and kernel source, it can optionally depend on other assigns - (UPat(Ops.ASSIGN, src=UPat((Ops.BUFFER, Ops.VIEW, Ops.KERNEL, Ops.ASSIGN))), lambda: True), - (UPat(GroupOp.All-{Ops.SINK, Ops.VIEW}), lambda: False), + (UPat(Ops.KERNEL, src=UPat((Ops.BUFFER, Ops.BUFFER_VIEW, Ops.ASSIGN))), lambda: True), + # assign has a buffer and kernel source, it can optionally depend on other assigns + (UPat(Ops.ASSIGN, src=UPat((Ops.BUFFER, Ops.BUFFER_VIEW, Ops.KERNEL, Ops.ASSIGN))), lambda: True), + (UPat(GroupOp.All-{Ops.SINK}), lambda: False), ]) # *** this is the UOp shape spec ***