assembly/amd: add pcode ds ops (#13939)

* assembly/amd: add pcode ds ops

* refactors

* fix ds op

* update autogen

* fix flat bug

* more tests

* fix emu test

* that's a hack

* generic

* fix all tests

* two tests

* fix test failure

* better

* remove __all__
This commit is contained in:
George Hotz
2026-01-01 16:24:13 -05:00
committed by GitHub
parent cb7c76a3bd
commit dfb813b760
11 changed files with 7881 additions and 368 deletions

View File

@@ -684,6 +684,9 @@ jobs:
run: AMD=1 PYTHON_REMU=1 MOCKGPU=1 AMD_LLVM=0 pytest -n=auto test/test_dtype_alu.py test/test_dtype.py
- name: Run RDNA3 dtype tests (AMD_LLVM=1)
run: AMD=1 PYTHON_REMU=1 MOCKGPU=1 AMD_LLVM=1 pytest -n=auto test/test_dtype_alu.py test/test_dtype.py
# TODO: run all once emulator is faster
- name: Run RDNA3 ops tests
run: SKIP_SLOW_TEST=1 AMD=1 PYTHON_REMU=1 MOCKGPU=1 AMD_LLVM=0 pytest -n=auto test/test_ops.py -k "test_sparse_categorical_crossentropy or test_tril"
testamdautogen:
name: AMD autogen

View File

@@ -11,6 +11,8 @@ Test with `PYTHONPATH="." pytest -n12 extra/assembly/amd/`
The code should be as readable and deduplicated as possible. asm and emu shouldn't be required for dsl.
The autogen folder is autogenerated from the AMD PDFs with `python3 -m extra.assembly.amd.pdf --arch all`
test_emu.py has a good set of instruction tests for the emulation, with USE_HW=1 it will compare to real hardware.
Whenever an instruction is fixed, regression tests should be added here and confirmed with real hardware.
@@ -26,6 +28,12 @@ The ops tests also pass, but they are very slow, so you should run them one at a
`SKIP_SLOW_TEST=1 PYTHONPATH="." AMD=1 PYTHON_REMU=1 MOCKGPU=1 AMD_LLVM=0 pytest -n=12 test/test_ops.py`
`SKIP_SLOW_TEST=1 PYTHONPATH="." AMD=1 PYTHON_REMU=1 MOCKGPU=1 AMD_LLVM=1 pytest -n=12 test/test_ops.py`
When something is caught by main tinygrad tests, a local regression test should be added to `extra/assembly/amd/test`. While working with tinygrad, you can dump the assembly with `DEBUG=7`. These tests all pass on real hardware, so if a test is failing with `AMD=1 PYTHON_REMU=1 MOCKGPU=1` it's likely because an instruction is emulated incorrectly. You can test without `MOCKGPU=1` to test on real hardware, if it works on real hardware there's a bug in the emulator.
When something is caught by main tinygrad tests, a local regression test should be added to `extra/assembly/amd/test`.
While working with tinygrad, you can dump the assembly with `DEBUG=7`. These tests all pass on real hardware
If a test is failing with `AMD=1 PYTHON_REMU=1 MOCKGPU=1` it's because an instruction is emulated incorrectly.
You can test without `MOCKGPU=1` to test on real hardware, if it works on real hardware there's a bug in the emulator.
IMPORTANT: if a test is failing in the emulator, it's an instruction bug. Use DEBUG=7, get the instructions, and debug.
Currently, only RDNA3 is well supported, but when finished, this will support RDNA3+RDNA4+CDNA in ~2000 lines.
Get line count with `cloc --by-file extra/assembly/amd/*.py`

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -9,7 +9,7 @@ from extra.assembly.amd.autogen.rdna3.enum import (VOP1Op, VOP2Op, VOP3Op, VOP3S
SOPCOp, SOPKOp, SOPPOp, SMEMOp, DSOp, FLATOp, MUBUFOp, MTBUFOp, MIMGOp, VINTERPOp)
# Common masks and bit conversion functions
MASK32, MASK64 = 0xffffffff, 0xffffffffffffffff
MASK32, MASK64, MASK128 = 0xffffffff, 0xffffffffffffffff, (1 << 128) - 1
_struct_f, _struct_I = struct.Struct("<f"), struct.Struct("<I")
_struct_e, _struct_H = struct.Struct("<e"), struct.Struct("<H")
_struct_d, _struct_Q = struct.Struct("<d"), struct.Struct("<Q")

View File

@@ -7,7 +7,7 @@ from extra.assembly.amd.pcode import Reg
from extra.assembly.amd.asm import detect_format
from extra.assembly.amd.autogen.rdna3.gen_pcode import get_compiled_functions
from extra.assembly.amd.autogen.rdna3.ins import (SOP1, SOP2, SOPC, SOPK, SOPP, SMEM, VOP1, VOP2, VOP3, VOP3SD, VOP3P, VOPC, DS, FLAT, VOPD,
SrcEnum, SOP1Op, SOP2Op, SOPCOp, SOPKOp, SOPPOp, SMEMOp, VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOP3POp, VOPCOp, DSOp, FLATOp, GLOBALOp, VOPDOp)
SrcEnum, SOP1Op, SOP2Op, SOPCOp, SOPKOp, SOPPOp, SMEMOp, VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOP3POp, VOPCOp, DSOp, FLATOp, GLOBALOp, SCRATCHOp, VOPDOp)
Program = dict[int, Inst]
WAVE_SIZE, SGPR_COUNT, VGPR_COUNT = 32, 128, 256
@@ -29,33 +29,65 @@ def _dst16(cur: int, val: int, is_hi: bool) -> int: return (cur & 0x0000ffff) |
def _vgpr_hi(src: int) -> bool: return src >= 256 and ((src - 256) & 0x80) != 0
def _vgpr_masked(src: int) -> int: return ((src - 256) & 0x7f) + 256 if src >= 256 else src
# Helper: get number of dwords from memory op name
def _op_ndwords(name: str) -> int:
if '_B128' in name: return 4
if '_B96' in name: return 3
if any(s in name for s in ('_B64', '_U64', '_I64', '_F64')): return 2
return 1
# Helper: build multi-dword Reg from consecutive VGPRs
def _vgpr_read(V: list, base: int, ndwords: int) -> Reg: return Reg(sum(V[base + i] << (32 * i) for i in range(ndwords)))
# Helper: write multi-dword value to consecutive VGPRs
def _vgpr_write(V: list, base: int, val: int, ndwords: int):
for i in range(ndwords): V[base + i] = (val >> (32 * i)) & MASK32
# Memory access
_valid_mem_ranges: list[tuple[int, int]] = []
def set_valid_mem_ranges(ranges: set[tuple[int, int]]) -> None: _valid_mem_ranges.clear(); _valid_mem_ranges.extend(ranges)
def _mem_valid(addr: int, size: int) -> bool:
return not _valid_mem_ranges or any(s <= addr and addr + size <= s + z for s, z in _valid_mem_ranges)
def _ctypes_at(addr: int, size: int): return (ctypes.c_uint8 if size == 1 else ctypes.c_uint16 if size == 2 else ctypes.c_uint32).from_address(addr)
def _ctypes_at(addr: int, size: int): return (ctypes.c_uint8 if size == 1 else ctypes.c_uint16 if size == 2 else ctypes.c_uint64 if size == 8 else ctypes.c_uint32).from_address(addr)
def mem_read(addr: int, size: int) -> int: return _ctypes_at(addr, size).value if _mem_valid(addr, size) else 0
def mem_write(addr: int, size: int, val: int) -> None:
if _mem_valid(addr, size): _ctypes_at(addr, size).value = val
# Memory op tables (not pseudocode - these are format descriptions)
def _mem_ops(ops, suffix_map):
return {getattr(e, f"{p}_{s}"): v for e in ops for s, v in suffix_map.items() for p in [e.__name__.replace("Op", "")]}
_LOAD_MAP = {'LOAD_B32': (1,4,0), 'LOAD_B64': (2,4,0), 'LOAD_B96': (3,4,0), 'LOAD_B128': (4,4,0), 'LOAD_U8': (1,1,0), 'LOAD_I8': (1,1,1), 'LOAD_U16': (1,2,0), 'LOAD_I16': (1,2,1)}
_STORE_MAP = {'STORE_B32': (1,4), 'STORE_B64': (2,4), 'STORE_B96': (3,4), 'STORE_B128': (4,4), 'STORE_B8': (1,1), 'STORE_B16': (1,2)}
FLAT_LOAD, FLAT_STORE = _mem_ops([GLOBALOp, FLATOp], _LOAD_MAP), _mem_ops([GLOBALOp, FLATOp], _STORE_MAP)
# D16 ops: load/store 16-bit to lower or upper half of VGPR. Format: (size, sign, hi) where hi=1 means upper 16 bits
_D16_LOAD_MAP = {'LOAD_D16_U8': (1,0,0), 'LOAD_D16_I8': (1,1,0), 'LOAD_D16_B16': (2,0,0),
'LOAD_D16_HI_U8': (1,0,1), 'LOAD_D16_HI_I8': (1,1,1), 'LOAD_D16_HI_B16': (2,0,1)}
_D16_STORE_MAP = {'STORE_D16_HI_B8': (1,1), 'STORE_D16_HI_B16': (2,1)} # (size, hi)
FLAT_D16_LOAD = _mem_ops([GLOBALOp, FLATOp], _D16_LOAD_MAP)
FLAT_D16_STORE = _mem_ops([GLOBALOp, FLATOp], _D16_STORE_MAP)
DS_LOAD = {DSOp.DS_LOAD_B32: (1,4,0), DSOp.DS_LOAD_B64: (2,4,0), DSOp.DS_LOAD_B128: (4,4,0), DSOp.DS_LOAD_U8: (1,1,0), DSOp.DS_LOAD_I8: (1,1,1), DSOp.DS_LOAD_U16: (1,2,0), DSOp.DS_LOAD_I16: (1,2,1)}
DS_STORE = {DSOp.DS_STORE_B32: (1,4), DSOp.DS_STORE_B64: (2,4), DSOp.DS_STORE_B128: (4,4), DSOp.DS_STORE_B8: (1,1), DSOp.DS_STORE_B16: (1,2)}
# 2ADDR ops: load/store two values using offset0 and offset1
DS_LOAD_2ADDR = {DSOp.DS_LOAD_2ADDR_B32: 4, DSOp.DS_LOAD_2ADDR_B64: 8}
DS_STORE_2ADDR = {DSOp.DS_STORE_2ADDR_B32: 4, DSOp.DS_STORE_2ADDR_B64: 8}
def _make_mem_accessor(read_fn, write_fn):
"""Create a memory accessor class with the given read/write functions."""
class _MemAccessor:
__slots__ = ('_addr',)
def __init__(self, addr: int): self._addr = int(addr)
u8 = property(lambda s: read_fn(s._addr, 1), lambda s, v: write_fn(s._addr, 1, int(v)))
u16 = property(lambda s: read_fn(s._addr, 2), lambda s, v: write_fn(s._addr, 2, int(v)))
u32 = property(lambda s: read_fn(s._addr, 4), lambda s, v: write_fn(s._addr, 4, int(v)))
u64 = property(lambda s: read_fn(s._addr, 8), lambda s, v: write_fn(s._addr, 8, int(v)))
i8 = property(lambda s: _sext(read_fn(s._addr, 1), 8), lambda s, v: write_fn(s._addr, 1, int(v)))
i16 = property(lambda s: _sext(read_fn(s._addr, 2), 16), lambda s, v: write_fn(s._addr, 2, int(v)))
i32 = property(lambda s: _sext(read_fn(s._addr, 4), 32), lambda s, v: write_fn(s._addr, 4, int(v)))
i64 = property(lambda s: _sext(read_fn(s._addr, 8), 64), lambda s, v: write_fn(s._addr, 8, int(v)))
b8, b16, b32, b64 = u8, u16, u32, u64
return _MemAccessor
_GlobalMemAccessor = _make_mem_accessor(mem_read, mem_write)
class _GlobalMem:
"""Global memory wrapper that supports MEM[addr].u32 style access."""
def __getitem__(self, addr) -> _GlobalMemAccessor: return _GlobalMemAccessor(addr)
GlobalMem = _GlobalMem()
class LDSMem:
"""LDS memory wrapper that supports MEM[addr].u32 style access."""
__slots__ = ('_lds',)
def __init__(self, lds: bytearray): self._lds = lds
def _read(self, addr: int, size: int) -> int:
addr = addr & 0xffff
return int.from_bytes(self._lds[addr:addr+size], 'little') if addr + size <= len(self._lds) else 0
def _write(self, addr: int, size: int, val: int):
addr = addr & 0xffff
if addr + size <= len(self._lds): self._lds[addr:addr+size] = (int(val) & ((1 << (size*8)) - 1)).to_bytes(size, 'little')
def __getitem__(self, addr): return _make_mem_accessor(self._read, self._write)(addr)
SMEM_LOAD = {SMEMOp.S_LOAD_B32: 1, SMEMOp.S_LOAD_B64: 2, SMEMOp.S_LOAD_B128: 4, SMEMOp.S_LOAD_B256: 8, SMEMOp.S_LOAD_B512: 16}
# VOPD op -> VOP3 op mapping (VOPD is dual-issue of VOP1/VOP2 ops, use VOP3 enums for pseudocode lookup)
@@ -197,60 +229,28 @@ def exec_scalar(st: WaveState, inst: Inst) -> int:
return new_pc_words - st.pc - 1 # -1 because emulator adds inst_words (1 for scalar)
return 0
def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = None) -> None:
def exec_vector(st: WaveState, inst: Inst, lane: int, lds: LDSMem | None = None) -> None:
"""Execute vector instruction for one lane."""
compiled = _get_compiled()
V = st.vgpr[lane]
# Memory ops (not ALU pseudocode)
if isinstance(inst, FLAT):
op, addr_reg, data_reg, vdst, offset, saddr = inst.op, inst.addr, inst.data, inst.vdst, _sext(inst.offset, 13), inst.saddr
addr = V[addr_reg] | (V[addr_reg+1] << 32)
addr = (st.rsgpr64(saddr) + V[addr_reg] + offset) & MASK64 if saddr not in (NULL, 0x7f) else (addr + offset) & MASK64
if op in FLAT_LOAD:
cnt, sz, sign = FLAT_LOAD[op]
for i in range(cnt): val = mem_read(addr + i * sz, sz); V[vdst + i] = _sext(val, sz * 8) & MASK32 if sign else val
elif op in FLAT_STORE:
cnt, sz = FLAT_STORE[op]
for i in range(cnt): mem_write(addr + i * sz, sz, V[data_reg + i] & ((1 << (sz * 8)) - 1))
elif op in FLAT_D16_LOAD:
sz, sign, hi = FLAT_D16_LOAD[op]
val = mem_read(addr, sz)
if sign: val = _sext(val, sz * 8) & 0xffff
V[vdst] = _dst16(V[vdst], val, hi)
elif op in FLAT_D16_STORE:
sz, hi = FLAT_D16_STORE[op]
mem_write(addr, sz, _src16(V[data_reg], hi) & ((1 << (sz * 8)) - 1))
else: raise NotImplementedError(f"FLAT op {op}")
return
if isinstance(inst, DS):
op, addr0, vdst = inst.op, (V[inst.addr] + inst.offset0) & 0xffff, inst.vdst
if op in DS_LOAD:
cnt, sz, sign = DS_LOAD[op]
for i in range(cnt): val = int.from_bytes(lds[addr0+i*sz:addr0+i*sz+sz], 'little'); V[vdst + i] = _sext(val, sz * 8) & MASK32 if sign else val
elif op in DS_STORE:
cnt, sz = DS_STORE[op]
for i in range(cnt): lds[addr0+i*sz:addr0+i*sz+sz] = (V[inst.data0 + i] & ((1 << (sz * 8)) - 1)).to_bytes(sz, 'little')
elif op in DS_LOAD_2ADDR:
# Load two values from addr+offset0*sz and addr+offset1*sz into vdst (B32: 1 dword each, B64: 2 dwords each)
# Note: offsets are scaled by data size (4 for B32, 8 for B64) per AMD ISA
sz = DS_LOAD_2ADDR[op]
addr0 = (V[inst.addr] + inst.offset0 * sz) & 0xffff
addr1 = (V[inst.addr] + inst.offset1 * sz) & 0xffff
cnt = sz // 4 # 1 for B32, 2 for B64
for i in range(cnt): V[vdst + i] = int.from_bytes(lds[addr0+i*4:addr0+i*4+4], 'little')
for i in range(cnt): V[vdst + cnt + i] = int.from_bytes(lds[addr1+i*4:addr1+i*4+4], 'little')
elif op in DS_STORE_2ADDR:
# Store two values from data0 and data1 to addr+offset0*sz and addr+offset1*sz
# Note: offsets are scaled by data size (4 for B32, 8 for B64) per AMD ISA
sz = DS_STORE_2ADDR[op]
addr0 = (V[inst.addr] + inst.offset0 * sz) & 0xffff
addr1 = (V[inst.addr] + inst.offset1 * sz) & 0xffff
cnt = sz // 4
for i in range(cnt): lds[addr0+i*4:addr0+i*4+4] = (V[inst.data0 + i] & MASK32).to_bytes(4, 'little')
for i in range(cnt): lds[addr1+i*4:addr1+i*4+4] = (V[inst.data1 + i] & MASK32).to_bytes(4, 'little')
else: raise NotImplementedError(f"DS op {op}")
# Memory ops (FLAT/GLOBAL/SCRATCH and DS) - use generated pcode
if isinstance(inst, (FLAT, DS)):
op, vdst, op_name = inst.op, inst.vdst, inst.op.name
fn, ndwords = compiled[type(op)][op], _op_ndwords(op_name)
if isinstance(inst, FLAT):
addr = V[inst.addr] | (V[inst.addr + 1] << 32)
ADDR = (st.rsgpr64(inst.saddr) + V[inst.addr] + _sext(inst.offset, 13)) & MASK64 if inst.saddr not in (NULL, 0x7f) else (addr + _sext(inst.offset, 13)) & MASK64
# For loads, VDATA comes from vdst (preserves unwritten bits); for stores, from inst.data
vdata_src = vdst if 'LOAD' in op_name else inst.data
result = fn(GlobalMem, ADDR, _vgpr_read(V, vdata_src, ndwords), Reg(V[vdst]), Reg(0))
if 'VDATA' in result: _vgpr_write(V, vdst, result['VDATA']._val, ndwords)
if 'RETURN_DATA' in result: _vgpr_write(V, vdst, result['RETURN_DATA']._val, ndwords)
else: # DS
DATA0, DATA1 = _vgpr_read(V, inst.data0, ndwords), _vgpr_read(V, inst.data1, ndwords) if inst.data1 is not None else Reg(0)
result = fn(lds, Reg(V[inst.addr]), DATA0, DATA1, Reg(inst.offset0), Reg(inst.offset1), Reg(0))
if 'RETURN_DATA' in result and ('_RTN' in op_name or '_LOAD' in op_name):
_vgpr_write(V, vdst, result['RETURN_DATA']._val, ndwords * 2 if '_2ADDR_' in op_name else ndwords)
return
# VOPD: dual-issue, execute two ops simultaneously (read all inputs before writes)
@@ -423,7 +423,7 @@ def exec_wmma(st: WaveState, inst, op: VOP3POp) -> None:
# MAIN EXECUTION LOOP
# ═══════════════════════════════════════════════════════════════════════════════
def step_wave(program: Program, st: WaveState, lds: bytearray, n_lanes: int) -> int:
def step_wave(program: Program, st: WaveState, lds: LDSMem, n_lanes: int) -> int:
inst = program.get(st.pc)
if inst is None: return 1
inst_words, st.literal = inst._words, getattr(inst, '_literal', None) or 0
@@ -443,7 +443,7 @@ def step_wave(program: Program, st: WaveState, lds: bytearray, n_lanes: int) ->
st.pc += inst_words
return 0
def exec_wave(program: Program, st: WaveState, lds: bytearray, n_lanes: int) -> int:
def exec_wave(program: Program, st: WaveState, lds: LDSMem, n_lanes: int) -> int:
while st.pc in program:
result = step_wave(program, st, lds, n_lanes)
if result == -1: return 0
@@ -453,7 +453,7 @@ def exec_wave(program: Program, st: WaveState, lds: bytearray, n_lanes: int) ->
def exec_workgroup(program: Program, workgroup_id: tuple[int, int, int], local_size: tuple[int, int, int], args_ptr: int,
wg_id_sgpr_base: int, wg_id_enables: tuple[bool, bool, bool]) -> None:
lx, ly, lz = local_size
total_threads, lds = lx * ly * lz, bytearray(65536)
total_threads, lds = lx * ly * lz, LDSMem(bytearray(65536))
waves: list[tuple[WaveState, int, int]] = []
for wave_start in range(0, total_threads, WAVE_SIZE):
n_lanes, st = min(WAVE_SIZE, total_threads - wave_start), WaveState()

