mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
@@ -19,6 +19,7 @@ from functools import partial
|
||||
from tqdm import tqdm
|
||||
from torchvision.utils import make_grid
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_only
|
||||
from omegaconf import ListConfig
|
||||
import urllib
|
||||
|
||||
from ldm.util import (
|
||||
@@ -120,7 +121,7 @@ class DDPM(pl.LightningModule):
|
||||
self.use_ema = use_ema
|
||||
if self.use_ema:
|
||||
self.model_ema = LitEma(self.model)
|
||||
print(f'Keeping EMAs of {len(list(self.model_ema.buffers()))}.')
|
||||
print(f' | Keeping EMAs of {len(list(self.model_ema.buffers()))}.')
|
||||
|
||||
self.use_scheduler = scheduler_config is not None
|
||||
if self.use_scheduler:
|
||||
@@ -1883,6 +1884,24 @@ class LatentDiffusion(DDPM):
|
||||
|
||||
return samples, intermediates
|
||||
|
||||
@torch.no_grad()
|
||||
def get_unconditional_conditioning(self, batch_size, null_label=None):
|
||||
if null_label is not None:
|
||||
xc = null_label
|
||||
if isinstance(xc, ListConfig):
|
||||
xc = list(xc)
|
||||
if isinstance(xc, dict) or isinstance(xc, list):
|
||||
c = self.get_learned_conditioning(xc)
|
||||
else:
|
||||
if hasattr(xc, "to"):
|
||||
xc = xc.to(self.device)
|
||||
c = self.get_learned_conditioning(xc)
|
||||
else:
|
||||
# todo: get null label from cond_stage_model
|
||||
raise NotImplementedError()
|
||||
c = repeat(c, "1 ... -> b ...", b=batch_size).to(self.device)
|
||||
return c
|
||||
|
||||
@torch.no_grad()
|
||||
def log_images(
|
||||
self,
|
||||
@@ -2147,8 +2166,8 @@ class DiffusionWrapper(pl.LightningModule):
|
||||
cc = torch.cat(c_crossattn, 1)
|
||||
out = self.diffusion_model(x, t, context=cc)
|
||||
elif self.conditioning_key == 'hybrid':
|
||||
xc = torch.cat([x] + c_concat, dim=1)
|
||||
cc = torch.cat(c_crossattn, 1)
|
||||
xc = torch.cat([x] + c_concat, dim=1)
|
||||
out = self.diffusion_model(xc, t, context=cc)
|
||||
elif self.conditioning_key == 'adm':
|
||||
cc = c_crossattn[0]
|
||||
@@ -2187,3 +2206,58 @@ class Layout2ImgDiffusion(LatentDiffusion):
|
||||
cond_img = torch.stack(bbox_imgs, dim=0)
|
||||
logs['bbox_image'] = cond_img
|
||||
return logs
|
||||
|
||||
class LatentInpaintDiffusion(LatentDiffusion):
|
||||
def __init__(
|
||||
self,
|
||||
concat_keys=("mask", "masked_image"),
|
||||
masked_image_key="masked_image",
|
||||
finetune_keys=None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.masked_image_key = masked_image_key
|
||||
assert self.masked_image_key in concat_keys
|
||||
self.concat_keys = concat_keys
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def get_input(
|
||||
self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False
|
||||
):
|
||||
# note: restricted to non-trainable encoders currently
|
||||
assert (
|
||||
not self.cond_stage_trainable
|
||||
), "trainable cond stages not yet supported for inpainting"
|
||||
z, c, x, xrec, xc = super().get_input(
|
||||
batch,
|
||||
self.first_stage_key,
|
||||
return_first_stage_outputs=True,
|
||||
force_c_encode=True,
|
||||
return_original_cond=True,
|
||||
bs=bs,
|
||||
)
|
||||
|
||||
assert exists(self.concat_keys)
|
||||
c_cat = list()
|
||||
for ck in self.concat_keys:
|
||||
cc = (
|
||||
rearrange(batch[ck], "b h w c -> b c h w")
|
||||
.to(memory_format=torch.contiguous_format)
|
||||
.float()
|
||||
)
|
||||
if bs is not None:
|
||||
cc = cc[:bs]
|
||||
cc = cc.to(self.device)
|
||||
bchw = z.shape
|
||||
if ck != self.masked_image_key:
|
||||
cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
|
||||
else:
|
||||
cc = self.get_first_stage_encoding(self.encode_first_stage(cc))
|
||||
c_cat.append(cc)
|
||||
c_cat = torch.cat(c_cat, dim=1)
|
||||
all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
|
||||
if return_first_stage_outputs:
|
||||
return z, all_conds, x, xrec, xc
|
||||
return z, all_conds
|
||||
|
||||
Reference in New Issue
Block a user