Files
tinygrad/tinygrad/viz/serve.py
qazal a2da61d096 use new style amd compiler in viz (#13848)
* working version, handcode gfx1100 arch

* get target from device properties

* lib in cfg test program spec
2025-12-27 23:59:30 +09:00

562 lines
30 KiB
Python
Executable File

#!/usr/bin/env python3
import multiprocessing, pickle, difflib, os, threading, json, time, sys, webbrowser, socket, argparse, functools, codecs, io, struct
import ctypes, pathlib, traceback, itertools
from contextlib import redirect_stdout, redirect_stderr, contextmanager
from decimal import Decimal
from urllib.parse import parse_qs, urlparse
from typing import Any, TypedDict, TypeVar, Generator, Callable
from tinygrad.helpers import colored, getenv, tqdm, unwrap, word_wrap, TRACEMETA, ProfileEvent, ProfileRangeEvent, TracingKey, ProfilePointEvent, temp
from tinygrad.helpers import printable, TCPServerWithReuse, HTTPRequestHandler
from tinygrad.uop.ops import TrackedGraphRewrite, RewriteTrace, UOp, Ops, GroupOp, srender, sint, sym_infer, range_str, pyrender
from tinygrad.uop.ops import print_uops, range_start, multirange_str
from tinygrad.device import ProfileDeviceEvent, ProfileGraphEvent, ProfileGraphEntry, Device, ProfileProgramEvent
from tinygrad.renderer import ProgramSpec
from tinygrad.dtype import dtypes
uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0", Ops.REDUCE: "#FF5B5B",
Ops.DEFINE_GLOBAL:"#cb9037", **{x:"#f2cb91" for x in {Ops.DEFINE_LOCAL, Ops.DEFINE_REG}}, Ops.REDUCE_AXIS: "#FF6B6B",
Ops.RANGE: "#c8a0e0", Ops.ASSIGN: "#909090", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff",
Ops.INDEX: "#cef263", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.KERNEL: "#3e7f55", Ops.CUSTOM_KERNEL: "#3ebf55",
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80",
Ops.BUFFER_VIEW: "#E5EAFF", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.ENCDEC: "#bf71b6",
Ops.ALLREDUCE: "#ff40a0", Ops.MSELECT: "#d040a0", Ops.MSTACK: "#d040a0", Ops.CONTIGUOUS: "#FFC14D",
Ops.BUFFERIZE: "#FF991C", Ops.REWRITE_ERROR: "#ff2e2e", Ops.AFTER: "#8A7866", Ops.END: "#524C46"}
# VIZ API
# A step is a lightweight descriptor for a trace entry
# Includes a name, metadata and a URL path for fetching the full data
def create_step(name:str, query:tuple[str, int, int], data=None, depth:int=0, **kwargs) -> dict:
return {"name":name, "query":f"{query[0]}?ctx={query[1]}&step={query[2]}", "data":data, "depth":depth, **kwargs}
# ** list all saved rewrites
ref_map:dict[Any, int] = {}
def get_rewrites(t:RewriteTrace) -> list[dict]:
ret = []
for i,(k,v) in enumerate(zip(t.keys, t.rewrites)):
steps = [create_step(s.name, ("/graph-rewrites", i, j), loc=s.loc, match_count=len(s.matches), code_line=printable(s.loc),
trace=k.tb if j==0 else None, depth=s.depth) for j,s in enumerate(v)]
if isinstance(k.ret, ProgramSpec):
steps.append(create_step("View UOp List", ("/uops", i, len(steps)), k.ret))
steps.append(create_step("View Program", ("/code", i, len(steps)), k.ret))
steps.append(create_step("View Disassembly", ("/asm", i, len(steps)), k.ret))
for key in k.keys: ref_map[key] = i
ret.append({"name":k.display_name, "steps":steps})
return ret
# ** get the complete UOp graphs for one rewrite
class GraphRewriteDetails(TypedDict):
graph: dict # JSON serialized UOp for this rewrite step
uop: str # strigified UOp for this rewrite step
diff: list[str]|None # diff of the single UOp that changed
change: list[int]|None # the new UOp id + all its parents ids
upat: tuple[tuple[str, int], str]|None # [loc, source_code] of the matched UPat
def shape_to_str(s:tuple[sint, ...]): return "(" + ','.join(srender(x) for x in s) + ")"
def mask_to_str(s:tuple[tuple[sint, sint], ...]): return "(" + ','.join(shape_to_str(x) for x in s) + ")"
def pystr(u:UOp) -> str:
# pyrender may check for shape mismatch
try: return pyrender(u)
except Exception: return str(u)
def uop_to_json(x:UOp) -> dict[int, dict]:
assert isinstance(x, UOp)
graph: dict[int, dict] = {}
excluded: set[UOp] = set()
for u in (toposort:=x.toposort()):
# always exclude DEVICE/CONST/UNIQUE
if u.op in {Ops.DEVICE, Ops.CONST, Ops.UNIQUE, Ops.LUNIQUE} and u is not x: excluded.add(u)
if u.op is Ops.VCONST and u.dtype.scalar() == dtypes.index and u is not x: excluded.add(u)
if u.op is Ops.VECTORIZE and len(u.src) == 0: excluded.add(u)
for u in toposort:
if u in excluded: continue
argst = codecs.decode(str(u.arg), "unicode_escape")
if u.op in GroupOp.Movement: argst = (mask_to_str if u.op in {Ops.SHRINK, Ops.PAD} else shape_to_str)(u.marg)
if u.op is Ops.KERNEL:
ast_str = f"SINK{tuple(s.op for s in u.arg.ast.src)}" if u.arg.ast.op is Ops.SINK else repr(u.arg.ast.op)
argst = f"<Kernel {len(list(u.arg.ast.toposort()))} {ast_str} {[str(m) for m in u.arg.metadata]}>"
label = f"{str(u.op).split('.')[1]}{(chr(10)+word_wrap(argst.replace(':', ''))) if u.arg is not None else ''}"
if u.dtype != dtypes.void: label += f"\n{u.dtype}"
for idx,x in enumerate(u.src[:1] if u.op in {Ops.BUFFERIZE, Ops.INDEX} else (u.src if u.op is not Ops.END else [])):
if x in excluded:
arg = f"{x.arg:g}" if x.op is Ops.CONST and dtypes.is_float(x.dtype) else f"{x.arg}"
label += f"\n{x.op.name}{idx} {arg}" + (f" {x.src[0].op}" if len(x.src) else "")
try:
if len(rngs:=u.ranges):
label += f"\n({multirange_str(rngs, color=True)})"
if u._shape is not None:
label += f"\n{shape_to_str(u.shape)}"
if u.op in {Ops.INDEX, Ops.BUFFERIZE}:
label += f"\n{u.render()}"
ranges: list[UOp] = []
for us in u.src[1:]: ranges += [s for s in us.toposort() if s.op in {Ops.RANGE, Ops.SPECIAL}]
if ranges: label += "\n"+' '.join([f"{s.render()}={s.vmax+1}" for s in ranges])
if u.op in {Ops.END, Ops.REDUCE} and len(trngs:=list(UOp.sink(*u.src[range_start[u.op]:]).ranges)):
label += "\n"+' '.join([f"{range_str(s, color=True)}({s.vmax+1})" for s in trngs])
except Exception:
label += "\n<ISSUE GETTING LABEL>"
if (ref:=ref_map.get(u.arg.ast) if u.op is Ops.KERNEL else None) is not None: label += f"\ncodegen@{ctxs[ref]['name']}"
# NOTE: kernel already has metadata in arg
if TRACEMETA >= 2 and u.metadata is not None and u.op is not Ops.KERNEL: label += "\n"+str(u.metadata)
graph[id(u)] = {"label":label, "src":[(i,id(x)) for i,x in enumerate(u.src) if x not in excluded], "color":uops_colors.get(u.op, "#ffffff"),
"ref":ref, "tag":repr(u.tag) if u.tag is not None else None}
return graph
@functools.cache
def _reconstruct(a:int):
op, dtype, src, arg, *rest = trace.uop_fields[a]
arg = type(arg)(_reconstruct(arg.ast), arg.metadata) if op is Ops.KERNEL else arg
return UOp(op, dtype, tuple(_reconstruct(s) for s in src), arg, *rest)
def get_full_rewrite(ctx:TrackedGraphRewrite) -> Generator[GraphRewriteDetails, None, None]:
next_sink = _reconstruct(ctx.sink)
# in the schedule graph we don't show indexing ops (unless it's in a kernel AST or rewriting dtypes.index sink)
yield {"graph":uop_to_json(next_sink), "uop":pystr(next_sink), "change":None, "diff":None, "upat":None}
replaces: dict[UOp, UOp] = {}
for u0_num,u1_num,upat_loc,dur in tqdm(ctx.matches):
replaces[u0:=_reconstruct(u0_num)] = u1 = _reconstruct(u1_num)
try: new_sink = next_sink.substitute(replaces)
except RuntimeError as e: new_sink = UOp(Ops.NOOP, arg=str(e))
match_repr = f"# {dur*1e6:.2f} us\n"+printable(upat_loc)
yield {"graph":(sink_json:=uop_to_json(new_sink)), "uop":pystr(new_sink), "change":[id(x) for x in u1.toposort() if id(x) in sink_json],
"diff":list(difflib.unified_diff(pystr(u0).splitlines(), pystr(u1).splitlines())), "upat":(upat_loc, match_repr)}
if not ctx.bottom_up: next_sink = new_sink
# encoder helpers
def enum_str(s, cache:dict[str, int]) -> int:
if (cret:=cache.get(s)) is not None: return cret
cache[s] = ret = len(cache)
return ret
def option(s:int|None) -> int: return 0 if s is None else s+1
# Profiler API
device_ts_diffs:dict[str, tuple[Decimal, Decimal]] = {}
def cpu_ts_diff(device:str, thread=0) -> Decimal: return device_ts_diffs.get(device, (Decimal(0),))[thread]
device_props:dict[str, dict] = {}
DevEvent = ProfileRangeEvent|ProfileGraphEntry|ProfilePointEvent
def flatten_events(profile:list[ProfileEvent]) -> Generator[tuple[Decimal, Decimal, DevEvent], None, None]:
for e in profile:
if isinstance(e, ProfileRangeEvent): yield (e.st+(diff:=cpu_ts_diff(e.device, e.is_copy)), (e.en if e.en is not None else e.st)+diff, e)
elif isinstance(e, ProfilePointEvent): yield (e.ts, e.ts, e)
elif isinstance(e, ProfileGraphEvent):
cpu_ts = []
for ent in e.ents: cpu_ts += [e.sigs[ent.st_id]+(diff:=cpu_ts_diff(ent.device, ent.is_copy)), e.sigs[ent.en_id]+diff]
yield (st:=min(cpu_ts)), (et:=max(cpu_ts)), ProfileRangeEvent(f"{e.ents[0].device.split(':')[0]} Graph", f"batched {len(e.ents)}", st, et)
for i,ent in enumerate(e.ents): yield (cpu_ts[i*2], cpu_ts[i*2+1], ent)
# normalize event timestamps and attach kernel metadata
def timeline_layout(dev_events:list[tuple[int, int, float, DevEvent]], start_ts:int, scache:dict[str, int]) -> bytes|None:
events:list[bytes] = []
exec_points:dict[str, ProfilePointEvent] = {}
for st,et,dur,e in dev_events:
if isinstance(e, ProfilePointEvent) and e.name == "exec": exec_points[e.arg["name"]] = e
if dur == 0: continue
name, fmt, key = e.name, [], None
if (ref:=ref_map.get(name)) is not None:
name = ctxs[ref]["name"]
if isinstance(p:=trace.keys[ref].ret, ProgramSpec) and (ei:=exec_points.get(p.name)) is not None:
flops = sym_infer(p.estimates.ops, var_vals:=ei.arg['var_vals'])/(t:=dur*1e-6)
membw, ldsbw = sym_infer(p.estimates.mem, var_vals)/t, sym_infer(p.estimates.lds, var_vals)/t
fmt = [f"{flops*1e-9:.0f} GFLOPS" if flops < 1e14 else f"{flops*1e-12:.0f} TFLOPS",
(f"{membw*1e-9:.0f} GB/s" if membw < 1e13 else f"{membw*1e-12:.0f} TB/s")+" mem",
(f"{ldsbw*1e-9:.0f} GB/s" if ldsbw < 1e15 else f"{ldsbw*1e-12:.0f} TB/s")+" lds"]
if (metadata_str:=",".join([str(m) for m in (ei.arg['metadata'] or ())])): fmt.append(metadata_str)
if isinstance(e, ProfileGraphEntry): fmt.append("(batched)")
key = ei.key
elif isinstance(e.name, TracingKey):
name = e.name.display_name
ref = next((v for k in e.name.keys if (v:=ref_map.get(k)) is not None), None)
events.append(struct.pack("<IIIIfI", enum_str(name, scache), option(ref), option(key), st-start_ts, dur, enum_str("\n".join(fmt), scache)))
return struct.pack("<BI", 0, len(events))+b"".join(events) if events else None
def encode_mem_free(key:int, ts:int, execs:list[ProfilePointEvent], scache:dict) -> bytes:
ei_encoding:list[tuple[int, int, int, int]] = [] # <[u32, u32, u32, u8] [run id, display name, buffer number and mode (2 = r/w, 1 = w, 0 = r)]
for e in execs:
num = next(i for i,k in enumerate(e.arg["bufs"]) if k == key)
mode = 2 if (num in e.arg["inputs"] and num in e.arg["outputs"]) else 1 if (num in e.arg["outputs"]) else 0
ei_encoding.append((e.key, enum_str(e.arg["name"], scache), num, mode))
return struct.pack("<BIII", 0, ts, key, len(ei_encoding))+b"".join(struct.pack("<IIIB", *t) for t in ei_encoding)
def mem_layout(dev_events:list[tuple[int, int, float, DevEvent]], start_ts:int, end_ts:int, peaks:list[int], dtype_size:dict[str, int],
scache:dict[str, int]) -> bytes|None:
peak, mem = 0, 0
temp:dict[int, int] = {}
events:list[bytes] = []
buf_ei:dict[int, list[ProfilePointEvent]] = {}
for st,_,_,e in dev_events:
if not isinstance(e, ProfilePointEvent): continue
if e.name == "alloc":
safe_sz = min(1_000_000_000_000, e.arg["sz"])
events.append(struct.pack("<BIIIQ", 1, int(e.ts)-start_ts, e.key, enum_str(e.arg["dtype"].name, scache), safe_sz))
dtype_size.setdefault(e.arg["dtype"].name, e.arg["dtype"].itemsize)
temp[e.key] = nbytes = safe_sz*e.arg["dtype"].itemsize
mem += nbytes
if mem > peak: peak = mem
if e.name == "exec" and e.arg["bufs"]:
for b in e.arg["bufs"]: buf_ei.setdefault(b, []).append(e)
if e.name == "free":
events.append(encode_mem_free(e.key, int(e.ts) - start_ts, buf_ei.pop(e.key, []), scache))
mem -= temp.pop(e.key)
for t in temp: events.append(encode_mem_free(t, end_ts-start_ts, buf_ei.pop(t, []), scache))
peaks.append(peak)
return struct.pack("<BIQ", 1, len(events), peak)+b"".join(events) if events else None
# by default, VIZ does not start when there is an error
# use this to instead display the traceback to the user
@contextmanager
def soft_err(fn:Callable):
try: yield
except Exception: fn({"src":traceback.format_exc()})
def row_tuple(row:str) -> tuple[int, ...]: return tuple(int(x.split(":")[1]) for x in row.split())
# *** Performance counters
def unpack_pmc(e) -> dict:
agg_cols = ["Name", "Sum"]
sample_cols = ["XCC", "INST", "SE", "SA", "WGP", "Value"]
rows:list[list] = []
view, ptr = memoryview(e.blob).cast('Q'), 0
for s in e.sched:
row:list = [s.name, 0, {"cols":sample_cols, "rows":[]}]
for sample in itertools.product(range(s.xcc), range(s.inst), range(s.se), range(s.sa), range(s.wgp)):
row[1] += (val:=int(view[ptr]))
row[2]["rows"].append(sample+(val,))
ptr += 1
rows.append(row)
return {"rows":rows, "cols":agg_cols}
# ** on startup, list all the performance counter traces
def load_counters(profile:list[ProfileEvent]) -> None:
from tinygrad.runtime.ops_amd import ProfileSQTTEvent, ProfilePMCEvent
counter_events:dict[tuple[str, int], dict] = {}
durations:dict[str, list[float]] = {}
prg_events:dict[str, ProfileProgramEvent] = {}
dev_events:dict[str, ProfileDeviceEvent] = {}
for e in profile:
if isinstance(e, (ProfilePMCEvent, ProfileSQTTEvent)): counter_events.setdefault((e.kern, e.exec_tag), {}).setdefault(type(e), []).append(e)
if isinstance(e, ProfileRangeEvent) and e.device.startswith("AMD") and e.en is not None:
durations.setdefault(str(e.name), []).append(float(e.en-e.st))
if isinstance(e, ProfileProgramEvent): prg_events[str(e.name)] = e
if isinstance(e, ProfileDeviceEvent): dev_events[e.device] = e
if len(counter_events) == 0: return None
ctxs.append({"name":"All Counters", "steps":[create_step("PMC", ("/all-pmc", len(ctxs), 0), (durations, all_counters:={}))]})
run_number = {n:0 for n,_ in counter_events}
for (k, tag),v in counter_events.items():
# use the colored name if it exists
name = trace.keys[r].ret.name if (r:=ref_map.get(k)) is not None else k
run_number[k] += 1
steps:list[dict] = []
if (pmc:=v.get(ProfilePMCEvent)):
steps.append(create_step("PMC", ("/prg-pmc", len(ctxs), len(steps)), pmc))
all_counters[(name, run_number[k], k)] = pmc[0]
if (sqtt:=v.get(ProfileSQTTEvent)):
# to decode a SQTT trace, we need the raw stream, program binary and device properties
steps.append(create_step("SQTT", ("/prg-sqtt", len(ctxs), len(steps)), ((k, tag), [*sqtt, prg_events[k], dev_events[sqtt[0].device]])))
if getenv("SQTT_PARSE"):
# run our decoder on startup, we don't use this since it only works on gfx11
from extra.sqtt.attempt_sqtt_parse import parse_sqtt_print_packets
for e in sqtt: parse_sqtt_print_packets(e.blob)
ctxs.append({"name":f"Exec {name} n{run_number[k]}", "steps":steps})
# ** SQTT OCC only unpacks wave start, end time and SIMD location
def unpack_sqtt(key:tuple[str, int], profile:list[ProfileEvent]) -> tuple[dict[str, list[ProfileEvent]], list[str], dict[str, dict[str, dict]]]:
# * init decoder
from extra.sqtt.roc import decode
rctx = decode(profile)
disasm = rctx.disasms[key[0]]
cu_events:dict[str, list[ProfileEvent]] = {}
# * INST waves
wave_insts:dict[str, dict[str, dict]] = {}
inst_units:dict[str, itertools.count] = {}
for w in rctx.inst_execs.get(key, []):
if (u:=w.wave_loc) not in inst_units: inst_units[u] = itertools.count(0)
n = next(inst_units[u])
if (events:=cu_events.get(w.cu_loc)) is None: cu_events[w.cu_loc] = events = []
events.append(ProfileRangeEvent(w.simd_loc, loc:=f"INST WAVE:{w.wave_id} N:{n}", Decimal(w.begin_time), Decimal(w.end_time)))
wave_insts.setdefault(w.cu_loc, {})[f"{u} N:{n}"] = {"wave":w, "disasm":disasm, "run_number":n, "loc":loc}
# * OCC waves
units:dict[str, itertools.count] = {}
wave_start:dict[str, int] = {}
for occ in rctx.occ_events.get(key, []):
if (u:=occ.wave_loc) not in units: units[u] = itertools.count(0)
if u in inst_units: continue
if occ.start: wave_start[u] = occ.time
else:
if (events:=cu_events.get(occ.cu_loc)) is None: cu_events[occ.cu_loc] = events = []
events.append(ProfileRangeEvent(occ.simd_loc, f"OCC WAVE:{occ.wave_id} N:{next(units[u])}", Decimal(wave_start.pop(u)), Decimal(occ.time)))
return cu_events, list(units), wave_insts
def device_sort_fn(k:str) -> tuple[int, str, int]:
order = {"GC": 0, "USER": 1, "TINY": 2, "DISK": 999}
dname = k.split()[0]
dev_rank = next((v for k,v in order.items() if dname.startswith(k)), len(order))
return (dev_rank, dname, len(k))
def get_profile(profile:list[ProfileEvent], sort_fn:Callable[[str], Any]=device_sort_fn) -> bytes|None:
# start by getting the time diffs
device_decoders:dict[str, Callable[[list[ProfileEvent]], None]] = {}
for ev in profile:
if isinstance(ev, ProfileDeviceEvent):
device_ts_diffs[ev.device] = (ev.comp_tdiff,ev.copy_tdiff if ev.copy_tdiff is not None else ev.comp_tdiff)
if ev.props is not None: device_props[ev.device] = ev.props
if (d:=ev.device.split(":")[0]) == "AMD": device_decoders[d] = load_counters
# load device specific counters
for fxn in device_decoders.values(): fxn(profile)
# map events per device
dev_events:dict[str, list[tuple[int, int, float, DevEvent]]] = {}
markers:list[ProfilePointEvent] = []
start_ts:int|None = None
end_ts:int|None = None
for ts,en,e in flatten_events(profile):
dev_events.setdefault(e.device,[]).append((st:=int(ts), et:=int(en), float(en-ts), e))
if start_ts is None or st < start_ts: start_ts = st
if end_ts is None or et > end_ts: end_ts = et
if isinstance(e, ProfilePointEvent) and e.name == "marker": markers.append(e)
if start_ts is None: return None
# return layout of per device events
layout:dict[str, bytes|None] = {}
scache:dict[str, int] = {}
peaks:list[int] = []
dtype_size:dict[str, int] = {}
for k,v in dev_events.items():
v.sort(key=lambda e:e[0])
layout[k] = timeline_layout(v, start_ts, scache)
layout[f"{k} Memory"] = mem_layout(v, start_ts, unwrap(end_ts), peaks, dtype_size, scache)
sorted_layout = sorted([k for k,v in layout.items() if v is not None], key=sort_fn)
ret = [b"".join([struct.pack("<B", len(k)), k.encode(), unwrap(layout[k])]) for k in sorted_layout]
index = json.dumps({"strings":list(scache), "dtypeSize":dtype_size, "markers":[{"ts":int(e.ts-start_ts), **e.arg} for e in markers]}).encode()
return struct.pack("<IQII", unwrap(end_ts)-start_ts, max(peaks,default=0), len(index), len(ret))+index+b"".join(ret)
# ** Assembly static analyzers
def get_stdout(f: Callable) -> str:
buf = io.StringIO()
try:
with redirect_stdout(buf), redirect_stderr(buf): f()
except Exception: traceback.print_exc(file=buf)
return buf.getvalue()
def amd_readelf(lib:bytes) -> list[dict]:
from tinygrad.runtime.support.elf import elf_loader
import msgpack
_, sections, __ = elf_loader(lib)
data = next((s for s in sections if s.name.startswith(".note"))).content
namesz, descsz, typ = struct.unpack_from(hdr:="<III", data, 0)
offset = (struct.calcsize(hdr)+namesz+3) & -4
notes = msgpack.unpackb(data[offset:offset+descsz])
keys = {".sgpr_count":"SGPRs", ".vgpr_count":"VGPRs", ".max_flat_workgroup_size":"Max WGP size",
".group_segment_fixed_size":"LDS size", ".private_segment_fixed_size":"Scratch size"}
return [{"label":label, "value":v} for k,label in keys.items() if (v:=notes["amdhsa.kernels"][0][k]) > 0]
def llvm_disasm(target:int, lib:bytes) -> dict[int, tuple[str, int]]:
from tinygrad.runtime.autogen import llvm
from tinygrad.runtime.support.elf import elf_loader
llvm.LLVMInitializeAMDGPUTargetInfo()
llvm.LLVMInitializeAMDGPUTargetMC()
llvm.LLVMInitializeAMDGPUAsmParser()
llvm.LLVMInitializeAMDGPUDisassembler()
# pass NULL to callbacks
cbs = [ctypes.cast(0, llvm.LLVMCreateDisasmCPUFeatures.argtypes[i]) for i in {5,6}]
arch = "gfx%d%x%x" % (target // 10000, (target // 100) % 100, target % 100)
ctx = llvm.LLVMCreateDisasmCPUFeatures("amdgcn-amd-amdhsa".encode(), arch.encode(), "".encode(), None, 0, *cbs)
image, sections, _ = elf_loader(lib)
text = next((sh.header for sh in sections if sh.name == ".text"), None)
assert text is not None, "no .text section found in ELF"
off, sz = text.sh_addr, text.sh_size
addr_table:dict[int, tuple[str, int]] = {}
out = ctypes.create_string_buffer(128)
cur_off = off
while cur_off < sz + off:
view = (ctypes.c_ubyte * ((sz + off) - cur_off)).from_buffer_copy(memoryview(image)[cur_off:])
instr_sz = llvm.LLVMDisasmInstruction(ctx, view, ctypes.c_uint64(len(view)), ctypes.c_uint64(0), out, ctypes.c_size_t(128))
addr_table[cur_off] = (out.value.decode("utf-8", "replace").strip(), instr_sz)
cur_off += instr_sz
return addr_table
SOPP_INSTS = {"s_branch", "s_cbranch_scc0", "s_cbranch_scc1", "s_cbranch_vccz", "s_cbranch_vccnz", "s_cbranch_execz", "s_cbranch_execnz"}
def parse_branch(asm:str) -> int|None:
inst, *operands = asm.split(" ")
if inst in SOPP_INSTS:
x = int(operands[0]) & 0xffff
return (x - 0x10000 if x & 0x8000 else x)*4
return None
COND_TAKEN, COND_NOT_TAKEN, UNCOND = range(3)
cfg_colors = {COND_TAKEN: "#3f7564", COND_NOT_TAKEN: "#7a4540", UNCOND: "#3b5f7e"}
def amdgpu_cfg(lib:bytes, target:int) -> dict:
# disassemble
pc_table = llvm_disasm(target, lib)
# get leaders
leaders:set[int] = {next(iter(pc_table))}
for pc, (asm, sz) in pc_table.items():
if (offset:=parse_branch(asm)) is not None: leaders.update((pc+sz+offset, pc+sz))
# build the cfg
curr:int|None = None
blocks:dict[int, list[int]] = {}
paths:dict[int, dict[int, int]] = {}
for pc, (asm, sz) in pc_table.items():
if pc in leaders:
paths[curr:=pc] = {}
blocks[pc] = []
else: assert curr is not None, f"no basic block found for {pc}"
blocks[curr].append(pc)
# control flow ends in endpgm
if asm == "s_endpgm": break
# otherwise a basic block can have exactly one or two paths
nx = pc+sz
if (offset:=parse_branch(asm)) is not None:
if asm.startswith("s_branch"): paths[curr][nx+offset] = UNCOND
else: paths[curr].update([(nx+offset, COND_TAKEN), (nx, COND_NOT_TAKEN)])
elif nx in leaders: paths[curr][nx] = UNCOND
return {"blocks":blocks, "paths":paths, "pc_table":pc_table, "colors":cfg_colors}
# ** Main render function to get the complete details about a trace event
def get_render(i:int, j:int, fmt:str) -> dict:
data = ctxs[i]["steps"][j]["data"]
if fmt == "graph-rewrites": return {"value":get_full_rewrite(trace.rewrites[i][j]), "content_type":"text/event-stream"}
if fmt == "uops": return {"src":get_stdout(lambda: print_uops(data.uops or [])), "lang":"txt"}
if fmt == "code": return {"src":data.src, "lang":"cpp"}
if fmt == "asm":
ret:dict = {"metadata":[]}
if data.device.startswith("AMD") and data.lib is not None:
with soft_err(lambda err: ret.update(err)):
ret["data"] = amdgpu_cfg(lib:=data.lib, device_props[data.device]["gfx_target_version"])
with soft_err(lambda err: ret["metadata"].append(err)): ret["metadata"].append(amd_readelf(lib))
else: ret["src"] = get_stdout(lambda: (compiler:=Device[data.device].compiler).disassemble(compiler.compile(data.src)))
return ret
if fmt == "all-pmc":
durations, pmc = data
ret = {"cols":{}, "rows":[]}
for (name, n, k),events in data[1].items():
pmc_table = unpack_pmc(events)
ret["cols"].update([(r[0], None) for r in pmc_table["rows"]])
ret["rows"].append((name, durations[k][n-1], *[r[1] for r in pmc_table["rows"]]))
ret["cols"] = ["Kernel", "Duration", *ret["cols"]]
return ret
if fmt == "prg-pmc": return unpack_pmc(data[0])
if fmt == "prg-sqtt":
ret = {}
if len((steps:=ctxs[i]["steps"])[j+1:]) == 0:
with soft_err(lambda err: ret.update(err)):
cu_events, units, wave_insts = unpack_sqtt(*data)
for cu in sorted(cu_events, key=row_tuple):
steps.append(create_step(f"{cu} {len(cu_events[cu])}", ("/cu-sqtt", i, len(steps)), depth=1,
data=[ProfilePointEvent(unit, "start", unit, ts=Decimal(0)) for unit in units]+cu_events[cu]))
for k in sorted(wave_insts.get(cu, []), key=row_tuple):
steps.append(create_step(k.replace(cu, ""), ("/sqtt-insts", i, len(steps)), loc=(data:=wave_insts[cu][k])["loc"], depth=2, data=data))
return {**ret, "steps":[{k:v for k,v in s.items() if k != "data"} for s in steps[j+1:]]}
if fmt == "cu-sqtt": return {"value":get_profile(data, sort_fn=row_tuple), "content_type":"application/octet-stream"}
if fmt == "sqtt-insts":
columns = ["PC", "Instruction", "Hits", "Cycles", "Stall", "Type"]
inst_columns = ["N", "Clk", "Idle", "Dur", "Stall"]
# Idle: The total time gap between the completion of previous instruction and the beginning of the current instruction.
# The idle time can be caused by:
# * Arbiter loss
# * Source or destination register dependency
# * Instruction cache miss
# Stall: The total number of cycles the hardware pipe couldn't issue an instruction.
# Duration: Total latency in cycles, defined as "Stall time + Issue time" for gfx9 or "Stall time + Execute time" for gfx10+.
prev_instr = (w:=data["wave"]).begin_time
pc_to_inst = data["disasm"]
start_pc = None
rows:dict[int, dict] = {}
for e in w.unpack_insts():
if start_pc is None: start_pc = e.pc
if (inst:=rows.get(e.pc)) is None:
rows[e.pc] = inst = {"pc":e.pc-start_pc, "inst":pc_to_inst[e.pc][0], "hit_count":0, "dur":0, "stall":0, "type":str(e.typ).split("_")[-1],
"hits":{"cols":inst_columns, "rows":[]}}
inst["hit_count"] += 1
inst["dur"] += e.dur
inst["stall"] += e.stall
inst["hits"]["rows"].append((inst["hit_count"]-1, e.time, max(0, e.time-prev_instr), e.dur, e.stall))
prev_instr = max(prev_instr, e.time + e.dur)
summary = [{"label":"Total Cycles", "value":w.end_time-w.begin_time}, {"label":"SE", "value":w.se}, {"label":"CU", "value":w.cu},
{"label":"SIMD", "value":w.simd}, {"label":"Wave ID", "value":w.wave_id}, {"label":"Run number", "value":data["run_number"]}]
return {"rows":[tuple(v.values()) for v in rows.values()], "cols":columns, "metadata":[summary]}
return data
# ** HTTP server
def get_int(query:dict[str, list[str]], k:str) -> int: return int(query.get(k,["0"])[0])
class Handler(HTTPRequestHandler):
def do_GET(self):
ret, status_code, content_type = b"", 200, "text/html"
if (url:=urlparse(self.path)).path == "/":
with open(os.path.join(os.path.dirname(__file__), "index.html"), "rb") as f: ret = f.read()
elif self.path.startswith(("/assets/", "/js/")) and '/..' not in self.path:
try:
with open(os.path.join(os.path.dirname(__file__), self.path.strip('/')), "rb") as f: ret = f.read()
if url.path.endswith(".js"): content_type = "application/javascript"
if url.path.endswith(".css"): content_type = "text/css"
except FileNotFoundError: status_code = 404
elif (query:=parse_qs(url.query)):
render_src = get_render(get_int(query, "ctx"), get_int(query, "step"), url.path.lstrip("/"))
if "content_type" in render_src: ret, content_type = render_src["value"], render_src["content_type"]
else: ret, content_type = json.dumps(render_src).encode(), "application/json"
if content_type == "text/event-stream": return self.stream_json(render_src["value"])
elif url.path == "/ctxs":
lst = [{**c, "steps":[{k:v for k, v in s.items() if k != "data"} for s in c["steps"]]} for c in ctxs]
ret, content_type = json.dumps(lst).encode(), "application/json"
elif url.path == "/get_profile" and profile_ret: ret, content_type = profile_ret, "application/octet-stream"
else: status_code = 404
return self.send_data(ret, content_type, status_code)
# ** main loop
def reloader():
mtime = os.stat(__file__).st_mtime
while not stop_reloader.is_set():
if mtime != os.stat(__file__).st_mtime:
print("reloading server...")
os.execv(sys.executable, [sys.executable] + sys.argv)
time.sleep(0.1)
T = TypeVar("T")
def load_pickle(path:pathlib.Path, default:T) -> T:
if not path.exists(): return default
with path.open("rb") as f: return pickle.load(f)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--kernels', type=pathlib.Path, help='Path to kernels', default=pathlib.Path(temp("rewrites.pkl", append_user=True)))
parser.add_argument('--profile', type=pathlib.Path, help='Path to profile', default=pathlib.Path(temp("profile.pkl", append_user=True)))
args = parser.parse_args()
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
if s.connect_ex(((HOST:="http://127.0.0.1").replace("http://", ""), PORT:=getenv("PORT", 8000))) == 0:
raise RuntimeError(f"{HOST}:{PORT} is occupied! use PORT= to change.")
stop_reloader = threading.Event()
multiprocessing.current_process().name = "VizProcess" # disallow opening of devices
st = time.perf_counter()
print("*** viz is starting")
ctxs:list[dict] = get_rewrites(trace:=load_pickle(args.kernels, default=RewriteTrace([], [], {})))
profile_ret = get_profile(load_pickle(args.profile, default=[]))
server = TCPServerWithReuse(('', PORT), Handler)
reloader_thread = threading.Thread(target=reloader)
reloader_thread.start()
print(colored(f"*** ready in {(time.perf_counter()-st)*1e3:4.2f}ms", "green"), flush=True)
if len(getenv("BROWSER", "")) > 0: webbrowser.open(f"{HOST}:{PORT}")
try: server.serve_forever()
except KeyboardInterrupt:
print("*** viz is shutting down...")
stop_reloader.set()