mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
test backward in test_tiny (#11697)
* test backward in test_tiny * empty
This commit is contained in:
@@ -29,8 +29,7 @@ if __name__ == "__main__":
|
||||
opt.zero_grad()
|
||||
samples = Tensor.randint(getenv("BS", 512), high=X_train.shape[0])
|
||||
loss = model(X_train[samples]).sparse_categorical_crossentropy(Y_train[samples]).backward()
|
||||
opt.step()
|
||||
return loss
|
||||
return loss.realize(*opt.schedule_step())
|
||||
|
||||
@TinyJit
|
||||
def get_test_acc() -> Tensor: return (model(X_test).argmax(axis=1) == Y_test).mean()*100
|
||||
|
||||
Reference in New Issue
Block a user