encapsulate the exported webgpu model (#8203)

This commit is contained in:
Ahmed Harmouche
2024-12-13 10:55:37 +01:00
committed by GitHub
parent 5864627abe
commit 651f72442c
4 changed files with 32 additions and 28 deletions

View File

@@ -17,8 +17,11 @@ canvas { display: none; }
* { text-align: center; font-family: monospace; }
</style>
<title>tinygrad has WebGPU</title>
<script src="../../net.js"></script>
<link rel="icon" type="image/x-icon" href="https://raw.githubusercontent.com/tinygrad/tinygrad/master/docs/logo.png">
<script type="module">
import model from "../../net.js";
window.model = model;
</script>
</head>
<body>
<h1>WebGPU <a href="https://github.com/geohot/tinygrad">tinygrad</a> EfficientNet!</h1>
@@ -61,8 +64,6 @@ canvas { display: none; }
const getLabels = async () => (await fetch("https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json")).json();
const getSavetensorBuffer = async () => new Uint8Array(await (await fetch("../../net.safetensors")).arrayBuffer());
const reorderChannelsAndRemoveAlpha = (data) => {
const out = [];
let i = 0;
@@ -97,9 +98,8 @@ canvas { display: none; }
try {
resultText.innerHTML = "loading..."
labels = await getLabels();
const safetensor = await getSavetensorBuffer();
const device = await getDevice();
net = await timer(() => setupNet(device, safetensor), "(compilation)");
net = await timer(() => model.load(device, '../../net.safetensors'), "(compilation)");
resultText.innerHTML = "ready"
} catch (e) {
error(e)

View File

@@ -12,7 +12,7 @@ if __name__ == "__main__":
yolo_infer = YOLOv8(w=0.25, r=2.0, d=0.33, num_classes=80)
state_dict = safe_load(get_weights_location(yolo_variant))
load_state_dict(yolo_infer, state_dict)
prg, inp_sizes, out_sizes, state = export_model(yolo_infer, Device.DEFAULT.lower(), Tensor.randn(1,3,416,416))
prg, inp_sizes, out_sizes, state = export_model(yolo_infer, Device.DEFAULT.lower(), Tensor.randn(1,3,416,416), model_name="yolov8")
dirname = Path(__file__).parent
safe_save(state, (dirname / "net.safetensors").as_posix())
with open(dirname / f"net.js", "w") as text_file:

View File

@@ -4,7 +4,10 @@
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>YOLOv8 tinygrad WebGPU</title>
<script src="./net.js"></script>
<script type="module">
import yolov8 from "./net.js"
window.yolov8 = yolov8;
</script>
<style>
body {
text-align: center;
@@ -213,7 +216,7 @@
wgpuError.style.display = "block";
loadingContainer.style.display = "none";
}
net = await loadNet(device);
net = await yolov8.load(device, "./net.safetensors");
loadingContainer.style.display = "none";
}
let start = performance.now();

View File

@@ -1,4 +1,4 @@
from typing import Tuple, Dict, List
from typing import Tuple, Dict, List, Optional
from tinygrad.dtype import DType
from tinygrad.renderer import ProgramSpec
from tinygrad.tensor import Device, Tensor
@@ -9,17 +9,6 @@ from tinygrad.dtype import dtypes
import json
EXPORT_SUPPORTED_DEVICE = ["WEBGPU", "CLANG", "CUDA", "GPU"]
web_utils = {
"getTensorBuffer":
"""const getTensorBuffer = (safetensorBuffer, tensorMetadata) => {
return safetensorBuffer.subarray(...tensorMetadata.data_offsets);
}""",
"getTensorMetadata": """const getTensorMetadata = (safetensorBuffer) => {
const metadataLength = Number(new DataView(safetensorBuffer.buffer).getBigUint64(0, true));
const metadata = JSON.parse(new TextDecoder("utf8").decode(safetensorBuffer.subarray(8, 8 + metadataLength)));
return Object.fromEntries(Object.entries(metadata).filter(([k, v]) => k !== "__metadata__").map(([k, v]) => [k, {...v, data_offsets: v.data_offsets.map(x => 8 + metadataLength + x)}]));
};"""
}
def compile_net(run:TinyJit, special_names:Dict[int,str]) -> Tuple[Dict[str,str],List[Tuple[str,List[str],List[int]]],Dict[str,Tuple[int,DType,int]],Dict[str,Tensor]]:
functions, bufs, bufs_to_save, statements, bufnum = {}, {}, {}, [], 0
@@ -82,14 +71,15 @@ def export_model_clang(functions:Dict[str,str], statements:Dict[str,Tuple[str,in
def dtype_to_js_type(dtype: DType) -> str:
return f"{'Uint' if dtype in dtypes.uints else 'Int' if (dtype in dtypes.sints or dtype == dtypes.bool) else 'Float'}{8*dtype.itemsize}Array"
def export_model_webgpu(functions, statements, bufs, bufs_to_save, weight_names, input_names, output_names) -> Tuple[str,int,int]:
def export_model_webgpu(functions, statements, bufs, weight_names, input_names, output_names, model_name) -> Tuple[str,int,int]:
exported_name = "model" if model_name == None else model_name
kernel_code = '\n\n'.join([f"const {key} = `{code.replace(key, 'main')}`;" for key, code in functions.items()])
kernel_names = ', '.join([name for (name, _args, _global_size, _local_size) in statements])
kernel_names = ', '.join([name for (name, _, _, _) in statements])
create_bind_group_layouts = ",".join([
"device.createBindGroupLayout({{entries: [{{binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: {{ type: 'uniform' }}}}, {}]}})".format(
",".join([f"{{binding: {argIdx+1}, visibility: GPUShaderStage.COMPUTE, buffer: {{ type: 'storage' }} }}" for argIdx, _ in enumerate(args)])
)
for i, (_name, args, global_size, _local_size) in enumerate(statements)
for _, (_, args, _, _) in enumerate(statements)
])
layouts = f"const layouts=[{create_bind_group_layouts}]"
kernel_calls = '\n '.join([f"addComputePass(device, commandEncoder, pipelines[{i}], layouts[{i}], infinityBuf, [{', '.join(args)}], {global_size});" for i, (_name, args, global_size, _local_size) in enumerate(statements) ])
@@ -103,9 +93,16 @@ def export_model_webgpu(functions, statements, bufs, bufs_to_save, weight_names,
output_readers = '\n '.join([f"await gpuReadBuffer{i}.mapAsync(GPUMapMode.READ);\n const resultBuffer{i} = new {output_buffer_types[i]}(gpuReadBuffer{i}.size/{bufs[output_names[i]][1].itemsize});\n resultBuffer{i}.set(new {output_buffer_types[i]}(gpuReadBuffer{i}.getMappedRange()));\n gpuReadBuffer{i}.unmap();" for i in range(len(output_names))])
output_return = '[{}]'.format(",".join([f'resultBuffer{i}' for i in range(len(output_names))]))
return f"""
{web_utils["getTensorBuffer"]}
const {exported_name} = (() => {{
const getTensorBuffer = (safetensorBuffer, tensorMetadata) => {{
return safetensorBuffer.subarray(...tensorMetadata.data_offsets);
}};
{web_utils["getTensorMetadata"]}
const getTensorMetadata = (safetensorBuffer) => {{
const metadataLength = Number(new DataView(safetensorBuffer.buffer).getBigUint64(0, true));
const metadata = JSON.parse(new TextDecoder("utf8").decode(safetensorBuffer.subarray(8, 8 + metadataLength)));
return Object.fromEntries(Object.entries(metadata).filter(([k, v]) => k !== "__metadata__").map(([k, v]) => [k, {{...v, data_offsets: v.data_offsets.map(x => 8 + metadataLength + x)}}]));
}};
const createEmptyBuf = (device, size) => {{
return device.createBuffer({{size, usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST }});
@@ -187,9 +184,13 @@ const setupNet = async (device, safetensor) => {{
return {output_return};
}}
}}
""" + f"\n\nconst loadNet = async (device) => {{ return await fetch('net.safetensors').then(x => x.arrayBuffer()).then(x => setupNet(device, new Uint8Array(x))); }}"
const load = async (device, weight_path) => {{ return await fetch(weight_path).then(x => x.arrayBuffer()).then(x => setupNet(device, new Uint8Array(x))); }}
return {{ load }};
}})();
export default {exported_name};
"""
def export_model(model, target:str, *inputs):
def export_model(model, target:str, *inputs, model_name: Optional[str] = None):
assert Device.DEFAULT in EXPORT_SUPPORTED_DEVICE, "only WEBGPU, CLANG, CUDA, GPU, METAL are supported"
with Context(JIT=2): run,special_names = jit_model(model, *inputs)
functions, statements, bufs, bufs_to_save = compile_net(run, special_names)
@@ -201,7 +202,7 @@ def export_model(model, target:str, *inputs):
if target == "clang":
prg = export_model_clang(functions, statements, bufs, bufs_to_save, input_names, output_names)
elif target == "webgpu":
prg = export_model_webgpu(functions, statements, bufs, bufs_to_save, weight_names, input_names, output_names)
prg = export_model_webgpu(functions, statements, bufs, weight_names, input_names, output_names, model_name)
else:
prg = json.dumps({
"backend": Device.DEFAULT,