select torch device in examples/beautiful_mnist_torch.py (#9575)

This commit is contained in:
qazal
2025-03-26 11:01:25 +08:00
committed by GitHub
parent 2c32126fc8
commit 93bcb974c5

View File

@@ -1,5 +1,5 @@
from tinygrad import dtypes, getenv
from tinygrad.helpers import trange, colored
from tinygrad import dtypes, getenv, Device
from tinygrad.helpers import trange, colored, DEBUG
from tinygrad.nn.datasets import mnist
import torch
from torch import nn, optim
@@ -30,7 +30,8 @@ if __name__ == "__main__":
import tinygrad.frontend.torch
device = torch.device("tiny")
else:
device = torch.device("mps")
device = torch.device({"METAL":"mps","NV":"cuda"}.get(Device.DEFAULT, "cpu"))
if DEBUG >= 1: print(f"using torch backend {device}")
X_train, Y_train, X_test, Y_test = mnist()
X_train = torch.tensor(X_train.float().numpy(), device=device)
Y_train = torch.tensor(Y_train.cast(dtypes.int64).numpy(), device=device)