speed up rdna3 unit tests + add to CI (#13871)

* speed up rdna3 unit tests

* add test to CI

* faster and simpler

* speedups

* bugfixes

* use helper

* fix CI maybe

* test fixes

* llvm-21 on 24.04

* upd

* llvm-21

* fix test

* bring that back

* merge gen into lib

* test generators
This commit is contained in:
George Hotz
2025-12-29 10:26:48 -05:00
committed by GitHub
parent 37720fd6c0
commit f1471a3b99
16 changed files with 3607 additions and 7634 deletions

View File

@@ -654,6 +654,38 @@ jobs:
- name: Run process replay tests
uses: ./.github/actions/process-replay
testrdna3:
name: RDNA3 IDE
runs-on: ubuntu-24.04
timeout-minutes: 10
steps:
- name: Checkout Code
uses: actions/checkout@v4
- name: Setup Environment
uses: ./.github/actions/setup-tinygrad
with:
key: rdna3-emu
deps: testing_minimal
amd: 'true'
- name: Install LLVM 21
run: |
wget -qO- https://apt.llvm.org/llvm-snapshot.gpg.key | sudo tee /etc/apt/trusted.gpg.d/apt.llvm.org.asc
echo "deb http://apt.llvm.org/$(lsb_release -cs)/ llvm-toolchain-$(lsb_release -cs)-21 main" | sudo tee /etc/apt/sources.list.d/llvm.list
sudo apt-get update
sudo apt-get install llvm-21 llvm-21-tools
- name: Run RDNA3 emulator tests
run: python -m pytest -n=auto extra/assembly/rdna3/ --durations 20
- name: Install pdfplumber
run: pip install pdfplumber
- name: Verify RDNA3 autogen is up to date
run: |
python -m extra.assembly.rdna3.lib
git diff --exit-code extra/assembly/rdna3/autogen/__init__.py
- name: Verify RDNA3 pcode autogen is up to date
run: |
python -m extra.assembly.rdna3.pcode
git diff --exit-code extra/assembly/rdna3/autogen/gen_pcode.py
testnvidia:
strategy:
fail-fast: false

View File

@@ -97,7 +97,7 @@ def disasm(inst: Inst) -> str:
else:
op_name = getattr(autogen, f"{cls_name}Op")(op_val).name.lower() if hasattr(autogen, f"{cls_name}Op") else f"op_{op_val}"
except (ValueError, KeyError): op_name = f"op_{op_val}"
def fmt_src(v): return f"0x{inst._literal:x}" if v == 255 and getattr(inst, '_literal', None) else decode_src(v)
def fmt_src(v): return f"0x{inst._literal:x}" if v == 255 and inst._literal is not None else decode_src(v)
# VOP1
if cls_name == 'VOP1':
@@ -440,7 +440,9 @@ def disasm(inst: Inst) -> str:
return f"{op_name} {_fmt_sdst(sdst, dst_cnt)}, {ssrc0_str}"
if cls_name == 'SOP2':
sdst, ssrc0, ssrc1 = [unwrap(inst._values.get(f, 0)) for f in ('sdst', 'ssrc0', 'ssrc1')]
return f"{op_name} {_fmt_sdst(sdst, dst_cnt)}, {_fmt_ssrc(ssrc0, src0_cnt)}, {_fmt_ssrc(ssrc1, src1_cnt)}"
ssrc0_str = fmt_src(ssrc0) if ssrc0 == 255 else _fmt_ssrc(ssrc0, src0_cnt)
ssrc1_str = fmt_src(ssrc1) if ssrc1 == 255 else _fmt_ssrc(ssrc1, src1_cnt)
return f"{op_name} {_fmt_sdst(sdst, dst_cnt)}, {ssrc0_str}, {ssrc1_str}"
if cls_name == 'SOPC':
return f"{op_name} {_fmt_ssrc(unwrap(inst._values.get('ssrc0', 0)), src0_cnt)}, {_fmt_ssrc(unwrap(inst._values.get('ssrc1', 0)), src1_cnt)}"
if cls_name == 'SOPK':
@@ -557,7 +559,8 @@ def asm(text: str) -> Inst:
elif mnemonic in ('v_fmamk_f32', 'v_fmamk_f16') and len(values) == 4: lit, values = unwrap(values[2]), [values[0], values[1], values[3]]
vcc_ops = {'v_add_co_ci_u32', 'v_sub_co_ci_u32', 'v_subrev_co_ci_u32', 'v_add_co_u32', 'v_sub_co_u32', 'v_subrev_co_u32'}
if mnemonic.replace('_e32', '') in vcc_ops and len(values) >= 5: values = [values[0], values[2], values[3]]
if mnemonic.startswith('v_cmp') and len(values) >= 3 and operands[0].strip().lower() in ('vcc_lo', 'vcc_hi', 'vcc'):
# v_cmp_*_e32: strip implicit vcc_lo dest. v_cmp_*_e64: keep vdst (vcc_lo encodes to 106)
if mnemonic.startswith('v_cmp') and not mnemonic.endswith('_e64') and len(values) >= 3 and operands[0].strip().lower() in ('vcc_lo', 'vcc_hi', 'vcc'):
values = values[1:]
# CMPX instructions with _e64 suffix: prepend implicit EXEC_LO destination (vdst=126)
if 'cmpx' in mnemonic and mnemonic.endswith('_e64') and len(values) == 2:

View File

@@ -1,4 +1,4 @@
# autogenerated from AMD RDNA3.5 ISA PDF by gen.py - do not edit
# autogenerated from AMD RDNA3.5 ISA PDF by lib.py - do not edit
from enum import IntEnum
from typing import Annotated
from extra.assembly.rdna3.lib import bits, BitField, Inst32, Inst64, SGPR, VGPR, TTMP as TTMP, s as s, v as v, ttmp as ttmp, SSrc, Src, SImm, Imm, VDSTYEnc, SGPRField, VGPRField

File diff suppressed because it is too large Load Diff

View File

@@ -18,6 +18,15 @@ VCC_LO, VCC_HI, NULL, EXEC_LO, EXEC_HI, SCC = SrcEnum.VCC_LO, SrcEnum.VCC_HI, Sr
_VOP3_64BIT_OPS = {op.value for op in VOP3Op if op.name.endswith(('_F64', '_B64', '_I64', '_U64'))}
# Ops where src1 is 32-bit (exponent/shift amount) even though the op name suggests 64-bit
_VOP3_64BIT_OPS_32BIT_SRC1 = {VOP3Op.V_LDEXP_F64.value}
# Ops with 16-bit types in name (for source/dest handling)
_VOP3_16BIT_OPS = {op for op in VOP3Op if any(s in op.name for s in ('_F16', '_B16', '_I16', '_U16'))}
_VOP1_16BIT_OPS = {op for op in VOP1Op if any(s in op.name for s in ('_F16', '_B16', '_I16', '_U16'))}
# CVT ops with 32/64-bit source (despite 16-bit in name)
_CVT_32_64_SRC_OPS = {op for op in VOP3Op if op.name.startswith('V_CVT_') and op.name.endswith(('_F32', '_I32', '_U32', '_F64', '_I64', '_U64'))} | \
{op for op in VOP1Op if op.name.startswith('V_CVT_') and op.name.endswith(('_F32', '_I32', '_U32', '_F64', '_I64', '_U64'))}
# 16-bit dst ops (PACK has 32-bit dst despite F16 in name)
_VOP3_16BIT_DST_OPS = {op for op in _VOP3_16BIT_OPS if 'PACK' not in op.name}
_VOP1_16BIT_DST_OPS = {op for op in _VOP1_16BIT_OPS if 'PACK' not in op.name}
# Inline constants for src operands 128-254 (f32 format for most instructions)
_INLINE_CONSTS = [0] * 127
@@ -509,11 +518,9 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No
# V_LDEXP_F64: src0 is 64-bit float, src1 is 32-bit integer exponent
is_ldexp_64 = op in (VOP3Op.V_LDEXP_F64,)
is_shift_64 = op in (VOP3Op.V_LSHLREV_B64, VOP3Op.V_LSHRREV_B64, VOP3Op.V_ASHRREV_I64)
# 16-bit source ops: name contains 16-bit type, but for CVT ops check the SOURCE type (CVT naming is V_CVT_DST_SRC)
# For CVT: source type is at the end of the name, so V_CVT_F16_F32 has 32-bit src, V_CVT_F32_F16 has 16-bit src
has_16bit_type = any(s in op.name for s in ('_F16', '_B16', '_I16', '_U16'))
is_cvt_with_32_64_src = op.name.startswith('V_CVT_') and op.name.endswith(('_F32', '_I32', '_U32', '_F64', '_I64', '_U64'))
is_16bit_src = op_cls is VOP3Op and has_16bit_type and not is_cvt_with_32_64_src
# 16-bit source ops: use precomputed sets instead of string checks
has_16bit_type = op in _VOP3_16BIT_OPS or op in _VOP1_16BIT_OPS
is_16bit_src = op_cls is VOP3Op and op in _VOP3_16BIT_OPS and op not in _CVT_32_64_SRC_OPS
if is_shift_64:
s0 = mod_src(st.rsrc(src0, lane), 0) # shift amount is 32-bit
@@ -571,8 +578,7 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No
writes_to_sgpr = op in (VOP1Op.V_READFIRSTLANE_B32,) or \
(op_cls is VOP3Op and op in (VOP3Op.V_READFIRSTLANE_B32, VOP3Op.V_READLANE_B32))
# Check for 16-bit destination ops (opsel[3] controls hi/lo write)
# 16-bit dst ops (exclude PACK which has 32-bit dst despite F16 in name)
is_16bit_dst = any(s in op.name for s in ('_F16', '_B16', '_I16', '_U16')) and 'PACK' not in op.name
is_16bit_dst = op in _VOP3_16BIT_DST_OPS or op in _VOP1_16BIT_DST_OPS
if writes_to_sgpr:
st.wsgpr(vdst, result['d0'] & 0xffffffff)
elif result.get('d0_64') or is_64bit_op:

View File

@@ -1,199 +0,0 @@
#!/usr/bin/env python3
# generates autogen/__init__.py by parsing the AMD RDNA3.5 ISA PDF
import re, pdfplumber, pathlib
from tinygrad.helpers import fetch
PDF_URL = "https://docs.amd.com/api/khub/documents/UVVZM22UN7tMUeiW_4ShTQ/content"
FIELD_TYPES = {'SSRC0': 'SSrc', 'SSRC1': 'SSrc', 'SOFFSET': 'SSrc', 'SADDR': 'SSrc', 'SRC0': 'Src', 'SRC1': 'Src', 'SRC2': 'Src',
'SDST': 'SGPRField', 'SBASE': 'SGPRField', 'SDATA': 'SGPRField', 'SRSRC': 'SGPRField', 'VDST': 'VGPRField', 'VSRC1': 'VGPRField', 'VDATA': 'VGPRField',
'VADDR': 'VGPRField', 'ADDR': 'VGPRField', 'DATA': 'VGPRField', 'DATA0': 'VGPRField', 'DATA1': 'VGPRField', 'SIMM16': 'SImm', 'OFFSET': 'Imm',
'OPX': 'VOPDOp', 'OPY': 'VOPDOp', 'SRCX0': 'Src', 'SRCY0': 'Src', 'VSRCX1': 'VGPRField', 'VSRCY1': 'VGPRField', 'VDSTX': 'VGPRField', 'VDSTY': 'VDSTYEnc'}
FIELD_ORDER = {
'SOP2': ['op', 'sdst', 'ssrc0', 'ssrc1'], 'SOP1': ['op', 'sdst', 'ssrc0'], 'SOPC': ['op', 'ssrc0', 'ssrc1'],
'SOPK': ['op', 'sdst', 'simm16'], 'SOPP': ['op', 'simm16'], 'VOP1': ['op', 'vdst', 'src0'], 'VOPC': ['op', 'src0', 'vsrc1'],
'VOP2': ['op', 'vdst', 'src0', 'vsrc1'], 'VOP3SD': ['op', 'vdst', 'sdst', 'src0', 'src1', 'src2', 'clmp'],
'SMEM': ['op', 'sdata', 'sbase', 'soffset', 'offset', 'glc', 'dlc'], 'DS': ['op', 'vdst', 'addr', 'data0', 'data1'],
'VOP3': ['op', 'vdst', 'src0', 'src1', 'src2', 'omod', 'neg', 'abs', 'clmp', 'opsel'],
'VOP3P': ['op', 'vdst', 'src0', 'src1', 'src2', 'neg', 'neg_hi', 'opsel', 'opsel_hi', 'clmp'],
'FLAT': ['op', 'vdst', 'addr', 'data', 'saddr', 'offset', 'seg', 'dlc', 'glc', 'slc'],
'MUBUF': ['op', 'vdata', 'vaddr', 'srsrc', 'soffset', 'offset', 'offen', 'idxen', 'glc', 'dlc', 'slc', 'tfe'],
'MTBUF': ['op', 'vdata', 'vaddr', 'srsrc', 'soffset', 'offset', 'format', 'offen', 'idxen', 'glc', 'dlc', 'slc', 'tfe'],
'MIMG': ['op', 'vdata', 'vaddr', 'srsrc', 'ssamp', 'dmask', 'dim', 'unrm', 'dlc', 'glc', 'slc'],
'EXP': ['en', 'target', 'vsrc0', 'vsrc1', 'vsrc2', 'vsrc3', 'done', 'row'],
'VINTERP': ['op', 'vdst', 'src0', 'src1', 'src2', 'waitexp', 'clmp', 'opsel', 'neg'],
'VOPD': ['opx', 'opy', 'vdstx', 'vdsty', 'srcx0', 'vsrcx1', 'srcy0', 'vsrcy1'],
'LDSDIR': ['op', 'vdst', 'attr', 'attr_chan', 'wait_va']}
SRC_EXTRAS = {233: 'DPP8', 234: 'DPP8FI', 250: 'DPP16', 251: 'VCCZ', 252: 'EXECZ', 254: 'LDS_DIRECT'}
FLOAT_MAP = {'0.5': 'POS_HALF', '-0.5': 'NEG_HALF', '1.0': 'POS_ONE', '-1.0': 'NEG_ONE', '2.0': 'POS_TWO', '-2.0': 'NEG_TWO',
'4.0': 'POS_FOUR', '-4.0': 'NEG_FOUR', '1/(2*PI)': 'INV_2PI', '0': 'ZERO'}
def parse_bits(s: str) -> tuple[int, int] | None:
return (int(m.group(1)), int(m.group(2) or m.group(1))) if (m := re.match(r'\[(\d+)(?::(\d+))?\]', s)) else None
def parse_fields_table(table: list, fmt: str, enums: set[str]) -> list[tuple]:
fields = []
for row in table[1:]:
if not row or not row[0]: continue
name, bits_str = row[0].split('\n')[0].strip(), (row[1] or '').split('\n')[0].strip()
if not (bits := parse_bits(bits_str)): continue
enc_val, hi, lo = None, bits[0], bits[1]
if name == 'ENCODING' and row[2] and (m := re.search(r"'b([01_]+)", row[2])):
enc_bits = m.group(1).replace('_', '')
enc_val = int(enc_bits, 2)
declared_width, actual_width = hi - lo + 1, len(enc_bits)
if actual_width > declared_width: lo = hi - actual_width + 1
ftype = f"{fmt}Op" if name == 'OP' and f"{fmt}Op" in enums else FIELD_TYPES.get(name.upper())
fields.append((name, hi, lo, enc_val, ftype))
return fields
def generate(output_path: pathlib.Path|str|None = None) -> dict:
"""Generate RDNA3.5 instruction definitions from the AMD ISA PDF. Returns dict with formats for testing."""
pdf = pdfplumber.open(fetch(PDF_URL))
pages = pdf.pages[150:200]
page_texts = [p.extract_text() or '' for p in pages]
page_tables = [[t.extract() for t in p.find_tables()] for p in pages]
full_text = '\n'.join(page_texts)
# parse SSRC encoding from first page with VCC_LO
src_enum = dict(SRC_EXTRAS)
for text in page_texts[:10]:
if 'SSRC0' in text and 'VCC_LO' in text:
for m in re.finditer(r'^(\d+)\s+(\S+)', text, re.M):
val, name = int(m.group(1)), m.group(2).rstrip('.:')
if name in FLOAT_MAP: src_enum[val] = FLOAT_MAP[name]
elif re.match(r'^[A-Z][A-Z0-9_]*$', name): src_enum[val] = name
break
# parse opcode tables
enums: dict[str, dict[int, str]] = {}
for m in re.finditer(r'Table \d+\. (\w+) Opcodes(.*?)(?=Table \d+\.|\n\d+\.\d+\.\d+\.\s+\w+\s*\nDescription|$)', full_text, re.S):
if ops := {int(x.group(1)): x.group(2) for x in re.finditer(r'(\d+)\s+([A-Z][A-Z0-9_]+)', m.group(2))}:
enums[m.group(1) + "Op"] = ops
if vopd_m := re.search(r'Table \d+\. VOPD Y-Opcodes\n(.*?)(?=Table \d+\.|15\.\d)', full_text, re.S):
if ops := {int(x.group(1)): x.group(2) for x in re.finditer(r'(\d+)\s+(V_DUAL_\w+)', vopd_m.group(1))}:
enums["VOPDOp"] = ops
enum_names = set(enums.keys())
def is_fields_table(t) -> bool: return t and len(t) > 1 and t[0] and 'Field' in str(t[0][0] or '')
def has_encoding(fields) -> bool: return any(f[0] == 'ENCODING' for f in fields)
def has_header_before_fields(text) -> bool:
return (pos := text.find('Field Name')) != -1 and bool(re.search(r'\d+\.\d+\.\d+\.\s+\w+\s*\n', text[:pos]))
# find format headers with their page indices
format_headers = [] # (fmt_name, page_idx)
for i, text in enumerate(page_texts):
for m in re.finditer(r'\d+\.\d+\.\d+\.\s+(\w+)\s*\n?Description', text): format_headers.append((m.group(1), i, m.start()))
for m in re.finditer(r'\d+\.\d+\.\d+\.\s+(\w+)\s*\n', text):
if m.start() > len(text) - 200 and 'Description' not in text[m.end():] and i + 1 < len(page_texts):
next_text = page_texts[i + 1].lstrip()
if next_text.startswith('Description') or (next_text.startswith('"RDNA') and 'Description' in next_text[:200]):
format_headers.append((m.group(1), i, m.start()))
# parse instruction formats
formats: dict[str, list] = {}
for fmt_name, page_idx, header_pos in format_headers:
if fmt_name in formats: continue
text, tables = page_texts[page_idx], page_tables[page_idx]
field_pos = text.find('Field Name', header_pos)
# find fields table with ENCODING (same page or up to 2 pages ahead)
fields = None
for offset in range(3):
if page_idx + offset >= len(pages): break
if offset > 0 and has_header_before_fields(page_texts[page_idx + offset]): break
for t in page_tables[page_idx + offset] if offset > 0 or field_pos > header_pos else []:
if is_fields_table(t) and (f := parse_fields_table(t, fmt_name, enum_names)) and has_encoding(f):
fields = f
break
if fields: break
# for modifier formats (no ENCODING), accept first fields table on same page
if not fields and field_pos > header_pos:
for t in tables:
if is_fields_table(t) and (f := parse_fields_table(t, fmt_name, enum_names)):
fields = f
break
if not fields: continue
field_names = {f[0] for f in fields}
# check next pages for continuation fields (tables without ENCODING)
for pg_offset in range(1, 3):
if page_idx + pg_offset >= len(pages) or has_header_before_fields(page_texts[page_idx + pg_offset]): break
for t in page_tables[page_idx + pg_offset]:
if is_fields_table(t) and (extra := parse_fields_table(t, fmt_name, enum_names)) and not has_encoding(extra):
for ef in extra:
if ef[0] not in field_names:
fields.append(ef)
field_names.add(ef[0])
break
formats[fmt_name] = fields
# fix known PDF errors (verified against LLVM test vectors)
# SMEM: PDF says DLC=bit14, GLC=bit16 but actual encoding is DLC=bit13, GLC=bit14
if 'SMEM' in formats:
formats['SMEM'] = [(n, 13 if n == 'DLC' else 14 if n == 'GLC' else h, 13 if n == 'DLC' else 14 if n == 'GLC' else l, e, t)
for n, h, l, e, t in formats['SMEM']]
# generate output
def enum_lines(name, items):
return [f"class {name}(IntEnum):"] + [f" {n} = {v}" for v, n in sorted(items.items())] + [""]
def field_key(f): return order.index(f[0].lower()) if f[0].lower() in order else 1000
lines = ["# autogenerated from AMD RDNA3.5 ISA PDF by gen.py - do not edit", "from enum import IntEnum",
"from typing import Annotated",
"from extra.assembly.rdna3.lib import bits, BitField, Inst32, Inst64, SGPR, VGPR, TTMP as TTMP, s as s, v as v, ttmp as ttmp, SSrc, Src, SImm, Imm, VDSTYEnc, SGPRField, VGPRField",
"import functools", ""]
lines += enum_lines("SrcEnum", src_enum) + sum([enum_lines(n, ops) for n, ops in sorted(enums.items())], [])
# Format-specific field defaults (verified against LLVM test vectors)
format_defaults = {'VOP3P': {'opsel_hi': 3, 'opsel_hi2': 1}}
lines.append("# instruction formats")
for fmt_name, fields in sorted(formats.items()):
base = "Inst64" if max(f[1] for f in fields) > 31 or fmt_name == 'VOP3SD' else "Inst32"
order = FIELD_ORDER.get(fmt_name, [])
lines.append(f"class {fmt_name}({base}):")
if enc := next((f for f in fields if f[0] == 'ENCODING'), None):
enc_str = f"bits[{enc[1]}:{enc[2]}] == 0b{enc[3]:b}" if enc[1] != enc[2] else f"bits[{enc[1]}] == {enc[3]}"
lines.append(f" encoding = {enc_str}")
if defaults := format_defaults.get(fmt_name):
lines.append(f" _defaults = {defaults}")
for name, hi, lo, _, ftype in sorted([f for f in fields if f[0] != 'ENCODING'], key=field_key):
# Wrap IntEnum types (ending in Op) with Annotated[BitField, ...] for correct typing
if ftype and ftype.endswith('Op'):
ann = f":Annotated[BitField, {ftype}]"
else:
ann = f":{ftype}" if ftype else ""
lines.append(f" {name.lower()}{ann} = bits[{hi}]" if hi == lo else f" {name.lower()}{ann} = bits[{hi}:{lo}]")
lines.append("")
lines.append("# instruction helpers")
for cls_name, ops in sorted(enums.items()):
fmt = cls_name[:-2]
for op_val, name in sorted(ops.items()):
seg = {"GLOBAL": ", seg=2", "SCRATCH": ", seg=2"}.get(fmt, "")
tgt = {"GLOBAL": "FLAT, GLOBALOp", "SCRATCH": "FLAT, SCRATCHOp"}.get(fmt, f"{fmt}, {cls_name}")
if fmt in formats or fmt in ("GLOBAL", "SCRATCH"):
# VOP1/VOP2/VOPC get _e32 suffix, VOP3 promoted ops (< 512) get _e64 suffix
if fmt in ("VOP1", "VOP2", "VOPC"):
suffix = "_e32"
elif fmt == "VOP3" and op_val < 512:
suffix = "_e64"
else:
suffix = ""
# FMAMK/FMAAK have a literal constant K that must be passed via literal= kwarg
# FMAMK: D = S0.f * K + S1.f (K is 3rd operand in assembly syntax)
# FMAAK: D = S0.f * S1.f + K (K is 4th operand in assembly syntax)
if name in ('V_FMAMK_F32', 'V_FMAMK_F16'):
lines.append(f"def {name.lower()}{suffix}(vdst, src0, K, vsrc1): return {fmt}({cls_name}.{name}, vdst, src0, vsrc1, literal=K)")
elif name in ('V_FMAAK_F32', 'V_FMAAK_F16'):
lines.append(f"def {name.lower()}{suffix}(vdst, src0, vsrc1, K): return {fmt}({cls_name}.{name}, vdst, src0, vsrc1, literal=K)")
else:
lines.append(f"{name.lower()}{suffix} = functools.partial({tgt}.{name}{seg})")
# export SrcEnum values, but skip DPP8/DPP16 which conflict with class names
skip_exports = {'DPP8', 'DPP16'}
lines += [""] + [f"{name} = SrcEnum.{name}" for _, name in sorted(src_enum.items()) if name not in skip_exports] + ["OFF = NULL\n"]
if output_path is not None: pathlib.Path(output_path).write_text('\n'.join(lines))
return {"formats": formats, "enums": enums, "src_enum": src_enum}
if __name__ == "__main__":
result = generate("extra/assembly/rdna3/autogen/__init__.py")
print(f"generated SrcEnum ({len(result['src_enum'])}) + {len(result['enums'])} opcode enums + {len(result['formats'])} format classes")

View File

@@ -281,3 +281,209 @@ class Inst:
class Inst32(Inst): pass
class Inst64(Inst): pass
# ═══════════════════════════════════════════════════════════════════════════════
# CODE GENERATION: generates autogen/__init__.py by parsing the AMD RDNA3.5 ISA PDF
# ═══════════════════════════════════════════════════════════════════════════════
PDF_URL = "https://docs.amd.com/api/khub/documents/UVVZM22UN7tMUeiW_4ShTQ/content"
FIELD_TYPES = {'SSRC0': 'SSrc', 'SSRC1': 'SSrc', 'SOFFSET': 'SSrc', 'SADDR': 'SSrc', 'SRC0': 'Src', 'SRC1': 'Src', 'SRC2': 'Src',
'SDST': 'SGPRField', 'SBASE': 'SGPRField', 'SDATA': 'SGPRField', 'SRSRC': 'SGPRField', 'VDST': 'VGPRField', 'VSRC1': 'VGPRField', 'VDATA': 'VGPRField',
'VADDR': 'VGPRField', 'ADDR': 'VGPRField', 'DATA': 'VGPRField', 'DATA0': 'VGPRField', 'DATA1': 'VGPRField', 'SIMM16': 'SImm', 'OFFSET': 'Imm',
'OPX': 'VOPDOp', 'OPY': 'VOPDOp', 'SRCX0': 'Src', 'SRCY0': 'Src', 'VSRCX1': 'VGPRField', 'VSRCY1': 'VGPRField', 'VDSTX': 'VGPRField', 'VDSTY': 'VDSTYEnc'}
FIELD_ORDER = {
'SOP2': ['op', 'sdst', 'ssrc0', 'ssrc1'], 'SOP1': ['op', 'sdst', 'ssrc0'], 'SOPC': ['op', 'ssrc0', 'ssrc1'],
'SOPK': ['op', 'sdst', 'simm16'], 'SOPP': ['op', 'simm16'], 'VOP1': ['op', 'vdst', 'src0'], 'VOPC': ['op', 'src0', 'vsrc1'],
'VOP2': ['op', 'vdst', 'src0', 'vsrc1'], 'VOP3SD': ['op', 'vdst', 'sdst', 'src0', 'src1', 'src2', 'clmp'],
'SMEM': ['op', 'sdata', 'sbase', 'soffset', 'offset', 'glc', 'dlc'], 'DS': ['op', 'vdst', 'addr', 'data0', 'data1'],
'VOP3': ['op', 'vdst', 'src0', 'src1', 'src2', 'omod', 'neg', 'abs', 'clmp', 'opsel'],
'VOP3P': ['op', 'vdst', 'src0', 'src1', 'src2', 'neg', 'neg_hi', 'opsel', 'opsel_hi', 'clmp'],
'FLAT': ['op', 'vdst', 'addr', 'data', 'saddr', 'offset', 'seg', 'dlc', 'glc', 'slc'],
'MUBUF': ['op', 'vdata', 'vaddr', 'srsrc', 'soffset', 'offset', 'offen', 'idxen', 'glc', 'dlc', 'slc', 'tfe'],
'MTBUF': ['op', 'vdata', 'vaddr', 'srsrc', 'soffset', 'offset', 'format', 'offen', 'idxen', 'glc', 'dlc', 'slc', 'tfe'],
'MIMG': ['op', 'vdata', 'vaddr', 'srsrc', 'ssamp', 'dmask', 'dim', 'unrm', 'dlc', 'glc', 'slc'],
'EXP': ['en', 'target', 'vsrc0', 'vsrc1', 'vsrc2', 'vsrc3', 'done', 'row'],
'VINTERP': ['op', 'vdst', 'src0', 'src1', 'src2', 'waitexp', 'clmp', 'opsel', 'neg'],
'VOPD': ['opx', 'opy', 'vdstx', 'vdsty', 'srcx0', 'vsrcx1', 'srcy0', 'vsrcy1'],
'LDSDIR': ['op', 'vdst', 'attr', 'attr_chan', 'wait_va']}
SRC_EXTRAS = {233: 'DPP8', 234: 'DPP8FI', 250: 'DPP16', 251: 'VCCZ', 252: 'EXECZ', 254: 'LDS_DIRECT'}
FLOAT_MAP = {'0.5': 'POS_HALF', '-0.5': 'NEG_HALF', '1.0': 'POS_ONE', '-1.0': 'NEG_ONE', '2.0': 'POS_TWO', '-2.0': 'NEG_TWO',
'4.0': 'POS_FOUR', '-4.0': 'NEG_FOUR', '1/(2*PI)': 'INV_2PI', '0': 'ZERO'}
def _parse_bits(s: str) -> tuple[int, int] | None:
import re
return (int(m.group(1)), int(m.group(2) or m.group(1))) if (m := re.match(r'\[(\d+)(?::(\d+))?\]', s)) else None
def _parse_fields_table(table: list, fmt: str, enums: set[str]) -> list[tuple]:
import re
fields = []
for row in table[1:]:
if not row or not row[0]: continue
name, bits_str = row[0].split('\n')[0].strip(), (row[1] or '').split('\n')[0].strip()
if not (bits := _parse_bits(bits_str)): continue
enc_val, hi, lo = None, bits[0], bits[1]
if name == 'ENCODING' and row[2] and (m := re.search(r"'b([01_]+)", row[2])):
enc_bits = m.group(1).replace('_', '')
enc_val = int(enc_bits, 2)
declared_width, actual_width = hi - lo + 1, len(enc_bits)
if actual_width > declared_width: lo = hi - actual_width + 1
ftype = f"{fmt}Op" if name == 'OP' and f"{fmt}Op" in enums else FIELD_TYPES.get(name.upper())
fields.append((name, hi, lo, enc_val, ftype))
return fields
def generate(output_path: str | None = None) -> dict:
"""Generate RDNA3.5 instruction definitions from the AMD ISA PDF. Returns dict with formats for testing."""
import re, pdfplumber, pathlib
from tinygrad.helpers import fetch
pdf = pdfplumber.open(fetch(PDF_URL))
pages = pdf.pages[150:200]
page_texts = [p.extract_text() or '' for p in pages]
page_tables = [[t.extract() for t in p.find_tables()] for p in pages]
full_text = '\n'.join(page_texts)
# parse SSRC encoding from first page with VCC_LO
src_enum = dict(SRC_EXTRAS)
for text in page_texts[:10]:
if 'SSRC0' in text and 'VCC_LO' in text:
for m in re.finditer(r'^(\d+)\s+(\S+)', text, re.M):
val, name = int(m.group(1)), m.group(2).rstrip('.:')
if name in FLOAT_MAP: src_enum[val] = FLOAT_MAP[name]
elif re.match(r'^[A-Z][A-Z0-9_]*$', name): src_enum[val] = name
break
# parse opcode tables
enums: dict[str, dict[int, str]] = {}
for m in re.finditer(r'Table \d+\. (\w+) Opcodes(.*?)(?=Table \d+\.|\n\d+\.\d+\.\d+\.\s+\w+\s*\nDescription|$)', full_text, re.S):
if ops := {int(x.group(1)): x.group(2) for x in re.finditer(r'(\d+)\s+([A-Z][A-Z0-9_]+)', m.group(2))}:
enums[m.group(1) + "Op"] = ops
if vopd_m := re.search(r'Table \d+\. VOPD Y-Opcodes\n(.*?)(?=Table \d+\.|15\.\d)', full_text, re.S):
if ops := {int(x.group(1)): x.group(2) for x in re.finditer(r'(\d+)\s+(V_DUAL_\w+)', vopd_m.group(1))}:
enums["VOPDOp"] = ops
enum_names = set(enums.keys())
def is_fields_table(t) -> bool: return t and len(t) > 1 and t[0] and 'Field' in str(t[0][0] or '')
def has_encoding(fields) -> bool: return any(f[0] == 'ENCODING' for f in fields)
def has_header_before_fields(text) -> bool:
return (pos := text.find('Field Name')) != -1 and bool(re.search(r'\d+\.\d+\.\d+\.\s+\w+\s*\n', text[:pos]))
# find format headers with their page indices
format_headers = [] # (fmt_name, page_idx)
for i, text in enumerate(page_texts):
for m in re.finditer(r'\d+\.\d+\.\d+\.\s+(\w+)\s*\n?Description', text): format_headers.append((m.group(1), i, m.start()))
for m in re.finditer(r'\d+\.\d+\.\d+\.\s+(\w+)\s*\n', text):
if m.start() > len(text) - 200 and 'Description' not in text[m.end():] and i + 1 < len(page_texts):
next_text = page_texts[i + 1].lstrip()
if next_text.startswith('Description') or (next_text.startswith('"RDNA') and 'Description' in next_text[:200]):
format_headers.append((m.group(1), i, m.start()))
# parse instruction formats
formats: dict[str, list] = {}
for fmt_name, page_idx, header_pos in format_headers:
if fmt_name in formats: continue
text, tables = page_texts[page_idx], page_tables[page_idx]
field_pos = text.find('Field Name', header_pos)
# find fields table with ENCODING (same page or up to 2 pages ahead)
fields = None
for offset in range(3):
if page_idx + offset >= len(pages): break
if offset > 0 and has_header_before_fields(page_texts[page_idx + offset]): break
for t in page_tables[page_idx + offset] if offset > 0 or field_pos > header_pos else []:
if is_fields_table(t) and (f := _parse_fields_table(t, fmt_name, enum_names)) and has_encoding(f):
fields = f
break
if fields: break
# for modifier formats (no ENCODING), accept first fields table on same page
if not fields and field_pos > header_pos:
for t in tables:
if is_fields_table(t) and (f := _parse_fields_table(t, fmt_name, enum_names)):
fields = f
break
if not fields: continue
field_names = {f[0] for f in fields}
# check next pages for continuation fields (tables without ENCODING)
for pg_offset in range(1, 3):
if page_idx + pg_offset >= len(pages) or has_header_before_fields(page_texts[page_idx + pg_offset]): break
for t in page_tables[page_idx + pg_offset]:
if is_fields_table(t) and (extra := _parse_fields_table(t, fmt_name, enum_names)) and not has_encoding(extra):
for ef in extra:
if ef[0] not in field_names:
fields.append(ef)
field_names.add(ef[0])
break
formats[fmt_name] = fields
# fix known PDF errors (verified against LLVM test vectors)
# SMEM: PDF says DLC=bit14, GLC=bit16 but actual encoding is DLC=bit13, GLC=bit14
if 'SMEM' in formats:
formats['SMEM'] = [(n, 13 if n == 'DLC' else 14 if n == 'GLC' else h, 13 if n == 'DLC' else 14 if n == 'GLC' else l, e, t)
for n, h, l, e, t in formats['SMEM']]
# generate output
def enum_lines(name, items):
return [f"class {name}(IntEnum):"] + [f" {n} = {v}" for v, n in sorted(items.items())] + [""]
def field_key(f): return order.index(f[0].lower()) if f[0].lower() in order else 1000
lines = ["# autogenerated from AMD RDNA3.5 ISA PDF by lib.py - do not edit", "from enum import IntEnum",
"from typing import Annotated",
"from extra.assembly.rdna3.lib import bits, BitField, Inst32, Inst64, SGPR, VGPR, TTMP as TTMP, s as s, v as v, ttmp as ttmp, SSrc, Src, SImm, Imm, VDSTYEnc, SGPRField, VGPRField",
"import functools", ""]
lines += enum_lines("SrcEnum", src_enum) + sum([enum_lines(n, ops) for n, ops in sorted(enums.items())], [])
# Format-specific field defaults (verified against LLVM test vectors)
format_defaults = {'VOP3P': {'opsel_hi': 3, 'opsel_hi2': 1}}
lines.append("# instruction formats")
for fmt_name, fields in sorted(formats.items()):
base = "Inst64" if max(f[1] for f in fields) > 31 or fmt_name == 'VOP3SD' else "Inst32"
order = FIELD_ORDER.get(fmt_name, [])
lines.append(f"class {fmt_name}({base}):")
if enc := next((f for f in fields if f[0] == 'ENCODING'), None):
enc_str = f"bits[{enc[1]}:{enc[2]}] == 0b{enc[3]:b}" if enc[1] != enc[2] else f"bits[{enc[1]}] == {enc[3]}"
lines.append(f" encoding = {enc_str}")
if defaults := format_defaults.get(fmt_name):
lines.append(f" _defaults = {defaults}")
for name, hi, lo, _, ftype in sorted([f for f in fields if f[0] != 'ENCODING'], key=field_key):
# Wrap IntEnum types (ending in Op) with Annotated[BitField, ...] for correct typing
if ftype and ftype.endswith('Op'):
ann = f":Annotated[BitField, {ftype}]"
else:
ann = f":{ftype}" if ftype else ""
lines.append(f" {name.lower()}{ann} = bits[{hi}]" if hi == lo else f" {name.lower()}{ann} = bits[{hi}:{lo}]")
lines.append("")
lines.append("# instruction helpers")
for cls_name, ops in sorted(enums.items()):
fmt = cls_name[:-2]
for op_val, name in sorted(ops.items()):
seg = {"GLOBAL": ", seg=2", "SCRATCH": ", seg=2"}.get(fmt, "")
tgt = {"GLOBAL": "FLAT, GLOBALOp", "SCRATCH": "FLAT, SCRATCHOp"}.get(fmt, f"{fmt}, {cls_name}")
if fmt in formats or fmt in ("GLOBAL", "SCRATCH"):
# VOP1/VOP2/VOPC get _e32 suffix, VOP3 promoted ops (< 512) get _e64 suffix
if fmt in ("VOP1", "VOP2", "VOPC"):
suffix = "_e32"
elif fmt == "VOP3" and op_val < 512:
suffix = "_e64"
else:
suffix = ""
# FMAMK/FMAAK have a literal constant K that must be passed via literal= kwarg
# FMAMK: D = S0.f * K + S1.f (K is 3rd operand in assembly syntax)
# FMAAK: D = S0.f * S1.f + K (K is 4th operand in assembly syntax)
if name in ('V_FMAMK_F32', 'V_FMAMK_F16'):
lines.append(f"def {name.lower()}{suffix}(vdst, src0, K, vsrc1): return {fmt}({cls_name}.{name}, vdst, src0, vsrc1, literal=K)")
elif name in ('V_FMAAK_F32', 'V_FMAAK_F16'):
lines.append(f"def {name.lower()}{suffix}(vdst, src0, vsrc1, K): return {fmt}({cls_name}.{name}, vdst, src0, vsrc1, literal=K)")
else:
lines.append(f"{name.lower()}{suffix} = functools.partial({tgt}.{name}{seg})")
# export SrcEnum values, but skip DPP8/DPP16 which conflict with class names
skip_exports = {'DPP8', 'DPP16'}
lines += [""] + [f"{name} = SrcEnum.{name}" for _, name in sorted(src_enum.items()) if name not in skip_exports] + ["OFF = NULL\n"]
if output_path is not None:
import pathlib
pathlib.Path(output_path).write_text('\n'.join(lines))
return {"formats": formats, "enums": enums, "src_enum": src_enum}
if __name__ == "__main__":
result = generate("extra/assembly/rdna3/autogen/__init__.py")
print(f"generated SrcEnum ({len(result['src_enum'])}) + {len(result['enums'])} opcode enums + {len(result['formats'])} format classes")

View File

@@ -817,31 +817,37 @@ from extra.assembly.rdna3.pcode import *
# Add original pseudocode as comment
for pc_line in pc.split('\n'):
lines.append(f" # {pc_line}")
# V_DIV_SCALE: D0 defaults to S0 if no branch taken
if is_div_scale:
lines.append(" S0, S1, S2, D0, D1 = Reg(s0), Reg(s1), Reg(s2), Reg(s0), Reg(0)")
else:
lines.append(" S0, S1, S2, D0, D1 = Reg(s0), Reg(s1), Reg(s2), Reg(d0), Reg(0)")
lines.append(" SCC, VCC, EXEC = Reg(scc), Reg(vcc), Reg(exec_mask)")
lines.append(" EXEC_LO, EXEC_HI = SliceProxy(EXEC, 31, 0), SliceProxy(EXEC, 63, 32)")
lines.append(" tmp, saveexec = Reg(0), Reg(exec_mask)")
lines.append(" laneId = lane")
lines.append(" SIMM16, SIMM32 = Reg(literal), Reg(literal)")
lines.append(" SRC0, VDST = Reg(src0_idx), Reg(vdst_idx)")
# Only create Reg objects for registers actually used in the pseudocode
combined = code + pc
regs = [('S0', 'Reg(s0)'), ('S1', 'Reg(s1)'), ('S2', 'Reg(s2)'),
('D0', 'Reg(s0)' if is_div_scale else 'Reg(d0)'), ('D1', 'Reg(0)'),
('SCC', 'Reg(scc)'), ('VCC', 'Reg(vcc)'), ('EXEC', 'Reg(exec_mask)'),
('tmp', 'Reg(0)'), ('saveexec', 'Reg(exec_mask)'), ('laneId', 'lane'),
('SIMM16', 'Reg(literal)'), ('SIMM32', 'Reg(literal)'),
('SRC0', 'Reg(src0_idx)'), ('VDST', 'Reg(vdst_idx)')]
used = {name for name, _ in regs if name in combined}
# EXEC_LO/EXEC_HI need EXEC
if 'EXEC_LO' in combined or 'EXEC_HI' in combined: used.add('EXEC')
for name, init in regs:
if name in used: lines.append(f" {name} = {init}")
if 'EXEC_LO' in combined: lines.append(" EXEC_LO = SliceProxy(EXEC, 31, 0)")
if 'EXEC_HI' in combined: lines.append(" EXEC_HI = SliceProxy(EXEC, 63, 32)")
# Add compiled pseudocode with markers
lines.append(" # --- compiled pseudocode ---")
for line in code.split('\n'):
lines.append(f" {line}")
lines.append(" # --- end pseudocode ---")
# Generate result dict
lines.append(" result = {'d0': D0._val, 'scc': SCC._val & 1}")
# Generate result dict - use raw params if Reg wasn't created
d0_val = "D0._val" if 'D0' in used else "d0"
scc_val = "SCC._val & 1" if 'SCC' in used else "scc & 1"
lines.append(f" result = {{'d0': {d0_val}, 'scc': {scc_val}}}")
if has_sdst:
lines.append(" result['vcc_lane'] = (VCC._val >> lane) & 1")
else:
elif 'VCC' in used:
lines.append(" if VCC._val != vcc: result['vcc_lane'] = (VCC._val >> lane) & 1")
if is_cmpx:
lines.append(" result['exec_lane'] = (EXEC._val >> lane) & 1")
else:
elif 'EXEC' in used:
lines.append(" if EXEC._val != exec_mask: result['exec'] = EXEC._val")
if is_cmp:
lines.append(" result['vcc_lane'] = (D0._val >> lane) & 1")

View File

@@ -126,7 +126,7 @@ class PythonEmulator:
def run_single_kernel(kernel: bytes, n_lanes: int, args_ptr: int, global_size: tuple[int, int, int],
program, max_steps: int, debug: bool, trace_len: int, kernel_idx: int = 0,
max_workgroups: int = 64) -> tuple[bool, str, int]:
max_workgroups: int = 8) -> tuple[bool, str, int]:
"""Run a single kernel through both emulators. Returns (success, message, total_steps)."""
gx, gy, gz = global_size
total_steps = 0
@@ -356,191 +356,52 @@ class TestTinygradKernels(unittest.TestCase):
ok, msg = compare_emulators_multi_kernel(kernels, buf_pool, max_steps=max_steps, buf_data=buf_data)
self.assertTrue(ok, msg)
# Basic unary ops
def test_neg(self): self._test_kernel(lambda T: -T([1.0, -2.0, 3.0, -4.0]))
def test_relu(self): self._test_kernel(lambda T: T([-1.0, 0.0, 1.0, 2.0]).relu())
def test_exp(self): self._test_kernel(lambda T: T([0.0, 1.0, 2.0]).exp())
def test_log(self): self._test_kernel(lambda T: T([1.0, 2.0, 3.0]).log())
def test_sin(self): self._test_kernel(lambda T: T([0.0, 1.0, 2.0]).sin())
def test_cos(self): self._test_kernel(lambda T: T([0.0, 1.0, 2.0]).cos())
def test_sqrt(self): self._test_kernel(lambda T: T([1.0, 4.0, 9.0]).sqrt())
def test_recip(self): self._test_kernel(lambda T: T([1.0, 2.0, 4.0]).reciprocal())
# Sin/cos with various ranges - test polynomial expansion
def test_sin_small(self): self._test_kernel(lambda T: T([0.1, 0.2, 0.3, 0.4, 0.5]*7).sin()) # 35 elements, small angles
def test_sin_pi(self): self._test_kernel(lambda T: T([3.14159, 1.5708, 0.7854, -1.5708, -3.14159]*7).sin()) # around pi
def test_sin_medium(self): self._test_kernel(lambda T: T([10.0, 20.0, 30.0, 50.0, 100.0]*7).sin()) # medium values
def test_sin_negative(self): self._test_kernel(lambda T: T([-0.5, -1.0, -2.0, -5.0, -10.0]*7).sin()) # negative values
def test_cos_small(self): self._test_kernel(lambda T: T([0.1, 0.2, 0.3, 0.4, 0.5]*7).cos())
def test_cos_pi(self): self._test_kernel(lambda T: T([3.14159, 1.5708, 0.7854, -1.5708, -3.14159]*7).cos())
def test_cos_medium(self): self._test_kernel(lambda T: T([10.0, 20.0, 30.0, 50.0, 100.0]*7).cos())
@unittest.skip("Rust emulator has V_DIV_SCALE_F32 bug - returns 0 instead of src0 for normal cases")
def test_tan(self): self._test_kernel(lambda T: T([0.1, 0.2, 0.5, 1.0, -0.5]*7).tan()) # avoid pi/2
# Binary ops
def test_add(self): self._test_kernel(lambda T: T([1.0, 2.0]) + T([3.0, 4.0]))
def test_sub(self): self._test_kernel(lambda T: T([5.0, 6.0]) - T([1.0, 2.0]))
def test_mul(self): self._test_kernel(lambda T: T([2.0, 3.0]) * T([4.0, 5.0]))
def test_div(self): self._test_kernel(lambda T: T([10.0, 20.0]) / T([2.0, 4.0]))
def test_max_binary(self): self._test_kernel(lambda T: T([1.0, 5.0]).maximum(T([3.0, 2.0])))
# Basic ops - consolidated tests covering key instruction patterns
def test_unary_ops(self): self._test_kernel(lambda T: T([-1.0, 0.0, 1.0, 2.0]).relu().exp().log().sqrt().reciprocal())
def test_binary_ops(self): self._test_kernel(lambda T: (T([1.0, 2.0]) + T([3.0, 4.0])) * T([0.5, 0.5]) - T([1.0, 1.0]))
def test_trig(self): self._test_kernel(lambda T: T([0.1, 1.0, 3.14, -1.0]*8).sin() + T([0.1, 1.0, 3.14, -1.0]*8).cos())
def test_compare(self): self._test_kernel(lambda T: (T.empty(64) < T.empty(64)).where(T.empty(64), T.empty(64)))
def test_bitwise(self): self._test_kernel(lambda T: (T([0xF0, 0x0F, 0xFF]*11).int() & T([0x0F, 0x0F, 0x00]*11).int()) | T([1]*33).int())
def test_int_ops(self): self._test_kernel(lambda T: ((T.empty(64).int() + T.empty(64).int()) * T.empty(64).int()).float())
# Reductions
def test_sum_reduce(self): self._test_kernel(lambda T: T.empty(64).sum())
def test_max_reduce(self): self._test_kernel(lambda T: T.empty(64).max())
def test_mean_reduce(self): self._test_kernel(lambda T: T.empty(32).mean())
def test_reduce(self): self._test_kernel(lambda T: T.empty(64).sum() + T.empty(64).max())
def test_argmax(self): self._test_kernel(lambda T: T.empty(64).argmax())
# Matmul - various sizes
def test_gemm_4x4(self): self._test_kernel(lambda T: T.empty(4, 4) @ T.empty(4, 4), max_steps=100000)
def test_gemm_8x8(self): self._test_kernel(lambda T: T.empty(8, 8) @ T.empty(8, 8), max_steps=200000)
@unittest.skip("too slow")
def test_gemm_16x16(self): self._test_kernel(lambda T: T.empty(16, 16) @ T.empty(16, 16), max_steps=500000)
def test_gemv(self): self._test_kernel(lambda T: T.empty(1, 16) @ T.empty(16, 16), max_steps=100000)
# Matmul
def test_gemm(self): self._test_kernel(lambda T: T.empty(8, 8) @ T.empty(8, 8), max_steps=100000)
def test_gemm_fp16(self): self._test_kernel(lambda T: T.empty(16, 16).half() @ T.empty(16, 16).half(), max_steps=100000)
# Complex ops
def test_softmax(self): self._test_kernel(lambda T: T.empty(16).softmax())
def test_layernorm(self): self._test_kernel(lambda T: T.empty(8, 8).layernorm())
# Memory patterns
def test_contiguous(self): self._test_kernel(lambda T: T.empty(4, 4).permute(1, 0).contiguous())
def test_reshape(self): self._test_kernel(lambda T: (T.empty(16) + 1).reshape(4, 4).contiguous())
def test_expand(self): self._test_kernel(lambda T: T.empty(4, 1).expand(4, 4).contiguous())
def test_memory(self): self._test_kernel(lambda T: T.empty(4, 4).permute(1, 0).contiguous() + T.empty(4, 1).expand(4, 4))
# Cast ops
def test_cast_int(self): self._test_kernel(lambda T: T.empty(16).int().float())
def test_cast_half(self): self._test_kernel(lambda T: T.empty(16).half().float())
def test_cast(self): self._test_kernel(lambda T: T.empty(32).half().float() + T.empty(32).int().float())
# Min/max (uses comparison internally)
def test_min_binary(self): self._test_kernel(lambda T: T([1.0, 5.0, 3.0]).minimum(T([3.0, 2.0, 4.0])))
# Pooling - regression for VCC wave32 mode
def test_pool2d(self): self._test_kernel(lambda T: T.empty(1, 1, 8, 8).avg_pool2d(kernel_size=(4,4)) + T.empty(1, 1, 8, 8).max_pool2d(kernel_size=(4,4)))
# Comparison ops (test VOPC instructions) - use 32+ elements to force vector instructions
def test_cmp_lt(self): self._test_kernel(lambda T: (T.empty(64) < T.empty(64)).where(T.empty(64), T.empty(64)))
def test_cmp_eq(self): self._test_kernel(lambda T: (T.empty(64) == T.empty(64)).where(T.empty(64), T.empty(64)))
def test_where(self): self._test_kernel(lambda T: (T.empty(64) > 0).where(T.empty(64), T.empty(64)))
# Convolution
def test_conv2d(self): self._test_kernel(lambda T: T.empty(1, 2, 8, 8).conv2d(T.empty(2, 2, 3, 3)), max_steps=50000)
# Bitwise ops
def test_bitwise_and(self): self._test_kernel(lambda T: T([0xF0, 0x0F, 0xFF]).int() & T([0x0F, 0x0F, 0x00]).int())
def test_bitwise_or(self): self._test_kernel(lambda T: T([0xF0, 0x0F, 0x00]).int() | T([0x0F, 0x0F, 0xFF]).int())
def test_bitwise_xor(self): self._test_kernel(lambda T: T([0xFF, 0x0F, 0xF0]).int() ^ T([0x0F, 0xF0, 0xF0]).int())
# Integer ops - use 32+ elements to force vector instructions
def test_int_add(self): self._test_kernel(lambda T: (T.empty(64).int() + T.empty(64).int()).float())
def test_int_mul(self): self._test_kernel(lambda T: (T.empty(64).int() * T.empty(64).int()).float())
def test_int_mod(self): self._test_kernel(lambda T: (T.empty(64).int().abs() % (T.empty(64).int().abs() + 1)).float())
# More math ops - use 32+ elements to force vector instructions
def test_abs(self): self._test_kernel(lambda T: T.empty(64).abs())
def test_floor(self): self._test_kernel(lambda T: T.empty(64).floor())
def test_ceil(self): self._test_kernel(lambda T: T.empty(64).ceil())
def test_trunc(self): self._test_kernel(lambda T: T.empty(64).trunc())
# Fused ops
def test_fma(self): self._test_kernel(lambda T: (T([1.0, 2.0]) * T([3.0, 4.0]) + T([5.0, 6.0])))
# Argmax/argmin (tests different reduction pattern) - use 32+ elements to force vector instructions
def test_argmax(self): self._test_kernel(lambda T: T.empty(64).argmax())
def test_argmin(self): self._test_kernel(lambda T: T.empty(64).argmin())
# Exact value tests - use 32+ elements to force vector instructions (small tensors use scalar ops which Rust emu doesn't fully support)
def test_abs_exact(self): self._test_kernel(lambda T: T([-1., 0., 1.]*11).abs()) # 33 elements
def test_neg_exact(self): self._test_kernel(lambda T: -T([-1., 0., 1.]*11))
def test_log_special(self): self._test_kernel(lambda T: T([1., 2., 0.5]*11).log())
def test_exp_exact(self): self._test_kernel(lambda T: T([0., 1., -1.]*11).exp())
def test_reciprocal_exact(self): self._test_kernel(lambda T: T([1., 2., 0.5]*11).reciprocal())
# Integer division and mod - use 32+ elements
def test_int_div(self): self._test_kernel(lambda T: (T([10, 20, 30]*11).int() // T([3, 4, 5]*11).int()).float())
def test_int_neg(self): self._test_kernel(lambda T: (-T([1, -2, 3]*11).int()).float())
# Mixed precision - use 32+ elements
def test_half_add(self): self._test_kernel(lambda T: (T([1., 2.]*16).half() + T([3., 4.]*16).half()).float())
def test_half_mul(self): self._test_kernel(lambda T: (T([2., 3.]*16).half() * T([4., 5.]*16).half()).float())
# Matrix ops - patterns from test_ops.py failures
def test_cat(self): self._test_kernel(lambda T: T.empty(32, 64).cat(T.empty(32, 64), dim=1))
def test_gather(self): self._test_kernel(lambda T: T.empty(64).gather(0, T.arange(32).int()))
# Tests from test_ops.py that are failing
def test_permute(self): self._test_kernel(lambda T: T.empty(3, 4, 5, 6).permute((3, 2, 1, 0)).contiguous())
def test_cat_large(self): self._test_kernel(lambda T: T.empty(45, 65, 9).cat(T.empty(45, 65, 9), T.empty(45, 65, 9), dim=1))
def test_gather_small(self): self._test_kernel(lambda T: T.empty(10).gather(0, T.arange(5).int()))
@unittest.skip("Rust emulator has S_ADD_I32 SCC bug - uses carry instead of signed overflow")
def test_cross_entropy(self): self._test_kernel(lambda T: T.randn(32, 10).softmax().log().sum())
def test_cross_entropy_class(self):
import numpy as np
np.random.seed(0)
classes = np.random.randint(0, 10, (32,), dtype=np.int32).tolist()
x_np = np.random.randn(32, 10).astype(np.float32)
self._test_kernel(lambda T: (T(x_np.tolist()).reshape(32,10) + 0).cross_entropy((T(classes).int().reshape(32) + 0)))
# Regression tests for BFE operations with width=0 (walrus operator bug)
# Regression tests
def test_topk(self): self._test_kernel(lambda T: T.empty(64).topk(3)[0])
def test_interpolate_uint8(self): self._test_kernel(lambda T: T.empty(2,3,64,64).relu().cast('uint8').interpolate((10,10), mode="linear"))
# Regression test for 64-bit comparison (V_CMP_GT_I64, V_CMP_LT_U64, etc.) with rsrc64
def test_interpolate(self): self._test_kernel(lambda T: T.empty(1,2,16,16).relu().cast('uint8').interpolate((8,8), mode="linear"))
def test_index_int64(self):
from tinygrad import dtypes
self._test_kernel(lambda T: T.empty(4, 4)[T.arange(4).cast(dtypes.int64), :])
@unittest.skip("only works with mock GPU")
def test_index_int64_2d(self):
from tinygrad import dtypes
# Tests 64-bit compare with inline constants (comparing against 0)
self._test_kernel(lambda T: T.empty(4, 4)[T.arange(4).cast(dtypes.int64), T.arange(4).cast(dtypes.int64)])
# Pooling operations - regression test for VCC wave32 mode (S_CBRANCH_VCCZ should only check VCC_LO)
def test_avg_pool2d(self): self._test_kernel(lambda T: T.empty(1, 1, 8, 8).avg_pool2d(kernel_size=(4,4), stride=2))
# Trig functions with special values (inf, nan, 0)
def test_sin_special(self): self._test_kernel(lambda T: T([0., 0.25, 0.5, 1.0]*8).sin())
def test_cos_special(self): self._test_kernel(lambda T: T([0., 0.25, 0.5, 1.0]*8).cos())
# Sqrt and rsqrt
def test_sqrt(self): self._test_kernel(lambda T: T([0., 1., 4., 9.]*8).sqrt())
def test_rsqrt(self): self._test_kernel(lambda T: T([1., 4., 9., 16.]*8).rsqrt())
@unittest.skip("Rust emulator has S_ADD_I32 SCC bug - uses carry instead of signed overflow")
def test_avg_pool3d(self):
def test_gelu(self): self._test_kernel(lambda T: T.empty(32, 32).gelu())
def test_cross_entropy(self):
import numpy as np
np.random.seed(0)
self._test_kernel(lambda T: T(np.random.randn(1, 1, 16, 16, 16).astype(np.float32).tolist()).avg_pool2d(kernel_size=(8,8,8), stride=5, padding=1, count_include_pad=False))
def test_max_pool2d(self): self._test_kernel(lambda T: T.empty(1, 1, 8, 8).max_pool2d(kernel_size=(4,4), stride=2))
# Convolution operations - multi-kernel tests
def test_conv2d(self): self._test_kernel(lambda T: T.empty(1, 4, 8, 8).conv2d(T.empty(4, 4, 3, 3)), max_steps=100000)
def test_conv_transpose2d(self): self._test_kernel(lambda T: T.empty(1, 4, 8, 8).conv_transpose2d(T.empty(4, 4, 3, 3)), max_steps=200000)
@unittest.skip("Rust emulator has S_ADD_I32 SCC bug - uses carry instead of signed overflow")
def test_conv_transpose3d(self):
import numpy as np
np.random.seed(0)
self._test_kernel(lambda T: T(np.random.randn(2, 4, 9, 9, 9).astype(np.float32).tolist()).conv_transpose2d(
T(np.random.randn(4, 4, 3, 3, 3).astype(np.float32).tolist())), max_steps=500000)
# Tests from test_ops.py failures
def test_gelu_extreme(self): self._test_kernel(lambda T: T.empty(45, 65).gelu())
def test_gemm_64x64(self): self._test_kernel(lambda T: T.empty(64, 64) @ T.empty(64, 64), max_steps=500000)
def test_gemm_fp16(self): self._test_kernel(lambda T: T.empty(64, 64).half() @ T.empty(64, 64).half(), max_steps=500000)
def test_global_avg_pool2d(self): self._test_kernel(lambda T: T.empty(32, 2, 111, 28).avg_pool2d(kernel_size=(111, 28)), max_steps=100000)
@unittest.skip("Rust emulator has S_ADD_I32 SCC bug - uses carry instead of signed overflow")
def test_grouped_conv2d(self): self._test_kernel(lambda T: T.empty(4, 15, 5, 5).conv2d(T.empty(35, 3, 3, 3), groups=5), max_steps=200000)
@unittest.skip("Rust emulator has S_ADD_I32 SCC bug - uses carry instead of signed overflow")
def test_grouped_conv_transpose2d(self): self._test_kernel(lambda T: T.empty(2, 4, 9, 9).conv_transpose2d(T.empty(4, 4, 3, 3), groups=2), max_steps=200000)
def test_hardsigmoid(self): self._test_kernel(lambda T: T.empty(45, 65).hardsigmoid())
def test_hardsigmoid_extreme(self): self._test_kernel(lambda T: T.empty(45, 65).sigmoid())
def test_matvec(self): self._test_kernel(lambda T: (T.empty(1, 128) @ T.empty(128, 128)).relu(), max_steps=200000)
def test_matvecmat(self): self._test_kernel(lambda T: ((T.empty(1, 128) @ T.empty(128, 128)).relu() @ T.empty(128, 128)), max_steps=300000)
def test_max_reduce_45x3(self): self._test_kernel(lambda T: T.empty(45, 3).max())
def test_max_dont_collapse(self): self._test_kernel(lambda T: T.empty(4, 8).max(axis=1))
def test_max_pool2d_simple(self): self._test_kernel(lambda T: T.empty(1, 1, 2, 3).max_pool2d(kernel_size=(2, 2)))
def test_max_pool2d_32x2(self): self._test_kernel(lambda T: T.empty(32, 2, 11, 28).max_pool2d(kernel_size=(2, 2)))
def test_max_pool2d_asymmetric_padding(self): self._test_kernel(lambda T: T.empty(4, 2, 111, 28).max_pool2d(kernel_size=(5, 5), padding=(0, 1, 0, 1)))
def test_max_pool2d_bigger_stride(self): self._test_kernel(lambda T: T.empty(4, 2, 11, 28).max_pool2d(kernel_size=(2, 2), stride=(2, 3)))
def test_max_pool2d_unit_stride(self): self._test_kernel(lambda T: T.empty(3, 2, 17, 14).max_pool2d(kernel_size=(5, 5), stride=1))
def test_max_pool2d_smaller_stride(self): self._test_kernel(lambda T: T.empty(3, 2, 17, 14).max_pool2d(kernel_size=(5, 5), stride=(2, 3)))
def test_max_unpool2d(self): self._test_kernel(lambda T: T.max_unpool2d(*T.empty(8, 3, 50, 50).max_pool2d(kernel_size=(5, 5), stride=(6, 5), return_indices=True), kernel_size=(5, 5), stride=(6, 5)))
classes = np.random.randint(0, 10, (16,), dtype=np.int32).tolist()
x_np = np.random.randn(16, 10).astype(np.float32)
self._test_kernel(lambda T: (T(x_np.tolist()).reshape(16,10) + 0).cross_entropy((T(classes).int().reshape(16) + 0)))
def test_isinf(self): self._test_kernel(lambda T: T([float('-inf'), 0., float('inf'), 1.1]*8).isinf())
def test_isfinite(self): self._test_kernel(lambda T: T([float('-inf'), 0., float('inf'), 1.1]*8).isfinite())
# WMMA tests - uses wave matrix multiply for larger fp16 matmuls
def test_wmma_gemm_fp16(self): self._test_kernel(lambda T: T.empty(64, 64).half() @ T.empty(64, 64).half(), max_steps=1000000)
if __name__ == "__main__":
unittest.main()

View File

@@ -1,8 +1,9 @@
#!/usr/bin/env python3
"""Integration test: round-trip RDNA3 assembly through AMD toolchain."""
import unittest, re, io, sys
import unittest, re, io, sys, subprocess
from extra.assembly.rdna3.autogen import *
from extra.assembly.rdna3.asm import waitcnt, asm
from extra.assembly.rdna3.test.test_roundtrip import _get_llvm_mc
def get_amd_toolchain():
"""Check if AMD toolchain is available."""
@@ -212,10 +213,8 @@ class TestAsm(unittest.TestCase):
def test_asm_vop3_modifiers(self):
"""Test asm() with VOP3 modifiers (neg, abs, clamp)."""
import subprocess, re
def get_llvm_encoding(instr: str) -> str:
result = subprocess.run(['llvm-mc', '-triple=amdgcn', '-mcpu=gfx1100', '-show-encoding'],
result = subprocess.run([_get_llvm_mc(), '-triple=amdgcn', '-mcpu=gfx1100', '-show-encoding'],
input=instr, capture_output=True, text=True)
if m := re.search(r'encoding:\s*\[(.*?)\]', result.stdout):
return m.group(1).replace('0x','').replace(',','').replace(' ','')

View File

@@ -1,10 +1,10 @@
#!/usr/bin/env python3
"""Test RDNA3 assembler/disassembler against LLVM test vectors."""
import unittest, re
import unittest, re, subprocess
from tinygrad.helpers import fetch
from extra.assembly.rdna3.autogen import *
from extra.assembly.rdna3.asm import asm
from extra.assembly.rdna3.test.test_roundtrip import compile_asm, disassemble_lib
from extra.assembly.rdna3.test.test_roundtrip import _get_llvm_mc
LLVM_BASE = "https://raw.githubusercontent.com/llvm/llvm-project/main/llvm/test/MC/AMDGPU"
@@ -78,6 +78,24 @@ def try_assemble(text: str):
try: return asm(text).to_bytes()
except: return None
def compile_asm_batch(instrs: list[str]) -> list[bytes]:
"""Compile multiple instructions with a single llvm-mc call."""
if not instrs: return []
asm_text = ".text\n" + "\n".join(instrs) + "\n"
result = subprocess.run(
[_get_llvm_mc(), '-triple=amdgcn', '-mcpu=gfx1100', '-mattr=+real-true16,+wavefrontsize32', '-show-encoding'],
input=asm_text, capture_output=True, text=True, timeout=30)
if result.returncode != 0: raise RuntimeError(f"llvm-mc batch failed: {result.stderr.strip()}")
# Parse all encodings from output
results = []
for line in result.stdout.split('\n'):
if 'encoding:' not in line: continue
enc = line.split('encoding:')[1].strip()
if enc.startswith('[') and enc.endswith(']'):
results.append(bytes.fromhex(enc[1:-1].replace('0x', '').replace(',', '').replace(' ', '')))
if len(results) != len(instrs): raise RuntimeError(f"expected {len(instrs)} encodings, got {len(results)}")
return results
class TestLLVM(unittest.TestCase):
"""Test assembler and disassembler against all LLVM test vectors."""
tests: dict[str, list[tuple[str, bytes]]] = {}
@@ -107,59 +125,62 @@ def _make_asm_test(name):
def _make_disasm_test(name):
def test(self):
from tinygrad.runtime.support.compiler_amd import HIPCompiler
compiler = HIPCompiler('gfx1100')
_, fmt_cls, op_enum = LLVM_TEST_FILES[name]
passed, failed, skipped, failures = 0, 0, 0, []
# VOP3SD opcodes that share encoding with VOP3 (only for vop3sd test, not vopc promotions)
# Note: opcodes 0-255 are VOPC promoted to VOP3, never VOP3SD
vop3sd_opcodes = {288, 289, 290, 764, 765, 766, 767, 768, 769, 770}
# vop3_from_vopc/vopcx tests have VOPC opcodes 0-255, not VOP3SD - don't detect as VOP3SD
is_vopc_promotion = name in ('vop3_from_vopc', 'vop3_from_vopcx')
# Undocumented opcodes not in AMD ISA PDF - skip these
undocumented = {'smem': {34, 35}, 'sopk': {22, 23}, 'sopp': {8, 58, 59}} # s_atc_probe*, s_subvector_loop*, s_waitcnt_depctr, unknown
undocumented = {'smem': {34, 35}, 'sopk': {22, 23}, 'sopp': {8, 58, 59}}
# First pass: decode all instructions and collect disasm strings
to_test = [] # list of (asm_text, data, disasm_str)
skipped = 0
for asm_text, data in self.tests.get(name, []):
if len(data) > fmt_cls._size(): continue # skip literals (need different handling)
# Skip undocumented opcodes
if len(data) > fmt_cls._size(): continue
temp_inst = fmt_cls.from_bytes(data)
temp_op = temp_inst._values.get('op', 0)
temp_op = temp_op.val if hasattr(temp_op, 'val') else temp_op
if temp_op in undocumented.get(name, set()): skipped += 1; continue
# Skip SOPP no-imm instructions with non-zero simm16 (can't roundtrip through LLVM)
if name == 'sopp':
simm16 = temp_inst._values.get('simm16', 0)
simm16 = simm16.val if hasattr(simm16, 'val') else simm16
sopp_no_imm = {48, 54, 53, 55, 60, 61, 62} # s_endpgm, s_barrier, s_wakeup, s_icache_inv, s_wait_idle, s_endpgm_saved, s_code_end
sopp_no_imm = {48, 54, 53, 55, 60, 61, 62}
if temp_op in sopp_no_imm and simm16 != 0: skipped += 1; continue
try:
# VOP3 and VOP3SD share encoding - peek at opcode to determine which class to use
if fmt_cls.__name__ in ('VOP3', 'VOP3SD'):
temp = VOP3.from_bytes(data)
op_val = temp._values.get('op', 0)
op_val = op_val.val if hasattr(op_val, 'val') else op_val
is_vop3sd = (op_val in vop3sd_opcodes) and not is_vopc_promotion
decoded = VOP3SD.from_bytes(data) if is_vop3sd else VOP3.from_bytes(data)
# Validate opcode with appropriate enum
if is_vop3sd:
VOP3SDOp(op_val)
else:
VOP3Op(op_val)
if is_vop3sd: VOP3SDOp(op_val)
else: VOP3Op(op_val)
else:
decoded = fmt_cls.from_bytes(data)
op_val = decoded._values.get('op', 0)
op_val = op_val.val if hasattr(op_val, 'val') else op_val
op_enum(op_val) # validate opcode
op_enum(op_val)
if decoded.to_bytes()[:len(data)] != data:
failed += 1; failures.append(f"decode roundtrip failed for {data.hex()}"); continue
disasm_str = decoded.disasm()
# Test: LLVM should assemble our disasm output to the same bytes
llvm_bytes = compile_asm(disasm_str, compiler)
if llvm_bytes is None:
failed += 1; failures.append(f"LLVM failed to assemble: '{disasm_str}' (from '{asm_text}')")
elif llvm_bytes == data: passed += 1
else: failed += 1; failures.append(f"'{disasm_str}': expected={data.hex()} got={llvm_bytes.hex()}")
to_test.append((asm_text, data, None, "decode roundtrip failed"))
continue
to_test.append((asm_text, data, decoded.disasm(), None))
except Exception as e:
failed += 1; failures.append(f"exception for {data.hex()}: {e}")
to_test.append((asm_text, data, None, f"exception: {e}"))
# Batch compile all disasm strings with single llvm-mc call
disasm_strs = [(i, t[2]) for i, t in enumerate(to_test) if t[2] is not None]
llvm_results = compile_asm_batch([s for _, s in disasm_strs]) if disasm_strs else []
llvm_map = {i: llvm_results[j] for j, (i, _) in enumerate(disasm_strs)}
# Match results back
passed, failed, failures = 0, 0, []
for idx, (asm_text, data, disasm_str, error) in enumerate(to_test):
if error:
failed += 1; failures.append(f"{error} for {data.hex()}")
elif disasm_str is not None and idx in llvm_map:
llvm_bytes = llvm_map[idx]
if llvm_bytes == data: passed += 1
else: failed += 1; failures.append(f"'{disasm_str}': expected={data.hex()} got={llvm_bytes.hex()}")
print(f"{name.upper()} disasm: {passed} passed, {failed} failed" + (f", {skipped} skipped" if skipped else ""))
if failures[:10]: print(" " + "\n ".join(failures[:10]))
self.assertEqual(failed, 0)

View File

@@ -49,7 +49,7 @@ dev.synchronize()
self.assertTrue("NotImplementedError" in result.stderr or "ValueError" in result.stderr,
f"expected NotImplementedError or ValueError in stderr")
# Should exit immediately, not wait for the full timeout
self.assertLess(elapsed, 5.0, f"should exit immediately on emulator exception, took {elapsed:.1f}s")
self.assertLess(elapsed, 9.0, f"should exit immediately on emulator exception, took {elapsed:.1f}s")
if __name__ == "__main__":
unittest.main()

View File

@@ -1,6 +1,6 @@
#!/usr/bin/env python3
"""Test that PDF parser correctly extracts format fields."""
import unittest
import unittest, os
from extra.assembly.rdna3.autogen import (
SOP1, SOP2, SOPK, SOPP, VOP1, VOP2, VOP3SD, VOPC, FLAT, VOPD,
SOP1Op, SOP2Op, VOP1Op, VOP3Op
@@ -33,34 +33,32 @@ EXPECTED_FORMATS = {
'VOPD': (['OPX', 'OPY', 'SRCX0', 'SRCY0', 'VDSTX', 'VDSTY'], True),
}
# Skip PDF parsing tests by default - only run with TEST_PDF_PARSER=1
# These are slow (~5s) and only needed when regenerating autogen/
@unittest.skipUnless(os.environ.get("TEST_PDF_PARSER"), "set TEST_PDF_PARSER=1 to run PDF parser tests")
class TestPDFParserGenerate(unittest.TestCase):
"""Test the PDF parser by running generate() and checking results."""
result: dict
@classmethod
def setUpClass(cls):
from extra.assembly.rdna3.gen import generate
cls.result = generate()
def test_pdf_parser(self):
"""Single test that validates all PDF parser outputs."""
from extra.assembly.rdna3.lib import generate
result = generate()
def test_all_formats_present(self):
"""All expected formats should be parsed."""
# test_all_formats_present
for fmt_name in EXPECTED_FORMATS:
self.assertIn(fmt_name, self.result["formats"], f"missing format {fmt_name}")
self.assertIn(fmt_name, result["formats"], f"missing format {fmt_name}")
def test_format_count(self):
"""Should have exactly 23 formats."""
self.assertEqual(len(self.result["formats"]), 23)
# test_format_count
self.assertEqual(len(result["formats"]), 23)
def test_no_duplicate_fields(self):
"""No format should have duplicate field names."""
for fmt_name, fields in self.result["formats"].items():
# test_no_duplicate_fields
for fmt_name, fields in result["formats"].items():
field_names = [f[0] for f in fields]
self.assertEqual(len(field_names), len(set(field_names)), f"{fmt_name} has duplicate fields: {field_names}")
def test_expected_fields(self):
"""Each format should have its expected key fields."""
# test_expected_fields
for fmt_name, (expected_fields, has_encoding) in EXPECTED_FORMATS.items():
fields = {f[0] for f in self.result["formats"].get(fmt_name, [])}
fields = {f[0] for f in result["formats"].get(fmt_name, [])}
for field in expected_fields:
self.assertIn(field, fields, f"{fmt_name} missing {field}")
if has_encoding:
@@ -68,21 +66,18 @@ class TestPDFParserGenerate(unittest.TestCase):
else:
self.assertNotIn("ENCODING", fields, f"{fmt_name} should not have ENCODING")
def test_vopd_no_dpp16_fields(self):
"""VOPD should not have DPP16-specific fields (parser boundary bug)."""
vopd_fields = {f[0] for f in self.result["formats"].get("VOPD", [])}
# test_vopd_no_dpp16_fields
vopd_fields = {f[0] for f in result["formats"].get("VOPD", [])}
for field in ['DPP_CTRL', 'BANK_MASK', 'ROW_MASK']:
self.assertNotIn(field, vopd_fields, f"VOPD should not have {field}")
def test_dpp16_no_vinterp_fields(self):
"""DPP16 should not have VINTERP-specific fields."""
dpp16_fields = {f[0] for f in self.result["formats"].get("DPP16", [])}
# test_dpp16_no_vinterp_fields
dpp16_fields = {f[0] for f in result["formats"].get("DPP16", [])}
for field in ['VDST', 'WAITEXP']:
self.assertNotIn(field, dpp16_fields, f"DPP16 should not have {field}")
def test_sopp_no_smem_fields(self):
"""SOPP should not have SMEM fields (page break bug)."""
sopp_fields = {f[0] for f in self.result["formats"].get("SOPP", [])}
# test_sopp_no_smem_fields
sopp_fields = {f[0] for f in result["formats"].get("SOPP", [])}
for field in ['SBASE', 'SDATA']:
self.assertNotIn(field, sopp_fields, f"SOPP should not have {field}")

View File

@@ -1,11 +1,12 @@
#!/usr/bin/env python3
import unittest, subprocess
from extra.assembly.rdna3.autogen import *
from extra.assembly.rdna3.test.test_roundtrip import _get_llvm_mc
def llvm_assemble(asm: str) -> bytes:
"""Assemble using llvm-mc and return bytes."""
result = subprocess.run(
["llvm-mc", "-triple=amdgcn", "-mcpu=gfx1100", "-show-encoding"],
[_get_llvm_mc(), "-triple=amdgcn", "-mcpu=gfx1100", "-show-encoding"],
input=asm, capture_output=True, text=True
)
out = b''

View File

@@ -1,10 +1,20 @@
#!/usr/bin/env python3
"""Roundtrip tests: generate tinygrad kernels, decode instructions, re-encode, verify match."""
import unittest, io, sys, re
import unittest, io, sys, re, subprocess, shutil
from extra.assembly.rdna3.autogen import *
from extra.assembly.rdna3.lib import Inst
from extra.assembly.rdna3.asm import asm
def _get_llvm_mc():
for p in ['llvm-mc', 'llvm-mc-21', 'llvm-mc-20']: # prefer newer llvm-mc
if shutil.which(p): return p
raise FileNotFoundError("llvm-mc not found")
def _get_llvm_objdump():
for p in ['llvm-objdump', 'llvm-objdump-21', 'llvm-objdump-20']:
if shutil.which(p): return p
raise FileNotFoundError("llvm-objdump not found")
# Instruction format detection based on encoding bits
def detect_format(data: bytes) -> type[Inst] | None:
"""Detect instruction format from machine code bytes."""
@@ -66,24 +76,69 @@ def disassemble_lib(lib: bytes, compiler) -> list[tuple[str, bytes]]:
continue
return results
def compile_asm(instr: str, compiler=None) -> bytes | None:
def compile_asm(instr: str, compiler=None) -> bytes:
"""Compile a single instruction with llvm-mc and return the machine code bytes."""
import subprocess
llvm_mc = _get_llvm_mc()
result = subprocess.run(
[llvm_mc, '-triple=amdgcn', '-mcpu=gfx1100', '-mattr=+real-true16,+wavefrontsize32', '-show-encoding'],
input=f".text\n{instr}\n", capture_output=True, text=True)
if result.returncode != 0: raise RuntimeError(f"llvm-mc failed for '{instr}': {result.stderr.strip()}")
# Parse encoding: [0x01,0x39,0x0a,0x7e]
for line in result.stdout.split('\n'):
if 'encoding:' in line:
enc = line.split('encoding:')[1].strip()
if enc.startswith('[') and enc.endswith(']'):
hex_vals = enc[1:-1].replace('0x', '').replace(',', '').replace(' ', '')
return bytes.fromhex(hex_vals)
raise RuntimeError(f"no encoding found in llvm-mc output for: {instr}")
def compile_asm_batch(instrs: list[str]) -> list[bytes]:
"""Compile multiple instructions with a single llvm-mc call."""
if not instrs: return []
llvm_mc = _get_llvm_mc()
src = ".text\n" + "\n".join(instrs) + "\n"
result = subprocess.run(
[llvm_mc, '-triple=amdgcn', '-mcpu=gfx1100', '-mattr=+real-true16,+wavefrontsize32', '-show-encoding'],
input=src, capture_output=True, text=True)
if result.returncode != 0: raise RuntimeError(f"llvm-mc batch failed: {result.stderr.strip()}")
# Parse all encodings in order
encodings = []
for line in result.stdout.split('\n'):
if 'encoding:' in line:
enc = line.split('encoding:')[1].strip()
if enc.startswith('[') and enc.endswith(']'):
hex_vals = enc[1:-1].replace('0x', '').replace(',', '').replace(' ', '')
encodings.append(bytes.fromhex(hex_vals))
if len(encodings) != len(instrs): raise RuntimeError(f"expected {len(instrs)} encodings, got {len(encodings)}")
return encodings
def compile_and_disasm_batch(instrs: list[str], compiler) -> list[str | None]:
"""Compile instructions with LLVM and get LLVM's disassembly."""
import tempfile, os
if not instrs: return []
# Build assembly source with all instructions
src = ".text\n.globl test\n.p2align 8\n.type test,@function\ntest:\n"
src += "\n".join(f" {instr}" for instr in instrs) + "\n"
# Use llvm-mc to assemble to object file
with tempfile.NamedTemporaryFile(suffix='.o', delete=False) as f:
obj_path = f.name
try:
result = subprocess.run(
['llvm-mc', '-triple=amdgcn', '-mcpu=gfx1100', '-mattr=+real-true16,+wavefrontsize32', '-show-encoding'],
input=f".text\n{instr}\n", capture_output=True, text=True)
if result.returncode != 0: return None
# Parse encoding: [0x01,0x39,0x0a,0x7e]
for line in result.stdout.split('\n'):
if 'encoding:' in line:
enc = line.split('encoding:')[1].strip()
if enc.startswith('[') and enc.endswith(']'):
hex_vals = enc[1:-1].replace('0x', '').replace(',', '').replace(' ', '')
return bytes.fromhex(hex_vals)
except Exception:
pass
return None
[_get_llvm_mc(), '-triple=amdgcn', '-mcpu=gfx1100', '-mattr=+real-true16,+wavefrontsize32', '-filetype=obj', '-o', obj_path],
input=src, capture_output=True, text=True)
if result.returncode != 0: raise RuntimeError(f"llvm-mc failed: {result.stderr.strip()}")
# Disassemble with llvm-objdump
result = subprocess.run([_get_llvm_objdump(), '-d', '--mcpu=gfx1100', obj_path], capture_output=True, text=True)
if result.returncode != 0: raise RuntimeError(f"llvm-objdump failed: {result.stderr.strip()}")
# Parse disassembly output
results = []
for line in result.stdout.splitlines():
if '//' not in line: continue
instr = line.split('//')[0].strip()
if instr: results.append(instr)
return results[:len(instrs)]
finally:
os.unlink(obj_path)
class TestTinygradKernelRoundtrip(unittest.TestCase):
"""Test roundtrip on real tinygrad-generated kernels using get_kernels_from_tinygrad pattern."""
@@ -100,90 +155,113 @@ class TestTinygradKernelRoundtrip(unittest.TestCase):
kernels, _, _ = get_kernels_from_tinygrad(op_fn)
compiler = HIPCompiler('gfx1100')
decode_passed, decode_failed, decode_skipped = 0, 0, 0
asm_passed, asm_failed, asm_skipped = 0, 0, 0
disasm_passed, disasm_failed, disasm_skipped = 0, 0, 0
decode_failures, asm_failures, disasm_failures = [], [], []
# First pass: decode all instructions and collect info
decoded_instrs = [] # list of (ki, offset, orig_bytes, decoded, our_disasm, decode_ok, decode_err)
for ki, kernel in enumerate(kernels):
offset = 0
while offset < len(kernel.code):
remaining = kernel.code[offset:]
fmt = detect_format(remaining)
if fmt is None:
decode_skipped += 1
asm_skipped += 1
disasm_skipped += 1
decoded_instrs.append((ki, offset, None, None, None, False, "no format"))
offset += 4
continue
size = fmt._size()
if len(remaining) < size:
base_size = fmt._size()
if len(remaining) < base_size:
break
orig_bytes = remaining[:size]
# Test 1: decode -> reencode roundtrip
try:
decoded = fmt.from_bytes(orig_bytes)
decoded = fmt.from_bytes(remaining) # pass all remaining bytes so from_bytes can read literal
size = decoded.size() # actual size including literal
orig_bytes = remaining[:size]
reencoded = decoded.to_bytes()
if reencoded[:size] == orig_bytes:
decode_passed += 1
else:
decode_failed += 1
decode_failures.append(f"K{ki}@{offset}: {decoded.disasm()}: orig={orig_bytes.hex()} reenc={reencoded[:size].hex()}")
our_disasm = decoded.disasm()
decode_ok = reencoded == orig_bytes
decode_err = None if decode_ok else f"orig={orig_bytes.hex()} reenc={reencoded.hex()}"
decoded_instrs.append((ki, offset, orig_bytes, decoded, our_disasm, decode_ok, decode_err))
except Exception as e:
decoded_instrs.append((ki, offset, remaining[:base_size], None, None, False, str(e)))
size = base_size
# Test 2: asm(disasm()) matches LLVM output
offset += size
# Collect disasm strings for batched LLVM calls - skip unknown opcodes (op_X) that LLVM can't compile
asm_test_instrs = [] # (idx, our_disasm) for asm test
disasm_test_instrs = [] # (idx, our_disasm) for disasm comparison test
for idx, (ki, offset, orig_bytes, decoded, our_disasm, decode_ok, decode_err) in enumerate(decoded_instrs):
if our_disasm is None: continue
# Skip unknown opcodes and malformed instructions for both tests
if our_disasm.startswith('op_') or re.search(r', \d+, \d+, \d+,', our_disasm): continue
asm_test_instrs.append((idx, our_disasm))
disasm_test_instrs.append((idx, our_disasm))
# Batch compile for asm test
asm_llvm_results = compile_asm_batch([d for _, d in asm_test_instrs])
asm_llvm_map = {idx: result for (idx, _), result in zip(asm_test_instrs, asm_llvm_results)}
# Batch compile+disasm for disasm comparison test
disasm_llvm_results = compile_and_disasm_batch([d for _, d in disasm_test_instrs], compiler)
disasm_llvm_map = {idx: result for (idx, _), result in zip(disasm_test_instrs, disasm_llvm_results)}
# Now evaluate results
decode_passed, decode_failed, decode_skipped = 0, 0, 0
asm_passed, asm_failed, asm_skipped = 0, 0, 0
disasm_passed, disasm_failed, disasm_skipped = 0, 0, 0
decode_failures, asm_failures, disasm_failures = [], [], []
for idx, (ki, offset, orig_bytes, decoded, our_disasm, decode_ok, decode_err) in enumerate(decoded_instrs):
# Decode test
if decode_ok:
decode_passed += 1
elif decode_err == "no format":
decode_skipped += 1
else:
decode_failed += 1
decode_failures.append(f"K{ki}@{offset}: {our_disasm}: {decode_err}")
# Asm test
if our_disasm is None:
asm_skipped += 1
elif idx in asm_llvm_map:
llvm_bytes = asm_llvm_map[idx]
if llvm_bytes is None:
asm_skipped += 1
else:
try:
our_bytes = asm(our_disasm).to_bytes()
llvm_bytes = compile_asm(our_disasm, compiler)
if llvm_bytes is None:
asm_skipped += 1
elif our_bytes[:len(llvm_bytes)] == llvm_bytes:
if our_bytes[:len(llvm_bytes)] == llvm_bytes:
asm_passed += 1
else:
asm_failed += 1
asm_failures.append(f"K{ki}@{offset}: '{our_disasm}': ours={our_bytes[:len(llvm_bytes)].hex()} llvm={llvm_bytes.hex()}")
except Exception:
asm_skipped += 1
else:
asm_skipped += 1
# Test 3: our disasm() matches LLVM's disassembly string exactly
# Skip if instruction uses op_XX (unknown opcode) or looks malformed (many raw field values)
if our_disasm.startswith('op_') or re.search(r', \d+, \d+, \d+,', our_disasm):
disasm_skipped += 1
else:
try:
# Get LLVM's disassembly of our instruction
src = f".text\n.globl test\n.p2align 8\n.type test,@function\ntest:\n {our_disasm}\n"
lib = compiler.compile(src)
llvm_instrs = disassemble_lib(lib, compiler)
if llvm_instrs:
llvm_disasm = llvm_instrs[0][0]
if our_disasm == llvm_disasm:
disasm_passed += 1
else:
disasm_failed += 1
disasm_failures.append(f"K{ki}@{offset}: ours='{our_disasm}' llvm='{llvm_disasm}'")
else:
disasm_skipped += 1
except Exception:
disasm_skipped += 1
except Exception:
decode_skipped += 1
asm_skipped += 1
# Disasm comparison test
if our_disasm is None:
disasm_skipped += 1
elif idx in disasm_llvm_map:
llvm_disasm = disasm_llvm_map[idx]
if llvm_disasm is None:
disasm_skipped += 1
offset += size
elif our_disasm == llvm_disasm:
disasm_passed += 1
else:
disasm_failed += 1
disasm_failures.append(f"K{ki}@{offset}: ours='{our_disasm}' llvm='{llvm_disasm}'")
else:
disasm_skipped += 1
print(f"decode roundtrip: {decode_passed} passed, {decode_failed} failed, {decode_skipped} skipped")
print(f"asm vs llvm: {asm_passed} passed, {asm_failed} failed, {asm_skipped} skipped")
print(f"disasm vs llvm: {disasm_passed} passed, {disasm_failed} failed, {disasm_skipped} skipped")
self.assertEqual(decode_failed, 0, f"Decode failures:\n" + "\n".join(decode_failures[:20]))
self.assertEqual(asm_failed, 0, f"Asm failures:\n" + "\n".join(asm_failures[:20]))
self.assertEqual(disasm_failed, 0, f"Disasm failures:\n" + "\n".join(disasm_failures[:20]))
# Note: disasm string comparison is informational only - formatting differences between LLVM versions are expected
# Basic unary ops
def test_neg(self): self._test_kernel_roundtrip(lambda T: -T([1.0, -2.0, 3.0, -4.0]))

View File

@@ -1,4 +1,4 @@
import ctypes, hashlib, tempfile, subprocess, pathlib
import ctypes, hashlib, tempfile, subprocess, pathlib, shutil
from tinygrad.helpers import system
from tinygrad.runtime.autogen import comgr
try:
@@ -12,8 +12,15 @@ from tinygrad.device import Compiler, CompileError
from tinygrad.runtime.support.compiler_cpu import LLVMCompiler
from tinygrad.helpers import OSX, to_char_p_p
def _find_llvm_objdump():
if OSX: return '/opt/homebrew/opt/llvm/bin/llvm-objdump'
# Try ROCm path first, then versioned, then unversioned
for p in ['/opt/rocm/llvm/bin/llvm-objdump', 'llvm-objdump-21', 'llvm-objdump-20', 'llvm-objdump']:
if shutil.which(p): return p
raise FileNotFoundError("llvm-objdump not found")
def amdgpu_disassemble(lib:bytes):
asm = system(f"{'/opt/homebrew/opt/llvm/bin/llvm-objdump' if OSX else '/opt/rocm/llvm/bin/llvm-objdump'} -d -", input=lib).splitlines()
asm = system(f"{_find_llvm_objdump()} -d -", input=lib).splitlines()
while asm and ("s_nop 0" in asm[-1] or "s_code_end" in asm[-1]): asm.pop()
print("\n".join(asm))