Files
tinygrad/examples/tinychat/tinychat-browser/compile.py
hooved 304afe0d55 tinychat in browser, Part 3: browser app (#9276)
* load llama3-1B to WEBGPU device

* include compile script for loading llama3 to WEBGPU

* parametrize max_context in build_transformer fxn

* jit_model with two different args sets

* compile for webgpu, split weights

* load model weight parts in browser

* export all tensors from initialized transformer

* run transformer inference in browser

* enable tiktoken with llama bpe in browser

* count total tokens on client with tiktoken.js

* full client-side chat streaming, eliminate server

* revert change that enabled jitting with 2 argsets

* llama without Variable or cache_kv, for webgpu

* have client use mask tokens / whole context

* cleanup staged weights

* add tiktoken.js build script, README

* export CLANG for Q6_k to float32 decompression

* fix and test exported CLANG code for Q6_k to fp32

* revert changes to jit and export_model

* isolate clang export

* test Q6_K to float32 decompression in browser

* gguf_load now also returns t_infos and data_start

* prepare llama-1B Q6_K gguf chunks for browser

* cache and decompress quantized llama in browser

* enable separate deployment of large files

* fix kv cache and symbolic with llama wgpu

* eliminate browser lag during decompression

* hash metadata and weight chunks

* delete obsolete indexeddb cache to free disk

* add progress bar, track model download/decompress

* refactor progress callback

* skip buffer hash verification for speed

* Display progress for entire loading scope

* Report page load errors to user

* actually display errors

* skip prompt tokens already seen by model

* skip prefilling with last assistant message tokens

* on page load tell user if webgpu not enabled

* push deployed URL root to window.history

* make note of bug sources with TODO items

* isolate bug in CLANG with BEAM=2

* remove clang_bug.py from diff

* decompress q6k to f32 on webgpu instead of clang

* remove unused code

* inter-weight decomp with larger wgpu kernels

* parallelize decompression submissions

* refactor dequantize scheduling

* add progress bar back

* fix bug

* temp fix for loading GGUF Q6_K to fp16 not fp32

* fix rendering of exported CLANG

* remove weight casts, sketch js functions for clang

* get symbolic vars from jit_cache for model export

* include symbolic vars in exported CLANG

* render js for clang transformer

* toggle clang/webgpu deployment; refactor decomp

* compile and render clang Q6_K->fp16 and int8 quant

* fix rendered clang for abs(fp16), to work in wasm

* simplify clang js wrapping

* run compiled clang in worker

* prepare llama weights in workers, q6k to int8/fp16

* tinychat on clang in browser, f32/int8 weights

* move wasm inference to (now flexible) worker

* don't load redundant embeddings

* modest wasm perf gain with compile flags

* set default backend, enable backend choice/backup

* render symbolic vars in exported WEBGPU

* quantize webgpu llama to int8/f32

* improve UX arising from rendered WEBGPU

* clean up webgpu launch

* new weights split: smaller chunks, tinygrad quant.

* switch webgpu inference to int8 quant

* remove unneeded clang decompression

* eliminate unneeded kv cache transfer to wasm

* use 1 worker for simplified clang decompression

* display launch errors

* refactor: stream load weight chunks to WebGPU

* show loading chunk completion

* quantize embeddings to int8

* test float16 as input for quantization

* webgpu: use f16 source, int8 embed, eliminate q6k

* simplify split weights prep: all from state_dict

* revert change to nn.state.gguf_load

* remove unneeded decompression from webgpu client

* remove unneeded code

* decrease dl chunks from 47 to 16 MiB

* improve stability of webgpu loading on mobile

* autodetect mobile, improve load stability

* refactor: progress closure

* refactor: one unified progress bar

* remove unneeded code

* revert changes to tinygrad core library

* enforce ios18.3 nerfed max buf size

* BEAM=3 webgpu

* cache integrity, mobile save throttling

* improve mobile UX - no autozoom on prompt box

* clang: int8 from f16, remove q6k

* reduce concurrent dls on mobile to 2 for stability

* refactor: wasm backend with stream loading

* prevent race between wasm load and indexedb save

* split wasm kernels into separate modules

* js wrapper for multiple wasm module inference

* revert multi-module wasm to single module

* make mobile wasm load more stable/fast

* refactor: copy weights into wasm without crashes

* fix bug in download queue; increase mobile dls

* refactor exported clang wrapper, split weights

* remove unnecessary code

* greatly improve int8 quant quality with rounding

* eliminate mobile throttling

* increase webgpu context to 4096 tokens

* export webgpu js functions

* enable separate hosted weights for mobile/pc

* enable prompt-thread switching during generation

* stop generation when max_context is reached

* show progress bar for prefill

* tell user if webgpu fails, while wasm loads

* make loading messages more concise

* update font

* revert changes to tinychat python app launch

* cleanup quantization, add scale_dtype param

* cleanup kv cache code

* cleanup compile code

* link tok_embeddings with output in webgpu export

* refactor: export_model webgpu: symbolic vars

* refactor: export_model weight loading

* forgot to commit export_model.py

* change CLANG to CPU

* deal with pylint incorrectly failing tests

* simplify f-strings for older CI python version

* fix pre-python3.12 parser errors

* [Int32Array] not Int32Array

* cleanup webgpu compile after refactor export_model

* refactor WASM export into export_model

* merge WebGPU/WASM compile scripts

* simplify max_contexts for local deployment

* fix parser issues and whitespace

* deduplicate variable defs for non-wasm clang export

* cleanup code

* cleanup compile scripts

* simplify wasm inference wrapping

* simplify webgpu symbolic vars export

* refactor: unify export of symbolic variables

* simplify WASM export

* simplify clang/wasm export

* update README and build scripts

* separate files for browser/python apps

* restore original python tinychat app files

* browser and python tinychats share assets

* minor cleanup

* isolate app layer diff

* add .gitignore for generated files

* validate CPU/WEBGPU models in python

* prevent infinite generation if validation fails

* check if exported weight files are unique

---------

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
2025-03-07 15:07:33 +08:00

149 lines
8.1 KiB
Python

import os, json, hashlib, math
from extra.export_model import export_model
from examples.llama3 import build_transformer, Tokenizer
from tinygrad.nn.state import get_state_dict, load_state_dict
from tinygrad import Device, Variable, Tensor, dtypes, TinyJit
from tinygrad.helpers import fetch, Context
from tiktoken.load import load_tiktoken_bpe, dump_tiktoken_bpe
def prepare_browser_chunks(model):
# split weights into browser-friendly chunks
state_dict = get_state_dict(model)
del state_dict['output.weight'], state_dict['output.scale'] # same as tok_embeddings; ensures consistency with model export
chunk_size = 16 * 1024 * 1024 # small chunks based on iphone browser constraints
metadata = {}
# We won't export cache_kv bytes (because we start inference on client at start_pos=0), but we will tell the client how big cache_kv needs to be
t_infos = [(v.lazydata.base.realized.nbytes, k, v.dtype) for k,v in state_dict.items() if "cache_kv" not in k]
empty_t_infos = [(v.lazydata.base.realized.nbytes, k, v.dtype) for k,v in state_dict.items() if "cache_kv" in k]
split_t_infos = []
for size, name, dtype in t_infos:
if size <= chunk_size:
split_t_infos.append((size, name, dtype, ()))
else: # split large weights into multiple parts
for i in range(0, size, chunk_size):
split_t_infos.append((min(chunk_size, size-i), f"{name}_part{math.ceil(i/chunk_size)}", dtype, (i, min(i+chunk_size, size))))
files = []
# pack weights into files with FFD bin packing
split_t_infos = sorted(split_t_infos, reverse=True)
for info in split_t_infos:
placed = False
for file in files:
if sum(i[0] for i in file) + info[0] <= chunk_size:
if info[3] and any(i[3] for i in file): continue # no two split tensors can touch the same file, due to wasm loading constraints
file.append(info)
placed = True
break
if not placed:
files.append([info])
tinygrad_dtypes = {dtypes.float32: "float32", dtypes.float16: "float16", dtypes.int8: "int8", dtypes.int32: "int32"}
for i, file in enumerate(files):
cursor = 0
with open(os.path.join(os.path.dirname(__file__), f'./net_part{i}.chunk'), "wb+") as writer:
for size, name, dtype, offsets in file:
name, part_num = (name, 0) if "_part" not in name else (name.split("_part")[0], int(name.split("_part")[1]))
default = {"parts": {}, "dtype": tinygrad_dtypes[dtype]}
weight_metadata = metadata.get(name, default)
weight_metadata["parts"][part_num] = {"file": i, "file_start_pos": cursor, "size": size}
metadata[name] = weight_metadata
data = bytes(state_dict[name].lazydata.base.realized.as_buffer())
data = data if not offsets else data[offsets[0]:offsets[1]]
writer.write(data)
cursor += size
metadata.update({name: {"parts": {0: {"empty": True, "size": size}}, "dtype": tinygrad_dtypes[dtype]} for size, name, dtype in empty_t_infos})
for k in metadata:
metadata[k]["parts"] = [part for part_num, part in sorted(metadata[k]["parts"].items(), key = lambda x: x[0])]
cursor = 0
for i, part in enumerate(metadata[k]["parts"]):
metadata[k]["parts"][i]["target_start_pos"] = cursor
cursor += part["size"]
metadata[k]["size"] = cursor
# compute hashes, which client app will check to determine whether to update with new weights and/or detect integrity issues
state_dict_hash = hashlib.sha256(json.dumps(metadata, sort_keys=True).encode("utf-8")).hexdigest()
metadata = {"state_dict": metadata, "state_dict_hash": state_dict_hash, "files": []}
hashes = set()
for i in range(len(files)):
with open(os.path.join(os.path.dirname(__file__), f'./net_part{i}.chunk'), "rb") as reader:
hash = hashlib.sha256(reader.read()).hexdigest()
hashes.add(hash)
metadata["files"].append({"name": f'net_part{i}.chunk', "hash": hash})
if len(hashes) != len(files): print(f"WARNING: {len(files)} files were exported, but only {len(hashes)} are unique: something may have gone wrong")
metadata_hash = hashlib.sha256(json.dumps(metadata, sort_keys=True).encode("utf-8")).hexdigest()
metadata = {"metadata": metadata, "metadata_hash": metadata_hash}
with open(os.path.join(os.path.dirname(__file__), f'./net_metadata.json'), "w") as writer: json.dump(metadata, writer, indent=4)
return metadata
def validate_model(model, tokenizer):
prompt = "yo"
toks = [tokenizer.bos_id]
toks += [tokenizer.special_tokens["<|start_header_id|>"]] + tokenizer.encode("user") + [tokenizer.special_tokens["<|end_header_id|>"]] + tokenizer.encode("\n\n")
toks += tokenizer.encode(prompt) + [tokenizer.special_tokens["<|eot_id|>"]]
toks += [tokenizer.special_tokens["<|start_header_id|>"]] + tokenizer.encode("assistant") + [tokenizer.special_tokens["<|end_header_id|>"]] + tokenizer.encode("\n\n")
start_pos = 0
run = TinyJit(model.forward)
for tok in toks[:-1]:
run(Tensor([[tok]]), Variable("start_pos", 0, model.max_context).bind(start_pos), 0.0, 0, 0.0, 0.0, 0.0).realize()
start_pos += 1
tok = toks[-1]
result = ""
expected = "How's it going?"
while True:
tok = run(Tensor([[tok]]), Variable("start_pos", 0, model.max_context).bind(start_pos), 0.0, 0, 0.0, 0.0, 0.0).item()
start_pos += 1
if tok in tokenizer.stop_tokens or len(result) > len(expected): break
result += tokenizer.decode([tok])
assert result == expected, f"Model validation failed, expected output: {expected}, actual output: {result}"
if __name__=="__main__":
# Export BPE data for use with tiktoken.js
tokenizer_path = fetch("https://huggingface.co/bofenghuang/Meta-Llama-3-8B/resolve/main/original/tokenizer.model", "tokenizer.model", subdir="llama3-1b-instruct")
mergeable_ranks = load_tiktoken_bpe(str(tokenizer_path))
bpe_path = os.path.join(os.path.dirname(__file__), "llama3-2.tiktoken")
dump_tiktoken_bpe(mergeable_ranks, bpe_path)
tokenizer = Tokenizer(str(tokenizer_path))
model_path = fetch("https://huggingface.co/bartowski/Llama-3.2-1B-Instruct-GGUF/resolve/main/Llama-3.2-1B-Instruct-f16.gguf", "Llama-3.2-1B-Instruct-f16.gguf", subdir="llama3-1b-instruct")
Tensor.no_grad = True
max_context=1024
tok = 128000
TEMPERATURE, TOP_K, TOP_P, ALPHA_F, ALPHA_P = 0.95, 0, 0.0, 0.0, 0.0
start_pos = Variable("start_pos", 0, max_context).bind(0)
model_input = lambda: [Tensor([[tok]]), start_pos, TEMPERATURE, TOP_K, TOP_P, ALPHA_F, ALPHA_P]
Device.DEFAULT="CPU"
model = build_transformer(model_path, model_size="1B", quantize="int8", scale_dtype=dtypes.float32, device=Device.DEFAULT, max_context=max_context)
state_dict = get_state_dict(model)
validate_model(model, tokenizer)
model_name = "transformer"
with Context(BEAM=3):
cprog, js_wrapper = export_model(model, "wasm", *model_input(), model_name=model_name)
# ensure consistency with exported weights
js_wrapper = js_wrapper.replace("output.weight", "tok_embeddings.weight").replace("output.scale", "tok_embeddings.scale")
with open(os.path.join(os.path.dirname(__file__), f"{model_name}.c"), "w") as f: f.write(cprog)
with open(os.path.join(os.path.dirname(__file__), "net_clang.js"), "w") as f: f.write(js_wrapper)
Device.DEFAULT="WEBGPU"
# float16 is not yet supported for dawn/Vulkan/NVIDIA stack, see: https://issues.chromium.org/issues/42251215
# therefore for now, we used CLANG to quantize the float16 llama to int8 with float32 scales, then load to WEBGPU
model = build_transformer(model_path, model_size="1B", quantize="int8", max_context=max_context, load_weights=False)
load_state_dict(model, state_dict)
# these were the same before load_state_dict
model.output.weight, model.output.scale = model.tok_embeddings.weight, model.tok_embeddings.scale
validate_model(model, tokenizer)
metadata = prepare_browser_chunks(model) # export weights to disk
with Context(BEAM=3):
prg, input_sizes, output_sizes, state = export_model(model, "webgpu", *model_input(), model_name=model_name, stream_weights=True)
# ensure consistency with exported weights
prg = prg.replace("output.weight", "tok_embeddings.weight").replace("output.scale", "tok_embeddings.scale")
with open(os.path.join(os.path.dirname(__file__), "net.js"), "w") as f: f.write(prg)