load/store

This commit is contained in:
George Hotz
2026-01-11 12:46:38 +09:00
parent c30e6a7169
commit 519ac0bb42
2 changed files with 33 additions and 15 deletions

View File

@@ -5,6 +5,9 @@ from tinygrad.dtype import dtypes, DType
from extra.assembly.amd.pcode_parse import parse, If, For, Lambda, Break, Return
import math
# Placeholder buffer for MEM operations - substituted in ucode with actual buffer
MEM_BUF = UOp(Ops.DEFINE_GLOBAL, dtypes.uint8.ptr(0), arg=0)
# ═══════════════════════════════════════════════════════════════════════════════
# TYPE MAPPINGS
# ═══════════════════════════════════════════════════════════════════════════════
@@ -146,6 +149,12 @@ def _typed_cast(x, op):
_fpat = UPat.var('x', dtype=dtypes.floats)
pcode_pm = PatternMatcher([
# MEM read: BITCAST(CUSTOM('MEM', addr)) -> INDEX(buf, addr) with element type
(UPat(Ops.BITCAST, name='bc', src=(UPat(Ops.CUSTOM, arg='MEM', src=(UPat.var('addr'),)),)),
lambda bc, addr: UOp(Ops.INDEX, bc.dtype, (MEM_BUF, addr))),
# MEM write: ASSIGN(INDEX, val) -> STORE (INDEX created by MEM read pattern above)
(UPat(Ops.ASSIGN, src=(UPat(Ops.INDEX, name='idx'), UPat.var('val'))),
lambda idx, val: UOp(Ops.STORE, dtypes.void, (idx, val))),
# Float ops (preserve input type)
(UPat(Ops.CUSTOM, arg='trunc', src=(_fpat,)), lambda x: UOp(Ops.TRUNC, x.dtype, (x,))),
(UPat(Ops.CUSTOM, arg='sqrt', src=(_fpat,)), lambda x: UOp(Ops.SQRT, x.dtype, (x,))),
@@ -297,10 +306,14 @@ pcode_spec = PatternMatcher([
# CUSTOMI/CAT: must be typed (slice bounds or bit concat determine type)
(UPat(Ops.CUSTOMI, name="x"), lambda x: x.dtype != dtypes.void),
(UPat(Ops.CAT, name="x"), lambda x: x.dtype != dtypes.void),
# CUSTOM: MEM and passthrough ops (abs, cvtToQuietNAN) can be void (wrapped by BITCAST/CAST)
(UPat(Ops.CUSTOM, name="x"), lambda x: x.dtype != dtypes.void or x.arg in {'MEM', 'abs', 'cvtToQuietNAN'}),
# CUSTOM: passthrough ops (abs, cvtToQuietNAN) can be void (wrapped by BITCAST/CAST)
(UPat(Ops.CUSTOM, name="x"), lambda x: x.dtype != dtypes.void or x.arg in {'abs', 'cvtToQuietNAN'}),
# POW allows int exponent with float base
(UPat(Ops.POW, dtype=dtypes.floats, src=(UPat(dtype=dtypes.floats), UPat(dtype=dtypes.ints))), lambda: True),
# Memory ops: STORE is void, INDEX has element type, DEFINE_GLOBAL is ptr
(UPat(Ops.STORE, dtype=dtypes.void), lambda: True),
(UPat(Ops.INDEX), lambda: True),
(UPat(Ops.DEFINE_GLOBAL), lambda: True),
]) + shared_spec
# ═══════════════════════════════════════════════════════════════════════════════

View File

@@ -3,7 +3,7 @@ import functools, struct, math
from tinygrad.uop.ops import UOp, Ops
from tinygrad.dtype import dtypes, DType, AddrSpace
from extra.assembly.amd.pcode_parse import If, For
from extra.assembly.amd.pcode_transform import parse_transform
from extra.assembly.amd.pcode_transform import parse_transform, MEM_BUF as PCODE_MEM_BUF
SIGNED, FLOATS = (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64), (dtypes.float16, dtypes.float32, dtypes.float64)
MASK32, MASK64 = 0xffffffff, 0xffffffffffffffff
@@ -84,11 +84,6 @@ def _expr(node: UOp, ctx: Ctx, hint: DType = None) -> UOp:
return _cast(ctx.vars[name], hint or ctx.vars[name].dtype)
case UOp(Ops.BITCAST, dt, (inner,)):
# Memory load: MEM[addr].type
if inner.op == Ops.CUSTOM and inner.arg == 'MEM':
addr = _expr(inner.src[0], ctx, dtypes.uint64)
idx = UOp(Ops.INDEX, dt.ptr(0, ctx.mem_buf.dtype.addrspace), (ctx.mem_buf, addr))
return UOp(Ops.LOAD, dt, (idx,))
# Typed variable access: Var.type
if inner.op == Ops.DEFINE_VAR and inner.arg[1] is None:
name = inner.arg[0]
@@ -185,6 +180,13 @@ def _expr(node: UOp, ctx: Ctx, hint: DType = None) -> UOp:
return UOp(Ops.OR, dtypes.uint64, (UOp(Ops.SHL, dtypes.uint64, (_cast(hi, dtypes.uint64), UOp.const(dtypes.uint64, 32))), _cast(lo, dtypes.uint64)))
return UOp(Ops.OR, dtypes.uint32, (UOp(Ops.SHL, dtypes.uint32, (_cast(hi, dtypes.uint32), UOp.const(dtypes.uint32, 16))),
UOp(Ops.AND, dtypes.uint32, (_cast(lo, dtypes.uint32), UOp.const(dtypes.uint32, 0xffff)))))
# Memory operations: INDEX from pcode_transform -> LOAD with actual buffer
# dt is element type (e.g. uint32), not ptr type
case UOp(Ops.INDEX, dt, (buf, addr)):
actual_buf = ctx.mem_buf if buf is PCODE_MEM_BUF else buf
idx = UOp(Ops.INDEX, dt.ptr(0, ctx.mem_buf.dtype.addrspace), (actual_buf, _expr(addr, ctx, dtypes.uint64)))
return UOp(Ops.LOAD, dt, (idx,))
raise ValueError(f"Cannot transform expression: {node}")
# ═══════════════════════════════════════════════════════════════════════════════
@@ -307,15 +309,18 @@ def _stmt(stmt, ctx: Ctx):
else:
ctx.vars[name] = UOp.const(dtype, 0)
# Memory store (from pcode_transform): STORE(INDEX(buf, addr), val)
case UOp(Ops.STORE, _, (idx, val)) if idx.op == Ops.INDEX:
buf, addr = idx.src
actual_buf = ctx.mem_buf if buf is PCODE_MEM_BUF else buf
dt = idx.dtype # element type
idx_expr = UOp(Ops.INDEX, dt.ptr(0, ctx.mem_buf.dtype.addrspace), (actual_buf, _expr(addr, ctx, dtypes.uint64)))
val_expr = _expr(val, ctx, dt)
ctx.mem_stores.append(UOp(Ops.STORE, dtypes.void, (idx_expr, val_expr)))
return
# Assignment: ASSIGN(lhs, rhs)
case UOp(Ops.ASSIGN, _, (lhs, rhs)):
# Memory store
if lhs.op == Ops.BITCAST and lhs.src[0].op == Ops.CUSTOM and lhs.src[0].arg == 'MEM':
addr, val = _expr(lhs.src[0].src[0], ctx, dtypes.uint64), _expr(rhs, ctx, lhs.dtype)
idx = UOp(Ops.INDEX, lhs.dtype.ptr(0, ctx.mem_buf.dtype.addrspace), (ctx.mem_buf, addr))
ctx.mem_stores.append(UOp(Ops.STORE, dtypes.void, (idx, val)))
return
# CAT assignment: {D1.u1, D0.u64} = ...
if lhs.op == Ops.CAT:
rhs_uop, offset = _expr(rhs, ctx), 0