minor cleanup payne_hanek_reduction [pr] (#7383)

This commit is contained in:
chenyu
2024-10-29 17:59:18 -04:00
committed by GitHub
parent f6abde95fa
commit 99b82f5708

View File

@@ -91,9 +91,8 @@ def payne_hanek_reduction(d:UOp) -> Tuple[UOp, UOp]:
# 190 bits of 2/pi for Payne-Hanek style argument reduction
two_over_pi_f = [0x00000000,0x28be60db,0x9391054a,0x7f09d5f4,0x7d4d3770,0x36d8a566,0x4f10e410]
input_dtype: DType = d.dtype
input_dtype = d.dtype
dtype_via = dtypes.float32 if d.dtype == dtypes.float16 else d.dtype
acc_dtype = dtypes.uint64
f, e = frexp(d)
ia = (f.cast(dtype_via) * 4.294967296e9).cast(dtypes.uint64)
@@ -108,18 +107,14 @@ def payne_hanek_reduction(d:UOp) -> Tuple[UOp, UOp]:
if count+offset <= len(two_over_pi_f[0:-2]):
an = _eq(i, count).where(_take(an, offset, count=count+1), an.const_like(two_over_pi_f[count+offset]))
return an
def _exact_pow2if(x): return pow2if(x, input_dtype).cast(acc_dtype)
def _shl_lazy(x, y): return (x.cast(acc_dtype) * _exact_pow2if(y)).cast(dtypes.uint32)
def _shr_lazy(x, y): return (x.cast(acc_dtype) // _exact_pow2if(y)).cast(dtypes.uint32)
def _shl_lazy(x, y): return (x.cast(dtypes.uint64) * pow2if(y, input_dtype).cast(dtypes.uint64)).cast(dtypes.uint32)
def _shr_lazy(x, y): return (x.cast(dtypes.uint64) // pow2if(y, input_dtype).cast(dtypes.uint64)).cast(dtypes.uint32)
# a_n = (two_over_pi_f[Int(i) + n] << e) | (two_over_pi_f[Int(i) + n+1] >> (nbits - e))
a1 = _take(UOp.const(dtypes.uint32, 0), 0)
a2 = _take(UOp.const(dtypes.uint32, 0), 1)
a3 = _take(UOp.const(dtypes.uint32, 0), 2)
a4 = _take(UOp.const(dtypes.uint32, 0), 3)
a = [_take(UOp.const(dtypes.uint32, 0), i) for i in range(4)]
# Note: e >= 1 for all numbers d >= 1.0. assume e != 0
hi = _shl_lazy(a1, e) | _shr_lazy(a2, offset)
mi = _shl_lazy(a2, e) | _shr_lazy(a3, offset)
lo = _shl_lazy(a3, e) | _shr_lazy(a4, offset)
hi = _shl_lazy(a[0], e) | _shr_lazy(a[1], offset)
mi = _shl_lazy(a[1], e) | _shr_lazy(a[2], offset)
lo = _shl_lazy(a[2], e) | _shr_lazy(a[3], offset)
def _hp_mul(x:UOp, y:UOp) -> UOp: return x.cast(dtypes.uint64) * y.cast(dtypes.uint64)
p = shl(_hp_mul(ia, hi), 32) + _hp_mul(ia, mi) + shr(_hp_mul(ia, lo), 32)