diff --git a/examples/other_mnist/beautiful_mnist_torch.py b/examples/other_mnist/beautiful_mnist_torch.py index a1a818501f..227edec95a 100644 --- a/examples/other_mnist/beautiful_mnist_torch.py +++ b/examples/other_mnist/beautiful_mnist_torch.py @@ -1,5 +1,5 @@ from tinygrad import dtypes, getenv, Device -from tinygrad.helpers import trange, colored, DEBUG +from tinygrad.helpers import trange, colored, DEBUG, temp from tinygrad.nn.datasets import mnist import torch from torch import nn, optim @@ -38,6 +38,7 @@ if __name__ == "__main__": X_test = torch.tensor(X_test.float().numpy(), device=device) Y_test = torch.tensor(Y_test.cast(dtypes.int64).numpy(), device=device) + if getenv("TORCHVIZ"): torch.cuda.memory._record_memory_history() model = Model().to(device) optimizer = optim.Adam(model.parameters(), 1e-3) @@ -63,3 +64,6 @@ if __name__ == "__main__": if target := getenv("TARGET_EVAL_ACC_PCT", 0.0): if test_acc >= target and test_acc != 100.0: print(colored(f"{test_acc=} >= {target}", "green")) else: raise ValueError(colored(f"{test_acc=} < {target}", "red")) + if getenv("TORCHVIZ"): + torch.cuda.memory._dump_snapshot(fp:=temp("torchviz.pkl", append_user=True)) + print(f"saved torch memory snapshot to {fp}, view in https://pytorch.org/memory_viz")