From 3210b656b6d3d05f569a36987968ef7ee3ecc0af Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Fri, 28 Feb 2025 12:15:04 +0200 Subject: [PATCH] BUFFER_VIEW is a node in the kernel graph + delete ViewOp (#9298) --- examples/mlperf/dataloader.py | 6 +++--- tinygrad/engine/jit.py | 3 +-- tinygrad/engine/realize.py | 6 ------ tinygrad/engine/schedule.py | 19 +++++++++---------- tinygrad/ops.py | 9 ++++++--- tinygrad/spec.py | 10 ++++++---- 6 files changed, 25 insertions(+), 28 deletions(-) diff --git a/examples/mlperf/dataloader.py b/examples/mlperf/dataloader.py index 833915f190..12924b5500 100644 --- a/examples/mlperf/dataloader.py +++ b/examples/mlperf/dataloader.py @@ -71,7 +71,7 @@ def loader_process(q_in, q_out, X:Tensor, seed): #storage_tensor._copyin(img_tensor.numpy()) # faster - X[idx].contiguous().realize().lazydata.base.realized.as_buffer(force_zero_copy=True)[:] = img.tobytes() + X[idx].contiguous().realize().lazydata.base.buffer.ensure_allocated().as_buffer(force_zero_copy=True)[:] = img.tobytes() # ideal #X[idx].assign(img.tobytes()) # NOTE: this is slow! @@ -261,8 +261,8 @@ def load_unet3d_data(preprocessed_dataset_dir, seed, queue_in, queue_out, X:Tens x = random_brightness_augmentation(x) x = gaussian_noise(x) - X[idx].contiguous().realize().lazydata.base.realized.as_buffer(force_zero_copy=True)[:] = x.tobytes() - Y[idx].contiguous().realize().lazydata.base.realized.as_buffer(force_zero_copy=True)[:] = y.tobytes() + X[idx].contiguous().realize().lazydata.base.buffer.ensure_allocated().as_buffer(force_zero_copy=True)[:] = x.tobytes() + Y[idx].contiguous().realize().lazydata.base.buffer.ensure_allocated().as_buffer(force_zero_copy=True)[:] = y.tobytes() queue_out.put(idx) queue_out.put(None) diff --git a/tinygrad/engine/jit.py b/tinygrad/engine/jit.py index 98e9d43e2d..61f9256bd6 100644 --- a/tinygrad/engine/jit.py +++ b/tinygrad/engine/jit.py @@ -6,7 +6,7 @@ from tinygrad.device import Buffer, Compiled, Device from tinygrad.dtype import DType from tinygrad.ops import UOp, Variable, sym_infer, Ops from tinygrad.shape.shapetracker import ShapeTracker -from tinygrad.engine.realize import ExecItem, capturing, ViewOp, BufferCopy, BufferXfer, CompiledRunner, Runner, Estimates +from tinygrad.engine.realize import ExecItem, capturing, BufferCopy, BufferXfer, CompiledRunner, Runner, Estimates from tinygrad.engine.memory import _internal_memory_planner from tinygrad.nn.state import get_parameters from dataclasses import dataclass @@ -39,7 +39,6 @@ def apply_graph_to_jit(jit_cache: list[ExecItem], input_rawbuffers: list[Buffer] current_device = None for ji in jit_cache: - if isinstance(ji.prg, ViewOp): continue ji_graph_dev: Optional[Compiled] = None # device on which the ji will be graphed. Not graphed if None. if isinstance(ji.prg, CompiledRunner): ji_graph_dev = ji.prg.dev elif isinstance(ji.prg, BufferXfer) and ji.bufs[0] and ji.bufs[0].device.split(":", 1)[0] in {"CUDA", "NV", "AMD"}: diff --git a/tinygrad/engine/realize.py b/tinygrad/engine/realize.py index 22c946a080..3ca7ecba56 100644 --- a/tinygrad/engine/realize.py +++ b/tinygrad/engine/realize.py @@ -66,11 +66,6 @@ class CompiledRunner(Runner): assert len(local_size) == 3, "local size must have len 3" return self._prg(*[x._buf for x in rawbufs], **lra, vals=tuple(var_vals[k] for k in self.p.vars), wait=wait) -class ViewOp(Runner): - def __init__(self, buf:Buffer): super().__init__(colored(f"view {buf.nbytes:8d} @ {buf.offset:<10d}", "yellow"), buf.device) - def __call__(self, rawbufs:list[Buffer], var_vals:dict[Variable, int], wait=False): - assert rawbufs[0]._base is not None and rawbufs[0]._base == rawbufs[1].base, f"must be base {rawbufs}" - class BufferCopy(Runner): def __init__(self, total_sz, dest_device, src_device): if total_sz >= 1e6: name = f"{type(self).__name__[6:].lower()} {total_sz/1e6:7.2f}M, {dest_device[:7]:>7s} <- {src_device[:7]:7s}" @@ -143,7 +138,6 @@ class ExecItem: # NOTE: ctx is the buffers si_lowerer = PatternMatcher([ (UPat(Ops.SINK, name="sink"), lambda ctx,sink: (runner:=get_runner(ctx[0].device, sink), [ctx[x] for x in runner.p.globals])), - (UPat(Ops.BUFFER_VIEW), lambda ctx: (ViewOp(ctx[0]), list(ctx))), (UPat(Ops.COPY, name="copy"), lambda ctx,copy: ((BufferXfer(ctx[0].nbytes, ctx[0].device, ctx[1].device) \ if hasattr(Device[ctx[0].device].allocator, '_transfer') and all_same([x.device.split(":")[0] for x in ctx]) \ else BufferCopy(ctx[0].nbytes, ctx[0].device, ctx[1].device)), list(ctx))), diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 54285126e0..a8a4988a2e 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -1,7 +1,7 @@ import sys, atexit, pickle from collections import defaultdict, deque from dataclasses import dataclass -from tinygrad.ops import UOp, Variable, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, track_rewrites, buffers +from tinygrad.ops import UOp, Variable, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, track_rewrites from tinygrad.ops import can_pad, identity_element, resolve, view_left, merge_views from tinygrad.codegen.symbolic import symbolic_simple from tinygrad.helpers import Context, ContextVar, Metadata, all_int, all_same, colored, diskcache_put, prod, dedup, unwrap, flatten, getenv, pluralize @@ -139,8 +139,8 @@ def realize_before_view(ctx:GrouperContext, view:UOp, src:UOp) -> None: do_realize = PatternMatcher([ # always realize SINK parents (UPat(Ops.SINK, name="s"), lambda ctx,s: ctx.realizes.update((x.base, None) for x in s.src if x.base.op not in {Ops.CONST, Ops.BIND, Ops.BUFFER})), - # always realize ASSIGN/CONTIGUOUS/COPY/BUFFER_VIEW - (UPat({Ops.ASSIGN, Ops.CONTIGUOUS, Ops.COPY, Ops.BUFFER_VIEW}, name="tr"), realize), + # always realize ASSIGN/CONTIGUOUS/COPY + (UPat({Ops.ASSIGN, Ops.CONTIGUOUS, Ops.COPY}, name="tr"), realize), # realize before expand or unsafe pad ops (UPat(Ops.VIEW, name="view", src=(UPat(GroupOp.All-{Ops.BUFFER, Ops.CONST, Ops.DEVICE}, name="src"),)), realize_before_view), # realize before COPY @@ -238,7 +238,7 @@ class KernelContext: realizes: dict[UOp, None] ops_metadata: dict[UOp, Metadata] -DONT_PLACE_IN_KERNEL = {Ops.KERNEL, Ops.ASSIGN, Ops.BUFFER} +DONT_PLACE_IN_KERNEL = {Ops.KERNEL, Ops.ASSIGN, Ops.BUFFER, Ops.BUFFER_VIEW} def append_to_kernel(ctx:KernelContext, x:UOp): new_srcs: list[UOp] = [] metadata = dict.fromkeys(x.arg.metadata) @@ -274,9 +274,9 @@ def load_buf(ctx:list[UOp], x:UOp): add_buffer_ops = PatternMatcher([ # LOAD - (UPat(Ops.BUFFER, name="x"), load_buf), - # STORE (except for COPY/BUFFER_VIEW) - (UPat(Ops.SINK, src=(UPat((Ops.COPY, Ops.BUFFER_VIEW), name="x"),)), lambda x:x), + (UPat((Ops.BUFFER, Ops.BUFFER_VIEW), name="x"), load_buf), + # STORE (except for COPY) + (UPat(Ops.SINK, src=(UPat(Ops.COPY, name="x"),)), lambda x:x), (UPat(Ops.SINK, src=(UPat(GroupOp.All-{Ops.STORE}, name="x"),)), lambda x: UOp.store(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), 0), ShapeTracker.from_shape(x.shape).to_uop(), x).sink()), ]) @@ -391,8 +391,6 @@ def schedule_uop(kernel:UOp, buffer_map:dict[UOp, UOp], var_vals:dict[Variable, ast = graph_rewrite(graph_rewrite(ast, unbind_vars+view_left, ctx=var_vals), view_right) # fix_kernel_ops ast = graph_rewrite(ast, fix_kernel_ops, var_vals) - # create subbuffer - if ast.op is Ops.BUFFER_VIEW: buffers[bufs[0]] = (base:=bufs[1].buffer).view(ast.size, ast.dtype, ast.arg[1]*base.dtype.itemsize) return ScheduleItem(ast, tuple(dedup([x.buffer for x in bufs])), kernel.arg.metadata) PROCESS_REPLAY_CAPTURE:dict[str, bytes] = {} @@ -427,10 +425,11 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va # if we created a KERNEL for this tensor, map it to the assigned buffer if (a:=kernel_map.get(v.base)) is not None and a.op is Ops.ASSIGN: becomes_map[k] = a.src[0] if v is v.base else a.src[0].view(unwrap(v.st)) + elif a is not None and a.op is Ops.BUFFER_VIEW and a.src[0].op is Ops.ASSIGN: becomes_map[k] = a.replace(src=(a.src[0].buf_uop,)) # tensors can also simplify to an existing buffer/const else: if k is v: continue - if v.base.op is Ops.BUFFER: becomes_map[k] = v + if v.base.op in {Ops.BUFFER, Ops.BUFFER_VIEW}: becomes_map[k] = v if v.base.op is Ops.CONST and all_int(v.shape): becomes_map[k] = v # if a kernel depends on a buffer, and that buffer is later assigned to, make the assign depend on the kernel's assign diff --git a/tinygrad/ops.py b/tinygrad/ops.py index fdbbc6644f..44b377eef7 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -513,7 +513,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): return dsrcs[0]._device if len(dsrcs:=[x for x in self.src if x._device is not None]) != 0 else None @property def buf_uop(self) -> UOp: - if self.op is Ops.BUFFER: return self + if self.op in {Ops.BUFFER, Ops.BUFFER_VIEW}: return self assert self.op is Ops.ASSIGN, f"must be ASSIGN {self.op}" return self.src[0].base @property @@ -521,14 +521,17 @@ class UOp(MathTrait, metaclass=UOpMetaClass): if self is not self.base: assert unwrap(self.st).contiguous, "VIEW only works here if it's contiguous" return self.src[0].buffer - assert self.op is Ops.BUFFER, f"must be BUFFER {self.op}" + assert self.op in {Ops.BUFFER, Ops.BUFFER_VIEW}, f"must be BUFFER {self.op}" if (cret:=buffers.get(self)) is not None: return cret + if self.op is Ops.BUFFER_VIEW: + buffers[self] = ret = (base:=self.src[0].buffer).view(self.size, self.dtype, self.arg[1]*base.dtype.itemsize) + return ret from tinygrad.device import Buffer assert isinstance(self.device, str), f"buffer not supported on multi {self.device}" buffers[self] = ret = Buffer(self.device, self.size, self.dtype if isinstance(self.dtype, ImageDType) else self.dtype.base) return ret @property - def realized(self) -> Optional[Buffer]: return self.buffer if self.op is Ops.BUFFER and self.buffer.is_allocated() else None + def realized(self) -> Optional[Buffer]: return self.buffer if self.op in {Ops.BUFFER, Ops.BUFFER_VIEW} and self.buffer.is_allocated() else None @property def is_realized(self) -> bool: return all(x.base.realized is not None for x in self.base.real_lbs) if self.base.op is Ops.MULTI else self.base.realized is not None diff --git a/tinygrad/spec.py b/tinygrad/spec.py index e45c3b57b4..bf2c751554 100644 --- a/tinygrad/spec.py +++ b/tinygrad/spec.py @@ -8,7 +8,7 @@ buffer_spec = PatternMatcher([ (UPat(Ops.DEVICE, dtypes.void, (), name="device"), lambda device: isinstance(device.arg, str)), (UPat(Ops.BUFFER, src=(UPat(Ops.DEVICE), UPat(Ops.UNIQUE)), name="buf"), lambda buf: isinstance(buf.arg, int) and isinstance(buf.dtype, (DType, ImageDType))), - (UPat(Ops.BUFFER_VIEW, src=(UPat(Ops.BUFFER),), name="buf_view"), + (UPat(Ops.BUFFER_VIEW, src=(UPat((Ops.BUFFER, Ops.BUFFER_VIEW)),), name="buf_view"), lambda buf_view: isinstance(buf_view.arg, tuple) and len(buf_view.arg) == 2 and all_int(buf_view.arg)), ]) @@ -21,7 +21,7 @@ tensor_uop_spec = buffer_spec+PatternMatcher([ # "make things that can't be images not images" can change the buffer dtype # this is fine as long as it's a realized buffer and base dtypes match. ((isinstance(mv.dtype, ImageDType) or isinstance(x.dtype, ImageDType)) and x.dtype.base == mv.dtype.base and x.base.op is Ops.BUFFER)), - (UPat(Ops.VIEW, src=(UPat(GroupOp.All-{Ops.BUFFER, Ops.CONST, Ops.DEVICE}),)), lambda: False), + (UPat(Ops.VIEW, src=(UPat(GroupOp.All-{Ops.BUFFER, Ops.BUFFER_VIEW, Ops.CONST, Ops.DEVICE}),)), lambda: False), # Tensor variable bindings (UPat(Ops.BIND, dtypes.int, (UPat(Ops.DEFINE_VAR), UPat.cvar(dtype=dtypes.int)), arg=None), lambda: True), @@ -125,9 +125,11 @@ spec = PatternMatcher([ # *** this is the spec of a Kernel in UOp *** kernel_spec = buffer_spec+PatternMatcher([ - (UPat(Ops.KERNEL, src=UPat((Ops.BUFFER, Ops.ASSIGN))), lambda: True), + (UPat(Ops.KERNEL, src=UPat((Ops.BUFFER, Ops.BUFFER_VIEW, Ops.ASSIGN))), lambda: True), # assign has a buffer view and kernel source, it can optionally depend on other assigns - (UPat(Ops.ASSIGN, src=UPat((Ops.BUFFER, Ops.VIEW, Ops.KERNEL, Ops.ASSIGN))), lambda: True), + (UPat(Ops.ASSIGN, src=UPat((Ops.BUFFER, Ops.BUFFER_VIEW, Ops.VIEW, Ops.KERNEL, Ops.ASSIGN))), lambda: True), + # we also allow for a BUFFER_VIEW to depend on ASSIGN in the kernel graph + (UPat(Ops.BUFFER_VIEW, src=(UPat(Ops.ASSIGN),)), lambda: True), (UPat(GroupOp.All-{Ops.SINK, Ops.VIEW}), lambda: False), ])