diff --git a/examples/mlperf/dataloader.py b/examples/mlperf/dataloader.py index 7a69c8186f..990931f765 100644 --- a/examples/mlperf/dataloader.py +++ b/examples/mlperf/dataloader.py @@ -1,5 +1,5 @@ import os, random, pickle, queue -from typing import List +from typing import List, Tuple from pathlib import Path from multiprocessing import Queue, Process, shared_memory, connection, Lock, cpu_count @@ -356,34 +356,50 @@ def batch_load_unet3d(preprocessed_dataset_dir:Path, batch_size:int=6, val:bool= ### RetinaNet -def load_retinanet_data(base_dir:Path, queue_in:Queue, queue_out:Queue, X:Tensor): +def load_retinanet_data(base_dir:Path, val:bool, queue_in:Queue, queue_out:Queue, X:Tensor, Y_boxes:Tensor, Y_labels:Tensor, anchors:np.ndarray): from extra.datasets.openimages import image_load, prepare_target, random_horizontal_flip, resize + from examples.mlperf.helpers import box_iou, find_matches + while (data:=queue_in.get()) is not None: idx, img, ann = data img_id = img["id"] - img = image_load(base_dir, "train", img["file_name"]) + img = image_load(base_dir, img["subset"], img["file_name"]) tgt = prepare_target(ann, img_id, img.size[::-1]) img, tgt = random_horizontal_flip(img, tgt) - img, _ = resize(img) + img, tgt, _ = resize(img, tgt=tgt) + match_quality_matrix = box_iou(tgt["boxes"], anchors) + matches = find_matches(match_quality_matrix, allow_low_quality_matches=True) + matches = np.clip(matches, 0, None) + boxes, labels = tgt["boxes"][matches], tgt["labels"][matches] X[idx].contiguous().realize().lazydata.realized.as_buffer(force_zero_copy=True)[:] = img.tobytes() + Y_boxes[idx].contiguous().realize().lazydata.realized.as_buffer(force_zero_copy=True)[:] = boxes.tobytes() + Y_labels[idx].contiguous().realize().lazydata.realized.as_buffer(force_zero_copy=True)[:] = labels.tobytes() queue_out.put(idx) queue_out.put(None) -def batch_load_retinanet(dataset, base_dir:Path, batch_size:int=32, seed:int=None): +def batch_load_retinanet(dataset, val:bool, anchors:np.ndarray, base_dir:Path, batch_size:int=32, seed:int=None): + def _enqueue_batch(bc): + for idx in range(bc * batch_size, (bc+1) * batch_size): + img = dataset.loadImgs(next(dataset_iter))[0] + ann = dataset.loadAnns(dataset.getAnnIds(img["id"])) + queue_in.put((idx, img, ann)) + + def _setup_shared_mem(shm_name:str, size:Tuple[int, ...], dtype:dtypes) -> Tuple[shared_memory.SharedMemory, Tensor]: + if os.path.exists(f"/dev/shm/{shm_name}"): os.unlink(f"/dev/shm/{shm_name}") + shm = shared_memory.SharedMemory(name=shm_name, create=True, size=prod(size)) + shm_tensor = Tensor.empty(*size, dtype=dtype, device=f"disk:/dev/shm/{shm_name}") + return shm, shm_tensor + image_ids = sorted(dataset.imgs.keys()) batch_count = min(32, len(image_ids) // batch_size) queue_in, queue_out = Queue(), Queue() procs, data_out_count = [], [0] * batch_count - shm_name_x = "retinanet_x" - # shm_name_x, shm_name_y = "retinanet_x", "retinanet_y" - sz = (batch_size * batch_count, 800, 800, 3) - if os.path.exists(f"/dev/shm/{shm_name_x}"): os.unlink(f"/dev/shm/{shm_name_x}") - # if os.path.exists(f"/dev/shm/{shm_name_y}"): os.unlink(f"/dev/shm/{shm_name_y}") - shm_x = shared_memory.SharedMemory(name=shm_name_x, create=True, size=prod(sz)) - # shm_y = shared_memory.SharedMemory(name=shm_name_y, create=True, size=prod(sz)) + shm_x, X = _setup_shared_mem("retinanet_x", (batch_size * batch_count, 800, 800, 3), dtypes.uint8) + shm_y_boxes, Y_boxes = _setup_shared_mem("retinanet_y_boxes", (batch_size * batch_count, 120087, 4), dtypes.float32) + shm_y_labels, Y_labels = _setup_shared_mem("retinanet_y_labels", (batch_size * batch_count, 120087), dtypes.int64) shutdown = False class Cookie: @@ -391,15 +407,9 @@ def batch_load_retinanet(dataset, base_dir:Path, batch_size:int=32, seed:int=Non self.bc = bc def __del__(self): if not shutdown: - try: enqueue_batch(self.bc) + try: _enqueue_batch(self.bc) except StopIteration: pass - def enqueue_batch(bc): - for idx in range(bc * batch_size, (bc+1) * batch_size): - img = dataset.loadImgs(next(dataset_iter))[0] - ann = dataset.loadAnns(dataset.getAnnIds(img["id"])) - queue_in.put((idx, img, ann)) - # def shuffle_indices(file_indices, seed=None): # rng = random.Random(seed) # rng.shuffle(file_indices) @@ -408,17 +418,14 @@ def batch_load_retinanet(dataset, base_dir:Path, batch_size:int=32, seed:int=Non dataset_iter = iter(image_ids) try: - X = Tensor.empty(*sz, dtype=dtypes.uint8, device=f"disk:/dev/shm/{shm_name_x}") - # Y = Tensor.empty(*sz, dtype=dtypes.uint8, device=f"disk:/dev/shm/{shm_name_y}") - for _ in range(cpu_count()): - proc = Process(target=load_retinanet_data, args=(base_dir, queue_in, queue_out, X)) + proc = Process(target=load_retinanet_data, args=(base_dir, val, queue_in, queue_out, X, Y_boxes, Y_labels, anchors)) proc.daemon = True proc.start() procs.append(proc) for bc in range(batch_count): - enqueue_batch(bc) + _enqueue_batch(bc) for _ in range(len(image_ids) // batch_size): while True: @@ -427,7 +434,7 @@ def batch_load_retinanet(dataset, base_dir:Path, batch_size:int=32, seed:int=Non if data_out_count[bc] == batch_size: break data_out_count[bc] = 0 - yield X[bc * batch_size:(bc + 1) * batch_size], Cookie(bc) + yield X[bc * batch_size:(bc + 1) * batch_size], Y_boxes[bc * batch_size:(bc + 1) * batch_size], Y_labels[bc * batch_size:(bc + 1) * batch_size], Cookie(bc) finally: shutdown = True @@ -442,10 +449,12 @@ def batch_load_retinanet(dataset, base_dir:Path, batch_size:int=32, seed:int=Non for proc in procs: proc.join() shm_x.close() - # shm_y.close() + shm_y_boxes.close() + shm_y_labels.close() try: shm_x.unlink() - # shm_y.unlink() + shm_y_boxes.unlink() + shm_y_labels.unlink() except FileNotFoundError: # happens with BENCHMARK set pass @@ -474,8 +483,9 @@ if __name__ == "__main__": from extra.datasets.openimages import BASEDIR, download_dataset from pycocotools.coco import COCO dataset = COCO(download_dataset(base_dir:=getenv("BASE_DIR", BASEDIR), "validation" if val else "train")) + anchors = np.ones((120087, 4)) with tqdm(total=len(dataset.imgs.keys())) as pbar: - for x, _ in batch_load_retinanet(dataset, base_dir): + for x, _, _, _ in batch_load_retinanet(dataset, val, anchors, base_dir): pbar.update(x.shape[0]) load_fn_name = f"load_{getenv('MODEL', 'resnet')}" diff --git a/examples/mlperf/helpers.py b/examples/mlperf/helpers.py index d232edfd86..90e1199ad1 100644 --- a/examples/mlperf/helpers.py +++ b/examples/mlperf/helpers.py @@ -1,6 +1,6 @@ from collections import OrderedDict import unicodedata -from typing import Optional +from typing import Optional, Tuple import numpy as np from tinygrad.nn import state from tinygrad.tensor import Tensor, dtypes @@ -238,3 +238,40 @@ def get_fake_data_bert(GPUS:list[str], BS:int): "masked_lm_weights": Tensor.empty((BS, 76), dtype=dtypes.float32).contiguous().shard_(GPUS, axis=0), "next_sentence_labels": Tensor.empty((BS, 1), dtype=dtypes.float32).contiguous().shard_(GPUS, axis=0), } + +def find_matches(match_quality_matrix:np.ndarray, high_threshold:float=0.5, low_threshold:float=0.4, allow_low_quality_matches:bool=False) -> np.ndarray: + BELOW_LOW_THRESHOLD, BETWEEN_THRESHOLDS = -1, -2 + + def _set_low_quality_matches_(matches:np.ndarray, all_matches:np.ndarray, match_quality_matrix:np.ndarray): + highest_quality_foreach_gt = np.max(match_quality_matrix, axis=1) + pred_inds_to_update = np.nonzero(match_quality_matrix == highest_quality_foreach_gt[:, None])[1] + matches[pred_inds_to_update] = all_matches[pred_inds_to_update] + + assert low_threshold <= high_threshold + + matched_vals, matches = match_quality_matrix.max(axis=0), match_quality_matrix.argmax(axis=0) + all_matches = np.copy(matches) if allow_low_quality_matches else None + below_low_threshold = matched_vals < low_threshold + between_thresholds = (matched_vals >= low_threshold) & (matched_vals < high_threshold) + matches[below_low_threshold] = BELOW_LOW_THRESHOLD + matches[between_thresholds] = BETWEEN_THRESHOLDS + + if allow_low_quality_matches: + assert all_matches is not None + _set_low_quality_matches_(matches, all_matches, match_quality_matrix) + + return matches + +def box_iou(boxes1:np.ndarray, boxes2:np.ndarray) -> np.ndarray: + def _box_area(boxes:np.ndarray) -> np.ndarray: return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) + + def _box_inter_union(boxes1:np.ndarray, boxes2:np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + area1, area2 = _box_area(boxes1), _box_area(boxes2) + lt, rb = np.maximum(boxes1[:, None, :2], boxes2[:, :2]), np.minimum(boxes1[:, None, 2:], boxes2[:, 2:]) + wh = np.clip(rb - lt, a_min=0, a_max=None) + inter = wh[:, :, 0] * wh[:, :, 1] + union = area1[:, None] + area2 - inter + return inter, union + + inter, union = _box_inter_union(boxes1, boxes2) + return inter / union diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index a4d714fbec..e3b4b2fa0b 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -349,6 +349,8 @@ def train_retinanet(): from extra.models import resnet from pycocotools.coco import COCO from tinygrad.helpers import get_child + + import numpy as np NUM_CLASSES = len(MLPERF_CLASSES) BASE_DIR = getenv("BASE_DIR", BASEDIR) @@ -380,14 +382,16 @@ def train_retinanet(): optim = Adam(params, lr=LR) # ** dataset ** + anchors = np.concatenate(model.anchor_gen((800, 800)), axis=0) + train_dataset = COCO(download_dataset(BASE_DIR, "train")) val_dataset = COCO(download_dataset(BASE_DIR, "validation")) - train_dataloader = batch_load_retinanet(train_dataset, Path(BASE_DIR), batch_size=256) + train_dataloader = batch_load_retinanet(train_dataset, False, anchors, Path(BASE_DIR), batch_size=256) # ** training loop ** with tqdm(total=len(train_dataset.imgs.keys())) as pbar: - for x, _ in train_dataloader: + for x, _, _, _ in train_dataloader: pbar.update(x.shape[0]) def train_unet3d(): diff --git a/extra/datasets/openimages.py b/extra/datasets/openimages.py index c86c162992..585939f2f0 100644 --- a/extra/datasets/openimages.py +++ b/extra/datasets/openimages.py @@ -6,6 +6,7 @@ from pathlib import Path import boto3, botocore from tinygrad import Tensor from tinygrad.helpers import fetch, tqdm, getenv +from typing import Optional, Dict, Tuple, Union import pandas as pd import concurrent.futures @@ -185,11 +186,24 @@ def random_horizontal_flip(img, tgt, prob=0.5): tgt["boxes"][:, [0, 2]] = w - tgt["boxes"][:, [2, 0]] return img, tgt -def resize(img, size=(800, 800)): +def resize(img:Image, tgt:Optional[Dict[str, Union[np.ndarray, Tuple]]]=None, size:Tuple[int, int]=(800, 800)) -> Union[Tuple[np.ndarray, np.ndarray, Tuple], Tuple[np.ndarray, Tuple]]: import torchvision.transforms.functional as F img_size = img.size[::-1] img = F.resize(img, size=size) img = np.array(img) + + if tgt is not None: + ratios = [s / s_orig for s, s_orig in zip(size, img.shape[::-1])] + ratio_w, ratio_h = ratios + x_min, y_min, x_max, y_max = [tgt["boxes"][:, i] for i in range(tgt["boxes"].shape[-1])] + x_min = x_min * ratio_w + x_max = x_max * ratio_w + y_min = y_min * ratio_h + y_max = y_max * ratio_h + + tgt["boxes"] = np.stack([x_min, y_min, x_max, y_max], axis=1) + return img, tgt, img_size + return img, img_size def normalize(img): diff --git a/extra/models/retinanet.py b/extra/models/retinanet.py index ab67b08d4d..415dd5de25 100644 --- a/extra/models/retinanet.py +++ b/extra/models/retinanet.py @@ -21,18 +21,18 @@ def generate_anchors(input_size, grid_sizes, scales, aspect_ratios): assert len(scales) == len(aspect_ratios) == len(grid_sizes) anchors = [] for s, ar, gs in zip(scales, aspect_ratios, grid_sizes): - s, ar = Tensor(s), Tensor(ar) - h_ratios = ar.sqrt() + s, ar = np.array(s), np.array(ar) + h_ratios = np.sqrt(ar) w_ratios = 1 / h_ratios ws = (w_ratios[:, None] * s[None, :]).reshape(-1) hs = (h_ratios[:, None] * s[None, :]).reshape(-1) - base_anchors = (Tensor.stack(-ws, -hs, ws, hs, dim=1) / 2).round() + base_anchors = (np.stack([-ws, -hs, ws, hs], axis=1) / 2).round() stride_h, stride_w = input_size[0] // gs[0], input_size[1] // gs[1] - shifts_y, shifts_x = meshgrid(Tensor.arange(0, stop=gs[0], dtype=dtypes.float32) * stride_h, Tensor.arange(0, stop=gs[1], dtype=dtypes.float32) * stride_w) + shifts_x, shifts_y = np.meshgrid(np.arange(gs[1]) * stride_w, np.arange(gs[0]) * stride_h) shifts_x = shifts_x.reshape(-1) shifts_y = shifts_y.reshape(-1) - shifts = Tensor.stack(shifts_x, shifts_y, shifts_x, shifts_y, dim=1) - anchors.append((shifts.reshape(-1, 1, 4) + base_anchors.reshape(1, -1, 4)).reshape(-1, 4)) + shifts = np.stack([shifts_x, shifts_y, shifts_x, shifts_y], axis=1, dtype=np.float32) + anchors.append((shifts[:, None] + base_anchors[None, :]).reshape(-1, 4)) return anchors class RetinaNet: @@ -96,7 +96,7 @@ class RetinaNet: # bbox coords from offsets anchor_idxs = topk_idxs // self.num_classes labels_per_level = topk_idxs % self.num_classes - boxes_per_level = decode_bbox(offsets_per_level[anchor_idxs], anchors_per_level.numpy()[anchor_idxs]) # TODO: remove numpy conversion + boxes_per_level = decode_bbox(offsets_per_level[anchor_idxs], anchors_per_level[anchor_idxs]) # clip to image size clipped_x = boxes_per_level[:, 0::2].clip(0, w) clipped_y = boxes_per_level[:, 1::2].clip(0, h)