Files
AMD-SHARK-Studio/tank/pytorch/v_diffusion/cfg_sample.py
Prashant Kumar 0dcf387089 Add shark_importer for torch_models. (#183)
All the torch_models are imported to gs::shark_tank.
Scripts have been updated.
2022-07-12 20:38:19 -07:00

117 lines
3.3 KiB
Python

import argparse
import os
from functools import partial
import clip
import torch
from torchvision import transforms
from tqdm import trange
try:
from diffusion import get_model, sampling, utils
except ModuleNotFoundError:
print(
"You need to download v-diffusion source from https://github.com/crowsonkb/v-diffusion-pytorch"
)
raise
torch.manual_seed(0)
def parse_prompt(prompt, default_weight=3.0):
if prompt.startswith("http://") or prompt.startswith("https://"):
vals = prompt.rsplit(":", 2)
vals = [vals[0] + ":" + vals[1], *vals[2:]]
else:
vals = prompt.rsplit(":", 1)
vals = vals + ["", default_weight][len(vals) :]
return vals[0], float(vals[1])
args = argparse.Namespace(
prompts=["New York City, oil on canvas"],
batch_size=1,
device="cuda",
model="cc12m_1_cfg",
n=1,
steps=10,
)
device = torch.device(args.device)
print("Using device:", device)
model = get_model(args.model)()
_, side_y, side_x = model.shape
checkpoint = f"{args.model}.pth"
if os.path.exists(checkpoint):
model.load_state_dict(torch.load(checkpoint, map_location="cpu"))
model = model.to(device).eval().requires_grad_(False)
clip_model_name = (
model.clip_model if hasattr(model, "clip_model") else "ViT-B/16"
)
clip_model = clip.load(clip_model_name, jit=False, device=device)[0]
clip_model.eval().requires_grad_(False)
normalize = transforms.Normalize(
mean=[0.48145466, 0.4578275, 0.40821073],
std=[0.26862954, 0.26130258, 0.27577711],
)
zero_embed = torch.zeros([1, clip_model.visual.output_dim], device=device)
target_embeds, weights = [zero_embed], []
txt, weight = parse_prompt(args.prompts[0])
target_embeds.append(
clip_model.encode_text(clip.tokenize(txt).to(device)).float()
)
weights.append(weight)
weights = torch.tensor([1 - sum(weights), *weights], device=device)
def cfg_model_fn(model, x, t):
n = x.shape[0]
n_conds = len(target_embeds)
x_in = x.repeat([n_conds, 1, 1, 1])
t_in = t.repeat([n_conds])
clip_embed_in = torch.cat([*target_embeds]).repeat_interleave(n, 0)
vs = model(x_in, t_in, clip_embed_in).view([n_conds, n, *x.shape[1:]])
v = vs.mul(weights[:, None, None, None, None]).sum(0)
return v
x = torch.randn([args.n, 3, side_y, side_x], device=device)
t = torch.linspace(1, 0, args.steps + 1, device=device)[:-1]
def repro(model):
if device.type == "cuda":
model = model.half()
steps = utils.get_spliced_ddpm_cosine_schedule(t)
for i in trange(0, args.n, args.batch_size):
cur_batch_size = min(args.n - i, args.batch_size)
outs = sampling.plms_sample(
partial(cfg_model_fn, model), x[i : i + cur_batch_size], steps, {}
)
for j, out in enumerate(outs):
utils.to_pil_image(out).save(f"out_{i + j:05}.png")
def trace(model, x, t):
n = x.shape[0]
n_conds = len(target_embeds)
x_in = x.repeat([n_conds, 1, 1, 1])
t_in = t.repeat([n_conds])
clip_embed_in = torch.cat([*target_embeds]).repeat_interleave(n, 0)
ts_mod = torch.jit.trace(model, (x_in, t_in, clip_embed_in))
print(ts_mod.graph)
clip_model = clip.load(clip_model_name, jit=True, device=device)[0]
print(clip_model.graph)
# You can't run both of these because repro will `.half()` the model
# repro(model)
trace(model, x, t[0])