From de6ab564583154c01ac5472187d4a08ad07371c6 Mon Sep 17 00:00:00 2001 From: chenyu Date: Sat, 13 Jul 2024 14:00:14 -0400 Subject: [PATCH] clean up transcend math with uop syntactic sugar [run_process_replay] (#5455) * clean up transcend math with uop syntactic sugar [run_process_replay] * that? * maybe --- tinygrad/codegen/transcendental.py | 163 ++++++++++++++--------------- tinygrad/codegen/uops.py | 3 +- 2 files changed, 79 insertions(+), 87 deletions(-) diff --git a/tinygrad/codegen/transcendental.py b/tinygrad/codegen/transcendental.py index 5f2c72aeb5..e4229d718c 100644 --- a/tinygrad/codegen/transcendental.py +++ b/tinygrad/codegen/transcendental.py @@ -1,14 +1,13 @@ import math from typing import Tuple, List from tinygrad.dtype import dtypes, DType -from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps from tinygrad.codegen.uops import UOp TRANSCENDENTAL_SUPPORTED_DTYPES = {dtypes.float16, dtypes.float32, dtypes.float64} def _lazy_map_numbers(x:UOp, inf:UOp, _inf:UOp, nan:UOp, ratio:UOp): """replace inf -> inf, -inf -> _inf, nan -> nan, otherwise -> ratio""" - return x.e(BinaryOps.CMPNE, x.const(math.inf)).e(TernaryOps.WHERE, x.e(BinaryOps.CMPNE, x).e(TernaryOps.WHERE, nan, x.e(BinaryOps.CMPNE, x.const(-math.inf)).e(TernaryOps.WHERE, ratio, _inf)), inf) # noqa: E501 + return x.ne(math.inf).where(x.ne(x).where(nan, x.ne(-math.inf).where(ratio, _inf)), inf) # *** helper functions for double/quad precision arithmetics *** def dfadd2_f2_f2_f2(xx:UOp, xy:UOp, yx:UOp, yy:UOp) -> Tuple[UOp, UOp]: return xx + yx, xy + yy def dfmul2_f2_f2_f2(xx:UOp, xy:UOp, yx:UOp, yy:UOp) -> Tuple[UOp, UOp]: return xx * yx, xx * yy + xy * yx @@ -40,20 +39,20 @@ def bits_to_float(d:UOp, float_dtype:DType) -> UOp: cast_to = {dtypes.uint64: dtypes.float64, dtypes.uint32: dtypes.float32, dtypes.uint16: float_dtype}[d.dtype] return d.bitcast(cast_to) # **** utils **** -def shr(x:UOp, y:int) -> UOp: return x.e(BinaryOps.IDIV, x.const(2**y)) -def shl(x:UOp, y:int) -> UOp: return x.e(BinaryOps.MUL, x.const(2**y)) +def shr(x:UOp, y:int) -> UOp: return x // (2**y) +def shl(x:UOp, y:int) -> UOp: return x * (2**y) def rintk(d:UOp) -> UOp: """ceiling(d:float) -> int""" assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES return_t = {dtypes.float64: dtypes.int64, dtypes.float32: dtypes.int32, dtypes.float16: dtypes.int16}[d.dtype] - return d.e(BinaryOps.ADD, d.e(BinaryOps.CMPLT, d.const(0.0)).e(TernaryOps.WHERE, d.const(-0.5), d.const(0.5))).cast(return_t) + return (d + d.lt(0.0).where(d.const(-0.5), d.const(0.5))).cast(return_t) def pow2if(q:UOp, float_dtype:DType): """cast(2^q, float_dtype) where q is any integer in the range of [-126, 127]""" assert q.dtype in (dtypes.int64, dtypes.int32, dtypes.int16, dtypes.uint32) final_dtype = {dtypes.int64: dtypes.float64, dtypes.int32: dtypes.float32, dtypes.int16: float_dtype, dtypes.uint32: dtypes.float32}[q.dtype] - return shl(q.e(BinaryOps.ADD, q.const(exponent_bias(final_dtype)+1)), significand_bits(final_dtype)).bitcast(final_dtype) + return shl((q + (exponent_bias(final_dtype)+1)), significand_bits(final_dtype)).bitcast(final_dtype) def ilogb2k(d:UOp) -> UOp: """calculate the integer part of log2(d), where d is normalized fp value in the range of [0, +inf).""" @@ -61,7 +60,7 @@ def ilogb2k(d:UOp) -> UOp: dint = d.bitcast({dtypes.float64: dtypes.int64, dtypes.float32: dtypes.int32, dtypes.float16: dtypes.int16}[d.dtype]) # -1 <= ilog2bk(d) <= 128 # ((float_to_bits(d) >> significand_bits(dtype)) & exponent_mask(dtype)) - exponent_bias(dtype) - return shr(dint, significand_bits(d.dtype)).e(BinaryOps.AND, dint.const(exponent_mask(d.dtype))).e(BinaryOps.ADD, dint.const(-(exponent_bias(d.dtype)+1))) # noqa: E501 + return (shr(dint, significand_bits(d.dtype)) & exponent_mask(d.dtype)) - (exponent_bias(d.dtype)+1) def ldexp3k(d:UOp, e:UOp) -> UOp: """d*2^e. e is a number obtained by casting an integer in the range [-127, 127] to a float. d is any float number.""" @@ -71,12 +70,12 @@ def ldexp3k(d:UOp, e:UOp) -> UOp: e = e.cast(cast_map[d.dtype]) m1 = d.bitcast(cast_map[d.dtype]) m2 = shl(e, significand_bits(d.dtype)) - return m1.e(BinaryOps.ADD, m2).bitcast(d.dtype).cast(dtype) + return (m1 + m2).bitcast(d.dtype).cast(dtype) def ldexp2k(d:UOp, e:UOp) -> UOp: """d*2^e. much faster than ldexp3k but risky. d > 0 and d is not denormal.""" assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES and e.dtype in (dtypes.int16, dtypes.int32, dtypes.int64) - return d.e(BinaryOps.MUL, pow2if(shr(e, 1), d.dtype)).e(BinaryOps.MUL, pow2if(e.e(BinaryOps.ADD, shr(e, 1).e(UnaryOps.NEG)), d.dtype)) + return (d * pow2if(shr(e, 1), d.dtype)) * pow2if(e - shr(e, 1), d.dtype) def frexp(v:UOp) -> Tuple[UOp, UOp]: """frexp(v) -> (mantissa, exponent)""" @@ -86,23 +85,19 @@ def frexp(v:UOp) -> Tuple[UOp, UOp]: m2 = {dtypes.float64: 0x3FE0000000000000, dtypes.float32: 0x3F000000, dtypes.float16: 0x3C00}[v.dtype] bias = {dtypes.float64: 1022, dtypes.float32: 126, dtypes.float16: 15}[v.dtype] bits = float_to_bits(v) - exponent = shr(bits, significand_bits(v.dtype)).e(BinaryOps.AND, bits.const(exponent_mask(v.dtype))) - exponent_zero = exponent.e(BinaryOps.CMPNE, exponent.const(0.0)) - result_f = bits_to_float(bits.e(BinaryOps.AND, bits.const(m1)).e(BinaryOps.OR, bits.const(m2)), v.dtype) - value = exponent_zero.e(TernaryOps.WHERE, result_f, v) - exp = exponent.e(BinaryOps.ADD, exponent.const(-bias)) - exp = exponent_zero.e(TernaryOps.WHERE, exp, exp.const(0)) - if v.dtype == dtypes.float16: - exp = exp.bitcast(dtypes.int16) + exponent = shr(bits, significand_bits(v.dtype)) & exponent_mask(v.dtype) + exponent_zero = exponent.ne(0.0) + result_f = bits_to_float((bits & m1) | m2, v.dtype) + value = exponent_zero.where(result_f, v) + exp = exponent + (-bias) + exp = exponent_zero.where(exp, exp.const(0)) + if v.dtype == dtypes.float16: exp = exp.bitcast(dtypes.int16) return value, exp -def mla(x:UOp, y:UOp, z:UOp) -> UOp: - """x*y+z""" - return x.e(BinaryOps.MUL, y).e(BinaryOps.ADD, z) +def mla(x:UOp, y:UOp, z:UOp) -> UOp: return x * y + z def polyN(u:UOp, s:UOp, coeffs:List[float]) -> UOp: - for c in coeffs: - u = mla(u, s, u.const(c)) + for c in coeffs: u = mla(u, s, u.const(c)) return u # *** reduction algorithms for sine *** def payne_hanek_reduction(d:UOp) -> Tuple[UOp, UOp]: @@ -122,46 +117,46 @@ def payne_hanek_reduction(d:UOp) -> Tuple[UOp, UOp]: acc_dtype = dtypes.uint64 f, e = frexp(d) - ia = (k := f.cast(dtype_via)).e(BinaryOps.MUL, k.const(4.294967296e9)).cast(dtypes.uint64) + ia = (f.cast(dtype_via) * 4.294967296e9).cast(dtypes.uint64) i = shr(e.cast(dtypes.uint64), 5) - e = (k := e.cast(dtypes.uint64)).e(BinaryOps.AND, k.const(31)).cast(dtypes.uint32) - offset = e.const(32).e(BinaryOps.ADD, e.e(UnaryOps.NEG)) + e = (e.cast(dtypes.uint64) & 31).cast(dtypes.uint32) + offset = -e + 32 - def _eq(arr:UOp, eq_to:int) -> UOp: return arr.e(BinaryOps.CMPNE, arr.const(eq_to)) + def _eq(arr:UOp, eq_to:int) -> UOp: return arr.ne(eq_to) def _take(an:UOp, offset:int, count:int=0) -> UOp: """an = two_over_pi_f[i+offset]""" if count+offset <= len(two_over_pi_f[0:-2]): - an = _eq(i, count).e(TernaryOps.WHERE, _take(an, offset, count=count+1), an.const(two_over_pi_f[count+offset])) + an = _eq(i, count).where(_take(an, offset, count=count+1), an.const(two_over_pi_f[count+offset])) return an def _exact_pow2if(x): return pow2if(x, input_dtype).cast(acc_dtype) - def _shl_lazy(x, y): return x.cast(acc_dtype).e(BinaryOps.MUL, _exact_pow2if(y)).cast(dtypes.uint32) - def _shr_lazy(x, y): return x.cast(acc_dtype).e(BinaryOps.IDIV, _exact_pow2if(y)).cast(dtypes.uint32) + def _shl_lazy(x, y): return (x.cast(acc_dtype) * _exact_pow2if(y)).cast(dtypes.uint32) + def _shr_lazy(x, y): return (x.cast(acc_dtype) // _exact_pow2if(y)).cast(dtypes.uint32) # a_n = (two_over_pi_f[Int(i) + n] << e) | (two_over_pi_f[Int(i) + n+1] >> (nbits - e)) a1 = _take(i.const(0).cast(dtypes.uint32), 0) a2 = _take(i.const(0).cast(dtypes.uint32), 1) a3 = _take(i.const(0).cast(dtypes.uint32), 2) a4 = _take(i.const(0).cast(dtypes.uint32), 3) # Note: e >= 1 for all numbers d >= 1.0. assume e != 0 - hi = _shl_lazy(a1, e).e(BinaryOps.OR, _shr_lazy(a2, offset)) - mi = _shl_lazy(a2, e).e(BinaryOps.OR, _shr_lazy(a3, offset)) - lo = _shl_lazy(a3, e).e(BinaryOps.OR, _shr_lazy(a4, offset)) + hi = _shl_lazy(a1, e) | _shr_lazy(a2, offset) + mi = _shl_lazy(a2, e) | _shr_lazy(a3, offset) + lo = _shl_lazy(a3, e) | _shr_lazy(a4, offset) - def _hp_mul(x:UOp, y:UOp) -> UOp: return x.cast(dtypes.uint64).e(BinaryOps.MUL, y.cast(dtypes.uint64)) + def _hp_mul(x:UOp, y:UOp) -> UOp: return x.cast(dtypes.uint64) * y.cast(dtypes.uint64) p = _hp_mul(ia, lo) - p = _hp_mul(ia, mi).e(BinaryOps.ADD, shr(p, 32)) - p = shl(_hp_mul(ia, hi), 32).e(BinaryOps.ADD, p) + p = _hp_mul(ia, mi) + shr(p, 32) + p = shl(_hp_mul(ia, hi), 32) + p q = shr(p, 62).cast(dtypes.int32) - p = p.e(BinaryOps.AND, p.const(0x3fffffffffffffff)) + p = p & 0x3fffffffffffffff d = p.cast(dtype_via) - d = d.e(BinaryOps.MUL, d.const(3.4061215800865545e-19)) + d = d * (3.4061215800865545e-19) r = d.cast(input_dtype) - fraction_map = f.e(BinaryOps.CMPLT, f.const(0.5)) + fraction_map = f.lt(0.5) # if fraction >= 0.5, r -= pi/2, q += 1 - r = fraction_map.e(TernaryOps.WHERE, r, r.e(BinaryOps.ADD, r.const(-math.pi / 2))) - q = fraction_map.e(TernaryOps.WHERE, q, q.e(BinaryOps.ADD, q.const(1))) + r = fraction_map.where(r, r + r.const(-math.pi / 2)) + q = fraction_map.where(q, q + 1) return r, q def cody_waite_reduction(d:UOp) -> Tuple[UOp, UOp]: @@ -171,11 +166,10 @@ def cody_waite_reduction(d:UOp) -> Tuple[UOp, UOp]: Returns a tuple of `(r, q)`, where the output format is the same as that of `payne_hanek_reduction`. """ m_1_pi = 0.318309886183790671537767526745028724 - qdh = d.e(BinaryOps.MUL, d.const(m_1_pi / 16777216)).cast(dtypes.int64).cast(d.dtype).e(BinaryOps.MUL, d.const(16777216.0)) + qdh = (d * (m_1_pi / 16777216)).cast(dtypes.int64).cast(d.dtype) * 16777216.0 def _quadrant(x:UOp) -> UOp: - if x.dtype == dtypes.float64: - return rintk(mla(d, d.const(m_1_pi), qdh.e(UnaryOps.NEG))).cast(x.dtype) - return rintk(x.e(BinaryOps.MUL, d.const(m_1_pi))).cast(x.dtype) + if x.dtype == dtypes.float64: return rintk(mla(d, d.const(m_1_pi), -qdh)).cast(x.dtype) + return rintk(x * m_1_pi).cast(x.dtype) def _reduce_d(x:UOp, q:UOp): if x.dtype == dtypes.float64: d = mla(qdh, x.const(-3.1415926218032836914), x) @@ -184,7 +178,7 @@ def cody_waite_reduction(d:UOp) -> Tuple[UOp, UOp]: d = mla(q, x.const(-3.1786509424591713469e-08), d) d = mla(qdh, x.const(-1.2246467864107188502e-16), d) d = mla(q, x.const(-1.2246467864107188502e-16), d) - d = mla(qdh.e(BinaryOps.ADD, q), x.const(-1.2736634327021899816e-24), d) + d = mla(qdh + q, x.const(-1.2736634327021899816e-24), d) elif x.dtype == dtypes.float16: # [FIXME] when reducing `d`, FP16 needs FP32 precision to achieve 1.0 ULP precision. d = _reduce_d(x.cast(dtypes.float32), q.cast(dtypes.float32)).cast(dtypes.float16) @@ -198,30 +192,30 @@ def cody_waite_reduction(d:UOp) -> Tuple[UOp, UOp]: # *** approximate sine on small angle. *** def trig_poly(d:UOp, coeff32, coeff64): u = None - s = d.e(BinaryOps.MUL, d) + s = d * d if d.dtype == dtypes.float64: - s2 = s.e(BinaryOps.MUL, s) - s4 = s2.e(BinaryOps.MUL, s2) - def __poly4(x: UOp, x2: UOp, c3, c2, c1, c0) -> UOp: return mla(x2, mla(x, x.const(c3), x.const(c2)), mla(x, x.const(c1), x.const(c0))) + s2 = s * s + s4 = s2 * s2 + def __poly4(x:UOp, x2:UOp, c3, c2, c1, c0) -> UOp: return mla(x2, mla(x, x.const(c3), x.const(c2)), mla(x, x.const(c1), x.const(c0))) def __poly8(x, x2, x4, c7, c6, c5, c4, c3, c2, c1, c0) -> UOp: return mla(x4, __poly4(x, x2, c7, c6, c5, c4), __poly4(x, x2, c3, c2, c1, c0)) u = __poly8(s, s2, s4, *coeff64[:-1]) u = mla(u, s, d.const(coeff64[-1])) else: u = polyN(s.const(coeff32[0]), s, coeff32[1:]) - return mla(s, u.e(BinaryOps.MUL, d), d) + return mla(s, u * d, d) # approximate sine on [-pi/2, pi/2] def sin_poly(d:UOp) -> UOp: return trig_poly(d, [2.6083159809786593541503e-06, -0.0001981069071916863322258, 0.00833307858556509017944336, -0.166666597127914428710938], [-7.97255955009037868891952e-18, 2.81009972710863200091251e-15, -7.64712219118158833288484e-13, 1.60590430605664501629054e-10, -2.50521083763502045810755e-08, 2.75573192239198747630416e-06, -0.000198412698412696162806809, 0.00833333333333332974823815, -0.166666666666666657414808]) # noqa: E501 def sin_poly_small(d:UOp, q:UOp) -> UOp: - def _ifand(n: int): return q.e(BinaryOps.AND, q.const(n)).e(BinaryOps.CMPNE, q.const(0)) + def _ifand(n:int): return (q & n).ne(0) r = sin_poly(d) - return r.e(BinaryOps.MUL, _ifand(1).e(TernaryOps.WHERE, r.const(-1), r.const(1))) + return r * _ifand(1).where(r.const(-1), r.const(1)) def sin_poly_large(d:UOp, q:UOp) -> UOp: - def _ifand(n: int): return q.e(BinaryOps.AND, q.const(n)).e(BinaryOps.CMPNE, q.const(0)) - d = d.e(BinaryOps.ADD, _ifand(1).e(TernaryOps.WHERE, d.const(math.pi / 2), d.const(0))) + def _ifand(n:int): return (q & n).ne(0) + d = d + _ifand(1).where(d.const(math.pi / 2), d.const(0)) r = sin_poly(d) - return r.e(BinaryOps.MUL, _ifand(2).e(TernaryOps.WHERE, r.const(-1), r.const(1))) + return r * _ifand(2).where(r.const(-1), r.const(1)) # *** toplevel functions for xsin/xlog2/xexp2 *** def xsin(d:UOp, fast:bool=False, switch_over:float=30.0) -> UOp: """ @@ -234,19 +228,19 @@ def xsin(d:UOp, fast:bool=False, switch_over:float=30.0) -> UOp: # mask +-inf/nan as zero x = _lazy_map_numbers(d, d.const(0.0), d.const(0.0), d.const(0.0), d) # x_sign = sign(x) - x_sign = x.e(BinaryOps.CMPNE, d.const(0)).e(TernaryOps.WHERE, x.e(BinaryOps.CMPLT, x.const(0)).e(TernaryOps.WHERE, x.const(-1), x.const(1)), x.const(0)) # noqa: E501 - x_abs = x.e(BinaryOps.MUL, x_sign) + x_sign = x.ne(0).where(x.lt(0).where(x.const(-1), x.const(1)), x.const(0)) + x_abs = x * x_sign r, q = reduction_algo(x_abs) if fast: result = sin_poly_small(r, q) else: # Payne Hanek Reduction assumes abs(x) >= pi/4, so for smaller values, use cody_waite_reduction. - switch_over_map = x_abs.e(BinaryOps.CMPLT, x.const(switch_over)) + switch_over_map = x_abs.lt(switch_over) r_fast, q_fast = cody_waite_reduction(x_abs) - r = switch_over_map.e(TernaryOps.WHERE, r_fast, r) - q = switch_over_map.e(TernaryOps.WHERE, q_fast, q) - result = switch_over_map.e(TernaryOps.WHERE, sin_poly_small(r, q), sin_poly_large(r, q)) - result = result.e(BinaryOps.MUL, x_sign) # adjusts the sign for abs(x). + r = switch_over_map.where(r_fast, r) + q = switch_over_map.where(q_fast, q) + result = switch_over_map.where(sin_poly_small(r, q), sin_poly_large(r, q)) + result = result * x_sign # adjusts the sign for abs(x). # sin(Inf) = NaN, sin(-Inf) = NaN, sin(NaN) = NaN return _lazy_map_numbers(d, d.const(math.nan), d.const(math.nan), d.const(math.nan), result) @@ -261,7 +255,7 @@ def xexp2(x:UOp) -> UOp: d = _lazy_map_numbers(x, x.const(0.0), x.const(0.0), x.const(0.0), x) q = rintk(d) # s = d - round(d) - s = d.e(BinaryOps.ADD, q.cast(d.dtype).e(UnaryOps.NEG)) + s = d - q.cast(d.dtype) # a polynomial approximation with 13 non-zero terms in the range of [−(log 2)/2,(log 2)/2]. if fp64_p: u = polyN(s.const(0.4434359082926529454e-9), s, [0.7073164598085707425e-8, 0.1017819260921760451e-6, 0.1321543872511327615e-5, 0.1525273353517584730e-4, 0.1540353045101147808e-3, 0.1333355814670499073e-2, 0.9618129107597600536e-2, 0.5550410866482046596e-1, 0.2402265069591012214e+0, 0.6931471805599452862e+0, 0.1000000000000000000e+1]) # noqa: E501 @@ -271,12 +265,12 @@ def xexp2(x:UOp) -> UOp: upper = {dtypes.float64: 1024, dtypes.float32: 128, dtypes.float16: 23.0}[d.dtype] lower = {dtypes.float64: -2000, dtypes.float32: -150, dtypes.float16: -22}[d.dtype] # Replace x >= upper with +inf - u = d.e(BinaryOps.CMPNE, d.const(upper)).e(TernaryOps.WHERE, u, d.const(math.inf)) - u = d.e(BinaryOps.CMPLT, d.const(upper)).e(TernaryOps.WHERE, u, d.const(math.inf)) + u = d.ne(upper).where(u, d.const(math.inf)) + u = d.lt(upper).where(u, d.const(math.inf)) # Replace x <= lower with zero. - u = d.e(BinaryOps.CMPLT, d.const(lower)).e(TernaryOps.WHERE, d.const(0.0), u) + u = d.lt(lower).where(d.const(0.0), u) # x=NaN never satisfies x < Inf. (for fastmode) - u = d.e(BinaryOps.CMPLT, d.const(math.inf)).e(TernaryOps.WHERE, u, u.const(math.nan)) + u = d.lt(math.inf).where(u, u.const(math.nan)) # exp2(Inf) = Inf, exp2(-Inf) = 0, exp2(NaN) = NaN return _lazy_map_numbers(x, x.const(math.inf), x.const(0.0), x.const(math.nan), u) @@ -289,39 +283,36 @@ def xlog2(d:UOp) -> UOp: fp64_p = d.dtype == dtypes.float64 FLT_MIN = d.const(1e-6 if d.dtype == dtypes.float16 else 1e-4) d_orig = d - denormal_map = d.e(BinaryOps.CMPLT, FLT_MIN) - for _ in range(8): - d = denormal_map.e(TernaryOps.WHERE, d.e(BinaryOps.MUL, d.const(2 ** 8)), d) + denormal_map = d.lt(FLT_MIN) + for _ in range(8): d = denormal_map.where(d * (2 ** 8), d) - e = ilogb2k(d.e(BinaryOps.MUL, d.const(1.0 / 0.75))).cast(d.dtype) - m = ldexp3k(d, e.e(UnaryOps.NEG)) - e = denormal_map.e(TernaryOps.WHERE, e.e(BinaryOps.ADD, e.const(-64)), e) + e = ilogb2k(d * (1.0 / 0.75)).cast(d.dtype) + m = ldexp3k(d, -e) + e = denormal_map.where(e + (-64), e) if fp64_p: - x = m.e(BinaryOps.ADD, m.const(-1.0)).e(BinaryOps.MUL, m.e(BinaryOps.ADD, m.const(1.0)).e(UnaryOps.RECIP)) - x2 = x.e(BinaryOps.MUL, x) + x = (m - 1.0) * (m + 1.0).recip() + x2 = x * x t = polyN(x.const(0.2211941750456081490e+0), x2, [0.2200768693152277689e+0, 0.2623708057488514656e+0, 0.3205977477944495502e+0, 0.4121985945485324709e+0, 0.5770780162997058982e+0, 0.96179669392608091449]) # noqa: E501 s_hi, s_lo = dfadd2_f2_f2_f2(e, e.const(0), *dfmul2_f2_f2_f2(t.const(2.885390081777926774), t.const(0), x, x.const(0))) - r = mla(t, x.e(BinaryOps.MUL, x2), s_hi.e(BinaryOps.ADD, s_lo)) + r = mla(t, x * x2, s_hi + s_lo) else: xx, xy = dfdiv2_f2_f2_f2(*dfadd2_f2_f2_f2(m.const(-1), m.const(0), m, m.const(0)), *dfadd2_f2_f2_f2(m.const(1), m.const(0), m, m.const(0))) - x2 = xx.e(BinaryOps.MUL, xx) + x2 = xx * xx t = polyN(d.const(0.4374550283e+0), x2, [0.5764790177e+0, 0.9618012905120]) sx, sy = dfadd2_f2_f2_f2(e, e.const(0), *dfmul2_f2_f2_f2(xx, xy, xx.const(2.8853900432586669922), xy.const(3.2734474483568488616e-08))) - sx, sy = dfadd2_f2_f2_f2(sx, sy, x2.const(0), x2.e(BinaryOps.MUL, xx).e(BinaryOps.MUL, t)) - r = sx.e(BinaryOps.ADD, sy) + sx, sy = dfadd2_f2_f2_f2(sx, sy, x2.const(0), (x2 * xx) * t) + r = sx + sy # log2(Inf) = Inf - r = d_orig.e(BinaryOps.CMPNE, d.const(math.inf)).e(TernaryOps.WHERE, r, r.const(math.inf)) + r = d_orig.ne(math.inf).where(r, r.const(math.inf)) # log2(x=-0.01) = NaN. where x < 0 - r = d_orig.e(BinaryOps.CMPLT, d.const(-0.0)).e(TernaryOps.WHERE, r.const(math.nan), r) + r = d_orig.lt(-0.0).where(r.const(math.nan), r) # log2(0) = -Inf, but we will compare using the value of y because 1e-200==0 is true. # log2_zero = the value of unmasked xlog2(0.0). log2_zero = {dtypes.float64: -1087, dtypes.float32: -191, dtypes.float16: -79, None: -math.inf}[d.dtype] - r = r.e(BinaryOps.CMPNE, r.const(log2_zero)).e(TernaryOps.WHERE, r, r.const(-math.inf)) + r = r.ne(log2_zero).where(r, r.const(-math.inf)) # log(NaN) = NaN, using for all real number x, either of x < Inf, x == Inf becomes True. - r = d_orig.e(BinaryOps.CMPLT, d_orig.const(math.inf)).e( - TernaryOps.WHERE, r, d_orig.e(BinaryOps.CMPNE, d_orig.const(math.inf)).e( - TernaryOps.WHERE, d.const(math.nan), d)) + r = d_orig.lt(math.inf).where(r, d_orig.ne(math.inf).where(d.const(math.nan), d)) # log(-0.0) = -Inf. In certain devices like PTX, x == -0.0 won't be true. so making reciprocal. - r = d_orig.e(UnaryOps.RECIP).e(BinaryOps.CMPNE, d_orig.const(-math.inf)).e(TernaryOps.WHERE, r, r.const(-math.inf)) + r = d_orig.recip().ne(-math.inf).where(r, r.const(-math.inf)) return r diff --git a/tinygrad/codegen/uops.py b/tinygrad/codegen/uops.py index 446eef7a1d..59d61659ef 100644 --- a/tinygrad/codegen/uops.py +++ b/tinygrad/codegen/uops.py @@ -57,6 +57,8 @@ class UOp: def __floordiv__(self, x): return UOp.alu(BinaryOps.IDIV, self, ufix(self.dtype, x)) def __truediv__(self, x): return UOp.alu(BinaryOps.MUL, self, UOp.alu(UnaryOps.RECIP, ufix(self.dtype, x))) def __mod__(self, x): return UOp.alu(BinaryOps.MOD, self, ufix(self.dtype, x)) + def __and__(self, x): return UOp.alu(BinaryOps.AND, self, ufix(self.dtype, x)) + def __or__(self, x): return UOp.alu(BinaryOps.OR, self, ufix(self.dtype, x)) def ne(self, x): return UOp.alu(BinaryOps.CMPNE, self, ufix(self.dtype, x)) def eq(self, x): return -self.ne(x) def lt(self, x): return UOp.alu(BinaryOps.CMPLT, self, ufix(self.dtype, x)) @@ -73,7 +75,6 @@ class UOp: return UOp(UOps.CONST, dtype, arg=dtypes.as_const(b, dtype) if dtype is not None else b) @staticmethod def alu(arg, *src:UOp): return UOp(UOps.ALU, dtypes.bool if arg in {BinaryOps.CMPLT, BinaryOps.CMPNE} else src[-1].dtype, src, arg) - def e(self, arg, *src:UOp): return UOp.alu(arg, self, *src) @staticmethod def load(*src:UOp, dtype:Optional[DType]=None, **kwargs): return UOp(UOps.LOAD, dtype, tuple(src)+tuple(kwargs.values())) @staticmethod