diff --git a/datasets/openimages.py b/datasets/openimages.py new file mode 100644 index 0000000000..11aee4d0cb --- /dev/null +++ b/datasets/openimages.py @@ -0,0 +1,165 @@ +import os +import math +import json +from extra.utils import OSX +import numpy as np +from PIL import Image +import pathlib +import boto3, botocore +from extra.utils import download_file +from tqdm import tqdm +import pandas as pd +import concurrent.futures + +BASEDIR = pathlib.Path(__file__).parent.parent / "datasets/open-images-v6-mlperf" +BUCKET_NAME = "open-images-dataset" +BBOX_ANNOTATIONS_URL = "https://storage.googleapis.com/openimages/v5/validation-annotations-bbox.csv" +MAP_CLASSES_URL = "https://storage.googleapis.com/openimages/v5/class-descriptions-boxable.csv" +MLPERF_CLASSES = ['Airplane', 'Antelope', 'Apple', 'Backpack', 'Balloon', 'Banana', + 'Barrel', 'Baseball bat', 'Baseball glove', 'Bee', 'Beer', 'Bench', 'Bicycle', + 'Bicycle helmet', 'Bicycle wheel', 'Billboard', 'Book', 'Bookcase', 'Boot', + 'Bottle', 'Bowl', 'Bowling equipment', 'Box', 'Boy', 'Brassiere', 'Bread', + 'Broccoli', 'Bronze sculpture', 'Bull', 'Bus', 'Bust', 'Butterfly', 'Cabinetry', + 'Cake', 'Camel', 'Camera', 'Candle', 'Candy', 'Cannon', 'Canoe', 'Carrot', 'Cart', + 'Castle', 'Cat', 'Cattle', 'Cello', 'Chair', 'Cheese', 'Chest of drawers', 'Chicken', + 'Christmas tree', 'Coat', 'Cocktail', 'Coffee', 'Coffee cup', 'Coffee table', 'Coin', + 'Common sunflower', 'Computer keyboard', 'Computer monitor', 'Convenience store', + 'Cookie', 'Countertop', 'Cowboy hat', 'Crab', 'Crocodile', 'Cucumber', 'Cupboard', + 'Curtain', 'Deer', 'Desk', 'Dinosaur', 'Dog', 'Doll', 'Dolphin', 'Door', 'Dragonfly', + 'Drawer', 'Dress', 'Drum', 'Duck', 'Eagle', 'Earrings', 'Egg (Food)', 'Elephant', + 'Falcon', 'Fedora', 'Flag', 'Flowerpot', 'Football', 'Football helmet', 'Fork', + 'Fountain', 'French fries', 'French horn', 'Frog', 'Giraffe', 'Girl', 'Glasses', + 'Goat', 'Goggles', 'Goldfish', 'Gondola', 'Goose', 'Grape', 'Grapefruit', 'Guitar', + 'Hamburger', 'Handbag', 'Harbor seal', 'Headphones', 'Helicopter', 'High heels', + 'Hiking equipment', 'Horse', 'House', 'Houseplant', 'Human arm', 'Human beard', + 'Human body', 'Human ear', 'Human eye', 'Human face', 'Human foot', 'Human hair', + 'Human hand', 'Human head', 'Human leg', 'Human mouth', 'Human nose', 'Ice cream', + 'Jacket', 'Jeans', 'Jellyfish', 'Juice', 'Kitchen & dining room table', 'Kite', + 'Lamp', 'Lantern', 'Laptop', 'Lavender (Plant)', 'Lemon', 'Light bulb', 'Lighthouse', + 'Lily', 'Lion', 'Lipstick', 'Lizard', 'Man', 'Maple', 'Microphone', 'Mirror', + 'Mixing bowl', 'Mobile phone', 'Monkey', 'Motorcycle', 'Muffin', 'Mug', 'Mule', + 'Mushroom', 'Musical keyboard', 'Necklace', 'Nightstand', 'Office building', + 'Orange', 'Owl', 'Oyster', 'Paddle', 'Palm tree', 'Parachute', 'Parrot', 'Pen', + 'Penguin', 'Personal flotation device', 'Piano', 'Picture frame', 'Pig', 'Pillow', + 'Pizza', 'Plate', 'Platter', 'Porch', 'Poster', 'Pumpkin', 'Rabbit', 'Rifle', + 'Roller skates', 'Rose', 'Salad', 'Sandal', 'Saucer', 'Saxophone', 'Scarf', 'Sea lion', + 'Sea turtle', 'Sheep', 'Shelf', 'Shirt', 'Shorts', 'Shrimp', 'Sink', 'Skateboard', + 'Ski', 'Skull', 'Skyscraper', 'Snake', 'Sock', 'Sofa bed', 'Sparrow', 'Spider', 'Spoon', + 'Sports uniform', 'Squirrel', 'Stairs', 'Stool', 'Strawberry', 'Street light', + 'Studio couch', 'Suit', 'Sun hat', 'Sunglasses', 'Surfboard', 'Sushi', 'Swan', + 'Swimming pool', 'Swimwear', 'Tank', 'Tap', 'Taxi', 'Tea', 'Teddy bear', 'Television', + 'Tent', 'Tie', 'Tiger', 'Tin can', 'Tire', 'Toilet', 'Tomato', 'Tortoise', 'Tower', + 'Traffic light', 'Train', 'Tripod', 'Truck', 'Trumpet', 'Umbrella', 'Van', 'Vase', + 'Vehicle registration plate', 'Violin', 'Wall clock', 'Waste container', 'Watch', + 'Whale', 'Wheel', 'Wheelchair', 'Whiteboard', 'Window', 'Wine', 'Wine glass', 'Woman', + 'Zebra', 'Zucchini', +] + +def openimages(): + ann_file = BASEDIR / "validation/labels/openimages-mlperf.json" + if not ann_file.is_file(): + fetch_openimages(ann_file) + return ann_file + +# this slows down the conversion a lot! +# maybe use https://raw.githubusercontent.com/scardine/image_size/master/get_image_size.py +def extract_dims(path): return Image.open(path).size[::-1] + +def export_to_coco(class_map, annotations, image_list, dataset_path, output_path, classes=MLPERF_CLASSES): + output_path.parent.mkdir(parents=True, exist_ok=True) + cats = [{"id": i, "name": c, "supercategory": None} for i, c in enumerate(classes)] + categories_map = pd.DataFrame([(i, c) for i, c in enumerate(classes)], columns=["category_id", "category_name"]) + class_map = class_map.merge(categories_map, left_on="DisplayName", right_on="category_name", how="inner") + annotations = annotations[np.isin(annotations["ImageID"], image_list)] + annotations = annotations.merge(class_map, on="LabelName", how="inner") + annotations["image_id"] = pd.factorize(annotations["ImageID"].tolist())[0] + annotations[["height", "width"]] = annotations.apply(lambda x: extract_dims(dataset_path / f"{x['ImageID']}.jpg"), axis=1, result_type="expand") + + # Images + imgs = [{"id": int(id + 1), "file_name": f"{image_id}.jpg", "height": row["height"], "width": row["width"], "license": None, "coco_url": None} + for (id, image_id), row in (annotations.groupby(["image_id", "ImageID"]).first().iterrows()) + ] + + # Annotations + annots = [] + for i, row in annotations.iterrows(): + xmin, ymin, xmax, ymax, img_w, img_h = [row[k] for k in ["XMin", "YMin", "XMax", "YMax", "width", "height"]] + x, y, w, h = xmin * img_w, ymin * img_h, (xmax - xmin) * img_w, (ymax - ymin) * img_h + coco_annot = {"id": int(i) + 1, "image_id": int(row["image_id"] + 1), "category_id": int(row["category_id"]), "bbox": [x, y, w, h], "area": w * h} + coco_annot.update({k: row[k] for k in ["IsOccluded", "IsInside", "IsDepiction", "IsTruncated", "IsGroupOf"]}) + coco_annot["iscrowd"] = int(row["IsGroupOf"]) + annots.append(coco_annot) + + info = {"dataset": "openimages_mlperf", "version": "v6"} + coco_annotations = {"info": info, "licenses": [], "categories": cats, "images": imgs, "annotations": annots} + with open(output_path, "w") as fp: + json.dump(coco_annotations, fp) + +def get_image_list(class_map, annotations, classes=MLPERF_CLASSES): + labels = class_map[np.isin(class_map["DisplayName"], classes)]["LabelName"] + image_ids = annotations[np.isin(annotations["LabelName"], labels)]["ImageID"].unique() + return image_ids + +def download_image(bucket, image_id, data_dir): + try: + bucket.download_file(f"validation/{image_id}.jpg", f"{data_dir}/{image_id}.jpg") + except botocore.exceptions.ClientError as exception: + sys.exit(f"ERROR when downloading image `validation/{image_id}`: {str(exception)}") + +def fetch_openimages(output_fn): + bucket = boto3.resource("s3", config=botocore.config.Config(signature_version=botocore.UNSIGNED)).Bucket(BUCKET_NAME) + + annotations_dir, data_dir = BASEDIR / "annotations", BASEDIR / "validation/data" + annotations_dir.mkdir(parents=True, exist_ok=True) + data_dir.mkdir(parents=True, exist_ok=True) + + annotations_fn = annotations_dir / BBOX_ANNOTATIONS_URL.split('/')[-1] + download_file(BBOX_ANNOTATIONS_URL, annotations_fn) + annotations = pd.read_csv(annotations_fn) + + classmap_fn = annotations_dir / MAP_CLASSES_URL.split('/')[-1] + download_file(MAP_CLASSES_URL, classmap_fn) + class_map = pd.read_csv(classmap_fn, names=["LabelName", "DisplayName"]) + + image_list = get_image_list(class_map, annotations) + + with concurrent.futures.ThreadPoolExecutor() as executor: + futures = [executor.submit(download_image, bucket, image_id, data_dir) for image_id in image_list] + for future in (t := tqdm(concurrent.futures.as_completed(futures), total=len(image_list))): + t.set_description(f"Downloading images") + future.result() + + print("Converting annotations to COCO format...") + export_to_coco(class_map, annotations, image_list, data_dir, output_fn) + +def image_load(fn): + img_folder = BASEDIR / "validation/data" + img = Image.open(img_folder / fn).convert('RGB') + import torchvision.transforms.functional as F + ret = F.resize(img, size=(800, 800)) + ret = np.array(ret) + return ret, img.size[::-1] + +def prepare_target(annotations, img_id, img_size): + boxes = [annot["bbox"] for annot in annotations] + boxes = np.array(boxes, dtype=np.float32).reshape(-1, 4) + boxes[:, 2:] += boxes[:, :2] + boxes[:, 0::2] = boxes[:, 0::2].clip(0, img_size[1]) + boxes[:, 1::2] = boxes[:, 1::2].clip(0, img_size[0]) + keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) + boxes = boxes[keep] + classes = [annot["category_id"] for annot in annotations] + classes = np.array(classes, dtype=np.int64) + classes = classes[keep] + return {"boxes": boxes, "labels": classes, "image_id": img_id, "image_size": img_size} + +def iterate(coco, bs=8): + image_ids = sorted(coco.imgs.keys()) + for i in range(0, len(image_ids), bs): + X, targets = [], [] + for img_id in image_ids[i:i+bs]: + x, original_size = image_load(coco.loadImgs(img_id)[0]["file_name"]) + X.append(x) + annotations = coco.loadAnns(coco.getAnnIds(img_id)) + targets.append(prepare_target(annotations, img_id, original_size)) + yield np.array(X), targets diff --git a/examples/mlperf/model_eval.py b/examples/mlperf/model_eval.py index ef66fd6834..bc75b1da83 100644 --- a/examples/mlperf/model_eval.py +++ b/examples/mlperf/model_eval.py @@ -42,6 +42,64 @@ def eval_resnet(): print(f"****** {n}/{d} {n*100.0/d:.2f}%") st = time.perf_counter() +def eval_retinanet(): + # RetinaNet with ResNeXt50_32X4D + from models.resnet import ResNeXt50_32X4D + from models.retinanet import RetinaNet + mdl = RetinaNet(ResNeXt50_32X4D()) + mdl.load_from_pretrained() + + input_mean = Tensor([0.485, 0.456, 0.406]).reshape(1, -1, 1, 1) + input_std = Tensor([0.229, 0.224, 0.225]).reshape(1, -1, 1, 1) + def input_fixup(x): + x = x.permute([0,3,1,2]) / 255.0 + x -= input_mean + x /= input_std + return x + + from datasets.openimages import openimages, iterate + from pycocotools.coco import COCO + from pycocotools.cocoeval import COCOeval + from contextlib import redirect_stdout + coco = COCO(openimages()) + coco_eval = COCOeval(coco, iouType="bbox") + coco_evalimgs, evaluated_imgs, ncats, narea = [], [], len(coco_eval.params.catIds), len(coco_eval.params.areaRng) + + from tinygrad.jit import TinyJit + mdlrun = TinyJit(lambda x: mdl(input_fixup(x)).realize()) + + n, bs = 0, 8 + st = time.perf_counter() + for x, targets in iterate(coco, bs): + dat = Tensor(x.astype(np.float32)) + mt = time.perf_counter() + if dat.shape[0] == bs: + outs = mdlrun(dat).numpy() + else: + mdlrun.jit_cache = None + outs = mdl(input_fixup(dat)).numpy() + et = time.perf_counter() + predictions = mdl.postprocess_detections(outs, input_size=dat.shape[1:3], orig_image_sizes=[t["image_size"] for t in targets]) + ext = time.perf_counter() + n += len(targets) + print(f"[{n}/{len(coco.imgs)}] == {(mt-st)*1000:.2f} ms loading data, {(et-mt)*1000:.2f} ms to run model, {(ext-et)*1000:.2f} ms for postprocessing") + img_ids = [t["image_id"] for t in targets] + coco_results = [{"image_id": targets[i]["image_id"], "category_id": label, "bbox": box, "score": score} + for i, prediction in enumerate(predictions) for box, score, label in zip(*prediction.values())] + with redirect_stdout(None): + coco_eval.cocoDt = coco.loadRes(coco_results) + coco_eval.params.imgIds = img_ids + coco_eval.evaluate() + evaluated_imgs.extend(img_ids) + coco_evalimgs.append(np.array(coco_eval.evalImgs).reshape(ncats, narea, len(img_ids))) + st = time.perf_counter() + + coco_eval.params.imgIds = evaluated_imgs + coco_eval._paramsEval.imgIds = evaluated_imgs + coco_eval.evalImgs = list(np.concatenate(coco_evalimgs, -1).flatten()) + coco_eval.accumulate() + coco_eval.summarize() + def eval_rnnt(): # RNN-T from models.rnnt import RNNT diff --git a/examples/mlperf/model_spec.py b/examples/mlperf/model_spec.py index 6b7c37a9e4..f11f76eb05 100644 --- a/examples/mlperf/model_spec.py +++ b/examples/mlperf/model_spec.py @@ -17,8 +17,12 @@ def spec_resnet(): test_model(mdl, img) def spec_retinanet(): - # TODO: Retinanet - pass + # Retinanet with ResNet backbone + from models.resnet import ResNet50 + from models.retinanet import RetinaNet + mdl = RetinaNet(ResNet50(), num_classes=91, num_anchors=9) + img = Tensor.randn(1, 3, 224, 224) + test_model(mdl, img) def spec_unet3d(): # 3D UNET diff --git a/models/resnet.py b/models/resnet.py index 0712274ef9..8bc955a35f 100644 --- a/models/resnet.py +++ b/models/resnet.py @@ -5,7 +5,8 @@ from extra.utils import get_child class BasicBlock: expansion = 1 - def __init__(self, in_planes, planes, stride=1): + def __init__(self, in_planes, planes, stride=1, groups=1, base_width=64): + assert groups == 1 and base_width == 64, "BasicBlock only supports groups=1 and base_width=64" self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(planes) self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, stride=1, bias=False) @@ -29,12 +30,13 @@ class Bottleneck: # NOTE: the original implementation places stride at the first convolution (self.conv1), this is the v1.5 variant expansion = 4 - def __init__(self, in_planes, planes, stride=1): - self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) - self.bn1 = nn.BatchNorm2d(planes) - self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, stride=stride, bias=False) - self.bn2 = nn.BatchNorm2d(planes) - self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) + def __init__(self, in_planes, planes, stride=1, groups=1, base_width=64): + width = int(planes * (base_width / 64.0)) * groups + self.conv1 = nn.Conv2d(in_planes, width, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(width) + self.conv2 = nn.Conv2d(width, width, kernel_size=3, padding=1, stride=stride, groups=groups, bias=False) + self.bn2 = nn.BatchNorm2d(width) + self.conv3 = nn.Conv2d(width, self.expansion*planes, kernel_size=1, bias=False) self.bn3 = nn.BatchNorm2d(self.expansion*planes) self.downsample = [] if stride != 1 or in_planes != self.expansion*planes: @@ -52,7 +54,7 @@ class Bottleneck: return out class ResNet: - def __init__(self, num, num_classes): + def __init__(self, num, num_classes, groups=1, width_per_group=64): self.num = num self.block = { @@ -73,6 +75,8 @@ class ResNet: self.in_planes = 64 + self.groups = groups + self.base_width = width_per_group self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, bias=False, padding=3) self.bn1 = nn.BatchNorm2d(64) self.layer1 = self._make_layer(self.block, 64, self.num_blocks[0], stride=1) @@ -85,7 +89,7 @@ class ResNet: strides = [stride] + [1] * (num_blocks-1) layers = [] for stride in strides: - layers.append(block(self.in_planes, planes, stride)) + layers.append(block(self.in_planes, planes, stride, self.groups, self.base_width)) self.in_planes = planes * block.expansion return layers @@ -107,14 +111,15 @@ class ResNet: # TODO replace with fake torch load model_urls = { - 18: 'https://download.pytorch.org/models/resnet18-5c106cde.pth', - 34: 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', - 50: 'https://download.pytorch.org/models/resnet50-19c8e357.pth', - 101: 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', - 152: 'https://download.pytorch.org/models/resnet152-b121ed2d.pth' + (18, 1, 64): 'https://download.pytorch.org/models/resnet18-5c106cde.pth', + (34, 1, 64): 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', + (50, 1, 64): 'https://download.pytorch.org/models/resnet50-19c8e357.pth', + (50, 32, 4): 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', + (101, 1, 64): 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', + (152, 1, 64): 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', } - self.url = model_urls[self.num] + self.url = model_urls[(self.num, self.groups, self.base_width)] from torch.hub import load_state_dict_from_url state_dict = load_state_dict_from_url(self.url, progress=True) @@ -126,7 +131,8 @@ class ResNet: print("skipping fully connected layer") continue # Skip FC if transfer learning - assert obj.shape == dat.shape, (k, obj.shape, dat.shape) + # TODO: remove or when #777 is merged + assert obj.shape == dat.shape or (obj.shape == (1,) and dat.shape == ()), (k, obj.shape, dat.shape) obj.assign(dat) ResNet18 = lambda num_classes=1000: ResNet(18, num_classes=num_classes) @@ -134,3 +140,4 @@ ResNet34 = lambda num_classes=1000: ResNet(34, num_classes=num_classes) ResNet50 = lambda num_classes=1000: ResNet(50, num_classes=num_classes) ResNet101 = lambda num_classes=1000: ResNet(101, num_classes=num_classes) ResNet152 = lambda num_classes=1000: ResNet(152, num_classes=num_classes) +ResNeXt50_32X4D = lambda num_classes=1000: ResNet(50, num_classes=num_classes, groups=32, width_per_group=4) diff --git a/models/retinanet.py b/models/retinanet.py new file mode 100644 index 0000000000..8baf0f073e --- /dev/null +++ b/models/retinanet.py @@ -0,0 +1,237 @@ +import math +from tinygrad.helpers import flatten +import tinygrad.nn as nn +from models.resnet import ResNet +from extra.utils import get_child +import numpy as np + +def nms(boxes, scores, thresh=0.5): + x1, y1, x2, y2 = np.rollaxis(boxes, 1) + areas = (x2 - x1 + 1) * (y2 - y1 + 1) + to_process, keep = scores.argsort()[::-1], [] + while to_process.size > 0: + cur, to_process = to_process[0], to_process[1:] + keep.append(cur) + inter_x1 = np.maximum(x1[cur], x1[to_process]) + inter_y1 = np.maximum(y1[cur], y1[to_process]) + inter_x2 = np.minimum(x2[cur], x2[to_process]) + inter_y2 = np.minimum(y2[cur], y2[to_process]) + inter_area = np.maximum(0, inter_x2 - inter_x1 + 1) * np.maximum(0, inter_y2 - inter_y1 + 1) + iou = inter_area / (areas[cur] + areas[to_process] - inter_area) + to_process = to_process[np.where(iou <= thresh)[0]] + return keep + +def decode_bbox(offsets, anchors): + dx, dy, dw, dh = np.rollaxis(offsets, 1) + widths, heights = anchors[:, 2] - anchors[:, 0], anchors[:, 3] - anchors[:, 1] + cx, cy = anchors[:, 0] + 0.5 * widths, anchors[:, 1] + 0.5 * heights + pred_cx, pred_cy = dx * widths + cx, dy * heights + cy + pred_w, pred_h = np.exp(dw) * widths, np.exp(dh) * heights + pred_x1, pred_y1 = pred_cx - 0.5 * pred_w, pred_cy - 0.5 * pred_h + pred_x2, pred_y2 = pred_cx + 0.5 * pred_w, pred_cy + 0.5 * pred_h + return np.stack([pred_x1, pred_y1, pred_x2, pred_y2], axis=1, dtype=np.float32) + +def generate_anchors(input_size, grid_sizes, scales, aspect_ratios): + assert len(scales) == len(aspect_ratios) == len(grid_sizes) + anchors = [] + for s, ar, gs in zip(scales, aspect_ratios, grid_sizes): + s, ar = np.array(s), np.array(ar) + h_ratios = np.sqrt(ar) + w_ratios = 1 / h_ratios + ws = (w_ratios[:, None] * s[None, :]).reshape(-1) + hs = (h_ratios[:, None] * s[None, :]).reshape(-1) + base_anchors = (np.stack([-ws, -hs, ws, hs], axis=1) / 2).round() + stride_h, stride_w = input_size[0] // gs[0], input_size[1] // gs[1] + shifts_x, shifts_y = np.meshgrid(np.arange(gs[1]) * stride_w, np.arange(gs[0]) * stride_h) + shifts_x = shifts_x.reshape(-1) + shifts_y = shifts_y.reshape(-1) + shifts = np.stack([shifts_x, shifts_y, shifts_x, shifts_y], axis=1, dtype=np.float32) + anchors.append((shifts[:, None] + base_anchors[None, :]).reshape(-1, 4)) + return anchors + +class RetinaNet: + def __init__(self, backbone: ResNet, num_classes=264, num_anchors=9, scales=None, aspect_ratios=None): + assert isinstance(backbone, ResNet) + scales = tuple((i, int(i*2**(1/3)), int(i*2**(2/3))) for i in 2**np.arange(5, 10)) if scales is None else scales + aspect_ratios = ((0.5, 1.0, 2.0),) * len(scales) if aspect_ratios is None else aspect_ratios + self.num_anchors, self.num_classes = num_anchors, num_classes + assert len(scales) == len(aspect_ratios) and all([self.num_anchors == len(s) * len(ar) for s, ar in zip(scales, aspect_ratios)]) + + self.backbone = ResNetFPN(backbone) + self.head = RetinaHead(self.backbone.out_channels, num_anchors=num_anchors, num_classes=num_classes) + self.anchor_gen = lambda input_size: generate_anchors(input_size, self.backbone.compute_grid_sizes(input_size), scales, aspect_ratios) + + def __call__(self, x): + return self.forward(x) + def forward(self, x): + return self.head(self.backbone(x)) + + def load_from_pretrained(self): + model_urls = { + (50, 1, 64): "https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth", + (50, 32, 4): "https://zenodo.org/record/6605272/files/retinanet_model_10.zip", + } + self.url = model_urls[(self.backbone.body.num, self.backbone.body.groups, self.backbone.body.base_width)] + from torch.hub import load_state_dict_from_url + state_dict = load_state_dict_from_url(self.url, progress=True, map_location='cpu') + state_dict = state_dict['model'] if 'model' in state_dict.keys() else state_dict + for k, v in state_dict.items(): + obj = get_child(self, k) + dat = v.detach().numpy() + assert obj.shape == dat.shape, (k, obj.shape, dat.shape) + obj.assign(dat) + + # predictions: (BS, (H1W1+...+HmWm)A, 4 + K) + def postprocess_detections(self, predictions, input_size=(800, 800), image_sizes=None, orig_image_sizes=None, score_thresh=0.05, topk_candidates=1000, nms_thresh=0.5): + anchors = self.anchor_gen(input_size) + grid_sizes = self.backbone.compute_grid_sizes(input_size) + split_idx = np.cumsum([int(self.num_anchors * sz[0] * sz[1]) for sz in grid_sizes[:-1]]) + detections = [] + for i, predictions_per_image in enumerate(predictions): + h, w = input_size if image_sizes is None else image_sizes[i] + + predictions_per_image = np.split(predictions_per_image, split_idx) + offsets_per_image = [br[:, :4] for br in predictions_per_image] + scores_per_image = [cl[:, 4:] for cl in predictions_per_image] + + image_boxes, image_scores, image_labels = [], [], [] + for offsets_per_level, scores_per_level, anchors_per_level in zip(offsets_per_image, scores_per_image, anchors): + # remove low scoring boxes + scores_per_level = scores_per_level.flatten() + keep_idxs = scores_per_level > score_thresh + scores_per_level = scores_per_level[keep_idxs] + + # keep topk + topk_idxs = np.where(keep_idxs)[0] + num_topk = min(len(topk_idxs), topk_candidates) + sort_idxs = scores_per_level.argsort()[-num_topk:][::-1] + topk_idxs, scores_per_level = topk_idxs[sort_idxs], scores_per_level[sort_idxs] + + # bbox coords from offsets + anchor_idxs = topk_idxs // self.num_classes + labels_per_level = topk_idxs % self.num_classes + boxes_per_level = decode_bbox(offsets_per_level[anchor_idxs], anchors_per_level[anchor_idxs]) + # clip to image size + clipped_x = boxes_per_level[:, 0::2].clip(0, w) + clipped_y = boxes_per_level[:, 1::2].clip(0, h) + boxes_per_level = np.stack([clipped_x, clipped_y], axis=2).reshape(-1, 4) + + image_boxes.append(boxes_per_level) + image_scores.append(scores_per_level) + image_labels.append(labels_per_level) + + image_boxes = np.concatenate(image_boxes) + image_scores = np.concatenate(image_scores) + image_labels = np.concatenate(image_labels) + + # nms for each class + keep_mask = np.zeros_like(image_scores, dtype=bool) + for class_id in np.unique(image_labels): + curr_indices = np.where(image_labels == class_id)[0] + curr_keep_indices = nms(image_boxes[curr_indices], image_scores[curr_indices], nms_thresh) + keep_mask[curr_indices[curr_keep_indices]] = True + keep = np.where(keep_mask)[0] + keep = keep[image_scores[keep].argsort()[::-1]] + + # resize bboxes back to original size + image_boxes = image_boxes[keep] + if orig_image_sizes is not None: + resized_x = image_boxes[:, 0::2] * orig_image_sizes[i][1] / w + resized_y = image_boxes[:, 1::2] * orig_image_sizes[i][0] / h + image_boxes = np.stack([resized_x, resized_y], axis=2).reshape(-1, 4) + # xywh format + image_boxes = np.concatenate([image_boxes[:, :2], image_boxes[:, 2:] - image_boxes[:, :2]], axis=1) + + detections.append({"boxes":image_boxes, "scores":image_scores[keep], "labels":image_labels[keep]}) + return detections + +class ClassificationHead: + def __init__(self, in_channels, num_anchors, num_classes): + self.num_classes = num_classes + self.conv = flatten([(nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1), lambda x: x.relu()) for _ in range(4)]) + self.cls_logits = nn.Conv2d(in_channels, num_anchors * num_classes, kernel_size=3, padding=1) + def __call__(self, x): + out = [self.cls_logits(feat.sequential(self.conv)).permute(0, 2, 3, 1).reshape(feat.shape[0], -1, self.num_classes) for feat in x] + return out[0].cat(*out[1:], dim=1).sigmoid() + +class RegressionHead: + def __init__(self, in_channels, num_anchors): + self.conv = flatten([(nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1), lambda x: x.relu()) for _ in range(4)]) + self.bbox_reg = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=3, padding=1) + def __call__(self, x): + out = [self.bbox_reg(feat.sequential(self.conv)).permute(0, 2, 3, 1).reshape(feat.shape[0], -1, 4) for feat in x] + return out[0].cat(*out[1:], dim=1) + +class RetinaHead: + def __init__(self, in_channels, num_anchors, num_classes): + self.classification_head = ClassificationHead(in_channels, num_anchors, num_classes) + self.regression_head = RegressionHead(in_channels, num_anchors) + def __call__(self, x): + pred_bbox, pred_class = self.regression_head(x), self.classification_head(x) + out = pred_bbox.cat(pred_class, dim=-1) + return out + +class ResNetFPN: + def __init__(self, resnet, out_channels=256, returned_layers=[2, 3, 4]): + self.out_channels = out_channels + self.body = resnet + in_channels_list = [(self.body.in_planes // 8) * 2 ** (i - 1) for i in returned_layers] + self.fpn = FPN(in_channels_list, out_channels) + + # this is needed to decouple inference from postprocessing (anchors generation) + def compute_grid_sizes(self, input_size): + return np.ceil(np.array(input_size)[None, :] / 2 ** np.arange(3, 8)[:, None]) + + def __call__(self, x): + out = self.body.bn1(self.body.conv1(x)).relu() + out = out.pad2d([1,1,1,1]).max_pool2d((3,3), 2) + out = out.sequential(self.body.layer1) + p3 = out.sequential(self.body.layer2) + p4 = p3.sequential(self.body.layer3) + p5 = p4.sequential(self.body.layer4) + return self.fpn([p3, p4, p5]) + +class ExtraFPNBlock: + def __init__(self, in_channels, out_channels): + self.p6 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1) + self.p7 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=2, padding=1) + self.use_P5 = in_channels == out_channels + + def __call__(self, p, c): + p5, c5 = p[-1], c[-1] + x = p5 if self.use_P5 else c5 + p6 = self.p6(x) + p7 = self.p7(p6.relu()) + p.extend([p6, p7]) + return p + +class FPN: + def __init__(self, in_channels_list, out_channels, extra_blocks=None): + self.inner_blocks, self.layer_blocks = [], [] + for in_channels in in_channels_list: + self.inner_blocks.append(nn.Conv2d(in_channels, out_channels, kernel_size=1)) + self.layer_blocks.append(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)) + self.extra_blocks = ExtraFPNBlock(256, 256) if extra_blocks is None else extra_blocks + + def __call__(self, x): + last_inner = self.inner_blocks[-1](x[-1]) + results = [self.layer_blocks[-1](last_inner)] + for idx in range(len(x) - 2, -1, -1): + inner_lateral = self.inner_blocks[idx](x[idx]) + + # upsample to inner_lateral's shape + (ih, iw), (oh, ow), prefix = last_inner.shape[-2:], inner_lateral.shape[-2:], last_inner.shape[:-2] + eh, ew = math.ceil(oh / ih), math.ceil(ow / iw) + inner_top_down = last_inner.reshape(*prefix, ih, 1, iw, 1).expand(*prefix, ih, eh, iw, ew).reshape(*prefix, ih*eh, iw*ew)[:, :, :oh, :ow] + + last_inner = inner_lateral + inner_top_down + results.insert(0, self.layer_blocks[idx](last_inner)) + if self.extra_blocks is not None: + results = self.extra_blocks(results, x) + return results + +if __name__ == "__main__": + from models.resnet import ResNeXt50_32X4D + backbone = ResNeXt50_32X4D() + retina = RetinaNet(backbone) + retina.load_from_pretrained()