fix old list behavior

This commit is contained in:
George Hotz
2021-11-29 01:49:21 -05:00
parent 7c160c6cee
commit c6fb087695
2 changed files with 3 additions and 1 deletions

View File

@@ -84,7 +84,7 @@ if __name__ == "__main__":
from tinygrad.tensor import Tensor, Device
r1 = Tensor(np.random.random(16).astype(np.float32)-0.5, device=Device.METAL)
r1 = Tensor([-2,-1,0,1,2], device=Device.METAL)
r2 = r1.relu()
print(r1.cpu())
print(r2.cpu())

View File

@@ -139,6 +139,8 @@ class Tensor:
@staticmethod
def _move_data(data, device):
if isinstance(data, list):
data = np.array(data, dtype=np.float32)
if isinstance(data, np.ndarray):
data = data.view(Device.buffers[Device.CPU])
if isinstance(data, Device.buffers[device]):