mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
Stable diffusion WebGPU port (#1370)
* WIP: Stable diffusion WebGPU port * Load whole model: split safetensor to avoid Chrome allocation limit * Gitignore .DS_Store, remove debug print * Clip tokenizer in JS * WIP: Compile model in parts (text model, diffusor, get_x_prev_and_pred_x0, decoder), and recreate forward logic in JS * e2e stable diffusion flow * Create initial random latent tensor in JS * SD working e2e * Log if some weights were not loaded properly * Remove latent_tensor.npy used for debugging * Cleanup, remove useless logs * Improve UI * Add progress bar * Remove .npy files used for debugging * Add clip tokenizer as external dependency * Remove alphas_cumprod.js and load it from safetensors * Refactor * Simplify a lot * Dedup base when limiting elementwise merge (webgpu) * Add return type to safe_load_metadata * Do not allow run when webgpu is not supported * Add progress bar, refactor, fix special names * Add option to chose from local vs huggingface weights * lowercase tinygrad :) * fp16 model dl, decompression client side * Cache f16 model in browser, better progress * Cache miss recovery --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
4
.gitignore
vendored
4
.gitignore
vendored
@@ -1,6 +1,7 @@
|
||||
__pycache__
|
||||
.venv/
|
||||
.vscode
|
||||
.DS_Store
|
||||
notebooks
|
||||
.*.swp
|
||||
.*.swo
|
||||
@@ -31,7 +32,8 @@ extra/datasets/kits/
|
||||
extra/datasets/COCO/
|
||||
extra/datasets/audio*
|
||||
venv
|
||||
examples/net.*[js,json,safetensors]
|
||||
examples/**/net.*[js,json]
|
||||
examples/**/*.safetensors
|
||||
node_modules
|
||||
package.json
|
||||
package-lock.json
|
||||
|
||||
@@ -538,7 +538,44 @@ class StableDiffusion:
|
||||
self.first_stage_model = AutoencoderKL()
|
||||
self.cond_stage_model = namedtuple("CondStageModel", ["transformer"])(transformer = namedtuple("Transformer", ["text_model"])(text_model = CLIPTextTransformer()))
|
||||
|
||||
# TODO: make __call__ run the model
|
||||
def get_x_prev_and_pred_x0(self, x, e_t, a_t, a_prev):
|
||||
temperature = 1
|
||||
sigma_t = 0
|
||||
sqrt_one_minus_at = (1-a_t).sqrt()
|
||||
#print(a_t, a_prev, sigma_t, sqrt_one_minus_at)
|
||||
|
||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||
|
||||
# direction pointing to x_t
|
||||
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
||||
|
||||
x_prev = a_prev.sqrt() * pred_x0 + dir_xt
|
||||
return x_prev, pred_x0
|
||||
|
||||
def get_model_output(self, unconditional_context, context, latent, timestep, unconditional_guidance_scale):
|
||||
# put into diffuser
|
||||
latents = self.model.diffusion_model(latent.expand(2, *latent.shape[1:]), timestep, unconditional_context.cat(context, dim=0))
|
||||
unconditional_latent, latent = latents[0:1], latents[1:2]
|
||||
|
||||
e_t = unconditional_latent + unconditional_guidance_scale * (latent - unconditional_latent)
|
||||
return e_t
|
||||
|
||||
def decode(self, x):
|
||||
x = self.first_stage_model.post_quant_conv(1/0.18215 * x)
|
||||
x = self.first_stage_model.decoder(x)
|
||||
|
||||
# make image correct size and scale
|
||||
x = (x + 1.0) / 2.0
|
||||
x = x.reshape(3,512,512).permute(1,2,0).clip(0,1)*255
|
||||
return x.cast(dtypes.uint8) if Device.DEFAULT != "WEBGPU" else x
|
||||
|
||||
def __call__(self, unconditional_context, context, latent, timestep, alphas, alphas_prev, guidance):
|
||||
e_t = self.get_model_output(unconditional_context, context, latent, timestep, guidance)
|
||||
x_prev, _ = self.get_x_prev_and_pred_x0(latent, e_t, alphas, alphas_prev)
|
||||
#e_t_next = get_model_output(x_prev)
|
||||
#e_t_prime = (e_t + e_t_next) / 2
|
||||
#x_prev, pred_x0 = get_x_prev_and_pred_x0(latent, e_t_prime, index)
|
||||
return x_prev.realize()
|
||||
|
||||
# ** ldm.models.autoencoder.AutoencoderKL (done!)
|
||||
# 3x512x512 <--> 4x64x64 (16384)
|
||||
@@ -595,65 +632,31 @@ if __name__ == "__main__":
|
||||
# done with clip model
|
||||
del model.cond_stage_model
|
||||
|
||||
def get_model_output(latent, timestep, unconditional_guidance_scale):
|
||||
# put into diffuser
|
||||
latents = model.model.diffusion_model(latent.expand(2, *latent.shape[1:]), timestep, unconditional_context.cat(context, dim=0))
|
||||
unconditional_latent, latent = latents[0:1], latents[1:2]
|
||||
|
||||
e_t = unconditional_latent + unconditional_guidance_scale * (latent - unconditional_latent)
|
||||
return e_t
|
||||
|
||||
timesteps = list(range(1, 1000, 1000//args.steps))
|
||||
print(f"running for {timesteps} timesteps")
|
||||
alphas = model.alphas_cumprod[Tensor(timesteps)]
|
||||
alphas_prev = Tensor([1.0]).cat(alphas[:-1])
|
||||
|
||||
def get_x_prev_and_pred_x0(x, e_t, index):
|
||||
temperature = 1
|
||||
a_t, a_prev = alphas[index], alphas_prev[index]
|
||||
sigma_t = 0
|
||||
sqrt_one_minus_at = (1-a_t).sqrt()
|
||||
#print(a_t, a_prev, sigma_t, sqrt_one_minus_at)
|
||||
|
||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||
|
||||
# direction pointing to x_t
|
||||
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
||||
|
||||
x_prev = a_prev.sqrt() * pred_x0 + dir_xt
|
||||
return x_prev, pred_x0
|
||||
|
||||
@TinyJit
|
||||
def do_step(latent, timestep, index, guidance):
|
||||
e_t = get_model_output(latent, timestep, guidance)
|
||||
x_prev, _ = get_x_prev_and_pred_x0(latent, e_t, index)
|
||||
#e_t_next = get_model_output(x_prev)
|
||||
#e_t_prime = (e_t + e_t_next) / 2
|
||||
#x_prev, pred_x0 = get_x_prev_and_pred_x0(latent, e_t_prime, index)
|
||||
return x_prev.realize()
|
||||
|
||||
# start with random noise
|
||||
if args.seed is not None: Tensor._seed = args.seed
|
||||
latent = Tensor.randn(1,4,64,64)
|
||||
|
||||
@TinyJit
|
||||
def run(model, *x): return model(*x).realize()
|
||||
|
||||
# this is diffusion
|
||||
with Context(BEAM=getenv("LATEBEAM")):
|
||||
# this is diffusion
|
||||
for index, timestep in (t:=tqdm(list(enumerate(timesteps))[::-1])):
|
||||
GlobalCounters.reset()
|
||||
t.set_description("%3d %3d" % (index, timestep))
|
||||
with Timing("step in ", enabled=args.timing, on_exit=lambda _: f", using {GlobalCounters.mem_used/1e9:.2f} GB"):
|
||||
latent = do_step(latent, Tensor([timestep]), Tensor([index]), Tensor([args.guidance]))
|
||||
tid = Tensor([index])
|
||||
latent = run(model, unconditional_context, context, latent, Tensor([timestep]), alphas[tid], alphas_prev[tid], Tensor([args.guidance]))
|
||||
if args.timing: Device[Device.DEFAULT].synchronize()
|
||||
del do_step
|
||||
del run
|
||||
|
||||
# upsample latent space to image with autoencoder
|
||||
x = model.first_stage_model.post_quant_conv(1/0.18215 * latent)
|
||||
x = model.first_stage_model.decoder(x)
|
||||
|
||||
# make image correct size and scale
|
||||
x = (x + 1.0) / 2.0
|
||||
x = (x.reshape(3,512,512).permute(1,2,0).clip(0,1)*255)
|
||||
if Device.DEFAULT != "WEBGPU": x = x.cast(dtypes.uint8)
|
||||
x = model.decode(latent)
|
||||
print(x.shape)
|
||||
|
||||
# save image
|
||||
|
||||
234
examples/webgpu/stable_diffusion/compile.py
Normal file
234
examples/webgpu/stable_diffusion/compile.py
Normal file
@@ -0,0 +1,234 @@
|
||||
import os
|
||||
from extra.export_model import compile_net, jit_model
|
||||
from examples.stable_diffusion import StableDiffusion
|
||||
from tinygrad.nn.state import get_state_dict, safe_save, safe_load_metadata, torch_load, load_state_dict
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.ops import Device
|
||||
from extra.utils import download_file
|
||||
from typing import NamedTuple, Any, List
|
||||
from pathlib import Path
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
def convert_f32_to_f16(input_file, output_file):
|
||||
with open(input_file, 'rb') as f:
|
||||
metadata_length_bytes = f.read(8)
|
||||
metadata_length = int.from_bytes(metadata_length_bytes, byteorder='little', signed=False)
|
||||
metadata_json_bytes = f.read(metadata_length)
|
||||
float32_values = np.fromfile(f, dtype=np.float32)
|
||||
|
||||
first_text_model_offset = 3772703308
|
||||
num_elements = int((first_text_model_offset)/4)
|
||||
front_float16_values = float32_values[:num_elements].astype(np.float16)
|
||||
rest_float32_values = float32_values[num_elements:]
|
||||
|
||||
with open(output_file, 'wb') as f:
|
||||
f.write(metadata_length_bytes)
|
||||
f.write(metadata_json_bytes)
|
||||
front_float16_values.tofile(f)
|
||||
rest_float32_values.tofile(f)
|
||||
|
||||
FILENAME = Path(__file__).parent.parent.parent.parent / "weights/sd-v1-4.ckpt"
|
||||
|
||||
def split_safetensor(fn):
|
||||
_, json_len, metadata = safe_load_metadata(fn)
|
||||
text_model_offset = 3772703308
|
||||
chunk_size = 536870912
|
||||
|
||||
for k in metadata:
|
||||
# safetensor is in fp16, except for text moel
|
||||
if (metadata[k]["data_offsets"][0] < text_model_offset):
|
||||
metadata[k]["data_offsets"][0] = int(metadata[k]["data_offsets"][0]/2)
|
||||
metadata[k]["data_offsets"][1] = int(metadata[k]["data_offsets"][1]/2)
|
||||
|
||||
last_offset = 0
|
||||
part_end_offsets = []
|
||||
|
||||
for k in metadata:
|
||||
offset = metadata[k]['data_offsets'][0]
|
||||
|
||||
if offset == text_model_offset:
|
||||
break
|
||||
|
||||
part_offset = offset - last_offset
|
||||
|
||||
if (part_offset >= chunk_size):
|
||||
part_end_offsets.append(8+json_len+offset)
|
||||
last_offset = offset
|
||||
|
||||
text_model_start = int(text_model_offset/2)
|
||||
net_bytes = bytes(open(fn, 'rb').read())
|
||||
part_end_offsets.append(text_model_start+8+json_len)
|
||||
cur_pos = 0
|
||||
|
||||
for i, end_pos in enumerate(part_end_offsets):
|
||||
with open(f'./net_part{i}.safetensors', "wb+") as f:
|
||||
f.write(net_bytes[cur_pos:end_pos])
|
||||
cur_pos = end_pos
|
||||
|
||||
with open(f'./net_textmodel.safetensors', "wb+") as f:
|
||||
f.write(net_bytes[text_model_start+8+json_len:])
|
||||
|
||||
return part_end_offsets
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description='Run Stable Diffusion', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument('--remoteweights', action='store_true', help="Use safetensors from Huggingface, or from local")
|
||||
args = parser.parse_args()
|
||||
Device.DEFAULT = "WEBGPU"
|
||||
|
||||
Tensor.no_grad = True
|
||||
model = StableDiffusion()
|
||||
|
||||
# load in weights
|
||||
download_file('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', FILENAME)
|
||||
load_state_dict(model, torch_load(FILENAME)['state_dict'], strict=False)
|
||||
|
||||
class Step(NamedTuple):
|
||||
name: str = ""
|
||||
input: List[Tensor] = []
|
||||
forward: Any = None
|
||||
|
||||
sub_steps = [
|
||||
Step(name = "textModel", input = [Tensor.randn(1, 77)], forward = model.cond_stage_model.transformer.text_model),
|
||||
Step(name = "diffusor", input = [Tensor.randn(1, 77, 768), Tensor.randn(1, 77, 768), Tensor.randn(1,4,64,64), Tensor.rand(1), Tensor.randn(1), Tensor.randn(1), Tensor.randn(1)], forward = model),
|
||||
Step(name = "decoder", input = [Tensor.randn(1,4,64,64)], forward = model.decode)
|
||||
]
|
||||
|
||||
prg = ""
|
||||
|
||||
def compile_step(model, step: Step):
|
||||
run, special_names = jit_model(step, *step.input)
|
||||
functions, statements, bufs, _ = compile_net(run, special_names)
|
||||
state = get_state_dict(model)
|
||||
weights = {id(x.lazydata.realized): name for name, x in state.items()}
|
||||
kernel_code = '\n\n'.join([f"const {key} = `{code.replace(key, 'main')}`;" for key, code in functions.items()])
|
||||
kernel_names = ', '.join([name for (name, _, _, _) in statements])
|
||||
kernel_calls = '\n '.join([f"addComputePass(device, commandEncoder, piplines[{i}], [{', '.join(args)}], {global_size});" for i, (_name, args, global_size, _local_size) in enumerate(statements) ])
|
||||
bufs = '\n '.join([f"const {name} = " + (f"createEmptyBuf(device, {size});" if _key not in weights else f"createWeightBuf(device, {size}, getTensorBuffer(safetensor, metadata['{weights[_key]}'], '{weights[_key]}'))") + ";" for name,(size,dtype,_key) in bufs.items()])
|
||||
gpu_write_bufs = '\n '.join([f"const gpuWriteBuffer{i} = device.createBuffer({{size:input{i}.size, usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.MAP_WRITE }});" for i,(_,value) in enumerate(special_names.items()) if "output" not in value])
|
||||
input_writer = '\n '.join([f"await gpuWriteBuffer{i}.mapAsync(GPUMapMode.WRITE);\n new Float32Array(gpuWriteBuffer{i}.getMappedRange()).set(" + f'data{i});' + f"\n gpuWriteBuffer{i}.unmap();\ncommandEncoder.copyBufferToBuffer(gpuWriteBuffer{i}, 0, input{i}, 0, gpuWriteBuffer{i}.size);" for i,(_,value) in enumerate(special_names.items()) if value != "output0"])
|
||||
return f"""\n var {step.name} = function() {{
|
||||
|
||||
{kernel_code}
|
||||
|
||||
return {{
|
||||
"setup": async (device, safetensor) => {{
|
||||
const metadata = getTensorMetadata(safetensor[0]);
|
||||
|
||||
{bufs}
|
||||
|
||||
{gpu_write_bufs}
|
||||
const gpuReadBuffer = device.createBuffer({{ size: output0.size, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ }});
|
||||
|
||||
const kernels = [{kernel_names}];
|
||||
const piplines = await Promise.all(kernels.map(name => device.createComputePipelineAsync({{layout: "auto", compute: {{ module: device.createShaderModule({{ code: name }}), entryPoint: "main" }}}})));
|
||||
|
||||
return async ({",".join([f'data{i}' for i,(k,v) in enumerate(special_names.items()) if v != "output0"])}) => {{
|
||||
const commandEncoder = device.createCommandEncoder();
|
||||
|
||||
{input_writer}
|
||||
|
||||
{kernel_calls}
|
||||
commandEncoder.copyBufferToBuffer(output0, 0, gpuReadBuffer, 0, output0.size);
|
||||
const gpuCommands = commandEncoder.finish();
|
||||
device.queue.submit([gpuCommands]);
|
||||
|
||||
await gpuReadBuffer.mapAsync(GPUMapMode.READ);
|
||||
const resultBuffer = new Float32Array(gpuReadBuffer.size/4);
|
||||
resultBuffer.set(new Float32Array(gpuReadBuffer.getMappedRange()));
|
||||
gpuReadBuffer.unmap();
|
||||
return resultBuffer;
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
"""
|
||||
|
||||
for step in sub_steps:
|
||||
print(f'Executing step={step.name}')
|
||||
prg += compile_step(model, step)
|
||||
|
||||
if step.name == "diffusor":
|
||||
if args.remoteweights:
|
||||
base_url = "https://huggingface.co/wpmed/tinygrad-sd-f16/resolve/main"
|
||||
else:
|
||||
state = get_state_dict(model)
|
||||
safe_save(state, os.path.join(os.path.dirname(__file__), "net.safetensors"))
|
||||
convert_f32_to_f16("./net.safetensors", "./net_conv.safetensors")
|
||||
split_safetensor("./net_conv.safetensors")
|
||||
os.remove("net.safetensors")
|
||||
os.remove("net_conv.safetensors")
|
||||
base_url = "."
|
||||
|
||||
prekernel = f"""
|
||||
window.MODEL_BASE_URL= "{base_url}";
|
||||
const getTensorMetadata = (safetensorBuffer) => {{
|
||||
const metadataLength = Number(new DataView(safetensorBuffer.buffer).getBigUint64(0, true));
|
||||
const metadata = JSON.parse(new TextDecoder("utf8").decode(safetensorBuffer.subarray(8, 8 + metadataLength)));
|
||||
return Object.fromEntries(Object.entries(metadata).filter(([k, v]) => k !== "__metadata__").map(([k, v]) => [k, {{...v, data_offsets: v.data_offsets.map(x => 8 + metadataLength + x)}}]));
|
||||
}};
|
||||
|
||||
const getTensorBuffer = (safetensorParts, tensorMetadata, key) => {{
|
||||
let selectedPart = 0;
|
||||
let counter = 0;
|
||||
let partStartOffsets = [1131408336, 2227518416, 3308987856, 4265298864];
|
||||
let correctedOffsets = tensorMetadata.data_offsets;
|
||||
let prev_offset = 0;
|
||||
|
||||
for (let start of partStartOffsets) {{
|
||||
prev_offset = (counter == 0) ? 0 : partStartOffsets[counter-1];
|
||||
|
||||
if (tensorMetadata.data_offsets[0] < start) {{
|
||||
selectedPart = counter;
|
||||
correctedOffsets = [correctedOffsets[0]-prev_offset, correctedOffsets[1]-prev_offset];
|
||||
break;
|
||||
}}
|
||||
|
||||
counter++;
|
||||
}}
|
||||
|
||||
let allZero = true;
|
||||
let out = safetensorParts[selectedPart].subarray(...correctedOffsets);
|
||||
|
||||
for (let i = 0; i < out.length; i++) {{
|
||||
if (out[i] !== 0) {{
|
||||
allZero = false;
|
||||
break;
|
||||
}}
|
||||
}}
|
||||
|
||||
if (allZero) {{
|
||||
console.log("Error: weight '" + key + "' is all zero.");
|
||||
}}
|
||||
|
||||
return safetensorParts[selectedPart].subarray(...correctedOffsets);
|
||||
}}
|
||||
|
||||
const getWeight = (safetensors, key) => {{
|
||||
let uint8Data = getTensorBuffer(safetensors, getTensorMetadata(safetensors[0])[key], key);
|
||||
return new Float32Array(uint8Data.buffer, uint8Data.byteOffset, uint8Data.byteLength / Float32Array.BYTES_PER_ELEMENT);
|
||||
}}
|
||||
|
||||
const createEmptyBuf = (device, size) => {{
|
||||
return device.createBuffer({{size, usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST }});
|
||||
}};
|
||||
|
||||
const createWeightBuf = (device, size, data) => {{
|
||||
const buf = device.createBuffer({{ mappedAtCreation: true, size, usage: GPUBufferUsage.STORAGE }});
|
||||
new Uint8Array(buf.getMappedRange()).set(data);
|
||||
buf.unmap();
|
||||
return buf;
|
||||
}};
|
||||
|
||||
const addComputePass = (device, commandEncoder, pipeline, bufs, workgroup) => {{
|
||||
const bindGroup = device.createBindGroup({{layout: pipeline.getBindGroupLayout(0), entries: bufs.map((buffer, index) => ({{ binding: index, resource: {{ buffer }} }}))}});
|
||||
const passEncoder = commandEncoder.beginComputePass();
|
||||
passEncoder.setPipeline(pipeline);
|
||||
passEncoder.setBindGroup(0, bindGroup);
|
||||
passEncoder.dispatchWorkgroups(...workgroup);
|
||||
passEncoder.end();
|
||||
}};"""
|
||||
|
||||
with open(os.path.join(os.path.dirname(__file__), "net.js"), "w") as text_file:
|
||||
text_file.write(prekernel + prg)
|
||||
64
examples/webgpu/stable_diffusion/f16_to_f32.js
Normal file
64
examples/webgpu/stable_diffusion/f16_to_f32.js
Normal file
@@ -0,0 +1,64 @@
|
||||
const f16tof32 = `
|
||||
fn u16_to_f16(x: u32) -> f32 {
|
||||
let sign = f32((x >> 15) & 0x1);
|
||||
let exponent = f32((x >> 10) & 0x1F);
|
||||
let fraction = f32(x & 0x3FF);
|
||||
|
||||
let sign_multiplier = select(1.0, -1.0, sign == 1.0);
|
||||
if (exponent == 0.0) {
|
||||
return sign_multiplier * 6.103515625e-5 * (fraction / 1024.0);
|
||||
} else {
|
||||
return sign_multiplier * exp2(exponent - 15.0) * (1.0 + fraction / 1024.0);
|
||||
}
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var<storage,read_write> data0: array<u32>;
|
||||
@group(0) @binding(1) var<storage,read_write> data1: array<f32>;
|
||||
@compute @workgroup_size(256) fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||
let gidx = gid.x;
|
||||
let outgidx = gidx*2;
|
||||
|
||||
if (gidx >= arrayLength(&data0)) {
|
||||
return;
|
||||
}
|
||||
|
||||
let oo = data0[gidx];
|
||||
let oo1 = (oo >> 16);
|
||||
let oo2 = oo & 0xFFFFu;
|
||||
|
||||
let f1 = u16_to_f16(oo2);
|
||||
let f2 = u16_to_f16(oo1);
|
||||
|
||||
data1[outgidx] = f1;
|
||||
data1[outgidx + 1] = f2;
|
||||
}`;
|
||||
|
||||
window.f16tof32GPU = async(device, inf16) => {
|
||||
const input = device.createBuffer({size: inf16.length, usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST });
|
||||
const output = device.createBuffer({size: inf16.length*2, usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST });
|
||||
|
||||
const gpuWriteBuffer = device.createBuffer({size: input.size, usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.MAP_WRITE });
|
||||
const gpuReadBuffer = device.createBuffer({ size: output.size, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ });
|
||||
const commandEncoder = device.createCommandEncoder();
|
||||
await gpuWriteBuffer.mapAsync(GPUMapMode.WRITE);
|
||||
|
||||
const alignedUint32View = new Uint32Array(inf16.buffer, inf16.byteOffset, inf16.length / 4);
|
||||
new Uint32Array(gpuWriteBuffer.getMappedRange()).set(alignedUint32View);
|
||||
|
||||
gpuWriteBuffer.unmap();
|
||||
commandEncoder.copyBufferToBuffer(gpuWriteBuffer, 0, input, 0, gpuWriteBuffer.size);
|
||||
const pipeline = await device.createComputePipelineAsync({layout: "auto", compute: { module: device.createShaderModule({ code: f16tof32 }), entryPoint: "main" }});
|
||||
|
||||
addComputePass(device, commandEncoder, pipeline, [input, output], [Math.ceil(inf16.length/(4*256)), 1, 1]);
|
||||
|
||||
commandEncoder.copyBufferToBuffer(output, 0, gpuReadBuffer, 0, output.size);
|
||||
const gpuCommands = commandEncoder.finish();
|
||||
device.queue.submit([gpuCommands]);
|
||||
|
||||
await gpuReadBuffer.mapAsync(GPUMapMode.READ);
|
||||
const resultBuffer = new Float32Array(gpuReadBuffer.size/4);
|
||||
resultBuffer.set(new Float32Array(gpuReadBuffer.getMappedRange()));
|
||||
gpuReadBuffer.unmap();
|
||||
|
||||
return resultBuffer;
|
||||
}
|
||||
511
examples/webgpu/stable_diffusion/index.html
Normal file
511
examples/webgpu/stable_diffusion/index.html
Normal file
@@ -0,0 +1,511 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>tinygrad has WebGPU</title>
|
||||
|
||||
<style>
|
||||
body {
|
||||
font-family: 'Arial', sans-serif;
|
||||
text-align: center;
|
||||
padding: 30px;
|
||||
}
|
||||
|
||||
a {
|
||||
text-decoration: none;
|
||||
color: #4A90E2;
|
||||
}
|
||||
|
||||
h1 {
|
||||
font-size: 36px;
|
||||
font-weight: normal;
|
||||
margin-bottom: 20px;
|
||||
}
|
||||
|
||||
#mybox {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
gap: 20px;
|
||||
width: 50%;
|
||||
margin: 0 auto;
|
||||
}
|
||||
|
||||
#promptText, #stepRange, #btnRunNet, #guidanceRange {
|
||||
font-size: 18px;
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
#result {
|
||||
font-size: 48px;
|
||||
}
|
||||
|
||||
#time {
|
||||
font-size: 16px;
|
||||
color: grey;
|
||||
}
|
||||
|
||||
canvas {
|
||||
margin-top: 20px;
|
||||
border: 1px solid #000;
|
||||
}
|
||||
|
||||
label {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 10px;
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
#sliderValue {
|
||||
margin-right: 10px;/
|
||||
}
|
||||
</style>
|
||||
|
||||
<script type="module">
|
||||
import ClipTokenizer from 'https://softwired.nyc3.cdn.digitaloceanspaces.com/sd/clip_tokenizer.js';
|
||||
window.clipTokenizer = new ClipTokenizer();
|
||||
</script>
|
||||
<script src="./f16_to_f32.js"></script>
|
||||
<script src="./net.js"></script>
|
||||
</head>
|
||||
|
||||
<body>
|
||||
<h1 id="wgpuError" style="display: none; color: red;">WebGPU is not supported in this browser</h1>
|
||||
<h1 id="sdTitle">StableDiffusion by <a href="https://github.com/tinygrad/tinygrad" target="_blank">tinygrad</a> WebGPU</h1>
|
||||
<div id="mybox">
|
||||
<input id="promptText" type="text" placeholder="Enter your prompt here" value="a horse sized cat eating a bagel">
|
||||
|
||||
<label>
|
||||
Steps: <span id="stepValue">8</span>
|
||||
<input id="stepRange" type="range" min="5" max="20" value="8" step="1">
|
||||
</label>
|
||||
|
||||
<label>
|
||||
Guidance: <span id="guidanceValue">7.5</span>
|
||||
<input id="guidanceRange" type="range" min="3" max="15" value="7.5" step="0.1">
|
||||
</label>
|
||||
|
||||
<input id="btnRunNet" type="button" value="Run" disabled>
|
||||
|
||||
<div id="divModelDl" style="display: flex; align-items: center; width: 100%; gap: 10px;">
|
||||
<span id="modelDlTitle">Downloading model</span>
|
||||
<progress id="modelDlProgressBar" value="0" max="100" style="flex-grow: 1;"></progress>
|
||||
<span id="modelDlProgressValue"></span>
|
||||
</div>
|
||||
|
||||
<div id="divStepProgress" style="display: none; align-items: center; width: 100%; gap: 10px;">
|
||||
<progress id="progressBar" value="0" max="100" style="flex-grow: 1;"></progress>
|
||||
<span id="progressFraction"></span>
|
||||
</div>
|
||||
</div>
|
||||
<canvas id="canvas" width="512" height="512"></canvas>
|
||||
|
||||
<script>
|
||||
function initDb() {
|
||||
return new Promise((resolve, reject) => {
|
||||
let db;
|
||||
const request = indexedDB.open('tinydb', 1);
|
||||
request.onerror = (event) => {
|
||||
console.error('Database error:', event.target.error);
|
||||
resolve(null);
|
||||
};
|
||||
|
||||
request.onsuccess = (event) => {
|
||||
db = event.target.result;
|
||||
console.log("Db initialized.");
|
||||
resolve(db);
|
||||
};
|
||||
|
||||
request.onupgradeneeded = (event) => {
|
||||
db = event.target.result;
|
||||
if (!db.objectStoreNames.contains('tensors')) {
|
||||
db.createObjectStore('tensors', { keyPath: 'id' });
|
||||
}
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
function saveTensorToDb(db, id, tensor) {
|
||||
return new Promise((resolve, reject) => {
|
||||
if (db == null) {
|
||||
resolve(null);
|
||||
}
|
||||
|
||||
const transaction = db.transaction(['tensors'], 'readwrite');
|
||||
const store = transaction.objectStore('tensors');
|
||||
const request = store.put({ id: id, content: tensor });
|
||||
|
||||
transaction.onabort = (event) => {
|
||||
console.log("Transaction error while saving tensor: " + event.target.error);
|
||||
resolve(null);
|
||||
};
|
||||
|
||||
request.onsuccess = () => {
|
||||
console.log('Tensor saved successfully.');
|
||||
resolve();
|
||||
};
|
||||
|
||||
request.onerror = (event) => {
|
||||
console.error('Tensor save failed:', event.target.error);
|
||||
resolve(null);
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
function readTensorFromDb(db, id) {
|
||||
return new Promise((resolve, reject) => {
|
||||
if (db == null) {
|
||||
resolve(null);
|
||||
}
|
||||
|
||||
const transaction = db.transaction(['tensors'], 'readonly');
|
||||
const store = transaction.objectStore('tensors');
|
||||
const request = store.get(id);
|
||||
|
||||
transaction.onabort = (event) => {
|
||||
console.log("Transaction error while reading tensor: " + event.target.error);
|
||||
resolve(null);
|
||||
};
|
||||
|
||||
request.onsuccess = (event) => {
|
||||
const result = event.target.result;
|
||||
if (result) {
|
||||
console.log("Cache hit: " + id);
|
||||
resolve(result);
|
||||
} else {
|
||||
console.log("Cache miss: " + id);
|
||||
resolve(null);
|
||||
}
|
||||
};
|
||||
|
||||
request.onerror = (event) => {
|
||||
console.error('Tensor retrieve failed: ', event.target.error);
|
||||
resolve(null);
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
window.addEventListener('load', async function() {
|
||||
if (!navigator.gpu) {
|
||||
document.getElementById("wgpuError").style.display = "";
|
||||
document.getElementById("sdTitle").style.display = "none";
|
||||
return;
|
||||
}
|
||||
|
||||
let db = await initDb();
|
||||
|
||||
const ctx = document.getElementById("canvas").getContext("2d", { willReadFrequently: true });
|
||||
let labels, nets, safetensorParts;
|
||||
|
||||
const getDevice = async () => {
|
||||
const adapter = await navigator.gpu.requestAdapter();
|
||||
const requiredLimits = {};
|
||||
const maxBufferSizeInSDModel = 1073741824;
|
||||
requiredLimits.maxStorageBufferBindingSize = maxBufferSizeInSDModel;
|
||||
requiredLimits.maxBufferSize = maxBufferSizeInSDModel;
|
||||
|
||||
return await adapter.requestDevice({
|
||||
requiredLimits
|
||||
});
|
||||
};
|
||||
|
||||
const timer = async (func, label = "") => {
|
||||
const start = performance.now();
|
||||
const out = await func();
|
||||
const delta = (performance.now() - start).toFixed(1)
|
||||
console.log(`${delta} ms ${label}`);
|
||||
return out;
|
||||
}
|
||||
|
||||
const getProgressDlForPart = async (part, progressCallback) => {
|
||||
const response = await fetch(part);
|
||||
const contentLength = response.headers.get('content-length');
|
||||
const total = parseInt(contentLength, 10);
|
||||
|
||||
const res = new Response(new ReadableStream({
|
||||
async start(controller) {
|
||||
const reader = response.body.getReader();
|
||||
for (;;) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) break;
|
||||
progressCallback(part, value.byteLength, total);
|
||||
controller.enqueue(value);
|
||||
}
|
||||
|
||||
controller.close();
|
||||
},
|
||||
}));
|
||||
|
||||
return res.arrayBuffer();
|
||||
};
|
||||
|
||||
const getAndDecompressF16Safetensors = async (device, progress) => {
|
||||
let totalLoaded = 0;
|
||||
let totalSize = 0;
|
||||
let partSize = {};
|
||||
|
||||
const progressCallback = (part, loaded, total) => {
|
||||
totalLoaded += loaded;
|
||||
|
||||
if (!partSize[part]) {
|
||||
totalSize += total;
|
||||
partSize[part] = true;
|
||||
}
|
||||
|
||||
progress(totalLoaded, totalSize);
|
||||
};
|
||||
|
||||
let combinedBuffer = await readTensorFromDb(db, "net.f16");
|
||||
let textModelU8 = await readTensorFromDb(db, "net.text");
|
||||
let textModelFetched = false;
|
||||
|
||||
if (combinedBuffer == null) {
|
||||
let dlParts = [
|
||||
getProgressDlForPart(window.MODEL_BASE_URL + '/net_part0.safetensors', progressCallback),
|
||||
getProgressDlForPart(window.MODEL_BASE_URL + '/net_part1.safetensors', progressCallback),
|
||||
getProgressDlForPart(window.MODEL_BASE_URL + '/net_part2.safetensors', progressCallback),
|
||||
getProgressDlForPart(window.MODEL_BASE_URL + '/net_part3.safetensors', progressCallback)
|
||||
];
|
||||
|
||||
if (textModelU8 == null) {
|
||||
dlParts.push(getProgressDlForPart(window.MODEL_BASE_URL + '/net_textmodel.safetensors', progressCallback));
|
||||
}
|
||||
|
||||
let buffers = await Promise.all(dlParts);
|
||||
|
||||
// Combine everything except for text model, since that's alreafy f32
|
||||
const totalLength = buffers.reduce((acc, buffer, index, array) => {
|
||||
if (index < 4) {
|
||||
return acc + buffer.byteLength;
|
||||
} else {
|
||||
return acc;
|
||||
}
|
||||
}, 0
|
||||
);
|
||||
|
||||
combinedBuffer = new Uint8Array(totalLength);
|
||||
|
||||
let offset = 0;
|
||||
buffers.forEach((buffer, index) => {
|
||||
if (index < 4) {
|
||||
combinedBuffer.set(new Uint8Array(buffer), offset);
|
||||
offset += buffer.byteLength;
|
||||
buffer = null;
|
||||
}
|
||||
});
|
||||
|
||||
await saveTensorToDb(db, "net.f16", combinedBuffer);
|
||||
|
||||
if (textModelU8 == null) {
|
||||
textModelFetched = true;
|
||||
textModelU8 = new Uint8Array(buffers[4]);
|
||||
await saveTensorToDb(db, "net.text", textModelU8);
|
||||
}
|
||||
} else {
|
||||
combinedBuffer = combinedBuffer.content;
|
||||
}
|
||||
|
||||
if (textModelU8 == null) {
|
||||
textModelU8 = new Uint8Array(await getProgressDlForPart(window.MODEL_BASE_URL + '/net_textmodel.safetensors', progressCallback));
|
||||
await saveTensorToDb(db, "net.text", textModelU8);
|
||||
} else if (!textModelFetched) {
|
||||
textModelU8 = textModelU8.content;
|
||||
}
|
||||
|
||||
document.getElementById("modelDlTitle").innerHTML = "Decompressing model";
|
||||
|
||||
const textModelOffset = 3772703308;
|
||||
const metadataLength = Number(new DataView(combinedBuffer.buffer).getBigUint64(0, true));
|
||||
const metadata = JSON.parse(new TextDecoder("utf8").decode(combinedBuffer.subarray(8, 8 + metadataLength)));
|
||||
|
||||
const allToDecomp = combinedBuffer.byteLength - (8 + metadataLength);
|
||||
const decodeChunkSize = 67107840;
|
||||
const numChunks = Math.ceil(allToDecomp/decodeChunkSize);
|
||||
|
||||
console.log(allToDecomp + " bytes to decompress");
|
||||
console.log("Will be decompressed in " + numChunks+ " chunks");
|
||||
|
||||
let partOffsets = [{start: 0, end: 1131408336}, {start: 1131408336, end: 2227518416}, {start: 2227518416, end: 3308987856}, {start: 3308987856, end: 4265298864}];
|
||||
let parts = [];
|
||||
|
||||
for (let offsets of partOffsets) {
|
||||
parts.push(new Uint8Array(offsets.end-offsets.start));
|
||||
}
|
||||
parts[0].set(new Uint8Array(new BigUint64Array([BigInt(metadataLength)]).buffer), 0);
|
||||
parts[0].set(combinedBuffer.subarray(8, 8 + metadataLength), 8);
|
||||
parts[3].set(textModelU8, textModelOffset+8+metadataLength - partOffsets[3].start);
|
||||
|
||||
let start = Date.now();
|
||||
let cursor = 0;
|
||||
|
||||
for (let i = 0; i < numChunks; i++) {
|
||||
progress(i, numChunks);
|
||||
let chunkStartF16 = 8 + metadataLength + (decodeChunkSize * i);
|
||||
let chunkEndF16 = chunkStartF16 + decodeChunkSize;
|
||||
let chunk = combinedBuffer.subarray(chunkStartF16, chunkEndF16);
|
||||
|
||||
if (chunk.byteLength %4 != 0) {
|
||||
const paddingBytes = 4 - (chunk.byteLength % 4);
|
||||
const alignedBuffer = new ArrayBuffer(chunk.byteLength + paddingBytes);
|
||||
const alignedView = new Uint8Array(alignedBuffer);
|
||||
alignedView.set(new Uint8Array(chunk));
|
||||
chunk = alignedView;
|
||||
}
|
||||
|
||||
let result = await f16tof32GPU(device, chunk);
|
||||
let resultUint8 = new Uint8Array(result.buffer);
|
||||
let chunkStartF32 = 8 + metadataLength + (decodeChunkSize * i * 2);
|
||||
let chunkEndF32 = chunkStartF32 + resultUint8.byteLength;
|
||||
let offsetInPart = chunkStartF32 - partOffsets[cursor].start;
|
||||
|
||||
if (chunkEndF32 < partOffsets[cursor].end || cursor === parts.length - 1) {
|
||||
parts[cursor].set(resultUint8, offsetInPart);
|
||||
} else {
|
||||
let spaceLeftInCurrentPart = partOffsets[cursor].end - chunkStartF32;
|
||||
parts[cursor].set(resultUint8.subarray(0, spaceLeftInCurrentPart), offsetInPart);
|
||||
|
||||
cursor++;
|
||||
|
||||
if (cursor < parts.length) {
|
||||
let nextPartOffset = spaceLeftInCurrentPart;
|
||||
let nextPartLength = resultUint8.length - nextPartOffset;
|
||||
parts[cursor].set(resultUint8.subarray(nextPartOffset, nextPartOffset + nextPartLength), 0);
|
||||
}
|
||||
}
|
||||
|
||||
resultUint8 = null;
|
||||
result = null;
|
||||
}
|
||||
|
||||
combinedBuffer = null;
|
||||
|
||||
let end = Date.now();
|
||||
console.log("Decoding took: " + ((end - start) / 1000) + " s");
|
||||
console.log("Avarage " + ((end - start) / numChunks) + " ms per chunk");
|
||||
|
||||
return parts;
|
||||
};
|
||||
|
||||
const loadNet = async () => {
|
||||
const modelDlTitle = document.getElementById("modelDlTitle");
|
||||
|
||||
const progress = (loaded, total) => {
|
||||
document.getElementById("modelDlProgressBar").value = (loaded/total) * 100
|
||||
document.getElementById("modelDlProgressValue").innerHTML = Math.trunc((loaded/total) * 100) + "%"
|
||||
}
|
||||
|
||||
const device = await getDevice();
|
||||
safetensorParts = await getAndDecompressF16Safetensors(device, progress);
|
||||
|
||||
modelDlTitle.innerHTML = "Compiling model"
|
||||
|
||||
let models = ["textModel", "diffusor", "decoder"];
|
||||
|
||||
nets = await timer(() => Promise.all([
|
||||
textModel().setup(device, safetensorParts),
|
||||
diffusor().setup(device, safetensorParts),
|
||||
decoder().setup(device, safetensorParts)
|
||||
]).then((loadedModels) => loadedModels.reduce((acc, model, index) => { acc[models[index]] = model; return acc; }, {})), "(compilation)")
|
||||
|
||||
progress(1, 1);
|
||||
|
||||
modelDlTitle.innerHTML = "Model ready"
|
||||
setTimeout(() => {
|
||||
document.getElementById("modelDlProgressBar").style.display = "none";
|
||||
document.getElementById("modelDlProgressValue").style.display = "none";
|
||||
document.getElementById("divStepProgress").style.display = "flex";
|
||||
}, 1000);
|
||||
document.getElementById("btnRunNet").disabled = false;
|
||||
}
|
||||
|
||||
function runStableDiffusion(prompt, steps, guidance) {
|
||||
return new Promise(async (resolve, reject) => {
|
||||
let context = await timer(() => nets["textModel"](clipTokenizer.encodeForCLIP(prompt)));
|
||||
let unconditional_context = await timer(() => nets["textModel"](clipTokenizer.encodeForCLIP("")));
|
||||
|
||||
let timesteps = [];
|
||||
|
||||
for (let i = 1; i < 1000; i += (1000/steps)) {
|
||||
timesteps.push(i);
|
||||
}
|
||||
|
||||
console.log("Timesteps: " + timesteps);
|
||||
|
||||
let alphasCumprod = getWeight(safetensorParts,"alphas_cumprod");
|
||||
let alphas = [];
|
||||
|
||||
for (t of timesteps) {
|
||||
alphas.push(alphasCumprod[Math.floor(t)]);
|
||||
}
|
||||
|
||||
alphas_prev = [1.0];
|
||||
|
||||
for (let i = 0; i < alphas.length-1; i++) {
|
||||
alphas_prev.push(alphas[i]);
|
||||
}
|
||||
|
||||
let inpSize = 4*64*64;
|
||||
latent = new Float32Array(inpSize);
|
||||
|
||||
for (let i = 0; i < inpSize; i++) {
|
||||
latent[i] = Math.sqrt(-2.0 * Math.log(Math.random())) * Math.cos(2.0 * Math.PI * Math.random());
|
||||
}
|
||||
|
||||
for (let i = timesteps.length - 1; i >= 0; i--) {
|
||||
let timestep = new Float32Array([timesteps[i]]);
|
||||
let x_prev = await timer(() => nets["diffusor"](unconditional_context, context, latent, timestep, new Float32Array([alphas[i]]), new Float32Array([alphas_prev[i]]), new Float32Array([guidance])));
|
||||
latent = x_prev;
|
||||
document.getElementById("progressBar").value = ((steps - i) / steps) * 100
|
||||
document.getElementById("progressFraction").innerHTML = (steps - i) + "/" + steps
|
||||
}
|
||||
|
||||
resolve(await timer(() => nets["decoder"](latent)));
|
||||
});
|
||||
}
|
||||
|
||||
document.getElementById("btnRunNet").addEventListener("click", function(e) {
|
||||
e.target.disabled = true;
|
||||
|
||||
runStableDiffusion(document.getElementById("promptText").value, document.getElementById("stepRange").value, document.getElementById("guidanceRange").value).then((image) => {
|
||||
let pixels = []
|
||||
let pixelCounter = 0
|
||||
|
||||
for (let j = 0; j < 512; j++) {
|
||||
for (let k = 0; k < 512; k++) {
|
||||
pixels.push(image[pixelCounter])
|
||||
pixels.push(image[pixelCounter+1])
|
||||
pixels.push(image[pixelCounter+2])
|
||||
pixels.push(255)
|
||||
pixelCounter += 3
|
||||
}
|
||||
}
|
||||
|
||||
ctx.putImageData(new ImageData(new Uint8ClampedArray(pixels), 512, 512), 0, 0);
|
||||
console.log(image);
|
||||
console.log("Success");
|
||||
e.target.disabled = false;
|
||||
});
|
||||
}, false);
|
||||
|
||||
const stepSlider = document.getElementById('stepRange');
|
||||
const stepValue = document.getElementById('stepValue');
|
||||
|
||||
stepSlider.addEventListener('input', function() {
|
||||
stepValue.textContent = stepSlider.value;
|
||||
});
|
||||
|
||||
const guidanceSlider = document.getElementById('guidanceRange');
|
||||
const guidanceValue = document.getElementById('guidanceValue');
|
||||
|
||||
guidanceSlider.addEventListener('input', function() {
|
||||
guidanceValue.textContent = guidanceSlider.value;
|
||||
});
|
||||
|
||||
loadNet();
|
||||
});
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
@@ -229,7 +229,9 @@ class LazyBuffer:
|
||||
|
||||
if MERGE_ELEMENTWISE_OPS:
|
||||
# remove the buffers from any (childless) BinaryOps that feed into this
|
||||
srcs = tuple([x.op if x.optype == BinaryOps and not x.children and not x.realized else x for x in srcs]) # type: ignore
|
||||
_srcs = tuple([x.op if x.optype == BinaryOps and not x.children and not x.realized else x for x in srcs]) # type: ignore
|
||||
# TODO: needs general merge limiting
|
||||
if out_device != "WEBGPU" or len(dedup([x.base for _src in _srcs for x in _src.buffers if not x.is_unrealized_const()])) < 7: srcs = _srcs # type: ignore
|
||||
|
||||
return create_lazybuffer(out_device, ShapeTracker.from_shape(out_shape), BinaryOps, LazyOp(op, srcs, arg), out_dtype)
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import os, json, pathlib, zipfile, pickle
|
||||
from tqdm import tqdm
|
||||
from typing import Dict, Union, List, Optional, Any
|
||||
from typing import Dict, Union, List, Optional, Any, Tuple
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import dtypes, prod, argsort, DEBUG, Timing, GlobalCounters, CI
|
||||
from tinygrad.shape.view import strides_for_shape
|
||||
@@ -9,11 +9,14 @@ from tinygrad.ops import Device
|
||||
safe_dtypes = {"F16": dtypes.float16, "F32": dtypes.float32, "U8": dtypes.uint8, "I8": dtypes.int8, "I32": dtypes.int32, "I64": dtypes.int64}
|
||||
inverse_safe_dtypes = {v:k for k,v in safe_dtypes.items()}
|
||||
|
||||
def safe_load(fn:Union[Tensor,str]) -> Dict[str, Tensor]:
|
||||
def safe_load_metadata(fn:Union[Tensor,str]) -> Tuple[Tensor, int, Any]:
|
||||
t = fn if isinstance(fn, Tensor) else Tensor.empty(os.stat(fn).st_size, dtype=dtypes.uint8, device=f"disk:{fn}")
|
||||
json_len = t[0:1].cast(dtypes.int64).numpy()[0]
|
||||
headers = json.loads(t[8:8+json_len].numpy().tobytes())
|
||||
return {k:t[8+json_len+v['data_offsets'][0]:].cast(safe_dtypes[v['dtype']])[:prod(v['shape'])].reshape(v['shape']) for k,v in headers.items() if k != "__metadata__"}
|
||||
return (t, json_len, json.loads(t[8:8+json_len].numpy().tobytes()))
|
||||
|
||||
def safe_load(fn:Union[Tensor,str]) -> Dict[str, Tensor]:
|
||||
t, json_len, metadata = safe_load_metadata(fn)
|
||||
return {k:t[8+json_len+v['data_offsets'][0]:].cast(safe_dtypes[v['dtype']])[:prod(v['shape'])].reshape(v['shape']) for k,v in metadata.items() if k != "__metadata__"}
|
||||
|
||||
def safe_save(tensors:Dict[str, Tensor], fn:str, metadata:Optional[Dict[str, Any]]=None):
|
||||
headers, offset = {}, 0
|
||||
|
||||
@@ -42,4 +42,4 @@ class RawWebGPUBuffer(RawBufferCopyIn):
|
||||
def toCPU(self) -> np.ndarray: return np.frombuffer(wgpu_device.queue.read_buffer(self._buf, 0), dtype=np.dtype(self.dtype.np, metadata={"backing": self})) # type: ignore
|
||||
|
||||
renderer = functools.partial(uops_to_cstyle, WGSLLanguage())
|
||||
WebGpuBuffer = Compiled(RawWebGPUBuffer, LinearizerOptions(supports_float4=False, local_max=[256, 256, 64], global_max=[65535, 65535, 65535]), renderer, lambda x: x, WebGPUProgram)
|
||||
WebGpuBuffer = Compiled(RawWebGPUBuffer, LinearizerOptions(device="WEBGPU", supports_float4=False, local_max=[256, 256, 64], global_max=[65535, 65535, 65535]), renderer, lambda x: x, WebGPUProgram)
|
||||
|
||||
Reference in New Issue
Block a user