add dataloader

This commit is contained in:
Francis Lata
2024-10-11 23:04:21 -04:00
parent 5dbebf460e
commit b802f74cee
5 changed files with 104 additions and 39 deletions

View File

@@ -1,5 +1,5 @@
import os, random, pickle, queue
from typing import List
from typing import List, Tuple
from pathlib import Path
from multiprocessing import Queue, Process, shared_memory, connection, Lock, cpu_count
@@ -356,34 +356,50 @@ def batch_load_unet3d(preprocessed_dataset_dir:Path, batch_size:int=6, val:bool=
### RetinaNet
def load_retinanet_data(base_dir:Path, queue_in:Queue, queue_out:Queue, X:Tensor):
def load_retinanet_data(base_dir:Path, val:bool, queue_in:Queue, queue_out:Queue, X:Tensor, Y_boxes:Tensor, Y_labels:Tensor, anchors:np.ndarray):
from extra.datasets.openimages import image_load, prepare_target, random_horizontal_flip, resize
from examples.mlperf.helpers import box_iou, find_matches
while (data:=queue_in.get()) is not None:
idx, img, ann = data
img_id = img["id"]
img = image_load(base_dir, "train", img["file_name"])
img = image_load(base_dir, img["subset"], img["file_name"])
tgt = prepare_target(ann, img_id, img.size[::-1])
img, tgt = random_horizontal_flip(img, tgt)
img, _ = resize(img)
img, tgt, _ = resize(img, tgt=tgt)
match_quality_matrix = box_iou(tgt["boxes"], anchors)
matches = find_matches(match_quality_matrix, allow_low_quality_matches=True)
matches = np.clip(matches, 0, None)
boxes, labels = tgt["boxes"][matches], tgt["labels"][matches]
X[idx].contiguous().realize().lazydata.realized.as_buffer(force_zero_copy=True)[:] = img.tobytes()
Y_boxes[idx].contiguous().realize().lazydata.realized.as_buffer(force_zero_copy=True)[:] = boxes.tobytes()
Y_labels[idx].contiguous().realize().lazydata.realized.as_buffer(force_zero_copy=True)[:] = labels.tobytes()
queue_out.put(idx)
queue_out.put(None)
def batch_load_retinanet(dataset, base_dir:Path, batch_size:int=32, seed:int=None):
def batch_load_retinanet(dataset, val:bool, anchors:np.ndarray, base_dir:Path, batch_size:int=32, seed:int=None):
def _enqueue_batch(bc):
for idx in range(bc * batch_size, (bc+1) * batch_size):
img = dataset.loadImgs(next(dataset_iter))[0]
ann = dataset.loadAnns(dataset.getAnnIds(img["id"]))
queue_in.put((idx, img, ann))
def _setup_shared_mem(shm_name:str, size:Tuple[int, ...], dtype:dtypes) -> Tuple[shared_memory.SharedMemory, Tensor]:
if os.path.exists(f"/dev/shm/{shm_name}"): os.unlink(f"/dev/shm/{shm_name}")
shm = shared_memory.SharedMemory(name=shm_name, create=True, size=prod(size))
shm_tensor = Tensor.empty(*size, dtype=dtype, device=f"disk:/dev/shm/{shm_name}")
return shm, shm_tensor
image_ids = sorted(dataset.imgs.keys())
batch_count = min(32, len(image_ids) // batch_size)
queue_in, queue_out = Queue(), Queue()
procs, data_out_count = [], [0] * batch_count
shm_name_x = "retinanet_x"
# shm_name_x, shm_name_y = "retinanet_x", "retinanet_y"
sz = (batch_size * batch_count, 800, 800, 3)
if os.path.exists(f"/dev/shm/{shm_name_x}"): os.unlink(f"/dev/shm/{shm_name_x}")
# if os.path.exists(f"/dev/shm/{shm_name_y}"): os.unlink(f"/dev/shm/{shm_name_y}")
shm_x = shared_memory.SharedMemory(name=shm_name_x, create=True, size=prod(sz))
# shm_y = shared_memory.SharedMemory(name=shm_name_y, create=True, size=prod(sz))
shm_x, X = _setup_shared_mem("retinanet_x", (batch_size * batch_count, 800, 800, 3), dtypes.uint8)
shm_y_boxes, Y_boxes = _setup_shared_mem("retinanet_y_boxes", (batch_size * batch_count, 120087, 4), dtypes.float32)
shm_y_labels, Y_labels = _setup_shared_mem("retinanet_y_labels", (batch_size * batch_count, 120087), dtypes.int64)
shutdown = False
class Cookie:
@@ -391,15 +407,9 @@ def batch_load_retinanet(dataset, base_dir:Path, batch_size:int=32, seed:int=Non
self.bc = bc
def __del__(self):
if not shutdown:
try: enqueue_batch(self.bc)
try: _enqueue_batch(self.bc)
except StopIteration: pass
def enqueue_batch(bc):
for idx in range(bc * batch_size, (bc+1) * batch_size):
img = dataset.loadImgs(next(dataset_iter))[0]
ann = dataset.loadAnns(dataset.getAnnIds(img["id"]))
queue_in.put((idx, img, ann))
# def shuffle_indices(file_indices, seed=None):
# rng = random.Random(seed)
# rng.shuffle(file_indices)
@@ -408,17 +418,14 @@ def batch_load_retinanet(dataset, base_dir:Path, batch_size:int=32, seed:int=Non
dataset_iter = iter(image_ids)
try:
X = Tensor.empty(*sz, dtype=dtypes.uint8, device=f"disk:/dev/shm/{shm_name_x}")
# Y = Tensor.empty(*sz, dtype=dtypes.uint8, device=f"disk:/dev/shm/{shm_name_y}")
for _ in range(cpu_count()):
proc = Process(target=load_retinanet_data, args=(base_dir, queue_in, queue_out, X))
proc = Process(target=load_retinanet_data, args=(base_dir, val, queue_in, queue_out, X, Y_boxes, Y_labels, anchors))
proc.daemon = True
proc.start()
procs.append(proc)
for bc in range(batch_count):
enqueue_batch(bc)
_enqueue_batch(bc)
for _ in range(len(image_ids) // batch_size):
while True:
@@ -427,7 +434,7 @@ def batch_load_retinanet(dataset, base_dir:Path, batch_size:int=32, seed:int=Non
if data_out_count[bc] == batch_size: break
data_out_count[bc] = 0
yield X[bc * batch_size:(bc + 1) * batch_size], Cookie(bc)
yield X[bc * batch_size:(bc + 1) * batch_size], Y_boxes[bc * batch_size:(bc + 1) * batch_size], Y_labels[bc * batch_size:(bc + 1) * batch_size], Cookie(bc)
finally:
shutdown = True
@@ -442,10 +449,12 @@ def batch_load_retinanet(dataset, base_dir:Path, batch_size:int=32, seed:int=Non
for proc in procs: proc.join()
shm_x.close()
# shm_y.close()
shm_y_boxes.close()
shm_y_labels.close()
try:
shm_x.unlink()
# shm_y.unlink()
shm_y_boxes.unlink()
shm_y_labels.unlink()
except FileNotFoundError:
# happens with BENCHMARK set
pass
@@ -474,8 +483,9 @@ if __name__ == "__main__":
from extra.datasets.openimages import BASEDIR, download_dataset
from pycocotools.coco import COCO
dataset = COCO(download_dataset(base_dir:=getenv("BASE_DIR", BASEDIR), "validation" if val else "train"))
anchors = np.ones((120087, 4))
with tqdm(total=len(dataset.imgs.keys())) as pbar:
for x, _ in batch_load_retinanet(dataset, base_dir):
for x, _, _, _ in batch_load_retinanet(dataset, val, anchors, base_dir):
pbar.update(x.shape[0])
load_fn_name = f"load_{getenv('MODEL', 'resnet')}"

View File

@@ -1,6 +1,6 @@
from collections import OrderedDict
import unicodedata
from typing import Optional
from typing import Optional, Tuple
import numpy as np
from tinygrad.nn import state
from tinygrad.tensor import Tensor, dtypes
@@ -238,3 +238,40 @@ def get_fake_data_bert(GPUS:list[str], BS:int):
"masked_lm_weights": Tensor.empty((BS, 76), dtype=dtypes.float32).contiguous().shard_(GPUS, axis=0),
"next_sentence_labels": Tensor.empty((BS, 1), dtype=dtypes.float32).contiguous().shard_(GPUS, axis=0),
}
def find_matches(match_quality_matrix:np.ndarray, high_threshold:float=0.5, low_threshold:float=0.4, allow_low_quality_matches:bool=False) -> np.ndarray:
BELOW_LOW_THRESHOLD, BETWEEN_THRESHOLDS = -1, -2
def _set_low_quality_matches_(matches:np.ndarray, all_matches:np.ndarray, match_quality_matrix:np.ndarray):
highest_quality_foreach_gt = np.max(match_quality_matrix, axis=1)
pred_inds_to_update = np.nonzero(match_quality_matrix == highest_quality_foreach_gt[:, None])[1]
matches[pred_inds_to_update] = all_matches[pred_inds_to_update]
assert low_threshold <= high_threshold
matched_vals, matches = match_quality_matrix.max(axis=0), match_quality_matrix.argmax(axis=0)
all_matches = np.copy(matches) if allow_low_quality_matches else None
below_low_threshold = matched_vals < low_threshold
between_thresholds = (matched_vals >= low_threshold) & (matched_vals < high_threshold)
matches[below_low_threshold] = BELOW_LOW_THRESHOLD
matches[between_thresholds] = BETWEEN_THRESHOLDS
if allow_low_quality_matches:
assert all_matches is not None
_set_low_quality_matches_(matches, all_matches, match_quality_matrix)
return matches
def box_iou(boxes1:np.ndarray, boxes2:np.ndarray) -> np.ndarray:
def _box_area(boxes:np.ndarray) -> np.ndarray: return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
def _box_inter_union(boxes1:np.ndarray, boxes2:np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
area1, area2 = _box_area(boxes1), _box_area(boxes2)
lt, rb = np.maximum(boxes1[:, None, :2], boxes2[:, :2]), np.minimum(boxes1[:, None, 2:], boxes2[:, 2:])
wh = np.clip(rb - lt, a_min=0, a_max=None)
inter = wh[:, :, 0] * wh[:, :, 1]
union = area1[:, None] + area2 - inter
return inter, union
inter, union = _box_inter_union(boxes1, boxes2)
return inter / union

View File

@@ -349,6 +349,8 @@ def train_retinanet():
from extra.models import resnet
from pycocotools.coco import COCO
from tinygrad.helpers import get_child
import numpy as np
NUM_CLASSES = len(MLPERF_CLASSES)
BASE_DIR = getenv("BASE_DIR", BASEDIR)
@@ -380,14 +382,16 @@ def train_retinanet():
optim = Adam(params, lr=LR)
# ** dataset **
anchors = np.concatenate(model.anchor_gen((800, 800)), axis=0)
train_dataset = COCO(download_dataset(BASE_DIR, "train"))
val_dataset = COCO(download_dataset(BASE_DIR, "validation"))
train_dataloader = batch_load_retinanet(train_dataset, Path(BASE_DIR), batch_size=256)
train_dataloader = batch_load_retinanet(train_dataset, False, anchors, Path(BASE_DIR), batch_size=256)
# ** training loop **
with tqdm(total=len(train_dataset.imgs.keys())) as pbar:
for x, _ in train_dataloader:
for x, _, _, _ in train_dataloader:
pbar.update(x.shape[0])
def train_unet3d():

View File

@@ -6,6 +6,7 @@ from pathlib import Path
import boto3, botocore
from tinygrad import Tensor
from tinygrad.helpers import fetch, tqdm, getenv
from typing import Optional, Dict, Tuple, Union
import pandas as pd
import concurrent.futures
@@ -185,11 +186,24 @@ def random_horizontal_flip(img, tgt, prob=0.5):
tgt["boxes"][:, [0, 2]] = w - tgt["boxes"][:, [2, 0]]
return img, tgt
def resize(img, size=(800, 800)):
def resize(img:Image, tgt:Optional[Dict[str, Union[np.ndarray, Tuple]]]=None, size:Tuple[int, int]=(800, 800)) -> Union[Tuple[np.ndarray, np.ndarray, Tuple], Tuple[np.ndarray, Tuple]]:
import torchvision.transforms.functional as F
img_size = img.size[::-1]
img = F.resize(img, size=size)
img = np.array(img)
if tgt is not None:
ratios = [s / s_orig for s, s_orig in zip(size, img.shape[::-1])]
ratio_w, ratio_h = ratios
x_min, y_min, x_max, y_max = [tgt["boxes"][:, i] for i in range(tgt["boxes"].shape[-1])]
x_min = x_min * ratio_w
x_max = x_max * ratio_w
y_min = y_min * ratio_h
y_max = y_max * ratio_h
tgt["boxes"] = np.stack([x_min, y_min, x_max, y_max], axis=1)
return img, tgt, img_size
return img, img_size
def normalize(img):

View File

@@ -21,18 +21,18 @@ def generate_anchors(input_size, grid_sizes, scales, aspect_ratios):
assert len(scales) == len(aspect_ratios) == len(grid_sizes)
anchors = []
for s, ar, gs in zip(scales, aspect_ratios, grid_sizes):
s, ar = Tensor(s), Tensor(ar)
h_ratios = ar.sqrt()
s, ar = np.array(s), np.array(ar)
h_ratios = np.sqrt(ar)
w_ratios = 1 / h_ratios
ws = (w_ratios[:, None] * s[None, :]).reshape(-1)
hs = (h_ratios[:, None] * s[None, :]).reshape(-1)
base_anchors = (Tensor.stack(-ws, -hs, ws, hs, dim=1) / 2).round()
base_anchors = (np.stack([-ws, -hs, ws, hs], axis=1) / 2).round()
stride_h, stride_w = input_size[0] // gs[0], input_size[1] // gs[1]
shifts_y, shifts_x = meshgrid(Tensor.arange(0, stop=gs[0], dtype=dtypes.float32) * stride_h, Tensor.arange(0, stop=gs[1], dtype=dtypes.float32) * stride_w)
shifts_x, shifts_y = np.meshgrid(np.arange(gs[1]) * stride_w, np.arange(gs[0]) * stride_h)
shifts_x = shifts_x.reshape(-1)
shifts_y = shifts_y.reshape(-1)
shifts = Tensor.stack(shifts_x, shifts_y, shifts_x, shifts_y, dim=1)
anchors.append((shifts.reshape(-1, 1, 4) + base_anchors.reshape(1, -1, 4)).reshape(-1, 4))
shifts = np.stack([shifts_x, shifts_y, shifts_x, shifts_y], axis=1, dtype=np.float32)
anchors.append((shifts[:, None] + base_anchors[None, :]).reshape(-1, 4))
return anchors
class RetinaNet:
@@ -96,7 +96,7 @@ class RetinaNet:
# bbox coords from offsets
anchor_idxs = topk_idxs // self.num_classes
labels_per_level = topk_idxs % self.num_classes
boxes_per_level = decode_bbox(offsets_per_level[anchor_idxs], anchors_per_level.numpy()[anchor_idxs]) # TODO: remove numpy conversion
boxes_per_level = decode_bbox(offsets_per_level[anchor_idxs], anchors_per_level[anchor_idxs])
# clip to image size
clipped_x = boxes_per_level[:, 0::2].clip(0, w)
clipped_y = boxes_per_level[:, 1::2].clip(0, h)