From bb6772498a2e574623a1a02486bc29802820b2be Mon Sep 17 00:00:00 2001 From: JPPhoto Date: Sat, 25 Feb 2023 09:05:49 -0600 Subject: [PATCH] Add symmetry types and a new symmetry implementation. --- .../diffusion/shared_invokeai_diffusion.py | 54 ++++++++++++++++--- 1 file changed, 46 insertions(+), 8 deletions(-) diff --git a/ldm/models/diffusion/shared_invokeai_diffusion.py b/ldm/models/diffusion/shared_invokeai_diffusion.py index cddddd3e86..79e1006c63 100644 --- a/ldm/models/diffusion/shared_invokeai_diffusion.py +++ b/ldm/models/diffusion/shared_invokeai_diffusion.py @@ -1,3 +1,4 @@ +from enum import Enum from contextlib import contextmanager from dataclasses import dataclass from math import ceil @@ -6,6 +7,7 @@ from typing import Callable, Optional, Union, Any, Dict import numpy as np import torch from diffusers.models.cross_attention import AttnProcessor +from einops import einops from typing_extensions import TypeAlias from ldm.invoke.globals import Globals @@ -20,12 +22,17 @@ ModelForwardCallback: TypeAlias = Union[ Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor] ] + +SymmetryType = Enum('SymmetryType', ['MIRROR', 'FADE']) + + @dataclass(frozen=True) class PostprocessingSettings: threshold: float warmup: float h_symmetry_time_pct: Optional[float] v_symmetry_time_pct: Optional[float] + symmetry_type: Optional[SymmetryType] = None class InvokeAIDiffuserComponent: @@ -401,6 +408,7 @@ class InvokeAIDiffuserComponent: v_symmetry_time_pct = None dev = latents.device.type + symmetry_type = postprocessing_settings.symmetry_type or SymmetryType.FADE latents.to(device='cpu') @@ -409,20 +417,50 @@ class InvokeAIDiffuserComponent: self.last_percent_through < h_symmetry_time_pct and percent_through >= h_symmetry_time_pct ): - # Horizontal symmetry occurs on the 3rd dimension of the latent - width = latents.shape[3] - x_flipped = torch.flip(latents, dims=[3]) - latents = torch.cat([latents[:, :, :, 0:int(width/2)], x_flipped[:, :, :, int(width/2):int(width)]], dim=3) + if symmetry_type is SymmetryType.MIRROR: + # Horizontal symmetry occurs on the 3rd dimension of the latent + width = latents.shape[3] + x_flipped = torch.flip(latents, dims=[3]) + latents = torch.cat([latents[:, :, :, 0:int(width/2)], x_flipped[:, :, :, int(width/2):int(width)]], dim=3) + elif symmetry_type is SymmetryType.FADE: + # Horizontal symmetry occurs on the 3rd dimension of the latent + width = latents.shape[3] + height = latents.shape[2] + dtype = latents.dtype + x_flipped = torch.flip(latents, dims=[3]) + apply_width = 2 * (width//4) + ramp1 = torch.linspace(start=1.0, end=0.5, steps=apply_width, device=latents.device) + ramp2 = torch.linspace(start=0.5, end=1.0, steps=width-(apply_width), device=latents.device) + ramp = torch.cat((ramp1,ramp2)) + fade1 = einops.repeat(tensor=ramp, pattern='m -> 1 4 k m', k=height).to(latents.device).type(dtype) + fade0 = 1 - fade1 + multiplier = (fade1 * fade0) * 1.25 + 1 + latents = ((latents * fade1) + (x_flipped * fade0)) * multiplier if ( v_symmetry_time_pct != None and self.last_percent_through < v_symmetry_time_pct and percent_through >= v_symmetry_time_pct ): - # Vertical symmetry occurs on the 2nd dimension of the latent - height = latents.shape[2] - y_flipped = torch.flip(latents, dims=[2]) - latents = torch.cat([latents[:, :, 0:int(height / 2)], y_flipped[:, :, int(height / 2):int(height)]], dim=2) + if symmetry_type is SymmetryType.MIRROR: + # Vertical symmetry occurs on the 2nd dimension of the latent + height = latents.shape[2] + y_flipped = torch.flip(latents, dims=[2]) + latents = torch.cat([latents[:, :, 0:int(height / 2)], y_flipped[:, :, int(height / 2):int(height)]], dim=2) + elif symmetry_type is SymmetryType.FADE: + # Vertical symmetry occurs on the 2nd dimension of the latent + width = latents.shape[3] + height = latents.shape[2] + dtype = latents.dtype + y_flipped = torch.flip(latents, dims=[2]) + apply_height = 2 * (height // 4) + ramp1 = torch.linspace(start=1.0, end=0.5, steps=apply_height, device=latents.device) + ramp2 = torch.linspace(start=0.5, end=1.0, steps=height - (apply_height), device=latents.device) + ramp = torch.cat((ramp1, ramp2)) + fade1 = einops.repeat(tensor=ramp, pattern='m -> 1 4 m k', k=width).to(latents.device).type(dtype) + fade0 = 1 - fade1 + multiplier = (fade1 * fade0) * 1.25 + 1 + latents = ((latents * fade1) + (y_flipped * fade0)) * multiplier self.last_percent_through = percent_through return latents.to(device=dev)