Bring back WebGPU (#7063)

* Start from andredaprato:webgpu-clean

* Fix infs

* inf wgsl function is not needed

* Emulated ulong for threefry, more tests passing

* Randomness tests passing

* Update model export to support new changes in webgpu, efficientnet export works again

* Simplify shift emulation in wgsl

* Delete test file

* Fix bigger than u32 u32 literal

* Why was skip copies added here?

* Python3.12 for webgpu tests

* Fix model export syntax error

* Get test ops passing with some skips

* Fix lint

* Much simpler shift

* Run more tests

* Timestamp queries are not supported in CI, so skip search tests

* All fancy indexing passing

* r is ctx

* Run more dtype tests by using is_dtype_supported

* Cleanup ulong shift rendering

* UPat -> Pat, UOps -> Ops

* Pat -> UPat

* Refactor render_ushift if-else

* Pattern to avoid ulong mul

* Remove vals_dtype

* is_nan trick + rewrite, test_isnan passing

* Rewrite a * select(1, nan, gate) -> select(a, nan, gate)

* No arg, just op

* Support char, uchar, short, ushort

* Run test_index_mnis now that we have uint8

* Fix pyling

* Save 3 lines by using base Compiler

* No more long emulation

* Remove fixup_binops

* No more external_local_bufx wgsl specific cstyle modif, use base extra_pm

* Simpler, faster copyin/out

* Skip some new tests that use long

* Fix typo

* copyout touchup

* Save lines by using render_cast

* WebGL is not supported in core, delete it from is_dtype_supported

* More narrow test skips for some unary tests

* TernaryOps, UnaryOps -> Ops

* TinyGrad supports WebGPU

* StableDiffusion demo: f16tof32 gpu is a lib, update UI

* Packed load/store, no more scale_size, no core tinygrad changes

* Rename copyin, copyout

* Device -> dev

* Fix lint

* Pattern matcher rule for packed load/store

* Refactor

* Shorter packed load/store

* this should fix lint

* Fix mypy

* SD compile script working

* New SD webgpu UI

* New default prompt

* New SD weights

* Fix title when webgpu not available

* Run symbolic tests, simplify is_nan, use round_up

* Show step time on UI

* Bump minimum wgpu version to v0.19

* Fix latent

---------

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
Ahmed Harmouche
2024-11-26 05:26:40 +01:00
committed by GitHub
parent ff3f2a9c1a
commit 10618aba98
18 changed files with 659 additions and 402 deletions

View File

@@ -1,7 +1,7 @@
from pathlib import Path
from extra.models.efficientnet import EfficientNet
from tinygrad.tensor import Tensor
from tinygrad.nn.state import safe_save
from tinygrad.nn.state import get_state_dict, safe_save, safe_load, load_state_dict
from extra.export_model import export_model
from tinygrad.helpers import getenv, fetch
import ast
@@ -9,11 +9,15 @@ import ast
if __name__ == "__main__":
model = EfficientNet(0)
model.load_from_pretrained()
dirname = Path(__file__).parent
# exporting a model that's loaded from safetensors doesn't work without loading in from safetensors first
# loading the state dict from a safetensor file changes the generated kernels
if getenv("WEBGPU") or getenv("WEBGL"):
safe_save(get_state_dict(model), (dirname / "net.safetensors").as_posix())
load_state_dict(model, safe_load(str(dirname / "net.safetensors")))
mode = "clang" if getenv("CLANG", "") != "" else "webgpu" if getenv("WEBGPU", "") != "" else "webgl" if getenv("WEBGL", "") != "" else ""
prg, inp_sizes, out_sizes, state = export_model(model, mode, Tensor.randn(1,3,224,224))
dirname = Path(__file__).parent
if getenv("CLANG", "") == "":
safe_save(state, (dirname / "net.safetensors").as_posix())
ext = "js" if getenv("WEBGPU", "") != "" or getenv("WEBGL", "") != "" else "json"
with open(dirname / f"net.{ext}", "w") as text_file:
text_file.write(prg)

View File

@@ -6,7 +6,7 @@ from tinygrad.tensor import Tensor
from tinygrad import Device
from tinygrad.helpers import fetch
from typing import NamedTuple, Any, List
from pathlib import Path
import requests
import argparse
import numpy as np
@@ -60,16 +60,22 @@ def split_safetensor(fn):
cur_pos = 0
for i, end_pos in enumerate(part_end_offsets):
with open(f'./net_part{i}.safetensors', "wb+") as f:
with open(os.path.join(os.path.dirname(__file__), f'./net_part{i}.safetensors'), "wb+") as f:
f.write(net_bytes[cur_pos:end_pos])
cur_pos = end_pos
with open(f'./net_textmodel.safetensors', "wb+") as f:
with open(os.path.join(os.path.dirname(__file__), f'./net_textmodel.safetensors'), "wb+") as f:
f.write(net_bytes[text_model_start+8+json_len:])
return part_end_offsets
def fetch_dep(file, url):
with open(file, "w", encoding="utf-8") as f:
f.write(requests.get(url).text.replace("https://huggingface.co/wpmed/tinygrad-sd-f16/raw/main/bpe_simple_vocab_16e6.mjs", "./bpe_simple_vocab_16e6.mjs"))
if __name__ == "__main__":
fetch_dep(os.path.join(os.path.dirname(__file__), "clip_tokenizer.js"), "https://huggingface.co/wpmed/tinygrad-sd-f16/raw/main/clip_tokenizer.js")
fetch_dep(os.path.join(os.path.dirname(__file__), "bpe_simple_vocab_16e6.mjs"), "https://huggingface.co/wpmed/tinygrad-sd-f16/raw/main/bpe_simple_vocab_16e6.mjs")
parser = argparse.ArgumentParser(description='Run Stable Diffusion', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--remoteweights', action='store_true', help="Use safetensors from Huggingface, or from local")
args = parser.parse_args()
@@ -94,12 +100,21 @@ if __name__ == "__main__":
prg = ""
def fixup_code(code, key):
code = code.replace(key, 'main')\
.replace("var<uniform> INFINITY : f32;\n", "fn inf(a: f32) -> f32 { return a/0.0; }\n")\
.replace("@group(0) @binding(0)", "")\
.replace("INFINITY", "inf(1.0)")
for i in range(1,9): code = code.replace(f"binding({i})", f"binding({i-1})")
return code
def compile_step(model, step: Step):
run, special_names = jit_model(step, *step.input)
functions, statements, bufs, _ = compile_net(run, special_names)
state = get_state_dict(model)
weights = {id(x.lazydata.base.realized): name for name, x in state.items()}
kernel_code = '\n\n'.join([f"const {key} = `{code.replace(key, 'main')}`;" for key, code in functions.items()])
kernel_code = '\n\n'.join([f"const {key} = `{fixup_code(code, key)}`;" for key, code in functions.items()])
kernel_names = ', '.join([name for (name, _, _, _) in statements])
kernel_calls = '\n '.join([f"addComputePass(device, commandEncoder, piplines[{i}], [{', '.join(args)}], {global_size});" for i, (_name, args, global_size, _local_size) in enumerate(statements) ])
bufs = '\n '.join([f"const {name} = " + (f"createEmptyBuf(device, {size});" if _key not in weights else f"createWeightBuf(device, {size}, getTensorBuffer(safetensor, metadata['{weights[_key]}'], '{weights[_key]}'))") + ";" for name,(size,dtype,_key) in bufs.items()])
@@ -148,14 +163,14 @@ if __name__ == "__main__":
if step.name == "diffusor":
if args.remoteweights:
base_url = "https://huggingface.co/wpmed/tinygrad-sd-f16/resolve/main"
base_url = "https://huggingface.co/wpmed/stable-diffusion-f16-new/resolve/main"
else:
state = get_state_dict(model)
safe_save(state, os.path.join(os.path.dirname(__file__), "net.safetensors"))
convert_f32_to_f16("./net.safetensors", "./net_conv.safetensors")
split_safetensor("./net_conv.safetensors")
os.remove("net.safetensors")
os.remove("net_conv.safetensors")
convert_f32_to_f16(os.path.join(os.path.dirname(__file__), "./net.safetensors"), os.path.join(os.path.dirname(__file__), "./net_conv.safetensors"))
split_safetensor(os.path.join(os.path.dirname(__file__), "./net_conv.safetensors"))
os.remove(os.path.join(os.path.dirname(__file__), "net.safetensors"))
os.remove(os.path.join(os.path.dirname(__file__), "net_conv.safetensors"))
base_url = "."
prekernel = f"""
@@ -185,20 +200,6 @@ if __name__ == "__main__":
counter++;
}}
let allZero = true;
let out = safetensorParts[selectedPart].subarray(...correctedOffsets);
for (let i = 0; i < out.length; i++) {{
if (out[i] !== 0) {{
allZero = false;
break;
}}
}}
if (allZero) {{
console.log("Error: weight '" + key + "' is all zero.");
}}
return safetensorParts[selectedPart].subarray(...correctedOffsets);
}}

View File

@@ -1,64 +0,0 @@
const f16tof32 = `
fn u16_to_f16(x: u32) -> f32 {
let sign = f32((x >> 15) & 0x1);
let exponent = f32((x >> 10) & 0x1F);
let fraction = f32(x & 0x3FF);
let sign_multiplier = select(1.0, -1.0, sign == 1.0);
if (exponent == 0.0) {
return sign_multiplier * 6.103515625e-5 * (fraction / 1024.0);
} else {
return sign_multiplier * exp2(exponent - 15.0) * (1.0 + fraction / 1024.0);
}
}
@group(0) @binding(0) var<storage,read_write> data0: array<u32>;
@group(0) @binding(1) var<storage,read_write> data1: array<f32>;
@compute @workgroup_size(256) fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let gidx = gid.x;
let outgidx = gidx*2;
if (gidx >= arrayLength(&data0)) {
return;
}
let oo = data0[gidx];
let oo1 = (oo >> 16);
let oo2 = oo & 0xFFFFu;
let f1 = u16_to_f16(oo2);
let f2 = u16_to_f16(oo1);
data1[outgidx] = f1;
data1[outgidx + 1] = f2;
}`;
window.f16tof32GPU = async(device, inf16) => {
const input = device.createBuffer({size: inf16.length, usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST });
const output = device.createBuffer({size: inf16.length*2, usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST });
const gpuWriteBuffer = device.createBuffer({size: input.size, usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.MAP_WRITE });
const gpuReadBuffer = device.createBuffer({ size: output.size, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ });
const commandEncoder = device.createCommandEncoder();
await gpuWriteBuffer.mapAsync(GPUMapMode.WRITE);
const alignedUint32View = new Uint32Array(inf16.buffer, inf16.byteOffset, inf16.length / 4);
new Uint32Array(gpuWriteBuffer.getMappedRange()).set(alignedUint32View);
gpuWriteBuffer.unmap();
commandEncoder.copyBufferToBuffer(gpuWriteBuffer, 0, input, 0, gpuWriteBuffer.size);
const pipeline = await device.createComputePipelineAsync({layout: "auto", compute: { module: device.createShaderModule({ code: f16tof32 }), entryPoint: "main" }});
addComputePass(device, commandEncoder, pipeline, [input, output], [Math.ceil(inf16.length/(4*256)), 1, 1]);
commandEncoder.copyBufferToBuffer(output, 0, gpuReadBuffer, 0, output.size);
const gpuCommands = commandEncoder.finish();
device.queue.submit([gpuCommands]);
await gpuReadBuffer.mapAsync(GPUMapMode.READ);
const resultBuffer = new Float32Array(gpuReadBuffer.size/4);
resultBuffer.set(new Float32Array(gpuReadBuffer.getMappedRange()));
gpuReadBuffer.unmap();
return resultBuffer;
}

View File

@@ -5,103 +5,213 @@
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>tinygrad has WebGPU</title>
<style>
body {
font-family: 'Arial', sans-serif;
text-align: center;
padding: 30px;
/* General Reset */
* {
margin: 0;
padding: 0;
box-sizing: border-box;
}
a {
text-decoration: none;
color: #4A90E2;
body {
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
background: #f4f7fb;
color: #333;
display: flex;
justify-content: center;
align-items: center;
min-height: 100vh;
padding: 20px;
flex-direction: column;
text-align: center;
}
h1 {
font-size: 36px;
font-weight: normal;
font-size: 2.2rem;
margin-bottom: 20px;
color: #4A90E2;
}
#wgpuError {
color: red;
font-size: 1.2rem;
margin-top: 20px;
display: none;
}
#sdTitle {
font-size: 1.5rem;
margin-bottom: 30px;
}
#mybox {
display: flex;
flex-direction: column;
align-items: center;
gap: 20px;
width: 50%;
margin: 0 auto;
background: #ffffff;
padding: 15px;
border-radius: 10px;
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.1);
width: 120%;
max-width: 550px;
margin-bottom: 10px;
}
#promptText, #stepRange, #btnRunNet, #guidanceRange {
font-size: 18px;
input[type="text"] {
width: 100%;
padding: 12px;
margin-bottom: 20px;
font-size: 1rem;
border: 1px solid #ccc;
border-radius: 8px;
outline: none;
transition: all 0.3s ease;
}
#result {
font-size: 48px;
}
#time {
font-size: 16px;
color: grey;
}
canvas {
margin-top: 20px;
border: 1px solid #000;
input[type="text"]:focus {
border-color: #4A90E2;
}
label {
display: flex;
justify-content: space-between;
font-size: 1rem;
margin-bottom: 15px;
align-items: center;
gap: 10px;
width: 100%;
}
#sliderValue {
margin-right: 10px;/
input[type="range"] {
width: 100%;
margin-left: 10px;
-webkit-appearance: none;
appearance: none;
height: 8px;
border-radius: 4px;
background: #ddd;
outline: none;
transition: background 0.3s ease;
}
input[type="range"]:focus {
background: #4A90E2;
}
#stepRange,
#guidanceRange {
width: 80%;
}
span {
font-size: 1.1rem;
font-weight: 600;
color: #333;
}
input[type="button"] {
padding: 12px 25px;
background-color: #4A90E2;
color: #fff;
font-size: 1.2rem;
border: none;
border-radius: 8px;
cursor: pointer;
transition: background-color 0.3s ease;
width: 100%;
margin-bottom: 20px;
}
input[type="button"]:disabled {
background-color: #ccc;
cursor: not-allowed;
}
input[type="button"]:hover {
background-color: #357ABD;
}
#divModelDl, #divStepProgress {
display: flex;
justify-content: space-between;
align-items: center;
margin-bottom: 20px;
}
#modelDlProgressBar,
#progressBar {
width: 80%;
height: 12px;
border-radius: 6px;
background-color: #e0e0e0;
}
#modelDlProgressBar::-webkit-progress-bar,
#progressBar::-webkit-progress-bar {
border-radius: 6px;
}
#modelDlProgressValue, #progressFraction {
font-size: 1rem;
font-weight: 600;
color: #333;
}
canvas {
max-width: 100%;
max-height: 450px;
margin-top: 10px;
border-radius: 8px;
border: 1px solid #ddd;
}
</style>
<script type="module">
import ClipTokenizer from 'https://softwired.nyc3.cdn.digitaloceanspaces.com/sd/clip_tokenizer.js';
import ClipTokenizer from './clip_tokenizer.js';
window.clipTokenizer = new ClipTokenizer();
</script>
<script src="./f16_to_f32.js"></script>
<script type="module">
import { f16tof32GPU } from 'https://unpkg.com/f16-to-f32-gpu@0.1.0/src/index.js';
window.f16tof32GPU = f16tof32GPU;
</script>
<script src="./net.js"></script>
</head>
<body>
<h1 id="wgpuError" style="display: none; color: red;">WebGPU is not supported in this browser</h1>
<h1 id="sdTitle">StableDiffusion by <a href="https://github.com/tinygrad/tinygrad" target="_blank">tinygrad</a> WebGPU</h1>
<div id="mybox">
<input id="promptText" type="text" placeholder="Enter your prompt here" value="a horse sized cat eating a bagel">
<h1 id="wgpuError" style="display: none;">WebGPU is not supported in this browser</h1>
<h1 id="sdTitle">StableDiffusion powered by <a href="https://github.com/tinygrad/tinygrad" target="_blank" style="color: #4A90E2;">tinygrad</a></h1>
<a href="https://github.com/tinygrad/tinygrad" target="_blank" style="position: absolute; top: 20px; right: 20px;">
<img src="https://upload.wikimedia.org/wikipedia/commons/9/91/Octicons-mark-github.svg"
alt="GitHub Logo"
style="width: 32px; height: 32px;">
</a>
<label>
Steps: <span id="stepValue">8</span>
<input id="stepRange" type="range" min="5" max="20" value="8" step="1">
</label>
<div id="mybox">
<input id="promptText" type="text" placeholder="Enter your prompt here" value="a human standing on the surface of mars">
<label>
Guidance: <span id="guidanceValue">7.5</span>
<input id="guidanceRange" type="range" min="3" max="15" value="7.5" step="0.1">
</label>
<input id="btnRunNet" type="button" value="Run" disabled>
<label>
Steps: <span id="stepValue">9</span>
<input id="stepRange" type="range" min="5" max="20" value="9" step="1">
</label>
<div id="divModelDl" style="display: flex; align-items: center; width: 100%; gap: 10px;">
<span id="modelDlTitle">Downloading model</span>
<progress id="modelDlProgressBar" value="0" max="100" style="flex-grow: 1;"></progress>
<span id="modelDlProgressValue"></span>
<label>
Guidance: <span id="guidanceValue">8.0</span>
<input id="guidanceRange" type="range" min="3" max="15" value="8.0" step="0.1">
</label>
<input id="btnRunNet" type="button" value="Run" disabled>
<div id="divModelDl" style="display: flex; align-items: center; width: 100%; gap: 10px;">
<span id="modelDlTitle">Downloading model</span>
<progress id="modelDlProgressBar" value="0" max="100" style="flex-grow: 1;"></progress>
<span id="modelDlProgressValue"></span>
</div>
<div id="divStepProgress" style="display: none; align-items: center; width: 100%; gap: 10px;">
<progress id="progressBar" value="0" max="100" style="flex-grow: 1;"></progress>
<span id="progressFraction"></span>
</div>
<div id="divStepTime" style="display: none; align-items: center; width: 100%; gap: 10px;">
<span id="stepTimeValue">0 ms</span>
</div>
</div>
<div id="divStepProgress" style="display: none; align-items: center; width: 100%; gap: 10px;">
<progress id="progressBar" value="0" max="100" style="flex-grow: 1;"></progress>
<span id="progressFraction"></span>
</div>
</div>
<canvas id="canvas" width="512" height="512"></canvas>
<canvas id="canvas" width="512" height="512"></canvas>
<script>
function initDb() {
@@ -129,30 +239,36 @@
}
function saveTensorToDb(db, id, tensor) {
return new Promise((resolve, reject) => {
if (db == null) {
resolve(null);
return readTensorFromDb(db, id).then((result) => {
if (!result) {
new Promise((resolve, reject) => {
if (db == null) {
resolve(null);
}
const transaction = db.transaction(['tensors'], 'readwrite');
const store = transaction.objectStore('tensors');
const request = store.put({ id: id, content: tensor });
transaction.onabort = (event) => {
console.log("Transaction error while saving tensor: " + event.target.error);
resolve(null);
};
request.onsuccess = () => {
console.log('Tensor saved successfully.');
resolve();
};
request.onerror = (event) => {
console.error('Tensor save failed:', event.target.error);
resolve(null);
};
});
} else {
return null;
}
const transaction = db.transaction(['tensors'], 'readwrite');
const store = transaction.objectStore('tensors');
const request = store.put({ id: id, content: tensor });
transaction.onabort = (event) => {
console.log("Transaction error while saving tensor: " + event.target.error);
resolve(null);
};
request.onsuccess = () => {
console.log('Tensor saved successfully.');
resolve();
};
request.onerror = (event) => {
console.error('Tensor save failed:', event.target.error);
resolve(null);
};
});
}).catch(()=> null);
}
function readTensorFromDb(db, id) {
@@ -173,10 +289,8 @@
request.onsuccess = (event) => {
const result = event.target.result;
if (result) {
console.log("Cache hit: " + id);
resolve(result);
} else {
console.log("Cache miss: " + id);
resolve(null);
}
};
@@ -190,7 +304,7 @@
window.addEventListener('load', async function() {
if (!navigator.gpu) {
document.getElementById("wgpuError").style.display = "";
document.getElementById("wgpuError").style.display = "block";
document.getElementById("sdTitle").style.display = "none";
return;
}
@@ -247,6 +361,18 @@
let totalSize = 0;
let partSize = {};
const getPart = async(key) => {
let part = await readTensorFromDb(db, key);
if (part) {
console.log(`Cache hit: ${key}`);
return Promise.resolve(part.content);
} else {
console.log(`Cache miss: ${key}`);
return getProgressDlForPart(`${window.MODEL_BASE_URL}/${key}.safetensors`, progressCallback);
}
}
const progressCallback = (part, loaded, total) => {
totalLoaded += loaded;
@@ -258,63 +384,31 @@
progress(totalLoaded, totalSize);
};
let combinedBuffer = await readTensorFromDb(db, "net.f16");
let textModelU8 = await readTensorFromDb(db, "net.text");
let textModelFetched = false;
let netKeys = ["net_part0", "net_part1", "net_part2", "net_part3", "net_textmodel"];
let buffers = await Promise.all(netKeys.map(key => getPart(key)));
if (combinedBuffer == null) {
let dlParts = [
getProgressDlForPart(window.MODEL_BASE_URL + '/net_part0.safetensors', progressCallback),
getProgressDlForPart(window.MODEL_BASE_URL + '/net_part1.safetensors', progressCallback),
getProgressDlForPart(window.MODEL_BASE_URL + '/net_part2.safetensors', progressCallback),
getProgressDlForPart(window.MODEL_BASE_URL + '/net_part3.safetensors', progressCallback)
];
if (textModelU8 == null) {
dlParts.push(getProgressDlForPart(window.MODEL_BASE_URL + '/net_textmodel.safetensors', progressCallback));
// Combine everything except for text model, since that's already f32
const totalLength = buffers.reduce((acc, buffer, index, array) => {
if (index < 4) {
return acc + buffer.byteLength;
} else {
return acc;
}
}, 0
);
let buffers = await Promise.all(dlParts);
// Combine everything except for text model, since that's alreafy f32
const totalLength = buffers.reduce((acc, buffer, index, array) => {
if (index < 4) {
return acc + buffer.byteLength;
} else {
return acc;
}
}, 0
);
combinedBuffer = new Uint8Array(totalLength);
let offset = 0;
buffers.forEach((buffer, index) => {
if (index < 4) {
combinedBuffer.set(new Uint8Array(buffer), offset);
offset += buffer.byteLength;
buffer = null;
}
});
await saveTensorToDb(db, "net.f16", combinedBuffer);
if (textModelU8 == null) {
textModelFetched = true;
textModelU8 = new Uint8Array(buffers[4]);
await saveTensorToDb(db, "net.text", textModelU8);
combinedBuffer = new Uint8Array(totalLength);
let offset = 0;
buffers.forEach((buffer, index) => {
saveTensorToDb(db, netKeys[index], new Uint8Array(buffer));
if (index < 4) {
combinedBuffer.set(new Uint8Array(buffer), offset);
offset += buffer.byteLength;
buffer = null;
}
} else {
combinedBuffer = combinedBuffer.content;
}
if (textModelU8 == null) {
textModelU8 = new Uint8Array(await getProgressDlForPart(window.MODEL_BASE_URL + '/net_textmodel.safetensors', progressCallback));
await saveTensorToDb(db, "net.text", textModelU8);
} else if (!textModelFetched) {
textModelU8 = textModelU8.content;
}
});
let textModelU8 = new Uint8Array(buffers[4]);
document.getElementById("modelDlTitle").innerHTML = "Decompressing model";
const textModelOffset = 3772703308;
@@ -346,16 +440,7 @@
let chunkStartF16 = 8 + metadataLength + (decodeChunkSize * i);
let chunkEndF16 = chunkStartF16 + decodeChunkSize;
let chunk = combinedBuffer.subarray(chunkStartF16, chunkEndF16);
if (chunk.byteLength %4 != 0) {
const paddingBytes = 4 - (chunk.byteLength % 4);
const alignedBuffer = new ArrayBuffer(chunk.byteLength + paddingBytes);
const alignedView = new Uint8Array(alignedBuffer);
alignedView.set(new Uint8Array(chunk));
chunk = alignedView;
}
let result = await f16tof32GPU(device, chunk);
let result = await f16tof32GPU(chunk);
let resultUint8 = new Uint8Array(result.buffer);
let chunkStartF32 = 8 + metadataLength + (decodeChunkSize * i * 2);
let chunkEndF32 = chunkStartF32 + resultUint8.byteLength;
@@ -421,7 +506,7 @@
document.getElementById("btnRunNet").disabled = false;
}
function runStableDiffusion(prompt, steps, guidance) {
function runStableDiffusion(prompt, steps, guidance, showStep) {
return new Promise(async (resolve, reject) => {
let context = await timer(() => nets["textModel"](clipTokenizer.encodeForCLIP(prompt)));
let unconditional_context = await timer(() => nets["textModel"](clipTokenizer.encodeForCLIP("")));
@@ -456,8 +541,16 @@
for (let i = timesteps.length - 1; i >= 0; i--) {
let timestep = new Float32Array([timesteps[i]]);
let x_prev = await timer(() => nets["diffusor"](unconditional_context, context, latent, timestep, new Float32Array([alphas[i]]), new Float32Array([alphas_prev[i]]), new Float32Array([guidance])));
let start = performance.now()
let x_prev = await nets["diffusor"](unconditional_context, context, latent, timestep, new Float32Array([alphas[i]]), new Float32Array([alphas_prev[i]]), new Float32Array([guidance]));
document.getElementById("divStepTime").style.display = "block";
document.getElementById("stepTimeValue").innerText = `${(performance.now() - start).toFixed(1)} ms / step`
latent = x_prev;
if (showStep != null) {
showStep(await nets["decoder"](latent));
}
document.getElementById("progressBar").value = ((steps - i) / steps) * 100
document.getElementById("progressFraction").innerHTML = (steps - i) + "/" + steps
}
@@ -466,27 +559,37 @@
});
}
function renderImage(e, image) {
let pixels = []
let pixelCounter = 0
for (let j = 0; j < 512; j++) {
for (let k = 0; k < 512; k++) {
pixels.push(image[pixelCounter])
pixels.push(image[pixelCounter+1])
pixels.push(image[pixelCounter+2])
pixels.push(255)
pixelCounter += 3
}
}
ctx.putImageData(new ImageData(new Uint8ClampedArray(pixels), 512, 512), 0, 0);
e.target.disabled = false;
}
document.getElementById("btnRunNet").addEventListener("click", function(e) {
e.target.disabled = true;
const canvas = document.getElementById("canvas");
ctx.clearRect(0, 0, canvas.width, canvas.height);
runStableDiffusion(document.getElementById("promptText").value, document.getElementById("stepRange").value, document.getElementById("guidanceRange").value).then((image) => {
let pixels = []
let pixelCounter = 0
for (let j = 0; j < 512; j++) {
for (let k = 0; k < 512; k++) {
pixels.push(image[pixelCounter])
pixels.push(image[pixelCounter+1])
pixels.push(image[pixelCounter+2])
pixels.push(255)
pixelCounter += 3
}
}
ctx.putImageData(new ImageData(new Uint8ClampedArray(pixels), 512, 512), 0, 0);
console.log(image);
console.log("Success");
e.target.disabled = false;
runStableDiffusion(
document.getElementById("promptText").value,
document.getElementById("stepRange").value,
document.getElementById("guidanceRange").value,
// Decode at each step
null
).then((image) => {
renderImage(e, image);
});
}, false);