mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
add train dataloader for unet3d
This commit is contained in:
@@ -348,6 +348,7 @@ def train_retinanet():
|
||||
def train_unet3d():
|
||||
from examples.mlperf.losses import dice_ce_loss
|
||||
from examples.mlperf.metrics import dice_score
|
||||
from examples.mlperf.dataloader import batch_load_unet3d
|
||||
from extra.models.unet3d import UNet3D
|
||||
from extra.datasets.kits19 import iterate, get_train_files, get_val_files, sliding_window_inference, preprocess_dataset, BASEDIR
|
||||
from tinygrad import Device, Tensor, GlobalCounters, Context
|
||||
@@ -423,6 +424,10 @@ def train_unet3d():
|
||||
print(f"saving checkpoint to {fn}")
|
||||
safe_save(state_dict, fn)
|
||||
|
||||
def data_get(it):
|
||||
x, y, cookie = next(it)
|
||||
return x.shard(GPUS, axis=0).realize(), y.shard(GPUS, axis=0), cookie
|
||||
|
||||
@TinyJit
|
||||
@Tensor.train()
|
||||
def train_step(model, x, y):
|
||||
@@ -467,33 +472,49 @@ def train_unet3d():
|
||||
if epoch <= LR_WARMUP_EPOCHS and LR_WARMUP_EPOCHS > 0:
|
||||
lr_warm_up(optim, LR_WARMUP_INIT_LR, LR, epoch, LR_WARMUP_EPOCHS)
|
||||
|
||||
train_dataloader = batch_load_unet3d(PREPROCESSED_DIR / "train", batch_size=BS, val=False, shuffle=True, seed=SEED)
|
||||
it = iter(tqdm(train_dataloader, total=SAMPLES_PER_EPOCH, desc=f"epoch {epoch}", disable=BENCHMARK))
|
||||
i, proc = 0, data_get(it)
|
||||
|
||||
prev_cookies = []
|
||||
st = time.perf_counter()
|
||||
|
||||
for i, (x, y) in enumerate(tqdm(iterate(get_train_files(), preprocessed_dir=PREPROCESSED_DIR, val=False, shuffle=True, bs=BS), total=SAMPLES_PER_EPOCH, desc=f"epoch {epoch}", disable=BENCHMARK), start=1):
|
||||
dt = time.perf_counter()
|
||||
|
||||
while proc is not None:
|
||||
GlobalCounters.reset()
|
||||
|
||||
x, y = Tensor(x).realize().shard(GPUS, axis=0), Tensor(y, requires_grad=False).shard(GPUS, axis=0)
|
||||
loss, proc = train_step(model, proc[0], proc[1]), proc[2]
|
||||
|
||||
loss = train_step(model, x, y)
|
||||
pt = time.perf_counter()
|
||||
|
||||
if len(prev_cookies) == getenv("STORE_COOKIES", 1): prev_cookies = [] # free previous cookies after gpu work has been enqueued
|
||||
try:
|
||||
next_proc = data_get(it)
|
||||
except StopIteration:
|
||||
next_proc = None
|
||||
|
||||
dt = time.perf_counter()
|
||||
|
||||
device_str = loss.device if isinstance(loss.device, str) else f"{loss.device[0]} * {len(loss.device)}"
|
||||
loss = loss.numpy().item()
|
||||
|
||||
cl = time.perf_counter()
|
||||
|
||||
if BENCHMARK: step_times.append(cl - st)
|
||||
|
||||
tqdm.write(
|
||||
f"{i:5} {((cl - st)) * 1000.0:7.2f} ms run, {(pt - st) * 1000.0:7.2f} ms python, {(dt - st) * 1000.0:6.2f} ms fetch data, "
|
||||
f"{loss:5.3f} loss, {optim.lr.numpy()[0]:.6f} LR, {GlobalCounters.mem_used / 1e9:.2f} GB used, {GlobalCounters.global_ops * 1e-9 / (cl - st):9.2f} GFLOPS"
|
||||
f"{i:5} {((cl - st)) * 1000.0:7.2f} ms run, {(pt - st) * 1000.0:7.2f} ms python, {(dt - pt) * 1000.0:6.2f} ms fetch data, "
|
||||
f"{(cl - dt) * 1000.0:7.2f} ms {device_str}, {loss:5.2f} loss, {optim.lr.numpy()[0]:.6f} LR, "
|
||||
f"{GlobalCounters.mem_used / 1e9:.2f} GB used, {GlobalCounters.global_ops * 1e-9 / (cl - st):9.2f} GFLOPS"
|
||||
)
|
||||
|
||||
if WANDB:
|
||||
wandb.log({"lr": optim.lr.numpy(), "train/loss": loss, "train/step_time": cl - st, "train/python_time": pt - st, "train/data_time": dt - st,
|
||||
"train/GFLOPS": GlobalCounters.global_ops * 1e-9 / (cl - st), "epoch": epoch + (i + 1) / SAMPLES_PER_EPOCH})
|
||||
wandb.log({"lr": optim.lr.numpy(), "train/loss": loss, "train/step_time": cl - st, "train/python_time": pt - st, "train/data_time": dt - pt,
|
||||
"train/cl_time": cl - dt, "train/GFLOPS": GlobalCounters.global_ops * 1e-9 / (cl - st), "epoch": epoch + (i + 1) / SAMPLES_PER_EPOCH})
|
||||
|
||||
st = cl
|
||||
prev_cookies.append(proc)
|
||||
proc, next_proc = next_proc, None # return old cookie
|
||||
i += 1
|
||||
|
||||
if i == BENCHMARK:
|
||||
median_step_time = sorted(step_times)[(BENCHMARK + 1) // 2] # in seconds
|
||||
|
||||
Reference in New Issue
Block a user