From 20d6d2119d752a8902004c797eb71b500a5fbe34 Mon Sep 17 00:00:00 2001 From: Francis Lata Date: Thu, 21 Nov 2024 03:55:54 -0800 Subject: [PATCH] get forward call to model work and setup multi-GPU --- examples/mlperf/model_train.py | 19 ++++++++++++------- extra/datasets/openimages.py | 16 +++++++--------- 2 files changed, 19 insertions(+), 16 deletions(-) diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index ab22fbaef2..9e7958e736 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -343,7 +343,7 @@ def train_resnet(): def train_retinanet(): from examples.mlperf.dataloader import batch_load_retinanet from examples.mlperf.initializers import FrozenBatchNorm2d - from extra.datasets.openimages import MLPERF_CLASSES, BASEDIR, download_dataset + from extra.datasets.openimages import MLPERF_CLASSES, BASEDIR, download_dataset, normalize from extra.models.retinanet import RetinaNet from extra.models import resnet from extra.lr_scheduler import LambdaLR @@ -354,6 +354,9 @@ def train_retinanet(): NUM_CLASSES = len(MLPERF_CLASSES) BASE_DIR = getenv("BASE_DIR", BASEDIR) + GPUS = [f"{Device.DEFAULT}:{i}" for i in range(getenv("GPUS", 1))] + + for x in GPUS: Device[x] def _freeze_backbone_layers(backbone, trainable_layers, loaded_keys): model_layers = ["layer4", "layer3", "layer2", "layer1", "conv1"][:trainable_layers] @@ -365,7 +368,7 @@ def train_retinanet(): def _data_get(it): x, y_boxes, y_labels, cookie = next(it) - return x, y_boxes, y_labels, cookie + return x.shard(GPUS, axis=0).realize(), y_boxes, y_labels, cookie def _create_lr_scheduler(optim, start_iter, warmup_iters, warmup_factor): # TODO: refactor this a bit more so we don't have to recreate it, unlike what MLPerf script is doing @@ -376,11 +379,10 @@ def train_retinanet(): return warmup_factor * (1 - alpha) + alpha return LambdaLR(optim, _lr_lambda) - def _train_step(model, optim, lr_scheduler): + def _train_step(model, optim, lr_scheduler, x): optim.zero_grad() - optim.step() - lr_scheduler.step() + y_hat = model(normalize(x, GPUS)) # ** hyperparameters ** # using https://github.com/mlcommons/logging/blob/96d0acee011ba97702532dcc39e6eeaa99ebef24/mlperf_logging/rcp_checker/training_4.1.0/rcps_ssd.json#L3 @@ -400,9 +402,11 @@ def train_retinanet(): _freeze_backbone_layers(backbone, 3, loaded_keys) model = RetinaNet(backbone, num_classes=NUM_CLASSES) + params = get_parameters(model) + + for p in params: p.realize().to_(GPUS) # ** optimizer ** - params = get_parameters(model) optim = Adam(params, lr=LR) # ** dataset ** @@ -428,7 +432,8 @@ def train_retinanet(): st = time.perf_counter() while proc is not None: - # _train_step(model, optim, lr_scheduler) # TODO: enable once full model has been integrated + x, y_boxes, y_labels, proc = proc + _train_step(model, optim, lr_scheduler, x) # TODO: enable once full model has been integrated if len(prev_cookies) == getenv("STORE_COOKIES", 1): prev_cookies = [] # free previous cookies after gpu work has been enqueued try: diff --git a/extra/datasets/openimages.py b/extra/datasets/openimages.py index 17f6e02f7c..83b2dc79e9 100644 --- a/extra/datasets/openimages.py +++ b/extra/datasets/openimages.py @@ -4,9 +4,9 @@ import numpy as np from PIL import Image from pathlib import Path import boto3, botocore -from tinygrad import Tensor +from tinygrad import Tensor, dtypes from tinygrad.helpers import fetch, tqdm, getenv -from typing import Optional, Dict, Tuple, Union +from typing import Optional, Dict, Tuple, Union, List import pandas as pd import concurrent.futures @@ -206,13 +206,11 @@ def resize(img:Image, tgt:Optional[Dict[str, Union[np.ndarray, Tuple]]]=None, si return img, img_size -def normalize(img): - mean = Tensor([0.485, 0.456, 0.406]).reshape(1, -1, 1, 1) - std = Tensor([0.229, 0.224, 0.225]).reshape(1, -1, 1, 1) - img = img.permute([0,3,1,2]) / 255.0 - img -= mean - img /= std - return img +def normalize(img:Tensor, device:List[str]): # TODO: pass device here + mean = Tensor([0.485, 0.456, 0.406], device=device, dtype=dtypes.float32).reshape(1, -1, 1, 1) + std = Tensor([0.229, 0.224, 0.225], device=device, dtype=dtypes.float32).reshape(1, -1, 1, 1) + img = ((img.permute([0, 3, 1, 2]) / 255.0) - mean) / std + return img.cast(dtypes.default_float) if __name__ == "__main__": download_dataset(base_dir:=getenv("BASE_DIR", BASEDIR), "train")