mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 14:43:57 -05:00
viz: amdgpu assembly basic block graph (#13755)
This commit is contained in:
@@ -5,28 +5,7 @@ from tinygrad.device import ProfileEvent, ProfileDeviceEvent, ProfileProgramEven
|
||||
from tinygrad.runtime.ops_amd import ProfileSQTTEvent, ProfilePMCEvent
|
||||
from tinygrad.runtime.autogen import llvm, rocprof
|
||||
from tinygrad.runtime.support.elf import elf_loader
|
||||
|
||||
def llvm_disasm(arch:str, lib:bytes) -> dict[int, tuple[str, int]]:
|
||||
llvm.LLVMInitializeAMDGPUTargetInfo()
|
||||
llvm.LLVMInitializeAMDGPUTargetMC()
|
||||
llvm.LLVMInitializeAMDGPUAsmParser()
|
||||
llvm.LLVMInitializeAMDGPUDisassembler()
|
||||
# pass NULL to callbacks
|
||||
cbs = [ctypes.cast(0, llvm.LLVMCreateDisasmCPUFeatures.argtypes[i]) for i in {5,6}]
|
||||
ctx = llvm.LLVMCreateDisasmCPUFeatures("amdgcn-amd-amdhsa".encode(), arch.encode(), "".encode(), None, 0, *cbs)
|
||||
image, sections, relocs = elf_loader(lib)
|
||||
text = next((sh.header for sh in sections if sh.name == ".text"), None)
|
||||
off, sz = unwrap(text).sh_addr, unwrap(text).sh_size
|
||||
|
||||
addr_table:dict[int, tuple[str, int]] = {}
|
||||
out = ctypes.create_string_buffer(128)
|
||||
cur_off = off
|
||||
while cur_off < sz + off:
|
||||
view = (ctypes.c_ubyte * ((sz + off) - cur_off)).from_buffer_copy(memoryview(image)[cur_off:])
|
||||
instr_sz = llvm.LLVMDisasmInstruction(ctx, view, ctypes.c_uint64(len(view)), ctypes.c_uint64(0), out, ctypes.c_size_t(128))
|
||||
addr_table[cur_off] = (out.value.decode("utf-8", "replace").strip(), instr_sz)
|
||||
cur_off += instr_sz
|
||||
return addr_table
|
||||
from tinygrad.viz.serve import llvm_disasm
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class InstExec:
|
||||
|
||||
172
test/testextra/test_cfg_viz.py
Normal file
172
test/testextra/test_cfg_viz.py
Normal file
@@ -0,0 +1,172 @@
|
||||
import unittest
|
||||
import textwrap
|
||||
|
||||
from tinygrad import Device, Tensor
|
||||
from tinygrad.uop.ops import UOp, Ops, track_rewrites
|
||||
from tinygrad.renderer import ProgramSpec
|
||||
from tinygrad.helpers import TracingKey
|
||||
from tinygrad.engine.realize import ExecItem, CompiledRunner
|
||||
|
||||
# TODO: use the RDNA3 renderer when it's in master
|
||||
template = """.text
|
||||
.globl fn_name
|
||||
.p2align 8
|
||||
.type fn_name,@function
|
||||
fn_name:
|
||||
INSTRUCTION
|
||||
|
||||
.rodata
|
||||
.p2align 6
|
||||
.amdhsa_kernel fn_name
|
||||
.amdhsa_user_sgpr_kernarg_segment_ptr 1
|
||||
.amdhsa_next_free_vgpr .amdgcn.next_free_vgpr
|
||||
.amdhsa_next_free_sgpr .amdgcn.next_free_sgpr
|
||||
.amdhsa_wavefront_size32 1
|
||||
.end_amdhsa_kernel
|
||||
|
||||
.amdgpu_metadata
|
||||
---
|
||||
amdhsa.version:
|
||||
- 1
|
||||
- 0
|
||||
amdhsa.kernels:
|
||||
- .name: fn_name
|
||||
.symbol: fn_name.kd
|
||||
.group_segment_fixed_size: 0
|
||||
.private_segment_fixed_size: 0
|
||||
.wavefront_size: 32
|
||||
.sgpr_count: 8
|
||||
.vgpr_count: 8
|
||||
.max_flat_workgroup_size: 1024
|
||||
.kernarg_segment_align: 8
|
||||
.kernarg_segment_size: 8
|
||||
.args:
|
||||
- .address_space: global
|
||||
.name: a
|
||||
.offset: 0
|
||||
.size: 8
|
||||
.type_name: 'float*'
|
||||
.value_kind: global_buffer
|
||||
...
|
||||
.end_amdgpu_metadata
|
||||
"""
|
||||
|
||||
@track_rewrites(name=lambda *args,ret,**kwargs: TracingKey(ret.name, ret=ret))
|
||||
def run_asm(name:str, src:str) -> ProgramSpec:
|
||||
prg = ProgramSpec(name, template.replace("fn_name", name).replace("INSTRUCTION", textwrap.dedent(src)), Device.DEFAULT, UOp(Ops.SINK))
|
||||
ei = ExecItem(UOp(Ops.SINK), [Tensor.empty(1).uop.buffer.ensure_allocated()], prg=CompiledRunner(prg))
|
||||
ei.run()
|
||||
return prg
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT == "AMD", "only on AMD")
|
||||
class TestCfg(unittest.TestCase):
|
||||
def setUp(self):
|
||||
arch = Device["AMD"].arch
|
||||
if not any(arch.startswith(a) for a in {"gfx11", "gfx12"}):
|
||||
self.skipTest(f"tests written for RDNA, got arch {arch}")
|
||||
|
||||
def test_simple(self):
|
||||
run_asm("simple", """
|
||||
entry:
|
||||
s_branch bb1
|
||||
bb1:
|
||||
s_endpgm
|
||||
""")
|
||||
|
||||
def test_diamond(self):
|
||||
run_asm("diamond", """
|
||||
entry:
|
||||
s_cmp_eq_i32 s0, 0
|
||||
s_cbranch_scc1 if
|
||||
s_branch else
|
||||
if:
|
||||
s_nop 1
|
||||
s_branch end
|
||||
else:
|
||||
s_nop 0
|
||||
end:
|
||||
s_endpgm
|
||||
""")
|
||||
|
||||
def test_loop(self):
|
||||
run_asm("simple_loop", """
|
||||
entry:
|
||||
s_mov_b32 s1, 4
|
||||
loop:
|
||||
s_add_u32 s1, s1, -1
|
||||
s_cmp_eq_i32 s1, 0
|
||||
s_cbranch_scc0 loop
|
||||
s_endpgm
|
||||
""")
|
||||
|
||||
def test_loop_branch(self):
|
||||
run_asm("loop_if", """
|
||||
entry:
|
||||
s_mov_b32 s1, 4
|
||||
loop:
|
||||
s_add_u32 s1, s1, -1
|
||||
s_cmp_eq_i32 s1, 2
|
||||
s_cbranch_scc1 cond
|
||||
s_branch cont
|
||||
cond:
|
||||
s_add_u32 s1, s1, -2
|
||||
cont:
|
||||
s_cmp_eq_i32 s1, 0
|
||||
s_cbranch_scc0 loop
|
||||
s_endpgm
|
||||
""")
|
||||
|
||||
def test_loop_break(self):
|
||||
run_asm("loop_break", """
|
||||
entry:
|
||||
s_mov_b32 s1, 8
|
||||
loop:
|
||||
s_add_u32 s1, s1, -1
|
||||
s_cmp_eq_i32 s1, 5
|
||||
s_cbranch_scc1 break
|
||||
s_cmp_eq_i32 s1, 0
|
||||
s_cbranch_scc0 loop
|
||||
break:
|
||||
s_endpgm
|
||||
""")
|
||||
|
||||
def test_switch(self):
|
||||
run_asm("switch_case", """
|
||||
entry:
|
||||
s_cmp_eq_i32 s0, 0
|
||||
s_cbranch_scc1 case0
|
||||
s_cmp_eq_i32 s0, 1
|
||||
s_cbranch_scc1 case1
|
||||
s_branch case2
|
||||
case0:
|
||||
s_nop 0
|
||||
s_branch join
|
||||
case1:
|
||||
s_nop 1
|
||||
s_branch join
|
||||
case2:
|
||||
s_nop 2
|
||||
s_branch join
|
||||
join:
|
||||
s_endpgm
|
||||
""")
|
||||
|
||||
def test_ping_pong(self):
|
||||
run_asm("ping_pong", """
|
||||
entry:
|
||||
s_cmp_eq_i32 s0, 0
|
||||
s_cbranch_scc1 ping
|
||||
s_branch pong
|
||||
ping:
|
||||
s_cmp_eq_i32 s1, 0
|
||||
s_cbranch_scc1 pong
|
||||
s_branch end
|
||||
pong:
|
||||
s_cmp_eq_i32 s2, 0
|
||||
s_cbranch_scc1 ping
|
||||
end:
|
||||
s_endpgm
|
||||
""")
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -124,7 +124,6 @@
|
||||
fill: rgba(26, 27, 38, 0.5);
|
||||
}
|
||||
.edgePath {
|
||||
stroke: #4a4b57;
|
||||
fill: none;
|
||||
stroke-width: 1.4px;
|
||||
}
|
||||
|
||||
@@ -78,6 +78,7 @@ const drawGraph = (data) => {
|
||||
const labels = nodes.selectAll("g.label").data(d => [d]).join("g").attr("class", "label");
|
||||
labels.attr("transform", d => `translate(-${d.labelWidth/2}, -${d.labelHeight/2+STROKE_WIDTH*2})`);
|
||||
labels.selectAll("text").data(d => {
|
||||
if (Array.isArray(d.label)) return [d.label];
|
||||
const ret = [[]];
|
||||
for (const s of parseColors(d.label, defaultColor="initial")) {
|
||||
const color = darkenHex(s.color, 25);
|
||||
@@ -87,7 +88,7 @@ const drawGraph = (data) => {
|
||||
}
|
||||
return [ret];
|
||||
}).join("text").selectAll("tspan").data(d => d).join("tspan").attr("x", "0").attr("dy", 14).selectAll("tspan").data(d => d).join("tspan")
|
||||
.attr("fill", d => d.color).text(d => d.st).attr("xml:space", "preserve");
|
||||
.attr("fill", d => d.color).text(d => d.st).attr("xml:space", "preserve").style("font-family", g.graph().font);
|
||||
addTags(nodes.selectAll("g.tag").data(d => d.tag != null ? [d] : []).join("g").attr("class", "tag")
|
||||
.attr("transform", d => `translate(${-d.width/2+8}, ${-d.height/2+8})`).datum(e => e.tag));
|
||||
// draw edges
|
||||
@@ -98,7 +99,7 @@ const drawGraph = (data) => {
|
||||
points.unshift(intersectRect(g.node(e.v), points[0]));
|
||||
points.push(intersectRect(g.node(e.w), points[points.length-1]));
|
||||
return line(points);
|
||||
}).attr("marker-end", "url(#arrowhead)");
|
||||
}).attr("marker-end", "url(#arrowhead)").attr("stroke", e => g.edge(e).color || "#4a4b57");
|
||||
}
|
||||
|
||||
// ** UOp graph
|
||||
@@ -630,9 +631,6 @@ hljs.registerLanguage("cpp", (hljs) => ({
|
||||
...hljs.getLanguage('cpp'),
|
||||
contains: [{ begin: '\\b(?:float|half)[0-9]+\\b', className: 'type' }, ...hljs.getLanguage('cpp').contains]
|
||||
}));
|
||||
hljs.registerLanguage("amdgpu", (hljs) => ({
|
||||
contains: [hljs.COMMENT("//", "$"), { begin:/\b(?:s_|v_|global_|buffer_|scratch_|flat_|ds_)[a-z0-9_]*\b/, className:"code" }]
|
||||
}));
|
||||
|
||||
async function fetchValue(path) {
|
||||
const res = await fetch(path);
|
||||
@@ -805,6 +803,7 @@ async function main() {
|
||||
return table;
|
||||
}
|
||||
if (ret.cols != null) renderTable(root, ret);
|
||||
else if (ret.data != null) renderDag(ret, { recenter:true });
|
||||
else if (ret.src != null) root.append(() => codeBlock(ret.src, ret.lang));
|
||||
ret.metadata?.forEach(m => {
|
||||
if (Array.isArray(m)) return metadata.appendChild(tabulate(m.map(({ label, value, idx }) => {
|
||||
@@ -836,7 +835,7 @@ async function main() {
|
||||
if (ret.length === 0) return;
|
||||
// ** center graph
|
||||
const data = ret[currentRewrite];
|
||||
const render = (opts) => renderDag({ graph:data.graph, change:data.change, opts }, { recenter:currentRewrite === 0 });
|
||||
const render = (opts) => renderDag({ data, opts }, { recenter:currentRewrite === 0 });
|
||||
render({ showIndexing:toggle.checked });
|
||||
toggle.onchange = (e) => render({ showIndexing:e.target.checked });
|
||||
// ** right sidebar metadata
|
||||
|
||||
@@ -1,22 +1,52 @@
|
||||
const NODE_PADDING = 10;
|
||||
const rectDims = (lw, lh) => ({ width:lw+NODE_PADDING*2, height:lh+NODE_PADDING*2, labelWidth:lw, labelHeight:lh });
|
||||
|
||||
const LINE_HEIGHT = 14;
|
||||
const canvas = new OffscreenCanvas(0, 0);
|
||||
const ctx = canvas.getContext("2d");
|
||||
ctx.font = `350 ${LINE_HEIGHT}px sans-serif`;
|
||||
|
||||
onmessage = (e) => {
|
||||
const { graph, change, opts } = e.data;
|
||||
const g = new dagre.graphlib.Graph({ compound: true });
|
||||
g.setGraph({ rankdir: "LR" }).setDefaultEdgeLabel(function() { return {}; });
|
||||
const { data, opts } = e.data;
|
||||
const g = new dagre.graphlib.Graph({ compound: true }).setDefaultEdgeLabel(function() { return {}; });
|
||||
(data.blocks != null ? layoutCfg : layoutUOp)(g, data, opts);
|
||||
postMessage(dagre.graphlib.json.write(g));
|
||||
self.close();
|
||||
}
|
||||
|
||||
const layoutCfg = (g, { blocks, paths, pc_table, colors }) => {
|
||||
g.setGraph({ rankdir:"TD", font:"monospace" });
|
||||
ctx.font = `350 ${LINE_HEIGHT}px ${g.graph().font}`;
|
||||
// basic blocks render the assembly in nodes
|
||||
for (const [lead, members] of Object.entries(blocks)) {
|
||||
let [width, height, label] = [0, 0, []];
|
||||
for (const m of members) {
|
||||
const text = pc_table[m][0];
|
||||
width = Math.max(width, ctx.measureText(text).width);
|
||||
height += LINE_HEIGHT;
|
||||
const [inst, ...operands] = text.split(" ");
|
||||
label.push([{st:inst+" ", color:"#7aa2f7"}, {st:operands.join(" "), color:"#9aa5ce"}]);
|
||||
}
|
||||
g.setNode(lead, { ...rectDims(width, height), label, id:lead, color:"#1a1b26" });
|
||||
}
|
||||
// paths become edges between basic blocks
|
||||
for (const [lead, value] of Object.entries(paths)) {
|
||||
for (const [id, color] of Object.entries(value)) g.setEdge(lead, id, {label:{type:"port", text:""}, color:colors[color]});
|
||||
}
|
||||
dagre.layout(g);
|
||||
}
|
||||
|
||||
const layoutUOp = (g, { graph, change }, opts) => {
|
||||
g.setGraph({ rankdir: "LR", font:"sans-serif" });
|
||||
ctx.font = `350 ${LINE_HEIGHT}px ${g.graph().font}`;
|
||||
if (change?.length) g.setNode("overlay", {label:"", labelWidth:0, labelHeight:0, className:"overlay"});
|
||||
for (const [k, {label, src, ref, ...rest }] of Object.entries(graph)) {
|
||||
for (const [k, {label, src, ref, color, tag }] of Object.entries(graph)) {
|
||||
// adjust node dims by label size (excluding escape codes) + add padding
|
||||
let [width, height] = [0, 0];
|
||||
for (line of label.replace(/\u001B\[(?:K|.*?m)/g, "").split("\n")) {
|
||||
width = Math.max(width, ctx.measureText(line).width);
|
||||
height += LINE_HEIGHT;
|
||||
}
|
||||
g.setNode(k, {width:width+NODE_PADDING*2, height:height+NODE_PADDING*2, label, labelHeight:height, labelWidth:width, ref, id:k, ...rest});
|
||||
g.setNode(k, {...rectDims(width, height), label, ref, id:k, color, tag});
|
||||
// add edges
|
||||
const edgeCounts = {};
|
||||
for (const [_, s] of src) edgeCounts[s] = (edgeCounts[s] || 0)+1;
|
||||
@@ -33,6 +63,4 @@ onmessage = (e) => {
|
||||
dagre.layout(g);
|
||||
// remove overlay node if it's empty
|
||||
if (!g.node("overlay")?.width) g.removeNode("overlay");
|
||||
postMessage(dagre.graphlib.json.write(g));
|
||||
self.close();
|
||||
}
|
||||
|
||||
@@ -378,6 +378,67 @@ def amd_readelf(lib:bytes) -> list[dict]:
|
||||
".group_segment_fixed_size":"LDS size", ".private_segment_fixed_size":"Scratch size"}
|
||||
return [{"label":label, "value":v} for k,label in keys.items() if (v:=notes["amdhsa.kernels"][0][k]) > 0]
|
||||
|
||||
def llvm_disasm(arch:str, lib:bytes) -> dict[int, tuple[str, int]]:
|
||||
from tinygrad.runtime.autogen import llvm
|
||||
from tinygrad.runtime.support.elf import elf_loader
|
||||
llvm.LLVMInitializeAMDGPUTargetInfo()
|
||||
llvm.LLVMInitializeAMDGPUTargetMC()
|
||||
llvm.LLVMInitializeAMDGPUAsmParser()
|
||||
llvm.LLVMInitializeAMDGPUDisassembler()
|
||||
# pass NULL to callbacks
|
||||
cbs = [ctypes.cast(0, llvm.LLVMCreateDisasmCPUFeatures.argtypes[i]) for i in {5,6}]
|
||||
ctx = llvm.LLVMCreateDisasmCPUFeatures("amdgcn-amd-amdhsa".encode(), arch.encode(), "".encode(), None, 0, *cbs)
|
||||
image, sections, _ = elf_loader(lib)
|
||||
text = next((sh.header for sh in sections if sh.name == ".text"), None)
|
||||
assert text is not None, "no .text section found in ELF"
|
||||
off, sz = text.sh_addr, text.sh_size
|
||||
addr_table:dict[int, tuple[str, int]] = {}
|
||||
out = ctypes.create_string_buffer(128)
|
||||
cur_off = off
|
||||
while cur_off < sz + off:
|
||||
view = (ctypes.c_ubyte * ((sz + off) - cur_off)).from_buffer_copy(memoryview(image)[cur_off:])
|
||||
instr_sz = llvm.LLVMDisasmInstruction(ctx, view, ctypes.c_uint64(len(view)), ctypes.c_uint64(0), out, ctypes.c_size_t(128))
|
||||
addr_table[cur_off] = (out.value.decode("utf-8", "replace").strip(), instr_sz)
|
||||
cur_off += instr_sz
|
||||
return addr_table
|
||||
|
||||
SOPP_INSTS = {"s_branch", "s_cbranch_scc0", "s_cbranch_scc1", "s_cbranch_vccz", "s_cbranch_vccnz", "s_cbranch_execz", "s_cbranch_execnz"}
|
||||
def parse_branch(asm:str) -> int|None:
|
||||
inst, *operands = asm.split(" ")
|
||||
if inst in SOPP_INSTS:
|
||||
x = int(operands[0]) & 0xffff
|
||||
return (x - 0x10000 if x & 0x8000 else x)*4
|
||||
return None
|
||||
|
||||
COND_TAKEN, COND_NOT_TAKEN, UNCOND = range(3)
|
||||
cfg_colors = {COND_TAKEN: "#3f7564", COND_NOT_TAKEN: "#7a4540", UNCOND: "#3b5f7e"}
|
||||
def amdgpu_cfg(lib:bytes, arch:str) -> dict:
|
||||
# disassemble
|
||||
pc_table = llvm_disasm(arch, lib)
|
||||
# get leaders
|
||||
leaders:set[int] = {next(iter(pc_table))}
|
||||
for pc, (asm, sz) in pc_table.items():
|
||||
if (offset:=parse_branch(asm)) is not None: leaders.update((pc+sz+offset, pc+sz))
|
||||
# build the cfg
|
||||
curr:int|None = None
|
||||
blocks:dict[int, list[int]] = {}
|
||||
paths:dict[int, dict[int, int]] = {}
|
||||
for pc, (asm, sz) in pc_table.items():
|
||||
if pc in leaders:
|
||||
paths[curr:=pc] = {}
|
||||
blocks[pc] = []
|
||||
else: assert curr is not None, f"no basic block found for {pc}"
|
||||
blocks[curr].append(pc)
|
||||
# control flow ends in endpgm
|
||||
if asm == "s_endpgm": break
|
||||
# otherwise a basic block can have exactly one or two paths
|
||||
nx = pc+sz
|
||||
if (offset:=parse_branch(asm)) is not None:
|
||||
if asm.startswith("s_branch"): paths[curr][nx+offset] = UNCOND
|
||||
else: paths[curr].update([(nx+offset, COND_TAKEN), (nx, COND_NOT_TAKEN)])
|
||||
elif nx in leaders: paths[curr][nx] = UNCOND
|
||||
return {"blocks":blocks, "paths":paths, "pc_table":pc_table, "colors":cfg_colors}
|
||||
|
||||
# ** Main render function to get the complete details about a trace event
|
||||
|
||||
def get_render(i:int, j:int, fmt:str) -> dict:
|
||||
@@ -392,14 +453,15 @@ def get_render(i:int, j:int, fmt:str) -> dict:
|
||||
if isinstance(compiler, LLVMCompiler):
|
||||
return get_llvm_mca(disasm_str, ctypes.string_at(llvm.LLVMGetTargetMachineTriple(tm:=compiler.target_machine)).decode(),
|
||||
ctypes.string_at(llvm.LLVMGetTargetMachineCPU(tm)).decode())
|
||||
metadata:list = []
|
||||
ret:dict = {"src":disasm_str}
|
||||
if data.device.startswith("AMD"):
|
||||
with soft_err(lambda err: metadata.append(err)):
|
||||
metadata.append(amd_readelf(compiler.compile(data.src)))
|
||||
return {"src":disasm_str, "lang":"amdgpu" if data.device.startswith("AMD") else None, "metadata":metadata}
|
||||
with soft_err(lambda err: ret.update(err)):
|
||||
metadata = amd_readelf(lib:=compiler.compile(data.src))
|
||||
ret = {"data":amdgpu_cfg(lib, getattr(compiler, "arch")), "metadata":[metadata]}
|
||||
return ret
|
||||
if fmt == "all-pmc":
|
||||
durations, pmc = data
|
||||
ret:dict = {"cols":{}, "rows":[]}
|
||||
ret = {"cols":{}, "rows":[]}
|
||||
for (name, n, k),events in data[1].items():
|
||||
pmc_table = unpack_pmc(events)
|
||||
ret["cols"].update([(r[0], None) for r in pmc_table["rows"]])
|
||||
|
||||
Reference in New Issue
Block a user