mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-22 10:37:55 -05:00
Add symmetry types and a new symmetry implementation.
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
from enum import Enum
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from math import ceil
|
||||
@@ -6,6 +7,7 @@ from typing import Callable, Optional, Union, Any, Dict
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers.models.cross_attention import AttnProcessor
|
||||
from einops import einops
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from ldm.invoke.globals import Globals
|
||||
@@ -20,12 +22,17 @@ ModelForwardCallback: TypeAlias = Union[
|
||||
Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]
|
||||
]
|
||||
|
||||
|
||||
SymmetryType = Enum('SymmetryType', ['MIRROR', 'FADE'])
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PostprocessingSettings:
|
||||
threshold: float
|
||||
warmup: float
|
||||
h_symmetry_time_pct: Optional[float]
|
||||
v_symmetry_time_pct: Optional[float]
|
||||
symmetry_type: Optional[SymmetryType] = None
|
||||
|
||||
|
||||
class InvokeAIDiffuserComponent:
|
||||
@@ -401,6 +408,7 @@ class InvokeAIDiffuserComponent:
|
||||
v_symmetry_time_pct = None
|
||||
|
||||
dev = latents.device.type
|
||||
symmetry_type = postprocessing_settings.symmetry_type or SymmetryType.FADE
|
||||
|
||||
latents.to(device='cpu')
|
||||
|
||||
@@ -409,20 +417,50 @@ class InvokeAIDiffuserComponent:
|
||||
self.last_percent_through < h_symmetry_time_pct and
|
||||
percent_through >= h_symmetry_time_pct
|
||||
):
|
||||
# Horizontal symmetry occurs on the 3rd dimension of the latent
|
||||
width = latents.shape[3]
|
||||
x_flipped = torch.flip(latents, dims=[3])
|
||||
latents = torch.cat([latents[:, :, :, 0:int(width/2)], x_flipped[:, :, :, int(width/2):int(width)]], dim=3)
|
||||
if symmetry_type is SymmetryType.MIRROR:
|
||||
# Horizontal symmetry occurs on the 3rd dimension of the latent
|
||||
width = latents.shape[3]
|
||||
x_flipped = torch.flip(latents, dims=[3])
|
||||
latents = torch.cat([latents[:, :, :, 0:int(width/2)], x_flipped[:, :, :, int(width/2):int(width)]], dim=3)
|
||||
elif symmetry_type is SymmetryType.FADE:
|
||||
# Horizontal symmetry occurs on the 3rd dimension of the latent
|
||||
width = latents.shape[3]
|
||||
height = latents.shape[2]
|
||||
dtype = latents.dtype
|
||||
x_flipped = torch.flip(latents, dims=[3])
|
||||
apply_width = 2 * (width//4)
|
||||
ramp1 = torch.linspace(start=1.0, end=0.5, steps=apply_width, device=latents.device)
|
||||
ramp2 = torch.linspace(start=0.5, end=1.0, steps=width-(apply_width), device=latents.device)
|
||||
ramp = torch.cat((ramp1,ramp2))
|
||||
fade1 = einops.repeat(tensor=ramp, pattern='m -> 1 4 k m', k=height).to(latents.device).type(dtype)
|
||||
fade0 = 1 - fade1
|
||||
multiplier = (fade1 * fade0) * 1.25 + 1
|
||||
latents = ((latents * fade1) + (x_flipped * fade0)) * multiplier
|
||||
|
||||
if (
|
||||
v_symmetry_time_pct != None and
|
||||
self.last_percent_through < v_symmetry_time_pct and
|
||||
percent_through >= v_symmetry_time_pct
|
||||
):
|
||||
# Vertical symmetry occurs on the 2nd dimension of the latent
|
||||
height = latents.shape[2]
|
||||
y_flipped = torch.flip(latents, dims=[2])
|
||||
latents = torch.cat([latents[:, :, 0:int(height / 2)], y_flipped[:, :, int(height / 2):int(height)]], dim=2)
|
||||
if symmetry_type is SymmetryType.MIRROR:
|
||||
# Vertical symmetry occurs on the 2nd dimension of the latent
|
||||
height = latents.shape[2]
|
||||
y_flipped = torch.flip(latents, dims=[2])
|
||||
latents = torch.cat([latents[:, :, 0:int(height / 2)], y_flipped[:, :, int(height / 2):int(height)]], dim=2)
|
||||
elif symmetry_type is SymmetryType.FADE:
|
||||
# Vertical symmetry occurs on the 2nd dimension of the latent
|
||||
width = latents.shape[3]
|
||||
height = latents.shape[2]
|
||||
dtype = latents.dtype
|
||||
y_flipped = torch.flip(latents, dims=[2])
|
||||
apply_height = 2 * (height // 4)
|
||||
ramp1 = torch.linspace(start=1.0, end=0.5, steps=apply_height, device=latents.device)
|
||||
ramp2 = torch.linspace(start=0.5, end=1.0, steps=height - (apply_height), device=latents.device)
|
||||
ramp = torch.cat((ramp1, ramp2))
|
||||
fade1 = einops.repeat(tensor=ramp, pattern='m -> 1 4 m k', k=width).to(latents.device).type(dtype)
|
||||
fade0 = 1 - fade1
|
||||
multiplier = (fade1 * fade0) * 1.25 + 1
|
||||
latents = ((latents * fade1) + (y_flipped * fade0)) * multiplier
|
||||
|
||||
self.last_percent_through = percent_through
|
||||
return latents.to(device=dev)
|
||||
|
||||
Reference in New Issue
Block a user