verify eval acc for hlb_cifar training (#3344)

set to 93% to reduce flakiness for now
This commit is contained in:
chenyu
2024-02-07 19:19:59 -05:00
committed by GitHub
parent 0d2dacb549
commit d8ad9e5660
2 changed files with 12 additions and 4 deletions

View File

@@ -145,7 +145,7 @@ jobs:
# - name: Run 10 CIFAR training steps w 6 GPUS
# run: time HALF=1 STEPS=10 BS=1536 GPUS=6 python3 examples/hlb_cifar10.py
- name: Run full CIFAR training
run: time HIP=1 HALF=1 LATEWINO=1 STEPS=1000 python3 examples/hlb_cifar10.py | tee train_cifar_one_gpu.py
run: time HIP=1 HALF=1 LATEWINO=1 STEPS=1000 TARGET_EVAL_ACC_PCT=93 python3 examples/hlb_cifar10.py | tee train_cifar_one_gpu.py
- uses: actions/upload-artifact@v4
with:
name: Speed (AMD)

View File

@@ -11,7 +11,7 @@ from extra.lr_scheduler import OneCycleLR
from tinygrad import nn, dtypes, Tensor, Device, GlobalCounters, TinyJit
from tinygrad.nn.state import get_state_dict, get_parameters
from tinygrad.nn import optim
from tinygrad.helpers import Context, BEAM, WINO, getenv
from tinygrad.helpers import Context, BEAM, WINO, getenv, colored
from tinygrad.features.multi import MultiLazyBuffer
BS, STEPS = getenv("BS", 512), getenv("STEPS", 1000)
@@ -351,6 +351,7 @@ def train_cifar():
model_ema: Optional[modelEMA] = None
projected_ema_decay_val = hyp['ema']['decay_base'] ** hyp['ema']['every_n_steps']
i = 0
eval_acc_pct = 0.0
batcher = fetch_batches(X_train, Y_train, BS=BS, is_train=True)
with Tensor.train():
st = time.monotonic()
@@ -378,9 +379,9 @@ def train_cifar():
correct_sum, correct_len = sum(corrects), len(corrects)
if model_ema: correct_sum_ema, correct_len_ema = sum(corrects_ema), len(corrects_ema)
acc = correct_sum/correct_len*100.0
eval_acc_pct = correct_sum/correct_len*100.0
if model_ema: acc_ema = correct_sum_ema/correct_len_ema*100.0
print(f"eval {correct_sum}/{correct_len} {acc:.2f}%, {(sum(losses)/len(losses)):7.2f} val_loss STEP={i} (in {(time.monotonic()-st)*1e3:.2f} ms)")
print(f"eval {correct_sum}/{correct_len} {eval_acc_pct:.2f}%, {(sum(losses)/len(losses)):7.2f} val_loss STEP={i} (in {(time.monotonic()-st)*1e3:.2f} ms)")
if model_ema: print(f"eval ema {correct_sum_ema}/{correct_len_ema} {acc_ema:.2f}%, {(sum(losses_ema)/len(losses_ema)):7.2f} val_loss STEP={i}")
if STEPS == 0 or i == STEPS: break
@@ -407,5 +408,12 @@ def train_cifar():
st = cl
i += 1
# verify eval acc
if target := getenv("TARGET_EVAL_ACC_PCT", 0.0):
if eval_acc_pct >= target:
print(colored(f"{eval_acc_pct=} >= {target}", "green"))
else:
raise ValueError(colored(f"{eval_acc_pct=} < {target}", "red"))
if __name__ == "__main__":
train_cifar()