stupid numpy hack

This commit is contained in:
George Hotz
2022-06-11 17:01:56 -07:00
parent c03c835d75
commit a4d0d3f17a
2 changed files with 3 additions and 2 deletions

View File

@@ -3,6 +3,7 @@ from tinygrad.helpers import get_conv_args
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, MovementOps, ProcessingOps
class CPUBuffer(np.ndarray):
def __new__(cls, shape): return np.zeros(shape, dtype=np.float32).view(CPUBuffer)
def relu(x): return np.maximum(x, 0)
def exp(x): return np.exp(x)
def log(x): return np.log(x)

View File

@@ -182,10 +182,10 @@ class Tensor:
def _move_data(data, device):
if isinstance(data, list):
data = np.array(data, dtype=np.float32)
if isinstance(data, np.ndarray):
data = data.view(Device.buffers[Device.CPU])
if isinstance(data, Device.buffers[device]):
return data
if isinstance(data, np.ndarray):
data = data.view(Device.buffers[Device.CPU])
if Tensor._get_data_dtype(data) != np.float32 and not Tensor.did_float_warning:
# warning? float64 is actually needed for numerical jacobian