Merge origin/master into gen_pdf_fast

This commit is contained in:
George Hotz
2025-12-30 14:48:03 -05:00
7 changed files with 723 additions and 834 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -2209,33 +2209,33 @@ buffer_atomic_xor_x2 = functools.partial(MUBUF, MUBUFOp.BUFFER_ATOMIC_XOR_X2)
buffer_atomic_inc_x2 = functools.partial(MUBUF, MUBUFOp.BUFFER_ATOMIC_INC_X2)
buffer_atomic_dec_x2 = functools.partial(MUBUF, MUBUFOp.BUFFER_ATOMIC_DEC_X2)
cdna4 = functools.partial(MUBUF, MUBUFOp.CDNA4)
scratch_load_ubyte = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_UBYTE, seg=2)
scratch_load_sbyte = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_SBYTE, seg=2)
scratch_load_ushort = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_USHORT, seg=2)
scratch_load_sshort = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_SSHORT, seg=2)
scratch_load_dword = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_DWORD, seg=2)
scratch_load_dwordx2 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_DWORDX2, seg=2)
scratch_load_dwordx3 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_DWORDX3, seg=2)
scratch_load_dwordx4 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_DWORDX4, seg=2)
scratch_store_byte = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_BYTE, seg=2)
scratch_store_byte_d16_hi = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_BYTE_D16_HI, seg=2)
scratch_store_short = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_SHORT, seg=2)
scratch_store_short_d16_hi = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_SHORT_D16_HI, seg=2)
scratch_store_dword = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_DWORD, seg=2)
scratch_store_dwordx2 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_DWORDX2, seg=2)
scratch_store_dwordx3 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_DWORDX3, seg=2)
scratch_store_dwordx4 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_DWORDX4, seg=2)
scratch_load_ubyte_d16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_UBYTE_D16, seg=2)
scratch_load_ubyte_d16_hi = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_UBYTE_D16_HI, seg=2)
scratch_load_sbyte_d16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_SBYTE_D16, seg=2)
scratch_load_sbyte_d16_hi = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_SBYTE_D16_HI, seg=2)
scratch_load_short_d16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_SHORT_D16, seg=2)
scratch_load_short_d16_hi = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_SHORT_D16_HI, seg=2)
scratch_load_lds_ubyte = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_LDS_UBYTE, seg=2)
scratch_load_lds_sbyte = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_LDS_SBYTE, seg=2)
scratch_load_lds_ushort = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_LDS_USHORT, seg=2)
scratch_load_lds_sshort = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_LDS_SSHORT, seg=2)
scratch_load_lds_dword = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_LDS_DWORD, seg=2)
scratch_load_ubyte = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_UBYTE, seg=1)
scratch_load_sbyte = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_SBYTE, seg=1)
scratch_load_ushort = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_USHORT, seg=1)
scratch_load_sshort = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_SSHORT, seg=1)
scratch_load_dword = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_DWORD, seg=1)
scratch_load_dwordx2 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_DWORDX2, seg=1)
scratch_load_dwordx3 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_DWORDX3, seg=1)
scratch_load_dwordx4 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_DWORDX4, seg=1)
scratch_store_byte = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_BYTE, seg=1)
scratch_store_byte_d16_hi = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_BYTE_D16_HI, seg=1)
scratch_store_short = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_SHORT, seg=1)
scratch_store_short_d16_hi = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_SHORT_D16_HI, seg=1)
scratch_store_dword = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_DWORD, seg=1)
scratch_store_dwordx2 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_DWORDX2, seg=1)
scratch_store_dwordx3 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_DWORDX3, seg=1)
scratch_store_dwordx4 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_DWORDX4, seg=1)
scratch_load_ubyte_d16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_UBYTE_D16, seg=1)
scratch_load_ubyte_d16_hi = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_UBYTE_D16_HI, seg=1)
scratch_load_sbyte_d16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_SBYTE_D16, seg=1)
scratch_load_sbyte_d16_hi = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_SBYTE_D16_HI, seg=1)
scratch_load_short_d16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_SHORT_D16, seg=1)
scratch_load_short_d16_hi = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_SHORT_D16_HI, seg=1)
scratch_load_lds_ubyte = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_LDS_UBYTE, seg=1)
scratch_load_lds_sbyte = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_LDS_SBYTE, seg=1)
scratch_load_lds_ushort = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_LDS_USHORT, seg=1)
scratch_load_lds_sshort = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_LDS_SSHORT, seg=1)
scratch_load_lds_dword = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_LDS_DWORD, seg=1)
s_load_dword = functools.partial(SMEM, SMEMOp.S_LOAD_DWORD)
s_load_dwordx2 = functools.partial(SMEM, SMEMOp.S_LOAD_DWORDX2)
s_load_dwordx4 = functools.partial(SMEM, SMEMOp.S_LOAD_DWORDX4)

