mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Demo for currently not MLIR lowerable v-diffusion model (including both demo path and traceable path). (#86)
This commit is contained in:
111
tank/pytorch/v_diffusion/cfg_sample.py
Normal file
111
tank/pytorch/v_diffusion/cfg_sample.py
Normal file
@@ -0,0 +1,111 @@
|
||||
import argparse
|
||||
import os
|
||||
from functools import partial
|
||||
|
||||
import clip
|
||||
import torch
|
||||
from torchvision import transforms
|
||||
from tqdm import trange
|
||||
|
||||
try:
|
||||
from diffusion import get_model, sampling, utils
|
||||
except ModuleNotFoundError:
|
||||
print(
|
||||
"You need to download v-diffusion source from https://github.com/crowsonkb/v-diffusion-pytorch"
|
||||
)
|
||||
raise
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
|
||||
def parse_prompt(prompt, default_weight=3.0):
|
||||
if prompt.startswith("http://") or prompt.startswith("https://"):
|
||||
vals = prompt.rsplit(":", 2)
|
||||
vals = [vals[0] + ":" + vals[1], *vals[2:]]
|
||||
else:
|
||||
vals = prompt.rsplit(":", 1)
|
||||
vals = vals + ["", default_weight][len(vals) :]
|
||||
return vals[0], float(vals[1])
|
||||
|
||||
|
||||
args = argparse.Namespace(
|
||||
prompts=["New York City, oil on canvas"],
|
||||
batch_size=1,
|
||||
device="cuda",
|
||||
model="cc12m_1_cfg",
|
||||
n=1,
|
||||
steps=10,
|
||||
)
|
||||
|
||||
device = torch.device(args.device)
|
||||
print("Using device:", device)
|
||||
|
||||
model = get_model(args.model)()
|
||||
_, side_y, side_x = model.shape
|
||||
checkpoint = f"{args.model}.pth"
|
||||
if os.path.exists(checkpoint):
|
||||
model.load_state_dict(torch.load(checkpoint, map_location="cpu"))
|
||||
|
||||
model = model.to(device).eval().requires_grad_(False)
|
||||
clip_model_name = model.clip_model if hasattr(model, "clip_model") else "ViT-B/16"
|
||||
clip_model = clip.load(clip_model_name, jit=False, device=device)[0]
|
||||
clip_model.eval().requires_grad_(False)
|
||||
normalize = transforms.Normalize(
|
||||
mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]
|
||||
)
|
||||
|
||||
zero_embed = torch.zeros([1, clip_model.visual.output_dim], device=device)
|
||||
target_embeds, weights = [zero_embed], []
|
||||
|
||||
txt, weight = parse_prompt(args.prompts[0])
|
||||
target_embeds.append(clip_model.encode_text(clip.tokenize(txt).to(device)).float())
|
||||
weights.append(weight)
|
||||
|
||||
weights = torch.tensor([1 - sum(weights), *weights], device=device)
|
||||
|
||||
|
||||
def cfg_model_fn(model, x, t):
|
||||
n = x.shape[0]
|
||||
n_conds = len(target_embeds)
|
||||
x_in = x.repeat([n_conds, 1, 1, 1])
|
||||
t_in = t.repeat([n_conds])
|
||||
clip_embed_in = torch.cat([*target_embeds]).repeat_interleave(n, 0)
|
||||
vs = model(x_in, t_in, clip_embed_in).view([n_conds, n, *x.shape[1:]])
|
||||
v = vs.mul(weights[:, None, None, None, None]).sum(0)
|
||||
return v
|
||||
|
||||
|
||||
x = torch.randn([args.n, 3, side_y, side_x], device=device)
|
||||
t = torch.linspace(1, 0, args.steps + 1, device=device)[:-1]
|
||||
|
||||
|
||||
def repro(model):
|
||||
if device.type == "cuda":
|
||||
model = model.half()
|
||||
|
||||
steps = utils.get_spliced_ddpm_cosine_schedule(t)
|
||||
for i in trange(0, args.n, args.batch_size):
|
||||
cur_batch_size = min(args.n - i, args.batch_size)
|
||||
outs = sampling.plms_sample(
|
||||
partial(cfg_model_fn, model), x[i : i + cur_batch_size], steps, {}
|
||||
)
|
||||
for j, out in enumerate(outs):
|
||||
utils.to_pil_image(out).save(f"out_{i + j:05}.png")
|
||||
|
||||
|
||||
def trace(model, x, t):
|
||||
n = x.shape[0]
|
||||
n_conds = len(target_embeds)
|
||||
x_in = x.repeat([n_conds, 1, 1, 1])
|
||||
t_in = t.repeat([n_conds])
|
||||
clip_embed_in = torch.cat([*target_embeds]).repeat_interleave(n, 0)
|
||||
ts_mod = torch.jit.trace(model, (x_in, t_in, clip_embed_in))
|
||||
print(ts_mod.graph)
|
||||
|
||||
clip_model = clip.load(clip_model_name, jit=True, device=device)[0]
|
||||
print(clip_model.graph)
|
||||
|
||||
|
||||
# You can't run both of these because repro will `.half()` the model
|
||||
# repro(model)
|
||||
trace(model, x, t[0])
|
||||
BIN
tank/pytorch/v_diffusion/out_00000.png
Normal file
BIN
tank/pytorch/v_diffusion/out_00000.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 145 KiB |
Reference in New Issue
Block a user