diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index df6d80deb1..ef8c60925b 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -28,6 +28,9 @@ def train_resnet(): print(f"Training on {GPUS}") for x in GPUS: Device[x] + TRAIN_BEAM = getenv("TRAIN_BEAM", BEAM.value) + EVAL_BEAM = getenv("EVAL_BEAM", BEAM.value) + # ** model definition and initializers ** num_classes = 1000 resnet.Conv2d = Conv2dHeNormal @@ -61,9 +64,11 @@ def train_resnet(): steps_in_val_epoch = config["steps_in_val_epoch"] = (len(get_val_files()) // EVAL_BS) config["DEFAULT_FLOAT"] = dtypes.default_float.name - config["BEAM"] = BEAM.value - config["WINO"] = WINO.value - config["SYNCBN"] = getenv("SYNCBN") + config["BEAM"] = BEAM.value + config["TRAIN_BEAM"] = TRAIN_BEAM + config["EVAL_BEAM"] = EVAL_BEAM + config["WINO"] = WINO.value + config["SYNCBN"] = getenv("SYNCBN") # ** Optimizer ** skip_list = [v for k, v in get_state_dict(model).items() if "bn" in k or "bias" in k or "downsample.1" in k] @@ -105,7 +110,7 @@ def train_resnet(): def normalize(x): return (x.permute([0, 3, 1, 2]) - input_mean).cast(dtypes.default_float) @TinyJit def train_step(X, Y): - with Context(BEAM=getenv("TRAIN_BEAM", BEAM.value)): + with Context(BEAM=TRAIN_BEAM): optimizer_group.zero_grad() X = normalize(X) out = model.forward(X) @@ -119,7 +124,7 @@ def train_resnet(): @TinyJit def eval_step(X, Y): - with Context(BEAM=getenv("EVAL_BEAM", BEAM.value)): + with Context(BEAM=EVAL_BEAM): X = normalize(X) out = model.forward(X) loss = out.cast(dtypes.float32).sparse_categorical_crossentropy(Y, label_smoothing=0.1) @@ -177,7 +182,7 @@ def train_resnet(): estimated_total_minutes = int(median_step_time * steps_in_train_epoch * epochs / 60) print(f"Estimated training time: {estimated_total_minutes // 60}h{estimated_total_minutes % 60}m") # if we are doing beam search, run the first eval too - if BEAM.value and e == start_epoch: break + if (TRAIN_BEAM or EVAL_BEAM) and e == start_epoch: break return # ** eval loop **