get forward call to model work and setup multi-GPU

This commit is contained in:
Francis Lata
2024-11-21 03:55:54 -08:00
parent e1bc499074
commit 20d6d2119d
2 changed files with 19 additions and 16 deletions

View File

@@ -343,7 +343,7 @@ def train_resnet():
def train_retinanet():
from examples.mlperf.dataloader import batch_load_retinanet
from examples.mlperf.initializers import FrozenBatchNorm2d
from extra.datasets.openimages import MLPERF_CLASSES, BASEDIR, download_dataset
from extra.datasets.openimages import MLPERF_CLASSES, BASEDIR, download_dataset, normalize
from extra.models.retinanet import RetinaNet
from extra.models import resnet
from extra.lr_scheduler import LambdaLR
@@ -354,6 +354,9 @@ def train_retinanet():
NUM_CLASSES = len(MLPERF_CLASSES)
BASE_DIR = getenv("BASE_DIR", BASEDIR)
GPUS = [f"{Device.DEFAULT}:{i}" for i in range(getenv("GPUS", 1))]
for x in GPUS: Device[x]
def _freeze_backbone_layers(backbone, trainable_layers, loaded_keys):
model_layers = ["layer4", "layer3", "layer2", "layer1", "conv1"][:trainable_layers]
@@ -365,7 +368,7 @@ def train_retinanet():
def _data_get(it):
x, y_boxes, y_labels, cookie = next(it)
return x, y_boxes, y_labels, cookie
return x.shard(GPUS, axis=0).realize(), y_boxes, y_labels, cookie
def _create_lr_scheduler(optim, start_iter, warmup_iters, warmup_factor):
# TODO: refactor this a bit more so we don't have to recreate it, unlike what MLPerf script is doing
@@ -376,11 +379,10 @@ def train_retinanet():
return warmup_factor * (1 - alpha) + alpha
return LambdaLR(optim, _lr_lambda)
def _train_step(model, optim, lr_scheduler):
def _train_step(model, optim, lr_scheduler, x):
optim.zero_grad()
optim.step()
lr_scheduler.step()
y_hat = model(normalize(x, GPUS))
# ** hyperparameters **
# using https://github.com/mlcommons/logging/blob/96d0acee011ba97702532dcc39e6eeaa99ebef24/mlperf_logging/rcp_checker/training_4.1.0/rcps_ssd.json#L3
@@ -400,9 +402,11 @@ def train_retinanet():
_freeze_backbone_layers(backbone, 3, loaded_keys)
model = RetinaNet(backbone, num_classes=NUM_CLASSES)
params = get_parameters(model)
for p in params: p.realize().to_(GPUS)
# ** optimizer **
params = get_parameters(model)
optim = Adam(params, lr=LR)
# ** dataset **
@@ -428,7 +432,8 @@ def train_retinanet():
st = time.perf_counter()
while proc is not None:
# _train_step(model, optim, lr_scheduler) # TODO: enable once full model has been integrated
x, y_boxes, y_labels, proc = proc
_train_step(model, optim, lr_scheduler, x) # TODO: enable once full model has been integrated
if len(prev_cookies) == getenv("STORE_COOKIES", 1): prev_cookies = [] # free previous cookies after gpu work has been enqueued
try:

View File

@@ -4,9 +4,9 @@ import numpy as np
from PIL import Image
from pathlib import Path
import boto3, botocore
from tinygrad import Tensor
from tinygrad import Tensor, dtypes
from tinygrad.helpers import fetch, tqdm, getenv
from typing import Optional, Dict, Tuple, Union
from typing import Optional, Dict, Tuple, Union, List
import pandas as pd
import concurrent.futures
@@ -206,13 +206,11 @@ def resize(img:Image, tgt:Optional[Dict[str, Union[np.ndarray, Tuple]]]=None, si
return img, img_size
def normalize(img):
mean = Tensor([0.485, 0.456, 0.406]).reshape(1, -1, 1, 1)
std = Tensor([0.229, 0.224, 0.225]).reshape(1, -1, 1, 1)
img = img.permute([0,3,1,2]) / 255.0
img -= mean
img /= std
return img
def normalize(img:Tensor, device:List[str]): # TODO: pass device here
mean = Tensor([0.485, 0.456, 0.406], device=device, dtype=dtypes.float32).reshape(1, -1, 1, 1)
std = Tensor([0.229, 0.224, 0.225], device=device, dtype=dtypes.float32).reshape(1, -1, 1, 1)
img = ((img.permute([0, 3, 1, 2]) / 255.0) - mean) / std
return img.cast(dtypes.default_float)
if __name__ == "__main__":
download_dataset(base_dir:=getenv("BASE_DIR", BASEDIR), "train")