assembly/amd: speed up emulator (#13932)

This commit is contained in:
George Hotz
2025-12-31 13:32:25 -05:00
committed by GitHub
parent 13973e4dea
commit f14428090f

View File

@@ -1,48 +1,56 @@
# library for RDNA3 assembly DSL
# mypy: ignore-errors
from __future__ import annotations
import struct, math
import struct, math, re
from enum import IntEnum
from functools import cache, cached_property
from typing import overload, Annotated, TypeVar, Generic
from extra.assembly.amd.autogen.rdna3.enum import (VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOP3POp, VOPCOp, VOPDOp, SOP1Op, SOP2Op,
SOPCOp, SOPKOp, SOPPOp, SMEMOp, DSOp, FLATOp, MUBUFOp, MTBUFOp, MIMGOp, VINTERPOp)
# Common masks and bit conversion functions
MASK32, MASK64 = 0xffffffff, 0xffffffffffffffff
def _f32(i): return struct.unpack("<f", struct.pack("<I", i & MASK32))[0]
_struct_f, _struct_I = struct.Struct("<f"), struct.Struct("<I")
_struct_e, _struct_H = struct.Struct("<e"), struct.Struct("<H")
_struct_d, _struct_Q = struct.Struct("<d"), struct.Struct("<Q")
def _f32(i): return _struct_f.unpack(_struct_I.pack(i & MASK32))[0]
def _i32(f):
if isinstance(f, int): f = float(f)
if math.isnan(f): return 0xffc00000 if math.copysign(1.0, f) < 0 else 0x7fc00000
if math.isinf(f): return 0x7f800000 if f > 0 else 0xff800000
try: return struct.unpack("<I", struct.pack("<f", f))[0]
try: return _struct_I.unpack(_struct_f.pack(f))[0]
except (OverflowError, struct.error): return 0x7f800000 if f > 0 else 0xff800000
def _sext(v, b): return v - (1 << b) if v & (1 << (b - 1)) else v
def _f16(i): return struct.unpack("<e", struct.pack("<H", i & 0xffff))[0]
def _f16(i): return _struct_e.unpack(_struct_H.pack(i & 0xffff))[0]
def _i16(f):
if math.isnan(f): return 0x7e00
if math.isinf(f): return 0x7c00 if f > 0 else 0xfc00
try: return struct.unpack("<H", struct.pack("<e", f))[0]
try: return _struct_H.unpack(_struct_e.pack(f))[0]
except (OverflowError, struct.error): return 0x7c00 if f > 0 else 0xfc00
def _f64(i): return struct.unpack("<d", struct.pack("<Q", i & MASK64))[0]
def _f64(i): return _struct_d.unpack(_struct_Q.pack(i & MASK64))[0]
def _i64(f):
if math.isnan(f): return 0x7ff8000000000000
if math.isinf(f): return 0x7ff0000000000000 if f > 0 else 0xfff0000000000000
try: return struct.unpack("<Q", struct.pack("<d", f))[0]
try: return _struct_Q.unpack(_struct_d.pack(f))[0]
except (OverflowError, struct.error): return 0x7ff0000000000000 if f > 0 else 0xfff0000000000000
# Instruction spec - register counts and dtypes derived from instruction names
import re
_REGS = {'B32': 1, 'B64': 2, 'B96': 3, 'B128': 4, 'B256': 8, 'B512': 16,
'F32': 1, 'I32': 1, 'U32': 1, 'F64': 2, 'I64': 2, 'U64': 2,
'F16': 1, 'I16': 1, 'U16': 1, 'B16': 1, 'I8': 1, 'U8': 1, 'B8': 1}
_CVT_RE = re.compile(r'CVT_([FIUB]\d+)_([FIUB]\d+)$')
_MAD_MUL_RE = re.compile(r'(?:MAD|MUL)_([IU]\d+)_([IU]\d+)$')
_PACK_RE = re.compile(r'PACK_([FIUB]\d+)_([FIUB]\d+)$')
_DST_SRC_RE = re.compile(r'_([FIUB]\d+)_([FIUB]\d+)$')
_SINGLE_RE = re.compile(r'_([FIUB](?:32|64|16|8|96|128|256|512))$')
@cache
def _suffix(name: str) -> tuple[str | None, str | None]:
name = name.upper()
if m := re.search(r'CVT_([FIUB]\d+)_([FIUB]\d+)$', name): return m.group(1), m.group(2)
if m := re.search(r'(?:MAD|MUL)_([IU]\d+)_([IU]\d+)$', name): return m.group(1), m.group(2)
if m := re.search(r'PACK_([FIUB]\d+)_([FIUB]\d+)$', name): return m.group(1), m.group(2)
# Generic dst_src pattern: S_BCNT0_I32_B64, S_BITREPLICATE_B64_B32, V_FREXP_EXP_I32_F64, etc.
if m := re.search(r'_([FIUB]\d+)_([FIUB]\d+)$', name): return m.group(1), m.group(2)
if m := re.search(r'_([FIUB](?:32|64|16|8|96|128|256|512))$', name): return m.group(1), m.group(1)
if m := _CVT_RE.search(name): return m.group(1), m.group(2)
if m := _MAD_MUL_RE.search(name): return m.group(1), m.group(2)
if m := _PACK_RE.search(name): return m.group(1), m.group(2)
if m := _DST_SRC_RE.search(name): return m.group(1), m.group(2)
if m := _SINGLE_RE.search(name): return m.group(1), m.group(1)
return None, None
_SPECIAL_REGS = {
'V_LSHLREV_B64': (2, 1, 2, 1), 'V_LSHRREV_B64': (2, 1, 2, 1), 'V_ASHRREV_I64': (2, 1, 2, 1),
@@ -71,27 +79,33 @@ _SPECIAL_DTYPE = {
'V_QSAD_PK_U16_U8': ('B64', 'B64', 'B64', 'B64'), 'V_MQSAD_PK_U16_U8': ('B64', 'B64', 'B64', 'B64'),
'V_MQSAD_U32_U8': ('B128', 'B64', 'B64', 'B128'),
}
@cache
def spec_regs(name: str) -> tuple[int, int, int, int]:
name = name.upper()
if name in _SPECIAL_REGS: return _SPECIAL_REGS[name]
if 'SAD' in name and 'U8' in name and 'QSAD' not in name and 'MQSAD' not in name: return 1, 1, 1, 1
uname = name.upper()
if uname in _SPECIAL_REGS: return _SPECIAL_REGS[uname]
if 'SAD' in uname and 'U8' in uname and 'QSAD' not in uname and 'MQSAD' not in uname: return 1, 1, 1, 1
dst_suf, src_suf = _suffix(name)
return _REGS.get(dst_suf, 1), _REGS.get(src_suf, 1), _REGS.get(src_suf, 1), _REGS.get(src_suf, 1)
@cache
def spec_dtype(name: str) -> tuple[str | None, str | None, str | None, str | None]:
name = name.upper()
if name in _SPECIAL_DTYPE: return _SPECIAL_DTYPE[name]
if 'SAD' in name and ('U8' in name or 'U16' in name) and 'QSAD' not in name and 'MQSAD' not in name: return 'U32', 'U32', 'U32', 'U32'
if '_CMP_' in name or '_CMPX_' in name:
uname = name.upper()
if uname in _SPECIAL_DTYPE: return _SPECIAL_DTYPE[uname]
if 'SAD' in uname and ('U8' in uname or 'U16' in uname) and 'QSAD' not in uname and 'MQSAD' not in uname: return 'U32', 'U32', 'U32', 'U32'
if '_CMP_' in uname or '_CMPX_' in uname:
dst_suf, src_suf = _suffix(name)
return 'EXEC' if '_CMPX_' in name else 'VCC', src_suf, src_suf, None
return 'EXEC' if '_CMPX_' in uname else 'VCC', src_suf, src_suf, None
dst_suf, src_suf = _suffix(name)
return dst_suf, src_suf, src_suf, src_suf
_F16_RE = re.compile(r'_[FIUB]16(?:_|$)')
_F64_RE = re.compile(r'_[FIUB]64(?:_|$)')
@cache
def spec_is_16bit(name: str) -> bool:
name = name.upper()
if 'SAD' in name or 'PACK' in name or '_PK_' in name or 'SAT_PK' in name or 'DOT2' in name: return False
if '_F32' in name or '_I32' in name or '_U32' in name or '_B32' in name: return False # mixed ops like V_DOT2ACC_F32_F16
return bool(re.search(r'_[FIUB]16(?:_|$)', name))
def spec_is_64bit(name: str) -> bool: return bool(re.search(r'_[FIUB]64(?:_|$)', name.upper()))
uname = name.upper()
if 'SAD' in uname or 'PACK' in uname or '_PK_' in uname or 'SAT_PK' in uname or 'DOT2' in uname: return False
if '_F32' in uname or '_I32' in uname or '_U32' in uname or '_B32' in uname: return False
return bool(_F16_RE.search(uname))
@cache
def spec_is_64bit(name: str) -> bool: return bool(_F64_RE.search(name.upper()))
_3SRC = {'FMA', 'MAD', 'MIN3', 'MAX3', 'MED3', 'DIV_FIX', 'DIV_FMAS', 'DIV_SCALE', 'SAD', 'LERP', 'ALIGN', 'CUBE', 'BFE', 'BFI',
'PERM_B32', 'PERMLANE', 'CNDMASK', 'XOR3', 'OR3', 'ADD3', 'LSHL_OR', 'AND_OR', 'LSHL_ADD', 'ADD_LSHL', 'XAD', 'MAXMIN',
'MINMAX', 'DOT2', 'DOT4', 'DOT8', 'WMMA', 'CVT_PK_U8', 'MULLIT', 'CO_CI'}
@@ -495,21 +509,25 @@ class Inst:
assert cls_name in self._enum_map, f"no enum map for {cls_name}"
return self._enum_map[cls_name](val)
@property
@cached_property
def op_name(self) -> str:
op = self.op
return op.name if hasattr(op, 'name') else ''
def dst_regs(self) -> int: return spec_regs(self.op_name)[0]
def src_regs(self, n: int) -> int: return spec_regs(self.op_name)[n + 1]
@cached_property
def _spec_regs(self) -> tuple[int, int, int, int]: return spec_regs(self.op_name)
@cached_property
def _spec_dtype(self) -> tuple[str | None, str | None, str | None, str | None]: return spec_dtype(self.op_name)
def dst_regs(self) -> int: return self._spec_regs[0]
def src_regs(self, n: int) -> int: return self._spec_regs[n + 1]
def num_srcs(self) -> int: return spec_num_srcs(self.op_name)
def dst_dtype(self) -> str | None: return spec_dtype(self.op_name)[0]
def src_dtype(self, n: int) -> str | None: return spec_dtype(self.op_name)[n + 1]
def is_src_16(self, n: int) -> bool: return self.src_regs(n) == 1 and is_dtype_16(self.src_dtype(n))
def is_src_64(self, n: int) -> bool: return self.src_regs(n) == 2
def dst_dtype(self) -> str | None: return self._spec_dtype[0]
def src_dtype(self, n: int) -> str | None: return self._spec_dtype[n + 1]
def is_src_16(self, n: int) -> bool: return self._spec_regs[n + 1] == 1 and is_dtype_16(self._spec_dtype[n + 1])
def is_src_64(self, n: int) -> bool: return self._spec_regs[n + 1] == 2
def is_16bit(self) -> bool: return spec_is_16bit(self.op_name)
def is_64bit(self) -> bool: return spec_is_64bit(self.op_name)
def is_dst_16(self) -> bool: return self.dst_regs() == 1 and is_dtype_16(self.dst_dtype())
def is_dst_16(self) -> bool: return self._spec_regs[0] == 1 and is_dtype_16(self._spec_dtype[0])
class Inst32(Inst): pass
class Inst64(Inst): pass