Merge origin/master, delete pcode.py

This commit is contained in:
George Hotz
2026-01-05 18:53:43 -08:00
38 changed files with 4993 additions and 4450 deletions

View File

@@ -40,11 +40,13 @@ jobs:
- name: Install autogen support packages
run: sudo apt-get install -y --no-install-recommends libclang-20-dev llvm-20-dev hip-dev libusb-1.0-0-dev
- name: Verify OpenCL autogen
continue-on-error: true
run: |
mv tinygrad/runtime/autogen/opencl.py /tmp/opencl.py.bak
python3 -c "from tinygrad.runtime.autogen import opencl"
diff /tmp/opencl.py.bak tinygrad/runtime/autogen/opencl.py
- name: Verify CUDA autogen
continue-on-error: true
run: |
mv tinygrad/runtime/autogen/cuda.py /tmp/cuda.py.bak
mv tinygrad/runtime/autogen/nvrtc.py /tmp/nvrtc.py.bak
@@ -58,6 +60,7 @@ jobs:
diff /tmp/nv_570.py.bak tinygrad/runtime/autogen/nv_570.py
diff /tmp/nv.py.bak tinygrad/runtime/autogen/nv.py
- name: Verify AMD autogen
continue-on-error: true
run: |
mv tinygrad/runtime/autogen/comgr.py /tmp/comgr.py.bak
mv tinygrad/runtime/autogen/hsa.py /tmp/hsa.py.bak
@@ -89,6 +92,7 @@ jobs:
diff /tmp/am_smu_v13_0_0.py.bak tinygrad/runtime/autogen/am/smu_v13_0_0.py
diff /tmp/am_smu_v14_0_2.py.bak tinygrad/runtime/autogen/am/smu_v14_0_2.py
- name: Verify Linux autogen
continue-on-error: true
run: |
mv tinygrad/runtime/autogen/libc.py /tmp/libc.py.bak
mv tinygrad/runtime/autogen/kfd.py /tmp/kfd.py.bak
@@ -104,16 +108,19 @@ jobs:
diff /tmp/pci.py.bak tinygrad/runtime/autogen/pci.py
diff /tmp/vfio.py.bak tinygrad/runtime/autogen/vfio.py
- name: Verify LLVM autogen
continue-on-error: true
run: |
mv tinygrad/runtime/autogen/llvm.py /tmp/llvm.py.bak
python3 -c "from tinygrad.runtime.autogen import llvm"
diff /tmp/llvm.py.bak tinygrad/runtime/autogen/llvm.py
- name: Verify WebGPU autogen
continue-on-error: true
run: |
mv tinygrad/runtime/autogen/webgpu.py /tmp/webgpu.py.bak
python3 -c "from tinygrad.runtime.autogen import webgpu"
diff /tmp/webgpu.py.bak tinygrad/runtime/autogen/webgpu.py
- name: Verify Qualcomm autogen
continue-on-error: true
run: |
mv tinygrad/runtime/autogen/kgsl.py /tmp/kgsl.py.bak
mv tinygrad/runtime/autogen/qcom_dsp.py /tmp/qcom_dsp.py.bak
@@ -121,20 +128,36 @@ jobs:
diff /tmp/kgsl.py.bak tinygrad/runtime/autogen/kgsl.py
diff /tmp/qcom_dsp.py.bak tinygrad/runtime/autogen/qcom_dsp.py
- name: Verify libusb autogen
continue-on-error: true
run: |
mv tinygrad/runtime/autogen/libusb.py /tmp/libusb.py.bak
python3 -c "from tinygrad.runtime.autogen import libusb"
diff /tmp/libusb.py.bak tinygrad/runtime/autogen/libusb.py
- name: Verify mesa autogen
continue-on-error: true
run: |
mv tinygrad/runtime/autogen/mesa.py /tmp/mesa.py.bak
python3 -c "from tinygrad.runtime.autogen import mesa"
diff /tmp/mesa.py.bak tinygrad/runtime/autogen/mesa.py
- name: Verify libclang autogen
continue-on-error: true
run: |
cp tinygrad/runtime/autogen/libclang.py /tmp/libclang.py.bak
REGEN=1 python3 -c "from tinygrad.runtime.autogen import libclang"
diff /tmp/libclang.py.bak tinygrad/runtime/autogen/libclang.py
- name: Generate patch for differences
run: |
if ! git diff --quiet; then
git diff > autogen-ubuntu.patch
fi
- name: Upload patch artifact
uses: actions/upload-artifact@v4
with:
name: autogen-ubuntu-patch
path: autogen-ubuntu.patch
if-no-files-found: ignore
- name: Fail if differences found
run: git diff --quiet
autogen-mac:
name: In-tree Autogen (macos)
runs-on: macos-14
@@ -147,10 +170,24 @@ jobs:
with:
llvm: 'true'
- name: Verify macos autogen
continue-on-error: true
run: |
mv tinygrad/runtime/autogen/metal.py /tmp/metal.py.bak
LIBCLANG_PATH=/opt/homebrew/opt/llvm@20/lib/libclang.dylib python3 -c "from tinygrad.runtime.autogen import metal"
diff /tmp/metal.py.bak tinygrad/runtime/autogen/metal.py
- name: Generate patch for differences
run: |
if ! git diff --quiet; then
git diff > autogen-macos.patch
fi
- name: Upload patch artifact
uses: actions/upload-artifact@v4
with:
name: autogen-macos-patch
path: autogen-macos.patch
if-no-files-found: ignore
- name: Fail if differences found
run: git diff --quiet
autogen-comgr-3:
name: In-tree Autogen (comgr 3)
runs-on: ubuntu-24.04
@@ -170,7 +207,21 @@ jobs:
sudo apt -qq update || true
sudo apt-get install -y --no-install-recommends libclang-20-dev comgr
- name: Verify comgr (3) autogen
continue-on-error: true
run: |
mv tinygrad/runtime/autogen/comgr_3.py /tmp/comgr_3.py.bak
python3 -c "from tinygrad.runtime.autogen import comgr_3"
diff /tmp/comgr_3.py.bak tinygrad/runtime/autogen/comgr_3.py
- name: Generate patch for differences
run: |
if ! git diff --quiet; then
git diff > autogen-comgr3.patch
fi
- name: Upload patch artifact
uses: actions/upload-artifact@v4
with:
name: autogen-comgr3-patch
path: autogen-comgr3.patch
if-no-files-found: ignore
- name: Fail if differences found
run: git diff --quiet

View File

@@ -257,6 +257,7 @@ jobs:
key: unittest-12
pydeps: "pillow numpy ftfy regex"
deps: testing_unit
llvm: 'true'
- name: Check Device.DEFAULT
run: python -c "from tinygrad import Device; assert Device.DEFAULT == 'CPU', Device.DEFAULT"
- name: Run unit tests
@@ -309,7 +310,7 @@ jobs:
deps: testing_unit
python-version: '3.14'
- name: Test SPEC=2
run: SPEC=2 pytest --maxfail=10 -n auto --durations=30 --ignore=test/models --ignore test/test_custom_kernel.py --ignore test/unit/test_hashing.py --timeout 60 -k "not test_setitem_big" --splits 2 --group ${{ matrix.group }}
run: SPEC=2 pytest --maxfail=10 -n auto --durations=30 --ignore=test/models --ignore test/test_custom_kernel.py --ignore test/unit/test_hashing.py --ignore test/unit/test_autogen.py --timeout 60 -k "not test_setitem_big" --splits 2 --group ${{ matrix.group }}
fuzzing:
name: Fuzzing
@@ -669,6 +670,10 @@ jobs:
deps: testing_minimal
amd: 'true'
python-version: '3.13'
- name: Verify AMD autogen is up to date
run: |
python -m extra.assembly.amd.pdf
git diff --exit-code extra/assembly/amd/autogen/
- name: Install LLVM 21
run: |
wget -qO- https://apt.llvm.org/llvm-snapshot.gpg.key | sudo tee /etc/apt/trusted.gpg.d/apt.llvm.org.asc
@@ -689,23 +694,6 @@ jobs:
- name: Run RDNA3 ops tests
run: SKIP_SLOW_TEST=1 AMD=1 PYTHON_REMU=1 MOCKGPU=1 AMD_LLVM=0 pytest -n=auto test/test_ops.py -k "test_sparse_categorical_crossentropy or test_tril"
testamdautogen:
name: AMD autogen
runs-on: ubuntu-24.04
timeout-minutes: 10
steps:
- name: Checkout Code
uses: actions/checkout@v4
- name: Setup Environment
uses: ./.github/actions/setup-tinygrad
with:
key: rdna3-autogen
pydeps: "pdfplumber"
- name: Verify AMD autogen is up to date
run: |
python -m extra.assembly.amd.pdf --arch all
git diff --exit-code extra/assembly/amd/autogen/
testnvidia:
strategy:
fail-fast: false
@@ -793,6 +781,8 @@ jobs:
ocelot: 'true'
llvm: 'true'
- name: Run unit tests
env:
LIBCLANG_PATH: '/opt/homebrew/opt/llvm@20/lib/libclang.dylib'
run: METAL=1 python -m pytest -n=auto test/unit/ --durations=20
- name: Run ONNX
run: METAL=1 python -m pytest -n=auto test/external/external_test_onnx_backend.py --durations=20

View File

@@ -192,9 +192,12 @@ When optimizing tinygrad internals:
9. **Avoid creating intermediate objects in hot paths** - For example, `any(x.op in ops for x in self.backward_slice)` is faster than `any(x.op in ops for x in {self:None, **self.backward_slice})` because it avoids dict creation.
## Pattern Matching Profiling
## Pattern Matching Analysis
Use `TRACK_MATCH_STATS=2` to identify expensive patterns:
**Use the right tool:**
- `TRACK_MATCH_STATS=2` - **Profiling**: identify expensive patterns
- `VIZ=-1` - **Inspection**: see all transformations, what every match pattern does, the before/after diffs
```bash
TRACK_MATCH_STATS=2 PYTHONPATH="." python3 test/external/external_benchmark_schedule.py
@@ -209,6 +212,14 @@ Key patterns to watch (from ResNet50 benchmark):
Patterns with 0% match rate are workload-specific overhead. They may be useful in other workloads, so don't remove them without understanding their purpose.
```bash
# Save the trace
VIZ=-1 python test/test_tiny.py TestTiny.test_gemm
# Explore it
./extra/viz/cli.py --help
```
## AMD Performance Counter Profiling
Set VIZ to `-2` to save performance counters traces for the AMD backend.

View File

@@ -153,8 +153,7 @@ class SMICtx:
tables = {}
for dev in self.devs:
match dev.ip_ver[am.MP1_HWIP]:
case (13,0,6): table_t = dev.smu.smu_mod.MetricsTableX_t
case (13,0,12): table_t = dev.smu.smu_mod.MetricsTableV2_t
case (13,0,6)|(13,0,12): table_t = dev.smu.smu_mod.MetricsTableX_t
case _: table_t = dev.smu.smu_mod.SmuMetricsExternal_t
tables[dev] = dev.smu.read_table(table_t, dev.smu.smu_mod.SMU_TABLE_SMU_METRICS) if dev.pci_state == "D0" else None
return tables
@@ -165,17 +164,17 @@ class SMICtx:
def get_gfx_activity(self, dev, metrics):
match dev.ip_ver[am.MP1_HWIP]:
case (13,0,6): return max(0, min(100, self._smuq10_round(metrics.SocketGfxBusy)))
case (13,0,6)|(13,0,12): return max(0, min(100, self._smuq10_round(metrics.SocketGfxBusy)))
case _: return metrics.SmuMetrics.AverageGfxActivity
def get_mem_activity(self, dev, metrics):
match dev.ip_ver[am.MP1_HWIP]:
case (13,0,6): return max(0, min(100, self._smuq10_round(metrics.DramBandwidthUtilization)))
case (13,0,6)|(13,0,12): return max(0, min(100, self._smuq10_round(metrics.DramBandwidthUtilization)))
case _: return metrics.SmuMetrics.AverageUclkActivity
def get_temps(self, dev, metrics, compact=False):
match dev.ip_ver[am.MP1_HWIP]:
case (13,0,6):
case (13,0,6)|(13,0,12):
temps = {
"Hotspot": self._smuq10_round(metrics.MaxSocketTemperature),
"HBM": self._smuq10_round(metrics.MaxHbmTemperature),
@@ -191,7 +190,7 @@ class SMICtx:
def get_voltage(self, dev, metrics, compact=False):
match dev.ip_ver[am.MP1_HWIP]:
case (13,0,6): return {}
case (13,0,6)|(13,0,12): return {}
case _:
voltage_keys = [(k, name) for k, name in dev.smu.smu_mod.SVI_PLANE_e.items()
if k < dev.smu.smu_mod.SVI_PLANE_COUNT and metrics.SmuMetrics.AvgVoltage[k] != 0]
@@ -205,33 +204,33 @@ class SMICtx:
def get_gfx_freq(self, dev, metrics):
if metrics is None: return 0
match dev.ip_ver[am.MP1_HWIP]:
case (13,0,6): return self._smuq10_round(metrics.GfxclkFrequency[0])
case (13,0,6)|(13,0,12): return self._smuq10_round(metrics.GfxclkFrequency[0])
case _:
return metrics.SmuMetrics.AverageGfxclkFrequencyPostDs if self.get_gfx_activity(dev, metrics) <= self.get_busy_threshold(dev) else \
metrics.SmuMetrics.AverageGfxclkFrequencyPreDs
def get_mem_freq(self, dev, metrics):
match dev.ip_ver[am.MP1_HWIP]:
case (13,0,6): return self._smuq10_round(metrics.UclkFrequency)
case (13,0,6)|(13,0,12): return self._smuq10_round(metrics.UclkFrequency)
case _:
return metrics.SmuMetrics.AverageMemclkFrequencyPostDs if self.get_mem_activity(dev, metrics) <= self.get_busy_threshold(dev) else \
metrics.SmuMetrics.AverageMemclkFrequencyPreDs
def get_fckl_freq(self, dev, metrics):
match dev.ip_ver[am.MP1_HWIP]:
case (13,0,6): return self._smuq10_round(metrics.FclkFrequency)
case (13,0,6)|(13,0,12): return self._smuq10_round(metrics.FclkFrequency)
case _:
return metrics.SmuMetrics.AverageFclkFrequencyPostDs if self.get_mem_activity(dev, metrics) <= self.get_busy_threshold(dev) else \
metrics.SmuMetrics.AverageFclkFrequencyPreDs
def get_fan_rpm_pwm(self, dev, metrics):
match dev.ip_ver[am.MP1_HWIP]:
case (13,0,6): return None, None
case (13,0,6)|(13,0,12): return None, None
case _: return metrics.SmuMetrics.AvgFanRpm, metrics.SmuMetrics.AvgFanPwm
def get_power(self, dev, metrics):
match dev.ip_ver[am.MP1_HWIP]:
case (13,0,6): return self._smuq10_round(metrics.SocketPower), self._smuq10_round(metrics.MaxSocketPowerLimit)
case (13,0,6)|(13,0,12): return self._smuq10_round(metrics.SocketPower), self._smuq10_round(metrics.MaxSocketPowerLimit)
case _: return metrics.SmuMetrics.AverageSocketPower, metrics.SmuMetrics.dGPU_W_MAX
def get_mem_usage(self, dev):

View File

@@ -3,10 +3,11 @@ from __future__ import annotations
import re
from extra.assembly.amd.dsl import Inst, RawImm, Reg, SrcMod, SGPR, VGPR, TTMP, s, v, ttmp, _RegFactory
from extra.assembly.amd.dsl import VCC_LO, VCC_HI, VCC, EXEC_LO, EXEC_HI, EXEC, SCC, M0, NULL, OFF
from extra.assembly.amd.dsl import SPECIAL_GPRS, SPECIAL_PAIRS, FLOAT_DEC, FLOAT_ENC, decode_src
from extra.assembly.amd.dsl import SPECIAL_GPRS, SPECIAL_PAIRS, SPECIAL_PAIRS_CDNA, FLOAT_DEC, FLOAT_ENC, decode_src
from extra.assembly.amd.autogen.rdna3 import ins
from extra.assembly.amd.autogen.rdna3.ins import (VOP1, VOP2, VOP3, VOP3SD, VOP3P, VOPC, VOPD, VINTERP, SOP1, SOP2, SOPC, SOPK, SOPP, SMEM, DS, FLAT, MUBUF, MTBUF, MIMG, EXP,
VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOPDOp, SOP1Op, SOPKOp, SOPPOp, SMEMOp, DSOp, MUBUFOp)
VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOPDOp, SOP1Op, SOPKOp, SOPPOp, SMEMOp, DSOp, MUBUFOp, MTBUFOp)
from extra.assembly.amd.autogen.rdna3.enum import BufFmt
def _is_cdna(inst: Inst) -> bool: return 'cdna' in inst.__class__.__module__
@@ -17,21 +18,37 @@ def _matches_encoding(word: int, cls: type[Inst]) -> bool:
return ((word >> bf.lo) & bf.mask()) == val
# Order matters: more specific encodings first, VOP2 last (it's a catch-all for bit31=0)
_FORMATS_64 = [VOPD, VOP3P, VINTERP, VOP3, DS, FLAT, MUBUF, MTBUF, MIMG, SMEM, EXP]
_FORMATS_32 = [SOP1, SOPC, SOPP, SOPK, VOPC, VOP1, SOP2, VOP2] # SOP2/VOP2 are catch-alls
_RDNA_FORMATS_64 = [VOPD, VOP3P, VINTERP, VOP3, DS, FLAT, MUBUF, MTBUF, MIMG, SMEM, EXP]
_RDNA_FORMATS_32 = [SOP1, SOPC, SOPP, SOPK, VOPC, VOP1, SOP2, VOP2] # SOP2/VOP2 are catch-alls
from extra.assembly.amd.autogen.cdna.ins import (VOP1 as C_VOP1, VOP2 as C_VOP2, VOPC as C_VOPC, VOP3A, VOP3B, VOP3P as C_VOP3P,
SOP1 as C_SOP1, SOP2 as C_SOP2, SOPC as C_SOPC, SOPK as C_SOPK, SOPP as C_SOPP, SMEM as C_SMEM, DS as C_DS,
FLAT as C_FLAT, MUBUF as C_MUBUF, MTBUF as C_MTBUF, SDWA, DPP)
_CDNA_FORMATS_64 = [C_VOP3P, VOP3A, C_DS, C_FLAT, C_MUBUF, C_MTBUF, C_SMEM]
_CDNA_FORMATS_32 = [SDWA, DPP, C_SOP1, C_SOPC, C_SOPP, C_SOPK, C_VOPC, C_VOP1, C_SOP2, C_VOP2]
_CDNA_VOP3B_OPS = {281, 282, 283, 284, 285, 286, 480, 481, 488, 489} # VOP3B opcodes
# CDNA opcode name aliases for disasm (new name -> old name expected by tests)
_CDNA_DISASM_ALIASES = {'v_fmac_f64': 'v_mul_legacy_f32', 'v_dot2c_f32_bf16': 'v_mac_f32', 'v_fmamk_f32': 'v_madmk_f32', 'v_fmaak_f32': 'v_madak_f32'}
def detect_format(data: bytes) -> type[Inst]:
def detect_format(data: bytes, arch: str = "rdna3") -> type[Inst]:
"""Detect instruction format from machine code bytes."""
assert len(data) >= 4, f"need at least 4 bytes, got {len(data)}"
word = int.from_bytes(data[:4], 'little')
# Check 64-bit formats first (bits[31:30] == 0b11)
if arch == "cdna":
if (word >> 30) == 0b11:
for cls in _CDNA_FORMATS_64:
if _matches_encoding(word, cls):
return VOP3B if cls is VOP3A and ((word >> 16) & 0x3ff) in _CDNA_VOP3B_OPS else cls
raise ValueError(f"unknown CDNA 64-bit format word={word:#010x}")
for cls in _CDNA_FORMATS_32:
if _matches_encoding(word, cls): return cls
raise ValueError(f"unknown CDNA 32-bit format word={word:#010x}")
# RDNA (default)
if (word >> 30) == 0b11:
for cls in _FORMATS_64:
for cls in _RDNA_FORMATS_64:
if _matches_encoding(word, cls):
return VOP3SD if cls is VOP3 and ((word >> 16) & 0x3ff) in Inst._VOP3SD_OPS else cls
raise ValueError(f"unknown 64-bit format word={word:#010x}")
# 32-bit formats
for cls in _FORMATS_32:
for cls in _RDNA_FORMATS_32:
if _matches_encoding(word, cls): return cls
raise ValueError(f"unknown 32-bit format word={word:#010x}")
@@ -44,6 +61,11 @@ HWREG = {1: 'HW_REG_MODE', 2: 'HW_REG_STATUS', 3: 'HW_REG_TRAPSTS', 4: 'HW_REG_H
19: 'HW_REG_PERF_SNAPSHOT_PC_HI', 20: 'HW_REG_FLAT_SCR_LO', 21: 'HW_REG_FLAT_SCR_HI', 22: 'HW_REG_XNACK_MASK',
23: 'HW_REG_HW_ID1', 24: 'HW_REG_HW_ID2', 25: 'HW_REG_POPS_PACKER', 28: 'HW_REG_IB_STS2'}
HWREG_IDS = {v.lower(): k for k, v in HWREG.items()}
# RDNA unified buffer format - extracted from PDF, use enum for name->value lookup
BUF_FMT = {e.name: e.value for e in BufFmt}
def _parse_buf_fmt_combo(s: str) -> int: # parse format:[BUF_DATA_FORMAT_X, BUF_NUM_FORMAT_Y]
parts = [p.strip().replace('BUF_DATA_FORMAT_', '').replace('BUF_NUM_FORMAT_', '') for p in s.split(',')]
return BUF_FMT.get(f'BUF_FMT_{parts[0]}_{parts[1]}') if len(parts) == 2 else None
MSG = {128: 'MSG_RTN_GET_DOORBELL', 129: 'MSG_RTN_GET_DDID', 130: 'MSG_RTN_GET_TMA',
131: 'MSG_RTN_GET_REALTIME', 132: 'MSG_RTN_SAVE_WAVE', 133: 'MSG_RTN_GET_TBA'}
@@ -54,22 +76,28 @@ MSG = {128: 'MSG_RTN_GET_DOORBELL', 129: 'MSG_RTN_GET_DDID', 130: 'MSG_RTN_GET_T
def _reg(p: str, b: int, n: int = 1) -> str: return f"{p}{b}" if n == 1 else f"{p}[{b}:{b+n-1}]"
def _sreg(b: int, n: int = 1) -> str: return _reg("s", b, n)
def _vreg(b: int, n: int = 1) -> str: return _reg("v", b, n)
def _areg(b: int, n: int = 1) -> str: return _reg("a", b, n) # accumulator registers for GFX90a
def _ttmp(b: int, n: int = 1) -> str: return _reg("ttmp", b - 108, n) if 108 <= b <= 123 else None
def _sreg_or_ttmp(b: int, n: int = 1) -> str: return _ttmp(b, n) or _sreg(b, n)
def _fmt_sdst(v: int, n: int = 1) -> str:
if v == 124: return "null"
def _fmt_sdst(v: int, n: int = 1, cdna: bool = False) -> str:
from extra.assembly.amd.dsl import SPECIAL_PAIRS_CDNA, SPECIAL_GPRS_CDNA
if t := _ttmp(v, n): return t
if n > 1: return SPECIAL_PAIRS.get(v) or _sreg(v, n)
return SPECIAL_GPRS.get(v, f"s{v}")
pairs = SPECIAL_PAIRS_CDNA if cdna else SPECIAL_PAIRS
gprs = SPECIAL_GPRS_CDNA if cdna else SPECIAL_GPRS
if n > 1: return pairs.get(v) or gprs.get(v) or _sreg(v, n) # also check gprs for null/m0
return gprs.get(v, f"s{v}")
def _fmt_src(v: int, n: int = 1) -> str:
if n == 1: return decode_src(v)
def _fmt_src(v: int, n: int = 1, cdna: bool = False) -> str:
from extra.assembly.amd.dsl import SPECIAL_PAIRS_CDNA
if n == 1: return decode_src(v, cdna)
if v >= 256: return _vreg(v - 256, n)
if v <= 105: return _sreg(v, n)
if n == 2 and v in SPECIAL_PAIRS: return SPECIAL_PAIRS[v]
if v <= 101: return _sreg(v, n) # s0-s101 can be pairs, but 102+ are special on CDNA
pairs = SPECIAL_PAIRS_CDNA if cdna else SPECIAL_PAIRS
if n == 2 and v in pairs: return pairs[v]
if v <= 105: return _sreg(v, n) # s102-s105 regular pairs for RDNA
if t := _ttmp(v, n): return t
return decode_src(v)
return decode_src(v, cdna)
def _fmt_v16(v: int, base: int = 256, hi_thresh: int = 384) -> str:
return f"v{(v - base) & 0x7f}.{'h' if v >= hi_thresh else 'l'}"
@@ -106,46 +134,72 @@ def _opsel_str(opsel: int, n: int, need: bool, is16_d: bool) -> str:
# ═══════════════════════════════════════════════════════════════════════════════
def _disasm_vop1(inst: VOP1) -> str:
name = inst.op_name.lower()
if inst.op in (VOP1Op.V_NOP, VOP1Op.V_PIPEFLUSH): return name
if inst.op == VOP1Op.V_READFIRSTLANE_B32: return f"v_readfirstlane_b32 {decode_src(inst.vdst)}, v{inst.src0 - 256 if inst.src0 >= 256 else inst.src0}"
# 16-bit dst: uses .h/.l suffix (determined by name pattern, not dtype - e.g. sat_pk_u8_i16 outputs 8-bit but uses 16-bit encoding)
name, cdna = inst.op_name.lower() or f'vop1_op_{inst.op}', _is_cdna(inst)
suf = "" if cdna else "_e32"
if name in ('v_nop', 'v_pipeflush', 'v_clrexcp'): return name # no operands
if 'readfirstlane' in name:
src = f"v{inst.src0 - 256}" if inst.src0 >= 256 else decode_src(inst.src0, cdna)
return f"{name} {_fmt_sdst(inst.vdst, 1, cdna)}, {src}"
# 16-bit dst: uses .h/.l suffix for RDNA (CDNA uses plain vN)
parts = name.split('_')
is_16d = any(p in ('f16','i16','u16','b16') for p in parts[-2:-1]) or (len(parts) >= 2 and parts[-1] in ('f16','i16','u16','b16') and 'cvt' not in name)
is_16d = not cdna and (any(p in ('f16','i16','u16','b16') for p in parts[-2:-1]) or (len(parts) >= 2 and parts[-1] in ('f16','i16','u16','b16') and 'cvt' not in name))
dst = _vreg(inst.vdst, inst.dst_regs()) if inst.dst_regs() > 1 else _fmt_v16(inst.vdst, 0, 128) if is_16d else f"v{inst.vdst}"
src = inst.lit(inst.src0) if inst.src0 == 255 else _fmt_src(inst.src0, inst.src_regs(0)) if inst.src_regs(0) > 1 else _src16(inst, inst.src0) if inst.is_src_16(0) and 'sat_pk' not in name else inst.lit(inst.src0)
return f"{name}_e32 {dst}, {src}"
src = inst.lit(inst.src0) if inst.src0 == 255 else _fmt_src(inst.src0, inst.src_regs(0), cdna) if inst.src_regs(0) > 1 else _src16(inst, inst.src0) if not cdna and inst.is_src_16(0) and 'sat_pk' not in name else inst.lit(inst.src0)
return f"{name}{suf} {dst}, {src}"
_VOP2_CARRY_OUT = {'v_add_co_u32', 'v_sub_co_u32', 'v_subrev_co_u32'} # carry out only
_VOP2_CARRY_INOUT = {'v_addc_co_u32', 'v_subb_co_u32', 'v_subbrev_co_u32'} # carry in and out
def _disasm_vop2(inst: VOP2) -> str:
name, cdna = inst.op_name.lower(), _is_cdna(inst)
suf = "" if not cdna and inst.op == VOP2Op.V_DOT2ACC_F32_F16 else "_e32"
if cdna: name = _CDNA_DISASM_ALIASES.get(name, name) # apply CDNA aliases
suf = "" if cdna or (not cdna and inst.op == VOP2Op.V_DOT2ACC_F32_F16) else "_e32"
lit = getattr(inst, '_literal', None)
is16 = not cdna and inst.is_16bit()
# fmaak: dst = src0 * vsrc1 + K, fmamk: dst = src0 * K + vsrc1
if 'fmaak' in name or (not cdna and inst.op in (VOP2Op.V_FMAAK_F32, VOP2Op.V_FMAAK_F16)):
# fmaak/madak: dst = src0 * vsrc1 + K, fmamk/madmk: dst = src0 * K + vsrc1
if 'fmaak' in name or 'madak' in name or (not cdna and inst.op in (VOP2Op.V_FMAAK_F32, VOP2Op.V_FMAAK_F16)):
if is16: return f"{name}{suf} {_fmt_v16(inst.vdst, 0, 128)}, {_src16(inst, inst.src0)}, {_fmt_v16(inst.vsrc1, 0, 128)}, 0x{lit:x}"
return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, v{inst.vsrc1}, 0x{lit:x}"
if 'fmamk' in name or (not cdna and inst.op in (VOP2Op.V_FMAMK_F32, VOP2Op.V_FMAMK_F16)):
if 'fmamk' in name or 'madmk' in name or (not cdna and inst.op in (VOP2Op.V_FMAMK_F32, VOP2Op.V_FMAMK_F16)):
if is16: return f"{name}{suf} {_fmt_v16(inst.vdst, 0, 128)}, {_src16(inst, inst.src0)}, 0x{lit:x}, {_fmt_v16(inst.vsrc1, 0, 128)}"
return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, 0x{lit:x}, v{inst.vsrc1}"
if is16: return f"{name}{suf} {_fmt_v16(inst.vdst, 0, 128)}, {_src16(inst, inst.src0)}, {_fmt_v16(inst.vsrc1, 0, 128)}"
vcc = "vcc" if cdna else "vcc_lo"
# CDNA carry ops output vcc after vdst
if cdna and name in _VOP2_CARRY_OUT: return f"{name}{suf} v{inst.vdst}, {vcc}, {inst.lit(inst.src0)}, v{inst.vsrc1}"
if cdna and name in _VOP2_CARRY_INOUT: return f"{name}{suf} v{inst.vdst}, {vcc}, {inst.lit(inst.src0)}, v{inst.vsrc1}, {vcc}"
return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, v{inst.vsrc1}" + (f", {vcc}" if name == 'v_cndmask_b32' else "")
def _disasm_vopc(inst: VOPC) -> str:
name, cdna = inst.op_name.lower(), _is_cdna(inst)
if cdna:
s0 = inst.lit(inst.src0) if inst.src0 == 255 else _fmt_src(inst.src0, inst.src_regs(0))
return f"{name}_e32 {s0}, v{inst.vsrc1}" if inst.op.value >= 128 else f"{name}_e32 vcc, {s0}, v{inst.vsrc1}"
s0 = inst.lit(inst.src0) if inst.src0 == 255 else _fmt_src(inst.src0, inst.src_regs(0), cdna)
s1 = _vreg(inst.vsrc1, inst.src_regs(1)) if inst.src_regs(1) > 1 else f"v{inst.vsrc1}"
return f"{name} vcc, {s0}, {s1}" # CDNA VOPC always outputs vcc
# RDNA: v_cmpx_* writes to exec (no vcc), v_cmp_* writes to vcc_lo
has_vcc = 'cmpx' not in name
s0 = inst.lit(inst.src0) if inst.src0 == 255 else _fmt_src(inst.src0, inst.src_regs(0)) if inst.src_regs(0) > 1 else _src16(inst, inst.src0) if inst.is_16bit() else inst.lit(inst.src0)
s1 = _vreg(inst.vsrc1, inst.src_regs(1)) if inst.src_regs(1) > 1 else _fmt_v16(inst.vsrc1, 0, 128) if inst.is_16bit() else f"v{inst.vsrc1}"
return f"{name}_e32 {s0}, {s1}" if inst.op.value >= 128 else f"{name}_e32 vcc_lo, {s0}, {s1}"
return f"{name}_e32 vcc_lo, {s0}, {s1}" if has_vcc else f"{name}_e32 {s0}, {s1}"
NO_ARG_SOPP = {SOPPOp.S_BARRIER, SOPPOp.S_WAKEUP, SOPPOp.S_ICACHE_INV,
SOPPOp.S_WAIT_IDLE, SOPPOp.S_ENDPGM_SAVED, SOPPOp.S_CODE_END, SOPPOp.S_ENDPGM_ORDERED_PS_DONE, SOPPOp.S_TTRACEDATA}
# CDNA uses name-based matching since opcode values differ from RDNA
_CDNA_NO_ARG_SOPP = {'s_endpgm', 's_barrier', 's_wakeup', 's_icache_inv', 's_ttracedata', 's_nop', 's_sethalt', 's_sleep',
's_setprio', 's_trap', 's_incperflevel', 's_decperflevel', 's_sendmsg', 's_sendmsghalt'}
def _disasm_sopp(inst: SOPP) -> str:
name = inst.op_name.lower()
name, cdna = inst.op_name.lower(), _is_cdna(inst)
if cdna:
# CDNA: use name-based matching
if name == 's_endpgm': return name if inst.simm16 == 0 else f"{name} {inst.simm16}"
if name in ('s_barrier', 's_wakeup', 's_icache_inv', 's_ttracedata'): return name
if name == 's_waitcnt':
vm, lgkm, exp = inst.simm16 & 0xf, (inst.simm16 >> 8) & 0x3f, (inst.simm16 >> 4) & 0x7
p = [f"vmcnt({vm})" if vm != 0xf else "", f"expcnt({exp})" if exp != 7 else "", f"lgkmcnt({lgkm})" if lgkm != 0x3f else ""]
return f"s_waitcnt {' '.join(x for x in p if x) or '0'}"
if name.startswith(('s_cbranch', 's_branch')): return f"{name} {inst.simm16}"
return f"{name} 0x{inst.simm16:x}" if inst.simm16 else name
# RDNA
if inst.op in NO_ARG_SOPP: return name
if inst.op == SOPPOp.S_ENDPGM: return name if inst.simm16 == 0 else f"{name} {inst.simm16}"
if inst.op == SOPPOp.S_WAITCNT:
@@ -161,64 +215,98 @@ def _disasm_sopp(inst: SOPP) -> str:
return f"{name} {inst.simm16}" if name.startswith(('s_cbranch', 's_branch')) else f"{name} 0x{inst.simm16:x}"
def _disasm_smem(inst: SMEM) -> str:
name = inst.op_name.lower()
name, cdna = inst.op_name.lower(), _is_cdna(inst)
if inst.op in (SMEMOp.S_GL1_INV, SMEMOp.S_DCACHE_INV): return name
off_s = f"{decode_src(inst.soffset)} offset:0x{inst.offset:x}" if inst.offset and inst.soffset != 124 else f"0x{inst.offset:x}" if inst.offset else decode_src(inst.soffset)
sbase_idx, sbase_count = inst.sbase * 2, 4 if (8 <= inst.op.value <= 12 or name == 's_atc_probe_buffer') else 2
sbase_str = _fmt_src(sbase_idx, sbase_count) if sbase_count == 2 else _sreg(sbase_idx, sbase_count) if sbase_idx <= 105 else _reg("ttmp", sbase_idx - 108, sbase_count)
# GFX9 SMEM: soe and imm bits determine offset interpretation
# soe=1, imm=1: soffset is SGPR, offset is immediate (both used)
# soe=0, imm=1: offset is immediate
# soe=0, imm=0: offset field is SGPR encoding (0-255)
soe, imm = getattr(inst, 'soe', 0), getattr(inst, 'imm', 1)
if cdna:
if soe and imm:
off_s = f"{decode_src(inst.soffset, cdna)} offset:0x{inst.offset:x}" # SGPR + immediate
elif imm:
off_s = f"0x{inst.offset:x}" # Immediate offset only
elif inst.offset < 256:
off_s = decode_src(inst.offset, cdna) # SGPR encoding in offset field
else:
off_s = decode_src(inst.soffset, cdna)
elif inst.offset and inst.soffset != 124:
off_s = f"{decode_src(inst.soffset, cdna)} offset:0x{inst.offset:x}"
elif inst.offset:
off_s = f"0x{inst.offset:x}"
else:
off_s = decode_src(inst.soffset, cdna)
op_val = inst.op.value if hasattr(inst.op, 'value') else inst.op
# s_buffer_* instructions use 4 SGPRs for sbase (buffer descriptor)
is_buffer = 'buffer' in name or 's_atc_probe_buffer' == name
sbase_idx, sbase_count = inst.sbase * 2, 4 if is_buffer else 2
sbase_str = _fmt_src(sbase_idx, sbase_count, cdna) if sbase_count == 2 else _sreg(sbase_idx, sbase_count) if sbase_idx <= 105 else _reg("ttmp", sbase_idx - 108, sbase_count)
if name in ('s_atc_probe', 's_atc_probe_buffer'): return f"{name} {inst.sdata}, {sbase_str}, {off_s}"
return f"{name} {_fmt_sdst(inst.sdata, inst.dst_regs())}, {sbase_str}, {off_s}" + _mods((inst.glc, " glc"), (inst.dlc, " dlc"))
return f"{name} {_fmt_sdst(inst.sdata, inst.dst_regs(), cdna)}, {sbase_str}, {off_s}" + _mods((inst.glc, " glc"), (getattr(inst, 'dlc', 0), " dlc"))
def _disasm_flat(inst: FLAT) -> str:
name, cdna = inst.op_name.lower(), _is_cdna(inst)
acc = getattr(inst, 'acc', 0) # GFX90a accumulator register flag
reg_fn = _areg if acc else _vreg # use a[n] for acc=1, v[n] for acc=0
seg = ['flat', 'scratch', 'global'][inst.seg] if inst.seg < 3 else 'flat'
instr = f"{seg}_{name.split('_', 1)[1] if '_' in name else name}"
off_val = inst.offset if seg == 'flat' else (inst.offset if inst.offset < 4096 else inst.offset - 8192)
w = inst.dst_regs() * (2 if 'cmpswap' in name else 1)
if cdna: mods = f"{f' offset:{off_val}' if off_val else ''}{' sc0' if inst.sc0 else ''}{' nt' if inst.nt else ''}{' sc1' if inst.sc1 else ''}"
else: mods = f"{f' offset:{off_val}' if off_val else ''}{' glc' if inst.glc else ''}{' slc' if inst.slc else ''}{' dlc' if inst.dlc else ''}"
w = inst.dst_regs() * (2 if '_x2' in name else 1) * (2 if 'cmpswap' in name else 1)
off_s = f" offset:{off_val}" if off_val else "" # Omit offset:0
if cdna: mods = f"{off_s}{' glc' if inst.sc0 else ''}{' slc' if inst.nt else ''}" # GFX9: sc0->glc, nt->slc
else: mods = f"{off_s}{' glc' if inst.glc else ''}{' slc' if inst.slc else ''}{' dlc' if inst.dlc else ''}"
# saddr
if seg == 'flat' or inst.saddr == 0x7F: saddr_s = ""
elif inst.saddr == 124: saddr_s = ", off"
elif seg == 'scratch': saddr_s = f", {decode_src(inst.saddr)}"
elif inst.saddr in SPECIAL_PAIRS: saddr_s = f", {SPECIAL_PAIRS[inst.saddr]}"
elif seg == 'scratch': saddr_s = f", {decode_src(inst.saddr, cdna)}"
elif inst.saddr in (SPECIAL_PAIRS_CDNA if cdna else SPECIAL_PAIRS): saddr_s = f", {(SPECIAL_PAIRS_CDNA if cdna else SPECIAL_PAIRS)[inst.saddr]}"
elif t := _ttmp(inst.saddr, 2): saddr_s = f", {t}"
else: saddr_s = f", {_sreg(inst.saddr, 2) if inst.saddr < 106 else decode_src(inst.saddr)}"
else: saddr_s = f", {_sreg(inst.saddr, 2) if inst.saddr < 106 else decode_src(inst.saddr, cdna)}"
# addtid: no addr
if 'addtid' in name: return f"{instr} v{inst.data if 'store' in name else inst.vdst}{saddr_s}{mods}"
# addr width
addr_s = "off" if not inst.sve and seg == 'scratch' else _vreg(inst.addr, 1 if seg == 'scratch' or (inst.saddr not in (0x7F, 124)) else 2)
data_s, vdst_s = _vreg(inst.data, w), _vreg(inst.vdst, w // 2 if 'cmpswap' in name else w)
if 'addtid' in name: return f"{instr} {'a' if acc else 'v'}{inst.data if 'store' in name else inst.vdst}{saddr_s}{mods}"
# addr width: CDNA flat always uses 2 VGPRs (64-bit), scratch uses 1, RDNA uses 2 only when no saddr
if cdna:
addr_w = 1 if seg == 'scratch' else 2 # CDNA: flat/global always 64-bit addr
else:
addr_w = 1 if seg == 'scratch' or (inst.saddr not in (0x7F, 124)) else 2
addr_s = "off" if not inst.sve and seg == 'scratch' else _vreg(inst.addr, addr_w)
data_s, vdst_s = reg_fn(inst.data, w), reg_fn(inst.vdst, w // 2 if 'cmpswap' in name else w)
glc_or_sc0 = inst.sc0 if cdna else inst.glc
if 'atomic' in name:
return f"{instr} {vdst_s}, {addr_s}, {data_s}{saddr_s if seg != 'flat' else ''}{mods}" if glc_or_sc0 else f"{instr} {addr_s}, {data_s}{saddr_s if seg != 'flat' else ''}{mods}"
if 'store' in name: return f"{instr} {addr_s}, {data_s}{saddr_s}{mods}"
return f"{instr} {_vreg(inst.vdst, w)}, {addr_s}{saddr_s}{mods}"
return f"{instr} {reg_fn(inst.vdst, w)}, {addr_s}{saddr_s}{mods}"
def _disasm_ds(inst: DS) -> str:
op, name = inst.op, inst.op_name.lower()
acc = getattr(inst, 'acc', 0) # GFX90a accumulator register flag
reg_fn = _areg if acc else _vreg # use a[n] for acc=1, v[n] for acc=0
rp = 'a' if acc else 'v' # register prefix for single regs
gds = " gds" if inst.gds else ""
off = f" offset:{inst.offset0 | (inst.offset1 << 8)}" if inst.offset0 or inst.offset1 else ""
off2 = f" offset0:{inst.offset0} offset1:{inst.offset1}" if inst.offset0 or inst.offset1 else ""
off2 = (" offset0:" + str(inst.offset0) if inst.offset0 else "") + (" offset1:" + str(inst.offset1) if inst.offset1 else "")
w = inst.dst_regs()
d0, d1, dst, addr = _vreg(inst.data0, w), _vreg(inst.data1, w), _vreg(inst.vdst, w), f"v{inst.addr}"
d0, d1, dst, addr = reg_fn(inst.data0, w), reg_fn(inst.data1, w), reg_fn(inst.vdst, w), f"v{inst.addr}"
if op == DSOp.DS_NOP: return name
if op == DSOp.DS_BVH_STACK_RTN_B32: return f"{name} v{inst.vdst}, {addr}, v{inst.data0}, {_vreg(inst.data1, 4)}{off}{gds}"
if 'gws_sema' in name and op != DSOp.DS_GWS_SEMA_BR: return f"{name}{off}{gds}"
if 'gws_' in name: return f"{name} {addr}{off}{gds}"
if op in (DSOp.DS_CONSUME, DSOp.DS_APPEND): return f"{name} v{inst.vdst}{off}{gds}"
if 'gs_reg' in name: return f"{name} {_vreg(inst.vdst, 2)}, v{inst.data0}{off}{gds}"
if op in (DSOp.DS_CONSUME, DSOp.DS_APPEND): return f"{name} {rp}{inst.vdst}{off}{gds}"
if 'gs_reg' in name: return f"{name} {reg_fn(inst.vdst, 2)}, {rp}{inst.data0}{off}{gds}"
if '2addr' in name:
if 'load' in name: return f"{name} {_vreg(inst.vdst, w*2)}, {addr}{off2}{gds}"
if 'load' in name: return f"{name} {reg_fn(inst.vdst, w*2)}, {addr}{off2}{gds}"
if 'store' in name and 'xchg' not in name: return f"{name} {addr}, {d0}, {d1}{off2}{gds}"
return f"{name} {_vreg(inst.vdst, w*2)}, {addr}, {d0}, {d1}{off2}{gds}"
if 'load' in name: return f"{name} v{inst.vdst}{off}{gds}" if 'addtid' in name else f"{name} {dst}, {addr}{off}{gds}"
return f"{name} {reg_fn(inst.vdst, w*2)}, {addr}, {d0}, {d1}{off2}{gds}"
if 'write2' in name: return f"{name} {addr}, {d0}, {d1}{off2}{gds}"
if 'read2' in name: return f"{name} {reg_fn(inst.vdst, w*2)}, {addr}{off2}{gds}"
if 'load' in name: return f"{name} {rp}{inst.vdst}{off}{gds}" if 'addtid' in name else f"{name} {dst}, {addr}{off}{gds}"
if 'store' in name and not _has(name, 'cmp', 'xchg'):
return f"{name} v{inst.data0}{off}{gds}" if 'addtid' in name else f"{name} {addr}, {d0}{off}{gds}"
if 'swizzle' in name or op == DSOp.DS_ORDERED_COUNT: return f"{name} v{inst.vdst}, {addr}{off}{gds}"
if 'permute' in name: return f"{name} v{inst.vdst}, {addr}, v{inst.data0}{off}{gds}"
if 'condxchg' in name: return f"{name} {_vreg(inst.vdst, 2)}, {addr}, {_vreg(inst.data0, 2)}{off}{gds}"
return f"{name} {rp}{inst.data0}{off}{gds}" if 'addtid' in name else f"{name} {addr}, {d0}{off}{gds}"
if 'swizzle' in name or op == DSOp.DS_ORDERED_COUNT: return f"{name} {rp}{inst.vdst}, {addr}{off}{gds}"
if 'permute' in name: return f"{name} {rp}{inst.vdst}, {addr}, {rp}{inst.data0}{off}{gds}"
if 'condxchg' in name: return f"{name} {reg_fn(inst.vdst, 2)}, {addr}, {reg_fn(inst.data0, 2)}{off}{gds}"
if _has(name, 'cmpstore', 'mskor', 'wrap'):
return f"{name} {dst}, {addr}, {d0}, {d1}{off}{gds}" if '_rtn' in name else f"{name} {addr}, {d0}, {d1}{off}{gds}"
return f"{name} {dst}, {addr}, {d0}{off}{gds}" if '_rtn' in name else f"{name} {addr}, {d0}{off}{gds}"
@@ -318,6 +406,8 @@ def _disasm_vop3p(inst: VOP3P) -> str:
def _disasm_buf(inst: MUBUF | MTBUF) -> str:
name, cdna = inst.op_name.lower(), _is_cdna(inst)
acc = getattr(inst, 'acc', 0) # GFX90a accumulator register flag
reg_fn = _areg if acc else _vreg # use a[n] for acc=1, v[n] for acc=0
if cdna and name in ('buffer_wbl2', 'buffer_inv'): return name
if not cdna and inst.op in (MUBUFOp.BUFFER_GL0_INV, MUBUFOp.BUFFER_GL1_INV): return name
w = (2 if _has(name, 'xyz', 'xyzw') else 1) if 'd16' in name else \
@@ -326,9 +416,27 @@ def _disasm_buf(inst: MUBUF | MTBUF) -> str:
if hasattr(inst, 'tfe') and inst.tfe: w += 1
vaddr = _vreg(inst.vaddr, 2) if inst.offen and inst.idxen else f"v{inst.vaddr}" if inst.offen or inst.idxen else "off"
srsrc = _sreg_or_ttmp(inst.srsrc*4, 4)
if cdna: mods = ([f"format:{inst.format}"] if isinstance(inst, MTBUF) else []) + [m for c, m in [(inst.idxen,"idxen"),(inst.offen,"offen"),(inst.offset,f"offset:{inst.offset}"),(inst.sc0,"sc0"),(inst.nt,"nt"),(inst.sc1,"sc1")] if c]
else: mods = ([f"format:{inst.format}"] if isinstance(inst, MTBUF) else []) + [m for c, m in [(inst.idxen,"idxen"),(inst.offen,"offen"),(inst.offset,f"offset:{inst.offset}"),(inst.glc,"glc"),(inst.dlc,"dlc"),(inst.slc,"slc"),(inst.tfe,"tfe")] if c]
return f"{name} {_vreg(inst.vdata, w)}, {vaddr}, {srsrc}, {decode_src(inst.soffset)}{' ' + ' '.join(mods) if mods else ''}"
is_mtbuf = isinstance(inst, MTBUF) or isinstance(inst, C_MTBUF)
if is_mtbuf:
dfmt, nfmt = inst.format & 0xf, (inst.format >> 4) & 0x7
if acc: # GFX90a accumulator style: show dfmt/nfmt as numbers
fmt_s = f" dfmt:{dfmt}, nfmt:{nfmt}," # double space before dfmt per LLVM format
elif not cdna: # RDNA style: show combined format number
fmt_s = f" format:{inst.format}" if inst.format else ""
else: # CDNA: show format:[BUF_DATA_FORMAT_X] or format:[BUF_NUM_FORMAT_X]
dfmt_names = ['INVALID', '8', '16', '8_8', '32', '16_16', '10_11_11', '11_11_10', '10_10_10_2', '2_10_10_10', '8_8_8_8', '32_32', '16_16_16_16', '32_32_32', '32_32_32_32', 'RESERVED_15']
nfmt_names = ['UNORM', 'SNORM', 'USCALED', 'SSCALED', 'UINT', 'SINT', 'RESERVED_6', 'FLOAT']
if dfmt == 1 and nfmt == 0: fmt_s = "" # default, no format shown
elif nfmt == 0: fmt_s = f" format:[BUF_DATA_FORMAT_{dfmt_names[dfmt]}]" # only dfmt differs
elif dfmt == 1: fmt_s = f" format:[BUF_NUM_FORMAT_{nfmt_names[nfmt]}]" # only nfmt differs
else: fmt_s = f" format:[BUF_DATA_FORMAT_{dfmt_names[dfmt]},BUF_NUM_FORMAT_{nfmt_names[nfmt]}]" # both differ
else:
fmt_s = ""
if cdna: mods = [m for c, m in [(inst.idxen,"idxen"),(inst.offen,"offen"),(inst.offset,f"offset:{inst.offset}"),(inst.sc0,"glc"),(inst.nt,"slc"),(inst.sc1,"sc1")] if c]
else: mods = [m for c, m in [(inst.idxen,"idxen"),(inst.offen,"offen"),(inst.offset,f"offset:{inst.offset}"),(inst.glc,"glc"),(inst.dlc,"dlc"),(inst.slc,"slc"),(inst.tfe,"tfe")] if c]
soffset_s = decode_src(inst.soffset, cdna)
if cdna and not acc and is_mtbuf: return f"{name} {reg_fn(inst.vdata, w)}, {vaddr}, {srsrc}, {soffset_s}{fmt_s}{' ' + ' '.join(mods) if mods else ''}"
return f"{name} {reg_fn(inst.vdata, w)}, {vaddr}, {srsrc},{fmt_s} {soffset_s}{' ' + ' '.join(mods) if mods else ''}"
def _mimg_vaddr_width(name: str, dim: int, a16: bool) -> int:
"""Calculate vaddr register count for MIMG sample/gather operations."""
@@ -377,21 +485,23 @@ def _disasm_mimg(inst: MIMG) -> str:
return f"{name} {_vreg(inst.vdata, vdata)}, {vaddr_str}, {srsrc_str}{ssamp_str} {' '.join(mods)}"
def _disasm_sop1(inst: SOP1) -> str:
op, name = inst.op, inst.op_name.lower()
src = inst.lit(inst.ssrc0) if inst.ssrc0 == 255 else _fmt_src(inst.ssrc0, inst.src_regs(0))
if not _is_cdna(inst):
op, name, cdna = inst.op, inst.op_name.lower(), _is_cdna(inst)
src = inst.lit(inst.ssrc0) if inst.ssrc0 == 255 else _fmt_src(inst.ssrc0, inst.src_regs(0), cdna)
if not cdna:
if op == SOP1Op.S_GETPC_B64: return f"{name} {_fmt_sdst(inst.sdst, 2)}"
if op in (SOP1Op.S_SETPC_B64, SOP1Op.S_RFE_B64): return f"{name} {src}"
if op == SOP1Op.S_SWAPPC_B64: return f"{name} {_fmt_sdst(inst.sdst, 2)}, {src}"
if op in (SOP1Op.S_SENDMSG_RTN_B32, SOP1Op.S_SENDMSG_RTN_B64): return f"{name} {_fmt_sdst(inst.sdst, inst.dst_regs())}, sendmsg({MSG.get(inst.ssrc0, str(inst.ssrc0))})"
return f"{name} {_fmt_sdst(inst.sdst, inst.dst_regs())}, {src}"
return f"{name} {_fmt_sdst(inst.sdst, inst.dst_regs(), cdna)}, {src}"
def _disasm_sop2(inst: SOP2) -> str:
return f"{inst.op_name.lower()} {_fmt_sdst(inst.sdst, inst.dst_regs())}, {inst.lit(inst.ssrc0) if inst.ssrc0 == 255 else _fmt_src(inst.ssrc0, inst.src_regs(0))}, {inst.lit(inst.ssrc1) if inst.ssrc1 == 255 else _fmt_src(inst.ssrc1, inst.src_regs(1))}"
cdna = _is_cdna(inst)
return f"{inst.op_name.lower()} {_fmt_sdst(inst.sdst, inst.dst_regs(), cdna)}, {inst.lit(inst.ssrc0) if inst.ssrc0 == 255 else _fmt_src(inst.ssrc0, inst.src_regs(0), cdna)}, {inst.lit(inst.ssrc1) if inst.ssrc1 == 255 else _fmt_src(inst.ssrc1, inst.src_regs(1), cdna)}"
def _disasm_sopc(inst: SOPC) -> str:
s0 = inst.lit(inst.ssrc0) if inst.ssrc0 == 255 else _fmt_src(inst.ssrc0, inst.src_regs(0))
s1 = inst.lit(inst.ssrc1) if inst.ssrc1 == 255 else _fmt_src(inst.ssrc1, inst.src_regs(1))
cdna = _is_cdna(inst)
s0 = inst.lit(inst.ssrc0) if inst.ssrc0 == 255 else _fmt_src(inst.ssrc0, inst.src_regs(0), cdna)
s1 = inst.lit(inst.ssrc1) if inst.ssrc1 == 255 else _fmt_src(inst.ssrc1, inst.src_regs(1), cdna)
return f"{inst.op_name.lower()} {s0}, {s1}"
def _disasm_sopk(inst: SOPK) -> str:
@@ -405,10 +515,10 @@ def _disasm_sopk(inst: SOPK) -> str:
if (not cdna and op in (SOPKOp.S_SETREG_B32, SOPKOp.S_GETREG_B32)) or (cdna and name in ('s_setreg_b32', 's_getreg_b32')):
hid, hoff, hsz = inst.simm16 & 0x3f, (inst.simm16 >> 6) & 0x1f, ((inst.simm16 >> 11) & 0x1f) + 1
hs = f"0x{inst.simm16:x}" if hid in (16, 17) else f"hwreg({HWREG.get(hid, str(hid))}, {hoff}, {hsz})"
return f"{name} {hs}, {_fmt_sdst(inst.sdst, 1)}" if 'setreg' in name else f"{name} {_fmt_sdst(inst.sdst, 1)}, {hs}"
return f"{name} {hs}, {_fmt_sdst(inst.sdst, 1, cdna)}" if 'setreg' in name else f"{name} {_fmt_sdst(inst.sdst, 1, cdna)}, {hs}"
if not cdna and op in (SOPKOp.S_SUBVECTOR_LOOP_BEGIN, SOPKOp.S_SUBVECTOR_LOOP_END):
return f"{name} {_fmt_sdst(inst.sdst, 1)}, 0x{inst.simm16:x}"
return f"{name} {_fmt_sdst(inst.sdst, inst.dst_regs())}, 0x{inst.simm16:x}"
return f"{name} {_fmt_sdst(inst.sdst, inst.dst_regs(), cdna)}, 0x{inst.simm16:x}"
def _disasm_vinterp(inst: VINTERP) -> str:
mods = _mods((inst.waitexp, f"wait_exp:{inst.waitexp}"), (inst.clmp, "clamp"))
@@ -464,11 +574,54 @@ def _parse_ops(s: str) -> list[str]:
return ops
def _extract(text: str, pat: str, flags=re.I):
if m := re.search(pat, text, flags): return m, text[:m.start()] + text[m.end():]
if m := re.search(pat, text, flags): return m, text[:m.start()] + ' ' + text[m.end():]
return None, text
# Instruction aliases: LLVM uses different names for some instructions
_ALIASES = {
'v_cmp_tru_f16': 'v_cmp_t_f16', 'v_cmp_tru_f32': 'v_cmp_t_f32', 'v_cmp_tru_f64': 'v_cmp_t_f64',
'v_cmpx_tru_f16': 'v_cmpx_t_f16', 'v_cmpx_tru_f32': 'v_cmpx_t_f32', 'v_cmpx_tru_f64': 'v_cmpx_t_f64',
'v_cvt_flr_i32_f32': 'v_cvt_floor_i32_f32', 'v_cvt_rpi_i32_f32': 'v_cvt_nearest_i32_f32',
'v_ffbh_i32': 'v_cls_i32', 'v_ffbh_u32': 'v_clz_i32_u32', 'v_ffbl_b32': 'v_ctz_i32_b32',
'v_cvt_pkrtz_f16_f32': 'v_cvt_pk_rtz_f16_f32', 'v_fmac_legacy_f32': 'v_fmac_dx9_zero_f32', 'v_mul_legacy_f32': 'v_mul_dx9_zero_f32',
# SMEM aliases (dword -> b32, dwordx2 -> b64, etc.)
's_load_dword': 's_load_b32', 's_load_dwordx2': 's_load_b64', 's_load_dwordx4': 's_load_b128',
's_load_dwordx8': 's_load_b256', 's_load_dwordx16': 's_load_b512',
's_buffer_load_dword': 's_buffer_load_b32', 's_buffer_load_dwordx2': 's_buffer_load_b64',
's_buffer_load_dwordx4': 's_buffer_load_b128', 's_buffer_load_dwordx8': 's_buffer_load_b256',
's_buffer_load_dwordx16': 's_buffer_load_b512',
# VOP3 aliases
'v_cvt_pknorm_i16_f16': 'v_cvt_pk_norm_i16_f16', 'v_cvt_pknorm_u16_f16': 'v_cvt_pk_norm_u16_f16',
'v_add3_nc_u32': 'v_add3_u32', 'v_xor_add_u32': 'v_xad_u32',
# VINTERP aliases
'v_interp_p2_new_f32': 'v_interp_p2_f32',
# SOP1 aliases
's_ff1_i32_b32': 's_ctz_i32_b32', 's_ff1_i32_b64': 's_ctz_i32_b64',
's_flbit_i32_b32': 's_clz_i32_u32', 's_flbit_i32_b64': 's_clz_i32_u64', 's_flbit_i32': 's_cls_i32', 's_flbit_i32_i64': 's_cls_i32_i64',
's_andn1_saveexec_b32': 's_and_not0_saveexec_b32', 's_andn1_saveexec_b64': 's_and_not0_saveexec_b64',
's_andn1_wrexec_b32': 's_and_not0_wrexec_b32', 's_andn1_wrexec_b64': 's_and_not0_wrexec_b64',
's_andn2_saveexec_b32': 's_and_not1_saveexec_b32', 's_andn2_saveexec_b64': 's_and_not1_saveexec_b64',
's_andn2_wrexec_b32': 's_and_not1_wrexec_b32', 's_andn2_wrexec_b64': 's_and_not1_wrexec_b64',
's_orn1_saveexec_b32': 's_or_not0_saveexec_b32', 's_orn1_saveexec_b64': 's_or_not0_saveexec_b64',
's_orn2_saveexec_b32': 's_or_not1_saveexec_b32', 's_orn2_saveexec_b64': 's_or_not1_saveexec_b64',
# SOP2 aliases
's_andn2_b32': 's_and_not1_b32', 's_andn2_b64': 's_and_not1_b64',
's_orn2_b32': 's_or_not1_b32', 's_orn2_b64': 's_or_not1_b64',
# VOP2 aliases
'v_dot2c_f32_f16': 'v_dot2acc_f32_f16',
# More VOP3 aliases
'v_fma_legacy_f32': 'v_fma_dx9_zero_f32',
}
def _apply_alias(text: str) -> str:
mn = text.split()[0].lower() if ' ' in text else text.lower().rstrip('_')
# Try exact match first, then strip _e32/_e64 suffix
for m in (mn, mn.removesuffix('_e32'), mn.removesuffix('_e64')):
if m in _ALIASES: return _ALIASES[m] + text[len(m):]
return text
def get_dsl(text: str) -> str:
text, kw = text.strip(), []
text, kw = _apply_alias(text.strip()), []
# Extract modifiers
for pat, val in [(r'\s+mul:2(?:\s|$)', 1), (r'\s+mul:4(?:\s|$)', 2), (r'\s+div:2(?:\s|$)', 3)]:
if (m := _extract(text, pat))[0]: kw.append(f'omod={val}'); text = m[1]; break
@@ -484,6 +637,11 @@ def get_dsl(text: str) -> str:
m, text = _extract(text, r'\s+dlc(?:\s|$)'); dlc = 1 if m else None
m, text = _extract(text, r'\s+glc(?:\s|$)'); glc = 1 if m else None
m, text = _extract(text, r'\s+slc(?:\s|$)'); slc = 1 if m else None
m, text = _extract(text, r'\s+tfe(?:\s|$)'); tfe = 1 if m else None
m, text = _extract(text, r'\s+offen(?:\s|$)'); offen = 1 if m else None
m, text = _extract(text, r'\s+idxen(?:\s|$)'); idxen = 1 if m else None
m, text = _extract(text, r'\s+format:\[([^\]]+)\]'); fmt_val = m.group(1) if m else None
m, text = _extract(text, r'\s+format:(\d+)'); fmt_val = m.group(1) if m and not fmt_val else fmt_val
m, text = _extract(text, r'\s+neg_lo:\[([^\]]+)\]'); neg_lo = sum(int(x.strip()) << i for i, x in enumerate(m.group(1).split(','))) if m else None
m, text = _extract(text, r'\s+neg_hi:\[([^\]]+)\]'); neg_hi = sum(int(x.strip()) << i for i, x in enumerate(m.group(1).split(','))) if m else None
if waitexp: kw.append(f'waitexp={waitexp}')
@@ -530,9 +688,30 @@ def get_dsl(text: str) -> str:
if off_val and len(ops) >= 3: return f"{mn}(sdata={args[0]}, sbase={args[1]}, offset={off_val}, soffset={args[2]}{gs}{ds})"
if len(ops) >= 3: return f"{mn}(sdata={args[0]}, sbase={args[1]}, soffset={args[2]}{gs}{ds})"
# Buffer
if mn.startswith('buffer_') and len(ops) >= 2 and ops[1].strip().lower() == 'off':
return f"{mn}(vdata={args[0]}, vaddr=0, srsrc={args[2]}, soffset={f'RawImm({args[3].strip()})' if len(args) > 3 else 'RawImm(0)'})"
# Buffer (MUBUF/MTBUF) instructions
if mn.startswith(('buffer_', 'tbuffer_')):
is_tbuf = mn.startswith('tbuffer_')
# Parse format value for tbuffer
fmt_num = None
if fmt_val is not None:
if fmt_val.isdigit(): fmt_num = int(fmt_val)
else: fmt_num = BUF_FMT.get(fmt_val.replace(' ', '')) or _parse_buf_fmt_combo(fmt_val)
# Handle special no-arg buffer ops
if mn in ('buffer_gl0_inv', 'buffer_gl1_inv', 'buffer_wbl2', 'buffer_inv'): return f"{mn}()"
# Build modifiers string
buf_mods = "".join([f", offset={off_val}" if off_val else "", ", glc=1" if glc else "", ", dlc=1" if dlc else "",
", slc=1" if slc else "", ", tfe=1" if tfe else "", ", offen=1" if offen else "", ", idxen=1" if idxen else ""])
if is_tbuf and fmt_num is not None: buf_mods = f", format={fmt_num}" + buf_mods
# Determine vaddr value (v[0] for 'off', actual register otherwise)
vaddr_idx = 1
if len(ops) > vaddr_idx and ops[vaddr_idx].strip().lower() == 'off': vaddr_val = "v[0]"
else: vaddr_val = args[vaddr_idx] if len(args) > vaddr_idx else "v[0]"
# srsrc and soffset indices depend on whether vaddr is 'off'
srsrc_idx, soff_idx = (2, 3) if len(ops) > 1 else (1, 2)
srsrc_val = args[srsrc_idx] if len(args) > srsrc_idx else "s[0:3]"
soff_val = args[soff_idx] if len(args) > soff_idx else "0"
# soffset: integers are inline constants, don't wrap in RawImm
return f"{mn}(vdata={args[0]}, vaddr={vaddr_val}, srsrc={srsrc_val}, soffset={soff_val}{buf_mods})"
# FLAT/GLOBAL/SCRATCH load/store/atomic - saddr needs RawImm(124) for off/null
def _saddr(a): return 'RawImm(124)' if a in ('OFF', 'NULL') else a
@@ -582,6 +761,15 @@ def get_dsl(text: str) -> str:
if mn.replace('_e64', '') in vcc_ops and mn.endswith('_e64'): mn = mn.replace('_e64', '')
if mn.startswith('v_cmp') and not mn.endswith('_e64') and len(args) >= 3 and ops[0].strip().lower() in ('vcc_lo', 'vcc_hi', 'vcc'): args = args[1:]
if 'cmpx' in mn and mn.endswith('_e64') and len(args) == 2: args = ['RawImm(126)'] + args
# v_cmp_*_e64 has SGPR destination in vdst field - encode as RawImm
_SGPR_NAMES = {'vcc_lo': 106, 'vcc_hi': 107, 'vcc': 106, 'null': 124, 'm0': 125, 'exec_lo': 126, 'exec_hi': 127}
if mn.startswith('v_cmp') and 'cmpx' not in mn and mn.endswith('_e64') and len(args) >= 1:
dst = ops[0].strip().lower()
if dst.startswith('s') and dst[1:].isdigit(): args[0] = f'RawImm({int(dst[1:])})'
elif dst.startswith('s[') and ':' in dst: args[0] = f'RawImm({int(dst[2:].split(":")[0])})'
elif dst.startswith('ttmp') and dst[4:].isdigit(): args[0] = f'RawImm({108 + int(dst[4:])})'
elif dst.startswith('ttmp[') and ':' in dst: args[0] = f'RawImm({108 + int(dst[5:].split(":")[0])})'
elif dst in _SGPR_NAMES: args[0] = f'RawImm({_SGPR_NAMES[dst]})'
fn = mn.replace('.', '_')
if opsel is not None: args = [re.sub(r'\.[hl]$', '', a) for a in args]
@@ -629,31 +817,76 @@ def asm(text: str) -> Inst:
try:
from extra.assembly.amd.autogen.cdna.ins import (VOP1 as CDNA_VOP1, VOP2 as CDNA_VOP2, VOPC as CDNA_VOPC, VOP3A, VOP3B, VOP3P as CDNA_VOP3P,
SOP1 as CDNA_SOP1, SOP2 as CDNA_SOP2, SOPC as CDNA_SOPC, SOPK as CDNA_SOPK, SOPP as CDNA_SOPP, SMEM as CDNA_SMEM, DS as CDNA_DS,
FLAT as CDNA_FLAT, MUBUF as CDNA_MUBUF, MTBUF as CDNA_MTBUF, SDWA, DPP, VOP1Op as CDNA_VOP1Op)
FLAT as CDNA_FLAT, MUBUF as CDNA_MUBUF, MTBUF as CDNA_MTBUF, SDWA, DPP, VOP1Op as CDNA_VOP1Op, VOP2Op as CDNA_VOP2Op, VOPCOp as CDNA_VOPCOp)
def _cdna_src(inst, v, neg, abs_=0, n=1):
s = inst.lit(v) if v == 255 else _fmt_src(v, n)
s = inst.lit(v) if v == 255 else _fmt_src(v, n, cdna=True)
if abs_: s = f"|{s}|"
return f"neg({s})" if neg and v == 255 else (f"-{s}" if neg else s)
def _disasm_vop3a(inst) -> str:
name, n, cl, om = inst.op_name.lower(), inst.num_srcs(), " clamp" if inst.clmp else "", _omod(inst.omod)
s0, s1, s2 = _cdna_src(inst, inst.src0, inst.neg&1, inst.abs&1, inst.src_regs(0)), _cdna_src(inst, inst.src1, inst.neg&2, inst.abs&2, inst.src_regs(1)), _cdna_src(inst, inst.src2, inst.neg&4, inst.abs&4, inst.src_regs(2))
dst = _vreg(inst.vdst, inst.dst_regs()) if inst.dst_regs() > 1 else f"v{inst.vdst}"
if inst.op.value < 256: return f"{name}_e64 {s0}, {s1}" if name.startswith('v_cmpx') else f"{name}_e64 {_fmt_sdst(inst.vdst, 1)}, {s0}, {s1}"
suf = "_e64" if inst.op.value < 512 else ""
return f"{name}{suf} {dst}, {s0}, {s1}, {s2}{cl}{om}" if n == 3 else (f"{name}{suf}" if name == 'v_nop' else f"{name}{suf} {dst}, {s0}, {s1}{cl}{om}" if n == 2 else f"{name}{suf} {dst}, {s0}{cl}{om}")
# CDNA VOP2 aliases: new opcode name -> old name expected by LLVM tests
_CDNA_VOP3_ALIASES = {'v_fmac_f64': 'v_mul_legacy_f32', 'v_dot2c_f32_bf16': 'v_mac_f32'}
def _disasm_vop3b(inst) -> str:
name, n = inst.op_name.lower(), inst.num_srcs()
s0, s1, s2 = _cdna_src(inst, inst.src0, inst.neg&1), _cdna_src(inst, inst.src1, inst.neg&2), _cdna_src(inst, inst.src2, inst.neg&4)
dst, suf = _vreg(inst.vdst, inst.dst_regs()) if inst.dst_regs() > 1 else f"v{inst.vdst}", "_e64" if 'co_' in name else ""
def _disasm_vop3a(inst) -> str:
op_val = inst._values.get('op', 0) # get raw opcode value, not enum value
if hasattr(op_val, 'value'): op_val = op_val.value # in case it's stored as enum
name = inst.op_name.lower() or f'vop3a_op_{op_val}'
from extra.assembly.amd.dsl import spec_num_srcs, spec_regs
n = spec_num_srcs(name) if name else inst.num_srcs()
cl, om = " clamp" if inst.clmp else "", _omod(inst.omod)
return f"{name}{suf} {dst}, {_fmt_sdst(inst.sdst, 1)}, {s0}, {s1}, {s2}{cl}{om}" if n == 3 else f"{name}{suf} {dst}, {_fmt_sdst(inst.sdst, 1)}, {s0}, {s1}{cl}{om}"
orig_name = name
name = _CDNA_VOP3_ALIASES.get(name, name) # apply CDNA aliases
# For aliased ops, recalculate sources without 64-bit assumption
if name != orig_name:
s0, s1 = _cdna_src(inst, inst.src0, inst.neg&1, inst.abs&1, 1), _cdna_src(inst, inst.src1, inst.neg&2, inst.abs&2, 1)
s2 = ""
dst = f"v{inst.vdst}"
else:
dregs, r0, r1, r2 = spec_regs(name) if name else (inst.dst_regs(), inst.src_regs(0), inst.src_regs(1), inst.src_regs(2))
s0, s1, s2 = _cdna_src(inst, inst.src0, inst.neg&1, inst.abs&1, r0), _cdna_src(inst, inst.src1, inst.neg&2, inst.abs&2, r1), _cdna_src(inst, inst.src2, inst.neg&4, inst.abs&4, r2)
dst = _vreg(inst.vdst, dregs) if dregs > 1 else f"v{inst.vdst}"
# True VOP3 instructions (512+) - 3-source ops
if op_val >= 512:
return f"{name} {dst}, {s0}, {s1}, {s2}{cl}{om}" if n == 3 else f"{name} {dst}, {s0}, {s1}{cl}{om}"
# VOPC (0-255): writes to SGPR pair, VOP2 (256-319): 2-3 src, VOP1 (320-511): 1 src
if op_val < 256:
sdst = _fmt_sdst(inst.vdst, 2, cdna=True) # VOPC writes to 64-bit SGPR pair
# v_cmpx_ also writes to sdst in CDNA VOP3 (unlike VOP32 where it writes to exec)
return f"{name}_e64 {sdst}, {s0}, {s1}{cl}"
if 320 <= op_val < 512: # VOP1 promoted
if name in ('v_nop', 'v_clrexcp'): return f"{name}_e64"
return f"{name}_e64 {dst}, {s0}{cl}{om}"
# VOP2 promoted (256-319)
if name == 'v_cndmask_b32':
s2 = _fmt_src(inst.src2, 2, cdna=True) # src2 is 64-bit SGPR pair
return f"{name}_e64 {dst}, {s0}, {s1}, {s2}{cl}{om}"
if name in ('v_mul_legacy_f32', 'v_mac_f32'):
return f"{name}_e64 {dst}, {s0}, {s1}{cl}{om}"
suf = "_e64" if op_val < 512 else ""
return f"{name}{suf} {dst}, {s0}, {s1}, {s2}{cl}{om}" if n == 3 else f"{name}{suf} {dst}, {s0}, {s1}{cl}{om}"
# GFX9-specific VOP3B opcodes not in CDNA enum
def _disasm_vop3b(inst) -> str:
op_val = inst._values.get('op', 0)
if hasattr(op_val, 'value'): op_val = op_val.value
name = inst.op_name.lower() or f'vop3b_op_{op_val}'
from extra.assembly.amd.dsl import spec_num_srcs, spec_regs
n = spec_num_srcs(name) if name else inst.num_srcs()
dregs, r0, r1, r2 = spec_regs(name) if name else (inst.dst_regs(), inst.src_regs(0), inst.src_regs(1), inst.src_regs(2))
s0, s1, s2 = _cdna_src(inst, inst.src0, inst.neg&1, n=r0), _cdna_src(inst, inst.src1, inst.neg&2, n=r1), _cdna_src(inst, inst.src2, inst.neg&4, n=r2)
dst = _vreg(inst.vdst, dregs) if dregs > 1 else f"v{inst.vdst}"
sdst = _fmt_sdst(inst.sdst, 2, cdna=True) # VOP3B sdst is always 64-bit SGPR pair
cl, om = " clamp" if inst.clmp else "", _omod(inst.omod)
# Carry ops need special handling
if name in ('v_addc_co_u32', 'v_subb_co_u32', 'v_subbrev_co_u32'):
s2 = _fmt_src(inst.src2, 2, cdna=True) # src2 is carry-in (64-bit SGPR pair)
return f"{name}_e64 {dst}, {sdst}, {s0}, {s1}, {s2}{cl}{om}"
suf = "_e64" if 'co_' in name else ""
return f"{name}{suf} {dst}, {sdst}, {s0}, {s1}, {s2}{cl}{om}" if n == 3 else f"{name}{suf} {dst}, {sdst}, {s0}, {s1}{cl}{om}"
def _disasm_cdna_vop3p(inst) -> str:
name, n, is_mfma = inst.op_name.lower(), inst.num_srcs(), 'mfma' in inst.op_name.lower() or 'smfmac' in inst.op_name.lower()
get_src = lambda v, sc: inst.lit(v) if v == 255 else _fmt_src(v, sc)
get_src = lambda v, sc: inst.lit(v) if v == 255 else _fmt_src(v, sc, cdna=True)
if is_mfma: sc = 2 if 'iu4' in name else 4 if 'iu8' in name or 'i4' in name else 8 if 'f16' in name or 'bf16' in name else 4; src0, src1, src2, dst = get_src(inst.src0, sc), get_src(inst.src1, sc), get_src(inst.src2, 16), _vreg(inst.vdst, 16)
else: src0, src1, src2, dst = get_src(inst.src0, 1), get_src(inst.src1, 1), get_src(inst.src2, 1), f"v{inst.vdst}"
opsel_hi = inst.opsel_hi | (inst.opsel_hi2 << 2)
@@ -665,20 +898,93 @@ try:
_UNUSED = {0: 'UNUSED_PAD', 1: 'UNUSED_SEXT', 2: 'UNUSED_PRESERVE'}
_DPP = {0x130: "wave_shl:1", 0x134: "wave_rol:1", 0x138: "wave_shr:1", 0x13c: "wave_ror:1", 0x140: "row_mirror", 0x141: "row_half_mirror", 0x142: "row_bcast:15", 0x143: "row_bcast:31"}
def _sdwa_src0(v, is_sgpr, sext=0, neg=0, abs_=0):
# s0=0: VGPR (v is VGPR number), s0=1: SGPR/constant (v is encoded like normal src)
s = decode_src(v, cdna=True) if is_sgpr else f"v{v}"
if sext: s = f"sext({s})"
if abs_: s = f"|{s}|"
return f"-{s}" if neg else s
def _sdwa_vsrc1(v, sext=0, neg=0, abs_=0):
# For VOP2 SDWA, vsrc1 is in vop_op field as raw VGPR number
s = f"v{v}"
if sext: s = f"sext({s})"
if abs_: s = f"|{s}|"
return f"-{s}" if neg else s
_OMOD_SDWA = {0: "", 1: " mul:2", 2: " mul:4", 3: " div:2"}
def _disasm_sdwa(inst) -> str:
try: name = CDNA_VOP1Op(inst.vop_op).name.lower()
except ValueError: name = f"vop1_op_{inst.vop_op}"
src = f"v{inst.src0 - 256 if inst.src0 >= 256 else inst.src0}" if isinstance(inst.src0, int) else str(inst.src0)
mods = [f"dst_sel:{_SEL[inst.dst_sel]}" for _ in [1] if inst.dst_sel != 6] + [f"dst_unused:{_UNUSED[inst.dst_u]}" for _ in [1] if inst.dst_u] + [f"src0_sel:{_SEL[inst.src0_sel]}" for _ in [1] if inst.src0_sel != 6]
return f"{name}_sdwa v{inst.vdst}, {src}" + (" " + " ".join(mods) if mods else "")
# SDWA format: vop2_op=63 -> VOP1, vop2_op=62 -> VOPC, vop2_op=0-61 -> VOP2
vop2_op = inst.vop2_op
src0 = _sdwa_src0(inst.src0, inst.s0, inst.src0_sext, inst.src0_neg, inst.src0_abs)
clamp = " clamp" if inst.clmp else ""
omod = _OMOD_SDWA.get(inst.omod, "")
if vop2_op == 63: # VOP1
try: name = CDNA_VOP1Op(inst.vop_op).name.lower()
except ValueError: name = f"vop1_op_{inst.vop_op}"
dst = f"v{inst.vdst}"
mods = [f"dst_sel:{_SEL[inst.dst_sel]}", f"dst_unused:{_UNUSED[inst.dst_u]}", f"src0_sel:{_SEL[inst.src0_sel]}"]
return f"{name}_sdwa {dst}, {src0}{clamp}{omod} " + " ".join(mods)
elif vop2_op == 62: # VOPC
try: name = CDNA_VOPCOp(inst.vdst).name.lower() # opcode is in vdst field for VOPC SDWA
except ValueError: name = f"vopc_op_{inst.vdst}"
src1 = _sdwa_vsrc1(inst.vop_op, inst.src1_sext, inst.src1_neg, inst.src1_abs) # vsrc1 is in vop_op field
# VOPC SDWA: dst encoded in byte 5 (bits 47:40): 0=vcc, 128+n=s[n:n+1]
sdst_enc = inst.dst_sel | (inst.dst_u << 3) | (inst.clmp << 5) | (inst.omod << 6)
if sdst_enc == 0:
sdst = "vcc"
else:
sdst_val = sdst_enc - 128 if sdst_enc >= 128 else sdst_enc
sdst = _fmt_sdst(sdst_val, 2, cdna=True)
mods = [f"src0_sel:{_SEL[inst.src0_sel]}", f"src1_sel:{_SEL[inst.src1_sel]}"]
return f"{name}_sdwa {sdst}, {src0}, {src1} " + " ".join(mods)
else: # VOP2
try: name = CDNA_VOP2Op(vop2_op).name.lower()
except ValueError: name = f"vop2_op_{vop2_op}"
name = _CDNA_DISASM_ALIASES.get(name, name) # apply aliases (v_fmac -> v_mac, etc.)
dst = f"v{inst.vdst}"
src1 = _sdwa_vsrc1(inst.vop_op, inst.src1_sext, inst.src1_neg, inst.src1_abs) # vsrc1 is in vop_op field
mods = [f"dst_sel:{_SEL[inst.dst_sel]}", f"dst_unused:{_UNUSED[inst.dst_u]}", f"src0_sel:{_SEL[inst.src0_sel]}", f"src1_sel:{_SEL[inst.src1_sel]}"]
# v_cndmask_b32 needs vcc as third operand
if name == 'v_cndmask_b32':
return f"{name}_sdwa {dst}, {src0}, {src1}, vcc{clamp}{omod} " + " ".join(mods)
# Carry ops need vcc - v_addc/subb also need vcc as carry-in
if name in ('v_addc_co_u32', 'v_subb_co_u32', 'v_subbrev_co_u32'):
return f"{name}_sdwa {dst}, vcc, {src0}, {src1}, vcc{clamp}{omod} " + " ".join(mods)
if '_co_' in name:
return f"{name}_sdwa {dst}, vcc, {src0}, {src1}{clamp}{omod} " + " ".join(mods)
return f"{name}_sdwa {dst}, {src0}, {src1}{clamp}{omod} " + " ".join(mods)
def _dpp_src(v, neg=0, abs_=0):
s = f"v{v}" if v < 256 else f"v{v - 256}"
if abs_: s = f"|{s}|"
return f"-{s}" if neg else s
def _disasm_dpp(inst) -> str:
try: name = CDNA_VOP1Op(inst.vop_op).name.lower()
except ValueError: name = f"vop1_op_{inst.vop_op}"
src, ctrl = f"v{inst.src0 - 256 if inst.src0 >= 256 else inst.src0}" if isinstance(inst.src0, int) else str(inst.src0), inst.dpp_ctrl
# DPP format: vop2_op=63 -> VOP1, vop2_op=0-62 -> VOP2
vop2_op = inst.vop2_op
ctrl = inst.dpp_ctrl
dpp = f"quad_perm:[{ctrl&3},{(ctrl>>2)&3},{(ctrl>>4)&3},{(ctrl>>6)&3}]" if ctrl < 0x100 else f"row_shl:{ctrl&0xf}" if ctrl < 0x110 else f"row_shr:{ctrl&0xf}" if ctrl < 0x120 else f"row_ror:{ctrl&0xf}" if ctrl < 0x130 else _DPP.get(ctrl, f"dpp_ctrl:0x{ctrl:x}")
mods = [dpp] + [f"row_mask:0x{inst.row_mask:x}" for _ in [1] if inst.row_mask != 0xf] + [f"bank_mask:0x{inst.bank_mask:x}" for _ in [1] if inst.bank_mask != 0xf] + ["bound_ctrl:1" for _ in [1] if inst.bound_ctrl]
return f"{name}_dpp v{inst.vdst}, {src} " + " ".join(mods)
src0 = _dpp_src(inst.src0, inst.src0_neg, inst.src0_abs)
# DPP modifiers: row_mask and bank_mask always shown, bound_ctrl:0 when bit=1
mods = [dpp, f"row_mask:0x{inst.row_mask:x}", f"bank_mask:0x{inst.bank_mask:x}"] + (["bound_ctrl:0"] if inst.bound_ctrl else [])
if vop2_op == 63: # VOP1
try: name = CDNA_VOP1Op(inst.vop_op).name.lower()
except ValueError: name = f"vop1_op_{inst.vop_op}"
return f"{name}_dpp v{inst.vdst}, {src0} " + " ".join(mods)
else: # VOP2
try: name = CDNA_VOP2Op(vop2_op).name.lower()
except ValueError: name = f"vop2_op_{vop2_op}"
name = _CDNA_DISASM_ALIASES.get(name, name)
src1 = _dpp_src(inst.vop_op, inst.src1_neg, inst.src1_abs) # vsrc1 is in vop_op field
if name == 'v_cndmask_b32':
return f"{name}_dpp v{inst.vdst}, {src0}, {src1}, vcc " + " ".join(mods)
if name in ('v_addc_co_u32', 'v_subb_co_u32', 'v_subbrev_co_u32'):
return f"{name}_dpp v{inst.vdst}, vcc, {src0}, {src1}, vcc " + " ".join(mods)
if '_co_' in name:
return f"{name}_dpp v{inst.vdst}, vcc, {src0}, {src1} " + " ".join(mods)
return f"{name}_dpp v{inst.vdst}, {src0}, {src1} " + " ".join(mods)
# Register CDNA handlers - shared formats use merged disassemblers, CDNA-only formats use dedicated ones
DISASM_HANDLERS.update({CDNA_VOP1: _disasm_vop1, CDNA_VOP2: _disasm_vop2, CDNA_VOPC: _disasm_vopc,

View File

@@ -1,46 +1,6 @@
# autogenerated from AMD CDNA3+CDNA4 ISA PDF by pdf.py - do not edit
# autogenerated from AMD ISA PDF by pdf.py - do not edit
from enum import IntEnum
class SrcEnum(IntEnum):
S_ADD_U32 = 0
S_SUB_U32 = 1
S_ADD_I32 = 2
S_SUB_I32 = 3
S_ADDC_U32 = 4
S_SUBB_U32 = 5
S_MIN_I32 = 6
FLAT_SCRATCH_LO = 102
FLAT_SCRATCH_HI = 103
XNACK_MASK_LO = 104
XNACK_MASK_HI = 105
VCC_LO = 106
VCC_HI = 107
M0 = 124
EXEC_LO = 126
EXEC_HI = 127
ZERO = 128
DPP8 = 233
DPP8FI = 234
SHARED_BASE = 235
SHARED_LIMIT = 236
PRIVATE_BASE = 237
PRIVATE_LIMIT = 238
RESERVED = 239
POS_HALF = 240
NEG_HALF = 241
POS_ONE = 242
NEG_ONE = 243
POS_TWO = 244
NEG_TWO = 245
POS_FOUR = 246
NEG_FOUR = 247
INV_2PI = 248
DPP16 = 250
VCCZ = 251
EXECZ = 252
SCC = 253
LDS_DIRECT = 254
class DSOp(IntEnum):
DS_ADD_U32 = 0
DS_SUB_U32 = 1
@@ -155,12 +115,6 @@ class DSOp(IntEnum):
DS_READ2ST64_B64 = 120
DS_ADD_RTN_F64 = 124
DS_CONDXCHG32_RTN_B64 = 126
DS_GWS_SEMA_RELEASE_ALL = 152
DS_GWS_INIT = 153
DS_GWS_SEMA_V = 154
DS_GWS_SEMA_BR = 155
DS_GWS_SEMA_P = 156
DS_GWS_BARRIER = 157
DS_READ_ADDTID_B32 = 182
DS_PK_ADD_RTN_F16 = 183
DS_PK_ADD_RTN_BF16 = 184
@@ -174,7 +128,6 @@ class DSOp(IntEnum):
DS_READ_B64_TR_B16 = 227
DS_READ_B96 = 254
DS_READ_B128 = 255
CDNA4 = 600
class FLATOp(IntEnum):
FLAT_LOAD_UBYTE = 16
@@ -231,7 +184,6 @@ class FLATOp(IntEnum):
FLAT_ATOMIC_XOR_X2 = 106
FLAT_ATOMIC_INC_X2 = 107
FLAT_ATOMIC_DEC_X2 = 108
CDNA4 = 600
class GLOBALOp(IntEnum):
GLOBAL_LOAD_UBYTE = 16
@@ -295,7 +247,6 @@ class GLOBALOp(IntEnum):
GLOBAL_ATOMIC_DEC_X2 = 108
GLOBAL_LOAD_LDS_DWORDX4 = 125
GLOBAL_LOAD_LDS_DWORDX3 = 126
CDNA4 = 600
class MTBUFOp(IntEnum):
TBUFFER_LOAD_FORMAT_X = 0
@@ -390,7 +341,6 @@ class MUBUFOp(IntEnum):
BUFFER_ATOMIC_XOR_X2 = 106
BUFFER_ATOMIC_INC_X2 = 107
BUFFER_ATOMIC_DEC_X2 = 108
CDNA4 = 600
class SCRATCHOp(IntEnum):
SCRATCH_LOAD_UBYTE = 16
@@ -504,7 +454,6 @@ class SMEMOp(IntEnum):
S_ATOMIC_XOR_X2 = 170
S_ATOMIC_INC_X2 = 171
S_ATOMIC_DEC_X2 = 172
CDNA4 = 600
class SOP1Op(IntEnum):
S_MOV_B32 = 0
@@ -561,7 +510,6 @@ class SOP1Op(IntEnum):
S_ANDN1_WREXEC_B64 = 53
S_ANDN2_WREXEC_B64 = 54
S_BITREPLICATE_B64_B32 = 55
CDNA4 = 600
class SOP2Op(IntEnum):
S_ADD_U32 = 0
@@ -616,7 +564,6 @@ class SOP2Op(IntEnum):
S_PACK_LL_B32_B16 = 50
S_PACK_LH_B32_B16 = 51
S_PACK_HH_B32_B16 = 52
CDNA4 = 600
class SOPCOp(IntEnum):
S_CMP_EQ_I32 = 0
@@ -639,7 +586,6 @@ class SOPCOp(IntEnum):
S_SET_GPR_IDX_ON = 17
S_CMP_EQ_U64 = 18
S_CMP_LG_U64 = 19
CDNA4 = 600
class SOPKOp(IntEnum):
S_MOVK_I32 = 0
@@ -695,7 +641,6 @@ class SOPPOp(IntEnum):
S_ENDPGM_SAVED = 27
S_SET_GPR_IDX_OFF = 28
S_SET_GPR_IDX_MODE = 29
CDNA4 = 600
class VOP1Op(IntEnum):
V_NOP = 0
@@ -783,7 +728,6 @@ class VOP1Op(IntEnum):
V_PERMLANE16_SWAP_B32 = 89
V_PERMLANE32_SWAP_B32 = 90
V_CVT_F32_BF16 = 91
CDNA4 = 600
class VOP2Op(IntEnum):
V_CNDMASK_B32 = 0
@@ -848,7 +792,6 @@ class VOP2Op(IntEnum):
V_FMAC_F32 = 59
V_PK_FMAC_F16 = 60
V_XNOR_B32 = 61
CDNA4 = 600
class VOP3AOp(IntEnum):
V_CMP_CLASS_F32 = 16
@@ -1268,7 +1211,7 @@ class VOP3AOp(IntEnum):
V_CVT_SCALEF32_SR_PK32_BF6_F32 = 597
V_CVT_SCALEF32_PK32_F32_FP6 = 598
V_CVT_SCALEF32_PK32_F32_BF6 = 599
CDNA4 = 600
V_CVT_SCALEF32_PK32_FP6_F16 = 600
V_CVT_SCALEF32_PK32_FP6_BF16 = 601
V_CVT_SCALEF32_PK32_BF6_F16 = 602
V_CVT_SCALEF32_PK32_BF6_BF16 = 603
@@ -1338,7 +1281,6 @@ class VOP3BOp(IntEnum):
V_DIV_SCALE_F64 = 481
V_MAD_U64_U32 = 488
V_MAD_I64_I32 = 489
CDNA4 = 600
class VOP3POp(IntEnum):
V_PK_MAD_I16 = 0
@@ -1388,8 +1330,6 @@ class VOP3POp(IntEnum):
V_SMFMAC_F32_16X16X128_BF8_BF8 = 59
V_SMFMAC_F32_16X16X128_BF8_FP8 = 60
V_SMFMAC_F32_16X16X128_FP8_BF8 = 61
V_MFMA_F32_16X16X8_XF32 = 62
V_MFMA_F32_32X32X4_XF32 = 63
V_MFMA_F32_32X32X1_2B_F32 = 64
V_MFMA_F32_16X16X1_4B_F32 = 65
V_MFMA_F32_4X4X1_16B_F32 = 66
@@ -1447,7 +1387,6 @@ class VOP3POp(IntEnum):
V_SMFMAC_F32_32X32X32_BF8_FP8 = 125
V_SMFMAC_F32_32X32X32_FP8_BF8 = 126
V_SMFMAC_F32_32X32X32_FP8_FP8 = 127
CDNA4 = 600
class VOPCOp(IntEnum):
V_CMP_CLASS_F32 = 16
@@ -1648,4 +1587,3 @@ class VOPCOp(IntEnum):
V_CMPX_NE_U64 = 253
V_CMPX_GE_U64 = 254
V_CMPX_T_U64 = 255
CDNA4 = 600

View File

@@ -1,19 +1,18 @@
# autogenerated from AMD CDNA3+CDNA4 ISA PDF by pdf.py - do not edit
# autogenerated from AMD ISA PDF by pdf.py - do not edit
# ruff: noqa: F401,F403
from typing import Annotated
from extra.assembly.amd.dsl import bits, BitField, Inst32, Inst64, Inst96, SGPR, VGPR, TTMP as TTMP, s as s, v as v, ttmp as ttmp, SSrc, Src, SImm, Imm, VDSTYEnc, SGPRField, VGPRField
from extra.assembly.amd.dsl import *
from extra.assembly.amd.autogen.cdna.enum import *
import functools
# instruction formats
class DPP(Inst64):
class DPP(Inst):
encoding = bits[8:0] == 0b11111010
vop_op = bits[16:9]
vdst:VGPRField = bits[24:17]
vop2_op = bits[31:25]
src0:Src = bits[39:32]
vop_op = bits[16:9]
vop2_op = bits[31:25]
dpp_ctrl = bits[48:40]
bound_ctrl = bits[51]
bc = bits[51]
src0_neg = bits[52]
src0_abs = bits[53]
src1_neg = bits[54]
@@ -21,7 +20,7 @@ class DPP(Inst64):
bank_mask = bits[59:56]
row_mask = bits[63:60]
class DS(Inst64):
class DS(Inst):
encoding = bits[31:26] == 0b110110
op:Annotated[BitField, DSOp] = bits[24:17]
vdst:VGPRField = bits[63:56]
@@ -33,7 +32,7 @@ class DS(Inst64):
gds = bits[16]
acc = bits[25]
class FLAT(Inst64):
class FLAT(Inst):
encoding = bits[31:26] == 0b110111
op:Annotated[BitField, FLATOp] = bits[24:18]
vdst:VGPRField = bits[63:56]
@@ -48,7 +47,7 @@ class FLAT(Inst64):
sc1 = bits[25]
acc = bits[55]
class MTBUF(Inst64):
class MTBUF(Inst):
encoding = bits[31:26] == 0b111010
op:Annotated[BitField, MTBUFOp] = bits[18:15]
vdata:VGPRField = bits[47:40]
@@ -58,12 +57,14 @@ class MTBUF(Inst64):
offset:Imm = bits[11:0]
offen = bits[12]
idxen = bits[13]
sc0 = bits[14]
dfmt = bits[22:19]
nfmt = bits[25:23]
sc1 = bits[53]
nt = bits[54]
acc = bits[55]
sc0 = bits[14]
class MUBUF(Inst64):
class MUBUF(Inst):
encoding = bits[31:26] == 0b111000
op:Annotated[BitField, MUBUFOp] = bits[24:18]
vdata:VGPRField = bits[47:40]
@@ -79,16 +80,16 @@ class MUBUF(Inst64):
nt = bits[17]
acc = bits[55]
class SDWA(Inst64):
class SDWA(Inst):
encoding = bits[8:0] == 0b11111001
vop_op = bits[16:9]
vdst:VGPRField = bits[24:17]
src0:Src = bits[39:32]
omod = bits[47:46]
clmp = bits[45]
vop_op = bits[16:9]
vop2_op = bits[31:25]
src0:Src = bits[39:32]
dst_sel = bits[42:40]
dst_u = bits[44:43]
clmp = bits[45]
omod = bits[47:46]
src0_sel = bits[50:48]
src0_sext = bits[51]
src0_neg = bits[52]
@@ -100,12 +101,10 @@ class SDWA(Inst64):
src1_abs = bits[61]
s1 = bits[63]
class SDWAB(Inst64):
class SDWAB(Inst):
sdst:SGPRField = bits[46:40]
src0:Src = bits[39:32]
dst_sel = bits[42:40]
dst_u = bits[44:43]
clmp = bits[45]
omod = bits[47:46]
sd = bits[47]
src0_sel = bits[50:48]
src0_sext = bits[51]
src0_neg = bits[52]
@@ -117,89 +116,88 @@ class SDWAB(Inst64):
src1_abs = bits[61]
s1 = bits[63]
class SMEM(Inst64):
class SMEM(Inst):
encoding = bits[31:26] == 0b110000
op:Annotated[BitField, SMEMOp] = bits[25:18]
sdata:SGPRField = bits[12:6]
sbase:SGPRField = bits[5:0]
soffset:SSrc = bits[63:57]
offset:Imm = bits[52:32]
glc = bits[14]
glc = bits[16]
soe = bits[14]
nv = bits[15]
imm = bits[17]
imm:Imm = bits[17]
class SOP1(Inst32):
class SOP1(Inst):
encoding = bits[31:23] == 0b101111101
op:Annotated[BitField, SOP1Op] = bits[15:8]
sdst:SGPRField = bits[22:16]
ssrc0:SSrc = bits[7:0]
class SOP2(Inst32):
class SOP2(Inst):
encoding = bits[31:30] == 0b10
op:Annotated[BitField, SOP2Op] = bits[29:23]
sdst:SGPRField = bits[22:16]
ssrc0:SSrc = bits[7:0]
ssrc1:SSrc = bits[15:8]
class SOPC(Inst32):
class SOPC(Inst):
encoding = bits[31:23] == 0b101111110
op:Annotated[BitField, SOPCOp] = bits[22:16]
ssrc0:SSrc = bits[7:0]
ssrc1:SSrc = bits[15:8]
class SOPK(Inst32):
class SOPK(Inst):
encoding = bits[31:28] == 0b1011
op:Annotated[BitField, SOPKOp] = bits[27:23]
sdst:SGPRField = bits[22:16]
simm16:SImm = bits[15:0]
class SOPP(Inst32):
class SOPP(Inst):
encoding = bits[31:23] == 0b101111111
op:Annotated[BitField, SOPPOp] = bits[22:16]
simm16:SImm = bits[15:0]
class VOP1(Inst32):
encoding = bits[31:25] == 0b111111
class VOP1(Inst):
encoding = bits[31:25] == 0b0111111
op:Annotated[BitField, VOP1Op] = bits[16:9]
vdst:VGPRField = bits[24:17]
src0:Src = bits[8:0]
class VOP2(Inst32):
encoding = bits[31] == 0
class VOP2(Inst):
encoding = bits[31] == 0b0
op:Annotated[BitField, VOP2Op] = bits[30:25]
vdst:VGPRField = bits[24:17]
src0:Src = bits[8:0]
vsrc1:VGPRField = bits[16:9]
class VOP3A(Inst64):
class VOP3A(Inst):
encoding = bits[31:26] == 0b110100
vdst:VGPRField = bits[7:0]
abs = bits[10:8]
opsel = bits[14:11]
clmp = bits[15]
op:Annotated[BitField, VOP3AOp] = bits[25:16]
vdst:VGPRField = bits[7:0]
src0:Src = bits[40:32]
src1:Src = bits[49:41]
src2:Src = bits[58:50]
omod = bits[60:59]
neg = bits[63:61]
abs = bits[10:8]
clmp = bits[15]
opsel = bits[14:11]
class VOP3B(Inst64):
class VOP3B(Inst):
encoding = bits[31:26] == 0b110100
op:Annotated[BitField, VOP3BOp] = bits[25:16]
vdst:VGPRField = bits[7:0]
sdst:SGPRField = bits[14:8]
clmp = bits[15]
op:Annotated[BitField, VOP3BOp] = bits[25:16]
src0:Src = bits[40:32]
src1:Src = bits[49:41]
src2:Src = bits[58:50]
omod = bits[60:59]
neg = bits[63:61]
clmp = bits[15]
class VOP3P(Inst64):
class VOP3P(Inst):
encoding = bits[31:23] == 0b110100111
_defaults = {'opsel_hi': 3, 'opsel_hi2': 1}
op:Annotated[BitField, VOP3POp] = bits[22:16]
vdst:VGPRField = bits[7:0]
src0:Src = bits[40:32]
@@ -207,13 +205,13 @@ class VOP3P(Inst64):
src2:Src = bits[58:50]
neg = bits[63:61]
neg_hi = bits[10:8]
clmp = bits[15]
opsel = bits[13:11]
opsel_hi = bits[60:59]
clmp = bits[15]
opsel_hi2 = bits[14]
class VOPC(Inst32):
encoding = bits[31:25] == 0b111110
class VOPC(Inst):
encoding = bits[31:25] == 0b0111110
op:Annotated[BitField, VOPCOp] = bits[24:17]
src0:Src = bits[8:0]
vsrc1:VGPRField = bits[16:9]
@@ -332,12 +330,6 @@ ds_read2_b64 = functools.partial(DS, DSOp.DS_READ2_B64)
ds_read2st64_b64 = functools.partial(DS, DSOp.DS_READ2ST64_B64)
ds_add_rtn_f64 = functools.partial(DS, DSOp.DS_ADD_RTN_F64)
ds_condxchg32_rtn_b64 = functools.partial(DS, DSOp.DS_CONDXCHG32_RTN_B64)
ds_gws_sema_release_all = functools.partial(DS, DSOp.DS_GWS_SEMA_RELEASE_ALL)
ds_gws_init = functools.partial(DS, DSOp.DS_GWS_INIT)
ds_gws_sema_v = functools.partial(DS, DSOp.DS_GWS_SEMA_V)
ds_gws_sema_br = functools.partial(DS, DSOp.DS_GWS_SEMA_BR)
ds_gws_sema_p = functools.partial(DS, DSOp.DS_GWS_SEMA_P)
ds_gws_barrier = functools.partial(DS, DSOp.DS_GWS_BARRIER)
ds_read_addtid_b32 = functools.partial(DS, DSOp.DS_READ_ADDTID_B32)
ds_pk_add_rtn_f16 = functools.partial(DS, DSOp.DS_PK_ADD_RTN_F16)
ds_pk_add_rtn_bf16 = functools.partial(DS, DSOp.DS_PK_ADD_RTN_BF16)
@@ -351,7 +343,6 @@ ds_read_b64_tr_b8 = functools.partial(DS, DSOp.DS_READ_B64_TR_B8)
ds_read_b64_tr_b16 = functools.partial(DS, DSOp.DS_READ_B64_TR_B16)
ds_read_b96 = functools.partial(DS, DSOp.DS_READ_B96)
ds_read_b128 = functools.partial(DS, DSOp.DS_READ_B128)
cdna4 = functools.partial(DS, DSOp.CDNA4)
flat_load_ubyte = functools.partial(FLAT, FLATOp.FLAT_LOAD_UBYTE)
flat_load_sbyte = functools.partial(FLAT, FLATOp.FLAT_LOAD_SBYTE)
flat_load_ushort = functools.partial(FLAT, FLATOp.FLAT_LOAD_USHORT)
@@ -406,7 +397,6 @@ flat_atomic_or_x2 = functools.partial(FLAT, FLATOp.FLAT_ATOMIC_OR_X2)
flat_atomic_xor_x2 = functools.partial(FLAT, FLATOp.FLAT_ATOMIC_XOR_X2)
flat_atomic_inc_x2 = functools.partial(FLAT, FLATOp.FLAT_ATOMIC_INC_X2)
flat_atomic_dec_x2 = functools.partial(FLAT, FLATOp.FLAT_ATOMIC_DEC_X2)
cdna4 = functools.partial(FLAT, FLATOp.CDNA4)
global_load_ubyte = functools.partial(FLAT, GLOBALOp.GLOBAL_LOAD_UBYTE, seg=2)
global_load_sbyte = functools.partial(FLAT, GLOBALOp.GLOBAL_LOAD_SBYTE, seg=2)
global_load_ushort = functools.partial(FLAT, GLOBALOp.GLOBAL_LOAD_USHORT, seg=2)
@@ -468,7 +458,6 @@ global_atomic_inc_x2 = functools.partial(FLAT, GLOBALOp.GLOBAL_ATOMIC_INC_X2, se
global_atomic_dec_x2 = functools.partial(FLAT, GLOBALOp.GLOBAL_ATOMIC_DEC_X2, seg=2)
global_load_lds_dwordx4 = functools.partial(FLAT, GLOBALOp.GLOBAL_LOAD_LDS_DWORDX4, seg=2)
global_load_lds_dwordx3 = functools.partial(FLAT, GLOBALOp.GLOBAL_LOAD_LDS_DWORDX3, seg=2)
cdna4 = functools.partial(FLAT, GLOBALOp.CDNA4, seg=2)
tbuffer_load_format_x = functools.partial(MTBUF, MTBUFOp.TBUFFER_LOAD_FORMAT_X)
tbuffer_load_format_xy = functools.partial(MTBUF, MTBUFOp.TBUFFER_LOAD_FORMAT_XY)
tbuffer_load_format_xyz = functools.partial(MTBUF, MTBUFOp.TBUFFER_LOAD_FORMAT_XYZ)
@@ -559,7 +548,6 @@ buffer_atomic_or_x2 = functools.partial(MUBUF, MUBUFOp.BUFFER_ATOMIC_OR_X2)
buffer_atomic_xor_x2 = functools.partial(MUBUF, MUBUFOp.BUFFER_ATOMIC_XOR_X2)
buffer_atomic_inc_x2 = functools.partial(MUBUF, MUBUFOp.BUFFER_ATOMIC_INC_X2)
buffer_atomic_dec_x2 = functools.partial(MUBUF, MUBUFOp.BUFFER_ATOMIC_DEC_X2)
cdna4 = functools.partial(MUBUF, MUBUFOp.CDNA4)
scratch_load_ubyte = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_UBYTE, seg=1)
scratch_load_sbyte = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_SBYTE, seg=1)
scratch_load_ushort = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_USHORT, seg=1)
@@ -669,7 +657,6 @@ s_atomic_or_x2 = functools.partial(SMEM, SMEMOp.S_ATOMIC_OR_X2)
s_atomic_xor_x2 = functools.partial(SMEM, SMEMOp.S_ATOMIC_XOR_X2)
s_atomic_inc_x2 = functools.partial(SMEM, SMEMOp.S_ATOMIC_INC_X2)
s_atomic_dec_x2 = functools.partial(SMEM, SMEMOp.S_ATOMIC_DEC_X2)
cdna4 = functools.partial(SMEM, SMEMOp.CDNA4)
s_mov_b32 = functools.partial(SOP1, SOP1Op.S_MOV_B32)
s_mov_b64 = functools.partial(SOP1, SOP1Op.S_MOV_B64)
s_cmov_b32 = functools.partial(SOP1, SOP1Op.S_CMOV_B32)
@@ -724,7 +711,6 @@ s_orn1_saveexec_b64 = functools.partial(SOP1, SOP1Op.S_ORN1_SAVEEXEC_B64)
s_andn1_wrexec_b64 = functools.partial(SOP1, SOP1Op.S_ANDN1_WREXEC_B64)
s_andn2_wrexec_b64 = functools.partial(SOP1, SOP1Op.S_ANDN2_WREXEC_B64)
s_bitreplicate_b64_b32 = functools.partial(SOP1, SOP1Op.S_BITREPLICATE_B64_B32)
cdna4 = functools.partial(SOP1, SOP1Op.CDNA4)
s_add_u32 = functools.partial(SOP2, SOP2Op.S_ADD_U32)
s_sub_u32 = functools.partial(SOP2, SOP2Op.S_SUB_U32)
s_add_i32 = functools.partial(SOP2, SOP2Op.S_ADD_I32)
@@ -777,7 +763,6 @@ s_lshl4_add_u32 = functools.partial(SOP2, SOP2Op.S_LSHL4_ADD_U32)
s_pack_ll_b32_b16 = functools.partial(SOP2, SOP2Op.S_PACK_LL_B32_B16)
s_pack_lh_b32_b16 = functools.partial(SOP2, SOP2Op.S_PACK_LH_B32_B16)
s_pack_hh_b32_b16 = functools.partial(SOP2, SOP2Op.S_PACK_HH_B32_B16)
cdna4 = functools.partial(SOP2, SOP2Op.CDNA4)
s_cmp_eq_i32 = functools.partial(SOPC, SOPCOp.S_CMP_EQ_I32)
s_cmp_lg_i32 = functools.partial(SOPC, SOPCOp.S_CMP_LG_I32)
s_cmp_gt_i32 = functools.partial(SOPC, SOPCOp.S_CMP_GT_I32)
@@ -798,7 +783,6 @@ s_setvskip = functools.partial(SOPC, SOPCOp.S_SETVSKIP)
s_set_gpr_idx_on = functools.partial(SOPC, SOPCOp.S_SET_GPR_IDX_ON)
s_cmp_eq_u64 = functools.partial(SOPC, SOPCOp.S_CMP_EQ_U64)
s_cmp_lg_u64 = functools.partial(SOPC, SOPCOp.S_CMP_LG_U64)
cdna4 = functools.partial(SOPC, SOPCOp.CDNA4)
s_movk_i32 = functools.partial(SOPK, SOPKOp.S_MOVK_I32)
s_cmovk_i32 = functools.partial(SOPK, SOPKOp.S_CMOVK_I32)
s_cmpk_eq_i32 = functools.partial(SOPK, SOPKOp.S_CMPK_EQ_I32)
@@ -850,7 +834,6 @@ s_cbranch_cdbgsys_and_user = functools.partial(SOPP, SOPPOp.S_CBRANCH_CDBGSYS_AN
s_endpgm_saved = functools.partial(SOPP, SOPPOp.S_ENDPGM_SAVED)
s_set_gpr_idx_off = functools.partial(SOPP, SOPPOp.S_SET_GPR_IDX_OFF)
s_set_gpr_idx_mode = functools.partial(SOPP, SOPPOp.S_SET_GPR_IDX_MODE)
cdna4 = functools.partial(SOPP, SOPPOp.CDNA4)
v_nop_e32 = functools.partial(VOP1, VOP1Op.V_NOP)
v_mov_b32_e32 = functools.partial(VOP1, VOP1Op.V_MOV_B32)
v_readfirstlane_b32_e32 = functools.partial(VOP1, VOP1Op.V_READFIRSTLANE_B32)
@@ -936,7 +919,6 @@ v_prng_b32_e32 = functools.partial(VOP1, VOP1Op.V_PRNG_B32)
v_permlane16_swap_b32_e32 = functools.partial(VOP1, VOP1Op.V_PERMLANE16_SWAP_B32)
v_permlane32_swap_b32_e32 = functools.partial(VOP1, VOP1Op.V_PERMLANE32_SWAP_B32)
v_cvt_f32_bf16_e32 = functools.partial(VOP1, VOP1Op.V_CVT_F32_BF16)
cdna4_e32 = functools.partial(VOP1, VOP1Op.CDNA4)
v_cndmask_b32_e32 = functools.partial(VOP2, VOP2Op.V_CNDMASK_B32)
v_add_f32_e32 = functools.partial(VOP2, VOP2Op.V_ADD_F32)
v_sub_f32_e32 = functools.partial(VOP2, VOP2Op.V_SUB_F32)
@@ -960,8 +942,8 @@ v_and_b32_e32 = functools.partial(VOP2, VOP2Op.V_AND_B32)
v_or_b32_e32 = functools.partial(VOP2, VOP2Op.V_OR_B32)
v_xor_b32_e32 = functools.partial(VOP2, VOP2Op.V_XOR_B32)
v_dot2c_f32_bf16_e32 = functools.partial(VOP2, VOP2Op.V_DOT2C_F32_BF16)
def v_fmamk_f32_e32(vdst, src0, K, vsrc1): return VOP2(VOP2Op.V_FMAMK_F32, vdst, src0, vsrc1, literal=K)
def v_fmaak_f32_e32(vdst, src0, vsrc1, K): return VOP2(VOP2Op.V_FMAAK_F32, vdst, src0, vsrc1, literal=K)
v_fmamk_f32_e32 = functools.partial(VOP2, VOP2Op.V_FMAMK_F32)
v_fmaak_f32_e32 = functools.partial(VOP2, VOP2Op.V_FMAAK_F32)
v_add_co_u32_e32 = functools.partial(VOP2, VOP2Op.V_ADD_CO_U32)
v_sub_co_u32_e32 = functools.partial(VOP2, VOP2Op.V_SUB_CO_U32)
v_subrev_co_u32_e32 = functools.partial(VOP2, VOP2Op.V_SUBREV_CO_U32)
@@ -999,7 +981,6 @@ v_dot8c_i32_i4_e32 = functools.partial(VOP2, VOP2Op.V_DOT8C_I32_I4)
v_fmac_f32_e32 = functools.partial(VOP2, VOP2Op.V_FMAC_F32)
v_pk_fmac_f16_e32 = functools.partial(VOP2, VOP2Op.V_PK_FMAC_F16)
v_xnor_b32_e32 = functools.partial(VOP2, VOP2Op.V_XNOR_B32)
cdna4_e32 = functools.partial(VOP2, VOP2Op.CDNA4)
v_cmp_class_f32 = functools.partial(VOP3A, VOP3AOp.V_CMP_CLASS_F32)
v_cmpx_class_f32 = functools.partial(VOP3A, VOP3AOp.V_CMPX_CLASS_F32)
v_cmp_class_f64 = functools.partial(VOP3A, VOP3AOp.V_CMP_CLASS_F64)
@@ -1417,7 +1398,7 @@ v_cvt_scalef32_sr_pk32_fp6_f32 = functools.partial(VOP3A, VOP3AOp.V_CVT_SCALEF32
v_cvt_scalef32_sr_pk32_bf6_f32 = functools.partial(VOP3A, VOP3AOp.V_CVT_SCALEF32_SR_PK32_BF6_F32)
v_cvt_scalef32_pk32_f32_fp6 = functools.partial(VOP3A, VOP3AOp.V_CVT_SCALEF32_PK32_F32_FP6)
v_cvt_scalef32_pk32_f32_bf6 = functools.partial(VOP3A, VOP3AOp.V_CVT_SCALEF32_PK32_F32_BF6)
cdna4 = functools.partial(VOP3A, VOP3AOp.CDNA4)
v_cvt_scalef32_pk32_fp6_f16 = functools.partial(VOP3A, VOP3AOp.V_CVT_SCALEF32_PK32_FP6_F16)
v_cvt_scalef32_pk32_fp6_bf16 = functools.partial(VOP3A, VOP3AOp.V_CVT_SCALEF32_PK32_FP6_BF16)
v_cvt_scalef32_pk32_bf6_f16 = functools.partial(VOP3A, VOP3AOp.V_CVT_SCALEF32_PK32_BF6_F16)
v_cvt_scalef32_pk32_bf6_bf16 = functools.partial(VOP3A, VOP3AOp.V_CVT_SCALEF32_PK32_BF6_BF16)
@@ -1485,7 +1466,6 @@ v_div_scale_f32 = functools.partial(VOP3B, VOP3BOp.V_DIV_SCALE_F32)
v_div_scale_f64 = functools.partial(VOP3B, VOP3BOp.V_DIV_SCALE_F64)
v_mad_u64_u32 = functools.partial(VOP3B, VOP3BOp.V_MAD_U64_U32)
v_mad_i64_i32 = functools.partial(VOP3B, VOP3BOp.V_MAD_I64_I32)
cdna4 = functools.partial(VOP3B, VOP3BOp.CDNA4)
v_pk_mad_i16 = functools.partial(VOP3P, VOP3POp.V_PK_MAD_I16)
v_pk_mul_lo_u16 = functools.partial(VOP3P, VOP3POp.V_PK_MUL_LO_U16)
v_pk_add_i16 = functools.partial(VOP3P, VOP3POp.V_PK_ADD_I16)
@@ -1533,8 +1513,6 @@ v_smfmac_i32_16x16x128_i8 = functools.partial(VOP3P, VOP3POp.V_SMFMAC_I32_16X16X
v_smfmac_f32_16x16x128_bf8_bf8 = functools.partial(VOP3P, VOP3POp.V_SMFMAC_F32_16X16X128_BF8_BF8)
v_smfmac_f32_16x16x128_bf8_fp8 = functools.partial(VOP3P, VOP3POp.V_SMFMAC_F32_16X16X128_BF8_FP8)
v_smfmac_f32_16x16x128_fp8_bf8 = functools.partial(VOP3P, VOP3POp.V_SMFMAC_F32_16X16X128_FP8_BF8)
v_mfma_f32_16x16x8_xf32 = functools.partial(VOP3P, VOP3POp.V_MFMA_F32_16X16X8_XF32)
v_mfma_f32_32x32x4_xf32 = functools.partial(VOP3P, VOP3POp.V_MFMA_F32_32X32X4_XF32)
v_mfma_f32_32x32x1_2b_f32 = functools.partial(VOP3P, VOP3POp.V_MFMA_F32_32X32X1_2B_F32)
v_mfma_f32_16x16x1_4b_f32 = functools.partial(VOP3P, VOP3POp.V_MFMA_F32_16X16X1_4B_F32)
v_mfma_f32_4x4x1_16b_f32 = functools.partial(VOP3P, VOP3POp.V_MFMA_F32_4X4X1_16B_F32)
@@ -1592,7 +1570,6 @@ v_smfmac_f32_32x32x32_bf8_bf8 = functools.partial(VOP3P, VOP3POp.V_SMFMAC_F32_32
v_smfmac_f32_32x32x32_bf8_fp8 = functools.partial(VOP3P, VOP3POp.V_SMFMAC_F32_32X32X32_BF8_FP8)
v_smfmac_f32_32x32x32_fp8_bf8 = functools.partial(VOP3P, VOP3POp.V_SMFMAC_F32_32X32X32_FP8_BF8)
v_smfmac_f32_32x32x32_fp8_fp8 = functools.partial(VOP3P, VOP3POp.V_SMFMAC_F32_32X32X32_FP8_FP8)
cdna4 = functools.partial(VOP3P, VOP3POp.CDNA4)
v_cmp_class_f32_e32 = functools.partial(VOPC, VOPCOp.V_CMP_CLASS_F32)
v_cmpx_class_f32_e32 = functools.partial(VOPC, VOPCOp.V_CMPX_CLASS_F32)
v_cmp_class_f64_e32 = functools.partial(VOPC, VOPCOp.V_CMP_CLASS_F64)
@@ -1790,42 +1767,4 @@ v_cmpx_le_u64_e32 = functools.partial(VOPC, VOPCOp.V_CMPX_LE_U64)
v_cmpx_gt_u64_e32 = functools.partial(VOPC, VOPCOp.V_CMPX_GT_U64)
v_cmpx_ne_u64_e32 = functools.partial(VOPC, VOPCOp.V_CMPX_NE_U64)
v_cmpx_ge_u64_e32 = functools.partial(VOPC, VOPCOp.V_CMPX_GE_U64)
v_cmpx_t_u64_e32 = functools.partial(VOPC, VOPCOp.V_CMPX_T_U64)
cdna4_e32 = functools.partial(VOPC, VOPCOp.CDNA4)
S_ADD_U32 = SrcEnum.S_ADD_U32
S_SUB_U32 = SrcEnum.S_SUB_U32
S_ADD_I32 = SrcEnum.S_ADD_I32
S_SUB_I32 = SrcEnum.S_SUB_I32
S_ADDC_U32 = SrcEnum.S_ADDC_U32
S_SUBB_U32 = SrcEnum.S_SUBB_U32
S_MIN_I32 = SrcEnum.S_MIN_I32
FLAT_SCRATCH_LO = SrcEnum.FLAT_SCRATCH_LO
FLAT_SCRATCH_HI = SrcEnum.FLAT_SCRATCH_HI
XNACK_MASK_LO = SrcEnum.XNACK_MASK_LO
XNACK_MASK_HI = SrcEnum.XNACK_MASK_HI
VCC_LO = SrcEnum.VCC_LO
VCC_HI = SrcEnum.VCC_HI
M0 = SrcEnum.M0
EXEC_LO = SrcEnum.EXEC_LO
EXEC_HI = SrcEnum.EXEC_HI
ZERO = SrcEnum.ZERO
DPP8FI = SrcEnum.DPP8FI
SHARED_BASE = SrcEnum.SHARED_BASE
SHARED_LIMIT = SrcEnum.SHARED_LIMIT
PRIVATE_BASE = SrcEnum.PRIVATE_BASE
PRIVATE_LIMIT = SrcEnum.PRIVATE_LIMIT
RESERVED = SrcEnum.RESERVED
POS_HALF = SrcEnum.POS_HALF
NEG_HALF = SrcEnum.NEG_HALF
POS_ONE = SrcEnum.POS_ONE
NEG_ONE = SrcEnum.NEG_ONE
POS_TWO = SrcEnum.POS_TWO
NEG_TWO = SrcEnum.NEG_TWO
POS_FOUR = SrcEnum.POS_FOUR
NEG_FOUR = SrcEnum.NEG_FOUR
INV_2PI = SrcEnum.INV_2PI
VCCZ = SrcEnum.VCCZ
EXECZ = SrcEnum.EXECZ
SCC = SrcEnum.SCC
LDS_DIRECT = SrcEnum.LDS_DIRECT
v_cmpx_t_u64_e32 = functools.partial(VOPC, VOPCOp.V_CMPX_T_U64)

File diff suppressed because one or more lines are too long

View File

@@ -1,34 +1,97 @@
# autogenerated from AMD RDNA3.5 ISA PDF by pdf.py - do not edit
# autogenerated from AMD ISA PDF by pdf.py - do not edit
from enum import IntEnum
class SrcEnum(IntEnum):
VCC_LO = 106
VCC_HI = 107
NULL = 124
M0 = 125
EXEC_LO = 126
EXEC_HI = 127
ZERO = 128
DPP8 = 233
DPP8FI = 234
SHARED_BASE = 235
SHARED_LIMIT = 236
PRIVATE_BASE = 237
PRIVATE_LIMIT = 238
POS_HALF = 240
NEG_HALF = 241
POS_ONE = 242
NEG_ONE = 243
POS_TWO = 244
NEG_TWO = 245
POS_FOUR = 246
NEG_FOUR = 247
INV_2PI = 248
DPP16 = 250
VCCZ = 251
EXECZ = 252
SCC = 253
LDS_DIRECT = 254
class BufFmt(IntEnum):
BUF_FMT_8_UNORM = 1
BUF_FMT_8_SNORM = 2
BUF_FMT_8_USCALED = 3
BUF_FMT_8_SSCALED = 4
BUF_FMT_8_UINT = 5
BUF_FMT_8_SINT = 6
BUF_FMT_16_UNORM = 7
BUF_FMT_16_SNORM = 8
BUF_FMT_16_USCALED = 9
BUF_FMT_16_SSCALED = 10
BUF_FMT_16_UINT = 11
BUF_FMT_16_SINT = 12
BUF_FMT_16_FLOAT = 13
BUF_FMT_8_8_UNORM = 14
BUF_FMT_8_8_SNORM = 15
BUF_FMT_8_8_USCALED = 16
BUF_FMT_8_8_SSCALED = 17
BUF_FMT_8_8_UINT = 18
BUF_FMT_8_8_SINT = 19
BUF_FMT_32_UINT = 20
BUF_FMT_32_SINT = 21
BUF_FMT_32_FLOAT = 22
BUF_FMT_16_16_UNORM = 23
BUF_FMT_16_16_SNORM = 24
BUF_FMT_16_16_USCALED = 25
BUF_FMT_16_16_SSCALED = 26
BUF_FMT_16_16_UINT = 27
BUF_FMT_16_16_SINT = 28
BUF_FMT_16_16_FLOAT = 29
BUF_FMT_10_11_11_FLOAT = 30
BUF_FMT_11_11_10_FLOAT = 31
BUF_FMT_10_10_10_2_UNORM = 32
BUF_FMT_10_10_10_2_SNORM = 33
BUF_FMT_10_10_10_2_UINT = 34
BUF_FMT_10_10_10_2_SINT = 35
BUF_FMT_2_10_10_10_UNORM = 36
BUF_FMT_2_10_10_10_SNORM = 37
BUF_FMT_2_10_10_10_USCALED = 38
BUF_FMT_2_10_10_10_SSCALED = 39
BUF_FMT_2_10_10_10_UINT = 40
BUF_FMT_2_10_10_10_SINT = 41
BUF_FMT_8_8_8_8_UNORM = 42
BUF_FMT_8_8_8_8_SNORM = 43
BUF_FMT_8_8_8_8_USCALED = 44
BUF_FMT_8_8_8_8_SSCALED = 45
BUF_FMT_8_8_8_8_UINT = 46
BUF_FMT_8_8_8_8_SINT = 47
BUF_FMT_32_32_UINT = 48
BUF_FMT_32_32_SINT = 49
BUF_FMT_32_32_FLOAT = 50
BUF_FMT_16_16_16_16_UNORM = 51
BUF_FMT_16_16_16_16_SNORM = 52
BUF_FMT_16_16_16_16_USCALED = 53
BUF_FMT_16_16_16_16_SSCALED = 54
BUF_FMT_16_16_16_16_UINT = 55
BUF_FMT_16_16_16_16_SINT = 56
BUF_FMT_16_16_16_16_FLOAT = 57
BUF_FMT_32_32_32_UINT = 58
BUF_FMT_32_32_32_SINT = 59
BUF_FMT_32_32_32_FLOAT = 60
BUF_FMT_32_32_32_32_UINT = 61
BUF_FMT_8_SRGB = 64
BUF_FMT_8_8_SRGB = 65
BUF_FMT_8_8_8_8_SRGB = 66
BUF_FMT_5_9_9_9_FLOAT = 67
BUF_FMT_5_6_5_UNORM = 68
BUF_FMT_1_5_5_5_UNORM = 69
BUF_FMT_5_5_5_1_UNORM = 70
BUF_FMT_4_4_4_4_UNORM = 71
BUF_FMT_4_4_UNORM = 72
BUF_FMT_1_UNORM = 73
BUF_FMT_1_REVERSED_UNORM = 74
BUF_FMT_32_FLOAT_CLAMP = 75
BUF_FMT_8_24_UNORM = 76
BUF_FMT_8_24_UINT = 77
BUF_FMT_24_8_UNORM = 78
BUF_FMT_24_8_UINT = 79
BUF_FMT_X24_8_32_UINT = 80
BUF_FMT_X24_8_32_FLOAT = 81
BUF_FMT_GB_GR_UNORM = 82
BUF_FMT_GB_GR_SNORM = 83
BUF_FMT_GB_GR_UINT = 84
BUF_FMT_GB_GR_SRGB = 85
BUF_FMT_BG_RG_UNORM = 86
BUF_FMT_BG_RG_SNORM = 87
BUF_FMT_BG_RG_UINT = 88
BUF_FMT_BG_RG_SRGB = 89
BUF_FMT_BC1_UNORM = 109
BUF_FMT_BC1_SRGB = 110
BUF_FMT_BC2_UNORM = 111
class DSOp(IntEnum):
DS_ADD_U32 = 0
@@ -1372,7 +1435,6 @@ class VOP3POp(IntEnum):
V_WMMA_I32_16X16X16_IU4 = 69
class VOP3SDOp(IntEnum):
DWORD = 1
V_ADD_CO_CI_U32 = 288
V_SUB_CO_CI_U32 = 289
V_SUBREV_CO_CI_U32 = 290

View File

@@ -1,12 +1,11 @@
# autogenerated from AMD RDNA3.5 ISA PDF by pdf.py - do not edit
# autogenerated from AMD ISA PDF by pdf.py - do not edit
# ruff: noqa: F401,F403
from typing import Annotated
from extra.assembly.amd.dsl import bits, BitField, Inst32, Inst64, Inst96, SGPR, VGPR, TTMP as TTMP, s as s, v as v, ttmp as ttmp, SSrc, Src, SImm, Imm, VDSTYEnc, SGPRField, VGPRField
from extra.assembly.amd.dsl import *
from extra.assembly.amd.autogen.rdna3.enum import *
import functools
# instruction formats
class DPP16(Inst64):
class DPP16(Inst):
src0:Src = bits[39:32]
dpp_ctrl = bits[48:40]
fi = bits[50]
@@ -18,7 +17,7 @@ class DPP16(Inst64):
bank_mask = bits[59:56]
row_mask = bits[63:60]
class DPP8(Inst64):
class DPP8(Inst):
src0:Src = bits[39:32]
lane_sel0 = bits[42:40]
lane_sel1 = bits[45:43]
@@ -29,7 +28,7 @@ class DPP8(Inst64):
lane_sel6 = bits[60:58]
lane_sel7 = bits[63:61]
class DS(Inst64):
class DS(Inst):
encoding = bits[31:26] == 0b110110
op:Annotated[BitField, DSOp] = bits[25:18]
vdst:VGPRField = bits[63:56]
@@ -40,18 +39,18 @@ class DS(Inst64):
offset1 = bits[15:8]
gds = bits[17]
class EXP(Inst64):
class EXP(Inst):
encoding = bits[31:26] == 0b111110
vsrc0:VGPRField = bits[39:32]
vsrc1:VGPRField = bits[47:40]
vsrc2:VGPRField = bits[55:48]
vsrc3:VGPRField = bits[63:56]
en = bits[3:0]
target = bits[9:4]
vsrc0 = bits[39:32]
vsrc1:VGPRField = bits[47:40]
vsrc2 = bits[55:48]
vsrc3 = bits[63:56]
done = bits[11]
row = bits[13]
class FLAT(Inst64):
class FLAT(Inst):
encoding = bits[31:26] == 0b110111
op:Annotated[BitField, FLATOp] = bits[24:18]
vdst:VGPRField = bits[63:56]
@@ -60,12 +59,12 @@ class FLAT(Inst64):
saddr:SSrc = bits[54:48]
offset:Imm = bits[12:0]
seg = bits[17:16]
dlc = bits[13]
glc = bits[14]
dlc = bits[13]
slc = bits[15]
sve = bits[55]
class LDSDIR(Inst32):
class LDSDIR(Inst):
encoding = bits[31:24] == 0b11001110
op = bits[21:20]
vdst:VGPRField = bits[7:0]
@@ -73,29 +72,29 @@ class LDSDIR(Inst32):
attr_chan = bits[9:8]
wait_va = bits[19:16]
class MIMG(Inst64):
class MIMG(Inst):
encoding = bits[31:26] == 0b111100
op:Annotated[BitField, MIMGOp] = bits[25:18]
vdata:VGPRField = bits[47:40]
vaddr:VGPRField = bits[39:32]
srsrc:SGPRField = bits[52:48]
ssamp = bits[62:58]
ssamp:SGPRField = bits[62:58]
dmask = bits[11:8]
dim = bits[4:2]
unrm = bits[7]
dlc = bits[13]
glc = bits[14]
dlc = bits[13]
slc = bits[12]
tfe = bits[53]
unrm = bits[7]
nsa = bits[0]
r128 = bits[15]
a16 = bits[16]
d16 = bits[17]
tfe = bits[53]
lwe = bits[54]
addr1 = bits[71:64]
addr2 = bits[79:72]
class MTBUF(Inst64):
class MTBUF(Inst):
encoding = bits[31:26] == 0b111010
op:Annotated[BitField, MTBUFOp] = bits[18:15]
vdata:VGPRField = bits[47:40]
@@ -111,7 +110,7 @@ class MTBUF(Inst64):
slc = bits[12]
tfe = bits[53]
class MUBUF(Inst64):
class MUBUF(Inst):
encoding = bits[31:26] == 0b111000
op:Annotated[BitField, MUBUFOp] = bits[25:18]
vdata:VGPRField = bits[47:40]
@@ -126,7 +125,7 @@ class MUBUF(Inst64):
slc = bits[12]
tfe = bits[53]
class SMEM(Inst64):
class SMEM(Inst):
encoding = bits[31:26] == 0b111101
op:Annotated[BitField, SMEMOp] = bits[25:18]
sdata:SGPRField = bits[12:6]
@@ -136,62 +135,63 @@ class SMEM(Inst64):
glc = bits[14]
dlc = bits[13]
class SOP1(Inst32):
class SOP1(Inst):
encoding = bits[31:23] == 0b101111101
op:Annotated[BitField, SOP1Op] = bits[15:8]
sdst:SGPRField = bits[22:16]
ssrc0:SSrc = bits[7:0]
class SOP2(Inst32):
class SOP2(Inst):
encoding = bits[31:30] == 0b10
op:Annotated[BitField, SOP2Op] = bits[29:23]
sdst:SGPRField = bits[22:16]
ssrc0:SSrc = bits[7:0]
ssrc1:SSrc = bits[15:8]
class SOPC(Inst32):
class SOPC(Inst):
encoding = bits[31:23] == 0b101111110
op:Annotated[BitField, SOPCOp] = bits[22:16]
ssrc0:SSrc = bits[7:0]
ssrc1:SSrc = bits[15:8]
class SOPK(Inst32):
class SOPK(Inst):
encoding = bits[31:28] == 0b1011
op:Annotated[BitField, SOPKOp] = bits[27:23]
sdst:SGPRField = bits[22:16]
simm16:SImm = bits[15:0]
class SOPP(Inst32):
class SOPP(Inst):
encoding = bits[31:23] == 0b101111111
op:Annotated[BitField, SOPPOp] = bits[22:16]
simm16:SImm = bits[15:0]
class VINTERP(Inst64):
class VINTERP(Inst):
encoding = bits[31:24] == 0b11001101
op:Annotated[BitField, VINTERPOp] = bits[22:16]
vdst:VGPRField = bits[7:0]
src0:Src = bits[40:32]
src0:Src = bits[40:32]
src1:Src = bits[49:41]
src2:Src = bits[58:50]
waitexp = bits[10:8]
neg = bits[63:61]
clmp = bits[15]
opsel = bits[14:11]
neg = bits[63:61]
waitexp = bits[10:8]
class VOP1(Inst32):
encoding = bits[31:25] == 0b111111
class VOP1(Inst):
encoding = bits[31:25] == 0b0111111
op:Annotated[BitField, VOP1Op] = bits[16:9]
vdst:VGPRField = bits[24:17]
src0:Src = bits[8:0]
class VOP2(Inst32):
encoding = bits[31] == 0
class VOP2(Inst):
encoding = bits[31] == 0b0
op:Annotated[BitField, VOP2Op] = bits[30:25]
vdst:VGPRField = bits[24:17]
src0:Src = bits[8:0]
vsrc1:VGPRField = bits[16:9]
class VOP3(Inst64):
class VOP3(Inst):
encoding = bits[31:26] == 0b110101
op:Annotated[BitField, VOP3Op] = bits[25:16]
vdst:VGPRField = bits[7:0]
@@ -204,9 +204,8 @@ class VOP3(Inst64):
clmp = bits[15]
opsel = bits[14:11]
class VOP3P(Inst64):
class VOP3P(Inst):
encoding = bits[31:24] == 0b11001100
_defaults = {'opsel_hi': 3, 'opsel_hi2': 1}
op:Annotated[BitField, VOP3POp] = bits[22:16]
vdst:VGPRField = bits[7:0]
src0:Src = bits[40:32]
@@ -214,12 +213,12 @@ class VOP3P(Inst64):
src2:Src = bits[58:50]
neg = bits[63:61]
neg_hi = bits[10:8]
clmp = bits[15]
opsel = bits[13:11]
opsel_hi = bits[60:59]
clmp = bits[15]
opsel_hi2 = bits[14]
class VOP3SD(Inst64):
class VOP3SD(Inst):
encoding = bits[31:26] == 0b110101
op:Annotated[BitField, VOP3SDOp] = bits[25:16]
vdst:VGPRField = bits[7:0]
@@ -227,26 +226,26 @@ class VOP3SD(Inst64):
src0:Src = bits[40:32]
src1:Src = bits[49:41]
src2:Src = bits[58:50]
clmp = bits[15]
omod = bits[60:59]
neg = bits[63:61]
clmp = bits[15]
class VOPC(Inst32):
encoding = bits[31:25] == 0b111110
class VOPC(Inst):
encoding = bits[31:25] == 0b0111110
op:Annotated[BitField, VOPCOp] = bits[24:17]
src0:Src = bits[8:0]
vsrc1:VGPRField = bits[16:9]
class VOPD(Inst64):
class VOPD(Inst):
encoding = bits[31:26] == 0b110010
opx:Annotated[BitField, VOPDOp] = bits[25:22]
opy:Annotated[BitField, VOPDOp] = bits[21:17]
vdstx:VGPRField = bits[63:56]
vdstx = bits[63:56]
vdsty:VDSTYEnc = bits[55:49]
srcx0:Src = bits[8:0]
vsrcx1:VGPRField = bits[16:9]
srcy0:Src = bits[40:32]
vsrcy1:VGPRField = bits[48:41]
vsrcx1 = bits[16:9]
vsrcy1 = bits[48:41]
# instruction helpers
ds_add_u32 = functools.partial(DS, DSOp.DS_ADD_U32)
@@ -1077,16 +1076,16 @@ v_add_nc_u32_e32 = functools.partial(VOP2, VOP2Op.V_ADD_NC_U32)
v_sub_nc_u32_e32 = functools.partial(VOP2, VOP2Op.V_SUB_NC_U32)
v_subrev_nc_u32_e32 = functools.partial(VOP2, VOP2Op.V_SUBREV_NC_U32)
v_fmac_f32_e32 = functools.partial(VOP2, VOP2Op.V_FMAC_F32)
def v_fmamk_f32_e32(vdst, src0, K, vsrc1): return VOP2(VOP2Op.V_FMAMK_F32, vdst, src0, vsrc1, literal=K)
def v_fmaak_f32_e32(vdst, src0, vsrc1, K): return VOP2(VOP2Op.V_FMAAK_F32, vdst, src0, vsrc1, literal=K)
v_fmamk_f32_e32 = functools.partial(VOP2, VOP2Op.V_FMAMK_F32)
v_fmaak_f32_e32 = functools.partial(VOP2, VOP2Op.V_FMAAK_F32)
v_cvt_pk_rtz_f16_f32_e32 = functools.partial(VOP2, VOP2Op.V_CVT_PK_RTZ_F16_F32)
v_add_f16_e32 = functools.partial(VOP2, VOP2Op.V_ADD_F16)
v_sub_f16_e32 = functools.partial(VOP2, VOP2Op.V_SUB_F16)
v_subrev_f16_e32 = functools.partial(VOP2, VOP2Op.V_SUBREV_F16)
v_mul_f16_e32 = functools.partial(VOP2, VOP2Op.V_MUL_F16)
v_fmac_f16_e32 = functools.partial(VOP2, VOP2Op.V_FMAC_F16)
def v_fmamk_f16_e32(vdst, src0, K, vsrc1): return VOP2(VOP2Op.V_FMAMK_F16, vdst, src0, vsrc1, literal=K)
def v_fmaak_f16_e32(vdst, src0, vsrc1, K): return VOP2(VOP2Op.V_FMAAK_F16, vdst, src0, vsrc1, literal=K)
v_fmamk_f16_e32 = functools.partial(VOP2, VOP2Op.V_FMAMK_F16)
v_fmaak_f16_e32 = functools.partial(VOP2, VOP2Op.V_FMAAK_F16)
v_max_f16_e32 = functools.partial(VOP2, VOP2Op.V_MAX_F16)
v_min_f16_e32 = functools.partial(VOP2, VOP2Op.V_MIN_F16)
v_ldexp_f16_e32 = functools.partial(VOP2, VOP2Op.V_LDEXP_F16)
@@ -1554,7 +1553,6 @@ v_wmma_f16_16x16x16_f16 = functools.partial(VOP3P, VOP3POp.V_WMMA_F16_16X16X16_F
v_wmma_bf16_16x16x16_bf16 = functools.partial(VOP3P, VOP3POp.V_WMMA_BF16_16X16X16_BF16)
v_wmma_i32_16x16x16_iu8 = functools.partial(VOP3P, VOP3POp.V_WMMA_I32_16X16X16_IU8)
v_wmma_i32_16x16x16_iu4 = functools.partial(VOP3P, VOP3POp.V_WMMA_I32_16X16X16_IU4)
dword = functools.partial(VOP3SD, VOP3SDOp.DWORD)
v_add_co_ci_u32 = functools.partial(VOP3SD, VOP3SDOp.V_ADD_CO_CI_U32)
v_sub_co_ci_u32 = functools.partial(VOP3SD, VOP3SDOp.V_SUB_CO_CI_U32)
v_subrev_co_ci_u32 = functools.partial(VOP3SD, VOP3SDOp.V_SUBREV_CO_CI_U32)
@@ -1771,31 +1769,4 @@ v_dual_dot2acc_f32_f16 = functools.partial(VOPD, VOPDOp.V_DUAL_DOT2ACC_F32_F16)
v_dual_dot2acc_f32_bf16 = functools.partial(VOPD, VOPDOp.V_DUAL_DOT2ACC_F32_BF16)
v_dual_add_nc_u32 = functools.partial(VOPD, VOPDOp.V_DUAL_ADD_NC_U32)
v_dual_lshlrev_b32 = functools.partial(VOPD, VOPDOp.V_DUAL_LSHLREV_B32)
v_dual_and_b32 = functools.partial(VOPD, VOPDOp.V_DUAL_AND_B32)
VCC_LO = SrcEnum.VCC_LO
VCC_HI = SrcEnum.VCC_HI
NULL = SrcEnum.NULL
M0 = SrcEnum.M0
EXEC_LO = SrcEnum.EXEC_LO
EXEC_HI = SrcEnum.EXEC_HI
ZERO = SrcEnum.ZERO
DPP8FI = SrcEnum.DPP8FI
SHARED_BASE = SrcEnum.SHARED_BASE
SHARED_LIMIT = SrcEnum.SHARED_LIMIT
PRIVATE_BASE = SrcEnum.PRIVATE_BASE
PRIVATE_LIMIT = SrcEnum.PRIVATE_LIMIT
POS_HALF = SrcEnum.POS_HALF
NEG_HALF = SrcEnum.NEG_HALF
POS_ONE = SrcEnum.POS_ONE
NEG_ONE = SrcEnum.NEG_ONE
POS_TWO = SrcEnum.POS_TWO
NEG_TWO = SrcEnum.NEG_TWO
POS_FOUR = SrcEnum.POS_FOUR
NEG_FOUR = SrcEnum.NEG_FOUR
INV_2PI = SrcEnum.INV_2PI
VCCZ = SrcEnum.VCCZ
EXECZ = SrcEnum.EXECZ
SCC = SrcEnum.SCC
LDS_DIRECT = SrcEnum.LDS_DIRECT
OFF = NULL
v_dual_and_b32 = functools.partial(VOPD, VOPDOp.V_DUAL_AND_B32)

File diff suppressed because it is too large Load Diff

View File

@@ -1,34 +1,100 @@
# autogenerated from AMD RDNA4 ISA PDF by pdf.py - do not edit
# autogenerated from AMD ISA PDF by pdf.py - do not edit
from enum import IntEnum
class SrcEnum(IntEnum):
VCC_LO = 106
VCC_HI = 107
NULL = 124
M0 = 125
EXEC_LO = 126
EXEC_HI = 127
ZERO = 128
DPP8 = 233
DPP8FI = 234
SHARED_BASE = 235
SHARED_LIMIT = 236
PRIVATE_BASE = 237
PRIVATE_LIMIT = 238
POS_HALF = 240
NEG_HALF = 241
POS_ONE = 242
NEG_ONE = 243
POS_TWO = 244
NEG_TWO = 245
POS_FOUR = 246
NEG_FOUR = 247
INV_2PI = 248
DPP16 = 250
VCCZ = 251
EXECZ = 252
SCC = 253
LDS_DIRECT = 254
class BufFmt(IntEnum):
BUF_FMT_8_UNORM = 1
BUF_FMT_8_SNORM = 2
BUF_FMT_8_USCALED = 3
BUF_FMT_8_SSCALED = 4
BUF_FMT_8_UINT = 5
BUF_FMT_8_SINT = 6
BUF_FMT_16_UNORM = 7
BUF_FMT_16_SNORM = 8
BUF_FMT_16_USCALED = 9
BUF_FMT_16_SSCALED = 10
BUF_FMT_16_UINT = 11
BUF_FMT_16_SINT = 12
BUF_FMT_16_FLOAT = 13
BUF_FMT_8_8_UNORM = 14
BUF_FMT_8_8_SNORM = 15
BUF_FMT_8_8_USCALED = 16
BUF_FMT_8_8_SSCALED = 17
BUF_FMT_8_8_UINT = 18
BUF_FMT_8_8_SINT = 19
BUF_FMT_32_UINT = 20
BUF_FMT_32_SINT = 21
BUF_FMT_32_FLOAT = 22
BUF_FMT_16_16_UNORM = 23
BUF_FMT_16_16_SNORM = 24
BUF_FMT_16_16_USCALED = 25
BUF_FMT_16_16_SSCALED = 26
BUF_FMT_16_16_UINT = 27
BUF_FMT_16_16_SINT = 28
BUF_FMT_16_16_FLOAT = 29
BUF_FMT_10_11_11_FLOAT = 30
BUF_FMT_11_11_10_FLOAT = 31
BUF_FMT_10_10_10_2_UNORM = 32
BUF_FMT_10_10_10_2_SNORM = 33
BUF_FMT_10_10_10_2_UINT = 34
BUF_FMT_10_10_10_2_SINT = 35
BUF_FMT_2_10_10_10_UNORM = 36
BUF_FMT_2_10_10_10_SNORM = 37
BUF_FMT_2_10_10_10_USCALED = 38
BUF_FMT_2_10_10_10_SSCALED = 39
BUF_FMT_2_10_10_10_UINT = 40
BUF_FMT_2_10_10_10_SINT = 41
BUF_FMT_8_8_8_8_UNORM = 42
BUF_FMT_8_8_8_8_SNORM = 43
BUF_FMT_8_8_8_8_USCALED = 44
BUF_FMT_8_8_8_8_SSCALED = 45
BUF_FMT_8_8_8_8_UINT = 46
BUF_FMT_8_8_8_8_SINT = 47
BUF_FMT_32_32_UINT = 48
BUF_FMT_32_32_SINT = 49
BUF_FMT_32_32_FLOAT = 50
BUF_FMT_16_16_16_16_UNORM = 51
BUF_FMT_16_16_16_16_SNORM = 52
BUF_FMT_16_16_16_16_USCALED = 53
BUF_FMT_16_16_16_16_SSCALED = 54
BUF_FMT_16_16_16_16_UINT = 55
BUF_FMT_16_16_16_16_SINT = 56
BUF_FMT_16_16_16_16_FLOAT = 57
BUF_FMT_32_32_32_UINT = 58
BUF_FMT_32_32_32_SINT = 59
BUF_FMT_32_32_32_FLOAT = 60
BUF_FMT_32_32_32_32_UINT = 61
BUF_FMT_32_32_32_32_SINT = 62
BUF_FMT_32_32_32_32_FLOAT = 63
BUF_FMT_8_SRGB = 64
BUF_FMT_8_8_SRGB = 65
BUF_FMT_8_8_8_8_SRGB = 66
BUF_FMT_5_9_9_9_FLOAT = 67
BUF_FMT_5_6_5_UNORM = 68
BUF_FMT_1_5_5_5_UNORM = 69
BUF_FMT_5_5_5_1_UNORM = 70
BUF_FMT_4_4_4_4_UNORM = 71
BUF_FMT_4_4_UNORM = 72
BUF_FMT_1_UNORM = 73
BUF_FMT_1_REVERSED_UNORM = 74
BUF_FMT_32_FLOAT_CLAMP = 75
BUF_FMT_8_24_UNORM = 76
BUF_FMT_8_24_UINT = 77
BUF_FMT_24_8_UNORM = 78
BUF_FMT_24_8_UINT = 79
BUF_FMT_X24_8_32_UINT = 80
BUF_FMT_X24_8_32_FLOAT = 81
BUF_FMT_GB_GR_UNORM = 82
BUF_FMT_GB_GR_SNORM = 83
BUF_FMT_GB_GR_UINT = 84
BUF_FMT_GB_GR_SRGB = 85
BUF_FMT_BG_RG_UNORM = 86
BUF_FMT_BG_RG_SNORM = 87
BUF_FMT_BG_RG_UINT = 88
BUF_FMT_BG_RG_SRGB = 89
BUF_FMT_BC1_UNORM = 109
BUF_FMT_BC1_SRGB = 110
BUF_FMT_BC2_UNORM = 111
BUF_FMT_BC2_SRGB = 112
class DSOp(IntEnum):
DS_ADD_U32 = 0
@@ -1347,7 +1413,6 @@ class VOP3POp(IntEnum):
V_SWMMAC_F32_16X16X32_BF8_BF8 = 90
class VOP3SDOp(IntEnum):
DWORD = 1
V_ADD_CO_CI_U32 = 288
V_SUB_CO_CI_U32 = 289
V_SUBREV_CO_CI_U32 = 290

View File

@@ -1,12 +1,11 @@
# autogenerated from AMD RDNA4 ISA PDF by pdf.py - do not edit
# autogenerated from AMD ISA PDF by pdf.py - do not edit
# ruff: noqa: F401,F403
from typing import Annotated
from extra.assembly.amd.dsl import bits, BitField, Inst32, Inst64, Inst96, SGPR, VGPR, TTMP as TTMP, s as s, v as v, ttmp as ttmp, SSrc, Src, SImm, Imm, VDSTYEnc, SGPRField, VGPRField
from extra.assembly.amd.dsl import *
from extra.assembly.amd.autogen.rdna4.enum import *
import functools
# instruction formats
class DPP16(Inst64):
class DPP16(Inst):
src0:Src = bits[39:32]
dpp_ctrl = bits[48:40]
fi = bits[50]
@@ -18,7 +17,7 @@ class DPP16(Inst64):
bank_mask = bits[59:56]
row_mask = bits[63:60]
class DPP8(Inst64):
class DPP8(Inst):
src0:Src = bits[39:32]
lane_sel0 = bits[42:40]
lane_sel1 = bits[45:43]
@@ -29,7 +28,17 @@ class DPP8(Inst64):
lane_sel6 = bits[60:58]
lane_sel7 = bits[63:61]
class SMEM(Inst64):
class DS(Inst):
encoding = bits[31:26] == 0b110110
op:Annotated[BitField, DSOp] = bits[25:18]
vdst:VGPRField = bits[63:56]
addr:VGPRField = bits[39:32]
data0:VGPRField = bits[47:40]
data1:VGPRField = bits[55:48]
offset0 = bits[7:0]
offset1 = bits[15:8]
class SMEM(Inst):
encoding = bits[31:26] == 0b111101
op:Annotated[BitField, SMEMOp] = bits[18:13]
sdata:SGPRField = bits[12:6]
@@ -39,153 +48,116 @@ class SMEM(Inst64):
th = bits[24:23]
ioffset = bits[55:32]
class SOP1(Inst32):
class SOP1(Inst):
encoding = bits[31:23] == 0b101111101
op:Annotated[BitField, SOP1Op] = bits[15:8]
sdst:SGPRField = bits[22:16]
ssrc0:SSrc = bits[7:0]
class SOP2(Inst32):
class SOP2(Inst):
encoding = bits[31:30] == 0b10
op:Annotated[BitField, SOP2Op] = bits[29:23]
sdst:SGPRField = bits[22:16]
ssrc0:SSrc = bits[7:0]
ssrc1:SSrc = bits[15:8]
class SOPC(Inst32):
class SOPC(Inst):
encoding = bits[31:23] == 0b101111110
op:Annotated[BitField, SOPCOp] = bits[22:16]
ssrc0:SSrc = bits[7:0]
ssrc1:SSrc = bits[15:8]
class SOPK(Inst32):
class SOPK(Inst):
encoding = bits[31:28] == 0b1011
op:Annotated[BitField, SOPKOp] = bits[27:23]
sdst:SGPRField = bits[22:16]
simm16:SImm = bits[15:0]
class SOPP(Inst32):
class SOPP(Inst):
encoding = bits[31:23] == 0b101111111
op:Annotated[BitField, SOPPOp] = bits[22:16]
simm16:SImm = bits[15:0]
class VBUFFER(Inst96):
class VBUFFER(Inst):
encoding = bits[31:26] == 0b110001
soffset:SSrc = bits[6:0]
op:Annotated[BitField, VBUFFEROp] = bits[21:14]
tfe = bits[22]
vdata:VGPRField = bits[39:32]
rsrc = bits[49:41]
scope = bits[51:50]
th = bits[54:52]
vaddr:VGPRField = bits[71:64]
soffset:SSrc = bits[6:0]
format = bits[61:55]
offen = bits[62]
idxen = bits[63]
vaddr:VGPRField = bits[71:64]
tfe = bits[22]
rsrc = bits[49:41]
scope = bits[51:50]
th = bits[54:52]
ioffset = bits[95:72]
class VDS(Inst64):
encoding = bits[31:26] == 0b110110
offset0 = bits[7:0]
offset1 = bits[15:8]
op = bits[25:18]
addr:VGPRField = bits[39:32]
data0:VGPRField = bits[47:40]
data1:VGPRField = bits[55:48]
vdst:VGPRField = bits[63:56]
class VDSDIR(Inst64):
encoding = bits[31:24] == 0b11001101
class VDSDIR(Inst):
encoding = bits[31:24] == 0b11001110
op:Annotated[BitField, VDSDIROp] = bits[21:20]
vdst:VGPRField = bits[7:0]
waitexp = bits[10:8]
opsel = bits[14:11]
cm = bits[15]
op:Annotated[BitField, VDSDIROp] = bits[20:16]
src0:Src = bits[40:32]
src1:Src = bits[49:41]
src2:Src = bits[58:50]
neg = bits[63:61]
attr = bits[15:10]
attr_chan = bits[9:8]
wait_va = bits[19:16]
wait_vmvsrc = bits[23]
class VEXPORT(Inst64):
class VEXPORT(Inst):
encoding = bits[31:26] == 0b111110
vsrc0:VGPRField = bits[39:32]
vsrc1:VGPRField = bits[47:40]
vsrc2:VGPRField = bits[55:48]
vsrc3:VGPRField = bits[63:56]
en = bits[3:0]
target = bits[9:4]
done = bits[11]
row = bits[13]
vsrc0 = bits[39:32]
vsrc1:VGPRField = bits[47:40]
vsrc2 = bits[55:48]
vsrc3 = bits[63:56]
class VFLAT(Inst96):
encoding = bits[31:24] == 0b11101100
saddr:SSrc = bits[6:0]
op:Annotated[BitField, VFLATOp] = bits[20:14]
vdst:VGPRField = bits[39:32]
sve = bits[49]
scope = bits[51:50]
th = bits[54:52]
vsrc = bits[62:55]
vaddr:VGPRField = bits[71:64]
ioffset = bits[95:72]
class VGLOBAL(Inst96):
encoding = bits[31:24] == 0b11101110
saddr:SSrc = bits[6:0]
op:Annotated[BitField, VGLOBALOp] = bits[20:14]
vdst:VGPRField = bits[39:32]
sve = bits[49]
scope = bits[51:50]
th = bits[54:52]
vsrc = bits[62:55]
vaddr:VGPRField = bits[71:64]
ioffset = bits[95:72]
class VIMAGE(Inst96):
class VIMAGE(Inst):
encoding = bits[31:26] == 0b110100
op:Annotated[BitField, VIMAGEOp] = bits[21:14]
vdata:VGPRField = bits[39:32]
dmask = bits[25:22]
dim = bits[2:0]
tfe = bits[55]
r128 = bits[4]
d16 = bits[5]
a16 = bits[6]
op:Annotated[BitField, VIMAGEOp] = bits[21:14]
dmask = bits[25:22]
vdata:VGPRField = bits[39:32]
rsrc = bits[49:41]
scope = bits[51:50]
th = bits[54:52]
tfe = bits[55]
vaddr4 = bits[56:63]
vaddr0 = bits[71:64]
vaddr1 = bits[79:72]
vaddr2 = bits[87:80]
vaddr3 = bits[95:88]
class VINTERP(Inst64):
class VINTERP(Inst):
encoding = bits[31:24] == 0b11001101
op:Annotated[BitField, VINTERPOp] = bits[20:16]
vdst:VGPRField = bits[7:0]
src0:Src = bits[40:32]
src1:Src = bits[49:41]
src2:Src = bits[58:50]
waitexp = bits[10:8]
opsel = bits[14:11]
neg = bits[63:61]
opsel = bits[14:11]
waitexp = bits[10:8]
cm = bits[15]
class VOP1(Inst32):
encoding = bits[31:25] == 0b111111
class VOP1(Inst):
encoding = bits[31:25] == 0b0111111
op:Annotated[BitField, VOP1Op] = bits[15:9]
vdst:VGPRField = bits[24:17]
src0:Src = bits[8:0]
class VOP2(Inst32):
encoding = bits[31] == 0
class VOP2(Inst):
encoding = bits[31] == 0b0
op:Annotated[BitField, VOP2Op] = bits[30:25]
vdst:VGPRField = bits[24:17]
src0:Src = bits[8:0]
vsrc1:VGPRField = bits[16:9]
class VOP3(Inst64):
class VOP3(Inst):
encoding = bits[31:26] == 0b110101
op:Annotated[BitField, VOP3Op] = bits[25:16]
vdst:VGPRField = bits[7:0]
@@ -198,9 +170,8 @@ class VOP3(Inst64):
opsel = bits[14:11]
cm = bits[15]
class VOP3P(Inst64):
class VOP3P(Inst):
encoding = bits[31:24] == 0b11001100
_defaults = {'opsel_hi': 3, 'opsel_hi2': 1}
op:Annotated[BitField, VOP3POp] = bits[22:16]
vdst:VGPRField = bits[7:0]
src0:Src = bits[40:32]
@@ -213,7 +184,7 @@ class VOP3P(Inst64):
opsel_hi2 = bits[14]
cm = bits[15]
class VOP3SD(Inst64):
class VOP3SD(Inst):
encoding = bits[31:26] == 0b110101
op:Annotated[BitField, VOP3SDOp] = bits[25:16]
vdst:VGPRField = bits[7:0]
@@ -221,38 +192,38 @@ class VOP3SD(Inst64):
src0:Src = bits[40:32]
src1:Src = bits[49:41]
src2:Src = bits[58:50]
cm = bits[15]
omod = bits[60:59]
neg = bits[63:61]
cm = bits[15]
class VOPC(Inst32):
encoding = bits[31:25] == 0b111110
class VOPC(Inst):
encoding = bits[31:25] == 0b0111110
op:Annotated[BitField, VOPCOp] = bits[24:17]
src0:Src = bits[8:0]
vsrc1:VGPRField = bits[16:9]
class VOPD(Inst64):
class VOPD(Inst):
encoding = bits[31:26] == 0b110010
opx:Annotated[BitField, VOPDOp] = bits[25:22]
opy:Annotated[BitField, VOPDOp] = bits[21:17]
vdstx:VGPRField = bits[63:56]
vdstx = bits[63:56]
vdsty:VDSTYEnc = bits[55:49]
srcx0:Src = bits[8:0]
vsrcx1:VGPRField = bits[16:9]
srcy0:Src = bits[40:32]
vsrcy1:VGPRField = bits[48:41]
vsrcx1 = bits[16:9]
vsrcy1 = bits[48:41]
class VSAMPLE(Inst96):
class VSAMPLE(Inst):
encoding = bits[31:26] == 0b111001
op:Annotated[BitField, VSAMPLEOp] = bits[21:14]
vdata:VGPRField = bits[39:32]
dmask = bits[25:22]
dim = bits[2:0]
tfe = bits[3]
unrm = bits[13]
r128 = bits[4]
d16 = bits[5]
a16 = bits[6]
unrm = bits[13]
op:Annotated[BitField, VSAMPLEOp] = bits[21:14]
dmask = bits[25:22]
vdata:VGPRField = bits[39:32]
lwe = bits[40]
rsrc = bits[49:41]
scope = bits[51:50]
@@ -263,19 +234,130 @@ class VSAMPLE(Inst96):
vaddr2 = bits[87:80]
vaddr3 = bits[95:88]
class VSCRATCH(Inst96):
encoding = bits[31:24] == 0b11101101
saddr:SSrc = bits[6:0]
op:Annotated[BitField, VSCRATCHOp] = bits[20:14]
vdst:VGPRField = bits[39:32]
sve = bits[49]
scope = bits[51:50]
th = bits[54:52]
vsrc = bits[62:55]
vaddr:VGPRField = bits[71:64]
ioffset = bits[95:72]
# instruction helpers
ds_add_u32 = functools.partial(DS, DSOp.DS_ADD_U32)
ds_sub_u32 = functools.partial(DS, DSOp.DS_SUB_U32)
ds_rsub_u32 = functools.partial(DS, DSOp.DS_RSUB_U32)
ds_inc_u32 = functools.partial(DS, DSOp.DS_INC_U32)
ds_dec_u32 = functools.partial(DS, DSOp.DS_DEC_U32)
ds_min_i32 = functools.partial(DS, DSOp.DS_MIN_I32)
ds_max_i32 = functools.partial(DS, DSOp.DS_MAX_I32)
ds_min_u32 = functools.partial(DS, DSOp.DS_MIN_U32)
ds_max_u32 = functools.partial(DS, DSOp.DS_MAX_U32)
ds_and_b32 = functools.partial(DS, DSOp.DS_AND_B32)
ds_or_b32 = functools.partial(DS, DSOp.DS_OR_B32)
ds_xor_b32 = functools.partial(DS, DSOp.DS_XOR_B32)
ds_mskor_b32 = functools.partial(DS, DSOp.DS_MSKOR_B32)
ds_store_b32 = functools.partial(DS, DSOp.DS_STORE_B32)
ds_store_2addr_b32 = functools.partial(DS, DSOp.DS_STORE_2ADDR_B32)
ds_store_2addr_stride64_b32 = functools.partial(DS, DSOp.DS_STORE_2ADDR_STRIDE64_B32)
ds_cmpstore_b32 = functools.partial(DS, DSOp.DS_CMPSTORE_B32)
ds_min_num_f32 = functools.partial(DS, DSOp.DS_MIN_NUM_F32)
ds_max_num_f32 = functools.partial(DS, DSOp.DS_MAX_NUM_F32)
ds_nop = functools.partial(DS, DSOp.DS_NOP)
ds_add_f32 = functools.partial(DS, DSOp.DS_ADD_F32)
ds_store_b8 = functools.partial(DS, DSOp.DS_STORE_B8)
ds_store_b16 = functools.partial(DS, DSOp.DS_STORE_B16)
ds_add_rtn_u32 = functools.partial(DS, DSOp.DS_ADD_RTN_U32)
ds_sub_rtn_u32 = functools.partial(DS, DSOp.DS_SUB_RTN_U32)
ds_rsub_rtn_u32 = functools.partial(DS, DSOp.DS_RSUB_RTN_U32)
ds_inc_rtn_u32 = functools.partial(DS, DSOp.DS_INC_RTN_U32)
ds_dec_rtn_u32 = functools.partial(DS, DSOp.DS_DEC_RTN_U32)
ds_min_rtn_i32 = functools.partial(DS, DSOp.DS_MIN_RTN_I32)
ds_max_rtn_i32 = functools.partial(DS, DSOp.DS_MAX_RTN_I32)
ds_min_rtn_u32 = functools.partial(DS, DSOp.DS_MIN_RTN_U32)
ds_max_rtn_u32 = functools.partial(DS, DSOp.DS_MAX_RTN_U32)
ds_and_rtn_b32 = functools.partial(DS, DSOp.DS_AND_RTN_B32)
ds_or_rtn_b32 = functools.partial(DS, DSOp.DS_OR_RTN_B32)
ds_xor_rtn_b32 = functools.partial(DS, DSOp.DS_XOR_RTN_B32)
ds_mskor_rtn_b32 = functools.partial(DS, DSOp.DS_MSKOR_RTN_B32)
ds_storexchg_rtn_b32 = functools.partial(DS, DSOp.DS_STOREXCHG_RTN_B32)
ds_storexchg_2addr_rtn_b32 = functools.partial(DS, DSOp.DS_STOREXCHG_2ADDR_RTN_B32)
ds_storexchg_2addr_stride64_rtn_b32 = functools.partial(DS, DSOp.DS_STOREXCHG_2ADDR_STRIDE64_RTN_B32)
ds_cmpstore_rtn_b32 = functools.partial(DS, DSOp.DS_CMPSTORE_RTN_B32)
ds_min_num_rtn_f32 = functools.partial(DS, DSOp.DS_MIN_NUM_RTN_F32)
ds_max_num_rtn_f32 = functools.partial(DS, DSOp.DS_MAX_NUM_RTN_F32)
ds_swizzle_b32 = functools.partial(DS, DSOp.DS_SWIZZLE_B32)
ds_load_b32 = functools.partial(DS, DSOp.DS_LOAD_B32)
ds_load_2addr_b32 = functools.partial(DS, DSOp.DS_LOAD_2ADDR_B32)
ds_load_2addr_stride64_b32 = functools.partial(DS, DSOp.DS_LOAD_2ADDR_STRIDE64_B32)
ds_load_i8 = functools.partial(DS, DSOp.DS_LOAD_I8)
ds_load_u8 = functools.partial(DS, DSOp.DS_LOAD_U8)
ds_load_i16 = functools.partial(DS, DSOp.DS_LOAD_I16)
ds_load_u16 = functools.partial(DS, DSOp.DS_LOAD_U16)
ds_consume = functools.partial(DS, DSOp.DS_CONSUME)
ds_append = functools.partial(DS, DSOp.DS_APPEND)
ds_add_u64 = functools.partial(DS, DSOp.DS_ADD_U64)
ds_sub_u64 = functools.partial(DS, DSOp.DS_SUB_U64)
ds_rsub_u64 = functools.partial(DS, DSOp.DS_RSUB_U64)
ds_inc_u64 = functools.partial(DS, DSOp.DS_INC_U64)
ds_dec_u64 = functools.partial(DS, DSOp.DS_DEC_U64)
ds_min_i64 = functools.partial(DS, DSOp.DS_MIN_I64)
ds_max_i64 = functools.partial(DS, DSOp.DS_MAX_I64)
ds_min_u64 = functools.partial(DS, DSOp.DS_MIN_U64)
ds_max_u64 = functools.partial(DS, DSOp.DS_MAX_U64)
ds_and_b64 = functools.partial(DS, DSOp.DS_AND_B64)
ds_or_b64 = functools.partial(DS, DSOp.DS_OR_B64)
ds_xor_b64 = functools.partial(DS, DSOp.DS_XOR_B64)
ds_mskor_b64 = functools.partial(DS, DSOp.DS_MSKOR_B64)
ds_store_b64 = functools.partial(DS, DSOp.DS_STORE_B64)
ds_store_2addr_b64 = functools.partial(DS, DSOp.DS_STORE_2ADDR_B64)
ds_store_2addr_stride64_b64 = functools.partial(DS, DSOp.DS_STORE_2ADDR_STRIDE64_B64)
ds_cmpstore_b64 = functools.partial(DS, DSOp.DS_CMPSTORE_B64)
ds_min_num_f64 = functools.partial(DS, DSOp.DS_MIN_NUM_F64)
ds_max_num_f64 = functools.partial(DS, DSOp.DS_MAX_NUM_F64)
ds_add_rtn_u64 = functools.partial(DS, DSOp.DS_ADD_RTN_U64)
ds_sub_rtn_u64 = functools.partial(DS, DSOp.DS_SUB_RTN_U64)
ds_rsub_rtn_u64 = functools.partial(DS, DSOp.DS_RSUB_RTN_U64)
ds_inc_rtn_u64 = functools.partial(DS, DSOp.DS_INC_RTN_U64)
ds_dec_rtn_u64 = functools.partial(DS, DSOp.DS_DEC_RTN_U64)
ds_min_rtn_i64 = functools.partial(DS, DSOp.DS_MIN_RTN_I64)
ds_max_rtn_i64 = functools.partial(DS, DSOp.DS_MAX_RTN_I64)
ds_min_rtn_u64 = functools.partial(DS, DSOp.DS_MIN_RTN_U64)
ds_max_rtn_u64 = functools.partial(DS, DSOp.DS_MAX_RTN_U64)
ds_and_rtn_b64 = functools.partial(DS, DSOp.DS_AND_RTN_B64)
ds_or_rtn_b64 = functools.partial(DS, DSOp.DS_OR_RTN_B64)
ds_xor_rtn_b64 = functools.partial(DS, DSOp.DS_XOR_RTN_B64)
ds_mskor_rtn_b64 = functools.partial(DS, DSOp.DS_MSKOR_RTN_B64)
ds_storexchg_rtn_b64 = functools.partial(DS, DSOp.DS_STOREXCHG_RTN_B64)
ds_storexchg_2addr_rtn_b64 = functools.partial(DS, DSOp.DS_STOREXCHG_2ADDR_RTN_B64)
ds_storexchg_2addr_stride64_rtn_b64 = functools.partial(DS, DSOp.DS_STOREXCHG_2ADDR_STRIDE64_RTN_B64)
ds_cmpstore_rtn_b64 = functools.partial(DS, DSOp.DS_CMPSTORE_RTN_B64)
ds_min_num_rtn_f64 = functools.partial(DS, DSOp.DS_MIN_NUM_RTN_F64)
ds_max_num_rtn_f64 = functools.partial(DS, DSOp.DS_MAX_NUM_RTN_F64)
ds_load_b64 = functools.partial(DS, DSOp.DS_LOAD_B64)
ds_load_2addr_b64 = functools.partial(DS, DSOp.DS_LOAD_2ADDR_B64)
ds_load_2addr_stride64_b64 = functools.partial(DS, DSOp.DS_LOAD_2ADDR_STRIDE64_B64)
ds_add_rtn_f32 = functools.partial(DS, DSOp.DS_ADD_RTN_F32)
ds_condxchg32_rtn_b64 = functools.partial(DS, DSOp.DS_CONDXCHG32_RTN_B64)
ds_cond_sub_u32 = functools.partial(DS, DSOp.DS_COND_SUB_U32)
ds_sub_clamp_u32 = functools.partial(DS, DSOp.DS_SUB_CLAMP_U32)
ds_pk_add_f16 = functools.partial(DS, DSOp.DS_PK_ADD_F16)
ds_pk_add_bf16 = functools.partial(DS, DSOp.DS_PK_ADD_BF16)
ds_store_b8_d16_hi = functools.partial(DS, DSOp.DS_STORE_B8_D16_HI)
ds_store_b16_d16_hi = functools.partial(DS, DSOp.DS_STORE_B16_D16_HI)
ds_load_u8_d16 = functools.partial(DS, DSOp.DS_LOAD_U8_D16)
ds_load_u8_d16_hi = functools.partial(DS, DSOp.DS_LOAD_U8_D16_HI)
ds_load_i8_d16 = functools.partial(DS, DSOp.DS_LOAD_I8_D16)
ds_load_i8_d16_hi = functools.partial(DS, DSOp.DS_LOAD_I8_D16_HI)
ds_load_u16_d16 = functools.partial(DS, DSOp.DS_LOAD_U16_D16)
ds_load_u16_d16_hi = functools.partial(DS, DSOp.DS_LOAD_U16_D16_HI)
ds_cond_sub_rtn_u32 = functools.partial(DS, DSOp.DS_COND_SUB_RTN_U32)
ds_sub_clamp_rtn_u32 = functools.partial(DS, DSOp.DS_SUB_CLAMP_RTN_U32)
ds_pk_add_rtn_f16 = functools.partial(DS, DSOp.DS_PK_ADD_RTN_F16)
ds_pk_add_rtn_bf16 = functools.partial(DS, DSOp.DS_PK_ADD_RTN_BF16)
ds_store_addtid_b32 = functools.partial(DS, DSOp.DS_STORE_ADDTID_B32)
ds_load_addtid_b32 = functools.partial(DS, DSOp.DS_LOAD_ADDTID_B32)
ds_permute_b32 = functools.partial(DS, DSOp.DS_PERMUTE_B32)
ds_bpermute_b32 = functools.partial(DS, DSOp.DS_BPERMUTE_B32)
ds_bpermute_fi_b32 = functools.partial(DS, DSOp.DS_BPERMUTE_FI_B32)
ds_store_b96 = functools.partial(DS, DSOp.DS_STORE_B96)
ds_store_b128 = functools.partial(DS, DSOp.DS_STORE_B128)
ds_bvh_stack_push4_pop1_rtn_b32 = functools.partial(DS, DSOp.DS_BVH_STACK_PUSH4_POP1_RTN_B32)
ds_bvh_stack_push8_pop1_rtn_b32 = functools.partial(DS, DSOp.DS_BVH_STACK_PUSH8_POP1_RTN_B32)
ds_bvh_stack_push8_pop2_rtn_b64 = functools.partial(DS, DSOp.DS_BVH_STACK_PUSH8_POP2_RTN_B64)
ds_load_b96 = functools.partial(DS, DSOp.DS_LOAD_B96)
ds_load_b128 = functools.partial(DS, DSOp.DS_LOAD_B128)
s_load_b32 = functools.partial(SMEM, SMEMOp.S_LOAD_B32)
s_load_b64 = functools.partial(SMEM, SMEMOp.S_LOAD_B64)
s_load_b128 = functools.partial(SMEM, SMEMOp.S_LOAD_B128)
@@ -647,126 +729,6 @@ tbuffer_store_d16_format_xyz = functools.partial(VBUFFER, VBUFFEROp.TBUFFER_STOR
tbuffer_store_d16_format_xyzw = functools.partial(VBUFFER, VBUFFEROp.TBUFFER_STORE_D16_FORMAT_XYZW)
ds_param_load = functools.partial(VDSDIR, VDSDIROp.DS_PARAM_LOAD)
ds_direct_load = functools.partial(VDSDIR, VDSDIROp.DS_DIRECT_LOAD)
flat_load_u8 = functools.partial(VFLAT, VFLATOp.FLAT_LOAD_U8)
flat_load_i8 = functools.partial(VFLAT, VFLATOp.FLAT_LOAD_I8)
flat_load_u16 = functools.partial(VFLAT, VFLATOp.FLAT_LOAD_U16)
flat_load_i16 = functools.partial(VFLAT, VFLATOp.FLAT_LOAD_I16)
flat_load_b32 = functools.partial(VFLAT, VFLATOp.FLAT_LOAD_B32)
flat_load_b64 = functools.partial(VFLAT, VFLATOp.FLAT_LOAD_B64)
flat_load_b96 = functools.partial(VFLAT, VFLATOp.FLAT_LOAD_B96)
flat_load_b128 = functools.partial(VFLAT, VFLATOp.FLAT_LOAD_B128)
flat_store_b8 = functools.partial(VFLAT, VFLATOp.FLAT_STORE_B8)
flat_store_b16 = functools.partial(VFLAT, VFLATOp.FLAT_STORE_B16)
flat_store_b32 = functools.partial(VFLAT, VFLATOp.FLAT_STORE_B32)
flat_store_b64 = functools.partial(VFLAT, VFLATOp.FLAT_STORE_B64)
flat_store_b96 = functools.partial(VFLAT, VFLATOp.FLAT_STORE_B96)
flat_store_b128 = functools.partial(VFLAT, VFLATOp.FLAT_STORE_B128)
flat_load_d16_u8 = functools.partial(VFLAT, VFLATOp.FLAT_LOAD_D16_U8)
flat_load_d16_i8 = functools.partial(VFLAT, VFLATOp.FLAT_LOAD_D16_I8)
flat_load_d16_b16 = functools.partial(VFLAT, VFLATOp.FLAT_LOAD_D16_B16)
flat_load_d16_hi_u8 = functools.partial(VFLAT, VFLATOp.FLAT_LOAD_D16_HI_U8)
flat_load_d16_hi_i8 = functools.partial(VFLAT, VFLATOp.FLAT_LOAD_D16_HI_I8)
flat_load_d16_hi_b16 = functools.partial(VFLAT, VFLATOp.FLAT_LOAD_D16_HI_B16)
flat_store_d16_hi_b8 = functools.partial(VFLAT, VFLATOp.FLAT_STORE_D16_HI_B8)
flat_store_d16_hi_b16 = functools.partial(VFLAT, VFLATOp.FLAT_STORE_D16_HI_B16)
flat_atomic_swap_b32 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_SWAP_B32)
flat_atomic_cmpswap_b32 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_CMPSWAP_B32)
flat_atomic_add_u32 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_ADD_U32)
flat_atomic_sub_u32 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_SUB_U32)
flat_atomic_sub_clamp_u32 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_SUB_CLAMP_U32)
flat_atomic_min_i32 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_MIN_I32)
flat_atomic_min_u32 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_MIN_U32)
flat_atomic_max_i32 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_MAX_I32)
flat_atomic_max_u32 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_MAX_U32)
flat_atomic_and_b32 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_AND_B32)
flat_atomic_or_b32 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_OR_B32)
flat_atomic_xor_b32 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_XOR_B32)
flat_atomic_inc_u32 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_INC_U32)
flat_atomic_dec_u32 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_DEC_U32)
flat_atomic_swap_b64 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_SWAP_B64)
flat_atomic_cmpswap_b64 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_CMPSWAP_B64)
flat_atomic_add_u64 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_ADD_U64)
flat_atomic_sub_u64 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_SUB_U64)
flat_atomic_min_i64 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_MIN_I64)
flat_atomic_min_u64 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_MIN_U64)
flat_atomic_max_i64 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_MAX_I64)
flat_atomic_max_u64 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_MAX_U64)
flat_atomic_and_b64 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_AND_B64)
flat_atomic_or_b64 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_OR_B64)
flat_atomic_xor_b64 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_XOR_B64)
flat_atomic_inc_u64 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_INC_U64)
flat_atomic_dec_u64 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_DEC_U64)
flat_atomic_cond_sub_u32 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_COND_SUB_U32)
flat_atomic_min_num_f32 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_MIN_NUM_F32)
flat_atomic_max_num_f32 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_MAX_NUM_F32)
flat_atomic_add_f32 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_ADD_F32)
flat_atomic_pk_add_f16 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_PK_ADD_F16)
flat_atomic_pk_add_bf16 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_PK_ADD_BF16)
global_load_u8 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_LOAD_U8)
global_load_i8 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_LOAD_I8)
global_load_u16 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_LOAD_U16)
global_load_i16 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_LOAD_I16)
global_load_b32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_LOAD_B32)
global_load_b64 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_LOAD_B64)
global_load_b96 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_LOAD_B96)
global_load_b128 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_LOAD_B128)
global_store_b8 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_STORE_B8)
global_store_b16 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_STORE_B16)
global_store_b32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_STORE_B32)
global_store_b64 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_STORE_B64)
global_store_b96 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_STORE_B96)
global_store_b128 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_STORE_B128)
global_load_d16_u8 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_LOAD_D16_U8)
global_load_d16_i8 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_LOAD_D16_I8)
global_load_d16_b16 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_LOAD_D16_B16)
global_load_d16_hi_u8 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_LOAD_D16_HI_U8)
global_load_d16_hi_i8 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_LOAD_D16_HI_I8)
global_load_d16_hi_b16 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_LOAD_D16_HI_B16)
global_store_d16_hi_b8 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_STORE_D16_HI_B8)
global_store_d16_hi_b16 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_STORE_D16_HI_B16)
global_load_addtid_b32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_LOAD_ADDTID_B32)
global_store_addtid_b32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_STORE_ADDTID_B32)
global_inv = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_INV)
global_wb = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_WB)
global_atomic_swap_b32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_SWAP_B32)
global_atomic_cmpswap_b32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_CMPSWAP_B32)
global_atomic_add_u32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_ADD_U32)
global_atomic_sub_u32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_SUB_U32)
global_atomic_sub_clamp_u32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_SUB_CLAMP_U32)
global_atomic_min_i32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_MIN_I32)
global_atomic_min_u32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_MIN_U32)
global_atomic_max_i32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_MAX_I32)
global_atomic_max_u32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_MAX_U32)
global_atomic_and_b32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_AND_B32)
global_atomic_or_b32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_OR_B32)
global_atomic_xor_b32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_XOR_B32)
global_atomic_inc_u32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_INC_U32)
global_atomic_dec_u32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_DEC_U32)
global_atomic_swap_b64 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_SWAP_B64)
global_atomic_cmpswap_b64 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_CMPSWAP_B64)
global_atomic_add_u64 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_ADD_U64)
global_atomic_sub_u64 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_SUB_U64)
global_atomic_min_i64 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_MIN_I64)
global_atomic_min_u64 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_MIN_U64)
global_atomic_max_i64 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_MAX_I64)
global_atomic_max_u64 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_MAX_U64)
global_atomic_and_b64 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_AND_B64)
global_atomic_or_b64 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_OR_B64)
global_atomic_xor_b64 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_XOR_B64)
global_atomic_inc_u64 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_INC_U64)
global_atomic_dec_u64 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_DEC_U64)
global_wbinv = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_WBINV)
global_atomic_cond_sub_u32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_COND_SUB_U32)
global_atomic_min_num_f32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_MIN_NUM_F32)
global_atomic_max_num_f32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_MAX_NUM_F32)
global_load_block = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_LOAD_BLOCK)
global_store_block = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_STORE_BLOCK)
global_atomic_add_f32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_ADD_F32)
global_load_tr_b128 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_LOAD_TR_B128)
global_load_tr_b64 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_LOAD_TR_B64)
global_atomic_pk_add_f16 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_PK_ADD_F16)
global_atomic_pk_add_bf16 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_PK_ADD_BF16)
global_atomic_ordered_add_b64 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_ORDERED_ADD_B64)
image_load = functools.partial(VIMAGE, VIMAGEOp.IMAGE_LOAD)
image_load_mip = functools.partial(VIMAGE, VIMAGEOp.IMAGE_LOAD_MIP)
image_load_pck = functools.partial(VIMAGE, VIMAGEOp.IMAGE_LOAD_PCK)
@@ -931,8 +893,8 @@ v_add_nc_u32_e32 = functools.partial(VOP2, VOP2Op.V_ADD_NC_U32)
v_sub_nc_u32_e32 = functools.partial(VOP2, VOP2Op.V_SUB_NC_U32)
v_subrev_nc_u32_e32 = functools.partial(VOP2, VOP2Op.V_SUBREV_NC_U32)
v_fmac_f32_e32 = functools.partial(VOP2, VOP2Op.V_FMAC_F32)
def v_fmamk_f32_e32(vdst, src0, K, vsrc1): return VOP2(VOP2Op.V_FMAMK_F32, vdst, src0, vsrc1, literal=K)
def v_fmaak_f32_e32(vdst, src0, vsrc1, K): return VOP2(VOP2Op.V_FMAAK_F32, vdst, src0, vsrc1, literal=K)
v_fmamk_f32_e32 = functools.partial(VOP2, VOP2Op.V_FMAMK_F32)
v_fmaak_f32_e32 = functools.partial(VOP2, VOP2Op.V_FMAAK_F32)
v_cvt_pk_rtz_f16_f32_e32 = functools.partial(VOP2, VOP2Op.V_CVT_PK_RTZ_F16_F32)
v_min_num_f16_e32 = functools.partial(VOP2, VOP2Op.V_MIN_NUM_F16)
v_max_num_f16_e32 = functools.partial(VOP2, VOP2Op.V_MAX_NUM_F16)
@@ -941,8 +903,8 @@ v_sub_f16_e32 = functools.partial(VOP2, VOP2Op.V_SUB_F16)
v_subrev_f16_e32 = functools.partial(VOP2, VOP2Op.V_SUBREV_F16)
v_mul_f16_e32 = functools.partial(VOP2, VOP2Op.V_MUL_F16)
v_fmac_f16_e32 = functools.partial(VOP2, VOP2Op.V_FMAC_F16)
def v_fmamk_f16_e32(vdst, src0, K, vsrc1): return VOP2(VOP2Op.V_FMAMK_F16, vdst, src0, vsrc1, literal=K)
def v_fmaak_f16_e32(vdst, src0, vsrc1, K): return VOP2(VOP2Op.V_FMAAK_F16, vdst, src0, vsrc1, literal=K)
v_fmamk_f16_e32 = functools.partial(VOP2, VOP2Op.V_FMAMK_F16)
v_fmaak_f16_e32 = functools.partial(VOP2, VOP2Op.V_FMAAK_F16)
v_ldexp_f16_e32 = functools.partial(VOP2, VOP2Op.V_LDEXP_F16)
v_pk_fmac_f16_e32 = functools.partial(VOP2, VOP2Op.V_PK_FMAC_F16)
v_cmp_lt_f16_e64 = functools.partial(VOP3, VOP3Op.V_CMP_LT_F16)
@@ -1435,7 +1397,6 @@ v_swmmac_f32_16x16x32_fp8_fp8 = functools.partial(VOP3P, VOP3POp.V_SWMMAC_F32_16
v_swmmac_f32_16x16x32_fp8_bf8 = functools.partial(VOP3P, VOP3POp.V_SWMMAC_F32_16X16X32_FP8_BF8)
v_swmmac_f32_16x16x32_bf8_fp8 = functools.partial(VOP3P, VOP3POp.V_SWMMAC_F32_16X16X32_BF8_FP8)
v_swmmac_f32_16x16x32_bf8_bf8 = functools.partial(VOP3P, VOP3POp.V_SWMMAC_F32_16X16X32_BF8_BF8)
dword = functools.partial(VOP3SD, VOP3SDOp.DWORD)
v_add_co_ci_u32 = functools.partial(VOP3SD, VOP3SDOp.V_ADD_CO_CI_U32)
v_sub_co_ci_u32 = functools.partial(VOP3SD, VOP3SDOp.V_SUB_CO_CI_U32)
v_subrev_co_ci_u32 = functools.partial(VOP3SD, VOP3SDOp.V_SUBREV_CO_CI_U32)
@@ -1682,55 +1643,4 @@ image_gather4_c_cl = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_GATHER4_C_CL)
image_gather4_c_l = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_GATHER4_C_L)
image_gather4_c_b = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_GATHER4_C_B)
image_gather4_c_b_cl = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_GATHER4_C_B_CL)
image_gather4h = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_GATHER4H)
scratch_load_u8 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_LOAD_U8)
scratch_load_i8 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_LOAD_I8)
scratch_load_u16 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_LOAD_U16)
scratch_load_i16 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_LOAD_I16)
scratch_load_b32 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_LOAD_B32)
scratch_load_b64 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_LOAD_B64)
scratch_load_b96 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_LOAD_B96)
scratch_load_b128 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_LOAD_B128)
scratch_store_b8 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_STORE_B8)
scratch_store_b16 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_STORE_B16)
scratch_store_b32 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_STORE_B32)
scratch_store_b64 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_STORE_B64)
scratch_store_b96 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_STORE_B96)
scratch_store_b128 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_STORE_B128)
scratch_load_d16_u8 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_LOAD_D16_U8)
scratch_load_d16_i8 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_LOAD_D16_I8)
scratch_load_d16_b16 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_LOAD_D16_B16)
scratch_load_d16_hi_u8 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_LOAD_D16_HI_U8)
scratch_load_d16_hi_i8 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_LOAD_D16_HI_I8)
scratch_load_d16_hi_b16 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_LOAD_D16_HI_B16)
scratch_store_d16_hi_b8 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_STORE_D16_HI_B8)
scratch_store_d16_hi_b16 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_STORE_D16_HI_B16)
scratch_load_block = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_LOAD_BLOCK)
scratch_store_block = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_STORE_BLOCK)
VCC_LO = SrcEnum.VCC_LO
VCC_HI = SrcEnum.VCC_HI
NULL = SrcEnum.NULL
M0 = SrcEnum.M0
EXEC_LO = SrcEnum.EXEC_LO
EXEC_HI = SrcEnum.EXEC_HI
ZERO = SrcEnum.ZERO
DPP8FI = SrcEnum.DPP8FI
SHARED_BASE = SrcEnum.SHARED_BASE
SHARED_LIMIT = SrcEnum.SHARED_LIMIT
PRIVATE_BASE = SrcEnum.PRIVATE_BASE
PRIVATE_LIMIT = SrcEnum.PRIVATE_LIMIT
POS_HALF = SrcEnum.POS_HALF
NEG_HALF = SrcEnum.NEG_HALF
POS_ONE = SrcEnum.POS_ONE
NEG_ONE = SrcEnum.NEG_ONE
POS_TWO = SrcEnum.POS_TWO
NEG_TWO = SrcEnum.NEG_TWO
POS_FOUR = SrcEnum.POS_FOUR
NEG_FOUR = SrcEnum.NEG_FOUR
INV_2PI = SrcEnum.INV_2PI
VCCZ = SrcEnum.VCCZ
EXECZ = SrcEnum.EXECZ
SCC = SrcEnum.SCC
LDS_DIRECT = SrcEnum.LDS_DIRECT
OFF = NULL
image_gather4h = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_GATHER4H)

File diff suppressed because one or more lines are too long

View File

@@ -7,6 +7,19 @@ from functools import cache
from typing import overload, Annotated, TypeVar, Generic
from extra.assembly.amd.autogen.rdna3.enum import (VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOP3POp, VOPCOp, VOPDOp, SOP1Op, SOP2Op,
SOPCOp, SOPKOp, SOPPOp, SMEMOp, DSOp, FLATOp, MUBUFOp, MTBUFOp, MIMGOp, VINTERPOp)
from extra.assembly.amd.autogen.cdna.enum import VOP1Op as CDNA_VOP1Op, VOP2Op as CDNA_VOP2Op
# Source operand encoding - constant across all AMD ISAs
class SrcEnum(IntEnum):
VCC_LO=106; VCC_HI=107; NULL=124; M0=125; EXEC_LO=126; EXEC_HI=127; ZERO=128
DPP8=233; DPP8FI=234; SHARED_BASE=235; SHARED_LIMIT=236; PRIVATE_BASE=237; PRIVATE_LIMIT=238
POS_HALF=240; NEG_HALF=241; POS_ONE=242; NEG_ONE=243; POS_TWO=244; NEG_TWO=245
POS_FOUR=246; NEG_FOUR=247; INV_2PI=248; DPP16=250; VCCZ=251; EXECZ=252; SCC=253; LDS_DIRECT=254
VCC_LO, VCC_HI, NULL, M0, EXEC_LO, EXEC_HI, ZERO = SrcEnum.VCC_LO, SrcEnum.VCC_HI, SrcEnum.NULL, SrcEnum.M0, SrcEnum.EXEC_LO, SrcEnum.EXEC_HI, SrcEnum.ZERO
DPP8FI, SHARED_BASE, SHARED_LIMIT, PRIVATE_BASE, PRIVATE_LIMIT = SrcEnum.DPP8FI, SrcEnum.SHARED_BASE, SrcEnum.SHARED_LIMIT, SrcEnum.PRIVATE_BASE, SrcEnum.PRIVATE_LIMIT
POS_HALF, NEG_HALF, POS_ONE, NEG_ONE, POS_TWO, NEG_TWO = SrcEnum.POS_HALF, SrcEnum.NEG_HALF, SrcEnum.POS_ONE, SrcEnum.NEG_ONE, SrcEnum.POS_TWO, SrcEnum.NEG_TWO
POS_FOUR, NEG_FOUR, INV_2PI, VCCZ, EXECZ, SCC, LDS_DIRECT = SrcEnum.POS_FOUR, SrcEnum.NEG_FOUR, SrcEnum.INV_2PI, SrcEnum.VCCZ, SrcEnum.EXECZ, SrcEnum.SCC, SrcEnum.LDS_DIRECT
OFF = NULL
# Common masks and bit conversion functions
MASK32, MASK64, MASK128 = 0xffffffff, 0xffffffffffffffff, (1 << 128) - 1
@@ -54,12 +67,14 @@ def f32_to_f16(f):
# Instruction spec - register counts and dtypes derived from instruction names
_REGS = {'B32': 1, 'B64': 2, 'B96': 3, 'B128': 4, 'B256': 8, 'B512': 16,
'F32': 1, 'I32': 1, 'U32': 1, 'F64': 2, 'I64': 2, 'U64': 2,
'F16': 1, 'I16': 1, 'U16': 1, 'B16': 1, 'I8': 1, 'U8': 1, 'B8': 1}
'F16': 1, 'I16': 1, 'U16': 1, 'B16': 1, 'I8': 1, 'U8': 1, 'B8': 1,
'DWORD': 1, 'DWORDX2': 2, 'DWORDX3': 3, 'DWORDX4': 4, 'DWORDX8': 8, 'DWORDX16': 16,
'BYTE': 1, 'SHORT': 1, 'UBYTE': 1, 'SBYTE': 1, 'USHORT': 1, 'SSHORT': 1}
_CVT_RE = re.compile(r'CVT_([FIUB]\d+)_([FIUB]\d+)$')
_MAD_MUL_RE = re.compile(r'(?:MAD|MUL)_([IU]\d+)_([IU]\d+)$')
_PACK_RE = re.compile(r'PACK_([FIUB]\d+)_([FIUB]\d+)$')
_DST_SRC_RE = re.compile(r'_([FIUB]\d+)_([FIUB]\d+)$')
_SINGLE_RE = re.compile(r'_([FIUB](?:32|64|16|8|96|128|256|512))$')
_SINGLE_RE = re.compile(r'_([FIUB](?:32|64|16|8|96|128|256|512)|DWORD(?:X(?:2|3|4|8|16))?|[US]?BYTE|[US]?SHORT)$')
@cache
def _suffix(name: str) -> tuple[str | None, str | None]:
name = name.upper()
@@ -250,7 +265,11 @@ def unwrap(val) -> int:
FLOAT_ENC = {0.5: 240, -0.5: 241, 1.0: 242, -1.0: 243, 2.0: 244, -2.0: 245, 4.0: 246, -4.0: 247}
FLOAT_DEC = {v: str(k) for k, v in FLOAT_ENC.items()}
SPECIAL_GPRS = {106: "vcc_lo", 107: "vcc_hi", 124: "null", 125: "m0", 126: "exec_lo", 127: "exec_hi", 253: "scc"}
SPECIAL_GPRS_CDNA = {102: "flat_scratch_lo", 103: "flat_scratch_hi", 104: "xnack_mask_lo", 105: "xnack_mask_hi",
106: "vcc_lo", 107: "vcc_hi", 124: "m0", 126: "exec_lo", 127: "exec_hi",
251: "src_vccz", 252: "src_execz", 253: "src_scc", 254: "src_lds_direct"}
SPECIAL_PAIRS = {106: "vcc", 126: "exec"}
SPECIAL_PAIRS_CDNA = {102: "flat_scratch", 104: "xnack_mask", 106: "vcc", 126: "exec"}
SRC_FIELDS = {'src0', 'src1', 'src2', 'ssrc0', 'ssrc1', 'soffset', 'srcx0', 'srcy0'}
RAW_FIELDS = {'vdata', 'vdst', 'vaddr', 'addr', 'data', 'data0', 'data1', 'sdst', 'sdata', 'vsrc1'}
@@ -267,9 +286,10 @@ def encode_src(val) -> int:
if isinstance(val, int): return 128 + val if 0 <= val <= 64 else 192 - val if -16 <= val <= -1 else 255
return 255
def decode_src(val: int) -> str:
def decode_src(val: int, cdna: bool = False) -> str:
special = SPECIAL_GPRS_CDNA if cdna else SPECIAL_GPRS
if val in special: return special[val]
if val <= 105: return f"s{val}"
if val in SPECIAL_GPRS: return SPECIAL_GPRS[val]
if val in FLOAT_DEC: return FLOAT_DEC[val]
if 108 <= val <= 123: return f"ttmp{val - 108}"
if 128 <= val <= 192: return str(val - 128)
@@ -288,7 +308,16 @@ class Inst:
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
cls._fields = {n: v[0] if isinstance(v, tuple) else v for n, v in cls.__dict__.items() if isinstance(v, BitField) or (isinstance(v, tuple) and len(v) == 2 and isinstance(v[0], BitField))}
# Merge fields from parent classes
cls._fields = {}
for base in reversed(cls.__mro__):
if base is Inst or not hasattr(base, '_fields'): continue
cls._fields.update(base._fields)
# Add this class's own fields (overrides parents)
cls._fields.update({n: v[0] if isinstance(v, tuple) else v for n, v in cls.__dict__.items() if isinstance(v, BitField) or (isinstance(v, tuple) and len(v) == 2 and isinstance(v[0], BitField))})
# Compute size from max bit (exclude optional fields starting at bit 64+, e.g. MIMG NSA)
max_bit = max((bf.hi for bf in cls._fields.values() if bf.lo < 64), default=0) if cls._fields else 0
cls._sz = 12 if max_bit > 63 else 8 if max_bit > 31 else 4
if 'encoding' in cls._fields and isinstance(cls.__dict__.get('encoding'), tuple): cls._encoding = cls.__dict__['encoding']
def _or_field(self, name: str, bit: int):
@@ -352,6 +381,16 @@ class Inst:
field_names = [n for n in self._fields if n != 'encoding']
# Map Python-friendly names to actual field names (abs_ -> abs for Python reserved word)
if 'abs_' in kwargs: kwargs['abs'] = kwargs.pop('abs_')
# If more args than fields, treat extra arg as literal (for FMAAK/FMAMK style instructions)
# FMAMK has K in middle (vdst, src0, K, vsrc1), FMAAK has K at end (vdst, src0, vsrc1, K)
args = list(args)
if len(args) > len(field_names) and literal is None:
for i, a in enumerate(args):
if isinstance(a, int) and not isinstance(a, SrcEnum) and i < len(field_names) and field_names[i] in ('vsrc1',):
literal = args.pop(i)
break
else:
literal = args.pop() # fallback: last arg is literal
orig_args = dict(zip(field_names, args)) | kwargs
self._values.update(orig_args)
self._precompute()
@@ -393,14 +432,14 @@ class Inst:
if name in SRC_FIELDS: self._encode_src(name, val)
elif name in RAW_FIELDS: self._encode_raw(name, val)
elif name == 'sbase': self._values[name] = (val.idx if isinstance(val, Reg) else val.val if isinstance(val, SrcMod) else val * 2) // 2
elif name in {'srsrc', 'ssamp'} and isinstance(val, Reg): self._values[name] = val.idx // 4
elif name in {'srsrc', 'ssamp'} and isinstance(val, Reg): self._values[name] = _encode_reg(val) // 4
elif marker is _VDSTYEnc and isinstance(val, VGPR): self._values[name] = val.idx >> 1
self._precompute_fields()
def _encode_field(self, name: str, val) -> int:
if isinstance(val, RawImm): return val.val
if isinstance(val, SrcMod) and not isinstance(val, Reg): return val.val # Special regs like VCC_LO
if name in {'srsrc', 'ssamp'}: return val.idx // 4 if isinstance(val, Reg) else val
if name in {'srsrc', 'ssamp'}: return _encode_reg(val) // 4 if isinstance(val, Reg) else val
if name == 'sbase': return val.idx // 2 if isinstance(val, Reg) else val.val // 2 if isinstance(val, SrcMod) else val
if name in RAW_FIELDS: return _encode_reg(val) if isinstance(val, Reg) else val
if isinstance(val, Reg) or name in SRC_FIELDS: return encode_src(val)
@@ -450,7 +489,7 @@ class Inst:
return result + (lit32 & MASK32).to_bytes(4, 'little')
@classmethod
def _size(cls) -> int: return 4 if issubclass(cls, Inst32) else 12 if issubclass(cls, Inst96) else 8
def _size(cls) -> int: return cls._sz
def size(self) -> int:
# Literal is always 4 bytes in the binary (for 64-bit ops, it's in high 32 bits)
return self._size() + (4 if self._literal is not None else 0)
@@ -514,7 +553,7 @@ class Inst:
lit32 = (self._literal >> 32) if self._literal > 0xffffffff else self._literal
s = f"0x{lit32:x}"
else:
s = decode_src(v)
s = decode_src(v, 'cdna' in self.__class__.__module__)
return f"-{s}" if neg else s
def __eq__(self, other):
@@ -540,21 +579,30 @@ class Inst:
elif hasattr(val, 'name'): self.op = val
else:
cls_name = self.__class__.__name__
# VOP3 with VOPC opcodes (0-255) -> VOPCOp, VOP3SD opcodes -> VOP3SDOp
if cls_name == 'VOP3':
try:
if val < 256: self.op = VOPCOp(val)
elif val in self._VOP3SD_OPS: self.op = VOP3SDOp(val)
else: self.op = VOP3Op(val)
except ValueError: self.op = val
# Prefer BitField marker (class-specific enum) over _enum_map (generic RDNA3 enums)
elif 'op' in self._fields and (marker := self._fields['op'].marker) and issubclass(marker, IntEnum):
is_cdna = cls_name in ('VOP3A', 'VOP3B')
# Try marker enum first (VOP3AOp, VOP3BOp, etc.)
marker = self._fields['op'].marker if 'op' in self._fields else None
if marker and issubclass(marker, IntEnum):
try: self.op = marker(val)
except ValueError: self.op = val
elif cls_name in self._enum_map:
try: self.op = self._enum_map[cls_name](val)
except ValueError: self.op = val
else: self.op = val
# Fallback for promoted instructions when marker lookup failed
if not hasattr(self.op, 'name') and cls_name in ('VOP3', 'VOP3A', 'VOP3B') and isinstance(val, int):
if val < 256:
try: self.op = VOPCOp(val)
except ValueError: pass
elif is_cdna and 256 <= val < 512:
try: self.op = (CDNA_VOP1Op(val - 320) if val >= 320 else CDNA_VOP2Op(val - 256))
except ValueError: pass
elif val in self._VOP3SD_OPS and not is_cdna:
try: self.op = VOP3SDOp(val)
except ValueError: pass
elif 256 <= val < 512 and not is_cdna:
try: self.op = VOP1Op(val - 384) if val >= 384 else VOP2Op(val - 256)
except ValueError: pass
self.op_name = self.op.name if hasattr(self.op, 'name') else ''
self._spec_regs = spec_regs(self.op_name)
self._spec_dtype = spec_dtype(self.op_name)
@@ -574,6 +622,4 @@ class Inst:
def is_64bit(self) -> bool: return spec_is_64bit(self.op_name)
def is_dst_16(self) -> bool: return self._spec_regs[0] == 1 and is_dtype_16(self._spec_dtype[0])
class Inst32(Inst): pass
class Inst64(Inst): pass
class Inst96(Inst): pass

View File

@@ -7,8 +7,9 @@ from extra.assembly.amd.dsl import Inst, unwrap, FLOAT_ENC, MASK32, MASK64, _f32
from extra.assembly.amd.asm import detect_format
from extra.assembly.amd.ucode import compile_uop
from extra.assembly.amd.autogen.rdna3.str_pcode import PSEUDOCODE_STRINGS
from extra.assembly.amd.dsl import SrcEnum
from extra.assembly.amd.autogen.rdna3.ins import (SOP1, SOP2, SOPC, SOPK, SOPP, SMEM, VOP1, VOP2, VOP3, VOP3SD, VOP3P, VOPC, DS, FLAT, VOPD,
SrcEnum, SOP1Op, SOP2Op, SOPCOp, SOPKOp, SOPPOp, SMEMOp, VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOP3POp, VOPCOp, DSOp, FLATOp, GLOBALOp, SCRATCHOp, VOPDOp)
SOP1Op, SOP2Op, SOPCOp, SOPKOp, SOPPOp, SMEMOp, VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOP3POp, VOPCOp, DSOp, FLATOp, GLOBALOp, SCRATCHOp, VOPDOp)
WAVE_SIZE, SGPR_COUNT, VGPR_COUNT = 32, 128, 256
VCC_LO, VCC_HI, NULL, EXEC_LO, EXEC_HI, SCC = SrcEnum.VCC_LO, SrcEnum.VCC_HI, SrcEnum.NULL, SrcEnum.EXEC_LO, SrcEnum.EXEC_HI, SrcEnum.SCC
@@ -291,14 +292,16 @@ def exec_vop(st: WaveState, inst: Inst, V: list, lane: int) -> None:
extra_kwargs = {'opsel': opsel, 'opsel_hi': inst.opsel_hi | (inst.opsel_hi2 << 2)} if isinstance(inst, VOP3P) and 'FMA_MIX' in inst.op_name else {}
result = inst._fn(s0, s1, s2, d0, st.scc, vcc_for_fn, lane, st.exec_mask, inst._literal, st.vgpr, src0_idx, vdst, **extra_kwargs)
# Check if this is a VOPC instruction (either standalone VOPC or VOP3 with VOPC opcode)
is_vopc = isinstance(inst.op, VOPCOp) or (isinstance(inst, VOP3) and inst.op.value < 256)
if 'VCC' in result:
if isinstance(inst, VOP3SD): st.pend_sgpr_lane(inst.sdst, lane, (result['VCC'] >> lane) & 1)
else: st.pend_sgpr_lane(VCC_LO if isinstance(inst, VOP2) and 'CO_CI' in inst.op_name else vdst, lane, (result['VCC'] >> lane) & 1)
if 'EXEC' in result:
st.pend_sgpr_lane(EXEC_LO, lane, (result['EXEC'] >> lane) & 1)
elif isinstance(inst.op, VOPCOp):
elif is_vopc:
st.pend_sgpr_lane(vdst, lane, (result['D0'] >> lane) & 1)
if not isinstance(inst.op, VOPCOp):
if not is_vopc:
d0_val = result['D0']
if inst.dst_regs() == 2: V[vdst], V[vdst + 1] = d0_val & MASK32, (d0_val >> 32) & MASK32
elif not isinstance(inst, VOP3P) and inst.is_dst_16(): V[vdst] = _dst16(V[vdst], d0_val, bool(opsel & 8) if isinstance(inst, VOP3) else dst_hi)

View File

@@ -1,435 +1,305 @@
# Generate AMD ISA autogen files from PDF documentation
# Combines format/enum generation (previously in dsl.py) and pseudocode compilation (previously in pcode.py)
# Usage: python -m extra.assembly.amd.pdf [--arch rdna3|rdna4|cdna|all]
import re, functools
from pathlib import Path
from concurrent.futures import ProcessPoolExecutor
# Generic PDF text extractor - no external dependencies
import re, zlib
from tinygrad.helpers import fetch, merge_dicts
PDF_URLS = {
"rdna3": "https://docs.amd.com/api/khub/documents/UVVZM22UN7tMUeiW_4ShTQ/content",
"rdna4": "https://docs.amd.com/api/khub/documents/uQpkEvk3pv~kfAb2x~j4uw/content",
"cdna": ["https://www.amd.com/content/dam/amd/en/documents/instinct-tech-docs/instruction-set-architectures/amd-instinct-mi300-cdna3-instruction-set-architecture.pdf",
"https://www.amd.com/content/dam/amd/en/documents/instinct-tech-docs/instruction-set-architectures/amd-instinct-cdna4-instruction-set-architecture.pdf"],
"cdna": "https://www.amd.com/content/dam/amd/en/documents/instinct-tech-docs/instruction-set-architectures/amd-instinct-cdna4-instruction-set-architecture.pdf",
}
# Field type mappings and ordering
FIELD_TYPES = {'SSRC0': 'SSrc', 'SSRC1': 'SSrc', 'SOFFSET': 'SSrc', 'SADDR': 'SSrc', 'SRC0': 'Src', 'SRC1': 'Src', 'SRC2': 'Src',
'SDST': 'SGPRField', 'SBASE': 'SGPRField', 'SDATA': 'SGPRField', 'SRSRC': 'SGPRField', 'VDST': 'VGPRField', 'VSRC1': 'VGPRField',
'VDATA': 'VGPRField', 'VADDR': 'VGPRField', 'ADDR': 'VGPRField', 'DATA': 'VGPRField', 'DATA0': 'VGPRField', 'DATA1': 'VGPRField',
'SIMM16': 'SImm', 'OFFSET': 'Imm', 'OPX': 'VOPDOp', 'OPY': 'VOPDOp', 'SRCX0': 'Src', 'SRCY0': 'Src',
'VSRCX1': 'VGPRField', 'VSRCY1': 'VGPRField', 'VDSTX': 'VGPRField', 'VDSTY': 'VDSTYEnc'}
FIELD_ORDER = {
'SOP2': ['op', 'sdst', 'ssrc0', 'ssrc1'], 'SOP1': ['op', 'sdst', 'ssrc0'], 'SOPC': ['op', 'ssrc0', 'ssrc1'],
'SOPK': ['op', 'sdst', 'simm16'], 'SOPP': ['op', 'simm16'], 'VOP1': ['op', 'vdst', 'src0'], 'VOPC': ['op', 'src0', 'vsrc1'],
'VOP2': ['op', 'vdst', 'src0', 'vsrc1'], 'VOP3SD': ['op', 'vdst', 'sdst', 'src0', 'src1', 'src2', 'clmp'],
'SMEM': ['op', 'sdata', 'sbase', 'soffset', 'offset', 'glc', 'dlc'], 'DS': ['op', 'vdst', 'addr', 'data0', 'data1'],
'VOP3': ['op', 'vdst', 'src0', 'src1', 'src2', 'omod', 'neg', 'abs', 'clmp', 'opsel'],
'VOP3P': ['op', 'vdst', 'src0', 'src1', 'src2', 'neg', 'neg_hi', 'opsel', 'opsel_hi', 'clmp'],
'FLAT': ['op', 'vdst', 'addr', 'data', 'saddr', 'offset', 'seg', 'dlc', 'glc', 'slc'],
'MUBUF': ['op', 'vdata', 'vaddr', 'srsrc', 'soffset', 'offset', 'offen', 'idxen', 'glc', 'dlc', 'slc', 'tfe'],
'MTBUF': ['op', 'vdata', 'vaddr', 'srsrc', 'soffset', 'offset', 'format', 'offen', 'idxen', 'glc', 'dlc', 'slc', 'tfe'],
'MIMG': ['op', 'vdata', 'vaddr', 'srsrc', 'ssamp', 'dmask', 'dim', 'unrm', 'dlc', 'glc', 'slc'],
'EXP': ['en', 'target', 'vsrc0', 'vsrc1', 'vsrc2', 'vsrc3', 'done', 'row'],
'VINTERP': ['op', 'vdst', 'src0', 'src1', 'src2', 'waitexp', 'clmp', 'opsel', 'neg'],
'VOPD': ['opx', 'opy', 'vdstx', 'vdsty', 'srcx0', 'vsrcx1', 'srcy0', 'vsrcy1'],
'LDSDIR': ['op', 'vdst', 'attr', 'attr_chan', 'wait_va']}
SRC_EXTRAS = {233: 'DPP8', 234: 'DPP8FI', 250: 'DPP16', 251: 'VCCZ', 252: 'EXECZ', 254: 'LDS_DIRECT'}
FLOAT_MAP = {'0.5': 'POS_HALF', '-0.5': 'NEG_HALF', '1.0': 'POS_ONE', '-1.0': 'NEG_ONE', '2.0': 'POS_TWO', '-2.0': 'NEG_TWO',
'4.0': 'POS_FOUR', '-4.0': 'NEG_FOUR', '1/(2*PI)': 'INV_2PI', '0': 'ZERO'}
INST_PATTERN = re.compile(r'^([SVD]S?_[A-Z0-9_]+|(?:FLAT|GLOBAL|SCRATCH)_[A-Z0-9_]+)\s+(\d+)\s*$', re.M)
# ═══════════════════════════════════════════════════════════════════════════════
# PDF PARSING WITH PAGE CACHING
# Generic PDF extraction tools
# ═══════════════════════════════════════════════════════════════════════════════
class CachedPDF:
"""PDF wrapper with page text/table caching for faster repeated access."""
def __init__(self, pdf):
self._pdf, self._text_cache, self._table_cache = pdf, {}, {}
def __len__(self): return len(self._pdf.pages)
def text(self, i):
if i not in self._text_cache: self._text_cache[i] = self._pdf.pages[i].extract_text() or ''
return self._text_cache[i]
def tables(self, i):
if i not in self._table_cache: self._table_cache[i] = [t.extract() for t in self._pdf.pages[i].find_tables()]
return self._table_cache[i]
def extract(url: str) -> list[list[tuple[float, float, str, str]]]:
"""Extract positioned text from PDF. Returns list of text elements (x, y, text, font) per page."""
data = fetch(url).read_bytes()
def _parse_bits(s: str) -> tuple[int, int] | None:
return (int(m.group(1)), int(m.group(2) or m.group(1))) if (m := re.match(r'\[(\d+)(?::(\d+))?\]', s)) else None
# Parse xref table to locate objects
xref: dict[int, int] = {}
pos = int(re.search(rb'startxref\s+(\d+)', data).group(1)) + 4
while data[pos:pos+7] != b'trailer':
while data[pos:pos+1] in b' \r\n': pos += 1
line_end = data.find(b'\n', pos)
start_obj, count = map(int, data[pos:line_end].split()[:2])
pos = line_end + 1
for i in range(count):
if data[pos+17:pos+18] == b'n' and (off := int(data[pos:pos+10])) > 0: xref[start_obj + i] = off
pos += 20
def _parse_fields_table(table: list, fmt: str, enums: set[str]) -> list[tuple]:
fields = []
for row in table[1:]:
if not row or not row[0]: continue
name, bits_str = row[0].split('\n')[0].strip(), (row[1] or '').split('\n')[0].strip()
if not (bits := _parse_bits(bits_str)): continue
enc_val, hi, lo = None, bits[0], bits[1]
if name == 'ENCODING' and row[2]:
desc = row[2]
# Handle shared FLAT/GLOBAL/SCRATCH table: look for format-specific encoding
fmt_key = fmt.lstrip('V').lower().capitalize() # VFLAT -> Flat, VGLOBAL -> Global
if m := re.search(rf"{fmt_key}='b([01_]+)", desc):
enc_bits = m.group(1).replace('_', '')
elif m := re.search(r"(?:'b|Must be:\s*)([01_]+)", desc):
enc_bits = m.group(1).replace('_', '')
else:
enc_bits = None
if enc_bits:
enc_val, declared_width, actual_width = int(enc_bits, 2), hi - lo + 1, len(enc_bits)
if actual_width > declared_width: lo = hi - actual_width + 1
ftype = f"{fmt}Op" if name == 'OP' and f"{fmt}Op" in enums else FIELD_TYPES.get(name.upper())
fields.append((name, hi, lo, enc_val, ftype))
return fields
def get_stream(n: int) -> bytes:
obj = data[xref[n]:data.find(b'endobj', xref[n])]
raw = obj[obj.find(b'stream\n') + 7:obj.find(b'\nendstream')]
return zlib.decompress(raw) if b'/FlateDecode' in obj else raw
def _parse_single_pdf(url: str):
"""Parse a single PDF and return (formats, enums, src_enum, doc_name, instructions)."""
import pdfplumber
from tinygrad.helpers import fetch
# Find page content streams and extract text
pages = []
for n in sorted(xref):
if b'/Type /Page' not in data[xref[n]:xref[n]+500]: continue
if not (m := re.search(rb'/Contents (\d+) 0 R', data[xref[n]:xref[n]+500])): continue
stream = get_stream(int(m.group(1))).decode('latin-1')
elements, font = [], ''
for bt in re.finditer(r'BT(.*?)ET', stream, re.S):
x, y = 0.0, 0.0
for m in re.finditer(r'(/F[\d.]+) [\d.]+ Tf|([\d.+-]+) ([\d.+-]+) Td|[\d.+-]+ [\d.+-]+ [\d.+-]+ [\d.+-]+ ([\d.+-]+) ([\d.+-]+) Tm|<([0-9A-Fa-f]+)>.*?Tj|\[([^\]]+)\] TJ', bt.group(1)):
if m.group(1): font = m.group(1)
elif m.group(2): x, y = x + float(m.group(2)), y + float(m.group(3))
elif m.group(4): x, y = float(m.group(4)), float(m.group(5))
elif m.group(6) and (t := bytes.fromhex(m.group(6)).decode('latin-1')).strip(): elements.append((x, y, t, font))
elif m.group(7) and (t := ''.join(bytes.fromhex(h).decode('latin-1') for h in re.findall(r'<([0-9A-Fa-f]+)>', m.group(7)))).strip(): elements.append((x, y, t, font))
pages.append(sorted(elements, key=lambda e: (-e[1], e[0])))
return pages
pdf = CachedPDF(pdfplumber.open(fetch(url)))
total_pages = len(pdf)
def extract_tables(pages: list[list[tuple[float, float, str, str]]]) -> dict[int, tuple[str, list[list[str]]]]:
"""Extract numbered tables from PDF pages. Returns {table_num: (title, rows)} where rows is list of cells per row."""
def group_by_y(texts, key=lambda y: round(y)):
by_y: dict[int, list[tuple[float, float, str]]] = {}
for x, y, t, _ in texts:
by_y.setdefault(key(y), []).append((x, y, t))
return by_y
# Auto-detect document type
first_page = pdf.text(0)
is_cdna4, is_cdna3 = 'CDNA4' in first_page or 'CDNA 4' in first_page, 'CDNA3' in first_page or 'MI300' in first_page
is_cdna, is_rdna4 = is_cdna3 or is_cdna4, 'RDNA4' in first_page or 'RDNA 4' in first_page
is_rdna35, is_rdna3 = 'RDNA3.5' in first_page or 'RDNA 3.5' in first_page, 'RDNA3' in first_page and 'RDNA3.5' not in first_page
doc_name = "CDNA4" if is_cdna4 else "CDNA3" if is_cdna3 else "RDNA4" if is_rdna4 else "RDNA3.5" if is_rdna35 else "RDNA3" if is_rdna3 else "Unknown"
# Find all table headers by merging text on same line
table_positions = []
for page_idx, texts in enumerate(pages):
for items in group_by_y(texts).values():
line = ''.join(t for _, t in sorted((x, t) for x, _, t in items))
if m := re.search(r'Table (\d+)\. (.+)', line):
table_positions.append((int(m.group(1)), m.group(2).strip(), page_idx, items[0][1]))
table_positions.sort(key=lambda t: (t[2], -t[3]))
# Find Microcode Formats section (for formats/enums)
microcode_start = next((i for i in range(int(total_pages * 0.2), total_pages)
if re.search(r'\d+\.\d+\.\d+\.\s+SOP2\b|Chapter \d+\.\s+Microcode Formats', pdf.text(i))), int(total_pages * 0.9))
# Find Instructions section (for pseudocode)
instr_start = next((i for i in range(int(total_pages * 0.1), int(total_pages * 0.5))
if re.search(r'Chapter \d+\.\s+Instructions\b', pdf.text(i))), total_pages // 3)
instr_end = next((i for start in [int(total_pages * 0.6), int(total_pages * 0.5), instr_start]
for i in range(start, min(start + 100, total_pages))
if re.search(r'Chapter \d+\.\s+Microcode Formats', pdf.text(i))), total_pages)
# Parse src enum from SSRC encoding table
src_enum = dict(SRC_EXTRAS)
for i in range(microcode_start, min(microcode_start + 10, total_pages)):
text = pdf.text(i)
if 'SSRC0' in text and 'VCC_LO' in text:
for m in re.finditer(r'^(\d+)\s+(\S+)', text, re.M):
val, name = int(m.group(1)), m.group(2).rstrip('.:')
if name in FLOAT_MAP: src_enum[val] = FLOAT_MAP[name]
elif re.match(r'^[A-Z][A-Z0-9_]*$', name): src_enum[val] = name
# For each table, find rows with matching X positions
result: dict[int, tuple[str, list[list[str]]]] = {}
for num, title, start_page, header_y in table_positions:
rows, col_xs = [], None
for page_idx in range(start_page, len(pages)):
page_texts = [(x, y, t) for x, y, t, _ in pages[page_idx] if 30 < y < 760 and (page_idx > start_page or y < header_y)]
for items in sorted(group_by_y([(x, y, t, '') for x, y, t in page_texts], key=lambda y: round(y / 5)).values(), key=lambda items: -items[0][1]):
xs = tuple(sorted(round(x) for x, _, _ in items))
if col_xs is None:
if len(xs) < 2: continue # Skip single-column rows before table starts
col_xs = xs
elif len(xs) == 1 and xs[0] in col_xs: continue # Skip continuation rows at known column positions
elif not any(c in xs for c in col_xs[:2]): break # Row missing first columns = end of table
rows.append([t for _, t in sorted((x, t) for x, _, t in items)])
else: continue
break
if rows: result[num] = (title, rows)
return result
# Parse opcode tables
full_text = '\n'.join(pdf.text(i) for i in range(microcode_start, min(microcode_start + 50, total_pages)))
# ═══════════════════════════════════════════════════════════════════════════════
# AMD specific extraction
# ═══════════════════════════════════════════════════════════════════════════════
def extract_enums(tables: dict[int, tuple[str, list[list[str]]]]) -> dict[str, dict[int, str]]:
"""Extract all enums from tables. Returns {enum_name: {value: name}}."""
enums: dict[str, dict[int, str]] = {}
for m in re.finditer(r'Table \d+\. (\w+) Opcodes(.*?)(?=Table \d+\.|\n\d+\.\d+\.\d+\.\s+\w+\s*\nDescription|$)', full_text, re.S):
if ops := {int(x.group(1)): x.group(2) for x in re.finditer(r'(\d+)\s+([A-Z][A-Z0-9_]+)', m.group(2))}:
enums[m.group(1) + "Op"] = ops
if vopd_m := re.search(r'Table \d+\. VOPD Y-Opcodes\n(.*?)(?=Table \d+\.|15\.\d)', full_text, re.S):
if ops := {int(x.group(1)): x.group(2) for x in re.finditer(r'(\d+)\s+(V_DUAL_\w+)', vopd_m.group(1))}:
enums["VOPDOp"] = ops
enum_names = set(enums.keys())
for num, (title, rows) in tables.items():
# Opcode enums from "XXX Opcodes" tables
if m := re.match(r'(\w+) (?:Y-)?Opcodes', title):
fmt_name = 'VOPD' if 'Y-Opcodes' in title else m.group(1)
ops: dict[int, str] = {}
for row in rows:
for i in range(0, len(row) - 1, 2):
if row[i].isdigit() and re.match(r'^[A-Z][A-Z0-9_]+$', row[i + 1]):
ops[int(row[i])] = row[i + 1]
if ops: enums[fmt_name] = ops
# BufFmt from "Data Format" tables
if 'Data Format' in title:
for row in rows:
for i in range(0, len(row) - 1, 2):
if row[i].isdigit() and re.match(r'^[\dA-Z_]+$', row[i + 1]) and 'INVALID' not in row[i + 1]:
enums.setdefault('BufFmt', {})[int(row[i])] = row[i + 1]
return enums
# Parse instruction formats
def is_fields_table(t): return t and len(t) > 1 and t[0] and 'Field' in str(t[0][0] or '')
def has_encoding(fields): return any(f[0] == 'ENCODING' for f in fields)
def has_header_before_fields(text): return (pos := text.find('Field Name')) != -1 and bool(re.search(r'\d+\.\d+\.\d+\.\s+\w+\s*\n', text[:pos]))
def extract_ins(tables: dict[int, tuple[str, list[list[str]]]]) -> tuple[dict[str, list[tuple[str, int, int]]], dict[str, str]]:
"""Extract formats and encodings from 'XXX Fields' tables. Returns (formats, encodings)."""
formats: dict[str, list[tuple[str, int, int]]] = {}
encodings: dict[str, str] = {}
for num, (title, rows) in tables.items():
if not (m := re.match(r'(\w+) Fields$', title)): continue
fmt_name = m.group(1)
fields = []
for row in rows:
if len(row) < 2: continue
if (bits := re.match(r'\[?(\d+):(\d+)\]?$', row[1])) or (bits := re.match(r'\[(\d+)\]$', row[1])):
field_name = row[0].lower()
hi, lo = int(bits.group(1)), int(bits.group(2)) if bits.lastindex >= 2 else int(bits.group(1))
if field_name == 'encoding' and len(row) >= 3:
enc_bits = None
if "'b" in row[2]: enc_bits = row[2].split("'b")[-1].replace('_', '')
elif (enc := re.search(r':\s*([01_]+)', row[2])): enc_bits = enc.group(1).replace('_', '')
if enc_bits:
# If encoding bits exceed field width, extend field to match (AMD docs sometimes have this)
declared_width, actual_width = hi - lo + 1, len(enc_bits)
if actual_width > declared_width: lo = hi - actual_width + 1
encodings[fmt_name] = enc_bits
fields.append((field_name, hi, lo))
if fields: formats[fmt_name] = fields
return formats, encodings
format_headers = []
for i in range(50):
if microcode_start + i >= total_pages: break
text = pdf.text(microcode_start + i)
for m in re.finditer(r'\d+\.\d+\.\d+\.\s+(\w+)\s*\n?Description', text): format_headers.append((m.group(1), i, m.start()))
for m in re.finditer(r'\d+\.\d+\.\d+\.\s+(\w+)\s*\n', text):
fmt_name = m.group(1)
if is_cdna and fmt_name.isupper() and len(fmt_name) >= 2: format_headers.append((fmt_name, i, m.start()))
elif m.start() > len(text) - 200 and 'Description' not in text[m.end():] and i + 1 < 50:
next_text = pdf.text(microcode_start + i + 1).lstrip()
if next_text.startswith('Description') or (next_text.startswith('"RDNA') and 'Description' in next_text[:200]):
format_headers.append((fmt_name, i, m.start()))
# RDNA4: Look for "Table X. Y Fields" patterns (e.g., VIMAGE, VSAMPLE, or shared FLAT/GLOBAL/SCRATCH)
for m in re.finditer(r'Table \d+\.\s+([\w,\s]+?)\s+Fields', text):
table_name = m.group(1).strip()
# Handle shared table like "FLAT, GLOBAL and SCRATCH"
if ',' in table_name or ' and ' in table_name:
for part in re.split(r',\s*|\s+and\s+', table_name):
fmt_name = 'V' + part.strip()
if fmt_name not in [h[0] for h in format_headers]: format_headers.append((fmt_name, i, m.start()))
elif table_name.startswith('V'):
if table_name not in [h[0] for h in format_headers]: format_headers.append((table_name, i, m.start()))
def extract_pcode(pages: list[list[tuple[float, float, str, str]]], enums: dict[str, dict[int, str]]) -> dict[tuple[str, int], str]:
"""Extract pseudocode for instructions. Returns {(name, opcode): pseudocode}."""
# Build lookup from instruction name to opcode
name_to_op = {name: op for ops in enums.values() for op, name in ops.items()}
formats: dict[str, list] = {}
for fmt_name, rel_idx, header_pos in format_headers:
if fmt_name in formats: continue
page_idx = microcode_start + rel_idx
text = pdf.text(page_idx)
field_pos = text.find('Field Name', header_pos)
fields = None
for offset in range(3):
if page_idx + offset >= total_pages: break
if offset > 0 and has_header_before_fields(pdf.text(page_idx + offset)): break
for t in pdf.tables(page_idx + offset) if offset > 0 or field_pos > header_pos else []:
if is_fields_table(t) and (f := _parse_fields_table(t, fmt_name, enum_names)) and has_encoding(f): fields = f; break
if fields: break
if not fields and field_pos > header_pos:
for t in pdf.tables(page_idx):
if is_fields_table(t) and (f := _parse_fields_table(t, fmt_name, enum_names)): fields = f; break
if not fields: continue
field_names = {f[0] for f in fields}
for pg_offset in range(1, 3):
if page_idx + pg_offset >= total_pages or has_header_before_fields(pdf.text(page_idx + pg_offset)): break
for t in pdf.tables(page_idx + pg_offset):
if is_fields_table(t) and (extra := _parse_fields_table(t, fmt_name, enum_names)) and not has_encoding(extra):
for ef in extra:
if ef[0] not in field_names: fields.append(ef); field_names.add(ef[0])
break
formats[fmt_name] = fields
# First pass: find all instruction headers across all pages
all_instructions: list[tuple[int, float, str, int]] = [] # (page_idx, y, name, opcode)
for page_idx, page in enumerate(pages):
by_y: dict[int, list[tuple[float, str]]] = {}
for x, y, t, _ in page:
by_y.setdefault(round(y), []).append((x, t))
for y, items in sorted(by_y.items(), reverse=True):
left = [(x, t) for x, t in items if 55 < x < 65]
right = [(x, t) for x, t in items if 535 < x < 550]
if left and right and left[0][1] in name_to_op and right[0][1].isdigit():
all_instructions.append((page_idx, y, left[0][1], int(right[0][1])))
# Fix known PDF errors
if 'SMEM' in formats:
formats['SMEM'] = [(n, 13 if n == 'DLC' else 14 if n == 'GLC' else h, 13 if n == 'DLC' else 14 if n == 'GLC' else l, e, t)
for n, h, l, e, t in formats['SMEM']]
# RDNA4: VFLAT/VGLOBAL/VSCRATCH OP field is [20:14] not [20:13] (PDF documentation error)
for fmt_name in ['VFLAT', 'VGLOBAL', 'VSCRATCH']:
if fmt_name in formats:
formats[fmt_name] = [(n, h, 14 if n == 'OP' else l, e, t) for n, h, l, e, t in formats[fmt_name]]
if doc_name in ('RDNA3', 'RDNA3.5'):
if 'SOPPOp' in enums:
for k, v in {8: 'S_WAITCNT_DEPCTR', 58: 'S_TTRACEDATA', 59: 'S_TTRACEDATA_IMM'}.items():
assert k not in enums['SOPPOp']; enums['SOPPOp'][k] = v
if 'SOPKOp' in enums:
for k, v in {22: 'S_SUBVECTOR_LOOP_BEGIN', 23: 'S_SUBVECTOR_LOOP_END'}.items():
assert k not in enums['SOPKOp']; enums['SOPKOp'][k] = v
if 'SMEMOp' in enums:
for k, v in {34: 'S_ATC_PROBE', 35: 'S_ATC_PROBE_BUFFER'}.items():
assert k not in enums['SMEMOp']; enums['SMEMOp'][k] = v
if 'DSOp' in enums:
for k, v in {24: 'DS_GWS_SEMA_RELEASE_ALL', 25: 'DS_GWS_INIT', 26: 'DS_GWS_SEMA_V', 27: 'DS_GWS_SEMA_BR', 28: 'DS_GWS_SEMA_P', 29: 'DS_GWS_BARRIER'}.items():
assert k not in enums['DSOp']; enums['DSOp'][k] = v
if 'FLATOp' in enums:
for k, v in {40: 'GLOBAL_LOAD_ADDTID_B32', 41: 'GLOBAL_STORE_ADDTID_B32', 55: 'FLAT_ATOMIC_CSUB_U32'}.items():
assert k not in enums['FLATOp']; enums['FLATOp'][k] = v
# CDNA SDWA/DPP: PDF only has modifier fields, need VOP1/VOP2 overlay for correct encoding
if is_cdna:
if 'SDWA' in formats:
formats['SDWA'] = [('ENCODING', 8, 0, 0xf9, None), ('VOP_OP', 16, 9, None, None), ('VDST', 24, 17, None, 'VGPRField'), ('VOP2_OP', 31, 25, None, None)] + \
[f for f in formats['SDWA'] if f[0] not in ('ENCODING', 'SDST', 'SD', 'ROW_MASK')]
if 'DPP' in formats:
formats['DPP'] = [('ENCODING', 8, 0, 0xfa, None), ('VOP_OP', 16, 9, None, None), ('VDST', 24, 17, None, 'VGPRField'), ('VOP2_OP', 31, 25, None, None),
('SRC0', 39, 32, None, 'Src'), ('DPP_CTRL', 48, 40, None, None), ('BOUND_CTRL', 51, 51, None, None), ('SRC0_NEG', 52, 52, None, None), ('SRC0_ABS', 53, 53, None, None),
('SRC1_NEG', 54, 54, None, None), ('SRC1_ABS', 55, 55, None, None), ('BANK_MASK', 59, 56, None, None), ('ROW_MASK', 63, 60, None, None)]
# Extract pseudocode for instructions
all_text = '\n'.join(pdf.text(i) for i in range(instr_start, instr_end))
matches = list(INST_PATTERN.finditer(all_text))
raw_pseudocode: dict[tuple[str, int], str] = {}
for i, match in enumerate(matches):
name, opcode = match.group(1), int(match.group(2))
start, end = match.end(), matches[i + 1].start() if i + 1 < len(matches) else match.end() + 2000
snippet = all_text[start:end].strip()
if pseudocode := _extract_pseudocode(snippet): raw_pseudocode[(name, opcode)] = pseudocode
return {"formats": formats, "enums": enums, "src_enum": src_enum, "doc_name": doc_name, "pseudocode": raw_pseudocode, "is_cdna": is_cdna}
def _extract_pseudocode(text: str) -> str | None:
"""Extract pseudocode from an instruction description snippet."""
lines, result, depth, in_lambda = text.split('\n'), [], 0, 0
for line in lines:
s = line.strip()
if not s or re.match(r'^\d+ of \d+$', s) or re.match(r'^\d+\.\d+\..*Instructions', s): continue
if s.startswith(('Notes', 'Functional examples', '', '-')): break # Stop at notes/bullets
if s.startswith(('"RDNA', 'AMD ', 'CDNA')): continue
if '' in s or '' in s: continue # Skip lines with bullets/dashes
if '= lambda(' in s: in_lambda += 1; continue
if in_lambda > 0:
if s.endswith(');'): in_lambda -= 1
continue
if s.startswith('if '): depth += 1
elif s.startswith('endif'): depth = max(0, depth - 1)
if s.endswith('.') and not any(p in s for p in ['D0', 'D1', 'S0', 'S1', 'S2', 'SCC', 'VCC', 'tmp', '=']): continue
if re.match(r'^[a-z].*\.$', s) and '=' not in s: continue
is_code = (any(p in s for p in ['D0.', 'D1.', 'S0.', 'S1.', 'S2.', 'SCC =', 'SCC ?', 'VCC', 'EXEC', 'tmp =', 'tmp[', 'lane =', 'PC =',
'D0[', 'D1[', 'S0[', 'S1[', 'S2[', 'MEM[', 'RETURN_DATA',
'VADDR', 'VDATA', 'VDST', 'SADDR', 'OFFSET']) or
s.startswith(('if ', 'else', 'elsif', 'endif', 'declare ', 'for ', 'endfor', '//')) or
re.match(r'^[a-z_]+\s*=', s) or re.match(r'^[a-z_]+\[', s) or (depth > 0 and '=' in s))
if is_code: result.append(s)
return '\n'.join(result) if result else None
def _merge_results(results: list[dict]) -> dict:
"""Merge multiple PDF parse results into a superset."""
merged = {"formats": {}, "enums": {}, "src_enum": dict(SRC_EXTRAS), "doc_names": [], "pseudocode": {}, "is_cdna": False}
for r in results:
merged["doc_names"].append(r["doc_name"])
merged["is_cdna"] = merged["is_cdna"] or r["is_cdna"]
for val, name in r["src_enum"].items():
if val in merged["src_enum"]: assert merged["src_enum"][val] == name
else: merged["src_enum"][val] = name
for enum_name, ops in r["enums"].items():
if enum_name not in merged["enums"]: merged["enums"][enum_name] = {}
for val, name in ops.items():
if val in merged["enums"][enum_name]: assert merged["enums"][enum_name][val] == name
else: merged["enums"][enum_name][val] = name
for fmt_name, fields in r["formats"].items():
if fmt_name not in merged["formats"]: merged["formats"][fmt_name] = list(fields)
else:
existing = {f[0]: (f[1], f[2]) for f in merged["formats"][fmt_name]}
for f in fields:
if f[0] in existing: assert existing[f[0]] == (f[1], f[2])
else: merged["formats"][fmt_name].append(f)
for key, pc in r["pseudocode"].items():
if key not in merged["pseudocode"]: merged["pseudocode"][key] = pc
return merged
# Second pass: extract pseudocode between consecutive instructions
pcode: dict[tuple[str, int], str] = {}
for i, (page_idx, y, name, opcode) in enumerate(all_instructions):
# Get end boundary from next instruction
if i + 1 < len(all_instructions):
next_page, next_y = all_instructions[i + 1][0], all_instructions[i + 1][1]
else:
next_page, next_y = page_idx, 0
# Collect F6 text from current position to next instruction
lines = []
for p in range(page_idx, next_page + 1):
start_y = y if p == page_idx else 800
end_y = next_y if p == next_page else 0
lines.extend((p, y2, t) for x, y2, t, f in pages[p] if f in ('/F6.0', '/F7.0') and end_y < y2 < start_y)
if lines:
# Sort by page first, then by y descending within each page (higher y = earlier text in PDF)
pcode_lines = [t.replace('Ê', '').strip() for _, _, t in sorted(lines, key=lambda x: (x[0], -x[1]))]
if pcode_lines: pcode[(name, opcode)] = '\n'.join(pcode_lines)
return pcode
# ═══════════════════════════════════════════════════════════════════════════════
# CODE GENERATION
# Write autogen files
# ═══════════════════════════════════════════════════════════════════════════════
def _generate_enum_py(enums, src_enum, doc_name) -> str:
"""Generate enum.py content (just enums, no dsl.py dependency)."""
def enum_lines(name, items): return [f"class {name}(IntEnum):"] + [f" {n} = {v}" for v, n in sorted(items.items())] + [""]
lines = [f"# autogenerated from AMD {doc_name} ISA PDF by pdf.py - do not edit", "from enum import IntEnum", ""]
lines += enum_lines("SrcEnum", src_enum) + sum([enum_lines(n, ops) for n, ops in sorted(enums.items())], [])
return '\n'.join(lines)
def _generate_ins_py(formats, enums, src_enum, doc_name) -> str:
"""Generate ins.py content (instruction formats and helpers, imports dsl.py and enum.py)."""
def field_key(f, order): return order.index(f[0].lower()) if f[0].lower() in order else 1000
lines = [f"# autogenerated from AMD {doc_name} ISA PDF by pdf.py - do not edit",
"# ruff: noqa: F401,F403", "from typing import Annotated",
"from extra.assembly.amd.dsl import bits, BitField, Inst32, Inst64, Inst96, SGPR, VGPR, TTMP as TTMP, s as s, v as v, ttmp as ttmp, SSrc, Src, SImm, Imm, VDSTYEnc, SGPRField, VGPRField",
"from extra.assembly.amd.autogen.{arch}.enum import *",
"import functools", ""]
format_defaults = {'VOP3P': {'opsel_hi': 3, 'opsel_hi2': 1}}
lines.append("# instruction formats")
# MIMG has optional NSA (Non-Sequential Address) fields that extend beyond 64 bits, but base encoding is 64-bit
inst64_override = {'MIMG'}
for fmt_name, fields in sorted(formats.items()):
max_bit = max(f[1] for f in fields)
if fmt_name in inst64_override: base = "Inst64"
else: base = "Inst96" if max_bit > 63 else "Inst64" if max_bit > 31 or fmt_name == 'VOP3SD' else "Inst32"
order = FIELD_ORDER.get(fmt_name, [])
lines.append(f"class {fmt_name}({base}):")
if enc := next((f for f in fields if f[0] == 'ENCODING'), None):
lines.append(f" encoding = bits[{enc[1]}:{enc[2]}] == 0b{enc[3]:b}" if enc[1] != enc[2] else f" encoding = bits[{enc[1]}] == {enc[3]}")
if defaults := format_defaults.get(fmt_name): lines.append(f" _defaults = {defaults}")
for name, hi, lo, _, ftype in sorted([f for f in fields if f[0] != 'ENCODING'], key=lambda f: field_key(f, order)):
ann = f":Annotated[BitField, {ftype}]" if ftype and ftype.endswith('Op') else f":{ftype}" if ftype else ""
lines.append(f" {name.lower()}{ann} = bits[{hi}]" if hi == lo else f" {name.lower()}{ann} = bits[{hi}:{lo}]")
def write_enums(enums: dict[str, dict[int, str]], arch: str, path: str):
"""Write enum.py file from extracted enums."""
lines = ["# autogenerated from AMD ISA PDF by pdf.py - do not edit", "from enum import IntEnum", ""]
for name, values in sorted(enums.items()):
suffix = "Op" if name not in ('Src', 'BufFmt') else ("Enum" if name == 'Src' else "")
prefix = "BUF_FMT_" if name == 'BufFmt' else ""
lines.append(f"class {name}{suffix}(IntEnum):")
for val, member in sorted(values.items()):
lines.append(f" {prefix}{member} = {val}")
lines.append("")
with open(path, "w") as f:
f.write("\n".join(lines))
def write_ins(formats: dict[str, list[tuple[str, int, int]]], encodings: dict[str, str], enums: dict[str, dict[int, str]], arch: str, path: str):
"""Write ins.py file from extracted formats and enums."""
# Field types and ordering
def field_type(name, fmt):
if name == 'op' and fmt in enums: return f'Annotated[BitField, {fmt}Op]'
if name in ('opx', 'opy'): return 'Annotated[BitField, VOPDOp]'
if name == 'vdsty': return 'VDSTYEnc'
if name in ('vdst', 'vsrc1', 'vaddr', 'vdata', 'data', 'data0', 'data1', 'addr', 'vsrc0', 'vsrc2', 'vsrc3'): return 'VGPRField'
if name in ('sdst', 'sbase', 'sdata', 'srsrc', 'ssamp'): return 'SGPRField'
if name.startswith('ssrc') or name in ('saddr', 'soffset'): return 'SSrc'
if name in ('src0', 'srcx0', 'srcy0') or name.startswith('src') and name[3:].isdigit(): return 'Src'
if name.startswith('simm'): return 'SImm'
if name == 'offset' or name.startswith('imm'): return 'Imm'
return None
field_priority = ['encoding', 'op', 'opx', 'opy', 'vdst', 'vdstx', 'vdsty', 'sdst', 'vdata', 'sdata', 'addr', 'vaddr', 'data', 'data0', 'data1',
'src0', 'srcx0', 'srcy0', 'vsrc0', 'ssrc0', 'src1', 'vsrc1', 'vsrcx1', 'vsrcy1', 'ssrc1', 'src2', 'vsrc2', 'src3', 'vsrc3',
'saddr', 'sbase', 'srsrc', 'ssamp', 'soffset', 'offset', 'simm16', 'en', 'target', 'attr', 'attr_chan',
'omod', 'neg', 'neg_hi', 'abs', 'clmp', 'opsel', 'opsel_hi', 'waitexp', 'wait_va',
'dmask', 'dim', 'seg', 'format', 'offen', 'idxen', 'glc', 'dlc', 'slc', 'tfe', 'unrm', 'done', 'row']
def sort_fields(fields):
order = {name: i for i, name in enumerate(field_priority)}
return sorted(fields, key=lambda f: (order.get(f[0], 1000), f[2]))
# Generate format classes
lines = ["# autogenerated from AMD ISA PDF by pdf.py - do not edit", "# ruff: noqa: F401,F403",
"from typing import Annotated",
"from extra.assembly.amd.dsl import *",
f"from extra.assembly.amd.autogen.{arch}.enum import *", "import functools", ""]
for fmt_name, fields in sorted(formats.items()):
lines.append(f"class {fmt_name}(Inst):")
for name, hi, lo in sort_fields(fields):
bits_str = f"bits[{hi}:{lo}]" if hi != lo else f"bits[{hi}]"
if name == 'encoding' and fmt_name in encodings: lines.append(f" encoding = {bits_str} == 0b{encodings[fmt_name]}")
else:
ftype = field_type(name, fmt_name)
lines.append(f" {name}{f':{ftype}' if ftype else ''} = {bits_str}")
lines.append("")
# Generate instruction helpers
lines.append("# instruction helpers")
for cls_name, ops in sorted(enums.items()):
fmt = cls_name[:-2]
for op_val, name in sorted(ops.items()):
seg = {"GLOBAL": ", seg=2", "SCRATCH": ", seg=1"}.get(fmt, "")
tgt = {"GLOBAL": "FLAT, GLOBALOp", "SCRATCH": "FLAT, SCRATCHOp"}.get(fmt, f"{fmt}, {cls_name}")
if fmt in formats or fmt in ("GLOBAL", "SCRATCH"):
suffix = "_e32" if fmt in ("VOP1", "VOP2", "VOPC") else "_e64" if fmt == "VOP3" and op_val < 512 else ""
if name in ('V_FMAMK_F32', 'V_FMAMK_F16'):
lines.append(f"def {name.lower()}{suffix}(vdst, src0, K, vsrc1): return {fmt}({cls_name}.{name}, vdst, src0, vsrc1, literal=K)")
elif name in ('V_FMAAK_F32', 'V_FMAAK_F16'):
lines.append(f"def {name.lower()}{suffix}(vdst, src0, vsrc1, K): return {fmt}({cls_name}.{name}, vdst, src0, vsrc1, literal=K)")
else: lines.append(f"{name.lower()}{suffix} = functools.partial({tgt}.{name}{seg})")
src_names = {name for _, name in src_enum.items()}
lines += [""] + [f"{name} = SrcEnum.{name}" for _, name in sorted(src_enum.items()) if name not in {'DPP8', 'DPP16'}]
if "NULL" in src_names: lines.append("OFF = NULL\n")
return '\n'.join(lines)
for fmt_name, ops in sorted(enums.items()):
seg = {"GLOBAL": ", seg=2", "SCRATCH": ", seg=1"}.get(fmt_name, "")
tgt = {"GLOBAL": "FLAT, GLOBALOp", "SCRATCH": "FLAT, SCRATCHOp"}.get(fmt_name, f"{fmt_name}, {fmt_name}Op")
suffix = "_e32" if fmt_name in ("VOP1", "VOP2", "VOPC") else "_e64" if fmt_name == "VOP3" and len(ops) > 0 else ""
if fmt_name in formats or fmt_name in ("GLOBAL", "SCRATCH"):
for op_val, name in sorted(ops.items()):
fn_suffix = suffix if fmt_name != "VOP3" or op_val < 512 else ""
lines.append(f"{name.lower()}{fn_suffix} = functools.partial({tgt}.{name}{seg})")
def _generate_str_pcode_py(enums, pseudocode, arch) -> str:
"""Generate str_pcode.py content (raw pseudocode strings)."""
# Get op enums for this arch (import from .ins which re-exports from .enum)
import importlib
autogen = importlib.import_module(f"extra.assembly.amd.autogen.{arch}.ins")
OP_ENUMS = [getattr(autogen, name) for name in ['SOP1Op', 'SOP2Op', 'SOPCOp', 'SOPKOp', 'SOPPOp', 'SMEMOp', 'VOP1Op', 'VOP2Op', 'VOP3Op', 'VOP3SDOp', 'VOP3POp', 'VOPCOp', 'VOP3AOp', 'VOP3BOp', 'DSOp', 'FLATOp', 'GLOBALOp', 'SCRATCHOp'] if hasattr(autogen, name)]
with open(path, "w") as f:
f.write("\n".join(lines))
# Build defined ops mapping
defined_ops: dict[tuple, list] = {}
for enum_cls in OP_ENUMS:
for op in enum_cls:
if op.name.startswith(('S_', 'V_', 'DS_', 'FLAT_', 'GLOBAL_', 'SCRATCH_')): defined_ops.setdefault((op.name, op.value), []).append((enum_cls, op))
enum_names = [e.__name__ for e in OP_ENUMS]
instructions: dict = {cls: {} for cls in OP_ENUMS}
for key, pc in pseudocode.items():
if key in defined_ops:
for enum_cls, enum_val in defined_ops[key]: instructions[enum_cls][enum_val] = pc
# Build string dictionaries for each enum
lines = [f'''# autogenerated by pdf.py - do not edit
# to regenerate: python -m extra.assembly.amd.pdf --arch {arch}
# ruff: noqa: E501
from extra.assembly.amd.autogen.{arch}.enum import {", ".join(enum_names)}
''']
all_dict_entries: dict = {}
for enum_cls in OP_ENUMS:
cls_name = enum_cls.__name__
if not instructions.get(enum_cls): continue
dict_entries = [(op, repr(pc)) for op, pc in instructions[enum_cls].items()]
if dict_entries:
all_dict_entries[enum_cls] = dict_entries
lines.append(f'{cls_name}_PCODE = {{')
for op, escaped in dict_entries: lines.append(f" {cls_name}.{op.name}: {escaped},")
lines.append('}\n')
lines.append('PSEUDOCODE_STRINGS = {')
for enum_cls in OP_ENUMS:
if all_dict_entries.get(enum_cls): lines.append(f' {enum_cls.__name__}: {enum_cls.__name__}_PCODE,')
lines.append('}')
return '\n'.join(lines)
# ═══════════════════════════════════════════════════════════════════════════════
# MAIN GENERATION
# ═══════════════════════════════════════════════════════════════════════════════
def generate_arch(arch: str) -> dict:
"""Generate enum.py, ins.py and str_pcode.py for a single architecture."""
urls = PDF_URLS[arch]
if isinstance(urls, str): urls = [urls]
print(f"\n{'='*60}\nGenerating {arch}...")
print(f"Parsing {len(urls)} PDF(s)...")
results = [_parse_single_pdf(url) for url in urls]
merged = _merge_results(results) if len(results) > 1 else results[0]
doc_name = "+".join(merged["doc_names"]) if len(results) > 1 else merged["doc_name"]
base_path = Path(f"extra/assembly/amd/autogen/{arch}")
base_path.mkdir(parents=True, exist_ok=True)
(base_path / "__init__.py").touch()
# Write enum.py (enums only, no dsl.py dependency)
enum_path = base_path / "enum.py"
enum_content = _generate_enum_py(merged["enums"], merged["src_enum"], doc_name)
enum_path.write_text(enum_content)
print(f"Generated {enum_path}: SrcEnum ({len(merged['src_enum'])}) + {len(merged['enums'])} enums")
# Write ins.py (instruction formats and helpers, imports dsl.py and enum.py)
ins_path = base_path / "ins.py"
ins_content = _generate_ins_py(merged["formats"], merged["enums"], merged["src_enum"], doc_name).replace("{arch}", arch)
ins_path.write_text(ins_content)
print(f"Generated {ins_path}: {len(merged['formats'])} formats")
# Write str_pcode.py (needs enum.py to exist first for imports)
pcode_path = base_path / "str_pcode.py"
pcode_content = _generate_str_pcode_py(merged["enums"], merged["pseudocode"], arch)
pcode_path.write_text(pcode_content)
print(f"Generated {pcode_path}: {len(merged['pseudocode'])} instructions")
return merged
def _generate_arch_wrapper(arch: str):
"""Wrapper for multiprocessing - returns arch name for ordering."""
generate_arch(arch)
return arch
def generate_all():
"""Generate all architectures in parallel."""
with ProcessPoolExecutor() as executor:
list(executor.map(_generate_arch_wrapper, PDF_URLS.keys()))
def write_pcode(pcode: dict[tuple[str, int], str], enums: dict[str, dict[int, str]], arch: str, path: str):
"""Write str_pcode.py file from extracted pseudocode."""
# Group pseudocode by enum class
by_enum: dict[str, list[tuple[str, int, str]]] = {}
for fmt_name, ops in enums.items():
for opcode, name in ops.items():
if (name, opcode) in pcode: by_enum.setdefault(f"{fmt_name}Op", []).append((name, opcode, pcode[(name, opcode)]))
# Generate file
enum_names = sorted(by_enum.keys())
lines = [f"# autogenerated by pdf.py - do not edit", f"# to regenerate: python -m extra.assembly.amd.pdf",
"# ruff: noqa: E501", f"from extra.assembly.amd.autogen.{arch}.enum import {', '.join(enum_names)}", ""]
for enum_name in enum_names:
lines.append(f"{enum_name}_PCODE = {{")
for name, opcode, code in sorted(by_enum[enum_name], key=lambda x: x[1]):
lines.append(f" {enum_name}.{name}: {code!r},")
lines.append("}\n")
lines.append(f"PSEUDOCODE_STRINGS = {{{', '.join(f'{e}: {e}_PCODE' for e in enum_names)}}}")
with open(path, "w") as f:
f.write("\n".join(lines))
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Generate AMD ISA autogen files from PDF documentation")
parser.add_argument("--arch", choices=list(PDF_URLS.keys()) + ["all"], default="rdna3")
args = parser.parse_args()
if args.arch == "all": generate_all()
else: generate_arch(args.arch)
import pathlib
for arch, url in PDF_URLS.items():
print(f"Processing {arch}...")
pages = extract(url)
tables = extract_tables(pages)
enums = extract_enums(tables)
formats, encodings = extract_ins(tables)
pcode = extract_pcode(pages, enums)
# Fix known PDF errors
if arch == 'rdna3':
fixes = {'SOPP': {8: 'S_WAITCNT_DEPCTR', 58: 'S_TTRACEDATA', 59: 'S_TTRACEDATA_IMM'},
'SOPK': {22: 'S_SUBVECTOR_LOOP_BEGIN', 23: 'S_SUBVECTOR_LOOP_END'},
'SMEM': {34: 'S_ATC_PROBE', 35: 'S_ATC_PROBE_BUFFER'},
'DS': {24: 'DS_GWS_SEMA_RELEASE_ALL', 25: 'DS_GWS_INIT', 26: 'DS_GWS_SEMA_V', 27: 'DS_GWS_SEMA_BR', 28: 'DS_GWS_SEMA_P', 29: 'DS_GWS_BARRIER'},
'FLAT': {40: 'GLOBAL_LOAD_ADDTID_B32', 41: 'GLOBAL_STORE_ADDTID_B32', 55: 'FLAT_ATOMIC_CSUB_U32'}}
for fmt, ops in fixes.items(): enums[fmt] = merge_dicts([enums[fmt], ops])
if arch in ('rdna3', 'rdna4'):
# RDNA SMEM: PDF says DLC=[14], GLC=[16] but hardware uses DLC=[13], GLC=[14]
if 'SMEM' in formats:
formats['SMEM'] = [(n, 13 if n == 'dlc' else 14 if n == 'glc' else h, 13 if n == 'dlc' else 14 if n == 'glc' else l)
for n, h, l in formats['SMEM']]
if arch == 'cdna':
# CDNA DS: PDF is missing the GDS field (bit 16)
if 'DS' in formats and not any(n == 'gds' for n, _, _ in formats['DS']):
formats['DS'].append(('gds', 16, 16))
# CDNA DPP/SDWA: PDF only documents modifier fields (bits[63:32]), need to add VOP overlay fields (bits[31:0])
vop_overlay = [('encoding', 8, 0), ('vop_op', 16, 9), ('vdst', 24, 17), ('vop2_op', 31, 25)]
if 'DPP' in formats and not any(n == 'encoding' for n, _, _ in formats['DPP']):
formats['DPP'] = vop_overlay + [('bc' if n == 'bound_ctrl' else n, h, l) for n, h, l in formats['DPP']]
encodings['DPP'] = '11111010'
if 'SDWA' in formats and not any(n == 'encoding' for n, _, _ in formats['SDWA']):
formats['SDWA'] = vop_overlay + [(n, h, l) for n, h, l in formats['SDWA']]
encodings['SDWA'] = '11111001'
base = pathlib.Path(__file__).parent / "autogen" / arch
write_enums(enums, arch, base / "enum.py")
write_ins(formats, encodings, enums, arch, base / "ins.py")
write_pcode(pcode, enums, arch, base / "str_pcode.py")
print(f" {len(tables)} tables, {len(pcode)} pcode -> {base}")

View File

@@ -1615,7 +1615,7 @@ class TestCarryBorrow(unittest.TestCase):
v_mov_b32_e32(v[2], s[2]),
v_mov_b32_e32(v[3], s[3]),
v_add_co_u32(v[4], VCC, v[0], v[2]),
v_add_co_ci_u32_e32(v[5], VCC, v[1], v[3]),
v_add_co_ci_u32_e32(v[5], v[1], v[3]),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][4], 0x00000000, "lo result")

View File

@@ -271,7 +271,7 @@ class TestVOP3P(unittest.TestCase):
s_mov_b32(s[1], 0x44004200), # hi=4.0, lo=3.0
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
v_pk_add_f16(v[2], v[0], v[1]),
v_pk_add_f16(v[2], v[0], v[1], opsel_hi=3, opsel_hi2=1),
]
st = run_program(instructions, n_lanes=1)
result = st.vgpr[0][2]
@@ -288,7 +288,7 @@ class TestVOP3P(unittest.TestCase):
s_mov_b32(s[1], 0x45004400), # hi=5.0, lo=4.0
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
v_pk_mul_f16(v[2], v[0], v[1]),
v_pk_mul_f16(v[2], v[0], v[1], opsel_hi=3, opsel_hi2=1),
]
st = run_program(instructions, n_lanes=1)
result = st.vgpr[0][2]
@@ -307,7 +307,7 @@ class TestVOP3P(unittest.TestCase):
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
v_mov_b32_e32(v[2], s[2]),
v_pk_fma_f16(v[3], v[0], v[1], v[2]),
v_pk_fma_f16(v[3], v[0], v[1], v[2], opsel_hi=3, opsel_hi2=1),
]
st = run_program(instructions, n_lanes=1)
result = st.vgpr[0][3]
@@ -325,7 +325,7 @@ class TestVOP3P(unittest.TestCase):
instructions = [
s_mov_b32(s[0], 0x3c003c00), # packed f16: hi=1.0, lo=1.0
v_mov_b32_e32(v[0], s[0]),
v_pk_add_f16(v[1], v[0], SrcEnum.POS_ONE), # Add inline constant 1.0
v_pk_add_f16(v[1], v[0], SrcEnum.POS_ONE, opsel_hi=3, opsel_hi2=1), # Add inline constant 1.0
]
st = run_program(instructions, n_lanes=1)
result = st.vgpr[0][1]
@@ -345,7 +345,7 @@ class TestVOP3P(unittest.TestCase):
instructions = [
s_mov_b32(s[0], 0x44004200), # packed f16: hi=4.0, lo=3.0
v_mov_b32_e32(v[0], s[0]),
v_pk_mul_f16(v[1], v[0], SrcEnum.POS_TWO),
v_pk_mul_f16(v[1], v[0], SrcEnum.POS_TWO, opsel_hi=3, opsel_hi2=1),
]
st = run_program(instructions, n_lanes=1)
result = st.vgpr[0][1]
@@ -486,12 +486,12 @@ class TestSpecialOps(unittest.TestCase):
"""V_DOT2_F32_BF16 computes dot product of bf16 pairs."""
# bf16 1.0 = 0x3f80, bf16 2.0 = 0x4000
instructions = [
s_mov_b32(s[0], 0x3f803f80), # packed bf16: 1.0, 1.0
s_mov_b32(s[1], 0x40003f80), # packed bf16: 2.0, 1.0
s_mov_b32(s[0], 0x3f803f80), # packed bf16: lo=1.0, hi=1.0
s_mov_b32(s[1], 0x40003f80), # packed bf16: lo=1.0, hi=2.0
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
v_mov_b32_e32(v[2], 0),
v_dot2_f32_bf16(v[3], v[0], v[1], v[2]),
v_dot2_f32_bf16(v[3], v[0], v[1], v[2], opsel_hi=3, opsel_hi2=1),
]
st = run_program(instructions, n_lanes=1)
# 1.0*1.0 + 1.0*2.0 + 0 = 3.0
@@ -510,7 +510,7 @@ class TestPackedMixedSigns(unittest.TestCase):
s_mov_b32(s[1], 0x3c003c00), # packed: hi=1.0, lo=1.0
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
v_pk_add_f16(v[2], v[0], v[1]),
v_pk_add_f16(v[2], v[0], v[1], opsel_hi=3, opsel_hi2=1),
]
st = run_program(instructions, n_lanes=1)
result = st.vgpr[0][2]

View File

@@ -1,192 +1,102 @@
#!/usr/bin/env python3
"""Test RDNA3 assembler/disassembler against LLVM test vectors."""
import unittest, re, subprocess
"""Test AMD assembler/disassembler against LLVM test vectors."""
import unittest, re, subprocess, functools
from tinygrad.helpers import fetch
from extra.assembly.amd.autogen.rdna3.ins import *
from extra.assembly.amd.asm import asm
from extra.assembly.amd.asm import asm, disasm, detect_format
from extra.assembly.amd.test.helpers import get_llvm_mc
LLVM_BASE = "https://raw.githubusercontent.com/llvm/llvm-project/main/llvm/test/MC/AMDGPU"
# Format info: (filename, format_class, op_enum)
LLVM_TEST_FILES = {
# Scalar ALU
'sop1': ('gfx11_asm_sop1.s', SOP1, SOP1Op),
'sop2': ('gfx11_asm_sop2.s', SOP2, SOP2Op),
'sopp': ('gfx11_asm_sopp.s', SOPP, SOPPOp),
'sopk': ('gfx11_asm_sopk.s', SOPK, SOPKOp),
'sopc': ('gfx11_asm_sopc.s', SOPC, SOPCOp),
# Vector ALU
'vop1': ('gfx11_asm_vop1.s', VOP1, VOP1Op),
'vop2': ('gfx11_asm_vop2.s', VOP2, VOP2Op),
'vopc': ('gfx11_asm_vopc.s', VOPC, VOPCOp),
'vop3': ('gfx11_asm_vop3.s', VOP3, VOP3Op),
'vop3p': ('gfx11_asm_vop3p.s', VOP3P, VOP3POp),
'vop3sd': ('gfx11_asm_vop3.s', VOP3SD, VOP3SDOp), # VOP3SD shares file with VOP3
'vinterp': ('gfx11_asm_vinterp.s', VINTERP, VINTERPOp),
'vopd': ('gfx11_asm_vopd.s', VOPD, VOPDOp),
'vopcx': ('gfx11_asm_vopcx.s', VOPC, VOPCOp), # VOPCX uses VOPC format
# VOP3 promotions (VOP1/VOP2/VOPC promoted to VOP3 encoding)
'vop3_from_vop1': ('gfx11_asm_vop3_from_vop1.s', VOP3, VOP3Op),
'vop3_from_vop2': ('gfx11_asm_vop3_from_vop2.s', VOP3, VOP3Op),
'vop3_from_vopc': ('gfx11_asm_vop3_from_vopc.s', VOP3, VOP3Op),
'vop3_from_vopcx': ('gfx11_asm_vop3_from_vopcx.s', VOP3, VOP3Op),
# Memory
'ds': ('gfx11_asm_ds.s', DS, DSOp),
'smem': ('gfx11_asm_smem.s', SMEM, SMEMOp),
'flat': ('gfx11_asm_flat.s', FLAT, FLATOp),
'mubuf': ('gfx11_asm_mubuf.s', MUBUF, MUBUFOp),
'mtbuf': ('gfx11_asm_mtbuf.s', MTBUF, MTBUFOp),
'mimg': ('gfx11_asm_mimg.s', MIMG, MIMGOp),
# WMMA (matrix multiply)
'wmma': ('gfx11_asm_wmma.s', VOP3P, VOP3POp),
# Additional features
'vop3_features': ('gfx11_asm_vop3_features.s', VOP3, VOP3Op),
'vop3p_features': ('gfx11_asm_vop3p_features.s', VOP3P, VOP3POp),
'vopd_features': ('gfx11_asm_vopd_features.s', VOPD, VOPDOp),
# Alias files (alternative mnemonics)
'vop3_alias': ('gfx11_asm_vop3_alias.s', VOP3, VOP3Op),
'vop3p_alias': ('gfx11_asm_vop3p_alias.s', VOP3P, VOP3POp),
'vopc_alias': ('gfx11_asm_vopc_alias.s', VOPC, VOPCOp),
'vopcx_alias': ('gfx11_asm_vopcx_alias.s', VOPC, VOPCOp),
'vinterp_alias': ('gfx11_asm_vinterp_alias.s', VINTERP, VINTERPOp),
'smem_alias': ('gfx11_asm_smem_alias.s', SMEM, SMEMOp),
'mubuf_alias': ('gfx11_asm_mubuf_alias.s', MUBUF, MUBUFOp),
'mtbuf_alias': ('gfx11_asm_mtbuf_alias.s', MTBUF, MTBUFOp),
}
RDNA_FILES = ['gfx11_asm_sop1.s', 'gfx11_asm_sop2.s', 'gfx11_asm_sopp.s', 'gfx11_asm_sopk.s', 'gfx11_asm_sopc.s',
'gfx11_asm_vop1.s', 'gfx11_asm_vop2.s', 'gfx11_asm_vopc.s', 'gfx11_asm_vop3.s', 'gfx11_asm_vop3p.s', 'gfx11_asm_vinterp.s',
'gfx11_asm_vopd.s', 'gfx11_asm_vopcx.s', 'gfx11_asm_vop3_from_vop1.s', 'gfx11_asm_vop3_from_vop2.s', 'gfx11_asm_vop3_from_vopc.s',
'gfx11_asm_vop3_from_vopcx.s', 'gfx11_asm_ds.s', 'gfx11_asm_smem.s', 'gfx11_asm_flat.s', 'gfx11_asm_mubuf.s', 'gfx11_asm_mtbuf.s',
'gfx11_asm_mimg.s', 'gfx11_asm_wmma.s', 'gfx11_asm_vop3_features.s', 'gfx11_asm_vop3p_features.s', 'gfx11_asm_vopd_features.s',
'gfx11_asm_vop3_alias.s', 'gfx11_asm_vop3p_alias.s', 'gfx11_asm_vopc_alias.s', 'gfx11_asm_vopcx_alias.s', 'gfx11_asm_vinterp_alias.s',
'gfx11_asm_smem_alias.s', 'gfx11_asm_mubuf_alias.s', 'gfx11_asm_mtbuf_alias.s']
# CDNA test files - includes gfx9 files for shared instructions, plus gfx90a/gfx942 specific files
# gfx90a_ldst_acc.s has MIMG mixed in, filtered via is_mimg check
CDNA_FILES = ['gfx9_asm_sop1.s', 'gfx9_asm_sop2.s', 'gfx9_asm_sopp.s', 'gfx9_asm_sopk.s', 'gfx9_asm_sopc.s',
'gfx9_asm_vop1.s', 'gfx9_asm_vop2.s', 'gfx9_asm_vopc.s', 'gfx9_asm_vop3.s', 'gfx9_asm_vop3p.s',
'gfx9_asm_ds.s', 'gfx9_asm_flat.s', 'gfx9_asm_smem.s', 'gfx9_asm_mubuf.s', 'gfx9_asm_mtbuf.s',
'gfx90a_ldst_acc.s', 'gfx90a_asm_features.s', 'flat-scratch-gfx942.s', 'gfx942_asm_features.s',
'mai-gfx90a.s', 'mai-gfx942.s']
def parse_llvm_tests(text: str) -> list[tuple[str, bytes]]:
"""Parse LLVM test format into (asm, expected_bytes) pairs."""
tests, lines = [], text.split('\n')
for i, line in enumerate(lines):
line = line.strip()
if not line or line.startswith(('//', '.', ';')): continue
asm_text = line.split('//')[0].strip()
if not asm_text: continue
for j in range(i, min(i + 3, len(lines))):
# Match GFX11, W32, or W64 encodings (all valid for gfx11)
# Format 1: "// GFX11: v_foo ... ; encoding: [0x01,0x02,...]"
# Format 2: "// GFX11: [0x01,0x02,...]" (used by DS, older files)
if m := re.search(r'(?:GFX11|W32|W64)[^:]*:.*?encoding:\s*\[(.*?)\]', lines[j]):
hex_bytes = m.group(1).replace('0x', '').replace(',', '').replace(' ', '')
elif m := re.search(r'(?:GFX11|W32|W64)[^:]*:\s*\[(0x[0-9a-fA-F,x\s]+)\]', lines[j]):
hex_bytes = m.group(1).replace('0x', '').replace(',', '').replace(' ', '')
else:
continue
if hex_bytes:
try: tests.append((asm_text, bytes.fromhex(hex_bytes)))
except ValueError: pass
break
def _is_mimg(data: bytes) -> bool: return (int.from_bytes(data[:4], 'little') >> 26) & 0x3f == 0b111100
def _parse_llvm_tests(text: str, pattern: str) -> list[tuple[str, bytes]]:
tests = []
for block in text.split('\n\n'):
asm_text, encoding = None, None
for line in block.split('\n'):
line = line.strip()
if not line or line.startswith(('.', ';')): continue
if not line.startswith('//'):
asm_text = line.split('//')[0].strip() or asm_text
if m := re.search(pattern + r'[^:]*:.*?(?:encoding:\s*)?\[(0x[0-9a-f,x\s]+)\]', line, re.I):
encoding = m.group(1).replace('0x', '').replace(',', '').replace(' ', '')
if asm_text and encoding:
try: tests.append((asm_text, bytes.fromhex(encoding)))
except ValueError: pass
return tests
def try_assemble(text: str):
"""Try to assemble instruction text, return bytes or None on failure."""
try: return asm(text).to_bytes()
except: return None
@functools.cache
def _get_tests(f: str, arch: str) -> list[tuple[str, bytes]]:
text = fetch(f"{LLVM_BASE}/{f}").read_bytes().decode('utf-8', errors='ignore')
if arch == "rdna3":
tests = _parse_llvm_tests(text, r'(?:GFX11|W32|W64)')
elif 'gfx90a' in f or 'gfx942' in f:
tests = _parse_llvm_tests(text, r'(?:GFX90A|GFX942)')
else:
tests = _parse_llvm_tests(text, r'(?:VI9|GFX9|CHECK)')
return [(a, d) for a, d in tests if not _is_mimg(d)] if arch == "cdna" else tests
def compile_asm_batch(instrs: list[str]) -> list[bytes]:
"""Compile multiple instructions with a single llvm-mc call."""
def _compile_asm_batch(instrs: list[str]) -> list[bytes]:
if not instrs: return []
asm_text = ".text\n" + "\n".join(instrs) + "\n"
result = subprocess.run(
[get_llvm_mc(), '-triple=amdgcn', '-mcpu=gfx1100', '-mattr=+real-true16,+wavefrontsize32', '-show-encoding'],
input=asm_text, capture_output=True, text=True, timeout=30)
if result.returncode != 0: raise RuntimeError(f"llvm-mc batch failed: {result.stderr.strip()}")
# Parse all encodings from output
results = []
for line in result.stdout.split('\n'):
if 'encoding:' not in line: continue
enc = line.split('encoding:')[1].strip()
if enc.startswith('[') and enc.endswith(']'):
results.append(bytes.fromhex(enc[1:-1].replace('0x', '').replace(',', '').replace(' ', '')))
if len(results) != len(instrs): raise RuntimeError(f"expected {len(instrs)} encodings, got {len(results)}")
return results
result = subprocess.run([get_llvm_mc(), '-triple=amdgcn', '-mcpu=gfx1100', '-mattr=+real-true16,+wavefrontsize32', '-show-encoding'],
input=".text\n" + "\n".join(instrs) + "\n", capture_output=True, text=True, timeout=30)
if result.returncode != 0: raise RuntimeError(f"llvm-mc failed: {result.stderr.strip()}")
return [bytes.fromhex(line.split('encoding:')[1].strip()[1:-1].replace('0x', '').replace(',', '').replace(' ', ''))
for line in result.stdout.split('\n') if 'encoding:' in line]
class TestLLVM(unittest.TestCase):
"""Test assembler and disassembler against all LLVM test vectors."""
tests: dict[str, list[tuple[str, bytes]]] = {}
@classmethod
def setUpClass(cls):
for name, (filename, _, _) in LLVM_TEST_FILES.items():
try:
data = fetch(f"{LLVM_BASE}/{filename}").read_bytes()
cls.tests[name] = parse_llvm_tests(data.decode('utf-8', errors='ignore'))
except Exception as e:
print(f"Warning: couldn't fetch {filename}: {e}")
cls.tests[name] = []
# Generate test methods dynamically for each format
def _make_asm_test(name):
def _make_test(f: str, arch: str, test_type: str):
def test(self):
passed, failed, skipped = 0, 0, 0
for asm_text, expected in self.tests.get(name, []):
result = try_assemble(asm_text)
if result is None: skipped += 1
elif result == expected: passed += 1
else: failed += 1
print(f"{name.upper()} asm: {passed} passed, {failed} failed, {skipped} skipped")
self.assertEqual(failed, 0)
tests = _get_tests(f, arch)
name = f"{arch}_{test_type}_{f}"
if test_type == "roundtrip":
for _, data in tests:
decoded = detect_format(data, arch).from_bytes(data)
self.assertEqual(decoded.to_bytes()[:len(data)], data)
print(f"{name}: {len(tests)} passed")
elif test_type == "asm":
passed, skipped = 0, 0
for asm_text, expected in tests:
try:
self.assertEqual(asm(asm_text).to_bytes(), expected)
passed += 1
except: skipped += 1
print(f"{name}: {passed} passed, {skipped} skipped")
elif test_type == "disasm":
to_test = []
for _, data in tests:
try:
decoded = detect_format(data, arch).from_bytes(data)
if decoded.to_bytes()[:len(data)] == data and (d := disasm(decoded)): to_test.append((data, d))
except: pass
print(f"{name}: {len(to_test)} passed, {len(tests) - len(to_test)} skipped")
if arch == "rdna3":
for (data, _), llvm in zip(to_test, _compile_asm_batch([t[1] for t in to_test])): self.assertEqual(llvm, data)
return test
def _make_disasm_test(name):
def test(self):
_, base_fmt_cls, base_op_enum = LLVM_TEST_FILES[name]
# VOP3SD opcodes that share encoding with VOP3 (only for vop3sd test, not vopc promotions)
vop3sd_opcodes = {288, 289, 290, 764, 765, 766, 767, 768, 769, 770}
is_vopc_promotion = name in ('vop3_from_vopc', 'vop3_from_vopcx')
class TestLLVM(unittest.TestCase): pass
# First pass: decode all instructions and collect disasm strings
to_test: list[tuple[str, bytes, str | None, str | None]] = [] # (asm_text, data, disasm_str, error)
for asm_text, data in self.tests.get(name, []):
# Detect VOP3 promotions in VOP1/VOP2/VOPC tests: VOP3 has bits [31:26]=0b110101 in first dword
is_vop3_enc = name in ('vop1', 'vop2', 'vopc', 'vopcx') and len(data) >= 4 and (data[3] >> 2) == 0x35
fmt_cls, op_enum = (VOP3, VOP3Op) if is_vop3_enc else (base_fmt_cls, base_op_enum)
try:
if base_fmt_cls.__name__ in ('VOP3', 'VOP3SD'):
temp = VOP3.from_bytes(data)
op_val = temp._values.get('op', 0)
op_val = op_val.val if hasattr(op_val, 'val') else op_val
is_vop3sd = (op_val in vop3sd_opcodes) and not is_vopc_promotion
decoded = VOP3SD.from_bytes(data) if is_vop3sd else VOP3.from_bytes(data)
if is_vop3sd: VOP3SDOp(op_val)
else: VOP3Op(op_val)
else:
decoded = fmt_cls.from_bytes(data)
op_val = decoded._values.get('op', 0)
op_val = op_val.val if hasattr(op_val, 'val') else op_val
op_enum(op_val)
if decoded.to_bytes()[:len(data)] != data:
to_test.append((asm_text, data, None, "decode roundtrip failed"))
continue
to_test.append((asm_text, data, decoded.disasm(), None))
except Exception as e:
to_test.append((asm_text, data, None, f"exception: {e}"))
# Batch compile all disasm strings with single llvm-mc call
disasm_strs = [(i, t[2]) for i, t in enumerate(to_test) if t[2] is not None]
llvm_results = compile_asm_batch([s for _, s in disasm_strs]) if disasm_strs else []
llvm_map = {i: llvm_results[j] for j, (i, _) in enumerate(disasm_strs)}
# Match results back
passed, failed = 0, 0
failures: list[str] = []
for idx, (asm_text, data, disasm_str, error) in enumerate(to_test):
if error:
failed += 1; failures.append(f"{error} for {data.hex()}")
elif disasm_str is not None and idx in llvm_map:
llvm_bytes = llvm_map[idx]
if llvm_bytes is not None and llvm_bytes == data: passed += 1
elif llvm_bytes is not None: failed += 1; failures.append(f"'{disasm_str}': expected={data.hex()} got={llvm_bytes.hex()}")
print(f"{name.upper()} disasm: {passed} passed, {failed} failed")
if failures[:10]: print(" " + "\n ".join(failures[:10]))
self.assertEqual(failed, 0)
return test
for name in LLVM_TEST_FILES:
setattr(TestLLVM, f'test_{name}_asm', _make_asm_test(name))
setattr(TestLLVM, f'test_{name}_disasm', _make_disasm_test(name))
for f in RDNA_FILES:
setattr(TestLLVM, f"test_rdna3_roundtrip_{f.replace('.s', '').replace('-', '_')}", _make_test(f, "rdna3", "roundtrip"))
setattr(TestLLVM, f"test_rdna3_asm_{f.replace('.s', '').replace('-', '_')}", _make_test(f, "rdna3", "asm"))
setattr(TestLLVM, f"test_rdna3_disasm_{f.replace('.s', '').replace('-', '_')}", _make_test(f, "rdna3", "disasm"))
for f in CDNA_FILES:
setattr(TestLLVM, f"test_cdna_roundtrip_{f.replace('.s', '').replace('-', '_')}", _make_test(f, "cdna", "roundtrip"))
setattr(TestLLVM, f"test_cdna_disasm_{f.replace('.s', '').replace('-', '_')}", _make_test(f, "cdna", "disasm"))
if __name__ == "__main__":
unittest.main()

View File

@@ -1,144 +0,0 @@
#!/usr/bin/env python3
"""Test CDNA assembler/disassembler against LLVM test vectors."""
import unittest, re, subprocess
from tinygrad.helpers import fetch
from extra.assembly.amd.autogen.cdna.ins import *
from extra.assembly.amd.asm import disasm
from extra.assembly.amd.test.helpers import get_llvm_mc
LLVM_BASE = "https://raw.githubusercontent.com/llvm/llvm-project/main/llvm/test/MC/AMDGPU"
def parse_llvm_tests(text: str, mnemonic_filter: str = None, size_filter: int = None) -> list[tuple[str, bytes]]:
"""Parse LLVM test format into (asm, expected_bytes) pairs."""
tests, lines = [], text.split('\n')
for i, line in enumerate(lines):
line = line.strip()
if not line or line.startswith(('//', '.', ';')): continue
asm_text = line.split('//')[0].strip()
if not asm_text or (mnemonic_filter and not asm_text.startswith(mnemonic_filter)): continue
for j in list(range(max(0, i - 3), i)) + list(range(i, min(i + 3, len(lines)))):
if m := re.search(r'(?:VI9|GFX9|CHECK)[^:]*:.*?encoding:\s*\[(.*?)\]', lines[j]):
hex_bytes = m.group(1).replace('0x', '').replace(',', '').replace(' ', '')
elif m := re.search(r'CHECK[^:]*:\s*\[(0x[0-9a-fA-F,x\s]+)\]', lines[j]):
hex_bytes = m.group(1).replace('0x', '').replace(',', '').replace(' ', '')
else: continue
try:
data = bytes.fromhex(hex_bytes)
if size_filter is None or len(data) == size_filter: tests.append((asm_text, data))
except ValueError: pass
break
return tests
# Use gfx9 tests for compatible scalar/vector formats and gfx90a/gfx942 tests for CDNA-specific instructions
# Format: (filename, format_class, op_enum, mcpu, mnemonic_filter, size_filter)
CDNA_TEST_FILES = {
# Scalar ALU - encoding is stable across GFX9/CDNA
'sop1': ('gfx9_asm_sop1.s', SOP1, SOP1Op, 'gfx940', None, None),
'sop2': ('gfx9_asm_sop2.s', SOP2, SOP2Op, 'gfx940', None, None),
'sopp': ('gfx9_asm_sopp.s', SOPP, SOPPOp, 'gfx940', None, None),
'sopp_gfx9': ('sopp-gfx9.s', SOPP, SOPPOp, 'gfx940', None, None),
'sopk': ('gfx9_asm_sopk.s', SOPK, SOPKOp, 'gfx940', None, None),
'sopc': ('gfx9_asm_sopc.s', SOPC, SOPCOp, 'gfx940', None, None),
# Vector ALU - encoding is mostly stable
'vop1': ('gfx9_asm_vop1.s', VOP1, VOP1Op, 'gfx940', None, None),
'vop1_gfx9': ('vop1-gfx9.s', VOP1, VOP1Op, 'gfx940', None, None),
'vop2': ('gfx9_asm_vop2.s', VOP2, VOP2Op, 'gfx940', None, None),
'vopc': ('gfx9_asm_vopc.s', VOPC, VOPCOp, 'gfx940', None, None),
'vop3p': ('gfx9_asm_vop3p.s', VOP3P, VOP3POp, 'gfx940', None, None),
'vop3_gfx9': ('vop3-gfx9.s', VOP3A, VOP3AOp, 'gfx940', None, 8), # Only 64-bit VOP3 instructions
# Memory instructions
'ds': ('gfx9_asm_ds.s', DS, DSOp, 'gfx940', None, None),
'ds_gfx9': ('ds-gfx9.s', DS, DSOp, 'gfx940', None, None),
# CDNA memory instructions (gfx90a has correct FLAT/MUBUF encodings with acc registers)
'flat_gfx90a': ('gfx90a_ldst_acc.s', FLAT, FLATOp, 'gfx90a', 'flat_', None),
'global_gfx90a': ('gfx90a_ldst_acc.s', FLAT, FLATOp, 'gfx90a', 'global_', None),
'mubuf_gfx90a': ('gfx90a_ldst_acc.s', MUBUF, MUBUFOp, 'gfx90a', 'buffer_', None),
'mubuf_gfx9': ('mubuf-gfx9.s', MUBUF, MUBUFOp, 'gfx940', None, None),
'scratch_gfx942': ('flat-scratch-gfx942.s', FLAT, FLATOp, 'gfx942', 'scratch_', None),
# CDNA-specific: MFMA/MAI instructions
'mai': ('mai-gfx942.s', VOP3P, VOP3POp, 'gfx942', None, None),
# SDWA and DPP format tests for VOP1 (VOP2 has different bit layout, tested separately)
'sdwa_vop1': ('gfx9_asm_vop1.s', SDWA, VOP1Op, 'gfx940', None, None),
'dpp_vop1': ('gfx9_asm_vop1.s', DPP, VOP1Op, 'gfx940', None, None),
}
class TestLLVMCDNA(unittest.TestCase):
"""Test CDNA instruction format decode/encode roundtrip and disassembly."""
tests: dict[str, list[tuple[str, bytes]]] = {}
@classmethod
def setUpClass(cls):
for name, (filename, _, _, _, mnemonic_filter, size_filter) in CDNA_TEST_FILES.items():
try:
data = fetch(f"{LLVM_BASE}/{filename}").read_bytes()
cls.tests[name] = parse_llvm_tests(data.decode('utf-8', errors='ignore'), mnemonic_filter, size_filter)
except Exception as e:
print(f"Warning: couldn't fetch {filename}: {e}")
cls.tests[name] = []
def _get_val(v): return v.val if hasattr(v, 'val') else v
def _filter_and_decode(tests, fmt_cls, op_enum):
"""Filter tests and decode instructions, yielding (asm_text, data, decoded, error)."""
fn, is_sdwa, is_dpp = fmt_cls.__name__, fmt_cls.__name__ == 'SDWA', fmt_cls.__name__ == 'DPP'
for asm_text, data in tests:
has_lit = False
# SDWA/DPP format tests: only accept matching 8-byte instructions
if is_sdwa:
if len(data) != 8 or data[0] != 0xf9: continue
elif is_dpp:
if len(data) != 8 or data[0] != 0xfa: continue
elif fmt_cls._size() == 4 and len(data) == 8:
if data[0] in (0xf9, 0xfa): continue # Skip SDWA/DPP (tested separately)
has_lit = data[0] == 255 or (len(data) >= 2 and data[1] == 255 and fn in ('SOP2', 'SOPC'))
if fn == 'SOPK': has_lit = has_lit or ((int.from_bytes(data[:4], 'little') >> 23) & 0x1f) == 20
if fn == 'VOP2': has_lit = has_lit or ((int.from_bytes(data[:4], 'little') >> 25) & 0x3f) in (23, 24, 36, 37)
if not has_lit: continue
if len(data) > fmt_cls._size() + (4 if has_lit else 0): continue
try:
decoded = fmt_cls.from_bytes(data)
# For SDWA/DPP, opcode location depends on VOP1 vs VOP2
if is_sdwa or is_dpp:
vop2_op = _get_val(decoded._values.get('vop2_op', 0))
op_val = _get_val(decoded._values.get('vop_op', 0)) if vop2_op == 0x3f else vop2_op
else:
op_val = _get_val(decoded._values.get('op', 0))
try: op_enum(op_val)
except ValueError: continue
yield asm_text, data, decoded, None
except Exception as e:
yield asm_text, data, None, str(e)
def _make_roundtrip_test(name):
def test(self):
_, fmt_cls, op_enum, _, _, _ = CDNA_TEST_FILES[name]
passed, failed, failures = 0, 0, []
for asm_text, data, decoded, error in _filter_and_decode(self.tests.get(name, []), fmt_cls, op_enum):
if error: failed += 1; failures.append(f"'{asm_text}': {error}"); continue
if decoded.to_bytes()[:len(data)] == data: passed += 1
else: failed += 1; failures.append(f"'{asm_text}': orig={data.hex()} reenc={decoded.to_bytes()[:len(data)].hex()}")
print(f"CDNA {name.upper()} roundtrip: {passed} passed, {failed} failed")
if failures[:5]: print(" " + "\n ".join(failures[:5]))
self.assertEqual(failed, 0)
return test
def _make_disasm_test(name):
def test(self):
_, fmt_cls, op_enum, _, _, _ = CDNA_TEST_FILES[name]
passed, failed, failures = 0, 0, []
for asm_text, data, decoded, error in _filter_and_decode(self.tests.get(name, []), fmt_cls, op_enum):
if error: failed += 1; failures.append(f"'{asm_text}': {error}"); continue
if decoded.to_bytes()[:len(data)] != data: failed += 1; failures.append(f"'{asm_text}': roundtrip failed"); continue
if not (disasm_text := disasm(decoded)) or not disasm_text.strip(): failed += 1; failures.append(f"'{asm_text}': empty disassembly"); continue
passed += 1
print(f"CDNA {name.upper()} disasm: {passed} passed, {failed} failed")
if failures[:5]: print(" " + "\n ".join(failures[:5]))
self.assertEqual(failed, 0)
return test
for name in CDNA_TEST_FILES:
setattr(TestLLVMCDNA, f'test_{name}_roundtrip', _make_roundtrip_test(name))
setattr(TestLLVMCDNA, f'test_{name}_disasm', _make_disasm_test(name))
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,50 @@
#!/usr/bin/env python3
"""Test pdf.py PDF parser and enum generation."""
import unittest, tempfile, importlib.util
from extra.assembly.amd.pdf import extract, extract_tables, extract_enums, write_enums, PDF_URLS
EXPECTED = {
"rdna3": {"pages": 655, "tables": 115, "sop2_ops": 67, "sop2_first": "S_ADD_U32"},
"rdna4": {"pages": 711, "tables": 125, "sop2_ops": 74, "sop2_first": "S_ADD_CO_U32"},
"cdna": {"pages": 610, "tables": 104, "sop2_ops": 52, "sop2_first": "S_ADD_U32"},
}
class TestPDF2(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.data = {name: extract(url) for name, url in PDF_URLS.items()}
cls.tables = {name: extract_tables(pages) for name, pages in cls.data.items()}
cls.enums = {name: extract_enums(cls.tables[name]) for name in PDF_URLS}
def test_page_counts(self):
for name, exp in EXPECTED.items():
self.assertEqual(len(self.data[name]), exp["pages"], f"{name} page count")
def test_table_counts(self):
for name, exp in EXPECTED.items():
self.assertEqual(len(self.tables[name]), exp["tables"], f"{name} table count")
def test_tables_sequential(self):
for name in PDF_URLS:
nums = sorted(self.tables[name].keys())
missing = set(range(1, max(nums) + 1)) - set(nums)
self.assertEqual(missing, set(), f"{name} missing tables: {missing}")
def test_generate_enums(self):
for name, exp in EXPECTED.items():
with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
write_enums(self.enums[name], name, f.name)
spec = importlib.util.spec_from_file_location("enum", f.name)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
# Check SOP2Op
self.assertTrue(hasattr(mod, 'SOP2Op'), f"{name} missing SOP2Op")
self.assertEqual(len(mod.SOP2Op), exp["sop2_ops"], f"{name} SOP2Op count")
self.assertEqual(mod.SOP2Op(0).name, exp["sop2_first"], f"{name} SOP2Op first")
# Check all enums have at least 2 ops
for attr in dir(mod):
if attr.endswith('Op'):
self.assertGreaterEqual(len(getattr(mod, attr)), 2, f"{name} {attr} has too few ops")
if __name__ == "__main__":
unittest.main()

View File

@@ -1,150 +0,0 @@
#!/usr/bin/env python3
"""Test that PDF parser correctly extracts format fields."""
import unittest, os
from extra.assembly.amd.autogen.rdna3.ins import SOP1, SOP2, SOPK, SOPP, VOP1, VOP2, VOP3SD, VOPC, FLAT, VOPD, SOP1Op, SOP2Op, VOP1Op, VOP3Op
# expected formats with key fields and whether they have ENCODING
EXPECTED_FORMATS = {
'DPP16': (['SRC0', 'DPP_CTRL', 'BANK_MASK', 'ROW_MASK'], False),
'DPP8': (['SRC0', 'LANE_SEL0', 'LANE_SEL7'], False),
'DS': (['OP', 'ADDR', 'DATA0', 'DATA1', 'VDST'], True),
'EXP': (['EN', 'TARGET', 'VSRC0', 'VSRC1', 'VSRC2', 'VSRC3'], True),
'FLAT': (['OP', 'ADDR', 'DATA', 'SADDR', 'VDST', 'OFFSET'], True),
'LDSDIR': (['VDST', 'OP'], True),
'MIMG': (['OP', 'VADDR', 'VDATA', 'SRSRC', 'DMASK'], True),
'MTBUF': (['OP', 'VADDR', 'VDATA', 'SRSRC', 'FORMAT', 'SOFFSET'], True),
'MUBUF': (['OP', 'VADDR', 'VDATA', 'SRSRC', 'SOFFSET'], True),
'SMEM': (['OP', 'SBASE', 'SDATA', 'OFFSET', 'SOFFSET'], True),
'SOP1': (['OP', 'SDST', 'SSRC0'], True),
'SOP2': (['OP', 'SDST', 'SSRC0', 'SSRC1'], True),
'SOPC': (['OP', 'SSRC0', 'SSRC1'], True),
'SOPK': (['OP', 'SDST', 'SIMM16'], True),
'SOPP': (['OP', 'SIMM16'], True),
'VINTERP': (['OP', 'VDST', 'SRC0', 'SRC1', 'SRC2'], True),
'VOP1': (['OP', 'VDST', 'SRC0'], True),
'VOP2': (['OP', 'VDST', 'SRC0', 'VSRC1'], True),
'VOP3': (['OP', 'VDST', 'SRC0', 'SRC1', 'SRC2'], True),
'VOP3P': (['OP', 'VDST', 'SRC0', 'SRC1', 'SRC2'], True),
'VOP3SD': (['OP', 'VDST', 'SDST', 'SRC0', 'SRC1', 'SRC2'], True),
'VOPC': (['OP', 'SRC0', 'VSRC1'], True),
'VOPD': (['OPX', 'OPY', 'SRCX0', 'SRCY0', 'VDSTX', 'VDSTY'], True),
}
# Skip PDF parsing tests by default - only run with TEST_PDF_PARSER=1
# These are slow (~5s) and only needed when regenerating autogen/
@unittest.skipUnless(os.environ.get("TEST_PDF_PARSER"), "set TEST_PDF_PARSER=1 to run PDF parser tests")
class TestPDFParserGenerate(unittest.TestCase):
"""Test the PDF parser by running generate() and checking results."""
def test_pdf_parser(self):
"""Single test that validates all PDF parser outputs."""
from extra.assembly.amd.dsl import generate
result = generate()
# test_all_formats_present
for fmt_name in EXPECTED_FORMATS:
self.assertIn(fmt_name, result["formats"], f"missing format {fmt_name}")
# test_format_count
self.assertEqual(len(result["formats"]), 23)
# test_no_duplicate_fields
for fmt_name, fields in result["formats"].items():
field_names = [f[0] for f in fields]
self.assertEqual(len(field_names), len(set(field_names)), f"{fmt_name} has duplicate fields: {field_names}")
# test_expected_fields
for fmt_name, (expected_fields, has_encoding) in EXPECTED_FORMATS.items():
fields = {f[0] for f in result["formats"].get(fmt_name, [])}
for field in expected_fields:
self.assertIn(field, fields, f"{fmt_name} missing {field}")
if has_encoding:
self.assertIn("ENCODING", fields, f"{fmt_name} should have ENCODING")
else:
self.assertNotIn("ENCODING", fields, f"{fmt_name} should not have ENCODING")
# test_vopd_no_dpp16_fields
vopd_fields = {f[0] for f in result["formats"].get("VOPD", [])}
for field in ['DPP_CTRL', 'BANK_MASK', 'ROW_MASK']:
self.assertNotIn(field, vopd_fields, f"VOPD should not have {field}")
# test_dpp16_no_vinterp_fields
dpp16_fields = {f[0] for f in result["formats"].get("DPP16", [])}
for field in ['VDST', 'WAITEXP']:
self.assertNotIn(field, dpp16_fields, f"DPP16 should not have {field}")
# test_sopp_no_smem_fields
sopp_fields = {f[0] for f in result["formats"].get("SOPP", [])}
for field in ['SBASE', 'SDATA']:
self.assertNotIn(field, sopp_fields, f"SOPP should not have {field}")
class TestPDFParser(unittest.TestCase):
"""Verify format classes have correct fields from PDF parsing."""
def test_sop2_fields(self):
"""SOP2 should have op, sdst, ssrc0, ssrc1."""
for field in ['op', 'sdst', 'ssrc0', 'ssrc1']:
self.assertIn(field, SOP2._fields)
self.assertEqual(SOP2._fields['op'].hi, 29)
self.assertEqual(SOP2._fields['op'].lo, 23)
def test_sop1_fields(self):
"""SOP1 should have op, sdst, ssrc0 with correct bit positions."""
for field in ['op', 'sdst', 'ssrc0']:
self.assertIn(field, SOP1._fields)
self.assertNotIn('simm16', SOP1._fields)
self.assertEqual(SOP1._fields['ssrc0'].hi, 7)
self.assertEqual(SOP1._fields['ssrc0'].lo, 0)
assert SOP1._encoding is not None
self.assertEqual(SOP1._encoding[0].hi, 31)
self.assertEqual(SOP1._encoding[1], 0b101111101)
def test_vop3sd_fields(self):
"""VOP3SD should have all fields including src0/src1/src2 from page continuation."""
for field in ['op', 'vdst', 'sdst', 'src0', 'src1', 'src2']:
self.assertIn(field, VOP3SD._fields)
self.assertEqual(VOP3SD._fields['src0'].hi, 40)
self.assertEqual(VOP3SD._fields['src0'].lo, 32)
self.assertEqual(VOP3SD._size(), 8)
def test_flat_has_vdst(self):
"""FLAT should have vdst field."""
self.assertIn('vdst', FLAT._fields)
self.assertEqual(FLAT._fields['vdst'].hi, 63)
self.assertEqual(FLAT._fields['vdst'].lo, 56)
def test_encoding_bits(self):
"""Verify encoding bits are correct for major formats."""
tests = [
(SOP2, 31, 30, 0b10),
(SOPK, 31, 28, 0b1011),
(SOPP, 31, 23, 0b101111111),
(VOP1, 31, 25, 0b0111111),
(VOP2, 31, 31, 0b0),
(VOPC, 31, 25, 0b0111110),
(FLAT, 31, 26, 0b110111),
]
for cls, hi, lo, val in tests:
assert cls._encoding is not None
self.assertEqual(cls._encoding[0].hi, hi, f"{cls.__name__} encoding hi")
self.assertEqual(cls._encoding[0].lo, lo, f"{cls.__name__} encoding lo")
self.assertEqual(cls._encoding[1], val, f"{cls.__name__} encoding val")
def test_opcode_enums_exist(self):
"""Verify opcode enums are generated with expected counts."""
self.assertGreater(len(SOP1Op), 50)
self.assertGreater(len(SOP2Op), 50)
self.assertGreater(len(VOP1Op), 50)
self.assertGreater(len(VOP3Op), 200)
def test_vopd_no_duplicate_fields(self):
"""VOPD should not have duplicate fields and should not include DPP16 fields."""
field_names = list(VOPD._fields.keys())
self.assertEqual(len(field_names), len(set(field_names)))
for field in ['srcx0', 'srcy0', 'opx', 'opy']:
self.assertIn(field, VOPD._fields)
for field in ['dpp_ctrl', 'bank_mask', 'row_mask']:
self.assertNotIn(field, VOPD._fields)
if __name__ == "__main__":
unittest.main()

38
extra/viz/cli.py Executable file
View File

@@ -0,0 +1,38 @@
#!/usr/bin/env python3
import argparse, pathlib
from typing import Iterator
from tinygrad.viz import serve as viz
from tinygrad.uop.ops import RewriteTrace
from tinygrad.helpers import temp, ansistrip, colored
def optional_eq(val:dict, arg:str|None) -> bool: return arg is None or ansistrip(val["name"]) == arg
def print_data(data:dict) -> None:
if isinstance(data.get("value"), Iterator):
for m in data["value"]:
if not m["diff"]: continue
fp = pathlib.Path(m["upat"][0][0])
print(f"{fp.parent.name}/{fp.name}:{m['upat'][0][1]}")
print(m["upat"][1])
for line in m["diff"]:
color = "red" if line.startswith("-") else "green" if line.startswith("+") else None
print(colored(line, color))
if data.get("src") is not None: print(data["src"])
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--kernel', type=str, default=None, metavar="NAME", help='Select a kernel by name (optional name, default: only list names)')
parser.add_argument('--select', type=str, default=None, metavar="NAME",
help='Select an item within the chosen kernel (optional name, default: only list names)')
args = parser.parse_args()
viz.trace = viz.load_pickle(pathlib.Path(temp("rewrites.pkl", append_user=True)), default=RewriteTrace([], [], {}))
viz.ctxs = viz.get_rewrites(viz.trace)
for k in viz.ctxs:
if not optional_eq(k, args.kernel): continue
print(k["name"])
if args.kernel is None: continue
for s in k["steps"]:
if not optional_eq(s, args.select): continue
print(" "*s["depth"]+s['name']+(f" - {s['match_count']}" if s.get('match_count') is not None else ''))
if args.select is not None: print_data(viz.get_render(s['query']))

View File

@@ -64,8 +64,6 @@ backend_test.exclude('test_qlinearmatmul_2D_int8_float32_cpu')
backend_test.exclude('test_qlinearmatmul_3D_int8_float32_cpu')
# tested in external_test_onnx_ops.py::TestMainOnnxOps.test_maxunpool_export_with_output_shape
backend_test.exclude('test_maxunpool_export_with_output_shape_cpu')
# tested in external_test_onnx_ops.py::TestMainOnnxOps.test_averagepool_3d_dilations_large_count_include_pad_is_1_ceil_mode_is_True
backend_test.exclude('test_averagepool_3d_dilations_large_count_include_pad_is_1_ceil_mode_is_True_cpu')
# tested in external_test_onnx_ops.py::TestMainOnnxOps.test_resize_downsample_scales_linear_align_corners
backend_test.exclude('test_resize_downsample_scales_linear_align_corners_cpu')
# tested in external_test_onnx_ops.py::TestMainOnnxOps.test_resize_downsample_scales_cubic_align_corners
@@ -151,8 +149,6 @@ backend_test.exclude('test_hannwindow_*')
backend_test.exclude('test_hardmax_*')
backend_test.exclude('test_gridsample_*')
backend_test.exclude('test_dft_*')
backend_test.exclude('test_einsum_batch_diagonal_cpu*') # TODO: equation = '...ii ->...i'
backend_test.exclude('test_einsum_inner_prod_cpu*') # TODO: equation = 'i,i'
backend_test.exclude('test_unique_*')
backend_test.exclude('test_sequence_*')
backend_test.exclude('test_nonmaxsuppression_*')
@@ -171,16 +167,14 @@ backend_test.exclude('test_split_to_sequence_*')
backend_test.exclude('test_ai_onnx_ml_tree_ensemble_*') # https://github.com/onnx/onnx/blob/main/onnx/reference/ops/aionnxml/op_tree_ensemble.py#L121
# TODO: not yet implemented
backend_test.exclude('test_tensorscatter_*')
backend_test.exclude('test_l1normalization_*')
backend_test.exclude('test_l2normalization_*')
backend_test.exclude('test_lpnormalization_*')
backend_test.exclude('test_einsum_scalar_cpu')
backend_test.exclude('test_mod_mixed_sign_float16_cpu')
backend_test.exclude('test_qlinearmatmul_2D_uint8_float16_cpu')
backend_test.exclude('test_qlinearmatmul_3D_uint8_float16_cpu')
backend_test.exclude('test_attention_3d_*')
backend_test.exclude('test_attention_4d_*')
backend_test.exclude('test_attention_4d_diff_heads_mask4d_padded_kv_cpu') # needs nonpad_kv_seqlen handling
backend_test.exclude('test_attention_4d_fp16_cpu') # fp16 numerical issues
backend_test.exclude('test_attention_4d_fp16_expanded_cpu') # fp16 numerical issues
backend_test.exclude('test_attention_4d_gqa_with_past_and_present_fp16_cpu') # fp16 numerical issues
backend_test.exclude('test_attention_4d_gqa_with_past_and_present_fp16_expanded_cpu') # fp16 numerical issues
# rest of the failing tests
@@ -197,16 +191,6 @@ backend_test.exclude('test_ai_onnx_ml_label_encoder_tensor_mapping_cpu') # bad d
backend_test.exclude('test_if_opt_cpu') # ValueError: 13 is not a valid AttributeType
backend_test.exclude('test_if_seq_cpu') # NotImplementedError: op='SequenceConstruct' is not supported
# regression from removing StrEnum in Domain
backend_test.exclude('test_adam_cpu')
backend_test.exclude('test_gradient_of_add_and_mul_cpu')
backend_test.exclude('test_gradient_of_add_cpu')
if Device.DEFAULT in ['CL', 'METAL']:
backend_test.exclude('test_resize_upsample_sizes_nearest_axes_2_3_cpu')
backend_test.exclude('test_resize_upsample_sizes_nearest_axes_3_2_cpu')
backend_test.exclude('test_resize_upsample_sizes_nearest_cpu')
if Device.DEFAULT == "METAL" or (OSX and Device.DEFAULT == "CL"):
# numerical inaccuracy
backend_test.exclude('test_mish_cpu')

View File

@@ -1171,6 +1171,8 @@ class TestOps(unittest.TestCase):
@slow_test
def test_einsum(self):
# scalar
helper_test_op([()], lambda a: torch.einsum('->', a), lambda a: Tensor.einsum('->', a))
# matrix transpose
helper_test_op([(10,10)], lambda a: torch.einsum('ij->ji', a), lambda a: Tensor.einsum('ij->ji', a))
helper_test_op([(10,10)], lambda a: torch.einsum('ij -> ji', a), lambda a: Tensor.einsum('ij -> ji', a))
@@ -1239,6 +1241,18 @@ class TestOps(unittest.TestCase):
self.helper_test_exception([(2, 3, 4), (2, 3, 4)], lambda a, b: torch.einsum('i...j,ji...->...', [a, b]),
lambda a, b: Tensor.einsum('i...j,ji...->...', [a, b]), expected=RuntimeError)
def test_einsum_trace(self):
# inner product
helper_test_op([(5,), (5,)], lambda a, b: torch.einsum('i,i', a, b), lambda a, b: Tensor.einsum('i,i', a, b))
# simple diagonal
helper_test_op([(4, 4)], lambda a: torch.einsum('ii->i', a), lambda a: Tensor.einsum('ii->i', a))
# trace (sum of diagonal)
helper_test_op([(4, 4)], lambda a: torch.einsum('ii->', a), lambda a: Tensor.einsum('ii->', a))
# batch diagonal
helper_test_op([(3, 5, 5)], lambda a: torch.einsum('...ii->...i', a), lambda a: Tensor.einsum('...ii->...i', a))
# batch trace
helper_test_op([(3, 5, 5)], lambda a: torch.einsum('...ii->...', a), lambda a: Tensor.einsum('...ii->...', a))
def test_einsum_shape_check(self):
self.helper_test_exception([(3,8,10,5), (11,5,13,16,8)], lambda a, b: torch.einsum('pqrs,tuqvr->pstuv', [a, b]),
lambda a, b: Tensor.einsum('pqrs,tuqvr->pstuv', [a, b]), expected=RuntimeError)

View File

@@ -807,7 +807,7 @@ class TestTK(unittest.TestCase):
Tensor.manual_seed(42)
B, N, H, H_KV, D = 1, 32, 2, 1, 32
B, N, H, H_KV, D = 1, 1024, 32, 32, 128
with Context(DEBUG=0):
q = Tensor.randn(B, N, H, D, dtype=dtypes.bfloat16, requires_grad=True).contiguous()
@@ -840,5 +840,57 @@ class TestTK(unittest.TestCase):
np.testing.assert_allclose(v.grad.numpy(), v_ref.grad.numpy(), atol=2e-2, rtol=2e-2)
np.testing.assert_allclose(k.grad.numpy(), k_ref.grad.numpy(), atol=5e-2, rtol=2e-2)
@unittest.expectedFailure
def test_fast_fa_bwd_causal_jitted(self):
from extra.thunder.tiny.fa import flash_attention
Tensor.manual_seed(42)
B, N, H, H_KV, D = 1, 1024, 32, 32, 128
with Context(DEBUG=0):
q = Tensor.randn(B, N, H, D, dtype=dtypes.bfloat16, requires_grad=True).contiguous()
k = Tensor.randn(B, N, H_KV, D, dtype=dtypes.bfloat16, requires_grad=True).contiguous()
v = Tensor.randn(B, N, H_KV, D, dtype=dtypes.bfloat16, requires_grad=True).contiguous()
Tensor.realize(q, k, v)
do = Tensor.ones(B, N, H, D, dtype=dtypes.float32).contiguous()
Tensor.realize(do)
def fn(q, k, v, do):
q_, k_, v_ = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
out = flash_attention(q_, k_, v_, is_causal=True)
out = out.float().transpose(1, 2)
out.backward(do)
Tensor.realize(out, q.grad, k.grad, v.grad)
return q.grad, k.grad, v.grad
fn_jitted = TinyJit(fn)
for _ in range(10):
q = Tensor.randn(B, N, H, D, dtype=dtypes.bfloat16, requires_grad=True).contiguous()
k = Tensor.randn(B, N, H_KV, D, dtype=dtypes.bfloat16, requires_grad=True).contiguous()
v = Tensor.randn(B, N, H_KV, D, dtype=dtypes.bfloat16, requires_grad=True).contiguous()
Tensor.realize(q, k, v)
do = Tensor.ones(B, N, H, D, dtype=dtypes.float32).contiguous()
Tensor.realize(do)
q.grad, k.grad, v.grad = fn_jitted(q, k, v, do)
with Context(DEBUG=0):
q_ref = q.detach().clone().requires_grad_(True)
k_ref = k.detach().clone().requires_grad_(True)
v_ref = v.detach().clone().requires_grad_(True)
Tensor.realize(q_ref, k_ref, v_ref)
q_ref_, k_ref_, v_ref_ = q_ref.transpose(1, 2), k_ref.transpose(1, 2), v_ref.transpose(1, 2)
ref = q_ref_.scaled_dot_product_attention(k_ref_, v_ref_, is_causal=True)
ref = ref.float().transpose(1, 2)
ref.backward(do)
Tensor.realize(q_ref.grad, k_ref.grad, v_ref.grad)
np.testing.assert_allclose(q.grad.numpy(), q_ref.grad.numpy(), atol=5e-2, rtol=2e-2)
np.testing.assert_allclose(v.grad.numpy(), v_ref.grad.numpy(), atol=2e-2, rtol=2e-2)
np.testing.assert_allclose(k.grad.numpy(), k_ref.grad.numpy(), atol=5e-2, rtol=2e-2)
if __name__ == "__main__":
unittest.main()

View File

@@ -1,6 +1,7 @@
import ctypes, subprocess, tempfile, unittest
from tinygrad.helpers import WIN
from tinygrad.runtime.support.c import Struct
from tinygrad.runtime.support.autogen import gen
class TestAutogen(unittest.TestCase):
def test_packed_struct_sizeof(self):
@@ -159,4 +160,97 @@ class TestAutogen(unittest.TestCase):
assert ihdr.num_dies == 1
assert ihdr.base_addr_64_bit == 1
@unittest.skipIf(WIN, "doesn't compile on windows")
def test_gen_from_header(self):
header_content = """
typedef struct {
int x;
int y;
} Point;
typedef enum {
RED = 0,
GREEN = 1,
BLUE = 2
} Color;
typedef struct {
Point origin;
int width;
int height;
Color color;
} Rectangle;
int add_points(Point a, Point b);
"""
with tempfile.NamedTemporaryFile(mode='w', suffix='.h') as f:
f.write(header_content)
f.flush()
generated_code = gen(name="test_header", dll=None, files=[f.name])
namespace = {}
exec(generated_code, namespace)
self.assertIn('Point', namespace)
self.assertIn('Color', namespace)
self.assertIn('Rectangle', namespace)
self.assertIn('RED', namespace)
self.assertIn('GREEN', namespace)
self.assertIn('BLUE', namespace)
self.assertEqual(namespace['RED'], 0)
self.assertEqual(namespace['GREEN'], 1)
self.assertEqual(namespace['BLUE'], 2)
Point = namespace['Point']
p = Point()
self.assertIsInstance(p, Struct)
self.assertTrue(hasattr(p, 'x'))
self.assertTrue(hasattr(p, 'y'))
Rectangle = namespace['Rectangle']
rect = Rectangle()
self.assertTrue(hasattr(rect, 'origin'))
self.assertTrue(hasattr(rect, 'width'))
self.assertTrue(hasattr(rect, 'height'))
self.assertTrue(hasattr(rect, 'color'))
@unittest.skipIf(WIN, "doesn't compile on windows")
def test_struct_ordering(self):
header_content = """
struct A;
struct C;
typedef struct A A;
struct B {
struct C *c_ptr;
};
struct C {
struct A *a_ptr;
};
struct A {
int x;
struct B *b_ptr;
};
"""
with tempfile.NamedTemporaryFile(mode='w', suffix='.h') as f:
f.write(header_content)
f.flush()
generated_code = gen(name="test_ordering", dll=None, files=[f.name])
namespace = {}
exec(generated_code, namespace)
self.assertIn('struct_A', namespace)
self.assertIn('struct_B', namespace)
self.assertIn('struct_C', namespace)
A, B, C = namespace['struct_A'], namespace['struct_B'], namespace['struct_C']
a, b, c = A(), B(), C()
self.assertTrue(hasattr(a, 'x'))
self.assertTrue(hasattr(a, 'b_ptr'))
self.assertTrue(hasattr(b, 'c_ptr'))
self.assertTrue(hasattr(c, 'a_ptr'))
if __name__ == "__main__": unittest.main()

View File

@@ -128,10 +128,6 @@ def unwrap_class_type(cls_t): return cls_t.func if isinstance(cls_t, functools.p
def pluralize(st:str, cnt:int): return f"{cnt} {st}"+('' if cnt == 1 else 's')
class LazySeq(Generic[T]): # NOTE: Mapping requires __iter__ and __len__, Sequence requires supporting __len__ and slicing in __getitem__
def __init__(self, gen:Callable[[int], T]): self.gen = gen
def __getitem__(self, idx:int) -> T: return self.gen(idx)
# for length N coefficients `p`, returns p[0] * x**(N-1) + p[1] * x**(N-2) + ... + p[-2] * x + p[-1]
def polyN(x:T, p:list[float]) -> T: return functools.reduce(lambda acc,c: acc*x+c, p, 0.0) # type: ignore

View File

@@ -1048,14 +1048,15 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
return output, present
def attention_onnx(Q:Tensor, K:Tensor, V:Tensor, attn_mask:Tensor|None=None, past_key:Tensor|None=None, past_value:Tensor|None=None,
is_causal:int=0, kv_num_heads:int|None=None, q_num_heads:int|None=None, qk_matmul_output_mode:int=0, scale:float|None=None,
softcap:float=0.0, softmax_precision:int|None=None):
nonpad_kv_seqlen:Tensor|None=None, is_causal:int=0, kv_num_heads:int|None=None, q_num_heads:int|None=None,
qk_matmul_output_mode:int=0, scale:float|None=None, softcap:float=0.0, softmax_precision:int|None=None):
if nonpad_kv_seqlen is not None: raise NotImplementedError("nonpad_kv_seqlen is not supported")
input_shape_len = Q.ndim
if input_shape_len == 3:
assert q_num_heads is not None and kv_num_heads is not None
Q = Q.reshape(Q.shape[0], q_num_heads, Q.shape[1], -1)
K = K.reshape(K.shape[0], kv_num_heads, K.shape[1], -1)
V = V.reshape(V.shape[0], kv_num_heads, V.shape[1], -1)
Q = Q.reshape(Q.shape[0], Q.shape[1], q_num_heads, -1).permute(0, 2, 1, 3)
K = K.reshape(K.shape[0], K.shape[1], kv_num_heads, -1).permute(0, 2, 1, 3)
V = V.reshape(V.shape[0], V.shape[1], kv_num_heads, -1).permute(0, 2, 1, 3)
if past_key is not None: K = past_key.cat(K, dim=2)
if past_value is not None: V = past_value.cat(V, dim=2)
@@ -1170,6 +1171,17 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
elif reduction == "min": x[i] = x[i].minimum(u)
return x
def TensorScatter(data: Tensor, updates: Tensor, indices: Tensor, mode: str = 'default'):
# scatter updates along axis -2 at positions given by indices, for each batch
B, U, D = indices.shape[0], updates.shape[-2], data.shape[-2]
orig_shape, data_flat, updates_flat = data.shape, data.reshape(-1, D, data.shape[-1]), updates.reshape(-1, U, updates.shape[-1])
B_total = data_flat.shape[0]
batch_idx = Tensor.arange(B_total, device=data.device).reshape(B_total, 1).expand(B_total, U)
indices_expanded = indices.reshape(B, *([1] * (data.ndim - 3))).expand(*orig_shape[:-2]).reshape(B_total)
row_idx = indices_expanded.reshape(B_total, 1).expand(B_total, U) + Tensor.arange(U, device=data.device).reshape(1, U).expand(B_total, U)
if mode == 'circular': row_idx = row_idx % D
return ScatterND(data_flat, batch_idx.unsqueeze(-1).cat(row_idx.unsqueeze(-1), dim=-1), updates_flat).reshape(orig_shape)
def ScatterElements(x: Tensor, indices: Tensor, updates: Tensor, axis=0, reduction:Literal["none", "add", "mul", "min", "max"]="none"):
indices = (indices < 0).where(x.shape[axis], 0) + indices
if reduction == "none": return x.scatter(axis, indices, updates)

View File

@@ -132,7 +132,7 @@ def __getattr__(nm):
tarball="https://gitlab.freedesktop.org/mesa/mesa/-/archive/mesa-25.2.7/mesa-25.2.7.tar.gz",
prolog=["import gzip, base64"], epilog=lambda path: [system(f"{root}/extra/mesa/lvp_nir_options.sh {path}")])
case "libclang":
return load("libclang", "'clang-20'",
return load("libclang", "['clang-20', 'clang']",
lambda: [f"{system('llvm-config-20 --includedir')}/clang-c/{s}.h" for s in ["Index", "CXString", "CXSourceLocation", "CXFile"]],
args=lambda: system("llvm-config-20 --cflags").split())
case "metal":

View File

@@ -1,7 +1,7 @@
# mypy: ignore-errors
import ctypes
from tinygrad.runtime.support.c import DLL, Struct, CEnum, _IO, _IOW, _IOR, _IOWR
dll = DLL('libclang', 'clang-20')
dll = DLL('libclang', ['clang-20', 'clang'])
CXIndex = ctypes.c_void_p
class struct_CXTargetInfoImpl(Struct): pass
CXTargetInfo = ctypes.POINTER(struct_CXTargetInfoImpl)

View File

@@ -530,7 +530,6 @@ add_tags = PatternMatcher([
])
# support for using a contiguous permuted view instead of the parent view if one exists
# modified from kernelize.py to not use ShapeTracker
def found_contiguous(ctx:dict[UOp, UOp], contig:UOp, src:UOp):
x = src

View File

@@ -18,7 +18,8 @@ from tinygrad.device import Device, Buffer
from tinygrad.engine.realize import run_schedule
# TODO: this should be the only usage of Device
def canonicalize_device(device:str|None) -> str: return Device.canonicalize(device)
def canonicalize_device(device:str|tuple|list|None) -> str|tuple[str, ...]:
return tuple(Device.canonicalize(d) for d in device) if isinstance(device, (tuple, list)) else Device.canonicalize(device)
# *** all in scope Tensors are here. this gets relevant UOps ***
@@ -115,7 +116,7 @@ class Tensor(OpMixin):
device:str|tuple|list|None=None, dtype:DTypeLike|None=None, requires_grad:bool|None=None, _force_unique:bool=False):
if device is None and isinstance(data, pathlib.Path): device = f"DISK:{data.resolve()}" # keep it on the disk if device is None
_dtype:DType|None = to_dtype(dtype) if dtype is not None else None
_device:str|tuple[str, ...] = tuple(canonicalize_device(x) for x in device) if isinstance(device, (tuple, list)) else canonicalize_device(device)
_device:str|tuple[str, ...] = canonicalize_device(device)
del device, dtype
# tensors can have gradients if you have called .backward
@@ -162,10 +163,10 @@ class Tensor(OpMixin):
# data might be on a different device
if isinstance(_device, str): self.uop:UOp = data if data.device == _device else data.copy_to_device(_device)
# if device is a tuple, we should have/construct a MultiLazyBuffer
# if device is a tuple, we should have/construct a multi-device UOp
elif isinstance(data.device, str): self.uop = Tensor(data).shard(_device).uop
else:
assert data.device == _device, f"MultiLazyBuffer device mismatch, {data.device} != {_device}"
assert data.device == _device, f"multi-device UOp device mismatch, {data.device} != {_device}"
self.uop = data
# add to all_tensors after construction succeeds
@@ -373,7 +374,7 @@ class Tensor(OpMixin):
"""
Moves the tensor to the given device.
"""
device = tuple(canonicalize_device(x) for x in device) if isinstance(device, (tuple, list)) else canonicalize_device(device)
device = canonicalize_device(device)
if device == self.device: return self
if not isinstance(device, str): return self.shard(device)
ret = Tensor(self.uop, device, requires_grad=self.requires_grad)
@@ -397,9 +398,9 @@ class Tensor(OpMixin):
print(t.shard((t.device, t.device), axis=1).uop)
```
"""
if not isinstance(self.device, str): raise RuntimeError("can't shard a MultiLazyBuffer")
if not isinstance(self.device, str): raise RuntimeError("can't shard a multi-device tensor")
if len(devices) == 1: return self.to(devices[0])
devices = tuple(canonicalize_device(x) for x in devices)
devices = cast(tuple[str, ...], canonicalize_device(devices))
mlb = self.uop.shard(devices, self._resolve_dim(axis)) if axis is not None else self.uop.copy_to_device(devices)
return Tensor(mlb, device=devices, requires_grad=self.requires_grad)
@@ -495,7 +496,7 @@ class Tensor(OpMixin):
dtype, shape = to_dtype(dtype) if dtype is not None else dtypes.default_float, argfix(*shape)
if not isinstance(size:=prod([x.vmax if isinstance(x, UOp) else x for x in shape]), int): raise ValueError(f"size must be int {size}")
# TODO: add test for multidevice tensor
device = tuple(canonicalize_device(d) for d in device) if isinstance(device, tuple) else canonicalize_device(device)
device = canonicalize_device(device)
return Tensor(UOp.new_buffer(device, size, dtype), device, dtype, **kwargs).shrink(((0,prod(shape)),)).reshape(shape)
def empty_like(self, **kwargs) -> Tensor:
@@ -577,7 +578,7 @@ class Tensor(OpMixin):
if not dtypes.is_float(dtype := to_dtype(dtype or dtypes.default_float)): raise ValueError(f"rand only supports float dtypes, got {dtype}")
if not all_int(shape:=argfix(*shape)) or not all(s >= 0 for s in shape): raise ValueError(f"invalid input {shape=}")
if device is not None and not isinstance(device, str): raise ValueError(f"rand only supports single device, got {device=}")
device = canonicalize_device(device)
device = cast(str, canonicalize_device(device))
# if shape has 0, return zero tensor
if (numel := prod(shape)) == 0: return Tensor.zeros(shape, device=device, dtype=dtype, **kwargs)
@@ -2062,37 +2063,33 @@ class Tensor(OpMixin):
print(Tensor.einsum("ij,ij->", x, y).numpy())
```
"""
def parse_formula(formula:str, *operands:Tensor):
if "..." in (formula := formula.replace(" ", "")):
ell_chars, ell_longest = "".join(c for c in string.ascii_letters if c not in formula), 0
for i, inp in enumerate(filter(lambda x: "..." in x, inputs := formula.split("->")[0].split(","))):
if (ell_count := max(operands[i].ndim, 1) - (len(inp) - len("..."))) > ell_longest: ell_longest = ell_count
inputs[i] = inp.replace("...", ell_chars[-ell_count:])
inputs_str, out_ellipse = ",".join(inputs), ell_chars[-ell_longest:]
return (inputs_str, formula.split("->")[1].replace("...", out_ellipse)) if "->" in formula else \
(inputs_str, out_ellipse + ''.join(sorted(c for c in inputs_str if inputs_str.count(c) == 1 and c.isalpha() and c not in out_ellipse)))
return formula.split("->") if "->" in formula else (formula, ''.join(c for c in sorted(formula) if formula.count(c) == 1 and c.isalpha()))
xs:tuple[Tensor, ...] = argfix(*operands)
inputs_str, output = parse_formula(formula, *xs)
inputs = inputs_str.split(",")
if len(xs)!=len(inputs): raise ValueError(f"number of inputs doesn't match number of operands in formula, expected {len(inputs)}, got {len(xs)}")
# map the value of each letter in the formula
letter_val = sorted(merge_dicts([dict(zip(letters, tensor.shape)) for letters, tensor in zip(inputs, xs)]).items())
xs_:list[Tensor] = []
lhs = [sorted(enumerate(s), key=lambda e:e[1]) for s in inputs]
for x,(order,letters) in zip(xs, [list(zip(*l)) for l in lhs]):
# permute to the sorted letter order, then reshape/expand to create dimensions for the missing letters
xs_.append(x.permute(order).reshape([val if letter in letters else 1 for letter,val in letter_val]).expand([val for _,val in letter_val]))
# ordinal encode the output alphabet
rhs_order = argsort(argsort(list(output)))
# sum over all axes that's not in the output, then permute to the output order
return functools.reduce(lambda a,b:a*b, xs_) \
.sum(axis=[axis for axis,(letter,_) in enumerate(letter_val) if letter not in output], dtype=dtype).permute(rhs_order)
xs, formula = list(argfix(*operands)), formula.replace(" ", "")
# expand ellipsis to letters, determine output
if "..." in formula:
ell, lhs = "".join(c for c in string.ascii_letters if c not in formula), (formula.split("->") + [""])[0]
ell_n = [max(0, x.ndim - len(s) + 3) if "..." in s else 0 for s, x in zip(lhs.split(","), xs)]
for i, (s, x) in enumerate(zip(inputs := lhs.split(","), xs)): inputs[i] = s.replace("...", ell[max(ell_n)-ell_n[i]:max(ell_n)])
lhs, auto = ",".join(inputs), "".join(sorted(c for c in lhs if lhs.count(c) == 1 and c.isalpha() and c not in ell))
formula = f"{lhs}->{formula.split('->')[1].replace('...', ell[:max(ell_n)]) if '->' in formula else ell[:max(ell_n)] + auto}"
lhs, rhs = formula.split("->") if "->" in formula else (formula, "".join(sorted(c for c in formula if formula.count(c)==1 and c.isalpha())))
inputs = lhs.split(",")
if len(xs) != len(inputs): raise ValueError(f"number of operands doesn't match, expected {len(inputs)}, got {len(xs)}")
# trace: take diagonal when letter repeats in single input
for i, (s, x) in enumerate(zip(inputs, xs)):
for c in set(s):
while s.count(c) > 1:
j, k, n = s.index(c), s.index(c, s.index(c)+1), cast(int, x.shape[s.index(c)])
perm = [d for d in range(x.ndim) if d not in (j,k)]+[j,k]
x = x.permute(perm).flatten(-2).pad(((0,0),)*(x.ndim-2)+((0,n),)).unflatten(-1,(n,n+1))[...,0] if x.ndim > 2 else x.diagonal()
s = s[:k] + s[k+1:]
inputs[i], xs[i] = s, x
# check sizes and build sorted alphabet
sz = merge_dicts([dict(zip(s, x.shape)) for s, x in zip(inputs, xs)])
alpha = sorted(sz)
# align all tensors to alphabet, multiply, sum non-output, permute to output order
xs = [x.permute(*[s.index(c) for c in sorted(s)]).reshape([sz[c] if c in s else 1 for c in alpha]).expand([sz[c] for c in alpha]) if s else x
for s, x in zip(inputs, xs)]
return functools.reduce(lambda a,b:a*b, xs).sum([i for i,c in enumerate(alpha) if c not in rhs], dtype=dtype).permute(argsort(argsort(list(rhs))))
# ***** processing ops *****

View File

@@ -3,7 +3,7 @@ from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp
from tinygrad.dtype import dtypes
from tinygrad.helpers import cdiv, cmod, CORRECT_DIVMOD_FOLDING, unwrap
# NOTE: this cache is only on index UOps and matches the cache in the old ShapeTracker in spirit
# NOTE: this cache is only on index UOps
@functools.cache
def fold_divmod_general(d: UOp, correct_divmod_folding: bool) -> UOp|None:
x, y = d.src

View File

@@ -240,7 +240,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
case Ops.RESHAPE:
if self.src[0]._shape is None: return self.marg
# movement ops change the shape. this is the logic from the old ShapeTracker
# movement ops change the shape
# NOTE: ssimplify is required because the shape needs to be canonical for broadcasting and same shape checking
if self.op in GroupOp.Movement.union({Ops.MULTI, Ops.REDUCE_AXIS, Ops.WMMA}):
ps = self.src[0]._shape
@@ -475,14 +475,12 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
return UOp(Ops.ALLREDUCE, self.dtype, (self, UOp(Ops.DEVICE, arg=device) if not isinstance(device, UOp) else device), op)
def overflows(self, dtype:DType) -> bool: return self.vmin < dtype.min or dtype.max < self.vmax
# *** ShapeTracker helpers ***
def split_uop(self:UOp, sep:Ops):
if self.op is sep:
for s in self.src: yield from s.split_uop(sep)
else: yield self
# *** from MultiLazyBuffer ***
# *** multi-device helpers ***
def multi(self, axis:int|None):
assert isinstance(self.device, tuple), f"multi device must be tuple, {self.device} isn't"
@@ -524,8 +522,6 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
return self.shrink(tuple((0,s) if i != axis else (dnum*sz,dnum*sz+sz) for i,s in enumerate(self.shape)))
def shard(self, devices:tuple[str, ...], axis:int) -> UOp: return self.copy_to_device(devices)._shard(axis).multi(axis)
# *** from LazyBuffer ***
def copy_to_device(self, device:str|tuple[str, ...]|UOp, arg=None):
assert arg is None or isinstance(self.device, tuple)
inp = self if arg is None else UOp(Ops.MSELECT, self.dtype, src=(self,), arg=arg)

View File

@@ -30,13 +30,13 @@ const Status = {STARTED:0, COMPLETE:1, ERR:2}
const updateProgress = (st, msg) => {
clearTimeout(timeout);
const msgEl = d3.select("#progress-message").style("display", "none");
const customEl = d3.select("#custom").html("");
const customEl = d3.select("#custom").style("display", "none");
if (st === Status.STARTED) {
msgEl.text(msg);
timeout = setTimeout(() => msgEl.style("display", "block"), 2000);
} else if (st === Status.ERR) {
displaySelection("#custom");
customEl.append("div").classed("raw-text", true).append(() => codeBlock(msg));
customEl.html("").append("div").classed("raw-text", true).append(() => codeBlock(msg));
}
}
@@ -685,9 +685,15 @@ window.addEventListener("popstate", (e) => {
if (e.state != null) setState(e.state);
});
const toggleLabel = d3.create("label").text("Show indexing (r)").node();
const toggle = d3.create("input").attr("type", "checkbox").attr("id", "show-indexing").property("checked", true).node();
toggleLabel.prepend(toggle);
const createToggle = (id, text) => {
const label = d3.create("label").text(text).node();
const toggle = d3.create("input").attr("type", "checkbox").attr("id", id).property("checked", true).node();
label.prepend(toggle);
return { toggle, label };
}
const { toggle, label:toggleLabel } = createToggle("show-indexing", "Show indexing (r)");
const showGraph = createToggle("show-graph", "Show graph (g)");
showGraph.toggle.onchange = () => displaySelection(rect("#graph").width > 0 ? "#custom" : "#graph");
function appendSteps(root, idx, steps) {
const stack = [];
@@ -748,7 +754,6 @@ async function main() {
if (ckey in cache) {
ret = cache[ckey];
}
// ** Text view
if (!ckey.startsWith("/graph")) {
if (!(ckey in cache)) cache[ckey] = ret = await fetchValue(ckey);
if (ret.steps?.length > 0) {
@@ -760,15 +765,25 @@ async function main() {
appendSteps(el.ctx, state.currentCtx, ctx.steps);
return setState({ currentStep:state.currentStep+1, expandSteps:true });
}
// cycles on the x axis
// timeline with cycles on the x axis
if (ret instanceof ArrayBuffer) {
opts = {heightScale:0.5, hideLabels:true, levelKey:(e) => parseInt(e.name.split(" ")[1].split(":")[1])};
return renderProfiler(ckey, "clk", opts);
}
displaySelection("#custom");
metadata.innerHTML = "";
ret.metadata?.forEach(m => {
if (Array.isArray(m)) return metadata.appendChild(tabulate(m.map(({ label, value }) => {
return [label.trim(), typeof value === "string" ? value : formatUnit(value)];
})).node());
metadata.appendChild(codeBlock(m.src)).classList.add("full-height")
});
// graph render
if (ret.data != null) {
metadata.prepend(showGraph.label);
renderDag(ret, { recenter:true });
} else displaySelection("#custom");
// table / plaintext render
const root = d3.create("div").classed("raw-text", true);
// detailed assembly view
function renderTable(root, ret) {
const table = root.append("table");
const thead = table.append("thead");
@@ -797,14 +812,7 @@ async function main() {
return table;
}
if (ret.cols != null) renderTable(root, ret);
else if (ret.data != null) renderDag(ret, { recenter:true });
else if (ret.src != null) root.append(() => codeBlock(ret.src, ret.lang));
ret.metadata?.forEach(m => {
if (Array.isArray(m)) return metadata.appendChild(tabulate(m.map(({ label, value }) => {
return [label.trim(), typeof value === "string" ? value : formatUnit(value)];
})).node());
metadata.appendChild(codeBlock(m.src)).classList.add("full-height")
});
return document.querySelector("#custom").replaceChildren(root.node());
}
// ** Graph view
@@ -961,9 +969,9 @@ document.addEventListener("keydown", (event) => {
document.getElementById("zoom-to-fit-btn").click();
}
// r key toggles indexing
if (event.key === "r") {
toggle.click();
}
if (event.key === "r") toggle.click();
// g key toggles graph
if (event.key === "g") showGraph.toggle.click();
});
main()

View File

@@ -406,7 +406,10 @@ def amdgpu_cfg(lib:bytes, target:int) -> dict:
curr:int|None = None
blocks:dict[int, list[int]] = {}
paths:dict[int, dict[int, int]] = {}
lines:list[str] = []
asm_width = max(len(asm) for asm, _ in pc_table.values())
for pc, (asm, sz) in pc_table.items():
lines.append(f" {asm:<{asm_width}} // {pc:012X}")
if pc in leaders:
paths[curr:=pc] = {}
blocks[pc] = []
@@ -420,7 +423,7 @@ def amdgpu_cfg(lib:bytes, target:int) -> dict:
if asm.startswith("s_branch"): paths[curr][nx+offset] = UNCOND
else: paths[curr].update([(nx+offset, COND_TAKEN), (nx, COND_NOT_TAKEN)])
elif nx in leaders: paths[curr][nx] = UNCOND
return {"blocks":blocks, "paths":paths, "pc_table":pc_table, "colors":cfg_colors}
return {"data":{"blocks":blocks, "paths":paths, "pc_table":pc_table, "colors":cfg_colors}, "src":"\n".join(lines)}
# ** Main render function to get the complete details about a trace event
@@ -435,7 +438,7 @@ def get_render(query:str) -> dict:
ret:dict = {"metadata":[]}
if data.device.startswith("AMD") and data.lib is not None:
with soft_err(lambda err: ret.update(err)):
ret["data"] = amdgpu_cfg(lib:=data.lib, device_props[data.device]["gfx_target_version"])
ret.update(amdgpu_cfg(lib:=data.lib, device_props[data.device]["gfx_target_version"]))
with soft_err(lambda err: ret["metadata"].append(err)): ret["metadata"].append(amd_readelf(lib))
else: ret["src"] = get_stdout(lambda: (compiler:=Device[data.device].compiler).disassemble(compiler.compile(data.src)))
return ret