mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
hmm, the native exp/log breaks it too much
This commit is contained in:
@@ -3,13 +3,14 @@
|
||||
from __future__ import annotations
|
||||
import os
|
||||
from tinygrad.llops.ops_gpu import GPUBuffer, CL, CLProgram, CLBuffer
|
||||
from tinygrad.ops import ProcessingOps, ReduceOps
|
||||
from tinygrad.ops import ProcessingOps, ReduceOps, UnaryOps, BinaryOps
|
||||
from tinygrad.helpers import prod, ConvArgs
|
||||
from typing import List, Tuple, Optional, Dict, Set
|
||||
import numpy as np
|
||||
import pyopencl as cl
|
||||
|
||||
UNSAFE_FLOAT4 = int(os.getenv("UNSAFE_FLOAT4", 0))
|
||||
NATIVE_EXPLOG = int(os.getenv("NATIVE_EXPLOG", 0))
|
||||
|
||||
import pathlib
|
||||
def load(x):
|
||||
@@ -99,6 +100,11 @@ def get_getters(ewbufs, ret):
|
||||
|
||||
def roundup(x, n=4): return (x+(n-1))//n * n
|
||||
class OpenCLBuffer(GPUBuffer):
|
||||
code_for_op = {
|
||||
UnaryOps.NOOP: "(A)", UnaryOps.NEG: "(-(A))", UnaryOps.RELU: "max(A, (float)0.)", UnaryOps.EXP: "native_exp(A)" if NATIVE_EXPLOG else "exp(A)", UnaryOps.LOG: "native_log(A)" if NATIVE_EXPLOG else "log(A)", UnaryOps.SIGN: "sign(A)",
|
||||
BinaryOps.ADD: "(A+B)", BinaryOps.SUB: "(A-B)", BinaryOps.MUL: "(A*B)", BinaryOps.DIV: "(A/B)", BinaryOps.POW: "pow(A,B)", BinaryOps.CMPEQ: "(A==B)",
|
||||
ReduceOps.SUM: "(acc + A)", ReduceOps.MAX: "max(A, acc)"
|
||||
}
|
||||
def __init__(self, shape, hostbuf:Optional[OpenCLBuffer]=None, backing:Optional[np.ndarray]=None):
|
||||
self._image = hostbuf._image if hostbuf is not None else None
|
||||
super().__init__(shape, hostbuf, backing)
|
||||
|
||||
@@ -66,7 +66,7 @@ class CLProgram:
|
||||
|
||||
class GPUBuffer:
|
||||
code_for_op = {
|
||||
UnaryOps.NOOP: "(A)", UnaryOps.NEG: "(-(A))", UnaryOps.RELU: "max(A, (float)0.)", UnaryOps.EXP: "native_exp(A)", UnaryOps.LOG: "native_log(A)", UnaryOps.SIGN: "sign(A)",
|
||||
UnaryOps.NOOP: "(A)", UnaryOps.NEG: "(-(A))", UnaryOps.RELU: "max(A, (float)0.)", UnaryOps.EXP: "exp(A)", UnaryOps.LOG: "log(A)", UnaryOps.SIGN: "sign(A)",
|
||||
BinaryOps.ADD: "(A+B)", BinaryOps.SUB: "(A-B)", BinaryOps.MUL: "(A*B)", BinaryOps.DIV: "(A/B)", BinaryOps.POW: "pow(A,B)", BinaryOps.CMPEQ: "(A==B)",
|
||||
ReduceOps.SUM: "(acc + A)", ReduceOps.MAX: "max(A, acc)"
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user