ds compiled

This commit is contained in:
George Hotz
2025-12-31 14:54:39 -05:00
parent f022a7d8a7
commit aec4d65241
3 changed files with 128 additions and 36 deletions

View File

@@ -2,7 +2,7 @@
# to regenerate: python -m extra.assembly.amd.pdf --arch rdna3
# ruff: noqa: E501,F405,F403
# mypy: ignore-errors
from extra.assembly.amd.autogen.rdna3.enum import SOP1Op, SOP2Op, SOPCOp, SOPKOp, SOPPOp, VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOP3POp, VOPCOp
from extra.assembly.amd.autogen.rdna3.enum import SOP1Op, SOP2Op, SOPCOp, SOPKOp, SOPPOp, VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOP3POp, VOPCOp, DSOp
from extra.assembly.amd.pcode import *
def _SOP1Op_S_MOV_B32(S0, S1, S2, D0, SCC, VCC, laneId, EXEC, literal, VGPR, src0_idx=0, vdst_idx=0, PC=None):
@@ -6360,6 +6360,119 @@ def _VOP3Op_V_WRITELANE_B32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal,
return {'d0': d0, 'scc': scc, 'vgpr_write': (wr_lane, vdst_idx, s0 & 0xffffffff)}
VOP3Op_FUNCTIONS[VOP3Op.V_WRITELANE_B32] = _VOP3Op_V_WRITELANE_B32
# ═══════════════════════════════════════════════════════════════════════════════
# DS (Data Share) INSTRUCTIONS
# DS instructions operate on LDS (Local Data Share) memory
# They receive: addr (address), data0/data1 (data VGPRs), offset0/offset1 (byte offsets)
# They return: {'vdst': [...]} for loads, {'lds_writes': [...]} for stores
# ═══════════════════════════════════════════════════════════════════════════════
def _ds_load(lds, addr, offset, size, sign_extend=False):
"""Load from LDS memory. Returns list of 32-bit values."""
a = (addr + offset) & 0xffff
if size <= 4:
val = int.from_bytes(lds[a:a+size], 'little')
if sign_extend and size < 4:
# Sign extend from size*8 bits to 32 bits
sign_bit = 1 << (size * 8 - 1)
if val & sign_bit: val |= ~((1 << (size * 8)) - 1)
return [val & 0xffffffff]
# Multi-dword load
return [int.from_bytes(lds[a+i*4:a+i*4+4], 'little') for i in range(size // 4)]
def _ds_store(lds, addr, offset, values, size):
"""Store to LDS memory. values is list of 32-bit dwords, size is bytes per element."""
a = (addr + offset) & 0xffff
if size <= 4:
lds[a:a+size] = (values[0] & ((1 << (size * 8)) - 1)).to_bytes(size, 'little')
else:
for i, v in enumerate(values):
lds[a+i*4:a+i*4+4] = (v & 0xffffffff).to_bytes(4, 'little')
# Load operations: DS_LOAD_B32, DS_LOAD_B64, DS_LOAD_B128, DS_LOAD_U8, DS_LOAD_I8, DS_LOAD_U16, DS_LOAD_I16
def _DSOp_DS_LOAD_B32(lds, addr, data0, data1, vdst, offset0, offset1):
return {'vdst': _ds_load(lds, addr, offset0, 4)}
def _DSOp_DS_LOAD_B64(lds, addr, data0, data1, vdst, offset0, offset1):
return {'vdst': _ds_load(lds, addr, offset0, 8)}
def _DSOp_DS_LOAD_B128(lds, addr, data0, data1, vdst, offset0, offset1):
return {'vdst': _ds_load(lds, addr, offset0, 16)}
def _DSOp_DS_LOAD_U8(lds, addr, data0, data1, vdst, offset0, offset1):
return {'vdst': _ds_load(lds, addr, offset0, 1, sign_extend=False)}
def _DSOp_DS_LOAD_I8(lds, addr, data0, data1, vdst, offset0, offset1):
return {'vdst': _ds_load(lds, addr, offset0, 1, sign_extend=True)}
def _DSOp_DS_LOAD_U16(lds, addr, data0, data1, vdst, offset0, offset1):
return {'vdst': _ds_load(lds, addr, offset0, 2, sign_extend=False)}
def _DSOp_DS_LOAD_I16(lds, addr, data0, data1, vdst, offset0, offset1):
return {'vdst': _ds_load(lds, addr, offset0, 2, sign_extend=True)}
# Store operations: DS_STORE_B32, DS_STORE_B64, DS_STORE_B128, DS_STORE_B8, DS_STORE_B16
def _DSOp_DS_STORE_B32(lds, addr, data0, data1, vdst, offset0, offset1):
_ds_store(lds, addr, offset0, data0, 4)
return {}
def _DSOp_DS_STORE_B64(lds, addr, data0, data1, vdst, offset0, offset1):
_ds_store(lds, addr, offset0, data0, 8)
return {}
def _DSOp_DS_STORE_B128(lds, addr, data0, data1, vdst, offset0, offset1):
_ds_store(lds, addr, offset0, data0, 16)
return {}
def _DSOp_DS_STORE_B8(lds, addr, data0, data1, vdst, offset0, offset1):
_ds_store(lds, addr, offset0, data0, 1)
return {}
def _DSOp_DS_STORE_B16(lds, addr, data0, data1, vdst, offset0, offset1):
_ds_store(lds, addr, offset0, data0, 2)
return {}
# 2-address operations: DS_LOAD_2ADDR_B32, DS_LOAD_2ADDR_B64, DS_STORE_2ADDR_B32, DS_STORE_2ADDR_B64
# Note: offsets are scaled by data size (4 for B32, 8 for B64)
def _DSOp_DS_LOAD_2ADDR_B32(lds, addr, data0, data1, vdst, offset0, offset1):
v0 = _ds_load(lds, addr, offset0 * 4, 4)
v1 = _ds_load(lds, addr, offset1 * 4, 4)
return {'vdst': v0 + v1}
def _DSOp_DS_LOAD_2ADDR_B64(lds, addr, data0, data1, vdst, offset0, offset1):
v0 = _ds_load(lds, addr, offset0 * 8, 8)
v1 = _ds_load(lds, addr, offset1 * 8, 8)
return {'vdst': v0 + v1}
def _DSOp_DS_STORE_2ADDR_B32(lds, addr, data0, data1, vdst, offset0, offset1):
_ds_store(lds, addr, offset0 * 4, data0, 4)
_ds_store(lds, addr, offset1 * 4, data1, 4)
return {}
def _DSOp_DS_STORE_2ADDR_B64(lds, addr, data0, data1, vdst, offset0, offset1):
_ds_store(lds, addr, offset0 * 8, data0, 8)
_ds_store(lds, addr, offset1 * 8, data1, 8)
return {}
DSOp_FUNCTIONS = {
DSOp.DS_LOAD_B32: _DSOp_DS_LOAD_B32,
DSOp.DS_LOAD_B64: _DSOp_DS_LOAD_B64,
DSOp.DS_LOAD_B128: _DSOp_DS_LOAD_B128,
DSOp.DS_LOAD_U8: _DSOp_DS_LOAD_U8,
DSOp.DS_LOAD_I8: _DSOp_DS_LOAD_I8,
DSOp.DS_LOAD_U16: _DSOp_DS_LOAD_U16,
DSOp.DS_LOAD_I16: _DSOp_DS_LOAD_I16,
DSOp.DS_STORE_B32: _DSOp_DS_STORE_B32,
DSOp.DS_STORE_B64: _DSOp_DS_STORE_B64,
DSOp.DS_STORE_B128: _DSOp_DS_STORE_B128,
DSOp.DS_STORE_B8: _DSOp_DS_STORE_B8,
DSOp.DS_STORE_B16: _DSOp_DS_STORE_B16,
DSOp.DS_LOAD_2ADDR_B32: _DSOp_DS_LOAD_2ADDR_B32,
DSOp.DS_LOAD_2ADDR_B64: _DSOp_DS_LOAD_2ADDR_B64,
DSOp.DS_STORE_2ADDR_B32: _DSOp_DS_STORE_2ADDR_B32,
DSOp.DS_STORE_2ADDR_B64: _DSOp_DS_STORE_2ADDR_B64,
}
COMPILED_FUNCTIONS = {
SOP1Op: SOP1Op_FUNCTIONS,
SOP2Op: SOP2Op_FUNCTIONS,
@@ -6372,6 +6485,7 @@ COMPILED_FUNCTIONS = {
VOP3SDOp: VOP3SDOp_FUNCTIONS,
VOP3POp: VOP3POp_FUNCTIONS,
VOPCOp: VOPCOp_FUNCTIONS,
DSOp: DSOp_FUNCTIONS,
}
def get_compiled_functions(): return COMPILED_FUNCTIONS

View File

@@ -51,11 +51,6 @@ _D16_LOAD_MAP = {'LOAD_D16_U8': (1,0,0), 'LOAD_D16_I8': (1,1,0), 'LOAD_D16_B16':
_D16_STORE_MAP = {'STORE_D16_HI_B8': (1,1), 'STORE_D16_HI_B16': (2,1)} # (size, hi)
FLAT_D16_LOAD = _mem_ops([GLOBALOp, FLATOp], _D16_LOAD_MAP)
FLAT_D16_STORE = _mem_ops([GLOBALOp, FLATOp], _D16_STORE_MAP)
DS_LOAD = {DSOp.DS_LOAD_B32: (1,4,0), DSOp.DS_LOAD_B64: (2,4,0), DSOp.DS_LOAD_B128: (4,4,0), DSOp.DS_LOAD_U8: (1,1,0), DSOp.DS_LOAD_I8: (1,1,1), DSOp.DS_LOAD_U16: (1,2,0), DSOp.DS_LOAD_I16: (1,2,1)}
DS_STORE = {DSOp.DS_STORE_B32: (1,4), DSOp.DS_STORE_B64: (2,4), DSOp.DS_STORE_B128: (4,4), DSOp.DS_STORE_B8: (1,1), DSOp.DS_STORE_B16: (1,2)}
# 2ADDR ops: load/store two values using offset0 and offset1
DS_LOAD_2ADDR = {DSOp.DS_LOAD_2ADDR_B32: 4, DSOp.DS_LOAD_2ADDR_B64: 8}
DS_STORE_2ADDR = {DSOp.DS_STORE_2ADDR_B32: 4, DSOp.DS_STORE_2ADDR_B64: 8}
SMEM_LOAD = {SMEMOp.S_LOAD_B32: 1, SMEMOp.S_LOAD_B64: 2, SMEMOp.S_LOAD_B128: 4, SMEMOp.S_LOAD_B256: 8, SMEMOp.S_LOAD_B512: 16}
# VOPD op -> VOP3 op mapping (VOPD is dual-issue of VOP1/VOP2 ops, use VOP3 enums for pseudocode lookup)
@@ -225,32 +220,15 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No
return
if isinstance(inst, DS):
op, addr0, vdst = inst.op, (V[inst.addr] + inst.offset0) & 0xffff, inst.vdst
if op in DS_LOAD:
cnt, sz, sign = DS_LOAD[op]
for i in range(cnt): val = int.from_bytes(lds[addr0+i*sz:addr0+i*sz+sz], 'little'); V[vdst + i] = _sext(val, sz * 8) & MASK32 if sign else val
elif op in DS_STORE:
cnt, sz = DS_STORE[op]
for i in range(cnt): lds[addr0+i*sz:addr0+i*sz+sz] = (V[inst.data0 + i] & ((1 << (sz * 8)) - 1)).to_bytes(sz, 'little')
elif op in DS_LOAD_2ADDR:
# Load two values from addr+offset0*sz and addr+offset1*sz into vdst (B32: 1 dword each, B64: 2 dwords each)
# Note: offsets are scaled by data size (4 for B32, 8 for B64) per AMD ISA
sz = DS_LOAD_2ADDR[op]
addr0 = (V[inst.addr] + inst.offset0 * sz) & 0xffff
addr1 = (V[inst.addr] + inst.offset1 * sz) & 0xffff
cnt = sz // 4 # 1 for B32, 2 for B64
for i in range(cnt): V[vdst + i] = int.from_bytes(lds[addr0+i*4:addr0+i*4+4], 'little')
for i in range(cnt): V[vdst + cnt + i] = int.from_bytes(lds[addr1+i*4:addr1+i*4+4], 'little')
elif op in DS_STORE_2ADDR:
# Store two values from data0 and data1 to addr+offset0*sz and addr+offset1*sz
# Note: offsets are scaled by data size (4 for B32, 8 for B64) per AMD ISA
sz = DS_STORE_2ADDR[op]
addr0 = (V[inst.addr] + inst.offset0 * sz) & 0xffff
addr1 = (V[inst.addr] + inst.offset1 * sz) & 0xffff
cnt = sz // 4
for i in range(cnt): lds[addr0+i*4:addr0+i*4+4] = (V[inst.data0 + i] & MASK32).to_bytes(4, 'little')
for i in range(cnt): lds[addr1+i*4:addr1+i*4+4] = (V[inst.data1 + i] & MASK32).to_bytes(4, 'little')
else: raise NotImplementedError(f"DS op {op}")
fn = compiled.get(DSOp, {}).get(inst.op)
if fn is None: raise NotImplementedError(f"DS op {inst.op.name} not in pseudocode")
# Prepare data registers as lists of dwords
data0 = [V[inst.data0 + i] for i in range(4)] # up to 4 dwords
data1 = [V[inst.data1 + i] for i in range(4)] if inst.data1 else [0, 0, 0, 0]
result = fn(lds, V[inst.addr], data0, data1, inst.vdst, inst.offset0, inst.offset1)
# Write results for loads
if 'vdst' in result:
for i, val in enumerate(result['vdst']): V[inst.vdst + i] = val & MASK32
return
# VOPD: dual-issue, execute two ops simultaneously (read all inputs before writes)

View File

@@ -36,7 +36,7 @@ FIELD_ORDER = {
SRC_EXTRAS = {233: 'DPP8', 234: 'DPP8FI', 250: 'DPP16', 251: 'VCCZ', 252: 'EXECZ', 254: 'LDS_DIRECT'}
FLOAT_MAP = {'0.5': 'POS_HALF', '-0.5': 'NEG_HALF', '1.0': 'POS_ONE', '-1.0': 'NEG_ONE', '2.0': 'POS_TWO', '-2.0': 'NEG_TWO',
'4.0': 'POS_FOUR', '-4.0': 'NEG_FOUR', '1/(2*PI)': 'INV_2PI', '0': 'ZERO'}
INST_PATTERN = re.compile(r'^([SV]_[A-Z0-9_]+)\s+(\d+)\s*$', re.M)
INST_PATTERN = re.compile(r'^([SVD]S?_[A-Z0-9_]+)\s+(\d+)\s*$', re.M)
# Patterns that can't be handled by the DSL (require special handling in emu.py)
UNSUPPORTED = ['SGPR[', 'V_SWAP', 'eval ', 'FATAL_HALT', 'HW_REGISTERS',
@@ -381,7 +381,7 @@ def _extract_pseudocode(text: str) -> str | None:
if s.endswith('.') and not any(p in s for p in ['D0', 'D1', 'S0', 'S1', 'S2', 'SCC', 'VCC', 'tmp', '=']): continue
if re.match(r'^[a-z].*\.$', s) and '=' not in s: continue
is_code = (any(p in s for p in ['D0.', 'D1.', 'S0.', 'S1.', 'S2.', 'SCC =', 'SCC ?', 'VCC', 'EXEC', 'tmp =', 'tmp[', 'lane =', 'PC =',
'D0[', 'D1[', 'S0[', 'S1[', 'S2[']) or
'D0[', 'D1[', 'S0[', 'S1[', 'S2[', 'MEM[', 'RETURN_DATA', 'DATA.', 'DATA0', 'DATA1', 'ADDR']) or
s.startswith(('if ', 'else', 'elsif', 'endif', 'declare ', 'for ', 'endfor', '//')) or
re.match(r'^[a-z_]+\s*=', s) or re.match(r'^[a-z_]+\[', s) or (depth > 0 and '=' in s))
if is_code: result.append(s)
@@ -467,13 +467,13 @@ def _generate_gen_pcode_py(enums, pseudocode, arch) -> str:
# Get op enums for this arch (import from .ins which re-exports from .enum)
import importlib
autogen = importlib.import_module(f"extra.assembly.amd.autogen.{arch}.ins")
OP_ENUMS = [getattr(autogen, name) for name in ['SOP1Op', 'SOP2Op', 'SOPCOp', 'SOPKOp', 'SOPPOp', 'VOP1Op', 'VOP2Op', 'VOP3Op', 'VOP3SDOp', 'VOP3POp', 'VOPCOp', 'VOP3AOp', 'VOP3BOp'] if hasattr(autogen, name)]
OP_ENUMS = [getattr(autogen, name) for name in ['SOP1Op', 'SOP2Op', 'SOPCOp', 'SOPKOp', 'SOPPOp', 'VOP1Op', 'VOP2Op', 'VOP3Op', 'VOP3SDOp', 'VOP3POp', 'VOPCOp', 'VOP3AOp', 'VOP3BOp', 'DSOp'] if hasattr(autogen, name)]
# Build defined ops mapping
defined_ops: dict[tuple, list] = {}
for enum_cls in OP_ENUMS:
for op in enum_cls:
if op.name.startswith(('S_', 'V_')): defined_ops.setdefault((op.name, op.value), []).append((enum_cls, op))
if op.name.startswith(('S_', 'V_', 'DS_')): defined_ops.setdefault((op.name, op.value), []).append((enum_cls, op))
enum_names = [e.__name__ for e in OP_ENUMS]
lines = [f'''# autogenerated by pdf.py - do not edit