diff --git a/extra/assembly/amd/dsl.py b/extra/assembly/amd/dsl.py index fdbc7b462c..cde5ef6984 100644 --- a/extra/assembly/amd/dsl.py +++ b/extra/assembly/amd/dsl.py @@ -1,48 +1,56 @@ # library for RDNA3 assembly DSL # mypy: ignore-errors from __future__ import annotations -import struct, math +import struct, math, re from enum import IntEnum +from functools import cache, cached_property from typing import overload, Annotated, TypeVar, Generic from extra.assembly.amd.autogen.rdna3.enum import (VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOP3POp, VOPCOp, VOPDOp, SOP1Op, SOP2Op, SOPCOp, SOPKOp, SOPPOp, SMEMOp, DSOp, FLATOp, MUBUFOp, MTBUFOp, MIMGOp, VINTERPOp) # Common masks and bit conversion functions MASK32, MASK64 = 0xffffffff, 0xffffffffffffffff -def _f32(i): return struct.unpack(" 0 else 0xff800000 - try: return struct.unpack(" 0 else 0xff800000 def _sext(v, b): return v - (1 << b) if v & (1 << (b - 1)) else v -def _f16(i): return struct.unpack(" 0 else 0xfc00 - try: return struct.unpack(" 0 else 0xfc00 -def _f64(i): return struct.unpack(" 0 else 0xfff0000000000000 - try: return struct.unpack(" 0 else 0xfff0000000000000 # Instruction spec - register counts and dtypes derived from instruction names -import re _REGS = {'B32': 1, 'B64': 2, 'B96': 3, 'B128': 4, 'B256': 8, 'B512': 16, 'F32': 1, 'I32': 1, 'U32': 1, 'F64': 2, 'I64': 2, 'U64': 2, 'F16': 1, 'I16': 1, 'U16': 1, 'B16': 1, 'I8': 1, 'U8': 1, 'B8': 1} +_CVT_RE = re.compile(r'CVT_([FIUB]\d+)_([FIUB]\d+)$') +_MAD_MUL_RE = re.compile(r'(?:MAD|MUL)_([IU]\d+)_([IU]\d+)$') +_PACK_RE = re.compile(r'PACK_([FIUB]\d+)_([FIUB]\d+)$') +_DST_SRC_RE = re.compile(r'_([FIUB]\d+)_([FIUB]\d+)$') +_SINGLE_RE = re.compile(r'_([FIUB](?:32|64|16|8|96|128|256|512))$') +@cache def _suffix(name: str) -> tuple[str | None, str | None]: name = name.upper() - if m := re.search(r'CVT_([FIUB]\d+)_([FIUB]\d+)$', name): return m.group(1), m.group(2) - if m := re.search(r'(?:MAD|MUL)_([IU]\d+)_([IU]\d+)$', name): return m.group(1), m.group(2) - if m := re.search(r'PACK_([FIUB]\d+)_([FIUB]\d+)$', name): return m.group(1), m.group(2) - # Generic dst_src pattern: S_BCNT0_I32_B64, S_BITREPLICATE_B64_B32, V_FREXP_EXP_I32_F64, etc. - if m := re.search(r'_([FIUB]\d+)_([FIUB]\d+)$', name): return m.group(1), m.group(2) - if m := re.search(r'_([FIUB](?:32|64|16|8|96|128|256|512))$', name): return m.group(1), m.group(1) + if m := _CVT_RE.search(name): return m.group(1), m.group(2) + if m := _MAD_MUL_RE.search(name): return m.group(1), m.group(2) + if m := _PACK_RE.search(name): return m.group(1), m.group(2) + if m := _DST_SRC_RE.search(name): return m.group(1), m.group(2) + if m := _SINGLE_RE.search(name): return m.group(1), m.group(1) return None, None _SPECIAL_REGS = { 'V_LSHLREV_B64': (2, 1, 2, 1), 'V_LSHRREV_B64': (2, 1, 2, 1), 'V_ASHRREV_I64': (2, 1, 2, 1), @@ -71,27 +79,33 @@ _SPECIAL_DTYPE = { 'V_QSAD_PK_U16_U8': ('B64', 'B64', 'B64', 'B64'), 'V_MQSAD_PK_U16_U8': ('B64', 'B64', 'B64', 'B64'), 'V_MQSAD_U32_U8': ('B128', 'B64', 'B64', 'B128'), } +@cache def spec_regs(name: str) -> tuple[int, int, int, int]: - name = name.upper() - if name in _SPECIAL_REGS: return _SPECIAL_REGS[name] - if 'SAD' in name and 'U8' in name and 'QSAD' not in name and 'MQSAD' not in name: return 1, 1, 1, 1 + uname = name.upper() + if uname in _SPECIAL_REGS: return _SPECIAL_REGS[uname] + if 'SAD' in uname and 'U8' in uname and 'QSAD' not in uname and 'MQSAD' not in uname: return 1, 1, 1, 1 dst_suf, src_suf = _suffix(name) return _REGS.get(dst_suf, 1), _REGS.get(src_suf, 1), _REGS.get(src_suf, 1), _REGS.get(src_suf, 1) +@cache def spec_dtype(name: str) -> tuple[str | None, str | None, str | None, str | None]: - name = name.upper() - if name in _SPECIAL_DTYPE: return _SPECIAL_DTYPE[name] - if 'SAD' in name and ('U8' in name or 'U16' in name) and 'QSAD' not in name and 'MQSAD' not in name: return 'U32', 'U32', 'U32', 'U32' - if '_CMP_' in name or '_CMPX_' in name: + uname = name.upper() + if uname in _SPECIAL_DTYPE: return _SPECIAL_DTYPE[uname] + if 'SAD' in uname and ('U8' in uname or 'U16' in uname) and 'QSAD' not in uname and 'MQSAD' not in uname: return 'U32', 'U32', 'U32', 'U32' + if '_CMP_' in uname or '_CMPX_' in uname: dst_suf, src_suf = _suffix(name) - return 'EXEC' if '_CMPX_' in name else 'VCC', src_suf, src_suf, None + return 'EXEC' if '_CMPX_' in uname else 'VCC', src_suf, src_suf, None dst_suf, src_suf = _suffix(name) return dst_suf, src_suf, src_suf, src_suf +_F16_RE = re.compile(r'_[FIUB]16(?:_|$)') +_F64_RE = re.compile(r'_[FIUB]64(?:_|$)') +@cache def spec_is_16bit(name: str) -> bool: - name = name.upper() - if 'SAD' in name or 'PACK' in name or '_PK_' in name or 'SAT_PK' in name or 'DOT2' in name: return False - if '_F32' in name or '_I32' in name or '_U32' in name or '_B32' in name: return False # mixed ops like V_DOT2ACC_F32_F16 - return bool(re.search(r'_[FIUB]16(?:_|$)', name)) -def spec_is_64bit(name: str) -> bool: return bool(re.search(r'_[FIUB]64(?:_|$)', name.upper())) + uname = name.upper() + if 'SAD' in uname or 'PACK' in uname or '_PK_' in uname or 'SAT_PK' in uname or 'DOT2' in uname: return False + if '_F32' in uname or '_I32' in uname or '_U32' in uname or '_B32' in uname: return False + return bool(_F16_RE.search(uname)) +@cache +def spec_is_64bit(name: str) -> bool: return bool(_F64_RE.search(name.upper())) _3SRC = {'FMA', 'MAD', 'MIN3', 'MAX3', 'MED3', 'DIV_FIX', 'DIV_FMAS', 'DIV_SCALE', 'SAD', 'LERP', 'ALIGN', 'CUBE', 'BFE', 'BFI', 'PERM_B32', 'PERMLANE', 'CNDMASK', 'XOR3', 'OR3', 'ADD3', 'LSHL_OR', 'AND_OR', 'LSHL_ADD', 'ADD_LSHL', 'XAD', 'MAXMIN', 'MINMAX', 'DOT2', 'DOT4', 'DOT8', 'WMMA', 'CVT_PK_U8', 'MULLIT', 'CO_CI'} @@ -495,21 +509,25 @@ class Inst: assert cls_name in self._enum_map, f"no enum map for {cls_name}" return self._enum_map[cls_name](val) - @property + @cached_property def op_name(self) -> str: op = self.op return op.name if hasattr(op, 'name') else '' - def dst_regs(self) -> int: return spec_regs(self.op_name)[0] - def src_regs(self, n: int) -> int: return spec_regs(self.op_name)[n + 1] + @cached_property + def _spec_regs(self) -> tuple[int, int, int, int]: return spec_regs(self.op_name) + @cached_property + def _spec_dtype(self) -> tuple[str | None, str | None, str | None, str | None]: return spec_dtype(self.op_name) + def dst_regs(self) -> int: return self._spec_regs[0] + def src_regs(self, n: int) -> int: return self._spec_regs[n + 1] def num_srcs(self) -> int: return spec_num_srcs(self.op_name) - def dst_dtype(self) -> str | None: return spec_dtype(self.op_name)[0] - def src_dtype(self, n: int) -> str | None: return spec_dtype(self.op_name)[n + 1] - def is_src_16(self, n: int) -> bool: return self.src_regs(n) == 1 and is_dtype_16(self.src_dtype(n)) - def is_src_64(self, n: int) -> bool: return self.src_regs(n) == 2 + def dst_dtype(self) -> str | None: return self._spec_dtype[0] + def src_dtype(self, n: int) -> str | None: return self._spec_dtype[n + 1] + def is_src_16(self, n: int) -> bool: return self._spec_regs[n + 1] == 1 and is_dtype_16(self._spec_dtype[n + 1]) + def is_src_64(self, n: int) -> bool: return self._spec_regs[n + 1] == 2 def is_16bit(self) -> bool: return spec_is_16bit(self.op_name) def is_64bit(self) -> bool: return spec_is_64bit(self.op_name) - def is_dst_16(self) -> bool: return self.dst_regs() == 1 and is_dtype_16(self.dst_dtype()) + def is_dst_16(self) -> bool: return self._spec_regs[0] == 1 and is_dtype_16(self._spec_dtype[0]) class Inst32(Inst): pass class Inst64(Inst): pass