move sample inside jit for beautiful_mnist (#3115)

also removed .realize() for jit functions since jit does it automatically now. a little more beautiful
This commit is contained in:
chenyu
2024-01-14 01:36:30 -05:00
committed by GitHub
parent a313e63a9b
commit fb3f8f7597

View File

@@ -24,21 +24,21 @@ if __name__ == "__main__":
opt = nn.optim.Adam(nn.state.get_parameters(model))
@TinyJit
def train_step(samples:Tensor) -> Tensor:
def train_step() -> Tensor:
with Tensor.train():
opt.zero_grad()
samples = Tensor.randint(512, high=X_train.shape[0])
# TODO: this "gather" of samples is very slow. will be under 5s when this is fixed
loss = model(X_train[samples]).sparse_categorical_crossentropy(Y_train[samples]).backward()
opt.step()
return loss.realize()
return loss
@TinyJit
def get_test_acc() -> Tensor: return ((model(X_test).argmax(axis=1) == Y_test).mean()*100).realize()
def get_test_acc() -> Tensor: return (model(X_test).argmax(axis=1) == Y_test).mean()*100
test_acc = float('nan')
for i in (t:=trange(70)):
GlobalCounters.reset() # NOTE: this makes it nice for DEBUG=2 timing
samples = Tensor.randint(512, high=X_train.shape[0]) # TODO: put this in the JIT when rand is fixed
loss = train_step(samples)
loss = train_step()
if i%10 == 9: test_acc = get_test_acc().item()
t.set_description(f"loss: {loss.item():6.2f} test_accuracy: {test_acc:5.2f}%")