mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
fix examples/hlb_cifar10.py
This commit is contained in:
@@ -114,7 +114,7 @@ def train_cifar():
|
||||
# use training batchnorm (and no_grad would change the kernels)
|
||||
out = model(Xt)
|
||||
outs = out.numpy().argmax(axis=1)
|
||||
loss = (out * Yt).mean().numpy()[0]
|
||||
loss = (out * Yt).mean().numpy()
|
||||
correct = outs == Yt.numpy().argmin(axis=1)
|
||||
print(f"eval {sum(correct)}/{len(correct)} {sum(correct)/len(correct)*100.0:.2f}%, {loss:7.2f} val_loss")
|
||||
if STEPS == 0: break
|
||||
@@ -123,7 +123,7 @@ def train_cifar():
|
||||
loss = train_step_jitted(model, optimizer, X, Y)
|
||||
et = time.monotonic()
|
||||
X, Y = fetch_batch(X_train, Y_train, BS=BS) # do this here
|
||||
loss_cpu = loss.numpy()[0]
|
||||
loss_cpu = loss.numpy()
|
||||
cl = time.monotonic()
|
||||
print(f"{i:3d} {(cl-st)*1000.0:7.2f} ms run, {(et-st)*1000.0:7.2f} ms python, {(cl-et)*1000.0:7.2f} ms CL, {loss_cpu:7.2f} loss, {GlobalCounters.mem_used/1e9:.2f} GB used, {GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user