mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
pad resnet training data with training data mean (#4369)
update model_train resnet to pad training
This commit is contained in:
@@ -61,8 +61,8 @@ def loader_process(q_in, q_out, X:Tensor, seed):
|
||||
random.seed(seed * 2 ** 10 + idx)
|
||||
img = preprocess_train(img)
|
||||
else:
|
||||
# pad zeros
|
||||
img = np.zeros((224, 224, 3), dtype=np.uint8)
|
||||
# pad data with training mean
|
||||
img = np.tile(np.array([[[123.68, 116.78, 103.94]]], dtype=np.uint8), (224, 224, 1))
|
||||
|
||||
# broken out
|
||||
#img_tensor = Tensor(img.tobytes(), device='CPU')
|
||||
|
||||
@@ -79,7 +79,7 @@ def train_resnet():
|
||||
eval_start_epoch = getenv("EVAL_START_EPOCH", 0)
|
||||
eval_freq = getenv("EVAL_FREQ", 1)
|
||||
|
||||
steps_in_train_epoch = config["steps_in_train_epoch"] = (len(get_train_files()) // BS)
|
||||
steps_in_train_epoch = config["steps_in_train_epoch"] = (round_up(len(get_train_files()), BS) // BS)
|
||||
steps_in_val_epoch = config["steps_in_val_epoch"] = (round_up(len(get_val_files()), EVAL_BS) // EVAL_BS)
|
||||
|
||||
config["DEFAULT_FLOAT"] = dtypes.default_float.name
|
||||
@@ -185,14 +185,13 @@ def train_resnet():
|
||||
MLLOGGER.start(key=mllog_constants.EPOCH_START, value=e+1, metadata=dict(epoch_num=e+1))
|
||||
Tensor.training = True
|
||||
BEAM.value = TRAIN_BEAM
|
||||
batch_loader = batch_load_resnet(batch_size=BS, val=False, shuffle=True, seed=seed*epochs + e)
|
||||
batch_loader = batch_load_resnet(batch_size=BS, val=False, shuffle=True, seed=seed*epochs + e, pad_first_batch=True)
|
||||
it = iter(tqdm(batch_loader, total=steps_in_train_epoch, desc=f"epoch {e}", disable=BENCHMARK))
|
||||
i, proc = 0, data_get(it)
|
||||
st = time.perf_counter()
|
||||
while proc is not None:
|
||||
GlobalCounters.reset()
|
||||
# TODO: pad training data
|
||||
(loss, top_1), _, proc = train_step(proc[0], proc[1]), proc[2], proc[3]
|
||||
(loss, top_1), y, proc = train_step(proc[0], proc[1]), proc[2], proc[3]
|
||||
|
||||
pt = time.perf_counter()
|
||||
|
||||
@@ -204,7 +203,8 @@ def train_resnet():
|
||||
dt = time.perf_counter()
|
||||
|
||||
device_str = loss.device if isinstance(loss.device, str) else f"{loss.device[0]} * {len(loss.device)}"
|
||||
loss, top_1_acc = loss.numpy().item(), top_1.numpy().item() / BS
|
||||
loss, top_1 = loss.numpy().item(), top_1.numpy().item()
|
||||
top_1_acc = top_1 / sum(yi != -1 for yi in y)
|
||||
|
||||
cl = time.perf_counter()
|
||||
if BENCHMARK:
|
||||
|
||||
Reference in New Issue
Block a user