mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
update seeding on dataloader and the start of training script
This commit is contained in:
@@ -361,13 +361,14 @@ def load_retinanet_data(base_dir:Path, val:bool, queue_in:Queue, queue_out:Queue
|
||||
idx, img, tgt = data
|
||||
img = image_load(base_dir, img["subset"], img["file_name"])
|
||||
|
||||
if seed is not None:
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
|
||||
if val:
|
||||
img = resize(img)[0]
|
||||
else:
|
||||
if seed is not None:
|
||||
np.random.seed(seed * 2 ** 10 + idx)
|
||||
random.seed(seed * 2 ** 10 + idx)
|
||||
torch.manual_seed(seed * 2 ** 10 + idx)
|
||||
|
||||
img, tgt = random_horizontal_flip(img, tgt)
|
||||
img, tgt, _ = resize(img, tgt=tgt)
|
||||
match_quality_matrix = box_iou(tgt["boxes"], (anchor := np.concatenate(generate_anchors((800, 800)))))
|
||||
|
||||
@@ -414,9 +414,7 @@ def train_retinanet():
|
||||
config["lr_warmup_epochs"] = lr_warmup_epochs = getenv("LR_WARMUP_EPOCHS", 1)
|
||||
config["lr_warmup_factor"] = lr_warmup_factor = getenv("LR_WARMUP_FACTOR", 1e-3)
|
||||
|
||||
if SEED:
|
||||
Tensor.manual_seed(SEED)
|
||||
np.random.seed(seed=SEED)
|
||||
if SEED: Tensor.manual_seed(SEED)
|
||||
|
||||
# ** model initializers **
|
||||
resnet.BatchNorm = FrozenBatchNorm2d
|
||||
|
||||
Reference in New Issue
Block a user