diff --git a/extra/assembly/rdna3/generate.py b/extra/assembly/rdna3/generate.py new file mode 100644 index 0000000000..9c0a2ff7e1 --- /dev/null +++ b/extra/assembly/rdna3/generate.py @@ -0,0 +1,136 @@ +import os, sys, struct +sys.path.append(os.getcwd()) +# PROFILE=1 to use +#os.environ["PROFILE"] = "1" +os.environ["SQTT"] = "1" +os.environ["SQTT_ITRACE_SE_MASK"] = "1" +os.environ["SQTT_LIMIT_SE"] = "1" +import xml.etree.ElementTree as ET + +from tinygrad import nn, Tensor, Device +from tinygrad.helpers import get_single_element +from tinygrad.engine.realize import lower_schedule +from tinygrad.runtime.support.elf import elf_loader +from tinygrad.runtime.ops_amd import ProfileSQTTEvent +from extra.sqtt.attempt_sqtt_parse import parse_sqtt_print_packets + +def disassemble(text, root:ET.Element): + i = 0 + while i < len(text): + ins = struct.unpack("I", text[i:i+4])[0] + + # 1. Get the encoding + did_match = False + for enc_el in root.findall("./ISA/Encodings/Encoding"): + mask = enc_el.findtext("EncodingIdentifierMask") + assert len(mask)%32 == 0 + bit_mask = int(mask, 2) + iden = [int(x.text, 2) for x in enc_el.find("EncodingIdentifiers").findall("EncodingIdentifier")] + for ide in iden: + if ins&bit_mask == ide: + did_match = True + break + if did_match: break + if not did_match: raise RuntimeError(f"unknown instruction {ins:08X}") + if len(mask) >= 64: ins = (struct.unpack("I", text[i+4:i+8])[0]<<32) | ins + if len(mask) >= 96: ins = (struct.unpack("I", text[i+8:i+12])[0]<<64) | ins + encoding_name = enc_el.findtext("EncodingName") + + #print(ET.tostring(enc_el).decode()) + + # 2. Parse the Fields for this Encoding + field_data = {} + for field in enc_el.findall("MicrocodeFormat/BitMap/Field"): + # Fields can be split into multiple ranges (RangeCount > 1) + ranges = sorted(field.findall("BitLayout/Range"), key=lambda x: int(x.attrib.get('Order'))) + val = 0 + current_shift = 0 + for rng in ranges: + width = int(rng.find("BitCount").text) + chunk = (ins >> int(rng.find("BitOffset").text)) & ((1 << width) - 1) + val |= (chunk << current_shift) + current_shift += width + field_data[field.find("FieldName").text] = val + # this is already used + del field_data["ENCODING"] + + # 3. Extract the instruction + did_match = False + for ins_el in root.findall("./ISA/Instructions/Instruction"): + ins_name = ins_el.findtext("InstructionName") + for ins_enc in ins_el.findall("InstructionEncodings/InstructionEncoding"): + if ins_enc.findtext("EncodingName") == encoding_name: + opcode = int(ins_enc.findtext("Opcode")) + if "OP" in field_data and opcode == field_data["OP"]: + did_match = True + del field_data["OP"] + break + if did_match: break + if did_match: break + + #print(ET.tostring(ins_enc).decode()) + #print() + #print(field_data) + if not did_match: + print(f"{i:4X} : {ins:16x} -- {encoding_name}") + elif did_match: + params = [] + #print(ET.tostring(ins_el).decode()) + + # 4. Extract the opcodes + for op_ins in ins_enc.findall("Operands/Operand"): + op_type = op_ins.findtext("OperandType") + op_size = op_ins.findtext("OperandSize") + op_fmt = op_ins.findtext("DataFormatName") + op_field_name = op_ins.findtext("FieldName") + if op_field_name is None: continue + assert op_field_name in field_data + # loop through operands for compare + for op_el in root.findall("./ISA/OperandTypes/OperandType"): + test_op_type = op_el.findtext("OperandTypeName") + val_dict = {} + for op_val in op_el.findall("OperandPredefinedValues/PredefinedValue"): + val_dict[int(op_val.findtext("Value"))] = op_val.findtext("Name") + if op_type == test_op_type: + if field_data[op_field_name] in val_dict: + print(op_type, op_size, op_fmt) + params.append(val_dict[field_data[op_field_name]]) + else: + params.append(f"{op_type}({field_data[op_field_name]})") + del field_data[op_field_name] + #print(op_type, op_size, op_fmt, op_el, op_field_name, + # field_data[op_field_name], + # val_dict.get(field_data[op_field_name], "")) + #print(ET.tostring(op_el).decode()) + + print(f"{i:4X} : {ins:16x} -- {ins_name.lower()} {', '.join(params)}", field_data) + + # advance + i += len(mask) // 8 + + #print(ET.tostring(root).decode()) + +if __name__ == "__main__": + # human readable manual at https://docs.amd.com/v/u/en-US/rdna35_instruction_set_architecture + fns = nn.state.zip_extract(Tensor.from_url("https://gpuopen.com/download/machine-readable-isa/latest/")) + xml_str = fns['amdgpu_isa_rdna3_5.xml'].to("CPU").data() + with open("/tmp/rdna35.xml", "wb") as f: f.write(bytes(xml_str)) + root = ET.fromstring(xml_str) + + a = Tensor.empty(16)+1 + for si, ei in lower_schedule(a.schedule()): + # get text + _, hdr, _ = elf_loader(ei.prg.lib) + text = get_single_element([x for x in hdr if x.name==".text"]).content + + # llvm disassembler + Device["AMD"].compiler.disassemble(ei.prg.lib) + + # run program + ei.run() + + sqtt_events = [e for e in Device["AMD"].profile_events if isinstance(e, ProfileSQTTEvent)] + for e in sqtt_events[0:1]: # only the first SE + parse_sqtt_print_packets(e.blob) + + disassemble(text[:0x40], root) diff --git a/extra/assembly/rdna3/parse.py b/extra/assembly/rdna3/parse.py new file mode 100644 index 0000000000..ea2d19bef9 --- /dev/null +++ b/extra/assembly/rdna3/parse.py @@ -0,0 +1,15 @@ +from tinygrad import Tensor, nn +import xml.etree.ElementTree as ET + +if __name__ == "__main__": + # human readable manual at https://docs.amd.com/v/u/en-US/rdna35_instruction_set_architecture + fns = nn.state.zip_extract(Tensor.from_url("https://gpuopen.com/download/machine-readable-isa/latest/")) + xml_str = fns['amdgpu_isa_rdna3_5.xml'].to("CPU").data() + root = ET.fromstring(xml_str) + + for op_el in root.findall("./ISA/OperandTypes/OperandType"): + op_name = op_el.findtext("OperandTypeName") + val_dict = {} + for op_val in op_el.findall("OperandPredefinedValues/PredefinedValue"): + val_dict[int(op_val.findtext("Value"))] = op_val.findtext("Name") + print(op_name, val_dict) diff --git a/extra/sqtt/attempt_sqtt_parse.py b/extra/sqtt/attempt_sqtt_parse.py index 9bf5d1a62f..788ca6cf6f 100644 --- a/extra/sqtt/attempt_sqtt_parse.py +++ b/extra/sqtt/attempt_sqtt_parse.py @@ -119,7 +119,6 @@ OPNAME = { 0xb: "VALU", 0xd: "VALU", 0xe: "VALU", - 0x10: "__END", 0x21: "VMEM_LOAD", 0x22: "VMEM_LOAD", 0x24: "VMEM_STORE", @@ -480,6 +479,7 @@ def parse_sqtt_print_packets(data: bytes, filter=DEFAULT_FILTER, verbose=True) - nib = (byte >> (offset & 4)) & 0xF reg = ((reg >> 4) | (nib << 60)) & ((1 << 64) - 1) offset += 4 + if offset != target: break # don't parse past the end # 2) Decode token from low 8 bits opcode = STATE_TO_OPCODE[reg & 0xFF] diff --git a/tinygrad/nn/state.py b/tinygrad/nn/state.py index 5c1ccc8e45..af30f445ef 100644 --- a/tinygrad/nn/state.py +++ b/tinygrad/nn/state.py @@ -1,4 +1,4 @@ -import json, pathlib, zipfile, pickle, tarfile, struct, functools, io +import json, pathlib, zipfile, pickle, tarfile, struct, functools, io, zlib from collections import OrderedDict from typing import Any, Callable, BinaryIO, Iterable, cast from tinygrad.tensor import Tensor @@ -161,6 +161,27 @@ def load_state_dict(model, state_dict:dict[str, Tensor], strict=True, verbose=Tr ret.append(v) return ret +@accept_filename +def zip_extract(t: Tensor) -> dict[str, Tensor]: + files: dict[str, Tensor] = {} + file_offsets: dict[str, tuple[Tensor, int, int]] = {} + with zipfile.ZipFile(TensorIO(t), "r") as myzip: + for zi in myzip.filelist: + file_offset = zi.header_offset+30+t[zi.header_offset+26:zi.header_offset+30].bitcast(dtypes.uint16).to("CPU").sum() + file_offsets[zi.filename] = (file_offset, zi.compress_size, zi.compress_type) + # sadly, the extra length needs to be read from the local header of each file. this is a limitation of the zip file format + Tensor.realize(*[x[0] for x in file_offsets.values()]) + for filename, (file_offset, compress_size, compress_type) in file_offsets.items(): + # possible to remove this realize/item? it's slow + file_offset_int = int(file_offset.item()) + files[filename] = t[file_offset_int:file_offset_int+compress_size] + match compress_type: + case zipfile.ZIP_STORED: pass + # TODO: we need a zlib UOp so this can be lazy + case zipfile.ZIP_DEFLATED: files[filename] = Tensor(zlib.decompress(files[filename].data(), -15)) + case _: raise NotImplementedError(f"compression {compress_type} not supported") + return files + @accept_filename def tar_extract(t: Tensor) -> dict[str, Tensor]: """ @@ -179,6 +200,7 @@ def tar_extract(t: Tensor) -> dict[str, Tensor]: # torch support! +# TODO: this should use tar_extract and zip_extract @accept_filename def torch_load(t:Tensor) -> dict[str, Tensor]: """