bring back native exp log

This commit is contained in:
George Hotz
2022-09-06 07:59:04 -07:00
parent d6f499fd69
commit f683b26eef

View File

@@ -10,6 +10,7 @@ import numpy as np
import pyopencl as cl
UNSAFE_FLOAT4 = int(os.getenv("UNSAFE_FLOAT4", 0))
NATIVE_EXPLOG = int(os.getenv("NATIVE_EXPLOG", 0)) # this is needed as a switch for the tests to pass
FLOAT16 = int(os.getenv("FLOAT16", 0))
import pathlib
@@ -102,9 +103,9 @@ def roundup(x, n=4): return (x+(n-1))//n * n
class OpenCLBuffer(GPUBuffer):
code_for_op = {
UnaryOps.NOOP: "(A)", UnaryOps.NEG: "(-(A))", UnaryOps.RELU: "max(A, (float)0.)", UnaryOps.SIGN: "sign(A)",
UnaryOps.EXP: "native_exp(A)",
UnaryOps.LOG: "native_log(A)",
UnaryOps.RECIPROCAL: "native_recip(A)",
UnaryOps.EXP: "native_exp(A)" if NATIVE_EXPLOG else "exp(A)",
UnaryOps.LOG: "native_log(A)" if NATIVE_EXPLOG else "log(A)",
UnaryOps.RECIPROCAL: "native_recip(A)" if NATIVE_EXPLOG else "((float)1.0/A)",
BinaryOps.ADD: "(A+B)", BinaryOps.SUB: "(A-B)", BinaryOps.MUL: "(A*B)", BinaryOps.DIV: "(A/B)", BinaryOps.POW: "pow(A,B)", BinaryOps.CMPEQ: "(A==B)",
ReduceOps.SUM: "(acc + A)", ReduceOps.MAX: "max(A, acc)"
}