rdna3 asm + zip_extract (#13499)

* rdna3 asm + zip_extract

* include sqtt

* fix end parsing

* disassembler working

* parsing fields

* instruction

* op

* more parsing
This commit is contained in:
George Hotz
2025-12-02 22:56:01 -08:00
committed by GitHub
parent 6bd355fa26
commit ddf3f2d0c4
4 changed files with 175 additions and 2 deletions

View File

@@ -0,0 +1,136 @@
import os, sys, struct
sys.path.append(os.getcwd())
# PROFILE=1 to use
#os.environ["PROFILE"] = "1"
os.environ["SQTT"] = "1"
os.environ["SQTT_ITRACE_SE_MASK"] = "1"
os.environ["SQTT_LIMIT_SE"] = "1"
import xml.etree.ElementTree as ET
from tinygrad import nn, Tensor, Device
from tinygrad.helpers import get_single_element
from tinygrad.engine.realize import lower_schedule
from tinygrad.runtime.support.elf import elf_loader
from tinygrad.runtime.ops_amd import ProfileSQTTEvent
from extra.sqtt.attempt_sqtt_parse import parse_sqtt_print_packets
def disassemble(text, root:ET.Element):
i = 0
while i < len(text):
ins = struct.unpack("I", text[i:i+4])[0]
# 1. Get the encoding
did_match = False
for enc_el in root.findall("./ISA/Encodings/Encoding"):
mask = enc_el.findtext("EncodingIdentifierMask")
assert len(mask)%32 == 0
bit_mask = int(mask, 2)
iden = [int(x.text, 2) for x in enc_el.find("EncodingIdentifiers").findall("EncodingIdentifier")]
for ide in iden:
if ins&bit_mask == ide:
did_match = True
break
if did_match: break
if not did_match: raise RuntimeError(f"unknown instruction {ins:08X}")
if len(mask) >= 64: ins = (struct.unpack("I", text[i+4:i+8])[0]<<32) | ins
if len(mask) >= 96: ins = (struct.unpack("I", text[i+8:i+12])[0]<<64) | ins
encoding_name = enc_el.findtext("EncodingName")
#print(ET.tostring(enc_el).decode())
# 2. Parse the Fields for this Encoding
field_data = {}
for field in enc_el.findall("MicrocodeFormat/BitMap/Field"):
# Fields can be split into multiple ranges (RangeCount > 1)
ranges = sorted(field.findall("BitLayout/Range"), key=lambda x: int(x.attrib.get('Order')))
val = 0
current_shift = 0
for rng in ranges:
width = int(rng.find("BitCount").text)
chunk = (ins >> int(rng.find("BitOffset").text)) & ((1 << width) - 1)
val |= (chunk << current_shift)
current_shift += width
field_data[field.find("FieldName").text] = val
# this is already used
del field_data["ENCODING"]
# 3. Extract the instruction
did_match = False
for ins_el in root.findall("./ISA/Instructions/Instruction"):
ins_name = ins_el.findtext("InstructionName")
for ins_enc in ins_el.findall("InstructionEncodings/InstructionEncoding"):
if ins_enc.findtext("EncodingName") == encoding_name:
opcode = int(ins_enc.findtext("Opcode"))
if "OP" in field_data and opcode == field_data["OP"]:
did_match = True
del field_data["OP"]
break
if did_match: break
if did_match: break
#print(ET.tostring(ins_enc).decode())
#print()
#print(field_data)
if not did_match:
print(f"{i:4X} : {ins:16x} -- {encoding_name}")
elif did_match:
params = []
#print(ET.tostring(ins_el).decode())
# 4. Extract the opcodes
for op_ins in ins_enc.findall("Operands/Operand"):
op_type = op_ins.findtext("OperandType")
op_size = op_ins.findtext("OperandSize")
op_fmt = op_ins.findtext("DataFormatName")
op_field_name = op_ins.findtext("FieldName")
if op_field_name is None: continue
assert op_field_name in field_data
# loop through operands for compare
for op_el in root.findall("./ISA/OperandTypes/OperandType"):
test_op_type = op_el.findtext("OperandTypeName")
val_dict = {}
for op_val in op_el.findall("OperandPredefinedValues/PredefinedValue"):
val_dict[int(op_val.findtext("Value"))] = op_val.findtext("Name")
if op_type == test_op_type:
if field_data[op_field_name] in val_dict:
print(op_type, op_size, op_fmt)
params.append(val_dict[field_data[op_field_name]])
else:
params.append(f"{op_type}({field_data[op_field_name]})")
del field_data[op_field_name]
#print(op_type, op_size, op_fmt, op_el, op_field_name,
# field_data[op_field_name],
# val_dict.get(field_data[op_field_name], "<UNK>"))
#print(ET.tostring(op_el).decode())
print(f"{i:4X} : {ins:16x} -- {ins_name.lower()} {', '.join(params)}", field_data)
# advance
i += len(mask) // 8
#print(ET.tostring(root).decode())
if __name__ == "__main__":
# human readable manual at https://docs.amd.com/v/u/en-US/rdna35_instruction_set_architecture
fns = nn.state.zip_extract(Tensor.from_url("https://gpuopen.com/download/machine-readable-isa/latest/"))
xml_str = fns['amdgpu_isa_rdna3_5.xml'].to("CPU").data()
with open("/tmp/rdna35.xml", "wb") as f: f.write(bytes(xml_str))
root = ET.fromstring(xml_str)
a = Tensor.empty(16)+1
for si, ei in lower_schedule(a.schedule()):
# get text
_, hdr, _ = elf_loader(ei.prg.lib)
text = get_single_element([x for x in hdr if x.name==".text"]).content
# llvm disassembler
Device["AMD"].compiler.disassemble(ei.prg.lib)
# run program
ei.run()
sqtt_events = [e for e in Device["AMD"].profile_events if isinstance(e, ProfileSQTTEvent)]
for e in sqtt_events[0:1]: # only the first SE
parse_sqtt_print_packets(e.blob)
disassemble(text[:0x40], root)

View File

@@ -0,0 +1,15 @@
from tinygrad import Tensor, nn
import xml.etree.ElementTree as ET
if __name__ == "__main__":
# human readable manual at https://docs.amd.com/v/u/en-US/rdna35_instruction_set_architecture
fns = nn.state.zip_extract(Tensor.from_url("https://gpuopen.com/download/machine-readable-isa/latest/"))
xml_str = fns['amdgpu_isa_rdna3_5.xml'].to("CPU").data()
root = ET.fromstring(xml_str)
for op_el in root.findall("./ISA/OperandTypes/OperandType"):
op_name = op_el.findtext("OperandTypeName")
val_dict = {}
for op_val in op_el.findall("OperandPredefinedValues/PredefinedValue"):
val_dict[int(op_val.findtext("Value"))] = op_val.findtext("Name")
print(op_name, val_dict)

View File

@@ -119,7 +119,6 @@ OPNAME = {
0xb: "VALU",
0xd: "VALU",
0xe: "VALU",
0x10: "__END",
0x21: "VMEM_LOAD",
0x22: "VMEM_LOAD",
0x24: "VMEM_STORE",
@@ -480,6 +479,7 @@ def parse_sqtt_print_packets(data: bytes, filter=DEFAULT_FILTER, verbose=True) -
nib = (byte >> (offset & 4)) & 0xF
reg = ((reg >> 4) | (nib << 60)) & ((1 << 64) - 1)
offset += 4
if offset != target: break # don't parse past the end
# 2) Decode token from low 8 bits
opcode = STATE_TO_OPCODE[reg & 0xFF]