From 640dac46c2c9f033bef0fcb6daefba81891196fb Mon Sep 17 00:00:00 2001 From: George Hotz Date: Mon, 5 Jan 2026 21:15:29 -0800 Subject: [PATCH] pcode_exec --- extra/assembly/amd/emu.py | 5 +- extra/assembly/amd/pcode_exec.py | 239 +++++++++++++++++++++++++++++++ 2 files changed, 242 insertions(+), 2 deletions(-) create mode 100644 extra/assembly/amd/pcode_exec.py diff --git a/extra/assembly/amd/emu.py b/extra/assembly/amd/emu.py index 380a21f938..c10318a837 100644 --- a/extra/assembly/amd/emu.py +++ b/extra/assembly/amd/emu.py @@ -6,6 +6,7 @@ from tinygrad.runtime.autogen import hsa from extra.assembly.amd.dsl import Inst, unwrap, FLOAT_ENC, MASK32, MASK64, _f32, _i32, _sext, _f16, _i16, _f64, _i64 from extra.assembly.amd.asm import detect_format from extra.assembly.amd.ucode import compile_uop +from extra.assembly.amd.pcode_exec import compile_exec from extra.assembly.amd.autogen.rdna3.str_pcode import PSEUDOCODE_STRINGS from extra.assembly.amd.dsl import SrcEnum from extra.assembly.amd.autogen.rdna3.ins import (SOP1, SOP2, SOPC, SOPK, SOPP, SMEM, VOP1, VOP2, VOP3, VOP3SD, VOP3P, VOPC, DS, FLAT, VOPD, @@ -381,9 +382,9 @@ def decode_program(data: bytes) -> dict[int, Inst]: else: inst._dispatch = dispatch_lane(exec_vop) # Compile pcode for instructions that use it (not VOPD which has _fnx/_fny, not special dispatches) - # Try ucode first (UOp-based), fall back to pcode (Python exec-based) + # Try pcode_exec first (simple UOp->Python), fall back to ucode (UOp graph evaluation) def _compile_op(cls_name, op_name, pcode): - return compile_uop(op_name, pcode) #or compile_pseudocode(cls_name, op_name, pcode) + return compile_exec(op_name, pcode) or compile_uop(op_name, pcode) # VOPD needs separate functions for X and Y ops if isinstance(inst, VOPD): def _compile_vopd_op(op): return _compile_op(type(op).__name__, op.name, PSEUDOCODE_STRINGS[type(op)][op]) diff --git a/extra/assembly/amd/pcode_exec.py b/extra/assembly/amd/pcode_exec.py new file mode 100644 index 0000000000..9c8857267d --- /dev/null +++ b/extra/assembly/amd/pcode_exec.py @@ -0,0 +1,239 @@ +# Simple pseudocode executor - transforms UOps to Python code strings, then exec() +import functools, math +from tinygrad.uop.ops import UOp, Ops +from extra.assembly.amd.pcode_parse import parse, Assign, Declare, If, For, Break, Return +from extra.assembly.amd.ucode import _apply_pseudocode_fixes +from extra.assembly.amd.pcode import ( + MASK32, MASK64, _f32, _i32, _f64, _i64, _f16, _i16, _bf16, _ibf16, _sext, _brev, _div, + fma, ldexp, sqrt, log2, fract, sin, cos, trunc, floor, ceil, exponent, sign, mantissa, signext_from_bit, + isNAN, isQuietNAN, isSignalNAN, cvtToQuietNAN, v_min_f32, v_max_f32, v_min_f16, v_max_f16, v_min_i32, v_max_i32, + f32_to_i32, f32_to_u32, f64_to_i32, f64_to_u32, i32_to_f32, u32_to_f32, f32_to_f16, f16_to_f32, + f32_to_f64, f64_to_f32, i32_to_f64, u32_to_f64, bf16_to_f32, f32_to_bf16, f16_to_i16, f16_to_u16, + ABSDIFF, BYTE_PERMUTE, v_sad_u8, v_msad_u8, s_ff1_i32_b32, s_ff1_i32_b64, + PI, DENORM, INF, OVERFLOW_F32, OVERFLOW_F64, UNDERFLOW_F32, UNDERFLOW_F64, TWO_OVER_PI_1201, +) + +_PARAM = {'OPSEL': 'opsel', 'OPSEL_HI': 'opsel_hi', 'PC': 'pc'} # pseudocode name -> param name +_INPUTS = {'S0', 'S1', 'S2', 'D0', 'SCC', 'VCC', 'EXEC', 'laneId', 'SIMM16', 'SIMM32', 'PC', 'OPSEL', 'OPSEL_HI', 'SRC0', 'VDST'} +_OUTPUTS = {'D0', 'D1', 'SCC', 'VCC', 'EXEC', 'PC'} + +from tinygrad.dtype import dtypes, DType + +def _dtype_info(dt: DType) -> tuple[str|None, int, bool, bool]: + """Returns (to_float_fn, bits, is_float, is_signed) for dtype.""" + if dt == dtypes.double: return ('_f64', 64, True, False) + if dt == dtypes.float: return ('_f32', 32, True, False) + if dt == dtypes.half: return ('_f16', 16, True, False) + if dt == dtypes.bfloat16: return ('_bf16', 16, True, False) + # Integer types - use itemsize and check signed via fmt + bits = dt.itemsize * 8 if hasattr(dt, 'itemsize') else 32 + signed = dt.fmt is not None and dt.fmt.islower() if hasattr(dt, 'fmt') else False + return (None, bits, False, signed) + +# Binary ops: op -> format string +_BINOPS = {Ops.ADD: '({a}+{b})', Ops.SUB: '({a}-{b})', Ops.MUL: '({a}*{b})', Ops.FDIV: '_div({a},{b})', Ops.MOD: '({a}%{b})', + Ops.AND: '({a}&{b})', Ops.OR: '({a}|{b})', Ops.XOR: '({a}^{b})', Ops.SHL: '({a}<>int({b}))', + Ops.POW: '({a}**{b})', Ops.CMPLT: '(1 if {a}<{b} else 0)', Ops.CMPLE: '(1 if {a}<={b} else 0)', + Ops.CMPEQ: '(1 if {a}=={b} else 0)', Ops.CMPNE: '(1 if {a}!={b} else 0)'} + +def _gen_expr_raw(u: UOp, ctx: set) -> str: + """Generate Python expression from UOp, yielding raw bits for float types (no float conversion).""" + if u.op == Ops.BITCAST: + _, bits, _, _ = _dtype_info(u.dtype) + src = _gen_expr(u.src[0], ctx) + return f'(int({src})&{hex((1< str: + """Generate Python expression from UOp.""" + if u.op == Ops.CONST: + if isinstance(u.arg, float): + if math.isnan(u.arg): return "float('nan')" + if math.isinf(u.arg): return f"float('{'-' if u.arg < 0 else ''}inf')" + return repr(u.arg) + if u.op == Ops.DEFINE_VAR: + name = u.arg[0] + return _PARAM.get(name, name) + if u.op == Ops.BITCAST: + to_fn, bits, is_float, is_signed = _dtype_info(u.dtype) + src = _gen_expr(u.src[0], ctx) + if to_fn: return f'{to_fn}({src})' + if is_signed: return f'_sext(int({src})&{hex((1<>{lo})&{hex((1<>{lo})&((1<<({hi}-{lo}+1))-1))' + if u.op == Ops.CUSTOM: # function call + fn = 'int' if u.arg == 'signext' else u.arg # signext -> int + # Functions expecting bits need raw values, not float-converted values + args = [_gen_expr_raw(a, ctx) if u.arg in ('f16_to_i16', 'f16_to_u16') else _gen_expr(a, ctx) for a in u.src] + return f'{fn}({",".join(args)})' + if u.op == Ops.CAT and len(u.src) == 2: # {hi, lo} concatenation + return f'(({_gen_expr(u.src[0], ctx)}<<32)|({_gen_expr(u.src[1], ctx)}&0xffffffff))' + return f'0 # unhandled {u.op}' + +# Float-to-int conversion functions for LHS assignments (use _toi* helpers that handle int passthrough) +_FLOAT_TO_INT = {dtypes.double: '_toi64', dtypes.float: '_toi32', dtypes.half: '_toi16', dtypes.bfloat16: '_toibf16'} + +def _extract_lhs(lhs: UOp) -> tuple[str, int, str|None, UOp|None, UOp|None]: + """Extract (var_name, bits, float_conv, hi, lo) from LHS. float_conv is None for int types.""" + match lhs: + case UOp(Ops.BITCAST, dt, (UOp(Ops.CUSTOMI, _, (UOp(Ops.DEFINE_VAR, _, _, (name, _, _)), hi, lo)),)): # tmp[31:16].f16 + _, bits, _, _ = _dtype_info(dt) + return (name, bits, _FLOAT_TO_INT.get(dt), hi, lo) + case UOp(Ops.BITCAST, dt, (UOp(Ops.DEFINE_VAR, _, _, (name, _, _)),)): # D0.f32 + _, bits, _, _ = _dtype_info(dt) + return (name, bits, _FLOAT_TO_INT.get(dt), None, None) + case UOp(Ops.DEFINE_VAR, _, _, (name, _, _)): # tmp + return (name, 64, None, None, None) + case UOp(Ops.CUSTOMI, _, (UOp(Ops.BITCAST, dt, (UOp(Ops.DEFINE_VAR, _, _, (name, _, _)),)), hi, lo)): # D0.u32[31:16] + _, bits, _, _ = _dtype_info(dt) + return (name, bits, _FLOAT_TO_INT.get(dt), hi, lo) + case UOp(Ops.CUSTOMI, _, (UOp(Ops.DEFINE_VAR, _, _, (name, _, _)), hi, lo)): # tmp[31:16] + return (name, 32, None, hi, lo) + return ('_unknown', 32, None, None, None) + +def _extract_cat_lhs(lhs: UOp) -> list[tuple[str, int]]: + """Extract list of (var_name, bits) from CAT LHS like { D1.u1, D0.u64 }.""" + if lhs.op != Ops.CAT: return None + result = [] + for src in lhs.src: + match src: + case UOp(Ops.BITCAST, dt, (UOp(Ops.DEFINE_VAR, _, _, (name, _, _)),)): + _, bits, _, _ = _dtype_info(dt) + result.append((name, bits)) + case _: return None + return result + +def _gen_stmt(stmt, ctx: set, indent: int = 0) -> list[str]: + """Generate Python statements.""" + p = ' ' * indent + match stmt: + case Assign(lhs, rhs) if lhs.op == Ops.CAT: + # Handle concatenation assignment like { D1.u1, D0.u64 } = expr + parts = _extract_cat_lhs(lhs) + if parts is None: return [f'{p}# unhandled CAT assignment'] + rhs_expr = _gen_expr(rhs, ctx) + lines = [f'{p}_cat_tmp = int({rhs_expr})'] + offset = 0 + for name, bits in reversed(parts): # reversed because low bits come last in { hi, lo } + ctx.add(name) + mask = hex((1 << bits) - 1) + lines.append(f'{p}{name} = (_cat_tmp >> {offset}) & {mask}') + offset += bits + return lines + case Assign(lhs, rhs): + name, bits, fconv, hi, lo = _extract_lhs(lhs) + # Initialize variable if first use + init = [f'{p}{name} = 0'] if name not in ctx and name not in _INPUTS else [] + ctx.add(name) + rhs_expr = _gen_expr(rhs, ctx) + mask = hex((1 << bits) - 1) + # Convert to int: for float types use _toi32/etc (handles NaN), for int types use int() + to_int = f'{fconv}({rhs_expr})' if fconv else f'int({rhs_expr})' + # Handle slice assignment + if hi is not None: + h, l = _gen_expr(hi, ctx), _gen_expr(lo, ctx) + if hi is lo: # single bit + return init + [f'{p}{name} = ({name} & ~(1 << {h})) | (({to_int} & 1) << {h})'] + return init + [f'{p}m = ((1 << ({h} - {l} + 1)) - 1) << {l}', + f'{p}{name} = ({name} & ~m) | (({to_int} << {l}) & m)'] + # Regular assignment - convert and mask to bit width + if bits < 64: return init + [f'{p}{name} = {to_int} & {mask}'] + if fconv: return init + [f'{p}{name} = {to_int}'] # 64-bit float still needs conversion + return init + [f'{p}{name} = {rhs_expr}'] + case Declare(name, _): + ctx.add(name) + return [f'{p}{name} = 0'] + case If(branches): + lines = [] + for i, (cond, body) in enumerate(branches): + kw = 'if' if i == 0 else ('elif' if cond else 'else') + lines.append(f'{p}{kw} {_gen_expr(cond, ctx)}:' if cond else f'{p}else:') + body_lines = [l for s in body for l in _gen_stmt(s, ctx, indent + 1)] + lines.extend(body_lines or [f'{p} pass']) + return lines + case For(var, start, end, body): + ctx.add(var) + lines = [f'{p}for {var} in range(int({_gen_expr(start, ctx)}), int({_gen_expr(end, ctx)}) + 1):'] + body_lines = [l for s in body for l in _gen_stmt(s, ctx, indent + 1)] + lines.extend(body_lines or [f'{p} pass']) + return lines + case Break(): return [f'{p}break'] + case Return(value): return [f'{p}return {_gen_expr(value, ctx)}'] + return [] + +def _gen_function(op_name: str, stmts: tuple, pcode: str) -> str: + """Generate complete function.""" + ctx = set() # tracks variables assigned in the body + sig = f'def _{op_name}(S0, S1, S2, D0, SCC, VCC, laneId, EXEC, SIMM32, VGPR, SRC0=0, VDST=0, pc=0, opsel=0, opsel_hi=0):' + # Generate body - add SIMM16 alias for SIMM32 + body = [' SIMM16 = SIMM32'] + [l for s in stmts for l in _gen_stmt(s, ctx, 1)] + # Generate return - use _PARAM mapping for variable names + ret_items = [f"'{out}': {_PARAM.get(out, out)}" for out in _OUTPUTS if out in ctx] + ret = f" return {{{', '.join(ret_items)}}}" + return '\n'.join([sig] + body + [ret]) + +# Sign extension helpers +_sext8 = lambda v: _sext(int(v) & 0xff, 8) +_sext16 = lambda v: _sext(int(v) & 0xffff, 16) +_sext32 = lambda v: _sext(int(v) & 0xffffffff, 32) +_sext64 = lambda v: _sext(int(v) & 0xffffffffffffffff, 64) + +# Float-to-int with passthrough for already-int values (like f32_to_f16 returns) +_toi64 = lambda v: v if isinstance(v, int) else _i64(float(v)) +_toi32 = lambda v: v if isinstance(v, int) else _i32(float(v)) +_toi16 = lambda v: v if isinstance(v, int) else _i16(float(v)) +_toibf16 = lambda v: v if isinstance(v, int) else _ibf16(float(v)) + +_GLOBALS = { + 'MASK32': MASK32, 'MASK64': MASK64, '_f32': _f32, '_i32': _i32, '_f64': _f64, '_i64': _i64, + '_f16': _f16, '_i16': _i16, '_bf16': _bf16, '_ibf16': _ibf16, '_sext': _sext, '_brev': _brev, '_div': _div, + '_sext8': _sext8, '_sext16': _sext16, '_sext32': _sext32, '_sext64': _sext64, + '_toi64': _toi64, '_toi32': _toi32, '_toi16': _toi16, '_toibf16': _toibf16, + 'fma': fma, 'ldexp': ldexp, 'sqrt': sqrt, 'log2': log2, 'fract': fract, 'sin': sin, 'cos': cos, + 'trunc': trunc, 'floor': floor, 'ceil': ceil, 'exponent': exponent, 'sign': sign, 'mantissa': mantissa, + 'signext_from_bit': signext_from_bit, 'isNAN': isNAN, 'isQuietNAN': isQuietNAN, 'isSignalNAN': isSignalNAN, + 'cvtToQuietNAN': cvtToQuietNAN, 'v_min_f32': v_min_f32, 'v_max_f32': v_max_f32, 'v_min_f16': v_min_f16, 'v_max_f16': v_max_f16, + 'v_min_i32': v_min_i32, 'v_max_i32': v_max_i32, + 'f32_to_i32': f32_to_i32, 'f32_to_u32': f32_to_u32, 'f64_to_i32': f64_to_i32, 'f64_to_u32': f64_to_u32, + 'i32_to_f32': i32_to_f32, 'u32_to_f32': u32_to_f32, 'f32_to_f16': f32_to_f16, 'f16_to_f32': f16_to_f32, + 'f32_to_f64': f32_to_f64, 'f64_to_f32': f64_to_f32, 'i32_to_f64': i32_to_f64, 'u32_to_f64': u32_to_f64, + 'bf16_to_f32': bf16_to_f32, 'f32_to_bf16': f32_to_bf16, 'f16_to_i16': f16_to_i16, 'f16_to_u16': f16_to_u16, + 'ABSDIFF': ABSDIFF, 'BYTE_PERMUTE': BYTE_PERMUTE, + 'v_sad_u8': v_sad_u8, 'v_msad_u8': v_msad_u8, 's_ff1_i32_b32': s_ff1_i32_b32, 's_ff1_i32_b64': s_ff1_i32_b64, + 'abs': abs, 'min': min, 'max': max, 'int': int, 'float': float, + 'PI': PI, 'DENORM': DENORM, 'INF': INF, 'OVERFLOW_F32': OVERFLOW_F32, 'OVERFLOW_F64': OVERFLOW_F64, + 'UNDERFLOW_F32': UNDERFLOW_F32, 'UNDERFLOW_F64': UNDERFLOW_F64, 'TWO_OVER_PI_1201': TWO_OVER_PI_1201, +} + +@functools.cache +def compile_exec(op_name: str, pseudocode: str): + """Compile pseudocode to executable function. Returns None if can't handle.""" + if 'MEM[' in pseudocode or 'LDS[' in pseudocode or 'VGPR[' in pseudocode: return None # skip memory ops + try: + pcode = _apply_pseudocode_fixes(op_name, pseudocode) + stmts = parse(pcode) + fn_code = _gen_function(op_name, stmts, pcode) + local_ns = {} + exec(fn_code, _GLOBALS, local_ns) + return local_ns[f'_{op_name}'] + except Exception: + return None