mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-06 21:53:53 -05:00
hwtest fixes for rdna3 dsl
This commit is contained in:
@@ -259,8 +259,10 @@ def disasm(inst: Inst) -> str:
|
||||
else:
|
||||
is_16bit_op = any(x in op_name for x in _16BIT_TYPES) and not any(x in op_name for x in ('dot2', 'pk_', 'sad', 'msad', 'qsad', 'mqsad'))
|
||||
is_f16_dst = is_f16_src = is_f16_src2 = is_16bit_op
|
||||
# Check if any opsel bit is set (any operand uses .h) - if so, we need explicit .l for low-half
|
||||
any_hi = opsel != 0
|
||||
def fmt_vop3_src(v, neg_bit, abs_bit, hi_bit=False, reg_cnt=1, is_16=False):
|
||||
s = _fmt_src_n(v, reg_cnt) if reg_cnt > 1 else f"v{v - 256}.h" if is_16 and v >= 256 and hi_bit else f"v{v - 256}.l" if is_16 and v >= 256 else fmt_src(v)
|
||||
s = _fmt_src_n(v, reg_cnt) if reg_cnt > 1 else f"v{v - 256}.h" if is_16 and v >= 256 and hi_bit else f"v{v - 256}.l" if is_16 and v >= 256 and any_hi else fmt_src(v)
|
||||
if abs_bit: s = f"|{s}|"
|
||||
return f"-{s}" if neg_bit else s
|
||||
# Determine register count for each source (check for cvt-specific 64-bit flags first)
|
||||
@@ -280,7 +282,7 @@ def disasm(inst: Inst) -> str:
|
||||
elif dst_cnt > 1:
|
||||
dst_str = _vreg(vdst, dst_cnt)
|
||||
elif is_f16_dst:
|
||||
dst_str = f"v{vdst}.h" if (opsel & 8) else f"v{vdst}.l"
|
||||
dst_str = f"v{vdst}.h" if (opsel & 8) else f"v{vdst}.l" if any_hi else f"v{vdst}"
|
||||
else:
|
||||
dst_str = f"v{vdst}"
|
||||
clamp_str = " clamp" if clmp else ""
|
||||
|
||||
@@ -40,8 +40,9 @@ bits = _Bits()
|
||||
|
||||
# Register types
|
||||
class Reg:
|
||||
def __init__(self, idx: int, count: int = 1, hi: bool = False): self.idx, self.count, self.hi = idx, count, hi
|
||||
def __init__(self, idx: int, count: int = 1, hi: bool = False, neg: bool = False): self.idx, self.count, self.hi, self.neg = idx, count, hi, neg
|
||||
def __repr__(self): return f"{self.__class__.__name__.lower()[0]}[{self.idx}]" if self.count == 1 else f"{self.__class__.__name__.lower()[0]}[{self.idx}:{self.idx + self.count}]"
|
||||
def __neg__(self): return self.__class__(self.idx, self.count, self.hi, neg=not self.neg)
|
||||
|
||||
T = TypeVar('T', bound=Reg)
|
||||
class _RegFactory(Generic[T]):
|
||||
@@ -162,6 +163,11 @@ class Inst:
|
||||
if name in SRC_FIELDS:
|
||||
encoded = encode_src(val)
|
||||
self._values[name] = RawImm(encoded)
|
||||
# Handle negation modifier for VOP3 instructions
|
||||
if isinstance(val, Reg) and val.neg and 'neg' in self._fields:
|
||||
neg_bit = {'src0': 1, 'src1': 2, 'src2': 4}.get(name, 0)
|
||||
cur_neg = self._values.get('neg', 0)
|
||||
self._values['neg'] = (cur_neg.val if isinstance(cur_neg, RawImm) else cur_neg) | neg_bit
|
||||
# Track literal value if needed (encoded as 255)
|
||||
if encoded == 255 and self._literal is None and isinstance(val, int) and not isinstance(val, IntEnum):
|
||||
self._literal = val
|
||||
|
||||
@@ -48,14 +48,12 @@ class TestHW(unittest.TestCase):
|
||||
])
|
||||
self.assertEqual(out, [2])
|
||||
|
||||
# assembler err
|
||||
@unittest.expectedFailure
|
||||
def test_simple_s_mov(self):
|
||||
out = get_output([
|
||||
s_mov_b32(s[7], 0x7fffffff),
|
||||
v_mov_b32_e32(v[1], s[7]),
|
||||
])
|
||||
self.assertEqual(out, [2])
|
||||
self.assertEqual(out, [0x7fffffff])
|
||||
|
||||
def test_exec_mov(self):
|
||||
out = get_output([
|
||||
@@ -102,8 +100,6 @@ class TestHW(unittest.TestCase):
|
||||
self.assertEqual(run_fmac(a, -b), f16_to_bits(-10.0))
|
||||
self.assertEqual(run_fmac(-a, -b), f16_to_bits(14.0))
|
||||
|
||||
# assembler err
|
||||
@unittest.expectedFailure
|
||||
def test_s_abs_i32(self):
|
||||
def check(x, y, dst=s[10], scc=0):
|
||||
for reg,val in [(dst, y), (SCC, scc)]:
|
||||
@@ -121,8 +117,6 @@ class TestHW(unittest.TestCase):
|
||||
check(0xffffffff, 0x00000001, scc=1)
|
||||
check(0, 0, scc=0)
|
||||
|
||||
# how do I negate a VGPR operand?
|
||||
@unittest.expectedFailure
|
||||
def test_v_rcp_f32_neg_vop3(self):
|
||||
def v_neg_rcp_f32(x:float, y:float):
|
||||
out = get_output([
|
||||
@@ -138,14 +132,12 @@ class TestHW(unittest.TestCase):
|
||||
v_neg_rcp_f32(-2.0, 0.5)
|
||||
v_neg_rcp_f32(2.0, -0.5)
|
||||
|
||||
# how do I negate a VGPR operand?
|
||||
@unittest.expectedFailure
|
||||
def test_v_cndmask_b32_neg(self):
|
||||
def v_neg(x:float, y:float):
|
||||
out = get_output([
|
||||
v_mov_b32_e32(v[1], f32_to_bits(x)),
|
||||
s_mov_b32(s[10], 1),
|
||||
v_cndmask_b32_e32(v[1], v[1], -v[1], s[10]),
|
||||
v_cndmask_b32_e64(v[1], v[1], -v[1], s[10]),
|
||||
])[0]
|
||||
assert out == f32_to_bits(y), f"{f32_from_bits(out)} != {y} / {out} != {f32_to_bits(y)}"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user