From fa873df9c149144a7c9dc8a0ebcf9e7449a38527 Mon Sep 17 00:00:00 2001 From: wozeparrot Date: Wed, 10 Jul 2024 20:13:52 +0000 Subject: [PATCH] bring tinychat more inline with tinyos' version (#5358) --- examples/llama3.py | 7 ++++--- examples/tinychat/favicon.svg | 25 +++++++++++++++++++++++++ examples/tinychat/index.css | 3 ++- examples/tinychat/index.html | 1 + examples/tinychat/index.js | 10 +++++----- extra/models/llama.py | 3 +++ 6 files changed, 40 insertions(+), 9 deletions(-) create mode 100644 examples/tinychat/favicon.svg diff --git a/examples/llama3.py b/examples/llama3.py index 04909758f7..29b5735e82 100644 --- a/examples/llama3.py +++ b/examples/llama3.py @@ -12,6 +12,7 @@ class Tokenizer: pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" def __init__(self, model_path: str): mergeable_ranks = load_tiktoken_bpe(model_path) + self.num_base_tokens = len(mergeable_ranks) special_tokens = [ "<|begin_of_text|>", "<|end_of_text|>", @@ -36,7 +37,7 @@ class Tokenizer: @property def stop_tokens(self): return {self.special_tokens["<|end_of_text|>"], self.special_tokens["<|eot_id|>"]} - def decode(self, toks): return self.model.decode(toks) + def decode(self, toks): return self.model.decode([t for t in toks if t < self.num_base_tokens]) def encode(self, text, allow_special=False): return self.model.encode(text, allowed_special="all" if allow_special else set(), disallowed_special=set()) @@ -181,7 +182,7 @@ def build_transformer(model_path: Path, model_size="8B", quantize=None, device=N TEMPERATURE = 0.85 TOP_K = 25 TOP_P = 0.9 -ALPHA_F = 1.1 +ALPHA_F = 0.1 ALPHA_P = 0.0 last_seen_toks = [] @@ -320,7 +321,7 @@ if __name__ == "__main__": toks = [tokenizer.bos_id] for message in rjson["messages"]: toks += encode_message(message["role"], message["content"]) - if message["role"] == "user": + if len(rjson["messages"]) > 0 and message["role"] == "user": toks += encode_role("assistant") return json.dumps(toks) diff --git a/examples/tinychat/favicon.svg b/examples/tinychat/favicon.svg new file mode 100644 index 0000000000..420a33e0f8 --- /dev/null +++ b/examples/tinychat/favicon.svg @@ -0,0 +1,25 @@ + + + + diff --git a/examples/tinychat/index.css b/examples/tinychat/index.css index e9e353f6fd..8f0908fb96 100644 --- a/examples/tinychat/index.css +++ b/examples/tinychat/index.css @@ -47,7 +47,8 @@ main { .title { font-size: 3rem; - margin: 3rem 0; + margin: 1rem 0; + margin-top: 3rem; } .histories-container-container { diff --git a/examples/tinychat/index.html b/examples/tinychat/index.html index 0a6cb053de..64932b5197 100644 --- a/examples/tinychat/index.html +++ b/examples/tinychat/index.html @@ -3,6 +3,7 @@ tinychat + diff --git a/examples/tinychat/index.js b/examples/tinychat/index.js index 974409fbd5..9f35030e38 100644 --- a/examples/tinychat/index.js +++ b/examples/tinychat/index.js @@ -74,7 +74,7 @@ document.addEventListener("alpine:init", () => { start_time = Date.now(); this.time_till_first = start_time - prefill_start; } else { - const diff = Date.now() - start_time + const diff = Date.now() - start_time; if (diff > 0) { this.tokens_per_second = tokens / (diff / 1000); } @@ -108,10 +108,10 @@ document.addEventListener("alpine:init", () => { updateTotalTokens(messages) { fetch(`${this.endpoint}/chat/token/encode`, { - method: 'POST', - headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ messages }) - }).then(response => response.json()).then(data => { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ messages }), + }).then((response) => response.json()).then((data) => { this.total_tokens = data.length; }).catch(console.error); }, diff --git a/extra/models/llama.py b/extra/models/llama.py index d8579c70bd..19422e8c22 100644 --- a/extra/models/llama.py +++ b/extra/models/llama.py @@ -109,6 +109,9 @@ def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float): setattr(sample, "alpha_counter", Tensor.zeros_like(logits, dtype=dtypes.int32).contiguous()) logits = logits - (sample.alpha_counter * af + (sample.alpha_counter > 0) * ap) + # replace NaNs with -inf + logits = (logits != logits).where(-float("inf"), logits) + # softmax t = (logits / temp).softmax()