diff --git a/examples/hlb_cifar10.py b/examples/hlb_cifar10.py index 164101622a..1b3b9f5f29 100644 --- a/examples/hlb_cifar10.py +++ b/examples/hlb_cifar10.py @@ -13,9 +13,11 @@ from tinygrad.nn.state import get_state_dict, get_parameters from tinygrad.nn import optim from tinygrad.helpers import Context, BEAM, WINO, getenv -BS, EVAL_BS, STEPS = getenv("BS", 512), getenv('EVAL_BS', 500), getenv("STEPS", 1000) +BS, STEPS = getenv("BS", 512), getenv("STEPS", 1000) +EVAL_BS = getenv("EVAL_BS", BS) GPUS = [f'{Device.DEFAULT}:{i}' for i in range(getenv("GPUS", 1))] -assert BS % len(GPUS) == 0, f"{BS=} is not a multiple of {len(GPUS)=}" +assert BS % len(GPUS) == 0, f"{BS=} is not a multiple of {len(GPUS)=}, uneven multi GPU is slow" +assert EVAL_BS % len(GPUS) == 0, f"{EVAL_BS=} is not a multiple of {len(GPUS)=}, uneven multi GPU is slow" for x in GPUS: Device[x] if getenv("HALF"):