This commit is contained in:
George Hotz
2026-01-04 11:57:42 -08:00
parent db9140b8b7
commit 2be5f8b688
2 changed files with 56 additions and 16 deletions

View File

@@ -97,7 +97,9 @@ def expr(s: str) -> Expr:
a = _split(s[m.end():e]); return Call(m[1], tuple(expr(x) for x in a) if a != [''] else ())
if s[:4] == 'MEM[' and (e := _match(s, 3, '[', ']')) != -1:
r, b = s[e+1:], Call('MEM', (expr(s[4:e]),))
return Typed(b, DTYPES[r[1:]]) if r[:1] == '.' and r[1:] in DTYPES else b
if not r: return b # Just MEM[addr]
if r[:1] == '.' and r[1:] in DTYPES: return Typed(b, DTYPES[r[1:]]) # MEM[addr].type
# Otherwise fall through to let binary operators parse (e.g., MEM[ADDR].b32.u32 + X)
if (q := _fop(s, ('?',))) > 0:
d = b = 0
for i in range(q+1, len(s)):

View File

@@ -42,17 +42,26 @@ INPUT_VARS = {
'VDATA': UOp(Ops.DEFINE_VAR, dtypes.uint64, (), ('VDATA', 0, 0xffffffffffffffff)),
'VDST': UOp(Ops.DEFINE_VAR, dtypes.uint64, (), ('VDST', 0, 0xffffffffffffffff)),
'RETURN_DATA': UOp(Ops.DEFINE_VAR, dtypes.uint64, (), ('RETURN_DATA', 0, 0xffffffffffffffff)),
# DS (LDS) op variables - DATA/DATA2 are the source data registers, OFFSET is address offset
'DATA': UOp(Ops.DEFINE_VAR, dtypes.uint64, (), ('DATA', 0, 0xffffffffffffffff)),
'DATA2': UOp(Ops.DEFINE_VAR, dtypes.uint64, (), ('DATA2', 0, 0xffffffffffffffff)),
'OFFSET': UOp(Ops.DEFINE_VAR, dtypes.uint32, (), ('OFFSET', 0, 0xffff)),
'OFFSET0': UOp(Ops.DEFINE_VAR, dtypes.uint32, (), ('OFFSET0', 0, 0xff)),
'OFFSET1': UOp(Ops.DEFINE_VAR, dtypes.uint32, (), ('OFFSET1', 0, 0xff)),
}
# Global memory buffer for MEM[] accesses - DEFINE_GLOBAL with byte pointer
MEM_BUF = UOp(Ops.DEFINE_GLOBAL, dtypes.uint8.ptr(addrspace=AddrSpace.GLOBAL), arg=0)
# LDS (local) memory buffer for DS ops - DEFINE_LOCAL with byte pointer
LDS_BUF = UOp(Ops.DEFINE_LOCAL, dtypes.uint8.ptr(addrspace=AddrSpace.LOCAL), arg=0)
class Ctx:
"""Compilation context - tracks variables and outputs."""
def __init__(self):
def __init__(self, mem_buf: UOp = MEM_BUF):
self.vars: dict[str, UOp] = dict(INPUT_VARS)
self.outputs: list[tuple[str, UOp, DType]] = []
self.mem_stores: list[UOp] = [] # STORE UOps for MEM
self.mem_stores: list[UOp] = [] # STORE UOps for MEM/LDS
self.mem_buf = mem_buf # MEM_BUF for global, LDS_BUF for local
def _expr(node, ctx: Ctx, hint: DType = None) -> UOp:
"""Transform qcode AST expression to UOp."""
@@ -85,10 +94,11 @@ def _expr(node, ctx: Ctx, hint: DType = None) -> UOp:
case Typed(expr, qdt):
dt = _qdt(qdt)
# Handle MEM[addr].type -> memory load using DEFINE_GLOBAL + INDEX + LOAD
# Handle MEM[addr].type -> memory load using INDEX + LOAD (buffer set by caller)
if isinstance(expr, Call) and expr.name == 'MEM':
addr_uop = _expr(expr.args[0], ctx, dtypes.uint64)
idx = UOp(Ops.INDEX, dt.ptr(0, AddrSpace.GLOBAL), (MEM_BUF, addr_uop))
buf = ctx.mem_buf
idx = UOp(Ops.INDEX, dt.ptr(0, buf.dtype.addrspace), (buf, addr_uop))
return UOp(Ops.LOAD, dt, (idx,))
if isinstance(expr, Var):
if expr.name in ('VCCZ', 'EXECZ'):
@@ -323,12 +333,13 @@ def _stmt(stmt, ctx: Ctx):
match stmt:
case Declare(_, _): pass
case Assign(lhs, rhs):
# Handle MEM[addr].type = value -> memory store using DEFINE_GLOBAL + INDEX + STORE
# Handle MEM[addr].type = value -> memory store using INDEX + STORE (buffer set by caller)
if isinstance(lhs, Typed) and isinstance(lhs.expr, Call) and lhs.expr.name == 'MEM':
dt = _qdt(lhs.dtype)
addr_uop = _expr(lhs.expr.args[0], ctx, dtypes.uint64)
val_uop = _expr(rhs, ctx, dt)
idx = UOp(Ops.INDEX, dt.ptr(0, AddrSpace.GLOBAL), (MEM_BUF, addr_uop))
buf = ctx.mem_buf
idx = UOp(Ops.INDEX, dt.ptr(0, buf.dtype.addrspace), (buf, addr_uop))
ctx.mem_stores.append(UOp(Ops.STORE, dtypes.void, (idx, val_uop)))
return
@@ -423,9 +434,9 @@ def _float_to_bits(val: float, dtype: DType) -> int:
if dtype == dtypes.float64: return struct.unpack('<Q', struct.pack('<d', val))[0]
return int(val)
def _compile_pseudocode(pseudocode: str) -> tuple[UOp, list[tuple[str, DType]], dict[str, UOp], list[UOp]]:
def _compile_pseudocode(pseudocode: str, mem_buf: UOp = MEM_BUF) -> tuple[UOp, list[tuple[str, DType]], dict[str, UOp], list[UOp]]:
"""Compile pseudocode to UOp graph. Returns (sink, outputs, input_vars, mem_stores)."""
ctx = Ctx()
ctx = Ctx(mem_buf=mem_buf)
for stmt in parse(pseudocode): _stmt(stmt, ctx)
sink = UOp(Ops.SINK, dtypes.void, tuple(u for _, u, _ in ctx.outputs) or ())
return sink, [(n, d) for n, _, d in ctx.outputs], INPUT_VARS, ctx.mem_stores
@@ -439,9 +450,12 @@ _DTYPE_ACCESSOR = {
def _make_fn(sink: UOp, output_info: list[tuple[str, DType]], input_vars: dict[str, UOp], mem_stores: list[UOp]):
"""Create runtime function using substitute+simplify, with memory ops on sink."""
is_mem = bool(mem_stores) or any(u.op == Ops.LOAD for u in sink.toposort())
# Add stores to sink so they get substituted/simplified too
# Add stores to sink first so they're included in detection
if mem_stores: sink = UOp(Ops.SINK, dtypes.void, sink.src + tuple(mem_stores))
# Detect memory type from UOps: check if any INDEX uses DEFINE_LOCAL
topo = sink.toposort()
is_lds = any(u.op == Ops.DEFINE_LOCAL for u in topo)
is_mem = bool(mem_stores) or any(u.op == Ops.LOAD for u in topo)
def _extract_results(s, MEM=None):
# Execute stores and extract output values from simplified sink
@@ -459,10 +473,31 @@ def _make_fn(sink: UOp, output_info: list[tuple[str, DType]], input_vars: dict[s
result[name] = _float_to_bits(s.src[i].arg, dtype) if dtype in FLOATS else int(s.src[i].arg) & (0xffffffff if dtype.itemsize <= 4 else 0xffffffffffffffff)
return result
if is_mem:
if is_lds:
# DS (LDS) ops: fn(MEM, addr, data0, data1, offset0, offset1)
def fn(MEM, addr, data0=0, data1=0, offset0=0, offset1=0):
dvars = {input_vars['ADDR']: UOp.const(dtypes.uint64, addr), input_vars['DATA']: UOp.const(dtypes.uint64, data0),
input_vars['DATA2']: UOp.const(dtypes.uint64, data1), input_vars['OFFSET']: UOp.const(dtypes.uint32, offset0),
input_vars['OFFSET0']: UOp.const(dtypes.uint32, offset0), input_vars['OFFSET1']: UOp.const(dtypes.uint32, offset1),
input_vars['RETURN_DATA']: UOp.const(dtypes.uint64, 0)}
s1 = sink.substitute(dvars).simplify()
# Replace LOADs with actual values from LDS, then simplify again
loads = {}
for u in s1.toposort():
if u.op == Ops.LOAD:
idx_uop = u.src[0]
load_addr, dt = int(idx_uop.src[1].arg), idx_uop.dtype.base
acc = _DTYPE_ACCESSOR.get(dt, 'u32')
loads[u] = UOp.const(dt, getattr(MEM[load_addr], acc))
s2 = s1.substitute(loads).simplify() if loads else s1
return _extract_results(s2, MEM)
return fn
elif is_mem:
# SMEM/FLAT/GLOBAL ops: fn(MEM, addr, vdata, vdst)
def fn(MEM, addr, vdata=0, vdst=0):
dvars = {input_vars['ADDR']: UOp.const(dtypes.uint64, addr), input_vars['SDATA']: UOp.const(dtypes.uint64, 0),
input_vars['VDATA']: UOp.const(dtypes.uint64, vdata), input_vars['VDST']: UOp.const(dtypes.uint64, vdst),
input_vars['DATA']: UOp.const(dtypes.uint64, vdata), input_vars['DATA2']: UOp.const(dtypes.uint64, 0),
input_vars['RETURN_DATA']: UOp.const(dtypes.uint64, 0)}
s1 = sink.substitute(dvars).simplify()
# Replace LOADs with actual values from MEM, then simplify again
@@ -477,6 +512,7 @@ def _make_fn(sink: UOp, output_info: list[tuple[str, DType]], input_vars: dict[s
return _extract_results(s2, MEM)
return fn
else:
# ALU ops: fn(s0, s1, s2, d0, scc, vcc, laneId, exec_mask, literal, VGPR, ...)
def fn(s0, s1, s2, d0, scc, vcc, laneId, exec_mask, literal, VGPR, src0_idx=0, vdst_idx=0, pc=None):
simm16 = (literal if -32768 <= literal <= 32767 else (literal - 65536 if literal < 65536 else 0)) if literal is not None else 0
dvars = {
@@ -534,10 +570,9 @@ _SKIP_OPS = {
'V_CVT_OFF_F32_I4', # CVT_OFF_TABLE lookup
}
# Patterns that still need pcode (LDS, register arrays, special ops, atomics/DS with DATA/OFFSET)
# Note: 'DATA.' matches atomic ops (not SDATA/VDATA), 'OFFSET0' matches DS ops (not just OFFSET in addr calc)
# Patterns that still need pcode (register arrays, special ops, complex atomics)
_PCODE_PATTERNS = ('LDS[', 'LDS(', 'VGPR[', 'SGPR[', 'GPR[', 'GS_REGS', 'thread_in[', 'thread_out[', 'thread_valid[',
' DATA.', '=DATA.', ' DATA[', '=DATA[', 'DATA2', 'OFFSET0', 'OFFSET1', 'OFFSET.')
'DATA2', 'OFFSET0', 'OFFSET1') # DS ops with dual offsets or DATA2
# Wide outputs (>64 bit) that need pcode's arbitrary-precision integers
_WIDE_OUTPUT_PATTERNS = ('SDATA[95', 'SDATA[127', 'SDATA[159', 'SDATA[191', 'SDATA[223', 'SDATA[255', # SMEM B128, B256, B512
'VDATA[95', 'VDATA[127') # FLAT B128
@@ -548,5 +583,8 @@ def compile_uop(op_name: str, pseudocode: str):
if op_name in _SKIP_OPS: return None
if any(p in pseudocode for p in _PCODE_PATTERNS): return None # these patterns still need pcode
if any(p in pseudocode for p in _WIDE_OUTPUT_PATTERNS): return None # >64-bit outputs need pcode
sink, output_info, input_vars, mem_stores = _compile_pseudocode(pseudocode)
# DS ops use LDS (local memory), others use global MEM
is_ds = op_name.startswith('DS_')
mem_buf = LDS_BUF if is_ds else MEM_BUF
sink, output_info, input_vars, mem_stores = _compile_pseudocode(pseudocode, mem_buf)
return _make_fn(sink, output_info, input_vars, mem_stores)