From 2be5f8b6881eb7d3229cbee7fc7fa512c90dad93 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Sun, 4 Jan 2026 11:57:42 -0800 Subject: [PATCH] work --- extra/assembly/amd/qcode.py | 4 ++- extra/assembly/amd/ucode.py | 68 +++++++++++++++++++++++++++++-------- 2 files changed, 56 insertions(+), 16 deletions(-) diff --git a/extra/assembly/amd/qcode.py b/extra/assembly/amd/qcode.py index 86dde9dffb..f6956f2eab 100644 --- a/extra/assembly/amd/qcode.py +++ b/extra/assembly/amd/qcode.py @@ -97,7 +97,9 @@ def expr(s: str) -> Expr: a = _split(s[m.end():e]); return Call(m[1], tuple(expr(x) for x in a) if a != [''] else ()) if s[:4] == 'MEM[' and (e := _match(s, 3, '[', ']')) != -1: r, b = s[e+1:], Call('MEM', (expr(s[4:e]),)) - return Typed(b, DTYPES[r[1:]]) if r[:1] == '.' and r[1:] in DTYPES else b + if not r: return b # Just MEM[addr] + if r[:1] == '.' and r[1:] in DTYPES: return Typed(b, DTYPES[r[1:]]) # MEM[addr].type + # Otherwise fall through to let binary operators parse (e.g., MEM[ADDR].b32.u32 + X) if (q := _fop(s, ('?',))) > 0: d = b = 0 for i in range(q+1, len(s)): diff --git a/extra/assembly/amd/ucode.py b/extra/assembly/amd/ucode.py index 54f9a785f4..f6efe83385 100644 --- a/extra/assembly/amd/ucode.py +++ b/extra/assembly/amd/ucode.py @@ -42,17 +42,26 @@ INPUT_VARS = { 'VDATA': UOp(Ops.DEFINE_VAR, dtypes.uint64, (), ('VDATA', 0, 0xffffffffffffffff)), 'VDST': UOp(Ops.DEFINE_VAR, dtypes.uint64, (), ('VDST', 0, 0xffffffffffffffff)), 'RETURN_DATA': UOp(Ops.DEFINE_VAR, dtypes.uint64, (), ('RETURN_DATA', 0, 0xffffffffffffffff)), + # DS (LDS) op variables - DATA/DATA2 are the source data registers, OFFSET is address offset + 'DATA': UOp(Ops.DEFINE_VAR, dtypes.uint64, (), ('DATA', 0, 0xffffffffffffffff)), + 'DATA2': UOp(Ops.DEFINE_VAR, dtypes.uint64, (), ('DATA2', 0, 0xffffffffffffffff)), + 'OFFSET': UOp(Ops.DEFINE_VAR, dtypes.uint32, (), ('OFFSET', 0, 0xffff)), + 'OFFSET0': UOp(Ops.DEFINE_VAR, dtypes.uint32, (), ('OFFSET0', 0, 0xff)), + 'OFFSET1': UOp(Ops.DEFINE_VAR, dtypes.uint32, (), ('OFFSET1', 0, 0xff)), } # Global memory buffer for MEM[] accesses - DEFINE_GLOBAL with byte pointer MEM_BUF = UOp(Ops.DEFINE_GLOBAL, dtypes.uint8.ptr(addrspace=AddrSpace.GLOBAL), arg=0) +# LDS (local) memory buffer for DS ops - DEFINE_LOCAL with byte pointer +LDS_BUF = UOp(Ops.DEFINE_LOCAL, dtypes.uint8.ptr(addrspace=AddrSpace.LOCAL), arg=0) class Ctx: """Compilation context - tracks variables and outputs.""" - def __init__(self): + def __init__(self, mem_buf: UOp = MEM_BUF): self.vars: dict[str, UOp] = dict(INPUT_VARS) self.outputs: list[tuple[str, UOp, DType]] = [] - self.mem_stores: list[UOp] = [] # STORE UOps for MEM + self.mem_stores: list[UOp] = [] # STORE UOps for MEM/LDS + self.mem_buf = mem_buf # MEM_BUF for global, LDS_BUF for local def _expr(node, ctx: Ctx, hint: DType = None) -> UOp: """Transform qcode AST expression to UOp.""" @@ -85,10 +94,11 @@ def _expr(node, ctx: Ctx, hint: DType = None) -> UOp: case Typed(expr, qdt): dt = _qdt(qdt) - # Handle MEM[addr].type -> memory load using DEFINE_GLOBAL + INDEX + LOAD + # Handle MEM[addr].type -> memory load using INDEX + LOAD (buffer set by caller) if isinstance(expr, Call) and expr.name == 'MEM': addr_uop = _expr(expr.args[0], ctx, dtypes.uint64) - idx = UOp(Ops.INDEX, dt.ptr(0, AddrSpace.GLOBAL), (MEM_BUF, addr_uop)) + buf = ctx.mem_buf + idx = UOp(Ops.INDEX, dt.ptr(0, buf.dtype.addrspace), (buf, addr_uop)) return UOp(Ops.LOAD, dt, (idx,)) if isinstance(expr, Var): if expr.name in ('VCCZ', 'EXECZ'): @@ -323,12 +333,13 @@ def _stmt(stmt, ctx: Ctx): match stmt: case Declare(_, _): pass case Assign(lhs, rhs): - # Handle MEM[addr].type = value -> memory store using DEFINE_GLOBAL + INDEX + STORE + # Handle MEM[addr].type = value -> memory store using INDEX + STORE (buffer set by caller) if isinstance(lhs, Typed) and isinstance(lhs.expr, Call) and lhs.expr.name == 'MEM': dt = _qdt(lhs.dtype) addr_uop = _expr(lhs.expr.args[0], ctx, dtypes.uint64) val_uop = _expr(rhs, ctx, dt) - idx = UOp(Ops.INDEX, dt.ptr(0, AddrSpace.GLOBAL), (MEM_BUF, addr_uop)) + buf = ctx.mem_buf + idx = UOp(Ops.INDEX, dt.ptr(0, buf.dtype.addrspace), (buf, addr_uop)) ctx.mem_stores.append(UOp(Ops.STORE, dtypes.void, (idx, val_uop))) return @@ -423,9 +434,9 @@ def _float_to_bits(val: float, dtype: DType) -> int: if dtype == dtypes.float64: return struct.unpack(' tuple[UOp, list[tuple[str, DType]], dict[str, UOp], list[UOp]]: +def _compile_pseudocode(pseudocode: str, mem_buf: UOp = MEM_BUF) -> tuple[UOp, list[tuple[str, DType]], dict[str, UOp], list[UOp]]: """Compile pseudocode to UOp graph. Returns (sink, outputs, input_vars, mem_stores).""" - ctx = Ctx() + ctx = Ctx(mem_buf=mem_buf) for stmt in parse(pseudocode): _stmt(stmt, ctx) sink = UOp(Ops.SINK, dtypes.void, tuple(u for _, u, _ in ctx.outputs) or ()) return sink, [(n, d) for n, _, d in ctx.outputs], INPUT_VARS, ctx.mem_stores @@ -439,9 +450,12 @@ _DTYPE_ACCESSOR = { def _make_fn(sink: UOp, output_info: list[tuple[str, DType]], input_vars: dict[str, UOp], mem_stores: list[UOp]): """Create runtime function using substitute+simplify, with memory ops on sink.""" - is_mem = bool(mem_stores) or any(u.op == Ops.LOAD for u in sink.toposort()) - # Add stores to sink so they get substituted/simplified too + # Add stores to sink first so they're included in detection if mem_stores: sink = UOp(Ops.SINK, dtypes.void, sink.src + tuple(mem_stores)) + # Detect memory type from UOps: check if any INDEX uses DEFINE_LOCAL + topo = sink.toposort() + is_lds = any(u.op == Ops.DEFINE_LOCAL for u in topo) + is_mem = bool(mem_stores) or any(u.op == Ops.LOAD for u in topo) def _extract_results(s, MEM=None): # Execute stores and extract output values from simplified sink @@ -459,10 +473,31 @@ def _make_fn(sink: UOp, output_info: list[tuple[str, DType]], input_vars: dict[s result[name] = _float_to_bits(s.src[i].arg, dtype) if dtype in FLOATS else int(s.src[i].arg) & (0xffffffff if dtype.itemsize <= 4 else 0xffffffffffffffff) return result - if is_mem: + if is_lds: + # DS (LDS) ops: fn(MEM, addr, data0, data1, offset0, offset1) + def fn(MEM, addr, data0=0, data1=0, offset0=0, offset1=0): + dvars = {input_vars['ADDR']: UOp.const(dtypes.uint64, addr), input_vars['DATA']: UOp.const(dtypes.uint64, data0), + input_vars['DATA2']: UOp.const(dtypes.uint64, data1), input_vars['OFFSET']: UOp.const(dtypes.uint32, offset0), + input_vars['OFFSET0']: UOp.const(dtypes.uint32, offset0), input_vars['OFFSET1']: UOp.const(dtypes.uint32, offset1), + input_vars['RETURN_DATA']: UOp.const(dtypes.uint64, 0)} + s1 = sink.substitute(dvars).simplify() + # Replace LOADs with actual values from LDS, then simplify again + loads = {} + for u in s1.toposort(): + if u.op == Ops.LOAD: + idx_uop = u.src[0] + load_addr, dt = int(idx_uop.src[1].arg), idx_uop.dtype.base + acc = _DTYPE_ACCESSOR.get(dt, 'u32') + loads[u] = UOp.const(dt, getattr(MEM[load_addr], acc)) + s2 = s1.substitute(loads).simplify() if loads else s1 + return _extract_results(s2, MEM) + return fn + elif is_mem: + # SMEM/FLAT/GLOBAL ops: fn(MEM, addr, vdata, vdst) def fn(MEM, addr, vdata=0, vdst=0): dvars = {input_vars['ADDR']: UOp.const(dtypes.uint64, addr), input_vars['SDATA']: UOp.const(dtypes.uint64, 0), input_vars['VDATA']: UOp.const(dtypes.uint64, vdata), input_vars['VDST']: UOp.const(dtypes.uint64, vdst), + input_vars['DATA']: UOp.const(dtypes.uint64, vdata), input_vars['DATA2']: UOp.const(dtypes.uint64, 0), input_vars['RETURN_DATA']: UOp.const(dtypes.uint64, 0)} s1 = sink.substitute(dvars).simplify() # Replace LOADs with actual values from MEM, then simplify again @@ -477,6 +512,7 @@ def _make_fn(sink: UOp, output_info: list[tuple[str, DType]], input_vars: dict[s return _extract_results(s2, MEM) return fn else: + # ALU ops: fn(s0, s1, s2, d0, scc, vcc, laneId, exec_mask, literal, VGPR, ...) def fn(s0, s1, s2, d0, scc, vcc, laneId, exec_mask, literal, VGPR, src0_idx=0, vdst_idx=0, pc=None): simm16 = (literal if -32768 <= literal <= 32767 else (literal - 65536 if literal < 65536 else 0)) if literal is not None else 0 dvars = { @@ -534,10 +570,9 @@ _SKIP_OPS = { 'V_CVT_OFF_F32_I4', # CVT_OFF_TABLE lookup } -# Patterns that still need pcode (LDS, register arrays, special ops, atomics/DS with DATA/OFFSET) -# Note: 'DATA.' matches atomic ops (not SDATA/VDATA), 'OFFSET0' matches DS ops (not just OFFSET in addr calc) +# Patterns that still need pcode (register arrays, special ops, complex atomics) _PCODE_PATTERNS = ('LDS[', 'LDS(', 'VGPR[', 'SGPR[', 'GPR[', 'GS_REGS', 'thread_in[', 'thread_out[', 'thread_valid[', - ' DATA.', '=DATA.', ' DATA[', '=DATA[', 'DATA2', 'OFFSET0', 'OFFSET1', 'OFFSET.') + 'DATA2', 'OFFSET0', 'OFFSET1') # DS ops with dual offsets or DATA2 # Wide outputs (>64 bit) that need pcode's arbitrary-precision integers _WIDE_OUTPUT_PATTERNS = ('SDATA[95', 'SDATA[127', 'SDATA[159', 'SDATA[191', 'SDATA[223', 'SDATA[255', # SMEM B128, B256, B512 'VDATA[95', 'VDATA[127') # FLAT B128 @@ -548,5 +583,8 @@ def compile_uop(op_name: str, pseudocode: str): if op_name in _SKIP_OPS: return None if any(p in pseudocode for p in _PCODE_PATTERNS): return None # these patterns still need pcode if any(p in pseudocode for p in _WIDE_OUTPUT_PATTERNS): return None # >64-bit outputs need pcode - sink, output_info, input_vars, mem_stores = _compile_pseudocode(pseudocode) + # DS ops use LDS (local memory), others use global MEM + is_ds = op_name.startswith('DS_') + mem_buf = LDS_BUF if is_ds else MEM_BUF + sink, output_info, input_vars, mem_stores = _compile_pseudocode(pseudocode, mem_buf) return _make_fn(sink, output_info, input_vars, mem_stores)