encapsulate the exported webgpu model (#8203)

This commit is contained in:
Ahmed Harmouche
2024-12-13 10:55:37 +01:00
committed by GitHub
parent 5864627abe
commit 651f72442c
4 changed files with 32 additions and 28 deletions

View File

@@ -17,8 +17,11 @@ canvas { display: none; }
* { text-align: center; font-family: monospace; } * { text-align: center; font-family: monospace; }
</style> </style>
<title>tinygrad has WebGPU</title> <title>tinygrad has WebGPU</title>
<script src="../../net.js"></script>
<link rel="icon" type="image/x-icon" href="https://raw.githubusercontent.com/tinygrad/tinygrad/master/docs/logo.png"> <link rel="icon" type="image/x-icon" href="https://raw.githubusercontent.com/tinygrad/tinygrad/master/docs/logo.png">
<script type="module">
import model from "../../net.js";
window.model = model;
</script>
</head> </head>
<body> <body>
<h1>WebGPU <a href="https://github.com/geohot/tinygrad">tinygrad</a> EfficientNet!</h1> <h1>WebGPU <a href="https://github.com/geohot/tinygrad">tinygrad</a> EfficientNet!</h1>
@@ -61,8 +64,6 @@ canvas { display: none; }
const getLabels = async () => (await fetch("https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json")).json(); const getLabels = async () => (await fetch("https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json")).json();
const getSavetensorBuffer = async () => new Uint8Array(await (await fetch("../../net.safetensors")).arrayBuffer());
const reorderChannelsAndRemoveAlpha = (data) => { const reorderChannelsAndRemoveAlpha = (data) => {
const out = []; const out = [];
let i = 0; let i = 0;
@@ -97,9 +98,8 @@ canvas { display: none; }
try { try {
resultText.innerHTML = "loading..." resultText.innerHTML = "loading..."
labels = await getLabels(); labels = await getLabels();
const safetensor = await getSavetensorBuffer();
const device = await getDevice(); const device = await getDevice();
net = await timer(() => setupNet(device, safetensor), "(compilation)"); net = await timer(() => model.load(device, '../../net.safetensors'), "(compilation)");
resultText.innerHTML = "ready" resultText.innerHTML = "ready"
} catch (e) { } catch (e) {
error(e) error(e)

View File

@@ -12,7 +12,7 @@ if __name__ == "__main__":
yolo_infer = YOLOv8(w=0.25, r=2.0, d=0.33, num_classes=80) yolo_infer = YOLOv8(w=0.25, r=2.0, d=0.33, num_classes=80)
state_dict = safe_load(get_weights_location(yolo_variant)) state_dict = safe_load(get_weights_location(yolo_variant))
load_state_dict(yolo_infer, state_dict) load_state_dict(yolo_infer, state_dict)
prg, inp_sizes, out_sizes, state = export_model(yolo_infer, Device.DEFAULT.lower(), Tensor.randn(1,3,416,416)) prg, inp_sizes, out_sizes, state = export_model(yolo_infer, Device.DEFAULT.lower(), Tensor.randn(1,3,416,416), model_name="yolov8")
dirname = Path(__file__).parent dirname = Path(__file__).parent
safe_save(state, (dirname / "net.safetensors").as_posix()) safe_save(state, (dirname / "net.safetensors").as_posix())
with open(dirname / f"net.js", "w") as text_file: with open(dirname / f"net.js", "w") as text_file:

View File

@@ -4,7 +4,10 @@
<meta charset="UTF-8"> <meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0"> <meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>YOLOv8 tinygrad WebGPU</title> <title>YOLOv8 tinygrad WebGPU</title>
<script src="./net.js"></script> <script type="module">
import yolov8 from "./net.js"
window.yolov8 = yolov8;
</script>
<style> <style>
body { body {
text-align: center; text-align: center;
@@ -213,7 +216,7 @@
wgpuError.style.display = "block"; wgpuError.style.display = "block";
loadingContainer.style.display = "none"; loadingContainer.style.display = "none";
} }
net = await loadNet(device); net = await yolov8.load(device, "./net.safetensors");
loadingContainer.style.display = "none"; loadingContainer.style.display = "none";
} }
let start = performance.now(); let start = performance.now();

View File

@@ -1,4 +1,4 @@
from typing import Tuple, Dict, List from typing import Tuple, Dict, List, Optional
from tinygrad.dtype import DType from tinygrad.dtype import DType
from tinygrad.renderer import ProgramSpec from tinygrad.renderer import ProgramSpec
from tinygrad.tensor import Device, Tensor from tinygrad.tensor import Device, Tensor
@@ -9,17 +9,6 @@ from tinygrad.dtype import dtypes
import json import json
EXPORT_SUPPORTED_DEVICE = ["WEBGPU", "CLANG", "CUDA", "GPU"] EXPORT_SUPPORTED_DEVICE = ["WEBGPU", "CLANG", "CUDA", "GPU"]
web_utils = {
"getTensorBuffer":
"""const getTensorBuffer = (safetensorBuffer, tensorMetadata) => {
return safetensorBuffer.subarray(...tensorMetadata.data_offsets);
}""",
"getTensorMetadata": """const getTensorMetadata = (safetensorBuffer) => {
const metadataLength = Number(new DataView(safetensorBuffer.buffer).getBigUint64(0, true));
const metadata = JSON.parse(new TextDecoder("utf8").decode(safetensorBuffer.subarray(8, 8 + metadataLength)));
return Object.fromEntries(Object.entries(metadata).filter(([k, v]) => k !== "__metadata__").map(([k, v]) => [k, {...v, data_offsets: v.data_offsets.map(x => 8 + metadataLength + x)}]));
};"""
}
def compile_net(run:TinyJit, special_names:Dict[int,str]) -> Tuple[Dict[str,str],List[Tuple[str,List[str],List[int]]],Dict[str,Tuple[int,DType,int]],Dict[str,Tensor]]: def compile_net(run:TinyJit, special_names:Dict[int,str]) -> Tuple[Dict[str,str],List[Tuple[str,List[str],List[int]]],Dict[str,Tuple[int,DType,int]],Dict[str,Tensor]]:
functions, bufs, bufs_to_save, statements, bufnum = {}, {}, {}, [], 0 functions, bufs, bufs_to_save, statements, bufnum = {}, {}, {}, [], 0
@@ -82,14 +71,15 @@ def export_model_clang(functions:Dict[str,str], statements:Dict[str,Tuple[str,in
def dtype_to_js_type(dtype: DType) -> str: def dtype_to_js_type(dtype: DType) -> str:
return f"{'Uint' if dtype in dtypes.uints else 'Int' if (dtype in dtypes.sints or dtype == dtypes.bool) else 'Float'}{8*dtype.itemsize}Array" return f"{'Uint' if dtype in dtypes.uints else 'Int' if (dtype in dtypes.sints or dtype == dtypes.bool) else 'Float'}{8*dtype.itemsize}Array"
def export_model_webgpu(functions, statements, bufs, bufs_to_save, weight_names, input_names, output_names) -> Tuple[str,int,int]: def export_model_webgpu(functions, statements, bufs, weight_names, input_names, output_names, model_name) -> Tuple[str,int,int]:
exported_name = "model" if model_name == None else model_name
kernel_code = '\n\n'.join([f"const {key} = `{code.replace(key, 'main')}`;" for key, code in functions.items()]) kernel_code = '\n\n'.join([f"const {key} = `{code.replace(key, 'main')}`;" for key, code in functions.items()])
kernel_names = ', '.join([name for (name, _args, _global_size, _local_size) in statements]) kernel_names = ', '.join([name for (name, _, _, _) in statements])
create_bind_group_layouts = ",".join([ create_bind_group_layouts = ",".join([
"device.createBindGroupLayout({{entries: [{{binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: {{ type: 'uniform' }}}}, {}]}})".format( "device.createBindGroupLayout({{entries: [{{binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: {{ type: 'uniform' }}}}, {}]}})".format(
",".join([f"{{binding: {argIdx+1}, visibility: GPUShaderStage.COMPUTE, buffer: {{ type: 'storage' }} }}" for argIdx, _ in enumerate(args)]) ",".join([f"{{binding: {argIdx+1}, visibility: GPUShaderStage.COMPUTE, buffer: {{ type: 'storage' }} }}" for argIdx, _ in enumerate(args)])
) )
for i, (_name, args, global_size, _local_size) in enumerate(statements) for _, (_, args, _, _) in enumerate(statements)
]) ])
layouts = f"const layouts=[{create_bind_group_layouts}]" layouts = f"const layouts=[{create_bind_group_layouts}]"
kernel_calls = '\n '.join([f"addComputePass(device, commandEncoder, pipelines[{i}], layouts[{i}], infinityBuf, [{', '.join(args)}], {global_size});" for i, (_name, args, global_size, _local_size) in enumerate(statements) ]) kernel_calls = '\n '.join([f"addComputePass(device, commandEncoder, pipelines[{i}], layouts[{i}], infinityBuf, [{', '.join(args)}], {global_size});" for i, (_name, args, global_size, _local_size) in enumerate(statements) ])
@@ -103,9 +93,16 @@ def export_model_webgpu(functions, statements, bufs, bufs_to_save, weight_names,
output_readers = '\n '.join([f"await gpuReadBuffer{i}.mapAsync(GPUMapMode.READ);\n const resultBuffer{i} = new {output_buffer_types[i]}(gpuReadBuffer{i}.size/{bufs[output_names[i]][1].itemsize});\n resultBuffer{i}.set(new {output_buffer_types[i]}(gpuReadBuffer{i}.getMappedRange()));\n gpuReadBuffer{i}.unmap();" for i in range(len(output_names))]) output_readers = '\n '.join([f"await gpuReadBuffer{i}.mapAsync(GPUMapMode.READ);\n const resultBuffer{i} = new {output_buffer_types[i]}(gpuReadBuffer{i}.size/{bufs[output_names[i]][1].itemsize});\n resultBuffer{i}.set(new {output_buffer_types[i]}(gpuReadBuffer{i}.getMappedRange()));\n gpuReadBuffer{i}.unmap();" for i in range(len(output_names))])
output_return = '[{}]'.format(",".join([f'resultBuffer{i}' for i in range(len(output_names))])) output_return = '[{}]'.format(",".join([f'resultBuffer{i}' for i in range(len(output_names))]))
return f""" return f"""
{web_utils["getTensorBuffer"]} const {exported_name} = (() => {{
const getTensorBuffer = (safetensorBuffer, tensorMetadata) => {{
return safetensorBuffer.subarray(...tensorMetadata.data_offsets);
}};
{web_utils["getTensorMetadata"]} const getTensorMetadata = (safetensorBuffer) => {{
const metadataLength = Number(new DataView(safetensorBuffer.buffer).getBigUint64(0, true));
const metadata = JSON.parse(new TextDecoder("utf8").decode(safetensorBuffer.subarray(8, 8 + metadataLength)));
return Object.fromEntries(Object.entries(metadata).filter(([k, v]) => k !== "__metadata__").map(([k, v]) => [k, {{...v, data_offsets: v.data_offsets.map(x => 8 + metadataLength + x)}}]));
}};
const createEmptyBuf = (device, size) => {{ const createEmptyBuf = (device, size) => {{
return device.createBuffer({{size, usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST }}); return device.createBuffer({{size, usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST }});
@@ -187,9 +184,13 @@ const setupNet = async (device, safetensor) => {{
return {output_return}; return {output_return};
}} }}
}} }}
""" + f"\n\nconst loadNet = async (device) => {{ return await fetch('net.safetensors').then(x => x.arrayBuffer()).then(x => setupNet(device, new Uint8Array(x))); }}" const load = async (device, weight_path) => {{ return await fetch(weight_path).then(x => x.arrayBuffer()).then(x => setupNet(device, new Uint8Array(x))); }}
return {{ load }};
}})();
export default {exported_name};
"""
def export_model(model, target:str, *inputs): def export_model(model, target:str, *inputs, model_name: Optional[str] = None):
assert Device.DEFAULT in EXPORT_SUPPORTED_DEVICE, "only WEBGPU, CLANG, CUDA, GPU, METAL are supported" assert Device.DEFAULT in EXPORT_SUPPORTED_DEVICE, "only WEBGPU, CLANG, CUDA, GPU, METAL are supported"
with Context(JIT=2): run,special_names = jit_model(model, *inputs) with Context(JIT=2): run,special_names = jit_model(model, *inputs)
functions, statements, bufs, bufs_to_save = compile_net(run, special_names) functions, statements, bufs, bufs_to_save = compile_net(run, special_names)
@@ -201,7 +202,7 @@ def export_model(model, target:str, *inputs):
if target == "clang": if target == "clang":
prg = export_model_clang(functions, statements, bufs, bufs_to_save, input_names, output_names) prg = export_model_clang(functions, statements, bufs, bufs_to_save, input_names, output_names)
elif target == "webgpu": elif target == "webgpu":
prg = export_model_webgpu(functions, statements, bufs, bufs_to_save, weight_names, input_names, output_names) prg = export_model_webgpu(functions, statements, bufs, weight_names, input_names, output_names, model_name)
else: else:
prg = json.dumps({ prg = json.dumps({
"backend": Device.DEFAULT, "backend": Device.DEFAULT,