mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
Merge origin/master, delete pcode.py
This commit is contained in:
51
.github/workflows/autogen.yml
vendored
51
.github/workflows/autogen.yml
vendored
@@ -40,11 +40,13 @@ jobs:
|
||||
- name: Install autogen support packages
|
||||
run: sudo apt-get install -y --no-install-recommends libclang-20-dev llvm-20-dev hip-dev libusb-1.0-0-dev
|
||||
- name: Verify OpenCL autogen
|
||||
continue-on-error: true
|
||||
run: |
|
||||
mv tinygrad/runtime/autogen/opencl.py /tmp/opencl.py.bak
|
||||
python3 -c "from tinygrad.runtime.autogen import opencl"
|
||||
diff /tmp/opencl.py.bak tinygrad/runtime/autogen/opencl.py
|
||||
- name: Verify CUDA autogen
|
||||
continue-on-error: true
|
||||
run: |
|
||||
mv tinygrad/runtime/autogen/cuda.py /tmp/cuda.py.bak
|
||||
mv tinygrad/runtime/autogen/nvrtc.py /tmp/nvrtc.py.bak
|
||||
@@ -58,6 +60,7 @@ jobs:
|
||||
diff /tmp/nv_570.py.bak tinygrad/runtime/autogen/nv_570.py
|
||||
diff /tmp/nv.py.bak tinygrad/runtime/autogen/nv.py
|
||||
- name: Verify AMD autogen
|
||||
continue-on-error: true
|
||||
run: |
|
||||
mv tinygrad/runtime/autogen/comgr.py /tmp/comgr.py.bak
|
||||
mv tinygrad/runtime/autogen/hsa.py /tmp/hsa.py.bak
|
||||
@@ -89,6 +92,7 @@ jobs:
|
||||
diff /tmp/am_smu_v13_0_0.py.bak tinygrad/runtime/autogen/am/smu_v13_0_0.py
|
||||
diff /tmp/am_smu_v14_0_2.py.bak tinygrad/runtime/autogen/am/smu_v14_0_2.py
|
||||
- name: Verify Linux autogen
|
||||
continue-on-error: true
|
||||
run: |
|
||||
mv tinygrad/runtime/autogen/libc.py /tmp/libc.py.bak
|
||||
mv tinygrad/runtime/autogen/kfd.py /tmp/kfd.py.bak
|
||||
@@ -104,16 +108,19 @@ jobs:
|
||||
diff /tmp/pci.py.bak tinygrad/runtime/autogen/pci.py
|
||||
diff /tmp/vfio.py.bak tinygrad/runtime/autogen/vfio.py
|
||||
- name: Verify LLVM autogen
|
||||
continue-on-error: true
|
||||
run: |
|
||||
mv tinygrad/runtime/autogen/llvm.py /tmp/llvm.py.bak
|
||||
python3 -c "from tinygrad.runtime.autogen import llvm"
|
||||
diff /tmp/llvm.py.bak tinygrad/runtime/autogen/llvm.py
|
||||
- name: Verify WebGPU autogen
|
||||
continue-on-error: true
|
||||
run: |
|
||||
mv tinygrad/runtime/autogen/webgpu.py /tmp/webgpu.py.bak
|
||||
python3 -c "from tinygrad.runtime.autogen import webgpu"
|
||||
diff /tmp/webgpu.py.bak tinygrad/runtime/autogen/webgpu.py
|
||||
- name: Verify Qualcomm autogen
|
||||
continue-on-error: true
|
||||
run: |
|
||||
mv tinygrad/runtime/autogen/kgsl.py /tmp/kgsl.py.bak
|
||||
mv tinygrad/runtime/autogen/qcom_dsp.py /tmp/qcom_dsp.py.bak
|
||||
@@ -121,20 +128,36 @@ jobs:
|
||||
diff /tmp/kgsl.py.bak tinygrad/runtime/autogen/kgsl.py
|
||||
diff /tmp/qcom_dsp.py.bak tinygrad/runtime/autogen/qcom_dsp.py
|
||||
- name: Verify libusb autogen
|
||||
continue-on-error: true
|
||||
run: |
|
||||
mv tinygrad/runtime/autogen/libusb.py /tmp/libusb.py.bak
|
||||
python3 -c "from tinygrad.runtime.autogen import libusb"
|
||||
diff /tmp/libusb.py.bak tinygrad/runtime/autogen/libusb.py
|
||||
- name: Verify mesa autogen
|
||||
continue-on-error: true
|
||||
run: |
|
||||
mv tinygrad/runtime/autogen/mesa.py /tmp/mesa.py.bak
|
||||
python3 -c "from tinygrad.runtime.autogen import mesa"
|
||||
diff /tmp/mesa.py.bak tinygrad/runtime/autogen/mesa.py
|
||||
- name: Verify libclang autogen
|
||||
continue-on-error: true
|
||||
run: |
|
||||
cp tinygrad/runtime/autogen/libclang.py /tmp/libclang.py.bak
|
||||
REGEN=1 python3 -c "from tinygrad.runtime.autogen import libclang"
|
||||
diff /tmp/libclang.py.bak tinygrad/runtime/autogen/libclang.py
|
||||
- name: Generate patch for differences
|
||||
run: |
|
||||
if ! git diff --quiet; then
|
||||
git diff > autogen-ubuntu.patch
|
||||
fi
|
||||
- name: Upload patch artifact
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: autogen-ubuntu-patch
|
||||
path: autogen-ubuntu.patch
|
||||
if-no-files-found: ignore
|
||||
- name: Fail if differences found
|
||||
run: git diff --quiet
|
||||
autogen-mac:
|
||||
name: In-tree Autogen (macos)
|
||||
runs-on: macos-14
|
||||
@@ -147,10 +170,24 @@ jobs:
|
||||
with:
|
||||
llvm: 'true'
|
||||
- name: Verify macos autogen
|
||||
continue-on-error: true
|
||||
run: |
|
||||
mv tinygrad/runtime/autogen/metal.py /tmp/metal.py.bak
|
||||
LIBCLANG_PATH=/opt/homebrew/opt/llvm@20/lib/libclang.dylib python3 -c "from tinygrad.runtime.autogen import metal"
|
||||
diff /tmp/metal.py.bak tinygrad/runtime/autogen/metal.py
|
||||
- name: Generate patch for differences
|
||||
run: |
|
||||
if ! git diff --quiet; then
|
||||
git diff > autogen-macos.patch
|
||||
fi
|
||||
- name: Upload patch artifact
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: autogen-macos-patch
|
||||
path: autogen-macos.patch
|
||||
if-no-files-found: ignore
|
||||
- name: Fail if differences found
|
||||
run: git diff --quiet
|
||||
autogen-comgr-3:
|
||||
name: In-tree Autogen (comgr 3)
|
||||
runs-on: ubuntu-24.04
|
||||
@@ -170,7 +207,21 @@ jobs:
|
||||
sudo apt -qq update || true
|
||||
sudo apt-get install -y --no-install-recommends libclang-20-dev comgr
|
||||
- name: Verify comgr (3) autogen
|
||||
continue-on-error: true
|
||||
run: |
|
||||
mv tinygrad/runtime/autogen/comgr_3.py /tmp/comgr_3.py.bak
|
||||
python3 -c "from tinygrad.runtime.autogen import comgr_3"
|
||||
diff /tmp/comgr_3.py.bak tinygrad/runtime/autogen/comgr_3.py
|
||||
- name: Generate patch for differences
|
||||
run: |
|
||||
if ! git diff --quiet; then
|
||||
git diff > autogen-comgr3.patch
|
||||
fi
|
||||
- name: Upload patch artifact
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: autogen-comgr3-patch
|
||||
path: autogen-comgr3.patch
|
||||
if-no-files-found: ignore
|
||||
- name: Fail if differences found
|
||||
run: git diff --quiet
|
||||
|
||||
26
.github/workflows/test.yml
vendored
26
.github/workflows/test.yml
vendored
@@ -257,6 +257,7 @@ jobs:
|
||||
key: unittest-12
|
||||
pydeps: "pillow numpy ftfy regex"
|
||||
deps: testing_unit
|
||||
llvm: 'true'
|
||||
- name: Check Device.DEFAULT
|
||||
run: python -c "from tinygrad import Device; assert Device.DEFAULT == 'CPU', Device.DEFAULT"
|
||||
- name: Run unit tests
|
||||
@@ -309,7 +310,7 @@ jobs:
|
||||
deps: testing_unit
|
||||
python-version: '3.14'
|
||||
- name: Test SPEC=2
|
||||
run: SPEC=2 pytest --maxfail=10 -n auto --durations=30 --ignore=test/models --ignore test/test_custom_kernel.py --ignore test/unit/test_hashing.py --timeout 60 -k "not test_setitem_big" --splits 2 --group ${{ matrix.group }}
|
||||
run: SPEC=2 pytest --maxfail=10 -n auto --durations=30 --ignore=test/models --ignore test/test_custom_kernel.py --ignore test/unit/test_hashing.py --ignore test/unit/test_autogen.py --timeout 60 -k "not test_setitem_big" --splits 2 --group ${{ matrix.group }}
|
||||
|
||||
fuzzing:
|
||||
name: Fuzzing
|
||||
@@ -669,6 +670,10 @@ jobs:
|
||||
deps: testing_minimal
|
||||
amd: 'true'
|
||||
python-version: '3.13'
|
||||
- name: Verify AMD autogen is up to date
|
||||
run: |
|
||||
python -m extra.assembly.amd.pdf
|
||||
git diff --exit-code extra/assembly/amd/autogen/
|
||||
- name: Install LLVM 21
|
||||
run: |
|
||||
wget -qO- https://apt.llvm.org/llvm-snapshot.gpg.key | sudo tee /etc/apt/trusted.gpg.d/apt.llvm.org.asc
|
||||
@@ -689,23 +694,6 @@ jobs:
|
||||
- name: Run RDNA3 ops tests
|
||||
run: SKIP_SLOW_TEST=1 AMD=1 PYTHON_REMU=1 MOCKGPU=1 AMD_LLVM=0 pytest -n=auto test/test_ops.py -k "test_sparse_categorical_crossentropy or test_tril"
|
||||
|
||||
testamdautogen:
|
||||
name: AMD autogen
|
||||
runs-on: ubuntu-24.04
|
||||
timeout-minutes: 10
|
||||
steps:
|
||||
- name: Checkout Code
|
||||
uses: actions/checkout@v4
|
||||
- name: Setup Environment
|
||||
uses: ./.github/actions/setup-tinygrad
|
||||
with:
|
||||
key: rdna3-autogen
|
||||
pydeps: "pdfplumber"
|
||||
- name: Verify AMD autogen is up to date
|
||||
run: |
|
||||
python -m extra.assembly.amd.pdf --arch all
|
||||
git diff --exit-code extra/assembly/amd/autogen/
|
||||
|
||||
testnvidia:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
@@ -793,6 +781,8 @@ jobs:
|
||||
ocelot: 'true'
|
||||
llvm: 'true'
|
||||
- name: Run unit tests
|
||||
env:
|
||||
LIBCLANG_PATH: '/opt/homebrew/opt/llvm@20/lib/libclang.dylib'
|
||||
run: METAL=1 python -m pytest -n=auto test/unit/ --durations=20
|
||||
- name: Run ONNX
|
||||
run: METAL=1 python -m pytest -n=auto test/external/external_test_onnx_backend.py --durations=20
|
||||
|
||||
15
CLAUDE.md
15
CLAUDE.md
@@ -192,9 +192,12 @@ When optimizing tinygrad internals:
|
||||
|
||||
9. **Avoid creating intermediate objects in hot paths** - For example, `any(x.op in ops for x in self.backward_slice)` is faster than `any(x.op in ops for x in {self:None, **self.backward_slice})` because it avoids dict creation.
|
||||
|
||||
## Pattern Matching Profiling
|
||||
## Pattern Matching Analysis
|
||||
|
||||
Use `TRACK_MATCH_STATS=2` to identify expensive patterns:
|
||||
**Use the right tool:**
|
||||
|
||||
- `TRACK_MATCH_STATS=2` - **Profiling**: identify expensive patterns
|
||||
- `VIZ=-1` - **Inspection**: see all transformations, what every match pattern does, the before/after diffs
|
||||
|
||||
```bash
|
||||
TRACK_MATCH_STATS=2 PYTHONPATH="." python3 test/external/external_benchmark_schedule.py
|
||||
@@ -209,6 +212,14 @@ Key patterns to watch (from ResNet50 benchmark):
|
||||
|
||||
Patterns with 0% match rate are workload-specific overhead. They may be useful in other workloads, so don't remove them without understanding their purpose.
|
||||
|
||||
```bash
|
||||
# Save the trace
|
||||
VIZ=-1 python test/test_tiny.py TestTiny.test_gemm
|
||||
|
||||
# Explore it
|
||||
./extra/viz/cli.py --help
|
||||
```
|
||||
|
||||
## AMD Performance Counter Profiling
|
||||
|
||||
Set VIZ to `-2` to save performance counters traces for the AMD backend.
|
||||
|
||||
@@ -153,8 +153,7 @@ class SMICtx:
|
||||
tables = {}
|
||||
for dev in self.devs:
|
||||
match dev.ip_ver[am.MP1_HWIP]:
|
||||
case (13,0,6): table_t = dev.smu.smu_mod.MetricsTableX_t
|
||||
case (13,0,12): table_t = dev.smu.smu_mod.MetricsTableV2_t
|
||||
case (13,0,6)|(13,0,12): table_t = dev.smu.smu_mod.MetricsTableX_t
|
||||
case _: table_t = dev.smu.smu_mod.SmuMetricsExternal_t
|
||||
tables[dev] = dev.smu.read_table(table_t, dev.smu.smu_mod.SMU_TABLE_SMU_METRICS) if dev.pci_state == "D0" else None
|
||||
return tables
|
||||
@@ -165,17 +164,17 @@ class SMICtx:
|
||||
|
||||
def get_gfx_activity(self, dev, metrics):
|
||||
match dev.ip_ver[am.MP1_HWIP]:
|
||||
case (13,0,6): return max(0, min(100, self._smuq10_round(metrics.SocketGfxBusy)))
|
||||
case (13,0,6)|(13,0,12): return max(0, min(100, self._smuq10_round(metrics.SocketGfxBusy)))
|
||||
case _: return metrics.SmuMetrics.AverageGfxActivity
|
||||
|
||||
def get_mem_activity(self, dev, metrics):
|
||||
match dev.ip_ver[am.MP1_HWIP]:
|
||||
case (13,0,6): return max(0, min(100, self._smuq10_round(metrics.DramBandwidthUtilization)))
|
||||
case (13,0,6)|(13,0,12): return max(0, min(100, self._smuq10_round(metrics.DramBandwidthUtilization)))
|
||||
case _: return metrics.SmuMetrics.AverageUclkActivity
|
||||
|
||||
def get_temps(self, dev, metrics, compact=False):
|
||||
match dev.ip_ver[am.MP1_HWIP]:
|
||||
case (13,0,6):
|
||||
case (13,0,6)|(13,0,12):
|
||||
temps = {
|
||||
"Hotspot": self._smuq10_round(metrics.MaxSocketTemperature),
|
||||
"HBM": self._smuq10_round(metrics.MaxHbmTemperature),
|
||||
@@ -191,7 +190,7 @@ class SMICtx:
|
||||
|
||||
def get_voltage(self, dev, metrics, compact=False):
|
||||
match dev.ip_ver[am.MP1_HWIP]:
|
||||
case (13,0,6): return {}
|
||||
case (13,0,6)|(13,0,12): return {}
|
||||
case _:
|
||||
voltage_keys = [(k, name) for k, name in dev.smu.smu_mod.SVI_PLANE_e.items()
|
||||
if k < dev.smu.smu_mod.SVI_PLANE_COUNT and metrics.SmuMetrics.AvgVoltage[k] != 0]
|
||||
@@ -205,33 +204,33 @@ class SMICtx:
|
||||
def get_gfx_freq(self, dev, metrics):
|
||||
if metrics is None: return 0
|
||||
match dev.ip_ver[am.MP1_HWIP]:
|
||||
case (13,0,6): return self._smuq10_round(metrics.GfxclkFrequency[0])
|
||||
case (13,0,6)|(13,0,12): return self._smuq10_round(metrics.GfxclkFrequency[0])
|
||||
case _:
|
||||
return metrics.SmuMetrics.AverageGfxclkFrequencyPostDs if self.get_gfx_activity(dev, metrics) <= self.get_busy_threshold(dev) else \
|
||||
metrics.SmuMetrics.AverageGfxclkFrequencyPreDs
|
||||
|
||||
def get_mem_freq(self, dev, metrics):
|
||||
match dev.ip_ver[am.MP1_HWIP]:
|
||||
case (13,0,6): return self._smuq10_round(metrics.UclkFrequency)
|
||||
case (13,0,6)|(13,0,12): return self._smuq10_round(metrics.UclkFrequency)
|
||||
case _:
|
||||
return metrics.SmuMetrics.AverageMemclkFrequencyPostDs if self.get_mem_activity(dev, metrics) <= self.get_busy_threshold(dev) else \
|
||||
metrics.SmuMetrics.AverageMemclkFrequencyPreDs
|
||||
|
||||
def get_fckl_freq(self, dev, metrics):
|
||||
match dev.ip_ver[am.MP1_HWIP]:
|
||||
case (13,0,6): return self._smuq10_round(metrics.FclkFrequency)
|
||||
case (13,0,6)|(13,0,12): return self._smuq10_round(metrics.FclkFrequency)
|
||||
case _:
|
||||
return metrics.SmuMetrics.AverageFclkFrequencyPostDs if self.get_mem_activity(dev, metrics) <= self.get_busy_threshold(dev) else \
|
||||
metrics.SmuMetrics.AverageFclkFrequencyPreDs
|
||||
|
||||
def get_fan_rpm_pwm(self, dev, metrics):
|
||||
match dev.ip_ver[am.MP1_HWIP]:
|
||||
case (13,0,6): return None, None
|
||||
case (13,0,6)|(13,0,12): return None, None
|
||||
case _: return metrics.SmuMetrics.AvgFanRpm, metrics.SmuMetrics.AvgFanPwm
|
||||
|
||||
def get_power(self, dev, metrics):
|
||||
match dev.ip_ver[am.MP1_HWIP]:
|
||||
case (13,0,6): return self._smuq10_round(metrics.SocketPower), self._smuq10_round(metrics.MaxSocketPowerLimit)
|
||||
case (13,0,6)|(13,0,12): return self._smuq10_round(metrics.SocketPower), self._smuq10_round(metrics.MaxSocketPowerLimit)
|
||||
case _: return metrics.SmuMetrics.AverageSocketPower, metrics.SmuMetrics.dGPU_W_MAX
|
||||
|
||||
def get_mem_usage(self, dev):
|
||||
|
||||
@@ -3,10 +3,11 @@ from __future__ import annotations
|
||||
import re
|
||||
from extra.assembly.amd.dsl import Inst, RawImm, Reg, SrcMod, SGPR, VGPR, TTMP, s, v, ttmp, _RegFactory
|
||||
from extra.assembly.amd.dsl import VCC_LO, VCC_HI, VCC, EXEC_LO, EXEC_HI, EXEC, SCC, M0, NULL, OFF
|
||||
from extra.assembly.amd.dsl import SPECIAL_GPRS, SPECIAL_PAIRS, FLOAT_DEC, FLOAT_ENC, decode_src
|
||||
from extra.assembly.amd.dsl import SPECIAL_GPRS, SPECIAL_PAIRS, SPECIAL_PAIRS_CDNA, FLOAT_DEC, FLOAT_ENC, decode_src
|
||||
from extra.assembly.amd.autogen.rdna3 import ins
|
||||
from extra.assembly.amd.autogen.rdna3.ins import (VOP1, VOP2, VOP3, VOP3SD, VOP3P, VOPC, VOPD, VINTERP, SOP1, SOP2, SOPC, SOPK, SOPP, SMEM, DS, FLAT, MUBUF, MTBUF, MIMG, EXP,
|
||||
VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOPDOp, SOP1Op, SOPKOp, SOPPOp, SMEMOp, DSOp, MUBUFOp)
|
||||
VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOPDOp, SOP1Op, SOPKOp, SOPPOp, SMEMOp, DSOp, MUBUFOp, MTBUFOp)
|
||||
from extra.assembly.amd.autogen.rdna3.enum import BufFmt
|
||||
|
||||
def _is_cdna(inst: Inst) -> bool: return 'cdna' in inst.__class__.__module__
|
||||
|
||||
@@ -17,21 +18,37 @@ def _matches_encoding(word: int, cls: type[Inst]) -> bool:
|
||||
return ((word >> bf.lo) & bf.mask()) == val
|
||||
|
||||
# Order matters: more specific encodings first, VOP2 last (it's a catch-all for bit31=0)
|
||||
_FORMATS_64 = [VOPD, VOP3P, VINTERP, VOP3, DS, FLAT, MUBUF, MTBUF, MIMG, SMEM, EXP]
|
||||
_FORMATS_32 = [SOP1, SOPC, SOPP, SOPK, VOPC, VOP1, SOP2, VOP2] # SOP2/VOP2 are catch-alls
|
||||
_RDNA_FORMATS_64 = [VOPD, VOP3P, VINTERP, VOP3, DS, FLAT, MUBUF, MTBUF, MIMG, SMEM, EXP]
|
||||
_RDNA_FORMATS_32 = [SOP1, SOPC, SOPP, SOPK, VOPC, VOP1, SOP2, VOP2] # SOP2/VOP2 are catch-alls
|
||||
from extra.assembly.amd.autogen.cdna.ins import (VOP1 as C_VOP1, VOP2 as C_VOP2, VOPC as C_VOPC, VOP3A, VOP3B, VOP3P as C_VOP3P,
|
||||
SOP1 as C_SOP1, SOP2 as C_SOP2, SOPC as C_SOPC, SOPK as C_SOPK, SOPP as C_SOPP, SMEM as C_SMEM, DS as C_DS,
|
||||
FLAT as C_FLAT, MUBUF as C_MUBUF, MTBUF as C_MTBUF, SDWA, DPP)
|
||||
_CDNA_FORMATS_64 = [C_VOP3P, VOP3A, C_DS, C_FLAT, C_MUBUF, C_MTBUF, C_SMEM]
|
||||
_CDNA_FORMATS_32 = [SDWA, DPP, C_SOP1, C_SOPC, C_SOPP, C_SOPK, C_VOPC, C_VOP1, C_SOP2, C_VOP2]
|
||||
_CDNA_VOP3B_OPS = {281, 282, 283, 284, 285, 286, 480, 481, 488, 489} # VOP3B opcodes
|
||||
# CDNA opcode name aliases for disasm (new name -> old name expected by tests)
|
||||
_CDNA_DISASM_ALIASES = {'v_fmac_f64': 'v_mul_legacy_f32', 'v_dot2c_f32_bf16': 'v_mac_f32', 'v_fmamk_f32': 'v_madmk_f32', 'v_fmaak_f32': 'v_madak_f32'}
|
||||
|
||||
def detect_format(data: bytes) -> type[Inst]:
|
||||
def detect_format(data: bytes, arch: str = "rdna3") -> type[Inst]:
|
||||
"""Detect instruction format from machine code bytes."""
|
||||
assert len(data) >= 4, f"need at least 4 bytes, got {len(data)}"
|
||||
word = int.from_bytes(data[:4], 'little')
|
||||
# Check 64-bit formats first (bits[31:30] == 0b11)
|
||||
if arch == "cdna":
|
||||
if (word >> 30) == 0b11:
|
||||
for cls in _CDNA_FORMATS_64:
|
||||
if _matches_encoding(word, cls):
|
||||
return VOP3B if cls is VOP3A and ((word >> 16) & 0x3ff) in _CDNA_VOP3B_OPS else cls
|
||||
raise ValueError(f"unknown CDNA 64-bit format word={word:#010x}")
|
||||
for cls in _CDNA_FORMATS_32:
|
||||
if _matches_encoding(word, cls): return cls
|
||||
raise ValueError(f"unknown CDNA 32-bit format word={word:#010x}")
|
||||
# RDNA (default)
|
||||
if (word >> 30) == 0b11:
|
||||
for cls in _FORMATS_64:
|
||||
for cls in _RDNA_FORMATS_64:
|
||||
if _matches_encoding(word, cls):
|
||||
return VOP3SD if cls is VOP3 and ((word >> 16) & 0x3ff) in Inst._VOP3SD_OPS else cls
|
||||
raise ValueError(f"unknown 64-bit format word={word:#010x}")
|
||||
# 32-bit formats
|
||||
for cls in _FORMATS_32:
|
||||
for cls in _RDNA_FORMATS_32:
|
||||
if _matches_encoding(word, cls): return cls
|
||||
raise ValueError(f"unknown 32-bit format word={word:#010x}")
|
||||
|
||||
@@ -44,6 +61,11 @@ HWREG = {1: 'HW_REG_MODE', 2: 'HW_REG_STATUS', 3: 'HW_REG_TRAPSTS', 4: 'HW_REG_H
|
||||
19: 'HW_REG_PERF_SNAPSHOT_PC_HI', 20: 'HW_REG_FLAT_SCR_LO', 21: 'HW_REG_FLAT_SCR_HI', 22: 'HW_REG_XNACK_MASK',
|
||||
23: 'HW_REG_HW_ID1', 24: 'HW_REG_HW_ID2', 25: 'HW_REG_POPS_PACKER', 28: 'HW_REG_IB_STS2'}
|
||||
HWREG_IDS = {v.lower(): k for k, v in HWREG.items()}
|
||||
# RDNA unified buffer format - extracted from PDF, use enum for name->value lookup
|
||||
BUF_FMT = {e.name: e.value for e in BufFmt}
|
||||
def _parse_buf_fmt_combo(s: str) -> int: # parse format:[BUF_DATA_FORMAT_X, BUF_NUM_FORMAT_Y]
|
||||
parts = [p.strip().replace('BUF_DATA_FORMAT_', '').replace('BUF_NUM_FORMAT_', '') for p in s.split(',')]
|
||||
return BUF_FMT.get(f'BUF_FMT_{parts[0]}_{parts[1]}') if len(parts) == 2 else None
|
||||
MSG = {128: 'MSG_RTN_GET_DOORBELL', 129: 'MSG_RTN_GET_DDID', 130: 'MSG_RTN_GET_TMA',
|
||||
131: 'MSG_RTN_GET_REALTIME', 132: 'MSG_RTN_SAVE_WAVE', 133: 'MSG_RTN_GET_TBA'}
|
||||
|
||||
@@ -54,22 +76,28 @@ MSG = {128: 'MSG_RTN_GET_DOORBELL', 129: 'MSG_RTN_GET_DDID', 130: 'MSG_RTN_GET_T
|
||||
def _reg(p: str, b: int, n: int = 1) -> str: return f"{p}{b}" if n == 1 else f"{p}[{b}:{b+n-1}]"
|
||||
def _sreg(b: int, n: int = 1) -> str: return _reg("s", b, n)
|
||||
def _vreg(b: int, n: int = 1) -> str: return _reg("v", b, n)
|
||||
def _areg(b: int, n: int = 1) -> str: return _reg("a", b, n) # accumulator registers for GFX90a
|
||||
def _ttmp(b: int, n: int = 1) -> str: return _reg("ttmp", b - 108, n) if 108 <= b <= 123 else None
|
||||
def _sreg_or_ttmp(b: int, n: int = 1) -> str: return _ttmp(b, n) or _sreg(b, n)
|
||||
|
||||
def _fmt_sdst(v: int, n: int = 1) -> str:
|
||||
if v == 124: return "null"
|
||||
def _fmt_sdst(v: int, n: int = 1, cdna: bool = False) -> str:
|
||||
from extra.assembly.amd.dsl import SPECIAL_PAIRS_CDNA, SPECIAL_GPRS_CDNA
|
||||
if t := _ttmp(v, n): return t
|
||||
if n > 1: return SPECIAL_PAIRS.get(v) or _sreg(v, n)
|
||||
return SPECIAL_GPRS.get(v, f"s{v}")
|
||||
pairs = SPECIAL_PAIRS_CDNA if cdna else SPECIAL_PAIRS
|
||||
gprs = SPECIAL_GPRS_CDNA if cdna else SPECIAL_GPRS
|
||||
if n > 1: return pairs.get(v) or gprs.get(v) or _sreg(v, n) # also check gprs for null/m0
|
||||
return gprs.get(v, f"s{v}")
|
||||
|
||||
def _fmt_src(v: int, n: int = 1) -> str:
|
||||
if n == 1: return decode_src(v)
|
||||
def _fmt_src(v: int, n: int = 1, cdna: bool = False) -> str:
|
||||
from extra.assembly.amd.dsl import SPECIAL_PAIRS_CDNA
|
||||
if n == 1: return decode_src(v, cdna)
|
||||
if v >= 256: return _vreg(v - 256, n)
|
||||
if v <= 105: return _sreg(v, n)
|
||||
if n == 2 and v in SPECIAL_PAIRS: return SPECIAL_PAIRS[v]
|
||||
if v <= 101: return _sreg(v, n) # s0-s101 can be pairs, but 102+ are special on CDNA
|
||||
pairs = SPECIAL_PAIRS_CDNA if cdna else SPECIAL_PAIRS
|
||||
if n == 2 and v in pairs: return pairs[v]
|
||||
if v <= 105: return _sreg(v, n) # s102-s105 regular pairs for RDNA
|
||||
if t := _ttmp(v, n): return t
|
||||
return decode_src(v)
|
||||
return decode_src(v, cdna)
|
||||
|
||||
def _fmt_v16(v: int, base: int = 256, hi_thresh: int = 384) -> str:
|
||||
return f"v{(v - base) & 0x7f}.{'h' if v >= hi_thresh else 'l'}"
|
||||
@@ -106,46 +134,72 @@ def _opsel_str(opsel: int, n: int, need: bool, is16_d: bool) -> str:
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
def _disasm_vop1(inst: VOP1) -> str:
|
||||
name = inst.op_name.lower()
|
||||
if inst.op in (VOP1Op.V_NOP, VOP1Op.V_PIPEFLUSH): return name
|
||||
if inst.op == VOP1Op.V_READFIRSTLANE_B32: return f"v_readfirstlane_b32 {decode_src(inst.vdst)}, v{inst.src0 - 256 if inst.src0 >= 256 else inst.src0}"
|
||||
# 16-bit dst: uses .h/.l suffix (determined by name pattern, not dtype - e.g. sat_pk_u8_i16 outputs 8-bit but uses 16-bit encoding)
|
||||
name, cdna = inst.op_name.lower() or f'vop1_op_{inst.op}', _is_cdna(inst)
|
||||
suf = "" if cdna else "_e32"
|
||||
if name in ('v_nop', 'v_pipeflush', 'v_clrexcp'): return name # no operands
|
||||
if 'readfirstlane' in name:
|
||||
src = f"v{inst.src0 - 256}" if inst.src0 >= 256 else decode_src(inst.src0, cdna)
|
||||
return f"{name} {_fmt_sdst(inst.vdst, 1, cdna)}, {src}"
|
||||
# 16-bit dst: uses .h/.l suffix for RDNA (CDNA uses plain vN)
|
||||
parts = name.split('_')
|
||||
is_16d = any(p in ('f16','i16','u16','b16') for p in parts[-2:-1]) or (len(parts) >= 2 and parts[-1] in ('f16','i16','u16','b16') and 'cvt' not in name)
|
||||
is_16d = not cdna and (any(p in ('f16','i16','u16','b16') for p in parts[-2:-1]) or (len(parts) >= 2 and parts[-1] in ('f16','i16','u16','b16') and 'cvt' not in name))
|
||||
dst = _vreg(inst.vdst, inst.dst_regs()) if inst.dst_regs() > 1 else _fmt_v16(inst.vdst, 0, 128) if is_16d else f"v{inst.vdst}"
|
||||
src = inst.lit(inst.src0) if inst.src0 == 255 else _fmt_src(inst.src0, inst.src_regs(0)) if inst.src_regs(0) > 1 else _src16(inst, inst.src0) if inst.is_src_16(0) and 'sat_pk' not in name else inst.lit(inst.src0)
|
||||
return f"{name}_e32 {dst}, {src}"
|
||||
src = inst.lit(inst.src0) if inst.src0 == 255 else _fmt_src(inst.src0, inst.src_regs(0), cdna) if inst.src_regs(0) > 1 else _src16(inst, inst.src0) if not cdna and inst.is_src_16(0) and 'sat_pk' not in name else inst.lit(inst.src0)
|
||||
return f"{name}{suf} {dst}, {src}"
|
||||
|
||||
_VOP2_CARRY_OUT = {'v_add_co_u32', 'v_sub_co_u32', 'v_subrev_co_u32'} # carry out only
|
||||
_VOP2_CARRY_INOUT = {'v_addc_co_u32', 'v_subb_co_u32', 'v_subbrev_co_u32'} # carry in and out
|
||||
def _disasm_vop2(inst: VOP2) -> str:
|
||||
name, cdna = inst.op_name.lower(), _is_cdna(inst)
|
||||
suf = "" if not cdna and inst.op == VOP2Op.V_DOT2ACC_F32_F16 else "_e32"
|
||||
if cdna: name = _CDNA_DISASM_ALIASES.get(name, name) # apply CDNA aliases
|
||||
suf = "" if cdna or (not cdna and inst.op == VOP2Op.V_DOT2ACC_F32_F16) else "_e32"
|
||||
lit = getattr(inst, '_literal', None)
|
||||
is16 = not cdna and inst.is_16bit()
|
||||
# fmaak: dst = src0 * vsrc1 + K, fmamk: dst = src0 * K + vsrc1
|
||||
if 'fmaak' in name or (not cdna and inst.op in (VOP2Op.V_FMAAK_F32, VOP2Op.V_FMAAK_F16)):
|
||||
# fmaak/madak: dst = src0 * vsrc1 + K, fmamk/madmk: dst = src0 * K + vsrc1
|
||||
if 'fmaak' in name or 'madak' in name or (not cdna and inst.op in (VOP2Op.V_FMAAK_F32, VOP2Op.V_FMAAK_F16)):
|
||||
if is16: return f"{name}{suf} {_fmt_v16(inst.vdst, 0, 128)}, {_src16(inst, inst.src0)}, {_fmt_v16(inst.vsrc1, 0, 128)}, 0x{lit:x}"
|
||||
return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, v{inst.vsrc1}, 0x{lit:x}"
|
||||
if 'fmamk' in name or (not cdna and inst.op in (VOP2Op.V_FMAMK_F32, VOP2Op.V_FMAMK_F16)):
|
||||
if 'fmamk' in name or 'madmk' in name or (not cdna and inst.op in (VOP2Op.V_FMAMK_F32, VOP2Op.V_FMAMK_F16)):
|
||||
if is16: return f"{name}{suf} {_fmt_v16(inst.vdst, 0, 128)}, {_src16(inst, inst.src0)}, 0x{lit:x}, {_fmt_v16(inst.vsrc1, 0, 128)}"
|
||||
return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, 0x{lit:x}, v{inst.vsrc1}"
|
||||
if is16: return f"{name}{suf} {_fmt_v16(inst.vdst, 0, 128)}, {_src16(inst, inst.src0)}, {_fmt_v16(inst.vsrc1, 0, 128)}"
|
||||
vcc = "vcc" if cdna else "vcc_lo"
|
||||
# CDNA carry ops output vcc after vdst
|
||||
if cdna and name in _VOP2_CARRY_OUT: return f"{name}{suf} v{inst.vdst}, {vcc}, {inst.lit(inst.src0)}, v{inst.vsrc1}"
|
||||
if cdna and name in _VOP2_CARRY_INOUT: return f"{name}{suf} v{inst.vdst}, {vcc}, {inst.lit(inst.src0)}, v{inst.vsrc1}, {vcc}"
|
||||
return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, v{inst.vsrc1}" + (f", {vcc}" if name == 'v_cndmask_b32' else "")
|
||||
|
||||
def _disasm_vopc(inst: VOPC) -> str:
|
||||
name, cdna = inst.op_name.lower(), _is_cdna(inst)
|
||||
if cdna:
|
||||
s0 = inst.lit(inst.src0) if inst.src0 == 255 else _fmt_src(inst.src0, inst.src_regs(0))
|
||||
return f"{name}_e32 {s0}, v{inst.vsrc1}" if inst.op.value >= 128 else f"{name}_e32 vcc, {s0}, v{inst.vsrc1}"
|
||||
s0 = inst.lit(inst.src0) if inst.src0 == 255 else _fmt_src(inst.src0, inst.src_regs(0), cdna)
|
||||
s1 = _vreg(inst.vsrc1, inst.src_regs(1)) if inst.src_regs(1) > 1 else f"v{inst.vsrc1}"
|
||||
return f"{name} vcc, {s0}, {s1}" # CDNA VOPC always outputs vcc
|
||||
# RDNA: v_cmpx_* writes to exec (no vcc), v_cmp_* writes to vcc_lo
|
||||
has_vcc = 'cmpx' not in name
|
||||
s0 = inst.lit(inst.src0) if inst.src0 == 255 else _fmt_src(inst.src0, inst.src_regs(0)) if inst.src_regs(0) > 1 else _src16(inst, inst.src0) if inst.is_16bit() else inst.lit(inst.src0)
|
||||
s1 = _vreg(inst.vsrc1, inst.src_regs(1)) if inst.src_regs(1) > 1 else _fmt_v16(inst.vsrc1, 0, 128) if inst.is_16bit() else f"v{inst.vsrc1}"
|
||||
return f"{name}_e32 {s0}, {s1}" if inst.op.value >= 128 else f"{name}_e32 vcc_lo, {s0}, {s1}"
|
||||
return f"{name}_e32 vcc_lo, {s0}, {s1}" if has_vcc else f"{name}_e32 {s0}, {s1}"
|
||||
|
||||
NO_ARG_SOPP = {SOPPOp.S_BARRIER, SOPPOp.S_WAKEUP, SOPPOp.S_ICACHE_INV,
|
||||
SOPPOp.S_WAIT_IDLE, SOPPOp.S_ENDPGM_SAVED, SOPPOp.S_CODE_END, SOPPOp.S_ENDPGM_ORDERED_PS_DONE, SOPPOp.S_TTRACEDATA}
|
||||
# CDNA uses name-based matching since opcode values differ from RDNA
|
||||
_CDNA_NO_ARG_SOPP = {'s_endpgm', 's_barrier', 's_wakeup', 's_icache_inv', 's_ttracedata', 's_nop', 's_sethalt', 's_sleep',
|
||||
's_setprio', 's_trap', 's_incperflevel', 's_decperflevel', 's_sendmsg', 's_sendmsghalt'}
|
||||
|
||||
def _disasm_sopp(inst: SOPP) -> str:
|
||||
name = inst.op_name.lower()
|
||||
name, cdna = inst.op_name.lower(), _is_cdna(inst)
|
||||
if cdna:
|
||||
# CDNA: use name-based matching
|
||||
if name == 's_endpgm': return name if inst.simm16 == 0 else f"{name} {inst.simm16}"
|
||||
if name in ('s_barrier', 's_wakeup', 's_icache_inv', 's_ttracedata'): return name
|
||||
if name == 's_waitcnt':
|
||||
vm, lgkm, exp = inst.simm16 & 0xf, (inst.simm16 >> 8) & 0x3f, (inst.simm16 >> 4) & 0x7
|
||||
p = [f"vmcnt({vm})" if vm != 0xf else "", f"expcnt({exp})" if exp != 7 else "", f"lgkmcnt({lgkm})" if lgkm != 0x3f else ""]
|
||||
return f"s_waitcnt {' '.join(x for x in p if x) or '0'}"
|
||||
if name.startswith(('s_cbranch', 's_branch')): return f"{name} {inst.simm16}"
|
||||
return f"{name} 0x{inst.simm16:x}" if inst.simm16 else name
|
||||
# RDNA
|
||||
if inst.op in NO_ARG_SOPP: return name
|
||||
if inst.op == SOPPOp.S_ENDPGM: return name if inst.simm16 == 0 else f"{name} {inst.simm16}"
|
||||
if inst.op == SOPPOp.S_WAITCNT:
|
||||
@@ -161,64 +215,98 @@ def _disasm_sopp(inst: SOPP) -> str:
|
||||
return f"{name} {inst.simm16}" if name.startswith(('s_cbranch', 's_branch')) else f"{name} 0x{inst.simm16:x}"
|
||||
|
||||
def _disasm_smem(inst: SMEM) -> str:
|
||||
name = inst.op_name.lower()
|
||||
name, cdna = inst.op_name.lower(), _is_cdna(inst)
|
||||
if inst.op in (SMEMOp.S_GL1_INV, SMEMOp.S_DCACHE_INV): return name
|
||||
off_s = f"{decode_src(inst.soffset)} offset:0x{inst.offset:x}" if inst.offset and inst.soffset != 124 else f"0x{inst.offset:x}" if inst.offset else decode_src(inst.soffset)
|
||||
sbase_idx, sbase_count = inst.sbase * 2, 4 if (8 <= inst.op.value <= 12 or name == 's_atc_probe_buffer') else 2
|
||||
sbase_str = _fmt_src(sbase_idx, sbase_count) if sbase_count == 2 else _sreg(sbase_idx, sbase_count) if sbase_idx <= 105 else _reg("ttmp", sbase_idx - 108, sbase_count)
|
||||
# GFX9 SMEM: soe and imm bits determine offset interpretation
|
||||
# soe=1, imm=1: soffset is SGPR, offset is immediate (both used)
|
||||
# soe=0, imm=1: offset is immediate
|
||||
# soe=0, imm=0: offset field is SGPR encoding (0-255)
|
||||
soe, imm = getattr(inst, 'soe', 0), getattr(inst, 'imm', 1)
|
||||
if cdna:
|
||||
if soe and imm:
|
||||
off_s = f"{decode_src(inst.soffset, cdna)} offset:0x{inst.offset:x}" # SGPR + immediate
|
||||
elif imm:
|
||||
off_s = f"0x{inst.offset:x}" # Immediate offset only
|
||||
elif inst.offset < 256:
|
||||
off_s = decode_src(inst.offset, cdna) # SGPR encoding in offset field
|
||||
else:
|
||||
off_s = decode_src(inst.soffset, cdna)
|
||||
elif inst.offset and inst.soffset != 124:
|
||||
off_s = f"{decode_src(inst.soffset, cdna)} offset:0x{inst.offset:x}"
|
||||
elif inst.offset:
|
||||
off_s = f"0x{inst.offset:x}"
|
||||
else:
|
||||
off_s = decode_src(inst.soffset, cdna)
|
||||
op_val = inst.op.value if hasattr(inst.op, 'value') else inst.op
|
||||
# s_buffer_* instructions use 4 SGPRs for sbase (buffer descriptor)
|
||||
is_buffer = 'buffer' in name or 's_atc_probe_buffer' == name
|
||||
sbase_idx, sbase_count = inst.sbase * 2, 4 if is_buffer else 2
|
||||
sbase_str = _fmt_src(sbase_idx, sbase_count, cdna) if sbase_count == 2 else _sreg(sbase_idx, sbase_count) if sbase_idx <= 105 else _reg("ttmp", sbase_idx - 108, sbase_count)
|
||||
if name in ('s_atc_probe', 's_atc_probe_buffer'): return f"{name} {inst.sdata}, {sbase_str}, {off_s}"
|
||||
return f"{name} {_fmt_sdst(inst.sdata, inst.dst_regs())}, {sbase_str}, {off_s}" + _mods((inst.glc, " glc"), (inst.dlc, " dlc"))
|
||||
return f"{name} {_fmt_sdst(inst.sdata, inst.dst_regs(), cdna)}, {sbase_str}, {off_s}" + _mods((inst.glc, " glc"), (getattr(inst, 'dlc', 0), " dlc"))
|
||||
|
||||
def _disasm_flat(inst: FLAT) -> str:
|
||||
name, cdna = inst.op_name.lower(), _is_cdna(inst)
|
||||
acc = getattr(inst, 'acc', 0) # GFX90a accumulator register flag
|
||||
reg_fn = _areg if acc else _vreg # use a[n] for acc=1, v[n] for acc=0
|
||||
seg = ['flat', 'scratch', 'global'][inst.seg] if inst.seg < 3 else 'flat'
|
||||
instr = f"{seg}_{name.split('_', 1)[1] if '_' in name else name}"
|
||||
off_val = inst.offset if seg == 'flat' else (inst.offset if inst.offset < 4096 else inst.offset - 8192)
|
||||
w = inst.dst_regs() * (2 if 'cmpswap' in name else 1)
|
||||
if cdna: mods = f"{f' offset:{off_val}' if off_val else ''}{' sc0' if inst.sc0 else ''}{' nt' if inst.nt else ''}{' sc1' if inst.sc1 else ''}"
|
||||
else: mods = f"{f' offset:{off_val}' if off_val else ''}{' glc' if inst.glc else ''}{' slc' if inst.slc else ''}{' dlc' if inst.dlc else ''}"
|
||||
w = inst.dst_regs() * (2 if '_x2' in name else 1) * (2 if 'cmpswap' in name else 1)
|
||||
off_s = f" offset:{off_val}" if off_val else "" # Omit offset:0
|
||||
if cdna: mods = f"{off_s}{' glc' if inst.sc0 else ''}{' slc' if inst.nt else ''}" # GFX9: sc0->glc, nt->slc
|
||||
else: mods = f"{off_s}{' glc' if inst.glc else ''}{' slc' if inst.slc else ''}{' dlc' if inst.dlc else ''}"
|
||||
# saddr
|
||||
if seg == 'flat' or inst.saddr == 0x7F: saddr_s = ""
|
||||
elif inst.saddr == 124: saddr_s = ", off"
|
||||
elif seg == 'scratch': saddr_s = f", {decode_src(inst.saddr)}"
|
||||
elif inst.saddr in SPECIAL_PAIRS: saddr_s = f", {SPECIAL_PAIRS[inst.saddr]}"
|
||||
elif seg == 'scratch': saddr_s = f", {decode_src(inst.saddr, cdna)}"
|
||||
elif inst.saddr in (SPECIAL_PAIRS_CDNA if cdna else SPECIAL_PAIRS): saddr_s = f", {(SPECIAL_PAIRS_CDNA if cdna else SPECIAL_PAIRS)[inst.saddr]}"
|
||||
elif t := _ttmp(inst.saddr, 2): saddr_s = f", {t}"
|
||||
else: saddr_s = f", {_sreg(inst.saddr, 2) if inst.saddr < 106 else decode_src(inst.saddr)}"
|
||||
else: saddr_s = f", {_sreg(inst.saddr, 2) if inst.saddr < 106 else decode_src(inst.saddr, cdna)}"
|
||||
# addtid: no addr
|
||||
if 'addtid' in name: return f"{instr} v{inst.data if 'store' in name else inst.vdst}{saddr_s}{mods}"
|
||||
# addr width
|
||||
addr_s = "off" if not inst.sve and seg == 'scratch' else _vreg(inst.addr, 1 if seg == 'scratch' or (inst.saddr not in (0x7F, 124)) else 2)
|
||||
data_s, vdst_s = _vreg(inst.data, w), _vreg(inst.vdst, w // 2 if 'cmpswap' in name else w)
|
||||
if 'addtid' in name: return f"{instr} {'a' if acc else 'v'}{inst.data if 'store' in name else inst.vdst}{saddr_s}{mods}"
|
||||
# addr width: CDNA flat always uses 2 VGPRs (64-bit), scratch uses 1, RDNA uses 2 only when no saddr
|
||||
if cdna:
|
||||
addr_w = 1 if seg == 'scratch' else 2 # CDNA: flat/global always 64-bit addr
|
||||
else:
|
||||
addr_w = 1 if seg == 'scratch' or (inst.saddr not in (0x7F, 124)) else 2
|
||||
addr_s = "off" if not inst.sve and seg == 'scratch' else _vreg(inst.addr, addr_w)
|
||||
data_s, vdst_s = reg_fn(inst.data, w), reg_fn(inst.vdst, w // 2 if 'cmpswap' in name else w)
|
||||
glc_or_sc0 = inst.sc0 if cdna else inst.glc
|
||||
if 'atomic' in name:
|
||||
return f"{instr} {vdst_s}, {addr_s}, {data_s}{saddr_s if seg != 'flat' else ''}{mods}" if glc_or_sc0 else f"{instr} {addr_s}, {data_s}{saddr_s if seg != 'flat' else ''}{mods}"
|
||||
if 'store' in name: return f"{instr} {addr_s}, {data_s}{saddr_s}{mods}"
|
||||
return f"{instr} {_vreg(inst.vdst, w)}, {addr_s}{saddr_s}{mods}"
|
||||
return f"{instr} {reg_fn(inst.vdst, w)}, {addr_s}{saddr_s}{mods}"
|
||||
|
||||
def _disasm_ds(inst: DS) -> str:
|
||||
op, name = inst.op, inst.op_name.lower()
|
||||
acc = getattr(inst, 'acc', 0) # GFX90a accumulator register flag
|
||||
reg_fn = _areg if acc else _vreg # use a[n] for acc=1, v[n] for acc=0
|
||||
rp = 'a' if acc else 'v' # register prefix for single regs
|
||||
gds = " gds" if inst.gds else ""
|
||||
off = f" offset:{inst.offset0 | (inst.offset1 << 8)}" if inst.offset0 or inst.offset1 else ""
|
||||
off2 = f" offset0:{inst.offset0} offset1:{inst.offset1}" if inst.offset0 or inst.offset1 else ""
|
||||
off2 = (" offset0:" + str(inst.offset0) if inst.offset0 else "") + (" offset1:" + str(inst.offset1) if inst.offset1 else "")
|
||||
w = inst.dst_regs()
|
||||
d0, d1, dst, addr = _vreg(inst.data0, w), _vreg(inst.data1, w), _vreg(inst.vdst, w), f"v{inst.addr}"
|
||||
d0, d1, dst, addr = reg_fn(inst.data0, w), reg_fn(inst.data1, w), reg_fn(inst.vdst, w), f"v{inst.addr}"
|
||||
|
||||
if op == DSOp.DS_NOP: return name
|
||||
if op == DSOp.DS_BVH_STACK_RTN_B32: return f"{name} v{inst.vdst}, {addr}, v{inst.data0}, {_vreg(inst.data1, 4)}{off}{gds}"
|
||||
if 'gws_sema' in name and op != DSOp.DS_GWS_SEMA_BR: return f"{name}{off}{gds}"
|
||||
if 'gws_' in name: return f"{name} {addr}{off}{gds}"
|
||||
if op in (DSOp.DS_CONSUME, DSOp.DS_APPEND): return f"{name} v{inst.vdst}{off}{gds}"
|
||||
if 'gs_reg' in name: return f"{name} {_vreg(inst.vdst, 2)}, v{inst.data0}{off}{gds}"
|
||||
if op in (DSOp.DS_CONSUME, DSOp.DS_APPEND): return f"{name} {rp}{inst.vdst}{off}{gds}"
|
||||
if 'gs_reg' in name: return f"{name} {reg_fn(inst.vdst, 2)}, {rp}{inst.data0}{off}{gds}"
|
||||
if '2addr' in name:
|
||||
if 'load' in name: return f"{name} {_vreg(inst.vdst, w*2)}, {addr}{off2}{gds}"
|
||||
if 'load' in name: return f"{name} {reg_fn(inst.vdst, w*2)}, {addr}{off2}{gds}"
|
||||
if 'store' in name and 'xchg' not in name: return f"{name} {addr}, {d0}, {d1}{off2}{gds}"
|
||||
return f"{name} {_vreg(inst.vdst, w*2)}, {addr}, {d0}, {d1}{off2}{gds}"
|
||||
if 'load' in name: return f"{name} v{inst.vdst}{off}{gds}" if 'addtid' in name else f"{name} {dst}, {addr}{off}{gds}"
|
||||
return f"{name} {reg_fn(inst.vdst, w*2)}, {addr}, {d0}, {d1}{off2}{gds}"
|
||||
if 'write2' in name: return f"{name} {addr}, {d0}, {d1}{off2}{gds}"
|
||||
if 'read2' in name: return f"{name} {reg_fn(inst.vdst, w*2)}, {addr}{off2}{gds}"
|
||||
if 'load' in name: return f"{name} {rp}{inst.vdst}{off}{gds}" if 'addtid' in name else f"{name} {dst}, {addr}{off}{gds}"
|
||||
if 'store' in name and not _has(name, 'cmp', 'xchg'):
|
||||
return f"{name} v{inst.data0}{off}{gds}" if 'addtid' in name else f"{name} {addr}, {d0}{off}{gds}"
|
||||
if 'swizzle' in name or op == DSOp.DS_ORDERED_COUNT: return f"{name} v{inst.vdst}, {addr}{off}{gds}"
|
||||
if 'permute' in name: return f"{name} v{inst.vdst}, {addr}, v{inst.data0}{off}{gds}"
|
||||
if 'condxchg' in name: return f"{name} {_vreg(inst.vdst, 2)}, {addr}, {_vreg(inst.data0, 2)}{off}{gds}"
|
||||
return f"{name} {rp}{inst.data0}{off}{gds}" if 'addtid' in name else f"{name} {addr}, {d0}{off}{gds}"
|
||||
if 'swizzle' in name or op == DSOp.DS_ORDERED_COUNT: return f"{name} {rp}{inst.vdst}, {addr}{off}{gds}"
|
||||
if 'permute' in name: return f"{name} {rp}{inst.vdst}, {addr}, {rp}{inst.data0}{off}{gds}"
|
||||
if 'condxchg' in name: return f"{name} {reg_fn(inst.vdst, 2)}, {addr}, {reg_fn(inst.data0, 2)}{off}{gds}"
|
||||
if _has(name, 'cmpstore', 'mskor', 'wrap'):
|
||||
return f"{name} {dst}, {addr}, {d0}, {d1}{off}{gds}" if '_rtn' in name else f"{name} {addr}, {d0}, {d1}{off}{gds}"
|
||||
return f"{name} {dst}, {addr}, {d0}{off}{gds}" if '_rtn' in name else f"{name} {addr}, {d0}{off}{gds}"
|
||||
@@ -318,6 +406,8 @@ def _disasm_vop3p(inst: VOP3P) -> str:
|
||||
|
||||
def _disasm_buf(inst: MUBUF | MTBUF) -> str:
|
||||
name, cdna = inst.op_name.lower(), _is_cdna(inst)
|
||||
acc = getattr(inst, 'acc', 0) # GFX90a accumulator register flag
|
||||
reg_fn = _areg if acc else _vreg # use a[n] for acc=1, v[n] for acc=0
|
||||
if cdna and name in ('buffer_wbl2', 'buffer_inv'): return name
|
||||
if not cdna and inst.op in (MUBUFOp.BUFFER_GL0_INV, MUBUFOp.BUFFER_GL1_INV): return name
|
||||
w = (2 if _has(name, 'xyz', 'xyzw') else 1) if 'd16' in name else \
|
||||
@@ -326,9 +416,27 @@ def _disasm_buf(inst: MUBUF | MTBUF) -> str:
|
||||
if hasattr(inst, 'tfe') and inst.tfe: w += 1
|
||||
vaddr = _vreg(inst.vaddr, 2) if inst.offen and inst.idxen else f"v{inst.vaddr}" if inst.offen or inst.idxen else "off"
|
||||
srsrc = _sreg_or_ttmp(inst.srsrc*4, 4)
|
||||
if cdna: mods = ([f"format:{inst.format}"] if isinstance(inst, MTBUF) else []) + [m for c, m in [(inst.idxen,"idxen"),(inst.offen,"offen"),(inst.offset,f"offset:{inst.offset}"),(inst.sc0,"sc0"),(inst.nt,"nt"),(inst.sc1,"sc1")] if c]
|
||||
else: mods = ([f"format:{inst.format}"] if isinstance(inst, MTBUF) else []) + [m for c, m in [(inst.idxen,"idxen"),(inst.offen,"offen"),(inst.offset,f"offset:{inst.offset}"),(inst.glc,"glc"),(inst.dlc,"dlc"),(inst.slc,"slc"),(inst.tfe,"tfe")] if c]
|
||||
return f"{name} {_vreg(inst.vdata, w)}, {vaddr}, {srsrc}, {decode_src(inst.soffset)}{' ' + ' '.join(mods) if mods else ''}"
|
||||
is_mtbuf = isinstance(inst, MTBUF) or isinstance(inst, C_MTBUF)
|
||||
if is_mtbuf:
|
||||
dfmt, nfmt = inst.format & 0xf, (inst.format >> 4) & 0x7
|
||||
if acc: # GFX90a accumulator style: show dfmt/nfmt as numbers
|
||||
fmt_s = f" dfmt:{dfmt}, nfmt:{nfmt}," # double space before dfmt per LLVM format
|
||||
elif not cdna: # RDNA style: show combined format number
|
||||
fmt_s = f" format:{inst.format}" if inst.format else ""
|
||||
else: # CDNA: show format:[BUF_DATA_FORMAT_X] or format:[BUF_NUM_FORMAT_X]
|
||||
dfmt_names = ['INVALID', '8', '16', '8_8', '32', '16_16', '10_11_11', '11_11_10', '10_10_10_2', '2_10_10_10', '8_8_8_8', '32_32', '16_16_16_16', '32_32_32', '32_32_32_32', 'RESERVED_15']
|
||||
nfmt_names = ['UNORM', 'SNORM', 'USCALED', 'SSCALED', 'UINT', 'SINT', 'RESERVED_6', 'FLOAT']
|
||||
if dfmt == 1 and nfmt == 0: fmt_s = "" # default, no format shown
|
||||
elif nfmt == 0: fmt_s = f" format:[BUF_DATA_FORMAT_{dfmt_names[dfmt]}]" # only dfmt differs
|
||||
elif dfmt == 1: fmt_s = f" format:[BUF_NUM_FORMAT_{nfmt_names[nfmt]}]" # only nfmt differs
|
||||
else: fmt_s = f" format:[BUF_DATA_FORMAT_{dfmt_names[dfmt]},BUF_NUM_FORMAT_{nfmt_names[nfmt]}]" # both differ
|
||||
else:
|
||||
fmt_s = ""
|
||||
if cdna: mods = [m for c, m in [(inst.idxen,"idxen"),(inst.offen,"offen"),(inst.offset,f"offset:{inst.offset}"),(inst.sc0,"glc"),(inst.nt,"slc"),(inst.sc1,"sc1")] if c]
|
||||
else: mods = [m for c, m in [(inst.idxen,"idxen"),(inst.offen,"offen"),(inst.offset,f"offset:{inst.offset}"),(inst.glc,"glc"),(inst.dlc,"dlc"),(inst.slc,"slc"),(inst.tfe,"tfe")] if c]
|
||||
soffset_s = decode_src(inst.soffset, cdna)
|
||||
if cdna and not acc and is_mtbuf: return f"{name} {reg_fn(inst.vdata, w)}, {vaddr}, {srsrc}, {soffset_s}{fmt_s}{' ' + ' '.join(mods) if mods else ''}"
|
||||
return f"{name} {reg_fn(inst.vdata, w)}, {vaddr}, {srsrc},{fmt_s} {soffset_s}{' ' + ' '.join(mods) if mods else ''}"
|
||||
|
||||
def _mimg_vaddr_width(name: str, dim: int, a16: bool) -> int:
|
||||
"""Calculate vaddr register count for MIMG sample/gather operations."""
|
||||
@@ -377,21 +485,23 @@ def _disasm_mimg(inst: MIMG) -> str:
|
||||
return f"{name} {_vreg(inst.vdata, vdata)}, {vaddr_str}, {srsrc_str}{ssamp_str} {' '.join(mods)}"
|
||||
|
||||
def _disasm_sop1(inst: SOP1) -> str:
|
||||
op, name = inst.op, inst.op_name.lower()
|
||||
src = inst.lit(inst.ssrc0) if inst.ssrc0 == 255 else _fmt_src(inst.ssrc0, inst.src_regs(0))
|
||||
if not _is_cdna(inst):
|
||||
op, name, cdna = inst.op, inst.op_name.lower(), _is_cdna(inst)
|
||||
src = inst.lit(inst.ssrc0) if inst.ssrc0 == 255 else _fmt_src(inst.ssrc0, inst.src_regs(0), cdna)
|
||||
if not cdna:
|
||||
if op == SOP1Op.S_GETPC_B64: return f"{name} {_fmt_sdst(inst.sdst, 2)}"
|
||||
if op in (SOP1Op.S_SETPC_B64, SOP1Op.S_RFE_B64): return f"{name} {src}"
|
||||
if op == SOP1Op.S_SWAPPC_B64: return f"{name} {_fmt_sdst(inst.sdst, 2)}, {src}"
|
||||
if op in (SOP1Op.S_SENDMSG_RTN_B32, SOP1Op.S_SENDMSG_RTN_B64): return f"{name} {_fmt_sdst(inst.sdst, inst.dst_regs())}, sendmsg({MSG.get(inst.ssrc0, str(inst.ssrc0))})"
|
||||
return f"{name} {_fmt_sdst(inst.sdst, inst.dst_regs())}, {src}"
|
||||
return f"{name} {_fmt_sdst(inst.sdst, inst.dst_regs(), cdna)}, {src}"
|
||||
|
||||
def _disasm_sop2(inst: SOP2) -> str:
|
||||
return f"{inst.op_name.lower()} {_fmt_sdst(inst.sdst, inst.dst_regs())}, {inst.lit(inst.ssrc0) if inst.ssrc0 == 255 else _fmt_src(inst.ssrc0, inst.src_regs(0))}, {inst.lit(inst.ssrc1) if inst.ssrc1 == 255 else _fmt_src(inst.ssrc1, inst.src_regs(1))}"
|
||||
cdna = _is_cdna(inst)
|
||||
return f"{inst.op_name.lower()} {_fmt_sdst(inst.sdst, inst.dst_regs(), cdna)}, {inst.lit(inst.ssrc0) if inst.ssrc0 == 255 else _fmt_src(inst.ssrc0, inst.src_regs(0), cdna)}, {inst.lit(inst.ssrc1) if inst.ssrc1 == 255 else _fmt_src(inst.ssrc1, inst.src_regs(1), cdna)}"
|
||||
|
||||
def _disasm_sopc(inst: SOPC) -> str:
|
||||
s0 = inst.lit(inst.ssrc0) if inst.ssrc0 == 255 else _fmt_src(inst.ssrc0, inst.src_regs(0))
|
||||
s1 = inst.lit(inst.ssrc1) if inst.ssrc1 == 255 else _fmt_src(inst.ssrc1, inst.src_regs(1))
|
||||
cdna = _is_cdna(inst)
|
||||
s0 = inst.lit(inst.ssrc0) if inst.ssrc0 == 255 else _fmt_src(inst.ssrc0, inst.src_regs(0), cdna)
|
||||
s1 = inst.lit(inst.ssrc1) if inst.ssrc1 == 255 else _fmt_src(inst.ssrc1, inst.src_regs(1), cdna)
|
||||
return f"{inst.op_name.lower()} {s0}, {s1}"
|
||||
|
||||
def _disasm_sopk(inst: SOPK) -> str:
|
||||
@@ -405,10 +515,10 @@ def _disasm_sopk(inst: SOPK) -> str:
|
||||
if (not cdna and op in (SOPKOp.S_SETREG_B32, SOPKOp.S_GETREG_B32)) or (cdna and name in ('s_setreg_b32', 's_getreg_b32')):
|
||||
hid, hoff, hsz = inst.simm16 & 0x3f, (inst.simm16 >> 6) & 0x1f, ((inst.simm16 >> 11) & 0x1f) + 1
|
||||
hs = f"0x{inst.simm16:x}" if hid in (16, 17) else f"hwreg({HWREG.get(hid, str(hid))}, {hoff}, {hsz})"
|
||||
return f"{name} {hs}, {_fmt_sdst(inst.sdst, 1)}" if 'setreg' in name else f"{name} {_fmt_sdst(inst.sdst, 1)}, {hs}"
|
||||
return f"{name} {hs}, {_fmt_sdst(inst.sdst, 1, cdna)}" if 'setreg' in name else f"{name} {_fmt_sdst(inst.sdst, 1, cdna)}, {hs}"
|
||||
if not cdna and op in (SOPKOp.S_SUBVECTOR_LOOP_BEGIN, SOPKOp.S_SUBVECTOR_LOOP_END):
|
||||
return f"{name} {_fmt_sdst(inst.sdst, 1)}, 0x{inst.simm16:x}"
|
||||
return f"{name} {_fmt_sdst(inst.sdst, inst.dst_regs())}, 0x{inst.simm16:x}"
|
||||
return f"{name} {_fmt_sdst(inst.sdst, inst.dst_regs(), cdna)}, 0x{inst.simm16:x}"
|
||||
|
||||
def _disasm_vinterp(inst: VINTERP) -> str:
|
||||
mods = _mods((inst.waitexp, f"wait_exp:{inst.waitexp}"), (inst.clmp, "clamp"))
|
||||
@@ -464,11 +574,54 @@ def _parse_ops(s: str) -> list[str]:
|
||||
return ops
|
||||
|
||||
def _extract(text: str, pat: str, flags=re.I):
|
||||
if m := re.search(pat, text, flags): return m, text[:m.start()] + text[m.end():]
|
||||
if m := re.search(pat, text, flags): return m, text[:m.start()] + ' ' + text[m.end():]
|
||||
return None, text
|
||||
|
||||
# Instruction aliases: LLVM uses different names for some instructions
|
||||
_ALIASES = {
|
||||
'v_cmp_tru_f16': 'v_cmp_t_f16', 'v_cmp_tru_f32': 'v_cmp_t_f32', 'v_cmp_tru_f64': 'v_cmp_t_f64',
|
||||
'v_cmpx_tru_f16': 'v_cmpx_t_f16', 'v_cmpx_tru_f32': 'v_cmpx_t_f32', 'v_cmpx_tru_f64': 'v_cmpx_t_f64',
|
||||
'v_cvt_flr_i32_f32': 'v_cvt_floor_i32_f32', 'v_cvt_rpi_i32_f32': 'v_cvt_nearest_i32_f32',
|
||||
'v_ffbh_i32': 'v_cls_i32', 'v_ffbh_u32': 'v_clz_i32_u32', 'v_ffbl_b32': 'v_ctz_i32_b32',
|
||||
'v_cvt_pkrtz_f16_f32': 'v_cvt_pk_rtz_f16_f32', 'v_fmac_legacy_f32': 'v_fmac_dx9_zero_f32', 'v_mul_legacy_f32': 'v_mul_dx9_zero_f32',
|
||||
# SMEM aliases (dword -> b32, dwordx2 -> b64, etc.)
|
||||
's_load_dword': 's_load_b32', 's_load_dwordx2': 's_load_b64', 's_load_dwordx4': 's_load_b128',
|
||||
's_load_dwordx8': 's_load_b256', 's_load_dwordx16': 's_load_b512',
|
||||
's_buffer_load_dword': 's_buffer_load_b32', 's_buffer_load_dwordx2': 's_buffer_load_b64',
|
||||
's_buffer_load_dwordx4': 's_buffer_load_b128', 's_buffer_load_dwordx8': 's_buffer_load_b256',
|
||||
's_buffer_load_dwordx16': 's_buffer_load_b512',
|
||||
# VOP3 aliases
|
||||
'v_cvt_pknorm_i16_f16': 'v_cvt_pk_norm_i16_f16', 'v_cvt_pknorm_u16_f16': 'v_cvt_pk_norm_u16_f16',
|
||||
'v_add3_nc_u32': 'v_add3_u32', 'v_xor_add_u32': 'v_xad_u32',
|
||||
# VINTERP aliases
|
||||
'v_interp_p2_new_f32': 'v_interp_p2_f32',
|
||||
# SOP1 aliases
|
||||
's_ff1_i32_b32': 's_ctz_i32_b32', 's_ff1_i32_b64': 's_ctz_i32_b64',
|
||||
's_flbit_i32_b32': 's_clz_i32_u32', 's_flbit_i32_b64': 's_clz_i32_u64', 's_flbit_i32': 's_cls_i32', 's_flbit_i32_i64': 's_cls_i32_i64',
|
||||
's_andn1_saveexec_b32': 's_and_not0_saveexec_b32', 's_andn1_saveexec_b64': 's_and_not0_saveexec_b64',
|
||||
's_andn1_wrexec_b32': 's_and_not0_wrexec_b32', 's_andn1_wrexec_b64': 's_and_not0_wrexec_b64',
|
||||
's_andn2_saveexec_b32': 's_and_not1_saveexec_b32', 's_andn2_saveexec_b64': 's_and_not1_saveexec_b64',
|
||||
's_andn2_wrexec_b32': 's_and_not1_wrexec_b32', 's_andn2_wrexec_b64': 's_and_not1_wrexec_b64',
|
||||
's_orn1_saveexec_b32': 's_or_not0_saveexec_b32', 's_orn1_saveexec_b64': 's_or_not0_saveexec_b64',
|
||||
's_orn2_saveexec_b32': 's_or_not1_saveexec_b32', 's_orn2_saveexec_b64': 's_or_not1_saveexec_b64',
|
||||
# SOP2 aliases
|
||||
's_andn2_b32': 's_and_not1_b32', 's_andn2_b64': 's_and_not1_b64',
|
||||
's_orn2_b32': 's_or_not1_b32', 's_orn2_b64': 's_or_not1_b64',
|
||||
# VOP2 aliases
|
||||
'v_dot2c_f32_f16': 'v_dot2acc_f32_f16',
|
||||
# More VOP3 aliases
|
||||
'v_fma_legacy_f32': 'v_fma_dx9_zero_f32',
|
||||
}
|
||||
|
||||
def _apply_alias(text: str) -> str:
|
||||
mn = text.split()[0].lower() if ' ' in text else text.lower().rstrip('_')
|
||||
# Try exact match first, then strip _e32/_e64 suffix
|
||||
for m in (mn, mn.removesuffix('_e32'), mn.removesuffix('_e64')):
|
||||
if m in _ALIASES: return _ALIASES[m] + text[len(m):]
|
||||
return text
|
||||
|
||||
def get_dsl(text: str) -> str:
|
||||
text, kw = text.strip(), []
|
||||
text, kw = _apply_alias(text.strip()), []
|
||||
# Extract modifiers
|
||||
for pat, val in [(r'\s+mul:2(?:\s|$)', 1), (r'\s+mul:4(?:\s|$)', 2), (r'\s+div:2(?:\s|$)', 3)]:
|
||||
if (m := _extract(text, pat))[0]: kw.append(f'omod={val}'); text = m[1]; break
|
||||
@@ -484,6 +637,11 @@ def get_dsl(text: str) -> str:
|
||||
m, text = _extract(text, r'\s+dlc(?:\s|$)'); dlc = 1 if m else None
|
||||
m, text = _extract(text, r'\s+glc(?:\s|$)'); glc = 1 if m else None
|
||||
m, text = _extract(text, r'\s+slc(?:\s|$)'); slc = 1 if m else None
|
||||
m, text = _extract(text, r'\s+tfe(?:\s|$)'); tfe = 1 if m else None
|
||||
m, text = _extract(text, r'\s+offen(?:\s|$)'); offen = 1 if m else None
|
||||
m, text = _extract(text, r'\s+idxen(?:\s|$)'); idxen = 1 if m else None
|
||||
m, text = _extract(text, r'\s+format:\[([^\]]+)\]'); fmt_val = m.group(1) if m else None
|
||||
m, text = _extract(text, r'\s+format:(\d+)'); fmt_val = m.group(1) if m and not fmt_val else fmt_val
|
||||
m, text = _extract(text, r'\s+neg_lo:\[([^\]]+)\]'); neg_lo = sum(int(x.strip()) << i for i, x in enumerate(m.group(1).split(','))) if m else None
|
||||
m, text = _extract(text, r'\s+neg_hi:\[([^\]]+)\]'); neg_hi = sum(int(x.strip()) << i for i, x in enumerate(m.group(1).split(','))) if m else None
|
||||
if waitexp: kw.append(f'waitexp={waitexp}')
|
||||
@@ -530,9 +688,30 @@ def get_dsl(text: str) -> str:
|
||||
if off_val and len(ops) >= 3: return f"{mn}(sdata={args[0]}, sbase={args[1]}, offset={off_val}, soffset={args[2]}{gs}{ds})"
|
||||
if len(ops) >= 3: return f"{mn}(sdata={args[0]}, sbase={args[1]}, soffset={args[2]}{gs}{ds})"
|
||||
|
||||
# Buffer
|
||||
if mn.startswith('buffer_') and len(ops) >= 2 and ops[1].strip().lower() == 'off':
|
||||
return f"{mn}(vdata={args[0]}, vaddr=0, srsrc={args[2]}, soffset={f'RawImm({args[3].strip()})' if len(args) > 3 else 'RawImm(0)'})"
|
||||
# Buffer (MUBUF/MTBUF) instructions
|
||||
if mn.startswith(('buffer_', 'tbuffer_')):
|
||||
is_tbuf = mn.startswith('tbuffer_')
|
||||
# Parse format value for tbuffer
|
||||
fmt_num = None
|
||||
if fmt_val is not None:
|
||||
if fmt_val.isdigit(): fmt_num = int(fmt_val)
|
||||
else: fmt_num = BUF_FMT.get(fmt_val.replace(' ', '')) or _parse_buf_fmt_combo(fmt_val)
|
||||
# Handle special no-arg buffer ops
|
||||
if mn in ('buffer_gl0_inv', 'buffer_gl1_inv', 'buffer_wbl2', 'buffer_inv'): return f"{mn}()"
|
||||
# Build modifiers string
|
||||
buf_mods = "".join([f", offset={off_val}" if off_val else "", ", glc=1" if glc else "", ", dlc=1" if dlc else "",
|
||||
", slc=1" if slc else "", ", tfe=1" if tfe else "", ", offen=1" if offen else "", ", idxen=1" if idxen else ""])
|
||||
if is_tbuf and fmt_num is not None: buf_mods = f", format={fmt_num}" + buf_mods
|
||||
# Determine vaddr value (v[0] for 'off', actual register otherwise)
|
||||
vaddr_idx = 1
|
||||
if len(ops) > vaddr_idx and ops[vaddr_idx].strip().lower() == 'off': vaddr_val = "v[0]"
|
||||
else: vaddr_val = args[vaddr_idx] if len(args) > vaddr_idx else "v[0]"
|
||||
# srsrc and soffset indices depend on whether vaddr is 'off'
|
||||
srsrc_idx, soff_idx = (2, 3) if len(ops) > 1 else (1, 2)
|
||||
srsrc_val = args[srsrc_idx] if len(args) > srsrc_idx else "s[0:3]"
|
||||
soff_val = args[soff_idx] if len(args) > soff_idx else "0"
|
||||
# soffset: integers are inline constants, don't wrap in RawImm
|
||||
return f"{mn}(vdata={args[0]}, vaddr={vaddr_val}, srsrc={srsrc_val}, soffset={soff_val}{buf_mods})"
|
||||
|
||||
# FLAT/GLOBAL/SCRATCH load/store/atomic - saddr needs RawImm(124) for off/null
|
||||
def _saddr(a): return 'RawImm(124)' if a in ('OFF', 'NULL') else a
|
||||
@@ -582,6 +761,15 @@ def get_dsl(text: str) -> str:
|
||||
if mn.replace('_e64', '') in vcc_ops and mn.endswith('_e64'): mn = mn.replace('_e64', '')
|
||||
if mn.startswith('v_cmp') and not mn.endswith('_e64') and len(args) >= 3 and ops[0].strip().lower() in ('vcc_lo', 'vcc_hi', 'vcc'): args = args[1:]
|
||||
if 'cmpx' in mn and mn.endswith('_e64') and len(args) == 2: args = ['RawImm(126)'] + args
|
||||
# v_cmp_*_e64 has SGPR destination in vdst field - encode as RawImm
|
||||
_SGPR_NAMES = {'vcc_lo': 106, 'vcc_hi': 107, 'vcc': 106, 'null': 124, 'm0': 125, 'exec_lo': 126, 'exec_hi': 127}
|
||||
if mn.startswith('v_cmp') and 'cmpx' not in mn and mn.endswith('_e64') and len(args) >= 1:
|
||||
dst = ops[0].strip().lower()
|
||||
if dst.startswith('s') and dst[1:].isdigit(): args[0] = f'RawImm({int(dst[1:])})'
|
||||
elif dst.startswith('s[') and ':' in dst: args[0] = f'RawImm({int(dst[2:].split(":")[0])})'
|
||||
elif dst.startswith('ttmp') and dst[4:].isdigit(): args[0] = f'RawImm({108 + int(dst[4:])})'
|
||||
elif dst.startswith('ttmp[') and ':' in dst: args[0] = f'RawImm({108 + int(dst[5:].split(":")[0])})'
|
||||
elif dst in _SGPR_NAMES: args[0] = f'RawImm({_SGPR_NAMES[dst]})'
|
||||
|
||||
fn = mn.replace('.', '_')
|
||||
if opsel is not None: args = [re.sub(r'\.[hl]$', '', a) for a in args]
|
||||
@@ -629,31 +817,76 @@ def asm(text: str) -> Inst:
|
||||
try:
|
||||
from extra.assembly.amd.autogen.cdna.ins import (VOP1 as CDNA_VOP1, VOP2 as CDNA_VOP2, VOPC as CDNA_VOPC, VOP3A, VOP3B, VOP3P as CDNA_VOP3P,
|
||||
SOP1 as CDNA_SOP1, SOP2 as CDNA_SOP2, SOPC as CDNA_SOPC, SOPK as CDNA_SOPK, SOPP as CDNA_SOPP, SMEM as CDNA_SMEM, DS as CDNA_DS,
|
||||
FLAT as CDNA_FLAT, MUBUF as CDNA_MUBUF, MTBUF as CDNA_MTBUF, SDWA, DPP, VOP1Op as CDNA_VOP1Op)
|
||||
FLAT as CDNA_FLAT, MUBUF as CDNA_MUBUF, MTBUF as CDNA_MTBUF, SDWA, DPP, VOP1Op as CDNA_VOP1Op, VOP2Op as CDNA_VOP2Op, VOPCOp as CDNA_VOPCOp)
|
||||
|
||||
def _cdna_src(inst, v, neg, abs_=0, n=1):
|
||||
s = inst.lit(v) if v == 255 else _fmt_src(v, n)
|
||||
s = inst.lit(v) if v == 255 else _fmt_src(v, n, cdna=True)
|
||||
if abs_: s = f"|{s}|"
|
||||
return f"neg({s})" if neg and v == 255 else (f"-{s}" if neg else s)
|
||||
|
||||
def _disasm_vop3a(inst) -> str:
|
||||
name, n, cl, om = inst.op_name.lower(), inst.num_srcs(), " clamp" if inst.clmp else "", _omod(inst.omod)
|
||||
s0, s1, s2 = _cdna_src(inst, inst.src0, inst.neg&1, inst.abs&1, inst.src_regs(0)), _cdna_src(inst, inst.src1, inst.neg&2, inst.abs&2, inst.src_regs(1)), _cdna_src(inst, inst.src2, inst.neg&4, inst.abs&4, inst.src_regs(2))
|
||||
dst = _vreg(inst.vdst, inst.dst_regs()) if inst.dst_regs() > 1 else f"v{inst.vdst}"
|
||||
if inst.op.value < 256: return f"{name}_e64 {s0}, {s1}" if name.startswith('v_cmpx') else f"{name}_e64 {_fmt_sdst(inst.vdst, 1)}, {s0}, {s1}"
|
||||
suf = "_e64" if inst.op.value < 512 else ""
|
||||
return f"{name}{suf} {dst}, {s0}, {s1}, {s2}{cl}{om}" if n == 3 else (f"{name}{suf}" if name == 'v_nop' else f"{name}{suf} {dst}, {s0}, {s1}{cl}{om}" if n == 2 else f"{name}{suf} {dst}, {s0}{cl}{om}")
|
||||
# CDNA VOP2 aliases: new opcode name -> old name expected by LLVM tests
|
||||
_CDNA_VOP3_ALIASES = {'v_fmac_f64': 'v_mul_legacy_f32', 'v_dot2c_f32_bf16': 'v_mac_f32'}
|
||||
|
||||
def _disasm_vop3b(inst) -> str:
|
||||
name, n = inst.op_name.lower(), inst.num_srcs()
|
||||
s0, s1, s2 = _cdna_src(inst, inst.src0, inst.neg&1), _cdna_src(inst, inst.src1, inst.neg&2), _cdna_src(inst, inst.src2, inst.neg&4)
|
||||
dst, suf = _vreg(inst.vdst, inst.dst_regs()) if inst.dst_regs() > 1 else f"v{inst.vdst}", "_e64" if 'co_' in name else ""
|
||||
def _disasm_vop3a(inst) -> str:
|
||||
op_val = inst._values.get('op', 0) # get raw opcode value, not enum value
|
||||
if hasattr(op_val, 'value'): op_val = op_val.value # in case it's stored as enum
|
||||
name = inst.op_name.lower() or f'vop3a_op_{op_val}'
|
||||
from extra.assembly.amd.dsl import spec_num_srcs, spec_regs
|
||||
n = spec_num_srcs(name) if name else inst.num_srcs()
|
||||
cl, om = " clamp" if inst.clmp else "", _omod(inst.omod)
|
||||
return f"{name}{suf} {dst}, {_fmt_sdst(inst.sdst, 1)}, {s0}, {s1}, {s2}{cl}{om}" if n == 3 else f"{name}{suf} {dst}, {_fmt_sdst(inst.sdst, 1)}, {s0}, {s1}{cl}{om}"
|
||||
orig_name = name
|
||||
name = _CDNA_VOP3_ALIASES.get(name, name) # apply CDNA aliases
|
||||
# For aliased ops, recalculate sources without 64-bit assumption
|
||||
if name != orig_name:
|
||||
s0, s1 = _cdna_src(inst, inst.src0, inst.neg&1, inst.abs&1, 1), _cdna_src(inst, inst.src1, inst.neg&2, inst.abs&2, 1)
|
||||
s2 = ""
|
||||
dst = f"v{inst.vdst}"
|
||||
else:
|
||||
dregs, r0, r1, r2 = spec_regs(name) if name else (inst.dst_regs(), inst.src_regs(0), inst.src_regs(1), inst.src_regs(2))
|
||||
s0, s1, s2 = _cdna_src(inst, inst.src0, inst.neg&1, inst.abs&1, r0), _cdna_src(inst, inst.src1, inst.neg&2, inst.abs&2, r1), _cdna_src(inst, inst.src2, inst.neg&4, inst.abs&4, r2)
|
||||
dst = _vreg(inst.vdst, dregs) if dregs > 1 else f"v{inst.vdst}"
|
||||
# True VOP3 instructions (512+) - 3-source ops
|
||||
if op_val >= 512:
|
||||
return f"{name} {dst}, {s0}, {s1}, {s2}{cl}{om}" if n == 3 else f"{name} {dst}, {s0}, {s1}{cl}{om}"
|
||||
# VOPC (0-255): writes to SGPR pair, VOP2 (256-319): 2-3 src, VOP1 (320-511): 1 src
|
||||
if op_val < 256:
|
||||
sdst = _fmt_sdst(inst.vdst, 2, cdna=True) # VOPC writes to 64-bit SGPR pair
|
||||
# v_cmpx_ also writes to sdst in CDNA VOP3 (unlike VOP32 where it writes to exec)
|
||||
return f"{name}_e64 {sdst}, {s0}, {s1}{cl}"
|
||||
if 320 <= op_val < 512: # VOP1 promoted
|
||||
if name in ('v_nop', 'v_clrexcp'): return f"{name}_e64"
|
||||
return f"{name}_e64 {dst}, {s0}{cl}{om}"
|
||||
# VOP2 promoted (256-319)
|
||||
if name == 'v_cndmask_b32':
|
||||
s2 = _fmt_src(inst.src2, 2, cdna=True) # src2 is 64-bit SGPR pair
|
||||
return f"{name}_e64 {dst}, {s0}, {s1}, {s2}{cl}{om}"
|
||||
if name in ('v_mul_legacy_f32', 'v_mac_f32'):
|
||||
return f"{name}_e64 {dst}, {s0}, {s1}{cl}{om}"
|
||||
suf = "_e64" if op_val < 512 else ""
|
||||
return f"{name}{suf} {dst}, {s0}, {s1}, {s2}{cl}{om}" if n == 3 else f"{name}{suf} {dst}, {s0}, {s1}{cl}{om}"
|
||||
|
||||
# GFX9-specific VOP3B opcodes not in CDNA enum
|
||||
def _disasm_vop3b(inst) -> str:
|
||||
op_val = inst._values.get('op', 0)
|
||||
if hasattr(op_val, 'value'): op_val = op_val.value
|
||||
name = inst.op_name.lower() or f'vop3b_op_{op_val}'
|
||||
from extra.assembly.amd.dsl import spec_num_srcs, spec_regs
|
||||
n = spec_num_srcs(name) if name else inst.num_srcs()
|
||||
dregs, r0, r1, r2 = spec_regs(name) if name else (inst.dst_regs(), inst.src_regs(0), inst.src_regs(1), inst.src_regs(2))
|
||||
s0, s1, s2 = _cdna_src(inst, inst.src0, inst.neg&1, n=r0), _cdna_src(inst, inst.src1, inst.neg&2, n=r1), _cdna_src(inst, inst.src2, inst.neg&4, n=r2)
|
||||
dst = _vreg(inst.vdst, dregs) if dregs > 1 else f"v{inst.vdst}"
|
||||
sdst = _fmt_sdst(inst.sdst, 2, cdna=True) # VOP3B sdst is always 64-bit SGPR pair
|
||||
cl, om = " clamp" if inst.clmp else "", _omod(inst.omod)
|
||||
# Carry ops need special handling
|
||||
if name in ('v_addc_co_u32', 'v_subb_co_u32', 'v_subbrev_co_u32'):
|
||||
s2 = _fmt_src(inst.src2, 2, cdna=True) # src2 is carry-in (64-bit SGPR pair)
|
||||
return f"{name}_e64 {dst}, {sdst}, {s0}, {s1}, {s2}{cl}{om}"
|
||||
suf = "_e64" if 'co_' in name else ""
|
||||
return f"{name}{suf} {dst}, {sdst}, {s0}, {s1}, {s2}{cl}{om}" if n == 3 else f"{name}{suf} {dst}, {sdst}, {s0}, {s1}{cl}{om}"
|
||||
|
||||
def _disasm_cdna_vop3p(inst) -> str:
|
||||
name, n, is_mfma = inst.op_name.lower(), inst.num_srcs(), 'mfma' in inst.op_name.lower() or 'smfmac' in inst.op_name.lower()
|
||||
get_src = lambda v, sc: inst.lit(v) if v == 255 else _fmt_src(v, sc)
|
||||
get_src = lambda v, sc: inst.lit(v) if v == 255 else _fmt_src(v, sc, cdna=True)
|
||||
if is_mfma: sc = 2 if 'iu4' in name else 4 if 'iu8' in name or 'i4' in name else 8 if 'f16' in name or 'bf16' in name else 4; src0, src1, src2, dst = get_src(inst.src0, sc), get_src(inst.src1, sc), get_src(inst.src2, 16), _vreg(inst.vdst, 16)
|
||||
else: src0, src1, src2, dst = get_src(inst.src0, 1), get_src(inst.src1, 1), get_src(inst.src2, 1), f"v{inst.vdst}"
|
||||
opsel_hi = inst.opsel_hi | (inst.opsel_hi2 << 2)
|
||||
@@ -665,20 +898,93 @@ try:
|
||||
_UNUSED = {0: 'UNUSED_PAD', 1: 'UNUSED_SEXT', 2: 'UNUSED_PRESERVE'}
|
||||
_DPP = {0x130: "wave_shl:1", 0x134: "wave_rol:1", 0x138: "wave_shr:1", 0x13c: "wave_ror:1", 0x140: "row_mirror", 0x141: "row_half_mirror", 0x142: "row_bcast:15", 0x143: "row_bcast:31"}
|
||||
|
||||
def _sdwa_src0(v, is_sgpr, sext=0, neg=0, abs_=0):
|
||||
# s0=0: VGPR (v is VGPR number), s0=1: SGPR/constant (v is encoded like normal src)
|
||||
s = decode_src(v, cdna=True) if is_sgpr else f"v{v}"
|
||||
if sext: s = f"sext({s})"
|
||||
if abs_: s = f"|{s}|"
|
||||
return f"-{s}" if neg else s
|
||||
|
||||
def _sdwa_vsrc1(v, sext=0, neg=0, abs_=0):
|
||||
# For VOP2 SDWA, vsrc1 is in vop_op field as raw VGPR number
|
||||
s = f"v{v}"
|
||||
if sext: s = f"sext({s})"
|
||||
if abs_: s = f"|{s}|"
|
||||
return f"-{s}" if neg else s
|
||||
|
||||
_OMOD_SDWA = {0: "", 1: " mul:2", 2: " mul:4", 3: " div:2"}
|
||||
|
||||
def _disasm_sdwa(inst) -> str:
|
||||
try: name = CDNA_VOP1Op(inst.vop_op).name.lower()
|
||||
except ValueError: name = f"vop1_op_{inst.vop_op}"
|
||||
src = f"v{inst.src0 - 256 if inst.src0 >= 256 else inst.src0}" if isinstance(inst.src0, int) else str(inst.src0)
|
||||
mods = [f"dst_sel:{_SEL[inst.dst_sel]}" for _ in [1] if inst.dst_sel != 6] + [f"dst_unused:{_UNUSED[inst.dst_u]}" for _ in [1] if inst.dst_u] + [f"src0_sel:{_SEL[inst.src0_sel]}" for _ in [1] if inst.src0_sel != 6]
|
||||
return f"{name}_sdwa v{inst.vdst}, {src}" + (" " + " ".join(mods) if mods else "")
|
||||
# SDWA format: vop2_op=63 -> VOP1, vop2_op=62 -> VOPC, vop2_op=0-61 -> VOP2
|
||||
vop2_op = inst.vop2_op
|
||||
src0 = _sdwa_src0(inst.src0, inst.s0, inst.src0_sext, inst.src0_neg, inst.src0_abs)
|
||||
clamp = " clamp" if inst.clmp else ""
|
||||
omod = _OMOD_SDWA.get(inst.omod, "")
|
||||
if vop2_op == 63: # VOP1
|
||||
try: name = CDNA_VOP1Op(inst.vop_op).name.lower()
|
||||
except ValueError: name = f"vop1_op_{inst.vop_op}"
|
||||
dst = f"v{inst.vdst}"
|
||||
mods = [f"dst_sel:{_SEL[inst.dst_sel]}", f"dst_unused:{_UNUSED[inst.dst_u]}", f"src0_sel:{_SEL[inst.src0_sel]}"]
|
||||
return f"{name}_sdwa {dst}, {src0}{clamp}{omod} " + " ".join(mods)
|
||||
elif vop2_op == 62: # VOPC
|
||||
try: name = CDNA_VOPCOp(inst.vdst).name.lower() # opcode is in vdst field for VOPC SDWA
|
||||
except ValueError: name = f"vopc_op_{inst.vdst}"
|
||||
src1 = _sdwa_vsrc1(inst.vop_op, inst.src1_sext, inst.src1_neg, inst.src1_abs) # vsrc1 is in vop_op field
|
||||
# VOPC SDWA: dst encoded in byte 5 (bits 47:40): 0=vcc, 128+n=s[n:n+1]
|
||||
sdst_enc = inst.dst_sel | (inst.dst_u << 3) | (inst.clmp << 5) | (inst.omod << 6)
|
||||
if sdst_enc == 0:
|
||||
sdst = "vcc"
|
||||
else:
|
||||
sdst_val = sdst_enc - 128 if sdst_enc >= 128 else sdst_enc
|
||||
sdst = _fmt_sdst(sdst_val, 2, cdna=True)
|
||||
mods = [f"src0_sel:{_SEL[inst.src0_sel]}", f"src1_sel:{_SEL[inst.src1_sel]}"]
|
||||
return f"{name}_sdwa {sdst}, {src0}, {src1} " + " ".join(mods)
|
||||
else: # VOP2
|
||||
try: name = CDNA_VOP2Op(vop2_op).name.lower()
|
||||
except ValueError: name = f"vop2_op_{vop2_op}"
|
||||
name = _CDNA_DISASM_ALIASES.get(name, name) # apply aliases (v_fmac -> v_mac, etc.)
|
||||
dst = f"v{inst.vdst}"
|
||||
src1 = _sdwa_vsrc1(inst.vop_op, inst.src1_sext, inst.src1_neg, inst.src1_abs) # vsrc1 is in vop_op field
|
||||
mods = [f"dst_sel:{_SEL[inst.dst_sel]}", f"dst_unused:{_UNUSED[inst.dst_u]}", f"src0_sel:{_SEL[inst.src0_sel]}", f"src1_sel:{_SEL[inst.src1_sel]}"]
|
||||
# v_cndmask_b32 needs vcc as third operand
|
||||
if name == 'v_cndmask_b32':
|
||||
return f"{name}_sdwa {dst}, {src0}, {src1}, vcc{clamp}{omod} " + " ".join(mods)
|
||||
# Carry ops need vcc - v_addc/subb also need vcc as carry-in
|
||||
if name in ('v_addc_co_u32', 'v_subb_co_u32', 'v_subbrev_co_u32'):
|
||||
return f"{name}_sdwa {dst}, vcc, {src0}, {src1}, vcc{clamp}{omod} " + " ".join(mods)
|
||||
if '_co_' in name:
|
||||
return f"{name}_sdwa {dst}, vcc, {src0}, {src1}{clamp}{omod} " + " ".join(mods)
|
||||
return f"{name}_sdwa {dst}, {src0}, {src1}{clamp}{omod} " + " ".join(mods)
|
||||
|
||||
def _dpp_src(v, neg=0, abs_=0):
|
||||
s = f"v{v}" if v < 256 else f"v{v - 256}"
|
||||
if abs_: s = f"|{s}|"
|
||||
return f"-{s}" if neg else s
|
||||
|
||||
def _disasm_dpp(inst) -> str:
|
||||
try: name = CDNA_VOP1Op(inst.vop_op).name.lower()
|
||||
except ValueError: name = f"vop1_op_{inst.vop_op}"
|
||||
src, ctrl = f"v{inst.src0 - 256 if inst.src0 >= 256 else inst.src0}" if isinstance(inst.src0, int) else str(inst.src0), inst.dpp_ctrl
|
||||
# DPP format: vop2_op=63 -> VOP1, vop2_op=0-62 -> VOP2
|
||||
vop2_op = inst.vop2_op
|
||||
ctrl = inst.dpp_ctrl
|
||||
dpp = f"quad_perm:[{ctrl&3},{(ctrl>>2)&3},{(ctrl>>4)&3},{(ctrl>>6)&3}]" if ctrl < 0x100 else f"row_shl:{ctrl&0xf}" if ctrl < 0x110 else f"row_shr:{ctrl&0xf}" if ctrl < 0x120 else f"row_ror:{ctrl&0xf}" if ctrl < 0x130 else _DPP.get(ctrl, f"dpp_ctrl:0x{ctrl:x}")
|
||||
mods = [dpp] + [f"row_mask:0x{inst.row_mask:x}" for _ in [1] if inst.row_mask != 0xf] + [f"bank_mask:0x{inst.bank_mask:x}" for _ in [1] if inst.bank_mask != 0xf] + ["bound_ctrl:1" for _ in [1] if inst.bound_ctrl]
|
||||
return f"{name}_dpp v{inst.vdst}, {src} " + " ".join(mods)
|
||||
src0 = _dpp_src(inst.src0, inst.src0_neg, inst.src0_abs)
|
||||
# DPP modifiers: row_mask and bank_mask always shown, bound_ctrl:0 when bit=1
|
||||
mods = [dpp, f"row_mask:0x{inst.row_mask:x}", f"bank_mask:0x{inst.bank_mask:x}"] + (["bound_ctrl:0"] if inst.bound_ctrl else [])
|
||||
if vop2_op == 63: # VOP1
|
||||
try: name = CDNA_VOP1Op(inst.vop_op).name.lower()
|
||||
except ValueError: name = f"vop1_op_{inst.vop_op}"
|
||||
return f"{name}_dpp v{inst.vdst}, {src0} " + " ".join(mods)
|
||||
else: # VOP2
|
||||
try: name = CDNA_VOP2Op(vop2_op).name.lower()
|
||||
except ValueError: name = f"vop2_op_{vop2_op}"
|
||||
name = _CDNA_DISASM_ALIASES.get(name, name)
|
||||
src1 = _dpp_src(inst.vop_op, inst.src1_neg, inst.src1_abs) # vsrc1 is in vop_op field
|
||||
if name == 'v_cndmask_b32':
|
||||
return f"{name}_dpp v{inst.vdst}, {src0}, {src1}, vcc " + " ".join(mods)
|
||||
if name in ('v_addc_co_u32', 'v_subb_co_u32', 'v_subbrev_co_u32'):
|
||||
return f"{name}_dpp v{inst.vdst}, vcc, {src0}, {src1}, vcc " + " ".join(mods)
|
||||
if '_co_' in name:
|
||||
return f"{name}_dpp v{inst.vdst}, vcc, {src0}, {src1} " + " ".join(mods)
|
||||
return f"{name}_dpp v{inst.vdst}, {src0}, {src1} " + " ".join(mods)
|
||||
|
||||
# Register CDNA handlers - shared formats use merged disassemblers, CDNA-only formats use dedicated ones
|
||||
DISASM_HANDLERS.update({CDNA_VOP1: _disasm_vop1, CDNA_VOP2: _disasm_vop2, CDNA_VOPC: _disasm_vopc,
|
||||
|
||||
@@ -1,46 +1,6 @@
|
||||
# autogenerated from AMD CDNA3+CDNA4 ISA PDF by pdf.py - do not edit
|
||||
# autogenerated from AMD ISA PDF by pdf.py - do not edit
|
||||
from enum import IntEnum
|
||||
|
||||
class SrcEnum(IntEnum):
|
||||
S_ADD_U32 = 0
|
||||
S_SUB_U32 = 1
|
||||
S_ADD_I32 = 2
|
||||
S_SUB_I32 = 3
|
||||
S_ADDC_U32 = 4
|
||||
S_SUBB_U32 = 5
|
||||
S_MIN_I32 = 6
|
||||
FLAT_SCRATCH_LO = 102
|
||||
FLAT_SCRATCH_HI = 103
|
||||
XNACK_MASK_LO = 104
|
||||
XNACK_MASK_HI = 105
|
||||
VCC_LO = 106
|
||||
VCC_HI = 107
|
||||
M0 = 124
|
||||
EXEC_LO = 126
|
||||
EXEC_HI = 127
|
||||
ZERO = 128
|
||||
DPP8 = 233
|
||||
DPP8FI = 234
|
||||
SHARED_BASE = 235
|
||||
SHARED_LIMIT = 236
|
||||
PRIVATE_BASE = 237
|
||||
PRIVATE_LIMIT = 238
|
||||
RESERVED = 239
|
||||
POS_HALF = 240
|
||||
NEG_HALF = 241
|
||||
POS_ONE = 242
|
||||
NEG_ONE = 243
|
||||
POS_TWO = 244
|
||||
NEG_TWO = 245
|
||||
POS_FOUR = 246
|
||||
NEG_FOUR = 247
|
||||
INV_2PI = 248
|
||||
DPP16 = 250
|
||||
VCCZ = 251
|
||||
EXECZ = 252
|
||||
SCC = 253
|
||||
LDS_DIRECT = 254
|
||||
|
||||
class DSOp(IntEnum):
|
||||
DS_ADD_U32 = 0
|
||||
DS_SUB_U32 = 1
|
||||
@@ -155,12 +115,6 @@ class DSOp(IntEnum):
|
||||
DS_READ2ST64_B64 = 120
|
||||
DS_ADD_RTN_F64 = 124
|
||||
DS_CONDXCHG32_RTN_B64 = 126
|
||||
DS_GWS_SEMA_RELEASE_ALL = 152
|
||||
DS_GWS_INIT = 153
|
||||
DS_GWS_SEMA_V = 154
|
||||
DS_GWS_SEMA_BR = 155
|
||||
DS_GWS_SEMA_P = 156
|
||||
DS_GWS_BARRIER = 157
|
||||
DS_READ_ADDTID_B32 = 182
|
||||
DS_PK_ADD_RTN_F16 = 183
|
||||
DS_PK_ADD_RTN_BF16 = 184
|
||||
@@ -174,7 +128,6 @@ class DSOp(IntEnum):
|
||||
DS_READ_B64_TR_B16 = 227
|
||||
DS_READ_B96 = 254
|
||||
DS_READ_B128 = 255
|
||||
CDNA4 = 600
|
||||
|
||||
class FLATOp(IntEnum):
|
||||
FLAT_LOAD_UBYTE = 16
|
||||
@@ -231,7 +184,6 @@ class FLATOp(IntEnum):
|
||||
FLAT_ATOMIC_XOR_X2 = 106
|
||||
FLAT_ATOMIC_INC_X2 = 107
|
||||
FLAT_ATOMIC_DEC_X2 = 108
|
||||
CDNA4 = 600
|
||||
|
||||
class GLOBALOp(IntEnum):
|
||||
GLOBAL_LOAD_UBYTE = 16
|
||||
@@ -295,7 +247,6 @@ class GLOBALOp(IntEnum):
|
||||
GLOBAL_ATOMIC_DEC_X2 = 108
|
||||
GLOBAL_LOAD_LDS_DWORDX4 = 125
|
||||
GLOBAL_LOAD_LDS_DWORDX3 = 126
|
||||
CDNA4 = 600
|
||||
|
||||
class MTBUFOp(IntEnum):
|
||||
TBUFFER_LOAD_FORMAT_X = 0
|
||||
@@ -390,7 +341,6 @@ class MUBUFOp(IntEnum):
|
||||
BUFFER_ATOMIC_XOR_X2 = 106
|
||||
BUFFER_ATOMIC_INC_X2 = 107
|
||||
BUFFER_ATOMIC_DEC_X2 = 108
|
||||
CDNA4 = 600
|
||||
|
||||
class SCRATCHOp(IntEnum):
|
||||
SCRATCH_LOAD_UBYTE = 16
|
||||
@@ -504,7 +454,6 @@ class SMEMOp(IntEnum):
|
||||
S_ATOMIC_XOR_X2 = 170
|
||||
S_ATOMIC_INC_X2 = 171
|
||||
S_ATOMIC_DEC_X2 = 172
|
||||
CDNA4 = 600
|
||||
|
||||
class SOP1Op(IntEnum):
|
||||
S_MOV_B32 = 0
|
||||
@@ -561,7 +510,6 @@ class SOP1Op(IntEnum):
|
||||
S_ANDN1_WREXEC_B64 = 53
|
||||
S_ANDN2_WREXEC_B64 = 54
|
||||
S_BITREPLICATE_B64_B32 = 55
|
||||
CDNA4 = 600
|
||||
|
||||
class SOP2Op(IntEnum):
|
||||
S_ADD_U32 = 0
|
||||
@@ -616,7 +564,6 @@ class SOP2Op(IntEnum):
|
||||
S_PACK_LL_B32_B16 = 50
|
||||
S_PACK_LH_B32_B16 = 51
|
||||
S_PACK_HH_B32_B16 = 52
|
||||
CDNA4 = 600
|
||||
|
||||
class SOPCOp(IntEnum):
|
||||
S_CMP_EQ_I32 = 0
|
||||
@@ -639,7 +586,6 @@ class SOPCOp(IntEnum):
|
||||
S_SET_GPR_IDX_ON = 17
|
||||
S_CMP_EQ_U64 = 18
|
||||
S_CMP_LG_U64 = 19
|
||||
CDNA4 = 600
|
||||
|
||||
class SOPKOp(IntEnum):
|
||||
S_MOVK_I32 = 0
|
||||
@@ -695,7 +641,6 @@ class SOPPOp(IntEnum):
|
||||
S_ENDPGM_SAVED = 27
|
||||
S_SET_GPR_IDX_OFF = 28
|
||||
S_SET_GPR_IDX_MODE = 29
|
||||
CDNA4 = 600
|
||||
|
||||
class VOP1Op(IntEnum):
|
||||
V_NOP = 0
|
||||
@@ -783,7 +728,6 @@ class VOP1Op(IntEnum):
|
||||
V_PERMLANE16_SWAP_B32 = 89
|
||||
V_PERMLANE32_SWAP_B32 = 90
|
||||
V_CVT_F32_BF16 = 91
|
||||
CDNA4 = 600
|
||||
|
||||
class VOP2Op(IntEnum):
|
||||
V_CNDMASK_B32 = 0
|
||||
@@ -848,7 +792,6 @@ class VOP2Op(IntEnum):
|
||||
V_FMAC_F32 = 59
|
||||
V_PK_FMAC_F16 = 60
|
||||
V_XNOR_B32 = 61
|
||||
CDNA4 = 600
|
||||
|
||||
class VOP3AOp(IntEnum):
|
||||
V_CMP_CLASS_F32 = 16
|
||||
@@ -1268,7 +1211,7 @@ class VOP3AOp(IntEnum):
|
||||
V_CVT_SCALEF32_SR_PK32_BF6_F32 = 597
|
||||
V_CVT_SCALEF32_PK32_F32_FP6 = 598
|
||||
V_CVT_SCALEF32_PK32_F32_BF6 = 599
|
||||
CDNA4 = 600
|
||||
V_CVT_SCALEF32_PK32_FP6_F16 = 600
|
||||
V_CVT_SCALEF32_PK32_FP6_BF16 = 601
|
||||
V_CVT_SCALEF32_PK32_BF6_F16 = 602
|
||||
V_CVT_SCALEF32_PK32_BF6_BF16 = 603
|
||||
@@ -1338,7 +1281,6 @@ class VOP3BOp(IntEnum):
|
||||
V_DIV_SCALE_F64 = 481
|
||||
V_MAD_U64_U32 = 488
|
||||
V_MAD_I64_I32 = 489
|
||||
CDNA4 = 600
|
||||
|
||||
class VOP3POp(IntEnum):
|
||||
V_PK_MAD_I16 = 0
|
||||
@@ -1388,8 +1330,6 @@ class VOP3POp(IntEnum):
|
||||
V_SMFMAC_F32_16X16X128_BF8_BF8 = 59
|
||||
V_SMFMAC_F32_16X16X128_BF8_FP8 = 60
|
||||
V_SMFMAC_F32_16X16X128_FP8_BF8 = 61
|
||||
V_MFMA_F32_16X16X8_XF32 = 62
|
||||
V_MFMA_F32_32X32X4_XF32 = 63
|
||||
V_MFMA_F32_32X32X1_2B_F32 = 64
|
||||
V_MFMA_F32_16X16X1_4B_F32 = 65
|
||||
V_MFMA_F32_4X4X1_16B_F32 = 66
|
||||
@@ -1447,7 +1387,6 @@ class VOP3POp(IntEnum):
|
||||
V_SMFMAC_F32_32X32X32_BF8_FP8 = 125
|
||||
V_SMFMAC_F32_32X32X32_FP8_BF8 = 126
|
||||
V_SMFMAC_F32_32X32X32_FP8_FP8 = 127
|
||||
CDNA4 = 600
|
||||
|
||||
class VOPCOp(IntEnum):
|
||||
V_CMP_CLASS_F32 = 16
|
||||
@@ -1648,4 +1587,3 @@ class VOPCOp(IntEnum):
|
||||
V_CMPX_NE_U64 = 253
|
||||
V_CMPX_GE_U64 = 254
|
||||
V_CMPX_T_U64 = 255
|
||||
CDNA4 = 600
|
||||
|
||||
@@ -1,19 +1,18 @@
|
||||
# autogenerated from AMD CDNA3+CDNA4 ISA PDF by pdf.py - do not edit
|
||||
# autogenerated from AMD ISA PDF by pdf.py - do not edit
|
||||
# ruff: noqa: F401,F403
|
||||
from typing import Annotated
|
||||
from extra.assembly.amd.dsl import bits, BitField, Inst32, Inst64, Inst96, SGPR, VGPR, TTMP as TTMP, s as s, v as v, ttmp as ttmp, SSrc, Src, SImm, Imm, VDSTYEnc, SGPRField, VGPRField
|
||||
from extra.assembly.amd.dsl import *
|
||||
from extra.assembly.amd.autogen.cdna.enum import *
|
||||
import functools
|
||||
|
||||
# instruction formats
|
||||
class DPP(Inst64):
|
||||
class DPP(Inst):
|
||||
encoding = bits[8:0] == 0b11111010
|
||||
vop_op = bits[16:9]
|
||||
vdst:VGPRField = bits[24:17]
|
||||
vop2_op = bits[31:25]
|
||||
src0:Src = bits[39:32]
|
||||
vop_op = bits[16:9]
|
||||
vop2_op = bits[31:25]
|
||||
dpp_ctrl = bits[48:40]
|
||||
bound_ctrl = bits[51]
|
||||
bc = bits[51]
|
||||
src0_neg = bits[52]
|
||||
src0_abs = bits[53]
|
||||
src1_neg = bits[54]
|
||||
@@ -21,7 +20,7 @@ class DPP(Inst64):
|
||||
bank_mask = bits[59:56]
|
||||
row_mask = bits[63:60]
|
||||
|
||||
class DS(Inst64):
|
||||
class DS(Inst):
|
||||
encoding = bits[31:26] == 0b110110
|
||||
op:Annotated[BitField, DSOp] = bits[24:17]
|
||||
vdst:VGPRField = bits[63:56]
|
||||
@@ -33,7 +32,7 @@ class DS(Inst64):
|
||||
gds = bits[16]
|
||||
acc = bits[25]
|
||||
|
||||
class FLAT(Inst64):
|
||||
class FLAT(Inst):
|
||||
encoding = bits[31:26] == 0b110111
|
||||
op:Annotated[BitField, FLATOp] = bits[24:18]
|
||||
vdst:VGPRField = bits[63:56]
|
||||
@@ -48,7 +47,7 @@ class FLAT(Inst64):
|
||||
sc1 = bits[25]
|
||||
acc = bits[55]
|
||||
|
||||
class MTBUF(Inst64):
|
||||
class MTBUF(Inst):
|
||||
encoding = bits[31:26] == 0b111010
|
||||
op:Annotated[BitField, MTBUFOp] = bits[18:15]
|
||||
vdata:VGPRField = bits[47:40]
|
||||
@@ -58,12 +57,14 @@ class MTBUF(Inst64):
|
||||
offset:Imm = bits[11:0]
|
||||
offen = bits[12]
|
||||
idxen = bits[13]
|
||||
sc0 = bits[14]
|
||||
dfmt = bits[22:19]
|
||||
nfmt = bits[25:23]
|
||||
sc1 = bits[53]
|
||||
nt = bits[54]
|
||||
acc = bits[55]
|
||||
sc0 = bits[14]
|
||||
|
||||
class MUBUF(Inst64):
|
||||
class MUBUF(Inst):
|
||||
encoding = bits[31:26] == 0b111000
|
||||
op:Annotated[BitField, MUBUFOp] = bits[24:18]
|
||||
vdata:VGPRField = bits[47:40]
|
||||
@@ -79,16 +80,16 @@ class MUBUF(Inst64):
|
||||
nt = bits[17]
|
||||
acc = bits[55]
|
||||
|
||||
class SDWA(Inst64):
|
||||
class SDWA(Inst):
|
||||
encoding = bits[8:0] == 0b11111001
|
||||
vop_op = bits[16:9]
|
||||
vdst:VGPRField = bits[24:17]
|
||||
src0:Src = bits[39:32]
|
||||
omod = bits[47:46]
|
||||
clmp = bits[45]
|
||||
vop_op = bits[16:9]
|
||||
vop2_op = bits[31:25]
|
||||
src0:Src = bits[39:32]
|
||||
dst_sel = bits[42:40]
|
||||
dst_u = bits[44:43]
|
||||
clmp = bits[45]
|
||||
omod = bits[47:46]
|
||||
src0_sel = bits[50:48]
|
||||
src0_sext = bits[51]
|
||||
src0_neg = bits[52]
|
||||
@@ -100,12 +101,10 @@ class SDWA(Inst64):
|
||||
src1_abs = bits[61]
|
||||
s1 = bits[63]
|
||||
|
||||
class SDWAB(Inst64):
|
||||
class SDWAB(Inst):
|
||||
sdst:SGPRField = bits[46:40]
|
||||
src0:Src = bits[39:32]
|
||||
dst_sel = bits[42:40]
|
||||
dst_u = bits[44:43]
|
||||
clmp = bits[45]
|
||||
omod = bits[47:46]
|
||||
sd = bits[47]
|
||||
src0_sel = bits[50:48]
|
||||
src0_sext = bits[51]
|
||||
src0_neg = bits[52]
|
||||
@@ -117,89 +116,88 @@ class SDWAB(Inst64):
|
||||
src1_abs = bits[61]
|
||||
s1 = bits[63]
|
||||
|
||||
class SMEM(Inst64):
|
||||
class SMEM(Inst):
|
||||
encoding = bits[31:26] == 0b110000
|
||||
op:Annotated[BitField, SMEMOp] = bits[25:18]
|
||||
sdata:SGPRField = bits[12:6]
|
||||
sbase:SGPRField = bits[5:0]
|
||||
soffset:SSrc = bits[63:57]
|
||||
offset:Imm = bits[52:32]
|
||||
glc = bits[14]
|
||||
glc = bits[16]
|
||||
soe = bits[14]
|
||||
nv = bits[15]
|
||||
imm = bits[17]
|
||||
imm:Imm = bits[17]
|
||||
|
||||
class SOP1(Inst32):
|
||||
class SOP1(Inst):
|
||||
encoding = bits[31:23] == 0b101111101
|
||||
op:Annotated[BitField, SOP1Op] = bits[15:8]
|
||||
sdst:SGPRField = bits[22:16]
|
||||
ssrc0:SSrc = bits[7:0]
|
||||
|
||||
class SOP2(Inst32):
|
||||
class SOP2(Inst):
|
||||
encoding = bits[31:30] == 0b10
|
||||
op:Annotated[BitField, SOP2Op] = bits[29:23]
|
||||
sdst:SGPRField = bits[22:16]
|
||||
ssrc0:SSrc = bits[7:0]
|
||||
ssrc1:SSrc = bits[15:8]
|
||||
|
||||
class SOPC(Inst32):
|
||||
class SOPC(Inst):
|
||||
encoding = bits[31:23] == 0b101111110
|
||||
op:Annotated[BitField, SOPCOp] = bits[22:16]
|
||||
ssrc0:SSrc = bits[7:0]
|
||||
ssrc1:SSrc = bits[15:8]
|
||||
|
||||
class SOPK(Inst32):
|
||||
class SOPK(Inst):
|
||||
encoding = bits[31:28] == 0b1011
|
||||
op:Annotated[BitField, SOPKOp] = bits[27:23]
|
||||
sdst:SGPRField = bits[22:16]
|
||||
simm16:SImm = bits[15:0]
|
||||
|
||||
class SOPP(Inst32):
|
||||
class SOPP(Inst):
|
||||
encoding = bits[31:23] == 0b101111111
|
||||
op:Annotated[BitField, SOPPOp] = bits[22:16]
|
||||
simm16:SImm = bits[15:0]
|
||||
|
||||
class VOP1(Inst32):
|
||||
encoding = bits[31:25] == 0b111111
|
||||
class VOP1(Inst):
|
||||
encoding = bits[31:25] == 0b0111111
|
||||
op:Annotated[BitField, VOP1Op] = bits[16:9]
|
||||
vdst:VGPRField = bits[24:17]
|
||||
src0:Src = bits[8:0]
|
||||
|
||||
class VOP2(Inst32):
|
||||
encoding = bits[31] == 0
|
||||
class VOP2(Inst):
|
||||
encoding = bits[31] == 0b0
|
||||
op:Annotated[BitField, VOP2Op] = bits[30:25]
|
||||
vdst:VGPRField = bits[24:17]
|
||||
src0:Src = bits[8:0]
|
||||
vsrc1:VGPRField = bits[16:9]
|
||||
|
||||
class VOP3A(Inst64):
|
||||
class VOP3A(Inst):
|
||||
encoding = bits[31:26] == 0b110100
|
||||
vdst:VGPRField = bits[7:0]
|
||||
abs = bits[10:8]
|
||||
opsel = bits[14:11]
|
||||
clmp = bits[15]
|
||||
op:Annotated[BitField, VOP3AOp] = bits[25:16]
|
||||
vdst:VGPRField = bits[7:0]
|
||||
src0:Src = bits[40:32]
|
||||
src1:Src = bits[49:41]
|
||||
src2:Src = bits[58:50]
|
||||
omod = bits[60:59]
|
||||
neg = bits[63:61]
|
||||
abs = bits[10:8]
|
||||
clmp = bits[15]
|
||||
opsel = bits[14:11]
|
||||
|
||||
class VOP3B(Inst64):
|
||||
class VOP3B(Inst):
|
||||
encoding = bits[31:26] == 0b110100
|
||||
op:Annotated[BitField, VOP3BOp] = bits[25:16]
|
||||
vdst:VGPRField = bits[7:0]
|
||||
sdst:SGPRField = bits[14:8]
|
||||
clmp = bits[15]
|
||||
op:Annotated[BitField, VOP3BOp] = bits[25:16]
|
||||
src0:Src = bits[40:32]
|
||||
src1:Src = bits[49:41]
|
||||
src2:Src = bits[58:50]
|
||||
omod = bits[60:59]
|
||||
neg = bits[63:61]
|
||||
clmp = bits[15]
|
||||
|
||||
class VOP3P(Inst64):
|
||||
class VOP3P(Inst):
|
||||
encoding = bits[31:23] == 0b110100111
|
||||
_defaults = {'opsel_hi': 3, 'opsel_hi2': 1}
|
||||
op:Annotated[BitField, VOP3POp] = bits[22:16]
|
||||
vdst:VGPRField = bits[7:0]
|
||||
src0:Src = bits[40:32]
|
||||
@@ -207,13 +205,13 @@ class VOP3P(Inst64):
|
||||
src2:Src = bits[58:50]
|
||||
neg = bits[63:61]
|
||||
neg_hi = bits[10:8]
|
||||
clmp = bits[15]
|
||||
opsel = bits[13:11]
|
||||
opsel_hi = bits[60:59]
|
||||
clmp = bits[15]
|
||||
opsel_hi2 = bits[14]
|
||||
|
||||
class VOPC(Inst32):
|
||||
encoding = bits[31:25] == 0b111110
|
||||
class VOPC(Inst):
|
||||
encoding = bits[31:25] == 0b0111110
|
||||
op:Annotated[BitField, VOPCOp] = bits[24:17]
|
||||
src0:Src = bits[8:0]
|
||||
vsrc1:VGPRField = bits[16:9]
|
||||
@@ -332,12 +330,6 @@ ds_read2_b64 = functools.partial(DS, DSOp.DS_READ2_B64)
|
||||
ds_read2st64_b64 = functools.partial(DS, DSOp.DS_READ2ST64_B64)
|
||||
ds_add_rtn_f64 = functools.partial(DS, DSOp.DS_ADD_RTN_F64)
|
||||
ds_condxchg32_rtn_b64 = functools.partial(DS, DSOp.DS_CONDXCHG32_RTN_B64)
|
||||
ds_gws_sema_release_all = functools.partial(DS, DSOp.DS_GWS_SEMA_RELEASE_ALL)
|
||||
ds_gws_init = functools.partial(DS, DSOp.DS_GWS_INIT)
|
||||
ds_gws_sema_v = functools.partial(DS, DSOp.DS_GWS_SEMA_V)
|
||||
ds_gws_sema_br = functools.partial(DS, DSOp.DS_GWS_SEMA_BR)
|
||||
ds_gws_sema_p = functools.partial(DS, DSOp.DS_GWS_SEMA_P)
|
||||
ds_gws_barrier = functools.partial(DS, DSOp.DS_GWS_BARRIER)
|
||||
ds_read_addtid_b32 = functools.partial(DS, DSOp.DS_READ_ADDTID_B32)
|
||||
ds_pk_add_rtn_f16 = functools.partial(DS, DSOp.DS_PK_ADD_RTN_F16)
|
||||
ds_pk_add_rtn_bf16 = functools.partial(DS, DSOp.DS_PK_ADD_RTN_BF16)
|
||||
@@ -351,7 +343,6 @@ ds_read_b64_tr_b8 = functools.partial(DS, DSOp.DS_READ_B64_TR_B8)
|
||||
ds_read_b64_tr_b16 = functools.partial(DS, DSOp.DS_READ_B64_TR_B16)
|
||||
ds_read_b96 = functools.partial(DS, DSOp.DS_READ_B96)
|
||||
ds_read_b128 = functools.partial(DS, DSOp.DS_READ_B128)
|
||||
cdna4 = functools.partial(DS, DSOp.CDNA4)
|
||||
flat_load_ubyte = functools.partial(FLAT, FLATOp.FLAT_LOAD_UBYTE)
|
||||
flat_load_sbyte = functools.partial(FLAT, FLATOp.FLAT_LOAD_SBYTE)
|
||||
flat_load_ushort = functools.partial(FLAT, FLATOp.FLAT_LOAD_USHORT)
|
||||
@@ -406,7 +397,6 @@ flat_atomic_or_x2 = functools.partial(FLAT, FLATOp.FLAT_ATOMIC_OR_X2)
|
||||
flat_atomic_xor_x2 = functools.partial(FLAT, FLATOp.FLAT_ATOMIC_XOR_X2)
|
||||
flat_atomic_inc_x2 = functools.partial(FLAT, FLATOp.FLAT_ATOMIC_INC_X2)
|
||||
flat_atomic_dec_x2 = functools.partial(FLAT, FLATOp.FLAT_ATOMIC_DEC_X2)
|
||||
cdna4 = functools.partial(FLAT, FLATOp.CDNA4)
|
||||
global_load_ubyte = functools.partial(FLAT, GLOBALOp.GLOBAL_LOAD_UBYTE, seg=2)
|
||||
global_load_sbyte = functools.partial(FLAT, GLOBALOp.GLOBAL_LOAD_SBYTE, seg=2)
|
||||
global_load_ushort = functools.partial(FLAT, GLOBALOp.GLOBAL_LOAD_USHORT, seg=2)
|
||||
@@ -468,7 +458,6 @@ global_atomic_inc_x2 = functools.partial(FLAT, GLOBALOp.GLOBAL_ATOMIC_INC_X2, se
|
||||
global_atomic_dec_x2 = functools.partial(FLAT, GLOBALOp.GLOBAL_ATOMIC_DEC_X2, seg=2)
|
||||
global_load_lds_dwordx4 = functools.partial(FLAT, GLOBALOp.GLOBAL_LOAD_LDS_DWORDX4, seg=2)
|
||||
global_load_lds_dwordx3 = functools.partial(FLAT, GLOBALOp.GLOBAL_LOAD_LDS_DWORDX3, seg=2)
|
||||
cdna4 = functools.partial(FLAT, GLOBALOp.CDNA4, seg=2)
|
||||
tbuffer_load_format_x = functools.partial(MTBUF, MTBUFOp.TBUFFER_LOAD_FORMAT_X)
|
||||
tbuffer_load_format_xy = functools.partial(MTBUF, MTBUFOp.TBUFFER_LOAD_FORMAT_XY)
|
||||
tbuffer_load_format_xyz = functools.partial(MTBUF, MTBUFOp.TBUFFER_LOAD_FORMAT_XYZ)
|
||||
@@ -559,7 +548,6 @@ buffer_atomic_or_x2 = functools.partial(MUBUF, MUBUFOp.BUFFER_ATOMIC_OR_X2)
|
||||
buffer_atomic_xor_x2 = functools.partial(MUBUF, MUBUFOp.BUFFER_ATOMIC_XOR_X2)
|
||||
buffer_atomic_inc_x2 = functools.partial(MUBUF, MUBUFOp.BUFFER_ATOMIC_INC_X2)
|
||||
buffer_atomic_dec_x2 = functools.partial(MUBUF, MUBUFOp.BUFFER_ATOMIC_DEC_X2)
|
||||
cdna4 = functools.partial(MUBUF, MUBUFOp.CDNA4)
|
||||
scratch_load_ubyte = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_UBYTE, seg=1)
|
||||
scratch_load_sbyte = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_SBYTE, seg=1)
|
||||
scratch_load_ushort = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_USHORT, seg=1)
|
||||
@@ -669,7 +657,6 @@ s_atomic_or_x2 = functools.partial(SMEM, SMEMOp.S_ATOMIC_OR_X2)
|
||||
s_atomic_xor_x2 = functools.partial(SMEM, SMEMOp.S_ATOMIC_XOR_X2)
|
||||
s_atomic_inc_x2 = functools.partial(SMEM, SMEMOp.S_ATOMIC_INC_X2)
|
||||
s_atomic_dec_x2 = functools.partial(SMEM, SMEMOp.S_ATOMIC_DEC_X2)
|
||||
cdna4 = functools.partial(SMEM, SMEMOp.CDNA4)
|
||||
s_mov_b32 = functools.partial(SOP1, SOP1Op.S_MOV_B32)
|
||||
s_mov_b64 = functools.partial(SOP1, SOP1Op.S_MOV_B64)
|
||||
s_cmov_b32 = functools.partial(SOP1, SOP1Op.S_CMOV_B32)
|
||||
@@ -724,7 +711,6 @@ s_orn1_saveexec_b64 = functools.partial(SOP1, SOP1Op.S_ORN1_SAVEEXEC_B64)
|
||||
s_andn1_wrexec_b64 = functools.partial(SOP1, SOP1Op.S_ANDN1_WREXEC_B64)
|
||||
s_andn2_wrexec_b64 = functools.partial(SOP1, SOP1Op.S_ANDN2_WREXEC_B64)
|
||||
s_bitreplicate_b64_b32 = functools.partial(SOP1, SOP1Op.S_BITREPLICATE_B64_B32)
|
||||
cdna4 = functools.partial(SOP1, SOP1Op.CDNA4)
|
||||
s_add_u32 = functools.partial(SOP2, SOP2Op.S_ADD_U32)
|
||||
s_sub_u32 = functools.partial(SOP2, SOP2Op.S_SUB_U32)
|
||||
s_add_i32 = functools.partial(SOP2, SOP2Op.S_ADD_I32)
|
||||
@@ -777,7 +763,6 @@ s_lshl4_add_u32 = functools.partial(SOP2, SOP2Op.S_LSHL4_ADD_U32)
|
||||
s_pack_ll_b32_b16 = functools.partial(SOP2, SOP2Op.S_PACK_LL_B32_B16)
|
||||
s_pack_lh_b32_b16 = functools.partial(SOP2, SOP2Op.S_PACK_LH_B32_B16)
|
||||
s_pack_hh_b32_b16 = functools.partial(SOP2, SOP2Op.S_PACK_HH_B32_B16)
|
||||
cdna4 = functools.partial(SOP2, SOP2Op.CDNA4)
|
||||
s_cmp_eq_i32 = functools.partial(SOPC, SOPCOp.S_CMP_EQ_I32)
|
||||
s_cmp_lg_i32 = functools.partial(SOPC, SOPCOp.S_CMP_LG_I32)
|
||||
s_cmp_gt_i32 = functools.partial(SOPC, SOPCOp.S_CMP_GT_I32)
|
||||
@@ -798,7 +783,6 @@ s_setvskip = functools.partial(SOPC, SOPCOp.S_SETVSKIP)
|
||||
s_set_gpr_idx_on = functools.partial(SOPC, SOPCOp.S_SET_GPR_IDX_ON)
|
||||
s_cmp_eq_u64 = functools.partial(SOPC, SOPCOp.S_CMP_EQ_U64)
|
||||
s_cmp_lg_u64 = functools.partial(SOPC, SOPCOp.S_CMP_LG_U64)
|
||||
cdna4 = functools.partial(SOPC, SOPCOp.CDNA4)
|
||||
s_movk_i32 = functools.partial(SOPK, SOPKOp.S_MOVK_I32)
|
||||
s_cmovk_i32 = functools.partial(SOPK, SOPKOp.S_CMOVK_I32)
|
||||
s_cmpk_eq_i32 = functools.partial(SOPK, SOPKOp.S_CMPK_EQ_I32)
|
||||
@@ -850,7 +834,6 @@ s_cbranch_cdbgsys_and_user = functools.partial(SOPP, SOPPOp.S_CBRANCH_CDBGSYS_AN
|
||||
s_endpgm_saved = functools.partial(SOPP, SOPPOp.S_ENDPGM_SAVED)
|
||||
s_set_gpr_idx_off = functools.partial(SOPP, SOPPOp.S_SET_GPR_IDX_OFF)
|
||||
s_set_gpr_idx_mode = functools.partial(SOPP, SOPPOp.S_SET_GPR_IDX_MODE)
|
||||
cdna4 = functools.partial(SOPP, SOPPOp.CDNA4)
|
||||
v_nop_e32 = functools.partial(VOP1, VOP1Op.V_NOP)
|
||||
v_mov_b32_e32 = functools.partial(VOP1, VOP1Op.V_MOV_B32)
|
||||
v_readfirstlane_b32_e32 = functools.partial(VOP1, VOP1Op.V_READFIRSTLANE_B32)
|
||||
@@ -936,7 +919,6 @@ v_prng_b32_e32 = functools.partial(VOP1, VOP1Op.V_PRNG_B32)
|
||||
v_permlane16_swap_b32_e32 = functools.partial(VOP1, VOP1Op.V_PERMLANE16_SWAP_B32)
|
||||
v_permlane32_swap_b32_e32 = functools.partial(VOP1, VOP1Op.V_PERMLANE32_SWAP_B32)
|
||||
v_cvt_f32_bf16_e32 = functools.partial(VOP1, VOP1Op.V_CVT_F32_BF16)
|
||||
cdna4_e32 = functools.partial(VOP1, VOP1Op.CDNA4)
|
||||
v_cndmask_b32_e32 = functools.partial(VOP2, VOP2Op.V_CNDMASK_B32)
|
||||
v_add_f32_e32 = functools.partial(VOP2, VOP2Op.V_ADD_F32)
|
||||
v_sub_f32_e32 = functools.partial(VOP2, VOP2Op.V_SUB_F32)
|
||||
@@ -960,8 +942,8 @@ v_and_b32_e32 = functools.partial(VOP2, VOP2Op.V_AND_B32)
|
||||
v_or_b32_e32 = functools.partial(VOP2, VOP2Op.V_OR_B32)
|
||||
v_xor_b32_e32 = functools.partial(VOP2, VOP2Op.V_XOR_B32)
|
||||
v_dot2c_f32_bf16_e32 = functools.partial(VOP2, VOP2Op.V_DOT2C_F32_BF16)
|
||||
def v_fmamk_f32_e32(vdst, src0, K, vsrc1): return VOP2(VOP2Op.V_FMAMK_F32, vdst, src0, vsrc1, literal=K)
|
||||
def v_fmaak_f32_e32(vdst, src0, vsrc1, K): return VOP2(VOP2Op.V_FMAAK_F32, vdst, src0, vsrc1, literal=K)
|
||||
v_fmamk_f32_e32 = functools.partial(VOP2, VOP2Op.V_FMAMK_F32)
|
||||
v_fmaak_f32_e32 = functools.partial(VOP2, VOP2Op.V_FMAAK_F32)
|
||||
v_add_co_u32_e32 = functools.partial(VOP2, VOP2Op.V_ADD_CO_U32)
|
||||
v_sub_co_u32_e32 = functools.partial(VOP2, VOP2Op.V_SUB_CO_U32)
|
||||
v_subrev_co_u32_e32 = functools.partial(VOP2, VOP2Op.V_SUBREV_CO_U32)
|
||||
@@ -999,7 +981,6 @@ v_dot8c_i32_i4_e32 = functools.partial(VOP2, VOP2Op.V_DOT8C_I32_I4)
|
||||
v_fmac_f32_e32 = functools.partial(VOP2, VOP2Op.V_FMAC_F32)
|
||||
v_pk_fmac_f16_e32 = functools.partial(VOP2, VOP2Op.V_PK_FMAC_F16)
|
||||
v_xnor_b32_e32 = functools.partial(VOP2, VOP2Op.V_XNOR_B32)
|
||||
cdna4_e32 = functools.partial(VOP2, VOP2Op.CDNA4)
|
||||
v_cmp_class_f32 = functools.partial(VOP3A, VOP3AOp.V_CMP_CLASS_F32)
|
||||
v_cmpx_class_f32 = functools.partial(VOP3A, VOP3AOp.V_CMPX_CLASS_F32)
|
||||
v_cmp_class_f64 = functools.partial(VOP3A, VOP3AOp.V_CMP_CLASS_F64)
|
||||
@@ -1417,7 +1398,7 @@ v_cvt_scalef32_sr_pk32_fp6_f32 = functools.partial(VOP3A, VOP3AOp.V_CVT_SCALEF32
|
||||
v_cvt_scalef32_sr_pk32_bf6_f32 = functools.partial(VOP3A, VOP3AOp.V_CVT_SCALEF32_SR_PK32_BF6_F32)
|
||||
v_cvt_scalef32_pk32_f32_fp6 = functools.partial(VOP3A, VOP3AOp.V_CVT_SCALEF32_PK32_F32_FP6)
|
||||
v_cvt_scalef32_pk32_f32_bf6 = functools.partial(VOP3A, VOP3AOp.V_CVT_SCALEF32_PK32_F32_BF6)
|
||||
cdna4 = functools.partial(VOP3A, VOP3AOp.CDNA4)
|
||||
v_cvt_scalef32_pk32_fp6_f16 = functools.partial(VOP3A, VOP3AOp.V_CVT_SCALEF32_PK32_FP6_F16)
|
||||
v_cvt_scalef32_pk32_fp6_bf16 = functools.partial(VOP3A, VOP3AOp.V_CVT_SCALEF32_PK32_FP6_BF16)
|
||||
v_cvt_scalef32_pk32_bf6_f16 = functools.partial(VOP3A, VOP3AOp.V_CVT_SCALEF32_PK32_BF6_F16)
|
||||
v_cvt_scalef32_pk32_bf6_bf16 = functools.partial(VOP3A, VOP3AOp.V_CVT_SCALEF32_PK32_BF6_BF16)
|
||||
@@ -1485,7 +1466,6 @@ v_div_scale_f32 = functools.partial(VOP3B, VOP3BOp.V_DIV_SCALE_F32)
|
||||
v_div_scale_f64 = functools.partial(VOP3B, VOP3BOp.V_DIV_SCALE_F64)
|
||||
v_mad_u64_u32 = functools.partial(VOP3B, VOP3BOp.V_MAD_U64_U32)
|
||||
v_mad_i64_i32 = functools.partial(VOP3B, VOP3BOp.V_MAD_I64_I32)
|
||||
cdna4 = functools.partial(VOP3B, VOP3BOp.CDNA4)
|
||||
v_pk_mad_i16 = functools.partial(VOP3P, VOP3POp.V_PK_MAD_I16)
|
||||
v_pk_mul_lo_u16 = functools.partial(VOP3P, VOP3POp.V_PK_MUL_LO_U16)
|
||||
v_pk_add_i16 = functools.partial(VOP3P, VOP3POp.V_PK_ADD_I16)
|
||||
@@ -1533,8 +1513,6 @@ v_smfmac_i32_16x16x128_i8 = functools.partial(VOP3P, VOP3POp.V_SMFMAC_I32_16X16X
|
||||
v_smfmac_f32_16x16x128_bf8_bf8 = functools.partial(VOP3P, VOP3POp.V_SMFMAC_F32_16X16X128_BF8_BF8)
|
||||
v_smfmac_f32_16x16x128_bf8_fp8 = functools.partial(VOP3P, VOP3POp.V_SMFMAC_F32_16X16X128_BF8_FP8)
|
||||
v_smfmac_f32_16x16x128_fp8_bf8 = functools.partial(VOP3P, VOP3POp.V_SMFMAC_F32_16X16X128_FP8_BF8)
|
||||
v_mfma_f32_16x16x8_xf32 = functools.partial(VOP3P, VOP3POp.V_MFMA_F32_16X16X8_XF32)
|
||||
v_mfma_f32_32x32x4_xf32 = functools.partial(VOP3P, VOP3POp.V_MFMA_F32_32X32X4_XF32)
|
||||
v_mfma_f32_32x32x1_2b_f32 = functools.partial(VOP3P, VOP3POp.V_MFMA_F32_32X32X1_2B_F32)
|
||||
v_mfma_f32_16x16x1_4b_f32 = functools.partial(VOP3P, VOP3POp.V_MFMA_F32_16X16X1_4B_F32)
|
||||
v_mfma_f32_4x4x1_16b_f32 = functools.partial(VOP3P, VOP3POp.V_MFMA_F32_4X4X1_16B_F32)
|
||||
@@ -1592,7 +1570,6 @@ v_smfmac_f32_32x32x32_bf8_bf8 = functools.partial(VOP3P, VOP3POp.V_SMFMAC_F32_32
|
||||
v_smfmac_f32_32x32x32_bf8_fp8 = functools.partial(VOP3P, VOP3POp.V_SMFMAC_F32_32X32X32_BF8_FP8)
|
||||
v_smfmac_f32_32x32x32_fp8_bf8 = functools.partial(VOP3P, VOP3POp.V_SMFMAC_F32_32X32X32_FP8_BF8)
|
||||
v_smfmac_f32_32x32x32_fp8_fp8 = functools.partial(VOP3P, VOP3POp.V_SMFMAC_F32_32X32X32_FP8_FP8)
|
||||
cdna4 = functools.partial(VOP3P, VOP3POp.CDNA4)
|
||||
v_cmp_class_f32_e32 = functools.partial(VOPC, VOPCOp.V_CMP_CLASS_F32)
|
||||
v_cmpx_class_f32_e32 = functools.partial(VOPC, VOPCOp.V_CMPX_CLASS_F32)
|
||||
v_cmp_class_f64_e32 = functools.partial(VOPC, VOPCOp.V_CMP_CLASS_F64)
|
||||
@@ -1790,42 +1767,4 @@ v_cmpx_le_u64_e32 = functools.partial(VOPC, VOPCOp.V_CMPX_LE_U64)
|
||||
v_cmpx_gt_u64_e32 = functools.partial(VOPC, VOPCOp.V_CMPX_GT_U64)
|
||||
v_cmpx_ne_u64_e32 = functools.partial(VOPC, VOPCOp.V_CMPX_NE_U64)
|
||||
v_cmpx_ge_u64_e32 = functools.partial(VOPC, VOPCOp.V_CMPX_GE_U64)
|
||||
v_cmpx_t_u64_e32 = functools.partial(VOPC, VOPCOp.V_CMPX_T_U64)
|
||||
cdna4_e32 = functools.partial(VOPC, VOPCOp.CDNA4)
|
||||
|
||||
S_ADD_U32 = SrcEnum.S_ADD_U32
|
||||
S_SUB_U32 = SrcEnum.S_SUB_U32
|
||||
S_ADD_I32 = SrcEnum.S_ADD_I32
|
||||
S_SUB_I32 = SrcEnum.S_SUB_I32
|
||||
S_ADDC_U32 = SrcEnum.S_ADDC_U32
|
||||
S_SUBB_U32 = SrcEnum.S_SUBB_U32
|
||||
S_MIN_I32 = SrcEnum.S_MIN_I32
|
||||
FLAT_SCRATCH_LO = SrcEnum.FLAT_SCRATCH_LO
|
||||
FLAT_SCRATCH_HI = SrcEnum.FLAT_SCRATCH_HI
|
||||
XNACK_MASK_LO = SrcEnum.XNACK_MASK_LO
|
||||
XNACK_MASK_HI = SrcEnum.XNACK_MASK_HI
|
||||
VCC_LO = SrcEnum.VCC_LO
|
||||
VCC_HI = SrcEnum.VCC_HI
|
||||
M0 = SrcEnum.M0
|
||||
EXEC_LO = SrcEnum.EXEC_LO
|
||||
EXEC_HI = SrcEnum.EXEC_HI
|
||||
ZERO = SrcEnum.ZERO
|
||||
DPP8FI = SrcEnum.DPP8FI
|
||||
SHARED_BASE = SrcEnum.SHARED_BASE
|
||||
SHARED_LIMIT = SrcEnum.SHARED_LIMIT
|
||||
PRIVATE_BASE = SrcEnum.PRIVATE_BASE
|
||||
PRIVATE_LIMIT = SrcEnum.PRIVATE_LIMIT
|
||||
RESERVED = SrcEnum.RESERVED
|
||||
POS_HALF = SrcEnum.POS_HALF
|
||||
NEG_HALF = SrcEnum.NEG_HALF
|
||||
POS_ONE = SrcEnum.POS_ONE
|
||||
NEG_ONE = SrcEnum.NEG_ONE
|
||||
POS_TWO = SrcEnum.POS_TWO
|
||||
NEG_TWO = SrcEnum.NEG_TWO
|
||||
POS_FOUR = SrcEnum.POS_FOUR
|
||||
NEG_FOUR = SrcEnum.NEG_FOUR
|
||||
INV_2PI = SrcEnum.INV_2PI
|
||||
VCCZ = SrcEnum.VCCZ
|
||||
EXECZ = SrcEnum.EXECZ
|
||||
SCC = SrcEnum.SCC
|
||||
LDS_DIRECT = SrcEnum.LDS_DIRECT
|
||||
v_cmpx_t_u64_e32 = functools.partial(VOPC, VOPCOp.V_CMPX_T_U64)
|
||||
File diff suppressed because one or more lines are too long
@@ -1,34 +1,97 @@
|
||||
# autogenerated from AMD RDNA3.5 ISA PDF by pdf.py - do not edit
|
||||
# autogenerated from AMD ISA PDF by pdf.py - do not edit
|
||||
from enum import IntEnum
|
||||
|
||||
class SrcEnum(IntEnum):
|
||||
VCC_LO = 106
|
||||
VCC_HI = 107
|
||||
NULL = 124
|
||||
M0 = 125
|
||||
EXEC_LO = 126
|
||||
EXEC_HI = 127
|
||||
ZERO = 128
|
||||
DPP8 = 233
|
||||
DPP8FI = 234
|
||||
SHARED_BASE = 235
|
||||
SHARED_LIMIT = 236
|
||||
PRIVATE_BASE = 237
|
||||
PRIVATE_LIMIT = 238
|
||||
POS_HALF = 240
|
||||
NEG_HALF = 241
|
||||
POS_ONE = 242
|
||||
NEG_ONE = 243
|
||||
POS_TWO = 244
|
||||
NEG_TWO = 245
|
||||
POS_FOUR = 246
|
||||
NEG_FOUR = 247
|
||||
INV_2PI = 248
|
||||
DPP16 = 250
|
||||
VCCZ = 251
|
||||
EXECZ = 252
|
||||
SCC = 253
|
||||
LDS_DIRECT = 254
|
||||
class BufFmt(IntEnum):
|
||||
BUF_FMT_8_UNORM = 1
|
||||
BUF_FMT_8_SNORM = 2
|
||||
BUF_FMT_8_USCALED = 3
|
||||
BUF_FMT_8_SSCALED = 4
|
||||
BUF_FMT_8_UINT = 5
|
||||
BUF_FMT_8_SINT = 6
|
||||
BUF_FMT_16_UNORM = 7
|
||||
BUF_FMT_16_SNORM = 8
|
||||
BUF_FMT_16_USCALED = 9
|
||||
BUF_FMT_16_SSCALED = 10
|
||||
BUF_FMT_16_UINT = 11
|
||||
BUF_FMT_16_SINT = 12
|
||||
BUF_FMT_16_FLOAT = 13
|
||||
BUF_FMT_8_8_UNORM = 14
|
||||
BUF_FMT_8_8_SNORM = 15
|
||||
BUF_FMT_8_8_USCALED = 16
|
||||
BUF_FMT_8_8_SSCALED = 17
|
||||
BUF_FMT_8_8_UINT = 18
|
||||
BUF_FMT_8_8_SINT = 19
|
||||
BUF_FMT_32_UINT = 20
|
||||
BUF_FMT_32_SINT = 21
|
||||
BUF_FMT_32_FLOAT = 22
|
||||
BUF_FMT_16_16_UNORM = 23
|
||||
BUF_FMT_16_16_SNORM = 24
|
||||
BUF_FMT_16_16_USCALED = 25
|
||||
BUF_FMT_16_16_SSCALED = 26
|
||||
BUF_FMT_16_16_UINT = 27
|
||||
BUF_FMT_16_16_SINT = 28
|
||||
BUF_FMT_16_16_FLOAT = 29
|
||||
BUF_FMT_10_11_11_FLOAT = 30
|
||||
BUF_FMT_11_11_10_FLOAT = 31
|
||||
BUF_FMT_10_10_10_2_UNORM = 32
|
||||
BUF_FMT_10_10_10_2_SNORM = 33
|
||||
BUF_FMT_10_10_10_2_UINT = 34
|
||||
BUF_FMT_10_10_10_2_SINT = 35
|
||||
BUF_FMT_2_10_10_10_UNORM = 36
|
||||
BUF_FMT_2_10_10_10_SNORM = 37
|
||||
BUF_FMT_2_10_10_10_USCALED = 38
|
||||
BUF_FMT_2_10_10_10_SSCALED = 39
|
||||
BUF_FMT_2_10_10_10_UINT = 40
|
||||
BUF_FMT_2_10_10_10_SINT = 41
|
||||
BUF_FMT_8_8_8_8_UNORM = 42
|
||||
BUF_FMT_8_8_8_8_SNORM = 43
|
||||
BUF_FMT_8_8_8_8_USCALED = 44
|
||||
BUF_FMT_8_8_8_8_SSCALED = 45
|
||||
BUF_FMT_8_8_8_8_UINT = 46
|
||||
BUF_FMT_8_8_8_8_SINT = 47
|
||||
BUF_FMT_32_32_UINT = 48
|
||||
BUF_FMT_32_32_SINT = 49
|
||||
BUF_FMT_32_32_FLOAT = 50
|
||||
BUF_FMT_16_16_16_16_UNORM = 51
|
||||
BUF_FMT_16_16_16_16_SNORM = 52
|
||||
BUF_FMT_16_16_16_16_USCALED = 53
|
||||
BUF_FMT_16_16_16_16_SSCALED = 54
|
||||
BUF_FMT_16_16_16_16_UINT = 55
|
||||
BUF_FMT_16_16_16_16_SINT = 56
|
||||
BUF_FMT_16_16_16_16_FLOAT = 57
|
||||
BUF_FMT_32_32_32_UINT = 58
|
||||
BUF_FMT_32_32_32_SINT = 59
|
||||
BUF_FMT_32_32_32_FLOAT = 60
|
||||
BUF_FMT_32_32_32_32_UINT = 61
|
||||
BUF_FMT_8_SRGB = 64
|
||||
BUF_FMT_8_8_SRGB = 65
|
||||
BUF_FMT_8_8_8_8_SRGB = 66
|
||||
BUF_FMT_5_9_9_9_FLOAT = 67
|
||||
BUF_FMT_5_6_5_UNORM = 68
|
||||
BUF_FMT_1_5_5_5_UNORM = 69
|
||||
BUF_FMT_5_5_5_1_UNORM = 70
|
||||
BUF_FMT_4_4_4_4_UNORM = 71
|
||||
BUF_FMT_4_4_UNORM = 72
|
||||
BUF_FMT_1_UNORM = 73
|
||||
BUF_FMT_1_REVERSED_UNORM = 74
|
||||
BUF_FMT_32_FLOAT_CLAMP = 75
|
||||
BUF_FMT_8_24_UNORM = 76
|
||||
BUF_FMT_8_24_UINT = 77
|
||||
BUF_FMT_24_8_UNORM = 78
|
||||
BUF_FMT_24_8_UINT = 79
|
||||
BUF_FMT_X24_8_32_UINT = 80
|
||||
BUF_FMT_X24_8_32_FLOAT = 81
|
||||
BUF_FMT_GB_GR_UNORM = 82
|
||||
BUF_FMT_GB_GR_SNORM = 83
|
||||
BUF_FMT_GB_GR_UINT = 84
|
||||
BUF_FMT_GB_GR_SRGB = 85
|
||||
BUF_FMT_BG_RG_UNORM = 86
|
||||
BUF_FMT_BG_RG_SNORM = 87
|
||||
BUF_FMT_BG_RG_UINT = 88
|
||||
BUF_FMT_BG_RG_SRGB = 89
|
||||
BUF_FMT_BC1_UNORM = 109
|
||||
BUF_FMT_BC1_SRGB = 110
|
||||
BUF_FMT_BC2_UNORM = 111
|
||||
|
||||
class DSOp(IntEnum):
|
||||
DS_ADD_U32 = 0
|
||||
@@ -1372,7 +1435,6 @@ class VOP3POp(IntEnum):
|
||||
V_WMMA_I32_16X16X16_IU4 = 69
|
||||
|
||||
class VOP3SDOp(IntEnum):
|
||||
DWORD = 1
|
||||
V_ADD_CO_CI_U32 = 288
|
||||
V_SUB_CO_CI_U32 = 289
|
||||
V_SUBREV_CO_CI_U32 = 290
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
# autogenerated from AMD RDNA3.5 ISA PDF by pdf.py - do not edit
|
||||
# autogenerated from AMD ISA PDF by pdf.py - do not edit
|
||||
# ruff: noqa: F401,F403
|
||||
from typing import Annotated
|
||||
from extra.assembly.amd.dsl import bits, BitField, Inst32, Inst64, Inst96, SGPR, VGPR, TTMP as TTMP, s as s, v as v, ttmp as ttmp, SSrc, Src, SImm, Imm, VDSTYEnc, SGPRField, VGPRField
|
||||
from extra.assembly.amd.dsl import *
|
||||
from extra.assembly.amd.autogen.rdna3.enum import *
|
||||
import functools
|
||||
|
||||
# instruction formats
|
||||
class DPP16(Inst64):
|
||||
class DPP16(Inst):
|
||||
src0:Src = bits[39:32]
|
||||
dpp_ctrl = bits[48:40]
|
||||
fi = bits[50]
|
||||
@@ -18,7 +17,7 @@ class DPP16(Inst64):
|
||||
bank_mask = bits[59:56]
|
||||
row_mask = bits[63:60]
|
||||
|
||||
class DPP8(Inst64):
|
||||
class DPP8(Inst):
|
||||
src0:Src = bits[39:32]
|
||||
lane_sel0 = bits[42:40]
|
||||
lane_sel1 = bits[45:43]
|
||||
@@ -29,7 +28,7 @@ class DPP8(Inst64):
|
||||
lane_sel6 = bits[60:58]
|
||||
lane_sel7 = bits[63:61]
|
||||
|
||||
class DS(Inst64):
|
||||
class DS(Inst):
|
||||
encoding = bits[31:26] == 0b110110
|
||||
op:Annotated[BitField, DSOp] = bits[25:18]
|
||||
vdst:VGPRField = bits[63:56]
|
||||
@@ -40,18 +39,18 @@ class DS(Inst64):
|
||||
offset1 = bits[15:8]
|
||||
gds = bits[17]
|
||||
|
||||
class EXP(Inst64):
|
||||
class EXP(Inst):
|
||||
encoding = bits[31:26] == 0b111110
|
||||
vsrc0:VGPRField = bits[39:32]
|
||||
vsrc1:VGPRField = bits[47:40]
|
||||
vsrc2:VGPRField = bits[55:48]
|
||||
vsrc3:VGPRField = bits[63:56]
|
||||
en = bits[3:0]
|
||||
target = bits[9:4]
|
||||
vsrc0 = bits[39:32]
|
||||
vsrc1:VGPRField = bits[47:40]
|
||||
vsrc2 = bits[55:48]
|
||||
vsrc3 = bits[63:56]
|
||||
done = bits[11]
|
||||
row = bits[13]
|
||||
|
||||
class FLAT(Inst64):
|
||||
class FLAT(Inst):
|
||||
encoding = bits[31:26] == 0b110111
|
||||
op:Annotated[BitField, FLATOp] = bits[24:18]
|
||||
vdst:VGPRField = bits[63:56]
|
||||
@@ -60,12 +59,12 @@ class FLAT(Inst64):
|
||||
saddr:SSrc = bits[54:48]
|
||||
offset:Imm = bits[12:0]
|
||||
seg = bits[17:16]
|
||||
dlc = bits[13]
|
||||
glc = bits[14]
|
||||
dlc = bits[13]
|
||||
slc = bits[15]
|
||||
sve = bits[55]
|
||||
|
||||
class LDSDIR(Inst32):
|
||||
class LDSDIR(Inst):
|
||||
encoding = bits[31:24] == 0b11001110
|
||||
op = bits[21:20]
|
||||
vdst:VGPRField = bits[7:0]
|
||||
@@ -73,29 +72,29 @@ class LDSDIR(Inst32):
|
||||
attr_chan = bits[9:8]
|
||||
wait_va = bits[19:16]
|
||||
|
||||
class MIMG(Inst64):
|
||||
class MIMG(Inst):
|
||||
encoding = bits[31:26] == 0b111100
|
||||
op:Annotated[BitField, MIMGOp] = bits[25:18]
|
||||
vdata:VGPRField = bits[47:40]
|
||||
vaddr:VGPRField = bits[39:32]
|
||||
srsrc:SGPRField = bits[52:48]
|
||||
ssamp = bits[62:58]
|
||||
ssamp:SGPRField = bits[62:58]
|
||||
dmask = bits[11:8]
|
||||
dim = bits[4:2]
|
||||
unrm = bits[7]
|
||||
dlc = bits[13]
|
||||
glc = bits[14]
|
||||
dlc = bits[13]
|
||||
slc = bits[12]
|
||||
tfe = bits[53]
|
||||
unrm = bits[7]
|
||||
nsa = bits[0]
|
||||
r128 = bits[15]
|
||||
a16 = bits[16]
|
||||
d16 = bits[17]
|
||||
tfe = bits[53]
|
||||
lwe = bits[54]
|
||||
addr1 = bits[71:64]
|
||||
addr2 = bits[79:72]
|
||||
|
||||
class MTBUF(Inst64):
|
||||
class MTBUF(Inst):
|
||||
encoding = bits[31:26] == 0b111010
|
||||
op:Annotated[BitField, MTBUFOp] = bits[18:15]
|
||||
vdata:VGPRField = bits[47:40]
|
||||
@@ -111,7 +110,7 @@ class MTBUF(Inst64):
|
||||
slc = bits[12]
|
||||
tfe = bits[53]
|
||||
|
||||
class MUBUF(Inst64):
|
||||
class MUBUF(Inst):
|
||||
encoding = bits[31:26] == 0b111000
|
||||
op:Annotated[BitField, MUBUFOp] = bits[25:18]
|
||||
vdata:VGPRField = bits[47:40]
|
||||
@@ -126,7 +125,7 @@ class MUBUF(Inst64):
|
||||
slc = bits[12]
|
||||
tfe = bits[53]
|
||||
|
||||
class SMEM(Inst64):
|
||||
class SMEM(Inst):
|
||||
encoding = bits[31:26] == 0b111101
|
||||
op:Annotated[BitField, SMEMOp] = bits[25:18]
|
||||
sdata:SGPRField = bits[12:6]
|
||||
@@ -136,62 +135,63 @@ class SMEM(Inst64):
|
||||
glc = bits[14]
|
||||
dlc = bits[13]
|
||||
|
||||
class SOP1(Inst32):
|
||||
class SOP1(Inst):
|
||||
encoding = bits[31:23] == 0b101111101
|
||||
op:Annotated[BitField, SOP1Op] = bits[15:8]
|
||||
sdst:SGPRField = bits[22:16]
|
||||
ssrc0:SSrc = bits[7:0]
|
||||
|
||||
class SOP2(Inst32):
|
||||
class SOP2(Inst):
|
||||
encoding = bits[31:30] == 0b10
|
||||
op:Annotated[BitField, SOP2Op] = bits[29:23]
|
||||
sdst:SGPRField = bits[22:16]
|
||||
ssrc0:SSrc = bits[7:0]
|
||||
ssrc1:SSrc = bits[15:8]
|
||||
|
||||
class SOPC(Inst32):
|
||||
class SOPC(Inst):
|
||||
encoding = bits[31:23] == 0b101111110
|
||||
op:Annotated[BitField, SOPCOp] = bits[22:16]
|
||||
ssrc0:SSrc = bits[7:0]
|
||||
ssrc1:SSrc = bits[15:8]
|
||||
|
||||
class SOPK(Inst32):
|
||||
class SOPK(Inst):
|
||||
encoding = bits[31:28] == 0b1011
|
||||
op:Annotated[BitField, SOPKOp] = bits[27:23]
|
||||
sdst:SGPRField = bits[22:16]
|
||||
simm16:SImm = bits[15:0]
|
||||
|
||||
class SOPP(Inst32):
|
||||
class SOPP(Inst):
|
||||
encoding = bits[31:23] == 0b101111111
|
||||
op:Annotated[BitField, SOPPOp] = bits[22:16]
|
||||
simm16:SImm = bits[15:0]
|
||||
|
||||
class VINTERP(Inst64):
|
||||
class VINTERP(Inst):
|
||||
encoding = bits[31:24] == 0b11001101
|
||||
op:Annotated[BitField, VINTERPOp] = bits[22:16]
|
||||
vdst:VGPRField = bits[7:0]
|
||||
src0:Src = bits[40:32]
|
||||
src0:Src = bits[40:32]
|
||||
src1:Src = bits[49:41]
|
||||
src2:Src = bits[58:50]
|
||||
waitexp = bits[10:8]
|
||||
neg = bits[63:61]
|
||||
clmp = bits[15]
|
||||
opsel = bits[14:11]
|
||||
neg = bits[63:61]
|
||||
waitexp = bits[10:8]
|
||||
|
||||
class VOP1(Inst32):
|
||||
encoding = bits[31:25] == 0b111111
|
||||
class VOP1(Inst):
|
||||
encoding = bits[31:25] == 0b0111111
|
||||
op:Annotated[BitField, VOP1Op] = bits[16:9]
|
||||
vdst:VGPRField = bits[24:17]
|
||||
src0:Src = bits[8:0]
|
||||
|
||||
class VOP2(Inst32):
|
||||
encoding = bits[31] == 0
|
||||
class VOP2(Inst):
|
||||
encoding = bits[31] == 0b0
|
||||
op:Annotated[BitField, VOP2Op] = bits[30:25]
|
||||
vdst:VGPRField = bits[24:17]
|
||||
src0:Src = bits[8:0]
|
||||
vsrc1:VGPRField = bits[16:9]
|
||||
|
||||
class VOP3(Inst64):
|
||||
class VOP3(Inst):
|
||||
encoding = bits[31:26] == 0b110101
|
||||
op:Annotated[BitField, VOP3Op] = bits[25:16]
|
||||
vdst:VGPRField = bits[7:0]
|
||||
@@ -204,9 +204,8 @@ class VOP3(Inst64):
|
||||
clmp = bits[15]
|
||||
opsel = bits[14:11]
|
||||
|
||||
class VOP3P(Inst64):
|
||||
class VOP3P(Inst):
|
||||
encoding = bits[31:24] == 0b11001100
|
||||
_defaults = {'opsel_hi': 3, 'opsel_hi2': 1}
|
||||
op:Annotated[BitField, VOP3POp] = bits[22:16]
|
||||
vdst:VGPRField = bits[7:0]
|
||||
src0:Src = bits[40:32]
|
||||
@@ -214,12 +213,12 @@ class VOP3P(Inst64):
|
||||
src2:Src = bits[58:50]
|
||||
neg = bits[63:61]
|
||||
neg_hi = bits[10:8]
|
||||
clmp = bits[15]
|
||||
opsel = bits[13:11]
|
||||
opsel_hi = bits[60:59]
|
||||
clmp = bits[15]
|
||||
opsel_hi2 = bits[14]
|
||||
|
||||
class VOP3SD(Inst64):
|
||||
class VOP3SD(Inst):
|
||||
encoding = bits[31:26] == 0b110101
|
||||
op:Annotated[BitField, VOP3SDOp] = bits[25:16]
|
||||
vdst:VGPRField = bits[7:0]
|
||||
@@ -227,26 +226,26 @@ class VOP3SD(Inst64):
|
||||
src0:Src = bits[40:32]
|
||||
src1:Src = bits[49:41]
|
||||
src2:Src = bits[58:50]
|
||||
clmp = bits[15]
|
||||
omod = bits[60:59]
|
||||
neg = bits[63:61]
|
||||
clmp = bits[15]
|
||||
|
||||
class VOPC(Inst32):
|
||||
encoding = bits[31:25] == 0b111110
|
||||
class VOPC(Inst):
|
||||
encoding = bits[31:25] == 0b0111110
|
||||
op:Annotated[BitField, VOPCOp] = bits[24:17]
|
||||
src0:Src = bits[8:0]
|
||||
vsrc1:VGPRField = bits[16:9]
|
||||
|
||||
class VOPD(Inst64):
|
||||
class VOPD(Inst):
|
||||
encoding = bits[31:26] == 0b110010
|
||||
opx:Annotated[BitField, VOPDOp] = bits[25:22]
|
||||
opy:Annotated[BitField, VOPDOp] = bits[21:17]
|
||||
vdstx:VGPRField = bits[63:56]
|
||||
vdstx = bits[63:56]
|
||||
vdsty:VDSTYEnc = bits[55:49]
|
||||
srcx0:Src = bits[8:0]
|
||||
vsrcx1:VGPRField = bits[16:9]
|
||||
srcy0:Src = bits[40:32]
|
||||
vsrcy1:VGPRField = bits[48:41]
|
||||
vsrcx1 = bits[16:9]
|
||||
vsrcy1 = bits[48:41]
|
||||
|
||||
# instruction helpers
|
||||
ds_add_u32 = functools.partial(DS, DSOp.DS_ADD_U32)
|
||||
@@ -1077,16 +1076,16 @@ v_add_nc_u32_e32 = functools.partial(VOP2, VOP2Op.V_ADD_NC_U32)
|
||||
v_sub_nc_u32_e32 = functools.partial(VOP2, VOP2Op.V_SUB_NC_U32)
|
||||
v_subrev_nc_u32_e32 = functools.partial(VOP2, VOP2Op.V_SUBREV_NC_U32)
|
||||
v_fmac_f32_e32 = functools.partial(VOP2, VOP2Op.V_FMAC_F32)
|
||||
def v_fmamk_f32_e32(vdst, src0, K, vsrc1): return VOP2(VOP2Op.V_FMAMK_F32, vdst, src0, vsrc1, literal=K)
|
||||
def v_fmaak_f32_e32(vdst, src0, vsrc1, K): return VOP2(VOP2Op.V_FMAAK_F32, vdst, src0, vsrc1, literal=K)
|
||||
v_fmamk_f32_e32 = functools.partial(VOP2, VOP2Op.V_FMAMK_F32)
|
||||
v_fmaak_f32_e32 = functools.partial(VOP2, VOP2Op.V_FMAAK_F32)
|
||||
v_cvt_pk_rtz_f16_f32_e32 = functools.partial(VOP2, VOP2Op.V_CVT_PK_RTZ_F16_F32)
|
||||
v_add_f16_e32 = functools.partial(VOP2, VOP2Op.V_ADD_F16)
|
||||
v_sub_f16_e32 = functools.partial(VOP2, VOP2Op.V_SUB_F16)
|
||||
v_subrev_f16_e32 = functools.partial(VOP2, VOP2Op.V_SUBREV_F16)
|
||||
v_mul_f16_e32 = functools.partial(VOP2, VOP2Op.V_MUL_F16)
|
||||
v_fmac_f16_e32 = functools.partial(VOP2, VOP2Op.V_FMAC_F16)
|
||||
def v_fmamk_f16_e32(vdst, src0, K, vsrc1): return VOP2(VOP2Op.V_FMAMK_F16, vdst, src0, vsrc1, literal=K)
|
||||
def v_fmaak_f16_e32(vdst, src0, vsrc1, K): return VOP2(VOP2Op.V_FMAAK_F16, vdst, src0, vsrc1, literal=K)
|
||||
v_fmamk_f16_e32 = functools.partial(VOP2, VOP2Op.V_FMAMK_F16)
|
||||
v_fmaak_f16_e32 = functools.partial(VOP2, VOP2Op.V_FMAAK_F16)
|
||||
v_max_f16_e32 = functools.partial(VOP2, VOP2Op.V_MAX_F16)
|
||||
v_min_f16_e32 = functools.partial(VOP2, VOP2Op.V_MIN_F16)
|
||||
v_ldexp_f16_e32 = functools.partial(VOP2, VOP2Op.V_LDEXP_F16)
|
||||
@@ -1554,7 +1553,6 @@ v_wmma_f16_16x16x16_f16 = functools.partial(VOP3P, VOP3POp.V_WMMA_F16_16X16X16_F
|
||||
v_wmma_bf16_16x16x16_bf16 = functools.partial(VOP3P, VOP3POp.V_WMMA_BF16_16X16X16_BF16)
|
||||
v_wmma_i32_16x16x16_iu8 = functools.partial(VOP3P, VOP3POp.V_WMMA_I32_16X16X16_IU8)
|
||||
v_wmma_i32_16x16x16_iu4 = functools.partial(VOP3P, VOP3POp.V_WMMA_I32_16X16X16_IU4)
|
||||
dword = functools.partial(VOP3SD, VOP3SDOp.DWORD)
|
||||
v_add_co_ci_u32 = functools.partial(VOP3SD, VOP3SDOp.V_ADD_CO_CI_U32)
|
||||
v_sub_co_ci_u32 = functools.partial(VOP3SD, VOP3SDOp.V_SUB_CO_CI_U32)
|
||||
v_subrev_co_ci_u32 = functools.partial(VOP3SD, VOP3SDOp.V_SUBREV_CO_CI_U32)
|
||||
@@ -1771,31 +1769,4 @@ v_dual_dot2acc_f32_f16 = functools.partial(VOPD, VOPDOp.V_DUAL_DOT2ACC_F32_F16)
|
||||
v_dual_dot2acc_f32_bf16 = functools.partial(VOPD, VOPDOp.V_DUAL_DOT2ACC_F32_BF16)
|
||||
v_dual_add_nc_u32 = functools.partial(VOPD, VOPDOp.V_DUAL_ADD_NC_U32)
|
||||
v_dual_lshlrev_b32 = functools.partial(VOPD, VOPDOp.V_DUAL_LSHLREV_B32)
|
||||
v_dual_and_b32 = functools.partial(VOPD, VOPDOp.V_DUAL_AND_B32)
|
||||
|
||||
VCC_LO = SrcEnum.VCC_LO
|
||||
VCC_HI = SrcEnum.VCC_HI
|
||||
NULL = SrcEnum.NULL
|
||||
M0 = SrcEnum.M0
|
||||
EXEC_LO = SrcEnum.EXEC_LO
|
||||
EXEC_HI = SrcEnum.EXEC_HI
|
||||
ZERO = SrcEnum.ZERO
|
||||
DPP8FI = SrcEnum.DPP8FI
|
||||
SHARED_BASE = SrcEnum.SHARED_BASE
|
||||
SHARED_LIMIT = SrcEnum.SHARED_LIMIT
|
||||
PRIVATE_BASE = SrcEnum.PRIVATE_BASE
|
||||
PRIVATE_LIMIT = SrcEnum.PRIVATE_LIMIT
|
||||
POS_HALF = SrcEnum.POS_HALF
|
||||
NEG_HALF = SrcEnum.NEG_HALF
|
||||
POS_ONE = SrcEnum.POS_ONE
|
||||
NEG_ONE = SrcEnum.NEG_ONE
|
||||
POS_TWO = SrcEnum.POS_TWO
|
||||
NEG_TWO = SrcEnum.NEG_TWO
|
||||
POS_FOUR = SrcEnum.POS_FOUR
|
||||
NEG_FOUR = SrcEnum.NEG_FOUR
|
||||
INV_2PI = SrcEnum.INV_2PI
|
||||
VCCZ = SrcEnum.VCCZ
|
||||
EXECZ = SrcEnum.EXECZ
|
||||
SCC = SrcEnum.SCC
|
||||
LDS_DIRECT = SrcEnum.LDS_DIRECT
|
||||
OFF = NULL
|
||||
v_dual_and_b32 = functools.partial(VOPD, VOPDOp.V_DUAL_AND_B32)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,34 +1,100 @@
|
||||
# autogenerated from AMD RDNA4 ISA PDF by pdf.py - do not edit
|
||||
# autogenerated from AMD ISA PDF by pdf.py - do not edit
|
||||
from enum import IntEnum
|
||||
|
||||
class SrcEnum(IntEnum):
|
||||
VCC_LO = 106
|
||||
VCC_HI = 107
|
||||
NULL = 124
|
||||
M0 = 125
|
||||
EXEC_LO = 126
|
||||
EXEC_HI = 127
|
||||
ZERO = 128
|
||||
DPP8 = 233
|
||||
DPP8FI = 234
|
||||
SHARED_BASE = 235
|
||||
SHARED_LIMIT = 236
|
||||
PRIVATE_BASE = 237
|
||||
PRIVATE_LIMIT = 238
|
||||
POS_HALF = 240
|
||||
NEG_HALF = 241
|
||||
POS_ONE = 242
|
||||
NEG_ONE = 243
|
||||
POS_TWO = 244
|
||||
NEG_TWO = 245
|
||||
POS_FOUR = 246
|
||||
NEG_FOUR = 247
|
||||
INV_2PI = 248
|
||||
DPP16 = 250
|
||||
VCCZ = 251
|
||||
EXECZ = 252
|
||||
SCC = 253
|
||||
LDS_DIRECT = 254
|
||||
class BufFmt(IntEnum):
|
||||
BUF_FMT_8_UNORM = 1
|
||||
BUF_FMT_8_SNORM = 2
|
||||
BUF_FMT_8_USCALED = 3
|
||||
BUF_FMT_8_SSCALED = 4
|
||||
BUF_FMT_8_UINT = 5
|
||||
BUF_FMT_8_SINT = 6
|
||||
BUF_FMT_16_UNORM = 7
|
||||
BUF_FMT_16_SNORM = 8
|
||||
BUF_FMT_16_USCALED = 9
|
||||
BUF_FMT_16_SSCALED = 10
|
||||
BUF_FMT_16_UINT = 11
|
||||
BUF_FMT_16_SINT = 12
|
||||
BUF_FMT_16_FLOAT = 13
|
||||
BUF_FMT_8_8_UNORM = 14
|
||||
BUF_FMT_8_8_SNORM = 15
|
||||
BUF_FMT_8_8_USCALED = 16
|
||||
BUF_FMT_8_8_SSCALED = 17
|
||||
BUF_FMT_8_8_UINT = 18
|
||||
BUF_FMT_8_8_SINT = 19
|
||||
BUF_FMT_32_UINT = 20
|
||||
BUF_FMT_32_SINT = 21
|
||||
BUF_FMT_32_FLOAT = 22
|
||||
BUF_FMT_16_16_UNORM = 23
|
||||
BUF_FMT_16_16_SNORM = 24
|
||||
BUF_FMT_16_16_USCALED = 25
|
||||
BUF_FMT_16_16_SSCALED = 26
|
||||
BUF_FMT_16_16_UINT = 27
|
||||
BUF_FMT_16_16_SINT = 28
|
||||
BUF_FMT_16_16_FLOAT = 29
|
||||
BUF_FMT_10_11_11_FLOAT = 30
|
||||
BUF_FMT_11_11_10_FLOAT = 31
|
||||
BUF_FMT_10_10_10_2_UNORM = 32
|
||||
BUF_FMT_10_10_10_2_SNORM = 33
|
||||
BUF_FMT_10_10_10_2_UINT = 34
|
||||
BUF_FMT_10_10_10_2_SINT = 35
|
||||
BUF_FMT_2_10_10_10_UNORM = 36
|
||||
BUF_FMT_2_10_10_10_SNORM = 37
|
||||
BUF_FMT_2_10_10_10_USCALED = 38
|
||||
BUF_FMT_2_10_10_10_SSCALED = 39
|
||||
BUF_FMT_2_10_10_10_UINT = 40
|
||||
BUF_FMT_2_10_10_10_SINT = 41
|
||||
BUF_FMT_8_8_8_8_UNORM = 42
|
||||
BUF_FMT_8_8_8_8_SNORM = 43
|
||||
BUF_FMT_8_8_8_8_USCALED = 44
|
||||
BUF_FMT_8_8_8_8_SSCALED = 45
|
||||
BUF_FMT_8_8_8_8_UINT = 46
|
||||
BUF_FMT_8_8_8_8_SINT = 47
|
||||
BUF_FMT_32_32_UINT = 48
|
||||
BUF_FMT_32_32_SINT = 49
|
||||
BUF_FMT_32_32_FLOAT = 50
|
||||
BUF_FMT_16_16_16_16_UNORM = 51
|
||||
BUF_FMT_16_16_16_16_SNORM = 52
|
||||
BUF_FMT_16_16_16_16_USCALED = 53
|
||||
BUF_FMT_16_16_16_16_SSCALED = 54
|
||||
BUF_FMT_16_16_16_16_UINT = 55
|
||||
BUF_FMT_16_16_16_16_SINT = 56
|
||||
BUF_FMT_16_16_16_16_FLOAT = 57
|
||||
BUF_FMT_32_32_32_UINT = 58
|
||||
BUF_FMT_32_32_32_SINT = 59
|
||||
BUF_FMT_32_32_32_FLOAT = 60
|
||||
BUF_FMT_32_32_32_32_UINT = 61
|
||||
BUF_FMT_32_32_32_32_SINT = 62
|
||||
BUF_FMT_32_32_32_32_FLOAT = 63
|
||||
BUF_FMT_8_SRGB = 64
|
||||
BUF_FMT_8_8_SRGB = 65
|
||||
BUF_FMT_8_8_8_8_SRGB = 66
|
||||
BUF_FMT_5_9_9_9_FLOAT = 67
|
||||
BUF_FMT_5_6_5_UNORM = 68
|
||||
BUF_FMT_1_5_5_5_UNORM = 69
|
||||
BUF_FMT_5_5_5_1_UNORM = 70
|
||||
BUF_FMT_4_4_4_4_UNORM = 71
|
||||
BUF_FMT_4_4_UNORM = 72
|
||||
BUF_FMT_1_UNORM = 73
|
||||
BUF_FMT_1_REVERSED_UNORM = 74
|
||||
BUF_FMT_32_FLOAT_CLAMP = 75
|
||||
BUF_FMT_8_24_UNORM = 76
|
||||
BUF_FMT_8_24_UINT = 77
|
||||
BUF_FMT_24_8_UNORM = 78
|
||||
BUF_FMT_24_8_UINT = 79
|
||||
BUF_FMT_X24_8_32_UINT = 80
|
||||
BUF_FMT_X24_8_32_FLOAT = 81
|
||||
BUF_FMT_GB_GR_UNORM = 82
|
||||
BUF_FMT_GB_GR_SNORM = 83
|
||||
BUF_FMT_GB_GR_UINT = 84
|
||||
BUF_FMT_GB_GR_SRGB = 85
|
||||
BUF_FMT_BG_RG_UNORM = 86
|
||||
BUF_FMT_BG_RG_SNORM = 87
|
||||
BUF_FMT_BG_RG_UINT = 88
|
||||
BUF_FMT_BG_RG_SRGB = 89
|
||||
BUF_FMT_BC1_UNORM = 109
|
||||
BUF_FMT_BC1_SRGB = 110
|
||||
BUF_FMT_BC2_UNORM = 111
|
||||
BUF_FMT_BC2_SRGB = 112
|
||||
|
||||
class DSOp(IntEnum):
|
||||
DS_ADD_U32 = 0
|
||||
@@ -1347,7 +1413,6 @@ class VOP3POp(IntEnum):
|
||||
V_SWMMAC_F32_16X16X32_BF8_BF8 = 90
|
||||
|
||||
class VOP3SDOp(IntEnum):
|
||||
DWORD = 1
|
||||
V_ADD_CO_CI_U32 = 288
|
||||
V_SUB_CO_CI_U32 = 289
|
||||
V_SUBREV_CO_CI_U32 = 290
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
# autogenerated from AMD RDNA4 ISA PDF by pdf.py - do not edit
|
||||
# autogenerated from AMD ISA PDF by pdf.py - do not edit
|
||||
# ruff: noqa: F401,F403
|
||||
from typing import Annotated
|
||||
from extra.assembly.amd.dsl import bits, BitField, Inst32, Inst64, Inst96, SGPR, VGPR, TTMP as TTMP, s as s, v as v, ttmp as ttmp, SSrc, Src, SImm, Imm, VDSTYEnc, SGPRField, VGPRField
|
||||
from extra.assembly.amd.dsl import *
|
||||
from extra.assembly.amd.autogen.rdna4.enum import *
|
||||
import functools
|
||||
|
||||
# instruction formats
|
||||
class DPP16(Inst64):
|
||||
class DPP16(Inst):
|
||||
src0:Src = bits[39:32]
|
||||
dpp_ctrl = bits[48:40]
|
||||
fi = bits[50]
|
||||
@@ -18,7 +17,7 @@ class DPP16(Inst64):
|
||||
bank_mask = bits[59:56]
|
||||
row_mask = bits[63:60]
|
||||
|
||||
class DPP8(Inst64):
|
||||
class DPP8(Inst):
|
||||
src0:Src = bits[39:32]
|
||||
lane_sel0 = bits[42:40]
|
||||
lane_sel1 = bits[45:43]
|
||||
@@ -29,7 +28,17 @@ class DPP8(Inst64):
|
||||
lane_sel6 = bits[60:58]
|
||||
lane_sel7 = bits[63:61]
|
||||
|
||||
class SMEM(Inst64):
|
||||
class DS(Inst):
|
||||
encoding = bits[31:26] == 0b110110
|
||||
op:Annotated[BitField, DSOp] = bits[25:18]
|
||||
vdst:VGPRField = bits[63:56]
|
||||
addr:VGPRField = bits[39:32]
|
||||
data0:VGPRField = bits[47:40]
|
||||
data1:VGPRField = bits[55:48]
|
||||
offset0 = bits[7:0]
|
||||
offset1 = bits[15:8]
|
||||
|
||||
class SMEM(Inst):
|
||||
encoding = bits[31:26] == 0b111101
|
||||
op:Annotated[BitField, SMEMOp] = bits[18:13]
|
||||
sdata:SGPRField = bits[12:6]
|
||||
@@ -39,153 +48,116 @@ class SMEM(Inst64):
|
||||
th = bits[24:23]
|
||||
ioffset = bits[55:32]
|
||||
|
||||
class SOP1(Inst32):
|
||||
class SOP1(Inst):
|
||||
encoding = bits[31:23] == 0b101111101
|
||||
op:Annotated[BitField, SOP1Op] = bits[15:8]
|
||||
sdst:SGPRField = bits[22:16]
|
||||
ssrc0:SSrc = bits[7:0]
|
||||
|
||||
class SOP2(Inst32):
|
||||
class SOP2(Inst):
|
||||
encoding = bits[31:30] == 0b10
|
||||
op:Annotated[BitField, SOP2Op] = bits[29:23]
|
||||
sdst:SGPRField = bits[22:16]
|
||||
ssrc0:SSrc = bits[7:0]
|
||||
ssrc1:SSrc = bits[15:8]
|
||||
|
||||
class SOPC(Inst32):
|
||||
class SOPC(Inst):
|
||||
encoding = bits[31:23] == 0b101111110
|
||||
op:Annotated[BitField, SOPCOp] = bits[22:16]
|
||||
ssrc0:SSrc = bits[7:0]
|
||||
ssrc1:SSrc = bits[15:8]
|
||||
|
||||
class SOPK(Inst32):
|
||||
class SOPK(Inst):
|
||||
encoding = bits[31:28] == 0b1011
|
||||
op:Annotated[BitField, SOPKOp] = bits[27:23]
|
||||
sdst:SGPRField = bits[22:16]
|
||||
simm16:SImm = bits[15:0]
|
||||
|
||||
class SOPP(Inst32):
|
||||
class SOPP(Inst):
|
||||
encoding = bits[31:23] == 0b101111111
|
||||
op:Annotated[BitField, SOPPOp] = bits[22:16]
|
||||
simm16:SImm = bits[15:0]
|
||||
|
||||
class VBUFFER(Inst96):
|
||||
class VBUFFER(Inst):
|
||||
encoding = bits[31:26] == 0b110001
|
||||
soffset:SSrc = bits[6:0]
|
||||
op:Annotated[BitField, VBUFFEROp] = bits[21:14]
|
||||
tfe = bits[22]
|
||||
vdata:VGPRField = bits[39:32]
|
||||
rsrc = bits[49:41]
|
||||
scope = bits[51:50]
|
||||
th = bits[54:52]
|
||||
vaddr:VGPRField = bits[71:64]
|
||||
soffset:SSrc = bits[6:0]
|
||||
format = bits[61:55]
|
||||
offen = bits[62]
|
||||
idxen = bits[63]
|
||||
vaddr:VGPRField = bits[71:64]
|
||||
tfe = bits[22]
|
||||
rsrc = bits[49:41]
|
||||
scope = bits[51:50]
|
||||
th = bits[54:52]
|
||||
ioffset = bits[95:72]
|
||||
|
||||
class VDS(Inst64):
|
||||
encoding = bits[31:26] == 0b110110
|
||||
offset0 = bits[7:0]
|
||||
offset1 = bits[15:8]
|
||||
op = bits[25:18]
|
||||
addr:VGPRField = bits[39:32]
|
||||
data0:VGPRField = bits[47:40]
|
||||
data1:VGPRField = bits[55:48]
|
||||
vdst:VGPRField = bits[63:56]
|
||||
|
||||
class VDSDIR(Inst64):
|
||||
encoding = bits[31:24] == 0b11001101
|
||||
class VDSDIR(Inst):
|
||||
encoding = bits[31:24] == 0b11001110
|
||||
op:Annotated[BitField, VDSDIROp] = bits[21:20]
|
||||
vdst:VGPRField = bits[7:0]
|
||||
waitexp = bits[10:8]
|
||||
opsel = bits[14:11]
|
||||
cm = bits[15]
|
||||
op:Annotated[BitField, VDSDIROp] = bits[20:16]
|
||||
src0:Src = bits[40:32]
|
||||
src1:Src = bits[49:41]
|
||||
src2:Src = bits[58:50]
|
||||
neg = bits[63:61]
|
||||
attr = bits[15:10]
|
||||
attr_chan = bits[9:8]
|
||||
wait_va = bits[19:16]
|
||||
wait_vmvsrc = bits[23]
|
||||
|
||||
class VEXPORT(Inst64):
|
||||
class VEXPORT(Inst):
|
||||
encoding = bits[31:26] == 0b111110
|
||||
vsrc0:VGPRField = bits[39:32]
|
||||
vsrc1:VGPRField = bits[47:40]
|
||||
vsrc2:VGPRField = bits[55:48]
|
||||
vsrc3:VGPRField = bits[63:56]
|
||||
en = bits[3:0]
|
||||
target = bits[9:4]
|
||||
done = bits[11]
|
||||
row = bits[13]
|
||||
vsrc0 = bits[39:32]
|
||||
vsrc1:VGPRField = bits[47:40]
|
||||
vsrc2 = bits[55:48]
|
||||
vsrc3 = bits[63:56]
|
||||
|
||||
class VFLAT(Inst96):
|
||||
encoding = bits[31:24] == 0b11101100
|
||||
saddr:SSrc = bits[6:0]
|
||||
op:Annotated[BitField, VFLATOp] = bits[20:14]
|
||||
vdst:VGPRField = bits[39:32]
|
||||
sve = bits[49]
|
||||
scope = bits[51:50]
|
||||
th = bits[54:52]
|
||||
vsrc = bits[62:55]
|
||||
vaddr:VGPRField = bits[71:64]
|
||||
ioffset = bits[95:72]
|
||||
|
||||
class VGLOBAL(Inst96):
|
||||
encoding = bits[31:24] == 0b11101110
|
||||
saddr:SSrc = bits[6:0]
|
||||
op:Annotated[BitField, VGLOBALOp] = bits[20:14]
|
||||
vdst:VGPRField = bits[39:32]
|
||||
sve = bits[49]
|
||||
scope = bits[51:50]
|
||||
th = bits[54:52]
|
||||
vsrc = bits[62:55]
|
||||
vaddr:VGPRField = bits[71:64]
|
||||
ioffset = bits[95:72]
|
||||
|
||||
class VIMAGE(Inst96):
|
||||
class VIMAGE(Inst):
|
||||
encoding = bits[31:26] == 0b110100
|
||||
op:Annotated[BitField, VIMAGEOp] = bits[21:14]
|
||||
vdata:VGPRField = bits[39:32]
|
||||
dmask = bits[25:22]
|
||||
dim = bits[2:0]
|
||||
tfe = bits[55]
|
||||
r128 = bits[4]
|
||||
d16 = bits[5]
|
||||
a16 = bits[6]
|
||||
op:Annotated[BitField, VIMAGEOp] = bits[21:14]
|
||||
dmask = bits[25:22]
|
||||
vdata:VGPRField = bits[39:32]
|
||||
rsrc = bits[49:41]
|
||||
scope = bits[51:50]
|
||||
th = bits[54:52]
|
||||
tfe = bits[55]
|
||||
vaddr4 = bits[56:63]
|
||||
vaddr0 = bits[71:64]
|
||||
vaddr1 = bits[79:72]
|
||||
vaddr2 = bits[87:80]
|
||||
vaddr3 = bits[95:88]
|
||||
|
||||
class VINTERP(Inst64):
|
||||
class VINTERP(Inst):
|
||||
encoding = bits[31:24] == 0b11001101
|
||||
op:Annotated[BitField, VINTERPOp] = bits[20:16]
|
||||
vdst:VGPRField = bits[7:0]
|
||||
src0:Src = bits[40:32]
|
||||
src1:Src = bits[49:41]
|
||||
src2:Src = bits[58:50]
|
||||
waitexp = bits[10:8]
|
||||
opsel = bits[14:11]
|
||||
neg = bits[63:61]
|
||||
opsel = bits[14:11]
|
||||
waitexp = bits[10:8]
|
||||
cm = bits[15]
|
||||
|
||||
class VOP1(Inst32):
|
||||
encoding = bits[31:25] == 0b111111
|
||||
class VOP1(Inst):
|
||||
encoding = bits[31:25] == 0b0111111
|
||||
op:Annotated[BitField, VOP1Op] = bits[15:9]
|
||||
vdst:VGPRField = bits[24:17]
|
||||
src0:Src = bits[8:0]
|
||||
|
||||
class VOP2(Inst32):
|
||||
encoding = bits[31] == 0
|
||||
class VOP2(Inst):
|
||||
encoding = bits[31] == 0b0
|
||||
op:Annotated[BitField, VOP2Op] = bits[30:25]
|
||||
vdst:VGPRField = bits[24:17]
|
||||
src0:Src = bits[8:0]
|
||||
vsrc1:VGPRField = bits[16:9]
|
||||
|
||||
class VOP3(Inst64):
|
||||
class VOP3(Inst):
|
||||
encoding = bits[31:26] == 0b110101
|
||||
op:Annotated[BitField, VOP3Op] = bits[25:16]
|
||||
vdst:VGPRField = bits[7:0]
|
||||
@@ -198,9 +170,8 @@ class VOP3(Inst64):
|
||||
opsel = bits[14:11]
|
||||
cm = bits[15]
|
||||
|
||||
class VOP3P(Inst64):
|
||||
class VOP3P(Inst):
|
||||
encoding = bits[31:24] == 0b11001100
|
||||
_defaults = {'opsel_hi': 3, 'opsel_hi2': 1}
|
||||
op:Annotated[BitField, VOP3POp] = bits[22:16]
|
||||
vdst:VGPRField = bits[7:0]
|
||||
src0:Src = bits[40:32]
|
||||
@@ -213,7 +184,7 @@ class VOP3P(Inst64):
|
||||
opsel_hi2 = bits[14]
|
||||
cm = bits[15]
|
||||
|
||||
class VOP3SD(Inst64):
|
||||
class VOP3SD(Inst):
|
||||
encoding = bits[31:26] == 0b110101
|
||||
op:Annotated[BitField, VOP3SDOp] = bits[25:16]
|
||||
vdst:VGPRField = bits[7:0]
|
||||
@@ -221,38 +192,38 @@ class VOP3SD(Inst64):
|
||||
src0:Src = bits[40:32]
|
||||
src1:Src = bits[49:41]
|
||||
src2:Src = bits[58:50]
|
||||
cm = bits[15]
|
||||
omod = bits[60:59]
|
||||
neg = bits[63:61]
|
||||
cm = bits[15]
|
||||
|
||||
class VOPC(Inst32):
|
||||
encoding = bits[31:25] == 0b111110
|
||||
class VOPC(Inst):
|
||||
encoding = bits[31:25] == 0b0111110
|
||||
op:Annotated[BitField, VOPCOp] = bits[24:17]
|
||||
src0:Src = bits[8:0]
|
||||
vsrc1:VGPRField = bits[16:9]
|
||||
|
||||
class VOPD(Inst64):
|
||||
class VOPD(Inst):
|
||||
encoding = bits[31:26] == 0b110010
|
||||
opx:Annotated[BitField, VOPDOp] = bits[25:22]
|
||||
opy:Annotated[BitField, VOPDOp] = bits[21:17]
|
||||
vdstx:VGPRField = bits[63:56]
|
||||
vdstx = bits[63:56]
|
||||
vdsty:VDSTYEnc = bits[55:49]
|
||||
srcx0:Src = bits[8:0]
|
||||
vsrcx1:VGPRField = bits[16:9]
|
||||
srcy0:Src = bits[40:32]
|
||||
vsrcy1:VGPRField = bits[48:41]
|
||||
vsrcx1 = bits[16:9]
|
||||
vsrcy1 = bits[48:41]
|
||||
|
||||
class VSAMPLE(Inst96):
|
||||
class VSAMPLE(Inst):
|
||||
encoding = bits[31:26] == 0b111001
|
||||
op:Annotated[BitField, VSAMPLEOp] = bits[21:14]
|
||||
vdata:VGPRField = bits[39:32]
|
||||
dmask = bits[25:22]
|
||||
dim = bits[2:0]
|
||||
tfe = bits[3]
|
||||
unrm = bits[13]
|
||||
r128 = bits[4]
|
||||
d16 = bits[5]
|
||||
a16 = bits[6]
|
||||
unrm = bits[13]
|
||||
op:Annotated[BitField, VSAMPLEOp] = bits[21:14]
|
||||
dmask = bits[25:22]
|
||||
vdata:VGPRField = bits[39:32]
|
||||
lwe = bits[40]
|
||||
rsrc = bits[49:41]
|
||||
scope = bits[51:50]
|
||||
@@ -263,19 +234,130 @@ class VSAMPLE(Inst96):
|
||||
vaddr2 = bits[87:80]
|
||||
vaddr3 = bits[95:88]
|
||||
|
||||
class VSCRATCH(Inst96):
|
||||
encoding = bits[31:24] == 0b11101101
|
||||
saddr:SSrc = bits[6:0]
|
||||
op:Annotated[BitField, VSCRATCHOp] = bits[20:14]
|
||||
vdst:VGPRField = bits[39:32]
|
||||
sve = bits[49]
|
||||
scope = bits[51:50]
|
||||
th = bits[54:52]
|
||||
vsrc = bits[62:55]
|
||||
vaddr:VGPRField = bits[71:64]
|
||||
ioffset = bits[95:72]
|
||||
|
||||
# instruction helpers
|
||||
ds_add_u32 = functools.partial(DS, DSOp.DS_ADD_U32)
|
||||
ds_sub_u32 = functools.partial(DS, DSOp.DS_SUB_U32)
|
||||
ds_rsub_u32 = functools.partial(DS, DSOp.DS_RSUB_U32)
|
||||
ds_inc_u32 = functools.partial(DS, DSOp.DS_INC_U32)
|
||||
ds_dec_u32 = functools.partial(DS, DSOp.DS_DEC_U32)
|
||||
ds_min_i32 = functools.partial(DS, DSOp.DS_MIN_I32)
|
||||
ds_max_i32 = functools.partial(DS, DSOp.DS_MAX_I32)
|
||||
ds_min_u32 = functools.partial(DS, DSOp.DS_MIN_U32)
|
||||
ds_max_u32 = functools.partial(DS, DSOp.DS_MAX_U32)
|
||||
ds_and_b32 = functools.partial(DS, DSOp.DS_AND_B32)
|
||||
ds_or_b32 = functools.partial(DS, DSOp.DS_OR_B32)
|
||||
ds_xor_b32 = functools.partial(DS, DSOp.DS_XOR_B32)
|
||||
ds_mskor_b32 = functools.partial(DS, DSOp.DS_MSKOR_B32)
|
||||
ds_store_b32 = functools.partial(DS, DSOp.DS_STORE_B32)
|
||||
ds_store_2addr_b32 = functools.partial(DS, DSOp.DS_STORE_2ADDR_B32)
|
||||
ds_store_2addr_stride64_b32 = functools.partial(DS, DSOp.DS_STORE_2ADDR_STRIDE64_B32)
|
||||
ds_cmpstore_b32 = functools.partial(DS, DSOp.DS_CMPSTORE_B32)
|
||||
ds_min_num_f32 = functools.partial(DS, DSOp.DS_MIN_NUM_F32)
|
||||
ds_max_num_f32 = functools.partial(DS, DSOp.DS_MAX_NUM_F32)
|
||||
ds_nop = functools.partial(DS, DSOp.DS_NOP)
|
||||
ds_add_f32 = functools.partial(DS, DSOp.DS_ADD_F32)
|
||||
ds_store_b8 = functools.partial(DS, DSOp.DS_STORE_B8)
|
||||
ds_store_b16 = functools.partial(DS, DSOp.DS_STORE_B16)
|
||||
ds_add_rtn_u32 = functools.partial(DS, DSOp.DS_ADD_RTN_U32)
|
||||
ds_sub_rtn_u32 = functools.partial(DS, DSOp.DS_SUB_RTN_U32)
|
||||
ds_rsub_rtn_u32 = functools.partial(DS, DSOp.DS_RSUB_RTN_U32)
|
||||
ds_inc_rtn_u32 = functools.partial(DS, DSOp.DS_INC_RTN_U32)
|
||||
ds_dec_rtn_u32 = functools.partial(DS, DSOp.DS_DEC_RTN_U32)
|
||||
ds_min_rtn_i32 = functools.partial(DS, DSOp.DS_MIN_RTN_I32)
|
||||
ds_max_rtn_i32 = functools.partial(DS, DSOp.DS_MAX_RTN_I32)
|
||||
ds_min_rtn_u32 = functools.partial(DS, DSOp.DS_MIN_RTN_U32)
|
||||
ds_max_rtn_u32 = functools.partial(DS, DSOp.DS_MAX_RTN_U32)
|
||||
ds_and_rtn_b32 = functools.partial(DS, DSOp.DS_AND_RTN_B32)
|
||||
ds_or_rtn_b32 = functools.partial(DS, DSOp.DS_OR_RTN_B32)
|
||||
ds_xor_rtn_b32 = functools.partial(DS, DSOp.DS_XOR_RTN_B32)
|
||||
ds_mskor_rtn_b32 = functools.partial(DS, DSOp.DS_MSKOR_RTN_B32)
|
||||
ds_storexchg_rtn_b32 = functools.partial(DS, DSOp.DS_STOREXCHG_RTN_B32)
|
||||
ds_storexchg_2addr_rtn_b32 = functools.partial(DS, DSOp.DS_STOREXCHG_2ADDR_RTN_B32)
|
||||
ds_storexchg_2addr_stride64_rtn_b32 = functools.partial(DS, DSOp.DS_STOREXCHG_2ADDR_STRIDE64_RTN_B32)
|
||||
ds_cmpstore_rtn_b32 = functools.partial(DS, DSOp.DS_CMPSTORE_RTN_B32)
|
||||
ds_min_num_rtn_f32 = functools.partial(DS, DSOp.DS_MIN_NUM_RTN_F32)
|
||||
ds_max_num_rtn_f32 = functools.partial(DS, DSOp.DS_MAX_NUM_RTN_F32)
|
||||
ds_swizzle_b32 = functools.partial(DS, DSOp.DS_SWIZZLE_B32)
|
||||
ds_load_b32 = functools.partial(DS, DSOp.DS_LOAD_B32)
|
||||
ds_load_2addr_b32 = functools.partial(DS, DSOp.DS_LOAD_2ADDR_B32)
|
||||
ds_load_2addr_stride64_b32 = functools.partial(DS, DSOp.DS_LOAD_2ADDR_STRIDE64_B32)
|
||||
ds_load_i8 = functools.partial(DS, DSOp.DS_LOAD_I8)
|
||||
ds_load_u8 = functools.partial(DS, DSOp.DS_LOAD_U8)
|
||||
ds_load_i16 = functools.partial(DS, DSOp.DS_LOAD_I16)
|
||||
ds_load_u16 = functools.partial(DS, DSOp.DS_LOAD_U16)
|
||||
ds_consume = functools.partial(DS, DSOp.DS_CONSUME)
|
||||
ds_append = functools.partial(DS, DSOp.DS_APPEND)
|
||||
ds_add_u64 = functools.partial(DS, DSOp.DS_ADD_U64)
|
||||
ds_sub_u64 = functools.partial(DS, DSOp.DS_SUB_U64)
|
||||
ds_rsub_u64 = functools.partial(DS, DSOp.DS_RSUB_U64)
|
||||
ds_inc_u64 = functools.partial(DS, DSOp.DS_INC_U64)
|
||||
ds_dec_u64 = functools.partial(DS, DSOp.DS_DEC_U64)
|
||||
ds_min_i64 = functools.partial(DS, DSOp.DS_MIN_I64)
|
||||
ds_max_i64 = functools.partial(DS, DSOp.DS_MAX_I64)
|
||||
ds_min_u64 = functools.partial(DS, DSOp.DS_MIN_U64)
|
||||
ds_max_u64 = functools.partial(DS, DSOp.DS_MAX_U64)
|
||||
ds_and_b64 = functools.partial(DS, DSOp.DS_AND_B64)
|
||||
ds_or_b64 = functools.partial(DS, DSOp.DS_OR_B64)
|
||||
ds_xor_b64 = functools.partial(DS, DSOp.DS_XOR_B64)
|
||||
ds_mskor_b64 = functools.partial(DS, DSOp.DS_MSKOR_B64)
|
||||
ds_store_b64 = functools.partial(DS, DSOp.DS_STORE_B64)
|
||||
ds_store_2addr_b64 = functools.partial(DS, DSOp.DS_STORE_2ADDR_B64)
|
||||
ds_store_2addr_stride64_b64 = functools.partial(DS, DSOp.DS_STORE_2ADDR_STRIDE64_B64)
|
||||
ds_cmpstore_b64 = functools.partial(DS, DSOp.DS_CMPSTORE_B64)
|
||||
ds_min_num_f64 = functools.partial(DS, DSOp.DS_MIN_NUM_F64)
|
||||
ds_max_num_f64 = functools.partial(DS, DSOp.DS_MAX_NUM_F64)
|
||||
ds_add_rtn_u64 = functools.partial(DS, DSOp.DS_ADD_RTN_U64)
|
||||
ds_sub_rtn_u64 = functools.partial(DS, DSOp.DS_SUB_RTN_U64)
|
||||
ds_rsub_rtn_u64 = functools.partial(DS, DSOp.DS_RSUB_RTN_U64)
|
||||
ds_inc_rtn_u64 = functools.partial(DS, DSOp.DS_INC_RTN_U64)
|
||||
ds_dec_rtn_u64 = functools.partial(DS, DSOp.DS_DEC_RTN_U64)
|
||||
ds_min_rtn_i64 = functools.partial(DS, DSOp.DS_MIN_RTN_I64)
|
||||
ds_max_rtn_i64 = functools.partial(DS, DSOp.DS_MAX_RTN_I64)
|
||||
ds_min_rtn_u64 = functools.partial(DS, DSOp.DS_MIN_RTN_U64)
|
||||
ds_max_rtn_u64 = functools.partial(DS, DSOp.DS_MAX_RTN_U64)
|
||||
ds_and_rtn_b64 = functools.partial(DS, DSOp.DS_AND_RTN_B64)
|
||||
ds_or_rtn_b64 = functools.partial(DS, DSOp.DS_OR_RTN_B64)
|
||||
ds_xor_rtn_b64 = functools.partial(DS, DSOp.DS_XOR_RTN_B64)
|
||||
ds_mskor_rtn_b64 = functools.partial(DS, DSOp.DS_MSKOR_RTN_B64)
|
||||
ds_storexchg_rtn_b64 = functools.partial(DS, DSOp.DS_STOREXCHG_RTN_B64)
|
||||
ds_storexchg_2addr_rtn_b64 = functools.partial(DS, DSOp.DS_STOREXCHG_2ADDR_RTN_B64)
|
||||
ds_storexchg_2addr_stride64_rtn_b64 = functools.partial(DS, DSOp.DS_STOREXCHG_2ADDR_STRIDE64_RTN_B64)
|
||||
ds_cmpstore_rtn_b64 = functools.partial(DS, DSOp.DS_CMPSTORE_RTN_B64)
|
||||
ds_min_num_rtn_f64 = functools.partial(DS, DSOp.DS_MIN_NUM_RTN_F64)
|
||||
ds_max_num_rtn_f64 = functools.partial(DS, DSOp.DS_MAX_NUM_RTN_F64)
|
||||
ds_load_b64 = functools.partial(DS, DSOp.DS_LOAD_B64)
|
||||
ds_load_2addr_b64 = functools.partial(DS, DSOp.DS_LOAD_2ADDR_B64)
|
||||
ds_load_2addr_stride64_b64 = functools.partial(DS, DSOp.DS_LOAD_2ADDR_STRIDE64_B64)
|
||||
ds_add_rtn_f32 = functools.partial(DS, DSOp.DS_ADD_RTN_F32)
|
||||
ds_condxchg32_rtn_b64 = functools.partial(DS, DSOp.DS_CONDXCHG32_RTN_B64)
|
||||
ds_cond_sub_u32 = functools.partial(DS, DSOp.DS_COND_SUB_U32)
|
||||
ds_sub_clamp_u32 = functools.partial(DS, DSOp.DS_SUB_CLAMP_U32)
|
||||
ds_pk_add_f16 = functools.partial(DS, DSOp.DS_PK_ADD_F16)
|
||||
ds_pk_add_bf16 = functools.partial(DS, DSOp.DS_PK_ADD_BF16)
|
||||
ds_store_b8_d16_hi = functools.partial(DS, DSOp.DS_STORE_B8_D16_HI)
|
||||
ds_store_b16_d16_hi = functools.partial(DS, DSOp.DS_STORE_B16_D16_HI)
|
||||
ds_load_u8_d16 = functools.partial(DS, DSOp.DS_LOAD_U8_D16)
|
||||
ds_load_u8_d16_hi = functools.partial(DS, DSOp.DS_LOAD_U8_D16_HI)
|
||||
ds_load_i8_d16 = functools.partial(DS, DSOp.DS_LOAD_I8_D16)
|
||||
ds_load_i8_d16_hi = functools.partial(DS, DSOp.DS_LOAD_I8_D16_HI)
|
||||
ds_load_u16_d16 = functools.partial(DS, DSOp.DS_LOAD_U16_D16)
|
||||
ds_load_u16_d16_hi = functools.partial(DS, DSOp.DS_LOAD_U16_D16_HI)
|
||||
ds_cond_sub_rtn_u32 = functools.partial(DS, DSOp.DS_COND_SUB_RTN_U32)
|
||||
ds_sub_clamp_rtn_u32 = functools.partial(DS, DSOp.DS_SUB_CLAMP_RTN_U32)
|
||||
ds_pk_add_rtn_f16 = functools.partial(DS, DSOp.DS_PK_ADD_RTN_F16)
|
||||
ds_pk_add_rtn_bf16 = functools.partial(DS, DSOp.DS_PK_ADD_RTN_BF16)
|
||||
ds_store_addtid_b32 = functools.partial(DS, DSOp.DS_STORE_ADDTID_B32)
|
||||
ds_load_addtid_b32 = functools.partial(DS, DSOp.DS_LOAD_ADDTID_B32)
|
||||
ds_permute_b32 = functools.partial(DS, DSOp.DS_PERMUTE_B32)
|
||||
ds_bpermute_b32 = functools.partial(DS, DSOp.DS_BPERMUTE_B32)
|
||||
ds_bpermute_fi_b32 = functools.partial(DS, DSOp.DS_BPERMUTE_FI_B32)
|
||||
ds_store_b96 = functools.partial(DS, DSOp.DS_STORE_B96)
|
||||
ds_store_b128 = functools.partial(DS, DSOp.DS_STORE_B128)
|
||||
ds_bvh_stack_push4_pop1_rtn_b32 = functools.partial(DS, DSOp.DS_BVH_STACK_PUSH4_POP1_RTN_B32)
|
||||
ds_bvh_stack_push8_pop1_rtn_b32 = functools.partial(DS, DSOp.DS_BVH_STACK_PUSH8_POP1_RTN_B32)
|
||||
ds_bvh_stack_push8_pop2_rtn_b64 = functools.partial(DS, DSOp.DS_BVH_STACK_PUSH8_POP2_RTN_B64)
|
||||
ds_load_b96 = functools.partial(DS, DSOp.DS_LOAD_B96)
|
||||
ds_load_b128 = functools.partial(DS, DSOp.DS_LOAD_B128)
|
||||
s_load_b32 = functools.partial(SMEM, SMEMOp.S_LOAD_B32)
|
||||
s_load_b64 = functools.partial(SMEM, SMEMOp.S_LOAD_B64)
|
||||
s_load_b128 = functools.partial(SMEM, SMEMOp.S_LOAD_B128)
|
||||
@@ -647,126 +729,6 @@ tbuffer_store_d16_format_xyz = functools.partial(VBUFFER, VBUFFEROp.TBUFFER_STOR
|
||||
tbuffer_store_d16_format_xyzw = functools.partial(VBUFFER, VBUFFEROp.TBUFFER_STORE_D16_FORMAT_XYZW)
|
||||
ds_param_load = functools.partial(VDSDIR, VDSDIROp.DS_PARAM_LOAD)
|
||||
ds_direct_load = functools.partial(VDSDIR, VDSDIROp.DS_DIRECT_LOAD)
|
||||
flat_load_u8 = functools.partial(VFLAT, VFLATOp.FLAT_LOAD_U8)
|
||||
flat_load_i8 = functools.partial(VFLAT, VFLATOp.FLAT_LOAD_I8)
|
||||
flat_load_u16 = functools.partial(VFLAT, VFLATOp.FLAT_LOAD_U16)
|
||||
flat_load_i16 = functools.partial(VFLAT, VFLATOp.FLAT_LOAD_I16)
|
||||
flat_load_b32 = functools.partial(VFLAT, VFLATOp.FLAT_LOAD_B32)
|
||||
flat_load_b64 = functools.partial(VFLAT, VFLATOp.FLAT_LOAD_B64)
|
||||
flat_load_b96 = functools.partial(VFLAT, VFLATOp.FLAT_LOAD_B96)
|
||||
flat_load_b128 = functools.partial(VFLAT, VFLATOp.FLAT_LOAD_B128)
|
||||
flat_store_b8 = functools.partial(VFLAT, VFLATOp.FLAT_STORE_B8)
|
||||
flat_store_b16 = functools.partial(VFLAT, VFLATOp.FLAT_STORE_B16)
|
||||
flat_store_b32 = functools.partial(VFLAT, VFLATOp.FLAT_STORE_B32)
|
||||
flat_store_b64 = functools.partial(VFLAT, VFLATOp.FLAT_STORE_B64)
|
||||
flat_store_b96 = functools.partial(VFLAT, VFLATOp.FLAT_STORE_B96)
|
||||
flat_store_b128 = functools.partial(VFLAT, VFLATOp.FLAT_STORE_B128)
|
||||
flat_load_d16_u8 = functools.partial(VFLAT, VFLATOp.FLAT_LOAD_D16_U8)
|
||||
flat_load_d16_i8 = functools.partial(VFLAT, VFLATOp.FLAT_LOAD_D16_I8)
|
||||
flat_load_d16_b16 = functools.partial(VFLAT, VFLATOp.FLAT_LOAD_D16_B16)
|
||||
flat_load_d16_hi_u8 = functools.partial(VFLAT, VFLATOp.FLAT_LOAD_D16_HI_U8)
|
||||
flat_load_d16_hi_i8 = functools.partial(VFLAT, VFLATOp.FLAT_LOAD_D16_HI_I8)
|
||||
flat_load_d16_hi_b16 = functools.partial(VFLAT, VFLATOp.FLAT_LOAD_D16_HI_B16)
|
||||
flat_store_d16_hi_b8 = functools.partial(VFLAT, VFLATOp.FLAT_STORE_D16_HI_B8)
|
||||
flat_store_d16_hi_b16 = functools.partial(VFLAT, VFLATOp.FLAT_STORE_D16_HI_B16)
|
||||
flat_atomic_swap_b32 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_SWAP_B32)
|
||||
flat_atomic_cmpswap_b32 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_CMPSWAP_B32)
|
||||
flat_atomic_add_u32 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_ADD_U32)
|
||||
flat_atomic_sub_u32 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_SUB_U32)
|
||||
flat_atomic_sub_clamp_u32 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_SUB_CLAMP_U32)
|
||||
flat_atomic_min_i32 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_MIN_I32)
|
||||
flat_atomic_min_u32 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_MIN_U32)
|
||||
flat_atomic_max_i32 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_MAX_I32)
|
||||
flat_atomic_max_u32 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_MAX_U32)
|
||||
flat_atomic_and_b32 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_AND_B32)
|
||||
flat_atomic_or_b32 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_OR_B32)
|
||||
flat_atomic_xor_b32 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_XOR_B32)
|
||||
flat_atomic_inc_u32 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_INC_U32)
|
||||
flat_atomic_dec_u32 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_DEC_U32)
|
||||
flat_atomic_swap_b64 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_SWAP_B64)
|
||||
flat_atomic_cmpswap_b64 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_CMPSWAP_B64)
|
||||
flat_atomic_add_u64 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_ADD_U64)
|
||||
flat_atomic_sub_u64 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_SUB_U64)
|
||||
flat_atomic_min_i64 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_MIN_I64)
|
||||
flat_atomic_min_u64 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_MIN_U64)
|
||||
flat_atomic_max_i64 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_MAX_I64)
|
||||
flat_atomic_max_u64 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_MAX_U64)
|
||||
flat_atomic_and_b64 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_AND_B64)
|
||||
flat_atomic_or_b64 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_OR_B64)
|
||||
flat_atomic_xor_b64 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_XOR_B64)
|
||||
flat_atomic_inc_u64 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_INC_U64)
|
||||
flat_atomic_dec_u64 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_DEC_U64)
|
||||
flat_atomic_cond_sub_u32 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_COND_SUB_U32)
|
||||
flat_atomic_min_num_f32 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_MIN_NUM_F32)
|
||||
flat_atomic_max_num_f32 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_MAX_NUM_F32)
|
||||
flat_atomic_add_f32 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_ADD_F32)
|
||||
flat_atomic_pk_add_f16 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_PK_ADD_F16)
|
||||
flat_atomic_pk_add_bf16 = functools.partial(VFLAT, VFLATOp.FLAT_ATOMIC_PK_ADD_BF16)
|
||||
global_load_u8 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_LOAD_U8)
|
||||
global_load_i8 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_LOAD_I8)
|
||||
global_load_u16 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_LOAD_U16)
|
||||
global_load_i16 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_LOAD_I16)
|
||||
global_load_b32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_LOAD_B32)
|
||||
global_load_b64 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_LOAD_B64)
|
||||
global_load_b96 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_LOAD_B96)
|
||||
global_load_b128 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_LOAD_B128)
|
||||
global_store_b8 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_STORE_B8)
|
||||
global_store_b16 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_STORE_B16)
|
||||
global_store_b32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_STORE_B32)
|
||||
global_store_b64 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_STORE_B64)
|
||||
global_store_b96 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_STORE_B96)
|
||||
global_store_b128 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_STORE_B128)
|
||||
global_load_d16_u8 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_LOAD_D16_U8)
|
||||
global_load_d16_i8 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_LOAD_D16_I8)
|
||||
global_load_d16_b16 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_LOAD_D16_B16)
|
||||
global_load_d16_hi_u8 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_LOAD_D16_HI_U8)
|
||||
global_load_d16_hi_i8 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_LOAD_D16_HI_I8)
|
||||
global_load_d16_hi_b16 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_LOAD_D16_HI_B16)
|
||||
global_store_d16_hi_b8 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_STORE_D16_HI_B8)
|
||||
global_store_d16_hi_b16 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_STORE_D16_HI_B16)
|
||||
global_load_addtid_b32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_LOAD_ADDTID_B32)
|
||||
global_store_addtid_b32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_STORE_ADDTID_B32)
|
||||
global_inv = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_INV)
|
||||
global_wb = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_WB)
|
||||
global_atomic_swap_b32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_SWAP_B32)
|
||||
global_atomic_cmpswap_b32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_CMPSWAP_B32)
|
||||
global_atomic_add_u32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_ADD_U32)
|
||||
global_atomic_sub_u32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_SUB_U32)
|
||||
global_atomic_sub_clamp_u32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_SUB_CLAMP_U32)
|
||||
global_atomic_min_i32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_MIN_I32)
|
||||
global_atomic_min_u32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_MIN_U32)
|
||||
global_atomic_max_i32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_MAX_I32)
|
||||
global_atomic_max_u32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_MAX_U32)
|
||||
global_atomic_and_b32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_AND_B32)
|
||||
global_atomic_or_b32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_OR_B32)
|
||||
global_atomic_xor_b32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_XOR_B32)
|
||||
global_atomic_inc_u32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_INC_U32)
|
||||
global_atomic_dec_u32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_DEC_U32)
|
||||
global_atomic_swap_b64 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_SWAP_B64)
|
||||
global_atomic_cmpswap_b64 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_CMPSWAP_B64)
|
||||
global_atomic_add_u64 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_ADD_U64)
|
||||
global_atomic_sub_u64 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_SUB_U64)
|
||||
global_atomic_min_i64 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_MIN_I64)
|
||||
global_atomic_min_u64 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_MIN_U64)
|
||||
global_atomic_max_i64 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_MAX_I64)
|
||||
global_atomic_max_u64 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_MAX_U64)
|
||||
global_atomic_and_b64 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_AND_B64)
|
||||
global_atomic_or_b64 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_OR_B64)
|
||||
global_atomic_xor_b64 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_XOR_B64)
|
||||
global_atomic_inc_u64 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_INC_U64)
|
||||
global_atomic_dec_u64 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_DEC_U64)
|
||||
global_wbinv = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_WBINV)
|
||||
global_atomic_cond_sub_u32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_COND_SUB_U32)
|
||||
global_atomic_min_num_f32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_MIN_NUM_F32)
|
||||
global_atomic_max_num_f32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_MAX_NUM_F32)
|
||||
global_load_block = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_LOAD_BLOCK)
|
||||
global_store_block = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_STORE_BLOCK)
|
||||
global_atomic_add_f32 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_ADD_F32)
|
||||
global_load_tr_b128 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_LOAD_TR_B128)
|
||||
global_load_tr_b64 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_LOAD_TR_B64)
|
||||
global_atomic_pk_add_f16 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_PK_ADD_F16)
|
||||
global_atomic_pk_add_bf16 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_PK_ADD_BF16)
|
||||
global_atomic_ordered_add_b64 = functools.partial(VGLOBAL, VGLOBALOp.GLOBAL_ATOMIC_ORDERED_ADD_B64)
|
||||
image_load = functools.partial(VIMAGE, VIMAGEOp.IMAGE_LOAD)
|
||||
image_load_mip = functools.partial(VIMAGE, VIMAGEOp.IMAGE_LOAD_MIP)
|
||||
image_load_pck = functools.partial(VIMAGE, VIMAGEOp.IMAGE_LOAD_PCK)
|
||||
@@ -931,8 +893,8 @@ v_add_nc_u32_e32 = functools.partial(VOP2, VOP2Op.V_ADD_NC_U32)
|
||||
v_sub_nc_u32_e32 = functools.partial(VOP2, VOP2Op.V_SUB_NC_U32)
|
||||
v_subrev_nc_u32_e32 = functools.partial(VOP2, VOP2Op.V_SUBREV_NC_U32)
|
||||
v_fmac_f32_e32 = functools.partial(VOP2, VOP2Op.V_FMAC_F32)
|
||||
def v_fmamk_f32_e32(vdst, src0, K, vsrc1): return VOP2(VOP2Op.V_FMAMK_F32, vdst, src0, vsrc1, literal=K)
|
||||
def v_fmaak_f32_e32(vdst, src0, vsrc1, K): return VOP2(VOP2Op.V_FMAAK_F32, vdst, src0, vsrc1, literal=K)
|
||||
v_fmamk_f32_e32 = functools.partial(VOP2, VOP2Op.V_FMAMK_F32)
|
||||
v_fmaak_f32_e32 = functools.partial(VOP2, VOP2Op.V_FMAAK_F32)
|
||||
v_cvt_pk_rtz_f16_f32_e32 = functools.partial(VOP2, VOP2Op.V_CVT_PK_RTZ_F16_F32)
|
||||
v_min_num_f16_e32 = functools.partial(VOP2, VOP2Op.V_MIN_NUM_F16)
|
||||
v_max_num_f16_e32 = functools.partial(VOP2, VOP2Op.V_MAX_NUM_F16)
|
||||
@@ -941,8 +903,8 @@ v_sub_f16_e32 = functools.partial(VOP2, VOP2Op.V_SUB_F16)
|
||||
v_subrev_f16_e32 = functools.partial(VOP2, VOP2Op.V_SUBREV_F16)
|
||||
v_mul_f16_e32 = functools.partial(VOP2, VOP2Op.V_MUL_F16)
|
||||
v_fmac_f16_e32 = functools.partial(VOP2, VOP2Op.V_FMAC_F16)
|
||||
def v_fmamk_f16_e32(vdst, src0, K, vsrc1): return VOP2(VOP2Op.V_FMAMK_F16, vdst, src0, vsrc1, literal=K)
|
||||
def v_fmaak_f16_e32(vdst, src0, vsrc1, K): return VOP2(VOP2Op.V_FMAAK_F16, vdst, src0, vsrc1, literal=K)
|
||||
v_fmamk_f16_e32 = functools.partial(VOP2, VOP2Op.V_FMAMK_F16)
|
||||
v_fmaak_f16_e32 = functools.partial(VOP2, VOP2Op.V_FMAAK_F16)
|
||||
v_ldexp_f16_e32 = functools.partial(VOP2, VOP2Op.V_LDEXP_F16)
|
||||
v_pk_fmac_f16_e32 = functools.partial(VOP2, VOP2Op.V_PK_FMAC_F16)
|
||||
v_cmp_lt_f16_e64 = functools.partial(VOP3, VOP3Op.V_CMP_LT_F16)
|
||||
@@ -1435,7 +1397,6 @@ v_swmmac_f32_16x16x32_fp8_fp8 = functools.partial(VOP3P, VOP3POp.V_SWMMAC_F32_16
|
||||
v_swmmac_f32_16x16x32_fp8_bf8 = functools.partial(VOP3P, VOP3POp.V_SWMMAC_F32_16X16X32_FP8_BF8)
|
||||
v_swmmac_f32_16x16x32_bf8_fp8 = functools.partial(VOP3P, VOP3POp.V_SWMMAC_F32_16X16X32_BF8_FP8)
|
||||
v_swmmac_f32_16x16x32_bf8_bf8 = functools.partial(VOP3P, VOP3POp.V_SWMMAC_F32_16X16X32_BF8_BF8)
|
||||
dword = functools.partial(VOP3SD, VOP3SDOp.DWORD)
|
||||
v_add_co_ci_u32 = functools.partial(VOP3SD, VOP3SDOp.V_ADD_CO_CI_U32)
|
||||
v_sub_co_ci_u32 = functools.partial(VOP3SD, VOP3SDOp.V_SUB_CO_CI_U32)
|
||||
v_subrev_co_ci_u32 = functools.partial(VOP3SD, VOP3SDOp.V_SUBREV_CO_CI_U32)
|
||||
@@ -1682,55 +1643,4 @@ image_gather4_c_cl = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_GATHER4_C_CL)
|
||||
image_gather4_c_l = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_GATHER4_C_L)
|
||||
image_gather4_c_b = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_GATHER4_C_B)
|
||||
image_gather4_c_b_cl = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_GATHER4_C_B_CL)
|
||||
image_gather4h = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_GATHER4H)
|
||||
scratch_load_u8 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_LOAD_U8)
|
||||
scratch_load_i8 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_LOAD_I8)
|
||||
scratch_load_u16 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_LOAD_U16)
|
||||
scratch_load_i16 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_LOAD_I16)
|
||||
scratch_load_b32 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_LOAD_B32)
|
||||
scratch_load_b64 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_LOAD_B64)
|
||||
scratch_load_b96 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_LOAD_B96)
|
||||
scratch_load_b128 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_LOAD_B128)
|
||||
scratch_store_b8 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_STORE_B8)
|
||||
scratch_store_b16 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_STORE_B16)
|
||||
scratch_store_b32 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_STORE_B32)
|
||||
scratch_store_b64 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_STORE_B64)
|
||||
scratch_store_b96 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_STORE_B96)
|
||||
scratch_store_b128 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_STORE_B128)
|
||||
scratch_load_d16_u8 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_LOAD_D16_U8)
|
||||
scratch_load_d16_i8 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_LOAD_D16_I8)
|
||||
scratch_load_d16_b16 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_LOAD_D16_B16)
|
||||
scratch_load_d16_hi_u8 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_LOAD_D16_HI_U8)
|
||||
scratch_load_d16_hi_i8 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_LOAD_D16_HI_I8)
|
||||
scratch_load_d16_hi_b16 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_LOAD_D16_HI_B16)
|
||||
scratch_store_d16_hi_b8 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_STORE_D16_HI_B8)
|
||||
scratch_store_d16_hi_b16 = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_STORE_D16_HI_B16)
|
||||
scratch_load_block = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_LOAD_BLOCK)
|
||||
scratch_store_block = functools.partial(VSCRATCH, VSCRATCHOp.SCRATCH_STORE_BLOCK)
|
||||
|
||||
VCC_LO = SrcEnum.VCC_LO
|
||||
VCC_HI = SrcEnum.VCC_HI
|
||||
NULL = SrcEnum.NULL
|
||||
M0 = SrcEnum.M0
|
||||
EXEC_LO = SrcEnum.EXEC_LO
|
||||
EXEC_HI = SrcEnum.EXEC_HI
|
||||
ZERO = SrcEnum.ZERO
|
||||
DPP8FI = SrcEnum.DPP8FI
|
||||
SHARED_BASE = SrcEnum.SHARED_BASE
|
||||
SHARED_LIMIT = SrcEnum.SHARED_LIMIT
|
||||
PRIVATE_BASE = SrcEnum.PRIVATE_BASE
|
||||
PRIVATE_LIMIT = SrcEnum.PRIVATE_LIMIT
|
||||
POS_HALF = SrcEnum.POS_HALF
|
||||
NEG_HALF = SrcEnum.NEG_HALF
|
||||
POS_ONE = SrcEnum.POS_ONE
|
||||
NEG_ONE = SrcEnum.NEG_ONE
|
||||
POS_TWO = SrcEnum.POS_TWO
|
||||
NEG_TWO = SrcEnum.NEG_TWO
|
||||
POS_FOUR = SrcEnum.POS_FOUR
|
||||
NEG_FOUR = SrcEnum.NEG_FOUR
|
||||
INV_2PI = SrcEnum.INV_2PI
|
||||
VCCZ = SrcEnum.VCCZ
|
||||
EXECZ = SrcEnum.EXECZ
|
||||
SCC = SrcEnum.SCC
|
||||
LDS_DIRECT = SrcEnum.LDS_DIRECT
|
||||
OFF = NULL
|
||||
image_gather4h = functools.partial(VSAMPLE, VSAMPLEOp.IMAGE_GATHER4H)
|
||||
File diff suppressed because one or more lines are too long
@@ -7,6 +7,19 @@ from functools import cache
|
||||
from typing import overload, Annotated, TypeVar, Generic
|
||||
from extra.assembly.amd.autogen.rdna3.enum import (VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOP3POp, VOPCOp, VOPDOp, SOP1Op, SOP2Op,
|
||||
SOPCOp, SOPKOp, SOPPOp, SMEMOp, DSOp, FLATOp, MUBUFOp, MTBUFOp, MIMGOp, VINTERPOp)
|
||||
from extra.assembly.amd.autogen.cdna.enum import VOP1Op as CDNA_VOP1Op, VOP2Op as CDNA_VOP2Op
|
||||
|
||||
# Source operand encoding - constant across all AMD ISAs
|
||||
class SrcEnum(IntEnum):
|
||||
VCC_LO=106; VCC_HI=107; NULL=124; M0=125; EXEC_LO=126; EXEC_HI=127; ZERO=128
|
||||
DPP8=233; DPP8FI=234; SHARED_BASE=235; SHARED_LIMIT=236; PRIVATE_BASE=237; PRIVATE_LIMIT=238
|
||||
POS_HALF=240; NEG_HALF=241; POS_ONE=242; NEG_ONE=243; POS_TWO=244; NEG_TWO=245
|
||||
POS_FOUR=246; NEG_FOUR=247; INV_2PI=248; DPP16=250; VCCZ=251; EXECZ=252; SCC=253; LDS_DIRECT=254
|
||||
VCC_LO, VCC_HI, NULL, M0, EXEC_LO, EXEC_HI, ZERO = SrcEnum.VCC_LO, SrcEnum.VCC_HI, SrcEnum.NULL, SrcEnum.M0, SrcEnum.EXEC_LO, SrcEnum.EXEC_HI, SrcEnum.ZERO
|
||||
DPP8FI, SHARED_BASE, SHARED_LIMIT, PRIVATE_BASE, PRIVATE_LIMIT = SrcEnum.DPP8FI, SrcEnum.SHARED_BASE, SrcEnum.SHARED_LIMIT, SrcEnum.PRIVATE_BASE, SrcEnum.PRIVATE_LIMIT
|
||||
POS_HALF, NEG_HALF, POS_ONE, NEG_ONE, POS_TWO, NEG_TWO = SrcEnum.POS_HALF, SrcEnum.NEG_HALF, SrcEnum.POS_ONE, SrcEnum.NEG_ONE, SrcEnum.POS_TWO, SrcEnum.NEG_TWO
|
||||
POS_FOUR, NEG_FOUR, INV_2PI, VCCZ, EXECZ, SCC, LDS_DIRECT = SrcEnum.POS_FOUR, SrcEnum.NEG_FOUR, SrcEnum.INV_2PI, SrcEnum.VCCZ, SrcEnum.EXECZ, SrcEnum.SCC, SrcEnum.LDS_DIRECT
|
||||
OFF = NULL
|
||||
|
||||
# Common masks and bit conversion functions
|
||||
MASK32, MASK64, MASK128 = 0xffffffff, 0xffffffffffffffff, (1 << 128) - 1
|
||||
@@ -54,12 +67,14 @@ def f32_to_f16(f):
|
||||
# Instruction spec - register counts and dtypes derived from instruction names
|
||||
_REGS = {'B32': 1, 'B64': 2, 'B96': 3, 'B128': 4, 'B256': 8, 'B512': 16,
|
||||
'F32': 1, 'I32': 1, 'U32': 1, 'F64': 2, 'I64': 2, 'U64': 2,
|
||||
'F16': 1, 'I16': 1, 'U16': 1, 'B16': 1, 'I8': 1, 'U8': 1, 'B8': 1}
|
||||
'F16': 1, 'I16': 1, 'U16': 1, 'B16': 1, 'I8': 1, 'U8': 1, 'B8': 1,
|
||||
'DWORD': 1, 'DWORDX2': 2, 'DWORDX3': 3, 'DWORDX4': 4, 'DWORDX8': 8, 'DWORDX16': 16,
|
||||
'BYTE': 1, 'SHORT': 1, 'UBYTE': 1, 'SBYTE': 1, 'USHORT': 1, 'SSHORT': 1}
|
||||
_CVT_RE = re.compile(r'CVT_([FIUB]\d+)_([FIUB]\d+)$')
|
||||
_MAD_MUL_RE = re.compile(r'(?:MAD|MUL)_([IU]\d+)_([IU]\d+)$')
|
||||
_PACK_RE = re.compile(r'PACK_([FIUB]\d+)_([FIUB]\d+)$')
|
||||
_DST_SRC_RE = re.compile(r'_([FIUB]\d+)_([FIUB]\d+)$')
|
||||
_SINGLE_RE = re.compile(r'_([FIUB](?:32|64|16|8|96|128|256|512))$')
|
||||
_SINGLE_RE = re.compile(r'_([FIUB](?:32|64|16|8|96|128|256|512)|DWORD(?:X(?:2|3|4|8|16))?|[US]?BYTE|[US]?SHORT)$')
|
||||
@cache
|
||||
def _suffix(name: str) -> tuple[str | None, str | None]:
|
||||
name = name.upper()
|
||||
@@ -250,7 +265,11 @@ def unwrap(val) -> int:
|
||||
FLOAT_ENC = {0.5: 240, -0.5: 241, 1.0: 242, -1.0: 243, 2.0: 244, -2.0: 245, 4.0: 246, -4.0: 247}
|
||||
FLOAT_DEC = {v: str(k) for k, v in FLOAT_ENC.items()}
|
||||
SPECIAL_GPRS = {106: "vcc_lo", 107: "vcc_hi", 124: "null", 125: "m0", 126: "exec_lo", 127: "exec_hi", 253: "scc"}
|
||||
SPECIAL_GPRS_CDNA = {102: "flat_scratch_lo", 103: "flat_scratch_hi", 104: "xnack_mask_lo", 105: "xnack_mask_hi",
|
||||
106: "vcc_lo", 107: "vcc_hi", 124: "m0", 126: "exec_lo", 127: "exec_hi",
|
||||
251: "src_vccz", 252: "src_execz", 253: "src_scc", 254: "src_lds_direct"}
|
||||
SPECIAL_PAIRS = {106: "vcc", 126: "exec"}
|
||||
SPECIAL_PAIRS_CDNA = {102: "flat_scratch", 104: "xnack_mask", 106: "vcc", 126: "exec"}
|
||||
SRC_FIELDS = {'src0', 'src1', 'src2', 'ssrc0', 'ssrc1', 'soffset', 'srcx0', 'srcy0'}
|
||||
RAW_FIELDS = {'vdata', 'vdst', 'vaddr', 'addr', 'data', 'data0', 'data1', 'sdst', 'sdata', 'vsrc1'}
|
||||
|
||||
@@ -267,9 +286,10 @@ def encode_src(val) -> int:
|
||||
if isinstance(val, int): return 128 + val if 0 <= val <= 64 else 192 - val if -16 <= val <= -1 else 255
|
||||
return 255
|
||||
|
||||
def decode_src(val: int) -> str:
|
||||
def decode_src(val: int, cdna: bool = False) -> str:
|
||||
special = SPECIAL_GPRS_CDNA if cdna else SPECIAL_GPRS
|
||||
if val in special: return special[val]
|
||||
if val <= 105: return f"s{val}"
|
||||
if val in SPECIAL_GPRS: return SPECIAL_GPRS[val]
|
||||
if val in FLOAT_DEC: return FLOAT_DEC[val]
|
||||
if 108 <= val <= 123: return f"ttmp{val - 108}"
|
||||
if 128 <= val <= 192: return str(val - 128)
|
||||
@@ -288,7 +308,16 @@ class Inst:
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
super().__init_subclass__(**kwargs)
|
||||
cls._fields = {n: v[0] if isinstance(v, tuple) else v for n, v in cls.__dict__.items() if isinstance(v, BitField) or (isinstance(v, tuple) and len(v) == 2 and isinstance(v[0], BitField))}
|
||||
# Merge fields from parent classes
|
||||
cls._fields = {}
|
||||
for base in reversed(cls.__mro__):
|
||||
if base is Inst or not hasattr(base, '_fields'): continue
|
||||
cls._fields.update(base._fields)
|
||||
# Add this class's own fields (overrides parents)
|
||||
cls._fields.update({n: v[0] if isinstance(v, tuple) else v for n, v in cls.__dict__.items() if isinstance(v, BitField) or (isinstance(v, tuple) and len(v) == 2 and isinstance(v[0], BitField))})
|
||||
# Compute size from max bit (exclude optional fields starting at bit 64+, e.g. MIMG NSA)
|
||||
max_bit = max((bf.hi for bf in cls._fields.values() if bf.lo < 64), default=0) if cls._fields else 0
|
||||
cls._sz = 12 if max_bit > 63 else 8 if max_bit > 31 else 4
|
||||
if 'encoding' in cls._fields and isinstance(cls.__dict__.get('encoding'), tuple): cls._encoding = cls.__dict__['encoding']
|
||||
|
||||
def _or_field(self, name: str, bit: int):
|
||||
@@ -352,6 +381,16 @@ class Inst:
|
||||
field_names = [n for n in self._fields if n != 'encoding']
|
||||
# Map Python-friendly names to actual field names (abs_ -> abs for Python reserved word)
|
||||
if 'abs_' in kwargs: kwargs['abs'] = kwargs.pop('abs_')
|
||||
# If more args than fields, treat extra arg as literal (for FMAAK/FMAMK style instructions)
|
||||
# FMAMK has K in middle (vdst, src0, K, vsrc1), FMAAK has K at end (vdst, src0, vsrc1, K)
|
||||
args = list(args)
|
||||
if len(args) > len(field_names) and literal is None:
|
||||
for i, a in enumerate(args):
|
||||
if isinstance(a, int) and not isinstance(a, SrcEnum) and i < len(field_names) and field_names[i] in ('vsrc1',):
|
||||
literal = args.pop(i)
|
||||
break
|
||||
else:
|
||||
literal = args.pop() # fallback: last arg is literal
|
||||
orig_args = dict(zip(field_names, args)) | kwargs
|
||||
self._values.update(orig_args)
|
||||
self._precompute()
|
||||
@@ -393,14 +432,14 @@ class Inst:
|
||||
if name in SRC_FIELDS: self._encode_src(name, val)
|
||||
elif name in RAW_FIELDS: self._encode_raw(name, val)
|
||||
elif name == 'sbase': self._values[name] = (val.idx if isinstance(val, Reg) else val.val if isinstance(val, SrcMod) else val * 2) // 2
|
||||
elif name in {'srsrc', 'ssamp'} and isinstance(val, Reg): self._values[name] = val.idx // 4
|
||||
elif name in {'srsrc', 'ssamp'} and isinstance(val, Reg): self._values[name] = _encode_reg(val) // 4
|
||||
elif marker is _VDSTYEnc and isinstance(val, VGPR): self._values[name] = val.idx >> 1
|
||||
self._precompute_fields()
|
||||
|
||||
def _encode_field(self, name: str, val) -> int:
|
||||
if isinstance(val, RawImm): return val.val
|
||||
if isinstance(val, SrcMod) and not isinstance(val, Reg): return val.val # Special regs like VCC_LO
|
||||
if name in {'srsrc', 'ssamp'}: return val.idx // 4 if isinstance(val, Reg) else val
|
||||
if name in {'srsrc', 'ssamp'}: return _encode_reg(val) // 4 if isinstance(val, Reg) else val
|
||||
if name == 'sbase': return val.idx // 2 if isinstance(val, Reg) else val.val // 2 if isinstance(val, SrcMod) else val
|
||||
if name in RAW_FIELDS: return _encode_reg(val) if isinstance(val, Reg) else val
|
||||
if isinstance(val, Reg) or name in SRC_FIELDS: return encode_src(val)
|
||||
@@ -450,7 +489,7 @@ class Inst:
|
||||
return result + (lit32 & MASK32).to_bytes(4, 'little')
|
||||
|
||||
@classmethod
|
||||
def _size(cls) -> int: return 4 if issubclass(cls, Inst32) else 12 if issubclass(cls, Inst96) else 8
|
||||
def _size(cls) -> int: return cls._sz
|
||||
def size(self) -> int:
|
||||
# Literal is always 4 bytes in the binary (for 64-bit ops, it's in high 32 bits)
|
||||
return self._size() + (4 if self._literal is not None else 0)
|
||||
@@ -514,7 +553,7 @@ class Inst:
|
||||
lit32 = (self._literal >> 32) if self._literal > 0xffffffff else self._literal
|
||||
s = f"0x{lit32:x}"
|
||||
else:
|
||||
s = decode_src(v)
|
||||
s = decode_src(v, 'cdna' in self.__class__.__module__)
|
||||
return f"-{s}" if neg else s
|
||||
|
||||
def __eq__(self, other):
|
||||
@@ -540,21 +579,30 @@ class Inst:
|
||||
elif hasattr(val, 'name'): self.op = val
|
||||
else:
|
||||
cls_name = self.__class__.__name__
|
||||
# VOP3 with VOPC opcodes (0-255) -> VOPCOp, VOP3SD opcodes -> VOP3SDOp
|
||||
if cls_name == 'VOP3':
|
||||
try:
|
||||
if val < 256: self.op = VOPCOp(val)
|
||||
elif val in self._VOP3SD_OPS: self.op = VOP3SDOp(val)
|
||||
else: self.op = VOP3Op(val)
|
||||
except ValueError: self.op = val
|
||||
# Prefer BitField marker (class-specific enum) over _enum_map (generic RDNA3 enums)
|
||||
elif 'op' in self._fields and (marker := self._fields['op'].marker) and issubclass(marker, IntEnum):
|
||||
is_cdna = cls_name in ('VOP3A', 'VOP3B')
|
||||
# Try marker enum first (VOP3AOp, VOP3BOp, etc.)
|
||||
marker = self._fields['op'].marker if 'op' in self._fields else None
|
||||
if marker and issubclass(marker, IntEnum):
|
||||
try: self.op = marker(val)
|
||||
except ValueError: self.op = val
|
||||
elif cls_name in self._enum_map:
|
||||
try: self.op = self._enum_map[cls_name](val)
|
||||
except ValueError: self.op = val
|
||||
else: self.op = val
|
||||
# Fallback for promoted instructions when marker lookup failed
|
||||
if not hasattr(self.op, 'name') and cls_name in ('VOP3', 'VOP3A', 'VOP3B') and isinstance(val, int):
|
||||
if val < 256:
|
||||
try: self.op = VOPCOp(val)
|
||||
except ValueError: pass
|
||||
elif is_cdna and 256 <= val < 512:
|
||||
try: self.op = (CDNA_VOP1Op(val - 320) if val >= 320 else CDNA_VOP2Op(val - 256))
|
||||
except ValueError: pass
|
||||
elif val in self._VOP3SD_OPS and not is_cdna:
|
||||
try: self.op = VOP3SDOp(val)
|
||||
except ValueError: pass
|
||||
elif 256 <= val < 512 and not is_cdna:
|
||||
try: self.op = VOP1Op(val - 384) if val >= 384 else VOP2Op(val - 256)
|
||||
except ValueError: pass
|
||||
self.op_name = self.op.name if hasattr(self.op, 'name') else ''
|
||||
self._spec_regs = spec_regs(self.op_name)
|
||||
self._spec_dtype = spec_dtype(self.op_name)
|
||||
@@ -574,6 +622,4 @@ class Inst:
|
||||
def is_64bit(self) -> bool: return spec_is_64bit(self.op_name)
|
||||
def is_dst_16(self) -> bool: return self._spec_regs[0] == 1 and is_dtype_16(self._spec_dtype[0])
|
||||
|
||||
class Inst32(Inst): pass
|
||||
class Inst64(Inst): pass
|
||||
class Inst96(Inst): pass
|
||||
|
||||
|
||||
@@ -7,8 +7,9 @@ from extra.assembly.amd.dsl import Inst, unwrap, FLOAT_ENC, MASK32, MASK64, _f32
|
||||
from extra.assembly.amd.asm import detect_format
|
||||
from extra.assembly.amd.ucode import compile_uop
|
||||
from extra.assembly.amd.autogen.rdna3.str_pcode import PSEUDOCODE_STRINGS
|
||||
from extra.assembly.amd.dsl import SrcEnum
|
||||
from extra.assembly.amd.autogen.rdna3.ins import (SOP1, SOP2, SOPC, SOPK, SOPP, SMEM, VOP1, VOP2, VOP3, VOP3SD, VOP3P, VOPC, DS, FLAT, VOPD,
|
||||
SrcEnum, SOP1Op, SOP2Op, SOPCOp, SOPKOp, SOPPOp, SMEMOp, VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOP3POp, VOPCOp, DSOp, FLATOp, GLOBALOp, SCRATCHOp, VOPDOp)
|
||||
SOP1Op, SOP2Op, SOPCOp, SOPKOp, SOPPOp, SMEMOp, VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOP3POp, VOPCOp, DSOp, FLATOp, GLOBALOp, SCRATCHOp, VOPDOp)
|
||||
|
||||
WAVE_SIZE, SGPR_COUNT, VGPR_COUNT = 32, 128, 256
|
||||
VCC_LO, VCC_HI, NULL, EXEC_LO, EXEC_HI, SCC = SrcEnum.VCC_LO, SrcEnum.VCC_HI, SrcEnum.NULL, SrcEnum.EXEC_LO, SrcEnum.EXEC_HI, SrcEnum.SCC
|
||||
@@ -291,14 +292,16 @@ def exec_vop(st: WaveState, inst: Inst, V: list, lane: int) -> None:
|
||||
extra_kwargs = {'opsel': opsel, 'opsel_hi': inst.opsel_hi | (inst.opsel_hi2 << 2)} if isinstance(inst, VOP3P) and 'FMA_MIX' in inst.op_name else {}
|
||||
result = inst._fn(s0, s1, s2, d0, st.scc, vcc_for_fn, lane, st.exec_mask, inst._literal, st.vgpr, src0_idx, vdst, **extra_kwargs)
|
||||
|
||||
# Check if this is a VOPC instruction (either standalone VOPC or VOP3 with VOPC opcode)
|
||||
is_vopc = isinstance(inst.op, VOPCOp) or (isinstance(inst, VOP3) and inst.op.value < 256)
|
||||
if 'VCC' in result:
|
||||
if isinstance(inst, VOP3SD): st.pend_sgpr_lane(inst.sdst, lane, (result['VCC'] >> lane) & 1)
|
||||
else: st.pend_sgpr_lane(VCC_LO if isinstance(inst, VOP2) and 'CO_CI' in inst.op_name else vdst, lane, (result['VCC'] >> lane) & 1)
|
||||
if 'EXEC' in result:
|
||||
st.pend_sgpr_lane(EXEC_LO, lane, (result['EXEC'] >> lane) & 1)
|
||||
elif isinstance(inst.op, VOPCOp):
|
||||
elif is_vopc:
|
||||
st.pend_sgpr_lane(vdst, lane, (result['D0'] >> lane) & 1)
|
||||
if not isinstance(inst.op, VOPCOp):
|
||||
if not is_vopc:
|
||||
d0_val = result['D0']
|
||||
if inst.dst_regs() == 2: V[vdst], V[vdst + 1] = d0_val & MASK32, (d0_val >> 32) & MASK32
|
||||
elif not isinstance(inst, VOP3P) and inst.is_dst_16(): V[vdst] = _dst16(V[vdst], d0_val, bool(opsel & 8) if isinstance(inst, VOP3) else dst_hi)
|
||||
|
||||
@@ -1,435 +1,305 @@
|
||||
# Generate AMD ISA autogen files from PDF documentation
|
||||
# Combines format/enum generation (previously in dsl.py) and pseudocode compilation (previously in pcode.py)
|
||||
# Usage: python -m extra.assembly.amd.pdf [--arch rdna3|rdna4|cdna|all]
|
||||
import re, functools
|
||||
from pathlib import Path
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
# Generic PDF text extractor - no external dependencies
|
||||
import re, zlib
|
||||
from tinygrad.helpers import fetch, merge_dicts
|
||||
|
||||
PDF_URLS = {
|
||||
"rdna3": "https://docs.amd.com/api/khub/documents/UVVZM22UN7tMUeiW_4ShTQ/content",
|
||||
"rdna4": "https://docs.amd.com/api/khub/documents/uQpkEvk3pv~kfAb2x~j4uw/content",
|
||||
"cdna": ["https://www.amd.com/content/dam/amd/en/documents/instinct-tech-docs/instruction-set-architectures/amd-instinct-mi300-cdna3-instruction-set-architecture.pdf",
|
||||
"https://www.amd.com/content/dam/amd/en/documents/instinct-tech-docs/instruction-set-architectures/amd-instinct-cdna4-instruction-set-architecture.pdf"],
|
||||
"cdna": "https://www.amd.com/content/dam/amd/en/documents/instinct-tech-docs/instruction-set-architectures/amd-instinct-cdna4-instruction-set-architecture.pdf",
|
||||
}
|
||||
|
||||
# Field type mappings and ordering
|
||||
FIELD_TYPES = {'SSRC0': 'SSrc', 'SSRC1': 'SSrc', 'SOFFSET': 'SSrc', 'SADDR': 'SSrc', 'SRC0': 'Src', 'SRC1': 'Src', 'SRC2': 'Src',
|
||||
'SDST': 'SGPRField', 'SBASE': 'SGPRField', 'SDATA': 'SGPRField', 'SRSRC': 'SGPRField', 'VDST': 'VGPRField', 'VSRC1': 'VGPRField',
|
||||
'VDATA': 'VGPRField', 'VADDR': 'VGPRField', 'ADDR': 'VGPRField', 'DATA': 'VGPRField', 'DATA0': 'VGPRField', 'DATA1': 'VGPRField',
|
||||
'SIMM16': 'SImm', 'OFFSET': 'Imm', 'OPX': 'VOPDOp', 'OPY': 'VOPDOp', 'SRCX0': 'Src', 'SRCY0': 'Src',
|
||||
'VSRCX1': 'VGPRField', 'VSRCY1': 'VGPRField', 'VDSTX': 'VGPRField', 'VDSTY': 'VDSTYEnc'}
|
||||
FIELD_ORDER = {
|
||||
'SOP2': ['op', 'sdst', 'ssrc0', 'ssrc1'], 'SOP1': ['op', 'sdst', 'ssrc0'], 'SOPC': ['op', 'ssrc0', 'ssrc1'],
|
||||
'SOPK': ['op', 'sdst', 'simm16'], 'SOPP': ['op', 'simm16'], 'VOP1': ['op', 'vdst', 'src0'], 'VOPC': ['op', 'src0', 'vsrc1'],
|
||||
'VOP2': ['op', 'vdst', 'src0', 'vsrc1'], 'VOP3SD': ['op', 'vdst', 'sdst', 'src0', 'src1', 'src2', 'clmp'],
|
||||
'SMEM': ['op', 'sdata', 'sbase', 'soffset', 'offset', 'glc', 'dlc'], 'DS': ['op', 'vdst', 'addr', 'data0', 'data1'],
|
||||
'VOP3': ['op', 'vdst', 'src0', 'src1', 'src2', 'omod', 'neg', 'abs', 'clmp', 'opsel'],
|
||||
'VOP3P': ['op', 'vdst', 'src0', 'src1', 'src2', 'neg', 'neg_hi', 'opsel', 'opsel_hi', 'clmp'],
|
||||
'FLAT': ['op', 'vdst', 'addr', 'data', 'saddr', 'offset', 'seg', 'dlc', 'glc', 'slc'],
|
||||
'MUBUF': ['op', 'vdata', 'vaddr', 'srsrc', 'soffset', 'offset', 'offen', 'idxen', 'glc', 'dlc', 'slc', 'tfe'],
|
||||
'MTBUF': ['op', 'vdata', 'vaddr', 'srsrc', 'soffset', 'offset', 'format', 'offen', 'idxen', 'glc', 'dlc', 'slc', 'tfe'],
|
||||
'MIMG': ['op', 'vdata', 'vaddr', 'srsrc', 'ssamp', 'dmask', 'dim', 'unrm', 'dlc', 'glc', 'slc'],
|
||||
'EXP': ['en', 'target', 'vsrc0', 'vsrc1', 'vsrc2', 'vsrc3', 'done', 'row'],
|
||||
'VINTERP': ['op', 'vdst', 'src0', 'src1', 'src2', 'waitexp', 'clmp', 'opsel', 'neg'],
|
||||
'VOPD': ['opx', 'opy', 'vdstx', 'vdsty', 'srcx0', 'vsrcx1', 'srcy0', 'vsrcy1'],
|
||||
'LDSDIR': ['op', 'vdst', 'attr', 'attr_chan', 'wait_va']}
|
||||
SRC_EXTRAS = {233: 'DPP8', 234: 'DPP8FI', 250: 'DPP16', 251: 'VCCZ', 252: 'EXECZ', 254: 'LDS_DIRECT'}
|
||||
FLOAT_MAP = {'0.5': 'POS_HALF', '-0.5': 'NEG_HALF', '1.0': 'POS_ONE', '-1.0': 'NEG_ONE', '2.0': 'POS_TWO', '-2.0': 'NEG_TWO',
|
||||
'4.0': 'POS_FOUR', '-4.0': 'NEG_FOUR', '1/(2*PI)': 'INV_2PI', '0': 'ZERO'}
|
||||
INST_PATTERN = re.compile(r'^([SVD]S?_[A-Z0-9_]+|(?:FLAT|GLOBAL|SCRATCH)_[A-Z0-9_]+)\s+(\d+)\s*$', re.M)
|
||||
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# PDF PARSING WITH PAGE CACHING
|
||||
# Generic PDF extraction tools
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
class CachedPDF:
|
||||
"""PDF wrapper with page text/table caching for faster repeated access."""
|
||||
def __init__(self, pdf):
|
||||
self._pdf, self._text_cache, self._table_cache = pdf, {}, {}
|
||||
def __len__(self): return len(self._pdf.pages)
|
||||
def text(self, i):
|
||||
if i not in self._text_cache: self._text_cache[i] = self._pdf.pages[i].extract_text() or ''
|
||||
return self._text_cache[i]
|
||||
def tables(self, i):
|
||||
if i not in self._table_cache: self._table_cache[i] = [t.extract() for t in self._pdf.pages[i].find_tables()]
|
||||
return self._table_cache[i]
|
||||
def extract(url: str) -> list[list[tuple[float, float, str, str]]]:
|
||||
"""Extract positioned text from PDF. Returns list of text elements (x, y, text, font) per page."""
|
||||
data = fetch(url).read_bytes()
|
||||
|
||||
def _parse_bits(s: str) -> tuple[int, int] | None:
|
||||
return (int(m.group(1)), int(m.group(2) or m.group(1))) if (m := re.match(r'\[(\d+)(?::(\d+))?\]', s)) else None
|
||||
# Parse xref table to locate objects
|
||||
xref: dict[int, int] = {}
|
||||
pos = int(re.search(rb'startxref\s+(\d+)', data).group(1)) + 4
|
||||
while data[pos:pos+7] != b'trailer':
|
||||
while data[pos:pos+1] in b' \r\n': pos += 1
|
||||
line_end = data.find(b'\n', pos)
|
||||
start_obj, count = map(int, data[pos:line_end].split()[:2])
|
||||
pos = line_end + 1
|
||||
for i in range(count):
|
||||
if data[pos+17:pos+18] == b'n' and (off := int(data[pos:pos+10])) > 0: xref[start_obj + i] = off
|
||||
pos += 20
|
||||
|
||||
def _parse_fields_table(table: list, fmt: str, enums: set[str]) -> list[tuple]:
|
||||
fields = []
|
||||
for row in table[1:]:
|
||||
if not row or not row[0]: continue
|
||||
name, bits_str = row[0].split('\n')[0].strip(), (row[1] or '').split('\n')[0].strip()
|
||||
if not (bits := _parse_bits(bits_str)): continue
|
||||
enc_val, hi, lo = None, bits[0], bits[1]
|
||||
if name == 'ENCODING' and row[2]:
|
||||
desc = row[2]
|
||||
# Handle shared FLAT/GLOBAL/SCRATCH table: look for format-specific encoding
|
||||
fmt_key = fmt.lstrip('V').lower().capitalize() # VFLAT -> Flat, VGLOBAL -> Global
|
||||
if m := re.search(rf"{fmt_key}='b([01_]+)", desc):
|
||||
enc_bits = m.group(1).replace('_', '')
|
||||
elif m := re.search(r"(?:'b|Must be:\s*)([01_]+)", desc):
|
||||
enc_bits = m.group(1).replace('_', '')
|
||||
else:
|
||||
enc_bits = None
|
||||
if enc_bits:
|
||||
enc_val, declared_width, actual_width = int(enc_bits, 2), hi - lo + 1, len(enc_bits)
|
||||
if actual_width > declared_width: lo = hi - actual_width + 1
|
||||
ftype = f"{fmt}Op" if name == 'OP' and f"{fmt}Op" in enums else FIELD_TYPES.get(name.upper())
|
||||
fields.append((name, hi, lo, enc_val, ftype))
|
||||
return fields
|
||||
def get_stream(n: int) -> bytes:
|
||||
obj = data[xref[n]:data.find(b'endobj', xref[n])]
|
||||
raw = obj[obj.find(b'stream\n') + 7:obj.find(b'\nendstream')]
|
||||
return zlib.decompress(raw) if b'/FlateDecode' in obj else raw
|
||||
|
||||
def _parse_single_pdf(url: str):
|
||||
"""Parse a single PDF and return (formats, enums, src_enum, doc_name, instructions)."""
|
||||
import pdfplumber
|
||||
from tinygrad.helpers import fetch
|
||||
# Find page content streams and extract text
|
||||
pages = []
|
||||
for n in sorted(xref):
|
||||
if b'/Type /Page' not in data[xref[n]:xref[n]+500]: continue
|
||||
if not (m := re.search(rb'/Contents (\d+) 0 R', data[xref[n]:xref[n]+500])): continue
|
||||
stream = get_stream(int(m.group(1))).decode('latin-1')
|
||||
elements, font = [], ''
|
||||
for bt in re.finditer(r'BT(.*?)ET', stream, re.S):
|
||||
x, y = 0.0, 0.0
|
||||
for m in re.finditer(r'(/F[\d.]+) [\d.]+ Tf|([\d.+-]+) ([\d.+-]+) Td|[\d.+-]+ [\d.+-]+ [\d.+-]+ [\d.+-]+ ([\d.+-]+) ([\d.+-]+) Tm|<([0-9A-Fa-f]+)>.*?Tj|\[([^\]]+)\] TJ', bt.group(1)):
|
||||
if m.group(1): font = m.group(1)
|
||||
elif m.group(2): x, y = x + float(m.group(2)), y + float(m.group(3))
|
||||
elif m.group(4): x, y = float(m.group(4)), float(m.group(5))
|
||||
elif m.group(6) and (t := bytes.fromhex(m.group(6)).decode('latin-1')).strip(): elements.append((x, y, t, font))
|
||||
elif m.group(7) and (t := ''.join(bytes.fromhex(h).decode('latin-1') for h in re.findall(r'<([0-9A-Fa-f]+)>', m.group(7)))).strip(): elements.append((x, y, t, font))
|
||||
pages.append(sorted(elements, key=lambda e: (-e[1], e[0])))
|
||||
return pages
|
||||
|
||||
pdf = CachedPDF(pdfplumber.open(fetch(url)))
|
||||
total_pages = len(pdf)
|
||||
def extract_tables(pages: list[list[tuple[float, float, str, str]]]) -> dict[int, tuple[str, list[list[str]]]]:
|
||||
"""Extract numbered tables from PDF pages. Returns {table_num: (title, rows)} where rows is list of cells per row."""
|
||||
def group_by_y(texts, key=lambda y: round(y)):
|
||||
by_y: dict[int, list[tuple[float, float, str]]] = {}
|
||||
for x, y, t, _ in texts:
|
||||
by_y.setdefault(key(y), []).append((x, y, t))
|
||||
return by_y
|
||||
|
||||
# Auto-detect document type
|
||||
first_page = pdf.text(0)
|
||||
is_cdna4, is_cdna3 = 'CDNA4' in first_page or 'CDNA 4' in first_page, 'CDNA3' in first_page or 'MI300' in first_page
|
||||
is_cdna, is_rdna4 = is_cdna3 or is_cdna4, 'RDNA4' in first_page or 'RDNA 4' in first_page
|
||||
is_rdna35, is_rdna3 = 'RDNA3.5' in first_page or 'RDNA 3.5' in first_page, 'RDNA3' in first_page and 'RDNA3.5' not in first_page
|
||||
doc_name = "CDNA4" if is_cdna4 else "CDNA3" if is_cdna3 else "RDNA4" if is_rdna4 else "RDNA3.5" if is_rdna35 else "RDNA3" if is_rdna3 else "Unknown"
|
||||
# Find all table headers by merging text on same line
|
||||
table_positions = []
|
||||
for page_idx, texts in enumerate(pages):
|
||||
for items in group_by_y(texts).values():
|
||||
line = ''.join(t for _, t in sorted((x, t) for x, _, t in items))
|
||||
if m := re.search(r'Table (\d+)\. (.+)', line):
|
||||
table_positions.append((int(m.group(1)), m.group(2).strip(), page_idx, items[0][1]))
|
||||
table_positions.sort(key=lambda t: (t[2], -t[3]))
|
||||
|
||||
# Find Microcode Formats section (for formats/enums)
|
||||
microcode_start = next((i for i in range(int(total_pages * 0.2), total_pages)
|
||||
if re.search(r'\d+\.\d+\.\d+\.\s+SOP2\b|Chapter \d+\.\s+Microcode Formats', pdf.text(i))), int(total_pages * 0.9))
|
||||
# Find Instructions section (for pseudocode)
|
||||
instr_start = next((i for i in range(int(total_pages * 0.1), int(total_pages * 0.5))
|
||||
if re.search(r'Chapter \d+\.\s+Instructions\b', pdf.text(i))), total_pages // 3)
|
||||
instr_end = next((i for start in [int(total_pages * 0.6), int(total_pages * 0.5), instr_start]
|
||||
for i in range(start, min(start + 100, total_pages))
|
||||
if re.search(r'Chapter \d+\.\s+Microcode Formats', pdf.text(i))), total_pages)
|
||||
|
||||
# Parse src enum from SSRC encoding table
|
||||
src_enum = dict(SRC_EXTRAS)
|
||||
for i in range(microcode_start, min(microcode_start + 10, total_pages)):
|
||||
text = pdf.text(i)
|
||||
if 'SSRC0' in text and 'VCC_LO' in text:
|
||||
for m in re.finditer(r'^(\d+)\s+(\S+)', text, re.M):
|
||||
val, name = int(m.group(1)), m.group(2).rstrip('.:')
|
||||
if name in FLOAT_MAP: src_enum[val] = FLOAT_MAP[name]
|
||||
elif re.match(r'^[A-Z][A-Z0-9_]*$', name): src_enum[val] = name
|
||||
# For each table, find rows with matching X positions
|
||||
result: dict[int, tuple[str, list[list[str]]]] = {}
|
||||
for num, title, start_page, header_y in table_positions:
|
||||
rows, col_xs = [], None
|
||||
for page_idx in range(start_page, len(pages)):
|
||||
page_texts = [(x, y, t) for x, y, t, _ in pages[page_idx] if 30 < y < 760 and (page_idx > start_page or y < header_y)]
|
||||
for items in sorted(group_by_y([(x, y, t, '') for x, y, t in page_texts], key=lambda y: round(y / 5)).values(), key=lambda items: -items[0][1]):
|
||||
xs = tuple(sorted(round(x) for x, _, _ in items))
|
||||
if col_xs is None:
|
||||
if len(xs) < 2: continue # Skip single-column rows before table starts
|
||||
col_xs = xs
|
||||
elif len(xs) == 1 and xs[0] in col_xs: continue # Skip continuation rows at known column positions
|
||||
elif not any(c in xs for c in col_xs[:2]): break # Row missing first columns = end of table
|
||||
rows.append([t for _, t in sorted((x, t) for x, _, t in items)])
|
||||
else: continue
|
||||
break
|
||||
if rows: result[num] = (title, rows)
|
||||
return result
|
||||
|
||||
# Parse opcode tables
|
||||
full_text = '\n'.join(pdf.text(i) for i in range(microcode_start, min(microcode_start + 50, total_pages)))
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# AMD specific extraction
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
def extract_enums(tables: dict[int, tuple[str, list[list[str]]]]) -> dict[str, dict[int, str]]:
|
||||
"""Extract all enums from tables. Returns {enum_name: {value: name}}."""
|
||||
enums: dict[str, dict[int, str]] = {}
|
||||
for m in re.finditer(r'Table \d+\. (\w+) Opcodes(.*?)(?=Table \d+\.|\n\d+\.\d+\.\d+\.\s+\w+\s*\nDescription|$)', full_text, re.S):
|
||||
if ops := {int(x.group(1)): x.group(2) for x in re.finditer(r'(\d+)\s+([A-Z][A-Z0-9_]+)', m.group(2))}:
|
||||
enums[m.group(1) + "Op"] = ops
|
||||
if vopd_m := re.search(r'Table \d+\. VOPD Y-Opcodes\n(.*?)(?=Table \d+\.|15\.\d)', full_text, re.S):
|
||||
if ops := {int(x.group(1)): x.group(2) for x in re.finditer(r'(\d+)\s+(V_DUAL_\w+)', vopd_m.group(1))}:
|
||||
enums["VOPDOp"] = ops
|
||||
enum_names = set(enums.keys())
|
||||
for num, (title, rows) in tables.items():
|
||||
# Opcode enums from "XXX Opcodes" tables
|
||||
if m := re.match(r'(\w+) (?:Y-)?Opcodes', title):
|
||||
fmt_name = 'VOPD' if 'Y-Opcodes' in title else m.group(1)
|
||||
ops: dict[int, str] = {}
|
||||
for row in rows:
|
||||
for i in range(0, len(row) - 1, 2):
|
||||
if row[i].isdigit() and re.match(r'^[A-Z][A-Z0-9_]+$', row[i + 1]):
|
||||
ops[int(row[i])] = row[i + 1]
|
||||
if ops: enums[fmt_name] = ops
|
||||
# BufFmt from "Data Format" tables
|
||||
if 'Data Format' in title:
|
||||
for row in rows:
|
||||
for i in range(0, len(row) - 1, 2):
|
||||
if row[i].isdigit() and re.match(r'^[\dA-Z_]+$', row[i + 1]) and 'INVALID' not in row[i + 1]:
|
||||
enums.setdefault('BufFmt', {})[int(row[i])] = row[i + 1]
|
||||
return enums
|
||||
|
||||
# Parse instruction formats
|
||||
def is_fields_table(t): return t and len(t) > 1 and t[0] and 'Field' in str(t[0][0] or '')
|
||||
def has_encoding(fields): return any(f[0] == 'ENCODING' for f in fields)
|
||||
def has_header_before_fields(text): return (pos := text.find('Field Name')) != -1 and bool(re.search(r'\d+\.\d+\.\d+\.\s+\w+\s*\n', text[:pos]))
|
||||
def extract_ins(tables: dict[int, tuple[str, list[list[str]]]]) -> tuple[dict[str, list[tuple[str, int, int]]], dict[str, str]]:
|
||||
"""Extract formats and encodings from 'XXX Fields' tables. Returns (formats, encodings)."""
|
||||
formats: dict[str, list[tuple[str, int, int]]] = {}
|
||||
encodings: dict[str, str] = {}
|
||||
for num, (title, rows) in tables.items():
|
||||
if not (m := re.match(r'(\w+) Fields$', title)): continue
|
||||
fmt_name = m.group(1)
|
||||
fields = []
|
||||
for row in rows:
|
||||
if len(row) < 2: continue
|
||||
if (bits := re.match(r'\[?(\d+):(\d+)\]?$', row[1])) or (bits := re.match(r'\[(\d+)\]$', row[1])):
|
||||
field_name = row[0].lower()
|
||||
hi, lo = int(bits.group(1)), int(bits.group(2)) if bits.lastindex >= 2 else int(bits.group(1))
|
||||
if field_name == 'encoding' and len(row) >= 3:
|
||||
enc_bits = None
|
||||
if "'b" in row[2]: enc_bits = row[2].split("'b")[-1].replace('_', '')
|
||||
elif (enc := re.search(r':\s*([01_]+)', row[2])): enc_bits = enc.group(1).replace('_', '')
|
||||
if enc_bits:
|
||||
# If encoding bits exceed field width, extend field to match (AMD docs sometimes have this)
|
||||
declared_width, actual_width = hi - lo + 1, len(enc_bits)
|
||||
if actual_width > declared_width: lo = hi - actual_width + 1
|
||||
encodings[fmt_name] = enc_bits
|
||||
fields.append((field_name, hi, lo))
|
||||
if fields: formats[fmt_name] = fields
|
||||
return formats, encodings
|
||||
|
||||
format_headers = []
|
||||
for i in range(50):
|
||||
if microcode_start + i >= total_pages: break
|
||||
text = pdf.text(microcode_start + i)
|
||||
for m in re.finditer(r'\d+\.\d+\.\d+\.\s+(\w+)\s*\n?Description', text): format_headers.append((m.group(1), i, m.start()))
|
||||
for m in re.finditer(r'\d+\.\d+\.\d+\.\s+(\w+)\s*\n', text):
|
||||
fmt_name = m.group(1)
|
||||
if is_cdna and fmt_name.isupper() and len(fmt_name) >= 2: format_headers.append((fmt_name, i, m.start()))
|
||||
elif m.start() > len(text) - 200 and 'Description' not in text[m.end():] and i + 1 < 50:
|
||||
next_text = pdf.text(microcode_start + i + 1).lstrip()
|
||||
if next_text.startswith('Description') or (next_text.startswith('"RDNA') and 'Description' in next_text[:200]):
|
||||
format_headers.append((fmt_name, i, m.start()))
|
||||
# RDNA4: Look for "Table X. Y Fields" patterns (e.g., VIMAGE, VSAMPLE, or shared FLAT/GLOBAL/SCRATCH)
|
||||
for m in re.finditer(r'Table \d+\.\s+([\w,\s]+?)\s+Fields', text):
|
||||
table_name = m.group(1).strip()
|
||||
# Handle shared table like "FLAT, GLOBAL and SCRATCH"
|
||||
if ',' in table_name or ' and ' in table_name:
|
||||
for part in re.split(r',\s*|\s+and\s+', table_name):
|
||||
fmt_name = 'V' + part.strip()
|
||||
if fmt_name not in [h[0] for h in format_headers]: format_headers.append((fmt_name, i, m.start()))
|
||||
elif table_name.startswith('V'):
|
||||
if table_name not in [h[0] for h in format_headers]: format_headers.append((table_name, i, m.start()))
|
||||
def extract_pcode(pages: list[list[tuple[float, float, str, str]]], enums: dict[str, dict[int, str]]) -> dict[tuple[str, int], str]:
|
||||
"""Extract pseudocode for instructions. Returns {(name, opcode): pseudocode}."""
|
||||
# Build lookup from instruction name to opcode
|
||||
name_to_op = {name: op for ops in enums.values() for op, name in ops.items()}
|
||||
|
||||
formats: dict[str, list] = {}
|
||||
for fmt_name, rel_idx, header_pos in format_headers:
|
||||
if fmt_name in formats: continue
|
||||
page_idx = microcode_start + rel_idx
|
||||
text = pdf.text(page_idx)
|
||||
field_pos = text.find('Field Name', header_pos)
|
||||
fields = None
|
||||
for offset in range(3):
|
||||
if page_idx + offset >= total_pages: break
|
||||
if offset > 0 and has_header_before_fields(pdf.text(page_idx + offset)): break
|
||||
for t in pdf.tables(page_idx + offset) if offset > 0 or field_pos > header_pos else []:
|
||||
if is_fields_table(t) and (f := _parse_fields_table(t, fmt_name, enum_names)) and has_encoding(f): fields = f; break
|
||||
if fields: break
|
||||
if not fields and field_pos > header_pos:
|
||||
for t in pdf.tables(page_idx):
|
||||
if is_fields_table(t) and (f := _parse_fields_table(t, fmt_name, enum_names)): fields = f; break
|
||||
if not fields: continue
|
||||
field_names = {f[0] for f in fields}
|
||||
for pg_offset in range(1, 3):
|
||||
if page_idx + pg_offset >= total_pages or has_header_before_fields(pdf.text(page_idx + pg_offset)): break
|
||||
for t in pdf.tables(page_idx + pg_offset):
|
||||
if is_fields_table(t) and (extra := _parse_fields_table(t, fmt_name, enum_names)) and not has_encoding(extra):
|
||||
for ef in extra:
|
||||
if ef[0] not in field_names: fields.append(ef); field_names.add(ef[0])
|
||||
break
|
||||
formats[fmt_name] = fields
|
||||
# First pass: find all instruction headers across all pages
|
||||
all_instructions: list[tuple[int, float, str, int]] = [] # (page_idx, y, name, opcode)
|
||||
for page_idx, page in enumerate(pages):
|
||||
by_y: dict[int, list[tuple[float, str]]] = {}
|
||||
for x, y, t, _ in page:
|
||||
by_y.setdefault(round(y), []).append((x, t))
|
||||
for y, items in sorted(by_y.items(), reverse=True):
|
||||
left = [(x, t) for x, t in items if 55 < x < 65]
|
||||
right = [(x, t) for x, t in items if 535 < x < 550]
|
||||
if left and right and left[0][1] in name_to_op and right[0][1].isdigit():
|
||||
all_instructions.append((page_idx, y, left[0][1], int(right[0][1])))
|
||||
|
||||
# Fix known PDF errors
|
||||
if 'SMEM' in formats:
|
||||
formats['SMEM'] = [(n, 13 if n == 'DLC' else 14 if n == 'GLC' else h, 13 if n == 'DLC' else 14 if n == 'GLC' else l, e, t)
|
||||
for n, h, l, e, t in formats['SMEM']]
|
||||
# RDNA4: VFLAT/VGLOBAL/VSCRATCH OP field is [20:14] not [20:13] (PDF documentation error)
|
||||
for fmt_name in ['VFLAT', 'VGLOBAL', 'VSCRATCH']:
|
||||
if fmt_name in formats:
|
||||
formats[fmt_name] = [(n, h, 14 if n == 'OP' else l, e, t) for n, h, l, e, t in formats[fmt_name]]
|
||||
if doc_name in ('RDNA3', 'RDNA3.5'):
|
||||
if 'SOPPOp' in enums:
|
||||
for k, v in {8: 'S_WAITCNT_DEPCTR', 58: 'S_TTRACEDATA', 59: 'S_TTRACEDATA_IMM'}.items():
|
||||
assert k not in enums['SOPPOp']; enums['SOPPOp'][k] = v
|
||||
if 'SOPKOp' in enums:
|
||||
for k, v in {22: 'S_SUBVECTOR_LOOP_BEGIN', 23: 'S_SUBVECTOR_LOOP_END'}.items():
|
||||
assert k not in enums['SOPKOp']; enums['SOPKOp'][k] = v
|
||||
if 'SMEMOp' in enums:
|
||||
for k, v in {34: 'S_ATC_PROBE', 35: 'S_ATC_PROBE_BUFFER'}.items():
|
||||
assert k not in enums['SMEMOp']; enums['SMEMOp'][k] = v
|
||||
if 'DSOp' in enums:
|
||||
for k, v in {24: 'DS_GWS_SEMA_RELEASE_ALL', 25: 'DS_GWS_INIT', 26: 'DS_GWS_SEMA_V', 27: 'DS_GWS_SEMA_BR', 28: 'DS_GWS_SEMA_P', 29: 'DS_GWS_BARRIER'}.items():
|
||||
assert k not in enums['DSOp']; enums['DSOp'][k] = v
|
||||
if 'FLATOp' in enums:
|
||||
for k, v in {40: 'GLOBAL_LOAD_ADDTID_B32', 41: 'GLOBAL_STORE_ADDTID_B32', 55: 'FLAT_ATOMIC_CSUB_U32'}.items():
|
||||
assert k not in enums['FLATOp']; enums['FLATOp'][k] = v
|
||||
# CDNA SDWA/DPP: PDF only has modifier fields, need VOP1/VOP2 overlay for correct encoding
|
||||
if is_cdna:
|
||||
if 'SDWA' in formats:
|
||||
formats['SDWA'] = [('ENCODING', 8, 0, 0xf9, None), ('VOP_OP', 16, 9, None, None), ('VDST', 24, 17, None, 'VGPRField'), ('VOP2_OP', 31, 25, None, None)] + \
|
||||
[f for f in formats['SDWA'] if f[0] not in ('ENCODING', 'SDST', 'SD', 'ROW_MASK')]
|
||||
if 'DPP' in formats:
|
||||
formats['DPP'] = [('ENCODING', 8, 0, 0xfa, None), ('VOP_OP', 16, 9, None, None), ('VDST', 24, 17, None, 'VGPRField'), ('VOP2_OP', 31, 25, None, None),
|
||||
('SRC0', 39, 32, None, 'Src'), ('DPP_CTRL', 48, 40, None, None), ('BOUND_CTRL', 51, 51, None, None), ('SRC0_NEG', 52, 52, None, None), ('SRC0_ABS', 53, 53, None, None),
|
||||
('SRC1_NEG', 54, 54, None, None), ('SRC1_ABS', 55, 55, None, None), ('BANK_MASK', 59, 56, None, None), ('ROW_MASK', 63, 60, None, None)]
|
||||
|
||||
# Extract pseudocode for instructions
|
||||
all_text = '\n'.join(pdf.text(i) for i in range(instr_start, instr_end))
|
||||
matches = list(INST_PATTERN.finditer(all_text))
|
||||
raw_pseudocode: dict[tuple[str, int], str] = {}
|
||||
for i, match in enumerate(matches):
|
||||
name, opcode = match.group(1), int(match.group(2))
|
||||
start, end = match.end(), matches[i + 1].start() if i + 1 < len(matches) else match.end() + 2000
|
||||
snippet = all_text[start:end].strip()
|
||||
if pseudocode := _extract_pseudocode(snippet): raw_pseudocode[(name, opcode)] = pseudocode
|
||||
|
||||
return {"formats": formats, "enums": enums, "src_enum": src_enum, "doc_name": doc_name, "pseudocode": raw_pseudocode, "is_cdna": is_cdna}
|
||||
|
||||
def _extract_pseudocode(text: str) -> str | None:
|
||||
"""Extract pseudocode from an instruction description snippet."""
|
||||
lines, result, depth, in_lambda = text.split('\n'), [], 0, 0
|
||||
for line in lines:
|
||||
s = line.strip()
|
||||
if not s or re.match(r'^\d+ of \d+$', s) or re.match(r'^\d+\.\d+\..*Instructions', s): continue
|
||||
if s.startswith(('Notes', 'Functional examples', '•', '-')): break # Stop at notes/bullets
|
||||
if s.startswith(('"RDNA', 'AMD ', 'CDNA')): continue
|
||||
if '•' in s or '–' in s: continue # Skip lines with bullets/dashes
|
||||
if '= lambda(' in s: in_lambda += 1; continue
|
||||
if in_lambda > 0:
|
||||
if s.endswith(');'): in_lambda -= 1
|
||||
continue
|
||||
if s.startswith('if '): depth += 1
|
||||
elif s.startswith('endif'): depth = max(0, depth - 1)
|
||||
if s.endswith('.') and not any(p in s for p in ['D0', 'D1', 'S0', 'S1', 'S2', 'SCC', 'VCC', 'tmp', '=']): continue
|
||||
if re.match(r'^[a-z].*\.$', s) and '=' not in s: continue
|
||||
is_code = (any(p in s for p in ['D0.', 'D1.', 'S0.', 'S1.', 'S2.', 'SCC =', 'SCC ?', 'VCC', 'EXEC', 'tmp =', 'tmp[', 'lane =', 'PC =',
|
||||
'D0[', 'D1[', 'S0[', 'S1[', 'S2[', 'MEM[', 'RETURN_DATA',
|
||||
'VADDR', 'VDATA', 'VDST', 'SADDR', 'OFFSET']) or
|
||||
s.startswith(('if ', 'else', 'elsif', 'endif', 'declare ', 'for ', 'endfor', '//')) or
|
||||
re.match(r'^[a-z_]+\s*=', s) or re.match(r'^[a-z_]+\[', s) or (depth > 0 and '=' in s))
|
||||
if is_code: result.append(s)
|
||||
return '\n'.join(result) if result else None
|
||||
|
||||
def _merge_results(results: list[dict]) -> dict:
|
||||
"""Merge multiple PDF parse results into a superset."""
|
||||
merged = {"formats": {}, "enums": {}, "src_enum": dict(SRC_EXTRAS), "doc_names": [], "pseudocode": {}, "is_cdna": False}
|
||||
for r in results:
|
||||
merged["doc_names"].append(r["doc_name"])
|
||||
merged["is_cdna"] = merged["is_cdna"] or r["is_cdna"]
|
||||
for val, name in r["src_enum"].items():
|
||||
if val in merged["src_enum"]: assert merged["src_enum"][val] == name
|
||||
else: merged["src_enum"][val] = name
|
||||
for enum_name, ops in r["enums"].items():
|
||||
if enum_name not in merged["enums"]: merged["enums"][enum_name] = {}
|
||||
for val, name in ops.items():
|
||||
if val in merged["enums"][enum_name]: assert merged["enums"][enum_name][val] == name
|
||||
else: merged["enums"][enum_name][val] = name
|
||||
for fmt_name, fields in r["formats"].items():
|
||||
if fmt_name not in merged["formats"]: merged["formats"][fmt_name] = list(fields)
|
||||
else:
|
||||
existing = {f[0]: (f[1], f[2]) for f in merged["formats"][fmt_name]}
|
||||
for f in fields:
|
||||
if f[0] in existing: assert existing[f[0]] == (f[1], f[2])
|
||||
else: merged["formats"][fmt_name].append(f)
|
||||
for key, pc in r["pseudocode"].items():
|
||||
if key not in merged["pseudocode"]: merged["pseudocode"][key] = pc
|
||||
return merged
|
||||
# Second pass: extract pseudocode between consecutive instructions
|
||||
pcode: dict[tuple[str, int], str] = {}
|
||||
for i, (page_idx, y, name, opcode) in enumerate(all_instructions):
|
||||
# Get end boundary from next instruction
|
||||
if i + 1 < len(all_instructions):
|
||||
next_page, next_y = all_instructions[i + 1][0], all_instructions[i + 1][1]
|
||||
else:
|
||||
next_page, next_y = page_idx, 0
|
||||
# Collect F6 text from current position to next instruction
|
||||
lines = []
|
||||
for p in range(page_idx, next_page + 1):
|
||||
start_y = y if p == page_idx else 800
|
||||
end_y = next_y if p == next_page else 0
|
||||
lines.extend((p, y2, t) for x, y2, t, f in pages[p] if f in ('/F6.0', '/F7.0') and end_y < y2 < start_y)
|
||||
if lines:
|
||||
# Sort by page first, then by y descending within each page (higher y = earlier text in PDF)
|
||||
pcode_lines = [t.replace('Ê', '').strip() for _, _, t in sorted(lines, key=lambda x: (x[0], -x[1]))]
|
||||
if pcode_lines: pcode[(name, opcode)] = '\n'.join(pcode_lines)
|
||||
return pcode
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# CODE GENERATION
|
||||
# Write autogen files
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
def _generate_enum_py(enums, src_enum, doc_name) -> str:
|
||||
"""Generate enum.py content (just enums, no dsl.py dependency)."""
|
||||
def enum_lines(name, items): return [f"class {name}(IntEnum):"] + [f" {n} = {v}" for v, n in sorted(items.items())] + [""]
|
||||
lines = [f"# autogenerated from AMD {doc_name} ISA PDF by pdf.py - do not edit", "from enum import IntEnum", ""]
|
||||
lines += enum_lines("SrcEnum", src_enum) + sum([enum_lines(n, ops) for n, ops in sorted(enums.items())], [])
|
||||
return '\n'.join(lines)
|
||||
|
||||
def _generate_ins_py(formats, enums, src_enum, doc_name) -> str:
|
||||
"""Generate ins.py content (instruction formats and helpers, imports dsl.py and enum.py)."""
|
||||
def field_key(f, order): return order.index(f[0].lower()) if f[0].lower() in order else 1000
|
||||
lines = [f"# autogenerated from AMD {doc_name} ISA PDF by pdf.py - do not edit",
|
||||
"# ruff: noqa: F401,F403", "from typing import Annotated",
|
||||
"from extra.assembly.amd.dsl import bits, BitField, Inst32, Inst64, Inst96, SGPR, VGPR, TTMP as TTMP, s as s, v as v, ttmp as ttmp, SSrc, Src, SImm, Imm, VDSTYEnc, SGPRField, VGPRField",
|
||||
"from extra.assembly.amd.autogen.{arch}.enum import *",
|
||||
"import functools", ""]
|
||||
format_defaults = {'VOP3P': {'opsel_hi': 3, 'opsel_hi2': 1}}
|
||||
lines.append("# instruction formats")
|
||||
# MIMG has optional NSA (Non-Sequential Address) fields that extend beyond 64 bits, but base encoding is 64-bit
|
||||
inst64_override = {'MIMG'}
|
||||
for fmt_name, fields in sorted(formats.items()):
|
||||
max_bit = max(f[1] for f in fields)
|
||||
if fmt_name in inst64_override: base = "Inst64"
|
||||
else: base = "Inst96" if max_bit > 63 else "Inst64" if max_bit > 31 or fmt_name == 'VOP3SD' else "Inst32"
|
||||
order = FIELD_ORDER.get(fmt_name, [])
|
||||
lines.append(f"class {fmt_name}({base}):")
|
||||
if enc := next((f for f in fields if f[0] == 'ENCODING'), None):
|
||||
lines.append(f" encoding = bits[{enc[1]}:{enc[2]}] == 0b{enc[3]:b}" if enc[1] != enc[2] else f" encoding = bits[{enc[1]}] == {enc[3]}")
|
||||
if defaults := format_defaults.get(fmt_name): lines.append(f" _defaults = {defaults}")
|
||||
for name, hi, lo, _, ftype in sorted([f for f in fields if f[0] != 'ENCODING'], key=lambda f: field_key(f, order)):
|
||||
ann = f":Annotated[BitField, {ftype}]" if ftype and ftype.endswith('Op') else f":{ftype}" if ftype else ""
|
||||
lines.append(f" {name.lower()}{ann} = bits[{hi}]" if hi == lo else f" {name.lower()}{ann} = bits[{hi}:{lo}]")
|
||||
def write_enums(enums: dict[str, dict[int, str]], arch: str, path: str):
|
||||
"""Write enum.py file from extracted enums."""
|
||||
lines = ["# autogenerated from AMD ISA PDF by pdf.py - do not edit", "from enum import IntEnum", ""]
|
||||
for name, values in sorted(enums.items()):
|
||||
suffix = "Op" if name not in ('Src', 'BufFmt') else ("Enum" if name == 'Src' else "")
|
||||
prefix = "BUF_FMT_" if name == 'BufFmt' else ""
|
||||
lines.append(f"class {name}{suffix}(IntEnum):")
|
||||
for val, member in sorted(values.items()):
|
||||
lines.append(f" {prefix}{member} = {val}")
|
||||
lines.append("")
|
||||
with open(path, "w") as f:
|
||||
f.write("\n".join(lines))
|
||||
|
||||
def write_ins(formats: dict[str, list[tuple[str, int, int]]], encodings: dict[str, str], enums: dict[str, dict[int, str]], arch: str, path: str):
|
||||
"""Write ins.py file from extracted formats and enums."""
|
||||
# Field types and ordering
|
||||
def field_type(name, fmt):
|
||||
if name == 'op' and fmt in enums: return f'Annotated[BitField, {fmt}Op]'
|
||||
if name in ('opx', 'opy'): return 'Annotated[BitField, VOPDOp]'
|
||||
if name == 'vdsty': return 'VDSTYEnc'
|
||||
if name in ('vdst', 'vsrc1', 'vaddr', 'vdata', 'data', 'data0', 'data1', 'addr', 'vsrc0', 'vsrc2', 'vsrc3'): return 'VGPRField'
|
||||
if name in ('sdst', 'sbase', 'sdata', 'srsrc', 'ssamp'): return 'SGPRField'
|
||||
if name.startswith('ssrc') or name in ('saddr', 'soffset'): return 'SSrc'
|
||||
if name in ('src0', 'srcx0', 'srcy0') or name.startswith('src') and name[3:].isdigit(): return 'Src'
|
||||
if name.startswith('simm'): return 'SImm'
|
||||
if name == 'offset' or name.startswith('imm'): return 'Imm'
|
||||
return None
|
||||
field_priority = ['encoding', 'op', 'opx', 'opy', 'vdst', 'vdstx', 'vdsty', 'sdst', 'vdata', 'sdata', 'addr', 'vaddr', 'data', 'data0', 'data1',
|
||||
'src0', 'srcx0', 'srcy0', 'vsrc0', 'ssrc0', 'src1', 'vsrc1', 'vsrcx1', 'vsrcy1', 'ssrc1', 'src2', 'vsrc2', 'src3', 'vsrc3',
|
||||
'saddr', 'sbase', 'srsrc', 'ssamp', 'soffset', 'offset', 'simm16', 'en', 'target', 'attr', 'attr_chan',
|
||||
'omod', 'neg', 'neg_hi', 'abs', 'clmp', 'opsel', 'opsel_hi', 'waitexp', 'wait_va',
|
||||
'dmask', 'dim', 'seg', 'format', 'offen', 'idxen', 'glc', 'dlc', 'slc', 'tfe', 'unrm', 'done', 'row']
|
||||
def sort_fields(fields):
|
||||
order = {name: i for i, name in enumerate(field_priority)}
|
||||
return sorted(fields, key=lambda f: (order.get(f[0], 1000), f[2]))
|
||||
|
||||
# Generate format classes
|
||||
lines = ["# autogenerated from AMD ISA PDF by pdf.py - do not edit", "# ruff: noqa: F401,F403",
|
||||
"from typing import Annotated",
|
||||
"from extra.assembly.amd.dsl import *",
|
||||
f"from extra.assembly.amd.autogen.{arch}.enum import *", "import functools", ""]
|
||||
for fmt_name, fields in sorted(formats.items()):
|
||||
lines.append(f"class {fmt_name}(Inst):")
|
||||
for name, hi, lo in sort_fields(fields):
|
||||
bits_str = f"bits[{hi}:{lo}]" if hi != lo else f"bits[{hi}]"
|
||||
if name == 'encoding' and fmt_name in encodings: lines.append(f" encoding = {bits_str} == 0b{encodings[fmt_name]}")
|
||||
else:
|
||||
ftype = field_type(name, fmt_name)
|
||||
lines.append(f" {name}{f':{ftype}' if ftype else ''} = {bits_str}")
|
||||
lines.append("")
|
||||
|
||||
# Generate instruction helpers
|
||||
lines.append("# instruction helpers")
|
||||
for cls_name, ops in sorted(enums.items()):
|
||||
fmt = cls_name[:-2]
|
||||
for op_val, name in sorted(ops.items()):
|
||||
seg = {"GLOBAL": ", seg=2", "SCRATCH": ", seg=1"}.get(fmt, "")
|
||||
tgt = {"GLOBAL": "FLAT, GLOBALOp", "SCRATCH": "FLAT, SCRATCHOp"}.get(fmt, f"{fmt}, {cls_name}")
|
||||
if fmt in formats or fmt in ("GLOBAL", "SCRATCH"):
|
||||
suffix = "_e32" if fmt in ("VOP1", "VOP2", "VOPC") else "_e64" if fmt == "VOP3" and op_val < 512 else ""
|
||||
if name in ('V_FMAMK_F32', 'V_FMAMK_F16'):
|
||||
lines.append(f"def {name.lower()}{suffix}(vdst, src0, K, vsrc1): return {fmt}({cls_name}.{name}, vdst, src0, vsrc1, literal=K)")
|
||||
elif name in ('V_FMAAK_F32', 'V_FMAAK_F16'):
|
||||
lines.append(f"def {name.lower()}{suffix}(vdst, src0, vsrc1, K): return {fmt}({cls_name}.{name}, vdst, src0, vsrc1, literal=K)")
|
||||
else: lines.append(f"{name.lower()}{suffix} = functools.partial({tgt}.{name}{seg})")
|
||||
src_names = {name for _, name in src_enum.items()}
|
||||
lines += [""] + [f"{name} = SrcEnum.{name}" for _, name in sorted(src_enum.items()) if name not in {'DPP8', 'DPP16'}]
|
||||
if "NULL" in src_names: lines.append("OFF = NULL\n")
|
||||
return '\n'.join(lines)
|
||||
for fmt_name, ops in sorted(enums.items()):
|
||||
seg = {"GLOBAL": ", seg=2", "SCRATCH": ", seg=1"}.get(fmt_name, "")
|
||||
tgt = {"GLOBAL": "FLAT, GLOBALOp", "SCRATCH": "FLAT, SCRATCHOp"}.get(fmt_name, f"{fmt_name}, {fmt_name}Op")
|
||||
suffix = "_e32" if fmt_name in ("VOP1", "VOP2", "VOPC") else "_e64" if fmt_name == "VOP3" and len(ops) > 0 else ""
|
||||
if fmt_name in formats or fmt_name in ("GLOBAL", "SCRATCH"):
|
||||
for op_val, name in sorted(ops.items()):
|
||||
fn_suffix = suffix if fmt_name != "VOP3" or op_val < 512 else ""
|
||||
lines.append(f"{name.lower()}{fn_suffix} = functools.partial({tgt}.{name}{seg})")
|
||||
|
||||
def _generate_str_pcode_py(enums, pseudocode, arch) -> str:
|
||||
"""Generate str_pcode.py content (raw pseudocode strings)."""
|
||||
# Get op enums for this arch (import from .ins which re-exports from .enum)
|
||||
import importlib
|
||||
autogen = importlib.import_module(f"extra.assembly.amd.autogen.{arch}.ins")
|
||||
OP_ENUMS = [getattr(autogen, name) for name in ['SOP1Op', 'SOP2Op', 'SOPCOp', 'SOPKOp', 'SOPPOp', 'SMEMOp', 'VOP1Op', 'VOP2Op', 'VOP3Op', 'VOP3SDOp', 'VOP3POp', 'VOPCOp', 'VOP3AOp', 'VOP3BOp', 'DSOp', 'FLATOp', 'GLOBALOp', 'SCRATCHOp'] if hasattr(autogen, name)]
|
||||
with open(path, "w") as f:
|
||||
f.write("\n".join(lines))
|
||||
|
||||
# Build defined ops mapping
|
||||
defined_ops: dict[tuple, list] = {}
|
||||
for enum_cls in OP_ENUMS:
|
||||
for op in enum_cls:
|
||||
if op.name.startswith(('S_', 'V_', 'DS_', 'FLAT_', 'GLOBAL_', 'SCRATCH_')): defined_ops.setdefault((op.name, op.value), []).append((enum_cls, op))
|
||||
|
||||
enum_names = [e.__name__ for e in OP_ENUMS]
|
||||
instructions: dict = {cls: {} for cls in OP_ENUMS}
|
||||
for key, pc in pseudocode.items():
|
||||
if key in defined_ops:
|
||||
for enum_cls, enum_val in defined_ops[key]: instructions[enum_cls][enum_val] = pc
|
||||
|
||||
# Build string dictionaries for each enum
|
||||
lines = [f'''# autogenerated by pdf.py - do not edit
|
||||
# to regenerate: python -m extra.assembly.amd.pdf --arch {arch}
|
||||
# ruff: noqa: E501
|
||||
from extra.assembly.amd.autogen.{arch}.enum import {", ".join(enum_names)}
|
||||
''']
|
||||
all_dict_entries: dict = {}
|
||||
for enum_cls in OP_ENUMS:
|
||||
cls_name = enum_cls.__name__
|
||||
if not instructions.get(enum_cls): continue
|
||||
dict_entries = [(op, repr(pc)) for op, pc in instructions[enum_cls].items()]
|
||||
if dict_entries:
|
||||
all_dict_entries[enum_cls] = dict_entries
|
||||
lines.append(f'{cls_name}_PCODE = {{')
|
||||
for op, escaped in dict_entries: lines.append(f" {cls_name}.{op.name}: {escaped},")
|
||||
lines.append('}\n')
|
||||
|
||||
lines.append('PSEUDOCODE_STRINGS = {')
|
||||
for enum_cls in OP_ENUMS:
|
||||
if all_dict_entries.get(enum_cls): lines.append(f' {enum_cls.__name__}: {enum_cls.__name__}_PCODE,')
|
||||
lines.append('}')
|
||||
return '\n'.join(lines)
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# MAIN GENERATION
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
def generate_arch(arch: str) -> dict:
|
||||
"""Generate enum.py, ins.py and str_pcode.py for a single architecture."""
|
||||
urls = PDF_URLS[arch]
|
||||
if isinstance(urls, str): urls = [urls]
|
||||
|
||||
print(f"\n{'='*60}\nGenerating {arch}...")
|
||||
print(f"Parsing {len(urls)} PDF(s)...")
|
||||
results = [_parse_single_pdf(url) for url in urls]
|
||||
merged = _merge_results(results) if len(results) > 1 else results[0]
|
||||
doc_name = "+".join(merged["doc_names"]) if len(results) > 1 else merged["doc_name"]
|
||||
|
||||
base_path = Path(f"extra/assembly/amd/autogen/{arch}")
|
||||
base_path.mkdir(parents=True, exist_ok=True)
|
||||
(base_path / "__init__.py").touch()
|
||||
|
||||
# Write enum.py (enums only, no dsl.py dependency)
|
||||
enum_path = base_path / "enum.py"
|
||||
enum_content = _generate_enum_py(merged["enums"], merged["src_enum"], doc_name)
|
||||
enum_path.write_text(enum_content)
|
||||
print(f"Generated {enum_path}: SrcEnum ({len(merged['src_enum'])}) + {len(merged['enums'])} enums")
|
||||
|
||||
# Write ins.py (instruction formats and helpers, imports dsl.py and enum.py)
|
||||
ins_path = base_path / "ins.py"
|
||||
ins_content = _generate_ins_py(merged["formats"], merged["enums"], merged["src_enum"], doc_name).replace("{arch}", arch)
|
||||
ins_path.write_text(ins_content)
|
||||
print(f"Generated {ins_path}: {len(merged['formats'])} formats")
|
||||
|
||||
# Write str_pcode.py (needs enum.py to exist first for imports)
|
||||
pcode_path = base_path / "str_pcode.py"
|
||||
pcode_content = _generate_str_pcode_py(merged["enums"], merged["pseudocode"], arch)
|
||||
pcode_path.write_text(pcode_content)
|
||||
print(f"Generated {pcode_path}: {len(merged['pseudocode'])} instructions")
|
||||
|
||||
return merged
|
||||
|
||||
def _generate_arch_wrapper(arch: str):
|
||||
"""Wrapper for multiprocessing - returns arch name for ordering."""
|
||||
generate_arch(arch)
|
||||
return arch
|
||||
|
||||
def generate_all():
|
||||
"""Generate all architectures in parallel."""
|
||||
with ProcessPoolExecutor() as executor:
|
||||
list(executor.map(_generate_arch_wrapper, PDF_URLS.keys()))
|
||||
def write_pcode(pcode: dict[tuple[str, int], str], enums: dict[str, dict[int, str]], arch: str, path: str):
|
||||
"""Write str_pcode.py file from extracted pseudocode."""
|
||||
# Group pseudocode by enum class
|
||||
by_enum: dict[str, list[tuple[str, int, str]]] = {}
|
||||
for fmt_name, ops in enums.items():
|
||||
for opcode, name in ops.items():
|
||||
if (name, opcode) in pcode: by_enum.setdefault(f"{fmt_name}Op", []).append((name, opcode, pcode[(name, opcode)]))
|
||||
# Generate file
|
||||
enum_names = sorted(by_enum.keys())
|
||||
lines = [f"# autogenerated by pdf.py - do not edit", f"# to regenerate: python -m extra.assembly.amd.pdf",
|
||||
"# ruff: noqa: E501", f"from extra.assembly.amd.autogen.{arch}.enum import {', '.join(enum_names)}", ""]
|
||||
for enum_name in enum_names:
|
||||
lines.append(f"{enum_name}_PCODE = {{")
|
||||
for name, opcode, code in sorted(by_enum[enum_name], key=lambda x: x[1]):
|
||||
lines.append(f" {enum_name}.{name}: {code!r},")
|
||||
lines.append("}\n")
|
||||
lines.append(f"PSEUDOCODE_STRINGS = {{{', '.join(f'{e}: {e}_PCODE' for e in enum_names)}}}")
|
||||
with open(path, "w") as f:
|
||||
f.write("\n".join(lines))
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description="Generate AMD ISA autogen files from PDF documentation")
|
||||
parser.add_argument("--arch", choices=list(PDF_URLS.keys()) + ["all"], default="rdna3")
|
||||
args = parser.parse_args()
|
||||
if args.arch == "all": generate_all()
|
||||
else: generate_arch(args.arch)
|
||||
import pathlib
|
||||
for arch, url in PDF_URLS.items():
|
||||
print(f"Processing {arch}...")
|
||||
pages = extract(url)
|
||||
tables = extract_tables(pages)
|
||||
enums = extract_enums(tables)
|
||||
formats, encodings = extract_ins(tables)
|
||||
pcode = extract_pcode(pages, enums)
|
||||
# Fix known PDF errors
|
||||
if arch == 'rdna3':
|
||||
fixes = {'SOPP': {8: 'S_WAITCNT_DEPCTR', 58: 'S_TTRACEDATA', 59: 'S_TTRACEDATA_IMM'},
|
||||
'SOPK': {22: 'S_SUBVECTOR_LOOP_BEGIN', 23: 'S_SUBVECTOR_LOOP_END'},
|
||||
'SMEM': {34: 'S_ATC_PROBE', 35: 'S_ATC_PROBE_BUFFER'},
|
||||
'DS': {24: 'DS_GWS_SEMA_RELEASE_ALL', 25: 'DS_GWS_INIT', 26: 'DS_GWS_SEMA_V', 27: 'DS_GWS_SEMA_BR', 28: 'DS_GWS_SEMA_P', 29: 'DS_GWS_BARRIER'},
|
||||
'FLAT': {40: 'GLOBAL_LOAD_ADDTID_B32', 41: 'GLOBAL_STORE_ADDTID_B32', 55: 'FLAT_ATOMIC_CSUB_U32'}}
|
||||
for fmt, ops in fixes.items(): enums[fmt] = merge_dicts([enums[fmt], ops])
|
||||
if arch in ('rdna3', 'rdna4'):
|
||||
# RDNA SMEM: PDF says DLC=[14], GLC=[16] but hardware uses DLC=[13], GLC=[14]
|
||||
if 'SMEM' in formats:
|
||||
formats['SMEM'] = [(n, 13 if n == 'dlc' else 14 if n == 'glc' else h, 13 if n == 'dlc' else 14 if n == 'glc' else l)
|
||||
for n, h, l in formats['SMEM']]
|
||||
if arch == 'cdna':
|
||||
# CDNA DS: PDF is missing the GDS field (bit 16)
|
||||
if 'DS' in formats and not any(n == 'gds' for n, _, _ in formats['DS']):
|
||||
formats['DS'].append(('gds', 16, 16))
|
||||
# CDNA DPP/SDWA: PDF only documents modifier fields (bits[63:32]), need to add VOP overlay fields (bits[31:0])
|
||||
vop_overlay = [('encoding', 8, 0), ('vop_op', 16, 9), ('vdst', 24, 17), ('vop2_op', 31, 25)]
|
||||
if 'DPP' in formats and not any(n == 'encoding' for n, _, _ in formats['DPP']):
|
||||
formats['DPP'] = vop_overlay + [('bc' if n == 'bound_ctrl' else n, h, l) for n, h, l in formats['DPP']]
|
||||
encodings['DPP'] = '11111010'
|
||||
if 'SDWA' in formats and not any(n == 'encoding' for n, _, _ in formats['SDWA']):
|
||||
formats['SDWA'] = vop_overlay + [(n, h, l) for n, h, l in formats['SDWA']]
|
||||
encodings['SDWA'] = '11111001'
|
||||
base = pathlib.Path(__file__).parent / "autogen" / arch
|
||||
write_enums(enums, arch, base / "enum.py")
|
||||
write_ins(formats, encodings, enums, arch, base / "ins.py")
|
||||
write_pcode(pcode, enums, arch, base / "str_pcode.py")
|
||||
print(f" {len(tables)} tables, {len(pcode)} pcode -> {base}")
|
||||
|
||||
@@ -1615,7 +1615,7 @@ class TestCarryBorrow(unittest.TestCase):
|
||||
v_mov_b32_e32(v[2], s[2]),
|
||||
v_mov_b32_e32(v[3], s[3]),
|
||||
v_add_co_u32(v[4], VCC, v[0], v[2]),
|
||||
v_add_co_ci_u32_e32(v[5], VCC, v[1], v[3]),
|
||||
v_add_co_ci_u32_e32(v[5], v[1], v[3]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
self.assertEqual(st.vgpr[0][4], 0x00000000, "lo result")
|
||||
|
||||
@@ -271,7 +271,7 @@ class TestVOP3P(unittest.TestCase):
|
||||
s_mov_b32(s[1], 0x44004200), # hi=4.0, lo=3.0
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_mov_b32_e32(v[1], s[1]),
|
||||
v_pk_add_f16(v[2], v[0], v[1]),
|
||||
v_pk_add_f16(v[2], v[0], v[1], opsel_hi=3, opsel_hi2=1),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = st.vgpr[0][2]
|
||||
@@ -288,7 +288,7 @@ class TestVOP3P(unittest.TestCase):
|
||||
s_mov_b32(s[1], 0x45004400), # hi=5.0, lo=4.0
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_mov_b32_e32(v[1], s[1]),
|
||||
v_pk_mul_f16(v[2], v[0], v[1]),
|
||||
v_pk_mul_f16(v[2], v[0], v[1], opsel_hi=3, opsel_hi2=1),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = st.vgpr[0][2]
|
||||
@@ -307,7 +307,7 @@ class TestVOP3P(unittest.TestCase):
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_mov_b32_e32(v[1], s[1]),
|
||||
v_mov_b32_e32(v[2], s[2]),
|
||||
v_pk_fma_f16(v[3], v[0], v[1], v[2]),
|
||||
v_pk_fma_f16(v[3], v[0], v[1], v[2], opsel_hi=3, opsel_hi2=1),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = st.vgpr[0][3]
|
||||
@@ -325,7 +325,7 @@ class TestVOP3P(unittest.TestCase):
|
||||
instructions = [
|
||||
s_mov_b32(s[0], 0x3c003c00), # packed f16: hi=1.0, lo=1.0
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_pk_add_f16(v[1], v[0], SrcEnum.POS_ONE), # Add inline constant 1.0
|
||||
v_pk_add_f16(v[1], v[0], SrcEnum.POS_ONE, opsel_hi=3, opsel_hi2=1), # Add inline constant 1.0
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = st.vgpr[0][1]
|
||||
@@ -345,7 +345,7 @@ class TestVOP3P(unittest.TestCase):
|
||||
instructions = [
|
||||
s_mov_b32(s[0], 0x44004200), # packed f16: hi=4.0, lo=3.0
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_pk_mul_f16(v[1], v[0], SrcEnum.POS_TWO),
|
||||
v_pk_mul_f16(v[1], v[0], SrcEnum.POS_TWO, opsel_hi=3, opsel_hi2=1),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = st.vgpr[0][1]
|
||||
@@ -486,12 +486,12 @@ class TestSpecialOps(unittest.TestCase):
|
||||
"""V_DOT2_F32_BF16 computes dot product of bf16 pairs."""
|
||||
# bf16 1.0 = 0x3f80, bf16 2.0 = 0x4000
|
||||
instructions = [
|
||||
s_mov_b32(s[0], 0x3f803f80), # packed bf16: 1.0, 1.0
|
||||
s_mov_b32(s[1], 0x40003f80), # packed bf16: 2.0, 1.0
|
||||
s_mov_b32(s[0], 0x3f803f80), # packed bf16: lo=1.0, hi=1.0
|
||||
s_mov_b32(s[1], 0x40003f80), # packed bf16: lo=1.0, hi=2.0
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_mov_b32_e32(v[1], s[1]),
|
||||
v_mov_b32_e32(v[2], 0),
|
||||
v_dot2_f32_bf16(v[3], v[0], v[1], v[2]),
|
||||
v_dot2_f32_bf16(v[3], v[0], v[1], v[2], opsel_hi=3, opsel_hi2=1),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
# 1.0*1.0 + 1.0*2.0 + 0 = 3.0
|
||||
@@ -510,7 +510,7 @@ class TestPackedMixedSigns(unittest.TestCase):
|
||||
s_mov_b32(s[1], 0x3c003c00), # packed: hi=1.0, lo=1.0
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_mov_b32_e32(v[1], s[1]),
|
||||
v_pk_add_f16(v[2], v[0], v[1]),
|
||||
v_pk_add_f16(v[2], v[0], v[1], opsel_hi=3, opsel_hi2=1),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = st.vgpr[0][2]
|
||||
|
||||
@@ -1,192 +1,102 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Test RDNA3 assembler/disassembler against LLVM test vectors."""
|
||||
import unittest, re, subprocess
|
||||
"""Test AMD assembler/disassembler against LLVM test vectors."""
|
||||
import unittest, re, subprocess, functools
|
||||
from tinygrad.helpers import fetch
|
||||
from extra.assembly.amd.autogen.rdna3.ins import *
|
||||
from extra.assembly.amd.asm import asm
|
||||
from extra.assembly.amd.asm import asm, disasm, detect_format
|
||||
from extra.assembly.amd.test.helpers import get_llvm_mc
|
||||
|
||||
LLVM_BASE = "https://raw.githubusercontent.com/llvm/llvm-project/main/llvm/test/MC/AMDGPU"
|
||||
|
||||
# Format info: (filename, format_class, op_enum)
|
||||
LLVM_TEST_FILES = {
|
||||
# Scalar ALU
|
||||
'sop1': ('gfx11_asm_sop1.s', SOP1, SOP1Op),
|
||||
'sop2': ('gfx11_asm_sop2.s', SOP2, SOP2Op),
|
||||
'sopp': ('gfx11_asm_sopp.s', SOPP, SOPPOp),
|
||||
'sopk': ('gfx11_asm_sopk.s', SOPK, SOPKOp),
|
||||
'sopc': ('gfx11_asm_sopc.s', SOPC, SOPCOp),
|
||||
# Vector ALU
|
||||
'vop1': ('gfx11_asm_vop1.s', VOP1, VOP1Op),
|
||||
'vop2': ('gfx11_asm_vop2.s', VOP2, VOP2Op),
|
||||
'vopc': ('gfx11_asm_vopc.s', VOPC, VOPCOp),
|
||||
'vop3': ('gfx11_asm_vop3.s', VOP3, VOP3Op),
|
||||
'vop3p': ('gfx11_asm_vop3p.s', VOP3P, VOP3POp),
|
||||
'vop3sd': ('gfx11_asm_vop3.s', VOP3SD, VOP3SDOp), # VOP3SD shares file with VOP3
|
||||
'vinterp': ('gfx11_asm_vinterp.s', VINTERP, VINTERPOp),
|
||||
'vopd': ('gfx11_asm_vopd.s', VOPD, VOPDOp),
|
||||
'vopcx': ('gfx11_asm_vopcx.s', VOPC, VOPCOp), # VOPCX uses VOPC format
|
||||
# VOP3 promotions (VOP1/VOP2/VOPC promoted to VOP3 encoding)
|
||||
'vop3_from_vop1': ('gfx11_asm_vop3_from_vop1.s', VOP3, VOP3Op),
|
||||
'vop3_from_vop2': ('gfx11_asm_vop3_from_vop2.s', VOP3, VOP3Op),
|
||||
'vop3_from_vopc': ('gfx11_asm_vop3_from_vopc.s', VOP3, VOP3Op),
|
||||
'vop3_from_vopcx': ('gfx11_asm_vop3_from_vopcx.s', VOP3, VOP3Op),
|
||||
# Memory
|
||||
'ds': ('gfx11_asm_ds.s', DS, DSOp),
|
||||
'smem': ('gfx11_asm_smem.s', SMEM, SMEMOp),
|
||||
'flat': ('gfx11_asm_flat.s', FLAT, FLATOp),
|
||||
'mubuf': ('gfx11_asm_mubuf.s', MUBUF, MUBUFOp),
|
||||
'mtbuf': ('gfx11_asm_mtbuf.s', MTBUF, MTBUFOp),
|
||||
'mimg': ('gfx11_asm_mimg.s', MIMG, MIMGOp),
|
||||
# WMMA (matrix multiply)
|
||||
'wmma': ('gfx11_asm_wmma.s', VOP3P, VOP3POp),
|
||||
# Additional features
|
||||
'vop3_features': ('gfx11_asm_vop3_features.s', VOP3, VOP3Op),
|
||||
'vop3p_features': ('gfx11_asm_vop3p_features.s', VOP3P, VOP3POp),
|
||||
'vopd_features': ('gfx11_asm_vopd_features.s', VOPD, VOPDOp),
|
||||
# Alias files (alternative mnemonics)
|
||||
'vop3_alias': ('gfx11_asm_vop3_alias.s', VOP3, VOP3Op),
|
||||
'vop3p_alias': ('gfx11_asm_vop3p_alias.s', VOP3P, VOP3POp),
|
||||
'vopc_alias': ('gfx11_asm_vopc_alias.s', VOPC, VOPCOp),
|
||||
'vopcx_alias': ('gfx11_asm_vopcx_alias.s', VOPC, VOPCOp),
|
||||
'vinterp_alias': ('gfx11_asm_vinterp_alias.s', VINTERP, VINTERPOp),
|
||||
'smem_alias': ('gfx11_asm_smem_alias.s', SMEM, SMEMOp),
|
||||
'mubuf_alias': ('gfx11_asm_mubuf_alias.s', MUBUF, MUBUFOp),
|
||||
'mtbuf_alias': ('gfx11_asm_mtbuf_alias.s', MTBUF, MTBUFOp),
|
||||
}
|
||||
RDNA_FILES = ['gfx11_asm_sop1.s', 'gfx11_asm_sop2.s', 'gfx11_asm_sopp.s', 'gfx11_asm_sopk.s', 'gfx11_asm_sopc.s',
|
||||
'gfx11_asm_vop1.s', 'gfx11_asm_vop2.s', 'gfx11_asm_vopc.s', 'gfx11_asm_vop3.s', 'gfx11_asm_vop3p.s', 'gfx11_asm_vinterp.s',
|
||||
'gfx11_asm_vopd.s', 'gfx11_asm_vopcx.s', 'gfx11_asm_vop3_from_vop1.s', 'gfx11_asm_vop3_from_vop2.s', 'gfx11_asm_vop3_from_vopc.s',
|
||||
'gfx11_asm_vop3_from_vopcx.s', 'gfx11_asm_ds.s', 'gfx11_asm_smem.s', 'gfx11_asm_flat.s', 'gfx11_asm_mubuf.s', 'gfx11_asm_mtbuf.s',
|
||||
'gfx11_asm_mimg.s', 'gfx11_asm_wmma.s', 'gfx11_asm_vop3_features.s', 'gfx11_asm_vop3p_features.s', 'gfx11_asm_vopd_features.s',
|
||||
'gfx11_asm_vop3_alias.s', 'gfx11_asm_vop3p_alias.s', 'gfx11_asm_vopc_alias.s', 'gfx11_asm_vopcx_alias.s', 'gfx11_asm_vinterp_alias.s',
|
||||
'gfx11_asm_smem_alias.s', 'gfx11_asm_mubuf_alias.s', 'gfx11_asm_mtbuf_alias.s']
|
||||
# CDNA test files - includes gfx9 files for shared instructions, plus gfx90a/gfx942 specific files
|
||||
# gfx90a_ldst_acc.s has MIMG mixed in, filtered via is_mimg check
|
||||
CDNA_FILES = ['gfx9_asm_sop1.s', 'gfx9_asm_sop2.s', 'gfx9_asm_sopp.s', 'gfx9_asm_sopk.s', 'gfx9_asm_sopc.s',
|
||||
'gfx9_asm_vop1.s', 'gfx9_asm_vop2.s', 'gfx9_asm_vopc.s', 'gfx9_asm_vop3.s', 'gfx9_asm_vop3p.s',
|
||||
'gfx9_asm_ds.s', 'gfx9_asm_flat.s', 'gfx9_asm_smem.s', 'gfx9_asm_mubuf.s', 'gfx9_asm_mtbuf.s',
|
||||
'gfx90a_ldst_acc.s', 'gfx90a_asm_features.s', 'flat-scratch-gfx942.s', 'gfx942_asm_features.s',
|
||||
'mai-gfx90a.s', 'mai-gfx942.s']
|
||||
|
||||
def parse_llvm_tests(text: str) -> list[tuple[str, bytes]]:
|
||||
"""Parse LLVM test format into (asm, expected_bytes) pairs."""
|
||||
tests, lines = [], text.split('\n')
|
||||
for i, line in enumerate(lines):
|
||||
line = line.strip()
|
||||
if not line or line.startswith(('//', '.', ';')): continue
|
||||
asm_text = line.split('//')[0].strip()
|
||||
if not asm_text: continue
|
||||
for j in range(i, min(i + 3, len(lines))):
|
||||
# Match GFX11, W32, or W64 encodings (all valid for gfx11)
|
||||
# Format 1: "// GFX11: v_foo ... ; encoding: [0x01,0x02,...]"
|
||||
# Format 2: "// GFX11: [0x01,0x02,...]" (used by DS, older files)
|
||||
if m := re.search(r'(?:GFX11|W32|W64)[^:]*:.*?encoding:\s*\[(.*?)\]', lines[j]):
|
||||
hex_bytes = m.group(1).replace('0x', '').replace(',', '').replace(' ', '')
|
||||
elif m := re.search(r'(?:GFX11|W32|W64)[^:]*:\s*\[(0x[0-9a-fA-F,x\s]+)\]', lines[j]):
|
||||
hex_bytes = m.group(1).replace('0x', '').replace(',', '').replace(' ', '')
|
||||
else:
|
||||
continue
|
||||
if hex_bytes:
|
||||
try: tests.append((asm_text, bytes.fromhex(hex_bytes)))
|
||||
except ValueError: pass
|
||||
break
|
||||
def _is_mimg(data: bytes) -> bool: return (int.from_bytes(data[:4], 'little') >> 26) & 0x3f == 0b111100
|
||||
|
||||
def _parse_llvm_tests(text: str, pattern: str) -> list[tuple[str, bytes]]:
|
||||
tests = []
|
||||
for block in text.split('\n\n'):
|
||||
asm_text, encoding = None, None
|
||||
for line in block.split('\n'):
|
||||
line = line.strip()
|
||||
if not line or line.startswith(('.', ';')): continue
|
||||
if not line.startswith('//'):
|
||||
asm_text = line.split('//')[0].strip() or asm_text
|
||||
if m := re.search(pattern + r'[^:]*:.*?(?:encoding:\s*)?\[(0x[0-9a-f,x\s]+)\]', line, re.I):
|
||||
encoding = m.group(1).replace('0x', '').replace(',', '').replace(' ', '')
|
||||
if asm_text and encoding:
|
||||
try: tests.append((asm_text, bytes.fromhex(encoding)))
|
||||
except ValueError: pass
|
||||
return tests
|
||||
|
||||
def try_assemble(text: str):
|
||||
"""Try to assemble instruction text, return bytes or None on failure."""
|
||||
try: return asm(text).to_bytes()
|
||||
except: return None
|
||||
@functools.cache
|
||||
def _get_tests(f: str, arch: str) -> list[tuple[str, bytes]]:
|
||||
text = fetch(f"{LLVM_BASE}/{f}").read_bytes().decode('utf-8', errors='ignore')
|
||||
if arch == "rdna3":
|
||||
tests = _parse_llvm_tests(text, r'(?:GFX11|W32|W64)')
|
||||
elif 'gfx90a' in f or 'gfx942' in f:
|
||||
tests = _parse_llvm_tests(text, r'(?:GFX90A|GFX942)')
|
||||
else:
|
||||
tests = _parse_llvm_tests(text, r'(?:VI9|GFX9|CHECK)')
|
||||
return [(a, d) for a, d in tests if not _is_mimg(d)] if arch == "cdna" else tests
|
||||
|
||||
def compile_asm_batch(instrs: list[str]) -> list[bytes]:
|
||||
"""Compile multiple instructions with a single llvm-mc call."""
|
||||
def _compile_asm_batch(instrs: list[str]) -> list[bytes]:
|
||||
if not instrs: return []
|
||||
asm_text = ".text\n" + "\n".join(instrs) + "\n"
|
||||
result = subprocess.run(
|
||||
[get_llvm_mc(), '-triple=amdgcn', '-mcpu=gfx1100', '-mattr=+real-true16,+wavefrontsize32', '-show-encoding'],
|
||||
input=asm_text, capture_output=True, text=True, timeout=30)
|
||||
if result.returncode != 0: raise RuntimeError(f"llvm-mc batch failed: {result.stderr.strip()}")
|
||||
# Parse all encodings from output
|
||||
results = []
|
||||
for line in result.stdout.split('\n'):
|
||||
if 'encoding:' not in line: continue
|
||||
enc = line.split('encoding:')[1].strip()
|
||||
if enc.startswith('[') and enc.endswith(']'):
|
||||
results.append(bytes.fromhex(enc[1:-1].replace('0x', '').replace(',', '').replace(' ', '')))
|
||||
if len(results) != len(instrs): raise RuntimeError(f"expected {len(instrs)} encodings, got {len(results)}")
|
||||
return results
|
||||
result = subprocess.run([get_llvm_mc(), '-triple=amdgcn', '-mcpu=gfx1100', '-mattr=+real-true16,+wavefrontsize32', '-show-encoding'],
|
||||
input=".text\n" + "\n".join(instrs) + "\n", capture_output=True, text=True, timeout=30)
|
||||
if result.returncode != 0: raise RuntimeError(f"llvm-mc failed: {result.stderr.strip()}")
|
||||
return [bytes.fromhex(line.split('encoding:')[1].strip()[1:-1].replace('0x', '').replace(',', '').replace(' ', ''))
|
||||
for line in result.stdout.split('\n') if 'encoding:' in line]
|
||||
|
||||
class TestLLVM(unittest.TestCase):
|
||||
"""Test assembler and disassembler against all LLVM test vectors."""
|
||||
tests: dict[str, list[tuple[str, bytes]]] = {}
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
for name, (filename, _, _) in LLVM_TEST_FILES.items():
|
||||
try:
|
||||
data = fetch(f"{LLVM_BASE}/{filename}").read_bytes()
|
||||
cls.tests[name] = parse_llvm_tests(data.decode('utf-8', errors='ignore'))
|
||||
except Exception as e:
|
||||
print(f"Warning: couldn't fetch {filename}: {e}")
|
||||
cls.tests[name] = []
|
||||
|
||||
# Generate test methods dynamically for each format
|
||||
def _make_asm_test(name):
|
||||
def _make_test(f: str, arch: str, test_type: str):
|
||||
def test(self):
|
||||
passed, failed, skipped = 0, 0, 0
|
||||
for asm_text, expected in self.tests.get(name, []):
|
||||
result = try_assemble(asm_text)
|
||||
if result is None: skipped += 1
|
||||
elif result == expected: passed += 1
|
||||
else: failed += 1
|
||||
print(f"{name.upper()} asm: {passed} passed, {failed} failed, {skipped} skipped")
|
||||
self.assertEqual(failed, 0)
|
||||
tests = _get_tests(f, arch)
|
||||
name = f"{arch}_{test_type}_{f}"
|
||||
if test_type == "roundtrip":
|
||||
for _, data in tests:
|
||||
decoded = detect_format(data, arch).from_bytes(data)
|
||||
self.assertEqual(decoded.to_bytes()[:len(data)], data)
|
||||
print(f"{name}: {len(tests)} passed")
|
||||
elif test_type == "asm":
|
||||
passed, skipped = 0, 0
|
||||
for asm_text, expected in tests:
|
||||
try:
|
||||
self.assertEqual(asm(asm_text).to_bytes(), expected)
|
||||
passed += 1
|
||||
except: skipped += 1
|
||||
print(f"{name}: {passed} passed, {skipped} skipped")
|
||||
elif test_type == "disasm":
|
||||
to_test = []
|
||||
for _, data in tests:
|
||||
try:
|
||||
decoded = detect_format(data, arch).from_bytes(data)
|
||||
if decoded.to_bytes()[:len(data)] == data and (d := disasm(decoded)): to_test.append((data, d))
|
||||
except: pass
|
||||
print(f"{name}: {len(to_test)} passed, {len(tests) - len(to_test)} skipped")
|
||||
if arch == "rdna3":
|
||||
for (data, _), llvm in zip(to_test, _compile_asm_batch([t[1] for t in to_test])): self.assertEqual(llvm, data)
|
||||
return test
|
||||
|
||||
def _make_disasm_test(name):
|
||||
def test(self):
|
||||
_, base_fmt_cls, base_op_enum = LLVM_TEST_FILES[name]
|
||||
# VOP3SD opcodes that share encoding with VOP3 (only for vop3sd test, not vopc promotions)
|
||||
vop3sd_opcodes = {288, 289, 290, 764, 765, 766, 767, 768, 769, 770}
|
||||
is_vopc_promotion = name in ('vop3_from_vopc', 'vop3_from_vopcx')
|
||||
class TestLLVM(unittest.TestCase): pass
|
||||
|
||||
# First pass: decode all instructions and collect disasm strings
|
||||
to_test: list[tuple[str, bytes, str | None, str | None]] = [] # (asm_text, data, disasm_str, error)
|
||||
for asm_text, data in self.tests.get(name, []):
|
||||
# Detect VOP3 promotions in VOP1/VOP2/VOPC tests: VOP3 has bits [31:26]=0b110101 in first dword
|
||||
is_vop3_enc = name in ('vop1', 'vop2', 'vopc', 'vopcx') and len(data) >= 4 and (data[3] >> 2) == 0x35
|
||||
fmt_cls, op_enum = (VOP3, VOP3Op) if is_vop3_enc else (base_fmt_cls, base_op_enum)
|
||||
try:
|
||||
if base_fmt_cls.__name__ in ('VOP3', 'VOP3SD'):
|
||||
temp = VOP3.from_bytes(data)
|
||||
op_val = temp._values.get('op', 0)
|
||||
op_val = op_val.val if hasattr(op_val, 'val') else op_val
|
||||
is_vop3sd = (op_val in vop3sd_opcodes) and not is_vopc_promotion
|
||||
decoded = VOP3SD.from_bytes(data) if is_vop3sd else VOP3.from_bytes(data)
|
||||
if is_vop3sd: VOP3SDOp(op_val)
|
||||
else: VOP3Op(op_val)
|
||||
else:
|
||||
decoded = fmt_cls.from_bytes(data)
|
||||
op_val = decoded._values.get('op', 0)
|
||||
op_val = op_val.val if hasattr(op_val, 'val') else op_val
|
||||
op_enum(op_val)
|
||||
if decoded.to_bytes()[:len(data)] != data:
|
||||
to_test.append((asm_text, data, None, "decode roundtrip failed"))
|
||||
continue
|
||||
to_test.append((asm_text, data, decoded.disasm(), None))
|
||||
except Exception as e:
|
||||
to_test.append((asm_text, data, None, f"exception: {e}"))
|
||||
|
||||
# Batch compile all disasm strings with single llvm-mc call
|
||||
disasm_strs = [(i, t[2]) for i, t in enumerate(to_test) if t[2] is not None]
|
||||
llvm_results = compile_asm_batch([s for _, s in disasm_strs]) if disasm_strs else []
|
||||
llvm_map = {i: llvm_results[j] for j, (i, _) in enumerate(disasm_strs)}
|
||||
|
||||
# Match results back
|
||||
passed, failed = 0, 0
|
||||
failures: list[str] = []
|
||||
for idx, (asm_text, data, disasm_str, error) in enumerate(to_test):
|
||||
if error:
|
||||
failed += 1; failures.append(f"{error} for {data.hex()}")
|
||||
elif disasm_str is not None and idx in llvm_map:
|
||||
llvm_bytes = llvm_map[idx]
|
||||
if llvm_bytes is not None and llvm_bytes == data: passed += 1
|
||||
elif llvm_bytes is not None: failed += 1; failures.append(f"'{disasm_str}': expected={data.hex()} got={llvm_bytes.hex()}")
|
||||
|
||||
print(f"{name.upper()} disasm: {passed} passed, {failed} failed")
|
||||
if failures[:10]: print(" " + "\n ".join(failures[:10]))
|
||||
self.assertEqual(failed, 0)
|
||||
return test
|
||||
|
||||
for name in LLVM_TEST_FILES:
|
||||
setattr(TestLLVM, f'test_{name}_asm', _make_asm_test(name))
|
||||
setattr(TestLLVM, f'test_{name}_disasm', _make_disasm_test(name))
|
||||
for f in RDNA_FILES:
|
||||
setattr(TestLLVM, f"test_rdna3_roundtrip_{f.replace('.s', '').replace('-', '_')}", _make_test(f, "rdna3", "roundtrip"))
|
||||
setattr(TestLLVM, f"test_rdna3_asm_{f.replace('.s', '').replace('-', '_')}", _make_test(f, "rdna3", "asm"))
|
||||
setattr(TestLLVM, f"test_rdna3_disasm_{f.replace('.s', '').replace('-', '_')}", _make_test(f, "rdna3", "disasm"))
|
||||
for f in CDNA_FILES:
|
||||
setattr(TestLLVM, f"test_cdna_roundtrip_{f.replace('.s', '').replace('-', '_')}", _make_test(f, "cdna", "roundtrip"))
|
||||
setattr(TestLLVM, f"test_cdna_disasm_{f.replace('.s', '').replace('-', '_')}", _make_test(f, "cdna", "disasm"))
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -1,144 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Test CDNA assembler/disassembler against LLVM test vectors."""
|
||||
import unittest, re, subprocess
|
||||
from tinygrad.helpers import fetch
|
||||
from extra.assembly.amd.autogen.cdna.ins import *
|
||||
from extra.assembly.amd.asm import disasm
|
||||
from extra.assembly.amd.test.helpers import get_llvm_mc
|
||||
|
||||
LLVM_BASE = "https://raw.githubusercontent.com/llvm/llvm-project/main/llvm/test/MC/AMDGPU"
|
||||
|
||||
def parse_llvm_tests(text: str, mnemonic_filter: str = None, size_filter: int = None) -> list[tuple[str, bytes]]:
|
||||
"""Parse LLVM test format into (asm, expected_bytes) pairs."""
|
||||
tests, lines = [], text.split('\n')
|
||||
for i, line in enumerate(lines):
|
||||
line = line.strip()
|
||||
if not line or line.startswith(('//', '.', ';')): continue
|
||||
asm_text = line.split('//')[0].strip()
|
||||
if not asm_text or (mnemonic_filter and not asm_text.startswith(mnemonic_filter)): continue
|
||||
for j in list(range(max(0, i - 3), i)) + list(range(i, min(i + 3, len(lines)))):
|
||||
if m := re.search(r'(?:VI9|GFX9|CHECK)[^:]*:.*?encoding:\s*\[(.*?)\]', lines[j]):
|
||||
hex_bytes = m.group(1).replace('0x', '').replace(',', '').replace(' ', '')
|
||||
elif m := re.search(r'CHECK[^:]*:\s*\[(0x[0-9a-fA-F,x\s]+)\]', lines[j]):
|
||||
hex_bytes = m.group(1).replace('0x', '').replace(',', '').replace(' ', '')
|
||||
else: continue
|
||||
try:
|
||||
data = bytes.fromhex(hex_bytes)
|
||||
if size_filter is None or len(data) == size_filter: tests.append((asm_text, data))
|
||||
except ValueError: pass
|
||||
break
|
||||
return tests
|
||||
|
||||
# Use gfx9 tests for compatible scalar/vector formats and gfx90a/gfx942 tests for CDNA-specific instructions
|
||||
# Format: (filename, format_class, op_enum, mcpu, mnemonic_filter, size_filter)
|
||||
CDNA_TEST_FILES = {
|
||||
# Scalar ALU - encoding is stable across GFX9/CDNA
|
||||
'sop1': ('gfx9_asm_sop1.s', SOP1, SOP1Op, 'gfx940', None, None),
|
||||
'sop2': ('gfx9_asm_sop2.s', SOP2, SOP2Op, 'gfx940', None, None),
|
||||
'sopp': ('gfx9_asm_sopp.s', SOPP, SOPPOp, 'gfx940', None, None),
|
||||
'sopp_gfx9': ('sopp-gfx9.s', SOPP, SOPPOp, 'gfx940', None, None),
|
||||
'sopk': ('gfx9_asm_sopk.s', SOPK, SOPKOp, 'gfx940', None, None),
|
||||
'sopc': ('gfx9_asm_sopc.s', SOPC, SOPCOp, 'gfx940', None, None),
|
||||
# Vector ALU - encoding is mostly stable
|
||||
'vop1': ('gfx9_asm_vop1.s', VOP1, VOP1Op, 'gfx940', None, None),
|
||||
'vop1_gfx9': ('vop1-gfx9.s', VOP1, VOP1Op, 'gfx940', None, None),
|
||||
'vop2': ('gfx9_asm_vop2.s', VOP2, VOP2Op, 'gfx940', None, None),
|
||||
'vopc': ('gfx9_asm_vopc.s', VOPC, VOPCOp, 'gfx940', None, None),
|
||||
'vop3p': ('gfx9_asm_vop3p.s', VOP3P, VOP3POp, 'gfx940', None, None),
|
||||
'vop3_gfx9': ('vop3-gfx9.s', VOP3A, VOP3AOp, 'gfx940', None, 8), # Only 64-bit VOP3 instructions
|
||||
# Memory instructions
|
||||
'ds': ('gfx9_asm_ds.s', DS, DSOp, 'gfx940', None, None),
|
||||
'ds_gfx9': ('ds-gfx9.s', DS, DSOp, 'gfx940', None, None),
|
||||
# CDNA memory instructions (gfx90a has correct FLAT/MUBUF encodings with acc registers)
|
||||
'flat_gfx90a': ('gfx90a_ldst_acc.s', FLAT, FLATOp, 'gfx90a', 'flat_', None),
|
||||
'global_gfx90a': ('gfx90a_ldst_acc.s', FLAT, FLATOp, 'gfx90a', 'global_', None),
|
||||
'mubuf_gfx90a': ('gfx90a_ldst_acc.s', MUBUF, MUBUFOp, 'gfx90a', 'buffer_', None),
|
||||
'mubuf_gfx9': ('mubuf-gfx9.s', MUBUF, MUBUFOp, 'gfx940', None, None),
|
||||
'scratch_gfx942': ('flat-scratch-gfx942.s', FLAT, FLATOp, 'gfx942', 'scratch_', None),
|
||||
# CDNA-specific: MFMA/MAI instructions
|
||||
'mai': ('mai-gfx942.s', VOP3P, VOP3POp, 'gfx942', None, None),
|
||||
# SDWA and DPP format tests for VOP1 (VOP2 has different bit layout, tested separately)
|
||||
'sdwa_vop1': ('gfx9_asm_vop1.s', SDWA, VOP1Op, 'gfx940', None, None),
|
||||
'dpp_vop1': ('gfx9_asm_vop1.s', DPP, VOP1Op, 'gfx940', None, None),
|
||||
}
|
||||
|
||||
class TestLLVMCDNA(unittest.TestCase):
|
||||
"""Test CDNA instruction format decode/encode roundtrip and disassembly."""
|
||||
tests: dict[str, list[tuple[str, bytes]]] = {}
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
for name, (filename, _, _, _, mnemonic_filter, size_filter) in CDNA_TEST_FILES.items():
|
||||
try:
|
||||
data = fetch(f"{LLVM_BASE}/{filename}").read_bytes()
|
||||
cls.tests[name] = parse_llvm_tests(data.decode('utf-8', errors='ignore'), mnemonic_filter, size_filter)
|
||||
except Exception as e:
|
||||
print(f"Warning: couldn't fetch {filename}: {e}")
|
||||
cls.tests[name] = []
|
||||
|
||||
def _get_val(v): return v.val if hasattr(v, 'val') else v
|
||||
|
||||
def _filter_and_decode(tests, fmt_cls, op_enum):
|
||||
"""Filter tests and decode instructions, yielding (asm_text, data, decoded, error)."""
|
||||
fn, is_sdwa, is_dpp = fmt_cls.__name__, fmt_cls.__name__ == 'SDWA', fmt_cls.__name__ == 'DPP'
|
||||
for asm_text, data in tests:
|
||||
has_lit = False
|
||||
# SDWA/DPP format tests: only accept matching 8-byte instructions
|
||||
if is_sdwa:
|
||||
if len(data) != 8 or data[0] != 0xf9: continue
|
||||
elif is_dpp:
|
||||
if len(data) != 8 or data[0] != 0xfa: continue
|
||||
elif fmt_cls._size() == 4 and len(data) == 8:
|
||||
if data[0] in (0xf9, 0xfa): continue # Skip SDWA/DPP (tested separately)
|
||||
has_lit = data[0] == 255 or (len(data) >= 2 and data[1] == 255 and fn in ('SOP2', 'SOPC'))
|
||||
if fn == 'SOPK': has_lit = has_lit or ((int.from_bytes(data[:4], 'little') >> 23) & 0x1f) == 20
|
||||
if fn == 'VOP2': has_lit = has_lit or ((int.from_bytes(data[:4], 'little') >> 25) & 0x3f) in (23, 24, 36, 37)
|
||||
if not has_lit: continue
|
||||
if len(data) > fmt_cls._size() + (4 if has_lit else 0): continue
|
||||
try:
|
||||
decoded = fmt_cls.from_bytes(data)
|
||||
# For SDWA/DPP, opcode location depends on VOP1 vs VOP2
|
||||
if is_sdwa or is_dpp:
|
||||
vop2_op = _get_val(decoded._values.get('vop2_op', 0))
|
||||
op_val = _get_val(decoded._values.get('vop_op', 0)) if vop2_op == 0x3f else vop2_op
|
||||
else:
|
||||
op_val = _get_val(decoded._values.get('op', 0))
|
||||
try: op_enum(op_val)
|
||||
except ValueError: continue
|
||||
yield asm_text, data, decoded, None
|
||||
except Exception as e:
|
||||
yield asm_text, data, None, str(e)
|
||||
|
||||
def _make_roundtrip_test(name):
|
||||
def test(self):
|
||||
_, fmt_cls, op_enum, _, _, _ = CDNA_TEST_FILES[name]
|
||||
passed, failed, failures = 0, 0, []
|
||||
for asm_text, data, decoded, error in _filter_and_decode(self.tests.get(name, []), fmt_cls, op_enum):
|
||||
if error: failed += 1; failures.append(f"'{asm_text}': {error}"); continue
|
||||
if decoded.to_bytes()[:len(data)] == data: passed += 1
|
||||
else: failed += 1; failures.append(f"'{asm_text}': orig={data.hex()} reenc={decoded.to_bytes()[:len(data)].hex()}")
|
||||
print(f"CDNA {name.upper()} roundtrip: {passed} passed, {failed} failed")
|
||||
if failures[:5]: print(" " + "\n ".join(failures[:5]))
|
||||
self.assertEqual(failed, 0)
|
||||
return test
|
||||
|
||||
def _make_disasm_test(name):
|
||||
def test(self):
|
||||
_, fmt_cls, op_enum, _, _, _ = CDNA_TEST_FILES[name]
|
||||
passed, failed, failures = 0, 0, []
|
||||
for asm_text, data, decoded, error in _filter_and_decode(self.tests.get(name, []), fmt_cls, op_enum):
|
||||
if error: failed += 1; failures.append(f"'{asm_text}': {error}"); continue
|
||||
if decoded.to_bytes()[:len(data)] != data: failed += 1; failures.append(f"'{asm_text}': roundtrip failed"); continue
|
||||
if not (disasm_text := disasm(decoded)) or not disasm_text.strip(): failed += 1; failures.append(f"'{asm_text}': empty disassembly"); continue
|
||||
passed += 1
|
||||
print(f"CDNA {name.upper()} disasm: {passed} passed, {failed} failed")
|
||||
if failures[:5]: print(" " + "\n ".join(failures[:5]))
|
||||
self.assertEqual(failed, 0)
|
||||
return test
|
||||
|
||||
for name in CDNA_TEST_FILES:
|
||||
setattr(TestLLVMCDNA, f'test_{name}_roundtrip', _make_roundtrip_test(name))
|
||||
setattr(TestLLVMCDNA, f'test_{name}_disasm', _make_disasm_test(name))
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
50
extra/assembly/amd/test/test_pdf.py
Normal file
50
extra/assembly/amd/test/test_pdf.py
Normal file
@@ -0,0 +1,50 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Test pdf.py PDF parser and enum generation."""
|
||||
import unittest, tempfile, importlib.util
|
||||
from extra.assembly.amd.pdf import extract, extract_tables, extract_enums, write_enums, PDF_URLS
|
||||
|
||||
EXPECTED = {
|
||||
"rdna3": {"pages": 655, "tables": 115, "sop2_ops": 67, "sop2_first": "S_ADD_U32"},
|
||||
"rdna4": {"pages": 711, "tables": 125, "sop2_ops": 74, "sop2_first": "S_ADD_CO_U32"},
|
||||
"cdna": {"pages": 610, "tables": 104, "sop2_ops": 52, "sop2_first": "S_ADD_U32"},
|
||||
}
|
||||
|
||||
class TestPDF2(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.data = {name: extract(url) for name, url in PDF_URLS.items()}
|
||||
cls.tables = {name: extract_tables(pages) for name, pages in cls.data.items()}
|
||||
cls.enums = {name: extract_enums(cls.tables[name]) for name in PDF_URLS}
|
||||
|
||||
def test_page_counts(self):
|
||||
for name, exp in EXPECTED.items():
|
||||
self.assertEqual(len(self.data[name]), exp["pages"], f"{name} page count")
|
||||
|
||||
def test_table_counts(self):
|
||||
for name, exp in EXPECTED.items():
|
||||
self.assertEqual(len(self.tables[name]), exp["tables"], f"{name} table count")
|
||||
|
||||
def test_tables_sequential(self):
|
||||
for name in PDF_URLS:
|
||||
nums = sorted(self.tables[name].keys())
|
||||
missing = set(range(1, max(nums) + 1)) - set(nums)
|
||||
self.assertEqual(missing, set(), f"{name} missing tables: {missing}")
|
||||
|
||||
def test_generate_enums(self):
|
||||
for name, exp in EXPECTED.items():
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
|
||||
write_enums(self.enums[name], name, f.name)
|
||||
spec = importlib.util.spec_from_file_location("enum", f.name)
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(mod)
|
||||
# Check SOP2Op
|
||||
self.assertTrue(hasattr(mod, 'SOP2Op'), f"{name} missing SOP2Op")
|
||||
self.assertEqual(len(mod.SOP2Op), exp["sop2_ops"], f"{name} SOP2Op count")
|
||||
self.assertEqual(mod.SOP2Op(0).name, exp["sop2_first"], f"{name} SOP2Op first")
|
||||
# Check all enums have at least 2 ops
|
||||
for attr in dir(mod):
|
||||
if attr.endswith('Op'):
|
||||
self.assertGreaterEqual(len(getattr(mod, attr)), 2, f"{name} {attr} has too few ops")
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -1,150 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Test that PDF parser correctly extracts format fields."""
|
||||
import unittest, os
|
||||
from extra.assembly.amd.autogen.rdna3.ins import SOP1, SOP2, SOPK, SOPP, VOP1, VOP2, VOP3SD, VOPC, FLAT, VOPD, SOP1Op, SOP2Op, VOP1Op, VOP3Op
|
||||
|
||||
# expected formats with key fields and whether they have ENCODING
|
||||
EXPECTED_FORMATS = {
|
||||
'DPP16': (['SRC0', 'DPP_CTRL', 'BANK_MASK', 'ROW_MASK'], False),
|
||||
'DPP8': (['SRC0', 'LANE_SEL0', 'LANE_SEL7'], False),
|
||||
'DS': (['OP', 'ADDR', 'DATA0', 'DATA1', 'VDST'], True),
|
||||
'EXP': (['EN', 'TARGET', 'VSRC0', 'VSRC1', 'VSRC2', 'VSRC3'], True),
|
||||
'FLAT': (['OP', 'ADDR', 'DATA', 'SADDR', 'VDST', 'OFFSET'], True),
|
||||
'LDSDIR': (['VDST', 'OP'], True),
|
||||
'MIMG': (['OP', 'VADDR', 'VDATA', 'SRSRC', 'DMASK'], True),
|
||||
'MTBUF': (['OP', 'VADDR', 'VDATA', 'SRSRC', 'FORMAT', 'SOFFSET'], True),
|
||||
'MUBUF': (['OP', 'VADDR', 'VDATA', 'SRSRC', 'SOFFSET'], True),
|
||||
'SMEM': (['OP', 'SBASE', 'SDATA', 'OFFSET', 'SOFFSET'], True),
|
||||
'SOP1': (['OP', 'SDST', 'SSRC0'], True),
|
||||
'SOP2': (['OP', 'SDST', 'SSRC0', 'SSRC1'], True),
|
||||
'SOPC': (['OP', 'SSRC0', 'SSRC1'], True),
|
||||
'SOPK': (['OP', 'SDST', 'SIMM16'], True),
|
||||
'SOPP': (['OP', 'SIMM16'], True),
|
||||
'VINTERP': (['OP', 'VDST', 'SRC0', 'SRC1', 'SRC2'], True),
|
||||
'VOP1': (['OP', 'VDST', 'SRC0'], True),
|
||||
'VOP2': (['OP', 'VDST', 'SRC0', 'VSRC1'], True),
|
||||
'VOP3': (['OP', 'VDST', 'SRC0', 'SRC1', 'SRC2'], True),
|
||||
'VOP3P': (['OP', 'VDST', 'SRC0', 'SRC1', 'SRC2'], True),
|
||||
'VOP3SD': (['OP', 'VDST', 'SDST', 'SRC0', 'SRC1', 'SRC2'], True),
|
||||
'VOPC': (['OP', 'SRC0', 'VSRC1'], True),
|
||||
'VOPD': (['OPX', 'OPY', 'SRCX0', 'SRCY0', 'VDSTX', 'VDSTY'], True),
|
||||
}
|
||||
|
||||
# Skip PDF parsing tests by default - only run with TEST_PDF_PARSER=1
|
||||
# These are slow (~5s) and only needed when regenerating autogen/
|
||||
@unittest.skipUnless(os.environ.get("TEST_PDF_PARSER"), "set TEST_PDF_PARSER=1 to run PDF parser tests")
|
||||
class TestPDFParserGenerate(unittest.TestCase):
|
||||
"""Test the PDF parser by running generate() and checking results."""
|
||||
|
||||
def test_pdf_parser(self):
|
||||
"""Single test that validates all PDF parser outputs."""
|
||||
from extra.assembly.amd.dsl import generate
|
||||
result = generate()
|
||||
|
||||
# test_all_formats_present
|
||||
for fmt_name in EXPECTED_FORMATS:
|
||||
self.assertIn(fmt_name, result["formats"], f"missing format {fmt_name}")
|
||||
|
||||
# test_format_count
|
||||
self.assertEqual(len(result["formats"]), 23)
|
||||
|
||||
# test_no_duplicate_fields
|
||||
for fmt_name, fields in result["formats"].items():
|
||||
field_names = [f[0] for f in fields]
|
||||
self.assertEqual(len(field_names), len(set(field_names)), f"{fmt_name} has duplicate fields: {field_names}")
|
||||
|
||||
# test_expected_fields
|
||||
for fmt_name, (expected_fields, has_encoding) in EXPECTED_FORMATS.items():
|
||||
fields = {f[0] for f in result["formats"].get(fmt_name, [])}
|
||||
for field in expected_fields:
|
||||
self.assertIn(field, fields, f"{fmt_name} missing {field}")
|
||||
if has_encoding:
|
||||
self.assertIn("ENCODING", fields, f"{fmt_name} should have ENCODING")
|
||||
else:
|
||||
self.assertNotIn("ENCODING", fields, f"{fmt_name} should not have ENCODING")
|
||||
|
||||
# test_vopd_no_dpp16_fields
|
||||
vopd_fields = {f[0] for f in result["formats"].get("VOPD", [])}
|
||||
for field in ['DPP_CTRL', 'BANK_MASK', 'ROW_MASK']:
|
||||
self.assertNotIn(field, vopd_fields, f"VOPD should not have {field}")
|
||||
|
||||
# test_dpp16_no_vinterp_fields
|
||||
dpp16_fields = {f[0] for f in result["formats"].get("DPP16", [])}
|
||||
for field in ['VDST', 'WAITEXP']:
|
||||
self.assertNotIn(field, dpp16_fields, f"DPP16 should not have {field}")
|
||||
|
||||
# test_sopp_no_smem_fields
|
||||
sopp_fields = {f[0] for f in result["formats"].get("SOPP", [])}
|
||||
for field in ['SBASE', 'SDATA']:
|
||||
self.assertNotIn(field, sopp_fields, f"SOPP should not have {field}")
|
||||
|
||||
class TestPDFParser(unittest.TestCase):
|
||||
"""Verify format classes have correct fields from PDF parsing."""
|
||||
|
||||
def test_sop2_fields(self):
|
||||
"""SOP2 should have op, sdst, ssrc0, ssrc1."""
|
||||
for field in ['op', 'sdst', 'ssrc0', 'ssrc1']:
|
||||
self.assertIn(field, SOP2._fields)
|
||||
self.assertEqual(SOP2._fields['op'].hi, 29)
|
||||
self.assertEqual(SOP2._fields['op'].lo, 23)
|
||||
|
||||
def test_sop1_fields(self):
|
||||
"""SOP1 should have op, sdst, ssrc0 with correct bit positions."""
|
||||
for field in ['op', 'sdst', 'ssrc0']:
|
||||
self.assertIn(field, SOP1._fields)
|
||||
self.assertNotIn('simm16', SOP1._fields)
|
||||
self.assertEqual(SOP1._fields['ssrc0'].hi, 7)
|
||||
self.assertEqual(SOP1._fields['ssrc0'].lo, 0)
|
||||
assert SOP1._encoding is not None
|
||||
self.assertEqual(SOP1._encoding[0].hi, 31)
|
||||
self.assertEqual(SOP1._encoding[1], 0b101111101)
|
||||
|
||||
def test_vop3sd_fields(self):
|
||||
"""VOP3SD should have all fields including src0/src1/src2 from page continuation."""
|
||||
for field in ['op', 'vdst', 'sdst', 'src0', 'src1', 'src2']:
|
||||
self.assertIn(field, VOP3SD._fields)
|
||||
self.assertEqual(VOP3SD._fields['src0'].hi, 40)
|
||||
self.assertEqual(VOP3SD._fields['src0'].lo, 32)
|
||||
self.assertEqual(VOP3SD._size(), 8)
|
||||
|
||||
def test_flat_has_vdst(self):
|
||||
"""FLAT should have vdst field."""
|
||||
self.assertIn('vdst', FLAT._fields)
|
||||
self.assertEqual(FLAT._fields['vdst'].hi, 63)
|
||||
self.assertEqual(FLAT._fields['vdst'].lo, 56)
|
||||
|
||||
def test_encoding_bits(self):
|
||||
"""Verify encoding bits are correct for major formats."""
|
||||
tests = [
|
||||
(SOP2, 31, 30, 0b10),
|
||||
(SOPK, 31, 28, 0b1011),
|
||||
(SOPP, 31, 23, 0b101111111),
|
||||
(VOP1, 31, 25, 0b0111111),
|
||||
(VOP2, 31, 31, 0b0),
|
||||
(VOPC, 31, 25, 0b0111110),
|
||||
(FLAT, 31, 26, 0b110111),
|
||||
]
|
||||
for cls, hi, lo, val in tests:
|
||||
assert cls._encoding is not None
|
||||
self.assertEqual(cls._encoding[0].hi, hi, f"{cls.__name__} encoding hi")
|
||||
self.assertEqual(cls._encoding[0].lo, lo, f"{cls.__name__} encoding lo")
|
||||
self.assertEqual(cls._encoding[1], val, f"{cls.__name__} encoding val")
|
||||
|
||||
def test_opcode_enums_exist(self):
|
||||
"""Verify opcode enums are generated with expected counts."""
|
||||
self.assertGreater(len(SOP1Op), 50)
|
||||
self.assertGreater(len(SOP2Op), 50)
|
||||
self.assertGreater(len(VOP1Op), 50)
|
||||
self.assertGreater(len(VOP3Op), 200)
|
||||
|
||||
def test_vopd_no_duplicate_fields(self):
|
||||
"""VOPD should not have duplicate fields and should not include DPP16 fields."""
|
||||
field_names = list(VOPD._fields.keys())
|
||||
self.assertEqual(len(field_names), len(set(field_names)))
|
||||
for field in ['srcx0', 'srcy0', 'opx', 'opy']:
|
||||
self.assertIn(field, VOPD._fields)
|
||||
for field in ['dpp_ctrl', 'bank_mask', 'row_mask']:
|
||||
self.assertNotIn(field, VOPD._fields)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
38
extra/viz/cli.py
Executable file
38
extra/viz/cli.py
Executable file
@@ -0,0 +1,38 @@
|
||||
#!/usr/bin/env python3
|
||||
import argparse, pathlib
|
||||
from typing import Iterator
|
||||
from tinygrad.viz import serve as viz
|
||||
from tinygrad.uop.ops import RewriteTrace
|
||||
from tinygrad.helpers import temp, ansistrip, colored
|
||||
|
||||
def optional_eq(val:dict, arg:str|None) -> bool: return arg is None or ansistrip(val["name"]) == arg
|
||||
|
||||
def print_data(data:dict) -> None:
|
||||
if isinstance(data.get("value"), Iterator):
|
||||
for m in data["value"]:
|
||||
if not m["diff"]: continue
|
||||
fp = pathlib.Path(m["upat"][0][0])
|
||||
print(f"{fp.parent.name}/{fp.name}:{m['upat'][0][1]}")
|
||||
print(m["upat"][1])
|
||||
for line in m["diff"]:
|
||||
color = "red" if line.startswith("-") else "green" if line.startswith("+") else None
|
||||
print(colored(line, color))
|
||||
if data.get("src") is not None: print(data["src"])
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--kernel', type=str, default=None, metavar="NAME", help='Select a kernel by name (optional name, default: only list names)')
|
||||
parser.add_argument('--select', type=str, default=None, metavar="NAME",
|
||||
help='Select an item within the chosen kernel (optional name, default: only list names)')
|
||||
args = parser.parse_args()
|
||||
|
||||
viz.trace = viz.load_pickle(pathlib.Path(temp("rewrites.pkl", append_user=True)), default=RewriteTrace([], [], {}))
|
||||
viz.ctxs = viz.get_rewrites(viz.trace)
|
||||
for k in viz.ctxs:
|
||||
if not optional_eq(k, args.kernel): continue
|
||||
print(k["name"])
|
||||
if args.kernel is None: continue
|
||||
for s in k["steps"]:
|
||||
if not optional_eq(s, args.select): continue
|
||||
print(" "*s["depth"]+s['name']+(f" - {s['match_count']}" if s.get('match_count') is not None else ''))
|
||||
if args.select is not None: print_data(viz.get_render(s['query']))
|
||||
26
test/external/external_test_onnx_backend.py
vendored
26
test/external/external_test_onnx_backend.py
vendored
@@ -64,8 +64,6 @@ backend_test.exclude('test_qlinearmatmul_2D_int8_float32_cpu')
|
||||
backend_test.exclude('test_qlinearmatmul_3D_int8_float32_cpu')
|
||||
# tested in external_test_onnx_ops.py::TestMainOnnxOps.test_maxunpool_export_with_output_shape
|
||||
backend_test.exclude('test_maxunpool_export_with_output_shape_cpu')
|
||||
# tested in external_test_onnx_ops.py::TestMainOnnxOps.test_averagepool_3d_dilations_large_count_include_pad_is_1_ceil_mode_is_True
|
||||
backend_test.exclude('test_averagepool_3d_dilations_large_count_include_pad_is_1_ceil_mode_is_True_cpu')
|
||||
# tested in external_test_onnx_ops.py::TestMainOnnxOps.test_resize_downsample_scales_linear_align_corners
|
||||
backend_test.exclude('test_resize_downsample_scales_linear_align_corners_cpu')
|
||||
# tested in external_test_onnx_ops.py::TestMainOnnxOps.test_resize_downsample_scales_cubic_align_corners
|
||||
@@ -151,8 +149,6 @@ backend_test.exclude('test_hannwindow_*')
|
||||
backend_test.exclude('test_hardmax_*')
|
||||
backend_test.exclude('test_gridsample_*')
|
||||
backend_test.exclude('test_dft_*')
|
||||
backend_test.exclude('test_einsum_batch_diagonal_cpu*') # TODO: equation = '...ii ->...i'
|
||||
backend_test.exclude('test_einsum_inner_prod_cpu*') # TODO: equation = 'i,i'
|
||||
backend_test.exclude('test_unique_*')
|
||||
backend_test.exclude('test_sequence_*')
|
||||
backend_test.exclude('test_nonmaxsuppression_*')
|
||||
@@ -171,16 +167,14 @@ backend_test.exclude('test_split_to_sequence_*')
|
||||
backend_test.exclude('test_ai_onnx_ml_tree_ensemble_*') # https://github.com/onnx/onnx/blob/main/onnx/reference/ops/aionnxml/op_tree_ensemble.py#L121
|
||||
|
||||
# TODO: not yet implemented
|
||||
backend_test.exclude('test_tensorscatter_*')
|
||||
backend_test.exclude('test_l1normalization_*')
|
||||
backend_test.exclude('test_l2normalization_*')
|
||||
backend_test.exclude('test_lpnormalization_*')
|
||||
backend_test.exclude('test_einsum_scalar_cpu')
|
||||
backend_test.exclude('test_mod_mixed_sign_float16_cpu')
|
||||
backend_test.exclude('test_qlinearmatmul_2D_uint8_float16_cpu')
|
||||
backend_test.exclude('test_qlinearmatmul_3D_uint8_float16_cpu')
|
||||
backend_test.exclude('test_attention_3d_*')
|
||||
backend_test.exclude('test_attention_4d_*')
|
||||
backend_test.exclude('test_attention_4d_diff_heads_mask4d_padded_kv_cpu') # needs nonpad_kv_seqlen handling
|
||||
backend_test.exclude('test_attention_4d_fp16_cpu') # fp16 numerical issues
|
||||
backend_test.exclude('test_attention_4d_fp16_expanded_cpu') # fp16 numerical issues
|
||||
backend_test.exclude('test_attention_4d_gqa_with_past_and_present_fp16_cpu') # fp16 numerical issues
|
||||
backend_test.exclude('test_attention_4d_gqa_with_past_and_present_fp16_expanded_cpu') # fp16 numerical issues
|
||||
|
||||
|
||||
# rest of the failing tests
|
||||
@@ -197,16 +191,6 @@ backend_test.exclude('test_ai_onnx_ml_label_encoder_tensor_mapping_cpu') # bad d
|
||||
backend_test.exclude('test_if_opt_cpu') # ValueError: 13 is not a valid AttributeType
|
||||
backend_test.exclude('test_if_seq_cpu') # NotImplementedError: op='SequenceConstruct' is not supported
|
||||
|
||||
# regression from removing StrEnum in Domain
|
||||
backend_test.exclude('test_adam_cpu')
|
||||
backend_test.exclude('test_gradient_of_add_and_mul_cpu')
|
||||
backend_test.exclude('test_gradient_of_add_cpu')
|
||||
|
||||
if Device.DEFAULT in ['CL', 'METAL']:
|
||||
backend_test.exclude('test_resize_upsample_sizes_nearest_axes_2_3_cpu')
|
||||
backend_test.exclude('test_resize_upsample_sizes_nearest_axes_3_2_cpu')
|
||||
backend_test.exclude('test_resize_upsample_sizes_nearest_cpu')
|
||||
|
||||
if Device.DEFAULT == "METAL" or (OSX and Device.DEFAULT == "CL"):
|
||||
# numerical inaccuracy
|
||||
backend_test.exclude('test_mish_cpu')
|
||||
|
||||
@@ -1171,6 +1171,8 @@ class TestOps(unittest.TestCase):
|
||||
|
||||
@slow_test
|
||||
def test_einsum(self):
|
||||
# scalar
|
||||
helper_test_op([()], lambda a: torch.einsum('->', a), lambda a: Tensor.einsum('->', a))
|
||||
# matrix transpose
|
||||
helper_test_op([(10,10)], lambda a: torch.einsum('ij->ji', a), lambda a: Tensor.einsum('ij->ji', a))
|
||||
helper_test_op([(10,10)], lambda a: torch.einsum('ij -> ji', a), lambda a: Tensor.einsum('ij -> ji', a))
|
||||
@@ -1239,6 +1241,18 @@ class TestOps(unittest.TestCase):
|
||||
self.helper_test_exception([(2, 3, 4), (2, 3, 4)], lambda a, b: torch.einsum('i...j,ji...->...', [a, b]),
|
||||
lambda a, b: Tensor.einsum('i...j,ji...->...', [a, b]), expected=RuntimeError)
|
||||
|
||||
def test_einsum_trace(self):
|
||||
# inner product
|
||||
helper_test_op([(5,), (5,)], lambda a, b: torch.einsum('i,i', a, b), lambda a, b: Tensor.einsum('i,i', a, b))
|
||||
# simple diagonal
|
||||
helper_test_op([(4, 4)], lambda a: torch.einsum('ii->i', a), lambda a: Tensor.einsum('ii->i', a))
|
||||
# trace (sum of diagonal)
|
||||
helper_test_op([(4, 4)], lambda a: torch.einsum('ii->', a), lambda a: Tensor.einsum('ii->', a))
|
||||
# batch diagonal
|
||||
helper_test_op([(3, 5, 5)], lambda a: torch.einsum('...ii->...i', a), lambda a: Tensor.einsum('...ii->...i', a))
|
||||
# batch trace
|
||||
helper_test_op([(3, 5, 5)], lambda a: torch.einsum('...ii->...', a), lambda a: Tensor.einsum('...ii->...', a))
|
||||
|
||||
def test_einsum_shape_check(self):
|
||||
self.helper_test_exception([(3,8,10,5), (11,5,13,16,8)], lambda a, b: torch.einsum('pqrs,tuqvr->pstuv', [a, b]),
|
||||
lambda a, b: Tensor.einsum('pqrs,tuqvr->pstuv', [a, b]), expected=RuntimeError)
|
||||
|
||||
@@ -807,7 +807,7 @@ class TestTK(unittest.TestCase):
|
||||
|
||||
Tensor.manual_seed(42)
|
||||
|
||||
B, N, H, H_KV, D = 1, 32, 2, 1, 32
|
||||
B, N, H, H_KV, D = 1, 1024, 32, 32, 128
|
||||
|
||||
with Context(DEBUG=0):
|
||||
q = Tensor.randn(B, N, H, D, dtype=dtypes.bfloat16, requires_grad=True).contiguous()
|
||||
@@ -840,5 +840,57 @@ class TestTK(unittest.TestCase):
|
||||
np.testing.assert_allclose(v.grad.numpy(), v_ref.grad.numpy(), atol=2e-2, rtol=2e-2)
|
||||
np.testing.assert_allclose(k.grad.numpy(), k_ref.grad.numpy(), atol=5e-2, rtol=2e-2)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_fast_fa_bwd_causal_jitted(self):
|
||||
from extra.thunder.tiny.fa import flash_attention
|
||||
|
||||
Tensor.manual_seed(42)
|
||||
|
||||
B, N, H, H_KV, D = 1, 1024, 32, 32, 128
|
||||
|
||||
with Context(DEBUG=0):
|
||||
q = Tensor.randn(B, N, H, D, dtype=dtypes.bfloat16, requires_grad=True).contiguous()
|
||||
k = Tensor.randn(B, N, H_KV, D, dtype=dtypes.bfloat16, requires_grad=True).contiguous()
|
||||
v = Tensor.randn(B, N, H_KV, D, dtype=dtypes.bfloat16, requires_grad=True).contiguous()
|
||||
Tensor.realize(q, k, v)
|
||||
|
||||
do = Tensor.ones(B, N, H, D, dtype=dtypes.float32).contiguous()
|
||||
Tensor.realize(do)
|
||||
|
||||
def fn(q, k, v, do):
|
||||
q_, k_, v_ = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
|
||||
out = flash_attention(q_, k_, v_, is_causal=True)
|
||||
out = out.float().transpose(1, 2)
|
||||
out.backward(do)
|
||||
Tensor.realize(out, q.grad, k.grad, v.grad)
|
||||
return q.grad, k.grad, v.grad
|
||||
|
||||
fn_jitted = TinyJit(fn)
|
||||
|
||||
for _ in range(10):
|
||||
q = Tensor.randn(B, N, H, D, dtype=dtypes.bfloat16, requires_grad=True).contiguous()
|
||||
k = Tensor.randn(B, N, H_KV, D, dtype=dtypes.bfloat16, requires_grad=True).contiguous()
|
||||
v = Tensor.randn(B, N, H_KV, D, dtype=dtypes.bfloat16, requires_grad=True).contiguous()
|
||||
Tensor.realize(q, k, v)
|
||||
do = Tensor.ones(B, N, H, D, dtype=dtypes.float32).contiguous()
|
||||
Tensor.realize(do)
|
||||
q.grad, k.grad, v.grad = fn_jitted(q, k, v, do)
|
||||
|
||||
with Context(DEBUG=0):
|
||||
q_ref = q.detach().clone().requires_grad_(True)
|
||||
k_ref = k.detach().clone().requires_grad_(True)
|
||||
v_ref = v.detach().clone().requires_grad_(True)
|
||||
Tensor.realize(q_ref, k_ref, v_ref)
|
||||
|
||||
q_ref_, k_ref_, v_ref_ = q_ref.transpose(1, 2), k_ref.transpose(1, 2), v_ref.transpose(1, 2)
|
||||
ref = q_ref_.scaled_dot_product_attention(k_ref_, v_ref_, is_causal=True)
|
||||
ref = ref.float().transpose(1, 2)
|
||||
ref.backward(do)
|
||||
Tensor.realize(q_ref.grad, k_ref.grad, v_ref.grad)
|
||||
|
||||
np.testing.assert_allclose(q.grad.numpy(), q_ref.grad.numpy(), atol=5e-2, rtol=2e-2)
|
||||
np.testing.assert_allclose(v.grad.numpy(), v_ref.grad.numpy(), atol=2e-2, rtol=2e-2)
|
||||
np.testing.assert_allclose(k.grad.numpy(), k_ref.grad.numpy(), atol=5e-2, rtol=2e-2)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import ctypes, subprocess, tempfile, unittest
|
||||
from tinygrad.helpers import WIN
|
||||
from tinygrad.runtime.support.c import Struct
|
||||
from tinygrad.runtime.support.autogen import gen
|
||||
|
||||
class TestAutogen(unittest.TestCase):
|
||||
def test_packed_struct_sizeof(self):
|
||||
@@ -159,4 +160,97 @@ class TestAutogen(unittest.TestCase):
|
||||
assert ihdr.num_dies == 1
|
||||
assert ihdr.base_addr_64_bit == 1
|
||||
|
||||
@unittest.skipIf(WIN, "doesn't compile on windows")
|
||||
def test_gen_from_header(self):
|
||||
header_content = """
|
||||
typedef struct {
|
||||
int x;
|
||||
int y;
|
||||
} Point;
|
||||
|
||||
typedef enum {
|
||||
RED = 0,
|
||||
GREEN = 1,
|
||||
BLUE = 2
|
||||
} Color;
|
||||
|
||||
typedef struct {
|
||||
Point origin;
|
||||
int width;
|
||||
int height;
|
||||
Color color;
|
||||
} Rectangle;
|
||||
|
||||
int add_points(Point a, Point b);
|
||||
"""
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.h') as f:
|
||||
f.write(header_content)
|
||||
f.flush()
|
||||
|
||||
generated_code = gen(name="test_header", dll=None, files=[f.name])
|
||||
|
||||
namespace = {}
|
||||
exec(generated_code, namespace)
|
||||
|
||||
self.assertIn('Point', namespace)
|
||||
self.assertIn('Color', namespace)
|
||||
self.assertIn('Rectangle', namespace)
|
||||
self.assertIn('RED', namespace)
|
||||
self.assertIn('GREEN', namespace)
|
||||
self.assertIn('BLUE', namespace)
|
||||
|
||||
self.assertEqual(namespace['RED'], 0)
|
||||
self.assertEqual(namespace['GREEN'], 1)
|
||||
self.assertEqual(namespace['BLUE'], 2)
|
||||
|
||||
Point = namespace['Point']
|
||||
p = Point()
|
||||
self.assertIsInstance(p, Struct)
|
||||
self.assertTrue(hasattr(p, 'x'))
|
||||
self.assertTrue(hasattr(p, 'y'))
|
||||
|
||||
Rectangle = namespace['Rectangle']
|
||||
rect = Rectangle()
|
||||
self.assertTrue(hasattr(rect, 'origin'))
|
||||
self.assertTrue(hasattr(rect, 'width'))
|
||||
self.assertTrue(hasattr(rect, 'height'))
|
||||
self.assertTrue(hasattr(rect, 'color'))
|
||||
|
||||
@unittest.skipIf(WIN, "doesn't compile on windows")
|
||||
def test_struct_ordering(self):
|
||||
header_content = """
|
||||
struct A;
|
||||
struct C;
|
||||
typedef struct A A;
|
||||
|
||||
struct B {
|
||||
struct C *c_ptr;
|
||||
};
|
||||
|
||||
struct C {
|
||||
struct A *a_ptr;
|
||||
};
|
||||
|
||||
struct A {
|
||||
int x;
|
||||
struct B *b_ptr;
|
||||
};
|
||||
"""
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.h') as f:
|
||||
f.write(header_content)
|
||||
f.flush()
|
||||
generated_code = gen(name="test_ordering", dll=None, files=[f.name])
|
||||
namespace = {}
|
||||
exec(generated_code, namespace)
|
||||
self.assertIn('struct_A', namespace)
|
||||
self.assertIn('struct_B', namespace)
|
||||
self.assertIn('struct_C', namespace)
|
||||
A, B, C = namespace['struct_A'], namespace['struct_B'], namespace['struct_C']
|
||||
a, b, c = A(), B(), C()
|
||||
self.assertTrue(hasattr(a, 'x'))
|
||||
self.assertTrue(hasattr(a, 'b_ptr'))
|
||||
self.assertTrue(hasattr(b, 'c_ptr'))
|
||||
self.assertTrue(hasattr(c, 'a_ptr'))
|
||||
|
||||
if __name__ == "__main__": unittest.main()
|
||||
|
||||
@@ -128,10 +128,6 @@ def unwrap_class_type(cls_t): return cls_t.func if isinstance(cls_t, functools.p
|
||||
|
||||
def pluralize(st:str, cnt:int): return f"{cnt} {st}"+('' if cnt == 1 else 's')
|
||||
|
||||
class LazySeq(Generic[T]): # NOTE: Mapping requires __iter__ and __len__, Sequence requires supporting __len__ and slicing in __getitem__
|
||||
def __init__(self, gen:Callable[[int], T]): self.gen = gen
|
||||
def __getitem__(self, idx:int) -> T: return self.gen(idx)
|
||||
|
||||
# for length N coefficients `p`, returns p[0] * x**(N-1) + p[1] * x**(N-2) + ... + p[-2] * x + p[-1]
|
||||
def polyN(x:T, p:list[float]) -> T: return functools.reduce(lambda acc,c: acc*x+c, p, 0.0) # type: ignore
|
||||
|
||||
|
||||
@@ -1048,14 +1048,15 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
|
||||
return output, present
|
||||
|
||||
def attention_onnx(Q:Tensor, K:Tensor, V:Tensor, attn_mask:Tensor|None=None, past_key:Tensor|None=None, past_value:Tensor|None=None,
|
||||
is_causal:int=0, kv_num_heads:int|None=None, q_num_heads:int|None=None, qk_matmul_output_mode:int=0, scale:float|None=None,
|
||||
softcap:float=0.0, softmax_precision:int|None=None):
|
||||
nonpad_kv_seqlen:Tensor|None=None, is_causal:int=0, kv_num_heads:int|None=None, q_num_heads:int|None=None,
|
||||
qk_matmul_output_mode:int=0, scale:float|None=None, softcap:float=0.0, softmax_precision:int|None=None):
|
||||
if nonpad_kv_seqlen is not None: raise NotImplementedError("nonpad_kv_seqlen is not supported")
|
||||
input_shape_len = Q.ndim
|
||||
if input_shape_len == 3:
|
||||
assert q_num_heads is not None and kv_num_heads is not None
|
||||
Q = Q.reshape(Q.shape[0], q_num_heads, Q.shape[1], -1)
|
||||
K = K.reshape(K.shape[0], kv_num_heads, K.shape[1], -1)
|
||||
V = V.reshape(V.shape[0], kv_num_heads, V.shape[1], -1)
|
||||
Q = Q.reshape(Q.shape[0], Q.shape[1], q_num_heads, -1).permute(0, 2, 1, 3)
|
||||
K = K.reshape(K.shape[0], K.shape[1], kv_num_heads, -1).permute(0, 2, 1, 3)
|
||||
V = V.reshape(V.shape[0], V.shape[1], kv_num_heads, -1).permute(0, 2, 1, 3)
|
||||
|
||||
if past_key is not None: K = past_key.cat(K, dim=2)
|
||||
if past_value is not None: V = past_value.cat(V, dim=2)
|
||||
@@ -1170,6 +1171,17 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
|
||||
elif reduction == "min": x[i] = x[i].minimum(u)
|
||||
return x
|
||||
|
||||
def TensorScatter(data: Tensor, updates: Tensor, indices: Tensor, mode: str = 'default'):
|
||||
# scatter updates along axis -2 at positions given by indices, for each batch
|
||||
B, U, D = indices.shape[0], updates.shape[-2], data.shape[-2]
|
||||
orig_shape, data_flat, updates_flat = data.shape, data.reshape(-1, D, data.shape[-1]), updates.reshape(-1, U, updates.shape[-1])
|
||||
B_total = data_flat.shape[0]
|
||||
batch_idx = Tensor.arange(B_total, device=data.device).reshape(B_total, 1).expand(B_total, U)
|
||||
indices_expanded = indices.reshape(B, *([1] * (data.ndim - 3))).expand(*orig_shape[:-2]).reshape(B_total)
|
||||
row_idx = indices_expanded.reshape(B_total, 1).expand(B_total, U) + Tensor.arange(U, device=data.device).reshape(1, U).expand(B_total, U)
|
||||
if mode == 'circular': row_idx = row_idx % D
|
||||
return ScatterND(data_flat, batch_idx.unsqueeze(-1).cat(row_idx.unsqueeze(-1), dim=-1), updates_flat).reshape(orig_shape)
|
||||
|
||||
def ScatterElements(x: Tensor, indices: Tensor, updates: Tensor, axis=0, reduction:Literal["none", "add", "mul", "min", "max"]="none"):
|
||||
indices = (indices < 0).where(x.shape[axis], 0) + indices
|
||||
if reduction == "none": return x.scatter(axis, indices, updates)
|
||||
|
||||
@@ -132,7 +132,7 @@ def __getattr__(nm):
|
||||
tarball="https://gitlab.freedesktop.org/mesa/mesa/-/archive/mesa-25.2.7/mesa-25.2.7.tar.gz",
|
||||
prolog=["import gzip, base64"], epilog=lambda path: [system(f"{root}/extra/mesa/lvp_nir_options.sh {path}")])
|
||||
case "libclang":
|
||||
return load("libclang", "'clang-20'",
|
||||
return load("libclang", "['clang-20', 'clang']",
|
||||
lambda: [f"{system('llvm-config-20 --includedir')}/clang-c/{s}.h" for s in ["Index", "CXString", "CXSourceLocation", "CXFile"]],
|
||||
args=lambda: system("llvm-config-20 --cflags").split())
|
||||
case "metal":
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# mypy: ignore-errors
|
||||
import ctypes
|
||||
from tinygrad.runtime.support.c import DLL, Struct, CEnum, _IO, _IOW, _IOR, _IOWR
|
||||
dll = DLL('libclang', 'clang-20')
|
||||
dll = DLL('libclang', ['clang-20', 'clang'])
|
||||
CXIndex = ctypes.c_void_p
|
||||
class struct_CXTargetInfoImpl(Struct): pass
|
||||
CXTargetInfo = ctypes.POINTER(struct_CXTargetInfoImpl)
|
||||
|
||||
@@ -530,7 +530,6 @@ add_tags = PatternMatcher([
|
||||
])
|
||||
|
||||
# support for using a contiguous permuted view instead of the parent view if one exists
|
||||
# modified from kernelize.py to not use ShapeTracker
|
||||
|
||||
def found_contiguous(ctx:dict[UOp, UOp], contig:UOp, src:UOp):
|
||||
x = src
|
||||
|
||||
@@ -18,7 +18,8 @@ from tinygrad.device import Device, Buffer
|
||||
from tinygrad.engine.realize import run_schedule
|
||||
|
||||
# TODO: this should be the only usage of Device
|
||||
def canonicalize_device(device:str|None) -> str: return Device.canonicalize(device)
|
||||
def canonicalize_device(device:str|tuple|list|None) -> str|tuple[str, ...]:
|
||||
return tuple(Device.canonicalize(d) for d in device) if isinstance(device, (tuple, list)) else Device.canonicalize(device)
|
||||
|
||||
# *** all in scope Tensors are here. this gets relevant UOps ***
|
||||
|
||||
@@ -115,7 +116,7 @@ class Tensor(OpMixin):
|
||||
device:str|tuple|list|None=None, dtype:DTypeLike|None=None, requires_grad:bool|None=None, _force_unique:bool=False):
|
||||
if device is None and isinstance(data, pathlib.Path): device = f"DISK:{data.resolve()}" # keep it on the disk if device is None
|
||||
_dtype:DType|None = to_dtype(dtype) if dtype is not None else None
|
||||
_device:str|tuple[str, ...] = tuple(canonicalize_device(x) for x in device) if isinstance(device, (tuple, list)) else canonicalize_device(device)
|
||||
_device:str|tuple[str, ...] = canonicalize_device(device)
|
||||
del device, dtype
|
||||
|
||||
# tensors can have gradients if you have called .backward
|
||||
@@ -162,10 +163,10 @@ class Tensor(OpMixin):
|
||||
|
||||
# data might be on a different device
|
||||
if isinstance(_device, str): self.uop:UOp = data if data.device == _device else data.copy_to_device(_device)
|
||||
# if device is a tuple, we should have/construct a MultiLazyBuffer
|
||||
# if device is a tuple, we should have/construct a multi-device UOp
|
||||
elif isinstance(data.device, str): self.uop = Tensor(data).shard(_device).uop
|
||||
else:
|
||||
assert data.device == _device, f"MultiLazyBuffer device mismatch, {data.device} != {_device}"
|
||||
assert data.device == _device, f"multi-device UOp device mismatch, {data.device} != {_device}"
|
||||
self.uop = data
|
||||
|
||||
# add to all_tensors after construction succeeds
|
||||
@@ -373,7 +374,7 @@ class Tensor(OpMixin):
|
||||
"""
|
||||
Moves the tensor to the given device.
|
||||
"""
|
||||
device = tuple(canonicalize_device(x) for x in device) if isinstance(device, (tuple, list)) else canonicalize_device(device)
|
||||
device = canonicalize_device(device)
|
||||
if device == self.device: return self
|
||||
if not isinstance(device, str): return self.shard(device)
|
||||
ret = Tensor(self.uop, device, requires_grad=self.requires_grad)
|
||||
@@ -397,9 +398,9 @@ class Tensor(OpMixin):
|
||||
print(t.shard((t.device, t.device), axis=1).uop)
|
||||
```
|
||||
"""
|
||||
if not isinstance(self.device, str): raise RuntimeError("can't shard a MultiLazyBuffer")
|
||||
if not isinstance(self.device, str): raise RuntimeError("can't shard a multi-device tensor")
|
||||
if len(devices) == 1: return self.to(devices[0])
|
||||
devices = tuple(canonicalize_device(x) for x in devices)
|
||||
devices = cast(tuple[str, ...], canonicalize_device(devices))
|
||||
mlb = self.uop.shard(devices, self._resolve_dim(axis)) if axis is not None else self.uop.copy_to_device(devices)
|
||||
return Tensor(mlb, device=devices, requires_grad=self.requires_grad)
|
||||
|
||||
@@ -495,7 +496,7 @@ class Tensor(OpMixin):
|
||||
dtype, shape = to_dtype(dtype) if dtype is not None else dtypes.default_float, argfix(*shape)
|
||||
if not isinstance(size:=prod([x.vmax if isinstance(x, UOp) else x for x in shape]), int): raise ValueError(f"size must be int {size}")
|
||||
# TODO: add test for multidevice tensor
|
||||
device = tuple(canonicalize_device(d) for d in device) if isinstance(device, tuple) else canonicalize_device(device)
|
||||
device = canonicalize_device(device)
|
||||
return Tensor(UOp.new_buffer(device, size, dtype), device, dtype, **kwargs).shrink(((0,prod(shape)),)).reshape(shape)
|
||||
|
||||
def empty_like(self, **kwargs) -> Tensor:
|
||||
@@ -577,7 +578,7 @@ class Tensor(OpMixin):
|
||||
if not dtypes.is_float(dtype := to_dtype(dtype or dtypes.default_float)): raise ValueError(f"rand only supports float dtypes, got {dtype}")
|
||||
if not all_int(shape:=argfix(*shape)) or not all(s >= 0 for s in shape): raise ValueError(f"invalid input {shape=}")
|
||||
if device is not None and not isinstance(device, str): raise ValueError(f"rand only supports single device, got {device=}")
|
||||
device = canonicalize_device(device)
|
||||
device = cast(str, canonicalize_device(device))
|
||||
|
||||
# if shape has 0, return zero tensor
|
||||
if (numel := prod(shape)) == 0: return Tensor.zeros(shape, device=device, dtype=dtype, **kwargs)
|
||||
@@ -2062,37 +2063,33 @@ class Tensor(OpMixin):
|
||||
print(Tensor.einsum("ij,ij->", x, y).numpy())
|
||||
```
|
||||
"""
|
||||
def parse_formula(formula:str, *operands:Tensor):
|
||||
if "..." in (formula := formula.replace(" ", "")):
|
||||
ell_chars, ell_longest = "".join(c for c in string.ascii_letters if c not in formula), 0
|
||||
for i, inp in enumerate(filter(lambda x: "..." in x, inputs := formula.split("->")[0].split(","))):
|
||||
if (ell_count := max(operands[i].ndim, 1) - (len(inp) - len("..."))) > ell_longest: ell_longest = ell_count
|
||||
inputs[i] = inp.replace("...", ell_chars[-ell_count:])
|
||||
inputs_str, out_ellipse = ",".join(inputs), ell_chars[-ell_longest:]
|
||||
return (inputs_str, formula.split("->")[1].replace("...", out_ellipse)) if "->" in formula else \
|
||||
(inputs_str, out_ellipse + ''.join(sorted(c for c in inputs_str if inputs_str.count(c) == 1 and c.isalpha() and c not in out_ellipse)))
|
||||
return formula.split("->") if "->" in formula else (formula, ''.join(c for c in sorted(formula) if formula.count(c) == 1 and c.isalpha()))
|
||||
|
||||
xs:tuple[Tensor, ...] = argfix(*operands)
|
||||
inputs_str, output = parse_formula(formula, *xs)
|
||||
inputs = inputs_str.split(",")
|
||||
if len(xs)!=len(inputs): raise ValueError(f"number of inputs doesn't match number of operands in formula, expected {len(inputs)}, got {len(xs)}")
|
||||
|
||||
# map the value of each letter in the formula
|
||||
letter_val = sorted(merge_dicts([dict(zip(letters, tensor.shape)) for letters, tensor in zip(inputs, xs)]).items())
|
||||
|
||||
xs_:list[Tensor] = []
|
||||
lhs = [sorted(enumerate(s), key=lambda e:e[1]) for s in inputs]
|
||||
for x,(order,letters) in zip(xs, [list(zip(*l)) for l in lhs]):
|
||||
# permute to the sorted letter order, then reshape/expand to create dimensions for the missing letters
|
||||
xs_.append(x.permute(order).reshape([val if letter in letters else 1 for letter,val in letter_val]).expand([val for _,val in letter_val]))
|
||||
|
||||
# ordinal encode the output alphabet
|
||||
rhs_order = argsort(argsort(list(output)))
|
||||
|
||||
# sum over all axes that's not in the output, then permute to the output order
|
||||
return functools.reduce(lambda a,b:a*b, xs_) \
|
||||
.sum(axis=[axis for axis,(letter,_) in enumerate(letter_val) if letter not in output], dtype=dtype).permute(rhs_order)
|
||||
xs, formula = list(argfix(*operands)), formula.replace(" ", "")
|
||||
# expand ellipsis to letters, determine output
|
||||
if "..." in formula:
|
||||
ell, lhs = "".join(c for c in string.ascii_letters if c not in formula), (formula.split("->") + [""])[0]
|
||||
ell_n = [max(0, x.ndim - len(s) + 3) if "..." in s else 0 for s, x in zip(lhs.split(","), xs)]
|
||||
for i, (s, x) in enumerate(zip(inputs := lhs.split(","), xs)): inputs[i] = s.replace("...", ell[max(ell_n)-ell_n[i]:max(ell_n)])
|
||||
lhs, auto = ",".join(inputs), "".join(sorted(c for c in lhs if lhs.count(c) == 1 and c.isalpha() and c not in ell))
|
||||
formula = f"{lhs}->{formula.split('->')[1].replace('...', ell[:max(ell_n)]) if '->' in formula else ell[:max(ell_n)] + auto}"
|
||||
lhs, rhs = formula.split("->") if "->" in formula else (formula, "".join(sorted(c for c in formula if formula.count(c)==1 and c.isalpha())))
|
||||
inputs = lhs.split(",")
|
||||
if len(xs) != len(inputs): raise ValueError(f"number of operands doesn't match, expected {len(inputs)}, got {len(xs)}")
|
||||
# trace: take diagonal when letter repeats in single input
|
||||
for i, (s, x) in enumerate(zip(inputs, xs)):
|
||||
for c in set(s):
|
||||
while s.count(c) > 1:
|
||||
j, k, n = s.index(c), s.index(c, s.index(c)+1), cast(int, x.shape[s.index(c)])
|
||||
perm = [d for d in range(x.ndim) if d not in (j,k)]+[j,k]
|
||||
x = x.permute(perm).flatten(-2).pad(((0,0),)*(x.ndim-2)+((0,n),)).unflatten(-1,(n,n+1))[...,0] if x.ndim > 2 else x.diagonal()
|
||||
s = s[:k] + s[k+1:]
|
||||
inputs[i], xs[i] = s, x
|
||||
# check sizes and build sorted alphabet
|
||||
sz = merge_dicts([dict(zip(s, x.shape)) for s, x in zip(inputs, xs)])
|
||||
alpha = sorted(sz)
|
||||
# align all tensors to alphabet, multiply, sum non-output, permute to output order
|
||||
xs = [x.permute(*[s.index(c) for c in sorted(s)]).reshape([sz[c] if c in s else 1 for c in alpha]).expand([sz[c] for c in alpha]) if s else x
|
||||
for s, x in zip(inputs, xs)]
|
||||
return functools.reduce(lambda a,b:a*b, xs).sum([i for i,c in enumerate(alpha) if c not in rhs], dtype=dtype).permute(argsort(argsort(list(rhs))))
|
||||
|
||||
# ***** processing ops *****
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.helpers import cdiv, cmod, CORRECT_DIVMOD_FOLDING, unwrap
|
||||
|
||||
# NOTE: this cache is only on index UOps and matches the cache in the old ShapeTracker in spirit
|
||||
# NOTE: this cache is only on index UOps
|
||||
@functools.cache
|
||||
def fold_divmod_general(d: UOp, correct_divmod_folding: bool) -> UOp|None:
|
||||
x, y = d.src
|
||||
|
||||
@@ -240,7 +240,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
case Ops.RESHAPE:
|
||||
if self.src[0]._shape is None: return self.marg
|
||||
|
||||
# movement ops change the shape. this is the logic from the old ShapeTracker
|
||||
# movement ops change the shape
|
||||
# NOTE: ssimplify is required because the shape needs to be canonical for broadcasting and same shape checking
|
||||
if self.op in GroupOp.Movement.union({Ops.MULTI, Ops.REDUCE_AXIS, Ops.WMMA}):
|
||||
ps = self.src[0]._shape
|
||||
@@ -475,14 +475,12 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
return UOp(Ops.ALLREDUCE, self.dtype, (self, UOp(Ops.DEVICE, arg=device) if not isinstance(device, UOp) else device), op)
|
||||
def overflows(self, dtype:DType) -> bool: return self.vmin < dtype.min or dtype.max < self.vmax
|
||||
|
||||
# *** ShapeTracker helpers ***
|
||||
|
||||
def split_uop(self:UOp, sep:Ops):
|
||||
if self.op is sep:
|
||||
for s in self.src: yield from s.split_uop(sep)
|
||||
else: yield self
|
||||
|
||||
# *** from MultiLazyBuffer ***
|
||||
# *** multi-device helpers ***
|
||||
|
||||
def multi(self, axis:int|None):
|
||||
assert isinstance(self.device, tuple), f"multi device must be tuple, {self.device} isn't"
|
||||
@@ -524,8 +522,6 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
return self.shrink(tuple((0,s) if i != axis else (dnum*sz,dnum*sz+sz) for i,s in enumerate(self.shape)))
|
||||
def shard(self, devices:tuple[str, ...], axis:int) -> UOp: return self.copy_to_device(devices)._shard(axis).multi(axis)
|
||||
|
||||
# *** from LazyBuffer ***
|
||||
|
||||
def copy_to_device(self, device:str|tuple[str, ...]|UOp, arg=None):
|
||||
assert arg is None or isinstance(self.device, tuple)
|
||||
inp = self if arg is None else UOp(Ops.MSELECT, self.dtype, src=(self,), arg=arg)
|
||||
|
||||
@@ -30,13 +30,13 @@ const Status = {STARTED:0, COMPLETE:1, ERR:2}
|
||||
const updateProgress = (st, msg) => {
|
||||
clearTimeout(timeout);
|
||||
const msgEl = d3.select("#progress-message").style("display", "none");
|
||||
const customEl = d3.select("#custom").html("");
|
||||
const customEl = d3.select("#custom").style("display", "none");
|
||||
if (st === Status.STARTED) {
|
||||
msgEl.text(msg);
|
||||
timeout = setTimeout(() => msgEl.style("display", "block"), 2000);
|
||||
} else if (st === Status.ERR) {
|
||||
displaySelection("#custom");
|
||||
customEl.append("div").classed("raw-text", true).append(() => codeBlock(msg));
|
||||
customEl.html("").append("div").classed("raw-text", true).append(() => codeBlock(msg));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -685,9 +685,15 @@ window.addEventListener("popstate", (e) => {
|
||||
if (e.state != null) setState(e.state);
|
||||
});
|
||||
|
||||
const toggleLabel = d3.create("label").text("Show indexing (r)").node();
|
||||
const toggle = d3.create("input").attr("type", "checkbox").attr("id", "show-indexing").property("checked", true).node();
|
||||
toggleLabel.prepend(toggle);
|
||||
const createToggle = (id, text) => {
|
||||
const label = d3.create("label").text(text).node();
|
||||
const toggle = d3.create("input").attr("type", "checkbox").attr("id", id).property("checked", true).node();
|
||||
label.prepend(toggle);
|
||||
return { toggle, label };
|
||||
}
|
||||
const { toggle, label:toggleLabel } = createToggle("show-indexing", "Show indexing (r)");
|
||||
const showGraph = createToggle("show-graph", "Show graph (g)");
|
||||
showGraph.toggle.onchange = () => displaySelection(rect("#graph").width > 0 ? "#custom" : "#graph");
|
||||
|
||||
function appendSteps(root, idx, steps) {
|
||||
const stack = [];
|
||||
@@ -748,7 +754,6 @@ async function main() {
|
||||
if (ckey in cache) {
|
||||
ret = cache[ckey];
|
||||
}
|
||||
// ** Text view
|
||||
if (!ckey.startsWith("/graph")) {
|
||||
if (!(ckey in cache)) cache[ckey] = ret = await fetchValue(ckey);
|
||||
if (ret.steps?.length > 0) {
|
||||
@@ -760,15 +765,25 @@ async function main() {
|
||||
appendSteps(el.ctx, state.currentCtx, ctx.steps);
|
||||
return setState({ currentStep:state.currentStep+1, expandSteps:true });
|
||||
}
|
||||
// cycles on the x axis
|
||||
// timeline with cycles on the x axis
|
||||
if (ret instanceof ArrayBuffer) {
|
||||
opts = {heightScale:0.5, hideLabels:true, levelKey:(e) => parseInt(e.name.split(" ")[1].split(":")[1])};
|
||||
return renderProfiler(ckey, "clk", opts);
|
||||
}
|
||||
displaySelection("#custom");
|
||||
metadata.innerHTML = "";
|
||||
ret.metadata?.forEach(m => {
|
||||
if (Array.isArray(m)) return metadata.appendChild(tabulate(m.map(({ label, value }) => {
|
||||
return [label.trim(), typeof value === "string" ? value : formatUnit(value)];
|
||||
})).node());
|
||||
metadata.appendChild(codeBlock(m.src)).classList.add("full-height")
|
||||
});
|
||||
// graph render
|
||||
if (ret.data != null) {
|
||||
metadata.prepend(showGraph.label);
|
||||
renderDag(ret, { recenter:true });
|
||||
} else displaySelection("#custom");
|
||||
// table / plaintext render
|
||||
const root = d3.create("div").classed("raw-text", true);
|
||||
// detailed assembly view
|
||||
function renderTable(root, ret) {
|
||||
const table = root.append("table");
|
||||
const thead = table.append("thead");
|
||||
@@ -797,14 +812,7 @@ async function main() {
|
||||
return table;
|
||||
}
|
||||
if (ret.cols != null) renderTable(root, ret);
|
||||
else if (ret.data != null) renderDag(ret, { recenter:true });
|
||||
else if (ret.src != null) root.append(() => codeBlock(ret.src, ret.lang));
|
||||
ret.metadata?.forEach(m => {
|
||||
if (Array.isArray(m)) return metadata.appendChild(tabulate(m.map(({ label, value }) => {
|
||||
return [label.trim(), typeof value === "string" ? value : formatUnit(value)];
|
||||
})).node());
|
||||
metadata.appendChild(codeBlock(m.src)).classList.add("full-height")
|
||||
});
|
||||
return document.querySelector("#custom").replaceChildren(root.node());
|
||||
}
|
||||
// ** Graph view
|
||||
@@ -961,9 +969,9 @@ document.addEventListener("keydown", (event) => {
|
||||
document.getElementById("zoom-to-fit-btn").click();
|
||||
}
|
||||
// r key toggles indexing
|
||||
if (event.key === "r") {
|
||||
toggle.click();
|
||||
}
|
||||
if (event.key === "r") toggle.click();
|
||||
// g key toggles graph
|
||||
if (event.key === "g") showGraph.toggle.click();
|
||||
});
|
||||
|
||||
main()
|
||||
|
||||
@@ -406,7 +406,10 @@ def amdgpu_cfg(lib:bytes, target:int) -> dict:
|
||||
curr:int|None = None
|
||||
blocks:dict[int, list[int]] = {}
|
||||
paths:dict[int, dict[int, int]] = {}
|
||||
lines:list[str] = []
|
||||
asm_width = max(len(asm) for asm, _ in pc_table.values())
|
||||
for pc, (asm, sz) in pc_table.items():
|
||||
lines.append(f" {asm:<{asm_width}} // {pc:012X}")
|
||||
if pc in leaders:
|
||||
paths[curr:=pc] = {}
|
||||
blocks[pc] = []
|
||||
@@ -420,7 +423,7 @@ def amdgpu_cfg(lib:bytes, target:int) -> dict:
|
||||
if asm.startswith("s_branch"): paths[curr][nx+offset] = UNCOND
|
||||
else: paths[curr].update([(nx+offset, COND_TAKEN), (nx, COND_NOT_TAKEN)])
|
||||
elif nx in leaders: paths[curr][nx] = UNCOND
|
||||
return {"blocks":blocks, "paths":paths, "pc_table":pc_table, "colors":cfg_colors}
|
||||
return {"data":{"blocks":blocks, "paths":paths, "pc_table":pc_table, "colors":cfg_colors}, "src":"\n".join(lines)}
|
||||
|
||||
# ** Main render function to get the complete details about a trace event
|
||||
|
||||
@@ -435,7 +438,7 @@ def get_render(query:str) -> dict:
|
||||
ret:dict = {"metadata":[]}
|
||||
if data.device.startswith("AMD") and data.lib is not None:
|
||||
with soft_err(lambda err: ret.update(err)):
|
||||
ret["data"] = amdgpu_cfg(lib:=data.lib, device_props[data.device]["gfx_target_version"])
|
||||
ret.update(amdgpu_cfg(lib:=data.lib, device_props[data.device]["gfx_target_version"]))
|
||||
with soft_err(lambda err: ret["metadata"].append(err)): ret["metadata"].append(amd_readelf(lib))
|
||||
else: ret["src"] = get_stdout(lambda: (compiler:=Device[data.device].compiler).disassemble(compiler.compile(data.src)))
|
||||
return ret
|
||||
|
||||
Reference in New Issue
Block a user