mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
add DEBUG_GC (#13465)
* add DEBUG_GC * fixup create_schedule_with_vars * work
This commit is contained in:
@@ -1,11 +1,11 @@
|
||||
import time
|
||||
from typing import cast
|
||||
from dataclasses import dataclass, field, replace
|
||||
from collections import deque, defaultdict
|
||||
from collections import deque
|
||||
from tinygrad.uop.ops import UOp, Ops, buffers
|
||||
from tinygrad.uop.spec import type_verify, tensor_spec
|
||||
from tinygrad.device import Device, Buffer, MultiBuffer
|
||||
from tinygrad.helpers import Metadata, all_same, DEBUG, cpu_profile, TracingKey, SPEC, flatten
|
||||
from tinygrad.device import Buffer, MultiBuffer
|
||||
from tinygrad.helpers import Metadata, DEBUG, cpu_profile, TracingKey, SPEC, flatten, disable_gc
|
||||
|
||||
# **** ScheduleItem return type
|
||||
|
||||
@@ -20,107 +20,100 @@ class ScheduleItem:
|
||||
# **** schedule linearizer
|
||||
|
||||
def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[str, int]]:
|
||||
# construct the KERNEL children graph based on assigns
|
||||
children: defaultdict[UOp, list[UOp]] = defaultdict(list)
|
||||
in_degree: dict[UOp, int] = {}
|
||||
var_vals: dict[str, int] = {}
|
||||
for u in sched_sink.toposort():
|
||||
if u.op is Ops.RANGE:
|
||||
in_degree.setdefault(u, 0)
|
||||
continue
|
||||
if u.op is not Ops.AFTER or u.src[1].op is Ops.RANGE: continue
|
||||
k = u.src[1]
|
||||
in_degree.setdefault(k, 0)
|
||||
for s in k.src[0].src if k.op is Ops.END else k.src:
|
||||
if s.op is Ops.AFTER:
|
||||
children[s.src[1]].append(k)
|
||||
in_degree[k] += 1
|
||||
elif s.op in {Ops.MSELECT, Ops.MSTACK}:
|
||||
for ss in s.src:
|
||||
if ss.op is Ops.MSELECT: ss = ss.src[0]
|
||||
if ss.op is not Ops.BUFFER:
|
||||
assert ss.op is Ops.AFTER, f"ss.op is not AFTER, it's {ss.op}"
|
||||
children[ss.src[1]].append(k)
|
||||
in_degree[k] += 1
|
||||
elif s.op is Ops.BUFFER:
|
||||
pass # a BUFFER is already realized, nothing to do here
|
||||
elif s.op is Ops.BIND:
|
||||
# for RANGE this is in fixedvars
|
||||
if s.src[1].op is not Ops.RANGE:
|
||||
var, val = s.unbind()
|
||||
assert var.expr not in var_vals or var_vals[var.expr] == val, f"bind mismatch on {var}, {var_vals[var.expr]} != {val}"
|
||||
var_vals[var.expr] = val
|
||||
with cpu_profile(TracingKey("toposort sched_sink")):
|
||||
# construct the KERNEL children graph based on assigns
|
||||
children: dict[UOp, list[UOp]] = {}
|
||||
in_degree: dict[UOp, int] = {}
|
||||
var_vals: dict[str, int] = {}
|
||||
for u in sched_sink.toposort():
|
||||
if u.op is Ops.RANGE:
|
||||
in_degree.setdefault(u, 0)
|
||||
continue
|
||||
if u.op is not Ops.AFTER or u.src[1].op is Ops.RANGE: continue
|
||||
k = u.src[1]
|
||||
in_degree.setdefault(k, 0)
|
||||
for s in k.src[0].src if k.op is Ops.END else k.src:
|
||||
if s.op is Ops.AFTER:
|
||||
children.setdefault(s.src[1], []).append(k)
|
||||
in_degree[k] += 1
|
||||
elif s.op in {Ops.MSELECT, Ops.MSTACK}:
|
||||
for ss in s.src:
|
||||
if ss.op is Ops.MSELECT: ss = ss.src[0]
|
||||
if ss.op is not Ops.BUFFER:
|
||||
assert ss.op is Ops.AFTER, f"ss.op is not AFTER, it's {ss.op}"
|
||||
children.setdefault(ss.src[1], []).append(k)
|
||||
in_degree[k] += 1
|
||||
elif s.op is Ops.BUFFER:
|
||||
pass # a BUFFER is already realized, nothing to do here
|
||||
elif s.op is Ops.BIND:
|
||||
# for RANGE this is in fixedvars
|
||||
if s.src[1].op is not Ops.RANGE:
|
||||
var, val = s.unbind()
|
||||
assert var.expr not in var_vals or var_vals[var.expr] == val, f"bind mismatch on {var}, {var_vals[var.expr]} != {val}"
|
||||
var_vals[var.expr] = val
|
||||
else:
|
||||
raise RuntimeError(f"input to kernel must be AFTER or BUFFER, not {s.op}")
|
||||
|
||||
with cpu_profile(TracingKey("linearize to ScheduleItem")):
|
||||
queue: deque[UOp] = deque()
|
||||
for k,v in in_degree.items():
|
||||
if v == 0: queue.append(k)
|
||||
|
||||
schedule: list[ScheduleItem|UOp] = []
|
||||
while len(queue):
|
||||
k = rk = queue.popleft()
|
||||
if k.op is Ops.END: k = k.src[0]
|
||||
if k.op is Ops.RANGE: schedule.append(k)
|
||||
elif k.op is Ops.KERNEL:
|
||||
ast = k.arg.ast
|
||||
# create subbuffers if needed
|
||||
if ast.op is Ops.BUFFER_VIEW:
|
||||
base = k.src[1].buf_uop.buffer
|
||||
assert isinstance(base, Buffer), "base can't be MultiBuffer"
|
||||
buffers[k.src[0]] = base.view(k.size, ast.dtype, ast.arg[1]*base.dtype.itemsize)
|
||||
ubufs = tuple(s.buf_uop.buffer for s in k.src if s.op is not Ops.BIND)
|
||||
bound_ranges = tuple(s for s in k.src if s.op is Ops.BIND and s.src[1].op is Ops.RANGE)
|
||||
if any(isinstance(x, MultiBuffer) for x in ubufs):
|
||||
assert all(isinstance(x, MultiBuffer) for x in ubufs), "kernel must all be multibuffer"
|
||||
dnums = [x for x in ast.variables() if x.arg[0] == '_device_num']
|
||||
for i,bufs in enumerate(zip(*[x.bufs for x in cast(tuple[MultiBuffer, ...], ubufs)])):
|
||||
schedule.append(ScheduleItem(ast, bufs, k.arg.metadata, {dnums[0].expr:i} if len(dnums) else {}, bound_ranges=bound_ranges))
|
||||
else:
|
||||
# ONE -> ONE
|
||||
schedule.append(ScheduleItem(ast, cast(tuple[Buffer, ...], ubufs), k.arg.metadata, bound_ranges=bound_ranges))
|
||||
if rk.op is Ops.END: schedule.append(rk)
|
||||
else:
|
||||
raise RuntimeError(f"input to kernel must be AFTER or BUFFER, not {s.op}")
|
||||
raise RuntimeError(f"can't schedule {k.op}")
|
||||
for x in children.get(rk, []):
|
||||
in_degree[x] -= 1
|
||||
if in_degree[x] == 0: queue.append(x)
|
||||
|
||||
# linearize KERNEL UOps into ScheduleItems in BFS order
|
||||
|
||||
def _heuristic(k: UOp):
|
||||
if k.op is Ops.KERNEL and k.arg.ast.op is Ops.COPY and not all_same([Device[cast(Buffer, s.buf_uop.buffer).device].group_id for s in k.src]):
|
||||
return 1000
|
||||
return 0
|
||||
|
||||
last_heuristic: int = 0
|
||||
queues: defaultdict[int, deque[UOp]] = defaultdict(deque)
|
||||
last_queue: deque[UOp] = deque()
|
||||
for k,v in in_degree.items():
|
||||
if v == 0: queues[_heuristic(k)].append(k)
|
||||
|
||||
schedule: list[ScheduleItem|UOp] = []
|
||||
while last_queue or any(queues.values()):
|
||||
if not last_queue: last_heuristic, last_queue = min((it for it in queues.items() if it[1]), key=lambda x: abs(x[0]-last_heuristic))
|
||||
k = rk = last_queue.popleft()
|
||||
if k.op is Ops.END: k = k.src[0]
|
||||
if k.op is Ops.RANGE: schedule.append(k)
|
||||
elif k.op is Ops.KERNEL:
|
||||
ast = k.arg.ast
|
||||
# create subbuffers if needed
|
||||
if ast.op is Ops.BUFFER_VIEW:
|
||||
base = k.src[1].buf_uop.buffer
|
||||
assert isinstance(base, Buffer), "base can't be MultiBuffer"
|
||||
buffers[k.src[0]] = base.view(k.size, ast.dtype, ast.arg[1]*base.dtype.itemsize)
|
||||
ubufs = tuple(s.buf_uop.buffer for s in k.src if s.op is not Ops.BIND)
|
||||
bound_ranges = tuple(s for s in k.src if s.op is Ops.BIND and s.src[1].op is Ops.RANGE)
|
||||
if any(isinstance(x, MultiBuffer) for x in ubufs):
|
||||
assert all(isinstance(x, MultiBuffer) for x in ubufs), "kernel must all be multibuffer"
|
||||
dnums = [x for x in ast.variables() if x.arg[0] == '_device_num']
|
||||
for i,bufs in enumerate(zip(*[x.bufs for x in cast(tuple[MultiBuffer, ...], ubufs)])):
|
||||
schedule.append(ScheduleItem(ast, bufs, k.arg.metadata, {dnums[0].expr:i} if len(dnums) else {}, bound_ranges=bound_ranges))
|
||||
with cpu_profile(TracingKey("expand ranges")):
|
||||
real_schedule: list[ScheduleItem] = []
|
||||
sched_ptr = 0
|
||||
in_ranges = {}
|
||||
range_ptrs = {}
|
||||
while sched_ptr < len(schedule):
|
||||
si = schedule[sched_ptr]
|
||||
if isinstance(si, UOp):
|
||||
if si.op is Ops.RANGE:
|
||||
in_ranges[si] = 0
|
||||
range_ptrs[si] = sched_ptr + 1
|
||||
elif si.op is Ops.END:
|
||||
if in_ranges[si.src[1]] < si.src[1].vmax:
|
||||
in_ranges[si.src[1]] += 1
|
||||
sched_ptr = range_ptrs[si.src[1]]
|
||||
continue
|
||||
else:
|
||||
# ONE -> ONE
|
||||
schedule.append(ScheduleItem(ast, cast(tuple[Buffer, ...], ubufs), k.arg.metadata, bound_ranges=bound_ranges))
|
||||
if rk.op is Ops.END: schedule.append(rk)
|
||||
else:
|
||||
raise RuntimeError(f"can't schedule {k.op}")
|
||||
for x in children[rk]:
|
||||
in_degree[x] -= 1
|
||||
if in_degree[x] == 0: queues[_heuristic(x)].append(x)
|
||||
|
||||
# expand the ranges in the schedule
|
||||
real_schedule: list[ScheduleItem] = []
|
||||
sched_ptr = 0
|
||||
in_ranges = {}
|
||||
range_ptrs = {}
|
||||
while sched_ptr < len(schedule):
|
||||
si = schedule[sched_ptr]
|
||||
if isinstance(si, UOp):
|
||||
if si.op is Ops.RANGE:
|
||||
in_ranges[si] = 0
|
||||
range_ptrs[si] = sched_ptr + 1
|
||||
elif si.op is Ops.END:
|
||||
if in_ranges[si.src[1]] < si.src[1].vmax:
|
||||
in_ranges[si.src[1]] += 1
|
||||
sched_ptr = range_ptrs[si.src[1]]
|
||||
continue
|
||||
else:
|
||||
real_schedule.append(replace(si, fixedvars=si.fixedvars | {s.src[0].arg[0]:in_ranges[s.src[1]] for s in si.bound_ranges}, bound_ranges=()))
|
||||
sched_ptr += 1
|
||||
real_schedule.append(replace(si, fixedvars=si.fixedvars | {s.src[0].arg[0]:in_ranges[s.src[1]] for s in si.bound_ranges}, bound_ranges=()))
|
||||
sched_ptr += 1
|
||||
return real_schedule, var_vals
|
||||
|
||||
from tinygrad.engine.memory import memory_planner
|
||||
from tinygrad.schedule.rangeify import get_rangeify_map
|
||||
from tinygrad.schedule.multi import get_multi_map
|
||||
|
||||
@disable_gc()
|
||||
def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], list[ScheduleItem], dict[str, int]]:
|
||||
# big_sink srcs are all the Tensors
|
||||
st = time.perf_counter()
|
||||
@@ -140,7 +133,7 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
|
||||
big_sink = big_sink.substitute(tensor_map, name="Apply Kernelize Map")
|
||||
|
||||
# create the schedule
|
||||
with cpu_profile(TracingKey("toposort schedule")): schedule, var_vals = create_schedule_with_vars(big_sink)
|
||||
schedule, var_vals = create_schedule_with_vars(big_sink)
|
||||
with cpu_profile(TracingKey("memory planner")): schedule = memory_planner(schedule)
|
||||
|
||||
# remove all AFTERs, after scheduling, the tensors are just buffers
|
||||
|
||||
@@ -291,6 +291,15 @@ def cpu_profile(name:str|TracingKey, device="TINY", is_copy=False, display=True)
|
||||
def profile_marker(name:str, color="gray") -> None:
|
||||
cpu_events.append(ProfilePointEvent("TINY", "marker", None, {"name":name, "color":color}))
|
||||
|
||||
if getenv("DEBUG_GC"):
|
||||
gc_start: decimal.Decimal = perf_counter_us()
|
||||
def my_gc_callback(phase, info):
|
||||
global gc_start
|
||||
if phase == 'start': gc_start = perf_counter_us()
|
||||
elif phase == "stop":
|
||||
cpu_events.append(ProfileRangeEvent("GC", f"collected: {info['collected']} (gen {info['generation']})", gc_start, perf_counter_us()))
|
||||
if PROFILE: gc.callbacks.append(my_gc_callback)
|
||||
|
||||
# *** universal database cache ***
|
||||
|
||||
cache_dir: str = os.path.join(getenv("XDG_CACHE_HOME", os.path.expanduser("~/Library/Caches" if OSX else "~/.cache")), "tinygrad")
|
||||
|
||||
@@ -5,7 +5,7 @@ from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _
|
||||
from tinygrad.uop.ops import track_rewrites, graph_rewrite, identity_element, sint, AxisType, BottomUpGate, Kernel, _remove_all_tags, range_str
|
||||
from tinygrad.uop.symbolic import symbolic
|
||||
from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, flatten, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY
|
||||
from tinygrad.helpers import PCONTIG, partition, get_single_element, unwrap, disable_gc
|
||||
from tinygrad.helpers import PCONTIG, partition, get_single_element, unwrap
|
||||
from tinygrad.codegen.simplify import pm_flatten_range, pm_reduce_simplify
|
||||
from tinygrad.codegen.opt import Opt
|
||||
from tinygrad.schedule.indexing import run_rangeify, BufferizeOpts, ALWAYS_CONTIGUOUS, IndexingContext, apply_movement_op
|
||||
@@ -538,7 +538,6 @@ replace_contiguous = PatternMatcher([
|
||||
(UPat(GroupOp.ALU, name="alu"), lambda ctx,alu: alu.replace(src=new_src) if (new_src:=tuple(ctx.get(s, s) for s in alu.src)) != alu.src else None),
|
||||
])
|
||||
|
||||
@disable_gc()
|
||||
@track_rewrites(lambda _,ret: f"Schedule {pluralize('Kernel', len([u for u in UOp.sink(*ret.values()).toposort() if u.op is Ops.KERNEL]))}", True)
|
||||
def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
|
||||
if getenv("VIZ"): graph_rewrite(sink, PatternMatcher([]), name="View Input Graph")
|
||||
|
||||
@@ -263,7 +263,7 @@ def load_sqtt(profile:list[ProfileEvent]) -> None:
|
||||
ctxs.append({"name":"Counters", "steps":steps})
|
||||
|
||||
def device_sort_fn(k:str) -> tuple[int, str, int]:
|
||||
order = {"USER": 0, "TINY": 1, "DISK": 999}
|
||||
order = {"GC": 0, "USER": 1, "TINY": 2, "DISK": 999}
|
||||
dname = k.split()[0]
|
||||
dev_rank = next((v for k,v in order.items() if dname.startswith(k)), len(order))
|
||||
return (dev_rank, dname, len(k))
|
||||
|
||||
Reference in New Issue
Block a user