assembly/amd: fix v_perm_b32 + PC fixes (#13897)

* assembly/amd: fix v_perm_b32

* add pc support
This commit is contained in:
George Hotz
2025-12-30 09:25:40 -05:00
committed by GitHub
parent 2b838dc1d8
commit 9c89be5235
6 changed files with 6134 additions and 3000 deletions

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -205,21 +205,11 @@ def exec_scalar(st: WaveState, inst: Inst) -> int:
compiled = _get_compiled()
inst_type = type(inst)
# SOPP: control flow (not ALU)
# SOPP: special cases for control flow that has no pseudocode
if inst_type is SOPP:
op = inst.op
if op == SOPPOp.S_ENDPGM: return -1
if op == SOPPOp.S_BARRIER: return -2
if op == SOPPOp.S_BRANCH: return _sext(inst.simm16, 16)
if op == SOPPOp.S_CBRANCH_SCC0: return _sext(inst.simm16, 16) if st.scc == 0 else 0
if op == SOPPOp.S_CBRANCH_SCC1: return _sext(inst.simm16, 16) if st.scc == 1 else 0
if op == SOPPOp.S_CBRANCH_VCCZ: return _sext(inst.simm16, 16) if (st.vcc & 0xffffffff) == 0 else 0
if op == SOPPOp.S_CBRANCH_VCCNZ: return _sext(inst.simm16, 16) if (st.vcc & 0xffffffff) != 0 else 0
if op == SOPPOp.S_CBRANCH_EXECZ: return _sext(inst.simm16, 16) if st.exec_mask == 0 else 0
if op == SOPPOp.S_CBRANCH_EXECNZ: return _sext(inst.simm16, 16) if st.exec_mask != 0 else 0
# Valid SOPP range is 0-61 (max defined opcode); anything above is invalid
if op > 61: raise NotImplementedError(f"Invalid SOPP opcode {op}")
return 0 # waits, hints, nops
# SMEM: memory loads (not ALU)
if inst_type is SMEM:
@@ -229,46 +219,39 @@ def exec_scalar(st: WaveState, inst: Inst) -> int:
for i in range(cnt): st.wsgpr(inst.sdata + i, mem_read((addr + i * 4) & 0xffffffffffffffff, 4))
return 0
# SOP1: special handling for ops not in pseudocode
if inst_type is SOP1:
op = SOP1Op(inst.op)
# S_GETPC_B64: Get program counter (PC is stored as byte offset, convert from words)
if op == SOP1Op.S_GETPC_B64:
pc_bytes = st.pc * 4 # PC is in words, convert to bytes
st.wsgpr64(inst.sdst, pc_bytes)
return 0
# S_SETPC_B64: Set program counter to source value (indirect jump)
# Returns delta such that st.pc + inst_words + delta = target_words
if op == SOP1Op.S_SETPC_B64:
target_bytes = st.rsrc64(inst.ssrc0, 0)
target_words = target_bytes // 4
inst_words = 1 # SOP1 is always 1 word
return target_words - st.pc - inst_words
# Get op enum and lookup compiled function
if inst_type is SOP1: op_cls, ssrc0, sdst = SOP1Op, inst.ssrc0, inst.sdst
elif inst_type is SOP2: op_cls, ssrc0, sdst = SOP2Op, inst.ssrc0, inst.sdst
elif inst_type is SOPC: op_cls, ssrc0, sdst = SOPCOp, inst.ssrc0, None
elif inst_type is SOPK: op_cls, ssrc0, sdst = SOPKOp, inst.sdst, inst.sdst # sdst is both src and dst
elif inst_type is SOPP: op_cls, ssrc0, sdst = SOPPOp, None, None
else: raise NotImplementedError(f"Unknown scalar type {inst_type}")
op = op_cls(inst.op)
# SOPP has gaps in the opcode enum - treat unknown opcodes as no-ops
try: op = op_cls(inst.op)
except ValueError:
if inst_type is SOPP: return 0
raise
fn = compiled.get(op_cls, {}).get(op)
if fn is None: raise NotImplementedError(f"{op.name} not in pseudocode")
if fn is None:
# SOPP instructions without pseudocode (waits, hints, nops) are no-ops
if inst_type is SOPP: return 0
raise NotImplementedError(f"{op.name} not in pseudocode")
# Build context - handle 64-bit ops that need 64-bit source reads
# 64-bit source ops: name ends with _B64, _I64, _U64 or contains _U64, _I64 before last underscore
is_64bit_s0 = op.name.endswith(('_B64', '_I64', '_U64')) or '_U64_' in op.name or '_I64_' in op.name
is_64bit_s0s1 = op_cls is SOPCOp and op in (SOPCOp.S_CMP_EQ_U64, SOPCOp.S_CMP_LG_U64)
s0 = st.rsrc64(ssrc0, 0) if is_64bit_s0 or is_64bit_s0s1 else (st.rsrc(ssrc0, 0) if inst_type != SOPK else st.rsgpr(inst.sdst))
s0 = st.rsrc64(ssrc0, 0) if is_64bit_s0 or is_64bit_s0s1 else (st.rsrc(ssrc0, 0) if inst_type not in (SOPK, SOPP) else (st.rsgpr(inst.sdst) if inst_type is SOPK else 0))
is_64bit_sop2 = is_64bit_s0 and inst_type is SOP2
s1 = st.rsrc64(inst.ssrc1, 0) if (is_64bit_sop2 or is_64bit_s0s1) else (st.rsrc(inst.ssrc1, 0) if inst_type in (SOP2, SOPC) else inst.simm16 if inst_type is SOPK else 0)
d0 = st.rsgpr64(sdst) if (is_64bit_s0 or is_64bit_s0s1) and sdst is not None else (st.rsgpr(sdst) if sdst is not None else 0)
exec_mask = st.exec_mask
literal = inst.simm16 if inst_type is SOPK else st.literal
literal = inst.simm16 if inst_type in (SOPK, SOPP) else st.literal
# Execute compiled function
result = fn(s0, s1, 0, d0, st.scc, st.vcc, 0, exec_mask, literal, None, {})
# Execute compiled function - pass PC in bytes for instructions that need it
pc_bytes = st.pc * 4
result = fn(s0, s1, 0, d0, st.scc, st.vcc, 0, exec_mask, literal, None, {}, pc=pc_bytes)
# Apply results
if sdst is not None:
@@ -278,7 +261,11 @@ def exec_scalar(st: WaveState, inst: Inst) -> int:
st.wsgpr(sdst, result['d0'])
if 'scc' in result: st.scc = result['scc']
if 'exec' in result: st.exec_mask = result['exec']
if 'pc_delta' in result: return result['pc_delta']
if 'new_pc' in result:
# Convert absolute byte address to word delta
# new_pc is where we want to go, st.pc is current position, inst._words will be added after
new_pc_words = result['new_pc'] // 4
return new_pc_words - st.pc - 1 # -1 because emulator adds inst_words (1 for scalar)
return 0
def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = None) -> None:
@@ -402,20 +389,6 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No
op_cls, op, src0, src1, src2, vdst = VOPCOp, VOPCOp(inst.op), inst.src0, inst.src1, None, inst.vdst
else:
op_cls, op, src0, src1, src2, vdst = VOP3Op, VOP3Op(inst.op), inst.src0, inst.src1, inst.src2, inst.vdst
# V_PERM_B32: byte permutation - not in pseudocode PDF, implement directly
# D0[byte_i] = selector[byte_i] < 8 ? {src0, src1}[selector[byte_i]] : (selector[byte_i] >= 0xD ? 0xFF : 0x00)
if op == VOP3Op.V_PERM_B32:
s0, s1, s2 = st.rsrc(inst.src0, lane), st.rsrc(inst.src1, lane), st.rsrc(inst.src2, lane)
# Combine src1 and src0 into 8-byte value: src1 is bytes 0-3, src0 is bytes 4-7
combined = (s1 & 0xffffffff) | ((s0 & 0xffffffff) << 32)
result = 0
for i in range(4): # 4 result bytes
sel = (s2 >> (i * 8)) & 0xff # byte selector for this position
if sel <= 7: result |= (((combined >> (sel * 8)) & 0xff) << (i * 8)) # select byte from combined
elif sel >= 0xd: result |= (0xff << (i * 8)) # 0xD-0xF: constant 0xFF
# else 0x8-0xC: constant 0x00 (already 0)
V[vdst] = result & 0xffffffff
return
elif inst_type is VOPC:
op = VOPCOp(inst.op)
# For 16-bit VOPC, vsrc1 uses same encoding as VOP2 16-bit: bit 7 selects hi(1) or lo(0) half

