diff --git a/extra/assembly/amd/emu.py b/extra/assembly/amd/emu.py index 7cd6ccbd39..b7aa412764 100644 --- a/extra/assembly/amd/emu.py +++ b/extra/assembly/amd/emu.py @@ -427,7 +427,8 @@ class _Ctx: pcode = get_pcode(op) vcc_reg = sdst_reg if sdst_reg is not None else VCC_LO.offset if 'VCC' not in srcs: srcs['VCC'] = self.rsgpr_dyn(_c(vcc_reg)) - srcs.update({'EXEC': exec_mask, 'SCC': self.rsgpr_dyn(_c(SCC.offset)), 'laneId': lane}) + srcs.update({'EXEC': exec_mask, 'SCC': self.rsgpr_dyn(_c(SCC.offset)), 'laneId': lane, + 'ROUND_MODE': _c(0), 'ROUND_TOWARD_ZERO': _c(0)}) # rounding mode: 0=RNE, RTZ constant _, assigns = parse_pcode(pcode, srcs) raw_stores: list = [] @@ -796,10 +797,13 @@ def _compile_vop3p(inst: VOP3P, ctx: _Ctx) -> UOp: return bits def build_remapped_src(src: UOp, opsel_lo_bit: int, opsel_hi_bit: int, neg_lo_bit: int, neg_hi_bit: int) -> UOp: return get_half_bits(src, bool(opsel_lo_bit), bool(neg_lo_bit)) | (get_half_bits(src, bool(opsel_hi_bit), bool(neg_hi_bit)) << UOp.const(dtypes.uint32, 16)) - s0_new = build_remapped_src(src0, opsel & 1, opsel_hi & 1, neg & 1, neg_hi & 1) - s1_new = build_remapped_src(src1, opsel & 2, opsel_hi & 2, neg & 2, neg_hi & 2) - s2_new = build_remapped_src(src2, opsel & 4, 1 if opsel_hi2 else 0, neg & 4, neg_hi & 4) - srcs = {'S0': s0_new, 'S1': s1_new, 'S2': s2_new} + # DOT IU instructions use NEG bits for signed/unsigned selection, not fp16 negation + is_dot_iu = 'DOT' in op_name and 'IU' in op_name + n0, n1, n2, nh0, nh1, nh2 = (0, 0, 0, 0, 0, 0) if is_dot_iu else (neg & 1, neg & 2, neg & 4, neg_hi & 1, neg_hi & 2, neg_hi & 4) + srcs = {'S0': build_remapped_src(src0, opsel & 1, opsel_hi & 1, n0, nh0), + 'S1': build_remapped_src(src1, opsel & 2, opsel_hi & 2, n1, nh1), + 'S2': build_remapped_src(src2, opsel & 4, 1 if opsel_hi2 else 0, n2, nh2)} + if is_dot_iu: srcs['NEG'] = UOp.const(dtypes.uint32, neg) return ctx.compile_vop_pcode(inst.op, srcs, lane, vdst_reg, exec_mask) def _compile_vopd(inst: VOPD, ctx: _Ctx) -> UOp: diff --git a/extra/assembly/amd/pcode.py b/extra/assembly/amd/pcode.py index 7103109d28..3f826f3260 100644 --- a/extra/assembly/amd/pcode.py +++ b/extra/assembly/amd/pcode.py @@ -94,13 +94,19 @@ def _trig_reduce(x, phase=0.0): return UOp(Ops.SIN, x.dtype, (x - n * _const(x.dtype, 6.283185307179586),)) def _signext(val: UOp) -> UOp: - for bits, mask, ext in [(8, 0xFF, 0xFFFFFF00), (16, 0xFFFF, 0xFFFF0000)]: + for bits, mask, ext in [(4, 0xF, 0xFFFFFFF0), (8, 0xFF, 0xFFFFFF00), (16, 0xFFFF, 0xFFFF0000)]: if (val.op == Ops.AND and len(val.src) == 2 and val.src[1].op == Ops.CONST and val.src[1].arg == mask) or val.dtype.itemsize == bits // 8: v32 = val.cast(dtypes.uint32) if val.dtype != dtypes.uint32 else val sb = (v32 >> _u32(bits - 1)) & _u32(1) return sb.ne(_u32(0)).where(v32 | _u32(ext), v32).cast(dtypes.int) return val.cast(dtypes.int64) if val.dtype in (dtypes.int, dtypes.int32) else val +def _signext_4bit(val: UOp) -> UOp: + """Sign extend a 4-bit value to 32-bit signed integer.""" + v32 = val.cast(dtypes.uint32) if val.dtype != dtypes.uint32 else val + sb = (v32 >> _u32(3)) & _u32(1) # sign bit at position 3 + return sb.ne(_u32(0)).where(v32 | _u32(0xFFFFFFF0), v32).bitcast(dtypes.int) + def _abs(val: UOp) -> UOp: if val.dtype not in (dtypes.float32, dtypes.float64, dtypes.half): return val _, _, _, _, shift = _float_info(val) @@ -227,11 +233,44 @@ _FUNCS: dict[str, Callable[..., UOp]] = { 'signext_from_bit': _signext_from_bit, 'ldexp': _ldexp, 'frexp_mant': _frexp_mant, 'mantissa': _frexp_mant, 'frexp_exp': _frexp_exp, 'trig_preop_result': _trig_preop, 's_ff1_i32_b32': lambda a: _ff1(a, 32), 's_ff1_i32_b64': lambda a: _ff1(a, 64), + # Normalization conversions: map [-1,1] or [0,1] to integer range + # Use floor(x + 0.5) for round-to-nearest + # SNORM: round(value * 32767), range is [-32767, 32767] (hardware behavior) + 'f16_to_snorm': lambda a: _floor(_f16_extract(a).cast(dtypes.float32) * _const(dtypes.float32, 32767) + _const(dtypes.float32, 0.5)).cast(dtypes.int).cast(dtypes.int16), + 'f16_to_unorm': lambda a: _floor(_f16_extract(a).cast(dtypes.float32) * _const(dtypes.float32, 65535) + _const(dtypes.float32, 0.5)).cast(dtypes.uint16), + 'f32_to_snorm': lambda a: _floor(a.bitcast(dtypes.float32) * _const(dtypes.float32, 32767) + _const(dtypes.float32, 0.5)).cast(dtypes.int).cast(dtypes.int16), + 'f32_to_unorm': lambda a: _floor(a.bitcast(dtypes.float32) * _const(dtypes.float32, 65535) + _const(dtypes.float32, 0.5)).cast(dtypes.uint16), + 'f32_to_u8': lambda a: _f_to_u(a.bitcast(dtypes.float32), dtypes.uint8), + # Integer truncation conversions + 'i32_to_i16': lambda a: a.cast(dtypes.int).cast(dtypes.int16), + 'u32_to_u16': lambda a: a.cast(dtypes.uint32).cast(dtypes.uint16), + 'u16_to_u32': lambda a: (a.cast(dtypes.uint32) & _u32(0xFFFF)), + 'u8_to_u32': lambda a: (a.cast(dtypes.uint32) & _u32(0xFF)), + 'u4_to_u32': lambda a: (a.cast(dtypes.uint32) & _u32(0xF)), + # Signed extraction with sign extension for dot products + 'i16_to_i32': lambda a: _signext(a.cast(dtypes.uint32) & _u32(0xFFFF)), + 'i8_to_i32': lambda a: _signext(a.cast(dtypes.uint32) & _u32(0xFF)), + 'i4_to_i32': lambda a: _signext_4bit(a.cast(dtypes.uint32) & _u32(0xF)), + # Float to int16 conversions + 'v_cvt_i16_f32': lambda a: UOp(Ops.TRUNC, dtypes.float32, (a.bitcast(dtypes.float32),)).cast(dtypes.int16), + 'v_cvt_u16_f32': lambda a: _f_to_u(a.bitcast(dtypes.float32), dtypes.uint16), } for is_max, name in [(False, 'min'), (True, 'max')]: for dt, sfx in [(dtypes.float32, 'f32'), (dtypes.int, 'i32'), (dtypes.uint32, 'u32'), (dtypes.int16, 'i16'), (dtypes.uint16, 'u16')]: _FUNCS[f'v_{name}_{sfx}'] = lambda *a, im=is_max, d=dt: _minmax_reduce(im, d, *a) _FUNCS[f'v_{name}3_{sfx}'] = lambda *a, im=is_max, d=dt: _minmax_reduce(im, d, *a) +# f16 min/max/min3/max3/med3 +for is_max, name in [(False, 'min'), (True, 'max')]: + _FUNCS[f'v_{name}_f16'] = lambda *a, im=is_max: _minmax_reduce(im, dtypes.half, *[_f16_extract(x) for x in a]) + _FUNCS[f'v_{name}3_f16'] = lambda *a, im=is_max: _minmax_reduce(im, dtypes.half, *[_f16_extract(x) for x in a]) + _FUNCS[f'v_{name}_num_f16'] = lambda *a, im=is_max: _minmax_reduce(im, dtypes.half, *[_f16_extract(x) for x in a]) + _FUNCS[f'v_{name}_num_f32'] = lambda *a, im=is_max: _minmax_reduce(im, dtypes.float32, *a) + _FUNCS[f'v_{name}3_num_f16'] = lambda *a, im=is_max: _minmax_reduce(im, dtypes.half, *[_f16_extract(x) for x in a]) + _FUNCS[f'v_{name}3_num_f32'] = lambda *a, im=is_max: _minmax_reduce(im, dtypes.float32, *a) + _FUNCS[f'v_{name}imum_f16'] = lambda *a, im=is_max: _minmax_reduce(im, dtypes.half, *[_f16_extract(x) for x in a]) + _FUNCS[f'v_{name}imum_f32'] = lambda *a, im=is_max: _minmax_reduce(im, dtypes.float32, *a) + _FUNCS[f'v_{name}imum3_f16'] = lambda *a, im=is_max: _minmax_reduce(im, dtypes.half, *[_f16_extract(x) for x in a]) + _FUNCS[f'v_{name}imum3_f32'] = lambda *a, im=is_max: _minmax_reduce(im, dtypes.float32, *a) # ═══════════════════════════════════════════════════════════════════════════════ # TOKENIZER/PARSER @@ -239,7 +278,7 @@ for is_max, name in [(False, 'min'), (True, 'max')]: DTYPES = {'u32': dtypes.uint32, 'i32': dtypes.int, 'f32': dtypes.float32, 'b32': dtypes.uint32, 'u64': dtypes.uint64, 'i64': dtypes.int64, 'f64': dtypes.float64, 'b64': dtypes.uint64, 'u16': dtypes.uint16, 'i16': dtypes.short, 'f16': dtypes.half, 'b16': dtypes.uint16, - 'u8': dtypes.uint8, 'i8': dtypes.int8, 'b8': dtypes.uint8, 'u1': dtypes.uint32} + 'u8': dtypes.uint8, 'i8': dtypes.int8, 'b8': dtypes.uint8, 'u4': dtypes.uint8, 'i4': dtypes.int8, 'u1': dtypes.uint32} _BITS_DT = {8: dtypes.uint8, 16: dtypes.uint16, 32: dtypes.uint32, 64: dtypes.uint64} _NUM_SUFFIXES = ('ULL', 'LL', 'UL', 'U', 'L', 'F', 'f') def _strip_suffix(num: str) -> tuple[str, str]: @@ -432,14 +471,6 @@ class Parser: return elem if self.at('LBRACKET') and name not in self.vars: self.eat('LBRACKET') - if self.at('NUM'): - idx_num = int(self.peek().val) - if f'{name}{idx_num}' in self.vars: - self.eat('NUM') - self.eat('RBRACKET') - elem = self.vars[f'{name}{idx_num}'] - if self.try_eat('DOT'): return _cast_to(elem, DTYPES.get(self.eat('IDENT').val, dtypes.uint32)) - return elem first = self.parse() return self._handle_bracket_rest(first, _u32(0), name) if name in self.vars: @@ -467,6 +498,7 @@ class Parser: if dt == base.dtype: return base if dt.itemsize == 2 and base.dtype.itemsize == 4: return (base & _const(base.dtype, 0xFFFF)).cast(dtypes.uint16) if dt == dtypes.uint16 else (base & _const(base.dtype, 0xFFFF)).cast(dtypes.uint16).bitcast(dt) + if field == 'i4': return _signext_4bit(base) return _cast_to(base, dt) def _handle_bracket(self, base, var_name: str | None = None) -> UOp: @@ -509,8 +541,8 @@ class Parser: var_name = self._find_var_name(base) if first.op == Ops.CONST: idx = int(first.arg) - if var_name and f'{var_name}{idx}' in self.vars: - v = self.vars[f'{var_name}{idx}'] + if var_name and f'{var_name}@{idx}' in self.vars: + v = self.vars[f'{var_name}@{idx}'] return _cast_to(v, dt_suffix) if dt_suffix else v dt = dtypes.uint64 if base.dtype in (dtypes.uint64, dtypes.int64) else dtypes.uint32 base_cast = base.cast(dt) if base.dtype != dt else base @@ -518,7 +550,7 @@ class Parser: return _cast_to(result, dt_suffix) if dt_suffix else result if var_name: idx_u32 = _to_u32(first) - elems = [(i, self.vars[f'{var_name}{i}']) for i in range(256) if f'{var_name}{i}' in self.vars] + elems = [(i, self.vars[f'{var_name}@{i}']) for i in range(256) if f'{var_name}@{i}' in self.vars] if elems: result = elems[-1][1] for ei, ev in reversed(elems[:-1]): @@ -537,7 +569,7 @@ class Parser: self.eat('RBRACE') var_name = self._find_var_name(base) if var_name: - elem = self.vars.get(f'{var_name}{idx}', _u32(0)) + elem = self.vars.get(f'{var_name}@{idx}', _u32(0)) # use @ to avoid collision with temps like A4 if self.try_eat('DOT'): dt_name = self.eat('IDENT').val return _cast_to(elem, DTYPES.get(dt_name, dtypes.uint32)) @@ -787,7 +819,7 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di if found_var: vars[found_var] = block_assigns[found_var] = _const(dtypes.bool, False) for loop_i in range(start_val, end_val + 1): subst_lines = [_subst_loop_var(bl, loop_var, loop_i) for bl in body_lines if not (has_break and bl.strip().lower() == 'break')] - _, iter_assigns, _ = parse_block(subst_lines, 0, vars, funcs, assigns) + _, iter_assigns, _ = parse_block(subst_lines, 0, {**vars, **block_assigns}, funcs, assigns) if has_break: assert found_var is not None found = block_assigns.get(found_var, vars.get(found_var)) @@ -944,7 +976,7 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di if existing is not None and isinstance(existing, UOp): block_assigns[var] = vars[var] = _set_bit(existing, _u32(idx), val) else: - block_assigns[f'{var}{idx}'] = vars[f'{var}{idx}'] = val + block_assigns[f'{var}@{idx}'] = vars[f'{var}@{idx}'] = val i += 1; continue # Compound assignment: var += or var -= diff --git a/extra/assembly/amd/test/hw/helpers.py b/extra/assembly/amd/test/hw/helpers.py index 839ae87cde..5bc1227c46 100644 --- a/extra/assembly/amd/test/hw/helpers.py +++ b/extra/assembly/amd/test/hw/helpers.py @@ -13,7 +13,7 @@ def _i32(f: float) -> int: return struct.unpack(' float: return struct.unpack(' float: return struct.unpack(' float: return struct.unpack(' int: f = float(f) if math.isnan(f): return 0x7e00 diff --git a/extra/assembly/amd/test/hw/test_vop1.py b/extra/assembly/amd/test/hw/test_vop1.py index 215a2a2fbe..12d518a066 100644 --- a/extra/assembly/amd/test/hw/test_vop1.py +++ b/extra/assembly/amd/test/hw/test_vop1.py @@ -255,7 +255,6 @@ class TestF16Conversions(unittest.TestCase): def test_v_cvt_f16_f32_small(self): """V_CVT_F16_F32 converts small f32 value.""" - from extra.assembly.amd.test.hw.helpers import f32_to_f16 instructions = [ v_mov_b32_e32(v[0], 0.5), v_cvt_f16_f32_e32(v[1], v[0]), @@ -293,7 +292,6 @@ class TestF16Conversions(unittest.TestCase): def test_v_cvt_f16_f32_reads_full_32bit_source(self): """V_CVT_F16_F32 must read full 32-bit f32 source.""" - from extra.assembly.amd.test.hw.helpers import _f16 instructions = [ s_mov_b32(s[0], 0x3fc00000), # f32 1.5 v_mov_b32_e32(v[0], s[0]), @@ -302,7 +300,7 @@ class TestF16Conversions(unittest.TestCase): st = run_program(instructions, n_lanes=1) result = st.vgpr[0][1] lo_bits = result & 0xffff - self.assertEqual(lo_bits, 0x3e00, f"Expected f16(1.5)=0x3e00, got 0x{lo_bits:04x} ({_f16(lo_bits)})") + self.assertEqual(lo_bits, 0x3e00, f"Expected f16(1.5)=0x3e00, got 0x{lo_bits:04x} ({f16(lo_bits)})") def test_v_cvt_i16_f16_zero(self): """V_CVT_I16_F16 converts f16 zero to i16 zero.""" @@ -696,7 +694,6 @@ class TestCvtF16Modifiers(unittest.TestCase): def test_v_cvt_f32_f16_abs_negative(self): """V_CVT_F32_F16 with |abs| on negative value.""" - from extra.assembly.amd.test.hw.helpers import f32_to_f16 f16_neg1 = f32_to_f16(-1.0) # 0xbc00 instructions = [ s_mov_b32(s[0], f16_neg1), @@ -709,7 +706,6 @@ class TestCvtF16Modifiers(unittest.TestCase): def test_v_cvt_f32_f16_abs_positive(self): """V_CVT_F32_F16 with |abs| on positive value (should stay positive).""" - from extra.assembly.amd.test.hw.helpers import f32_to_f16 f16_2 = f32_to_f16(2.0) # 0x4000 instructions = [ s_mov_b32(s[0], f16_2), @@ -722,7 +718,6 @@ class TestCvtF16Modifiers(unittest.TestCase): def test_v_cvt_f32_f16_neg_positive(self): """V_CVT_F32_F16 with neg on positive value.""" - from extra.assembly.amd.test.hw.helpers import f32_to_f16 f16_2 = f32_to_f16(2.0) # 0x4000 instructions = [ s_mov_b32(s[0], f16_2), @@ -735,7 +730,6 @@ class TestCvtF16Modifiers(unittest.TestCase): def test_v_cvt_f32_f16_neg_negative(self): """V_CVT_F32_F16 with neg on negative value (double negative).""" - from extra.assembly.amd.test.hw.helpers import f32_to_f16 f16_neg2 = f32_to_f16(-2.0) # 0xc000 instructions = [ s_mov_b32(s[0], f16_neg2), @@ -748,7 +742,6 @@ class TestCvtF16Modifiers(unittest.TestCase): def test_v_cvt_f16_f32_then_pack_for_wmma(self): """CVT F32->F16 followed by pack (common WMMA pattern).""" - from extra.assembly.amd.test.hw.helpers import _f16 f32_val = 3.5 instructions = [ s_mov_b32(s[0], f2i(f32_val)), @@ -757,8 +750,8 @@ class TestCvtF16Modifiers(unittest.TestCase): v_pack_b32_f16(v[2], v[1], v[1]), # Pack same value ] st = run_program(instructions, n_lanes=1) - lo = _f16(st.vgpr[0][2] & 0xffff) - hi = _f16((st.vgpr[0][2] >> 16) & 0xffff) + lo = f16(st.vgpr[0][2] & 0xffff) + hi = f16((st.vgpr[0][2] >> 16) & 0xffff) self.assertAlmostEqual(lo, f32_val, places=1) self.assertAlmostEqual(hi, f32_val, places=1) @@ -804,7 +797,6 @@ class TestConversionRounding(unittest.TestCase): def test_f16_to_f32_precision(self): """F16 to F32 conversion precision.""" - from extra.assembly.amd.test.hw.helpers import f32_to_f16 f16_val = f32_to_f16(1.5) instructions = [ s_mov_b32(s[0], f16_val), @@ -816,7 +808,6 @@ class TestConversionRounding(unittest.TestCase): def test_f16_denormal_to_f32(self): """F16 denormal converts to small positive f32.""" - from extra.assembly.amd.test.hw.helpers import _f16 f16_denorm = 0x0001 # Smallest positive f16 denormal instructions = [ v_mov_b32_e32(v[0], f16_denorm), @@ -1512,5 +1503,63 @@ class TestReciprocalF16(unittest.TestCase): self.assertAlmostEqual(result, 0.25, places=2, msg="1/4.0 should be 0.25") +class TestCvtNormF16(unittest.TestCase): + """Tests for V_CVT_NORM_I16_F16 and V_CVT_NORM_U16_F16.""" + + def test_cvt_norm_i16_f16_positive(self): + """V_CVT_NORM_I16_F16: f16 1.0 -> i16 max (32767).""" + instructions = [ + s_mov_b32(s[0], f32_to_f16(1.0)), + v_mov_b32_e32(v[0], s[0]), + v_cvt_norm_i16_f16_e32(v[1], v[0]), + ] + st = run_program(instructions, n_lanes=1) + result = st.vgpr[0][1] & 0xffff + self.assertEqual(result, 32767) + + def test_cvt_norm_i16_f16_negative(self): + """V_CVT_NORM_I16_F16: f16 -1.0 -> i16 -32767 (0x8001).""" + instructions = [ + s_mov_b32(s[0], f32_to_f16(-1.0)), + v_mov_b32_e32(v[0], s[0]), + v_cvt_norm_i16_f16_e32(v[1], v[0]), + ] + st = run_program(instructions, n_lanes=1) + result = st.vgpr[0][1] & 0xffff + self.assertEqual(result, 0x8001) # -32767, hardware uses symmetric range + + def test_cvt_norm_i16_f16_zero(self): + """V_CVT_NORM_I16_F16: f16 0.0 -> i16 0.""" + instructions = [ + v_mov_b32_e32(v[0], 0), + v_cvt_norm_i16_f16_e32(v[1], v[0]), + ] + st = run_program(instructions, n_lanes=1) + result = st.vgpr[0][1] & 0xffff + self.assertEqual(result, 0) + + def test_cvt_norm_u16_f16_one(self): + """V_CVT_NORM_U16_F16: f16 1.0 -> u16 max (65535).""" + instructions = [ + s_mov_b32(s[0], f32_to_f16(1.0)), + v_mov_b32_e32(v[0], s[0]), + v_cvt_norm_u16_f16_e32(v[1], v[0]), + ] + st = run_program(instructions, n_lanes=1) + result = st.vgpr[0][1] & 0xffff + self.assertEqual(result, 65535) + + def test_cvt_norm_u16_f16_half(self): + """V_CVT_NORM_U16_F16: f16 0.5 -> u16 ~32768.""" + instructions = [ + s_mov_b32(s[0], f32_to_f16(0.5)), + v_mov_b32_e32(v[0], s[0]), + v_cvt_norm_u16_f16_e32(v[1], v[0]), + ] + st = run_program(instructions, n_lanes=1) + result = st.vgpr[0][1] & 0xffff + self.assertAlmostEqual(result, 32768, delta=1) + + if __name__ == '__main__': unittest.main() diff --git a/extra/assembly/amd/test/hw/test_vop3.py b/extra/assembly/amd/test/hw/test_vop3.py index be8ecdde6e..596fcd4711 100644 --- a/extra/assembly/amd/test/hw/test_vop3.py +++ b/extra/assembly/amd/test/hw/test_vop3.py @@ -857,7 +857,6 @@ class TestF16Modifiers(unittest.TestCase): def test_v_fma_f16_inline_const_1_0(self): """V_FMA_F16: a*b + 1.0 should use f16 inline constant.""" - from extra.assembly.amd.test.hw.helpers import f32_to_f16, _f16 f16_a = f32_to_f16(0.325928) # ~0x3537 f16_b = f32_to_f16(-0.486572) # ~0xb7c9 instructions = [ @@ -868,13 +867,12 @@ class TestF16Modifiers(unittest.TestCase): v_fma_f16(v[4], v[4], v[6], 1.0), # 1.0 is inline constant ] st = run_program(instructions, n_lanes=1) - result = _f16(st.vgpr[0][4] & 0xffff) + result = f16(st.vgpr[0][4] & 0xffff) expected = 0.325928 * (-0.486572) + 1.0 self.assertAlmostEqual(result, expected, delta=0.01) def test_v_fma_f16_inline_const_0_5(self): """V_FMA_F16: a*b + 0.5 should use f16 inline constant.""" - from extra.assembly.amd.test.hw.helpers import f32_to_f16, _f16 f16_a = f32_to_f16(2.0) f16_b = f32_to_f16(3.0) instructions = [ @@ -885,13 +883,12 @@ class TestF16Modifiers(unittest.TestCase): v_fma_f16(v[2], v[0], v[1], 0.5), # 0.5 is inline constant ] st = run_program(instructions, n_lanes=1) - result = _f16(st.vgpr[0][2] & 0xffff) + result = f16(st.vgpr[0][2] & 0xffff) expected = 2.0 * 3.0 + 0.5 self.assertAlmostEqual(result, expected, delta=0.01) def test_v_fma_f16_inline_const_neg_1_0(self): """V_FMA_F16: a*b + (-1.0) should use f16 inline constant.""" - from extra.assembly.amd.test.hw.helpers import f32_to_f16, _f16 f16_a = f32_to_f16(2.0) f16_b = f32_to_f16(3.0) instructions = [ @@ -902,13 +899,12 @@ class TestF16Modifiers(unittest.TestCase): v_fma_f16(v[2], v[0], v[1], -1.0), # -1.0 is inline constant ] st = run_program(instructions, n_lanes=1) - result = _f16(st.vgpr[0][2] & 0xffff) + result = f16(st.vgpr[0][2] & 0xffff) expected = 2.0 * 3.0 + (-1.0) self.assertAlmostEqual(result, expected, delta=0.01) def test_v_add_f16_abs_both(self): """V_ADD_F16 with abs on both operands.""" - from extra.assembly.amd.test.hw.helpers import f32_to_f16, _f16 f16_neg2 = f32_to_f16(-2.0) f16_neg3 = f32_to_f16(-3.0) instructions = [ @@ -919,12 +915,11 @@ class TestF16Modifiers(unittest.TestCase): v_add_f16_e64(v[2], abs(v[0]), abs(v[1])), # |-2| + |-3| = 5 ] st = run_program(instructions, n_lanes=1) - result = _f16(st.vgpr[0][2] & 0xffff) + result = f16(st.vgpr[0][2] & 0xffff) self.assertAlmostEqual(result, 5.0, delta=0.01) def test_v_mul_f16_neg_abs(self): """V_MUL_F16 with neg on one operand and abs on another.""" - from extra.assembly.amd.test.hw.helpers import f32_to_f16, _f16 f16_2 = f32_to_f16(2.0) f16_neg3 = f32_to_f16(-3.0) instructions = [ @@ -935,7 +930,7 @@ class TestF16Modifiers(unittest.TestCase): v_mul_f16_e64(v[2], -v[0], abs(v[1])), # -(2) * |-3| = -6 ] st = run_program(instructions, n_lanes=1) - result = _f16(st.vgpr[0][2] & 0xffff) + result = f16(st.vgpr[0][2] & 0xffff) self.assertAlmostEqual(result, -6.0, delta=0.01) def test_v_fmac_f16_hi_dest(self): @@ -943,7 +938,6 @@ class TestF16Modifiers(unittest.TestCase): This tests the case from AMD_LLVM sin(0) where V_FMAC_F16 writes to v0.h. """ - from extra.assembly.amd.test.hw.helpers import _f16 instructions = [ s_mov_b32(s[0], 0x38003c00), # v0 = {hi=0.5, lo=1.0} v_mov_b32_e32(v[0], s[0]), @@ -954,8 +948,8 @@ class TestF16Modifiers(unittest.TestCase): ] st = run_program(instructions, n_lanes=1) v0 = st.vgpr[0][0] - result_hi = _f16((v0 >> 16) & 0xffff) - result_lo = _f16(v0 & 0xffff) + result_hi = f16((v0 >> 16) & 0xffff) + result_lo = f16(v0 & 0xffff) self.assertAlmostEqual(result_hi, 0.5, delta=0.01, msg=f"Expected hi=0.5, got {result_hi}") self.assertAlmostEqual(result_lo, 1.0, delta=0.01, msg=f"Expected lo=1.0, got {result_lo}") @@ -2955,5 +2949,318 @@ class TestVOP3Clamp(unittest.TestCase): self.assertAlmostEqual(i2f(st.vgpr[3][1]), 1.0, places=5, msg="lane 3: 2.5 should clamp to 1.0") +class TestCvtPkF16(unittest.TestCase): + """Tests for V_CVT_PK_RTZ_F16_F32 - pack two f32 to f16 with round toward zero.""" + + def test_cvt_pk_rtz_f16_f32_basic(self): + """V_CVT_PK_RTZ_F16_F32: basic pack of two f32 values.""" + instructions = [ + v_mov_b32_e32(v[0], 1.0), + v_mov_b32_e32(v[1], 2.0), + v_cvt_pk_rtz_f16_f32_e64(v[2], v[0], v[1]), + ] + st = run_program(instructions, n_lanes=1) + result = st.vgpr[0][2] + lo_f16 = f16(result & 0xffff) + hi_f16 = f16((result >> 16) & 0xffff) + self.assertAlmostEqual(lo_f16, 1.0, delta=0.01) + self.assertAlmostEqual(hi_f16, 2.0, delta=0.01) + + +class TestCvtPkNorm(unittest.TestCase): + """Tests for V_CVT_PK_NORM_I16_F32 and V_CVT_PK_NORM_U16_F32.""" + + def test_cvt_pk_norm_i16_f32_basic(self): + """V_CVT_PK_NORM_I16_F32: pack two f32 to normalized i16.""" + instructions = [ + v_mov_b32_e32(v[0], 1.0), + v_mov_b32_e32(v[1], -1.0), + v_cvt_pk_norm_i16_f32(v[2], v[0], v[1]), + ] + st = run_program(instructions, n_lanes=1) + result = st.vgpr[0][2] + lo = result & 0xffff + hi = (result >> 16) & 0xffff + self.assertEqual(lo, 32767) + self.assertEqual(hi, 0x8001) # -32767, hardware uses symmetric range + + def test_cvt_pk_norm_u16_f32_basic(self): + """V_CVT_PK_NORM_U16_F32: pack two f32 to normalized u16.""" + instructions = [ + v_mov_b32_e32(v[0], 1.0), + v_mov_b32_e32(v[1], 0.5), + v_cvt_pk_norm_u16_f32(v[2], v[0], v[1]), + ] + st = run_program(instructions, n_lanes=1) + result = st.vgpr[0][2] + lo = result & 0xffff + hi = (result >> 16) & 0xffff + self.assertEqual(lo, 65535) + self.assertAlmostEqual(hi, 32768, delta=1) + + +class TestCvtPkInt(unittest.TestCase): + """Tests for V_CVT_PK_I16_I32, V_CVT_PK_U16_U32, V_CVT_PK_I16_F32, V_CVT_PK_U16_F32.""" + + def test_cvt_pk_i16_i32_basic(self): + """V_CVT_PK_I16_I32: pack two i32 to i16.""" + instructions = [ + s_mov_b32(s[0], 100), + s_mov_b32(s[1], -100 & 0xffffffff), + v_mov_b32_e32(v[0], s[0]), + v_mov_b32_e32(v[1], s[1]), + v_cvt_pk_i16_i32(v[2], v[0], v[1]), + ] + st = run_program(instructions, n_lanes=1) + result = st.vgpr[0][2] + lo = result & 0xffff + hi = (result >> 16) & 0xffff + lo_signed = lo if lo < 32768 else lo - 65536 + hi_signed = hi if hi < 32768 else hi - 65536 + self.assertEqual(lo_signed, 100) + self.assertEqual(hi_signed, -100) + + def test_cvt_pk_u16_u32_basic(self): + """V_CVT_PK_U16_U32: pack two u32 to u16.""" + instructions = [ + s_mov_b32(s[0], 1000), + s_mov_b32(s[1], 2000), + v_mov_b32_e32(v[0], s[0]), + v_mov_b32_e32(v[1], s[1]), + v_cvt_pk_u16_u32(v[2], v[0], v[1]), + ] + st = run_program(instructions, n_lanes=1) + result = st.vgpr[0][2] + lo = result & 0xffff + hi = (result >> 16) & 0xffff + self.assertEqual(lo, 1000) + self.assertEqual(hi, 2000) + + def test_cvt_pk_i16_f32_basic(self): + """V_CVT_PK_I16_F32: convert two f32 to packed i16.""" + instructions = [ + v_mov_b32_e32(v[0], 100.5), + v_mov_b32_e32(v[1], -50.7), + v_cvt_pk_i16_f32(v[2], v[0], v[1]), + ] + st = run_program(instructions, n_lanes=1) + result = st.vgpr[0][2] + lo = result & 0xffff + hi = (result >> 16) & 0xffff + lo_signed = lo if lo < 32768 else lo - 65536 + hi_signed = hi if hi < 32768 else hi - 65536 + self.assertEqual(lo_signed, 100) + self.assertEqual(hi_signed, -50) + + def test_cvt_pk_u16_f32_basic(self): + """V_CVT_PK_U16_F32: convert two f32 to packed u16.""" + instructions = [ + v_mov_b32_e32(v[0], 100.9), + v_mov_b32_e32(v[1], 200.1), + v_cvt_pk_u16_f32(v[2], v[0], v[1]), + ] + st = run_program(instructions, n_lanes=1) + result = st.vgpr[0][2] + lo = result & 0xffff + hi = (result >> 16) & 0xffff + self.assertEqual(lo, 100) + self.assertEqual(hi, 200) + + def test_cvt_pk_u8_f32_basic(self): + """V_CVT_PK_U8_F32: convert f32 to u8 and pack at byte position.""" + instructions = [ + v_mov_b32_e32(v[0], 128.5), + v_mov_b32_e32(v[1], 0), + v_mov_b32_e32(v[2], 0), + v_cvt_pk_u8_f32(v[2], v[0], v[1], v[2]), + ] + st = run_program(instructions, n_lanes=1) + result = st.vgpr[0][2] + byte0 = result & 0xff + self.assertEqual(byte0, 128) + + +class TestDotProduct(unittest.TestCase): + """Tests for dot product instructions V_DOT4_U32_U8, V_DOT8_U32_U4.""" + + def test_v_dot4_u32_u8_basic(self): + """V_DOT4_U32_U8: 4-element dot product of u8 vectors.""" + src0 = 0x04030201 # {4, 3, 2, 1} + src1 = 0x01010101 # {1, 1, 1, 1} + instructions = [ + s_mov_b32(s[0], src0), + s_mov_b32(s[1], src1), + v_mov_b32_e32(v[0], s[0]), + v_mov_b32_e32(v[1], s[1]), + v_mov_b32_e32(v[2], 0), + v_dot4_u32_u8(v[2], v[0], v[1], v[2]), + ] + st = run_program(instructions, n_lanes=1) + result = st.vgpr[0][2] + self.assertEqual(result, 10) + + def test_v_dot4_u32_u8_with_accumulator(self): + """V_DOT4_U32_U8 with non-zero accumulator.""" + src0 = 0x02020202 # {2, 2, 2, 2} + src1 = 0x03030303 # {3, 3, 3, 3} + instructions = [ + s_mov_b32(s[0], src0), + s_mov_b32(s[1], src1), + v_mov_b32_e32(v[0], s[0]), + v_mov_b32_e32(v[1], s[1]), + v_mov_b32_e32(v[2], 100), + v_dot4_u32_u8(v[2], v[0], v[1], v[2]), + ] + st = run_program(instructions, n_lanes=1) + result = st.vgpr[0][2] + self.assertEqual(result, 124) + + def test_v_dot8_u32_u4_basic(self): + """V_DOT8_U32_U4: 8-element dot product of u4 vectors.""" + # src0 = 8 nibbles: {1,2,3,4,5,6,7,8} packed as 0x87654321 + # src1 = 8 nibbles: {1,1,1,1,1,1,1,1} packed as 0x11111111 + # result = 1+2+3+4+5+6+7+8 = 36 + src0 = 0x87654321 + src1 = 0x11111111 + instructions = [ + s_mov_b32(s[0], src0), + s_mov_b32(s[1], src1), + v_mov_b32_e32(v[0], s[0]), + v_mov_b32_e32(v[1], s[1]), + v_mov_b32_e32(v[2], 0), + v_dot8_u32_u4(v[2], v[0], v[1], v[2]), + ] + st = run_program(instructions, n_lanes=1) + result = st.vgpr[0][2] + self.assertEqual(result, 36) + + +class TestMinMaxF16Vop3(unittest.TestCase): + """Tests for V_MIN3_F16, V_MAX3_F16, V_MED3_F16, V_MINMAX_F16, V_MAXMIN_F16.""" + + def test_v_min3_f16_basic(self): + """V_MIN3_F16: minimum of three f16 values.""" + instructions = [ + s_mov_b32(s[0], f32_to_f16(3.0)), + s_mov_b32(s[1], f32_to_f16(1.0)), + s_mov_b32(s[2], f32_to_f16(2.0)), + v_mov_b32_e32(v[0], s[0]), + v_mov_b32_e32(v[1], s[1]), + v_mov_b32_e32(v[2], s[2]), + v_min3_f16(v[3], v[0], v[1], v[2]), + ] + st = run_program(instructions, n_lanes=1) + result = f16(st.vgpr[0][3] & 0xffff) + self.assertAlmostEqual(result, 1.0, delta=0.01) + + def test_v_max3_f16_basic(self): + """V_MAX3_F16: maximum of three f16 values.""" + instructions = [ + s_mov_b32(s[0], f32_to_f16(1.0)), + s_mov_b32(s[1], f32_to_f16(3.0)), + s_mov_b32(s[2], f32_to_f16(2.0)), + v_mov_b32_e32(v[0], s[0]), + v_mov_b32_e32(v[1], s[1]), + v_mov_b32_e32(v[2], s[2]), + v_max3_f16(v[3], v[0], v[1], v[2]), + ] + st = run_program(instructions, n_lanes=1) + result = f16(st.vgpr[0][3] & 0xffff) + self.assertAlmostEqual(result, 3.0, delta=0.01) + + def test_v_med3_f16_basic(self): + """V_MED3_F16: median of three f16 values.""" + instructions = [ + s_mov_b32(s[0], f32_to_f16(3.0)), + s_mov_b32(s[1], f32_to_f16(1.0)), + s_mov_b32(s[2], f32_to_f16(2.0)), + v_mov_b32_e32(v[0], s[0]), + v_mov_b32_e32(v[1], s[1]), + v_mov_b32_e32(v[2], s[2]), + v_med3_f16(v[3], v[0], v[1], v[2]), + ] + st = run_program(instructions, n_lanes=1) + result = f16(st.vgpr[0][3] & 0xffff) + self.assertAlmostEqual(result, 2.0, delta=0.01) + + def test_v_minmax_f16_basic(self): + """V_MINMAX_F16: clamp(src0, min=src1, max=src2).""" + instructions = [ + s_mov_b32(s[0], f32_to_f16(2.5)), + s_mov_b32(s[1], f32_to_f16(1.0)), + s_mov_b32(s[2], f32_to_f16(2.0)), + v_mov_b32_e32(v[0], s[0]), + v_mov_b32_e32(v[1], s[1]), + v_mov_b32_e32(v[2], s[2]), + v_minmax_f16(v[3], v[0], v[1], v[2]), + ] + st = run_program(instructions, n_lanes=1) + result = f16(st.vgpr[0][3] & 0xffff) + self.assertAlmostEqual(result, 2.0, delta=0.01) + + def test_v_maxmin_f16_basic(self): + """V_MAXMIN_F16: clamp(src0, min=src2, max=src1).""" + instructions = [ + s_mov_b32(s[0], f32_to_f16(0.5)), + s_mov_b32(s[1], f32_to_f16(2.0)), + s_mov_b32(s[2], f32_to_f16(1.0)), + v_mov_b32_e32(v[0], s[0]), + v_mov_b32_e32(v[1], s[1]), + v_mov_b32_e32(v[2], s[2]), + v_maxmin_f16(v[3], v[0], v[1], v[2]), + ] + st = run_program(instructions, n_lanes=1) + result = f16(st.vgpr[0][3] & 0xffff) + self.assertAlmostEqual(result, 1.0, delta=0.01) + + def test_v_min3_f16_with_neg(self): + """V_MIN3_F16 with neg modifier: min(-3, 1, 2) = -3.""" + instructions = [ + s_mov_b32(s[0], f32_to_f16(3.0)), + s_mov_b32(s[1], f32_to_f16(1.0)), + s_mov_b32(s[2], f32_to_f16(2.0)), + v_mov_b32_e32(v[0], s[0]), + v_mov_b32_e32(v[1], s[1]), + v_mov_b32_e32(v[2], s[2]), + v_min3_f16(v[3], -v[0], v[1], v[2]), # neg on first operand + ] + st = run_program(instructions, n_lanes=1) + result = f16(st.vgpr[0][3] & 0xffff) + self.assertAlmostEqual(result, -3.0, delta=0.01) + + def test_v_max3_f16_with_abs(self): + """V_MAX3_F16 with abs modifier: max(|-3|, 1, 2) = 3.""" + instructions = [ + s_mov_b32(s[0], f32_to_f16(-3.0)), + s_mov_b32(s[1], f32_to_f16(1.0)), + s_mov_b32(s[2], f32_to_f16(2.0)), + v_mov_b32_e32(v[0], s[0]), + v_mov_b32_e32(v[1], s[1]), + v_mov_b32_e32(v[2], s[2]), + v_max3_f16(v[3], abs(v[0]), v[1], v[2]), # abs on first operand + ] + st = run_program(instructions, n_lanes=1) + result = f16(st.vgpr[0][3] & 0xffff) + self.assertAlmostEqual(result, 3.0, delta=0.01) + + def test_v_med3_f16_opsel_hi(self): + """V_MED3_F16 with opsel reading from hi half.""" + # Pack two f16 values: hi=5.0, lo=1.0 + packed = (f32_to_f16(5.0) << 16) | f32_to_f16(1.0) + instructions = [ + s_mov_b32(s[0], packed), + s_mov_b32(s[1], f32_to_f16(3.0)), + s_mov_b32(s[2], f32_to_f16(4.0)), + v_mov_b32_e32(v[0], s[0]), + v_mov_b32_e32(v[1], s[1]), + v_mov_b32_e32(v[2], s[2]), + # Read hi half of v[0] (5.0), med3(5, 3, 4) = 4 + v_med3_f16(v[3], v[0].h, v[1], v[2]), + ] + st = run_program(instructions, n_lanes=1) + result = f16(st.vgpr[0][3] & 0xffff) + self.assertAlmostEqual(result, 4.0, delta=0.01) + + if __name__ == '__main__': unittest.main() diff --git a/extra/assembly/amd/test/hw/test_vop3p.py b/extra/assembly/amd/test/hw/test_vop3p.py index 9c1ead9a9d..62671b4ad8 100644 --- a/extra/assembly/amd/test/hw/test_vop3p.py +++ b/extra/assembly/amd/test/hw/test_vop3p.py @@ -149,7 +149,6 @@ class TestFmaMix(unittest.TestCase): def test_v_fma_mix_f32_src2_f16_lo(self): """V_FMA_MIX_F32 with src2 as f16 from lo bits.""" - from extra.assembly.amd.test.hw.helpers import f32_to_f16 f16_2 = f32_to_f16(2.0) instructions = [ s_mov_b32(s[0], f2i(1.0)), @@ -166,7 +165,6 @@ class TestFmaMix(unittest.TestCase): def test_v_fma_mix_f32_src2_f16_hi(self): """V_FMA_MIX_F32 with src2 as f16 from hi bits.""" - from extra.assembly.amd.test.hw.helpers import f32_to_f16 f16_2 = f32_to_f16(2.0) val = (f16_2 << 16) | 0 instructions = [ @@ -199,7 +197,6 @@ class TestFmaMix(unittest.TestCase): def test_v_fma_mix_f32_with_abs_f16_src2_lo(self): """V_FMA_MIX_F32 with abs modifier on f16 src2 (lo half). Regression test for sin(1.0) bug.""" - from extra.assembly.amd.test.hw.helpers import f32_to_f16 f16_neg1 = f32_to_f16(-1.0) # 0xbc00 instructions = [ s_mov_b32(s[0], f2i(0.0)), # src0 = 0.0 (f32) @@ -217,7 +214,6 @@ class TestFmaMix(unittest.TestCase): def test_v_fma_mix_f32_with_neg_f16_src2_lo(self): """V_FMA_MIX_F32 with neg modifier on f16 src2 (lo half).""" - from extra.assembly.amd.test.hw.helpers import f32_to_f16 f16_1 = f32_to_f16(1.0) # 0x3c00 instructions = [ s_mov_b32(s[0], f2i(0.0)), # src0 = 0.0 (f32) @@ -235,7 +231,6 @@ class TestFmaMix(unittest.TestCase): def test_v_fma_mix_f32_with_abs_f16_src2_hi(self): """V_FMA_MIX_F32 with abs modifier on f16 src2 (hi half).""" - from extra.assembly.amd.test.hw.helpers import f32_to_f16 f16_neg1 = f32_to_f16(-1.0) # 0xbc00 val = (f16_neg1 << 16) | 0 # -1.0 in hi, 0 in lo instructions = [ @@ -254,7 +249,6 @@ class TestFmaMix(unittest.TestCase): def test_v_fma_mixlo_f16(self): """V_FMA_MIXLO_F16 writes to low 16 bits of destination.""" - from extra.assembly.amd.test.hw.helpers import _f16 instructions = [ s_mov_b32(s[0], f2i(2.0)), v_mov_b32_e32(v[0], s[0]), @@ -267,14 +261,13 @@ class TestFmaMix(unittest.TestCase): VOP3P(VOP3POp.V_FMA_MIXLO_F16, vdst=v[3], src0=v[0], src1=v[1], src2=v[2], opsel=0, opsel_hi=0, opsel_hi2=0), ] st = run_program(instructions, n_lanes=1) - lo = _f16(st.vgpr[0][3] & 0xffff) + lo = f16(st.vgpr[0][3] & 0xffff) hi = (st.vgpr[0][3] >> 16) & 0xffff self.assertAlmostEqual(lo, 7.0, places=1) self.assertEqual(hi, 0xdead, f"hi should be preserved, got 0x{hi:04x}") def test_v_fma_mixlo_f16_all_f32_sources(self): """V_FMA_MIXLO_F16 with all f32 sources.""" - from extra.assembly.amd.test.hw.helpers import _f16 instructions = [ s_mov_b32(s[0], f2i(1.0)), v_mov_b32_e32(v[0], s[0]), @@ -286,13 +279,12 @@ class TestFmaMix(unittest.TestCase): VOP3P(VOP3POp.V_FMA_MIXLO_F16, vdst=v[3], src0=v[0], src1=v[1], src2=v[2], opsel=0, opsel_hi=0, opsel_hi2=0), ] st = run_program(instructions, n_lanes=1) - lo = _f16(st.vgpr[0][3] & 0xffff) + lo = f16(st.vgpr[0][3] & 0xffff) # 1*2+3 = 5 self.assertAlmostEqual(lo, 5.0, places=1) def test_v_fma_mixlo_f16_sin_case(self): """V_FMA_MIXLO_F16 case from sin kernel.""" - from extra.assembly.amd.test.hw.helpers import _f16 instructions = [ s_mov_b32(s[0], 0x3f800000), # f32 1.0 v_mov_b32_e32(v[3], s[0]), @@ -305,7 +297,7 @@ class TestFmaMix(unittest.TestCase): VOP3P(VOP3POp.V_FMA_MIXLO_F16, vdst=v[3], src0=v[3], src1=s[6], src2=v[5], opsel=0, opsel_hi=0, opsel_hi2=0), ] st = run_program(instructions, n_lanes=1) - lo = _f16(st.vgpr[0][3] & 0xffff) + lo = f16(st.vgpr[0][3] & 0xffff) self.assertAlmostEqual(lo, -3.14159, delta=0.01) @@ -314,7 +306,6 @@ class TestVOP3P(unittest.TestCase): def test_v_pk_add_f16_basic(self): """V_PK_ADD_F16 adds two packed f16 values.""" - from extra.assembly.amd.test.hw.helpers import _f16 instructions = [ s_mov_b32(s[0], 0x40003c00), # hi=2.0, lo=1.0 s_mov_b32(s[1], 0x44004200), # hi=4.0, lo=3.0 @@ -324,14 +315,13 @@ class TestVOP3P(unittest.TestCase): ] st = run_program(instructions, n_lanes=1) result = st.vgpr[0][2] - lo = _f16(result & 0xffff) - hi = _f16((result >> 16) & 0xffff) + lo = f16(result & 0xffff) + hi = f16((result >> 16) & 0xffff) self.assertAlmostEqual(lo, 4.0, places=2) self.assertAlmostEqual(hi, 6.0, places=2) def test_v_pk_mul_f16_basic(self): """V_PK_MUL_F16 multiplies two packed f16 values.""" - from extra.assembly.amd.test.hw.helpers import _f16 instructions = [ s_mov_b32(s[0], 0x42004000), # hi=3.0, lo=2.0 s_mov_b32(s[1], 0x45004400), # hi=5.0, lo=4.0 @@ -341,14 +331,13 @@ class TestVOP3P(unittest.TestCase): ] st = run_program(instructions, n_lanes=1) result = st.vgpr[0][2] - lo = _f16(result & 0xffff) - hi = _f16((result >> 16) & 0xffff) + lo = f16(result & 0xffff) + hi = f16((result >> 16) & 0xffff) self.assertAlmostEqual(lo, 8.0, places=1) self.assertAlmostEqual(hi, 15.0, places=1) def test_v_pk_fma_f16_basic(self): """V_PK_FMA_F16: D = A * B + C for packed f16.""" - from extra.assembly.amd.test.hw.helpers import _f16 instructions = [ s_mov_b32(s[0], 0x42004000), # A: hi=3.0, lo=2.0 s_mov_b32(s[1], 0x45004400), # B: hi=5.0, lo=4.0 @@ -360,8 +349,8 @@ class TestVOP3P(unittest.TestCase): ] st = run_program(instructions, n_lanes=1) result = st.vgpr[0][3] - lo = _f16(result & 0xffff) - hi = _f16((result >> 16) & 0xffff) + lo = f16(result & 0xffff) + hi = f16((result >> 16) & 0xffff) self.assertAlmostEqual(lo, 9.0, places=1) # 2*4+1 self.assertAlmostEqual(hi, 16.0, places=0) # 3*5+1 @@ -370,7 +359,6 @@ class TestVOP3P(unittest.TestCase): Inline constants for VOP3P are f16 values in the low 16 bits only. hi half of inline constant is 0, so hi result = v0.hi + 0 = 1.0. """ - from extra.assembly.amd.test.hw.helpers import _f16 instructions = [ s_mov_b32(s[0], 0x3c003c00), # packed f16: hi=1.0, lo=1.0 v_mov_b32_e32(v[0], s[0]), @@ -378,8 +366,8 @@ class TestVOP3P(unittest.TestCase): ] st = run_program(instructions, n_lanes=1) result = st.vgpr[0][1] - lo = _f16(result & 0xffff) - hi = _f16((result >> 16) & 0xffff) + lo = f16(result & 0xffff) + hi = f16((result >> 16) & 0xffff) # lo = 1.0 + 1.0 = 2.0, hi = 1.0 + 0.0 = 1.0 (inline const hi half is 0) self.assertAlmostEqual(lo, 2.0, places=2) self.assertAlmostEqual(hi, 1.0, places=2) @@ -388,7 +376,6 @@ class TestVOP3P(unittest.TestCase): """V_PK_MUL_F16 with inline constant POS_TWO (2.0). Inline constant has value only in low 16 bits, hi is 0. """ - from extra.assembly.amd.test.hw.helpers import _f16 # v0 = packed (3.0, 4.0), multiply by POS_TWO # lo = 3.0 * 2.0 = 6.0, hi = 4.0 * 0.0 = 0.0 (inline const hi is 0) instructions = [ @@ -398,8 +385,8 @@ class TestVOP3P(unittest.TestCase): ] st = run_program(instructions, n_lanes=1) result = st.vgpr[0][1] - lo = _f16(result & 0xffff) - hi = _f16((result >> 16) & 0xffff) + lo = f16(result & 0xffff) + hi = f16((result >> 16) & 0xffff) self.assertAlmostEqual(lo, 6.0, places=1) self.assertAlmostEqual(hi, 0.0, places=1) @@ -413,7 +400,6 @@ class TestWMMAF16(unittest.TestCase): def test_v_wmma_f16_16x16x16_f16_all_ones(self): """V_WMMA_F16_16X16X16_F16 with all ones produces 16.0 in f16.""" - from extra.assembly.amd.test.hw.helpers import _f16 instructions = [] instructions.append(s_mov_b32(s[0], 0x3c003c00)) # packed f16 1.0 # Initialize A matrix in v[16:23] (8 regs) @@ -432,13 +418,12 @@ class TestWMMAF16(unittest.TestCase): for lane in range(32): for reg in range(8): result = st.vgpr[lane][reg] - lo = _f16(result & 0xffff) + lo = f16(result & 0xffff) self.assertAlmostEqual(lo, 16.0, places=1, msg=f"v[{reg}] lane {lane}: expected 16.0, got {lo}") self.assertEqual(result >> 16, 0, msg=f"v[{reg}] lane {lane}: hi bits should be 0") def test_v_wmma_f16_16x16x16_f16_with_accumulator(self): """V_WMMA_F16_16X16X16_F16 with non-zero accumulator.""" - from extra.assembly.amd.test.hw.helpers import _f16 instructions = [] instructions.append(s_mov_b32(s[0], 0x3c003c00)) # packed f16 1.0 instructions.append(s_mov_b32(s[1], 0x4500)) # f16 5.0 in lo bits only @@ -458,7 +443,7 @@ class TestWMMAF16(unittest.TestCase): for lane in range(32): for reg in range(8): result = st.vgpr[lane][reg] - lo = _f16(result & 0xffff) + lo = f16(result & 0xffff) self.assertAlmostEqual(lo, 21.0, places=0, msg=f"v[{reg}] lane {lane}: expected 21.0, got {lo}") self.assertEqual(result >> 16, 0, msg=f"v[{reg}] lane {lane}: hi bits should be 0") @@ -468,7 +453,6 @@ class TestWMMAF16(unittest.TestCase): Regression test: WMMA was using static register indices instead of dynamic. This test uses v[64:71] for A, v[80:87] for B, v[96:103] for C/D. """ - from extra.assembly.amd.test.hw.helpers import _f16 instructions = [] instructions.append(s_mov_b32(s[0], 0x3c003c00)) # packed f16 1.0 # Initialize A matrix in v[64:71] (8 regs) @@ -490,7 +474,7 @@ class TestWMMAF16(unittest.TestCase): for lane in range(32): for reg in range(8): result = st.vgpr[lane][reg] - lo = _f16(result & 0xffff) + lo = f16(result & 0xffff) self.assertAlmostEqual(lo, 16.0, places=1, msg=f"v[{reg}] lane {lane}: expected 16.0, got {lo}") self.assertEqual(result >> 16, 0, msg=f"v[{reg}] lane {lane}: hi bits should be 0") @@ -713,7 +697,6 @@ class TestPackedMixedSigns(unittest.TestCase): def test_pk_add_f16_mixed_signs(self): """V_PK_ADD_F16 with mixed positive/negative values.""" - from extra.assembly.amd.test.hw.helpers import _f16 instructions = [ s_mov_b32(s[0], 0xc0003c00), # packed: hi=-2.0, lo=1.0 s_mov_b32(s[1], 0x3c003c00), # packed: hi=1.0, lo=1.0 @@ -723,14 +706,13 @@ class TestPackedMixedSigns(unittest.TestCase): ] st = run_program(instructions, n_lanes=1) result = st.vgpr[0][2] - lo = _f16(result & 0xffff) - hi = _f16((result >> 16) & 0xffff) + lo = f16(result & 0xffff) + hi = f16((result >> 16) & 0xffff) self.assertAlmostEqual(lo, 2.0, places=2) # 1.0 + 1.0 self.assertAlmostEqual(hi, -1.0, places=2) # -2.0 + 1.0 def test_pk_mul_f16_zero(self): """V_PK_MUL_F16 with zero.""" - from extra.assembly.amd.test.hw.helpers import _f16 instructions = [ s_mov_b32(s[0], 0x40004000), # packed: 2.0, 2.0 s_mov_b32(s[1], 0x00000000), # packed: 0.0, 0.0 @@ -743,5 +725,276 @@ class TestPackedMixedSigns(unittest.TestCase): self.assertEqual(result, 0x00000000, "2.0 * 0.0 should be 0.0") +class TestDot2F32F16(unittest.TestCase): + """Tests for V_DOT2_F32_F16 - dot product of f16 pairs producing f32.""" + + def test_v_dot2_f32_f16_basic(self): + """V_DOT2_F32_F16: dot product of two packed f16 pairs -> f32.""" + # src0 = {hi=2.0, lo=1.0}, src1 = {hi=4.0, lo=3.0} + # result = 1.0*3.0 + 2.0*4.0 + 0 = 3 + 8 = 11.0 + src0 = (f32_to_f16(2.0) << 16) | f32_to_f16(1.0) + src1 = (f32_to_f16(4.0) << 16) | f32_to_f16(3.0) + instructions = [ + s_mov_b32(s[0], src0), + s_mov_b32(s[1], src1), + v_mov_b32_e32(v[0], s[0]), + v_mov_b32_e32(v[1], s[1]), + v_mov_b32_e32(v[2], 0), + v_dot2_f32_f16(v[3], v[0], v[1], v[2], opsel_hi=3, opsel_hi2=1), + ] + st = run_program(instructions, n_lanes=1) + result = i2f(st.vgpr[0][3]) + self.assertAlmostEqual(result, 11.0, places=2) + + def test_v_dot2_f32_f16_with_accumulator(self): + """V_DOT2_F32_F16 with non-zero f32 accumulator.""" + # src0 = {hi=1.0, lo=1.0}, src1 = {hi=1.0, lo=1.0}, acc = 5.0 + # result = 1.0*1.0 + 1.0*1.0 + 5.0 = 7.0 + src0 = (f32_to_f16(1.0) << 16) | f32_to_f16(1.0) + instructions = [ + s_mov_b32(s[0], src0), + s_mov_b32(s[1], f2i(5.0)), + v_mov_b32_e32(v[0], s[0]), + v_mov_b32_e32(v[1], s[0]), # same as src0 + v_mov_b32_e32(v[2], s[1]), + v_dot2_f32_f16(v[3], v[0], v[1], v[2], opsel_hi=3, opsel_hi2=1), + ] + st = run_program(instructions, n_lanes=1) + result = i2f(st.vgpr[0][3]) + self.assertAlmostEqual(result, 7.0, places=2) + + def test_v_dot2_f32_f16_negative_values(self): + """V_DOT2_F32_F16 with negative f16 values.""" + # src0 = {hi=-2.0, lo=3.0}, src1 = {hi=1.0, lo=2.0} + # result = 3.0*2.0 + (-2.0)*1.0 + 0 = 6 - 2 = 4.0 + src0 = (f32_to_f16(-2.0) << 16) | f32_to_f16(3.0) + src1 = (f32_to_f16(1.0) << 16) | f32_to_f16(2.0) + instructions = [ + s_mov_b32(s[0], src0), + s_mov_b32(s[1], src1), + v_mov_b32_e32(v[0], s[0]), + v_mov_b32_e32(v[1], s[1]), + v_mov_b32_e32(v[2], 0), + v_dot2_f32_f16(v[3], v[0], v[1], v[2], opsel_hi=3, opsel_hi2=1), + ] + st = run_program(instructions, n_lanes=1) + result = i2f(st.vgpr[0][3]) + self.assertAlmostEqual(result, 4.0, places=2) + + +class TestDot2F16F16(unittest.TestCase): + """Tests for V_DOT2_F16_F16 - dot product of f16 pairs producing f16.""" + + def test_v_dot2_f16_f16_basic(self): + """V_DOT2_F16_F16: dot product of two packed f16 pairs -> f16.""" + # src0 = {hi=2.0, lo=1.0}, src1 = {hi=3.0, lo=2.0} + # result = 1.0*2.0 + 2.0*3.0 + 0 = 2 + 6 = 8.0 (f16) + src0 = (f32_to_f16(2.0) << 16) | f32_to_f16(1.0) + src1 = (f32_to_f16(3.0) << 16) | f32_to_f16(2.0) + instructions = [ + s_mov_b32(s[0], src0), + s_mov_b32(s[1], src1), + v_mov_b32_e32(v[0], s[0]), + v_mov_b32_e32(v[1], s[1]), + v_mov_b32_e32(v[2], 0), + v_dot2_f16_f16(v[3], v[0], v[1], v[2]), + ] + st = run_program(instructions, n_lanes=1) + result = f16(st.vgpr[0][3] & 0xffff) + self.assertAlmostEqual(result, 8.0, places=1) + + def test_v_dot2_f16_f16_with_accumulator(self): + """V_DOT2_F16_F16 with non-zero f16 accumulator.""" + # src0 = {hi=1.0, lo=1.0}, src1 = {hi=1.0, lo=1.0}, acc = 3.0 (f16) + # result = 1.0*1.0 + 1.0*1.0 + 3.0 = 5.0 (f16) + src0 = (f32_to_f16(1.0) << 16) | f32_to_f16(1.0) + acc = f32_to_f16(3.0) + instructions = [ + s_mov_b32(s[0], src0), + s_mov_b32(s[2], acc), + v_mov_b32_e32(v[0], s[0]), + v_mov_b32_e32(v[1], s[0]), # same as src0 + v_mov_b32_e32(v[2], s[2]), + v_dot2_f16_f16(v[3], v[0], v[1], v[2]), + ] + st = run_program(instructions, n_lanes=1) + result = f16(st.vgpr[0][3] & 0xffff) + self.assertAlmostEqual(result, 5.0, places=1) + + +class TestSignedDotProducts(unittest.TestCase): + """Tests for V_DOT4_I32_IU8 and V_DOT8_I32_IU4 with signed inputs.""" + + def test_v_dot4_i32_iu8_signed_both(self): + """V_DOT4_I32_IU8 with both inputs signed (neg=0b011).""" + # src0 = {-1, -2, 3, 4} as i8 = {0xff, 0xfe, 0x03, 0x04} + # src1 = {1, 1, 1, 1} as i8 + # result = (-1)*1 + (-2)*1 + 3*1 + 4*1 = -1 - 2 + 3 + 4 = 4 + src0 = (0xff << 24) | (0xfe << 16) | (0x03 << 8) | 0x04 # -1, -2, 3, 4 + src1 = 0x01010101 # 1, 1, 1, 1 + instructions = [ + s_mov_b32(s[0], src0), + s_mov_b32(s[1], src1), + v_mov_b32_e32(v[0], s[0]), + v_mov_b32_e32(v[1], s[1]), + v_mov_b32_e32(v[2], 0), + v_dot4_i32_iu8(v[3], v[0], v[1], v[2], neg=0b011), # both signed + ] + st = run_program(instructions, n_lanes=1) + result = st.vgpr[0][3] + # Result is i32, interpret as signed + if result >= 0x80000000: + result = result - 0x100000000 + self.assertEqual(result, 4) + + def test_v_dot4_i32_iu8_src0_signed(self): + """V_DOT4_I32_IU8 with only src0 signed (neg=0b001).""" + # src0 = {-1, -1, -1, -1} as i8 = {0xff, 0xff, 0xff, 0xff} + # src1 = {2, 2, 2, 2} as u8 + # result = (-1)*2 + (-1)*2 + (-1)*2 + (-1)*2 = -8 + src0 = 0xffffffff # -1, -1, -1, -1 (as i8) + src1 = 0x02020202 # 2, 2, 2, 2 (as u8) + instructions = [ + s_mov_b32(s[0], src0), + s_mov_b32(s[1], src1), + v_mov_b32_e32(v[0], s[0]), + v_mov_b32_e32(v[1], s[1]), + v_mov_b32_e32(v[2], 0), + v_dot4_i32_iu8(v[3], v[0], v[1], v[2], neg=0b001), # src0 signed + ] + st = run_program(instructions, n_lanes=1) + result = st.vgpr[0][3] + if result >= 0x80000000: + result = result - 0x100000000 + self.assertEqual(result, -8) + + def test_v_dot4_i32_iu8_src1_signed(self): + """V_DOT4_I32_IU8 with only src1 signed (neg=0b010).""" + # src0 = {2, 2, 2, 2} as u8 + # src1 = {-1, -1, -1, -1} as i8 = {0xff, 0xff, 0xff, 0xff} + # result = 2*(-1) + 2*(-1) + 2*(-1) + 2*(-1) = -8 + src0 = 0x02020202 # 2, 2, 2, 2 (as u8) + src1 = 0xffffffff # -1, -1, -1, -1 (as i8) + instructions = [ + s_mov_b32(s[0], src0), + s_mov_b32(s[1], src1), + v_mov_b32_e32(v[0], s[0]), + v_mov_b32_e32(v[1], s[1]), + v_mov_b32_e32(v[2], 0), + v_dot4_i32_iu8(v[3], v[0], v[1], v[2], neg=0b010), # src1 signed + ] + st = run_program(instructions, n_lanes=1) + result = st.vgpr[0][3] + if result >= 0x80000000: + result = result - 0x100000000 + self.assertEqual(result, -8) + + def test_v_dot4_i32_iu8_unsigned_as_reference(self): + """V_DOT4_I32_IU8 with both unsigned (neg=0) - same as V_DOT4_U32_U8.""" + # src0 = {0xff, 0xff, 0xff, 0xff} = 255 each as u8 + # src1 = {1, 1, 1, 1} + # result = 255*1 + 255*1 + 255*1 + 255*1 = 1020 + src0 = 0xffffffff + src1 = 0x01010101 + instructions = [ + s_mov_b32(s[0], src0), + s_mov_b32(s[1], src1), + v_mov_b32_e32(v[0], s[0]), + v_mov_b32_e32(v[1], s[1]), + v_mov_b32_e32(v[2], 0), + v_dot4_i32_iu8(v[3], v[0], v[1], v[2], neg=0), # both unsigned + ] + st = run_program(instructions, n_lanes=1) + self.assertEqual(st.vgpr[0][3], 1020) + + def test_v_dot8_i32_iu4_signed_both(self): + """V_DOT8_I32_IU4 with both inputs signed (neg=0b011).""" + # src0 = 8 nibbles: {-1, -2, 3, 4, -1, -2, 3, 4} as i4 + # i4 -1 = 0xf, -2 = 0xe, 3 = 0x3, 4 = 0x4 + # src0 = 0xfe34fe34 + # src1 = {1, 1, 1, 1, 1, 1, 1, 1} as i4 = 0x11111111 + # result = 2 * ((-1)*1 + (-2)*1 + 3*1 + 4*1) = 2 * 4 = 8 + src0 = 0xfe34fe34 + src1 = 0x11111111 + instructions = [ + s_mov_b32(s[0], src0), + s_mov_b32(s[1], src1), + v_mov_b32_e32(v[0], s[0]), + v_mov_b32_e32(v[1], s[1]), + v_mov_b32_e32(v[2], 0), + v_dot8_i32_iu4(v[3], v[0], v[1], v[2], neg=0b011), # both signed + ] + st = run_program(instructions, n_lanes=1) + result = st.vgpr[0][3] + if result >= 0x80000000: + result = result - 0x100000000 + self.assertEqual(result, 8) + + def test_v_dot8_i32_iu4_all_negative(self): + """V_DOT8_I32_IU4 with all negative signed values.""" + # src0 = 8 nibbles all -1 (0xf) = 0xffffffff + # src1 = 8 nibbles all 1 = 0x11111111 + # result = 8 * ((-1)*1) = -8 + src0 = 0xffffffff # all -1 as i4 + src1 = 0x11111111 # all 1 + instructions = [ + s_mov_b32(s[0], src0), + s_mov_b32(s[1], src1), + v_mov_b32_e32(v[0], s[0]), + v_mov_b32_e32(v[1], s[1]), + v_mov_b32_e32(v[2], 0), + v_dot8_i32_iu4(v[3], v[0], v[1], v[2], neg=0b011), # both signed + ] + st = run_program(instructions, n_lanes=1) + result = st.vgpr[0][3] + if result >= 0x80000000: + result = result - 0x100000000 + self.assertEqual(result, -8) + + +class TestPkMinMaxF16(unittest.TestCase): + """Tests for V_PK_MIN_F16 and V_PK_MAX_F16.""" + + def test_v_pk_min_f16_basic(self): + """V_PK_MIN_F16: packed min of two f16 pairs.""" + # src0 = {hi=3.0, lo=1.0}, src1 = {hi=2.0, lo=4.0} + # result = {min(3,2)=2, min(1,4)=1} + src0 = (f32_to_f16(3.0) << 16) | f32_to_f16(1.0) + src1 = (f32_to_f16(2.0) << 16) | f32_to_f16(4.0) + instructions = [ + s_mov_b32(s[0], src0), + s_mov_b32(s[1], src1), + v_mov_b32_e32(v[0], s[0]), + v_mov_b32_e32(v[1], s[1]), + v_pk_min_f16(v[2], v[0], v[1]), + ] + st = run_program(instructions, n_lanes=1) + result = st.vgpr[0][2] + lo = f16(result & 0xffff) + hi = f16((result >> 16) & 0xffff) + self.assertAlmostEqual(lo, 1.0, delta=0.01) + self.assertAlmostEqual(hi, 2.0, delta=0.01) + + def test_v_pk_max_f16_basic(self): + """V_PK_MAX_F16: packed max of two f16 pairs.""" + # src0 = {hi=3.0, lo=1.0}, src1 = {hi=2.0, lo=4.0} + # result = {max(3,2)=3, max(1,4)=4} + src0 = (f32_to_f16(3.0) << 16) | f32_to_f16(1.0) + src1 = (f32_to_f16(2.0) << 16) | f32_to_f16(4.0) + instructions = [ + s_mov_b32(s[0], src0), + s_mov_b32(s[1], src1), + v_mov_b32_e32(v[0], s[0]), + v_mov_b32_e32(v[1], s[1]), + v_pk_max_f16(v[2], v[0], v[1]), + ] + st = run_program(instructions, n_lanes=1) + result = st.vgpr[0][2] + lo = f16(result & 0xffff) + hi = f16((result >> 16) & 0xffff) + self.assertAlmostEqual(lo, 4.0, delta=0.01) + self.assertAlmostEqual(hi, 3.0, delta=0.01) + + if __name__ == '__main__': unittest.main() diff --git a/extra/assembly/amd/test/test_emu2_pcode.py b/extra/assembly/amd/test/test_emu2_pcode.py index 9667c3f421..4c25b4931f 100644 --- a/extra/assembly/amd/test/test_emu2_pcode.py +++ b/extra/assembly/amd/test/test_emu2_pcode.py @@ -294,7 +294,7 @@ class TestAllPcode(unittest.TestCase): 'ADDR': u32(), 'ADDR_BASE': u32(), 'TADDR': u32(), 'DATA': u32(), 'DATA0': u32(), 'DATA1': u32(), 'DATA2': u32(), 'VDATA': u32(), 'VDATA0': u32(), 'VDATA1': u32(), 'VDATA2': u32(), 'VDATA3': u32(), 'OPSEL': u32(), 'OPSEL_HI': u32(), 'NEG': u32(), 'NEG_HI': u32(), 'CLAMP': u32(), - 'M0': u32(), 'PC': u64(), 'DENORM': u32(1), 'ROUND_MODE': u32(), 'WAVE_STATUS': u32(), + 'M0': u32(), 'PC': u64(), 'DENORM': u32(1), 'ROUND_MODE': u32(), 'ROUND_TOWARD_ZERO': u32(), 'ROUND_NEAREST_EVEN': u32(), 'WAVE_STATUS': u32(), 'MAX_FLOAT_F32': u32(0x7f7fffff), 'Unsigned': u32(1), 'clampedLOD': u32(), '_lds': lds, '_vmem': lds, '_active': UOp.const(dtypes.bool, True)}