hlb_cifar multi gpu training (#3150)

* cifar train with multi gpu

* GPUS=1 is noop
This commit is contained in:
chenyu
2024-01-16 14:38:45 -05:00
committed by GitHub
parent cc0de99751
commit 589c16756f

View File

@@ -1,11 +1,4 @@
#!/usr/bin/env python3
# setup for distributed
from extra import dist
from tinygrad.helpers import getenv
if __name__ == "__main__":
if getenv("DIST"):
dist.preinit()
from extra.dist import collectives
# tinygrad implementation of https://github.com/tysam-code/hlb-CIFAR10/blob/main/main.py
# https://myrtle.ai/learn/how-to-train-your-resnet-8-bag-of-tricks/
@@ -15,12 +8,15 @@ import numpy as np
from typing import Optional
from extra.datasets import fetch_cifar, cifar_mean, cifar_std
from tinygrad import nn, dtypes, Tensor, Device, GlobalCounters, TinyJit
from tinygrad.nn.state import get_state_dict
from tinygrad.nn.state import get_state_dict, get_parameters
from tinygrad.nn import optim
from extra.lr_scheduler import OneCycleLR
from tinygrad.helpers import Context, getenv
BS, EVAL_BS, STEPS = getenv("BS", 512), getenv('EVAL_BS', 500), getenv("STEPS", 1000)
GPUS = [f'{Device.DEFAULT}:{i}' for i in range(getenv("GPUS", 1))]
assert BS % len(GPUS) == 0, f"{BS=} is not a multiple of {len(GPUS)=}"
for x in GPUS: Device[x]
if getenv("HALF", 0):
dtypes.default_float = dtypes.float16
@@ -75,7 +71,7 @@ class SpeedyResNet:
def __call__(self, x, training=True):
# pad to 32x32 because whitening conv creates 31x31 images that are awfully slow to compute with
# TODO: remove the pad but instead let the kernel optimizer itself
# TODO: remove the pad but instead let the kernel optimize itself
forward = lambda x: x.conv2d(self.whitening).pad2d((1,0,0,1)).sequential(self.net)
return forward(x) if training else forward(x)*0.5 + forward(x[..., ::-1])*0.5
@@ -251,11 +247,6 @@ def train_cifar():
set_seed(hyp['seed'])
# this import needs to be done here because this is running in a subprocess
from extra.dist import OOB
assert OOB is not None or not getenv("DIST"), "OOB should be initialized"
rank, world_size = getenv("RANK"), getenv("WORLD_SIZE", 1)
X_train, Y_train, X_test, Y_test = fetch_cifar()
# load data and label into GPU and convert to dtype accordingly
X_train, X_test = X_train.to(device=Device.DEFAULT).float(), X_test.to(device=Device.DEFAULT).float()
@@ -278,6 +269,10 @@ def train_cifar():
X_train, Y_train = X_train.cast(dtypes.default_float), Y_train.cast(dtypes.default_float)
X_test, Y_test = X_test.cast(dtypes.default_float), Y_test.cast(dtypes.default_float)
if len(GPUS) > 1:
for x in get_parameters(model):
x.to_(GPUS)
# parse the training params into bias and non-bias
params_dict = get_state_dict(model)
params_bias = []
@@ -310,17 +305,6 @@ def train_cifar():
optimizer[1].zero_grad()
loss.backward()
if getenv("DIST"):
# sync gradients across ranks
bucket, offset = [], 0
for _, v in params_dict.items():
if v.grad is not None: bucket.append(v.grad.flatten())
grads = collectives.allreduce(Tensor.cat(*bucket))
for _, v in params_dict.items():
if v.grad is not None:
v.grad.assign(grads[offset:offset+v.grad.numel()].reshape(*v.grad.shape))
offset += v.grad.numel()
optimizer[0].step()
optimizer[1].step()
lr_scheduler[0].step()
@@ -360,9 +344,8 @@ def train_cifar():
losses = []
losses_ema = []
for Xt, Yt in fetch_batches(X_test, Y_test, BS=EVAL_BS, is_train=False):
# further split batch if distributed
if getenv("DIST"):
Xt, Yt = Xt.chunk(min(world_size, 5), 0)[min(rank, 4)], Yt.chunk(min(world_size, 5), 0)[min(rank, 4)]
if len(GPUS) > 1:
Xt, Yt = Xt.shard(GPUS, axis=0), Yt.shard(GPUS, axis=0)
correct, loss = eval_step_jitted(model, Xt, Yt)
losses.append(loss.numpy().tolist())
@@ -375,77 +358,34 @@ def train_cifar():
# collect accuracy across ranks
correct_sum, correct_len = sum(corrects), len(corrects)
if model_ema: correct_sum_ema, correct_len_ema = sum(corrects_ema), len(corrects_ema)
if getenv("DIST"):
if rank == 0:
for j in range(1, min(world_size, 5)):
if model_ema:
recv_sum, recv_len, recv_sum_ema, recv_len_ema = OOB.recv(j)
else:
recv_sum, recv_len = OOB.recv(j)
correct_sum += recv_sum
correct_len += recv_len
if model_ema:
correct_sum_ema += recv_sum_ema
correct_len_ema += recv_len_ema
elif rank < min(world_size, 5):
if model_ema:
OOB.send((correct_sum, correct_len, correct_sum_ema, correct_len_ema), 0)
else:
OOB.send((correct_sum, correct_len), 0)
# only rank 0 prints
if rank == 0:
acc = correct_sum/correct_len*100.0
if model_ema: acc_ema = correct_sum_ema/correct_len_ema*100.0
print(f"eval {correct_sum}/{correct_len} {acc:.2f}%, {(sum(losses)/len(losses)):7.2f} val_loss STEP={i} (in {(time.monotonic()-st)*1e3:.2f} ms)")
if model_ema: print(f"eval ema {correct_sum_ema}/{correct_len_ema} {acc_ema:.2f}%, {(sum(losses_ema)/len(losses_ema)):7.2f} val_loss STEP={i}")
acc = correct_sum/correct_len*100.0
if model_ema: acc_ema = correct_sum_ema/correct_len_ema*100.0
print(f"eval {correct_sum}/{correct_len} {acc:.2f}%, {(sum(losses)/len(losses)):7.2f} val_loss STEP={i} (in {(time.monotonic()-st)*1e3:.2f} ms)")
if model_ema: print(f"eval ema {correct_sum_ema}/{correct_len_ema} {acc_ema:.2f}%, {(sum(losses_ema)/len(losses_ema)):7.2f} val_loss STEP={i}")
if STEPS == 0 or i == STEPS: break
X, Y = next(batcher)
if getenv("DIST"):
X, Y = X.chunk(world_size, 0)[rank], Y.chunk(world_size, 0)[rank]
if len(GPUS) > 1:
X, Y = X.shard(GPUS, axis=0), Y.shard(GPUS, axis=0)
GlobalCounters.reset()
with Context(BEAM=getenv("LATEBEAM")):
loss = train_step_jitted(model, [opt_bias, opt_non_bias], [lr_sched_bias, lr_sched_non_bias], X, Y)
et = time.monotonic()
loss_cpu = loss.numpy()
# EMA for network weights
if i > hyp['ema']['steps'] and (i+1) % hyp['ema']['every_n_steps'] == 0:
if getenv("EMA") and i > hyp['ema']['steps'] and (i+1) % hyp['ema']['every_n_steps'] == 0:
if model_ema is None:
model_ema = modelEMA(W, model)
model_ema.update(model, Tensor([projected_ema_decay_val*(i/STEPS)**hyp['ema']['decay_pow']]))
cl = time.monotonic()
if not getenv("DIST"):
# 53 221.74 ms run, 2.22 ms python, 219.52 ms CL, 803.39 loss, 0.000807 LR, 4.66 GB used, 3042.49 GFLOPS, 674.65 GOPS
print(f"{i:3d} {(cl-st)*1000.0:7.2f} ms run, {(et-st)*1000.0:7.2f} ms python, {(cl-et)*1000.0:7.2f} ms {loss.device}, {loss_cpu:7.2f} loss, {opt_non_bias.lr.numpy()[0]:.6f} LR, {GlobalCounters.mem_used/1e9:.2f} GB used, {GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS, {GlobalCounters.global_ops*1e-9:9.2f} GOPS")
else:
print(f"{i:3d} {(cl-st)*1000.0:7.2f} ms run, {(et-st)*1000.0:7.2f} ms python, {(cl-et)*1000.0:7.2f} ms {loss.device}, {loss_cpu:7.2f} loss, {opt_non_bias.lr.numpy()[0]:.6f} LR, {world_size*GlobalCounters.mem_used/1e9:.2f} GB used, {world_size*GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS")
device_str = loss.device if isinstance(loss.device, str) else f"{loss.device[0]} * {len(loss.device)}"
# 53 221.74 ms run, 2.22 ms python, 219.52 ms CL, 803.39 loss, 0.000807 LR, 4.66 GB used, 3042.49 GFLOPS, 674.65 GOPS
print(f"{i:3d} {(cl-st)*1000.0:7.2f} ms run, {(et-st)*1000.0:7.2f} ms python, {(cl-et)*1000.0:7.2f} ms {device_str}, {loss_cpu:7.2f} loss, {opt_non_bias.lr.numpy()[0]:.6f} LR, {GlobalCounters.mem_used/1e9:.2f} GB used, {GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS, {GlobalCounters.global_ops*1e-9:9.2f} GOPS")
st = cl
i += 1
if __name__ == "__main__":
if not getenv("DIST"):
train_cifar()
else: # distributed
if getenv("HIP"):
from tinygrad.runtime.ops_hip import HIP
devices = [f"hip:{i}" for i in range(HIP.device_count)]
else:
from tinygrad.runtime.ops_gpu import CLDevice
devices = [f"gpu:{i}" for i in range(len(CLDevice.device_ids))]
world_size = len(devices)
# ensure that the batch size is divisible by the number of devices
assert BS % world_size == 0, f"batch size {BS} is not divisible by world size {world_size}"
# ensure that the evaluation batch size is divisible by the number of devices
assert EVAL_BS % min(world_size, 5) == 0, f"evaluation batch size {EVAL_BS} is not divisible by world size {min(world_size, 5)}"
# init out-of-band communication
dist.init_oob(world_size)
# start the processes
processes = []
for rank, device in enumerate(devices):
processes.append(dist.spawn(rank, device, fn=train_cifar, args=()))
for p in processes: p.join()
train_cifar()