mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
use new style amd compiler in viz (#13848)
* working version, handcode gfx1100 arch * get target from device properties * lib in cfg test program spec
This commit is contained in:
@@ -55,9 +55,9 @@ class _ROCParseCtx:
|
|||||||
self.occ_events:dict[RunKey, list[OccEvent]] = {}
|
self.occ_events:dict[RunKey, list[OccEvent]] = {}
|
||||||
|
|
||||||
for prog in prog_evs:
|
for prog in prog_evs:
|
||||||
arch = "gfx%d%x%x" % ((trgt:=unwrap(dev_evs[prog.device].props)['gfx_target_version']) // 10000, (trgt // 100) % 100, trgt % 100)
|
|
||||||
base = unwrap(prog.base)
|
base = unwrap(prog.base)
|
||||||
self.disasms[prog.name] = asm = {base+addr:info for addr,info in llvm_disasm(arch, unwrap(prog.lib)).items()}
|
target = unwrap(dev_evs[prog.device].props)['gfx_target_version']
|
||||||
|
self.disasms[prog.name] = asm = {base+addr:info for addr,info in llvm_disasm(target, unwrap(prog.lib)).items()}
|
||||||
|
|
||||||
def next_sqtt(self):
|
def next_sqtt(self):
|
||||||
x = next(self.sqtt_evs, None)
|
x = next(self.sqtt_evs, None)
|
||||||
|
|||||||
@@ -60,8 +60,8 @@ amdhsa.kernels:
|
|||||||
@track_rewrites(name=lambda *args,ret,**kwargs: TracingKey(ret.name, ret=ret))
|
@track_rewrites(name=lambda *args,ret,**kwargs: TracingKey(ret.name, ret=ret))
|
||||||
def run_asm(name:str, insts:list) -> ProgramSpec:
|
def run_asm(name:str, insts:list) -> ProgramSpec:
|
||||||
src = "\n".join([inst if isinstance(inst, str) else inst.disasm() for inst in insts])
|
src = "\n".join([inst if isinstance(inst, str) else inst.disasm() for inst in insts])
|
||||||
prg = ProgramSpec(name, template.replace("fn_name", name).replace("INSTRUCTION", textwrap.dedent(src)), Device.DEFAULT, UOp(Ops.SINK),
|
prg = ProgramSpec(name, src:=template.replace("fn_name", name).replace("INSTRUCTION", textwrap.dedent(src)), Device.DEFAULT, UOp(Ops.SINK),
|
||||||
global_size=[1, 1, 1], local_size=[1, 1, 1], globals=[0])
|
lib=Device[Device.DEFAULT].compiler.compile(src), global_size=[1, 1, 1], local_size=[1, 1, 1], globals=[0])
|
||||||
ei = ExecItem(UOp(Ops.SINK), [Tensor.empty(1).uop.buffer.ensure_allocated()], prg=CompiledRunner(prg))
|
ei = ExecItem(UOp(Ops.SINK), [Tensor.empty(1).uop.buffer.ensure_allocated()], prg=CompiledRunner(prg))
|
||||||
ei.run()
|
ei.run()
|
||||||
return prg
|
return prg
|
||||||
|
|||||||
@@ -140,6 +140,8 @@ def option(s:int|None) -> int: return 0 if s is None else s+1
|
|||||||
device_ts_diffs:dict[str, tuple[Decimal, Decimal]] = {}
|
device_ts_diffs:dict[str, tuple[Decimal, Decimal]] = {}
|
||||||
def cpu_ts_diff(device:str, thread=0) -> Decimal: return device_ts_diffs.get(device, (Decimal(0),))[thread]
|
def cpu_ts_diff(device:str, thread=0) -> Decimal: return device_ts_diffs.get(device, (Decimal(0),))[thread]
|
||||||
|
|
||||||
|
device_props:dict[str, dict] = {}
|
||||||
|
|
||||||
DevEvent = ProfileRangeEvent|ProfileGraphEntry|ProfilePointEvent
|
DevEvent = ProfileRangeEvent|ProfileGraphEntry|ProfilePointEvent
|
||||||
def flatten_events(profile:list[ProfileEvent]) -> Generator[tuple[Decimal, Decimal, DevEvent], None, None]:
|
def flatten_events(profile:list[ProfileEvent]) -> Generator[tuple[Decimal, Decimal, DevEvent], None, None]:
|
||||||
for e in profile:
|
for e in profile:
|
||||||
@@ -309,6 +311,7 @@ def get_profile(profile:list[ProfileEvent], sort_fn:Callable[[str], Any]=device_
|
|||||||
for ev in profile:
|
for ev in profile:
|
||||||
if isinstance(ev, ProfileDeviceEvent):
|
if isinstance(ev, ProfileDeviceEvent):
|
||||||
device_ts_diffs[ev.device] = (ev.comp_tdiff,ev.copy_tdiff if ev.copy_tdiff is not None else ev.comp_tdiff)
|
device_ts_diffs[ev.device] = (ev.comp_tdiff,ev.copy_tdiff if ev.copy_tdiff is not None else ev.comp_tdiff)
|
||||||
|
if ev.props is not None: device_props[ev.device] = ev.props
|
||||||
if (d:=ev.device.split(":")[0]) == "AMD": device_decoders[d] = load_counters
|
if (d:=ev.device.split(":")[0]) == "AMD": device_decoders[d] = load_counters
|
||||||
# load device specific counters
|
# load device specific counters
|
||||||
for fxn in device_decoders.values(): fxn(profile)
|
for fxn in device_decoders.values(): fxn(profile)
|
||||||
@@ -358,7 +361,7 @@ def amd_readelf(lib:bytes) -> list[dict]:
|
|||||||
".group_segment_fixed_size":"LDS size", ".private_segment_fixed_size":"Scratch size"}
|
".group_segment_fixed_size":"LDS size", ".private_segment_fixed_size":"Scratch size"}
|
||||||
return [{"label":label, "value":v} for k,label in keys.items() if (v:=notes["amdhsa.kernels"][0][k]) > 0]
|
return [{"label":label, "value":v} for k,label in keys.items() if (v:=notes["amdhsa.kernels"][0][k]) > 0]
|
||||||
|
|
||||||
def llvm_disasm(arch:str, lib:bytes) -> dict[int, tuple[str, int]]:
|
def llvm_disasm(target:int, lib:bytes) -> dict[int, tuple[str, int]]:
|
||||||
from tinygrad.runtime.autogen import llvm
|
from tinygrad.runtime.autogen import llvm
|
||||||
from tinygrad.runtime.support.elf import elf_loader
|
from tinygrad.runtime.support.elf import elf_loader
|
||||||
llvm.LLVMInitializeAMDGPUTargetInfo()
|
llvm.LLVMInitializeAMDGPUTargetInfo()
|
||||||
@@ -367,6 +370,7 @@ def llvm_disasm(arch:str, lib:bytes) -> dict[int, tuple[str, int]]:
|
|||||||
llvm.LLVMInitializeAMDGPUDisassembler()
|
llvm.LLVMInitializeAMDGPUDisassembler()
|
||||||
# pass NULL to callbacks
|
# pass NULL to callbacks
|
||||||
cbs = [ctypes.cast(0, llvm.LLVMCreateDisasmCPUFeatures.argtypes[i]) for i in {5,6}]
|
cbs = [ctypes.cast(0, llvm.LLVMCreateDisasmCPUFeatures.argtypes[i]) for i in {5,6}]
|
||||||
|
arch = "gfx%d%x%x" % (target // 10000, (target // 100) % 100, target % 100)
|
||||||
ctx = llvm.LLVMCreateDisasmCPUFeatures("amdgcn-amd-amdhsa".encode(), arch.encode(), "".encode(), None, 0, *cbs)
|
ctx = llvm.LLVMCreateDisasmCPUFeatures("amdgcn-amd-amdhsa".encode(), arch.encode(), "".encode(), None, 0, *cbs)
|
||||||
image, sections, _ = elf_loader(lib)
|
image, sections, _ = elf_loader(lib)
|
||||||
text = next((sh.header for sh in sections if sh.name == ".text"), None)
|
text = next((sh.header for sh in sections if sh.name == ".text"), None)
|
||||||
@@ -392,9 +396,9 @@ def parse_branch(asm:str) -> int|None:
|
|||||||
|
|
||||||
COND_TAKEN, COND_NOT_TAKEN, UNCOND = range(3)
|
COND_TAKEN, COND_NOT_TAKEN, UNCOND = range(3)
|
||||||
cfg_colors = {COND_TAKEN: "#3f7564", COND_NOT_TAKEN: "#7a4540", UNCOND: "#3b5f7e"}
|
cfg_colors = {COND_TAKEN: "#3f7564", COND_NOT_TAKEN: "#7a4540", UNCOND: "#3b5f7e"}
|
||||||
def amdgpu_cfg(lib:bytes, arch:str) -> dict:
|
def amdgpu_cfg(lib:bytes, target:int) -> dict:
|
||||||
# disassemble
|
# disassemble
|
||||||
pc_table = llvm_disasm(arch, lib)
|
pc_table = llvm_disasm(target, lib)
|
||||||
# get leaders
|
# get leaders
|
||||||
leaders:set[int] = {next(iter(pc_table))}
|
leaders:set[int] = {next(iter(pc_table))}
|
||||||
for pc, (asm, sz) in pc_table.items():
|
for pc, (asm, sz) in pc_table.items():
|
||||||
@@ -427,13 +431,12 @@ def get_render(i:int, j:int, fmt:str) -> dict:
|
|||||||
if fmt == "uops": return {"src":get_stdout(lambda: print_uops(data.uops or [])), "lang":"txt"}
|
if fmt == "uops": return {"src":get_stdout(lambda: print_uops(data.uops or [])), "lang":"txt"}
|
||||||
if fmt == "code": return {"src":data.src, "lang":"cpp"}
|
if fmt == "code": return {"src":data.src, "lang":"cpp"}
|
||||||
if fmt == "asm":
|
if fmt == "asm":
|
||||||
compiler = Device[data.device].compiler
|
|
||||||
ret:dict = {"metadata":[]}
|
ret:dict = {"metadata":[]}
|
||||||
if data.device.startswith("AMD"):
|
if data.device.startswith("AMD") and data.lib is not None:
|
||||||
with soft_err(lambda err: ret.update(err)):
|
with soft_err(lambda err: ret.update(err)):
|
||||||
ret["data"] = amdgpu_cfg(lib:=compiler.compile(data.src), getattr(compiler, "arch"))
|
ret["data"] = amdgpu_cfg(lib:=data.lib, device_props[data.device]["gfx_target_version"])
|
||||||
with soft_err(lambda err: ret["metadata"].append(err)): ret["metadata"].append(amd_readelf(lib))
|
with soft_err(lambda err: ret["metadata"].append(err)): ret["metadata"].append(amd_readelf(lib))
|
||||||
else: ret["src"] = get_stdout(lambda: compiler.disassemble(compiler.compile(data.src)))
|
else: ret["src"] = get_stdout(lambda: (compiler:=Device[data.device].compiler).disassemble(compiler.compile(data.src)))
|
||||||
return ret
|
return ret
|
||||||
if fmt == "all-pmc":
|
if fmt == "all-pmc":
|
||||||
durations, pmc = data
|
durations, pmc = data
|
||||||
|
|||||||
Reference in New Issue
Block a user