Files
tinygrad/extra/export_model.py
hooved 01f7a4fadc tinychat in browser, Part 2: model export (#9274)
* load llama3-1B to WEBGPU device

* include compile script for loading llama3 to WEBGPU

* parametrize max_context in build_transformer fxn

* jit_model with two different args sets

* compile for webgpu, split weights

* load model weight parts in browser

* export all tensors from initialized transformer

* run transformer inference in browser

* enable tiktoken with llama bpe in browser

* count total tokens on client with tiktoken.js

* full client-side chat streaming, eliminate server

* revert change that enabled jitting with 2 argsets

* llama without Variable or cache_kv, for webgpu

* have client use mask tokens / whole context

* cleanup staged weights

* add tiktoken.js build script, README

* export CLANG for Q6_k to float32 decompression

* fix and test exported CLANG code for Q6_k to fp32

* revert changes to jit and export_model

* isolate clang export

* test Q6_K to float32 decompression in browser

* gguf_load now also returns t_infos and data_start

* prepare llama-1B Q6_K gguf chunks for browser

* cache and decompress quantized llama in browser

* enable separate deployment of large files

* fix kv cache and symbolic with llama wgpu

* eliminate browser lag during decompression

* hash metadata and weight chunks

* delete obsolete indexeddb cache to free disk

* add progress bar, track model download/decompress

* refactor progress callback

* skip buffer hash verification for speed

* Display progress for entire loading scope

* Report page load errors to user

* actually display errors

* skip prompt tokens already seen by model

* skip prefilling with last assistant message tokens

* on page load tell user if webgpu not enabled

* push deployed URL root to window.history

* make note of bug sources with TODO items

* isolate bug in CLANG with BEAM=2

* remove clang_bug.py from diff

* decompress q6k to f32 on webgpu instead of clang

* remove unused code

* inter-weight decomp with larger wgpu kernels

* parallelize decompression submissions

* refactor dequantize scheduling

* add progress bar back

* fix bug

* temp fix for loading GGUF Q6_K to fp16 not fp32

* fix rendering of exported CLANG

* remove weight casts, sketch js functions for clang

* get symbolic vars from jit_cache for model export

* include symbolic vars in exported CLANG

* render js for clang transformer

* toggle clang/webgpu deployment; refactor decomp

* compile and render clang Q6_K->fp16 and int8 quant

* fix rendered clang for abs(fp16), to work in wasm

* simplify clang js wrapping

* run compiled clang in worker

* prepare llama weights in workers, q6k to int8/fp16

* tinychat on clang in browser, f32/int8 weights

* move wasm inference to (now flexible) worker

* don't load redundant embeddings

* modest wasm perf gain with compile flags

* set default backend, enable backend choice/backup

* render symbolic vars in exported WEBGPU

* quantize webgpu llama to int8/f32

* improve UX arising from rendered WEBGPU

* clean up webgpu launch

* new weights split: smaller chunks, tinygrad quant.

* switch webgpu inference to int8 quant

* remove unneeded clang decompression

* eliminate unneeded kv cache transfer to wasm

* use 1 worker for simplified clang decompression

* display launch errors

* refactor: stream load weight chunks to WebGPU

* show loading chunk completion

* quantize embeddings to int8

* test float16 as input for quantization

* webgpu: use f16 source, int8 embed, eliminate q6k

* simplify split weights prep: all from state_dict

* revert change to nn.state.gguf_load

* remove unneeded decompression from webgpu client

* remove unneeded code

* decrease dl chunks from 47 to 16 MiB

* improve stability of webgpu loading on mobile

* autodetect mobile, improve load stability

* refactor: progress closure

* refactor: one unified progress bar

* remove unneeded code

* revert changes to tinygrad core library

* enforce ios18.3 nerfed max buf size

* BEAM=3 webgpu

* cache integrity, mobile save throttling

* improve mobile UX - no autozoom on prompt box

* clang: int8 from f16, remove q6k

* reduce concurrent dls on mobile to 2 for stability

* refactor: wasm backend with stream loading

* prevent race between wasm load and indexedb save

* split wasm kernels into separate modules

* js wrapper for multiple wasm module inference

* revert multi-module wasm to single module

* make mobile wasm load more stable/fast

* refactor: copy weights into wasm without crashes

* fix bug in download queue; increase mobile dls

* refactor exported clang wrapper, split weights

* remove unnecessary code

* greatly improve int8 quant quality with rounding

* eliminate mobile throttling

* increase webgpu context to 4096 tokens

* export webgpu js functions

* enable separate hosted weights for mobile/pc

* enable prompt-thread switching during generation

* stop generation when max_context is reached

* show progress bar for prefill

* tell user if webgpu fails, while wasm loads

* make loading messages more concise

* update font

* revert changes to tinychat python app launch

* cleanup quantization, add scale_dtype param

* cleanup kv cache code

* cleanup compile code

* link tok_embeddings with output in webgpu export

* refactor: export_model webgpu: symbolic vars

* refactor: export_model weight loading

* forgot to commit export_model.py

* change CLANG to CPU

* deal with pylint incorrectly failing tests

* simplify f-strings for older CI python version

* fix pre-python3.12 parser errors

* [Int32Array] not Int32Array

* cleanup webgpu compile after refactor export_model

* refactor WASM export into export_model

* merge WebGPU/WASM compile scripts

* simplify max_contexts for local deployment

* fix parser issues and whitespace

* deduplicate variable defs for non-wasm clang export

* cleanup code

* cleanup compile scripts

* simplify wasm inference wrapping

* simplify webgpu symbolic vars export

* refactor: unify export of symbolic variables

* simplify WASM export

* simplify clang/wasm export

* update README and build scripts

* separate files for browser/python apps

* restore original python tinychat app files

* browser and python tinychats share assets

* minor cleanup

* isolate compile/export model

---------

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
2025-03-04 15:53:30 +08:00

300 lines
16 KiB
Python

from typing import Tuple, Dict, List, Optional
from tinygrad.dtype import DType
from tinygrad.renderer import ProgramSpec
from tinygrad.tensor import Device, Tensor
from tinygrad.engine.jit import TinyJit
from tinygrad.nn.state import get_state_dict
from tinygrad.helpers import Context
from tinygrad.dtype import dtypes
from tinygrad.ops import Ops
import json
from collections import OrderedDict
EXPORT_SUPPORTED_DEVICE = ["WEBGPU", "CPU", "CUDA", "GPU"]
def compile_net(run:TinyJit, special_names:Dict[int,str]) -> Tuple[Dict[str,str],List[Tuple[str,List[str],List[int]]],Dict[str,Tuple[int,DType,int]],Dict[str,Tensor]]:
functions, bufs, bufs_to_save, statements, bufnum = {}, {}, {}, [], 0
for ji in run.jit_cache:
fxn: ProgramSpec = ji.prg.p
functions[fxn.function_name] = fxn.src # NOTE: this assumes all with the same name are the same
cargs = []
for i,arg in enumerate(ji.bufs):
key = id(arg)
if key not in bufs:
if key in special_names:
bufs[key] = (special_names[key], arg.size*arg.dtype.itemsize, arg.dtype, key)
else:
bufs[key] = (f"buf_{bufnum}", arg.size*arg.dtype.itemsize, arg.dtype, key)
bufnum += 1
if i > 0: bufs_to_save[bufs[key][0]] = arg # if first usage of a buffer is not an output, and it's not a special name
cargs.append(bufs[key][0])
cargs += [var for var in fxn.vars if getattr(var, "op", None) is Ops.DEFINE_VAR] # symbolic vars; is it necessary or sufficient to check for DEFINE_VAR?
statements.append((fxn.function_name, cargs, fxn.global_size, fxn.local_size))
return functions, statements, {name:(size, dtype, key) for (name,size,dtype,key) in bufs.values()}, bufs_to_save
def jit_model(model, *args) -> Tuple[TinyJit,Dict[int,str]]:
assert hasattr(model, "forward") or callable(model), "model needs a forward function"
@TinyJit
def run(*x):
out = model.forward(*x) if hasattr(model, "forward") else model(*x)
assert isinstance(out, tuple) or isinstance(out, list) or isinstance(out, Tensor), "model output must be a Tensor, tuple, or a list of Tensors for export"
out = [out] if isinstance(out, Tensor) else out
return [o.realize() for o in out]
# twice to run the JIT
for _ in range(2): the_output = run(*args)
special_names = {}
# hack to put the inputs back
for (j,i),idx in run.input_replace.items():
realized_input = args[idx].lazydata.base.realized
run.jit_cache[j].bufs[i] = realized_input
special_names[id(realized_input)] = f'input{idx}'
# TODO: fetch this from the jit in self.input_replace and self.ret (hint: use get_parameters on self.ret)
for i, output in enumerate(the_output):
special_names[id(output.lazydata.base.realized)] = f'output{i}'
return run, special_names
def export_model_clang(functions:Dict[str,str], statements:Dict[str,Tuple[str,int,int]], bufs:Dict[str,Tuple[str,int,int]],
bufs_to_save:Dict[str,Tensor], input_names:List[str], output_names:List[str], weight_names={}, model_name="model", symbolic_vars={}, wasm=False) -> str:
headers = ["#include <tgmath.h>"]
cprog = list(functions.values())
dtype_map = {dtypes.int: "int", dtypes.float: "float", dtypes.uchar: "unsigned char", dtypes.char: "signed char", dtypes.half: "__fp16", dtypes.uint: "unsigned int"}
inputs = [(name, dtype_map[bufs[name][1]], bufs[name][0]) for name in input_names + list(symbolic_vars.values())]
outputs = [(name, dtype_map[bufs[name][1]], bufs[name][0]) for name in output_names]
forward_args = ",".join(f"{dtype}{'*' if name not in symbolic_vars.values() else ''} {name}" for name,dtype,_ in (outputs+inputs if wasm else inputs+outputs))
if not wasm:
for name,cl in bufs_to_save.items():
weight = ''.join(["\\x%02X"%x for x in bytes(cl._buf)])
cprog.append(f"unsigned char {name}_data[] = \"{weight}\";")
cprog += [f"{dtype_map[dtype]} {name}[{len}];" if name not in bufs_to_save else f"{dtype_map[dtype]} *{name} = ({dtype_map[dtype]} *){name}_data;" for name,(len,dtype,_key) in bufs.items() if name not in input_names+output_names]
cprog += [f"void net({forward_args}) {{"] + [f"{name}({', '.join(args)});" for (name, args, _global_size, _local_size) in statements] + ["}"]
return '\n'.join(headers + cprog)
else:
if bufs_to_save:
headers += ["#include <stddef.h>"]
bufs_to_save = {k:v for k,v in bufs.items() if v[2] in weight_names} # causes random seeds to be set as zeroes, not exported as a model weight
buf_to_name = OrderedDict((buf_name, {"name": weight_names[data[2]], "idx": i}) for i, (buf_name, data) in enumerate(bufs_to_save.items()))
cprog.append(f"void* bufs[{len(buf_to_name)}];")
cprog.append(f"""void set_buf(size_t index, void* ptr) {{\n bufs[index] = ptr;\n}}""")
for name in set(bufs.keys()) - set(bufs_to_save.keys()) - set(input_names + output_names):
n_bytes, dtype, _ = bufs[name]
cprog += [f"{dtype_map[dtype]} {name}[{n_bytes // dtype.itemsize}];"]
cprog += [f"void net({forward_args})"] + ["{"]
get_weight_ptr = lambda x: f"({dtype_map[bufs_to_save[x][1]]} *)bufs[{buf_to_name[x]['idx']}]" if x in bufs_to_save else x
cprog += [f" {name}({', '.join(map(get_weight_ptr, args))});" for (name, args, _global_size, _local_size) in statements] + ["}"]
weightMapping = "" if not bufs_to_save else f"""\nconst weightNames = [{", ".join([f'"{weight_name}"' for weight_name in [v["name"] for v in buf_to_name.values()]])}];
const {model_name}_name_to_id = Object.fromEntries(weightNames.map((name, index) => [name, index]));\n"""
top = f"""import {model_name}Module from './{model_name}.js'{weightMapping}"""
whitespace = "\n "
js_wrapper = f"""{top}\nvar {model_name} = async function() {{
const wasm = await {model_name}Module();
{whitespace.join(f"const {name}Ptr = wasm._malloc({n_bytes});" for name, _, n_bytes in outputs+inputs if name not in symbolic_vars.values())}
return {{
run: ({",".join(name for name,_,_ in inputs)}) => {{
{(whitespace + " ").join(f"wasm.HEAPU8.set({name}, {name}Ptr);" for name,_,_ in inputs if name not in symbolic_vars.values())}
wasm._net({", ".join(f"{name}{'Ptr' if name not in symbolic_vars.values() else ''}" for name,_,_ in outputs+inputs)});
{(whitespace + " ").join(f"const {name} = wasm.HEAPU8.slice({name}Ptr, {name}Ptr + {n_bytes});" for name,_,n_bytes in outputs)}
return [{", ".join(f"{name}" for name,_,_ in outputs)}];
}},
wasm: wasm
}}
}}\nexport {{ {model_name}, {model_name}_name_to_id }};"""
return '\n'.join(headers + cprog), js_wrapper
def dtype_to_js_type(dtype: DType) -> str:
return f"{'Uint' if dtype in dtypes.uints else 'Int' if (dtype in dtypes.sints or dtype == dtypes.bool) else 'Float'}{8*dtype.itemsize}Array"
def export_model_webgpu(functions, statements, bufs, weight_names, input_names, output_names, model_name, symbolic_vars={}, stream_weights=False) -> Tuple[str,int,int]:
kernel_code = '\n\n'.join([f"const {key} = `{code.replace(key, 'main')}`;" for key, code in functions.items()])
kernel_names = ', '.join([name for (name, _, _, _) in statements])
input_names += list(symbolic_vars.values())
input_buffer_types = [dtype_to_js_type(bufs[inp_name][1]) for inp_name in input_names]
output_buffer_types = [dtype_to_js_type(bufs[out_name][1]) for out_name in output_names]
buf_type = lambda x: "uniform" if x in set(symbolic_vars.values()) else "storage"
create_bind_group_layouts = ",".join([
"device.createBindGroupLayout({{entries: [{{binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: {{ type: 'uniform' }}}}, {}]}})".format(
",".join([f"{{binding: {argIdx+1}, visibility: GPUShaderStage.COMPUTE, buffer: {{ type: '{buf_type(argName)}' }} }}" for argIdx, argName in enumerate(args)])
)
for _, (_, args, _, _) in enumerate(statements)
])
layouts = f"const layouts=[{create_bind_group_layouts}]"
kernel_calls = '\n '.join([f"addComputePass(device, commandEncoder, pipelines[{i}], layouts[{i}], infinityBuf, [{', '.join(args)}], [{', '.join(str(x) for x in global_size)}]);" for i, (_name, args, global_size, _local_size) in enumerate(statements) ])
buf_type = lambda x: "createUniformBuf" if x in set(uop.arg[0] for uop in symbolic_vars) else "createEmptyBuf"
map_to_external_weight = lambda _key: f"state_dict['{weight_names[_key]}']" if stream_weights else f"getTensorBuffer(safetensor, metadata['{weight_names[_key]}'])"
_bufs = '\n '.join([f"const {name} = " + (f"{buf_type(_key)}(device, {size});" if _key not in weight_names else f"createWeightBuf(device, {size}, {map_to_external_weight(_key)})") + ";" for name,(size,dtype,_key) in bufs.items()])
gpu_write_bufs = '\n '.join([f"const gpuWriteBuffer{i} = device.createBuffer({{size:{input_name}.size, usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.MAP_WRITE }});" for i,input_name in enumerate(input_names)])
input_writers = '\n '.join([f"await gpuWriteBuffer{i}.mapAsync(GPUMapMode.WRITE);\n new {input_buffer_types[i]}(gpuWriteBuffer{i}.getMappedRange()).set(" + f'_{inp_name});' + f"\n gpuWriteBuffer{i}.unmap();\n commandEncoder.copyBufferToBuffer(gpuWriteBuffer{i}, 0, {inp_name}, 0, gpuWriteBuffer{i}.size);" for i,inp_name in enumerate(input_names)])
gpu_read_bufs = '\n '.join([f"const gpuReadBuffer{i} = device.createBuffer({{size:{output_name}.size, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ }});" for i,output_name in enumerate(output_names)])
outbuf_copies = '\n '.join([f"commandEncoder.copyBufferToBuffer({output_name}, 0, gpuReadBuffer{i}, 0, output{i}.size);" for i,output_name in enumerate(output_names)])
output_readers = '\n '.join([f"await gpuReadBuffer{i}.mapAsync(GPUMapMode.READ);\n const resultBuffer{i} = new {output_buffer_types[i]}(gpuReadBuffer{i}.size/{bufs[output_names[i]][1].itemsize});\n resultBuffer{i}.set(new {output_buffer_types[i]}(gpuReadBuffer{i}.getMappedRange()));\n gpuReadBuffer{i}.unmap();" for i in range(len(output_names))])
output_return = '[{}]'.format(",".join([f'resultBuffer{i}' for i in range(len(output_names))]))
getTensorMetadata = f"""\nconst getTensorMetadata = (safetensorBuffer) => {{
const metadataLength = Number(new DataView(safetensorBuffer.buffer).getBigUint64(0, true));
const metadata = JSON.parse(new TextDecoder("utf8").decode(safetensorBuffer.subarray(8, 8 + metadataLength)));
return Object.fromEntries(Object.entries(metadata).filter(([k, v]) => k !== "__metadata__").map(([k, v]) => [k, {{...v, data_offsets: v.data_offsets.map(x => 8 + metadataLength + x)}}]));
}};\n""" if not stream_weights else ""
return f"""
const {model_name} = (() => {{
const getTensorBuffer = (safetensorBuffer, tensorMetadata) => {{
return safetensorBuffer.subarray(...tensorMetadata.data_offsets);
}};
{getTensorMetadata}
const createEmptyBuf = (device, size) => {{
return device.createBuffer({{size, usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST }});
}};
const createUniformBuf = (device, size) => {{
return device.createBuffer({{size, usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST}})
}}
const createInfinityUniformBuf = (device) => {{
const size = 4;
const buf = device.createBuffer({{
mappedAtCreation: true,
size,
usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST
}});
new Float32Array(buf.getMappedRange())[0] = Infinity;
buf.unmap();
return buf;
}};
const createWeightBuf = (device, size, data) => {{
const buf = device.createBuffer({{ size, usage: GPUBufferUsage.STORAGE{" | GPUBufferUsage.COPY_DST" if stream_weights else ", mappedAtCreation: true"} }});
{"data.bytes = buf;" if stream_weights else "new Uint8Array(buf.getMappedRange()).set(data); buf.unmap();"}
return buf;
}};
const addComputePass = (device, commandEncoder, pipeline, layout, infinityUniformBuf, bufs, workgroup) => {{
const bindGroup = device.createBindGroup({{
layout: layout,
entries: [
{{ binding: 0, resource: {{ buffer: infinityUniformBuf }} }},
...bufs.map((buffer, index) => ({{ binding: index + 1, resource: {{ buffer }} }}))
]
}});
const passEncoder = commandEncoder.beginComputePass();
passEncoder.setPipeline(pipeline);
passEncoder.setBindGroup(0, bindGroup);
passEncoder.dispatchWorkgroups(...workgroup);
passEncoder.end();
}};
{kernel_code}
const setupNet = async (device, {"state_dict" if stream_weights else "safetensor"}) => {{
{"const metadata = getTensorMetadata(safetensor);" if not stream_weights else ""}
const infinityBuf = createInfinityUniformBuf(device);
{layouts}
{_bufs}
{gpu_write_bufs}
{gpu_read_bufs}
const kernels = [{kernel_names}];
const pipelines = await Promise.all(kernels.map(async (name, i) => {{
return await device.createComputePipelineAsync({{
layout: device.createPipelineLayout({{
bindGroupLayouts: [layouts[i]],
}}),
compute: {{
module: device.createShaderModule({{
code: name,
}}),
entryPoint: "main",
}},
}});
}}))
return async ({",".join([f"_{input_name}" for input_name in input_names])}) => {{
const commandEncoder = device.createCommandEncoder();
{input_writers}
{kernel_calls}
{outbuf_copies}
const gpuCommands = commandEncoder.finish();
device.queue.submit([gpuCommands]);
{output_readers}
return {output_return};
}}
}}
const load = async (device, weight_path) => {{ return await fetch(weight_path).then(x => x.arrayBuffer()).then(x => setupNet(device, new Uint8Array(x))); }}
return {{ load, setupNet }};
}})();
export default {model_name};
"""
def export_model(model, target:str, *inputs, model_name: Optional[str] = "model", stream_weights=False):
assert Device.DEFAULT in EXPORT_SUPPORTED_DEVICE, "only WEBGPU, CPU, CUDA, GPU, METAL are supported"
with Context(JIT=2): run,special_names = jit_model(model, *inputs)
functions, statements, bufs, bufs_to_save = compile_net(run, special_names)
state = get_state_dict(model)
weight_names = {id(x.lazydata.base.realized): name for name, x in state.items()}
input_names = [name for _,name in special_names.items() if "input" in name]
output_names = [name for _,name in special_names.items() if "output" in name]
# handle symbolic variables; TODO: refactor to fix some of this stuff upstream in tinygrad
symbolic_vars = OrderedDict()
for i, (_, args, global_size, _) in enumerate(statements):
for j, var in enumerate(args):
if getattr(var, "op", None) is Ops.DEFINE_VAR and isinstance(getattr(var, "arg", None), tuple) and isinstance(var.arg[0], str):
if var not in symbolic_vars:
symbolic_vars[var] = var.arg[0]
bufs[symbolic_vars[var]] = (var.dtype.itemsize, var.dtype, symbolic_vars[var])
statements[i][1][j] = symbolic_vars[var]
if global_size:
for j, dim in enumerate(global_size):
if getattr(dim, "op", None) is Ops.ADD and len(dim.src) == 2 and {dim.src[0].op, dim.src[1].op} == {Ops.DEFINE_VAR, Ops.CONST}:
name, val = dim.src if dim.src[1].op is Ops.CONST else reversed(dim.src)
global_size[j] = f"_{name.arg[0]}[0] + {val.arg}"
prg = ""
if target == "clang":
prg = export_model_clang(functions, statements, bufs, bufs_to_save, input_names, output_names)
elif target == "wasm":
return export_model_clang(functions, statements, bufs, bufs_to_save, input_names, output_names, weight_names, model_name, symbolic_vars, wasm=True)
elif target == "webgpu":
prg = export_model_webgpu(functions, statements, bufs, weight_names, input_names, output_names, model_name, symbolic_vars, stream_weights)
else:
prg = json.dumps({
"backend": Device.DEFAULT,
"inputs": [{
"size": bufs[name][0],
"dtype": bufs[name][1].name
} for name in input_names],
"outputs": [{
"size": bufs[name][0],
"dtype": bufs[name][1].name
} for name in output_names],
"functions": functions,
"statements": [{
"kernel": kernel,
"args": args,
"global_size": global_size,
"local_size": local_size
} for (kernel, args, global_size, local_size) in statements],
"buffers": {
name: {
"size": size,
"dtype": dtype.name,
"id": weight_names[_key] if _key in weight_names else ""
} for name, (size,dtype,_key) in bufs.items() if name not in ["input", "outputs"]
}
})
return prg, {input:bufs[input][0] for input in input_names}, {output:bufs[output][0] for output in output_names}, state