mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-14 16:44:59 -05:00
hsa profiler (#3711)
* hsa profiler * simpler * profile * copy -> is_copy * print when saved * faster * do not create structs --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
@@ -1,9 +1,9 @@
|
||||
import ctypes, collections, time, itertools
|
||||
from typing import List, Any, Dict, cast, Optional, Union
|
||||
from typing import List, Any, Dict, cast, Optional, Union, Tuple
|
||||
from tinygrad.helpers import GraphException, init_c_var
|
||||
from tinygrad.device import Compiled, Buffer, CompiledASTRunner, BufferXfer, MultiDeviceJITGraph, update_stats
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.runtime.ops_hsa import HSADevice
|
||||
from tinygrad.runtime.ops_hsa import HSADevice, PROFILE, Profiler
|
||||
from tinygrad.features.jit import JitItem, get_input_replace, get_jit_stats, \
|
||||
get_jc_idxs_with_updatable_launch_dims, get_jc_idxs_with_updatable_var_vals
|
||||
import tinygrad.runtime.autogen.hsa as hsa
|
||||
@@ -67,6 +67,7 @@ class HSAGraph(MultiDeviceJITGraph):
|
||||
self.w_dependency_map: Dict[Any, Union[hsa.hsa_signal_t, int]] = {}
|
||||
self.r_dependency_map: Dict[Any, List[Union[hsa.hsa_signal_t, int]]] = collections.defaultdict(list)
|
||||
signals_to_devices: Dict[ctypes.c_uint64, List[HSADevice]] = {}
|
||||
self.profile_info: Dict[Compiled, List[Tuple[Any, ...]]] = collections.defaultdict(list)
|
||||
|
||||
# Special packet to wait for the world.
|
||||
self.kickoff_signals: Dict[HSADevice, hsa.hsa_signal_t] = {}
|
||||
@@ -79,7 +80,11 @@ class HSAGraph(MultiDeviceJITGraph):
|
||||
for i in range(0, len(wait_signals), 5):
|
||||
self.virt_aql_queues[ji.prg.device].submit_barrier(wait_signals=wait_signals[i:i+5])
|
||||
self.packets[j] = hsa.hsa_kernel_dispatch_packet_t.from_address(self.virt_aql_queues[ji.prg.device].write_addr)
|
||||
self.virt_aql_queues[ji.prg.device].submit_kernel(ji.prg.clprg, *ji.prg.launch_dims(var_vals), ctypes.addressof(self.ji_kargs_structs[j])) #type:ignore
|
||||
sync_signal = self.virt_aql_queues[ji.prg.device].submit_kernel(ji.prg.clprg, *ji.prg.launch_dims(var_vals), #type:ignore
|
||||
ctypes.addressof(self.ji_kargs_structs[j]), need_signal=PROFILE)
|
||||
if PROFILE:
|
||||
self.profile_info[ji.prg.device].append((sync_signal, ji.prg.clprg.name, False))
|
||||
self.signals_to_reset.append(sync_signal)
|
||||
elif isinstance(ji.prg, BufferXfer):
|
||||
dest, src = [cast(Buffer, x) for x in ji.rawbufs[0:2]]
|
||||
dest_dev, src_dev = cast(HSADevice, dest.d), cast(HSADevice, src.d)
|
||||
@@ -90,6 +95,7 @@ class HSAGraph(MultiDeviceJITGraph):
|
||||
wait_signals = self.access_resources(read=[src], write=[dest], new_dependency=sync_signal, sync_with_aql_packets=True)
|
||||
self.transfers.append((dest._buf, dest_dev.agent, src._buf, src_dev.agent, dest.nbytes, len(wait_signals),
|
||||
(hsa.hsa_signal_t*len(wait_signals))(*wait_signals), sync_signal, hsa.HSA_AMD_SDMA_ENGINE_0, True))
|
||||
if PROFILE: self.profile_info[src_dev].append((sync_signal, f"transfer: HSA:{src_dev.device_id} -> HSA:{dest_dev.device_id}", True))
|
||||
|
||||
# Wait for all active signals to finish the graph
|
||||
wait_signals_to_finish: Dict[HSADevice, List[hsa.hsa_signal_t]] = collections.defaultdict(list)
|
||||
@@ -145,6 +151,7 @@ class HSAGraph(MultiDeviceJITGraph):
|
||||
hsa.hsa_signal_wait_scacquire(self.finish_signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE)
|
||||
et = time.perf_counter() - st
|
||||
|
||||
for profdev,profdata in self.profile_info.items(): Profiler.tracked_signals[profdev] += profdata
|
||||
update_stats(f"<batched {len(self.jit_cache)}>", self.op_estimate, self.mem_estimate, var_vals, et, buf_count=len(input_rawbuffers),
|
||||
jit=jit, num_kernels=len(self.jit_cache), device="HSA")
|
||||
return et
|
||||
|
||||
@@ -1,13 +1,45 @@
|
||||
from __future__ import annotations
|
||||
import ctypes, functools, subprocess, io, atexit
|
||||
from typing import Tuple, TypeVar, List, Dict
|
||||
import ctypes, functools, subprocess, io, atexit, collections, json
|
||||
from typing import Tuple, TypeVar, List, Dict, Any
|
||||
import tinygrad.runtime.autogen.hsa as hsa
|
||||
from tinygrad.helpers import DEBUG, init_c_var, from_mv, round_up, to_mv, init_c_struct_t
|
||||
from tinygrad.helpers import DEBUG, init_c_var, from_mv, round_up, to_mv, init_c_struct_t, getenv
|
||||
from tinygrad.device import Compiled, LRUAllocator, BufferOptions
|
||||
from tinygrad.codegen.kernel import LinearizerOptions
|
||||
from tinygrad.runtime.ops_hip import HIPCompiler
|
||||
from tinygrad.runtime.driver.hsa import check, scan_agents, find_memory_pool, AQLQueue
|
||||
|
||||
PROFILE = getenv("PROFILE", 0)
|
||||
|
||||
class HSAProfiler:
|
||||
def __init__(self):
|
||||
self.tracked_signals = collections.defaultdict(list)
|
||||
self.collected_events: List[Tuple[Any, ...]] = []
|
||||
self.copy_timings = hsa.hsa_amd_profiling_async_copy_time_t()
|
||||
self.disp_timings = hsa.hsa_amd_profiling_dispatch_time_t()
|
||||
|
||||
def track(self, signal, device, name, is_copy=False): self.tracked_signals[device].append((signal, name, is_copy))
|
||||
def process(self, device):
|
||||
# Process all tracked signals, should be called before any of tracked signals are reused.
|
||||
for sig,name,is_copy in self.tracked_signals[device]:
|
||||
if is_copy: check(hsa.hsa_amd_profiling_get_async_copy_time(sig, ctypes.byref(timings := self.copy_timings)))
|
||||
else: check(hsa.hsa_amd_profiling_get_dispatch_time(device.agent, sig, ctypes.byref(timings := self.disp_timings))) #type:ignore
|
||||
self.collected_events.append((device.device_id, is_copy, name, timings.start, timings.end))
|
||||
self.tracked_signals.pop(device)
|
||||
|
||||
def save(self, path):
|
||||
mjson = []
|
||||
for i in range(len(HSADevice.devices)):
|
||||
mjson.append({"name": "process_name", "ph": "M", "pid": i, "args": {"name": "HSA"}})
|
||||
mjson.append({"name": "thread_name", "ph": "M", "pid": i, "tid": 0, "args": {"name": "AQL"}})
|
||||
mjson.append({"name": "thread_name", "ph": "M", "pid": i, "tid": 1, "args": {"name": "SDMA"}})
|
||||
|
||||
for dev_id,queue_id,name,st,et in self.collected_events:
|
||||
mjson.append({"name": name, "ph": "B", "pid": dev_id, "tid": queue_id, "ts": st*1e-3})
|
||||
mjson.append({"name": name, "ph": "E", "pid": dev_id, "tid": queue_id, "ts": et*1e-3})
|
||||
with open(path, "w") as f: f.write(json.dumps({"traceEvents": mjson}))
|
||||
print(f"Saved HSA profile to {path}")
|
||||
Profiler = HSAProfiler()
|
||||
|
||||
class HSACompiler(HIPCompiler):
|
||||
linearizer_opts = LinearizerOptions("HSA", has_tensor_cores=True, shared_max=65536)
|
||||
|
||||
@@ -51,7 +83,8 @@ class HSAProgram:
|
||||
for i in range(len(vals)): args_st.__setattr__(f'v{i}', vals[i])
|
||||
self.device.flush_hdp()
|
||||
|
||||
signal = self.device.hw_queue.submit_kernel(self, global_size, local_size, kernargs, need_signal=wait)
|
||||
signal = self.device.hw_queue.submit_kernel(self, global_size, local_size, kernargs, need_signal=(wait or PROFILE))
|
||||
if PROFILE: Profiler.track(signal, self.device, self.name)
|
||||
if wait:
|
||||
hsa.hsa_signal_wait_scacquire(signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE)
|
||||
check(hsa.hsa_amd_profiling_get_dispatch_time(self.device.agent, signal, ctypes.byref(timings := hsa.hsa_amd_profiling_dispatch_time_t())))
|
||||
@@ -91,6 +124,7 @@ class HSAAllocator(LRUAllocator):
|
||||
1, ctypes.byref(sync_signal), copy_signal, hsa.HSA_AMD_SDMA_ENGINE_0, True))
|
||||
self.device.hw_queue.submit_barrier(wait_signals=[copy_signal])
|
||||
self.device.delayed_free.append(mem)
|
||||
if PROFILE: Profiler.track(copy_signal, self.device, f"copyin: CPU -> HSA:{self.device.device_id}", is_copy=True)
|
||||
|
||||
def copy_from_fd(self, dest, fd, offset, size):
|
||||
sync_signal = self.device.hw_queue.submit_barrier(need_signal=True)
|
||||
@@ -137,6 +171,7 @@ class HSAAllocator(LRUAllocator):
|
||||
check(hsa.hsa_amd_memory_async_copy(addr, HSADevice.cpu_agent, src, self.device.agent, dest.nbytes, 0, None, copy_signal))
|
||||
hsa.hsa_signal_wait_scacquire(copy_signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE)
|
||||
check(hsa.hsa_amd_memory_unlock(from_mv(dest)))
|
||||
if PROFILE: Profiler.track(copy_signal, self.device, f"copyout: HSA:{self.device.device_id} -> CPU", is_copy=True)
|
||||
|
||||
def transfer(self, dest:T, src:T, sz:int, src_dev=None, dest_dev=None):
|
||||
copy_signal = dest_dev.alloc_signal(reusable=False)
|
||||
@@ -146,6 +181,7 @@ class HSAAllocator(LRUAllocator):
|
||||
check(hsa.hsa_amd_memory_async_copy_on_engine(dest, dest_dev.agent, src, src_dev.agent, sz, 2, c_wait_signal, copy_signal, hsa.HSA_AMD_SDMA_ENGINE_0, True)) # noqa: E501
|
||||
src_dev.hw_queue.submit_barrier(wait_signals=[copy_signal])
|
||||
dest_dev.hw_queue.submit_barrier(wait_signals=[copy_signal])
|
||||
if PROFILE: Profiler.track(copy_signal, src_dev, f"transfer: HSA:{src_dev.device_id} -> HSA:{dest_dev.device_id}", is_copy=True)
|
||||
|
||||
class HSADevice(Compiled):
|
||||
devices: List[HSADevice] = []
|
||||
@@ -159,6 +195,7 @@ class HSADevice(Compiled):
|
||||
HSADevice.agents = scan_agents()
|
||||
HSADevice.cpu_agent = HSADevice.agents[hsa.HSA_DEVICE_TYPE_CPU][0]
|
||||
HSADevice.cpu_mempool = find_memory_pool(HSADevice.cpu_agent, segtyp=hsa.HSA_AMD_SEGMENT_GLOBAL, location=hsa.HSA_AMD_MEMORY_POOL_LOCATION_CPU)
|
||||
if PROFILE: check(hsa.hsa_amd_profiling_async_copy_enable(1))
|
||||
|
||||
self.device_id = int(device.split(":")[1]) if ":" in device else 0
|
||||
self.agent = HSADevice.agents[hsa.HSA_DEVICE_TYPE_GPU][self.device_id]
|
||||
@@ -196,6 +233,7 @@ class HSADevice(Compiled):
|
||||
self.delayed_free.clear()
|
||||
|
||||
self.kernarg_next_addr = self.kernarg_start_addr
|
||||
Profiler.process(self)
|
||||
|
||||
@staticmethod
|
||||
def synchronize_system():
|
||||
@@ -226,6 +264,9 @@ class HSADevice(Compiled):
|
||||
def hsa_terminate():
|
||||
# Need to stop/delete aql queue before hsa shut down, this leads to gpu hangs.
|
||||
for dev in HSADevice.devices:
|
||||
Profiler.process(dev)
|
||||
setattr(dev, 'synchronize', lambda: None) # some destructors might require to sync, but hw_queue is removed.
|
||||
del dev.hw_queue
|
||||
|
||||
hsa.hsa_shut_down()
|
||||
if Profiler.collected_events: Profiler.save("/tmp/profile.json")
|
||||
|
||||
Reference in New Issue
Block a user