Apply black

This commit is contained in:
Martin Kristiansen
2023-07-27 10:54:01 -04:00
parent 2183dba5c5
commit 218b6d0546
148 changed files with 5486 additions and 6296 deletions

View File

@@ -2,6 +2,7 @@ from torchvision.datasets.utils import download_url
from ldm.util import instantiate_from_config
import torch
import os
# todo ?
from google.colab import files
from IPython.display import Image as ipyimg
@@ -16,21 +17,21 @@ import time
from omegaconf import OmegaConf
from ldm.invoke.devices import choose_torch_device
def download_models(mode):
def download_models(mode):
if mode == "superresolution":
# this is the small bsr light model
url_conf = 'https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1'
url_ckpt = 'https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1'
url_conf = "https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1"
url_ckpt = "https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1"
path_conf = 'logs/diffusion/superresolution_bsr/configs/project.yaml'
path_ckpt = 'logs/diffusion/superresolution_bsr/checkpoints/last.ckpt'
path_conf = "logs/diffusion/superresolution_bsr/configs/project.yaml"
path_ckpt = "logs/diffusion/superresolution_bsr/checkpoints/last.ckpt"
download_url(url_conf, path_conf)
download_url(url_ckpt, path_ckpt)
path_conf = path_conf + '/?dl=1' # fix it
path_ckpt = path_ckpt + '/?dl=1' # fix it
path_conf = path_conf + "/?dl=1" # fix it
path_ckpt = path_ckpt + "/?dl=1" # fix it
return path_conf, path_ckpt
else:
@@ -62,20 +63,20 @@ def get_custom_cond(mode):
if mode == "superresolution":
uploaded_img = files.upload()
filename = next(iter(uploaded_img))
name, filetype = filename.split(".") # todo assumes just one dot in name !
name, filetype = filename.split(".") # todo assumes just one dot in name !
os.rename(f"{filename}", f"{dest}/{mode}/custom_{name}.{filetype}")
elif mode == "text_conditional":
w = widgets.Text(value='A cake with cream!', disabled=True)
w = widgets.Text(value="A cake with cream!", disabled=True)
display(w)
with open(f"{dest}/{mode}/custom_{w.value[:20]}.txt", 'w') as f:
with open(f"{dest}/{mode}/custom_{w.value[:20]}.txt", "w") as f:
f.write(w.value)
elif mode == "class_conditional":
w = widgets.IntSlider(min=0, max=1000)
display(w)
with open(f"{dest}/{mode}/custom.txt", 'w') as f:
with open(f"{dest}/{mode}/custom.txt", "w") as f:
f.write(w.value)
else:
@@ -94,11 +95,7 @@ def select_cond_path(mode):
path = os.path.join(path, mode)
onlyfiles = [f for f in sorted(os.listdir(path))]
selected = widgets.RadioButtons(
options=onlyfiles,
description='Select conditioning:',
disabled=False
)
selected = widgets.RadioButtons(options=onlyfiles, description="Select conditioning:", disabled=False)
display(selected)
selected_path = os.path.join(path, selected.value)
return selected_path
@@ -113,9 +110,9 @@ def get_cond(mode, selected_path):
c = Image.open(selected_path)
c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0)
c_up = torchvision.transforms.functional.resize(c, size=[up_f * c.shape[2], up_f * c.shape[3]], antialias=True)
c_up = rearrange(c_up, '1 c h w -> 1 h w c')
c = rearrange(c, '1 c h w -> 1 h w c')
c = 2. * c - 1.
c_up = rearrange(c_up, "1 c h w -> 1 h w c")
c = rearrange(c, "1 c h w -> 1 h w c")
c = 2.0 * c - 1.0
device = choose_torch_device()
c = c.to(device)
@@ -130,7 +127,6 @@ def visualize_cond_img(path):
def run(model, selected_path, task, custom_steps, resize_enabled=False, classifier_ckpt=None, global_step=None):
example = get_cond(task, selected_path)
save_intermediate_vid = False
@@ -138,10 +134,10 @@ def run(model, selected_path, task, custom_steps, resize_enabled=False, classifi
masked = False
guider = None
ckwargs = None
mode = 'ddim'
mode = "ddim"
ddim_use_x0_pred = False
temperature = 1.
eta = 1.
temperature = 1.0
eta = 1.0
make_progrow = True
custom_shape = None
@@ -152,14 +148,17 @@ def run(model, selected_path, task, custom_steps, resize_enabled=False, classifi
ks = 128
stride = 64
vqf = 4 #
model.split_input_params = {"ks": (ks, ks), "stride": (stride, stride),
"vqf": vqf,
"patch_distributed_vq": True,
"tie_braker": False,
"clip_max_weight": 0.5,
"clip_min_weight": 0.01,
"clip_max_tie_weight": 0.5,
"clip_min_tie_weight": 0.01}
model.split_input_params = {
"ks": (ks, ks),
"stride": (stride, stride),
"vqf": vqf,
"patch_distributed_vq": True,
"tie_braker": False,
"clip_max_weight": 0.5,
"clip_min_weight": 0.01,
"clip_max_tie_weight": 0.5,
"clip_min_tie_weight": 0.01,
}
else:
if hasattr(model, "split_input_params"):
delattr(model, "split_input_params")
@@ -170,53 +169,112 @@ def run(model, selected_path, task, custom_steps, resize_enabled=False, classifi
for n in range(n_runs):
if custom_shape is not None:
x_T = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device)
x_T = repeat(x_T, '1 c h w -> b c h w', b=custom_shape[0])
x_T = repeat(x_T, "1 c h w -> b c h w", b=custom_shape[0])
logs = make_convolutional_sample(example, model,
mode=mode, custom_steps=custom_steps,
eta=eta, swap_mode=False , masked=masked,
invert_mask=invert_mask, quantize_x0=False,
custom_schedule=None, decode_interval=10,
resize_enabled=resize_enabled, custom_shape=custom_shape,
temperature=temperature, noise_dropout=0.,
corrector=guider, corrector_kwargs=ckwargs, x_T=x_T, save_intermediate_vid=save_intermediate_vid,
make_progrow=make_progrow,ddim_use_x0_pred=ddim_use_x0_pred
)
logs = make_convolutional_sample(
example,
model,
mode=mode,
custom_steps=custom_steps,
eta=eta,
swap_mode=False,
masked=masked,
invert_mask=invert_mask,
quantize_x0=False,
custom_schedule=None,
decode_interval=10,
resize_enabled=resize_enabled,
custom_shape=custom_shape,
temperature=temperature,
noise_dropout=0.0,
corrector=guider,
corrector_kwargs=ckwargs,
x_T=x_T,
save_intermediate_vid=save_intermediate_vid,
make_progrow=make_progrow,
ddim_use_x0_pred=ddim_use_x0_pred,
)
return logs
@torch.no_grad()
def convsample_ddim(model, cond, steps, shape, eta=1.0, callback=None, normals_sequence=None,
mask=None, x0=None, quantize_x0=False, img_callback=None,
temperature=1., noise_dropout=0., score_corrector=None,
corrector_kwargs=None, x_T=None, log_every_t=None
):
def convsample_ddim(
model,
cond,
steps,
shape,
eta=1.0,
callback=None,
normals_sequence=None,
mask=None,
x0=None,
quantize_x0=False,
img_callback=None,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
x_T=None,
log_every_t=None,
):
ddim = DDIMSampler(model)
bs = shape[0] # dont know where this comes from but wayne
shape = shape[1:] # cut batch dim
print(f"Sampling with eta = {eta}; steps: {steps}")
samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, conditioning=cond, callback=callback,
normals_sequence=normals_sequence, quantize_x0=quantize_x0, eta=eta,
mask=mask, x0=x0, temperature=temperature, verbose=False,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs, x_T=x_T)
samples, intermediates = ddim.sample(
steps,
batch_size=bs,
shape=shape,
conditioning=cond,
callback=callback,
normals_sequence=normals_sequence,
quantize_x0=quantize_x0,
eta=eta,
mask=mask,
x0=x0,
temperature=temperature,
verbose=False,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,
)
return samples, intermediates
@torch.no_grad()
def make_convolutional_sample(batch, model, mode="vanilla", custom_steps=None, eta=1.0, swap_mode=False, masked=False,
invert_mask=True, quantize_x0=False, custom_schedule=None, decode_interval=1000,
resize_enabled=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None,
corrector_kwargs=None, x_T=None, save_intermediate_vid=False, make_progrow=True,ddim_use_x0_pred=False):
def make_convolutional_sample(
batch,
model,
mode="vanilla",
custom_steps=None,
eta=1.0,
swap_mode=False,
masked=False,
invert_mask=True,
quantize_x0=False,
custom_schedule=None,
decode_interval=1000,
resize_enabled=False,
custom_shape=None,
temperature=1.0,
noise_dropout=0.0,
corrector=None,
corrector_kwargs=None,
x_T=None,
save_intermediate_vid=False,
make_progrow=True,
ddim_use_x0_pred=False,
):
log = dict()
z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key,
return_first_stage_outputs=True,
force_c_encode=not (hasattr(model, 'split_input_params')
and model.cond_stage_key == 'coordinates_bbox'),
return_original_cond=True)
z, c, x, xrec, xc = model.get_input(
batch,
model.first_stage_key,
return_first_stage_outputs=True,
force_c_encode=not (hasattr(model, "split_input_params") and model.cond_stage_key == "coordinates_bbox"),
return_original_cond=True,
)
log_every_t = 1 if save_intermediate_vid else None
@@ -231,30 +289,41 @@ def make_convolutional_sample(batch, model, mode="vanilla", custom_steps=None, e
if ismap(xc):
log["original_conditioning"] = model.to_rgb(xc)
if hasattr(model, 'cond_stage_key'):
if hasattr(model, "cond_stage_key"):
log[model.cond_stage_key] = model.to_rgb(xc)
else:
log["original_conditioning"] = xc if xc is not None else torch.zeros_like(x)
if model.cond_stage_model:
log[model.cond_stage_key] = xc if xc is not None else torch.zeros_like(x)
if model.cond_stage_key =='class_label':
if model.cond_stage_key == "class_label":
log[model.cond_stage_key] = xc[model.cond_stage_key]
with model.ema_scope("Plotting"):
t0 = time.time()
img_cb = None
sample, intermediates = convsample_ddim(model, c, steps=custom_steps, shape=z.shape,
eta=eta,
quantize_x0=quantize_x0, img_callback=img_cb, mask=None, x0=z0,
temperature=temperature, noise_dropout=noise_dropout,
score_corrector=corrector, corrector_kwargs=corrector_kwargs,
x_T=x_T, log_every_t=log_every_t)
sample, intermediates = convsample_ddim(
model,
c,
steps=custom_steps,
shape=z.shape,
eta=eta,
quantize_x0=quantize_x0,
img_callback=img_cb,
mask=None,
x0=z0,
temperature=temperature,
noise_dropout=noise_dropout,
score_corrector=corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,
log_every_t=log_every_t,
)
t1 = time.time()
if ddim_use_x0_pred:
sample = intermediates['pred_x0'][-1]
sample = intermediates["pred_x0"][-1]
x_sample = model.decode_first_stage(sample)