Stable diffusion WebGPU port (#1370)

* WIP: Stable diffusion WebGPU port

* Load whole model: split safetensor to avoid Chrome allocation limit

* Gitignore .DS_Store, remove debug print

* Clip tokenizer in JS

* WIP: Compile model in parts (text model, diffusor, get_x_prev_and_pred_x0, decoder), and recreate forward logic in JS

* e2e stable diffusion flow

* Create initial random latent tensor in JS

* SD working e2e

* Log if some weights were not loaded properly

* Remove latent_tensor.npy used for debugging

* Cleanup, remove useless logs

* Improve UI

* Add progress bar

* Remove .npy files used for debugging

* Add clip tokenizer as external dependency

* Remove alphas_cumprod.js and load it from safetensors

* Refactor

* Simplify a lot

* Dedup base when limiting elementwise merge (webgpu)

* Add return type to safe_load_metadata

* Do not allow run when webgpu is not supported

* Add progress bar, refactor, fix special names

* Add option to chose from local vs huggingface weights

* lowercase tinygrad :)

* fp16 model dl, decompression client side

* Cache f16 model in browser, better progress

* Cache miss recovery

---------

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
Ahmed Harmouche
2023-11-04 02:29:16 +01:00
committed by GitHub
parent f582ec56d5
commit 265304e7fd
8 changed files with 869 additions and 50 deletions

4
.gitignore vendored
View File

@@ -1,6 +1,7 @@
__pycache__
.venv/
.vscode
.DS_Store
notebooks
.*.swp
.*.swo
@@ -31,7 +32,8 @@ extra/datasets/kits/
extra/datasets/COCO/
extra/datasets/audio*
venv
examples/net.*[js,json,safetensors]
examples/**/net.*[js,json]
examples/**/*.safetensors
node_modules
package.json
package-lock.json

View File

