mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
tinychat in browser, Part 2: model export (#9274)
* load llama3-1B to WEBGPU device * include compile script for loading llama3 to WEBGPU * parametrize max_context in build_transformer fxn * jit_model with two different args sets * compile for webgpu, split weights * load model weight parts in browser * export all tensors from initialized transformer * run transformer inference in browser * enable tiktoken with llama bpe in browser * count total tokens on client with tiktoken.js * full client-side chat streaming, eliminate server * revert change that enabled jitting with 2 argsets * llama without Variable or cache_kv, for webgpu * have client use mask tokens / whole context * cleanup staged weights * add tiktoken.js build script, README * export CLANG for Q6_k to float32 decompression * fix and test exported CLANG code for Q6_k to fp32 * revert changes to jit and export_model * isolate clang export * test Q6_K to float32 decompression in browser * gguf_load now also returns t_infos and data_start * prepare llama-1B Q6_K gguf chunks for browser * cache and decompress quantized llama in browser * enable separate deployment of large files * fix kv cache and symbolic with llama wgpu * eliminate browser lag during decompression * hash metadata and weight chunks * delete obsolete indexeddb cache to free disk * add progress bar, track model download/decompress * refactor progress callback * skip buffer hash verification for speed * Display progress for entire loading scope * Report page load errors to user * actually display errors * skip prompt tokens already seen by model * skip prefilling with last assistant message tokens * on page load tell user if webgpu not enabled * push deployed URL root to window.history * make note of bug sources with TODO items * isolate bug in CLANG with BEAM=2 * remove clang_bug.py from diff * decompress q6k to f32 on webgpu instead of clang * remove unused code * inter-weight decomp with larger wgpu kernels * parallelize decompression submissions * refactor dequantize scheduling * add progress bar back * fix bug * temp fix for loading GGUF Q6_K to fp16 not fp32 * fix rendering of exported CLANG * remove weight casts, sketch js functions for clang * get symbolic vars from jit_cache for model export * include symbolic vars in exported CLANG * render js for clang transformer * toggle clang/webgpu deployment; refactor decomp * compile and render clang Q6_K->fp16 and int8 quant * fix rendered clang for abs(fp16), to work in wasm * simplify clang js wrapping * run compiled clang in worker * prepare llama weights in workers, q6k to int8/fp16 * tinychat on clang in browser, f32/int8 weights * move wasm inference to (now flexible) worker * don't load redundant embeddings * modest wasm perf gain with compile flags * set default backend, enable backend choice/backup * render symbolic vars in exported WEBGPU * quantize webgpu llama to int8/f32 * improve UX arising from rendered WEBGPU * clean up webgpu launch * new weights split: smaller chunks, tinygrad quant. * switch webgpu inference to int8 quant * remove unneeded clang decompression * eliminate unneeded kv cache transfer to wasm * use 1 worker for simplified clang decompression * display launch errors * refactor: stream load weight chunks to WebGPU * show loading chunk completion * quantize embeddings to int8 * test float16 as input for quantization * webgpu: use f16 source, int8 embed, eliminate q6k * simplify split weights prep: all from state_dict * revert change to nn.state.gguf_load * remove unneeded decompression from webgpu client * remove unneeded code * decrease dl chunks from 47 to 16 MiB * improve stability of webgpu loading on mobile * autodetect mobile, improve load stability * refactor: progress closure * refactor: one unified progress bar * remove unneeded code * revert changes to tinygrad core library * enforce ios18.3 nerfed max buf size * BEAM=3 webgpu * cache integrity, mobile save throttling * improve mobile UX - no autozoom on prompt box * clang: int8 from f16, remove q6k * reduce concurrent dls on mobile to 2 for stability * refactor: wasm backend with stream loading * prevent race between wasm load and indexedb save * split wasm kernels into separate modules * js wrapper for multiple wasm module inference * revert multi-module wasm to single module * make mobile wasm load more stable/fast * refactor: copy weights into wasm without crashes * fix bug in download queue; increase mobile dls * refactor exported clang wrapper, split weights * remove unnecessary code * greatly improve int8 quant quality with rounding * eliminate mobile throttling * increase webgpu context to 4096 tokens * export webgpu js functions * enable separate hosted weights for mobile/pc * enable prompt-thread switching during generation * stop generation when max_context is reached * show progress bar for prefill * tell user if webgpu fails, while wasm loads * make loading messages more concise * update font * revert changes to tinychat python app launch * cleanup quantization, add scale_dtype param * cleanup kv cache code * cleanup compile code * link tok_embeddings with output in webgpu export * refactor: export_model webgpu: symbolic vars * refactor: export_model weight loading * forgot to commit export_model.py * change CLANG to CPU * deal with pylint incorrectly failing tests * simplify f-strings for older CI python version * fix pre-python3.12 parser errors * [Int32Array] not Int32Array * cleanup webgpu compile after refactor export_model * refactor WASM export into export_model * merge WebGPU/WASM compile scripts * simplify max_contexts for local deployment * fix parser issues and whitespace * deduplicate variable defs for non-wasm clang export * cleanup code * cleanup compile scripts * simplify wasm inference wrapping * simplify webgpu symbolic vars export * refactor: unify export of symbolic variables * simplify WASM export * simplify clang/wasm export * update README and build scripts * separate files for browser/python apps * restore original python tinychat app files * browser and python tinychats share assets * minor cleanup * isolate compile/export model --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
8
examples/tinychat/tinychat-browser/README.md
Normal file
8
examples/tinychat/tinychat-browser/README.md
Normal file
@@ -0,0 +1,8 @@
|
||||
# How to build and run tinychat in browser (WebGPU and WASM)
|
||||
- `PYTHONPATH=. python examples/tinychat/tinychat-browser/compile.py`
|
||||
- `./examples/tinychat/tinychat-browser/compile_wasm.sh`
|
||||
- Prerequisite: [install emscripten](https://emscripten.org/docs/getting_started/downloads.html). This script looks for `~/emsdk/emsdk_env.sh`, adjust this based on your installation.
|
||||
- `./examples/tinychat/tinychat-browser/make_tiktoken_js.sh`
|
||||
- Prerequisite: install `npm`, `webpack`.
|
||||
- `cd examples/tinychat && python -m http.server 7776`
|
||||
- In browser: open either `localhost:7776/tinychat-browser` (WebGPU), or `localhost:7776/tinychat-browser/?backend=wasm` (WASM)
|
||||
123
examples/tinychat/tinychat-browser/compile.py
Normal file
123
examples/tinychat/tinychat-browser/compile.py
Normal file
@@ -0,0 +1,123 @@
|
||||
import os, json, hashlib, math
|
||||
from extra.export_model import export_model
|
||||
from examples.llama3 import build_transformer
|
||||
from tinygrad.nn.state import get_state_dict, load_state_dict
|
||||
from tinygrad import Device, Variable, Tensor, dtypes
|
||||
from tinygrad.helpers import fetch, Context
|
||||
from tiktoken.load import load_tiktoken_bpe, dump_tiktoken_bpe
|
||||
|
||||
def prepare_browser_chunks(model):
|
||||
# split weights into browser-friendly chunks
|
||||
state_dict = get_state_dict(model)
|
||||
del state_dict['output.weight'], state_dict['output.scale'] # same as tok_embeddings; ensures consistency with model export
|
||||
chunk_size = 16 * 1024 * 1024 # small chunks based on iphone browser constraints
|
||||
metadata = {}
|
||||
# We won't export cache_kv bytes (because we start inference on client at start_pos=0), but we will tell the client how big cache_kv needs to be
|
||||
t_infos = [(v.lazydata.base.realized.nbytes, k, v.dtype) for k,v in state_dict.items() if "cache_kv" not in k]
|
||||
empty_t_infos = [(v.lazydata.base.realized.nbytes, k, v.dtype) for k,v in state_dict.items() if "cache_kv" in k]
|
||||
|
||||
split_t_infos = []
|
||||
for size, name, dtype in t_infos:
|
||||
if size <= chunk_size:
|
||||
split_t_infos.append((size, name, dtype, ()))
|
||||
else: # split large weights into multiple parts
|
||||
for i in range(0, size, chunk_size):
|
||||
split_t_infos.append((min(chunk_size, size-i), f"{name}_part{math.ceil(i/chunk_size)}", dtype, (i, min(i+chunk_size, size))))
|
||||
|
||||
files = []
|
||||
# pack weights into files with FFD bin packing
|
||||
split_t_infos = sorted(split_t_infos, reverse=True)
|
||||
for info in split_t_infos:
|
||||
placed = False
|
||||
for file in files:
|
||||
if sum(i[0] for i in file) + info[0] <= chunk_size:
|
||||
if info[3] and any(i[3] for i in file): continue # no two split tensors can touch the same file, due to wasm loading constraints
|
||||
file.append(info)
|
||||
placed = True
|
||||
break
|
||||
if not placed:
|
||||
files.append([info])
|
||||
|
||||
tinygrad_dtypes = {dtypes.float32: "float32", dtypes.float16: "float16", dtypes.int8: "int8", dtypes.int32: "int32"}
|
||||
for i, file in enumerate(files):
|
||||
cursor = 0
|
||||
with open(os.path.join(os.path.dirname(__file__), f'./net_part{i}.chunk'), "wb+") as writer:
|
||||
for size, name, dtype, offsets in file:
|
||||
name, part_num = (name, 0) if "_part" not in name else (name.split("_part")[0], int(name.split("_part")[1]))
|
||||
default = {"parts": {}, "dtype": tinygrad_dtypes[dtype]}
|
||||
weight_metadata = metadata.get(name, default)
|
||||
weight_metadata["parts"][part_num] = {"file": i, "file_start_pos": cursor, "size": size}
|
||||
metadata[name] = weight_metadata
|
||||
data = bytes(state_dict[name].lazydata.base.realized.as_buffer())
|
||||
data = data if not offsets else data[offsets[0]:offsets[1]]
|
||||
writer.write(data)
|
||||
cursor += size
|
||||
|
||||
metadata.update({name: {"parts": {0: {"empty": True, "size": size}}, "dtype": tinygrad_dtypes[dtype]} for size, name, dtype in empty_t_infos})
|
||||
|
||||
for k in metadata:
|
||||
metadata[k]["parts"] = [part for part_num, part in sorted(metadata[k]["parts"].items(), key = lambda x: x[0])]
|
||||
cursor = 0
|
||||
for i, part in enumerate(metadata[k]["parts"]):
|
||||
metadata[k]["parts"][i]["target_start_pos"] = cursor
|
||||
cursor += part["size"]
|
||||
metadata[k]["size"] = cursor
|
||||
|
||||
# compute hashes, which client app will check to determine whether to update with new weights and/or detect integrity issues
|
||||
state_dict_hash = hashlib.sha256(json.dumps(metadata, sort_keys=True).encode("utf-8")).hexdigest()
|
||||
metadata = {"state_dict": metadata, "state_dict_hash": state_dict_hash, "files": []}
|
||||
for i in range(len(files)):
|
||||
with open(os.path.join(os.path.dirname(__file__), f'./net_part{i}.chunk'), "rb") as reader:
|
||||
metadata["files"].append({"name": f'net_part{i}.chunk', "hash": hashlib.sha256(reader.read()).hexdigest()})
|
||||
metadata_hash = hashlib.sha256(json.dumps(metadata, sort_keys=True).encode("utf-8")).hexdigest()
|
||||
metadata = {"metadata": metadata, "metadata_hash": metadata_hash}
|
||||
|
||||
with open(os.path.join(os.path.dirname(__file__), f'./net_metadata.json'), "w") as writer: json.dump(metadata, writer, indent=4)
|
||||
return metadata
|
||||
|
||||
if __name__=="__main__":
|
||||
# Export BPE data for use with tiktoken.js
|
||||
tokenizer_path = fetch("https://huggingface.co/bofenghuang/Meta-Llama-3-8B/resolve/main/original/tokenizer.model", "tokenizer.model", subdir="llama3-1b-instruct")
|
||||
mergeable_ranks = load_tiktoken_bpe(str(tokenizer_path))
|
||||
bpe_path = os.path.join(os.path.dirname(__file__), "llama3-2.tiktoken")
|
||||
dump_tiktoken_bpe(mergeable_ranks, bpe_path)
|
||||
|
||||
model_path = fetch("https://huggingface.co/bartowski/Llama-3.2-1B-Instruct-GGUF/resolve/main/Llama-3.2-1B-Instruct-f16.gguf", "Llama-3.2-1B-Instruct-f16.gguf", subdir="llama3-1b-instruct")
|
||||
Tensor.no_grad = True
|
||||
max_context=1024
|
||||
tok = 128000
|
||||
TEMPERATURE, TOP_K, TOP_P, ALPHA_F, ALPHA_P = 0.95, 0, 0.0, 0.0, 0.0
|
||||
start_pos = Variable("start_pos", 0, max_context).bind(0)
|
||||
model_input = lambda: [Tensor([[tok]]), start_pos, TEMPERATURE, TOP_K, TOP_P, ALPHA_F, ALPHA_P]
|
||||
|
||||
Device.DEFAULT="CPU"
|
||||
model = build_transformer(model_path, model_size="1B", quantize="int8", scale_dtype=dtypes.float32, device=Device.DEFAULT, max_context=max_context)
|
||||
state_dict = get_state_dict(model)
|
||||
out = model.forward(*model_input())
|
||||
model_name = "transformer"
|
||||
|
||||
with Context(BEAM=3):
|
||||
cprog, js_wrapper = export_model(model, "wasm", *model_input(), model_name=model_name)
|
||||
# ensure consistency with exported weights
|
||||
js_wrapper = js_wrapper.replace("output.weight", "tok_embeddings.weight").replace("output.scale", "tok_embeddings.scale")
|
||||
|
||||
with open(os.path.join(os.path.dirname(__file__), f"{model_name}.c"), "w") as f: f.write(cprog)
|
||||
with open(os.path.join(os.path.dirname(__file__), "net_clang.js"), "w") as f: f.write(js_wrapper)
|
||||
|
||||
Device.DEFAULT="WEBGPU"
|
||||
# float16 is not yet supported for dawn/Vulkan/NVIDIA stack, see: https://issues.chromium.org/issues/42251215
|
||||
# therefore for now, we used CLANG to quantize the float16 llama to int8 with float32 scales, then load to WEBGPU
|
||||
model = build_transformer(model_path, model_size="1B", quantize="int8", max_context=max_context, load_weights=False)
|
||||
load_state_dict(model, state_dict)
|
||||
# these were the same before load_state_dict
|
||||
model.output.weight, model.output.scale = model.tok_embeddings.weight, model.tok_embeddings.scale
|
||||
|
||||
out = model.forward(*model_input())
|
||||
metadata = prepare_browser_chunks(model) # export weights to disk
|
||||
|
||||
with Context(BEAM=3):
|
||||
prg, input_sizes, output_sizes, state = export_model(model, "webgpu", *model_input(), model_name=model_name, stream_weights=True)
|
||||
# ensure consistency with exported weights
|
||||
prg = prg.replace("output.weight", "tok_embeddings.weight").replace("output.scale", "tok_embeddings.scale")
|
||||
|
||||
with open(os.path.join(os.path.dirname(__file__), "net.js"), "w") as f: f.write(prg)
|
||||
23
examples/tinychat/tinychat-browser/compile_wasm.sh
Executable file
23
examples/tinychat/tinychat-browser/compile_wasm.sh
Executable file
@@ -0,0 +1,23 @@
|
||||
#!/usr/bin/env bash
|
||||
cd "$(dirname "$0")"
|
||||
|
||||
# prereq: install emscripten: https://emscripten.org/docs/getting_started/downloads.html
|
||||
EMSCRIPTEN_PATH=~/emsdk/emsdk_env.sh
|
||||
source $EMSCRIPTEN_PATH
|
||||
step="transformer"
|
||||
initial_memory=6553600
|
||||
max_memory=1500053504
|
||||
exported_functions='["_net", "_malloc", "_free", "_set_buf"]'
|
||||
|
||||
emcc "${step}.c" \
|
||||
-O3 -msimd128 -ffast-math -flto \
|
||||
-o "${step}.js" \
|
||||
-s MODULARIZE=1 \
|
||||
-s EXPORT_ES6=1 \
|
||||
-s EXPORTED_FUNCTIONS="${exported_functions}" \
|
||||
-s ENVIRONMENT='worker' \
|
||||
-s FILESYSTEM=0 \
|
||||
-s EVAL_CTORS \
|
||||
-s ALLOW_MEMORY_GROWTH=1 \
|
||||
-s INITIAL_MEMORY="$initial_memory" \
|
||||
-s MAXIMUM_MEMORY="$max_memory"
|
||||
Reference in New Issue
Block a user