mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
cleanup mnist data load in beautiful_mnist (#5106)
This commit is contained in:
@@ -20,9 +20,6 @@ class Model:
|
||||
if __name__ == "__main__":
|
||||
X_train, Y_train, X_test, Y_test = mnist()
|
||||
|
||||
# TODO: remove this when HIP is fixed
|
||||
X_train, X_test = X_train.float(), X_test.float()
|
||||
|
||||
model = Model()
|
||||
opt = nn.optim.Adam(nn.state.get_parameters(model))
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
from typing import List, Callable
|
||||
from tinygrad import Tensor, TinyJit, nn, GlobalCounters, Device
|
||||
from tinygrad.helpers import getenv, colored, trange
|
||||
from extra.datasets import fetch_mnist
|
||||
from tinygrad.nn.datasets import mnist
|
||||
|
||||
GPUS = [f'{Device.DEFAULT}:{i}' for i in range(getenv("GPUS", 2))]
|
||||
|
||||
@@ -20,7 +20,7 @@ class Model:
|
||||
def __call__(self, x:Tensor) -> Tensor: return x.sequential(self.layers)
|
||||
|
||||
if __name__ == "__main__":
|
||||
X_train, Y_train, X_test, Y_test = fetch_mnist(tensors=True)
|
||||
X_train, Y_train, X_test, Y_test = mnist()
|
||||
# we shard the test data on axis 0
|
||||
X_test.shard_(GPUS, axis=0)
|
||||
Y_test.shard_(GPUS, axis=0)
|
||||
|
||||
Reference in New Issue
Block a user