From d0abe13b6060d07e7946df16a64686f9d359ac94 Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Sun, 19 Feb 2023 16:04:54 -0800 Subject: [PATCH] performance(InvokeAIDiffuserComponent): add low-memory path for calculating conditioned and unconditioned predictions sequentially Proof of concept. Still needs to be wired up to options or heuristics. --- .../diffusion/shared_invokeai_diffusion.py | 20 ++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/ldm/models/diffusion/shared_invokeai_diffusion.py b/ldm/models/diffusion/shared_invokeai_diffusion.py index d9163468d2..2e513d3f5a 100644 --- a/ldm/models/diffusion/shared_invokeai_diffusion.py +++ b/ldm/models/diffusion/shared_invokeai_diffusion.py @@ -29,7 +29,7 @@ class InvokeAIDiffuserComponent: * Hybrid conditioning (used for inpainting) ''' debug_thresholding = False - + sequential_conditioning = False @dataclass class ExtraConditioningInfo: @@ -149,9 +149,13 @@ class InvokeAIDiffuserComponent: unconditioning, conditioning, cross_attention_control_types_to_do) + elif self.sequential_conditioning: + unconditioned_next_x, conditioned_next_x = self._apply_standard_conditioning_sequentially( + x, sigma, unconditioning, conditioning) + else: - unconditioned_next_x, conditioned_next_x = self._apply_standard_conditioning(x, sigma, unconditioning, - conditioning) + unconditioned_next_x, conditioned_next_x = self._apply_standard_conditioning( + x, sigma, unconditioning, conditioning) combined_next_x = self._combine(unconditioned_next_x, conditioned_next_x, unconditional_guidance_scale) @@ -198,6 +202,16 @@ class InvokeAIDiffuserComponent: return unconditioned_next_x, conditioned_next_x + def _apply_standard_conditioning_sequentially(self, x: torch.Tensor, sigma, unconditioning: torch.Tensor, conditioning: torch.Tensor): + # low-memory sequential path + unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning) + conditioned_next_x = self.model_forward_callback(x, sigma, conditioning) + if conditioned_next_x.device.type == 'mps': + # prevent a result filled with zeros. seems to be a torch bug. + conditioned_next_x = conditioned_next_x.clone() + return unconditioned_next_x, conditioned_next_x + + def _apply_hybrid_conditioning(self, x, sigma, unconditioning, conditioning): assert isinstance(conditioning, dict) assert isinstance(unconditioning, dict)