This commit is contained in:
George Hotz
2026-01-03 13:06:37 -08:00
parent 8ded12b01b
commit 532e3fe07a
2 changed files with 39 additions and 30 deletions

View File

@@ -3,7 +3,7 @@
from __future__ import annotations
import struct, math, re
from enum import IntEnum
from functools import cache, cached_property
from functools import cache
from typing import overload, Annotated, TypeVar, Generic
from extra.assembly.amd.autogen.rdna3.enum import (VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOP3POp, VOPCOp, VOPDOp, SOP1Op, SOP2Op,
SOPCOp, SOPKOp, SOPPOp, SMEMOp, DSOp, FLATOp, MUBUFOp, MTBUFOp, MIMGOp, VINTERPOp)
@@ -346,6 +346,7 @@ class Inst:
if 'abs_' in kwargs: kwargs['abs'] = kwargs.pop('abs_')
orig_args = dict(zip(field_names, args)) | kwargs
self._values.update(orig_args)
self._precompute()
self._validate(orig_args)
# Pre-shift literal for 64-bit sources (literal param is always raw 32-bit value from user)
if literal is not None:
@@ -386,6 +387,7 @@ class Inst:
elif name == 'sbase': self._values[name] = (val.idx if isinstance(val, Reg) else val.val if isinstance(val, SrcMod) else val * 2) // 2
elif name in {'srsrc', 'ssamp'} and isinstance(val, Reg): self._values[name] = val.idx // 4
elif marker is _VDSTYEnc and isinstance(val, VGPR): self._values[name] = val.idx >> 1
self._precompute_fields()
def _encode_field(self, name: str, val) -> int:
if isinstance(val, RawImm): return val.val
@@ -450,6 +452,8 @@ class Inst:
inst = object.__new__(cls)
inst._values = {n: RawImm(v) if n in SRC_FIELDS else v for n, bf in cls._fields.items() if n != 'encoding' for v in [(word >> bf.lo) & bf.mask()]}
inst._literal = None
inst._precompute()
inst._precompute_fields()
return inst
@classmethod
@@ -510,25 +514,32 @@ class Inst:
'VOPD': VOPDOp, 'VINTERP': VINTERPOp}
_VOP3SD_OPS = {288, 289, 290, 764, 765, 766, 767, 768, 769, 770}
@property
def op(self):
"""Return the op as an enum (e.g., VOP1Op.V_MOV_B32). VOP3 returns VOPCOp/VOP3SDOp for those op ranges."""
def _precompute(self):
"""Precompute op, op_name, _spec_regs, _spec_dtype for fast access."""
val = self._values.get('op')
if val is None: return None
if hasattr(val, 'name'): return val # already an enum
cls_name = self.__class__.__name__
assert cls_name in self._enum_map, f"no enum map for {cls_name}"
return self._enum_map[cls_name](val)
if val is None: self.op = None
elif hasattr(val, 'name'): self.op = val
else:
cls_name = self.__class__.__name__
# VOP3 with VOPC opcodes (0-255) -> VOPCOp, VOP3SD opcodes -> VOP3SDOp
if cls_name == 'VOP3':
try:
if val < 256: self.op = VOPCOp(val)
elif val in self._VOP3SD_OPS: self.op = VOP3SDOp(val)
else: self.op = VOP3Op(val)
except ValueError: self.op = val
elif cls_name in self._enum_map:
try: self.op = self._enum_map[cls_name](val)
except ValueError: self.op = val
else: self.op = val
self.op_name = self.op.name if hasattr(self.op, 'name') else ''
self._spec_regs = spec_regs(self.op_name)
self._spec_dtype = spec_dtype(self.op_name)
@cached_property
def op_name(self) -> str:
op = self.op
return op.name if hasattr(op, 'name') else ''
@cached_property
def _spec_regs(self) -> tuple[int, int, int, int]: return spec_regs(self.op_name)
@cached_property
def _spec_dtype(self) -> tuple[str | None, str | None, str | None, str | None]: return spec_dtype(self.op_name)
def _precompute_fields(self):
"""Unwrap all field values as direct attributes for fast access."""
for name, val in self._values.items():
if name != 'op': setattr(self, name, unwrap(val))
def dst_regs(self) -> int: return self._spec_regs[0]
def src_regs(self, n: int) -> int: return self._spec_regs[n + 1]
def num_srcs(self) -> int: return spec_num_srcs(self.op_name)

View File

