mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
add checkpointing and training resume capabilities
This commit is contained in:
@@ -363,6 +363,7 @@ def train_retinanet():
|
||||
config["gpus"] = GPUS = [f"{Device.DEFAULT}:{i}" for i in range(getenv("GPUS", 1))]
|
||||
|
||||
for x in GPUS: Device[x]
|
||||
print(f"training on {GPUS}")
|
||||
|
||||
def _freeze_backbone_layers(backbone, trainable_layers, loaded_keys):
|
||||
model_layers = ["layer4", "layer3", "layer2", "layer1", "conv1"][:trainable_layers]
|
||||
@@ -403,24 +404,13 @@ def train_retinanet():
|
||||
config["lr"] = lr = 1e-4
|
||||
config["lr_warmup_epochs"] = lr_warmup_epochs = 1
|
||||
config["lr_warmup_factor"] = lr_warmup_factor = 1e-3
|
||||
config["seed"] = seed = getenv("SEED", random.SystemRandom().randint(0, 2**32 - 1))
|
||||
config["bs"] = bs = getenv("BS", 128)
|
||||
config["num_epochs"] = num_epochs = getenv("EPOCHS", 4)
|
||||
config["seed"] = SEED = getenv("SEED", random.SystemRandom().randint(0, 2**32 - 1))
|
||||
config["bs"] = BS = getenv("BS", 128)
|
||||
config["epochs"] = EPOCHS = getenv("EPOCHS", 4)
|
||||
|
||||
if seed:
|
||||
Tensor.manual_seed(seed)
|
||||
np.random.seed(seed=seed)
|
||||
|
||||
# ** initialize wandb **
|
||||
if (WANDB := getenv("WANDB")):
|
||||
import wandb
|
||||
|
||||
wandb_args = {"project": "MLPerf-RetinaNet"}
|
||||
if (wandb_id := getenv("WANDB_RESUME", "")):
|
||||
wandb_args["id"] = wandb_id
|
||||
wandb_args["resume"] = "must"
|
||||
|
||||
wandb.init(config=config, **wandb_args)
|
||||
if SEED:
|
||||
Tensor.manual_seed(SEED)
|
||||
np.random.seed(seed=SEED)
|
||||
|
||||
# ** model initializers **
|
||||
resnet.BatchNorm = FrozenBatchNorm2d
|
||||
@@ -445,13 +435,32 @@ def train_retinanet():
|
||||
val_dataset = COCO(download_dataset(BASE_DIR, "validation"))
|
||||
|
||||
# ** lr scheduler **
|
||||
config["steps_in_train_epoch"] = steps_in_train_epoch = round_up(len(train_dataset.imgs.keys()), bs) // bs
|
||||
config["steps_in_train_epoch"] = steps_in_train_epoch = round_up(len(train_dataset.imgs.keys()), BS) // BS
|
||||
start_iter, warmup_iters = start_epoch * steps_in_train_epoch, lr_warmup_epochs * steps_in_train_epoch
|
||||
lr_scheduler = _create_lr_scheduler(optim, start_iter, warmup_iters, lr_warmup_factor)
|
||||
|
||||
# ** resume from checkpointing **
|
||||
if ckpt := getenv("RESUME", ""):
|
||||
load_training_state(model, optim, lr_scheduler, safe_load(ckpt))
|
||||
start_epoch = int(lr_scheduler.epoch_counter.item() / steps_in_train_epoch)
|
||||
print(f"resuming from {ckpt} at epoch {start_epoch}")
|
||||
|
||||
# ** initialize wandb **
|
||||
if WANDB := getenv("WANDB"):
|
||||
import wandb
|
||||
|
||||
wandb_args = {"project": "MLPerf-RetinaNet"}
|
||||
if wandb_id := getenv("WANDB_RESUME", ""):
|
||||
wandb_args["id"] = wandb_id
|
||||
wandb_args["resume"] = "must"
|
||||
|
||||
wandb.init(config=config, **wandb_args)
|
||||
|
||||
print(f"training with batch size {BS} for {EPOCHS} epochs")
|
||||
|
||||
# ** training loop **
|
||||
for e in range(start_epoch, num_epochs):
|
||||
train_dataloader = batch_load_retinanet(train_dataset, False, Path(BASE_DIR), batch_size=bs, seed=seed)
|
||||
for e in range(start_epoch, EPOCHS):
|
||||
train_dataloader = batch_load_retinanet(train_dataset, False, Path(BASE_DIR), batch_size=BS, seed=SEED)
|
||||
it = iter(tqdm(train_dataloader, total=steps_in_train_epoch, desc=f"epoch {e}", disable=BENCHMARK))
|
||||
i, proc = 0, _data_get(it)
|
||||
|
||||
@@ -499,12 +508,21 @@ def train_retinanet():
|
||||
if i == BENCHMARK:
|
||||
assert not math.isnan(loss)
|
||||
median_step_time = sorted(step_times)[(BENCHMARK + 1) // 2] # in seconds
|
||||
estimated_total_minutes = int(median_step_time * steps_in_train_epoch * num_epochs / 60)
|
||||
estimated_total_minutes = int(median_step_time * steps_in_train_epoch * EPOCHS / 60)
|
||||
print(f"Estimated training time: {estimated_total_minutes // 60}h{estimated_total_minutes % 60}m")
|
||||
print(f"epoch global_ops: {steps_in_train_epoch * GlobalCounters.global_ops:_}, "
|
||||
f"epoch global_mem: {steps_in_train_epoch * GlobalCounters.global_mem:_}")
|
||||
return
|
||||
|
||||
if getenv("CKPT"):
|
||||
if not os.path.exists(ckpt_dir := Path(getenv("CKPT_DIR", "./ckpts"))): os.mkdir(ckpt_dir)
|
||||
if WANDB and wandb.run is not None:
|
||||
fn = ckpt_dir / Path(f"{time.strftime('%Y%m%d_%H%M%S')}_{wandb.run.id}_e{e}.safe")
|
||||
else:
|
||||
fn = ckpt_dir / Path(f"{time.strftime('%Y%m%d_%H%M%S')}_e{e}.safe")
|
||||
print(f"saving ckpt to {fn}")
|
||||
safe_save(get_training_state(model, optim, lr_scheduler), fn)
|
||||
|
||||
def train_unet3d():
|
||||
"""
|
||||
Trains the UNet3D model.
|
||||
|
||||
Reference in New Issue
Block a user