mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
assembly/amd: fix v_perm_b32 + PC fixes (#13897)
* assembly/amd: fix v_perm_b32 * add pc support
This commit is contained in:
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -205,21 +205,11 @@ def exec_scalar(st: WaveState, inst: Inst) -> int:
|
||||
compiled = _get_compiled()
|
||||
inst_type = type(inst)
|
||||
|
||||
# SOPP: control flow (not ALU)
|
||||
# SOPP: special cases for control flow that has no pseudocode
|
||||
if inst_type is SOPP:
|
||||
op = inst.op
|
||||
if op == SOPPOp.S_ENDPGM: return -1
|
||||
if op == SOPPOp.S_BARRIER: return -2
|
||||
if op == SOPPOp.S_BRANCH: return _sext(inst.simm16, 16)
|
||||
if op == SOPPOp.S_CBRANCH_SCC0: return _sext(inst.simm16, 16) if st.scc == 0 else 0
|
||||
if op == SOPPOp.S_CBRANCH_SCC1: return _sext(inst.simm16, 16) if st.scc == 1 else 0
|
||||
if op == SOPPOp.S_CBRANCH_VCCZ: return _sext(inst.simm16, 16) if (st.vcc & 0xffffffff) == 0 else 0
|
||||
if op == SOPPOp.S_CBRANCH_VCCNZ: return _sext(inst.simm16, 16) if (st.vcc & 0xffffffff) != 0 else 0
|
||||
if op == SOPPOp.S_CBRANCH_EXECZ: return _sext(inst.simm16, 16) if st.exec_mask == 0 else 0
|
||||
if op == SOPPOp.S_CBRANCH_EXECNZ: return _sext(inst.simm16, 16) if st.exec_mask != 0 else 0
|
||||
# Valid SOPP range is 0-61 (max defined opcode); anything above is invalid
|
||||
if op > 61: raise NotImplementedError(f"Invalid SOPP opcode {op}")
|
||||
return 0 # waits, hints, nops
|
||||
|
||||
# SMEM: memory loads (not ALU)
|
||||
if inst_type is SMEM:
|
||||
@@ -229,46 +219,39 @@ def exec_scalar(st: WaveState, inst: Inst) -> int:
|
||||
for i in range(cnt): st.wsgpr(inst.sdata + i, mem_read((addr + i * 4) & 0xffffffffffffffff, 4))
|
||||
return 0
|
||||
|
||||
# SOP1: special handling for ops not in pseudocode
|
||||
if inst_type is SOP1:
|
||||
op = SOP1Op(inst.op)
|
||||
# S_GETPC_B64: Get program counter (PC is stored as byte offset, convert from words)
|
||||
if op == SOP1Op.S_GETPC_B64:
|
||||
pc_bytes = st.pc * 4 # PC is in words, convert to bytes
|
||||
st.wsgpr64(inst.sdst, pc_bytes)
|
||||
return 0
|
||||
# S_SETPC_B64: Set program counter to source value (indirect jump)
|
||||
# Returns delta such that st.pc + inst_words + delta = target_words
|
||||
if op == SOP1Op.S_SETPC_B64:
|
||||
target_bytes = st.rsrc64(inst.ssrc0, 0)
|
||||
target_words = target_bytes // 4
|
||||
inst_words = 1 # SOP1 is always 1 word
|
||||
return target_words - st.pc - inst_words
|
||||
|
||||
# Get op enum and lookup compiled function
|
||||
if inst_type is SOP1: op_cls, ssrc0, sdst = SOP1Op, inst.ssrc0, inst.sdst
|
||||
elif inst_type is SOP2: op_cls, ssrc0, sdst = SOP2Op, inst.ssrc0, inst.sdst
|
||||
elif inst_type is SOPC: op_cls, ssrc0, sdst = SOPCOp, inst.ssrc0, None
|
||||
elif inst_type is SOPK: op_cls, ssrc0, sdst = SOPKOp, inst.sdst, inst.sdst # sdst is both src and dst
|
||||
elif inst_type is SOPP: op_cls, ssrc0, sdst = SOPPOp, None, None
|
||||
else: raise NotImplementedError(f"Unknown scalar type {inst_type}")
|
||||
|
||||
op = op_cls(inst.op)
|
||||
# SOPP has gaps in the opcode enum - treat unknown opcodes as no-ops
|
||||
try: op = op_cls(inst.op)
|
||||
except ValueError:
|
||||
if inst_type is SOPP: return 0
|
||||
raise
|
||||
fn = compiled.get(op_cls, {}).get(op)
|
||||
if fn is None: raise NotImplementedError(f"{op.name} not in pseudocode")
|
||||
if fn is None:
|
||||
# SOPP instructions without pseudocode (waits, hints, nops) are no-ops
|
||||
if inst_type is SOPP: return 0
|
||||
raise NotImplementedError(f"{op.name} not in pseudocode")
|
||||
|
||||
# Build context - handle 64-bit ops that need 64-bit source reads
|
||||
# 64-bit source ops: name ends with _B64, _I64, _U64 or contains _U64, _I64 before last underscore
|
||||
is_64bit_s0 = op.name.endswith(('_B64', '_I64', '_U64')) or '_U64_' in op.name or '_I64_' in op.name
|
||||
is_64bit_s0s1 = op_cls is SOPCOp and op in (SOPCOp.S_CMP_EQ_U64, SOPCOp.S_CMP_LG_U64)
|
||||
s0 = st.rsrc64(ssrc0, 0) if is_64bit_s0 or is_64bit_s0s1 else (st.rsrc(ssrc0, 0) if inst_type != SOPK else st.rsgpr(inst.sdst))
|
||||
s0 = st.rsrc64(ssrc0, 0) if is_64bit_s0 or is_64bit_s0s1 else (st.rsrc(ssrc0, 0) if inst_type not in (SOPK, SOPP) else (st.rsgpr(inst.sdst) if inst_type is SOPK else 0))
|
||||
is_64bit_sop2 = is_64bit_s0 and inst_type is SOP2
|
||||
s1 = st.rsrc64(inst.ssrc1, 0) if (is_64bit_sop2 or is_64bit_s0s1) else (st.rsrc(inst.ssrc1, 0) if inst_type in (SOP2, SOPC) else inst.simm16 if inst_type is SOPK else 0)
|
||||
d0 = st.rsgpr64(sdst) if (is_64bit_s0 or is_64bit_s0s1) and sdst is not None else (st.rsgpr(sdst) if sdst is not None else 0)
|
||||
exec_mask = st.exec_mask
|
||||
literal = inst.simm16 if inst_type is SOPK else st.literal
|
||||
literal = inst.simm16 if inst_type in (SOPK, SOPP) else st.literal
|
||||
|
||||
# Execute compiled function
|
||||
result = fn(s0, s1, 0, d0, st.scc, st.vcc, 0, exec_mask, literal, None, {})
|
||||
# Execute compiled function - pass PC in bytes for instructions that need it
|
||||
pc_bytes = st.pc * 4
|
||||
result = fn(s0, s1, 0, d0, st.scc, st.vcc, 0, exec_mask, literal, None, {}, pc=pc_bytes)
|
||||
|
||||
# Apply results
|
||||
if sdst is not None:
|
||||
@@ -278,7 +261,11 @@ def exec_scalar(st: WaveState, inst: Inst) -> int:
|
||||
st.wsgpr(sdst, result['d0'])
|
||||
if 'scc' in result: st.scc = result['scc']
|
||||
if 'exec' in result: st.exec_mask = result['exec']
|
||||
if 'pc_delta' in result: return result['pc_delta']
|
||||
if 'new_pc' in result:
|
||||
# Convert absolute byte address to word delta
|
||||
# new_pc is where we want to go, st.pc is current position, inst._words will be added after
|
||||
new_pc_words = result['new_pc'] // 4
|
||||
return new_pc_words - st.pc - 1 # -1 because emulator adds inst_words (1 for scalar)
|
||||
return 0
|
||||
|
||||
def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = None) -> None:
|
||||
@@ -402,20 +389,6 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No
|
||||
op_cls, op, src0, src1, src2, vdst = VOPCOp, VOPCOp(inst.op), inst.src0, inst.src1, None, inst.vdst
|
||||
else:
|
||||
op_cls, op, src0, src1, src2, vdst = VOP3Op, VOP3Op(inst.op), inst.src0, inst.src1, inst.src2, inst.vdst
|
||||
# V_PERM_B32: byte permutation - not in pseudocode PDF, implement directly
|
||||
# D0[byte_i] = selector[byte_i] < 8 ? {src0, src1}[selector[byte_i]] : (selector[byte_i] >= 0xD ? 0xFF : 0x00)
|
||||
if op == VOP3Op.V_PERM_B32:
|
||||
s0, s1, s2 = st.rsrc(inst.src0, lane), st.rsrc(inst.src1, lane), st.rsrc(inst.src2, lane)
|
||||
# Combine src1 and src0 into 8-byte value: src1 is bytes 0-3, src0 is bytes 4-7
|
||||
combined = (s1 & 0xffffffff) | ((s0 & 0xffffffff) << 32)
|
||||
result = 0
|
||||
for i in range(4): # 4 result bytes
|
||||
sel = (s2 >> (i * 8)) & 0xff # byte selector for this position
|
||||
if sel <= 7: result |= (((combined >> (sel * 8)) & 0xff) << (i * 8)) # select byte from combined
|
||||
elif sel >= 0xd: result |= (0xff << (i * 8)) # 0xD-0xF: constant 0xFF
|
||||
# else 0x8-0xC: constant 0x00 (already 0)
|
||||
V[vdst] = result & 0xffffffff
|
||||
return
|
||||
elif inst_type is VOPC:
|
||||
op = VOPCOp(inst.op)
|
||||
# For 16-bit VOPC, vsrc1 uses same encoding as VOP2 16-bit: bit 7 selects hi(1) or lo(0) half
|
||||
|
||||
@@ -341,7 +341,7 @@ def F(x):
|
||||
if isinstance(x, int): return _f32(x) # int -> interpret as f32 bits
|
||||
if isinstance(x, TypedView): return x # preserve TypedView for bit-pattern checks
|
||||
return float(x) # already a float or float-like
|
||||
signext = lambda x: x
|
||||
signext = lambda x: int(x) # sign-extend to full width - already handled by Python's arbitrary precision ints
|
||||
pack = lambda hi, lo: ((int(hi) & 0xffff) << 16) | (int(lo) & 0xffff)
|
||||
pack32 = lambda hi, lo: ((int(hi) & 0xffffffff) << 32) | (int(lo) & 0xffffffff)
|
||||
_pack, _pack32 = pack, pack32 # Aliases for internal use
|
||||
@@ -519,6 +519,17 @@ class TypedView:
|
||||
|
||||
def __bool__(s): return bool(int(s))
|
||||
|
||||
# Allow chained type access like jump_addr.i64 when jump_addr is already a TypedView
|
||||
# These just return self or convert appropriately
|
||||
@property
|
||||
def i64(s): return s if s._bits == 64 and s._signed else int(s)
|
||||
@property
|
||||
def u64(s): return s if s._bits == 64 and not s._signed else int(s) & MASK64
|
||||
@property
|
||||
def i32(s): return s if s._bits == 32 and s._signed else _sext(int(s) & MASK32, 32)
|
||||
@property
|
||||
def u32(s): return s if s._bits == 32 and not s._signed else int(s) & MASK32
|
||||
|
||||
class Reg:
|
||||
"""GPU register: D0.f32 = S0.f32 + S1.f32 just works."""
|
||||
__slots__ = ('_val',)
|
||||
@@ -542,6 +553,7 @@ class Reg:
|
||||
bf16 = property(lambda s: TypedView(s, 16, is_float=True, is_bf16=True), lambda s, v: setattr(s, '_val', (s._val & 0xffff0000) | ((v if isinstance(v, int) else _ibf16(float(v))) & 0xffff)))
|
||||
u8 = property(lambda s: TypedView(s, 8))
|
||||
i8 = property(lambda s: TypedView(s, 8, signed=True))
|
||||
u1 = property(lambda s: TypedView(s, 1)) # single bit
|
||||
|
||||
def __getitem__(s, key):
|
||||
if isinstance(key, slice): return SliceProxy(s, int(key.start), int(key.stop))
|
||||
@@ -664,7 +676,7 @@ def compile_pseudocode(pseudocode: str) -> str:
|
||||
|
||||
def _assign(lhs: str, rhs: str) -> str:
|
||||
"""Generate assignment. Bare tmp/SCC/etc get wrapped in Reg()."""
|
||||
if lhs in ('tmp', 'SCC', 'VCC', 'EXEC', 'D0', 'D1', 'saveexec'):
|
||||
if lhs in ('tmp', 'SCC', 'VCC', 'EXEC', 'D0', 'D1', 'saveexec', 'PC'):
|
||||
return f"{lhs} = Reg({rhs})"
|
||||
return f"{lhs} = {rhs}"
|
||||
|
||||
@@ -801,14 +813,14 @@ INST_PATTERN = re.compile(r'^([SV]_[A-Z0-9_]+)\s+(\d+)\s*$', re.M)
|
||||
|
||||
# Patterns that can't be handled by the DSL (require special handling in emu.py)
|
||||
UNSUPPORTED = ['SGPR[', 'V_SWAP', 'eval ', 'FATAL_HALT', 'HW_REGISTERS',
|
||||
'PC =', 'PC=', 'PC+', '= PC', 'vscnt', 'vmcnt', 'expcnt', 'lgkmcnt',
|
||||
'vscnt', 'vmcnt', 'expcnt', 'lgkmcnt',
|
||||
'CVT_OFF_TABLE', 'ThreadMask',
|
||||
'S1[i', 'C.i32', 'S[i]', 'in[', '2.0 / PI',
|
||||
'if n.', 'DST.u32', 'addrd = DST', 'addr = DST'] # Malformed pseudocode from PDF
|
||||
|
||||
def extract_pseudocode(text: str) -> str | None:
|
||||
"""Extract pseudocode from an instruction description snippet."""
|
||||
lines, result, depth = text.split('\n'), [], 0
|
||||
lines, result, depth, in_lambda = text.split('\n'), [], 0, 0
|
||||
for line in lines:
|
||||
s = line.strip()
|
||||
if not s: continue
|
||||
@@ -817,12 +829,17 @@ def extract_pseudocode(text: str) -> str | None:
|
||||
# Skip document headers (RDNA or CDNA)
|
||||
if s.startswith('"RDNA') or s.startswith('AMD ') or s.startswith('CDNA'): continue
|
||||
if s.startswith('Notes') or s.startswith('Functional examples'): break
|
||||
# Track lambda definitions (e.g., BYTE_PERMUTE = lambda(data, sel) (...))
|
||||
if '= lambda(' in s: in_lambda += 1; continue
|
||||
if in_lambda > 0:
|
||||
if s.endswith(');'): in_lambda -= 1
|
||||
continue
|
||||
if s.startswith('if '): depth += 1
|
||||
elif s.startswith('endif'): depth = max(0, depth - 1)
|
||||
if s.endswith('.') and not any(p in s for p in ['D0', 'D1', 'S0', 'S1', 'S2', 'SCC', 'VCC', 'tmp', '=']): continue
|
||||
if re.match(r'^[a-z].*\.$', s) and '=' not in s: continue
|
||||
is_code = (
|
||||
any(p in s for p in ['D0.', 'D1.', 'S0.', 'S1.', 'S2.', 'SCC =', 'SCC ?', 'VCC', 'EXEC', 'tmp =', 'tmp[', 'lane =']) or
|
||||
any(p in s for p in ['D0.', 'D1.', 'S0.', 'S1.', 'S2.', 'SCC =', 'SCC ?', 'VCC', 'EXEC', 'tmp =', 'tmp[', 'lane =', 'PC =']) or
|
||||
any(p in s for p in ['D0[', 'D1[', 'S0[', 'S1[', 'S2[']) or
|
||||
s.startswith(('if ', 'else', 'elsif', 'endif', 'declare ', 'for ', 'endfor', '//')) or
|
||||
re.match(r'^[a-z_]+\s*=', s) or re.match(r'^[a-z_]+\[', s) or (depth > 0 and '=' in s)
|
||||
@@ -1043,10 +1060,12 @@ from extra.assembly.amd.pcode import *
|
||||
is_div_scale = 'DIV_SCALE' in op.name
|
||||
# VOP3SD instructions that write VCC per-lane (either via VCC.u64[laneId] or by setting VCC = 0/1)
|
||||
has_sdst = cls_name == 'VOP3SDOp' and ('VCC.u64[laneId]' in pc or is_div_scale)
|
||||
# Instructions that use/modify PC
|
||||
has_pc = 'PC' in pc
|
||||
|
||||
# Generate function with indented body
|
||||
fn_name = f"_{cls_name}_{op.name}"
|
||||
lines.append(f"def {fn_name}(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):")
|
||||
lines.append(f"def {fn_name}(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0, pc=0):")
|
||||
# Add original pseudocode as comment
|
||||
for pc_line in pc.split('\n'):
|
||||
lines.append(f" # {pc_line}")
|
||||
@@ -1057,14 +1076,21 @@ from extra.assembly.amd.pcode import *
|
||||
('SCC', 'Reg(scc)'), ('VCC', 'Reg(vcc)'), ('EXEC', 'Reg(exec_mask)'),
|
||||
('tmp', 'Reg(0)'), ('saveexec', 'Reg(exec_mask)'), ('laneId', 'lane'),
|
||||
('SIMM16', 'Reg(literal)'), ('SIMM32', 'Reg(literal)'),
|
||||
('SRC0', 'Reg(src0_idx)'), ('VDST', 'Reg(vdst_idx)')]
|
||||
('SRC0', 'Reg(src0_idx)'), ('VDST', 'Reg(vdst_idx)'),
|
||||
('PC', 'Reg(pc)')] # PC is passed in as byte address
|
||||
used = {name for name, _ in regs if name in combined}
|
||||
# EXEC_LO/EXEC_HI need EXEC
|
||||
if 'EXEC_LO' in combined or 'EXEC_HI' in combined: used.add('EXEC')
|
||||
# VCCZ/EXECZ need VCC/EXEC
|
||||
if 'VCCZ' in combined: used.add('VCC')
|
||||
if 'EXECZ' in combined: used.add('EXEC')
|
||||
for name, init in regs:
|
||||
if name in used: lines.append(f" {name} = {init}")
|
||||
if 'EXEC_LO' in combined: lines.append(" EXEC_LO = SliceProxy(EXEC, 31, 0)")
|
||||
if 'EXEC_HI' in combined: lines.append(" EXEC_HI = SliceProxy(EXEC, 63, 32)")
|
||||
# VCCZ = 1 if VCC == 0, EXECZ = 1 if EXEC == 0
|
||||
if 'VCCZ' in combined: lines.append(" VCCZ = Reg(1 if VCC._val == 0 else 0)")
|
||||
if 'EXECZ' in combined: lines.append(" EXECZ = Reg(1 if EXEC._val == 0 else 0)")
|
||||
# Add compiled pseudocode with markers
|
||||
lines.append(" # --- compiled pseudocode ---")
|
||||
for line in code.split('\n'):
|
||||
@@ -1088,6 +1114,11 @@ from extra.assembly.amd.pcode import *
|
||||
lines.append(" result['d0_64'] = True")
|
||||
if has_d1:
|
||||
lines.append(" result['d1'] = D1._val & 1")
|
||||
if has_pc:
|
||||
# Return new PC as absolute byte address, emulator will compute delta
|
||||
# Handle negative values (backward jumps): PC._val is stored as unsigned, convert to signed
|
||||
lines.append(" _pc = PC._val if PC._val < 0x8000000000000000 else PC._val - 0x10000000000000000")
|
||||
lines.append(" result['new_pc'] = _pc # absolute byte address")
|
||||
lines.append(" return result")
|
||||
lines.append("")
|
||||
|
||||
|
||||
@@ -2590,6 +2590,30 @@ class TestNewPcodeHelpers(unittest.TestCase):
|
||||
# byte 3: sel=0x0C = 12 -> 0x00
|
||||
self.assertEqual(result, 0x00FFFFFF, f"Expected 0x00FFFFFF, got 0x{result:08x}")
|
||||
|
||||
def test_v_perm_b32_sign_extend(self):
|
||||
"""V_PERM_B32: Test sign extension selectors 8-11."""
|
||||
# Combined = {S0, S1} where S1 is bytes 0-3, S0 is bytes 4-7
|
||||
# s0 = 0x00008000 -> byte 5 (0x80) has sign bit set
|
||||
# s1 = 0x80000080 -> bytes 1 (0x00) and 3 (0x80) have sign bits, byte 0 (0x80) has sign bit
|
||||
# Combined = 0x00008000_80000080
|
||||
# selector = 0x08090A0B -> sign of bytes 1,3,5,7
|
||||
# byte 0: sel=0x0B -> sign of byte 7 (0x00) -> 0x00
|
||||
# byte 1: sel=0x0A -> sign of byte 5 (0x80) -> 0xFF
|
||||
# byte 2: sel=0x09 -> sign of byte 3 (0x80) -> 0xFF
|
||||
# byte 3: sel=0x08 -> sign of byte 1 (0x00) -> 0x00
|
||||
instructions = [
|
||||
s_mov_b32(s[0], 0x00008000),
|
||||
s_mov_b32(s[1], 0x80000080),
|
||||
s_mov_b32(s[2], 0x08090A0B),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_mov_b32_e32(v[1], s[1]),
|
||||
v_mov_b32_e32(v[2], s[2]),
|
||||
v_perm_b32(v[3], v[0], v[1], v[2]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = st.vgpr[0][3]
|
||||
self.assertEqual(result, 0x00FFFF00, f"Expected 0x00FFFF00, got 0x{result:08x}")
|
||||
|
||||
def test_v_dot2_f32_bf16_basic(self):
|
||||
"""V_DOT2_F32_BF16: Dot product of two bf16 pairs accumulated into f32."""
|
||||
from extra.assembly.amd.pcode import _ibf16
|
||||
|
||||
Reference in New Issue
Block a user