diff --git a/ldm/models/diffusion/shared_invokeai_diffusion.py b/ldm/models/diffusion/shared_invokeai_diffusion.py index d9163468d2..2e513d3f5a 100644 --- a/ldm/models/diffusion/shared_invokeai_diffusion.py +++ b/ldm/models/diffusion/shared_invokeai_diffusion.py @@ -29,7 +29,7 @@ class InvokeAIDiffuserComponent: * Hybrid conditioning (used for inpainting) ''' debug_thresholding = False - + sequential_conditioning = False @dataclass class ExtraConditioningInfo: @@ -149,9 +149,13 @@ class InvokeAIDiffuserComponent: unconditioning, conditioning, cross_attention_control_types_to_do) + elif self.sequential_conditioning: + unconditioned_next_x, conditioned_next_x = self._apply_standard_conditioning_sequentially( + x, sigma, unconditioning, conditioning) + else: - unconditioned_next_x, conditioned_next_x = self._apply_standard_conditioning(x, sigma, unconditioning, - conditioning) + unconditioned_next_x, conditioned_next_x = self._apply_standard_conditioning( + x, sigma, unconditioning, conditioning) combined_next_x = self._combine(unconditioned_next_x, conditioned_next_x, unconditional_guidance_scale) @@ -198,6 +202,16 @@ class InvokeAIDiffuserComponent: return unconditioned_next_x, conditioned_next_x + def _apply_standard_conditioning_sequentially(self, x: torch.Tensor, sigma, unconditioning: torch.Tensor, conditioning: torch.Tensor): + # low-memory sequential path + unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning) + conditioned_next_x = self.model_forward_callback(x, sigma, conditioning) + if conditioned_next_x.device.type == 'mps': + # prevent a result filled with zeros. seems to be a torch bug. + conditioned_next_x = conditioned_next_x.clone() + return unconditioned_next_x, conditioned_next_x + + def _apply_hybrid_conditioning(self, x, sigma, unconditioning, conditioning): assert isinstance(conditioning, dict) assert isinstance(unconditioning, dict)