mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
revert buffer_view change (#9311)
* Revert "BUFFER_VIEW is a node in the kernel graph + delete ViewOp (#9298)" This reverts commit3210b656b6. * Revert "substitute ast from kernel op [pr] (#9293)" This reverts commit5a9c788ae6.
This commit is contained in:
@@ -71,7 +71,7 @@ def loader_process(q_in, q_out, X:Tensor, seed):
|
||||
#storage_tensor._copyin(img_tensor.numpy())
|
||||
|
||||
# faster
|
||||
X[idx].contiguous().realize().lazydata.base.buffer.ensure_allocated().as_buffer(force_zero_copy=True)[:] = img.tobytes()
|
||||
X[idx].contiguous().realize().lazydata.base.realized.as_buffer(force_zero_copy=True)[:] = img.tobytes()
|
||||
|
||||
# ideal
|
||||
#X[idx].assign(img.tobytes()) # NOTE: this is slow!
|
||||
@@ -261,8 +261,8 @@ def load_unet3d_data(preprocessed_dataset_dir, seed, queue_in, queue_out, X:Tens
|
||||
x = random_brightness_augmentation(x)
|
||||
x = gaussian_noise(x)
|
||||
|
||||
X[idx].contiguous().realize().lazydata.base.buffer.ensure_allocated().as_buffer(force_zero_copy=True)[:] = x.tobytes()
|
||||
Y[idx].contiguous().realize().lazydata.base.buffer.ensure_allocated().as_buffer(force_zero_copy=True)[:] = y.tobytes()
|
||||
X[idx].contiguous().realize().lazydata.base.realized.as_buffer(force_zero_copy=True)[:] = x.tobytes()
|
||||
Y[idx].contiguous().realize().lazydata.base.realized.as_buffer(force_zero_copy=True)[:] = y.tobytes()
|
||||
|
||||
queue_out.put(idx)
|
||||
queue_out.put(None)
|
||||
|
||||
@@ -6,7 +6,7 @@ from tinygrad.device import Buffer, Compiled, Device
|
||||
from tinygrad.dtype import DType
|
||||
from tinygrad.ops import UOp, Variable, sym_infer, Ops
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.engine.realize import ExecItem, capturing, BufferCopy, BufferXfer, CompiledRunner, Runner, Estimates
|
||||
from tinygrad.engine.realize import ExecItem, capturing, ViewOp, BufferCopy, BufferXfer, CompiledRunner, Runner, Estimates
|
||||
from tinygrad.engine.memory import _internal_memory_planner
|
||||
from tinygrad.nn.state import get_parameters
|
||||
from dataclasses import dataclass
|
||||
@@ -39,6 +39,7 @@ def apply_graph_to_jit(jit_cache: list[ExecItem], input_rawbuffers: list[Buffer]
|
||||
current_device = None
|
||||
|
||||
for ji in jit_cache:
|
||||
if isinstance(ji.prg, ViewOp): continue
|
||||
ji_graph_dev: Optional[Compiled] = None # device on which the ji will be graphed. Not graphed if None.
|
||||
if isinstance(ji.prg, CompiledRunner): ji_graph_dev = ji.prg.dev
|
||||
elif isinstance(ji.prg, BufferXfer) and ji.bufs[0] and ji.bufs[0].device.split(":", 1)[0] in {"CUDA", "NV", "AMD"}:
|
||||
|
||||
@@ -66,6 +66,11 @@ class CompiledRunner(Runner):
|
||||
assert len(local_size) == 3, "local size must have len 3"
|
||||
return self._prg(*[x._buf for x in rawbufs], **lra, vals=tuple(var_vals[k] for k in self.p.vars), wait=wait)
|
||||
|
||||
class ViewOp(Runner):
|
||||
def __init__(self, buf:Buffer): super().__init__(colored(f"view {buf.nbytes:8d} @ {buf.offset:<10d}", "yellow"), buf.device)
|
||||
def __call__(self, rawbufs:list[Buffer], var_vals:dict[Variable, int], wait=False):
|
||||
assert rawbufs[0]._base is not None and rawbufs[0]._base == rawbufs[1].base, f"must be base {rawbufs}"
|
||||
|
||||
class BufferCopy(Runner):
|
||||
def __init__(self, total_sz, dest_device, src_device):
|
||||
if total_sz >= 1e6: name = f"{type(self).__name__[6:].lower()} {total_sz/1e6:7.2f}M, {dest_device[:7]:>7s} <- {src_device[:7]:7s}"
|
||||
@@ -138,6 +143,7 @@ class ExecItem:
|
||||
# NOTE: ctx is the buffers
|
||||
si_lowerer = PatternMatcher([
|
||||
(UPat(Ops.SINK, name="sink"), lambda ctx,sink: (runner:=get_runner(ctx[0].device, sink), [ctx[x] for x in runner.p.globals])),
|
||||
(UPat(Ops.BUFFER_VIEW), lambda ctx: (ViewOp(ctx[0]), list(ctx))),
|
||||
(UPat(Ops.COPY, name="copy"), lambda ctx,copy: ((BufferXfer(ctx[0].nbytes, ctx[0].device, ctx[1].device) \
|
||||
if hasattr(Device[ctx[0].device].allocator, '_transfer') and all_same([x.device.split(":")[0] for x in ctx]) \
|
||||
else BufferCopy(ctx[0].nbytes, ctx[0].device, ctx[1].device)), list(ctx))),
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import sys, atexit, pickle
|
||||
from collections import defaultdict, deque
|
||||
from dataclasses import dataclass
|
||||
from tinygrad.ops import UOp, Variable, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, track_rewrites
|
||||
from tinygrad.ops import UOp, Variable, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, track_rewrites, buffers
|
||||
from tinygrad.ops import can_pad, identity_element, resolve, view_left, merge_views
|
||||
from tinygrad.codegen.symbolic import symbolic_simple
|
||||
from tinygrad.helpers import Context, ContextVar, Metadata, all_int, all_same, colored, diskcache_put, prod, dedup, unwrap, flatten, getenv, pluralize
|
||||
@@ -139,8 +139,8 @@ def realize_before_view(ctx:GrouperContext, view:UOp, src:UOp) -> None:
|
||||
do_realize = PatternMatcher([
|
||||
# always realize SINK parents
|
||||
(UPat(Ops.SINK, name="s"), lambda ctx,s: ctx.realizes.update((x.base, None) for x in s.src if x.base.op not in {Ops.CONST, Ops.BIND, Ops.BUFFER})),
|
||||
# always realize ASSIGN/CONTIGUOUS/COPY
|
||||
(UPat({Ops.ASSIGN, Ops.CONTIGUOUS, Ops.COPY}, name="tr"), realize),
|
||||
# always realize ASSIGN/CONTIGUOUS/COPY/BUFFER_VIEW
|
||||
(UPat({Ops.ASSIGN, Ops.CONTIGUOUS, Ops.COPY, Ops.BUFFER_VIEW}, name="tr"), realize),
|
||||
# realize before expand or unsafe pad ops
|
||||
(UPat(Ops.VIEW, name="view", src=(UPat(GroupOp.All-{Ops.BUFFER, Ops.CONST, Ops.DEVICE}, name="src"),)), realize_before_view),
|
||||
# realize before COPY
|
||||
@@ -238,7 +238,7 @@ class KernelContext:
|
||||
realizes: dict[UOp, None]
|
||||
ops_metadata: dict[UOp, Metadata]
|
||||
|
||||
DONT_PLACE_IN_KERNEL = {Ops.KERNEL, Ops.ASSIGN, Ops.BUFFER, Ops.BUFFER_VIEW}
|
||||
DONT_PLACE_IN_KERNEL = {Ops.KERNEL, Ops.ASSIGN, Ops.BUFFER}
|
||||
def append_to_kernel(ctx:KernelContext, x:UOp):
|
||||
new_srcs: list[UOp] = []
|
||||
metadata = dict.fromkeys(x.arg.metadata)
|
||||
@@ -274,9 +274,9 @@ def load_buf(ctx:list[UOp], x:UOp):
|
||||
|
||||
add_buffer_ops = PatternMatcher([
|
||||
# LOAD
|
||||
(UPat((Ops.BUFFER, Ops.BUFFER_VIEW), name="x"), load_buf),
|
||||
# STORE (except for COPY)
|
||||
(UPat(Ops.SINK, src=(UPat(Ops.COPY, name="x"),)), lambda x:x),
|
||||
(UPat(Ops.BUFFER, name="x"), load_buf),
|
||||
# STORE (except for COPY/BUFFER_VIEW)
|
||||
(UPat(Ops.SINK, src=(UPat((Ops.COPY, Ops.BUFFER_VIEW), name="x"),)), lambda x:x),
|
||||
(UPat(Ops.SINK, src=(UPat(GroupOp.All-{Ops.STORE}, name="x"),)),
|
||||
lambda x: UOp.store(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), 0), ShapeTracker.from_shape(x.shape).to_uop(), x).sink()),
|
||||
])
|
||||
@@ -380,18 +380,20 @@ class ScheduleItem:
|
||||
bufs: tuple[Buffer, ...]
|
||||
metadata: tuple[Metadata, ...]
|
||||
|
||||
def schedule_uop(kernel:UOp, buffer_map:dict[UOp, UOp], var_vals:dict[Variable, int]) -> ScheduleItem:
|
||||
assert kernel.op is Ops.KERNEL, f"kernel isn't kernel, it's {kernel.op}"
|
||||
def schedule_uop(sink:UOp, var_vals:dict[Variable, int]) -> ScheduleItem:
|
||||
assert sink.op is Ops.ASSIGN and sink.src[1].op is Ops.KERNEL, f"{sink} must be ASSIGN"
|
||||
# substitute kernel sources for the target buffer
|
||||
ast = kernel.arg.ast.substitute({k:v for k,v in buffer_map.items() if k is not kernel.arg.ast}).sink()
|
||||
ast = sink.src[1].arg.ast.substitute({s.src[1].arg.ast:s.src[0] for s in sink.src[1].src if s.op is Ops.ASSIGN}).sink()
|
||||
# add buffer ops
|
||||
ast = graph_rewrite(ast, add_buffer_ops, bufs:=[kernel.src[0].buf_uop], bottom_up=True)
|
||||
ast = graph_rewrite(ast, add_buffer_ops, bufs:=[sink.buf_uop], bottom_up=True)
|
||||
if ast.op is Ops.SINK and not all_same(dev:=[x.device for x in bufs]): raise RuntimeError(f"all buffers must be on the same device: {dev}")
|
||||
# unbind_vars + push views to edges
|
||||
ast = graph_rewrite(graph_rewrite(ast, unbind_vars+view_left, ctx=var_vals), view_right)
|
||||
# fix_kernel_ops
|
||||
ast = graph_rewrite(ast, fix_kernel_ops, var_vals)
|
||||
return ScheduleItem(ast, tuple(dedup([x.buffer for x in bufs])), kernel.arg.metadata)
|
||||
# create subbuffer
|
||||
if ast.op is Ops.BUFFER_VIEW: buffers[bufs[0]] = (base:=bufs[1].buffer).view(ast.size, ast.dtype, ast.arg[1]*base.dtype.itemsize)
|
||||
return ScheduleItem(ast, tuple(dedup([x.buffer for x in bufs])), sink.src[1].arg.metadata)
|
||||
|
||||
PROCESS_REPLAY_CAPTURE:dict[str, bytes] = {}
|
||||
if CAPTURE_PROCESS_REPLAY:
|
||||
@@ -425,21 +427,18 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
|
||||
# if we created a KERNEL for this tensor, map it to the assigned buffer
|
||||
if (a:=kernel_map.get(v.base)) is not None and a.op is Ops.ASSIGN:
|
||||
becomes_map[k] = a.src[0] if v is v.base else a.src[0].view(unwrap(v.st))
|
||||
elif a is not None and a.op is Ops.BUFFER_VIEW and a.src[0].op is Ops.ASSIGN: becomes_map[k] = a.replace(src=(a.src[0].buf_uop,))
|
||||
# tensors can also simplify to an existing buffer/const
|
||||
else:
|
||||
if k is v: continue
|
||||
if v.base.op in {Ops.BUFFER, Ops.BUFFER_VIEW}: becomes_map[k] = v
|
||||
if v.base.op is Ops.BUFFER: becomes_map[k] = v
|
||||
if v.base.op is Ops.CONST and all_int(v.shape): becomes_map[k] = v
|
||||
|
||||
# if a kernel depends on a buffer, and that buffer is later assigned to, make the assign depend on the kernel's assign
|
||||
kernel_assign: dict[UOp, UOp] = {}
|
||||
buffer_map: dict[UOp, UOp] = {}
|
||||
assign_rep: dict[UOp, UOp] = {}
|
||||
for u in sched_sink.toposort:
|
||||
if u.op is not Ops.ASSIGN: continue
|
||||
kernel_assign[u.buf_uop] = u
|
||||
buffer_map[u.src[1].arg.ast] = u.src[0]
|
||||
for s in u.src[1].src:
|
||||
if s.op is not Ops.BUFFER or s is u.buf_uop or (a:=kernel_assign.get(s)) is None: continue
|
||||
if any(x.op is Ops.ASSIGN and x.buf_uop is s for x in u.toposort):
|
||||
@@ -456,8 +455,10 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
|
||||
children: dict[UOp, list[UOp]] = {}
|
||||
in_degree: dict[UOp, int] = {}
|
||||
for u in sched_sink.toposort:
|
||||
if u.op is not Ops.ASSIGN: continue
|
||||
in_degree[u] = 0
|
||||
for s in u.src:
|
||||
for s in u.src[1].src:
|
||||
if s.op is not Ops.ASSIGN: continue
|
||||
children.setdefault(s, []).append(u)
|
||||
in_degree[u] += 1
|
||||
|
||||
@@ -466,16 +467,15 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
|
||||
var_vals: dict[Variable, int] = {}
|
||||
while queue:
|
||||
u = queue.popleft()
|
||||
if u.op is Ops.ASSIGN:
|
||||
schedule.append(schedule_uop(u.src[1], buffer_map, var_vals))
|
||||
# increment the refcount of the target buf (this is required by the JIT and memory planner)
|
||||
u.buf_uop.buffer.ref(1)
|
||||
schedule.append(schedule_uop(u, var_vals))
|
||||
# increment the refcount of the target buf (this is required by the JIT and memory planner)
|
||||
u.buf_uop.buffer.ref(1)
|
||||
for x in children.get(u, []):
|
||||
in_degree[x] -= 1
|
||||
if in_degree[x] == 0: queue.append(x)
|
||||
|
||||
# confirm everything was scheduled correctly
|
||||
if len(schedule) != (kc:=len(buffer_map)): raise RuntimeError(f"cycle detected in graph, created {kc} kernels but only scheduled {len(schedule)}")
|
||||
if len(schedule) != (kc:=len(in_degree)): raise RuntimeError(f"cycle detected in graph, created {kc} kernels but only scheduled {len(schedule)}")
|
||||
if DEBUG >= 1 and len(schedule) >= 10: print(f"scheduled {len(schedule)} kernels")
|
||||
# capture process replay
|
||||
if CAPTURE_PROCESS_REPLAY:
|
||||
|
||||
@@ -513,7 +513,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
return dsrcs[0]._device if len(dsrcs:=[x for x in self.src if x._device is not None]) != 0 else None
|
||||
@property
|
||||
def buf_uop(self) -> UOp:
|
||||
if self.op in {Ops.BUFFER, Ops.BUFFER_VIEW}: return self
|
||||
assert self.op is Ops.ASSIGN, f"must be ASSIGN {self.op}"
|
||||
return self.src[0].base
|
||||
@property
|
||||
@@ -521,17 +520,14 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
if self is not self.base:
|
||||
assert unwrap(self.st).contiguous, "VIEW only works here if it's contiguous"
|
||||
return self.src[0].buffer
|
||||
assert self.op in {Ops.BUFFER, Ops.BUFFER_VIEW}, f"must be BUFFER {self.op}"
|
||||
assert self.op is Ops.BUFFER, f"must be BUFFER {self.op}"
|
||||
if (cret:=buffers.get(self)) is not None: return cret
|
||||
if self.op is Ops.BUFFER_VIEW:
|
||||
buffers[self] = ret = (base:=self.src[0].buffer).view(self.size, self.dtype, self.arg[1]*base.dtype.itemsize)
|
||||
return ret
|
||||
from tinygrad.device import Buffer
|
||||
assert isinstance(self.device, str), f"buffer not supported on multi {self.device}"
|
||||
buffers[self] = ret = Buffer(self.device, self.size, self.dtype if isinstance(self.dtype, ImageDType) else self.dtype.base)
|
||||
return ret
|
||||
@property
|
||||
def realized(self) -> Optional[Buffer]: return self.buffer if self.op in {Ops.BUFFER, Ops.BUFFER_VIEW} and self.buffer.is_allocated() else None
|
||||
def realized(self) -> Optional[Buffer]: return self.buffer if self.op is Ops.BUFFER and self.buffer.is_allocated() else None
|
||||
@property
|
||||
def is_realized(self) -> bool:
|
||||
return all(x.base.realized is not None for x in self.base.real_lbs) if self.base.op is Ops.MULTI else self.base.realized is not None
|
||||
|
||||
@@ -8,7 +8,7 @@ buffer_spec = PatternMatcher([
|
||||
(UPat(Ops.DEVICE, dtypes.void, (), name="device"), lambda device: isinstance(device.arg, str)),
|
||||
(UPat(Ops.BUFFER, src=(UPat(Ops.DEVICE), UPat(Ops.UNIQUE)), name="buf"),
|
||||
lambda buf: isinstance(buf.arg, int) and isinstance(buf.dtype, (DType, ImageDType))),
|
||||
(UPat(Ops.BUFFER_VIEW, src=(UPat((Ops.BUFFER, Ops.BUFFER_VIEW)),), name="buf_view"),
|
||||
(UPat(Ops.BUFFER_VIEW, src=(UPat(Ops.BUFFER),), name="buf_view"),
|
||||
lambda buf_view: isinstance(buf_view.arg, tuple) and len(buf_view.arg) == 2 and all_int(buf_view.arg)),
|
||||
])
|
||||
|
||||
@@ -21,7 +21,7 @@ tensor_uop_spec = buffer_spec+PatternMatcher([
|
||||
# "make things that can't be images not images" can change the buffer dtype
|
||||
# this is fine as long as it's a realized buffer and base dtypes match.
|
||||
((isinstance(mv.dtype, ImageDType) or isinstance(x.dtype, ImageDType)) and x.dtype.base == mv.dtype.base and x.base.op is Ops.BUFFER)),
|
||||
(UPat(Ops.VIEW, src=(UPat(GroupOp.All-{Ops.BUFFER, Ops.BUFFER_VIEW, Ops.CONST, Ops.DEVICE}),)), lambda: False),
|
||||
(UPat(Ops.VIEW, src=(UPat(GroupOp.All-{Ops.BUFFER, Ops.CONST, Ops.DEVICE}),)), lambda: False),
|
||||
|
||||
# Tensor variable bindings
|
||||
(UPat(Ops.BIND, dtypes.int, (UPat(Ops.DEFINE_VAR), UPat.cvar(dtype=dtypes.int)), arg=None), lambda: True),
|
||||
@@ -125,11 +125,9 @@ spec = PatternMatcher([
|
||||
# *** this is the spec of a Kernel in UOp ***
|
||||
|
||||
kernel_spec = buffer_spec+PatternMatcher([
|
||||
(UPat(Ops.KERNEL, src=UPat((Ops.BUFFER, Ops.BUFFER_VIEW, Ops.ASSIGN))), lambda: True),
|
||||
(UPat(Ops.KERNEL, src=UPat((Ops.BUFFER, Ops.ASSIGN))), lambda: True),
|
||||
# assign has a buffer view and kernel source, it can optionally depend on other assigns
|
||||
(UPat(Ops.ASSIGN, src=UPat((Ops.BUFFER, Ops.BUFFER_VIEW, Ops.VIEW, Ops.KERNEL, Ops.ASSIGN))), lambda: True),
|
||||
# we also allow for a BUFFER_VIEW to depend on ASSIGN in the kernel graph
|
||||
(UPat(Ops.BUFFER_VIEW, src=(UPat(Ops.ASSIGN),)), lambda: True),
|
||||
(UPat(Ops.ASSIGN, src=UPat((Ops.BUFFER, Ops.VIEW, Ops.KERNEL, Ops.ASSIGN))), lambda: True),
|
||||
(UPat(GroupOp.All-{Ops.SINK, Ops.VIEW}), lambda: False),
|
||||
])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user