View File

@@ -56,6 +56,12 @@ class DSOp(IntEnum):
DS_MAX_F32 = 19
DS_NOP = 20
DS_ADD_F32 = 21
DS_GWS_SEMA_RELEASE_ALL = 24
DS_GWS_INIT = 25
DS_GWS_SEMA_V = 26
DS_GWS_SEMA_BR = 27
DS_GWS_SEMA_P = 28
DS_GWS_BARRIER = 29
DS_STORE_B8 = 30
DS_STORE_B16 = 31
DS_ADD_RTN_U32 = 32
@@ -178,10 +184,13 @@ class FLATOp(IntEnum):
FLAT_LOAD_D16_HI_B16 = 35
FLAT_STORE_D16_HI_B8 = 36
FLAT_STORE_D16_HI_B16 = 37
GLOBAL_LOAD_ADDTID_B32 = 40
GLOBAL_STORE_ADDTID_B32 = 41
FLAT_ATOMIC_SWAP_B32 = 51
FLAT_ATOMIC_CMPSWAP_B32 = 52
FLAT_ATOMIC_ADD_U32 = 53
FLAT_ATOMIC_SUB_U32 = 54
FLAT_ATOMIC_CSUB_U32 = 55
FLAT_ATOMIC_MIN_I32 = 56
FLAT_ATOMIC_MIN_U32 = 57
FLAT_ATOMIC_MAX_I32 = 58
@@ -717,6 +726,7 @@ class SOPPOp(IntEnum):
S_SET_INST_PREFETCH_DISTANCE = 4
S_CLAUSE = 5
S_DELAY_ALU = 7
S_WAITCNT_DEPCTR = 8
S_WAITCNT = 9
S_WAIT_IDLE = 10
S_WAIT_EVENT = 11
@@ -1848,6 +1858,12 @@ ds_min_f32 = functools.partial(DS, DSOp.DS_MIN_F32)
ds_max_f32 = functools.partial(DS, DSOp.DS_MAX_F32)
ds_nop = functools.partial(DS, DSOp.DS_NOP)
ds_add_f32 = functools.partial(DS, DSOp.DS_ADD_F32)
ds_gws_sema_release_all = functools.partial(DS, DSOp.DS_GWS_SEMA_RELEASE_ALL)
ds_gws_init = functools.partial(DS, DSOp.DS_GWS_INIT)
ds_gws_sema_v = functools.partial(DS, DSOp.DS_GWS_SEMA_V)
ds_gws_sema_br = functools.partial(DS, DSOp.DS_GWS_SEMA_BR)
ds_gws_sema_p = functools.partial(DS, DSOp.DS_GWS_SEMA_P)
ds_gws_barrier = functools.partial(DS, DSOp.DS_GWS_BARRIER)
ds_store_b8 = functools.partial(DS, DSOp.DS_STORE_B8)
ds_store_b16 = functools.partial(DS, DSOp.DS_STORE_B16)
ds_add_rtn_u32 = functools.partial(DS, DSOp.DS_ADD_RTN_U32)
@@ -1968,10 +1984,13 @@ flat_load_d16_hi_i8 = functools.partial(FLAT, FLATOp.FLAT_LOAD_D16_HI_I8)
flat_load_d16_hi_b16 = functools.partial(FLAT, FLATOp.FLAT_LOAD_D16_HI_B16)
flat_store_d16_hi_b8 = functools.partial(FLAT, FLATOp.FLAT_STORE_D16_HI_B8)
flat_store_d16_hi_b16 = functools.partial(FLAT, FLATOp.FLAT_STORE_D16_HI_B16)
global_load_addtid_b32 = functools.partial(FLAT, FLATOp.GLOBAL_LOAD_ADDTID_B32)
global_store_addtid_b32 = functools.partial(FLAT, FLATOp.GLOBAL_STORE_ADDTID_B32)
flat_atomic_swap_b32 = functools.partial(FLAT, FLATOp.FLAT_ATOMIC_SWAP_B32)
flat_atomic_cmpswap_b32 = functools.partial(FLAT, FLATOp.FLAT_ATOMIC_CMPSWAP_B32)
flat_atomic_add_u32 = functools.partial(FLAT, FLATOp.FLAT_ATOMIC_ADD_U32)
flat_atomic_sub_u32 = functools.partial(FLAT, FLATOp.FLAT_ATOMIC_SUB_U32)
flat_atomic_csub_u32 = functools.partial(FLAT, FLATOp.FLAT_ATOMIC_CSUB_U32)
flat_atomic_min_i32 = functools.partial(FLAT, FLATOp.FLAT_ATOMIC_MIN_I32)
flat_atomic_min_u32 = functools.partial(FLAT, FLATOp.FLAT_ATOMIC_MIN_U32)
flat_atomic_max_i32 = functools.partial(FLAT, FLATOp.FLAT_ATOMIC_MAX_I32)
@@ -2226,28 +2245,28 @@ buffer_atomic_cmpswap_f32 = functools.partial(MUBUF, MUBUFOp.BUFFER_ATOMIC_CMPSW
buffer_atomic_min_f32 = functools.partial(MUBUF, MUBUFOp.BUFFER_ATOMIC_MIN_F32)
buffer_atomic_max_f32 = functools.partial(MUBUF, MUBUFOp.BUFFER_ATOMIC_MAX_F32)
buffer_atomic_add_f32 = functools.partial(MUBUF, MUBUFOp.BUFFER_ATOMIC_ADD_F32)
scratch_load_u8 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_U8, seg=2)
scratch_load_i8 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_I8, seg=2)
scratch_load_u16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_U16, seg=2)
scratch_load_i16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_I16, seg=2)
scratch_load_b32 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_B32, seg=2)
scratch_load_b64 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_B64, seg=2)
scratch_load_b96 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_B96, seg=2)
scratch_load_b128 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_B128, seg=2)
scratch_store_b8 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_B8, seg=2)
scratch_store_b16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_B16, seg=2)
scratch_store_b32 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_B32, seg=2)
scratch_store_b64 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_B64, seg=2)
scratch_store_b96 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_B96, seg=2)
scratch_store_b128 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_B128, seg=2)
scratch_load_d16_u8 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_D16_U8, seg=2)
scratch_load_d16_i8 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_D16_I8, seg=2)
scratch_load_d16_b16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_D16_B16, seg=2)
scratch_load_d16_hi_u8 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_D16_HI_U8, seg=2)
scratch_load_d16_hi_i8 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_D16_HI_I8, seg=2)
scratch_load_d16_hi_b16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_D16_HI_B16, seg=2)
scratch_store_d16_hi_b8 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_D16_HI_B8, seg=2)
scratch_store_d16_hi_b16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_D16_HI_B16, seg=2)
scratch_load_u8 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_U8, seg=1)
scratch_load_i8 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_I8, seg=1)
scratch_load_u16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_U16, seg=1)
scratch_load_i16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_I16, seg=1)
scratch_load_b32 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_B32, seg=1)
scratch_load_b64 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_B64, seg=1)
scratch_load_b96 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_B96, seg=1)
scratch_load_b128 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_B128, seg=1)
scratch_store_b8 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_B8, seg=1)
scratch_store_b16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_B16, seg=1)
scratch_store_b32 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_B32, seg=1)
scratch_store_b64 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_B64, seg=1)
scratch_store_b96 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_B96, seg=1)
scratch_store_b128 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_B128, seg=1)
scratch_load_d16_u8 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_D16_U8, seg=1)
scratch_load_d16_i8 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_D16_I8, seg=1)
scratch_load_d16_b16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_D16_B16, seg=1)
scratch_load_d16_hi_u8 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_D16_HI_U8, seg=1)
scratch_load_d16_hi_i8 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_D16_HI_I8, seg=1)
scratch_load_d16_hi_b16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_D16_HI_B16, seg=1)
scratch_store_d16_hi_b8 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_D16_HI_B8, seg=1)
scratch_store_d16_hi_b16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_D16_HI_B16, seg=1)
s_load_b32 = functools.partial(SMEM, SMEMOp.S_LOAD_B32)
s_load_b64 = functools.partial(SMEM, SMEMOp.S_LOAD_B64)
s_load_b128 = functools.partial(SMEM, SMEMOp.S_LOAD_B128)
@@ -2485,6 +2504,7 @@ s_sleep = functools.partial(SOPP, SOPPOp.S_SLEEP)
s_set_inst_prefetch_distance = functools.partial(SOPP, SOPPOp.S_SET_INST_PREFETCH_DISTANCE)
s_clause = functools.partial(SOPP, SOPPOp.S_CLAUSE)
s_delay_alu = functools.partial(SOPP, SOPPOp.S_DELAY_ALU)
s_waitcnt_depctr = functools.partial(SOPP, SOPPOp.S_WAITCNT_DEPCTR)
s_waitcnt = functools.partial(SOPP, SOPPOp.S_WAITCNT)
s_wait_idle = functools.partial(SOPP, SOPPOp.S_WAIT_IDLE)
s_wait_event = functools.partial(SOPP, SOPPOp.S_WAIT_EVENT)

