hsa profiler (#3711)

* hsa profiler

* simpler

* profile

* copy -> is_copy

* print when saved

* faster

* do not create structs

---------

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
nimlgen
2024-03-14 07:19:22 +03:00
committed by GitHub
parent 56b914fc8c
commit 0f050b1028
2 changed files with 55 additions and 7 deletions

View File

@@ -1,9 +1,9 @@
import ctypes, collections, time, itertools
from typing import List, Any, Dict, cast, Optional, Union
from typing import List, Any, Dict, cast, Optional, Union, Tuple
from tinygrad.helpers import GraphException, init_c_var
from tinygrad.device import Compiled, Buffer, CompiledASTRunner, BufferXfer, MultiDeviceJITGraph, update_stats
from tinygrad.shape.symbolic import Variable
from tinygrad.runtime.ops_hsa import HSADevice
from tinygrad.runtime.ops_hsa import HSADevice, PROFILE, Profiler
from tinygrad.features.jit import JitItem, get_input_replace, get_jit_stats, \
get_jc_idxs_with_updatable_launch_dims, get_jc_idxs_with_updatable_var_vals
import tinygrad.runtime.autogen.hsa as hsa
@@ -67,6 +67,7 @@ class HSAGraph(MultiDeviceJITGraph):
self.w_dependency_map: Dict[Any, Union[hsa.hsa_signal_t, int]] = {}
self.r_dependency_map: Dict[Any, List[Union[hsa.hsa_signal_t, int]]] = collections.defaultdict(list)
signals_to_devices: Dict[ctypes.c_uint64, List[HSADevice]] = {}
self.profile_info: Dict[Compiled, List[Tuple[Any, ...]]] = collections.defaultdict(list)
# Special packet to wait for the world.
self.kickoff_signals: Dict[HSADevice, hsa.hsa_signal_t] = {}
@@ -79,7 +80,11 @@ class HSAGraph(MultiDeviceJITGraph):
for i in range(0, len(wait_signals), 5):
self.virt_aql_queues[ji.prg.device].submit_barrier(wait_signals=wait_signals[i:i+5])
self.packets[j] = hsa.hsa_kernel_dispatch_packet_t.from_address(self.virt_aql_queues[ji.prg.device].write_addr)
self.virt_aql_queues[ji.prg.device].submit_kernel(ji.prg.clprg, *ji.prg.launch_dims(var_vals), ctypes.addressof(self.ji_kargs_structs[j])) #type:ignore
sync_signal = self.virt_aql_queues[ji.prg.device].submit_kernel(ji.prg.clprg, *ji.prg.launch_dims(var_vals), #type:ignore
ctypes.addressof(self.ji_kargs_structs[j]), need_signal=PROFILE)
if PROFILE:
self.profile_info[ji.prg.device].append((sync_signal, ji.prg.clprg.name, False))
self.signals_to_reset.append(sync_signal)
elif isinstance(ji.prg, BufferXfer):
dest, src = [cast(Buffer, x) for x in ji.rawbufs[0:2]]
dest_dev, src_dev = cast(HSADevice, dest.d), cast(HSADevice, src.d)
@@ -90,6 +95,7 @@ class HSAGraph(MultiDeviceJITGraph):
wait_signals = self.access_resources(read=[src], write=[dest], new_dependency=sync_signal, sync_with_aql_packets=True)
self.transfers.append((dest._buf, dest_dev.agent, src._buf, src_dev.agent, dest.nbytes, len(wait_signals),
(hsa.hsa_signal_t*len(wait_signals))(*wait_signals), sync_signal, hsa.HSA_AMD_SDMA_ENGINE_0, True))
if PROFILE: self.profile_info[src_dev].append((sync_signal, f"transfer: HSA:{src_dev.device_id} -> HSA:{dest_dev.device_id}", True))
# Wait for all active signals to finish the graph
wait_signals_to_finish: Dict[HSADevice, List[hsa.hsa_signal_t]] = collections.defaultdict(list)
@@ -145,6 +151,7 @@ class HSAGraph(MultiDeviceJITGraph):
hsa.hsa_signal_wait_scacquire(self.finish_signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE)
et = time.perf_counter() - st
for profdev,profdata in self.profile_info.items(): Profiler.tracked_signals[profdev] += profdata
update_stats(f"<batched {len(self.jit_cache)}>", self.op_estimate, self.mem_estimate, var_vals, et, buf_count=len(input_rawbuffers),
jit=jit, num_kernels=len(self.jit_cache), device="HSA")
return et

View File

@@ -1,13 +1,45 @@
from __future__ import annotations
import ctypes, functools, subprocess, io, atexit
from typing import Tuple, TypeVar, List, Dict
import ctypes, functools, subprocess, io, atexit, collections, json
from typing import Tuple, TypeVar, List, Dict, Any
import tinygrad.runtime.autogen.hsa as hsa
from tinygrad.helpers import DEBUG, init_c_var, from_mv, round_up, to_mv, init_c_struct_t
from tinygrad.helpers import DEBUG, init_c_var, from_mv, round_up, to_mv, init_c_struct_t, getenv
from tinygrad.device import Compiled, LRUAllocator, BufferOptions
from tinygrad.codegen.kernel import LinearizerOptions
from tinygrad.runtime.ops_hip import HIPCompiler
from tinygrad.runtime.driver.hsa import check, scan_agents, find_memory_pool, AQLQueue
PROFILE = getenv("PROFILE", 0)
class HSAProfiler:
def __init__(self):
self.tracked_signals = collections.defaultdict(list)
self.collected_events: List[Tuple[Any, ...]] = []
self.copy_timings = hsa.hsa_amd_profiling_async_copy_time_t()
self.disp_timings = hsa.hsa_amd_profiling_dispatch_time_t()
def track(self, signal, device, name, is_copy=False): self.tracked_signals[device].append((signal, name, is_copy))
def process(self, device):
# Process all tracked signals, should be called before any of tracked signals are reused.
for sig,name,is_copy in self.tracked_signals[device]:
if is_copy: check(hsa.hsa_amd_profiling_get_async_copy_time(sig, ctypes.byref(timings := self.copy_timings)))
else: check(hsa.hsa_amd_profiling_get_dispatch_time(device.agent, sig, ctypes.byref(timings := self.disp_timings))) #type:ignore
self.collected_events.append((device.device_id, is_copy, name, timings.start, timings.end))
self.tracked_signals.pop(device)
def save(self, path):
mjson = []
for i in range(len(HSADevice.devices)):
mjson.append({"name": "process_name", "ph": "M", "pid": i, "args": {"name": "HSA"}})
mjson.append({"name": "thread_name", "ph": "M", "pid": i, "tid": 0, "args": {"name": "AQL"}})
mjson.append({"name": "thread_name", "ph": "M", "pid": i, "tid": 1, "args": {"name": "SDMA"}})
for dev_id,queue_id,name,st,et in self.collected_events:
mjson.append({"name": name, "ph": "B", "pid": dev_id, "tid": queue_id, "ts": st*1e-3})
mjson.append({"name": name, "ph": "E", "pid": dev_id, "tid": queue_id, "ts": et*1e-3})
with open(path, "w") as f: f.write(json.dumps({"traceEvents": mjson}))
print(f"Saved HSA profile to {path}")
Profiler = HSAProfiler()
class HSACompiler(HIPCompiler):
linearizer_opts = LinearizerOptions("HSA", has_tensor_cores=True, shared_max=65536)
@@ -51,7 +83,8 @@ class HSAProgram:
for i in range(len(vals)): args_st.__setattr__(f'v{i}', vals[i])
self.device.flush_hdp()
signal = self.device.hw_queue.submit_kernel(self, global_size, local_size, kernargs, need_signal=wait)
signal = self.device.hw_queue.submit_kernel(self, global_size, local_size, kernargs, need_signal=(wait or PROFILE))
if PROFILE: Profiler.track(signal, self.device, self.name)
if wait:
hsa.hsa_signal_wait_scacquire(signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE)
check(hsa.hsa_amd_profiling_get_dispatch_time(self.device.agent, signal, ctypes.byref(timings := hsa.hsa_amd_profiling_dispatch_time_t())))
@@ -91,6 +124,7 @@ class HSAAllocator(LRUAllocator):
1, ctypes.byref(sync_signal), copy_signal, hsa.HSA_AMD_SDMA_ENGINE_0, True))
self.device.hw_queue.submit_barrier(wait_signals=[copy_signal])
self.device.delayed_free.append(mem)
if PROFILE: Profiler.track(copy_signal, self.device, f"copyin: CPU -> HSA:{self.device.device_id}", is_copy=True)
def copy_from_fd(self, dest, fd, offset, size):
sync_signal = self.device.hw_queue.submit_barrier(need_signal=True)
@@ -137,6 +171,7 @@ class HSAAllocator(LRUAllocator):
check(hsa.hsa_amd_memory_async_copy(addr, HSADevice.cpu_agent, src, self.device.agent, dest.nbytes, 0, None, copy_signal))
hsa.hsa_signal_wait_scacquire(copy_signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE)
check(hsa.hsa_amd_memory_unlock(from_mv(dest)))
if PROFILE: Profiler.track(copy_signal, self.device, f"copyout: HSA:{self.device.device_id} -> CPU", is_copy=True)
def transfer(self, dest:T, src:T, sz:int, src_dev=None, dest_dev=None):
copy_signal = dest_dev.alloc_signal(reusable=False)
@@ -146,6 +181,7 @@ class HSAAllocator(LRUAllocator):
check(hsa.hsa_amd_memory_async_copy_on_engine(dest, dest_dev.agent, src, src_dev.agent, sz, 2, c_wait_signal, copy_signal, hsa.HSA_AMD_SDMA_ENGINE_0, True)) # noqa: E501
src_dev.hw_queue.submit_barrier(wait_signals=[copy_signal])
dest_dev.hw_queue.submit_barrier(wait_signals=[copy_signal])
if PROFILE: Profiler.track(copy_signal, src_dev, f"transfer: HSA:{src_dev.device_id} -> HSA:{dest_dev.device_id}", is_copy=True)
class HSADevice(Compiled):
devices: List[HSADevice] = []
@@ -159,6 +195,7 @@ class HSADevice(Compiled):
HSADevice.agents = scan_agents()
HSADevice.cpu_agent = HSADevice.agents[hsa.HSA_DEVICE_TYPE_CPU][0]
HSADevice.cpu_mempool = find_memory_pool(HSADevice.cpu_agent, segtyp=hsa.HSA_AMD_SEGMENT_GLOBAL, location=hsa.HSA_AMD_MEMORY_POOL_LOCATION_CPU)
if PROFILE: check(hsa.hsa_amd_profiling_async_copy_enable(1))
self.device_id = int(device.split(":")[1]) if ":" in device else 0
self.agent = HSADevice.agents[hsa.HSA_DEVICE_TYPE_GPU][self.device_id]
@@ -196,6 +233,7 @@ class HSADevice(Compiled):
self.delayed_free.clear()
self.kernarg_next_addr = self.kernarg_start_addr
Profiler.process(self)
@staticmethod
def synchronize_system():
@@ -226,6 +264,9 @@ class HSADevice(Compiled):
def hsa_terminate():
# Need to stop/delete aql queue before hsa shut down, this leads to gpu hangs.
for dev in HSADevice.devices:
Profiler.process(dev)
setattr(dev, 'synchronize', lambda: None) # some destructors might require to sync, but hw_queue is removed.
del dev.hw_queue
hsa.hsa_shut_down()
if Profiler.collected_events: Profiler.save("/tmp/profile.json")