mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
RetinaNet MLPerf (#8385)
* add support for a custom BASEDIR for openimages download * make export step faster * add focal loss * update model_eval with new dataloader * generate_anchors in tinygrad * update initializers for model * small cleanup * revert isin enhancements * recursively go through backbone layers to freeze them * add optimizer * minor cleanup * start dataloader work with input images * add first transform for train set * reuse existing prepare_target * continue with dataloader implementation * add dataloader * separate out KiTS19 dataset test cases * create mock data samples for test * add dataloader + test * cleanup dataloader test and revert shm path * trim dataloader related code needed from ref * got dataloader with normalize working * update image to be float32 * add back normalization and negate it in test * clean up reference dataset implementation + ruff changes * add validation set test * add proper training loop over the training dataset * add LambdaLR support * add LR scheduler and the start of training step * get forward call to model work and setup multi-GPU * already passed device * return matches from dataloader * hotfix for dataloader typo causing some hang * start some work on classification loss * update focal loss to support masking * add missing test and cleanup focal loss * cleanup unit tests * remove masking support for sigmoid_focal_loss * make ClassificationHead loss work * cleanups + fix dataloader tests * remove sigmoid when computing loss * make anchors use Tensors * simplify anchors batching * revert anchors to use np * implement regression loss * fix regression loss * cleanup losses * move BoxCoder to MLPerf helpers * revert helper changes * fixes after helper refactor cleanup * add tests for l1_loss * start re-enabling training step * minor cleanup * add pycocotools to testing dependencies * make training work * adjust regression loss to mask after L1 loss is calculated * reduce img and lbl sizes by half for KiTS19 dataset tests * Revert "reduce img and lbl sizes by half for KiTS19 dataset tests" This reverts commitd115b0c664. * temporarily disable openimages dataset tests to debug CI * enable openimages dataset test and create samples once * temporarily disable openimages validation set test * reenable test and add some debugging to the test * add boto3 testing dependencies * add pandas to testing dependencies * This reverts commit467704fec6. * reenable test * move sample creation to setup * realize boxcoder's encoding * add wandb * fix wandb resuming feature * move anchors as part of dataloader * fix dtype for anchor inside dataloader and fix horizontal flip transformation * add support for BENCHMARK * set seed * debug dataset test failuire * Revert "debug dataset test failuire" This reverts commit1b2f9d7f50. * fix dataloader script * do not realize when sharding model weights * setup openimages samples differently * create the necessary samples per test case * enable lr scheduler and fix benchmark timing * add jit to the training loop * add checkpointing and training resume capabilities * refactor on training loop and start the work on val looop * add debug logging for dataloader test * debug test * assert boxes again * update validation dataloader and more cleanups * fix validation test case * add multi device support to retinanet eval * fix issue with realized on dataloader * remove optional disk tensors in dataloader * remove verbose debugging on datasets test * put back parallel testing and remove img_ids Tensor from dataloader * cleanup train and validation dataloader * return validation targets in dataloader * cleanup boxes and labels in dataloader * fix img_ids repeating its values * remove unnecessary targets from validation dataloader * add validation loop to training script * adjust LR to be the ratio of the batch size * minor cleanups * remove frozen layers from optimizer's params * hyperparameter adjustments and cleanups * model init, hyperparam, and data preprocessing updates * no need to return loaded keys for resnet * fix train script * update loss calculation for regresionhead and some cleanups * add JIT reset support * add nan check during training * Revert "add nan check during training" This reverts commitddf1f0d5dd. * Revert "Revert "add nan check during training"" This reverts commitb7b2943197. * some typing cleanups * update seeding on dataloader and the start of training script * undo changse * undo more changes * more typing fixes * minor cleanups * update dataloader seed * hotfix: log metric and move target metric check outside of CKPT * check for CKPT when target metric is reached before saving * add TRAIN_BEAM and EVAL_BEAM * minor cleanup * update hyperparams and add support for EVAL_BS * add green coloring to metric reached statement * initial work to support f16 * update model initializers to be monkeypatched * update layers to support float32 weight loading + float16 training * don't return loss that's scaled * run eval on benchmark beam * move BEAM to their respective steps * update layers to be compatible with fp16 * end BENCHMARK after first eval * cleanups and adjust learning rate for fp16 * remove duplicated files from test * revert losses changes * Revert "revert losses changes" This reverts commitaebccf93ac. * go back to old LR * cast batchnorm to float32 * set new loss scaler default value for float16 * remove LambdaLRScheduler * remove runner and use dataloader on eval * fix retinanet eval with new dataloader * remove unused import * revert lr_scheduler updates * use BS=96 with new learning rate * rename module initializers * more cleanups on training loop * remove contig from optim.step * simplify sum when computing loss
This commit is contained in:
@@ -68,3 +68,62 @@ class LayerNormBert:
|
||||
xn = x.cast(dtypes.float32).layernorm(eps=self.eps, axis=self.axis).cast(x.dtype)
|
||||
if not self.elementwise_affine: return xn
|
||||
return (xn * self.weight.cast(dtypes.default_float) + self.bias.cast(dtypes.default_float))
|
||||
|
||||
class FrozenBatchNorm2dRetinaNet(nn.BatchNorm2d):
|
||||
def __init__(self, sz:int, eps=1e-5, affine=True, track_running_stats=True, momentum=0.1):
|
||||
self.eps, self.track_running_stats, self.momentum = eps, track_running_stats, momentum
|
||||
|
||||
self.weight = Tensor.ones(sz, dtype=dtypes.float32, requires_grad=False) if affine else None
|
||||
self.bias = Tensor.zeros(sz, dtype=dtypes.float32, requires_grad=False) if affine else None
|
||||
|
||||
if track_running_stats: self.running_mean, self.running_var = Tensor.zeros(sz, dtype=dtypes.float32, requires_grad=False), Tensor.ones(sz, dtype=dtypes.float32, requires_grad=False)
|
||||
self.num_batches_tracked = Tensor.zeros(1, dtype=dtypes.long, requires_grad=False)
|
||||
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
batch_mean, batch_var = super().calc_stats(x.cast(dtypes.float32))
|
||||
if self.track_running_stats and Tensor.training:
|
||||
self.running_mean.assign((1-self.momentum) * self.running_mean + self.momentum * batch_mean.detach().cast(self.running_mean.dtype))
|
||||
self.running_var.assign((1-self.momentum) * self.running_var + self.momentum * x.numel()/(x.numel()-x.shape[1]) * batch_var.detach().cast(self.running_var.dtype))
|
||||
self.num_batches_tracked += 1
|
||||
return x.cast(dtypes.float32).batchnorm(self.weight, self.bias, batch_mean, batch_var.add(self.eps).rsqrt()).cast(x.dtype)
|
||||
|
||||
class Conv2dNormalRetinaNet(nn.Conv2d):
|
||||
def __init__(self, in_channels:int, out_channels:int, kernel_size:int|tuple[int, ...],
|
||||
stride:int=1, padding:int|tuple[int, ...]|str=0, dilation:int=1, groups:int=1,
|
||||
bias:bool=True, prior_prob:float|None=None):
|
||||
super().__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
|
||||
self.weight = Tensor.normal(*self.weight.shape, std=0.01, dtype=dtypes.float32)
|
||||
if bias:
|
||||
if prior_prob:
|
||||
prior_prob = Tensor(prior_prob, device=self.bias.device, dtype=dtypes.float32).expand(*self.bias.shape)
|
||||
self.bias = -(((1 - prior_prob) / prior_prob).log())
|
||||
else: self.bias = Tensor.zeros_like(self.bias, dtype=dtypes.float32)
|
||||
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
return x.conv2d(self.weight.cast(dtypes.default_float), self.bias.cast(dtypes.default_float) if self.bias is not None else None,
|
||||
groups=self.groups, stride=self.stride, padding=self.padding)
|
||||
|
||||
class Conv2dKaimingUniformRetinaNet(nn.Conv2d):
|
||||
def __init__(self, in_channels:int, out_channels:int, kernel_size:int|tuple[int, ...],
|
||||
stride:int=1, padding:int|tuple[int, ...]|str=0, dilation:int=1, groups:int=1,
|
||||
bias:bool=True):
|
||||
super().__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
|
||||
self.weight = Tensor.kaiming_uniform(*self.weight.shape, a=1, dtype=dtypes.float32)
|
||||
if bias: self.bias = Tensor.zeros_like(self.bias, dtype=dtypes.float32)
|
||||
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
return x.conv2d(self.weight.cast(dtypes.default_float), self.bias.cast(dtypes.default_float) if self.bias is not None else None,
|
||||
groups=self.groups, stride=self.stride, padding=self.padding)
|
||||
|
||||
class Conv2dRetinaNet(nn.Conv2d):
|
||||
def __init__(self, in_channels:int, out_channels:int, kernel_size:int|tuple[int, ...],
|
||||
stride:int=1, padding:int|tuple[int, ...]|str=0, dilation:int=1, groups:int=1,
|
||||
bias:bool=True):
|
||||
super().__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
|
||||
scale = 1 / math.sqrt(in_channels * prod(self.kernel_size))
|
||||
self.weight = Tensor.uniform(out_channels, in_channels//groups, *self.kernel_size, low=-scale, high=scale, dtype=dtypes.float32)
|
||||
self.bias: Tensor|None = Tensor.uniform(out_channels, low=-scale, high=scale, dtype=dtypes.float32) if bias else None
|
||||
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
return x.conv2d(self.weight.cast(dtypes.default_float), self.bias.cast(dtypes.default_float) if self.bias is not None else None,
|
||||
groups=self.groups, stride=self.stride, dilation=self.dilation, padding=self.padding)
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import os, time, math, functools
|
||||
import os, time, math, functools, random
|
||||
from pathlib import Path
|
||||
import multiprocessing
|
||||
|
||||
from tinygrad import Device, GlobalCounters, Tensor, TinyJit, dtypes
|
||||
from tinygrad.helpers import getenv, BEAM, WINO, round_up, diskcache_clear, FUSE_CONV_BW
|
||||
from tinygrad.nn.state import get_parameters, get_state_dict, safe_load, safe_save
|
||||
from tinygrad.nn.optim import LAMB, LARS, SGD, OptimizerGroup
|
||||
from tinygrad.nn.optim import LAMB, LARS, SGD, OptimizerGroup, Adam
|
||||
|
||||
from extra.lr_scheduler import LRSchedulerGroup
|
||||
from examples.mlperf.helpers import get_training_state, load_training_state
|
||||
@@ -343,8 +343,243 @@ def train_resnet():
|
||||
safe_save(get_training_state(model, optimizer_group, scheduler_group), fn)
|
||||
|
||||
def train_retinanet():
|
||||
# TODO: Retinanet
|
||||
pass
|
||||
from contextlib import redirect_stdout
|
||||
from examples.mlperf.dataloader import batch_load_retinanet
|
||||
from examples.mlperf.initializers import FrozenBatchNorm2dRetinaNet, Conv2dNormalRetinaNet, Conv2dKaimingUniformRetinaNet, Linear, Conv2dRetinaNet
|
||||
from extra.datasets.openimages import MLPERF_CLASSES, BASEDIR, download_dataset, normalize
|
||||
from extra.models import resnet
|
||||
from pycocotools.coco import COCO
|
||||
from pycocotools.cocoeval import COCOeval
|
||||
from tinygrad.helpers import colored, Context
|
||||
from typing import Iterator
|
||||
import extra.models.retinanet as retinanet
|
||||
|
||||
import numpy as np
|
||||
|
||||
config, target_metric = {}, 0.34
|
||||
|
||||
NUM_CLASSES = len(MLPERF_CLASSES)
|
||||
BASE_DIR = getenv("BASE_DIR", BASEDIR)
|
||||
BENCHMARK = getenv("BENCHMARK")
|
||||
config["gpus"] = GPUS = [f"{Device.DEFAULT}:{i}" for i in range(getenv("GPUS", 6))]
|
||||
|
||||
for x in GPUS: Device[x]
|
||||
print(f"training on {GPUS}")
|
||||
|
||||
def _freeze_backbone_layers(backbone:resnet.ResNet, trainable_layers:int):
|
||||
layers_to_train = ["layer4", "layer3", "layer2", "layer1", "conv1"][:trainable_layers]
|
||||
for k, v in get_state_dict(backbone).items():
|
||||
if all([not k.startswith(layer) for layer in layers_to_train]):
|
||||
v.requires_grad = False
|
||||
|
||||
def _data_get(it:Iterator[tuple[Tensor, ...]], val:bool=False):
|
||||
if val:
|
||||
x, img_ids, img_sizes, cookie = next(it)
|
||||
return x.shard(GPUS, axis=0).realize(), img_ids, img_sizes, cookie
|
||||
|
||||
x, y_boxes, y_labels, matches, anchors, cookie = next(it)
|
||||
return x.shard(GPUS, axis=0).realize(), y_boxes.shard(GPUS, axis=0), y_labels.shard(GPUS, axis=0), matches.shard(GPUS, axis=0), anchors.shard(GPUS, axis=0), cookie
|
||||
|
||||
@TinyJit
|
||||
def _train_step(model, optim, loss_scaler, x, **kwargs):
|
||||
with Context(BEAM=TRAIN_BEAM):
|
||||
optim.zero_grad()
|
||||
|
||||
losses = model(normalize(x, GPUS), **kwargs)
|
||||
loss = sum([l for l in losses.values()])
|
||||
|
||||
(loss * loss_scaler).backward()
|
||||
for t in optim.params: t.grad = t.grad / loss_scaler
|
||||
|
||||
optim.step()
|
||||
|
||||
return loss.realize(), losses
|
||||
|
||||
@TinyJit
|
||||
def _eval_step(model, x, **kwargs):
|
||||
with Context(BEAM=EVAL_BEAM):
|
||||
out = model(normalize(x, GPUS), **kwargs)
|
||||
return out.realize()
|
||||
|
||||
# ** hyperparameters **
|
||||
config["seed"] = SEED = getenv("SEED", random.SystemRandom().randint(0, 2**32 - 1))
|
||||
config["bs"] = BS = getenv("BS", 16 * len(GPUS) if dtypes.default_float == dtypes.float16 else 12 * len(GPUS))
|
||||
config["eval_bs"] = EVAL_BS = getenv("EVAL_BS", BS)
|
||||
config["epochs"] = EPOCHS = getenv("EPOCHS", 4)
|
||||
config["train_beam"] = TRAIN_BEAM = getenv("TRAIN_BEAM", BEAM.value)
|
||||
config["eval_beam"] = EVAL_BEAM = getenv("EVAL_BEAM", BEAM.value)
|
||||
config["lr"] = lr = getenv("LR", 8.5e-5 * (BS / 96))
|
||||
config["loss_scaler"] = loss_scaler = getenv("LOSS_SCALER", 2**11 if dtypes.default_float == dtypes.float16 else 1.0)
|
||||
config["default_float"] = dtypes.default_float.name
|
||||
|
||||
if SEED: Tensor.manual_seed(SEED)
|
||||
|
||||
# ** model initializers **
|
||||
resnet.BatchNorm = FrozenBatchNorm2dRetinaNet
|
||||
resnet.Linear = Linear
|
||||
resnet.Conv2d = Conv2dRetinaNet
|
||||
|
||||
retinanet.ConvHead = Conv2dNormalRetinaNet
|
||||
retinanet.ConvClassificationHeadLogits = functools.partial(Conv2dNormalRetinaNet, prior_prob=0.01)
|
||||
retinanet.ConvFPN = Conv2dKaimingUniformRetinaNet
|
||||
|
||||
# ** model setup **
|
||||
backbone = resnet.ResNeXt50_32X4D(num_classes=None)
|
||||
backbone.load_from_pretrained()
|
||||
_freeze_backbone_layers(backbone, 3)
|
||||
|
||||
model = retinanet.RetinaNet(backbone, num_classes=NUM_CLASSES)
|
||||
params = get_parameters(model)
|
||||
|
||||
for p in params: p.to_(GPUS)
|
||||
|
||||
step_times, start_epoch = [], 0
|
||||
|
||||
# ** optimizer **
|
||||
optim = Adam(params, lr=lr)
|
||||
|
||||
# ** dataset **
|
||||
train_dataset = COCO(download_dataset(BASE_DIR, "train"))
|
||||
val_dataset = COCO(download_dataset(BASE_DIR, "validation"))
|
||||
coco_val = COCOeval(cocoGt=val_dataset, iouType="bbox")
|
||||
|
||||
# ** lr scheduler **
|
||||
config["steps_in_train_epoch"] = steps_in_train_epoch = round_up(len(train_dataset.imgs.keys()), BS) // BS
|
||||
config["steps_in_val_epoch"] = steps_in_val_epoch = (round_up(len(val_dataset.imgs.keys()), BS) // BS)
|
||||
start_iter = start_epoch * steps_in_train_epoch
|
||||
|
||||
# ** initialize wandb **
|
||||
if (WANDB:=getenv("WANDB")):
|
||||
import wandb
|
||||
wandb.init(config=config, project="MLPerf-RetinaNet")
|
||||
|
||||
print(f"training with batch size {BS} for {EPOCHS} epochs")
|
||||
|
||||
for e in range(start_epoch, EPOCHS):
|
||||
# ** training loop **
|
||||
train_dataloader = batch_load_retinanet(train_dataset, False, Path(BASE_DIR), batch_size=BS, seed=SEED)
|
||||
it = iter(tqdm(train_dataloader, total=steps_in_train_epoch, desc=f"epoch {e}", disable=BENCHMARK))
|
||||
i, proc = 0, _data_get(it)
|
||||
|
||||
prev_cookies = []
|
||||
st = time.perf_counter()
|
||||
|
||||
while proc is not None:
|
||||
GlobalCounters.reset()
|
||||
|
||||
x, y_bboxes, y_labels, matches, anchors, proc = proc
|
||||
loss, losses = _train_step(model, optim, loss_scaler, x, labels=y_labels, matches=matches, anchors=anchors, bboxes=y_bboxes)
|
||||
|
||||
pt = time.perf_counter()
|
||||
|
||||
if len(prev_cookies) == getenv("STORE_COOKIES", 1): prev_cookies = [] # free previous cookies after gpu work has been enqueued
|
||||
try:
|
||||
next_proc = _data_get(it)
|
||||
except StopIteration:
|
||||
next_proc = None
|
||||
|
||||
dt = time.perf_counter()
|
||||
|
||||
device_str = loss.device if isinstance(loss.device, str) else f"{loss.device[0]} * {len(loss.device)}"
|
||||
loss = loss.item()
|
||||
|
||||
cl = time.perf_counter()
|
||||
if BENCHMARK: step_times.append(cl - st)
|
||||
|
||||
if not math.isfinite(loss):
|
||||
print("loss is nan")
|
||||
return
|
||||
|
||||
tqdm.write(
|
||||
f"{i:5} {((cl - st)) * 1000.0:7.2f} ms run, {(pt - st) * 1000.0:7.2f} ms python, {(dt - pt) * 1000.0:6.2f} ms fetch data, "
|
||||
f"{(cl - dt) * 1000.0:7.2f} ms {device_str}, {loss:5.2f} loss, {losses['classification_loss'].item():5.4f} classification loss, {losses['regression_loss'].item():5.4f} regression loss, "
|
||||
f"{optim.lr.numpy()[0]:.6f} LR, {GlobalCounters.mem_used / 1e9:.2f} GB used, {GlobalCounters.global_ops * 1e-9 / (cl - st):9.2f} GFLOPS"
|
||||
)
|
||||
|
||||
if WANDB:
|
||||
wandb.log({"lr": optim.lr.numpy(), "train/loss": loss, "train/classification_loss": losses["classification_loss"].item(), "train/regression_loss": losses["regression_loss"].item(),
|
||||
"train/step_time": cl - st, "train/python_time": pt - st, "train/data_time": dt - pt, "train/cl_time": cl - dt,
|
||||
"train/GFLOPS": GlobalCounters.global_ops * 1e-9 / (cl - st), "epoch": e + (i + 1) / steps_in_train_epoch})
|
||||
|
||||
st = cl
|
||||
prev_cookies.append(proc)
|
||||
proc, next_proc = next_proc, None # return old cookie
|
||||
i += 1
|
||||
|
||||
if i == BENCHMARK:
|
||||
assert not math.isnan(loss)
|
||||
median_step_time = sorted(step_times)[(BENCHMARK + 1) // 2] # in seconds
|
||||
estimated_total_minutes = int(median_step_time * steps_in_train_epoch * EPOCHS / 60)
|
||||
print(f"Estimated training time: {estimated_total_minutes // 60}h{estimated_total_minutes % 60}m")
|
||||
print(f"epoch global_ops: {steps_in_train_epoch * GlobalCounters.global_ops:_}, "
|
||||
f"epoch global_mem: {steps_in_train_epoch * GlobalCounters.global_mem:_}")
|
||||
# if we are doing beam search, run the first eval too
|
||||
if (TRAIN_BEAM or EVAL_BEAM) and e == start_epoch: break
|
||||
return
|
||||
|
||||
# ** eval loop **
|
||||
if getenv("RESET_STEP", 1): _train_step.reset()
|
||||
|
||||
with Tensor.train(mode=False), Tensor.test():
|
||||
val_dataloader = batch_load_retinanet(val_dataset, (val:=True), Path(BASE_DIR), batch_size=EVAL_BS, shuffle=False, seed=SEED)
|
||||
it = iter(tqdm(val_dataloader, total=steps_in_val_epoch))
|
||||
i, proc = 0, _data_get(it, val=val)
|
||||
|
||||
eval_times, prev_cookies = [], []
|
||||
val_img_ids, val_imgs, ncats, narea = [], [], len(coco_val.params.catIds), len(coco_val.params.areaRng)
|
||||
|
||||
while proc is not None:
|
||||
GlobalCounters.reset()
|
||||
st = time.time()
|
||||
|
||||
out, img_ids, img_sizes, proc = _eval_step(model, (x:=proc[0])).numpy(), proc[1], proc[2], proc[3]
|
||||
out = model.postprocess_detections(out, input_size=x.shape[1:3], orig_image_sizes=img_sizes)
|
||||
coco_results = [{"image_id": img_ids[i], "category_id": label, "bbox": box.tolist(), "score": score}
|
||||
for i, prediction in enumerate(out) for box, score, label in zip(*prediction.values())]
|
||||
|
||||
with redirect_stdout(None):
|
||||
coco_val.cocoDt = val_dataset.loadRes(coco_results)
|
||||
coco_val.params.imgIds = img_ids
|
||||
coco_val.evaluate()
|
||||
|
||||
val_img_ids.extend(img_ids)
|
||||
val_imgs.append(np.array(coco_val.evalImgs).reshape(ncats, narea, len(img_ids)))
|
||||
|
||||
if len(prev_cookies) == getenv("STORE_COOKIES", 1): prev_cookies = [] # free previous cookies after gpu work has been enqueued
|
||||
try:
|
||||
next_proc = _data_get(it, val=val)
|
||||
except StopIteration:
|
||||
next_proc = None
|
||||
|
||||
prev_cookies.append(proc)
|
||||
proc, next_proc = next_proc, None
|
||||
i += 1
|
||||
|
||||
if i == BENCHMARK:
|
||||
return
|
||||
|
||||
et = time.time()
|
||||
eval_times.append(et - st)
|
||||
|
||||
if getenv("RESET_STEP", 1): _eval_step.reset()
|
||||
total_fw_time = sum(eval_times) / len(eval_times)
|
||||
|
||||
coco_val.params.imgIds = val_img_ids
|
||||
coco_val._paramsEval.imgIds = val_img_ids
|
||||
coco_val.evalImgs = list(np.concatenate(val_imgs, -1).flatten())
|
||||
coco_val.accumulate()
|
||||
coco_val.summarize()
|
||||
|
||||
val_metric = coco_val.stats[0]
|
||||
|
||||
tqdm.write(f"eval time: {total_fw_time:.2f}, eval metric: {val_metric:.4f}")
|
||||
|
||||
if WANDB:
|
||||
wandb.log({"eval/forward_time": total_fw_time, "eval/metric": val_metric, "epoch": e + 1})
|
||||
|
||||
if val_metric >= target_metric:
|
||||
print(colored(f"target metric reached: {val_metric:.2f}/{target_metric:.2f}", color="green"))
|
||||
break
|
||||
|
||||
def train_unet3d():
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user