add TORCHVIZ=1 to beautiful_mnist_torch (#9576)

This commit is contained in:
qazal
2025-03-26 11:17:08 +08:00
committed by GitHub
parent 93bcb974c5
commit c03dadfcb9

View File

@@ -1,5 +1,5 @@
from tinygrad import dtypes, getenv, Device
from tinygrad.helpers import trange, colored, DEBUG
from tinygrad.helpers import trange, colored, DEBUG, temp
from tinygrad.nn.datasets import mnist
import torch
from torch import nn, optim
@@ -38,6 +38,7 @@ if __name__ == "__main__":
X_test = torch.tensor(X_test.float().numpy(), device=device)
Y_test = torch.tensor(Y_test.cast(dtypes.int64).numpy(), device=device)
if getenv("TORCHVIZ"): torch.cuda.memory._record_memory_history()
model = Model().to(device)
optimizer = optim.Adam(model.parameters(), 1e-3)
@@ -63,3 +64,6 @@ if __name__ == "__main__":
if target := getenv("TARGET_EVAL_ACC_PCT", 0.0):
if test_acc >= target and test_acc != 100.0: print(colored(f"{test_acc=} >= {target}", "green"))
else: raise ValueError(colored(f"{test_acc=} < {target}", "red"))
if getenv("TORCHVIZ"):
torch.cuda.memory._dump_snapshot(fp:=temp("torchviz.pkl", append_user=True))
print(f"saved torch memory snapshot to {fp}, view in https://pytorch.org/memory_viz")