cleanup mnist data load in beautiful_mnist (#5106)

This commit is contained in:
chenyu
2024-06-22 18:31:51 -04:00
committed by GitHub
parent 5516b790ad
commit 055e616302
2 changed files with 2 additions and 5 deletions

View File

@@ -20,9 +20,6 @@ class Model:
if __name__ == "__main__":
X_train, Y_train, X_test, Y_test = mnist()
# TODO: remove this when HIP is fixed
X_train, X_test = X_train.float(), X_test.float()
model = Model()
opt = nn.optim.Adam(nn.state.get_parameters(model))

View File

@@ -2,7 +2,7 @@
from typing import List, Callable
from tinygrad import Tensor, TinyJit, nn, GlobalCounters, Device
from tinygrad.helpers import getenv, colored, trange
from extra.datasets import fetch_mnist
from tinygrad.nn.datasets import mnist
GPUS = [f'{Device.DEFAULT}:{i}' for i in range(getenv("GPUS", 2))]
@@ -20,7 +20,7 @@ class Model:
def __call__(self, x:Tensor) -> Tensor: return x.sequential(self.layers)
if __name__ == "__main__":
X_train, Y_train, X_test, Y_test = fetch_mnist(tensors=True)
X_train, Y_train, X_test, Y_test = mnist()
# we shard the test data on axis 0
X_test.shard_(GPUS, axis=0)
Y_test.shard_(GPUS, axis=0)