YoloV8 on WebGPU (#8007)

Port YoloV8 to WebGPU
This commit is contained in:
Ahmed Harmouche
2024-12-03 15:10:41 +01:00
committed by GitHub
parent 09eac42fd6
commit 8818046940
4 changed files with 82 additions and 49 deletions

View File

@@ -1,20 +1,16 @@
from pathlib import Path
from examples.yolov8 import YOLOv8
from examples.yolov8 import YOLOv8, get_weights_location
from tinygrad.tensor import Tensor
from tinygrad.nn.state import safe_save
from extra.export_model import export_model
from tinygrad.helpers import fetch
from tinygrad.helpers import getenv
from tinygrad.device import Device
from tinygrad.nn.state import safe_load, load_state_dict
if __name__ == "__main__":
Device.DEFAULT = "WEBGL"
Device.DEFAULT = "WEBGPU"
yolo_variant = 'n'
yolo_infer = YOLOv8(w=0.25, r=2.0, d=0.33, num_classes=80)
weights_location = Path(__file__).parents[1] / "weights" / f'yolov8{yolo_variant}.safetensors'
fetch(f'https://gitlab.com/r3sist/yolov8_weights/-/raw/master/yolov8{yolo_variant}.safetensors', weights_location)
state_dict = safe_load(weights_location)
state_dict = safe_load(get_weights_location(yolo_variant))
load_state_dict(yolo_infer, state_dict)
prg, inp_sizes, out_sizes, state = export_model(yolo_infer, Device.DEFAULT.lower(), Tensor.randn(1,3,640,640))
dirname = Path(__file__).parent

View File

@@ -3,7 +3,7 @@
<html lang="en">
<head>
<meta charset="UTF-8">
<title>YOLOv8 tinygrad WebGL</title>
<title>YOLOv8 tinygrad WebGPU</title>
<script src="./net.js"></script>
<style>
body {
@@ -38,7 +38,7 @@
</style>
</head>
<body>
<h1>YOLOv8 tinygrad WebGL</h1>
<h1>YOLOv8 tinygrad WebGPU</h1>
<div class="video-container">
<video id="video" muted autoplay playsinline></video>
<canvas id="canvas"></canvas>
@@ -56,7 +56,7 @@
const offscreenContext = offscreenCanvas.getContext('2d');
if (navigator.mediaDevices && navigator.mediaDevices.getUserMedia) {
navigator.mediaDevices.getUserMedia({ audio: false, video: true }).then(function (stream) {
navigator.mediaDevices.getUserMedia({ audio: false, video: { facingMode: { ideal: "environment" }}}).then(function (stream) {
video.srcObject = stream;
video.onloadedmetadata = function() {
canvas.width = video.clientWidth;
@@ -66,54 +66,69 @@
}
async function processFrame() {
offscreenContext.drawImage(video, 0, 0, 640, 640);
if (video.videoWidth == 0 || video.videoHeight == 0) {
requestAnimationFrame(processFrame);
return;
}
const videoAspectRatio = video.videoWidth / video.videoHeight;
let targetWidth, targetHeight;
if (videoAspectRatio > 1) {
targetWidth = 640;
targetHeight = 640 / videoAspectRatio;
} else {
targetHeight = 640;
targetWidth = 640 * videoAspectRatio;
}
const offsetX = (640 - targetWidth) / 2;
const offsetY = (640 - targetHeight) / 2;
offscreenContext.clearRect(0, 0, 640, 640);
offscreenContext.drawImage(video, offsetX, offsetY, targetWidth, targetHeight);
const boxes = await detectObjectsOnFrame(offscreenContext);
drawBoxes(offscreenCanvas, boxes);
drawBoxes(offscreenCanvas, boxes, targetWidth, targetHeight, offsetX, offsetY);
requestAnimationFrame(processFrame);
}
requestAnimationFrame(processFrame);
function drawBoxes(offscreenCanvas, boxes) {
const canvas = document.querySelector("canvas");
const ctx = canvas.getContext("2d");
function drawBoxes(offscreenCanvas, boxes, targetWidth, targetHeight, offsetX, offsetY) {
const ctx = document.querySelector("canvas").getContext("2d");
ctx.clearRect(0, 0, canvas.width, canvas.height);
ctx.lineWidth = 3;
ctx.font = "20px serif";
const scaleX = canvas.width / 640;
const scaleY = canvas.height / 640;
ctx.font = "30px serif";
const scaleX = canvas.width / targetWidth;
const scaleY = canvas.height / targetHeight;
boxes.forEach(([x1, y1, x2, y2, label]) => {
const classIndex = yolo_classes.indexOf(label);
const color = classColors[classIndex];
const textWidth = ctx.measureText(label).width;
ctx.strokeStyle = color;
ctx.fillStyle = color;
let adjustedX1 = x1 * scaleX;
let adjustedY1 = y1 * scaleY;
let adjustedX2 = x2 * scaleX;
let adjustedY2 = y2 * scaleY;
let boxWidth = adjustedX2 - adjustedX1;
let boxHeight = adjustedY2 - adjustedY1;
const adjustedX1 = (x1 - offsetX) * scaleX;
const adjustedY1 = (y1 - offsetY) * scaleY;
const adjustedX2 = (x2 - offsetX) * scaleX;
const adjustedY2 = (y2 - offsetY) * scaleY;
const boxWidth = adjustedX2 - adjustedX1;
const boxHeight = adjustedY2 - adjustedY1;
ctx.strokeRect(adjustedX1, adjustedY1, boxWidth, boxHeight);
const textWidth = ctx.measureText(label).width;
ctx.fillRect(adjustedX1, adjustedY1 - 25, textWidth + 10, 25);
ctx.fillStyle = "#000000";
ctx.fillText(label, adjustedX1, adjustedY1 - 7);
ctx.fillStyle = "#FFFFFF";
ctx.fillText(label, adjustedX1 + 5, adjustedY1 - 7);
});
}
async function detectObjectsOnFrame(offscreenContext) {
if (!net) net = await loadNet();
if (!net) net = await loadNet(await getDevice());
let start = performance.now();
const [input,img_width,img_height] = await prepareInput(offscreenContext);
console.log("Preprocess took: " + (performance.now() - start) + " ms");
start = performance.now();
const output = net(new Float32Array(input));
const output = await net(new Float32Array(input));
console.log("Inference took: " + (performance.now() - start) + " ms");
start = performance.now();
let out = processOutput(output,img_width,img_height);
let out = processOutput(output[0],img_width,img_height);
console.log("Postprocess took: " + (performance.now() - start) + " ms");
return out;
}
@@ -134,27 +149,21 @@
resolve([input, img_width, img_height])
})
}
const loadNet = async () => {
try {
const safetensor = await (new Uint8Array(await (await fetch("./net.safetensors")).arrayBuffer()));
const gl = document.createElement("canvas").getContext("webgl2");
return setupNet(gl, safetensor);
} catch (e) {
console.log(e);
return null;
}
}
const getDevice = async () => {
if (!navigator.gpu) error("WebGPU not supported.");
const adapter = await navigator.gpu.requestAdapter();
return await adapter.requestDevice();
};
function processOutput(output, img_width, img_height) {
let boxes = [];
for (let index=0;index<8400;index++) {
const [class_id,prob] = [...Array(80).keys()]
.map(col => [col, output[8400*(col+4)+index]])
.reduce((accum, item) => item[1]>accum[1] ? item : accum,[0,0]);
if (prob < 0.25) {
continue;
}
if (prob < 0.25) continue;
const label = yolo_classes[class_id];
const xc = output[index];
const yc = output[8400+index];

View File

@@ -1,5 +1,7 @@
from tinygrad.nn import Conv2d, BatchNorm2d
from tinygrad.tensor import Tensor
from tinygrad.device import is_dtype_supported
from tinygrad import dtypes
import numpy as np
from itertools import chain
from pathlib import Path
@@ -8,6 +10,7 @@ from collections import defaultdict
import time, sys
from tinygrad.helpers import fetch
from tinygrad.nn.state import safe_load, load_state_dict
import json
#Model architecture from https://github.com/ultralytics/ultralytics/issues/189
#The upsampling class has been taken from this pull request https://github.com/tinygrad/tinygrad/pull/784 by dc-dc-dc. Now 2(?) models use upsampling. (retinet and this)
@@ -385,6 +388,32 @@ class YOLOv8:
yolov8_head_weights = [(22, self.head)]
return [*zip(backbone_modules, self.net.return_modules()), *zip(yolov8neck_modules, self.fpn.return_modules()), *yolov8_head_weights]
def convert_f16_safetensor_to_f32(input_file: Path, output_file: Path):
with open(input_file, 'rb') as f:
metadata_length = int.from_bytes(f.read(8), 'little')
metadata = json.loads(f.read(metadata_length).decode())
float32_values = np.fromfile(f, dtype=np.float16).astype(np.float32)
for v in metadata.values():
if v["dtype"] == "F16": v.update({"dtype": "F32", "data_offsets": [offset * 2 for offset in v["data_offsets"]]})
with open(output_file, 'wb') as f:
new_metadata_bytes = json.dumps(metadata).encode()
f.write(len(new_metadata_bytes).to_bytes(8, 'little'))
f.write(new_metadata_bytes)
float32_values.tofile(f)
def get_weights_location(yolo_variant: str) -> Path:
weights_location = Path(__file__).parents[1] / "weights" / f'yolov8{yolo_variant}.safetensors'
fetch(f'https://gitlab.com/r3sist/yolov8_weights/-/raw/master/yolov8{yolo_variant}.safetensors', weights_location)
if not is_dtype_supported(dtypes.half):
f32_weights = weights_location.with_name(f"{weights_location.stem}_f32.safetensors")
if not f32_weights.exists(): convert_f16_safetensor_to_f32(weights_location, f32_weights)
weights_location = f32_weights
return weights_location
if __name__ == '__main__':
# usage : python3 yolov8.py "image_URL OR image_path" "v8 variant" (optional, n is default)
@@ -410,8 +439,7 @@ if __name__ == '__main__':
# Different YOLOv8 variants use different w , r, and d multiples. For a list , refer to this yaml file (the scales section) https://github.com/ultralytics/ultralytics/blob/main/ultralytics/cfg/models/v8/yolov8.yaml
depth, width, ratio = get_variant_multiples(yolo_variant)
yolo_infer = YOLOv8(w=width, r=ratio, d=depth, num_classes=80)
state_dict = safe_load(fetch(f'https://gitlab.com/r3sist/yolov8_weights/-/raw/master/yolov8{yolo_variant}.safetensors'))
state_dict = safe_load(get_weights_location(yolo_variant))
load_state_dict(yolo_infer, state_dict)
st = time.time()

View File

@@ -262,7 +262,7 @@ def export_model_webgpu(functions, statements, bufs, bufs_to_save, weight_names,
input_writers = '\n '.join([f"await gpuWriteBuffer{i}.mapAsync(GPUMapMode.WRITE);\n new Float32Array(gpuWriteBuffer{i}.getMappedRange()).set(" + f'_{inp_name});' + f"\n gpuWriteBuffer{i}.unmap();\n commandEncoder.copyBufferToBuffer(gpuWriteBuffer{i}, 0, {inp_name}, 0, gpuWriteBuffer{i}.size);" for i,inp_name in enumerate(input_names)])
gpu_read_bufs = '\n '.join([f"const gpuReadBuffer{i} = device.createBuffer({{size:{output_name}.size, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ }});" for i,output_name in enumerate(output_names)])
outbuf_copies = '\n '.join([f"commandEncoder.copyBufferToBuffer({output_name}, 0, gpuReadBuffer{i}, 0, output{i}.size);" for i,output_name in enumerate(output_names)])
output_readers = '\n '.join([f"await gpuReadBuffer{i}.mapAsync(GPUMapMode.READ);\n const resultBuffer{i} = new Float32Array(gpuReadBuffer{i}.size);\n resultBuffer{i}.set(new Float32Array(gpuReadBuffer{i}.getMappedRange()));\n gpuReadBuffer{i}.unmap();" for i in range(len(output_names))])
output_readers = '\n '.join([f"await gpuReadBuffer{i}.mapAsync(GPUMapMode.READ);\n const resultBuffer{i} = new Float32Array(gpuReadBuffer{i}.size/4);\n resultBuffer{i}.set(new Float32Array(gpuReadBuffer{i}.getMappedRange()));\n gpuReadBuffer{i}.unmap();" for i in range(len(output_names))])
output_return = '[{}]'.format(",".join([f'resultBuffer{i}' for i in range(len(output_names))]))
return f"""
{web_utils["getTensorBuffer"]}