View File

@@ -1,6 +1,6 @@
# DSL for RDNA3 pseudocode - makes pseudocode expressions work directly as Python
import struct, math
from extra.assembly.amd.dsl import MASK32, MASK64, _f32, _i32, _sext, _f16, _i16, _f64, _i64
from extra.assembly.amd.dsl import MASK32, MASK64, MASK128, _f32, _i32, _sext, _f16, _i16, _f64, _i64
# ═══════════════════════════════════════════════════════════════════════════════
# HELPER FUNCTIONS
@@ -206,47 +206,6 @@ def signext_from_bit(val, bit):
if val & (1 << (bit - 1)): return val - (1 << bit)
return val
# ═══════════════════════════════════════════════════════════════════════════════
# DSL EXPORTS
# ═══════════════════════════════════════════════════════════════════════════════
__all__ = [
# Classes
'Reg', 'SliceProxy', 'TypedView',
# Pack functions
'_pack', '_pack32', 'pack', 'pack32',
# Constants
'WAVE32', 'WAVE64', 'MASK32', 'MASK64', 'WAVE_MODE', 'DENORM', 'OVERFLOW_F32', 'UNDERFLOW_F32',
'OVERFLOW_F64', 'UNDERFLOW_F64', 'MAX_FLOAT_F32', 'ROUND_MODE', 'cvtToQuietNAN', 'DST', 'INF', 'PI',
'TWO_OVER_PI_1201',
# Aliases for pseudocode
's_ff1_i32_b32', 's_ff1_i32_b64', 'GT_NEG_ZERO', 'LT_NEG_ZERO',
'isNAN', 'isQuietNAN', 'isSignalNAN', 'fma', 'ldexp', 'sign', 'exponent', 'F', 'signext',
# Conversion functions
'_f32', '_i32', '_f16', '_i16', '_f64', '_i64', '_sext', '_to_f16_bits', '_f16_to_f32_bits',
'i32_to_f32', 'u32_to_f32', 'i32_to_f64', 'u32_to_f64', 'f32_to_f64', 'f64_to_f32',
'f32_to_i32', 'f32_to_u32', 'f64_to_i32', 'f64_to_u32', 'f32_to_f16', 'f16_to_f32',
'i16_to_f16', 'u16_to_f16', 'f16_to_i16', 'f16_to_u16', 'u32_to_u16', 'i32_to_i16',
'f16_to_snorm', 'f16_to_unorm', 'f32_to_snorm', 'f32_to_unorm', 'v_cvt_i16_f32', 'v_cvt_u16_f32',
'SAT8', 'f32_to_u8', 'u8_to_u32', 'u4_to_u32',
# BF16 conversion functions
'_bf16', '_ibf16', 'bf16_to_f32', 'f32_to_bf16',
# Math functions
'trunc', 'floor', 'ceil', 'sqrt', 'log2', 'sin', 'cos', 'pow', 'fract', 'isEven', 'mantissa',
# Min/max functions
'v_min_f32', 'v_max_f32', 'v_min_i32', 'v_max_i32', 'v_min_u32', 'v_max_u32',
'v_min_f16', 'v_max_f16', 'v_min_i16', 'v_max_i16', 'v_min_u16', 'v_max_u16',
'v_min3_f32', 'v_max3_f32', 'v_min3_i32', 'v_max3_i32', 'v_min3_u32', 'v_max3_u32',
'v_min3_f16', 'v_max3_f16', 'v_min3_i16', 'v_max3_i16', 'v_min3_u16', 'v_max3_u16',
'ABSDIFF',
# Byte/SAD helper functions
'BYTE_PERMUTE', 'v_sad_u8', 'v_msad_u8',
# Bit manipulation
'_brev32', '_brev64', '_ctz32', '_ctz64', '_exponent', '_is_denorm_f32', '_is_denorm_f64',
'_sign', '_mantissa_f32', '_div', '_isnan', '_isquietnan', '_issignalnan', '_gt_neg_zero', '_lt_neg_zero', '_fma', '_ldexp', '_signext',
'signext_from_bit',
]
# Aliases used in pseudocode
s_ff1_i32_b32, s_ff1_i32_b64 = _ctz32, _ctz64
GT_NEG_ZERO, LT_NEG_ZERO = _gt_neg_zero, _lt_neg_zero
@@ -341,12 +300,6 @@ class _Denorm:
f64 = _DenormChecker(64)
DENORM = _Denorm()
def _brev(v, bits):
"""Bit-reverse a value."""
result = 0
for i in range(bits): result |= ((v >> i) & 1) << (bits - 1 - i)
return result
class SliceProxy:
"""Proxy for D0[31:16] that supports .f16/.u16 etc getters and setters."""
__slots__ = ('_reg', '_high', '_low', '_reversed')
@@ -474,9 +427,9 @@ class TypedView:
def u32(s): return s if s._bits == 32 and not s._signed else int(s) & MASK32
class Reg:
"""GPU register: D0.f32 = S0.f32 + S1.f32 just works."""
"""GPU register: D0.f32 = S0.f32 + S1.f32 just works. Supports up to 128 bits for DS_LOAD_B128."""
__slots__ = ('_val',)
def __init__(self, val=0): self._val = int(val) & MASK64
def __init__(self, val=0): self._val = int(val) & MASK128
# Typed views
u64 = property(lambda s: TypedView(s, 64), lambda s, v: setattr(s, '_val', int(v) & MASK64))

