mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
stupid numpy hack
This commit is contained in:
@@ -3,6 +3,7 @@ from tinygrad.helpers import get_conv_args
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, MovementOps, ProcessingOps
|
||||
|
||||
class CPUBuffer(np.ndarray):
|
||||
def __new__(cls, shape): return np.zeros(shape, dtype=np.float32).view(CPUBuffer)
|
||||
def relu(x): return np.maximum(x, 0)
|
||||
def exp(x): return np.exp(x)
|
||||
def log(x): return np.log(x)
|
||||
|
||||
@@ -182,10 +182,10 @@ class Tensor:
|
||||
def _move_data(data, device):
|
||||
if isinstance(data, list):
|
||||
data = np.array(data, dtype=np.float32)
|
||||
if isinstance(data, np.ndarray):
|
||||
data = data.view(Device.buffers[Device.CPU])
|
||||
if isinstance(data, Device.buffers[device]):
|
||||
return data
|
||||
if isinstance(data, np.ndarray):
|
||||
data = data.view(Device.buffers[Device.CPU])
|
||||
|
||||
if Tensor._get_data_dtype(data) != np.float32 and not Tensor.did_float_warning:
|
||||
# warning? float64 is actually needed for numerical jacobian
|
||||
|
||||
Reference in New Issue
Block a user