From 6c24eda522af1fe43bd636c97f77e54ceeee6dad Mon Sep 17 00:00:00 2001 From: wozeparrot Date: Sat, 8 Jun 2024 19:05:45 +0000 Subject: [PATCH] feat: tinychat (#4869) --- examples/llama3.py | 14 +- examples/tinychat/common.css | 130 ++++++++++++++++ examples/tinychat/index.css | 269 ++++++++++++++++++++++++++++++++++ examples/tinychat/index.html | 137 +++++++++++++++++ examples/tinychat/index.js | 277 +++++++++++++++++++++++++++++++++++ 5 files changed, 825 insertions(+), 2 deletions(-) create mode 100644 examples/tinychat/common.css create mode 100644 examples/tinychat/index.css create mode 100644 examples/tinychat/index.html create mode 100644 examples/tinychat/index.js diff --git a/examples/llama3.py b/examples/llama3.py index e1c45087ae..08d596bc64 100644 --- a/examples/llama3.py +++ b/examples/llama3.py @@ -196,6 +196,9 @@ if __name__ == "__main__": parser.add_argument("--shard", type=int, default=1) parser.add_argument("--quantize", choices=["int8", "nf4"]) parser.add_argument("--api", action="store_true") + parser.add_argument("--host", type=str, default="0.0.0.0") + parser.add_argument("--port", type=int, default=7776) + parser.add_argument("--debug", action="store_true") parser.add_argument("--seed", type=int) parser.add_argument("--timing", action="store_true", help="Print timing per token") parser.add_argument("--profile", action="store_true", help="Output profile data") @@ -215,7 +218,7 @@ if __name__ == "__main__": param_bytes = sum(x.lazydata.size * x.dtype.itemsize for x in get_parameters(model)) if args.api: - from bottle import Bottle, request, response, HTTPResponse, abort + from bottle import Bottle, request, response, HTTPResponse, abort, static_file app = Bottle() cors_headers = { @@ -231,6 +234,13 @@ if __name__ == "__main__": def enable_cors(): for key, value in cors_headers.items(): response.set_header(key, value) + @app.route("/") + def server_static(filename): + return static_file(filename, root=(Path(__file__).parent / "tinychat").as_posix()) + @app.route("/") + def index(): + return static_file("index.html", root=(Path(__file__).parent / "tinychat").as_posix()) + @app.get("/v1/models") def models(): return json.dumps([str(args.model)]) @@ -330,7 +340,7 @@ if __name__ == "__main__": } yield f"data: {json.dumps(res)}\n\n" - app.run(host="0.0.0.0", port=7776, debug=True) + app.run(host=args.host, port=args.port, debug=args.debug) else: prompt = [tokenizer.bos_id] + encode_message("system", "You are an *emotive* assistant.") diff --git a/examples/tinychat/common.css b/examples/tinychat/common.css new file mode 100644 index 0000000000..e654d6a13f --- /dev/null +++ b/examples/tinychat/common.css @@ -0,0 +1,130 @@ +/* make it responsive */ +@media(min-width: 852px) { + body { + font-size: 14px; + } +} +@media(max-width: 852px) { + body { + font-size: 12px; + } +} + +/* resets */ +html, body { + width: 100%; + height: 100%; +} + +*::-webkit-scrollbar { + display: none; +} + +* { + -ms-overflow-style: none; + scrollbar-width: none; +} + +* { + -moz-box-sizing: border-box; + -webkit-box-sizing: border-box; + box-sizing: border-box; +} + +/* default */ +body { + margin: 0; + background-color: var(--primary-bg-color); + color: var(--foreground-color); +} + +h1, h2, h3, h4, h5, h6 { + margin: 0em; +} + +hr { + width: 92%; +} + +button { + cursor: pointer; + border: none; + background-color: transparent; +} +button:hover { +} +button:active { +} + +/* components */ +.container { + margin: 0 auto; + padding: 1rem; +} + +.centered { + display: flex; + flex-direction: column; + justify-content: center; + align-items: center; +} + +.centered-w-only { + position: absolute; + left: 50%; + transform: translateX(-50%); +} + +.centered-h-only { + position: absolute; + top: 50%; + transform: translateY(-50%); +} + +.card { + padding: 0; +} + +.card-header { + padding: 0.5rem 1rem; +} + +.card-container { + width: 96vw; + height: 100%; + gap: 1rem; + display: flex; + flex-direction: row; + flex-wrap: wrap; + justify-content: center; + align-items: center; +} + +.clean-a { + text-decoration: underline; + text-decoration-color: #006fc1; + text-decoration-thickness: 2px; + color: inherit; +} + +.hover-underline { + text-decoration: underline; + text-decoration-color: #228039; + text-decoration-thickness: 2px; + color: inherit; +} + +.flex-horizontal { + display: flex; + flex-direction: row; + justify-content: space-between; + align-items: center; +} + +.vertical-separator { + padding: 0 0.5rem; +} + +[x-cloak] { + display: none !important; +} diff --git a/examples/tinychat/index.css b/examples/tinychat/index.css new file mode 100644 index 0000000000..1cc7514a47 --- /dev/null +++ b/examples/tinychat/index.css @@ -0,0 +1,269 @@ +/* define colors */ +:root { + --primary-color: #a52e4d; + --primary-color-transparent: #a52e4d66; + --secondary-color: #228039; + --secondary-color-transparent: #22803966; + + --red-color: #a52e4d; + --green-color: #228039; + --silver-color: #88808e; +} +@media(prefers-color-scheme: light) { + :root { + --primary-bg-color: #f0f0f0; + --secondary-bg-color: #eeeeee; + --tertiary-bg-color: #dddddd; + --foreground-color: #111111; + --accent-color: #000000; + } +} +@media(prefers-color-scheme: dark) { + :root { + --primary-bg-color: #111111; + --secondary-bg-color: #131313; + --tertiary-bg-color: #232323; + --foreground-color: #f0f0f0; + --accent-color: #aaaaaa; + } +} + +main { + width: 100%; + height: 100%; + + display: flex; + flex-direction: column; + + place-items: center; +} + +.home { + width: 100%; + height: 90%; + + margin-bottom: 10rem; +} + +.title { + font-size: 3rem; + margin: 3rem 0; +} + +.histories-container-container { + width: 100%; + max-height: 75%; + + position: relative; +} + +.histories-container { + overflow-y: auto; + overflow-x: hidden; + width: 100%; + height: 100%; + + display: flex; + flex-direction: column; + gap: 1rem; + align-items: center; + + margin: 0; + padding: 3rem 1rem; +} + +.histories-start { + height: 3rem; + width: 100%; + + z-index: 999; + top: 0; + position: absolute; + + background: linear-gradient(180deg, var(--primary-bg-color) 0%, transparent 100%); +} +.histories-end { + height: 3rem; + width: 100%; + + z-index: 999; + bottom: 0; + position: absolute; + + background: linear-gradient(0deg, var(--primary-bg-color) 0%, transparent 100%); +} + +.history { + padding: 1rem; + width: 100%; + max-width: 40rem; + + background-color: var(--tertiary-bg-color); + border-radius: 10px; + border-left: 2px solid var(--primary-color); + + cursor: pointer; + + transform: translateX(calc(1px * var(--tx, 0))); + opacity: var(--opacity, 1); +} +.history:hover { + background-color: var(--secondary-bg-color); +} + +.history-delete-button { + position: absolute; + top: 0; + right: 0; + padding: 0.5rem; + margin: 0; + outline: none; + border: none; + background-color: var(--secondary-bg-color); + color: var(--foreground-color); + border-radius: 0 0 0 10px; + cursor: pointer; + transition: 0.2s; +} +.history-delete-button:hover { + background-color: var(--tertiary-bg-color); + padding: 0.75rem; +} + +.messages { + overflow-y: auto; + height: 100%; + width: 100%; + + display: flex; + flex-direction: column; + gap: 1rem; + align-items: center; + padding-top: 1rem; + padding-bottom: 9rem; +} + +.message { + width: 96%; + max-width: 80rem; + + display: grid; + + background-color: var(--secondary-bg-color); + padding: 0.5rem 1rem; + border-radius: 10px; +} +.message-role-ai { + border-bottom: 2px solid var(--primary-color); + border-left: 2px solid var(--primary-color); + box-shadow: -10px 10px 20px 2px var(--primary-color-transparent); +} +.message-role-user { + border-bottom: 2px solid var(--secondary-color); + border-right: 2px solid var(--secondary-color); + box-shadow: 10px 10px 20px 2px var(--secondary-color-transparent); +} + +.message > pre { + white-space: pre-wrap; +} + +.hljs { + width: 100%; + position: relative; + border-radius: 10px; + /* wrap code blocks */ + white-space: pre-wrap; +} +/* put clipboard button in the top right corner of the code block */ +.clipboard-button { + position: absolute; + top: 0; + right: 0; + padding: 0.5rem; + margin: 0; + outline: none; + border: none; + background-color: var(--secondary-bg-color); + color: var(--foreground-color); + border-radius: 0 0 0 10px; + cursor: pointer; + transition: 0.2s; +} +.clipboard-button:hover { + background-color: var(--tertiary-bg-color); + padding: 0.75rem; +} + +.input-container { + position: absolute; + bottom: 0; + + /* linear gradient from background-color to transparent on the top */ + background: linear-gradient(0deg, var(--primary-bg-color) 55%, transparent 100%); + + width: 100%; + display: flex; + justify-content: center; + align-items: center; + z-index: 999; +} + +.input { + width: 90%; + min-height: 3rem; + flex-shrink: 0; + + display: flex; + flex-direction: row; + justify-content: center; + gap: 0.5rem; + + align-items: flex-end; + margin-bottom: 2rem; + margin-top: 4rem; +} + +.input-form { + width: 100%; + padding: 1rem; + min-height: 3rem; + max-height: 8rem; + + background-color: var(--tertiary-bg-color); + color: var(--foreground-color); + border-radius: 10px; + border: none; + resize: none; + outline: none; +} + +.input-button { + height: 3rem; + width: 4rem; + + background-color: var(--secondary-color); + color: var(--foreground-color); + border-radius: 10px; + padding: 0.5rem; + cursor: pointer; +} +.input-button:hover { + background-color: var(--secondary-color-transparent); +} +.input-button:disabled { + background-color: var(--secondary-bg-color); + cursor: not-allowed; +} + +/* wrap text */ +p { + white-space: pre-wrap; +} + +/* fonts */ +.megrim-regular { + font-family: "Megrim", system-ui; + font-weight: 400; + font-style: normal; +} diff --git a/examples/tinychat/index.html b/examples/tinychat/index.html new file mode 100644 index 0000000000..8cd2519078 --- /dev/null +++ b/examples/tinychat/index.html @@ -0,0 +1,137 @@ + + + + tinychat + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+
+

tinychat

+
+ +
+ +
+ +
+
+
+
+
+
+ + +
+
+
+ + + diff --git a/examples/tinychat/index.js b/examples/tinychat/index.js new file mode 100644 index 0000000000..8b1c43ff5d --- /dev/null +++ b/examples/tinychat/index.js @@ -0,0 +1,277 @@ +document.addEventListener("alpine:init", () => { + Alpine.data("state", () => ({ + // current state + cstate: { + time: null, + messages: [], + }, + + // historical state + histories: JSON.parse(localStorage.getItem("histories")) || [], + + home: 0, + generating: false, + endpoint: `${window.location.origin}/v1`, + + removeHistory(cstate) { + const index = this.histories.findIndex((state) => { + return state.time === cstate.time; + }); + if (index !== -1) { + this.histories.splice(index, 1); + localStorage.setItem("histories", JSON.stringify(this.histories)); + } + }, + + async handleSend() { + const el = document.getElementById("input-form"); + const value = el.value.trim(); + if (!value) return; + + if (this.generating) return; + this.generating = true; + if (this.home === 0) this.home = 1; + + // ensure that going back in history will go back to home + window.history.pushState({}, '', '/'); + + // add message to list + this.cstate.messages.push({ role: "user", content: value }); + + // clear textarea + el.value = ""; + el.style.height = "auto"; + el.style.height = el.scrollHeight + "px"; + + // start receiving server sent events + let gottenFirstChunk = false; + for await (const chunk of this.openaiChatCompletion(this.cstate.messages)) { + if (!gottenFirstChunk) { + this.cstate.messages.push({ role: "ai", content: "" }); + gottenFirstChunk = true; + } + + // add chunk to the last message + this.cstate.messages[this.cstate.messages.length - 1].content += chunk; + } + + // update the state in histories or add it if it doesn't exist + const index = this.histories.findIndex((cstate) => { + return cstate.time === this.cstate.time; + }); + this.cstate.time = Date.now(); + if (index !== -1) { + // update the time + this.histories[index] = this.cstate; + } else { + this.histories.push(this.cstate); + } + // update in local storage + localStorage.setItem("histories", JSON.stringify(this.histories)); + + this.generating = false; + }, + + async handleEnter(event) { + // if shift is not pressed + if (!event.shiftKey) { + event.preventDefault(); + await this.handleSend(); + } + }, + + async *openaiChatCompletion(messages) { + // stream response + const response = await fetch(`${this.endpoint}/chat/completions`, { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ + "messages": messages, + "stream": true, + }), + }); + if (!response.ok) { + throw new Error("Failed to fetch"); + } + + for await ( + const event of response.body.pipeThrough(new TextDecoderStream()) + .pipeThrough(new EventSourceParserStream()) + ) { + if (event.type === "event") { + const json = JSON.parse(event.data); + + if (json.choices) { + const choice = json.choices[0]; + + // see if the completion is done + if (choice.finish_reason === "stop") break; + + // yield the completion + yield choice.delta.content; + } + } + } + }, + })); +}); + +const { markedHighlight } = globalThis.markedHighlight; +marked.use(markedHighlight({ + langPrefix: "hljs language-", + highlight(code, lang, _info) { + const language = hljs.getLanguage(lang) ? lang : "plaintext"; + return hljs.highlight(code, { language }).value; + }, +})); + +// **** eventsource-parser **** +class EventSourceParserStream extends TransformStream { + constructor() { + let parser; + + super({ + start(controller) { + parser = createParser((event) => { + if (event.type === "event") { + controller.enqueue(event); + } + }); + }, + + transform(chunk) { + parser.feed(chunk); + }, + }); + } +} + +function createParser(onParse) { + let isFirstChunk; + let buffer; + let startingPosition; + let startingFieldLength; + let eventId; + let eventName; + let data; + reset(); + return { + feed, + reset, + }; + function reset() { + isFirstChunk = true; + buffer = ""; + startingPosition = 0; + startingFieldLength = -1; + eventId = void 0; + eventName = void 0; + data = ""; + } + function feed(chunk) { + buffer = buffer ? buffer + chunk : chunk; + if (isFirstChunk && hasBom(buffer)) { + buffer = buffer.slice(BOM.length); + } + isFirstChunk = false; + const length = buffer.length; + let position = 0; + let discardTrailingNewline = false; + while (position < length) { + if (discardTrailingNewline) { + if (buffer[position] === "\n") { + ++position; + } + discardTrailingNewline = false; + } + let lineLength = -1; + let fieldLength = startingFieldLength; + let character; + for ( + let index = startingPosition; + lineLength < 0 && index < length; + ++index + ) { + character = buffer[index]; + if (character === ":" && fieldLength < 0) { + fieldLength = index - position; + } else if (character === "\r") { + discardTrailingNewline = true; + lineLength = index - position; + } else if (character === "\n") { + lineLength = index - position; + } + } + if (lineLength < 0) { + startingPosition = length - position; + startingFieldLength = fieldLength; + break; + } else { + startingPosition = 0; + startingFieldLength = -1; + } + parseEventStreamLine(buffer, position, fieldLength, lineLength); + position += lineLength + 1; + } + if (position === length) { + buffer = ""; + } else if (position > 0) { + buffer = buffer.slice(position); + } + } + function parseEventStreamLine(lineBuffer, index, fieldLength, lineLength) { + if (lineLength === 0) { + if (data.length > 0) { + onParse({ + type: "event", + id: eventId, + event: eventName || void 0, + data: data.slice(0, -1), + // remove trailing newline + }); + + data = ""; + eventId = void 0; + } + eventName = void 0; + return; + } + const noValue = fieldLength < 0; + const field = lineBuffer.slice( + index, + index + (noValue ? lineLength : fieldLength), + ); + let step = 0; + if (noValue) { + step = lineLength; + } else if (lineBuffer[index + fieldLength + 1] === " ") { + step = fieldLength + 2; + } else { + step = fieldLength + 1; + } + const position = index + step; + const valueLength = lineLength - step; + const value = lineBuffer.slice(position, position + valueLength).toString(); + if (field === "data") { + data += value ? "".concat(value, "\n") : "\n"; + } else if (field === "event") { + eventName = value; + } else if (field === "id" && !value.includes("\0")) { + eventId = value; + } else if (field === "retry") { + const retry = parseInt(value, 10); + if (!Number.isNaN(retry)) { + onParse({ + type: "reconnect-interval", + value: retry, + }); + } + } + } +} +const BOM = [239, 187, 191]; +function hasBom(buffer) { + return BOM.every((charCode, index) => buffer.charCodeAt(index) === charCode); +}