assembly/amd: split test_emu into hw tests (#13966)

* assmebly/amd: split test_emu into hw tests

* hw tests

* bugfixes

* more tests and fix
This commit is contained in:
George Hotz
2026-01-02 11:04:56 -05:00
committed by GitHub
parent 2e2b5fed12
commit 0e282025ff
14 changed files with 6772 additions and 5772 deletions

View File

@@ -13,12 +13,21 @@ MASK32, MASK64, MASK128 = 0xffffffff, 0xffffffffffffffff, (1 << 128) - 1
_struct_f, _struct_I = struct.Struct("<f"), struct.Struct("<I")
_struct_e, _struct_H = struct.Struct("<e"), struct.Struct("<H")
_struct_d, _struct_Q = struct.Struct("<d"), struct.Struct("<Q")
def _f32(i): return _struct_f.unpack(_struct_I.pack(i & MASK32))[0]
def _f32(i):
i = i & MASK32
# RDNA3 default mode: flush f32 denormals to zero (FTZ)
# Denormal: exponent=0 (bits 23-30) and mantissa!=0 (bits 0-22)
if (i & 0x7f800000) == 0 and (i & 0x007fffff) != 0: return 0.0
return _struct_f.unpack(_struct_I.pack(i))[0]
def _i32(f):
if isinstance(f, int): f = float(f)
if math.isnan(f): return 0xffc00000 if math.copysign(1.0, f) < 0 else 0x7fc00000
if math.isinf(f): return 0x7f800000 if f > 0 else 0xff800000
try: return _struct_I.unpack(_struct_f.pack(f))[0]
try:
bits = _struct_I.unpack(_struct_f.pack(f))[0]
# RDNA3 default mode: flush f32 denormals to zero (FTZ)
if (bits & 0x7f800000) == 0 and (bits & 0x007fffff) != 0: return 0x80000000 if bits & 0x80000000 else 0
return bits
except (OverflowError, struct.error): return 0x7f800000 if f > 0 else 0xff800000
def _sext(v, b): return v - (1 << b) if v & (1 << (b - 1)) else v
def _f16(i): return _struct_e.unpack(_struct_H.pack(i & 0xffff))[0]
@@ -333,6 +342,8 @@ class Inst:
def __init__(self, *args, literal: int | None = None, **kwargs):
self._values, self._literal = dict(self._defaults), None
field_names = [n for n in self._fields if n != 'encoding']
# Map Python-friendly names to actual field names (abs_ -> abs for Python reserved word)
if 'abs_' in kwargs: kwargs['abs'] = kwargs.pop('abs_')
orig_args = dict(zip(field_names, args)) | kwargs
self._values.update(orig_args)
self._validate(orig_args)

View File

@@ -35,7 +35,15 @@ def _gt_neg_zero(a, b): return (a > b) or (a == 0 and b == 0 and not math.copysi
def _lt_neg_zero(a, b): return (a < b) or (a == 0 and b == 0 and math.copysign(1, a) < 0 and not math.copysign(1, b) < 0)
def _fma(a, b, c): return a * b + c
def _signext(v): return v
def _fpop(fn): return lambda x: (x := float(x), x if math.isnan(x) or math.isinf(x) else float(fn(x)))[1]
def _fpop(fn):
def wrapper(x):
x = float(x)
if math.isnan(x) or math.isinf(x): return x
result = float(fn(x))
# Preserve sign of zero (IEEE 754: ceil(-0.0) = -0.0, ceil(-0.1) = -0.0)
if result == 0.0: return math.copysign(0.0, x)
return result
return wrapper
trunc, floor, ceil = _fpop(math.trunc), _fpop(math.floor), _fpop(math.ceil)
class _SafeFloat(float):
"""Float subclass that uses _div for division to handle 0/inf correctly."""
@@ -75,7 +83,11 @@ def _trig(fn, x):
# V_SIN/COS_F32: hardware does frac on input cycles before computing
if math.isinf(x) or math.isnan(x): return float("nan")
frac_cycles = fract(x / (2 * math.pi))
return fn(frac_cycles * 2 * math.pi)
result = fn(frac_cycles * 2 * math.pi)
# Hardware returns exactly 0 for cos(π/2), sin(π), etc. due to lookup table
# Round very small results (below f32 precision) to exactly 0
if abs(result) < 1e-7: return 0.0
return result
def sin(x): return _trig(math.sin, x)
def cos(x): return _trig(math.cos, x)
def pow(a, b):

View File

@@ -0,0 +1 @@
"""Hardware-validated emulator tests for RDNA3 instructions."""

View File

@@ -0,0 +1,200 @@
"""Test infrastructure for hardware-validated RDNA3 emulator tests.
Uses run_asm() with memory output, so tests can run on both emulator and real hardware.
Set USE_HW=1 to run on both emulator and real hardware, comparing results.
"""
import ctypes, os, struct
from extra.assembly.amd.autogen.rdna3.ins import *
from extra.assembly.amd.dsl import RawImm
from extra.assembly.amd.emu import WaveState, run_asm, set_valid_mem_ranges
from extra.assembly.amd.pcode import _i32, _f32
VCC = SrcEnum.VCC_LO # For VOP3SD sdst field
USE_HW = os.environ.get("USE_HW", "0") == "1"
FLOAT_TOLERANCE = 1e-5
# Output buffer layout: vgpr[16][32], sgpr[16], vcc, scc
N_VGPRS, N_SGPRS, WAVE_SIZE = 16, 16, 32
VGPR_BYTES = N_VGPRS * WAVE_SIZE * 4 # 16 regs * 32 lanes * 4 bytes = 2048
SGPR_BYTES = N_SGPRS * 4 # 16 regs * 4 bytes = 64
OUT_BYTES = VGPR_BYTES + SGPR_BYTES + 8 # + vcc + scc
# Float conversion helpers
def f2i(f: float) -> int: return _i32(f)
def i2f(i: int) -> float: return _f32(i)
def f2i64(f: float) -> int: return struct.unpack('<Q', struct.pack('<d', f))[0]
def i642f(i: int) -> float: return struct.unpack('<d', struct.pack('<Q', i))[0]
def assemble(instructions: list) -> bytes:
return b''.join(inst.to_bytes() for inst in instructions)
def get_prologue_epilogue(n_lanes: int) -> tuple[list, list]:
"""Generate prologue and epilogue instructions for state capture."""
prologue = [
s_mov_b32(s[80], s[0]),
s_mov_b32(s[81], s[1]),
v_mov_b32_e32(v[255], v[0]),
]
for i in range(N_VGPRS):
prologue.append(v_mov_b32_e32(v[i], 0))
for i in range(N_SGPRS):
prologue.append(s_mov_b32(s[i], 0))
prologue.append(s_mov_b32(s[SrcEnum.VCC_LO - 128], 0))
epilogue = [
s_mov_b32(s[90], SrcEnum.VCC_LO),
s_cselect_b32(s[91], 1, 0),
s_load_b64(s[92:93], s[80], 0, soffset=SrcEnum.NULL),
s_waitcnt(lgkmcnt=0),
v_lshlrev_b32_e32(v[240], 2, v[255]),
]
for i in range(N_VGPRS):
epilogue.append(global_store_b32(addr=v[240], data=v[i], saddr=s[92], offset=i * WAVE_SIZE * 4))
epilogue.append(v_mov_b32_e32(v[241], 0))
epilogue.append(v_cmp_eq_u32_e32(v[255], v[241]))
epilogue.append(s_and_saveexec_b32(s[94], SrcEnum.VCC_LO))
epilogue.append(v_mov_b32_e32(v[240], 0))
for i in range(N_SGPRS):
epilogue.append(v_mov_b32_e32(v[243], s[i]))
epilogue.append(global_store_b32(addr=v[240], data=v[243], saddr=s[92], offset=VGPR_BYTES + i * 4))
epilogue.append(v_mov_b32_e32(v[243], s[90]))
epilogue.append(global_store_b32(addr=v[240], data=v[243], saddr=s[92], offset=VGPR_BYTES + SGPR_BYTES))
epilogue.append(v_mov_b32_e32(v[243], s[91]))
epilogue.append(global_store_b32(addr=v[240], data=v[243], saddr=s[92], offset=VGPR_BYTES + SGPR_BYTES + 4))
epilogue.append(s_mov_b32(s[SrcEnum.EXEC_LO - 128], s[94]))
epilogue.append(s_endpgm())
return prologue, epilogue
def parse_output(out_buf: bytes, n_lanes: int) -> WaveState:
"""Parse output buffer into WaveState."""
st = WaveState()
for i in range(N_VGPRS):
for lane in range(n_lanes):
off = i * WAVE_SIZE * 4 + lane * 4
st.vgpr[lane][i] = struct.unpack_from('<I', out_buf, off)[0]
for i in range(N_SGPRS):
st.sgpr[i] = struct.unpack_from('<I', out_buf, VGPR_BYTES + i * 4)[0]
st.vcc = struct.unpack_from('<I', out_buf, VGPR_BYTES + SGPR_BYTES)[0]
st.scc = struct.unpack_from('<I', out_buf, VGPR_BYTES + SGPR_BYTES + 4)[0]
return st
def run_program_emu(instructions: list, n_lanes: int = 1) -> WaveState:
"""Run instructions via emulator run_asm, dump state to memory, return WaveState."""
out_buf = (ctypes.c_uint8 * OUT_BYTES)(*([0] * OUT_BYTES))
out_addr = ctypes.addressof(out_buf)
prologue, epilogue = get_prologue_epilogue(n_lanes)
code = assemble(prologue + instructions + epilogue)
args = (ctypes.c_uint64 * 1)(out_addr)
args_ptr = ctypes.addressof(args)
kernel_buf = (ctypes.c_char * len(code)).from_buffer_copy(code)
lib_ptr = ctypes.addressof(kernel_buf)
set_valid_mem_ranges({(out_addr, OUT_BYTES), (args_ptr, 8)})
result = run_asm(lib_ptr, len(code), 1, 1, 1, n_lanes, 1, 1, args_ptr)
assert result == 0, f"run_asm failed with {result}"
return parse_output(bytes(out_buf), n_lanes)
def run_program_hw(instructions: list, n_lanes: int = 1) -> WaveState:
"""Run instructions on real AMD hardware via HIPCompiler and AMDProgram."""
from tinygrad.device import Device
from tinygrad.runtime.ops_amd import AMDProgram
from tinygrad.runtime.support.compiler_amd import HIPCompiler
from tinygrad.helpers import flat_mv
dev = Device["AMD"]
compiler = HIPCompiler(dev.arch)
prologue, epilogue = get_prologue_epilogue(n_lanes)
code = assemble(prologue + instructions + epilogue)
byte_str = ', '.join(f'0x{b:02x}' for b in code)
asm_src = f""".text
.globl test
.p2align 8
.type test,@function
test:
.byte {byte_str}
.rodata
.p2align 6
.amdhsa_kernel test
.amdhsa_next_free_vgpr 256
.amdhsa_next_free_sgpr 96
.amdhsa_wavefront_size32 1
.amdhsa_user_sgpr_kernarg_segment_ptr 1
.amdhsa_kernarg_size 8
.amdhsa_group_segment_fixed_size 65536
.end_amdhsa_kernel
.amdgpu_metadata
---
amdhsa.version:
- 1
- 0
amdhsa.kernels:
- .name: test
.symbol: test.kd
.kernarg_segment_size: 8
.group_segment_fixed_size: 65536
.private_segment_fixed_size: 0
.kernarg_segment_align: 8
.wavefront_size: 32
.sgpr_count: 96
.vgpr_count: 256
.max_flat_workgroup_size: 1024
...
.end_amdgpu_metadata
"""
lib = compiler.compile(asm_src)
prg = AMDProgram(dev, "test", lib)
out_gpu = dev.allocator.alloc(OUT_BYTES)
prg(out_gpu, global_size=(1, 1, 1), local_size=(n_lanes, 1, 1), wait=True)
out_buf = bytearray(OUT_BYTES)
dev.allocator._copyout(flat_mv(memoryview(out_buf)), out_gpu)
return parse_output(bytes(out_buf), n_lanes)
def compare_wave_states(emu_st: WaveState, hw_st: WaveState, n_lanes: int, n_vgprs: int = N_VGPRS) -> list[str]:
"""Compare two WaveStates and return list of differences."""
import math
diffs = []
for i in range(n_vgprs):
for lane in range(n_lanes):
emu_val = emu_st.vgpr[lane][i]
hw_val = hw_st.vgpr[lane][i]
if emu_val != hw_val:
emu_f, hw_f = _f32(emu_val), _f32(hw_val)
if math.isnan(emu_f) and math.isnan(hw_f):
continue
diffs.append(f"v[{i}] lane {lane}: emu=0x{emu_val:08x} ({emu_f:.6g}) hw=0x{hw_val:08x} ({hw_f:.6g})")
for i in range(N_SGPRS):
emu_val = emu_st.sgpr[i]
hw_val = hw_st.sgpr[i]
if emu_val != hw_val:
diffs.append(f"s[{i}]: emu=0x{emu_val:08x} hw=0x{hw_val:08x}")
if emu_st.vcc != hw_st.vcc:
diffs.append(f"vcc: emu=0x{emu_st.vcc:08x} hw=0x{hw_st.vcc:08x}")
if emu_st.scc != hw_st.scc:
diffs.append(f"scc: emu={emu_st.scc} hw={hw_st.scc}")
return diffs
def run_program(instructions: list, n_lanes: int = 1) -> WaveState:
"""Run instructions and return WaveState.
If USE_HW=1, runs on both emulator and hardware, compares results, and raises if they differ.
Otherwise, runs only on emulator.
"""
emu_st = run_program_emu(instructions, n_lanes)
if USE_HW:
hw_st = run_program_hw(instructions, n_lanes)
diffs = compare_wave_states(emu_st, hw_st, n_lanes)
if diffs:
raise AssertionError(f"Emulator vs Hardware mismatch:\n" + "\n".join(diffs))
return hw_st
return emu_st

View File

@@ -0,0 +1,629 @@
"""Tests for DS instructions - data share (LDS) operations.
Includes: ds_store_b32, ds_load_b32, ds_store_2addr_*, ds_load_2addr_*,
ds_add_*, ds_max_*, ds_min_*, ds_and_*, ds_or_*, ds_xor_*,
ds_inc_*, ds_dec_*, ds_cmpstore_*, ds_storexchg_*
"""
import unittest
from extra.assembly.amd.test.hw.helpers import *
class TestDS2Addr(unittest.TestCase):
"""Tests for DS_*_2ADDR instructions."""
def test_ds_store_load_2addr_b32(self):
"""DS_STORE_2ADDR_B32 and DS_LOAD_2ADDR_B32 with offset * 4."""
instructions = [
v_mov_b32_e32(v[10], 0),
s_mov_b32(s[0], 0xAAAAAAAA),
v_mov_b32_e32(v[0], s[0]),
s_mov_b32(s[0], 0xBBBBBBBB),
v_mov_b32_e32(v[1], s[0]),
DS(DSOp.DS_STORE_2ADDR_B32, addr=v[10], data0=v[0], data1=v[1], vdst=v[0], offset0=0, offset1=1),
s_waitcnt(lgkmcnt=0),
DS(DSOp.DS_LOAD_2ADDR_B32, addr=v[10], vdst=v[2], offset0=0, offset1=1),
s_waitcnt(lgkmcnt=0),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][2], 0xAAAAAAAA)
self.assertEqual(st.vgpr[0][3], 0xBBBBBBBB)
def test_ds_store_load_2addr_b64(self):
"""DS_STORE_2ADDR_B64 and DS_LOAD_2ADDR_B64."""
instructions = [
v_mov_b32_e32(v[10], 0),
s_mov_b32(s[0], 0xDEADBEEF),
v_mov_b32_e32(v[0], s[0]),
s_mov_b32(s[0], 0xCAFEBABE),
v_mov_b32_e32(v[1], s[0]),
s_mov_b32(s[0], 0x12345678),
v_mov_b32_e32(v[2], s[0]),
s_mov_b32(s[0], 0x9ABCDEF0),
v_mov_b32_e32(v[3], s[0]),
DS(DSOp.DS_STORE_2ADDR_B64, addr=v[10], data0=v[0], data1=v[2], vdst=v[0], offset0=0, offset1=2),
s_waitcnt(lgkmcnt=0),
DS(DSOp.DS_LOAD_2ADDR_B64, addr=v[10], vdst=v[4], offset0=0, offset1=2),
s_waitcnt(lgkmcnt=0),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][4], 0xDEADBEEF)
self.assertEqual(st.vgpr[0][5], 0xCAFEBABE)
self.assertEqual(st.vgpr[0][6], 0x12345678)
self.assertEqual(st.vgpr[0][7], 0x9ABCDEF0)
class TestDS2AddrMore(unittest.TestCase):
"""Additional DS_*_2ADDR tests."""
def test_ds_store_load_2addr_b32_nonzero_offsets(self):
"""DS_STORE_2ADDR_B32 with non-zero offsets (offset*4 scaling)."""
instructions = [
v_mov_b32_e32(v[10], 0),
s_mov_b32(s[2], 0x11111111),
v_mov_b32_e32(v[0], s[2]),
s_mov_b32(s[2], 0x22222222),
v_mov_b32_e32(v[1], s[2]),
DS(DSOp.DS_STORE_2ADDR_B32, addr=v[10], data0=v[0], data1=v[1], vdst=v[0], offset0=2, offset1=5),
s_waitcnt(lgkmcnt=0),
DS(DSOp.DS_LOAD_2ADDR_B32, addr=v[10], vdst=v[2], offset0=2, offset1=5),
s_waitcnt(lgkmcnt=0),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][2], 0x11111111, "v2 should have value from offset 8 (2*4)")
self.assertEqual(st.vgpr[0][3], 0x22222222, "v3 should have value from offset 20 (5*4)")
def test_ds_2addr_b64_no_overlap(self):
"""DS_LOAD_2ADDR_B64 with adjacent offsets should not overlap."""
instructions = [
v_mov_b32_e32(v[10], 0),
s_mov_b32(s[2], 0x11111111),
v_mov_b32_e32(v[0], s[2]),
ds_store_b32(addr=v[10], data0=v[0], offset0=0),
s_mov_b32(s[2], 0x22222222),
v_mov_b32_e32(v[0], s[2]),
ds_store_b32(addr=v[10], data0=v[0], offset0=4),
s_mov_b32(s[2], 0x33333333),
v_mov_b32_e32(v[0], s[2]),
ds_store_b32(addr=v[10], data0=v[0], offset0=8),
s_mov_b32(s[2], 0x44444444),
v_mov_b32_e32(v[0], s[2]),
ds_store_b32(addr=v[10], data0=v[0], offset0=12),
s_waitcnt(lgkmcnt=0),
DS(DSOp.DS_LOAD_2ADDR_B64, addr=v[10], vdst=v[4], offset0=0, offset1=1),
s_waitcnt(lgkmcnt=0),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][4], 0x11111111, "v4 should be 0x11111111")
self.assertEqual(st.vgpr[0][5], 0x22222222, "v5 should be 0x22222222")
self.assertEqual(st.vgpr[0][6], 0x33333333, "v6 should be 0x33333333")
self.assertEqual(st.vgpr[0][7], 0x44444444, "v7 should be 0x44444444")
def test_ds_load_2addr_b32_no_overwrite(self):
"""DS_LOAD_2ADDR_B32 should only write 2 VGPRs."""
instructions = [
v_mov_b32_e32(v[10], 0),
s_mov_b32(s[2], 0xAAAAAAAA),
v_mov_b32_e32(v[0], s[2]),
s_mov_b32(s[2], 0xBBBBBBBB),
v_mov_b32_e32(v[1], s[2]),
DS(DSOp.DS_STORE_2ADDR_B32, addr=v[10], data0=v[0], data1=v[1], vdst=v[0], offset0=0, offset1=1),
s_waitcnt(lgkmcnt=0),
s_mov_b32(s[2], 0xDEADBEEF),
v_mov_b32_e32(v[4], s[2]), # Sentinel
DS(DSOp.DS_LOAD_2ADDR_B32, addr=v[10], vdst=v[2], offset0=0, offset1=1),
s_waitcnt(lgkmcnt=0),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][2], 0xAAAAAAAA)
self.assertEqual(st.vgpr[0][3], 0xBBBBBBBB)
self.assertEqual(st.vgpr[0][4], 0xDEADBEEF, "v4 should be untouched")
def test_ds_load_b64_no_overwrite(self):
"""DS_LOAD_B64 should only write 2 VGPRs."""
instructions = [
v_mov_b32_e32(v[10], 0),
s_mov_b32(s[2], 0xDEADBEEF),
v_mov_b32_e32(v[0], s[2]),
s_mov_b32(s[2], 0xCAFEBABE),
v_mov_b32_e32(v[1], s[2]),
ds_store_b64(addr=v[10], data0=v[0], offset0=0),
s_waitcnt(lgkmcnt=0),
s_mov_b32(s[2], 0x12345678),
v_mov_b32_e32(v[4], s[2]), # Sentinel
ds_load_b64(addr=v[10], vdst=v[2], offset0=0),
s_waitcnt(lgkmcnt=0),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][2], 0xDEADBEEF)
self.assertEqual(st.vgpr[0][3], 0xCAFEBABE)
self.assertEqual(st.vgpr[0][4], 0x12345678, "v4 should be untouched")
class TestDSAtomic(unittest.TestCase):
"""Tests for DS atomic operations."""
def test_ds_max_rtn_u32(self):
"""DS_MAX_RTN_U32: atomically store max and return old value."""
instructions = [
v_mov_b32_e32(v[10], 0),
s_mov_b32(s[2], 100),
v_mov_b32_e32(v[0], s[2]),
ds_store_b32(addr=v[10], data0=v[0], offset0=0),
s_waitcnt(lgkmcnt=0),
s_mov_b32(s[2], 200),
v_mov_b32_e32(v[1], s[2]),
ds_max_rtn_u32(addr=v[10], data0=v[1], vdst=v[2], offset0=0),
s_waitcnt(lgkmcnt=0),
ds_load_b32(addr=v[10], vdst=v[3], offset0=0),
s_waitcnt(lgkmcnt=0),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][2], 100, "v2 should have old value (100)")
self.assertEqual(st.vgpr[0][3], 200, "v3 should have max(100, 200) = 200")
def test_ds_min_rtn_u32(self):
"""DS_MIN_RTN_U32: atomically store min and return old value."""
instructions = [
v_mov_b32_e32(v[10], 0),
s_mov_b32(s[2], 200),
v_mov_b32_e32(v[0], s[2]),
ds_store_b32(addr=v[10], data0=v[0], offset0=0),
s_waitcnt(lgkmcnt=0),
s_mov_b32(s[2], 100),
v_mov_b32_e32(v[1], s[2]),
ds_min_rtn_u32(addr=v[10], data0=v[1], vdst=v[2], offset0=0),
s_waitcnt(lgkmcnt=0),
ds_load_b32(addr=v[10], vdst=v[3], offset0=0),
s_waitcnt(lgkmcnt=0),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][2], 200)
self.assertEqual(st.vgpr[0][3], 100)
def test_ds_and_rtn_b32(self):
"""DS_AND_RTN_B32: atomically AND and return old value."""
instructions = [
v_mov_b32_e32(v[10], 0),
s_mov_b32(s[2], 0xFF00FF00),
v_mov_b32_e32(v[0], s[2]),
ds_store_b32(addr=v[10], data0=v[0], offset0=0),
s_waitcnt(lgkmcnt=0),
s_mov_b32(s[2], 0xFFFF0000),
v_mov_b32_e32(v[1], s[2]),
ds_and_rtn_b32(addr=v[10], data0=v[1], vdst=v[2], offset0=0),
s_waitcnt(lgkmcnt=0),
ds_load_b32(addr=v[10], vdst=v[3], offset0=0),
s_waitcnt(lgkmcnt=0),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][2], 0xFF00FF00)
self.assertEqual(st.vgpr[0][3], 0xFF000000)
def test_ds_or_rtn_b32(self):
"""DS_OR_RTN_B32: atomically OR and return old value."""
instructions = [
v_mov_b32_e32(v[10], 0),
s_mov_b32(s[2], 0x00FF0000),
v_mov_b32_e32(v[0], s[2]),
ds_store_b32(addr=v[10], data0=v[0], offset0=0),
s_waitcnt(lgkmcnt=0),
s_mov_b32(s[2], 0x000000FF),
v_mov_b32_e32(v[1], s[2]),
ds_or_rtn_b32(addr=v[10], data0=v[1], vdst=v[2], offset0=0),
s_waitcnt(lgkmcnt=0),
ds_load_b32(addr=v[10], vdst=v[3], offset0=0),
s_waitcnt(lgkmcnt=0),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][2], 0x00FF0000)
self.assertEqual(st.vgpr[0][3], 0x00FF00FF)
def test_ds_xor_rtn_b32(self):
"""DS_XOR_RTN_B32: atomically XOR and return old value."""
instructions = [
v_mov_b32_e32(v[10], 0),
s_mov_b32(s[2], 0xAAAAAAAA),
v_mov_b32_e32(v[0], s[2]),
ds_store_b32(addr=v[10], data0=v[0], offset0=0),
s_waitcnt(lgkmcnt=0),
s_mov_b32(s[2], 0xFFFFFFFF),
v_mov_b32_e32(v[1], s[2]),
ds_xor_rtn_b32(addr=v[10], data0=v[1], vdst=v[2], offset0=0),
s_waitcnt(lgkmcnt=0),
ds_load_b32(addr=v[10], vdst=v[3], offset0=0),
s_waitcnt(lgkmcnt=0),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][2], 0xAAAAAAAA)
self.assertEqual(st.vgpr[0][3], 0x55555555)
def test_ds_inc_rtn_u32(self):
"""DS_INC_RTN_U32: increment with wrap."""
instructions = [
v_mov_b32_e32(v[10], 0),
s_mov_b32(s[2], 5),
v_mov_b32_e32(v[0], s[2]),
ds_store_b32(addr=v[10], data0=v[0], offset0=0),
s_waitcnt(lgkmcnt=0),
s_mov_b32(s[2], 10), # limit
v_mov_b32_e32(v[1], s[2]),
ds_inc_rtn_u32(addr=v[10], data0=v[1], vdst=v[2], offset0=0),
s_waitcnt(lgkmcnt=0),
ds_load_b32(addr=v[10], vdst=v[3], offset0=0),
s_waitcnt(lgkmcnt=0),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][2], 5)
self.assertEqual(st.vgpr[0][3], 6)
def test_ds_dec_rtn_u32(self):
"""DS_DEC_RTN_U32: decrement with wrap."""
instructions = [
v_mov_b32_e32(v[10], 0),
s_mov_b32(s[2], 5),
v_mov_b32_e32(v[0], s[2]),
ds_store_b32(addr=v[10], data0=v[0], offset0=0),
s_waitcnt(lgkmcnt=0),
s_mov_b32(s[2], 10), # limit
v_mov_b32_e32(v[1], s[2]),
ds_dec_rtn_u32(addr=v[10], data0=v[1], vdst=v[2], offset0=0),
s_waitcnt(lgkmcnt=0),
ds_load_b32(addr=v[10], vdst=v[3], offset0=0),
s_waitcnt(lgkmcnt=0),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][2], 5)
self.assertEqual(st.vgpr[0][3], 4)
def test_ds_cmpstore_b32_match(self):
"""DS_CMPSTORE_B32: conditional store when compare matches."""
instructions = [
v_mov_b32_e32(v[10], 0),
s_mov_b32(s[2], 100),
v_mov_b32_e32(v[0], s[2]),
ds_store_b32(addr=v[10], data0=v[0], offset0=0),
s_waitcnt(lgkmcnt=0),
s_mov_b32(s[2], 200),
v_mov_b32_e32(v[1], s[2]), # new value
s_mov_b32(s[2], 100),
v_mov_b32_e32(v[2], s[2]), # compare = 100 (matches)
ds_cmpstore_b32(addr=v[10], data0=v[1], data1=v[2], vdst=v[3], offset0=0),
s_waitcnt(lgkmcnt=0),
ds_load_b32(addr=v[10], vdst=v[4], offset0=0),
s_waitcnt(lgkmcnt=0),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][4], 200)
def test_ds_cmpstore_b32_no_match(self):
"""DS_CMPSTORE_B32: no store when compare doesn't match."""
instructions = [
v_mov_b32_e32(v[10], 0),
s_mov_b32(s[2], 100),
v_mov_b32_e32(v[0], s[2]),
ds_store_b32(addr=v[10], data0=v[0], offset0=0),
s_waitcnt(lgkmcnt=0),
s_mov_b32(s[2], 200),
v_mov_b32_e32(v[1], s[2]), # new value
s_mov_b32(s[2], 50),
v_mov_b32_e32(v[2], s[2]), # compare = 50 (doesn't match)
ds_cmpstore_b32(addr=v[10], data0=v[1], data1=v[2], vdst=v[3], offset0=0),
s_waitcnt(lgkmcnt=0),
ds_load_b32(addr=v[10], vdst=v[4], offset0=0),
s_waitcnt(lgkmcnt=0),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][4], 100)
def test_ds_max_u32_no_rtn(self):
"""DS_MAX_U32 (no RTN): atomically store max, no return value."""
instructions = [
v_mov_b32_e32(v[10], 0),
s_mov_b32(s[2], 100),
v_mov_b32_e32(v[0], s[2]),
ds_store_b32(addr=v[10], data0=v[0], offset0=0),
s_waitcnt(lgkmcnt=0),
s_mov_b32(s[2], 200),
v_mov_b32_e32(v[1], s[2]),
ds_max_u32(addr=v[10], data0=v[1], vdst=v[2], offset0=0),
s_waitcnt(lgkmcnt=0),
ds_load_b32(addr=v[10], vdst=v[3], offset0=0),
s_waitcnt(lgkmcnt=0),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][3], 200, "v3 should have max(100, 200) = 200")
def test_ds_add_u32_no_rtn_preserves_vdst(self):
"""DS_ADD_U32 (no RTN) should NOT write to vdst."""
instructions = [
v_mov_b32_e32(v[10], 0),
s_mov_b32(s[2], 0xDEADBEEF),
v_mov_b32_e32(v[2], s[2]), # sentinel
s_mov_b32(s[2], 100),
v_mov_b32_e32(v[0], s[2]),
ds_store_b32(addr=v[10], data0=v[0], offset0=0),
s_waitcnt(lgkmcnt=0),
s_mov_b32(s[2], 50),
v_mov_b32_e32(v[1], s[2]),
ds_add_u32(addr=v[10], data0=v[1], vdst=v[2], offset0=0),
s_waitcnt(lgkmcnt=0),
ds_load_b32(addr=v[10], vdst=v[3], offset0=0),
s_waitcnt(lgkmcnt=0),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][2], 0xDEADBEEF, "v2 should preserve sentinel")
self.assertEqual(st.vgpr[0][3], 150, "v3 should have 100 + 50 = 150")
def test_ds_add_rtn_u32_writes_vdst(self):
"""DS_ADD_RTN_U32 should write old value to vdst."""
instructions = [
v_mov_b32_e32(v[10], 0),
s_mov_b32(s[2], 0xDEADBEEF),
v_mov_b32_e32(v[2], s[2]), # sentinel
s_mov_b32(s[2], 100),
v_mov_b32_e32(v[0], s[2]),
ds_store_b32(addr=v[10], data0=v[0], offset0=0),
s_waitcnt(lgkmcnt=0),
s_mov_b32(s[2], 50),
v_mov_b32_e32(v[1], s[2]),
ds_add_rtn_u32(addr=v[10], data0=v[1], vdst=v[2], offset0=0),
s_waitcnt(lgkmcnt=0),
ds_load_b32(addr=v[10], vdst=v[3], offset0=0),
s_waitcnt(lgkmcnt=0),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][2], 100, "v2 should have old value (100)")
self.assertEqual(st.vgpr[0][3], 150, "v3 should have 100 + 50 = 150")
def test_ds_dec_rtn_u32_wrap(self):
"""DS_DEC_RTN_U32: decrement wraps when value is 0 or > limit."""
instructions = [
v_mov_b32_e32(v[10], 0),
s_mov_b32(s[2], 0), # Start at 0
v_mov_b32_e32(v[0], s[2]),
ds_store_b32(addr=v[10], data0=v[0], offset0=0),
s_waitcnt(lgkmcnt=0),
s_mov_b32(s[2], 10), # limit
v_mov_b32_e32(v[1], s[2]),
ds_dec_rtn_u32(addr=v[10], data0=v[1], vdst=v[2], offset0=0),
s_waitcnt(lgkmcnt=0),
ds_load_b32(addr=v[10], vdst=v[3], offset0=0),
s_waitcnt(lgkmcnt=0),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][2], 0, "v2 should have old value (0)")
# When mem == 0 or mem > limit, result = limit
self.assertEqual(st.vgpr[0][3], 10, "v3 should wrap to limit (10)")
class TestDSStorexchg(unittest.TestCase):
"""Tests for DS_STOREXCHG instructions."""
def test_ds_storexchg_rtn_b32(self):
"""DS_STOREXCHG_RTN_B32: exchange value and return old."""
instructions = [
v_mov_b32_e32(v[10], 0),
s_mov_b32(s[0], 0xAAAAAAAA),
v_mov_b32_e32(v[0], s[0]),
ds_store_b32(addr=v[10], data0=v[0], offset0=0),
s_waitcnt(lgkmcnt=0),
s_mov_b32(s[0], 0xBBBBBBBB),
v_mov_b32_e32(v[1], s[0]),
DS(DSOp.DS_STOREXCHG_RTN_B32, addr=v[10], data0=v[1], vdst=v[2], offset0=0),
s_waitcnt(lgkmcnt=0),
ds_load_b32(addr=v[10], vdst=v[3], offset0=0),
s_waitcnt(lgkmcnt=0),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][2], 0xAAAAAAAA)
self.assertEqual(st.vgpr[0][3], 0xBBBBBBBB)
class TestDSRegisterWidth(unittest.TestCase):
"""Regression tests: DS loads should only write correct number of VGPRs."""
def test_ds_load_b32_no_overwrite(self):
"""DS_LOAD_B32 should only write 1 VGPR."""
instructions = [
v_mov_b32_e32(v[0], 0),
s_mov_b32(s[0], 0xDEADBEEF),
v_mov_b32_e32(v[1], s[0]),
s_mov_b32(s[0], 0x11111111),
v_mov_b32_e32(v[2], s[0]), # sentinel
ds_store_b32(addr=v[0], data0=v[1], offset0=0),
s_waitcnt(lgkmcnt=0),
ds_load_b32(addr=v[0], vdst=v[1], offset0=0),
s_waitcnt(lgkmcnt=0),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][1], 0xDEADBEEF)
self.assertEqual(st.vgpr[0][2], 0x11111111, "v2 should be untouched")
class TestDS2AddrStride64(unittest.TestCase):
"""Tests for DS_*_2ADDR_STRIDE64 (offset * 256 for B32, offset * 512 for B64)."""
def test_ds_store_load_2addr_stride64_b32(self):
"""DS_STORE_2ADDR_STRIDE64_B32: stores at ADDR + offset*256."""
instructions = [
v_mov_b32_e32(v[10], 0),
s_mov_b32(s[0], 0xAAAAAAAA),
v_mov_b32_e32(v[0], s[0]),
s_mov_b32(s[0], 0xBBBBBBBB),
v_mov_b32_e32(v[1], s[0]),
DS(DSOp.DS_STORE_2ADDR_STRIDE64_B32, addr=v[10], data0=v[0], data1=v[1], vdst=v[0], offset0=1, offset1=2),
s_waitcnt(lgkmcnt=0),
DS(DSOp.DS_LOAD_2ADDR_STRIDE64_B32, addr=v[10], vdst=v[2], offset0=1, offset1=2),
s_waitcnt(lgkmcnt=0),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][2], 0xAAAAAAAA, "v2 from addr 256")
self.assertEqual(st.vgpr[0][3], 0xBBBBBBBB, "v3 from addr 512")
def test_ds_store_load_2addr_stride64_b64(self):
"""DS_STORE_2ADDR_STRIDE64_B64: stores at ADDR + offset*512."""
instructions = [
v_mov_b32_e32(v[10], 0),
s_mov_b32(s[0], 0xDEADBEEF),
v_mov_b32_e32(v[0], s[0]),
s_mov_b32(s[0], 0xCAFEBABE),
v_mov_b32_e32(v[1], s[0]),
s_mov_b32(s[0], 0x12345678),
v_mov_b32_e32(v[2], s[0]),
s_mov_b32(s[0], 0x9ABCDEF0),
v_mov_b32_e32(v[3], s[0]),
DS(DSOp.DS_STORE_2ADDR_STRIDE64_B64, addr=v[10], data0=v[0], data1=v[2], vdst=v[0], offset0=1, offset1=2),
s_waitcnt(lgkmcnt=0),
DS(DSOp.DS_LOAD_2ADDR_STRIDE64_B64, addr=v[10], vdst=v[4], offset0=1, offset1=2),
s_waitcnt(lgkmcnt=0),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][4], 0xDEADBEEF)
self.assertEqual(st.vgpr[0][5], 0xCAFEBABE)
self.assertEqual(st.vgpr[0][6], 0x12345678)
self.assertEqual(st.vgpr[0][7], 0x9ABCDEF0)
def test_ds_storexchg_2addr_rtn_b32(self):
"""DS_STOREXCHG_2ADDR_RTN_B32: exchange at two addresses."""
instructions = [
v_mov_b32_e32(v[10], 0),
s_mov_b32(s[0], 0x11111111),
v_mov_b32_e32(v[0], s[0]),
s_mov_b32(s[0], 0x22222222),
v_mov_b32_e32(v[1], s[0]),
DS(DSOp.DS_STORE_2ADDR_B32, addr=v[10], data0=v[0], data1=v[1], vdst=v[0], offset0=0, offset1=1),
s_waitcnt(lgkmcnt=0),
s_mov_b32(s[0], 0xAAAAAAAA),
v_mov_b32_e32(v[2], s[0]),
s_mov_b32(s[0], 0xBBBBBBBB),
v_mov_b32_e32(v[3], s[0]),
DS(DSOp.DS_STOREXCHG_2ADDR_RTN_B32, addr=v[10], data0=v[2], data1=v[3], vdst=v[4], offset0=0, offset1=1),
s_waitcnt(lgkmcnt=0),
DS(DSOp.DS_LOAD_2ADDR_B32, addr=v[10], vdst=v[6], offset0=0, offset1=1),
s_waitcnt(lgkmcnt=0),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][4], 0x11111111, "old val 0")
self.assertEqual(st.vgpr[0][5], 0x22222222, "old val 1")
self.assertEqual(st.vgpr[0][6], 0xAAAAAAAA, "new val 0")
self.assertEqual(st.vgpr[0][7], 0xBBBBBBBB, "new val 1")
def test_ds_storexchg_rtn_b64(self):
"""DS_STOREXCHG_RTN_B64: exchange 64-bit value and return old."""
instructions = [
v_mov_b32_e32(v[10], 0),
s_mov_b32(s[0], 0xDEADBEEF),
v_mov_b32_e32(v[0], s[0]), # initial low
s_mov_b32(s[0], 0xCAFEBABE),
v_mov_b32_e32(v[1], s[0]), # initial high
DS(DSOp.DS_STORE_B64, addr=v[10], data0=v[0], vdst=v[0], offset0=0),
s_waitcnt(lgkmcnt=0),
s_mov_b32(s[0], 0x12345678),
v_mov_b32_e32(v[2], s[0]), # new low
s_mov_b32(s[0], 0x9ABCDEF0),
v_mov_b32_e32(v[3], s[0]), # new high
DS(DSOp.DS_STOREXCHG_RTN_B64, addr=v[10], data0=v[2], vdst=v[4], offset0=0),
s_waitcnt(lgkmcnt=0),
DS(DSOp.DS_LOAD_B64, addr=v[10], vdst=v[6], offset0=0),
s_waitcnt(lgkmcnt=0),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][4], 0xDEADBEEF, "v4 should have old low dword")
self.assertEqual(st.vgpr[0][5], 0xCAFEBABE, "v5 should have old high dword")
self.assertEqual(st.vgpr[0][6], 0x12345678, "v6 should have new low dword")
self.assertEqual(st.vgpr[0][7], 0x9ABCDEF0, "v7 should have new high dword")
def test_ds_store_load_2addr_stride64_b64_roundtrip(self):
"""DS_STORE_2ADDR_STRIDE64_B64 followed by DS_LOAD_2ADDR_STRIDE64_B64 works correctly."""
instructions = [
v_mov_b32_e32(v[10], 0),
s_mov_b32(s[0], 0x11111111),
v_mov_b32_e32(v[0], s[0]),
s_mov_b32(s[0], 0x22222222),
v_mov_b32_e32(v[1], s[0]),
DS(DSOp.DS_STORE_2ADDR_STRIDE64_B64, addr=v[10], data0=v[0], data1=v[0], vdst=v[0], offset0=1, offset1=2),
s_waitcnt(lgkmcnt=0),
DS(DSOp.DS_LOAD_2ADDR_STRIDE64_B64, addr=v[10], vdst=v[2], offset0=1, offset1=2),
s_waitcnt(lgkmcnt=0),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][2], 0x11111111, "v2 should have val1 low")
self.assertEqual(st.vgpr[0][3], 0x22222222, "v3 should have val1 high")
self.assertEqual(st.vgpr[0][4], 0x11111111, "v4 should have val2 low")
self.assertEqual(st.vgpr[0][5], 0x22222222, "v5 should have val2 high")
def test_ds_storexchg_2addr_stride64_rtn_b32(self):
"""DS_STOREXCHG_2ADDR_STRIDE64_RTN_B32: exchange at two addresses (offset*256)."""
instructions = [
v_mov_b32_e32(v[10], 0),
s_mov_b32(s[0], 0x11111111),
v_mov_b32_e32(v[0], s[0]),
s_mov_b32(s[0], 0x22222222),
v_mov_b32_e32(v[1], s[0]),
DS(DSOp.DS_STORE_2ADDR_STRIDE64_B32, addr=v[10], data0=v[0], data1=v[1], vdst=v[0], offset0=1, offset1=2),
s_waitcnt(lgkmcnt=0),
s_mov_b32(s[0], 0xAAAAAAAA),
v_mov_b32_e32(v[2], s[0]),
s_mov_b32(s[0], 0xBBBBBBBB),
v_mov_b32_e32(v[3], s[0]),
DS(DSOp.DS_STOREXCHG_2ADDR_STRIDE64_RTN_B32, addr=v[10], data0=v[2], data1=v[3], vdst=v[4], offset0=1, offset1=2),
s_waitcnt(lgkmcnt=0),
DS(DSOp.DS_LOAD_2ADDR_STRIDE64_B32, addr=v[10], vdst=v[6], offset0=1, offset1=2),
s_waitcnt(lgkmcnt=0),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][4], 0x11111111, "v4 should have old value")
self.assertEqual(st.vgpr[0][5], 0x22222222, "v5 should have old value")
self.assertEqual(st.vgpr[0][6], 0xAAAAAAAA, "v6 should have new value")
self.assertEqual(st.vgpr[0][7], 0xBBBBBBBB, "v7 should have new value")
def test_ds_storexchg_2addr_stride64_rtn_b64_returns_old(self):
"""DS_STOREXCHG_2ADDR_STRIDE64_RTN_B64: returns old values correctly."""
instructions = [
v_mov_b32_e32(v[10], 0),
s_mov_b32(s[0], 0x11111111),
v_mov_b32_e32(v[0], s[0]),
s_mov_b32(s[0], 0x22222222),
v_mov_b32_e32(v[1], s[0]),
DS(DSOp.DS_STORE_2ADDR_STRIDE64_B64, addr=v[10], data0=v[0], data1=v[0], vdst=v[0], offset0=1, offset1=2),
s_waitcnt(lgkmcnt=0),
s_mov_b32(s[0], 0xAAAAAAAA),
v_mov_b32_e32(v[6], s[0]),
s_mov_b32(s[0], 0xBBBBBBBB),
v_mov_b32_e32(v[7], s[0]),
DS(DSOp.DS_STOREXCHG_2ADDR_STRIDE64_RTN_B64, addr=v[10], data0=v[6], data1=v[6], vdst=v[8], offset0=1, offset1=2),
s_waitcnt(lgkmcnt=0),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][8], 0x11111111, "v8 should have old val1 low")
self.assertEqual(st.vgpr[0][9], 0x22222222, "v9 should have old val1 high")
self.assertEqual(st.vgpr[0][10], 0x11111111, "v10 should have old val2 low")
self.assertEqual(st.vgpr[0][11], 0x22222222, "v11 should have old val2 high")
class TestAtomicOrdering(unittest.TestCase):
"""Tests for atomic operation return values and ordering."""
def test_ds_add_rtn_sequence(self):
"""DS_ADD_RTN returns correct old values in sequence."""
instructions = [
v_mov_b32_e32(v[10], 0),
v_mov_b32_e32(v[0], 100),
DS(DSOp.DS_STORE_B32, addr=v[10], data0=v[0], vdst=v[0], offset0=0),
s_waitcnt(lgkmcnt=0),
v_mov_b32_e32(v[1], 25),
DS(DSOp.DS_ADD_RTN_U32, addr=v[10], data0=v[1], vdst=v[2], offset0=0),
s_waitcnt(lgkmcnt=0),
DS(DSOp.DS_ADD_RTN_U32, addr=v[10], data0=v[1], vdst=v[3], offset0=0),
s_waitcnt(lgkmcnt=0),
DS(DSOp.DS_LOAD_B32, addr=v[10], vdst=v[4], offset0=0),
s_waitcnt(lgkmcnt=0),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][2], 100, "First add should return 100")
self.assertEqual(st.vgpr[0][3], 125, "Second add should return 125")
self.assertEqual(st.vgpr[0][4], 150, "Final value should be 150")
if __name__ == '__main__':
unittest.main()

