From 01f7a4fadcaaee6fd86e202f515c1dcd3e28fdff Mon Sep 17 00:00:00 2001 From: hooved <172129504+hooved@users.noreply.github.com> Date: Tue, 4 Mar 2025 02:53:30 -0500 Subject: [PATCH] tinychat in browser, Part 2: model export (#9274) * load llama3-1B to WEBGPU device * include compile script for loading llama3 to WEBGPU * parametrize max_context in build_transformer fxn * jit_model with two different args sets * compile for webgpu, split weights * load model weight parts in browser * export all tensors from initialized transformer * run transformer inference in browser * enable tiktoken with llama bpe in browser * count total tokens on client with tiktoken.js * full client-side chat streaming, eliminate server * revert change that enabled jitting with 2 argsets * llama without Variable or cache_kv, for webgpu * have client use mask tokens / whole context * cleanup staged weights * add tiktoken.js build script, README * export CLANG for Q6_k to float32 decompression * fix and test exported CLANG code for Q6_k to fp32 * revert changes to jit and export_model * isolate clang export * test Q6_K to float32 decompression in browser * gguf_load now also returns t_infos and data_start * prepare llama-1B Q6_K gguf chunks for browser * cache and decompress quantized llama in browser * enable separate deployment of large files * fix kv cache and symbolic with llama wgpu * eliminate browser lag during decompression * hash metadata and weight chunks * delete obsolete indexeddb cache to free disk * add progress bar, track model download/decompress * refactor progress callback * skip buffer hash verification for speed * Display progress for entire loading scope * Report page load errors to user * actually display errors * skip prompt tokens already seen by model * skip prefilling with last assistant message tokens * on page load tell user if webgpu not enabled * push deployed URL root to window.history * make note of bug sources with TODO items * isolate bug in CLANG with BEAM=2 * remove clang_bug.py from diff * decompress q6k to f32 on webgpu instead of clang * remove unused code * inter-weight decomp with larger wgpu kernels * parallelize decompression submissions * refactor dequantize scheduling * add progress bar back * fix bug * temp fix for loading GGUF Q6_K to fp16 not fp32 * fix rendering of exported CLANG * remove weight casts, sketch js functions for clang * get symbolic vars from jit_cache for model export * include symbolic vars in exported CLANG * render js for clang transformer * toggle clang/webgpu deployment; refactor decomp * compile and render clang Q6_K->fp16 and int8 quant * fix rendered clang for abs(fp16), to work in wasm * simplify clang js wrapping * run compiled clang in worker * prepare llama weights in workers, q6k to int8/fp16 * tinychat on clang in browser, f32/int8 weights * move wasm inference to (now flexible) worker * don't load redundant embeddings * modest wasm perf gain with compile flags * set default backend, enable backend choice/backup * render symbolic vars in exported WEBGPU * quantize webgpu llama to int8/f32 * improve UX arising from rendered WEBGPU * clean up webgpu launch * new weights split: smaller chunks, tinygrad quant. * switch webgpu inference to int8 quant * remove unneeded clang decompression * eliminate unneeded kv cache transfer to wasm * use 1 worker for simplified clang decompression * display launch errors * refactor: stream load weight chunks to WebGPU * show loading chunk completion * quantize embeddings to int8 * test float16 as input for quantization * webgpu: use f16 source, int8 embed, eliminate q6k * simplify split weights prep: all from state_dict * revert change to nn.state.gguf_load * remove unneeded decompression from webgpu client * remove unneeded code * decrease dl chunks from 47 to 16 MiB * improve stability of webgpu loading on mobile * autodetect mobile, improve load stability * refactor: progress closure * refactor: one unified progress bar * remove unneeded code * revert changes to tinygrad core library * enforce ios18.3 nerfed max buf size * BEAM=3 webgpu * cache integrity, mobile save throttling * improve mobile UX - no autozoom on prompt box * clang: int8 from f16, remove q6k * reduce concurrent dls on mobile to 2 for stability * refactor: wasm backend with stream loading * prevent race between wasm load and indexedb save * split wasm kernels into separate modules * js wrapper for multiple wasm module inference * revert multi-module wasm to single module * make mobile wasm load more stable/fast * refactor: copy weights into wasm without crashes * fix bug in download queue; increase mobile dls * refactor exported clang wrapper, split weights * remove unnecessary code * greatly improve int8 quant quality with rounding * eliminate mobile throttling * increase webgpu context to 4096 tokens * export webgpu js functions * enable separate hosted weights for mobile/pc * enable prompt-thread switching during generation * stop generation when max_context is reached * show progress bar for prefill * tell user if webgpu fails, while wasm loads * make loading messages more concise * update font * revert changes to tinychat python app launch * cleanup quantization, add scale_dtype param * cleanup kv cache code * cleanup compile code * link tok_embeddings with output in webgpu export * refactor: export_model webgpu: symbolic vars * refactor: export_model weight loading * forgot to commit export_model.py * change CLANG to CPU * deal with pylint incorrectly failing tests * simplify f-strings for older CI python version * fix pre-python3.12 parser errors * [Int32Array] not Int32Array * cleanup webgpu compile after refactor export_model * refactor WASM export into export_model * merge WebGPU/WASM compile scripts * simplify max_contexts for local deployment * fix parser issues and whitespace * deduplicate variable defs for non-wasm clang export * cleanup code * cleanup compile scripts * simplify wasm inference wrapping * simplify webgpu symbolic vars export * refactor: unify export of symbolic variables * simplify WASM export * simplify clang/wasm export * update README and build scripts * separate files for browser/python apps * restore original python tinychat app files * browser and python tinychats share assets * minor cleanup * isolate compile/export model --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com> --- examples/tinychat/tinychat-browser/README.md | 8 ++ examples/tinychat/tinychat-browser/compile.py | 123 ++++++++++++++++ .../tinychat/tinychat-browser/compile_wasm.sh | 23 +++ extra/export_model.py | 136 +++++++++++++----- 4 files changed, 255 insertions(+), 35 deletions(-) create mode 100644 examples/tinychat/tinychat-browser/README.md create mode 100644 examples/tinychat/tinychat-browser/compile.py create mode 100755 examples/tinychat/tinychat-browser/compile_wasm.sh diff --git a/examples/tinychat/tinychat-browser/README.md b/examples/tinychat/tinychat-browser/README.md new file mode 100644 index 0000000000..2051cdfbda --- /dev/null +++ b/examples/tinychat/tinychat-browser/README.md @@ -0,0 +1,8 @@ +# How to build and run tinychat in browser (WebGPU and WASM) +- `PYTHONPATH=. python examples/tinychat/tinychat-browser/compile.py` +- `./examples/tinychat/tinychat-browser/compile_wasm.sh` + - Prerequisite: [install emscripten](https://emscripten.org/docs/getting_started/downloads.html). This script looks for `~/emsdk/emsdk_env.sh`, adjust this based on your installation. +- `./examples/tinychat/tinychat-browser/make_tiktoken_js.sh` + - Prerequisite: install `npm`, `webpack`. +- `cd examples/tinychat && python -m http.server 7776` +- In browser: open either `localhost:7776/tinychat-browser` (WebGPU), or `localhost:7776/tinychat-browser/?backend=wasm` (WASM) \ No newline at end of file diff --git a/examples/tinychat/tinychat-browser/compile.py b/examples/tinychat/tinychat-browser/compile.py new file mode 100644 index 0000000000..b9544dfeff --- /dev/null +++ b/examples/tinychat/tinychat-browser/compile.py @@ -0,0 +1,123 @@ +import os, json, hashlib, math +from extra.export_model import export_model +from examples.llama3 import build_transformer +from tinygrad.nn.state import get_state_dict, load_state_dict +from tinygrad import Device, Variable, Tensor, dtypes +from tinygrad.helpers import fetch, Context +from tiktoken.load import load_tiktoken_bpe, dump_tiktoken_bpe + +def prepare_browser_chunks(model): + # split weights into browser-friendly chunks + state_dict = get_state_dict(model) + del state_dict['output.weight'], state_dict['output.scale'] # same as tok_embeddings; ensures consistency with model export + chunk_size = 16 * 1024 * 1024 # small chunks based on iphone browser constraints + metadata = {} + # We won't export cache_kv bytes (because we start inference on client at start_pos=0), but we will tell the client how big cache_kv needs to be + t_infos = [(v.lazydata.base.realized.nbytes, k, v.dtype) for k,v in state_dict.items() if "cache_kv" not in k] + empty_t_infos = [(v.lazydata.base.realized.nbytes, k, v.dtype) for k,v in state_dict.items() if "cache_kv" in k] + + split_t_infos = [] + for size, name, dtype in t_infos: + if size <= chunk_size: + split_t_infos.append((size, name, dtype, ())) + else: # split large weights into multiple parts + for i in range(0, size, chunk_size): + split_t_infos.append((min(chunk_size, size-i), f"{name}_part{math.ceil(i/chunk_size)}", dtype, (i, min(i+chunk_size, size)))) + + files = [] + # pack weights into files with FFD bin packing + split_t_infos = sorted(split_t_infos, reverse=True) + for info in split_t_infos: + placed = False + for file in files: + if sum(i[0] for i in file) + info[0] <= chunk_size: + if info[3] and any(i[3] for i in file): continue # no two split tensors can touch the same file, due to wasm loading constraints + file.append(info) + placed = True + break + if not placed: + files.append([info]) + + tinygrad_dtypes = {dtypes.float32: "float32", dtypes.float16: "float16", dtypes.int8: "int8", dtypes.int32: "int32"} + for i, file in enumerate(files): + cursor = 0 + with open(os.path.join(os.path.dirname(__file__), f'./net_part{i}.chunk'), "wb+") as writer: + for size, name, dtype, offsets in file: + name, part_num = (name, 0) if "_part" not in name else (name.split("_part")[0], int(name.split("_part")[1])) + default = {"parts": {}, "dtype": tinygrad_dtypes[dtype]} + weight_metadata = metadata.get(name, default) + weight_metadata["parts"][part_num] = {"file": i, "file_start_pos": cursor, "size": size} + metadata[name] = weight_metadata + data = bytes(state_dict[name].lazydata.base.realized.as_buffer()) + data = data if not offsets else data[offsets[0]:offsets[1]] + writer.write(data) + cursor += size + + metadata.update({name: {"parts": {0: {"empty": True, "size": size}}, "dtype": tinygrad_dtypes[dtype]} for size, name, dtype in empty_t_infos}) + + for k in metadata: + metadata[k]["parts"] = [part for part_num, part in sorted(metadata[k]["parts"].items(), key = lambda x: x[0])] + cursor = 0 + for i, part in enumerate(metadata[k]["parts"]): + metadata[k]["parts"][i]["target_start_pos"] = cursor + cursor += part["size"] + metadata[k]["size"] = cursor + + # compute hashes, which client app will check to determine whether to update with new weights and/or detect integrity issues + state_dict_hash = hashlib.sha256(json.dumps(metadata, sort_keys=True).encode("utf-8")).hexdigest() + metadata = {"state_dict": metadata, "state_dict_hash": state_dict_hash, "files": []} + for i in range(len(files)): + with open(os.path.join(os.path.dirname(__file__), f'./net_part{i}.chunk'), "rb") as reader: + metadata["files"].append({"name": f'net_part{i}.chunk', "hash": hashlib.sha256(reader.read()).hexdigest()}) + metadata_hash = hashlib.sha256(json.dumps(metadata, sort_keys=True).encode("utf-8")).hexdigest() + metadata = {"metadata": metadata, "metadata_hash": metadata_hash} + + with open(os.path.join(os.path.dirname(__file__), f'./net_metadata.json'), "w") as writer: json.dump(metadata, writer, indent=4) + return metadata + +if __name__=="__main__": + # Export BPE data for use with tiktoken.js + tokenizer_path = fetch("https://huggingface.co/bofenghuang/Meta-Llama-3-8B/resolve/main/original/tokenizer.model", "tokenizer.model", subdir="llama3-1b-instruct") + mergeable_ranks = load_tiktoken_bpe(str(tokenizer_path)) + bpe_path = os.path.join(os.path.dirname(__file__), "llama3-2.tiktoken") + dump_tiktoken_bpe(mergeable_ranks, bpe_path) + + model_path = fetch("https://huggingface.co/bartowski/Llama-3.2-1B-Instruct-GGUF/resolve/main/Llama-3.2-1B-Instruct-f16.gguf", "Llama-3.2-1B-Instruct-f16.gguf", subdir="llama3-1b-instruct") + Tensor.no_grad = True + max_context=1024 + tok = 128000 + TEMPERATURE, TOP_K, TOP_P, ALPHA_F, ALPHA_P = 0.95, 0, 0.0, 0.0, 0.0 + start_pos = Variable("start_pos", 0, max_context).bind(0) + model_input = lambda: [Tensor([[tok]]), start_pos, TEMPERATURE, TOP_K, TOP_P, ALPHA_F, ALPHA_P] + + Device.DEFAULT="CPU" + model = build_transformer(model_path, model_size="1B", quantize="int8", scale_dtype=dtypes.float32, device=Device.DEFAULT, max_context=max_context) + state_dict = get_state_dict(model) + out = model.forward(*model_input()) + model_name = "transformer" + + with Context(BEAM=3): + cprog, js_wrapper = export_model(model, "wasm", *model_input(), model_name=model_name) + # ensure consistency with exported weights + js_wrapper = js_wrapper.replace("output.weight", "tok_embeddings.weight").replace("output.scale", "tok_embeddings.scale") + + with open(os.path.join(os.path.dirname(__file__), f"{model_name}.c"), "w") as f: f.write(cprog) + with open(os.path.join(os.path.dirname(__file__), "net_clang.js"), "w") as f: f.write(js_wrapper) + + Device.DEFAULT="WEBGPU" + # float16 is not yet supported for dawn/Vulkan/NVIDIA stack, see: https://issues.chromium.org/issues/42251215 + # therefore for now, we used CLANG to quantize the float16 llama to int8 with float32 scales, then load to WEBGPU + model = build_transformer(model_path, model_size="1B", quantize="int8", max_context=max_context, load_weights=False) + load_state_dict(model, state_dict) + # these were the same before load_state_dict + model.output.weight, model.output.scale = model.tok_embeddings.weight, model.tok_embeddings.scale + + out = model.forward(*model_input()) + metadata = prepare_browser_chunks(model) # export weights to disk + + with Context(BEAM=3): + prg, input_sizes, output_sizes, state = export_model(model, "webgpu", *model_input(), model_name=model_name, stream_weights=True) + # ensure consistency with exported weights + prg = prg.replace("output.weight", "tok_embeddings.weight").replace("output.scale", "tok_embeddings.scale") + + with open(os.path.join(os.path.dirname(__file__), "net.js"), "w") as f: f.write(prg) \ No newline at end of file diff --git a/examples/tinychat/tinychat-browser/compile_wasm.sh b/examples/tinychat/tinychat-browser/compile_wasm.sh new file mode 100755 index 0000000000..fda21b0c4a --- /dev/null +++ b/examples/tinychat/tinychat-browser/compile_wasm.sh @@ -0,0 +1,23 @@ +#!/usr/bin/env bash +cd "$(dirname "$0")" + +# prereq: install emscripten: https://emscripten.org/docs/getting_started/downloads.html +EMSCRIPTEN_PATH=~/emsdk/emsdk_env.sh +source $EMSCRIPTEN_PATH +step="transformer" +initial_memory=6553600 +max_memory=1500053504 +exported_functions='["_net", "_malloc", "_free", "_set_buf"]' + +emcc "${step}.c" \ + -O3 -msimd128 -ffast-math -flto \ + -o "${step}.js" \ + -s MODULARIZE=1 \ + -s EXPORT_ES6=1 \ + -s EXPORTED_FUNCTIONS="${exported_functions}" \ + -s ENVIRONMENT='worker' \ + -s FILESYSTEM=0 \ + -s EVAL_CTORS \ + -s ALLOW_MEMORY_GROWTH=1 \ + -s INITIAL_MEMORY="$initial_memory" \ + -s MAXIMUM_MEMORY="$max_memory" \ No newline at end of file diff --git a/extra/export_model.py b/extra/export_model.py index edd5b01f49..8e60373f29 100644 --- a/extra/export_model.py +++ b/extra/export_model.py @@ -6,7 +6,9 @@ from tinygrad.engine.jit import TinyJit from tinygrad.nn.state import get_state_dict from tinygrad.helpers import Context from tinygrad.dtype import dtypes +from tinygrad.ops import Ops import json +from collections import OrderedDict EXPORT_SUPPORTED_DEVICE = ["WEBGPU", "CPU", "CUDA", "GPU"] @@ -26,6 +28,7 @@ def compile_net(run:TinyJit, special_names:Dict[int,str]) -> Tuple[Dict[str,str] bufnum += 1 if i > 0: bufs_to_save[bufs[key][0]] = arg # if first usage of a buffer is not an output, and it's not a special name cargs.append(bufs[key][0]) + cargs += [var for var in fxn.vars if getattr(var, "op", None) is Ops.DEFINE_VAR] # symbolic vars; is it necessary or sufficient to check for DEFINE_VAR? statements.append((fxn.function_name, cargs, fxn.global_size, fxn.local_size)) return functions, statements, {name:(size, dtype, key) for (name,size,dtype,key) in bufs.values()}, bufs_to_save @@ -54,60 +57,105 @@ def jit_model(model, *args) -> Tuple[TinyJit,Dict[int,str]]: special_names[id(output.lazydata.base.realized)] = f'output{i}' return run, special_names -def export_model_clang(functions:Dict[str,str], statements:Dict[str,Tuple[str,int,int]], bufs:Dict[str,Tuple[str,int,int]], bufs_to_save:Dict[str,Tensor], input_names:List[str], output_names:List[str]) -> str: - cprog = ["#include "] +def export_model_clang(functions:Dict[str,str], statements:Dict[str,Tuple[str,int,int]], bufs:Dict[str,Tuple[str,int,int]], + bufs_to_save:Dict[str,Tensor], input_names:List[str], output_names:List[str], weight_names={}, model_name="model", symbolic_vars={}, wasm=False) -> str: + headers = ["#include "] + cprog = list(functions.values()) + dtype_map = {dtypes.int: "int", dtypes.float: "float", dtypes.uchar: "unsigned char", dtypes.char: "signed char", dtypes.half: "__fp16", dtypes.uint: "unsigned int"} + inputs = [(name, dtype_map[bufs[name][1]], bufs[name][0]) for name in input_names + list(symbolic_vars.values())] + outputs = [(name, dtype_map[bufs[name][1]], bufs[name][0]) for name in output_names] + forward_args = ",".join(f"{dtype}{'*' if name not in symbolic_vars.values() else ''} {name}" for name,dtype,_ in (outputs+inputs if wasm else inputs+outputs)) - for name,cl in bufs_to_save.items(): - weight = ''.join(["\\x%02X"%x for x in bytes(cl._buf)]) - cprog.append(f"unsigned char {name}_data[] = \"{weight}\";") + if not wasm: + for name,cl in bufs_to_save.items(): + weight = ''.join(["\\x%02X"%x for x in bytes(cl._buf)]) + cprog.append(f"unsigned char {name}_data[] = \"{weight}\";") + cprog += [f"{dtype_map[dtype]} {name}[{len}];" if name not in bufs_to_save else f"{dtype_map[dtype]} *{name} = ({dtype_map[dtype]} *){name}_data;" for name,(len,dtype,_key) in bufs.items() if name not in input_names+output_names] + cprog += [f"void net({forward_args}) {{"] + [f"{name}({', '.join(args)});" for (name, args, _global_size, _local_size) in statements] + ["}"] + return '\n'.join(headers + cprog) + else: + if bufs_to_save: + headers += ["#include "] + bufs_to_save = {k:v for k,v in bufs.items() if v[2] in weight_names} # causes random seeds to be set as zeroes, not exported as a model weight + buf_to_name = OrderedDict((buf_name, {"name": weight_names[data[2]], "idx": i}) for i, (buf_name, data) in enumerate(bufs_to_save.items())) + cprog.append(f"void* bufs[{len(buf_to_name)}];") + cprog.append(f"""void set_buf(size_t index, void* ptr) {{\n bufs[index] = ptr;\n}}""") - inputs = ", ".join([f'float* {input}' for input in input_names]) - outputs = ", ".join([f'float* {output}' for output in output_names]) - cprog += [f"float {name}[{len}];" if name not in bufs_to_save else f"float *{name} = (float *){name}_data;" for name,(len,dtype,_key) in bufs.items() if name not in ['input', 'outputs']] - cprog += list(functions.values()) - cprog += [f"void net({inputs}, {outputs}) {{"] + [f"{name}({', '.join(args)});" for (name, args, _global_size, _local_size) in statements] + ["}"] - return '\n'.join(cprog) + for name in set(bufs.keys()) - set(bufs_to_save.keys()) - set(input_names + output_names): + n_bytes, dtype, _ = bufs[name] + cprog += [f"{dtype_map[dtype]} {name}[{n_bytes // dtype.itemsize}];"] + + cprog += [f"void net({forward_args})"] + ["{"] + get_weight_ptr = lambda x: f"({dtype_map[bufs_to_save[x][1]]} *)bufs[{buf_to_name[x]['idx']}]" if x in bufs_to_save else x + cprog += [f" {name}({', '.join(map(get_weight_ptr, args))});" for (name, args, _global_size, _local_size) in statements] + ["}"] + weightMapping = "" if not bufs_to_save else f"""\nconst weightNames = [{", ".join([f'"{weight_name}"' for weight_name in [v["name"] for v in buf_to_name.values()]])}]; +const {model_name}_name_to_id = Object.fromEntries(weightNames.map((name, index) => [name, index]));\n""" + top = f"""import {model_name}Module from './{model_name}.js'{weightMapping}""" + whitespace = "\n " + js_wrapper = f"""{top}\nvar {model_name} = async function() {{ + const wasm = await {model_name}Module(); + {whitespace.join(f"const {name}Ptr = wasm._malloc({n_bytes});" for name, _, n_bytes in outputs+inputs if name not in symbolic_vars.values())} + return {{ + run: ({",".join(name for name,_,_ in inputs)}) => {{ + {(whitespace + " ").join(f"wasm.HEAPU8.set({name}, {name}Ptr);" for name,_,_ in inputs if name not in symbolic_vars.values())} + wasm._net({", ".join(f"{name}{'Ptr' if name not in symbolic_vars.values() else ''}" for name,_,_ in outputs+inputs)}); + {(whitespace + " ").join(f"const {name} = wasm.HEAPU8.slice({name}Ptr, {name}Ptr + {n_bytes});" for name,_,n_bytes in outputs)} + return [{", ".join(f"{name}" for name,_,_ in outputs)}]; + }}, + wasm: wasm + }} +}}\nexport {{ {model_name}, {model_name}_name_to_id }};""" + + return '\n'.join(headers + cprog), js_wrapper def dtype_to_js_type(dtype: DType) -> str: return f"{'Uint' if dtype in dtypes.uints else 'Int' if (dtype in dtypes.sints or dtype == dtypes.bool) else 'Float'}{8*dtype.itemsize}Array" -def export_model_webgpu(functions, statements, bufs, weight_names, input_names, output_names, model_name) -> Tuple[str,int,int]: - exported_name = "model" if model_name == None else model_name +def export_model_webgpu(functions, statements, bufs, weight_names, input_names, output_names, model_name, symbolic_vars={}, stream_weights=False) -> Tuple[str,int,int]: kernel_code = '\n\n'.join([f"const {key} = `{code.replace(key, 'main')}`;" for key, code in functions.items()]) kernel_names = ', '.join([name for (name, _, _, _) in statements]) + input_names += list(symbolic_vars.values()) + input_buffer_types = [dtype_to_js_type(bufs[inp_name][1]) for inp_name in input_names] + output_buffer_types = [dtype_to_js_type(bufs[out_name][1]) for out_name in output_names] + + buf_type = lambda x: "uniform" if x in set(symbolic_vars.values()) else "storage" create_bind_group_layouts = ",".join([ "device.createBindGroupLayout({{entries: [{{binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: {{ type: 'uniform' }}}}, {}]}})".format( - ",".join([f"{{binding: {argIdx+1}, visibility: GPUShaderStage.COMPUTE, buffer: {{ type: 'storage' }} }}" for argIdx, _ in enumerate(args)]) + ",".join([f"{{binding: {argIdx+1}, visibility: GPUShaderStage.COMPUTE, buffer: {{ type: '{buf_type(argName)}' }} }}" for argIdx, argName in enumerate(args)]) ) for _, (_, args, _, _) in enumerate(statements) ]) layouts = f"const layouts=[{create_bind_group_layouts}]" - kernel_calls = '\n '.join([f"addComputePass(device, commandEncoder, pipelines[{i}], layouts[{i}], infinityBuf, [{', '.join(args)}], {global_size});" for i, (_name, args, global_size, _local_size) in enumerate(statements) ]) - _bufs = '\n '.join([f"const {name} = " + (f"createEmptyBuf(device, {size});" if _key not in weight_names else f"createWeightBuf(device, {size}, getTensorBuffer(safetensor, metadata['{weight_names[_key]}']))") + ";" for name,(size,dtype,_key) in bufs.items()]) + kernel_calls = '\n '.join([f"addComputePass(device, commandEncoder, pipelines[{i}], layouts[{i}], infinityBuf, [{', '.join(args)}], [{', '.join(str(x) for x in global_size)}]);" for i, (_name, args, global_size, _local_size) in enumerate(statements) ]) + + buf_type = lambda x: "createUniformBuf" if x in set(uop.arg[0] for uop in symbolic_vars) else "createEmptyBuf" + map_to_external_weight = lambda _key: f"state_dict['{weight_names[_key]}']" if stream_weights else f"getTensorBuffer(safetensor, metadata['{weight_names[_key]}'])" + _bufs = '\n '.join([f"const {name} = " + (f"{buf_type(_key)}(device, {size});" if _key not in weight_names else f"createWeightBuf(device, {size}, {map_to_external_weight(_key)})") + ";" for name,(size,dtype,_key) in bufs.items()]) gpu_write_bufs = '\n '.join([f"const gpuWriteBuffer{i} = device.createBuffer({{size:{input_name}.size, usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.MAP_WRITE }});" for i,input_name in enumerate(input_names)]) - input_buffer_types = [dtype_to_js_type(bufs[inp_name][1]) for inp_name in input_names] - output_buffer_types = [dtype_to_js_type(bufs[out_name][1]) for out_name in output_names] input_writers = '\n '.join([f"await gpuWriteBuffer{i}.mapAsync(GPUMapMode.WRITE);\n new {input_buffer_types[i]}(gpuWriteBuffer{i}.getMappedRange()).set(" + f'_{inp_name});' + f"\n gpuWriteBuffer{i}.unmap();\n commandEncoder.copyBufferToBuffer(gpuWriteBuffer{i}, 0, {inp_name}, 0, gpuWriteBuffer{i}.size);" for i,inp_name in enumerate(input_names)]) gpu_read_bufs = '\n '.join([f"const gpuReadBuffer{i} = device.createBuffer({{size:{output_name}.size, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ }});" for i,output_name in enumerate(output_names)]) outbuf_copies = '\n '.join([f"commandEncoder.copyBufferToBuffer({output_name}, 0, gpuReadBuffer{i}, 0, output{i}.size);" for i,output_name in enumerate(output_names)]) output_readers = '\n '.join([f"await gpuReadBuffer{i}.mapAsync(GPUMapMode.READ);\n const resultBuffer{i} = new {output_buffer_types[i]}(gpuReadBuffer{i}.size/{bufs[output_names[i]][1].itemsize});\n resultBuffer{i}.set(new {output_buffer_types[i]}(gpuReadBuffer{i}.getMappedRange()));\n gpuReadBuffer{i}.unmap();" for i in range(len(output_names))]) output_return = '[{}]'.format(",".join([f'resultBuffer{i}' for i in range(len(output_names))])) - return f""" -const {exported_name} = (() => {{ -const getTensorBuffer = (safetensorBuffer, tensorMetadata) => {{ - return safetensorBuffer.subarray(...tensorMetadata.data_offsets); -}}; - -const getTensorMetadata = (safetensorBuffer) => {{ + getTensorMetadata = f"""\nconst getTensorMetadata = (safetensorBuffer) => {{ const metadataLength = Number(new DataView(safetensorBuffer.buffer).getBigUint64(0, true)); const metadata = JSON.parse(new TextDecoder("utf8").decode(safetensorBuffer.subarray(8, 8 + metadataLength))); return Object.fromEntries(Object.entries(metadata).filter(([k, v]) => k !== "__metadata__").map(([k, v]) => [k, {{...v, data_offsets: v.data_offsets.map(x => 8 + metadataLength + x)}}])); +}};\n""" if not stream_weights else "" + return f""" +const {model_name} = (() => {{ +const getTensorBuffer = (safetensorBuffer, tensorMetadata) => {{ + return safetensorBuffer.subarray(...tensorMetadata.data_offsets); }}; - +{getTensorMetadata} const createEmptyBuf = (device, size) => {{ return device.createBuffer({{size, usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST }}); }}; +const createUniformBuf = (device, size) => {{ + return device.createBuffer({{size, usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST}}) +}} + const createInfinityUniformBuf = (device) => {{ const size = 4; const buf = device.createBuffer({{ @@ -121,9 +169,8 @@ const createInfinityUniformBuf = (device) => {{ }}; const createWeightBuf = (device, size, data) => {{ - const buf = device.createBuffer({{ mappedAtCreation: true, size, usage: GPUBufferUsage.STORAGE }}); - new Uint8Array(buf.getMappedRange()).set(data); - buf.unmap(); + const buf = device.createBuffer({{ size, usage: GPUBufferUsage.STORAGE{" | GPUBufferUsage.COPY_DST" if stream_weights else ", mappedAtCreation: true"} }}); + {"data.bytes = buf;" if stream_weights else "new Uint8Array(buf.getMappedRange()).set(data); buf.unmap();"} return buf; }}; @@ -145,8 +192,8 @@ const addComputePass = (device, commandEncoder, pipeline, layout, infinityUnifor {kernel_code} -const setupNet = async (device, safetensor) => {{ - const metadata = getTensorMetadata(safetensor); +const setupNet = async (device, {"state_dict" if stream_weights else "safetensor"}) => {{ + {"const metadata = getTensorMetadata(safetensor);" if not stream_weights else ""} const infinityBuf = createInfinityUniformBuf(device); {layouts} @@ -185,12 +232,12 @@ const setupNet = async (device, safetensor) => {{ }} }} const load = async (device, weight_path) => {{ return await fetch(weight_path).then(x => x.arrayBuffer()).then(x => setupNet(device, new Uint8Array(x))); }} -return {{ load }}; +return {{ load, setupNet }}; }})(); -export default {exported_name}; +export default {model_name}; """ -def export_model(model, target:str, *inputs, model_name: Optional[str] = None): +def export_model(model, target:str, *inputs, model_name: Optional[str] = "model", stream_weights=False): assert Device.DEFAULT in EXPORT_SUPPORTED_DEVICE, "only WEBGPU, CPU, CUDA, GPU, METAL are supported" with Context(JIT=2): run,special_names = jit_model(model, *inputs) functions, statements, bufs, bufs_to_save = compile_net(run, special_names) @@ -198,11 +245,30 @@ def export_model(model, target:str, *inputs, model_name: Optional[str] = None): weight_names = {id(x.lazydata.base.realized): name for name, x in state.items()} input_names = [name for _,name in special_names.items() if "input" in name] output_names = [name for _,name in special_names.items() if "output" in name] + + # handle symbolic variables; TODO: refactor to fix some of this stuff upstream in tinygrad + symbolic_vars = OrderedDict() + for i, (_, args, global_size, _) in enumerate(statements): + for j, var in enumerate(args): + if getattr(var, "op", None) is Ops.DEFINE_VAR and isinstance(getattr(var, "arg", None), tuple) and isinstance(var.arg[0], str): + if var not in symbolic_vars: + symbolic_vars[var] = var.arg[0] + bufs[symbolic_vars[var]] = (var.dtype.itemsize, var.dtype, symbolic_vars[var]) + statements[i][1][j] = symbolic_vars[var] + + if global_size: + for j, dim in enumerate(global_size): + if getattr(dim, "op", None) is Ops.ADD and len(dim.src) == 2 and {dim.src[0].op, dim.src[1].op} == {Ops.DEFINE_VAR, Ops.CONST}: + name, val = dim.src if dim.src[1].op is Ops.CONST else reversed(dim.src) + global_size[j] = f"_{name.arg[0]}[0] + {val.arg}" + prg = "" if target == "clang": prg = export_model_clang(functions, statements, bufs, bufs_to_save, input_names, output_names) + elif target == "wasm": + return export_model_clang(functions, statements, bufs, bufs_to_save, input_names, output_names, weight_names, model_name, symbolic_vars, wasm=True) elif target == "webgpu": - prg = export_model_webgpu(functions, statements, bufs, weight_names, input_names, output_names, model_name) + prg = export_model_webgpu(functions, statements, bufs, weight_names, input_names, output_names, model_name, symbolic_vars, stream_weights) else: prg = json.dumps({ "backend": Device.DEFAULT,