Files
tinygrad/examples/webgpu/stable_diffusion/compile.py
Ahmed Harmouche 10618aba98 Bring back WebGPU (#7063)
* Start from andredaprato:webgpu-clean

* Fix infs

* inf wgsl function is not needed

* Emulated ulong for threefry, more tests passing

* Randomness tests passing

* Update model export to support new changes in webgpu, efficientnet export works again

* Simplify shift emulation in wgsl

* Delete test file

* Fix bigger than u32 u32 literal

* Why was skip copies added here?

* Python3.12 for webgpu tests

* Fix model export syntax error

* Get test ops passing with some skips

* Fix lint

* Much simpler shift

* Run more tests

* Timestamp queries are not supported in CI, so skip search tests

* All fancy indexing passing

* r is ctx

* Run more dtype tests by using is_dtype_supported

* Cleanup ulong shift rendering

* UPat -> Pat, UOps -> Ops

* Pat -> UPat

* Refactor render_ushift if-else

* Pattern to avoid ulong mul

* Remove vals_dtype

* is_nan trick + rewrite, test_isnan passing

* Rewrite a * select(1, nan, gate) -> select(a, nan, gate)

* No arg, just op

* Support char, uchar, short, ushort

* Run test_index_mnis now that we have uint8

* Fix pyling

* Save 3 lines by using base Compiler

* No more long emulation

* Remove fixup_binops

* No more external_local_bufx wgsl specific cstyle modif, use base extra_pm

* Simpler, faster copyin/out

* Skip some new tests that use long

* Fix typo

* copyout touchup

* Save lines by using render_cast

* WebGL is not supported in core, delete it from is_dtype_supported

* More narrow test skips for some unary tests

* TernaryOps, UnaryOps -> Ops

* TinyGrad supports WebGPU

* StableDiffusion demo: f16tof32 gpu is a lib, update UI

* Packed load/store, no more scale_size, no core tinygrad changes

* Rename copyin, copyout

* Device -> dev

* Fix lint

* Pattern matcher rule for packed load/store

* Refactor

* Shorter packed load/store

* this should fix lint

* Fix mypy

* SD compile script working

* New SD webgpu UI

* New default prompt

* New SD weights

* Fix title when webgpu not available

* Run symbolic tests, simplify is_nan, use round_up

* Show step time on UI

* Bump minimum wgpu version to v0.19

* Fix latent

---------

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
2024-11-26 12:26:40 +08:00

233 lines
10 KiB
Python

import os
from extra.export_model import compile_net, jit_model
from examples.stable_diffusion import StableDiffusion
from tinygrad.nn.state import get_state_dict, safe_save, safe_load_metadata, torch_load, load_state_dict
from tinygrad.tensor import Tensor
from tinygrad import Device
from tinygrad.helpers import fetch
from typing import NamedTuple, Any, List
import requests
import argparse
import numpy as np
def convert_f32_to_f16(input_file, output_file):
with open(input_file, 'rb') as f:
metadata_length_bytes = f.read(8)
metadata_length = int.from_bytes(metadata_length_bytes, byteorder='little', signed=False)
metadata_json_bytes = f.read(metadata_length)
float32_values = np.fromfile(f, dtype=np.float32)
first_text_model_offset = 3772703308
num_elements = int((first_text_model_offset)/4)
front_float16_values = float32_values[:num_elements].astype(np.float16)
rest_float32_values = float32_values[num_elements:]
with open(output_file, 'wb') as f:
f.write(metadata_length_bytes)
f.write(metadata_json_bytes)
front_float16_values.tofile(f)
rest_float32_values.tofile(f)
def split_safetensor(fn):
_, json_len, metadata = safe_load_metadata(fn)
text_model_offset = 3772703308
chunk_size = 536870912
for k in metadata:
# safetensor is in fp16, except for text moel
if (metadata[k]["data_offsets"][0] < text_model_offset):
metadata[k]["data_offsets"][0] = int(metadata[k]["data_offsets"][0]/2)
metadata[k]["data_offsets"][1] = int(metadata[k]["data_offsets"][1]/2)
last_offset = 0
part_end_offsets = []
for k in metadata:
offset = metadata[k]['data_offsets'][0]
if offset == text_model_offset:
break
part_offset = offset - last_offset
if (part_offset >= chunk_size):
part_end_offsets.append(8+json_len+offset)
last_offset = offset
text_model_start = int(text_model_offset/2)
net_bytes = bytes(open(fn, 'rb').read())
part_end_offsets.append(text_model_start+8+json_len)
cur_pos = 0
for i, end_pos in enumerate(part_end_offsets):
with open(os.path.join(os.path.dirname(__file__), f'./net_part{i}.safetensors'), "wb+") as f:
f.write(net_bytes[cur_pos:end_pos])
cur_pos = end_pos
with open(os.path.join(os.path.dirname(__file__), f'./net_textmodel.safetensors'), "wb+") as f:
f.write(net_bytes[text_model_start+8+json_len:])
return part_end_offsets
def fetch_dep(file, url):
with open(file, "w", encoding="utf-8") as f:
f.write(requests.get(url).text.replace("https://huggingface.co/wpmed/tinygrad-sd-f16/raw/main/bpe_simple_vocab_16e6.mjs", "./bpe_simple_vocab_16e6.mjs"))
if __name__ == "__main__":
fetch_dep(os.path.join(os.path.dirname(__file__), "clip_tokenizer.js"), "https://huggingface.co/wpmed/tinygrad-sd-f16/raw/main/clip_tokenizer.js")
fetch_dep(os.path.join(os.path.dirname(__file__), "bpe_simple_vocab_16e6.mjs"), "https://huggingface.co/wpmed/tinygrad-sd-f16/raw/main/bpe_simple_vocab_16e6.mjs")
parser = argparse.ArgumentParser(description='Run Stable Diffusion', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--remoteweights', action='store_true', help="Use safetensors from Huggingface, or from local")
args = parser.parse_args()
Device.DEFAULT = "WEBGPU"
Tensor.no_grad = True
model = StableDiffusion()
# load in weights
load_state_dict(model, torch_load(fetch('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', 'sd-v1-4.ckpt'))['state_dict'], strict=False)
class Step(NamedTuple):
name: str = ""
input: List[Tensor] = []
forward: Any = None
sub_steps = [
Step(name = "textModel", input = [Tensor.randn(1, 77)], forward = model.cond_stage_model.transformer.text_model),
Step(name = "diffusor", input = [Tensor.randn(1, 77, 768), Tensor.randn(1, 77, 768), Tensor.randn(1,4,64,64), Tensor.rand(1), Tensor.randn(1), Tensor.randn(1), Tensor.randn(1)], forward = model),
Step(name = "decoder", input = [Tensor.randn(1,4,64,64)], forward = model.decode)
]
prg = ""
def fixup_code(code, key):
code = code.replace(key, 'main')\
.replace("var<uniform> INFINITY : f32;\n", "fn inf(a: f32) -> f32 { return a/0.0; }\n")\
.replace("@group(0) @binding(0)", "")\
.replace("INFINITY", "inf(1.0)")
for i in range(1,9): code = code.replace(f"binding({i})", f"binding({i-1})")
return code
def compile_step(model, step: Step):
run, special_names = jit_model(step, *step.input)
functions, statements, bufs, _ = compile_net(run, special_names)
state = get_state_dict(model)
weights = {id(x.lazydata.base.realized): name for name, x in state.items()}
kernel_code = '\n\n'.join([f"const {key} = `{fixup_code(code, key)}`;" for key, code in functions.items()])
kernel_names = ', '.join([name for (name, _, _, _) in statements])
kernel_calls = '\n '.join([f"addComputePass(device, commandEncoder, piplines[{i}], [{', '.join(args)}], {global_size});" for i, (_name, args, global_size, _local_size) in enumerate(statements) ])
bufs = '\n '.join([f"const {name} = " + (f"createEmptyBuf(device, {size});" if _key not in weights else f"createWeightBuf(device, {size}, getTensorBuffer(safetensor, metadata['{weights[_key]}'], '{weights[_key]}'))") + ";" for name,(size,dtype,_key) in bufs.items()])
gpu_write_bufs = '\n '.join([f"const gpuWriteBuffer{i} = device.createBuffer({{size:input{i}.size, usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.MAP_WRITE }});" for i,(_,value) in enumerate(special_names.items()) if "output" not in value])
input_writer = '\n '.join([f"await gpuWriteBuffer{i}.mapAsync(GPUMapMode.WRITE);\n new Float32Array(gpuWriteBuffer{i}.getMappedRange()).set(" + f'data{i});' + f"\n gpuWriteBuffer{i}.unmap();\ncommandEncoder.copyBufferToBuffer(gpuWriteBuffer{i}, 0, input{i}, 0, gpuWriteBuffer{i}.size);" for i,(_,value) in enumerate(special_names.items()) if value != "output0"])
return f"""\n var {step.name} = function() {{
{kernel_code}
return {{
"setup": async (device, safetensor) => {{
const metadata = getTensorMetadata(safetensor[0]);
{bufs}
{gpu_write_bufs}
const gpuReadBuffer = device.createBuffer({{ size: output0.size, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ }});
const kernels = [{kernel_names}];
const piplines = await Promise.all(kernels.map(name => device.createComputePipelineAsync({{layout: "auto", compute: {{ module: device.createShaderModule({{ code: name }}), entryPoint: "main" }}}})));
return async ({",".join([f'data{i}' for i,(k,v) in enumerate(special_names.items()) if v != "output0"])}) => {{
const commandEncoder = device.createCommandEncoder();
{input_writer}
{kernel_calls}
commandEncoder.copyBufferToBuffer(output0, 0, gpuReadBuffer, 0, output0.size);
const gpuCommands = commandEncoder.finish();
device.queue.submit([gpuCommands]);
await gpuReadBuffer.mapAsync(GPUMapMode.READ);
const resultBuffer = new Float32Array(gpuReadBuffer.size/4);
resultBuffer.set(new Float32Array(gpuReadBuffer.getMappedRange()));
gpuReadBuffer.unmap();
return resultBuffer;
}}
}}
}}
}}
"""
for step in sub_steps:
print(f'Executing step={step.name}')
prg += compile_step(model, step)
if step.name == "diffusor":
if args.remoteweights:
base_url = "https://huggingface.co/wpmed/stable-diffusion-f16-new/resolve/main"
else:
state = get_state_dict(model)
safe_save(state, os.path.join(os.path.dirname(__file__), "net.safetensors"))
convert_f32_to_f16(os.path.join(os.path.dirname(__file__), "./net.safetensors"), os.path.join(os.path.dirname(__file__), "./net_conv.safetensors"))
split_safetensor(os.path.join(os.path.dirname(__file__), "./net_conv.safetensors"))
os.remove(os.path.join(os.path.dirname(__file__), "net.safetensors"))
os.remove(os.path.join(os.path.dirname(__file__), "net_conv.safetensors"))
base_url = "."
prekernel = f"""
window.MODEL_BASE_URL= "{base_url}";
const getTensorMetadata = (safetensorBuffer) => {{
const metadataLength = Number(new DataView(safetensorBuffer.buffer).getBigUint64(0, true));
const metadata = JSON.parse(new TextDecoder("utf8").decode(safetensorBuffer.subarray(8, 8 + metadataLength)));
return Object.fromEntries(Object.entries(metadata).filter(([k, v]) => k !== "__metadata__").map(([k, v]) => [k, {{...v, data_offsets: v.data_offsets.map(x => 8 + metadataLength + x)}}]));
}};
const getTensorBuffer = (safetensorParts, tensorMetadata, key) => {{
let selectedPart = 0;
let counter = 0;
let partStartOffsets = [1131408336, 2227518416, 3308987856, 4265298864];
let correctedOffsets = tensorMetadata.data_offsets;
let prev_offset = 0;
for (let start of partStartOffsets) {{
prev_offset = (counter == 0) ? 0 : partStartOffsets[counter-1];
if (tensorMetadata.data_offsets[0] < start) {{
selectedPart = counter;
correctedOffsets = [correctedOffsets[0]-prev_offset, correctedOffsets[1]-prev_offset];
break;
}}
counter++;
}}
return safetensorParts[selectedPart].subarray(...correctedOffsets);
}}
const getWeight = (safetensors, key) => {{
let uint8Data = getTensorBuffer(safetensors, getTensorMetadata(safetensors[0])[key], key);
return new Float32Array(uint8Data.buffer, uint8Data.byteOffset, uint8Data.byteLength / Float32Array.BYTES_PER_ELEMENT);
}}
const createEmptyBuf = (device, size) => {{
return device.createBuffer({{size, usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST }});
}};
const createWeightBuf = (device, size, data) => {{
const buf = device.createBuffer({{ mappedAtCreation: true, size, usage: GPUBufferUsage.STORAGE }});
new Uint8Array(buf.getMappedRange()).set(data);
buf.unmap();
return buf;
}};
const addComputePass = (device, commandEncoder, pipeline, bufs, workgroup) => {{
const bindGroup = device.createBindGroup({{layout: pipeline.getBindGroupLayout(0), entries: bufs.map((buffer, index) => ({{ binding: index, resource: {{ buffer }} }}))}});
const passEncoder = commandEncoder.beginComputePass();
passEncoder.setPipeline(pipeline);
passEncoder.setBindGroup(0, bindGroup);
passEncoder.dispatchWorkgroups(...workgroup);
passEncoder.end();
}};"""
with open(os.path.join(os.path.dirname(__file__), "net.js"), "w") as text_file:
text_file.write(prekernel + prg)