YoloV8 on WebGPU (#8007)

Port YoloV8 to WebGPU
This commit is contained in:
Ahmed Harmouche
2024-12-03 15:10:41 +01:00
committed by GitHub
parent 09eac42fd6
commit 8818046940
4 changed files with 82 additions and 49 deletions

View File

@@ -0,0 +1,19 @@
from pathlib import Path
from examples.yolov8 import YOLOv8, get_weights_location
from tinygrad.tensor import Tensor
from tinygrad.nn.state import safe_save
from extra.export_model import export_model
from tinygrad.device import Device
from tinygrad.nn.state import safe_load, load_state_dict
if __name__ == "__main__":
Device.DEFAULT = "WEBGPU"
yolo_variant = 'n'
yolo_infer = YOLOv8(w=0.25, r=2.0, d=0.33, num_classes=80)
state_dict = safe_load(get_weights_location(yolo_variant))
load_state_dict(yolo_infer, state_dict)
prg, inp_sizes, out_sizes, state = export_model(yolo_infer, Device.DEFAULT.lower(), Tensor.randn(1,3,640,640))
dirname = Path(__file__).parent
safe_save(state, (dirname / "net.safetensors").as_posix())
with open(dirname / f"net.js", "w") as text_file:
text_file.write(prg)

View File

@@ -0,0 +1,232 @@
<!-- Pre and post-processing functions from: https://github.com/AndreyGermanov/yolov8_onnx_javascript -->
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<title>YOLOv8 tinygrad WebGPU</title>
<script src="./net.js"></script>
<style>
body {
text-align: center;
font-family: Arial, sans-serif;
margin: 0;
padding: 0;
overflow: hidden;
}
.video-container {
position: relative;
width: 100%;
margin: 0 auto;
}
#video, #canvas {
position: absolute;
top: 0;
left: 0;
width: 100%;
height: auto;
}
#canvas {
background: transparent;
}
h1 {
margin-top: 20px;
}
</style>
</head>
<body>
<h1>YOLOv8 tinygrad WebGPU</h1>
<div class="video-container">
<video id="video" muted autoplay playsinline></video>
<canvas id="canvas"></canvas>
</div>
<script>
let net = null;
const video = document.getElementById('video');
const canvas = document.getElementById('canvas');
const context = canvas.getContext('2d');
const offscreenCanvas = document.createElement('canvas');
offscreenCanvas.width = 640;
offscreenCanvas.height = 640;
const offscreenContext = offscreenCanvas.getContext('2d');
if (navigator.mediaDevices && navigator.mediaDevices.getUserMedia) {
navigator.mediaDevices.getUserMedia({ audio: false, video: { facingMode: { ideal: "environment" }}}).then(function (stream) {
video.srcObject = stream;
video.onloadedmetadata = function() {
canvas.width = video.clientWidth;
canvas.height = video.clientHeight;
}
});
}
async function processFrame() {
if (video.videoWidth == 0 || video.videoHeight == 0) {
requestAnimationFrame(processFrame);
return;
}
const videoAspectRatio = video.videoWidth / video.videoHeight;
let targetWidth, targetHeight;
if (videoAspectRatio > 1) {
targetWidth = 640;
targetHeight = 640 / videoAspectRatio;
} else {
targetHeight = 640;
targetWidth = 640 * videoAspectRatio;
}
const offsetX = (640 - targetWidth) / 2;
const offsetY = (640 - targetHeight) / 2;
offscreenContext.clearRect(0, 0, 640, 640);
offscreenContext.drawImage(video, offsetX, offsetY, targetWidth, targetHeight);
const boxes = await detectObjectsOnFrame(offscreenContext);
drawBoxes(offscreenCanvas, boxes, targetWidth, targetHeight, offsetX, offsetY);
requestAnimationFrame(processFrame);
}
requestAnimationFrame(processFrame);
function drawBoxes(offscreenCanvas, boxes, targetWidth, targetHeight, offsetX, offsetY) {
const ctx = document.querySelector("canvas").getContext("2d");
ctx.clearRect(0, 0, canvas.width, canvas.height);
ctx.lineWidth = 3;
ctx.font = "30px serif";
const scaleX = canvas.width / targetWidth;
const scaleY = canvas.height / targetHeight;
boxes.forEach(([x1, y1, x2, y2, label]) => {
const classIndex = yolo_classes.indexOf(label);
const color = classColors[classIndex];
ctx.strokeStyle = color;
ctx.fillStyle = color;
const adjustedX1 = (x1 - offsetX) * scaleX;
const adjustedY1 = (y1 - offsetY) * scaleY;
const adjustedX2 = (x2 - offsetX) * scaleX;
const adjustedY2 = (y2 - offsetY) * scaleY;
const boxWidth = adjustedX2 - adjustedX1;
const boxHeight = adjustedY2 - adjustedY1;
ctx.strokeRect(adjustedX1, adjustedY1, boxWidth, boxHeight);
const textWidth = ctx.measureText(label).width;
ctx.fillRect(adjustedX1, adjustedY1 - 25, textWidth + 10, 25);
ctx.fillStyle = "#FFFFFF";
ctx.fillText(label, adjustedX1 + 5, adjustedY1 - 7);
});
}
async function detectObjectsOnFrame(offscreenContext) {
if (!net) net = await loadNet(await getDevice());
let start = performance.now();
const [input,img_width,img_height] = await prepareInput(offscreenContext);
console.log("Preprocess took: " + (performance.now() - start) + " ms");
start = performance.now();
const output = await net(new Float32Array(input));
console.log("Inference took: " + (performance.now() - start) + " ms");
start = performance.now();
let out = processOutput(output[0],img_width,img_height);
console.log("Postprocess took: " + (performance.now() - start) + " ms");
return out;
}
async function prepareInput(offscreenContext) {
return new Promise(resolve => {
const [img_width,img_height] = [640, 640]
const imgData = offscreenContext.getImageData(0,0,640,640);
const pixels = imgData.data;
const red = [], green = [], blue = [];
for (let index=0; index<pixels.length; index+=4) {
red.push(pixels[index]/255.0);
green.push(pixels[index+1]/255.0);
blue.push(pixels[index+2]/255.0);
}
const input = [...red, ...green, ...blue];
resolve([input, img_width, img_height])
})
}
const getDevice = async () => {
if (!navigator.gpu) error("WebGPU not supported.");
const adapter = await navigator.gpu.requestAdapter();
return await adapter.requestDevice();
};
function processOutput(output, img_width, img_height) {
let boxes = [];
for (let index=0;index<8400;index++) {
const [class_id,prob] = [...Array(80).keys()]
.map(col => [col, output[8400*(col+4)+index]])
.reduce((accum, item) => item[1]>accum[1] ? item : accum,[0,0]);
if (prob < 0.25) continue;
const label = yolo_classes[class_id];
const xc = output[index];
const yc = output[8400+index];
const w = output[2*8400+index];
const h = output[3*8400+index];
const x1 = (xc-w/2)/640*img_width;
const y1 = (yc-h/2)/640*img_height;
const x2 = (xc+w/2)/640*img_width;
const y2 = (yc+h/2)/640*img_height;
boxes.push([x1,y1,x2,y2,label,prob]);
}
boxes = boxes.sort((box1,box2) => box2[5]-box1[5])
const result = [];
while (boxes.length>0) {
result.push(boxes[0]);
boxes = boxes.filter(box => iou(boxes[0],box)<0.7);
}
return result;
}
function iou(box1,box2) {
return intersection(box1,box2)/union(box1,box2);
}
function union(box1,box2) {
const [box1_x1,box1_y1,box1_x2,box1_y2] = box1;
const [box2_x1,box2_y1,box2_x2,box2_y2] = box2;
const box1_area = (box1_x2-box1_x1)*(box1_y2-box1_y1)
const box2_area = (box2_x2-box2_x1)*(box2_y2-box2_y1)
return box1_area + box2_area - intersection(box1,box2)
}
function intersection(box1,box2) {
const [box1_x1,box1_y1,box1_x2,box1_y2] = box1;
const [box2_x1,box2_y1,box2_x2,box2_y2] = box2;
const x1 = Math.max(box1_x1,box2_x1);
const y1 = Math.max(box1_y1,box2_y1);
const x2 = Math.min(box1_x2,box2_x2);
const y2 = Math.min(box1_y2,box2_y2);
return (x2-x1)*(y2-y1)
}
const yolo_classes = [
'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat',
'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse',
'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase',
'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard',
'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant',
'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven',
'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
];
function generateColors(numColors) {
const colors = [];
for (let i = 0; i < 360; i += 360 / numColors) {
colors.push(`hsl(${i}, 100%, 50%)`);
}
return colors;
}
const classColors = generateColors(yolo_classes.length);
</script>
</body>
</html>