mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-29 08:48:15 -05:00
minor cleanup payne_hanek_reduction [pr] (#7383)
This commit is contained in:
@@ -91,9 +91,8 @@ def payne_hanek_reduction(d:UOp) -> Tuple[UOp, UOp]:
|
||||
# 190 bits of 2/pi for Payne-Hanek style argument reduction
|
||||
two_over_pi_f = [0x00000000,0x28be60db,0x9391054a,0x7f09d5f4,0x7d4d3770,0x36d8a566,0x4f10e410]
|
||||
|
||||
input_dtype: DType = d.dtype
|
||||
input_dtype = d.dtype
|
||||
dtype_via = dtypes.float32 if d.dtype == dtypes.float16 else d.dtype
|
||||
acc_dtype = dtypes.uint64
|
||||
|
||||
f, e = frexp(d)
|
||||
ia = (f.cast(dtype_via) * 4.294967296e9).cast(dtypes.uint64)
|
||||
@@ -108,18 +107,14 @@ def payne_hanek_reduction(d:UOp) -> Tuple[UOp, UOp]:
|
||||
if count+offset <= len(two_over_pi_f[0:-2]):
|
||||
an = _eq(i, count).where(_take(an, offset, count=count+1), an.const_like(two_over_pi_f[count+offset]))
|
||||
return an
|
||||
def _exact_pow2if(x): return pow2if(x, input_dtype).cast(acc_dtype)
|
||||
def _shl_lazy(x, y): return (x.cast(acc_dtype) * _exact_pow2if(y)).cast(dtypes.uint32)
|
||||
def _shr_lazy(x, y): return (x.cast(acc_dtype) // _exact_pow2if(y)).cast(dtypes.uint32)
|
||||
def _shl_lazy(x, y): return (x.cast(dtypes.uint64) * pow2if(y, input_dtype).cast(dtypes.uint64)).cast(dtypes.uint32)
|
||||
def _shr_lazy(x, y): return (x.cast(dtypes.uint64) // pow2if(y, input_dtype).cast(dtypes.uint64)).cast(dtypes.uint32)
|
||||
# a_n = (two_over_pi_f[Int(i) + n] << e) | (two_over_pi_f[Int(i) + n+1] >> (nbits - e))
|
||||
a1 = _take(UOp.const(dtypes.uint32, 0), 0)
|
||||
a2 = _take(UOp.const(dtypes.uint32, 0), 1)
|
||||
a3 = _take(UOp.const(dtypes.uint32, 0), 2)
|
||||
a4 = _take(UOp.const(dtypes.uint32, 0), 3)
|
||||
a = [_take(UOp.const(dtypes.uint32, 0), i) for i in range(4)]
|
||||
# Note: e >= 1 for all numbers d >= 1.0. assume e != 0
|
||||
hi = _shl_lazy(a1, e) | _shr_lazy(a2, offset)
|
||||
mi = _shl_lazy(a2, e) | _shr_lazy(a3, offset)
|
||||
lo = _shl_lazy(a3, e) | _shr_lazy(a4, offset)
|
||||
hi = _shl_lazy(a[0], e) | _shr_lazy(a[1], offset)
|
||||
mi = _shl_lazy(a[1], e) | _shr_lazy(a[2], offset)
|
||||
lo = _shl_lazy(a[2], e) | _shr_lazy(a[3], offset)
|
||||
|
||||
def _hp_mul(x:UOp, y:UOp) -> UOp: return x.cast(dtypes.uint64) * y.cast(dtypes.uint64)
|
||||
p = shl(_hp_mul(ia, hi), 32) + _hp_mul(ia, mi) + shr(_hp_mul(ia, lo), 32)
|
||||
|
||||
Reference in New Issue
Block a user