mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 06:58:11 -05:00
Merge origin/master into gen_pdf_fast
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -2209,33 +2209,33 @@ buffer_atomic_xor_x2 = functools.partial(MUBUF, MUBUFOp.BUFFER_ATOMIC_XOR_X2)
|
||||
buffer_atomic_inc_x2 = functools.partial(MUBUF, MUBUFOp.BUFFER_ATOMIC_INC_X2)
|
||||
buffer_atomic_dec_x2 = functools.partial(MUBUF, MUBUFOp.BUFFER_ATOMIC_DEC_X2)
|
||||
cdna4 = functools.partial(MUBUF, MUBUFOp.CDNA4)
|
||||
scratch_load_ubyte = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_UBYTE, seg=2)
|
||||
scratch_load_sbyte = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_SBYTE, seg=2)
|
||||
scratch_load_ushort = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_USHORT, seg=2)
|
||||
scratch_load_sshort = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_SSHORT, seg=2)
|
||||
scratch_load_dword = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_DWORD, seg=2)
|
||||
scratch_load_dwordx2 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_DWORDX2, seg=2)
|
||||
scratch_load_dwordx3 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_DWORDX3, seg=2)
|
||||
scratch_load_dwordx4 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_DWORDX4, seg=2)
|
||||
scratch_store_byte = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_BYTE, seg=2)
|
||||
scratch_store_byte_d16_hi = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_BYTE_D16_HI, seg=2)
|
||||
scratch_store_short = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_SHORT, seg=2)
|
||||
scratch_store_short_d16_hi = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_SHORT_D16_HI, seg=2)
|
||||
scratch_store_dword = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_DWORD, seg=2)
|
||||
scratch_store_dwordx2 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_DWORDX2, seg=2)
|
||||
scratch_store_dwordx3 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_DWORDX3, seg=2)
|
||||
scratch_store_dwordx4 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_DWORDX4, seg=2)
|
||||
scratch_load_ubyte_d16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_UBYTE_D16, seg=2)
|
||||
scratch_load_ubyte_d16_hi = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_UBYTE_D16_HI, seg=2)
|
||||
scratch_load_sbyte_d16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_SBYTE_D16, seg=2)
|
||||
scratch_load_sbyte_d16_hi = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_SBYTE_D16_HI, seg=2)
|
||||
scratch_load_short_d16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_SHORT_D16, seg=2)
|
||||
scratch_load_short_d16_hi = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_SHORT_D16_HI, seg=2)
|
||||
scratch_load_lds_ubyte = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_LDS_UBYTE, seg=2)
|
||||
scratch_load_lds_sbyte = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_LDS_SBYTE, seg=2)
|
||||
scratch_load_lds_ushort = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_LDS_USHORT, seg=2)
|
||||
scratch_load_lds_sshort = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_LDS_SSHORT, seg=2)
|
||||
scratch_load_lds_dword = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_LDS_DWORD, seg=2)
|
||||
scratch_load_ubyte = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_UBYTE, seg=1)
|
||||
scratch_load_sbyte = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_SBYTE, seg=1)
|
||||
scratch_load_ushort = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_USHORT, seg=1)
|
||||
scratch_load_sshort = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_SSHORT, seg=1)
|
||||
scratch_load_dword = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_DWORD, seg=1)
|
||||
scratch_load_dwordx2 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_DWORDX2, seg=1)
|
||||
scratch_load_dwordx3 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_DWORDX3, seg=1)
|
||||
scratch_load_dwordx4 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_DWORDX4, seg=1)
|
||||
scratch_store_byte = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_BYTE, seg=1)
|
||||
scratch_store_byte_d16_hi = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_BYTE_D16_HI, seg=1)
|
||||
scratch_store_short = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_SHORT, seg=1)
|
||||
scratch_store_short_d16_hi = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_SHORT_D16_HI, seg=1)
|
||||
scratch_store_dword = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_DWORD, seg=1)
|
||||
scratch_store_dwordx2 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_DWORDX2, seg=1)
|
||||
scratch_store_dwordx3 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_DWORDX3, seg=1)
|
||||
scratch_store_dwordx4 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_DWORDX4, seg=1)
|
||||
scratch_load_ubyte_d16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_UBYTE_D16, seg=1)
|
||||
scratch_load_ubyte_d16_hi = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_UBYTE_D16_HI, seg=1)
|
||||
scratch_load_sbyte_d16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_SBYTE_D16, seg=1)
|
||||
scratch_load_sbyte_d16_hi = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_SBYTE_D16_HI, seg=1)
|
||||
scratch_load_short_d16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_SHORT_D16, seg=1)
|
||||
scratch_load_short_d16_hi = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_SHORT_D16_HI, seg=1)
|
||||
scratch_load_lds_ubyte = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_LDS_UBYTE, seg=1)
|
||||
scratch_load_lds_sbyte = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_LDS_SBYTE, seg=1)
|
||||
scratch_load_lds_ushort = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_LDS_USHORT, seg=1)
|
||||
scratch_load_lds_sshort = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_LDS_SSHORT, seg=1)
|
||||
scratch_load_lds_dword = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_LDS_DWORD, seg=1)
|
||||
s_load_dword = functools.partial(SMEM, SMEMOp.S_LOAD_DWORD)
|
||||
s_load_dwordx2 = functools.partial(SMEM, SMEMOp.S_LOAD_DWORDX2)
|
||||
s_load_dwordx4 = functools.partial(SMEM, SMEMOp.S_LOAD_DWORDX4)
|
||||
|
||||
@@ -56,6 +56,12 @@ class DSOp(IntEnum):
|
||||
DS_MAX_F32 = 19
|
||||
DS_NOP = 20
|
||||
DS_ADD_F32 = 21
|
||||
DS_GWS_SEMA_RELEASE_ALL = 24
|
||||
DS_GWS_INIT = 25
|
||||
DS_GWS_SEMA_V = 26
|
||||
DS_GWS_SEMA_BR = 27
|
||||
DS_GWS_SEMA_P = 28
|
||||
DS_GWS_BARRIER = 29
|
||||
DS_STORE_B8 = 30
|
||||
DS_STORE_B16 = 31
|
||||
DS_ADD_RTN_U32 = 32
|
||||
@@ -178,10 +184,13 @@ class FLATOp(IntEnum):
|
||||
FLAT_LOAD_D16_HI_B16 = 35
|
||||
FLAT_STORE_D16_HI_B8 = 36
|
||||
FLAT_STORE_D16_HI_B16 = 37
|
||||
GLOBAL_LOAD_ADDTID_B32 = 40
|
||||
GLOBAL_STORE_ADDTID_B32 = 41
|
||||
FLAT_ATOMIC_SWAP_B32 = 51
|
||||
FLAT_ATOMIC_CMPSWAP_B32 = 52
|
||||
FLAT_ATOMIC_ADD_U32 = 53
|
||||
FLAT_ATOMIC_SUB_U32 = 54
|
||||
FLAT_ATOMIC_CSUB_U32 = 55
|
||||
FLAT_ATOMIC_MIN_I32 = 56
|
||||
FLAT_ATOMIC_MIN_U32 = 57
|
||||
FLAT_ATOMIC_MAX_I32 = 58
|
||||
@@ -717,6 +726,7 @@ class SOPPOp(IntEnum):
|
||||
S_SET_INST_PREFETCH_DISTANCE = 4
|
||||
S_CLAUSE = 5
|
||||
S_DELAY_ALU = 7
|
||||
S_WAITCNT_DEPCTR = 8
|
||||
S_WAITCNT = 9
|
||||
S_WAIT_IDLE = 10
|
||||
S_WAIT_EVENT = 11
|
||||
@@ -1848,6 +1858,12 @@ ds_min_f32 = functools.partial(DS, DSOp.DS_MIN_F32)
|
||||
ds_max_f32 = functools.partial(DS, DSOp.DS_MAX_F32)
|
||||
ds_nop = functools.partial(DS, DSOp.DS_NOP)
|
||||
ds_add_f32 = functools.partial(DS, DSOp.DS_ADD_F32)
|
||||
ds_gws_sema_release_all = functools.partial(DS, DSOp.DS_GWS_SEMA_RELEASE_ALL)
|
||||
ds_gws_init = functools.partial(DS, DSOp.DS_GWS_INIT)
|
||||
ds_gws_sema_v = functools.partial(DS, DSOp.DS_GWS_SEMA_V)
|
||||
ds_gws_sema_br = functools.partial(DS, DSOp.DS_GWS_SEMA_BR)
|
||||
ds_gws_sema_p = functools.partial(DS, DSOp.DS_GWS_SEMA_P)
|
||||
ds_gws_barrier = functools.partial(DS, DSOp.DS_GWS_BARRIER)
|
||||
ds_store_b8 = functools.partial(DS, DSOp.DS_STORE_B8)
|
||||
ds_store_b16 = functools.partial(DS, DSOp.DS_STORE_B16)
|
||||
ds_add_rtn_u32 = functools.partial(DS, DSOp.DS_ADD_RTN_U32)
|
||||
@@ -1968,10 +1984,13 @@ flat_load_d16_hi_i8 = functools.partial(FLAT, FLATOp.FLAT_LOAD_D16_HI_I8)
|
||||
flat_load_d16_hi_b16 = functools.partial(FLAT, FLATOp.FLAT_LOAD_D16_HI_B16)
|
||||
flat_store_d16_hi_b8 = functools.partial(FLAT, FLATOp.FLAT_STORE_D16_HI_B8)
|
||||
flat_store_d16_hi_b16 = functools.partial(FLAT, FLATOp.FLAT_STORE_D16_HI_B16)
|
||||
global_load_addtid_b32 = functools.partial(FLAT, FLATOp.GLOBAL_LOAD_ADDTID_B32)
|
||||
global_store_addtid_b32 = functools.partial(FLAT, FLATOp.GLOBAL_STORE_ADDTID_B32)
|
||||
flat_atomic_swap_b32 = functools.partial(FLAT, FLATOp.FLAT_ATOMIC_SWAP_B32)
|
||||
flat_atomic_cmpswap_b32 = functools.partial(FLAT, FLATOp.FLAT_ATOMIC_CMPSWAP_B32)
|
||||
flat_atomic_add_u32 = functools.partial(FLAT, FLATOp.FLAT_ATOMIC_ADD_U32)
|
||||
flat_atomic_sub_u32 = functools.partial(FLAT, FLATOp.FLAT_ATOMIC_SUB_U32)
|
||||
flat_atomic_csub_u32 = functools.partial(FLAT, FLATOp.FLAT_ATOMIC_CSUB_U32)
|
||||
flat_atomic_min_i32 = functools.partial(FLAT, FLATOp.FLAT_ATOMIC_MIN_I32)
|
||||
flat_atomic_min_u32 = functools.partial(FLAT, FLATOp.FLAT_ATOMIC_MIN_U32)
|
||||
flat_atomic_max_i32 = functools.partial(FLAT, FLATOp.FLAT_ATOMIC_MAX_I32)
|
||||
@@ -2226,28 +2245,28 @@ buffer_atomic_cmpswap_f32 = functools.partial(MUBUF, MUBUFOp.BUFFER_ATOMIC_CMPSW
|
||||
buffer_atomic_min_f32 = functools.partial(MUBUF, MUBUFOp.BUFFER_ATOMIC_MIN_F32)
|
||||
buffer_atomic_max_f32 = functools.partial(MUBUF, MUBUFOp.BUFFER_ATOMIC_MAX_F32)
|
||||
buffer_atomic_add_f32 = functools.partial(MUBUF, MUBUFOp.BUFFER_ATOMIC_ADD_F32)
|
||||
scratch_load_u8 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_U8, seg=2)
|
||||
scratch_load_i8 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_I8, seg=2)
|
||||
scratch_load_u16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_U16, seg=2)
|
||||
scratch_load_i16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_I16, seg=2)
|
||||
scratch_load_b32 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_B32, seg=2)
|
||||
scratch_load_b64 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_B64, seg=2)
|
||||
scratch_load_b96 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_B96, seg=2)
|
||||
scratch_load_b128 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_B128, seg=2)
|
||||
scratch_store_b8 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_B8, seg=2)
|
||||
scratch_store_b16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_B16, seg=2)
|
||||
scratch_store_b32 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_B32, seg=2)
|
||||
scratch_store_b64 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_B64, seg=2)
|
||||
scratch_store_b96 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_B96, seg=2)
|
||||
scratch_store_b128 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_B128, seg=2)
|
||||
scratch_load_d16_u8 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_D16_U8, seg=2)
|
||||
scratch_load_d16_i8 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_D16_I8, seg=2)
|
||||
scratch_load_d16_b16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_D16_B16, seg=2)
|
||||
scratch_load_d16_hi_u8 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_D16_HI_U8, seg=2)
|
||||
scratch_load_d16_hi_i8 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_D16_HI_I8, seg=2)
|
||||
scratch_load_d16_hi_b16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_D16_HI_B16, seg=2)
|
||||
scratch_store_d16_hi_b8 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_D16_HI_B8, seg=2)
|
||||
scratch_store_d16_hi_b16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_D16_HI_B16, seg=2)
|
||||
scratch_load_u8 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_U8, seg=1)
|
||||
scratch_load_i8 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_I8, seg=1)
|
||||
scratch_load_u16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_U16, seg=1)
|
||||
scratch_load_i16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_I16, seg=1)
|
||||
scratch_load_b32 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_B32, seg=1)
|
||||
scratch_load_b64 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_B64, seg=1)
|
||||
scratch_load_b96 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_B96, seg=1)
|
||||
scratch_load_b128 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_B128, seg=1)
|
||||
scratch_store_b8 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_B8, seg=1)
|
||||
scratch_store_b16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_B16, seg=1)
|
||||
scratch_store_b32 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_B32, seg=1)
|
||||
scratch_store_b64 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_B64, seg=1)
|
||||
scratch_store_b96 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_B96, seg=1)
|
||||
scratch_store_b128 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_B128, seg=1)
|
||||
scratch_load_d16_u8 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_D16_U8, seg=1)
|
||||
scratch_load_d16_i8 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_D16_I8, seg=1)
|
||||
scratch_load_d16_b16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_D16_B16, seg=1)
|
||||
scratch_load_d16_hi_u8 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_D16_HI_U8, seg=1)
|
||||
scratch_load_d16_hi_i8 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_D16_HI_I8, seg=1)
|
||||
scratch_load_d16_hi_b16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_D16_HI_B16, seg=1)
|
||||
scratch_store_d16_hi_b8 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_D16_HI_B8, seg=1)
|
||||
scratch_store_d16_hi_b16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_D16_HI_B16, seg=1)
|
||||
s_load_b32 = functools.partial(SMEM, SMEMOp.S_LOAD_B32)
|
||||
s_load_b64 = functools.partial(SMEM, SMEMOp.S_LOAD_B64)
|
||||
s_load_b128 = functools.partial(SMEM, SMEMOp.S_LOAD_B128)
|
||||
@@ -2485,6 +2504,7 @@ s_sleep = functools.partial(SOPP, SOPPOp.S_SLEEP)
|
||||
s_set_inst_prefetch_distance = functools.partial(SOPP, SOPPOp.S_SET_INST_PREFETCH_DISTANCE)
|
||||
s_clause = functools.partial(SOPP, SOPPOp.S_CLAUSE)
|
||||
s_delay_alu = functools.partial(SOPP, SOPPOp.S_DELAY_ALU)
|
||||
s_waitcnt_depctr = functools.partial(SOPP, SOPPOp.S_WAITCNT_DEPCTR)
|
||||
s_waitcnt = functools.partial(SOPP, SOPPOp.S_WAITCNT)
|
||||
s_wait_idle = functools.partial(SOPP, SOPPOp.S_WAIT_IDLE)
|
||||
s_wait_event = functools.partial(SOPP, SOPPOp.S_WAIT_EVENT)
|
||||
|
||||
@@ -6,22 +6,21 @@ from typing import overload, Annotated, TypeVar, Generic
|
||||
|
||||
# Bit field DSL
|
||||
class BitField:
|
||||
def __init__(self, hi: int, lo: int, name: str | None = None): self.hi, self.lo, self.name = hi, lo, name
|
||||
def __set_name__(self, owner, name): self.name, self._owner = name, owner
|
||||
def __init__(self, hi: int, lo: int, name: str | None = None): self.hi, self.lo, self.name, self._marker = hi, lo, name, None
|
||||
def __set_name__(self, owner, name):
|
||||
import typing
|
||||
self.name, self._owner = name, owner
|
||||
# Cache marker at class definition time
|
||||
hints = typing.get_type_hints(owner, include_extras=True)
|
||||
if name in hints:
|
||||
hint = hints[name]
|
||||
if typing.get_origin(hint) is Annotated:
|
||||
args = typing.get_args(hint)
|
||||
self._marker = args[1] if len(args) > 1 else None
|
||||
def __eq__(self, val: int) -> tuple[BitField, int]: return (self, val) # type: ignore
|
||||
def mask(self) -> int: return (1 << (self.hi - self.lo + 1)) - 1
|
||||
@property
|
||||
def marker(self) -> type | None:
|
||||
# Get marker from Annotated type hint if present
|
||||
import typing
|
||||
if hasattr(self, '_owner') and self.name:
|
||||
hints = typing.get_type_hints(self._owner, include_extras=True)
|
||||
if self.name in hints:
|
||||
hint = hints[self.name]
|
||||
if typing.get_origin(hint) is Annotated:
|
||||
args = typing.get_args(hint)
|
||||
return args[1] if len(args) > 1 else None
|
||||
return None
|
||||
def marker(self) -> type | None: return self._marker
|
||||
@overload
|
||||
def __get__(self, obj: None, objtype: type) -> BitField: ...
|
||||
@overload
|
||||
@@ -179,6 +178,21 @@ class Inst:
|
||||
raise ValueError(f"SOP1 {op_val.name} expects {expected} destination register(s), got {sdst_val.count}")
|
||||
if isinstance(ssrc0_val, Reg) and ssrc0_val.count != expected:
|
||||
raise ValueError(f"SOP1 {op_val.name} expects {expected} source register(s), got {ssrc0_val.count}")
|
||||
# FLAT: set sve=1 when addr is a VGPR for scratch only
|
||||
# For scratch (seg=1), sve=1 means addr VGPR is used; sve=0 means addr is "off"
|
||||
# For global (seg=2) and flat (seg=0), sve is always 0
|
||||
if self.__class__.__name__ == 'FLAT' and 'sve' in self._fields:
|
||||
seg_val = self._values.get('seg', 0)
|
||||
if isinstance(seg_val, RawImm): seg_val = seg_val.val
|
||||
addr_val = orig_args.get('addr')
|
||||
if seg_val == 1 and isinstance(addr_val, VGPR): self._values['sve'] = 1
|
||||
# VOP3P: v_fma_mix* instructions (opcodes 32-34) have opsel_hi default of 0, not 7
|
||||
if self.__class__.__name__ == 'VOP3P':
|
||||
op_val = orig_args.get(field_names[0]) if args else orig_args.get('op')
|
||||
if hasattr(op_val, 'value'): op_val = op_val.value
|
||||
if op_val in (32, 33, 34) and 'opsel_hi' not in orig_args and 'opsel_hi2' not in orig_args:
|
||||
self._values['opsel_hi'] = 0
|
||||
self._values['opsel_hi2'] = 0
|
||||
# Type check and encode values
|
||||
for name, val in list(self._values.items()):
|
||||
if name == 'encoding': continue
|
||||
@@ -340,6 +354,14 @@ class Inst:
|
||||
lit = f", literal={hex(self._literal)}" if self._literal is not None else ""
|
||||
return f"{self.__class__.__name__}({', '.join(f'{k}={v}' for k, v in items)}{lit})"
|
||||
|
||||
def __getattr__(self, name: str):
|
||||
if name.startswith('_'): raise AttributeError(name)
|
||||
return unwrap(self._values.get(name, 0))
|
||||
|
||||
def lit(self, v: int) -> str:
|
||||
from extra.assembly.amd.asm import decode_src
|
||||
return f"0x{self._literal:x}" if v == 255 and self._literal else decode_src(v)
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, Inst): return NotImplemented
|
||||
return self.__class__ == other.__class__ and self._values == other._values and self._literal == other._literal
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
import ctypes, os
|
||||
from extra.assembly.amd.dsl import Inst, RawImm
|
||||
from extra.assembly.amd.asm import detect_format
|
||||
from extra.assembly.amd.pcode import _f32, _i32, _sext, _f16, _i16, _f64, _i64
|
||||
from extra.assembly.amd.autogen.rdna3.gen_pcode import get_compiled_functions
|
||||
from extra.assembly.amd.autogen.rdna3 import (
|
||||
@@ -146,21 +147,7 @@ class WaveState:
|
||||
for reg, val in self._pend_sgpr.items(): self.sgpr[reg] = val
|
||||
self._pend_sgpr.clear()
|
||||
|
||||
# Instruction decode
|
||||
def decode_format(word: int) -> tuple[type[Inst] | None, bool]:
|
||||
hi2 = (word >> 30) & 0x3
|
||||
if hi2 == 0b11:
|
||||
enc = (word >> 26) & 0xf
|
||||
if enc == 0b1101: return SMEM, True
|
||||
if enc == 0b0101:
|
||||
op = (word >> 16) & 0x3ff
|
||||
return (VOP3SD, True) if op in (288, 289, 290, 764, 765, 766, 767, 768, 769, 770) else (VOP3, True)
|
||||
return {0b0011: (VOP3P, True), 0b0110: (DS, True), 0b0111: (FLAT, True), 0b0010: (VOPD, True)}.get(enc, (None, True))
|
||||
if hi2 == 0b10:
|
||||
enc = (word >> 23) & 0x7f
|
||||
return {0b1111101: (SOP1, False), 0b1111110: (SOPC, False), 0b1111111: (SOPP, False)}.get(enc, (SOPK, False) if ((word >> 28) & 0xf) == 0b1011 else (SOP2, False))
|
||||
enc = (word >> 25) & 0x7f
|
||||
return (VOPC, False) if enc == 0b0111110 else (VOP1, False) if enc == 0b0111111 else (VOP2, False)
|
||||
|
||||
|
||||
def _unwrap(v) -> int: return v.val if isinstance(v, RawImm) else v.value if hasattr(v, 'value') else v
|
||||
|
||||
@@ -168,10 +155,10 @@ def decode_program(data: bytes) -> Program:
|
||||
result: Program = {}
|
||||
i = 0
|
||||
while i < len(data):
|
||||
word = int.from_bytes(data[i:i+4], 'little')
|
||||
inst_class, is_64 = decode_format(word)
|
||||
try: inst_class = detect_format(data[i:])
|
||||
except ValueError: break # stop at invalid instruction (padding/metadata after code)
|
||||
if inst_class is None: i += 4; continue
|
||||
base_size = 8 if is_64 else 4
|
||||
base_size = inst_class._size()
|
||||
# Pass enough data for potential 64-bit literal (base + 8 bytes max)
|
||||
inst = inst_class.from_bytes(data[i:i+base_size+8])
|
||||
for name, val in inst._values.items(): setattr(inst, name, _unwrap(val))
|
||||
|
||||
@@ -65,12 +65,18 @@ def parse_llvm_tests(text: str) -> list[tuple[str, bytes]]:
|
||||
if not asm_text: continue
|
||||
for j in range(i, min(i + 3, len(lines))):
|
||||
# Match GFX11, W32, or W64 encodings (all valid for gfx11)
|
||||
# Format 1: "// GFX11: v_foo ... ; encoding: [0x01,0x02,...]"
|
||||
# Format 2: "// GFX11: [0x01,0x02,...]" (used by DS, older files)
|
||||
if m := re.search(r'(?:GFX11|W32|W64)[^:]*:.*?encoding:\s*\[(.*?)\]', lines[j]):
|
||||
hex_bytes = m.group(1).replace('0x', '').replace(',', '').replace(' ', '')
|
||||
if hex_bytes:
|
||||
try: tests.append((asm_text, bytes.fromhex(hex_bytes)))
|
||||
except ValueError: pass
|
||||
break
|
||||
elif m := re.search(r'(?:GFX11|W32|W64)[^:]*:\s*\[(0x[0-9a-fA-F,x\s]+)\]', lines[j]):
|
||||
hex_bytes = m.group(1).replace('0x', '').replace(',', '').replace(' ', '')
|
||||
else:
|
||||
continue
|
||||
if hex_bytes:
|
||||
try: tests.append((asm_text, bytes.fromhex(hex_bytes)))
|
||||
except ValueError: pass
|
||||
break
|
||||
return tests
|
||||
|
||||
def try_assemble(text: str):
|
||||
|
||||
@@ -4,51 +4,9 @@ import unittest, io, sys, re, subprocess, os
|
||||
from extra.assembly.amd.autogen.rdna3 import *
|
||||
from extra.assembly.amd.dsl import Inst
|
||||
from extra.assembly.amd.asm import asm
|
||||
from extra.assembly.amd.asm import detect_format
|
||||
from extra.assembly.amd.test.helpers import get_llvm_mc, get_llvm_objdump
|
||||
|
||||
# Instruction format detection based on encoding bits
|
||||
def detect_format(data: bytes) -> type[Inst] | None:
|
||||
"""Detect instruction format from machine code bytes."""
|
||||
if len(data) < 4: return None
|
||||
word = int.from_bytes(data[:4], 'little')
|
||||
enc_9bit = (word >> 23) & 0x1FF # 9-bit encoding for SOP1/SOPC/SOPP
|
||||
enc_8bit = (word >> 24) & 0xFF
|
||||
|
||||
# Check 9-bit encodings first (most specific)
|
||||
if enc_9bit == 0x17D: return SOP1 # bits 31:23 = 101111101
|
||||
if enc_9bit == 0x17E: return SOPC # bits 31:23 = 101111110
|
||||
if enc_9bit == 0x17F: return SOPP # bits 31:23 = 101111111
|
||||
# SOPK: bits 31:28 = 1011, bits 27:23 = opcode (check after SOP1/SOPC/SOPP)
|
||||
if enc_8bit in range(0xB0, 0xC0): return SOPK
|
||||
# SOP2: bits 31:23 in range 0x100-0x17C (0x80-0xBE in bits 31:24, but not SOPK)
|
||||
if 0x80 <= enc_8bit <= 0x9F: return SOP2
|
||||
# VOP1: bits 31:25 = 0111111 (0x3F)
|
||||
if (word >> 25) == 0x3F: return VOP1
|
||||
# VOPC: bits 31:25 = 0111110 (0x3E)
|
||||
if (word >> 25) == 0x3E: return VOPC
|
||||
# VOP2: bits 31:30 = 00
|
||||
if (word >> 30) == 0: return VOP2
|
||||
|
||||
# Check 64-bit formats
|
||||
if len(data) >= 8:
|
||||
if enc_8bit in (0xD4, 0xD5, 0xD7):
|
||||
# VOP3 and VOP3SD share encoding - check opcode to determine which
|
||||
# VOP3SD opcodes: 288-290 (v_*_co_ci_*), 764-770 (v_div_scale_*, v_mad_*, v_*_co_u32)
|
||||
op = (int.from_bytes(data[:8], 'little') >> 16) & 0x3FF
|
||||
if op in {288, 289, 290, 764, 765, 766, 767, 768, 769, 770}: return VOP3SD
|
||||
return VOP3
|
||||
if enc_8bit == 0xD6: return VOP3SD
|
||||
if enc_8bit == 0xCC: return VOP3P
|
||||
if enc_8bit == 0xCD: return VINTERP
|
||||
if enc_8bit in (0xC8, 0xC9): return VOPD
|
||||
if enc_8bit == 0xF4: return SMEM
|
||||
if enc_8bit == 0xD8: return DS
|
||||
if enc_8bit in (0xDC, 0xDD, 0xDE, 0xDF): return FLAT
|
||||
if enc_8bit in (0xE0, 0xE1, 0xE2, 0xE3): return MUBUF
|
||||
if enc_8bit in (0xE8, 0xE9, 0xEA, 0xEB): return MTBUF
|
||||
|
||||
return None
|
||||
|
||||
def disassemble_lib(lib: bytes, compiler) -> list[tuple[str, bytes]]:
|
||||
"""Disassemble ELF binary and return list of (instruction_text, machine_code_bytes)."""
|
||||
old_stdout = sys.stdout
|
||||
|
||||
Reference in New Issue
Block a user