View File

@@ -341,7 +341,7 @@ def F(x):
if isinstance(x, int): return _f32(x) # int -> interpret as f32 bits
if isinstance(x, TypedView): return x # preserve TypedView for bit-pattern checks
return float(x) # already a float or float-like
signext = lambda x: x
signext = lambda x: int(x) # sign-extend to full width - already handled by Python's arbitrary precision ints
pack = lambda hi, lo: ((int(hi) & 0xffff) << 16) | (int(lo) & 0xffff)
pack32 = lambda hi, lo: ((int(hi) & 0xffffffff) << 32) | (int(lo) & 0xffffffff)
_pack, _pack32 = pack, pack32 # Aliases for internal use
@@ -519,6 +519,17 @@ class TypedView:
def __bool__(s): return bool(int(s))
# Allow chained type access like jump_addr.i64 when jump_addr is already a TypedView
# These just return self or convert appropriately
@property
def i64(s): return s if s._bits == 64 and s._signed else int(s)
@property
def u64(s): return s if s._bits == 64 and not s._signed else int(s) & MASK64
@property
def i32(s): return s if s._bits == 32 and s._signed else _sext(int(s) & MASK32, 32)
@property
def u32(s): return s if s._bits == 32 and not s._signed else int(s) & MASK32
class Reg:
"""GPU register: D0.f32 = S0.f32 + S1.f32 just works."""
__slots__ = ('_val',)
@@ -542,6 +553,7 @@ class Reg:
bf16 = property(lambda s: TypedView(s, 16, is_float=True, is_bf16=True), lambda s, v: setattr(s, '_val', (s._val & 0xffff0000) | ((v if isinstance(v, int) else _ibf16(float(v))) & 0xffff)))
u8 = property(lambda s: TypedView(s, 8))
i8 = property(lambda s: TypedView(s, 8, signed=True))
u1 = property(lambda s: TypedView(s, 1)) # single bit
def __getitem__(s, key):
if isinstance(key, slice): return SliceProxy(s, int(key.start), int(key.stop))
@@ -664,7 +676,7 @@ def compile_pseudocode(pseudocode: str) -> str:
def _assign(lhs: str, rhs: str) -> str:
"""Generate assignment. Bare tmp/SCC/etc get wrapped in Reg()."""
if lhs in ('tmp', 'SCC', 'VCC', 'EXEC', 'D0', 'D1', 'saveexec'):
if lhs in ('tmp', 'SCC', 'VCC', 'EXEC', 'D0', 'D1', 'saveexec', 'PC'):
return f"{lhs} = Reg({rhs})"
return f"{lhs} = {rhs}"
@@ -801,14 +813,14 @@ INST_PATTERN = re.compile(r'^([SV]_[A-Z0-9_]+)\s+(\d+)\s*$', re.M)
# Patterns that can't be handled by the DSL (require special handling in emu.py)
UNSUPPORTED = ['SGPR[', 'V_SWAP', 'eval ', 'FATAL_HALT', 'HW_REGISTERS',
'PC =', 'PC=', 'PC+', '= PC', 'vscnt', 'vmcnt', 'expcnt', 'lgkmcnt',
'vscnt', 'vmcnt', 'expcnt', 'lgkmcnt',
'CVT_OFF_TABLE', 'ThreadMask',
'S1[i', 'C.i32', 'S[i]', 'in[', '2.0 / PI',
'if n.', 'DST.u32', 'addrd = DST', 'addr = DST'] # Malformed pseudocode from PDF
def extract_pseudocode(text: str) -> str | None:
"""Extract pseudocode from an instruction description snippet."""
lines, result, depth = text.split('\n'), [], 0
lines, result, depth, in_lambda = text.split('\n'), [], 0, 0
for line in lines:
s = line.strip()
if not s: continue
@@ -817,12 +829,17 @@ def extract_pseudocode(text: str) -> str | None:
# Skip document headers (RDNA or CDNA)
if s.startswith('"RDNA') or s.startswith('AMD ') or s.startswith('CDNA'): continue
if s.startswith('Notes') or s.startswith('Functional examples'): break
# Track lambda definitions (e.g., BYTE_PERMUTE = lambda(data, sel) (...))
if '= lambda(' in s: in_lambda += 1; continue
if in_lambda > 0:
if s.endswith(');'): in_lambda -= 1
continue
if s.startswith('if '): depth += 1
elif s.startswith('endif'): depth = max(0, depth - 1)
if s.endswith('.') and not any(p in s for p in ['D0', 'D1', 'S0', 'S1', 'S2', 'SCC', 'VCC', 'tmp', '=']): continue
if re.match(r'^[a-z].*\.$', s) and '=' not in s: continue
is_code = (
any(p in s for p in ['D0.', 'D1.', 'S0.', 'S1.', 'S2.', 'SCC =', 'SCC ?', 'VCC', 'EXEC', 'tmp =', 'tmp[', 'lane =']) or
any(p in s for p in ['D0.', 'D1.', 'S0.', 'S1.', 'S2.', 'SCC =', 'SCC ?', 'VCC', 'EXEC', 'tmp =', 'tmp[', 'lane =', 'PC =']) or
any(p in s for p in ['D0[', 'D1[', 'S0[', 'S1[', 'S2[']) or
s.startswith(('if ', 'else', 'elsif', 'endif', 'declare ', 'for ', 'endfor', '//')) or
re.match(r'^[a-z_]+\s*=', s) or re.match(r'^[a-z_]+\[', s) or (depth > 0 and '=' in s)
@@ -1043,10 +1060,12 @@ from extra.assembly.amd.pcode import *
is_div_scale = 'DIV_SCALE' in op.name
# VOP3SD instructions that write VCC per-lane (either via VCC.u64[laneId] or by setting VCC = 0/1)
has_sdst = cls_name == 'VOP3SDOp' and ('VCC.u64[laneId]' in pc or is_div_scale)
# Instructions that use/modify PC
has_pc = 'PC' in pc
# Generate function with indented body
fn_name = f"_{cls_name}_{op.name}"
lines.append(f"def {fn_name}(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):")
lines.append(f"def {fn_name}(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0, pc=0):")
# Add original pseudocode as comment
for pc_line in pc.split('\n'):
lines.append(f" # {pc_line}")
@@ -1057,14 +1076,21 @@ from extra.assembly.amd.pcode import *
('SCC', 'Reg(scc)'), ('VCC', 'Reg(vcc)'), ('EXEC', 'Reg(exec_mask)'),
('tmp', 'Reg(0)'), ('saveexec', 'Reg(exec_mask)'), ('laneId', 'lane'),
('SIMM16', 'Reg(literal)'), ('SIMM32', 'Reg(literal)'),
('SRC0', 'Reg(src0_idx)'), ('VDST', 'Reg(vdst_idx)')]
('SRC0', 'Reg(src0_idx)'), ('VDST', 'Reg(vdst_idx)'),
('PC', 'Reg(pc)')] # PC is passed in as byte address
used = {name for name, _ in regs if name in combined}
# EXEC_LO/EXEC_HI need EXEC
if 'EXEC_LO' in combined or 'EXEC_HI' in combined: used.add('EXEC')
# VCCZ/EXECZ need VCC/EXEC
if 'VCCZ' in combined: used.add('VCC')
if 'EXECZ' in combined: used.add('EXEC')
for name, init in regs:
if name in used: lines.append(f" {name} = {init}")
if 'EXEC_LO' in combined: lines.append(" EXEC_LO = SliceProxy(EXEC, 31, 0)")
if 'EXEC_HI' in combined: lines.append(" EXEC_HI = SliceProxy(EXEC, 63, 32)")
# VCCZ = 1 if VCC == 0, EXECZ = 1 if EXEC == 0
if 'VCCZ' in combined: lines.append(" VCCZ = Reg(1 if VCC._val == 0 else 0)")
if 'EXECZ' in combined: lines.append(" EXECZ = Reg(1 if EXEC._val == 0 else 0)")
# Add compiled pseudocode with markers
lines.append(" # --- compiled pseudocode ---")
for line in code.split('\n'):
@@ -1088,6 +1114,11 @@ from extra.assembly.amd.pcode import *
lines.append(" result['d0_64'] = True")
if has_d1:
lines.append(" result['d1'] = D1._val & 1")
if has_pc:
# Return new PC as absolute byte address, emulator will compute delta
# Handle negative values (backward jumps): PC._val is stored as unsigned, convert to signed
lines.append(" _pc = PC._val if PC._val < 0x8000000000000000 else PC._val - 0x10000000000000000")
lines.append(" result['new_pc'] = _pc # absolute byte address")
lines.append(" return result")
lines.append("")

