Revert "transcendental works with long decomp" (#14676)

This commit is contained in:
Christopher Milan
2026-02-10 16:46:34 -08:00
committed by GitHub
parent 0662c8037d
commit 389e2eeda1
2 changed files with 22 additions and 25 deletions

View File

@@ -26,7 +26,8 @@ class TestTranscendentalMath(unittest.TestCase):
atol=3e-2, rtol=1e-5) # sin can have bigger atol for very big x
@unittest.skipIf(getenv("MOCKGPU") and Device.DEFAULT in {"NV", "CUDA"}, "crashed")
@given(ht.float32, strat.sampled_from([(Tensor.exp, np.exp), (Tensor.log, np.log), (Tensor.sin, np.sin)]))
@given(ht.float32, strat.sampled_from([(Tensor.exp, np.exp),(Tensor.log, np.log)] +
([(Tensor.sin, np.sin)] if is_dtype_supported(dtypes.ulong) else [])))
def test_float32(self, x, op):
# wrong nan behavior on Vulkan
if (math.isnan(x) or (x < 0 and op[0] == Tensor.log)) and CI and Device.DEFAULT == "WEBGPU" and not OSX: return
@@ -36,7 +37,8 @@ class TestTranscendentalMath(unittest.TestCase):
atol=2e-5, rtol=1e-5)
@unittest.skipUnless(is_dtype_supported(dtypes.float16, Device.DEFAULT), f"no float16 on {Device.DEFAULT}")
@given(ht.float16, strat.sampled_from([(Tensor.exp, np.exp), (Tensor.log, np.log), (Tensor.sin, np.sin)]))
@given(ht.float16, strat.sampled_from([(Tensor.exp, np.exp),(Tensor.log, np.log)] +
([(Tensor.sin, np.sin)] if is_dtype_supported(dtypes.ulong) else [])))
def test_float16(self, x, op):
# wrong nan behavior on Vulkan
if (math.isnan(x) or (x < 0 and op[0] == Tensor.log)) and CI and Device.DEFAULT == "WEBGPU" and not OSX: return
@@ -59,6 +61,7 @@ class TestTranscendentalMath(unittest.TestCase):
class TestFromFuzzer(unittest.TestCase):
@given(strat.sampled_from(dtypes_float))
@unittest.skipUnless(is_dtype_supported(dtypes.ulong), "Needs ulong")
def test_sin(self, dtype):
if not is_dtype_supported(dtype): return
if dtype == dtypes.float64:
@@ -139,6 +142,7 @@ class TestFloat16Log2(unittest.TestCase):
np.testing.assert_allclose(result, expected, rtol=5e-2, err_msg=f"log2({val})")
class TestTranscendentalSchedule(unittest.TestCase):
@unittest.skipUnless(is_dtype_supported(dtypes.ulong), "Needs ulong")
def test_transcendental_sin_fusion(self):
with Context(TRANSCENDENTAL=2):
a = Tensor.empty(10)

View File

