BUFFER_VIEW is a node in the kernel graph + delete ViewOp (#9298)

This commit is contained in:
qazal
2025-02-28 12:15:04 +02:00
committed by GitHub
parent 5a9c788ae6
commit 3210b656b6
6 changed files with 25 additions and 28 deletions

View File

@@ -71,7 +71,7 @@ def loader_process(q_in, q_out, X:Tensor, seed):
#storage_tensor._copyin(img_tensor.numpy())
# faster
X[idx].contiguous().realize().lazydata.base.realized.as_buffer(force_zero_copy=True)[:] = img.tobytes()
X[idx].contiguous().realize().lazydata.base.buffer.ensure_allocated().as_buffer(force_zero_copy=True)[:] = img.tobytes()
# ideal
#X[idx].assign(img.tobytes()) # NOTE: this is slow!
@@ -261,8 +261,8 @@ def load_unet3d_data(preprocessed_dataset_dir, seed, queue_in, queue_out, X:Tens
x = random_brightness_augmentation(x)
x = gaussian_noise(x)
X[idx].contiguous().realize().lazydata.base.realized.as_buffer(force_zero_copy=True)[:] = x.tobytes()
Y[idx].contiguous().realize().lazydata.base.realized.as_buffer(force_zero_copy=True)[:] = y.tobytes()
X[idx].contiguous().realize().lazydata.base.buffer.ensure_allocated().as_buffer(force_zero_copy=True)[:] = x.tobytes()
Y[idx].contiguous().realize().lazydata.base.buffer.ensure_allocated().as_buffer(force_zero_copy=True)[:] = y.tobytes()
queue_out.put(idx)
queue_out.put(None)

View File

@@ -6,7 +6,7 @@ from tinygrad.device import Buffer, Compiled, Device
from tinygrad.dtype import DType
from tinygrad.ops import UOp, Variable, sym_infer, Ops
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.engine.realize import ExecItem, capturing, ViewOp, BufferCopy, BufferXfer, CompiledRunner, Runner, Estimates
from tinygrad.engine.realize import ExecItem, capturing, BufferCopy, BufferXfer, CompiledRunner, Runner, Estimates
from tinygrad.engine.memory import _internal_memory_planner
from tinygrad.nn.state import get_parameters
from dataclasses import dataclass
@@ -39,7 +39,6 @@ def apply_graph_to_jit(jit_cache: list[ExecItem], input_rawbuffers: list[Buffer]
current_device = None
for ji in jit_cache:
if isinstance(ji.prg, ViewOp): continue
ji_graph_dev: Optional[Compiled] = None # device on which the ji will be graphed. Not graphed if None.
if isinstance(ji.prg, CompiledRunner): ji_graph_dev = ji.prg.dev
elif isinstance(ji.prg, BufferXfer) and ji.bufs[0] and ji.bufs[0].device.split(":", 1)[0] in {"CUDA", "NV", "AMD"}:

View File

@@ -66,11 +66,6 @@ class CompiledRunner(Runner):
assert len(local_size) == 3, "local size must have len 3"
return self._prg(*[x._buf for x in rawbufs], **lra, vals=tuple(var_vals[k] for k in self.p.vars), wait=wait)
class ViewOp(Runner):
def __init__(self, buf:Buffer): super().__init__(colored(f"view {buf.nbytes:8d} @ {buf.offset:<10d}", "yellow"), buf.device)
def __call__(self, rawbufs:list[Buffer], var_vals:dict[Variable, int], wait=False):
assert rawbufs[0]._base is not None and rawbufs[0]._base == rawbufs[1].base, f"must be base {rawbufs}"
class BufferCopy(Runner):
def __init__(self, total_sz, dest_device, src_device):
if total_sz >= 1e6: name = f"{type(self).__name__[6:].lower()} {total_sz/1e6:7.2f}M, {dest_device[:7]:>7s} <- {src_device[:7]:7s}"
@@ -143,7 +138,6 @@ class ExecItem:
# NOTE: ctx is the buffers
si_lowerer = PatternMatcher([
(UPat(Ops.SINK, name="sink"), lambda ctx,sink: (runner:=get_runner(ctx[0].device, sink), [ctx[x] for x in runner.p.globals])),
(UPat(Ops.BUFFER_VIEW), lambda ctx: (ViewOp(ctx[0]), list(ctx))),
(UPat(Ops.COPY, name="copy"), lambda ctx,copy: ((BufferXfer(ctx[0].nbytes, ctx[0].device, ctx[1].device) \
if hasattr(Device[ctx[0].device].allocator, '_transfer') and all_same([x.device.split(":")[0] for x in ctx]) \
else BufferCopy(ctx[0].nbytes, ctx[0].device, ctx[1].device)), list(ctx))),

View File

@@ -1,7 +1,7 @@
import sys, atexit, pickle
from collections import defaultdict, deque
from dataclasses import dataclass
from tinygrad.ops import UOp, Variable, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, track_rewrites, buffers
from tinygrad.ops import UOp, Variable, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, track_rewrites
from tinygrad.ops import can_pad, identity_element, resolve, view_left, merge_views
from tinygrad.codegen.symbolic import symbolic_simple
from tinygrad.helpers import Context, ContextVar, Metadata, all_int, all_same, colored, diskcache_put, prod, dedup, unwrap, flatten, getenv, pluralize
@@ -139,8 +139,8 @@ def realize_before_view(ctx:GrouperContext, view:UOp, src:UOp) -> None:
do_realize = PatternMatcher([
# always realize SINK parents
(UPat(Ops.SINK, name="s"), lambda ctx,s: ctx.realizes.update((x.base, None) for x in s.src if x.base.op not in {Ops.CONST, Ops.BIND, Ops.BUFFER})),
# always realize ASSIGN/CONTIGUOUS/COPY/BUFFER_VIEW
(UPat({Ops.ASSIGN, Ops.CONTIGUOUS, Ops.COPY, Ops.BUFFER_VIEW}, name="tr"), realize),
# always realize ASSIGN/CONTIGUOUS/COPY
(UPat({Ops.ASSIGN, Ops.CONTIGUOUS, Ops.COPY}, name="tr"), realize),
# realize before expand or unsafe pad ops
(UPat(Ops.VIEW, name="view", src=(UPat(GroupOp.All-{Ops.BUFFER, Ops.CONST, Ops.DEVICE}, name="src"),)), realize_before_view),
# realize before COPY
@@ -238,7 +238,7 @@ class KernelContext:
realizes: dict[UOp, None]
ops_metadata: dict[UOp, Metadata]
DONT_PLACE_IN_KERNEL = {Ops.KERNEL, Ops.ASSIGN, Ops.BUFFER}
DONT_PLACE_IN_KERNEL = {Ops.KERNEL, Ops.ASSIGN, Ops.BUFFER, Ops.BUFFER_VIEW}
def append_to_kernel(ctx:KernelContext, x:UOp):
new_srcs: list[UOp] = []
metadata = dict.fromkeys(x.arg.metadata)
@@ -274,9 +274,9 @@ def load_buf(ctx:list[UOp], x:UOp):
add_buffer_ops = PatternMatcher([
# LOAD
(UPat(Ops.BUFFER, name="x"), load_buf),
# STORE (except for COPY/BUFFER_VIEW)
(UPat(Ops.SINK, src=(UPat((Ops.COPY, Ops.BUFFER_VIEW), name="x"),)), lambda x:x),
(UPat((Ops.BUFFER, Ops.BUFFER_VIEW), name="x"), load_buf),
# STORE (except for COPY)
(UPat(Ops.SINK, src=(UPat(Ops.COPY, name="x"),)), lambda x:x),
(UPat(Ops.SINK, src=(UPat(GroupOp.All-{Ops.STORE}, name="x"),)),
lambda x: UOp.store(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), 0), ShapeTracker.from_shape(x.shape).to_uop(), x).sink()),
])
@@ -391,8 +391,6 @@ def schedule_uop(kernel:UOp, buffer_map:dict[UOp, UOp], var_vals:dict[Variable,
ast = graph_rewrite(graph_rewrite(ast, unbind_vars+view_left, ctx=var_vals), view_right)
# fix_kernel_ops
ast = graph_rewrite(ast, fix_kernel_ops, var_vals)
# create subbuffer
if ast.op is Ops.BUFFER_VIEW: buffers[bufs[0]] = (base:=bufs[1].buffer).view(ast.size, ast.dtype, ast.arg[1]*base.dtype.itemsize)
return ScheduleItem(ast, tuple(dedup([x.buffer for x in bufs])), kernel.arg.metadata)
PROCESS_REPLAY_CAPTURE:dict[str, bytes] = {}
@@ -427,10 +425,11 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
# if we created a KERNEL for this tensor, map it to the assigned buffer
if (a:=kernel_map.get(v.base)) is not None and a.op is Ops.ASSIGN:
becomes_map[k] = a.src[0] if v is v.base else a.src[0].view(unwrap(v.st))
elif a is not None and a.op is Ops.BUFFER_VIEW and a.src[0].op is Ops.ASSIGN: becomes_map[k] = a.replace(src=(a.src[0].buf_uop,))
# tensors can also simplify to an existing buffer/const
else:
if k is v: continue
if v.base.op is Ops.BUFFER: becomes_map[k] = v
if v.base.op in {Ops.BUFFER, Ops.BUFFER_VIEW}: becomes_map[k] = v
if v.base.op is Ops.CONST and all_int(v.shape): becomes_map[k] = v
# if a kernel depends on a buffer, and that buffer is later assigned to, make the assign depend on the kernel's assign

View File

@@ -513,7 +513,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
return dsrcs[0]._device if len(dsrcs:=[x for x in self.src if x._device is not None]) != 0 else None
@property
def buf_uop(self) -> UOp:
if self.op is Ops.BUFFER: return self
if self.op in {Ops.BUFFER, Ops.BUFFER_VIEW}: return self
assert self.op is Ops.ASSIGN, f"must be ASSIGN {self.op}"
return self.src[0].base
@property
@@ -521,14 +521,17 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
if self is not self.base:
assert unwrap(self.st).contiguous, "VIEW only works here if it's contiguous"
return self.src[0].buffer
assert self.op is Ops.BUFFER, f"must be BUFFER {self.op}"
assert self.op in {Ops.BUFFER, Ops.BUFFER_VIEW}, f"must be BUFFER {self.op}"
if (cret:=buffers.get(self)) is not None: return cret
if self.op is Ops.BUFFER_VIEW:
buffers[self] = ret = (base:=self.src[0].buffer).view(self.size, self.dtype, self.arg[1]*base.dtype.itemsize)
return ret
from tinygrad.device import Buffer
assert isinstance(self.device, str), f"buffer not supported on multi {self.device}"
buffers[self] = ret = Buffer(self.device, self.size, self.dtype if isinstance(self.dtype, ImageDType) else self.dtype.base)
return ret
@property
def realized(self) -> Optional[Buffer]: return self.buffer if self.op is Ops.BUFFER and self.buffer.is_allocated() else None
def realized(self) -> Optional[Buffer]: return self.buffer if self.op in {Ops.BUFFER, Ops.BUFFER_VIEW} and self.buffer.is_allocated() else None
@property
def is_realized(self) -> bool:
return all(x.base.realized is not None for x in self.base.real_lbs) if self.base.op is Ops.MULTI else self.base.realized is not None

View File

@@ -8,7 +8,7 @@ buffer_spec = PatternMatcher([
(UPat(Ops.DEVICE, dtypes.void, (), name="device"), lambda device: isinstance(device.arg, str)),
(UPat(Ops.BUFFER, src=(UPat(Ops.DEVICE), UPat(Ops.UNIQUE)), name="buf"),
lambda buf: isinstance(buf.arg, int) and isinstance(buf.dtype, (DType, ImageDType))),
(UPat(Ops.BUFFER_VIEW, src=(UPat(Ops.BUFFER),), name="buf_view"),
(UPat(Ops.BUFFER_VIEW, src=(UPat((Ops.BUFFER, Ops.BUFFER_VIEW)),), name="buf_view"),
lambda buf_view: isinstance(buf_view.arg, tuple) and len(buf_view.arg) == 2 and all_int(buf_view.arg)),
])
@@ -21,7 +21,7 @@ tensor_uop_spec = buffer_spec+PatternMatcher([
# "make things that can't be images not images" can change the buffer dtype
# this is fine as long as it's a realized buffer and base dtypes match.
((isinstance(mv.dtype, ImageDType) or isinstance(x.dtype, ImageDType)) and x.dtype.base == mv.dtype.base and x.base.op is Ops.BUFFER)),
(UPat(Ops.VIEW, src=(UPat(GroupOp.All-{Ops.BUFFER, Ops.CONST, Ops.DEVICE}),)), lambda: False),
(UPat(Ops.VIEW, src=(UPat(GroupOp.All-{Ops.BUFFER, Ops.BUFFER_VIEW, Ops.CONST, Ops.DEVICE}),)), lambda: False),
# Tensor variable bindings
(UPat(Ops.BIND, dtypes.int, (UPat(Ops.DEFINE_VAR), UPat.cvar(dtype=dtypes.int)), arg=None), lambda: True),
@@ -125,9 +125,11 @@ spec = PatternMatcher([
# *** this is the spec of a Kernel in UOp ***
kernel_spec = buffer_spec+PatternMatcher([
(UPat(Ops.KERNEL, src=UPat((Ops.BUFFER, Ops.ASSIGN))), lambda: True),
(UPat(Ops.KERNEL, src=UPat((Ops.BUFFER, Ops.BUFFER_VIEW, Ops.ASSIGN))), lambda: True),
# assign has a buffer view and kernel source, it can optionally depend on other assigns
(UPat(Ops.ASSIGN, src=UPat((Ops.BUFFER, Ops.VIEW, Ops.KERNEL, Ops.ASSIGN))), lambda: True),
(UPat(Ops.ASSIGN, src=UPat((Ops.BUFFER, Ops.BUFFER_VIEW, Ops.VIEW, Ops.KERNEL, Ops.ASSIGN))), lambda: True),
# we also allow for a BUFFER_VIEW to depend on ASSIGN in the kernel graph
(UPat(Ops.BUFFER_VIEW, src=(UPat(Ops.ASSIGN),)), lambda: True),
(UPat(GroupOp.All-{Ops.SINK, Ops.VIEW}), lambda: False),
])