diff --git a/tinygrad/codegen/transcendental.py b/tinygrad/codegen/transcendental.py index 980a946219..26254a006d 100644 --- a/tinygrad/codegen/transcendental.py +++ b/tinygrad/codegen/transcendental.py @@ -91,9 +91,8 @@ def payne_hanek_reduction(d:UOp) -> Tuple[UOp, UOp]: # 190 bits of 2/pi for Payne-Hanek style argument reduction two_over_pi_f = [0x00000000,0x28be60db,0x9391054a,0x7f09d5f4,0x7d4d3770,0x36d8a566,0x4f10e410] - input_dtype: DType = d.dtype + input_dtype = d.dtype dtype_via = dtypes.float32 if d.dtype == dtypes.float16 else d.dtype - acc_dtype = dtypes.uint64 f, e = frexp(d) ia = (f.cast(dtype_via) * 4.294967296e9).cast(dtypes.uint64) @@ -108,18 +107,14 @@ def payne_hanek_reduction(d:UOp) -> Tuple[UOp, UOp]: if count+offset <= len(two_over_pi_f[0:-2]): an = _eq(i, count).where(_take(an, offset, count=count+1), an.const_like(two_over_pi_f[count+offset])) return an - def _exact_pow2if(x): return pow2if(x, input_dtype).cast(acc_dtype) - def _shl_lazy(x, y): return (x.cast(acc_dtype) * _exact_pow2if(y)).cast(dtypes.uint32) - def _shr_lazy(x, y): return (x.cast(acc_dtype) // _exact_pow2if(y)).cast(dtypes.uint32) + def _shl_lazy(x, y): return (x.cast(dtypes.uint64) * pow2if(y, input_dtype).cast(dtypes.uint64)).cast(dtypes.uint32) + def _shr_lazy(x, y): return (x.cast(dtypes.uint64) // pow2if(y, input_dtype).cast(dtypes.uint64)).cast(dtypes.uint32) # a_n = (two_over_pi_f[Int(i) + n] << e) | (two_over_pi_f[Int(i) + n+1] >> (nbits - e)) - a1 = _take(UOp.const(dtypes.uint32, 0), 0) - a2 = _take(UOp.const(dtypes.uint32, 0), 1) - a3 = _take(UOp.const(dtypes.uint32, 0), 2) - a4 = _take(UOp.const(dtypes.uint32, 0), 3) + a = [_take(UOp.const(dtypes.uint32, 0), i) for i in range(4)] # Note: e >= 1 for all numbers d >= 1.0. assume e != 0 - hi = _shl_lazy(a1, e) | _shr_lazy(a2, offset) - mi = _shl_lazy(a2, e) | _shr_lazy(a3, offset) - lo = _shl_lazy(a3, e) | _shr_lazy(a4, offset) + hi = _shl_lazy(a[0], e) | _shr_lazy(a[1], offset) + mi = _shl_lazy(a[1], e) | _shr_lazy(a[2], offset) + lo = _shl_lazy(a[2], e) | _shr_lazy(a[3], offset) def _hp_mul(x:UOp, y:UOp) -> UOp: return x.cast(dtypes.uint64) * y.cast(dtypes.uint64) p = shl(_hp_mul(ia, hi), 32) + _hp_mul(ia, mi) + shr(_hp_mul(ia, lo), 32)