From 93bcb974c59a97a0173ad698586041440316909b Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Wed, 26 Mar 2025 11:01:25 +0800 Subject: [PATCH] select torch device in examples/beautiful_mnist_torch.py (#9575) --- examples/other_mnist/beautiful_mnist_torch.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/other_mnist/beautiful_mnist_torch.py b/examples/other_mnist/beautiful_mnist_torch.py index e1cef16d9c..a1a818501f 100644 --- a/examples/other_mnist/beautiful_mnist_torch.py +++ b/examples/other_mnist/beautiful_mnist_torch.py @@ -1,5 +1,5 @@ -from tinygrad import dtypes, getenv -from tinygrad.helpers import trange, colored +from tinygrad import dtypes, getenv, Device +from tinygrad.helpers import trange, colored, DEBUG from tinygrad.nn.datasets import mnist import torch from torch import nn, optim @@ -30,7 +30,8 @@ if __name__ == "__main__": import tinygrad.frontend.torch device = torch.device("tiny") else: - device = torch.device("mps") + device = torch.device({"METAL":"mps","NV":"cuda"}.get(Device.DEFAULT, "cpu")) + if DEBUG >= 1: print(f"using torch backend {device}") X_train, Y_train, X_test, Y_test = mnist() X_train = torch.tensor(X_train.float().numpy(), device=device) Y_train = torch.tensor(Y_train.cast(dtypes.int64).numpy(), device=device)