diff --git a/examples/mlperf/dataloader.py b/examples/mlperf/dataloader.py index a2aad4d3db..e1aace4fd2 100644 --- a/examples/mlperf/dataloader.py +++ b/examples/mlperf/dataloader.py @@ -61,8 +61,8 @@ def loader_process(q_in, q_out, X:Tensor, seed): random.seed(seed * 2 ** 10 + idx) img = preprocess_train(img) else: - # pad zeros - img = np.zeros((224, 224, 3), dtype=np.uint8) + # pad data with training mean + img = np.tile(np.array([[[123.68, 116.78, 103.94]]], dtype=np.uint8), (224, 224, 1)) # broken out #img_tensor = Tensor(img.tobytes(), device='CPU') diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index 09ca95cd31..1e4ea09167 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -79,7 +79,7 @@ def train_resnet(): eval_start_epoch = getenv("EVAL_START_EPOCH", 0) eval_freq = getenv("EVAL_FREQ", 1) - steps_in_train_epoch = config["steps_in_train_epoch"] = (len(get_train_files()) // BS) + steps_in_train_epoch = config["steps_in_train_epoch"] = (round_up(len(get_train_files()), BS) // BS) steps_in_val_epoch = config["steps_in_val_epoch"] = (round_up(len(get_val_files()), EVAL_BS) // EVAL_BS) config["DEFAULT_FLOAT"] = dtypes.default_float.name @@ -185,14 +185,13 @@ def train_resnet(): MLLOGGER.start(key=mllog_constants.EPOCH_START, value=e+1, metadata=dict(epoch_num=e+1)) Tensor.training = True BEAM.value = TRAIN_BEAM - batch_loader = batch_load_resnet(batch_size=BS, val=False, shuffle=True, seed=seed*epochs + e) + batch_loader = batch_load_resnet(batch_size=BS, val=False, shuffle=True, seed=seed*epochs + e, pad_first_batch=True) it = iter(tqdm(batch_loader, total=steps_in_train_epoch, desc=f"epoch {e}", disable=BENCHMARK)) i, proc = 0, data_get(it) st = time.perf_counter() while proc is not None: GlobalCounters.reset() - # TODO: pad training data - (loss, top_1), _, proc = train_step(proc[0], proc[1]), proc[2], proc[3] + (loss, top_1), y, proc = train_step(proc[0], proc[1]), proc[2], proc[3] pt = time.perf_counter() @@ -204,7 +203,8 @@ def train_resnet(): dt = time.perf_counter() device_str = loss.device if isinstance(loss.device, str) else f"{loss.device[0]} * {len(loss.device)}" - loss, top_1_acc = loss.numpy().item(), top_1.numpy().item() / BS + loss, top_1 = loss.numpy().item(), top_1.numpy().item() + top_1_acc = top_1 / sum(yi != -1 for yi in y) cl = time.perf_counter() if BENCHMARK: