From 389e2eeda1447fdac7449b951c1c1683624367ce Mon Sep 17 00:00:00 2001 From: Christopher Milan Date: Tue, 10 Feb 2026 16:46:34 -0800 Subject: [PATCH] Revert "transcendental works with long decomp" (#14676) --- test/test_transcendental.py | 8 +++++-- tinygrad/uop/decompositions.py | 39 ++++++++++++++-------------------- 2 files changed, 22 insertions(+), 25 deletions(-) diff --git a/test/test_transcendental.py b/test/test_transcendental.py index d9dbeae074..7c2ddbc87f 100644 --- a/test/test_transcendental.py +++ b/test/test_transcendental.py @@ -26,7 +26,8 @@ class TestTranscendentalMath(unittest.TestCase): atol=3e-2, rtol=1e-5) # sin can have bigger atol for very big x @unittest.skipIf(getenv("MOCKGPU") and Device.DEFAULT in {"NV", "CUDA"}, "crashed") - @given(ht.float32, strat.sampled_from([(Tensor.exp, np.exp), (Tensor.log, np.log), (Tensor.sin, np.sin)])) + @given(ht.float32, strat.sampled_from([(Tensor.exp, np.exp),(Tensor.log, np.log)] + + ([(Tensor.sin, np.sin)] if is_dtype_supported(dtypes.ulong) else []))) def test_float32(self, x, op): # wrong nan behavior on Vulkan if (math.isnan(x) or (x < 0 and op[0] == Tensor.log)) and CI and Device.DEFAULT == "WEBGPU" and not OSX: return @@ -36,7 +37,8 @@ class TestTranscendentalMath(unittest.TestCase): atol=2e-5, rtol=1e-5) @unittest.skipUnless(is_dtype_supported(dtypes.float16, Device.DEFAULT), f"no float16 on {Device.DEFAULT}") - @given(ht.float16, strat.sampled_from([(Tensor.exp, np.exp), (Tensor.log, np.log), (Tensor.sin, np.sin)])) + @given(ht.float16, strat.sampled_from([(Tensor.exp, np.exp),(Tensor.log, np.log)] + + ([(Tensor.sin, np.sin)] if is_dtype_supported(dtypes.ulong) else []))) def test_float16(self, x, op): # wrong nan behavior on Vulkan if (math.isnan(x) or (x < 0 and op[0] == Tensor.log)) and CI and Device.DEFAULT == "WEBGPU" and not OSX: return @@ -59,6 +61,7 @@ class TestTranscendentalMath(unittest.TestCase): class TestFromFuzzer(unittest.TestCase): @given(strat.sampled_from(dtypes_float)) + @unittest.skipUnless(is_dtype_supported(dtypes.ulong), "Needs ulong") def test_sin(self, dtype): if not is_dtype_supported(dtype): return if dtype == dtypes.float64: @@ -139,6 +142,7 @@ class TestFloat16Log2(unittest.TestCase): np.testing.assert_allclose(result, expected, rtol=5e-2, err_msg=f"log2({val})") class TestTranscendentalSchedule(unittest.TestCase): + @unittest.skipUnless(is_dtype_supported(dtypes.ulong), "Needs ulong") def test_transcendental_sin_fusion(self): with Context(TRANSCENDENTAL=2): a = Tensor.empty(10) diff --git a/tinygrad/uop/decompositions.py b/tinygrad/uop/decompositions.py index 3485ac22f8..1689f45ab8 100644 --- a/tinygrad/uop/decompositions.py +++ b/tinygrad/uop/decompositions.py @@ -65,7 +65,7 @@ def frexp(v:UOp) -> tuple[UOp, UOp]: return mantissa, exp # *** reduction algorithms for sine *** -def payne_hanek_reduction(d:UOp, supports_long:bool=True) -> tuple[UOp, UOp]: +def payne_hanek_reduction(d:UOp) -> tuple[UOp, UOp]: """ Performs Payne-Hanek Reduction: computes the remainder of `d` modulo pi/2 for the values `d` where 39800.0 <= d <= +Inf @@ -80,10 +80,10 @@ def payne_hanek_reduction(d:UOp, supports_long:bool=True) -> tuple[UOp, UOp]: intermediate_dtype = dtypes.float32.vec(d.dtype.count) if d.dtype.base.scalar() == dtypes.float16 else d.dtype - f, e = frexp(d) # NOTE: this implicitly assumes that double support implies long support - ia = (f.cast(intermediate_dtype) * 4.294967296e9) + f, e = frexp(d) + ia = (f.cast(intermediate_dtype) * 4.294967296e9).cast(dtypes.uint64) # extract 96 relevant bits of 2/pi based on magnitude of argument - i = shr(e.cast(dtypes.uint32), 5) + i = shr(e.cast(dtypes.uint64), 5) e = e.cast(dtypes.int32) & 31 offset = 32 - e @@ -92,8 +92,8 @@ def payne_hanek_reduction(d:UOp, supports_long:bool=True) -> tuple[UOp, UOp]: if count+offset < len(two_over_pi_f) - 1: an = i.ne(count).where(_take(an, offset, count=count+1), an.const_like(two_over_pi_f[count+offset])) return an - def _shl_lazy(x:UOp, y:UOp): return x * (y >= 32).where(0, pow2if(y, d.dtype).cast(dtypes.uint32)) - def _shr_lazy(x:UOp, y:UOp): return (overflow := y >= 32).where(0, x // pow2if(overflow.where(0, y), d.dtype).cast(dtypes.uint32)) + def _shl_lazy(x:UOp, y:UOp): return (x.cast(dtypes.uint64) * pow2if(y, d.dtype).cast(dtypes.uint64)).cast(dtypes.uint32) + def _shr_lazy(x:UOp, y:UOp): return (x.cast(dtypes.uint64) // pow2if(y, d.dtype).cast(dtypes.uint64)).cast(dtypes.uint32) a = [_take(UOp.const(dtypes.uint32.vec(d.dtype.count), 0), i) for i in range(4)] # (two_over_pi_f[Int(i) + n] << e) | (two_over_pi_f[Int(i) + n+1] >> (nbits - e)) @@ -102,21 +102,14 @@ def payne_hanek_reduction(d:UOp, supports_long:bool=True) -> tuple[UOp, UOp]: mi = _shl_lazy(a[1], e) | _shr_lazy(a[2], offset) lo = _shl_lazy(a[2], e) | _shr_lazy(a[3], offset) - if supports_long: - def _hp_mul(x:UOp, y:UOp) -> UOp: return x.cast(dtypes.uint64) * y.cast(dtypes.uint64) - # compute x * 2/pi - p = shl(_hp_mul(ia, hi), 32) + _hp_mul(ia, mi) + shr(_hp_mul(ia, lo), 32) - # round quotient to nearest - q = shr(p, 62).cast(dtypes.int32) - p = (p & 0x3fffffffffffffff).cast(intermediate_dtype) - else: - zero, ia = UOp.const(dtypes.uint, 0), l2i(Ops.CAST, dtypes.uint64, ia) - # compute x * 2/pi: (ia * { hi, lo }) + (ia * lo, 32)_lo - p = l2i(Ops.ADD, dtypes.uint, *l2i(Ops.MUL, dtypes.uint, *ia, mi, hi), l2i(Ops.MUL, dtypes.uint, *ia, lo, zero)[1], zero) - # round quotient to nearest: q = shr(p, 62) = p_hi >> 30 - q = shr(p[1], 30).cast(dtypes.int32) - p = l2i(Ops.CAST, intermediate_dtype, p[0], p[1] & 0x3FFFFFFF) - r = (p * (3.4061215800865545e-19)).cast(d.dtype) + def _hp_mul(x:UOp, y:UOp) -> UOp: return x.cast(dtypes.uint64) * y.cast(dtypes.uint64) + # compute x * 2/pi + p = shl(_hp_mul(ia, hi), 32) + _hp_mul(ia, mi) + shr(_hp_mul(ia, lo), 32) + + # round quotient to nearest + q = shr(p, 62).cast(dtypes.int32) + p = p & 0x3fffffffffffffff + r = (p.cast(intermediate_dtype) * (3.4061215800865545e-19)).cast(d.dtype) # if fraction >= 0.5, r -= pi/2, q += 1 return (f<0.5).where(r, r - math.pi/2), (f<0.5).where(q, q + 1) @@ -176,7 +169,7 @@ def sin_poly_large(d:UOp, q:UOp) -> UOp: # *** toplevel functions for xsin/xlog2/xexp2 *** -def xsin(d:UOp, ctx:str, fast:bool=False, switch_over:float=30.0) -> UOp: +def xsin(d:UOp, fast:bool=False, switch_over:float=30.0) -> UOp: """ Implements a 1.0 ULP approximation for Ops.SIN. - fast=True assumes x <= switch_over. @@ -188,7 +181,7 @@ def xsin(d:UOp, ctx:str, fast:bool=False, switch_over:float=30.0) -> UOp: # x_sign = sign(x) x_sign = x.ne(0).where((x<0).where(x.const_like(-1), x.const_like(1)), x.const_like(0)) x_abs = x * x_sign - r, q = cody_waite_reduction(x_abs) if fast else payne_hanek_reduction(x_abs, is_dtype_supported(dtypes.long, ctx)) + r, q = (cody_waite_reduction if fast else payne_hanek_reduction)(x_abs) if fast: result = sin_poly_small(r, q) else: # Payne Hanek Reduction assumes abs(x) >= pi/4, so for smaller values, use cody_waite_reduction.