diff --git a/extra/assembly/amd/emu.py b/extra/assembly/amd/emu.py index 844418fd40..64517d64fc 100644 --- a/extra/assembly/amd/emu.py +++ b/extra/assembly/amd/emu.py @@ -381,7 +381,7 @@ def decode_program(data: bytes) -> dict[int, Inst]: # Compile pcode for instructions that use it (not VOPD which has _fnx/_fny, not special dispatches) # Try ucode first (UOp-based), fall back to pcode (Python exec-based) def _compile_op(cls_name, op_name, pcode): - return compile_uop(cls_name, op_name, pcode) or compile_pseudocode(cls_name, op_name, pcode) + return compile_uop(op_name, pcode) or compile_pseudocode(cls_name, op_name, pcode) # VOPD needs separate functions for X and Y ops if isinstance(inst, VOPD): def _compile_vopd_op(op): return _compile_op(type(op).__name__, op.name, PSEUDOCODE_STRINGS[type(op)][op]) diff --git a/extra/assembly/amd/ucode.py b/extra/assembly/amd/ucode.py index ce7f99444b..ef0e77ce21 100644 --- a/extra/assembly/amd/ucode.py +++ b/extra/assembly/amd/ucode.py @@ -54,9 +54,19 @@ def _expr(node, ctx: Ctx, hint: DType = None) -> UOp: case Var(name): if name == 'PI': return UOp.const(hint or dtypes.float64, math.pi) - if name in ('INF', '+INF', '-INF'): return UOp.const(hint or dtypes.float64, float('-inf') if name == '-INF' else float('inf')) + if 'INF' in name and name.replace('+', '').replace('-', '').replace('.f16', '').replace('.f32', '').replace('.f64', '') == 'INF': + dt = dtypes.float16 if '.f16' in name else dtypes.float32 if '.f32' in name else hint or dtypes.float64 + return UOp.const(dt, float('-inf') if name.startswith('-') else float('inf')) if name in ('WAVE_MODE.IEEE', 'WAVE32'): return UOp.const(dtypes.uint32, 1) - if name in ('WAVE64', 'ROUND_MODE'): return UOp.const(dtypes.uint32, 0) + if name in ('WAVE64', 'ROUND_MODE', 'WAVE_STATUS.COND_DBG_SYS', 'WAVE_STATUS.COND_DBG_USER'): return UOp.const(dtypes.uint32, 0) + if name == 'MAX_FLOAT_F32': return UOp.const(dtypes.float32, 3.402823466e+38) + if name == 'OVERFLOW_F32': return UOp.const(dtypes.float32, float('inf')) + if name == 'OVERFLOW_F64': return UOp.const(dtypes.float64, float('inf')) + if name == 'UNDERFLOW_F32': return UOp.const(dtypes.float32, 0.0) + if name == 'UNDERFLOW_F64': return UOp.const(dtypes.float64, 0.0) + if name == 'DENORM.f32': return UOp.const(dtypes.float32, 1.17549435e-38) + if name == 'DENORM.f64': return UOp.const(dtypes.float64, 2.2250738585072014e-308) + if name == 'NAN.f32': return UOp.const(dtypes.float32, float('nan')) if name in ('VCCZ', 'EXECZ'): return _cast(UOp(Ops.CMPEQ, dtypes.bool, (ctx.vars.get('VCC' if name == 'VCCZ' else 'EXEC'), UOp.const(dtypes.uint64, 0))), dtypes.uint32) if name.startswith('eval '): return ctx.vars.get('_eval', UOp.const(dtypes.uint32, 0)) @@ -68,6 +78,7 @@ def _expr(node, ctx: Ctx, hint: DType = None) -> UOp: if isinstance(expr, Var): if expr.name in ('VCCZ', 'EXECZ'): return _cast(UOp(Ops.CMPEQ, dtypes.bool, (ctx.vars.get('VCC' if expr.name == 'VCCZ' else 'EXEC'), UOp.const(dtypes.uint64, 0))), dtypes.uint32) + if expr.name.startswith('WAVE_STATUS.COND_DBG'): return UOp.const(dtypes.uint32, 0) vn = expr.name + '_64' if qdt in (QDType.F64, QDType.U64, QDType.I64, QDType.B64) and expr.name.isupper() else expr.name base = ctx.vars.get(vn) if vn in ctx.vars else ctx.vars.get(expr.name) if base is None: raise ValueError(f"Unknown variable: {expr.name}") @@ -249,18 +260,34 @@ def _transform_call(name: str, a: list[UOp], hint: DType) -> UOp: dt, clamp_neg = CVT_MAP[name] v = UOp(Ops.WHERE, a[0].dtype, (UOp(Ops.CMPLT, dtypes.bool, (a[0], UOp.const(a[0].dtype, 0.0))), UOp.const(a[0].dtype, 0.0), a[0])) if clamp_neg else a[0] return UOp(Ops.CAST, dt, (v,)) - if name in ('f16_to_snorm', 'f16_to_unorm'): - lo, scale, out_dt = (-1.0, 32767.0, dtypes.int16) if name == 'f16_to_snorm' else (0.0, 65535.0, dtypes.uint16) + if name in ('f16_to_snorm', 'f16_to_unorm', 'f32_to_snorm', 'f32_to_unorm'): + lo, scale, out_dt = (-1.0, 32767.0, dtypes.int16) if 'snorm' in name else (0.0, 65535.0, dtypes.uint16) clamped = UOp(Ops.WHERE, a[0].dtype, (UOp(Ops.CMPLT, dtypes.bool, (a[0], UOp.const(a[0].dtype, lo))), UOp.const(a[0].dtype, lo), a[0])) clamped = UOp(Ops.WHERE, a[0].dtype, (UOp(Ops.CMPLT, dtypes.bool, (UOp.const(a[0].dtype, 1.0), clamped)), UOp.const(a[0].dtype, 1.0), clamped)) return UOp(Ops.CAST, out_dt, (UOp(Ops.MUL, a[0].dtype, (clamped, UOp.const(a[0].dtype, scale))),)) + if name == 'u32_to_u16': return UOp(Ops.AND, dtypes.uint32, (a[0], UOp.const(dtypes.uint32, 0xffff))) + if name == 'i32_to_i16': return _cast(UOp(Ops.AND, dtypes.uint32, (_cast(a[0], dtypes.uint32), UOp.const(dtypes.uint32, 0xffff))), dtypes.int16) + if name in ('LT_NEG_ZERO', 'GT_NEG_ZERO'): int_dt = {dtypes.float64: dtypes.int64, dtypes.float16: dtypes.int16}.get(a[0].dtype, dtypes.int32) a_bits, b_bits = UOp(Ops.BITCAST, int_dt, (a[0],)), UOp(Ops.BITCAST, int_dt, (a[1],)) return UOp(Ops.CMPLT, dtypes.bool, ((a_bits, b_bits) if name == 'LT_NEG_ZERO' else (b_bits, a_bits))) if name.startswith('v_min_') or name.startswith('v_max_'): return UOp(Ops.WHERE, a[0].dtype, (UOp(Ops.CMPLT, dtypes.bool, ((a[0], a[1]) if 'min' in name else (a[1], a[0]))), a[0], a[1])) - + if name.startswith('v_max3_') or name.startswith('v_min3_'): + cmp = lambda x, y: UOp(Ops.CMPLT, dtypes.bool, ((x, y) if 'min' in name else (y, x))) + m01 = UOp(Ops.WHERE, a[0].dtype, (cmp(a[0], a[1]), a[0], a[1])) + return UOp(Ops.WHERE, a[0].dtype, (cmp(m01, a[2]), m01, a[2])) + if name in ('v_sad_u8', 'v_msad_u8'): # sum of absolute differences + result = a[2] if len(a) > 2 else UOp.const(dtypes.uint32, 0) + for i in range(4): + byte_a = UOp(Ops.AND, dtypes.uint32, (UOp(Ops.SHR, dtypes.uint32, (a[0], UOp.const(dtypes.uint32, i*8))), UOp.const(dtypes.uint32, 0xff))) + byte_b = UOp(Ops.AND, dtypes.uint32, (UOp(Ops.SHR, dtypes.uint32, (a[1], UOp.const(dtypes.uint32, i*8))), UOp.const(dtypes.uint32, 0xff))) + diff = UOp(Ops.SUB, dtypes.uint32, (byte_a, byte_b)) + abs_diff = UOp(Ops.WHERE, dtypes.uint32, (UOp(Ops.CMPLT, dtypes.bool, (diff, UOp.const(dtypes.uint32, 0x80000000))), diff, + UOp(Ops.SUB, dtypes.uint32, (UOp.const(dtypes.uint32, 0), diff)))) + result = UOp(Ops.ADD, dtypes.uint32, (result, abs_diff)) + return result raise ValueError(f"Unknown function: {name}") def _get_lhs_info(lhs, ctx: Ctx) -> tuple[str, DType, int|None, int|None, str|None]: @@ -397,80 +424,54 @@ def _make_uop_fn(sink: UOp, output_info: list[tuple[str, DType]], input_vars: di return result return fn -SUPPORTED_OPS: set[str] = { - 'V_ADD3_U32', 'V_ADD_CO_CI_U32', 'V_ADD_CO_U32', 'V_ADD_F16', 'V_ADD_F32', 'V_ADD_F64', 'V_ADD_LSHL_U32', 'V_ADD_NC_I16', 'V_ADD_NC_I32', 'V_ADD_NC_U16', 'V_ADD_NC_U32', - 'V_ALIGNBIT_B32', 'V_ALIGNBYTE_B32', 'V_AND_B16', 'V_AND_B32', 'V_AND_OR_B32', 'V_ASHRREV_I16', 'V_ASHRREV_I32', 'V_ASHRREV_I64', - 'V_BFE_I32', 'V_BFE_U32', 'V_BFI_B32', 'V_BFM_B32', 'V_CNDMASK_B16', 'V_CNDMASK_B32', 'V_COS_F16', 'V_COS_F32', 'V_CUBEID_F32', 'V_CUBESC_F32', - 'V_CVT_F16_F32', 'V_CVT_F32_F16', 'V_CVT_F32_I32', 'V_CVT_F32_U32', 'V_CVT_F32_UBYTE0', 'V_CVT_F32_UBYTE1', 'V_CVT_F32_UBYTE2', 'V_CVT_F32_UBYTE3', - 'V_CVT_FLOOR_I32_F32', 'V_CVT_I32_F32', 'V_CVT_I32_I16', 'V_CVT_NEAREST_I32_F32', 'V_CVT_PK_I16_F32', 'V_CVT_PK_U16_F32', 'V_CVT_PK_U8_F32', 'V_CVT_U32_F32', 'V_CVT_U32_U16', - 'V_DOT2_F16_F16', 'V_DOT2_F32_F16', 'V_DOT2ACC_F32_F16', 'V_FMA_DX9_ZERO_F32', 'V_FMA_F16', 'V_FMA_F32', 'V_FMA_F64', 'V_FMAAK_F16', 'V_FMAAK_F32', - 'V_FMAC_DX9_ZERO_F32', 'V_FMAC_F16', 'V_FMAC_F32', 'V_FMAMK_F16', 'V_FMAMK_F32', 'V_FREXP_EXP_I16_F16', 'V_FREXP_EXP_I32_F32', 'V_FREXP_EXP_I32_F64', - 'V_LERP_U8', 'V_LOG_F16', 'V_LOG_F32', 'V_LSHL_ADD_U32', 'V_LSHL_OR_B32', 'V_LSHLREV_B16', 'V_LSHLREV_B32', 'V_LSHLREV_B64', 'V_LSHRREV_B16', 'V_LSHRREV_B32', 'V_LSHRREV_B64', - 'V_MAD_I16', 'V_MAD_I32_I16', 'V_MAD_I32_I24', 'V_MAD_U16', 'V_MAD_U32_U16', 'V_MAD_U32_U24', 'V_MAX_I16', 'V_MAX_I32', 'V_MAX_U16', 'V_MAX_U32', - 'V_MIN_I16', 'V_MIN_I32', 'V_MIN_U16', 'V_MIN_U32', 'V_MOV_B16', 'V_MOV_B32', 'V_MSAD_U8', 'V_MUL_DX9_ZERO_F32', 'V_MUL_F16', 'V_MUL_F32', 'V_MUL_F64', - 'V_MUL_HI_I32', 'V_MUL_HI_I32_I24', 'V_MUL_HI_U32', 'V_MUL_HI_U32_U24', 'V_MUL_I32_I24', 'V_MUL_LO_U16', 'V_MUL_LO_U32', 'V_MUL_U32_U24', - 'V_NOT_B16', 'V_NOT_B32', 'V_OR3_B32', 'V_OR_B16', 'V_OR_B32', 'V_PACK_B32_F16', 'V_PK_FMAC_F16', 'V_RCP_F16', 'V_RCP_F32', 'V_RCP_F64', 'V_RCP_IFLAG_F32', - 'V_RSQ_F16', 'V_RSQ_F32', 'V_RSQ_F64', 'V_PK_ADD_F16', 'V_PK_ADD_I16', 'V_PK_ADD_U16', 'V_PK_ASHRREV_I16', 'V_PK_FMA_F16', 'V_PK_LSHLREV_B16', 'V_PK_LSHRREV_B16', - 'V_PK_MAD_I16', 'V_PK_MAD_U16', 'V_PK_MAX_I16', 'V_PK_MAX_U16', 'V_PK_MIN_I16', 'V_PK_MIN_U16', 'V_PK_MUL_F16', 'V_PK_MUL_LO_U16', 'V_PK_SUB_I16', 'V_PK_SUB_U16', - 'V_RNDNE_F16', 'V_RNDNE_F32', 'V_RNDNE_F64', 'V_SAD_U8', 'V_SAD_U16', 'V_SAD_U32', 'V_SIN_F16', 'V_SIN_F32', 'V_SQRT_F16', 'V_SQRT_F32', 'V_SQRT_F64', - 'V_CVT_F32_F64', 'V_CVT_F64_F32', 'V_CVT_F64_I32', 'V_CVT_F64_U32', 'V_CVT_I32_F64', 'V_CVT_U32_F64', 'V_CVT_NORM_I16_F16', 'V_CVT_NORM_U16_F16', - 'V_CVT_PK_NORM_I16_F16', 'V_CVT_PK_NORM_U16_F16', 'V_CVT_PK_RTZ_F16_F32', 'V_SUB_CO_CI_U32', 'V_SUB_CO_U32', 'V_SUB_F16', 'V_SUB_F32', - 'V_SUB_NC_I16', 'V_SUB_NC_I32', 'V_SUB_NC_U16', 'V_SUB_NC_U32', 'V_SUBREV_CO_CI_U32', 'V_SUBREV_CO_U32', 'V_SUBREV_F16', 'V_SUBREV_F32', 'V_SUBREV_NC_U32', - 'V_SWAP_B16', 'V_SWAP_B32', 'V_TRUNC_F16', 'V_TRUNC_F32', 'V_TRUNC_F64', 'V_WRITELANE_B32', 'V_XAD_U32', 'V_XNOR_B32', 'V_XOR3_B32', 'V_XOR_B16', 'V_XOR_B32', - 'V_CVT_F16_I16', 'V_CVT_F16_U16', 'V_CVT_I16_F16', 'V_CVT_U16_F16', 'V_EXP_F16', 'V_EXP_F32', 'V_LDEXP_F16', 'V_LDEXP_F32', 'V_LDEXP_F64', - 'V_CUBEMA_F32', 'V_CUBETC_F32', 'V_SAT_PK_U8_I16', 'V_MAX3_I16', 'V_MAX3_I32', 'V_MAX3_U16', 'V_MAX3_U32', 'V_MIN3_I16', 'V_MIN3_I32', 'V_MIN3_U16', 'V_MIN3_U32', - 'V_MAXMIN_I32', 'V_MAXMIN_U32', 'V_MINMAX_I32', 'V_MINMAX_U32', - 'V_CMP_EQ_F16', 'V_CMP_EQ_F32', 'V_CMP_EQ_F64', 'V_CMP_EQ_I16', 'V_CMP_EQ_I32', 'V_CMP_EQ_I64', 'V_CMP_EQ_U16', 'V_CMP_EQ_U32', 'V_CMP_EQ_U64', - 'V_CMP_F_F16', 'V_CMP_F_F32', 'V_CMP_F_F64', 'V_CMP_F_I32', 'V_CMP_F_I64', 'V_CMP_F_U32', 'V_CMP_F_U64', - 'V_CMP_GE_F16', 'V_CMP_GE_F32', 'V_CMP_GE_F64', 'V_CMP_GE_I16', 'V_CMP_GE_I32', 'V_CMP_GE_I64', 'V_CMP_GE_U16', 'V_CMP_GE_U32', 'V_CMP_GE_U64', - 'V_CMP_GT_F16', 'V_CMP_GT_F32', 'V_CMP_GT_F64', 'V_CMP_GT_I16', 'V_CMP_GT_I32', 'V_CMP_GT_I64', 'V_CMP_GT_U16', 'V_CMP_GT_U32', 'V_CMP_GT_U64', - 'V_CMP_LE_F16', 'V_CMP_LE_F32', 'V_CMP_LE_F64', 'V_CMP_LE_I16', 'V_CMP_LE_I32', 'V_CMP_LE_I64', 'V_CMP_LE_U16', 'V_CMP_LE_U32', 'V_CMP_LE_U64', - 'V_CMP_LG_F16', 'V_CMP_LG_F32', 'V_CMP_LG_F64', 'V_CMP_LT_F16', 'V_CMP_LT_F32', 'V_CMP_LT_F64', 'V_CMP_LT_I16', 'V_CMP_LT_I32', 'V_CMP_LT_I64', - 'V_CMP_LT_U16', 'V_CMP_LT_U32', 'V_CMP_LT_U64', 'V_CMP_NE_I16', 'V_CMP_NE_I32', 'V_CMP_NE_I64', 'V_CMP_NE_U16', 'V_CMP_NE_U32', 'V_CMP_NE_U64', - 'V_CMP_NEQ_F16', 'V_CMP_NEQ_F32', 'V_CMP_NEQ_F64', 'V_CMP_NGE_F16', 'V_CMP_NGE_F32', 'V_CMP_NGE_F64', 'V_CMP_NGT_F16', 'V_CMP_NGT_F32', 'V_CMP_NGT_F64', - 'V_CMP_NLE_F16', 'V_CMP_NLE_F32', 'V_CMP_NLE_F64', 'V_CMP_NLG_F16', 'V_CMP_NLG_F32', 'V_CMP_NLG_F64', 'V_CMP_NLT_F16', 'V_CMP_NLT_F32', 'V_CMP_NLT_F64', - 'V_CMP_O_F16', 'V_CMP_O_F32', 'V_CMP_O_F64', 'V_CMP_T_F16', 'V_CMP_T_F32', 'V_CMP_T_F64', 'V_CMP_T_I32', 'V_CMP_T_I64', 'V_CMP_T_U32', 'V_CMP_T_U64', - 'V_CMP_U_F16', 'V_CMP_U_F32', 'V_CMP_U_F64', - 'V_CMPX_EQ_F16', 'V_CMPX_EQ_F32', 'V_CMPX_EQ_F64', 'V_CMPX_EQ_I16', 'V_CMPX_EQ_I32', 'V_CMPX_EQ_I64', 'V_CMPX_EQ_U16', 'V_CMPX_EQ_U32', 'V_CMPX_EQ_U64', - 'V_CMPX_F_F16', 'V_CMPX_F_F32', 'V_CMPX_F_F64', 'V_CMPX_F_I32', 'V_CMPX_F_I64', 'V_CMPX_F_U32', 'V_CMPX_F_U64', - 'V_CMPX_GE_F16', 'V_CMPX_GE_F32', 'V_CMPX_GE_F64', 'V_CMPX_GE_I16', 'V_CMPX_GE_I32', 'V_CMPX_GE_I64', 'V_CMPX_GE_U16', 'V_CMPX_GE_U32', 'V_CMPX_GE_U64', - 'V_CMPX_GT_F16', 'V_CMPX_GT_F32', 'V_CMPX_GT_F64', 'V_CMPX_GT_I16', 'V_CMPX_GT_I32', 'V_CMPX_GT_I64', 'V_CMPX_GT_U16', 'V_CMPX_GT_U32', 'V_CMPX_GT_U64', - 'V_CMPX_LE_F16', 'V_CMPX_LE_F32', 'V_CMPX_LE_F64', 'V_CMPX_LE_I16', 'V_CMPX_LE_I32', 'V_CMPX_LE_I64', 'V_CMPX_LE_U16', 'V_CMPX_LE_U32', 'V_CMPX_LE_U64', - 'V_CMPX_LG_F16', 'V_CMPX_LG_F32', 'V_CMPX_LG_F64', 'V_CMPX_LT_F16', 'V_CMPX_LT_F32', 'V_CMPX_LT_F64', 'V_CMPX_LT_I16', 'V_CMPX_LT_I32', 'V_CMPX_LT_I64', - 'V_CMPX_LT_U16', 'V_CMPX_LT_U32', 'V_CMPX_LT_U64', 'V_CMPX_NE_I16', 'V_CMPX_NE_I32', 'V_CMPX_NE_I64', 'V_CMPX_NE_U16', 'V_CMPX_NE_U32', 'V_CMPX_NE_U64', - 'V_CMPX_NEQ_F16', 'V_CMPX_NEQ_F32', 'V_CMPX_NEQ_F64', 'V_CMPX_NGE_F16', 'V_CMPX_NGE_F32', 'V_CMPX_NGE_F64', 'V_CMPX_NGT_F16', 'V_CMPX_NGT_F32', 'V_CMPX_NGT_F64', - 'V_CMPX_NLE_F16', 'V_CMPX_NLE_F32', 'V_CMPX_NLE_F64', 'V_CMPX_NLG_F16', 'V_CMPX_NLG_F32', 'V_CMPX_NLG_F64', 'V_CMPX_NLT_F16', 'V_CMPX_NLT_F32', 'V_CMPX_NLT_F64', - 'V_CMPX_O_F16', 'V_CMPX_O_F32', 'V_CMPX_O_F64', 'V_CMPX_T_F16', 'V_CMPX_T_F32', 'V_CMPX_T_F64', 'V_CMPX_T_I32', 'V_CMPX_T_I64', 'V_CMPX_T_U32', 'V_CMPX_T_U64', - 'V_CMPX_U_F16', 'V_CMPX_U_F32', 'V_CMPX_U_F64', - 'S_ABSDIFF_I32', 'S_ABS_I32', 'S_ADD_F16', 'S_ADD_F32', 'S_ADD_I32', 'S_ADD_U32', 'S_ADDC_U32', 'S_ADDK_I32', 'S_AND_B32', 'S_AND_B64', - 'S_AND_NOT0_SAVEEXEC_B32', 'S_AND_NOT0_SAVEEXEC_B64', 'S_AND_NOT0_WREXEC_B32', 'S_AND_NOT0_WREXEC_B64', 'S_AND_NOT1_B32', 'S_AND_NOT1_B64', - 'S_AND_NOT1_SAVEEXEC_B32', 'S_AND_NOT1_SAVEEXEC_B64', 'S_AND_NOT1_WREXEC_B32', 'S_AND_NOT1_WREXEC_B64', 'S_AND_SAVEEXEC_B32', 'S_AND_SAVEEXEC_B64', - 'S_ASHR_I32', 'S_ASHR_I64', 'S_BCNT0_I32_B32', 'S_BCNT0_I32_B64', 'S_BCNT1_I32_B32', 'S_BCNT1_I32_B64', 'S_BFE_I32', 'S_BFE_I64', 'S_BFE_U32', 'S_BFE_U64', - 'S_BFM_B32', 'S_BFM_B64', 'S_BITSET0_B32', 'S_BITSET0_B64', 'S_BITSET1_B32', 'S_BITSET1_B64', 'S_CMOVK_I32', 'S_CMOV_B32', 'S_CMOV_B64', - 'S_CSELECT_B32', 'S_CSELECT_B64', 'S_CVT_F16_F32', 'S_CVT_F32_F16', 'S_CVT_F32_I32', 'S_CVT_F32_U32', 'S_CVT_HI_F32_F16', 'S_CVT_I32_F32', - 'S_CVT_PK_RTZ_F16_F32', 'S_CVT_U32_F32', 'S_DELAY_ALU', 'S_FMAAK_F32', 'S_FMAC_F16', 'S_FMAC_F32', 'S_FMAMK_F32', 'S_LSHL_B32', 'S_LSHL_B64', - 'S_LSHL1_ADD_U32', 'S_LSHL2_ADD_U32', 'S_LSHL3_ADD_U32', 'S_LSHL4_ADD_U32', 'S_LSHR_B32', 'S_LSHR_B64', 'S_MAX_I32', 'S_MAX_U32', 'S_MIN_I32', 'S_MIN_U32', - 'S_MOVK_I32', 'S_MOV_B32', 'S_MOV_B64', 'S_MULK_I32', 'S_MUL_F16', 'S_MUL_F32', 'S_MUL_HI_I32', 'S_MUL_HI_U32', 'S_MUL_I32', 'S_NAND_B32', 'S_NAND_B64', - 'S_NAND_SAVEEXEC_B32', 'S_NAND_SAVEEXEC_B64', 'S_NOP', 'S_NOR_B32', 'S_NOR_B64', 'S_NOR_SAVEEXEC_B32', 'S_NOR_SAVEEXEC_B64', 'S_NOT_B32', 'S_NOT_B64', - 'S_OR_B32', 'S_OR_B64', 'S_OR_NOT0_SAVEEXEC_B32', 'S_OR_NOT0_SAVEEXEC_B64', 'S_OR_NOT1_B32', 'S_OR_NOT1_B64', 'S_OR_NOT1_SAVEEXEC_B32', 'S_OR_NOT1_SAVEEXEC_B64', - 'S_OR_SAVEEXEC_B32', 'S_OR_SAVEEXEC_B64', 'S_PACK_HH_B32_B16', 'S_PACK_HL_B32_B16', 'S_PACK_LH_B32_B16', 'S_PACK_LL_B32_B16', 'S_RFE_B64', - 'S_RNDNE_F16', 'S_RNDNE_F32', 'S_SENDMSG_RTN_B32', 'S_SENDMSG_RTN_B64', 'S_SETPC_B64', 'S_SEXT_I32_I16', 'S_SEXT_I32_I8', - 'S_SUB_F16', 'S_SUB_F32', 'S_SUB_I32', 'S_SUB_U32', 'S_SUBB_U32', 'S_TRUNC_F16', 'S_TRUNC_F32', 'S_VERSION', 'S_BITCMP0_B32', 'S_BITCMP0_B64', - 'S_BITCMP1_B32', 'S_BITCMP1_B64', 'S_MAX_F16', 'S_MAX_F32', 'S_MIN_F16', 'S_MIN_F32', 'S_WAITCNT_EXPCNT', 'S_WAITCNT_LGKMCNT', 'S_WAITCNT_VMCNT', - 'S_WAITCNT_VSCNT', 'S_BRANCH', 'S_CALL_B64', 'S_CBRANCH_EXECNZ', 'S_CBRANCH_EXECZ', 'S_CBRANCH_SCC0', 'S_CBRANCH_SCC1', 'S_CBRANCH_VCCNZ', 'S_CBRANCH_VCCZ', - 'S_GETPC_B64', 'S_XNOR_B32', 'S_XNOR_B64', 'S_XNOR_SAVEEXEC_B32', 'S_XNOR_SAVEEXEC_B64', 'S_XOR_B32', 'S_XOR_B64', 'S_XOR_SAVEEXEC_B32', 'S_XOR_SAVEEXEC_B64', - 'S_CMPK_EQ_I32', 'S_CMPK_EQ_U32', 'S_CMPK_GE_I32', 'S_CMPK_GE_U32', 'S_CMPK_GT_I32', 'S_CMPK_GT_U32', 'S_CMPK_LE_I32', 'S_CMPK_LE_U32', - 'S_CMPK_LG_I32', 'S_CMPK_LG_U32', 'S_CMPK_LT_I32', 'S_CMPK_LT_U32', 'S_CMP_EQ_F16', 'S_CMP_EQ_F32', 'S_CMP_EQ_I32', 'S_CMP_EQ_U32', 'S_CMP_EQ_U64', - 'S_CMP_GE_F16', 'S_CMP_GE_F32', 'S_CMP_GE_I32', 'S_CMP_GE_U32', 'S_CMP_GT_F16', 'S_CMP_GT_F32', 'S_CMP_GT_I32', 'S_CMP_GT_U32', - 'S_CMP_LE_F16', 'S_CMP_LE_F32', 'S_CMP_LE_I32', 'S_CMP_LE_U32', 'S_CMP_LG_F16', 'S_CMP_LG_F32', 'S_CMP_LG_I32', 'S_CMP_LG_U32', 'S_CMP_LG_U64', - 'S_CMP_LT_F16', 'S_CMP_LT_F32', 'S_CMP_LT_I32', 'S_CMP_LT_U32', 'S_CMP_NEQ_F16', 'S_CMP_NEQ_F32', 'S_CMP_NGE_F16', 'S_CMP_NGE_F32', - 'S_CMP_NGT_F16', 'S_CMP_NGT_F32', 'S_CMP_NLE_F16', 'S_CMP_NLE_F32', 'S_CMP_NLG_F16', 'S_CMP_NLG_F32', 'S_CMP_NLT_F16', 'S_CMP_NLT_F32', - 'S_CMP_O_F16', 'S_CMP_O_F32', 'S_CMP_U_F16', 'S_CMP_U_F32', +# Ops with known issues (subtle float semantics, register array access, unimplemented features) +_SKIP_OPS = { + # Float ops with subtle semantics (neg zero, NaN handling) + 'V_CEIL_F16', 'V_CEIL_F32', 'V_CEIL_F64', 'V_FLOOR_F16', 'V_FLOOR_F32', 'V_FLOOR_F64', + 'V_FRACT_F16', 'V_FRACT_F32', 'V_FRACT_F64', 'V_DIV_FMAS_F32', 'V_DIV_FMAS_F64', + 'V_CMP_CLASS_F16', 'V_CMP_CLASS_F32', 'V_CMP_CLASS_F64', 'V_CMPX_CLASS_F16', 'V_CMPX_CLASS_F32', 'V_CMPX_CLASS_F64', + 'V_FREXP_MANT_F16', 'V_FREXP_MANT_F32', 'V_FREXP_MANT_F64', + 'V_DIV_FIXUP_F16', 'V_DIV_FIXUP_F32', 'V_DIV_SCALE_F32', 'V_DIV_SCALE_F64', # complex NaN/inf/denorm handling + 'V_TRIG_PREOP_F64', # lookup table for 2/PI mantissa bits + # Bit manipulation ops (need CLZ/CTZ/BREV intrinsics) + 'V_CLZ_I32_U32', 'V_CTZ_I32_B32', 'S_BREV_B32', 'S_BREV_B64', + 'S_FF0_I32_B32', 'S_FF0_I32_B64', 'S_FF1_I32_B32', 'S_FF1_I32_B64', + 'S_FLBIT_I32', 'S_FLBIT_I32_B32', 'S_FLBIT_I32_B64', 'S_FLBIT_I32_I32', 'S_FLBIT_I32_I64', + 'S_BITREPLICATE_B64_B32', # bit replication loop + 'S_BITSET0_B32', 'S_BITSET0_B64', 'S_BITSET1_B32', 'S_BITSET1_B64', # D0[S0[n:0]] = val (dynamic index into output) + 'S_QUADMASK_B32', 'S_QUADMASK_B64', 'S_WQM_B32', 'S_WQM_B64', # quad/wave mask ops with +: slice syntax + # Register array access (SRC0, DST, VGPR[], SGPR[]) + 'S_GETREG_B32', 'S_SETREG_B32', 'S_SETREG_IMM32_B32', + 'S_MOVRELD_B32', 'S_MOVRELD_B64', 'S_MOVRELS_B32', 'S_MOVRELS_B64', 'S_MOVRELSD_2_B32', + 'V_MOVRELD_B32', 'V_MOVRELS_B32', 'V_MOVRELSD_B32', 'V_MOVRELSD_2_B32', + 'V_READFIRSTLANE_B32', 'V_WRITELANE_B32', 'V_READLANE_B32', + 'V_PERMLANE16_B32', 'V_PERMLANE64_B32', 'V_PERMLANEX16_B32', + 'V_MBCNT_HI_U32_B32', 'V_MBCNT_LO_U32_B32', + 'S_SENDMSG_RTN_B32', 'S_SENDMSG_RTN_B64', + 'V_SWAPREL_B32', # VGPR[laneId][addr] register array access + 'V_PERM_B32', # BYTE_PERMUTE function not implemented + # Control flow / special ops (no actual computation) + 'S_NOP', 'S_SETHALT', 'S_TRAP', + # 65-bit intermediate results / multi-output with carry + 'V_MAD_U64_U32', 'V_MAD_I64_I32', # { D1.u1, D0.u64 } = 65-bit result + # Dot product ops (need bf16/u4 conversion functions or array declarations) + 'V_DOT2_F32_BF16', # bf16_to_f32 not implemented + 'V_DOT4_I32_IU8', 'V_DOT4_U32_U8', # u8_to_u32 array access pattern + 'V_DOT8_I32_IU4', 'V_DOT8_U32_U4', # u4_to_u32 array access pattern + # VOP3P mixed precision (S[i] array access in loop) + 'V_FMA_MIX_F32', 'V_FMA_MIXLO_F16', 'V_FMA_MIXHI_F16', + # Lookup table ops + 'V_CVT_OFF_F32_I4', # CVT_OFF_TABLE lookup } +# Memory/runtime patterns that can't be compiled to pure UOps +_MEM_PATTERNS = ('MEM[', 'LDS[', 'LDS(', 'SMEM[', 'FLAT[', 'GLOBAL[', 'DS[', 'SCRATCH[', 'VGPR[', 'SGPR[', 'GPR[', 'GS_REGS', 'ADDR', + 'thread_in[', 'thread_out[', 'thread_valid[') # DS_SWIZZLE patterns + @functools.cache -def compile_uop(cls_name: str, op_name: str, pseudocode: str): +def compile_uop(op_name: str, pseudocode: str): """Compile pseudocode to UOp-based function. Returns None if unsupported.""" - if op_name not in SUPPORTED_OPS: return None + if op_name in _SKIP_OPS: return None + if any(p in pseudocode for p in _MEM_PATTERNS): return None # memory ops need pcode sink, output_info, input_vars = _compile_pseudocode(pseudocode) return _make_uop_fn(sink, output_info, input_vars)