Cleaned up and refactored new symmetry.

This commit is contained in:
JPPhoto
2023-02-26 18:00:51 -06:00
parent 9d97b106b0
commit c0f259173d

View File

@@ -407,33 +407,37 @@ class InvokeAIDiffuserComponent:
if (v_symmetry_time_pct is not None and (v_symmetry_time_pct <= 0.0 or v_symmetry_time_pct > 1.0)):
v_symmetry_time_pct = None
width = latents.shape[3]
height = latents.shape[2]
dev = latents.device.type
dtype = latents.dtype
symmetry_type = postprocessing_settings.symmetry_type or SymmetryType.FADE
latents.to(device='cpu')
def make_ramp(ease_in:int, total:int) -> torch.Tensor:
ramp1 = torch.linspace(start=1.0, end=0.5, steps=ease_in, device=dev)
ramp2 = torch.linspace(start=0.5, end=1.0, steps=total - ease_in, device=dev)
ramp = torch.cat((ramp1, ramp2))
return ramp
if (
h_symmetry_time_pct != None and
self.last_percent_through < h_symmetry_time_pct and
percent_through >= h_symmetry_time_pct
):
# Horizontal symmetry occurs on the 3rd dimension of the latent
x_flipped = torch.flip(latents, dims=[3])
if symmetry_type is SymmetryType.MIRROR:
# Horizontal symmetry occurs on the 3rd dimension of the latent
width = latents.shape[3]
x_flipped = torch.flip(latents, dims=[3])
# Use the first half of latents and then the flipped one on this dimension
latents = torch.cat([latents[:, :, :, 0:int(width/2)], x_flipped[:, :, :, int(width/2):int(width)]], dim=3)
elif symmetry_type is SymmetryType.FADE:
# Horizontal symmetry occurs on the 3rd dimension of the latent
width = latents.shape[3]
height = latents.shape[2]
dtype = latents.dtype
x_flipped = torch.flip(latents, dims=[3])
apply_width = 2 * (width//4)
ramp1 = torch.linspace(start=1.0, end=0.5, steps=apply_width, device=latents.device)
ramp2 = torch.linspace(start=0.5, end=1.0, steps=width-(apply_width), device=latents.device)
ramp = torch.cat((ramp1,ramp2))
apply_width = width // 2
# Create a linear ramp so the middle gets perfect symmetry but the edges retain their original latents
ramp = make_ramp(ease_in=apply_width, total=width)
fade1 = einops.repeat(tensor=ramp, pattern='m -> 1 4 k m', k=height).to(latents.device).type(dtype)
fade0 = 1 - fade1
# Multiply the crossover region to retain details and avoid a "muddy" appearance
multiplier = (fade1 * fade0) * 1.25 + 1
latents = ((latents * fade1) + (x_flipped * fade0)) * multiplier
@@ -442,23 +446,18 @@ class InvokeAIDiffuserComponent:
self.last_percent_through < v_symmetry_time_pct and
percent_through >= v_symmetry_time_pct
):
# Vertical symmetry occurs on the 3rd dimension of the latent
y_flipped = torch.flip(latents, dims=[2])
if symmetry_type is SymmetryType.MIRROR:
# Vertical symmetry occurs on the 2nd dimension of the latent
height = latents.shape[2]
y_flipped = torch.flip(latents, dims=[2])
latents = torch.cat([latents[:, :, 0:int(height / 2)], y_flipped[:, :, int(height / 2):int(height)]], dim=2)
elif symmetry_type is SymmetryType.FADE:
# Vertical symmetry occurs on the 2nd dimension of the latent
width = latents.shape[3]
height = latents.shape[2]
dtype = latents.dtype
y_flipped = torch.flip(latents, dims=[2])
apply_height = 2 * (height // 4)
ramp1 = torch.linspace(start=1.0, end=0.5, steps=apply_height, device=latents.device)
ramp2 = torch.linspace(start=0.5, end=1.0, steps=height - (apply_height), device=latents.device)
ramp = torch.cat((ramp1, ramp2))
apply_height = height // 2
# Create a linear ramp so the middle gets perfect symmetry but the edges retain their original latents
ramp = make_ramp(ease_in=apply_height, total=height)
fade1 = einops.repeat(tensor=ramp, pattern='m -> 1 4 m k', k=width).to(latents.device).type(dtype)
fade0 = 1 - fade1
# Multiply the crossover region to retain details and avoid a "muddy" appearance
multiplier = (fade1 * fade0) * 1.25 + 1
latents = ((latents * fade1) + (y_flipped * fade0)) * multiplier