mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
more complex
This commit is contained in:
@@ -80,10 +80,12 @@ def payne_hanek_reduction(d:UOp) -> tuple[UOp, UOp]:
|
||||
intermediate_dtype = dtypes.float32.vec(d.dtype.count) if d.dtype.base.scalar() == dtypes.float16 else d.dtype
|
||||
|
||||
f, e = frexp(d)
|
||||
ia = (f.cast(intermediate_dtype) * 4.294967296e9).cast(dtypes.uint64)
|
||||
vc = d.dtype.count
|
||||
u64, u32, i32 = dtypes.uint64.vec(vc), dtypes.uint32.vec(vc), dtypes.int32.vec(vc)
|
||||
ia = (f.cast(intermediate_dtype) * 4.294967296e9).cast(u64)
|
||||
# extract 96 relevant bits of 2/pi based on magnitude of argument
|
||||
i = shr(e.cast(dtypes.uint64), 5)
|
||||
e = e.cast(dtypes.int32) & 31
|
||||
i = shr(e.cast(u64), 5)
|
||||
e = e.cast(i32) & 31
|
||||
offset = 32 - e
|
||||
|
||||
def _take(an:UOp, offset:int, count:int=0) -> UOp:
|
||||
@@ -91,8 +93,8 @@ def payne_hanek_reduction(d:UOp) -> tuple[UOp, UOp]:
|
||||
if count+offset < len(two_over_pi_f) - 1:
|
||||
an = i.ne(count).where(_take(an, offset, count=count+1), an.const_like(two_over_pi_f[count+offset]))
|
||||
return an
|
||||
def _shl_lazy(x:UOp, y:UOp): return (x.cast(dtypes.uint64) * pow2if(y, d.dtype).cast(dtypes.uint64)).cast(dtypes.uint32)
|
||||
def _shr_lazy(x:UOp, y:UOp): return (x.cast(dtypes.uint64) // pow2if(y, d.dtype).cast(dtypes.uint64)).cast(dtypes.uint32)
|
||||
def _shl_lazy(x:UOp, y:UOp): return (x.cast(u64) * pow2if(y, d.dtype).cast(u64)).cast(u32)
|
||||
def _shr_lazy(x:UOp, y:UOp): return (x.cast(u64) // pow2if(y, d.dtype).cast(u64)).cast(u32)
|
||||
|
||||
a = [_take(UOp.const(dtypes.uint32.vec(d.dtype.count), 0), i) for i in range(4)]
|
||||
# (two_over_pi_f[Int(i) + n] << e) | (two_over_pi_f[Int(i) + n+1] >> (nbits - e))
|
||||
@@ -101,12 +103,12 @@ def payne_hanek_reduction(d:UOp) -> tuple[UOp, UOp]:
|
||||
mi = _shl_lazy(a[1], e) | _shr_lazy(a[2], offset)
|
||||
lo = _shl_lazy(a[2], e) | _shr_lazy(a[3], offset)
|
||||
|
||||
def _hp_mul(x:UOp, y:UOp) -> UOp: return x.cast(dtypes.uint64) * y.cast(dtypes.uint64)
|
||||
def _hp_mul(x:UOp, y:UOp) -> UOp: return x.cast(u64) * y.cast(u64)
|
||||
# compute x * 2/pi
|
||||
p = shl(_hp_mul(ia, hi), 32) + _hp_mul(ia, mi) + shr(_hp_mul(ia, lo), 32)
|
||||
|
||||
# round quotient to nearest
|
||||
q = shr(p, 62).cast(dtypes.int32)
|
||||
q = shr(p, 62).cast(i32)
|
||||
p = p & 0x3fffffffffffffff
|
||||
r = (p.cast(intermediate_dtype) * (3.4061215800865545e-19)).cast(d.dtype)
|
||||
|
||||
@@ -133,7 +135,8 @@ def cody_waite_reduction(d:UOp) -> tuple[UOp, UOp]:
|
||||
d = (qdh + q) * -PI_D + d
|
||||
elif x.dtype.scalar() == dtypes.float16:
|
||||
# [FIXME] when reducing `d`, FP16 needs FP32 precision to achieve 1.0 ULP precision.
|
||||
d = _reduce_d(x.cast(dtypes.float32), q.cast(dtypes.float32)).cast(dtypes.float16)
|
||||
f32 = dtypes.float32.vec(x.dtype.count)
|
||||
d = _reduce_d(x.cast(f32), q.cast(f32)).cast(x.dtype)
|
||||
else:
|
||||
# https://github.com/shibatch/sleef/blob/4e08851f59fc2b545f9c393c6a23dfd311a26308/src/libm/sleefsp.c#L464-L503
|
||||
d = q * -3.1414794921875 + x
|
||||
@@ -143,9 +146,9 @@ def cody_waite_reduction(d:UOp) -> tuple[UOp, UOp]:
|
||||
return d
|
||||
|
||||
m_1_pi = 0.318309886183790671537767526745028724
|
||||
qdh = (d * (m_1_pi / 2.0**24)).cast(dtypes.int64).cast(d.dtype) * (2.0**24)
|
||||
qdh = (d * (m_1_pi / 2.0**24)).cast(dtypes.int64.vec(d.dtype.count)).cast(d.dtype) * (2.0**24)
|
||||
quadrant = rintk(d * m_1_pi -qdh) if d.dtype.base.scalar() == dtypes.float64 else rintk(d * m_1_pi)
|
||||
return _reduce_d(d, quadrant.cast(d.dtype)), quadrant.cast(dtypes.int32)
|
||||
return _reduce_d(d, quadrant.cast(d.dtype)), quadrant.cast(dtypes.int32.vec(d.dtype.count))
|
||||
|
||||
# *** approximate sine on small angle. ***
|
||||
def trig_poly(d:UOp, coeff32, coeff64): return d * (polyN(d*d, coeff64) if d.dtype.scalar() == dtypes.float64 else polyN(d*d, coeff32))
|
||||
@@ -224,7 +227,8 @@ def xlog2(d:UOp) -> UOp:
|
||||
"""
|
||||
assert d.dtype.scalar() in TRANSCENDENTAL_DTYPES
|
||||
# TODO: float16 denormal need float32 to achieve precision
|
||||
if d.dtype.scalar() == dtypes.float16: return xlog2(d.cast(dtypes.float32)).cast(dtypes.float16)
|
||||
if d.dtype.scalar() == dtypes.float16:
|
||||
return xlog2(d.cast(dtypes.float32.vec(d.dtype.count))).cast(d.dtype)
|
||||
FLT_MIN = d.const_like(1e-6 if d.dtype.scalar() == dtypes.float16 else 1e-4)
|
||||
is_denormal = d<FLT_MIN
|
||||
a = is_denormal.where(d * (2 ** 64), d)
|
||||
@@ -261,9 +265,10 @@ def xpow(base:UOp, exponent:UOp) -> UOp:
|
||||
# start with b ** e = exp2(e * log2(b))
|
||||
ret = (base < 0).where(-base, base).log2().mul(exponent).exp2()
|
||||
# negative base adjustment: nan for non-integer exponent and -1 for odd exponent
|
||||
non_int = exponent != exponent.cast(dtypes.int32).cast(exponent.dtype)
|
||||
int32_dtype = dtypes.int32.vec(exponent.dtype.count)
|
||||
non_int = exponent != exponent.cast(int32_dtype).cast(exponent.dtype)
|
||||
adj = non_int.where(ret.const_like(math.nan),
|
||||
(exponent < 0).where(-exponent, exponent).cast(dtypes.int32).mod(2).cast(dtypes.bool).where(ret.const_like(-1), ret.const_like(1)))
|
||||
(exponent < 0).where(-exponent, exponent).cast(int32_dtype).mod(2).cast(dtypes.bool.vec(exponent.dtype.count)).where(ret.const_like(-1), ret.const_like(1)))
|
||||
# fix 0 ** 0 = 1
|
||||
return (base.eq(0) & exponent.eq(0)).where(ret.const_like(1), ret * (base < 0).where(adj, ret.const_like(1)))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user