mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
Bring back WebGPU (#7063)
* Start from andredaprato:webgpu-clean * Fix infs * inf wgsl function is not needed * Emulated ulong for threefry, more tests passing * Randomness tests passing * Update model export to support new changes in webgpu, efficientnet export works again * Simplify shift emulation in wgsl * Delete test file * Fix bigger than u32 u32 literal * Why was skip copies added here? * Python3.12 for webgpu tests * Fix model export syntax error * Get test ops passing with some skips * Fix lint * Much simpler shift * Run more tests * Timestamp queries are not supported in CI, so skip search tests * All fancy indexing passing * r is ctx * Run more dtype tests by using is_dtype_supported * Cleanup ulong shift rendering * UPat -> Pat, UOps -> Ops * Pat -> UPat * Refactor render_ushift if-else * Pattern to avoid ulong mul * Remove vals_dtype * is_nan trick + rewrite, test_isnan passing * Rewrite a * select(1, nan, gate) -> select(a, nan, gate) * No arg, just op * Support char, uchar, short, ushort * Run test_index_mnis now that we have uint8 * Fix pyling * Save 3 lines by using base Compiler * No more long emulation * Remove fixup_binops * No more external_local_bufx wgsl specific cstyle modif, use base extra_pm * Simpler, faster copyin/out * Skip some new tests that use long * Fix typo * copyout touchup * Save lines by using render_cast * WebGL is not supported in core, delete it from is_dtype_supported * More narrow test skips for some unary tests * TernaryOps, UnaryOps -> Ops * TinyGrad supports WebGPU * StableDiffusion demo: f16tof32 gpu is a lib, update UI * Packed load/store, no more scale_size, no core tinygrad changes * Rename copyin, copyout * Device -> dev * Fix lint * Pattern matcher rule for packed load/store * Refactor * Shorter packed load/store * this should fix lint * Fix mypy * SD compile script working * New SD webgpu UI * New default prompt * New SD weights * Fix title when webgpu not available * Run symbolic tests, simplify is_nan, use round_up * Show step time on UI * Bump minimum wgpu version to v0.19 * Fix latent --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
from pathlib import Path
|
||||
from extra.models.efficientnet import EfficientNet
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.nn.state import safe_save
|
||||
from tinygrad.nn.state import get_state_dict, safe_save, safe_load, load_state_dict
|
||||
from extra.export_model import export_model
|
||||
from tinygrad.helpers import getenv, fetch
|
||||
import ast
|
||||
@@ -9,11 +9,15 @@ import ast
|
||||
if __name__ == "__main__":
|
||||
model = EfficientNet(0)
|
||||
model.load_from_pretrained()
|
||||
dirname = Path(__file__).parent
|
||||
# exporting a model that's loaded from safetensors doesn't work without loading in from safetensors first
|
||||
# loading the state dict from a safetensor file changes the generated kernels
|
||||
if getenv("WEBGPU") or getenv("WEBGL"):
|
||||
safe_save(get_state_dict(model), (dirname / "net.safetensors").as_posix())
|
||||
load_state_dict(model, safe_load(str(dirname / "net.safetensors")))
|
||||
mode = "clang" if getenv("CLANG", "") != "" else "webgpu" if getenv("WEBGPU", "") != "" else "webgl" if getenv("WEBGL", "") != "" else ""
|
||||
prg, inp_sizes, out_sizes, state = export_model(model, mode, Tensor.randn(1,3,224,224))
|
||||
dirname = Path(__file__).parent
|
||||
if getenv("CLANG", "") == "":
|
||||
safe_save(state, (dirname / "net.safetensors").as_posix())
|
||||
ext = "js" if getenv("WEBGPU", "") != "" or getenv("WEBGL", "") != "" else "json"
|
||||
with open(dirname / f"net.{ext}", "w") as text_file:
|
||||
text_file.write(prg)
|
||||
|
||||
@@ -6,7 +6,7 @@ from tinygrad.tensor import Tensor
|
||||
from tinygrad import Device
|
||||
from tinygrad.helpers import fetch
|
||||
from typing import NamedTuple, Any, List
|
||||
from pathlib import Path
|
||||
import requests
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
@@ -60,16 +60,22 @@ def split_safetensor(fn):
|
||||
cur_pos = 0
|
||||
|
||||
for i, end_pos in enumerate(part_end_offsets):
|
||||
with open(f'./net_part{i}.safetensors', "wb+") as f:
|
||||
with open(os.path.join(os.path.dirname(__file__), f'./net_part{i}.safetensors'), "wb+") as f:
|
||||
f.write(net_bytes[cur_pos:end_pos])
|
||||
cur_pos = end_pos
|
||||
|
||||
with open(f'./net_textmodel.safetensors', "wb+") as f:
|
||||
with open(os.path.join(os.path.dirname(__file__), f'./net_textmodel.safetensors'), "wb+") as f:
|
||||
f.write(net_bytes[text_model_start+8+json_len:])
|
||||
|
||||
return part_end_offsets
|
||||
|
||||
def fetch_dep(file, url):
|
||||
with open(file, "w", encoding="utf-8") as f:
|
||||
f.write(requests.get(url).text.replace("https://huggingface.co/wpmed/tinygrad-sd-f16/raw/main/bpe_simple_vocab_16e6.mjs", "./bpe_simple_vocab_16e6.mjs"))
|
||||
|
||||
if __name__ == "__main__":
|
||||
fetch_dep(os.path.join(os.path.dirname(__file__), "clip_tokenizer.js"), "https://huggingface.co/wpmed/tinygrad-sd-f16/raw/main/clip_tokenizer.js")
|
||||
fetch_dep(os.path.join(os.path.dirname(__file__), "bpe_simple_vocab_16e6.mjs"), "https://huggingface.co/wpmed/tinygrad-sd-f16/raw/main/bpe_simple_vocab_16e6.mjs")
|
||||
parser = argparse.ArgumentParser(description='Run Stable Diffusion', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument('--remoteweights', action='store_true', help="Use safetensors from Huggingface, or from local")
|
||||
args = parser.parse_args()
|
||||
@@ -94,12 +100,21 @@ if __name__ == "__main__":
|
||||
|
||||
prg = ""
|
||||
|
||||
def fixup_code(code, key):
|
||||
code = code.replace(key, 'main')\
|
||||
.replace("var<uniform> INFINITY : f32;\n", "fn inf(a: f32) -> f32 { return a/0.0; }\n")\
|
||||
.replace("@group(0) @binding(0)", "")\
|
||||
.replace("INFINITY", "inf(1.0)")
|
||||
|
||||
for i in range(1,9): code = code.replace(f"binding({i})", f"binding({i-1})")
|
||||
return code
|
||||
|
||||
def compile_step(model, step: Step):
|
||||
run, special_names = jit_model(step, *step.input)
|
||||
functions, statements, bufs, _ = compile_net(run, special_names)
|
||||
state = get_state_dict(model)
|
||||
weights = {id(x.lazydata.base.realized): name for name, x in state.items()}
|
||||
kernel_code = '\n\n'.join([f"const {key} = `{code.replace(key, 'main')}`;" for key, code in functions.items()])
|
||||
kernel_code = '\n\n'.join([f"const {key} = `{fixup_code(code, key)}`;" for key, code in functions.items()])
|
||||
kernel_names = ', '.join([name for (name, _, _, _) in statements])
|
||||
kernel_calls = '\n '.join([f"addComputePass(device, commandEncoder, piplines[{i}], [{', '.join(args)}], {global_size});" for i, (_name, args, global_size, _local_size) in enumerate(statements) ])
|
||||
bufs = '\n '.join([f"const {name} = " + (f"createEmptyBuf(device, {size});" if _key not in weights else f"createWeightBuf(device, {size}, getTensorBuffer(safetensor, metadata['{weights[_key]}'], '{weights[_key]}'))") + ";" for name,(size,dtype,_key) in bufs.items()])
|
||||
@@ -148,14 +163,14 @@ if __name__ == "__main__":
|
||||
|
||||
if step.name == "diffusor":
|
||||
if args.remoteweights:
|
||||
base_url = "https://huggingface.co/wpmed/tinygrad-sd-f16/resolve/main"
|
||||
base_url = "https://huggingface.co/wpmed/stable-diffusion-f16-new/resolve/main"
|
||||
else:
|
||||
state = get_state_dict(model)
|
||||
safe_save(state, os.path.join(os.path.dirname(__file__), "net.safetensors"))
|
||||
convert_f32_to_f16("./net.safetensors", "./net_conv.safetensors")
|
||||
split_safetensor("./net_conv.safetensors")
|
||||
os.remove("net.safetensors")
|
||||
os.remove("net_conv.safetensors")
|
||||
convert_f32_to_f16(os.path.join(os.path.dirname(__file__), "./net.safetensors"), os.path.join(os.path.dirname(__file__), "./net_conv.safetensors"))
|
||||
split_safetensor(os.path.join(os.path.dirname(__file__), "./net_conv.safetensors"))
|
||||
os.remove(os.path.join(os.path.dirname(__file__), "net.safetensors"))
|
||||
os.remove(os.path.join(os.path.dirname(__file__), "net_conv.safetensors"))
|
||||
base_url = "."
|
||||
|
||||
prekernel = f"""
|
||||
@@ -185,20 +200,6 @@ if __name__ == "__main__":
|
||||
counter++;
|
||||
}}
|
||||
|
||||
let allZero = true;
|
||||
let out = safetensorParts[selectedPart].subarray(...correctedOffsets);
|
||||
|
||||
for (let i = 0; i < out.length; i++) {{
|
||||
if (out[i] !== 0) {{
|
||||
allZero = false;
|
||||
break;
|
||||
}}
|
||||
}}
|
||||
|
||||
if (allZero) {{
|
||||
console.log("Error: weight '" + key + "' is all zero.");
|
||||
}}
|
||||
|
||||
return safetensorParts[selectedPart].subarray(...correctedOffsets);
|
||||
}}
|
||||
|
||||
|
||||
@@ -1,64 +0,0 @@
|
||||
const f16tof32 = `
|
||||
fn u16_to_f16(x: u32) -> f32 {
|
||||
let sign = f32((x >> 15) & 0x1);
|
||||
let exponent = f32((x >> 10) & 0x1F);
|
||||
let fraction = f32(x & 0x3FF);
|
||||
|
||||
let sign_multiplier = select(1.0, -1.0, sign == 1.0);
|
||||
if (exponent == 0.0) {
|
||||
return sign_multiplier * 6.103515625e-5 * (fraction / 1024.0);
|
||||
} else {
|
||||
return sign_multiplier * exp2(exponent - 15.0) * (1.0 + fraction / 1024.0);
|
||||
}
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<storage,read_write> data0: array<u32>;
|
||||
@group(0) @binding(1) var<storage,read_write> data1: array<f32>;
|
||||
@compute @workgroup_size(256) fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||
let gidx = gid.x;
|
||||
let outgidx = gidx*2;
|
||||
|
||||
if (gidx >= arrayLength(&data0)) {
|
||||
return;
|
||||
}
|
||||
|
||||
let oo = data0[gidx];
|
||||
let oo1 = (oo >> 16);
|
||||
let oo2 = oo & 0xFFFFu;
|
||||
|
||||
let f1 = u16_to_f16(oo2);
|
||||
let f2 = u16_to_f16(oo1);
|
||||
|
||||
data1[outgidx] = f1;
|
||||
data1[outgidx + 1] = f2;
|
||||
}`;
|
||||
|
||||
window.f16tof32GPU = async(device, inf16) => {
|
||||
const input = device.createBuffer({size: inf16.length, usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST });
|
||||
const output = device.createBuffer({size: inf16.length*2, usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST });
|
||||
|
||||
const gpuWriteBuffer = device.createBuffer({size: input.size, usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.MAP_WRITE });
|
||||
const gpuReadBuffer = device.createBuffer({ size: output.size, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ });
|
||||
const commandEncoder = device.createCommandEncoder();
|
||||
await gpuWriteBuffer.mapAsync(GPUMapMode.WRITE);
|
||||
|
||||
const alignedUint32View = new Uint32Array(inf16.buffer, inf16.byteOffset, inf16.length / 4);
|
||||
new Uint32Array(gpuWriteBuffer.getMappedRange()).set(alignedUint32View);
|
||||
|
||||
gpuWriteBuffer.unmap();
|
||||
commandEncoder.copyBufferToBuffer(gpuWriteBuffer, 0, input, 0, gpuWriteBuffer.size);
|
||||
const pipeline = await device.createComputePipelineAsync({layout: "auto", compute: { module: device.createShaderModule({ code: f16tof32 }), entryPoint: "main" }});
|
||||
|
||||
addComputePass(device, commandEncoder, pipeline, [input, output], [Math.ceil(inf16.length/(4*256)), 1, 1]);
|
||||
|
||||
commandEncoder.copyBufferToBuffer(output, 0, gpuReadBuffer, 0, output.size);
|
||||
const gpuCommands = commandEncoder.finish();
|
||||
device.queue.submit([gpuCommands]);
|
||||
|
||||
await gpuReadBuffer.mapAsync(GPUMapMode.READ);
|
||||
const resultBuffer = new Float32Array(gpuReadBuffer.size/4);
|
||||
resultBuffer.set(new Float32Array(gpuReadBuffer.getMappedRange()));
|
||||
gpuReadBuffer.unmap();
|
||||
|
||||
return resultBuffer;
|
||||
}
|
||||
@@ -5,103 +5,213 @@
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>tinygrad has WebGPU</title>
|
||||
|
||||
<style>
|
||||
body {
|
||||
font-family: 'Arial', sans-serif;
|
||||
text-align: center;
|
||||
padding: 30px;
|
||||
/* General Reset */
|
||||
* {
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
box-sizing: border-box;
|
||||
}
|
||||
|
||||
a {
|
||||
text-decoration: none;
|
||||
color: #4A90E2;
|
||||
body {
|
||||
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
|
||||
background: #f4f7fb;
|
||||
color: #333;
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
min-height: 100vh;
|
||||
padding: 20px;
|
||||
flex-direction: column;
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
h1 {
|
||||
font-size: 36px;
|
||||
font-weight: normal;
|
||||
font-size: 2.2rem;
|
||||
margin-bottom: 20px;
|
||||
color: #4A90E2;
|
||||
}
|
||||
|
||||
#wgpuError {
|
||||
color: red;
|
||||
font-size: 1.2rem;
|
||||
margin-top: 20px;
|
||||
display: none;
|
||||
}
|
||||
|
||||
#sdTitle {
|
||||
font-size: 1.5rem;
|
||||
margin-bottom: 30px;
|
||||
}
|
||||
|
||||
#mybox {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
gap: 20px;
|
||||
width: 50%;
|
||||
margin: 0 auto;
|
||||
background: #ffffff;
|
||||
padding: 15px;
|
||||
border-radius: 10px;
|
||||
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.1);
|
||||
width: 120%;
|
||||
max-width: 550px;
|
||||
margin-bottom: 10px;
|
||||
}
|
||||
|
||||
#promptText, #stepRange, #btnRunNet, #guidanceRange {
|
||||
font-size: 18px;
|
||||
input[type="text"] {
|
||||
width: 100%;
|
||||
padding: 12px;
|
||||
margin-bottom: 20px;
|
||||
font-size: 1rem;
|
||||
border: 1px solid #ccc;
|
||||
border-radius: 8px;
|
||||
outline: none;
|
||||
transition: all 0.3s ease;
|
||||
}
|
||||
|
||||
#result {
|
||||
font-size: 48px;
|
||||
}
|
||||
|
||||
#time {
|
||||
font-size: 16px;
|
||||
color: grey;
|
||||
}
|
||||
|
||||
canvas {
|
||||
margin-top: 20px;
|
||||
border: 1px solid #000;
|
||||
input[type="text"]:focus {
|
||||
border-color: #4A90E2;
|
||||
}
|
||||
|
||||
label {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
font-size: 1rem;
|
||||
margin-bottom: 15px;
|
||||
align-items: center;
|
||||
gap: 10px;
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
#sliderValue {
|
||||
margin-right: 10px;/
|
||||
|
||||
input[type="range"] {
|
||||
width: 100%;
|
||||
margin-left: 10px;
|
||||
-webkit-appearance: none;
|
||||
appearance: none;
|
||||
height: 8px;
|
||||
border-radius: 4px;
|
||||
background: #ddd;
|
||||
outline: none;
|
||||
transition: background 0.3s ease;
|
||||
}
|
||||
|
||||
input[type="range"]:focus {
|
||||
background: #4A90E2;
|
||||
}
|
||||
|
||||
#stepRange,
|
||||
#guidanceRange {
|
||||
width: 80%;
|
||||
}
|
||||
|
||||
span {
|
||||
font-size: 1.1rem;
|
||||
font-weight: 600;
|
||||
color: #333;
|
||||
}
|
||||
|
||||
input[type="button"] {
|
||||
padding: 12px 25px;
|
||||
background-color: #4A90E2;
|
||||
color: #fff;
|
||||
font-size: 1.2rem;
|
||||
border: none;
|
||||
border-radius: 8px;
|
||||
cursor: pointer;
|
||||
transition: background-color 0.3s ease;
|
||||
width: 100%;
|
||||
margin-bottom: 20px;
|
||||
}
|
||||
|
||||
input[type="button"]:disabled {
|
||||
background-color: #ccc;
|
||||
cursor: not-allowed;
|
||||
}
|
||||
|
||||
input[type="button"]:hover {
|
||||
background-color: #357ABD;
|
||||
}
|
||||
|
||||
#divModelDl, #divStepProgress {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
margin-bottom: 20px;
|
||||
}
|
||||
|
||||
#modelDlProgressBar,
|
||||
#progressBar {
|
||||
width: 80%;
|
||||
height: 12px;
|
||||
border-radius: 6px;
|
||||
background-color: #e0e0e0;
|
||||
}
|
||||
|
||||
#modelDlProgressBar::-webkit-progress-bar,
|
||||
#progressBar::-webkit-progress-bar {
|
||||
border-radius: 6px;
|
||||
}
|
||||
|
||||
#modelDlProgressValue, #progressFraction {
|
||||
font-size: 1rem;
|
||||
font-weight: 600;
|
||||
color: #333;
|
||||
}
|
||||
|
||||
canvas {
|
||||
max-width: 100%;
|
||||
max-height: 450px;
|
||||
margin-top: 10px;
|
||||
border-radius: 8px;
|
||||
border: 1px solid #ddd;
|
||||
}
|
||||
</style>
|
||||
|
||||
<script type="module">
|
||||
import ClipTokenizer from 'https://softwired.nyc3.cdn.digitaloceanspaces.com/sd/clip_tokenizer.js';
|
||||
import ClipTokenizer from './clip_tokenizer.js';
|
||||
window.clipTokenizer = new ClipTokenizer();
|
||||
</script>
|
||||
<script src="./f16_to_f32.js"></script>
|
||||
<script type="module">
|
||||
import { f16tof32GPU } from 'https://unpkg.com/f16-to-f32-gpu@0.1.0/src/index.js';
|
||||
window.f16tof32GPU = f16tof32GPU;
|
||||
</script>
|
||||
<script src="./net.js"></script>
|
||||
</head>
|
||||
|
||||
<body>
|
||||
<h1 id="wgpuError" style="display: none; color: red;">WebGPU is not supported in this browser</h1>
|
||||
<h1 id="sdTitle">StableDiffusion by <a href="https://github.com/tinygrad/tinygrad" target="_blank">tinygrad</a> WebGPU</h1>
|
||||
<div id="mybox">
|
||||
<input id="promptText" type="text" placeholder="Enter your prompt here" value="a horse sized cat eating a bagel">
|
||||
<h1 id="wgpuError" style="display: none;">WebGPU is not supported in this browser</h1>
|
||||
<h1 id="sdTitle">StableDiffusion powered by <a href="https://github.com/tinygrad/tinygrad" target="_blank" style="color: #4A90E2;">tinygrad</a></h1>
|
||||
<a href="https://github.com/tinygrad/tinygrad" target="_blank" style="position: absolute; top: 20px; right: 20px;">
|
||||
<img src="https://upload.wikimedia.org/wikipedia/commons/9/91/Octicons-mark-github.svg"
|
||||
alt="GitHub Logo"
|
||||
style="width: 32px; height: 32px;">
|
||||
</a>
|
||||
|
||||
<label>
|
||||
Steps: <span id="stepValue">8</span>
|
||||
<input id="stepRange" type="range" min="5" max="20" value="8" step="1">
|
||||
</label>
|
||||
<div id="mybox">
|
||||
<input id="promptText" type="text" placeholder="Enter your prompt here" value="a human standing on the surface of mars">
|
||||
|
||||
<label>
|
||||
Guidance: <span id="guidanceValue">7.5</span>
|
||||
<input id="guidanceRange" type="range" min="3" max="15" value="7.5" step="0.1">
|
||||
</label>
|
||||
|
||||
<input id="btnRunNet" type="button" value="Run" disabled>
|
||||
<label>
|
||||
Steps: <span id="stepValue">9</span>
|
||||
<input id="stepRange" type="range" min="5" max="20" value="9" step="1">
|
||||
</label>
|
||||
|
||||
<div id="divModelDl" style="display: flex; align-items: center; width: 100%; gap: 10px;">
|
||||
<span id="modelDlTitle">Downloading model</span>
|
||||
<progress id="modelDlProgressBar" value="0" max="100" style="flex-grow: 1;"></progress>
|
||||
<span id="modelDlProgressValue"></span>
|
||||
<label>
|
||||
Guidance: <span id="guidanceValue">8.0</span>
|
||||
<input id="guidanceRange" type="range" min="3" max="15" value="8.0" step="0.1">
|
||||
</label>
|
||||
|
||||
<input id="btnRunNet" type="button" value="Run" disabled>
|
||||
|
||||
<div id="divModelDl" style="display: flex; align-items: center; width: 100%; gap: 10px;">
|
||||
<span id="modelDlTitle">Downloading model</span>
|
||||
<progress id="modelDlProgressBar" value="0" max="100" style="flex-grow: 1;"></progress>
|
||||
<span id="modelDlProgressValue"></span>
|
||||
</div>
|
||||
|
||||
<div id="divStepProgress" style="display: none; align-items: center; width: 100%; gap: 10px;">
|
||||
<progress id="progressBar" value="0" max="100" style="flex-grow: 1;"></progress>
|
||||
<span id="progressFraction"></span>
|
||||
</div>
|
||||
|
||||
<div id="divStepTime" style="display: none; align-items: center; width: 100%; gap: 10px;">
|
||||
<span id="stepTimeValue">0 ms</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div id="divStepProgress" style="display: none; align-items: center; width: 100%; gap: 10px;">
|
||||
<progress id="progressBar" value="0" max="100" style="flex-grow: 1;"></progress>
|
||||
<span id="progressFraction"></span>
|
||||
</div>
|
||||
</div>
|
||||
<canvas id="canvas" width="512" height="512"></canvas>
|
||||
<canvas id="canvas" width="512" height="512"></canvas>
|
||||
|
||||
<script>
|
||||
function initDb() {
|
||||
@@ -129,30 +239,36 @@
|
||||
}
|
||||
|
||||
function saveTensorToDb(db, id, tensor) {
|
||||
return new Promise((resolve, reject) => {
|
||||
if (db == null) {
|
||||
resolve(null);
|
||||
return readTensorFromDb(db, id).then((result) => {
|
||||
if (!result) {
|
||||
new Promise((resolve, reject) => {
|
||||
if (db == null) {
|
||||
resolve(null);
|
||||
}
|
||||
|
||||
const transaction = db.transaction(['tensors'], 'readwrite');
|
||||
const store = transaction.objectStore('tensors');
|
||||
const request = store.put({ id: id, content: tensor });
|
||||
|
||||
transaction.onabort = (event) => {
|
||||
console.log("Transaction error while saving tensor: " + event.target.error);
|
||||
resolve(null);
|
||||
};
|
||||
|
||||
request.onsuccess = () => {
|
||||
console.log('Tensor saved successfully.');
|
||||
resolve();
|
||||
};
|
||||
|
||||
request.onerror = (event) => {
|
||||
console.error('Tensor save failed:', event.target.error);
|
||||
resolve(null);
|
||||
};
|
||||
});
|
||||
} else {
|
||||
return null;
|
||||
}
|
||||
|
||||
const transaction = db.transaction(['tensors'], 'readwrite');
|
||||
const store = transaction.objectStore('tensors');
|
||||
const request = store.put({ id: id, content: tensor });
|
||||
|
||||
transaction.onabort = (event) => {
|
||||
console.log("Transaction error while saving tensor: " + event.target.error);
|
||||
resolve(null);
|
||||
};
|
||||
|
||||
request.onsuccess = () => {
|
||||
console.log('Tensor saved successfully.');
|
||||
resolve();
|
||||
};
|
||||
|
||||
request.onerror = (event) => {
|
||||
console.error('Tensor save failed:', event.target.error);
|
||||
resolve(null);
|
||||
};
|
||||
});
|
||||
}).catch(()=> null);
|
||||
}
|
||||
|
||||
function readTensorFromDb(db, id) {
|
||||
@@ -173,10 +289,8 @@
|
||||
request.onsuccess = (event) => {
|
||||
const result = event.target.result;
|
||||
if (result) {
|
||||
console.log("Cache hit: " + id);
|
||||
resolve(result);
|
||||
} else {
|
||||
console.log("Cache miss: " + id);
|
||||
resolve(null);
|
||||
}
|
||||
};
|
||||
@@ -190,7 +304,7 @@
|
||||
|
||||
window.addEventListener('load', async function() {
|
||||
if (!navigator.gpu) {
|
||||
document.getElementById("wgpuError").style.display = "";
|
||||
document.getElementById("wgpuError").style.display = "block";
|
||||
document.getElementById("sdTitle").style.display = "none";
|
||||
return;
|
||||
}
|
||||
@@ -247,6 +361,18 @@
|
||||
let totalSize = 0;
|
||||
let partSize = {};
|
||||
|
||||
const getPart = async(key) => {
|
||||
let part = await readTensorFromDb(db, key);
|
||||
|
||||
if (part) {
|
||||
console.log(`Cache hit: ${key}`);
|
||||
return Promise.resolve(part.content);
|
||||
} else {
|
||||
console.log(`Cache miss: ${key}`);
|
||||
return getProgressDlForPart(`${window.MODEL_BASE_URL}/${key}.safetensors`, progressCallback);
|
||||
}
|
||||
}
|
||||
|
||||
const progressCallback = (part, loaded, total) => {
|
||||
totalLoaded += loaded;
|
||||
|
||||
@@ -258,63 +384,31 @@
|
||||
progress(totalLoaded, totalSize);
|
||||
};
|
||||
|
||||
let combinedBuffer = await readTensorFromDb(db, "net.f16");
|
||||
let textModelU8 = await readTensorFromDb(db, "net.text");
|
||||
let textModelFetched = false;
|
||||
let netKeys = ["net_part0", "net_part1", "net_part2", "net_part3", "net_textmodel"];
|
||||
let buffers = await Promise.all(netKeys.map(key => getPart(key)));
|
||||
|
||||
if (combinedBuffer == null) {
|
||||
let dlParts = [
|
||||
getProgressDlForPart(window.MODEL_BASE_URL + '/net_part0.safetensors', progressCallback),
|
||||
getProgressDlForPart(window.MODEL_BASE_URL + '/net_part1.safetensors', progressCallback),
|
||||
getProgressDlForPart(window.MODEL_BASE_URL + '/net_part2.safetensors', progressCallback),
|
||||
getProgressDlForPart(window.MODEL_BASE_URL + '/net_part3.safetensors', progressCallback)
|
||||
];
|
||||
|
||||
if (textModelU8 == null) {
|
||||
dlParts.push(getProgressDlForPart(window.MODEL_BASE_URL + '/net_textmodel.safetensors', progressCallback));
|
||||
// Combine everything except for text model, since that's already f32
|
||||
const totalLength = buffers.reduce((acc, buffer, index, array) => {
|
||||
if (index < 4) {
|
||||
return acc + buffer.byteLength;
|
||||
} else {
|
||||
return acc;
|
||||
}
|
||||
}, 0
|
||||
);
|
||||
|
||||
let buffers = await Promise.all(dlParts);
|
||||
|
||||
// Combine everything except for text model, since that's alreafy f32
|
||||
const totalLength = buffers.reduce((acc, buffer, index, array) => {
|
||||
if (index < 4) {
|
||||
return acc + buffer.byteLength;
|
||||
} else {
|
||||
return acc;
|
||||
}
|
||||
}, 0
|
||||
);
|
||||
|
||||
combinedBuffer = new Uint8Array(totalLength);
|
||||
|
||||
let offset = 0;
|
||||
buffers.forEach((buffer, index) => {
|
||||
if (index < 4) {
|
||||
combinedBuffer.set(new Uint8Array(buffer), offset);
|
||||
offset += buffer.byteLength;
|
||||
buffer = null;
|
||||
}
|
||||
});
|
||||
|
||||
await saveTensorToDb(db, "net.f16", combinedBuffer);
|
||||
|
||||
if (textModelU8 == null) {
|
||||
textModelFetched = true;
|
||||
textModelU8 = new Uint8Array(buffers[4]);
|
||||
await saveTensorToDb(db, "net.text", textModelU8);
|
||||
combinedBuffer = new Uint8Array(totalLength);
|
||||
let offset = 0;
|
||||
buffers.forEach((buffer, index) => {
|
||||
saveTensorToDb(db, netKeys[index], new Uint8Array(buffer));
|
||||
if (index < 4) {
|
||||
combinedBuffer.set(new Uint8Array(buffer), offset);
|
||||
offset += buffer.byteLength;
|
||||
buffer = null;
|
||||
}
|
||||
} else {
|
||||
combinedBuffer = combinedBuffer.content;
|
||||
}
|
||||
|
||||
if (textModelU8 == null) {
|
||||
textModelU8 = new Uint8Array(await getProgressDlForPart(window.MODEL_BASE_URL + '/net_textmodel.safetensors', progressCallback));
|
||||
await saveTensorToDb(db, "net.text", textModelU8);
|
||||
} else if (!textModelFetched) {
|
||||
textModelU8 = textModelU8.content;
|
||||
}
|
||||
});
|
||||
|
||||
let textModelU8 = new Uint8Array(buffers[4]);
|
||||
document.getElementById("modelDlTitle").innerHTML = "Decompressing model";
|
||||
|
||||
const textModelOffset = 3772703308;
|
||||
@@ -346,16 +440,7 @@
|
||||
let chunkStartF16 = 8 + metadataLength + (decodeChunkSize * i);
|
||||
let chunkEndF16 = chunkStartF16 + decodeChunkSize;
|
||||
let chunk = combinedBuffer.subarray(chunkStartF16, chunkEndF16);
|
||||
|
||||
if (chunk.byteLength %4 != 0) {
|
||||
const paddingBytes = 4 - (chunk.byteLength % 4);
|
||||
const alignedBuffer = new ArrayBuffer(chunk.byteLength + paddingBytes);
|
||||
const alignedView = new Uint8Array(alignedBuffer);
|
||||
alignedView.set(new Uint8Array(chunk));
|
||||
chunk = alignedView;
|
||||
}
|
||||
|
||||
let result = await f16tof32GPU(device, chunk);
|
||||
let result = await f16tof32GPU(chunk);
|
||||
let resultUint8 = new Uint8Array(result.buffer);
|
||||
let chunkStartF32 = 8 + metadataLength + (decodeChunkSize * i * 2);
|
||||
let chunkEndF32 = chunkStartF32 + resultUint8.byteLength;
|
||||
@@ -421,7 +506,7 @@
|
||||
document.getElementById("btnRunNet").disabled = false;
|
||||
}
|
||||
|
||||
function runStableDiffusion(prompt, steps, guidance) {
|
||||
function runStableDiffusion(prompt, steps, guidance, showStep) {
|
||||
return new Promise(async (resolve, reject) => {
|
||||
let context = await timer(() => nets["textModel"](clipTokenizer.encodeForCLIP(prompt)));
|
||||
let unconditional_context = await timer(() => nets["textModel"](clipTokenizer.encodeForCLIP("")));
|
||||
@@ -456,8 +541,16 @@
|
||||
|
||||
for (let i = timesteps.length - 1; i >= 0; i--) {
|
||||
let timestep = new Float32Array([timesteps[i]]);
|
||||
let x_prev = await timer(() => nets["diffusor"](unconditional_context, context, latent, timestep, new Float32Array([alphas[i]]), new Float32Array([alphas_prev[i]]), new Float32Array([guidance])));
|
||||
let start = performance.now()
|
||||
let x_prev = await nets["diffusor"](unconditional_context, context, latent, timestep, new Float32Array([alphas[i]]), new Float32Array([alphas_prev[i]]), new Float32Array([guidance]));
|
||||
document.getElementById("divStepTime").style.display = "block";
|
||||
document.getElementById("stepTimeValue").innerText = `${(performance.now() - start).toFixed(1)} ms / step`
|
||||
latent = x_prev;
|
||||
|
||||
if (showStep != null) {
|
||||
showStep(await nets["decoder"](latent));
|
||||
}
|
||||
|
||||
document.getElementById("progressBar").value = ((steps - i) / steps) * 100
|
||||
document.getElementById("progressFraction").innerHTML = (steps - i) + "/" + steps
|
||||
}
|
||||
@@ -466,27 +559,37 @@
|
||||
});
|
||||
}
|
||||
|
||||
function renderImage(e, image) {
|
||||
let pixels = []
|
||||
let pixelCounter = 0
|
||||
|
||||
for (let j = 0; j < 512; j++) {
|
||||
for (let k = 0; k < 512; k++) {
|
||||
pixels.push(image[pixelCounter])
|
||||
pixels.push(image[pixelCounter+1])
|
||||
pixels.push(image[pixelCounter+2])
|
||||
pixels.push(255)
|
||||
pixelCounter += 3
|
||||
}
|
||||
}
|
||||
|
||||
ctx.putImageData(new ImageData(new Uint8ClampedArray(pixels), 512, 512), 0, 0);
|
||||
e.target.disabled = false;
|
||||
}
|
||||
|
||||
document.getElementById("btnRunNet").addEventListener("click", function(e) {
|
||||
e.target.disabled = true;
|
||||
const canvas = document.getElementById("canvas");
|
||||
ctx.clearRect(0, 0, canvas.width, canvas.height);
|
||||
|
||||
runStableDiffusion(document.getElementById("promptText").value, document.getElementById("stepRange").value, document.getElementById("guidanceRange").value).then((image) => {
|
||||
let pixels = []
|
||||
let pixelCounter = 0
|
||||
|
||||
for (let j = 0; j < 512; j++) {
|
||||
for (let k = 0; k < 512; k++) {
|
||||
pixels.push(image[pixelCounter])
|
||||
pixels.push(image[pixelCounter+1])
|
||||
pixels.push(image[pixelCounter+2])
|
||||
pixels.push(255)
|
||||
pixelCounter += 3
|
||||
}
|
||||
}
|
||||
|
||||
ctx.putImageData(new ImageData(new Uint8ClampedArray(pixels), 512, 512), 0, 0);
|
||||
console.log(image);
|
||||
console.log("Success");
|
||||
e.target.disabled = false;
|
||||
runStableDiffusion(
|
||||
document.getElementById("promptText").value,
|
||||
document.getElementById("stepRange").value,
|
||||
document.getElementById("guidanceRange").value,
|
||||
// Decode at each step
|
||||
null
|
||||
).then((image) => {
|
||||
renderImage(e, image);
|
||||
});
|
||||
}, false);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user