diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index 5431d18842..d7d9c2feb7 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -407,6 +407,10 @@ def train_retinanet(): config["bs"] = bs = getenv("BS", 128) config["num_epochs"] = num_epochs = getenv("EPOCHS", 4) + if seed: + Tensor.manual_seed(seed) + np.random.seed(seed=seed) + # ** initialize wandb ** if (WANDB := getenv("WANDB")): import wandb