mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
* renderer/amd: move in tree * fix paths in tests * 24000 lines * no delete for amd files
171 lines
5.8 KiB
Python
171 lines
5.8 KiB
Python
import unittest
|
|
from tinygrad.renderer.amd.dsl import *
|
|
from tinygrad.renderer.amd.dsl import VDSTYField
|
|
from tinygrad.runtime.autogen.amd.rdna3.enum import VOP1Op, VOP2Op
|
|
from tinygrad.runtime.autogen.amd.rdna3.ins import VOP1
|
|
|
|
class TestRegisters(unittest.TestCase):
|
|
def test_vgpr_single(self):
|
|
self.assertEqual(repr(v[5]), "v[5]")
|
|
self.assertEqual(v[5].offset, 261) # 256 + 5
|
|
self.assertEqual(v[5].sz, 1)
|
|
|
|
def test_sgpr_single(self):
|
|
self.assertEqual(repr(s[10]), "s[10]")
|
|
self.assertEqual(s[10].offset, 10)
|
|
|
|
def test_vgpr_range(self):
|
|
self.assertEqual(repr(v[0:3]), "v[0:3]")
|
|
self.assertEqual(v[0:3].offset, 256)
|
|
self.assertEqual(v[0:3].sz, 4)
|
|
|
|
def test_sgpr_range(self):
|
|
self.assertEqual(repr(s[4:5]), "s[4:5]")
|
|
self.assertEqual(s[4:5].sz, 2)
|
|
|
|
def test_ttmp_reslice(self):
|
|
# ttmp is src[108:123], so ttmp[0] should be src[108]
|
|
self.assertEqual(ttmp[0].offset, 108)
|
|
self.assertEqual(ttmp[1].offset, 109)
|
|
# ttmp[0:1] is 2 elements (inclusive slicing)
|
|
self.assertEqual(ttmp[0:1].offset, 108)
|
|
self.assertEqual(ttmp[0:1].sz, 2)
|
|
# ttmp[0:1][0] should be src[108]
|
|
self.assertEqual(ttmp[0:1][0].offset, 108)
|
|
|
|
def test_special_regs(self):
|
|
self.assertEqual(NULL.offset, 124)
|
|
self.assertEqual(M0.offset, 125)
|
|
self.assertEqual(EXEC_LO.offset, 126)
|
|
self.assertEqual(EXEC_HI.offset, 127)
|
|
# Check repr round-trips
|
|
self.assertEqual(repr(NULL), "NULL")
|
|
self.assertEqual(repr(M0), "M0")
|
|
self.assertEqual(repr(EXEC_LO), "EXEC_LO")
|
|
self.assertEqual(repr(EXEC), "EXEC")
|
|
|
|
def test_vcc(self):
|
|
self.assertEqual(VCC.offset, 106)
|
|
self.assertEqual(VCC.sz, 2)
|
|
self.assertEqual(VCC_LO.offset, 106)
|
|
self.assertEqual(VCC_HI.offset, 107)
|
|
# Check repr round-trips
|
|
self.assertEqual(repr(VCC_LO), "VCC_LO")
|
|
self.assertEqual(repr(VCC_HI), "VCC_HI")
|
|
self.assertEqual(repr(VCC), "VCC")
|
|
|
|
def test_float_constants(self):
|
|
self.assertEqual(src[240].offset, 240)
|
|
self.assertEqual(repr(src[240]), "0.5")
|
|
self.assertEqual(repr(src[242]), "1.0")
|
|
self.assertEqual(repr(src[243]), "-1.0")
|
|
|
|
def test_int_constants(self):
|
|
self.assertEqual(repr(src[128]), "0")
|
|
self.assertEqual(repr(src[129]), "1")
|
|
self.assertEqual(repr(src[192]), "64")
|
|
self.assertEqual(repr(src[193]), "-1")
|
|
self.assertEqual(repr(src[208]), "-16")
|
|
|
|
class TestEnumBitField(unittest.TestCase):
|
|
def test_enum_name(self):
|
|
self.assertEqual(VOP1Op.V_MOV_B32_E32.name, "V_MOV_B32_E32")
|
|
|
|
def test_enum_value(self):
|
|
self.assertEqual(VOP1Op.V_MOV_B32_E32.value, 1)
|
|
|
|
def test_enum_comparison(self):
|
|
self.assertEqual(VOP1Op.V_MOV_B32_E32, VOP1Op.V_MOV_B32_E32)
|
|
self.assertNotEqual(VOP1Op.V_NOP_E32, VOP1Op.V_MOV_B32_E32)
|
|
|
|
def test_enum_different_types(self):
|
|
# VOP1Op and VOP2Op are different enums, even if same value
|
|
self.assertNotEqual(VOP1Op.V_MOV_B32_E32, VOP2Op.V_CNDMASK_B32_E32)
|
|
|
|
def test_wrong_enum_type_raises(self):
|
|
# Passing VOP2Op to VOP1 should raise
|
|
with self.assertRaises(RuntimeError):
|
|
VOP1(VOP2Op.V_CNDMASK_B32_E32, v[5], v[6])
|
|
|
|
class TestVOP1(unittest.TestCase):
|
|
def test_class_setup(self):
|
|
self.assertEqual(VOP1._size(), 4)
|
|
field_names = [n for n, _ in VOP1._fields]
|
|
self.assertIn('encoding', field_names)
|
|
self.assertIn('op', field_names)
|
|
self.assertIn('vdst', field_names)
|
|
self.assertIn('src0', field_names)
|
|
|
|
def test_encoding_vgpr_vgpr(self):
|
|
i = VOP1(VOP1Op.V_MOV_B32_E32, v[5], v[6])
|
|
raw = i._raw
|
|
# Check each field
|
|
self.assertEqual((raw >> 25) & 0x7f, 0b0111111) # encoding
|
|
self.assertEqual((raw >> 17) & 0xff, 5) # vdst (just VGPR index)
|
|
self.assertEqual((raw >> 9) & 0xff, 1) # op
|
|
self.assertEqual(raw & 0x1ff, 262) # src0 (256 + 6)
|
|
|
|
def test_encoding_vgpr_sgpr(self):
|
|
i = VOP1(VOP1Op.V_MOV_B32_E32, v[5], s[10])
|
|
raw = i._raw
|
|
self.assertEqual((raw >> 17) & 0xff, 5) # vdst (just VGPR index)
|
|
self.assertEqual(raw & 0x1ff, 10) # src0 (SGPR encoded)
|
|
|
|
def test_to_bytes(self):
|
|
i = VOP1(VOP1Op.V_MOV_B32_E32, v[5], v[6])
|
|
b = i.to_bytes()
|
|
self.assertEqual(len(b), 4)
|
|
self.assertEqual(int.from_bytes(b, 'little'), i._raw)
|
|
|
|
def test_from_bytes(self):
|
|
i1 = VOP1(VOP1Op.V_MOV_B32_E32, v[5], v[6])
|
|
i2 = VOP1.from_bytes(i1.to_bytes())
|
|
self.assertEqual(i1._raw, i2._raw)
|
|
|
|
def test_repr(self):
|
|
i = VOP1(VOP1Op.V_MOV_B32_E32, v[5], v[6])
|
|
self.assertEqual(repr(i), "v_mov_b32_e32(v[5], v[6])")
|
|
|
|
def test_repr_sgpr_src(self):
|
|
i = VOP1(VOP1Op.V_MOV_B32_E32, v[5], s[10])
|
|
self.assertEqual(repr(i), "v_mov_b32_e32(v[5], s[10])")
|
|
|
|
def test_kwargs(self):
|
|
i1 = VOP1(VOP1Op.V_MOV_B32_E32, v[5], v[6])
|
|
i2 = VOP1(op=VOP1Op.V_MOV_B32_E32, vdst=v[5], src0=v[6])
|
|
self.assertEqual(i1._raw, i2._raw)
|
|
|
|
def test_kwargs_partial(self):
|
|
i1 = VOP1(VOP1Op.V_MOV_B32_E32, v[5], v[6])
|
|
i2 = VOP1(VOP1Op.V_MOV_B32_E32, src0=v[6], vdst=v[5])
|
|
self.assertEqual(i1._raw, i2._raw)
|
|
|
|
class TestVDSTYField(unittest.TestCase):
|
|
def test_encode_even_vgpr(self):
|
|
f = VDSTYField(6, 0) # 7-bit field
|
|
self.assertEqual(f.encode(v[0]), 0)
|
|
self.assertEqual(f.encode(v[2]), 1)
|
|
self.assertEqual(f.encode(v[4]), 2)
|
|
self.assertEqual(f.encode(v[254]), 127)
|
|
|
|
def test_encode_non_vgpr_raises(self):
|
|
f = VDSTYField(6, 0)
|
|
with self.assertRaises(ValueError) as ctx:
|
|
f.encode(s[0])
|
|
self.assertIn("VGPR", str(ctx.exception))
|
|
|
|
def test_encode_non_reg_raises(self):
|
|
f = VDSTYField(6, 0)
|
|
with self.assertRaises(TypeError) as ctx:
|
|
f.encode(42)
|
|
self.assertIn("Reg", str(ctx.exception))
|
|
|
|
def test_decode_returns_raw(self):
|
|
f = VDSTYField(6, 0)
|
|
# decode returns raw value, actual vdsty computed with vdstx context
|
|
self.assertEqual(f.decode(0), 0)
|
|
self.assertEqual(f.decode(127), 127)
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|