View File

@@ -0,0 +1,363 @@
"""Tests for FLAT instructions - flat memory operations.
Includes: flat_load_*, flat_store_*, flat_atomic_*
"""
import unittest
from extra.assembly.amd.test.hw.helpers import *
class TestFlatAtomic(unittest.TestCase):
"""Tests for FLAT atomic instructions."""
def _make_test(self, setup_instrs, atomic_instr, check_fn, test_offset=2000):
"""Helper to create atomic test instructions."""
instructions = [
s_load_b64(s[2:3], s[80], 0, soffset=SrcEnum.NULL),
s_waitcnt(lgkmcnt=0),
v_mov_b32_e32(v[0], s[2]),
v_mov_b32_e32(v[1], s[3]),
] + setup_instrs + [atomic_instr, s_waitcnt(vmcnt=0),
v_mov_b32_e32(v[0], 0),
v_mov_b32_e32(v[1], 0),
s_mov_b32(s[2], 0),
s_mov_b32(s[3], 0),
]
st = run_program(instructions, n_lanes=1)
check_fn(st)
def test_flat_atomic_add_u32(self):
"""FLAT_ATOMIC_ADD_U32 adds to memory and returns old value."""
TEST_OFFSET = 2000
setup = [
s_mov_b32(s[0], 100),
v_mov_b32_e32(v[2], s[0]),
global_store_b32(addr=v[0], data=v[2], saddr=SrcEnum.NULL, offset=TEST_OFFSET),
s_waitcnt(vmcnt=0),
s_mov_b32(s[0], 50),
v_mov_b32_e32(v[3], s[0]),
]
atomic = FLAT(FLATOp.FLAT_ATOMIC_ADD_U32, addr=v[0], data=v[3], vdst=v[4], saddr=SrcEnum.NULL, offset=TEST_OFFSET, glc=1)
def check(st):
self.assertEqual(st.vgpr[0][4], 100)
self._make_test(setup, atomic, check, TEST_OFFSET)
def test_flat_atomic_swap_b32(self):
"""FLAT_ATOMIC_SWAP_B32 swaps memory value and returns old value."""
TEST_OFFSET = 2000
setup = [
s_mov_b32(s[0], 0xAAAAAAAA),
v_mov_b32_e32(v[2], s[0]),
global_store_b32(addr=v[0], data=v[2], saddr=SrcEnum.NULL, offset=TEST_OFFSET),
s_waitcnt(vmcnt=0),
s_mov_b32(s[0], 0xBBBBBBBB),
v_mov_b32_e32(v[3], s[0]),
]
atomic = FLAT(FLATOp.FLAT_ATOMIC_SWAP_B32, addr=v[0], data=v[3], vdst=v[4], saddr=SrcEnum.NULL, offset=TEST_OFFSET, glc=1)
def check(st):
self.assertEqual(st.vgpr[0][4], 0xAAAAAAAA)
self._make_test(setup, atomic, check, TEST_OFFSET)
def test_flat_atomic_and_b32(self):
"""FLAT_ATOMIC_AND_B32 ANDs with memory and returns old value."""
TEST_OFFSET = 2000
setup = [
s_mov_b32(s[0], 0xFF00FF00),
v_mov_b32_e32(v[2], s[0]),
global_store_b32(addr=v[0], data=v[2], saddr=SrcEnum.NULL, offset=TEST_OFFSET),
s_waitcnt(vmcnt=0),
s_mov_b32(s[0], 0xFFFF0000),
v_mov_b32_e32(v[3], s[0]),
]
atomic = FLAT(FLATOp.FLAT_ATOMIC_AND_B32, addr=v[0], data=v[3], vdst=v[4], saddr=SrcEnum.NULL, offset=TEST_OFFSET, glc=1)
def check(st):
self.assertEqual(st.vgpr[0][4], 0xFF00FF00)
self._make_test(setup, atomic, check, TEST_OFFSET)
def test_flat_atomic_or_b32(self):
"""FLAT_ATOMIC_OR_B32 ORs with memory and returns old value."""
TEST_OFFSET = 2000
setup = [
s_mov_b32(s[0], 0x00FF0000),
v_mov_b32_e32(v[2], s[0]),
global_store_b32(addr=v[0], data=v[2], saddr=SrcEnum.NULL, offset=TEST_OFFSET),
s_waitcnt(vmcnt=0),
s_mov_b32(s[0], 0x0000FF00),
v_mov_b32_e32(v[3], s[0]),
]
atomic = FLAT(FLATOp.FLAT_ATOMIC_OR_B32, addr=v[0], data=v[3], vdst=v[4], saddr=SrcEnum.NULL, offset=TEST_OFFSET, glc=1)
def check(st):
self.assertEqual(st.vgpr[0][4], 0x00FF0000)
self._make_test(setup, atomic, check, TEST_OFFSET)
def test_flat_atomic_inc_u32(self):
"""FLAT_ATOMIC_INC_U32 increments and returns old value."""
TEST_OFFSET = 2000
setup = [
s_mov_b32(s[0], 10),
v_mov_b32_e32(v[2], s[0]),
global_store_b32(addr=v[0], data=v[2], saddr=SrcEnum.NULL, offset=TEST_OFFSET),
s_waitcnt(vmcnt=0),
s_mov_b32(s[0], 100), # threshold
v_mov_b32_e32(v[3], s[0]),
]
atomic = FLAT(FLATOp.FLAT_ATOMIC_INC_U32, addr=v[0], data=v[3], vdst=v[4], saddr=SrcEnum.NULL, offset=TEST_OFFSET, glc=1)
def check(st):
self.assertEqual(st.vgpr[0][4], 10)
self._make_test(setup, atomic, check, TEST_OFFSET)
def test_flat_atomic_dec_u32(self):
"""FLAT_ATOMIC_DEC_U32 decrements and returns old value."""
TEST_OFFSET = 2000
setup = [
s_mov_b32(s[0], 10),
v_mov_b32_e32(v[2], s[0]),
global_store_b32(addr=v[0], data=v[2], saddr=SrcEnum.NULL, offset=TEST_OFFSET),
s_waitcnt(vmcnt=0),
s_mov_b32(s[0], 100),
v_mov_b32_e32(v[3], s[0]),
]
atomic = FLAT(FLATOp.FLAT_ATOMIC_DEC_U32, addr=v[0], data=v[3], vdst=v[4], saddr=SrcEnum.NULL, offset=TEST_OFFSET, glc=1)
def check(st):
self.assertEqual(st.vgpr[0][4], 10)
self._make_test(setup, atomic, check, TEST_OFFSET)
def test_flat_atomic_sub_u32(self):
"""FLAT_ATOMIC_SUB_U32 subtracts from memory and returns old value."""
TEST_OFFSET = 2000
setup = [
s_mov_b32(s[0], 100),
v_mov_b32_e32(v[2], s[0]),
global_store_b32(addr=v[0], data=v[2], saddr=SrcEnum.NULL, offset=TEST_OFFSET),
s_waitcnt(vmcnt=0),
s_mov_b32(s[0], 30),
v_mov_b32_e32(v[3], s[0]), # sub 30
]
atomic = FLAT(FLATOp.FLAT_ATOMIC_SUB_U32, addr=v[0], data=v[3], vdst=v[4], saddr=SrcEnum.NULL, offset=TEST_OFFSET, glc=1)
def check(st):
self.assertEqual(st.vgpr[0][4], 100, "v4 should have old value (100)")
self._make_test(setup, atomic, check, TEST_OFFSET)
def test_flat_atomic_xor_b32(self):
"""FLAT_ATOMIC_XOR_B32 XORs with memory and returns old value."""
TEST_OFFSET = 2000
setup = [
s_mov_b32(s[0], 0xAAAAAAAA),
v_mov_b32_e32(v[2], s[0]),
global_store_b32(addr=v[0], data=v[2], saddr=SrcEnum.NULL, offset=TEST_OFFSET),
s_waitcnt(vmcnt=0),
s_mov_b32(s[0], 0xFFFFFFFF),
v_mov_b32_e32(v[3], s[0]), # XOR mask
]
atomic = FLAT(FLATOp.FLAT_ATOMIC_XOR_B32, addr=v[0], data=v[3], vdst=v[4], saddr=SrcEnum.NULL, offset=TEST_OFFSET, glc=1)
def check(st):
self.assertEqual(st.vgpr[0][4], 0xAAAAAAAA, "v4 should have old value")
self._make_test(setup, atomic, check, TEST_OFFSET)
def test_flat_atomic_min_u32(self):
"""FLAT_ATOMIC_MIN_U32 stores min and returns old value."""
TEST_OFFSET = 2000
setup = [
s_mov_b32(s[0], 100),
v_mov_b32_e32(v[2], s[0]),
global_store_b32(addr=v[0], data=v[2], saddr=SrcEnum.NULL, offset=TEST_OFFSET),
s_waitcnt(vmcnt=0),
s_mov_b32(s[0], 50),
v_mov_b32_e32(v[3], s[0]), # compare value (smaller)
]
atomic = FLAT(FLATOp.FLAT_ATOMIC_MIN_U32, addr=v[0], data=v[3], vdst=v[4], saddr=SrcEnum.NULL, offset=TEST_OFFSET, glc=1)
def check(st):
self.assertEqual(st.vgpr[0][4], 100, "v4 should have old value (100)")
self._make_test(setup, atomic, check, TEST_OFFSET)
def test_flat_atomic_max_u32(self):
"""FLAT_ATOMIC_MAX_U32 stores max and returns old value."""
TEST_OFFSET = 2000
setup = [
s_mov_b32(s[0], 50),
v_mov_b32_e32(v[2], s[0]),
global_store_b32(addr=v[0], data=v[2], saddr=SrcEnum.NULL, offset=TEST_OFFSET),
s_waitcnt(vmcnt=0),
s_mov_b32(s[0], 100),
v_mov_b32_e32(v[3], s[0]), # compare value (larger)
]
atomic = FLAT(FLATOp.FLAT_ATOMIC_MAX_U32, addr=v[0], data=v[3], vdst=v[4], saddr=SrcEnum.NULL, offset=TEST_OFFSET, glc=1)
def check(st):
self.assertEqual(st.vgpr[0][4], 50, "v4 should have old value (50)")
self._make_test(setup, atomic, check, TEST_OFFSET)
def test_flat_atomic_inc_u64_returns_old_value(self):
"""FLAT_ATOMIC_INC_U64 should return full 64-bit old value."""
TEST_OFFSET = 2000
setup = [
# Store initial 64-bit value: 0xCAFEBABE_DEADBEEF
s_mov_b32(s[0], 0xDEADBEEF),
v_mov_b32_e32(v[2], s[0]),
s_mov_b32(s[0], 0xCAFEBABE),
v_mov_b32_e32(v[3], s[0]),
global_store_b64(addr=v[0], data=v[2], saddr=SrcEnum.NULL, offset=TEST_OFFSET),
s_waitcnt(vmcnt=0),
# Threshold: 0xFFFFFFFF_FFFFFFFF
s_mov_b32(s[0], 0xFFFFFFFF),
v_mov_b32_e32(v[4], s[0]),
v_mov_b32_e32(v[5], s[0]),
]
atomic = FLAT(FLATOp.FLAT_ATOMIC_INC_U64, addr=v[0], data=v[4], vdst=v[6], saddr=SrcEnum.NULL, offset=TEST_OFFSET, glc=1)
def check(st):
self.assertEqual(st.vgpr[0][6], 0xDEADBEEF, "v6 should have old value low dword")
self.assertEqual(st.vgpr[0][7], 0xCAFEBABE, "v7 should have old value high dword")
self._make_test(setup, atomic, check, TEST_OFFSET)
def test_flat_atomic_add_u64(self):
"""FLAT_ATOMIC_ADD_U64 adds 64-bit value and returns old value."""
TEST_OFFSET = 2000
setup = [
s_mov_b32(s[0], 0x11111111),
v_mov_b32_e32(v[2], s[0]),
s_mov_b32(s[0], 0x22222222),
v_mov_b32_e32(v[3], s[0]),
global_store_b64(addr=v[0], data=v[2], saddr=SrcEnum.NULL, offset=TEST_OFFSET),
s_waitcnt(vmcnt=0),
s_mov_b32(s[0], 0x00000001), # add 1
v_mov_b32_e32(v[4], s[0]),
s_mov_b32(s[0], 0x00000000),
v_mov_b32_e32(v[5], s[0]),
]
atomic = FLAT(FLATOp.FLAT_ATOMIC_ADD_U64, addr=v[0], data=v[4], vdst=v[6], saddr=SrcEnum.NULL, offset=TEST_OFFSET, glc=1)
def check(st):
self.assertEqual(st.vgpr[0][6], 0x11111111, "v6 should have old value low")
self.assertEqual(st.vgpr[0][7], 0x22222222, "v7 should have old value high")
self._make_test(setup, atomic, check, TEST_OFFSET)
def test_flat_atomic_swap_b64(self):
"""FLAT_ATOMIC_SWAP_B64 swaps 64-bit value and returns old value."""
TEST_OFFSET = 2000
setup = [
s_mov_b32(s[0], 0xAAAAAAAA),
v_mov_b32_e32(v[2], s[0]),
s_mov_b32(s[0], 0xBBBBBBBB),
v_mov_b32_e32(v[3], s[0]),
global_store_b64(addr=v[0], data=v[2], saddr=SrcEnum.NULL, offset=TEST_OFFSET),
s_waitcnt(vmcnt=0),
s_mov_b32(s[0], 0xCCCCCCCC),
v_mov_b32_e32(v[4], s[0]),
s_mov_b32(s[0], 0xDDDDDDDD),
v_mov_b32_e32(v[5], s[0]),
]
atomic = FLAT(FLATOp.FLAT_ATOMIC_SWAP_B64, addr=v[0], data=v[4], vdst=v[6], saddr=SrcEnum.NULL, offset=TEST_OFFSET, glc=1)
def check(st):
self.assertEqual(st.vgpr[0][6], 0xAAAAAAAA, "v6 should have old value low")
self.assertEqual(st.vgpr[0][7], 0xBBBBBBBB, "v7 should have old value high")
self._make_test(setup, atomic, check, TEST_OFFSET)
class TestFlatLoad(unittest.TestCase):
"""Tests for FLAT load instructions."""
def test_flat_load_b32(self):
"""FLAT_LOAD_B32 loads 32-bit value correctly."""
TEST_OFFSET = 2000
instructions = [
s_load_b64(s[2:3], s[80], 0, soffset=SrcEnum.NULL),
s_waitcnt(lgkmcnt=0),
v_mov_b32_e32(v[0], s[2]),
v_mov_b32_e32(v[1], s[3]),
s_mov_b32(s[0], 0xDEADBEEF),
v_mov_b32_e32(v[2], s[0]),
global_store_b32(addr=v[0], data=v[2], saddr=SrcEnum.NULL, offset=TEST_OFFSET),
s_waitcnt(vmcnt=0),
FLAT(FLATOp.FLAT_LOAD_B32, addr=v[0], vdst=v[4], saddr=SrcEnum.NULL, offset=TEST_OFFSET),
s_waitcnt(vmcnt=0),
v_mov_b32_e32(v[0], 0),
v_mov_b32_e32(v[1], 0),
s_mov_b32(s[2], 0),
s_mov_b32(s[3], 0),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][4], 0xDEADBEEF)
def test_flat_load_b64(self):
"""FLAT_LOAD_B64 loads 64-bit value correctly."""
TEST_OFFSET = 2000
instructions = [
s_load_b64(s[2:3], s[80], 0, soffset=SrcEnum.NULL),
s_waitcnt(lgkmcnt=0),
v_mov_b32_e32(v[0], s[2]),
v_mov_b32_e32(v[1], s[3]),
s_mov_b32(s[0], 0xDEADBEEF),
v_mov_b32_e32(v[2], s[0]),
s_mov_b32(s[0], 0xCAFEBABE),
v_mov_b32_e32(v[3], s[0]),
global_store_b64(addr=v[0], data=v[2], saddr=SrcEnum.NULL, offset=TEST_OFFSET),
s_waitcnt(vmcnt=0),
FLAT(FLATOp.FLAT_LOAD_B64, addr=v[0], vdst=v[4], saddr=SrcEnum.NULL, offset=TEST_OFFSET),
s_waitcnt(vmcnt=0),
v_mov_b32_e32(v[0], 0),
v_mov_b32_e32(v[1], 0),
s_mov_b32(s[2], 0),
s_mov_b32(s[3], 0),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][4], 0xDEADBEEF)
self.assertEqual(st.vgpr[0][5], 0xCAFEBABE)
def test_flat_load_b96(self):
"""FLAT_LOAD_B96 loads 96-bit (3 dword) value correctly."""
TEST_OFFSET = 2000
instructions = [
s_load_b64(s[2:3], s[80], 0, soffset=SrcEnum.NULL),
s_waitcnt(lgkmcnt=0),
v_mov_b32_e32(v[0], s[2]),
v_mov_b32_e32(v[1], s[3]),
s_mov_b32(s[0], 0x11111111),
v_mov_b32_e32(v[2], s[0]),
s_mov_b32(s[0], 0x22222222),
v_mov_b32_e32(v[3], s[0]),
s_mov_b32(s[0], 0x33333333),
v_mov_b32_e32(v[4], s[0]),
global_store_b96(addr=v[0], data=v[2], saddr=SrcEnum.NULL, offset=TEST_OFFSET),
s_waitcnt(vmcnt=0),
FLAT(FLATOp.FLAT_LOAD_B96, addr=v[0], vdst=v[5], saddr=SrcEnum.NULL, offset=TEST_OFFSET),
s_waitcnt(vmcnt=0),
v_mov_b32_e32(v[0], 0),
v_mov_b32_e32(v[1], 0),
s_mov_b32(s[2], 0),
s_mov_b32(s[3], 0),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][5], 0x11111111)
self.assertEqual(st.vgpr[0][6], 0x22222222)
self.assertEqual(st.vgpr[0][7], 0x33333333)
def test_flat_load_b128(self):
"""FLAT_LOAD_B128 loads 128-bit value correctly."""
TEST_OFFSET = 2000
instructions = [
s_load_b64(s[2:3], s[80], 0, soffset=SrcEnum.NULL),
s_waitcnt(lgkmcnt=0),
v_mov_b32_e32(v[0], s[2]),
v_mov_b32_e32(v[1], s[3]),
s_mov_b32(s[0], 0x11111111),
v_mov_b32_e32(v[2], s[0]),
s_mov_b32(s[0], 0x22222222),
v_mov_b32_e32(v[3], s[0]),
s_mov_b32(s[0], 0x33333333),
v_mov_b32_e32(v[4], s[0]),
s_mov_b32(s[0], 0x44444444),
v_mov_b32_e32(v[5], s[0]),
global_store_b128(addr=v[0], data=v[2], saddr=SrcEnum.NULL, offset=TEST_OFFSET),
s_waitcnt(vmcnt=0),
FLAT(FLATOp.FLAT_LOAD_B128, addr=v[0], vdst=v[6], saddr=SrcEnum.NULL, offset=TEST_OFFSET),
s_waitcnt(vmcnt=0),
v_mov_b32_e32(v[0], 0),
v_mov_b32_e32(v[1], 0),
s_mov_b32(s[2], 0),
s_mov_b32(s[3], 0),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][6], 0x11111111)
self.assertEqual(st.vgpr[0][7], 0x22222222)
self.assertEqual(st.vgpr[0][8], 0x33333333)
self.assertEqual(st.vgpr[0][9], 0x44444444)
if __name__ == '__main__':
unittest.main()

