mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
encapsulate the exported webgpu model (#8203)
This commit is contained in:
@@ -17,8 +17,11 @@ canvas { display: none; }
|
||||
* { text-align: center; font-family: monospace; }
|
||||
</style>
|
||||
<title>tinygrad has WebGPU</title>
|
||||
<script src="../../net.js"></script>
|
||||
<link rel="icon" type="image/x-icon" href="https://raw.githubusercontent.com/tinygrad/tinygrad/master/docs/logo.png">
|
||||
<script type="module">
|
||||
import model from "../../net.js";
|
||||
window.model = model;
|
||||
</script>
|
||||
</head>
|
||||
<body>
|
||||
<h1>WebGPU <a href="https://github.com/geohot/tinygrad">tinygrad</a> EfficientNet!</h1>
|
||||
@@ -61,8 +64,6 @@ canvas { display: none; }
|
||||
|
||||
const getLabels = async () => (await fetch("https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json")).json();
|
||||
|
||||
const getSavetensorBuffer = async () => new Uint8Array(await (await fetch("../../net.safetensors")).arrayBuffer());
|
||||
|
||||
const reorderChannelsAndRemoveAlpha = (data) => {
|
||||
const out = [];
|
||||
let i = 0;
|
||||
@@ -97,9 +98,8 @@ canvas { display: none; }
|
||||
try {
|
||||
resultText.innerHTML = "loading..."
|
||||
labels = await getLabels();
|
||||
const safetensor = await getSavetensorBuffer();
|
||||
const device = await getDevice();
|
||||
net = await timer(() => setupNet(device, safetensor), "(compilation)");
|
||||
net = await timer(() => model.load(device, '../../net.safetensors'), "(compilation)");
|
||||
resultText.innerHTML = "ready"
|
||||
} catch (e) {
|
||||
error(e)
|
||||
|
||||
@@ -12,7 +12,7 @@ if __name__ == "__main__":
|
||||
yolo_infer = YOLOv8(w=0.25, r=2.0, d=0.33, num_classes=80)
|
||||
state_dict = safe_load(get_weights_location(yolo_variant))
|
||||
load_state_dict(yolo_infer, state_dict)
|
||||
prg, inp_sizes, out_sizes, state = export_model(yolo_infer, Device.DEFAULT.lower(), Tensor.randn(1,3,416,416))
|
||||
prg, inp_sizes, out_sizes, state = export_model(yolo_infer, Device.DEFAULT.lower(), Tensor.randn(1,3,416,416), model_name="yolov8")
|
||||
dirname = Path(__file__).parent
|
||||
safe_save(state, (dirname / "net.safetensors").as_posix())
|
||||
with open(dirname / f"net.js", "w") as text_file:
|
||||
|
||||
@@ -4,7 +4,10 @@
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>YOLOv8 tinygrad WebGPU</title>
|
||||
<script src="./net.js"></script>
|
||||
<script type="module">
|
||||
import yolov8 from "./net.js"
|
||||
window.yolov8 = yolov8;
|
||||
</script>
|
||||
<style>
|
||||
body {
|
||||
text-align: center;
|
||||
@@ -213,7 +216,7 @@
|
||||
wgpuError.style.display = "block";
|
||||
loadingContainer.style.display = "none";
|
||||
}
|
||||
net = await loadNet(device);
|
||||
net = await yolov8.load(device, "./net.safetensors");
|
||||
loadingContainer.style.display = "none";
|
||||
}
|
||||
let start = performance.now();
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Tuple, Dict, List
|
||||
from typing import Tuple, Dict, List, Optional
|
||||
from tinygrad.dtype import DType
|
||||
from tinygrad.renderer import ProgramSpec
|
||||
from tinygrad.tensor import Device, Tensor
|
||||
@@ -9,17 +9,6 @@ from tinygrad.dtype import dtypes
|
||||
import json
|
||||
|
||||
EXPORT_SUPPORTED_DEVICE = ["WEBGPU", "CLANG", "CUDA", "GPU"]
|
||||
web_utils = {
|
||||
"getTensorBuffer":
|
||||
"""const getTensorBuffer = (safetensorBuffer, tensorMetadata) => {
|
||||
return safetensorBuffer.subarray(...tensorMetadata.data_offsets);
|
||||
}""",
|
||||
"getTensorMetadata": """const getTensorMetadata = (safetensorBuffer) => {
|
||||
const metadataLength = Number(new DataView(safetensorBuffer.buffer).getBigUint64(0, true));
|
||||
const metadata = JSON.parse(new TextDecoder("utf8").decode(safetensorBuffer.subarray(8, 8 + metadataLength)));
|
||||
return Object.fromEntries(Object.entries(metadata).filter(([k, v]) => k !== "__metadata__").map(([k, v]) => [k, {...v, data_offsets: v.data_offsets.map(x => 8 + metadataLength + x)}]));
|
||||
};"""
|
||||
}
|
||||
|
||||
def compile_net(run:TinyJit, special_names:Dict[int,str]) -> Tuple[Dict[str,str],List[Tuple[str,List[str],List[int]]],Dict[str,Tuple[int,DType,int]],Dict[str,Tensor]]:
|
||||
functions, bufs, bufs_to_save, statements, bufnum = {}, {}, {}, [], 0
|
||||
@@ -82,14 +71,15 @@ def export_model_clang(functions:Dict[str,str], statements:Dict[str,Tuple[str,in
|
||||
def dtype_to_js_type(dtype: DType) -> str:
|
||||
return f"{'Uint' if dtype in dtypes.uints else 'Int' if (dtype in dtypes.sints or dtype == dtypes.bool) else 'Float'}{8*dtype.itemsize}Array"
|
||||
|
||||
def export_model_webgpu(functions, statements, bufs, bufs_to_save, weight_names, input_names, output_names) -> Tuple[str,int,int]:
|
||||
def export_model_webgpu(functions, statements, bufs, weight_names, input_names, output_names, model_name) -> Tuple[str,int,int]:
|
||||
exported_name = "model" if model_name == None else model_name
|
||||
kernel_code = '\n\n'.join([f"const {key} = `{code.replace(key, 'main')}`;" for key, code in functions.items()])
|
||||
kernel_names = ', '.join([name for (name, _args, _global_size, _local_size) in statements])
|
||||
kernel_names = ', '.join([name for (name, _, _, _) in statements])
|
||||
create_bind_group_layouts = ",".join([
|
||||
"device.createBindGroupLayout({{entries: [{{binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: {{ type: 'uniform' }}}}, {}]}})".format(
|
||||
",".join([f"{{binding: {argIdx+1}, visibility: GPUShaderStage.COMPUTE, buffer: {{ type: 'storage' }} }}" for argIdx, _ in enumerate(args)])
|
||||
)
|
||||
for i, (_name, args, global_size, _local_size) in enumerate(statements)
|
||||
for _, (_, args, _, _) in enumerate(statements)
|
||||
])
|
||||
layouts = f"const layouts=[{create_bind_group_layouts}]"
|
||||
kernel_calls = '\n '.join([f"addComputePass(device, commandEncoder, pipelines[{i}], layouts[{i}], infinityBuf, [{', '.join(args)}], {global_size});" for i, (_name, args, global_size, _local_size) in enumerate(statements) ])
|
||||
@@ -103,9 +93,16 @@ def export_model_webgpu(functions, statements, bufs, bufs_to_save, weight_names,
|
||||
output_readers = '\n '.join([f"await gpuReadBuffer{i}.mapAsync(GPUMapMode.READ);\n const resultBuffer{i} = new {output_buffer_types[i]}(gpuReadBuffer{i}.size/{bufs[output_names[i]][1].itemsize});\n resultBuffer{i}.set(new {output_buffer_types[i]}(gpuReadBuffer{i}.getMappedRange()));\n gpuReadBuffer{i}.unmap();" for i in range(len(output_names))])
|
||||
output_return = '[{}]'.format(",".join([f'resultBuffer{i}' for i in range(len(output_names))]))
|
||||
return f"""
|
||||
{web_utils["getTensorBuffer"]}
|
||||
const {exported_name} = (() => {{
|
||||
const getTensorBuffer = (safetensorBuffer, tensorMetadata) => {{
|
||||
return safetensorBuffer.subarray(...tensorMetadata.data_offsets);
|
||||
}};
|
||||
|
||||
{web_utils["getTensorMetadata"]}
|
||||
const getTensorMetadata = (safetensorBuffer) => {{
|
||||
const metadataLength = Number(new DataView(safetensorBuffer.buffer).getBigUint64(0, true));
|
||||
const metadata = JSON.parse(new TextDecoder("utf8").decode(safetensorBuffer.subarray(8, 8 + metadataLength)));
|
||||
return Object.fromEntries(Object.entries(metadata).filter(([k, v]) => k !== "__metadata__").map(([k, v]) => [k, {{...v, data_offsets: v.data_offsets.map(x => 8 + metadataLength + x)}}]));
|
||||
}};
|
||||
|
||||
const createEmptyBuf = (device, size) => {{
|
||||
return device.createBuffer({{size, usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST }});
|
||||
@@ -187,9 +184,13 @@ const setupNet = async (device, safetensor) => {{
|
||||
return {output_return};
|
||||
}}
|
||||
}}
|
||||
""" + f"\n\nconst loadNet = async (device) => {{ return await fetch('net.safetensors').then(x => x.arrayBuffer()).then(x => setupNet(device, new Uint8Array(x))); }}"
|
||||
const load = async (device, weight_path) => {{ return await fetch(weight_path).then(x => x.arrayBuffer()).then(x => setupNet(device, new Uint8Array(x))); }}
|
||||
return {{ load }};
|
||||
}})();
|
||||
export default {exported_name};
|
||||
"""
|
||||
|
||||
def export_model(model, target:str, *inputs):
|
||||
def export_model(model, target:str, *inputs, model_name: Optional[str] = None):
|
||||
assert Device.DEFAULT in EXPORT_SUPPORTED_DEVICE, "only WEBGPU, CLANG, CUDA, GPU, METAL are supported"
|
||||
with Context(JIT=2): run,special_names = jit_model(model, *inputs)
|
||||
functions, statements, bufs, bufs_to_save = compile_net(run, special_names)
|
||||
@@ -201,7 +202,7 @@ def export_model(model, target:str, *inputs):
|
||||
if target == "clang":
|
||||
prg = export_model_clang(functions, statements, bufs, bufs_to_save, input_names, output_names)
|
||||
elif target == "webgpu":
|
||||
prg = export_model_webgpu(functions, statements, bufs, bufs_to_save, weight_names, input_names, output_names)
|
||||
prg = export_model_webgpu(functions, statements, bufs, weight_names, input_names, output_names, model_name)
|
||||
else:
|
||||
prg = json.dumps({
|
||||
"backend": Device.DEFAULT,
|
||||
|
||||
Reference in New Issue
Block a user