mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
assembly/amd: test more instructions (#14365)
* assembly/amd: test more instructions * more * passing * revert * no const fold * remove junk * cleaner
This commit is contained in:
@@ -427,7 +427,8 @@ class _Ctx:
|
||||
pcode = get_pcode(op)
|
||||
vcc_reg = sdst_reg if sdst_reg is not None else VCC_LO.offset
|
||||
if 'VCC' not in srcs: srcs['VCC'] = self.rsgpr_dyn(_c(vcc_reg))
|
||||
srcs.update({'EXEC': exec_mask, 'SCC': self.rsgpr_dyn(_c(SCC.offset)), 'laneId': lane})
|
||||
srcs.update({'EXEC': exec_mask, 'SCC': self.rsgpr_dyn(_c(SCC.offset)), 'laneId': lane,
|
||||
'ROUND_MODE': _c(0), 'ROUND_TOWARD_ZERO': _c(0)}) # rounding mode: 0=RNE, RTZ constant
|
||||
_, assigns = parse_pcode(pcode, srcs)
|
||||
|
||||
raw_stores: list = []
|
||||
@@ -796,10 +797,13 @@ def _compile_vop3p(inst: VOP3P, ctx: _Ctx) -> UOp:
|
||||
return bits
|
||||
def build_remapped_src(src: UOp, opsel_lo_bit: int, opsel_hi_bit: int, neg_lo_bit: int, neg_hi_bit: int) -> UOp:
|
||||
return get_half_bits(src, bool(opsel_lo_bit), bool(neg_lo_bit)) | (get_half_bits(src, bool(opsel_hi_bit), bool(neg_hi_bit)) << UOp.const(dtypes.uint32, 16))
|
||||
s0_new = build_remapped_src(src0, opsel & 1, opsel_hi & 1, neg & 1, neg_hi & 1)
|
||||
s1_new = build_remapped_src(src1, opsel & 2, opsel_hi & 2, neg & 2, neg_hi & 2)
|
||||
s2_new = build_remapped_src(src2, opsel & 4, 1 if opsel_hi2 else 0, neg & 4, neg_hi & 4)
|
||||
srcs = {'S0': s0_new, 'S1': s1_new, 'S2': s2_new}
|
||||
# DOT IU instructions use NEG bits for signed/unsigned selection, not fp16 negation
|
||||
is_dot_iu = 'DOT' in op_name and 'IU' in op_name
|
||||
n0, n1, n2, nh0, nh1, nh2 = (0, 0, 0, 0, 0, 0) if is_dot_iu else (neg & 1, neg & 2, neg & 4, neg_hi & 1, neg_hi & 2, neg_hi & 4)
|
||||
srcs = {'S0': build_remapped_src(src0, opsel & 1, opsel_hi & 1, n0, nh0),
|
||||
'S1': build_remapped_src(src1, opsel & 2, opsel_hi & 2, n1, nh1),
|
||||
'S2': build_remapped_src(src2, opsel & 4, 1 if opsel_hi2 else 0, n2, nh2)}
|
||||
if is_dot_iu: srcs['NEG'] = UOp.const(dtypes.uint32, neg)
|
||||
return ctx.compile_vop_pcode(inst.op, srcs, lane, vdst_reg, exec_mask)
|
||||
|
||||
def _compile_vopd(inst: VOPD, ctx: _Ctx) -> UOp:
|
||||
|
||||
@@ -94,13 +94,19 @@ def _trig_reduce(x, phase=0.0):
|
||||
return UOp(Ops.SIN, x.dtype, (x - n * _const(x.dtype, 6.283185307179586),))
|
||||
|
||||
def _signext(val: UOp) -> UOp:
|
||||
for bits, mask, ext in [(8, 0xFF, 0xFFFFFF00), (16, 0xFFFF, 0xFFFF0000)]:
|
||||
for bits, mask, ext in [(4, 0xF, 0xFFFFFFF0), (8, 0xFF, 0xFFFFFF00), (16, 0xFFFF, 0xFFFF0000)]:
|
||||
if (val.op == Ops.AND and len(val.src) == 2 and val.src[1].op == Ops.CONST and val.src[1].arg == mask) or val.dtype.itemsize == bits // 8:
|
||||
v32 = val.cast(dtypes.uint32) if val.dtype != dtypes.uint32 else val
|
||||
sb = (v32 >> _u32(bits - 1)) & _u32(1)
|
||||
return sb.ne(_u32(0)).where(v32 | _u32(ext), v32).cast(dtypes.int)
|
||||
return val.cast(dtypes.int64) if val.dtype in (dtypes.int, dtypes.int32) else val
|
||||
|
||||
def _signext_4bit(val: UOp) -> UOp:
|
||||
"""Sign extend a 4-bit value to 32-bit signed integer."""
|
||||
v32 = val.cast(dtypes.uint32) if val.dtype != dtypes.uint32 else val
|
||||
sb = (v32 >> _u32(3)) & _u32(1) # sign bit at position 3
|
||||
return sb.ne(_u32(0)).where(v32 | _u32(0xFFFFFFF0), v32).bitcast(dtypes.int)
|
||||
|
||||
def _abs(val: UOp) -> UOp:
|
||||
if val.dtype not in (dtypes.float32, dtypes.float64, dtypes.half): return val
|
||||
_, _, _, _, shift = _float_info(val)
|
||||
@@ -227,11 +233,44 @@ _FUNCS: dict[str, Callable[..., UOp]] = {
|
||||
'signext_from_bit': _signext_from_bit, 'ldexp': _ldexp, 'frexp_mant': _frexp_mant, 'mantissa': _frexp_mant,
|
||||
'frexp_exp': _frexp_exp, 'trig_preop_result': _trig_preop,
|
||||
's_ff1_i32_b32': lambda a: _ff1(a, 32), 's_ff1_i32_b64': lambda a: _ff1(a, 64),
|
||||
# Normalization conversions: map [-1,1] or [0,1] to integer range
|
||||
# Use floor(x + 0.5) for round-to-nearest
|
||||
# SNORM: round(value * 32767), range is [-32767, 32767] (hardware behavior)
|
||||
'f16_to_snorm': lambda a: _floor(_f16_extract(a).cast(dtypes.float32) * _const(dtypes.float32, 32767) + _const(dtypes.float32, 0.5)).cast(dtypes.int).cast(dtypes.int16),
|
||||
'f16_to_unorm': lambda a: _floor(_f16_extract(a).cast(dtypes.float32) * _const(dtypes.float32, 65535) + _const(dtypes.float32, 0.5)).cast(dtypes.uint16),
|
||||
'f32_to_snorm': lambda a: _floor(a.bitcast(dtypes.float32) * _const(dtypes.float32, 32767) + _const(dtypes.float32, 0.5)).cast(dtypes.int).cast(dtypes.int16),
|
||||
'f32_to_unorm': lambda a: _floor(a.bitcast(dtypes.float32) * _const(dtypes.float32, 65535) + _const(dtypes.float32, 0.5)).cast(dtypes.uint16),
|
||||
'f32_to_u8': lambda a: _f_to_u(a.bitcast(dtypes.float32), dtypes.uint8),
|
||||
# Integer truncation conversions
|
||||
'i32_to_i16': lambda a: a.cast(dtypes.int).cast(dtypes.int16),
|
||||
'u32_to_u16': lambda a: a.cast(dtypes.uint32).cast(dtypes.uint16),
|
||||
'u16_to_u32': lambda a: (a.cast(dtypes.uint32) & _u32(0xFFFF)),
|
||||
'u8_to_u32': lambda a: (a.cast(dtypes.uint32) & _u32(0xFF)),
|
||||
'u4_to_u32': lambda a: (a.cast(dtypes.uint32) & _u32(0xF)),
|
||||
# Signed extraction with sign extension for dot products
|
||||
'i16_to_i32': lambda a: _signext(a.cast(dtypes.uint32) & _u32(0xFFFF)),
|
||||
'i8_to_i32': lambda a: _signext(a.cast(dtypes.uint32) & _u32(0xFF)),
|
||||
'i4_to_i32': lambda a: _signext_4bit(a.cast(dtypes.uint32) & _u32(0xF)),
|
||||
# Float to int16 conversions
|
||||
'v_cvt_i16_f32': lambda a: UOp(Ops.TRUNC, dtypes.float32, (a.bitcast(dtypes.float32),)).cast(dtypes.int16),
|
||||
'v_cvt_u16_f32': lambda a: _f_to_u(a.bitcast(dtypes.float32), dtypes.uint16),
|
||||
}
|
||||
for is_max, name in [(False, 'min'), (True, 'max')]:
|
||||
for dt, sfx in [(dtypes.float32, 'f32'), (dtypes.int, 'i32'), (dtypes.uint32, 'u32'), (dtypes.int16, 'i16'), (dtypes.uint16, 'u16')]:
|
||||
_FUNCS[f'v_{name}_{sfx}'] = lambda *a, im=is_max, d=dt: _minmax_reduce(im, d, *a)
|
||||
_FUNCS[f'v_{name}3_{sfx}'] = lambda *a, im=is_max, d=dt: _minmax_reduce(im, d, *a)
|
||||
# f16 min/max/min3/max3/med3
|
||||
for is_max, name in [(False, 'min'), (True, 'max')]:
|
||||
_FUNCS[f'v_{name}_f16'] = lambda *a, im=is_max: _minmax_reduce(im, dtypes.half, *[_f16_extract(x) for x in a])
|
||||
_FUNCS[f'v_{name}3_f16'] = lambda *a, im=is_max: _minmax_reduce(im, dtypes.half, *[_f16_extract(x) for x in a])
|
||||
_FUNCS[f'v_{name}_num_f16'] = lambda *a, im=is_max: _minmax_reduce(im, dtypes.half, *[_f16_extract(x) for x in a])
|
||||
_FUNCS[f'v_{name}_num_f32'] = lambda *a, im=is_max: _minmax_reduce(im, dtypes.float32, *a)
|
||||
_FUNCS[f'v_{name}3_num_f16'] = lambda *a, im=is_max: _minmax_reduce(im, dtypes.half, *[_f16_extract(x) for x in a])
|
||||
_FUNCS[f'v_{name}3_num_f32'] = lambda *a, im=is_max: _minmax_reduce(im, dtypes.float32, *a)
|
||||
_FUNCS[f'v_{name}imum_f16'] = lambda *a, im=is_max: _minmax_reduce(im, dtypes.half, *[_f16_extract(x) for x in a])
|
||||
_FUNCS[f'v_{name}imum_f32'] = lambda *a, im=is_max: _minmax_reduce(im, dtypes.float32, *a)
|
||||
_FUNCS[f'v_{name}imum3_f16'] = lambda *a, im=is_max: _minmax_reduce(im, dtypes.half, *[_f16_extract(x) for x in a])
|
||||
_FUNCS[f'v_{name}imum3_f32'] = lambda *a, im=is_max: _minmax_reduce(im, dtypes.float32, *a)
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# TOKENIZER/PARSER
|
||||
@@ -239,7 +278,7 @@ for is_max, name in [(False, 'min'), (True, 'max')]:
|
||||
|
||||
DTYPES = {'u32': dtypes.uint32, 'i32': dtypes.int, 'f32': dtypes.float32, 'b32': dtypes.uint32, 'u64': dtypes.uint64, 'i64': dtypes.int64,
|
||||
'f64': dtypes.float64, 'b64': dtypes.uint64, 'u16': dtypes.uint16, 'i16': dtypes.short, 'f16': dtypes.half, 'b16': dtypes.uint16,
|
||||
'u8': dtypes.uint8, 'i8': dtypes.int8, 'b8': dtypes.uint8, 'u1': dtypes.uint32}
|
||||
'u8': dtypes.uint8, 'i8': dtypes.int8, 'b8': dtypes.uint8, 'u4': dtypes.uint8, 'i4': dtypes.int8, 'u1': dtypes.uint32}
|
||||
_BITS_DT = {8: dtypes.uint8, 16: dtypes.uint16, 32: dtypes.uint32, 64: dtypes.uint64}
|
||||
_NUM_SUFFIXES = ('ULL', 'LL', 'UL', 'U', 'L', 'F', 'f')
|
||||
def _strip_suffix(num: str) -> tuple[str, str]:
|
||||
@@ -432,14 +471,6 @@ class Parser:
|
||||
return elem
|
||||
if self.at('LBRACKET') and name not in self.vars:
|
||||
self.eat('LBRACKET')
|
||||
if self.at('NUM'):
|
||||
idx_num = int(self.peek().val)
|
||||
if f'{name}{idx_num}' in self.vars:
|
||||
self.eat('NUM')
|
||||
self.eat('RBRACKET')
|
||||
elem = self.vars[f'{name}{idx_num}']
|
||||
if self.try_eat('DOT'): return _cast_to(elem, DTYPES.get(self.eat('IDENT').val, dtypes.uint32))
|
||||
return elem
|
||||
first = self.parse()
|
||||
return self._handle_bracket_rest(first, _u32(0), name)
|
||||
if name in self.vars:
|
||||
@@ -467,6 +498,7 @@ class Parser:
|
||||
if dt == base.dtype: return base
|
||||
if dt.itemsize == 2 and base.dtype.itemsize == 4:
|
||||
return (base & _const(base.dtype, 0xFFFF)).cast(dtypes.uint16) if dt == dtypes.uint16 else (base & _const(base.dtype, 0xFFFF)).cast(dtypes.uint16).bitcast(dt)
|
||||
if field == 'i4': return _signext_4bit(base)
|
||||
return _cast_to(base, dt)
|
||||
|
||||
def _handle_bracket(self, base, var_name: str | None = None) -> UOp:
|
||||
@@ -509,8 +541,8 @@ class Parser:
|
||||
var_name = self._find_var_name(base)
|
||||
if first.op == Ops.CONST:
|
||||
idx = int(first.arg)
|
||||
if var_name and f'{var_name}{idx}' in self.vars:
|
||||
v = self.vars[f'{var_name}{idx}']
|
||||
if var_name and f'{var_name}@{idx}' in self.vars:
|
||||
v = self.vars[f'{var_name}@{idx}']
|
||||
return _cast_to(v, dt_suffix) if dt_suffix else v
|
||||
dt = dtypes.uint64 if base.dtype in (dtypes.uint64, dtypes.int64) else dtypes.uint32
|
||||
base_cast = base.cast(dt) if base.dtype != dt else base
|
||||
@@ -518,7 +550,7 @@ class Parser:
|
||||
return _cast_to(result, dt_suffix) if dt_suffix else result
|
||||
if var_name:
|
||||
idx_u32 = _to_u32(first)
|
||||
elems = [(i, self.vars[f'{var_name}{i}']) for i in range(256) if f'{var_name}{i}' in self.vars]
|
||||
elems = [(i, self.vars[f'{var_name}@{i}']) for i in range(256) if f'{var_name}@{i}' in self.vars]
|
||||
if elems:
|
||||
result = elems[-1][1]
|
||||
for ei, ev in reversed(elems[:-1]):
|
||||
@@ -537,7 +569,7 @@ class Parser:
|
||||
self.eat('RBRACE')
|
||||
var_name = self._find_var_name(base)
|
||||
if var_name:
|
||||
elem = self.vars.get(f'{var_name}{idx}', _u32(0))
|
||||
elem = self.vars.get(f'{var_name}@{idx}', _u32(0)) # use @ to avoid collision with temps like A4
|
||||
if self.try_eat('DOT'):
|
||||
dt_name = self.eat('IDENT').val
|
||||
return _cast_to(elem, DTYPES.get(dt_name, dtypes.uint32))
|
||||
@@ -787,7 +819,7 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di
|
||||
if found_var: vars[found_var] = block_assigns[found_var] = _const(dtypes.bool, False)
|
||||
for loop_i in range(start_val, end_val + 1):
|
||||
subst_lines = [_subst_loop_var(bl, loop_var, loop_i) for bl in body_lines if not (has_break and bl.strip().lower() == 'break')]
|
||||
_, iter_assigns, _ = parse_block(subst_lines, 0, vars, funcs, assigns)
|
||||
_, iter_assigns, _ = parse_block(subst_lines, 0, {**vars, **block_assigns}, funcs, assigns)
|
||||
if has_break:
|
||||
assert found_var is not None
|
||||
found = block_assigns.get(found_var, vars.get(found_var))
|
||||
@@ -944,7 +976,7 @@ def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: di
|
||||
if existing is not None and isinstance(existing, UOp):
|
||||
block_assigns[var] = vars[var] = _set_bit(existing, _u32(idx), val)
|
||||
else:
|
||||
block_assigns[f'{var}{idx}'] = vars[f'{var}{idx}'] = val
|
||||
block_assigns[f'{var}@{idx}'] = vars[f'{var}@{idx}'] = val
|
||||
i += 1; continue
|
||||
|
||||
# Compound assignment: var += or var -=
|
||||
|
||||
@@ -13,7 +13,7 @@ def _i32(f: float) -> int: return struct.unpack('<I', struct.pack('<f', f))[0]
|
||||
def _f32(i: int) -> float: return struct.unpack('<f', struct.pack('<I', i & 0xFFFFFFFF))[0]
|
||||
|
||||
# f16 conversion helpers
|
||||
def _f16(i: int) -> float: return struct.unpack('<e', struct.pack('<H', i & 0xFFFF))[0]
|
||||
def f16(i: int) -> float: return struct.unpack('<e', struct.pack('<H', i & 0xFFFF))[0]
|
||||
def f32_to_f16(f: float) -> int:
|
||||
f = float(f)
|
||||
if math.isnan(f): return 0x7e00
|
||||
|
||||
@@ -255,7 +255,6 @@ class TestF16Conversions(unittest.TestCase):
|
||||
|
||||
def test_v_cvt_f16_f32_small(self):
|
||||
"""V_CVT_F16_F32 converts small f32 value."""
|
||||
from extra.assembly.amd.test.hw.helpers import f32_to_f16
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[0], 0.5),
|
||||
v_cvt_f16_f32_e32(v[1], v[0]),
|
||||
@@ -293,7 +292,6 @@ class TestF16Conversions(unittest.TestCase):
|
||||
|
||||
def test_v_cvt_f16_f32_reads_full_32bit_source(self):
|
||||
"""V_CVT_F16_F32 must read full 32-bit f32 source."""
|
||||
from extra.assembly.amd.test.hw.helpers import _f16
|
||||
instructions = [
|
||||
s_mov_b32(s[0], 0x3fc00000), # f32 1.5
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
@@ -302,7 +300,7 @@ class TestF16Conversions(unittest.TestCase):
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = st.vgpr[0][1]
|
||||
lo_bits = result & 0xffff
|
||||
self.assertEqual(lo_bits, 0x3e00, f"Expected f16(1.5)=0x3e00, got 0x{lo_bits:04x} ({_f16(lo_bits)})")
|
||||
self.assertEqual(lo_bits, 0x3e00, f"Expected f16(1.5)=0x3e00, got 0x{lo_bits:04x} ({f16(lo_bits)})")
|
||||
|
||||
def test_v_cvt_i16_f16_zero(self):
|
||||
"""V_CVT_I16_F16 converts f16 zero to i16 zero."""
|
||||
@@ -696,7 +694,6 @@ class TestCvtF16Modifiers(unittest.TestCase):
|
||||
|
||||
def test_v_cvt_f32_f16_abs_negative(self):
|
||||
"""V_CVT_F32_F16 with |abs| on negative value."""
|
||||
from extra.assembly.amd.test.hw.helpers import f32_to_f16
|
||||
f16_neg1 = f32_to_f16(-1.0) # 0xbc00
|
||||
instructions = [
|
||||
s_mov_b32(s[0], f16_neg1),
|
||||
@@ -709,7 +706,6 @@ class TestCvtF16Modifiers(unittest.TestCase):
|
||||
|
||||
def test_v_cvt_f32_f16_abs_positive(self):
|
||||
"""V_CVT_F32_F16 with |abs| on positive value (should stay positive)."""
|
||||
from extra.assembly.amd.test.hw.helpers import f32_to_f16
|
||||
f16_2 = f32_to_f16(2.0) # 0x4000
|
||||
instructions = [
|
||||
s_mov_b32(s[0], f16_2),
|
||||
@@ -722,7 +718,6 @@ class TestCvtF16Modifiers(unittest.TestCase):
|
||||
|
||||
def test_v_cvt_f32_f16_neg_positive(self):
|
||||
"""V_CVT_F32_F16 with neg on positive value."""
|
||||
from extra.assembly.amd.test.hw.helpers import f32_to_f16
|
||||
f16_2 = f32_to_f16(2.0) # 0x4000
|
||||
instructions = [
|
||||
s_mov_b32(s[0], f16_2),
|
||||
@@ -735,7 +730,6 @@ class TestCvtF16Modifiers(unittest.TestCase):
|
||||
|
||||
def test_v_cvt_f32_f16_neg_negative(self):
|
||||
"""V_CVT_F32_F16 with neg on negative value (double negative)."""
|
||||
from extra.assembly.amd.test.hw.helpers import f32_to_f16
|
||||
f16_neg2 = f32_to_f16(-2.0) # 0xc000
|
||||
instructions = [
|
||||
s_mov_b32(s[0], f16_neg2),
|
||||
@@ -748,7 +742,6 @@ class TestCvtF16Modifiers(unittest.TestCase):
|
||||
|
||||
def test_v_cvt_f16_f32_then_pack_for_wmma(self):
|
||||
"""CVT F32->F16 followed by pack (common WMMA pattern)."""
|
||||
from extra.assembly.amd.test.hw.helpers import _f16
|
||||
f32_val = 3.5
|
||||
instructions = [
|
||||
s_mov_b32(s[0], f2i(f32_val)),
|
||||
@@ -757,8 +750,8 @@ class TestCvtF16Modifiers(unittest.TestCase):
|
||||
v_pack_b32_f16(v[2], v[1], v[1]), # Pack same value
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
lo = _f16(st.vgpr[0][2] & 0xffff)
|
||||
hi = _f16((st.vgpr[0][2] >> 16) & 0xffff)
|
||||
lo = f16(st.vgpr[0][2] & 0xffff)
|
||||
hi = f16((st.vgpr[0][2] >> 16) & 0xffff)
|
||||
self.assertAlmostEqual(lo, f32_val, places=1)
|
||||
self.assertAlmostEqual(hi, f32_val, places=1)
|
||||
|
||||
@@ -804,7 +797,6 @@ class TestConversionRounding(unittest.TestCase):
|
||||
|
||||
def test_f16_to_f32_precision(self):
|
||||
"""F16 to F32 conversion precision."""
|
||||
from extra.assembly.amd.test.hw.helpers import f32_to_f16
|
||||
f16_val = f32_to_f16(1.5)
|
||||
instructions = [
|
||||
s_mov_b32(s[0], f16_val),
|
||||
@@ -816,7 +808,6 @@ class TestConversionRounding(unittest.TestCase):
|
||||
|
||||
def test_f16_denormal_to_f32(self):
|
||||
"""F16 denormal converts to small positive f32."""
|
||||
from extra.assembly.amd.test.hw.helpers import _f16
|
||||
f16_denorm = 0x0001 # Smallest positive f16 denormal
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[0], f16_denorm),
|
||||
@@ -1512,5 +1503,63 @@ class TestReciprocalF16(unittest.TestCase):
|
||||
self.assertAlmostEqual(result, 0.25, places=2, msg="1/4.0 should be 0.25")
|
||||
|
||||
|
||||
class TestCvtNormF16(unittest.TestCase):
|
||||
"""Tests for V_CVT_NORM_I16_F16 and V_CVT_NORM_U16_F16."""
|
||||
|
||||
def test_cvt_norm_i16_f16_positive(self):
|
||||
"""V_CVT_NORM_I16_F16: f16 1.0 -> i16 max (32767)."""
|
||||
instructions = [
|
||||
s_mov_b32(s[0], f32_to_f16(1.0)),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_cvt_norm_i16_f16_e32(v[1], v[0]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = st.vgpr[0][1] & 0xffff
|
||||
self.assertEqual(result, 32767)
|
||||
|
||||
def test_cvt_norm_i16_f16_negative(self):
|
||||
"""V_CVT_NORM_I16_F16: f16 -1.0 -> i16 -32767 (0x8001)."""
|
||||
instructions = [
|
||||
s_mov_b32(s[0], f32_to_f16(-1.0)),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_cvt_norm_i16_f16_e32(v[1], v[0]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = st.vgpr[0][1] & 0xffff
|
||||
self.assertEqual(result, 0x8001) # -32767, hardware uses symmetric range
|
||||
|
||||
def test_cvt_norm_i16_f16_zero(self):
|
||||
"""V_CVT_NORM_I16_F16: f16 0.0 -> i16 0."""
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[0], 0),
|
||||
v_cvt_norm_i16_f16_e32(v[1], v[0]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = st.vgpr[0][1] & 0xffff
|
||||
self.assertEqual(result, 0)
|
||||
|
||||
def test_cvt_norm_u16_f16_one(self):
|
||||
"""V_CVT_NORM_U16_F16: f16 1.0 -> u16 max (65535)."""
|
||||
instructions = [
|
||||
s_mov_b32(s[0], f32_to_f16(1.0)),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_cvt_norm_u16_f16_e32(v[1], v[0]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = st.vgpr[0][1] & 0xffff
|
||||
self.assertEqual(result, 65535)
|
||||
|
||||
def test_cvt_norm_u16_f16_half(self):
|
||||
"""V_CVT_NORM_U16_F16: f16 0.5 -> u16 ~32768."""
|
||||
instructions = [
|
||||
s_mov_b32(s[0], f32_to_f16(0.5)),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_cvt_norm_u16_f16_e32(v[1], v[0]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = st.vgpr[0][1] & 0xffff
|
||||
self.assertAlmostEqual(result, 32768, delta=1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -857,7 +857,6 @@ class TestF16Modifiers(unittest.TestCase):
|
||||
|
||||
def test_v_fma_f16_inline_const_1_0(self):
|
||||
"""V_FMA_F16: a*b + 1.0 should use f16 inline constant."""
|
||||
from extra.assembly.amd.test.hw.helpers import f32_to_f16, _f16
|
||||
f16_a = f32_to_f16(0.325928) # ~0x3537
|
||||
f16_b = f32_to_f16(-0.486572) # ~0xb7c9
|
||||
instructions = [
|
||||
@@ -868,13 +867,12 @@ class TestF16Modifiers(unittest.TestCase):
|
||||
v_fma_f16(v[4], v[4], v[6], 1.0), # 1.0 is inline constant
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = _f16(st.vgpr[0][4] & 0xffff)
|
||||
result = f16(st.vgpr[0][4] & 0xffff)
|
||||
expected = 0.325928 * (-0.486572) + 1.0
|
||||
self.assertAlmostEqual(result, expected, delta=0.01)
|
||||
|
||||
def test_v_fma_f16_inline_const_0_5(self):
|
||||
"""V_FMA_F16: a*b + 0.5 should use f16 inline constant."""
|
||||
from extra.assembly.amd.test.hw.helpers import f32_to_f16, _f16
|
||||
f16_a = f32_to_f16(2.0)
|
||||
f16_b = f32_to_f16(3.0)
|
||||
instructions = [
|
||||
@@ -885,13 +883,12 @@ class TestF16Modifiers(unittest.TestCase):
|
||||
v_fma_f16(v[2], v[0], v[1], 0.5), # 0.5 is inline constant
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = _f16(st.vgpr[0][2] & 0xffff)
|
||||
result = f16(st.vgpr[0][2] & 0xffff)
|
||||
expected = 2.0 * 3.0 + 0.5
|
||||
self.assertAlmostEqual(result, expected, delta=0.01)
|
||||
|
||||
def test_v_fma_f16_inline_const_neg_1_0(self):
|
||||
"""V_FMA_F16: a*b + (-1.0) should use f16 inline constant."""
|
||||
from extra.assembly.amd.test.hw.helpers import f32_to_f16, _f16
|
||||
f16_a = f32_to_f16(2.0)
|
||||
f16_b = f32_to_f16(3.0)
|
||||
instructions = [
|
||||
@@ -902,13 +899,12 @@ class TestF16Modifiers(unittest.TestCase):
|
||||
v_fma_f16(v[2], v[0], v[1], -1.0), # -1.0 is inline constant
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = _f16(st.vgpr[0][2] & 0xffff)
|
||||
result = f16(st.vgpr[0][2] & 0xffff)
|
||||
expected = 2.0 * 3.0 + (-1.0)
|
||||
self.assertAlmostEqual(result, expected, delta=0.01)
|
||||
|
||||
def test_v_add_f16_abs_both(self):
|
||||
"""V_ADD_F16 with abs on both operands."""
|
||||
from extra.assembly.amd.test.hw.helpers import f32_to_f16, _f16
|
||||
f16_neg2 = f32_to_f16(-2.0)
|
||||
f16_neg3 = f32_to_f16(-3.0)
|
||||
instructions = [
|
||||
@@ -919,12 +915,11 @@ class TestF16Modifiers(unittest.TestCase):
|
||||
v_add_f16_e64(v[2], abs(v[0]), abs(v[1])), # |-2| + |-3| = 5
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = _f16(st.vgpr[0][2] & 0xffff)
|
||||
result = f16(st.vgpr[0][2] & 0xffff)
|
||||
self.assertAlmostEqual(result, 5.0, delta=0.01)
|
||||
|
||||
def test_v_mul_f16_neg_abs(self):
|
||||
"""V_MUL_F16 with neg on one operand and abs on another."""
|
||||
from extra.assembly.amd.test.hw.helpers import f32_to_f16, _f16
|
||||
f16_2 = f32_to_f16(2.0)
|
||||
f16_neg3 = f32_to_f16(-3.0)
|
||||
instructions = [
|
||||
@@ -935,7 +930,7 @@ class TestF16Modifiers(unittest.TestCase):
|
||||
v_mul_f16_e64(v[2], -v[0], abs(v[1])), # -(2) * |-3| = -6
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = _f16(st.vgpr[0][2] & 0xffff)
|
||||
result = f16(st.vgpr[0][2] & 0xffff)
|
||||
self.assertAlmostEqual(result, -6.0, delta=0.01)
|
||||
|
||||
def test_v_fmac_f16_hi_dest(self):
|
||||
@@ -943,7 +938,6 @@ class TestF16Modifiers(unittest.TestCase):
|
||||
|
||||
This tests the case from AMD_LLVM sin(0) where V_FMAC_F16 writes to v0.h.
|
||||
"""
|
||||
from extra.assembly.amd.test.hw.helpers import _f16
|
||||
instructions = [
|
||||
s_mov_b32(s[0], 0x38003c00), # v0 = {hi=0.5, lo=1.0}
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
@@ -954,8 +948,8 @@ class TestF16Modifiers(unittest.TestCase):
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
v0 = st.vgpr[0][0]
|
||||
result_hi = _f16((v0 >> 16) & 0xffff)
|
||||
result_lo = _f16(v0 & 0xffff)
|
||||
result_hi = f16((v0 >> 16) & 0xffff)
|
||||
result_lo = f16(v0 & 0xffff)
|
||||
self.assertAlmostEqual(result_hi, 0.5, delta=0.01, msg=f"Expected hi=0.5, got {result_hi}")
|
||||
self.assertAlmostEqual(result_lo, 1.0, delta=0.01, msg=f"Expected lo=1.0, got {result_lo}")
|
||||
|
||||
@@ -2955,5 +2949,318 @@ class TestVOP3Clamp(unittest.TestCase):
|
||||
self.assertAlmostEqual(i2f(st.vgpr[3][1]), 1.0, places=5, msg="lane 3: 2.5 should clamp to 1.0")
|
||||
|
||||
|
||||
class TestCvtPkF16(unittest.TestCase):
|
||||
"""Tests for V_CVT_PK_RTZ_F16_F32 - pack two f32 to f16 with round toward zero."""
|
||||
|
||||
def test_cvt_pk_rtz_f16_f32_basic(self):
|
||||
"""V_CVT_PK_RTZ_F16_F32: basic pack of two f32 values."""
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[0], 1.0),
|
||||
v_mov_b32_e32(v[1], 2.0),
|
||||
v_cvt_pk_rtz_f16_f32_e64(v[2], v[0], v[1]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = st.vgpr[0][2]
|
||||
lo_f16 = f16(result & 0xffff)
|
||||
hi_f16 = f16((result >> 16) & 0xffff)
|
||||
self.assertAlmostEqual(lo_f16, 1.0, delta=0.01)
|
||||
self.assertAlmostEqual(hi_f16, 2.0, delta=0.01)
|
||||
|
||||
|
||||
class TestCvtPkNorm(unittest.TestCase):
|
||||
"""Tests for V_CVT_PK_NORM_I16_F32 and V_CVT_PK_NORM_U16_F32."""
|
||||
|
||||
def test_cvt_pk_norm_i16_f32_basic(self):
|
||||
"""V_CVT_PK_NORM_I16_F32: pack two f32 to normalized i16."""
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[0], 1.0),
|
||||
v_mov_b32_e32(v[1], -1.0),
|
||||
v_cvt_pk_norm_i16_f32(v[2], v[0], v[1]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = st.vgpr[0][2]
|
||||
lo = result & 0xffff
|
||||
hi = (result >> 16) & 0xffff
|
||||
self.assertEqual(lo, 32767)
|
||||
self.assertEqual(hi, 0x8001) # -32767, hardware uses symmetric range
|
||||
|
||||
def test_cvt_pk_norm_u16_f32_basic(self):
|
||||
"""V_CVT_PK_NORM_U16_F32: pack two f32 to normalized u16."""
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[0], 1.0),
|
||||
v_mov_b32_e32(v[1], 0.5),
|
||||
v_cvt_pk_norm_u16_f32(v[2], v[0], v[1]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = st.vgpr[0][2]
|
||||
lo = result & 0xffff
|
||||
hi = (result >> 16) & 0xffff
|
||||
self.assertEqual(lo, 65535)
|
||||
self.assertAlmostEqual(hi, 32768, delta=1)
|
||||
|
||||
|
||||
class TestCvtPkInt(unittest.TestCase):
|
||||
"""Tests for V_CVT_PK_I16_I32, V_CVT_PK_U16_U32, V_CVT_PK_I16_F32, V_CVT_PK_U16_F32."""
|
||||
|
||||
def test_cvt_pk_i16_i32_basic(self):
|
||||
"""V_CVT_PK_I16_I32: pack two i32 to i16."""
|
||||
instructions = [
|
||||
s_mov_b32(s[0], 100),
|
||||
s_mov_b32(s[1], -100 & 0xffffffff),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_mov_b32_e32(v[1], s[1]),
|
||||
v_cvt_pk_i16_i32(v[2], v[0], v[1]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = st.vgpr[0][2]
|
||||
lo = result & 0xffff
|
||||
hi = (result >> 16) & 0xffff
|
||||
lo_signed = lo if lo < 32768 else lo - 65536
|
||||
hi_signed = hi if hi < 32768 else hi - 65536
|
||||
self.assertEqual(lo_signed, 100)
|
||||
self.assertEqual(hi_signed, -100)
|
||||
|
||||
def test_cvt_pk_u16_u32_basic(self):
|
||||
"""V_CVT_PK_U16_U32: pack two u32 to u16."""
|
||||
instructions = [
|
||||
s_mov_b32(s[0], 1000),
|
||||
s_mov_b32(s[1], 2000),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_mov_b32_e32(v[1], s[1]),
|
||||
v_cvt_pk_u16_u32(v[2], v[0], v[1]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = st.vgpr[0][2]
|
||||
lo = result & 0xffff
|
||||
hi = (result >> 16) & 0xffff
|
||||
self.assertEqual(lo, 1000)
|
||||
self.assertEqual(hi, 2000)
|
||||
|
||||
def test_cvt_pk_i16_f32_basic(self):
|
||||
"""V_CVT_PK_I16_F32: convert two f32 to packed i16."""
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[0], 100.5),
|
||||
v_mov_b32_e32(v[1], -50.7),
|
||||
v_cvt_pk_i16_f32(v[2], v[0], v[1]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = st.vgpr[0][2]
|
||||
lo = result & 0xffff
|
||||
hi = (result >> 16) & 0xffff
|
||||
lo_signed = lo if lo < 32768 else lo - 65536
|
||||
hi_signed = hi if hi < 32768 else hi - 65536
|
||||
self.assertEqual(lo_signed, 100)
|
||||
self.assertEqual(hi_signed, -50)
|
||||
|
||||
def test_cvt_pk_u16_f32_basic(self):
|
||||
"""V_CVT_PK_U16_F32: convert two f32 to packed u16."""
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[0], 100.9),
|
||||
v_mov_b32_e32(v[1], 200.1),
|
||||
v_cvt_pk_u16_f32(v[2], v[0], v[1]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = st.vgpr[0][2]
|
||||
lo = result & 0xffff
|
||||
hi = (result >> 16) & 0xffff
|
||||
self.assertEqual(lo, 100)
|
||||
self.assertEqual(hi, 200)
|
||||
|
||||
def test_cvt_pk_u8_f32_basic(self):
|
||||
"""V_CVT_PK_U8_F32: convert f32 to u8 and pack at byte position."""
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[0], 128.5),
|
||||
v_mov_b32_e32(v[1], 0),
|
||||
v_mov_b32_e32(v[2], 0),
|
||||
v_cvt_pk_u8_f32(v[2], v[0], v[1], v[2]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = st.vgpr[0][2]
|
||||
byte0 = result & 0xff
|
||||
self.assertEqual(byte0, 128)
|
||||
|
||||
|
||||
class TestDotProduct(unittest.TestCase):
|
||||
"""Tests for dot product instructions V_DOT4_U32_U8, V_DOT8_U32_U4."""
|
||||
|
||||
def test_v_dot4_u32_u8_basic(self):
|
||||
"""V_DOT4_U32_U8: 4-element dot product of u8 vectors."""
|
||||
src0 = 0x04030201 # {4, 3, 2, 1}
|
||||
src1 = 0x01010101 # {1, 1, 1, 1}
|
||||
instructions = [
|
||||
s_mov_b32(s[0], src0),
|
||||
s_mov_b32(s[1], src1),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_mov_b32_e32(v[1], s[1]),
|
||||
v_mov_b32_e32(v[2], 0),
|
||||
v_dot4_u32_u8(v[2], v[0], v[1], v[2]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = st.vgpr[0][2]
|
||||
self.assertEqual(result, 10)
|
||||
|
||||
def test_v_dot4_u32_u8_with_accumulator(self):
|
||||
"""V_DOT4_U32_U8 with non-zero accumulator."""
|
||||
src0 = 0x02020202 # {2, 2, 2, 2}
|
||||
src1 = 0x03030303 # {3, 3, 3, 3}
|
||||
instructions = [
|
||||
s_mov_b32(s[0], src0),
|
||||
s_mov_b32(s[1], src1),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_mov_b32_e32(v[1], s[1]),
|
||||
v_mov_b32_e32(v[2], 100),
|
||||
v_dot4_u32_u8(v[2], v[0], v[1], v[2]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = st.vgpr[0][2]
|
||||
self.assertEqual(result, 124)
|
||||
|
||||
def test_v_dot8_u32_u4_basic(self):
|
||||
"""V_DOT8_U32_U4: 8-element dot product of u4 vectors."""
|
||||
# src0 = 8 nibbles: {1,2,3,4,5,6,7,8} packed as 0x87654321
|
||||
# src1 = 8 nibbles: {1,1,1,1,1,1,1,1} packed as 0x11111111
|
||||
# result = 1+2+3+4+5+6+7+8 = 36
|
||||
src0 = 0x87654321
|
||||
src1 = 0x11111111
|
||||
instructions = [
|
||||
s_mov_b32(s[0], src0),
|
||||
s_mov_b32(s[1], src1),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_mov_b32_e32(v[1], s[1]),
|
||||
v_mov_b32_e32(v[2], 0),
|
||||
v_dot8_u32_u4(v[2], v[0], v[1], v[2]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = st.vgpr[0][2]
|
||||
self.assertEqual(result, 36)
|
||||
|
||||
|
||||
class TestMinMaxF16Vop3(unittest.TestCase):
|
||||
"""Tests for V_MIN3_F16, V_MAX3_F16, V_MED3_F16, V_MINMAX_F16, V_MAXMIN_F16."""
|
||||
|
||||
def test_v_min3_f16_basic(self):
|
||||
"""V_MIN3_F16: minimum of three f16 values."""
|
||||
instructions = [
|
||||
s_mov_b32(s[0], f32_to_f16(3.0)),
|
||||
s_mov_b32(s[1], f32_to_f16(1.0)),
|
||||
s_mov_b32(s[2], f32_to_f16(2.0)),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_mov_b32_e32(v[1], s[1]),
|
||||
v_mov_b32_e32(v[2], s[2]),
|
||||
v_min3_f16(v[3], v[0], v[1], v[2]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = f16(st.vgpr[0][3] & 0xffff)
|
||||
self.assertAlmostEqual(result, 1.0, delta=0.01)
|
||||
|
||||
def test_v_max3_f16_basic(self):
|
||||
"""V_MAX3_F16: maximum of three f16 values."""
|
||||
instructions = [
|
||||
s_mov_b32(s[0], f32_to_f16(1.0)),
|
||||
s_mov_b32(s[1], f32_to_f16(3.0)),
|
||||
s_mov_b32(s[2], f32_to_f16(2.0)),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_mov_b32_e32(v[1], s[1]),
|
||||
v_mov_b32_e32(v[2], s[2]),
|
||||
v_max3_f16(v[3], v[0], v[1], v[2]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = f16(st.vgpr[0][3] & 0xffff)
|
||||
self.assertAlmostEqual(result, 3.0, delta=0.01)
|
||||
|
||||
def test_v_med3_f16_basic(self):
|
||||
"""V_MED3_F16: median of three f16 values."""
|
||||
instructions = [
|
||||
s_mov_b32(s[0], f32_to_f16(3.0)),
|
||||
s_mov_b32(s[1], f32_to_f16(1.0)),
|
||||
s_mov_b32(s[2], f32_to_f16(2.0)),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_mov_b32_e32(v[1], s[1]),
|
||||
v_mov_b32_e32(v[2], s[2]),
|
||||
v_med3_f16(v[3], v[0], v[1], v[2]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = f16(st.vgpr[0][3] & 0xffff)
|
||||
self.assertAlmostEqual(result, 2.0, delta=0.01)
|
||||
|
||||
def test_v_minmax_f16_basic(self):
|
||||
"""V_MINMAX_F16: clamp(src0, min=src1, max=src2)."""
|
||||
instructions = [
|
||||
s_mov_b32(s[0], f32_to_f16(2.5)),
|
||||
s_mov_b32(s[1], f32_to_f16(1.0)),
|
||||
s_mov_b32(s[2], f32_to_f16(2.0)),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_mov_b32_e32(v[1], s[1]),
|
||||
v_mov_b32_e32(v[2], s[2]),
|
||||
v_minmax_f16(v[3], v[0], v[1], v[2]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = f16(st.vgpr[0][3] & 0xffff)
|
||||
self.assertAlmostEqual(result, 2.0, delta=0.01)
|
||||
|
||||
def test_v_maxmin_f16_basic(self):
|
||||
"""V_MAXMIN_F16: clamp(src0, min=src2, max=src1)."""
|
||||
instructions = [
|
||||
s_mov_b32(s[0], f32_to_f16(0.5)),
|
||||
s_mov_b32(s[1], f32_to_f16(2.0)),
|
||||
s_mov_b32(s[2], f32_to_f16(1.0)),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_mov_b32_e32(v[1], s[1]),
|
||||
v_mov_b32_e32(v[2], s[2]),
|
||||
v_maxmin_f16(v[3], v[0], v[1], v[2]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = f16(st.vgpr[0][3] & 0xffff)
|
||||
self.assertAlmostEqual(result, 1.0, delta=0.01)
|
||||
|
||||
def test_v_min3_f16_with_neg(self):
|
||||
"""V_MIN3_F16 with neg modifier: min(-3, 1, 2) = -3."""
|
||||
instructions = [
|
||||
s_mov_b32(s[0], f32_to_f16(3.0)),
|
||||
s_mov_b32(s[1], f32_to_f16(1.0)),
|
||||
s_mov_b32(s[2], f32_to_f16(2.0)),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_mov_b32_e32(v[1], s[1]),
|
||||
v_mov_b32_e32(v[2], s[2]),
|
||||
v_min3_f16(v[3], -v[0], v[1], v[2]), # neg on first operand
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = f16(st.vgpr[0][3] & 0xffff)
|
||||
self.assertAlmostEqual(result, -3.0, delta=0.01)
|
||||
|
||||
def test_v_max3_f16_with_abs(self):
|
||||
"""V_MAX3_F16 with abs modifier: max(|-3|, 1, 2) = 3."""
|
||||
instructions = [
|
||||
s_mov_b32(s[0], f32_to_f16(-3.0)),
|
||||
s_mov_b32(s[1], f32_to_f16(1.0)),
|
||||
s_mov_b32(s[2], f32_to_f16(2.0)),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_mov_b32_e32(v[1], s[1]),
|
||||
v_mov_b32_e32(v[2], s[2]),
|
||||
v_max3_f16(v[3], abs(v[0]), v[1], v[2]), # abs on first operand
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = f16(st.vgpr[0][3] & 0xffff)
|
||||
self.assertAlmostEqual(result, 3.0, delta=0.01)
|
||||
|
||||
def test_v_med3_f16_opsel_hi(self):
|
||||
"""V_MED3_F16 with opsel reading from hi half."""
|
||||
# Pack two f16 values: hi=5.0, lo=1.0
|
||||
packed = (f32_to_f16(5.0) << 16) | f32_to_f16(1.0)
|
||||
instructions = [
|
||||
s_mov_b32(s[0], packed),
|
||||
s_mov_b32(s[1], f32_to_f16(3.0)),
|
||||
s_mov_b32(s[2], f32_to_f16(4.0)),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_mov_b32_e32(v[1], s[1]),
|
||||
v_mov_b32_e32(v[2], s[2]),
|
||||
# Read hi half of v[0] (5.0), med3(5, 3, 4) = 4
|
||||
v_med3_f16(v[3], v[0].h, v[1], v[2]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = f16(st.vgpr[0][3] & 0xffff)
|
||||
self.assertAlmostEqual(result, 4.0, delta=0.01)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -149,7 +149,6 @@ class TestFmaMix(unittest.TestCase):
|
||||
|
||||
def test_v_fma_mix_f32_src2_f16_lo(self):
|
||||
"""V_FMA_MIX_F32 with src2 as f16 from lo bits."""
|
||||
from extra.assembly.amd.test.hw.helpers import f32_to_f16
|
||||
f16_2 = f32_to_f16(2.0)
|
||||
instructions = [
|
||||
s_mov_b32(s[0], f2i(1.0)),
|
||||
@@ -166,7 +165,6 @@ class TestFmaMix(unittest.TestCase):
|
||||
|
||||
def test_v_fma_mix_f32_src2_f16_hi(self):
|
||||
"""V_FMA_MIX_F32 with src2 as f16 from hi bits."""
|
||||
from extra.assembly.amd.test.hw.helpers import f32_to_f16
|
||||
f16_2 = f32_to_f16(2.0)
|
||||
val = (f16_2 << 16) | 0
|
||||
instructions = [
|
||||
@@ -199,7 +197,6 @@ class TestFmaMix(unittest.TestCase):
|
||||
|
||||
def test_v_fma_mix_f32_with_abs_f16_src2_lo(self):
|
||||
"""V_FMA_MIX_F32 with abs modifier on f16 src2 (lo half). Regression test for sin(1.0) bug."""
|
||||
from extra.assembly.amd.test.hw.helpers import f32_to_f16
|
||||
f16_neg1 = f32_to_f16(-1.0) # 0xbc00
|
||||
instructions = [
|
||||
s_mov_b32(s[0], f2i(0.0)), # src0 = 0.0 (f32)
|
||||
@@ -217,7 +214,6 @@ class TestFmaMix(unittest.TestCase):
|
||||
|
||||
def test_v_fma_mix_f32_with_neg_f16_src2_lo(self):
|
||||
"""V_FMA_MIX_F32 with neg modifier on f16 src2 (lo half)."""
|
||||
from extra.assembly.amd.test.hw.helpers import f32_to_f16
|
||||
f16_1 = f32_to_f16(1.0) # 0x3c00
|
||||
instructions = [
|
||||
s_mov_b32(s[0], f2i(0.0)), # src0 = 0.0 (f32)
|
||||
@@ -235,7 +231,6 @@ class TestFmaMix(unittest.TestCase):
|
||||
|
||||
def test_v_fma_mix_f32_with_abs_f16_src2_hi(self):
|
||||
"""V_FMA_MIX_F32 with abs modifier on f16 src2 (hi half)."""
|
||||
from extra.assembly.amd.test.hw.helpers import f32_to_f16
|
||||
f16_neg1 = f32_to_f16(-1.0) # 0xbc00
|
||||
val = (f16_neg1 << 16) | 0 # -1.0 in hi, 0 in lo
|
||||
instructions = [
|
||||
@@ -254,7 +249,6 @@ class TestFmaMix(unittest.TestCase):
|
||||
|
||||
def test_v_fma_mixlo_f16(self):
|
||||
"""V_FMA_MIXLO_F16 writes to low 16 bits of destination."""
|
||||
from extra.assembly.amd.test.hw.helpers import _f16
|
||||
instructions = [
|
||||
s_mov_b32(s[0], f2i(2.0)),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
@@ -267,14 +261,13 @@ class TestFmaMix(unittest.TestCase):
|
||||
VOP3P(VOP3POp.V_FMA_MIXLO_F16, vdst=v[3], src0=v[0], src1=v[1], src2=v[2], opsel=0, opsel_hi=0, opsel_hi2=0),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
lo = _f16(st.vgpr[0][3] & 0xffff)
|
||||
lo = f16(st.vgpr[0][3] & 0xffff)
|
||||
hi = (st.vgpr[0][3] >> 16) & 0xffff
|
||||
self.assertAlmostEqual(lo, 7.0, places=1)
|
||||
self.assertEqual(hi, 0xdead, f"hi should be preserved, got 0x{hi:04x}")
|
||||
|
||||
def test_v_fma_mixlo_f16_all_f32_sources(self):
|
||||
"""V_FMA_MIXLO_F16 with all f32 sources."""
|
||||
from extra.assembly.amd.test.hw.helpers import _f16
|
||||
instructions = [
|
||||
s_mov_b32(s[0], f2i(1.0)),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
@@ -286,13 +279,12 @@ class TestFmaMix(unittest.TestCase):
|
||||
VOP3P(VOP3POp.V_FMA_MIXLO_F16, vdst=v[3], src0=v[0], src1=v[1], src2=v[2], opsel=0, opsel_hi=0, opsel_hi2=0),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
lo = _f16(st.vgpr[0][3] & 0xffff)
|
||||
lo = f16(st.vgpr[0][3] & 0xffff)
|
||||
# 1*2+3 = 5
|
||||
self.assertAlmostEqual(lo, 5.0, places=1)
|
||||
|
||||
def test_v_fma_mixlo_f16_sin_case(self):
|
||||
"""V_FMA_MIXLO_F16 case from sin kernel."""
|
||||
from extra.assembly.amd.test.hw.helpers import _f16
|
||||
instructions = [
|
||||
s_mov_b32(s[0], 0x3f800000), # f32 1.0
|
||||
v_mov_b32_e32(v[3], s[0]),
|
||||
@@ -305,7 +297,7 @@ class TestFmaMix(unittest.TestCase):
|
||||
VOP3P(VOP3POp.V_FMA_MIXLO_F16, vdst=v[3], src0=v[3], src1=s[6], src2=v[5], opsel=0, opsel_hi=0, opsel_hi2=0),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
lo = _f16(st.vgpr[0][3] & 0xffff)
|
||||
lo = f16(st.vgpr[0][3] & 0xffff)
|
||||
self.assertAlmostEqual(lo, -3.14159, delta=0.01)
|
||||
|
||||
|
||||
@@ -314,7 +306,6 @@ class TestVOP3P(unittest.TestCase):
|
||||
|
||||
def test_v_pk_add_f16_basic(self):
|
||||
"""V_PK_ADD_F16 adds two packed f16 values."""
|
||||
from extra.assembly.amd.test.hw.helpers import _f16
|
||||
instructions = [
|
||||
s_mov_b32(s[0], 0x40003c00), # hi=2.0, lo=1.0
|
||||
s_mov_b32(s[1], 0x44004200), # hi=4.0, lo=3.0
|
||||
@@ -324,14 +315,13 @@ class TestVOP3P(unittest.TestCase):
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = st.vgpr[0][2]
|
||||
lo = _f16(result & 0xffff)
|
||||
hi = _f16((result >> 16) & 0xffff)
|
||||
lo = f16(result & 0xffff)
|
||||
hi = f16((result >> 16) & 0xffff)
|
||||
self.assertAlmostEqual(lo, 4.0, places=2)
|
||||
self.assertAlmostEqual(hi, 6.0, places=2)
|
||||
|
||||
def test_v_pk_mul_f16_basic(self):
|
||||
"""V_PK_MUL_F16 multiplies two packed f16 values."""
|
||||
from extra.assembly.amd.test.hw.helpers import _f16
|
||||
instructions = [
|
||||
s_mov_b32(s[0], 0x42004000), # hi=3.0, lo=2.0
|
||||
s_mov_b32(s[1], 0x45004400), # hi=5.0, lo=4.0
|
||||
@@ -341,14 +331,13 @@ class TestVOP3P(unittest.TestCase):
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = st.vgpr[0][2]
|
||||
lo = _f16(result & 0xffff)
|
||||
hi = _f16((result >> 16) & 0xffff)
|
||||
lo = f16(result & 0xffff)
|
||||
hi = f16((result >> 16) & 0xffff)
|
||||
self.assertAlmostEqual(lo, 8.0, places=1)
|
||||
self.assertAlmostEqual(hi, 15.0, places=1)
|
||||
|
||||
def test_v_pk_fma_f16_basic(self):
|
||||
"""V_PK_FMA_F16: D = A * B + C for packed f16."""
|
||||
from extra.assembly.amd.test.hw.helpers import _f16
|
||||
instructions = [
|
||||
s_mov_b32(s[0], 0x42004000), # A: hi=3.0, lo=2.0
|
||||
s_mov_b32(s[1], 0x45004400), # B: hi=5.0, lo=4.0
|
||||
@@ -360,8 +349,8 @@ class TestVOP3P(unittest.TestCase):
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = st.vgpr[0][3]
|
||||
lo = _f16(result & 0xffff)
|
||||
hi = _f16((result >> 16) & 0xffff)
|
||||
lo = f16(result & 0xffff)
|
||||
hi = f16((result >> 16) & 0xffff)
|
||||
self.assertAlmostEqual(lo, 9.0, places=1) # 2*4+1
|
||||
self.assertAlmostEqual(hi, 16.0, places=0) # 3*5+1
|
||||
|
||||
@@ -370,7 +359,6 @@ class TestVOP3P(unittest.TestCase):
|
||||
Inline constants for VOP3P are f16 values in the low 16 bits only.
|
||||
hi half of inline constant is 0, so hi result = v0.hi + 0 = 1.0.
|
||||
"""
|
||||
from extra.assembly.amd.test.hw.helpers import _f16
|
||||
instructions = [
|
||||
s_mov_b32(s[0], 0x3c003c00), # packed f16: hi=1.0, lo=1.0
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
@@ -378,8 +366,8 @@ class TestVOP3P(unittest.TestCase):
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = st.vgpr[0][1]
|
||||
lo = _f16(result & 0xffff)
|
||||
hi = _f16((result >> 16) & 0xffff)
|
||||
lo = f16(result & 0xffff)
|
||||
hi = f16((result >> 16) & 0xffff)
|
||||
# lo = 1.0 + 1.0 = 2.0, hi = 1.0 + 0.0 = 1.0 (inline const hi half is 0)
|
||||
self.assertAlmostEqual(lo, 2.0, places=2)
|
||||
self.assertAlmostEqual(hi, 1.0, places=2)
|
||||
@@ -388,7 +376,6 @@ class TestVOP3P(unittest.TestCase):
|
||||
"""V_PK_MUL_F16 with inline constant POS_TWO (2.0).
|
||||
Inline constant has value only in low 16 bits, hi is 0.
|
||||
"""
|
||||
from extra.assembly.amd.test.hw.helpers import _f16
|
||||
# v0 = packed (3.0, 4.0), multiply by POS_TWO
|
||||
# lo = 3.0 * 2.0 = 6.0, hi = 4.0 * 0.0 = 0.0 (inline const hi is 0)
|
||||
instructions = [
|
||||
@@ -398,8 +385,8 @@ class TestVOP3P(unittest.TestCase):
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = st.vgpr[0][1]
|
||||
lo = _f16(result & 0xffff)
|
||||
hi = _f16((result >> 16) & 0xffff)
|
||||
lo = f16(result & 0xffff)
|
||||
hi = f16((result >> 16) & 0xffff)
|
||||
self.assertAlmostEqual(lo, 6.0, places=1)
|
||||
self.assertAlmostEqual(hi, 0.0, places=1)
|
||||
|
||||
@@ -413,7 +400,6 @@ class TestWMMAF16(unittest.TestCase):
|
||||
|
||||
def test_v_wmma_f16_16x16x16_f16_all_ones(self):
|
||||
"""V_WMMA_F16_16X16X16_F16 with all ones produces 16.0 in f16."""
|
||||
from extra.assembly.amd.test.hw.helpers import _f16
|
||||
instructions = []
|
||||
instructions.append(s_mov_b32(s[0], 0x3c003c00)) # packed f16 1.0
|
||||
# Initialize A matrix in v[16:23] (8 regs)
|
||||
@@ -432,13 +418,12 @@ class TestWMMAF16(unittest.TestCase):
|
||||
for lane in range(32):
|
||||
for reg in range(8):
|
||||
result = st.vgpr[lane][reg]
|
||||
lo = _f16(result & 0xffff)
|
||||
lo = f16(result & 0xffff)
|
||||
self.assertAlmostEqual(lo, 16.0, places=1, msg=f"v[{reg}] lane {lane}: expected 16.0, got {lo}")
|
||||
self.assertEqual(result >> 16, 0, msg=f"v[{reg}] lane {lane}: hi bits should be 0")
|
||||
|
||||
def test_v_wmma_f16_16x16x16_f16_with_accumulator(self):
|
||||
"""V_WMMA_F16_16X16X16_F16 with non-zero accumulator."""
|
||||
from extra.assembly.amd.test.hw.helpers import _f16
|
||||
instructions = []
|
||||
instructions.append(s_mov_b32(s[0], 0x3c003c00)) # packed f16 1.0
|
||||
instructions.append(s_mov_b32(s[1], 0x4500)) # f16 5.0 in lo bits only
|
||||
@@ -458,7 +443,7 @@ class TestWMMAF16(unittest.TestCase):
|
||||
for lane in range(32):
|
||||
for reg in range(8):
|
||||
result = st.vgpr[lane][reg]
|
||||
lo = _f16(result & 0xffff)
|
||||
lo = f16(result & 0xffff)
|
||||
self.assertAlmostEqual(lo, 21.0, places=0, msg=f"v[{reg}] lane {lane}: expected 21.0, got {lo}")
|
||||
self.assertEqual(result >> 16, 0, msg=f"v[{reg}] lane {lane}: hi bits should be 0")
|
||||
|
||||
@@ -468,7 +453,6 @@ class TestWMMAF16(unittest.TestCase):
|
||||
Regression test: WMMA was using static register indices instead of dynamic.
|
||||
This test uses v[64:71] for A, v[80:87] for B, v[96:103] for C/D.
|
||||
"""
|
||||
from extra.assembly.amd.test.hw.helpers import _f16
|
||||
instructions = []
|
||||
instructions.append(s_mov_b32(s[0], 0x3c003c00)) # packed f16 1.0
|
||||
# Initialize A matrix in v[64:71] (8 regs)
|
||||
@@ -490,7 +474,7 @@ class TestWMMAF16(unittest.TestCase):
|
||||
for lane in range(32):
|
||||
for reg in range(8):
|
||||
result = st.vgpr[lane][reg]
|
||||
lo = _f16(result & 0xffff)
|
||||
lo = f16(result & 0xffff)
|
||||
self.assertAlmostEqual(lo, 16.0, places=1, msg=f"v[{reg}] lane {lane}: expected 16.0, got {lo}")
|
||||
self.assertEqual(result >> 16, 0, msg=f"v[{reg}] lane {lane}: hi bits should be 0")
|
||||
|
||||
@@ -713,7 +697,6 @@ class TestPackedMixedSigns(unittest.TestCase):
|
||||
|
||||
def test_pk_add_f16_mixed_signs(self):
|
||||
"""V_PK_ADD_F16 with mixed positive/negative values."""
|
||||
from extra.assembly.amd.test.hw.helpers import _f16
|
||||
instructions = [
|
||||
s_mov_b32(s[0], 0xc0003c00), # packed: hi=-2.0, lo=1.0
|
||||
s_mov_b32(s[1], 0x3c003c00), # packed: hi=1.0, lo=1.0
|
||||
@@ -723,14 +706,13 @@ class TestPackedMixedSigns(unittest.TestCase):
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = st.vgpr[0][2]
|
||||
lo = _f16(result & 0xffff)
|
||||
hi = _f16((result >> 16) & 0xffff)
|
||||
lo = f16(result & 0xffff)
|
||||
hi = f16((result >> 16) & 0xffff)
|
||||
self.assertAlmostEqual(lo, 2.0, places=2) # 1.0 + 1.0
|
||||
self.assertAlmostEqual(hi, -1.0, places=2) # -2.0 + 1.0
|
||||
|
||||
def test_pk_mul_f16_zero(self):
|
||||
"""V_PK_MUL_F16 with zero."""
|
||||
from extra.assembly.amd.test.hw.helpers import _f16
|
||||
instructions = [
|
||||
s_mov_b32(s[0], 0x40004000), # packed: 2.0, 2.0
|
||||
s_mov_b32(s[1], 0x00000000), # packed: 0.0, 0.0
|
||||
@@ -743,5 +725,276 @@ class TestPackedMixedSigns(unittest.TestCase):
|
||||
self.assertEqual(result, 0x00000000, "2.0 * 0.0 should be 0.0")
|
||||
|
||||
|
||||
class TestDot2F32F16(unittest.TestCase):
|
||||
"""Tests for V_DOT2_F32_F16 - dot product of f16 pairs producing f32."""
|
||||
|
||||
def test_v_dot2_f32_f16_basic(self):
|
||||
"""V_DOT2_F32_F16: dot product of two packed f16 pairs -> f32."""
|
||||
# src0 = {hi=2.0, lo=1.0}, src1 = {hi=4.0, lo=3.0}
|
||||
# result = 1.0*3.0 + 2.0*4.0 + 0 = 3 + 8 = 11.0
|
||||
src0 = (f32_to_f16(2.0) << 16) | f32_to_f16(1.0)
|
||||
src1 = (f32_to_f16(4.0) << 16) | f32_to_f16(3.0)
|
||||
instructions = [
|
||||
s_mov_b32(s[0], src0),
|
||||
s_mov_b32(s[1], src1),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_mov_b32_e32(v[1], s[1]),
|
||||
v_mov_b32_e32(v[2], 0),
|
||||
v_dot2_f32_f16(v[3], v[0], v[1], v[2], opsel_hi=3, opsel_hi2=1),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = i2f(st.vgpr[0][3])
|
||||
self.assertAlmostEqual(result, 11.0, places=2)
|
||||
|
||||
def test_v_dot2_f32_f16_with_accumulator(self):
|
||||
"""V_DOT2_F32_F16 with non-zero f32 accumulator."""
|
||||
# src0 = {hi=1.0, lo=1.0}, src1 = {hi=1.0, lo=1.0}, acc = 5.0
|
||||
# result = 1.0*1.0 + 1.0*1.0 + 5.0 = 7.0
|
||||
src0 = (f32_to_f16(1.0) << 16) | f32_to_f16(1.0)
|
||||
instructions = [
|
||||
s_mov_b32(s[0], src0),
|
||||
s_mov_b32(s[1], f2i(5.0)),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_mov_b32_e32(v[1], s[0]), # same as src0
|
||||
v_mov_b32_e32(v[2], s[1]),
|
||||
v_dot2_f32_f16(v[3], v[0], v[1], v[2], opsel_hi=3, opsel_hi2=1),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = i2f(st.vgpr[0][3])
|
||||
self.assertAlmostEqual(result, 7.0, places=2)
|
||||
|
||||
def test_v_dot2_f32_f16_negative_values(self):
|
||||
"""V_DOT2_F32_F16 with negative f16 values."""
|
||||
# src0 = {hi=-2.0, lo=3.0}, src1 = {hi=1.0, lo=2.0}
|
||||
# result = 3.0*2.0 + (-2.0)*1.0 + 0 = 6 - 2 = 4.0
|
||||
src0 = (f32_to_f16(-2.0) << 16) | f32_to_f16(3.0)
|
||||
src1 = (f32_to_f16(1.0) << 16) | f32_to_f16(2.0)
|
||||
instructions = [
|
||||
s_mov_b32(s[0], src0),
|
||||
s_mov_b32(s[1], src1),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_mov_b32_e32(v[1], s[1]),
|
||||
v_mov_b32_e32(v[2], 0),
|
||||
v_dot2_f32_f16(v[3], v[0], v[1], v[2], opsel_hi=3, opsel_hi2=1),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = i2f(st.vgpr[0][3])
|
||||
self.assertAlmostEqual(result, 4.0, places=2)
|
||||
|
||||
|
||||
class TestDot2F16F16(unittest.TestCase):
|
||||
"""Tests for V_DOT2_F16_F16 - dot product of f16 pairs producing f16."""
|
||||
|
||||
def test_v_dot2_f16_f16_basic(self):
|
||||
"""V_DOT2_F16_F16: dot product of two packed f16 pairs -> f16."""
|
||||
# src0 = {hi=2.0, lo=1.0}, src1 = {hi=3.0, lo=2.0}
|
||||
# result = 1.0*2.0 + 2.0*3.0 + 0 = 2 + 6 = 8.0 (f16)
|
||||
src0 = (f32_to_f16(2.0) << 16) | f32_to_f16(1.0)
|
||||
src1 = (f32_to_f16(3.0) << 16) | f32_to_f16(2.0)
|
||||
instructions = [
|
||||
s_mov_b32(s[0], src0),
|
||||
s_mov_b32(s[1], src1),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_mov_b32_e32(v[1], s[1]),
|
||||
v_mov_b32_e32(v[2], 0),
|
||||
v_dot2_f16_f16(v[3], v[0], v[1], v[2]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = f16(st.vgpr[0][3] & 0xffff)
|
||||
self.assertAlmostEqual(result, 8.0, places=1)
|
||||
|
||||
def test_v_dot2_f16_f16_with_accumulator(self):
|
||||
"""V_DOT2_F16_F16 with non-zero f16 accumulator."""
|
||||
# src0 = {hi=1.0, lo=1.0}, src1 = {hi=1.0, lo=1.0}, acc = 3.0 (f16)
|
||||
# result = 1.0*1.0 + 1.0*1.0 + 3.0 = 5.0 (f16)
|
||||
src0 = (f32_to_f16(1.0) << 16) | f32_to_f16(1.0)
|
||||
acc = f32_to_f16(3.0)
|
||||
instructions = [
|
||||
s_mov_b32(s[0], src0),
|
||||
s_mov_b32(s[2], acc),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_mov_b32_e32(v[1], s[0]), # same as src0
|
||||
v_mov_b32_e32(v[2], s[2]),
|
||||
v_dot2_f16_f16(v[3], v[0], v[1], v[2]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = f16(st.vgpr[0][3] & 0xffff)
|
||||
self.assertAlmostEqual(result, 5.0, places=1)
|
||||
|
||||
|
||||
class TestSignedDotProducts(unittest.TestCase):
|
||||
"""Tests for V_DOT4_I32_IU8 and V_DOT8_I32_IU4 with signed inputs."""
|
||||
|
||||
def test_v_dot4_i32_iu8_signed_both(self):
|
||||
"""V_DOT4_I32_IU8 with both inputs signed (neg=0b011)."""
|
||||
# src0 = {-1, -2, 3, 4} as i8 = {0xff, 0xfe, 0x03, 0x04}
|
||||
# src1 = {1, 1, 1, 1} as i8
|
||||
# result = (-1)*1 + (-2)*1 + 3*1 + 4*1 = -1 - 2 + 3 + 4 = 4
|
||||
src0 = (0xff << 24) | (0xfe << 16) | (0x03 << 8) | 0x04 # -1, -2, 3, 4
|
||||
src1 = 0x01010101 # 1, 1, 1, 1
|
||||
instructions = [
|
||||
s_mov_b32(s[0], src0),
|
||||
s_mov_b32(s[1], src1),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_mov_b32_e32(v[1], s[1]),
|
||||
v_mov_b32_e32(v[2], 0),
|
||||
v_dot4_i32_iu8(v[3], v[0], v[1], v[2], neg=0b011), # both signed
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = st.vgpr[0][3]
|
||||
# Result is i32, interpret as signed
|
||||
if result >= 0x80000000:
|
||||
result = result - 0x100000000
|
||||
self.assertEqual(result, 4)
|
||||
|
||||
def test_v_dot4_i32_iu8_src0_signed(self):
|
||||
"""V_DOT4_I32_IU8 with only src0 signed (neg=0b001)."""
|
||||
# src0 = {-1, -1, -1, -1} as i8 = {0xff, 0xff, 0xff, 0xff}
|
||||
# src1 = {2, 2, 2, 2} as u8
|
||||
# result = (-1)*2 + (-1)*2 + (-1)*2 + (-1)*2 = -8
|
||||
src0 = 0xffffffff # -1, -1, -1, -1 (as i8)
|
||||
src1 = 0x02020202 # 2, 2, 2, 2 (as u8)
|
||||
instructions = [
|
||||
s_mov_b32(s[0], src0),
|
||||
s_mov_b32(s[1], src1),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_mov_b32_e32(v[1], s[1]),
|
||||
v_mov_b32_e32(v[2], 0),
|
||||
v_dot4_i32_iu8(v[3], v[0], v[1], v[2], neg=0b001), # src0 signed
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = st.vgpr[0][3]
|
||||
if result >= 0x80000000:
|
||||
result = result - 0x100000000
|
||||
self.assertEqual(result, -8)
|
||||
|
||||
def test_v_dot4_i32_iu8_src1_signed(self):
|
||||
"""V_DOT4_I32_IU8 with only src1 signed (neg=0b010)."""
|
||||
# src0 = {2, 2, 2, 2} as u8
|
||||
# src1 = {-1, -1, -1, -1} as i8 = {0xff, 0xff, 0xff, 0xff}
|
||||
# result = 2*(-1) + 2*(-1) + 2*(-1) + 2*(-1) = -8
|
||||
src0 = 0x02020202 # 2, 2, 2, 2 (as u8)
|
||||
src1 = 0xffffffff # -1, -1, -1, -1 (as i8)
|
||||
instructions = [
|
||||
s_mov_b32(s[0], src0),
|
||||
s_mov_b32(s[1], src1),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_mov_b32_e32(v[1], s[1]),
|
||||
v_mov_b32_e32(v[2], 0),
|
||||
v_dot4_i32_iu8(v[3], v[0], v[1], v[2], neg=0b010), # src1 signed
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = st.vgpr[0][3]
|
||||
if result >= 0x80000000:
|
||||
result = result - 0x100000000
|
||||
self.assertEqual(result, -8)
|
||||
|
||||
def test_v_dot4_i32_iu8_unsigned_as_reference(self):
|
||||
"""V_DOT4_I32_IU8 with both unsigned (neg=0) - same as V_DOT4_U32_U8."""
|
||||
# src0 = {0xff, 0xff, 0xff, 0xff} = 255 each as u8
|
||||
# src1 = {1, 1, 1, 1}
|
||||
# result = 255*1 + 255*1 + 255*1 + 255*1 = 1020
|
||||
src0 = 0xffffffff
|
||||
src1 = 0x01010101
|
||||
instructions = [
|
||||
s_mov_b32(s[0], src0),
|
||||
s_mov_b32(s[1], src1),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_mov_b32_e32(v[1], s[1]),
|
||||
v_mov_b32_e32(v[2], 0),
|
||||
v_dot4_i32_iu8(v[3], v[0], v[1], v[2], neg=0), # both unsigned
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
self.assertEqual(st.vgpr[0][3], 1020)
|
||||
|
||||
def test_v_dot8_i32_iu4_signed_both(self):
|
||||
"""V_DOT8_I32_IU4 with both inputs signed (neg=0b011)."""
|
||||
# src0 = 8 nibbles: {-1, -2, 3, 4, -1, -2, 3, 4} as i4
|
||||
# i4 -1 = 0xf, -2 = 0xe, 3 = 0x3, 4 = 0x4
|
||||
# src0 = 0xfe34fe34
|
||||
# src1 = {1, 1, 1, 1, 1, 1, 1, 1} as i4 = 0x11111111
|
||||
# result = 2 * ((-1)*1 + (-2)*1 + 3*1 + 4*1) = 2 * 4 = 8
|
||||
src0 = 0xfe34fe34
|
||||
src1 = 0x11111111
|
||||
instructions = [
|
||||
s_mov_b32(s[0], src0),
|
||||
s_mov_b32(s[1], src1),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_mov_b32_e32(v[1], s[1]),
|
||||
v_mov_b32_e32(v[2], 0),
|
||||
v_dot8_i32_iu4(v[3], v[0], v[1], v[2], neg=0b011), # both signed
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = st.vgpr[0][3]
|
||||
if result >= 0x80000000:
|
||||
result = result - 0x100000000
|
||||
self.assertEqual(result, 8)
|
||||
|
||||
def test_v_dot8_i32_iu4_all_negative(self):
|
||||
"""V_DOT8_I32_IU4 with all negative signed values."""
|
||||
# src0 = 8 nibbles all -1 (0xf) = 0xffffffff
|
||||
# src1 = 8 nibbles all 1 = 0x11111111
|
||||
# result = 8 * ((-1)*1) = -8
|
||||
src0 = 0xffffffff # all -1 as i4
|
||||
src1 = 0x11111111 # all 1
|
||||
instructions = [
|
||||
s_mov_b32(s[0], src0),
|
||||
s_mov_b32(s[1], src1),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_mov_b32_e32(v[1], s[1]),
|
||||
v_mov_b32_e32(v[2], 0),
|
||||
v_dot8_i32_iu4(v[3], v[0], v[1], v[2], neg=0b011), # both signed
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = st.vgpr[0][3]
|
||||
if result >= 0x80000000:
|
||||
result = result - 0x100000000
|
||||
self.assertEqual(result, -8)
|
||||
|
||||
|
||||
class TestPkMinMaxF16(unittest.TestCase):
|
||||
"""Tests for V_PK_MIN_F16 and V_PK_MAX_F16."""
|
||||
|
||||
def test_v_pk_min_f16_basic(self):
|
||||
"""V_PK_MIN_F16: packed min of two f16 pairs."""
|
||||
# src0 = {hi=3.0, lo=1.0}, src1 = {hi=2.0, lo=4.0}
|
||||
# result = {min(3,2)=2, min(1,4)=1}
|
||||
src0 = (f32_to_f16(3.0) << 16) | f32_to_f16(1.0)
|
||||
src1 = (f32_to_f16(2.0) << 16) | f32_to_f16(4.0)
|
||||
instructions = [
|
||||
s_mov_b32(s[0], src0),
|
||||
s_mov_b32(s[1], src1),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_mov_b32_e32(v[1], s[1]),
|
||||
v_pk_min_f16(v[2], v[0], v[1]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = st.vgpr[0][2]
|
||||
lo = f16(result & 0xffff)
|
||||
hi = f16((result >> 16) & 0xffff)
|
||||
self.assertAlmostEqual(lo, 1.0, delta=0.01)
|
||||
self.assertAlmostEqual(hi, 2.0, delta=0.01)
|
||||
|
||||
def test_v_pk_max_f16_basic(self):
|
||||
"""V_PK_MAX_F16: packed max of two f16 pairs."""
|
||||
# src0 = {hi=3.0, lo=1.0}, src1 = {hi=2.0, lo=4.0}
|
||||
# result = {max(3,2)=3, max(1,4)=4}
|
||||
src0 = (f32_to_f16(3.0) << 16) | f32_to_f16(1.0)
|
||||
src1 = (f32_to_f16(2.0) << 16) | f32_to_f16(4.0)
|
||||
instructions = [
|
||||
s_mov_b32(s[0], src0),
|
||||
s_mov_b32(s[1], src1),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_mov_b32_e32(v[1], s[1]),
|
||||
v_pk_max_f16(v[2], v[0], v[1]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = st.vgpr[0][2]
|
||||
lo = f16(result & 0xffff)
|
||||
hi = f16((result >> 16) & 0xffff)
|
||||
self.assertAlmostEqual(lo, 4.0, delta=0.01)
|
||||
self.assertAlmostEqual(hi, 3.0, delta=0.01)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -294,7 +294,7 @@ class TestAllPcode(unittest.TestCase):
|
||||
'ADDR': u32(), 'ADDR_BASE': u32(), 'TADDR': u32(), 'DATA': u32(), 'DATA0': u32(), 'DATA1': u32(), 'DATA2': u32(),
|
||||
'VDATA': u32(), 'VDATA0': u32(), 'VDATA1': u32(), 'VDATA2': u32(), 'VDATA3': u32(),
|
||||
'OPSEL': u32(), 'OPSEL_HI': u32(), 'NEG': u32(), 'NEG_HI': u32(), 'CLAMP': u32(),
|
||||
'M0': u32(), 'PC': u64(), 'DENORM': u32(1), 'ROUND_MODE': u32(), 'WAVE_STATUS': u32(),
|
||||
'M0': u32(), 'PC': u64(), 'DENORM': u32(1), 'ROUND_MODE': u32(), 'ROUND_TOWARD_ZERO': u32(), 'ROUND_NEAREST_EVEN': u32(), 'WAVE_STATUS': u32(),
|
||||
'MAX_FLOAT_F32': u32(0x7f7fffff), 'Unsigned': u32(1), 'clampedLOD': u32(),
|
||||
'_lds': lds, '_vmem': lds, '_active': UOp.const(dtypes.bool, True)}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user