@@ -538,7 +538,44 @@ class StableDiffusion:
self.first_stage_model = AutoencoderKL()
self.cond_stage_model = namedtuple("CondStageModel", ["transformer"])(transformer = namedtuple("Transformer", ["text_model"])(text_model = CLIPTextTransformer()))
# TODO: make __call__ run the model
def get_x_prev_and_pred_x0(self, x, e_t, a_t, a_prev):
temperature = 1
sigma_t = 0
sqrt_one_minus_at = (1-a_t).sqrt()
#print(a_t, a_prev, sigma_t, sqrt_one_minus_at)
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
# direction pointing to x_t
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
x_prev = a_prev.sqrt() * pred_x0 + dir_xt
return x_prev, pred_x0
def get_model_output(self, unconditional_context, context, latent, timestep, unconditional_guidance_scale):
# put into diffuser
latents = self.model.diffusion_model(latent.expand(2, *latent.shape[1:]), timestep, unconditional_context.cat(context, dim=0))
unconditional_latent, latent = latents[0:1], latents[1:2]
e_t = unconditional_latent + unconditional_guidance_scale * (latent - unconditional_latent)
return e_t
def decode(self, x):
x = self.first_stage_model.post_quant_conv(1/0.18215 * x)
x = self.first_stage_model.decoder(x)
# make image correct size and scale
x = (x + 1.0) / 2.0
x = x.reshape(3,512,512).permute(1,2,0).clip(0,1)*255
return x.cast(dtypes.uint8) if Device.DEFAULT != "WEBGPU" else x
def __call__(self, unconditional_context, context, latent, timestep, alphas, alphas_prev, guidance):
e_t = self.get_model_output(unconditional_context, context, latent, timestep, guidance)
x_prev, _ = self.get_x_prev_and_pred_x0(latent, e_t, alphas, alphas_prev)
#e_t_next = get_model_output(x_prev)
#e_t_prime = (e_t + e_t_next) / 2
#x_prev, pred_x0 = get_x_prev_and_pred_x0(latent, e_t_prime, index)
return x_prev.realize()
# ** ldm.models.autoencoder.AutoencoderKL (done!)
# 3x512x512 <--> 4x64x64 (16384)
@@ -595,65 +632,31 @@ if __name__ == "__main__":
# done with clip model
del model.cond_stage_model
def get_model_output(latent, timestep, unconditional_guidance_scale):
# put into diffuser
latents = model.model.diffusion_model(latent.expand(2, *latent.shape[1:]), timestep, unconditional_context.cat(context, dim=0))
unconditional_latent, latent = latents[0:1], latents[1:2]
e_t = unconditional_latent + unconditional_guidance_scale * (latent - unconditional_latent)
return e_t
timesteps = list(range(1, 1000, 1000//args.steps))
print(f"running for {timesteps} timesteps")
alphas = model.alphas_cumprod[Tensor(timesteps)]
alphas_prev = Tensor([1.0]).cat(alphas[:-1])
def get_x_prev_and_pred_x0(x, e_t, index):
temperature = 1
a_t, a_prev = alphas[index], alphas_prev[index]
sigma_t = 0
sqrt_one_minus_at = (1-a_t).sqrt()
#print(a_t, a_prev, sigma_t, sqrt_one_minus_at)
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
# direction pointing to x_t
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
x_prev = a_prev.sqrt() * pred_x0 + dir_xt
return x_prev, pred_x0
@TinyJit
def do_step(latent, timestep, index, guidance):
e_t = get_model_output(latent, timestep, guidance)
x_prev, _ = get_x_prev_and_pred_x0(latent, e_t, index)
#e_t_next = get_model_output(x_prev)
#e_t_prime = (e_t + e_t_next) / 2
#x_prev, pred_x0 = get_x_prev_and_pred_x0(latent, e_t_prime, index)
return x_prev.realize()
# start with random noise
if args.seed is not None: Tensor._seed = args.seed
latent = Tensor.randn(1,4,64,64)
@TinyJit
def run(model, *x): return model(*x).realize()
# this is diffusion
with Context(BEAM=getenv("LATEBEAM")):
# this is diffusion
for index, timestep in (t:=tqdm(list(enumerate(timesteps))[::-1])):
GlobalCounters.reset()
t.set_description("%3d %3d" % (index, timestep))
with Timing("step in ", enabled=args.timing, on_exit=lambda _: f", using {GlobalCounters.mem_used/1e9:.2f} GB"):
latent = do_step(latent, Tensor([timestep]), Tensor([index]), Tensor([args.guidance]))
tid = Tensor([index])
latent = run(model, unconditional_context, context, latent, Tensor([timestep]), alphas[tid], alphas_prev[tid], Tensor([args.guidance]))
if args.timing: Device[Device.DEFAULT].synchronize()
del do_step
del run
# upsample latent space to image with autoencoder
x = model.first_stage_model.post_quant_conv(1/0.18215 * latent)
x = model.first_stage_model.decoder(x)
# make image correct size and scale
x = (x + 1.0) / 2.0
x = (x.reshape(3,512,512).permute(1,2,0).clip(0,1)*255)
if Device.DEFAULT != "WEBGPU": x = x.cast(dtypes.uint8)
x = model.decode(latent)
print(x.shape)
# save image

View File

@@ -0,0 +1,234 @@
import os
from extra.export_model import compile_net, jit_model
from examples.stable_diffusion import StableDiffusion
from tinygrad.nn.state import get_state_dict, safe_save, safe_load_metadata, torch_load, load_state_dict
from tinygrad.tensor import Tensor
from tinygrad.ops import Device
from extra.utils import download_file
from typing import NamedTuple, Any, List
from pathlib import Path
import argparse
import numpy as np
def convert_f32_to_f16(input_file, output_file):
with open(input_file, 'rb') as f:
metadata_length_bytes = f.read(8)
metadata_length = int.from_bytes(metadata_length_bytes, byteorder='little', signed=False)
metadata_json_bytes = f.read(metadata_length)
float32_values = np.fromfile(f, dtype=np.float32)
first_text_model_offset = 3772703308
num_elements = int((first_text_model_offset)/4)
front_float16_values = float32_values[:num_elements].astype(np.float16)
rest_float32_values = float32_values[num_elements:]
with open(output_file, 'wb') as f:
f.write(metadata_length_bytes)
f.write(metadata_json_bytes)
front_float16_values.tofile(f)
rest_float32_values.tofile(f)
FILENAME = Path(__file__).parent.parent.parent.parent / "weights/sd-v1-4.ckpt"
def split_safetensor(fn):
_, json_len, metadata = safe_load_metadata(fn)
text_model_offset = 3772703308
chunk_size = 536870912
for k in metadata:
# safetensor is in fp16, except for text moel
if (metadata[k]["data_offsets"][0] < text_model_offset):
metadata[k]["data_offsets"][0] = int(metadata[k]["data_offsets"][0]/2)
metadata[k]["data_offsets"][1] = int(metadata[k]["data_offsets"][1]/2)
last_offset = 0
part_end_offsets = []
for k in metadata:
offset = metadata[k]['data_offsets'][0]
if offset == text_model_offset:
break
part_offset = offset - last_offset
if (part_offset >= chunk_size):
part_end_offsets.append(8+json_len+offset)
last_offset = offset
text_model_start = int(text_model_offset/2)
net_bytes = bytes(open(fn, 'rb').read())
part_end_offsets.append(text_model_start+8+json_len)
cur_pos = 0
for i, end_pos in enumerate(part_end_offsets):
with open(f'./net_part{i}.safetensors', "wb+") as f:
f.write(net_bytes[cur_pos:end_pos])
cur_pos = end_pos
with open(f'./net_textmodel.safetensors', "wb+") as f:
f.write(net_bytes[text_model_start+8+json_len:])
return part_end_offsets
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Run Stable Diffusion', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--remoteweights', action='store_true', help="Use safetensors from Huggingface, or from local")
args = parser.parse_args()
Device.DEFAULT = "WEBGPU"
Tensor.no_grad = True
model = StableDiffusion()
# load in weights
download_file('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', FILENAME)
load_state_dict(model, torch_load(FILENAME)['state_dict'], strict=False)
class Step(NamedTuple):
name: str = ""
input: List[Tensor] = []
forward: Any = None
sub_steps = [
Step(name = "textModel", input = [Tensor.randn(1, 77)], forward = model.cond_stage_model.transformer.text_model),
Step(name = "diffusor", input = [Tensor.randn(1, 77, 768), Tensor.randn(1, 77, 768), Tensor.randn(1,4,64,64), Tensor.rand(1), Tensor.randn(1), Tensor.randn(1), Tensor.randn(1)], forward = model),
Step(name = "decoder", input = [Tensor.randn(1,4,64,64)], forward = model.decode)
]
prg = ""
def compile_step(model, step: Step):
run, special_names = jit_model(step, *step.input)
functions, statements, bufs, _ = compile_net(run, special_names)
state = get_state_dict(model)
weights = {id(x.lazydata.realized): name for name, x in state.items()}
kernel_code = '\n\n'.join([f"const {key} = `{code.replace(key, 'main')}`;" for key, code in functions.items()])
kernel_names = ', '.join([name for (name, _, _, _) in statements])
kernel_calls = '\n '.join([f"addComputePass(device, commandEncoder, piplines[{i}], [{', '.join(args)}], {global_size});" for i, (_name, args, global_size, _local_size) in enumerate(statements) ])
bufs = '\n '.join([f"const {name} = " + (f"createEmptyBuf(device, {size});" if _key not in weights else f"createWeightBuf(device, {size}, getTensorBuffer(safetensor, metadata['{weights[_key]}'], '{weights[_key]}'))") + ";" for name,(size,dtype,_key) in bufs.items()])
gpu_write_bufs = '\n '.join([f"const gpuWriteBuffer{i} = device.createBuffer({{size:input{i}.size, usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.MAP_WRITE }});" for i,(_,value) in enumerate(special_names.items()) if "output" not in value])
input_writer = '\n '.join([f"await gpuWriteBuffer{i}.mapAsync(GPUMapMode.WRITE);\n new Float32Array(gpuWriteBuffer{i}.getMappedRange()).set(" + f'data{i});' + f"\n gpuWriteBuffer{i}.unmap();\ncommandEncoder.copyBufferToBuffer(gpuWriteBuffer{i}, 0, input{i}, 0, gpuWriteBuffer{i}.size);" for i,(_,value) in enumerate(special_names.items()) if value != "output0"])
return f"""\n var {step.name} = function() {{
{kernel_code}
return {{
"setup": async (device, safetensor) => {{
const metadata = getTensorMetadata(safetensor[0]);
{bufs}
{gpu_write_bufs}
const gpuReadBuffer = device.createBuffer({{ size: output0.size, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ }});
const kernels = [{kernel_names}];
const piplines = await Promise.all(kernels.map(name => device.createComputePipelineAsync({{layout: "auto", compute: {{ module: device.createShaderModule({{ code: name }}), entryPoint: "main" }}}})));
return async ({",".join([f'data{i}' for i,(k,v) in enumerate(special_names.items()) if v != "output0"])}) => {{
const commandEncoder = device.createCommandEncoder();
{input_writer}
{kernel_calls}
commandEncoder.copyBufferToBuffer(output0, 0, gpuReadBuffer, 0, output0.size);
const gpuCommands = commandEncoder.finish();
device.queue.submit([gpuCommands]);
await gpuReadBuffer.mapAsync(GPUMapMode.READ);
const resultBuffer = new Float32Array(gpuReadBuffer.size/4);
resultBuffer.set(new Float32Array(gpuReadBuffer.getMappedRange()));
gpuReadBuffer.unmap();
return resultBuffer;
}}
}}
}}
}}
"""
for step in sub_steps:
print(f'Executing step={step.name}')
prg += compile_step(model, step)
if step.name == "diffusor":
if args.remoteweights:
base_url = "https://huggingface.co/wpmed/tinygrad-sd-f16/resolve/main"
else:
state = get_state_dict(model)
safe_save(state, os.path.join(os.path.dirname(__file__), "net.safetensors"))
convert_f32_to_f16("./net.safetensors", "./net_conv.safetensors")
split_safetensor("./net_conv.safetensors")
os.remove("net.safetensors")
os.remove("net_conv.safetensors")
base_url = "."
prekernel = f"""
window.MODEL_BASE_URL= "{base_url}";
const getTensorMetadata = (safetensorBuffer) => {{
const metadataLength = Number(new DataView(safetensorBuffer.buffer).getBigUint64(0, true));
const metadata = JSON.parse(new TextDecoder("utf8").decode(safetensorBuffer.subarray(8, 8 + metadataLength)));
return Object.fromEntries(Object.entries(metadata).filter(([k, v]) => k !== "__metadata__").map(([k, v]) => [k, {{...v, data_offsets: v.data_offsets.map(x => 8 + metadataLength + x)}}]));
}};
const getTensorBuffer = (safetensorParts, tensorMetadata, key) => {{
let selectedPart = 0;
let counter = 0;
let partStartOffsets = [1131408336, 2227518416, 3308987856, 4265298864];
let correctedOffsets = tensorMetadata.data_offsets;
let prev_offset = 0;
for (let start of partStartOffsets) {{
prev_offset = (counter == 0) ? 0 : partStartOffsets[counter-1];
if (tensorMetadata.data_offsets[0] < start) {{
selectedPart = counter;
correctedOffsets = [correctedOffsets[0]-prev_offset, correctedOffsets[1]-prev_offset];
break;
}}
counter++;
}}
let allZero = true;
let out = safetensorParts[selectedPart].subarray(...correctedOffsets);
for (let i = 0; i < out.length; i++) {{
if (out[i] !== 0) {{
allZero = false;
break;
}}
}}
if (allZero) {{
console.log("Error: weight '" + key + "' is all zero.");
}}
return safetensorParts[selectedPart].subarray(...correctedOffsets);
}}
const getWeight = (safetensors, key) => {{
let uint8Data = getTensorBuffer(safetensors, getTensorMetadata(safetensors[0])[key], key);
return new Float32Array(uint8Data.buffer, uint8Data.byteOffset, uint8Data.byteLength / Float32Array.BYTES_PER_ELEMENT);
}}
const createEmptyBuf = (device, size) => {{
return device.createBuffer({{size, usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST }});
}};
const createWeightBuf = (device, size, data) => {{
const buf = device.createBuffer({{ mappedAtCreation: true, size, usage: GPUBufferUsage.STORAGE }});
new Uint8Array(buf.getMappedRange()).set(data);
buf.unmap();
return buf;
}};
const addComputePass = (device, commandEncoder, pipeline, bufs, workgroup) => {{
const bindGroup = device.createBindGroup({{layout: pipeline.getBindGroupLayout(0), entries: bufs.map((buffer, index) => ({{ binding: index, resource: {{ buffer }} }}))}});
const passEncoder = commandEncoder.beginComputePass();
passEncoder.setPipeline(pipeline);
passEncoder.setBindGroup(0, bindGroup);
passEncoder.dispatchWorkgroups(...workgroup);
passEncoder.end();
}};"""
with open(os.path.join(os.path.dirname(__file__), "net.js"), "w") as text_file:
text_file.write(prekernel + prg)

View File

@@ -0,0 +1,64 @@
const f16tof32 = `
fn u16_to_f16(x: u32) -> f32 {
let sign = f32((x >> 15) & 0x1);
let exponent = f32((x >> 10) & 0x1F);
let fraction = f32(x & 0x3FF);
let sign_multiplier = select(1.0, -1.0, sign == 1.0);
if (exponent == 0.0) {
return sign_multiplier * 6.103515625e-5 * (fraction / 1024.0);
} else {
return sign_multiplier * exp2(exponent - 15.0) * (1.0 + fraction / 1024.0);
}
}
@group(0) @binding(0) var<storage,read_write> data0: array<u32>;
@group(0) @binding(1) var<storage,read_write> data1: array<f32>;
@compute @workgroup_size(256) fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let gidx = gid.x;
let outgidx = gidx*2;
if (gidx >= arrayLength(&data0)) {
return;
}
let oo = data0[gidx];
let oo1 = (oo >> 16);
let oo2 = oo & 0xFFFFu;
let f1 = u16_to_f16(oo2);
let f2 = u16_to_f16(oo1);
data1[outgidx] = f1;
data1[outgidx + 1] = f2;
}`;
window.f16tof32GPU = async(device, inf16) => {
const input = device.createBuffer({size: inf16.length, usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST });
const output = device.createBuffer({size: inf16.length*2, usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST });
const gpuWriteBuffer = device.createBuffer({size: input.size, usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.MAP_WRITE });
const gpuReadBuffer = device.createBuffer({ size: output.size, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ });
const commandEncoder = device.createCommandEncoder();
await gpuWriteBuffer.mapAsync(GPUMapMode.WRITE);
const alignedUint32View = new Uint32Array(inf16.buffer, inf16.byteOffset, inf16.length / 4);
new Uint32Array(gpuWriteBuffer.getMappedRange()).set(alignedUint32View);
gpuWriteBuffer.unmap();
commandEncoder.copyBufferToBuffer(gpuWriteBuffer, 0, input, 0, gpuWriteBuffer.size);
const pipeline = await device.createComputePipelineAsync({layout: "auto", compute: { module: device.createShaderModule({ code: f16tof32 }), entryPoint: "main" }});
addComputePass(device, commandEncoder, pipeline, [input, output], [Math.ceil(inf16.length/(4*256)), 1, 1]);
commandEncoder.copyBufferToBuffer(output, 0, gpuReadBuffer, 0, output.size);
const gpuCommands = commandEncoder.finish();
device.queue.submit([gpuCommands]);
await gpuReadBuffer.mapAsync(GPUMapMode.READ);
const resultBuffer = new Float32Array(gpuReadBuffer.size/4);
resultBuffer.set(new Float32Array(gpuReadBuffer.getMappedRange()));
gpuReadBuffer.unmap();
return resultBuffer;
}

View File

@@ -0,0 +1,511 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>tinygrad has WebGPU</title>
<style>
body {
font-family: 'Arial', sans-serif;
text-align: center;
padding: 30px;
}
a {
text-decoration: none;
color: #4A90E2;
}
h1 {
font-size: 36px;
font-weight: normal;
margin-bottom: 20px;
}
#mybox {
display: flex;
flex-direction: column;
align-items: center;
gap: 20px;
width: 50%;
margin: 0 auto;
}
#promptText, #stepRange, #btnRunNet, #guidanceRange {
font-size: 18px;
width: 100%;
}
#result {
font-size: 48px;
}
#time {
font-size: 16px;
color: grey;
}
canvas {
margin-top: 20px;
border: 1px solid #000;
}
label {
display: flex;
align-items: center;
gap: 10px;
width: 100%;
}
#sliderValue {
margin-right: 10px;/
}
</style>
<script type="module">
import ClipTokenizer from 'https://softwired.nyc3.cdn.digitaloceanspaces.com/sd/clip_tokenizer.js';
window.clipTokenizer = new ClipTokenizer();
</script>
<script src="./f16_to_f32.js"></script>
<script src="./net.js"></script>
</head>
<body>
<h1 id="wgpuError" style="display: none; color: red;">WebGPU is not supported in this browser</h1>
<h1 id="sdTitle">StableDiffusion by <a href="https://github.com/tinygrad/tinygrad" target="_blank">tinygrad</a> WebGPU</h1>
<div id="mybox">
<input id="promptText" type="text" placeholder="Enter your prompt here" value="a horse sized cat eating a bagel">
<label>
Steps: <span id="stepValue">8</span>
<input id="stepRange" type="range" min="5" max="20" value="8" step="1">
</label>
<label>
Guidance: <span id="guidanceValue">7.5</span>
<input id="guidanceRange" type="range" min="3" max="15" value="7.5" step="0.1">
</label>
<input id="btnRunNet" type="button" value="Run" disabled>
<div id="divModelDl" style="display: flex; align-items: center; width: 100%; gap: 10px;">
<span id="modelDlTitle">Downloading model</span>
<progress id="modelDlProgressBar" value="0" max="100" style="flex-grow: 1;"></progress>
<span id="modelDlProgressValue"></span>
</div>
<div id="divStepProgress" style="display: none; align-items: center; width: 100%; gap: 10px;">
<progress id="progressBar" value="0" max="100" style="flex-grow: 1;"></progress>
<span id="progressFraction"></span>
</div>
</div>
<canvas id="canvas" width="512" height="512"></canvas>
<script>
function initDb() {
return new Promise((resolve, reject) => {
let db;
const request = indexedDB.open('tinydb', 1);
request.onerror = (event) => {
console.error('Database error:', event.target.error);
resolve(null);
};
request.onsuccess = (event) => {
db = event.target.result;
console.log("Db initialized.");
resolve(db);
};
request.onupgradeneeded = (event) => {
db = event.target.result;
if (!db.objectStoreNames.contains('tensors')) {
db.createObjectStore('tensors', { keyPath: 'id' });
}
};
});
}
function saveTensorToDb(db, id, tensor) {
return new Promise((resolve, reject) => {
if (db == null) {
resolve(null);
}
const transaction = db.transaction(['tensors'], 'readwrite');
const store = transaction.objectStore('tensors');
const request = store.put({ id: id, content: tensor });
transaction.onabort = (event) => {
console.log("Transaction error while saving tensor: " + event.target.error);
resolve(null);
};
request.onsuccess = () => {
console.log('Tensor saved successfully.');
resolve();
};
request.onerror = (event) => {
console.error('Tensor save failed:', event.target.error);
resolve(null);
};
});
}
function readTensorFromDb(db, id) {
return new Promise((resolve, reject) => {
if (db == null) {
resolve(null);
}
const transaction = db.transaction(['tensors'], 'readonly');
const store = transaction.objectStore('tensors');
const request = store.get(id);
transaction.onabort = (event) => {
console.log("Transaction error while reading tensor: " + event.target.error);
resolve(null);
};
request.onsuccess = (event) => {
const result = event.target.result;
if (result) {
console.log("Cache hit: " + id);
resolve(result);
} else {
console.log("Cache miss: " + id);
resolve(null);
}
};
request.onerror = (event) => {
console.error('Tensor retrieve failed: ', event.target.error);
resolve(null);
};
});
}
window.addEventListener('load', async function() {
if (!navigator.gpu) {
document.getElementById("wgpuError").style.display = "";
document.getElementById("sdTitle").style.display = "none";
return;
}
let db = await initDb();
const ctx = document.getElementById("canvas").getContext("2d", { willReadFrequently: true });
let labels, nets, safetensorParts;
const getDevice = async () => {
const adapter = await navigator.gpu.requestAdapter();
const requiredLimits = {};
const maxBufferSizeInSDModel = 1073741824;
requiredLimits.maxStorageBufferBindingSize = maxBufferSizeInSDModel;
requiredLimits.maxBufferSize = maxBufferSizeInSDModel;
return await adapter.requestDevice({
requiredLimits
});
};
const timer = async (func, label = "") => {
const start = performance.now();
const out = await func();
const delta = (performance.now() - start).toFixed(1)
console.log(`${delta} ms ${label}`);
return out;
}
const getProgressDlForPart = async (part, progressCallback) => {
const response = await fetch(part);
const contentLength = response.headers.get('content-length');
const total = parseInt(contentLength, 10);
const res = new Response(new ReadableStream({
async start(controller) {
const reader = response.body.getReader();
for (;;) {
const { done, value } = await reader.read();
if (done) break;
progressCallback(part, value.byteLength, total);
controller.enqueue(value);
}
controller.close();
},
}));
return res.arrayBuffer();
};
const getAndDecompressF16Safetensors = async (device, progress) => {
let totalLoaded = 0;
let totalSize = 0;
let partSize = {};
const progressCallback = (part, loaded, total) => {
totalLoaded += loaded;
if (!partSize[part]) {
totalSize += total;
partSize[part] = true;
}
progress(totalLoaded, totalSize);
};
let combinedBuffer = await readTensorFromDb(db, "net.f16");
let textModelU8 = await readTensorFromDb(db, "net.text");
let textModelFetched = false;
if (combinedBuffer == null) {
let dlParts = [
getProgressDlForPart(window.MODEL_BASE_URL + '/net_part0.safetensors', progressCallback),
getProgressDlForPart(window.MODEL_BASE_URL + '/net_part1.safetensors', progressCallback),
getProgressDlForPart(window.MODEL_BASE_URL + '/net_part2.safetensors', progressCallback),
getProgressDlForPart(window.MODEL_BASE_URL + '/net_part3.safetensors', progressCallback)
];
if (textModelU8 == null) {
dlParts.push(getProgressDlForPart(window.MODEL_BASE_URL + '/net_textmodel.safetensors', progressCallback));
}
let buffers = await Promise.all(dlParts);
// Combine everything except for text model, since that's alreafy f32
const totalLength = buffers.reduce((acc, buffer, index, array) => {
if (index < 4) {
return acc + buffer.byteLength;
} else {
return acc;
}
}, 0
);
combinedBuffer = new Uint8Array(totalLength);
let offset = 0;
buffers.forEach((buffer, index) => {
if (index < 4) {
combinedBuffer.set(new Uint8Array(buffer), offset);
offset += buffer.byteLength;
buffer = null;
}
});
await saveTensorToDb(db, "net.f16", combinedBuffer);
if (textModelU8 == null) {
textModelFetched = true;
textModelU8 = new Uint8Array(buffers[4]);
await saveTensorToDb(db, "net.text", textModelU8);
}
} else {
combinedBuffer = combinedBuffer.content;
}
if (textModelU8 == null) {
textModelU8 = new Uint8Array(await getProgressDlForPart(window.MODEL_BASE_URL + '/net_textmodel.safetensors', progressCallback));
await saveTensorToDb(db, "net.text", textModelU8);
} else if (!textModelFetched) {
textModelU8 = textModelU8.content;
}
document.getElementById("modelDlTitle").innerHTML = "Decompressing model";
const textModelOffset = 3772703308;
const metadataLength = Number(new DataView(combinedBuffer.buffer).getBigUint64(0, true));
const metadata = JSON.parse(new TextDecoder("utf8").decode(combinedBuffer.subarray(8, 8 + metadataLength)));
const allToDecomp = combinedBuffer.byteLength - (8 + metadataLength);
const decodeChunkSize = 67107840;
const numChunks = Math.ceil(allToDecomp/decodeChunkSize);
console.log(allToDecomp + " bytes to decompress");
console.log("Will be decompressed in " + numChunks+ " chunks");
let partOffsets = [{start: 0, end: 1131408336}, {start: 1131408336, end: 2227518416}, {start: 2227518416, end: 3308987856}, {start: 3308987856, end: 4265298864}];
let parts = [];
for (let offsets of partOffsets) {
parts.push(new Uint8Array(offsets.end-offsets.start));
}
parts[0].set(new Uint8Array(new BigUint64Array([BigInt(metadataLength)]).buffer), 0);
parts[0].set(combinedBuffer.subarray(8, 8 + metadataLength), 8);
parts[3].set(textModelU8, textModelOffset+8+metadataLength - partOffsets[3].start);
let start = Date.now();
let cursor = 0;
for (let i = 0; i < numChunks; i++) {
progress(i, numChunks);
let chunkStartF16 = 8 + metadataLength + (decodeChunkSize * i);
let chunkEndF16 = chunkStartF16 + decodeChunkSize;
let chunk = combinedBuffer.subarray(chunkStartF16, chunkEndF16);
if (chunk.byteLength %4 != 0) {
const paddingBytes = 4 - (chunk.byteLength % 4);
const alignedBuffer = new ArrayBuffer(chunk.byteLength + paddingBytes);
const alignedView = new Uint8Array(alignedBuffer);
alignedView.set(new Uint8Array(chunk));
chunk = alignedView;
}
let result = await f16tof32GPU(device, chunk);
let resultUint8 = new Uint8Array(result.buffer);
let chunkStartF32 = 8 + metadataLength + (decodeChunkSize * i * 2);
let chunkEndF32 = chunkStartF32 + resultUint8.byteLength;
let offsetInPart = chunkStartF32 - partOffsets[cursor].start;
if (chunkEndF32 < partOffsets[cursor].end || cursor === parts.length - 1) {
parts[cursor].set(resultUint8, offsetInPart);
} else {
let spaceLeftInCurrentPart = partOffsets[cursor].end - chunkStartF32;
parts[cursor].set(resultUint8.subarray(0, spaceLeftInCurrentPart), offsetInPart);
cursor++;
if (cursor < parts.length) {
let nextPartOffset = spaceLeftInCurrentPart;
let nextPartLength = resultUint8.length - nextPartOffset;
parts[cursor].set(resultUint8.subarray(nextPartOffset, nextPartOffset + nextPartLength), 0);
}
}
resultUint8 = null;
result = null;
}
combinedBuffer = null;
let end = Date.now();
console.log("Decoding took: " + ((end - start) / 1000) + " s");
console.log("Avarage " + ((end - start) / numChunks) + " ms per chunk");
return parts;
};
const loadNet = async () => {
const modelDlTitle = document.getElementById("modelDlTitle");
const progress = (loaded, total) => {
document.getElementById("modelDlProgressBar").value = (loaded/total) * 100
document.getElementById("modelDlProgressValue").innerHTML = Math.trunc((loaded/total) * 100) + "%"
}
const device = await getDevice();
safetensorParts = await getAndDecompressF16Safetensors(device, progress);
modelDlTitle.innerHTML = "Compiling model"
let models = ["textModel", "diffusor", "decoder"];
nets = await timer(() => Promise.all([
textModel().setup(device, safetensorParts),
diffusor().setup(device, safetensorParts),
decoder().setup(device, safetensorParts)
]).then((loadedModels) => loadedModels.reduce((acc, model, index) => { acc[models[index]] = model; return acc; }, {})), "(compilation)")
progress(1, 1);
modelDlTitle.innerHTML = "Model ready"
setTimeout(() => {
document.getElementById("modelDlProgressBar").style.display = "none";
document.getElementById("modelDlProgressValue").style.display = "none";
document.getElementById("divStepProgress").style.display = "flex";
}, 1000);
document.getElementById("btnRunNet").disabled = false;
}
function runStableDiffusion(prompt, steps, guidance) {
return new Promise(async (resolve, reject) => {
let context = await timer(() => nets["textModel"](clipTokenizer.encodeForCLIP(prompt)));
let unconditional_context = await timer(() => nets["textModel"](clipTokenizer.encodeForCLIP("")));
let timesteps = [];
for (let i = 1; i < 1000; i += (1000/steps)) {
timesteps.push(i);
}
console.log("Timesteps: " + timesteps);
let alphasCumprod = getWeight(safetensorParts,"alphas_cumprod");
let alphas = [];
for (t of timesteps) {
alphas.push(alphasCumprod[Math.floor(t)]);
}
alphas_prev = [1.0];
for (let i = 0; i < alphas.length-1; i++) {
alphas_prev.push(alphas[i]);
}
let inpSize = 4*64*64;
latent = new Float32Array(inpSize);
for (let i = 0; i < inpSize; i++) {
latent[i] = Math.sqrt(-2.0 * Math.log(Math.random())) * Math.cos(2.0 * Math.PI * Math.random());
}
for (let i = timesteps.length - 1; i >= 0; i--) {
let timestep = new Float32Array([timesteps[i]]);
let x_prev = await timer(() => nets["diffusor"](unconditional_context, context, latent, timestep, new Float32Array([alphas[i]]), new Float32Array([alphas_prev[i]]), new Float32Array([guidance])));
latent = x_prev;
document.getElementById("progressBar").value = ((steps - i) / steps) * 100
document.getElementById("progressFraction").innerHTML = (steps - i) + "/" + steps
}
resolve(await timer(() => nets["decoder"](latent)));
});
}
document.getElementById("btnRunNet").addEventListener("click", function(e) {
e.target.disabled = true;
runStableDiffusion(document.getElementById("promptText").value, document.getElementById("stepRange").value, document.getElementById("guidanceRange").value).then((image) => {
let pixels = []
let pixelCounter = 0
for (let j = 0; j < 512; j++) {
for (let k = 0; k < 512; k++) {
pixels.push(image[pixelCounter])
pixels.push(image[pixelCounter+1])
pixels.push(image[pixelCounter+2])
pixels.push(255)
pixelCounter += 3
}
}
ctx.putImageData(new ImageData(new Uint8ClampedArray(pixels), 512, 512), 0, 0);
console.log(image);
console.log("Success");
e.target.disabled = false;
});
}, false);
const stepSlider = document.getElementById('stepRange');
const stepValue = document.getElementById('stepValue');
stepSlider.addEventListener('input', function() {
stepValue.textContent = stepSlider.value;
});
const guidanceSlider = document.getElementById('guidanceRange');
const guidanceValue = document.getElementById('guidanceValue');
guidanceSlider.addEventListener('input', function() {
guidanceValue.textContent = guidanceSlider.value;
});
loadNet();
});
</script>
</body>
</html>

