From 7322d9ec4a36948fd36f24631a76d0c998b1054f Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Mon, 29 Dec 2025 17:30:17 -0500 Subject: [PATCH] assembly/amd: add new instruction support to pcode (#13885) * assembly/amd: add new instruction support * more * regen all --- extra/assembly/amd/autogen/cdna/gen_pcode.py | 558 ++++++++++++++++++ extra/assembly/amd/autogen/rdna3/gen_pcode.py | 213 +++++++ extra/assembly/amd/autogen/rdna4/gen_pcode.py | 213 +++++++ extra/assembly/amd/emu.py | 9 +- extra/assembly/amd/pcode.py | 91 ++- extra/assembly/amd/test/test_emu.py | 241 ++++++++ extra/assembly/amd/test/test_pcode.py | 128 +++- 7 files changed, 1440 insertions(+), 13 deletions(-) diff --git a/extra/assembly/amd/autogen/cdna/gen_pcode.py b/extra/assembly/amd/autogen/cdna/gen_pcode.py index 908678c82b..53a39ddfce 100644 --- a/extra/assembly/amd/autogen/cdna/gen_pcode.py +++ b/extra/assembly/amd/autogen/cdna/gen_pcode.py @@ -82,6 +82,51 @@ def _SOP1Op_S_NOT_B64(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, result['d0_64'] = True return result +def _SOP1Op_S_WQM_B32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): + # tmp = 0U; + # declare i : 6'U; + # for i in 6'0U : 6'31U do + # tmp[i] = S0.u32[i & 6'60U +: 6'4U] != 0U + # endfor; + # D0.u32 = tmp; + # SCC = D0.u32 != 0U + S0 = Reg(s0) + D0 = Reg(d0) + SCC = Reg(scc) + tmp = Reg(0) + # --- compiled pseudocode --- + tmp = Reg(0) + for i in range(0, int(31)+1): + tmp[i] = S0.u32[(i & 60) + (4) - 1 : (i & 60)] != 0 + D0.u32 = tmp + SCC = Reg(D0.u32 != 0) + # --- end pseudocode --- + result = {'d0': D0._val, 'scc': SCC._val & 1} + return result + +def _SOP1Op_S_WQM_B64(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): + # tmp = 0ULL; + # declare i : 6'U; + # for i in 6'0U : 6'63U do + # tmp[i] = S0.u64[i & 6'60U +: 6'4U] != 0ULL + # endfor; + # D0.u64 = tmp; + # SCC = D0.u64 != 0ULL + S0 = Reg(s0) + D0 = Reg(d0) + SCC = Reg(scc) + tmp = Reg(0) + # --- compiled pseudocode --- + tmp = Reg(0) + for i in range(0, int(63)+1): + tmp[i] = S0.u64[(i & 60) + (4) - 1 : (i & 60)] != 0 + D0.u64 = tmp + SCC = Reg(D0.u64 != 0) + # --- end pseudocode --- + result = {'d0': D0._val, 'scc': SCC._val & 1} + result['d0_64'] = True + return result + def _SOP1Op_S_BREV_B32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): # D0.u32[31 : 0] = S0.u32[0 : 31] S0 = Reg(s0) @@ -619,6 +664,49 @@ def _SOP1Op_S_XNOR_SAVEEXEC_B64(s0, s1, s2, d0, scc, vcc, lane, exec_mask, liter result['d0_64'] = True return result +def _SOP1Op_S_QUADMASK_B32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): + # tmp = 0U; + # for i in 0 : 7 do + # tmp[i] = S0.u32[i * 4 +: 4] != 0U + # endfor; + # D0.u32 = tmp; + # SCC = D0.u32 != 0U + S0 = Reg(s0) + D0 = Reg(d0) + SCC = Reg(scc) + tmp = Reg(0) + # --- compiled pseudocode --- + tmp = Reg(0) + for i in range(0, int(7)+1): + tmp[i] = S0.u32[(i * 4) + (4) - 1 : (i * 4)] != 0 + D0.u32 = tmp + SCC = Reg(D0.u32 != 0) + # --- end pseudocode --- + result = {'d0': D0._val, 'scc': SCC._val & 1} + return result + +def _SOP1Op_S_QUADMASK_B64(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): + # tmp = 0ULL; + # for i in 0 : 15 do + # tmp[i] = S0.u64[i * 4 +: 4] != 0ULL + # endfor; + # D0.u64 = tmp; + # SCC = D0.u64 != 0ULL + S0 = Reg(s0) + D0 = Reg(d0) + SCC = Reg(scc) + tmp = Reg(0) + # --- compiled pseudocode --- + tmp = Reg(0) + for i in range(0, int(15)+1): + tmp[i] = S0.u64[(i * 4) + (4) - 1 : (i * 4)] != 0 + D0.u64 = tmp + SCC = Reg(D0.u64 != 0) + # --- end pseudocode --- + result = {'d0': D0._val, 'scc': SCC._val & 1} + result['d0_64'] = True + return result + def _SOP1Op_S_ABS_I32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): # D0.i32 = S0.i32 < 0 ? -S0.i32 : S0.i32; # SCC = D0.i32 != 0 @@ -755,6 +843,8 @@ SOP1Op_FUNCTIONS = { SOP1Op.S_CMOV_B64: _SOP1Op_S_CMOV_B64, SOP1Op.S_NOT_B32: _SOP1Op_S_NOT_B32, SOP1Op.S_NOT_B64: _SOP1Op_S_NOT_B64, + SOP1Op.S_WQM_B32: _SOP1Op_S_WQM_B32, + SOP1Op.S_WQM_B64: _SOP1Op_S_WQM_B64, SOP1Op.S_BREV_B32: _SOP1Op_S_BREV_B32, SOP1Op.S_BREV_B64: _SOP1Op_S_BREV_B64, SOP1Op.S_BCNT0_I32_B32: _SOP1Op_S_BCNT0_I32_B32, @@ -783,6 +873,8 @@ SOP1Op_FUNCTIONS = { SOP1Op.S_NAND_SAVEEXEC_B64: _SOP1Op_S_NAND_SAVEEXEC_B64, SOP1Op.S_NOR_SAVEEXEC_B64: _SOP1Op_S_NOR_SAVEEXEC_B64, SOP1Op.S_XNOR_SAVEEXEC_B64: _SOP1Op_S_XNOR_SAVEEXEC_B64, + SOP1Op.S_QUADMASK_B32: _SOP1Op_S_QUADMASK_B32, + SOP1Op.S_QUADMASK_B64: _SOP1Op_S_QUADMASK_B64, SOP1Op.S_ABS_I32: _SOP1Op_S_ABS_I32, SOP1Op.S_SET_GPR_IDX_IDX: _SOP1Op_S_SET_GPR_IDX_IDX, SOP1Op.S_ANDN1_SAVEEXEC_B64: _SOP1Op_S_ANDN1_SAVEEXEC_B64, @@ -3495,6 +3587,24 @@ def _VOP2Op_V_XOR_B32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, result = {'d0': D0._val, 'scc': scc & 1} return result +def _VOP2Op_V_DOT2C_F32_BF16(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): + # tmp = D0.f32; + # tmp += bf16_to_f32(S0[15 : 0].bf16) * bf16_to_f32(S1[15 : 0].bf16); + # tmp += bf16_to_f32(S0[31 : 16].bf16) * bf16_to_f32(S1[31 : 16].bf16); + # D0.f32 = tmp + S0 = Reg(s0) + S1 = Reg(s1) + D0 = Reg(d0) + tmp = Reg(0) + # --- compiled pseudocode --- + tmp = Reg(D0.f32) + tmp += bf16_to_f32(S0[15 : 0].bf16) * bf16_to_f32(S1[15 : 0].bf16) + tmp += bf16_to_f32(S0[31 : 16].bf16) * bf16_to_f32(S1[31 : 16].bf16) + D0.f32 = tmp + # --- end pseudocode --- + result = {'d0': D0._val, 'scc': scc & 1} + return result + def _VOP2Op_V_FMAMK_F32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): # D0.f32 = fma(S0.f32, SIMM32.f32, S1.f32) S0 = Reg(s0) @@ -4119,6 +4229,7 @@ VOP2Op_FUNCTIONS = { VOP2Op.V_AND_B32: _VOP2Op_V_AND_B32, VOP2Op.V_OR_B32: _VOP2Op_V_OR_B32, VOP2Op.V_XOR_B32: _VOP2Op_V_XOR_B32, + VOP2Op.V_DOT2C_F32_BF16: _VOP2Op_V_DOT2C_F32_BF16, VOP2Op.V_FMAMK_F32: _VOP2Op_V_FMAMK_F32, VOP2Op.V_FMAAK_F32: _VOP2Op_V_FMAAK_F32, VOP2Op.V_ADD_CO_U32: _VOP2Op_V_ADD_CO_U32, @@ -4482,6 +4593,25 @@ def _VOP3POp_V_PK_MAX_F16(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VG result = {'d0': D0._val, 'scc': scc & 1} return result +def _VOP3POp_V_DOT2_F32_BF16(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): + # tmp = 32'F(S0[15 : 0].bf16) * 32'F(S1[15 : 0].bf16); + # tmp += 32'F(S0[31 : 16].bf16) * 32'F(S1[31 : 16].bf16); + # tmp += S2.f32; + # D0.f32 = tmp + S0 = Reg(s0) + S1 = Reg(s1) + S2 = Reg(s2) + D0 = Reg(d0) + tmp = Reg(0) + # --- compiled pseudocode --- + tmp = Reg(F(S0[15 : 0].bf16) * F(S1[15 : 0].bf16)) + tmp += F(S0[31 : 16].bf16) * F(S1[31 : 16].bf16) + tmp += S2.f32 + D0.f32 = tmp + # --- end pseudocode --- + result = {'d0': D0._val, 'scc': scc & 1} + return result + def _VOP3POp_V_PK_MINIMUM3_F16(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): # tmp[31 : 16].f16 = 16'F(v_minimum3_f16(S0[31 : 16].f16, S1[31 : 16].f16, S2[31 : 16].f16)); # tmp[15 : 0].f16 = 16'F(v_minimum3_f16(S0[15 : 0].f16, S1[15 : 0].f16, S2[15 : 0].f16)); @@ -4774,6 +4904,7 @@ VOP3POp_FUNCTIONS = { VOP3POp.V_PK_MUL_F16: _VOP3POp_V_PK_MUL_F16, VOP3POp.V_PK_MIN_F16: _VOP3POp_V_PK_MIN_F16, VOP3POp.V_PK_MAX_F16: _VOP3POp_V_PK_MAX_F16, + VOP3POp.V_DOT2_F32_BF16: _VOP3POp_V_DOT2_F32_BF16, VOP3POp.V_PK_MINIMUM3_F16: _VOP3POp_V_PK_MINIMUM3_F16, VOP3POp.V_PK_MAXIMUM3_F16: _VOP3POp_V_PK_MAXIMUM3_F16, VOP3POp.V_DOT2_F32_F16: _VOP3POp_V_DOT2_F32_F16, @@ -13632,6 +13763,24 @@ def _VOP3AOp_V_XOR_B32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, result = {'d0': D0._val, 'scc': scc & 1} return result +def _VOP3AOp_V_DOT2C_F32_BF16(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): + # tmp = D0.f32; + # tmp += bf16_to_f32(S0[15 : 0].bf16) * bf16_to_f32(S1[15 : 0].bf16); + # tmp += bf16_to_f32(S0[31 : 16].bf16) * bf16_to_f32(S1[31 : 16].bf16); + # D0.f32 = tmp + S0 = Reg(s0) + S1 = Reg(s1) + D0 = Reg(d0) + tmp = Reg(0) + # --- compiled pseudocode --- + tmp = Reg(D0.f32) + tmp += bf16_to_f32(S0[15 : 0].bf16) * bf16_to_f32(S1[15 : 0].bf16) + tmp += bf16_to_f32(S0[31 : 16].bf16) * bf16_to_f32(S1[31 : 16].bf16) + D0.f32 = tmp + # --- end pseudocode --- + result = {'d0': D0._val, 'scc': scc & 1} + return result + def _VOP3AOp_V_ADD_F16(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): # D0.f16 = S0.f16 + S1.f16 S0 = Reg(s0) @@ -14521,6 +14670,18 @@ def _VOP3AOp_V_SAD_U8(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, result = {'d0': D0._val, 'scc': scc & 1} return result +def _VOP3AOp_V_SAD_HI_U8(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): + # D0.u32 = (32'U(v_sad_u8(S0, S1, 0U)) << 16U) + S2.u32 + S0 = Reg(s0) + S1 = Reg(s1) + S2 = Reg(s2) + D0 = Reg(d0) + # --- compiled pseudocode --- + D0.u32 = ((v_sad_u8(S0, S1, 0)) << 16) + S2.u32 + # --- end pseudocode --- + result = {'d0': D0._val, 'scc': scc & 1} + return result + def _VOP3AOp_V_SAD_U16(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): # // UNSIGNED comparison # tmp = S2.u32; @@ -14747,6 +14908,71 @@ def _VOP3AOp_V_MSAD_U8(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, result = {'d0': D0._val, 'scc': scc & 1} return result +def _VOP3AOp_V_QSAD_PK_U16_U8(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): + # tmp[63 : 48] = 16'B(v_sad_u8(S0[55 : 24], S1[31 : 0], S2[63 : 48].u32)); + # tmp[47 : 32] = 16'B(v_sad_u8(S0[47 : 16], S1[31 : 0], S2[47 : 32].u32)); + # tmp[31 : 16] = 16'B(v_sad_u8(S0[39 : 8], S1[31 : 0], S2[31 : 16].u32)); + # tmp[15 : 0] = 16'B(v_sad_u8(S0[31 : 0], S1[31 : 0], S2[15 : 0].u32)); + # D0.b64 = tmp.b64 + S0 = Reg(s0) + S1 = Reg(s1) + S2 = Reg(s2) + D0 = Reg(d0) + tmp = Reg(0) + # --- compiled pseudocode --- + tmp[63 : 48] = (v_sad_u8(S0[55 : 24], S1[31 : 0], S2[63 : 48].u32)) + tmp[47 : 32] = (v_sad_u8(S0[47 : 16], S1[31 : 0], S2[47 : 32].u32)) + tmp[31 : 16] = (v_sad_u8(S0[39 : 8], S1[31 : 0], S2[31 : 16].u32)) + tmp[15 : 0] = (v_sad_u8(S0[31 : 0], S1[31 : 0], S2[15 : 0].u32)) + D0.b64 = tmp.b64 + # --- end pseudocode --- + result = {'d0': D0._val, 'scc': scc & 1} + result['d0_64'] = True + return result + +def _VOP3AOp_V_MQSAD_PK_U16_U8(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): + # tmp[63 : 48] = 16'B(v_msad_u8(S0[55 : 24], S1[31 : 0], S2[63 : 48].u32)); + # tmp[47 : 32] = 16'B(v_msad_u8(S0[47 : 16], S1[31 : 0], S2[47 : 32].u32)); + # tmp[31 : 16] = 16'B(v_msad_u8(S0[39 : 8], S1[31 : 0], S2[31 : 16].u32)); + # tmp[15 : 0] = 16'B(v_msad_u8(S0[31 : 0], S1[31 : 0], S2[15 : 0].u32)); + # D0.b64 = tmp.b64 + S0 = Reg(s0) + S1 = Reg(s1) + S2 = Reg(s2) + D0 = Reg(d0) + tmp = Reg(0) + # --- compiled pseudocode --- + tmp[63 : 48] = (v_msad_u8(S0[55 : 24], S1[31 : 0], S2[63 : 48].u32)) + tmp[47 : 32] = (v_msad_u8(S0[47 : 16], S1[31 : 0], S2[47 : 32].u32)) + tmp[31 : 16] = (v_msad_u8(S0[39 : 8], S1[31 : 0], S2[31 : 16].u32)) + tmp[15 : 0] = (v_msad_u8(S0[31 : 0], S1[31 : 0], S2[15 : 0].u32)) + D0.b64 = tmp.b64 + # --- end pseudocode --- + result = {'d0': D0._val, 'scc': scc & 1} + result['d0_64'] = True + return result + +def _VOP3AOp_V_MQSAD_U32_U8(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): + # tmp[127 : 96] = 32'B(v_msad_u8(S0[55 : 24], S1[31 : 0], S2[127 : 96].u32)); + # tmp[95 : 64] = 32'B(v_msad_u8(S0[47 : 16], S1[31 : 0], S2[95 : 64].u32)); + # tmp[63 : 32] = 32'B(v_msad_u8(S0[39 : 8], S1[31 : 0], S2[63 : 32].u32)); + # tmp[31 : 0] = 32'B(v_msad_u8(S0[31 : 0], S1[31 : 0], S2[31 : 0].u32)); + # D0.b128 = tmp.b128 + S0 = Reg(s0) + S1 = Reg(s1) + S2 = Reg(s2) + D0 = Reg(d0) + tmp = Reg(0) + # --- compiled pseudocode --- + tmp[127 : 96] = (v_msad_u8(S0[55 : 24], S1[31 : 0], S2[127 : 96].u32)) + tmp[95 : 64] = (v_msad_u8(S0[47 : 16], S1[31 : 0], S2[95 : 64].u32)) + tmp[63 : 32] = (v_msad_u8(S0[39 : 8], S1[31 : 0], S2[63 : 32].u32)) + tmp[31 : 0] = (v_msad_u8(S0[31 : 0], S1[31 : 0], S2[31 : 0].u32)) + D0.b128 = tmp.b128 + # --- end pseudocode --- + result = {'d0': D0._val, 'scc': scc & 1} + return result + def _VOP3AOp_V_MAD_LEGACY_F16(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): # tmp = S0.f16 * S1.f16 + S2.f16; # if OPSEL.u4[3] then @@ -15601,6 +15827,76 @@ def _VOP3AOp_V_CVT_SCALEF32_SR_BF8_F16(s0, s1, s2, d0, scc, vcc, lane, exec_mask result = {'d0': d0, 'scc': scc & 1} return result +def _VOP3AOp_V_CVT_SCALEF32_PK_FP8_BF16(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): + # scale = 32'U(exponent(S1.f32)); + # tmp0 = bf16_to_fp8_scale(S0[15 : 0].bf16, scale.u8); + # tmp1 = bf16_to_fp8_scale(S0[31 : 16].bf16, scale.u8); + # dstword = OPSEL[3].i32 * 16; + # // Other destination bits are preserved + S0 = Reg(s0) + S1 = Reg(s1) + tmp = Reg(0) + # --- compiled pseudocode --- + scale = (exponent(S1.f32)) + tmp0 = bf16_to_fp8_scale(S0[15 : 0].bf16, scale.u8) + tmp1 = bf16_to_fp8_scale(S0[31 : 16].bf16, scale.u8) + dstword = OPSEL[3].i32 * 16 + # --- end pseudocode --- + result = {'d0': d0, 'scc': scc & 1} + return result + +def _VOP3AOp_V_CVT_SCALEF32_PK_BF8_BF16(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): + # scale = 32'U(exponent(S1.f32)); + # tmp0 = bf16_to_bf8_scale(S0[15 : 0].bf16, scale.u8); + # tmp1 = bf16_to_bf8_scale(S0[31 : 16].bf16, scale.u8); + # dstword = OPSEL[3].i32 * 16; + # // Other destination bits are preserved + S0 = Reg(s0) + S1 = Reg(s1) + tmp = Reg(0) + # --- compiled pseudocode --- + scale = (exponent(S1.f32)) + tmp0 = bf16_to_bf8_scale(S0[15 : 0].bf16, scale.u8) + tmp1 = bf16_to_bf8_scale(S0[31 : 16].bf16, scale.u8) + dstword = OPSEL[3].i32 * 16 + # --- end pseudocode --- + result = {'d0': d0, 'scc': scc & 1} + return result + +def _VOP3AOp_V_CVT_SCALEF32_SR_FP8_BF16(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): + # scale = 32'U(exponent(S2.f32)); + # tmp = bf16_to_fp8_sr_scale(S0.bf16, S1.u32, scale.u8); + # dstbyte = OPSEL[3 : 2].i32 * 8; + # // Other destination bits are preserved + S0 = Reg(s0) + S1 = Reg(s1) + S2 = Reg(s2) + tmp = Reg(0) + # --- compiled pseudocode --- + scale = (exponent(S2.f32)) + tmp = Reg(bf16_to_fp8_sr_scale(S0.bf16, S1.u32, scale.u8)) + dstbyte = OPSEL[3 : 2].i32 * 8 + # --- end pseudocode --- + result = {'d0': d0, 'scc': scc & 1} + return result + +def _VOP3AOp_V_CVT_SCALEF32_SR_BF8_BF16(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): + # scale = 32'U(exponent(S2.f32)); + # tmp = bf16_to_bf8_sr_scale(S0.bf16, S1.u32, scale.u8); + # dstbyte = OPSEL[3 : 2].i32 * 8; + # // Other destination bits are preserved + S0 = Reg(s0) + S1 = Reg(s1) + S2 = Reg(s2) + tmp = Reg(0) + # --- compiled pseudocode --- + scale = (exponent(S2.f32)) + tmp = Reg(bf16_to_bf8_sr_scale(S0.bf16, S1.u32, scale.u8)) + dstbyte = OPSEL[3 : 2].i32 * 8 + # --- end pseudocode --- + result = {'d0': d0, 'scc': scc & 1} + return result + def _VOP3AOp_V_CVT_SCALEF32_PK_F16_FP8(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): # scale = 32'U(exponent(S1.f32)); # srcword = OPSEL[0].i32 * 16; @@ -15699,6 +15995,24 @@ def _VOP3AOp_V_CVT_SCALEF32_PK_FP4_F16(s0, s1, s2, d0, scc, vcc, lane, exec_mask result = {'d0': d0, 'scc': scc & 1} return result +def _VOP3AOp_V_CVT_SCALEF32_PK_FP4_BF16(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): + # scale = 32'U(exponent(S1.f32)); + # tmp0 = bf16_to_fp4_scale(S0[15 : 0].bf16, scale.u8); + # tmp1 = bf16_to_fp4_scale(S0[31 : 16].bf16, scale.u8); + # dstbyte = OPSEL[3 : 2].i32 * 8; + # // Other destination bits are preserved + S0 = Reg(s0) + S1 = Reg(s1) + tmp = Reg(0) + # --- compiled pseudocode --- + scale = (exponent(S1.f32)) + tmp0 = bf16_to_fp4_scale(S0[15 : 0].bf16, scale.u8) + tmp1 = bf16_to_fp4_scale(S0[31 : 16].bf16, scale.u8) + dstbyte = OPSEL[3 : 2].i32 * 8 + # --- end pseudocode --- + result = {'d0': d0, 'scc': scc & 1} + return result + def _VOP3AOp_V_CVT_SCALEF32_SR_PK_FP4_F16(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): # scale = 32'U(exponent(S2.f32)); # randomVal = S1.u32; @@ -15720,6 +16034,27 @@ def _VOP3AOp_V_CVT_SCALEF32_SR_PK_FP4_F16(s0, s1, s2, d0, scc, vcc, lane, exec_m result = {'d0': d0, 'scc': scc & 1} return result +def _VOP3AOp_V_CVT_SCALEF32_SR_PK_FP4_BF16(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): + # scale = 32'U(exponent(S2.f32)); + # randomVal = S1.u32; + # tmp0 = bf16_to_fp4_sr_scale(S0[15 : 0].bf16, randomVal, scale.u8); + # tmp1 = bf16_to_fp4_sr_scale(S0[31 : 16].bf16, randomVal, scale.u8); + # dstbyte = OPSEL[3 : 2].i32 * 8; + # // Other destination bits are preserved + S0 = Reg(s0) + S1 = Reg(s1) + S2 = Reg(s2) + tmp = Reg(0) + # --- compiled pseudocode --- + scale = (exponent(S2.f32)) + randomVal = S1.u32 + tmp0 = bf16_to_fp4_sr_scale(S0[15 : 0].bf16, randomVal, scale.u8) + tmp1 = bf16_to_fp4_sr_scale(S0[31 : 16].bf16, randomVal, scale.u8) + dstbyte = OPSEL[3 : 2].i32 * 8 + # --- end pseudocode --- + result = {'d0': d0, 'scc': scc & 1} + return result + def _VOP3AOp_V_CVT_SCALEF32_PK_F16_FP4(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): # scale = 32'U(exponent(S1.f32)); # srcbyte = OPSEL[1 : 0].i32 * 8; @@ -15741,6 +16076,27 @@ def _VOP3AOp_V_CVT_SCALEF32_PK_F16_FP4(s0, s1, s2, d0, scc, vcc, lane, exec_mask result = {'d0': D0._val, 'scc': scc & 1} return result +def _VOP3AOp_V_CVT_SCALEF32_PK_BF16_FP4(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): + # scale = 32'U(exponent(S1.f32)); + # srcbyte = OPSEL[1 : 0].i32 * 8; + # src = VGPR[laneId][SRC0.u32][srcbyte + 7 : srcbyte].b8; + # D0[15 : 0].bf16 = tmp0; + # D0[31 : 16].bf16 = tmp1 + S1 = Reg(s1) + D0 = Reg(d0) + tmp = Reg(0) + laneId = lane + SRC0 = Reg(src0_idx) + # --- compiled pseudocode --- + scale = (exponent(S1.f32)) + srcbyte = OPSEL[1 : 0].i32 * 8 + src = VGPR[laneId][SRC0.u32][srcbyte + 7 : srcbyte].b8 + D0[15 : 0].bf16 = tmp0 + D0[31 : 16].bf16 = tmp1 + # --- end pseudocode --- + result = {'d0': D0._val, 'scc': scc & 1} + return result + def _VOP3AOp_V_CVT_SCALEF32_2XPK16_FP6_F32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): # scale = 32'U(exponent(S2.f32)); # declare tmp : 192'B; @@ -15873,6 +16229,26 @@ def _VOP3AOp_V_CVT_SCALEF32_PK32_F32_BF6(s0, s1, s2, d0, scc, vcc, lane, exec_ma result = {'d0': D0._val, 'scc': scc & 1} return result +def _VOP3AOp_V_CVT_SCALEF32_PK32_FP6_BF16(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): + # scale = 32'U(exponent(S1.f32)); + # declare tmp : 192'B; + # for pass in 0 : 31 do + # tmp[dOffset + 5 : dOffset].fp6 = bf16_to_fp6_scale(S0[sOffset + 15 : sOffset].bf16, scale.u8) + # endfor; + # D0[191 : 0] = tmp.b192 + S0 = Reg(s0) + S1 = Reg(s1) + D0 = Reg(d0) + tmp = Reg(0) + # --- compiled pseudocode --- + scale = (exponent(S1.f32)) + for pass in range(0, int(31)+1): + tmp[dOffset + 5 : dOffset].fp6 = bf16_to_fp6_scale(S0[sOffset + 15 : sOffset].bf16, scale.u8) + D0[191 : 0] = tmp.b192 + # --- end pseudocode --- + result = {'d0': D0._val, 'scc': scc & 1} + return result + def _VOP3AOp_V_CVT_SCALEF32_PK32_BF6_F16(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): # scale = 32'U(exponent(S1.f32)); # declare tmp : 192'B; @@ -15893,6 +16269,26 @@ def _VOP3AOp_V_CVT_SCALEF32_PK32_BF6_F16(s0, s1, s2, d0, scc, vcc, lane, exec_ma result = {'d0': D0._val, 'scc': scc & 1} return result +def _VOP3AOp_V_CVT_SCALEF32_PK32_BF6_BF16(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): + # scale = 32'U(exponent(S1.f32)); + # declare tmp : 192'B; + # for pass in 0 : 31 do + # tmp[dOffset + 5 : dOffset].bf6 = bf16_to_bf6_scale(S0[sOffset + 15 : sOffset].bf16, scale.u8) + # endfor; + # D0[191 : 0] = tmp.b192 + S0 = Reg(s0) + S1 = Reg(s1) + D0 = Reg(d0) + tmp = Reg(0) + # --- compiled pseudocode --- + scale = (exponent(S1.f32)) + for pass in range(0, int(31)+1): + tmp[dOffset + 5 : dOffset].bf6 = bf16_to_bf6_scale(S0[sOffset + 15 : sOffset].bf16, scale.u8) + D0[191 : 0] = tmp.b192 + # --- end pseudocode --- + result = {'d0': D0._val, 'scc': scc & 1} + return result + def _VOP3AOp_V_CVT_SCALEF32_SR_PK32_FP6_F16(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): # scale = 32'U(exponent(S2.f32)); # randomVal = S1.u32; @@ -15915,6 +16311,28 @@ def _VOP3AOp_V_CVT_SCALEF32_SR_PK32_FP6_F16(s0, s1, s2, d0, scc, vcc, lane, exec result = {'d0': D0._val, 'scc': scc & 1} return result +def _VOP3AOp_V_CVT_SCALEF32_SR_PK32_FP6_BF16(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): + # scale = 32'U(exponent(S2.f32)); + # randomVal = S1.u32; + # declare tmp : 192'B; + # for pass in 0 : 31 do + # tmp[dOffset + 5 : dOffset].fp6 = bf16_to_fp6_sr_scale(S0[sOffset + 15 : sOffset].bf16, randomVal, + # endfor; + # D0[191 : 0] = tmp.b192 + S0 = Reg(s0) + S1 = Reg(s1) + S2 = Reg(s2) + D0 = Reg(d0) + tmp = Reg(0) + # --- compiled pseudocode --- + scale = (exponent(S2.f32)) + randomVal = S1.u32 + for pass in range(0, int(31)+1): + tmp[dOffset + 5 : dOffset].fp6 = bf16_to_fp6_sr_scale(S0[sOffset + 15 : sOffset].bf16, randomVal, endfor; D0[191 : 0] = tmp.b192 + # --- end pseudocode --- + result = {'d0': D0._val, 'scc': scc & 1} + return result + def _VOP3AOp_V_CVT_SCALEF32_SR_PK32_BF6_F16(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): # scale = 32'U(exponent(S2.f32)); # randomVal = S1.u32; @@ -15937,6 +16355,28 @@ def _VOP3AOp_V_CVT_SCALEF32_SR_PK32_BF6_F16(s0, s1, s2, d0, scc, vcc, lane, exec result = {'d0': D0._val, 'scc': scc & 1} return result +def _VOP3AOp_V_CVT_SCALEF32_SR_PK32_BF6_BF16(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): + # scale = 32'U(exponent(S2.f32)); + # randomVal = S1.u32; + # declare tmp : 192'B; + # for pass in 0 : 31 do + # tmp[dOffset + 5 : dOffset].bf6 = bf16_to_bf6_sr_scale(S0[sOffset + 15 : sOffset].bf16, randomVal, + # endfor; + # D0[191 : 0] = tmp.b192 + S0 = Reg(s0) + S1 = Reg(s1) + S2 = Reg(s2) + D0 = Reg(d0) + tmp = Reg(0) + # --- compiled pseudocode --- + scale = (exponent(S2.f32)) + randomVal = S1.u32 + for pass in range(0, int(31)+1): + tmp[dOffset + 5 : dOffset].bf6 = bf16_to_bf6_sr_scale(S0[sOffset + 15 : sOffset].bf16, randomVal, endfor; D0[191 : 0] = tmp.b192 + # --- end pseudocode --- + result = {'d0': D0._val, 'scc': scc & 1} + return result + def _VOP3AOp_V_CVT_SCALEF32_PK32_F16_FP6(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): # scale = 32'U(exponent(S1.f32)); # declare tmp : 512'B; @@ -15957,6 +16397,26 @@ def _VOP3AOp_V_CVT_SCALEF32_PK32_F16_FP6(s0, s1, s2, d0, scc, vcc, lane, exec_ma result = {'d0': D0._val, 'scc': scc & 1} return result +def _VOP3AOp_V_CVT_SCALEF32_PK32_BF16_FP6(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): + # scale = 32'U(exponent(S1.f32)); + # declare tmp : 512'B; + # for pass in 0 : 31 do + # tmp[dOffset + 15 : dOffset].bf16 = fp6_to_bf16_scale(S0[sOffset + 5 : sOffset].fp6, scale.u8) + # endfor; + # D0[511 : 0] = tmp.b512 + S0 = Reg(s0) + S1 = Reg(s1) + D0 = Reg(d0) + tmp = Reg(0) + # --- compiled pseudocode --- + scale = (exponent(S1.f32)) + for pass in range(0, int(31)+1): + tmp[dOffset + 15 : dOffset].bf16 = fp6_to_bf16_scale(S0[sOffset + 5 : sOffset].fp6, scale.u8) + D0[511 : 0] = tmp.b512 + # --- end pseudocode --- + result = {'d0': D0._val, 'scc': scc & 1} + return result + def _VOP3AOp_V_CVT_SCALEF32_PK32_F16_BF6(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): # scale = 32'U(exponent(S1.f32)); # declare tmp : 512'B; @@ -15977,6 +16437,26 @@ def _VOP3AOp_V_CVT_SCALEF32_PK32_F16_BF6(s0, s1, s2, d0, scc, vcc, lane, exec_ma result = {'d0': D0._val, 'scc': scc & 1} return result +def _VOP3AOp_V_CVT_SCALEF32_PK32_BF16_BF6(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): + # scale = 32'U(exponent(S1.f32)); + # declare tmp : 512'B; + # for pass in 0 : 31 do + # tmp[dOffset + 15 : dOffset].bf16 = bf6_to_bf16_scale(S0[sOffset + 5 : sOffset].bf6, scale.u8) + # endfor; + # D0[511 : 0] = tmp.b512 + S0 = Reg(s0) + S1 = Reg(s1) + D0 = Reg(d0) + tmp = Reg(0) + # --- compiled pseudocode --- + scale = (exponent(S1.f32)) + for pass in range(0, int(31)+1): + tmp[dOffset + 15 : dOffset].bf16 = bf6_to_bf16_scale(S0[sOffset + 5 : sOffset].bf6, scale.u8) + D0[511 : 0] = tmp.b512 + # --- end pseudocode --- + result = {'d0': D0._val, 'scc': scc & 1} + return result + def _VOP3AOp_V_ASHR_PK_I8_I32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): # if n <= -128 then # elsif n >= 127 then @@ -16048,6 +16528,63 @@ def _VOP3AOp_V_CVT_PK_F16_F32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal result = {'d0': d0, 'scc': scc & 1} return result +def _VOP3AOp_V_CVT_PK_BF16_F32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): + # prev_mode = ROUND_MODE; + # tmp[15 : 0].bf16 = f32_to_bf16(S0.f32); + # tmp[31 : 16].bf16 = f32_to_bf16(S1.f32); + S0 = Reg(s0) + S1 = Reg(s1) + tmp = Reg(0) + # --- compiled pseudocode --- + prev_mode = ROUND_MODE + tmp[15 : 0].bf16 = f32_to_bf16(S0.f32) + tmp[31 : 16].bf16 = f32_to_bf16(S1.f32) + # --- end pseudocode --- + result = {'d0': d0, 'scc': scc & 1} + return result + +def _VOP3AOp_V_CVT_SCALEF32_PK_BF16_FP8(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): + # scale = 32'U(exponent(S1.f32)); + # srcword = OPSEL[0].i32 * 16; + # src = VGPR[laneId][SRC0.u32][srcword + 15 : srcword].b16; + # D0[15 : 0].bf16 = tmp0.bf16; + # D0[31 : 16].bf16 = tmp1.bf16 + S1 = Reg(s1) + D0 = Reg(d0) + tmp = Reg(0) + laneId = lane + SRC0 = Reg(src0_idx) + # --- compiled pseudocode --- + scale = (exponent(S1.f32)) + srcword = OPSEL[0].i32 * 16 + src = VGPR[laneId][SRC0.u32][srcword + 15 : srcword].b16 + D0[15 : 0].bf16 = tmp0.bf16 + D0[31 : 16].bf16 = tmp1.bf16 + # --- end pseudocode --- + result = {'d0': D0._val, 'scc': scc & 1} + return result + +def _VOP3AOp_V_CVT_SCALEF32_PK_BF16_BF8(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): + # scale = 32'U(exponent(S1.f32)); + # srcword = OPSEL[0].i32 * 16; + # src = VGPR[laneId][SRC0.u32][srcword + 15 : srcword].b16; + # D0[15 : 0].bf16 = tmp0.bf16; + # D0[31 : 16].bf16 = tmp1.bf16 + S1 = Reg(s1) + D0 = Reg(d0) + tmp = Reg(0) + laneId = lane + SRC0 = Reg(src0_idx) + # --- compiled pseudocode --- + scale = (exponent(S1.f32)) + srcword = OPSEL[0].i32 * 16 + src = VGPR[laneId][SRC0.u32][srcword + 15 : srcword].b16 + D0[15 : 0].bf16 = tmp0.bf16 + D0[31 : 16].bf16 = tmp1.bf16 + # --- end pseudocode --- + result = {'d0': D0._val, 'scc': scc & 1} + return result + def _VOP3AOp_V_ADD_F64(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): # D0.f64 = S0.f64 + S1.f64 S0 = Reg(s0) @@ -16771,6 +17308,7 @@ VOP3AOp_FUNCTIONS = { VOP3AOp.V_AND_B32: _VOP3AOp_V_AND_B32, VOP3AOp.V_OR_B32: _VOP3AOp_V_OR_B32, VOP3AOp.V_XOR_B32: _VOP3AOp_V_XOR_B32, + VOP3AOp.V_DOT2C_F32_BF16: _VOP3AOp_V_DOT2C_F32_BF16, VOP3AOp.V_ADD_F16: _VOP3AOp_V_ADD_F16, VOP3AOp.V_SUB_F16: _VOP3AOp_V_SUB_F16, VOP3AOp.V_SUBREV_F16: _VOP3AOp_V_SUBREV_F16, @@ -16824,6 +17362,7 @@ VOP3AOp_FUNCTIONS = { VOP3AOp.V_MED3_I32: _VOP3AOp_V_MED3_I32, VOP3AOp.V_MED3_U32: _VOP3AOp_V_MED3_U32, VOP3AOp.V_SAD_U8: _VOP3AOp_V_SAD_U8, + VOP3AOp.V_SAD_HI_U8: _VOP3AOp_V_SAD_HI_U8, VOP3AOp.V_SAD_U16: _VOP3AOp_V_SAD_U16, VOP3AOp.V_SAD_U32: _VOP3AOp_V_SAD_U32, VOP3AOp.V_CVT_PK_U8_F32: _VOP3AOp_V_CVT_PK_U8_F32, @@ -16832,6 +17371,9 @@ VOP3AOp_FUNCTIONS = { VOP3AOp.V_DIV_FMAS_F32: _VOP3AOp_V_DIV_FMAS_F32, VOP3AOp.V_DIV_FMAS_F64: _VOP3AOp_V_DIV_FMAS_F64, VOP3AOp.V_MSAD_U8: _VOP3AOp_V_MSAD_U8, + VOP3AOp.V_QSAD_PK_U16_U8: _VOP3AOp_V_QSAD_PK_U16_U8, + VOP3AOp.V_MQSAD_PK_U16_U8: _VOP3AOp_V_MQSAD_PK_U16_U8, + VOP3AOp.V_MQSAD_U32_U8: _VOP3AOp_V_MQSAD_U32_U8, VOP3AOp.V_MAD_LEGACY_F16: _VOP3AOp_V_MAD_LEGACY_F16, VOP3AOp.V_MAD_LEGACY_U16: _VOP3AOp_V_MAD_LEGACY_U16, VOP3AOp.V_MAD_LEGACY_I16: _VOP3AOp_V_MAD_LEGACY_I16, @@ -16879,27 +17421,43 @@ VOP3AOp_FUNCTIONS = { VOP3AOp.V_CVT_SCALEF32_PK_BF8_F16: _VOP3AOp_V_CVT_SCALEF32_PK_BF8_F16, VOP3AOp.V_CVT_SCALEF32_SR_FP8_F16: _VOP3AOp_V_CVT_SCALEF32_SR_FP8_F16, VOP3AOp.V_CVT_SCALEF32_SR_BF8_F16: _VOP3AOp_V_CVT_SCALEF32_SR_BF8_F16, + VOP3AOp.V_CVT_SCALEF32_PK_FP8_BF16: _VOP3AOp_V_CVT_SCALEF32_PK_FP8_BF16, + VOP3AOp.V_CVT_SCALEF32_PK_BF8_BF16: _VOP3AOp_V_CVT_SCALEF32_PK_BF8_BF16, + VOP3AOp.V_CVT_SCALEF32_SR_FP8_BF16: _VOP3AOp_V_CVT_SCALEF32_SR_FP8_BF16, + VOP3AOp.V_CVT_SCALEF32_SR_BF8_BF16: _VOP3AOp_V_CVT_SCALEF32_SR_BF8_BF16, VOP3AOp.V_CVT_SCALEF32_PK_F16_FP8: _VOP3AOp_V_CVT_SCALEF32_PK_F16_FP8, VOP3AOp.V_CVT_SCALEF32_PK_F16_BF8: _VOP3AOp_V_CVT_SCALEF32_PK_F16_BF8, VOP3AOp.V_CVT_SCALEF32_F16_FP8: _VOP3AOp_V_CVT_SCALEF32_F16_FP8, VOP3AOp.V_CVT_SCALEF32_F16_BF8: _VOP3AOp_V_CVT_SCALEF32_F16_BF8, VOP3AOp.V_CVT_SCALEF32_PK_FP4_F16: _VOP3AOp_V_CVT_SCALEF32_PK_FP4_F16, + VOP3AOp.V_CVT_SCALEF32_PK_FP4_BF16: _VOP3AOp_V_CVT_SCALEF32_PK_FP4_BF16, VOP3AOp.V_CVT_SCALEF32_SR_PK_FP4_F16: _VOP3AOp_V_CVT_SCALEF32_SR_PK_FP4_F16, + VOP3AOp.V_CVT_SCALEF32_SR_PK_FP4_BF16: _VOP3AOp_V_CVT_SCALEF32_SR_PK_FP4_BF16, VOP3AOp.V_CVT_SCALEF32_PK_F16_FP4: _VOP3AOp_V_CVT_SCALEF32_PK_F16_FP4, + VOP3AOp.V_CVT_SCALEF32_PK_BF16_FP4: _VOP3AOp_V_CVT_SCALEF32_PK_BF16_FP4, VOP3AOp.V_CVT_SCALEF32_2XPK16_FP6_F32: _VOP3AOp_V_CVT_SCALEF32_2XPK16_FP6_F32, VOP3AOp.V_CVT_SCALEF32_2XPK16_BF6_F32: _VOP3AOp_V_CVT_SCALEF32_2XPK16_BF6_F32, VOP3AOp.V_CVT_SCALEF32_SR_PK32_FP6_F32: _VOP3AOp_V_CVT_SCALEF32_SR_PK32_FP6_F32, VOP3AOp.V_CVT_SCALEF32_SR_PK32_BF6_F32: _VOP3AOp_V_CVT_SCALEF32_SR_PK32_BF6_F32, VOP3AOp.V_CVT_SCALEF32_PK32_F32_FP6: _VOP3AOp_V_CVT_SCALEF32_PK32_F32_FP6, VOP3AOp.V_CVT_SCALEF32_PK32_F32_BF6: _VOP3AOp_V_CVT_SCALEF32_PK32_F32_BF6, + VOP3AOp.V_CVT_SCALEF32_PK32_FP6_BF16: _VOP3AOp_V_CVT_SCALEF32_PK32_FP6_BF16, VOP3AOp.V_CVT_SCALEF32_PK32_BF6_F16: _VOP3AOp_V_CVT_SCALEF32_PK32_BF6_F16, + VOP3AOp.V_CVT_SCALEF32_PK32_BF6_BF16: _VOP3AOp_V_CVT_SCALEF32_PK32_BF6_BF16, VOP3AOp.V_CVT_SCALEF32_SR_PK32_FP6_F16: _VOP3AOp_V_CVT_SCALEF32_SR_PK32_FP6_F16, + VOP3AOp.V_CVT_SCALEF32_SR_PK32_FP6_BF16: _VOP3AOp_V_CVT_SCALEF32_SR_PK32_FP6_BF16, VOP3AOp.V_CVT_SCALEF32_SR_PK32_BF6_F16: _VOP3AOp_V_CVT_SCALEF32_SR_PK32_BF6_F16, + VOP3AOp.V_CVT_SCALEF32_SR_PK32_BF6_BF16: _VOP3AOp_V_CVT_SCALEF32_SR_PK32_BF6_BF16, VOP3AOp.V_CVT_SCALEF32_PK32_F16_FP6: _VOP3AOp_V_CVT_SCALEF32_PK32_F16_FP6, + VOP3AOp.V_CVT_SCALEF32_PK32_BF16_FP6: _VOP3AOp_V_CVT_SCALEF32_PK32_BF16_FP6, VOP3AOp.V_CVT_SCALEF32_PK32_F16_BF6: _VOP3AOp_V_CVT_SCALEF32_PK32_F16_BF6, + VOP3AOp.V_CVT_SCALEF32_PK32_BF16_BF6: _VOP3AOp_V_CVT_SCALEF32_PK32_BF16_BF6, VOP3AOp.V_ASHR_PK_I8_I32: _VOP3AOp_V_ASHR_PK_I8_I32, VOP3AOp.V_ASHR_PK_U8_I32: _VOP3AOp_V_ASHR_PK_U8_I32, VOP3AOp.V_CVT_PK_F16_F32: _VOP3AOp_V_CVT_PK_F16_F32, + VOP3AOp.V_CVT_PK_BF16_F32: _VOP3AOp_V_CVT_PK_BF16_F32, + VOP3AOp.V_CVT_SCALEF32_PK_BF16_FP8: _VOP3AOp_V_CVT_SCALEF32_PK_BF16_FP8, + VOP3AOp.V_CVT_SCALEF32_PK_BF16_BF8: _VOP3AOp_V_CVT_SCALEF32_PK_BF16_BF8, VOP3AOp.V_ADD_F64: _VOP3AOp_V_ADD_F64, VOP3AOp.V_MUL_F64: _VOP3AOp_V_MUL_F64, VOP3AOp.V_MIN_F64: _VOP3AOp_V_MIN_F64, diff --git a/extra/assembly/amd/autogen/rdna3/gen_pcode.py b/extra/assembly/amd/autogen/rdna3/gen_pcode.py index bc7da76621..8ce42c1cc9 100644 --- a/extra/assembly/amd/autogen/rdna3/gen_pcode.py +++ b/extra/assembly/amd/autogen/rdna3/gen_pcode.py @@ -394,6 +394,94 @@ def _SOP1Op_S_BCNT1_I32_B64(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, result['d0_64'] = True return result +def _SOP1Op_S_QUADMASK_B32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): + # tmp = 0U; + # for i in 0 : 7 do + # tmp[i] = S0.u32[i * 4 +: 4] != 0U + # endfor; + # D0.u32 = tmp; + # SCC = D0.u32 != 0U + S0 = Reg(s0) + D0 = Reg(d0) + SCC = Reg(scc) + tmp = Reg(0) + # --- compiled pseudocode --- + tmp = Reg(0) + for i in range(0, int(7)+1): + tmp[i] = S0.u32[(i * 4) + (4) - 1 : (i * 4)] != 0 + D0.u32 = tmp + SCC = Reg(D0.u32 != 0) + # --- end pseudocode --- + result = {'d0': D0._val, 'scc': SCC._val & 1} + return result + +def _SOP1Op_S_QUADMASK_B64(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): + # tmp = 0ULL; + # for i in 0 : 15 do + # tmp[i] = S0.u64[i * 4 +: 4] != 0ULL + # endfor; + # D0.u64 = tmp; + # SCC = D0.u64 != 0ULL + S0 = Reg(s0) + D0 = Reg(d0) + SCC = Reg(scc) + tmp = Reg(0) + # --- compiled pseudocode --- + tmp = Reg(0) + for i in range(0, int(15)+1): + tmp[i] = S0.u64[(i * 4) + (4) - 1 : (i * 4)] != 0 + D0.u64 = tmp + SCC = Reg(D0.u64 != 0) + # --- end pseudocode --- + result = {'d0': D0._val, 'scc': SCC._val & 1} + result['d0_64'] = True + return result + +def _SOP1Op_S_WQM_B32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): + # tmp = 0U; + # declare i : 6'U; + # for i in 6'0U : 6'31U do + # tmp[i] = S0.u32[i & 6'60U +: 6'4U] != 0U + # endfor; + # D0.u32 = tmp; + # SCC = D0.u32 != 0U + S0 = Reg(s0) + D0 = Reg(d0) + SCC = Reg(scc) + tmp = Reg(0) + # --- compiled pseudocode --- + tmp = Reg(0) + for i in range(0, int(31)+1): + tmp[i] = S0.u32[(i & 60) + (4) - 1 : (i & 60)] != 0 + D0.u32 = tmp + SCC = Reg(D0.u32 != 0) + # --- end pseudocode --- + result = {'d0': D0._val, 'scc': SCC._val & 1} + return result + +def _SOP1Op_S_WQM_B64(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): + # tmp = 0ULL; + # declare i : 6'U; + # for i in 6'0U : 6'63U do + # tmp[i] = S0.u64[i & 6'60U +: 6'4U] != 0ULL + # endfor; + # D0.u64 = tmp; + # SCC = D0.u64 != 0ULL + S0 = Reg(s0) + D0 = Reg(d0) + SCC = Reg(scc) + tmp = Reg(0) + # --- compiled pseudocode --- + tmp = Reg(0) + for i in range(0, int(63)+1): + tmp[i] = S0.u64[(i & 60) + (4) - 1 : (i & 60)] != 0 + D0.u64 = tmp + SCC = Reg(D0.u64 != 0) + # --- end pseudocode --- + result = {'d0': D0._val, 'scc': SCC._val & 1} + result['d0_64'] = True + return result + def _SOP1Op_S_NOT_B32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): # D0.u32 = ~S0.u32; # SCC = D0.u32 != 0U @@ -1178,6 +1266,10 @@ SOP1Op_FUNCTIONS = { SOP1Op.S_BCNT0_I32_B64: _SOP1Op_S_BCNT0_I32_B64, SOP1Op.S_BCNT1_I32_B32: _SOP1Op_S_BCNT1_I32_B32, SOP1Op.S_BCNT1_I32_B64: _SOP1Op_S_BCNT1_I32_B64, + SOP1Op.S_QUADMASK_B32: _SOP1Op_S_QUADMASK_B32, + SOP1Op.S_QUADMASK_B64: _SOP1Op_S_QUADMASK_B64, + SOP1Op.S_WQM_B32: _SOP1Op_S_WQM_B32, + SOP1Op.S_WQM_B64: _SOP1Op_S_WQM_B64, SOP1Op.S_NOT_B32: _SOP1Op_S_NOT_B32, SOP1Op.S_NOT_B64: _SOP1Op_S_NOT_B64, SOP1Op.S_AND_SAVEEXEC_B32: _SOP1Op_S_AND_SAVEEXEC_B32, @@ -10159,6 +10251,18 @@ def _VOP3Op_V_SAD_U8(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _ result = {'d0': D0._val, 'scc': scc & 1} return result +def _VOP3Op_V_SAD_HI_U8(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): + # D0.u32 = (32'U(v_sad_u8(S0, S1, 0U)) << 16U) + S2.u32 + S0 = Reg(s0) + S1 = Reg(s1) + S2 = Reg(s2) + D0 = Reg(d0) + # --- compiled pseudocode --- + D0.u32 = ((v_sad_u8(S0, S1, 0)) << 16) + S2.u32 + # --- end pseudocode --- + result = {'d0': D0._val, 'scc': scc & 1} + return result + def _VOP3Op_V_SAD_U16(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): # // UNSIGNED comparison # tmp = S2.u32; @@ -10385,6 +10489,71 @@ def _VOP3Op_V_MSAD_U8(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, result = {'d0': D0._val, 'scc': scc & 1} return result +def _VOP3Op_V_QSAD_PK_U16_U8(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): + # tmp[63 : 48] = 16'B(v_sad_u8(S0[55 : 24], S1[31 : 0], S2[63 : 48].u32)); + # tmp[47 : 32] = 16'B(v_sad_u8(S0[47 : 16], S1[31 : 0], S2[47 : 32].u32)); + # tmp[31 : 16] = 16'B(v_sad_u8(S0[39 : 8], S1[31 : 0], S2[31 : 16].u32)); + # tmp[15 : 0] = 16'B(v_sad_u8(S0[31 : 0], S1[31 : 0], S2[15 : 0].u32)); + # D0.b64 = tmp.b64 + S0 = Reg(s0) + S1 = Reg(s1) + S2 = Reg(s2) + D0 = Reg(d0) + tmp = Reg(0) + # --- compiled pseudocode --- + tmp[63 : 48] = (v_sad_u8(S0[55 : 24], S1[31 : 0], S2[63 : 48].u32)) + tmp[47 : 32] = (v_sad_u8(S0[47 : 16], S1[31 : 0], S2[47 : 32].u32)) + tmp[31 : 16] = (v_sad_u8(S0[39 : 8], S1[31 : 0], S2[31 : 16].u32)) + tmp[15 : 0] = (v_sad_u8(S0[31 : 0], S1[31 : 0], S2[15 : 0].u32)) + D0.b64 = tmp.b64 + # --- end pseudocode --- + result = {'d0': D0._val, 'scc': scc & 1} + result['d0_64'] = True + return result + +def _VOP3Op_V_MQSAD_PK_U16_U8(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): + # tmp[63 : 48] = 16'B(v_msad_u8(S0[55 : 24], S1[31 : 0], S2[63 : 48].u32)); + # tmp[47 : 32] = 16'B(v_msad_u8(S0[47 : 16], S1[31 : 0], S2[47 : 32].u32)); + # tmp[31 : 16] = 16'B(v_msad_u8(S0[39 : 8], S1[31 : 0], S2[31 : 16].u32)); + # tmp[15 : 0] = 16'B(v_msad_u8(S0[31 : 0], S1[31 : 0], S2[15 : 0].u32)); + # D0.b64 = tmp.b64 + S0 = Reg(s0) + S1 = Reg(s1) + S2 = Reg(s2) + D0 = Reg(d0) + tmp = Reg(0) + # --- compiled pseudocode --- + tmp[63 : 48] = (v_msad_u8(S0[55 : 24], S1[31 : 0], S2[63 : 48].u32)) + tmp[47 : 32] = (v_msad_u8(S0[47 : 16], S1[31 : 0], S2[47 : 32].u32)) + tmp[31 : 16] = (v_msad_u8(S0[39 : 8], S1[31 : 0], S2[31 : 16].u32)) + tmp[15 : 0] = (v_msad_u8(S0[31 : 0], S1[31 : 0], S2[15 : 0].u32)) + D0.b64 = tmp.b64 + # --- end pseudocode --- + result = {'d0': D0._val, 'scc': scc & 1} + result['d0_64'] = True + return result + +def _VOP3Op_V_MQSAD_U32_U8(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): + # tmp[127 : 96] = 32'B(v_msad_u8(S0[55 : 24], S1[31 : 0], S2[127 : 96].u32)); + # tmp[95 : 64] = 32'B(v_msad_u8(S0[47 : 16], S1[31 : 0], S2[95 : 64].u32)); + # tmp[63 : 32] = 32'B(v_msad_u8(S0[39 : 8], S1[31 : 0], S2[63 : 32].u32)); + # tmp[31 : 0] = 32'B(v_msad_u8(S0[31 : 0], S1[31 : 0], S2[31 : 0].u32)); + # D0.b128 = tmp.b128 + S0 = Reg(s0) + S1 = Reg(s1) + S2 = Reg(s2) + D0 = Reg(d0) + tmp = Reg(0) + # --- compiled pseudocode --- + tmp[127 : 96] = (v_msad_u8(S0[55 : 24], S1[31 : 0], S2[127 : 96].u32)) + tmp[95 : 64] = (v_msad_u8(S0[47 : 16], S1[31 : 0], S2[95 : 64].u32)) + tmp[63 : 32] = (v_msad_u8(S0[39 : 8], S1[31 : 0], S2[63 : 32].u32)) + tmp[31 : 0] = (v_msad_u8(S0[31 : 0], S1[31 : 0], S2[31 : 0].u32)) + D0.b128 = tmp.b128 + # --- end pseudocode --- + result = {'d0': D0._val, 'scc': scc & 1} + return result + def _VOP3Op_V_XOR3_B32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): # D0.u32 = (S0.u32 ^ S1.u32 ^ S2.u32) S0 = Reg(s0) @@ -10860,6 +11029,25 @@ def _VOP3Op_V_DOT2_F16_F16(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, V result = {'d0': D0._val, 'scc': scc & 1} return result +def _VOP3Op_V_DOT2_BF16_BF16(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): + # tmp = S2.bf16; + # tmp += S0[15 : 0].bf16 * S1[15 : 0].bf16; + # tmp += S0[31 : 16].bf16 * S1[31 : 16].bf16; + # D0.bf16 = tmp + S0 = Reg(s0) + S1 = Reg(s1) + S2 = Reg(s2) + D0 = Reg(d0) + tmp = Reg(0) + # --- compiled pseudocode --- + tmp = Reg(S2.bf16) + tmp += S0[15 : 0].bf16 * S1[15 : 0].bf16 + tmp += S0[31 : 16].bf16 * S1[31 : 16].bf16 + D0.bf16 = tmp + # --- end pseudocode --- + result = {'d0': D0._val, 'scc': scc & 1} + return result + def _VOP3Op_V_ADD_NC_U16(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): # D0.u16 = S0.u16 + S1.u16 S0 = Reg(s0) @@ -11793,6 +11981,7 @@ VOP3Op_FUNCTIONS = { VOP3Op.V_MED3_I32: _VOP3Op_V_MED3_I32, VOP3Op.V_MED3_U32: _VOP3Op_V_MED3_U32, VOP3Op.V_SAD_U8: _VOP3Op_V_SAD_U8, + VOP3Op.V_SAD_HI_U8: _VOP3Op_V_SAD_HI_U8, VOP3Op.V_SAD_U16: _VOP3Op_V_SAD_U16, VOP3Op.V_SAD_U32: _VOP3Op_V_SAD_U32, VOP3Op.V_CVT_PK_U8_F32: _VOP3Op_V_CVT_PK_U8_F32, @@ -11801,6 +11990,9 @@ VOP3Op_FUNCTIONS = { VOP3Op.V_DIV_FMAS_F32: _VOP3Op_V_DIV_FMAS_F32, VOP3Op.V_DIV_FMAS_F64: _VOP3Op_V_DIV_FMAS_F64, VOP3Op.V_MSAD_U8: _VOP3Op_V_MSAD_U8, + VOP3Op.V_QSAD_PK_U16_U8: _VOP3Op_V_QSAD_PK_U16_U8, + VOP3Op.V_MQSAD_PK_U16_U8: _VOP3Op_V_MQSAD_PK_U16_U8, + VOP3Op.V_MQSAD_U32_U8: _VOP3Op_V_MQSAD_U32_U8, VOP3Op.V_XOR3_B32: _VOP3Op_V_XOR3_B32, VOP3Op.V_MAD_U16: _VOP3Op_V_MAD_U16, VOP3Op.V_XAD_U32: _VOP3Op_V_XAD_U32, @@ -11834,6 +12026,7 @@ VOP3Op_FUNCTIONS = { VOP3Op.V_MAXMIN_I32: _VOP3Op_V_MAXMIN_I32, VOP3Op.V_MINMAX_I32: _VOP3Op_V_MINMAX_I32, VOP3Op.V_DOT2_F16_F16: _VOP3Op_V_DOT2_F16_F16, + VOP3Op.V_DOT2_BF16_BF16: _VOP3Op_V_DOT2_BF16_BF16, VOP3Op.V_ADD_NC_U16: _VOP3Op_V_ADD_NC_U16, VOP3Op.V_SUB_NC_U16: _VOP3Op_V_SUB_NC_U16, VOP3Op.V_MUL_LO_U16: _VOP3Op_V_MUL_LO_U16, @@ -12552,6 +12745,25 @@ def _VOP3POp_V_DOT8_U32_U4(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, V result = {'d0': D0._val, 'scc': scc & 1} return result +def _VOP3POp_V_DOT2_F32_BF16(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): + # tmp = S2.f32; + # tmp += bf16_to_f32(S0[15 : 0].bf16) * bf16_to_f32(S1[15 : 0].bf16); + # tmp += bf16_to_f32(S0[31 : 16].bf16) * bf16_to_f32(S1[31 : 16].bf16); + # D0.f32 = tmp + S0 = Reg(s0) + S1 = Reg(s1) + S2 = Reg(s2) + D0 = Reg(d0) + tmp = Reg(0) + # --- compiled pseudocode --- + tmp = Reg(S2.f32) + tmp += bf16_to_f32(S0[15 : 0].bf16) * bf16_to_f32(S1[15 : 0].bf16) + tmp += bf16_to_f32(S0[31 : 16].bf16) * bf16_to_f32(S1[31 : 16].bf16) + D0.f32 = tmp + # --- end pseudocode --- + result = {'d0': D0._val, 'scc': scc & 1} + return result + VOP3POp_FUNCTIONS = { VOP3POp.V_PK_MAD_I16: _VOP3POp_V_PK_MAD_I16, VOP3POp.V_PK_MUL_LO_U16: _VOP3POp_V_PK_MUL_LO_U16, @@ -12575,6 +12787,7 @@ VOP3POp_FUNCTIONS = { VOP3POp.V_DOT2_F32_F16: _VOP3POp_V_DOT2_F32_F16, VOP3POp.V_DOT4_U32_U8: _VOP3POp_V_DOT4_U32_U8, VOP3POp.V_DOT8_U32_U4: _VOP3POp_V_DOT8_U32_U4, + VOP3POp.V_DOT2_F32_BF16: _VOP3POp_V_DOT2_F32_BF16, } def _VOPCOp_V_CMP_F_F16(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): diff --git a/extra/assembly/amd/autogen/rdna4/gen_pcode.py b/extra/assembly/amd/autogen/rdna4/gen_pcode.py index f06aad1f58..e7cc670a9b 100644 --- a/extra/assembly/amd/autogen/rdna4/gen_pcode.py +++ b/extra/assembly/amd/autogen/rdna4/gen_pcode.py @@ -394,6 +394,94 @@ def _SOP1Op_S_BCNT1_I32_B64(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, result['d0_64'] = True return result +def _SOP1Op_S_QUADMASK_B32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): + # tmp = 0U; + # for i in 0 : 7 do + # tmp[i] = S0.u32[i * 4 +: 4] != 0U + # endfor; + # D0.u32 = tmp; + # SCC = D0.u32 != 0U + S0 = Reg(s0) + D0 = Reg(d0) + SCC = Reg(scc) + tmp = Reg(0) + # --- compiled pseudocode --- + tmp = Reg(0) + for i in range(0, int(7)+1): + tmp[i] = S0.u32[(i * 4) + (4) - 1 : (i * 4)] != 0 + D0.u32 = tmp + SCC = Reg(D0.u32 != 0) + # --- end pseudocode --- + result = {'d0': D0._val, 'scc': SCC._val & 1} + return result + +def _SOP1Op_S_QUADMASK_B64(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): + # tmp = 0ULL; + # for i in 0 : 15 do + # tmp[i] = S0.u64[i * 4 +: 4] != 0ULL + # endfor; + # D0.u64 = tmp; + # SCC = D0.u64 != 0ULL + S0 = Reg(s0) + D0 = Reg(d0) + SCC = Reg(scc) + tmp = Reg(0) + # --- compiled pseudocode --- + tmp = Reg(0) + for i in range(0, int(15)+1): + tmp[i] = S0.u64[(i * 4) + (4) - 1 : (i * 4)] != 0 + D0.u64 = tmp + SCC = Reg(D0.u64 != 0) + # --- end pseudocode --- + result = {'d0': D0._val, 'scc': SCC._val & 1} + result['d0_64'] = True + return result + +def _SOP1Op_S_WQM_B32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): + # tmp = 0U; + # declare i : 6'U; + # for i in 6'0U : 6'31U do + # tmp[i] = S0.u32[i & 6'60U +: 6'4U] != 0U + # endfor; + # D0.u32 = tmp; + # SCC = D0.u32 != 0U + S0 = Reg(s0) + D0 = Reg(d0) + SCC = Reg(scc) + tmp = Reg(0) + # --- compiled pseudocode --- + tmp = Reg(0) + for i in range(0, int(31)+1): + tmp[i] = S0.u32[(i & 60) + (4) - 1 : (i & 60)] != 0 + D0.u32 = tmp + SCC = Reg(D0.u32 != 0) + # --- end pseudocode --- + result = {'d0': D0._val, 'scc': SCC._val & 1} + return result + +def _SOP1Op_S_WQM_B64(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): + # tmp = 0ULL; + # declare i : 6'U; + # for i in 6'0U : 6'63U do + # tmp[i] = S0.u64[i & 6'60U +: 6'4U] != 0ULL + # endfor; + # D0.u64 = tmp; + # SCC = D0.u64 != 0ULL + S0 = Reg(s0) + D0 = Reg(d0) + SCC = Reg(scc) + tmp = Reg(0) + # --- compiled pseudocode --- + tmp = Reg(0) + for i in range(0, int(63)+1): + tmp[i] = S0.u64[(i & 60) + (4) - 1 : (i & 60)] != 0 + D0.u64 = tmp + SCC = Reg(D0.u64 != 0) + # --- end pseudocode --- + result = {'d0': D0._val, 'scc': SCC._val & 1} + result['d0_64'] = True + return result + def _SOP1Op_S_NOT_B32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): # D0.u32 = ~S0.u32; # SCC = D0.u32 != 0U @@ -1264,6 +1352,10 @@ SOP1Op_FUNCTIONS = { SOP1Op.S_BCNT0_I32_B64: _SOP1Op_S_BCNT0_I32_B64, SOP1Op.S_BCNT1_I32_B32: _SOP1Op_S_BCNT1_I32_B32, SOP1Op.S_BCNT1_I32_B64: _SOP1Op_S_BCNT1_I32_B64, + SOP1Op.S_QUADMASK_B32: _SOP1Op_S_QUADMASK_B32, + SOP1Op.S_QUADMASK_B64: _SOP1Op_S_QUADMASK_B64, + SOP1Op.S_WQM_B32: _SOP1Op_S_WQM_B32, + SOP1Op.S_WQM_B64: _SOP1Op_S_WQM_B64, SOP1Op.S_NOT_B32: _SOP1Op_S_NOT_B32, SOP1Op.S_NOT_B64: _SOP1Op_S_NOT_B64, SOP1Op.S_AND_SAVEEXEC_B32: _SOP1Op_S_AND_SAVEEXEC_B32, @@ -10034,6 +10126,18 @@ def _VOP3Op_V_SAD_U8(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _ result = {'d0': D0._val, 'scc': scc & 1} return result +def _VOP3Op_V_SAD_HI_U8(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): + # D0.u32 = (32'U(v_sad_u8(S0, S1, 0U)) << 16U) + S2.u32 + S0 = Reg(s0) + S1 = Reg(s1) + S2 = Reg(s2) + D0 = Reg(d0) + # --- compiled pseudocode --- + D0.u32 = ((v_sad_u8(S0, S1, 0)) << 16) + S2.u32 + # --- end pseudocode --- + result = {'d0': D0._val, 'scc': scc & 1} + return result + def _VOP3Op_V_SAD_U16(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): # // UNSIGNED comparison # tmp = S2.u32; @@ -10410,6 +10514,71 @@ def _VOP3Op_V_MSAD_U8(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, result = {'d0': D0._val, 'scc': scc & 1} return result +def _VOP3Op_V_QSAD_PK_U16_U8(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): + # tmp[63 : 48] = 16'B(v_sad_u8(S0[55 : 24], S1[31 : 0], S2[63 : 48].u32)); + # tmp[47 : 32] = 16'B(v_sad_u8(S0[47 : 16], S1[31 : 0], S2[47 : 32].u32)); + # tmp[31 : 16] = 16'B(v_sad_u8(S0[39 : 8], S1[31 : 0], S2[31 : 16].u32)); + # tmp[15 : 0] = 16'B(v_sad_u8(S0[31 : 0], S1[31 : 0], S2[15 : 0].u32)); + # D0.b64 = tmp.b64 + S0 = Reg(s0) + S1 = Reg(s1) + S2 = Reg(s2) + D0 = Reg(d0) + tmp = Reg(0) + # --- compiled pseudocode --- + tmp[63 : 48] = (v_sad_u8(S0[55 : 24], S1[31 : 0], S2[63 : 48].u32)) + tmp[47 : 32] = (v_sad_u8(S0[47 : 16], S1[31 : 0], S2[47 : 32].u32)) + tmp[31 : 16] = (v_sad_u8(S0[39 : 8], S1[31 : 0], S2[31 : 16].u32)) + tmp[15 : 0] = (v_sad_u8(S0[31 : 0], S1[31 : 0], S2[15 : 0].u32)) + D0.b64 = tmp.b64 + # --- end pseudocode --- + result = {'d0': D0._val, 'scc': scc & 1} + result['d0_64'] = True + return result + +def _VOP3Op_V_MQSAD_PK_U16_U8(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): + # tmp[63 : 48] = 16'B(v_msad_u8(S0[55 : 24], S1[31 : 0], S2[63 : 48].u32)); + # tmp[47 : 32] = 16'B(v_msad_u8(S0[47 : 16], S1[31 : 0], S2[47 : 32].u32)); + # tmp[31 : 16] = 16'B(v_msad_u8(S0[39 : 8], S1[31 : 0], S2[31 : 16].u32)); + # tmp[15 : 0] = 16'B(v_msad_u8(S0[31 : 0], S1[31 : 0], S2[15 : 0].u32)); + # D0.b64 = tmp.b64 + S0 = Reg(s0) + S1 = Reg(s1) + S2 = Reg(s2) + D0 = Reg(d0) + tmp = Reg(0) + # --- compiled pseudocode --- + tmp[63 : 48] = (v_msad_u8(S0[55 : 24], S1[31 : 0], S2[63 : 48].u32)) + tmp[47 : 32] = (v_msad_u8(S0[47 : 16], S1[31 : 0], S2[47 : 32].u32)) + tmp[31 : 16] = (v_msad_u8(S0[39 : 8], S1[31 : 0], S2[31 : 16].u32)) + tmp[15 : 0] = (v_msad_u8(S0[31 : 0], S1[31 : 0], S2[15 : 0].u32)) + D0.b64 = tmp.b64 + # --- end pseudocode --- + result = {'d0': D0._val, 'scc': scc & 1} + result['d0_64'] = True + return result + +def _VOP3Op_V_MQSAD_U32_U8(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): + # tmp[127 : 96] = 32'B(v_msad_u8(S0[55 : 24], S1[31 : 0], S2[127 : 96].u32)); + # tmp[95 : 64] = 32'B(v_msad_u8(S0[47 : 16], S1[31 : 0], S2[95 : 64].u32)); + # tmp[63 : 32] = 32'B(v_msad_u8(S0[39 : 8], S1[31 : 0], S2[63 : 32].u32)); + # tmp[31 : 0] = 32'B(v_msad_u8(S0[31 : 0], S1[31 : 0], S2[31 : 0].u32)); + # D0.b128 = tmp.b128 + S0 = Reg(s0) + S1 = Reg(s1) + S2 = Reg(s2) + D0 = Reg(d0) + tmp = Reg(0) + # --- compiled pseudocode --- + tmp[127 : 96] = (v_msad_u8(S0[55 : 24], S1[31 : 0], S2[127 : 96].u32)) + tmp[95 : 64] = (v_msad_u8(S0[47 : 16], S1[31 : 0], S2[95 : 64].u32)) + tmp[63 : 32] = (v_msad_u8(S0[39 : 8], S1[31 : 0], S2[63 : 32].u32)) + tmp[31 : 0] = (v_msad_u8(S0[31 : 0], S1[31 : 0], S2[31 : 0].u32)) + D0.b128 = tmp.b128 + # --- end pseudocode --- + result = {'d0': D0._val, 'scc': scc & 1} + return result + def _VOP3Op_V_XOR3_B32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): # D0.u32 = (S0.u32 ^ S1.u32 ^ S2.u32) S0 = Reg(s0) @@ -10786,6 +10955,25 @@ def _VOP3Op_V_DOT2_F16_F16(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, V result = {'d0': D0._val, 'scc': scc & 1} return result +def _VOP3Op_V_DOT2_BF16_BF16(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): + # tmp = S2.bf16; + # tmp += S0[15 : 0].bf16 * S1[15 : 0].bf16; + # tmp += S0[31 : 16].bf16 * S1[31 : 16].bf16; + # D0.bf16 = tmp + S0 = Reg(s0) + S1 = Reg(s1) + S2 = Reg(s2) + D0 = Reg(d0) + tmp = Reg(0) + # --- compiled pseudocode --- + tmp = Reg(S2.bf16) + tmp += S0[15 : 0].bf16 * S1[15 : 0].bf16 + tmp += S0[31 : 16].bf16 * S1[31 : 16].bf16 + D0.bf16 = tmp + # --- end pseudocode --- + result = {'d0': D0._val, 'scc': scc & 1} + return result + def _VOP3Op_V_MINMAX_NUM_F32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): # D0.f32 = v_max_num_f32(v_min_num_f32(S0.f32, S1.f32), S2.f32) S0 = Reg(s0) @@ -11993,6 +12181,7 @@ VOP3Op_FUNCTIONS = { VOP3Op.V_MED3_I32: _VOP3Op_V_MED3_I32, VOP3Op.V_MED3_U32: _VOP3Op_V_MED3_U32, VOP3Op.V_SAD_U8: _VOP3Op_V_SAD_U8, + VOP3Op.V_SAD_HI_U8: _VOP3Op_V_SAD_HI_U8, VOP3Op.V_SAD_U16: _VOP3Op_V_SAD_U16, VOP3Op.V_SAD_U32: _VOP3Op_V_SAD_U32, VOP3Op.V_CVT_PK_U8_F32: _VOP3Op_V_CVT_PK_U8_F32, @@ -12011,6 +12200,9 @@ VOP3Op_FUNCTIONS = { VOP3Op.V_DIV_FMAS_F32: _VOP3Op_V_DIV_FMAS_F32, VOP3Op.V_DIV_FMAS_F64: _VOP3Op_V_DIV_FMAS_F64, VOP3Op.V_MSAD_U8: _VOP3Op_V_MSAD_U8, + VOP3Op.V_QSAD_PK_U16_U8: _VOP3Op_V_QSAD_PK_U16_U8, + VOP3Op.V_MQSAD_PK_U16_U8: _VOP3Op_V_MQSAD_PK_U16_U8, + VOP3Op.V_MQSAD_U32_U8: _VOP3Op_V_MQSAD_U32_U8, VOP3Op.V_XOR3_B32: _VOP3Op_V_XOR3_B32, VOP3Op.V_MAD_U16: _VOP3Op_V_MAD_U16, VOP3Op.V_XAD_U32: _VOP3Op_V_XAD_U32, @@ -12037,6 +12229,7 @@ VOP3Op_FUNCTIONS = { VOP3Op.V_MAXMIN_I32: _VOP3Op_V_MAXMIN_I32, VOP3Op.V_MINMAX_I32: _VOP3Op_V_MINMAX_I32, VOP3Op.V_DOT2_F16_F16: _VOP3Op_V_DOT2_F16_F16, + VOP3Op.V_DOT2_BF16_BF16: _VOP3Op_V_DOT2_BF16_BF16, VOP3Op.V_MINMAX_NUM_F32: _VOP3Op_V_MINMAX_NUM_F32, VOP3Op.V_MAXMIN_NUM_F32: _VOP3Op_V_MAXMIN_NUM_F32, VOP3Op.V_MINMAX_NUM_F16: _VOP3Op_V_MINMAX_NUM_F16, @@ -12754,6 +12947,25 @@ def _VOP3POp_V_DOT8_U32_U4(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, V result = {'d0': D0._val, 'scc': scc & 1} return result +def _VOP3POp_V_DOT2_F32_BF16(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): + # tmp = S2.f32; + # tmp += bf16_to_f32(S0[15 : 0].bf16) * bf16_to_f32(S1[15 : 0].bf16); + # tmp += bf16_to_f32(S0[31 : 16].bf16) * bf16_to_f32(S1[31 : 16].bf16); + # D0.f32 = tmp + S0 = Reg(s0) + S1 = Reg(s1) + S2 = Reg(s2) + D0 = Reg(d0) + tmp = Reg(0) + # --- compiled pseudocode --- + tmp = Reg(S2.f32) + tmp += bf16_to_f32(S0[15 : 0].bf16) * bf16_to_f32(S1[15 : 0].bf16) + tmp += bf16_to_f32(S0[31 : 16].bf16) * bf16_to_f32(S1[31 : 16].bf16) + D0.f32 = tmp + # --- end pseudocode --- + result = {'d0': D0._val, 'scc': scc & 1} + return result + def _VOP3POp_V_PK_MIN_NUM_F16(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): # declare tmp : 32'B; # tmp[15 : 0].f16 = v_min_num_f16(S0[15 : 0].f16, S1[15 : 0].f16); @@ -12935,6 +13147,7 @@ VOP3POp_FUNCTIONS = { VOP3POp.V_DOT2_F32_F16: _VOP3POp_V_DOT2_F32_F16, VOP3POp.V_DOT4_U32_U8: _VOP3POp_V_DOT4_U32_U8, VOP3POp.V_DOT8_U32_U4: _VOP3POp_V_DOT8_U32_U4, + VOP3POp.V_DOT2_F32_BF16: _VOP3POp_V_DOT2_F32_BF16, VOP3POp.V_PK_MIN_NUM_F16: _VOP3POp_V_PK_MIN_NUM_F16, VOP3POp.V_PK_MAX_NUM_F16: _VOP3POp_V_PK_MAX_NUM_F16, VOP3POp.V_PK_MINIMUM_F16: _VOP3POp_V_PK_MINIMUM_F16, diff --git a/extra/assembly/amd/emu.py b/extra/assembly/amd/emu.py index 2a7c908ca5..cce704c55f 100644 --- a/extra/assembly/amd/emu.py +++ b/extra/assembly/amd/emu.py @@ -20,7 +20,8 @@ _VOP3_64BIT_OPS = {op.value for op in VOP3Op if op.name.endswith(('_F64', '_B64' # Ops where src1 is 32-bit (exponent/shift amount) even though the op name suggests 64-bit _VOP3_64BIT_OPS_32BIT_SRC1 = {VOP3Op.V_LDEXP_F64.value} # Ops with 16-bit types in name (for source/dest handling) -_VOP3_16BIT_OPS = {op for op in VOP3Op if any(s in op.name for s in ('_F16', '_B16', '_I16', '_U16'))} +# Exception: SAD/MSAD ops take 32-bit packed sources and extract 16-bit/8-bit chunks internally +_VOP3_16BIT_OPS = {op for op in VOP3Op if any(s in op.name for s in ('_F16', '_B16', '_I16', '_U16')) and 'SAD' not in op.name} _VOP1_16BIT_OPS = {op for op in VOP1Op if any(s in op.name for s in ('_F16', '_B16', '_I16', '_U16'))} _VOP2_16BIT_OPS = {op for op in VOP2Op if any(s in op.name for s in ('_F16', '_B16', '_I16', '_U16'))} # CVT ops with 32/64-bit source (despite 16-bit in name) @@ -382,11 +383,11 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No else: op_cls, op, src0, src1, src2, vdst = VOP3Op, VOP3Op(inst.op), inst.src0, inst.src1, inst.src2, inst.vdst # V_PERM_B32: byte permutation - not in pseudocode PDF, implement directly - # D0[byte_i] = selector[byte_i] < 8 ? {src1, src0}[selector[byte_i]] : (selector[byte_i] >= 0xD ? 0xFF : 0x00) + # D0[byte_i] = selector[byte_i] < 8 ? {src0, src1}[selector[byte_i]] : (selector[byte_i] >= 0xD ? 0xFF : 0x00) if op == VOP3Op.V_PERM_B32: s0, s1, s2 = st.rsrc(inst.src0, lane), st.rsrc(inst.src1, lane), st.rsrc(inst.src2, lane) - # Combine src0 and src1 into 8-byte value: src0 is bytes 0-3, src1 is bytes 4-7 - combined = (s0 & 0xffffffff) | ((s1 & 0xffffffff) << 32) + # Combine src1 and src0 into 8-byte value: src1 is bytes 0-3, src0 is bytes 4-7 + combined = (s1 & 0xffffffff) | ((s0 & 0xffffffff) << 32) result = 0 for i in range(4): # 4 result bytes sel = (s2 >> (i * 8)) & 0xff # byte selector for this position diff --git a/extra/assembly/amd/pcode.py b/extra/assembly/amd/pcode.py index 5ef1b2b0e0..dcf96a5a99 100644 --- a/extra/assembly/amd/pcode.py +++ b/extra/assembly/amd/pcode.py @@ -195,7 +195,59 @@ v_min3_i16 = v_min3_i32 v_max3_i16 = v_max3_i32 def v_min3_u16(a, b, c): return min(a & 0xffff, b & 0xffff, c & 0xffff) def v_max3_u16(a, b, c): return max(a & 0xffff, b & 0xffff, c & 0xffff) -def ABSDIFF(a, b): return abs(a - b) +def ABSDIFF(a, b): return abs(int(a) - int(b)) + +# BF16 (bfloat16) conversion functions +def _bf16(i): + """Convert bf16 bits to float. BF16 is just the top 16 bits of f32.""" + return struct.unpack(" 0 else 0xff80 # bf16 ±infinity + try: return (struct.unpack("> 16) & 0xffff + except (OverflowError, struct.error): return 0x7f80 if f > 0 else 0xff80 +def bf16_to_f32(v): return _bf16(v) if isinstance(v, int) else float(v) +def f32_to_bf16(f): return _ibf16(f) + +# BYTE_PERMUTE for V_PERM_B32 - select bytes from 64-bit data based on selector +def BYTE_PERMUTE(data, sel): + """Select a byte from 64-bit data based on selector value. + sel 0-7: select byte from data (S1 is bytes 0-3, S0 is bytes 4-7 in {S0,S1}) + sel 8-11: sign-extend from specific bytes (8->byte1, 9->byte3, 10->byte5, 11->byte7) + sel 12: constant 0x00 + sel >= 13: constant 0xFF""" + sel = int(sel) & 0xff + if sel <= 7: return (int(data) >> (sel * 8)) & 0xff + if sel == 8: return 0xff if ((int(data) >> 15) & 1) else 0x00 # sign of byte 1 + if sel == 9: return 0xff if ((int(data) >> 31) & 1) else 0x00 # sign of byte 3 + if sel == 10: return 0xff if ((int(data) >> 47) & 1) else 0x00 # sign of byte 5 + if sel == 11: return 0xff if ((int(data) >> 63) & 1) else 0x00 # sign of byte 7 + if sel == 12: return 0x00 + return 0xff # sel >= 13 + +# v_sad_u8 helper for V_SAD instructions (sum of absolute differences of 4 bytes) +def v_sad_u8(s0, s1, s2): + """V_SAD_U8: Sum of absolute differences of 4 byte pairs plus accumulator.""" + s0, s1, s2 = int(s0), int(s1), int(s2) + result = s2 + for i in range(4): + a = (s0 >> (i * 8)) & 0xff + b = (s1 >> (i * 8)) & 0xff + result += abs(a - b) + return result & 0xffffffff + +# v_msad_u8 helper (masked SAD - skip when reference byte is 0) +def v_msad_u8(s0, s1, s2): + """V_MSAD_U8: Masked sum of absolute differences (skip if reference byte is 0).""" + s0, s1, s2 = int(s0), int(s1), int(s2) + result = s2 + for i in range(4): + a = (s0 >> (i * 8)) & 0xff + b = (s1 >> (i * 8)) & 0xff + if b != 0: # Only add diff if reference (s1) byte is non-zero + result += abs(a - b) + return result & 0xffffffff def f16_to_snorm(f): return max(-32768, min(32767, int(round(max(-1.0, min(1.0, f)) * 32767)))) def f16_to_unorm(f): return max(0, min(65535, int(round(max(0.0, min(1.0, f)) * 65535)))) def f32_to_snorm(f): return max(-32768, min(32767, int(round(max(-1.0, min(1.0, f)) * 32767)))) @@ -240,6 +292,8 @@ __all__ = [ 'i16_to_f16', 'u16_to_f16', 'f16_to_i16', 'f16_to_u16', 'u32_to_u16', 'i32_to_i16', 'f16_to_snorm', 'f16_to_unorm', 'f32_to_snorm', 'f32_to_unorm', 'v_cvt_i16_f32', 'v_cvt_u16_f32', 'SAT8', 'f32_to_u8', 'u8_to_u32', 'u4_to_u32', + # BF16 conversion functions + '_bf16', '_ibf16', 'bf16_to_f32', 'f32_to_bf16', # Math functions 'trunc', 'floor', 'ceil', 'sqrt', 'log2', 'sin', 'cos', 'pow', 'fract', 'isEven', 'mantissa', # Min/max functions @@ -248,6 +302,8 @@ __all__ = [ 'v_min3_f32', 'v_max3_f32', 'v_min3_i32', 'v_max3_i32', 'v_min3_u32', 'v_max3_u32', 'v_min3_f16', 'v_max3_f16', 'v_min3_i16', 'v_max3_i16', 'v_min3_u16', 'v_max3_u16', 'ABSDIFF', + # Byte/SAD helper functions + 'BYTE_PERMUTE', 'v_sad_u8', 'v_msad_u8', # Bit manipulation '_brev32', '_brev64', '_ctz32', '_ctz64', '_exponent', '_is_denorm_f32', '_is_denorm_f64', '_sign', '_mantissa_f32', '_div', '_isnan', '_isquietnan', '_issignalnan', '_gt_neg_zero', '_lt_neg_zero', '_fma', '_ldexp', '_signext', @@ -354,16 +410,25 @@ class SliceProxy: i32 = property(lambda s: _sext(s._get() & MASK32, 32), lambda s, v: s._set(v)) f16 = property(lambda s: _f16(s._get()), lambda s, v: s._set(v if isinstance(v, int) else _i16(float(v)))) f32 = property(lambda s: _f32(s._get()), lambda s, v: s._set(_i32(float(v)))) + bf16 = property(lambda s: _bf16(s._get()), lambda s, v: s._set(v if isinstance(v, int) else _ibf16(float(v)))) b16, b32 = u16, u32 def __int__(self): return self._get() def __index__(self): return self._get() + # Comparison operators (compare as integers) + def __eq__(s, o): return s._get() == int(o) + def __ne__(s, o): return s._get() != int(o) + def __lt__(s, o): return s._get() < int(o) + def __le__(s, o): return s._get() <= int(o) + def __gt__(s, o): return s._get() > int(o) + def __ge__(s, o): return s._get() >= int(o) + class TypedView: """View for S0.u32 that supports [4:0] slicing and [bit] access.""" - __slots__ = ('_reg', '_bits', '_signed', '_float') - def __init__(self, reg, bits, signed=False, is_float=False): - self._reg, self._bits, self._signed, self._float = reg, bits, signed, is_float + __slots__ = ('_reg', '_bits', '_signed', '_float', '_bf16') + def __init__(self, reg, bits, signed=False, is_float=False, is_bf16=False): + self._reg, self._bits, self._signed, self._float, self._bf16 = reg, bits, signed, is_float, is_bf16 @property def _val(self): @@ -390,6 +455,7 @@ class TypedView: def __trunc__(self): return int(float(self)) if self._float else int(self) def __float__(self): if self._float: + if self._bf16: return _bf16(self._val) # bf16 uses different conversion return _f16(self._val) if self._bits == 16 else _f32(self._val) if self._bits == 32 else _f64(self._val) return float(int(self)) @@ -454,6 +520,7 @@ class Reg: i16 = property(lambda s: TypedView(s, 16, signed=True), lambda s, v: setattr(s, '_val', (s._val & 0xffff0000) | (int(v) & 0xffff))) b16 = property(lambda s: TypedView(s, 16), lambda s, v: setattr(s, '_val', (s._val & 0xffff0000) | (int(v) & 0xffff))) f16 = property(lambda s: TypedView(s, 16, is_float=True), lambda s, v: setattr(s, '_val', (s._val & 0xffff0000) | ((v if isinstance(v, int) else _i16(float(v))) & 0xffff))) + bf16 = property(lambda s: TypedView(s, 16, is_float=True, is_bf16=True), lambda s, v: setattr(s, '_val', (s._val & 0xffff0000) | ((v if isinstance(v, int) else _ibf16(float(v))) & 0xffff))) u8 = property(lambda s: TypedView(s, 8)) i8 = property(lambda s: TypedView(s, 8, signed=True)) @@ -610,6 +677,14 @@ def _expr(e: str) -> str: e = e.replace('+INF', 'INF').replace('-INF', '(-INF)') e = re.sub(r'NAN\.f\d+', 'float("nan")', e) + # Verilog bit slice syntax: [start +: width] -> extract width bits starting at start + # Convert to Python slice: [start + width - 1 : start] + def convert_verilog_slice(m): + start, width = m.group(1).strip(), m.group(2).strip() + # Convert to high:low slice format + return f'[({start}) + ({width}) - 1 : ({start})]' + e = re.sub(r'\[([^:\[\]]+)\s*\+:\s*([^:\[\]]+)\]', convert_verilog_slice, e) + # Recursively process bracket contents to handle nested ternaries like S1.u32[x ? a : b] def process_brackets(s): result, i = [], 0 @@ -706,10 +781,10 @@ from extra.assembly.amd.dsl import PDF_URLS INST_PATTERN = re.compile(r'^([SV]_[A-Z0-9_]+)\s+(\d+)\s*$', re.M) # Patterns that can't be handled by the DSL (require special handling in emu.py) -UNSUPPORTED = ['SGPR[', 'V_SWAP', 'eval ', 'BYTE_PERMUTE', 'FATAL_HALT', 'HW_REGISTERS', - 'PC =', 'PC=', 'PC+', '= PC', 'v_sad', '+:', 'vscnt', 'vmcnt', 'expcnt', 'lgkmcnt', - 'CVT_OFF_TABLE', '.bf16', 'ThreadMask', - 'S1[i', 'C.i32', 'v_msad_u8', 'S[i]', 'in[', '2.0 / PI', +UNSUPPORTED = ['SGPR[', 'V_SWAP', 'eval ', 'FATAL_HALT', 'HW_REGISTERS', + 'PC =', 'PC=', 'PC+', '= PC', 'vscnt', 'vmcnt', 'expcnt', 'lgkmcnt', + 'CVT_OFF_TABLE', 'ThreadMask', + 'S1[i', 'C.i32', 'S[i]', 'in[', '2.0 / PI', 'if n.', 'DST.u32', 'addrd = DST', 'addr = DST'] # Malformed pseudocode from PDF def extract_pseudocode(text: str) -> str | None: diff --git a/extra/assembly/amd/test/test_emu.py b/extra/assembly/amd/test/test_emu.py index f1edc38ae2..9f78e854dd 100644 --- a/extra/assembly/amd/test/test_emu.py +++ b/extra/assembly/amd/test/test_emu.py @@ -2357,5 +2357,246 @@ class TestF64Conversions(unittest.TestCase): self.assertEqual(result, -8, f"Expected -8, got {result} (lo=0x{lo:08x}, hi=0x{hi:08x})") +class TestNewPcodeHelpers(unittest.TestCase): + """Tests for newly added pcode helper functions (SAD, BYTE_PERMUTE, BF16).""" + + def test_v_sad_u8_basic(self): + """V_SAD_U8: Sum of absolute differences of 4 bytes.""" + # s0 = 0x05040302, s1 = 0x04030201, s2 = 10 -> diff = 1+1+1+1 = 4, result = 14 + instructions = [ + s_mov_b32(s[0], 0x05040302), + s_mov_b32(s[1], 0x04030201), + v_mov_b32_e32(v[0], s[0]), + v_mov_b32_e32(v[1], s[1]), + v_mov_b32_e32(v[2], 10), + v_sad_u8(v[3], v[0], v[1], v[2]), + ] + st = run_program(instructions, n_lanes=1) + result = st.vgpr[0][3] + self.assertEqual(result, 14, f"Expected 14, got {result}") + + def test_v_sad_u8_identical_bytes(self): + """V_SAD_U8: When both operands are identical, SAD = 0 + accumulator.""" + instructions = [ + s_mov_b32(s[0], 0xDEADBEEF), + v_mov_b32_e32(v[0], s[0]), + v_mov_b32_e32(v[1], s[0]), # Same as v0 + v_mov_b32_e32(v[2], 42), # Accumulator + v_sad_u8(v[3], v[0], v[1], v[2]), + ] + st = run_program(instructions, n_lanes=1) + result = st.vgpr[0][3] + self.assertEqual(result, 42, f"Expected 42, got {result}") + + def test_v_sad_u16_basic(self): + """V_SAD_U16: Sum of absolute differences of 2 half-words.""" + # s0 = 0x00020003, s1 = 0x00010001 -> diff = |2-1| + |3-1| = 1 + 2 = 3 + instructions = [ + s_mov_b32(s[0], 0x00020003), + s_mov_b32(s[1], 0x00010001), + v_mov_b32_e32(v[0], s[0]), + v_mov_b32_e32(v[1], s[1]), + v_mov_b32_e32(v[2], 0), + v_sad_u16(v[3], v[0], v[1], v[2]), + ] + st = run_program(instructions, n_lanes=1) + result = st.vgpr[0][3] + self.assertEqual(result, 3, f"Expected 3, got {result}") + + def test_v_sad_u32_basic(self): + """V_SAD_U32: Absolute difference of 32-bit values.""" + # s0 = 100, s1 = 30 -> diff = 70, s2 = 5 -> result = 75 + instructions = [ + v_mov_b32_e32(v[0], 100), + v_mov_b32_e32(v[1], 30), + v_mov_b32_e32(v[2], 5), + v_sad_u32(v[3], v[0], v[1], v[2]), + ] + st = run_program(instructions, n_lanes=1) + result = st.vgpr[0][3] + self.assertEqual(result, 75, f"Expected 75, got {result}") + + def test_v_msad_u8_masked(self): + """V_MSAD_U8: Skip bytes where reference (s1) is 0.""" + # s0 = 0x10101010, s1 = 0x00010001, s2 = 0 + # Only bytes 0 and 2 of s1 are non-zero, so only those contribute + # diff = |0x10-0x01| + |0x10-0x01| = 15 + 15 = 30 + instructions = [ + s_mov_b32(s[0], 0x10101010), + s_mov_b32(s[1], 0x00010001), + v_mov_b32_e32(v[0], s[0]), + v_mov_b32_e32(v[1], s[1]), + v_mov_b32_e32(v[2], 0), + v_msad_u8(v[3], v[0], v[1], v[2]), + ] + st = run_program(instructions, n_lanes=1) + result = st.vgpr[0][3] + self.assertEqual(result, 30, f"Expected 30, got {result}") + + def test_v_perm_b32_select_bytes(self): + """V_PERM_B32: Select bytes from combined {s0, s1}.""" + # Combined = {S0, S1} where S1 is bytes 0-3, S0 is bytes 4-7 + # s0 = 0x03020100 -> bytes 4-7 of combined + # s1 = 0x07060504 -> bytes 0-3 of combined + # Combined = 0x03020100_07060504 + # selector = 0x00010203 -> select bytes 3,2,1,0 from combined = 0x04,0x05,0x06,0x07 + instructions = [ + s_mov_b32(s[0], 0x03020100), + s_mov_b32(s[1], 0x07060504), + s_mov_b32(s[2], 0x00010203), + v_mov_b32_e32(v[0], s[0]), + v_mov_b32_e32(v[1], s[1]), + v_mov_b32_e32(v[2], s[2]), + v_perm_b32(v[3], v[0], v[1], v[2]), + ] + st = run_program(instructions, n_lanes=1) + result = st.vgpr[0][3] + self.assertEqual(result, 0x04050607, f"Expected 0x04050607, got 0x{result:08x}") + + def test_v_perm_b32_select_high_bytes(self): + """V_PERM_B32: Select bytes from high word (s0).""" + # Combined = {S0, S1} where S1 is bytes 0-3, S0 is bytes 4-7 + # s0 = 0x03020100 -> bytes 4-7 of combined + # s1 = 0x07060504 -> bytes 0-3 of combined + # selector = 0x04050607 -> select bytes 7,6,5,4 from combined = 0x00,0x01,0x02,0x03 + instructions = [ + s_mov_b32(s[0], 0x03020100), + s_mov_b32(s[1], 0x07060504), + s_mov_b32(s[2], 0x04050607), + v_mov_b32_e32(v[0], s[0]), + v_mov_b32_e32(v[1], s[1]), + v_mov_b32_e32(v[2], s[2]), + v_perm_b32(v[3], v[0], v[1], v[2]), + ] + st = run_program(instructions, n_lanes=1) + result = st.vgpr[0][3] + self.assertEqual(result, 0x00010203, f"Expected 0x00010203, got 0x{result:08x}") + + def test_v_perm_b32_constant_values(self): + """V_PERM_B32: Test constant 0x00 (sel=12) and 0xFF (sel>=13).""" + # selector = 0x0C0D0E0F -> bytes: 12=0x00, 13=0xFF, 14=0xFF, 15=0xFF + instructions = [ + s_mov_b32(s[0], 0x12345678), + s_mov_b32(s[1], 0xABCDEF01), + s_mov_b32(s[2], 0x0C0D0E0F), + v_mov_b32_e32(v[0], s[0]), + v_mov_b32_e32(v[1], s[1]), + v_mov_b32_e32(v[2], s[2]), + v_perm_b32(v[3], v[0], v[1], v[2]), + ] + st = run_program(instructions, n_lanes=1) + result = st.vgpr[0][3] + # byte 0: sel=0x0F >= 13 -> 0xFF + # byte 1: sel=0x0E >= 13 -> 0xFF + # byte 2: sel=0x0D >= 13 -> 0xFF + # byte 3: sel=0x0C = 12 -> 0x00 + self.assertEqual(result, 0x00FFFFFF, f"Expected 0x00FFFFFF, got 0x{result:08x}") + + def test_v_dot2_f32_bf16_basic(self): + """V_DOT2_F32_BF16: Dot product of two bf16 pairs accumulated into f32.""" + from extra.assembly.amd.pcode import _ibf16 + # A = packed (2.0, 3.0) as bf16, B = packed (4.0, 5.0) as bf16 + # Result = 2*4 + 3*5 + acc = 8 + 15 + 0 = 23.0 + a_lo, a_hi = _ibf16(2.0), _ibf16(3.0) + b_lo, b_hi = _ibf16(4.0), _ibf16(5.0) + a_packed = (a_hi << 16) | a_lo + b_packed = (b_hi << 16) | b_lo + instructions = [ + s_mov_b32(s[0], a_packed), + s_mov_b32(s[1], b_packed), + v_mov_b32_e32(v[0], s[0]), + v_mov_b32_e32(v[1], s[1]), + v_mov_b32_e32(v[2], 0), # accumulator = 0 + v_dot2_f32_bf16(v[3], v[0], v[1], v[2]), + ] + st = run_program(instructions, n_lanes=1) + result = i2f(st.vgpr[0][3]) + self.assertAlmostEqual(result, 23.0, places=1, msg=f"Expected 23.0, got {result}") + + +class TestQuadmaskWqm(unittest.TestCase): + """Tests for S_QUADMASK and S_WQM instructions.""" + + def test_s_quadmask_b32_all_quads_active(self): + """S_QUADMASK_B32: All quads have at least one active lane.""" + # Input: 0xFFFFFFFF (all bits set) -> all 8 quads active -> result = 0xFF + instructions = [ + s_mov_b32(s[0], 0xFFFFFFFF), + s_quadmask_b32(s[1], s[0]), + ] + st = run_program(instructions, n_lanes=1) + result = st.sgpr[1] + self.assertEqual(result, 0xFF, f"Expected 0xFF, got 0x{result:x}") + self.assertEqual(st.scc, 1, "SCC should be 1 (result != 0)") + + def test_s_quadmask_b32_alternating_quads(self): + """S_QUADMASK_B32: Every other quad has lanes active.""" + # Input: 0x0F0F0F0F -> quads 0,2,4,6 active (bits 0-3, 8-11, 16-19, 24-27) + # Result: bits 0,2,4,6 set = 0x55 + instructions = [ + s_mov_b32(s[0], 0x0F0F0F0F), + s_quadmask_b32(s[1], s[0]), + ] + st = run_program(instructions, n_lanes=1) + result = st.sgpr[1] + self.assertEqual(result, 0x55, f"Expected 0x55, got 0x{result:x}") + + def test_s_quadmask_b32_no_quads_active(self): + """S_QUADMASK_B32: No quads have active lanes.""" + instructions = [ + s_mov_b32(s[0], 0), + s_quadmask_b32(s[1], s[0]), + ] + st = run_program(instructions, n_lanes=1) + result = st.sgpr[1] + self.assertEqual(result, 0, f"Expected 0, got 0x{result:x}") + self.assertEqual(st.scc, 0, "SCC should be 0 (result == 0)") + + def test_s_quadmask_b32_single_lane_per_quad(self): + """S_QUADMASK_B32: Single lane active in each quad.""" + # Input: 0x11111111 -> bit 0 of each nibble set -> all 8 quads active + instructions = [ + s_mov_b32(s[0], 0x11111111), + s_quadmask_b32(s[1], s[0]), + ] + st = run_program(instructions, n_lanes=1) + result = st.sgpr[1] + self.assertEqual(result, 0xFF, f"Expected 0xFF, got 0x{result:x}") + + def test_s_wqm_b32_all_active(self): + """S_WQM_B32: Whole quad mode - if any lane in quad is active, activate all.""" + # Input: 0x11111111 -> one lane per quad -> output all quads fully active = 0xFFFFFFFF + instructions = [ + s_mov_b32(s[0], 0x11111111), + s_wqm_b32(s[1], s[0]), + ] + st = run_program(instructions, n_lanes=1) + result = st.sgpr[1] + self.assertEqual(result, 0xFFFFFFFF, f"Expected 0xFFFFFFFF, got 0x{result:x}") + self.assertEqual(st.scc, 1, "SCC should be 1 (result != 0)") + + def test_s_wqm_b32_alternating_quads(self): + """S_WQM_B32: Only some quads have active lanes.""" + # Input: 0x0000000F -> only quad 0 has lanes -> output = 0x0000000F (quad 0 all active) + instructions = [ + s_mov_b32(s[0], 0x00000001), # single lane in quad 0 + s_wqm_b32(s[1], s[0]), + ] + st = run_program(instructions, n_lanes=1) + result = st.sgpr[1] + self.assertEqual(result, 0x0000000F, f"Expected 0x0000000F, got 0x{result:x}") + + def test_s_wqm_b32_zero(self): + """S_WQM_B32: No lanes active.""" + instructions = [ + s_mov_b32(s[0], 0), + s_wqm_b32(s[1], s[0]), + ] + st = run_program(instructions, n_lanes=1) + result = st.sgpr[1] + self.assertEqual(result, 0, f"Expected 0, got 0x{result:x}") + self.assertEqual(st.scc, 0, "SCC should be 0 (result == 0)") + + if __name__ == '__main__': unittest.main() diff --git a/extra/assembly/amd/test/test_pcode.py b/extra/assembly/amd/test/test_pcode.py index b9b9c0395c..141b938baa 100644 --- a/extra/assembly/amd/test/test_pcode.py +++ b/extra/assembly/amd/test/test_pcode.py @@ -1,7 +1,9 @@ #!/usr/bin/env python3 """Tests for the RDNA3 pseudocode DSL.""" import unittest -from extra.assembly.amd.pcode import Reg, TypedView, SliceProxy, ExecContext, compile_pseudocode, _expr, MASK32, MASK64, _f32, _i32, _f16, _i16, f32_to_f16, _isnan +from extra.assembly.amd.pcode import (Reg, TypedView, SliceProxy, ExecContext, compile_pseudocode, _expr, MASK32, MASK64, + _f32, _i32, _f16, _i16, f32_to_f16, _isnan, _bf16, _ibf16, bf16_to_f32, f32_to_bf16, + BYTE_PERMUTE, v_sad_u8, v_msad_u8) from extra.assembly.amd.autogen.rdna3.gen_pcode import _VOP3SDOp_V_DIV_SCALE_F32, _VOPCOp_V_CMP_CLASS_F32 class TestReg(unittest.TestCase): @@ -265,5 +267,129 @@ class TestPseudocodeRegressions(unittest.TestCase): self.assertFalse(_isnan(normal_reg.f32), "_isnan should return False for normal TypedView") self.assertFalse(_isnan(inf_reg.f32), "_isnan should return False for inf TypedView") +class TestBF16(unittest.TestCase): + """Tests for BF16 (bfloat16) support.""" + + def test_bf16_conversion(self): + """Test bf16 <-> f32 conversion.""" + # bf16 is just the top 16 bits of f32 + # 1.0f = 0x3f800000, bf16 = 0x3f80 + self.assertAlmostEqual(_bf16(0x3f80), 1.0, places=2) + self.assertEqual(_ibf16(1.0), 0x3f80) + # 2.0f = 0x40000000, bf16 = 0x4000 + self.assertAlmostEqual(_bf16(0x4000), 2.0, places=2) + self.assertEqual(_ibf16(2.0), 0x4000) + # -1.0f = 0xbf800000, bf16 = 0xbf80 + self.assertAlmostEqual(_bf16(0xbf80), -1.0, places=2) + self.assertEqual(_ibf16(-1.0), 0xbf80) + + def test_bf16_special_values(self): + """Test bf16 special values (inf, nan).""" + import math + # +inf: f32 = 0x7f800000, bf16 = 0x7f80 + self.assertTrue(math.isinf(_bf16(0x7f80))) + self.assertEqual(_ibf16(float('inf')), 0x7f80) + # -inf: f32 = 0xff800000, bf16 = 0xff80 + self.assertTrue(math.isinf(_bf16(0xff80))) + self.assertEqual(_ibf16(float('-inf')), 0xff80) + # NaN: quiet NaN bf16 = 0x7fc0 + self.assertTrue(math.isnan(_bf16(0x7fc0))) + self.assertEqual(_ibf16(float('nan')), 0x7fc0) + + def test_bf16_register_property(self): + """Test Reg.bf16 property.""" + r = Reg(0) + r.bf16 = 3.0 # 3.0f = 0x40400000, bf16 = 0x4040 + self.assertEqual(r._val & 0xffff, 0x4040) + self.assertAlmostEqual(float(r.bf16), 3.0, places=1) + + def test_bf16_slice_property(self): + """Test SliceProxy.bf16 property.""" + r = Reg(0x40404040) # Two bf16 3.0 values + self.assertAlmostEqual(r[15:0].bf16, 3.0, places=1) + self.assertAlmostEqual(r[31:16].bf16, 3.0, places=1) + +class TestBytePermute(unittest.TestCase): + """Tests for BYTE_PERMUTE helper function (V_PERM_B32).""" + + def test_byte_select_0_to_7(self): + """Test selecting bytes 0-7 from 64-bit data.""" + # data = {s0, s1} where s0 is bytes 0-3, s1 is bytes 4-7 + # Combined: 0x0706050403020100 (byte 0 = 0x00, byte 7 = 0x07) + data = 0x0706050403020100 + for i in range(8): + self.assertEqual(BYTE_PERMUTE(data, i), i, f"byte {i} should be {i}") + + def test_sign_extend_bytes(self): + """Test sign extension selectors 8-11.""" + # sel 8: sign of byte 1 (bits 15:8) + # sel 9: sign of byte 3 (bits 31:24) + # sel 10: sign of byte 5 (bits 47:40) + # sel 11: sign of byte 7 (bits 63:56) + data = 0x8000800080008000 # All relevant bytes have sign bit set + self.assertEqual(BYTE_PERMUTE(data, 8), 0xff) + self.assertEqual(BYTE_PERMUTE(data, 9), 0xff) + self.assertEqual(BYTE_PERMUTE(data, 10), 0xff) + self.assertEqual(BYTE_PERMUTE(data, 11), 0xff) + data = 0x7f007f007f007f00 # No sign bits set + self.assertEqual(BYTE_PERMUTE(data, 8), 0x00) + self.assertEqual(BYTE_PERMUTE(data, 9), 0x00) + self.assertEqual(BYTE_PERMUTE(data, 10), 0x00) + self.assertEqual(BYTE_PERMUTE(data, 11), 0x00) + + def test_constant_zero(self): + """Test selector 12 returns 0x00.""" + self.assertEqual(BYTE_PERMUTE(0xffffffffffffffff, 12), 0x00) + + def test_constant_ff(self): + """Test selectors >= 13 return 0xFF.""" + for sel in [13, 14, 15, 255]: + self.assertEqual(BYTE_PERMUTE(0, sel), 0xff, f"sel {sel} should be 0xff") + +class TestSADHelpers(unittest.TestCase): + """Tests for V_SAD_U8 and V_MSAD_U8 helper functions.""" + + def test_v_sad_u8_basic(self): + """Test v_sad_u8 with simple values.""" + # s0 = 0x04030201, s1 = 0x04030201 -> diff = 0 for all bytes + result = v_sad_u8(0x04030201, 0x04030201, 0) + self.assertEqual(result, 0) + # s0 = 0x05040302, s1 = 0x04030201 -> diff = 1+1+1+1 = 4 + result = v_sad_u8(0x05040302, 0x04030201, 0) + self.assertEqual(result, 4) + + def test_v_sad_u8_with_accumulator(self): + """Test v_sad_u8 with non-zero accumulator.""" + # s0 = 0x05040302, s1 = 0x04030201, s2 = 100 -> 4 + 100 = 104 + result = v_sad_u8(0x05040302, 0x04030201, 100) + self.assertEqual(result, 104) + + def test_v_sad_u8_large_diff(self): + """Test v_sad_u8 with maximum byte differences.""" + # s0 = 0xffffffff, s1 = 0x00000000 -> diff = 255*4 = 1020 + result = v_sad_u8(0xffffffff, 0x00000000, 0) + self.assertEqual(result, 1020) + + def test_v_msad_u8_basic(self): + """Test v_msad_u8 masks when reference byte is 0.""" + # s0 = 0x10101010, s1 = 0x00000000 -> all masked, result = 0 + result = v_msad_u8(0x10101010, 0x00000000, 0) + self.assertEqual(result, 0) + # s0 = 0x10101010, s1 = 0x01010101 -> diff = |0x10-0x01|*4 = 15*4 = 60 + result = v_msad_u8(0x10101010, 0x01010101, 0) + self.assertEqual(result, 60) + + def test_v_msad_u8_partial_mask(self): + """Test v_msad_u8 with partial masking.""" + # s0 = 0x10101010, s1 = 0x00010001 -> bytes 1 and 3 masked + # diff = |0x10-0x01| + |0x10-0x01| = 15 + 15 = 30 + result = v_msad_u8(0x10101010, 0x00010001, 0) + self.assertEqual(result, 30) + + def test_v_msad_u8_with_accumulator(self): + """Test v_msad_u8 with non-zero accumulator.""" + result = v_msad_u8(0x10101010, 0x01010101, 50) + self.assertEqual(result, 110) # 60 + 50 + if __name__ == '__main__': unittest.main()