@@ -179,9 +179,8 @@ def decode_program(data: bytes) -> Program:
base_size = inst_class._size()
# Pass enough data for potential 64-bit literal (base + 8 bytes max)
inst = inst_class.from_bytes(data[i:i+base_size+8])
for name, val in inst._values.items():
if name != 'op': setattr(inst, name, unwrap(val)) # skip op to preserve property access
inst._words = inst.size() // 4
inst._fn = COMPILED_FUNCTIONS.get(type(inst.op), {}).get(inst.op)
result[i // 4] = inst
i += inst._words * 4
return result
@@ -201,10 +200,10 @@ def exec_scalar(st: WaveState, inst: Inst) -> int:
elif isinstance(inst, SOPP): ssrc0, sdst = None, None
else: raise NotImplementedError(f"Unknown scalar type {type(inst)}")
fn = COMPILED_FUNCTIONS.get(type(inst.op), {}).get(inst.op)
fn = inst._fn
if fn is None:
if isinstance(inst, SOPP): return 0 # SOPP without pseudocode (waits, hints, nops) are no-ops
raise NotImplementedError(f"{inst.op.name} not in pseudocode")
raise NotImplementedError(f"{inst.op_name} not in pseudocode")
# SMEM: memory loads
if isinstance(inst, SMEM):
@@ -250,26 +249,24 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: LDSMem | None = None)
for vopd_op, s0, s1, d0, dst in inputs: V[dst] = _exec_vopd(vopd_op, s0, s1, d0, st, lane)
return
# Lookup compiled function for this op (V_NOP has no pcode)
if isinstance(inst, VOP1) and inst.op == VOP1Op.V_NOP: return
fn = COMPILED_FUNCTIONS.get(type(inst.op), {}).get(inst.op)
fn = inst._fn
if fn is None: raise NotImplementedError(f"{inst.op_name} not in pseudocode")
# Memory ops (FLAT/GLOBAL/SCRATCH and DS)
if isinstance(inst, (FLAT, DS)):
ndwords = _op_ndwords(inst.op.name)
ndwords = _op_ndwords(inst.op_name)
if isinstance(inst, FLAT):
addr = V[inst.addr] | (V[inst.addr + 1] << 32)
ADDR = (st.rsgpr64(inst.saddr) + V[inst.addr] + _sext(inst.offset, 13)) & MASK64 if inst.saddr not in (NULL, 0x7f) else (addr + _sext(inst.offset, 13)) & MASK64
vdata_src = inst.vdst if 'LOAD' in inst.op.name else inst.data
vdata_src = inst.vdst if 'LOAD' in inst.op_name else inst.data
result = fn(GlobalMem, ADDR, _vgpr_read(V, vdata_src, ndwords), V[inst.vdst])
if 'VDATA' in result: _vgpr_write(V, inst.vdst, result['VDATA'], ndwords)
if 'RETURN_DATA' in result: _vgpr_write(V, inst.vdst, result['RETURN_DATA'], ndwords)
else: # DS
data0, data1 = _vgpr_read(V, inst.data0, ndwords), _vgpr_read(V, inst.data1, ndwords) if inst.data1 is not None else 0
result = fn(lds, V[inst.addr], data0, data1, inst.offset0, inst.offset1)
if 'RETURN_DATA' in result and ('_RTN' in inst.op.name or '_LOAD' in inst.op.name):
_vgpr_write(V, inst.vdst, result['RETURN_DATA'], ndwords * 2 if '_2ADDR_' in inst.op.name else ndwords)
if 'RETURN_DATA' in result and ('_RTN' in inst.op_name or '_LOAD' in inst.op_name):
_vgpr_write(V, inst.vdst, result['RETURN_DATA'], ndwords * 2 if '_2ADDR_' in inst.op_name else ndwords)
return
# VOP3SD: has extra scalar dest for carry output
@@ -389,7 +386,8 @@ def step_wave(program: Program, st: WaveState, lds: LDSMem, n_lanes: int) -> int
st.pc += inst_words + exec_scalar(st, inst)
return 0
# Wave-level vector ops: execute once for entire wave (not per-lane)
if isinstance(inst, VOP3P) and 'WMMA' in inst.op_name:
if isinstance(inst, VOP1) and inst.op == VOP1Op.V_NOP: pass
elif isinstance(inst, VOP3P) and 'WMMA' in inst.op_name:
exec_wmma(st, inst, inst.op)
elif isinstance(inst, VOP3) and inst.op == VOP3Op.V_WRITELANE_B32:
wr_lane = st.rsrc(inst.src1, 0) & 0x1f