View File

@@ -0,0 +1,364 @@
"""Tests for GLOBAL instructions - global memory operations.
Includes: global_load_*, global_store_*, global_atomic_*, global_load_d16_*
"""
import unittest
from extra.assembly.amd.test.hw.helpers import *
class TestGlobalAtomic(unittest.TestCase):
"""Tests for GLOBAL atomic instructions."""
def _make_test(self, setup_instrs, atomic_instr, check_fn, test_offset=2000):
"""Helper to create atomic test instructions."""
instructions = [
s_load_b64(s[2:3], s[80], 0, soffset=SrcEnum.NULL),
s_waitcnt(lgkmcnt=0),
v_mov_b32_e32(v[0], s[2]),
v_mov_b32_e32(v[1], s[3]),
] + setup_instrs + [atomic_instr, s_waitcnt(vmcnt=0),
v_mov_b32_e32(v[0], 0),
v_mov_b32_e32(v[1], 0),
s_mov_b32(s[2], 0),
s_mov_b32(s[3], 0),
]
st = run_program(instructions, n_lanes=1)
check_fn(st)
def test_global_atomic_add_u32(self):
"""GLOBAL_ATOMIC_ADD_U32 adds to memory and returns old value."""
TEST_OFFSET = 2000
setup = [
s_mov_b32(s[0], 100),
v_mov_b32_e32(v[2], s[0]),
global_store_b32(addr=v[0], data=v[2], saddr=SrcEnum.NULL, offset=TEST_OFFSET),
s_waitcnt(vmcnt=0),
s_mov_b32(s[0], 50),
v_mov_b32_e32(v[3], s[0]),
]
atomic = FLAT(GLOBALOp.GLOBAL_ATOMIC_ADD_U32, addr=v[0], data=v[3], vdst=v[4], saddr=SrcEnum.NULL, offset=TEST_OFFSET, glc=1, seg=2)
def check(st):
self.assertEqual(st.vgpr[0][4], 100)
self._make_test(setup, atomic, check, TEST_OFFSET)
def test_global_atomic_add_u64(self):
"""GLOBAL_ATOMIC_ADD_U64 adds 64-bit value and returns old value."""
TEST_OFFSET = 2000
setup = [
s_mov_b32(s[0], 0xFFFFFFFF),
v_mov_b32_e32(v[2], s[0]),
s_mov_b32(s[0], 0x00000000),
v_mov_b32_e32(v[3], s[0]),
global_store_b64(addr=v[0], data=v[2], saddr=SrcEnum.NULL, offset=TEST_OFFSET),
s_waitcnt(vmcnt=0),
s_mov_b32(s[0], 0x00000001),
v_mov_b32_e32(v[4], s[0]),
s_mov_b32(s[0], 0x00000000),
v_mov_b32_e32(v[5], s[0]),
]
atomic = FLAT(GLOBALOp.GLOBAL_ATOMIC_ADD_U64, addr=v[0], data=v[4], vdst=v[6], saddr=SrcEnum.NULL, offset=TEST_OFFSET, glc=1, seg=2)
def check(st):
self.assertEqual(st.vgpr[0][6], 0xFFFFFFFF)
self.assertEqual(st.vgpr[0][7], 0x00000000)
self._make_test(setup, atomic, check, TEST_OFFSET)
class TestGlobalLoad(unittest.TestCase):
"""Tests for GLOBAL load instructions."""
def test_global_load_b96(self):
"""GLOBAL_LOAD_B96 loads 96-bit value correctly."""
TEST_OFFSET = 2000
instructions = [
s_load_b64(s[2:3], s[80], 0, soffset=SrcEnum.NULL),
s_waitcnt(lgkmcnt=0),
v_mov_b32_e32(v[0], s[2]),
v_mov_b32_e32(v[1], s[3]),
s_mov_b32(s[0], 0xAAAAAAAA),
v_mov_b32_e32(v[2], s[0]),
s_mov_b32(s[0], 0xBBBBBBBB),
v_mov_b32_e32(v[3], s[0]),
s_mov_b32(s[0], 0xCCCCCCCC),
v_mov_b32_e32(v[4], s[0]),
global_store_b96(addr=v[0], data=v[2], saddr=SrcEnum.NULL, offset=TEST_OFFSET),
s_waitcnt(vmcnt=0),
FLAT(GLOBALOp.GLOBAL_LOAD_B96, addr=v[0], vdst=v[5], saddr=SrcEnum.NULL, offset=TEST_OFFSET, seg=2),
s_waitcnt(vmcnt=0),
v_mov_b32_e32(v[0], 0),
v_mov_b32_e32(v[1], 0),
s_mov_b32(s[2], 0),
s_mov_b32(s[3], 0),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][5], 0xAAAAAAAA)
self.assertEqual(st.vgpr[0][6], 0xBBBBBBBB)
self.assertEqual(st.vgpr[0][7], 0xCCCCCCCC)
def test_global_load_b128(self):
"""GLOBAL_LOAD_B128 loads 128-bit value correctly."""
TEST_OFFSET = 2000
instructions = [
s_load_b64(s[2:3], s[80], 0, soffset=SrcEnum.NULL),
s_waitcnt(lgkmcnt=0),
v_mov_b32_e32(v[0], s[2]),
v_mov_b32_e32(v[1], s[3]),
s_mov_b32(s[0], 0xDEADBEEF),
v_mov_b32_e32(v[2], s[0]),
s_mov_b32(s[0], 0xCAFEBABE),
v_mov_b32_e32(v[3], s[0]),
s_mov_b32(s[0], 0x12345678),
v_mov_b32_e32(v[4], s[0]),
s_mov_b32(s[0], 0x9ABCDEF0),
v_mov_b32_e32(v[5], s[0]),
global_store_b128(addr=v[0], data=v[2], saddr=SrcEnum.NULL, offset=TEST_OFFSET),
s_waitcnt(vmcnt=0),
FLAT(GLOBALOp.GLOBAL_LOAD_B128, addr=v[0], vdst=v[6], saddr=SrcEnum.NULL, offset=TEST_OFFSET, seg=2),
s_waitcnt(vmcnt=0),
v_mov_b32_e32(v[0], 0),
v_mov_b32_e32(v[1], 0),
s_mov_b32(s[2], 0),
s_mov_b32(s[3], 0),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][6], 0xDEADBEEF)
self.assertEqual(st.vgpr[0][7], 0xCAFEBABE)
self.assertEqual(st.vgpr[0][8], 0x12345678)
self.assertEqual(st.vgpr[0][9], 0x9ABCDEF0)
class TestGlobalStore(unittest.TestCase):
"""Tests for GLOBAL store instructions."""
def test_global_store_b64_basic(self):
"""GLOBAL_STORE_B64 stores 8 bytes from v[n:n+1] to memory."""
TEST_OFFSET = 256
instructions = [
s_load_b64(s[2:3], s[80], 0, soffset=SrcEnum.NULL),
s_waitcnt(lgkmcnt=0),
s_mov_b32(s[4], 0xDEADBEEF),
s_mov_b32(s[5], 0xCAFEBABE),
v_mov_b32_e32(v[2], s[4]),
v_mov_b32_e32(v[3], s[5]),
v_mov_b32_e32(v[0], 0),
global_store_b64(addr=v[0], data=v[2], saddr=s[2], offset=TEST_OFFSET),
s_waitcnt(vmcnt=0),
FLAT(GLOBALOp.GLOBAL_LOAD_B64, addr=v[0], vdst=v[4], data=v[4], saddr=s[2], offset=TEST_OFFSET, seg=2),
s_waitcnt(vmcnt=0),
v_mov_b32_e32(v[0], v[4]),
v_mov_b32_e32(v[1], v[5]),
s_mov_b32(s[2], 0),
s_mov_b32(s[3], 0),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][0], 0xDEADBEEF)
self.assertEqual(st.vgpr[0][1], 0xCAFEBABE)
class TestD16HiLoads(unittest.TestCase):
"""Tests for D16_HI load instructions that load into high 16 bits."""
def test_global_load_d16_hi_b16_preserves_low_bits(self):
"""GLOBAL_LOAD_D16_HI_B16 must preserve low 16 bits of destination."""
TEST_OFFSET = 256
instructions = [
s_load_b64(s[2:3], s[80], 0, soffset=SrcEnum.NULL),
s_waitcnt(lgkmcnt=0),
v_mov_b32_e32(v[0], s[2]),
v_mov_b32_e32(v[1], s[3]),
s_mov_b32(s[4], 0xCAFE),
v_mov_b32_e32(v[2], s[4]),
global_store_b16(addr=v[0], data=v[2], saddr=SrcEnum.NULL, offset=TEST_OFFSET),
s_waitcnt(vmcnt=0),
s_mov_b32(s[4], 0x0000BEEF),
v_mov_b32_e32(v[3], s[4]),
FLAT(GLOBALOp.GLOBAL_LOAD_D16_HI_B16, addr=v[0], vdst=v[3], data=v[3], saddr=SrcEnum.NULL, offset=TEST_OFFSET, seg=2),
s_waitcnt(vmcnt=0),
v_mov_b32_e32(v[0], v[3]),
v_mov_b32_e32(v[1], 0),
s_mov_b32(s[2], 0),
s_mov_b32(s[3], 0),
]
st = run_program(instructions, n_lanes=1)
result = st.vgpr[0][0]
self.assertEqual(result, 0xCAFEBEEF, f"Expected 0xCAFEBEEF, got 0x{result:08x}")
def test_global_load_d16_hi_b16_data_differs_from_vdst(self):
"""GLOBAL_LOAD_D16_HI_B16 where data field differs from vdst."""
TEST_OFFSET = 256
instructions = [
s_load_b64(s[2:3], s[80], 0, soffset=SrcEnum.NULL),
s_waitcnt(lgkmcnt=0),
s_mov_b32(s[4], 0xCAFE),
v_mov_b32_e32(v[2], s[4]),
v_mov_b32_e32(v[3], 0),
global_store_b16(addr=v[3], data=v[2], saddr=s[2], offset=TEST_OFFSET),
s_waitcnt(vmcnt=0),
s_mov_b32(s[4], 0x0000DEAD),
v_mov_b32_e32(v[0], s[4]), # data field - should NOT affect result
v_mov_b32_e32(v[1], 0), # vdst - low bits should be preserved
FLAT(GLOBALOp.GLOBAL_LOAD_D16_HI_B16, addr=v[1], vdst=v[1], data=v[0], saddr=s[2], offset=TEST_OFFSET, seg=2),
s_waitcnt(vmcnt=0),
v_mov_b32_e32(v[0], v[1]),
s_mov_b32(s[2], 0),
s_mov_b32(s[3], 0),
]
st = run_program(instructions, n_lanes=1)
result = st.vgpr[0][0]
self.assertEqual(result, 0xCAFE0000, f"Expected 0xCAFE0000, got 0x{result:08x}")
def test_global_load_d16_hi_u8_data_differs_from_vdst(self):
"""GLOBAL_LOAD_D16_HI_U8 where data field differs from vdst."""
TEST_OFFSET = 256
instructions = [
s_load_b64(s[2:3], s[80], 0, soffset=SrcEnum.NULL),
s_waitcnt(lgkmcnt=0),
s_mov_b32(s[4], 0xAB),
v_mov_b32_e32(v[2], s[4]),
v_mov_b32_e32(v[3], 0),
global_store_b8(addr=v[3], data=v[2], saddr=s[2], offset=TEST_OFFSET),
s_waitcnt(vmcnt=0),
s_mov_b32(s[4], 0x0000DEAD),
v_mov_b32_e32(v[4], s[4]), # data field
s_mov_b32(s[4], 0x0000BEEF),
v_mov_b32_e32(v[5], s[4]), # vdst
v_mov_b32_e32(v[3], 0),
FLAT(GLOBALOp.GLOBAL_LOAD_D16_HI_U8, addr=v[3], vdst=v[5], data=v[4], saddr=s[2], offset=TEST_OFFSET, seg=2),
s_waitcnt(vmcnt=0),
v_mov_b32_e32(v[0], v[5]),
s_mov_b32(s[2], 0),
s_mov_b32(s[3], 0),
]
st = run_program(instructions, n_lanes=1)
result = st.vgpr[0][0]
self.assertEqual(result, 0x00ABBEEF, f"Expected 0x00ABBEEF, got 0x{result:08x}")
def test_global_load_d16_hi_b16_same_addr_and_dst_zero_addr(self):
"""GLOBAL_LOAD_D16_HI_B16 with same register for addr and vdst, addr value=0."""
TEST_OFFSET = 256
instructions = [
s_load_b64(s[2:3], s[80], 0, soffset=SrcEnum.NULL),
s_waitcnt(lgkmcnt=0),
s_mov_b32(s[4], 0xCAFE),
v_mov_b32_e32(v[2], s[4]),
v_mov_b32_e32(v[3], 0),
global_store_b16(addr=v[3], data=v[2], saddr=s[2], offset=TEST_OFFSET),
s_waitcnt(vmcnt=0),
v_mov_b32_e32(v[1], 0),
FLAT(GLOBALOp.GLOBAL_LOAD_D16_HI_B16, addr=v[1], vdst=v[1], data=v[1], saddr=s[2], offset=TEST_OFFSET, seg=2),
s_waitcnt(vmcnt=0),
v_mov_b32_e32(v[0], v[1]),
s_mov_b32(s[2], 0),
s_mov_b32(s[3], 0),
]
st = run_program(instructions, n_lanes=1)
result = st.vgpr[0][0]
self.assertEqual(result, 0xCAFE0000, f"Expected 0xCAFE0000, got 0x{result:08x}")
def test_global_load_d16_hi_b16_tril_exact_pattern(self):
"""Exact pattern from tril() failure: data=v0 differs from vdst=v1."""
TEST_OFFSET = 256
instructions = [
s_load_b64(s[2:3], s[80], 0, soffset=SrcEnum.NULL),
s_waitcnt(lgkmcnt=0),
s_mov_b32(s[4], 0x01010101),
v_mov_b32_e32(v[10], s[4]),
v_mov_b32_e32(v[3], 0),
global_store_b32(addr=v[3], data=v[10], saddr=s[2], offset=TEST_OFFSET),
global_store_b32(addr=v[3], data=v[10], saddr=s[2], offset=TEST_OFFSET+4),
s_waitcnt(vmcnt=0),
# Set v[0] to 0x0101 (simulating prior u16 load result)
s_mov_b32(s[4], 0x0101),
v_mov_b32_e32(v[0], s[4]),
# Set v[1] to 0
v_mov_b32_e32(v[1], 0),
# Load using v[1] as addr AND vdst, but v[0] as data
FLAT(GLOBALOp.GLOBAL_LOAD_D16_HI_B16, addr=v[1], vdst=v[1], data=v[0], saddr=s[2], offset=TEST_OFFSET+6, seg=2),
s_waitcnt(vmcnt=0),
v_mov_b32_e32(v[0], v[1]),
s_mov_b32(s[2], 0),
s_mov_b32(s[3], 0),
]
st = run_program(instructions, n_lanes=1)
result = st.vgpr[0][0]
# Expected: hi=0x0101 (loaded), lo=0x0000 (from v1) -> 0x01010000
self.assertEqual(result, 0x01010000, f"Expected 0x01010000, got 0x{result:08x}")
def test_global_load_d16_hi_i8_data_differs_from_vdst(self):
"""GLOBAL_LOAD_D16_HI_I8 where data field differs from vdst."""
TEST_OFFSET = 256
instructions = [
s_load_b64(s[2:3], s[80], 0, soffset=SrcEnum.NULL),
s_waitcnt(lgkmcnt=0),
s_mov_b32(s[4], 0x80), # negative signed byte = -128
v_mov_b32_e32(v[2], s[4]),
v_mov_b32_e32(v[3], 0),
global_store_b8(addr=v[3], data=v[2], saddr=s[2], offset=TEST_OFFSET),
s_waitcnt(vmcnt=0),
s_mov_b32(s[4], 0x0000DEAD),
v_mov_b32_e32(v[4], s[4]), # data field
s_mov_b32(s[4], 0x0000BEEF),
v_mov_b32_e32(v[5], s[4]), # vdst
v_mov_b32_e32(v[3], 0),
FLAT(GLOBALOp.GLOBAL_LOAD_D16_HI_I8, addr=v[3], vdst=v[5], data=v[4], saddr=s[2], offset=TEST_OFFSET, seg=2),
s_waitcnt(vmcnt=0),
v_mov_b32_e32(v[0], v[5]),
s_mov_b32(s[2], 0),
s_mov_b32(s[3], 0),
]
st = run_program(instructions, n_lanes=1)
result = st.vgpr[0][0]
# 0x80 sign-extended = 0xFF80, lo=0xBEEF -> 0xFF80BEEF
self.assertEqual(result, 0xFF80BEEF, f"Expected 0xFF80BEEF, got 0x{result:08x}")
def test_global_store_b64_tril_pattern(self):
"""Test the exact pattern from tril() kernel that was failing."""
TEST_OFFSET = 256
instructions = [
s_load_b64(s[2:3], s[80], 0, soffset=SrcEnum.NULL),
s_waitcnt(lgkmcnt=0),
s_mov_b32(s[4], 0x01010101),
v_mov_b32_e32(v[10], s[4]),
v_mov_b32_e32(v[11], s[4]),
s_mov_b32(s[4], 0x01),
v_mov_b32_e32(v[12], s[4]),
v_mov_b32_e32(v[0], 0),
global_store_b64(addr=v[0], data=v[10], saddr=s[2], offset=TEST_OFFSET),
global_store_b8(addr=v[0], data=v[12], saddr=s[2], offset=TEST_OFFSET+8),
s_waitcnt(vmcnt=0),
v_mov_b32_e32(v[2], 0),
v_mov_b32_e32(v[1], 0),
FLAT(GLOBALOp.GLOBAL_LOAD_U16, addr=v[2], vdst=v[0], data=v[0], saddr=s[2], offset=TEST_OFFSET+3, seg=2),
FLAT(GLOBALOp.GLOBAL_LOAD_D16_HI_B16, addr=v[1], vdst=v[1], data=v[1], saddr=s[2], offset=TEST_OFFSET+6, seg=2),
FLAT(GLOBALOp.GLOBAL_LOAD_U8, addr=v[2], vdst=v[3], data=v[3], saddr=s[2], offset=TEST_OFFSET, seg=2),
FLAT(GLOBALOp.GLOBAL_LOAD_U8, addr=v[2], vdst=v[4], data=v[4], saddr=s[2], offset=TEST_OFFSET+8, seg=2),
s_waitcnt(vmcnt=0),
v_and_b32_e32(v[5], 0xffff, v[0]),
v_lshlrev_b32_e32(v[0], 24, v[0]),
v_lshrrev_b32_e32(v[5], 8, v[5]),
v_or_b32_e32(v[0], v[3], v[0]),
v_or_b32_e32(v[1], v[5], v[1]),
global_store_b64(addr=v[2], data=v[0], saddr=s[2], offset=TEST_OFFSET+16),
s_waitcnt(vmcnt=0),
FLAT(GLOBALOp.GLOBAL_LOAD_B64, addr=v[2], vdst=v[6], data=v[6], saddr=s[2], offset=TEST_OFFSET+16, seg=2),
s_waitcnt(vmcnt=0),
v_mov_b32_e32(v[0], v[6]),
v_mov_b32_e32(v[1], v[7]),
s_mov_b32(s[2], 0),
s_mov_b32(s[3], 0),
]
st = run_program(instructions, n_lanes=1)
v0 = st.vgpr[0][0]
v1 = st.vgpr[0][1]
self.assertEqual(v0, 0x01000001, f"v0: expected 0x01000001, got 0x{v0:08x}")
self.assertEqual(v1, 0x01010001, f"v1: expected 0x01010001, got 0x{v1:08x}")
byte5 = (v1 >> 8) & 0xff
self.assertEqual(byte5, 0x00, f"byte5: expected 0x00, got 0x{byte5:02x}")
if __name__ == '__main__':
unittest.main()

