Merge origin/master into only_reg_emu2 (keep branch's Reg-based approach)

This commit is contained in:
George Hotz
2025-12-30 18:53:50 +00:00
15 changed files with 6397 additions and 2892 deletions

View File

@@ -654,7 +654,7 @@ jobs:
- name: Run process replay tests - name: Run process replay tests
uses: ./.github/actions/process-replay uses: ./.github/actions/process-replay
testrdna3: testamdasm:
name: AMD ASM IDE name: AMD ASM IDE
runs-on: ubuntu-24.04 runs-on: ubuntu-24.04
timeout-minutes: 10 timeout-minutes: 10
@@ -677,8 +677,25 @@ jobs:
run: cloc --by-file extra/assembly/amd/*.py run: cloc --by-file extra/assembly/amd/*.py
- name: Run RDNA3 emulator tests - name: Run RDNA3 emulator tests
run: python -m pytest -n=auto extra/assembly/amd/ --durations 20 run: python -m pytest -n=auto extra/assembly/amd/ --durations 20
- name: Install pdfplumber - name: Run RDNA3 emulator tests (AMD_LLVM=1)
run: pip install pdfplumber run: AMD_LLVM=1 python -m pytest -n=auto extra/assembly/amd/ --durations 20
- name: Run RDNA3 dtype tests
run: PYTHONPATH="." AMD=1 PYTHON_REMU=1 MOCKGPU=1 AMD_LLVM=0 pytest -n=auto test/test_dtype_alu.py test/test_dtype.py
- name: Run RDNA3 dtype tests (AMD_LLVM=1)
run: PYTHONPATH="." AMD=1 PYTHON_REMU=1 MOCKGPU=1 AMD_LLVM=1 pytest -n=auto test/test_dtype_alu.py test/test_dtype.py
testamdautogen:
name: AMD autogen
runs-on: ubuntu-24.04
timeout-minutes: 10
steps:
- name: Checkout Code
uses: actions/checkout@v4
- name: Setup Environment
uses: ./.github/actions/setup-tinygrad
with:
key: rdna3-autogen
pydeps: "pdfplumber"
- name: Verify AMD autogen is up to date - name: Verify AMD autogen is up to date
run: | run: |
python -m extra.assembly.amd.dsl --arch all python -m extra.assembly.amd.dsl --arch all

File diff suppressed because it is too large Load Diff

View File

@@ -2209,33 +2209,33 @@ buffer_atomic_xor_x2 = functools.partial(MUBUF, MUBUFOp.BUFFER_ATOMIC_XOR_X2)
buffer_atomic_inc_x2 = functools.partial(MUBUF, MUBUFOp.BUFFER_ATOMIC_INC_X2) buffer_atomic_inc_x2 = functools.partial(MUBUF, MUBUFOp.BUFFER_ATOMIC_INC_X2)
buffer_atomic_dec_x2 = functools.partial(MUBUF, MUBUFOp.BUFFER_ATOMIC_DEC_X2) buffer_atomic_dec_x2 = functools.partial(MUBUF, MUBUFOp.BUFFER_ATOMIC_DEC_X2)
cdna4 = functools.partial(MUBUF, MUBUFOp.CDNA4) cdna4 = functools.partial(MUBUF, MUBUFOp.CDNA4)
scratch_load_ubyte = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_UBYTE, seg=2) scratch_load_ubyte = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_UBYTE, seg=1)
scratch_load_sbyte = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_SBYTE, seg=2) scratch_load_sbyte = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_SBYTE, seg=1)
scratch_load_ushort = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_USHORT, seg=2) scratch_load_ushort = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_USHORT, seg=1)
scratch_load_sshort = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_SSHORT, seg=2) scratch_load_sshort = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_SSHORT, seg=1)
scratch_load_dword = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_DWORD, seg=2) scratch_load_dword = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_DWORD, seg=1)
scratch_load_dwordx2 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_DWORDX2, seg=2) scratch_load_dwordx2 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_DWORDX2, seg=1)
scratch_load_dwordx3 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_DWORDX3, seg=2) scratch_load_dwordx3 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_DWORDX3, seg=1)
scratch_load_dwordx4 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_DWORDX4, seg=2) scratch_load_dwordx4 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_DWORDX4, seg=1)
scratch_store_byte = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_BYTE, seg=2) scratch_store_byte = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_BYTE, seg=1)
scratch_store_byte_d16_hi = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_BYTE_D16_HI, seg=2) scratch_store_byte_d16_hi = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_BYTE_D16_HI, seg=1)
scratch_store_short = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_SHORT, seg=2) scratch_store_short = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_SHORT, seg=1)
scratch_store_short_d16_hi = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_SHORT_D16_HI, seg=2) scratch_store_short_d16_hi = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_SHORT_D16_HI, seg=1)
scratch_store_dword = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_DWORD, seg=2) scratch_store_dword = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_DWORD, seg=1)
scratch_store_dwordx2 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_DWORDX2, seg=2) scratch_store_dwordx2 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_DWORDX2, seg=1)
scratch_store_dwordx3 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_DWORDX3, seg=2) scratch_store_dwordx3 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_DWORDX3, seg=1)
scratch_store_dwordx4 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_DWORDX4, seg=2) scratch_store_dwordx4 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_DWORDX4, seg=1)
scratch_load_ubyte_d16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_UBYTE_D16, seg=2) scratch_load_ubyte_d16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_UBYTE_D16, seg=1)
scratch_load_ubyte_d16_hi = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_UBYTE_D16_HI, seg=2) scratch_load_ubyte_d16_hi = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_UBYTE_D16_HI, seg=1)
scratch_load_sbyte_d16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_SBYTE_D16, seg=2) scratch_load_sbyte_d16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_SBYTE_D16, seg=1)
scratch_load_sbyte_d16_hi = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_SBYTE_D16_HI, seg=2) scratch_load_sbyte_d16_hi = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_SBYTE_D16_HI, seg=1)
scratch_load_short_d16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_SHORT_D16, seg=2) scratch_load_short_d16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_SHORT_D16, seg=1)
scratch_load_short_d16_hi = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_SHORT_D16_HI, seg=2) scratch_load_short_d16_hi = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_SHORT_D16_HI, seg=1)
scratch_load_lds_ubyte = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_LDS_UBYTE, seg=2) scratch_load_lds_ubyte = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_LDS_UBYTE, seg=1)
scratch_load_lds_sbyte = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_LDS_SBYTE, seg=2) scratch_load_lds_sbyte = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_LDS_SBYTE, seg=1)
scratch_load_lds_ushort = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_LDS_USHORT, seg=2) scratch_load_lds_ushort = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_LDS_USHORT, seg=1)
scratch_load_lds_sshort = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_LDS_SSHORT, seg=2) scratch_load_lds_sshort = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_LDS_SSHORT, seg=1)
scratch_load_lds_dword = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_LDS_DWORD, seg=2) scratch_load_lds_dword = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_LDS_DWORD, seg=1)
s_load_dword = functools.partial(SMEM, SMEMOp.S_LOAD_DWORD) s_load_dword = functools.partial(SMEM, SMEMOp.S_LOAD_DWORD)
s_load_dwordx2 = functools.partial(SMEM, SMEMOp.S_LOAD_DWORDX2) s_load_dwordx2 = functools.partial(SMEM, SMEMOp.S_LOAD_DWORDX2)
s_load_dwordx4 = functools.partial(SMEM, SMEMOp.S_LOAD_DWORDX4) s_load_dwordx4 = functools.partial(SMEM, SMEMOp.S_LOAD_DWORDX4)

File diff suppressed because it is too large Load Diff

View File

@@ -56,6 +56,12 @@ class DSOp(IntEnum):
DS_MAX_F32 = 19 DS_MAX_F32 = 19
DS_NOP = 20 DS_NOP = 20
DS_ADD_F32 = 21 DS_ADD_F32 = 21
DS_GWS_SEMA_RELEASE_ALL = 24
DS_GWS_INIT = 25
DS_GWS_SEMA_V = 26
DS_GWS_SEMA_BR = 27
DS_GWS_SEMA_P = 28
DS_GWS_BARRIER = 29
DS_STORE_B8 = 30 DS_STORE_B8 = 30
DS_STORE_B16 = 31 DS_STORE_B16 = 31
DS_ADD_RTN_U32 = 32 DS_ADD_RTN_U32 = 32
@@ -178,10 +184,13 @@ class FLATOp(IntEnum):
FLAT_LOAD_D16_HI_B16 = 35 FLAT_LOAD_D16_HI_B16 = 35
FLAT_STORE_D16_HI_B8 = 36 FLAT_STORE_D16_HI_B8 = 36
FLAT_STORE_D16_HI_B16 = 37 FLAT_STORE_D16_HI_B16 = 37
GLOBAL_LOAD_ADDTID_B32 = 40
GLOBAL_STORE_ADDTID_B32 = 41
FLAT_ATOMIC_SWAP_B32 = 51 FLAT_ATOMIC_SWAP_B32 = 51
FLAT_ATOMIC_CMPSWAP_B32 = 52 FLAT_ATOMIC_CMPSWAP_B32 = 52
FLAT_ATOMIC_ADD_U32 = 53 FLAT_ATOMIC_ADD_U32 = 53
FLAT_ATOMIC_SUB_U32 = 54 FLAT_ATOMIC_SUB_U32 = 54
FLAT_ATOMIC_CSUB_U32 = 55
FLAT_ATOMIC_MIN_I32 = 56 FLAT_ATOMIC_MIN_I32 = 56
FLAT_ATOMIC_MIN_U32 = 57 FLAT_ATOMIC_MIN_U32 = 57
FLAT_ATOMIC_MAX_I32 = 58 FLAT_ATOMIC_MAX_I32 = 58
@@ -717,6 +726,7 @@ class SOPPOp(IntEnum):
S_SET_INST_PREFETCH_DISTANCE = 4 S_SET_INST_PREFETCH_DISTANCE = 4
S_CLAUSE = 5 S_CLAUSE = 5
S_DELAY_ALU = 7 S_DELAY_ALU = 7
S_WAITCNT_DEPCTR = 8
S_WAITCNT = 9 S_WAITCNT = 9
S_WAIT_IDLE = 10 S_WAIT_IDLE = 10
S_WAIT_EVENT = 11 S_WAIT_EVENT = 11
@@ -1848,6 +1858,12 @@ ds_min_f32 = functools.partial(DS, DSOp.DS_MIN_F32)
ds_max_f32 = functools.partial(DS, DSOp.DS_MAX_F32) ds_max_f32 = functools.partial(DS, DSOp.DS_MAX_F32)
ds_nop = functools.partial(DS, DSOp.DS_NOP) ds_nop = functools.partial(DS, DSOp.DS_NOP)
ds_add_f32 = functools.partial(DS, DSOp.DS_ADD_F32) ds_add_f32 = functools.partial(DS, DSOp.DS_ADD_F32)
ds_gws_sema_release_all = functools.partial(DS, DSOp.DS_GWS_SEMA_RELEASE_ALL)
ds_gws_init = functools.partial(DS, DSOp.DS_GWS_INIT)
ds_gws_sema_v = functools.partial(DS, DSOp.DS_GWS_SEMA_V)
ds_gws_sema_br = functools.partial(DS, DSOp.DS_GWS_SEMA_BR)
ds_gws_sema_p = functools.partial(DS, DSOp.DS_GWS_SEMA_P)
ds_gws_barrier = functools.partial(DS, DSOp.DS_GWS_BARRIER)
ds_store_b8 = functools.partial(DS, DSOp.DS_STORE_B8) ds_store_b8 = functools.partial(DS, DSOp.DS_STORE_B8)
ds_store_b16 = functools.partial(DS, DSOp.DS_STORE_B16) ds_store_b16 = functools.partial(DS, DSOp.DS_STORE_B16)
ds_add_rtn_u32 = functools.partial(DS, DSOp.DS_ADD_RTN_U32) ds_add_rtn_u32 = functools.partial(DS, DSOp.DS_ADD_RTN_U32)
@@ -1968,10 +1984,13 @@ flat_load_d16_hi_i8 = functools.partial(FLAT, FLATOp.FLAT_LOAD_D16_HI_I8)
flat_load_d16_hi_b16 = functools.partial(FLAT, FLATOp.FLAT_LOAD_D16_HI_B16) flat_load_d16_hi_b16 = functools.partial(FLAT, FLATOp.FLAT_LOAD_D16_HI_B16)
flat_store_d16_hi_b8 = functools.partial(FLAT, FLATOp.FLAT_STORE_D16_HI_B8) flat_store_d16_hi_b8 = functools.partial(FLAT, FLATOp.FLAT_STORE_D16_HI_B8)
flat_store_d16_hi_b16 = functools.partial(FLAT, FLATOp.FLAT_STORE_D16_HI_B16) flat_store_d16_hi_b16 = functools.partial(FLAT, FLATOp.FLAT_STORE_D16_HI_B16)
global_load_addtid_b32 = functools.partial(FLAT, FLATOp.GLOBAL_LOAD_ADDTID_B32)
global_store_addtid_b32 = functools.partial(FLAT, FLATOp.GLOBAL_STORE_ADDTID_B32)
flat_atomic_swap_b32 = functools.partial(FLAT, FLATOp.FLAT_ATOMIC_SWAP_B32) flat_atomic_swap_b32 = functools.partial(FLAT, FLATOp.FLAT_ATOMIC_SWAP_B32)
flat_atomic_cmpswap_b32 = functools.partial(FLAT, FLATOp.FLAT_ATOMIC_CMPSWAP_B32) flat_atomic_cmpswap_b32 = functools.partial(FLAT, FLATOp.FLAT_ATOMIC_CMPSWAP_B32)
flat_atomic_add_u32 = functools.partial(FLAT, FLATOp.FLAT_ATOMIC_ADD_U32) flat_atomic_add_u32 = functools.partial(FLAT, FLATOp.FLAT_ATOMIC_ADD_U32)
flat_atomic_sub_u32 = functools.partial(FLAT, FLATOp.FLAT_ATOMIC_SUB_U32) flat_atomic_sub_u32 = functools.partial(FLAT, FLATOp.FLAT_ATOMIC_SUB_U32)
flat_atomic_csub_u32 = functools.partial(FLAT, FLATOp.FLAT_ATOMIC_CSUB_U32)
flat_atomic_min_i32 = functools.partial(FLAT, FLATOp.FLAT_ATOMIC_MIN_I32) flat_atomic_min_i32 = functools.partial(FLAT, FLATOp.FLAT_ATOMIC_MIN_I32)
flat_atomic_min_u32 = functools.partial(FLAT, FLATOp.FLAT_ATOMIC_MIN_U32) flat_atomic_min_u32 = functools.partial(FLAT, FLATOp.FLAT_ATOMIC_MIN_U32)
flat_atomic_max_i32 = functools.partial(FLAT, FLATOp.FLAT_ATOMIC_MAX_I32) flat_atomic_max_i32 = functools.partial(FLAT, FLATOp.FLAT_ATOMIC_MAX_I32)
@@ -2226,28 +2245,28 @@ buffer_atomic_cmpswap_f32 = functools.partial(MUBUF, MUBUFOp.BUFFER_ATOMIC_CMPSW
buffer_atomic_min_f32 = functools.partial(MUBUF, MUBUFOp.BUFFER_ATOMIC_MIN_F32) buffer_atomic_min_f32 = functools.partial(MUBUF, MUBUFOp.BUFFER_ATOMIC_MIN_F32)
buffer_atomic_max_f32 = functools.partial(MUBUF, MUBUFOp.BUFFER_ATOMIC_MAX_F32) buffer_atomic_max_f32 = functools.partial(MUBUF, MUBUFOp.BUFFER_ATOMIC_MAX_F32)
buffer_atomic_add_f32 = functools.partial(MUBUF, MUBUFOp.BUFFER_ATOMIC_ADD_F32) buffer_atomic_add_f32 = functools.partial(MUBUF, MUBUFOp.BUFFER_ATOMIC_ADD_F32)
scratch_load_u8 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_U8, seg=2) scratch_load_u8 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_U8, seg=1)
scratch_load_i8 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_I8, seg=2) scratch_load_i8 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_I8, seg=1)
scratch_load_u16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_U16, seg=2) scratch_load_u16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_U16, seg=1)
scratch_load_i16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_I16, seg=2) scratch_load_i16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_I16, seg=1)
scratch_load_b32 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_B32, seg=2) scratch_load_b32 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_B32, seg=1)
scratch_load_b64 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_B64, seg=2) scratch_load_b64 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_B64, seg=1)
scratch_load_b96 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_B96, seg=2) scratch_load_b96 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_B96, seg=1)
scratch_load_b128 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_B128, seg=2) scratch_load_b128 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_B128, seg=1)
scratch_store_b8 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_B8, seg=2) scratch_store_b8 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_B8, seg=1)
scratch_store_b16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_B16, seg=2) scratch_store_b16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_B16, seg=1)
scratch_store_b32 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_B32, seg=2) scratch_store_b32 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_B32, seg=1)
scratch_store_b64 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_B64, seg=2) scratch_store_b64 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_B64, seg=1)
scratch_store_b96 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_B96, seg=2) scratch_store_b96 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_B96, seg=1)
scratch_store_b128 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_B128, seg=2) scratch_store_b128 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_B128, seg=1)
scratch_load_d16_u8 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_D16_U8, seg=2) scratch_load_d16_u8 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_D16_U8, seg=1)
scratch_load_d16_i8 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_D16_I8, seg=2) scratch_load_d16_i8 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_D16_I8, seg=1)
scratch_load_d16_b16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_D16_B16, seg=2) scratch_load_d16_b16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_D16_B16, seg=1)
scratch_load_d16_hi_u8 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_D16_HI_U8, seg=2) scratch_load_d16_hi_u8 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_D16_HI_U8, seg=1)
scratch_load_d16_hi_i8 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_D16_HI_I8, seg=2) scratch_load_d16_hi_i8 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_D16_HI_I8, seg=1)
scratch_load_d16_hi_b16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_D16_HI_B16, seg=2) scratch_load_d16_hi_b16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_D16_HI_B16, seg=1)
scratch_store_d16_hi_b8 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_D16_HI_B8, seg=2) scratch_store_d16_hi_b8 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_D16_HI_B8, seg=1)
scratch_store_d16_hi_b16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_D16_HI_B16, seg=2) scratch_store_d16_hi_b16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_D16_HI_B16, seg=1)
s_load_b32 = functools.partial(SMEM, SMEMOp.S_LOAD_B32) s_load_b32 = functools.partial(SMEM, SMEMOp.S_LOAD_B32)
s_load_b64 = functools.partial(SMEM, SMEMOp.S_LOAD_B64) s_load_b64 = functools.partial(SMEM, SMEMOp.S_LOAD_B64)
s_load_b128 = functools.partial(SMEM, SMEMOp.S_LOAD_B128) s_load_b128 = functools.partial(SMEM, SMEMOp.S_LOAD_B128)
@@ -2485,6 +2504,7 @@ s_sleep = functools.partial(SOPP, SOPPOp.S_SLEEP)
s_set_inst_prefetch_distance = functools.partial(SOPP, SOPPOp.S_SET_INST_PREFETCH_DISTANCE) s_set_inst_prefetch_distance = functools.partial(SOPP, SOPPOp.S_SET_INST_PREFETCH_DISTANCE)
s_clause = functools.partial(SOPP, SOPPOp.S_CLAUSE) s_clause = functools.partial(SOPP, SOPPOp.S_CLAUSE)
s_delay_alu = functools.partial(SOPP, SOPPOp.S_DELAY_ALU) s_delay_alu = functools.partial(SOPP, SOPPOp.S_DELAY_ALU)
s_waitcnt_depctr = functools.partial(SOPP, SOPPOp.S_WAITCNT_DEPCTR)
s_waitcnt = functools.partial(SOPP, SOPPOp.S_WAITCNT) s_waitcnt = functools.partial(SOPP, SOPPOp.S_WAITCNT)
s_wait_idle = functools.partial(SOPP, SOPPOp.S_WAIT_IDLE) s_wait_idle = functools.partial(SOPP, SOPPOp.S_WAIT_IDLE)
s_wait_event = functools.partial(SOPP, SOPPOp.S_WAIT_EVENT) s_wait_event = functools.partial(SOPP, SOPPOp.S_WAIT_EVENT)

File diff suppressed because it is too large Load Diff

View File

@@ -6,22 +6,21 @@ from typing import overload, Annotated, TypeVar, Generic
# Bit field DSL # Bit field DSL
class BitField: class BitField:
def __init__(self, hi: int, lo: int, name: str | None = None): self.hi, self.lo, self.name = hi, lo, name def __init__(self, hi: int, lo: int, name: str | None = None): self.hi, self.lo, self.name, self._marker = hi, lo, name, None
def __set_name__(self, owner, name): self.name, self._owner = name, owner def __set_name__(self, owner, name):
import typing
self.name, self._owner = name, owner
# Cache marker at class definition time
hints = typing.get_type_hints(owner, include_extras=True)
if name in hints:
hint = hints[name]
if typing.get_origin(hint) is Annotated:
args = typing.get_args(hint)
self._marker = args[1] if len(args) > 1 else None
def __eq__(self, val: int) -> tuple[BitField, int]: return (self, val) # type: ignore def __eq__(self, val: int) -> tuple[BitField, int]: return (self, val) # type: ignore
def mask(self) -> int: return (1 << (self.hi - self.lo + 1)) - 1 def mask(self) -> int: return (1 << (self.hi - self.lo + 1)) - 1
@property @property
def marker(self) -> type | None: def marker(self) -> type | None: return self._marker
# Get marker from Annotated type hint if present
import typing
if hasattr(self, '_owner') and self.name:
hints = typing.get_type_hints(self._owner, include_extras=True)
if self.name in hints:
hint = hints[self.name]
if typing.get_origin(hint) is Annotated:
args = typing.get_args(hint)
return args[1] if len(args) > 1 else None
return None
@overload @overload
def __get__(self, obj: None, objtype: type) -> BitField: ... def __get__(self, obj: None, objtype: type) -> BitField: ...
@overload @overload
@@ -179,6 +178,21 @@ class Inst:
raise ValueError(f"SOP1 {op_val.name} expects {expected} destination register(s), got {sdst_val.count}") raise ValueError(f"SOP1 {op_val.name} expects {expected} destination register(s), got {sdst_val.count}")
if isinstance(ssrc0_val, Reg) and ssrc0_val.count != expected: if isinstance(ssrc0_val, Reg) and ssrc0_val.count != expected:
raise ValueError(f"SOP1 {op_val.name} expects {expected} source register(s), got {ssrc0_val.count}") raise ValueError(f"SOP1 {op_val.name} expects {expected} source register(s), got {ssrc0_val.count}")
# FLAT: set sve=1 when addr is a VGPR for scratch only
# For scratch (seg=1), sve=1 means addr VGPR is used; sve=0 means addr is "off"
# For global (seg=2) and flat (seg=0), sve is always 0
if self.__class__.__name__ == 'FLAT' and 'sve' in self._fields:
seg_val = self._values.get('seg', 0)
if isinstance(seg_val, RawImm): seg_val = seg_val.val
addr_val = orig_args.get('addr')
if seg_val == 1 and isinstance(addr_val, VGPR): self._values['sve'] = 1
# VOP3P: v_fma_mix* instructions (opcodes 32-34) have opsel_hi default of 0, not 7
if self.__class__.__name__ == 'VOP3P':
op_val = orig_args.get(field_names[0]) if args else orig_args.get('op')
if hasattr(op_val, 'value'): op_val = op_val.value
if op_val in (32, 33, 34) and 'opsel_hi' not in orig_args and 'opsel_hi2' not in orig_args:
self._values['opsel_hi'] = 0
self._values['opsel_hi2'] = 0
# Type check and encode values # Type check and encode values
for name, val in list(self._values.items()): for name, val in list(self._values.items()):
if name == 'encoding': continue if name == 'encoding': continue
@@ -283,6 +297,10 @@ class Inst:
from extra.assembly.amd.autogen.rdna3 import VOP3Op from extra.assembly.amd.autogen.rdna3 import VOP3Op
try: op_name = VOP3Op(op).name try: op_name = VOP3Op(op).name
except ValueError: pass except ValueError: pass
if op_name is None and self.__class__.__name__ == 'VOPC':
from extra.assembly.amd.autogen.rdna3 import VOPCOp
try: op_name = VOPCOp(op).name
except ValueError: pass
if op_name is None: return False if op_name is None: return False
# V_LDEXP_F64 has 32-bit integer exponent in src1, so literal is 32-bit # V_LDEXP_F64 has 32-bit integer exponent in src1, so literal is 32-bit
if op_name == 'V_LDEXP_F64': return False if op_name == 'V_LDEXP_F64': return False
@@ -315,6 +333,9 @@ class Inst:
op_val = inst._values.get('op', 0) op_val = inst._values.get('op', 0)
has_literal = cls.__name__ == 'VOP2' and op_val in (44, 45, 55, 56) has_literal = cls.__name__ == 'VOP2' and op_val in (44, 45, 55, 56)
has_literal = has_literal or (cls.__name__ == 'SOP2' and op_val in (69, 70)) has_literal = has_literal or (cls.__name__ == 'SOP2' and op_val in (69, 70))
# VOPD fmaak/fmamk always have a literal (opx/opy value 1 or 2)
opx, opy = inst._values.get('opx', 0), inst._values.get('opy', 0)
has_literal = has_literal or (cls.__name__ == 'VOPD' and (opx in (1, 2) or opy in (1, 2)))
for n in SRC_FIELDS: for n in SRC_FIELDS:
if n in inst._values and isinstance(inst._values[n], RawImm) and inst._values[n].val == 255: has_literal = True if n in inst._values and isinstance(inst._values[n], RawImm) and inst._values[n].val == 255: has_literal = True
if has_literal: if has_literal:
@@ -333,6 +354,14 @@ class Inst:
lit = f", literal={hex(self._literal)}" if self._literal is not None else "" lit = f", literal={hex(self._literal)}" if self._literal is not None else ""
return f"{self.__class__.__name__}({', '.join(f'{k}={v}' for k, v in items)}{lit})" return f"{self.__class__.__name__}({', '.join(f'{k}={v}' for k, v in items)}{lit})"
def __getattr__(self, name: str):
if name.startswith('_'): raise AttributeError(name)
return unwrap(self._values.get(name, 0))
def lit(self, v: int) -> str:
from extra.assembly.amd.asm import decode_src
return f"0x{self._literal:x}" if v == 255 and self._literal else decode_src(v)
def __eq__(self, other): def __eq__(self, other):
if not isinstance(other, Inst): return NotImplemented if not isinstance(other, Inst): return NotImplemented
return self.__class__ == other.__class__ and self._values == other._values and self._literal == other._literal return self.__class__ == other.__class__ and self._values == other._values and self._literal == other._literal
@@ -512,10 +541,24 @@ def _parse_single_pdf(url: str) -> dict:
break break
formats[fmt_name] = fields formats[fmt_name] = fields
# fix known PDF errors # fix known PDF errors - assert if already present (so we know when the bug is fixed)
if 'SMEM' in formats: if 'SMEM' in formats:
formats['SMEM'] = [(n, 13 if n == 'DLC' else 14 if n == 'GLC' else h, 13 if n == 'DLC' else 14 if n == 'GLC' else l, e, t) formats['SMEM'] = [(n, 13 if n == 'DLC' else 14 if n == 'GLC' else h, 13 if n == 'DLC' else 14 if n == 'GLC' else l, e, t)
for n, h, l, e, t in formats['SMEM']] for n, h, l, e, t in formats['SMEM']]
# add missing opcodes not in PDF tables (RDNA3/RDNA3.5 specific)
if doc_name in ('RDNA3', 'RDNA3.5'):
if 'SOPPOp' in enums:
assert 8 not in enums['SOPPOp'], "S_WAITCNT_DEPCTR now in PDF, remove workaround"
enums['SOPPOp'][8] = 'S_WAITCNT_DEPCTR'
if 'DSOp' in enums:
gws_ops = {24: 'DS_GWS_SEMA_RELEASE_ALL', 25: 'DS_GWS_INIT', 26: 'DS_GWS_SEMA_V',
27: 'DS_GWS_SEMA_BR', 28: 'DS_GWS_SEMA_P', 29: 'DS_GWS_BARRIER'}
for k in gws_ops: assert k not in enums['DSOp'], f"{gws_ops[k]} now in PDF, remove workaround"
enums['DSOp'].update(gws_ops)
if 'FLATOp' in enums:
flat_ops = {40: 'GLOBAL_LOAD_ADDTID_B32', 41: 'GLOBAL_STORE_ADDTID_B32', 55: 'FLAT_ATOMIC_CSUB_U32'}
for k in flat_ops: assert k not in enums['FLATOp'], f"{flat_ops[k]} now in PDF, remove workaround"
enums['FLATOp'].update(flat_ops)
return {"formats": formats, "enums": enums, "src_enum": src_enum, "doc_name": doc_name, "is_cdna": is_cdna} return {"formats": formats, "enums": enums, "src_enum": src_enum, "doc_name": doc_name, "is_cdna": is_cdna}
@@ -601,7 +644,7 @@ def generate(output_path: str | None = None, arch: str = "rdna3") -> dict:
for cls_name, ops in sorted(enums.items()): for cls_name, ops in sorted(enums.items()):
fmt = cls_name[:-2] fmt = cls_name[:-2]
for op_val, name in sorted(ops.items()): for op_val, name in sorted(ops.items()):
seg = {"GLOBAL": ", seg=2", "SCRATCH": ", seg=2"}.get(fmt, "") seg = {"GLOBAL": ", seg=2", "SCRATCH": ", seg=1"}.get(fmt, "")
tgt = {"GLOBAL": "FLAT, GLOBALOp", "SCRATCH": "FLAT, SCRATCHOp"}.get(fmt, f"{fmt}, {cls_name}") tgt = {"GLOBAL": "FLAT, GLOBALOp", "SCRATCH": "FLAT, SCRATCHOp"}.get(fmt, f"{fmt}, {cls_name}")
if fmt in formats or fmt in ("GLOBAL", "SCRATCH"): if fmt in formats or fmt in ("GLOBAL", "SCRATCH"):
if fmt in ("VOP1", "VOP2", "VOPC"): if fmt in ("VOP1", "VOP2", "VOPC"):

View File

@@ -191,6 +191,9 @@ def run_single_kernel(kernel: bytes, n_lanes: int, args_ptr: int, global_size: t
python_result = python.step() python_result = python.step()
if rust_result != python_result: if rust_result != python_result:
# Rust returns 1 for unsupported instructions - skip test
if rust_result == 1 and python_result == 0:
raise unittest.SkipTest(f"Rust emulator doesn't support instruction: {inst_str}")
trace_str = "\n".join(f" step {s}: PC={pc:3d} {d}" for s, pc, d, _, _ in trace) trace_str = "\n".join(f" step {s}: PC={pc:3d} {d}" for s, pc, d, _, _ in trace)
return False, f"K{kernel_idx} WG({gidx},{gidy},{gidz}) Step {step}: different return codes: rust={rust_result}, python={python_result}, inst={inst_str}\n Recent instructions:\n{trace_str}", total_steps return False, f"K{kernel_idx} WG({gidx},{gidy},{gidz}) Step {step}: different return codes: rust={rust_result}, python={python_result}, inst={inst_str}\n Recent instructions:\n{trace_str}", total_steps
@@ -361,6 +364,7 @@ class TestTinygradKernels(unittest.TestCase):
# Matmul # Matmul
def test_gemm(self): self._test_kernel(lambda T: T.empty(8, 8) @ T.empty(8, 8), max_steps=100000) def test_gemm(self): self._test_kernel(lambda T: T.empty(8, 8) @ T.empty(8, 8), max_steps=100000)
@unittest.skip("Rust emulator crashes on this kernel (assertion failure in thread.rs)")
def test_gemm_fp16(self): self._test_kernel(lambda T: T.empty(16, 16).half() @ T.empty(16, 16).half(), max_steps=100000) def test_gemm_fp16(self): self._test_kernel(lambda T: T.empty(16, 16).half() @ T.empty(16, 16).half(), max_steps=100000)
# Complex ops # Complex ops

File diff suppressed because it is too large Load Diff

View File

@@ -65,12 +65,18 @@ def parse_llvm_tests(text: str) -> list[tuple[str, bytes]]:
if not asm_text: continue if not asm_text: continue
for j in range(i, min(i + 3, len(lines))): for j in range(i, min(i + 3, len(lines))):
# Match GFX11, W32, or W64 encodings (all valid for gfx11) # Match GFX11, W32, or W64 encodings (all valid for gfx11)
# Format 1: "// GFX11: v_foo ... ; encoding: [0x01,0x02,...]"
# Format 2: "// GFX11: [0x01,0x02,...]" (used by DS, older files)
if m := re.search(r'(?:GFX11|W32|W64)[^:]*:.*?encoding:\s*\[(.*?)\]', lines[j]): if m := re.search(r'(?:GFX11|W32|W64)[^:]*:.*?encoding:\s*\[(.*?)\]', lines[j]):
hex_bytes = m.group(1).replace('0x', '').replace(',', '').replace(' ', '') hex_bytes = m.group(1).replace('0x', '').replace(',', '').replace(' ', '')
if hex_bytes: elif m := re.search(r'(?:GFX11|W32|W64)[^:]*:\s*\[(0x[0-9a-fA-F,x\s]+)\]', lines[j]):
try: tests.append((asm_text, bytes.fromhex(hex_bytes))) hex_bytes = m.group(1).replace('0x', '').replace(',', '').replace(' ', '')
except ValueError: pass else:
break continue
if hex_bytes:
try: tests.append((asm_text, bytes.fromhex(hex_bytes)))
except ValueError: pass
break
return tests return tests
def try_assemble(text: str): def try_assemble(text: str):

View File

@@ -210,6 +210,8 @@ D0.u32 = tmp.u32""")
for i in 0 : 31 do for i in 0 : 31 do
if S0.u32[i] == 1 then if S0.u32[i] == 1 then
tmp = i tmp = i
endif
endfor
D0.i32 = tmp""") D0.i32 = tmp""")
ctx = ExecContext(s0=0b1000) # Bit 3 is set ctx = ExecContext(s0=0b1000) # Bit 3 is set
ctx.run(code) ctx.run(code)

View File

@@ -4,46 +4,9 @@ import unittest, io, sys, re, subprocess, os
from extra.assembly.amd.autogen.rdna3 import * from extra.assembly.amd.autogen.rdna3 import *
from extra.assembly.amd.dsl import Inst from extra.assembly.amd.dsl import Inst
from extra.assembly.amd.asm import asm from extra.assembly.amd.asm import asm
from extra.assembly.amd.asm import detect_format
from extra.assembly.amd.test.helpers import get_llvm_mc, get_llvm_objdump from extra.assembly.amd.test.helpers import get_llvm_mc, get_llvm_objdump
# Instruction format detection based on encoding bits
def detect_format(data: bytes) -> type[Inst] | None:
"""Detect instruction format from machine code bytes."""
if len(data) < 4: return None
word = int.from_bytes(data[:4], 'little')
enc_9bit = (word >> 23) & 0x1FF # 9-bit encoding for SOP1/SOPC/SOPP
enc_8bit = (word >> 24) & 0xFF
# Check 9-bit encodings first (most specific)
if enc_9bit == 0x17D: return SOP1 # bits 31:23 = 101111101
if enc_9bit == 0x17E: return SOPC # bits 31:23 = 101111110
if enc_9bit == 0x17F: return SOPP # bits 31:23 = 101111111
# SOPK: bits 31:28 = 1011, bits 27:23 = opcode (check after SOP1/SOPC/SOPP)
if enc_8bit in range(0xB0, 0xC0): return SOPK
# SOP2: bits 31:23 in range 0x100-0x17C (0x80-0xBE in bits 31:24, but not SOPK)
if 0x80 <= enc_8bit <= 0x9F: return SOP2
# VOP1: bits 31:25 = 0111111 (0x3F)
if (word >> 25) == 0x3F: return VOP1
# VOPC: bits 31:25 = 0111110 (0x3E)
if (word >> 25) == 0x3E: return VOPC
# VOP2: bits 31:30 = 00
if (word >> 30) == 0: return VOP2
# Check 64-bit formats
if len(data) >= 8:
if enc_8bit in (0xD4, 0xD5, 0xD7): return VOP3
if enc_8bit == 0xD6: return VOP3SD
if enc_8bit == 0xCC: return VOP3P
if enc_8bit == 0xCD: return VINTERP
if enc_8bit in (0xC8, 0xC9): return VOPD
if enc_8bit == 0xF4: return SMEM
if enc_8bit == 0xD8: return DS
if enc_8bit in (0xDC, 0xDD, 0xDE, 0xDF): return FLAT
if enc_8bit in (0xE0, 0xE1, 0xE2, 0xE3): return MUBUF
if enc_8bit in (0xE8, 0xE9, 0xEA, 0xEB): return MTBUF
return None
def disassemble_lib(lib: bytes, compiler) -> list[tuple[str, bytes]]: def disassemble_lib(lib: bytes, compiler) -> list[tuple[str, bytes]]:
"""Disassemble ELF binary and return list of (instruction_text, machine_code_bytes).""" """Disassemble ELF binary and return list of (instruction_text, machine_code_bytes)."""
old_stdout = sys.stdout old_stdout = sys.stdout

View File

@@ -1,7 +1,7 @@
import unittest, operator, math import unittest, operator, math
from tinygrad import Tensor, dtypes, Device from tinygrad import Tensor, dtypes, Device
from tinygrad.dtype import DType, truncate from tinygrad.dtype import DType, truncate
from tinygrad.helpers import CI, getenv, CPU_LLVM from tinygrad.helpers import CI, getenv
from tinygrad.tensor import _to_np_dtype from tinygrad.tensor import _to_np_dtype
from tinygrad.device import is_dtype_supported from tinygrad.device import is_dtype_supported
from tinygrad.runtime.ops_python import from_storage_scalar from tinygrad.runtime.ops_python import from_storage_scalar
@@ -48,7 +48,7 @@ class ht:
int32 = strat.integers(-2147483648, 2147483647) int32 = strat.integers(-2147483648, 2147483647)
int64 = strat.integers(-9223372036854775808, 9223372036854775807) int64 = strat.integers(-9223372036854775808, 9223372036854775807)
bool = strat.booleans() bool = strat.booleans()
ht.bfloat16 = ht.uint16 ht.bfloat16 = ht.uint16.filter(lambda x: ((x >> 7) & 0xFF) != 0) # filter subnormal bfloat16
ht.fp8e4m3 = ht.uint8 ht.fp8e4m3 = ht.uint8
ht.fp8e5m2 = ht.uint8 ht.fp8e5m2 = ht.uint8
@@ -138,7 +138,6 @@ class TestDTypeALU(unittest.TestCase):
def test_float16_unary(self, a, op): universal_test_unary(a, dtypes.float16, op) def test_float16_unary(self, a, op): universal_test_unary(a, dtypes.float16, op)
@unittest.skipUnless(is_dtype_supported(dtypes.bfloat16), f"no bfloat16 on {Device.DEFAULT}") @unittest.skipUnless(is_dtype_supported(dtypes.bfloat16), f"no bfloat16 on {Device.DEFAULT}")
@unittest.skipIf(CPU_LLVM, "bfloat16 precision issues with CPU_LLVM")
@given(ht.bfloat16, strat.sampled_from(unary_operations)) @given(ht.bfloat16, strat.sampled_from(unary_operations))
def test_bfloat16_unary(self, a, op): universal_test_unary(from_storage_scalar(a, dtypes.bfloat16), dtypes.bfloat16, op) def test_bfloat16_unary(self, a, op): universal_test_unary(from_storage_scalar(a, dtypes.bfloat16), dtypes.bfloat16, op)

View File

@@ -189,16 +189,18 @@ class AM_SMU(AM_IP):
return table_t.from_buffer(bytearray(self.adev.vram.view(self.driver_table_paddr, ctypes.sizeof(table_t))[:])) return table_t.from_buffer(bytearray(self.adev.vram.view(self.driver_table_paddr, ctypes.sizeof(table_t))[:]))
def set_clocks(self, level): def set_clocks(self, level):
if self.adev.ip_ver[am.MP0_HWIP] in {(13,0,6), (13,0,12)}: return # TODO
if not hasattr(self, 'clcks'): if not hasattr(self, 'clcks'):
clks = [self.smu_mod.PPCLK_UCLK, self.smu_mod.PPCLK_FCLK, self.smu_mod.PPCLK_SOCCLK]
if self.adev.ip_ver[am.MP0_HWIP] not in {(13,0,6), (13,0,12)}: clks.append(self.smu_mod.PPCLK_GFXCLK)
self.clcks = {} self.clcks = {}
for clck in [self.smu_mod.PPCLK_GFXCLK, self.smu_mod.PPCLK_UCLK, self.smu_mod.PPCLK_FCLK, self.smu_mod.PPCLK_SOCCLK]: for clck in clks:
cnt = self._send_msg(self.smu_mod.PPSMC_MSG_GetDpmFreqByIndex, (clck<<16)|0xff, read_back_arg=True)&0x7fffffff cnt = self._send_msg(self.smu_mod.PPSMC_MSG_GetDpmFreqByIndex, (clck<<16)|0xff, read_back_arg=True)&0x7fffffff
self.clcks[clck] = [self._send_msg(self.smu_mod.PPSMC_MSG_GetDpmFreqByIndex, (clck<<16)|i, read_back_arg=True)&0x7fffffff for i in range(cnt)] self.clcks[clck] = [self._send_msg(self.smu_mod.PPSMC_MSG_GetDpmFreqByIndex, (clck<<16)|i, read_back_arg=True)&0x7fffffff for i in range(cnt)]
for clck, vals in self.clcks.items(): for clck, vals in self.clcks.items():
self._send_msg(self.smu_mod.PPSMC_MSG_SetSoftMinByFreq, clck << 16 | (vals[level])) if not vals: continue
with contextlib.suppress(TimeoutError): self._send_msg(self.smu_mod.PPSMC_MSG_SetSoftMinByFreq, clck << 16 | (vals[level]), timeout=20)
self._send_msg(self.smu_mod.PPSMC_MSG_SetSoftMaxByFreq, clck << 16 | (vals[level])) self._send_msg(self.smu_mod.PPSMC_MSG_SetSoftMaxByFreq, clck << 16 | (vals[level]))
def _smu_cmn_send_msg(self, msg:int, param=0, debug=False): def _smu_cmn_send_msg(self, msg:int, param=0, debug=False):

View File

@@ -70,7 +70,7 @@ def validate_index(buf:UOp, idx:UOp, gate:UOp|None=None):
# WEBGPU has a BITCAST in the index. TODO: fix # WEBGPU has a BITCAST in the index. TODO: fix
if any(x.op is Ops.BITCAST for x in idx.toposort()): return True if any(x.op is Ops.BITCAST for x in idx.toposort()): return True
if not z3_imported: raise ImportError("z3 >= 4.12.4 is required for bounds checking, try IGNORE_OOB=0 or \"pip install 'z3-solver>=4.12.4\"") if not z3_imported: raise ImportError("bounds checking requires z3 >= 4.12.4, use IGNORE_OOB=1 to disable, or \"pip install 'z3-solver>=4.12.4\"")
solver = z3.Solver(ctx=z3.Context()) solver = z3.Solver(ctx=z3.Context())
z3_idx, z3_mask = uops_to_z3(solver, idx, gate) z3_idx, z3_mask = uops_to_z3(solver, idx, gate)
solver.add(z3_mask) solver.add(z3_mask)