mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
* assembly/amd: clean up pcode * regen * lil * jit the pcode * sendmsg * cleanups * inst prefetch lol
67 lines
3.3 KiB
Python
67 lines
3.3 KiB
Python
"""Shared test helpers for RDNA3 tests."""
|
|
import shutil
|
|
from dataclasses import dataclass
|
|
|
|
@dataclass
|
|
class KernelInfo:
|
|
code: bytes
|
|
global_size: tuple[int, int, int]
|
|
local_size: tuple[int, int, int]
|
|
buf_idxs: list[int] # indices into shared buffer pool
|
|
buf_sizes: list[int] # sizes for each buffer index
|
|
|
|
# LLVM tool detection (shared across test files)
|
|
def get_llvm_mc():
|
|
"""Find llvm-mc executable, preferring newer versions."""
|
|
for p in ['llvm-mc', 'llvm-mc-21', 'llvm-mc-20']:
|
|
if shutil.which(p): return p
|
|
raise FileNotFoundError("llvm-mc not found")
|
|
|
|
def get_llvm_objdump():
|
|
"""Find llvm-objdump executable, preferring newer versions."""
|
|
for p in ['llvm-objdump', 'llvm-objdump-21', 'llvm-objdump-20']:
|
|
if shutil.which(p): return p
|
|
raise FileNotFoundError("llvm-objdump not found")
|
|
|
|
# ═══════════════════════════════════════════════════════════════════════════════
|
|
# EXECUTION CONTEXT (for testing compiled pseudocode)
|
|
# ═══════════════════════════════════════════════════════════════════════════════
|
|
|
|
class ExecContext:
|
|
"""Context for running compiled pseudocode in tests."""
|
|
def __init__(self, s0=0, s1=0, s2=0, d0=0, scc=0, vcc=0, lane=0, exec_mask=0xffffffff, literal=0, vgprs=None, src0_idx=0, vdst_idx=0):
|
|
from extra.assembly.amd.pcode import Reg, MASK32, MASK64, TypedView
|
|
self._Reg, self._MASK64, self._TypedView = Reg, MASK64, TypedView
|
|
self.S0, self.S1, self.S2 = Reg(s0), Reg(s1), Reg(s2)
|
|
self.D0, self.D1 = Reg(d0), Reg(0)
|
|
self.SCC, self.VCC, self.EXEC = Reg(scc), Reg(vcc), Reg(exec_mask)
|
|
self.tmp, self.saveexec = Reg(0), Reg(exec_mask)
|
|
self.lane, self.laneId, self.literal = lane, lane, literal
|
|
self.SIMM16, self.SIMM32 = Reg(literal), Reg(literal)
|
|
self.VGPR = vgprs if vgprs is not None else {}
|
|
self.SRC0, self.VDST = Reg(src0_idx), Reg(vdst_idx)
|
|
|
|
def run(self, code: str):
|
|
"""Execute compiled code."""
|
|
import extra.assembly.amd.pcode as pcode
|
|
ns = {k: getattr(pcode, k) for k in dir(pcode) if not k.startswith('_')}
|
|
# Also include underscore-prefixed helpers that compiled pseudocode uses
|
|
for k in ['_pack', '_pack32']:
|
|
if hasattr(pcode, k): ns[k] = getattr(pcode, k)
|
|
ns.update({
|
|
'S0': self.S0, 'S1': self.S1, 'S2': self.S2, 'D0': self.D0, 'D1': self.D1,
|
|
'SCC': self.SCC, 'VCC': self.VCC, 'EXEC': self.EXEC,
|
|
'EXEC_LO': self._TypedView(self.EXEC, 31, 0), 'EXEC_HI': self._TypedView(self.EXEC, 63, 32),
|
|
'tmp': self.tmp, 'saveexec': self.saveexec,
|
|
'lane': self.lane, 'laneId': self.laneId, 'literal': self.literal,
|
|
'SIMM16': self.SIMM16, 'SIMM32': self.SIMM32, 'VGPR': self.VGPR, 'SRC0': self.SRC0, 'VDST': self.VDST,
|
|
})
|
|
exec(code, ns)
|
|
def _sync(ctx_reg, ns_val):
|
|
if isinstance(ns_val, self._Reg): ctx_reg._val = ns_val._val
|
|
else: ctx_reg._val = int(ns_val) & self._MASK64
|
|
for name in ('SCC', 'VCC', 'EXEC', 'D0', 'D1', 'tmp', 'saveexec'):
|
|
if ns.get(name) is not getattr(self, name): _sync(getattr(self, name), ns[name])
|
|
|
|
def result(self) -> dict: return {"d0": self.D0._val, "scc": self.SCC._val & 1}
|