mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
100% asm
This commit is contained in:
@@ -60,7 +60,17 @@ HWREG = {1: 'HW_REG_MODE', 2: 'HW_REG_STATUS', 3: 'HW_REG_TRAPSTS', 4: 'HW_REG_H
|
||||
6: 'HW_REG_LDS_ALLOC', 7: 'HW_REG_IB_STS', 15: 'HW_REG_SH_MEM_BASES', 18: 'HW_REG_PERF_SNAPSHOT_PC_LO',
|
||||
19: 'HW_REG_PERF_SNAPSHOT_PC_HI', 20: 'HW_REG_FLAT_SCR_LO', 21: 'HW_REG_FLAT_SCR_HI', 22: 'HW_REG_XNACK_MASK',
|
||||
23: 'HW_REG_HW_ID1', 24: 'HW_REG_HW_ID2', 25: 'HW_REG_POPS_PACKER', 28: 'HW_REG_IB_STS2'}
|
||||
# GFX942-specific HWREG values
|
||||
_HWREG_GFX942 = {'HW_REG_XCC_ID': 20, 'HW_REG_SQ_PERF_SNAPSHOT_DATA': 21, 'HW_REG_SQ_PERF_SNAPSHOT_DATA1': 22,
|
||||
'HW_REG_SQ_PERF_SNAPSHOT_PC_LO': 23, 'HW_REG_SQ_PERF_SNAPSHOT_PC_HI': 24}
|
||||
HWREG_IDS = {v.lower(): k for k, v in HWREG.items()}
|
||||
HWREG_IDS.update({k.lower(): v for k, v in _HWREG_GFX942.items()})
|
||||
def hwreg(name, offset=0, size=32):
|
||||
"""Encode hwreg(name[, offset[, size]]) -> simm16 value. id[5:0], offset[10:6], size-1[15:11]"""
|
||||
if isinstance(name, int): hid = name
|
||||
else: hid = HWREG_IDS.get(name.lower(), HWREG_IDS.get(name.lower().replace('hw_reg_', ''), None))
|
||||
if hid is None: raise ValueError(f"unknown hwreg: {name}")
|
||||
return hid | (offset << 6) | ((size - 1) << 11)
|
||||
# RDNA unified buffer format - extracted from PDF, use enum for name->value lookup
|
||||
BUF_FMT = {e.name: e.value for e in BufFmt}
|
||||
def _parse_buf_fmt_combo(s: str) -> int: # parse format:[BUF_DATA_FORMAT_X, BUF_NUM_FORMAT_Y]
|
||||
@@ -571,6 +581,13 @@ def _op2dsl(op: str, arch: str = "rdna3") -> str:
|
||||
if m := re.match(r'^([asvt](?:tmp)?)\[(\d+):(\d+)\]$', lo): return wrap(f"{rp.get(m.group(1), m.group(1))}[{m.group(2)}:{m.group(3)}]")
|
||||
if m := re.match(r'^([asvt](?:tmp)?)(\d+)$', lo): return wrap(f"{rp.get(m.group(1), m.group(1))}[{m.group(2)}]")
|
||||
if re.match(r'^-?\d+$|^-?0x[0-9a-fA-F]+$', op): return f"SrcMod({op}, neg={neg}, abs_={abs_})" if neg or abs_ else op
|
||||
# Floating-point literal: convert to IEEE 754 32-bit integer representation
|
||||
import struct
|
||||
try:
|
||||
f = float(op)
|
||||
as_int = struct.unpack('<I', struct.pack('<f', f))[0]
|
||||
return f"SrcMod({as_int}, neg={neg}, abs_={abs_})" if neg or abs_ else str(as_int)
|
||||
except ValueError: pass
|
||||
return wrap(op)
|
||||
|
||||
def _parse_ops(s: str) -> list[str]:
|
||||
@@ -722,6 +739,7 @@ def get_dsl(text: str, arch: str = "rdna3", gfx942: bool = False) -> str:
|
||||
m, text = _extract(text, r'\s+row_half_mirror(?:\s|$)'); dpp_ctrl = 0x141 if m else dpp_ctrl
|
||||
m, text = _extract(text, r'\s+row_bcast:15(?:\s|$)'); dpp_ctrl = 0x142 if m else dpp_ctrl
|
||||
m, text = _extract(text, r'\s+row_bcast:31(?:\s|$)'); dpp_ctrl = 0x143 if m else dpp_ctrl
|
||||
m, text = _extract(text, r'\s+row_newbcast:(\d+)'); dpp_ctrl = 0x150 + int(m.group(1)) if m else dpp_ctrl
|
||||
m, text = _extract(text, r'\s+row_mask:(0x[0-9a-fA-F]+|\d+)'); dpp_row_mask = int(m.group(1), 0) if m else None; dpp_row_mask_specified = m is not None
|
||||
m, text = _extract(text, r'\s+bank_mask:(0x[0-9a-fA-F]+|\d+)'); dpp_bank_mask = int(m.group(1), 0) if m else None; dpp_bank_mask_specified = m is not None
|
||||
m, text = _extract(text, r'\s+bound_ctrl:([01])'); dpp_bound_ctrl = 1 if m else None # bound_ctrl:0 or bound_ctrl:1 both set bit to 1
|
||||
@@ -1311,7 +1329,9 @@ def get_dsl(text: str, arch: str = "rdna3", gfx942: bool = False) -> str:
|
||||
# ACC register support for CDNA: detect a[N] registers and set acc=1
|
||||
acc_mod = ', acc=1' if arch == 'cdna' and _has_acc(args) else ''
|
||||
args = [_acc_to_vgpr(a) for a in args] # convert a[N] to v[N] for encoding
|
||||
if glc and len(args) >= 3: return f"{mn}(vdst={args[0]}, addr={args[1]}, data={args[2]}{f', saddr={_saddr(args[3], seg)}' if len(args) >= 4 else f', saddr={_saddr_off(seg)}'}{flat_mods}{acc_mod})"
|
||||
# For atomics with return value: vdst, addr, data, [saddr] - triggered by glc (or sc0 for GFX942)
|
||||
has_return = glc or sc0
|
||||
if has_return and len(args) >= 3: return f"{mn}(vdst={args[0]}, addr={args[1]}, data={args[2]}{f', saddr={_saddr(args[3], seg)}' if len(args) >= 4 else f', saddr={_saddr_off(seg)}'}{flat_mods}{acc_mod})"
|
||||
if len(args) >= 2: return f"{mn}(addr={args[0]}, data={args[1]}{f', saddr={_saddr(args[2], seg)}' if len(args) >= 3 else f', saddr={_saddr_off(seg)}'}{flat_mods}{acc_mod})"
|
||||
|
||||
# DS instructions
|
||||
@@ -1453,6 +1473,95 @@ def get_dsl(text: str, arch: str = "rdna3", gfx942: bool = False) -> str:
|
||||
a_str = ', '.join(vop3_args + all_kw)
|
||||
return f"{fn[:-4]}({a_str})"
|
||||
|
||||
# CDNA VOP1 with modifiers: auto-promote to VOP3A/SDWA/DPP
|
||||
# Check if this is a VOP1 instruction needing extended encoding (not already _e64/_sdwa/_dpp)
|
||||
has_vop3_mods = any(k.startswith(('omod=', 'clmp=')) for k in all_kw)
|
||||
has_sdwa_mods = sdwa_src0_sel is not None or sdwa_src1_sel is not None or sdwa_dst_sel is not None
|
||||
has_dpp_mods = dpp_ctrl is not None
|
||||
if arch == "cdna" and fn.startswith('v_') and not fn.endswith(('_e64', '_sdwa', '_dpp')) and (has_vop3_mods or has_sdwa_mods or has_dpp_mods):
|
||||
from extra.assembly.amd.autogen.cdna.ins import VOP1Op, VOP2Op, SDWA, DPP
|
||||
fn_upper = fn.upper()
|
||||
vop1_op = getattr(VOP1Op, fn_upper, None)
|
||||
vop2_op = getattr(VOP2Op, fn_upper, None)
|
||||
if vop1_op is not None or vop2_op is not None:
|
||||
if has_sdwa_mods:
|
||||
# SDWA encoding for VOP1/VOP2 with src0_sel/src1_sel/dst_sel
|
||||
sdwa_kw = []
|
||||
src0_orig = ops[1].strip().lower() if len(ops) > 1 else ''
|
||||
src0_is_sgpr = src0_orig.startswith('s') and not src0_orig.startswith('src')
|
||||
src0_is_literal = src0_orig.isdigit() or (len(src0_orig) > 2 and src0_orig[:2] == '0x')
|
||||
if vop1_op is not None:
|
||||
sdwa_kw.append(f'vop_op={vop1_op.value}')
|
||||
sdwa_kw.append('vop2_op=63') # 0x3f indicates VOP1 mode
|
||||
sdwa_kw.append(f'vdst={args[0]}' if args else 'vdst=v[0]')
|
||||
sdwa_kw.append(f'src0={args[1]}' if len(args) > 1 else 'src0=v[0]')
|
||||
else:
|
||||
sdwa_kw.append(f'vop_op={args[1] if len(args) > 1 else "v[0]"}')
|
||||
sdwa_kw.append(f'vop2_op={vop2_op.value}')
|
||||
sdwa_kw.append(f'vdst={args[0]}' if args else 'vdst=v[0]')
|
||||
sdwa_kw.append(f'src0={args[2] if len(args) > 2 else "v[0]"}')
|
||||
sdwa_kw.append(f'dst_sel={sdwa_dst_sel if sdwa_dst_sel is not None else 6}')
|
||||
sdwa_kw.append('dst_u=0')
|
||||
sdwa_kw.append(f'src0_sel={sdwa_src0_sel if sdwa_src0_sel is not None else 6}')
|
||||
sdwa_kw.append('src0_sext=0')
|
||||
sdwa_kw.append('src0_neg=0')
|
||||
sdwa_kw.append('src0_abs=0')
|
||||
sdwa_kw.append(f's0={1 if src0_is_sgpr or src0_is_literal else 0}') # s0=1 for SGPR/literal
|
||||
sdwa_kw.append(f'src1_sel={sdwa_src1_sel if sdwa_src1_sel is not None else 0}') # 0 for VOP1
|
||||
sdwa_kw.append('src1_sext=0')
|
||||
sdwa_kw.append('src1_neg=0')
|
||||
sdwa_kw.append('src1_abs=0')
|
||||
sdwa_kw.append('s1=0')
|
||||
# Add clamp and omod if present
|
||||
if any(k == 'clmp=1' for k in all_kw): sdwa_kw.append('clmp=1')
|
||||
for k in all_kw:
|
||||
if k.startswith('omod='): sdwa_kw.append(k); break
|
||||
return f"SDWA({', '.join(sdwa_kw)})"
|
||||
elif has_dpp_mods:
|
||||
# DPP encoding for VOP1/VOP2 with quad_perm/row_shl/etc.
|
||||
dpp_kw = []
|
||||
if vop1_op is not None:
|
||||
dpp_kw.append(f'vop_op={vop1_op.value}')
|
||||
dpp_kw.append('vop2_op=63') # 0x3f indicates VOP1 mode
|
||||
dpp_kw.append(f'vdst={args[0]}' if args else 'vdst=v[0]')
|
||||
dpp_kw.append(f'src0={args[1]}' if len(args) > 1 else 'src0=v[0]')
|
||||
else:
|
||||
dpp_kw.append(f'vop_op={args[1] if len(args) > 1 else "v[0]"}')
|
||||
dpp_kw.append(f'vop2_op={vop2_op.value}')
|
||||
dpp_kw.append(f'vdst={args[0]}' if args else 'vdst=v[0]')
|
||||
dpp_kw.append(f'src0={args[2] if len(args) > 2 else "v[0]"}')
|
||||
dpp_kw.append(f'dpp_ctrl={dpp_ctrl}')
|
||||
dpp_kw.append(f'row_mask={dpp_row_mask if dpp_row_mask is not None else 15}')
|
||||
dpp_kw.append(f'bank_mask={dpp_bank_mask if dpp_bank_mask is not None else 15}')
|
||||
dpp_kw.append(f'bound_ctrl={dpp_bound_ctrl if dpp_bound_ctrl is not None else 0}')
|
||||
return f"DPP({', '.join(dpp_kw)})"
|
||||
elif has_vop3_mods and vop1_op is not None:
|
||||
# VOP3A encoding for VOP1 with clamp/omod
|
||||
from extra.assembly.amd.autogen.cdna.ins import VOP3AOp
|
||||
# Calculate promoted opcode: VOP3 op = 320 + VOP1_op
|
||||
promoted_op = 320 + vop1_op.value
|
||||
vop3_kw = [f'op={promoted_op}']
|
||||
vop3_kw.append(f'vdst={args[0]}' if args else 'vdst=v[0]')
|
||||
vop3_kw.append(f'src0={args[1]}' if len(args) > 1 else 'src0=v[0]')
|
||||
vop3_kw.append('src1=RawImm(0)')
|
||||
vop3_kw.append('src2=RawImm(0)')
|
||||
vop3_kw.extend(all_kw)
|
||||
return f"VOP3A({', '.join(vop3_kw)})"
|
||||
|
||||
# GFX942-specific VOP3A opcode adjustments: some instructions need +64 offset
|
||||
_GFX942_VOP3A_OFFSET64 = {'V_CVT_PK_BF8_F32', 'V_CVT_PK_FP8_F32', 'V_CVT_SR_BF8_F32', 'V_CVT_SR_FP8_F32', 'V_LSHL_ADD_U64'}
|
||||
if gfx942 and fn.upper() in _GFX942_VOP3A_OFFSET64:
|
||||
from extra.assembly.amd.autogen.cdna.ins import VOP3AOp
|
||||
base_op = getattr(VOP3AOp, fn.upper(), None)
|
||||
if base_op is not None:
|
||||
vop3_kw = [f'op={base_op + 64}']
|
||||
vop3_kw.append(f'vdst={args[0]}' if args else 'vdst=v[0]')
|
||||
vop3_kw.append(f'src0={args[1]}' if len(args) > 1 else 'src0=v[0]')
|
||||
vop3_kw.append(f'src1={args[2]}' if len(args) > 2 else 'src1=RawImm(0)')
|
||||
vop3_kw.append(f'src2={args[3]}' if len(args) > 3 else 'src2=RawImm(0)')
|
||||
vop3_kw.extend(all_kw)
|
||||
return f"VOP3A({', '.join(vop3_kw)})"
|
||||
|
||||
a_str, kw_str = ', '.join(args), ', '.join(all_kw)
|
||||
return f"{fn}({a_str}, {kw_str})" if kw_str and a_str else f"{fn}({kw_str})" if kw_str else f"{fn}({a_str})"
|
||||
|
||||
@@ -1465,9 +1574,15 @@ def asm(text: str, arch: str = "rdna3") -> Inst:
|
||||
from extra.assembly.amd.autogen.cdna import ins as cdna_ins
|
||||
ns = {n: getattr(cdna_ins, n) for n in dir(cdna_ins) if not n.startswith('_')}
|
||||
# CDNA special registers: m0=124, flat_scratch=102-103, xnack_mask=104-105, no NULL (use m0 for off)
|
||||
# HWREG symbolic names for s_getreg_b32/s_setreg_b32
|
||||
_hwreg_names = {k: v for k, v in _HWREG_GFX942.items()}
|
||||
_hwreg_names.update({v: k for k, v in HWREG.items()}) # standard names: id -> name
|
||||
_hwreg_ids = {v: k for k, v in _hwreg_names.items()} # reverse: name -> id
|
||||
ns.update({'s': s, 'v': v, 'ttmp': ttmp, 'abs': abs, 'RawImm': RawImm, 'SrcMod': SrcMod, 'VGPR': VGPR, 'SGPR': SGPR, 'TTMP': TTMP,
|
||||
'VCC_LO': RawImm(106), 'VCC_HI': RawImm(107), 'VCC': RawImm(106), 'EXEC_LO': RawImm(126), 'EXEC_HI': RawImm(127), 'EXEC': RawImm(126),
|
||||
'SCC': RawImm(253), 'M0': RawImm(124), 'NULL': RawImm(124), 'OFF': RawImm(124),
|
||||
'SCC': RawImm(253), 'M0': RawImm(124), 'NULL': RawImm(124), 'OFF': RawImm(124), 'hwreg': hwreg,
|
||||
'HW_REG_XCC_ID': 20, 'HW_REG_SQ_PERF_SNAPSHOT_DATA': 21, 'HW_REG_SQ_PERF_SNAPSHOT_DATA1': 22,
|
||||
'HW_REG_SQ_PERF_SNAPSHOT_PC_LO': 23, 'HW_REG_SQ_PERF_SNAPSHOT_PC_HI': 24,
|
||||
'FLAT_SCRATCH_LO': RawImm(102), 'FLAT_SCRATCH_HI': RawImm(103), 'FLAT_SCRATCH': RawImm(102),
|
||||
'XNACK_MASK_LO': RawImm(104), 'XNACK_MASK_HI': RawImm(105), 'XNACK_MASK': RawImm(104),
|
||||
'SRC_VCCZ': RawImm(251), 'SRC_EXECZ': RawImm(252), 'SRC_SCC': RawImm(253), 'SRC_LDS_DIRECT': RawImm(254)})
|
||||
|
||||
Reference in New Issue
Block a user