mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-06 21:53:53 -05:00
assembly/amd: add pcode ds ops (#13939)
* assembly/amd: add pcode ds ops * refactors * fix ds op * update autogen * fix flat bug * more tests * fix emu test * that's a hack * generic * fix all tests * two tests * fix test failure * better * remove __all__
This commit is contained in:
3
.github/workflows/test.yml
vendored
3
.github/workflows/test.yml
vendored
@@ -684,6 +684,9 @@ jobs:
|
||||
run: AMD=1 PYTHON_REMU=1 MOCKGPU=1 AMD_LLVM=0 pytest -n=auto test/test_dtype_alu.py test/test_dtype.py
|
||||
- name: Run RDNA3 dtype tests (AMD_LLVM=1)
|
||||
run: AMD=1 PYTHON_REMU=1 MOCKGPU=1 AMD_LLVM=1 pytest -n=auto test/test_dtype_alu.py test/test_dtype.py
|
||||
# TODO: run all once emulator is faster
|
||||
- name: Run RDNA3 ops tests
|
||||
run: SKIP_SLOW_TEST=1 AMD=1 PYTHON_REMU=1 MOCKGPU=1 AMD_LLVM=0 pytest -n=auto test/test_ops.py -k "test_sparse_categorical_crossentropy or test_tril"
|
||||
|
||||
testamdautogen:
|
||||
name: AMD autogen
|
||||
|
||||
@@ -11,6 +11,8 @@ Test with `PYTHONPATH="." pytest -n12 extra/assembly/amd/`
|
||||
|
||||
The code should be as readable and deduplicated as possible. asm and emu shouldn't be required for dsl.
|
||||
|
||||
The autogen folder is autogenerated from the AMD PDFs with `python3 -m extra.assembly.amd.pdf --arch all`
|
||||
|
||||
test_emu.py has a good set of instruction tests for the emulation, with USE_HW=1 it will compare to real hardware.
|
||||
Whenever an instruction is fixed, regression tests should be added here and confirmed with real hardware.
|
||||
|
||||
@@ -26,6 +28,12 @@ The ops tests also pass, but they are very slow, so you should run them one at a
|
||||
`SKIP_SLOW_TEST=1 PYTHONPATH="." AMD=1 PYTHON_REMU=1 MOCKGPU=1 AMD_LLVM=0 pytest -n=12 test/test_ops.py`
|
||||
`SKIP_SLOW_TEST=1 PYTHONPATH="." AMD=1 PYTHON_REMU=1 MOCKGPU=1 AMD_LLVM=1 pytest -n=12 test/test_ops.py`
|
||||
|
||||
When something is caught by main tinygrad tests, a local regression test should be added to `extra/assembly/amd/test`. While working with tinygrad, you can dump the assembly with `DEBUG=7`. These tests all pass on real hardware, so if a test is failing with `AMD=1 PYTHON_REMU=1 MOCKGPU=1` it's likely because an instruction is emulated incorrectly. You can test without `MOCKGPU=1` to test on real hardware, if it works on real hardware there's a bug in the emulator.
|
||||
When something is caught by main tinygrad tests, a local regression test should be added to `extra/assembly/amd/test`.
|
||||
While working with tinygrad, you can dump the assembly with `DEBUG=7`. These tests all pass on real hardware
|
||||
If a test is failing with `AMD=1 PYTHON_REMU=1 MOCKGPU=1` it's because an instruction is emulated incorrectly.
|
||||
You can test without `MOCKGPU=1` to test on real hardware, if it works on real hardware there's a bug in the emulator.
|
||||
IMPORTANT: if a test is failing in the emulator, it's an instruction bug. Use DEBUG=7, get the instructions, and debug.
|
||||
|
||||
Currently, only RDNA3 is well supported, but when finished, this will support RDNA3+RDNA4+CDNA in ~2000 lines.
|
||||
Get line count with `cloc --by-file extra/assembly/amd/*.py`
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -9,7 +9,7 @@ from extra.assembly.amd.autogen.rdna3.enum import (VOP1Op, VOP2Op, VOP3Op, VOP3S
|
||||
SOPCOp, SOPKOp, SOPPOp, SMEMOp, DSOp, FLATOp, MUBUFOp, MTBUFOp, MIMGOp, VINTERPOp)
|
||||
|
||||
# Common masks and bit conversion functions
|
||||
MASK32, MASK64 = 0xffffffff, 0xffffffffffffffff
|
||||
MASK32, MASK64, MASK128 = 0xffffffff, 0xffffffffffffffff, (1 << 128) - 1
|
||||
_struct_f, _struct_I = struct.Struct("<f"), struct.Struct("<I")
|
||||
_struct_e, _struct_H = struct.Struct("<e"), struct.Struct("<H")
|
||||
_struct_d, _struct_Q = struct.Struct("<d"), struct.Struct("<Q")
|
||||
|
||||
@@ -7,7 +7,7 @@ from extra.assembly.amd.pcode import Reg
|
||||
from extra.assembly.amd.asm import detect_format
|
||||
from extra.assembly.amd.autogen.rdna3.gen_pcode import get_compiled_functions
|
||||
from extra.assembly.amd.autogen.rdna3.ins import (SOP1, SOP2, SOPC, SOPK, SOPP, SMEM, VOP1, VOP2, VOP3, VOP3SD, VOP3P, VOPC, DS, FLAT, VOPD,
|
||||
SrcEnum, SOP1Op, SOP2Op, SOPCOp, SOPKOp, SOPPOp, SMEMOp, VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOP3POp, VOPCOp, DSOp, FLATOp, GLOBALOp, VOPDOp)
|
||||
SrcEnum, SOP1Op, SOP2Op, SOPCOp, SOPKOp, SOPPOp, SMEMOp, VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOP3POp, VOPCOp, DSOp, FLATOp, GLOBALOp, SCRATCHOp, VOPDOp)
|
||||
|
||||
Program = dict[int, Inst]
|
||||
WAVE_SIZE, SGPR_COUNT, VGPR_COUNT = 32, 128, 256
|
||||
@@ -29,33 +29,65 @@ def _dst16(cur: int, val: int, is_hi: bool) -> int: return (cur & 0x0000ffff) |
|
||||
def _vgpr_hi(src: int) -> bool: return src >= 256 and ((src - 256) & 0x80) != 0
|
||||
def _vgpr_masked(src: int) -> int: return ((src - 256) & 0x7f) + 256 if src >= 256 else src
|
||||
|
||||
# Helper: get number of dwords from memory op name
|
||||
def _op_ndwords(name: str) -> int:
|
||||
if '_B128' in name: return 4
|
||||
if '_B96' in name: return 3
|
||||
if any(s in name for s in ('_B64', '_U64', '_I64', '_F64')): return 2
|
||||
return 1
|
||||
|
||||
# Helper: build multi-dword Reg from consecutive VGPRs
|
||||
def _vgpr_read(V: list, base: int, ndwords: int) -> Reg: return Reg(sum(V[base + i] << (32 * i) for i in range(ndwords)))
|
||||
|
||||
# Helper: write multi-dword value to consecutive VGPRs
|
||||
def _vgpr_write(V: list, base: int, val: int, ndwords: int):
|
||||
for i in range(ndwords): V[base + i] = (val >> (32 * i)) & MASK32
|
||||
|
||||
# Memory access
|
||||
_valid_mem_ranges: list[tuple[int, int]] = []
|
||||
def set_valid_mem_ranges(ranges: set[tuple[int, int]]) -> None: _valid_mem_ranges.clear(); _valid_mem_ranges.extend(ranges)
|
||||
def _mem_valid(addr: int, size: int) -> bool:
|
||||
return not _valid_mem_ranges or any(s <= addr and addr + size <= s + z for s, z in _valid_mem_ranges)
|
||||
def _ctypes_at(addr: int, size: int): return (ctypes.c_uint8 if size == 1 else ctypes.c_uint16 if size == 2 else ctypes.c_uint32).from_address(addr)
|
||||
def _ctypes_at(addr: int, size: int): return (ctypes.c_uint8 if size == 1 else ctypes.c_uint16 if size == 2 else ctypes.c_uint64 if size == 8 else ctypes.c_uint32).from_address(addr)
|
||||
def mem_read(addr: int, size: int) -> int: return _ctypes_at(addr, size).value if _mem_valid(addr, size) else 0
|
||||
def mem_write(addr: int, size: int, val: int) -> None:
|
||||
if _mem_valid(addr, size): _ctypes_at(addr, size).value = val
|
||||
|
||||
# Memory op tables (not pseudocode - these are format descriptions)
|
||||
def _mem_ops(ops, suffix_map):
|
||||
return {getattr(e, f"{p}_{s}"): v for e in ops for s, v in suffix_map.items() for p in [e.__name__.replace("Op", "")]}
|
||||
_LOAD_MAP = {'LOAD_B32': (1,4,0), 'LOAD_B64': (2,4,0), 'LOAD_B96': (3,4,0), 'LOAD_B128': (4,4,0), 'LOAD_U8': (1,1,0), 'LOAD_I8': (1,1,1), 'LOAD_U16': (1,2,0), 'LOAD_I16': (1,2,1)}
|
||||
_STORE_MAP = {'STORE_B32': (1,4), 'STORE_B64': (2,4), 'STORE_B96': (3,4), 'STORE_B128': (4,4), 'STORE_B8': (1,1), 'STORE_B16': (1,2)}
|
||||
FLAT_LOAD, FLAT_STORE = _mem_ops([GLOBALOp, FLATOp], _LOAD_MAP), _mem_ops([GLOBALOp, FLATOp], _STORE_MAP)
|
||||
# D16 ops: load/store 16-bit to lower or upper half of VGPR. Format: (size, sign, hi) where hi=1 means upper 16 bits
|
||||
_D16_LOAD_MAP = {'LOAD_D16_U8': (1,0,0), 'LOAD_D16_I8': (1,1,0), 'LOAD_D16_B16': (2,0,0),
|
||||
'LOAD_D16_HI_U8': (1,0,1), 'LOAD_D16_HI_I8': (1,1,1), 'LOAD_D16_HI_B16': (2,0,1)}
|
||||
_D16_STORE_MAP = {'STORE_D16_HI_B8': (1,1), 'STORE_D16_HI_B16': (2,1)} # (size, hi)
|
||||
FLAT_D16_LOAD = _mem_ops([GLOBALOp, FLATOp], _D16_LOAD_MAP)
|
||||
FLAT_D16_STORE = _mem_ops([GLOBALOp, FLATOp], _D16_STORE_MAP)
|
||||
DS_LOAD = {DSOp.DS_LOAD_B32: (1,4,0), DSOp.DS_LOAD_B64: (2,4,0), DSOp.DS_LOAD_B128: (4,4,0), DSOp.DS_LOAD_U8: (1,1,0), DSOp.DS_LOAD_I8: (1,1,1), DSOp.DS_LOAD_U16: (1,2,0), DSOp.DS_LOAD_I16: (1,2,1)}
|
||||
DS_STORE = {DSOp.DS_STORE_B32: (1,4), DSOp.DS_STORE_B64: (2,4), DSOp.DS_STORE_B128: (4,4), DSOp.DS_STORE_B8: (1,1), DSOp.DS_STORE_B16: (1,2)}
|
||||
# 2ADDR ops: load/store two values using offset0 and offset1
|
||||
DS_LOAD_2ADDR = {DSOp.DS_LOAD_2ADDR_B32: 4, DSOp.DS_LOAD_2ADDR_B64: 8}
|
||||
DS_STORE_2ADDR = {DSOp.DS_STORE_2ADDR_B32: 4, DSOp.DS_STORE_2ADDR_B64: 8}
|
||||
def _make_mem_accessor(read_fn, write_fn):
|
||||
"""Create a memory accessor class with the given read/write functions."""
|
||||
class _MemAccessor:
|
||||
__slots__ = ('_addr',)
|
||||
def __init__(self, addr: int): self._addr = int(addr)
|
||||
u8 = property(lambda s: read_fn(s._addr, 1), lambda s, v: write_fn(s._addr, 1, int(v)))
|
||||
u16 = property(lambda s: read_fn(s._addr, 2), lambda s, v: write_fn(s._addr, 2, int(v)))
|
||||
u32 = property(lambda s: read_fn(s._addr, 4), lambda s, v: write_fn(s._addr, 4, int(v)))
|
||||
u64 = property(lambda s: read_fn(s._addr, 8), lambda s, v: write_fn(s._addr, 8, int(v)))
|
||||
i8 = property(lambda s: _sext(read_fn(s._addr, 1), 8), lambda s, v: write_fn(s._addr, 1, int(v)))
|
||||
i16 = property(lambda s: _sext(read_fn(s._addr, 2), 16), lambda s, v: write_fn(s._addr, 2, int(v)))
|
||||
i32 = property(lambda s: _sext(read_fn(s._addr, 4), 32), lambda s, v: write_fn(s._addr, 4, int(v)))
|
||||
i64 = property(lambda s: _sext(read_fn(s._addr, 8), 64), lambda s, v: write_fn(s._addr, 8, int(v)))
|
||||
b8, b16, b32, b64 = u8, u16, u32, u64
|
||||
return _MemAccessor
|
||||
|
||||
_GlobalMemAccessor = _make_mem_accessor(mem_read, mem_write)
|
||||
|
||||
class _GlobalMem:
|
||||
"""Global memory wrapper that supports MEM[addr].u32 style access."""
|
||||
def __getitem__(self, addr) -> _GlobalMemAccessor: return _GlobalMemAccessor(addr)
|
||||
GlobalMem = _GlobalMem()
|
||||
|
||||
class LDSMem:
|
||||
"""LDS memory wrapper that supports MEM[addr].u32 style access."""
|
||||
__slots__ = ('_lds',)
|
||||
def __init__(self, lds: bytearray): self._lds = lds
|
||||
def _read(self, addr: int, size: int) -> int:
|
||||
addr = addr & 0xffff
|
||||
return int.from_bytes(self._lds[addr:addr+size], 'little') if addr + size <= len(self._lds) else 0
|
||||
def _write(self, addr: int, size: int, val: int):
|
||||
addr = addr & 0xffff
|
||||
if addr + size <= len(self._lds): self._lds[addr:addr+size] = (int(val) & ((1 << (size*8)) - 1)).to_bytes(size, 'little')
|
||||
def __getitem__(self, addr): return _make_mem_accessor(self._read, self._write)(addr)
|
||||
|
||||
SMEM_LOAD = {SMEMOp.S_LOAD_B32: 1, SMEMOp.S_LOAD_B64: 2, SMEMOp.S_LOAD_B128: 4, SMEMOp.S_LOAD_B256: 8, SMEMOp.S_LOAD_B512: 16}
|
||||
|
||||
# VOPD op -> VOP3 op mapping (VOPD is dual-issue of VOP1/VOP2 ops, use VOP3 enums for pseudocode lookup)
|
||||
@@ -197,60 +229,28 @@ def exec_scalar(st: WaveState, inst: Inst) -> int:
|
||||
return new_pc_words - st.pc - 1 # -1 because emulator adds inst_words (1 for scalar)
|
||||
return 0
|
||||
|
||||
def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = None) -> None:
|
||||
def exec_vector(st: WaveState, inst: Inst, lane: int, lds: LDSMem | None = None) -> None:
|
||||
"""Execute vector instruction for one lane."""
|
||||
compiled = _get_compiled()
|
||||
V = st.vgpr[lane]
|
||||
|
||||
# Memory ops (not ALU pseudocode)
|
||||
if isinstance(inst, FLAT):
|
||||
op, addr_reg, data_reg, vdst, offset, saddr = inst.op, inst.addr, inst.data, inst.vdst, _sext(inst.offset, 13), inst.saddr
|
||||
addr = V[addr_reg] | (V[addr_reg+1] << 32)
|
||||
addr = (st.rsgpr64(saddr) + V[addr_reg] + offset) & MASK64 if saddr not in (NULL, 0x7f) else (addr + offset) & MASK64
|
||||
if op in FLAT_LOAD:
|
||||
cnt, sz, sign = FLAT_LOAD[op]
|
||||
for i in range(cnt): val = mem_read(addr + i * sz, sz); V[vdst + i] = _sext(val, sz * 8) & MASK32 if sign else val
|
||||
elif op in FLAT_STORE:
|
||||
cnt, sz = FLAT_STORE[op]
|
||||
for i in range(cnt): mem_write(addr + i * sz, sz, V[data_reg + i] & ((1 << (sz * 8)) - 1))
|
||||
elif op in FLAT_D16_LOAD:
|
||||
sz, sign, hi = FLAT_D16_LOAD[op]
|
||||
val = mem_read(addr, sz)
|
||||
if sign: val = _sext(val, sz * 8) & 0xffff
|
||||
V[vdst] = _dst16(V[vdst], val, hi)
|
||||
elif op in FLAT_D16_STORE:
|
||||
sz, hi = FLAT_D16_STORE[op]
|
||||
mem_write(addr, sz, _src16(V[data_reg], hi) & ((1 << (sz * 8)) - 1))
|
||||
else: raise NotImplementedError(f"FLAT op {op}")
|
||||
return
|
||||
|
||||
if isinstance(inst, DS):
|
||||
op, addr0, vdst = inst.op, (V[inst.addr] + inst.offset0) & 0xffff, inst.vdst
|
||||
if op in DS_LOAD:
|
||||
cnt, sz, sign = DS_LOAD[op]
|
||||
for i in range(cnt): val = int.from_bytes(lds[addr0+i*sz:addr0+i*sz+sz], 'little'); V[vdst + i] = _sext(val, sz * 8) & MASK32 if sign else val
|
||||
elif op in DS_STORE:
|
||||
cnt, sz = DS_STORE[op]
|
||||
for i in range(cnt): lds[addr0+i*sz:addr0+i*sz+sz] = (V[inst.data0 + i] & ((1 << (sz * 8)) - 1)).to_bytes(sz, 'little')
|
||||
elif op in DS_LOAD_2ADDR:
|
||||
# Load two values from addr+offset0*sz and addr+offset1*sz into vdst (B32: 1 dword each, B64: 2 dwords each)
|
||||
# Note: offsets are scaled by data size (4 for B32, 8 for B64) per AMD ISA
|
||||
sz = DS_LOAD_2ADDR[op]
|
||||
addr0 = (V[inst.addr] + inst.offset0 * sz) & 0xffff
|
||||
addr1 = (V[inst.addr] + inst.offset1 * sz) & 0xffff
|
||||
cnt = sz // 4 # 1 for B32, 2 for B64
|
||||
for i in range(cnt): V[vdst + i] = int.from_bytes(lds[addr0+i*4:addr0+i*4+4], 'little')
|
||||
for i in range(cnt): V[vdst + cnt + i] = int.from_bytes(lds[addr1+i*4:addr1+i*4+4], 'little')
|
||||
elif op in DS_STORE_2ADDR:
|
||||
# Store two values from data0 and data1 to addr+offset0*sz and addr+offset1*sz
|
||||
# Note: offsets are scaled by data size (4 for B32, 8 for B64) per AMD ISA
|
||||
sz = DS_STORE_2ADDR[op]
|
||||
addr0 = (V[inst.addr] + inst.offset0 * sz) & 0xffff
|
||||
addr1 = (V[inst.addr] + inst.offset1 * sz) & 0xffff
|
||||
cnt = sz // 4
|
||||
for i in range(cnt): lds[addr0+i*4:addr0+i*4+4] = (V[inst.data0 + i] & MASK32).to_bytes(4, 'little')
|
||||
for i in range(cnt): lds[addr1+i*4:addr1+i*4+4] = (V[inst.data1 + i] & MASK32).to_bytes(4, 'little')
|
||||
else: raise NotImplementedError(f"DS op {op}")
|
||||
# Memory ops (FLAT/GLOBAL/SCRATCH and DS) - use generated pcode
|
||||
if isinstance(inst, (FLAT, DS)):
|
||||
op, vdst, op_name = inst.op, inst.vdst, inst.op.name
|
||||
fn, ndwords = compiled[type(op)][op], _op_ndwords(op_name)
|
||||
if isinstance(inst, FLAT):
|
||||
addr = V[inst.addr] | (V[inst.addr + 1] << 32)
|
||||
ADDR = (st.rsgpr64(inst.saddr) + V[inst.addr] + _sext(inst.offset, 13)) & MASK64 if inst.saddr not in (NULL, 0x7f) else (addr + _sext(inst.offset, 13)) & MASK64
|
||||
# For loads, VDATA comes from vdst (preserves unwritten bits); for stores, from inst.data
|
||||
vdata_src = vdst if 'LOAD' in op_name else inst.data
|
||||
result = fn(GlobalMem, ADDR, _vgpr_read(V, vdata_src, ndwords), Reg(V[vdst]), Reg(0))
|
||||
if 'VDATA' in result: _vgpr_write(V, vdst, result['VDATA']._val, ndwords)
|
||||
if 'RETURN_DATA' in result: _vgpr_write(V, vdst, result['RETURN_DATA']._val, ndwords)
|
||||
else: # DS
|
||||
DATA0, DATA1 = _vgpr_read(V, inst.data0, ndwords), _vgpr_read(V, inst.data1, ndwords) if inst.data1 is not None else Reg(0)
|
||||
result = fn(lds, Reg(V[inst.addr]), DATA0, DATA1, Reg(inst.offset0), Reg(inst.offset1), Reg(0))
|
||||
if 'RETURN_DATA' in result and ('_RTN' in op_name or '_LOAD' in op_name):
|
||||
_vgpr_write(V, vdst, result['RETURN_DATA']._val, ndwords * 2 if '_2ADDR_' in op_name else ndwords)
|
||||
return
|
||||
|
||||
# VOPD: dual-issue, execute two ops simultaneously (read all inputs before writes)
|
||||
@@ -423,7 +423,7 @@ def exec_wmma(st: WaveState, inst, op: VOP3POp) -> None:
|
||||
# MAIN EXECUTION LOOP
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
def step_wave(program: Program, st: WaveState, lds: bytearray, n_lanes: int) -> int:
|
||||
def step_wave(program: Program, st: WaveState, lds: LDSMem, n_lanes: int) -> int:
|
||||
inst = program.get(st.pc)
|
||||
if inst is None: return 1
|
||||
inst_words, st.literal = inst._words, getattr(inst, '_literal', None) or 0
|
||||
@@ -443,7 +443,7 @@ def step_wave(program: Program, st: WaveState, lds: bytearray, n_lanes: int) ->
|
||||
st.pc += inst_words
|
||||
return 0
|
||||
|
||||
def exec_wave(program: Program, st: WaveState, lds: bytearray, n_lanes: int) -> int:
|
||||
def exec_wave(program: Program, st: WaveState, lds: LDSMem, n_lanes: int) -> int:
|
||||
while st.pc in program:
|
||||
result = step_wave(program, st, lds, n_lanes)
|
||||
if result == -1: return 0
|
||||
@@ -453,7 +453,7 @@ def exec_wave(program: Program, st: WaveState, lds: bytearray, n_lanes: int) ->
|
||||
def exec_workgroup(program: Program, workgroup_id: tuple[int, int, int], local_size: tuple[int, int, int], args_ptr: int,
|
||||
wg_id_sgpr_base: int, wg_id_enables: tuple[bool, bool, bool]) -> None:
|
||||
lx, ly, lz = local_size
|
||||
total_threads, lds = lx * ly * lz, bytearray(65536)
|
||||
total_threads, lds = lx * ly * lz, LDSMem(bytearray(65536))
|
||||
waves: list[tuple[WaveState, int, int]] = []
|
||||
for wave_start in range(0, total_threads, WAVE_SIZE):
|
||||
n_lanes, st = min(WAVE_SIZE, total_threads - wave_start), WaveState()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# DSL for RDNA3 pseudocode - makes pseudocode expressions work directly as Python
|
||||
import struct, math
|
||||
from extra.assembly.amd.dsl import MASK32, MASK64, _f32, _i32, _sext, _f16, _i16, _f64, _i64
|
||||
from extra.assembly.amd.dsl import MASK32, MASK64, MASK128, _f32, _i32, _sext, _f16, _i16, _f64, _i64
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# HELPER FUNCTIONS
|
||||
@@ -206,47 +206,6 @@ def signext_from_bit(val, bit):
|
||||
if val & (1 << (bit - 1)): return val - (1 << bit)
|
||||
return val
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# DSL EXPORTS
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
__all__ = [
|
||||
# Classes
|
||||
'Reg', 'SliceProxy', 'TypedView',
|
||||
# Pack functions
|
||||
'_pack', '_pack32', 'pack', 'pack32',
|
||||
# Constants
|
||||
'WAVE32', 'WAVE64', 'MASK32', 'MASK64', 'WAVE_MODE', 'DENORM', 'OVERFLOW_F32', 'UNDERFLOW_F32',
|
||||
'OVERFLOW_F64', 'UNDERFLOW_F64', 'MAX_FLOAT_F32', 'ROUND_MODE', 'cvtToQuietNAN', 'DST', 'INF', 'PI',
|
||||
'TWO_OVER_PI_1201',
|
||||
# Aliases for pseudocode
|
||||
's_ff1_i32_b32', 's_ff1_i32_b64', 'GT_NEG_ZERO', 'LT_NEG_ZERO',
|
||||
'isNAN', 'isQuietNAN', 'isSignalNAN', 'fma', 'ldexp', 'sign', 'exponent', 'F', 'signext',
|
||||
# Conversion functions
|
||||
'_f32', '_i32', '_f16', '_i16', '_f64', '_i64', '_sext', '_to_f16_bits', '_f16_to_f32_bits',
|
||||
'i32_to_f32', 'u32_to_f32', 'i32_to_f64', 'u32_to_f64', 'f32_to_f64', 'f64_to_f32',
|
||||
'f32_to_i32', 'f32_to_u32', 'f64_to_i32', 'f64_to_u32', 'f32_to_f16', 'f16_to_f32',
|
||||
'i16_to_f16', 'u16_to_f16', 'f16_to_i16', 'f16_to_u16', 'u32_to_u16', 'i32_to_i16',
|
||||
'f16_to_snorm', 'f16_to_unorm', 'f32_to_snorm', 'f32_to_unorm', 'v_cvt_i16_f32', 'v_cvt_u16_f32',
|
||||
'SAT8', 'f32_to_u8', 'u8_to_u32', 'u4_to_u32',
|
||||
# BF16 conversion functions
|
||||
'_bf16', '_ibf16', 'bf16_to_f32', 'f32_to_bf16',
|
||||
# Math functions
|
||||
'trunc', 'floor', 'ceil', 'sqrt', 'log2', 'sin', 'cos', 'pow', 'fract', 'isEven', 'mantissa',
|
||||
# Min/max functions
|
||||
'v_min_f32', 'v_max_f32', 'v_min_i32', 'v_max_i32', 'v_min_u32', 'v_max_u32',
|
||||
'v_min_f16', 'v_max_f16', 'v_min_i16', 'v_max_i16', 'v_min_u16', 'v_max_u16',
|
||||
'v_min3_f32', 'v_max3_f32', 'v_min3_i32', 'v_max3_i32', 'v_min3_u32', 'v_max3_u32',
|
||||
'v_min3_f16', 'v_max3_f16', 'v_min3_i16', 'v_max3_i16', 'v_min3_u16', 'v_max3_u16',
|
||||
'ABSDIFF',
|
||||
# Byte/SAD helper functions
|
||||
'BYTE_PERMUTE', 'v_sad_u8', 'v_msad_u8',
|
||||
# Bit manipulation
|
||||
'_brev32', '_brev64', '_ctz32', '_ctz64', '_exponent', '_is_denorm_f32', '_is_denorm_f64',
|
||||
'_sign', '_mantissa_f32', '_div', '_isnan', '_isquietnan', '_issignalnan', '_gt_neg_zero', '_lt_neg_zero', '_fma', '_ldexp', '_signext',
|
||||
'signext_from_bit',
|
||||
]
|
||||
|
||||
# Aliases used in pseudocode
|
||||
s_ff1_i32_b32, s_ff1_i32_b64 = _ctz32, _ctz64
|
||||
GT_NEG_ZERO, LT_NEG_ZERO = _gt_neg_zero, _lt_neg_zero
|
||||
@@ -341,12 +300,6 @@ class _Denorm:
|
||||
f64 = _DenormChecker(64)
|
||||
DENORM = _Denorm()
|
||||
|
||||
def _brev(v, bits):
|
||||
"""Bit-reverse a value."""
|
||||
result = 0
|
||||
for i in range(bits): result |= ((v >> i) & 1) << (bits - 1 - i)
|
||||
return result
|
||||
|
||||
class SliceProxy:
|
||||
"""Proxy for D0[31:16] that supports .f16/.u16 etc getters and setters."""
|
||||
__slots__ = ('_reg', '_high', '_low', '_reversed')
|
||||
@@ -474,9 +427,9 @@ class TypedView:
|
||||
def u32(s): return s if s._bits == 32 and not s._signed else int(s) & MASK32
|
||||
|
||||
class Reg:
|
||||
"""GPU register: D0.f32 = S0.f32 + S1.f32 just works."""
|
||||
"""GPU register: D0.f32 = S0.f32 + S1.f32 just works. Supports up to 128 bits for DS_LOAD_B128."""
|
||||
__slots__ = ('_val',)
|
||||
def __init__(self, val=0): self._val = int(val) & MASK64
|
||||
def __init__(self, val=0): self._val = int(val) & MASK128
|
||||
|
||||
# Typed views
|
||||
u64 = property(lambda s: TypedView(s, 64), lambda s, v: setattr(s, '_val', int(v) & MASK64))
|
||||
|
||||
@@ -36,7 +36,7 @@ FIELD_ORDER = {
|
||||
SRC_EXTRAS = {233: 'DPP8', 234: 'DPP8FI', 250: 'DPP16', 251: 'VCCZ', 252: 'EXECZ', 254: 'LDS_DIRECT'}
|
||||
FLOAT_MAP = {'0.5': 'POS_HALF', '-0.5': 'NEG_HALF', '1.0': 'POS_ONE', '-1.0': 'NEG_ONE', '2.0': 'POS_TWO', '-2.0': 'NEG_TWO',
|
||||
'4.0': 'POS_FOUR', '-4.0': 'NEG_FOUR', '1/(2*PI)': 'INV_2PI', '0': 'ZERO'}
|
||||
INST_PATTERN = re.compile(r'^([SV]_[A-Z0-9_]+)\s+(\d+)\s*$', re.M)
|
||||
INST_PATTERN = re.compile(r'^([SVD]S?_[A-Z0-9_]+|(?:FLAT|GLOBAL|SCRATCH)_[A-Z0-9_]+)\s+(\d+)\s*$', re.M)
|
||||
|
||||
# Patterns that can't be handled by the DSL (require special handling in emu.py)
|
||||
UNSUPPORTED = ['SGPR[', 'V_SWAP', 'eval ', 'FATAL_HALT', 'HW_REGISTERS',
|
||||
@@ -46,7 +46,8 @@ UNSUPPORTED = ['SGPR[', 'V_SWAP', 'eval ', 'FATAL_HALT', 'HW_REGISTERS',
|
||||
'if n.', 'DST.u32', 'addrd = DST', 'addr = DST',
|
||||
'BARRIER_STATE', 'ReallocVgprs',
|
||||
'GPR_IDX', 'VSKIP', 'specified in', 'TTBL',
|
||||
'fp6', 'bf6'] # Malformed pseudocode from PDF
|
||||
'fp6', 'bf6', 'GS_REGS', 'M0.base', 'DS_DATA', '= 0..', 'sign(src', 'if no LDS', 'gds_base', 'vector mask',
|
||||
'SGPR_ADDR', 'INST_OFFSET', 'laneID'] # FLAT ops with non-standard vars
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# COMPILER: pseudocode -> Python (minimal transforms)
|
||||
@@ -68,8 +69,8 @@ def compile_pseudocode(pseudocode: str) -> str:
|
||||
lines = []
|
||||
indent, need_pass, in_first_match_loop = 0, False, False
|
||||
for line in joined_lines:
|
||||
line = line.strip()
|
||||
if not line or line.startswith('//'): continue
|
||||
line = line.split('//')[0].strip() # Strip C-style comments
|
||||
if not line: continue
|
||||
if line.startswith('if '):
|
||||
lines.append(' ' * indent + f"if {_expr(line[3:].rstrip(' then'))}:")
|
||||
indent += 1
|
||||
@@ -351,8 +352,9 @@ def _extract_pseudocode(text: str) -> str | None:
|
||||
for line in lines:
|
||||
s = line.strip()
|
||||
if not s or re.match(r'^\d+ of \d+$', s) or re.match(r'^\d+\.\d+\..*Instructions', s): continue
|
||||
if s.startswith(('Notes', 'Functional examples')): break
|
||||
if s.startswith(('Notes', 'Functional examples', '•', '-')): break # Stop at notes/bullets
|
||||
if s.startswith(('"RDNA', 'AMD ', 'CDNA')): continue
|
||||
if '•' in s or '–' in s: continue # Skip lines with bullets/dashes
|
||||
if '= lambda(' in s: in_lambda += 1; continue
|
||||
if in_lambda > 0:
|
||||
if s.endswith(');'): in_lambda -= 1
|
||||
@@ -362,7 +364,8 @@ def _extract_pseudocode(text: str) -> str | None:
|
||||
if s.endswith('.') and not any(p in s for p in ['D0', 'D1', 'S0', 'S1', 'S2', 'SCC', 'VCC', 'tmp', '=']): continue
|
||||
if re.match(r'^[a-z].*\.$', s) and '=' not in s: continue
|
||||
is_code = (any(p in s for p in ['D0.', 'D1.', 'S0.', 'S1.', 'S2.', 'SCC =', 'SCC ?', 'VCC', 'EXEC', 'tmp =', 'tmp[', 'lane =', 'PC =',
|
||||
'D0[', 'D1[', 'S0[', 'S1[', 'S2[']) or
|
||||
'D0[', 'D1[', 'S0[', 'S1[', 'S2[', 'MEM[', 'RETURN_DATA',
|
||||
'VADDR', 'VDATA', 'VDST', 'SADDR', 'OFFSET']) or
|
||||
s.startswith(('if ', 'else', 'elsif', 'endif', 'declare ', 'for ', 'endfor', '//')) or
|
||||
re.match(r'^[a-z_]+\s*=', s) or re.match(r'^[a-z_]+\[', s) or (depth > 0 and '=' in s))
|
||||
if is_code: result.append(s)
|
||||
@@ -448,28 +451,23 @@ def _generate_gen_pcode_py(enums, pseudocode, arch) -> str:
|
||||
# Get op enums for this arch (import from .ins which re-exports from .enum)
|
||||
import importlib
|
||||
autogen = importlib.import_module(f"extra.assembly.amd.autogen.{arch}.ins")
|
||||
OP_ENUMS = [getattr(autogen, name) for name in ['SOP1Op', 'SOP2Op', 'SOPCOp', 'SOPKOp', 'SOPPOp', 'VOP1Op', 'VOP2Op', 'VOP3Op', 'VOP3SDOp', 'VOP3POp', 'VOPCOp', 'VOP3AOp', 'VOP3BOp'] if hasattr(autogen, name)]
|
||||
OP_ENUMS = [getattr(autogen, name) for name in ['SOP1Op', 'SOP2Op', 'SOPCOp', 'SOPKOp', 'SOPPOp', 'VOP1Op', 'VOP2Op', 'VOP3Op', 'VOP3SDOp', 'VOP3POp', 'VOPCOp', 'VOP3AOp', 'VOP3BOp', 'DSOp', 'FLATOp', 'GLOBALOp', 'SCRATCHOp'] if hasattr(autogen, name)]
|
||||
|
||||
# Build defined ops mapping
|
||||
defined_ops: dict[tuple, list] = {}
|
||||
for enum_cls in OP_ENUMS:
|
||||
for op in enum_cls:
|
||||
if op.name.startswith(('S_', 'V_')): defined_ops.setdefault((op.name, op.value), []).append((enum_cls, op))
|
||||
if op.name.startswith(('S_', 'V_', 'DS_', 'FLAT_', 'GLOBAL_', 'SCRATCH_')): defined_ops.setdefault((op.name, op.value), []).append((enum_cls, op))
|
||||
|
||||
enum_names = [e.__name__ for e in OP_ENUMS]
|
||||
lines = [f'''# autogenerated by pdf.py - do not edit
|
||||
# to regenerate: python -m extra.assembly.amd.pdf --arch {arch}
|
||||
# ruff: noqa: E501,F405,F403
|
||||
# mypy: ignore-errors
|
||||
from extra.assembly.amd.autogen.{arch}.enum import {", ".join(enum_names)}
|
||||
from extra.assembly.amd.pcode import *
|
||||
''']
|
||||
|
||||
instructions: dict = {cls: {} for cls in OP_ENUMS}
|
||||
for key, pc in pseudocode.items():
|
||||
if key in defined_ops:
|
||||
for enum_cls, enum_val in defined_ops[key]: instructions[enum_cls][enum_val] = pc
|
||||
|
||||
# First pass: generate all function code
|
||||
fn_lines: list[str] = []
|
||||
all_fn_entries: dict = {}
|
||||
for enum_cls in OP_ENUMS:
|
||||
cls_name = enum_cls.__name__
|
||||
if not instructions.get(enum_cls): continue
|
||||
@@ -480,28 +478,44 @@ from extra.assembly.amd.pcode import *
|
||||
code = compile_pseudocode(pc)
|
||||
code = _apply_pseudocode_fixes(op, code)
|
||||
fn_name, fn_code = _generate_function(cls_name, op, pc, code)
|
||||
lines.append(fn_code)
|
||||
fn_lines.append(fn_code)
|
||||
fn_entries.append((op, fn_name))
|
||||
except Exception as e: print(f" Warning: Failed to compile {op.name}: {e}")
|
||||
if fn_entries:
|
||||
lines.append(f'{cls_name}_FUNCTIONS = {{')
|
||||
for op, fn_name in fn_entries: lines.append(f" {cls_name}.{op.name}: {fn_name},")
|
||||
lines.append('}\n')
|
||||
all_fn_entries[enum_cls] = fn_entries
|
||||
fn_lines.append(f'{cls_name}_FUNCTIONS = {{')
|
||||
for op, fn_name in fn_entries: fn_lines.append(f" {cls_name}.{op.name}: {fn_name},")
|
||||
fn_lines.append('}\n')
|
||||
|
||||
# Add V_WRITELANE_B32 if VOP3Op exists
|
||||
if 'VOP3Op' in enum_names:
|
||||
lines.append('''
|
||||
fn_lines.append('''
|
||||
# V_WRITELANE_B32: Write scalar to specific lane's VGPR (not in PDF pseudocode)
|
||||
def _VOP3Op_V_WRITELANE_B32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
|
||||
def _VOP3Op_V_WRITELANE_B32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, src0_idx=0, vdst_idx=0, PC=None):
|
||||
wr_lane = s1 & 0x1f
|
||||
return {'d0': d0, 'scc': scc, 'vgpr_write': (wr_lane, vdst_idx, s0 & 0xffffffff)}
|
||||
VOP3Op_FUNCTIONS[VOP3Op.V_WRITELANE_B32] = _VOP3Op_V_WRITELANE_B32
|
||||
''')
|
||||
|
||||
lines.append('COMPILED_FUNCTIONS = {')
|
||||
fn_lines.append('COMPILED_FUNCTIONS = {')
|
||||
for enum_cls in OP_ENUMS:
|
||||
if instructions.get(enum_cls): lines.append(f' {enum_cls.__name__}: {enum_cls.__name__}_FUNCTIONS,')
|
||||
lines.append('}\n\ndef get_compiled_functions(): return COMPILED_FUNCTIONS')
|
||||
if all_fn_entries.get(enum_cls): fn_lines.append(f' {enum_cls.__name__}: {enum_cls.__name__}_FUNCTIONS,')
|
||||
fn_lines.append('}\n\ndef get_compiled_functions(): return COMPILED_FUNCTIONS')
|
||||
|
||||
# Second pass: scan generated code for pcode imports
|
||||
fn_code_str = '\n'.join(fn_lines)
|
||||
import extra.assembly.amd.pcode as pcode_module
|
||||
pcode_exports = [name for name in dir(pcode_module) if not name.startswith('_') or name.startswith('_') and not name.startswith('__')]
|
||||
used_imports = sorted(name for name in pcode_exports if re.search(rf'\b{re.escape(name)}\b', fn_code_str))
|
||||
|
||||
# Build final output with explicit imports
|
||||
lines = [f'''# autogenerated by pdf.py - do not edit
|
||||
# to regenerate: python -m extra.assembly.amd.pdf --arch {arch}
|
||||
# ruff: noqa: E501
|
||||
# mypy: ignore-errors
|
||||
from extra.assembly.amd.autogen.{arch}.enum import {", ".join(enum_names)}
|
||||
from extra.assembly.amd.pcode import {", ".join(used_imports)}
|
||||
'''] + fn_lines
|
||||
return '\n'.join(lines)
|
||||
|
||||
def _apply_pseudocode_fixes(op, code: str) -> str:
|
||||
@@ -541,19 +555,32 @@ def _generate_function(cls_name: str, op, pc: str, code: str) -> tuple[str, str]
|
||||
is_cmpx = (cls_name in ('VOPCOp', 'VOP3Op')) and 'EXEC.u64[laneId]' in pc
|
||||
is_div_scale = 'DIV_SCALE' in op.name
|
||||
has_sdst = cls_name == 'VOP3SDOp' and ('VCC.u64[laneId]' in pc or is_div_scale)
|
||||
is_ds = cls_name == 'DSOp'
|
||||
is_flat = cls_name in ('FLATOp', 'GLOBALOp', 'SCRATCHOp')
|
||||
combined = code + pc
|
||||
|
||||
fn_name = f"_{cls_name}_{op.name}"
|
||||
# Function accepts Reg objects directly (uppercase names), laneId is passed directly as int
|
||||
lines = [f"def {fn_name}(S0, S1, S2, D0, SCC, VCC, laneId, EXEC, literal, VGPR, src0_idx=0, vdst_idx=0, PC=None):"]
|
||||
# DSOp functions get additional MEM and offset parameters
|
||||
# FLAT/GLOBAL ops get MEM, vaddr, vdata, saddr, offset parameters
|
||||
if is_ds:
|
||||
lines = [f"def {fn_name}(MEM, ADDR, DATA0, DATA1, OFFSET0, OFFSET1, RETURN_DATA):"]
|
||||
elif is_flat:
|
||||
lines = [f"def {fn_name}(MEM, ADDR, VDATA, VDST, RETURN_DATA):"]
|
||||
else:
|
||||
lines = [f"def {fn_name}(S0, S1, S2, D0, SCC, VCC, laneId, EXEC, literal, VGPR, src0_idx=0, vdst_idx=0, PC=None):"]
|
||||
|
||||
# Registers that need special handling (not passed directly)
|
||||
# Only init if used but not first assigned as `name = Reg(...)` in the compiled code
|
||||
# Registers that need special handling (aliases or init)
|
||||
def needs_init(name): return name in combined and not re.search(rf'^\s*{name}\s*=\s*Reg\(', code, re.MULTILINE)
|
||||
special_regs = [('D1', 'Reg(0)'), ('SIMM16', 'Reg(literal)'), ('SIMM32', 'Reg(literal)'),
|
||||
('SRC0', 'Reg(src0_idx)'), ('VDST', 'Reg(vdst_idx)')]
|
||||
if needs_init('tmp'): special_regs.insert(0, ('tmp', 'Reg(0)'))
|
||||
if needs_init('saveexec'): special_regs.insert(0, ('saveexec', 'Reg(EXEC._val)'))
|
||||
special_regs = []
|
||||
if is_ds: special_regs = [('DATA', 'DATA0'), ('DATA2', 'DATA1'), ('OFFSET', 'OFFSET0'), ('ADDR_BASE', 'ADDR')]
|
||||
elif is_flat: special_regs = [('DATA', 'VDATA')]
|
||||
else:
|
||||
special_regs = [('D1', 'Reg(0)'), ('SIMM16', 'Reg(literal)'), ('SIMM32', 'Reg(literal)'),
|
||||
('SRC0', 'Reg(src0_idx)'), ('VDST', 'Reg(vdst_idx)')]
|
||||
if needs_init('tmp'): special_regs.insert(0, ('tmp', 'Reg(0)'))
|
||||
if needs_init('saveexec'): special_regs.insert(0, ('saveexec', 'Reg(EXEC._val)'))
|
||||
|
||||
used = {name for name, _ in special_regs if name in combined}
|
||||
|
||||
# Detect which registers are modified (not just read) - look for assignments
|
||||
@@ -562,6 +589,10 @@ def _generate_function(cls_name: str, op, pc: str, code: str) -> tuple[str, str]
|
||||
modifies_vcc = has_sdst or bool(re.search(r'VCC\.(u32|u64|b32|b64)\s*=|VCC\.u64\[laneId\]\s*=', combined))
|
||||
modifies_scc = bool(re.search(r'\bSCC\s*=', combined))
|
||||
modifies_pc = bool(re.search(r'\bPC\s*=', combined))
|
||||
# DS/FLAT ops: detect memory writes (MEM[...] = ...)
|
||||
modifies_mem = (is_ds or is_flat) and bool(re.search(r'MEM\[.*\]\.[a-z0-9]+\s*=', combined))
|
||||
# FLAT ops: detect VDST writes
|
||||
modifies_vdst = is_flat and bool(re.search(r'VDST[\.\[].*=', combined))
|
||||
|
||||
# Build init code for special registers
|
||||
init_lines = []
|
||||
@@ -587,6 +618,15 @@ def _generate_function(cls_name: str, op, pc: str, code: str) -> tuple[str, str]
|
||||
if modifies_exec: result_items.append("'EXEC': EXEC")
|
||||
if has_d1: result_items.append("'D1': D1")
|
||||
if modifies_pc: result_items.append("'PC': PC")
|
||||
# DS ops: return RETURN_DATA if it was written (left side of assignment)
|
||||
if is_ds and 'RETURN_DATA' in combined and re.search(r'^\s*RETURN_DATA[\.\[].*=', code, re.MULTILINE):
|
||||
result_items.append("'RETURN_DATA': RETURN_DATA")
|
||||
# FLAT ops: return RETURN_DATA for atomics, VDATA for loads (only if written to)
|
||||
if is_flat:
|
||||
if 'RETURN_DATA' in combined and re.search(r'^\s*RETURN_DATA[\.\[].*=', code, re.MULTILINE):
|
||||
result_items.append("'RETURN_DATA': RETURN_DATA")
|
||||
if re.search(r'^\s*VDATA[\.\[].*=', code, re.MULTILINE):
|
||||
result_items.append("'VDATA': VDATA")
|
||||
lines.append(f" return {{{', '.join(result_items)}}}\n")
|
||||
return fn_name, '\n'.join(lines)
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ os.environ["AMD"] = "1"
|
||||
os.environ["MOCKGPU"] = "1"
|
||||
os.environ["PYTHON_REMU"] = "1"
|
||||
|
||||
from extra.assembly.amd.emu import WaveState, decode_program, step_wave, WAVE_SIZE, set_valid_mem_ranges
|
||||
from extra.assembly.amd.emu import WaveState, decode_program, step_wave, WAVE_SIZE, set_valid_mem_ranges, LDSMem
|
||||
from extra.assembly.amd.test.helpers import KernelInfo
|
||||
|
||||
REMU_PATH = Path(__file__).parents[3] / "remu/target/release/libremu.so"
|
||||
@@ -99,7 +99,7 @@ class PythonEmulator:
|
||||
self.program = decode_program(kernel)
|
||||
self.state = WaveState()
|
||||
self.state.exec_mask = (1 << n_lanes) - 1
|
||||
self.lds = bytearray(65536)
|
||||
self.lds = LDSMem(bytearray(65536))
|
||||
self.n_lanes = n_lanes
|
||||
|
||||
def step(self) -> int:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user