From 54e6a68acbf314ddbcbda450bb0c5d39b39af4a3 Mon Sep 17 00:00:00 2001 From: Damian at mba Date: Tue, 18 Oct 2022 22:09:06 +0200 Subject: [PATCH] wip bringing cross-attention to PLMS and DDIM --- ldm/invoke/generator/txt2img.py | 4 +- ldm/models/diffusion/cross_attention.py | 52 ++++++++++++++++++++- ldm/models/diffusion/ddim.py | 21 +++++++-- ldm/models/diffusion/ksampler.py | 62 +++++++------------------ ldm/models/diffusion/plms.py | 28 ++++++++--- ldm/models/diffusion/sampler.py | 8 ++-- 6 files changed, 112 insertions(+), 63 deletions(-) diff --git a/ldm/invoke/generator/txt2img.py b/ldm/invoke/generator/txt2img.py index 9f066745f7..669f3d81ff 100644 --- a/ldm/invoke/generator/txt2img.py +++ b/ldm/invoke/generator/txt2img.py @@ -19,7 +19,7 @@ class Txt2Img(Generator): kwargs are 'width' and 'height' """ self.perlin = perlin - uc, c, ec, edit_index_map = conditioning + uc, c, ec, edit_opcodes = conditioning @torch.no_grad() def make_image(x_T): @@ -44,7 +44,7 @@ class Txt2Img(Generator): unconditional_guidance_scale = cfg_scale, unconditional_conditioning = uc, edited_conditioning = ec, - edit_token_index_map = edit_index_map, + conditioning_edit_opcodes = edit_opcodes, eta = ddim_eta, img_callback = step_callback, threshold = threshold, diff --git a/ldm/models/diffusion/cross_attention.py b/ldm/models/diffusion/cross_attention.py index 71d5995b4a..c0760fff47 100644 --- a/ldm/models/diffusion/cross_attention.py +++ b/ldm/models/diffusion/cross_attention.py @@ -2,6 +2,55 @@ from enum import Enum import torch + +class CrossAttentionControllableDiffusionMixin: + + def setup_cross_attention_control_if_appropriate(self, model, edited_conditioning, edit_opcodes): + self.edited_conditioning = edited_conditioning + + if edited_conditioning is not None: + # a cat sitting on a car + CrossAttentionControl.setup_attention_editing(model, edited_conditioning, edit_opcodes) + else: + # pass through the attention func but don't act on it + CrossAttentionControl.clear_attention_editing(model) + + def cleanup_cross_attention_control(self, model): + CrossAttentionControl.clear_attention_editing(model) + + def do_cross_attention_controllable_diffusion_step(self, x, sigma, unconditioning, conditioning, model, model_forward_callback): + + CrossAttentionControl.clear_requests(model) + + if self.edited_conditioning is None: + # faster batched path + x_twice = torch.cat([x]*2) + sigma_twice = torch.cat([sigma]*2) + both_conditionings = torch.cat([unconditioning, conditioning]) + unconditioned_next_x, conditioned_next_x = model_forward_callback(x_twice, sigma_twice, both_conditionings).chunk(2) + else: + # slower non-batched path (20% slower on mac MPS) + # We are only interested in using attention maps for conditioned_next_x, but batching them with generation of + # unconditioned_next_x causes attention maps to *also* be saved for the unconditioned_next_x. + # This messes app their application later, due to mismatched shape of dim 0 (seems to be 16 for batched vs. 8) + # (For the batched invocation the `wrangler` function gets attention tensor with shape[0]=16, + # representing batched uncond + cond, but then when it comes to applying the saved attention, the + # wrangler gets an attention tensor which only has shape[0]=8, representing just self.edited_conditionings.) + # todo: give CrossAttentionControl's `wrangler` function more info so it can work with a batched call as well. + unconditioned_next_x = model_forward_callback(x, sigma, unconditioning) + + # process x using the original prompt, saving the attention maps + CrossAttentionControl.request_save_attention_maps(model) + _ = model_forward_callback(x, sigma, cond=conditioning) + CrossAttentionControl.clear_requests(model) + + # process x again, using the saved attention maps to control where self.edited_conditioning will be applied + CrossAttentionControl.request_apply_saved_attention_maps(model) + conditioned_next_x = model_forward_callback(x, sigma, self.edited_conditioning) + CrossAttentionControl.clear_requests(model) + + return unconditioned_next_x, conditioned_next_x + # adapted from bloc97's CrossAttentionControl colab # https://github.com/bloc97/CrossAttentionControl @@ -27,7 +76,8 @@ class CrossAttentionControl: # adapted from init_attention_edit device = substitute_conditioning.device - max_length = model.inner_model.cond_stage_model.max_length + # urgh. should this be hardcoded? + max_length = 77 # mask=1 means use base prompt attention, mask=0 means use edited prompt attention mask = torch.zeros(max_length) indices_target = torch.arange(max_length, dtype=torch.long) diff --git a/ldm/models/diffusion/ddim.py b/ldm/models/diffusion/ddim.py index f5dada8627..4980b03c42 100644 --- a/ldm/models/diffusion/ddim.py +++ b/ldm/models/diffusion/ddim.py @@ -5,13 +5,23 @@ import numpy as np from tqdm import tqdm from functools import partial from ldm.invoke.devices import choose_torch_device +from ldm.models.diffusion.cross_attention import CrossAttentionControllableDiffusionMixin from ldm.models.diffusion.sampler import Sampler from ldm.modules.diffusionmodules.util import noise_like -class DDIMSampler(Sampler): +class DDIMSampler(Sampler, CrossAttentionControllableDiffusionMixin): def __init__(self, model, schedule='linear', device=None, **kwargs): super().__init__(model,schedule,model.num_timesteps,device) + def prepare_to_sample(self, t_enc, **kwargs): + super().prepare_to_sample(t_enc, **kwargs) + + edited_conditioning = kwargs.get('edited_conditioning', None) + edit_opcodes = kwargs.get('conditioning_edit_opcodes', None) + + self.setup_cross_attention_control_if_appropriate(self.model, edited_conditioning, edit_opcodes) + + # This is the central routine @torch.no_grad() def p_sample( @@ -37,12 +47,13 @@ class DDIMSampler(Sampler): unconditional_conditioning is None or unconditional_guidance_scale == 1.0 ): + # damian0815 does not think this code path is ever used e_t = self.model.apply_model(x, t, c) else: - x_in = torch.cat([x] * 2) - t_in = torch.cat([t] * 2) - c_in = torch.cat([unconditional_conditioning, c]) - e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) + + e_t_uncond, e_t = self.do_cross_attention_controllable_diffusion_step(x, t, unconditional_conditioning, c, self.model, + model_forward_callback=lambda x, sigma, cond: self.model.apply_model(x, sigma, cond)) + e_t = e_t_uncond + unconditional_guidance_scale * ( e_t - e_t_uncond ) diff --git a/ldm/models/diffusion/ksampler.py b/ldm/models/diffusion/ksampler.py index 7459e2e7cc..78d5978efe 100644 --- a/ldm/models/diffusion/ksampler.py +++ b/ldm/models/diffusion/ksampler.py @@ -13,7 +13,8 @@ from ldm.modules.diffusionmodules.util import ( noise_like, extract_into_tensor, ) -from ldm.models.diffusion.cross_attention import CrossAttentionControl +from ldm.models.diffusion.cross_attention import CrossAttentionControl, CrossAttentionControllableDiffusionMixin + def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7): if threshold <= 0.0: @@ -29,53 +30,26 @@ def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7): return torch.clamp(result, min=minval, max=maxval) -class CFGDenoiser(nn.Module): - def __init__(self, model, threshold = 0, warmup = 0, edited_conditioning = None, edit_opcodes = None): +class CFGDenoiser(nn.Module, CrossAttentionControllableDiffusionMixin): + def __init__(self, model, threshold = 0, warmup = 0): super().__init__() self.inner_model = model self.threshold = threshold self.warmup_max = warmup self.warmup = max(warmup / 10, 1) - self.edited_conditioning = edited_conditioning + def prepare_to_sample(self, t_enc, **kwargs): + + edited_conditioning = kwargs.get('edited_conditioning', None) + conditioning_edit_opcodes = kwargs.get('conditioning_edit_opcodes', None) + + self.setup_cross_attention_control_if_appropriate(self.model, edited_conditioning, conditioning_edit_opcodes) - if edited_conditioning is not None: - # a cat sitting on a car - CrossAttentionControl.setup_attention_editing(self.inner_model, edited_conditioning, edit_opcodes) - else: - # pass through the attention func but don't act on it - CrossAttentionControl.clear_attention_editing(self.inner_model) def forward(self, x, sigma, uncond, cond, cond_scale): - CrossAttentionControl.clear_requests(self.inner_model) - - if self.edited_conditioning is None: - # faster batch path - x_twice = torch.cat([x]*2) - sigma_twice = torch.cat([sigma]*2) - both_conditionings = torch.cat([uncond, cond]) - unconditioned_next_x, conditioned_next_x = self.inner_model(x_twice, sigma_twice, cond=both_conditionings).chunk(2) - else: - # slower non-batched path (20% slower on mac MPS) - # We are only interested in using attention maps for conditioned_next_x, but batching them with generation of - # unconditioned_next_x causes attention maps to *also* be saved for the unconditioned_next_x. - # This messes app their application later, due to mismatched shape of dim 0 (16 vs. 8) - # (For the batched invocation the `wrangler` function gets attention tensor with shape[0]=16, - # representing batched uncond + cond, but then when it comes to applying the saved attention, the - # wrangler gets an attention tensor which only has shape[0]=8, representing just self.edited_conditionings.) - # todo: give CrossAttentionControl's `wrangler` function more info so it can work with a batched call as well. - unconditioned_next_x = self.inner_model(x, sigma, cond=uncond) - - # process x using the original prompt, saving the attention maps - CrossAttentionControl.request_save_attention_maps(self.inner_model) - _ = self.inner_model(x, sigma, cond=cond) - CrossAttentionControl.clear_requests(self.inner_model) - - # process x again, using the saved attention maps to control where self.edited_conditioning will be applied - CrossAttentionControl.request_apply_saved_attention_maps(self.inner_model) - conditioned_next_x = self.inner_model(x, sigma, cond=self.edited_conditioning) - CrossAttentionControl.clear_requests(self.inner_model) + unconditioned_next_x, conditioned_next_x = self.do_cross_attention_controllable_diffusion_step(x, sigma, uncond, cond, self.inner_model, + model_forward_callback=lambda x, sigma, cond: self.inner_model(x, sigma, cond=cond)) if self.warmup < self.warmup_max: thresh = max(1, 1 + (self.threshold - 1) * (self.warmup / self.warmup_max)) @@ -204,7 +178,7 @@ class KSampler(Sampler): unconditional_guidance_scale=1.0, unconditional_conditioning=None, edited_conditioning=None, - edit_token_index_map=None, + conditioning_edit_opcodes=None, threshold = 0, perlin = 0, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... @@ -236,21 +210,22 @@ class KSampler(Sampler): else: x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0] - model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10), - edited_conditioning=edited_conditioning, edit_opcodes=edit_token_index_map) + model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10)) + model_wrap_cfg.prepare_to_sample(S, edited_conditioning=edited_conditioning, conditioning_edit_opcodes=conditioning_edit_opcodes) extra_args = { 'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': unconditional_guidance_scale, } print(f'>> Sampling with k_{self.schedule} starting at step {len(self.sigmas)-S-1} of {len(self.sigmas)-1} ({S} new sampling steps)') - return ( + sampling_result = ( K.sampling.__dict__[f'sample_{self.schedule}']( model_wrap_cfg, x, sigmas, extra_args=extra_args, callback=route_callback ), None, ) + return sampling_result # this code will support inpainting if and when ksampler API modified or # a workaround is found. @@ -312,7 +287,7 @@ class KSampler(Sampler): else: return x - def prepare_to_sample(self,t_enc): + def prepare_to_sample(self,t_enc,**kwargs): self.t_enc = t_enc self.model_wrap = None self.ds = None @@ -323,4 +298,3 @@ class KSampler(Sampler): Overrides parent method to return the q_sample of the inner model. ''' return self.model.inner_model.q_sample(x0,ts) - diff --git a/ldm/models/diffusion/plms.py b/ldm/models/diffusion/plms.py index 9e722eb932..eb778813a0 100644 --- a/ldm/models/diffusion/plms.py +++ b/ldm/models/diffusion/plms.py @@ -5,14 +5,24 @@ import numpy as np from tqdm import tqdm from functools import partial from ldm.invoke.devices import choose_torch_device +from ldm.models.diffusion.cross_attention import CrossAttentionControllableDiffusionMixin from ldm.models.diffusion.sampler import Sampler from ldm.modules.diffusionmodules.util import noise_like -class PLMSSampler(Sampler): +class PLMSSampler(Sampler, CrossAttentionControllableDiffusionMixin): def __init__(self, model, schedule='linear', device=None, **kwargs): super().__init__(model,schedule,model.num_timesteps, device) + def prepare_to_sample(self, t_enc, **kwargs): + super().prepare_to_sample(t_enc, **kwargs) + + edited_conditioning = kwargs.get('edited_conditioning', None) + edit_opcodes = kwargs.get('conditioning_edit_opcodes', None) + + self.setup_cross_attention_control_if_appropriate(self.model, edited_conditioning, edit_opcodes) + + # this is the essential routine @torch.no_grad() def p_sample( @@ -41,14 +51,18 @@ class PLMSSampler(Sampler): unconditional_conditioning is None or unconditional_guidance_scale == 1.0 ): + # damian0815 does not think this code path is ever used e_t = self.model.apply_model(x, t, c) else: - x_in = torch.cat([x] * 2) - t_in = torch.cat([t] * 2) - c_in = torch.cat([unconditional_conditioning, c]) - e_t_uncond, e_t = self.model.apply_model( - x_in, t_in, c_in - ).chunk(2) + #x_in = torch.cat([x] * 2) + #t_in = torch.cat([t] * 2) + #c_in = torch.cat([unconditional_conditioning, c]) + #e_t_uncond, e_t = self.model.apply_model( + # x_in, t_in, c_in + #).chunk(2) + e_t_uncond, e_t = self.do_cross_attention_controllable_diffusion_step(x, t, unconditional_conditioning, c, self.model, + model_forward_callback=lambda x, sigma, cond: self.model.apply_model(x, sigma, cond)) + e_t = e_t_uncond + unconditional_guidance_scale * ( e_t - e_t_uncond ) diff --git a/ldm/models/diffusion/sampler.py b/ldm/models/diffusion/sampler.py index eb7caebba0..b8377ebb39 100644 --- a/ldm/models/diffusion/sampler.py +++ b/ldm/models/diffusion/sampler.py @@ -192,6 +192,7 @@ class Sampler(object): unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning, steps=S, + **kwargs ) return samples, intermediates @@ -216,6 +217,7 @@ class Sampler(object): unconditional_guidance_scale=1.0, unconditional_conditioning=None, steps=None, + **kwargs ): b = shape[0] time_range = ( @@ -233,7 +235,7 @@ class Sampler(object): dynamic_ncols=True, ) old_eps = [] - self.prepare_to_sample(t_enc=total_steps) + self.prepare_to_sample(t_enc=total_steps,**kwargs) img = self.get_initial_image(x_T,shape,total_steps) # probably don't need this at all @@ -323,7 +325,7 @@ class Sampler(object): iterator = tqdm(time_range, desc='Decoding image', total=total_steps) x_dec = x_latent x0 = init_latent - self.prepare_to_sample(t_enc=total_steps) + self.prepare_to_sample(t_enc=total_steps,**kwargs) for i, step in enumerate(iterator): index = total_steps - i - 1 @@ -414,5 +416,3 @@ class Sampler(object): ''' return self.model.q_sample(x0,ts) - -