Merge origin/master into only_reg_emu2 (keep branch's Reg-based approach)

This commit is contained in:
George Hotz
2025-12-30 18:53:50 +00:00
15 changed files with 6397 additions and 2892 deletions

View File

@@ -654,7 +654,7 @@ jobs:
- name: Run process replay tests
uses: ./.github/actions/process-replay
testrdna3:
testamdasm:
name: AMD ASM IDE
runs-on: ubuntu-24.04
timeout-minutes: 10
@@ -677,8 +677,25 @@ jobs:
run: cloc --by-file extra/assembly/amd/*.py
- name: Run RDNA3 emulator tests
run: python -m pytest -n=auto extra/assembly/amd/ --durations 20
- name: Install pdfplumber
run: pip install pdfplumber
- name: Run RDNA3 emulator tests (AMD_LLVM=1)
run: AMD_LLVM=1 python -m pytest -n=auto extra/assembly/amd/ --durations 20
- name: Run RDNA3 dtype tests
run: PYTHONPATH="." AMD=1 PYTHON_REMU=1 MOCKGPU=1 AMD_LLVM=0 pytest -n=auto test/test_dtype_alu.py test/test_dtype.py
- name: Run RDNA3 dtype tests (AMD_LLVM=1)
run: PYTHONPATH="." AMD=1 PYTHON_REMU=1 MOCKGPU=1 AMD_LLVM=1 pytest -n=auto test/test_dtype_alu.py test/test_dtype.py
testamdautogen:
name: AMD autogen
runs-on: ubuntu-24.04
timeout-minutes: 10
steps:
- name: Checkout Code
uses: actions/checkout@v4
- name: Setup Environment
uses: ./.github/actions/setup-tinygrad
with:
key: rdna3-autogen
pydeps: "pdfplumber"
- name: Verify AMD autogen is up to date
run: |
python -m extra.assembly.amd.dsl --arch all

File diff suppressed because it is too large Load Diff

View File

@@ -2209,33 +2209,33 @@ buffer_atomic_xor_x2 = functools.partial(MUBUF, MUBUFOp.BUFFER_ATOMIC_XOR_X2)
buffer_atomic_inc_x2 = functools.partial(MUBUF, MUBUFOp.BUFFER_ATOMIC_INC_X2)
buffer_atomic_dec_x2 = functools.partial(MUBUF, MUBUFOp.BUFFER_ATOMIC_DEC_X2)
cdna4 = functools.partial(MUBUF, MUBUFOp.CDNA4)
scratch_load_ubyte = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_UBYTE, seg=2)
scratch_load_sbyte = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_SBYTE, seg=2)
scratch_load_ushort = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_USHORT, seg=2)
scratch_load_sshort = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_SSHORT, seg=2)
scratch_load_dword = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_DWORD, seg=2)
scratch_load_dwordx2 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_DWORDX2, seg=2)
scratch_load_dwordx3 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_DWORDX3, seg=2)
scratch_load_dwordx4 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_DWORDX4, seg=2)
scratch_store_byte = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_BYTE, seg=2)
scratch_store_byte_d16_hi = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_BYTE_D16_HI, seg=2)
scratch_store_short = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_SHORT, seg=2)
scratch_store_short_d16_hi = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_SHORT_D16_HI, seg=2)
scratch_store_dword = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_DWORD, seg=2)
scratch_store_dwordx2 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_DWORDX2, seg=2)
scratch_store_dwordx3 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_DWORDX3, seg=2)
scratch_store_dwordx4 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_DWORDX4, seg=2)
scratch_load_ubyte_d16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_UBYTE_D16, seg=2)
scratch_load_ubyte_d16_hi = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_UBYTE_D16_HI, seg=2)
scratch_load_sbyte_d16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_SBYTE_D16, seg=2)
scratch_load_sbyte_d16_hi = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_SBYTE_D16_HI, seg=2)
scratch_load_short_d16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_SHORT_D16, seg=2)
scratch_load_short_d16_hi = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_SHORT_D16_HI, seg=2)
scratch_load_lds_ubyte = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_LDS_UBYTE, seg=2)
scratch_load_lds_sbyte = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_LDS_SBYTE, seg=2)
scratch_load_lds_ushort = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_LDS_USHORT, seg=2)
scratch_load_lds_sshort = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_LDS_SSHORT, seg=2)
scratch_load_lds_dword = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_LDS_DWORD, seg=2)
scratch_load_ubyte = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_UBYTE, seg=1)
scratch_load_sbyte = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_SBYTE, seg=1)
scratch_load_ushort = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_USHORT, seg=1)
scratch_load_sshort = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_SSHORT, seg=1)
scratch_load_dword = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_DWORD, seg=1)
scratch_load_dwordx2 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_DWORDX2, seg=1)
scratch_load_dwordx3 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_DWORDX3, seg=1)
scratch_load_dwordx4 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_DWORDX4, seg=1)
scratch_store_byte = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_BYTE, seg=1)
scratch_store_byte_d16_hi = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_BYTE_D16_HI, seg=1)
scratch_store_short = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_SHORT, seg=1)
scratch_store_short_d16_hi = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_SHORT_D16_HI, seg=1)
scratch_store_dword = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_DWORD, seg=1)
scratch_store_dwordx2 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_DWORDX2, seg=1)
scratch_store_dwordx3 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_DWORDX3, seg=1)
scratch_store_dwordx4 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_DWORDX4, seg=1)
scratch_load_ubyte_d16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_UBYTE_D16, seg=1)
scratch_load_ubyte_d16_hi = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_UBYTE_D16_HI, seg=1)
scratch_load_sbyte_d16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_SBYTE_D16, seg=1)
scratch_load_sbyte_d16_hi = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_SBYTE_D16_HI, seg=1)
scratch_load_short_d16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_SHORT_D16, seg=1)
scratch_load_short_d16_hi = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_SHORT_D16_HI, seg=1)
scratch_load_lds_ubyte = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_LDS_UBYTE, seg=1)
scratch_load_lds_sbyte = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_LDS_SBYTE, seg=1)
scratch_load_lds_ushort = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_LDS_USHORT, seg=1)
scratch_load_lds_sshort = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_LDS_SSHORT, seg=1)
scratch_load_lds_dword = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_LDS_DWORD, seg=1)
s_load_dword = functools.partial(SMEM, SMEMOp.S_LOAD_DWORD)
s_load_dwordx2 = functools.partial(SMEM, SMEMOp.S_LOAD_DWORDX2)
s_load_dwordx4 = functools.partial(SMEM, SMEMOp.S_LOAD_DWORDX4)

File diff suppressed because it is too large Load Diff

View File

@@ -56,6 +56,12 @@ class DSOp(IntEnum):
DS_MAX_F32 = 19
DS_NOP = 20
DS_ADD_F32 = 21
DS_GWS_SEMA_RELEASE_ALL = 24
DS_GWS_INIT = 25
DS_GWS_SEMA_V = 26
DS_GWS_SEMA_BR = 27
DS_GWS_SEMA_P = 28
DS_GWS_BARRIER = 29
DS_STORE_B8 = 30
DS_STORE_B16 = 31
DS_ADD_RTN_U32 = 32
@@ -178,10 +184,13 @@ class FLATOp(IntEnum):
FLAT_LOAD_D16_HI_B16 = 35
FLAT_STORE_D16_HI_B8 = 36
FLAT_STORE_D16_HI_B16 = 37
GLOBAL_LOAD_ADDTID_B32 = 40
GLOBAL_STORE_ADDTID_B32 = 41
FLAT_ATOMIC_SWAP_B32 = 51
FLAT_ATOMIC_CMPSWAP_B32 = 52
FLAT_ATOMIC_ADD_U32 = 53
FLAT_ATOMIC_SUB_U32 = 54
FLAT_ATOMIC_CSUB_U32 = 55
FLAT_ATOMIC_MIN_I32 = 56
FLAT_ATOMIC_MIN_U32 = 57
FLAT_ATOMIC_MAX_I32 = 58
@@ -717,6 +726,7 @@ class SOPPOp(IntEnum):
S_SET_INST_PREFETCH_DISTANCE = 4
S_CLAUSE = 5
S_DELAY_ALU = 7
S_WAITCNT_DEPCTR = 8
S_WAITCNT = 9
S_WAIT_IDLE = 10
S_WAIT_EVENT = 11
@@ -1848,6 +1858,12 @@ ds_min_f32 = functools.partial(DS, DSOp.DS_MIN_F32)
ds_max_f32 = functools.partial(DS, DSOp.DS_MAX_F32)
ds_nop = functools.partial(DS, DSOp.DS_NOP)
ds_add_f32 = functools.partial(DS, DSOp.DS_ADD_F32)
ds_gws_sema_release_all = functools.partial(DS, DSOp.DS_GWS_SEMA_RELEASE_ALL)
ds_gws_init = functools.partial(DS, DSOp.DS_GWS_INIT)
ds_gws_sema_v = functools.partial(DS, DSOp.DS_GWS_SEMA_V)
ds_gws_sema_br = functools.partial(DS, DSOp.DS_GWS_SEMA_BR)
ds_gws_sema_p = functools.partial(DS, DSOp.DS_GWS_SEMA_P)
ds_gws_barrier = functools.partial(DS, DSOp.DS_GWS_BARRIER)
ds_store_b8 = functools.partial(DS, DSOp.DS_STORE_B8)
ds_store_b16 = functools.partial(DS, DSOp.DS_STORE_B16)
ds_add_rtn_u32 = functools.partial(DS, DSOp.DS_ADD_RTN_U32)
@@ -1968,10 +1984,13 @@ flat_load_d16_hi_i8 = functools.partial(FLAT, FLATOp.FLAT_LOAD_D16_HI_I8)
flat_load_d16_hi_b16 = functools.partial(FLAT, FLATOp.FLAT_LOAD_D16_HI_B16)
flat_store_d16_hi_b8 = functools.partial(FLAT, FLATOp.FLAT_STORE_D16_HI_B8)
flat_store_d16_hi_b16 = functools.partial(FLAT, FLATOp.FLAT_STORE_D16_HI_B16)
global_load_addtid_b32 = functools.partial(FLAT, FLATOp.GLOBAL_LOAD_ADDTID_B32)
global_store_addtid_b32 = functools.partial(FLAT, FLATOp.GLOBAL_STORE_ADDTID_B32)
flat_atomic_swap_b32 = functools.partial(FLAT, FLATOp.FLAT_ATOMIC_SWAP_B32)
flat_atomic_cmpswap_b32 = functools.partial(FLAT, FLATOp.FLAT_ATOMIC_CMPSWAP_B32)
flat_atomic_add_u32 = functools.partial(FLAT, FLATOp.FLAT_ATOMIC_ADD_U32)
flat_atomic_sub_u32 = functools.partial(FLAT, FLATOp.FLAT_ATOMIC_SUB_U32)
flat_atomic_csub_u32 = functools.partial(FLAT, FLATOp.FLAT_ATOMIC_CSUB_U32)
flat_atomic_min_i32 = functools.partial(FLAT, FLATOp.FLAT_ATOMIC_MIN_I32)
flat_atomic_min_u32 = functools.partial(FLAT, FLATOp.FLAT_ATOMIC_MIN_U32)
flat_atomic_max_i32 = functools.partial(FLAT, FLATOp.FLAT_ATOMIC_MAX_I32)
@@ -2226,28 +2245,28 @@ buffer_atomic_cmpswap_f32 = functools.partial(MUBUF, MUBUFOp.BUFFER_ATOMIC_CMPSW
buffer_atomic_min_f32 = functools.partial(MUBUF, MUBUFOp.BUFFER_ATOMIC_MIN_F32)
buffer_atomic_max_f32 = functools.partial(MUBUF, MUBUFOp.BUFFER_ATOMIC_MAX_F32)
buffer_atomic_add_f32 = functools.partial(MUBUF, MUBUFOp.BUFFER_ATOMIC_ADD_F32)
scratch_load_u8 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_U8, seg=2)
scratch_load_i8 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_I8, seg=2)
scratch_load_u16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_U16, seg=2)
scratch_load_i16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_I16, seg=2)
scratch_load_b32 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_B32, seg=2)
scratch_load_b64 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_B64, seg=2)
scratch_load_b96 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_B96, seg=2)
scratch_load_b128 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_B128, seg=2)
scratch_store_b8 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_B8, seg=2)
scratch_store_b16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_B16, seg=2)
scratch_store_b32 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_B32, seg=2)
scratch_store_b64 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_B64, seg=2)
scratch_store_b96 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_B96, seg=2)
scratch_store_b128 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_B128, seg=2)
scratch_load_d16_u8 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_D16_U8, seg=2)
scratch_load_d16_i8 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_D16_I8, seg=2)
scratch_load_d16_b16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_D16_B16, seg=2)
scratch_load_d16_hi_u8 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_D16_HI_U8, seg=2)
scratch_load_d16_hi_i8 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_D16_HI_I8, seg=2)
scratch_load_d16_hi_b16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_D16_HI_B16, seg=2)
scratch_store_d16_hi_b8 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_D16_HI_B8, seg=2)
scratch_store_d16_hi_b16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_D16_HI_B16, seg=2)
scratch_load_u8 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_U8, seg=1)
scratch_load_i8 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_I8, seg=1)
scratch_load_u16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_U16, seg=1)
scratch_load_i16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_I16, seg=1)
scratch_load_b32 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_B32, seg=1)
scratch_load_b64 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_B64, seg=1)
scratch_load_b96 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_B96, seg=1)
scratch_load_b128 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_B128, seg=1)
scratch_store_b8 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_B8, seg=1)
scratch_store_b16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_B16, seg=1)
scratch_store_b32 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_B32, seg=1)
scratch_store_b64 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_B64, seg=1)
scratch_store_b96 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_B96, seg=1)
scratch_store_b128 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_B128, seg=1)
scratch_load_d16_u8 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_D16_U8, seg=1)
scratch_load_d16_i8 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_D16_I8, seg=1)
scratch_load_d16_b16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_D16_B16, seg=1)
scratch_load_d16_hi_u8 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_D16_HI_U8, seg=1)
scratch_load_d16_hi_i8 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_D16_HI_I8, seg=1)
scratch_load_d16_hi_b16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_D16_HI_B16, seg=1)
scratch_store_d16_hi_b8 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_D16_HI_B8, seg=1)
scratch_store_d16_hi_b16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_D16_HI_B16, seg=1)
s_load_b32 = functools.partial(SMEM, SMEMOp.S_LOAD_B32)
s_load_b64 = functools.partial(SMEM, SMEMOp.S_LOAD_B64)
s_load_b128 = functools.partial(SMEM, SMEMOp.S_LOAD_B128)
@@ -2485,6 +2504,7 @@ s_sleep = functools.partial(SOPP, SOPPOp.S_SLEEP)
s_set_inst_prefetch_distance = functools.partial(SOPP, SOPPOp.S_SET_INST_PREFETCH_DISTANCE)
s_clause = functools.partial(SOPP, SOPPOp.S_CLAUSE)
s_delay_alu = functools.partial(SOPP, SOPPOp.S_DELAY_ALU)
s_waitcnt_depctr = functools.partial(SOPP, SOPPOp.S_WAITCNT_DEPCTR)
s_waitcnt = functools.partial(SOPP, SOPPOp.S_WAITCNT)
s_wait_idle = functools.partial(SOPP, SOPPOp.S_WAIT_IDLE)
s_wait_event = functools.partial(SOPP, SOPPOp.S_WAIT_EVENT)

File diff suppressed because it is too large Load Diff

View File

@@ -6,22 +6,21 @@ from typing import overload, Annotated, TypeVar, Generic
# Bit field DSL
class BitField:
def __init__(self, hi: int, lo: int, name: str | None = None): self.hi, self.lo, self.name = hi, lo, name
def __set_name__(self, owner, name): self.name, self._owner = name, owner
def __init__(self, hi: int, lo: int, name: str | None = None): self.hi, self.lo, self.name, self._marker = hi, lo, name, None
def __set_name__(self, owner, name):
import typing
self.name, self._owner = name, owner
# Cache marker at class definition time
hints = typing.get_type_hints(owner, include_extras=True)
if name in hints:
hint = hints[name]
if typing.get_origin(hint) is Annotated:
args = typing.get_args(hint)
self._marker = args[1] if len(args) > 1 else None
def __eq__(self, val: int) -> tuple[BitField, int]: return (self, val) # type: ignore
def mask(self) -> int: return (1 << (self.hi - self.lo + 1)) - 1
@property
def marker(self) -> type | None:
# Get marker from Annotated type hint if present
import typing
if hasattr(self, '_owner') and self.name:
hints = typing.get_type_hints(self._owner, include_extras=True)
if self.name in hints:
hint = hints[self.name]
if typing.get_origin(hint) is Annotated:
args = typing.get_args(hint)
return args[1] if len(args) > 1 else None
return None
def marker(self) -> type | None: return self._marker
@overload
def __get__(self, obj: None, objtype: type) -> BitField: ...
@overload
@@ -179,6 +178,21 @@ class Inst:
raise ValueError(f"SOP1 {op_val.name} expects {expected} destination register(s), got {sdst_val.count}")
if isinstance(ssrc0_val, Reg) and ssrc0_val.count != expected:
raise ValueError(f"SOP1 {op_val.name} expects {expected} source register(s), got {ssrc0_val.count}")
# FLAT: set sve=1 when addr is a VGPR for scratch only
# For scratch (seg=1), sve=1 means addr VGPR is used; sve=0 means addr is "off"
# For global (seg=2) and flat (seg=0), sve is always 0
if self.__class__.__name__ == 'FLAT' and 'sve' in self._fields:
seg_val = self._values.get('seg', 0)
if isinstance(seg_val, RawImm): seg_val = seg_val.val
addr_val = orig_args.get('addr')
if seg_val == 1 and isinstance(addr_val, VGPR): self._values['sve'] = 1
# VOP3P: v_fma_mix* instructions (opcodes 32-34) have opsel_hi default of 0, not 7
if self.__class__.__name__ == 'VOP3P':
op_val = orig_args.get(field_names[0]) if args else orig_args.get('op')
if hasattr(op_val, 'value'): op_val = op_val.value
if op_val in (32, 33, 34) and 'opsel_hi' not in orig_args and 'opsel_hi2' not in orig_args:
self._values['opsel_hi'] = 0
self._values['opsel_hi2'] = 0
# Type check and encode values
for name, val in list(self._values.items()):
if name == 'encoding': continue
@@ -283,6 +297,10 @@ class Inst:
from extra.assembly.amd.autogen.rdna3 import VOP3Op
try: op_name = VOP3Op(op).name
except ValueError: pass
if op_name is None and self.__class__.__name__ == 'VOPC':
from extra.assembly.amd.autogen.rdna3 import VOPCOp
try: op_name = VOPCOp(op).name
except ValueError: pass
if op_name is None: return False
# V_LDEXP_F64 has 32-bit integer exponent in src1, so literal is 32-bit
if op_name == 'V_LDEXP_F64': return False
@@ -315,6 +333,9 @@ class Inst:
op_val = inst._values.get('op', 0)
has_literal = cls.__name__ == 'VOP2' and op_val in (44, 45, 55, 56)
has_literal = has_literal or (cls.__name__ == 'SOP2' and op_val in (69, 70))
# VOPD fmaak/fmamk always have a literal (opx/opy value 1 or 2)
opx, opy = inst._values.get('opx', 0), inst._values.get('opy', 0)
has_literal = has_literal or (cls.__name__ == 'VOPD' and (opx in (1, 2) or opy in (1, 2)))
for n in SRC_FIELDS:
if n in inst._values and isinstance(inst._values[n], RawImm) and inst._values[n].val == 255: has_literal = True
if has_literal:
@@ -333,6 +354,14 @@ class Inst:
lit = f", literal={hex(self._literal)}" if self._literal is not None else ""
return f"{self.__class__.__name__}({', '.join(f'{k}={v}' for k, v in items)}{lit})"
def __getattr__(self, name: str):
if name.startswith('_'): raise AttributeError(name)
return unwrap(self._values.get(name, 0))
def lit(self, v: int) -> str:
from extra.assembly.amd.asm import decode_src
return f"0x{self._literal:x}" if v == 255 and self._literal else decode_src(v)
def __eq__(self, other):
if not isinstance(other, Inst): return NotImplemented
return self.__class__ == other.__class__ and self._values == other._values and self._literal == other._literal
@@ -512,10 +541,24 @@ def _parse_single_pdf(url: str) -> dict:
break
formats[fmt_name] = fields
# fix known PDF errors
# fix known PDF errors - assert if already present (so we know when the bug is fixed)
if 'SMEM' in formats:
formats['SMEM'] = [(n, 13 if n == 'DLC' else 14 if n == 'GLC' else h, 13 if n == 'DLC' else 14 if n == 'GLC' else l, e, t)
for n, h, l, e, t in formats['SMEM']]
# add missing opcodes not in PDF tables (RDNA3/RDNA3.5 specific)
if doc_name in ('RDNA3', 'RDNA3.5'):
if 'SOPPOp' in enums:
assert 8 not in enums['SOPPOp'], "S_WAITCNT_DEPCTR now in PDF, remove workaround"
enums['SOPPOp'][8] = 'S_WAITCNT_DEPCTR'
if 'DSOp' in enums:
gws_ops = {24: 'DS_GWS_SEMA_RELEASE_ALL', 25: 'DS_GWS_INIT', 26: 'DS_GWS_SEMA_V',
27: 'DS_GWS_SEMA_BR', 28: 'DS_GWS_SEMA_P', 29: 'DS_GWS_BARRIER'}
for k in gws_ops: assert k not in enums['DSOp'], f"{gws_ops[k]} now in PDF, remove workaround"
enums['DSOp'].update(gws_ops)
if 'FLATOp' in enums:
flat_ops = {40: 'GLOBAL_LOAD_ADDTID_B32', 41: 'GLOBAL_STORE_ADDTID_B32', 55: 'FLAT_ATOMIC_CSUB_U32'}
for k in flat_ops: assert k not in enums['FLATOp'], f"{flat_ops[k]} now in PDF, remove workaround"
enums['FLATOp'].update(flat_ops)
return {"formats": formats, "enums": enums, "src_enum": src_enum, "doc_name": doc_name, "is_cdna": is_cdna}
@@ -601,7 +644,7 @@ def generate(output_path: str | None = None, arch: str = "rdna3") -> dict:
for cls_name, ops in sorted(enums.items()):
fmt = cls_name[:-2]
for op_val, name in sorted(ops.items()):
seg = {"GLOBAL": ", seg=2", "SCRATCH": ", seg=2"}.get(fmt, "")
seg = {"GLOBAL": ", seg=2", "SCRATCH": ", seg=1"}.get(fmt, "")
tgt = {"GLOBAL": "FLAT, GLOBALOp", "SCRATCH": "FLAT, SCRATCHOp"}.get(fmt, f"{fmt}, {cls_name}")
if fmt in formats or fmt in ("GLOBAL", "SCRATCH"):
if fmt in ("VOP1", "VOP2", "VOPC"):

View File

@@ -191,6 +191,9 @@ def run_single_kernel(kernel: bytes, n_lanes: int, args_ptr: int, global_size: t
python_result = python.step()
if rust_result != python_result:
# Rust returns 1 for unsupported instructions - skip test
if rust_result == 1 and python_result == 0:
raise unittest.SkipTest(f"Rust emulator doesn't support instruction: {inst_str}")
trace_str = "\n".join(f" step {s}: PC={pc:3d} {d}" for s, pc, d, _, _ in trace)
return False, f"K{kernel_idx} WG({gidx},{gidy},{gidz}) Step {step}: different return codes: rust={rust_result}, python={python_result}, inst={inst_str}\n Recent instructions:\n{trace_str}", total_steps
@@ -361,6 +364,7 @@ class TestTinygradKernels(unittest.TestCase):
# Matmul
def test_gemm(self): self._test_kernel(lambda T: T.empty(8, 8) @ T.empty(8, 8), max_steps=100000)
@unittest.skip("Rust emulator crashes on this kernel (assertion failure in thread.rs)")
def test_gemm_fp16(self): self._test_kernel(lambda T: T.empty(16, 16).half() @ T.empty(16, 16).half(), max_steps=100000)
# Complex ops

File diff suppressed because it is too large Load Diff

View File

@@ -65,12 +65,18 @@ def parse_llvm_tests(text: str) -> list[tuple[str, bytes]]:
if not asm_text: continue
for j in range(i, min(i + 3, len(lines))):
# Match GFX11, W32, or W64 encodings (all valid for gfx11)
# Format 1: "// GFX11: v_foo ... ; encoding: [0x01,0x02,...]"
# Format 2: "// GFX11: [0x01,0x02,...]" (used by DS, older files)
if m := re.search(r'(?:GFX11|W32|W64)[^:]*:.*?encoding:\s*\[(.*?)\]', lines[j]):
hex_bytes = m.group(1).replace('0x', '').replace(',', '').replace(' ', '')
if hex_bytes:
try: tests.append((asm_text, bytes.fromhex(hex_bytes)))
except ValueError: pass
break
elif m := re.search(r'(?:GFX11|W32|W64)[^:]*:\s*\[(0x[0-9a-fA-F,x\s]+)\]', lines[j]):
hex_bytes = m.group(1).replace('0x', '').replace(',', '').replace(' ', '')
else:
continue
if hex_bytes:
try: tests.append((asm_text, bytes.fromhex(hex_bytes)))
except ValueError: pass
break
return tests
def try_assemble(text: str):

View File

@@ -210,6 +210,8 @@ D0.u32 = tmp.u32""")
for i in 0 : 31 do
if S0.u32[i] == 1 then
tmp = i
endif
endfor
D0.i32 = tmp""")
ctx = ExecContext(s0=0b1000) # Bit 3 is set
ctx.run(code)

View File

@@ -4,46 +4,9 @@ import unittest, io, sys, re, subprocess, os
from extra.assembly.amd.autogen.rdna3 import *
from extra.assembly.amd.dsl import Inst
from extra.assembly.amd.asm import asm
from extra.assembly.amd.asm import detect_format
from extra.assembly.amd.test.helpers import get_llvm_mc, get_llvm_objdump
# Instruction format detection based on encoding bits
def detect_format(data: bytes) -> type[Inst] | None:
"""Detect instruction format from machine code bytes."""
if len(data) < 4: return None
word = int.from_bytes(data[:4], 'little')
enc_9bit = (word >> 23) & 0x1FF # 9-bit encoding for SOP1/SOPC/SOPP
enc_8bit = (word >> 24) & 0xFF
# Check 9-bit encodings first (most specific)
if enc_9bit == 0x17D: return SOP1 # bits 31:23 = 101111101
if enc_9bit == 0x17E: return SOPC # bits 31:23 = 101111110
if enc_9bit == 0x17F: return SOPP # bits 31:23 = 101111111
# SOPK: bits 31:28 = 1011, bits 27:23 = opcode (check after SOP1/SOPC/SOPP)
if enc_8bit in range(0xB0, 0xC0): return SOPK
# SOP2: bits 31:23 in range 0x100-0x17C (0x80-0xBE in bits 31:24, but not SOPK)
if 0x80 <= enc_8bit <= 0x9F: return SOP2
# VOP1: bits 31:25 = 0111111 (0x3F)
if (word >> 25) == 0x3F: return VOP1
# VOPC: bits 31:25 = 0111110 (0x3E)
if (word >> 25) == 0x3E: return VOPC
# VOP2: bits 31:30 = 00
if (word >> 30) == 0: return VOP2
# Check 64-bit formats
if len(data) >= 8:
if enc_8bit in (0xD4, 0xD5, 0xD7): return VOP3
if enc_8bit == 0xD6: return VOP3SD
if enc_8bit == 0xCC: return VOP3P
if enc_8bit == 0xCD: return VINTERP
if enc_8bit in (0xC8, 0xC9): return VOPD
if enc_8bit == 0xF4: return SMEM
if enc_8bit == 0xD8: return DS
if enc_8bit in (0xDC, 0xDD, 0xDE, 0xDF): return FLAT
if enc_8bit in (0xE0, 0xE1, 0xE2, 0xE3): return MUBUF
if enc_8bit in (0xE8, 0xE9, 0xEA, 0xEB): return MTBUF
return None
def disassemble_lib(lib: bytes, compiler) -> list[tuple[str, bytes]]:
"""Disassemble ELF binary and return list of (instruction_text, machine_code_bytes)."""
old_stdout = sys.stdout

View File

@@ -1,7 +1,7 @@
import unittest, operator, math
from tinygrad import Tensor, dtypes, Device
from tinygrad.dtype import DType, truncate
from tinygrad.helpers import CI, getenv, CPU_LLVM
from tinygrad.helpers import CI, getenv
from tinygrad.tensor import _to_np_dtype
from tinygrad.device import is_dtype_supported
from tinygrad.runtime.ops_python import from_storage_scalar
@@ -48,7 +48,7 @@ class ht:
int32 = strat.integers(-2147483648, 2147483647)
int64 = strat.integers(-9223372036854775808, 9223372036854775807)
bool = strat.booleans()
ht.bfloat16 = ht.uint16
ht.bfloat16 = ht.uint16.filter(lambda x: ((x >> 7) & 0xFF) != 0) # filter subnormal bfloat16
ht.fp8e4m3 = ht.uint8
ht.fp8e5m2 = ht.uint8
@@ -138,7 +138,6 @@ class TestDTypeALU(unittest.TestCase):
def test_float16_unary(self, a, op): universal_test_unary(a, dtypes.float16, op)
@unittest.skipUnless(is_dtype_supported(dtypes.bfloat16), f"no bfloat16 on {Device.DEFAULT}")
@unittest.skipIf(CPU_LLVM, "bfloat16 precision issues with CPU_LLVM")
@given(ht.bfloat16, strat.sampled_from(unary_operations))
def test_bfloat16_unary(self, a, op): universal_test_unary(from_storage_scalar(a, dtypes.bfloat16), dtypes.bfloat16, op)

View File

@@ -189,16 +189,18 @@ class AM_SMU(AM_IP):
return table_t.from_buffer(bytearray(self.adev.vram.view(self.driver_table_paddr, ctypes.sizeof(table_t))[:]))
def set_clocks(self, level):
if self.adev.ip_ver[am.MP0_HWIP] in {(13,0,6), (13,0,12)}: return # TODO
if not hasattr(self, 'clcks'):
clks = [self.smu_mod.PPCLK_UCLK, self.smu_mod.PPCLK_FCLK, self.smu_mod.PPCLK_SOCCLK]
if self.adev.ip_ver[am.MP0_HWIP] not in {(13,0,6), (13,0,12)}: clks.append(self.smu_mod.PPCLK_GFXCLK)
self.clcks = {}
for clck in [self.smu_mod.PPCLK_GFXCLK, self.smu_mod.PPCLK_UCLK, self.smu_mod.PPCLK_FCLK, self.smu_mod.PPCLK_SOCCLK]:
for clck in clks:
cnt = self._send_msg(self.smu_mod.PPSMC_MSG_GetDpmFreqByIndex, (clck<<16)|0xff, read_back_arg=True)&0x7fffffff
self.clcks[clck] = [self._send_msg(self.smu_mod.PPSMC_MSG_GetDpmFreqByIndex, (clck<<16)|i, read_back_arg=True)&0x7fffffff for i in range(cnt)]
for clck, vals in self.clcks.items():
self._send_msg(self.smu_mod.PPSMC_MSG_SetSoftMinByFreq, clck << 16 | (vals[level]))
if not vals: continue
with contextlib.suppress(TimeoutError): self._send_msg(self.smu_mod.PPSMC_MSG_SetSoftMinByFreq, clck << 16 | (vals[level]), timeout=20)
self._send_msg(self.smu_mod.PPSMC_MSG_SetSoftMaxByFreq, clck << 16 | (vals[level]))
def _smu_cmn_send_msg(self, msg:int, param=0, debug=False):

View File

@@ -70,7 +70,7 @@ def validate_index(buf:UOp, idx:UOp, gate:UOp|None=None):
# WEBGPU has a BITCAST in the index. TODO: fix
if any(x.op is Ops.BITCAST for x in idx.toposort()): return True
if not z3_imported: raise ImportError("z3 >= 4.12.4 is required for bounds checking, try IGNORE_OOB=0 or \"pip install 'z3-solver>=4.12.4\"")
if not z3_imported: raise ImportError("bounds checking requires z3 >= 4.12.4, use IGNORE_OOB=1 to disable, or \"pip install 'z3-solver>=4.12.4\"")
solver = z3.Solver(ctx=z3.Context())
z3_idx, z3_mask = uops_to_z3(solver, idx, gate)
solver.add(z3_mask)