View File

@@ -0,0 +1,205 @@
"""Tests for SOP instructions - scalar operations.
Includes: s_add_u32, s_mov_b32, s_and_b32, s_or_b32, s_quadmask_b32, s_wqm_b32,
s_cbranch_vccnz, s_cbranch_vccz
"""
import unittest
from extra.assembly.amd.test.hw.helpers import *
class TestBasicScalar(unittest.TestCase):
"""Tests for basic scalar operations."""
def test_s_add_u32(self):
"""S_ADD_U32 adds two scalar values."""
instructions = [
s_mov_b32(s[0], 100),
s_mov_b32(s[1], 200),
s_add_u32(s[2], s[0], s[1]),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.sgpr[2], 300)
def test_s_add_u32_carry(self):
"""S_ADD_U32 sets SCC on overflow."""
instructions = [
s_mov_b32(s[0], 64),
s_not_b32(s[0], s[0]), # ~64 = 0xffffffbf
s_mov_b32(s[1], 64),
s_add_u32(s[2], s[0], s[1]), # 0xffffffbf + 64 = 0xffffffff
s_mov_b32(s[3], 1),
s_add_u32(s[4], s[2], s[3]), # 0xffffffff + 1 = overflow
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.sgpr[4], 0)
self.assertEqual(st.scc, 1)
class TestQuadmaskWqm(unittest.TestCase):
"""Tests for S_QUADMASK_B32 and S_WQM_B32."""
def test_s_quadmask_b32_all_quads_active(self):
"""S_QUADMASK_B32 with all quads active."""
instructions = [
s_mov_b32(s[0], 0xFFFFFFFF), # All lanes active
s_quadmask_b32(s[1], s[0]),
]
st = run_program(instructions, n_lanes=1)
# Each quad (4 lanes) with any bit set -> 1 bit in result
# 32 lanes = 8 quads, all active -> 0xFF
self.assertEqual(st.sgpr[1], 0xFF)
def test_s_quadmask_b32_alternating_quads(self):
"""S_QUADMASK_B32 with alternating quads active."""
instructions = [
s_mov_b32(s[0], 0x0F0F0F0F), # Quads 0,2,4,6 active
s_quadmask_b32(s[1], s[0]),
]
st = run_program(instructions, n_lanes=1)
# Quads 0,2,4,6 have at least one bit -> 0b01010101 = 0x55
self.assertEqual(st.sgpr[1], 0x55)
def test_s_quadmask_b32_no_quads_active(self):
"""S_QUADMASK_B32 with no quads active."""
instructions = [
s_mov_b32(s[0], 0),
s_quadmask_b32(s[1], s[0]),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.sgpr[1], 0)
def test_s_quadmask_b32_single_lane_per_quad(self):
"""S_QUADMASK_B32 with single lane active in each quad."""
instructions = [
s_mov_b32(s[0], 0x11111111), # Bit 0 of each nibble
s_quadmask_b32(s[1], s[0]),
]
st = run_program(instructions, n_lanes=1)
# All 8 quads have at least one lane -> 0xFF
self.assertEqual(st.sgpr[1], 0xFF)
def test_s_wqm_b32_all_active(self):
"""S_WQM_B32 with all lanes active returns all 1s."""
instructions = [
s_mov_b32(s[0], 0xFFFFFFFF),
s_wqm_b32(s[1], s[0]),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.sgpr[1], 0xFFFFFFFF)
def test_s_wqm_b32_alternating_quads(self):
"""S_WQM_B32 with single lane per quad expands to full quads."""
instructions = [
s_mov_b32(s[0], 0x11111111), # One lane per quad
s_wqm_b32(s[1], s[0]),
]
st = run_program(instructions, n_lanes=1)
# Each quad with any bit expands to all 4 bits
self.assertEqual(st.sgpr[1], 0xFFFFFFFF)
def test_s_wqm_b32_zero(self):
"""S_WQM_B32 with zero input returns zero."""
instructions = [
s_mov_b32(s[0], 0),
s_wqm_b32(s[1], s[0]),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.sgpr[1], 0)
class TestBranch(unittest.TestCase):
"""Tests for branch instructions."""
def test_cbranch_vccnz_ignores_vcc_hi(self):
"""S_CBRANCH_VCCNZ should only check VCC_LO in wave32."""
instructions = [
# Set VCC_LO = 0, VCC_HI = 1
s_mov_b32(s[SrcEnum.VCC_LO - 128], 0),
s_mov_b32(s[SrcEnum.VCC_HI - 128], 1),
v_mov_b32_e32(v[0], 0),
# If VCC_HI is incorrectly used, branch will be taken
s_cbranch_vccnz(1), # Skip next instruction if VCC != 0
v_mov_b32_e32(v[0], 42), # This should execute
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][0], 42, "Branch should NOT be taken (VCC_LO is 0)")
def test_cbranch_vccz_ignores_vcc_hi(self):
"""S_CBRANCH_VCCZ should only check VCC_LO in wave32."""
instructions = [
# Set VCC_LO = 1, VCC_HI = 0
s_mov_b32(s[SrcEnum.VCC_LO - 128], 1),
s_mov_b32(s[SrcEnum.VCC_HI - 128], 0),
v_mov_b32_e32(v[0], 0),
# If VCC_HI is incorrectly used, branch will be taken
s_cbranch_vccz(1), # Skip next instruction if VCC == 0
v_mov_b32_e32(v[0], 42), # This should execute
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][0], 42, "Branch should NOT be taken (VCC_LO is 1)")
def test_cbranch_vccnz_branches_on_vcc_lo(self):
"""S_CBRANCH_VCCNZ branches when VCC_LO is non-zero."""
instructions = [
s_mov_b32(s[SrcEnum.VCC_LO - 128], 1),
v_mov_b32_e32(v[0], 0),
s_cbranch_vccnz(1), # Skip next instruction if VCC != 0
v_mov_b32_e32(v[0], 42), # This should be skipped
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][0], 0, "Branch should be taken (VCC_LO is 1)")
class Test64BitLiterals(unittest.TestCase):
"""Tests for 64-bit literal encoding in instructions."""
def test_64bit_literal_negative_encoding(self):
"""64-bit literal -2^32 encodes correctly."""
lit = -4294967296.0 # -2^32
lit_bits = f2i64(lit)
instructions = [
s_mov_b32(s[0], lit_bits & 0xffffffff),
s_mov_b32(s[1], lit_bits >> 32),
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
]
st = run_program(instructions, n_lanes=1)
result = i642f(st.vgpr[0][0] | (st.vgpr[0][1] << 32))
self.assertAlmostEqual(result, -4294967296.0, places=5)
def test_64bit_literal_positive_encoding(self):
"""64-bit instruction encodes large positive literals correctly."""
large_val = 0x12345678
inst = v_add_f64(v[2], v[0], large_val)
self.assertIsNotNone(inst._literal, "Literal should be set")
actual_lit = (inst._literal >> 32) & 0xffffffff
self.assertEqual(actual_lit, large_val, f"Literal should be {large_val:#x}, got {actual_lit:#x}")
class TestSCCBehavior(unittest.TestCase):
"""Tests for SCC condition code behavior."""
def test_scc_from_s_cmp(self):
"""SCC should be set by scalar compare."""
instructions = [
s_mov_b32(s[0], 10),
s_cmp_eq_u32(s[0], 10),
s_cselect_b32(s[1], 1, 0),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.sgpr[1], 1, "SCC should be true")
self.assertEqual(st.scc, 1)
def test_scc_clear(self):
"""SCC should be cleared by failing compare."""
instructions = [
s_mov_b32(s[0], 10),
s_cmp_eq_u32(s[0], 20),
s_cselect_b32(s[1], 1, 0),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.sgpr[1], 0, "SCC should be false")
self.assertEqual(st.scc, 0)
if __name__ == '__main__':
unittest.main()

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,451 @@
"""Tests for VOP2 instructions - two operand vector operations.
Includes: v_add_f32, v_mul_f32, v_and_b32, v_or_b32, v_xor_b32,
v_lshrrev_b32, v_lshlrev_b32, v_fmac_f32, v_fmaak_f32, v_fmamk_f32,
v_add_nc_u32, v_cndmask_b32, v_add_f16, v_mul_f16
"""
import unittest
from extra.assembly.amd.test.hw.helpers import *
class TestBasicArithmetic(unittest.TestCase):
"""Tests for basic arithmetic VOP2 instructions."""
def test_v_add_f32(self):
"""V_ADD_F32 adds two floats."""
instructions = [
v_mov_b32_e32(v[0], 1.0),
v_mov_b32_e32(v[1], 2.0),
v_add_f32_e32(v[2], v[0], v[1]),
]
st = run_program(instructions, n_lanes=1)
self.assertAlmostEqual(i2f(st.vgpr[0][2]), 3.0, places=5)
def test_v_mul_f32(self):
"""V_MUL_F32 multiplies two floats."""
instructions = [
v_mov_b32_e32(v[0], 2.0),
v_mov_b32_e32(v[1], 4.0),
v_mul_f32_e32(v[2], v[0], v[1]),
]
st = run_program(instructions, n_lanes=1)
self.assertAlmostEqual(i2f(st.vgpr[0][2]), 8.0, places=5)
def test_v_fmac_f32(self):
"""V_FMAC_F32: d = d + a*b using inline constants."""
instructions = [
v_mov_b32_e32(v[0], 2.0),
v_mov_b32_e32(v[1], 4.0),
v_mov_b32_e32(v[2], 1.0),
v_fmac_f32_e32(v[2], v[0], v[1]),
]
st = run_program(instructions, n_lanes=1)
self.assertAlmostEqual(i2f(st.vgpr[0][2]), 9.0, places=5)
def test_v_fmaak_f32(self):
"""V_FMAAK_F32: d = a * b + K using inline constants."""
instructions = [
v_mov_b32_e32(v[0], 2.0),
v_mov_b32_e32(v[1], 4.0),
v_fmaak_f32_e32(v[2], v[0], v[1], 0x3f800000),
]
st = run_program(instructions, n_lanes=1)
self.assertAlmostEqual(i2f(st.vgpr[0][2]), 9.0, places=5)
def test_v_fmamk_f32_basic(self):
"""V_FMAMK_F32: d = a * K + b."""
instructions = [
v_mov_b32_e32(v[0], 2.0),
v_mov_b32_e32(v[1], 1.0),
v_fmamk_f32_e32(v[2], v[0], 0x40800000, v[1]),
]
st = run_program(instructions, n_lanes=1)
self.assertAlmostEqual(i2f(st.vgpr[0][2]), 9.0, places=5)
def test_v_fmamk_f32_small_constant(self):
"""V_FMAMK_F32 with small constant."""
instructions = [
v_mov_b32_e32(v[0], 4.0),
v_mov_b32_e32(v[1], 1.0),
v_fmamk_f32_e32(v[2], v[0], f2i(0.5), v[1]),
]
st = run_program(instructions, n_lanes=1)
self.assertAlmostEqual(i2f(st.vgpr[0][2]), 3.0, places=5)
class TestBitManipulation(unittest.TestCase):
"""Tests for bit manipulation VOP2 instructions."""
def test_v_and_b32(self):
"""V_AND_B32 bitwise and."""
instructions = [
s_mov_b32(s[0], 0xff),
s_mov_b32(s[1], 0x0f),
v_mov_b32_e32(v[0], s[0]),
v_and_b32_e32(v[1], s[1], v[0]),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][1], 0x0f)
def test_v_and_b32_quadrant(self):
"""V_AND_B32 for quadrant extraction (n & 3)."""
instructions = [
s_mov_b32(s[0], 15915),
v_mov_b32_e32(v[0], s[0]),
v_and_b32_e32(v[1], 3, v[0]),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][1], 15915 & 3)
def test_v_lshrrev_b32(self):
"""V_LSHRREV_B32 logical shift right."""
instructions = [
s_mov_b32(s[0], 0xff00),
v_mov_b32_e32(v[0], s[0]),
v_lshrrev_b32_e32(v[1], 8, v[0]),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][1], 0xff)
def test_v_lshlrev_b32(self):
"""V_LSHLREV_B32 logical shift left."""
instructions = [
s_mov_b32(s[0], 0xff),
v_mov_b32_e32(v[0], s[0]),
v_lshlrev_b32_e32(v[1], 8, v[0]),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][1], 0xff00)
def test_v_xor_b32(self):
"""V_XOR_B32 bitwise xor (used in sin for sign)."""
instructions = [
s_mov_b32(s[0], 0x80000000),
s_mov_b32(s[1], f2i(1.0)),
v_mov_b32_e32(v[0], s[1]),
v_xor_b32_e32(v[1], s[0], v[0]),
]
st = run_program(instructions, n_lanes=1)
self.assertAlmostEqual(i2f(st.vgpr[0][1]), -1.0, places=5)
def test_v_xor_b32_sign_flip(self):
"""V_XOR_B32 for sign flip pattern."""
instructions = [
s_mov_b32(s[0], 0x80000000),
v_mov_b32_e32(v[0], -2.0),
v_xor_b32_e32(v[1], s[0], v[0]),
]
st = run_program(instructions, n_lanes=1)
self.assertAlmostEqual(i2f(st.vgpr[0][1]), 2.0, places=5)
class TestSpecialValues(unittest.TestCase):
"""Tests for special float values - inf, nan, zero handling."""
def test_v_mul_f32_zero_times_inf(self):
"""V_MUL_F32: 0 * inf = NaN."""
import math
instructions = [
v_mov_b32_e32(v[0], 0),
s_mov_b32(s[0], 0x7f800000),
v_mov_b32_e32(v[1], s[0]),
v_mul_f32_e32(v[2], v[0], v[1]),
]
st = run_program(instructions, n_lanes=1)
self.assertTrue(math.isnan(i2f(st.vgpr[0][2])))
def test_v_add_f32_inf_minus_inf(self):
"""V_ADD_F32: inf + (-inf) = NaN."""
import math
instructions = [
s_mov_b32(s[0], 0x7f800000),
s_mov_b32(s[1], 0xff800000),
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
v_add_f32_e32(v[2], v[0], v[1]),
]
st = run_program(instructions, n_lanes=1)
self.assertTrue(math.isnan(i2f(st.vgpr[0][2])))
class TestF16Ops(unittest.TestCase):
"""Tests for 16-bit VOP2 operations."""
def test_v_add_f16_basic(self):
"""V_ADD_F16 adds two f16 values."""
instructions = [
s_mov_b32(s[0], 0x3c00), # f16 1.0
s_mov_b32(s[1], 0x4000), # f16 2.0
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
v_add_f16_e32(v[2], v[0], v[1]),
]
st = run_program(instructions, n_lanes=1)
result = st.vgpr[0][2] & 0xffff
self.assertEqual(result, 0x4200, f"Expected 0x4200 (f16 3.0), got 0x{result:04x}")
def test_v_add_f16_negative(self):
"""V_ADD_F16 with negative values."""
instructions = [
s_mov_b32(s[0], 0x3c00), # f16 1.0
s_mov_b32(s[1], 0xc000), # f16 -2.0
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
v_add_f16_e32(v[2], v[0], v[1]),
]
st = run_program(instructions, n_lanes=1)
result = st.vgpr[0][2] & 0xffff
self.assertEqual(result, 0xbc00, f"Expected 0xbc00 (f16 -1.0), got 0x{result:04x}")
def test_v_mul_f16_basic(self):
"""V_MUL_F16 multiplies two f16 values."""
instructions = [
s_mov_b32(s[0], 0x4000), # f16 2.0
s_mov_b32(s[1], 0x4200), # f16 3.0
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
v_mul_f16_e32(v[2], v[0], v[1]),
]
st = run_program(instructions, n_lanes=1)
result = st.vgpr[0][2] & 0xffff
self.assertEqual(result, 0x4600, f"Expected 0x4600 (f16 6.0), got 0x{result:04x}")
def test_v_mul_f16_by_zero(self):
"""V_MUL_F16 by zero."""
instructions = [
s_mov_b32(s[0], 0x4000), # f16 2.0
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], 0),
v_mul_f16_e32(v[2], v[0], v[1]),
]
st = run_program(instructions, n_lanes=1)
result = st.vgpr[0][2] & 0xffff
self.assertEqual(result, 0x0000, f"Expected 0x0000 (f16 0.0), got 0x{result:04x}")
def test_v_fmac_f16_basic(self):
"""V_FMAC_F16: d = d + a*b."""
instructions = [
s_mov_b32(s[0], 0x4000), # f16 2.0
s_mov_b32(s[1], 0x4200), # f16 3.0
s_mov_b32(s[2], 0x3c00), # f16 1.0
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
v_mov_b32_e32(v[2], s[2]),
v_fmac_f16_e32(v[2], v[0], v[1]),
]
st = run_program(instructions, n_lanes=1)
result = st.vgpr[0][2] & 0xffff
# 2.0 * 3.0 + 1.0 = 7.0, f16 7.0 = 0x4700
self.assertEqual(result, 0x4700, f"Expected 0x4700 (f16 7.0), got 0x{result:04x}")
def test_v_fmaak_f16_basic(self):
"""V_FMAAK_F16: d = a * b + K."""
instructions = [
s_mov_b32(s[0], 0x4000), # f16 2.0
s_mov_b32(s[1], 0x4200), # f16 3.0
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
v_fmaak_f16_e32(v[2], v[0], v[1], 0x3c00), # + f16 1.0
]
st = run_program(instructions, n_lanes=1)
result = st.vgpr[0][2] & 0xffff
# 2.0 * 3.0 + 1.0 = 7.0, f16 7.0 = 0x4700
self.assertEqual(result, 0x4700, f"Expected 0x4700 (f16 7.0), got 0x{result:04x}")
class TestHiHalfOps(unittest.TestCase):
"""Tests for VOP2 16-bit operations with hi-half operands."""
def test_v_add_f16_src0_hi_fold(self):
"""V_ADD_F16 with src0 hi-half fold (same register, different halves)."""
instructions = [
s_mov_b32(s[0], 0x40003c00), # lo=f16(1.0), hi=f16(2.0)
v_mov_b32_e32(v[0], s[0]),
VOP3(VOP3Op.V_ADD_F16, vdst=v[1], src0=v[0], src1=v[0], opsel=0b0001),
]
st = run_program(instructions, n_lanes=1)
result = st.vgpr[0][1] & 0xffff
self.assertEqual(result, 0x4200, f"Expected f16(3.0)=0x4200, got 0x{result:04x}")
def test_v_add_f16_src0_hi_different_reg(self):
"""V_ADD_F16 with src0 hi-half from different register."""
instructions = [
s_mov_b32(s[0], 0x40000000), # hi=f16(2.0), lo=0
s_mov_b32(s[1], 0x00003c00), # hi=0, lo=f16(1.0)
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
VOP3(VOP3Op.V_ADD_F16, vdst=v[2], src0=v[0], src1=v[1], opsel=0b0001),
]
st = run_program(instructions, n_lanes=1)
result = st.vgpr[0][2] & 0xffff
self.assertEqual(result, 0x4200, f"Expected f16(3.0)=0x4200, got 0x{result:04x}")
def test_v_mul_f16_src0_hi(self):
"""V_MUL_F16 with src0 from high half."""
instructions = [
s_mov_b32(s[0], 0x40000000), # hi=f16(2.0), lo=0
s_mov_b32(s[1], 0x00004200), # hi=0, lo=f16(3.0)
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
VOP3(VOP3Op.V_MUL_F16, vdst=v[2], src0=v[0], src1=v[1], opsel=0b0001),
]
st = run_program(instructions, n_lanes=1)
result = st.vgpr[0][2] & 0xffff
self.assertEqual(result, 0x4600, f"Expected f16(6.0)=0x4600, got 0x{result:04x}")
def test_v_mul_f16_hi_half(self):
"""V_MUL_F16 reading from high half."""
instructions = [
s_mov_b32(s[0], 0x40003c00), # lo=1.0, hi=2.0
v_mov_b32_e32(v[0], s[0]),
VOP3(VOP3Op.V_MUL_F16, vdst=v[1], src0=v[0], src1=v[0], opsel=0b0011),
]
st = run_program(instructions, n_lanes=1)
result = st.vgpr[0][1] & 0xffff
self.assertEqual(result, 0x4400, f"Expected f16(4.0)=0x4400, got 0x{result:04x}")
def test_v_fma_f16_hi_dest(self):
"""V_FMA_F16 writing to high half with opsel.
Uses V_FMA_F16 (not V_FMAC_F16) because it has explicit src2 operand
which makes opsel handling clearer.
"""
instructions = [
s_mov_b32(s[0], 0x3c000000), # hi=f16(1.0), lo=0
s_mov_b32(s[1], 0x4000), # f16(2.0) in lo
s_mov_b32(s[2], 0x4200), # f16(3.0) in lo
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
v_mov_b32_e32(v[2], s[2]),
# V_FMA_F16: dst = src0 * src1 + src2
# opsel=0b1100: bit2=src2 hi, bit3=dst hi
# So: v[0].hi = v[1].lo * v[2].lo + v[0].hi = 2.0 * 3.0 + 1.0 = 7.0
VOP3(VOP3Op.V_FMA_F16, vdst=v[0], src0=v[1], src1=v[2], src2=v[0], opsel=0b1100),
]
st = run_program(instructions, n_lanes=1)
hi = (st.vgpr[0][0] >> 16) & 0xffff
# 2.0 * 3.0 + 1.0 = 7.0, f16 7.0 = 0x4700
self.assertEqual(hi, 0x4700, f"Expected f16(7.0)=0x4700 in hi, got 0x{hi:04x}")
def test_v_add_f16_multilane(self):
"""V_ADD_F16 with multiple lanes."""
instructions = [
s_mov_b32(s[0], 0x3c00), # f16 1.0
s_mov_b32(s[1], 0x4000), # f16 2.0
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
v_add_f16_e32(v[2], v[0], v[1]),
]
st = run_program(instructions, n_lanes=4)
for lane in range(4):
result = st.vgpr[lane][2] & 0xffff
self.assertEqual(result, 0x4200, f"Lane {lane}: expected 0x4200, got 0x{result:04x}")
class TestCndmask(unittest.TestCase):
"""Tests for V_CNDMASK_B32 and V_CNDMASK_B16."""
def test_v_cndmask_b16_select_src0(self):
"""V_CNDMASK_B16 selects src0 when VCC bit is 0."""
instructions = [
s_mov_b32(s[SrcEnum.VCC_LO - 128], 0), # VCC = 0
s_mov_b32(s[0], 0x3c00), # f16 1.0
s_mov_b32(s[1], 0x4000), # f16 2.0
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
v_cndmask_b16(v[2], v[0], v[1], VCC),
]
st = run_program(instructions, n_lanes=1)
result = st.vgpr[0][2] & 0xffff
self.assertEqual(result, 0x3c00, f"Expected src0=0x3c00, got 0x{result:04x}")
def test_v_cndmask_b16_select_src1(self):
"""V_CNDMASK_B16 selects src1 when VCC bit is 1."""
instructions = [
s_mov_b32(s[SrcEnum.VCC_LO - 128], 1), # VCC = 1
s_mov_b32(s[0], 0x3c00), # f16 1.0
s_mov_b32(s[1], 0x4000), # f16 2.0
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
v_cndmask_b16(v[2], v[0], v[1], VCC),
]
st = run_program(instructions, n_lanes=1)
result = st.vgpr[0][2] & 0xffff
self.assertEqual(result, 0x4000, f"Expected src1=0x4000, got 0x{result:04x}")
def test_v_cndmask_b16_write_hi(self):
"""V_CNDMASK_B16 can write to high 16 bits with opsel."""
instructions = [
s_mov_b32(s[0], 0x3c003800), # src0: hi=1.0, lo=0.5
v_mov_b32_e32(v[0], s[0]),
s_mov_b32(s[1], 0x4000c000), # src1: hi=2.0, lo=-2.0
v_mov_b32_e32(v[1], s[1]),
s_mov_b32(s[2], 0xDEAD0000), # v2 initial: hi=0xDEAD, lo=0
v_mov_b32_e32(v[2], s[2]),
s_mov_b32(s[SrcEnum.VCC_LO - 128], 0), # vcc = 0, select src0
# opsel=0b1011: bit0=src0 hi, bit1=src1 hi, bit3=dst hi
VOP3(VOP3Op.V_CNDMASK_B16, vdst=v[2], src0=v[0], src1=v[1], src2=SrcEnum.VCC_LO, opsel=0b1011),
]
st = run_program(instructions, n_lanes=1)
hi = (st.vgpr[0][2] >> 16) & 0xffff
lo = st.vgpr[0][2] & 0xffff
# vcc=0 selects src0.h = 1.0 = 0x3c00, writes to hi
self.assertEqual(hi, 0x3c00, f"Expected hi=0x3c00 (1.0), got 0x{hi:04x}")
self.assertEqual(lo, 0x0000, f"Expected lo preserved as 0, got 0x{lo:04x}")
class TestSpecialFloatValues(unittest.TestCase):
"""Tests for special float value handling in VOP2 instructions."""
def test_neg_zero_add(self):
"""-0.0 + 0.0 = +0.0 (IEEE 754)."""
neg_zero = 0x80000000
instructions = [
s_mov_b32(s[0], neg_zero),
v_mov_b32_e32(v[0], s[0]),
v_add_f32_e32(v[1], 0.0, v[0]),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][1], 0x00000000, "Should be +0.0")
def test_neg_zero_mul(self):
"""-0.0 * -1.0 = +0.0."""
neg_zero = 0x80000000
instructions = [
s_mov_b32(s[0], neg_zero),
v_mov_b32_e32(v[0], s[0]),
v_mul_f32_e32(v[1], -1.0, v[0]),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][1], 0x00000000, "Should be +0.0")
def test_inf_minus_inf(self):
"""+inf - inf = NaN."""
import math
pos_inf = 0x7f800000
neg_inf = 0xff800000
instructions = [
s_mov_b32(s[0], pos_inf),
s_mov_b32(s[1], neg_inf),
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
v_sub_f32_e32(v[2], v[0], v[1]), # inf - (-inf) = inf
v_add_f32_e32(v[3], v[0], v[1]), # inf + (-inf) = NaN
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][2], pos_inf, "inf - (-inf) = inf")
self.assertTrue(math.isnan(i2f(st.vgpr[0][3])), "inf + (-inf) = NaN")
def test_denormal_f32_mul_ftz(self):
"""Denormal * normal - RDNA3 flushes denormals to zero (FTZ mode)."""
smallest_denorm = 0x00000001 # Smallest positive denormal
instructions = [
s_mov_b32(s[0], smallest_denorm),
v_mov_b32_e32(v[0], s[0]),
v_mul_f32_e32(v[1], 2.0, v[0]), # Denormal input gets flushed to 0
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][1], 0x00000000)
if __name__ == '__main__':
unittest.main()

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,538 @@
"""Tests for VOP3P instructions - packed 16-bit vector operations.
Includes: v_pk_add_f16, v_pk_mul_f16, v_pk_fma_f16, v_pack_b32_f16, v_wmma_*, v_dot2_*
"""
import unittest
from extra.assembly.amd.test.hw.helpers import *
class TestPackInstructions(unittest.TestCase):
"""Tests for pack instructions."""
def test_v_pack_b32_f16(self):
"""V_PACK_B32_F16 packs two f16 values into one 32-bit register."""
instructions = [
s_mov_b32(s[0], 0x3c00), # f16 1.0
s_mov_b32(s[1], 0x4000), # f16 2.0
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
v_pack_b32_f16(v[2], v[0], v[1]),
]
st = run_program(instructions, n_lanes=1)
result = st.vgpr[0][2]
self.assertEqual(result, 0x40003c00, f"Expected 0x40003c00, got 0x{result:08x}")
def test_v_pack_b32_f16_opsel_hi_hi(self):
"""V_PACK_B32_F16 with opsel to read high halves."""
inst = v_pack_b32_f16(v[2], v[0], v[1])
inst._values['opsel'] = 0b0011
instructions = [
s_mov_b32(s[0], 0x40003c00), # hi=2.0, lo=1.0
s_mov_b32(s[1], 0x44004200), # hi=4.0, lo=3.0
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
inst,
]
st = run_program(instructions, n_lanes=1)
result = st.vgpr[0][2]
self.assertEqual(result, 0x44004000, f"Expected 0x44004000, got 0x{result:08x}")
class TestPackMore(unittest.TestCase):
"""Additional pack instruction tests."""
def test_v_pack_b32_f16_basic(self):
"""V_PACK_B32_F16 packs two f16 values."""
instructions = [
s_mov_b32(s[0], 0x3c00), # f16 1.0
s_mov_b32(s[1], 0x4000), # f16 2.0
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
v_pack_b32_f16(v[2], v[0], v[1]),
]
st = run_program(instructions, n_lanes=1)
result = st.vgpr[0][2]
self.assertEqual(result, 0x40003c00, f"Expected 0x40003c00, got 0x{result:08x}")
def test_v_pack_b32_f16_with_cvt(self):
"""V_PACK_B32_F16 after V_CVT_F16_F32 conversions."""
instructions = [
s_mov_b32(s[0], 0x3f800000),
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[0]),
v_cvt_f16_f32_e32(v[2], v[0]),
v_cvt_f16_f32_e32(v[3], v[1]),
v_pack_b32_f16(v[4], v[2], v[3]),
]
st = run_program(instructions, n_lanes=1)
result = st.vgpr[0][4]
self.assertEqual(result, 0x3c003c00, f"Expected 0x3c003c00, got 0x{result:08x}")
def test_v_pack_b32_f16_packed_sources(self):
"""V_PACK_B32_F16 with packed f16 sources (reads lo halves)."""
instructions = [
s_mov_b32(s[0], 0x40003c00), # hi=2.0, lo=1.0
s_mov_b32(s[1], 0x44004200), # hi=4.0, lo=3.0
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
v_pack_b32_f16(v[2], v[0], v[1]),
]
st = run_program(instructions, n_lanes=1)
result = st.vgpr[0][2]
# Expected: hi=v1.lo=0x4200 (3.0), lo=v0.lo=0x3c00 (1.0) -> 0x42003c00
self.assertEqual(result, 0x42003c00, f"Expected 0x42003c00, got 0x{result:08x}")
def test_v_pack_b32_f16_opsel_lo_hi(self):
"""V_PACK_B32_F16 with opsel=0b0010 to read lo from src0, hi from src1."""
inst = v_pack_b32_f16(v[2], v[0], v[1])
inst._values['opsel'] = 0b0010
instructions = [
s_mov_b32(s[0], 0x40003c00),
s_mov_b32(s[1], 0x44004200),
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
inst,
]
st = run_program(instructions, n_lanes=1)
result = st.vgpr[0][2]
self.assertEqual(result, 0x44003c00, f"Expected 0x44003c00, got 0x{result:08x}")
def test_v_pack_b32_f16_opsel_hi_lo(self):
"""V_PACK_B32_F16 with opsel=0b0001 to read hi from src0, lo from src1."""
inst = v_pack_b32_f16(v[2], v[0], v[1])
inst._values['opsel'] = 0b0001
instructions = [
s_mov_b32(s[0], 0x40003c00),
s_mov_b32(s[1], 0x44004200),
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
inst,
]
st = run_program(instructions, n_lanes=1)
result = st.vgpr[0][2]
self.assertEqual(result, 0x42004000, f"Expected 0x42004000, got 0x{result:08x}")
def test_v_pack_b32_f16_zeros(self):
"""V_PACK_B32_F16 with zero values."""
instructions = [
v_mov_b32_e32(v[0], 0),
v_mov_b32_e32(v[1], 0),
v_pack_b32_f16(v[2], v[0], v[1]),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][2], 0)
def test_v_pack_b32_f16_both_positive(self):
"""V_PACK_B32_F16 with positive f16 values."""
instructions = [
s_mov_b32(s[0], 0x4200), # f16 3.0
s_mov_b32(s[1], 0x4400), # f16 4.0
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
v_pack_b32_f16(v[2], v[0], v[1]),
]
st = run_program(instructions, n_lanes=1)
result = st.vgpr[0][2]
self.assertEqual(result, 0x44004200, f"Expected 0x44004200, got 0x{result:08x}")
class TestFmaMix(unittest.TestCase):
"""Tests for V_FMA_MIX_F32 and V_FMA_MIXLO_F16."""
def test_v_fma_mix_f32_all_f32_sources(self):
"""V_FMA_MIX_F32 with all f32 sources."""
instructions = [
s_mov_b32(s[0], f2i(2.0)),
v_mov_b32_e32(v[0], s[0]),
s_mov_b32(s[1], f2i(3.0)),
v_mov_b32_e32(v[1], s[1]),
s_mov_b32(s[2], f2i(1.0)),
v_mov_b32_e32(v[2], s[2]),
VOP3P(VOP3POp.V_FMA_MIX_F32, vdst=v[3], src0=v[0], src1=v[1], src2=v[2], opsel=0, opsel_hi=0, opsel_hi2=0),
]
st = run_program(instructions, n_lanes=1)
result = i2f(st.vgpr[0][3])
self.assertAlmostEqual(result, 7.0, places=5)
def test_v_fma_mix_f32_src2_f16_lo(self):
"""V_FMA_MIX_F32 with src2 as f16 from lo bits."""
from extra.assembly.amd.pcode import f32_to_f16
f16_2 = f32_to_f16(2.0)
instructions = [
s_mov_b32(s[0], f2i(1.0)),
v_mov_b32_e32(v[0], s[0]),
s_mov_b32(s[1], f2i(3.0)),
v_mov_b32_e32(v[1], s[1]),
s_mov_b32(s[2], f16_2),
v_mov_b32_e32(v[2], s[2]),
VOP3P(VOP3POp.V_FMA_MIX_F32, vdst=v[3], src0=v[0], src1=v[1], src2=v[2], opsel=0, opsel_hi=0, opsel_hi2=1),
]
st = run_program(instructions, n_lanes=1)
result = i2f(st.vgpr[0][3])
self.assertAlmostEqual(result, 5.0, places=5)
def test_v_fma_mix_f32_src2_f16_hi(self):
"""V_FMA_MIX_F32 with src2 as f16 from hi bits."""
from extra.assembly.amd.pcode import f32_to_f16
f16_2 = f32_to_f16(2.0)
val = (f16_2 << 16) | 0
instructions = [
s_mov_b32(s[0], f2i(1.0)),
v_mov_b32_e32(v[0], s[0]),
s_mov_b32(s[1], f2i(3.0)),
v_mov_b32_e32(v[1], s[1]),
s_mov_b32(s[2], val),
v_mov_b32_e32(v[2], s[2]),
VOP3P(VOP3POp.V_FMA_MIX_F32, vdst=v[3], src0=v[0], src1=v[1], src2=v[2], opsel=4, opsel_hi=0, opsel_hi2=1),
]
st = run_program(instructions, n_lanes=1)
result = i2f(st.vgpr[0][3])
self.assertAlmostEqual(result, 5.0, places=5)
def test_v_fma_mix_f32_with_abs(self):
"""V_FMA_MIX_F32 with abs modifier on src2."""
instructions = [
s_mov_b32(s[0], f2i(2.0)),
v_mov_b32_e32(v[0], s[0]),
s_mov_b32(s[1], f2i(3.0)),
v_mov_b32_e32(v[1], s[1]),
s_mov_b32(s[2], f2i(-1.0)),
v_mov_b32_e32(v[2], s[2]),
VOP3P(VOP3POp.V_FMA_MIX_F32, vdst=v[3], src0=v[0], src1=v[1], src2=v[2], opsel=0, opsel_hi=0, opsel_hi2=0, neg_hi=4),
]
st = run_program(instructions, n_lanes=1)
result = i2f(st.vgpr[0][3])
self.assertAlmostEqual(result, 7.0, places=5)
def test_v_fma_mixlo_f16(self):
"""V_FMA_MIXLO_F16 writes to low 16 bits of destination."""
from extra.assembly.amd.pcode import _f16
instructions = [
s_mov_b32(s[0], f2i(2.0)),
v_mov_b32_e32(v[0], s[0]),
s_mov_b32(s[1], f2i(3.0)),
v_mov_b32_e32(v[1], s[1]),
s_mov_b32(s[2], f2i(1.0)),
v_mov_b32_e32(v[2], s[2]),
s_mov_b32(s[3], 0xdead0000),
v_mov_b32_e32(v[3], s[3]),
VOP3P(VOP3POp.V_FMA_MIXLO_F16, vdst=v[3], src0=v[0], src1=v[1], src2=v[2], opsel=0, opsel_hi=0, opsel_hi2=0),
]
st = run_program(instructions, n_lanes=1)
lo = _f16(st.vgpr[0][3] & 0xffff)
hi = (st.vgpr[0][3] >> 16) & 0xffff
self.assertAlmostEqual(lo, 7.0, places=1)
self.assertEqual(hi, 0xdead, f"hi should be preserved, got 0x{hi:04x}")
def test_v_fma_mixlo_f16_all_f32_sources(self):
"""V_FMA_MIXLO_F16 with all f32 sources."""
from extra.assembly.amd.pcode import _f16
instructions = [
s_mov_b32(s[0], f2i(1.0)),
v_mov_b32_e32(v[0], s[0]),
s_mov_b32(s[1], f2i(2.0)),
v_mov_b32_e32(v[1], s[1]),
s_mov_b32(s[2], f2i(3.0)),
v_mov_b32_e32(v[2], s[2]),
v_mov_b32_e32(v[3], 0),
VOP3P(VOP3POp.V_FMA_MIXLO_F16, vdst=v[3], src0=v[0], src1=v[1], src2=v[2], opsel=0, opsel_hi=0, opsel_hi2=0),
]
st = run_program(instructions, n_lanes=1)
lo = _f16(st.vgpr[0][3] & 0xffff)
# 1*2+3 = 5
self.assertAlmostEqual(lo, 5.0, places=1)
def test_v_fma_mixlo_f16_sin_case(self):
"""V_FMA_MIXLO_F16 case from sin kernel."""
from extra.assembly.amd.pcode import _f16
instructions = [
s_mov_b32(s[0], 0x3f800000), # f32 1.0
v_mov_b32_e32(v[3], s[0]),
s_mov_b32(s[1], 0xaf05a309), # f32 tiny negative
s_mov_b32(s[6], s[1]),
s_mov_b32(s[2], 0xc0490fdb), # f32 -π
v_mov_b32_e32(v[5], s[2]),
s_mov_b32(s[3], 0x3f800000),
v_mov_b32_e32(v[3], s[3]),
VOP3P(VOP3POp.V_FMA_MIXLO_F16, vdst=v[3], src0=v[3], src1=s[6], src2=v[5], opsel=0, opsel_hi=0, opsel_hi2=0),
]
st = run_program(instructions, n_lanes=1)
lo = _f16(st.vgpr[0][3] & 0xffff)
self.assertAlmostEqual(lo, -3.14159, delta=0.01)
class TestVOP3P(unittest.TestCase):
"""Tests for VOP3P packed 16-bit operations."""
def test_v_pk_add_f16_basic(self):
"""V_PK_ADD_F16 adds two packed f16 values."""
from extra.assembly.amd.pcode import _f16
instructions = [
s_mov_b32(s[0], 0x40003c00), # hi=2.0, lo=1.0
s_mov_b32(s[1], 0x44004200), # hi=4.0, lo=3.0
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
v_pk_add_f16(v[2], v[0], v[1]),
]
st = run_program(instructions, n_lanes=1)
result = st.vgpr[0][2]
lo = _f16(result & 0xffff)
hi = _f16((result >> 16) & 0xffff)
self.assertAlmostEqual(lo, 4.0, places=2)
self.assertAlmostEqual(hi, 6.0, places=2)
def test_v_pk_mul_f16_basic(self):
"""V_PK_MUL_F16 multiplies two packed f16 values."""
from extra.assembly.amd.pcode import _f16
instructions = [
s_mov_b32(s[0], 0x42004000), # hi=3.0, lo=2.0
s_mov_b32(s[1], 0x45004400), # hi=5.0, lo=4.0
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
v_pk_mul_f16(v[2], v[0], v[1]),
]
st = run_program(instructions, n_lanes=1)
result = st.vgpr[0][2]
lo = _f16(result & 0xffff)
hi = _f16((result >> 16) & 0xffff)
self.assertAlmostEqual(lo, 8.0, places=1)
self.assertAlmostEqual(hi, 15.0, places=1)
def test_v_pk_fma_f16_basic(self):
"""V_PK_FMA_F16: D = A * B + C for packed f16."""
from extra.assembly.amd.pcode import _f16
instructions = [
s_mov_b32(s[0], 0x42004000), # A: hi=3.0, lo=2.0
s_mov_b32(s[1], 0x45004400), # B: hi=5.0, lo=4.0
s_mov_b32(s[2], 0x3c003c00), # C: hi=1.0, lo=1.0
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
v_mov_b32_e32(v[2], s[2]),
v_pk_fma_f16(v[3], v[0], v[1], v[2]),
]
st = run_program(instructions, n_lanes=1)
result = st.vgpr[0][3]
lo = _f16(result & 0xffff)
hi = _f16((result >> 16) & 0xffff)
self.assertAlmostEqual(lo, 9.0, places=1) # 2*4+1
self.assertAlmostEqual(hi, 16.0, places=0) # 3*5+1
def test_v_pk_add_f16_with_inline_constant(self):
"""V_PK_ADD_F16 with inline constant POS_ONE (1.0).
Inline constants for VOP3P are f16 values in the low 16 bits only.
hi half of inline constant is 0, so hi result = v0.hi + 0 = 1.0.
"""
from extra.assembly.amd.pcode import _f16
instructions = [
s_mov_b32(s[0], 0x3c003c00), # packed f16: hi=1.0, lo=1.0
v_mov_b32_e32(v[0], s[0]),
v_pk_add_f16(v[1], v[0], SrcEnum.POS_ONE), # Add inline constant 1.0
]
st = run_program(instructions, n_lanes=1)
result = st.vgpr[0][1]
lo = _f16(result & 0xffff)
hi = _f16((result >> 16) & 0xffff)
# lo = 1.0 + 1.0 = 2.0, hi = 1.0 + 0.0 = 1.0 (inline const hi half is 0)
self.assertAlmostEqual(lo, 2.0, places=2)
self.assertAlmostEqual(hi, 1.0, places=2)
def test_v_pk_mul_f16_with_inline_constant(self):
"""V_PK_MUL_F16 with inline constant POS_TWO (2.0).
Inline constant has value only in low 16 bits, hi is 0.
"""
from extra.assembly.amd.pcode import _f16
# v0 = packed (3.0, 4.0), multiply by POS_TWO
# lo = 3.0 * 2.0 = 6.0, hi = 4.0 * 0.0 = 0.0 (inline const hi is 0)
instructions = [
s_mov_b32(s[0], 0x44004200), # packed f16: hi=4.0, lo=3.0
v_mov_b32_e32(v[0], s[0]),
v_pk_mul_f16(v[1], v[0], SrcEnum.POS_TWO),
]
st = run_program(instructions, n_lanes=1)
result = st.vgpr[0][1]
lo = _f16(result & 0xffff)
hi = _f16((result >> 16) & 0xffff)
self.assertAlmostEqual(lo, 6.0, places=1)
self.assertAlmostEqual(hi, 0.0, places=1)
class TestWMMA(unittest.TestCase):
"""Tests for WMMA (Wave Matrix Multiply-Accumulate) instructions."""
def test_v_wmma_f32_16x16x16_f16_all_ones(self):
"""V_WMMA_F32_16X16X16_F16 with all ones produces 16.0."""
instructions = []
instructions.append(s_mov_b32(s[0], 0x3c003c00)) # packed f16 1.0
for i in range(16, 32):
instructions.append(v_mov_b32_e32(v[i], s[0]))
for i in range(8):
instructions.append(v_mov_b32_e32(v[i], 0))
instructions.append(v_wmma_f32_16x16x16_f16(v[0], v[16], v[24], v[0]))
st = run_program(instructions, n_lanes=32)
expected = f2i(16.0)
for lane in range(32):
for reg in range(8):
result = st.vgpr[lane][reg]
self.assertEqual(result, expected, f"v[{reg}] lane {lane}: expected 16.0, got {i2f(result)}")
def test_v_wmma_f32_16x16x16_f16_with_accumulator(self):
"""V_WMMA_F32_16X16X16_F16 with non-zero accumulator."""
instructions = []
instructions.append(s_mov_b32(s[0], 0x3c003c00))
instructions.append(s_mov_b32(s[1], f2i(5.0)))
for i in range(16, 32):
instructions.append(v_mov_b32_e32(v[i], s[0]))
for i in range(8):
instructions.append(v_mov_b32_e32(v[i], s[1]))
instructions.append(v_wmma_f32_16x16x16_f16(v[0], v[16], v[24], v[0]))
st = run_program(instructions, n_lanes=32)
expected = f2i(21.0) # 16 + 5
for lane in range(32):
for reg in range(8):
result = st.vgpr[lane][reg]
self.assertEqual(result, expected, f"v[{reg}] lane {lane}: expected 21.0, got {i2f(result)}")
class TestSpecialOps(unittest.TestCase):
"""Tests for special operations (SAD, PERM, DOT2)."""
def test_v_sad_u8_basic(self):
"""V_SAD_U8 computes sum of absolute differences."""
instructions = [
s_mov_b32(s[0], 0x04030201), # bytes: 1, 2, 3, 4
s_mov_b32(s[1], 0x05040302), # bytes: 2, 3, 4, 5
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
v_mov_b32_e32(v[2], 0),
v_sad_u8(v[3], v[0], v[1], v[2]),
]
st = run_program(instructions, n_lanes=1)
# |1-2| + |2-3| + |3-4| + |4-5| = 1 + 1 + 1 + 1 = 4
self.assertEqual(st.vgpr[0][3], 4)
def test_v_sad_u8_identical_bytes(self):
"""V_SAD_U8 with identical inputs returns accumulator."""
instructions = [
s_mov_b32(s[0], 0x04030201),
v_mov_b32_e32(v[0], s[0]),
s_mov_b32(s[1], 10),
v_mov_b32_e32(v[2], s[1]),
v_sad_u8(v[3], v[0], v[0], v[2]),
]
st = run_program(instructions, n_lanes=1)
# Same inputs -> SAD = 0, result = accumulator = 10
self.assertEqual(st.vgpr[0][3], 10)
def test_v_sad_u16_basic(self):
"""V_SAD_U16 computes sum of absolute differences of u16 pairs."""
instructions = [
s_mov_b32(s[0], 0x00030001), # hi=3, lo=1
s_mov_b32(s[1], 0x00050002), # hi=5, lo=2
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
v_mov_b32_e32(v[2], 0),
v_sad_u16(v[3], v[0], v[1], v[2]),
]
st = run_program(instructions, n_lanes=1)
# |1-2| + |3-5| = 1 + 2 = 3
self.assertEqual(st.vgpr[0][3], 3)
def test_v_sad_u32_basic(self):
"""V_SAD_U32 computes absolute difference of u32 values."""
instructions = [
s_mov_b32(s[0], 100),
s_mov_b32(s[1], 70),
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
v_mov_b32_e32(v[2], 0),
v_sad_u32(v[3], v[0], v[1], v[2]),
]
st = run_program(instructions, n_lanes=1)
# |100-70| = 30
self.assertEqual(st.vgpr[0][3], 30)
def test_v_msad_u8_masked(self):
"""V_MSAD_U8 masked SAD operation."""
instructions = [
s_mov_b32(s[0], 0x04030201),
s_mov_b32(s[1], 0x05040302),
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
v_mov_b32_e32(v[2], 0),
v_msad_u8(v[3], v[0], v[1], v[2]),
]
st = run_program(instructions, n_lanes=1)
# V_MSAD_U8 skips bytes where src0 is 0
# Since no bytes are 0, result same as V_SAD_U8 = 4
self.assertEqual(st.vgpr[0][3], 4)
def test_v_perm_b32_select_bytes(self):
"""V_PERM_B32 selects bytes from two sources.
V_PERM_B32 concatenates {S1, S0} as a 64-bit value with S1 in low 32 bits.
Selector byte values 0-3 select from S1, values 4-7 select from S0.
"""
instructions = [
s_mov_b32(s[0], 0x44332211), # src0: bytes 4-7 in 64-bit view
s_mov_b32(s[1], 0x88776655), # src1: bytes 0-3 in 64-bit view
s_mov_b32(s[2], 0x07060504), # select bytes 4,5,6,7 (from src0)
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
v_perm_b32(v[2], v[0], v[1], s[2]),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][2], 0x44332211)
def test_v_dot2_f32_bf16_basic(self):
"""V_DOT2_F32_BF16 computes dot product of bf16 pairs."""
# bf16 1.0 = 0x3f80, bf16 2.0 = 0x4000
instructions = [
s_mov_b32(s[0], 0x3f803f80), # packed bf16: 1.0, 1.0
s_mov_b32(s[1], 0x40003f80), # packed bf16: 2.0, 1.0
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
v_mov_b32_e32(v[2], 0),
v_dot2_f32_bf16(v[3], v[0], v[1], v[2]),
]
st = run_program(instructions, n_lanes=1)
# 1.0*1.0 + 1.0*2.0 + 0 = 3.0
result = i2f(st.vgpr[0][3])
self.assertAlmostEqual(result, 3.0, places=4)
class TestPackedMixedSigns(unittest.TestCase):
"""Tests for packed operations with mixed sign values."""
def test_pk_add_f16_mixed_signs(self):
"""V_PK_ADD_F16 with mixed positive/negative values."""
from extra.assembly.amd.pcode import _f16
instructions = [
s_mov_b32(s[0], 0xc0003c00), # packed: hi=-2.0, lo=1.0
s_mov_b32(s[1], 0x3c003c00), # packed: hi=1.0, lo=1.0
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
v_pk_add_f16(v[2], v[0], v[1]),
]
st = run_program(instructions, n_lanes=1)
result = st.vgpr[0][2]
lo = _f16(result & 0xffff)
hi = _f16((result >> 16) & 0xffff)
self.assertAlmostEqual(lo, 2.0, places=2) # 1.0 + 1.0
self.assertAlmostEqual(hi, -1.0, places=2) # -2.0 + 1.0
def test_pk_mul_f16_zero(self):
"""V_PK_MUL_F16 with zero."""
from extra.assembly.amd.pcode import _f16
instructions = [
s_mov_b32(s[0], 0x40004000), # packed: 2.0, 2.0
s_mov_b32(s[1], 0x00000000), # packed: 0.0, 0.0
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
v_pk_mul_f16(v[2], v[0], v[1]),
]
st = run_program(instructions, n_lanes=1)
result = st.vgpr[0][2]
self.assertEqual(result, 0x00000000, "2.0 * 0.0 should be 0.0")
if __name__ == '__main__':
unittest.main()

