mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
remove mla [run_process_replay] (#6357)
* remove mla * other bad uses of const
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
# type: ignore
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# TARGET arch is: []
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# type: ignore
|
||||
import ctypes, ctypes.util, struct, fcntl, re
|
||||
from hexdump import hexdump
|
||||
import pathlib, sys
|
||||
|
||||
@@ -86,9 +86,7 @@ def frexp(v:UOp) -> Tuple[UOp, UOp]:
|
||||
if v.dtype == dtypes.float16: exp = exp.bitcast(dtypes.int16)
|
||||
return value, exp
|
||||
|
||||
def mla(x:UOp, y:UOp, z:UOp) -> UOp: return x * y + z
|
||||
|
||||
def polyN(u:UOp, s:UOp, coeffs:List[float]) -> UOp: return functools.reduce(lambda u,c: mla(u, s, u.const(c)), coeffs, u)
|
||||
def polyN(u:UOp, s:UOp, coeffs:List[float]) -> UOp: return functools.reduce(lambda u,c: u*s+c, coeffs, u)
|
||||
# *** reduction algorithms for sine ***
|
||||
def payne_hanek_reduction(d:UOp) -> Tuple[UOp, UOp]:
|
||||
"""
|
||||
@@ -122,10 +120,10 @@ def payne_hanek_reduction(d:UOp) -> Tuple[UOp, UOp]:
|
||||
def _shl_lazy(x, y): return (x.cast(acc_dtype) * _exact_pow2if(y)).cast(dtypes.uint32)
|
||||
def _shr_lazy(x, y): return (x.cast(acc_dtype) // _exact_pow2if(y)).cast(dtypes.uint32)
|
||||
# a_n = (two_over_pi_f[Int(i) + n] << e) | (two_over_pi_f[Int(i) + n+1] >> (nbits - e))
|
||||
a1 = _take(i.const(0).cast(dtypes.uint32), 0)
|
||||
a2 = _take(i.const(0).cast(dtypes.uint32), 1)
|
||||
a3 = _take(i.const(0).cast(dtypes.uint32), 2)
|
||||
a4 = _take(i.const(0).cast(dtypes.uint32), 3)
|
||||
a1 = _take(UOp.const(dtypes.uint32, 0), 0)
|
||||
a2 = _take(UOp.const(dtypes.uint32, 0), 1)
|
||||
a3 = _take(UOp.const(dtypes.uint32, 0), 2)
|
||||
a4 = _take(UOp.const(dtypes.uint32, 0), 3)
|
||||
# Note: e >= 1 for all numbers d >= 1.0. assume e != 0
|
||||
hi = _shl_lazy(a1, e) | _shr_lazy(a2, offset)
|
||||
mi = _shl_lazy(a2, e) | _shr_lazy(a3, offset)
|
||||
@@ -141,7 +139,7 @@ def payne_hanek_reduction(d:UOp) -> Tuple[UOp, UOp]:
|
||||
r = (p.cast(dtype_via) * (3.4061215800865545e-19)).cast(input_dtype)
|
||||
|
||||
# if fraction >= 0.5, r -= pi/2, q += 1
|
||||
return f.lt(0.5).where(r, r + r.const(-math.pi / 2)), f.lt(0.5).where(q, q + 1)
|
||||
return f.lt(0.5).where(r, r + (-math.pi / 2)), f.lt(0.5).where(q, q + 1)
|
||||
|
||||
def cody_waite_reduction(d:UOp) -> Tuple[UOp, UOp]:
|
||||
"""
|
||||
@@ -152,25 +150,25 @@ def cody_waite_reduction(d:UOp) -> Tuple[UOp, UOp]:
|
||||
m_1_pi = 0.318309886183790671537767526745028724
|
||||
qdh = (d * (m_1_pi / 16777216)).cast(dtypes.int64).cast(d.dtype) * 16777216.0
|
||||
def _quadrant(x:UOp) -> UOp:
|
||||
if x.dtype == dtypes.float64: return rintk(mla(d, d.const(m_1_pi), -qdh)).cast(x.dtype)
|
||||
if x.dtype == dtypes.float64: return rintk(d * m_1_pi -qdh).cast(x.dtype)
|
||||
return rintk(x * m_1_pi).cast(x.dtype)
|
||||
def _reduce_d(x:UOp, q:UOp):
|
||||
if x.dtype == dtypes.float64:
|
||||
d = mla(qdh, x.const(-3.1415926218032836914), x)
|
||||
d = mla(q, x.const(-3.1415926218032836914), d)
|
||||
d = mla(qdh, x.const(-3.1786509424591713469e-08), d)
|
||||
d = mla(q, x.const(-3.1786509424591713469e-08), d)
|
||||
d = mla(qdh, x.const(-1.2246467864107188502e-16), d)
|
||||
d = mla(q, x.const(-1.2246467864107188502e-16), d)
|
||||
d = mla(qdh + q, x.const(-1.2736634327021899816e-24), d)
|
||||
d = qdh * -3.1415926218032836914 + x
|
||||
d = q * -3.1415926218032836914 + d
|
||||
d = qdh * -3.1786509424591713469e-08 + d
|
||||
d = q * -3.1786509424591713469e-08 + d
|
||||
d = qdh * -1.2246467864107188502e-16 + d
|
||||
d = q * -1.2246467864107188502e-16 + d
|
||||
d = (qdh + q) * -1.2736634327021899816e-24 + d
|
||||
elif x.dtype == dtypes.float16:
|
||||
# [FIXME] when reducing `d`, FP16 needs FP32 precision to achieve 1.0 ULP precision.
|
||||
d = _reduce_d(x.cast(dtypes.float32), q.cast(dtypes.float32)).cast(dtypes.float16)
|
||||
else:
|
||||
d = mla(q, x.const(-3.1414794921875), x)
|
||||
d = mla(q, x.const(-0.00011315941810607910156), d)
|
||||
d = mla(q, x.const(-1.9841872589410058936e-09), d)
|
||||
d = mla(q, x.const(-1.2154201256553420762e-10), d)
|
||||
d = q * -3.1414794921875 + x
|
||||
d = q * -0.00011315941810607910156 + d
|
||||
d = q * -1.9841872589410058936e-09 + d
|
||||
d = q * -1.2154201256553420762e-10 + d
|
||||
return d
|
||||
return _reduce_d(d, (q := _quadrant(d))), q.cast(dtypes.int32)
|
||||
# *** approximate sine on small angle. ***
|
||||
@@ -180,13 +178,13 @@ def trig_poly(d:UOp, coeff32, coeff64):
|
||||
if d.dtype == dtypes.float64:
|
||||
s2 = s * s
|
||||
s4 = s2 * s2
|
||||
def __poly4(x:UOp, x2:UOp, c3, c2, c1, c0) -> UOp: return mla(x2, mla(x, x.const(c3), x.const(c2)), mla(x, x.const(c1), x.const(c0)))
|
||||
def __poly8(x, x2, x4, c7, c6, c5, c4, c3, c2, c1, c0) -> UOp: return mla(x4, __poly4(x, x2, c7, c6, c5, c4), __poly4(x, x2, c3, c2, c1, c0))
|
||||
def __poly4(x:UOp, x2:UOp, c3, c2, c1, c0) -> UOp: return x2 * (x*c3+c2) + (x*c1+c0)
|
||||
def __poly8(x, x2, x4, c7, c6, c5, c4, c3, c2, c1, c0) -> UOp: return x4 * __poly4(x, x2, c7, c6, c5, c4) + __poly4(x, x2, c3, c2, c1, c0)
|
||||
u = __poly8(s, s2, s4, *coeff64[:-1])
|
||||
u = mla(u, s, d.const(coeff64[-1]))
|
||||
u = u * s + coeff64[-1]
|
||||
else:
|
||||
u = polyN(s.const(coeff32[0]), s, coeff32[1:])
|
||||
return mla(s, u * d, d)
|
||||
return s * (u * d) + d
|
||||
# approximate sine on [-pi/2, pi/2]
|
||||
def sin_poly(d:UOp) -> UOp:
|
||||
return trig_poly(d, [2.6083159809786593541503e-06, -0.0001981069071916863322258, 0.00833307858556509017944336, -0.166666597127914428710938],
|
||||
@@ -288,7 +286,7 @@ def xlog2(d:UOp) -> UOp:
|
||||
t = polyN(x.const(0.2211941750456081490e+0), x2, [0.2200768693152277689e+0, 0.2623708057488514656e+0, 0.3205977477944495502e+0,
|
||||
0.4121985945485324709e+0, 0.5770780162997058982e+0, 0.96179669392608091449])
|
||||
s_hi, s_lo = dfadd2_f2_f2_f2(e, e.const(0), *dfmul2_f2_f2_f2(t.const(2.885390081777926774), t.const(0), x, x.const(0)))
|
||||
r = mla(t, x * x2, s_hi + s_lo)
|
||||
r = t * (x * x2) + (s_hi + s_lo)
|
||||
else:
|
||||
xx, xy = dfdiv2_f2_f2_f2(*dfadd2_f2_f2_f2(m.const(-1), m.const(0), m, m.const(0)), *dfadd2_f2_f2_f2(m.const(1), m.const(0), m, m.const(0)))
|
||||
x2 = xx * xx
|
||||
|
||||
Reference in New Issue
Block a user