From b52ff63896249ef711a283aa7b69ff38ae6b8cbc Mon Sep 17 00:00:00 2001 From: George Hotz Date: Sun, 4 Jan 2026 15:48:31 -0800 Subject: [PATCH] fixes --- extra/assembly/amd/ucode.py | 173 +++++++++++++++++++++++++-------- test/unit/test_uop_symbolic.py | 55 +++++++++++ tinygrad/dtype.py | 6 +- tinygrad/uop/symbolic.py | 28 +++++- 4 files changed, 218 insertions(+), 44 deletions(-) diff --git a/extra/assembly/amd/ucode.py b/extra/assembly/amd/ucode.py index 8d9f64fc0d..a5cd6ddc4e 100644 --- a/extra/assembly/amd/ucode.py +++ b/extra/assembly/amd/ucode.py @@ -89,6 +89,14 @@ def _expr(node: UOp, ctx: Ctx, hint: DType = None) -> UOp: # Handle Var.type if inner.op == Ops.DEFINE_VAR and inner.arg[1] is None: name = inner.arg[0] + # Handle INF.f32, INF.f64, NAN.f32, NAN.f64, etc. + if name == 'INF' or name in ('+INF', '-INF'): + return UOp.const(dt, float('-inf') if name.startswith('-') else float('inf')) + if name == 'NAN': + return UOp.const(dt, float('nan')) + if name == 'DENORM': + denorm = {dtypes.float32: 1.17549435e-38, dtypes.float64: 2.2250738585072014e-308}.get(dt, 1.17549435e-38) + return UOp.const(dt, denorm) if name in ('VCCZ', 'EXECZ'): return _cast(UOp(Ops.CMPEQ, dtypes.bool, (ctx.vars.get('VCC' if name == 'VCCZ' else 'EXEC'), UOp.const(dtypes.uint64, 0))), dtypes.uint32) if name.startswith('WAVE_STATUS.COND_DBG'): return UOp.const(dtypes.uint32, 0) @@ -118,6 +126,8 @@ def _expr(node: UOp, ctx: Ctx, hint: DType = None) -> UOp: # Single-bit slice: base[idx:idx] -> (base >> idx) & 1 if hi_expr is lo_expr: return UOp(Ops.AND, dtypes.uint32, (_cast(UOp(Ops.SHR, base.dtype, (base, _cast(lo_uop, base.dtype))), dtypes.uint32), UOp.const(dtypes.uint32, 1))) + # Simplify the bounds to get constant values (needed when loop variables are substituted) + hi_uop, lo_uop = hi_uop.simplify(), lo_uop.simplify() if hi_uop.op == Ops.CONST and lo_uop.op == Ops.CONST: hi_val, lo_val = int(hi_uop.arg), int(lo_uop.arg) if hi_val < lo_val: @@ -230,7 +240,29 @@ def _call_floor(v): return UOp(Ops.WHERE, v.dtype, (needs_adjust, UOp(Ops.SUB, v.dtype, (truncated, UOp.const(v.dtype, 1))), truncated)) def _call_fract(v): return UOp(Ops.SUB, v.dtype, (v, _call_floor(v))) def _call_isNAN(v): return UOp(Ops.CMPNE, dtypes.bool, (v, v)) -def _call_isSignalNAN(v): return UOp.const(dtypes.bool, 0) +def _call_isSignalNAN(v): + # Signaling NaN: exponent all 1s, mantissa non-zero, MSB of mantissa is 0 + # Unwrap CAST to check on original float type + while v.op == Ops.CAST and v.dtype in FLOATS: v = v.src[0] + uint_dt, _, exp_shift, exp_mask, mant_mask = FP_INFO.get(v.dtype, FP_INFO[dtypes.float32]) + bits = UOp(Ops.BITCAST, uint_dt, (v,)) + exp = UOp(Ops.AND, uint_dt, (UOp(Ops.SHR, uint_dt, (bits, UOp.const(uint_dt, exp_shift))), UOp.const(uint_dt, exp_mask))) + mant = UOp(Ops.AND, uint_dt, (bits, UOp.const(uint_dt, mant_mask))) + quiet_bit = {dtypes.float64: 0x8000000000000, dtypes.float32: 0x400000, dtypes.float16: 0x200}.get(v.dtype, 0x400000) + is_exp_all_ones = UOp(Ops.CMPEQ, dtypes.bool, (exp, UOp.const(uint_dt, exp_mask))) + is_mant_nonzero = UOp(Ops.CMPNE, dtypes.bool, (mant, UOp.const(uint_dt, 0))) + is_quiet_bit_clear = UOp(Ops.CMPEQ, dtypes.bool, (UOp(Ops.AND, uint_dt, (mant, UOp.const(uint_dt, quiet_bit))), UOp.const(uint_dt, 0))) + return UOp(Ops.AND, dtypes.bool, (UOp(Ops.AND, dtypes.bool, (is_exp_all_ones, is_mant_nonzero)), is_quiet_bit_clear)) +def _call_isQuietNAN(v): + # Quiet NaN: exponent all 1s, MSB of mantissa is 1 + while v.op == Ops.CAST and v.dtype in FLOATS: v = v.src[0] + uint_dt, _, exp_shift, exp_mask, mant_mask = FP_INFO.get(v.dtype, FP_INFO[dtypes.float32]) + bits = UOp(Ops.BITCAST, uint_dt, (v,)) + exp = UOp(Ops.AND, uint_dt, (UOp(Ops.SHR, uint_dt, (bits, UOp.const(uint_dt, exp_shift))), UOp.const(uint_dt, exp_mask))) + quiet_bit = {dtypes.float64: 0x8000000000000, dtypes.float32: 0x400000, dtypes.float16: 0x200}.get(v.dtype, 0x400000) + is_exp_all_ones = UOp(Ops.CMPEQ, dtypes.bool, (exp, UOp.const(uint_dt, exp_mask))) + is_quiet_bit_set = UOp(Ops.CMPNE, dtypes.bool, (UOp(Ops.AND, uint_dt, (bits, UOp.const(uint_dt, quiet_bit))), UOp.const(uint_dt, 0))) + return UOp(Ops.AND, dtypes.bool, (is_exp_all_ones, is_quiet_bit_set)) def _call_cvtToQuietNAN(v): return v def _call_isINF(v): return UOp(Ops.OR, dtypes.bool, (UOp(Ops.CMPEQ, dtypes.bool, (v, UOp.const(v.dtype, float('inf')))), @@ -285,7 +317,7 @@ def _call_BYTE_PERMUTE(src, sel): CALL_DISPATCH = { 'MEM': _call_MEM, 'fma': _call_fma, 'abs': _call_abs, 'cos': _call_cos, 'rsqrt': _call_rsqrt, - 'clamp': _call_clamp, 'floor': _call_floor, 'fract': _call_fract, 'isNAN': _call_isNAN, 'isQuietNAN': _call_isNAN, + 'clamp': _call_clamp, 'floor': _call_floor, 'fract': _call_fract, 'isNAN': _call_isNAN, 'isQuietNAN': _call_isQuietNAN, 'isSignalNAN': _call_isSignalNAN, 'cvtToQuietNAN': _call_cvtToQuietNAN, 'isINF': _call_isINF, 'sign': _call_sign, 'exponent': _call_exponent, 'mantissa': _call_mantissa, 'isEven': _call_isEven, 'signext': _call_signext, 'signext_from_bit': _call_signext_from_bit, 'ABSDIFF': _call_ABSDIFF, 'SAT8': _call_SAT8, @@ -299,6 +331,10 @@ def _transform_call(name: str, a: list[UOp], hint: DType) -> UOp: result_dt = a[0].dtype if a[0].dtype in FLOATS else hint or dtypes.float32 return UOp(Ops.EXP2, result_dt, (a[1] if a[1].dtype == result_dt else UOp(Ops.CAST, result_dt, (a[1],)),)) if name in MATH_OPS: return UOp(MATH_OPS[name], a[0].dtype, (a[0],)) + if name == 'ldexp': + # ldexp(x, exp) = x * 2^exp + exp_float = UOp(Ops.CAST, a[0].dtype, (a[1],)) if a[1].dtype != a[0].dtype else a[1] + return UOp(Ops.MUL, a[0].dtype, (a[0], UOp(Ops.EXP2, a[0].dtype, (exp_float,)))) if name in ('min', 'max'): return UOp(Ops.WHERE, a[0].dtype, (UOp(Ops.CMPLT, dtypes.bool, ((a[0], a[1]) if name == 'min' else (a[1], a[0]))), a[0], a[1])) if name in CVT_MAP: @@ -350,12 +386,18 @@ def _get_lhs_info(lhs: UOp, ctx: Ctx) -> tuple[str, DType, int|None, int|None, s return name, dtypes.uint32, int(hi), int(lo), None case UOp(Ops.CUSTOMI, _, (UOp(Ops.BITCAST, dt, (UOp(Ops.DEFINE_VAR, _, _, (name, None, None)),)), UOp(Ops.DEFINE_VAR, _, _, (idx, None, None)), idx2)) if lhs.src[1] is lhs.src[2]: return name, dt, None, None, idx + # Handle tmp[i] where i is a variable (single-bit index) + case UOp(Ops.CUSTOMI, _, (UOp(Ops.DEFINE_VAR, _, _, (name, None, None)), UOp(Ops.DEFINE_VAR, _, _, (idx, None, None)), idx2)) if lhs.src[1] is lhs.src[2]: + return name, dtypes.uint32, None, None, idx case UOp(Ops.DEFINE_VAR, _, _, (name, None, None)): return name, dtypes.uint32, None, None, None raise ValueError(f"Cannot parse LHS: {lhs}") def _stmt(stmt, ctx: Ctx): match stmt: - case Declare(name, dtype): ctx.decls[name] = dtype + case Declare(name, dtype): + ctx.decls[name] = dtype + # Initialize declared variable with zero value + ctx.vars[name] = UOp.const(dtype, 0) case Assign(lhs, rhs): # Handle MEM[addr].type = value -> memory store if lhs.op == Ops.BITCAST and lhs.src[0].op == Ops.CUSTOM and lhs.src[0].arg == 'MEM': @@ -367,6 +409,26 @@ def _stmt(stmt, ctx: Ctx): ctx.mem_stores.append(UOp(Ops.STORE, dtypes.void, (idx, val_uop))) return + # Handle CAT (multi-output assignment) like {D1.u1, D0.u64} = ... + if lhs.op == Ops.CAT: + rhs_uop = _expr(rhs, ctx) + out_vars = ('D0', 'D1', 'SCC', 'VCC', 'EXEC', 'PC', 'SDATA', 'VDATA', 'RETURN_DATA') + offset = 0 + for part in reversed(lhs.src): # CAT is hi, lo order, so reverse to get lo first + if part.op == Ops.BITCAST and part.src[0].op == Ops.DEFINE_VAR: + dt, name = part.dtype, part.src[0].arg[0] + # Map non-standard dtypes to real dtypes + if dt.name == 'u1': bits, real_dt = 1, dtypes.uint32 + elif dt == dtypes.ulong or dt.name == 'ulong': bits, real_dt = 64, dtypes.uint64 + else: bits, real_dt = dt.itemsize * 8, dt + mask = (1 << bits) - 1 + extracted = UOp(Ops.AND, rhs_uop.dtype, (UOp(Ops.SHR, rhs_uop.dtype, (rhs_uop, UOp.const(rhs_uop.dtype, offset))), UOp.const(rhs_uop.dtype, mask))) + val = _cast(extracted, real_dt) + ctx.vars[name] = val + if name in out_vars: ctx.outputs.append((name, val, real_dt)) + offset += bits + return + var, dtype, hi, lo, idx_var = _get_lhs_info(lhs, ctx) out_vars = ('D0', 'D1', 'SCC', 'VCC', 'EXEC', 'PC', 'SDATA', 'VDATA', 'RETURN_DATA') @@ -420,25 +482,41 @@ def _transform_if(branches: tuple, ctx: Ctx): for s in body: _stmt(s, sub_ctx) parsed.append((cond_uop, sub_ctx)) - # Collect all assigned variables across all branches + # Collect all assigned variables across all branches (both outputs and locals) out_vars = ('D0', 'D1', 'SCC', 'VCC', 'EXEC', 'PC', 'SDATA', 'VDATA', 'RETURN_DATA') - assigned = set() + assigned_outputs = set() + assigned_locals = set() for _, sub_ctx in parsed: for name, _, _ in sub_ctx.outputs: - if name in out_vars: assigned.add(name) + if name in out_vars: assigned_outputs.add(name) + # Track local variables that were modified in branches + for name, val in sub_ctx.vars.items(): + if name not in ctx.vars or ctx.vars[name] is not val: + if name not in out_vars and name not in INPUT_VARS: + assigned_locals.add(name) - for var in assigned: + # Merge output variables + for var in assigned_outputs: dtype = next((d for _, sub_ctx in parsed for n, _, d in sub_ctx.outputs if n == var), dtypes.uint32) result = ctx.vars.get(var, UOp.const(dtype, 0)) for cond_uop, sub_ctx in reversed(parsed): branch_val = next((u for n, u, _ in sub_ctx.outputs if n == var), None) if branch_val is not None: result = branch_val if cond_uop is None else UOp(Ops.WHERE, branch_val.dtype, (cond_uop, branch_val, _cast(result, branch_val.dtype))) - ctx.vars[var] = result ctx.outputs = [(n, u, d) for n, u, d in ctx.outputs if n != var] ctx.outputs.append((var, result, dtype)) + # Merge local variables (like 'result') + for var in assigned_locals: + dtype = ctx.decls.get(var, dtypes.uint32) + result = ctx.vars.get(var, UOp.const(dtype, 0)) + for cond_uop, sub_ctx in reversed(parsed): + if var in sub_ctx.vars and (var not in ctx.vars or sub_ctx.vars[var] is not ctx.vars[var]): + branch_val = sub_ctx.vars[var] + result = branch_val if cond_uop is None else UOp(Ops.WHERE, branch_val.dtype, (cond_uop, branch_val, _cast(result, branch_val.dtype))) + ctx.vars[var] = result + def _transform_for(var: str, start: UOp, end: UOp, body: tuple, ctx: Ctx): start_val = start.arg if start.op == Ops.CONST else int(_expr(start, ctx).arg) end_val = end.arg if end.op == Ops.CONST else int(_expr(end, ctx).arg) @@ -478,6 +556,28 @@ def _make_fn(sink: UOp, output_info: list[tuple[str, DType]], input_vars: dict[s is_lds = any(u.op == Ops.DEFINE_LOCAL for u in topo) is_mem = bool(mem_stores) or any(u.op == Ops.LOAD for u in topo) + def _eval_uop(u: UOp) -> int|float|None: + """Recursively evaluate a UOp tree to a constant value.""" + if u.op == Ops.CONST: return u.arg + if u.op == Ops.CAST: + v = _eval_uop(u.src[0]) + return v if v is not None else None + if u.op == Ops.BITCAST: + v = _eval_uop(u.src[0]) + return v if v is not None else None + if u.op in (Ops.ADD, Ops.SUB, Ops.MUL, Ops.AND, Ops.OR, Ops.XOR, Ops.SHR, Ops.SHL): + l, r = _eval_uop(u.src[0]), _eval_uop(u.src[1]) + if l is None or r is None: return None + if u.op == Ops.ADD: return l + r + if u.op == Ops.SUB: return l - r + if u.op == Ops.MUL: return l * r + if u.op == Ops.AND: return int(l) & int(r) + if u.op == Ops.OR: return int(l) | int(r) + if u.op == Ops.XOR: return int(l) ^ int(r) + if u.op == Ops.SHR: return int(l) >> int(r) + if u.op == Ops.SHL: return int(l) << int(r) + return None + def _extract_results(s, MEM=None): for u in s.src: if u.op == Ops.STORE: @@ -489,8 +589,13 @@ def _make_fn(sink: UOp, output_info: list[tuple[str, DType]], input_vars: dict[s setattr(MEM[addr], acc, int(val)) result = {} for i, (name, dtype) in enumerate(output_info): - if i >= len(s.src) or s.src[i].op != Ops.CONST: continue - result[name] = _float_to_bits(s.src[i].arg, dtype) if dtype in FLOATS else int(s.src[i].arg) & (0xffffffff if dtype.itemsize <= 4 else 0xffffffffffffffff) + if i >= len(s.src): continue + if s.src[i].op == Ops.CONST: + val = s.src[i].arg + else: + val = _eval_uop(s.src[i]) + if val is None: continue + result[name] = _float_to_bits(val, dtype) if dtype in FLOATS else int(val) & (0xffffffff if dtype.itemsize <= 4 else 0xffffffffffffffff) return result if is_lds: @@ -543,44 +648,34 @@ def _make_fn(sink: UOp, output_info: list[tuple[str, DType]], input_vars: dict[s return _extract_results(sink.substitute(dvars).simplify()) return fn -_SKIP_OPS = { - 'V_DIV_FMAS_F32', 'V_DIV_FMAS_F64', - 'V_CMP_CLASS_F16', 'V_CMP_CLASS_F32', 'V_CMP_CLASS_F64', 'V_CMPX_CLASS_F16', 'V_CMPX_CLASS_F32', 'V_CMPX_CLASS_F64', - 'V_FREXP_MANT_F16', 'V_FREXP_MANT_F32', 'V_FREXP_MANT_F64', - 'V_DIV_FIXUP_F16', 'V_DIV_FIXUP_F32', 'V_DIV_SCALE_F32', 'V_DIV_SCALE_F64', - 'V_TRIG_PREOP_F64', - 'S_FF0_I32_B32', 'S_FF0_I32_B64', 'S_FF1_I32_B32', 'S_FF1_I32_B64', - 'S_FLBIT_I32', 'S_FLBIT_I32_B32', 'S_FLBIT_I32_B64', 'S_FLBIT_I32_I32', 'S_FLBIT_I32_I64', - 'S_BITREPLICATE_B64_B32', - 'S_BITSET0_B32', 'S_BITSET0_B64', 'S_BITSET1_B32', 'S_BITSET1_B64', - 'S_QUADMASK_B32', 'S_QUADMASK_B64', 'S_WQM_B32', 'S_WQM_B64', - 'S_GETREG_B32', 'S_SETREG_B32', 'S_SETREG_IMM32_B32', - 'S_MOVRELD_B32', 'S_MOVRELD_B64', 'S_MOVRELS_B32', 'S_MOVRELS_B64', 'S_MOVRELSD_2_B32', - 'V_MOVRELD_B32', 'V_MOVRELS_B32', 'V_MOVRELSD_B32', 'V_MOVRELSD_2_B32', - 'V_READFIRSTLANE_B32', 'V_WRITELANE_B32', 'V_READLANE_B32', - 'V_PERMLANE16_B32', 'V_PERMLANE64_B32', 'V_PERMLANEX16_B32', - 'V_MBCNT_HI_U32_B32', 'V_MBCNT_LO_U32_B32', - 'S_SENDMSG_RTN_B32', 'S_SENDMSG_RTN_B64', - 'V_SWAPREL_B32', 'V_PERM_B32', - 'S_NOP', 'S_SETHALT', 'S_TRAP', - 'V_MAD_U64_U32', 'V_MAD_I64_I32', - 'V_DOT2_F32_BF16', - 'V_DOT4_I32_IU8', 'V_DOT4_U32_U8', - 'V_DOT8_I32_IU4', 'V_DOT8_U32_U4', - 'V_FMA_MIX_F32', 'V_FMA_MIXLO_F16', 'V_FMA_MIXHI_F16', - 'V_CVT_OFF_F32_I4', -} +_SKIP_OPS: set[str] = set() -_PCODE_PATTERNS = ('LDS[', 'LDS(', 'VGPR[', 'SGPR[', 'GPR[', 'GS_REGS', 'thread_in[', 'thread_out[', 'thread_valid[', - 'DATA2', 'OFFSET0', 'OFFSET1') +_PCODE_PATTERNS = ('LDS[', 'LDS(', 'VGPR[', 'SGPR[', 'GPR[', 'GS_REGS', 'thread_in[', 'thread_out[', 'thread_valid[') _WIDE_OUTPUT_PATTERNS = ('SDATA[95', 'SDATA[127', 'SDATA[159', 'SDATA[191', 'SDATA[223', 'SDATA[255', 'VDATA[95', 'VDATA[127') +def _apply_pseudocode_fixes(op_name: str, pcode: str) -> str: + """Apply known fixes for PDF pseudocode bugs - same as pcode.py but for raw pseudocode.""" + # V_DIV_FMAS: fix scaling factor + if op_name == 'V_DIV_FMAS_F32': + pcode = pcode.replace('D0.f32 = 2.0 ** 32 * fma(S0.f32, S1.f32, S2.f32)', + 'D0.f32 = (2.0 ** 64 if exponent(S2.f32) > 127 else 2.0 ** -64) * fma(S0.f32, S1.f32, S2.f32)') + if op_name == 'V_DIV_FMAS_F64': + pcode = pcode.replace('D0.f64 = 2.0 ** 64 * fma(S0.f64, S1.f64, S2.f64)', + 'D0.f64 = (2.0 ** 128 if exponent(S2.f64) > 1023 else 2.0 ** -128) * fma(S0.f64, S1.f64, S2.f64)') + # V_DIV_SCALE: D0 defaults to S0 if no branch sets it + if op_name == 'V_DIV_SCALE_F32': + pcode = 'D0.f32 = S0.f32\n' + pcode + if op_name == 'V_DIV_SCALE_F64': + pcode = 'D0.f64 = S0.f64\n' + pcode + return pcode + @functools.cache def compile_uop(op_name: str, pseudocode: str): if op_name in _SKIP_OPS: return None if any(p in pseudocode for p in _PCODE_PATTERNS): return None if any(p in pseudocode for p in _WIDE_OUTPUT_PATTERNS): return None + pseudocode = _apply_pseudocode_fixes(op_name, pseudocode) is_ds = op_name.startswith('DS_') mem_buf = LDS_BUF if is_ds else MEM_BUF sink, output_info, input_vars, mem_stores = _compile_pseudocode(pseudocode, mem_buf) diff --git a/test/unit/test_uop_symbolic.py b/test/unit/test_uop_symbolic.py index 82a2767dce..e7e74d09fc 100644 --- a/test/unit/test_uop_symbolic.py +++ b/test/unit/test_uop_symbolic.py @@ -1096,5 +1096,60 @@ class TestFuzzFailure(unittest.TestCase): rn = expr.substitute({v1:v1_val, v2:v2_val, v3:v3_val}).ssimplify() assert num==rn, f"{num} != {rn}" +class TestBitcast(unittest.TestCase): + def test_bitcast_preserves_signaling_nan_bits(self): + # Signaling NaN in f32: exponent all 1s, mantissa non-zero with MSB clear + snan_bits = 0x7f800001 + bits = UOp.const(dtypes.uint32, snan_bits) + # BITCAST uint32 -> float32 should NOT fold (would corrupt NaN bits) + bf = bits.bitcast(dtypes.float32) + result = bf.simplify() + self.assertEqual(result.op, Ops.BITCAST, "signaling NaN bitcast should not fold to CONST") + self.assertEqual(result.src[0].arg, snan_bits) + + def test_bitcast_double_preserves_bits(self): + # BITCAST(BITCAST(x)) where outer dtype == x.dtype should fold back to x + snan_bits = 0x7f800001 + bits = UOp.const(dtypes.uint32, snan_bits) + bf1 = bits.bitcast(dtypes.float32) + bf2 = bf1.bitcast(dtypes.uint32) + result = bf2.simplify() + self.assertEqual(result.op, Ops.CONST) + self.assertEqual(result.arg, snan_bits, "double bitcast should preserve original bits") + + def test_bitcast_quiet_nan_folds(self): + # Quiet NaN in f32: can be folded since Python's nan preserves these bits + qnan_bits = 0x7fc00000 + bits = UOp.const(dtypes.uint32, qnan_bits) + bf = bits.bitcast(dtypes.float32) + result = bf.simplify() + self.assertEqual(result.op, Ops.CONST) + self.assertTrue(math.isnan(result.arg)) + + def test_bitcast_normal_float_folds(self): + # Normal float values should fold + bits = UOp.const(dtypes.uint32, 0x40490fdb) # pi + bf = bits.bitcast(dtypes.float32) + result = bf.simplify() + self.assertEqual(result.op, Ops.CONST) + self.assertAlmostEqual(result.arg, 3.14159265, places=5) + + def test_bitcast_f64_signaling_nan(self): + # Signaling NaN in f64 + snan_bits = 0x7ff0000000000001 + bits = UOp.const(dtypes.uint64, snan_bits) + bf = bits.bitcast(dtypes.float64) + result = bf.simplify() + self.assertEqual(result.op, Ops.BITCAST, "f64 signaling NaN bitcast should not fold") + + def test_bitcast_f64_double_preserves_bits(self): + snan_bits = 0x7ff0000000000001 + bits = UOp.const(dtypes.uint64, snan_bits) + bf1 = bits.bitcast(dtypes.float64) + bf2 = bf1.bitcast(dtypes.uint64) + result = bf2.simplify() + self.assertEqual(result.op, Ops.CONST) + self.assertEqual(result.arg, snan_bits) + if __name__ == '__main__': unittest.main() diff --git a/tinygrad/dtype.py b/tinygrad/dtype.py index 9a79696c10..1cba6650d7 100644 --- a/tinygrad/dtype.py +++ b/tinygrad/dtype.py @@ -140,7 +140,11 @@ class dtypes: if isinstance(val, InvalidType): return val # NOTE: float('nan') != float('nan'), so we canonicalize here if isinstance(val, float) and math.isnan(val): val = math.nan - return int(val) if dtypes.is_int(dtype) else float(val) if dtypes.is_float(dtype) else bool(val) + if dtypes.is_int(dtype): return int(val) + if dtypes.is_float(dtype): return float(val) + if dtype == dtypes.bool: return bool(val) + # For unknown types (e.g. wide integers u65, b65), preserve as int + return int(val) if isinstance(val, (int, float)) else val @staticmethod @functools.cache def min(dtype:DType): diff --git a/tinygrad/uop/symbolic.py b/tinygrad/uop/symbolic.py index d21347d9a3..abcdb9cde4 100644 --- a/tinygrad/uop/symbolic.py +++ b/tinygrad/uop/symbolic.py @@ -17,11 +17,28 @@ def simplify_pow(x:UOp, c:UOp) -> UOp|None: return None def fold_bitcast(root:UOp, c:UOp) -> UOp|None: - if (from_fmt:=c.dtype.scalar().fmt) is None or (to_fmt:=root.dtype.scalar().fmt) is None: return None if c.dtype.itemsize != root.dtype.itemsize: return None - trunc = truncate.get(c.dtype.scalar(), lambda x: x) - def convert(v:ConstType): return struct.unpack(to_fmt, struct.pack(from_fmt, trunc(v)))[0] - return root.const_like(convert(c.arg) if root.dtype.count == 1 else tuple(map(convert, c.arg))) + from_scalar, to_scalar = c.dtype.scalar(), root.dtype.scalar() + if from_scalar.fmt is None or to_scalar.fmt is None: return None + trunc = truncate.get(from_scalar, lambda x: x) + int_fmts = {1: 'B', 2: 'H', 4: 'I', 8: 'Q'} # unsigned int formats by size + def convert(v:ConstType): + int_fmt = int_fmts.get(from_scalar.itemsize) + if int_fmt is None: return None + # Get the integer bit representation + if isinstance(v, float): + int_val = struct.unpack('<'+int_fmt, struct.pack('<'+from_scalar.fmt, v))[0] + else: + int_val = int(trunc(v)) & ((1 << (from_scalar.itemsize * 8)) - 1) + # Convert to output type + result = struct.unpack('<'+to_scalar.fmt, struct.pack('<'+int_fmt, int_val))[0] + # Don't fold if result is NaN with non-canonical bits (as_const normalizes all NaN to math.nan) + if isinstance(result, float) and math.isnan(result): + canonical_nan_bits = struct.unpack('<'+int_fmt, struct.pack('<'+to_scalar.fmt, math.nan))[0] + if int_val != canonical_nan_bits: return None # would be corrupted by as_const + return result + result = convert(c.arg) if root.dtype.count == 1 else tuple(map(convert, c.arg)) + return None if result is None or (isinstance(result, tuple) and None in result) else root.const_like(result) invalid_pat = UPat(Ops.CONST, arg=Invalid, name="i") invalid_gate = UPat.var("cond").where(UPat.var("x"), invalid_pat) @@ -99,6 +116,9 @@ symbolic_simple = propagate_invalid + PatternMatcher([ (UPat(Ops.CAST, name="root", src=(UPat.cvar("c"),)), lambda root, c: root.const_like(c.arg)), (UPat((Ops.CAST, Ops.BITCAST), name="root"), lambda root: root.src[0] if root.dtype == root.src[0].dtype else None), (UPat(Ops.BITCAST, name="root", src=(UPat.cvar("c"),)), fold_bitcast), + # BITCAST(BITCAST(x)) -> x when outer dtype matches x's dtype (preserves NaN bits) + (UPat(Ops.BITCAST, name="outer", src=(UPat(Ops.BITCAST, src=(UPat.var("x"),)),)), + lambda outer, x: x if outer.dtype == x.dtype and outer.dtype.itemsize == outer.src[0].dtype.itemsize else None), # b.cast(a).cast(b) -> b if a preserves all values in b (UPat.var('x').cast(name="a").cast(name="b"), lambda x,a,b: x if x.dtype == b.dtype and can_lossless_cast(b.dtype, a.dtype) else None), # ** pow **