tinychat in browser, Part 2: model export (#9274)

* load llama3-1B to WEBGPU device

* include compile script for loading llama3 to WEBGPU

* parametrize max_context in build_transformer fxn

* jit_model with two different args sets

* compile for webgpu, split weights

* load model weight parts in browser

* export all tensors from initialized transformer

* run transformer inference in browser

* enable tiktoken with llama bpe in browser

* count total tokens on client with tiktoken.js

* full client-side chat streaming, eliminate server

* revert change that enabled jitting with 2 argsets

* llama without Variable or cache_kv, for webgpu

* have client use mask tokens / whole context

* cleanup staged weights

* add tiktoken.js build script, README

* export CLANG for Q6_k to float32 decompression

* fix and test exported CLANG code for Q6_k to fp32

* revert changes to jit and export_model

* isolate clang export

* test Q6_K to float32 decompression in browser

* gguf_load now also returns t_infos and data_start

* prepare llama-1B Q6_K gguf chunks for browser

* cache and decompress quantized llama in browser

* enable separate deployment of large files

* fix kv cache and symbolic with llama wgpu

* eliminate browser lag during decompression

* hash metadata and weight chunks

* delete obsolete indexeddb cache to free disk

* add progress bar, track model download/decompress

* refactor progress callback

* skip buffer hash verification for speed

* Display progress for entire loading scope

* Report page load errors to user

* actually display errors

* skip prompt tokens already seen by model

* skip prefilling with last assistant message tokens

* on page load tell user if webgpu not enabled

* push deployed URL root to window.history

* make note of bug sources with TODO items

* isolate bug in CLANG with BEAM=2

* remove clang_bug.py from diff

* decompress q6k to f32 on webgpu instead of clang

* remove unused code

* inter-weight decomp with larger wgpu kernels

* parallelize decompression submissions

* refactor dequantize scheduling

* add progress bar back

* fix bug

* temp fix for loading GGUF Q6_K to fp16 not fp32

* fix rendering of exported CLANG

* remove weight casts, sketch js functions for clang

* get symbolic vars from jit_cache for model export

* include symbolic vars in exported CLANG

* render js for clang transformer

* toggle clang/webgpu deployment; refactor decomp

* compile and render clang Q6_K->fp16 and int8 quant

* fix rendered clang for abs(fp16), to work in wasm

* simplify clang js wrapping

* run compiled clang in worker

* prepare llama weights in workers, q6k to int8/fp16

* tinychat on clang in browser, f32/int8 weights

* move wasm inference to (now flexible) worker

* don't load redundant embeddings

* modest wasm perf gain with compile flags

* set default backend, enable backend choice/backup

* render symbolic vars in exported WEBGPU

* quantize webgpu llama to int8/f32

* improve UX arising from rendered WEBGPU

* clean up webgpu launch

* new weights split: smaller chunks, tinygrad quant.

* switch webgpu inference to int8 quant

* remove unneeded clang decompression

* eliminate unneeded kv cache transfer to wasm

* use 1 worker for simplified clang decompression

* display launch errors

* refactor: stream load weight chunks to WebGPU

* show loading chunk completion

* quantize embeddings to int8

* test float16 as input for quantization

* webgpu: use f16 source, int8 embed, eliminate q6k

* simplify split weights prep: all from state_dict

* revert change to nn.state.gguf_load

* remove unneeded decompression from webgpu client

* remove unneeded code

* decrease dl chunks from 47 to 16 MiB

* improve stability of webgpu loading on mobile

* autodetect mobile, improve load stability

* refactor: progress closure

* refactor: one unified progress bar

* remove unneeded code

* revert changes to tinygrad core library

* enforce ios18.3 nerfed max buf size

* BEAM=3 webgpu

* cache integrity, mobile save throttling

* improve mobile UX - no autozoom on prompt box

* clang: int8 from f16, remove q6k

* reduce concurrent dls on mobile to 2 for stability

* refactor: wasm backend with stream loading

* prevent race between wasm load and indexedb save

* split wasm kernels into separate modules

* js wrapper for multiple wasm module inference

* revert multi-module wasm to single module

* make mobile wasm load more stable/fast

* refactor: copy weights into wasm without crashes

* fix bug in download queue; increase mobile dls

* refactor exported clang wrapper, split weights

* remove unnecessary code

* greatly improve int8 quant quality with rounding

* eliminate mobile throttling

* increase webgpu context to 4096 tokens

* export webgpu js functions

* enable separate hosted weights for mobile/pc

* enable prompt-thread switching during generation

* stop generation when max_context is reached

* show progress bar for prefill

* tell user if webgpu fails, while wasm loads

* make loading messages more concise

* update font

* revert changes to tinychat python app launch

* cleanup quantization, add scale_dtype param

* cleanup kv cache code

* cleanup compile code

* link tok_embeddings with output in webgpu export

* refactor: export_model webgpu: symbolic vars

* refactor: export_model weight loading

* forgot to commit export_model.py

* change CLANG to CPU

* deal with pylint incorrectly failing tests

* simplify f-strings for older CI python version

* fix pre-python3.12 parser errors

* [Int32Array] not Int32Array

* cleanup webgpu compile after refactor export_model

* refactor WASM export into export_model

* merge WebGPU/WASM compile scripts

* simplify max_contexts for local deployment

* fix parser issues and whitespace

* deduplicate variable defs for non-wasm clang export

* cleanup code

* cleanup compile scripts

* simplify wasm inference wrapping

* simplify webgpu symbolic vars export

* refactor: unify export of symbolic variables

* simplify WASM export

* simplify clang/wasm export

* update README and build scripts

* separate files for browser/python apps

* restore original python tinychat app files

* browser and python tinychats share assets

* minor cleanup

* isolate compile/export model

---------

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
hooved
2025-03-04 02:53:30 -05:00
committed by GitHub
parent 94db8426cb
commit 01f7a4fadc
4 changed files with 255 additions and 35 deletions

View File

@@ -6,7 +6,9 @@ from tinygrad.engine.jit import TinyJit
from tinygrad.nn.state import get_state_dict
from tinygrad.helpers import Context
from tinygrad.dtype import dtypes
from tinygrad.ops import Ops
import json
from collections import OrderedDict
EXPORT_SUPPORTED_DEVICE = ["WEBGPU", "CPU", "CUDA", "GPU"]
@@ -26,6 +28,7 @@ def compile_net(run:TinyJit, special_names:Dict[int,str]) -> Tuple[Dict[str,str]
bufnum += 1
if i > 0: bufs_to_save[bufs[key][0]] = arg # if first usage of a buffer is not an output, and it's not a special name
cargs.append(bufs[key][0])
cargs += [var for var in fxn.vars if getattr(var, "op", None) is Ops.DEFINE_VAR] # symbolic vars; is it necessary or sufficient to check for DEFINE_VAR?
statements.append((fxn.function_name, cargs, fxn.global_size, fxn.local_size))
return functions, statements, {name:(size, dtype, key) for (name,size,dtype,key) in bufs.values()}, bufs_to_save
@@ -54,60 +57,105 @@ def jit_model(model, *args) -> Tuple[TinyJit,Dict[int,str]]:
special_names[id(output.lazydata.base.realized)] = f'output{i}'
return run, special_names
def export_model_clang(functions:Dict[str,str], statements:Dict[str,Tuple[str,int,int]], bufs:Dict[str,Tuple[str,int,int]], bufs_to_save:Dict[str,Tensor], input_names:List[str], output_names:List[str]) -> str:
cprog = ["#include <tgmath.h>"]
def export_model_clang(functions:Dict[str,str], statements:Dict[str,Tuple[str,int,int]], bufs:Dict[str,Tuple[str,int,int]],
bufs_to_save:Dict[str,Tensor], input_names:List[str], output_names:List[str], weight_names={}, model_name="model", symbolic_vars={}, wasm=False) -> str:
headers = ["#include <tgmath.h>"]
cprog = list(functions.values())
dtype_map = {dtypes.int: "int", dtypes.float: "float", dtypes.uchar: "unsigned char", dtypes.char: "signed char", dtypes.half: "__fp16", dtypes.uint: "unsigned int"}
inputs = [(name, dtype_map[bufs[name][1]], bufs[name][0]) for name in input_names + list(symbolic_vars.values())]
outputs = [(name, dtype_map[bufs[name][1]], bufs[name][0]) for name in output_names]
forward_args = ",".join(f"{dtype}{'*' if name not in symbolic_vars.values() else ''} {name}" for name,dtype,_ in (outputs+inputs if wasm else inputs+outputs))
for name,cl in bufs_to_save.items():
weight = ''.join(["\\x%02X"%x for x in bytes(cl._buf)])
cprog.append(f"unsigned char {name}_data[] = \"{weight}\";")
if not wasm:
for name,cl in bufs_to_save.items():
weight = ''.join(["\\x%02X"%x for x in bytes(cl._buf)])
cprog.append(f"unsigned char {name}_data[] = \"{weight}\";")
cprog += [f"{dtype_map[dtype]} {name}[{len}];" if name not in bufs_to_save else f"{dtype_map[dtype]} *{name} = ({dtype_map[dtype]} *){name}_data;" for name,(len,dtype,_key) in bufs.items() if name not in input_names+output_names]
cprog += [f"void net({forward_args}) {{"] + [f"{name}({', '.join(args)});" for (name, args, _global_size, _local_size) in statements] + ["}"]
return '\n'.join(headers + cprog)
else:
if bufs_to_save:
headers += ["#include <stddef.h>"]
bufs_to_save = {k:v for k,v in bufs.items() if v[2] in weight_names} # causes random seeds to be set as zeroes, not exported as a model weight
buf_to_name = OrderedDict((buf_name, {"name": weight_names[data[2]], "idx": i}) for i, (buf_name, data) in enumerate(bufs_to_save.items()))
cprog.append(f"void* bufs[{len(buf_to_name)}];")
cprog.append(f"""void set_buf(size_t index, void* ptr) {{\n bufs[index] = ptr;\n}}""")
inputs = ", ".join([f'float* {input}' for input in input_names])
outputs = ", ".join([f'float* {output}' for output in output_names])
cprog += [f"float {name}[{len}];" if name not in bufs_to_save else f"float *{name} = (float *){name}_data;" for name,(len,dtype,_key) in bufs.items() if name not in ['input', 'outputs']]
cprog += list(functions.values())
cprog += [f"void net({inputs}, {outputs}) {{"] + [f"{name}({', '.join(args)});" for (name, args, _global_size, _local_size) in statements] + ["}"]
return '\n'.join(cprog)
for name in set(bufs.keys()) - set(bufs_to_save.keys()) - set(input_names + output_names):
n_bytes, dtype, _ = bufs[name]
cprog += [f"{dtype_map[dtype]} {name}[{n_bytes // dtype.itemsize}];"]
cprog += [f"void net({forward_args})"] + ["{"]
get_weight_ptr = lambda x: f"({dtype_map[bufs_to_save[x][1]]} *)bufs[{buf_to_name[x]['idx']}]" if x in bufs_to_save else x
cprog += [f" {name}({', '.join(map(get_weight_ptr, args))});" for (name, args, _global_size, _local_size) in statements] + ["}"]
weightMapping = "" if not bufs_to_save else f"""\nconst weightNames = [{", ".join([f'"{weight_name}"' for weight_name in [v["name"] for v in buf_to_name.values()]])}];
const {model_name}_name_to_id = Object.fromEntries(weightNames.map((name, index) => [name, index]));\n"""
top = f"""import {model_name}Module from './{model_name}.js'{weightMapping}"""
whitespace = "\n "
js_wrapper = f"""{top}\nvar {model_name} = async function() {{
const wasm = await {model_name}Module();
{whitespace.join(f"const {name}Ptr = wasm._malloc({n_bytes});" for name, _, n_bytes in outputs+inputs if name not in symbolic_vars.values())}
return {{
run: ({",".join(name for name,_,_ in inputs)}) => {{
{(whitespace + " ").join(f"wasm.HEAPU8.set({name}, {name}Ptr);" for name,_,_ in inputs if name not in symbolic_vars.values())}
wasm._net({", ".join(f"{name}{'Ptr' if name not in symbolic_vars.values() else ''}" for name,_,_ in outputs+inputs)});
{(whitespace + " ").join(f"const {name} = wasm.HEAPU8.slice({name}Ptr, {name}Ptr + {n_bytes});" for name,_,n_bytes in outputs)}
return [{", ".join(f"{name}" for name,_,_ in outputs)}];
}},
wasm: wasm
}}
}}\nexport {{ {model_name}, {model_name}_name_to_id }};"""
return '\n'.join(headers + cprog), js_wrapper
def dtype_to_js_type(dtype: DType) -> str:
return f"{'Uint' if dtype in dtypes.uints else 'Int' if (dtype in dtypes.sints or dtype == dtypes.bool) else 'Float'}{8*dtype.itemsize}Array"
def export_model_webgpu(functions, statements, bufs, weight_names, input_names, output_names, model_name) -> Tuple[str,int,int]:
exported_name = "model" if model_name == None else model_name
def export_model_webgpu(functions, statements, bufs, weight_names, input_names, output_names, model_name, symbolic_vars={}, stream_weights=False) -> Tuple[str,int,int]:
kernel_code = '\n\n'.join([f"const {key} = `{code.replace(key, 'main')}`;" for key, code in functions.items()])
kernel_names = ', '.join([name for (name, _, _, _) in statements])
input_names += list(symbolic_vars.values())
input_buffer_types = [dtype_to_js_type(bufs[inp_name][1]) for inp_name in input_names]
output_buffer_types = [dtype_to_js_type(bufs[out_name][1]) for out_name in output_names]
buf_type = lambda x: "uniform" if x in set(symbolic_vars.values()) else "storage"
create_bind_group_layouts = ",".join([
"device.createBindGroupLayout({{entries: [{{binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: {{ type: 'uniform' }}}}, {}]}})".format(
",".join([f"{{binding: {argIdx+1}, visibility: GPUShaderStage.COMPUTE, buffer: {{ type: 'storage' }} }}" for argIdx, _ in enumerate(args)])
",".join([f"{{binding: {argIdx+1}, visibility: GPUShaderStage.COMPUTE, buffer: {{ type: '{buf_type(argName)}' }} }}" for argIdx, argName in enumerate(args)])
)
for _, (_, args, _, _) in enumerate(statements)
])
layouts = f"const layouts=[{create_bind_group_layouts}]"
kernel_calls = '\n '.join([f"addComputePass(device, commandEncoder, pipelines[{i}], layouts[{i}], infinityBuf, [{', '.join(args)}], {global_size});" for i, (_name, args, global_size, _local_size) in enumerate(statements) ])
_bufs = '\n '.join([f"const {name} = " + (f"createEmptyBuf(device, {size});" if _key not in weight_names else f"createWeightBuf(device, {size}, getTensorBuffer(safetensor, metadata['{weight_names[_key]}']))") + ";" for name,(size,dtype,_key) in bufs.items()])
kernel_calls = '\n '.join([f"addComputePass(device, commandEncoder, pipelines[{i}], layouts[{i}], infinityBuf, [{', '.join(args)}], [{', '.join(str(x) for x in global_size)}]);" for i, (_name, args, global_size, _local_size) in enumerate(statements) ])
buf_type = lambda x: "createUniformBuf" if x in set(uop.arg[0] for uop in symbolic_vars) else "createEmptyBuf"
map_to_external_weight = lambda _key: f"state_dict['{weight_names[_key]}']" if stream_weights else f"getTensorBuffer(safetensor, metadata['{weight_names[_key]}'])"
_bufs = '\n '.join([f"const {name} = " + (f"{buf_type(_key)}(device, {size});" if _key not in weight_names else f"createWeightBuf(device, {size}, {map_to_external_weight(_key)})") + ";" for name,(size,dtype,_key) in bufs.items()])
gpu_write_bufs = '\n '.join([f"const gpuWriteBuffer{i} = device.createBuffer({{size:{input_name}.size, usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.MAP_WRITE }});" for i,input_name in enumerate(input_names)])
input_buffer_types = [dtype_to_js_type(bufs[inp_name][1]) for inp_name in input_names]
output_buffer_types = [dtype_to_js_type(bufs[out_name][1]) for out_name in output_names]
input_writers = '\n '.join([f"await gpuWriteBuffer{i}.mapAsync(GPUMapMode.WRITE);\n new {input_buffer_types[i]}(gpuWriteBuffer{i}.getMappedRange()).set(" + f'_{inp_name});' + f"\n gpuWriteBuffer{i}.unmap();\n commandEncoder.copyBufferToBuffer(gpuWriteBuffer{i}, 0, {inp_name}, 0, gpuWriteBuffer{i}.size);" for i,inp_name in enumerate(input_names)])
gpu_read_bufs = '\n '.join([f"const gpuReadBuffer{i} = device.createBuffer({{size:{output_name}.size, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ }});" for i,output_name in enumerate(output_names)])
outbuf_copies = '\n '.join([f"commandEncoder.copyBufferToBuffer({output_name}, 0, gpuReadBuffer{i}, 0, output{i}.size);" for i,output_name in enumerate(output_names)])
output_readers = '\n '.join([f"await gpuReadBuffer{i}.mapAsync(GPUMapMode.READ);\n const resultBuffer{i} = new {output_buffer_types[i]}(gpuReadBuffer{i}.size/{bufs[output_names[i]][1].itemsize});\n resultBuffer{i}.set(new {output_buffer_types[i]}(gpuReadBuffer{i}.getMappedRange()));\n gpuReadBuffer{i}.unmap();" for i in range(len(output_names))])
output_return = '[{}]'.format(",".join([f'resultBuffer{i}' for i in range(len(output_names))]))
return f"""
const {exported_name} = (() => {{
const getTensorBuffer = (safetensorBuffer, tensorMetadata) => {{
return safetensorBuffer.subarray(...tensorMetadata.data_offsets);
}};
const getTensorMetadata = (safetensorBuffer) => {{
getTensorMetadata = f"""\nconst getTensorMetadata = (safetensorBuffer) => {{
const metadataLength = Number(new DataView(safetensorBuffer.buffer).getBigUint64(0, true));
const metadata = JSON.parse(new TextDecoder("utf8").decode(safetensorBuffer.subarray(8, 8 + metadataLength)));
return Object.fromEntries(Object.entries(metadata).filter(([k, v]) => k !== "__metadata__").map(([k, v]) => [k, {{...v, data_offsets: v.data_offsets.map(x => 8 + metadataLength + x)}}]));
}};\n""" if not stream_weights else ""
return f"""
const {model_name} = (() => {{
const getTensorBuffer = (safetensorBuffer, tensorMetadata) => {{
return safetensorBuffer.subarray(...tensorMetadata.data_offsets);
}};
{getTensorMetadata}
const createEmptyBuf = (device, size) => {{
return device.createBuffer({{size, usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST }});
}};
const createUniformBuf = (device, size) => {{
return device.createBuffer({{size, usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST}})
}}
const createInfinityUniformBuf = (device) => {{
const size = 4;
const buf = device.createBuffer({{
@@ -121,9 +169,8 @@ const createInfinityUniformBuf = (device) => {{
}};
const createWeightBuf = (device, size, data) => {{
const buf = device.createBuffer({{ mappedAtCreation: true, size, usage: GPUBufferUsage.STORAGE }});
new Uint8Array(buf.getMappedRange()).set(data);
buf.unmap();
const buf = device.createBuffer({{ size, usage: GPUBufferUsage.STORAGE{" | GPUBufferUsage.COPY_DST" if stream_weights else ", mappedAtCreation: true"} }});
{"data.bytes = buf;" if stream_weights else "new Uint8Array(buf.getMappedRange()).set(data); buf.unmap();"}
return buf;
}};
@@ -145,8 +192,8 @@ const addComputePass = (device, commandEncoder, pipeline, layout, infinityUnifor
{kernel_code}
const setupNet = async (device, safetensor) => {{
const metadata = getTensorMetadata(safetensor);
const setupNet = async (device, {"state_dict" if stream_weights else "safetensor"}) => {{
{"const metadata = getTensorMetadata(safetensor);" if not stream_weights else ""}
const infinityBuf = createInfinityUniformBuf(device);
{layouts}
@@ -185,12 +232,12 @@ const setupNet = async (device, safetensor) => {{
}}
}}
const load = async (device, weight_path) => {{ return await fetch(weight_path).then(x => x.arrayBuffer()).then(x => setupNet(device, new Uint8Array(x))); }}
return {{ load }};
return {{ load, setupNet }};
}})();
export default {exported_name};
export default {model_name};
"""
def export_model(model, target:str, *inputs, model_name: Optional[str] = None):
def export_model(model, target:str, *inputs, model_name: Optional[str] = "model", stream_weights=False):
assert Device.DEFAULT in EXPORT_SUPPORTED_DEVICE, "only WEBGPU, CPU, CUDA, GPU, METAL are supported"
with Context(JIT=2): run,special_names = jit_model(model, *inputs)
functions, statements, bufs, bufs_to_save = compile_net(run, special_names)
@@ -198,11 +245,30 @@ def export_model(model, target:str, *inputs, model_name: Optional[str] = None):
weight_names = {id(x.lazydata.base.realized): name for name, x in state.items()}
input_names = [name for _,name in special_names.items() if "input" in name]
output_names = [name for _,name in special_names.items() if "output" in name]
# handle symbolic variables; TODO: refactor to fix some of this stuff upstream in tinygrad
symbolic_vars = OrderedDict()
for i, (_, args, global_size, _) in enumerate(statements):
for j, var in enumerate(args):
if getattr(var, "op", None) is Ops.DEFINE_VAR and isinstance(getattr(var, "arg", None), tuple) and isinstance(var.arg[0], str):
if var not in symbolic_vars:
symbolic_vars[var] = var.arg[0]
bufs[symbolic_vars[var]] = (var.dtype.itemsize, var.dtype, symbolic_vars[var])
statements[i][1][j] = symbolic_vars[var]
if global_size:
for j, dim in enumerate(global_size):
if getattr(dim, "op", None) is Ops.ADD and len(dim.src) == 2 and {dim.src[0].op, dim.src[1].op} == {Ops.DEFINE_VAR, Ops.CONST}:
name, val = dim.src if dim.src[1].op is Ops.CONST else reversed(dim.src)
global_size[j] = f"_{name.arg[0]}[0] + {val.arg}"
prg = ""
if target == "clang":
prg = export_model_clang(functions, statements, bufs, bufs_to_save, input_names, output_names)
elif target == "wasm":
return export_model_clang(functions, statements, bufs, bufs_to_save, input_names, output_names, weight_names, model_name, symbolic_vars, wasm=True)
elif target == "webgpu":
prg = export_model_webgpu(functions, statements, bufs, weight_names, input_names, output_names, model_name)
prg = export_model_webgpu(functions, statements, bufs, weight_names, input_names, output_names, model_name, symbolic_vars, stream_weights)
else:
prg = json.dumps({
"backend": Device.DEFAULT,