assembly/amd: add new instruction support to pcode (#13885)

* assembly/amd: add new instruction support

* more

* regen all
This commit is contained in:
George Hotz
2025-12-29 17:30:17 -05:00
committed by GitHub
parent 0d326f5b9b
commit 7322d9ec4a
7 changed files with 1440 additions and 13 deletions

View File

@@ -82,6 +82,51 @@ def _SOP1Op_S_NOT_B64(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR,
result['d0_64'] = True
return result
def _SOP1Op_S_WQM_B32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
# tmp = 0U;
# declare i : 6'U;
# for i in 6'0U : 6'31U do
# tmp[i] = S0.u32[i & 6'60U +: 6'4U] != 0U
# endfor;
# D0.u32 = tmp;
# SCC = D0.u32 != 0U
S0 = Reg(s0)
D0 = Reg(d0)
SCC = Reg(scc)
tmp = Reg(0)
# --- compiled pseudocode ---
tmp = Reg(0)
for i in range(0, int(31)+1):
tmp[i] = S0.u32[(i & 60) + (4) - 1 : (i & 60)] != 0
D0.u32 = tmp
SCC = Reg(D0.u32 != 0)
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': SCC._val & 1}
return result
def _SOP1Op_S_WQM_B64(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
# tmp = 0ULL;
# declare i : 6'U;
# for i in 6'0U : 6'63U do
# tmp[i] = S0.u64[i & 6'60U +: 6'4U] != 0ULL
# endfor;
# D0.u64 = tmp;
# SCC = D0.u64 != 0ULL
S0 = Reg(s0)
D0 = Reg(d0)
SCC = Reg(scc)
tmp = Reg(0)
# --- compiled pseudocode ---
tmp = Reg(0)
for i in range(0, int(63)+1):
tmp[i] = S0.u64[(i & 60) + (4) - 1 : (i & 60)] != 0
D0.u64 = tmp
SCC = Reg(D0.u64 != 0)
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': SCC._val & 1}
result['d0_64'] = True
return result
def _SOP1Op_S_BREV_B32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
# D0.u32[31 : 0] = S0.u32[0 : 31]
S0 = Reg(s0)
@@ -619,6 +664,49 @@ def _SOP1Op_S_XNOR_SAVEEXEC_B64(s0, s1, s2, d0, scc, vcc, lane, exec_mask, liter
result['d0_64'] = True
return result
def _SOP1Op_S_QUADMASK_B32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
# tmp = 0U;
# for i in 0 : 7 do
# tmp[i] = S0.u32[i * 4 +: 4] != 0U
# endfor;
# D0.u32 = tmp;
# SCC = D0.u32 != 0U
S0 = Reg(s0)
D0 = Reg(d0)
SCC = Reg(scc)
tmp = Reg(0)
# --- compiled pseudocode ---
tmp = Reg(0)
for i in range(0, int(7)+1):
tmp[i] = S0.u32[(i * 4) + (4) - 1 : (i * 4)] != 0
D0.u32 = tmp
SCC = Reg(D0.u32 != 0)
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': SCC._val & 1}
return result
def _SOP1Op_S_QUADMASK_B64(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
# tmp = 0ULL;
# for i in 0 : 15 do
# tmp[i] = S0.u64[i * 4 +: 4] != 0ULL
# endfor;
# D0.u64 = tmp;
# SCC = D0.u64 != 0ULL
S0 = Reg(s0)
D0 = Reg(d0)
SCC = Reg(scc)
tmp = Reg(0)
# --- compiled pseudocode ---
tmp = Reg(0)
for i in range(0, int(15)+1):
tmp[i] = S0.u64[(i * 4) + (4) - 1 : (i * 4)] != 0
D0.u64 = tmp
SCC = Reg(D0.u64 != 0)
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': SCC._val & 1}
result['d0_64'] = True
return result
def _SOP1Op_S_ABS_I32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
# D0.i32 = S0.i32 < 0 ? -S0.i32 : S0.i32;
# SCC = D0.i32 != 0
@@ -755,6 +843,8 @@ SOP1Op_FUNCTIONS = {
SOP1Op.S_CMOV_B64: _SOP1Op_S_CMOV_B64,
SOP1Op.S_NOT_B32: _SOP1Op_S_NOT_B32,
SOP1Op.S_NOT_B64: _SOP1Op_S_NOT_B64,
SOP1Op.S_WQM_B32: _SOP1Op_S_WQM_B32,
SOP1Op.S_WQM_B64: _SOP1Op_S_WQM_B64,
SOP1Op.S_BREV_B32: _SOP1Op_S_BREV_B32,
SOP1Op.S_BREV_B64: _SOP1Op_S_BREV_B64,
SOP1Op.S_BCNT0_I32_B32: _SOP1Op_S_BCNT0_I32_B32,
@@ -783,6 +873,8 @@ SOP1Op_FUNCTIONS = {
SOP1Op.S_NAND_SAVEEXEC_B64: _SOP1Op_S_NAND_SAVEEXEC_B64,
SOP1Op.S_NOR_SAVEEXEC_B64: _SOP1Op_S_NOR_SAVEEXEC_B64,
SOP1Op.S_XNOR_SAVEEXEC_B64: _SOP1Op_S_XNOR_SAVEEXEC_B64,
SOP1Op.S_QUADMASK_B32: _SOP1Op_S_QUADMASK_B32,
SOP1Op.S_QUADMASK_B64: _SOP1Op_S_QUADMASK_B64,
SOP1Op.S_ABS_I32: _SOP1Op_S_ABS_I32,
SOP1Op.S_SET_GPR_IDX_IDX: _SOP1Op_S_SET_GPR_IDX_IDX,
SOP1Op.S_ANDN1_SAVEEXEC_B64: _SOP1Op_S_ANDN1_SAVEEXEC_B64,
@@ -3495,6 +3587,24 @@ def _VOP2Op_V_XOR_B32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR,
result = {'d0': D0._val, 'scc': scc & 1}
return result
def _VOP2Op_V_DOT2C_F32_BF16(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
# tmp = D0.f32;
# tmp += bf16_to_f32(S0[15 : 0].bf16) * bf16_to_f32(S1[15 : 0].bf16);
# tmp += bf16_to_f32(S0[31 : 16].bf16) * bf16_to_f32(S1[31 : 16].bf16);
# D0.f32 = tmp
S0 = Reg(s0)
S1 = Reg(s1)
D0 = Reg(d0)
tmp = Reg(0)
# --- compiled pseudocode ---
tmp = Reg(D0.f32)
tmp += bf16_to_f32(S0[15 : 0].bf16) * bf16_to_f32(S1[15 : 0].bf16)
tmp += bf16_to_f32(S0[31 : 16].bf16) * bf16_to_f32(S1[31 : 16].bf16)
D0.f32 = tmp
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': scc & 1}
return result
def _VOP2Op_V_FMAMK_F32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
# D0.f32 = fma(S0.f32, SIMM32.f32, S1.f32)
S0 = Reg(s0)
@@ -4119,6 +4229,7 @@ VOP2Op_FUNCTIONS = {
VOP2Op.V_AND_B32: _VOP2Op_V_AND_B32,
VOP2Op.V_OR_B32: _VOP2Op_V_OR_B32,
VOP2Op.V_XOR_B32: _VOP2Op_V_XOR_B32,
VOP2Op.V_DOT2C_F32_BF16: _VOP2Op_V_DOT2C_F32_BF16,
VOP2Op.V_FMAMK_F32: _VOP2Op_V_FMAMK_F32,
VOP2Op.V_FMAAK_F32: _VOP2Op_V_FMAAK_F32,
VOP2Op.V_ADD_CO_U32: _VOP2Op_V_ADD_CO_U32,
@@ -4482,6 +4593,25 @@ def _VOP3POp_V_PK_MAX_F16(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VG
result = {'d0': D0._val, 'scc': scc & 1}
return result
def _VOP3POp_V_DOT2_F32_BF16(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
# tmp = 32'F(S0[15 : 0].bf16) * 32'F(S1[15 : 0].bf16);
# tmp += 32'F(S0[31 : 16].bf16) * 32'F(S1[31 : 16].bf16);
# tmp += S2.f32;
# D0.f32 = tmp
S0 = Reg(s0)
S1 = Reg(s1)
S2 = Reg(s2)
D0 = Reg(d0)
tmp = Reg(0)
# --- compiled pseudocode ---
tmp = Reg(F(S0[15 : 0].bf16) * F(S1[15 : 0].bf16))
tmp += F(S0[31 : 16].bf16) * F(S1[31 : 16].bf16)
tmp += S2.f32
D0.f32 = tmp
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': scc & 1}
return result
def _VOP3POp_V_PK_MINIMUM3_F16(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
# tmp[31 : 16].f16 = 16'F(v_minimum3_f16(S0[31 : 16].f16, S1[31 : 16].f16, S2[31 : 16].f16));
# tmp[15 : 0].f16 = 16'F(v_minimum3_f16(S0[15 : 0].f16, S1[15 : 0].f16, S2[15 : 0].f16));
@@ -4774,6 +4904,7 @@ VOP3POp_FUNCTIONS = {
VOP3POp.V_PK_MUL_F16: _VOP3POp_V_PK_MUL_F16,
VOP3POp.V_PK_MIN_F16: _VOP3POp_V_PK_MIN_F16,
VOP3POp.V_PK_MAX_F16: _VOP3POp_V_PK_MAX_F16,
VOP3POp.V_DOT2_F32_BF16: _VOP3POp_V_DOT2_F32_BF16,
VOP3POp.V_PK_MINIMUM3_F16: _VOP3POp_V_PK_MINIMUM3_F16,
VOP3POp.V_PK_MAXIMUM3_F16: _VOP3POp_V_PK_MAXIMUM3_F16,
VOP3POp.V_DOT2_F32_F16: _VOP3POp_V_DOT2_F32_F16,
@@ -13632,6 +13763,24 @@ def _VOP3AOp_V_XOR_B32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR,
result = {'d0': D0._val, 'scc': scc & 1}
return result
def _VOP3AOp_V_DOT2C_F32_BF16(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
# tmp = D0.f32;
# tmp += bf16_to_f32(S0[15 : 0].bf16) * bf16_to_f32(S1[15 : 0].bf16);
# tmp += bf16_to_f32(S0[31 : 16].bf16) * bf16_to_f32(S1[31 : 16].bf16);
# D0.f32 = tmp
S0 = Reg(s0)
S1 = Reg(s1)
D0 = Reg(d0)
tmp = Reg(0)
# --- compiled pseudocode ---
tmp = Reg(D0.f32)
tmp += bf16_to_f32(S0[15 : 0].bf16) * bf16_to_f32(S1[15 : 0].bf16)
tmp += bf16_to_f32(S0[31 : 16].bf16) * bf16_to_f32(S1[31 : 16].bf16)
D0.f32 = tmp
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': scc & 1}
return result
def _VOP3AOp_V_ADD_F16(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
# D0.f16 = S0.f16 + S1.f16
S0 = Reg(s0)
@@ -14521,6 +14670,18 @@ def _VOP3AOp_V_SAD_U8(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR,
result = {'d0': D0._val, 'scc': scc & 1}
return result
def _VOP3AOp_V_SAD_HI_U8(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
# D0.u32 = (32'U(v_sad_u8(S0, S1, 0U)) << 16U) + S2.u32
S0 = Reg(s0)
S1 = Reg(s1)
S2 = Reg(s2)
D0 = Reg(d0)
# --- compiled pseudocode ---
D0.u32 = ((v_sad_u8(S0, S1, 0)) << 16) + S2.u32
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': scc & 1}
return result
def _VOP3AOp_V_SAD_U16(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
# // UNSIGNED comparison
# tmp = S2.u32;
@@ -14747,6 +14908,71 @@ def _VOP3AOp_V_MSAD_U8(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR,
result = {'d0': D0._val, 'scc': scc & 1}
return result
def _VOP3AOp_V_QSAD_PK_U16_U8(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
# tmp[63 : 48] = 16'B(v_sad_u8(S0[55 : 24], S1[31 : 0], S2[63 : 48].u32));
# tmp[47 : 32] = 16'B(v_sad_u8(S0[47 : 16], S1[31 : 0], S2[47 : 32].u32));
# tmp[31 : 16] = 16'B(v_sad_u8(S0[39 : 8], S1[31 : 0], S2[31 : 16].u32));
# tmp[15 : 0] = 16'B(v_sad_u8(S0[31 : 0], S1[31 : 0], S2[15 : 0].u32));
# D0.b64 = tmp.b64
S0 = Reg(s0)
S1 = Reg(s1)
S2 = Reg(s2)
D0 = Reg(d0)
tmp = Reg(0)
# --- compiled pseudocode ---
tmp[63 : 48] = (v_sad_u8(S0[55 : 24], S1[31 : 0], S2[63 : 48].u32))
tmp[47 : 32] = (v_sad_u8(S0[47 : 16], S1[31 : 0], S2[47 : 32].u32))
tmp[31 : 16] = (v_sad_u8(S0[39 : 8], S1[31 : 0], S2[31 : 16].u32))
tmp[15 : 0] = (v_sad_u8(S0[31 : 0], S1[31 : 0], S2[15 : 0].u32))
D0.b64 = tmp.b64
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': scc & 1}
result['d0_64'] = True
return result
def _VOP3AOp_V_MQSAD_PK_U16_U8(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
# tmp[63 : 48] = 16'B(v_msad_u8(S0[55 : 24], S1[31 : 0], S2[63 : 48].u32));
# tmp[47 : 32] = 16'B(v_msad_u8(S0[47 : 16], S1[31 : 0], S2[47 : 32].u32));
# tmp[31 : 16] = 16'B(v_msad_u8(S0[39 : 8], S1[31 : 0], S2[31 : 16].u32));
# tmp[15 : 0] = 16'B(v_msad_u8(S0[31 : 0], S1[31 : 0], S2[15 : 0].u32));
# D0.b64 = tmp.b64
S0 = Reg(s0)
S1 = Reg(s1)
S2 = Reg(s2)
D0 = Reg(d0)
tmp = Reg(0)
# --- compiled pseudocode ---
tmp[63 : 48] = (v_msad_u8(S0[55 : 24], S1[31 : 0], S2[63 : 48].u32))
tmp[47 : 32] = (v_msad_u8(S0[47 : 16], S1[31 : 0], S2[47 : 32].u32))
tmp[31 : 16] = (v_msad_u8(S0[39 : 8], S1[31 : 0], S2[31 : 16].u32))
tmp[15 : 0] = (v_msad_u8(S0[31 : 0], S1[31 : 0], S2[15 : 0].u32))
D0.b64 = tmp.b64
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': scc & 1}
result['d0_64'] = True
return result
def _VOP3AOp_V_MQSAD_U32_U8(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
# tmp[127 : 96] = 32'B(v_msad_u8(S0[55 : 24], S1[31 : 0], S2[127 : 96].u32));
# tmp[95 : 64] = 32'B(v_msad_u8(S0[47 : 16], S1[31 : 0], S2[95 : 64].u32));
# tmp[63 : 32] = 32'B(v_msad_u8(S0[39 : 8], S1[31 : 0], S2[63 : 32].u32));
# tmp[31 : 0] = 32'B(v_msad_u8(S0[31 : 0], S1[31 : 0], S2[31 : 0].u32));
# D0.b128 = tmp.b128
S0 = Reg(s0)
S1 = Reg(s1)
S2 = Reg(s2)
D0 = Reg(d0)
tmp = Reg(0)
# --- compiled pseudocode ---
tmp[127 : 96] = (v_msad_u8(S0[55 : 24], S1[31 : 0], S2[127 : 96].u32))
tmp[95 : 64] = (v_msad_u8(S0[47 : 16], S1[31 : 0], S2[95 : 64].u32))
tmp[63 : 32] = (v_msad_u8(S0[39 : 8], S1[31 : 0], S2[63 : 32].u32))
tmp[31 : 0] = (v_msad_u8(S0[31 : 0], S1[31 : 0], S2[31 : 0].u32))
D0.b128 = tmp.b128
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': scc & 1}
return result
def _VOP3AOp_V_MAD_LEGACY_F16(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
# tmp = S0.f16 * S1.f16 + S2.f16;
# if OPSEL.u4[3] then
@@ -15601,6 +15827,76 @@ def _VOP3AOp_V_CVT_SCALEF32_SR_BF8_F16(s0, s1, s2, d0, scc, vcc, lane, exec_mask
result = {'d0': d0, 'scc': scc & 1}
return result
def _VOP3AOp_V_CVT_SCALEF32_PK_FP8_BF16(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
# scale = 32'U(exponent(S1.f32));
# tmp0 = bf16_to_fp8_scale(S0[15 : 0].bf16, scale.u8);
# tmp1 = bf16_to_fp8_scale(S0[31 : 16].bf16, scale.u8);
# dstword = OPSEL[3].i32 * 16;
# // Other destination bits are preserved
S0 = Reg(s0)
S1 = Reg(s1)
tmp = Reg(0)
# --- compiled pseudocode ---
scale = (exponent(S1.f32))
tmp0 = bf16_to_fp8_scale(S0[15 : 0].bf16, scale.u8)
tmp1 = bf16_to_fp8_scale(S0[31 : 16].bf16, scale.u8)
dstword = OPSEL[3].i32 * 16
# --- end pseudocode ---
result = {'d0': d0, 'scc': scc & 1}
return result
def _VOP3AOp_V_CVT_SCALEF32_PK_BF8_BF16(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
# scale = 32'U(exponent(S1.f32));
# tmp0 = bf16_to_bf8_scale(S0[15 : 0].bf16, scale.u8);
# tmp1 = bf16_to_bf8_scale(S0[31 : 16].bf16, scale.u8);
# dstword = OPSEL[3].i32 * 16;
# // Other destination bits are preserved
S0 = Reg(s0)
S1 = Reg(s1)
tmp = Reg(0)
# --- compiled pseudocode ---
scale = (exponent(S1.f32))
tmp0 = bf16_to_bf8_scale(S0[15 : 0].bf16, scale.u8)
tmp1 = bf16_to_bf8_scale(S0[31 : 16].bf16, scale.u8)
dstword = OPSEL[3].i32 * 16
# --- end pseudocode ---
result = {'d0': d0, 'scc': scc & 1}
return result
def _VOP3AOp_V_CVT_SCALEF32_SR_FP8_BF16(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
# scale = 32'U(exponent(S2.f32));
# tmp = bf16_to_fp8_sr_scale(S0.bf16, S1.u32, scale.u8);
# dstbyte = OPSEL[3 : 2].i32 * 8;
# // Other destination bits are preserved
S0 = Reg(s0)
S1 = Reg(s1)
S2 = Reg(s2)
tmp = Reg(0)
# --- compiled pseudocode ---
scale = (exponent(S2.f32))
tmp = Reg(bf16_to_fp8_sr_scale(S0.bf16, S1.u32, scale.u8))
dstbyte = OPSEL[3 : 2].i32 * 8
# --- end pseudocode ---
result = {'d0': d0, 'scc': scc & 1}
return result
def _VOP3AOp_V_CVT_SCALEF32_SR_BF8_BF16(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
# scale = 32'U(exponent(S2.f32));
# tmp = bf16_to_bf8_sr_scale(S0.bf16, S1.u32, scale.u8);
# dstbyte = OPSEL[3 : 2].i32 * 8;
# // Other destination bits are preserved
S0 = Reg(s0)
S1 = Reg(s1)
S2 = Reg(s2)
tmp = Reg(0)
# --- compiled pseudocode ---
scale = (exponent(S2.f32))
tmp = Reg(bf16_to_bf8_sr_scale(S0.bf16, S1.u32, scale.u8))
dstbyte = OPSEL[3 : 2].i32 * 8
# --- end pseudocode ---
result = {'d0': d0, 'scc': scc & 1}
return result
def _VOP3AOp_V_CVT_SCALEF32_PK_F16_FP8(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
# scale = 32'U(exponent(S1.f32));
# srcword = OPSEL[0].i32 * 16;
@@ -15699,6 +15995,24 @@ def _VOP3AOp_V_CVT_SCALEF32_PK_FP4_F16(s0, s1, s2, d0, scc, vcc, lane, exec_mask
result = {'d0': d0, 'scc': scc & 1}
return result
def _VOP3AOp_V_CVT_SCALEF32_PK_FP4_BF16(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
# scale = 32'U(exponent(S1.f32));
# tmp0 = bf16_to_fp4_scale(S0[15 : 0].bf16, scale.u8);
# tmp1 = bf16_to_fp4_scale(S0[31 : 16].bf16, scale.u8);
# dstbyte = OPSEL[3 : 2].i32 * 8;
# // Other destination bits are preserved
S0 = Reg(s0)
S1 = Reg(s1)
tmp = Reg(0)
# --- compiled pseudocode ---
scale = (exponent(S1.f32))
tmp0 = bf16_to_fp4_scale(S0[15 : 0].bf16, scale.u8)
tmp1 = bf16_to_fp4_scale(S0[31 : 16].bf16, scale.u8)
dstbyte = OPSEL[3 : 2].i32 * 8
# --- end pseudocode ---
result = {'d0': d0, 'scc': scc & 1}
return result
def _VOP3AOp_V_CVT_SCALEF32_SR_PK_FP4_F16(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
# scale = 32'U(exponent(S2.f32));
# randomVal = S1.u32;
@@ -15720,6 +16034,27 @@ def _VOP3AOp_V_CVT_SCALEF32_SR_PK_FP4_F16(s0, s1, s2, d0, scc, vcc, lane, exec_m
result = {'d0': d0, 'scc': scc & 1}
return result
def _VOP3AOp_V_CVT_SCALEF32_SR_PK_FP4_BF16(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
# scale = 32'U(exponent(S2.f32));
# randomVal = S1.u32;
# tmp0 = bf16_to_fp4_sr_scale(S0[15 : 0].bf16, randomVal, scale.u8);
# tmp1 = bf16_to_fp4_sr_scale(S0[31 : 16].bf16, randomVal, scale.u8);
# dstbyte = OPSEL[3 : 2].i32 * 8;
# // Other destination bits are preserved
S0 = Reg(s0)
S1 = Reg(s1)
S2 = Reg(s2)
tmp = Reg(0)
# --- compiled pseudocode ---
scale = (exponent(S2.f32))
randomVal = S1.u32
tmp0 = bf16_to_fp4_sr_scale(S0[15 : 0].bf16, randomVal, scale.u8)
tmp1 = bf16_to_fp4_sr_scale(S0[31 : 16].bf16, randomVal, scale.u8)
dstbyte = OPSEL[3 : 2].i32 * 8
# --- end pseudocode ---
result = {'d0': d0, 'scc': scc & 1}
return result
def _VOP3AOp_V_CVT_SCALEF32_PK_F16_FP4(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
# scale = 32'U(exponent(S1.f32));
# srcbyte = OPSEL[1 : 0].i32 * 8;
@@ -15741,6 +16076,27 @@ def _VOP3AOp_V_CVT_SCALEF32_PK_F16_FP4(s0, s1, s2, d0, scc, vcc, lane, exec_mask
result = {'d0': D0._val, 'scc': scc & 1}
return result
def _VOP3AOp_V_CVT_SCALEF32_PK_BF16_FP4(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
# scale = 32'U(exponent(S1.f32));
# srcbyte = OPSEL[1 : 0].i32 * 8;
# src = VGPR[laneId][SRC0.u32][srcbyte + 7 : srcbyte].b8;
# D0[15 : 0].bf16 = tmp0;
# D0[31 : 16].bf16 = tmp1
S1 = Reg(s1)
D0 = Reg(d0)
tmp = Reg(0)
laneId = lane
SRC0 = Reg(src0_idx)
# --- compiled pseudocode ---
scale = (exponent(S1.f32))
srcbyte = OPSEL[1 : 0].i32 * 8
src = VGPR[laneId][SRC0.u32][srcbyte + 7 : srcbyte].b8
D0[15 : 0].bf16 = tmp0
D0[31 : 16].bf16 = tmp1
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': scc & 1}
return result
def _VOP3AOp_V_CVT_SCALEF32_2XPK16_FP6_F32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
# scale = 32'U(exponent(S2.f32));
# declare tmp : 192'B;
@@ -15873,6 +16229,26 @@ def _VOP3AOp_V_CVT_SCALEF32_PK32_F32_BF6(s0, s1, s2, d0, scc, vcc, lane, exec_ma
result = {'d0': D0._val, 'scc': scc & 1}
return result
def _VOP3AOp_V_CVT_SCALEF32_PK32_FP6_BF16(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
# scale = 32'U(exponent(S1.f32));
# declare tmp : 192'B;
# for pass in 0 : 31 do
# tmp[dOffset + 5 : dOffset].fp6 = bf16_to_fp6_scale(S0[sOffset + 15 : sOffset].bf16, scale.u8)
# endfor;
# D0[191 : 0] = tmp.b192
S0 = Reg(s0)
S1 = Reg(s1)
D0 = Reg(d0)
tmp = Reg(0)
# --- compiled pseudocode ---
scale = (exponent(S1.f32))
for pass in range(0, int(31)+1):
tmp[dOffset + 5 : dOffset].fp6 = bf16_to_fp6_scale(S0[sOffset + 15 : sOffset].bf16, scale.u8)
D0[191 : 0] = tmp.b192
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': scc & 1}
return result
def _VOP3AOp_V_CVT_SCALEF32_PK32_BF6_F16(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
# scale = 32'U(exponent(S1.f32));
# declare tmp : 192'B;
@@ -15893,6 +16269,26 @@ def _VOP3AOp_V_CVT_SCALEF32_PK32_BF6_F16(s0, s1, s2, d0, scc, vcc, lane, exec_ma
result = {'d0': D0._val, 'scc': scc & 1}
return result
def _VOP3AOp_V_CVT_SCALEF32_PK32_BF6_BF16(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
# scale = 32'U(exponent(S1.f32));
# declare tmp : 192'B;
# for pass in 0 : 31 do
# tmp[dOffset + 5 : dOffset].bf6 = bf16_to_bf6_scale(S0[sOffset + 15 : sOffset].bf16, scale.u8)
# endfor;
# D0[191 : 0] = tmp.b192
S0 = Reg(s0)
S1 = Reg(s1)
D0 = Reg(d0)
tmp = Reg(0)
# --- compiled pseudocode ---
scale = (exponent(S1.f32))
for pass in range(0, int(31)+1):
tmp[dOffset + 5 : dOffset].bf6 = bf16_to_bf6_scale(S0[sOffset + 15 : sOffset].bf16, scale.u8)
D0[191 : 0] = tmp.b192
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': scc & 1}
return result
def _VOP3AOp_V_CVT_SCALEF32_SR_PK32_FP6_F16(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
# scale = 32'U(exponent(S2.f32));
# randomVal = S1.u32;
@@ -15915,6 +16311,28 @@ def _VOP3AOp_V_CVT_SCALEF32_SR_PK32_FP6_F16(s0, s1, s2, d0, scc, vcc, lane, exec
result = {'d0': D0._val, 'scc': scc & 1}
return result
def _VOP3AOp_V_CVT_SCALEF32_SR_PK32_FP6_BF16(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
# scale = 32'U(exponent(S2.f32));
# randomVal = S1.u32;
# declare tmp : 192'B;
# for pass in 0 : 31 do
# tmp[dOffset + 5 : dOffset].fp6 = bf16_to_fp6_sr_scale(S0[sOffset + 15 : sOffset].bf16, randomVal,
# endfor;
# D0[191 : 0] = tmp.b192
S0 = Reg(s0)
S1 = Reg(s1)
S2 = Reg(s2)
D0 = Reg(d0)
tmp = Reg(0)
# --- compiled pseudocode ---
scale = (exponent(S2.f32))
randomVal = S1.u32
for pass in range(0, int(31)+1):
tmp[dOffset + 5 : dOffset].fp6 = bf16_to_fp6_sr_scale(S0[sOffset + 15 : sOffset].bf16, randomVal, endfor; D0[191 : 0] = tmp.b192
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': scc & 1}
return result
def _VOP3AOp_V_CVT_SCALEF32_SR_PK32_BF6_F16(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
# scale = 32'U(exponent(S2.f32));
# randomVal = S1.u32;
@@ -15937,6 +16355,28 @@ def _VOP3AOp_V_CVT_SCALEF32_SR_PK32_BF6_F16(s0, s1, s2, d0, scc, vcc, lane, exec
result = {'d0': D0._val, 'scc': scc & 1}
return result
def _VOP3AOp_V_CVT_SCALEF32_SR_PK32_BF6_BF16(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
# scale = 32'U(exponent(S2.f32));
# randomVal = S1.u32;
# declare tmp : 192'B;
# for pass in 0 : 31 do
# tmp[dOffset + 5 : dOffset].bf6 = bf16_to_bf6_sr_scale(S0[sOffset + 15 : sOffset].bf16, randomVal,
# endfor;
# D0[191 : 0] = tmp.b192
S0 = Reg(s0)
S1 = Reg(s1)
S2 = Reg(s2)
D0 = Reg(d0)
tmp = Reg(0)
# --- compiled pseudocode ---
scale = (exponent(S2.f32))
randomVal = S1.u32
for pass in range(0, int(31)+1):
tmp[dOffset + 5 : dOffset].bf6 = bf16_to_bf6_sr_scale(S0[sOffset + 15 : sOffset].bf16, randomVal, endfor; D0[191 : 0] = tmp.b192
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': scc & 1}
return result
def _VOP3AOp_V_CVT_SCALEF32_PK32_F16_FP6(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
# scale = 32'U(exponent(S1.f32));
# declare tmp : 512'B;
@@ -15957,6 +16397,26 @@ def _VOP3AOp_V_CVT_SCALEF32_PK32_F16_FP6(s0, s1, s2, d0, scc, vcc, lane, exec_ma
result = {'d0': D0._val, 'scc': scc & 1}
return result
def _VOP3AOp_V_CVT_SCALEF32_PK32_BF16_FP6(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
# scale = 32'U(exponent(S1.f32));
# declare tmp : 512'B;
# for pass in 0 : 31 do
# tmp[dOffset + 15 : dOffset].bf16 = fp6_to_bf16_scale(S0[sOffset + 5 : sOffset].fp6, scale.u8)
# endfor;
# D0[511 : 0] = tmp.b512
S0 = Reg(s0)
S1 = Reg(s1)
D0 = Reg(d0)
tmp = Reg(0)
# --- compiled pseudocode ---
scale = (exponent(S1.f32))
for pass in range(0, int(31)+1):
tmp[dOffset + 15 : dOffset].bf16 = fp6_to_bf16_scale(S0[sOffset + 5 : sOffset].fp6, scale.u8)
D0[511 : 0] = tmp.b512
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': scc & 1}
return result
def _VOP3AOp_V_CVT_SCALEF32_PK32_F16_BF6(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
# scale = 32'U(exponent(S1.f32));
# declare tmp : 512'B;
@@ -15977,6 +16437,26 @@ def _VOP3AOp_V_CVT_SCALEF32_PK32_F16_BF6(s0, s1, s2, d0, scc, vcc, lane, exec_ma
result = {'d0': D0._val, 'scc': scc & 1}
return result
def _VOP3AOp_V_CVT_SCALEF32_PK32_BF16_BF6(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
# scale = 32'U(exponent(S1.f32));
# declare tmp : 512'B;
# for pass in 0 : 31 do
# tmp[dOffset + 15 : dOffset].bf16 = bf6_to_bf16_scale(S0[sOffset + 5 : sOffset].bf6, scale.u8)
# endfor;
# D0[511 : 0] = tmp.b512
S0 = Reg(s0)
S1 = Reg(s1)
D0 = Reg(d0)
tmp = Reg(0)
# --- compiled pseudocode ---
scale = (exponent(S1.f32))
for pass in range(0, int(31)+1):
tmp[dOffset + 15 : dOffset].bf16 = bf6_to_bf16_scale(S0[sOffset + 5 : sOffset].bf6, scale.u8)
D0[511 : 0] = tmp.b512
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': scc & 1}
return result
def _VOP3AOp_V_ASHR_PK_I8_I32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
# if n <= -128 then
# elsif n >= 127 then
@@ -16048,6 +16528,63 @@ def _VOP3AOp_V_CVT_PK_F16_F32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal
result = {'d0': d0, 'scc': scc & 1}
return result
def _VOP3AOp_V_CVT_PK_BF16_F32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
# prev_mode = ROUND_MODE;
# tmp[15 : 0].bf16 = f32_to_bf16(S0.f32);
# tmp[31 : 16].bf16 = f32_to_bf16(S1.f32);
S0 = Reg(s0)
S1 = Reg(s1)
tmp = Reg(0)
# --- compiled pseudocode ---
prev_mode = ROUND_MODE
tmp[15 : 0].bf16 = f32_to_bf16(S0.f32)
tmp[31 : 16].bf16 = f32_to_bf16(S1.f32)
# --- end pseudocode ---
result = {'d0': d0, 'scc': scc & 1}
return result
def _VOP3AOp_V_CVT_SCALEF32_PK_BF16_FP8(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
# scale = 32'U(exponent(S1.f32));
# srcword = OPSEL[0].i32 * 16;
# src = VGPR[laneId][SRC0.u32][srcword + 15 : srcword].b16;
# D0[15 : 0].bf16 = tmp0.bf16;
# D0[31 : 16].bf16 = tmp1.bf16
S1 = Reg(s1)
D0 = Reg(d0)
tmp = Reg(0)
laneId = lane
SRC0 = Reg(src0_idx)
# --- compiled pseudocode ---
scale = (exponent(S1.f32))
srcword = OPSEL[0].i32 * 16
src = VGPR[laneId][SRC0.u32][srcword + 15 : srcword].b16
D0[15 : 0].bf16 = tmp0.bf16
D0[31 : 16].bf16 = tmp1.bf16
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': scc & 1}
return result
def _VOP3AOp_V_CVT_SCALEF32_PK_BF16_BF8(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
# scale = 32'U(exponent(S1.f32));
# srcword = OPSEL[0].i32 * 16;
# src = VGPR[laneId][SRC0.u32][srcword + 15 : srcword].b16;
# D0[15 : 0].bf16 = tmp0.bf16;
# D0[31 : 16].bf16 = tmp1.bf16
S1 = Reg(s1)
D0 = Reg(d0)
tmp = Reg(0)
laneId = lane
SRC0 = Reg(src0_idx)
# --- compiled pseudocode ---
scale = (exponent(S1.f32))
srcword = OPSEL[0].i32 * 16
src = VGPR[laneId][SRC0.u32][srcword + 15 : srcword].b16
D0[15 : 0].bf16 = tmp0.bf16
D0[31 : 16].bf16 = tmp1.bf16
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': scc & 1}
return result
def _VOP3AOp_V_ADD_F64(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
# D0.f64 = S0.f64 + S1.f64
S0 = Reg(s0)
@@ -16771,6 +17308,7 @@ VOP3AOp_FUNCTIONS = {
VOP3AOp.V_AND_B32: _VOP3AOp_V_AND_B32,
VOP3AOp.V_OR_B32: _VOP3AOp_V_OR_B32,
VOP3AOp.V_XOR_B32: _VOP3AOp_V_XOR_B32,
VOP3AOp.V_DOT2C_F32_BF16: _VOP3AOp_V_DOT2C_F32_BF16,
VOP3AOp.V_ADD_F16: _VOP3AOp_V_ADD_F16,
VOP3AOp.V_SUB_F16: _VOP3AOp_V_SUB_F16,
VOP3AOp.V_SUBREV_F16: _VOP3AOp_V_SUBREV_F16,
@@ -16824,6 +17362,7 @@ VOP3AOp_FUNCTIONS = {
VOP3AOp.V_MED3_I32: _VOP3AOp_V_MED3_I32,
VOP3AOp.V_MED3_U32: _VOP3AOp_V_MED3_U32,
VOP3AOp.V_SAD_U8: _VOP3AOp_V_SAD_U8,
VOP3AOp.V_SAD_HI_U8: _VOP3AOp_V_SAD_HI_U8,
VOP3AOp.V_SAD_U16: _VOP3AOp_V_SAD_U16,
VOP3AOp.V_SAD_U32: _VOP3AOp_V_SAD_U32,
VOP3AOp.V_CVT_PK_U8_F32: _VOP3AOp_V_CVT_PK_U8_F32,
@@ -16832,6 +17371,9 @@ VOP3AOp_FUNCTIONS = {
VOP3AOp.V_DIV_FMAS_F32: _VOP3AOp_V_DIV_FMAS_F32,
VOP3AOp.V_DIV_FMAS_F64: _VOP3AOp_V_DIV_FMAS_F64,
VOP3AOp.V_MSAD_U8: _VOP3AOp_V_MSAD_U8,
VOP3AOp.V_QSAD_PK_U16_U8: _VOP3AOp_V_QSAD_PK_U16_U8,
VOP3AOp.V_MQSAD_PK_U16_U8: _VOP3AOp_V_MQSAD_PK_U16_U8,
VOP3AOp.V_MQSAD_U32_U8: _VOP3AOp_V_MQSAD_U32_U8,
VOP3AOp.V_MAD_LEGACY_F16: _VOP3AOp_V_MAD_LEGACY_F16,
VOP3AOp.V_MAD_LEGACY_U16: _VOP3AOp_V_MAD_LEGACY_U16,
VOP3AOp.V_MAD_LEGACY_I16: _VOP3AOp_V_MAD_LEGACY_I16,
@@ -16879,27 +17421,43 @@ VOP3AOp_FUNCTIONS = {
VOP3AOp.V_CVT_SCALEF32_PK_BF8_F16: _VOP3AOp_V_CVT_SCALEF32_PK_BF8_F16,
VOP3AOp.V_CVT_SCALEF32_SR_FP8_F16: _VOP3AOp_V_CVT_SCALEF32_SR_FP8_F16,
VOP3AOp.V_CVT_SCALEF32_SR_BF8_F16: _VOP3AOp_V_CVT_SCALEF32_SR_BF8_F16,
VOP3AOp.V_CVT_SCALEF32_PK_FP8_BF16: _VOP3AOp_V_CVT_SCALEF32_PK_FP8_BF16,
VOP3AOp.V_CVT_SCALEF32_PK_BF8_BF16: _VOP3AOp_V_CVT_SCALEF32_PK_BF8_BF16,
VOP3AOp.V_CVT_SCALEF32_SR_FP8_BF16: _VOP3AOp_V_CVT_SCALEF32_SR_FP8_BF16,
VOP3AOp.V_CVT_SCALEF32_SR_BF8_BF16: _VOP3AOp_V_CVT_SCALEF32_SR_BF8_BF16,
VOP3AOp.V_CVT_SCALEF32_PK_F16_FP8: _VOP3AOp_V_CVT_SCALEF32_PK_F16_FP8,
VOP3AOp.V_CVT_SCALEF32_PK_F16_BF8: _VOP3AOp_V_CVT_SCALEF32_PK_F16_BF8,
VOP3AOp.V_CVT_SCALEF32_F16_FP8: _VOP3AOp_V_CVT_SCALEF32_F16_FP8,
VOP3AOp.V_CVT_SCALEF32_F16_BF8: _VOP3AOp_V_CVT_SCALEF32_F16_BF8,
VOP3AOp.V_CVT_SCALEF32_PK_FP4_F16: _VOP3AOp_V_CVT_SCALEF32_PK_FP4_F16,
VOP3AOp.V_CVT_SCALEF32_PK_FP4_BF16: _VOP3AOp_V_CVT_SCALEF32_PK_FP4_BF16,
VOP3AOp.V_CVT_SCALEF32_SR_PK_FP4_F16: _VOP3AOp_V_CVT_SCALEF32_SR_PK_FP4_F16,
VOP3AOp.V_CVT_SCALEF32_SR_PK_FP4_BF16: _VOP3AOp_V_CVT_SCALEF32_SR_PK_FP4_BF16,
VOP3AOp.V_CVT_SCALEF32_PK_F16_FP4: _VOP3AOp_V_CVT_SCALEF32_PK_F16_FP4,
VOP3AOp.V_CVT_SCALEF32_PK_BF16_FP4: _VOP3AOp_V_CVT_SCALEF32_PK_BF16_FP4,
VOP3AOp.V_CVT_SCALEF32_2XPK16_FP6_F32: _VOP3AOp_V_CVT_SCALEF32_2XPK16_FP6_F32,
VOP3AOp.V_CVT_SCALEF32_2XPK16_BF6_F32: _VOP3AOp_V_CVT_SCALEF32_2XPK16_BF6_F32,
VOP3AOp.V_CVT_SCALEF32_SR_PK32_FP6_F32: _VOP3AOp_V_CVT_SCALEF32_SR_PK32_FP6_F32,
VOP3AOp.V_CVT_SCALEF32_SR_PK32_BF6_F32: _VOP3AOp_V_CVT_SCALEF32_SR_PK32_BF6_F32,
VOP3AOp.V_CVT_SCALEF32_PK32_F32_FP6: _VOP3AOp_V_CVT_SCALEF32_PK32_F32_FP6,
VOP3AOp.V_CVT_SCALEF32_PK32_F32_BF6: _VOP3AOp_V_CVT_SCALEF32_PK32_F32_BF6,
VOP3AOp.V_CVT_SCALEF32_PK32_FP6_BF16: _VOP3AOp_V_CVT_SCALEF32_PK32_FP6_BF16,
VOP3AOp.V_CVT_SCALEF32_PK32_BF6_F16: _VOP3AOp_V_CVT_SCALEF32_PK32_BF6_F16,
VOP3AOp.V_CVT_SCALEF32_PK32_BF6_BF16: _VOP3AOp_V_CVT_SCALEF32_PK32_BF6_BF16,
VOP3AOp.V_CVT_SCALEF32_SR_PK32_FP6_F16: _VOP3AOp_V_CVT_SCALEF32_SR_PK32_FP6_F16,
VOP3AOp.V_CVT_SCALEF32_SR_PK32_FP6_BF16: _VOP3AOp_V_CVT_SCALEF32_SR_PK32_FP6_BF16,
VOP3AOp.V_CVT_SCALEF32_SR_PK32_BF6_F16: _VOP3AOp_V_CVT_SCALEF32_SR_PK32_BF6_F16,
VOP3AOp.V_CVT_SCALEF32_SR_PK32_BF6_BF16: _VOP3AOp_V_CVT_SCALEF32_SR_PK32_BF6_BF16,
VOP3AOp.V_CVT_SCALEF32_PK32_F16_FP6: _VOP3AOp_V_CVT_SCALEF32_PK32_F16_FP6,
VOP3AOp.V_CVT_SCALEF32_PK32_BF16_FP6: _VOP3AOp_V_CVT_SCALEF32_PK32_BF16_FP6,
VOP3AOp.V_CVT_SCALEF32_PK32_F16_BF6: _VOP3AOp_V_CVT_SCALEF32_PK32_F16_BF6,
VOP3AOp.V_CVT_SCALEF32_PK32_BF16_BF6: _VOP3AOp_V_CVT_SCALEF32_PK32_BF16_BF6,
VOP3AOp.V_ASHR_PK_I8_I32: _VOP3AOp_V_ASHR_PK_I8_I32,
VOP3AOp.V_ASHR_PK_U8_I32: _VOP3AOp_V_ASHR_PK_U8_I32,
VOP3AOp.V_CVT_PK_F16_F32: _VOP3AOp_V_CVT_PK_F16_F32,
VOP3AOp.V_CVT_PK_BF16_F32: _VOP3AOp_V_CVT_PK_BF16_F32,
VOP3AOp.V_CVT_SCALEF32_PK_BF16_FP8: _VOP3AOp_V_CVT_SCALEF32_PK_BF16_FP8,
VOP3AOp.V_CVT_SCALEF32_PK_BF16_BF8: _VOP3AOp_V_CVT_SCALEF32_PK_BF16_BF8,
VOP3AOp.V_ADD_F64: _VOP3AOp_V_ADD_F64,
VOP3AOp.V_MUL_F64: _VOP3AOp_V_MUL_F64,
VOP3AOp.V_MIN_F64: _VOP3AOp_V_MIN_F64,

View File

@@ -394,6 +394,94 @@ def _SOP1Op_S_BCNT1_I32_B64(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal,
result['d0_64'] = True
return result
def _SOP1Op_S_QUADMASK_B32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
# tmp = 0U;
# for i in 0 : 7 do
# tmp[i] = S0.u32[i * 4 +: 4] != 0U
# endfor;
# D0.u32 = tmp;
# SCC = D0.u32 != 0U
S0 = Reg(s0)
D0 = Reg(d0)
SCC = Reg(scc)
tmp = Reg(0)
# --- compiled pseudocode ---
tmp = Reg(0)
for i in range(0, int(7)+1):
tmp[i] = S0.u32[(i * 4) + (4) - 1 : (i * 4)] != 0
D0.u32 = tmp
SCC = Reg(D0.u32 != 0)
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': SCC._val & 1}
return result
def _SOP1Op_S_QUADMASK_B64(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
# tmp = 0ULL;
# for i in 0 : 15 do
# tmp[i] = S0.u64[i * 4 +: 4] != 0ULL
# endfor;
# D0.u64 = tmp;
# SCC = D0.u64 != 0ULL
S0 = Reg(s0)
D0 = Reg(d0)
SCC = Reg(scc)
tmp = Reg(0)
# --- compiled pseudocode ---
tmp = Reg(0)
for i in range(0, int(15)+1):
tmp[i] = S0.u64[(i * 4) + (4) - 1 : (i * 4)] != 0
D0.u64 = tmp
SCC = Reg(D0.u64 != 0)
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': SCC._val & 1}
result['d0_64'] = True
return result
def _SOP1Op_S_WQM_B32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
# tmp = 0U;
# declare i : 6'U;
# for i in 6'0U : 6'31U do
# tmp[i] = S0.u32[i & 6'60U +: 6'4U] != 0U
# endfor;
# D0.u32 = tmp;
# SCC = D0.u32 != 0U
S0 = Reg(s0)
D0 = Reg(d0)
SCC = Reg(scc)
tmp = Reg(0)
# --- compiled pseudocode ---
tmp = Reg(0)
for i in range(0, int(31)+1):
tmp[i] = S0.u32[(i & 60) + (4) - 1 : (i & 60)] != 0
D0.u32 = tmp
SCC = Reg(D0.u32 != 0)
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': SCC._val & 1}
return result
def _SOP1Op_S_WQM_B64(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
# tmp = 0ULL;
# declare i : 6'U;
# for i in 6'0U : 6'63U do
# tmp[i] = S0.u64[i & 6'60U +: 6'4U] != 0ULL
# endfor;
# D0.u64 = tmp;
# SCC = D0.u64 != 0ULL
S0 = Reg(s0)
D0 = Reg(d0)
SCC = Reg(scc)
tmp = Reg(0)
# --- compiled pseudocode ---
tmp = Reg(0)
for i in range(0, int(63)+1):
tmp[i] = S0.u64[(i & 60) + (4) - 1 : (i & 60)] != 0
D0.u64 = tmp
SCC = Reg(D0.u64 != 0)
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': SCC._val & 1}
result['d0_64'] = True
return result
def _SOP1Op_S_NOT_B32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
# D0.u32 = ~S0.u32;
# SCC = D0.u32 != 0U
@@ -1178,6 +1266,10 @@ SOP1Op_FUNCTIONS = {
SOP1Op.S_BCNT0_I32_B64: _SOP1Op_S_BCNT0_I32_B64,
SOP1Op.S_BCNT1_I32_B32: _SOP1Op_S_BCNT1_I32_B32,
SOP1Op.S_BCNT1_I32_B64: _SOP1Op_S_BCNT1_I32_B64,
SOP1Op.S_QUADMASK_B32: _SOP1Op_S_QUADMASK_B32,
SOP1Op.S_QUADMASK_B64: _SOP1Op_S_QUADMASK_B64,
SOP1Op.S_WQM_B32: _SOP1Op_S_WQM_B32,
SOP1Op.S_WQM_B64: _SOP1Op_S_WQM_B64,
SOP1Op.S_NOT_B32: _SOP1Op_S_NOT_B32,
SOP1Op.S_NOT_B64: _SOP1Op_S_NOT_B64,
SOP1Op.S_AND_SAVEEXEC_B32: _SOP1Op_S_AND_SAVEEXEC_B32,
@@ -10159,6 +10251,18 @@ def _VOP3Op_V_SAD_U8(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _
result = {'d0': D0._val, 'scc': scc & 1}
return result
def _VOP3Op_V_SAD_HI_U8(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
# D0.u32 = (32'U(v_sad_u8(S0, S1, 0U)) << 16U) + S2.u32
S0 = Reg(s0)
S1 = Reg(s1)
S2 = Reg(s2)
D0 = Reg(d0)
# --- compiled pseudocode ---
D0.u32 = ((v_sad_u8(S0, S1, 0)) << 16) + S2.u32
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': scc & 1}
return result
def _VOP3Op_V_SAD_U16(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
# // UNSIGNED comparison
# tmp = S2.u32;
@@ -10385,6 +10489,71 @@ def _VOP3Op_V_MSAD_U8(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR,
result = {'d0': D0._val, 'scc': scc & 1}
return result
def _VOP3Op_V_QSAD_PK_U16_U8(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
# tmp[63 : 48] = 16'B(v_sad_u8(S0[55 : 24], S1[31 : 0], S2[63 : 48].u32));
# tmp[47 : 32] = 16'B(v_sad_u8(S0[47 : 16], S1[31 : 0], S2[47 : 32].u32));
# tmp[31 : 16] = 16'B(v_sad_u8(S0[39 : 8], S1[31 : 0], S2[31 : 16].u32));
# tmp[15 : 0] = 16'B(v_sad_u8(S0[31 : 0], S1[31 : 0], S2[15 : 0].u32));
# D0.b64 = tmp.b64
S0 = Reg(s0)
S1 = Reg(s1)
S2 = Reg(s2)
D0 = Reg(d0)
tmp = Reg(0)
# --- compiled pseudocode ---
tmp[63 : 48] = (v_sad_u8(S0[55 : 24], S1[31 : 0], S2[63 : 48].u32))
tmp[47 : 32] = (v_sad_u8(S0[47 : 16], S1[31 : 0], S2[47 : 32].u32))
tmp[31 : 16] = (v_sad_u8(S0[39 : 8], S1[31 : 0], S2[31 : 16].u32))
tmp[15 : 0] = (v_sad_u8(S0[31 : 0], S1[31 : 0], S2[15 : 0].u32))
D0.b64 = tmp.b64
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': scc & 1}
result['d0_64'] = True
return result
def _VOP3Op_V_MQSAD_PK_U16_U8(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
# tmp[63 : 48] = 16'B(v_msad_u8(S0[55 : 24], S1[31 : 0], S2[63 : 48].u32));
# tmp[47 : 32] = 16'B(v_msad_u8(S0[47 : 16], S1[31 : 0], S2[47 : 32].u32));
# tmp[31 : 16] = 16'B(v_msad_u8(S0[39 : 8], S1[31 : 0], S2[31 : 16].u32));
# tmp[15 : 0] = 16'B(v_msad_u8(S0[31 : 0], S1[31 : 0], S2[15 : 0].u32));
# D0.b64 = tmp.b64
S0 = Reg(s0)
S1 = Reg(s1)
S2 = Reg(s2)
D0 = Reg(d0)
tmp = Reg(0)
# --- compiled pseudocode ---
tmp[63 : 48] = (v_msad_u8(S0[55 : 24], S1[31 : 0], S2[63 : 48].u32))
tmp[47 : 32] = (v_msad_u8(S0[47 : 16], S1[31 : 0], S2[47 : 32].u32))
tmp[31 : 16] = (v_msad_u8(S0[39 : 8], S1[31 : 0], S2[31 : 16].u32))
tmp[15 : 0] = (v_msad_u8(S0[31 : 0], S1[31 : 0], S2[15 : 0].u32))
D0.b64 = tmp.b64
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': scc & 1}
result['d0_64'] = True
return result
def _VOP3Op_V_MQSAD_U32_U8(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
# tmp[127 : 96] = 32'B(v_msad_u8(S0[55 : 24], S1[31 : 0], S2[127 : 96].u32));
# tmp[95 : 64] = 32'B(v_msad_u8(S0[47 : 16], S1[31 : 0], S2[95 : 64].u32));
# tmp[63 : 32] = 32'B(v_msad_u8(S0[39 : 8], S1[31 : 0], S2[63 : 32].u32));
# tmp[31 : 0] = 32'B(v_msad_u8(S0[31 : 0], S1[31 : 0], S2[31 : 0].u32));
# D0.b128 = tmp.b128
S0 = Reg(s0)
S1 = Reg(s1)
S2 = Reg(s2)
D0 = Reg(d0)
tmp = Reg(0)
# --- compiled pseudocode ---
tmp[127 : 96] = (v_msad_u8(S0[55 : 24], S1[31 : 0], S2[127 : 96].u32))
tmp[95 : 64] = (v_msad_u8(S0[47 : 16], S1[31 : 0], S2[95 : 64].u32))
tmp[63 : 32] = (v_msad_u8(S0[39 : 8], S1[31 : 0], S2[63 : 32].u32))
tmp[31 : 0] = (v_msad_u8(S0[31 : 0], S1[31 : 0], S2[31 : 0].u32))
D0.b128 = tmp.b128
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': scc & 1}
return result
def _VOP3Op_V_XOR3_B32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
# D0.u32 = (S0.u32 ^ S1.u32 ^ S2.u32)
S0 = Reg(s0)
@@ -10860,6 +11029,25 @@ def _VOP3Op_V_DOT2_F16_F16(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, V
result = {'d0': D0._val, 'scc': scc & 1}
return result
def _VOP3Op_V_DOT2_BF16_BF16(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
# tmp = S2.bf16;
# tmp += S0[15 : 0].bf16 * S1[15 : 0].bf16;
# tmp += S0[31 : 16].bf16 * S1[31 : 16].bf16;
# D0.bf16 = tmp
S0 = Reg(s0)
S1 = Reg(s1)
S2 = Reg(s2)
D0 = Reg(d0)
tmp = Reg(0)
# --- compiled pseudocode ---
tmp = Reg(S2.bf16)
tmp += S0[15 : 0].bf16 * S1[15 : 0].bf16
tmp += S0[31 : 16].bf16 * S1[31 : 16].bf16
D0.bf16 = tmp
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': scc & 1}
return result
def _VOP3Op_V_ADD_NC_U16(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
# D0.u16 = S0.u16 + S1.u16
S0 = Reg(s0)
@@ -11793,6 +11981,7 @@ VOP3Op_FUNCTIONS = {
VOP3Op.V_MED3_I32: _VOP3Op_V_MED3_I32,
VOP3Op.V_MED3_U32: _VOP3Op_V_MED3_U32,
VOP3Op.V_SAD_U8: _VOP3Op_V_SAD_U8,
VOP3Op.V_SAD_HI_U8: _VOP3Op_V_SAD_HI_U8,
VOP3Op.V_SAD_U16: _VOP3Op_V_SAD_U16,
VOP3Op.V_SAD_U32: _VOP3Op_V_SAD_U32,
VOP3Op.V_CVT_PK_U8_F32: _VOP3Op_V_CVT_PK_U8_F32,
@@ -11801,6 +11990,9 @@ VOP3Op_FUNCTIONS = {
VOP3Op.V_DIV_FMAS_F32: _VOP3Op_V_DIV_FMAS_F32,
VOP3Op.V_DIV_FMAS_F64: _VOP3Op_V_DIV_FMAS_F64,
VOP3Op.V_MSAD_U8: _VOP3Op_V_MSAD_U8,
VOP3Op.V_QSAD_PK_U16_U8: _VOP3Op_V_QSAD_PK_U16_U8,
VOP3Op.V_MQSAD_PK_U16_U8: _VOP3Op_V_MQSAD_PK_U16_U8,
VOP3Op.V_MQSAD_U32_U8: _VOP3Op_V_MQSAD_U32_U8,
VOP3Op.V_XOR3_B32: _VOP3Op_V_XOR3_B32,
VOP3Op.V_MAD_U16: _VOP3Op_V_MAD_U16,
VOP3Op.V_XAD_U32: _VOP3Op_V_XAD_U32,
@@ -11834,6 +12026,7 @@ VOP3Op_FUNCTIONS = {
VOP3Op.V_MAXMIN_I32: _VOP3Op_V_MAXMIN_I32,
VOP3Op.V_MINMAX_I32: _VOP3Op_V_MINMAX_I32,
VOP3Op.V_DOT2_F16_F16: _VOP3Op_V_DOT2_F16_F16,
VOP3Op.V_DOT2_BF16_BF16: _VOP3Op_V_DOT2_BF16_BF16,
VOP3Op.V_ADD_NC_U16: _VOP3Op_V_ADD_NC_U16,
VOP3Op.V_SUB_NC_U16: _VOP3Op_V_SUB_NC_U16,
VOP3Op.V_MUL_LO_U16: _VOP3Op_V_MUL_LO_U16,
@@ -12552,6 +12745,25 @@ def _VOP3POp_V_DOT8_U32_U4(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, V
result = {'d0': D0._val, 'scc': scc & 1}
return result
def _VOP3POp_V_DOT2_F32_BF16(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
# tmp = S2.f32;
# tmp += bf16_to_f32(S0[15 : 0].bf16) * bf16_to_f32(S1[15 : 0].bf16);
# tmp += bf16_to_f32(S0[31 : 16].bf16) * bf16_to_f32(S1[31 : 16].bf16);
# D0.f32 = tmp
S0 = Reg(s0)
S1 = Reg(s1)
S2 = Reg(s2)
D0 = Reg(d0)
tmp = Reg(0)
# --- compiled pseudocode ---
tmp = Reg(S2.f32)
tmp += bf16_to_f32(S0[15 : 0].bf16) * bf16_to_f32(S1[15 : 0].bf16)
tmp += bf16_to_f32(S0[31 : 16].bf16) * bf16_to_f32(S1[31 : 16].bf16)
D0.f32 = tmp
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': scc & 1}
return result
VOP3POp_FUNCTIONS = {
VOP3POp.V_PK_MAD_I16: _VOP3POp_V_PK_MAD_I16,
VOP3POp.V_PK_MUL_LO_U16: _VOP3POp_V_PK_MUL_LO_U16,
@@ -12575,6 +12787,7 @@ VOP3POp_FUNCTIONS = {
VOP3POp.V_DOT2_F32_F16: _VOP3POp_V_DOT2_F32_F16,
VOP3POp.V_DOT4_U32_U8: _VOP3POp_V_DOT4_U32_U8,
VOP3POp.V_DOT8_U32_U4: _VOP3POp_V_DOT8_U32_U4,
VOP3POp.V_DOT2_F32_BF16: _VOP3POp_V_DOT2_F32_BF16,
}
def _VOPCOp_V_CMP_F_F16(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):

View File

@@ -394,6 +394,94 @@ def _SOP1Op_S_BCNT1_I32_B64(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal,
result['d0_64'] = True
return result
def _SOP1Op_S_QUADMASK_B32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
# tmp = 0U;
# for i in 0 : 7 do
# tmp[i] = S0.u32[i * 4 +: 4] != 0U
# endfor;
# D0.u32 = tmp;
# SCC = D0.u32 != 0U
S0 = Reg(s0)
D0 = Reg(d0)
SCC = Reg(scc)
tmp = Reg(0)
# --- compiled pseudocode ---
tmp = Reg(0)
for i in range(0, int(7)+1):
tmp[i] = S0.u32[(i * 4) + (4) - 1 : (i * 4)] != 0
D0.u32 = tmp
SCC = Reg(D0.u32 != 0)
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': SCC._val & 1}
return result
def _SOP1Op_S_QUADMASK_B64(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
# tmp = 0ULL;
# for i in 0 : 15 do
# tmp[i] = S0.u64[i * 4 +: 4] != 0ULL
# endfor;
# D0.u64 = tmp;
# SCC = D0.u64 != 0ULL
S0 = Reg(s0)
D0 = Reg(d0)
SCC = Reg(scc)
tmp = Reg(0)
# --- compiled pseudocode ---
tmp = Reg(0)
for i in range(0, int(15)+1):
tmp[i] = S0.u64[(i * 4) + (4) - 1 : (i * 4)] != 0
D0.u64 = tmp
SCC = Reg(D0.u64 != 0)
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': SCC._val & 1}
result['d0_64'] = True
return result
def _SOP1Op_S_WQM_B32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
# tmp = 0U;
# declare i : 6'U;
# for i in 6'0U : 6'31U do
# tmp[i] = S0.u32[i & 6'60U +: 6'4U] != 0U
# endfor;
# D0.u32 = tmp;
# SCC = D0.u32 != 0U
S0 = Reg(s0)
D0 = Reg(d0)
SCC = Reg(scc)
tmp = Reg(0)
# --- compiled pseudocode ---
tmp = Reg(0)
for i in range(0, int(31)+1):
tmp[i] = S0.u32[(i & 60) + (4) - 1 : (i & 60)] != 0
D0.u32 = tmp
SCC = Reg(D0.u32 != 0)
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': SCC._val & 1}
return result
def _SOP1Op_S_WQM_B64(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
# tmp = 0ULL;
# declare i : 6'U;
# for i in 6'0U : 6'63U do
# tmp[i] = S0.u64[i & 6'60U +: 6'4U] != 0ULL
# endfor;
# D0.u64 = tmp;
# SCC = D0.u64 != 0ULL
S0 = Reg(s0)
D0 = Reg(d0)
SCC = Reg(scc)
tmp = Reg(0)
# --- compiled pseudocode ---
tmp = Reg(0)
for i in range(0, int(63)+1):
tmp[i] = S0.u64[(i & 60) + (4) - 1 : (i & 60)] != 0
D0.u64 = tmp
SCC = Reg(D0.u64 != 0)
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': SCC._val & 1}
result['d0_64'] = True
return result
def _SOP1Op_S_NOT_B32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
# D0.u32 = ~S0.u32;
# SCC = D0.u32 != 0U
@@ -1264,6 +1352,10 @@ SOP1Op_FUNCTIONS = {
SOP1Op.S_BCNT0_I32_B64: _SOP1Op_S_BCNT0_I32_B64,
SOP1Op.S_BCNT1_I32_B32: _SOP1Op_S_BCNT1_I32_B32,
SOP1Op.S_BCNT1_I32_B64: _SOP1Op_S_BCNT1_I32_B64,
SOP1Op.S_QUADMASK_B32: _SOP1Op_S_QUADMASK_B32,
SOP1Op.S_QUADMASK_B64: _SOP1Op_S_QUADMASK_B64,
SOP1Op.S_WQM_B32: _SOP1Op_S_WQM_B32,
SOP1Op.S_WQM_B64: _SOP1Op_S_WQM_B64,
SOP1Op.S_NOT_B32: _SOP1Op_S_NOT_B32,
SOP1Op.S_NOT_B64: _SOP1Op_S_NOT_B64,
SOP1Op.S_AND_SAVEEXEC_B32: _SOP1Op_S_AND_SAVEEXEC_B32,
@@ -10034,6 +10126,18 @@ def _VOP3Op_V_SAD_U8(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _
result = {'d0': D0._val, 'scc': scc & 1}
return result
def _VOP3Op_V_SAD_HI_U8(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
# D0.u32 = (32'U(v_sad_u8(S0, S1, 0U)) << 16U) + S2.u32
S0 = Reg(s0)
S1 = Reg(s1)
S2 = Reg(s2)
D0 = Reg(d0)
# --- compiled pseudocode ---
D0.u32 = ((v_sad_u8(S0, S1, 0)) << 16) + S2.u32
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': scc & 1}
return result
def _VOP3Op_V_SAD_U16(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
# // UNSIGNED comparison
# tmp = S2.u32;
@@ -10410,6 +10514,71 @@ def _VOP3Op_V_MSAD_U8(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR,
result = {'d0': D0._val, 'scc': scc & 1}
return result
def _VOP3Op_V_QSAD_PK_U16_U8(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
# tmp[63 : 48] = 16'B(v_sad_u8(S0[55 : 24], S1[31 : 0], S2[63 : 48].u32));
# tmp[47 : 32] = 16'B(v_sad_u8(S0[47 : 16], S1[31 : 0], S2[47 : 32].u32));
# tmp[31 : 16] = 16'B(v_sad_u8(S0[39 : 8], S1[31 : 0], S2[31 : 16].u32));
# tmp[15 : 0] = 16'B(v_sad_u8(S0[31 : 0], S1[31 : 0], S2[15 : 0].u32));
# D0.b64 = tmp.b64
S0 = Reg(s0)
S1 = Reg(s1)
S2 = Reg(s2)
D0 = Reg(d0)
tmp = Reg(0)
# --- compiled pseudocode ---
tmp[63 : 48] = (v_sad_u8(S0[55 : 24], S1[31 : 0], S2[63 : 48].u32))
tmp[47 : 32] = (v_sad_u8(S0[47 : 16], S1[31 : 0], S2[47 : 32].u32))
tmp[31 : 16] = (v_sad_u8(S0[39 : 8], S1[31 : 0], S2[31 : 16].u32))
tmp[15 : 0] = (v_sad_u8(S0[31 : 0], S1[31 : 0], S2[15 : 0].u32))
D0.b64 = tmp.b64
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': scc & 1}
result['d0_64'] = True
return result
def _VOP3Op_V_MQSAD_PK_U16_U8(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
# tmp[63 : 48] = 16'B(v_msad_u8(S0[55 : 24], S1[31 : 0], S2[63 : 48].u32));
# tmp[47 : 32] = 16'B(v_msad_u8(S0[47 : 16], S1[31 : 0], S2[47 : 32].u32));
# tmp[31 : 16] = 16'B(v_msad_u8(S0[39 : 8], S1[31 : 0], S2[31 : 16].u32));
# tmp[15 : 0] = 16'B(v_msad_u8(S0[31 : 0], S1[31 : 0], S2[15 : 0].u32));
# D0.b64 = tmp.b64
S0 = Reg(s0)
S1 = Reg(s1)
S2 = Reg(s2)
D0 = Reg(d0)
tmp = Reg(0)
# --- compiled pseudocode ---
tmp[63 : 48] = (v_msad_u8(S0[55 : 24], S1[31 : 0], S2[63 : 48].u32))
tmp[47 : 32] = (v_msad_u8(S0[47 : 16], S1[31 : 0], S2[47 : 32].u32))
tmp[31 : 16] = (v_msad_u8(S0[39 : 8], S1[31 : 0], S2[31 : 16].u32))
tmp[15 : 0] = (v_msad_u8(S0[31 : 0], S1[31 : 0], S2[15 : 0].u32))
D0.b64 = tmp.b64
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': scc & 1}
result['d0_64'] = True
return result
def _VOP3Op_V_MQSAD_U32_U8(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
# tmp[127 : 96] = 32'B(v_msad_u8(S0[55 : 24], S1[31 : 0], S2[127 : 96].u32));
# tmp[95 : 64] = 32'B(v_msad_u8(S0[47 : 16], S1[31 : 0], S2[95 : 64].u32));
# tmp[63 : 32] = 32'B(v_msad_u8(S0[39 : 8], S1[31 : 0], S2[63 : 32].u32));
# tmp[31 : 0] = 32'B(v_msad_u8(S0[31 : 0], S1[31 : 0], S2[31 : 0].u32));
# D0.b128 = tmp.b128
S0 = Reg(s0)
S1 = Reg(s1)
S2 = Reg(s2)
D0 = Reg(d0)
tmp = Reg(0)
# --- compiled pseudocode ---
tmp[127 : 96] = (v_msad_u8(S0[55 : 24], S1[31 : 0], S2[127 : 96].u32))
tmp[95 : 64] = (v_msad_u8(S0[47 : 16], S1[31 : 0], S2[95 : 64].u32))
tmp[63 : 32] = (v_msad_u8(S0[39 : 8], S1[31 : 0], S2[63 : 32].u32))
tmp[31 : 0] = (v_msad_u8(S0[31 : 0], S1[31 : 0], S2[31 : 0].u32))
D0.b128 = tmp.b128
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': scc & 1}
return result
def _VOP3Op_V_XOR3_B32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
# D0.u32 = (S0.u32 ^ S1.u32 ^ S2.u32)
S0 = Reg(s0)
@@ -10786,6 +10955,25 @@ def _VOP3Op_V_DOT2_F16_F16(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, V
result = {'d0': D0._val, 'scc': scc & 1}
return result
def _VOP3Op_V_DOT2_BF16_BF16(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
# tmp = S2.bf16;
# tmp += S0[15 : 0].bf16 * S1[15 : 0].bf16;
# tmp += S0[31 : 16].bf16 * S1[31 : 16].bf16;
# D0.bf16 = tmp
S0 = Reg(s0)
S1 = Reg(s1)
S2 = Reg(s2)
D0 = Reg(d0)
tmp = Reg(0)
# --- compiled pseudocode ---
tmp = Reg(S2.bf16)
tmp += S0[15 : 0].bf16 * S1[15 : 0].bf16
tmp += S0[31 : 16].bf16 * S1[31 : 16].bf16
D0.bf16 = tmp
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': scc & 1}
return result
def _VOP3Op_V_MINMAX_NUM_F32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
# D0.f32 = v_max_num_f32(v_min_num_f32(S0.f32, S1.f32), S2.f32)
S0 = Reg(s0)
@@ -11993,6 +12181,7 @@ VOP3Op_FUNCTIONS = {
VOP3Op.V_MED3_I32: _VOP3Op_V_MED3_I32,
VOP3Op.V_MED3_U32: _VOP3Op_V_MED3_U32,
VOP3Op.V_SAD_U8: _VOP3Op_V_SAD_U8,
VOP3Op.V_SAD_HI_U8: _VOP3Op_V_SAD_HI_U8,
VOP3Op.V_SAD_U16: _VOP3Op_V_SAD_U16,
VOP3Op.V_SAD_U32: _VOP3Op_V_SAD_U32,
VOP3Op.V_CVT_PK_U8_F32: _VOP3Op_V_CVT_PK_U8_F32,
@@ -12011,6 +12200,9 @@ VOP3Op_FUNCTIONS = {
VOP3Op.V_DIV_FMAS_F32: _VOP3Op_V_DIV_FMAS_F32,
VOP3Op.V_DIV_FMAS_F64: _VOP3Op_V_DIV_FMAS_F64,
VOP3Op.V_MSAD_U8: _VOP3Op_V_MSAD_U8,
VOP3Op.V_QSAD_PK_U16_U8: _VOP3Op_V_QSAD_PK_U16_U8,
VOP3Op.V_MQSAD_PK_U16_U8: _VOP3Op_V_MQSAD_PK_U16_U8,
VOP3Op.V_MQSAD_U32_U8: _VOP3Op_V_MQSAD_U32_U8,
VOP3Op.V_XOR3_B32: _VOP3Op_V_XOR3_B32,
VOP3Op.V_MAD_U16: _VOP3Op_V_MAD_U16,
VOP3Op.V_XAD_U32: _VOP3Op_V_XAD_U32,
@@ -12037,6 +12229,7 @@ VOP3Op_FUNCTIONS = {
VOP3Op.V_MAXMIN_I32: _VOP3Op_V_MAXMIN_I32,
VOP3Op.V_MINMAX_I32: _VOP3Op_V_MINMAX_I32,
VOP3Op.V_DOT2_F16_F16: _VOP3Op_V_DOT2_F16_F16,
VOP3Op.V_DOT2_BF16_BF16: _VOP3Op_V_DOT2_BF16_BF16,
VOP3Op.V_MINMAX_NUM_F32: _VOP3Op_V_MINMAX_NUM_F32,
VOP3Op.V_MAXMIN_NUM_F32: _VOP3Op_V_MAXMIN_NUM_F32,
VOP3Op.V_MINMAX_NUM_F16: _VOP3Op_V_MINMAX_NUM_F16,
@@ -12754,6 +12947,25 @@ def _VOP3POp_V_DOT8_U32_U4(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, V
result = {'d0': D0._val, 'scc': scc & 1}
return result
def _VOP3POp_V_DOT2_F32_BF16(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
# tmp = S2.f32;
# tmp += bf16_to_f32(S0[15 : 0].bf16) * bf16_to_f32(S1[15 : 0].bf16);
# tmp += bf16_to_f32(S0[31 : 16].bf16) * bf16_to_f32(S1[31 : 16].bf16);
# D0.f32 = tmp
S0 = Reg(s0)
S1 = Reg(s1)
S2 = Reg(s2)
D0 = Reg(d0)
tmp = Reg(0)
# --- compiled pseudocode ---
tmp = Reg(S2.f32)
tmp += bf16_to_f32(S0[15 : 0].bf16) * bf16_to_f32(S1[15 : 0].bf16)
tmp += bf16_to_f32(S0[31 : 16].bf16) * bf16_to_f32(S1[31 : 16].bf16)
D0.f32 = tmp
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': scc & 1}
return result
def _VOP3POp_V_PK_MIN_NUM_F16(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
# declare tmp : 32'B;
# tmp[15 : 0].f16 = v_min_num_f16(S0[15 : 0].f16, S1[15 : 0].f16);
@@ -12935,6 +13147,7 @@ VOP3POp_FUNCTIONS = {
VOP3POp.V_DOT2_F32_F16: _VOP3POp_V_DOT2_F32_F16,
VOP3POp.V_DOT4_U32_U8: _VOP3POp_V_DOT4_U32_U8,
VOP3POp.V_DOT8_U32_U4: _VOP3POp_V_DOT8_U32_U4,
VOP3POp.V_DOT2_F32_BF16: _VOP3POp_V_DOT2_F32_BF16,
VOP3POp.V_PK_MIN_NUM_F16: _VOP3POp_V_PK_MIN_NUM_F16,
VOP3POp.V_PK_MAX_NUM_F16: _VOP3POp_V_PK_MAX_NUM_F16,
VOP3POp.V_PK_MINIMUM_F16: _VOP3POp_V_PK_MINIMUM_F16,

View File

@@ -20,7 +20,8 @@ _VOP3_64BIT_OPS = {op.value for op in VOP3Op if op.name.endswith(('_F64', '_B64'
# Ops where src1 is 32-bit (exponent/shift amount) even though the op name suggests 64-bit
_VOP3_64BIT_OPS_32BIT_SRC1 = {VOP3Op.V_LDEXP_F64.value}
# Ops with 16-bit types in name (for source/dest handling)
_VOP3_16BIT_OPS = {op for op in VOP3Op if any(s in op.name for s in ('_F16', '_B16', '_I16', '_U16'))}
# Exception: SAD/MSAD ops take 32-bit packed sources and extract 16-bit/8-bit chunks internally
_VOP3_16BIT_OPS = {op for op in VOP3Op if any(s in op.name for s in ('_F16', '_B16', '_I16', '_U16')) and 'SAD' not in op.name}
_VOP1_16BIT_OPS = {op for op in VOP1Op if any(s in op.name for s in ('_F16', '_B16', '_I16', '_U16'))}
_VOP2_16BIT_OPS = {op for op in VOP2Op if any(s in op.name for s in ('_F16', '_B16', '_I16', '_U16'))}
# CVT ops with 32/64-bit source (despite 16-bit in name)
@@ -382,11 +383,11 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No
else:
op_cls, op, src0, src1, src2, vdst = VOP3Op, VOP3Op(inst.op), inst.src0, inst.src1, inst.src2, inst.vdst
# V_PERM_B32: byte permutation - not in pseudocode PDF, implement directly
# D0[byte_i] = selector[byte_i] < 8 ? {src1, src0}[selector[byte_i]] : (selector[byte_i] >= 0xD ? 0xFF : 0x00)
# D0[byte_i] = selector[byte_i] < 8 ? {src0, src1}[selector[byte_i]] : (selector[byte_i] >= 0xD ? 0xFF : 0x00)
if op == VOP3Op.V_PERM_B32:
s0, s1, s2 = st.rsrc(inst.src0, lane), st.rsrc(inst.src1, lane), st.rsrc(inst.src2, lane)
# Combine src0 and src1 into 8-byte value: src0 is bytes 0-3, src1 is bytes 4-7
combined = (s0 & 0xffffffff) | ((s1 & 0xffffffff) << 32)
# Combine src1 and src0 into 8-byte value: src1 is bytes 0-3, src0 is bytes 4-7
combined = (s1 & 0xffffffff) | ((s0 & 0xffffffff) << 32)
result = 0
for i in range(4): # 4 result bytes
sel = (s2 >> (i * 8)) & 0xff # byte selector for this position

View File

@@ -195,7 +195,59 @@ v_min3_i16 = v_min3_i32
v_max3_i16 = v_max3_i32
def v_min3_u16(a, b, c): return min(a & 0xffff, b & 0xffff, c & 0xffff)
def v_max3_u16(a, b, c): return max(a & 0xffff, b & 0xffff, c & 0xffff)
def ABSDIFF(a, b): return abs(a - b)
def ABSDIFF(a, b): return abs(int(a) - int(b))
# BF16 (bfloat16) conversion functions
def _bf16(i):
"""Convert bf16 bits to float. BF16 is just the top 16 bits of f32."""
return struct.unpack("<f", struct.pack("<I", (i & 0xffff) << 16))[0]
def _ibf16(f):
"""Convert float to bf16 bits (truncate to top 16 bits of f32)."""
if math.isnan(f): return 0x7fc0 # bf16 quiet NaN
if math.isinf(f): return 0x7f80 if f > 0 else 0xff80 # bf16 ±infinity
try: return (struct.unpack("<I", struct.pack("<f", float(f)))[0] >> 16) & 0xffff
except (OverflowError, struct.error): return 0x7f80 if f > 0 else 0xff80
def bf16_to_f32(v): return _bf16(v) if isinstance(v, int) else float(v)
def f32_to_bf16(f): return _ibf16(f)
# BYTE_PERMUTE for V_PERM_B32 - select bytes from 64-bit data based on selector
def BYTE_PERMUTE(data, sel):
"""Select a byte from 64-bit data based on selector value.
sel 0-7: select byte from data (S1 is bytes 0-3, S0 is bytes 4-7 in {S0,S1})
sel 8-11: sign-extend from specific bytes (8->byte1, 9->byte3, 10->byte5, 11->byte7)
sel 12: constant 0x00
sel >= 13: constant 0xFF"""
sel = int(sel) & 0xff
if sel <= 7: return (int(data) >> (sel * 8)) & 0xff
if sel == 8: return 0xff if ((int(data) >> 15) & 1) else 0x00 # sign of byte 1
if sel == 9: return 0xff if ((int(data) >> 31) & 1) else 0x00 # sign of byte 3
if sel == 10: return 0xff if ((int(data) >> 47) & 1) else 0x00 # sign of byte 5
if sel == 11: return 0xff if ((int(data) >> 63) & 1) else 0x00 # sign of byte 7
if sel == 12: return 0x00
return 0xff # sel >= 13
# v_sad_u8 helper for V_SAD instructions (sum of absolute differences of 4 bytes)
def v_sad_u8(s0, s1, s2):
"""V_SAD_U8: Sum of absolute differences of 4 byte pairs plus accumulator."""
s0, s1, s2 = int(s0), int(s1), int(s2)
result = s2
for i in range(4):
a = (s0 >> (i * 8)) & 0xff
b = (s1 >> (i * 8)) & 0xff
result += abs(a - b)
return result & 0xffffffff
# v_msad_u8 helper (masked SAD - skip when reference byte is 0)
def v_msad_u8(s0, s1, s2):
"""V_MSAD_U8: Masked sum of absolute differences (skip if reference byte is 0)."""
s0, s1, s2 = int(s0), int(s1), int(s2)
result = s2
for i in range(4):
a = (s0 >> (i * 8)) & 0xff
b = (s1 >> (i * 8)) & 0xff
if b != 0: # Only add diff if reference (s1) byte is non-zero
result += abs(a - b)
return result & 0xffffffff
def f16_to_snorm(f): return max(-32768, min(32767, int(round(max(-1.0, min(1.0, f)) * 32767))))
def f16_to_unorm(f): return max(0, min(65535, int(round(max(0.0, min(1.0, f)) * 65535))))
def f32_to_snorm(f): return max(-32768, min(32767, int(round(max(-1.0, min(1.0, f)) * 32767))))
@@ -240,6 +292,8 @@ __all__ = [
'i16_to_f16', 'u16_to_f16', 'f16_to_i16', 'f16_to_u16', 'u32_to_u16', 'i32_to_i16',
'f16_to_snorm', 'f16_to_unorm', 'f32_to_snorm', 'f32_to_unorm', 'v_cvt_i16_f32', 'v_cvt_u16_f32',
'SAT8', 'f32_to_u8', 'u8_to_u32', 'u4_to_u32',
# BF16 conversion functions
'_bf16', '_ibf16', 'bf16_to_f32', 'f32_to_bf16',
# Math functions
'trunc', 'floor', 'ceil', 'sqrt', 'log2', 'sin', 'cos', 'pow', 'fract', 'isEven', 'mantissa',
# Min/max functions
@@ -248,6 +302,8 @@ __all__ = [
'v_min3_f32', 'v_max3_f32', 'v_min3_i32', 'v_max3_i32', 'v_min3_u32', 'v_max3_u32',
'v_min3_f16', 'v_max3_f16', 'v_min3_i16', 'v_max3_i16', 'v_min3_u16', 'v_max3_u16',
'ABSDIFF',
# Byte/SAD helper functions
'BYTE_PERMUTE', 'v_sad_u8', 'v_msad_u8',
# Bit manipulation
'_brev32', '_brev64', '_ctz32', '_ctz64', '_exponent', '_is_denorm_f32', '_is_denorm_f64',
'_sign', '_mantissa_f32', '_div', '_isnan', '_isquietnan', '_issignalnan', '_gt_neg_zero', '_lt_neg_zero', '_fma', '_ldexp', '_signext',
@@ -354,16 +410,25 @@ class SliceProxy:
i32 = property(lambda s: _sext(s._get() & MASK32, 32), lambda s, v: s._set(v))
f16 = property(lambda s: _f16(s._get()), lambda s, v: s._set(v if isinstance(v, int) else _i16(float(v))))
f32 = property(lambda s: _f32(s._get()), lambda s, v: s._set(_i32(float(v))))
bf16 = property(lambda s: _bf16(s._get()), lambda s, v: s._set(v if isinstance(v, int) else _ibf16(float(v))))
b16, b32 = u16, u32
def __int__(self): return self._get()
def __index__(self): return self._get()
# Comparison operators (compare as integers)
def __eq__(s, o): return s._get() == int(o)
def __ne__(s, o): return s._get() != int(o)
def __lt__(s, o): return s._get() < int(o)
def __le__(s, o): return s._get() <= int(o)
def __gt__(s, o): return s._get() > int(o)
def __ge__(s, o): return s._get() >= int(o)
class TypedView:
"""View for S0.u32 that supports [4:0] slicing and [bit] access."""
__slots__ = ('_reg', '_bits', '_signed', '_float')
def __init__(self, reg, bits, signed=False, is_float=False):
self._reg, self._bits, self._signed, self._float = reg, bits, signed, is_float
__slots__ = ('_reg', '_bits', '_signed', '_float', '_bf16')
def __init__(self, reg, bits, signed=False, is_float=False, is_bf16=False):
self._reg, self._bits, self._signed, self._float, self._bf16 = reg, bits, signed, is_float, is_bf16
@property
def _val(self):
@@ -390,6 +455,7 @@ class TypedView:
def __trunc__(self): return int(float(self)) if self._float else int(self)
def __float__(self):
if self._float:
if self._bf16: return _bf16(self._val) # bf16 uses different conversion
return _f16(self._val) if self._bits == 16 else _f32(self._val) if self._bits == 32 else _f64(self._val)
return float(int(self))
@@ -454,6 +520,7 @@ class Reg:
i16 = property(lambda s: TypedView(s, 16, signed=True), lambda s, v: setattr(s, '_val', (s._val & 0xffff0000) | (int(v) & 0xffff)))
b16 = property(lambda s: TypedView(s, 16), lambda s, v: setattr(s, '_val', (s._val & 0xffff0000) | (int(v) & 0xffff)))
f16 = property(lambda s: TypedView(s, 16, is_float=True), lambda s, v: setattr(s, '_val', (s._val & 0xffff0000) | ((v if isinstance(v, int) else _i16(float(v))) & 0xffff)))
bf16 = property(lambda s: TypedView(s, 16, is_float=True, is_bf16=True), lambda s, v: setattr(s, '_val', (s._val & 0xffff0000) | ((v if isinstance(v, int) else _ibf16(float(v))) & 0xffff)))
u8 = property(lambda s: TypedView(s, 8))
i8 = property(lambda s: TypedView(s, 8, signed=True))
@@ -610,6 +677,14 @@ def _expr(e: str) -> str:
e = e.replace('+INF', 'INF').replace('-INF', '(-INF)')
e = re.sub(r'NAN\.f\d+', 'float("nan")', e)
# Verilog bit slice syntax: [start +: width] -> extract width bits starting at start
# Convert to Python slice: [start + width - 1 : start]
def convert_verilog_slice(m):
start, width = m.group(1).strip(), m.group(2).strip()
# Convert to high:low slice format
return f'[({start}) + ({width}) - 1 : ({start})]'
e = re.sub(r'\[([^:\[\]]+)\s*\+:\s*([^:\[\]]+)\]', convert_verilog_slice, e)
# Recursively process bracket contents to handle nested ternaries like S1.u32[x ? a : b]
def process_brackets(s):
result, i = [], 0
@@ -706,10 +781,10 @@ from extra.assembly.amd.dsl import PDF_URLS
INST_PATTERN = re.compile(r'^([SV]_[A-Z0-9_]+)\s+(\d+)\s*$', re.M)
# Patterns that can't be handled by the DSL (require special handling in emu.py)
UNSUPPORTED = ['SGPR[', 'V_SWAP', 'eval ', 'BYTE_PERMUTE', 'FATAL_HALT', 'HW_REGISTERS',
'PC =', 'PC=', 'PC+', '= PC', 'v_sad', '+:', 'vscnt', 'vmcnt', 'expcnt', 'lgkmcnt',
'CVT_OFF_TABLE', '.bf16', 'ThreadMask',
'S1[i', 'C.i32', 'v_msad_u8', 'S[i]', 'in[', '2.0 / PI',
UNSUPPORTED = ['SGPR[', 'V_SWAP', 'eval ', 'FATAL_HALT', 'HW_REGISTERS',
'PC =', 'PC=', 'PC+', '= PC', 'vscnt', 'vmcnt', 'expcnt', 'lgkmcnt',
'CVT_OFF_TABLE', 'ThreadMask',
'S1[i', 'C.i32', 'S[i]', 'in[', '2.0 / PI',
'if n.', 'DST.u32', 'addrd = DST', 'addr = DST'] # Malformed pseudocode from PDF
def extract_pseudocode(text: str) -> str | None:

View File

@@ -2357,5 +2357,246 @@ class TestF64Conversions(unittest.TestCase):
self.assertEqual(result, -8, f"Expected -8, got {result} (lo=0x{lo:08x}, hi=0x{hi:08x})")
class TestNewPcodeHelpers(unittest.TestCase):
"""Tests for newly added pcode helper functions (SAD, BYTE_PERMUTE, BF16)."""
def test_v_sad_u8_basic(self):
"""V_SAD_U8: Sum of absolute differences of 4 bytes."""
# s0 = 0x05040302, s1 = 0x04030201, s2 = 10 -> diff = 1+1+1+1 = 4, result = 14
instructions = [
s_mov_b32(s[0], 0x05040302),
s_mov_b32(s[1], 0x04030201),
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
v_mov_b32_e32(v[2], 10),
v_sad_u8(v[3], v[0], v[1], v[2]),
]
st = run_program(instructions, n_lanes=1)
result = st.vgpr[0][3]
self.assertEqual(result, 14, f"Expected 14, got {result}")
def test_v_sad_u8_identical_bytes(self):
"""V_SAD_U8: When both operands are identical, SAD = 0 + accumulator."""
instructions = [
s_mov_b32(s[0], 0xDEADBEEF),
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[0]), # Same as v0
v_mov_b32_e32(v[2], 42), # Accumulator
v_sad_u8(v[3], v[0], v[1], v[2]),
]
st = run_program(instructions, n_lanes=1)
result = st.vgpr[0][3]
self.assertEqual(result, 42, f"Expected 42, got {result}")
def test_v_sad_u16_basic(self):
"""V_SAD_U16: Sum of absolute differences of 2 half-words."""
# s0 = 0x00020003, s1 = 0x00010001 -> diff = |2-1| + |3-1| = 1 + 2 = 3
instructions = [
s_mov_b32(s[0], 0x00020003),
s_mov_b32(s[1], 0x00010001),
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
v_mov_b32_e32(v[2], 0),
v_sad_u16(v[3], v[0], v[1], v[2]),
]
st = run_program(instructions, n_lanes=1)
result = st.vgpr[0][3]
self.assertEqual(result, 3, f"Expected 3, got {result}")
def test_v_sad_u32_basic(self):
"""V_SAD_U32: Absolute difference of 32-bit values."""
# s0 = 100, s1 = 30 -> diff = 70, s2 = 5 -> result = 75
instructions = [
v_mov_b32_e32(v[0], 100),
v_mov_b32_e32(v[1], 30),
v_mov_b32_e32(v[2], 5),
v_sad_u32(v[3], v[0], v[1], v[2]),
]
st = run_program(instructions, n_lanes=1)
result = st.vgpr[0][3]
self.assertEqual(result, 75, f"Expected 75, got {result}")
def test_v_msad_u8_masked(self):
"""V_MSAD_U8: Skip bytes where reference (s1) is 0."""
# s0 = 0x10101010, s1 = 0x00010001, s2 = 0
# Only bytes 0 and 2 of s1 are non-zero, so only those contribute
# diff = |0x10-0x01| + |0x10-0x01| = 15 + 15 = 30
instructions = [
s_mov_b32(s[0], 0x10101010),
s_mov_b32(s[1], 0x00010001),
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
v_mov_b32_e32(v[2], 0),
v_msad_u8(v[3], v[0], v[1], v[2]),
]
st = run_program(instructions, n_lanes=1)
result = st.vgpr[0][3]
self.assertEqual(result, 30, f"Expected 30, got {result}")
def test_v_perm_b32_select_bytes(self):
"""V_PERM_B32: Select bytes from combined {s0, s1}."""
# Combined = {S0, S1} where S1 is bytes 0-3, S0 is bytes 4-7
# s0 = 0x03020100 -> bytes 4-7 of combined
# s1 = 0x07060504 -> bytes 0-3 of combined
# Combined = 0x03020100_07060504
# selector = 0x00010203 -> select bytes 3,2,1,0 from combined = 0x04,0x05,0x06,0x07
instructions = [
s_mov_b32(s[0], 0x03020100),
s_mov_b32(s[1], 0x07060504),
s_mov_b32(s[2], 0x00010203),
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
v_mov_b32_e32(v[2], s[2]),
v_perm_b32(v[3], v[0], v[1], v[2]),
]
st = run_program(instructions, n_lanes=1)
result = st.vgpr[0][3]
self.assertEqual(result, 0x04050607, f"Expected 0x04050607, got 0x{result:08x}")
def test_v_perm_b32_select_high_bytes(self):
"""V_PERM_B32: Select bytes from high word (s0)."""
# Combined = {S0, S1} where S1 is bytes 0-3, S0 is bytes 4-7
# s0 = 0x03020100 -> bytes 4-7 of combined
# s1 = 0x07060504 -> bytes 0-3 of combined
# selector = 0x04050607 -> select bytes 7,6,5,4 from combined = 0x00,0x01,0x02,0x03
instructions = [
s_mov_b32(s[0], 0x03020100),
s_mov_b32(s[1], 0x07060504),
s_mov_b32(s[2], 0x04050607),
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
v_mov_b32_e32(v[2], s[2]),
v_perm_b32(v[3], v[0], v[1], v[2]),
]
st = run_program(instructions, n_lanes=1)
result = st.vgpr[0][3]
self.assertEqual(result, 0x00010203, f"Expected 0x00010203, got 0x{result:08x}")
def test_v_perm_b32_constant_values(self):
"""V_PERM_B32: Test constant 0x00 (sel=12) and 0xFF (sel>=13)."""
# selector = 0x0C0D0E0F -> bytes: 12=0x00, 13=0xFF, 14=0xFF, 15=0xFF
instructions = [
s_mov_b32(s[0], 0x12345678),
s_mov_b32(s[1], 0xABCDEF01),
s_mov_b32(s[2], 0x0C0D0E0F),
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
v_mov_b32_e32(v[2], s[2]),
v_perm_b32(v[3], v[0], v[1], v[2]),
]
st = run_program(instructions, n_lanes=1)
result = st.vgpr[0][3]
# byte 0: sel=0x0F >= 13 -> 0xFF
# byte 1: sel=0x0E >= 13 -> 0xFF
# byte 2: sel=0x0D >= 13 -> 0xFF
# byte 3: sel=0x0C = 12 -> 0x00
self.assertEqual(result, 0x00FFFFFF, f"Expected 0x00FFFFFF, got 0x{result:08x}")
def test_v_dot2_f32_bf16_basic(self):
"""V_DOT2_F32_BF16: Dot product of two bf16 pairs accumulated into f32."""
from extra.assembly.amd.pcode import _ibf16
# A = packed (2.0, 3.0) as bf16, B = packed (4.0, 5.0) as bf16
# Result = 2*4 + 3*5 + acc = 8 + 15 + 0 = 23.0
a_lo, a_hi = _ibf16(2.0), _ibf16(3.0)
b_lo, b_hi = _ibf16(4.0), _ibf16(5.0)
a_packed = (a_hi << 16) | a_lo
b_packed = (b_hi << 16) | b_lo
instructions = [
s_mov_b32(s[0], a_packed),
s_mov_b32(s[1], b_packed),
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
v_mov_b32_e32(v[2], 0), # accumulator = 0
v_dot2_f32_bf16(v[3], v[0], v[1], v[2]),
]
st = run_program(instructions, n_lanes=1)
result = i2f(st.vgpr[0][3])
self.assertAlmostEqual(result, 23.0, places=1, msg=f"Expected 23.0, got {result}")
class TestQuadmaskWqm(unittest.TestCase):
"""Tests for S_QUADMASK and S_WQM instructions."""
def test_s_quadmask_b32_all_quads_active(self):
"""S_QUADMASK_B32: All quads have at least one active lane."""
# Input: 0xFFFFFFFF (all bits set) -> all 8 quads active -> result = 0xFF
instructions = [
s_mov_b32(s[0], 0xFFFFFFFF),
s_quadmask_b32(s[1], s[0]),
]
st = run_program(instructions, n_lanes=1)
result = st.sgpr[1]
self.assertEqual(result, 0xFF, f"Expected 0xFF, got 0x{result:x}")
self.assertEqual(st.scc, 1, "SCC should be 1 (result != 0)")
def test_s_quadmask_b32_alternating_quads(self):
"""S_QUADMASK_B32: Every other quad has lanes active."""
# Input: 0x0F0F0F0F -> quads 0,2,4,6 active (bits 0-3, 8-11, 16-19, 24-27)
# Result: bits 0,2,4,6 set = 0x55
instructions = [
s_mov_b32(s[0], 0x0F0F0F0F),
s_quadmask_b32(s[1], s[0]),
]
st = run_program(instructions, n_lanes=1)
result = st.sgpr[1]
self.assertEqual(result, 0x55, f"Expected 0x55, got 0x{result:x}")
def test_s_quadmask_b32_no_quads_active(self):
"""S_QUADMASK_B32: No quads have active lanes."""
instructions = [
s_mov_b32(s[0], 0),
s_quadmask_b32(s[1], s[0]),
]
st = run_program(instructions, n_lanes=1)
result = st.sgpr[1]
self.assertEqual(result, 0, f"Expected 0, got 0x{result:x}")
self.assertEqual(st.scc, 0, "SCC should be 0 (result == 0)")
def test_s_quadmask_b32_single_lane_per_quad(self):
"""S_QUADMASK_B32: Single lane active in each quad."""
# Input: 0x11111111 -> bit 0 of each nibble set -> all 8 quads active
instructions = [
s_mov_b32(s[0], 0x11111111),
s_quadmask_b32(s[1], s[0]),
]
st = run_program(instructions, n_lanes=1)
result = st.sgpr[1]
self.assertEqual(result, 0xFF, f"Expected 0xFF, got 0x{result:x}")
def test_s_wqm_b32_all_active(self):
"""S_WQM_B32: Whole quad mode - if any lane in quad is active, activate all."""
# Input: 0x11111111 -> one lane per quad -> output all quads fully active = 0xFFFFFFFF
instructions = [
s_mov_b32(s[0], 0x11111111),
s_wqm_b32(s[1], s[0]),
]
st = run_program(instructions, n_lanes=1)
result = st.sgpr[1]
self.assertEqual(result, 0xFFFFFFFF, f"Expected 0xFFFFFFFF, got 0x{result:x}")
self.assertEqual(st.scc, 1, "SCC should be 1 (result != 0)")
def test_s_wqm_b32_alternating_quads(self):
"""S_WQM_B32: Only some quads have active lanes."""
# Input: 0x0000000F -> only quad 0 has lanes -> output = 0x0000000F (quad 0 all active)
instructions = [
s_mov_b32(s[0], 0x00000001), # single lane in quad 0
s_wqm_b32(s[1], s[0]),
]
st = run_program(instructions, n_lanes=1)
result = st.sgpr[1]
self.assertEqual(result, 0x0000000F, f"Expected 0x0000000F, got 0x{result:x}")
def test_s_wqm_b32_zero(self):
"""S_WQM_B32: No lanes active."""
instructions = [
s_mov_b32(s[0], 0),
s_wqm_b32(s[1], s[0]),
]
st = run_program(instructions, n_lanes=1)
result = st.sgpr[1]
self.assertEqual(result, 0, f"Expected 0, got 0x{result:x}")
self.assertEqual(st.scc, 0, "SCC should be 0 (result == 0)")
if __name__ == '__main__':
unittest.main()

View File

@@ -1,7 +1,9 @@
#!/usr/bin/env python3
"""Tests for the RDNA3 pseudocode DSL."""
import unittest
from extra.assembly.amd.pcode import Reg, TypedView, SliceProxy, ExecContext, compile_pseudocode, _expr, MASK32, MASK64, _f32, _i32, _f16, _i16, f32_to_f16, _isnan
from extra.assembly.amd.pcode import (Reg, TypedView, SliceProxy, ExecContext, compile_pseudocode, _expr, MASK32, MASK64,
_f32, _i32, _f16, _i16, f32_to_f16, _isnan, _bf16, _ibf16, bf16_to_f32, f32_to_bf16,
BYTE_PERMUTE, v_sad_u8, v_msad_u8)
from extra.assembly.amd.autogen.rdna3.gen_pcode import _VOP3SDOp_V_DIV_SCALE_F32, _VOPCOp_V_CMP_CLASS_F32
class TestReg(unittest.TestCase):
@@ -265,5 +267,129 @@ class TestPseudocodeRegressions(unittest.TestCase):
self.assertFalse(_isnan(normal_reg.f32), "_isnan should return False for normal TypedView")
self.assertFalse(_isnan(inf_reg.f32), "_isnan should return False for inf TypedView")
class TestBF16(unittest.TestCase):
"""Tests for BF16 (bfloat16) support."""
def test_bf16_conversion(self):
"""Test bf16 <-> f32 conversion."""
# bf16 is just the top 16 bits of f32
# 1.0f = 0x3f800000, bf16 = 0x3f80
self.assertAlmostEqual(_bf16(0x3f80), 1.0, places=2)
self.assertEqual(_ibf16(1.0), 0x3f80)
# 2.0f = 0x40000000, bf16 = 0x4000
self.assertAlmostEqual(_bf16(0x4000), 2.0, places=2)
self.assertEqual(_ibf16(2.0), 0x4000)
# -1.0f = 0xbf800000, bf16 = 0xbf80
self.assertAlmostEqual(_bf16(0xbf80), -1.0, places=2)
self.assertEqual(_ibf16(-1.0), 0xbf80)
def test_bf16_special_values(self):
"""Test bf16 special values (inf, nan)."""
import math
# +inf: f32 = 0x7f800000, bf16 = 0x7f80
self.assertTrue(math.isinf(_bf16(0x7f80)))
self.assertEqual(_ibf16(float('inf')), 0x7f80)
# -inf: f32 = 0xff800000, bf16 = 0xff80
self.assertTrue(math.isinf(_bf16(0xff80)))
self.assertEqual(_ibf16(float('-inf')), 0xff80)
# NaN: quiet NaN bf16 = 0x7fc0
self.assertTrue(math.isnan(_bf16(0x7fc0)))
self.assertEqual(_ibf16(float('nan')), 0x7fc0)
def test_bf16_register_property(self):
"""Test Reg.bf16 property."""
r = Reg(0)
r.bf16 = 3.0 # 3.0f = 0x40400000, bf16 = 0x4040
self.assertEqual(r._val & 0xffff, 0x4040)
self.assertAlmostEqual(float(r.bf16), 3.0, places=1)
def test_bf16_slice_property(self):
"""Test SliceProxy.bf16 property."""
r = Reg(0x40404040) # Two bf16 3.0 values
self.assertAlmostEqual(r[15:0].bf16, 3.0, places=1)
self.assertAlmostEqual(r[31:16].bf16, 3.0, places=1)
class TestBytePermute(unittest.TestCase):
"""Tests for BYTE_PERMUTE helper function (V_PERM_B32)."""
def test_byte_select_0_to_7(self):
"""Test selecting bytes 0-7 from 64-bit data."""
# data = {s0, s1} where s0 is bytes 0-3, s1 is bytes 4-7
# Combined: 0x0706050403020100 (byte 0 = 0x00, byte 7 = 0x07)
data = 0x0706050403020100
for i in range(8):
self.assertEqual(BYTE_PERMUTE(data, i), i, f"byte {i} should be {i}")
def test_sign_extend_bytes(self):
"""Test sign extension selectors 8-11."""
# sel 8: sign of byte 1 (bits 15:8)
# sel 9: sign of byte 3 (bits 31:24)
# sel 10: sign of byte 5 (bits 47:40)
# sel 11: sign of byte 7 (bits 63:56)
data = 0x8000800080008000 # All relevant bytes have sign bit set
self.assertEqual(BYTE_PERMUTE(data, 8), 0xff)
self.assertEqual(BYTE_PERMUTE(data, 9), 0xff)
self.assertEqual(BYTE_PERMUTE(data, 10), 0xff)
self.assertEqual(BYTE_PERMUTE(data, 11), 0xff)
data = 0x7f007f007f007f00 # No sign bits set
self.assertEqual(BYTE_PERMUTE(data, 8), 0x00)
self.assertEqual(BYTE_PERMUTE(data, 9), 0x00)
self.assertEqual(BYTE_PERMUTE(data, 10), 0x00)
self.assertEqual(BYTE_PERMUTE(data, 11), 0x00)
def test_constant_zero(self):
"""Test selector 12 returns 0x00."""
self.assertEqual(BYTE_PERMUTE(0xffffffffffffffff, 12), 0x00)
def test_constant_ff(self):
"""Test selectors >= 13 return 0xFF."""
for sel in [13, 14, 15, 255]:
self.assertEqual(BYTE_PERMUTE(0, sel), 0xff, f"sel {sel} should be 0xff")
class TestSADHelpers(unittest.TestCase):
"""Tests for V_SAD_U8 and V_MSAD_U8 helper functions."""
def test_v_sad_u8_basic(self):
"""Test v_sad_u8 with simple values."""
# s0 = 0x04030201, s1 = 0x04030201 -> diff = 0 for all bytes
result = v_sad_u8(0x04030201, 0x04030201, 0)
self.assertEqual(result, 0)
# s0 = 0x05040302, s1 = 0x04030201 -> diff = 1+1+1+1 = 4
result = v_sad_u8(0x05040302, 0x04030201, 0)
self.assertEqual(result, 4)
def test_v_sad_u8_with_accumulator(self):
"""Test v_sad_u8 with non-zero accumulator."""
# s0 = 0x05040302, s1 = 0x04030201, s2 = 100 -> 4 + 100 = 104
result = v_sad_u8(0x05040302, 0x04030201, 100)
self.assertEqual(result, 104)
def test_v_sad_u8_large_diff(self):
"""Test v_sad_u8 with maximum byte differences."""
# s0 = 0xffffffff, s1 = 0x00000000 -> diff = 255*4 = 1020
result = v_sad_u8(0xffffffff, 0x00000000, 0)
self.assertEqual(result, 1020)
def test_v_msad_u8_basic(self):
"""Test v_msad_u8 masks when reference byte is 0."""
# s0 = 0x10101010, s1 = 0x00000000 -> all masked, result = 0
result = v_msad_u8(0x10101010, 0x00000000, 0)
self.assertEqual(result, 0)
# s0 = 0x10101010, s1 = 0x01010101 -> diff = |0x10-0x01|*4 = 15*4 = 60
result = v_msad_u8(0x10101010, 0x01010101, 0)
self.assertEqual(result, 60)
def test_v_msad_u8_partial_mask(self):
"""Test v_msad_u8 with partial masking."""
# s0 = 0x10101010, s1 = 0x00010001 -> bytes 1 and 3 masked
# diff = |0x10-0x01| + |0x10-0x01| = 15 + 15 = 30
result = v_msad_u8(0x10101010, 0x00010001, 0)
self.assertEqual(result, 30)
def test_v_msad_u8_with_accumulator(self):
"""Test v_msad_u8 with non-zero accumulator."""
result = v_msad_u8(0x10101010, 0x01010101, 50)
self.assertEqual(result, 110) # 60 + 50
if __name__ == '__main__':
unittest.main()