From 532e3fe07a5daca64b6b7083c9172e954bf8c83f Mon Sep 17 00:00:00 2001 From: George Hotz Date: Sat, 3 Jan 2026 13:06:37 -0800 Subject: [PATCH] speedups --- extra/assembly/amd/dsl.py | 47 ++++++++++++++++++++++++--------------- extra/assembly/amd/emu.py | 22 +++++++++--------- 2 files changed, 39 insertions(+), 30 deletions(-) diff --git a/extra/assembly/amd/dsl.py b/extra/assembly/amd/dsl.py index 08b34bee27..0e34374fd8 100644 --- a/extra/assembly/amd/dsl.py +++ b/extra/assembly/amd/dsl.py @@ -3,7 +3,7 @@ from __future__ import annotations import struct, math, re from enum import IntEnum -from functools import cache, cached_property +from functools import cache from typing import overload, Annotated, TypeVar, Generic from extra.assembly.amd.autogen.rdna3.enum import (VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOP3POp, VOPCOp, VOPDOp, SOP1Op, SOP2Op, SOPCOp, SOPKOp, SOPPOp, SMEMOp, DSOp, FLATOp, MUBUFOp, MTBUFOp, MIMGOp, VINTERPOp) @@ -346,6 +346,7 @@ class Inst: if 'abs_' in kwargs: kwargs['abs'] = kwargs.pop('abs_') orig_args = dict(zip(field_names, args)) | kwargs self._values.update(orig_args) + self._precompute() self._validate(orig_args) # Pre-shift literal for 64-bit sources (literal param is always raw 32-bit value from user) if literal is not None: @@ -386,6 +387,7 @@ class Inst: elif name == 'sbase': self._values[name] = (val.idx if isinstance(val, Reg) else val.val if isinstance(val, SrcMod) else val * 2) // 2 elif name in {'srsrc', 'ssamp'} and isinstance(val, Reg): self._values[name] = val.idx // 4 elif marker is _VDSTYEnc and isinstance(val, VGPR): self._values[name] = val.idx >> 1 + self._precompute_fields() def _encode_field(self, name: str, val) -> int: if isinstance(val, RawImm): return val.val @@ -450,6 +452,8 @@ class Inst: inst = object.__new__(cls) inst._values = {n: RawImm(v) if n in SRC_FIELDS else v for n, bf in cls._fields.items() if n != 'encoding' for v in [(word >> bf.lo) & bf.mask()]} inst._literal = None + inst._precompute() + inst._precompute_fields() return inst @classmethod @@ -510,25 +514,32 @@ class Inst: 'VOPD': VOPDOp, 'VINTERP': VINTERPOp} _VOP3SD_OPS = {288, 289, 290, 764, 765, 766, 767, 768, 769, 770} - @property - def op(self): - """Return the op as an enum (e.g., VOP1Op.V_MOV_B32). VOP3 returns VOPCOp/VOP3SDOp for those op ranges.""" + def _precompute(self): + """Precompute op, op_name, _spec_regs, _spec_dtype for fast access.""" val = self._values.get('op') - if val is None: return None - if hasattr(val, 'name'): return val # already an enum - cls_name = self.__class__.__name__ - assert cls_name in self._enum_map, f"no enum map for {cls_name}" - return self._enum_map[cls_name](val) + if val is None: self.op = None + elif hasattr(val, 'name'): self.op = val + else: + cls_name = self.__class__.__name__ + # VOP3 with VOPC opcodes (0-255) -> VOPCOp, VOP3SD opcodes -> VOP3SDOp + if cls_name == 'VOP3': + try: + if val < 256: self.op = VOPCOp(val) + elif val in self._VOP3SD_OPS: self.op = VOP3SDOp(val) + else: self.op = VOP3Op(val) + except ValueError: self.op = val + elif cls_name in self._enum_map: + try: self.op = self._enum_map[cls_name](val) + except ValueError: self.op = val + else: self.op = val + self.op_name = self.op.name if hasattr(self.op, 'name') else '' + self._spec_regs = spec_regs(self.op_name) + self._spec_dtype = spec_dtype(self.op_name) - @cached_property - def op_name(self) -> str: - op = self.op - return op.name if hasattr(op, 'name') else '' - - @cached_property - def _spec_regs(self) -> tuple[int, int, int, int]: return spec_regs(self.op_name) - @cached_property - def _spec_dtype(self) -> tuple[str | None, str | None, str | None, str | None]: return spec_dtype(self.op_name) + def _precompute_fields(self): + """Unwrap all field values as direct attributes for fast access.""" + for name, val in self._values.items(): + if name != 'op': setattr(self, name, unwrap(val)) def dst_regs(self) -> int: return self._spec_regs[0] def src_regs(self, n: int) -> int: return self._spec_regs[n + 1] def num_srcs(self) -> int: return spec_num_srcs(self.op_name) diff --git a/extra/assembly/amd/emu.py b/extra/assembly/amd/emu.py index 566713b57e..a38e5fc112 100644 --- a/extra/assembly/amd/emu.py +++ b/extra/assembly/amd/emu.py @@ -179,9 +179,8 @@ def decode_program(data: bytes) -> Program: base_size = inst_class._size() # Pass enough data for potential 64-bit literal (base + 8 bytes max) inst = inst_class.from_bytes(data[i:i+base_size+8]) - for name, val in inst._values.items(): - if name != 'op': setattr(inst, name, unwrap(val)) # skip op to preserve property access inst._words = inst.size() // 4 + inst._fn = COMPILED_FUNCTIONS.get(type(inst.op), {}).get(inst.op) result[i // 4] = inst i += inst._words * 4 return result @@ -201,10 +200,10 @@ def exec_scalar(st: WaveState, inst: Inst) -> int: elif isinstance(inst, SOPP): ssrc0, sdst = None, None else: raise NotImplementedError(f"Unknown scalar type {type(inst)}") - fn = COMPILED_FUNCTIONS.get(type(inst.op), {}).get(inst.op) + fn = inst._fn if fn is None: if isinstance(inst, SOPP): return 0 # SOPP without pseudocode (waits, hints, nops) are no-ops - raise NotImplementedError(f"{inst.op.name} not in pseudocode") + raise NotImplementedError(f"{inst.op_name} not in pseudocode") # SMEM: memory loads if isinstance(inst, SMEM): @@ -250,26 +249,24 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: LDSMem | None = None) for vopd_op, s0, s1, d0, dst in inputs: V[dst] = _exec_vopd(vopd_op, s0, s1, d0, st, lane) return - # Lookup compiled function for this op (V_NOP has no pcode) - if isinstance(inst, VOP1) and inst.op == VOP1Op.V_NOP: return - fn = COMPILED_FUNCTIONS.get(type(inst.op), {}).get(inst.op) + fn = inst._fn if fn is None: raise NotImplementedError(f"{inst.op_name} not in pseudocode") # Memory ops (FLAT/GLOBAL/SCRATCH and DS) if isinstance(inst, (FLAT, DS)): - ndwords = _op_ndwords(inst.op.name) + ndwords = _op_ndwords(inst.op_name) if isinstance(inst, FLAT): addr = V[inst.addr] | (V[inst.addr + 1] << 32) ADDR = (st.rsgpr64(inst.saddr) + V[inst.addr] + _sext(inst.offset, 13)) & MASK64 if inst.saddr not in (NULL, 0x7f) else (addr + _sext(inst.offset, 13)) & MASK64 - vdata_src = inst.vdst if 'LOAD' in inst.op.name else inst.data + vdata_src = inst.vdst if 'LOAD' in inst.op_name else inst.data result = fn(GlobalMem, ADDR, _vgpr_read(V, vdata_src, ndwords), V[inst.vdst]) if 'VDATA' in result: _vgpr_write(V, inst.vdst, result['VDATA'], ndwords) if 'RETURN_DATA' in result: _vgpr_write(V, inst.vdst, result['RETURN_DATA'], ndwords) else: # DS data0, data1 = _vgpr_read(V, inst.data0, ndwords), _vgpr_read(V, inst.data1, ndwords) if inst.data1 is not None else 0 result = fn(lds, V[inst.addr], data0, data1, inst.offset0, inst.offset1) - if 'RETURN_DATA' in result and ('_RTN' in inst.op.name or '_LOAD' in inst.op.name): - _vgpr_write(V, inst.vdst, result['RETURN_DATA'], ndwords * 2 if '_2ADDR_' in inst.op.name else ndwords) + if 'RETURN_DATA' in result and ('_RTN' in inst.op_name or '_LOAD' in inst.op_name): + _vgpr_write(V, inst.vdst, result['RETURN_DATA'], ndwords * 2 if '_2ADDR_' in inst.op_name else ndwords) return # VOP3SD: has extra scalar dest for carry output @@ -389,7 +386,8 @@ def step_wave(program: Program, st: WaveState, lds: LDSMem, n_lanes: int) -> int st.pc += inst_words + exec_scalar(st, inst) return 0 # Wave-level vector ops: execute once for entire wave (not per-lane) - if isinstance(inst, VOP3P) and 'WMMA' in inst.op_name: + if isinstance(inst, VOP1) and inst.op == VOP1Op.V_NOP: pass + elif isinstance(inst, VOP3P) and 'WMMA' in inst.op_name: exec_wmma(st, inst, inst.op) elif isinstance(inst, VOP3) and inst.op == VOP3Op.V_WRITELANE_B32: wr_lane = st.rsrc(inst.src1, 0) & 0x1f