diff --git a/extra/assembly/amd/ucode.py b/extra/assembly/amd/ucode.py index 76fd1f4d81..f65f61a809 100644 --- a/extra/assembly/amd/ucode.py +++ b/extra/assembly/amd/ucode.py @@ -175,7 +175,11 @@ def _expr(node: UOp, ctx: Ctx, hint: DType = None) -> UOp: case UOp(Ops.CAST, dt, (inner,)): inner_resolved = _expr(inner, ctx, dt) - if dt in FLOATS: return UOp(Ops.CAST, dt, (inner_resolved,)) + if dt in FLOATS: + # For 32'F(0xffc00000) etc, treat integer constants as BITCAST (interpret bits as float) + if inner_resolved.op == Ops.CONST and inner_resolved.dtype not in FLOATS: + return UOp(Ops.BITCAST, dt, (inner_resolved,)) + return UOp(Ops.CAST, dt, (inner_resolved,)) if inner_resolved.dtype.itemsize == dt.itemsize: return _cast(inner_resolved, dt) if dt in SIGNED and inner_resolved.dtype in SIGNED: return UOp(Ops.CAST, dt, (inner_resolved,)) @@ -291,6 +295,15 @@ def _call_cvtToQuietNAN(v): return v def _call_isINF(v): return UOp(Ops.OR, dtypes.bool, (UOp(Ops.CMPEQ, dtypes.bool, (v, UOp.const(v.dtype, float('inf')))), UOp(Ops.CMPEQ, dtypes.bool, (v, UOp.const(v.dtype, float('-inf')))))) +def _call_isDENORM(v): + # Denormalized float: exponent is 0, mantissa is non-zero, value is not zero + uint_dt, _, exp_shift, exp_mask, mant_mask = FP_INFO.get(v.dtype, FP_INFO[dtypes.float32]) + bits = UOp(Ops.BITCAST, uint_dt, (v,)) + exp = UOp(Ops.AND, uint_dt, (UOp(Ops.SHR, uint_dt, (bits, UOp.const(uint_dt, exp_shift))), UOp.const(uint_dt, exp_mask))) + mant = UOp(Ops.AND, uint_dt, (bits, UOp.const(uint_dt, mant_mask))) + is_exp_zero = UOp(Ops.CMPEQ, dtypes.bool, (exp, UOp.const(uint_dt, 0))) + is_mant_nonzero = UOp(Ops.CMPNE, dtypes.bool, (mant, UOp.const(uint_dt, 0))) + return UOp(Ops.AND, dtypes.bool, (is_exp_zero, is_mant_nonzero)) def _call_sign(v): uint_dt, sign_shift, _, _, _ = FP_INFO.get(v.dtype, FP_INFO[dtypes.float32]) bits = UOp(Ops.BITCAST, uint_dt, (v,)) @@ -300,9 +313,18 @@ def _call_exponent(v): bits = UOp(Ops.BITCAST, uint_dt, (v,)) return UOp(Ops.AND, dtypes.uint32, (_cast(UOp(Ops.SHR, uint_dt, (bits, UOp.const(uint_dt, exp_shift))), dtypes.uint32), UOp.const(dtypes.uint32, exp_mask))) def _call_mantissa(v): - uint_dt, _, _, _, mant_mask = FP_INFO.get(v.dtype, FP_INFO[dtypes.float32]) - bits, out_dt = UOp(Ops.BITCAST, uint_dt, (v,)), dtypes.uint64 if v.dtype == dtypes.float64 else dtypes.uint32 - return UOp(Ops.AND, out_dt, (_cast(bits, out_dt) if out_dt != uint_dt else bits, UOp.const(out_dt, mant_mask))) + # AMD V_FREXP_MANT returns mantissa in [0.5, 1.0) range like math.frexp()[0] + # For normalized floats: set exponent to bias-1 (makes value in [0.5,1.0)) + # For zero: return zero; for inf/nan: should be handled by caller + uint_dt, sign_shift, exp_shift, exp_mask, mant_mask = FP_INFO.get(v.dtype, FP_INFO[dtypes.float32]) + bias = {dtypes.float64: 1023, dtypes.float32: 127, dtypes.float16: 15}.get(v.dtype, 127) + bits = UOp(Ops.BITCAST, uint_dt, (v,)) + sign_and_mant = UOp(Ops.AND, uint_dt, (bits, UOp.const(uint_dt, (1 << sign_shift) | mant_mask))) + new_exp = UOp.const(uint_dt, (bias - 1) << exp_shift) # exponent = -1 in biased form + result_bits = UOp(Ops.OR, uint_dt, (sign_and_mant, new_exp)) + result = UOp(Ops.BITCAST, v.dtype, (result_bits,)) + is_zero = UOp(Ops.CMPEQ, dtypes.bool, (v, UOp.const(v.dtype, 0.0))) + return UOp(Ops.WHERE, v.dtype, (is_zero, v, result)) def _call_isEven(v): int_val = UOp(Ops.CAST, dtypes.int64, (v,)) return UOp(Ops.CMPEQ, dtypes.bool, (UOp(Ops.AND, dtypes.int64, (int_val, UOp.const(dtypes.int64, 1))), UOp.const(dtypes.int64, 0))) @@ -354,7 +376,7 @@ def _call_BYTE_PERMUTE(src, sel): CALL_DISPATCH = { 'MEM': _call_MEM, 'fma': _call_fma, 'abs': _call_abs, 'cos': _call_cos, 'rsqrt': _call_rsqrt, 'clamp': _call_clamp, 'floor': _call_floor, 'fract': _call_fract, 'isNAN': _call_isNAN, 'isQuietNAN': _call_isQuietNAN, - 'isSignalNAN': _call_isSignalNAN, 'cvtToQuietNAN': _call_cvtToQuietNAN, 'isINF': _call_isINF, + 'isSignalNAN': _call_isSignalNAN, 'cvtToQuietNAN': _call_cvtToQuietNAN, 'isINF': _call_isINF, 'isDENORM': _call_isDENORM, 'sign': _call_sign, 'exponent': _call_exponent, 'mantissa': _call_mantissa, 'isEven': _call_isEven, 'signext': _call_signext, 'signext_from_bit': _call_signext_from_bit, 'ABSDIFF': _call_ABSDIFF, 'SAT8': _call_SAT8, 'BYTE_PERMUTE': _call_BYTE_PERMUTE, 'bf16_to_f32': _call_bf16_to_f32, @@ -435,7 +457,11 @@ def _get_lhs_info(lhs: UOp, ctx: Ctx) -> tuple[str, DType, int|None, int|None, s # Array element access with variable index return name, var_dtype.scalar(), None, None, None, idx # Return idx as variable name for array_idx return name, dtypes.uint32, None, None, idx, None - case UOp(Ops.DEFINE_VAR, _, _, (name, None, None)): return name, dtypes.uint32, None, None, None, None + case UOp(Ops.DEFINE_VAR, _, _, (name, None, None)): + # If the variable already exists, use its dtype; otherwise default to uint32 + existing = ctx.vars.get(name) + dtype = existing.dtype if existing is not None else dtypes.uint32 + return name, dtype, None, None, None, None raise ValueError(f"Cannot parse LHS: {lhs}") def _stmt(stmt, ctx: Ctx): @@ -660,6 +686,20 @@ def _make_fn(sink: UOp, output_info: list[tuple[str, DType]], input_vars: dict[s if u.op == Ops.XOR: return int(l) ^ int(r) if u.op == Ops.SHR: return int(l) >> int(r) if u.op == Ops.SHL: return int(l) << int(r) + if u.op == Ops.NEG: + v = _eval_uop(u.src[0]) + return -v if v is not None else None + if u.op in (Ops.CMPEQ, Ops.CMPNE, Ops.CMPLT, Ops.CMPLE): + l, r = _eval_uop(u.src[0]), _eval_uop(u.src[1]) + if l is None or r is None: return None + if u.op == Ops.CMPEQ: return l == r + if u.op == Ops.CMPNE: return l != r + if u.op == Ops.CMPLT: return l < r + if u.op == Ops.CMPLE: return l <= r + if u.op == Ops.WHERE: + c, t, f = _eval_uop(u.src[0]), _eval_uop(u.src[1]), _eval_uop(u.src[2]) + if c is None or t is None or f is None: return None + return t if c else f return None def _extract_results(s, MEM=None): @@ -733,12 +773,8 @@ def _make_fn(sink: UOp, output_info: list[tuple[str, DType]], input_vars: dict[s return _extract_results(sink.substitute(dvars).simplify()) return fn -# Ops that need Python exec features (inline conditionals, complex PDF fixes, precise FMA) - fall back to pcode.py -_SKIP_OPS: set[str] = {'V_DIV_FMAS_F32', 'V_DIV_FMAS_F64', 'V_DIV_SCALE_F32', 'V_DIV_SCALE_F64', - 'V_DIV_FIXUP_F32', 'V_DIV_FIXUP_F64', 'V_TRIG_PREOP_F64', - 'V_FMA_F64', 'V_FMA_F32', # FMA needs precise math.fma semantics - 'V_FREXP_MANT_F64', 'V_FREXP_MANT_F32', # mantissa() returns [0.5,1.0) range float - 'V_DOT2_F32_BF16'} # compound assignment parsing issues +# Ops that need Python exec features (inline conditionals, complex PDF fixes) - fall back to pcode.py +_SKIP_OPS: set[str] = {'V_TRIG_PREOP_F64'} _PCODE_PATTERNS = ('LDS[', 'LDS(', 'VGPR[', 'SGPR[', 'GPR[', 'GS_REGS', 'thread_in[', 'thread_out[', 'thread_valid[') _WIDE_OUTPUT_PATTERNS = ('SDATA[95', 'SDATA[127', 'SDATA[159', 'SDATA[191', 'SDATA[223', 'SDATA[255', @@ -746,6 +782,62 @@ _WIDE_OUTPUT_PATTERNS = ('SDATA[95', 'SDATA[127', 'SDATA[159', 'SDATA[191', 'SDA def _apply_pseudocode_fixes(op_name: str, pcode: str) -> str: """Apply known fixes for PDF pseudocode bugs - same as pcode.py but for raw pseudocode.""" + if op_name == 'V_DIV_FMAS_F32': + pcode = pcode.replace('D0.f32 = 2.0F ** 32 * fma(S0.f32, S1.f32, S2.f32)', + 'D0.f32 = (exponent(S2.f32) > 127) ? (2.0F ** 64 * fma(S0.f32, S1.f32, S2.f32)) : (2.0F ** -64 * fma(S0.f32, S1.f32, S2.f32))') + if op_name == 'V_DIV_FMAS_F64': + pcode = pcode.replace('D0.f64 = 2.0 ** 64 * fma(S0.f64, S1.f64, S2.f64)', + 'D0.f64 = (exponent(S2.f64) > 1023) ? (2.0 ** 128 * fma(S0.f64, S1.f64, S2.f64)) : (2.0 ** -128 * fma(S0.f64, S1.f64, S2.f64))') + if op_name == 'V_DIV_FIXUP_F32': + # When S0 (estimate) is NaN but inputs are valid, return OVERFLOW instead of NaN + pcode = pcode.replace('D0.f32 = sign_out ? -abs(S0.f32) : abs(S0.f32)', + 'D0.f32 = isNAN(S0.f32) ? (sign_out ? -OVERFLOW_F32 : OVERFLOW_F32) : (sign_out ? -abs(S0.f32) : abs(S0.f32))') + if op_name == 'V_DIV_FIXUP_F64': + pcode = pcode.replace('D0.f64 = sign_out ? -abs(S0.f64) : abs(S0.f64)', + 'D0.f64 = isNAN(S0.f64) ? (sign_out ? -OVERFLOW_F64 : OVERFLOW_F64) : (sign_out ? -abs(S0.f64) : abs(S0.f64))') + if op_name == 'V_DIV_SCALE_F32': + # Fix 0: Replace DENORM comparisons with isDENORM() calls (order matters - do longer patterns first) + pcode = pcode.replace('S2.f32 / S1.f32 == DENORM.f32', 'isDENORM(S2.f32 / S1.f32)') + pcode = pcode.replace('1.0 / 64\'F(S1.f32) == DENORM.f64', 'isDENORM(1.0 / 64\'F(S1.f32))') + pcode = pcode.replace('S1.f32 == DENORM.f32', 'isDENORM(S1.f32)') + # Fix 1: Set VCC=1 when returning NAN for zero inputs + pcode = pcode.replace('D0.f32 = NAN.f32', 'VCC = 0x1LL;\nD0.f32 = NAN.f32') + # Fix 2: Remove the S1==DENORM branch (it's wrong), handle at end + pcode = pcode.replace('elsif isDENORM(S1.f32) then\nD0.f32 = ldexp(S0.f32, 64)', + 'elsif 1 == 0 then\nD0.f32 = S0.f32') + # Fix 3: Set VCC=1 for tiny numerator case + pcode = pcode.replace('elsif exponent(S2.f32) <= 23 then\n// Numerator is tiny\nD0.f32 = ldexp(S0.f32, 64)', + 'elsif exponent(S2.f32) <= 23 then\nVCC = 0x1LL;\nD0.f32 = ldexp(S0.f32, 64)') + # Fix 4: Simplify S2/S1==DENORM case (just set VCC, don't check S0==S2) + pcode = pcode.replace('elsif isDENORM(S2.f32 / S1.f32) then\nVCC = 0x1LL;\nif S0.f32 == S2.f32 then\n// Only scale the numerator\nD0.f32 = ldexp(S0.f32, 64)\nendif', + 'elsif isDENORM(S2.f32 / S1.f32) then\nVCC = 0x1LL;\nD0.f32 = S0.f32') + # Fix 5: Add else to nested ifs that don't have D0 assignment + pcode = pcode.replace('D0.f32 = ldexp(S0.f32, 64)\nendif\nelsif', 'D0.f32 = ldexp(S0.f32, 64)\nelse\nD0.f32 = S0.f32\nendif\nelsif') + # Fix 6: Add else clause to outermost if before final endif, and check for S1==DENORM at end + lines = pcode.rstrip().split('\n') + for i in range(len(lines) - 1, -1, -1): + if lines[i].strip() == 'endif': + lines.insert(i, 'else\nD0.f32 = S0.f32') + break + pcode = '\n'.join(lines) + ';\nif isDENORM(S1.f32) then\nD0.f32 = NAN.f32\nendif' + if op_name == 'V_DIV_SCALE_F64': + pcode = pcode.replace('S2.f64 / S1.f64 == DENORM.f64', 'isDENORM(S2.f64 / S1.f64)') + pcode = pcode.replace('1.0 / S1.f64 == DENORM.f64', 'isDENORM(1.0 / S1.f64)') + pcode = pcode.replace('S1.f64 == DENORM.f64', 'isDENORM(S1.f64)') + pcode = pcode.replace('D0.f64 = NAN.f64', 'VCC = 0x1LL;\nD0.f64 = NAN.f64') + pcode = pcode.replace('elsif isDENORM(S1.f64) then\nD0.f64 = ldexp(S0.f64, 128)', + 'elsif 1 == 0 then\nD0.f64 = S0.f64') + pcode = pcode.replace('elsif exponent(S2.f64) <= 52 then\n// Numerator is tiny\nD0.f64 = ldexp(S0.f64, 128)', + 'elsif exponent(S2.f64) <= 52 then\nVCC = 0x1LL;\nD0.f64 = ldexp(S0.f64, 128)') + pcode = pcode.replace('elsif isDENORM(S2.f64 / S1.f64) then\nVCC = 0x1LL;\nif S0.f64 == S2.f64 then\n// Only scale the numerator\nD0.f64 = ldexp(S0.f64, 128)\nendif', + 'elsif isDENORM(S2.f64 / S1.f64) then\nVCC = 0x1LL;\nD0.f64 = S0.f64') + pcode = pcode.replace('D0.f64 = ldexp(S0.f64, 128)\nendif\nelsif', 'D0.f64 = ldexp(S0.f64, 128)\nelse\nD0.f64 = S0.f64\nendif\nelsif') + lines = pcode.rstrip().split('\n') + for i in range(len(lines) - 1, -1, -1): + if lines[i].strip() == 'endif': + lines.insert(i, 'else\nD0.f64 = S0.f64') + break + pcode = '\n'.join(lines) + ';\nif isDENORM(S1.f64) then\nD0.f64 = NAN.f64\nendif' return pcode @functools.cache diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 06b1911d7e..1efc832974 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -878,7 +878,8 @@ python_alu: dict[Ops, Callable] = { Ops.TRUNC: lambda x: x if math.isinf(x) or math.isnan(x) else math.copysign(math.trunc(x), x), Ops.NEG: operator.neg, Ops.ADD: operator.add, Ops.SUB: operator.sub, Ops.MUL: operator.mul, Ops.CMPNE: operator.ne, Ops.CMPLT: operator.lt, Ops.CMPLE: operator.le, Ops.XOR: operator.xor, Ops.OR: operator.or_, Ops.AND: operator.and_, Ops.SHR: operator.rshift, Ops.SHL: operator.lshift, Ops.MAX: max, - Ops.MOD: cmod, Ops.IDIV: cdiv, Ops.MULACC: lambda x,y,z: (x*y)+z, Ops.WHERE: lambda x,y,z: y if x else z, Ops.CMPEQ: operator.eq, + Ops.MOD: cmod, Ops.IDIV: cdiv, Ops.WHERE: lambda x,y,z: y if x else z, Ops.CMPEQ: operator.eq, + Ops.MULACC: lambda x,y,z: math.fma(x,y,z) if not (math.isinf(x) or math.isinf(y) or math.isnan(x) or math.isnan(y)) else (x*y)+z, Ops.FDIV: lambda x,y: x/y if y != 0 else (math.nan if x == 0 else math.copysign(math.inf, x*y))} def exec_alu(op:Ops, dtype:DType, operands, truncate_output=True):