mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
19
examples/webgpu/yolov8/compile.py
Normal file
19
examples/webgpu/yolov8/compile.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from pathlib import Path
|
||||
from examples.yolov8 import YOLOv8, get_weights_location
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.nn.state import safe_save
|
||||
from extra.export_model import export_model
|
||||
from tinygrad.device import Device
|
||||
from tinygrad.nn.state import safe_load, load_state_dict
|
||||
|
||||
if __name__ == "__main__":
|
||||
Device.DEFAULT = "WEBGPU"
|
||||
yolo_variant = 'n'
|
||||
yolo_infer = YOLOv8(w=0.25, r=2.0, d=0.33, num_classes=80)
|
||||
state_dict = safe_load(get_weights_location(yolo_variant))
|
||||
load_state_dict(yolo_infer, state_dict)
|
||||
prg, inp_sizes, out_sizes, state = export_model(yolo_infer, Device.DEFAULT.lower(), Tensor.randn(1,3,640,640))
|
||||
dirname = Path(__file__).parent
|
||||
safe_save(state, (dirname / "net.safetensors").as_posix())
|
||||
with open(dirname / f"net.js", "w") as text_file:
|
||||
text_file.write(prg)
|
||||
232
examples/webgpu/yolov8/index.html
Normal file
232
examples/webgpu/yolov8/index.html
Normal file
@@ -0,0 +1,232 @@
|
||||
<!-- Pre and post-processing functions from: https://github.com/AndreyGermanov/yolov8_onnx_javascript -->
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<title>YOLOv8 tinygrad WebGPU</title>
|
||||
<script src="./net.js"></script>
|
||||
<style>
|
||||
body {
|
||||
text-align: center;
|
||||
font-family: Arial, sans-serif;
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.video-container {
|
||||
position: relative;
|
||||
width: 100%;
|
||||
margin: 0 auto;
|
||||
}
|
||||
|
||||
#video, #canvas {
|
||||
position: absolute;
|
||||
top: 0;
|
||||
left: 0;
|
||||
width: 100%;
|
||||
height: auto;
|
||||
}
|
||||
|
||||
#canvas {
|
||||
background: transparent;
|
||||
}
|
||||
|
||||
h1 {
|
||||
margin-top: 20px;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<h1>YOLOv8 tinygrad WebGPU</h1>
|
||||
<div class="video-container">
|
||||
<video id="video" muted autoplay playsinline></video>
|
||||
<canvas id="canvas"></canvas>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
let net = null;
|
||||
|
||||
const video = document.getElementById('video');
|
||||
const canvas = document.getElementById('canvas');
|
||||
const context = canvas.getContext('2d');
|
||||
const offscreenCanvas = document.createElement('canvas');
|
||||
offscreenCanvas.width = 640;
|
||||
offscreenCanvas.height = 640;
|
||||
const offscreenContext = offscreenCanvas.getContext('2d');
|
||||
|
||||
if (navigator.mediaDevices && navigator.mediaDevices.getUserMedia) {
|
||||
navigator.mediaDevices.getUserMedia({ audio: false, video: { facingMode: { ideal: "environment" }}}).then(function (stream) {
|
||||
video.srcObject = stream;
|
||||
video.onloadedmetadata = function() {
|
||||
canvas.width = video.clientWidth;
|
||||
canvas.height = video.clientHeight;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
async function processFrame() {
|
||||
if (video.videoWidth == 0 || video.videoHeight == 0) {
|
||||
requestAnimationFrame(processFrame);
|
||||
return;
|
||||
}
|
||||
const videoAspectRatio = video.videoWidth / video.videoHeight;
|
||||
let targetWidth, targetHeight;
|
||||
|
||||
if (videoAspectRatio > 1) {
|
||||
targetWidth = 640;
|
||||
targetHeight = 640 / videoAspectRatio;
|
||||
} else {
|
||||
targetHeight = 640;
|
||||
targetWidth = 640 * videoAspectRatio;
|
||||
}
|
||||
|
||||
const offsetX = (640 - targetWidth) / 2;
|
||||
const offsetY = (640 - targetHeight) / 2;
|
||||
offscreenContext.clearRect(0, 0, 640, 640);
|
||||
offscreenContext.drawImage(video, offsetX, offsetY, targetWidth, targetHeight);
|
||||
const boxes = await detectObjectsOnFrame(offscreenContext);
|
||||
drawBoxes(offscreenCanvas, boxes, targetWidth, targetHeight, offsetX, offsetY);
|
||||
requestAnimationFrame(processFrame);
|
||||
}
|
||||
|
||||
requestAnimationFrame(processFrame);
|
||||
|
||||
function drawBoxes(offscreenCanvas, boxes, targetWidth, targetHeight, offsetX, offsetY) {
|
||||
const ctx = document.querySelector("canvas").getContext("2d");
|
||||
ctx.clearRect(0, 0, canvas.width, canvas.height);
|
||||
ctx.lineWidth = 3;
|
||||
ctx.font = "30px serif";
|
||||
const scaleX = canvas.width / targetWidth;
|
||||
const scaleY = canvas.height / targetHeight;
|
||||
|
||||
boxes.forEach(([x1, y1, x2, y2, label]) => {
|
||||
const classIndex = yolo_classes.indexOf(label);
|
||||
const color = classColors[classIndex];
|
||||
ctx.strokeStyle = color;
|
||||
ctx.fillStyle = color;
|
||||
const adjustedX1 = (x1 - offsetX) * scaleX;
|
||||
const adjustedY1 = (y1 - offsetY) * scaleY;
|
||||
const adjustedX2 = (x2 - offsetX) * scaleX;
|
||||
const adjustedY2 = (y2 - offsetY) * scaleY;
|
||||
const boxWidth = adjustedX2 - adjustedX1;
|
||||
const boxHeight = adjustedY2 - adjustedY1;
|
||||
ctx.strokeRect(adjustedX1, adjustedY1, boxWidth, boxHeight);
|
||||
const textWidth = ctx.measureText(label).width;
|
||||
ctx.fillRect(adjustedX1, adjustedY1 - 25, textWidth + 10, 25);
|
||||
ctx.fillStyle = "#FFFFFF";
|
||||
ctx.fillText(label, adjustedX1 + 5, adjustedY1 - 7);
|
||||
});
|
||||
}
|
||||
|
||||
async function detectObjectsOnFrame(offscreenContext) {
|
||||
if (!net) net = await loadNet(await getDevice());
|
||||
let start = performance.now();
|
||||
const [input,img_width,img_height] = await prepareInput(offscreenContext);
|
||||
console.log("Preprocess took: " + (performance.now() - start) + " ms");
|
||||
start = performance.now();
|
||||
const output = await net(new Float32Array(input));
|
||||
console.log("Inference took: " + (performance.now() - start) + " ms");
|
||||
start = performance.now();
|
||||
let out = processOutput(output[0],img_width,img_height);
|
||||
console.log("Postprocess took: " + (performance.now() - start) + " ms");
|
||||
return out;
|
||||
}
|
||||
|
||||
async function prepareInput(offscreenContext) {
|
||||
return new Promise(resolve => {
|
||||
const [img_width,img_height] = [640, 640]
|
||||
const imgData = offscreenContext.getImageData(0,0,640,640);
|
||||
const pixels = imgData.data;
|
||||
const red = [], green = [], blue = [];
|
||||
|
||||
for (let index=0; index<pixels.length; index+=4) {
|
||||
red.push(pixels[index]/255.0);
|
||||
green.push(pixels[index+1]/255.0);
|
||||
blue.push(pixels[index+2]/255.0);
|
||||
}
|
||||
const input = [...red, ...green, ...blue];
|
||||
resolve([input, img_width, img_height])
|
||||
})
|
||||
}
|
||||
|
||||
const getDevice = async () => {
|
||||
if (!navigator.gpu) error("WebGPU not supported.");
|
||||
const adapter = await navigator.gpu.requestAdapter();
|
||||
return await adapter.requestDevice();
|
||||
};
|
||||
|
||||
function processOutput(output, img_width, img_height) {
|
||||
let boxes = [];
|
||||
for (let index=0;index<8400;index++) {
|
||||
const [class_id,prob] = [...Array(80).keys()]
|
||||
.map(col => [col, output[8400*(col+4)+index]])
|
||||
.reduce((accum, item) => item[1]>accum[1] ? item : accum,[0,0]);
|
||||
|
||||
if (prob < 0.25) continue;
|
||||
const label = yolo_classes[class_id];
|
||||
const xc = output[index];
|
||||
const yc = output[8400+index];
|
||||
const w = output[2*8400+index];
|
||||
const h = output[3*8400+index];
|
||||
const x1 = (xc-w/2)/640*img_width;
|
||||
const y1 = (yc-h/2)/640*img_height;
|
||||
const x2 = (xc+w/2)/640*img_width;
|
||||
const y2 = (yc+h/2)/640*img_height;
|
||||
boxes.push([x1,y1,x2,y2,label,prob]);
|
||||
}
|
||||
|
||||
boxes = boxes.sort((box1,box2) => box2[5]-box1[5])
|
||||
const result = [];
|
||||
while (boxes.length>0) {
|
||||
result.push(boxes[0]);
|
||||
boxes = boxes.filter(box => iou(boxes[0],box)<0.7);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
function iou(box1,box2) {
|
||||
return intersection(box1,box2)/union(box1,box2);
|
||||
}
|
||||
|
||||
function union(box1,box2) {
|
||||
const [box1_x1,box1_y1,box1_x2,box1_y2] = box1;
|
||||
const [box2_x1,box2_y1,box2_x2,box2_y2] = box2;
|
||||
const box1_area = (box1_x2-box1_x1)*(box1_y2-box1_y1)
|
||||
const box2_area = (box2_x2-box2_x1)*(box2_y2-box2_y1)
|
||||
return box1_area + box2_area - intersection(box1,box2)
|
||||
}
|
||||
|
||||
function intersection(box1,box2) {
|
||||
const [box1_x1,box1_y1,box1_x2,box1_y2] = box1;
|
||||
const [box2_x1,box2_y1,box2_x2,box2_y2] = box2;
|
||||
const x1 = Math.max(box1_x1,box2_x1);
|
||||
const y1 = Math.max(box1_y1,box2_y1);
|
||||
const x2 = Math.min(box1_x2,box2_x2);
|
||||
const y2 = Math.min(box1_y2,box2_y2);
|
||||
return (x2-x1)*(y2-y1)
|
||||
}
|
||||
|
||||
const yolo_classes = [
|
||||
'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat',
|
||||
'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse',
|
||||
'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase',
|
||||
'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard',
|
||||
'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
|
||||
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant',
|
||||
'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven',
|
||||
'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
|
||||
];
|
||||
|
||||
function generateColors(numColors) {
|
||||
const colors = [];
|
||||
for (let i = 0; i < 360; i += 360 / numColors) {
|
||||
colors.push(`hsl(${i}, 100%, 50%)`);
|
||||
}
|
||||
return colors;
|
||||
}
|
||||
|
||||
const classColors = generateColors(yolo_classes.length);
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
Reference in New Issue
Block a user