mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
tinychat in browser, Part 3: browser app (#9276)
* load llama3-1B to WEBGPU device * include compile script for loading llama3 to WEBGPU * parametrize max_context in build_transformer fxn * jit_model with two different args sets * compile for webgpu, split weights * load model weight parts in browser * export all tensors from initialized transformer * run transformer inference in browser * enable tiktoken with llama bpe in browser * count total tokens on client with tiktoken.js * full client-side chat streaming, eliminate server * revert change that enabled jitting with 2 argsets * llama without Variable or cache_kv, for webgpu * have client use mask tokens / whole context * cleanup staged weights * add tiktoken.js build script, README * export CLANG for Q6_k to float32 decompression * fix and test exported CLANG code for Q6_k to fp32 * revert changes to jit and export_model * isolate clang export * test Q6_K to float32 decompression in browser * gguf_load now also returns t_infos and data_start * prepare llama-1B Q6_K gguf chunks for browser * cache and decompress quantized llama in browser * enable separate deployment of large files * fix kv cache and symbolic with llama wgpu * eliminate browser lag during decompression * hash metadata and weight chunks * delete obsolete indexeddb cache to free disk * add progress bar, track model download/decompress * refactor progress callback * skip buffer hash verification for speed * Display progress for entire loading scope * Report page load errors to user * actually display errors * skip prompt tokens already seen by model * skip prefilling with last assistant message tokens * on page load tell user if webgpu not enabled * push deployed URL root to window.history * make note of bug sources with TODO items * isolate bug in CLANG with BEAM=2 * remove clang_bug.py from diff * decompress q6k to f32 on webgpu instead of clang * remove unused code * inter-weight decomp with larger wgpu kernels * parallelize decompression submissions * refactor dequantize scheduling * add progress bar back * fix bug * temp fix for loading GGUF Q6_K to fp16 not fp32 * fix rendering of exported CLANG * remove weight casts, sketch js functions for clang * get symbolic vars from jit_cache for model export * include symbolic vars in exported CLANG * render js for clang transformer * toggle clang/webgpu deployment; refactor decomp * compile and render clang Q6_K->fp16 and int8 quant * fix rendered clang for abs(fp16), to work in wasm * simplify clang js wrapping * run compiled clang in worker * prepare llama weights in workers, q6k to int8/fp16 * tinychat on clang in browser, f32/int8 weights * move wasm inference to (now flexible) worker * don't load redundant embeddings * modest wasm perf gain with compile flags * set default backend, enable backend choice/backup * render symbolic vars in exported WEBGPU * quantize webgpu llama to int8/f32 * improve UX arising from rendered WEBGPU * clean up webgpu launch * new weights split: smaller chunks, tinygrad quant. * switch webgpu inference to int8 quant * remove unneeded clang decompression * eliminate unneeded kv cache transfer to wasm * use 1 worker for simplified clang decompression * display launch errors * refactor: stream load weight chunks to WebGPU * show loading chunk completion * quantize embeddings to int8 * test float16 as input for quantization * webgpu: use f16 source, int8 embed, eliminate q6k * simplify split weights prep: all from state_dict * revert change to nn.state.gguf_load * remove unneeded decompression from webgpu client * remove unneeded code * decrease dl chunks from 47 to 16 MiB * improve stability of webgpu loading on mobile * autodetect mobile, improve load stability * refactor: progress closure * refactor: one unified progress bar * remove unneeded code * revert changes to tinygrad core library * enforce ios18.3 nerfed max buf size * BEAM=3 webgpu * cache integrity, mobile save throttling * improve mobile UX - no autozoom on prompt box * clang: int8 from f16, remove q6k * reduce concurrent dls on mobile to 2 for stability * refactor: wasm backend with stream loading * prevent race between wasm load and indexedb save * split wasm kernels into separate modules * js wrapper for multiple wasm module inference * revert multi-module wasm to single module * make mobile wasm load more stable/fast * refactor: copy weights into wasm without crashes * fix bug in download queue; increase mobile dls * refactor exported clang wrapper, split weights * remove unnecessary code * greatly improve int8 quant quality with rounding * eliminate mobile throttling * increase webgpu context to 4096 tokens * export webgpu js functions * enable separate hosted weights for mobile/pc * enable prompt-thread switching during generation * stop generation when max_context is reached * show progress bar for prefill * tell user if webgpu fails, while wasm loads * make loading messages more concise * update font * revert changes to tinychat python app launch * cleanup quantization, add scale_dtype param * cleanup kv cache code * cleanup compile code * link tok_embeddings with output in webgpu export * refactor: export_model webgpu: symbolic vars * refactor: export_model weight loading * forgot to commit export_model.py * change CLANG to CPU * deal with pylint incorrectly failing tests * simplify f-strings for older CI python version * fix pre-python3.12 parser errors * [Int32Array] not Int32Array * cleanup webgpu compile after refactor export_model * refactor WASM export into export_model * merge WebGPU/WASM compile scripts * simplify max_contexts for local deployment * fix parser issues and whitespace * deduplicate variable defs for non-wasm clang export * cleanup code * cleanup compile scripts * simplify wasm inference wrapping * simplify webgpu symbolic vars export * refactor: unify export of symbolic variables * simplify WASM export * simplify clang/wasm export * update README and build scripts * separate files for browser/python apps * restore original python tinychat app files * browser and python tinychats share assets * minor cleanup * isolate app layer diff * add .gitignore for generated files * validate CPU/WEBGPU models in python * prevent infinite generation if validation fails * check if exported weight files are unique --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
5
examples/tinychat/tinychat-browser/.gitignore
vendored
Normal file
5
examples/tinychat/tinychat-browser/.gitignore
vendored
Normal file
@@ -0,0 +1,5 @@
|
||||
net_*
|
||||
llama3-2.tiktoken
|
||||
tiktoken.js
|
||||
tiktoken_bg.wasm
|
||||
transformer*
|
||||
@@ -1,8 +1,8 @@
|
||||
import os, json, hashlib, math
|
||||
from extra.export_model import export_model
|
||||
from examples.llama3 import build_transformer
|
||||
from examples.llama3 import build_transformer, Tokenizer
|
||||
from tinygrad.nn.state import get_state_dict, load_state_dict
|
||||
from tinygrad import Device, Variable, Tensor, dtypes
|
||||
from tinygrad import Device, Variable, Tensor, dtypes, TinyJit
|
||||
from tinygrad.helpers import fetch, Context
|
||||
from tiktoken.load import load_tiktoken_bpe, dump_tiktoken_bpe
|
||||
|
||||
@@ -66,21 +66,47 @@ def prepare_browser_chunks(model):
|
||||
# compute hashes, which client app will check to determine whether to update with new weights and/or detect integrity issues
|
||||
state_dict_hash = hashlib.sha256(json.dumps(metadata, sort_keys=True).encode("utf-8")).hexdigest()
|
||||
metadata = {"state_dict": metadata, "state_dict_hash": state_dict_hash, "files": []}
|
||||
hashes = set()
|
||||
for i in range(len(files)):
|
||||
with open(os.path.join(os.path.dirname(__file__), f'./net_part{i}.chunk'), "rb") as reader:
|
||||
metadata["files"].append({"name": f'net_part{i}.chunk', "hash": hashlib.sha256(reader.read()).hexdigest()})
|
||||
hash = hashlib.sha256(reader.read()).hexdigest()
|
||||
hashes.add(hash)
|
||||
metadata["files"].append({"name": f'net_part{i}.chunk', "hash": hash})
|
||||
if len(hashes) != len(files): print(f"WARNING: {len(files)} files were exported, but only {len(hashes)} are unique: something may have gone wrong")
|
||||
metadata_hash = hashlib.sha256(json.dumps(metadata, sort_keys=True).encode("utf-8")).hexdigest()
|
||||
metadata = {"metadata": metadata, "metadata_hash": metadata_hash}
|
||||
|
||||
with open(os.path.join(os.path.dirname(__file__), f'./net_metadata.json'), "w") as writer: json.dump(metadata, writer, indent=4)
|
||||
return metadata
|
||||
|
||||
def validate_model(model, tokenizer):
|
||||
prompt = "yo"
|
||||
toks = [tokenizer.bos_id]
|
||||
toks += [tokenizer.special_tokens["<|start_header_id|>"]] + tokenizer.encode("user") + [tokenizer.special_tokens["<|end_header_id|>"]] + tokenizer.encode("\n\n")
|
||||
toks += tokenizer.encode(prompt) + [tokenizer.special_tokens["<|eot_id|>"]]
|
||||
toks += [tokenizer.special_tokens["<|start_header_id|>"]] + tokenizer.encode("assistant") + [tokenizer.special_tokens["<|end_header_id|>"]] + tokenizer.encode("\n\n")
|
||||
start_pos = 0
|
||||
run = TinyJit(model.forward)
|
||||
for tok in toks[:-1]:
|
||||
run(Tensor([[tok]]), Variable("start_pos", 0, model.max_context).bind(start_pos), 0.0, 0, 0.0, 0.0, 0.0).realize()
|
||||
start_pos += 1
|
||||
tok = toks[-1]
|
||||
result = ""
|
||||
expected = "How's it going?"
|
||||
while True:
|
||||
tok = run(Tensor([[tok]]), Variable("start_pos", 0, model.max_context).bind(start_pos), 0.0, 0, 0.0, 0.0, 0.0).item()
|
||||
start_pos += 1
|
||||
if tok in tokenizer.stop_tokens or len(result) > len(expected): break
|
||||
result += tokenizer.decode([tok])
|
||||
assert result == expected, f"Model validation failed, expected output: {expected}, actual output: {result}"
|
||||
|
||||
if __name__=="__main__":
|
||||
# Export BPE data for use with tiktoken.js
|
||||
tokenizer_path = fetch("https://huggingface.co/bofenghuang/Meta-Llama-3-8B/resolve/main/original/tokenizer.model", "tokenizer.model", subdir="llama3-1b-instruct")
|
||||
mergeable_ranks = load_tiktoken_bpe(str(tokenizer_path))
|
||||
bpe_path = os.path.join(os.path.dirname(__file__), "llama3-2.tiktoken")
|
||||
dump_tiktoken_bpe(mergeable_ranks, bpe_path)
|
||||
tokenizer = Tokenizer(str(tokenizer_path))
|
||||
|
||||
model_path = fetch("https://huggingface.co/bartowski/Llama-3.2-1B-Instruct-GGUF/resolve/main/Llama-3.2-1B-Instruct-f16.gguf", "Llama-3.2-1B-Instruct-f16.gguf", subdir="llama3-1b-instruct")
|
||||
Tensor.no_grad = True
|
||||
@@ -93,7 +119,7 @@ if __name__=="__main__":
|
||||
Device.DEFAULT="CPU"
|
||||
model = build_transformer(model_path, model_size="1B", quantize="int8", scale_dtype=dtypes.float32, device=Device.DEFAULT, max_context=max_context)
|
||||
state_dict = get_state_dict(model)
|
||||
out = model.forward(*model_input())
|
||||
validate_model(model, tokenizer)
|
||||
model_name = "transformer"
|
||||
|
||||
with Context(BEAM=3):
|
||||
@@ -112,7 +138,7 @@ if __name__=="__main__":
|
||||
# these were the same before load_state_dict
|
||||
model.output.weight, model.output.scale = model.tok_embeddings.weight, model.tok_embeddings.scale
|
||||
|
||||
out = model.forward(*model_input())
|
||||
validate_model(model, tokenizer)
|
||||
metadata = prepare_browser_chunks(model) # export weights to disk
|
||||
|
||||
with Context(BEAM=3):
|
||||
|
||||
322
examples/tinychat/tinychat-browser/index.css
Normal file
322
examples/tinychat/tinychat-browser/index.css
Normal file
@@ -0,0 +1,322 @@
|
||||
/* define colors */
|
||||
:root {
|
||||
--primary-color: #fff;
|
||||
--secondary-color: #2a2a2a;
|
||||
--secondary-color-transparent: #ffffff66;
|
||||
--primary-bg-color: #1a1a1a;
|
||||
--foreground-color: #f0f0f0;
|
||||
}
|
||||
|
||||
main {
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
|
||||
place-items: center;
|
||||
}
|
||||
|
||||
.home {
|
||||
width: 100%;
|
||||
height: 90%;
|
||||
|
||||
margin-bottom: 10rem;
|
||||
}
|
||||
|
||||
.title {
|
||||
font-size: 3rem;
|
||||
margin: 1rem 0;
|
||||
margin-top: 3rem;
|
||||
}
|
||||
|
||||
.histories-container-container {
|
||||
width: 100%;
|
||||
max-height: 75%;
|
||||
|
||||
position: relative;
|
||||
}
|
||||
|
||||
.histories-container {
|
||||
overflow-y: auto;
|
||||
overflow-x: hidden;
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 1rem;
|
||||
align-items: center;
|
||||
|
||||
margin: 0;
|
||||
padding: 3rem 1rem;
|
||||
}
|
||||
|
||||
.histories-start {
|
||||
height: 3rem;
|
||||
width: 100%;
|
||||
|
||||
z-index: 999;
|
||||
top: 0;
|
||||
position: absolute;
|
||||
|
||||
background: linear-gradient(
|
||||
180deg,
|
||||
var(--primary-bg-color) 0%,
|
||||
transparent 100%
|
||||
);
|
||||
}
|
||||
.histories-end {
|
||||
height: 3rem;
|
||||
width: 100%;
|
||||
|
||||
z-index: 999;
|
||||
bottom: 0;
|
||||
position: absolute;
|
||||
|
||||
background: linear-gradient(
|
||||
0deg,
|
||||
var(--primary-bg-color) 0%,
|
||||
transparent 100%
|
||||
);
|
||||
}
|
||||
|
||||
.history {
|
||||
padding: 1rem;
|
||||
width: 100%;
|
||||
max-width: 40rem;
|
||||
|
||||
background-color: var(--secondary-color);
|
||||
border-radius: 10px;
|
||||
border-left: 2px solid var(--primary-color);
|
||||
|
||||
cursor: pointer;
|
||||
|
||||
transform: translateX(calc(1px * var(--tx, 0)));
|
||||
opacity: var(--opacity, 1);
|
||||
}
|
||||
.history:hover {
|
||||
background-color: var(--secondary-color);
|
||||
}
|
||||
|
||||
.history-delete-button {
|
||||
position: absolute;
|
||||
top: 0;
|
||||
right: 0;
|
||||
padding: 0.5rem;
|
||||
margin: 0;
|
||||
outline: none;
|
||||
border: none;
|
||||
background-color: var(--secondary-color);
|
||||
color: var(--foreground-color);
|
||||
border-radius: 0 0 0 10px;
|
||||
cursor: pointer;
|
||||
transition: 0.2s;
|
||||
}
|
||||
.history-delete-button:hover {
|
||||
background-color: var(--secondary-color);
|
||||
padding: 0.75rem;
|
||||
}
|
||||
|
||||
.messages {
|
||||
overflow-y: auto;
|
||||
height: 100%;
|
||||
width: 100%;
|
||||
max-width: 1200px;
|
||||
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 1rem;
|
||||
align-items: center;
|
||||
padding-top: 1rem;
|
||||
padding-bottom: 11rem;
|
||||
}
|
||||
|
||||
.message {
|
||||
max-width: 75%;
|
||||
padding: 0.5rem 1rem;
|
||||
border-radius: 20px;
|
||||
}
|
||||
.message-role-assistant {
|
||||
background-color: var(--secondary-color);
|
||||
margin-right: auto;
|
||||
color: #fff;
|
||||
}
|
||||
.message-role-user {
|
||||
margin-left: auto;
|
||||
background-color: var(--primary-color);
|
||||
color: #000;
|
||||
}
|
||||
|
||||
.message > pre {
|
||||
white-space: pre-wrap;
|
||||
}
|
||||
|
||||
.hljs {
|
||||
width: 100%;
|
||||
position: relative;
|
||||
border-radius: 10px;
|
||||
/* wrap code blocks */
|
||||
white-space: pre-wrap;
|
||||
}
|
||||
/* put clipboard button in the top right corner of the code block */
|
||||
.clipboard-button {
|
||||
position: absolute;
|
||||
top: 0;
|
||||
right: 0;
|
||||
padding: 0.5rem;
|
||||
margin: 0;
|
||||
outline: none;
|
||||
border: none;
|
||||
background-color: var(--secondary-color);
|
||||
color: var(--foreground-color);
|
||||
border-radius: 0 0 0 10px;
|
||||
cursor: pointer;
|
||||
transition: 0.2s;
|
||||
}
|
||||
.clipboard-button:hover {
|
||||
background-color: var(--secondary-color);
|
||||
padding: 0.75rem;
|
||||
}
|
||||
|
||||
.input-container {
|
||||
position: absolute;
|
||||
bottom: 0;
|
||||
|
||||
/* linear gradient from background-color to transparent on the top */
|
||||
background: linear-gradient(
|
||||
0deg,
|
||||
var(--primary-bg-color) 55%,
|
||||
transparent 100%
|
||||
);
|
||||
|
||||
width: 100%;
|
||||
max-width: 1200px;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
z-index: 999;
|
||||
}
|
||||
|
||||
.input-performance {
|
||||
margin-top: 4rem;
|
||||
|
||||
display: flex;
|
||||
flex-direction: row;
|
||||
gap: 1rem;
|
||||
}
|
||||
|
||||
.input-performance-point {
|
||||
display: flex;
|
||||
flex-direction: row;
|
||||
place-items: center;
|
||||
gap: 0.5rem;
|
||||
}
|
||||
.input-performance-point > p {
|
||||
height: 1rem;
|
||||
line-height: normal;
|
||||
}
|
||||
|
||||
.input {
|
||||
width: 90%;
|
||||
min-height: 3rem;
|
||||
flex-shrink: 0;
|
||||
|
||||
display: flex;
|
||||
flex-direction: row;
|
||||
justify-content: center;
|
||||
gap: 0.5rem;
|
||||
|
||||
align-items: flex-end;
|
||||
margin-bottom: 2rem;
|
||||
}
|
||||
|
||||
.input-form {
|
||||
width: 100%;
|
||||
padding: 1rem;
|
||||
min-height: 3rem;
|
||||
max-height: 8rem;
|
||||
|
||||
background-color: var(--secondary-color);
|
||||
color: var(--foreground-color);
|
||||
border-radius: 10px;
|
||||
border: none;
|
||||
resize: none;
|
||||
outline: none;
|
||||
}
|
||||
.mobile .input-form { /* prevent auto-zoom on touching prompt box */
|
||||
font-size: 16px;
|
||||
}
|
||||
|
||||
.input-button {
|
||||
height: 3rem;
|
||||
width: 4rem;
|
||||
|
||||
background-color: var(--primary-color);
|
||||
color: var(--secondary-color);
|
||||
border-radius: 10px;
|
||||
padding: 0.5rem;
|
||||
cursor: pointer;
|
||||
}
|
||||
.input-button:hover {
|
||||
background-color: var(--secondary-color-transparent);
|
||||
}
|
||||
.input-button:disabled {
|
||||
background-color: var(--secondary-color);
|
||||
cursor: not-allowed;
|
||||
}
|
||||
|
||||
/* wrap text */
|
||||
p {
|
||||
white-space: pre-wrap;
|
||||
}
|
||||
|
||||
/* fonts */
|
||||
.megrim-regular {
|
||||
font-family: monospace;
|
||||
font-weight: 400;
|
||||
font-style: normal;
|
||||
}
|
||||
|
||||
.monospace {
|
||||
font-family: monospace;
|
||||
}
|
||||
|
||||
.loading-bar {
|
||||
display: flex;
|
||||
flex-direction: row;
|
||||
align-items: center;
|
||||
gap: 0.5rem;
|
||||
width: 100%;
|
||||
min-height: 3rem;
|
||||
margin-bottom: 2rem;
|
||||
}
|
||||
|
||||
.loading-text {
|
||||
color: var(--foreground-color);
|
||||
font-size: 1rem;
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
#progress-percentage {
|
||||
color: var(--foreground-color);
|
||||
font-size: 1rem;
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
.progress-bar {
|
||||
flex-grow: 1;
|
||||
height: 0.5rem;
|
||||
background-color: var(--secondary-color);
|
||||
border-radius: 5px;
|
||||
overflow: hidden;
|
||||
position: relative;
|
||||
}
|
||||
|
||||
.progress {
|
||||
width: 0%;
|
||||
height: 100%;
|
||||
background-color: var(--primary-color);
|
||||
transition: width 0.2s ease-in-out;
|
||||
}
|
||||
182
examples/tinychat/tinychat-browser/index.html
Normal file
182
examples/tinychat/tinychat-browser/index.html
Normal file
@@ -0,0 +1,182 @@
|
||||
<!DOCTYPE html>
|
||||
|
||||
<head>
|
||||
<title>tinychat</title>
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1">
|
||||
<link rel="icon" href="../favicon.svg" type="image/svg+xml">
|
||||
|
||||
<script defer src="../assets/cdn.jsdelivr.net/npm/@alpine-collective/toolkit@1.0.2/dist/cdn.min.js"></script>
|
||||
<script defer src="../assets/cdn.jsdelivr.net/npm/@alpinejs/intersect@3.x.x/dist/cdn.min.js"></script>
|
||||
<script defer src="../assets/cdn.jsdelivr.net/npm/@alpinejs/focus@3.x.x/dist/cdn.min.js"></script>
|
||||
<script defer src="../assets/unpkg.com/@marcreichel/alpine-autosize@1.3.x/dist/alpine-autosize.min.js"></script>
|
||||
<script defer src="../assets/unpkg.com/alpinejs@3.x.x/dist/cdn.min.js"></script>
|
||||
|
||||
<script src="../assets/unpkg.com/dompurify@3.1.5/dist/purify.min.js"></script>
|
||||
<script src="../assets/unpkg.com/marked@13.0.0/marked.min.js"></script>
|
||||
<script src="../assets/unpkg.com/marked-highlight@2.1.2/lib/index.umd.js"></script>
|
||||
<script src="../assets/unpkg.com/@highlightjs/cdn-assets@11.9.0/highlight.min.js"></script>
|
||||
|
||||
<script src="index.js"></script>
|
||||
|
||||
<link rel="stylesheet" href="../assets/cdn.jsdelivr.net/npm/purecss@3.0.0/build/base-min.css">
|
||||
<link rel="stylesheet" href="../assets/cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/css/all.min.css"
|
||||
integrity="sha512-SnH5WK+bZxgPHs44uWIX+LLJAJ9/2PkPKZ5QiAj6Ta86w+fsb2TkcmfRyVX3pBnMFcV7oQPJkl9QevSCWr3W6A=="
|
||||
crossorigin="anonymous" referrerpolicy="no-referrer" />
|
||||
<link rel="stylesheet" href="../assets/unpkg.com/@highlightjs/cdn-assets@11.9.0/styles/vs2015.min.css">
|
||||
|
||||
<link rel="stylesheet" href="index.css">
|
||||
<link rel="stylesheet" href="../common.css">
|
||||
</head>
|
||||
|
||||
<body>
|
||||
<main x-data="state" x-init="console.log(endpoint)">
|
||||
<div class="home centered" x-show="home === 0" x-transition x-effect="
|
||||
$refs.inputForm.focus();
|
||||
if (home === 1) setTimeout(() => home = 2, 100);
|
||||
if (home === -1) setTimeout(() => home = 0, 100);
|
||||
" @popstate.window="
|
||||
if (home === 2) {
|
||||
cancelGeneration = true;
|
||||
if (maxContextReached) generating = false;
|
||||
if (!generating) cstate = { time: null, messages: [] };
|
||||
home = -1;
|
||||
time_till_first = 0;
|
||||
tokens_per_second = 0;
|
||||
total_tokens = 0;
|
||||
}
|
||||
">
|
||||
<h1 class="title megrim-regular">tinychat</h1>
|
||||
<div class="histories-container-container">
|
||||
<template x-if="histories.length">
|
||||
<div class="histories-start"></div>
|
||||
</template>
|
||||
<div class="histories-container" x-intersect="
|
||||
$el.scrollTo({ top: 0, behavior: 'smooth' });
|
||||
">
|
||||
<template x-for="_state in histories.toSorted((a, b) => b.time - a.time)">
|
||||
<div x-data="{ otx: 0, trigger: 75 }" class="history" @click="
|
||||
cstate = _state;
|
||||
updateTotalTokens(cstate.messages);
|
||||
home = 1;
|
||||
// ensure that going back in history will go back to home
|
||||
window.history.pushState({}, '', window.TINYCHAT_ROOT || '/');
|
||||
" @touchstart="
|
||||
otx = $event.changedTouches[0].clientX;
|
||||
" @touchmove="
|
||||
$el.style.setProperty('--tx', $event.changedTouches[0].clientX - otx);
|
||||
$el.style.setProperty('--opacity', 1 - (Math.abs($event.changedTouches[0].clientX - otx) / trigger));
|
||||
" @touchend="
|
||||
if (Math.abs($event.changedTouches[0].clientX - otx) > trigger) removeHistory(_state);
|
||||
$el.style.setProperty('--tx', 0);
|
||||
$el.style.setProperty('--opacity', 1);
|
||||
">
|
||||
<h3 x-text="new Date(_state.time).toLocaleString()"></h3>
|
||||
<p x-text="$truncate(_state.messages[0].content, 80)"></p>
|
||||
<!-- delete button -->
|
||||
<button class="history-delete-button" @click.stop="removeHistory(_state);">
|
||||
<i class=" fas fa-trash"></i>
|
||||
</button>
|
||||
</div>
|
||||
</template>
|
||||
</div>
|
||||
<template x-if="histories.length">
|
||||
<div class="histories-end"></div>
|
||||
</template>
|
||||
</div>
|
||||
</div>
|
||||
<div x-ref="messages" class="messages" x-init="
|
||||
$watch('cstate', value => {
|
||||
$el.innerHTML = '';
|
||||
value.messages.forEach(({ role, content }) => {
|
||||
const div = document.createElement('div');
|
||||
div.className = `message message-role-${role}`;
|
||||
try {
|
||||
div.innerHTML = DOMPurify.sanitize(marked.parse(content));
|
||||
} catch (e) {
|
||||
console.log(content);
|
||||
console.error(e);
|
||||
}
|
||||
|
||||
// add a clipboard button to all code blocks
|
||||
const codeBlocks = div.querySelectorAll('.hljs');
|
||||
codeBlocks.forEach(codeBlock => {
|
||||
const button = document.createElement('button');
|
||||
button.className = 'clipboard-button';
|
||||
button.innerHTML = '<i class=\'fas fa-clipboard\'></i>';
|
||||
button.onclick = () => {
|
||||
// navigator.clipboard.writeText(codeBlock.textContent);
|
||||
const range = document.createRange();
|
||||
range.setStartBefore(codeBlock);
|
||||
range.setEndAfter(codeBlock);
|
||||
window.getSelection()?.removeAllRanges();
|
||||
window.getSelection()?.addRange(range);
|
||||
document.execCommand('copy');
|
||||
window.getSelection()?.removeAllRanges();
|
||||
|
||||
button.innerHTML = '<i class=\'fas fa-check\'></i>';
|
||||
setTimeout(() => button.innerHTML = '<i class=\'fas fa-clipboard\'></i>', 1000);
|
||||
};
|
||||
codeBlock.appendChild(button);
|
||||
});
|
||||
|
||||
$el.appendChild(div);
|
||||
});
|
||||
|
||||
$el.scrollTo({ top: $el.scrollHeight, behavior: 'smooth' });
|
||||
});
|
||||
" x-intersect="
|
||||
$el.scrollTo({ top: $el.scrollHeight, behavior: 'smooth' });
|
||||
" x-show="home === 2" x-transition>
|
||||
</div>
|
||||
<div class="input-container">
|
||||
<div class="input-performance">
|
||||
<span class="input-performance-point">
|
||||
<p class="monospace" x-text="(time_till_first / 1000).toFixed(2)"></p>
|
||||
<p class="megrim-regular">SEC TO FIRST TOKEN</p>
|
||||
</span>
|
||||
<span class="input-performance-point">
|
||||
<p class="monospace" x-text="tokens_per_second.toFixed(1)"></p>
|
||||
<p class="megrim-regular">TOKENS/SEC</p>
|
||||
</span>
|
||||
<span class="input-performance-point">
|
||||
<p class="monospace" x-text="total_tokens"></p>
|
||||
<p class="megrim-regular">TOKENS</p>
|
||||
</span>
|
||||
</div>
|
||||
<div class="loading-bar" x-show="loadingMessage !== ''">
|
||||
<p class="loading-text" id="loading-message">Loading:</p>
|
||||
<span id="progress-percentage">0%</span>
|
||||
<div class="progress-bar">
|
||||
<div class="progress"></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="input" x-show="loadingMessage === ''">
|
||||
<textarea x-ref="inputForm" id="input-form" class="input-form" autofocus rows=1 x-autosize
|
||||
:placeholder="generating ? placeholderText : 'Say something'" :disabled="generating" @input="
|
||||
home = (home === 0) ? 1 : home
|
||||
if (cstate.messages.length === 0 && $el.value === '') home = -1;
|
||||
|
||||
if ($el.value !== '') {
|
||||
const messages = [...cstate.messages];
|
||||
messages.push({ role: 'user', content: $el.value });
|
||||
updateTotalTokens(messages);
|
||||
} else {
|
||||
if (cstate.messages.length === 0) total_tokens = 0;
|
||||
else updateTotalTokens(cstate.messages);
|
||||
}
|
||||
" x-effect="
|
||||
console.log(generating);
|
||||
if (!generating) $nextTick(() => {
|
||||
$el.focus();
|
||||
setTimeout(() => $refs.messages.scrollTo({ top: $refs.messages.scrollHeight, behavior: 'smooth' }), 100);
|
||||
});
|
||||
" @keydown.enter="await handleEnter($event)" @keydown.escape.window="$focus.focus($el)"></textarea>
|
||||
<button class="input-button" :disabled="generating" @click="await handleSend()">
|
||||
<i class="fas" :class="generating ? 'fa-spinner fa-spin' : 'fa-paper-plane'"></i>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</main>
|
||||
</body>
|
||||
|
||||
</html>
|
||||
927
examples/tinychat/tinychat-browser/index.js
Normal file
927
examples/tinychat/tinychat-browser/index.js
Normal file
@@ -0,0 +1,927 @@
|
||||
window.TINYCHAT_ROOT = "/tinychat-browser/";
|
||||
const queryParams = new URLSearchParams(window.location.search);
|
||||
const normalizedParams = Object.fromEntries([...queryParams].map(([key, value]) => [key.toUpperCase(), value.toUpperCase()]));
|
||||
window.BACKEND = (normalizedParams["BACKEND"] === "WASM") ? "WASM" : "WebGPU";
|
||||
const isMobileAgent = /Mobi|Android|iPhone|iPad|iPod/i.test(navigator.userAgent);
|
||||
const hasTouchScreen = 'ontouchstart' in window || navigator.maxTouchPoints > 0;
|
||||
window.isMobile = isMobileAgent || hasTouchScreen;
|
||||
if (window.isMobile) document.documentElement.classList.add('mobile'); // prevent annoying auto-zoom when entering prompt on mobile
|
||||
// MODEL_BASE_URL is where the weights are hosted, WEBGPU_EXPORT is the JS-wrapped WebGPU code exported from tinygrad
|
||||
window.PC_MODEL_BASE_URL = ".";
|
||||
window.PC_WEBGPU_EXPORT = './net.js'
|
||||
window.PC_MAX_CONTEXT = 1024;
|
||||
window.MOBILE_MODEL_BASE_URL = ".";
|
||||
window.MOBILE_WEBGPU_EXPORT = './net.js'
|
||||
window.MOBILE_MAX_CONTEXT = 1024;
|
||||
|
||||
const tiktokenReady = (async () => {
|
||||
const { init, get_encoding, Tiktoken, load } = await import('./tiktoken.js');
|
||||
window.Tiktoken = Tiktoken;
|
||||
window.tiktokenInit = init;
|
||||
window.tiktokenGetEncoding = get_encoding;
|
||||
window.tiktokenLoad = load;
|
||||
})();
|
||||
|
||||
async function getDevice() {
|
||||
let adapter;
|
||||
try {
|
||||
adapter = await navigator.gpu.requestAdapter();
|
||||
if (!adapter) {
|
||||
this.loadingMessage = "Loading WASM (WebGPU not enabled):";
|
||||
throw new Error("No WebGPU adapter found");
|
||||
}
|
||||
} catch(error) {
|
||||
this.loadingMessage = "Loading WASM (WebGPU not enabled):";
|
||||
throw error;
|
||||
}
|
||||
const requiredLimits = {};
|
||||
const maxBufferSize = 322122544;
|
||||
requiredLimits.maxStorageBufferBindingSize = maxBufferSize;
|
||||
requiredLimits.maxBufferSize = maxBufferSize;
|
||||
requiredLimits.maxComputeInvocationsPerWorkgroup = 512; // may need to vary based on what the WEBGPU backend produces
|
||||
|
||||
try {
|
||||
return await adapter.requestDevice({ requiredLimits });
|
||||
} catch(error) {
|
||||
this.loadingMessage = "Loading WASM (WebGPU error):";
|
||||
throw error;
|
||||
}
|
||||
};
|
||||
|
||||
// copied from examples/webgpu/stable_diffusion/index.html
|
||||
function initDb() {
|
||||
return new Promise((resolve, reject) => {
|
||||
let db;
|
||||
const request = indexedDB.open('tinydb', 1);
|
||||
request.onerror = (event) => {
|
||||
console.error('Database error:', event.target.error);
|
||||
resolve(null);
|
||||
};
|
||||
|
||||
request.onsuccess = (event) => {
|
||||
db = event.target.result;
|
||||
console.log("Db initialized.");
|
||||
resolve(db);
|
||||
};
|
||||
|
||||
request.onupgradeneeded = (event) => {
|
||||
db = event.target.result;
|
||||
if (!db.objectStoreNames.contains('tensors')) {
|
||||
db.createObjectStore('tensors', { keyPath: 'id' });
|
||||
}
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
// copied from examples/webgpu/stable_diffusion/index.html
|
||||
function readTensorFromDb(db, id) {
|
||||
return new Promise((resolve, reject) => {
|
||||
if (db == null) {
|
||||
resolve(null);
|
||||
}
|
||||
|
||||
const transaction = db.transaction(['tensors'], 'readonly');
|
||||
const store = transaction.objectStore('tensors');
|
||||
const request = store.get(id);
|
||||
|
||||
transaction.onabort = (event) => {
|
||||
console.log("Transaction error while reading tensor: " + event.target.error);
|
||||
resolve(null);
|
||||
};
|
||||
|
||||
request.onsuccess = (event) => {
|
||||
const result = event.target.result;
|
||||
if (result) {
|
||||
resolve(result);
|
||||
} else {
|
||||
resolve(null);
|
||||
}
|
||||
};
|
||||
|
||||
request.onerror = (event) => {
|
||||
console.error('Tensor retrieve failed: ', event.target.error);
|
||||
resolve(null);
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
function getAllKeysFromDb(db) {
|
||||
return new Promise((resolve, reject) => {
|
||||
if (db == null) {resolve([]);}
|
||||
const transaction = db.transaction(['tensors'], 'readonly');
|
||||
const store = transaction.objectStore('tensors');
|
||||
const request = store.getAllKeys();
|
||||
transaction.onabort = (event) => {
|
||||
console.log("Transaction error while reading IndexedDB keys: " + event.target.error);
|
||||
resolve([]);
|
||||
};
|
||||
request.onsuccess = function (event) {resolve(event.target.result);};
|
||||
request.onerror = (event) => {
|
||||
console.error('Retrieval of IndexedDB keys failed: ', event.target.error);
|
||||
resolve([]);
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
// modified from examples/webgpu/stable_diffusion/index.html
|
||||
function saveTensorToDb(db, id, tensor) {
|
||||
return readTensorFromDb(db, id).then((result) => {
|
||||
if (!result) {
|
||||
new Promise((resolve, reject) => {
|
||||
if (db == null) {
|
||||
resolve(null);
|
||||
}
|
||||
|
||||
const transaction = db.transaction(['tensors'], 'readwrite');
|
||||
const store = transaction.objectStore('tensors');
|
||||
const request = store.put({ id: id, content: tensor });
|
||||
|
||||
transaction.onabort = (event) => {
|
||||
console.log("Transaction error while saving tensor: " + event.target.error);
|
||||
resolve(null);
|
||||
};
|
||||
|
||||
request.onsuccess = () => {
|
||||
console.log('Tensor saved successfully.');
|
||||
resolve();
|
||||
};
|
||||
|
||||
request.onerror = (event) => {
|
||||
console.error('Tensor save failed:', event.target.error);
|
||||
resolve(null);
|
||||
};
|
||||
});
|
||||
} else {
|
||||
return null;
|
||||
}
|
||||
}).catch(()=> null);
|
||||
}
|
||||
|
||||
function deleteTensorFromDb(db, id) {
|
||||
return new Promise((resolve, reject) => {
|
||||
if (db == null) {
|
||||
console.error("Database is not initialized.");
|
||||
resolve(null);
|
||||
return;
|
||||
}
|
||||
|
||||
const transaction = db.transaction(['tensors'], 'readwrite');
|
||||
const store = transaction.objectStore('tensors');
|
||||
const request = store.delete(id);
|
||||
|
||||
transaction.oncomplete = () => {
|
||||
console.log(`Tensor with ID '${id}' deleted successfully.`);
|
||||
resolve();
|
||||
};
|
||||
|
||||
transaction.onerror = (event) => {
|
||||
console.error("Transaction error while deleting tensor:", event.target.error);
|
||||
resolve(null);
|
||||
};
|
||||
|
||||
request.onerror = (event) => {
|
||||
console.error('Tensor deletion failed:', event.target.error);
|
||||
resolve(null);
|
||||
};
|
||||
|
||||
request.onsuccess = () => {
|
||||
console.log(`Delete request for tensor with ID '${id}' succeeded.`);
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
function makeProgress(total) {
|
||||
let acc = 0;
|
||||
const ret = function progress(amount, message) {
|
||||
if (amount >= 0) { // allow updating message only
|
||||
acc += amount;
|
||||
const percentage = total ? Math.trunc((acc / total) * 100) : 0;
|
||||
document.querySelector('.progress').style.width = `${percentage}%`;
|
||||
document.getElementById('progress-percentage').textContent = `${percentage}%`;
|
||||
}
|
||||
if (message) {
|
||||
this.loadingMessage = message;
|
||||
document.getElementById('loading-message').textContent = this.loadingMessage;
|
||||
}
|
||||
}.bind(this);
|
||||
ret.total = total;
|
||||
return ret;
|
||||
}
|
||||
|
||||
function sendMessageToWorker(worker, message) {
|
||||
return new Promise((resolve, reject) => {
|
||||
const onMessage = (event) => {
|
||||
resolve(event.data);
|
||||
worker.removeEventListener('message', onMessage);
|
||||
worker.removeEventListener('error', onError);
|
||||
};
|
||||
|
||||
const onError = (error) => {
|
||||
reject(error);
|
||||
worker.removeEventListener('message', onMessage);
|
||||
worker.removeEventListener('error', onError);
|
||||
};
|
||||
|
||||
worker.addEventListener('message', onMessage);
|
||||
worker.addEventListener('error', onError);
|
||||
|
||||
if (message.header === "token") worker.postMessage(message.data);
|
||||
else if (message.header === "load_state_dict") {
|
||||
if (message.data === "done") worker.postMessage(message.data);
|
||||
else worker.postMessage(message.data, message.data.map(file => file.bytes.buffer));
|
||||
}
|
||||
else if (message.header === "init") worker.postMessage("init");
|
||||
});
|
||||
}
|
||||
|
||||
async function load_state_dict (data, device, progress) {
|
||||
let state_dict = data.metadata.state_dict;
|
||||
let completed = 0;
|
||||
|
||||
// modified from examples/webgpu/stable_diffusion/index.html getProgressDlForPart
|
||||
const loadPart = async (part) => {
|
||||
const response = await fetch(part);
|
||||
const res = new Response(new ReadableStream({
|
||||
async start(controller) {
|
||||
const reader = response.body.getReader();
|
||||
for (;;) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) break;
|
||||
progress(value.byteLength);
|
||||
controller.enqueue(value);
|
||||
}
|
||||
controller.close();
|
||||
},
|
||||
}));
|
||||
|
||||
return res.arrayBuffer();
|
||||
};
|
||||
|
||||
let db = await initDb();
|
||||
|
||||
const getPart = async(filename, hash) => {
|
||||
let part = await readTensorFromDb(db, hash);
|
||||
|
||||
if (part) {
|
||||
console.log(`Cache hit: ${filename}, hash: ${hash}`);
|
||||
progress(part.content.byteLength);
|
||||
return Promise.resolve(part.content);
|
||||
} else {
|
||||
console.log(`Cache miss: ${filename}, hash: ${hash}`);
|
||||
return loadPart(`${window.MODEL_BASE_URL}/${filename}`);
|
||||
}
|
||||
}
|
||||
|
||||
const correctHashes = data.metadata.files.map(file => file.hash)
|
||||
// delete unused cached buffers to free disk space -- if we update weights, user will otherwise have obsolete cached buffers
|
||||
const dbKeys = await getAllKeysFromDb(db);
|
||||
const correctHashesSet = new Set(correctHashes);
|
||||
const notInCorrectHashes = dbKeys.filter(key => !correctHashesSet.has(key));
|
||||
// await these right before starting to save new stuff
|
||||
const deletionPromises = notInCorrectHashes.map(async (hash) => deleteTensorFromDb(db, hash));
|
||||
|
||||
// instantiates empty weight buffers on WebGPU, attaches buffers to state_dict
|
||||
let model;
|
||||
if (window.BACKEND === "WebGPU") {
|
||||
//model = await transformer().setup(device, state_dict, progress);
|
||||
model = await transformer.setupNet(device, state_dict);
|
||||
progress(0.15 * progress.total);
|
||||
|
||||
}
|
||||
else if (window.BACKEND === "WASM") {
|
||||
progress(0.02 * progress.total);
|
||||
model = new Worker(`./worker.js?version=${Date.now()}`);
|
||||
await sendMessageToWorker(model, {header: "init"});
|
||||
progress(0.02 * progress.total);
|
||||
progress(0.11 * progress.total);
|
||||
}
|
||||
|
||||
const downloaded = [];
|
||||
const triggerChainDownload = async (toDownload) => {
|
||||
const numDownloaders = window.isMobile ? 4 : toDownload.length; // TODO: dynamically base this on DL file size? current assumption is 16 MiB chunks
|
||||
|
||||
const chainDownload = async() => {
|
||||
const file = toDownload.shift();
|
||||
loadPart(`${window.MODEL_BASE_URL}/${file.name}`) // triggers download
|
||||
.then(async (arraybuf) => {
|
||||
downloaded.push({ ...file, bytes: new Uint8Array(arraybuf)});
|
||||
// pause downloads if further processing is a bottleneck
|
||||
while (toDownload.length && downloaded.length >= numDownloaders) await new Promise(resolve => setTimeout(resolve, 5));
|
||||
if (toDownload.length && downloaded.length < numDownloaders) chainDownload(); // start next download
|
||||
})
|
||||
}
|
||||
for (let i=0; i<numDownloaders; i++) if (toDownload.length) chainDownload();
|
||||
}
|
||||
|
||||
const loadFileToStateDict = async(file) => {
|
||||
if (window.BACKEND === "WebGPU") {
|
||||
for (const part of file.parts) {
|
||||
if (part.empty) continue;
|
||||
part.bytes = (part.size === file.bytes.length) ? file.bytes : file.bytes.slice(part.file_start_pos, part.file_start_pos + part.size);
|
||||
device.queue.writeBuffer(state_dict[part.key].bytes, part.target_start_pos, part.bytes); // improves stability over mappedAtCreation writing
|
||||
part.bytes = null;
|
||||
}
|
||||
}
|
||||
else if (window.BACKEND === "WASM") {
|
||||
await sendMessageToWorker(model, {header: "load_state_dict", data: [file]});
|
||||
}
|
||||
file.bytes = null;
|
||||
}
|
||||
|
||||
if (window.BACKEND === "WebGPU") { // contiguous loading not needed for WebGPU stability
|
||||
const files = data.tensor_file_groups.flatMap(obj => obj.files);
|
||||
data.tensor_file_groups = [{contiguous: false, files: files}];
|
||||
}
|
||||
|
||||
for (const group of data.tensor_file_groups) {
|
||||
const contiguous = group.contiguous;
|
||||
const files = group.files;
|
||||
const tensor_file_indices = files.map(file => file.index);
|
||||
const contiguousFiles = [];
|
||||
const fileHashes = new Set(files.map(file => file.hash));
|
||||
const cachedFileHashes = new Set(dbKeys.filter(key => fileHashes.has(key)));
|
||||
const cachedFiles = files.filter(file => cachedFileHashes.has(file.hash));
|
||||
const toDownload = files.filter(file => !cachedFileHashes.has(file.hash));
|
||||
triggerChainDownload(toDownload);
|
||||
|
||||
const loadDelay = 5;
|
||||
await Promise.all(deletionPromises);
|
||||
|
||||
while (completed < files.length) {
|
||||
const start = performance.now();
|
||||
// prioritize files from downloaded queue, so we can continue downloading more files
|
||||
if (downloaded.length) {
|
||||
const file = downloaded.shift();
|
||||
await saveTensorToDb(db, file.hash, file.bytes); // for wasm, must await to prevent race between indexedDB and transfer to worker
|
||||
if (!contiguous) await loadFileToStateDict(file);
|
||||
else contiguousFiles.push(file);
|
||||
completed += 1;
|
||||
}
|
||||
else if (!downloaded.length && cachedFiles.length) {
|
||||
const file = cachedFiles.shift();
|
||||
file.bytes = await getPart(file.name, file.hash); // reads data from IndexedDB
|
||||
if (!contiguous) await loadFileToStateDict(file);
|
||||
else contiguousFiles.push(file);
|
||||
completed += 1;
|
||||
}
|
||||
const end = performance.now();
|
||||
const elapsed = end - start;
|
||||
if (elapsed < loadDelay) await new Promise(resolve => setTimeout(resolve, loadDelay - elapsed));
|
||||
}
|
||||
if (contiguous) {
|
||||
const orderMap = tensor_file_indices.reduce((acc, id, index) => {acc[id] = index; return acc;}, {});
|
||||
contiguousFiles.sort((a, b) => orderMap[a.index] - orderMap[b.index]); // glue files together in the right order
|
||||
await sendMessageToWorker(model, {header: "load_state_dict", data: contiguousFiles});
|
||||
}
|
||||
completed = 0;
|
||||
}
|
||||
|
||||
// initialize empty kv_caches, which were part of exported model's state_dict, but which we didn't want to package/download
|
||||
if (window.BACKEND === "WASM") {
|
||||
for (const [k, v] of Object.entries(state_dict).filter(([_, v]) => v.empty === true)) {
|
||||
v.parts[0].file_start_pos = 0;
|
||||
const file = { parts: v.parts, size: v.size, bytes: new Uint8Array(v.size).fill(0) };
|
||||
await loadFileToStateDict(file);
|
||||
}
|
||||
}
|
||||
|
||||
return model;
|
||||
};
|
||||
|
||||
document.addEventListener("alpine:init", () => {
|
||||
Alpine.data("state", () => ({
|
||||
// loadingMessage updates the user on page load progress, including weights download and decompression
|
||||
// if loadingMessage is not '', then prompt box will be hidden: this is default behavior on page load
|
||||
placeholderText: "Generating...",
|
||||
loadingMessage: `Loading ${window.BACKEND} model:`,
|
||||
// model
|
||||
nets: {},
|
||||
tokenizer: null,
|
||||
max_context: 1024,
|
||||
lastSeenToks: [],
|
||||
|
||||
progress: null,
|
||||
|
||||
async init() {
|
||||
var device = null;
|
||||
var webgpuErrorMessage = null;
|
||||
if (window.BACKEND === "WebGPU") {
|
||||
try {
|
||||
device = await getDevice.call(this);
|
||||
console.log("WebGPU device initialized");
|
||||
} catch (error) {
|
||||
window.BACKEND = "WASM";
|
||||
console.log(`error: ${error}\nFailed to launch WebGPU. Loading WASM model instead...`); // return;
|
||||
webgpuErrorMessage = this.loadingMessage;
|
||||
}
|
||||
}
|
||||
|
||||
window.MODEL_BASE_URL = (window.BACKEND === "WebGPU" && !window.isMobile) ? window.PC_MODEL_BASE_URL : window.MOBILE_MODEL_BASE_URL;
|
||||
this.max_context = (window.BACKEND === "WebGPU" && !window.isMobile) ? window.PC_MAX_CONTEXT : window.MOBILE_MAX_CONTEXT;
|
||||
|
||||
const kernelsReady = (async () => {
|
||||
if (window.BACKEND === "WASM") {var exports = await import(`./net_clang.js?version=${Date.now()}`);}
|
||||
else if (window.BACKEND === "WebGPU" && !window.isMobile) {var exports = await import(`${PC_WEBGPU_EXPORT}?version=${Date.now()}`);}
|
||||
else if (window.BACKEND === "WebGPU" && window.isMobile) {var exports = await import(`${MOBILE_WEBGPU_EXPORT}?version=${Date.now()}`);}
|
||||
self.transformer = exports.default;
|
||||
})();
|
||||
|
||||
const response = await fetch(`${window.MODEL_BASE_URL}/net_metadata.json`);
|
||||
// TODO: cache metadata (and everything else, including tokenizer)
|
||||
// TODO: use service worker to reload page when offline
|
||||
const data = await response.json();
|
||||
data.metadata.files = data.metadata.files.map((file, index) => ({...file, index}));
|
||||
const state_dict = data.metadata.state_dict;
|
||||
|
||||
/*
|
||||
- allocating memory to WASM on mobile has longstanding issues: https://github.com/WebAssembly/design/issues/1397
|
||||
|
||||
- the below pattern, while yielding a succesfully-functioning model when it doesn't crash, causes regular crashes on iOS Safari (iphone 15 iOS 18.3):
|
||||
- call WASM malloc (to fit all tensors, or one per tensor) for all tensors up front, then load tensor byte chunks into the buffers in random order
|
||||
|
||||
- the below pattern has been stable on iOS Safari (iphone 15 iOS 18.3):
|
||||
- call only one WASM malloc at a time before filling the allocated bytes, as small as possible (malloc up to 256 MiB has been tested)
|
||||
- fill the malloc'd memory in linear order from start to end (what has been tested is calling wasm.HEAPU8.set on 16 MiB chunks from start to end)
|
||||
- use ALLOW_MEMORY_GROWTH=1 in wasm compilation, minimize initial memory
|
||||
|
||||
- additional considerations affecting loading design, for WASM:
|
||||
- it seems that copying bytes into wasm memory cannot be zero-copy without sharedarraybuffer, which isn't currently used due to increased hosting complexity
|
||||
- non-zero copies create memory pressure, which is not reliably capped because of lack of control over garbage collection
|
||||
- to minimize peak memory pressure if GC is delayed, we process (i.e. download + copy into WASM) large tensors (> 16 MiB) one at a time, in descending size order
|
||||
*/
|
||||
data.tensor_file_groups = []; // see above: for WASM, limit processing of multi-file Tensors to one at a time, in descending order based on Tensor size
|
||||
const unsplit_tensors = [];
|
||||
const sortedEntries = Object.entries(state_dict).sort(([, objA], [, objB]) => objB.size - objA.size);
|
||||
|
||||
let totalSize = 0;
|
||||
const seen = new Set();
|
||||
for (const [k,v] of sortedEntries) {
|
||||
const files_in_tensor = [];
|
||||
for (const part of v.parts) {
|
||||
part.key = k;
|
||||
if (part.empty) state_dict[k].empty = true; // assumes no other parts of this weight exist and are non-empty
|
||||
else {
|
||||
const file = data.metadata.files[part.file];
|
||||
if (!seen.has(file.index)) {
|
||||
seen.add(file.index);
|
||||
files_in_tensor.push(file);
|
||||
}
|
||||
totalSize += part.size;
|
||||
part.dtype = v.dtype;
|
||||
if (!data.metadata.files[part.file].parts) data.metadata.files[part.file].parts = [];
|
||||
data.metadata.files[part.file].size ??= 0;
|
||||
data.metadata.files[part.file].size += part.size;
|
||||
data.metadata.files[part.file].parts.push(part);
|
||||
}
|
||||
}
|
||||
if (files_in_tensor.length > 1) data.tensor_file_groups.push({contiguous: true, files: files_in_tensor}); // [tensorN_file0, tensorN_file1, ...]
|
||||
else if (files_in_tensor.length > 0) unsplit_tensors.push(files_in_tensor[0]);
|
||||
}
|
||||
data.tensor_file_groups.push({contiguous: false, files: unsplit_tensors});
|
||||
|
||||
data.totalSize = totalSize;
|
||||
totalSize = totalSize / 0.8; // give space in progress bar for initializing model bufs, and tokenizer
|
||||
this.progress = makeProgress.call(this, totalSize); // creates closure with totalSize
|
||||
|
||||
try {
|
||||
this.progress(0.01 * totalSize, "Loading tokenizer:");
|
||||
const wasmResponse = await fetch(`${window.MODEL_BASE_URL}/tiktoken_bg.wasm`);
|
||||
this.progress(0.01 * totalSize);
|
||||
const wasmBytes = await wasmResponse.arrayBuffer();
|
||||
await tiktokenReady;
|
||||
await window.tiktokenInit((imports) => WebAssembly.instantiate(wasmBytes, imports));
|
||||
this.progress(0.01 * totalSize);
|
||||
|
||||
this.tokenizer = await createTokenizer(`${window.MODEL_BASE_URL}/llama3-2.tiktoken`);
|
||||
const tokenizer_works = (new TextDecoder().decode(this.tokenizer.decode(this.tokenizer.encode("hello world"))) === "hello world");
|
||||
console.log("tokenizer works:", tokenizer_works)
|
||||
this.progress(0.01 * totalSize);
|
||||
} catch (error) {this.progress(-1, `Error launching tokenizer: ${error}`); console.log(error); return;}
|
||||
|
||||
try {
|
||||
const loadModelMessage = (webgpuErrorMessage) ? webgpuErrorMessage : `Loading ${window.BACKEND} model:`
|
||||
this.progress(0, loadModelMessage);
|
||||
await kernelsReady;
|
||||
const model = await load_state_dict(data, device, this.progress);
|
||||
|
||||
if (window.BACKEND === "WebGPU") {
|
||||
this.nets = {"transformer": model};
|
||||
}
|
||||
else if (window.BACKEND === "WASM") {
|
||||
const msg = await sendMessageToWorker(model, {header: "load_state_dict", data: "done"});
|
||||
this.nets = {"transformer": async (tok, start_pos) => sendMessageToWorker(model, {header: "token", data: [tok, start_pos]})};
|
||||
}
|
||||
this.progress(0.01 * totalSize, `Launching ${window.BACKEND} model:`);
|
||||
this.loadingMessage = ""; // Triggers removal of loading bar, display of prompt box
|
||||
} catch (error) {this.progress(-1, `Error launching model: ${error}`); console.log(error); return;}
|
||||
},
|
||||
|
||||
// current state
|
||||
cstate: {
|
||||
time: null,
|
||||
messages: [],
|
||||
},
|
||||
|
||||
// historical state
|
||||
histories: JSON.parse(localStorage.getItem("histories")) || [],
|
||||
|
||||
home: 0,
|
||||
generating: false,
|
||||
maxContextReached: false,
|
||||
cancelGeneration: false,
|
||||
endpoint: `${window.location.origin}/v1`,
|
||||
|
||||
// performance tracking
|
||||
time_till_first: 0,
|
||||
tokens_per_second: 0,
|
||||
total_tokens: 0,
|
||||
max_context: 0,
|
||||
|
||||
removeHistory(cstate) {
|
||||
const index = this.histories.findIndex((state) => {
|
||||
return state.time === cstate.time;
|
||||
});
|
||||
if (index !== -1) {
|
||||
this.histories.splice(index, 1);
|
||||
localStorage.setItem("histories", JSON.stringify(this.histories));
|
||||
}
|
||||
},
|
||||
|
||||
async handleSend() {
|
||||
const el = document.getElementById("input-form");
|
||||
const value = el.value.trim();
|
||||
if (!value) return;
|
||||
|
||||
if (this.generating) return;
|
||||
this.maxContextReached = false;
|
||||
this.placeholderText = "Generating...";
|
||||
this.generating = true;
|
||||
this.cancelGeneration = false;
|
||||
if (this.home === 0) this.home = 1;
|
||||
|
||||
// ensure that going back in history will go back to home
|
||||
window.history.pushState({}, "", window.TINYCHAT_ROOT || "/");
|
||||
|
||||
// add message to list
|
||||
this.cstate.messages.push({ role: "user", content: value });
|
||||
|
||||
// clear textarea
|
||||
el.value = "";
|
||||
el.style.height = "auto";
|
||||
el.style.height = el.scrollHeight + "px";
|
||||
|
||||
// reset performance tracking
|
||||
const prefill_start = Date.now();
|
||||
let start_time = 0;
|
||||
let tokens = 0;
|
||||
this.tokens_per_second = 0;
|
||||
|
||||
let gottenFirstChunk = false;
|
||||
try {
|
||||
for await (
|
||||
const chunk of this.openaiChatCompletion(this.cstate.messages)
|
||||
) {
|
||||
if (!gottenFirstChunk) {
|
||||
this.cstate.messages.push({ role: "assistant", content: "" });
|
||||
gottenFirstChunk = true;
|
||||
}
|
||||
|
||||
// add chunk to the last message
|
||||
// TODO: handle localStorage overflow
|
||||
// possible example: this.cstate.messages[...] was undefined when trying to prompt within an old cstate (chat session)
|
||||
this.cstate.messages[this.cstate.messages.length - 1].content += chunk;
|
||||
|
||||
// calculate performance tracking
|
||||
tokens += 1;
|
||||
this.total_tokens += 1;
|
||||
if (start_time === 0) {
|
||||
start_time = Date.now();
|
||||
this.time_till_first = start_time - prefill_start;
|
||||
} else {
|
||||
const diff = Date.now() - start_time;
|
||||
if (diff > 0) {
|
||||
this.tokens_per_second = tokens / (diff / 1000);
|
||||
}
|
||||
}
|
||||
this.checkMaxContext(this.total_tokens);
|
||||
if (this.cancelGeneration) break;
|
||||
}
|
||||
} finally {
|
||||
// update the state in histories or add it if it doesn't exist
|
||||
const index = this.histories.findIndex((cstate) => {
|
||||
return cstate.time === this.cstate.time;
|
||||
});
|
||||
this.cstate.time = Date.now();
|
||||
if (index !== -1) {
|
||||
// update the time
|
||||
this.histories[index] = this.cstate;
|
||||
} else {
|
||||
this.histories.push(this.cstate);
|
||||
}
|
||||
// update in local storage
|
||||
localStorage.setItem("histories", JSON.stringify(this.histories));
|
||||
|
||||
if (!this.maxContextReached) this.generating = false;
|
||||
if (this.cancelGeneration && !this.maxContextReached) this.cstate = { time: null, messages: [] };
|
||||
}
|
||||
},
|
||||
|
||||
async handleEnter(event) {
|
||||
// if shift is not pressed
|
||||
if (!event.shiftKey) {
|
||||
event.preventDefault();
|
||||
await this.handleSend();
|
||||
}
|
||||
},
|
||||
|
||||
updateTotalTokens(messages) {
|
||||
try {
|
||||
let toks = [this.tokenizer.bos_id];
|
||||
messages.forEach((message) => {
|
||||
if (!message.role || !message.content) {
|
||||
throw new Error("Each message must have a 'role' and 'content' property.");
|
||||
}
|
||||
toks = toks.concat(this.tokenizer.encodeMessage(message.role, message.content));
|
||||
|
||||
if (messages.length > 0 && messages[messages.length - 1].role === "user") {
|
||||
toks = toks.concat(this.tokenizer.encodeRole("assistant"));
|
||||
}
|
||||
this.total_tokens = toks.length;
|
||||
});
|
||||
} catch (error) {
|
||||
console.error("Error updating total tokens:", error);
|
||||
}
|
||||
},
|
||||
|
||||
checkMaxContext(num_tokens) {
|
||||
if (num_tokens >= this.max_context) {
|
||||
this.cancelGeneration = true;
|
||||
this.maxContextReached = true;
|
||||
this.placeholderText = `Max context reached: ${this.max_context} tokens`;
|
||||
}
|
||||
},
|
||||
|
||||
async *openaiChatCompletion(messages) {
|
||||
let tokens = [this.tokenizer.bos_id];
|
||||
for (const message of messages) {
|
||||
tokens = tokens.concat(this.tokenizer.encodeMessage(message.role, message.content));
|
||||
}
|
||||
tokens = tokens.concat(this.tokenizer.encodeRole("assistant"));
|
||||
this.checkMaxContext(tokens.length); // don't waste time prefilling if we know we're over the token limit
|
||||
let startPos = 0
|
||||
const prefillToks = tokens.slice(0, -1);
|
||||
|
||||
// Skip the largest possible sequence of tokens already represented at the beginning of the model's kv caches
|
||||
for (let i=0; i <= prefillToks.length; i++) {
|
||||
startPos = i;
|
||||
if (i == prefillToks.length) break;
|
||||
if (i == this.lastSeenToks.length) break;
|
||||
if (prefillToks[i] !== this.lastSeenToks[i]) break;
|
||||
}
|
||||
//this.lastSeenToks = prefillToks;
|
||||
//prefillToks = prefillToks.slice(startPos);
|
||||
const unprocessedPrefillToks = prefillToks.slice(startPos);
|
||||
this.lastSeenToks = prefillToks.slice(0, startPos);
|
||||
|
||||
this.progress = makeProgress(unprocessedPrefillToks.length);
|
||||
this.loadingMessage = (window.BACKEND === "WebGPU") ? "Reading input:" : "Loading (enable WebGPU for speed):";
|
||||
this.progress(0, this.loadingMessage);
|
||||
for (const tok of unprocessedPrefillToks) {
|
||||
if (this.cancelGeneration) {this.loadingMessage=""; return;}
|
||||
if (window.BACKEND === "WebGPU") {await this.nets["transformer"](new Int32Array([tok]), new Int32Array([startPos]));}
|
||||
else {await this.nets["transformer"](tok, startPos);}
|
||||
this.lastSeenToks.push(tok)
|
||||
startPos += 1;
|
||||
this.progress(1);
|
||||
}
|
||||
this.loadingMessage = ""; // hides progress bar
|
||||
|
||||
let lastTok = tokens[tokens.length - 1];
|
||||
while (true) {
|
||||
if (window.BACKEND === "WebGPU") {var tok = await this.nets["transformer"](new Int32Array([lastTok]), new Int32Array([startPos])); tok = tok[0][0];}
|
||||
else {var tok = await this.nets["transformer"](lastTok, startPos);}
|
||||
this.lastSeenToks.push(lastTok); // lets us skip prefilling with these tokens at the next prompt in this chain
|
||||
startPos += 1;
|
||||
lastTok = tok;
|
||||
if (this.tokenizer.stop_tokens.has(lastTok)) break;
|
||||
yield new TextDecoder().decode(this.tokenizer.decode([lastTok]));
|
||||
}
|
||||
},
|
||||
}));
|
||||
});
|
||||
|
||||
const { markedHighlight } = globalThis.markedHighlight;
|
||||
marked.use(markedHighlight({
|
||||
langPrefix: "hljs language-",
|
||||
highlight(code, lang, _info) {
|
||||
const language = hljs.getLanguage(lang) ? lang : "plaintext";
|
||||
return hljs.highlight(code, { language }).value;
|
||||
},
|
||||
}));
|
||||
|
||||
// **** eventsource-parser ****
|
||||
class EventSourceParserStream extends TransformStream {
|
||||
constructor() {
|
||||
let parser;
|
||||
|
||||
super({
|
||||
start(controller) {
|
||||
parser = createParser((event) => {
|
||||
if (event.type === "event") {
|
||||
controller.enqueue(event);
|
||||
}
|
||||
});
|
||||
},
|
||||
|
||||
transform(chunk) {
|
||||
parser.feed(chunk);
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
function createParser(onParse) {
|
||||
let isFirstChunk;
|
||||
let buffer;
|
||||
let startingPosition;
|
||||
let startingFieldLength;
|
||||
let eventId;
|
||||
let eventName;
|
||||
let data;
|
||||
reset();
|
||||
return {
|
||||
feed,
|
||||
reset,
|
||||
};
|
||||
function reset() {
|
||||
isFirstChunk = true;
|
||||
buffer = "";
|
||||
startingPosition = 0;
|
||||
startingFieldLength = -1;
|
||||
eventId = void 0;
|
||||
eventName = void 0;
|
||||
data = "";
|
||||
}
|
||||
function feed(chunk) {
|
||||
buffer = buffer ? buffer + chunk : chunk;
|
||||
if (isFirstChunk && hasBom(buffer)) {
|
||||
buffer = buffer.slice(BOM.length);
|
||||
}
|
||||
isFirstChunk = false;
|
||||
const length = buffer.length;
|
||||
let position = 0;
|
||||
let discardTrailingNewline = false;
|
||||
while (position < length) {
|
||||
if (discardTrailingNewline) {
|
||||
if (buffer[position] === "\n") {
|
||||
++position;
|
||||
}
|
||||
discardTrailingNewline = false;
|
||||
}
|
||||
let lineLength = -1;
|
||||
let fieldLength = startingFieldLength;
|
||||
let character;
|
||||
for (
|
||||
let index = startingPosition;
|
||||
lineLength < 0 && index < length;
|
||||
++index
|
||||
) {
|
||||
character = buffer[index];
|
||||
if (character === ":" && fieldLength < 0) {
|
||||
fieldLength = index - position;
|
||||
} else if (character === "\r") {
|
||||
discardTrailingNewline = true;
|
||||
lineLength = index - position;
|
||||
} else if (character === "\n") {
|
||||
lineLength = index - position;
|
||||
}
|
||||
}
|
||||
if (lineLength < 0) {
|
||||
startingPosition = length - position;
|
||||
startingFieldLength = fieldLength;
|
||||
break;
|
||||
} else {
|
||||
startingPosition = 0;
|
||||
startingFieldLength = -1;
|
||||
}
|
||||
parseEventStreamLine(buffer, position, fieldLength, lineLength);
|
||||
position += lineLength + 1;
|
||||
}
|
||||
if (position === length) {
|
||||
buffer = "";
|
||||
} else if (position > 0) {
|
||||
buffer = buffer.slice(position);
|
||||
}
|
||||
}
|
||||
function parseEventStreamLine(lineBuffer, index, fieldLength, lineLength) {
|
||||
if (lineLength === 0) {
|
||||
if (data.length > 0) {
|
||||
onParse({
|
||||
type: "event",
|
||||
id: eventId,
|
||||
event: eventName || void 0,
|
||||
data: data.slice(0, -1),
|
||||
// remove trailing newline
|
||||
});
|
||||
|
||||
data = "";
|
||||
eventId = void 0;
|
||||
}
|
||||
eventName = void 0;
|
||||
return;
|
||||
}
|
||||
const noValue = fieldLength < 0;
|
||||
const field = lineBuffer.slice(
|
||||
index,
|
||||
index + (noValue ? lineLength : fieldLength),
|
||||
);
|
||||
let step = 0;
|
||||
if (noValue) {
|
||||
step = lineLength;
|
||||
} else if (lineBuffer[index + fieldLength + 1] === " ") {
|
||||
step = fieldLength + 2;
|
||||
} else {
|
||||
step = fieldLength + 1;
|
||||
}
|
||||
const position = index + step;
|
||||
const valueLength = lineLength - step;
|
||||
const value = lineBuffer.slice(position, position + valueLength).toString();
|
||||
if (field === "data") {
|
||||
data += value ? "".concat(value, "\n") : "\n";
|
||||
} else if (field === "event") {
|
||||
eventName = value;
|
||||
} else if (field === "id" && !value.includes("\0")) {
|
||||
eventId = value;
|
||||
} else if (field === "retry") {
|
||||
const retry = parseInt(value, 10);
|
||||
if (!Number.isNaN(retry)) {
|
||||
onParse({
|
||||
type: "reconnect-interval",
|
||||
value: retry,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
const BOM = [239, 187, 191];
|
||||
function hasBom(buffer) {
|
||||
return BOM.every((charCode, index) => buffer.charCodeAt(index) === charCode);
|
||||
}
|
||||
|
||||
const PAT_STR = "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+";
|
||||
|
||||
async function createTokenizer(bpeUrl) {
|
||||
const num_base_tokens = 128000;
|
||||
const special_tokens = {
|
||||
"<|begin_of_text|>": 128000,
|
||||
"<|end_of_text|>": 128001,
|
||||
"<|start_header_id|>": 128006,
|
||||
"<|end_header_id|>": 128007,
|
||||
"<|eot_id|>": 128009
|
||||
};
|
||||
const model = await window.tiktokenLoad({
|
||||
"load_tiktoken_bpe": bpeUrl,
|
||||
"special_tokens": special_tokens,
|
||||
"pat_str": PAT_STR
|
||||
});
|
||||
const tokenizer = new window.Tiktoken(model.bpe_ranks, model.special_tokens, model.pat_str)
|
||||
|
||||
return {
|
||||
get bos_id() {
|
||||
return special_tokens["<|begin_of_text|>"];
|
||||
},
|
||||
|
||||
get stop_tokens() {
|
||||
return new Set([
|
||||
special_tokens["<|end_of_text|>"],
|
||||
special_tokens["<|eot_id|>"],
|
||||
]);
|
||||
},
|
||||
|
||||
decode(toks) {
|
||||
const filtered = toks.filter((t) => t < num_base_tokens);
|
||||
return tokenizer.decode(filtered);
|
||||
},
|
||||
|
||||
encode(text, allow_special = false) {
|
||||
const allowedSpecial = allow_special ? "all" : new Set();
|
||||
const disallowedSpecial = new Set();
|
||||
return tokenizer.encode(text, allowedSpecial, disallowedSpecial);
|
||||
},
|
||||
|
||||
encodeRole(role) {
|
||||
const tokens = [];
|
||||
tokens.push(special_tokens["<|start_header_id|>"]);
|
||||
tokens.push(...this.encode(role));
|
||||
tokens.push(special_tokens["<|end_header_id|>"]);
|
||||
tokens.push(...this.encode("\n\n"));
|
||||
return tokens;
|
||||
},
|
||||
|
||||
encodeMessage(role, content) {
|
||||
const roleTokens = this.encodeRole(role);
|
||||
const contentTokens = this.encode(content.trim());
|
||||
return [...roleTokens, ...contentTokens, special_tokens["<|eot_id|>"]];
|
||||
},
|
||||
};
|
||||
}
|
||||
11
examples/tinychat/tinychat-browser/make_tiktoken_js.sh
Executable file
11
examples/tinychat/tinychat-browser/make_tiktoken_js.sh
Executable file
@@ -0,0 +1,11 @@
|
||||
#!/usr/bin/env bash
|
||||
cd "$(dirname "$0")"
|
||||
npm init -y && \
|
||||
npm install --save-dev webpack webpack-cli && \
|
||||
npm install tiktoken && \
|
||||
jq '.scripts.build = "webpack"' package.json > package.tmp.json && \
|
||||
mv package.tmp.json package.json && \
|
||||
npm run build && \
|
||||
mv dist/*.wasm ./tiktoken_bg.wasm && \
|
||||
mv dist/* ./ && \
|
||||
rm -rf dist node_modules package-lock.json package.json
|
||||
5
examples/tinychat/tinychat-browser/tiktoken-export.js
Normal file
5
examples/tinychat/tinychat-browser/tiktoken-export.js
Normal file
@@ -0,0 +1,5 @@
|
||||
// Force Webpack to copy the WASM
|
||||
import 'tiktoken/tiktoken_bg.wasm';
|
||||
import { init, get_encoding, encoding_for_model, Tiktoken } from 'tiktoken/init';
|
||||
import { load } from 'tiktoken/load';
|
||||
export { init, get_encoding, encoding_for_model, Tiktoken, load };
|
||||
25
examples/tinychat/tinychat-browser/webpack.config.js
Normal file
25
examples/tinychat/tinychat-browser/webpack.config.js
Normal file
@@ -0,0 +1,25 @@
|
||||
const path = require("path");
|
||||
|
||||
module.exports = {
|
||||
mode: "production",
|
||||
entry: "./tiktoken-export.js",
|
||||
output: {
|
||||
filename: "tiktoken.js",
|
||||
path: path.resolve(__dirname, "dist"),
|
||||
library: {
|
||||
type: "module"
|
||||
}
|
||||
},
|
||||
experiments: {
|
||||
outputModule: true,
|
||||
asyncWebAssembly: true
|
||||
},
|
||||
module: {
|
||||
rules: [
|
||||
{
|
||||
test: /\.wasm$/,
|
||||
type: "asset/resource",
|
||||
}
|
||||
]
|
||||
}
|
||||
};
|
||||
62
examples/tinychat/tinychat-browser/worker.js
Normal file
62
examples/tinychat/tinychat-browser/worker.js
Normal file
@@ -0,0 +1,62 @@
|
||||
const kernelsReady = (async () => {
|
||||
// can't get browser to use updated versions except with cache-busting query string
|
||||
const exports = await import(`./net_clang.js?version=${Date.now()}`);
|
||||
Object.assign(self, exports);
|
||||
})();
|
||||
|
||||
async function init(event) {
|
||||
await kernelsReady;
|
||||
self.model = await self.transformer();
|
||||
self.addEventListener("message", loadStateDict);
|
||||
self.removeEventListener("message", init);
|
||||
self.postMessage("success");
|
||||
}
|
||||
|
||||
function loadStateDict(event) {
|
||||
if (event.data === "done") {
|
||||
self.addEventListener("message", inference);
|
||||
self.removeEventListener("message", loadStateDict);
|
||||
}
|
||||
else {
|
||||
if (event.data.length > 1) {
|
||||
// the bytes from files are set contiguously in WASM memory
|
||||
const malloc_size = event.data.reduce((sum, file) => sum + file.bytes.length, 0);
|
||||
const malloc_ptr = self.model.wasm._malloc(malloc_size);
|
||||
let cursor = 0;
|
||||
for (const file of event.data) {
|
||||
self.model.wasm.HEAPU8.set(file.bytes, malloc_ptr + cursor);
|
||||
for (const part of file.parts) {
|
||||
if (part.target_start_pos === 0) {
|
||||
// tell WASM code where the tensor is in memory
|
||||
self.model.wasm._set_buf(self.transformer_name_to_id[part.key], malloc_ptr + cursor);
|
||||
}
|
||||
cursor += part.size;
|
||||
}
|
||||
file.bytes = null;
|
||||
}
|
||||
}
|
||||
else {
|
||||
// the bytes from files are not guaranteed to be set contiguously in WASM memory
|
||||
const file = event.data[0];
|
||||
const malloc_ptr = self.model.wasm._malloc(file.size);
|
||||
self.model.wasm.HEAPU8.set(file.bytes, malloc_ptr);
|
||||
for (const part of file.parts) {
|
||||
if (part.target_start_pos === 0) {
|
||||
self.model.wasm._set_buf(self.transformer_name_to_id[part.key], malloc_ptr + part.file_start_pos);
|
||||
}
|
||||
}
|
||||
file.bytes = null;
|
||||
}
|
||||
}
|
||||
self.postMessage("success");
|
||||
}
|
||||
|
||||
function inference(event) {
|
||||
const [tok, start_pos] = event.data;
|
||||
const int32tok = new Int32Array([tok]);
|
||||
const model_out = self.model.run(new Uint8Array(int32tok.buffer), start_pos);
|
||||
const int32nextTok = new Int32Array(model_out[0].buffer);
|
||||
self.postMessage(int32nextTok[0]);
|
||||
}
|
||||
|
||||
self.addEventListener("message", init);
|
||||
Reference in New Issue
Block a user