mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
nv/amd profiler (#4718)
* nv/amd profiler * fix * fix * profile copies * profile logger * fixes * more fixes * less lines and fixes * fixes * some linter * back sync, no related change * fix gpu2cpu time def * simpler * linter * linter * docs * add add_event api
This commit is contained in:
@@ -47,3 +47,4 @@ DEFAULT_FLOAT | [HALF, ...]| specify the default float dtype (FLOAT32, HAL
|
||||
IMAGE | [1-2] | enable 2d specific optimizations
|
||||
FLOAT16 | [1] | use float16 for images instead of float32
|
||||
PTX | [1] | enable the specialized [PTX](https://docs.nvidia.com/cuda/parallel-thread-execution/) assembler for Nvidia GPUs. If not set, defaults to generic CUDA codegen backend.
|
||||
PROFILE | [1] | enable output of [perfetto](https://ui.perfetto.dev/) compatible profile. This feature is supported in NV and AMD backends.
|
||||
|
||||
@@ -2,15 +2,13 @@ from __future__ import annotations
|
||||
import ctypes, functools, subprocess, io, atexit, collections, json
|
||||
from typing import Tuple, TypeVar, List, Dict, Any
|
||||
import tinygrad.runtime.autogen.hsa as hsa
|
||||
from tinygrad.helpers import DEBUG, init_c_var, from_mv, round_up, to_mv, init_c_struct_t, getenv
|
||||
from tinygrad.helpers import DEBUG, init_c_var, from_mv, round_up, to_mv, init_c_struct_t, getenv, PROFILE
|
||||
from tinygrad.device import Compiled, Compiler, CompileError, BufferOptions, LRUAllocator
|
||||
from tinygrad.renderer.cstyle import HIPRenderer
|
||||
from tinygrad.runtime.driver.hsa import check, scan_agents, find_memory_pool, AQLQueue
|
||||
from tinygrad.runtime.driver.hip_comgr import compile_hip
|
||||
if getenv("IOCTL"): import extra.hip_gpu_driver.hip_ioctl # noqa: F401
|
||||
|
||||
PROFILE = getenv("PROFILE", 0)
|
||||
|
||||
class HSAProfiler:
|
||||
def __init__(self):
|
||||
self.tracked_signals = collections.defaultdict(list)
|
||||
|
||||
@@ -3,8 +3,8 @@ import multiprocessing
|
||||
from dataclasses import dataclass
|
||||
from collections import defaultdict
|
||||
from typing import List, Optional, Dict, Tuple, Any, cast
|
||||
import importlib, inspect, functools, pathlib, os, ctypes
|
||||
from tinygrad.helpers import getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, from_mv
|
||||
import importlib, inspect, functools, pathlib, os, ctypes, atexit, time, contextlib
|
||||
from tinygrad.helpers import getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, from_mv, ProfileLogger, PROFILE
|
||||
from tinygrad.dtype import DType, ImageDType
|
||||
from tinygrad.renderer import Renderer
|
||||
|
||||
@@ -186,11 +186,23 @@ class Compiled:
|
||||
|
||||
# **************** for HCQ Compatible Devices ****************
|
||||
|
||||
@contextlib.contextmanager
|
||||
def hcq_profile(dev, queue_type, enabled, desc):
|
||||
st, en = (dev._get_signal(), dev._get_signal()) if enabled else (None, None)
|
||||
if enabled: queue_type().timestamp(st).submit(dev)
|
||||
try: yield (st, en)
|
||||
finally:
|
||||
if enabled: queue_type().timestamp(en).submit(dev)
|
||||
if enabled and PROFILE: dev.sig_prof_records.append((st, en, desc, queue_type is dev.hw_copy_queue_t))
|
||||
|
||||
class HCQCompatCompiled(Compiled):
|
||||
def __init__(self, device:str, allocator:Allocator, renderer:Renderer, compiler:Compiler, runtime, comp_queue_t, copy_queue_t, timeline_signals):
|
||||
self.hw_compute_queue_t, self.hw_copy_queue_t = comp_queue_t, copy_queue_t
|
||||
self.timeline_value: int = 1
|
||||
self.timeline_signal, self._shadow_timeline_signal = timeline_signals
|
||||
self.sig_prof_records: List[Tuple[Any, Any, str, bool]] = []
|
||||
self.raw_prof_records: List[Tuple[int, int, str, bool]] = []
|
||||
if PROFILE: self._prof_setup()
|
||||
|
||||
from tinygrad.runtime.graph.hcq import HCQGraph
|
||||
super().__init__(device, allocator, renderer, compiler, runtime, HCQGraph)
|
||||
@@ -198,6 +210,9 @@ class HCQCompatCompiled(Compiled):
|
||||
@classmethod
|
||||
def _read_signal(self, sig): raise NotImplementedError("need _read_signal") # reads a value for a signal
|
||||
|
||||
@classmethod
|
||||
def _read_timestamp(self, sig): raise NotImplementedError("need _read_timestamp") # reads a timestamp for a signal
|
||||
|
||||
@classmethod
|
||||
def _set_signal(self, sig, value): raise NotImplementedError("need _set_signal") # sets a value for a signal
|
||||
|
||||
@@ -207,6 +222,32 @@ class HCQCompatCompiled(Compiled):
|
||||
@classmethod
|
||||
def _wait_signal(self, signal, value=0, timeout=10000): raise NotImplementedError("need _wait_signal") # waits for a signal value
|
||||
|
||||
def _gpu2cpu_time(self, gpu_time, is_copy): raise NotImplementedError("need _gpu2cpu_time")
|
||||
|
||||
def _prof_setup(self):
|
||||
self.profile_logger = ProfileLogger()
|
||||
|
||||
def _sync_queue(q_t):
|
||||
q_t().timestamp(self.timeline_signal).signal(self.timeline_signal, self.timeline_value).submit(self)
|
||||
self.timeline_value += 1
|
||||
cpu_start_time = time.perf_counter_ns() / 1e3
|
||||
self._wait_signal(self.timeline_signal, self.timeline_value - 1)
|
||||
return cpu_start_time, self._read_timestamp(self.timeline_signal)
|
||||
self.cpu_start_time, self.gpu_start_time = _sync_queue(self.hw_compute_queue_t)
|
||||
self.copy_cpu_start_time, self.copy_gpu_start_time = _sync_queue(self.hw_copy_queue_t)
|
||||
|
||||
atexit.register(self._prof_finalize)
|
||||
|
||||
def _prof_process_events(self):
|
||||
self.raw_prof_records += [(self._read_timestamp(st), self._read_timestamp(en), name, is_cp) for st, en, name, is_cp in self.sig_prof_records]
|
||||
for st, en, _, _ in self.sig_prof_records: self.signals_pool += [st, en] # type: ignore
|
||||
self.sig_prof_records = []
|
||||
|
||||
def _prof_finalize(self):
|
||||
for st, en, name, is_cp in self.raw_prof_records:
|
||||
self.profile_logger.events += [(name, self._gpu2cpu_time(st, is_cp), self._gpu2cpu_time(en, is_cp), self.dname, ["COMPUTE", "DMA"][is_cp])]
|
||||
del self.profile_logger
|
||||
|
||||
def _wrap_timeline_signal(self):
|
||||
self.timeline_signal, self._shadow_timeline_signal, self.timeline_value = self._shadow_timeline_signal, self.timeline_signal, 1
|
||||
self._set_signal(self.timeline_signal, 0)
|
||||
@@ -220,15 +261,16 @@ class HCQCompatAllocator(LRUAllocator): # pylint: disable=abstract-method
|
||||
super().__init__()
|
||||
|
||||
def copyin(self, dest, src: memoryview):
|
||||
for i in range(0, src.nbytes, self.b[0].size):
|
||||
self.b_next = (self.b_next + 1) % len(self.b)
|
||||
self.device._wait_signal(self.device.timeline_signal, self.b_timeline[self.b_next])
|
||||
ctypes.memmove(self.b[self.b_next].va_addr, from_mv(src[i:]), lsize:=min(self.b[self.b_next].size, src.nbytes-i))
|
||||
self.device.hw_copy_queue_t().wait(self.device.timeline_signal, self.device.timeline_value - 1) \
|
||||
.copy(dest.va_addr+i, self.b[self.b_next].va_addr, lsize) \
|
||||
.signal(self.device.timeline_signal, self.device.timeline_value).submit(self.device)
|
||||
self.b_timeline[self.b_next] = self.device.timeline_value
|
||||
self.device.timeline_value += 1
|
||||
with hcq_profile(self.device, self.device.hw_copy_queue_t, desc=f"CPU -> {self.device.dname}", enabled=PROFILE):
|
||||
for i in range(0, src.nbytes, self.b[0].size):
|
||||
self.b_next = (self.b_next + 1) % len(self.b)
|
||||
self.device._wait_signal(self.device.timeline_signal, self.b_timeline[self.b_next])
|
||||
ctypes.memmove(self.b[self.b_next].va_addr, from_mv(src[i:]), lsize:=min(self.b[self.b_next].size, src.nbytes-i))
|
||||
self.device.hw_copy_queue_t().wait(self.device.timeline_signal, self.device.timeline_value - 1) \
|
||||
.copy(dest.va_addr+i, self.b[self.b_next].va_addr, lsize) \
|
||||
.signal(self.device.timeline_signal, self.device.timeline_value).submit(self.device)
|
||||
self.b_timeline[self.b_next] = self.device.timeline_value
|
||||
self.device.timeline_value += 1
|
||||
|
||||
def copy_from_disk(self, dest, src, size):
|
||||
def _get_temp_buf():
|
||||
@@ -238,31 +280,36 @@ class HCQCompatAllocator(LRUAllocator): # pylint: disable=abstract-method
|
||||
return (self.b[self.b_next].va_addr, self.b_next)
|
||||
return None
|
||||
|
||||
for (batch_info, dst_off, src_off, copy_size) in src.device.allocator._copyout_sharded(src, size, _get_temp_buf, seg_len=self.b[0].size):
|
||||
self.device.hw_copy_queue_t().wait(self.device.timeline_signal, self.device.timeline_value - 1) \
|
||||
.copy(dest.va_addr + dst_off, batch_info[0] + src_off, copy_size) \
|
||||
.signal(self.device.timeline_signal, self.device.timeline_value).submit(self.device)
|
||||
self.b_timeline[batch_info[1]] = self.device.timeline_value
|
||||
self.device.timeline_value += 1
|
||||
with hcq_profile(self.device, self.device.hw_copy_queue_t, desc=f"DISK -> {self.device.dname}", enabled=PROFILE):
|
||||
for (batch_info, dst_off, src_off, copy_size) in src.device.allocator._copyout_sharded(src, size, _get_temp_buf, seg_len=self.b[0].size):
|
||||
self.device.hw_copy_queue_t().wait(self.device.timeline_signal, self.device.timeline_value - 1) \
|
||||
.copy(dest.va_addr + dst_off, batch_info[0] + src_off, copy_size) \
|
||||
.signal(self.device.timeline_signal, self.device.timeline_value).submit(self.device)
|
||||
self.b_timeline[batch_info[1]] = self.device.timeline_value
|
||||
self.device.timeline_value += 1
|
||||
|
||||
def copyout(self, dest:memoryview, src):
|
||||
self.device.synchronize()
|
||||
for i in range(0, dest.nbytes, self.b[0].size):
|
||||
self.device.hw_copy_queue_t().wait(self.device.timeline_signal, self.device.timeline_value - 1) \
|
||||
.copy(self.b[0].va_addr, src.va_addr+i, lsize:=min(self.b[0].size, dest.nbytes-i)) \
|
||||
.signal(self.device.timeline_signal, self.device.timeline_value).submit(self.device)
|
||||
self.device._wait_signal(self.device.timeline_signal, self.device.timeline_value)
|
||||
self.device.timeline_value += 1
|
||||
|
||||
ctypes.memmove(from_mv(dest[i:]), self.b[0].va_addr, lsize)
|
||||
with hcq_profile(self.device, self.device.hw_copy_queue_t, desc=f"{self.device.dname} -> CPU", enabled=PROFILE):
|
||||
for i in range(0, dest.nbytes, self.b[0].size):
|
||||
self.device.hw_copy_queue_t().wait(self.device.timeline_signal, self.device.timeline_value - 1) \
|
||||
.copy(self.b[0].va_addr, src.va_addr+i, lsize:=min(self.b[0].size, dest.nbytes-i)) \
|
||||
.signal(self.device.timeline_signal, self.device.timeline_value).submit(self.device)
|
||||
self.device._wait_signal(self.device.timeline_signal, self.device.timeline_value)
|
||||
self.device.timeline_value += 1
|
||||
|
||||
ctypes.memmove(from_mv(dest[i:]), self.b[0].va_addr, lsize)
|
||||
|
||||
def transfer(self, dest, src, sz: int, src_dev, dest_dev):
|
||||
src_dev._gpu_map(dest)
|
||||
self.device.hw_copy_queue_t().wait(src_dev.timeline_signal, src_dev.timeline_value - 1) \
|
||||
.wait(dest_dev.timeline_signal, dest_dev.timeline_value - 1) \
|
||||
.copy(dest.va_addr, src.va_addr, sz) \
|
||||
.signal(src_dev.timeline_signal, src_dev.timeline_value).submit(src_dev)
|
||||
self.device.hw_compute_queue_t().wait(src_dev.timeline_signal, src_dev.timeline_value).submit(dest_dev)
|
||||
src_dev.timeline_value += 1
|
||||
|
||||
with hcq_profile(self.device, self.device.hw_copy_queue_t, desc=f"{src_dev.dname} -> {dest_dev.dname}", enabled=PROFILE):
|
||||
src_dev.hw_copy_queue_t().wait(src_dev.timeline_signal, src_dev.timeline_value - 1) \
|
||||
.wait(dest_dev.timeline_signal, dest_dev.timeline_value - 1) \
|
||||
.copy(dest.va_addr, src.va_addr, sz) \
|
||||
.signal(src_dev.timeline_signal, src_dev.timeline_value).submit(src_dev)
|
||||
dest_dev.hw_compute_queue_t().wait(src_dev.timeline_signal, src_dev.timeline_value).submit(dest_dev)
|
||||
src_dev.timeline_value += 1
|
||||
|
||||
def offset(self, buf, size:int, offset:int): return type(buf)(base=buf.base + offset, va_addr=buf.va_addr + offset, length=size, size=size)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3, cProfile, pstats, tempfile, pathlib, string, ctypes, sys
|
||||
import itertools, urllib.request, subprocess, shutil, math
|
||||
import itertools, urllib.request, subprocess, shutil, math, json
|
||||
from typing import Dict, Tuple, Union, List, ClassVar, Optional, Iterable, Any, TypeVar, TYPE_CHECKING, Callable, Sequence
|
||||
if TYPE_CHECKING: # TODO: remove this and import TypeGuard from typing once minimum python supported version is 3.10
|
||||
from typing_extensions import TypeGuard
|
||||
@@ -103,7 +103,7 @@ class ContextVar:
|
||||
DEBUG, IMAGE, BEAM, NOOPT, JIT = ContextVar("DEBUG", 0), ContextVar("IMAGE", 0), ContextVar("BEAM", 0), ContextVar("NOOPT", 0), ContextVar("JIT", 1)
|
||||
WINO, THREEFRY, CACHECOLLECTING = ContextVar("WINO", 0), ContextVar("THREEFRY", 0), ContextVar("CACHECOLLECTING", 1)
|
||||
GRAPH, GRAPHPATH, SAVE_SCHEDULE, RING = ContextVar("GRAPH", 0), getenv("GRAPHPATH", "/tmp/net"), ContextVar("SAVE_SCHEDULE", 0), ContextVar("RING", 1)
|
||||
MULTIOUTPUT = ContextVar("MULTIOUTPUT", 1)
|
||||
MULTIOUTPUT, PROFILE = ContextVar("MULTIOUTPUT", 1), ContextVar("PROFILE", 0)
|
||||
|
||||
# **************** global state Counters ****************
|
||||
|
||||
@@ -144,6 +144,35 @@ class Profiling(contextlib.ContextDecorator):
|
||||
colored(_format_fcn(fcn), "yellow") + " "*(50-len(_format_fcn(fcn))),
|
||||
colored(f"<- {(scallers[0][1][2]/tottime)*100:3.0f}% {_format_fcn(scallers[0][0])}", "BLACK") if len(scallers) else '')
|
||||
|
||||
class ProfileLogger:
|
||||
writers: int = 0
|
||||
mjson: List[Dict] = []
|
||||
actors: Dict[str, int] = {}
|
||||
subactors: Dict[Tuple[str, str], int] = {}
|
||||
path = getenv("PROFILE_OUTPUT_FILE", temp("tinygrad_profile.json"))
|
||||
|
||||
def __init__(self): self.events, ProfileLogger.writers = [], ProfileLogger.writers + 1
|
||||
|
||||
def add_event(self, ev_name, ev_start, ev_end, actor, subactor=None): self.events += [(ev_name, ev_start, ev_end, actor, subactor)]
|
||||
|
||||
def __del__(self):
|
||||
for name,st,et,actor_name,subactor_name in self.events:
|
||||
if actor_name not in self.actors:
|
||||
self.actors[actor_name] = (pid:=len(self.actors))
|
||||
self.mjson.append({"name": "process_name", "ph": "M", "pid": pid, "args": {"name": actor_name}})
|
||||
|
||||
if (subactor_key:=(actor_name,subactor_name)) not in self.subactors:
|
||||
self.subactors[subactor_key] = (tid:=len(self.subactors))
|
||||
self.mjson.append({"name": "thread_name", "ph": "M", "pid": self.actors[actor_name], "tid":tid, "args": {"name": subactor_name}})
|
||||
|
||||
self.mjson.append({"name": name, "ph": "B", "pid": self.actors[actor_name], "tid": self.subactors.get(subactor_key, -1), "ts": st})
|
||||
self.mjson.append({"name": name, "ph": "E", "pid": self.actors[actor_name], "tid": self.subactors.get(subactor_key, -1), "ts": et})
|
||||
|
||||
ProfileLogger.writers -= 1
|
||||
if ProfileLogger.writers == 0:
|
||||
with open(self.path, "w") as f: f.write(json.dumps({"traceEvents": self.mjson}))
|
||||
print(f"Saved profile to {self.path}. Use https://ui.perfetto.dev/ to open it.")
|
||||
|
||||
# *** universal database cache ***
|
||||
|
||||
_cache_dir: str = getenv("XDG_CACHE_HOME", os.path.expanduser("~/Library/Caches" if OSX else "~/.cache"))
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import collections, array, time
|
||||
from typing import List, Any, Dict, cast, Optional, Tuple, Set
|
||||
from tinygrad.helpers import round_up, to_mv
|
||||
from tinygrad.helpers import round_up, to_mv, PROFILE
|
||||
from tinygrad.device import Buffer, BufferOptions, Compiled, Device
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner
|
||||
@@ -49,7 +49,7 @@ class HCQGraph(MultiGraphRunner):
|
||||
self.kickoff_value = 0
|
||||
self.graph_timeline = {dev: 0 for dev in self.devices}
|
||||
|
||||
signal_scheduling: Dict[int, Tuple[List, Optional[int]]] = {}
|
||||
self.signal_sched: Dict[int, Tuple[List, Optional[int], Optional[Tuple]]] = {} # Dict[ji_idx, (deps, output sigval, (prof_st_sig, prof_en_sig))]
|
||||
self.exec_ptrs: Dict[int, Tuple[Any, int]] = {}
|
||||
self.copy_to_devs: Dict[Compiled, Set[Compiled]] = {dev: set() for dev in self.devices}
|
||||
|
||||
@@ -66,14 +66,16 @@ class HCQGraph(MultiGraphRunner):
|
||||
self.comp_signal_val[dev] = sig_val
|
||||
elif isinstance(ji.prg, BufferXfer):
|
||||
dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]]
|
||||
deps = self.access_resources([src], [dest], (self.copy_signal[Device[src.device]], sig_val:=j+1))
|
||||
deps = self.access_resources([src], [dest], (self.copy_signal[(dev:=Device[src.device])], sig_val:=j+1))
|
||||
deps = [x for x in deps if id(x[0]) != id(self.copy_signal[Device[src.device]])]
|
||||
self.copy_signal_val[Device[src.device]] = sig_val
|
||||
self.copy_to_devs[Device[dest.device]].add(Device[src.device])
|
||||
|
||||
# When running compute, set up lazy signals, since no dependencies might be there. Copies always have signals to sync.
|
||||
signal_scheduling[j] = (deps, None if isinstance(ji.prg, CompiledRunner) else j + 1)
|
||||
for sig, val in deps: signal_scheduling[val - 1] = (signal_scheduling[val - 1][0], val) # set need output for signal, as it has deps.
|
||||
prof_ji_desc = ji.prg.clprg.name if isinstance(ji.prg, CompiledRunner) else f"{ji.bufs[1].device} -> {ji.bufs[0].device}" # type: ignore
|
||||
prof_info = (dev._get_signal(), dev._get_signal(), dev, prof_ji_desc, isinstance(ji.prg, BufferXfer)) if PROFILE else None
|
||||
self.signal_sched[j] = (deps, None if isinstance(ji.prg, CompiledRunner) else j + 1, prof_info)
|
||||
for sig, val in deps: self.signal_sched[val - 1] = (self.signal_sched[val - 1][0], val, self.signal_sched[val - 1][2])
|
||||
|
||||
# Building hardware queues
|
||||
for dev in self.devices:
|
||||
@@ -81,21 +83,27 @@ class HCQGraph(MultiGraphRunner):
|
||||
self.copy_queues[dev].wait(dev.timeline_signal, dev.timeline_value - 1).wait(self.kickoff_signal, self.kickoff_value)
|
||||
|
||||
for j,ji in enumerate(self.jit_cache):
|
||||
deps, signal_value = signal_scheduling[j]
|
||||
deps, signal_value, prof_info = self.signal_sched[j]
|
||||
enqueue_queue = self.copy_queues[Device[ji.bufs[1].device]] if isinstance(ji.prg, BufferXfer) else self.comp_queues[ji.prg.device] #type:ignore
|
||||
|
||||
# Encode waits and start profile timestamp (if needed).
|
||||
for sig, val in deps: enqueue_queue.wait(sig, val)
|
||||
if prof_info: enqueue_queue.timestamp(prof_info[0])
|
||||
|
||||
# Encode main commands based on ji type.
|
||||
if isinstance(ji.prg, CompiledRunner):
|
||||
for sig, val in deps: self.comp_queues[ji.prg.device].wait(sig, val)
|
||||
self.comp_queues[ji.prg.device].exec(ji.prg.clprg, self.kargs_addrs[j], *ji.prg.p.launch_dims(var_vals),
|
||||
signal=self.comp_signal[ji.prg.device] if signal_value is not None else None, signal_value=signal_value)
|
||||
self.exec_ptrs[j] = (self.comp_queues[ji.prg.device], len(self.comp_queues[ji.prg.device]) - 1)
|
||||
elif isinstance(ji.prg, BufferXfer):
|
||||
dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]]
|
||||
Device[src.device]._gpu_map(dest._buf) #type: ignore
|
||||
|
||||
for sig,val in deps: self.copy_queues[Device[src.device]].wait(sig, val)
|
||||
self.copy_queues[Device[src.device]].copy(dest._buf.va_addr, src._buf.va_addr, dest.nbytes) \
|
||||
.signal(self.copy_signal[Device[src.device]], signal_value)
|
||||
|
||||
# Encode finish profile timestamp (if needed).
|
||||
if prof_info: enqueue_queue.timestamp(prof_info[1])
|
||||
|
||||
for dev in self.devices:
|
||||
if self.copy_signal_val[dev] > 0: self.comp_queues[dev].wait(self.copy_signal[dev], self.copy_signal_val[dev])
|
||||
for dep_dev in self.copy_to_devs[dev]: self.comp_queues[dev].wait(self.copy_signal[dep_dev], self.copy_signal_val[dep_dev])
|
||||
@@ -113,6 +121,10 @@ class HCQGraph(MultiGraphRunner):
|
||||
dev._set_signal(self.copy_signal[dev], 0)
|
||||
self.devices[0]._set_signal(self.kickoff_signal, self.kickoff_value)
|
||||
|
||||
if PROFILE and self.kickoff_value > 1:
|
||||
for _,_,(st,en,dev,desc,is_cp) in self.signal_sched.values(): #type: ignore
|
||||
dev.raw_prof_records += [(dev._read_timestamp(st), dev._read_timestamp(en), desc, is_cp)]
|
||||
|
||||
# Update rawbuffers
|
||||
for (j,i),input_idx in self.input_replace.items(): self.ji_args_bufs[j][i] = input_rawbuffers[input_idx]._buf.va_addr
|
||||
|
||||
@@ -145,5 +157,9 @@ class HCQGraph(MultiGraphRunner):
|
||||
return [(k, max(v for x, v in deps if id(x) == idk)) for idk, k in {id(x[0]): x[0] for x in deps}.items()]
|
||||
|
||||
def __del__(self):
|
||||
# Graph is destructed. No need to keep signals any more, so return them as part of profiling.
|
||||
if PROFILE and self.kickoff_value > 1:
|
||||
for _,_,(st,en,dev,desc,is_cp) in self.signal_sched.values(): dev.sig_prof_records += [(st, en, desc, is_cp)] #type: ignore
|
||||
|
||||
self.devices[0].signals_pool += [self.kickoff_signal] + list(self.copy_signal.values()) + list(self.comp_signal.values()) # type: ignore
|
||||
for dev,buf in self.kernargs_bufs.items(): dev.allocator._free(buf, BufferOptions(cpu_access=True))
|
||||
for dev, buf in self.kernargs_bufs.items(): dev.allocator._free(buf, BufferOptions(cpu_access=True))
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
from typing import Tuple, List, Any
|
||||
import os, fcntl, ctypes, ctypes.util, functools, re, pathlib, mmap, struct, errno, subprocess, time, array
|
||||
from tinygrad.device import HCQCompatCompiled, HCQCompatAllocator, Compiler, CompileError, BufferOptions
|
||||
from tinygrad.helpers import getenv, init_c_struct_t, to_mv, round_up, DEBUG
|
||||
from tinygrad.helpers import getenv, init_c_struct_t, to_mv, round_up, DEBUG, PROFILE
|
||||
from tinygrad.renderer.cstyle import AMDRenderer
|
||||
from tinygrad.runtime.driver.hip_comgr import compile_hip
|
||||
import tinygrad.runtime.autogen.kfd as kfd
|
||||
@@ -49,7 +49,7 @@ def ioctls_from_header():
|
||||
return type("KIO", (object, ), fxns)
|
||||
kio = ioctls_from_header()
|
||||
|
||||
SIGNAL_SIZE, SIGNAL_COUNT = ctypes.sizeof(hsa.amd_signal_t), 16384
|
||||
SIGNAL_SIZE, SIGNAL_COUNT = ctypes.sizeof(hsa.amd_signal_t), 65536
|
||||
SIGNAL_VALUE_OFFSET = getattr(hsa.amd_signal_t, 'value').offset
|
||||
|
||||
regBIF_BX_PF1_GPU_HDP_FLUSH_REQ = 0x0106
|
||||
@@ -170,8 +170,9 @@ class HWPM4Queue(HWQueue):
|
||||
amd_gpu.PACKET3_RELEASE_MEM_DATA_SEL(mem_data_sel) | amd_gpu.PACKET3_RELEASE_MEM_INT_SEL(mem_int_sel) | amd_gpu.PACKET3_RELEASE_MEM_DST_SEL(0),
|
||||
address & 0xffffffff, address >> 32, value & 0xffffffff, value >> 32, cst]
|
||||
|
||||
def timestamp(self, addr):
|
||||
self._release_mem(CACHE_FLUSH_AND_INV_TS_EVENT, mem_data_sel=3, mem_int_sel=0, address=addr)
|
||||
def timestamp(self, sig):
|
||||
self._release_mem(CACHE_FLUSH_AND_INV_TS_EVENT, mem_data_sel=3, mem_int_sel=0,
|
||||
address=ctypes.addressof(sig) + getattr(hsa.amd_signal_t, 'start_ts').offset)
|
||||
return self._mark_command_end()
|
||||
|
||||
def signal(self, signal:hsa.amd_signal_t, value=0):
|
||||
@@ -270,6 +271,11 @@ class HWCopyQueue(HWQueue):
|
||||
if value is not None: self.q[self.cmd_offsets[cmd_idx] + 3] = value
|
||||
return self
|
||||
|
||||
def timestamp(self, sig: hsa.amd_signal_t):
|
||||
self._q([amd_gpu.SDMA_OP_TIMESTAMP | amd_gpu.SDMA_PKT_TIMESTAMP_GET_HEADER_SUB_OP(amd_gpu.SDMA_SUBOP_TIMESTAMP_GET_GLOBAL),
|
||||
*data64_le(ctypes.addressof(sig) + getattr(hsa.amd_signal_t, 'start_ts').offset)])
|
||||
return self._mark_command_end()
|
||||
|
||||
def submit(self, device:AMDDevice):
|
||||
read_ptr = device.sdma_read_pointer[0]
|
||||
if (device.sdma_doorbell_value-read_ptr) > device.sdma_ring.size: raise RuntimeError("SDMA queue overrun")
|
||||
@@ -364,18 +370,21 @@ class AMDProgram:
|
||||
for i in range(len(args)): args_st.__setattr__(f'f{i}', args[i].va_addr)
|
||||
for i in range(len(vals)): args_st.__setattr__(f'v{i}', vals[i])
|
||||
|
||||
sig_st, sig_en = (self.device._get_signal(), self.device._get_signal()) if PROFILE else (self.device.time_event_st, self.device.time_event_en)
|
||||
|
||||
q = HWPM4Queue()
|
||||
q.wait(self.device.timeline_signal, self.device.timeline_value - 1).memory_barrier()
|
||||
if wait: q.timestamp(ctypes.addressof(self.device.timeline_signal) + getattr(hsa.amd_signal_t, 'start_ts').offset)
|
||||
if wait or PROFILE: q.timestamp(sig_st)
|
||||
q.exec(self, self.device.kernargs_ptr, global_size, local_size)
|
||||
if wait: q.timestamp(ctypes.addressof(self.device.timeline_signal) + getattr(hsa.amd_signal_t, 'end_ts').offset)
|
||||
if wait or PROFILE: q.timestamp(sig_en)
|
||||
q.signal(self.device.timeline_signal, self.device.timeline_value).submit(self.device)
|
||||
self.device.timeline_value += 1
|
||||
self.device.kernargs_ptr += self.kernargs_alloc_size
|
||||
|
||||
if PROFILE: self.device.sig_prof_records.append((sig_st, sig_en, self.name, False))
|
||||
if wait:
|
||||
self.device._wait_signal(self.device.timeline_signal, self.device.timeline_value - 1)
|
||||
return (self.device.timeline_signal.end_ts - self.device.timeline_signal.start_ts) / 1e8
|
||||
return (sig_en.start_ts - sig_st.start_ts) / 1e8
|
||||
|
||||
class AMDAllocator(HCQCompatAllocator):
|
||||
def __init__(self, device:AMDDevice): super().__init__(device, batch_size=SDMA_MAX_COPY_SIZE)
|
||||
@@ -432,6 +441,9 @@ class AMDDevice(HCQCompatCompiled):
|
||||
@classmethod
|
||||
def _read_signal(self, sig): return sig.value
|
||||
|
||||
@classmethod
|
||||
def _read_timestamp(self, sig): return sig.start_ts
|
||||
|
||||
@classmethod
|
||||
def _set_signal(self, sig, value): sig.value = value
|
||||
|
||||
@@ -479,6 +491,8 @@ class AMDDevice(HCQCompatCompiled):
|
||||
self._gpu_map(AMDDevice.event_page)
|
||||
sync_event = kio.create_event(AMDDevice.kfd, auto_reset=1)
|
||||
|
||||
self.time_event_st, self.time_event_en = AMDDevice._get_signal(), AMDDevice._get_signal()
|
||||
|
||||
self.kernargs = self._gpu_alloc(0x1000000, kfd.KFD_IOC_ALLOC_MEM_FLAGS_VRAM)
|
||||
self.kernargs_ptr = self.kernargs.va_addr
|
||||
|
||||
@@ -526,9 +540,14 @@ class AMDDevice(HCQCompatCompiled):
|
||||
super().__init__(device, AMDAllocator(self), AMDRenderer(), AMDCompiler(self.arch), functools.partial(AMDProgram, self), HWPM4Queue, HWCopyQueue,
|
||||
timeline_signals=[self._get_signal(sync_event=sync_event), self._get_signal(sync_event=kio.create_event(AMDDevice.kfd, auto_reset=1))])
|
||||
|
||||
def _gpu2cpu_time(self, gpu_time, is_copy):
|
||||
if is_copy: return self.copy_cpu_start_time + (gpu_time - self.copy_gpu_start_time) / 1e2
|
||||
return self.cpu_start_time + (gpu_time - self.gpu_start_time) / 1e2
|
||||
|
||||
def synchronize(self):
|
||||
AMDDevice._wait_signal(self.timeline_signal, self.timeline_value - 1)
|
||||
|
||||
# reset kernargs
|
||||
self.kernargs_ptr = self.kernargs.va_addr
|
||||
if self.timeline_value > (1 << 31): self._wrap_timeline_signal()
|
||||
if PROFILE: self._prof_process_events()
|
||||
|
||||
@@ -3,7 +3,7 @@ import os, ctypes, pathlib, re, fcntl, functools, mmap, struct, tempfile, hashli
|
||||
from typing import Tuple, List, Any
|
||||
from dataclasses import dataclass
|
||||
from tinygrad.device import HCQCompatCompiled, HCQCompatAllocator, Compiler, CompileError, BufferOptions
|
||||
from tinygrad.helpers import getenv, from_mv, mv_address, init_c_struct_t, to_mv, round_up, to_char_p_p, DEBUG, prod
|
||||
from tinygrad.helpers import getenv, from_mv, mv_address, init_c_struct_t, to_mv, round_up, to_char_p_p, DEBUG, prod, PROFILE
|
||||
from tinygrad.renderer.cstyle import NVRenderer
|
||||
from tinygrad.runtime.ops_cuda import check as cuda_check, _get_bytes, CUDACompiler
|
||||
import tinygrad.runtime.autogen.cuda as cuda
|
||||
@@ -103,9 +103,11 @@ class HWQueue:
|
||||
(3 << 0) | (1 << 24)] # ACQUIRE | PAYLOAD_SIZE_64BIT
|
||||
return self._mark_command_end()
|
||||
|
||||
def timestamp(self, signal): return HWQueue.signal(self, signal, timestamp=True)
|
||||
|
||||
def signal(self, signal, value=0, timestamp=False):
|
||||
self.q += [nvmethod(0, nv_gpu.NVC56F_SEM_ADDR_LO, 5), *nvdata64_le(ctypes.addressof(from_mv(signal))), *nvdata64_le(value),
|
||||
(1 << 0) | (1 << 20) | (1 << 24) | ((1 << 25) if timestamp else 0)] # RELEASE | RELEASE_WFI | PAYLOAD_SIZE_64BIT | RELEASE_TIMESTAMP
|
||||
(1 << 0) | (1 << 20) | (1 << 24) | ((1 << 25) if timestamp else 0)] # RELEASE | RELEASE_WFI | PAYLOAD_SIZE_64BIT | RELEASE_TIMESTAMP
|
||||
self.q += [nvmethod(0, nv_gpu.NVC56F_NON_STALL_INTERRUPT, 1), 0x0]
|
||||
return self._mark_command_end()
|
||||
|
||||
@@ -316,19 +318,22 @@ class NVProgram:
|
||||
if MOCKGPU: self.constbuffer_0[0:2] = [len(args), len(vals)]
|
||||
kernargs = [arg_half for arg in args for arg_half in nvdata64_le(arg.base)] + [val for val in vals]
|
||||
|
||||
sig_st, sig_en = (self.device._get_signal(), self.device._get_signal()) if PROFILE else (self.device.time_event_st, self.device.time_event_en)
|
||||
|
||||
queue = HWComputeQueue()
|
||||
queue.wait(self.device.timeline_signal, self.device.timeline_value - 1)
|
||||
if wait: queue.signal(self.device.time_event_st, timestamp=True)
|
||||
if wait or PROFILE: queue.timestamp(sig_st)
|
||||
queue.copy_from_cpu(self.device.kernargs_ptr, self.constbuffer_0 + kernargs)
|
||||
queue.exec(self, self.device.kernargs_ptr, global_size, local_size)
|
||||
if wait: queue.signal(self.device.time_event_en, timestamp=True)
|
||||
if wait or PROFILE: queue.timestamp(sig_en)
|
||||
queue.signal(self.device.timeline_signal, self.device.timeline_value).submit(self.device)
|
||||
self.device.timeline_value += 1
|
||||
self.device.kernargs_ptr += self.kernargs_alloc_size
|
||||
|
||||
if PROFILE: self.device.sig_prof_records.append((sig_st, sig_en, self.name, False))
|
||||
if wait:
|
||||
self.device._wait_signal(self.device.timeline_signal, self.device.timeline_value - 1)
|
||||
return (self.device.time_event_en[1] - self.device.time_event_st[1]) / 1e9
|
||||
return (sig_en[1] - sig_st[1]) / 1e9
|
||||
|
||||
class NVAllocator(HCQCompatAllocator):
|
||||
def __init__(self, device:NVDevice): super().__init__(device)
|
||||
@@ -498,7 +503,7 @@ class NVDevice(HCQCompatCompiled):
|
||||
uvm.enable_peer_access(self.fd_uvm, gpuUuidA=nv_gpu.struct_nv_uuid(uuid=self.gpu_uuid), gpuUuidB=nv_gpu.struct_nv_uuid(uuid=dev.gpu_uuid))
|
||||
|
||||
if NVDevice.signals_page is None:
|
||||
NVDevice.signals_page = self._gpu_system_alloc(0x10000, map_to_cpu=True)
|
||||
NVDevice.signals_page = self._gpu_system_alloc(16 * 65536, map_to_cpu=True)
|
||||
NVDevice.signals_pool = [to_mv(self.signals_page.base + off, 16).cast("Q") for off in range(0, NVDevice.signals_page.length, 16)]
|
||||
else: self._gpu_map(NVDevice.signals_page)
|
||||
|
||||
@@ -535,15 +540,12 @@ class NVDevice(HCQCompatCompiled):
|
||||
|
||||
NVDevice.devices.append(self)
|
||||
|
||||
def synchronize(self):
|
||||
NVDevice._wait_signal(self.timeline_signal, self.timeline_value - 1)
|
||||
self.cmdq_wptr = 0
|
||||
|
||||
if self.timeline_value > (1 << 63): self._wrap_timeline_signal()
|
||||
|
||||
@classmethod
|
||||
def _read_signal(self, sig): return sig[0]
|
||||
|
||||
@classmethod
|
||||
def _read_timestamp(self, sig): return sig[1]
|
||||
|
||||
@classmethod
|
||||
def _set_signal(self, sig, value): sig[0] = value
|
||||
|
||||
@@ -559,6 +561,15 @@ class NVDevice(HCQCompatCompiled):
|
||||
if signal[0] >= value: return
|
||||
raise RuntimeError(f"wait_result: {timeout} ms TIMEOUT!")
|
||||
|
||||
def _gpu2cpu_time(self, gpu_time, is_copy): return self.cpu_start_time + (gpu_time - self.gpu_start_time) / 1e3
|
||||
|
||||
def synchronize(self):
|
||||
NVDevice._wait_signal(self.timeline_signal, self.timeline_value - 1)
|
||||
self.cmdq_wptr = 0
|
||||
|
||||
if self.timeline_value > (1 << 63): self._wrap_timeline_signal()
|
||||
if PROFILE: self._prof_process_events()
|
||||
|
||||
def _new_gpu_fifo(self, gpfifo_area, ctxshare, channel_group, offset=0, entries=0x400) -> GPFifo:
|
||||
notifier = self._gpu_system_alloc(48 << 20)
|
||||
params = nv_gpu.NV_CHANNELGPFIFO_ALLOCATION_PARAMETERS(hObjectError=notifier.hMemory, hObjectBuffer=gpfifo_area.hMemory,
|
||||
|
||||
Reference in New Issue
Block a user