more complex

This commit is contained in:
George Hotz
2025-12-17 10:27:41 -04:00
parent d142b4eef8
commit 11dc895757

View File

@@ -80,10 +80,12 @@ def payne_hanek_reduction(d:UOp) -> tuple[UOp, UOp]:
intermediate_dtype = dtypes.float32.vec(d.dtype.count) if d.dtype.base.scalar() == dtypes.float16 else d.dtype
f, e = frexp(d)
ia = (f.cast(intermediate_dtype) * 4.294967296e9).cast(dtypes.uint64)
vc = d.dtype.count
u64, u32, i32 = dtypes.uint64.vec(vc), dtypes.uint32.vec(vc), dtypes.int32.vec(vc)
ia = (f.cast(intermediate_dtype) * 4.294967296e9).cast(u64)
# extract 96 relevant bits of 2/pi based on magnitude of argument
i = shr(e.cast(dtypes.uint64), 5)
e = e.cast(dtypes.int32) & 31
i = shr(e.cast(u64), 5)
e = e.cast(i32) & 31
offset = 32 - e
def _take(an:UOp, offset:int, count:int=0) -> UOp:
@@ -91,8 +93,8 @@ def payne_hanek_reduction(d:UOp) -> tuple[UOp, UOp]:
if count+offset < len(two_over_pi_f) - 1:
an = i.ne(count).where(_take(an, offset, count=count+1), an.const_like(two_over_pi_f[count+offset]))
return an
def _shl_lazy(x:UOp, y:UOp): return (x.cast(dtypes.uint64) * pow2if(y, d.dtype).cast(dtypes.uint64)).cast(dtypes.uint32)
def _shr_lazy(x:UOp, y:UOp): return (x.cast(dtypes.uint64) // pow2if(y, d.dtype).cast(dtypes.uint64)).cast(dtypes.uint32)
def _shl_lazy(x:UOp, y:UOp): return (x.cast(u64) * pow2if(y, d.dtype).cast(u64)).cast(u32)
def _shr_lazy(x:UOp, y:UOp): return (x.cast(u64) // pow2if(y, d.dtype).cast(u64)).cast(u32)
a = [_take(UOp.const(dtypes.uint32.vec(d.dtype.count), 0), i) for i in range(4)]
# (two_over_pi_f[Int(i) + n] << e) | (two_over_pi_f[Int(i) + n+1] >> (nbits - e))
@@ -101,12 +103,12 @@ def payne_hanek_reduction(d:UOp) -> tuple[UOp, UOp]:
mi = _shl_lazy(a[1], e) | _shr_lazy(a[2], offset)
lo = _shl_lazy(a[2], e) | _shr_lazy(a[3], offset)
def _hp_mul(x:UOp, y:UOp) -> UOp: return x.cast(dtypes.uint64) * y.cast(dtypes.uint64)
def _hp_mul(x:UOp, y:UOp) -> UOp: return x.cast(u64) * y.cast(u64)
# compute x * 2/pi
p = shl(_hp_mul(ia, hi), 32) + _hp_mul(ia, mi) + shr(_hp_mul(ia, lo), 32)
# round quotient to nearest
q = shr(p, 62).cast(dtypes.int32)
q = shr(p, 62).cast(i32)
p = p & 0x3fffffffffffffff
r = (p.cast(intermediate_dtype) * (3.4061215800865545e-19)).cast(d.dtype)
@@ -133,7 +135,8 @@ def cody_waite_reduction(d:UOp) -> tuple[UOp, UOp]:
d = (qdh + q) * -PI_D + d
elif x.dtype.scalar() == dtypes.float16:
# [FIXME] when reducing `d`, FP16 needs FP32 precision to achieve 1.0 ULP precision.
d = _reduce_d(x.cast(dtypes.float32), q.cast(dtypes.float32)).cast(dtypes.float16)
f32 = dtypes.float32.vec(x.dtype.count)
d = _reduce_d(x.cast(f32), q.cast(f32)).cast(x.dtype)
else:
# https://github.com/shibatch/sleef/blob/4e08851f59fc2b545f9c393c6a23dfd311a26308/src/libm/sleefsp.c#L464-L503
d = q * -3.1414794921875 + x
@@ -143,9 +146,9 @@ def cody_waite_reduction(d:UOp) -> tuple[UOp, UOp]:
return d
m_1_pi = 0.318309886183790671537767526745028724
qdh = (d * (m_1_pi / 2.0**24)).cast(dtypes.int64).cast(d.dtype) * (2.0**24)
qdh = (d * (m_1_pi / 2.0**24)).cast(dtypes.int64.vec(d.dtype.count)).cast(d.dtype) * (2.0**24)
quadrant = rintk(d * m_1_pi -qdh) if d.dtype.base.scalar() == dtypes.float64 else rintk(d * m_1_pi)
return _reduce_d(d, quadrant.cast(d.dtype)), quadrant.cast(dtypes.int32)
return _reduce_d(d, quadrant.cast(d.dtype)), quadrant.cast(dtypes.int32.vec(d.dtype.count))
# *** approximate sine on small angle. ***
def trig_poly(d:UOp, coeff32, coeff64): return d * (polyN(d*d, coeff64) if d.dtype.scalar() == dtypes.float64 else polyN(d*d, coeff32))
@@ -224,7 +227,8 @@ def xlog2(d:UOp) -> UOp:
"""
assert d.dtype.scalar() in TRANSCENDENTAL_DTYPES
# TODO: float16 denormal need float32 to achieve precision
if d.dtype.scalar() == dtypes.float16: return xlog2(d.cast(dtypes.float32)).cast(dtypes.float16)
if d.dtype.scalar() == dtypes.float16:
return xlog2(d.cast(dtypes.float32.vec(d.dtype.count))).cast(d.dtype)
FLT_MIN = d.const_like(1e-6 if d.dtype.scalar() == dtypes.float16 else 1e-4)
is_denormal = d<FLT_MIN
a = is_denormal.where(d * (2 ** 64), d)
@@ -261,9 +265,10 @@ def xpow(base:UOp, exponent:UOp) -> UOp:
# start with b ** e = exp2(e * log2(b))
ret = (base < 0).where(-base, base).log2().mul(exponent).exp2()
# negative base adjustment: nan for non-integer exponent and -1 for odd exponent
non_int = exponent != exponent.cast(dtypes.int32).cast(exponent.dtype)
int32_dtype = dtypes.int32.vec(exponent.dtype.count)
non_int = exponent != exponent.cast(int32_dtype).cast(exponent.dtype)
adj = non_int.where(ret.const_like(math.nan),
(exponent < 0).where(-exponent, exponent).cast(dtypes.int32).mod(2).cast(dtypes.bool).where(ret.const_like(-1), ret.const_like(1)))
(exponent < 0).where(-exponent, exponent).cast(int32_dtype).mod(2).cast(dtypes.bool.vec(exponent.dtype.count)).where(ret.const_like(-1), ret.const_like(1)))
# fix 0 ** 0 = 1
return (base.eq(0) & exponent.eq(0)).where(ret.const_like(1), ret * (base < 0).where(adj, ret.const_like(1)))