From c0f259173da49468426ca0385c43290c07913e14 Mon Sep 17 00:00:00 2001 From: JPPhoto Date: Sun, 26 Feb 2023 18:00:51 -0600 Subject: [PATCH] Cleaned up and refactored new symmetry. --- .../diffusion/shared_invokeai_diffusion.py | 45 +++++++++---------- 1 file changed, 22 insertions(+), 23 deletions(-) diff --git a/ldm/models/diffusion/shared_invokeai_diffusion.py b/ldm/models/diffusion/shared_invokeai_diffusion.py index 79e1006c63..e950bd2969 100644 --- a/ldm/models/diffusion/shared_invokeai_diffusion.py +++ b/ldm/models/diffusion/shared_invokeai_diffusion.py @@ -407,33 +407,37 @@ class InvokeAIDiffuserComponent: if (v_symmetry_time_pct is not None and (v_symmetry_time_pct <= 0.0 or v_symmetry_time_pct > 1.0)): v_symmetry_time_pct = None + width = latents.shape[3] + height = latents.shape[2] dev = latents.device.type + dtype = latents.dtype symmetry_type = postprocessing_settings.symmetry_type or SymmetryType.FADE latents.to(device='cpu') + def make_ramp(ease_in:int, total:int) -> torch.Tensor: + ramp1 = torch.linspace(start=1.0, end=0.5, steps=ease_in, device=dev) + ramp2 = torch.linspace(start=0.5, end=1.0, steps=total - ease_in, device=dev) + ramp = torch.cat((ramp1, ramp2)) + return ramp + if ( h_symmetry_time_pct != None and self.last_percent_through < h_symmetry_time_pct and percent_through >= h_symmetry_time_pct ): + # Horizontal symmetry occurs on the 3rd dimension of the latent + x_flipped = torch.flip(latents, dims=[3]) if symmetry_type is SymmetryType.MIRROR: - # Horizontal symmetry occurs on the 3rd dimension of the latent - width = latents.shape[3] - x_flipped = torch.flip(latents, dims=[3]) + # Use the first half of latents and then the flipped one on this dimension latents = torch.cat([latents[:, :, :, 0:int(width/2)], x_flipped[:, :, :, int(width/2):int(width)]], dim=3) elif symmetry_type is SymmetryType.FADE: - # Horizontal symmetry occurs on the 3rd dimension of the latent - width = latents.shape[3] - height = latents.shape[2] - dtype = latents.dtype - x_flipped = torch.flip(latents, dims=[3]) - apply_width = 2 * (width//4) - ramp1 = torch.linspace(start=1.0, end=0.5, steps=apply_width, device=latents.device) - ramp2 = torch.linspace(start=0.5, end=1.0, steps=width-(apply_width), device=latents.device) - ramp = torch.cat((ramp1,ramp2)) + apply_width = width // 2 + # Create a linear ramp so the middle gets perfect symmetry but the edges retain their original latents + ramp = make_ramp(ease_in=apply_width, total=width) fade1 = einops.repeat(tensor=ramp, pattern='m -> 1 4 k m', k=height).to(latents.device).type(dtype) fade0 = 1 - fade1 + # Multiply the crossover region to retain details and avoid a "muddy" appearance multiplier = (fade1 * fade0) * 1.25 + 1 latents = ((latents * fade1) + (x_flipped * fade0)) * multiplier @@ -442,23 +446,18 @@ class InvokeAIDiffuserComponent: self.last_percent_through < v_symmetry_time_pct and percent_through >= v_symmetry_time_pct ): + # Vertical symmetry occurs on the 3rd dimension of the latent + y_flipped = torch.flip(latents, dims=[2]) if symmetry_type is SymmetryType.MIRROR: # Vertical symmetry occurs on the 2nd dimension of the latent - height = latents.shape[2] - y_flipped = torch.flip(latents, dims=[2]) latents = torch.cat([latents[:, :, 0:int(height / 2)], y_flipped[:, :, int(height / 2):int(height)]], dim=2) elif symmetry_type is SymmetryType.FADE: - # Vertical symmetry occurs on the 2nd dimension of the latent - width = latents.shape[3] - height = latents.shape[2] - dtype = latents.dtype - y_flipped = torch.flip(latents, dims=[2]) - apply_height = 2 * (height // 4) - ramp1 = torch.linspace(start=1.0, end=0.5, steps=apply_height, device=latents.device) - ramp2 = torch.linspace(start=0.5, end=1.0, steps=height - (apply_height), device=latents.device) - ramp = torch.cat((ramp1, ramp2)) + apply_height = height // 2 + # Create a linear ramp so the middle gets perfect symmetry but the edges retain their original latents + ramp = make_ramp(ease_in=apply_height, total=height) fade1 = einops.repeat(tensor=ramp, pattern='m -> 1 4 m k', k=width).to(latents.device).type(dtype) fade0 = 1 - fade1 + # Multiply the crossover region to retain details and avoid a "muddy" appearance multiplier = (fade1 * fade0) * 1.25 + 1 latents = ((latents * fade1) + (y_flipped * fade0)) * multiplier