From 88180469401b1b5845a22e64f5952a4f129e0ae8 Mon Sep 17 00:00:00 2001 From: Ahmed Harmouche Date: Tue, 3 Dec 2024 15:10:41 +0100 Subject: [PATCH] YoloV8 on WebGPU (#8007) Port YoloV8 to WebGPU --- examples/{webgl => webgpu}/yolov8/compile.py | 10 +-- examples/{webgl => webgpu}/yolov8/index.html | 87 +++++++++++--------- examples/yolov8.py | 32 ++++++- extra/export_model.py | 2 +- 4 files changed, 82 insertions(+), 49 deletions(-) rename examples/{webgl => webgpu}/yolov8/compile.py (63%) rename examples/{webgl => webgpu}/yolov8/index.html (76%) diff --git a/examples/webgl/yolov8/compile.py b/examples/webgpu/yolov8/compile.py similarity index 63% rename from examples/webgl/yolov8/compile.py rename to examples/webgpu/yolov8/compile.py index 07a345d17a..34f4780752 100644 --- a/examples/webgl/yolov8/compile.py +++ b/examples/webgpu/yolov8/compile.py @@ -1,20 +1,16 @@ from pathlib import Path -from examples.yolov8 import YOLOv8 +from examples.yolov8 import YOLOv8, get_weights_location from tinygrad.tensor import Tensor from tinygrad.nn.state import safe_save from extra.export_model import export_model -from tinygrad.helpers import fetch -from tinygrad.helpers import getenv from tinygrad.device import Device from tinygrad.nn.state import safe_load, load_state_dict if __name__ == "__main__": - Device.DEFAULT = "WEBGL" + Device.DEFAULT = "WEBGPU" yolo_variant = 'n' yolo_infer = YOLOv8(w=0.25, r=2.0, d=0.33, num_classes=80) - weights_location = Path(__file__).parents[1] / "weights" / f'yolov8{yolo_variant}.safetensors' - fetch(f'https://gitlab.com/r3sist/yolov8_weights/-/raw/master/yolov8{yolo_variant}.safetensors', weights_location) - state_dict = safe_load(weights_location) + state_dict = safe_load(get_weights_location(yolo_variant)) load_state_dict(yolo_infer, state_dict) prg, inp_sizes, out_sizes, state = export_model(yolo_infer, Device.DEFAULT.lower(), Tensor.randn(1,3,640,640)) dirname = Path(__file__).parent diff --git a/examples/webgl/yolov8/index.html b/examples/webgpu/yolov8/index.html similarity index 76% rename from examples/webgl/yolov8/index.html rename to examples/webgpu/yolov8/index.html index 4de507d136..0bc7c8f020 100644 --- a/examples/webgl/yolov8/index.html +++ b/examples/webgpu/yolov8/index.html @@ -3,7 +3,7 @@ - YOLOv8 tinygrad WebGL + YOLOv8 tinygrad WebGPU -

YOLOv8 tinygrad WebGL

+

YOLOv8 tinygrad WebGPU

@@ -56,7 +56,7 @@ const offscreenContext = offscreenCanvas.getContext('2d'); if (navigator.mediaDevices && navigator.mediaDevices.getUserMedia) { - navigator.mediaDevices.getUserMedia({ audio: false, video: true }).then(function (stream) { + navigator.mediaDevices.getUserMedia({ audio: false, video: { facingMode: { ideal: "environment" }}}).then(function (stream) { video.srcObject = stream; video.onloadedmetadata = function() { canvas.width = video.clientWidth; @@ -66,54 +66,69 @@ } async function processFrame() { - offscreenContext.drawImage(video, 0, 0, 640, 640); + if (video.videoWidth == 0 || video.videoHeight == 0) { + requestAnimationFrame(processFrame); + return; + } + const videoAspectRatio = video.videoWidth / video.videoHeight; + let targetWidth, targetHeight; + + if (videoAspectRatio > 1) { + targetWidth = 640; + targetHeight = 640 / videoAspectRatio; + } else { + targetHeight = 640; + targetWidth = 640 * videoAspectRatio; + } + + const offsetX = (640 - targetWidth) / 2; + const offsetY = (640 - targetHeight) / 2; + offscreenContext.clearRect(0, 0, 640, 640); + offscreenContext.drawImage(video, offsetX, offsetY, targetWidth, targetHeight); const boxes = await detectObjectsOnFrame(offscreenContext); - drawBoxes(offscreenCanvas, boxes); + drawBoxes(offscreenCanvas, boxes, targetWidth, targetHeight, offsetX, offsetY); requestAnimationFrame(processFrame); } requestAnimationFrame(processFrame); - function drawBoxes(offscreenCanvas, boxes) { - const canvas = document.querySelector("canvas"); - const ctx = canvas.getContext("2d"); + function drawBoxes(offscreenCanvas, boxes, targetWidth, targetHeight, offsetX, offsetY) { + const ctx = document.querySelector("canvas").getContext("2d"); ctx.clearRect(0, 0, canvas.width, canvas.height); ctx.lineWidth = 3; - ctx.font = "20px serif"; - const scaleX = canvas.width / 640; - const scaleY = canvas.height / 640; + ctx.font = "30px serif"; + const scaleX = canvas.width / targetWidth; + const scaleY = canvas.height / targetHeight; boxes.forEach(([x1, y1, x2, y2, label]) => { const classIndex = yolo_classes.indexOf(label); const color = classColors[classIndex]; - const textWidth = ctx.measureText(label).width; ctx.strokeStyle = color; ctx.fillStyle = color; - - let adjustedX1 = x1 * scaleX; - let adjustedY1 = y1 * scaleY; - let adjustedX2 = x2 * scaleX; - let adjustedY2 = y2 * scaleY; - let boxWidth = adjustedX2 - adjustedX1; - let boxHeight = adjustedY2 - adjustedY1; - + const adjustedX1 = (x1 - offsetX) * scaleX; + const adjustedY1 = (y1 - offsetY) * scaleY; + const adjustedX2 = (x2 - offsetX) * scaleX; + const adjustedY2 = (y2 - offsetY) * scaleY; + const boxWidth = adjustedX2 - adjustedX1; + const boxHeight = adjustedY2 - adjustedY1; ctx.strokeRect(adjustedX1, adjustedY1, boxWidth, boxHeight); + const textWidth = ctx.measureText(label).width; ctx.fillRect(adjustedX1, adjustedY1 - 25, textWidth + 10, 25); - ctx.fillStyle = "#000000"; - ctx.fillText(label, adjustedX1, adjustedY1 - 7); + ctx.fillStyle = "#FFFFFF"; + ctx.fillText(label, adjustedX1 + 5, adjustedY1 - 7); }); } async function detectObjectsOnFrame(offscreenContext) { - if (!net) net = await loadNet(); + if (!net) net = await loadNet(await getDevice()); let start = performance.now(); const [input,img_width,img_height] = await prepareInput(offscreenContext); console.log("Preprocess took: " + (performance.now() - start) + " ms"); start = performance.now(); - const output = net(new Float32Array(input)); + const output = await net(new Float32Array(input)); console.log("Inference took: " + (performance.now() - start) + " ms"); start = performance.now(); - let out = processOutput(output,img_width,img_height); + let out = processOutput(output[0],img_width,img_height); console.log("Postprocess took: " + (performance.now() - start) + " ms"); return out; } @@ -134,27 +149,21 @@ resolve([input, img_width, img_height]) }) } - - const loadNet = async () => { - try { - const safetensor = await (new Uint8Array(await (await fetch("./net.safetensors")).arrayBuffer())); - const gl = document.createElement("canvas").getContext("webgl2"); - return setupNet(gl, safetensor); - } catch (e) { - console.log(e); - return null; - } - } + const getDevice = async () => { + if (!navigator.gpu) error("WebGPU not supported."); + const adapter = await navigator.gpu.requestAdapter(); + return await adapter.requestDevice(); + }; + function processOutput(output, img_width, img_height) { let boxes = []; for (let index=0;index<8400;index++) { const [class_id,prob] = [...Array(80).keys()] .map(col => [col, output[8400*(col+4)+index]]) .reduce((accum, item) => item[1]>accum[1] ? item : accum,[0,0]); - if (prob < 0.25) { - continue; - } + + if (prob < 0.25) continue; const label = yolo_classes[class_id]; const xc = output[index]; const yc = output[8400+index]; diff --git a/examples/yolov8.py b/examples/yolov8.py index 37396c7501..4fbf5ed4f2 100644 --- a/examples/yolov8.py +++ b/examples/yolov8.py @@ -1,5 +1,7 @@ from tinygrad.nn import Conv2d, BatchNorm2d from tinygrad.tensor import Tensor +from tinygrad.device import is_dtype_supported +from tinygrad import dtypes import numpy as np from itertools import chain from pathlib import Path @@ -8,6 +10,7 @@ from collections import defaultdict import time, sys from tinygrad.helpers import fetch from tinygrad.nn.state import safe_load, load_state_dict +import json #Model architecture from https://github.com/ultralytics/ultralytics/issues/189 #The upsampling class has been taken from this pull request https://github.com/tinygrad/tinygrad/pull/784 by dc-dc-dc. Now 2(?) models use upsampling. (retinet and this) @@ -385,6 +388,32 @@ class YOLOv8: yolov8_head_weights = [(22, self.head)] return [*zip(backbone_modules, self.net.return_modules()), *zip(yolov8neck_modules, self.fpn.return_modules()), *yolov8_head_weights] +def convert_f16_safetensor_to_f32(input_file: Path, output_file: Path): + with open(input_file, 'rb') as f: + metadata_length = int.from_bytes(f.read(8), 'little') + metadata = json.loads(f.read(metadata_length).decode()) + float32_values = np.fromfile(f, dtype=np.float16).astype(np.float32) + + for v in metadata.values(): + if v["dtype"] == "F16": v.update({"dtype": "F32", "data_offsets": [offset * 2 for offset in v["data_offsets"]]}) + + with open(output_file, 'wb') as f: + new_metadata_bytes = json.dumps(metadata).encode() + f.write(len(new_metadata_bytes).to_bytes(8, 'little')) + f.write(new_metadata_bytes) + float32_values.tofile(f) + +def get_weights_location(yolo_variant: str) -> Path: + weights_location = Path(__file__).parents[1] / "weights" / f'yolov8{yolo_variant}.safetensors' + fetch(f'https://gitlab.com/r3sist/yolov8_weights/-/raw/master/yolov8{yolo_variant}.safetensors', weights_location) + + if not is_dtype_supported(dtypes.half): + f32_weights = weights_location.with_name(f"{weights_location.stem}_f32.safetensors") + if not f32_weights.exists(): convert_f16_safetensor_to_f32(weights_location, f32_weights) + weights_location = f32_weights + + return weights_location + if __name__ == '__main__': # usage : python3 yolov8.py "image_URL OR image_path" "v8 variant" (optional, n is default) @@ -410,8 +439,7 @@ if __name__ == '__main__': # Different YOLOv8 variants use different w , r, and d multiples. For a list , refer to this yaml file (the scales section) https://github.com/ultralytics/ultralytics/blob/main/ultralytics/cfg/models/v8/yolov8.yaml depth, width, ratio = get_variant_multiples(yolo_variant) yolo_infer = YOLOv8(w=width, r=ratio, d=depth, num_classes=80) - - state_dict = safe_load(fetch(f'https://gitlab.com/r3sist/yolov8_weights/-/raw/master/yolov8{yolo_variant}.safetensors')) + state_dict = safe_load(get_weights_location(yolo_variant)) load_state_dict(yolo_infer, state_dict) st = time.time() diff --git a/extra/export_model.py b/extra/export_model.py index bb7de53a6b..866f6c2b20 100644 --- a/extra/export_model.py +++ b/extra/export_model.py @@ -262,7 +262,7 @@ def export_model_webgpu(functions, statements, bufs, bufs_to_save, weight_names, input_writers = '\n '.join([f"await gpuWriteBuffer{i}.mapAsync(GPUMapMode.WRITE);\n new Float32Array(gpuWriteBuffer{i}.getMappedRange()).set(" + f'_{inp_name});' + f"\n gpuWriteBuffer{i}.unmap();\n commandEncoder.copyBufferToBuffer(gpuWriteBuffer{i}, 0, {inp_name}, 0, gpuWriteBuffer{i}.size);" for i,inp_name in enumerate(input_names)]) gpu_read_bufs = '\n '.join([f"const gpuReadBuffer{i} = device.createBuffer({{size:{output_name}.size, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ }});" for i,output_name in enumerate(output_names)]) outbuf_copies = '\n '.join([f"commandEncoder.copyBufferToBuffer({output_name}, 0, gpuReadBuffer{i}, 0, output{i}.size);" for i,output_name in enumerate(output_names)]) - output_readers = '\n '.join([f"await gpuReadBuffer{i}.mapAsync(GPUMapMode.READ);\n const resultBuffer{i} = new Float32Array(gpuReadBuffer{i}.size);\n resultBuffer{i}.set(new Float32Array(gpuReadBuffer{i}.getMappedRange()));\n gpuReadBuffer{i}.unmap();" for i in range(len(output_names))]) + output_readers = '\n '.join([f"await gpuReadBuffer{i}.mapAsync(GPUMapMode.READ);\n const resultBuffer{i} = new Float32Array(gpuReadBuffer{i}.size/4);\n resultBuffer{i}.set(new Float32Array(gpuReadBuffer{i}.getMappedRange()));\n gpuReadBuffer{i}.unmap();" for i in range(len(output_names))]) output_return = '[{}]'.format(",".join([f'resultBuffer{i}' for i in range(len(output_names))])) return f""" {web_utils["getTensorBuffer"]}