diff --git a/ldm/models/diffusion/shared_invokeai_diffusion.py b/ldm/models/diffusion/shared_invokeai_diffusion.py index ca3e608fc0..d9163468d2 100644 --- a/ldm/models/diffusion/shared_invokeai_diffusion.py +++ b/ldm/models/diffusion/shared_invokeai_diffusion.py @@ -1,4 +1,3 @@ -import math from contextlib import contextmanager from dataclasses import dataclass from math import ceil @@ -6,8 +5,8 @@ from typing import Callable, Optional, Union, Any, Dict import numpy as np import torch - from diffusers.models.cross_attention import AttnProcessor + from ldm.models.diffusion.cross_attention_control import Arguments, \ restore_default_cross_attention, override_cross_attention, Context, get_cross_attention_modules, \ CrossAttentionType, SwapCrossAttnContext @@ -143,11 +142,16 @@ class InvokeAIDiffuserComponent: wants_hybrid_conditioning = isinstance(conditioning, dict) if wants_hybrid_conditioning: - unconditioned_next_x, conditioned_next_x = self.apply_hybrid_conditioning(x, sigma, unconditioning, conditioning) + unconditioned_next_x, conditioned_next_x = self._apply_hybrid_conditioning(x, sigma, unconditioning, + conditioning) elif wants_cross_attention_control: - unconditioned_next_x, conditioned_next_x = self.apply_cross_attention_controlled_conditioning(x, sigma, unconditioning, conditioning, cross_attention_control_types_to_do) + unconditioned_next_x, conditioned_next_x = self._apply_cross_attention_controlled_conditioning(x, sigma, + unconditioning, + conditioning, + cross_attention_control_types_to_do) else: - unconditioned_next_x, conditioned_next_x = self.apply_standard_conditioning(x, sigma, unconditioning, conditioning) + unconditioned_next_x, conditioned_next_x = self._apply_standard_conditioning(x, sigma, unconditioning, + conditioning) combined_next_x = self._combine(unconditioned_next_x, conditioned_next_x, unconditional_guidance_scale) @@ -181,7 +185,7 @@ class InvokeAIDiffuserComponent: # methods below are called from do_diffusion_step and should be considered private to this class. - def apply_standard_conditioning(self, x, sigma, unconditioning, conditioning): + def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning): # fast batched path x_twice = torch.cat([x] * 2) sigma_twice = torch.cat([sigma] * 2) @@ -194,7 +198,7 @@ class InvokeAIDiffuserComponent: return unconditioned_next_x, conditioned_next_x - def apply_hybrid_conditioning(self, x, sigma, unconditioning, conditioning): + def _apply_hybrid_conditioning(self, x, sigma, unconditioning, conditioning): assert isinstance(conditioning, dict) assert isinstance(unconditioning, dict) x_twice = torch.cat([x] * 2) @@ -212,18 +216,21 @@ class InvokeAIDiffuserComponent: return unconditioned_next_x, conditioned_next_x - def apply_cross_attention_controlled_conditioning(self, + def _apply_cross_attention_controlled_conditioning(self, x: torch.Tensor, sigma, unconditioning, conditioning, cross_attention_control_types_to_do): if self.is_running_diffusers: - return self.apply_cross_attention_controlled_conditioning__diffusers(x, sigma, unconditioning, conditioning, cross_attention_control_types_to_do) + return self._apply_cross_attention_controlled_conditioning__diffusers(x, sigma, unconditioning, + conditioning, + cross_attention_control_types_to_do) else: - return self.apply_cross_attention_controlled_conditioning__compvis(x, sigma, unconditioning, conditioning, cross_attention_control_types_to_do) + return self._apply_cross_attention_controlled_conditioning__compvis(x, sigma, unconditioning, conditioning, + cross_attention_control_types_to_do) - def apply_cross_attention_controlled_conditioning__diffusers(self, + def _apply_cross_attention_controlled_conditioning__diffusers(self, x: torch.Tensor, sigma, unconditioning, @@ -246,7 +253,7 @@ class InvokeAIDiffuserComponent: return unconditioned_next_x, conditioned_next_x - def apply_cross_attention_controlled_conditioning__compvis(self, x:torch.Tensor, sigma, unconditioning, conditioning, cross_attention_control_types_to_do): + def _apply_cross_attention_controlled_conditioning__compvis(self, x:torch.Tensor, sigma, unconditioning, conditioning, cross_attention_control_types_to_do): # print('pct', percent_through, ': doing cross attention control on', cross_attention_control_types_to_do) # slower non-batched path (20% slower on mac MPS) # We are only interested in using attention maps for conditioned_next_x, but batching them with generation of