View File

@@ -6,22 +6,21 @@ from typing import overload, Annotated, TypeVar, Generic
# Bit field DSL
class BitField:
def __init__(self, hi: int, lo: int, name: str | None = None): self.hi, self.lo, self.name = hi, lo, name
def __set_name__(self, owner, name): self.name, self._owner = name, owner
def __init__(self, hi: int, lo: int, name: str | None = None): self.hi, self.lo, self.name, self._marker = hi, lo, name, None
def __set_name__(self, owner, name):
import typing
self.name, self._owner = name, owner
# Cache marker at class definition time
hints = typing.get_type_hints(owner, include_extras=True)
if name in hints:
hint = hints[name]
if typing.get_origin(hint) is Annotated:
args = typing.get_args(hint)
self._marker = args[1] if len(args) > 1 else None
def __eq__(self, val: int) -> tuple[BitField, int]: return (self, val) # type: ignore
def mask(self) -> int: return (1 << (self.hi - self.lo + 1)) - 1
@property
def marker(self) -> type | None:
# Get marker from Annotated type hint if present
import typing
if hasattr(self, '_owner') and self.name:
hints = typing.get_type_hints(self._owner, include_extras=True)
if self.name in hints:
hint = hints[self.name]
if typing.get_origin(hint) is Annotated:
args = typing.get_args(hint)
return args[1] if len(args) > 1 else None
return None
def marker(self) -> type | None: return self._marker
@overload
def __get__(self, obj: None, objtype: type) -> BitField: ...
@overload
@@ -179,6 +178,21 @@ class Inst:
raise ValueError(f"SOP1 {op_val.name} expects {expected} destination register(s), got {sdst_val.count}")
if isinstance(ssrc0_val, Reg) and ssrc0_val.count != expected:
raise ValueError(f"SOP1 {op_val.name} expects {expected} source register(s), got {ssrc0_val.count}")
# FLAT: set sve=1 when addr is a VGPR for scratch only
# For scratch (seg=1), sve=1 means addr VGPR is used; sve=0 means addr is "off"
# For global (seg=2) and flat (seg=0), sve is always 0
if self.__class__.__name__ == 'FLAT' and 'sve' in self._fields:
seg_val = self._values.get('seg', 0)
if isinstance(seg_val, RawImm): seg_val = seg_val.val
addr_val = orig_args.get('addr')
if seg_val == 1 and isinstance(addr_val, VGPR): self._values['sve'] = 1
# VOP3P: v_fma_mix* instructions (opcodes 32-34) have opsel_hi default of 0, not 7
if self.__class__.__name__ == 'VOP3P':
op_val = orig_args.get(field_names[0]) if args else orig_args.get('op')
if hasattr(op_val, 'value'): op_val = op_val.value
if op_val in (32, 33, 34) and 'opsel_hi' not in orig_args and 'opsel_hi2' not in orig_args:
self._values['opsel_hi'] = 0
self._values['opsel_hi2'] = 0
# Type check and encode values
for name, val in list(self._values.items()):
if name == 'encoding': continue
@@ -340,6 +354,14 @@ class Inst:
lit = f", literal={hex(self._literal)}" if self._literal is not None else ""
return f"{self.__class__.__name__}({', '.join(f'{k}={v}' for k, v in items)}{lit})"
def __getattr__(self, name: str):
if name.startswith('_'): raise AttributeError(name)
return unwrap(self._values.get(name, 0))
def lit(self, v: int) -> str:
from extra.assembly.amd.asm import decode_src
return f"0x{self._literal:x}" if v == 255 and self._literal else decode_src(v)
def __eq__(self, other):
if not isinstance(other, Inst): return NotImplemented
return self.__class__ == other.__class__ and self._values == other._values and self._literal == other._literal

