diff --git a/examples/webgpu/yolov8/index.html b/examples/webgpu/yolov8/index.html
index 66e6a88780..dafa04afdd 100644
--- a/examples/webgpu/yolov8/index.html
+++ b/examples/webgpu/yolov8/index.html
@@ -176,10 +176,12 @@
offscreenContext.clearRect(0, 0, modelInputSize, modelInputSize);
offscreenContext.drawImage(video, offsetX, offsetY, targetWidth, targetHeight);
const boxes = await detectObjectsOnFrame(offscreenContext);
- drawBoxes(offscreenCanvas, boxes, targetWidth, targetHeight, offsetX, offsetY);
+ const validBoxes = [];
+ for (let i = 0; i < boxes.length; i += 6)
+ if (boxes[i + 4] > 0) validBoxes.push([boxes[i], boxes[i + 1], boxes[i + 2], boxes[i + 3], boxes[i + 5]]);
+ drawBoxes(offscreenCanvas, validBoxes, targetWidth, targetHeight, offsetX, offsetY);
requestAnimationFrame(processFrame);
}
-
requestAnimationFrame(processFrame);
function drawBoxes(offscreenCanvas, boxes, targetWidth, targetHeight, offsetX, offsetY) {
@@ -190,8 +192,8 @@
const scaleX = canvas.width / targetWidth;
const scaleY = canvas.height / targetHeight;
- boxes.forEach(([x1, y1, x2, y2, label]) => {
- const classIndex = yolo_classes.indexOf(label);
+ boxes.forEach(([x1, y1, x2, y2, classIndex]) => {
+ const label = yolo_classes[classIndex];
const color = classColors[classIndex];
ctx.strokeStyle = color;
ctx.fillStyle = color;
@@ -219,21 +221,13 @@
net = await yolov8.load(device, "./net.safetensors");
loadingContainer.style.display = "none";
}
- let start = performance.now();
- const [input,img_width,img_height] = await prepareInput(offscreenContext);
- console.log("Preprocess took: " + (performance.now() - start) + " ms");
- start = performance.now();
+ const input = await prepareInput(offscreenContext);
const output = await net(new Float32Array(input));
- console.log("Inference took: " + (performance.now() - start) + " ms");
- start = performance.now();
- let out = processOutput(output[0],img_width,img_height);
- console.log("Postprocess took: " + (performance.now() - start) + " ms");
- return out;
+ return output[0];
}
async function prepareInput(offscreenContext) {
return new Promise(resolve => {
- const [img_width,img_height] = [modelInputSize, modelInputSize]
const imgData = offscreenContext.getImageData(0,0,modelInputSize,modelInputSize);
const pixels = imgData.data;
const red = [], green = [], blue = [];
@@ -244,7 +238,7 @@
blue.push(pixels[index+2]/255.0);
}
const input = [...red, ...green, ...blue];
- resolve([input, img_width, img_height])
+ resolve(input)
})
}
@@ -257,57 +251,6 @@
});
};
- function processOutput(output, img_width, img_height) {
- let boxes = [];
- const numPredictions = Math.pow(modelInputSize/32, 2) * 21;
- for (let index=0;index [col, output[numPredictions*(col+4)+index]])
- .reduce((accum, item) => item[1]>accum[1] ? item : accum,[0,0]);
-
- if (prob < 0.25) continue;
- const label = yolo_classes[class_id];
- const xc = output[index];
- const yc = output[numPredictions+index];
- const w = output[2*numPredictions+index];
- const h = output[3*numPredictions+index];
- const x1 = (xc-w/2)/modelInputSize*img_width;
- const y1 = (yc-h/2)/modelInputSize*img_height;
- const x2 = (xc+w/2)/modelInputSize*img_width;
- const y2 = (yc+h/2)/modelInputSize*img_height;
- boxes.push([x1,y1,x2,y2,label,prob]);
- }
-
- boxes = boxes.sort((box1,box2) => box2[5]-box1[5])
- const result = [];
- while (boxes.length>0) {
- result.push(boxes[0]);
- boxes = boxes.filter(box => iou(boxes[0],box)<0.7);
- }
- return result;
- }
-
- function iou(box1,box2) {
- return intersection(box1,box2)/union(box1,box2);
- }
-
- function union(box1,box2) {
- const [box1_x1,box1_y1,box1_x2,box1_y2] = box1;
- const [box2_x1,box2_y1,box2_x2,box2_y2] = box2;
- const box1_area = (box1_x2-box1_x1)*(box1_y2-box1_y1)
- const box2_area = (box2_x2-box2_x1)*(box2_y2-box2_y1)
- return box1_area + box2_area - intersection(box1,box2)
- }
-
- function intersection(box1,box2) {
- const [box1_x1,box1_y1,box1_x2,box1_y2] = box1;
- const [box2_x1,box2_y1,box2_x2,box2_y2] = box2;
- const x1 = Math.max(box1_x1,box2_x1);
- const y1 = Math.max(box1_y1,box2_y1);
- const x2 = Math.min(box1_x2,box2_x2);
- const y2 = Math.min(box1_y2,box2_y2);
- return (x2-x1)*(y2-y1)
- }
const yolo_classes = [
'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat',
diff --git a/examples/yolov8.py b/examples/yolov8.py
index 4fbf5ed4f2..782fa7c036 100644
--- a/examples/yolov8.py
+++ b/examples/yolov8.py
@@ -42,69 +42,8 @@ def preprocess(im, imgsz=640, model_stride=32, model_pt=True):
im = im / 255.0 # 0 - 255 to 0.0 - 1.0
return im
-# Post Processing functions
-def box_area(box):
- return (box[:, 2] - box[:, 0]) * (box[:, 3] - box[:, 1])
-def box_iou(box1, box2):
- lt = np.maximum(box1[:, None, :2], box2[:, :2])
- rb = np.minimum(box1[:, None, 2:], box2[:, 2:])
- wh = np.clip(rb - lt, 0, None)
- inter = wh[:, :, 0] * wh[:, :, 1]
- area1 = box_area(box1)[:, None]
- area2 = box_area(box2)[None, :]
- iou = inter / (area1 + area2 - inter)
- return iou
-
-def compute_nms(boxes, scores, iou_threshold):
- order, keep = scores.argsort()[::-1], []
- while order.size > 0:
- i = order[0]
- keep.append(i)
- if order.size == 1:
- break
- iou = box_iou(boxes[i][None, :], boxes[order[1:]])
- inds = np.where(np.atleast_1d(iou.squeeze()) <= iou_threshold)[0]
- order = order[inds + 1]
- return np.array(keep)
-
-def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, agnostic=False, max_det=300, nc=0, max_wh=7680):
- prediction = prediction[0] if isinstance(prediction, (list, tuple)) else prediction
- bs, nc = prediction.shape[0], nc or (prediction.shape[1] - 4)
- xc = np.amax(prediction[:, 4:4 + nc], axis=1) > conf_thres
- nm = prediction.shape[1] - nc - 4
- output = [np.zeros((0, 6 + nm))] * bs
-
- for xi, x in enumerate(prediction):
- x = x.swapaxes(0, -1)[xc[xi]]
- if not x.shape[0]: continue
- box, cls, mask = np.split(x, [4, 4 + nc], axis=1)
- conf, j = np.max(cls, axis=1, keepdims=True), np.argmax(cls, axis=1, keepdims=True)
- x = np.concatenate((xywh2xyxy(box), conf, j.astype(np.float32), mask), axis=1)
- x = x[conf.ravel() > conf_thres]
- if not x.shape[0]: continue
- x = x[np.argsort(-x[:, 4])]
- c = x[:, 5:6] * (0 if agnostic else max_wh)
- boxes, scores = x[:, :4] + c, x[:, 4]
- i = compute_nms(boxes, scores, iou_thres)[:max_det]
- output[xi] = x[i]
- return output
-
-def postprocess(preds, img, orig_imgs):
- print('copying to CPU now for post processing')
- #if you are on CPU, this causes an overflow runtime error. doesn't "seem" to make any difference in the predictions though.
- # TODO: make non_max_suppression in tinygrad - to make this faster
- preds = preds.numpy() if isinstance(preds, Tensor) else preds
- preds = non_max_suppression(prediction=preds, conf_thres=0.25, iou_thres=0.7, agnostic=False, max_det=300)
- all_preds = []
- for i, pred in enumerate(preds):
- orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
- if not isinstance(orig_imgs, Tensor):
- pred[:, :4] = scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
- all_preds.append(pred)
- return all_preds
-
-def draw_bounding_boxes_and_save(orig_img_paths, output_img_paths, all_predictions, class_labels, iou_threshold=0.5):
+def draw_bounding_boxes_and_save(orig_img_path, output_img_path, predictions, class_labels):
color_dict = {label: tuple((((i+1) * 50) % 256, ((i+1) * 100) % 256, ((i+1) * 150) % 256)) for i, label in enumerate(class_labels)}
font = cv2.FONT_HERSHEY_SIMPLEX
@@ -113,52 +52,32 @@ def draw_bounding_boxes_and_save(orig_img_paths, output_img_paths, all_predictio
brightness = (r * 299 + g * 587 + b * 114) / 1000
return brightness > 127
- for img_idx, (orig_img_path, output_img_path, predictions) in enumerate(zip(orig_img_paths, output_img_paths, all_predictions)):
- predictions = np.array(predictions)
- orig_img = cv2.imread(orig_img_path) if not isinstance(orig_img_path, np.ndarray) else cv2.imdecode(orig_img_path, 1)
- height, width, _ = orig_img.shape
- box_thickness = int((height + width) / 400)
- font_scale = (height + width) / 2500
+ orig_img = cv2.imread(orig_img_path) if not isinstance(orig_img_path, np.ndarray) else cv2.imdecode(orig_img_path, 1)
+ height, width, _ = orig_img.shape
+ box_thickness = int((height + width) / 400)
+ font_scale = (height + width) / 2500
+ object_count = defaultdict(int)
- grouped_preds = defaultdict(list)
- object_count = defaultdict(int)
+ for pred in predictions:
+ x1, y1, x2, y2, conf, class_id = pred
+ if conf == 0: continue
+ x1, y1, x2, y2, class_id = map(int, (x1, y1, x2, y2, class_id))
+ color = color_dict[class_labels[class_id]]
+ cv2.rectangle(orig_img, (x1, y1), (x2, y2), color, box_thickness)
+ label = f"{class_labels[class_id]} {conf:.2f}"
+ text_size, _ = cv2.getTextSize(label, font, font_scale, 1)
+ label_y, bg_y = (y1 - 4, y1 - text_size[1] - 4) if y1 - text_size[1] - 4 > 0 else (y1 + text_size[1], y1)
+ cv2.rectangle(orig_img, (x1, bg_y), (x1 + text_size[0], bg_y + text_size[1]), color, -1)
+ font_color = (0, 0, 0) if is_bright_color(color) else (255, 255, 255)
+ cv2.putText(orig_img, label, (x1, label_y), font, font_scale, font_color, 1, cv2.LINE_AA)
+ object_count[class_labels[class_id]] += 1
- for pred_np in predictions:
- grouped_preds[int(pred_np[-1])].append(pred_np)
+ print("Objects detected:")
+ for obj, count in object_count.items():
+ print(f"- {obj}: {count}")
- def draw_box_and_label(pred, color):
- x1, y1, x2, y2, conf, _ = pred
- x1, y1, x2, y2 = map(int, (x1, y1, x2, y2))
- cv2.rectangle(orig_img, (x1, y1), (x2, y2), color, box_thickness)
- label = f"{class_labels[class_id]} {conf:.2f}"
- text_size, _ = cv2.getTextSize(label, font, font_scale, 1)
- label_y, bg_y = (y1 - 4, y1 - text_size[1] - 4) if y1 - text_size[1] - 4 > 0 else (y1 + text_size[1], y1)
- cv2.rectangle(orig_img, (x1, bg_y), (x1 + text_size[0], bg_y + text_size[1]), color, -1)
- font_color = (0, 0, 0) if is_bright_color(color) else (255, 255, 255)
- cv2.putText(orig_img, label, (x1, label_y), font, font_scale, font_color, 1, cv2.LINE_AA)
-
- for class_id, pred_list in grouped_preds.items():
- pred_list = np.array(pred_list)
- while len(pred_list) > 0:
- max_conf_idx = np.argmax(pred_list[:, 4])
- max_conf_pred = pred_list[max_conf_idx]
- pred_list = np.delete(pred_list, max_conf_idx, axis=0)
- color = color_dict[class_labels[class_id]]
- draw_box_and_label(max_conf_pred, color)
- object_count[class_labels[class_id]] += 1
- iou_scores = box_iou(np.array([max_conf_pred[:4]]), pred_list[:, :4])
- low_iou_indices = np.where(iou_scores[0] < iou_threshold)[0]
- pred_list = pred_list[low_iou_indices]
- for low_conf_pred in pred_list:
- draw_box_and_label(low_conf_pred, color)
-
- print(f"Image {img_idx + 1}:")
- print("Objects detected:")
- for obj, count in object_count.items():
- print(f"- {obj}: {count}")
-
- cv2.imwrite(output_img_path, orig_img)
- print(f'saved detections at {output_img_path}')
+ cv2.imwrite(output_img_path, orig_img)
+ print(f'saved detections at {output_img_path}')
# utility functions for forward pass.
def dist2bbox(distance, anchor_points, xywh=True, dim=-1):
@@ -202,34 +121,26 @@ def clip_boxes(boxes, shape):
boxes[..., [1, 3]] = np.clip(boxes[..., [1, 3]], 0, shape[0]) # y1, y2
return boxes
-def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None):
+def scale_boxes(img1_shape, predictions, img0_shape, ratio_pad=None):
gain = ratio_pad if ratio_pad else min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1])
pad = ((img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2)
- boxes_np = boxes.numpy() if isinstance(boxes, Tensor) else boxes
- boxes_np[..., [0, 2]] -= pad[0]
- boxes_np[..., [1, 3]] -= pad[1]
- boxes_np[..., :4] /= gain
- boxes_np = clip_boxes(boxes_np, img0_shape)
- return boxes_np
-
-def xywh2xyxy(x):
- xy = x[..., :2] # center x, y
- wh = x[..., 2:4] # width, height
- xy1 = xy - wh / 2 # top left x, y
- xy2 = xy + wh / 2 # bottom right x, y
- result = np.concatenate((xy1, xy2), axis=-1)
- return Tensor(result) if isinstance(x, Tensor) else result
+ for pred in predictions:
+ boxes_np = pred[:4].numpy() if isinstance(pred[:4], Tensor) else pred[:4]
+ boxes_np[..., [0, 2]] -= pad[0]
+ boxes_np[..., [1, 3]] -= pad[1]
+ boxes_np[..., :4] /= gain
+ boxes_np = clip_boxes(boxes_np, img0_shape)
+ pred[:4] = boxes_np
+ return predictions
def get_variant_multiples(variant):
return {'n':(0.33, 0.25, 2.0), 's':(0.33, 0.50, 2.0), 'm':(0.67, 0.75, 1.5), 'l':(1.0, 1.0, 1.0), 'x':(1, 1.25, 1.0) }.get(variant, None)
def label_predictions(all_predictions):
class_index_count = defaultdict(int)
- for predictions in all_predictions:
- predictions = np.array(predictions)
- for pred_np in predictions:
- class_id = int(pred_np[-1])
- class_index_count[class_id] += 1
+ for pred in all_predictions:
+ class_id = int(pred[-1])
+ if pred[-2] != 0: class_index_count[class_id] += 1
return dict(class_index_count)
@@ -380,7 +291,9 @@ class YOLOv8:
def __call__(self, x):
x = self.net(x)
x = self.fpn(*x)
- return self.head(x)
+ x = self.head(x)
+ # TODO: postprocess needs to be in the model to be compiled to webgpu
+ return postprocess(x)
def return_all_trainable_modules(self):
backbone_modules = [*range(10)]
@@ -403,6 +316,39 @@ def convert_f16_safetensor_to_f32(input_file: Path, output_file: Path):
f.write(new_metadata_bytes)
float32_values.tofile(f)
+def compute_iou_matrix(boxes):
+ x1, y1, x2, y2 = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3]
+ areas = (x2 - x1) * (y2 - y1)
+ x1 = Tensor.maximum(x1[:, None], x1[None, :])
+ y1 = Tensor.maximum(y1[:, None], y1[None, :])
+ x2 = Tensor.minimum(x2[:, None], x2[None, :])
+ y2 = Tensor.minimum(y2[:, None], y2[None, :])
+ w = Tensor.maximum(Tensor(0), x2 - x1)
+ h = Tensor.maximum(Tensor(0), y2 - y1)
+ intersection = w * h
+ union = areas[:, None] + areas[None, :] - intersection
+ return intersection / union
+
+def postprocess(output, max_det=300, conf_threshold=0.25, iou_threshold=0.45):
+ xc, yc, w, h, class_scores = output[0][0], output[0][1], output[0][2], output[0][3], output[0][4:]
+ class_ids = Tensor.argmax(class_scores, axis=0)
+ probs = Tensor.max(class_scores, axis=0)
+ probs = Tensor.where(probs >= conf_threshold, probs, 0)
+ x1 = xc - w / 2
+ y1 = yc - h / 2
+ x2 = xc + w / 2
+ y2 = yc + h / 2
+ boxes = Tensor.stack(x1, y1, x2, y2, probs, class_ids, dim=1)
+ order = Tensor.topk(probs, max_det)[1]
+ boxes = boxes[order]
+ iou = compute_iou_matrix(boxes[:, :4])
+ iou = Tensor.triu(iou, diagonal=1)
+ same_class_mask = boxes[:, -1][:, None] == boxes[:, -1][None, :]
+ high_iou_mask = (iou > iou_threshold) & same_class_mask
+ no_overlap_mask = high_iou_mask.sum(axis=0) == 0
+ boxes = boxes * no_overlap_mask.unsqueeze(-1)
+ return boxes
+
def get_weights_location(yolo_variant: str) -> Path:
weights_location = Path(__file__).parents[1] / "weights" / f'yolov8{yolo_variant}.safetensors'
fetch(f'https://gitlab.com/r3sist/yolov8_weights/-/raw/master/yolov8{yolo_variant}.safetensors', weights_location)
@@ -428,14 +374,13 @@ if __name__ == '__main__':
output_folder_path = Path('./outputs_yolov8')
output_folder_path.mkdir(parents=True, exist_ok=True)
#absolute image path or URL
- image_location = [np.frombuffer(fetch(img_path).read_bytes(), np.uint8)]
- image = [cv2.imdecode(image_location[0], 1)]
- out_paths = [(output_folder_path / f"{Path(img_path).stem}_output{Path(img_path).suffix or '.png'}").as_posix()]
+ image_location = np.frombuffer(fetch(img_path).read_bytes(), np.uint8)
+ image = [cv2.imdecode(image_location, 1)]
+ out_path = (output_folder_path / f"{Path(img_path).stem}_output{Path(img_path).suffix or '.png'}").as_posix()
if not isinstance(image[0], np.ndarray):
print('Error in image loading. Check your image file.')
sys.exit(1)
pre_processed_image = preprocess(image)
-
# Different YOLOv8 variants use different w , r, and d multiples. For a list , refer to this yaml file (the scales section) https://github.com/ultralytics/ultralytics/blob/main/ultralytics/cfg/models/v8/yolov8.yaml
depth, width, ratio = get_variant_multiples(yolo_variant)
yolo_infer = YOLOv8(w=width, r=ratio, d=depth, num_classes=80)
@@ -443,15 +388,13 @@ if __name__ == '__main__':
load_state_dict(yolo_infer, state_dict)
st = time.time()
- predictions = yolo_infer(pre_processed_image)
+ predictions = yolo_infer(pre_processed_image).numpy()
+
print(f'did inference in {int(round(((time.time() - st) * 1000)))}ms')
-
- post_predictions = postprocess(preds=predictions, img=pre_processed_image, orig_imgs=image)
-
#v8 and v3 have same 80 class names for Object Detection
class_labels = fetch('https://raw.githubusercontent.com/pjreddie/darknet/master/data/coco.names').read_text().split("\n")
-
- draw_bounding_boxes_and_save(orig_img_paths=image_location, output_img_paths=out_paths, all_predictions=post_predictions, class_labels=class_labels)
+ predictions = scale_boxes(pre_processed_image.shape[2:], predictions, image[0].shape)
+ draw_bounding_boxes_and_save(orig_img_path=image_location, output_img_path=out_path, predictions=predictions, class_labels=class_labels)
# TODO for later:
# 1. Fix SPPF minor difference due to maxpool
diff --git a/test/external/external_test_yolov8.py b/test/external/external_test_yolov8.py
index a215aa0675..5506fae7f0 100644
--- a/test/external/external_test_yolov8.py
+++ b/test/external/external_test_yolov8.py
@@ -1,5 +1,6 @@
import numpy as np
-from examples.yolov8 import YOLOv8, get_variant_multiples, preprocess, postprocess, label_predictions
+from examples.yolov8 import YOLOv8, get_variant_multiples, preprocess, label_predictions, postprocess
+from tinygrad import Tensor
import unittest
import io, cv2
import onnxruntime as ort
@@ -28,9 +29,8 @@ class TestYOLOv8(unittest.TestCase):
img = cv2.imdecode(np.frombuffer(fetch(test_image_urls[i]).read_bytes(), np.uint8), 1)
test_image = preprocess([img])
predictions = TinyYolov8(test_image)
- post_predictions = postprocess(preds=predictions, img=test_image, orig_imgs=[img])
- labels = label_predictions(post_predictions)
- assert labels == {5: 1, 0: 4, 11: 1} if i == 0 else labels == {0: 13, 29: 1, 32: 1}
+ labels = label_predictions(predictions.numpy())
+ assert labels == {5: 1, 0: 4, 11: 1} if i == 0 else labels == {0: 12, 29: 1, 32: 1}
def test_forward_pass_torch_onnx(self):
variant = 'n'
@@ -58,12 +58,16 @@ class TestYOLOv8(unittest.TestCase):
onnx_output_name = onnx_session.get_outputs()[0].name
onnx_output = onnx_session.run([onnx_output_name], {onnx_input_name: input_image.numpy()})
- tiny_output = TinyYolov8(input_image)
+ tiny_output = TinyYolov8(input_image).numpy()
+ onnx_output = postprocess(Tensor(onnx_output[0])).numpy()
+ #invalid boxes are multiplied by zero in postprocess
+ onnx_output = onnx_output[onnx_output[:, 4] != 0]
+ tiny_output = tiny_output[tiny_output[:, 4] != 0]
# currently rtol is 0.025 because there is a 1-2% difference in our predictions
# because of the zero padding in SPPF module (line 280) maxpooling layers rather than the -infinity in torch.
# This difference does not make a difference "visually".
- np.testing.assert_allclose(onnx_output[0], tiny_output.numpy(), atol=5e-4, rtol=0.025)
+ np.testing.assert_allclose(onnx_output, tiny_output, atol=5e-4, rtol=0.025)
if __name__ == '__main__':
unittest.main()