mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
select torch device in examples/beautiful_mnist_torch.py (#9575)
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from tinygrad import dtypes, getenv
|
||||
from tinygrad.helpers import trange, colored
|
||||
from tinygrad import dtypes, getenv, Device
|
||||
from tinygrad.helpers import trange, colored, DEBUG
|
||||
from tinygrad.nn.datasets import mnist
|
||||
import torch
|
||||
from torch import nn, optim
|
||||
@@ -30,7 +30,8 @@ if __name__ == "__main__":
|
||||
import tinygrad.frontend.torch
|
||||
device = torch.device("tiny")
|
||||
else:
|
||||
device = torch.device("mps")
|
||||
device = torch.device({"METAL":"mps","NV":"cuda"}.get(Device.DEFAULT, "cpu"))
|
||||
if DEBUG >= 1: print(f"using torch backend {device}")
|
||||
X_train, Y_train, X_test, Y_test = mnist()
|
||||
X_train = torch.tensor(X_train.float().numpy(), device=device)
|
||||
Y_train = torch.tensor(Y_train.cast(dtypes.int64).numpy(), device=device)
|
||||
|
||||
Reference in New Issue
Block a user