pad resnet training data with training data mean (#4369)

update model_train resnet to pad training
This commit is contained in:
chenyu
2024-05-02 20:26:15 -04:00
committed by GitHub
parent 3cf8291f2f
commit 2c3b7f8e70
2 changed files with 7 additions and 7 deletions

View File

@@ -61,8 +61,8 @@ def loader_process(q_in, q_out, X:Tensor, seed):
random.seed(seed * 2 ** 10 + idx)
img = preprocess_train(img)
else:
# pad zeros
img = np.zeros((224, 224, 3), dtype=np.uint8)
# pad data with training mean
img = np.tile(np.array([[[123.68, 116.78, 103.94]]], dtype=np.uint8), (224, 224, 1))
# broken out
#img_tensor = Tensor(img.tobytes(), device='CPU')

View File

@@ -79,7 +79,7 @@ def train_resnet():
eval_start_epoch = getenv("EVAL_START_EPOCH", 0)
eval_freq = getenv("EVAL_FREQ", 1)
steps_in_train_epoch = config["steps_in_train_epoch"] = (len(get_train_files()) // BS)
steps_in_train_epoch = config["steps_in_train_epoch"] = (round_up(len(get_train_files()), BS) // BS)
steps_in_val_epoch = config["steps_in_val_epoch"] = (round_up(len(get_val_files()), EVAL_BS) // EVAL_BS)
config["DEFAULT_FLOAT"] = dtypes.default_float.name
@@ -185,14 +185,13 @@ def train_resnet():
MLLOGGER.start(key=mllog_constants.EPOCH_START, value=e+1, metadata=dict(epoch_num=e+1))
Tensor.training = True
BEAM.value = TRAIN_BEAM
batch_loader = batch_load_resnet(batch_size=BS, val=False, shuffle=True, seed=seed*epochs + e)
batch_loader = batch_load_resnet(batch_size=BS, val=False, shuffle=True, seed=seed*epochs + e, pad_first_batch=True)
it = iter(tqdm(batch_loader, total=steps_in_train_epoch, desc=f"epoch {e}", disable=BENCHMARK))
i, proc = 0, data_get(it)
st = time.perf_counter()
while proc is not None:
GlobalCounters.reset()
# TODO: pad training data
(loss, top_1), _, proc = train_step(proc[0], proc[1]), proc[2], proc[3]
(loss, top_1), y, proc = train_step(proc[0], proc[1]), proc[2], proc[3]
pt = time.perf_counter()
@@ -204,7 +203,8 @@ def train_resnet():
dt = time.perf_counter()
device_str = loss.device if isinstance(loss.device, str) else f"{loss.device[0]} * {len(loss.device)}"
loss, top_1_acc = loss.numpy().item(), top_1.numpy().item() / BS
loss, top_1 = loss.numpy().item(), top_1.numpy().item()
top_1_acc = top_1 / sum(yi != -1 for yi in y)
cl = time.perf_counter()
if BENCHMARK: