Fast YoloV8 on WebGPU (#8036)

* Fast yolov8 with downscaled input

* Faster + FPS meter

* Add loader while model is downloading/compiling

* Title touchup
This commit is contained in:
Ahmed Harmouche
2024-12-04 15:23:09 +01:00
committed by GitHub
parent b116e1511d
commit c9e7701417
2 changed files with 115 additions and 26 deletions

View File

@@ -12,7 +12,7 @@ if __name__ == "__main__":
yolo_infer = YOLOv8(w=0.25, r=2.0, d=0.33, num_classes=80)
state_dict = safe_load(get_weights_location(yolo_variant))
load_state_dict(yolo_infer, state_dict)
prg, inp_sizes, out_sizes, state = export_model(yolo_infer, Device.DEFAULT.lower(), Tensor.randn(1,3,640,640))
prg, inp_sizes, out_sizes, state = export_model(yolo_infer, Device.DEFAULT.lower(), Tensor.randn(1,3,256,256))
dirname = Path(__file__).parent
safe_save(state, (dirname / "net.safetensors").as_posix())
with open(dirname / f"net.js", "w") as text_file:

View File

@@ -1,12 +1,12 @@
<!-- Pre and post-processing functions from: https://github.com/AndreyGermanov/yolov8_onnx_javascript -->
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>YOLOv8 tinygrad WebGPU</title>
<script src="./net.js"></script>
<style>
body {
body {
text-align: center;
font-family: Arial, sans-serif;
margin: 0;
@@ -17,7 +17,11 @@
.video-container {
position: relative;
width: 100%;
height: 100vh;
margin: 0 auto;
display: flex;
align-items: center;
justify-content: center;
}
#video, #canvas {
@@ -28,33 +32,96 @@
height: auto;
}
.loader {
width: 48px;
height: 48px;
border: 5px solid #FFF;
border-bottom-color: transparent;
border-radius: 50%;
display: inline-block;
box-sizing: border-box;
animation: rotation 1s linear infinite;
}
@keyframes rotation {
0% {
transform: rotate(0deg);
}
100% {
transform: rotate(360deg);
}
}
#canvas {
background: transparent;
}
#fps-meter {
position: absolute;
top: 20px;
right: 20px;
background-color: rgba(0, 0, 0, 0.7);
color: white;
padding: 10px;
font-size: 18px;
border-radius: 5px;
z-index: 10;
}
h1 {
margin-top: 20px;
}
.loading-container {
display: flex;
flex-direction: column;
align-items: center;
justify-content: center;
position: fixed;
top: 0;
left: 0;
width: 100%;
height: 100%;
background-color: rgba(0, 0, 0, 0.6);
z-index: 10;
}
.loading-text {
font-size: 24px;
color: white;
margin-bottom: 20px;
}
</style>
</head>
<body>
<h1>YOLOv8 tinygrad WebGPU</h1>
<h2>YOLOv8 tinygrad WebGPU</h2>
<div class="video-container">
<video id="video" muted autoplay playsinline></video>
<canvas id="canvas"></canvas>
<div id="fps-meter"></div>
<div id="div-loading" class="loading-container">
<p class="loading-text">Loading model</p>
<span class="loader"></span>
</div>
</div>
<script>
let net = null;
const modelInputSize = 256
let lastCalledTime;
let fps = 0, accumFps = 0, frameCounter = 0;
const video = document.getElementById('video');
const canvas = document.getElementById('canvas');
const context = canvas.getContext('2d');
const offscreenCanvas = document.createElement('canvas');
offscreenCanvas.width = 640;
offscreenCanvas.height = 640;
const fpsMeter = document.getElementById('fps-meter');
const loadingContainer = document.getElementById('div-loading');
offscreenCanvas.width = modelInputSize;
offscreenCanvas.height = modelInputSize;
const offscreenContext = offscreenCanvas.getContext('2d');
if (navigator.mediaDevices && navigator.mediaDevices.getUserMedia) {
navigator.mediaDevices.getUserMedia({ audio: false, video: { facingMode: { ideal: "environment" }}}).then(function (stream) {
video.srcObject = stream;
@@ -70,20 +137,38 @@
requestAnimationFrame(processFrame);
return;
}
if (!lastCalledTime) {
lastCalledTime = performance.now();
fps = 0;
} else {
const now = performance.now();
delta = (now - lastCalledTime)/1000.0;
lastCalledTime = now;
accumFps += 1/delta;
if (frameCounter++ >= 30) {
fps = accumFps/frameCounter;
frameCounter = 0;
accumFps = 0;
fpsMeter.innerText = `FPS: ${fps.toFixed(1)}`
}
}
const videoAspectRatio = video.videoWidth / video.videoHeight;
let targetWidth, targetHeight;
if (videoAspectRatio > 1) {
targetWidth = 640;
targetHeight = 640 / videoAspectRatio;
targetWidth = modelInputSize;
targetHeight = modelInputSize / videoAspectRatio;
} else {
targetHeight = 640;
targetWidth = 640 * videoAspectRatio;
targetHeight = modelInputSize;
targetWidth = modelInputSize * videoAspectRatio;
}
const offsetX = (640 - targetWidth) / 2;
const offsetY = (640 - targetHeight) / 2;
offscreenContext.clearRect(0, 0, 640, 640);
const offsetX = (modelInputSize - targetWidth) / 2;
const offsetY = (modelInputSize - targetHeight) / 2;
offscreenContext.clearRect(0, 0, modelInputSize, modelInputSize);
offscreenContext.drawImage(video, offsetX, offsetY, targetWidth, targetHeight);
const boxes = await detectObjectsOnFrame(offscreenContext);
drawBoxes(offscreenCanvas, boxes, targetWidth, targetHeight, offsetX, offsetY);
@@ -120,7 +205,10 @@
}
async function detectObjectsOnFrame(offscreenContext) {
if (!net) net = await loadNet(await getDevice());
if (!net) {
net = await loadNet(await getDevice());
loadingContainer.style.display = "none";
}
let start = performance.now();
const [input,img_width,img_height] = await prepareInput(offscreenContext);
console.log("Preprocess took: " + (performance.now() - start) + " ms");
@@ -135,8 +223,8 @@
async function prepareInput(offscreenContext) {
return new Promise(resolve => {
const [img_width,img_height] = [640, 640]
const imgData = offscreenContext.getImageData(0,0,640,640);
const [img_width,img_height] = [modelInputSize, modelInputSize]
const imgData = offscreenContext.getImageData(0,0,modelInputSize,modelInputSize);
const pixels = imgData.data;
const red = [], green = [], blue = [];
@@ -158,21 +246,22 @@
function processOutput(output, img_width, img_height) {
let boxes = [];
for (let index=0;index<8400;index++) {
const numPredictions = Math.pow(modelInputSize/32, 2) * 21;
for (let index=0;index<numPredictions;index++) {
const [class_id,prob] = [...Array(80).keys()]
.map(col => [col, output[8400*(col+4)+index]])
.map(col => [col, output[numPredictions*(col+4)+index]])
.reduce((accum, item) => item[1]>accum[1] ? item : accum,[0,0]);
if (prob < 0.25) continue;
const label = yolo_classes[class_id];
const xc = output[index];
const yc = output[8400+index];
const w = output[2*8400+index];
const h = output[3*8400+index];
const x1 = (xc-w/2)/640*img_width;
const y1 = (yc-h/2)/640*img_height;
const x2 = (xc+w/2)/640*img_width;
const y2 = (yc+h/2)/640*img_height;
const yc = output[numPredictions+index];
const w = output[2*numPredictions+index];
const h = output[3*numPredictions+index];
const x1 = (xc-w/2)/modelInputSize*img_width;
const y1 = (yc-h/2)/modelInputSize*img_height;
const x2 = (xc+w/2)/modelInputSize*img_width;
const y2 = (yc+h/2)/modelInputSize*img_height;
boxes.push([x1,y1,x2,y2,label,prob]);
}