hmm, the native exp/log breaks it too much

This commit is contained in:
Comma Device
2022-08-22 17:13:08 -07:00
parent 0c1378e7db
commit 9678cb8a1a
2 changed files with 8 additions and 2 deletions

View File

@@ -3,13 +3,14 @@
from __future__ import annotations
import os
from tinygrad.llops.ops_gpu import GPUBuffer, CL, CLProgram, CLBuffer
from tinygrad.ops import ProcessingOps, ReduceOps
from tinygrad.ops import ProcessingOps, ReduceOps, UnaryOps, BinaryOps
from tinygrad.helpers import prod, ConvArgs
from typing import List, Tuple, Optional, Dict, Set
import numpy as np
import pyopencl as cl
UNSAFE_FLOAT4 = int(os.getenv("UNSAFE_FLOAT4", 0))
NATIVE_EXPLOG = int(os.getenv("NATIVE_EXPLOG", 0))
import pathlib
def load(x):
@@ -99,6 +100,11 @@ def get_getters(ewbufs, ret):
def roundup(x, n=4): return (x+(n-1))//n * n
class OpenCLBuffer(GPUBuffer):
code_for_op = {
UnaryOps.NOOP: "(A)", UnaryOps.NEG: "(-(A))", UnaryOps.RELU: "max(A, (float)0.)", UnaryOps.EXP: "native_exp(A)" if NATIVE_EXPLOG else "exp(A)", UnaryOps.LOG: "native_log(A)" if NATIVE_EXPLOG else "log(A)", UnaryOps.SIGN: "sign(A)",
BinaryOps.ADD: "(A+B)", BinaryOps.SUB: "(A-B)", BinaryOps.MUL: "(A*B)", BinaryOps.DIV: "(A/B)", BinaryOps.POW: "pow(A,B)", BinaryOps.CMPEQ: "(A==B)",
ReduceOps.SUM: "(acc + A)", ReduceOps.MAX: "max(A, acc)"
}
def __init__(self, shape, hostbuf:Optional[OpenCLBuffer]=None, backing:Optional[np.ndarray]=None):
self._image = hostbuf._image if hostbuf is not None else None
super().__init__(shape, hostbuf, backing)

View File

@@ -66,7 +66,7 @@ class CLProgram:
class GPUBuffer:
code_for_op = {
UnaryOps.NOOP: "(A)", UnaryOps.NEG: "(-(A))", UnaryOps.RELU: "max(A, (float)0.)", UnaryOps.EXP: "native_exp(A)", UnaryOps.LOG: "native_log(A)", UnaryOps.SIGN: "sign(A)",
UnaryOps.NOOP: "(A)", UnaryOps.NEG: "(-(A))", UnaryOps.RELU: "max(A, (float)0.)", UnaryOps.EXP: "exp(A)", UnaryOps.LOG: "log(A)", UnaryOps.SIGN: "sign(A)",
BinaryOps.ADD: "(A+B)", BinaryOps.SUB: "(A-B)", BinaryOps.MUL: "(A*B)", BinaryOps.DIV: "(A/B)", BinaryOps.POW: "pow(A,B)", BinaryOps.CMPEQ: "(A==B)",
ReduceOps.SUM: "(acc + A)", ReduceOps.MAX: "max(A, acc)"
}