viz: amdgpu assembly basic block graph (#13755)

This commit is contained in:
qazal
2025-12-23 00:17:16 +09:00
committed by GitHub
parent df0f9d6860
commit 389f01c7f4
6 changed files with 281 additions and 42 deletions

View File

@@ -5,28 +5,7 @@ from tinygrad.device import ProfileEvent, ProfileDeviceEvent, ProfileProgramEven
from tinygrad.runtime.ops_amd import ProfileSQTTEvent, ProfilePMCEvent
from tinygrad.runtime.autogen import llvm, rocprof
from tinygrad.runtime.support.elf import elf_loader
def llvm_disasm(arch:str, lib:bytes) -> dict[int, tuple[str, int]]:
llvm.LLVMInitializeAMDGPUTargetInfo()
llvm.LLVMInitializeAMDGPUTargetMC()
llvm.LLVMInitializeAMDGPUAsmParser()
llvm.LLVMInitializeAMDGPUDisassembler()
# pass NULL to callbacks
cbs = [ctypes.cast(0, llvm.LLVMCreateDisasmCPUFeatures.argtypes[i]) for i in {5,6}]
ctx = llvm.LLVMCreateDisasmCPUFeatures("amdgcn-amd-amdhsa".encode(), arch.encode(), "".encode(), None, 0, *cbs)
image, sections, relocs = elf_loader(lib)
text = next((sh.header for sh in sections if sh.name == ".text"), None)
off, sz = unwrap(text).sh_addr, unwrap(text).sh_size
addr_table:dict[int, tuple[str, int]] = {}
out = ctypes.create_string_buffer(128)
cur_off = off
while cur_off < sz + off:
view = (ctypes.c_ubyte * ((sz + off) - cur_off)).from_buffer_copy(memoryview(image)[cur_off:])
instr_sz = llvm.LLVMDisasmInstruction(ctx, view, ctypes.c_uint64(len(view)), ctypes.c_uint64(0), out, ctypes.c_size_t(128))
addr_table[cur_off] = (out.value.decode("utf-8", "replace").strip(), instr_sz)
cur_off += instr_sz
return addr_table
from tinygrad.viz.serve import llvm_disasm
@dataclasses.dataclass(frozen=True)
class InstExec:

View File

@@ -0,0 +1,172 @@
import unittest
import textwrap
from tinygrad import Device, Tensor
from tinygrad.uop.ops import UOp, Ops, track_rewrites
from tinygrad.renderer import ProgramSpec
from tinygrad.helpers import TracingKey
from tinygrad.engine.realize import ExecItem, CompiledRunner
# TODO: use the RDNA3 renderer when it's in master
template = """.text
.globl fn_name
.p2align 8
.type fn_name,@function
fn_name:
INSTRUCTION
.rodata
.p2align 6
.amdhsa_kernel fn_name
.amdhsa_user_sgpr_kernarg_segment_ptr 1
.amdhsa_next_free_vgpr .amdgcn.next_free_vgpr
.amdhsa_next_free_sgpr .amdgcn.next_free_sgpr
.amdhsa_wavefront_size32 1
.end_amdhsa_kernel
.amdgpu_metadata
---
amdhsa.version:
- 1
- 0
amdhsa.kernels:
- .name: fn_name
.symbol: fn_name.kd
.group_segment_fixed_size: 0
.private_segment_fixed_size: 0
.wavefront_size: 32
.sgpr_count: 8
.vgpr_count: 8
.max_flat_workgroup_size: 1024
.kernarg_segment_align: 8
.kernarg_segment_size: 8
.args:
- .address_space: global
.name: a
.offset: 0
.size: 8
.type_name: 'float*'
.value_kind: global_buffer
...
.end_amdgpu_metadata
"""
@track_rewrites(name=lambda *args,ret,**kwargs: TracingKey(ret.name, ret=ret))
def run_asm(name:str, src:str) -> ProgramSpec:
prg = ProgramSpec(name, template.replace("fn_name", name).replace("INSTRUCTION", textwrap.dedent(src)), Device.DEFAULT, UOp(Ops.SINK))
ei = ExecItem(UOp(Ops.SINK), [Tensor.empty(1).uop.buffer.ensure_allocated()], prg=CompiledRunner(prg))
ei.run()
return prg
@unittest.skipUnless(Device.DEFAULT == "AMD", "only on AMD")
class TestCfg(unittest.TestCase):
def setUp(self):
arch = Device["AMD"].arch
if not any(arch.startswith(a) for a in {"gfx11", "gfx12"}):
self.skipTest(f"tests written for RDNA, got arch {arch}")
def test_simple(self):
run_asm("simple", """
entry:
s_branch bb1
bb1:
s_endpgm
""")
def test_diamond(self):
run_asm("diamond", """
entry:
s_cmp_eq_i32 s0, 0
s_cbranch_scc1 if
s_branch else
if:
s_nop 1
s_branch end
else:
s_nop 0
end:
s_endpgm
""")
def test_loop(self):
run_asm("simple_loop", """
entry:
s_mov_b32 s1, 4
loop:
s_add_u32 s1, s1, -1
s_cmp_eq_i32 s1, 0
s_cbranch_scc0 loop
s_endpgm
""")
def test_loop_branch(self):
run_asm("loop_if", """
entry:
s_mov_b32 s1, 4
loop:
s_add_u32 s1, s1, -1
s_cmp_eq_i32 s1, 2
s_cbranch_scc1 cond
s_branch cont
cond:
s_add_u32 s1, s1, -2
cont:
s_cmp_eq_i32 s1, 0
s_cbranch_scc0 loop
s_endpgm
""")
def test_loop_break(self):
run_asm("loop_break", """
entry:
s_mov_b32 s1, 8
loop:
s_add_u32 s1, s1, -1
s_cmp_eq_i32 s1, 5
s_cbranch_scc1 break
s_cmp_eq_i32 s1, 0
s_cbranch_scc0 loop
break:
s_endpgm
""")
def test_switch(self):
run_asm("switch_case", """
entry:
s_cmp_eq_i32 s0, 0
s_cbranch_scc1 case0
s_cmp_eq_i32 s0, 1
s_cbranch_scc1 case1
s_branch case2
case0:
s_nop 0
s_branch join
case1:
s_nop 1
s_branch join
case2:
s_nop 2
s_branch join
join:
s_endpgm
""")
def test_ping_pong(self):
run_asm("ping_pong", """
entry:
s_cmp_eq_i32 s0, 0
s_cbranch_scc1 ping
s_branch pong
ping:
s_cmp_eq_i32 s1, 0
s_cbranch_scc1 pong
s_branch end
pong:
s_cmp_eq_i32 s2, 0
s_cbranch_scc1 ping
end:
s_endpgm
""")
if __name__ == "__main__":
unittest.main()

View File

@@ -124,7 +124,6 @@
fill: rgba(26, 27, 38, 0.5);
}
.edgePath {
stroke: #4a4b57;
fill: none;
stroke-width: 1.4px;
}

