add train dataloader for unet3d

This commit is contained in:
Francis Lata
2024-06-04 18:54:23 +00:00
parent 6678ee692e
commit c166d129df

View File

@@ -348,6 +348,7 @@ def train_retinanet():
def train_unet3d():
from examples.mlperf.losses import dice_ce_loss
from examples.mlperf.metrics import dice_score
from examples.mlperf.dataloader import batch_load_unet3d
from extra.models.unet3d import UNet3D
from extra.datasets.kits19 import iterate, get_train_files, get_val_files, sliding_window_inference, preprocess_dataset, BASEDIR
from tinygrad import Device, Tensor, GlobalCounters, Context
@@ -423,6 +424,10 @@ def train_unet3d():
print(f"saving checkpoint to {fn}")
safe_save(state_dict, fn)
def data_get(it):
x, y, cookie = next(it)
return x.shard(GPUS, axis=0).realize(), y.shard(GPUS, axis=0), cookie
@TinyJit
@Tensor.train()
def train_step(model, x, y):
@@ -467,33 +472,49 @@ def train_unet3d():
if epoch <= LR_WARMUP_EPOCHS and LR_WARMUP_EPOCHS > 0:
lr_warm_up(optim, LR_WARMUP_INIT_LR, LR, epoch, LR_WARMUP_EPOCHS)
train_dataloader = batch_load_unet3d(PREPROCESSED_DIR / "train", batch_size=BS, val=False, shuffle=True, seed=SEED)
it = iter(tqdm(train_dataloader, total=SAMPLES_PER_EPOCH, desc=f"epoch {epoch}", disable=BENCHMARK))
i, proc = 0, data_get(it)
prev_cookies = []
st = time.perf_counter()
for i, (x, y) in enumerate(tqdm(iterate(get_train_files(), preprocessed_dir=PREPROCESSED_DIR, val=False, shuffle=True, bs=BS), total=SAMPLES_PER_EPOCH, desc=f"epoch {epoch}", disable=BENCHMARK), start=1):
dt = time.perf_counter()
while proc is not None:
GlobalCounters.reset()
x, y = Tensor(x).realize().shard(GPUS, axis=0), Tensor(y, requires_grad=False).shard(GPUS, axis=0)
loss, proc = train_step(model, proc[0], proc[1]), proc[2]
loss = train_step(model, x, y)
pt = time.perf_counter()
if len(prev_cookies) == getenv("STORE_COOKIES", 1): prev_cookies = [] # free previous cookies after gpu work has been enqueued
try:
next_proc = data_get(it)
except StopIteration:
next_proc = None
dt = time.perf_counter()
device_str = loss.device if isinstance(loss.device, str) else f"{loss.device[0]} * {len(loss.device)}"
loss = loss.numpy().item()
cl = time.perf_counter()
if BENCHMARK: step_times.append(cl - st)
tqdm.write(
f"{i:5} {((cl - st)) * 1000.0:7.2f} ms run, {(pt - st) * 1000.0:7.2f} ms python, {(dt - st) * 1000.0:6.2f} ms fetch data, "
f"{loss:5.3f} loss, {optim.lr.numpy()[0]:.6f} LR, {GlobalCounters.mem_used / 1e9:.2f} GB used, {GlobalCounters.global_ops * 1e-9 / (cl - st):9.2f} GFLOPS"
f"{i:5} {((cl - st)) * 1000.0:7.2f} ms run, {(pt - st) * 1000.0:7.2f} ms python, {(dt - pt) * 1000.0:6.2f} ms fetch data, "
f"{(cl - dt) * 1000.0:7.2f} ms {device_str}, {loss:5.2f} loss, {optim.lr.numpy()[0]:.6f} LR, "
f"{GlobalCounters.mem_used / 1e9:.2f} GB used, {GlobalCounters.global_ops * 1e-9 / (cl - st):9.2f} GFLOPS"
)
if WANDB:
wandb.log({"lr": optim.lr.numpy(), "train/loss": loss, "train/step_time": cl - st, "train/python_time": pt - st, "train/data_time": dt - st,
"train/GFLOPS": GlobalCounters.global_ops * 1e-9 / (cl - st), "epoch": epoch + (i + 1) / SAMPLES_PER_EPOCH})
wandb.log({"lr": optim.lr.numpy(), "train/loss": loss, "train/step_time": cl - st, "train/python_time": pt - st, "train/data_time": dt - pt,
"train/cl_time": cl - dt, "train/GFLOPS": GlobalCounters.global_ops * 1e-9 / (cl - st), "epoch": epoch + (i + 1) / SAMPLES_PER_EPOCH})
st = cl
prev_cookies.append(proc)
proc, next_proc = next_proc, None # return old cookie
i += 1
if i == BENCHMARK:
median_step_time = sorted(step_times)[(BENCHMARK + 1) // 2] # in seconds