@@ -65,7 +65,7 @@ def frexp(v:UOp) -> tuple[UOp, UOp]:
return mantissa, exp
# *** reduction algorithms for sine ***
def payne_hanek_reduction(d:UOp, supports_long:bool=True) -> tuple[UOp, UOp]:
def payne_hanek_reduction(d:UOp) -> tuple[UOp, UOp]:
"""
Performs Payne-Hanek Reduction: computes the remainder of `d` modulo pi/2 for the values `d` where
39800.0 <= d <= +Inf
@@ -80,10 +80,10 @@ def payne_hanek_reduction(d:UOp, supports_long:bool=True) -> tuple[UOp, UOp]:
intermediate_dtype = dtypes.float32.vec(d.dtype.count) if d.dtype.base.scalar() == dtypes.float16 else d.dtype
f, e = frexp(d) # NOTE: this implicitly assumes that double support implies long support
ia = (f.cast(intermediate_dtype) * 4.294967296e9)
f, e = frexp(d)
ia = (f.cast(intermediate_dtype) * 4.294967296e9).cast(dtypes.uint64)
# extract 96 relevant bits of 2/pi based on magnitude of argument
i = shr(e.cast(dtypes.uint32), 5)
i = shr(e.cast(dtypes.uint64), 5)
e = e.cast(dtypes.int32) & 31
offset = 32 - e
@@ -92,8 +92,8 @@ def payne_hanek_reduction(d:UOp, supports_long:bool=True) -> tuple[UOp, UOp]:
if count+offset < len(two_over_pi_f) - 1:
an = i.ne(count).where(_take(an, offset, count=count+1), an.const_like(two_over_pi_f[count+offset]))
return an
def _shl_lazy(x:UOp, y:UOp): return x * (y >= 32).where(0, pow2if(y, d.dtype).cast(dtypes.uint32))
def _shr_lazy(x:UOp, y:UOp): return (overflow := y >= 32).where(0, x // pow2if(overflow.where(0, y), d.dtype).cast(dtypes.uint32))
def _shl_lazy(x:UOp, y:UOp): return (x.cast(dtypes.uint64) * pow2if(y, d.dtype).cast(dtypes.uint64)).cast(dtypes.uint32)
def _shr_lazy(x:UOp, y:UOp): return (x.cast(dtypes.uint64) // pow2if(y, d.dtype).cast(dtypes.uint64)).cast(dtypes.uint32)
a = [_take(UOp.const(dtypes.uint32.vec(d.dtype.count), 0), i) for i in range(4)]
# (two_over_pi_f[Int(i) + n] << e) | (two_over_pi_f[Int(i) + n+1] >> (nbits - e))
@@ -102,21 +102,14 @@ def payne_hanek_reduction(d:UOp, supports_long:bool=True) -> tuple[UOp, UOp]:
mi = _shl_lazy(a[1], e) | _shr_lazy(a[2], offset)
lo = _shl_lazy(a[2], e) | _shr_lazy(a[3], offset)
if supports_long:
def _hp_mul(x:UOp, y:UOp) -> UOp: return x.cast(dtypes.uint64) * y.cast(dtypes.uint64)
# compute x * 2/pi
p = shl(_hp_mul(ia, hi), 32) + _hp_mul(ia, mi) + shr(_hp_mul(ia, lo), 32)
# round quotient to nearest
q = shr(p, 62).cast(dtypes.int32)
p = (p & 0x3fffffffffffffff).cast(intermediate_dtype)
else:
zero, ia = UOp.const(dtypes.uint, 0), l2i(Ops.CAST, dtypes.uint64, ia)
# compute x * 2/pi: (ia * { hi, lo }) + (ia * lo, 32)_lo
p = l2i(Ops.ADD, dtypes.uint, *l2i(Ops.MUL, dtypes.uint, *ia, mi, hi), l2i(Ops.MUL, dtypes.uint, *ia, lo, zero)[1], zero)
# round quotient to nearest: q = shr(p, 62) = p_hi >> 30
q = shr(p[1], 30).cast(dtypes.int32)
p = l2i(Ops.CAST, intermediate_dtype, p[0], p[1] & 0x3FFFFFFF)
r = (p * (3.4061215800865545e-19)).cast(d.dtype)
def _hp_mul(x:UOp, y:UOp) -> UOp: return x.cast(dtypes.uint64) * y.cast(dtypes.uint64)
# compute x * 2/pi
p = shl(_hp_mul(ia, hi), 32) + _hp_mul(ia, mi) + shr(_hp_mul(ia, lo), 32)
# round quotient to nearest
q = shr(p, 62).cast(dtypes.int32)
p = p & 0x3fffffffffffffff
r = (p.cast(intermediate_dtype) * (3.4061215800865545e-19)).cast(d.dtype)
# if fraction >= 0.5, r -= pi/2, q += 1
return (f<0.5).where(r, r - math.pi/2), (f<0.5).where(q, q + 1)
@@ -176,7 +169,7 @@ def sin_poly_large(d:UOp, q:UOp) -> UOp:
# *** toplevel functions for xsin/xlog2/xexp2 ***
def xsin(d:UOp, ctx:str, fast:bool=False, switch_over:float=30.0) -> UOp:
def xsin(d:UOp, fast:bool=False, switch_over:float=30.0) -> UOp:
"""
Implements a 1.0 ULP approximation for Ops.SIN.
- fast=True assumes x <= switch_over.
@@ -188,7 +181,7 @@ def xsin(d:UOp, ctx:str, fast:bool=False, switch_over:float=30.0) -> UOp:
# x_sign = sign(x)
x_sign = x.ne(0).where((x<0).where(x.const_like(-1), x.const_like(1)), x.const_like(0))
x_abs = x * x_sign
r, q = cody_waite_reduction(x_abs) if fast else payne_hanek_reduction(x_abs, is_dtype_supported(dtypes.long, ctx))
r, q = (cody_waite_reduction if fast else payne_hanek_reduction)(x_abs)
if fast: result = sin_poly_small(r, q)
else:
# Payne Hanek Reduction assumes abs(x) >= pi/4, so for smaller values, use cody_waite_reduction.