View File

@@ -36,7 +36,7 @@ FIELD_ORDER = {
SRC_EXTRAS = {233: 'DPP8', 234: 'DPP8FI', 250: 'DPP16', 251: 'VCCZ', 252: 'EXECZ', 254: 'LDS_DIRECT'}
FLOAT_MAP = {'0.5': 'POS_HALF', '-0.5': 'NEG_HALF', '1.0': 'POS_ONE', '-1.0': 'NEG_ONE', '2.0': 'POS_TWO', '-2.0': 'NEG_TWO',
'4.0': 'POS_FOUR', '-4.0': 'NEG_FOUR', '1/(2*PI)': 'INV_2PI', '0': 'ZERO'}
INST_PATTERN = re.compile(r'^([SV]_[A-Z0-9_]+)\s+(\d+)\s*$', re.M)
INST_PATTERN = re.compile(r'^([SVD]S?_[A-Z0-9_]+|(?:FLAT|GLOBAL|SCRATCH)_[A-Z0-9_]+)\s+(\d+)\s*$', re.M)
# Patterns that can't be handled by the DSL (require special handling in emu.py)
UNSUPPORTED = ['SGPR[', 'V_SWAP', 'eval ', 'FATAL_HALT', 'HW_REGISTERS',
@@ -46,7 +46,8 @@ UNSUPPORTED = ['SGPR[', 'V_SWAP', 'eval ', 'FATAL_HALT', 'HW_REGISTERS',
'if n.', 'DST.u32', 'addrd = DST', 'addr = DST',
'BARRIER_STATE', 'ReallocVgprs',
'GPR_IDX', 'VSKIP', 'specified in', 'TTBL',
'fp6', 'bf6'] # Malformed pseudocode from PDF
'fp6', 'bf6', 'GS_REGS', 'M0.base', 'DS_DATA', '= 0..', 'sign(src', 'if no LDS', 'gds_base', 'vector mask',
'SGPR_ADDR', 'INST_OFFSET', 'laneID'] # FLAT ops with non-standard vars
# ═══════════════════════════════════════════════════════════════════════════════
# COMPILER: pseudocode -> Python (minimal transforms)
@@ -68,8 +69,8 @@ def compile_pseudocode(pseudocode: str) -> str:
lines = []
indent, need_pass, in_first_match_loop = 0, False, False
for line in joined_lines:
line = line.strip()
if not line or line.startswith('//'): continue
line = line.split('//')[0].strip() # Strip C-style comments
if not line: continue
if line.startswith('if '):
lines.append(' ' * indent + f"if {_expr(line[3:].rstrip(' then'))}:")
indent += 1
@@ -351,8 +352,9 @@ def _extract_pseudocode(text: str) -> str | None:
for line in lines:
s = line.strip()
if not s or re.match(r'^\d+ of \d+$', s) or re.match(r'^\d+\.\d+\..*Instructions', s): continue
if s.startswith(('Notes', 'Functional examples')): break
if s.startswith(('Notes', 'Functional examples', '', '-')): break # Stop at notes/bullets
if s.startswith(('"RDNA', 'AMD ', 'CDNA')): continue
if '' in s or '' in s: continue # Skip lines with bullets/dashes
if '= lambda(' in s: in_lambda += 1; continue
if in_lambda > 0:
if s.endswith(');'): in_lambda -= 1
@@ -362,7 +364,8 @@ def _extract_pseudocode(text: str) -> str | None:
if s.endswith('.') and not any(p in s for p in ['D0', 'D1', 'S0', 'S1', 'S2', 'SCC', 'VCC', 'tmp', '=']): continue
if re.match(r'^[a-z].*\.$', s) and '=' not in s: continue
is_code = (any(p in s for p in ['D0.', 'D1.', 'S0.', 'S1.', 'S2.', 'SCC =', 'SCC ?', 'VCC', 'EXEC', 'tmp =', 'tmp[', 'lane =', 'PC =',
'D0[', 'D1[', 'S0[', 'S1[', 'S2[']) or
'D0[', 'D1[', 'S0[', 'S1[', 'S2[', 'MEM[', 'RETURN_DATA',
'VADDR', 'VDATA', 'VDST', 'SADDR', 'OFFSET']) or
s.startswith(('if ', 'else', 'elsif', 'endif', 'declare ', 'for ', 'endfor', '//')) or
re.match(r'^[a-z_]+\s*=', s) or re.match(r'^[a-z_]+\[', s) or (depth > 0 and '=' in s))
if is_code: result.append(s)
@@ -448,28 +451,23 @@ def _generate_gen_pcode_py(enums, pseudocode, arch) -> str:
# Get op enums for this arch (import from .ins which re-exports from .enum)
import importlib
autogen = importlib.import_module(f"extra.assembly.amd.autogen.{arch}.ins")
OP_ENUMS = [getattr(autogen, name) for name in ['SOP1Op', 'SOP2Op', 'SOPCOp', 'SOPKOp', 'SOPPOp', 'VOP1Op', 'VOP2Op', 'VOP3Op', 'VOP3SDOp', 'VOP3POp', 'VOPCOp', 'VOP3AOp', 'VOP3BOp'] if hasattr(autogen, name)]
OP_ENUMS = [getattr(autogen, name) for name in ['SOP1Op', 'SOP2Op', 'SOPCOp', 'SOPKOp', 'SOPPOp', 'VOP1Op', 'VOP2Op', 'VOP3Op', 'VOP3SDOp', 'VOP3POp', 'VOPCOp', 'VOP3AOp', 'VOP3BOp', 'DSOp', 'FLATOp', 'GLOBALOp', 'SCRATCHOp'] if hasattr(autogen, name)]
# Build defined ops mapping
defined_ops: dict[tuple, list] = {}
for enum_cls in OP_ENUMS:
for op in enum_cls:
if op.name.startswith(('S_', 'V_')): defined_ops.setdefault((op.name, op.value), []).append((enum_cls, op))
if op.name.startswith(('S_', 'V_', 'DS_', 'FLAT_', 'GLOBAL_', 'SCRATCH_')): defined_ops.setdefault((op.name, op.value), []).append((enum_cls, op))
enum_names = [e.__name__ for e in OP_ENUMS]
lines = [f'''# autogenerated by pdf.py - do not edit
# to regenerate: python -m extra.assembly.amd.pdf --arch {arch}
# ruff: noqa: E501,F405,F403
# mypy: ignore-errors
from extra.assembly.amd.autogen.{arch}.enum import {", ".join(enum_names)}
from extra.assembly.amd.pcode import *
''']
instructions: dict = {cls: {} for cls in OP_ENUMS}
for key, pc in pseudocode.items():
if key in defined_ops:
for enum_cls, enum_val in defined_ops[key]: instructions[enum_cls][enum_val] = pc
# First pass: generate all function code
fn_lines: list[str] = []
all_fn_entries: dict = {}
for enum_cls in OP_ENUMS:
cls_name = enum_cls.__name__
if not instructions.get(enum_cls): continue
@@ -480,28 +478,44 @@ from extra.assembly.amd.pcode import *
code = compile_pseudocode(pc)
code = _apply_pseudocode_fixes(op, code)
fn_name, fn_code = _generate_function(cls_name, op, pc, code)
lines.append(fn_code)
fn_lines.append(fn_code)
fn_entries.append((op, fn_name))
except Exception as e: print(f" Warning: Failed to compile {op.name}: {e}")
if fn_entries:
lines.append(f'{cls_name}_FUNCTIONS = {{')
for op, fn_name in fn_entries: lines.append(f" {cls_name}.{op.name}: {fn_name},")
lines.append('}\n')
all_fn_entries[enum_cls] = fn_entries
fn_lines.append(f'{cls_name}_FUNCTIONS = {{')
for op, fn_name in fn_entries: fn_lines.append(f" {cls_name}.{op.name}: {fn_name},")
fn_lines.append('}\n')
# Add V_WRITELANE_B32 if VOP3Op exists
if 'VOP3Op' in enum_names:
lines.append('''
fn_lines.append('''
# V_WRITELANE_B32: Write scalar to specific lane's VGPR (not in PDF pseudocode)
def _VOP3Op_V_WRITELANE_B32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
def _VOP3Op_V_WRITELANE_B32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, src0_idx=0, vdst_idx=0, PC=None):
wr_lane = s1 & 0x1f
return {'d0': d0, 'scc': scc, 'vgpr_write': (wr_lane, vdst_idx, s0 & 0xffffffff)}
VOP3Op_FUNCTIONS[VOP3Op.V_WRITELANE_B32] = _VOP3Op_V_WRITELANE_B32
''')
lines.append('COMPILED_FUNCTIONS = {')
fn_lines.append('COMPILED_FUNCTIONS = {')
for enum_cls in OP_ENUMS:
if instructions.get(enum_cls): lines.append(f' {enum_cls.__name__}: {enum_cls.__name__}_FUNCTIONS,')
lines.append('}\n\ndef get_compiled_functions(): return COMPILED_FUNCTIONS')
if all_fn_entries.get(enum_cls): fn_lines.append(f' {enum_cls.__name__}: {enum_cls.__name__}_FUNCTIONS,')
fn_lines.append('}\n\ndef get_compiled_functions(): return COMPILED_FUNCTIONS')
# Second pass: scan generated code for pcode imports
fn_code_str = '\n'.join(fn_lines)
import extra.assembly.amd.pcode as pcode_module
pcode_exports = [name for name in dir(pcode_module) if not name.startswith('_') or name.startswith('_') and not name.startswith('__')]
used_imports = sorted(name for name in pcode_exports if re.search(rf'\b{re.escape(name)}\b', fn_code_str))
# Build final output with explicit imports
lines = [f'''# autogenerated by pdf.py - do not edit
# to regenerate: python -m extra.assembly.amd.pdf --arch {arch}
# ruff: noqa: E501
# mypy: ignore-errors
from extra.assembly.amd.autogen.{arch}.enum import {", ".join(enum_names)}
from extra.assembly.amd.pcode import {", ".join(used_imports)}
'''] + fn_lines
return '\n'.join(lines)
def _apply_pseudocode_fixes(op, code: str) -> str:
@@ -541,19 +555,32 @@ def _generate_function(cls_name: str, op, pc: str, code: str) -> tuple[str, str]
is_cmpx = (cls_name in ('VOPCOp', 'VOP3Op')) and 'EXEC.u64[laneId]' in pc
is_div_scale = 'DIV_SCALE' in op.name
has_sdst = cls_name == 'VOP3SDOp' and ('VCC.u64[laneId]' in pc or is_div_scale)
is_ds = cls_name == 'DSOp'
is_flat = cls_name in ('FLATOp', 'GLOBALOp', 'SCRATCHOp')
combined = code + pc
fn_name = f"_{cls_name}_{op.name}"
# Function accepts Reg objects directly (uppercase names), laneId is passed directly as int
lines = [f"def {fn_name}(S0, S1, S2, D0, SCC, VCC, laneId, EXEC, literal, VGPR, src0_idx=0, vdst_idx=0, PC=None):"]
# DSOp functions get additional MEM and offset parameters
# FLAT/GLOBAL ops get MEM, vaddr, vdata, saddr, offset parameters
if is_ds:
lines = [f"def {fn_name}(MEM, ADDR, DATA0, DATA1, OFFSET0, OFFSET1, RETURN_DATA):"]
elif is_flat:
lines = [f"def {fn_name}(MEM, ADDR, VDATA, VDST, RETURN_DATA):"]
else:
lines = [f"def {fn_name}(S0, S1, S2, D0, SCC, VCC, laneId, EXEC, literal, VGPR, src0_idx=0, vdst_idx=0, PC=None):"]
# Registers that need special handling (not passed directly)
# Only init if used but not first assigned as `name = Reg(...)` in the compiled code
# Registers that need special handling (aliases or init)
def needs_init(name): return name in combined and not re.search(rf'^\s*{name}\s*=\s*Reg\(', code, re.MULTILINE)
special_regs = [('D1', 'Reg(0)'), ('SIMM16', 'Reg(literal)'), ('SIMM32', 'Reg(literal)'),
('SRC0', 'Reg(src0_idx)'), ('VDST', 'Reg(vdst_idx)')]
if needs_init('tmp'): special_regs.insert(0, ('tmp', 'Reg(0)'))
if needs_init('saveexec'): special_regs.insert(0, ('saveexec', 'Reg(EXEC._val)'))
special_regs = []
if is_ds: special_regs = [('DATA', 'DATA0'), ('DATA2', 'DATA1'), ('OFFSET', 'OFFSET0'), ('ADDR_BASE', 'ADDR')]
elif is_flat: special_regs = [('DATA', 'VDATA')]
else:
special_regs = [('D1', 'Reg(0)'), ('SIMM16', 'Reg(literal)'), ('SIMM32', 'Reg(literal)'),
('SRC0', 'Reg(src0_idx)'), ('VDST', 'Reg(vdst_idx)')]
if needs_init('tmp'): special_regs.insert(0, ('tmp', 'Reg(0)'))
if needs_init('saveexec'): special_regs.insert(0, ('saveexec', 'Reg(EXEC._val)'))
used = {name for name, _ in special_regs if name in combined}
# Detect which registers are modified (not just read) - look for assignments
@@ -562,6 +589,10 @@ def _generate_function(cls_name: str, op, pc: str, code: str) -> tuple[str, str]
modifies_vcc = has_sdst or bool(re.search(r'VCC\.(u32|u64|b32|b64)\s*=|VCC\.u64\[laneId\]\s*=', combined))
modifies_scc = bool(re.search(r'\bSCC\s*=', combined))
modifies_pc = bool(re.search(r'\bPC\s*=', combined))
# DS/FLAT ops: detect memory writes (MEM[...] = ...)
modifies_mem = (is_ds or is_flat) and bool(re.search(r'MEM\[.*\]\.[a-z0-9]+\s*=', combined))
# FLAT ops: detect VDST writes
modifies_vdst = is_flat and bool(re.search(r'VDST[\.\[].*=', combined))
# Build init code for special registers
init_lines = []
@@ -587,6 +618,15 @@ def _generate_function(cls_name: str, op, pc: str, code: str) -> tuple[str, str]
if modifies_exec: result_items.append("'EXEC': EXEC")
if has_d1: result_items.append("'D1': D1")
if modifies_pc: result_items.append("'PC': PC")
# DS ops: return RETURN_DATA if it was written (left side of assignment)
if is_ds and 'RETURN_DATA' in combined and re.search(r'^\s*RETURN_DATA[\.\[].*=', code, re.MULTILINE):
result_items.append("'RETURN_DATA': RETURN_DATA")
# FLAT ops: return RETURN_DATA for atomics, VDATA for loads (only if written to)
if is_flat:
if 'RETURN_DATA' in combined and re.search(r'^\s*RETURN_DATA[\.\[].*=', code, re.MULTILINE):
result_items.append("'RETURN_DATA': RETURN_DATA")
if re.search(r'^\s*VDATA[\.\[].*=', code, re.MULTILINE):
result_items.append("'VDATA': VDATA")
lines.append(f" return {{{', '.join(result_items)}}}\n")
return fn_name, '\n'.join(lines)

View File

@@ -9,7 +9,7 @@ os.environ["AMD"] = "1"
os.environ["MOCKGPU"] = "1"
os.environ["PYTHON_REMU"] = "1"
from extra.assembly.amd.emu import WaveState, decode_program, step_wave, WAVE_SIZE, set_valid_mem_ranges
from extra.assembly.amd.emu import WaveState, decode_program, step_wave, WAVE_SIZE, set_valid_mem_ranges, LDSMem
from extra.assembly.amd.test.helpers import KernelInfo
REMU_PATH = Path(__file__).parents[3] / "remu/target/release/libremu.so"
@@ -99,7 +99,7 @@ class PythonEmulator:
self.program = decode_program(kernel)
self.state = WaveState()
self.state.exec_mask = (1 << n_lanes) - 1
self.lds = bytearray(65536)
self.lds = LDSMem(bytearray(65536))
self.n_lanes = n_lanes
def step(self) -> int:

File diff suppressed because it is too large Load Diff