View File

@@ -3,6 +3,7 @@
from __future__ import annotations
import ctypes, os
from extra.assembly.amd.dsl import Inst, RawImm
from extra.assembly.amd.asm import detect_format
from extra.assembly.amd.pcode import _f32, _i32, _sext, _f16, _i16, _f64, _i64
from extra.assembly.amd.autogen.rdna3.gen_pcode import get_compiled_functions
from extra.assembly.amd.autogen.rdna3 import (
@@ -146,21 +147,7 @@ class WaveState:
for reg, val in self._pend_sgpr.items(): self.sgpr[reg] = val
self._pend_sgpr.clear()
# Instruction decode
def decode_format(word: int) -> tuple[type[Inst] | None, bool]:
hi2 = (word >> 30) & 0x3
if hi2 == 0b11:
enc = (word >> 26) & 0xf
if enc == 0b1101: return SMEM, True
if enc == 0b0101:
op = (word >> 16) & 0x3ff
return (VOP3SD, True) if op in (288, 289, 290, 764, 765, 766, 767, 768, 769, 770) else (VOP3, True)
return {0b0011: (VOP3P, True), 0b0110: (DS, True), 0b0111: (FLAT, True), 0b0010: (VOPD, True)}.get(enc, (None, True))
if hi2 == 0b10:
enc = (word >> 23) & 0x7f
return {0b1111101: (SOP1, False), 0b1111110: (SOPC, False), 0b1111111: (SOPP, False)}.get(enc, (SOPK, False) if ((word >> 28) & 0xf) == 0b1011 else (SOP2, False))
enc = (word >> 25) & 0x7f
return (VOPC, False) if enc == 0b0111110 else (VOP1, False) if enc == 0b0111111 else (VOP2, False)
def _unwrap(v) -> int: return v.val if isinstance(v, RawImm) else v.value if hasattr(v, 'value') else v
@@ -168,10 +155,10 @@ def decode_program(data: bytes) -> Program:
result: Program = {}
i = 0
while i < len(data):
word = int.from_bytes(data[i:i+4], 'little')
inst_class, is_64 = decode_format(word)
try: inst_class = detect_format(data[i:])
except ValueError: break # stop at invalid instruction (padding/metadata after code)
if inst_class is None: i += 4; continue
base_size = 8 if is_64 else 4
base_size = inst_class._size()
# Pass enough data for potential 64-bit literal (base + 8 bytes max)
inst = inst_class.from_bytes(data[i:i+base_size+8])
for name, val in inst._values.items(): setattr(inst, name, _unwrap(val))

View File

@@ -65,12 +65,18 @@ def parse_llvm_tests(text: str) -> list[tuple[str, bytes]]:
if not asm_text: continue
for j in range(i, min(i + 3, len(lines))):
# Match GFX11, W32, or W64 encodings (all valid for gfx11)
# Format 1: "// GFX11: v_foo ... ; encoding: [0x01,0x02,...]"
# Format 2: "// GFX11: [0x01,0x02,...]" (used by DS, older files)
if m := re.search(r'(?:GFX11|W32|W64)[^:]*:.*?encoding:\s*\[(.*?)\]', lines[j]):
hex_bytes = m.group(1).replace('0x', '').replace(',', '').replace(' ', '')
if hex_bytes:
try: tests.append((asm_text, bytes.fromhex(hex_bytes)))
except ValueError: pass
break
elif m := re.search(r'(?:GFX11|W32|W64)[^:]*:\s*\[(0x[0-9a-fA-F,x\s]+)\]', lines[j]):
hex_bytes = m.group(1).replace('0x', '').replace(',', '').replace(' ', '')
else:
continue
if hex_bytes:
try: tests.append((asm_text, bytes.fromhex(hex_bytes)))
except ValueError: pass
break
return tests
def try_assemble(text: str):

View File

@@ -4,51 +4,9 @@ import unittest, io, sys, re, subprocess, os
from extra.assembly.amd.autogen.rdna3 import *
from extra.assembly.amd.dsl import Inst
from extra.assembly.amd.asm import asm
from extra.assembly.amd.asm import detect_format
from extra.assembly.amd.test.helpers import get_llvm_mc, get_llvm_objdump
# Instruction format detection based on encoding bits
def detect_format(data: bytes) -> type[Inst] | None:
"""Detect instruction format from machine code bytes."""
if len(data) < 4: return None
word = int.from_bytes(data[:4], 'little')
enc_9bit = (word >> 23) & 0x1FF # 9-bit encoding for SOP1/SOPC/SOPP
enc_8bit = (word >> 24) & 0xFF
# Check 9-bit encodings first (most specific)
if enc_9bit == 0x17D: return SOP1 # bits 31:23 = 101111101
if enc_9bit == 0x17E: return SOPC # bits 31:23 = 101111110
if enc_9bit == 0x17F: return SOPP # bits 31:23 = 101111111
# SOPK: bits 31:28 = 1011, bits 27:23 = opcode (check after SOP1/SOPC/SOPP)
if enc_8bit in range(0xB0, 0xC0): return SOPK
# SOP2: bits 31:23 in range 0x100-0x17C (0x80-0xBE in bits 31:24, but not SOPK)
if 0x80 <= enc_8bit <= 0x9F: return SOP2
# VOP1: bits 31:25 = 0111111 (0x3F)
if (word >> 25) == 0x3F: return VOP1
# VOPC: bits 31:25 = 0111110 (0x3E)
if (word >> 25) == 0x3E: return VOPC
# VOP2: bits 31:30 = 00
if (word >> 30) == 0: return VOP2
# Check 64-bit formats
if len(data) >= 8:
if enc_8bit in (0xD4, 0xD5, 0xD7):
# VOP3 and VOP3SD share encoding - check opcode to determine which
# VOP3SD opcodes: 288-290 (v_*_co_ci_*), 764-770 (v_div_scale_*, v_mad_*, v_*_co_u32)
op = (int.from_bytes(data[:8], 'little') >> 16) & 0x3FF
if op in {288, 289, 290, 764, 765, 766, 767, 768, 769, 770}: return VOP3SD
return VOP3
if enc_8bit == 0xD6: return VOP3SD
if enc_8bit == 0xCC: return VOP3P
if enc_8bit == 0xCD: return VINTERP
if enc_8bit in (0xC8, 0xC9): return VOPD
if enc_8bit == 0xF4: return SMEM
if enc_8bit == 0xD8: return DS
if enc_8bit in (0xDC, 0xDD, 0xDE, 0xDF): return FLAT
if enc_8bit in (0xE0, 0xE1, 0xE2, 0xE3): return MUBUF
if enc_8bit in (0xE8, 0xE9, 0xEA, 0xEB): return MTBUF
return None
def disassemble_lib(lib: bytes, compiler) -> list[tuple[str, bytes]]:
"""Disassemble ELF binary and return list of (instruction_text, machine_code_bytes)."""
old_stdout = sys.stdout