View File

@@ -0,0 +1,486 @@
"""Tests for VOPC instructions - vector compare operations.
Includes: v_cmp_class_f32, v_cmp_class_f16, v_cmp_eq_*, v_cmp_lt_*, v_cmp_gt_*
"""
import unittest
from extra.assembly.amd.test.hw.helpers import *
VCC = 106 # SGPR index for VCC_LO
class TestCmpClass(unittest.TestCase):
"""Tests for V_CMP_CLASS_F32 float classification."""
def test_cmp_class_quiet_nan(self):
"""V_CMP_CLASS_F32 detects quiet NaN."""
quiet_nan = 0x7fc00000
instructions = [
s_mov_b32(s[0], quiet_nan),
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], 0b0000000010), # bit 1 = quiet NaN
v_cmp_class_f32_e32(v[0], v[1]),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vcc & 1, 1, "Should detect quiet NaN")
def test_cmp_class_signaling_nan(self):
"""V_CMP_CLASS_F32 detects signaling NaN."""
signal_nan = 0x7f800001
instructions = [
s_mov_b32(s[0], signal_nan),
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], 0b0000000001), # bit 0 = signaling NaN
v_cmp_class_f32_e32(v[0], v[1]),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vcc & 1, 1, "Should detect signaling NaN")
def test_cmp_class_positive_inf(self):
"""V_CMP_CLASS_F32 detects +inf."""
pos_inf = 0x7f800000
instructions = [
s_mov_b32(s[0], pos_inf),
s_mov_b32(s[1], 0b1000000000), # bit 9 = +inf
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
v_cmp_class_f32_e32(v[0], v[1]),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vcc & 1, 1, "Should detect +inf")
def test_cmp_class_negative_inf(self):
"""V_CMP_CLASS_F32 detects -inf."""
neg_inf = 0xff800000
instructions = [
s_mov_b32(s[0], neg_inf),
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], 0b0000000100), # bit 2 = -inf
v_cmp_class_f32_e32(v[0], v[1]),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vcc & 1, 1, "Should detect -inf")
def test_cmp_class_normal_positive(self):
"""V_CMP_CLASS_F32 detects positive normal."""
instructions = [
v_mov_b32_e32(v[0], 1.0),
s_mov_b32(s[1], 0b0100000000), # bit 8 = positive normal
v_mov_b32_e32(v[1], s[1]),
v_cmp_class_f32_e32(v[0], v[1]),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vcc & 1, 1, "Should detect positive normal")
def test_cmp_class_normal_negative(self):
"""V_CMP_CLASS_F32 detects negative normal."""
instructions = [
v_mov_b32_e32(v[0], -1.0),
v_mov_b32_e32(v[1], 0b0000001000), # bit 3 = negative normal
v_cmp_class_f32_e32(v[0], v[1]),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vcc & 1, 1, "Should detect negative normal")
def test_cmp_class_quiet_nan_not_signaling(self):
"""Quiet NaN does not match signaling NaN mask."""
quiet_nan = 0x7fc00000
instructions = [
s_mov_b32(s[0], quiet_nan),
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], 0b0000000001), # bit 0 = signaling NaN only
v_cmp_class_f32_e32(v[0], v[1]),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vcc & 1, 0, "Quiet NaN should not match signaling mask")
def test_cmp_class_signaling_nan_not_quiet(self):
"""Signaling NaN does not match quiet NaN mask."""
signal_nan = 0x7f800001
instructions = [
s_mov_b32(s[0], signal_nan),
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], 0b0000000010), # bit 1 = quiet NaN only
v_cmp_class_f32_e32(v[0], v[1]),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vcc & 1, 0, "Signaling NaN should not match quiet mask")
def test_v_cmp_sets_vcc_bits(self):
"""V_CMP_EQ sets VCC bits based on per-lane comparison."""
instructions = [
s_mov_b32(s[0], 5),
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[0]),
v_cmp_eq_u32_e32(v[0], v[1]),
]
st = run_program(instructions, n_lanes=4)
self.assertEqual(st.vcc & 0xf, 0xf, "All lanes should match")
class TestCmpClassF16(unittest.TestCase):
"""Tests for V_CMP_CLASS_F16 float classification.
Class bit mapping:
bit 0 = signaling NaN
bit 1 = quiet NaN
bit 2 = -infinity
bit 3 = -normal
bit 4 = -denormal
bit 5 = -zero
bit 6 = +zero
bit 7 = +denormal
bit 8 = +normal
bit 9 = +infinity
"""
def test_cmp_class_f16_positive_zero(self):
"""V_CMP_CLASS_F16: +zero matches bit 6."""
instructions = [
v_mov_b32_e32(v[0], 0x0000), # f16 +0.0
v_mov_b32_e32(v[1], 0x40), # bit 6 = +zero
v_cmp_class_f16_e32(v[0], v[1]),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vcc & 1, 1, "Should detect positive zero")
def test_cmp_class_f16_negative_zero(self):
"""V_CMP_CLASS_F16: -zero matches bit 5."""
instructions = [
s_mov_b32(s[0], 0x8000), # f16 -0.0
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], 0x20), # bit 5 = -zero
v_cmp_class_f16_e32(v[0], v[1]),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vcc & 1, 1, "Should detect negative zero")
def test_cmp_class_f16_positive_normal(self):
"""V_CMP_CLASS_F16: +1.0 (normal) matches bit 8."""
instructions = [
s_mov_b32(s[0], 0x3c00), # f16 +1.0
s_mov_b32(s[1], 0x100), # bit 8 = +normal
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
v_cmp_class_f16_e32(v[0], v[1]),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vcc & 1, 1, "Should detect positive normal")
def test_cmp_class_f16_negative_normal(self):
"""V_CMP_CLASS_F16: -1.0 (normal) matches bit 3."""
instructions = [
s_mov_b32(s[0], 0xbc00), # f16 -1.0
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], 0x08), # bit 3 = -normal
v_cmp_class_f16_e32(v[0], v[1]),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vcc & 1, 1, "Should detect negative normal")
def test_cmp_class_f16_positive_infinity(self):
"""V_CMP_CLASS_F16: +inf matches bit 9."""
instructions = [
s_mov_b32(s[0], 0x7c00), # f16 +inf
s_mov_b32(s[1], 0x200), # bit 9 = +inf
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
v_cmp_class_f16_e32(v[0], v[1]),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vcc & 1, 1, "Should detect positive infinity")
def test_cmp_class_f16_negative_infinity(self):
"""V_CMP_CLASS_F16: -inf matches bit 2."""
instructions = [
s_mov_b32(s[0], 0xfc00), # f16 -inf
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], 0x04), # bit 2 = -inf
v_cmp_class_f16_e32(v[0], v[1]),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vcc & 1, 1, "Should detect negative infinity")
def test_cmp_class_f16_quiet_nan(self):
"""V_CMP_CLASS_F16: quiet NaN matches bit 1."""
instructions = [
s_mov_b32(s[0], 0x7e00), # f16 quiet NaN
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], 0x02), # bit 1 = quiet NaN
v_cmp_class_f16_e32(v[0], v[1]),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vcc & 1, 1, "Should detect quiet NaN")
def test_cmp_class_f16_signaling_nan(self):
"""V_CMP_CLASS_F16: signaling NaN matches bit 0."""
instructions = [
s_mov_b32(s[0], 0x7c01), # f16 signaling NaN
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], 0x01), # bit 0 = signaling NaN
v_cmp_class_f16_e32(v[0], v[1]),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vcc & 1, 1, "Should detect signaling NaN")
def test_cmp_class_f16_positive_denormal(self):
"""V_CMP_CLASS_F16: positive denormal matches bit 7."""
instructions = [
v_mov_b32_e32(v[0], 1), # f16 +denormal (0x0001)
v_mov_b32_e32(v[1], 0x80), # bit 7 = +denormal
v_cmp_class_f16_e32(v[0], v[1]),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vcc & 1, 1, "Should detect positive denormal")
def test_cmp_class_f16_negative_denormal(self):
"""V_CMP_CLASS_F16: negative denormal matches bit 4."""
instructions = [
s_mov_b32(s[0], 0x8001), # f16 -denormal
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], 0x10), # bit 4 = -denormal
v_cmp_class_f16_e32(v[0], v[1]),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vcc & 1, 1, "Should detect negative denormal")
def test_cmp_class_f16_combined_mask_zeros(self):
"""V_CMP_CLASS_F16: mask 0x60 covers both +zero and -zero."""
instructions = [
v_mov_b32_e32(v[0], 0), # f16 +0.0
v_mov_b32_e32(v[1], 0x60), # bits 5 and 6 (+-zero)
v_cmp_class_f16_e32(v[0], v[1]),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vcc & 1, 1, "VCC should be 1 for +zero with mask 0x60")
def test_cmp_class_f16_combined_mask_1f8(self):
"""V_CMP_CLASS_F16: mask 0x1f8 covers -normal,-denorm,-zero,+zero,+denorm,+normal.
This is the exact mask used in the f16 sin kernel at PC=46.
"""
instructions = [
v_mov_b32_e32(v[0], 0), # f16 +0.0
s_mov_b32(s[0], 0x1f8),
v_mov_b32_e32(v[1], s[0]), # mask 0x1f8
v_cmp_class_f16_e32(v[0], v[1]),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vcc & 1, 1, "VCC should be 1 for +zero with mask 0x1f8")
def test_cmp_class_f16_vop3_encoding(self):
"""V_CMP_CLASS_F16 in VOP3 encoding (v_cmp_class_f16_e64)."""
instructions = [
v_mov_b32_e32(v[0], 0), # f16 +0.0
s_mov_b32(s[0], 0x1f8), # class mask
VOP3(VOP3Op.V_CMP_CLASS_F16, vdst=RawImm(VCC), src0=v[0], src1=s[0]),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vcc & 1, 1, "VCC should be 1 for +zero with VOP3 encoding")
def test_cmp_class_f16_vop3_normal_positive(self):
"""V_CMP_CLASS_F16 VOP3 encoding with +1.0 (normal)."""
instructions = [
s_mov_b32(s[0], 0x3c00), # f16 +1.0
v_mov_b32_e32(v[0], s[0]),
s_mov_b32(s[1], 0x1f8), # class mask
VOP3(VOP3Op.V_CMP_CLASS_F16, vdst=RawImm(VCC), src0=v[0], src1=s[1]),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vcc & 1, 1, "VCC should be 1 for +1.0 (normal) with mask 0x1f8")
def test_cmp_class_f16_vop3_nan_fails_mask(self):
"""V_CMP_CLASS_F16 VOP3: NaN should NOT match mask 0x1f8 (no NaN bits set)."""
instructions = [
s_mov_b32(s[0], 0x7e00), # f16 quiet NaN
v_mov_b32_e32(v[0], s[0]),
s_mov_b32(s[1], 0x1f8), # class mask
VOP3(VOP3Op.V_CMP_CLASS_F16, vdst=RawImm(VCC), src0=v[0], src1=s[1]),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vcc & 1, 0, "VCC should be 0 for NaN with mask 0x1f8 (no NaN bits)")
def test_cmp_class_f16_vop3_inf_fails_mask(self):
"""V_CMP_CLASS_F16 VOP3: +inf should NOT match mask 0x1f8 (no inf bits set)."""
instructions = [
s_mov_b32(s[0], 0x7c00), # f16 +inf
v_mov_b32_e32(v[0], s[0]),
s_mov_b32(s[1], 0x1f8), # class mask
VOP3(VOP3Op.V_CMP_CLASS_F16, vdst=RawImm(VCC), src0=v[0], src1=s[1]),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vcc & 1, 0, "VCC should be 0 for +inf with mask 0x1f8 (no inf bits)")
class TestCmpInt(unittest.TestCase):
"""Tests for integer comparison operations."""
def test_v_cmp_eq_u32(self):
"""V_CMP_EQ_U32 sets VCC bits based on per-lane comparison."""
instructions = [
s_mov_b32(s[0], 5),
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[0]),
v_cmp_eq_u32_e32(v[0], v[1]),
]
st = run_program(instructions, n_lanes=4)
self.assertEqual(st.vcc & 0xf, 0xf, "All lanes should match")
def test_cmp_eq_u16_opsel_lo_lo(self):
"""V_CMP_EQ_U16 comparing lo halves."""
instructions = [
s_mov_b32(s[0], 0x12340005), # lo=5, hi=0x1234
s_mov_b32(s[1], 0xABCD0005), # lo=5, hi=0xABCD
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
v_cmp_eq_u16_e32(v[0], v[1]),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vcc & 1, 1, "Lo halves should be equal")
def test_cmp_eq_u16_opsel_hi_hi(self):
"""V_CMP_EQ_U16 comparing hi halves with VOP3 opsel.
VOPC doesn't have opsel, so we use VOP3 form for hi-half comparisons.
VOP3 compares write result to SGPR via vdst field.
"""
instructions = [
s_mov_b32(s[2], 0x00051234), # hi=5, lo=0x1234
v_mov_b32_e32(v[0], s[2]),
s_mov_b32(s[2], 0x0005ABCD), # hi=5, lo=0xABCD
v_mov_b32_e32(v[1], s[2]),
# opsel=3 means compare hi halves, vdst=v[0] actually writes to s[0]
VOP3(VOP3Op.V_CMP_EQ_U16, vdst=v[0], src0=v[0], src1=v[1], opsel=3),
]
st = run_program(instructions, n_lanes=1)
# Result is in sgpr[0], not vcc
self.assertEqual(st.sgpr[0] & 1, 1, "Hi halves should be equal: 5==5")
def test_cmp_eq_u16_opsel_hi_hi_equal(self):
"""V_CMP_EQ_U16 VOP3 with opsel=3 compares hi halves (equal case)."""
instructions = [
s_mov_b32(s[2], 0x12340005), # lo=5, hi=0x1234
v_mov_b32_e32(v[0], s[2]),
s_mov_b32(s[2], 0x12340009), # lo=9, hi=0x1234
v_mov_b32_e32(v[1], s[2]),
VOP3(VOP3Op.V_CMP_EQ_U16, vdst=v[0], src0=v[0], src1=v[1], opsel=3),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.sgpr[0] & 1, 1, "hi==hi should be true: 0x1234==0x1234")
def test_cmp_gt_u16_opsel_hi(self):
"""V_CMP_GT_U16 VOP3 with opsel=3 compares hi halves."""
instructions = [
s_mov_b32(s[2], 0x99990005), # lo=5, hi=0x9999
v_mov_b32_e32(v[0], s[2]),
s_mov_b32(s[2], 0x12340005), # lo=5, hi=0x1234
v_mov_b32_e32(v[1], s[2]),
VOP3(VOP3Op.V_CMP_GT_U16, vdst=v[0], src0=v[0], src1=v[1], opsel=3),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.sgpr[0] & 1, 1, "hi>hi should be true: 0x9999>0x1234")
class TestCmpFloat(unittest.TestCase):
"""Tests for float comparison operations."""
def test_v_cmp_lt_f16_vsrc1_hi(self):
"""V_CMP_LT_F16 with both operands from high half using VOP3 opsel."""
instructions = [
s_mov_b32(s[2], 0x3c000000), # hi=1.0 (f16), lo=0
v_mov_b32_e32(v[0], s[2]),
s_mov_b32(s[2], 0x40000000), # hi=2.0 (f16), lo=0
v_mov_b32_e32(v[1], s[2]),
# opsel=3 means read hi halves for both src0 and src1
VOP3(VOP3Op.V_CMP_LT_F16, vdst=v[0], src0=v[0], src1=v[1], opsel=3),
]
st = run_program(instructions, n_lanes=1)
# Result is in sgpr[0]
self.assertEqual(st.sgpr[0] & 1, 1, "1.0 < 2.0 should be true")
def test_v_cmp_gt_f16_vsrc1_hi(self):
"""V_CMP_GT_F16 with both operands from high half using VOP3 opsel."""
instructions = [
s_mov_b32(s[2], 0x40000000), # hi=2.0 (f16), lo=0
v_mov_b32_e32(v[0], s[2]),
s_mov_b32(s[2], 0x3c000000), # hi=1.0 (f16), lo=0
v_mov_b32_e32(v[1], s[2]),
# opsel=3 means read hi halves for both src0 and src1
VOP3(VOP3Op.V_CMP_GT_F16, vdst=v[0], src0=v[0], src1=v[1], opsel=3),
]
st = run_program(instructions, n_lanes=1)
# Result is in sgpr[0]
self.assertEqual(st.sgpr[0] & 1, 1, "2.0 > 1.0 should be true")
def test_v_cmp_eq_f16_vsrc1_hi_equal(self):
"""v_cmp_eq_f16 with equal low and high halves."""
instructions = [
s_mov_b32(s[0], 0x42004200), # hi=3.0 (0x4200), lo=3.0 (0x4200)
v_mov_b32_e32(v[0], s[0]),
v_cmp_eq_f16_e32(v[0], v[0].h),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vcc & 1, 1, "Expected vcc=1 (3.0 == 3.0)")
def test_v_cmp_neq_f16_vsrc1_hi(self):
"""v_cmp_neq_f16 with different low and high halves."""
instructions = [
s_mov_b32(s[0], 0x40003c00), # hi=2.0 (0x4000), lo=1.0 (0x3c00)
v_mov_b32_e32(v[0], s[0]),
v_cmp_lg_f16_e32(v[0], v[0].h),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vcc & 1, 1, "Expected vcc=1 (1.0 != 2.0)")
def test_v_cmp_nge_f16_inf_self(self):
"""v_cmp_nge_f16 comparing -inf with itself (unordered less than).
Regression test: -inf < -inf should be false (IEEE 754).
"""
instructions = [
s_mov_b32(s[0], 0xFC00FC00), # both halves = -inf (0xFC00)
v_mov_b32_e32(v[0], s[0]),
v_cmp_nge_f16_e32(v[0], v[0].h),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vcc & 1, 0, "Expected vcc=0 (-inf >= -inf)")
def test_v_cmp_f16_multilane(self):
"""v_cmp_lt_f16 with vsrc1=v128 across multiple lanes."""
instructions = [
# Lane 0: v0 = 0x40003c00 (hi=2.0, lo=1.0) -> 1.0 < 2.0 = true
# Lane 1: v0 = 0x3c004000 (hi=1.0, lo=2.0) -> 2.0 < 1.0 = false
v_mov_b32_e32(v[0], 0x40003c00), # default
v_cmp_eq_u32_e32(1, v[255]), # vcc = (lane == 1)
v_cndmask_b32_e64(v[0], v[0], 0x3c004000, SrcEnum.VCC_LO),
v_cmp_lt_f16_e32(v[0], v[0].h),
]
st = run_program(instructions, n_lanes=2)
self.assertEqual(st.vcc & 1, 1, "Lane 0: expected vcc=1 (1.0 < 2.0)")
self.assertEqual((st.vcc >> 1) & 1, 0, "Lane 1: expected vcc=0 (2.0 < 1.0)")
class TestVCCBehavior(unittest.TestCase):
"""Tests for VCC condition code behavior."""
def test_vcc_all_lanes_true(self):
"""VCC should have all bits set when all lanes compare true."""
instructions = [
v_mov_b32_e32(v[0], 5),
v_mov_b32_e32(v[1], 5),
v_cmp_eq_u32_e32(v[0], v[1]),
]
st = run_program(instructions, n_lanes=32)
self.assertEqual(st.vcc, 0xFFFFFFFF, "All 32 lanes should be true")
def test_vcc_lane_dependent(self):
"""VCC should differ per lane based on lane_id comparison."""
instructions = [
v_mov_b32_e32(v[0], 16),
v_cmp_lt_u32_e32(v[255], v[0]), # lanes 0-15 are < 16
]
st = run_program(instructions, n_lanes=32)
self.assertEqual(st.vcc & 0xFFFF, 0xFFFF, "Lanes 0-15 should be true")
self.assertEqual(st.vcc >> 16, 0x0000, "Lanes 16-31 should be false")
if __name__ == '__main__':
unittest.main()

File diff suppressed because it is too large Load Diff