mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
YoloV8 WebGPU fixes (#8057)
* Bump up input size to 416, show if webgpu is not supported * Minor fix in export_model
This commit is contained in:
@@ -12,7 +12,7 @@ if __name__ == "__main__":
|
||||
yolo_infer = YOLOv8(w=0.25, r=2.0, d=0.33, num_classes=80)
|
||||
state_dict = safe_load(get_weights_location(yolo_variant))
|
||||
load_state_dict(yolo_infer, state_dict)
|
||||
prg, inp_sizes, out_sizes, state = export_model(yolo_infer, Device.DEFAULT.lower(), Tensor.randn(1,3,256,256))
|
||||
prg, inp_sizes, out_sizes, state = export_model(yolo_infer, Device.DEFAULT.lower(), Tensor.randn(1,3,416,416))
|
||||
dirname = Path(__file__).parent
|
||||
safe_save(state, (dirname / "net.safetensors").as_posix())
|
||||
with open(dirname / f"net.js", "w") as text_file:
|
||||
|
||||
@@ -95,6 +95,7 @@
|
||||
</head>
|
||||
<body>
|
||||
<h2>YOLOv8 tinygrad WebGPU</h2>
|
||||
<h2 id="wgpu-error" style="display: none; color: red;">Error: WebGPU is not supported in this browser</h2>
|
||||
<div class="video-container">
|
||||
<video id="video" muted autoplay playsinline></video>
|
||||
<canvas id="canvas"></canvas>
|
||||
@@ -107,7 +108,7 @@
|
||||
</div>
|
||||
<script>
|
||||
let net = null;
|
||||
const modelInputSize = 256
|
||||
const modelInputSize = 416;
|
||||
let lastCalledTime;
|
||||
let fps = 0, accumFps = 0, frameCounter = 0;
|
||||
|
||||
@@ -117,6 +118,7 @@
|
||||
const offscreenCanvas = document.createElement('canvas');
|
||||
const fpsMeter = document.getElementById('fps-meter');
|
||||
const loadingContainer = document.getElementById('div-loading');
|
||||
const wgpuError = document.getElementById('wgpu-error');
|
||||
offscreenCanvas.width = modelInputSize;
|
||||
offscreenCanvas.height = modelInputSize;
|
||||
const offscreenContext = offscreenCanvas.getContext('2d');
|
||||
@@ -147,7 +149,7 @@
|
||||
lastCalledTime = now;
|
||||
accumFps += 1/delta;
|
||||
|
||||
if (frameCounter++ >= 30) {
|
||||
if (frameCounter++ >= 10) {
|
||||
fps = accumFps/frameCounter;
|
||||
frameCounter = 0;
|
||||
accumFps = 0;
|
||||
@@ -206,7 +208,12 @@
|
||||
|
||||
async function detectObjectsOnFrame(offscreenContext) {
|
||||
if (!net) {
|
||||
net = await loadNet(await getDevice());
|
||||
let device = await getDevice();
|
||||
if (!device) {
|
||||
wgpuError.style.display = "block";
|
||||
loadingContainer.style.display = "none";
|
||||
}
|
||||
net = await loadNet(device);
|
||||
loadingContainer.style.display = "none";
|
||||
}
|
||||
let start = performance.now();
|
||||
@@ -239,7 +246,7 @@
|
||||
}
|
||||
|
||||
const getDevice = async () => {
|
||||
if (!navigator.gpu) error("WebGPU not supported.");
|
||||
if (!navigator.gpu) return false;
|
||||
const adapter = await navigator.gpu.requestAdapter();
|
||||
return await adapter.requestDevice();
|
||||
};
|
||||
|
||||
@@ -81,7 +81,7 @@ def export_model_clang(functions:Dict[str,str], statements:Dict[str,Tuple[str,in
|
||||
|
||||
def export_model_webgpu(functions, statements, bufs, bufs_to_save, weight_names, input_names, output_names) -> Tuple[str,int,int]:
|
||||
def dtype_to_js_type(dtype: DType) -> str:
|
||||
return "Uint32Array" if dtype in dtypes.uints else "Int32Array" if (dtype in dtypes.ints or dtype == dtypes.bool) else "Float32Array"
|
||||
return "Uint32Array" if dtype in dtypes.uints else "Int32Array" if (dtype in dtypes.sints or dtype == dtypes.bool) else "Float32Array"
|
||||
kernel_code = '\n\n'.join([f"const {key} = `{code.replace(key, 'main')}`;" for key, code in functions.items()])
|
||||
kernel_names = ', '.join([name for (name, _args, _global_size, _local_size) in statements])
|
||||
create_bind_group_layouts = ",".join([
|
||||
|
||||
Reference in New Issue
Block a user