mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
* renderer/amd: move in tree * fix paths in tests * 24000 lines * no delete for amd files
230 lines
10 KiB
Python
230 lines
10 KiB
Python
#!/usr/bin/env python3
|
|
"""Test DS and other compute-relevant instruction formats.
|
|
|
|
Note: Graphics-only formats (EXP, MUBUF, MTBUF, MIMG) are not supported - use GLOBAL/FLAT for memory access in compute.
|
|
"""
|
|
import unittest
|
|
from tinygrad.runtime.autogen.amd.rdna3.ins import *
|
|
from tinygrad.renderer.amd.dsl import VCC_HI, EXEC_LO, NULL
|
|
OFF = NULL # OFF is alias for NULL
|
|
from tinygrad.renderer.amd import detect_format
|
|
|
|
|
|
class TestDS(unittest.TestCase):
|
|
"""Test DS (data share / LDS) instructions."""
|
|
|
|
def test_ds_store_b32(self):
|
|
# ds_store_b32 v0, v1
|
|
# GFX11: encoding: [0x00,0x00,0x34,0xd8,0x00,0x01,0x00,0x00]
|
|
inst = ds_store_b32(addr=v[0], data0=v[1])
|
|
self.assertEqual(inst.to_bytes(), bytes([0x00,0x00,0x34,0xd8,0x00,0x01,0x00,0x00]))
|
|
|
|
def test_ds_load_b32(self):
|
|
# ds_load_b32 v0, v1
|
|
# GFX11: encoding: [0x00,0x00,0xd8,0xd8,0x01,0x00,0x00,0x00]
|
|
inst = ds_load_b32(vdst=v[0], addr=v[1])
|
|
self.assertEqual(inst.to_bytes(), bytes([0x00,0x00,0xd8,0xd8,0x01,0x00,0x00,0x00]))
|
|
|
|
def test_ds_store_b32_offset(self):
|
|
# ds_store_b32 v0, v1 offset:64
|
|
# GFX11: encoding: [0x40,0x00,0x34,0xd8,0x00,0x01,0x00,0x00]
|
|
inst = ds_store_b32(addr=v[0], data0=v[1], offset0=64)
|
|
self.assertEqual(inst.to_bytes(), bytes([0x40,0x00,0x34,0xd8,0x00,0x01,0x00,0x00]))
|
|
|
|
def test_ds_load_b64(self):
|
|
# ds_load_b64 v[0:1], v2
|
|
# GFX11: encoding: [0x00,0x00,0xd8,0xd9,0x02,0x00,0x00,0x00]
|
|
inst = ds_load_b64(vdst=v[0:1], addr=v[2])
|
|
self.assertEqual(inst.to_bytes(), bytes([0x00,0x00,0xd8,0xd9,0x02,0x00,0x00,0x00]))
|
|
|
|
def test_ds_add_u32(self):
|
|
# ds_add_u32 v0, v1
|
|
# GFX11: encoding: [0x00,0x00,0x00,0xd8,0x00,0x01,0x00,0x00]
|
|
inst = ds_add_u32(addr=v[0], data0=v[1])
|
|
self.assertEqual(inst.to_bytes(), bytes([0x00,0x00,0x00,0xd8,0x00,0x01,0x00,0x00]))
|
|
|
|
def test_ds_store_b32_gds(self):
|
|
# ds_store_b32 v0, v1 gds
|
|
# GFX11: encoding: [0x00,0x00,0x36,0xd8,0x00,0x01,0x00,0x00]
|
|
inst = ds_store_b32(addr=v[0], data0=v[1], gds=1)
|
|
self.assertEqual(inst.to_bytes(), bytes([0x00,0x00,0x36,0xd8,0x00,0x01,0x00,0x00]))
|
|
|
|
|
|
class TestVOP3(unittest.TestCase):
|
|
"""Test VOP3 (3-operand vector) instructions."""
|
|
|
|
def test_v_fma_f32(self):
|
|
# v_fma_f32 v0, v1, v2, v3
|
|
# GFX11: encoding: [0x00,0x00,0x13,0xd6,0x01,0x05,0x0e,0x04]
|
|
inst = v_fma_f32(vdst=v[0], src0=v[1], src1=v[2], src2=v[3])
|
|
self.assertEqual(inst.to_bytes(), bytes([0x00,0x00,0x13,0xd6,0x01,0x05,0x0e,0x04]))
|
|
|
|
def test_v_mad_f32(self):
|
|
# v_fmac_f32_e64 v0, v1, v2 (fmac is fma with implicit dst as src2)
|
|
# Use v_fma_f32 with vdst == src2
|
|
inst = v_fma_f32(vdst=v[0], src0=v[1], src1=v[2], src2=v[0])
|
|
self.assertEqual(inst.to_bytes()[:4], bytes([0x00,0x00,0x13,0xd6]))
|
|
|
|
def test_v_add3_u32(self):
|
|
# v_add3_u32 v0, v1, v2, v3
|
|
# GFX11: encoding: [0x00,0x00,0x55,0xd6,0x01,0x05,0x0e,0x04]
|
|
inst = v_add3_u32(vdst=v[0], src0=v[1], src1=v[2], src2=v[3])
|
|
self.assertEqual(inst.to_bytes(), bytes([0x00,0x00,0x55,0xd6,0x01,0x05,0x0e,0x04]))
|
|
|
|
|
|
class TestFLAT(unittest.TestCase):
|
|
"""Test FLAT/GLOBAL/SCRATCH memory instructions."""
|
|
|
|
def test_global_load_b32(self):
|
|
# global_load_b32 v0, v[1:2], off (seg=2 for global)
|
|
# GFX11: encoding: [0x00,0x00,0x52,0xdc,0x01,0x00,0x7c,0x00]
|
|
inst = global_load_b32(vdst=v[0], addr=v[1:2], saddr=OFF)
|
|
self.assertEqual(inst.to_bytes(), bytes([0x00,0x00,0x52,0xdc,0x01,0x00,0x7c,0x00]))
|
|
|
|
def test_global_store_b32(self):
|
|
# global_store_b32 v[0:1], v2, off (seg=2 for global)
|
|
# GFX11: encoding: [0x00,0x00,0x6a,0xdc,0x00,0x02,0x7c,0x00]
|
|
inst = global_store_b32(addr=v[0:1], data=v[2], saddr=OFF)
|
|
self.assertEqual(inst.to_bytes(), bytes([0x00,0x00,0x6a,0xdc,0x00,0x02,0x7c,0x00]))
|
|
|
|
def test_global_load_b32_saddr(self):
|
|
# global_load_b32 v0, v1, s[0:1] (seg=2 for global)
|
|
# GFX11: encoding: [0x00,0x00,0x52,0xdc,0x01,0x00,0x00,0x00]
|
|
inst = global_load_b32(vdst=v[0], addr=v[1], saddr=s[0:1])
|
|
self.assertEqual(inst.to_bytes(), bytes([0x00,0x00,0x52,0xdc,0x01,0x00,0x00,0x00]))
|
|
|
|
def test_global_load_b32_offset(self):
|
|
# global_load_b32 v0, v[1:2], off offset:256 (seg=2 for global)
|
|
# GFX11: encoding: [0x00,0x01,0x52,0xdc,0x01,0x00,0x7c,0x00]
|
|
inst = global_load_b32(vdst=v[0], addr=v[1:2], saddr=OFF, offset=256)
|
|
self.assertEqual(inst.to_bytes(), bytes([0x00,0x01,0x52,0xdc,0x01,0x00,0x7c,0x00]))
|
|
|
|
def test_global_load_b64(self):
|
|
# global_load_b64 v[0:1], v[2:3], off (seg=2 for global)
|
|
# GFX11: encoding: [0x00,0x00,0x56,0xdc,0x02,0x00,0x7c,0x00]
|
|
inst = global_load_b64(vdst=v[0:1], addr=v[2:3], saddr=OFF)
|
|
self.assertEqual(inst.to_bytes(), bytes([0x00,0x00,0x56,0xdc,0x02,0x00,0x7c,0x00]))
|
|
|
|
|
|
class TestSMEM(unittest.TestCase):
|
|
"""Test SMEM (scalar memory) instructions - regression tests for glc/dlc bit positions."""
|
|
|
|
def test_smem_dlc_bit_position(self):
|
|
# s_load_b32 s5, s[2:3], s0 dlc - tests that DLC is at bit 13 (not bit 14)
|
|
# GFX11: encoding: [0x41,0x21,0x00,0xf4,0x00,0x00,0x00,0x00]
|
|
inst = s_load_b32(sdata=s[5], sbase=s[2:3], soffset=s[0], dlc=1)
|
|
self.assertEqual(inst.to_bytes(), bytes([0x41,0x21,0x00,0xf4,0x00,0x00,0x00,0x00]))
|
|
|
|
def test_smem_glc_bit_position(self):
|
|
# s_load_b32 s5, s[2:3], s0 glc - tests that GLC is at bit 14 (not bit 16)
|
|
# GFX11: encoding: [0x41,0x41,0x00,0xf4,0x00,0x00,0x00,0x00]
|
|
inst = s_load_b32(sdata=s[5], sbase=s[2:3], soffset=s[0], glc=1)
|
|
self.assertEqual(inst.to_bytes(), bytes([0x41,0x41,0x00,0xf4,0x00,0x00,0x00,0x00]))
|
|
|
|
def test_smem_glc_dlc_combined(self):
|
|
# s_load_b32 s5, s[2:3], s0 glc dlc - tests both flags together
|
|
# GFX11: encoding: [0x41,0x61,0x00,0xf4,0x00,0x00,0x00,0x00]
|
|
inst = s_load_b32(sdata=s[5], sbase=s[2:3], soffset=s[0], glc=1, dlc=1)
|
|
self.assertEqual(inst.to_bytes(), bytes([0x41,0x61,0x00,0xf4,0x00,0x00,0x00,0x00]))
|
|
|
|
def test_smem_disasm_roundtrip_dlc(self):
|
|
# Test that disassembly/reassembly preserves DLC bit correctly
|
|
data = bytes([0x41,0x21,0x00,0xf4,0x00,0x00,0x00,0x00])
|
|
decoded = SMEM.from_bytes(data)
|
|
self.assertEqual(decoded.to_bytes(), data)
|
|
|
|
def test_smem_disasm_roundtrip_glc_dlc(self):
|
|
# Test that disassembly/reassembly preserves GLC+DLC bits correctly
|
|
data = bytes([0x41,0x61,0x00,0xf4,0x00,0x00,0x00,0x00])
|
|
decoded = SMEM.from_bytes(data)
|
|
self.assertEqual(decoded.to_bytes(), data)
|
|
|
|
|
|
class TestVOP3Literal(unittest.TestCase):
|
|
"""Test VOP3 literal handling - regression tests for Inst64 literal encoding."""
|
|
|
|
def test_vop3_with_literal(self):
|
|
# v_add3_u32 v5, vcc_hi, 0xaf123456, v255
|
|
# GFX11: encoding: [0x05,0x00,0x55,0xd6,0x6b,0xfe,0xfd,0x07,0x56,0x34,0x12,0xaf]
|
|
inst = VOP3(VOP3Op.V_ADD3_U32, vdst=v[5], src0=VCC_HI, src1=0xaf123456, src2=v[255])
|
|
expected = bytes([0x05,0x00,0x55,0xd6,0x6b,0xfe,0xfd,0x07,0x56,0x34,0x12,0xaf])
|
|
self.assertEqual(inst.to_bytes(), expected)
|
|
|
|
def test_vop3_literal_null_operand(self):
|
|
# v_add3_u32 v5, null, exec_lo, 0xaf123456
|
|
# GFX11: encoding: [0x05,0x00,0x55,0xd6,0x7c,0xfc,0xfc,0x03,0x56,0x34,0x12,0xaf]
|
|
inst = VOP3(VOP3Op.V_ADD3_U32, vdst=v[5], src0=NULL, src1=EXEC_LO, src2=0xaf123456)
|
|
expected = bytes([0x05,0x00,0x55,0xd6,0x7c,0xfc,0xfc,0x03,0x56,0x34,0x12,0xaf])
|
|
self.assertEqual(inst.to_bytes(), expected)
|
|
|
|
def test_vop3p_with_literal(self):
|
|
# Test VOP3P literal encoding (also uses Inst64)
|
|
inst = VOP3P(VOP3POp.V_PK_ADD_F16, vdst=v[5], src0=0.5, src1=0x12345678, src2=v[0])
|
|
self.assertEqual(len(inst.to_bytes()), 12) # 8 bytes + 4 byte literal
|
|
|
|
|
|
class TestDetectFormat(unittest.TestCase):
|
|
"""Test detect_format uses encoding from autogen classes."""
|
|
|
|
def test_detect_sopp(self):
|
|
self.assertEqual(detect_format(s_endpgm().to_bytes()), SOPP)
|
|
self.assertEqual(detect_format(s_nop(0).to_bytes()), SOPP)
|
|
self.assertEqual(detect_format(s_barrier().to_bytes()), SOPP)
|
|
|
|
def test_detect_sop1(self):
|
|
self.assertEqual(detect_format(s_mov_b32(s[0], 0).to_bytes()), SOP1)
|
|
self.assertEqual(detect_format(s_mov_b64(s[0:1], 0).to_bytes()), SOP1)
|
|
|
|
def test_detect_sop2(self):
|
|
self.assertEqual(detect_format(s_add_u32(s[0], s[1], s[2]).to_bytes()), SOP2)
|
|
self.assertEqual(detect_format(s_mul_i32(s[0], s[1], s[2]).to_bytes()), SOP2)
|
|
|
|
def test_detect_sopc(self):
|
|
self.assertEqual(detect_format(s_cmp_eq_i32(s[0], s[1]).to_bytes()), SOPC)
|
|
|
|
def test_detect_sopk(self):
|
|
self.assertEqual(detect_format(s_movk_i32(s[0], 0x1234).to_bytes()), SOPK)
|
|
|
|
def test_detect_vop1(self):
|
|
self.assertEqual(detect_format(v_mov_b32_e32(v[0], 0).to_bytes()), VOP1)
|
|
self.assertEqual(detect_format(v_rcp_f32_e32(v[0], v[1]).to_bytes()), VOP1)
|
|
|
|
def test_detect_vop2(self):
|
|
self.assertEqual(detect_format(v_add_f32_e32(v[0], v[1], v[2]).to_bytes()), VOP2)
|
|
self.assertEqual(detect_format(v_mul_f32_e32(v[0], v[1], v[2]).to_bytes()), VOP2)
|
|
|
|
def test_detect_vopc(self):
|
|
self.assertEqual(detect_format(v_cmp_eq_f32_e32(v[0], v[1]).to_bytes()), VOPC)
|
|
self.assertEqual(detect_format(v_cmp_lt_i32_e32(v[0], v[1]).to_bytes()), VOPC)
|
|
|
|
def test_detect_vop3(self):
|
|
self.assertEqual(detect_format(v_add_f32_e64(v[0], v[1], v[2]).to_bytes()), VOP3)
|
|
self.assertEqual(detect_format(v_fma_f32(v[0], v[1], v[2], v[3]).to_bytes()), VOP3)
|
|
|
|
def test_detect_vop3p(self):
|
|
self.assertEqual(detect_format(VOP3P(VOP3POp.V_PK_ADD_F16, v[0], v[1], v[2], v[3]).to_bytes()), VOP3P)
|
|
|
|
def test_detect_smem(self):
|
|
self.assertEqual(detect_format(s_load_b32(sdata=s[0], sbase=s[2:3], offset=0).to_bytes()), SMEM)
|
|
self.assertEqual(detect_format(s_load_b64(sdata=s[0:1], sbase=s[2:3], soffset=s[5]).to_bytes()), SMEM)
|
|
|
|
def test_detect_ds(self):
|
|
self.assertEqual(detect_format(ds_load_b32(v[0], v[1]).to_bytes()), DS)
|
|
self.assertEqual(detect_format(ds_store_b32(v[0], v[1]).to_bytes()), DS)
|
|
|
|
def test_detect_flat(self):
|
|
self.assertEqual(detect_format(global_load_b32(vdst=v[0], addr=v[1:2], saddr=NULL).to_bytes()), GLOBAL)
|
|
self.assertEqual(detect_format(global_store_b32(addr=v[0:1], data=v[2], saddr=NULL).to_bytes()), GLOBAL)
|
|
|
|
def test_detect_vopd(self):
|
|
inst = VOPD(VOPDOp.V_DUAL_MOV_B32, VOPDOp.V_DUAL_MOV_B32, vdstx=v[0], vdsty=v[1], srcx0=0, srcy0=0)
|
|
self.assertEqual(detect_format(inst.to_bytes()), VOPD)
|
|
|
|
def test_detect_vinterp(self):
|
|
inst = VINTERP(VINTERPOp.V_INTERP_P10_F32, vdst=v[0], src0=v[1], src1=v[2], src2=v[3])
|
|
self.assertEqual(detect_format(inst.to_bytes()), VINTERP)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|