mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
Fast YoloV8 on WebGPU (#8036)
* Fast yolov8 with downscaled input * Faster + FPS meter * Add loader while model is downloading/compiling * Title touchup
This commit is contained in:
@@ -12,7 +12,7 @@ if __name__ == "__main__":
|
||||
yolo_infer = YOLOv8(w=0.25, r=2.0, d=0.33, num_classes=80)
|
||||
state_dict = safe_load(get_weights_location(yolo_variant))
|
||||
load_state_dict(yolo_infer, state_dict)
|
||||
prg, inp_sizes, out_sizes, state = export_model(yolo_infer, Device.DEFAULT.lower(), Tensor.randn(1,3,640,640))
|
||||
prg, inp_sizes, out_sizes, state = export_model(yolo_infer, Device.DEFAULT.lower(), Tensor.randn(1,3,256,256))
|
||||
dirname = Path(__file__).parent
|
||||
safe_save(state, (dirname / "net.safetensors").as_posix())
|
||||
with open(dirname / f"net.js", "w") as text_file:
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
<!-- Pre and post-processing functions from: https://github.com/AndreyGermanov/yolov8_onnx_javascript -->
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>YOLOv8 tinygrad WebGPU</title>
|
||||
<script src="./net.js"></script>
|
||||
<style>
|
||||
body {
|
||||
body {
|
||||
text-align: center;
|
||||
font-family: Arial, sans-serif;
|
||||
margin: 0;
|
||||
@@ -17,7 +17,11 @@
|
||||
.video-container {
|
||||
position: relative;
|
||||
width: 100%;
|
||||
height: 100vh;
|
||||
margin: 0 auto;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
}
|
||||
|
||||
#video, #canvas {
|
||||
@@ -28,33 +32,96 @@
|
||||
height: auto;
|
||||
}
|
||||
|
||||
.loader {
|
||||
width: 48px;
|
||||
height: 48px;
|
||||
border: 5px solid #FFF;
|
||||
border-bottom-color: transparent;
|
||||
border-radius: 50%;
|
||||
display: inline-block;
|
||||
box-sizing: border-box;
|
||||
animation: rotation 1s linear infinite;
|
||||
}
|
||||
|
||||
@keyframes rotation {
|
||||
0% {
|
||||
transform: rotate(0deg);
|
||||
}
|
||||
100% {
|
||||
transform: rotate(360deg);
|
||||
}
|
||||
}
|
||||
|
||||
#canvas {
|
||||
background: transparent;
|
||||
}
|
||||
|
||||
#fps-meter {
|
||||
position: absolute;
|
||||
top: 20px;
|
||||
right: 20px;
|
||||
background-color: rgba(0, 0, 0, 0.7);
|
||||
color: white;
|
||||
padding: 10px;
|
||||
font-size: 18px;
|
||||
border-radius: 5px;
|
||||
z-index: 10;
|
||||
}
|
||||
|
||||
h1 {
|
||||
margin-top: 20px;
|
||||
}
|
||||
|
||||
.loading-container {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
position: fixed;
|
||||
top: 0;
|
||||
left: 0;
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
background-color: rgba(0, 0, 0, 0.6);
|
||||
z-index: 10;
|
||||
}
|
||||
|
||||
.loading-text {
|
||||
font-size: 24px;
|
||||
color: white;
|
||||
margin-bottom: 20px;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<h1>YOLOv8 tinygrad WebGPU</h1>
|
||||
<h2>YOLOv8 tinygrad WebGPU</h2>
|
||||
<div class="video-container">
|
||||
<video id="video" muted autoplay playsinline></video>
|
||||
<canvas id="canvas"></canvas>
|
||||
<div id="fps-meter"></div>
|
||||
|
||||
<div id="div-loading" class="loading-container">
|
||||
<p class="loading-text">Loading model</p>
|
||||
<span class="loader"></span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
let net = null;
|
||||
const modelInputSize = 256
|
||||
let lastCalledTime;
|
||||
let fps = 0, accumFps = 0, frameCounter = 0;
|
||||
|
||||
const video = document.getElementById('video');
|
||||
const canvas = document.getElementById('canvas');
|
||||
const context = canvas.getContext('2d');
|
||||
const offscreenCanvas = document.createElement('canvas');
|
||||
offscreenCanvas.width = 640;
|
||||
offscreenCanvas.height = 640;
|
||||
const fpsMeter = document.getElementById('fps-meter');
|
||||
const loadingContainer = document.getElementById('div-loading');
|
||||
offscreenCanvas.width = modelInputSize;
|
||||
offscreenCanvas.height = modelInputSize;
|
||||
const offscreenContext = offscreenCanvas.getContext('2d');
|
||||
|
||||
|
||||
if (navigator.mediaDevices && navigator.mediaDevices.getUserMedia) {
|
||||
navigator.mediaDevices.getUserMedia({ audio: false, video: { facingMode: { ideal: "environment" }}}).then(function (stream) {
|
||||
video.srcObject = stream;
|
||||
@@ -70,20 +137,38 @@
|
||||
requestAnimationFrame(processFrame);
|
||||
return;
|
||||
}
|
||||
|
||||
if (!lastCalledTime) {
|
||||
lastCalledTime = performance.now();
|
||||
fps = 0;
|
||||
} else {
|
||||
const now = performance.now();
|
||||
delta = (now - lastCalledTime)/1000.0;
|
||||
lastCalledTime = now;
|
||||
accumFps += 1/delta;
|
||||
|
||||
if (frameCounter++ >= 30) {
|
||||
fps = accumFps/frameCounter;
|
||||
frameCounter = 0;
|
||||
accumFps = 0;
|
||||
fpsMeter.innerText = `FPS: ${fps.toFixed(1)}`
|
||||
}
|
||||
}
|
||||
|
||||
const videoAspectRatio = video.videoWidth / video.videoHeight;
|
||||
let targetWidth, targetHeight;
|
||||
|
||||
if (videoAspectRatio > 1) {
|
||||
targetWidth = 640;
|
||||
targetHeight = 640 / videoAspectRatio;
|
||||
targetWidth = modelInputSize;
|
||||
targetHeight = modelInputSize / videoAspectRatio;
|
||||
} else {
|
||||
targetHeight = 640;
|
||||
targetWidth = 640 * videoAspectRatio;
|
||||
targetHeight = modelInputSize;
|
||||
targetWidth = modelInputSize * videoAspectRatio;
|
||||
}
|
||||
|
||||
const offsetX = (640 - targetWidth) / 2;
|
||||
const offsetY = (640 - targetHeight) / 2;
|
||||
offscreenContext.clearRect(0, 0, 640, 640);
|
||||
const offsetX = (modelInputSize - targetWidth) / 2;
|
||||
const offsetY = (modelInputSize - targetHeight) / 2;
|
||||
offscreenContext.clearRect(0, 0, modelInputSize, modelInputSize);
|
||||
offscreenContext.drawImage(video, offsetX, offsetY, targetWidth, targetHeight);
|
||||
const boxes = await detectObjectsOnFrame(offscreenContext);
|
||||
drawBoxes(offscreenCanvas, boxes, targetWidth, targetHeight, offsetX, offsetY);
|
||||
@@ -120,7 +205,10 @@
|
||||
}
|
||||
|
||||
async function detectObjectsOnFrame(offscreenContext) {
|
||||
if (!net) net = await loadNet(await getDevice());
|
||||
if (!net) {
|
||||
net = await loadNet(await getDevice());
|
||||
loadingContainer.style.display = "none";
|
||||
}
|
||||
let start = performance.now();
|
||||
const [input,img_width,img_height] = await prepareInput(offscreenContext);
|
||||
console.log("Preprocess took: " + (performance.now() - start) + " ms");
|
||||
@@ -135,8 +223,8 @@
|
||||
|
||||
async function prepareInput(offscreenContext) {
|
||||
return new Promise(resolve => {
|
||||
const [img_width,img_height] = [640, 640]
|
||||
const imgData = offscreenContext.getImageData(0,0,640,640);
|
||||
const [img_width,img_height] = [modelInputSize, modelInputSize]
|
||||
const imgData = offscreenContext.getImageData(0,0,modelInputSize,modelInputSize);
|
||||
const pixels = imgData.data;
|
||||
const red = [], green = [], blue = [];
|
||||
|
||||
@@ -158,21 +246,22 @@
|
||||
|
||||
function processOutput(output, img_width, img_height) {
|
||||
let boxes = [];
|
||||
for (let index=0;index<8400;index++) {
|
||||
const numPredictions = Math.pow(modelInputSize/32, 2) * 21;
|
||||
for (let index=0;index<numPredictions;index++) {
|
||||
const [class_id,prob] = [...Array(80).keys()]
|
||||
.map(col => [col, output[8400*(col+4)+index]])
|
||||
.map(col => [col, output[numPredictions*(col+4)+index]])
|
||||
.reduce((accum, item) => item[1]>accum[1] ? item : accum,[0,0]);
|
||||
|
||||
if (prob < 0.25) continue;
|
||||
const label = yolo_classes[class_id];
|
||||
const xc = output[index];
|
||||
const yc = output[8400+index];
|
||||
const w = output[2*8400+index];
|
||||
const h = output[3*8400+index];
|
||||
const x1 = (xc-w/2)/640*img_width;
|
||||
const y1 = (yc-h/2)/640*img_height;
|
||||
const x2 = (xc+w/2)/640*img_width;
|
||||
const y2 = (yc+h/2)/640*img_height;
|
||||
const yc = output[numPredictions+index];
|
||||
const w = output[2*numPredictions+index];
|
||||
const h = output[3*numPredictions+index];
|
||||
const x1 = (xc-w/2)/modelInputSize*img_width;
|
||||
const y1 = (yc-h/2)/modelInputSize*img_height;
|
||||
const x2 = (xc+w/2)/modelInputSize*img_width;
|
||||
const y2 = (yc+h/2)/modelInputSize*img_height;
|
||||
boxes.push([x1,y1,x2,y2,label,prob]);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user