mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 06:58:11 -05:00
verify eval acc for hlb_cifar training (#3344)
set to 93% to reduce flakiness for now
This commit is contained in:
2
.github/workflows/benchmark.yml
vendored
2
.github/workflows/benchmark.yml
vendored
@@ -145,7 +145,7 @@ jobs:
|
||||
# - name: Run 10 CIFAR training steps w 6 GPUS
|
||||
# run: time HALF=1 STEPS=10 BS=1536 GPUS=6 python3 examples/hlb_cifar10.py
|
||||
- name: Run full CIFAR training
|
||||
run: time HIP=1 HALF=1 LATEWINO=1 STEPS=1000 python3 examples/hlb_cifar10.py | tee train_cifar_one_gpu.py
|
||||
run: time HIP=1 HALF=1 LATEWINO=1 STEPS=1000 TARGET_EVAL_ACC_PCT=93 python3 examples/hlb_cifar10.py | tee train_cifar_one_gpu.py
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: Speed (AMD)
|
||||
|
||||
@@ -11,7 +11,7 @@ from extra.lr_scheduler import OneCycleLR
|
||||
from tinygrad import nn, dtypes, Tensor, Device, GlobalCounters, TinyJit
|
||||
from tinygrad.nn.state import get_state_dict, get_parameters
|
||||
from tinygrad.nn import optim
|
||||
from tinygrad.helpers import Context, BEAM, WINO, getenv
|
||||
from tinygrad.helpers import Context, BEAM, WINO, getenv, colored
|
||||
from tinygrad.features.multi import MultiLazyBuffer
|
||||
|
||||
BS, STEPS = getenv("BS", 512), getenv("STEPS", 1000)
|
||||
@@ -351,6 +351,7 @@ def train_cifar():
|
||||
model_ema: Optional[modelEMA] = None
|
||||
projected_ema_decay_val = hyp['ema']['decay_base'] ** hyp['ema']['every_n_steps']
|
||||
i = 0
|
||||
eval_acc_pct = 0.0
|
||||
batcher = fetch_batches(X_train, Y_train, BS=BS, is_train=True)
|
||||
with Tensor.train():
|
||||
st = time.monotonic()
|
||||
@@ -378,9 +379,9 @@ def train_cifar():
|
||||
correct_sum, correct_len = sum(corrects), len(corrects)
|
||||
if model_ema: correct_sum_ema, correct_len_ema = sum(corrects_ema), len(corrects_ema)
|
||||
|
||||
acc = correct_sum/correct_len*100.0
|
||||
eval_acc_pct = correct_sum/correct_len*100.0
|
||||
if model_ema: acc_ema = correct_sum_ema/correct_len_ema*100.0
|
||||
print(f"eval {correct_sum}/{correct_len} {acc:.2f}%, {(sum(losses)/len(losses)):7.2f} val_loss STEP={i} (in {(time.monotonic()-st)*1e3:.2f} ms)")
|
||||
print(f"eval {correct_sum}/{correct_len} {eval_acc_pct:.2f}%, {(sum(losses)/len(losses)):7.2f} val_loss STEP={i} (in {(time.monotonic()-st)*1e3:.2f} ms)")
|
||||
if model_ema: print(f"eval ema {correct_sum_ema}/{correct_len_ema} {acc_ema:.2f}%, {(sum(losses_ema)/len(losses_ema)):7.2f} val_loss STEP={i}")
|
||||
|
||||
if STEPS == 0 or i == STEPS: break
|
||||
@@ -407,5 +408,12 @@ def train_cifar():
|
||||
st = cl
|
||||
i += 1
|
||||
|
||||
# verify eval acc
|
||||
if target := getenv("TARGET_EVAL_ACC_PCT", 0.0):
|
||||
if eval_acc_pct >= target:
|
||||
print(colored(f"{eval_acc_pct=} >= {target}", "green"))
|
||||
else:
|
||||
raise ValueError(colored(f"{eval_acc_pct=} < {target}", "red"))
|
||||
|
||||
if __name__ == "__main__":
|
||||
train_cifar()
|
||||
|
||||
Reference in New Issue
Block a user