mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 06:34:03 -05:00
ds compiled
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
# to regenerate: python -m extra.assembly.amd.pdf --arch rdna3
|
||||
# ruff: noqa: E501,F405,F403
|
||||
# mypy: ignore-errors
|
||||
from extra.assembly.amd.autogen.rdna3.enum import SOP1Op, SOP2Op, SOPCOp, SOPKOp, SOPPOp, VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOP3POp, VOPCOp
|
||||
from extra.assembly.amd.autogen.rdna3.enum import SOP1Op, SOP2Op, SOPCOp, SOPKOp, SOPPOp, VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOP3POp, VOPCOp, DSOp
|
||||
from extra.assembly.amd.pcode import *
|
||||
|
||||
def _SOP1Op_S_MOV_B32(S0, S1, S2, D0, SCC, VCC, laneId, EXEC, literal, VGPR, src0_idx=0, vdst_idx=0, PC=None):
|
||||
@@ -6360,6 +6360,119 @@ def _VOP3Op_V_WRITELANE_B32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal,
|
||||
return {'d0': d0, 'scc': scc, 'vgpr_write': (wr_lane, vdst_idx, s0 & 0xffffffff)}
|
||||
VOP3Op_FUNCTIONS[VOP3Op.V_WRITELANE_B32] = _VOP3Op_V_WRITELANE_B32
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# DS (Data Share) INSTRUCTIONS
|
||||
# DS instructions operate on LDS (Local Data Share) memory
|
||||
# They receive: addr (address), data0/data1 (data VGPRs), offset0/offset1 (byte offsets)
|
||||
# They return: {'vdst': [...]} for loads, {'lds_writes': [...]} for stores
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
def _ds_load(lds, addr, offset, size, sign_extend=False):
|
||||
"""Load from LDS memory. Returns list of 32-bit values."""
|
||||
a = (addr + offset) & 0xffff
|
||||
if size <= 4:
|
||||
val = int.from_bytes(lds[a:a+size], 'little')
|
||||
if sign_extend and size < 4:
|
||||
# Sign extend from size*8 bits to 32 bits
|
||||
sign_bit = 1 << (size * 8 - 1)
|
||||
if val & sign_bit: val |= ~((1 << (size * 8)) - 1)
|
||||
return [val & 0xffffffff]
|
||||
# Multi-dword load
|
||||
return [int.from_bytes(lds[a+i*4:a+i*4+4], 'little') for i in range(size // 4)]
|
||||
|
||||
def _ds_store(lds, addr, offset, values, size):
|
||||
"""Store to LDS memory. values is list of 32-bit dwords, size is bytes per element."""
|
||||
a = (addr + offset) & 0xffff
|
||||
if size <= 4:
|
||||
lds[a:a+size] = (values[0] & ((1 << (size * 8)) - 1)).to_bytes(size, 'little')
|
||||
else:
|
||||
for i, v in enumerate(values):
|
||||
lds[a+i*4:a+i*4+4] = (v & 0xffffffff).to_bytes(4, 'little')
|
||||
|
||||
# Load operations: DS_LOAD_B32, DS_LOAD_B64, DS_LOAD_B128, DS_LOAD_U8, DS_LOAD_I8, DS_LOAD_U16, DS_LOAD_I16
|
||||
def _DSOp_DS_LOAD_B32(lds, addr, data0, data1, vdst, offset0, offset1):
|
||||
return {'vdst': _ds_load(lds, addr, offset0, 4)}
|
||||
|
||||
def _DSOp_DS_LOAD_B64(lds, addr, data0, data1, vdst, offset0, offset1):
|
||||
return {'vdst': _ds_load(lds, addr, offset0, 8)}
|
||||
|
||||
def _DSOp_DS_LOAD_B128(lds, addr, data0, data1, vdst, offset0, offset1):
|
||||
return {'vdst': _ds_load(lds, addr, offset0, 16)}
|
||||
|
||||
def _DSOp_DS_LOAD_U8(lds, addr, data0, data1, vdst, offset0, offset1):
|
||||
return {'vdst': _ds_load(lds, addr, offset0, 1, sign_extend=False)}
|
||||
|
||||
def _DSOp_DS_LOAD_I8(lds, addr, data0, data1, vdst, offset0, offset1):
|
||||
return {'vdst': _ds_load(lds, addr, offset0, 1, sign_extend=True)}
|
||||
|
||||
def _DSOp_DS_LOAD_U16(lds, addr, data0, data1, vdst, offset0, offset1):
|
||||
return {'vdst': _ds_load(lds, addr, offset0, 2, sign_extend=False)}
|
||||
|
||||
def _DSOp_DS_LOAD_I16(lds, addr, data0, data1, vdst, offset0, offset1):
|
||||
return {'vdst': _ds_load(lds, addr, offset0, 2, sign_extend=True)}
|
||||
|
||||
# Store operations: DS_STORE_B32, DS_STORE_B64, DS_STORE_B128, DS_STORE_B8, DS_STORE_B16
|
||||
def _DSOp_DS_STORE_B32(lds, addr, data0, data1, vdst, offset0, offset1):
|
||||
_ds_store(lds, addr, offset0, data0, 4)
|
||||
return {}
|
||||
|
||||
def _DSOp_DS_STORE_B64(lds, addr, data0, data1, vdst, offset0, offset1):
|
||||
_ds_store(lds, addr, offset0, data0, 8)
|
||||
return {}
|
||||
|
||||
def _DSOp_DS_STORE_B128(lds, addr, data0, data1, vdst, offset0, offset1):
|
||||
_ds_store(lds, addr, offset0, data0, 16)
|
||||
return {}
|
||||
|
||||
def _DSOp_DS_STORE_B8(lds, addr, data0, data1, vdst, offset0, offset1):
|
||||
_ds_store(lds, addr, offset0, data0, 1)
|
||||
return {}
|
||||
|
||||
def _DSOp_DS_STORE_B16(lds, addr, data0, data1, vdst, offset0, offset1):
|
||||
_ds_store(lds, addr, offset0, data0, 2)
|
||||
return {}
|
||||
|
||||
# 2-address operations: DS_LOAD_2ADDR_B32, DS_LOAD_2ADDR_B64, DS_STORE_2ADDR_B32, DS_STORE_2ADDR_B64
|
||||
# Note: offsets are scaled by data size (4 for B32, 8 for B64)
|
||||
def _DSOp_DS_LOAD_2ADDR_B32(lds, addr, data0, data1, vdst, offset0, offset1):
|
||||
v0 = _ds_load(lds, addr, offset0 * 4, 4)
|
||||
v1 = _ds_load(lds, addr, offset1 * 4, 4)
|
||||
return {'vdst': v0 + v1}
|
||||
|
||||
def _DSOp_DS_LOAD_2ADDR_B64(lds, addr, data0, data1, vdst, offset0, offset1):
|
||||
v0 = _ds_load(lds, addr, offset0 * 8, 8)
|
||||
v1 = _ds_load(lds, addr, offset1 * 8, 8)
|
||||
return {'vdst': v0 + v1}
|
||||
|
||||
def _DSOp_DS_STORE_2ADDR_B32(lds, addr, data0, data1, vdst, offset0, offset1):
|
||||
_ds_store(lds, addr, offset0 * 4, data0, 4)
|
||||
_ds_store(lds, addr, offset1 * 4, data1, 4)
|
||||
return {}
|
||||
|
||||
def _DSOp_DS_STORE_2ADDR_B64(lds, addr, data0, data1, vdst, offset0, offset1):
|
||||
_ds_store(lds, addr, offset0 * 8, data0, 8)
|
||||
_ds_store(lds, addr, offset1 * 8, data1, 8)
|
||||
return {}
|
||||
|
||||
DSOp_FUNCTIONS = {
|
||||
DSOp.DS_LOAD_B32: _DSOp_DS_LOAD_B32,
|
||||
DSOp.DS_LOAD_B64: _DSOp_DS_LOAD_B64,
|
||||
DSOp.DS_LOAD_B128: _DSOp_DS_LOAD_B128,
|
||||
DSOp.DS_LOAD_U8: _DSOp_DS_LOAD_U8,
|
||||
DSOp.DS_LOAD_I8: _DSOp_DS_LOAD_I8,
|
||||
DSOp.DS_LOAD_U16: _DSOp_DS_LOAD_U16,
|
||||
DSOp.DS_LOAD_I16: _DSOp_DS_LOAD_I16,
|
||||
DSOp.DS_STORE_B32: _DSOp_DS_STORE_B32,
|
||||
DSOp.DS_STORE_B64: _DSOp_DS_STORE_B64,
|
||||
DSOp.DS_STORE_B128: _DSOp_DS_STORE_B128,
|
||||
DSOp.DS_STORE_B8: _DSOp_DS_STORE_B8,
|
||||
DSOp.DS_STORE_B16: _DSOp_DS_STORE_B16,
|
||||
DSOp.DS_LOAD_2ADDR_B32: _DSOp_DS_LOAD_2ADDR_B32,
|
||||
DSOp.DS_LOAD_2ADDR_B64: _DSOp_DS_LOAD_2ADDR_B64,
|
||||
DSOp.DS_STORE_2ADDR_B32: _DSOp_DS_STORE_2ADDR_B32,
|
||||
DSOp.DS_STORE_2ADDR_B64: _DSOp_DS_STORE_2ADDR_B64,
|
||||
}
|
||||
|
||||
COMPILED_FUNCTIONS = {
|
||||
SOP1Op: SOP1Op_FUNCTIONS,
|
||||
SOP2Op: SOP2Op_FUNCTIONS,
|
||||
@@ -6372,6 +6485,7 @@ COMPILED_FUNCTIONS = {
|
||||
VOP3SDOp: VOP3SDOp_FUNCTIONS,
|
||||
VOP3POp: VOP3POp_FUNCTIONS,
|
||||
VOPCOp: VOPCOp_FUNCTIONS,
|
||||
DSOp: DSOp_FUNCTIONS,
|
||||
}
|
||||
|
||||
def get_compiled_functions(): return COMPILED_FUNCTIONS
|
||||
@@ -51,11 +51,6 @@ _D16_LOAD_MAP = {'LOAD_D16_U8': (1,0,0), 'LOAD_D16_I8': (1,1,0), 'LOAD_D16_B16':
|
||||
_D16_STORE_MAP = {'STORE_D16_HI_B8': (1,1), 'STORE_D16_HI_B16': (2,1)} # (size, hi)
|
||||
FLAT_D16_LOAD = _mem_ops([GLOBALOp, FLATOp], _D16_LOAD_MAP)
|
||||
FLAT_D16_STORE = _mem_ops([GLOBALOp, FLATOp], _D16_STORE_MAP)
|
||||
DS_LOAD = {DSOp.DS_LOAD_B32: (1,4,0), DSOp.DS_LOAD_B64: (2,4,0), DSOp.DS_LOAD_B128: (4,4,0), DSOp.DS_LOAD_U8: (1,1,0), DSOp.DS_LOAD_I8: (1,1,1), DSOp.DS_LOAD_U16: (1,2,0), DSOp.DS_LOAD_I16: (1,2,1)}
|
||||
DS_STORE = {DSOp.DS_STORE_B32: (1,4), DSOp.DS_STORE_B64: (2,4), DSOp.DS_STORE_B128: (4,4), DSOp.DS_STORE_B8: (1,1), DSOp.DS_STORE_B16: (1,2)}
|
||||
# 2ADDR ops: load/store two values using offset0 and offset1
|
||||
DS_LOAD_2ADDR = {DSOp.DS_LOAD_2ADDR_B32: 4, DSOp.DS_LOAD_2ADDR_B64: 8}
|
||||
DS_STORE_2ADDR = {DSOp.DS_STORE_2ADDR_B32: 4, DSOp.DS_STORE_2ADDR_B64: 8}
|
||||
SMEM_LOAD = {SMEMOp.S_LOAD_B32: 1, SMEMOp.S_LOAD_B64: 2, SMEMOp.S_LOAD_B128: 4, SMEMOp.S_LOAD_B256: 8, SMEMOp.S_LOAD_B512: 16}
|
||||
|
||||
# VOPD op -> VOP3 op mapping (VOPD is dual-issue of VOP1/VOP2 ops, use VOP3 enums for pseudocode lookup)
|
||||
@@ -225,32 +220,15 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No
|
||||
return
|
||||
|
||||
if isinstance(inst, DS):
|
||||
op, addr0, vdst = inst.op, (V[inst.addr] + inst.offset0) & 0xffff, inst.vdst
|
||||
if op in DS_LOAD:
|
||||
cnt, sz, sign = DS_LOAD[op]
|
||||
for i in range(cnt): val = int.from_bytes(lds[addr0+i*sz:addr0+i*sz+sz], 'little'); V[vdst + i] = _sext(val, sz * 8) & MASK32 if sign else val
|
||||
elif op in DS_STORE:
|
||||
cnt, sz = DS_STORE[op]
|
||||
for i in range(cnt): lds[addr0+i*sz:addr0+i*sz+sz] = (V[inst.data0 + i] & ((1 << (sz * 8)) - 1)).to_bytes(sz, 'little')
|
||||
elif op in DS_LOAD_2ADDR:
|
||||
# Load two values from addr+offset0*sz and addr+offset1*sz into vdst (B32: 1 dword each, B64: 2 dwords each)
|
||||
# Note: offsets are scaled by data size (4 for B32, 8 for B64) per AMD ISA
|
||||
sz = DS_LOAD_2ADDR[op]
|
||||
addr0 = (V[inst.addr] + inst.offset0 * sz) & 0xffff
|
||||
addr1 = (V[inst.addr] + inst.offset1 * sz) & 0xffff
|
||||
cnt = sz // 4 # 1 for B32, 2 for B64
|
||||
for i in range(cnt): V[vdst + i] = int.from_bytes(lds[addr0+i*4:addr0+i*4+4], 'little')
|
||||
for i in range(cnt): V[vdst + cnt + i] = int.from_bytes(lds[addr1+i*4:addr1+i*4+4], 'little')
|
||||
elif op in DS_STORE_2ADDR:
|
||||
# Store two values from data0 and data1 to addr+offset0*sz and addr+offset1*sz
|
||||
# Note: offsets are scaled by data size (4 for B32, 8 for B64) per AMD ISA
|
||||
sz = DS_STORE_2ADDR[op]
|
||||
addr0 = (V[inst.addr] + inst.offset0 * sz) & 0xffff
|
||||
addr1 = (V[inst.addr] + inst.offset1 * sz) & 0xffff
|
||||
cnt = sz // 4
|
||||
for i in range(cnt): lds[addr0+i*4:addr0+i*4+4] = (V[inst.data0 + i] & MASK32).to_bytes(4, 'little')
|
||||
for i in range(cnt): lds[addr1+i*4:addr1+i*4+4] = (V[inst.data1 + i] & MASK32).to_bytes(4, 'little')
|
||||
else: raise NotImplementedError(f"DS op {op}")
|
||||
fn = compiled.get(DSOp, {}).get(inst.op)
|
||||
if fn is None: raise NotImplementedError(f"DS op {inst.op.name} not in pseudocode")
|
||||
# Prepare data registers as lists of dwords
|
||||
data0 = [V[inst.data0 + i] for i in range(4)] # up to 4 dwords
|
||||
data1 = [V[inst.data1 + i] for i in range(4)] if inst.data1 else [0, 0, 0, 0]
|
||||
result = fn(lds, V[inst.addr], data0, data1, inst.vdst, inst.offset0, inst.offset1)
|
||||
# Write results for loads
|
||||
if 'vdst' in result:
|
||||
for i, val in enumerate(result['vdst']): V[inst.vdst + i] = val & MASK32
|
||||
return
|
||||
|
||||
# VOPD: dual-issue, execute two ops simultaneously (read all inputs before writes)
|
||||
|
||||
@@ -36,7 +36,7 @@ FIELD_ORDER = {
|
||||
SRC_EXTRAS = {233: 'DPP8', 234: 'DPP8FI', 250: 'DPP16', 251: 'VCCZ', 252: 'EXECZ', 254: 'LDS_DIRECT'}
|
||||
FLOAT_MAP = {'0.5': 'POS_HALF', '-0.5': 'NEG_HALF', '1.0': 'POS_ONE', '-1.0': 'NEG_ONE', '2.0': 'POS_TWO', '-2.0': 'NEG_TWO',
|
||||
'4.0': 'POS_FOUR', '-4.0': 'NEG_FOUR', '1/(2*PI)': 'INV_2PI', '0': 'ZERO'}
|
||||
INST_PATTERN = re.compile(r'^([SV]_[A-Z0-9_]+)\s+(\d+)\s*$', re.M)
|
||||
INST_PATTERN = re.compile(r'^([SVD]S?_[A-Z0-9_]+)\s+(\d+)\s*$', re.M)
|
||||
|
||||
# Patterns that can't be handled by the DSL (require special handling in emu.py)
|
||||
UNSUPPORTED = ['SGPR[', 'V_SWAP', 'eval ', 'FATAL_HALT', 'HW_REGISTERS',
|
||||
@@ -381,7 +381,7 @@ def _extract_pseudocode(text: str) -> str | None:
|
||||
if s.endswith('.') and not any(p in s for p in ['D0', 'D1', 'S0', 'S1', 'S2', 'SCC', 'VCC', 'tmp', '=']): continue
|
||||
if re.match(r'^[a-z].*\.$', s) and '=' not in s: continue
|
||||
is_code = (any(p in s for p in ['D0.', 'D1.', 'S0.', 'S1.', 'S2.', 'SCC =', 'SCC ?', 'VCC', 'EXEC', 'tmp =', 'tmp[', 'lane =', 'PC =',
|
||||
'D0[', 'D1[', 'S0[', 'S1[', 'S2[']) or
|
||||
'D0[', 'D1[', 'S0[', 'S1[', 'S2[', 'MEM[', 'RETURN_DATA', 'DATA.', 'DATA0', 'DATA1', 'ADDR']) or
|
||||
s.startswith(('if ', 'else', 'elsif', 'endif', 'declare ', 'for ', 'endfor', '//')) or
|
||||
re.match(r'^[a-z_]+\s*=', s) or re.match(r'^[a-z_]+\[', s) or (depth > 0 and '=' in s))
|
||||
if is_code: result.append(s)
|
||||
@@ -467,13 +467,13 @@ def _generate_gen_pcode_py(enums, pseudocode, arch) -> str:
|
||||
# Get op enums for this arch (import from .ins which re-exports from .enum)
|
||||
import importlib
|
||||
autogen = importlib.import_module(f"extra.assembly.amd.autogen.{arch}.ins")
|
||||
OP_ENUMS = [getattr(autogen, name) for name in ['SOP1Op', 'SOP2Op', 'SOPCOp', 'SOPKOp', 'SOPPOp', 'VOP1Op', 'VOP2Op', 'VOP3Op', 'VOP3SDOp', 'VOP3POp', 'VOPCOp', 'VOP3AOp', 'VOP3BOp'] if hasattr(autogen, name)]
|
||||
OP_ENUMS = [getattr(autogen, name) for name in ['SOP1Op', 'SOP2Op', 'SOPCOp', 'SOPKOp', 'SOPPOp', 'VOP1Op', 'VOP2Op', 'VOP3Op', 'VOP3SDOp', 'VOP3POp', 'VOPCOp', 'VOP3AOp', 'VOP3BOp', 'DSOp'] if hasattr(autogen, name)]
|
||||
|
||||
# Build defined ops mapping
|
||||
defined_ops: dict[tuple, list] = {}
|
||||
for enum_cls in OP_ENUMS:
|
||||
for op in enum_cls:
|
||||
if op.name.startswith(('S_', 'V_')): defined_ops.setdefault((op.name, op.value), []).append((enum_cls, op))
|
||||
if op.name.startswith(('S_', 'V_', 'DS_')): defined_ops.setdefault((op.name, op.value), []).append((enum_cls, op))
|
||||
|
||||
enum_names = [e.__name__ for e in OP_ENUMS]
|
||||
lines = [f'''# autogenerated by pdf.py - do not edit
|
||||
|
||||
Reference in New Issue
Block a user