mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
small cleanups (#9947)
This commit is contained in:
@@ -432,7 +432,8 @@ def train_retinanet():
|
||||
model = retinanet.RetinaNet(backbone, num_classes=NUM_CLASSES)
|
||||
params = get_parameters(model)
|
||||
|
||||
for p in params: p.to_(GPUS)
|
||||
if len(GPUS) > 1:
|
||||
for p in params: p.to_(GPUS)
|
||||
|
||||
step_times, start_epoch = [], 0
|
||||
|
||||
@@ -446,8 +447,7 @@ def train_retinanet():
|
||||
|
||||
# ** lr scheduler **
|
||||
config["steps_in_train_epoch"] = steps_in_train_epoch = round_up(len(train_dataset.imgs.keys()), BS) // BS
|
||||
config["steps_in_val_epoch"] = steps_in_val_epoch = (round_up(len(val_dataset.imgs.keys()), BS) // BS)
|
||||
start_iter = start_epoch * steps_in_train_epoch
|
||||
config["steps_in_val_epoch"] = steps_in_val_epoch = (round_up(len(val_dataset.imgs.keys()), EVAL_BS) // EVAL_BS)
|
||||
|
||||
# ** initialize wandb **
|
||||
if (WANDB:=getenv("WANDB")):
|
||||
|
||||
Reference in New Issue
Block a user