mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
Webgpu support (#1077)
* initial commit * 81 passing * 105 passing tests * 148 passing * CI tests * install dep on ci * try opencl pkgs * try using vulkan * down to only 6 failing * refactor * cleaning up * another test skipped due to buffer limit * linter * segfault * indent fix * another segfault found * small touchups * Fix max and maxpool tests * Add constant folding * Add javascript export script * better asserts in codegen * manual upcasting * reverted token type change * skip safetensor test due to unsupported type * FIx efficientnet and all other model tests * Remove np copy * fixed indent and missing import * manually destroy the buffer * revert back to length * linter errors * removed extra val * skip broken tests * skipping more tests * Make the page pretty * Save model weights as safetensor * Fix imagenet to c test * Fix second imagenet to c bug * Async and paralel kernel compilation * workgroup support * reversed local size * fixed non local bug * correct local groups * ci experiment * removed typo * Fix define local by using shared memory * Refactor * try running on mac * match metal tests * add more workers * scope down tests * trying windows runner * fixed windows env * see how many it can do * merged master * refactor * missed refactor * increase test suite coverage * missing import * whitespace in test_efficientnet.py * getting there * fixed reset * fixed bufs * switched to cstyle * cleanup * min/max rename * one more linter issue * fixed demo * linter * testing ci chrome * add unsafe webgpu arg * add build step * remove WEBGPU from cmd line * use module * try forcing directx * trying forced metal backend * temp disable conv2d for CI * disable conv_trasnpose2d --------- Co-authored-by: 0x4d - Martin Loretz <20306567+martinloretzzz@users.noreply.github.com> Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
87
examples/webgpu/compile_webgpu.py
Normal file
87
examples/webgpu/compile_webgpu.py
Normal file
@@ -0,0 +1,87 @@
|
||||
from os import path
|
||||
from examples.compile_efficientnet import compile_net, jit_model
|
||||
from models.efficientnet import EfficientNet
|
||||
from tinygrad.state import get_state_dict, safe_save
|
||||
from tinygrad.tensor import Tensor
|
||||
|
||||
if __name__ == "__main__":
|
||||
model = EfficientNet(0)
|
||||
model.load_from_pretrained()
|
||||
run, special_names = jit_model(model, Tensor.randn(1,3,224,224))
|
||||
functions, statements, bufs, _bufs_to_save = compile_net(run, special_names)
|
||||
|
||||
state = get_state_dict(model)
|
||||
weights = {id(x.lazydata.realized): name for name, x in state.items()}
|
||||
safe_save(state, path.join(path.dirname(__file__), "net.safetensors"))
|
||||
|
||||
kernel_code = '\n\n'.join([f"const {key} = `{code.replace(key, 'main')}`;" for key, code in functions.items()])
|
||||
kernel_names = ', '.join([name for (name, _args, _global_size) in statements])
|
||||
kernel_calls = '\n '.join([f"addComputePass(device, commandEncoder, piplines[{i}], [{', '.join(args)}], {global_size});" for i, (_name, args, global_size) in enumerate(statements) ])
|
||||
bufs = '\n '.join([f"const {buf[0]} = " + (f"createEmptyBuf(device, {buf[1]});" if buf[2] not in weights else f"createWeightBuf(device, {buf[1]}, getTensorBuffer(safetensor, metadata['{weights[buf[2]]}']))") + ";" for buf in bufs.values()])
|
||||
|
||||
prg = f"""const getTensorMetadata = (safetensorBuffer) => {{
|
||||
const metadataLength = Number(new DataView(safetensorBuffer.buffer).getBigUint64(0, true));
|
||||
const metadata = JSON.parse(new TextDecoder("utf8").decode(safetensorBuffer.subarray(8, 8 + metadataLength)));
|
||||
return Object.fromEntries(Object.entries(metadata).filter(([k, v]) => k !== "__metadata__").map(([k, v]) => [k, {{...v, data_offsets: v.data_offsets.map(x => 8 + metadataLength + x)}}]));
|
||||
}};
|
||||
|
||||
const getTensorBuffer = (safetensorBuffer, tensorMetadata) => {{
|
||||
return safetensorBuffer.subarray(...tensorMetadata.data_offsets);
|
||||
}}
|
||||
|
||||
const createEmptyBuf = (device, size) => {{
|
||||
return device.createBuffer({{size, usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST }});
|
||||
}};
|
||||
|
||||
const createWeightBuf = (device, size, data) => {{
|
||||
const buf = device.createBuffer({{ mappedAtCreation: true, size, usage: GPUBufferUsage.STORAGE }});
|
||||
new Uint8Array(buf.getMappedRange()).set(data);
|
||||
buf.unmap();
|
||||
return buf;
|
||||
}};
|
||||
|
||||
const addComputePass = (device, commandEncoder, pipeline, bufs, workgroup) => {{
|
||||
const bindGroup = device.createBindGroup({{layout: pipeline.getBindGroupLayout(0), entries: bufs.map((buffer, index) => ({{ binding: index, resource: {{ buffer }} }}))}});
|
||||
const passEncoder = commandEncoder.beginComputePass();
|
||||
passEncoder.setPipeline(pipeline);
|
||||
passEncoder.setBindGroup(0, bindGroup);
|
||||
passEncoder.dispatchWorkgroups(...workgroup);
|
||||
passEncoder.end();
|
||||
}};
|
||||
|
||||
{kernel_code}
|
||||
|
||||
const setupNet = async (device, safetensor) => {{
|
||||
const metadata = getTensorMetadata(safetensor);
|
||||
|
||||
{bufs}
|
||||
|
||||
const gpuWriteBuffer = device.createBuffer({{size:input.size, usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.MAP_WRITE }});
|
||||
const gpuReadBuffer = device.createBuffer({{ size: outputs.size, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ }});
|
||||
|
||||
const kernels = [{kernel_names}];
|
||||
const piplines = await Promise.all(kernels.map(name => device.createComputePipelineAsync({{layout: "auto", compute: {{ module: device.createShaderModule({{ code: name }}), entryPoint: "main" }}}})));
|
||||
|
||||
return async (data) => {{
|
||||
await gpuWriteBuffer.mapAsync(GPUMapMode.WRITE);
|
||||
new Float32Array(gpuWriteBuffer.getMappedRange()).set(data);
|
||||
gpuWriteBuffer.unmap();
|
||||
|
||||
const commandEncoder = device.createCommandEncoder();
|
||||
commandEncoder.copyBufferToBuffer(gpuWriteBuffer, 0, input, 0, gpuWriteBuffer.size);
|
||||
{kernel_calls}
|
||||
commandEncoder.copyBufferToBuffer(outputs, 0, gpuReadBuffer, 0, outputs.size);
|
||||
const gpuCommands = commandEncoder.finish();
|
||||
device.queue.submit([gpuCommands]);
|
||||
|
||||
await gpuReadBuffer.mapAsync(GPUMapMode.READ);
|
||||
const resultBuffer = new Float32Array(gpuReadBuffer.size);
|
||||
resultBuffer.set(new Float32Array(gpuReadBuffer.getMappedRange()));
|
||||
gpuReadBuffer.unmap();
|
||||
return resultBuffer;
|
||||
}}
|
||||
}}
|
||||
"""
|
||||
|
||||
with open(path.join(path.dirname(__file__), "net.js"), "w") as text_file:
|
||||
text_file.write(prg)
|
||||
119
examples/webgpu/index.html
Normal file
119
examples/webgpu/index.html
Normal file
@@ -0,0 +1,119 @@
|
||||
<html>
|
||||
<head>
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<style>
|
||||
#result { font-size: 48px; }
|
||||
#time { font-size: 16px; color: grey; }
|
||||
#mybox { padding: 20px; }
|
||||
#resultbox { padding: 50px; }
|
||||
.bigggg { font-size: 18px; margin-top: 10px; }
|
||||
.bigg { font-size: 18px; }
|
||||
#url { font-size: 18px; width: 70%; }
|
||||
a { text-decoration: none; }
|
||||
h1 { padding: 50px; padding-bottom: 0px; font-size: 36px; font-weight: normal; }
|
||||
#imagebox { height:224px; width:224px; border: 1px dotted black; }
|
||||
#video { height:0px; width:0px; border: 1px dotted black; object-fit: cover;}
|
||||
canvas { display: none; }
|
||||
* { text-align: center; font-family: monospace; }
|
||||
</style>
|
||||
<title>tinygrad has WebGPU</title>
|
||||
<script src="./net.js"></script>
|
||||
</head>
|
||||
<body>
|
||||
<h1>WebGPU <a href="https://github.com/geohot/tinygrad">tinygrad</a> EfficientNet!</h1>
|
||||
<div id="mybox">
|
||||
<input type="text" id="url" placeholder="put url here" value="https://upload.wikimedia.org/wikipedia/commons/d/da/Norwegian_hen.jpg">
|
||||
<input class="bigg" type="button" onclick="runNetWResource(document.getElementById('url').value)" value="Use URL">
|
||||
</div>
|
||||
<br/>
|
||||
<img id="imagebox"></img>
|
||||
<canvas id="canvas" width="200" height="200"> </canvas>
|
||||
<div id="resultbox">
|
||||
<div id="result">result will go here</div>
|
||||
<div id="time"></div>
|
||||
</div>
|
||||
<script>
|
||||
const ctx = document.getElementById("canvas").getContext("2d", { willReadFrequently: true });
|
||||
const resultText = document.getElementById('result');
|
||||
let labels, net;
|
||||
|
||||
const error = (err) => {
|
||||
resultText.innerHTML = `Error: ${err}`;
|
||||
throw new Error(err);
|
||||
}
|
||||
|
||||
const getDevice = async () => {
|
||||
if (!navigator.gpu) error("WebGPU not supported.");
|
||||
const adapter = await navigator.gpu.requestAdapter();
|
||||
return await adapter.requestDevice();
|
||||
};
|
||||
|
||||
const timer = async (func, label = "") => {
|
||||
document.getElementById('time').innerHTML = "";
|
||||
const start = performance.now();
|
||||
const out = await func();
|
||||
const delta = (performance.now() - start).toFixed(1)
|
||||
console.log(`${delta} ms ${label}`);
|
||||
document.getElementById('time').innerHTML = `${delta} ms ${label}`;
|
||||
return out;
|
||||
}
|
||||
|
||||
const getLabels = async () => (await fetch("https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json")).json();
|
||||
|
||||
const getSavetensorBuffer = async () => new Uint8Array(await (await fetch("./net.safetensors")).arrayBuffer());
|
||||
|
||||
const reorderChannelsAndRemoveAlpha = (data) => {
|
||||
const out = [];
|
||||
let i = 0;
|
||||
for (let c = 0; c < 3; c++) {
|
||||
for (let x = 0; x < 224 * 224; x++) {
|
||||
out[i] = data[x * 4 + c];
|
||||
i++;
|
||||
}
|
||||
}
|
||||
return out;
|
||||
};
|
||||
|
||||
const runNetWResource = async (resource) => {
|
||||
resultText.innerHTML = "pending..."
|
||||
if (resource == "") error("sir. please type in a URL");
|
||||
const response = await fetch(resource)
|
||||
if (!response.ok) error("sir. that is not a good URL. try a new one");
|
||||
document.getElementById("imagebox").src = resource
|
||||
|
||||
const img = new Image();
|
||||
img.crossOrigin = "Anonymous";
|
||||
img.onload = () => {
|
||||
URL.revokeObjectURL(img.src);
|
||||
ctx.drawImage(img, 0, 0, 224, 224);
|
||||
const data = ctx.getImageData(0, 0, 224, 224).data;
|
||||
runNet(data)
|
||||
};
|
||||
img.src = resource;
|
||||
}
|
||||
|
||||
const loadLet = async () => {
|
||||
resultText.innerHTML = "loading..."
|
||||
labels = await getLabels();
|
||||
const safetensor = await getSavetensorBuffer();
|
||||
const device = await getDevice();
|
||||
net = await timer(() => setupNet(device, safetensor), "(compilation)");
|
||||
resultText.innerHTML = "ready"
|
||||
}
|
||||
|
||||
const runNet = async (data) => {
|
||||
if (!net) error("Net not loaded yet.");
|
||||
|
||||
const input = reorderChannelsAndRemoveAlpha(Array.from(data).map((pix) => (pix / 255.0) * 0.45 - 0.225));
|
||||
const out = await timer(() => net(new Float32Array(input)));
|
||||
|
||||
const arr = Array.from(new Float32Array(out));
|
||||
const index = arr.indexOf(Math.max(...arr));
|
||||
|
||||
resultText.textContent = labels[index];
|
||||
};
|
||||
|
||||
loadLet();
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
Reference in New Issue
Block a user