mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-09 12:45:32 -05:00
Add symmetry to generation (#2675)
Added symmetry to Invoke based on discussions with @damian0815. This can currently only be activated via the CLI with the `--h_symmetry_time_pct` and `--v_symmetry_time_pct` options. Those take values from 0.0-1.0, exclusive, indicating the percentage through generation at which symmetry is applied as a one-time operation. To have symmetry in either axis applied after the first step, use a very low value like 0.001.
This commit is contained in:
@@ -18,6 +18,8 @@ from ldm.models.diffusion.cross_attention_map_saving import AttentionMapSaver
|
||||
class PostprocessingSettings:
|
||||
threshold: float
|
||||
warmup: float
|
||||
h_symmetry_time_pct: Optional[float]
|
||||
v_symmetry_time_pct: Optional[float]
|
||||
|
||||
|
||||
class InvokeAIDiffuserComponent:
|
||||
@@ -30,7 +32,7 @@ class InvokeAIDiffuserComponent:
|
||||
* Hybrid conditioning (used for inpainting)
|
||||
'''
|
||||
debug_thresholding = False
|
||||
|
||||
last_percent_through = 0.0
|
||||
|
||||
@dataclass
|
||||
class ExtraConditioningInfo:
|
||||
@@ -56,6 +58,7 @@ class InvokeAIDiffuserComponent:
|
||||
self.is_running_diffusers = is_running_diffusers
|
||||
self.model_forward_callback = model_forward_callback
|
||||
self.cross_attention_control_context = None
|
||||
self.last_percent_through = 0.0
|
||||
|
||||
@contextmanager
|
||||
def custom_attention_context(self,
|
||||
@@ -164,6 +167,7 @@ class InvokeAIDiffuserComponent:
|
||||
if postprocessing_settings is not None:
|
||||
percent_through = self.calculate_percent_through(sigma, step_index, total_step_count)
|
||||
latents = self.apply_threshold(postprocessing_settings, latents, percent_through)
|
||||
latents = self.apply_symmetry(postprocessing_settings, latents, percent_through)
|
||||
return latents
|
||||
|
||||
def calculate_percent_through(self, sigma, step_index, total_step_count):
|
||||
@@ -292,8 +296,12 @@ class InvokeAIDiffuserComponent:
|
||||
self,
|
||||
postprocessing_settings: PostprocessingSettings,
|
||||
latents: torch.Tensor,
|
||||
percent_through
|
||||
percent_through: float
|
||||
) -> torch.Tensor:
|
||||
|
||||
if postprocessing_settings.threshold is None or postprocessing_settings.threshold == 0.0:
|
||||
return latents
|
||||
|
||||
threshold = postprocessing_settings.threshold
|
||||
warmup = postprocessing_settings.warmup
|
||||
|
||||
@@ -342,6 +350,56 @@ class InvokeAIDiffuserComponent:
|
||||
|
||||
return latents
|
||||
|
||||
def apply_symmetry(
|
||||
self,
|
||||
postprocessing_settings: PostprocessingSettings,
|
||||
latents: torch.Tensor,
|
||||
percent_through: float
|
||||
) -> torch.Tensor:
|
||||
|
||||
# Reset our last percent through if this is our first step.
|
||||
if percent_through == 0.0:
|
||||
self.last_percent_through = 0.0
|
||||
|
||||
if postprocessing_settings is None:
|
||||
return latents
|
||||
|
||||
# Check for out of bounds
|
||||
h_symmetry_time_pct = postprocessing_settings.h_symmetry_time_pct
|
||||
if (h_symmetry_time_pct is not None and (h_symmetry_time_pct <= 0.0 or h_symmetry_time_pct > 1.0)):
|
||||
h_symmetry_time_pct = None
|
||||
|
||||
v_symmetry_time_pct = postprocessing_settings.v_symmetry_time_pct
|
||||
if (v_symmetry_time_pct is not None and (v_symmetry_time_pct <= 0.0 or v_symmetry_time_pct > 1.0)):
|
||||
v_symmetry_time_pct = None
|
||||
|
||||
dev = latents.device.type
|
||||
|
||||
latents.to(device='cpu')
|
||||
|
||||
if (
|
||||
h_symmetry_time_pct != None and
|
||||
self.last_percent_through < h_symmetry_time_pct and
|
||||
percent_through >= h_symmetry_time_pct
|
||||
):
|
||||
# Horizontal symmetry occurs on the 3rd dimension of the latent
|
||||
width = latents.shape[3]
|
||||
x_flipped = torch.flip(latents, dims=[3])
|
||||
latents = torch.cat([latents[:, :, :, 0:int(width/2)], x_flipped[:, :, :, int(width/2):int(width)]], dim=3)
|
||||
|
||||
if (
|
||||
v_symmetry_time_pct != None and
|
||||
self.last_percent_through < v_symmetry_time_pct and
|
||||
percent_through >= v_symmetry_time_pct
|
||||
):
|
||||
# Vertical symmetry occurs on the 2nd dimension of the latent
|
||||
height = latents.shape[2]
|
||||
y_flipped = torch.flip(latents, dims=[2])
|
||||
latents = torch.cat([latents[:, :, 0:int(height / 2)], y_flipped[:, :, int(height / 2):int(height)]], dim=2)
|
||||
|
||||
self.last_percent_through = percent_through
|
||||
return latents.to(device=dev)
|
||||
|
||||
def estimate_percent_through(self, step_index, sigma):
|
||||
if step_index is not None and self.cross_attention_control_context is not None:
|
||||
# percent_through will never reach 1.0 (but this is intended)
|
||||
|
||||
Reference in New Issue
Block a user