mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
Stable diffusion WebGPU port (#1370)
* WIP: Stable diffusion WebGPU port * Load whole model: split safetensor to avoid Chrome allocation limit * Gitignore .DS_Store, remove debug print * Clip tokenizer in JS * WIP: Compile model in parts (text model, diffusor, get_x_prev_and_pred_x0, decoder), and recreate forward logic in JS * e2e stable diffusion flow * Create initial random latent tensor in JS * SD working e2e * Log if some weights were not loaded properly * Remove latent_tensor.npy used for debugging * Cleanup, remove useless logs * Improve UI * Add progress bar * Remove .npy files used for debugging * Add clip tokenizer as external dependency * Remove alphas_cumprod.js and load it from safetensors * Refactor * Simplify a lot * Dedup base when limiting elementwise merge (webgpu) * Add return type to safe_load_metadata * Do not allow run when webgpu is not supported * Add progress bar, refactor, fix special names * Add option to chose from local vs huggingface weights * lowercase tinygrad :) * fp16 model dl, decompression client side * Cache f16 model in browser, better progress * Cache miss recovery --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
@@ -538,7 +538,44 @@ class StableDiffusion:
|
||||
self.first_stage_model = AutoencoderKL()
|
||||
self.cond_stage_model = namedtuple("CondStageModel", ["transformer"])(transformer = namedtuple("Transformer", ["text_model"])(text_model = CLIPTextTransformer()))
|
||||
|
||||
# TODO: make __call__ run the model
|
||||
def get_x_prev_and_pred_x0(self, x, e_t, a_t, a_prev):
|
||||
temperature = 1
|
||||
sigma_t = 0
|
||||
sqrt_one_minus_at = (1-a_t).sqrt()
|
||||
#print(a_t, a_prev, sigma_t, sqrt_one_minus_at)
|
||||
|
||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||
|
||||
# direction pointing to x_t
|
||||
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
||||
|
||||
x_prev = a_prev.sqrt() * pred_x0 + dir_xt
|
||||
return x_prev, pred_x0
|
||||
|
||||
def get_model_output(self, unconditional_context, context, latent, timestep, unconditional_guidance_scale):
|
||||
# put into diffuser
|
||||
latents = self.model.diffusion_model(latent.expand(2, *latent.shape[1:]), timestep, unconditional_context.cat(context, dim=0))
|
||||
unconditional_latent, latent = latents[0:1], latents[1:2]
|
||||
|
||||
e_t = unconditional_latent + unconditional_guidance_scale * (latent - unconditional_latent)
|
||||
return e_t
|
||||
|
||||
def decode(self, x):
|
||||
x = self.first_stage_model.post_quant_conv(1/0.18215 * x)
|
||||
x = self.first_stage_model.decoder(x)
|
||||
|
||||
# make image correct size and scale
|
||||
x = (x + 1.0) / 2.0
|
||||
x = x.reshape(3,512,512).permute(1,2,0).clip(0,1)*255
|
||||
return x.cast(dtypes.uint8) if Device.DEFAULT != "WEBGPU" else x
|
||||
|
||||
def __call__(self, unconditional_context, context, latent, timestep, alphas, alphas_prev, guidance):
|
||||
e_t = self.get_model_output(unconditional_context, context, latent, timestep, guidance)
|
||||
x_prev, _ = self.get_x_prev_and_pred_x0(latent, e_t, alphas, alphas_prev)
|
||||
#e_t_next = get_model_output(x_prev)
|
||||
#e_t_prime = (e_t + e_t_next) / 2
|
||||
#x_prev, pred_x0 = get_x_prev_and_pred_x0(latent, e_t_prime, index)
|
||||
return x_prev.realize()
|
||||
|
||||
# ** ldm.models.autoencoder.AutoencoderKL (done!)
|
||||
# 3x512x512 <--> 4x64x64 (16384)
|
||||
@@ -595,65 +632,31 @@ if __name__ == "__main__":
|
||||
# done with clip model
|
||||
del model.cond_stage_model
|
||||
|
||||
def get_model_output(latent, timestep, unconditional_guidance_scale):
|
||||
# put into diffuser
|
||||
latents = model.model.diffusion_model(latent.expand(2, *latent.shape[1:]), timestep, unconditional_context.cat(context, dim=0))
|
||||
unconditional_latent, latent = latents[0:1], latents[1:2]
|
||||
|
||||
e_t = unconditional_latent + unconditional_guidance_scale * (latent - unconditional_latent)
|
||||
return e_t
|
||||
|
||||
timesteps = list(range(1, 1000, 1000//args.steps))
|
||||
print(f"running for {timesteps} timesteps")
|
||||
alphas = model.alphas_cumprod[Tensor(timesteps)]
|
||||
alphas_prev = Tensor([1.0]).cat(alphas[:-1])
|
||||
|
||||
def get_x_prev_and_pred_x0(x, e_t, index):
|
||||
temperature = 1
|
||||
a_t, a_prev = alphas[index], alphas_prev[index]
|
||||
sigma_t = 0
|
||||
sqrt_one_minus_at = (1-a_t).sqrt()
|
||||
#print(a_t, a_prev, sigma_t, sqrt_one_minus_at)
|
||||
|
||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||
|
||||
# direction pointing to x_t
|
||||
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
||||
|
||||
x_prev = a_prev.sqrt() * pred_x0 + dir_xt
|
||||
return x_prev, pred_x0
|
||||
|
||||
@TinyJit
|
||||
def do_step(latent, timestep, index, guidance):
|
||||
e_t = get_model_output(latent, timestep, guidance)
|
||||
x_prev, _ = get_x_prev_and_pred_x0(latent, e_t, index)
|
||||
#e_t_next = get_model_output(x_prev)
|
||||
#e_t_prime = (e_t + e_t_next) / 2
|
||||
#x_prev, pred_x0 = get_x_prev_and_pred_x0(latent, e_t_prime, index)
|
||||
return x_prev.realize()
|
||||
|
||||
# start with random noise
|
||||
if args.seed is not None: Tensor._seed = args.seed
|
||||
latent = Tensor.randn(1,4,64,64)
|
||||
|
||||
@TinyJit
|
||||
def run(model, *x): return model(*x).realize()
|
||||
|
||||
# this is diffusion
|
||||
with Context(BEAM=getenv("LATEBEAM")):
|
||||
# this is diffusion
|
||||
for index, timestep in (t:=tqdm(list(enumerate(timesteps))[::-1])):
|
||||
GlobalCounters.reset()
|
||||
t.set_description("%3d %3d" % (index, timestep))
|
||||
with Timing("step in ", enabled=args.timing, on_exit=lambda _: f", using {GlobalCounters.mem_used/1e9:.2f} GB"):
|
||||
latent = do_step(latent, Tensor([timestep]), Tensor([index]), Tensor([args.guidance]))
|
||||
tid = Tensor([index])
|
||||
latent = run(model, unconditional_context, context, latent, Tensor([timestep]), alphas[tid], alphas_prev[tid], Tensor([args.guidance]))
|
||||
if args.timing: Device[Device.DEFAULT].synchronize()
|
||||
del do_step
|
||||
del run
|
||||
|
||||
# upsample latent space to image with autoencoder
|
||||
x = model.first_stage_model.post_quant_conv(1/0.18215 * latent)
|
||||
x = model.first_stage_model.decoder(x)
|
||||
|
||||
# make image correct size and scale
|
||||
x = (x + 1.0) / 2.0
|
||||
x = (x.reshape(3,512,512).permute(1,2,0).clip(0,1)*255)
|
||||
if Device.DEFAULT != "WEBGPU": x = x.cast(dtypes.uint8)
|
||||
x = model.decode(latent)
|
||||
print(x.shape)
|
||||
|
||||
# save image
|
||||
|
||||
Reference in New Issue
Block a user