View File

@@ -78,6 +78,7 @@ const drawGraph = (data) => {
const labels = nodes.selectAll("g.label").data(d => [d]).join("g").attr("class", "label");
labels.attr("transform", d => `translate(-${d.labelWidth/2}, -${d.labelHeight/2+STROKE_WIDTH*2})`);
labels.selectAll("text").data(d => {
if (Array.isArray(d.label)) return [d.label];
const ret = [[]];
for (const s of parseColors(d.label, defaultColor="initial")) {
const color = darkenHex(s.color, 25);
@@ -87,7 +88,7 @@ const drawGraph = (data) => {
}
return [ret];
}).join("text").selectAll("tspan").data(d => d).join("tspan").attr("x", "0").attr("dy", 14).selectAll("tspan").data(d => d).join("tspan")
.attr("fill", d => d.color).text(d => d.st).attr("xml:space", "preserve");
.attr("fill", d => d.color).text(d => d.st).attr("xml:space", "preserve").style("font-family", g.graph().font);
addTags(nodes.selectAll("g.tag").data(d => d.tag != null ? [d] : []).join("g").attr("class", "tag")
.attr("transform", d => `translate(${-d.width/2+8}, ${-d.height/2+8})`).datum(e => e.tag));
// draw edges
@@ -98,7 +99,7 @@ const drawGraph = (data) => {
points.unshift(intersectRect(g.node(e.v), points[0]));
points.push(intersectRect(g.node(e.w), points[points.length-1]));
return line(points);
}).attr("marker-end", "url(#arrowhead)");
}).attr("marker-end", "url(#arrowhead)").attr("stroke", e => g.edge(e).color || "#4a4b57");
}
// ** UOp graph
@@ -630,9 +631,6 @@ hljs.registerLanguage("cpp", (hljs) => ({
...hljs.getLanguage('cpp'),
contains: [{ begin: '\\b(?:float|half)[0-9]+\\b', className: 'type' }, ...hljs.getLanguage('cpp').contains]
}));
hljs.registerLanguage("amdgpu", (hljs) => ({
contains: [hljs.COMMENT("//", "$"), { begin:/\b(?:s_|v_|global_|buffer_|scratch_|flat_|ds_)[a-z0-9_]*\b/, className:"code" }]
}));
async function fetchValue(path) {
const res = await fetch(path);
@@ -805,6 +803,7 @@ async function main() {
return table;
}
if (ret.cols != null) renderTable(root, ret);
else if (ret.data != null) renderDag(ret, { recenter:true });
else if (ret.src != null) root.append(() => codeBlock(ret.src, ret.lang));
ret.metadata?.forEach(m => {
if (Array.isArray(m)) return metadata.appendChild(tabulate(m.map(({ label, value, idx }) => {
@@ -836,7 +835,7 @@ async function main() {
if (ret.length === 0) return;
// ** center graph
const data = ret[currentRewrite];
const render = (opts) => renderDag({ graph:data.graph, change:data.change, opts }, { recenter:currentRewrite === 0 });
const render = (opts) => renderDag({ data, opts }, { recenter:currentRewrite === 0 });
render({ showIndexing:toggle.checked });
toggle.onchange = (e) => render({ showIndexing:e.target.checked });
// ** right sidebar metadata

View File

@@ -1,22 +1,52 @@
const NODE_PADDING = 10;
const rectDims = (lw, lh) => ({ width:lw+NODE_PADDING*2, height:lh+NODE_PADDING*2, labelWidth:lw, labelHeight:lh });
const LINE_HEIGHT = 14;
const canvas = new OffscreenCanvas(0, 0);
const ctx = canvas.getContext("2d");
ctx.font = `350 ${LINE_HEIGHT}px sans-serif`;
onmessage = (e) => {
const { graph, change, opts } = e.data;
const g = new dagre.graphlib.Graph({ compound: true });
g.setGraph({ rankdir: "LR" }).setDefaultEdgeLabel(function() { return {}; });
const { data, opts } = e.data;
const g = new dagre.graphlib.Graph({ compound: true }).setDefaultEdgeLabel(function() { return {}; });
(data.blocks != null ? layoutCfg : layoutUOp)(g, data, opts);
postMessage(dagre.graphlib.json.write(g));
self.close();
}
const layoutCfg = (g, { blocks, paths, pc_table, colors }) => {
g.setGraph({ rankdir:"TD", font:"monospace" });
ctx.font = `350 ${LINE_HEIGHT}px ${g.graph().font}`;
// basic blocks render the assembly in nodes
for (const [lead, members] of Object.entries(blocks)) {
let [width, height, label] = [0, 0, []];
for (const m of members) {
const text = pc_table[m][0];
width = Math.max(width, ctx.measureText(text).width);
height += LINE_HEIGHT;
const [inst, ...operands] = text.split(" ");
label.push([{st:inst+" ", color:"#7aa2f7"}, {st:operands.join(" "), color:"#9aa5ce"}]);
}
g.setNode(lead, { ...rectDims(width, height), label, id:lead, color:"#1a1b26" });
}
// paths become edges between basic blocks
for (const [lead, value] of Object.entries(paths)) {
for (const [id, color] of Object.entries(value)) g.setEdge(lead, id, {label:{type:"port", text:""}, color:colors[color]});
}
dagre.layout(g);
}
const layoutUOp = (g, { graph, change }, opts) => {
g.setGraph({ rankdir: "LR", font:"sans-serif" });
ctx.font = `350 ${LINE_HEIGHT}px ${g.graph().font}`;
if (change?.length) g.setNode("overlay", {label:"", labelWidth:0, labelHeight:0, className:"overlay"});
for (const [k, {label, src, ref, ...rest }] of Object.entries(graph)) {
for (const [k, {label, src, ref, color, tag }] of Object.entries(graph)) {
// adjust node dims by label size (excluding escape codes) + add padding
let [width, height] = [0, 0];
for (line of label.replace(/\u001B\[(?:K|.*?m)/g, "").split("\n")) {
width = Math.max(width, ctx.measureText(line).width);
height += LINE_HEIGHT;
}
g.setNode(k, {width:width+NODE_PADDING*2, height:height+NODE_PADDING*2, label, labelHeight:height, labelWidth:width, ref, id:k, ...rest});
g.setNode(k, {...rectDims(width, height), label, ref, id:k, color, tag});
// add edges
const edgeCounts = {};
for (const [_, s] of src) edgeCounts[s] = (edgeCounts[s] || 0)+1;
@@ -33,6 +63,4 @@ onmessage = (e) => {
dagre.layout(g);
// remove overlay node if it's empty
if (!g.node("overlay")?.width) g.removeNode("overlay");
postMessage(dagre.graphlib.json.write(g));
self.close();
}

View File

@@ -378,6 +378,67 @@ def amd_readelf(lib:bytes) -> list[dict]:
".group_segment_fixed_size":"LDS size", ".private_segment_fixed_size":"Scratch size"}
return [{"label":label, "value":v} for k,label in keys.items() if (v:=notes["amdhsa.kernels"][0][k]) > 0]
def llvm_disasm(arch:str, lib:bytes) -> dict[int, tuple[str, int]]:
from tinygrad.runtime.autogen import llvm
from tinygrad.runtime.support.elf import elf_loader
llvm.LLVMInitializeAMDGPUTargetInfo()
llvm.LLVMInitializeAMDGPUTargetMC()
llvm.LLVMInitializeAMDGPUAsmParser()
llvm.LLVMInitializeAMDGPUDisassembler()
# pass NULL to callbacks
cbs = [ctypes.cast(0, llvm.LLVMCreateDisasmCPUFeatures.argtypes[i]) for i in {5,6}]
ctx = llvm.LLVMCreateDisasmCPUFeatures("amdgcn-amd-amdhsa".encode(), arch.encode(), "".encode(), None, 0, *cbs)
image, sections, _ = elf_loader(lib)
text = next((sh.header for sh in sections if sh.name == ".text"), None)
assert text is not None, "no .text section found in ELF"
off, sz = text.sh_addr, text.sh_size
addr_table:dict[int, tuple[str, int]] = {}
out = ctypes.create_string_buffer(128)
cur_off = off
while cur_off < sz + off:
view = (ctypes.c_ubyte * ((sz + off) - cur_off)).from_buffer_copy(memoryview(image)[cur_off:])
instr_sz = llvm.LLVMDisasmInstruction(ctx, view, ctypes.c_uint64(len(view)), ctypes.c_uint64(0), out, ctypes.c_size_t(128))
addr_table[cur_off] = (out.value.decode("utf-8", "replace").strip(), instr_sz)
cur_off += instr_sz
return addr_table
SOPP_INSTS = {"s_branch", "s_cbranch_scc0", "s_cbranch_scc1", "s_cbranch_vccz", "s_cbranch_vccnz", "s_cbranch_execz", "s_cbranch_execnz"}
def parse_branch(asm:str) -> int|None:
inst, *operands = asm.split(" ")
if inst in SOPP_INSTS:
x = int(operands[0]) & 0xffff
return (x - 0x10000 if x & 0x8000 else x)*4
return None
COND_TAKEN, COND_NOT_TAKEN, UNCOND = range(3)
cfg_colors = {COND_TAKEN: "#3f7564", COND_NOT_TAKEN: "#7a4540", UNCOND: "#3b5f7e"}
def amdgpu_cfg(lib:bytes, arch:str) -> dict:
# disassemble
pc_table = llvm_disasm(arch, lib)
# get leaders
leaders:set[int] = {next(iter(pc_table))}
for pc, (asm, sz) in pc_table.items():
if (offset:=parse_branch(asm)) is not None: leaders.update((pc+sz+offset, pc+sz))
# build the cfg
curr:int|None = None
blocks:dict[int, list[int]] = {}
paths:dict[int, dict[int, int]] = {}
for pc, (asm, sz) in pc_table.items():
if pc in leaders:
paths[curr:=pc] = {}
blocks[pc] = []
else: assert curr is not None, f"no basic block found for {pc}"
blocks[curr].append(pc)
# control flow ends in endpgm
if asm == "s_endpgm": break
# otherwise a basic block can have exactly one or two paths
nx = pc+sz
if (offset:=parse_branch(asm)) is not None:
if asm.startswith("s_branch"): paths[curr][nx+offset] = UNCOND
else: paths[curr].update([(nx+offset, COND_TAKEN), (nx, COND_NOT_TAKEN)])
elif nx in leaders: paths[curr][nx] = UNCOND
return {"blocks":blocks, "paths":paths, "pc_table":pc_table, "colors":cfg_colors}
# ** Main render function to get the complete details about a trace event
def get_render(i:int, j:int, fmt:str) -> dict:
@@ -392,14 +453,15 @@ def get_render(i:int, j:int, fmt:str) -> dict:
if isinstance(compiler, LLVMCompiler):
return get_llvm_mca(disasm_str, ctypes.string_at(llvm.LLVMGetTargetMachineTriple(tm:=compiler.target_machine)).decode(),
ctypes.string_at(llvm.LLVMGetTargetMachineCPU(tm)).decode())
metadata:list = []
ret:dict = {"src":disasm_str}
if data.device.startswith("AMD"):
with soft_err(lambda err: metadata.append(err)):
metadata.append(amd_readelf(compiler.compile(data.src)))
return {"src":disasm_str, "lang":"amdgpu" if data.device.startswith("AMD") else None, "metadata":metadata}
with soft_err(lambda err: ret.update(err)):
metadata = amd_readelf(lib:=compiler.compile(data.src))
ret = {"data":amdgpu_cfg(lib, getattr(compiler, "arch")), "metadata":[metadata]}
return ret
if fmt == "all-pmc":
durations, pmc = data
ret:dict = {"cols":{}, "rows":[]}
ret = {"cols":{}, "rows":[]}
for (name, n, k),events in data[1].items():
pmc_table = unpack_pmc(events)
ret["cols"].update([(r[0], None) for r in pmc_table["rows"]])