mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
work
This commit is contained in:
@@ -97,7 +97,9 @@ def expr(s: str) -> Expr:
|
||||
a = _split(s[m.end():e]); return Call(m[1], tuple(expr(x) for x in a) if a != [''] else ())
|
||||
if s[:4] == 'MEM[' and (e := _match(s, 3, '[', ']')) != -1:
|
||||
r, b = s[e+1:], Call('MEM', (expr(s[4:e]),))
|
||||
return Typed(b, DTYPES[r[1:]]) if r[:1] == '.' and r[1:] in DTYPES else b
|
||||
if not r: return b # Just MEM[addr]
|
||||
if r[:1] == '.' and r[1:] in DTYPES: return Typed(b, DTYPES[r[1:]]) # MEM[addr].type
|
||||
# Otherwise fall through to let binary operators parse (e.g., MEM[ADDR].b32.u32 + X)
|
||||
if (q := _fop(s, ('?',))) > 0:
|
||||
d = b = 0
|
||||
for i in range(q+1, len(s)):
|
||||
|
||||
@@ -42,17 +42,26 @@ INPUT_VARS = {
|
||||
'VDATA': UOp(Ops.DEFINE_VAR, dtypes.uint64, (), ('VDATA', 0, 0xffffffffffffffff)),
|
||||
'VDST': UOp(Ops.DEFINE_VAR, dtypes.uint64, (), ('VDST', 0, 0xffffffffffffffff)),
|
||||
'RETURN_DATA': UOp(Ops.DEFINE_VAR, dtypes.uint64, (), ('RETURN_DATA', 0, 0xffffffffffffffff)),
|
||||
# DS (LDS) op variables - DATA/DATA2 are the source data registers, OFFSET is address offset
|
||||
'DATA': UOp(Ops.DEFINE_VAR, dtypes.uint64, (), ('DATA', 0, 0xffffffffffffffff)),
|
||||
'DATA2': UOp(Ops.DEFINE_VAR, dtypes.uint64, (), ('DATA2', 0, 0xffffffffffffffff)),
|
||||
'OFFSET': UOp(Ops.DEFINE_VAR, dtypes.uint32, (), ('OFFSET', 0, 0xffff)),
|
||||
'OFFSET0': UOp(Ops.DEFINE_VAR, dtypes.uint32, (), ('OFFSET0', 0, 0xff)),
|
||||
'OFFSET1': UOp(Ops.DEFINE_VAR, dtypes.uint32, (), ('OFFSET1', 0, 0xff)),
|
||||
}
|
||||
|
||||
# Global memory buffer for MEM[] accesses - DEFINE_GLOBAL with byte pointer
|
||||
MEM_BUF = UOp(Ops.DEFINE_GLOBAL, dtypes.uint8.ptr(addrspace=AddrSpace.GLOBAL), arg=0)
|
||||
# LDS (local) memory buffer for DS ops - DEFINE_LOCAL with byte pointer
|
||||
LDS_BUF = UOp(Ops.DEFINE_LOCAL, dtypes.uint8.ptr(addrspace=AddrSpace.LOCAL), arg=0)
|
||||
|
||||
class Ctx:
|
||||
"""Compilation context - tracks variables and outputs."""
|
||||
def __init__(self):
|
||||
def __init__(self, mem_buf: UOp = MEM_BUF):
|
||||
self.vars: dict[str, UOp] = dict(INPUT_VARS)
|
||||
self.outputs: list[tuple[str, UOp, DType]] = []
|
||||
self.mem_stores: list[UOp] = [] # STORE UOps for MEM
|
||||
self.mem_stores: list[UOp] = [] # STORE UOps for MEM/LDS
|
||||
self.mem_buf = mem_buf # MEM_BUF for global, LDS_BUF for local
|
||||
|
||||
def _expr(node, ctx: Ctx, hint: DType = None) -> UOp:
|
||||
"""Transform qcode AST expression to UOp."""
|
||||
@@ -85,10 +94,11 @@ def _expr(node, ctx: Ctx, hint: DType = None) -> UOp:
|
||||
|
||||
case Typed(expr, qdt):
|
||||
dt = _qdt(qdt)
|
||||
# Handle MEM[addr].type -> memory load using DEFINE_GLOBAL + INDEX + LOAD
|
||||
# Handle MEM[addr].type -> memory load using INDEX + LOAD (buffer set by caller)
|
||||
if isinstance(expr, Call) and expr.name == 'MEM':
|
||||
addr_uop = _expr(expr.args[0], ctx, dtypes.uint64)
|
||||
idx = UOp(Ops.INDEX, dt.ptr(0, AddrSpace.GLOBAL), (MEM_BUF, addr_uop))
|
||||
buf = ctx.mem_buf
|
||||
idx = UOp(Ops.INDEX, dt.ptr(0, buf.dtype.addrspace), (buf, addr_uop))
|
||||
return UOp(Ops.LOAD, dt, (idx,))
|
||||
if isinstance(expr, Var):
|
||||
if expr.name in ('VCCZ', 'EXECZ'):
|
||||
@@ -323,12 +333,13 @@ def _stmt(stmt, ctx: Ctx):
|
||||
match stmt:
|
||||
case Declare(_, _): pass
|
||||
case Assign(lhs, rhs):
|
||||
# Handle MEM[addr].type = value -> memory store using DEFINE_GLOBAL + INDEX + STORE
|
||||
# Handle MEM[addr].type = value -> memory store using INDEX + STORE (buffer set by caller)
|
||||
if isinstance(lhs, Typed) and isinstance(lhs.expr, Call) and lhs.expr.name == 'MEM':
|
||||
dt = _qdt(lhs.dtype)
|
||||
addr_uop = _expr(lhs.expr.args[0], ctx, dtypes.uint64)
|
||||
val_uop = _expr(rhs, ctx, dt)
|
||||
idx = UOp(Ops.INDEX, dt.ptr(0, AddrSpace.GLOBAL), (MEM_BUF, addr_uop))
|
||||
buf = ctx.mem_buf
|
||||
idx = UOp(Ops.INDEX, dt.ptr(0, buf.dtype.addrspace), (buf, addr_uop))
|
||||
ctx.mem_stores.append(UOp(Ops.STORE, dtypes.void, (idx, val_uop)))
|
||||
return
|
||||
|
||||
@@ -423,9 +434,9 @@ def _float_to_bits(val: float, dtype: DType) -> int:
|
||||
if dtype == dtypes.float64: return struct.unpack('<Q', struct.pack('<d', val))[0]
|
||||
return int(val)
|
||||
|
||||
def _compile_pseudocode(pseudocode: str) -> tuple[UOp, list[tuple[str, DType]], dict[str, UOp], list[UOp]]:
|
||||
def _compile_pseudocode(pseudocode: str, mem_buf: UOp = MEM_BUF) -> tuple[UOp, list[tuple[str, DType]], dict[str, UOp], list[UOp]]:
|
||||
"""Compile pseudocode to UOp graph. Returns (sink, outputs, input_vars, mem_stores)."""
|
||||
ctx = Ctx()
|
||||
ctx = Ctx(mem_buf=mem_buf)
|
||||
for stmt in parse(pseudocode): _stmt(stmt, ctx)
|
||||
sink = UOp(Ops.SINK, dtypes.void, tuple(u for _, u, _ in ctx.outputs) or ())
|
||||
return sink, [(n, d) for n, _, d in ctx.outputs], INPUT_VARS, ctx.mem_stores
|
||||
@@ -439,9 +450,12 @@ _DTYPE_ACCESSOR = {
|
||||
|
||||
def _make_fn(sink: UOp, output_info: list[tuple[str, DType]], input_vars: dict[str, UOp], mem_stores: list[UOp]):
|
||||
"""Create runtime function using substitute+simplify, with memory ops on sink."""
|
||||
is_mem = bool(mem_stores) or any(u.op == Ops.LOAD for u in sink.toposort())
|
||||
# Add stores to sink so they get substituted/simplified too
|
||||
# Add stores to sink first so they're included in detection
|
||||
if mem_stores: sink = UOp(Ops.SINK, dtypes.void, sink.src + tuple(mem_stores))
|
||||
# Detect memory type from UOps: check if any INDEX uses DEFINE_LOCAL
|
||||
topo = sink.toposort()
|
||||
is_lds = any(u.op == Ops.DEFINE_LOCAL for u in topo)
|
||||
is_mem = bool(mem_stores) or any(u.op == Ops.LOAD for u in topo)
|
||||
|
||||
def _extract_results(s, MEM=None):
|
||||
# Execute stores and extract output values from simplified sink
|
||||
@@ -459,10 +473,31 @@ def _make_fn(sink: UOp, output_info: list[tuple[str, DType]], input_vars: dict[s
|
||||
result[name] = _float_to_bits(s.src[i].arg, dtype) if dtype in FLOATS else int(s.src[i].arg) & (0xffffffff if dtype.itemsize <= 4 else 0xffffffffffffffff)
|
||||
return result
|
||||
|
||||
if is_mem:
|
||||
if is_lds:
|
||||
# DS (LDS) ops: fn(MEM, addr, data0, data1, offset0, offset1)
|
||||
def fn(MEM, addr, data0=0, data1=0, offset0=0, offset1=0):
|
||||
dvars = {input_vars['ADDR']: UOp.const(dtypes.uint64, addr), input_vars['DATA']: UOp.const(dtypes.uint64, data0),
|
||||
input_vars['DATA2']: UOp.const(dtypes.uint64, data1), input_vars['OFFSET']: UOp.const(dtypes.uint32, offset0),
|
||||
input_vars['OFFSET0']: UOp.const(dtypes.uint32, offset0), input_vars['OFFSET1']: UOp.const(dtypes.uint32, offset1),
|
||||
input_vars['RETURN_DATA']: UOp.const(dtypes.uint64, 0)}
|
||||
s1 = sink.substitute(dvars).simplify()
|
||||
# Replace LOADs with actual values from LDS, then simplify again
|
||||
loads = {}
|
||||
for u in s1.toposort():
|
||||
if u.op == Ops.LOAD:
|
||||
idx_uop = u.src[0]
|
||||
load_addr, dt = int(idx_uop.src[1].arg), idx_uop.dtype.base
|
||||
acc = _DTYPE_ACCESSOR.get(dt, 'u32')
|
||||
loads[u] = UOp.const(dt, getattr(MEM[load_addr], acc))
|
||||
s2 = s1.substitute(loads).simplify() if loads else s1
|
||||
return _extract_results(s2, MEM)
|
||||
return fn
|
||||
elif is_mem:
|
||||
# SMEM/FLAT/GLOBAL ops: fn(MEM, addr, vdata, vdst)
|
||||
def fn(MEM, addr, vdata=0, vdst=0):
|
||||
dvars = {input_vars['ADDR']: UOp.const(dtypes.uint64, addr), input_vars['SDATA']: UOp.const(dtypes.uint64, 0),
|
||||
input_vars['VDATA']: UOp.const(dtypes.uint64, vdata), input_vars['VDST']: UOp.const(dtypes.uint64, vdst),
|
||||
input_vars['DATA']: UOp.const(dtypes.uint64, vdata), input_vars['DATA2']: UOp.const(dtypes.uint64, 0),
|
||||
input_vars['RETURN_DATA']: UOp.const(dtypes.uint64, 0)}
|
||||
s1 = sink.substitute(dvars).simplify()
|
||||
# Replace LOADs with actual values from MEM, then simplify again
|
||||
@@ -477,6 +512,7 @@ def _make_fn(sink: UOp, output_info: list[tuple[str, DType]], input_vars: dict[s
|
||||
return _extract_results(s2, MEM)
|
||||
return fn
|
||||
else:
|
||||
# ALU ops: fn(s0, s1, s2, d0, scc, vcc, laneId, exec_mask, literal, VGPR, ...)
|
||||
def fn(s0, s1, s2, d0, scc, vcc, laneId, exec_mask, literal, VGPR, src0_idx=0, vdst_idx=0, pc=None):
|
||||
simm16 = (literal if -32768 <= literal <= 32767 else (literal - 65536 if literal < 65536 else 0)) if literal is not None else 0
|
||||
dvars = {
|
||||
@@ -534,10 +570,9 @@ _SKIP_OPS = {
|
||||
'V_CVT_OFF_F32_I4', # CVT_OFF_TABLE lookup
|
||||
}
|
||||
|
||||
# Patterns that still need pcode (LDS, register arrays, special ops, atomics/DS with DATA/OFFSET)
|
||||
# Note: 'DATA.' matches atomic ops (not SDATA/VDATA), 'OFFSET0' matches DS ops (not just OFFSET in addr calc)
|
||||
# Patterns that still need pcode (register arrays, special ops, complex atomics)
|
||||
_PCODE_PATTERNS = ('LDS[', 'LDS(', 'VGPR[', 'SGPR[', 'GPR[', 'GS_REGS', 'thread_in[', 'thread_out[', 'thread_valid[',
|
||||
' DATA.', '=DATA.', ' DATA[', '=DATA[', 'DATA2', 'OFFSET0', 'OFFSET1', 'OFFSET.')
|
||||
'DATA2', 'OFFSET0', 'OFFSET1') # DS ops with dual offsets or DATA2
|
||||
# Wide outputs (>64 bit) that need pcode's arbitrary-precision integers
|
||||
_WIDE_OUTPUT_PATTERNS = ('SDATA[95', 'SDATA[127', 'SDATA[159', 'SDATA[191', 'SDATA[223', 'SDATA[255', # SMEM B128, B256, B512
|
||||
'VDATA[95', 'VDATA[127') # FLAT B128
|
||||
@@ -548,5 +583,8 @@ def compile_uop(op_name: str, pseudocode: str):
|
||||
if op_name in _SKIP_OPS: return None
|
||||
if any(p in pseudocode for p in _PCODE_PATTERNS): return None # these patterns still need pcode
|
||||
if any(p in pseudocode for p in _WIDE_OUTPUT_PATTERNS): return None # >64-bit outputs need pcode
|
||||
sink, output_info, input_vars, mem_stores = _compile_pseudocode(pseudocode)
|
||||
# DS ops use LDS (local memory), others use global MEM
|
||||
is_ds = op_name.startswith('DS_')
|
||||
mem_buf = LDS_BUF if is_ds else MEM_BUF
|
||||
sink, output_info, input_vars, mem_stores = _compile_pseudocode(pseudocode, mem_buf)
|
||||
return _make_fn(sink, output_info, input_vars, mem_stores)
|
||||
|
||||
Reference in New Issue
Block a user