mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
move sample inside jit for beautiful_mnist (#3115)
also removed .realize() for jit functions since jit does it automatically now. a little more beautiful
This commit is contained in:
@@ -24,21 +24,21 @@ if __name__ == "__main__":
|
||||
opt = nn.optim.Adam(nn.state.get_parameters(model))
|
||||
|
||||
@TinyJit
|
||||
def train_step(samples:Tensor) -> Tensor:
|
||||
def train_step() -> Tensor:
|
||||
with Tensor.train():
|
||||
opt.zero_grad()
|
||||
samples = Tensor.randint(512, high=X_train.shape[0])
|
||||
# TODO: this "gather" of samples is very slow. will be under 5s when this is fixed
|
||||
loss = model(X_train[samples]).sparse_categorical_crossentropy(Y_train[samples]).backward()
|
||||
opt.step()
|
||||
return loss.realize()
|
||||
return loss
|
||||
|
||||
@TinyJit
|
||||
def get_test_acc() -> Tensor: return ((model(X_test).argmax(axis=1) == Y_test).mean()*100).realize()
|
||||
def get_test_acc() -> Tensor: return (model(X_test).argmax(axis=1) == Y_test).mean()*100
|
||||
|
||||
test_acc = float('nan')
|
||||
for i in (t:=trange(70)):
|
||||
GlobalCounters.reset() # NOTE: this makes it nice for DEBUG=2 timing
|
||||
samples = Tensor.randint(512, high=X_train.shape[0]) # TODO: put this in the JIT when rand is fixed
|
||||
loss = train_step(samples)
|
||||
loss = train_step()
|
||||
if i%10 == 9: test_acc = get_test_acc().item()
|
||||
t.set_description(f"loss: {loss.item():6.2f} test_accuracy: {test_acc:5.2f}%")
|
||||
|
||||
Reference in New Issue
Block a user