Merge branch 'main' into perf/lowmem_sequential_guidance

This commit is contained in:
Lincoln Stein
2023-02-20 07:42:35 -05:00
committed by GitHub
10 changed files with 140 additions and 22 deletions

View File

@@ -24,6 +24,8 @@ ModelForwardCallback: TypeAlias = Union[
class PostprocessingSettings:
threshold: float
warmup: float
h_symmetry_time_pct: Optional[float]
v_symmetry_time_pct: Optional[float]
class InvokeAIDiffuserComponent:
@@ -179,6 +181,7 @@ class InvokeAIDiffuserComponent:
if postprocessing_settings is not None:
percent_through = self.calculate_percent_through(sigma, step_index, total_step_count)
latents = self.apply_threshold(postprocessing_settings, latents, percent_through)
latents = self.apply_symmetry(postprocessing_settings, latents, percent_through)
return latents
def calculate_percent_through(self, sigma, step_index, total_step_count):
@@ -320,8 +323,12 @@ class InvokeAIDiffuserComponent:
self,
postprocessing_settings: PostprocessingSettings,
latents: torch.Tensor,
percent_through
percent_through: float
) -> torch.Tensor:
if postprocessing_settings.threshold is None or postprocessing_settings.threshold == 0.0:
return latents
threshold = postprocessing_settings.threshold
warmup = postprocessing_settings.warmup
@@ -370,6 +377,56 @@ class InvokeAIDiffuserComponent:
return latents
def apply_symmetry(
self,
postprocessing_settings: PostprocessingSettings,
latents: torch.Tensor,
percent_through: float
) -> torch.Tensor:
# Reset our last percent through if this is our first step.
if percent_through == 0.0:
self.last_percent_through = 0.0
if postprocessing_settings is None:
return latents
# Check for out of bounds
h_symmetry_time_pct = postprocessing_settings.h_symmetry_time_pct
if (h_symmetry_time_pct is not None and (h_symmetry_time_pct <= 0.0 or h_symmetry_time_pct > 1.0)):
h_symmetry_time_pct = None
v_symmetry_time_pct = postprocessing_settings.v_symmetry_time_pct
if (v_symmetry_time_pct is not None and (v_symmetry_time_pct <= 0.0 or v_symmetry_time_pct > 1.0)):
v_symmetry_time_pct = None
dev = latents.device.type
latents.to(device='cpu')
if (
h_symmetry_time_pct != None and
self.last_percent_through < h_symmetry_time_pct and
percent_through >= h_symmetry_time_pct
):
# Horizontal symmetry occurs on the 3rd dimension of the latent
width = latents.shape[3]
x_flipped = torch.flip(latents, dims=[3])
latents = torch.cat([latents[:, :, :, 0:int(width/2)], x_flipped[:, :, :, int(width/2):int(width)]], dim=3)
if (
v_symmetry_time_pct != None and
self.last_percent_through < v_symmetry_time_pct and
percent_through >= v_symmetry_time_pct
):
# Vertical symmetry occurs on the 2nd dimension of the latent
height = latents.shape[2]
y_flipped = torch.flip(latents, dims=[2])
latents = torch.cat([latents[:, :, 0:int(height / 2)], y_flipped[:, :, int(height / 2):int(height)]], dim=2)
self.last_percent_through = percent_through
return latents.to(device=dev)
def estimate_percent_through(self, step_index, sigma):
if step_index is not None and self.cross_attention_control_context is not None:
# percent_through will never reach 1.0 (but this is intended)