mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 13:58:00 -05:00
bring back native exp log
This commit is contained in:
@@ -10,6 +10,7 @@ import numpy as np
|
||||
import pyopencl as cl
|
||||
|
||||
UNSAFE_FLOAT4 = int(os.getenv("UNSAFE_FLOAT4", 0))
|
||||
NATIVE_EXPLOG = int(os.getenv("NATIVE_EXPLOG", 0)) # this is needed as a switch for the tests to pass
|
||||
FLOAT16 = int(os.getenv("FLOAT16", 0))
|
||||
|
||||
import pathlib
|
||||
@@ -102,9 +103,9 @@ def roundup(x, n=4): return (x+(n-1))//n * n
|
||||
class OpenCLBuffer(GPUBuffer):
|
||||
code_for_op = {
|
||||
UnaryOps.NOOP: "(A)", UnaryOps.NEG: "(-(A))", UnaryOps.RELU: "max(A, (float)0.)", UnaryOps.SIGN: "sign(A)",
|
||||
UnaryOps.EXP: "native_exp(A)",
|
||||
UnaryOps.LOG: "native_log(A)",
|
||||
UnaryOps.RECIPROCAL: "native_recip(A)",
|
||||
UnaryOps.EXP: "native_exp(A)" if NATIVE_EXPLOG else "exp(A)",
|
||||
UnaryOps.LOG: "native_log(A)" if NATIVE_EXPLOG else "log(A)",
|
||||
UnaryOps.RECIPROCAL: "native_recip(A)" if NATIVE_EXPLOG else "((float)1.0/A)",
|
||||
BinaryOps.ADD: "(A+B)", BinaryOps.SUB: "(A-B)", BinaryOps.MUL: "(A*B)", BinaryOps.DIV: "(A/B)", BinaryOps.POW: "pow(A,B)", BinaryOps.CMPEQ: "(A==B)",
|
||||
ReduceOps.SUM: "(acc + A)", ReduceOps.MAX: "max(A, acc)"
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user