From 3b9950241e83c0b437c3057645056c23dfbfeb61 Mon Sep 17 00:00:00 2001 From: hooved <172129504+hooved@users.noreply.github.com> Date: Thu, 27 Feb 2025 15:57:37 -0500 Subject: [PATCH] tinychat in browser, Part 1: llama (#9273) * load llama3-1B to WEBGPU device * include compile script for loading llama3 to WEBGPU * parametrize max_context in build_transformer fxn * jit_model with two different args sets * compile for webgpu, split weights * load model weight parts in browser * export all tensors from initialized transformer * run transformer inference in browser * enable tiktoken with llama bpe in browser * count total tokens on client with tiktoken.js * full client-side chat streaming, eliminate server * revert change that enabled jitting with 2 argsets * llama without Variable or cache_kv, for webgpu * have client use mask tokens / whole context * cleanup staged weights * add tiktoken.js build script, README * export CLANG for Q6_k to float32 decompression * fix and test exported CLANG code for Q6_k to fp32 * revert changes to jit and export_model * isolate clang export * test Q6_K to float32 decompression in browser * gguf_load now also returns t_infos and data_start * prepare llama-1B Q6_K gguf chunks for browser * cache and decompress quantized llama in browser * enable separate deployment of large files * fix kv cache and symbolic with llama wgpu * eliminate browser lag during decompression * hash metadata and weight chunks * delete obsolete indexeddb cache to free disk * add progress bar, track model download/decompress * refactor progress callback * skip buffer hash verification for speed * Display progress for entire loading scope * Report page load errors to user * actually display errors * skip prompt tokens already seen by model * skip prefilling with last assistant message tokens * on page load tell user if webgpu not enabled * push deployed URL root to window.history * make note of bug sources with TODO items * isolate bug in CLANG with BEAM=2 * remove clang_bug.py from diff * decompress q6k to f32 on webgpu instead of clang * remove unused code * inter-weight decomp with larger wgpu kernels * parallelize decompression submissions * refactor dequantize scheduling * add progress bar back * fix bug * temp fix for loading GGUF Q6_K to fp16 not fp32 * fix rendering of exported CLANG * remove weight casts, sketch js functions for clang * get symbolic vars from jit_cache for model export * include symbolic vars in exported CLANG * render js for clang transformer * toggle clang/webgpu deployment; refactor decomp * compile and render clang Q6_K->fp16 and int8 quant * fix rendered clang for abs(fp16), to work in wasm * simplify clang js wrapping * run compiled clang in worker * prepare llama weights in workers, q6k to int8/fp16 * tinychat on clang in browser, f32/int8 weights * move wasm inference to (now flexible) worker * don't load redundant embeddings * modest wasm perf gain with compile flags * set default backend, enable backend choice/backup * render symbolic vars in exported WEBGPU * quantize webgpu llama to int8/f32 * improve UX arising from rendered WEBGPU * clean up webgpu launch * new weights split: smaller chunks, tinygrad quant. * switch webgpu inference to int8 quant * remove unneeded clang decompression * eliminate unneeded kv cache transfer to wasm * use 1 worker for simplified clang decompression * display launch errors * refactor: stream load weight chunks to WebGPU * show loading chunk completion * quantize embeddings to int8 * test float16 as input for quantization * webgpu: use f16 source, int8 embed, eliminate q6k * simplify split weights prep: all from state_dict * revert change to nn.state.gguf_load * remove unneeded decompression from webgpu client * remove unneeded code * decrease dl chunks from 47 to 16 MiB * improve stability of webgpu loading on mobile * autodetect mobile, improve load stability * refactor: progress closure * refactor: one unified progress bar * remove unneeded code * revert changes to tinygrad core library * enforce ios18.3 nerfed max buf size * BEAM=3 webgpu * cache integrity, mobile save throttling * improve mobile UX - no autozoom on prompt box * clang: int8 from f16, remove q6k * reduce concurrent dls on mobile to 2 for stability * refactor: wasm backend with stream loading * prevent race between wasm load and indexedb save * split wasm kernels into separate modules * js wrapper for multiple wasm module inference * revert multi-module wasm to single module * make mobile wasm load more stable/fast * refactor: copy weights into wasm without crashes * fix bug in download queue; increase mobile dls * refactor exported clang wrapper, split weights * remove unnecessary code * greatly improve int8 quant quality with rounding * eliminate mobile throttling * increase webgpu context to 4096 tokens * export webgpu js functions * enable separate hosted weights for mobile/pc * enable prompt-thread switching during generation * stop generation when max_context is reached * show progress bar for prefill * tell user if webgpu fails, while wasm loads * make loading messages more concise * update font * revert changes to tinychat python app launch * cleanup quantization, add scale_dtype param * cleanup kv cache code * cleanup compile code * link tok_embeddings with output in webgpu export * refactor: export_model webgpu: symbolic vars * refactor: export_model weight loading * forgot to commit export_model.py * change CLANG to CPU * deal with pylint incorrectly failing tests * simplify f-strings for older CI python version * fix pre-python3.12 parser errors * [Int32Array] not Int32Array * cleanup webgpu compile after refactor export_model * refactor WASM export into export_model * merge WebGPU/WASM compile scripts * simplify max_contexts for local deployment * fix parser issues and whitespace * deduplicate variable defs for non-wasm clang export * cleanup code * cleanup compile scripts * simplify wasm inference wrapping * simplify webgpu symbolic vars export * refactor: unify export of symbolic variables * simplify WASM export * simplify clang/wasm export * update README and build scripts * separate files for browser/python apps * restore original python tinychat app files * browser and python tinychats share assets * minor cleanup * isolate diffs to llama files * minor cleanup * set default scale_dtype * set default scale_dtype for NF4 quantization * make quantization of tok_embeds optional * match output with tok_embeds if not quantizing * minor change --- examples/llama3.py | 38 ++++++++++++++++++++++++++------------ extra/models/llama.py | 10 +++++----- 2 files changed, 31 insertions(+), 17 deletions(-) diff --git a/examples/llama3.py b/examples/llama3.py index 1579624ce9..3cc8c7666c 100644 --- a/examples/llama3.py +++ b/examples/llama3.py @@ -73,16 +73,17 @@ class Int8Linear: self.scale = Tensor.ones(out_features, dtype=dtypes.half) def __call__(self, x): - return x.dot(self.weight.cast(dtype=dtypes.half).T*self.scale) + return x.dot(self.weight.cast(self.scale.dtype).T*self.scale) @staticmethod - def quantize(tensors, device): + def quantize(tensors, device, scale_dtype=dtypes.float16, quantize_embeds=False): new_tensors = {} for name,v in tensors.items(): - if "feed_forward" in name or "attention.w" in name: + if "feed_forward" in name or "attention.w" in name or (quantize_embeds and "tok_embeddings.weight" in name): assert "weight" in name, name + v = v.cast(scale_dtype) scale = v.abs().max(axis=1) / 127.0 - int8_weight = (v.T/scale).T.cast(dtype=dtypes.int8) + int8_weight = (v.T/scale).T.round().cast(dtype=dtypes.int8) # without round(), cast truncates -34.9 to -34 new_tensors[name] = int8_weight new_tensors[name.replace('weight', 'scale')] = scale if isinstance(device, tuple): @@ -90,8 +91,20 @@ class Int8Linear: new_tensors[name.replace('weight', 'scale')].shard_(device, axis=None) else: new_tensors[name] = v + if quantize_embeds: new_tensors.update({"output.weight": new_tensors["tok_embeddings.weight"], "output.scale": new_tensors["tok_embeddings.scale"]}) return new_tensors +class Int8Embedding: + def __init__(self, vocab_size:int, embed_size:int): + self.vocab_sz, self.embed_sz = vocab_size, embed_size + self.weight, self.scale = Tensor.ones(vocab_size, embed_size, dtype=dtypes.int8), Tensor.ones(vocab_size, dtype=dtypes.half) + + def __call__(self, idx:Tensor) -> Tensor: + if not hasattr(self, 'arange'): self.arange = Tensor.arange(self.vocab_sz, requires_grad=False, device=self.weight.device).unsqueeze(-1) + big_shp = idx.shape+(self.vocab_sz, self.embed_sz) + arange, idx, vals = self.arange.expand(big_shp), idx.reshape(idx.shape+(1, 1)).expand(big_shp), (self.weight.cast(self.scale.dtype).T*self.scale).T + return (arange == idx).mul(vals).sum(-2, acc_dtype=vals.dtype) + def NF4Linear(block_size): _CODE = [ -1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0, @@ -113,7 +126,7 @@ def NF4Linear(block_size): return x.linear(unscaled.reshape(self.out_features, self.in_features).T) @staticmethod - def quantize(state_dict: dict[str, Tensor], device) -> dict[str, Tensor]: + def quantize(state_dict: dict[str, Tensor], device, scale_dtype=dtypes.float16) -> dict[str, Tensor]: new_state_dict = {} for k, v in state_dict.items(): if "feed_forward" in k or "attention.w" in k: @@ -121,7 +134,7 @@ def NF4Linear(block_size): scale = (grouped.abs().max(axis=1, keepdim=True)) coded = ((grouped / scale).unsqueeze(-1) - CODE.to(v.device)).abs().argmin(axis=-1).cast(dtypes.uint8).flatten() new_state_dict[k] = coded[::2] * 2 ** 4 + coded[1::2] - new_state_dict[k.replace(".weight", ".scale")] = scale.cast(dtypes.float16) + new_state_dict[k.replace(".weight", ".scale")] = scale.cast(scale_dtype) if isinstance(device, tuple): new_state_dict[k].shard_(device, axis=-1) new_state_dict[k.replace('weight', 'scale')].shard_(device, axis=None) @@ -144,13 +157,14 @@ MODEL_PARAMS = { "files": 8 } } -def build_transformer(model_path: Path, model_size="8B", quantize=None, device=None): +def build_transformer(model_path: Path, model_size="8B", quantize=None, scale_dtype=dtypes.float16, device=None, max_context=8192, load_weights=True): # build model - if quantize == "int8": linear = Int8Linear - elif quantize == "nf4": linear = NF4Linear(64) - else: linear = nn.Linear - model = Transformer(**MODEL_PARAMS[model_size]["args"], linear=linear, max_context=8192, jit=True) + if quantize == "int8": linear, embedding, quantize_embeds = Int8Linear, Int8Embedding, True + elif quantize == "nf4": linear, embedding, quantize_embeds = NF4Linear(64), nn.Embedding, False + else: linear, embedding, quantize_embeds = nn.Linear, nn.Embedding, False + model = Transformer(**MODEL_PARAMS[model_size]["args"], linear=linear, embedding=embedding, max_context=max_context, jit=True) + if not load_weights: return model # load weights if model_path.is_dir(): if (model_path / "model.safetensors.index.json").exists(): weights = load(str(model_path / "model.safetensors.index.json")) @@ -168,7 +182,7 @@ def build_transformer(model_path: Path, model_size="8B", quantize=None, device=N # quantize if quantize == "float16": weights = {k:v.cast(quantize).contiguous() for k,v in weights.items()} elif quantize is not None: - weights = linear.quantize(weights, device) + weights = linear.quantize(weights, device, scale_dtype, quantize_embeds) for _,v in weights.items(): v.realize() # shard diff --git a/extra/models/llama.py b/extra/models/llama.py index 808d4ed018..d7db5f63ad 100644 --- a/extra/models/llama.py +++ b/extra/models/llama.py @@ -70,8 +70,8 @@ class Attention: assert xk.dtype == xv.dtype == self.cache_kv.dtype, f"{xk.dtype=}, {xv.dtype=}, {self.cache_kv.dtype=}" self.cache_kv.shrink((None, None, (start_pos, start_pos+seqlen), None, None)).assign(Tensor.stack(xk, xv)).realize() - keys = self.cache_kv[0].shrink((None, (0, start_pos+seqlen), None, None)) if start_pos > 0 else xk - values = self.cache_kv[1].shrink((None, (0, start_pos+seqlen), None, None)) if start_pos > 0 else xv + keys = self.cache_kv[0].shrink((None, (0, start_pos+seqlen), None, None)) + values = self.cache_kv[1].shrink((None, (0, start_pos+seqlen), None, None)) keys, values = repeat_kv(keys, self.n_rep), repeat_kv(values, self.n_rep) xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2) @@ -151,11 +151,11 @@ def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float): return output_token class Transformer: - def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_layers:int, norm_eps:float, vocab_size, linear=nn.Linear, n_kv_heads=None, rope_theta=10000, max_context=1024, jit=True, feed_forward=FeedForward): + def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_layers:int, norm_eps:float, vocab_size, linear=nn.Linear, embedding=nn.Embedding, n_kv_heads=None, rope_theta=10000, max_context=1024, jit=True, feed_forward=FeedForward): self.layers = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context, linear, feed_forward=feed_forward) for _ in range(n_layers)] self.norm = nn.RMSNorm(dim, norm_eps) - self.tok_embeddings = nn.Embedding(vocab_size, dim) - self.output = nn.Linear(dim, vocab_size, bias=False) + self.tok_embeddings = embedding(vocab_size, dim) + self.output = nn.Linear(dim, vocab_size, bias=False) if embedding == nn.Embedding else linear(dim, vocab_size, bias=False) self.max_context = max_context self.freqs_cis = precompute_freqs_cis(dim // n_heads, self.max_context * 2, rope_theta).contiguous() self.forward_jit = TinyJit(self.forward) if jit else None