View File

@@ -2590,6 +2590,30 @@ class TestNewPcodeHelpers(unittest.TestCase):
# byte 3: sel=0x0C = 12 -> 0x00
self.assertEqual(result, 0x00FFFFFF, f"Expected 0x00FFFFFF, got 0x{result:08x}")
def test_v_perm_b32_sign_extend(self):
"""V_PERM_B32: Test sign extension selectors 8-11."""
# Combined = {S0, S1} where S1 is bytes 0-3, S0 is bytes 4-7
# s0 = 0x00008000 -> byte 5 (0x80) has sign bit set
# s1 = 0x80000080 -> bytes 1 (0x00) and 3 (0x80) have sign bits, byte 0 (0x80) has sign bit
# Combined = 0x00008000_80000080
# selector = 0x08090A0B -> sign of bytes 1,3,5,7
# byte 0: sel=0x0B -> sign of byte 7 (0x00) -> 0x00
# byte 1: sel=0x0A -> sign of byte 5 (0x80) -> 0xFF
# byte 2: sel=0x09 -> sign of byte 3 (0x80) -> 0xFF
# byte 3: sel=0x08 -> sign of byte 1 (0x00) -> 0x00
instructions = [
s_mov_b32(s[0], 0x00008000),
s_mov_b32(s[1], 0x80000080),
s_mov_b32(s[2], 0x08090A0B),
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
v_mov_b32_e32(v[2], s[2]),
v_perm_b32(v[3], v[0], v[1], v[2]),
]
st = run_program(instructions, n_lanes=1)
result = st.vgpr[0][3]
self.assertEqual(result, 0x00FFFF00, f"Expected 0x00FFFF00, got 0x{result:08x}")
def test_v_dot2_f32_bf16_basic(self):
"""V_DOT2_F32_BF16: Dot product of two bf16 pairs accumulated into f32."""
from extra.assembly.amd.pcode import _ibf16