mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
speed up rdna3 unit tests + add to CI (#13871)
* speed up rdna3 unit tests * add test to CI * faster and simpler * speedups * bugfixes * use helper * fix CI maybe * test fixes * llvm-21 on 24.04 * upd * llvm-21 * fix test * bring that back * merge gen into lib * test generators
This commit is contained in:
@@ -97,7 +97,7 @@ def disasm(inst: Inst) -> str:
|
||||
else:
|
||||
op_name = getattr(autogen, f"{cls_name}Op")(op_val).name.lower() if hasattr(autogen, f"{cls_name}Op") else f"op_{op_val}"
|
||||
except (ValueError, KeyError): op_name = f"op_{op_val}"
|
||||
def fmt_src(v): return f"0x{inst._literal:x}" if v == 255 and getattr(inst, '_literal', None) else decode_src(v)
|
||||
def fmt_src(v): return f"0x{inst._literal:x}" if v == 255 and inst._literal is not None else decode_src(v)
|
||||
|
||||
# VOP1
|
||||
if cls_name == 'VOP1':
|
||||
@@ -440,7 +440,9 @@ def disasm(inst: Inst) -> str:
|
||||
return f"{op_name} {_fmt_sdst(sdst, dst_cnt)}, {ssrc0_str}"
|
||||
if cls_name == 'SOP2':
|
||||
sdst, ssrc0, ssrc1 = [unwrap(inst._values.get(f, 0)) for f in ('sdst', 'ssrc0', 'ssrc1')]
|
||||
return f"{op_name} {_fmt_sdst(sdst, dst_cnt)}, {_fmt_ssrc(ssrc0, src0_cnt)}, {_fmt_ssrc(ssrc1, src1_cnt)}"
|
||||
ssrc0_str = fmt_src(ssrc0) if ssrc0 == 255 else _fmt_ssrc(ssrc0, src0_cnt)
|
||||
ssrc1_str = fmt_src(ssrc1) if ssrc1 == 255 else _fmt_ssrc(ssrc1, src1_cnt)
|
||||
return f"{op_name} {_fmt_sdst(sdst, dst_cnt)}, {ssrc0_str}, {ssrc1_str}"
|
||||
if cls_name == 'SOPC':
|
||||
return f"{op_name} {_fmt_ssrc(unwrap(inst._values.get('ssrc0', 0)), src0_cnt)}, {_fmt_ssrc(unwrap(inst._values.get('ssrc1', 0)), src1_cnt)}"
|
||||
if cls_name == 'SOPK':
|
||||
@@ -557,7 +559,8 @@ def asm(text: str) -> Inst:
|
||||
elif mnemonic in ('v_fmamk_f32', 'v_fmamk_f16') and len(values) == 4: lit, values = unwrap(values[2]), [values[0], values[1], values[3]]
|
||||
vcc_ops = {'v_add_co_ci_u32', 'v_sub_co_ci_u32', 'v_subrev_co_ci_u32', 'v_add_co_u32', 'v_sub_co_u32', 'v_subrev_co_u32'}
|
||||
if mnemonic.replace('_e32', '') in vcc_ops and len(values) >= 5: values = [values[0], values[2], values[3]]
|
||||
if mnemonic.startswith('v_cmp') and len(values) >= 3 and operands[0].strip().lower() in ('vcc_lo', 'vcc_hi', 'vcc'):
|
||||
# v_cmp_*_e32: strip implicit vcc_lo dest. v_cmp_*_e64: keep vdst (vcc_lo encodes to 106)
|
||||
if mnemonic.startswith('v_cmp') and not mnemonic.endswith('_e64') and len(values) >= 3 and operands[0].strip().lower() in ('vcc_lo', 'vcc_hi', 'vcc'):
|
||||
values = values[1:]
|
||||
# CMPX instructions with _e64 suffix: prepend implicit EXEC_LO destination (vdst=126)
|
||||
if 'cmpx' in mnemonic and mnemonic.endswith('_e64') and len(values) == 2:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# autogenerated from AMD RDNA3.5 ISA PDF by gen.py - do not edit
|
||||
# autogenerated from AMD RDNA3.5 ISA PDF by lib.py - do not edit
|
||||
from enum import IntEnum
|
||||
from typing import Annotated
|
||||
from extra.assembly.rdna3.lib import bits, BitField, Inst32, Inst64, SGPR, VGPR, TTMP as TTMP, s as s, v as v, ttmp as ttmp, SSrc, Src, SImm, Imm, VDSTYEnc, SGPRField, VGPRField
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -18,6 +18,15 @@ VCC_LO, VCC_HI, NULL, EXEC_LO, EXEC_HI, SCC = SrcEnum.VCC_LO, SrcEnum.VCC_HI, Sr
|
||||
_VOP3_64BIT_OPS = {op.value for op in VOP3Op if op.name.endswith(('_F64', '_B64', '_I64', '_U64'))}
|
||||
# Ops where src1 is 32-bit (exponent/shift amount) even though the op name suggests 64-bit
|
||||
_VOP3_64BIT_OPS_32BIT_SRC1 = {VOP3Op.V_LDEXP_F64.value}
|
||||
# Ops with 16-bit types in name (for source/dest handling)
|
||||
_VOP3_16BIT_OPS = {op for op in VOP3Op if any(s in op.name for s in ('_F16', '_B16', '_I16', '_U16'))}
|
||||
_VOP1_16BIT_OPS = {op for op in VOP1Op if any(s in op.name for s in ('_F16', '_B16', '_I16', '_U16'))}
|
||||
# CVT ops with 32/64-bit source (despite 16-bit in name)
|
||||
_CVT_32_64_SRC_OPS = {op for op in VOP3Op if op.name.startswith('V_CVT_') and op.name.endswith(('_F32', '_I32', '_U32', '_F64', '_I64', '_U64'))} | \
|
||||
{op for op in VOP1Op if op.name.startswith('V_CVT_') and op.name.endswith(('_F32', '_I32', '_U32', '_F64', '_I64', '_U64'))}
|
||||
# 16-bit dst ops (PACK has 32-bit dst despite F16 in name)
|
||||
_VOP3_16BIT_DST_OPS = {op for op in _VOP3_16BIT_OPS if 'PACK' not in op.name}
|
||||
_VOP1_16BIT_DST_OPS = {op for op in _VOP1_16BIT_OPS if 'PACK' not in op.name}
|
||||
|
||||
# Inline constants for src operands 128-254 (f32 format for most instructions)
|
||||
_INLINE_CONSTS = [0] * 127
|
||||
@@ -509,11 +518,9 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No
|
||||
# V_LDEXP_F64: src0 is 64-bit float, src1 is 32-bit integer exponent
|
||||
is_ldexp_64 = op in (VOP3Op.V_LDEXP_F64,)
|
||||
is_shift_64 = op in (VOP3Op.V_LSHLREV_B64, VOP3Op.V_LSHRREV_B64, VOP3Op.V_ASHRREV_I64)
|
||||
# 16-bit source ops: name contains 16-bit type, but for CVT ops check the SOURCE type (CVT naming is V_CVT_DST_SRC)
|
||||
# For CVT: source type is at the end of the name, so V_CVT_F16_F32 has 32-bit src, V_CVT_F32_F16 has 16-bit src
|
||||
has_16bit_type = any(s in op.name for s in ('_F16', '_B16', '_I16', '_U16'))
|
||||
is_cvt_with_32_64_src = op.name.startswith('V_CVT_') and op.name.endswith(('_F32', '_I32', '_U32', '_F64', '_I64', '_U64'))
|
||||
is_16bit_src = op_cls is VOP3Op and has_16bit_type and not is_cvt_with_32_64_src
|
||||
# 16-bit source ops: use precomputed sets instead of string checks
|
||||
has_16bit_type = op in _VOP3_16BIT_OPS or op in _VOP1_16BIT_OPS
|
||||
is_16bit_src = op_cls is VOP3Op and op in _VOP3_16BIT_OPS and op not in _CVT_32_64_SRC_OPS
|
||||
|
||||
if is_shift_64:
|
||||
s0 = mod_src(st.rsrc(src0, lane), 0) # shift amount is 32-bit
|
||||
@@ -571,8 +578,7 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No
|
||||
writes_to_sgpr = op in (VOP1Op.V_READFIRSTLANE_B32,) or \
|
||||
(op_cls is VOP3Op and op in (VOP3Op.V_READFIRSTLANE_B32, VOP3Op.V_READLANE_B32))
|
||||
# Check for 16-bit destination ops (opsel[3] controls hi/lo write)
|
||||
# 16-bit dst ops (exclude PACK which has 32-bit dst despite F16 in name)
|
||||
is_16bit_dst = any(s in op.name for s in ('_F16', '_B16', '_I16', '_U16')) and 'PACK' not in op.name
|
||||
is_16bit_dst = op in _VOP3_16BIT_DST_OPS or op in _VOP1_16BIT_DST_OPS
|
||||
if writes_to_sgpr:
|
||||
st.wsgpr(vdst, result['d0'] & 0xffffffff)
|
||||
elif result.get('d0_64') or is_64bit_op:
|
||||
|
||||
@@ -1,199 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# generates autogen/__init__.py by parsing the AMD RDNA3.5 ISA PDF
|
||||
import re, pdfplumber, pathlib
|
||||
from tinygrad.helpers import fetch
|
||||
|
||||
PDF_URL = "https://docs.amd.com/api/khub/documents/UVVZM22UN7tMUeiW_4ShTQ/content"
|
||||
FIELD_TYPES = {'SSRC0': 'SSrc', 'SSRC1': 'SSrc', 'SOFFSET': 'SSrc', 'SADDR': 'SSrc', 'SRC0': 'Src', 'SRC1': 'Src', 'SRC2': 'Src',
|
||||
'SDST': 'SGPRField', 'SBASE': 'SGPRField', 'SDATA': 'SGPRField', 'SRSRC': 'SGPRField', 'VDST': 'VGPRField', 'VSRC1': 'VGPRField', 'VDATA': 'VGPRField',
|
||||
'VADDR': 'VGPRField', 'ADDR': 'VGPRField', 'DATA': 'VGPRField', 'DATA0': 'VGPRField', 'DATA1': 'VGPRField', 'SIMM16': 'SImm', 'OFFSET': 'Imm',
|
||||
'OPX': 'VOPDOp', 'OPY': 'VOPDOp', 'SRCX0': 'Src', 'SRCY0': 'Src', 'VSRCX1': 'VGPRField', 'VSRCY1': 'VGPRField', 'VDSTX': 'VGPRField', 'VDSTY': 'VDSTYEnc'}
|
||||
FIELD_ORDER = {
|
||||
'SOP2': ['op', 'sdst', 'ssrc0', 'ssrc1'], 'SOP1': ['op', 'sdst', 'ssrc0'], 'SOPC': ['op', 'ssrc0', 'ssrc1'],
|
||||
'SOPK': ['op', 'sdst', 'simm16'], 'SOPP': ['op', 'simm16'], 'VOP1': ['op', 'vdst', 'src0'], 'VOPC': ['op', 'src0', 'vsrc1'],
|
||||
'VOP2': ['op', 'vdst', 'src0', 'vsrc1'], 'VOP3SD': ['op', 'vdst', 'sdst', 'src0', 'src1', 'src2', 'clmp'],
|
||||
'SMEM': ['op', 'sdata', 'sbase', 'soffset', 'offset', 'glc', 'dlc'], 'DS': ['op', 'vdst', 'addr', 'data0', 'data1'],
|
||||
'VOP3': ['op', 'vdst', 'src0', 'src1', 'src2', 'omod', 'neg', 'abs', 'clmp', 'opsel'],
|
||||
'VOP3P': ['op', 'vdst', 'src0', 'src1', 'src2', 'neg', 'neg_hi', 'opsel', 'opsel_hi', 'clmp'],
|
||||
'FLAT': ['op', 'vdst', 'addr', 'data', 'saddr', 'offset', 'seg', 'dlc', 'glc', 'slc'],
|
||||
'MUBUF': ['op', 'vdata', 'vaddr', 'srsrc', 'soffset', 'offset', 'offen', 'idxen', 'glc', 'dlc', 'slc', 'tfe'],
|
||||
'MTBUF': ['op', 'vdata', 'vaddr', 'srsrc', 'soffset', 'offset', 'format', 'offen', 'idxen', 'glc', 'dlc', 'slc', 'tfe'],
|
||||
'MIMG': ['op', 'vdata', 'vaddr', 'srsrc', 'ssamp', 'dmask', 'dim', 'unrm', 'dlc', 'glc', 'slc'],
|
||||
'EXP': ['en', 'target', 'vsrc0', 'vsrc1', 'vsrc2', 'vsrc3', 'done', 'row'],
|
||||
'VINTERP': ['op', 'vdst', 'src0', 'src1', 'src2', 'waitexp', 'clmp', 'opsel', 'neg'],
|
||||
'VOPD': ['opx', 'opy', 'vdstx', 'vdsty', 'srcx0', 'vsrcx1', 'srcy0', 'vsrcy1'],
|
||||
'LDSDIR': ['op', 'vdst', 'attr', 'attr_chan', 'wait_va']}
|
||||
SRC_EXTRAS = {233: 'DPP8', 234: 'DPP8FI', 250: 'DPP16', 251: 'VCCZ', 252: 'EXECZ', 254: 'LDS_DIRECT'}
|
||||
FLOAT_MAP = {'0.5': 'POS_HALF', '-0.5': 'NEG_HALF', '1.0': 'POS_ONE', '-1.0': 'NEG_ONE', '2.0': 'POS_TWO', '-2.0': 'NEG_TWO',
|
||||
'4.0': 'POS_FOUR', '-4.0': 'NEG_FOUR', '1/(2*PI)': 'INV_2PI', '0': 'ZERO'}
|
||||
|
||||
def parse_bits(s: str) -> tuple[int, int] | None:
|
||||
return (int(m.group(1)), int(m.group(2) or m.group(1))) if (m := re.match(r'\[(\d+)(?::(\d+))?\]', s)) else None
|
||||
|
||||
def parse_fields_table(table: list, fmt: str, enums: set[str]) -> list[tuple]:
|
||||
fields = []
|
||||
for row in table[1:]:
|
||||
if not row or not row[0]: continue
|
||||
name, bits_str = row[0].split('\n')[0].strip(), (row[1] or '').split('\n')[0].strip()
|
||||
if not (bits := parse_bits(bits_str)): continue
|
||||
enc_val, hi, lo = None, bits[0], bits[1]
|
||||
if name == 'ENCODING' and row[2] and (m := re.search(r"'b([01_]+)", row[2])):
|
||||
enc_bits = m.group(1).replace('_', '')
|
||||
enc_val = int(enc_bits, 2)
|
||||
declared_width, actual_width = hi - lo + 1, len(enc_bits)
|
||||
if actual_width > declared_width: lo = hi - actual_width + 1
|
||||
ftype = f"{fmt}Op" if name == 'OP' and f"{fmt}Op" in enums else FIELD_TYPES.get(name.upper())
|
||||
fields.append((name, hi, lo, enc_val, ftype))
|
||||
return fields
|
||||
|
||||
def generate(output_path: pathlib.Path|str|None = None) -> dict:
|
||||
"""Generate RDNA3.5 instruction definitions from the AMD ISA PDF. Returns dict with formats for testing."""
|
||||
pdf = pdfplumber.open(fetch(PDF_URL))
|
||||
pages = pdf.pages[150:200]
|
||||
page_texts = [p.extract_text() or '' for p in pages]
|
||||
page_tables = [[t.extract() for t in p.find_tables()] for p in pages]
|
||||
full_text = '\n'.join(page_texts)
|
||||
|
||||
# parse SSRC encoding from first page with VCC_LO
|
||||
src_enum = dict(SRC_EXTRAS)
|
||||
for text in page_texts[:10]:
|
||||
if 'SSRC0' in text and 'VCC_LO' in text:
|
||||
for m in re.finditer(r'^(\d+)\s+(\S+)', text, re.M):
|
||||
val, name = int(m.group(1)), m.group(2).rstrip('.:')
|
||||
if name in FLOAT_MAP: src_enum[val] = FLOAT_MAP[name]
|
||||
elif re.match(r'^[A-Z][A-Z0-9_]*$', name): src_enum[val] = name
|
||||
break
|
||||
|
||||
# parse opcode tables
|
||||
enums: dict[str, dict[int, str]] = {}
|
||||
for m in re.finditer(r'Table \d+\. (\w+) Opcodes(.*?)(?=Table \d+\.|\n\d+\.\d+\.\d+\.\s+\w+\s*\nDescription|$)', full_text, re.S):
|
||||
if ops := {int(x.group(1)): x.group(2) for x in re.finditer(r'(\d+)\s+([A-Z][A-Z0-9_]+)', m.group(2))}:
|
||||
enums[m.group(1) + "Op"] = ops
|
||||
if vopd_m := re.search(r'Table \d+\. VOPD Y-Opcodes\n(.*?)(?=Table \d+\.|15\.\d)', full_text, re.S):
|
||||
if ops := {int(x.group(1)): x.group(2) for x in re.finditer(r'(\d+)\s+(V_DUAL_\w+)', vopd_m.group(1))}:
|
||||
enums["VOPDOp"] = ops
|
||||
enum_names = set(enums.keys())
|
||||
|
||||
def is_fields_table(t) -> bool: return t and len(t) > 1 and t[0] and 'Field' in str(t[0][0] or '')
|
||||
def has_encoding(fields) -> bool: return any(f[0] == 'ENCODING' for f in fields)
|
||||
def has_header_before_fields(text) -> bool:
|
||||
return (pos := text.find('Field Name')) != -1 and bool(re.search(r'\d+\.\d+\.\d+\.\s+\w+\s*\n', text[:pos]))
|
||||
|
||||
# find format headers with their page indices
|
||||
format_headers = [] # (fmt_name, page_idx)
|
||||
for i, text in enumerate(page_texts):
|
||||
for m in re.finditer(r'\d+\.\d+\.\d+\.\s+(\w+)\s*\n?Description', text): format_headers.append((m.group(1), i, m.start()))
|
||||
for m in re.finditer(r'\d+\.\d+\.\d+\.\s+(\w+)\s*\n', text):
|
||||
if m.start() > len(text) - 200 and 'Description' not in text[m.end():] and i + 1 < len(page_texts):
|
||||
next_text = page_texts[i + 1].lstrip()
|
||||
if next_text.startswith('Description') or (next_text.startswith('"RDNA') and 'Description' in next_text[:200]):
|
||||
format_headers.append((m.group(1), i, m.start()))
|
||||
|
||||
# parse instruction formats
|
||||
formats: dict[str, list] = {}
|
||||
for fmt_name, page_idx, header_pos in format_headers:
|
||||
if fmt_name in formats: continue
|
||||
text, tables = page_texts[page_idx], page_tables[page_idx]
|
||||
field_pos = text.find('Field Name', header_pos)
|
||||
|
||||
# find fields table with ENCODING (same page or up to 2 pages ahead)
|
||||
fields = None
|
||||
for offset in range(3):
|
||||
if page_idx + offset >= len(pages): break
|
||||
if offset > 0 and has_header_before_fields(page_texts[page_idx + offset]): break
|
||||
for t in page_tables[page_idx + offset] if offset > 0 or field_pos > header_pos else []:
|
||||
if is_fields_table(t) and (f := parse_fields_table(t, fmt_name, enum_names)) and has_encoding(f):
|
||||
fields = f
|
||||
break
|
||||
if fields: break
|
||||
|
||||
# for modifier formats (no ENCODING), accept first fields table on same page
|
||||
if not fields and field_pos > header_pos:
|
||||
for t in tables:
|
||||
if is_fields_table(t) and (f := parse_fields_table(t, fmt_name, enum_names)):
|
||||
fields = f
|
||||
break
|
||||
|
||||
if not fields: continue
|
||||
field_names = {f[0] for f in fields}
|
||||
|
||||
# check next pages for continuation fields (tables without ENCODING)
|
||||
for pg_offset in range(1, 3):
|
||||
if page_idx + pg_offset >= len(pages) or has_header_before_fields(page_texts[page_idx + pg_offset]): break
|
||||
for t in page_tables[page_idx + pg_offset]:
|
||||
if is_fields_table(t) and (extra := parse_fields_table(t, fmt_name, enum_names)) and not has_encoding(extra):
|
||||
for ef in extra:
|
||||
if ef[0] not in field_names:
|
||||
fields.append(ef)
|
||||
field_names.add(ef[0])
|
||||
break
|
||||
formats[fmt_name] = fields
|
||||
|
||||
# fix known PDF errors (verified against LLVM test vectors)
|
||||
# SMEM: PDF says DLC=bit14, GLC=bit16 but actual encoding is DLC=bit13, GLC=bit14
|
||||
if 'SMEM' in formats:
|
||||
formats['SMEM'] = [(n, 13 if n == 'DLC' else 14 if n == 'GLC' else h, 13 if n == 'DLC' else 14 if n == 'GLC' else l, e, t)
|
||||
for n, h, l, e, t in formats['SMEM']]
|
||||
|
||||
# generate output
|
||||
def enum_lines(name, items):
|
||||
return [f"class {name}(IntEnum):"] + [f" {n} = {v}" for v, n in sorted(items.items())] + [""]
|
||||
def field_key(f): return order.index(f[0].lower()) if f[0].lower() in order else 1000
|
||||
lines = ["# autogenerated from AMD RDNA3.5 ISA PDF by gen.py - do not edit", "from enum import IntEnum",
|
||||
"from typing import Annotated",
|
||||
"from extra.assembly.rdna3.lib import bits, BitField, Inst32, Inst64, SGPR, VGPR, TTMP as TTMP, s as s, v as v, ttmp as ttmp, SSrc, Src, SImm, Imm, VDSTYEnc, SGPRField, VGPRField",
|
||||
"import functools", ""]
|
||||
lines += enum_lines("SrcEnum", src_enum) + sum([enum_lines(n, ops) for n, ops in sorted(enums.items())], [])
|
||||
# Format-specific field defaults (verified against LLVM test vectors)
|
||||
format_defaults = {'VOP3P': {'opsel_hi': 3, 'opsel_hi2': 1}}
|
||||
lines.append("# instruction formats")
|
||||
for fmt_name, fields in sorted(formats.items()):
|
||||
base = "Inst64" if max(f[1] for f in fields) > 31 or fmt_name == 'VOP3SD' else "Inst32"
|
||||
order = FIELD_ORDER.get(fmt_name, [])
|
||||
lines.append(f"class {fmt_name}({base}):")
|
||||
if enc := next((f for f in fields if f[0] == 'ENCODING'), None):
|
||||
enc_str = f"bits[{enc[1]}:{enc[2]}] == 0b{enc[3]:b}" if enc[1] != enc[2] else f"bits[{enc[1]}] == {enc[3]}"
|
||||
lines.append(f" encoding = {enc_str}")
|
||||
if defaults := format_defaults.get(fmt_name):
|
||||
lines.append(f" _defaults = {defaults}")
|
||||
for name, hi, lo, _, ftype in sorted([f for f in fields if f[0] != 'ENCODING'], key=field_key):
|
||||
# Wrap IntEnum types (ending in Op) with Annotated[BitField, ...] for correct typing
|
||||
if ftype and ftype.endswith('Op'):
|
||||
ann = f":Annotated[BitField, {ftype}]"
|
||||
else:
|
||||
ann = f":{ftype}" if ftype else ""
|
||||
lines.append(f" {name.lower()}{ann} = bits[{hi}]" if hi == lo else f" {name.lower()}{ann} = bits[{hi}:{lo}]")
|
||||
lines.append("")
|
||||
lines.append("# instruction helpers")
|
||||
for cls_name, ops in sorted(enums.items()):
|
||||
fmt = cls_name[:-2]
|
||||
for op_val, name in sorted(ops.items()):
|
||||
seg = {"GLOBAL": ", seg=2", "SCRATCH": ", seg=2"}.get(fmt, "")
|
||||
tgt = {"GLOBAL": "FLAT, GLOBALOp", "SCRATCH": "FLAT, SCRATCHOp"}.get(fmt, f"{fmt}, {cls_name}")
|
||||
if fmt in formats or fmt in ("GLOBAL", "SCRATCH"):
|
||||
# VOP1/VOP2/VOPC get _e32 suffix, VOP3 promoted ops (< 512) get _e64 suffix
|
||||
if fmt in ("VOP1", "VOP2", "VOPC"):
|
||||
suffix = "_e32"
|
||||
elif fmt == "VOP3" and op_val < 512:
|
||||
suffix = "_e64"
|
||||
else:
|
||||
suffix = ""
|
||||
# FMAMK/FMAAK have a literal constant K that must be passed via literal= kwarg
|
||||
# FMAMK: D = S0.f * K + S1.f (K is 3rd operand in assembly syntax)
|
||||
# FMAAK: D = S0.f * S1.f + K (K is 4th operand in assembly syntax)
|
||||
if name in ('V_FMAMK_F32', 'V_FMAMK_F16'):
|
||||
lines.append(f"def {name.lower()}{suffix}(vdst, src0, K, vsrc1): return {fmt}({cls_name}.{name}, vdst, src0, vsrc1, literal=K)")
|
||||
elif name in ('V_FMAAK_F32', 'V_FMAAK_F16'):
|
||||
lines.append(f"def {name.lower()}{suffix}(vdst, src0, vsrc1, K): return {fmt}({cls_name}.{name}, vdst, src0, vsrc1, literal=K)")
|
||||
else:
|
||||
lines.append(f"{name.lower()}{suffix} = functools.partial({tgt}.{name}{seg})")
|
||||
# export SrcEnum values, but skip DPP8/DPP16 which conflict with class names
|
||||
skip_exports = {'DPP8', 'DPP16'}
|
||||
lines += [""] + [f"{name} = SrcEnum.{name}" for _, name in sorted(src_enum.items()) if name not in skip_exports] + ["OFF = NULL\n"]
|
||||
|
||||
if output_path is not None: pathlib.Path(output_path).write_text('\n'.join(lines))
|
||||
return {"formats": formats, "enums": enums, "src_enum": src_enum}
|
||||
|
||||
if __name__ == "__main__":
|
||||
result = generate("extra/assembly/rdna3/autogen/__init__.py")
|
||||
print(f"generated SrcEnum ({len(result['src_enum'])}) + {len(result['enums'])} opcode enums + {len(result['formats'])} format classes")
|
||||
@@ -281,3 +281,209 @@ class Inst:
|
||||
|
||||
class Inst32(Inst): pass
|
||||
class Inst64(Inst): pass
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# CODE GENERATION: generates autogen/__init__.py by parsing the AMD RDNA3.5 ISA PDF
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
PDF_URL = "https://docs.amd.com/api/khub/documents/UVVZM22UN7tMUeiW_4ShTQ/content"
|
||||
FIELD_TYPES = {'SSRC0': 'SSrc', 'SSRC1': 'SSrc', 'SOFFSET': 'SSrc', 'SADDR': 'SSrc', 'SRC0': 'Src', 'SRC1': 'Src', 'SRC2': 'Src',
|
||||
'SDST': 'SGPRField', 'SBASE': 'SGPRField', 'SDATA': 'SGPRField', 'SRSRC': 'SGPRField', 'VDST': 'VGPRField', 'VSRC1': 'VGPRField', 'VDATA': 'VGPRField',
|
||||
'VADDR': 'VGPRField', 'ADDR': 'VGPRField', 'DATA': 'VGPRField', 'DATA0': 'VGPRField', 'DATA1': 'VGPRField', 'SIMM16': 'SImm', 'OFFSET': 'Imm',
|
||||
'OPX': 'VOPDOp', 'OPY': 'VOPDOp', 'SRCX0': 'Src', 'SRCY0': 'Src', 'VSRCX1': 'VGPRField', 'VSRCY1': 'VGPRField', 'VDSTX': 'VGPRField', 'VDSTY': 'VDSTYEnc'}
|
||||
FIELD_ORDER = {
|
||||
'SOP2': ['op', 'sdst', 'ssrc0', 'ssrc1'], 'SOP1': ['op', 'sdst', 'ssrc0'], 'SOPC': ['op', 'ssrc0', 'ssrc1'],
|
||||
'SOPK': ['op', 'sdst', 'simm16'], 'SOPP': ['op', 'simm16'], 'VOP1': ['op', 'vdst', 'src0'], 'VOPC': ['op', 'src0', 'vsrc1'],
|
||||
'VOP2': ['op', 'vdst', 'src0', 'vsrc1'], 'VOP3SD': ['op', 'vdst', 'sdst', 'src0', 'src1', 'src2', 'clmp'],
|
||||
'SMEM': ['op', 'sdata', 'sbase', 'soffset', 'offset', 'glc', 'dlc'], 'DS': ['op', 'vdst', 'addr', 'data0', 'data1'],
|
||||
'VOP3': ['op', 'vdst', 'src0', 'src1', 'src2', 'omod', 'neg', 'abs', 'clmp', 'opsel'],
|
||||
'VOP3P': ['op', 'vdst', 'src0', 'src1', 'src2', 'neg', 'neg_hi', 'opsel', 'opsel_hi', 'clmp'],
|
||||
'FLAT': ['op', 'vdst', 'addr', 'data', 'saddr', 'offset', 'seg', 'dlc', 'glc', 'slc'],
|
||||
'MUBUF': ['op', 'vdata', 'vaddr', 'srsrc', 'soffset', 'offset', 'offen', 'idxen', 'glc', 'dlc', 'slc', 'tfe'],
|
||||
'MTBUF': ['op', 'vdata', 'vaddr', 'srsrc', 'soffset', 'offset', 'format', 'offen', 'idxen', 'glc', 'dlc', 'slc', 'tfe'],
|
||||
'MIMG': ['op', 'vdata', 'vaddr', 'srsrc', 'ssamp', 'dmask', 'dim', 'unrm', 'dlc', 'glc', 'slc'],
|
||||
'EXP': ['en', 'target', 'vsrc0', 'vsrc1', 'vsrc2', 'vsrc3', 'done', 'row'],
|
||||
'VINTERP': ['op', 'vdst', 'src0', 'src1', 'src2', 'waitexp', 'clmp', 'opsel', 'neg'],
|
||||
'VOPD': ['opx', 'opy', 'vdstx', 'vdsty', 'srcx0', 'vsrcx1', 'srcy0', 'vsrcy1'],
|
||||
'LDSDIR': ['op', 'vdst', 'attr', 'attr_chan', 'wait_va']}
|
||||
SRC_EXTRAS = {233: 'DPP8', 234: 'DPP8FI', 250: 'DPP16', 251: 'VCCZ', 252: 'EXECZ', 254: 'LDS_DIRECT'}
|
||||
FLOAT_MAP = {'0.5': 'POS_HALF', '-0.5': 'NEG_HALF', '1.0': 'POS_ONE', '-1.0': 'NEG_ONE', '2.0': 'POS_TWO', '-2.0': 'NEG_TWO',
|
||||
'4.0': 'POS_FOUR', '-4.0': 'NEG_FOUR', '1/(2*PI)': 'INV_2PI', '0': 'ZERO'}
|
||||
|
||||
def _parse_bits(s: str) -> tuple[int, int] | None:
|
||||
import re
|
||||
return (int(m.group(1)), int(m.group(2) or m.group(1))) if (m := re.match(r'\[(\d+)(?::(\d+))?\]', s)) else None
|
||||
|
||||
def _parse_fields_table(table: list, fmt: str, enums: set[str]) -> list[tuple]:
|
||||
import re
|
||||
fields = []
|
||||
for row in table[1:]:
|
||||
if not row or not row[0]: continue
|
||||
name, bits_str = row[0].split('\n')[0].strip(), (row[1] or '').split('\n')[0].strip()
|
||||
if not (bits := _parse_bits(bits_str)): continue
|
||||
enc_val, hi, lo = None, bits[0], bits[1]
|
||||
if name == 'ENCODING' and row[2] and (m := re.search(r"'b([01_]+)", row[2])):
|
||||
enc_bits = m.group(1).replace('_', '')
|
||||
enc_val = int(enc_bits, 2)
|
||||
declared_width, actual_width = hi - lo + 1, len(enc_bits)
|
||||
if actual_width > declared_width: lo = hi - actual_width + 1
|
||||
ftype = f"{fmt}Op" if name == 'OP' and f"{fmt}Op" in enums else FIELD_TYPES.get(name.upper())
|
||||
fields.append((name, hi, lo, enc_val, ftype))
|
||||
return fields
|
||||
|
||||
def generate(output_path: str | None = None) -> dict:
|
||||
"""Generate RDNA3.5 instruction definitions from the AMD ISA PDF. Returns dict with formats for testing."""
|
||||
import re, pdfplumber, pathlib
|
||||
from tinygrad.helpers import fetch
|
||||
|
||||
pdf = pdfplumber.open(fetch(PDF_URL))
|
||||
pages = pdf.pages[150:200]
|
||||
page_texts = [p.extract_text() or '' for p in pages]
|
||||
page_tables = [[t.extract() for t in p.find_tables()] for p in pages]
|
||||
full_text = '\n'.join(page_texts)
|
||||
|
||||
# parse SSRC encoding from first page with VCC_LO
|
||||
src_enum = dict(SRC_EXTRAS)
|
||||
for text in page_texts[:10]:
|
||||
if 'SSRC0' in text and 'VCC_LO' in text:
|
||||
for m in re.finditer(r'^(\d+)\s+(\S+)', text, re.M):
|
||||
val, name = int(m.group(1)), m.group(2).rstrip('.:')
|
||||
if name in FLOAT_MAP: src_enum[val] = FLOAT_MAP[name]
|
||||
elif re.match(r'^[A-Z][A-Z0-9_]*$', name): src_enum[val] = name
|
||||
break
|
||||
|
||||
# parse opcode tables
|
||||
enums: dict[str, dict[int, str]] = {}
|
||||
for m in re.finditer(r'Table \d+\. (\w+) Opcodes(.*?)(?=Table \d+\.|\n\d+\.\d+\.\d+\.\s+\w+\s*\nDescription|$)', full_text, re.S):
|
||||
if ops := {int(x.group(1)): x.group(2) for x in re.finditer(r'(\d+)\s+([A-Z][A-Z0-9_]+)', m.group(2))}:
|
||||
enums[m.group(1) + "Op"] = ops
|
||||
if vopd_m := re.search(r'Table \d+\. VOPD Y-Opcodes\n(.*?)(?=Table \d+\.|15\.\d)', full_text, re.S):
|
||||
if ops := {int(x.group(1)): x.group(2) for x in re.finditer(r'(\d+)\s+(V_DUAL_\w+)', vopd_m.group(1))}:
|
||||
enums["VOPDOp"] = ops
|
||||
enum_names = set(enums.keys())
|
||||
|
||||
def is_fields_table(t) -> bool: return t and len(t) > 1 and t[0] and 'Field' in str(t[0][0] or '')
|
||||
def has_encoding(fields) -> bool: return any(f[0] == 'ENCODING' for f in fields)
|
||||
def has_header_before_fields(text) -> bool:
|
||||
return (pos := text.find('Field Name')) != -1 and bool(re.search(r'\d+\.\d+\.\d+\.\s+\w+\s*\n', text[:pos]))
|
||||
|
||||
# find format headers with their page indices
|
||||
format_headers = [] # (fmt_name, page_idx)
|
||||
for i, text in enumerate(page_texts):
|
||||
for m in re.finditer(r'\d+\.\d+\.\d+\.\s+(\w+)\s*\n?Description', text): format_headers.append((m.group(1), i, m.start()))
|
||||
for m in re.finditer(r'\d+\.\d+\.\d+\.\s+(\w+)\s*\n', text):
|
||||
if m.start() > len(text) - 200 and 'Description' not in text[m.end():] and i + 1 < len(page_texts):
|
||||
next_text = page_texts[i + 1].lstrip()
|
||||
if next_text.startswith('Description') or (next_text.startswith('"RDNA') and 'Description' in next_text[:200]):
|
||||
format_headers.append((m.group(1), i, m.start()))
|
||||
|
||||
# parse instruction formats
|
||||
formats: dict[str, list] = {}
|
||||
for fmt_name, page_idx, header_pos in format_headers:
|
||||
if fmt_name in formats: continue
|
||||
text, tables = page_texts[page_idx], page_tables[page_idx]
|
||||
field_pos = text.find('Field Name', header_pos)
|
||||
|
||||
# find fields table with ENCODING (same page or up to 2 pages ahead)
|
||||
fields = None
|
||||
for offset in range(3):
|
||||
if page_idx + offset >= len(pages): break
|
||||
if offset > 0 and has_header_before_fields(page_texts[page_idx + offset]): break
|
||||
for t in page_tables[page_idx + offset] if offset > 0 or field_pos > header_pos else []:
|
||||
if is_fields_table(t) and (f := _parse_fields_table(t, fmt_name, enum_names)) and has_encoding(f):
|
||||
fields = f
|
||||
break
|
||||
if fields: break
|
||||
|
||||
# for modifier formats (no ENCODING), accept first fields table on same page
|
||||
if not fields and field_pos > header_pos:
|
||||
for t in tables:
|
||||
if is_fields_table(t) and (f := _parse_fields_table(t, fmt_name, enum_names)):
|
||||
fields = f
|
||||
break
|
||||
|
||||
if not fields: continue
|
||||
field_names = {f[0] for f in fields}
|
||||
|
||||
# check next pages for continuation fields (tables without ENCODING)
|
||||
for pg_offset in range(1, 3):
|
||||
if page_idx + pg_offset >= len(pages) or has_header_before_fields(page_texts[page_idx + pg_offset]): break
|
||||
for t in page_tables[page_idx + pg_offset]:
|
||||
if is_fields_table(t) and (extra := _parse_fields_table(t, fmt_name, enum_names)) and not has_encoding(extra):
|
||||
for ef in extra:
|
||||
if ef[0] not in field_names:
|
||||
fields.append(ef)
|
||||
field_names.add(ef[0])
|
||||
break
|
||||
formats[fmt_name] = fields
|
||||
|
||||
# fix known PDF errors (verified against LLVM test vectors)
|
||||
# SMEM: PDF says DLC=bit14, GLC=bit16 but actual encoding is DLC=bit13, GLC=bit14
|
||||
if 'SMEM' in formats:
|
||||
formats['SMEM'] = [(n, 13 if n == 'DLC' else 14 if n == 'GLC' else h, 13 if n == 'DLC' else 14 if n == 'GLC' else l, e, t)
|
||||
for n, h, l, e, t in formats['SMEM']]
|
||||
|
||||
# generate output
|
||||
def enum_lines(name, items):
|
||||
return [f"class {name}(IntEnum):"] + [f" {n} = {v}" for v, n in sorted(items.items())] + [""]
|
||||
def field_key(f): return order.index(f[0].lower()) if f[0].lower() in order else 1000
|
||||
lines = ["# autogenerated from AMD RDNA3.5 ISA PDF by lib.py - do not edit", "from enum import IntEnum",
|
||||
"from typing import Annotated",
|
||||
"from extra.assembly.rdna3.lib import bits, BitField, Inst32, Inst64, SGPR, VGPR, TTMP as TTMP, s as s, v as v, ttmp as ttmp, SSrc, Src, SImm, Imm, VDSTYEnc, SGPRField, VGPRField",
|
||||
"import functools", ""]
|
||||
lines += enum_lines("SrcEnum", src_enum) + sum([enum_lines(n, ops) for n, ops in sorted(enums.items())], [])
|
||||
# Format-specific field defaults (verified against LLVM test vectors)
|
||||
format_defaults = {'VOP3P': {'opsel_hi': 3, 'opsel_hi2': 1}}
|
||||
lines.append("# instruction formats")
|
||||
for fmt_name, fields in sorted(formats.items()):
|
||||
base = "Inst64" if max(f[1] for f in fields) > 31 or fmt_name == 'VOP3SD' else "Inst32"
|
||||
order = FIELD_ORDER.get(fmt_name, [])
|
||||
lines.append(f"class {fmt_name}({base}):")
|
||||
if enc := next((f for f in fields if f[0] == 'ENCODING'), None):
|
||||
enc_str = f"bits[{enc[1]}:{enc[2]}] == 0b{enc[3]:b}" if enc[1] != enc[2] else f"bits[{enc[1]}] == {enc[3]}"
|
||||
lines.append(f" encoding = {enc_str}")
|
||||
if defaults := format_defaults.get(fmt_name):
|
||||
lines.append(f" _defaults = {defaults}")
|
||||
for name, hi, lo, _, ftype in sorted([f for f in fields if f[0] != 'ENCODING'], key=field_key):
|
||||
# Wrap IntEnum types (ending in Op) with Annotated[BitField, ...] for correct typing
|
||||
if ftype and ftype.endswith('Op'):
|
||||
ann = f":Annotated[BitField, {ftype}]"
|
||||
else:
|
||||
ann = f":{ftype}" if ftype else ""
|
||||
lines.append(f" {name.lower()}{ann} = bits[{hi}]" if hi == lo else f" {name.lower()}{ann} = bits[{hi}:{lo}]")
|
||||
lines.append("")
|
||||
lines.append("# instruction helpers")
|
||||
for cls_name, ops in sorted(enums.items()):
|
||||
fmt = cls_name[:-2]
|
||||
for op_val, name in sorted(ops.items()):
|
||||
seg = {"GLOBAL": ", seg=2", "SCRATCH": ", seg=2"}.get(fmt, "")
|
||||
tgt = {"GLOBAL": "FLAT, GLOBALOp", "SCRATCH": "FLAT, SCRATCHOp"}.get(fmt, f"{fmt}, {cls_name}")
|
||||
if fmt in formats or fmt in ("GLOBAL", "SCRATCH"):
|
||||
# VOP1/VOP2/VOPC get _e32 suffix, VOP3 promoted ops (< 512) get _e64 suffix
|
||||
if fmt in ("VOP1", "VOP2", "VOPC"):
|
||||
suffix = "_e32"
|
||||
elif fmt == "VOP3" and op_val < 512:
|
||||
suffix = "_e64"
|
||||
else:
|
||||
suffix = ""
|
||||
# FMAMK/FMAAK have a literal constant K that must be passed via literal= kwarg
|
||||
# FMAMK: D = S0.f * K + S1.f (K is 3rd operand in assembly syntax)
|
||||
# FMAAK: D = S0.f * S1.f + K (K is 4th operand in assembly syntax)
|
||||
if name in ('V_FMAMK_F32', 'V_FMAMK_F16'):
|
||||
lines.append(f"def {name.lower()}{suffix}(vdst, src0, K, vsrc1): return {fmt}({cls_name}.{name}, vdst, src0, vsrc1, literal=K)")
|
||||
elif name in ('V_FMAAK_F32', 'V_FMAAK_F16'):
|
||||
lines.append(f"def {name.lower()}{suffix}(vdst, src0, vsrc1, K): return {fmt}({cls_name}.{name}, vdst, src0, vsrc1, literal=K)")
|
||||
else:
|
||||
lines.append(f"{name.lower()}{suffix} = functools.partial({tgt}.{name}{seg})")
|
||||
# export SrcEnum values, but skip DPP8/DPP16 which conflict with class names
|
||||
skip_exports = {'DPP8', 'DPP16'}
|
||||
lines += [""] + [f"{name} = SrcEnum.{name}" for _, name in sorted(src_enum.items()) if name not in skip_exports] + ["OFF = NULL\n"]
|
||||
|
||||
if output_path is not None:
|
||||
import pathlib
|
||||
pathlib.Path(output_path).write_text('\n'.join(lines))
|
||||
return {"formats": formats, "enums": enums, "src_enum": src_enum}
|
||||
|
||||
if __name__ == "__main__":
|
||||
result = generate("extra/assembly/rdna3/autogen/__init__.py")
|
||||
print(f"generated SrcEnum ({len(result['src_enum'])}) + {len(result['enums'])} opcode enums + {len(result['formats'])} format classes")
|
||||
|
||||
@@ -817,31 +817,37 @@ from extra.assembly.rdna3.pcode import *
|
||||
# Add original pseudocode as comment
|
||||
for pc_line in pc.split('\n'):
|
||||
lines.append(f" # {pc_line}")
|
||||
# V_DIV_SCALE: D0 defaults to S0 if no branch taken
|
||||
if is_div_scale:
|
||||
lines.append(" S0, S1, S2, D0, D1 = Reg(s0), Reg(s1), Reg(s2), Reg(s0), Reg(0)")
|
||||
else:
|
||||
lines.append(" S0, S1, S2, D0, D1 = Reg(s0), Reg(s1), Reg(s2), Reg(d0), Reg(0)")
|
||||
lines.append(" SCC, VCC, EXEC = Reg(scc), Reg(vcc), Reg(exec_mask)")
|
||||
lines.append(" EXEC_LO, EXEC_HI = SliceProxy(EXEC, 31, 0), SliceProxy(EXEC, 63, 32)")
|
||||
lines.append(" tmp, saveexec = Reg(0), Reg(exec_mask)")
|
||||
lines.append(" laneId = lane")
|
||||
lines.append(" SIMM16, SIMM32 = Reg(literal), Reg(literal)")
|
||||
lines.append(" SRC0, VDST = Reg(src0_idx), Reg(vdst_idx)")
|
||||
# Only create Reg objects for registers actually used in the pseudocode
|
||||
combined = code + pc
|
||||
regs = [('S0', 'Reg(s0)'), ('S1', 'Reg(s1)'), ('S2', 'Reg(s2)'),
|
||||
('D0', 'Reg(s0)' if is_div_scale else 'Reg(d0)'), ('D1', 'Reg(0)'),
|
||||
('SCC', 'Reg(scc)'), ('VCC', 'Reg(vcc)'), ('EXEC', 'Reg(exec_mask)'),
|
||||
('tmp', 'Reg(0)'), ('saveexec', 'Reg(exec_mask)'), ('laneId', 'lane'),
|
||||
('SIMM16', 'Reg(literal)'), ('SIMM32', 'Reg(literal)'),
|
||||
('SRC0', 'Reg(src0_idx)'), ('VDST', 'Reg(vdst_idx)')]
|
||||
used = {name for name, _ in regs if name in combined}
|
||||
# EXEC_LO/EXEC_HI need EXEC
|
||||
if 'EXEC_LO' in combined or 'EXEC_HI' in combined: used.add('EXEC')
|
||||
for name, init in regs:
|
||||
if name in used: lines.append(f" {name} = {init}")
|
||||
if 'EXEC_LO' in combined: lines.append(" EXEC_LO = SliceProxy(EXEC, 31, 0)")
|
||||
if 'EXEC_HI' in combined: lines.append(" EXEC_HI = SliceProxy(EXEC, 63, 32)")
|
||||
# Add compiled pseudocode with markers
|
||||
lines.append(" # --- compiled pseudocode ---")
|
||||
for line in code.split('\n'):
|
||||
lines.append(f" {line}")
|
||||
lines.append(" # --- end pseudocode ---")
|
||||
# Generate result dict
|
||||
lines.append(" result = {'d0': D0._val, 'scc': SCC._val & 1}")
|
||||
# Generate result dict - use raw params if Reg wasn't created
|
||||
d0_val = "D0._val" if 'D0' in used else "d0"
|
||||
scc_val = "SCC._val & 1" if 'SCC' in used else "scc & 1"
|
||||
lines.append(f" result = {{'d0': {d0_val}, 'scc': {scc_val}}}")
|
||||
if has_sdst:
|
||||
lines.append(" result['vcc_lane'] = (VCC._val >> lane) & 1")
|
||||
else:
|
||||
elif 'VCC' in used:
|
||||
lines.append(" if VCC._val != vcc: result['vcc_lane'] = (VCC._val >> lane) & 1")
|
||||
if is_cmpx:
|
||||
lines.append(" result['exec_lane'] = (EXEC._val >> lane) & 1")
|
||||
else:
|
||||
elif 'EXEC' in used:
|
||||
lines.append(" if EXEC._val != exec_mask: result['exec'] = EXEC._val")
|
||||
if is_cmp:
|
||||
lines.append(" result['vcc_lane'] = (D0._val >> lane) & 1")
|
||||
|
||||
@@ -126,7 +126,7 @@ class PythonEmulator:
|
||||
|
||||
def run_single_kernel(kernel: bytes, n_lanes: int, args_ptr: int, global_size: tuple[int, int, int],
|
||||
program, max_steps: int, debug: bool, trace_len: int, kernel_idx: int = 0,
|
||||
max_workgroups: int = 64) -> tuple[bool, str, int]:
|
||||
max_workgroups: int = 8) -> tuple[bool, str, int]:
|
||||
"""Run a single kernel through both emulators. Returns (success, message, total_steps)."""
|
||||
gx, gy, gz = global_size
|
||||
total_steps = 0
|
||||
@@ -356,191 +356,52 @@ class TestTinygradKernels(unittest.TestCase):
|
||||
ok, msg = compare_emulators_multi_kernel(kernels, buf_pool, max_steps=max_steps, buf_data=buf_data)
|
||||
self.assertTrue(ok, msg)
|
||||
|
||||
# Basic unary ops
|
||||
def test_neg(self): self._test_kernel(lambda T: -T([1.0, -2.0, 3.0, -4.0]))
|
||||
def test_relu(self): self._test_kernel(lambda T: T([-1.0, 0.0, 1.0, 2.0]).relu())
|
||||
def test_exp(self): self._test_kernel(lambda T: T([0.0, 1.0, 2.0]).exp())
|
||||
def test_log(self): self._test_kernel(lambda T: T([1.0, 2.0, 3.0]).log())
|
||||
def test_sin(self): self._test_kernel(lambda T: T([0.0, 1.0, 2.0]).sin())
|
||||
def test_cos(self): self._test_kernel(lambda T: T([0.0, 1.0, 2.0]).cos())
|
||||
def test_sqrt(self): self._test_kernel(lambda T: T([1.0, 4.0, 9.0]).sqrt())
|
||||
def test_recip(self): self._test_kernel(lambda T: T([1.0, 2.0, 4.0]).reciprocal())
|
||||
|
||||
# Sin/cos with various ranges - test polynomial expansion
|
||||
def test_sin_small(self): self._test_kernel(lambda T: T([0.1, 0.2, 0.3, 0.4, 0.5]*7).sin()) # 35 elements, small angles
|
||||
def test_sin_pi(self): self._test_kernel(lambda T: T([3.14159, 1.5708, 0.7854, -1.5708, -3.14159]*7).sin()) # around pi
|
||||
def test_sin_medium(self): self._test_kernel(lambda T: T([10.0, 20.0, 30.0, 50.0, 100.0]*7).sin()) # medium values
|
||||
def test_sin_negative(self): self._test_kernel(lambda T: T([-0.5, -1.0, -2.0, -5.0, -10.0]*7).sin()) # negative values
|
||||
def test_cos_small(self): self._test_kernel(lambda T: T([0.1, 0.2, 0.3, 0.4, 0.5]*7).cos())
|
||||
def test_cos_pi(self): self._test_kernel(lambda T: T([3.14159, 1.5708, 0.7854, -1.5708, -3.14159]*7).cos())
|
||||
def test_cos_medium(self): self._test_kernel(lambda T: T([10.0, 20.0, 30.0, 50.0, 100.0]*7).cos())
|
||||
@unittest.skip("Rust emulator has V_DIV_SCALE_F32 bug - returns 0 instead of src0 for normal cases")
|
||||
def test_tan(self): self._test_kernel(lambda T: T([0.1, 0.2, 0.5, 1.0, -0.5]*7).tan()) # avoid pi/2
|
||||
|
||||
# Binary ops
|
||||
def test_add(self): self._test_kernel(lambda T: T([1.0, 2.0]) + T([3.0, 4.0]))
|
||||
def test_sub(self): self._test_kernel(lambda T: T([5.0, 6.0]) - T([1.0, 2.0]))
|
||||
def test_mul(self): self._test_kernel(lambda T: T([2.0, 3.0]) * T([4.0, 5.0]))
|
||||
def test_div(self): self._test_kernel(lambda T: T([10.0, 20.0]) / T([2.0, 4.0]))
|
||||
def test_max_binary(self): self._test_kernel(lambda T: T([1.0, 5.0]).maximum(T([3.0, 2.0])))
|
||||
# Basic ops - consolidated tests covering key instruction patterns
|
||||
def test_unary_ops(self): self._test_kernel(lambda T: T([-1.0, 0.0, 1.0, 2.0]).relu().exp().log().sqrt().reciprocal())
|
||||
def test_binary_ops(self): self._test_kernel(lambda T: (T([1.0, 2.0]) + T([3.0, 4.0])) * T([0.5, 0.5]) - T([1.0, 1.0]))
|
||||
def test_trig(self): self._test_kernel(lambda T: T([0.1, 1.0, 3.14, -1.0]*8).sin() + T([0.1, 1.0, 3.14, -1.0]*8).cos())
|
||||
def test_compare(self): self._test_kernel(lambda T: (T.empty(64) < T.empty(64)).where(T.empty(64), T.empty(64)))
|
||||
def test_bitwise(self): self._test_kernel(lambda T: (T([0xF0, 0x0F, 0xFF]*11).int() & T([0x0F, 0x0F, 0x00]*11).int()) | T([1]*33).int())
|
||||
def test_int_ops(self): self._test_kernel(lambda T: ((T.empty(64).int() + T.empty(64).int()) * T.empty(64).int()).float())
|
||||
|
||||
# Reductions
|
||||
def test_sum_reduce(self): self._test_kernel(lambda T: T.empty(64).sum())
|
||||
def test_max_reduce(self): self._test_kernel(lambda T: T.empty(64).max())
|
||||
def test_mean_reduce(self): self._test_kernel(lambda T: T.empty(32).mean())
|
||||
def test_reduce(self): self._test_kernel(lambda T: T.empty(64).sum() + T.empty(64).max())
|
||||
def test_argmax(self): self._test_kernel(lambda T: T.empty(64).argmax())
|
||||
|
||||
# Matmul - various sizes
|
||||
def test_gemm_4x4(self): self._test_kernel(lambda T: T.empty(4, 4) @ T.empty(4, 4), max_steps=100000)
|
||||
def test_gemm_8x8(self): self._test_kernel(lambda T: T.empty(8, 8) @ T.empty(8, 8), max_steps=200000)
|
||||
@unittest.skip("too slow")
|
||||
def test_gemm_16x16(self): self._test_kernel(lambda T: T.empty(16, 16) @ T.empty(16, 16), max_steps=500000)
|
||||
def test_gemv(self): self._test_kernel(lambda T: T.empty(1, 16) @ T.empty(16, 16), max_steps=100000)
|
||||
# Matmul
|
||||
def test_gemm(self): self._test_kernel(lambda T: T.empty(8, 8) @ T.empty(8, 8), max_steps=100000)
|
||||
def test_gemm_fp16(self): self._test_kernel(lambda T: T.empty(16, 16).half() @ T.empty(16, 16).half(), max_steps=100000)
|
||||
|
||||
# Complex ops
|
||||
def test_softmax(self): self._test_kernel(lambda T: T.empty(16).softmax())
|
||||
def test_layernorm(self): self._test_kernel(lambda T: T.empty(8, 8).layernorm())
|
||||
|
||||
# Memory patterns
|
||||
def test_contiguous(self): self._test_kernel(lambda T: T.empty(4, 4).permute(1, 0).contiguous())
|
||||
def test_reshape(self): self._test_kernel(lambda T: (T.empty(16) + 1).reshape(4, 4).contiguous())
|
||||
def test_expand(self): self._test_kernel(lambda T: T.empty(4, 1).expand(4, 4).contiguous())
|
||||
def test_memory(self): self._test_kernel(lambda T: T.empty(4, 4).permute(1, 0).contiguous() + T.empty(4, 1).expand(4, 4))
|
||||
|
||||
# Cast ops
|
||||
def test_cast_int(self): self._test_kernel(lambda T: T.empty(16).int().float())
|
||||
def test_cast_half(self): self._test_kernel(lambda T: T.empty(16).half().float())
|
||||
def test_cast(self): self._test_kernel(lambda T: T.empty(32).half().float() + T.empty(32).int().float())
|
||||
|
||||
# Min/max (uses comparison internally)
|
||||
def test_min_binary(self): self._test_kernel(lambda T: T([1.0, 5.0, 3.0]).minimum(T([3.0, 2.0, 4.0])))
|
||||
# Pooling - regression for VCC wave32 mode
|
||||
def test_pool2d(self): self._test_kernel(lambda T: T.empty(1, 1, 8, 8).avg_pool2d(kernel_size=(4,4)) + T.empty(1, 1, 8, 8).max_pool2d(kernel_size=(4,4)))
|
||||
|
||||
# Comparison ops (test VOPC instructions) - use 32+ elements to force vector instructions
|
||||
def test_cmp_lt(self): self._test_kernel(lambda T: (T.empty(64) < T.empty(64)).where(T.empty(64), T.empty(64)))
|
||||
def test_cmp_eq(self): self._test_kernel(lambda T: (T.empty(64) == T.empty(64)).where(T.empty(64), T.empty(64)))
|
||||
def test_where(self): self._test_kernel(lambda T: (T.empty(64) > 0).where(T.empty(64), T.empty(64)))
|
||||
# Convolution
|
||||
def test_conv2d(self): self._test_kernel(lambda T: T.empty(1, 2, 8, 8).conv2d(T.empty(2, 2, 3, 3)), max_steps=50000)
|
||||
|
||||
# Bitwise ops
|
||||
def test_bitwise_and(self): self._test_kernel(lambda T: T([0xF0, 0x0F, 0xFF]).int() & T([0x0F, 0x0F, 0x00]).int())
|
||||
def test_bitwise_or(self): self._test_kernel(lambda T: T([0xF0, 0x0F, 0x00]).int() | T([0x0F, 0x0F, 0xFF]).int())
|
||||
def test_bitwise_xor(self): self._test_kernel(lambda T: T([0xFF, 0x0F, 0xF0]).int() ^ T([0x0F, 0xF0, 0xF0]).int())
|
||||
|
||||
# Integer ops - use 32+ elements to force vector instructions
|
||||
def test_int_add(self): self._test_kernel(lambda T: (T.empty(64).int() + T.empty(64).int()).float())
|
||||
def test_int_mul(self): self._test_kernel(lambda T: (T.empty(64).int() * T.empty(64).int()).float())
|
||||
def test_int_mod(self): self._test_kernel(lambda T: (T.empty(64).int().abs() % (T.empty(64).int().abs() + 1)).float())
|
||||
|
||||
# More math ops - use 32+ elements to force vector instructions
|
||||
def test_abs(self): self._test_kernel(lambda T: T.empty(64).abs())
|
||||
def test_floor(self): self._test_kernel(lambda T: T.empty(64).floor())
|
||||
def test_ceil(self): self._test_kernel(lambda T: T.empty(64).ceil())
|
||||
def test_trunc(self): self._test_kernel(lambda T: T.empty(64).trunc())
|
||||
|
||||
# Fused ops
|
||||
def test_fma(self): self._test_kernel(lambda T: (T([1.0, 2.0]) * T([3.0, 4.0]) + T([5.0, 6.0])))
|
||||
|
||||
# Argmax/argmin (tests different reduction pattern) - use 32+ elements to force vector instructions
|
||||
def test_argmax(self): self._test_kernel(lambda T: T.empty(64).argmax())
|
||||
def test_argmin(self): self._test_kernel(lambda T: T.empty(64).argmin())
|
||||
|
||||
# Exact value tests - use 32+ elements to force vector instructions (small tensors use scalar ops which Rust emu doesn't fully support)
|
||||
def test_abs_exact(self): self._test_kernel(lambda T: T([-1., 0., 1.]*11).abs()) # 33 elements
|
||||
def test_neg_exact(self): self._test_kernel(lambda T: -T([-1., 0., 1.]*11))
|
||||
def test_log_special(self): self._test_kernel(lambda T: T([1., 2., 0.5]*11).log())
|
||||
def test_exp_exact(self): self._test_kernel(lambda T: T([0., 1., -1.]*11).exp())
|
||||
def test_reciprocal_exact(self): self._test_kernel(lambda T: T([1., 2., 0.5]*11).reciprocal())
|
||||
|
||||
# Integer division and mod - use 32+ elements
|
||||
def test_int_div(self): self._test_kernel(lambda T: (T([10, 20, 30]*11).int() // T([3, 4, 5]*11).int()).float())
|
||||
def test_int_neg(self): self._test_kernel(lambda T: (-T([1, -2, 3]*11).int()).float())
|
||||
|
||||
# Mixed precision - use 32+ elements
|
||||
def test_half_add(self): self._test_kernel(lambda T: (T([1., 2.]*16).half() + T([3., 4.]*16).half()).float())
|
||||
def test_half_mul(self): self._test_kernel(lambda T: (T([2., 3.]*16).half() * T([4., 5.]*16).half()).float())
|
||||
|
||||
# Matrix ops - patterns from test_ops.py failures
|
||||
def test_cat(self): self._test_kernel(lambda T: T.empty(32, 64).cat(T.empty(32, 64), dim=1))
|
||||
def test_gather(self): self._test_kernel(lambda T: T.empty(64).gather(0, T.arange(32).int()))
|
||||
|
||||
# Tests from test_ops.py that are failing
|
||||
def test_permute(self): self._test_kernel(lambda T: T.empty(3, 4, 5, 6).permute((3, 2, 1, 0)).contiguous())
|
||||
def test_cat_large(self): self._test_kernel(lambda T: T.empty(45, 65, 9).cat(T.empty(45, 65, 9), T.empty(45, 65, 9), dim=1))
|
||||
def test_gather_small(self): self._test_kernel(lambda T: T.empty(10).gather(0, T.arange(5).int()))
|
||||
@unittest.skip("Rust emulator has S_ADD_I32 SCC bug - uses carry instead of signed overflow")
|
||||
def test_cross_entropy(self): self._test_kernel(lambda T: T.randn(32, 10).softmax().log().sum())
|
||||
def test_cross_entropy_class(self):
|
||||
import numpy as np
|
||||
np.random.seed(0)
|
||||
classes = np.random.randint(0, 10, (32,), dtype=np.int32).tolist()
|
||||
x_np = np.random.randn(32, 10).astype(np.float32)
|
||||
self._test_kernel(lambda T: (T(x_np.tolist()).reshape(32,10) + 0).cross_entropy((T(classes).int().reshape(32) + 0)))
|
||||
|
||||
# Regression tests for BFE operations with width=0 (walrus operator bug)
|
||||
# Regression tests
|
||||
def test_topk(self): self._test_kernel(lambda T: T.empty(64).topk(3)[0])
|
||||
def test_interpolate_uint8(self): self._test_kernel(lambda T: T.empty(2,3,64,64).relu().cast('uint8').interpolate((10,10), mode="linear"))
|
||||
|
||||
# Regression test for 64-bit comparison (V_CMP_GT_I64, V_CMP_LT_U64, etc.) with rsrc64
|
||||
def test_interpolate(self): self._test_kernel(lambda T: T.empty(1,2,16,16).relu().cast('uint8').interpolate((8,8), mode="linear"))
|
||||
def test_index_int64(self):
|
||||
from tinygrad import dtypes
|
||||
self._test_kernel(lambda T: T.empty(4, 4)[T.arange(4).cast(dtypes.int64), :])
|
||||
|
||||
@unittest.skip("only works with mock GPU")
|
||||
def test_index_int64_2d(self):
|
||||
from tinygrad import dtypes
|
||||
# Tests 64-bit compare with inline constants (comparing against 0)
|
||||
self._test_kernel(lambda T: T.empty(4, 4)[T.arange(4).cast(dtypes.int64), T.arange(4).cast(dtypes.int64)])
|
||||
|
||||
# Pooling operations - regression test for VCC wave32 mode (S_CBRANCH_VCCZ should only check VCC_LO)
|
||||
def test_avg_pool2d(self): self._test_kernel(lambda T: T.empty(1, 1, 8, 8).avg_pool2d(kernel_size=(4,4), stride=2))
|
||||
|
||||
# Trig functions with special values (inf, nan, 0)
|
||||
def test_sin_special(self): self._test_kernel(lambda T: T([0., 0.25, 0.5, 1.0]*8).sin())
|
||||
def test_cos_special(self): self._test_kernel(lambda T: T([0., 0.25, 0.5, 1.0]*8).cos())
|
||||
|
||||
# Sqrt and rsqrt
|
||||
def test_sqrt(self): self._test_kernel(lambda T: T([0., 1., 4., 9.]*8).sqrt())
|
||||
def test_rsqrt(self): self._test_kernel(lambda T: T([1., 4., 9., 16.]*8).rsqrt())
|
||||
@unittest.skip("Rust emulator has S_ADD_I32 SCC bug - uses carry instead of signed overflow")
|
||||
def test_avg_pool3d(self):
|
||||
def test_gelu(self): self._test_kernel(lambda T: T.empty(32, 32).gelu())
|
||||
def test_cross_entropy(self):
|
||||
import numpy as np
|
||||
np.random.seed(0)
|
||||
self._test_kernel(lambda T: T(np.random.randn(1, 1, 16, 16, 16).astype(np.float32).tolist()).avg_pool2d(kernel_size=(8,8,8), stride=5, padding=1, count_include_pad=False))
|
||||
def test_max_pool2d(self): self._test_kernel(lambda T: T.empty(1, 1, 8, 8).max_pool2d(kernel_size=(4,4), stride=2))
|
||||
|
||||
# Convolution operations - multi-kernel tests
|
||||
def test_conv2d(self): self._test_kernel(lambda T: T.empty(1, 4, 8, 8).conv2d(T.empty(4, 4, 3, 3)), max_steps=100000)
|
||||
def test_conv_transpose2d(self): self._test_kernel(lambda T: T.empty(1, 4, 8, 8).conv_transpose2d(T.empty(4, 4, 3, 3)), max_steps=200000)
|
||||
@unittest.skip("Rust emulator has S_ADD_I32 SCC bug - uses carry instead of signed overflow")
|
||||
def test_conv_transpose3d(self):
|
||||
import numpy as np
|
||||
np.random.seed(0)
|
||||
self._test_kernel(lambda T: T(np.random.randn(2, 4, 9, 9, 9).astype(np.float32).tolist()).conv_transpose2d(
|
||||
T(np.random.randn(4, 4, 3, 3, 3).astype(np.float32).tolist())), max_steps=500000)
|
||||
|
||||
# Tests from test_ops.py failures
|
||||
def test_gelu_extreme(self): self._test_kernel(lambda T: T.empty(45, 65).gelu())
|
||||
def test_gemm_64x64(self): self._test_kernel(lambda T: T.empty(64, 64) @ T.empty(64, 64), max_steps=500000)
|
||||
def test_gemm_fp16(self): self._test_kernel(lambda T: T.empty(64, 64).half() @ T.empty(64, 64).half(), max_steps=500000)
|
||||
def test_global_avg_pool2d(self): self._test_kernel(lambda T: T.empty(32, 2, 111, 28).avg_pool2d(kernel_size=(111, 28)), max_steps=100000)
|
||||
@unittest.skip("Rust emulator has S_ADD_I32 SCC bug - uses carry instead of signed overflow")
|
||||
def test_grouped_conv2d(self): self._test_kernel(lambda T: T.empty(4, 15, 5, 5).conv2d(T.empty(35, 3, 3, 3), groups=5), max_steps=200000)
|
||||
@unittest.skip("Rust emulator has S_ADD_I32 SCC bug - uses carry instead of signed overflow")
|
||||
def test_grouped_conv_transpose2d(self): self._test_kernel(lambda T: T.empty(2, 4, 9, 9).conv_transpose2d(T.empty(4, 4, 3, 3), groups=2), max_steps=200000)
|
||||
def test_hardsigmoid(self): self._test_kernel(lambda T: T.empty(45, 65).hardsigmoid())
|
||||
def test_hardsigmoid_extreme(self): self._test_kernel(lambda T: T.empty(45, 65).sigmoid())
|
||||
def test_matvec(self): self._test_kernel(lambda T: (T.empty(1, 128) @ T.empty(128, 128)).relu(), max_steps=200000)
|
||||
def test_matvecmat(self): self._test_kernel(lambda T: ((T.empty(1, 128) @ T.empty(128, 128)).relu() @ T.empty(128, 128)), max_steps=300000)
|
||||
def test_max_reduce_45x3(self): self._test_kernel(lambda T: T.empty(45, 3).max())
|
||||
def test_max_dont_collapse(self): self._test_kernel(lambda T: T.empty(4, 8).max(axis=1))
|
||||
def test_max_pool2d_simple(self): self._test_kernel(lambda T: T.empty(1, 1, 2, 3).max_pool2d(kernel_size=(2, 2)))
|
||||
def test_max_pool2d_32x2(self): self._test_kernel(lambda T: T.empty(32, 2, 11, 28).max_pool2d(kernel_size=(2, 2)))
|
||||
def test_max_pool2d_asymmetric_padding(self): self._test_kernel(lambda T: T.empty(4, 2, 111, 28).max_pool2d(kernel_size=(5, 5), padding=(0, 1, 0, 1)))
|
||||
def test_max_pool2d_bigger_stride(self): self._test_kernel(lambda T: T.empty(4, 2, 11, 28).max_pool2d(kernel_size=(2, 2), stride=(2, 3)))
|
||||
def test_max_pool2d_unit_stride(self): self._test_kernel(lambda T: T.empty(3, 2, 17, 14).max_pool2d(kernel_size=(5, 5), stride=1))
|
||||
def test_max_pool2d_smaller_stride(self): self._test_kernel(lambda T: T.empty(3, 2, 17, 14).max_pool2d(kernel_size=(5, 5), stride=(2, 3)))
|
||||
def test_max_unpool2d(self): self._test_kernel(lambda T: T.max_unpool2d(*T.empty(8, 3, 50, 50).max_pool2d(kernel_size=(5, 5), stride=(6, 5), return_indices=True), kernel_size=(5, 5), stride=(6, 5)))
|
||||
classes = np.random.randint(0, 10, (16,), dtype=np.int32).tolist()
|
||||
x_np = np.random.randn(16, 10).astype(np.float32)
|
||||
self._test_kernel(lambda T: (T(x_np.tolist()).reshape(16,10) + 0).cross_entropy((T(classes).int().reshape(16) + 0)))
|
||||
def test_isinf(self): self._test_kernel(lambda T: T([float('-inf'), 0., float('inf'), 1.1]*8).isinf())
|
||||
def test_isfinite(self): self._test_kernel(lambda T: T([float('-inf'), 0., float('inf'), 1.1]*8).isfinite())
|
||||
|
||||
# WMMA tests - uses wave matrix multiply for larger fp16 matmuls
|
||||
def test_wmma_gemm_fp16(self): self._test_kernel(lambda T: T.empty(64, 64).half() @ T.empty(64, 64).half(), max_steps=1000000)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Integration test: round-trip RDNA3 assembly through AMD toolchain."""
|
||||
import unittest, re, io, sys
|
||||
import unittest, re, io, sys, subprocess
|
||||
from extra.assembly.rdna3.autogen import *
|
||||
from extra.assembly.rdna3.asm import waitcnt, asm
|
||||
from extra.assembly.rdna3.test.test_roundtrip import _get_llvm_mc
|
||||
|
||||
def get_amd_toolchain():
|
||||
"""Check if AMD toolchain is available."""
|
||||
@@ -212,10 +213,8 @@ class TestAsm(unittest.TestCase):
|
||||
|
||||
def test_asm_vop3_modifiers(self):
|
||||
"""Test asm() with VOP3 modifiers (neg, abs, clamp)."""
|
||||
import subprocess, re
|
||||
|
||||
def get_llvm_encoding(instr: str) -> str:
|
||||
result = subprocess.run(['llvm-mc', '-triple=amdgcn', '-mcpu=gfx1100', '-show-encoding'],
|
||||
result = subprocess.run([_get_llvm_mc(), '-triple=amdgcn', '-mcpu=gfx1100', '-show-encoding'],
|
||||
input=instr, capture_output=True, text=True)
|
||||
if m := re.search(r'encoding:\s*\[(.*?)\]', result.stdout):
|
||||
return m.group(1).replace('0x','').replace(',','').replace(' ','')
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Test RDNA3 assembler/disassembler against LLVM test vectors."""
|
||||
import unittest, re
|
||||
import unittest, re, subprocess
|
||||
from tinygrad.helpers import fetch
|
||||
from extra.assembly.rdna3.autogen import *
|
||||
from extra.assembly.rdna3.asm import asm
|
||||
from extra.assembly.rdna3.test.test_roundtrip import compile_asm, disassemble_lib
|
||||
from extra.assembly.rdna3.test.test_roundtrip import _get_llvm_mc
|
||||
|
||||
LLVM_BASE = "https://raw.githubusercontent.com/llvm/llvm-project/main/llvm/test/MC/AMDGPU"
|
||||
|
||||
@@ -78,6 +78,24 @@ def try_assemble(text: str):
|
||||
try: return asm(text).to_bytes()
|
||||
except: return None
|
||||
|
||||
def compile_asm_batch(instrs: list[str]) -> list[bytes]:
|
||||
"""Compile multiple instructions with a single llvm-mc call."""
|
||||
if not instrs: return []
|
||||
asm_text = ".text\n" + "\n".join(instrs) + "\n"
|
||||
result = subprocess.run(
|
||||
[_get_llvm_mc(), '-triple=amdgcn', '-mcpu=gfx1100', '-mattr=+real-true16,+wavefrontsize32', '-show-encoding'],
|
||||
input=asm_text, capture_output=True, text=True, timeout=30)
|
||||
if result.returncode != 0: raise RuntimeError(f"llvm-mc batch failed: {result.stderr.strip()}")
|
||||
# Parse all encodings from output
|
||||
results = []
|
||||
for line in result.stdout.split('\n'):
|
||||
if 'encoding:' not in line: continue
|
||||
enc = line.split('encoding:')[1].strip()
|
||||
if enc.startswith('[') and enc.endswith(']'):
|
||||
results.append(bytes.fromhex(enc[1:-1].replace('0x', '').replace(',', '').replace(' ', '')))
|
||||
if len(results) != len(instrs): raise RuntimeError(f"expected {len(instrs)} encodings, got {len(results)}")
|
||||
return results
|
||||
|
||||
class TestLLVM(unittest.TestCase):
|
||||
"""Test assembler and disassembler against all LLVM test vectors."""
|
||||
tests: dict[str, list[tuple[str, bytes]]] = {}
|
||||
@@ -107,59 +125,62 @@ def _make_asm_test(name):
|
||||
|
||||
def _make_disasm_test(name):
|
||||
def test(self):
|
||||
from tinygrad.runtime.support.compiler_amd import HIPCompiler
|
||||
compiler = HIPCompiler('gfx1100')
|
||||
_, fmt_cls, op_enum = LLVM_TEST_FILES[name]
|
||||
passed, failed, skipped, failures = 0, 0, 0, []
|
||||
# VOP3SD opcodes that share encoding with VOP3 (only for vop3sd test, not vopc promotions)
|
||||
# Note: opcodes 0-255 are VOPC promoted to VOP3, never VOP3SD
|
||||
vop3sd_opcodes = {288, 289, 290, 764, 765, 766, 767, 768, 769, 770}
|
||||
# vop3_from_vopc/vopcx tests have VOPC opcodes 0-255, not VOP3SD - don't detect as VOP3SD
|
||||
is_vopc_promotion = name in ('vop3_from_vopc', 'vop3_from_vopcx')
|
||||
# Undocumented opcodes not in AMD ISA PDF - skip these
|
||||
undocumented = {'smem': {34, 35}, 'sopk': {22, 23}, 'sopp': {8, 58, 59}} # s_atc_probe*, s_subvector_loop*, s_waitcnt_depctr, unknown
|
||||
undocumented = {'smem': {34, 35}, 'sopk': {22, 23}, 'sopp': {8, 58, 59}}
|
||||
|
||||
# First pass: decode all instructions and collect disasm strings
|
||||
to_test = [] # list of (asm_text, data, disasm_str)
|
||||
skipped = 0
|
||||
for asm_text, data in self.tests.get(name, []):
|
||||
if len(data) > fmt_cls._size(): continue # skip literals (need different handling)
|
||||
# Skip undocumented opcodes
|
||||
if len(data) > fmt_cls._size(): continue
|
||||
temp_inst = fmt_cls.from_bytes(data)
|
||||
temp_op = temp_inst._values.get('op', 0)
|
||||
temp_op = temp_op.val if hasattr(temp_op, 'val') else temp_op
|
||||
if temp_op in undocumented.get(name, set()): skipped += 1; continue
|
||||
# Skip SOPP no-imm instructions with non-zero simm16 (can't roundtrip through LLVM)
|
||||
if name == 'sopp':
|
||||
simm16 = temp_inst._values.get('simm16', 0)
|
||||
simm16 = simm16.val if hasattr(simm16, 'val') else simm16
|
||||
sopp_no_imm = {48, 54, 53, 55, 60, 61, 62} # s_endpgm, s_barrier, s_wakeup, s_icache_inv, s_wait_idle, s_endpgm_saved, s_code_end
|
||||
sopp_no_imm = {48, 54, 53, 55, 60, 61, 62}
|
||||
if temp_op in sopp_no_imm and simm16 != 0: skipped += 1; continue
|
||||
try:
|
||||
# VOP3 and VOP3SD share encoding - peek at opcode to determine which class to use
|
||||
if fmt_cls.__name__ in ('VOP3', 'VOP3SD'):
|
||||
temp = VOP3.from_bytes(data)
|
||||
op_val = temp._values.get('op', 0)
|
||||
op_val = op_val.val if hasattr(op_val, 'val') else op_val
|
||||
is_vop3sd = (op_val in vop3sd_opcodes) and not is_vopc_promotion
|
||||
decoded = VOP3SD.from_bytes(data) if is_vop3sd else VOP3.from_bytes(data)
|
||||
# Validate opcode with appropriate enum
|
||||
if is_vop3sd:
|
||||
VOP3SDOp(op_val)
|
||||
else:
|
||||
VOP3Op(op_val)
|
||||
if is_vop3sd: VOP3SDOp(op_val)
|
||||
else: VOP3Op(op_val)
|
||||
else:
|
||||
decoded = fmt_cls.from_bytes(data)
|
||||
op_val = decoded._values.get('op', 0)
|
||||
op_val = op_val.val if hasattr(op_val, 'val') else op_val
|
||||
op_enum(op_val) # validate opcode
|
||||
op_enum(op_val)
|
||||
if decoded.to_bytes()[:len(data)] != data:
|
||||
failed += 1; failures.append(f"decode roundtrip failed for {data.hex()}"); continue
|
||||
disasm_str = decoded.disasm()
|
||||
# Test: LLVM should assemble our disasm output to the same bytes
|
||||
llvm_bytes = compile_asm(disasm_str, compiler)
|
||||
if llvm_bytes is None:
|
||||
failed += 1; failures.append(f"LLVM failed to assemble: '{disasm_str}' (from '{asm_text}')")
|
||||
elif llvm_bytes == data: passed += 1
|
||||
else: failed += 1; failures.append(f"'{disasm_str}': expected={data.hex()} got={llvm_bytes.hex()}")
|
||||
to_test.append((asm_text, data, None, "decode roundtrip failed"))
|
||||
continue
|
||||
to_test.append((asm_text, data, decoded.disasm(), None))
|
||||
except Exception as e:
|
||||
failed += 1; failures.append(f"exception for {data.hex()}: {e}")
|
||||
to_test.append((asm_text, data, None, f"exception: {e}"))
|
||||
|
||||
# Batch compile all disasm strings with single llvm-mc call
|
||||
disasm_strs = [(i, t[2]) for i, t in enumerate(to_test) if t[2] is not None]
|
||||
llvm_results = compile_asm_batch([s for _, s in disasm_strs]) if disasm_strs else []
|
||||
llvm_map = {i: llvm_results[j] for j, (i, _) in enumerate(disasm_strs)}
|
||||
|
||||
# Match results back
|
||||
passed, failed, failures = 0, 0, []
|
||||
for idx, (asm_text, data, disasm_str, error) in enumerate(to_test):
|
||||
if error:
|
||||
failed += 1; failures.append(f"{error} for {data.hex()}")
|
||||
elif disasm_str is not None and idx in llvm_map:
|
||||
llvm_bytes = llvm_map[idx]
|
||||
if llvm_bytes == data: passed += 1
|
||||
else: failed += 1; failures.append(f"'{disasm_str}': expected={data.hex()} got={llvm_bytes.hex()}")
|
||||
|
||||
print(f"{name.upper()} disasm: {passed} passed, {failed} failed" + (f", {skipped} skipped" if skipped else ""))
|
||||
if failures[:10]: print(" " + "\n ".join(failures[:10]))
|
||||
self.assertEqual(failed, 0)
|
||||
|
||||
@@ -49,7 +49,7 @@ dev.synchronize()
|
||||
self.assertTrue("NotImplementedError" in result.stderr or "ValueError" in result.stderr,
|
||||
f"expected NotImplementedError or ValueError in stderr")
|
||||
# Should exit immediately, not wait for the full timeout
|
||||
self.assertLess(elapsed, 5.0, f"should exit immediately on emulator exception, took {elapsed:.1f}s")
|
||||
self.assertLess(elapsed, 9.0, f"should exit immediately on emulator exception, took {elapsed:.1f}s")
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Test that PDF parser correctly extracts format fields."""
|
||||
import unittest
|
||||
import unittest, os
|
||||
from extra.assembly.rdna3.autogen import (
|
||||
SOP1, SOP2, SOPK, SOPP, VOP1, VOP2, VOP3SD, VOPC, FLAT, VOPD,
|
||||
SOP1Op, SOP2Op, VOP1Op, VOP3Op
|
||||
@@ -33,34 +33,32 @@ EXPECTED_FORMATS = {
|
||||
'VOPD': (['OPX', 'OPY', 'SRCX0', 'SRCY0', 'VDSTX', 'VDSTY'], True),
|
||||
}
|
||||
|
||||
# Skip PDF parsing tests by default - only run with TEST_PDF_PARSER=1
|
||||
# These are slow (~5s) and only needed when regenerating autogen/
|
||||
@unittest.skipUnless(os.environ.get("TEST_PDF_PARSER"), "set TEST_PDF_PARSER=1 to run PDF parser tests")
|
||||
class TestPDFParserGenerate(unittest.TestCase):
|
||||
"""Test the PDF parser by running generate() and checking results."""
|
||||
result: dict
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
from extra.assembly.rdna3.gen import generate
|
||||
cls.result = generate()
|
||||
def test_pdf_parser(self):
|
||||
"""Single test that validates all PDF parser outputs."""
|
||||
from extra.assembly.rdna3.lib import generate
|
||||
result = generate()
|
||||
|
||||
def test_all_formats_present(self):
|
||||
"""All expected formats should be parsed."""
|
||||
# test_all_formats_present
|
||||
for fmt_name in EXPECTED_FORMATS:
|
||||
self.assertIn(fmt_name, self.result["formats"], f"missing format {fmt_name}")
|
||||
self.assertIn(fmt_name, result["formats"], f"missing format {fmt_name}")
|
||||
|
||||
def test_format_count(self):
|
||||
"""Should have exactly 23 formats."""
|
||||
self.assertEqual(len(self.result["formats"]), 23)
|
||||
# test_format_count
|
||||
self.assertEqual(len(result["formats"]), 23)
|
||||
|
||||
def test_no_duplicate_fields(self):
|
||||
"""No format should have duplicate field names."""
|
||||
for fmt_name, fields in self.result["formats"].items():
|
||||
# test_no_duplicate_fields
|
||||
for fmt_name, fields in result["formats"].items():
|
||||
field_names = [f[0] for f in fields]
|
||||
self.assertEqual(len(field_names), len(set(field_names)), f"{fmt_name} has duplicate fields: {field_names}")
|
||||
|
||||
def test_expected_fields(self):
|
||||
"""Each format should have its expected key fields."""
|
||||
# test_expected_fields
|
||||
for fmt_name, (expected_fields, has_encoding) in EXPECTED_FORMATS.items():
|
||||
fields = {f[0] for f in self.result["formats"].get(fmt_name, [])}
|
||||
fields = {f[0] for f in result["formats"].get(fmt_name, [])}
|
||||
for field in expected_fields:
|
||||
self.assertIn(field, fields, f"{fmt_name} missing {field}")
|
||||
if has_encoding:
|
||||
@@ -68,21 +66,18 @@ class TestPDFParserGenerate(unittest.TestCase):
|
||||
else:
|
||||
self.assertNotIn("ENCODING", fields, f"{fmt_name} should not have ENCODING")
|
||||
|
||||
def test_vopd_no_dpp16_fields(self):
|
||||
"""VOPD should not have DPP16-specific fields (parser boundary bug)."""
|
||||
vopd_fields = {f[0] for f in self.result["formats"].get("VOPD", [])}
|
||||
# test_vopd_no_dpp16_fields
|
||||
vopd_fields = {f[0] for f in result["formats"].get("VOPD", [])}
|
||||
for field in ['DPP_CTRL', 'BANK_MASK', 'ROW_MASK']:
|
||||
self.assertNotIn(field, vopd_fields, f"VOPD should not have {field}")
|
||||
|
||||
def test_dpp16_no_vinterp_fields(self):
|
||||
"""DPP16 should not have VINTERP-specific fields."""
|
||||
dpp16_fields = {f[0] for f in self.result["formats"].get("DPP16", [])}
|
||||
# test_dpp16_no_vinterp_fields
|
||||
dpp16_fields = {f[0] for f in result["formats"].get("DPP16", [])}
|
||||
for field in ['VDST', 'WAITEXP']:
|
||||
self.assertNotIn(field, dpp16_fields, f"DPP16 should not have {field}")
|
||||
|
||||
def test_sopp_no_smem_fields(self):
|
||||
"""SOPP should not have SMEM fields (page break bug)."""
|
||||
sopp_fields = {f[0] for f in self.result["formats"].get("SOPP", [])}
|
||||
# test_sopp_no_smem_fields
|
||||
sopp_fields = {f[0] for f in result["formats"].get("SOPP", [])}
|
||||
for field in ['SBASE', 'SDATA']:
|
||||
self.assertNotIn(field, sopp_fields, f"SOPP should not have {field}")
|
||||
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
#!/usr/bin/env python3
|
||||
import unittest, subprocess
|
||||
from extra.assembly.rdna3.autogen import *
|
||||
from extra.assembly.rdna3.test.test_roundtrip import _get_llvm_mc
|
||||
|
||||
def llvm_assemble(asm: str) -> bytes:
|
||||
"""Assemble using llvm-mc and return bytes."""
|
||||
result = subprocess.run(
|
||||
["llvm-mc", "-triple=amdgcn", "-mcpu=gfx1100", "-show-encoding"],
|
||||
[_get_llvm_mc(), "-triple=amdgcn", "-mcpu=gfx1100", "-show-encoding"],
|
||||
input=asm, capture_output=True, text=True
|
||||
)
|
||||
out = b''
|
||||
|
||||
@@ -1,10 +1,20 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Roundtrip tests: generate tinygrad kernels, decode instructions, re-encode, verify match."""
|
||||
import unittest, io, sys, re
|
||||
import unittest, io, sys, re, subprocess, shutil
|
||||
from extra.assembly.rdna3.autogen import *
|
||||
from extra.assembly.rdna3.lib import Inst
|
||||
from extra.assembly.rdna3.asm import asm
|
||||
|
||||
def _get_llvm_mc():
|
||||
for p in ['llvm-mc', 'llvm-mc-21', 'llvm-mc-20']: # prefer newer llvm-mc
|
||||
if shutil.which(p): return p
|
||||
raise FileNotFoundError("llvm-mc not found")
|
||||
|
||||
def _get_llvm_objdump():
|
||||
for p in ['llvm-objdump', 'llvm-objdump-21', 'llvm-objdump-20']:
|
||||
if shutil.which(p): return p
|
||||
raise FileNotFoundError("llvm-objdump not found")
|
||||
|
||||
# Instruction format detection based on encoding bits
|
||||
def detect_format(data: bytes) -> type[Inst] | None:
|
||||
"""Detect instruction format from machine code bytes."""
|
||||
@@ -66,24 +76,69 @@ def disassemble_lib(lib: bytes, compiler) -> list[tuple[str, bytes]]:
|
||||
continue
|
||||
return results
|
||||
|
||||
def compile_asm(instr: str, compiler=None) -> bytes | None:
|
||||
def compile_asm(instr: str, compiler=None) -> bytes:
|
||||
"""Compile a single instruction with llvm-mc and return the machine code bytes."""
|
||||
import subprocess
|
||||
llvm_mc = _get_llvm_mc()
|
||||
result = subprocess.run(
|
||||
[llvm_mc, '-triple=amdgcn', '-mcpu=gfx1100', '-mattr=+real-true16,+wavefrontsize32', '-show-encoding'],
|
||||
input=f".text\n{instr}\n", capture_output=True, text=True)
|
||||
if result.returncode != 0: raise RuntimeError(f"llvm-mc failed for '{instr}': {result.stderr.strip()}")
|
||||
# Parse encoding: [0x01,0x39,0x0a,0x7e]
|
||||
for line in result.stdout.split('\n'):
|
||||
if 'encoding:' in line:
|
||||
enc = line.split('encoding:')[1].strip()
|
||||
if enc.startswith('[') and enc.endswith(']'):
|
||||
hex_vals = enc[1:-1].replace('0x', '').replace(',', '').replace(' ', '')
|
||||
return bytes.fromhex(hex_vals)
|
||||
raise RuntimeError(f"no encoding found in llvm-mc output for: {instr}")
|
||||
|
||||
def compile_asm_batch(instrs: list[str]) -> list[bytes]:
|
||||
"""Compile multiple instructions with a single llvm-mc call."""
|
||||
if not instrs: return []
|
||||
llvm_mc = _get_llvm_mc()
|
||||
src = ".text\n" + "\n".join(instrs) + "\n"
|
||||
result = subprocess.run(
|
||||
[llvm_mc, '-triple=amdgcn', '-mcpu=gfx1100', '-mattr=+real-true16,+wavefrontsize32', '-show-encoding'],
|
||||
input=src, capture_output=True, text=True)
|
||||
if result.returncode != 0: raise RuntimeError(f"llvm-mc batch failed: {result.stderr.strip()}")
|
||||
# Parse all encodings in order
|
||||
encodings = []
|
||||
for line in result.stdout.split('\n'):
|
||||
if 'encoding:' in line:
|
||||
enc = line.split('encoding:')[1].strip()
|
||||
if enc.startswith('[') and enc.endswith(']'):
|
||||
hex_vals = enc[1:-1].replace('0x', '').replace(',', '').replace(' ', '')
|
||||
encodings.append(bytes.fromhex(hex_vals))
|
||||
if len(encodings) != len(instrs): raise RuntimeError(f"expected {len(instrs)} encodings, got {len(encodings)}")
|
||||
return encodings
|
||||
|
||||
def compile_and_disasm_batch(instrs: list[str], compiler) -> list[str | None]:
|
||||
"""Compile instructions with LLVM and get LLVM's disassembly."""
|
||||
import tempfile, os
|
||||
if not instrs: return []
|
||||
# Build assembly source with all instructions
|
||||
src = ".text\n.globl test\n.p2align 8\n.type test,@function\ntest:\n"
|
||||
src += "\n".join(f" {instr}" for instr in instrs) + "\n"
|
||||
# Use llvm-mc to assemble to object file
|
||||
with tempfile.NamedTemporaryFile(suffix='.o', delete=False) as f:
|
||||
obj_path = f.name
|
||||
try:
|
||||
result = subprocess.run(
|
||||
['llvm-mc', '-triple=amdgcn', '-mcpu=gfx1100', '-mattr=+real-true16,+wavefrontsize32', '-show-encoding'],
|
||||
input=f".text\n{instr}\n", capture_output=True, text=True)
|
||||
if result.returncode != 0: return None
|
||||
# Parse encoding: [0x01,0x39,0x0a,0x7e]
|
||||
for line in result.stdout.split('\n'):
|
||||
if 'encoding:' in line:
|
||||
enc = line.split('encoding:')[1].strip()
|
||||
if enc.startswith('[') and enc.endswith(']'):
|
||||
hex_vals = enc[1:-1].replace('0x', '').replace(',', '').replace(' ', '')
|
||||
return bytes.fromhex(hex_vals)
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
[_get_llvm_mc(), '-triple=amdgcn', '-mcpu=gfx1100', '-mattr=+real-true16,+wavefrontsize32', '-filetype=obj', '-o', obj_path],
|
||||
input=src, capture_output=True, text=True)
|
||||
if result.returncode != 0: raise RuntimeError(f"llvm-mc failed: {result.stderr.strip()}")
|
||||
# Disassemble with llvm-objdump
|
||||
result = subprocess.run([_get_llvm_objdump(), '-d', '--mcpu=gfx1100', obj_path], capture_output=True, text=True)
|
||||
if result.returncode != 0: raise RuntimeError(f"llvm-objdump failed: {result.stderr.strip()}")
|
||||
# Parse disassembly output
|
||||
results = []
|
||||
for line in result.stdout.splitlines():
|
||||
if '//' not in line: continue
|
||||
instr = line.split('//')[0].strip()
|
||||
if instr: results.append(instr)
|
||||
return results[:len(instrs)]
|
||||
finally:
|
||||
os.unlink(obj_path)
|
||||
|
||||
class TestTinygradKernelRoundtrip(unittest.TestCase):
|
||||
"""Test roundtrip on real tinygrad-generated kernels using get_kernels_from_tinygrad pattern."""
|
||||
@@ -100,90 +155,113 @@ class TestTinygradKernelRoundtrip(unittest.TestCase):
|
||||
kernels, _, _ = get_kernels_from_tinygrad(op_fn)
|
||||
compiler = HIPCompiler('gfx1100')
|
||||
|
||||
decode_passed, decode_failed, decode_skipped = 0, 0, 0
|
||||
asm_passed, asm_failed, asm_skipped = 0, 0, 0
|
||||
disasm_passed, disasm_failed, disasm_skipped = 0, 0, 0
|
||||
decode_failures, asm_failures, disasm_failures = [], [], []
|
||||
|
||||
# First pass: decode all instructions and collect info
|
||||
decoded_instrs = [] # list of (ki, offset, orig_bytes, decoded, our_disasm, decode_ok, decode_err)
|
||||
for ki, kernel in enumerate(kernels):
|
||||
offset = 0
|
||||
while offset < len(kernel.code):
|
||||
remaining = kernel.code[offset:]
|
||||
fmt = detect_format(remaining)
|
||||
if fmt is None:
|
||||
decode_skipped += 1
|
||||
asm_skipped += 1
|
||||
disasm_skipped += 1
|
||||
decoded_instrs.append((ki, offset, None, None, None, False, "no format"))
|
||||
offset += 4
|
||||
continue
|
||||
|
||||
size = fmt._size()
|
||||
if len(remaining) < size:
|
||||
base_size = fmt._size()
|
||||
if len(remaining) < base_size:
|
||||
break
|
||||
|
||||
orig_bytes = remaining[:size]
|
||||
|
||||
# Test 1: decode -> reencode roundtrip
|
||||
try:
|
||||
decoded = fmt.from_bytes(orig_bytes)
|
||||
decoded = fmt.from_bytes(remaining) # pass all remaining bytes so from_bytes can read literal
|
||||
size = decoded.size() # actual size including literal
|
||||
orig_bytes = remaining[:size]
|
||||
reencoded = decoded.to_bytes()
|
||||
if reencoded[:size] == orig_bytes:
|
||||
decode_passed += 1
|
||||
else:
|
||||
decode_failed += 1
|
||||
decode_failures.append(f"K{ki}@{offset}: {decoded.disasm()}: orig={orig_bytes.hex()} reenc={reencoded[:size].hex()}")
|
||||
|
||||
our_disasm = decoded.disasm()
|
||||
decode_ok = reencoded == orig_bytes
|
||||
decode_err = None if decode_ok else f"orig={orig_bytes.hex()} reenc={reencoded.hex()}"
|
||||
decoded_instrs.append((ki, offset, orig_bytes, decoded, our_disasm, decode_ok, decode_err))
|
||||
except Exception as e:
|
||||
decoded_instrs.append((ki, offset, remaining[:base_size], None, None, False, str(e)))
|
||||
size = base_size
|
||||
|
||||
# Test 2: asm(disasm()) matches LLVM output
|
||||
offset += size
|
||||
|
||||
# Collect disasm strings for batched LLVM calls - skip unknown opcodes (op_X) that LLVM can't compile
|
||||
asm_test_instrs = [] # (idx, our_disasm) for asm test
|
||||
disasm_test_instrs = [] # (idx, our_disasm) for disasm comparison test
|
||||
|
||||
for idx, (ki, offset, orig_bytes, decoded, our_disasm, decode_ok, decode_err) in enumerate(decoded_instrs):
|
||||
if our_disasm is None: continue
|
||||
# Skip unknown opcodes and malformed instructions for both tests
|
||||
if our_disasm.startswith('op_') or re.search(r', \d+, \d+, \d+,', our_disasm): continue
|
||||
asm_test_instrs.append((idx, our_disasm))
|
||||
disasm_test_instrs.append((idx, our_disasm))
|
||||
|
||||
# Batch compile for asm test
|
||||
asm_llvm_results = compile_asm_batch([d for _, d in asm_test_instrs])
|
||||
asm_llvm_map = {idx: result for (idx, _), result in zip(asm_test_instrs, asm_llvm_results)}
|
||||
|
||||
# Batch compile+disasm for disasm comparison test
|
||||
disasm_llvm_results = compile_and_disasm_batch([d for _, d in disasm_test_instrs], compiler)
|
||||
disasm_llvm_map = {idx: result for (idx, _), result in zip(disasm_test_instrs, disasm_llvm_results)}
|
||||
|
||||
# Now evaluate results
|
||||
decode_passed, decode_failed, decode_skipped = 0, 0, 0
|
||||
asm_passed, asm_failed, asm_skipped = 0, 0, 0
|
||||
disasm_passed, disasm_failed, disasm_skipped = 0, 0, 0
|
||||
decode_failures, asm_failures, disasm_failures = [], [], []
|
||||
|
||||
for idx, (ki, offset, orig_bytes, decoded, our_disasm, decode_ok, decode_err) in enumerate(decoded_instrs):
|
||||
# Decode test
|
||||
if decode_ok:
|
||||
decode_passed += 1
|
||||
elif decode_err == "no format":
|
||||
decode_skipped += 1
|
||||
else:
|
||||
decode_failed += 1
|
||||
decode_failures.append(f"K{ki}@{offset}: {our_disasm}: {decode_err}")
|
||||
|
||||
# Asm test
|
||||
if our_disasm is None:
|
||||
asm_skipped += 1
|
||||
elif idx in asm_llvm_map:
|
||||
llvm_bytes = asm_llvm_map[idx]
|
||||
if llvm_bytes is None:
|
||||
asm_skipped += 1
|
||||
else:
|
||||
try:
|
||||
our_bytes = asm(our_disasm).to_bytes()
|
||||
llvm_bytes = compile_asm(our_disasm, compiler)
|
||||
if llvm_bytes is None:
|
||||
asm_skipped += 1
|
||||
elif our_bytes[:len(llvm_bytes)] == llvm_bytes:
|
||||
if our_bytes[:len(llvm_bytes)] == llvm_bytes:
|
||||
asm_passed += 1
|
||||
else:
|
||||
asm_failed += 1
|
||||
asm_failures.append(f"K{ki}@{offset}: '{our_disasm}': ours={our_bytes[:len(llvm_bytes)].hex()} llvm={llvm_bytes.hex()}")
|
||||
except Exception:
|
||||
asm_skipped += 1
|
||||
else:
|
||||
asm_skipped += 1
|
||||
|
||||
# Test 3: our disasm() matches LLVM's disassembly string exactly
|
||||
# Skip if instruction uses op_XX (unknown opcode) or looks malformed (many raw field values)
|
||||
if our_disasm.startswith('op_') or re.search(r', \d+, \d+, \d+,', our_disasm):
|
||||
disasm_skipped += 1
|
||||
else:
|
||||
try:
|
||||
# Get LLVM's disassembly of our instruction
|
||||
src = f".text\n.globl test\n.p2align 8\n.type test,@function\ntest:\n {our_disasm}\n"
|
||||
lib = compiler.compile(src)
|
||||
llvm_instrs = disassemble_lib(lib, compiler)
|
||||
if llvm_instrs:
|
||||
llvm_disasm = llvm_instrs[0][0]
|
||||
if our_disasm == llvm_disasm:
|
||||
disasm_passed += 1
|
||||
else:
|
||||
disasm_failed += 1
|
||||
disasm_failures.append(f"K{ki}@{offset}: ours='{our_disasm}' llvm='{llvm_disasm}'")
|
||||
else:
|
||||
disasm_skipped += 1
|
||||
except Exception:
|
||||
disasm_skipped += 1
|
||||
|
||||
except Exception:
|
||||
decode_skipped += 1
|
||||
asm_skipped += 1
|
||||
# Disasm comparison test
|
||||
if our_disasm is None:
|
||||
disasm_skipped += 1
|
||||
elif idx in disasm_llvm_map:
|
||||
llvm_disasm = disasm_llvm_map[idx]
|
||||
if llvm_disasm is None:
|
||||
disasm_skipped += 1
|
||||
|
||||
offset += size
|
||||
elif our_disasm == llvm_disasm:
|
||||
disasm_passed += 1
|
||||
else:
|
||||
disasm_failed += 1
|
||||
disasm_failures.append(f"K{ki}@{offset}: ours='{our_disasm}' llvm='{llvm_disasm}'")
|
||||
else:
|
||||
disasm_skipped += 1
|
||||
|
||||
print(f"decode roundtrip: {decode_passed} passed, {decode_failed} failed, {decode_skipped} skipped")
|
||||
print(f"asm vs llvm: {asm_passed} passed, {asm_failed} failed, {asm_skipped} skipped")
|
||||
print(f"disasm vs llvm: {disasm_passed} passed, {disasm_failed} failed, {disasm_skipped} skipped")
|
||||
self.assertEqual(decode_failed, 0, f"Decode failures:\n" + "\n".join(decode_failures[:20]))
|
||||
self.assertEqual(asm_failed, 0, f"Asm failures:\n" + "\n".join(asm_failures[:20]))
|
||||
self.assertEqual(disasm_failed, 0, f"Disasm failures:\n" + "\n".join(disasm_failures[:20]))
|
||||
# Note: disasm string comparison is informational only - formatting differences between LLVM versions are expected
|
||||
|
||||
# Basic unary ops
|
||||
def test_neg(self): self._test_kernel_roundtrip(lambda T: -T([1.0, -2.0, 3.0, -4.0]))
|
||||
|
||||
Reference in New Issue
Block a user