tinychat in browser, Part 1: llama (#9273)

* load llama3-1B to WEBGPU device

* include compile script for loading llama3 to WEBGPU

* parametrize max_context in build_transformer fxn

* jit_model with two different args sets

* compile for webgpu, split weights

* load model weight parts in browser

* export all tensors from initialized transformer

* run transformer inference in browser

* enable tiktoken with llama bpe in browser

* count total tokens on client with tiktoken.js

* full client-side chat streaming, eliminate server

* revert change that enabled jitting with 2 argsets

* llama without Variable or cache_kv, for webgpu

* have client use mask tokens / whole context

* cleanup staged weights

* add tiktoken.js build script, README

* export CLANG for Q6_k to float32 decompression

* fix and test exported CLANG code for Q6_k to fp32

* revert changes to jit and export_model

* isolate clang export

* test Q6_K to float32 decompression in browser

* gguf_load now also returns t_infos and data_start

* prepare llama-1B Q6_K gguf chunks for browser

* cache and decompress quantized llama in browser

* enable separate deployment of large files

* fix kv cache and symbolic with llama wgpu

* eliminate browser lag during decompression

* hash metadata and weight chunks

* delete obsolete indexeddb cache to free disk

* add progress bar, track model download/decompress

* refactor progress callback

* skip buffer hash verification for speed

* Display progress for entire loading scope

* Report page load errors to user

* actually display errors

* skip prompt tokens already seen by model

* skip prefilling with last assistant message tokens

* on page load tell user if webgpu not enabled

* push deployed URL root to window.history

* make note of bug sources with TODO items

* isolate bug in CLANG with BEAM=2

* remove clang_bug.py from diff

* decompress q6k to f32 on webgpu instead of clang

* remove unused code

* inter-weight decomp with larger wgpu kernels

* parallelize decompression submissions

* refactor dequantize scheduling

* add progress bar back

* fix bug

* temp fix for loading GGUF Q6_K to fp16 not fp32

* fix rendering of exported CLANG

* remove weight casts, sketch js functions for clang

* get symbolic vars from jit_cache for model export

* include symbolic vars in exported CLANG

* render js for clang transformer

* toggle clang/webgpu deployment; refactor decomp

* compile and render clang Q6_K->fp16 and int8 quant

* fix rendered clang for abs(fp16), to work in wasm

* simplify clang js wrapping

* run compiled clang in worker

* prepare llama weights in workers, q6k to int8/fp16

* tinychat on clang in browser, f32/int8 weights

* move wasm inference to (now flexible) worker

* don't load redundant embeddings

* modest wasm perf gain with compile flags

* set default backend, enable backend choice/backup

* render symbolic vars in exported WEBGPU

* quantize webgpu llama to int8/f32

* improve UX arising from rendered WEBGPU

* clean up webgpu launch

* new weights split: smaller chunks, tinygrad quant.

* switch webgpu inference to int8 quant

* remove unneeded clang decompression

* eliminate unneeded kv cache transfer to wasm

* use 1 worker for simplified clang decompression

* display launch errors

* refactor: stream load weight chunks to WebGPU

* show loading chunk completion

* quantize embeddings to int8

* test float16 as input for quantization

* webgpu: use f16 source, int8 embed, eliminate q6k

* simplify split weights prep: all from state_dict

* revert change to nn.state.gguf_load

* remove unneeded decompression from webgpu client

* remove unneeded code

* decrease dl chunks from 47 to 16 MiB

* improve stability of webgpu loading on mobile

* autodetect mobile, improve load stability

* refactor: progress closure

* refactor: one unified progress bar

* remove unneeded code

* revert changes to tinygrad core library

* enforce ios18.3 nerfed max buf size

* BEAM=3 webgpu

* cache integrity, mobile save throttling

* improve mobile UX - no autozoom on prompt box

* clang: int8 from f16, remove q6k

* reduce concurrent dls on mobile to 2 for stability

* refactor: wasm backend with stream loading

* prevent race between wasm load and indexedb save

* split wasm kernels into separate modules

* js wrapper for multiple wasm module inference

* revert multi-module wasm to single module

* make mobile wasm load more stable/fast

* refactor: copy weights into wasm without crashes

* fix bug in download queue; increase mobile dls

* refactor exported clang wrapper, split weights

* remove unnecessary code

* greatly improve int8 quant quality with rounding

* eliminate mobile throttling

* increase webgpu context to 4096 tokens

* export webgpu js functions

* enable separate hosted weights for mobile/pc

* enable prompt-thread switching during generation

* stop generation when max_context is reached

* show progress bar for prefill

* tell user if webgpu fails, while wasm loads

* make loading messages more concise

* update font

* revert changes to tinychat python app launch

* cleanup quantization, add scale_dtype param

* cleanup kv cache code

* cleanup compile code

* link tok_embeddings with output in webgpu export

* refactor: export_model webgpu: symbolic vars

* refactor: export_model weight loading

* forgot to commit export_model.py

* change CLANG to CPU

* deal with pylint incorrectly failing tests

* simplify f-strings for older CI python version

* fix pre-python3.12 parser errors

* [Int32Array] not Int32Array

* cleanup webgpu compile after refactor export_model

* refactor WASM export into export_model

* merge WebGPU/WASM compile scripts

* simplify max_contexts for local deployment

* fix parser issues and whitespace

* deduplicate variable defs for non-wasm clang export

* cleanup code

* cleanup compile scripts

* simplify wasm inference wrapping

* simplify webgpu symbolic vars export

* refactor: unify export of symbolic variables

* simplify WASM export

* simplify clang/wasm export

* update README and build scripts

* separate files for browser/python apps

* restore original python tinychat app files

* browser and python tinychats share assets

* minor cleanup

* isolate diffs to llama files

* minor cleanup

* set default scale_dtype

* set default scale_dtype for NF4 quantization

* make quantization of tok_embeds optional

* match output with tok_embeds if not quantizing

* minor change
This commit is contained in:
hooved
2025-02-27 15:57:37 -05:00
committed by GitHub
parent 184030168d
commit 3b9950241e
2 changed files with 31 additions and 17 deletions

View File

@@ -73,16 +73,17 @@ class Int8Linear:
self.scale = Tensor.ones(out_features, dtype=dtypes.half)
def __call__(self, x):
return x.dot(self.weight.cast(dtype=dtypes.half).T*self.scale)
return x.dot(self.weight.cast(self.scale.dtype).T*self.scale)
@staticmethod
def quantize(tensors, device):
def quantize(tensors, device, scale_dtype=dtypes.float16, quantize_embeds=False):
new_tensors = {}
for name,v in tensors.items():
if "feed_forward" in name or "attention.w" in name:
if "feed_forward" in name or "attention.w" in name or (quantize_embeds and "tok_embeddings.weight" in name):
assert "weight" in name, name
v = v.cast(scale_dtype)
scale = v.abs().max(axis=1) / 127.0
int8_weight = (v.T/scale).T.cast(dtype=dtypes.int8)
int8_weight = (v.T/scale).T.round().cast(dtype=dtypes.int8) # without round(), cast truncates -34.9 to -34
new_tensors[name] = int8_weight
new_tensors[name.replace('weight', 'scale')] = scale
if isinstance(device, tuple):
@@ -90,8 +91,20 @@ class Int8Linear:
new_tensors[name.replace('weight', 'scale')].shard_(device, axis=None)
else:
new_tensors[name] = v
if quantize_embeds: new_tensors.update({"output.weight": new_tensors["tok_embeddings.weight"], "output.scale": new_tensors["tok_embeddings.scale"]})
return new_tensors
class Int8Embedding:
def __init__(self, vocab_size:int, embed_size:int):
self.vocab_sz, self.embed_sz = vocab_size, embed_size
self.weight, self.scale = Tensor.ones(vocab_size, embed_size, dtype=dtypes.int8), Tensor.ones(vocab_size, dtype=dtypes.half)
def __call__(self, idx:Tensor) -> Tensor:
if not hasattr(self, 'arange'): self.arange = Tensor.arange(self.vocab_sz, requires_grad=False, device=self.weight.device).unsqueeze(-1)
big_shp = idx.shape+(self.vocab_sz, self.embed_sz)
arange, idx, vals = self.arange.expand(big_shp), idx.reshape(idx.shape+(1, 1)).expand(big_shp), (self.weight.cast(self.scale.dtype).T*self.scale).T
return (arange == idx).mul(vals).sum(-2, acc_dtype=vals.dtype)
def NF4Linear(block_size):
_CODE = [
-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0,
@@ -113,7 +126,7 @@ def NF4Linear(block_size):
return x.linear(unscaled.reshape(self.out_features, self.in_features).T)
@staticmethod
def quantize(state_dict: dict[str, Tensor], device) -> dict[str, Tensor]:
def quantize(state_dict: dict[str, Tensor], device, scale_dtype=dtypes.float16) -> dict[str, Tensor]:
new_state_dict = {}
for k, v in state_dict.items():
if "feed_forward" in k or "attention.w" in k:
@@ -121,7 +134,7 @@ def NF4Linear(block_size):
scale = (grouped.abs().max(axis=1, keepdim=True))
coded = ((grouped / scale).unsqueeze(-1) - CODE.to(v.device)).abs().argmin(axis=-1).cast(dtypes.uint8).flatten()
new_state_dict[k] = coded[::2] * 2 ** 4 + coded[1::2]
new_state_dict[k.replace(".weight", ".scale")] = scale.cast(dtypes.float16)
new_state_dict[k.replace(".weight", ".scale")] = scale.cast(scale_dtype)
if isinstance(device, tuple):
new_state_dict[k].shard_(device, axis=-1)
new_state_dict[k.replace('weight', 'scale')].shard_(device, axis=None)
@@ -144,13 +157,14 @@ MODEL_PARAMS = {
"files": 8
}
}
def build_transformer(model_path: Path, model_size="8B", quantize=None, device=None):
def build_transformer(model_path: Path, model_size="8B", quantize=None, scale_dtype=dtypes.float16, device=None, max_context=8192, load_weights=True):
# build model
if quantize == "int8": linear = Int8Linear
elif quantize == "nf4": linear = NF4Linear(64)
else: linear = nn.Linear
model = Transformer(**MODEL_PARAMS[model_size]["args"], linear=linear, max_context=8192, jit=True)
if quantize == "int8": linear, embedding, quantize_embeds = Int8Linear, Int8Embedding, True
elif quantize == "nf4": linear, embedding, quantize_embeds = NF4Linear(64), nn.Embedding, False
else: linear, embedding, quantize_embeds = nn.Linear, nn.Embedding, False
model = Transformer(**MODEL_PARAMS[model_size]["args"], linear=linear, embedding=embedding, max_context=max_context, jit=True)
if not load_weights: return model
# load weights
if model_path.is_dir():
if (model_path / "model.safetensors.index.json").exists(): weights = load(str(model_path / "model.safetensors.index.json"))
@@ -168,7 +182,7 @@ def build_transformer(model_path: Path, model_size="8B", quantize=None, device=N
# quantize
if quantize == "float16": weights = {k:v.cast(quantize).contiguous() for k,v in weights.items()}
elif quantize is not None:
weights = linear.quantize(weights, device)
weights = linear.quantize(weights, device, scale_dtype, quantize_embeds)
for _,v in weights.items(): v.realize()
# shard

View File

@@ -70,8 +70,8 @@ class Attention:
assert xk.dtype == xv.dtype == self.cache_kv.dtype, f"{xk.dtype=}, {xv.dtype=}, {self.cache_kv.dtype=}"
self.cache_kv.shrink((None, None, (start_pos, start_pos+seqlen), None, None)).assign(Tensor.stack(xk, xv)).realize()
keys = self.cache_kv[0].shrink((None, (0, start_pos+seqlen), None, None)) if start_pos > 0 else xk
values = self.cache_kv[1].shrink((None, (0, start_pos+seqlen), None, None)) if start_pos > 0 else xv
keys = self.cache_kv[0].shrink((None, (0, start_pos+seqlen), None, None))
values = self.cache_kv[1].shrink((None, (0, start_pos+seqlen), None, None))
keys, values = repeat_kv(keys, self.n_rep), repeat_kv(values, self.n_rep)
xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)
@@ -151,11 +151,11 @@ def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):
return output_token
class Transformer:
def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_layers:int, norm_eps:float, vocab_size, linear=nn.Linear, n_kv_heads=None, rope_theta=10000, max_context=1024, jit=True, feed_forward=FeedForward):
def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_layers:int, norm_eps:float, vocab_size, linear=nn.Linear, embedding=nn.Embedding, n_kv_heads=None, rope_theta=10000, max_context=1024, jit=True, feed_forward=FeedForward):
self.layers = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context, linear, feed_forward=feed_forward) for _ in range(n_layers)]
self.norm = nn.RMSNorm(dim, norm_eps)
self.tok_embeddings = nn.Embedding(vocab_size, dim)
self.output = nn.Linear(dim, vocab_size, bias=False)
self.tok_embeddings = embedding(vocab_size, dim)
self.output = nn.Linear(dim, vocab_size, bias=False) if embedding == nn.Embedding else linear(dim, vocab_size, bias=False)
self.max_context = max_context
self.freqs_cis = precompute_freqs_cis(dim // n_heads, self.max_context * 2, rope_theta).contiguous()
self.forward_jit = TinyJit(self.forward) if jit else None