mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-11 07:05:04 -05:00
replaced all dtype.np with _to_np_dtype defined in tensor.py. after this, the only numpy usages are (1) Tensor(np.ndarray), (2) construct .numpy() output, (3) numpy random buffer
1272 lines
41 KiB
Python
1272 lines
41 KiB
Python
import re
|
|
import math
|
|
import os
|
|
import numpy as np
|
|
from pathlib import Path
|
|
from tinygrad import nn, Tensor, dtypes
|
|
from tinygrad.tensor import _to_np_dtype
|
|
from tinygrad.helpers import get_child, fetch
|
|
from tinygrad.nn.state import torch_load
|
|
from extra.models.resnet import ResNet
|
|
from extra.models.retinanet import nms as _box_nms
|
|
|
|
USE_NP_GATHER = os.getenv('FULL_TINYGRAD', '0') == '0'
|
|
|
|
def rint(tensor):
|
|
x = (tensor*2).cast(dtypes.int32).contiguous().cast(dtypes.float32)/2
|
|
return (x<0).where(x.floor(), x.ceil())
|
|
|
|
def nearest_interpolate(tensor, scale_factor):
|
|
bs, c, py, px = tensor.shape
|
|
return tensor.reshape(bs, c, py, 1, px, 1).expand(bs, c, py, scale_factor, px, scale_factor).reshape(bs, c, py * scale_factor, px * scale_factor)
|
|
|
|
def meshgrid(x, y):
|
|
grid_x = Tensor.cat(*[x[idx:idx+1].expand(y.shape).unsqueeze(0) for idx in range(x.shape[0])])
|
|
grid_y = Tensor.cat(*[y.unsqueeze(0)]*x.shape[0])
|
|
return grid_x.reshape(-1, 1), grid_y.reshape(-1, 1)
|
|
|
|
def topk(input_, k, dim=-1, largest=True, sorted=False):
|
|
k = min(k, input_.shape[dim]-1)
|
|
input_ = input_.numpy()
|
|
if largest: input_ *= -1
|
|
ind = np.argpartition(input_, k, axis=dim)
|
|
if largest: input_ *= -1
|
|
ind = np.take(ind, np.arange(k), axis=dim) # k non-sorted indices
|
|
input_ = np.take_along_axis(input_, ind, axis=dim) # k non-sorted values
|
|
if not sorted: return Tensor(input_), ind
|
|
if largest: input_ *= -1
|
|
ind_part = np.argsort(input_, axis=dim)
|
|
ind = np.take_along_axis(ind, ind_part, axis=dim)
|
|
if largest: input_ *= -1
|
|
val = np.take_along_axis(input_, ind_part, axis=dim)
|
|
return Tensor(val), ind
|
|
|
|
# This is very slow for large arrays, or indices
|
|
def _gather(array, indices):
|
|
indices = indices.float().to(array.device)
|
|
reshape_arg = [1]*array.ndim + [array.shape[-1]]
|
|
return Tensor.where(
|
|
indices.unsqueeze(indices.ndim).expand(*indices.shape, array.shape[-1]) == Tensor.arange(array.shape[-1]).reshape(*reshape_arg).expand(*indices.shape, array.shape[-1]),
|
|
array, 0,
|
|
).sum(indices.ndim)
|
|
|
|
# TODO: replace npgather with a faster gather using tinygrad only
|
|
# NOTE: this blocks the gradient
|
|
def npgather(array,indices):
|
|
if isinstance(array, Tensor): array = array.numpy()
|
|
if isinstance(indices, Tensor): indices = indices.numpy()
|
|
if isinstance(indices, list): indices = np.asarray(indices)
|
|
return Tensor(array[indices.astype(int)])
|
|
|
|
def get_strides(shape):
|
|
prod = [1]
|
|
for idx in range(len(shape)-1, -1, -1): prod.append(prod[-1] * shape[idx])
|
|
# something about ints is broken with gpu, cuda
|
|
return Tensor(prod[::-1][1:], dtype=dtypes.int32).unsqueeze(0)
|
|
|
|
# with keys as integer array for all axes
|
|
def tensor_getitem(tensor, *keys):
|
|
# something about ints is broken with gpu, cuda
|
|
flat_keys = Tensor.stack(*[key.expand((sum(keys)).shape).reshape(-1) for key in keys], dim=1).cast(dtypes.int32)
|
|
strides = get_strides(tensor.shape)
|
|
idxs = (flat_keys * strides).sum(1)
|
|
gatherer = npgather if USE_NP_GATHER else _gather
|
|
return gatherer(tensor.reshape(-1), idxs).reshape(sum(keys).shape)
|
|
|
|
|
|
# for gather with indicies only on axis=0
|
|
def tensor_gather(tensor, indices):
|
|
if not isinstance(indices, Tensor):
|
|
indices = Tensor(indices, requires_grad=False)
|
|
if len(tensor.shape) > 2:
|
|
rem_shape = list(tensor.shape)[1:]
|
|
tensor = tensor.reshape(tensor.shape[0], -1)
|
|
else:
|
|
rem_shape = None
|
|
if len(tensor.shape) > 1:
|
|
tensor = tensor.T
|
|
repeat_arg = [1]*(tensor.ndim-1) + [tensor.shape[-2]]
|
|
indices = indices.unsqueeze(indices.ndim).repeat(repeat_arg)
|
|
ret = _gather(tensor, indices)
|
|
if rem_shape:
|
|
ret = ret.reshape([indices.shape[0]] + rem_shape)
|
|
else:
|
|
ret = _gather(tensor, indices)
|
|
del indices
|
|
return ret
|
|
|
|
|
|
class LastLevelMaxPool:
|
|
def __call__(self, x): return [Tensor.max_pool2d(x, 1, 2)]
|
|
|
|
|
|
# transpose
|
|
FLIP_LEFT_RIGHT = 0
|
|
FLIP_TOP_BOTTOM = 1
|
|
|
|
|
|
def permute_and_flatten(layer:Tensor, N, A, C, H, W):
|
|
layer = layer.reshape(N, -1, C, H, W)
|
|
layer = layer.permute(0, 3, 4, 1, 2)
|
|
layer = layer.reshape(N, -1, C)
|
|
return layer
|
|
|
|
|
|
class BoxList:
|
|
def __init__(self, bbox, image_size, mode="xyxy"):
|
|
if not isinstance(bbox, Tensor):
|
|
bbox = Tensor(bbox)
|
|
if bbox.ndim != 2:
|
|
raise ValueError(
|
|
"bbox should have 2 dimensions, got {}".format(bbox.ndim)
|
|
)
|
|
if bbox.shape[-1] != 4:
|
|
raise ValueError(
|
|
"last dimenion of bbox should have a "
|
|
"size of 4, got {}".format(bbox.shape[-1])
|
|
)
|
|
if mode not in ("xyxy", "xywh"):
|
|
raise ValueError("mode should be 'xyxy' or 'xywh'")
|
|
|
|
self.bbox = bbox
|
|
self.size = image_size # (image_width, image_height)
|
|
self.mode = mode
|
|
self.extra_fields = {}
|
|
|
|
def __repr__(self):
|
|
s = self.__class__.__name__ + "("
|
|
s += "num_boxes={}, ".format(len(self))
|
|
s += "image_width={}, ".format(self.size[0])
|
|
s += "image_height={}, ".format(self.size[1])
|
|
s += "mode={})".format(self.mode)
|
|
return s
|
|
|
|
def area(self):
|
|
box = self.bbox
|
|
if self.mode == "xyxy":
|
|
TO_REMOVE = 1
|
|
area = (box[:, 2] - box[:, 0] + TO_REMOVE) * (box[:, 3] - box[:, 1] + TO_REMOVE)
|
|
elif self.mode == "xywh":
|
|
area = box[:, 2] * box[:, 3]
|
|
return area
|
|
|
|
def add_field(self, field, field_data):
|
|
self.extra_fields[field] = field_data
|
|
|
|
def get_field(self, field):
|
|
return self.extra_fields[field]
|
|
|
|
def has_field(self, field):
|
|
return field in self.extra_fields
|
|
|
|
def fields(self):
|
|
return list(self.extra_fields.keys())
|
|
|
|
def _copy_extra_fields(self, bbox):
|
|
for k, v in bbox.extra_fields.items():
|
|
self.extra_fields[k] = v
|
|
|
|
def convert(self, mode):
|
|
if mode == self.mode:
|
|
return self
|
|
xmin, ymin, xmax, ymax = self._split_into_xyxy()
|
|
if mode == "xyxy":
|
|
bbox = Tensor.cat(*(xmin, ymin, xmax, ymax), dim=-1)
|
|
bbox = BoxList(bbox, self.size, mode=mode)
|
|
else:
|
|
TO_REMOVE = 1
|
|
bbox = Tensor.cat(
|
|
*(xmin, ymin, xmax - xmin + TO_REMOVE, ymax - ymin + TO_REMOVE), dim=-1
|
|
)
|
|
bbox = BoxList(bbox, self.size, mode=mode)
|
|
bbox._copy_extra_fields(self)
|
|
return bbox
|
|
|
|
def _split_into_xyxy(self):
|
|
if self.mode == "xyxy":
|
|
xmin, ymin, xmax, ymax = self.bbox.chunk(4, dim=-1)
|
|
return xmin, ymin, xmax, ymax
|
|
if self.mode == "xywh":
|
|
TO_REMOVE = 1
|
|
xmin, ymin, w, h = self.bbox.chunk(4, dim=-1)
|
|
return (
|
|
xmin,
|
|
ymin,
|
|
xmin + (w - TO_REMOVE).clamp(min=0),
|
|
ymin + (h - TO_REMOVE).clamp(min=0),
|
|
)
|
|
|
|
def resize(self, size, *args, **kwargs):
|
|
ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(size, self.size))
|
|
if ratios[0] == ratios[1]:
|
|
ratio = ratios[0]
|
|
scaled_box = self.bbox * ratio
|
|
bbox = BoxList(scaled_box, size, mode=self.mode)
|
|
for k, v in self.extra_fields.items():
|
|
if not isinstance(v, Tensor):
|
|
v = v.resize(size, *args, **kwargs)
|
|
bbox.add_field(k, v)
|
|
return bbox
|
|
|
|
ratio_width, ratio_height = ratios
|
|
xmin, ymin, xmax, ymax = self._split_into_xyxy()
|
|
scaled_xmin = xmin * ratio_width
|
|
scaled_xmax = xmax * ratio_width
|
|
scaled_ymin = ymin * ratio_height
|
|
scaled_ymax = ymax * ratio_height
|
|
scaled_box = Tensor.cat(
|
|
*(scaled_xmin, scaled_ymin, scaled_xmax, scaled_ymax), dim=-1
|
|
)
|
|
bbox = BoxList(scaled_box, size, mode="xyxy")
|
|
for k, v in self.extra_fields.items():
|
|
if not isinstance(v, Tensor):
|
|
v = v.resize(size, *args, **kwargs)
|
|
bbox.add_field(k, v)
|
|
|
|
return bbox.convert(self.mode)
|
|
|
|
def transpose(self, method):
|
|
image_width, image_height = self.size
|
|
xmin, ymin, xmax, ymax = self._split_into_xyxy()
|
|
if method == FLIP_LEFT_RIGHT:
|
|
TO_REMOVE = 1
|
|
transposed_xmin = image_width - xmax - TO_REMOVE
|
|
transposed_xmax = image_width - xmin - TO_REMOVE
|
|
transposed_ymin = ymin
|
|
transposed_ymax = ymax
|
|
elif method == FLIP_TOP_BOTTOM:
|
|
transposed_xmin = xmin
|
|
transposed_xmax = xmax
|
|
transposed_ymin = image_height - ymax
|
|
transposed_ymax = image_height - ymin
|
|
|
|
transposed_boxes = Tensor.cat(
|
|
*(transposed_xmin, transposed_ymin, transposed_xmax, transposed_ymax), dim=-1
|
|
)
|
|
bbox = BoxList(transposed_boxes, self.size, mode="xyxy")
|
|
for k, v in self.extra_fields.items():
|
|
if not isinstance(v, Tensor):
|
|
v = v.transpose(method)
|
|
bbox.add_field(k, v)
|
|
return bbox.convert(self.mode)
|
|
|
|
def clip_to_image(self, remove_empty=True):
|
|
TO_REMOVE = 1
|
|
bb1 = self.bbox.clip(min_=0, max_=self.size[0] - TO_REMOVE)[:, 0]
|
|
bb2 = self.bbox.clip(min_=0, max_=self.size[1] - TO_REMOVE)[:, 1]
|
|
bb3 = self.bbox.clip(min_=0, max_=self.size[0] - TO_REMOVE)[:, 2]
|
|
bb4 = self.bbox.clip(min_=0, max_=self.size[1] - TO_REMOVE)[:, 3]
|
|
self.bbox = Tensor.stack(bb1, bb2, bb3, bb4, dim=1)
|
|
if remove_empty:
|
|
box = self.bbox
|
|
keep = (box[:, 3] > box[:, 1]) & (box[:, 2] > box[:, 0])
|
|
return self[keep]
|
|
return self
|
|
|
|
def __getitem__(self, item):
|
|
if isinstance(item, list):
|
|
if len(item) == 0:
|
|
return []
|
|
if sum(item) == len(item) and isinstance(item[0], bool):
|
|
return self
|
|
bbox = BoxList(tensor_gather(self.bbox, item), self.size, self.mode)
|
|
for k, v in self.extra_fields.items():
|
|
bbox.add_field(k, tensor_gather(v, item))
|
|
return bbox
|
|
|
|
def __len__(self):
|
|
return self.bbox.shape[0]
|
|
|
|
|
|
def cat_boxlist(bboxes):
|
|
size = bboxes[0].size
|
|
mode = bboxes[0].mode
|
|
fields = set(bboxes[0].fields())
|
|
cat_box_list = [bbox.bbox for bbox in bboxes if bbox.bbox.shape[0] > 0]
|
|
|
|
if len(cat_box_list) > 0:
|
|
cat_boxes = BoxList(Tensor.cat(*cat_box_list, dim=0), size, mode)
|
|
else:
|
|
cat_boxes = BoxList(bboxes[0].bbox, size, mode)
|
|
for field in fields:
|
|
cat_field_list = [bbox.get_field(field) for bbox in bboxes if bbox.get_field(field).shape[0] > 0]
|
|
|
|
if len(cat_box_list) > 0:
|
|
data = Tensor.cat(*cat_field_list, dim=0)
|
|
else:
|
|
data = bboxes[0].get_field(field)
|
|
|
|
cat_boxes.add_field(field, data)
|
|
|
|
return cat_boxes
|
|
|
|
|
|
class FPN:
|
|
def __init__(self, in_channels_list, out_channels):
|
|
self.inner_blocks, self.layer_blocks = [], []
|
|
for in_channels in in_channels_list:
|
|
self.inner_blocks.append(nn.Conv2d(in_channels, out_channels, kernel_size=1))
|
|
self.layer_blocks.append(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1))
|
|
self.top_block = LastLevelMaxPool()
|
|
|
|
def __call__(self, x: Tensor):
|
|
last_inner = self.inner_blocks[-1](x[-1])
|
|
results = []
|
|
results.append(self.layer_blocks[-1](last_inner))
|
|
for feature, inner_block, layer_block in zip(
|
|
x[:-1][::-1], self.inner_blocks[:-1][::-1], self.layer_blocks[:-1][::-1]
|
|
):
|
|
if not inner_block:
|
|
continue
|
|
inner_top_down = nearest_interpolate(last_inner, scale_factor=2)
|
|
inner_lateral = inner_block(feature)
|
|
last_inner = inner_lateral + inner_top_down
|
|
layer_result = layer_block(last_inner)
|
|
results.insert(0, layer_result)
|
|
last_results = self.top_block(results[-1])
|
|
results.extend(last_results)
|
|
|
|
return tuple(results)
|
|
|
|
|
|
class ResNetFPN:
|
|
def __init__(self, resnet, out_channels=256):
|
|
self.out_channels = out_channels
|
|
self.body = resnet
|
|
in_channels_stage2 = 256
|
|
in_channels_list = [
|
|
in_channels_stage2,
|
|
in_channels_stage2 * 2,
|
|
in_channels_stage2 * 4,
|
|
in_channels_stage2 * 8,
|
|
]
|
|
self.fpn = FPN(in_channels_list, out_channels)
|
|
|
|
def __call__(self, x):
|
|
x = self.body(x)
|
|
return self.fpn(x)
|
|
|
|
|
|
class AnchorGenerator:
|
|
def __init__(
|
|
self,
|
|
sizes=(32, 64, 128, 256, 512),
|
|
aspect_ratios=(0.5, 1.0, 2.0),
|
|
anchor_strides=(4, 8, 16, 32, 64),
|
|
straddle_thresh=0,
|
|
):
|
|
if len(anchor_strides) == 1:
|
|
anchor_stride = anchor_strides[0]
|
|
cell_anchors = [
|
|
generate_anchors(anchor_stride, sizes, aspect_ratios)
|
|
]
|
|
else:
|
|
if len(anchor_strides) != len(sizes):
|
|
raise RuntimeError("FPN should have #anchor_strides == #sizes")
|
|
|
|
cell_anchors = [
|
|
generate_anchors(
|
|
anchor_stride,
|
|
size if isinstance(size, (tuple, list)) else (size,),
|
|
aspect_ratios
|
|
)
|
|
for anchor_stride, size in zip(anchor_strides, sizes)
|
|
]
|
|
self.strides = anchor_strides
|
|
self.cell_anchors = cell_anchors
|
|
self.straddle_thresh = straddle_thresh
|
|
|
|
def num_anchors_per_location(self):
|
|
return [cell_anchors.shape[0] for cell_anchors in self.cell_anchors]
|
|
|
|
def grid_anchors(self, grid_sizes):
|
|
anchors = []
|
|
for size, stride, base_anchors in zip(
|
|
grid_sizes, self.strides, self.cell_anchors
|
|
):
|
|
grid_height, grid_width = size
|
|
device = base_anchors.device
|
|
shifts_x = Tensor.arange(
|
|
start=0, stop=grid_width * stride, step=stride, dtype=dtypes.float32, device=device
|
|
)
|
|
shifts_y = Tensor.arange(
|
|
start=0, stop=grid_height * stride, step=stride, dtype=dtypes.float32, device=device
|
|
)
|
|
shift_y, shift_x = meshgrid(shifts_y, shifts_x)
|
|
shift_x = shift_x.reshape(-1)
|
|
shift_y = shift_y.reshape(-1)
|
|
shifts = Tensor.stack(shift_x, shift_y, shift_x, shift_y, dim=1)
|
|
|
|
anchors.append(
|
|
(shifts.reshape(-1, 1, 4) + base_anchors.reshape(1, -1, 4)).reshape(-1, 4)
|
|
)
|
|
|
|
return anchors
|
|
|
|
def add_visibility_to(self, boxlist):
|
|
image_width, image_height = boxlist.size
|
|
anchors = boxlist.bbox
|
|
if self.straddle_thresh >= 0:
|
|
inds_inside = (
|
|
(anchors[:, 0] >= -self.straddle_thresh)
|
|
* (anchors[:, 1] >= -self.straddle_thresh)
|
|
* (anchors[:, 2] < image_width + self.straddle_thresh)
|
|
* (anchors[:, 3] < image_height + self.straddle_thresh)
|
|
)
|
|
else:
|
|
device = anchors.device
|
|
inds_inside = Tensor.ones(anchors.shape[0], dtype=dtypes.uint8, device=device)
|
|
boxlist.add_field("visibility", inds_inside)
|
|
|
|
def __call__(self, image_list, feature_maps):
|
|
grid_sizes = [feature_map.shape[-2:] for feature_map in feature_maps]
|
|
anchors_over_all_feature_maps = self.grid_anchors(grid_sizes)
|
|
anchors = []
|
|
for (image_height, image_width) in image_list.image_sizes:
|
|
anchors_in_image = []
|
|
for anchors_per_feature_map in anchors_over_all_feature_maps:
|
|
boxlist = BoxList(
|
|
anchors_per_feature_map, (image_width, image_height), mode="xyxy"
|
|
)
|
|
self.add_visibility_to(boxlist)
|
|
anchors_in_image.append(boxlist)
|
|
anchors.append(anchors_in_image)
|
|
return anchors
|
|
|
|
|
|
def generate_anchors(
|
|
stride=16, sizes=(32, 64, 128, 256, 512), aspect_ratios=(0.5, 1, 2)
|
|
):
|
|
return _generate_anchors(stride, Tensor(list(sizes)) / stride, Tensor(list(aspect_ratios)))
|
|
|
|
|
|
def _generate_anchors(base_size, scales, aspect_ratios):
|
|
anchor = Tensor([1, 1, base_size, base_size]) - 1
|
|
anchors = _ratio_enum(anchor, aspect_ratios)
|
|
anchors = Tensor.cat(
|
|
*[_scale_enum(anchors[i, :], scales).reshape(-1, 4) for i in range(anchors.shape[0])]
|
|
)
|
|
return anchors
|
|
|
|
|
|
def _whctrs(anchor):
|
|
w = anchor[2] - anchor[0] + 1
|
|
h = anchor[3] - anchor[1] + 1
|
|
x_ctr = anchor[0] + 0.5 * (w - 1)
|
|
y_ctr = anchor[1] + 0.5 * (h - 1)
|
|
return w, h, x_ctr, y_ctr
|
|
|
|
|
|
def _mkanchors(ws, hs, x_ctr, y_ctr):
|
|
ws = ws[:, None]
|
|
hs = hs[:, None]
|
|
anchors = Tensor.cat(*(
|
|
x_ctr - 0.5 * (ws - 1),
|
|
y_ctr - 0.5 * (hs - 1),
|
|
x_ctr + 0.5 * (ws - 1),
|
|
y_ctr + 0.5 * (hs - 1),
|
|
), dim=1)
|
|
return anchors
|
|
|
|
|
|
def _ratio_enum(anchor, ratios):
|
|
w, h, x_ctr, y_ctr = _whctrs(anchor)
|
|
size = w * h
|
|
size_ratios = size / ratios
|
|
ws = rint(Tensor.sqrt(size_ratios))
|
|
hs = rint(ws * ratios)
|
|
anchors = _mkanchors(ws, hs, x_ctr, y_ctr)
|
|
return anchors
|
|
|
|
|
|
def _scale_enum(anchor, scales):
|
|
w, h, x_ctr, y_ctr = _whctrs(anchor)
|
|
ws = w * scales
|
|
hs = h * scales
|
|
anchors = _mkanchors(ws, hs, x_ctr, y_ctr)
|
|
return anchors
|
|
|
|
|
|
class RPNHead:
|
|
def __init__(self, in_channels, num_anchors):
|
|
self.conv = nn.Conv2d(in_channels, 256, kernel_size=3, padding=1)
|
|
self.cls_logits = nn.Conv2d(256, num_anchors, kernel_size=1)
|
|
self.bbox_pred = nn.Conv2d(256, num_anchors * 4, kernel_size=1)
|
|
|
|
def __call__(self, x):
|
|
logits = []
|
|
bbox_reg = []
|
|
for feature in x:
|
|
t = Tensor.relu(self.conv(feature))
|
|
logits.append(self.cls_logits(t))
|
|
bbox_reg.append(self.bbox_pred(t))
|
|
return logits, bbox_reg
|
|
|
|
|
|
class BoxCoder(object):
|
|
def __init__(self, weights, bbox_xform_clip=math.log(1000. / 16)):
|
|
self.weights = weights
|
|
self.bbox_xform_clip = bbox_xform_clip
|
|
|
|
def encode(self, reference_boxes, proposals):
|
|
TO_REMOVE = 1 # TODO remove
|
|
ex_widths = proposals[:, 2] - proposals[:, 0] + TO_REMOVE
|
|
ex_heights = proposals[:, 3] - proposals[:, 1] + TO_REMOVE
|
|
ex_ctr_x = proposals[:, 0] + 0.5 * ex_widths
|
|
ex_ctr_y = proposals[:, 1] + 0.5 * ex_heights
|
|
|
|
gt_widths = reference_boxes[:, 2] - reference_boxes[:, 0] + TO_REMOVE
|
|
gt_heights = reference_boxes[:, 3] - reference_boxes[:, 1] + TO_REMOVE
|
|
gt_ctr_x = reference_boxes[:, 0] + 0.5 * gt_widths
|
|
gt_ctr_y = reference_boxes[:, 1] + 0.5 * gt_heights
|
|
|
|
wx, wy, ww, wh = self.weights
|
|
targets_dx = wx * (gt_ctr_x - ex_ctr_x) / ex_widths
|
|
targets_dy = wy * (gt_ctr_y - ex_ctr_y) / ex_heights
|
|
targets_dw = ww * Tensor.log(gt_widths / ex_widths)
|
|
targets_dh = wh * Tensor.log(gt_heights / ex_heights)
|
|
|
|
targets = Tensor.stack(targets_dx, targets_dy, targets_dw, targets_dh, dim=1)
|
|
return targets
|
|
|
|
def decode(self, rel_codes, boxes):
|
|
boxes = boxes.cast(rel_codes.dtype)
|
|
rel_codes = rel_codes
|
|
|
|
TO_REMOVE = 1 # TODO remove
|
|
widths = boxes[:, 2] - boxes[:, 0] + TO_REMOVE
|
|
heights = boxes[:, 3] - boxes[:, 1] + TO_REMOVE
|
|
ctr_x = boxes[:, 0] + 0.5 * widths
|
|
ctr_y = boxes[:, 1] + 0.5 * heights
|
|
|
|
wx, wy, ww, wh = self.weights
|
|
dx = rel_codes[:, 0::4] / wx
|
|
dy = rel_codes[:, 1::4] / wy
|
|
dw = rel_codes[:, 2::4] / ww
|
|
dh = rel_codes[:, 3::4] / wh
|
|
|
|
# Prevent sending too large values into Tensor.exp()
|
|
dw = dw.clip(min_=dw.min(), max_=self.bbox_xform_clip)
|
|
dh = dh.clip(min_=dh.min(), max_=self.bbox_xform_clip)
|
|
|
|
pred_ctr_x = dx * widths[:, None] + ctr_x[:, None]
|
|
pred_ctr_y = dy * heights[:, None] + ctr_y[:, None]
|
|
pred_w = dw.exp() * widths[:, None]
|
|
pred_h = dh.exp() * heights[:, None]
|
|
x = pred_ctr_x - 0.5 * pred_w
|
|
y = pred_ctr_y - 0.5 * pred_h
|
|
w = pred_ctr_x + 0.5 * pred_w - 1
|
|
h = pred_ctr_y + 0.5 * pred_h - 1
|
|
pred_boxes = Tensor.stack(x, y, w, h).permute(1,2,0).reshape(rel_codes.shape[0], rel_codes.shape[1])
|
|
return pred_boxes
|
|
|
|
|
|
def boxlist_nms(boxlist, nms_thresh, max_proposals=-1, score_field="scores"):
|
|
if nms_thresh <= 0:
|
|
return boxlist
|
|
mode = boxlist.mode
|
|
boxlist = boxlist.convert("xyxy")
|
|
boxes = boxlist.bbox
|
|
score = boxlist.get_field(score_field)
|
|
keep = _box_nms(boxes.numpy(), score.numpy(), nms_thresh)
|
|
if max_proposals > 0:
|
|
keep = keep[:max_proposals]
|
|
boxlist = boxlist[keep]
|
|
return boxlist.convert(mode)
|
|
|
|
|
|
def remove_small_boxes(boxlist, min_size):
|
|
xywh_boxes = boxlist.convert("xywh").bbox
|
|
_, _, ws, hs = xywh_boxes.chunk(4, dim=1)
|
|
keep = ((
|
|
(ws >= min_size) * (hs >= min_size)
|
|
) > 0).reshape(-1)
|
|
if keep.sum().numpy() == len(boxlist):
|
|
return boxlist
|
|
else:
|
|
keep = keep.numpy().nonzero()[0]
|
|
return boxlist[keep]
|
|
|
|
|
|
class RPNPostProcessor:
|
|
# Not used in Loss calculation
|
|
def __init__(
|
|
self,
|
|
pre_nms_top_n,
|
|
post_nms_top_n,
|
|
nms_thresh,
|
|
min_size,
|
|
box_coder=None,
|
|
fpn_post_nms_top_n=None,
|
|
):
|
|
self.pre_nms_top_n = pre_nms_top_n
|
|
self.post_nms_top_n = post_nms_top_n
|
|
self.nms_thresh = nms_thresh
|
|
self.min_size = min_size
|
|
|
|
if box_coder is None:
|
|
box_coder = BoxCoder(weights=(1.0, 1.0, 1.0, 1.0))
|
|
self.box_coder = box_coder
|
|
|
|
if fpn_post_nms_top_n is None:
|
|
fpn_post_nms_top_n = post_nms_top_n
|
|
self.fpn_post_nms_top_n = fpn_post_nms_top_n
|
|
|
|
def forward_for_single_feature_map(self, anchors, objectness, box_regression):
|
|
device = objectness.device
|
|
N, A, H, W = objectness.shape
|
|
objectness = permute_and_flatten(objectness, N, A, 1, H, W).reshape(N, -1)
|
|
objectness = objectness.sigmoid()
|
|
|
|
box_regression = permute_and_flatten(box_regression, N, A, 4, H, W)
|
|
|
|
num_anchors = A * H * W
|
|
|
|
pre_nms_top_n = min(self.pre_nms_top_n, num_anchors)
|
|
objectness, topk_idx = topk(objectness, pre_nms_top_n, dim=1, sorted=False)
|
|
concat_anchors = Tensor.cat(*[a.bbox for a in anchors], dim=0).reshape(N, -1, 4)
|
|
image_shapes = [box.size for box in anchors]
|
|
|
|
box_regression_list = []
|
|
concat_anchors_list = []
|
|
for batch_idx in range(N):
|
|
box_regression_list.append(tensor_gather(box_regression[batch_idx], topk_idx[batch_idx]))
|
|
concat_anchors_list.append(tensor_gather(concat_anchors[batch_idx], topk_idx[batch_idx]))
|
|
|
|
box_regression = Tensor.stack(*box_regression_list)
|
|
concat_anchors = Tensor.stack(*concat_anchors_list)
|
|
|
|
proposals = self.box_coder.decode(
|
|
box_regression.reshape(-1, 4), concat_anchors.reshape(-1, 4)
|
|
)
|
|
|
|
proposals = proposals.reshape(N, -1, 4)
|
|
|
|
result = []
|
|
for proposal, score, im_shape in zip(proposals, objectness, image_shapes):
|
|
boxlist = BoxList(proposal, im_shape, mode="xyxy")
|
|
boxlist.add_field("objectness", score)
|
|
boxlist = boxlist.clip_to_image(remove_empty=False)
|
|
boxlist = remove_small_boxes(boxlist, self.min_size)
|
|
boxlist = boxlist_nms(
|
|
boxlist,
|
|
self.nms_thresh,
|
|
max_proposals=self.post_nms_top_n,
|
|
score_field="objectness",
|
|
)
|
|
result.append(boxlist)
|
|
return result
|
|
|
|
def __call__(self, anchors, objectness, box_regression):
|
|
sampled_boxes = []
|
|
num_levels = len(objectness)
|
|
anchors = list(zip(*anchors))
|
|
for a, o, b in zip(anchors, objectness, box_regression):
|
|
sampled_boxes.append(self.forward_for_single_feature_map(a, o, b))
|
|
|
|
boxlists = list(zip(*sampled_boxes))
|
|
boxlists = [cat_boxlist(boxlist) for boxlist in boxlists]
|
|
|
|
if num_levels > 1:
|
|
boxlists = self.select_over_all_levels(boxlists)
|
|
|
|
return boxlists
|
|
|
|
def select_over_all_levels(self, boxlists):
|
|
num_images = len(boxlists)
|
|
for i in range(num_images):
|
|
objectness = boxlists[i].get_field("objectness")
|
|
post_nms_top_n = min(self.fpn_post_nms_top_n, objectness.shape[0])
|
|
_, inds_sorted = topk(objectness,
|
|
post_nms_top_n, dim=0, sorted=False
|
|
)
|
|
boxlists[i] = boxlists[i][inds_sorted]
|
|
return boxlists
|
|
|
|
|
|
class RPN:
|
|
def __init__(self, in_channels):
|
|
self.anchor_generator = AnchorGenerator()
|
|
|
|
in_channels = 256
|
|
head = RPNHead(
|
|
in_channels, self.anchor_generator.num_anchors_per_location()[0]
|
|
)
|
|
rpn_box_coder = BoxCoder(weights=(1.0, 1.0, 1.0, 1.0))
|
|
box_selector_test = RPNPostProcessor(
|
|
pre_nms_top_n=1000,
|
|
post_nms_top_n=1000,
|
|
nms_thresh=0.7,
|
|
min_size=0,
|
|
box_coder=rpn_box_coder,
|
|
fpn_post_nms_top_n=1000
|
|
)
|
|
self.head = head
|
|
self.box_selector_test = box_selector_test
|
|
|
|
def __call__(self, images, features, targets=None):
|
|
objectness, rpn_box_regression = self.head(features)
|
|
anchors = self.anchor_generator(images, features)
|
|
boxes = self.box_selector_test(anchors, objectness, rpn_box_regression)
|
|
return boxes, {}
|
|
|
|
|
|
def make_conv3x3(
|
|
in_channels,
|
|
out_channels,
|
|
dilation=1,
|
|
stride=1,
|
|
use_gn=False,
|
|
):
|
|
conv = nn.Conv2d(
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size=3,
|
|
stride=stride,
|
|
padding=dilation,
|
|
dilation=dilation,
|
|
bias=False if use_gn else True
|
|
)
|
|
return conv
|
|
|
|
|
|
class MaskRCNNFPNFeatureExtractor:
|
|
def __init__(self):
|
|
resolution = 14
|
|
scales = (0.25, 0.125, 0.0625, 0.03125)
|
|
sampling_ratio = 2
|
|
pooler = Pooler(
|
|
output_size=(resolution, resolution),
|
|
scales=scales,
|
|
sampling_ratio=sampling_ratio,
|
|
)
|
|
input_size = 256
|
|
self.pooler = pooler
|
|
|
|
use_gn = False
|
|
layers = (256, 256, 256, 256)
|
|
dilation = 1
|
|
self.mask_fcn1 = make_conv3x3(input_size, layers[0], dilation=dilation, stride=1, use_gn=use_gn)
|
|
self.mask_fcn2 = make_conv3x3(layers[0], layers[1], dilation=dilation, stride=1, use_gn=use_gn)
|
|
self.mask_fcn3 = make_conv3x3(layers[1], layers[2], dilation=dilation, stride=1, use_gn=use_gn)
|
|
self.mask_fcn4 = make_conv3x3(layers[2], layers[3], dilation=dilation, stride=1, use_gn=use_gn)
|
|
self.blocks = [self.mask_fcn1, self.mask_fcn2, self.mask_fcn3, self.mask_fcn4]
|
|
|
|
def __call__(self, x, proposals):
|
|
x = self.pooler(x, proposals)
|
|
for layer in self.blocks:
|
|
if x is not None:
|
|
x = Tensor.relu(layer(x))
|
|
return x
|
|
|
|
|
|
class MaskRCNNC4Predictor:
|
|
def __init__(self):
|
|
num_classes = 81
|
|
dim_reduced = 256
|
|
num_inputs = dim_reduced
|
|
self.conv5_mask = nn.ConvTranspose2d(num_inputs, dim_reduced, 2, 2, 0)
|
|
self.mask_fcn_logits = nn.Conv2d(dim_reduced, num_classes, 1, 1, 0)
|
|
|
|
def __call__(self, x):
|
|
x = Tensor.relu(self.conv5_mask(x))
|
|
return self.mask_fcn_logits(x)
|
|
|
|
|
|
class FPN2MLPFeatureExtractor:
|
|
def __init__(self, cfg):
|
|
resolution = 7
|
|
scales = (0.25, 0.125, 0.0625, 0.03125)
|
|
sampling_ratio = 2
|
|
pooler = Pooler(
|
|
output_size=(resolution, resolution),
|
|
scales=scales,
|
|
sampling_ratio=sampling_ratio,
|
|
)
|
|
input_size = 256 * resolution ** 2
|
|
representation_size = 1024
|
|
self.pooler = pooler
|
|
self.fc6 = nn.Linear(input_size, representation_size)
|
|
self.fc7 = nn.Linear(representation_size, representation_size)
|
|
|
|
def __call__(self, x, proposals):
|
|
x = self.pooler(x, proposals)
|
|
x = x.reshape(x.shape[0], -1)
|
|
x = Tensor.relu(self.fc6(x))
|
|
x = Tensor.relu(self.fc7(x))
|
|
return x
|
|
|
|
|
|
def _bilinear_interpolate(
|
|
input, # [N, C, H, W]
|
|
roi_batch_ind, # [K]
|
|
y, # [K, PH, IY]
|
|
x, # [K, PW, IX]
|
|
ymask, # [K, IY]
|
|
xmask, # [K, IX]
|
|
):
|
|
_, channels, height, width = input.shape
|
|
y = y.clip(min_=0.0, max_=float(height-1))
|
|
x = x.clip(min_=0.0, max_=float(width-1))
|
|
|
|
# Tensor.where doesnt work well with int32 data so cast to float32
|
|
y_low = y.cast(dtypes.int32).contiguous().float().contiguous()
|
|
x_low = x.cast(dtypes.int32).contiguous().float().contiguous()
|
|
|
|
y_high = Tensor.where(y_low >= height - 1, float(height - 1), y_low + 1)
|
|
y_low = Tensor.where(y_low >= height - 1, float(height - 1), y_low)
|
|
|
|
x_high = Tensor.where(x_low >= width - 1, float(width - 1), x_low + 1)
|
|
x_low = Tensor.where(x_low >= width - 1, float(width - 1), x_low)
|
|
|
|
ly = y - y_low
|
|
lx = x - x_low
|
|
hy = 1.0 - ly
|
|
hx = 1.0 - lx
|
|
|
|
def masked_index(
|
|
y, # [K, PH, IY]
|
|
x, # [K, PW, IX]
|
|
):
|
|
if ymask is not None:
|
|
assert xmask is not None
|
|
y = Tensor.where(ymask[:, None, :], y, 0)
|
|
x = Tensor.where(xmask[:, None, :], x, 0)
|
|
key1 = roi_batch_ind[:, None, None, None, None, None]
|
|
key2 = Tensor.arange(channels, device=input.device)[None, :, None, None, None, None]
|
|
key3 = y[:, None, :, None, :, None]
|
|
key4 = x[:, None, None, :, None, :]
|
|
return tensor_getitem(input,key1,key2,key3,key4) # [K, C, PH, PW, IY, IX]
|
|
|
|
v1 = masked_index(y_low, x_low)
|
|
v2 = masked_index(y_low, x_high)
|
|
v3 = masked_index(y_high, x_low)
|
|
v4 = masked_index(y_high, x_high)
|
|
|
|
# all ws preemptively [K, C, PH, PW, IY, IX]
|
|
def outer_prod(y, x):
|
|
return y[:, None, :, None, :, None] * x[:, None, None, :, None, :]
|
|
|
|
w1 = outer_prod(hy, hx)
|
|
w2 = outer_prod(hy, lx)
|
|
w3 = outer_prod(ly, hx)
|
|
w4 = outer_prod(ly, lx)
|
|
|
|
val = w1*v1 + w2*v2 + w3*v3 + w4*v4
|
|
return val
|
|
|
|
#https://pytorch.org/vision/main/_modules/torchvision/ops/roi_align.html#roi_align
|
|
def _roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
|
|
orig_dtype = input.dtype
|
|
_, _, height, width = input.shape
|
|
ph = Tensor.arange(pooled_height, device=input.device)
|
|
pw = Tensor.arange(pooled_width, device=input.device)
|
|
|
|
roi_batch_ind = rois[:, 0].cast(dtypes.int32).contiguous()
|
|
offset = 0.5 if aligned else 0.0
|
|
roi_start_w = rois[:, 1] * spatial_scale - offset
|
|
roi_start_h = rois[:, 2] * spatial_scale - offset
|
|
roi_end_w = rois[:, 3] * spatial_scale - offset
|
|
roi_end_h = rois[:, 4] * spatial_scale - offset
|
|
|
|
roi_width = roi_end_w - roi_start_w
|
|
roi_height = roi_end_h - roi_start_h
|
|
if not aligned:
|
|
roi_width = roi_width.maximum(1.0)
|
|
roi_height = roi_height.maximum(1.0)
|
|
|
|
bin_size_h = roi_height / pooled_height
|
|
bin_size_w = roi_width / pooled_width
|
|
|
|
exact_sampling = sampling_ratio > 0
|
|
roi_bin_grid_h = sampling_ratio if exact_sampling else (roi_height / pooled_height).ceil()
|
|
roi_bin_grid_w = sampling_ratio if exact_sampling else (roi_width / pooled_width).ceil()
|
|
|
|
if exact_sampling:
|
|
count = max(roi_bin_grid_h * roi_bin_grid_w, 1)
|
|
iy = Tensor.arange(roi_bin_grid_h, device=input.device)
|
|
ix = Tensor.arange(roi_bin_grid_w, device=input.device)
|
|
ymask = None
|
|
xmask = None
|
|
else:
|
|
count = (roi_bin_grid_h * roi_bin_grid_w).maximum(1)
|
|
iy = Tensor.arange(height, device=input.device)
|
|
ix = Tensor.arange(width, device=input.device)
|
|
ymask = iy[None, :] < roi_bin_grid_h[:, None]
|
|
xmask = ix[None, :] < roi_bin_grid_w[:, None]
|
|
|
|
def from_K(t):
|
|
return t[:, None, None]
|
|
|
|
y = (
|
|
from_K(roi_start_h)
|
|
+ ph[None, :, None] * from_K(bin_size_h)
|
|
+ (iy[None, None, :] + 0.5) * from_K(bin_size_h / roi_bin_grid_h)
|
|
)
|
|
x = (
|
|
from_K(roi_start_w)
|
|
+ pw[None, :, None] * from_K(bin_size_w)
|
|
+ (ix[None, None, :] + 0.5) * from_K(bin_size_w / roi_bin_grid_w)
|
|
)
|
|
|
|
val = _bilinear_interpolate(input, roi_batch_ind, y, x, ymask, xmask)
|
|
if not exact_sampling:
|
|
val = ymask[:, None, None, None, :, None].where(val, 0)
|
|
val = xmask[:, None, None, None, None, :].where(val, 0)
|
|
|
|
output = val.sum((-1, -2))
|
|
if isinstance(count, Tensor):
|
|
output /= count[:, None, None, None]
|
|
else:
|
|
output /= count
|
|
|
|
output = output.cast(orig_dtype)
|
|
return output
|
|
|
|
class ROIAlign:
|
|
def __init__(self, output_size, spatial_scale, sampling_ratio):
|
|
self.output_size = output_size
|
|
self.spatial_scale = spatial_scale
|
|
self.sampling_ratio = sampling_ratio
|
|
|
|
def __call__(self, input, rois):
|
|
output = _roi_align(
|
|
input, rois, self.spatial_scale, self.output_size[0], self.output_size[1], self.sampling_ratio, aligned=False
|
|
)
|
|
return output
|
|
|
|
|
|
class LevelMapper:
|
|
def __init__(self, k_min, k_max, canonical_scale=224, canonical_level=4, eps=1e-6):
|
|
self.k_min = k_min
|
|
self.k_max = k_max
|
|
self.s0 = canonical_scale
|
|
self.lvl0 = canonical_level
|
|
self.eps = eps
|
|
|
|
def __call__(self, boxlists):
|
|
s = Tensor.sqrt(Tensor.cat(*[boxlist.area() for boxlist in boxlists]))
|
|
target_lvls = (self.lvl0 + Tensor.log2(s / self.s0 + self.eps)).floor()
|
|
target_lvls = target_lvls.clip(min_=self.k_min, max_=self.k_max)
|
|
return target_lvls - self.k_min
|
|
|
|
|
|
class Pooler:
|
|
def __init__(self, output_size, scales, sampling_ratio):
|
|
self.output_size = output_size
|
|
self.scales = scales
|
|
self.sampling_ratio = sampling_ratio
|
|
poolers = []
|
|
for scale in scales:
|
|
poolers.append(
|
|
ROIAlign(
|
|
output_size, spatial_scale=scale, sampling_ratio=sampling_ratio
|
|
)
|
|
)
|
|
self.poolers = poolers
|
|
self.output_size = output_size
|
|
lvl_min = -math.log2(scales[0])
|
|
lvl_max = -math.log2(scales[-1])
|
|
self.map_levels = LevelMapper(lvl_min, lvl_max)
|
|
|
|
def convert_to_roi_format(self, boxes):
|
|
concat_boxes = Tensor.cat(*[b.bbox for b in boxes], dim=0)
|
|
device, dtype = concat_boxes.device, concat_boxes.dtype
|
|
ids = Tensor.cat(
|
|
*[
|
|
Tensor.full((len(b), 1), i, dtype=dtype, device=device)
|
|
for i, b in enumerate(boxes)
|
|
],
|
|
dim=0,
|
|
)
|
|
if concat_boxes.shape[0] != 0:
|
|
rois = Tensor.cat(*[ids, concat_boxes], dim=1)
|
|
return rois
|
|
|
|
def __call__(self, x, boxes):
|
|
num_levels = len(self.poolers)
|
|
rois = self.convert_to_roi_format(boxes)
|
|
if rois is not None:
|
|
if num_levels == 1:
|
|
return self.poolers[0](x[0], rois)
|
|
|
|
levels = self.map_levels(boxes)
|
|
results = []
|
|
all_idxs = []
|
|
for level, (per_level_feature, pooler) in enumerate(zip(x, self.poolers)):
|
|
# this is fine because no grad will flow through index
|
|
idx_in_level = (levels.numpy() == level).nonzero()[0]
|
|
if len(idx_in_level) > 0:
|
|
rois_per_level = tensor_gather(rois, idx_in_level)
|
|
pooler_output = pooler(per_level_feature, rois_per_level)
|
|
all_idxs.extend(idx_in_level)
|
|
results.append(pooler_output)
|
|
|
|
return tensor_gather(Tensor.cat(*results), [x[0] for x in sorted({i:idx for i, idx in enumerate(all_idxs)}.items(), key=lambda x: x[1])])
|
|
|
|
|
|
class FPNPredictor:
|
|
def __init__(self):
|
|
num_classes = 81
|
|
representation_size = 1024
|
|
self.cls_score = nn.Linear(representation_size, num_classes)
|
|
num_bbox_reg_classes = num_classes
|
|
self.bbox_pred = nn.Linear(representation_size, num_bbox_reg_classes * 4)
|
|
|
|
def __call__(self, x):
|
|
scores = self.cls_score(x)
|
|
bbox_deltas = self.bbox_pred(x)
|
|
return scores, bbox_deltas
|
|
|
|
|
|
class PostProcessor:
|
|
# Not used in training
|
|
def __init__(
|
|
self,
|
|
score_thresh=0.05,
|
|
nms=0.5,
|
|
detections_per_img=100,
|
|
box_coder=None,
|
|
cls_agnostic_bbox_reg=False
|
|
):
|
|
self.score_thresh = score_thresh
|
|
self.nms = nms
|
|
self.detections_per_img = detections_per_img
|
|
if box_coder is None:
|
|
box_coder = BoxCoder(weights=(10., 10., 5., 5.))
|
|
self.box_coder = box_coder
|
|
self.cls_agnostic_bbox_reg = cls_agnostic_bbox_reg
|
|
|
|
def __call__(self, x, boxes):
|
|
class_logits, box_regression = x
|
|
class_prob = Tensor.softmax(class_logits, -1)
|
|
image_shapes = [box.size for box in boxes]
|
|
boxes_per_image = [len(box) for box in boxes]
|
|
concat_boxes = Tensor.cat(*[a.bbox for a in boxes], dim=0)
|
|
|
|
if self.cls_agnostic_bbox_reg:
|
|
box_regression = box_regression[:, -4:]
|
|
proposals = self.box_coder.decode(
|
|
box_regression.reshape(sum(boxes_per_image), -1), concat_boxes
|
|
)
|
|
if self.cls_agnostic_bbox_reg:
|
|
proposals = proposals.repeat([1, class_prob.shape[1]])
|
|
num_classes = class_prob.shape[1]
|
|
proposals = proposals.unsqueeze(0)
|
|
class_prob = class_prob.unsqueeze(0)
|
|
results = []
|
|
for prob, boxes_per_img, image_shape in zip(
|
|
class_prob, proposals, image_shapes
|
|
):
|
|
boxlist = self.prepare_boxlist(boxes_per_img, prob, image_shape)
|
|
boxlist = boxlist.clip_to_image(remove_empty=False)
|
|
boxlist = self.filter_results(boxlist, num_classes)
|
|
results.append(boxlist)
|
|
return results
|
|
|
|
def prepare_boxlist(self, boxes, scores, image_shape):
|
|
boxes = boxes.reshape(-1, 4)
|
|
scores = scores.reshape(-1)
|
|
boxlist = BoxList(boxes, image_shape, mode="xyxy")
|
|
boxlist.add_field("scores", scores)
|
|
return boxlist
|
|
|
|
def filter_results(self, boxlist, num_classes):
|
|
boxes = boxlist.bbox.reshape(-1, num_classes * 4)
|
|
scores = boxlist.get_field("scores").reshape(-1, num_classes)
|
|
|
|
device = scores.device
|
|
result = []
|
|
scores = scores.numpy()
|
|
boxes = boxes.numpy()
|
|
inds_all = scores > self.score_thresh
|
|
for j in range(1, num_classes):
|
|
inds = inds_all[:, j].nonzero()[0]
|
|
# This needs to be done in numpy because it can create empty arrays
|
|
scores_j = scores[inds, j]
|
|
boxes_j = boxes[inds, j * 4: (j + 1) * 4]
|
|
boxes_j = Tensor(boxes_j)
|
|
scores_j = Tensor(scores_j)
|
|
boxlist_for_class = BoxList(boxes_j, boxlist.size, mode="xyxy")
|
|
boxlist_for_class.add_field("scores", scores_j)
|
|
if len(boxlist_for_class):
|
|
boxlist_for_class = boxlist_nms(
|
|
boxlist_for_class, self.nms
|
|
)
|
|
num_labels = len(boxlist_for_class)
|
|
boxlist_for_class.add_field(
|
|
"labels", Tensor.full((num_labels,), j, device=device)
|
|
)
|
|
result.append(boxlist_for_class)
|
|
|
|
result = cat_boxlist(result)
|
|
number_of_detections = len(result)
|
|
|
|
if number_of_detections > self.detections_per_img > 0:
|
|
cls_scores = result.get_field("scores")
|
|
image_thresh, _ = topk(cls_scores, k=self.detections_per_img)
|
|
image_thresh = image_thresh.numpy()[-1]
|
|
keep = (cls_scores.numpy() >= image_thresh).nonzero()[0]
|
|
result = result[keep]
|
|
return result
|
|
|
|
|
|
class RoIBoxHead:
|
|
def __init__(self, in_channels):
|
|
self.feature_extractor = FPN2MLPFeatureExtractor(in_channels)
|
|
self.predictor = FPNPredictor()
|
|
self.post_processor = PostProcessor(
|
|
score_thresh=0.05,
|
|
nms=0.5,
|
|
detections_per_img=100,
|
|
box_coder=BoxCoder(weights=(10., 10., 5., 5.)),
|
|
cls_agnostic_bbox_reg=False
|
|
)
|
|
|
|
def __call__(self, features, proposals, targets=None):
|
|
x = self.feature_extractor(features, proposals)
|
|
class_logits, box_regression = self.predictor(x)
|
|
if not Tensor.training:
|
|
result = self.post_processor((class_logits, box_regression), proposals)
|
|
return x, result, {}
|
|
|
|
|
|
class MaskPostProcessor:
|
|
# Not used in loss calculation
|
|
def __call__(self, x, boxes):
|
|
mask_prob = x.sigmoid().numpy()
|
|
num_masks = x.shape[0]
|
|
labels = [bbox.get_field("labels") for bbox in boxes]
|
|
labels = Tensor.cat(*labels).numpy().astype(np.int32)
|
|
index = np.arange(num_masks)
|
|
mask_prob = mask_prob[index, labels][:, None]
|
|
boxes_per_image, cumsum = [], 0
|
|
for box in boxes:
|
|
cumsum += len(box)
|
|
boxes_per_image.append(cumsum)
|
|
# using numpy here as Tensor.chunk doesnt have custom chunk sizes
|
|
mask_prob = np.split(mask_prob, boxes_per_image, axis=0)
|
|
results = []
|
|
for prob, box in zip(mask_prob, boxes):
|
|
bbox = BoxList(box.bbox, box.size, mode="xyxy")
|
|
for field in box.fields():
|
|
bbox.add_field(field, box.get_field(field))
|
|
prob = Tensor(prob)
|
|
bbox.add_field("mask", prob)
|
|
results.append(bbox)
|
|
|
|
return results
|
|
|
|
|
|
class Mask:
|
|
def __init__(self):
|
|
self.feature_extractor = MaskRCNNFPNFeatureExtractor()
|
|
self.predictor = MaskRCNNC4Predictor()
|
|
self.post_processor = MaskPostProcessor()
|
|
|
|
def __call__(self, features, proposals, targets=None):
|
|
x = self.feature_extractor(features, proposals)
|
|
if x:
|
|
mask_logits = self.predictor(x)
|
|
if not Tensor.training:
|
|
result = self.post_processor(mask_logits, proposals)
|
|
return x, result, {}
|
|
return x, [], {}
|
|
|
|
|
|
class RoIHeads:
|
|
def __init__(self, in_channels):
|
|
self.box = RoIBoxHead(in_channels)
|
|
self.mask = Mask()
|
|
|
|
def __call__(self, features, proposals, targets=None):
|
|
x, detections, _ = self.box(features, proposals, targets)
|
|
x, detections, _ = self.mask(features, detections, targets)
|
|
return x, detections, {}
|
|
|
|
|
|
class ImageList(object):
|
|
def __init__(self, tensors, image_sizes):
|
|
self.tensors = tensors
|
|
self.image_sizes = image_sizes
|
|
|
|
def to(self, *args, **kwargs):
|
|
cast_tensor = self.tensors.to(*args, **kwargs)
|
|
return ImageList(cast_tensor, self.image_sizes)
|
|
|
|
|
|
def to_image_list(tensors, size_divisible=32):
|
|
# Preprocessing
|
|
if isinstance(tensors, Tensor) and size_divisible > 0:
|
|
tensors = [tensors]
|
|
|
|
if isinstance(tensors, ImageList):
|
|
return tensors
|
|
elif isinstance(tensors, Tensor):
|
|
# single tensor shape can be inferred
|
|
assert tensors.ndim == 4
|
|
image_sizes = [tensor.shape[-2:] for tensor in tensors]
|
|
return ImageList(tensors, image_sizes)
|
|
elif isinstance(tensors, (tuple, list)):
|
|
max_size = tuple(max(s) for s in zip(*[img.shape for img in tensors]))
|
|
if size_divisible > 0:
|
|
|
|
stride = size_divisible
|
|
max_size = list(max_size)
|
|
max_size[1] = int(math.ceil(max_size[1] / stride) * stride)
|
|
max_size[2] = int(math.ceil(max_size[2] / stride) * stride)
|
|
max_size = tuple(max_size)
|
|
|
|
batch_shape = (len(tensors),) + max_size
|
|
batched_imgs = np.zeros(batch_shape, dtype=_to_np_dtype(tensors[0].dtype))
|
|
for img, pad_img in zip(tensors, batched_imgs):
|
|
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]] += img.numpy()
|
|
|
|
batched_imgs = Tensor(batched_imgs)
|
|
image_sizes = [im.shape[-2:] for im in tensors]
|
|
|
|
return ImageList(batched_imgs, image_sizes)
|
|
else:
|
|
raise TypeError("Unsupported type for to_image_list: {}".format(type(tensors)))
|
|
|
|
|
|
class MaskRCNN:
|
|
def __init__(self, backbone: ResNet):
|
|
self.backbone = ResNetFPN(backbone, out_channels=256)
|
|
self.rpn = RPN(self.backbone.out_channels)
|
|
self.roi_heads = RoIHeads(self.backbone.out_channels)
|
|
|
|
def load_from_pretrained(self):
|
|
fn = Path('./') / "weights/maskrcnn.pt"
|
|
fetch("https://download.pytorch.org/models/maskrcnn/e2e_mask_rcnn_R_50_FPN_1x.pth", fn)
|
|
|
|
state_dict = torch_load(fn)['model']
|
|
loaded_keys = []
|
|
for k, v in state_dict.items():
|
|
if "module." in k:
|
|
k = k.replace("module.", "")
|
|
if "stem." in k:
|
|
k = k.replace("stem.", "")
|
|
if "fpn_inner" in k:
|
|
block_index = int(re.search(r"fpn_inner(\d+)", k).group(1))
|
|
k = re.sub(r"fpn_inner\d+", f"inner_blocks.{block_index - 1}", k)
|
|
if "fpn_layer" in k:
|
|
block_index = int(re.search(r"fpn_layer(\d+)", k).group(1))
|
|
k = re.sub(r"fpn_layer\d+", f"layer_blocks.{block_index - 1}", k)
|
|
loaded_keys.append(k)
|
|
get_child(self, k).assign(v.numpy()).realize()
|
|
return loaded_keys
|
|
|
|
def __call__(self, images):
|
|
images = to_image_list(images)
|
|
features = self.backbone(images.tensors)
|
|
proposals, _ = self.rpn(images, features)
|
|
x, result, _ = self.roi_heads(features, proposals)
|
|
return result
|
|
|
|
|
|
if __name__ == '__main__':
|
|
resnet = resnet = ResNet(50, num_classes=None, stride_in_1x1=True)
|
|
model = MaskRCNN(backbone=resnet)
|
|
model.load_from_pretrained()
|