hlb_cifar: support EVAL_STEPS=1000, print when dataset is shuffled

This commit is contained in:
George Hotz
2023-10-18 01:11:08 +00:00
parent 2b5ea7d9cb
commit 9b1c3cd9ca

View File

@@ -93,7 +93,7 @@ def train_cifar():
'bias_decay': 1.08 * 6.45e-4 * BS/bias_scaler,
'non_bias_decay': 1.08 * 6.45e-4 * BS,
'final_lr_ratio': 0.025,
'initial_div_factor': 1e16,
'initial_div_factor': 1e16,
'label_smoothing': 0.20,
'momentum': 0.85,
'percent_start': 0.23,
@@ -102,7 +102,7 @@ def train_cifar():
'net': {
'kernel_size': 2, # kernel size for the whitening layer
'cutmix_size': 3,
'cutmix_steps': 499,
'cutmix_steps': 499,
'pad_amount': 2
},
'ema': {
@@ -187,7 +187,7 @@ def train_cifar():
return X_cropped.reshape((-1, 3, crop_size, crop_size))
def cutmix(X, Y, mask_size=3):
def cutmix(X:Tensor, Y:Tensor, mask_size=3):
# fill the square with randomly selected images from the same batch
mask = make_square_mask(X.shape, mask_size)
order = list(range(0, X.shape[0]))
@@ -200,9 +200,10 @@ def train_cifar():
return X_cutmix, Y_cutmix
# the operations that remain inside batch fetcher is the ones that involves random operations
def fetch_batches(X_in, Y_in, BS, is_train):
def fetch_batches(X_in:Tensor, Y_in:Tensor, BS:int, is_train:bool):
step = 0
while True:
st = time.monotonic()
X, Y = X_in, Y_in
order = list(range(0, X.shape[0]))
random.shuffle(order)
@@ -211,6 +212,8 @@ def train_cifar():
X = Tensor.where(Tensor.rand(X.shape[0],1,1,1) < 0.5, X[..., ::-1], X) # flip LR
if step >= hyp['net']['cutmix_steps']: X, Y = cutmix(X, Y, mask_size=hyp['net']['cutmix_size'])
X, Y = X.numpy(), Y.numpy()
et = time.monotonic()
print(f"shuffling {'training' if is_train else 'test'} dataset in {(et-st)*1e3:.2f} ms")
for i in range(0, X.shape[0], BS):
# pad the last batch
batch_end = min(i+BS, Y.shape[0])
@@ -344,8 +347,9 @@ def train_cifar():
i = 0
batcher = fetch_batches(X_train, Y_train, BS=BS, is_train=True)
with Tensor.train():
st = time.monotonic()
while i <= STEPS:
if i%100 == 0 and i > 1:
if i%getenv("EVAL_STEPS", 100) == 0 and i > 1:
# Use Tensor.training = False here actually bricks batchnorm, even with track_running_stats=True
corrects = []
corrects_ema = []
@@ -399,7 +403,6 @@ def train_cifar():
if getenv("DIST"):
X, Y = X.chunk(world_size, 0)[rank], Y.chunk(world_size, 0)[rank]
GlobalCounters.reset()
st = time.monotonic()
loss = train_step_jitted(model, [opt_bias, opt_non_bias], [lr_sched_bias, lr_sched_non_bias], X, Y)
et = time.monotonic()
loss_cpu = loss.numpy()
@@ -413,6 +416,7 @@ def train_cifar():
print(f"{i:3d} {(cl-st)*1000.0:7.2f} ms run, {(et-st)*1000.0:7.2f} ms python, {(cl-et)*1000.0:7.2f} ms CL, {loss_cpu:7.2f} loss, {opt_non_bias.lr.numpy()[0]:.6f} LR, {GlobalCounters.mem_used/1e9:.2f} GB used, {GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS")
else:
print(f"{i:3d} {(cl-st)*1000.0:7.2f} ms run, {(et-st)*1000.0:7.2f} ms python, {(cl-et)*1000.0:7.2f} ms CL, {loss_cpu:7.2f} loss, {opt_non_bias.lr.numpy()[0]:.6f} LR, {world_size*GlobalCounters.mem_used/1e9:.2f} GB used, {world_size*GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS")
st = cl
i += 1
if __name__ == "__main__":