View File

@@ -229,7 +229,9 @@ class LazyBuffer:
if MERGE_ELEMENTWISE_OPS:
# remove the buffers from any (childless) BinaryOps that feed into this
srcs = tuple([x.op if x.optype == BinaryOps and not x.children and not x.realized else x for x in srcs]) # type: ignore
_srcs = tuple([x.op if x.optype == BinaryOps and not x.children and not x.realized else x for x in srcs]) # type: ignore
# TODO: needs general merge limiting
if out_device != "WEBGPU" or len(dedup([x.base for _src in _srcs for x in _src.buffers if not x.is_unrealized_const()])) < 7: srcs = _srcs # type: ignore
return create_lazybuffer(out_device, ShapeTracker.from_shape(out_shape), BinaryOps, LazyOp(op, srcs, arg), out_dtype)

View File

@@ -1,6 +1,6 @@
import os, json, pathlib, zipfile, pickle
from tqdm import tqdm
from typing import Dict, Union, List, Optional, Any
from typing import Dict, Union, List, Optional, Any, Tuple
from tinygrad.tensor import Tensor
from tinygrad.helpers import dtypes, prod, argsort, DEBUG, Timing, GlobalCounters, CI
from tinygrad.shape.view import strides_for_shape
@@ -9,11 +9,14 @@ from tinygrad.ops import Device
safe_dtypes = {"F16": dtypes.float16, "F32": dtypes.float32, "U8": dtypes.uint8, "I8": dtypes.int8, "I32": dtypes.int32, "I64": dtypes.int64}
inverse_safe_dtypes = {v:k for k,v in safe_dtypes.items()}
def safe_load(fn:Union[Tensor,str]) -> Dict[str, Tensor]:
def safe_load_metadata(fn:Union[Tensor,str]) -> Tuple[Tensor, int, Any]:
t = fn if isinstance(fn, Tensor) else Tensor.empty(os.stat(fn).st_size, dtype=dtypes.uint8, device=f"disk:{fn}")
json_len = t[0:1].cast(dtypes.int64).numpy()[0]
headers = json.loads(t[8:8+json_len].numpy().tobytes())
return {k:t[8+json_len+v['data_offsets'][0]:].cast(safe_dtypes[v['dtype']])[:prod(v['shape'])].reshape(v['shape']) for k,v in headers.items() if k != "__metadata__"}
return (t, json_len, json.loads(t[8:8+json_len].numpy().tobytes()))
def safe_load(fn:Union[Tensor,str]) -> Dict[str, Tensor]:
t, json_len, metadata = safe_load_metadata(fn)
return {k:t[8+json_len+v['data_offsets'][0]:].cast(safe_dtypes[v['dtype']])[:prod(v['shape'])].reshape(v['shape']) for k,v in metadata.items() if k != "__metadata__"}
def safe_save(tensors:Dict[str, Tensor], fn:str, metadata:Optional[Dict[str, Any]]=None):
headers, offset = {}, 0

View File

@@ -42,4 +42,4 @@ class RawWebGPUBuffer(RawBufferCopyIn):
def toCPU(self) -> np.ndarray: return np.frombuffer(wgpu_device.queue.read_buffer(self._buf, 0), dtype=np.dtype(self.dtype.np, metadata={"backing": self})) # type: ignore
renderer = functools.partial(uops_to_cstyle, WGSLLanguage())
WebGpuBuffer = Compiled(RawWebGPUBuffer, LinearizerOptions(supports_float4=False, local_max=[256, 256, 64], global_max=[65535, 65535, 65535]), renderer, lambda x: x, WebGPUProgram)
WebGpuBuffer = Compiled(RawWebGPUBuffer, LinearizerOptions(device="WEBGPU", supports_float4=False, local_max=[256, 256, 64], global_max=[65535, 65535, 65535]), renderer, lambda x: x, WebGPUProgram)