From 87e72f1540e8d17553eef2bc27ef1cd23b8cf554 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Sun, 4 Jan 2026 16:32:35 -0800 Subject: [PATCH] ftz --- extra/assembly/amd/ucode.py | 168 ++++++++++++++++++++++++++---------- tinygrad/uop/symbolic.py | 9 +- 2 files changed, 129 insertions(+), 48 deletions(-) diff --git a/extra/assembly/amd/ucode.py b/extra/assembly/amd/ucode.py index a5cd6ddc4e..76fd1f4d81 100644 --- a/extra/assembly/amd/ucode.py +++ b/extra/assembly/amd/ucode.py @@ -7,6 +7,13 @@ from extra.assembly.amd.qcode import parse, Assign, Declare, If, For SIGNED = (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64) FLOATS = (dtypes.float16, dtypes.float32, dtypes.float64) +# FTZ (Flush To Zero): RDNA3 default mode flushes f32 denormals to ±0 +def _ftz32(bits: int) -> float: + bits = bits & 0xffffffff + if (bits & 0x7f800000) == 0 and (bits & 0x007fffff) != 0: # denormal + return 0.0 + return struct.unpack(' UOp: return x if x.dtype == dtype else UOp(Ops.BITCAST if dtype.itemsize == x.dtype.itemsize else Ops.CAST, dtype, (x,)) @@ -28,6 +35,7 @@ INPUT_VARS = { 'SIMM32': UOp(Ops.DEFINE_VAR, dtypes.uint32, (), ('SIMM32', 0, 0xffffffff)), 'PC': UOp(Ops.DEFINE_VAR, dtypes.uint64, (), ('PC', 0, 0xffffffffffffffff)), 'ADDR': UOp(Ops.DEFINE_VAR, dtypes.uint64, (), ('ADDR', 0, 0xffffffffffffffff)), + 'ADDR_BASE': UOp(Ops.DEFINE_VAR, dtypes.uint64, (), ('ADDR', 0, 0xffffffffffffffff)), # Alias for ADDR 'SDATA': UOp(Ops.DEFINE_VAR, dtypes.uint64, (), ('SDATA', 0, 0xffffffffffffffff)), 'VDATA': UOp(Ops.DEFINE_VAR, dtypes.uint64, (), ('VDATA', 0, 0xffffffffffffffff)), 'VDST': UOp(Ops.DEFINE_VAR, dtypes.uint64, (), ('VDST', 0, 0xffffffffffffffff)), @@ -37,6 +45,8 @@ INPUT_VARS = { 'OFFSET': UOp(Ops.DEFINE_VAR, dtypes.uint32, (), ('OFFSET', 0, 0xffff)), 'OFFSET0': UOp(Ops.DEFINE_VAR, dtypes.uint32, (), ('OFFSET0', 0, 0xff)), 'OFFSET1': UOp(Ops.DEFINE_VAR, dtypes.uint32, (), ('OFFSET1', 0, 0xff)), + 'OPSEL': UOp(Ops.DEFINE_VAR, dtypes.uint32, (), ('OPSEL', 0, 7)), + 'OPSEL_HI': UOp(Ops.DEFINE_VAR, dtypes.uint32, (), ('OPSEL_HI', 0, 7)), } MEM_BUF = UOp(Ops.DEFINE_GLOBAL, dtypes.uint8.ptr(addrspace=AddrSpace.GLOBAL), arg=0) @@ -121,7 +131,21 @@ def _expr(node: UOp, ctx: Ctx, hint: DType = None) -> UOp: if dt in FLOATS: return UOp(Ops.BITCAST, dt, (inner_resolved,)) return _cast(inner_resolved, dt) - case UOp(Ops.CUSTOMI, _, (base_expr, hi_expr, lo_expr)): # Slice + case UOp(Ops.CUSTOMI, _, (base_expr, hi_expr, lo_expr)): # Slice or array access + # Check for array element access first: arr[idx] where arr is a vector type + if base_expr.op == Ops.DEFINE_VAR and base_expr.arg[1] is None and hi_expr is lo_expr: + name = base_expr.arg[0] + var_dtype = ctx.decls.get(name) + if var_dtype is not None and var_dtype.count > 1: + # Array element access - look up stored element + idx_uop = _expr(hi_expr, ctx) + idx_uop = idx_uop.simplify() + if idx_uop.op == Ops.CONST: + arr_key = f"{name}_{int(idx_uop.arg)}" + if arr_key in ctx.vars: + return ctx.vars[arr_key] + # Element not set, return default value + return UOp.const(var_dtype.scalar(), 0) base, hi_uop, lo_uop = _expr(base_expr, ctx), _expr(hi_expr, ctx), _expr(lo_expr, ctx) # Single-bit slice: base[idx:idx] -> (base >> idx) & 1 if hi_expr is lo_expr: @@ -296,23 +320,35 @@ def _call_SAT8(v): clamped = UOp(Ops.WHERE, v.dtype, (UOp(Ops.CMPLT, dtypes.bool, (v, UOp.const(v.dtype, -128))), UOp.const(v.dtype, -128), v)) return UOp(Ops.WHERE, v.dtype, (UOp(Ops.CMPLT, dtypes.bool, (UOp.const(v.dtype, 127), clamped)), UOp.const(v.dtype, 127), clamped)) def _call_BYTE_PERMUTE(src, sel): - src_fixed = UOp(Ops.OR, dtypes.uint64, - (UOp(Ops.SHL, dtypes.uint64, (UOp(Ops.AND, dtypes.uint64, (_cast(src, dtypes.uint64), UOp.const(dtypes.uint64, 0xffffffff))), UOp.const(dtypes.uint64, 32))), - UOp(Ops.SHR, dtypes.uint64, (_cast(src, dtypes.uint64), UOp.const(dtypes.uint64, 32))))) + # src is {S0, S1} = (S0 << 32) | S1, where bytes 0-3 are S1, bytes 4-7 are S0 + src64 = _cast(src, dtypes.uint64) sel_val = UOp(Ops.AND, dtypes.uint32, (_cast(sel, dtypes.uint32), UOp.const(dtypes.uint32, 0xff))) sel_idx = UOp(Ops.AND, dtypes.uint32, (sel_val, UOp.const(dtypes.uint32, 7))) - sel_hi = UOp(Ops.AND, dtypes.uint32, (sel_val, UOp.const(dtypes.uint32, 0x80))) sel_nibble = UOp(Ops.AND, dtypes.uint32, (sel_val, UOp.const(dtypes.uint32, 0xf))) + # Normal byte select (sel 0-7): extract byte at index shift = UOp(Ops.SHL, dtypes.uint32, (sel_idx, UOp.const(dtypes.uint32, 3))) - byte_val = _cast(UOp(Ops.AND, dtypes.uint64, (UOp(Ops.SHR, dtypes.uint64, (src_fixed, _cast(shift, dtypes.uint64))), UOp.const(dtypes.uint64, 0xff))), dtypes.uint32) - byte_msb = UOp(Ops.AND, dtypes.uint32, (byte_val, UOp.const(dtypes.uint32, 0x80))) - sign_ext_val = UOp(Ops.WHERE, dtypes.uint32, (UOp(Ops.CMPNE, dtypes.bool, (byte_msb, UOp.const(dtypes.uint32, 0))), UOp.const(dtypes.uint32, 0xff), UOp.const(dtypes.uint32, 0))) - is_sign_ext = UOp(Ops.AND, dtypes.bool, (UOp(Ops.CMPLT, dtypes.bool, (UOp.const(dtypes.uint32, 7), sel_nibble)), UOp(Ops.CMPLT, dtypes.bool, (sel_nibble, UOp.const(dtypes.uint32, 12))))) - is_const_zero = UOp(Ops.CMPEQ, dtypes.bool, (sel_nibble, UOp.const(dtypes.uint32, 12))) - is_const_ff = UOp(Ops.CMPLT, dtypes.bool, (UOp.const(dtypes.uint32, 12), sel_nibble)) - result = UOp(Ops.WHERE, dtypes.uint32, (is_const_ff, UOp.const(dtypes.uint32, 0xff), byte_val)) - result = UOp(Ops.WHERE, dtypes.uint32, (is_const_zero, UOp.const(dtypes.uint32, 0), result)) - result = UOp(Ops.WHERE, dtypes.uint32, (is_sign_ext, sign_ext_val, result)) + byte_val = _cast(UOp(Ops.AND, dtypes.uint64, (UOp(Ops.SHR, dtypes.uint64, (src64, _cast(shift, dtypes.uint64))), UOp.const(dtypes.uint64, 0xff))), dtypes.uint32) + # Sign extension (sel 8-11): check bit 15/31/47/63 respectively + def sign_ext_bit(bit_pos): + bit = UOp(Ops.AND, dtypes.uint64, (UOp(Ops.SHR, dtypes.uint64, (src64, UOp.const(dtypes.uint64, bit_pos))), UOp.const(dtypes.uint64, 1))) + return UOp(Ops.WHERE, dtypes.uint32, (UOp(Ops.CMPNE, dtypes.bool, (bit, UOp.const(dtypes.uint64, 0))), UOp.const(dtypes.uint32, 0xff), UOp.const(dtypes.uint32, 0))) + sign8, sign9, sign10, sign11 = sign_ext_bit(15), sign_ext_bit(31), sign_ext_bit(47), sign_ext_bit(63) + # Build result based on selector + is_sel8 = UOp(Ops.CMPEQ, dtypes.bool, (sel_nibble, UOp.const(dtypes.uint32, 8))) + is_sel9 = UOp(Ops.CMPEQ, dtypes.bool, (sel_nibble, UOp.const(dtypes.uint32, 9))) + is_sel10 = UOp(Ops.CMPEQ, dtypes.bool, (sel_nibble, UOp.const(dtypes.uint32, 10))) + is_sel11 = UOp(Ops.CMPEQ, dtypes.bool, (sel_nibble, UOp.const(dtypes.uint32, 11))) + is_sel12 = UOp(Ops.CMPEQ, dtypes.bool, (sel_nibble, UOp.const(dtypes.uint32, 12))) + is_sel_gt12 = UOp(Ops.CMPLT, dtypes.bool, (UOp.const(dtypes.uint32, 12), sel_nibble)) + result = byte_val + result = UOp(Ops.WHERE, dtypes.uint32, (is_sel8, sign8, result)) + result = UOp(Ops.WHERE, dtypes.uint32, (is_sel9, sign9, result)) + result = UOp(Ops.WHERE, dtypes.uint32, (is_sel10, sign10, result)) + result = UOp(Ops.WHERE, dtypes.uint32, (is_sel11, sign11, result)) + result = UOp(Ops.WHERE, dtypes.uint32, (is_sel12, UOp.const(dtypes.uint32, 0), result)) + result = UOp(Ops.WHERE, dtypes.uint32, (is_sel_gt12, UOp.const(dtypes.uint32, 0xff), result)) + # High bit of selector (0x80) means return 0 + sel_hi = UOp(Ops.AND, dtypes.uint32, (sel_val, UOp.const(dtypes.uint32, 0x80))) return UOp(Ops.WHERE, dtypes.uint32, (UOp(Ops.CMPNE, dtypes.bool, (sel_hi, UOp.const(dtypes.uint32, 0))), UOp.const(dtypes.uint32, 0), result)) CALL_DISPATCH = { @@ -370,34 +406,50 @@ def _transform_call(name: str, a: list[UOp], hint: DType) -> UOp: return result raise ValueError(f"Unknown function: {name}") -def _get_lhs_info(lhs: UOp, ctx: Ctx) -> tuple[str, DType, int|None, int|None, str|None]: - """Extract assignment target: (var_name, dtype, hi_bit, lo_bit, idx_var)""" +def _get_lhs_info(lhs: UOp, ctx: Ctx) -> tuple[str, DType, int|None, int|None, str|None, int|None]: + """Extract assignment target: (var_name, dtype, hi_bit, lo_bit, idx_var, array_idx)""" match lhs: - case UOp(Ops.BITCAST, dt, (UOp(Ops.DEFINE_VAR, _, _, (name, None, None)),)): return name, dt, None, None, None + case UOp(Ops.BITCAST, dt, (UOp(Ops.DEFINE_VAR, _, _, (name, None, None)),)): return name, dt, None, None, None, None case UOp(Ops.BITCAST, dt, (UOp(Ops.CUSTOMI, _, (UOp(Ops.DEFINE_VAR, _, _, (name, None, None)), UOp(Ops.CONST, _, _, hi), UOp(Ops.CONST, _, _, lo))),)): - return name, dt, int(hi), int(lo), None + return name, dt, int(hi), int(lo), None, None case UOp(Ops.BITCAST, _, (UOp(Ops.CUSTOMI, _, (UOp(Ops.BITCAST, _, (UOp(Ops.DEFINE_VAR, _, _, (name, None, None)),)), UOp(Ops.DEFINE_VAR, _, _, (idx, None, None)), idx2)),)) if lhs.src[0].src[1] is lhs.src[0].src[2]: - return name, dtypes.uint64, None, None, idx + return name, dtypes.uint64, None, None, idx, None case UOp(Ops.BITCAST, dt, (UOp(Ops.CUSTOMI, _, (UOp(Ops.DEFINE_VAR, _, _, (name, None, None)), UOp(Ops.DEFINE_VAR, _, _, (idx, None, None)), idx2)),)) if lhs.src[0].src[1] is lhs.src[0].src[2]: - return name, dt, None, None, idx + return name, dt, None, None, idx, None case UOp(Ops.CUSTOMI, _, (UOp(Ops.BITCAST, _, (UOp(Ops.DEFINE_VAR, _, _, (name, None, None)),)), UOp(Ops.CONST, _, _, hi), UOp(Ops.CONST, _, _, lo))): - return name, dtypes.uint32, int(hi), int(lo), None + return name, dtypes.uint32, int(hi), int(lo), None, None + case UOp(Ops.CUSTOMI, _, (UOp(Ops.DEFINE_VAR, _, _, (name, None, None)), UOp(Ops.CONST, _, _, idx), _)) if lhs.src[1] is lhs.src[2]: + # Check if this is array element access (variable is a vector type) + var_dtype = ctx.decls.get(name) + if var_dtype is not None and var_dtype.count > 1: + return name, var_dtype.scalar(), None, None, None, int(idx) + return name, dtypes.uint32, int(idx), int(idx), None, None case UOp(Ops.CUSTOMI, _, (UOp(Ops.DEFINE_VAR, _, _, (name, None, None)), UOp(Ops.CONST, _, _, hi), UOp(Ops.CONST, _, _, lo))): - return name, dtypes.uint32, int(hi), int(lo), None + return name, dtypes.uint32, int(hi), int(lo), None, None case UOp(Ops.CUSTOMI, _, (UOp(Ops.BITCAST, dt, (UOp(Ops.DEFINE_VAR, _, _, (name, None, None)),)), UOp(Ops.DEFINE_VAR, _, _, (idx, None, None)), idx2)) if lhs.src[1] is lhs.src[2]: - return name, dt, None, None, idx - # Handle tmp[i] where i is a variable (single-bit index) + return name, dt, None, None, idx, None + # Handle arr[i] where i is a variable - check if it's array element or bit index case UOp(Ops.CUSTOMI, _, (UOp(Ops.DEFINE_VAR, _, _, (name, None, None)), UOp(Ops.DEFINE_VAR, _, _, (idx, None, None)), idx2)) if lhs.src[1] is lhs.src[2]: - return name, dtypes.uint32, None, None, idx - case UOp(Ops.DEFINE_VAR, _, _, (name, None, None)): return name, dtypes.uint32, None, None, None + var_dtype = ctx.decls.get(name) + if var_dtype is not None and var_dtype.count > 1: + # Array element access with variable index + return name, var_dtype.scalar(), None, None, None, idx # Return idx as variable name for array_idx + return name, dtypes.uint32, None, None, idx, None + case UOp(Ops.DEFINE_VAR, _, _, (name, None, None)): return name, dtypes.uint32, None, None, None, None raise ValueError(f"Cannot parse LHS: {lhs}") def _stmt(stmt, ctx: Ctx): match stmt: case Declare(name, dtype): ctx.decls[name] = dtype - # Initialize declared variable with zero value - ctx.vars[name] = UOp.const(dtype, 0) + # Special handling for S array - it maps to source operands S0, S1, S2 + if name == 'S' and dtype.count == 3: + ctx.vars['S_0'] = ctx.vars['S0'] + ctx.vars['S_1'] = ctx.vars['S1'] + ctx.vars['S_2'] = ctx.vars['S2'] + else: + # Initialize declared variable with zero value + ctx.vars[name] = UOp.const(dtype, 0) case Assign(lhs, rhs): # Handle MEM[addr].type = value -> memory store if lhs.op == Ops.BITCAST and lhs.src[0].op == Ops.CUSTOM and lhs.src[0].arg == 'MEM': @@ -429,9 +481,27 @@ def _stmt(stmt, ctx: Ctx): offset += bits return - var, dtype, hi, lo, idx_var = _get_lhs_info(lhs, ctx) + var, dtype, hi, lo, idx_var, array_idx = _get_lhs_info(lhs, ctx) out_vars = ('D0', 'D1', 'SCC', 'VCC', 'EXEC', 'PC', 'SDATA', 'VDATA', 'RETURN_DATA') + # Handle array element assignment: arr[idx] = value + if array_idx is not None: + var_dtype = ctx.decls.get(var) + if var_dtype is None: raise ValueError(f"Unknown array variable: {var}") + rhs_uop = _expr(rhs, ctx, dtype) + # array_idx can be an int or a variable name (str) + if isinstance(array_idx, str): + # Variable index - resolve it + idx_uop = ctx.vars.get(array_idx) + if idx_uop is not None and idx_uop.op == Ops.CONST: + arr_key = f"{var}_{int(idx_uop.arg)}" + else: + raise ValueError(f"Non-constant array index: {array_idx}") + else: + arr_key = f"{var}_{array_idx}" + ctx.vars[arr_key] = rhs_uop + return + if idx_var is not None: base, idx = ctx.vars.get(var), ctx.vars.get(idx_var) if base is None or idx is None: raise ValueError(f"Unknown variable: {var} or {idx_var}") @@ -564,7 +634,21 @@ def _make_fn(sink: UOp, output_info: list[tuple[str, DType]], input_vars: dict[s return v if v is not None else None if u.op == Ops.BITCAST: v = _eval_uop(u.src[0]) - return v if v is not None else None + if v is None: return None + # Convert between int and float bit representations + if u.dtype == dtypes.float64 and u.src[0].dtype in (dtypes.uint64, dtypes.int64): + return struct.unpack(' str: """Apply known fixes for PDF pseudocode bugs - same as pcode.py but for raw pseudocode.""" - # V_DIV_FMAS: fix scaling factor - if op_name == 'V_DIV_FMAS_F32': - pcode = pcode.replace('D0.f32 = 2.0 ** 32 * fma(S0.f32, S1.f32, S2.f32)', - 'D0.f32 = (2.0 ** 64 if exponent(S2.f32) > 127 else 2.0 ** -64) * fma(S0.f32, S1.f32, S2.f32)') - if op_name == 'V_DIV_FMAS_F64': - pcode = pcode.replace('D0.f64 = 2.0 ** 64 * fma(S0.f64, S1.f64, S2.f64)', - 'D0.f64 = (2.0 ** 128 if exponent(S2.f64) > 1023 else 2.0 ** -128) * fma(S0.f64, S1.f64, S2.f64)') - # V_DIV_SCALE: D0 defaults to S0 if no branch sets it - if op_name == 'V_DIV_SCALE_F32': - pcode = 'D0.f32 = S0.f32\n' + pcode - if op_name == 'V_DIV_SCALE_F64': - pcode = 'D0.f64 = S0.f64\n' + pcode return pcode @functools.cache diff --git a/tinygrad/uop/symbolic.py b/tinygrad/uop/symbolic.py index abcdb9cde4..6f35cdbf97 100644 --- a/tinygrad/uop/symbolic.py +++ b/tinygrad/uop/symbolic.py @@ -32,6 +32,9 @@ def fold_bitcast(root:UOp, c:UOp) -> UOp|None: int_val = int(trunc(v)) & ((1 << (from_scalar.itemsize * 8)) - 1) # Convert to output type result = struct.unpack('<'+to_scalar.fmt, struct.pack('<'+int_fmt, int_val))[0] + # FTZ: flush f32 denormals to zero (for AMD GPU emulation - RDNA3 default mode) + if to_scalar.fmt == 'f' and (int_val & 0x7f800000) == 0 and (int_val & 0x007fffff) != 0: + result = 0.0 # Don't fold if result is NaN with non-canonical bits (as_const normalizes all NaN to math.nan) if isinstance(result, float) and math.isnan(result): canonical_nan_bits = struct.unpack('<'+int_fmt, struct.pack('<'+to_scalar.fmt, math.nan))[0] @@ -109,9 +112,9 @@ symbolic_simple = propagate_invalid + PatternMatcher([ ((UPat.var("x") * UPat.var("x2")) / UPat.var("x2"), lambda x,x2: x), # (x*x2)/x2 -> x # x*0 -> 0 or 0*x -> 0 # if x is nan or inf it should render the nan value. - # NOTE: this can be wrong for loaded NaN - (UPat.var("x") * 0, lambda x: x.const_like(float("nan") if x.op is Ops.CONST - and isinstance(x.arg, float) and (math.isnan(x.arg) or math.isinf(x.arg)) else 0)), + # NOTE: this can be wrong for loaded NaN - disabled for AMD emulator correctness + # (UPat.var("x") * 0, lambda x: x.const_like(float("nan") if x.op is Ops.CONST + # and isinstance(x.arg, float) and (math.isnan(x.arg) or math.isinf(x.arg)) else 0)), # *** cast/bitcast *** (UPat(Ops.CAST, name="root", src=(UPat.cvar("c"),)), lambda root, c: root.const_like(c.arg)), (UPat((Ops.CAST, Ops.BITCAST), name="root"), lambda root: root.src[0] if root.dtype == root.src[0].dtype else None),