mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 06:58:11 -05:00
* assembly/amd: make pdf.py code shine * no merge * pdf2 is the future * something * regen enums * test * work * remove junk * write * pcode extraction * pdf2 passes all tests * simplify * simpler pdf * late filter * remove hacks * simplify pdf2.py * field type * remove defaults * don't export srcenum * simple pdf.py * simpler * cleaner * less hack in PDF
306 lines
18 KiB
Python
306 lines
18 KiB
Python
# Generic PDF text extractor - no external dependencies
|
|
import re, zlib
|
|
from tinygrad.helpers import fetch, merge_dicts
|
|
|
|
PDF_URLS = {
|
|
"rdna3": "https://docs.amd.com/api/khub/documents/UVVZM22UN7tMUeiW_4ShTQ/content",
|
|
"rdna4": "https://docs.amd.com/api/khub/documents/uQpkEvk3pv~kfAb2x~j4uw/content",
|
|
"cdna": "https://www.amd.com/content/dam/amd/en/documents/instinct-tech-docs/instruction-set-architectures/amd-instinct-cdna4-instruction-set-architecture.pdf",
|
|
}
|
|
|
|
# ═══════════════════════════════════════════════════════════════════════════════
|
|
# Generic PDF extraction tools
|
|
# ═══════════════════════════════════════════════════════════════════════════════
|
|
|
|
def extract(url: str) -> list[list[tuple[float, float, str, str]]]:
|
|
"""Extract positioned text from PDF. Returns list of text elements (x, y, text, font) per page."""
|
|
data = fetch(url).read_bytes()
|
|
|
|
# Parse xref table to locate objects
|
|
xref: dict[int, int] = {}
|
|
pos = int(re.search(rb'startxref\s+(\d+)', data).group(1)) + 4
|
|
while data[pos:pos+7] != b'trailer':
|
|
while data[pos:pos+1] in b' \r\n': pos += 1
|
|
line_end = data.find(b'\n', pos)
|
|
start_obj, count = map(int, data[pos:line_end].split()[:2])
|
|
pos = line_end + 1
|
|
for i in range(count):
|
|
if data[pos+17:pos+18] == b'n' and (off := int(data[pos:pos+10])) > 0: xref[start_obj + i] = off
|
|
pos += 20
|
|
|
|
def get_stream(n: int) -> bytes:
|
|
obj = data[xref[n]:data.find(b'endobj', xref[n])]
|
|
raw = obj[obj.find(b'stream\n') + 7:obj.find(b'\nendstream')]
|
|
return zlib.decompress(raw) if b'/FlateDecode' in obj else raw
|
|
|
|
# Find page content streams and extract text
|
|
pages = []
|
|
for n in sorted(xref):
|
|
if b'/Type /Page' not in data[xref[n]:xref[n]+500]: continue
|
|
if not (m := re.search(rb'/Contents (\d+) 0 R', data[xref[n]:xref[n]+500])): continue
|
|
stream = get_stream(int(m.group(1))).decode('latin-1')
|
|
elements, font = [], ''
|
|
for bt in re.finditer(r'BT(.*?)ET', stream, re.S):
|
|
x, y = 0.0, 0.0
|
|
for m in re.finditer(r'(/F[\d.]+) [\d.]+ Tf|([\d.+-]+) ([\d.+-]+) Td|[\d.+-]+ [\d.+-]+ [\d.+-]+ [\d.+-]+ ([\d.+-]+) ([\d.+-]+) Tm|<([0-9A-Fa-f]+)>.*?Tj|\[([^\]]+)\] TJ', bt.group(1)):
|
|
if m.group(1): font = m.group(1)
|
|
elif m.group(2): x, y = x + float(m.group(2)), y + float(m.group(3))
|
|
elif m.group(4): x, y = float(m.group(4)), float(m.group(5))
|
|
elif m.group(6) and (t := bytes.fromhex(m.group(6)).decode('latin-1')).strip(): elements.append((x, y, t, font))
|
|
elif m.group(7) and (t := ''.join(bytes.fromhex(h).decode('latin-1') for h in re.findall(r'<([0-9A-Fa-f]+)>', m.group(7)))).strip(): elements.append((x, y, t, font))
|
|
pages.append(sorted(elements, key=lambda e: (-e[1], e[0])))
|
|
return pages
|
|
|
|
def extract_tables(pages: list[list[tuple[float, float, str, str]]]) -> dict[int, tuple[str, list[list[str]]]]:
|
|
"""Extract numbered tables from PDF pages. Returns {table_num: (title, rows)} where rows is list of cells per row."""
|
|
def group_by_y(texts, key=lambda y: round(y)):
|
|
by_y: dict[int, list[tuple[float, float, str]]] = {}
|
|
for x, y, t, _ in texts:
|
|
by_y.setdefault(key(y), []).append((x, y, t))
|
|
return by_y
|
|
|
|
# Find all table headers by merging text on same line
|
|
table_positions = []
|
|
for page_idx, texts in enumerate(pages):
|
|
for items in group_by_y(texts).values():
|
|
line = ''.join(t for _, t in sorted((x, t) for x, _, t in items))
|
|
if m := re.search(r'Table (\d+)\. (.+)', line):
|
|
table_positions.append((int(m.group(1)), m.group(2).strip(), page_idx, items[0][1]))
|
|
table_positions.sort(key=lambda t: (t[2], -t[3]))
|
|
|
|
# For each table, find rows with matching X positions
|
|
result: dict[int, tuple[str, list[list[str]]]] = {}
|
|
for num, title, start_page, header_y in table_positions:
|
|
rows, col_xs = [], None
|
|
for page_idx in range(start_page, len(pages)):
|
|
page_texts = [(x, y, t) for x, y, t, _ in pages[page_idx] if 30 < y < 760 and (page_idx > start_page or y < header_y)]
|
|
for items in sorted(group_by_y([(x, y, t, '') for x, y, t in page_texts], key=lambda y: round(y / 5)).values(), key=lambda items: -items[0][1]):
|
|
xs = tuple(sorted(round(x) for x, _, _ in items))
|
|
if col_xs is None:
|
|
if len(xs) < 2: continue # Skip single-column rows before table starts
|
|
col_xs = xs
|
|
elif len(xs) == 1 and xs[0] in col_xs: continue # Skip continuation rows at known column positions
|
|
elif not any(c in xs for c in col_xs[:2]): break # Row missing first columns = end of table
|
|
rows.append([t for _, t in sorted((x, t) for x, _, t in items)])
|
|
else: continue
|
|
break
|
|
if rows: result[num] = (title, rows)
|
|
return result
|
|
|
|
# ═══════════════════════════════════════════════════════════════════════════════
|
|
# AMD specific extraction
|
|
# ═══════════════════════════════════════════════════════════════════════════════
|
|
|
|
def extract_enums(tables: dict[int, tuple[str, list[list[str]]]]) -> dict[str, dict[int, str]]:
|
|
"""Extract all enums from tables. Returns {enum_name: {value: name}}."""
|
|
enums: dict[str, dict[int, str]] = {}
|
|
for num, (title, rows) in tables.items():
|
|
# Opcode enums from "XXX Opcodes" tables
|
|
if m := re.match(r'(\w+) (?:Y-)?Opcodes', title):
|
|
fmt_name = 'VOPD' if 'Y-Opcodes' in title else m.group(1)
|
|
ops: dict[int, str] = {}
|
|
for row in rows:
|
|
for i in range(0, len(row) - 1, 2):
|
|
if row[i].isdigit() and re.match(r'^[A-Z][A-Z0-9_]+$', row[i + 1]):
|
|
ops[int(row[i])] = row[i + 1]
|
|
if ops: enums[fmt_name] = ops
|
|
# BufFmt from "Data Format" tables
|
|
if 'Data Format' in title:
|
|
for row in rows:
|
|
for i in range(0, len(row) - 1, 2):
|
|
if row[i].isdigit() and re.match(r'^[\dA-Z_]+$', row[i + 1]) and 'INVALID' not in row[i + 1]:
|
|
enums.setdefault('BufFmt', {})[int(row[i])] = row[i + 1]
|
|
return enums
|
|
|
|
def extract_ins(tables: dict[int, tuple[str, list[list[str]]]]) -> tuple[dict[str, list[tuple[str, int, int]]], dict[str, str]]:
|
|
"""Extract formats and encodings from 'XXX Fields' tables. Returns (formats, encodings)."""
|
|
formats: dict[str, list[tuple[str, int, int]]] = {}
|
|
encodings: dict[str, str] = {}
|
|
for num, (title, rows) in tables.items():
|
|
if not (m := re.match(r'(\w+) Fields$', title)): continue
|
|
fmt_name = m.group(1)
|
|
fields = []
|
|
for row in rows:
|
|
if len(row) < 2: continue
|
|
if (bits := re.match(r'\[?(\d+):(\d+)\]?$', row[1])) or (bits := re.match(r'\[(\d+)\]$', row[1])):
|
|
field_name = row[0].lower()
|
|
hi, lo = int(bits.group(1)), int(bits.group(2)) if bits.lastindex >= 2 else int(bits.group(1))
|
|
if field_name == 'encoding' and len(row) >= 3:
|
|
enc_bits = None
|
|
if "'b" in row[2]: enc_bits = row[2].split("'b")[-1].replace('_', '')
|
|
elif (enc := re.search(r':\s*([01_]+)', row[2])): enc_bits = enc.group(1).replace('_', '')
|
|
if enc_bits:
|
|
# If encoding bits exceed field width, extend field to match (AMD docs sometimes have this)
|
|
declared_width, actual_width = hi - lo + 1, len(enc_bits)
|
|
if actual_width > declared_width: lo = hi - actual_width + 1
|
|
encodings[fmt_name] = enc_bits
|
|
fields.append((field_name, hi, lo))
|
|
if fields: formats[fmt_name] = fields
|
|
return formats, encodings
|
|
|
|
def extract_pcode(pages: list[list[tuple[float, float, str, str]]], enums: dict[str, dict[int, str]]) -> dict[tuple[str, int], str]:
|
|
"""Extract pseudocode for instructions. Returns {(name, opcode): pseudocode}."""
|
|
# Build lookup from instruction name to opcode
|
|
name_to_op = {name: op for ops in enums.values() for op, name in ops.items()}
|
|
|
|
# First pass: find all instruction headers across all pages
|
|
all_instructions: list[tuple[int, float, str, int]] = [] # (page_idx, y, name, opcode)
|
|
for page_idx, page in enumerate(pages):
|
|
by_y: dict[int, list[tuple[float, str]]] = {}
|
|
for x, y, t, _ in page:
|
|
by_y.setdefault(round(y), []).append((x, t))
|
|
for y, items in sorted(by_y.items(), reverse=True):
|
|
left = [(x, t) for x, t in items if 55 < x < 65]
|
|
right = [(x, t) for x, t in items if 535 < x < 550]
|
|
if left and right and left[0][1] in name_to_op and right[0][1].isdigit():
|
|
all_instructions.append((page_idx, y, left[0][1], int(right[0][1])))
|
|
|
|
# Second pass: extract pseudocode between consecutive instructions
|
|
pcode: dict[tuple[str, int], str] = {}
|
|
for i, (page_idx, y, name, opcode) in enumerate(all_instructions):
|
|
# Get end boundary from next instruction
|
|
if i + 1 < len(all_instructions):
|
|
next_page, next_y = all_instructions[i + 1][0], all_instructions[i + 1][1]
|
|
else:
|
|
next_page, next_y = page_idx, 0
|
|
# Collect F6 text from current position to next instruction
|
|
lines = []
|
|
for p in range(page_idx, next_page + 1):
|
|
start_y = y if p == page_idx else 800
|
|
end_y = next_y if p == next_page else 0
|
|
lines.extend((p, y2, t) for x, y2, t, f in pages[p] if f in ('/F6.0', '/F7.0') and end_y < y2 < start_y)
|
|
if lines:
|
|
# Sort by page first, then by y descending within each page (higher y = earlier text in PDF)
|
|
pcode_lines = [t.replace('Ê', '').strip() for _, _, t in sorted(lines, key=lambda x: (x[0], -x[1]))]
|
|
if pcode_lines: pcode[(name, opcode)] = '\n'.join(pcode_lines)
|
|
return pcode
|
|
|
|
# ═══════════════════════════════════════════════════════════════════════════════
|
|
# Write autogen files
|
|
# ═══════════════════════════════════════════════════════════════════════════════
|
|
|
|
def write_enums(enums: dict[str, dict[int, str]], arch: str, path: str):
|
|
"""Write enum.py file from extracted enums."""
|
|
lines = ["# autogenerated from AMD ISA PDF by pdf.py - do not edit", "from enum import IntEnum", ""]
|
|
for name, values in sorted(enums.items()):
|
|
suffix = "Op" if name not in ('Src', 'BufFmt') else ("Enum" if name == 'Src' else "")
|
|
prefix = "BUF_FMT_" if name == 'BufFmt' else ""
|
|
lines.append(f"class {name}{suffix}(IntEnum):")
|
|
for val, member in sorted(values.items()):
|
|
lines.append(f" {prefix}{member} = {val}")
|
|
lines.append("")
|
|
with open(path, "w") as f:
|
|
f.write("\n".join(lines))
|
|
|
|
def write_ins(formats: dict[str, list[tuple[str, int, int]]], encodings: dict[str, str], enums: dict[str, dict[int, str]], arch: str, path: str):
|
|
"""Write ins.py file from extracted formats and enums."""
|
|
# Field types and ordering
|
|
def field_type(name, fmt):
|
|
if name == 'op' and fmt in enums: return f'Annotated[BitField, {fmt}Op]'
|
|
if name in ('opx', 'opy'): return 'Annotated[BitField, VOPDOp]'
|
|
if name == 'vdsty': return 'VDSTYEnc'
|
|
if name in ('vdst', 'vsrc1', 'vaddr', 'vdata', 'data', 'data0', 'data1', 'addr', 'vsrc0', 'vsrc2', 'vsrc3'): return 'VGPRField'
|
|
if name in ('sdst', 'sbase', 'sdata', 'srsrc', 'ssamp'): return 'SGPRField'
|
|
if name.startswith('ssrc') or name in ('saddr', 'soffset'): return 'SSrc'
|
|
if name in ('src0', 'srcx0', 'srcy0') or name.startswith('src') and name[3:].isdigit(): return 'Src'
|
|
if name.startswith('simm'): return 'SImm'
|
|
if name == 'offset' or name.startswith('imm'): return 'Imm'
|
|
return None
|
|
field_priority = ['encoding', 'op', 'opx', 'opy', 'vdst', 'vdstx', 'vdsty', 'sdst', 'vdata', 'sdata', 'addr', 'vaddr', 'data', 'data0', 'data1',
|
|
'src0', 'srcx0', 'srcy0', 'vsrc0', 'ssrc0', 'src1', 'vsrc1', 'vsrcx1', 'vsrcy1', 'ssrc1', 'src2', 'vsrc2', 'src3', 'vsrc3',
|
|
'saddr', 'sbase', 'srsrc', 'ssamp', 'soffset', 'offset', 'simm16', 'en', 'target', 'attr', 'attr_chan',
|
|
'omod', 'neg', 'neg_hi', 'abs', 'clmp', 'opsel', 'opsel_hi', 'waitexp', 'wait_va',
|
|
'dmask', 'dim', 'seg', 'format', 'offen', 'idxen', 'glc', 'dlc', 'slc', 'tfe', 'unrm', 'done', 'row']
|
|
def sort_fields(fields):
|
|
order = {name: i for i, name in enumerate(field_priority)}
|
|
return sorted(fields, key=lambda f: (order.get(f[0], 1000), f[2]))
|
|
|
|
# Generate format classes
|
|
lines = ["# autogenerated from AMD ISA PDF by pdf.py - do not edit", "# ruff: noqa: F401,F403",
|
|
"from typing import Annotated",
|
|
"from extra.assembly.amd.dsl import *",
|
|
f"from extra.assembly.amd.autogen.{arch}.enum import *", "import functools", ""]
|
|
for fmt_name, fields in sorted(formats.items()):
|
|
lines.append(f"class {fmt_name}(Inst):")
|
|
for name, hi, lo in sort_fields(fields):
|
|
bits_str = f"bits[{hi}:{lo}]" if hi != lo else f"bits[{hi}]"
|
|
if name == 'encoding' and fmt_name in encodings: lines.append(f" encoding = {bits_str} == 0b{encodings[fmt_name]}")
|
|
else:
|
|
ftype = field_type(name, fmt_name)
|
|
lines.append(f" {name}{f':{ftype}' if ftype else ''} = {bits_str}")
|
|
lines.append("")
|
|
|
|
# Generate instruction helpers
|
|
lines.append("# instruction helpers")
|
|
for fmt_name, ops in sorted(enums.items()):
|
|
seg = {"GLOBAL": ", seg=2", "SCRATCH": ", seg=1"}.get(fmt_name, "")
|
|
tgt = {"GLOBAL": "FLAT, GLOBALOp", "SCRATCH": "FLAT, SCRATCHOp"}.get(fmt_name, f"{fmt_name}, {fmt_name}Op")
|
|
suffix = "_e32" if fmt_name in ("VOP1", "VOP2", "VOPC") else "_e64" if fmt_name == "VOP3" and len(ops) > 0 else ""
|
|
if fmt_name in formats or fmt_name in ("GLOBAL", "SCRATCH"):
|
|
for op_val, name in sorted(ops.items()):
|
|
fn_suffix = suffix if fmt_name != "VOP3" or op_val < 512 else ""
|
|
lines.append(f"{name.lower()}{fn_suffix} = functools.partial({tgt}.{name}{seg})")
|
|
|
|
with open(path, "w") as f:
|
|
f.write("\n".join(lines))
|
|
|
|
def write_pcode(pcode: dict[tuple[str, int], str], enums: dict[str, dict[int, str]], arch: str, path: str):
|
|
"""Write str_pcode.py file from extracted pseudocode."""
|
|
# Group pseudocode by enum class
|
|
by_enum: dict[str, list[tuple[str, int, str]]] = {}
|
|
for fmt_name, ops in enums.items():
|
|
for opcode, name in ops.items():
|
|
if (name, opcode) in pcode: by_enum.setdefault(f"{fmt_name}Op", []).append((name, opcode, pcode[(name, opcode)]))
|
|
# Generate file
|
|
enum_names = sorted(by_enum.keys())
|
|
lines = [f"# autogenerated by pdf.py - do not edit", f"# to regenerate: python -m extra.assembly.amd.pdf",
|
|
"# ruff: noqa: E501", f"from extra.assembly.amd.autogen.{arch}.enum import {', '.join(enum_names)}", ""]
|
|
for enum_name in enum_names:
|
|
lines.append(f"{enum_name}_PCODE = {{")
|
|
for name, opcode, code in sorted(by_enum[enum_name], key=lambda x: x[1]):
|
|
lines.append(f" {enum_name}.{name}: {code!r},")
|
|
lines.append("}\n")
|
|
lines.append(f"PSEUDOCODE_STRINGS = {{{', '.join(f'{e}: {e}_PCODE' for e in enum_names)}}}")
|
|
with open(path, "w") as f:
|
|
f.write("\n".join(lines))
|
|
|
|
if __name__ == "__main__":
|
|
import pathlib
|
|
for arch, url in PDF_URLS.items():
|
|
print(f"Processing {arch}...")
|
|
pages = extract(url)
|
|
tables = extract_tables(pages)
|
|
enums = extract_enums(tables)
|
|
formats, encodings = extract_ins(tables)
|
|
pcode = extract_pcode(pages, enums)
|
|
# Fix known PDF errors
|
|
if arch == 'rdna3':
|
|
fixes = {'SOPP': {8: 'S_WAITCNT_DEPCTR', 58: 'S_TTRACEDATA', 59: 'S_TTRACEDATA_IMM'},
|
|
'SOPK': {22: 'S_SUBVECTOR_LOOP_BEGIN', 23: 'S_SUBVECTOR_LOOP_END'},
|
|
'SMEM': {34: 'S_ATC_PROBE', 35: 'S_ATC_PROBE_BUFFER'},
|
|
'DS': {24: 'DS_GWS_SEMA_RELEASE_ALL', 25: 'DS_GWS_INIT', 26: 'DS_GWS_SEMA_V', 27: 'DS_GWS_SEMA_BR', 28: 'DS_GWS_SEMA_P', 29: 'DS_GWS_BARRIER'},
|
|
'FLAT': {40: 'GLOBAL_LOAD_ADDTID_B32', 41: 'GLOBAL_STORE_ADDTID_B32', 55: 'FLAT_ATOMIC_CSUB_U32'}}
|
|
for fmt, ops in fixes.items(): enums[fmt] = merge_dicts([enums[fmt], ops])
|
|
if arch in ('rdna3', 'rdna4'):
|
|
# RDNA SMEM: PDF says DLC=[14], GLC=[16] but hardware uses DLC=[13], GLC=[14]
|
|
if 'SMEM' in formats:
|
|
formats['SMEM'] = [(n, 13 if n == 'dlc' else 14 if n == 'glc' else h, 13 if n == 'dlc' else 14 if n == 'glc' else l)
|
|
for n, h, l in formats['SMEM']]
|
|
if arch == 'cdna':
|
|
# CDNA DS: PDF is missing the GDS field (bit 16)
|
|
if 'DS' in formats and not any(n == 'gds' for n, _, _ in formats['DS']):
|
|
formats['DS'].append(('gds', 16, 16))
|
|
# CDNA DPP/SDWA: PDF only documents modifier fields (bits[63:32]), need to add VOP overlay fields (bits[31:0])
|
|
vop_overlay = [('encoding', 8, 0), ('vop_op', 16, 9), ('vdst', 24, 17), ('vop2_op', 31, 25)]
|
|
if 'DPP' in formats and not any(n == 'encoding' for n, _, _ in formats['DPP']):
|
|
formats['DPP'] = vop_overlay + [('bc' if n == 'bound_ctrl' else n, h, l) for n, h, l in formats['DPP']]
|
|
encodings['DPP'] = '11111010'
|
|
if 'SDWA' in formats and not any(n == 'encoding' for n, _, _ in formats['SDWA']):
|
|
formats['SDWA'] = vop_overlay + [(n, h, l) for n, h, l in formats['SDWA']]
|
|
encodings['SDWA'] = '11111001'
|
|
base = pathlib.Path(__file__).parent / "autogen" / arch
|
|
write_enums(enums, arch, base / "enum.py")
|
|
write_ins(formats, encodings, enums, arch, base / "ins.py")
|
|
write_pcode(pcode, enums, arch, base / "str_pcode.py")
|
|
print(f" {len(tables)} tables, {len(pcode)} pcode -> {base}")
|