mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
hlb_cifar multi gpu training (#3150)
* cifar train with multi gpu * GPUS=1 is noop
This commit is contained in:
@@ -1,11 +1,4 @@
|
||||
#!/usr/bin/env python3
|
||||
# setup for distributed
|
||||
from extra import dist
|
||||
from tinygrad.helpers import getenv
|
||||
if __name__ == "__main__":
|
||||
if getenv("DIST"):
|
||||
dist.preinit()
|
||||
from extra.dist import collectives
|
||||
|
||||
# tinygrad implementation of https://github.com/tysam-code/hlb-CIFAR10/blob/main/main.py
|
||||
# https://myrtle.ai/learn/how-to-train-your-resnet-8-bag-of-tricks/
|
||||
@@ -15,12 +8,15 @@ import numpy as np
|
||||
from typing import Optional
|
||||
from extra.datasets import fetch_cifar, cifar_mean, cifar_std
|
||||
from tinygrad import nn, dtypes, Tensor, Device, GlobalCounters, TinyJit
|
||||
from tinygrad.nn.state import get_state_dict
|
||||
from tinygrad.nn.state import get_state_dict, get_parameters
|
||||
from tinygrad.nn import optim
|
||||
from extra.lr_scheduler import OneCycleLR
|
||||
from tinygrad.helpers import Context, getenv
|
||||
|
||||
BS, EVAL_BS, STEPS = getenv("BS", 512), getenv('EVAL_BS', 500), getenv("STEPS", 1000)
|
||||
GPUS = [f'{Device.DEFAULT}:{i}' for i in range(getenv("GPUS", 1))]
|
||||
assert BS % len(GPUS) == 0, f"{BS=} is not a multiple of {len(GPUS)=}"
|
||||
for x in GPUS: Device[x]
|
||||
|
||||
if getenv("HALF", 0):
|
||||
dtypes.default_float = dtypes.float16
|
||||
@@ -75,7 +71,7 @@ class SpeedyResNet:
|
||||
|
||||
def __call__(self, x, training=True):
|
||||
# pad to 32x32 because whitening conv creates 31x31 images that are awfully slow to compute with
|
||||
# TODO: remove the pad but instead let the kernel optimizer itself
|
||||
# TODO: remove the pad but instead let the kernel optimize itself
|
||||
forward = lambda x: x.conv2d(self.whitening).pad2d((1,0,0,1)).sequential(self.net)
|
||||
return forward(x) if training else forward(x)*0.5 + forward(x[..., ::-1])*0.5
|
||||
|
||||
@@ -251,11 +247,6 @@ def train_cifar():
|
||||
|
||||
set_seed(hyp['seed'])
|
||||
|
||||
# this import needs to be done here because this is running in a subprocess
|
||||
from extra.dist import OOB
|
||||
assert OOB is not None or not getenv("DIST"), "OOB should be initialized"
|
||||
rank, world_size = getenv("RANK"), getenv("WORLD_SIZE", 1)
|
||||
|
||||
X_train, Y_train, X_test, Y_test = fetch_cifar()
|
||||
# load data and label into GPU and convert to dtype accordingly
|
||||
X_train, X_test = X_train.to(device=Device.DEFAULT).float(), X_test.to(device=Device.DEFAULT).float()
|
||||
@@ -278,6 +269,10 @@ def train_cifar():
|
||||
X_train, Y_train = X_train.cast(dtypes.default_float), Y_train.cast(dtypes.default_float)
|
||||
X_test, Y_test = X_test.cast(dtypes.default_float), Y_test.cast(dtypes.default_float)
|
||||
|
||||
if len(GPUS) > 1:
|
||||
for x in get_parameters(model):
|
||||
x.to_(GPUS)
|
||||
|
||||
# parse the training params into bias and non-bias
|
||||
params_dict = get_state_dict(model)
|
||||
params_bias = []
|
||||
@@ -310,17 +305,6 @@ def train_cifar():
|
||||
optimizer[1].zero_grad()
|
||||
loss.backward()
|
||||
|
||||
if getenv("DIST"):
|
||||
# sync gradients across ranks
|
||||
bucket, offset = [], 0
|
||||
for _, v in params_dict.items():
|
||||
if v.grad is not None: bucket.append(v.grad.flatten())
|
||||
grads = collectives.allreduce(Tensor.cat(*bucket))
|
||||
for _, v in params_dict.items():
|
||||
if v.grad is not None:
|
||||
v.grad.assign(grads[offset:offset+v.grad.numel()].reshape(*v.grad.shape))
|
||||
offset += v.grad.numel()
|
||||
|
||||
optimizer[0].step()
|
||||
optimizer[1].step()
|
||||
lr_scheduler[0].step()
|
||||
@@ -360,9 +344,8 @@ def train_cifar():
|
||||
losses = []
|
||||
losses_ema = []
|
||||
for Xt, Yt in fetch_batches(X_test, Y_test, BS=EVAL_BS, is_train=False):
|
||||
# further split batch if distributed
|
||||
if getenv("DIST"):
|
||||
Xt, Yt = Xt.chunk(min(world_size, 5), 0)[min(rank, 4)], Yt.chunk(min(world_size, 5), 0)[min(rank, 4)]
|
||||
if len(GPUS) > 1:
|
||||
Xt, Yt = Xt.shard(GPUS, axis=0), Yt.shard(GPUS, axis=0)
|
||||
|
||||
correct, loss = eval_step_jitted(model, Xt, Yt)
|
||||
losses.append(loss.numpy().tolist())
|
||||
@@ -375,77 +358,34 @@ def train_cifar():
|
||||
# collect accuracy across ranks
|
||||
correct_sum, correct_len = sum(corrects), len(corrects)
|
||||
if model_ema: correct_sum_ema, correct_len_ema = sum(corrects_ema), len(corrects_ema)
|
||||
if getenv("DIST"):
|
||||
if rank == 0:
|
||||
for j in range(1, min(world_size, 5)):
|
||||
if model_ema:
|
||||
recv_sum, recv_len, recv_sum_ema, recv_len_ema = OOB.recv(j)
|
||||
else:
|
||||
recv_sum, recv_len = OOB.recv(j)
|
||||
correct_sum += recv_sum
|
||||
correct_len += recv_len
|
||||
if model_ema:
|
||||
correct_sum_ema += recv_sum_ema
|
||||
correct_len_ema += recv_len_ema
|
||||
elif rank < min(world_size, 5):
|
||||
if model_ema:
|
||||
OOB.send((correct_sum, correct_len, correct_sum_ema, correct_len_ema), 0)
|
||||
else:
|
||||
OOB.send((correct_sum, correct_len), 0)
|
||||
|
||||
# only rank 0 prints
|
||||
if rank == 0:
|
||||
acc = correct_sum/correct_len*100.0
|
||||
if model_ema: acc_ema = correct_sum_ema/correct_len_ema*100.0
|
||||
print(f"eval {correct_sum}/{correct_len} {acc:.2f}%, {(sum(losses)/len(losses)):7.2f} val_loss STEP={i} (in {(time.monotonic()-st)*1e3:.2f} ms)")
|
||||
if model_ema: print(f"eval ema {correct_sum_ema}/{correct_len_ema} {acc_ema:.2f}%, {(sum(losses_ema)/len(losses_ema)):7.2f} val_loss STEP={i}")
|
||||
acc = correct_sum/correct_len*100.0
|
||||
if model_ema: acc_ema = correct_sum_ema/correct_len_ema*100.0
|
||||
print(f"eval {correct_sum}/{correct_len} {acc:.2f}%, {(sum(losses)/len(losses)):7.2f} val_loss STEP={i} (in {(time.monotonic()-st)*1e3:.2f} ms)")
|
||||
if model_ema: print(f"eval ema {correct_sum_ema}/{correct_len_ema} {acc_ema:.2f}%, {(sum(losses_ema)/len(losses_ema)):7.2f} val_loss STEP={i}")
|
||||
|
||||
if STEPS == 0 or i == STEPS: break
|
||||
|
||||
X, Y = next(batcher)
|
||||
if getenv("DIST"):
|
||||
X, Y = X.chunk(world_size, 0)[rank], Y.chunk(world_size, 0)[rank]
|
||||
if len(GPUS) > 1:
|
||||
X, Y = X.shard(GPUS, axis=0), Y.shard(GPUS, axis=0)
|
||||
|
||||
GlobalCounters.reset()
|
||||
with Context(BEAM=getenv("LATEBEAM")):
|
||||
loss = train_step_jitted(model, [opt_bias, opt_non_bias], [lr_sched_bias, lr_sched_non_bias], X, Y)
|
||||
et = time.monotonic()
|
||||
loss_cpu = loss.numpy()
|
||||
# EMA for network weights
|
||||
if i > hyp['ema']['steps'] and (i+1) % hyp['ema']['every_n_steps'] == 0:
|
||||
if getenv("EMA") and i > hyp['ema']['steps'] and (i+1) % hyp['ema']['every_n_steps'] == 0:
|
||||
if model_ema is None:
|
||||
model_ema = modelEMA(W, model)
|
||||
model_ema.update(model, Tensor([projected_ema_decay_val*(i/STEPS)**hyp['ema']['decay_pow']]))
|
||||
cl = time.monotonic()
|
||||
if not getenv("DIST"):
|
||||
# 53 221.74 ms run, 2.22 ms python, 219.52 ms CL, 803.39 loss, 0.000807 LR, 4.66 GB used, 3042.49 GFLOPS, 674.65 GOPS
|
||||
print(f"{i:3d} {(cl-st)*1000.0:7.2f} ms run, {(et-st)*1000.0:7.2f} ms python, {(cl-et)*1000.0:7.2f} ms {loss.device}, {loss_cpu:7.2f} loss, {opt_non_bias.lr.numpy()[0]:.6f} LR, {GlobalCounters.mem_used/1e9:.2f} GB used, {GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS, {GlobalCounters.global_ops*1e-9:9.2f} GOPS")
|
||||
else:
|
||||
print(f"{i:3d} {(cl-st)*1000.0:7.2f} ms run, {(et-st)*1000.0:7.2f} ms python, {(cl-et)*1000.0:7.2f} ms {loss.device}, {loss_cpu:7.2f} loss, {opt_non_bias.lr.numpy()[0]:.6f} LR, {world_size*GlobalCounters.mem_used/1e9:.2f} GB used, {world_size*GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS")
|
||||
device_str = loss.device if isinstance(loss.device, str) else f"{loss.device[0]} * {len(loss.device)}"
|
||||
# 53 221.74 ms run, 2.22 ms python, 219.52 ms CL, 803.39 loss, 0.000807 LR, 4.66 GB used, 3042.49 GFLOPS, 674.65 GOPS
|
||||
print(f"{i:3d} {(cl-st)*1000.0:7.2f} ms run, {(et-st)*1000.0:7.2f} ms python, {(cl-et)*1000.0:7.2f} ms {device_str}, {loss_cpu:7.2f} loss, {opt_non_bias.lr.numpy()[0]:.6f} LR, {GlobalCounters.mem_used/1e9:.2f} GB used, {GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS, {GlobalCounters.global_ops*1e-9:9.2f} GOPS")
|
||||
st = cl
|
||||
i += 1
|
||||
|
||||
if __name__ == "__main__":
|
||||
if not getenv("DIST"):
|
||||
train_cifar()
|
||||
else: # distributed
|
||||
if getenv("HIP"):
|
||||
from tinygrad.runtime.ops_hip import HIP
|
||||
devices = [f"hip:{i}" for i in range(HIP.device_count)]
|
||||
else:
|
||||
from tinygrad.runtime.ops_gpu import CLDevice
|
||||
devices = [f"gpu:{i}" for i in range(len(CLDevice.device_ids))]
|
||||
world_size = len(devices)
|
||||
|
||||
# ensure that the batch size is divisible by the number of devices
|
||||
assert BS % world_size == 0, f"batch size {BS} is not divisible by world size {world_size}"
|
||||
|
||||
# ensure that the evaluation batch size is divisible by the number of devices
|
||||
assert EVAL_BS % min(world_size, 5) == 0, f"evaluation batch size {EVAL_BS} is not divisible by world size {min(world_size, 5)}"
|
||||
|
||||
# init out-of-band communication
|
||||
dist.init_oob(world_size)
|
||||
|
||||
# start the processes
|
||||
processes = []
|
||||
for rank, device in enumerate(devices):
|
||||
processes.append(dist.spawn(rank, device, fn=train_cifar, args=()))
|
||||
for p in processes: p.join()
|
||||
train_cifar()
|
||||
|
||||
Reference in New Issue
Block a user