From 1b42b4e1b85cf3a030c239e2eaed92d88b22c973 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Thu, 1 Jun 2023 19:03:17 -0700 Subject: [PATCH] fix examples/hlb_cifar10.py --- examples/hlb_cifar10.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/hlb_cifar10.py b/examples/hlb_cifar10.py index 8cd6548dcc..2fcc0a02a4 100644 --- a/examples/hlb_cifar10.py +++ b/examples/hlb_cifar10.py @@ -114,7 +114,7 @@ def train_cifar(): # use training batchnorm (and no_grad would change the kernels) out = model(Xt) outs = out.numpy().argmax(axis=1) - loss = (out * Yt).mean().numpy()[0] + loss = (out * Yt).mean().numpy() correct = outs == Yt.numpy().argmin(axis=1) print(f"eval {sum(correct)}/{len(correct)} {sum(correct)/len(correct)*100.0:.2f}%, {loss:7.2f} val_loss") if STEPS == 0: break @@ -123,7 +123,7 @@ def train_cifar(): loss = train_step_jitted(model, optimizer, X, Y) et = time.monotonic() X, Y = fetch_batch(X_train, Y_train, BS=BS) # do this here - loss_cpu = loss.numpy()[0] + loss_cpu = loss.numpy() cl = time.monotonic() print(f"{i:3d} {(cl-st)*1000.0:7.2f} ms run, {(et-st)*1000.0:7.2f} ms python, {(cl-et)*1000.0:7.2f} ms CL, {loss_cpu:7.2f} loss, {GlobalCounters.mem_used/1e9:.2f} GB used, {GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS")