Compare commits

...

1 Commits

Author SHA1 Message Date
Ryan Dick
c53ec3de45 checkpoint 2024-09-17 21:29:05 +00:00
2 changed files with 40 additions and 3 deletions

View File

@@ -28,8 +28,19 @@ class InpaintExtension:
This function should be called after each denoising step.
"""
timestep_cutoff = 0.5
if timestep > timestep_cutoff:
# Early in the denoising process, use the smaller mask.
# I.e. treat gradient values as 0.0.
mask = self._inpaint_mask.where(self._inpaint_mask >= (1.0 - 1e-3), 0.0)
else:
# After the cut-off, use the larger mask.
# I.e. treat gradient values as 1.0.
mask = self._inpaint_mask.where(self._inpaint_mask <= (0.0 + 1e-3), 1.0)
# mask = (self._inpaint_mask > (0.0 + 1e-5)).float()
# Noise the init latents for the current timestep.
noised_init_latents = self._noise * timestep + (1.0 - timestep) * self._init_latents
# Merge the intermediate latents with the noised_init_latents using the inpaint_mask.
return intermediate_latents * self._inpaint_mask + noised_init_latents * (1.0 - self._inpaint_mask)
return intermediate_latents * mask + noised_init_latents * (1.0 - mask)

View File

@@ -31,10 +31,24 @@ def get_noise(
def time_shift(mu: float, sigma: float, t: torch.Tensor) -> torch.Tensor:
"""Shift the timestep schedule.
This is a simmilar idea to the beta schedule introduced in https://arxiv.org/abs/2305.08898. But, the function for
remapping timesteps in [0, 1] is different.
Properties of this function:
- Recommended sigma values: 1.0 <= sigma <= 3.0.
- When sigma=1.0 and mu=0.0, the conversion is the identity function.
- Increasing sigma results in an increasingly steep logistic function.
- Adjusting mu shifts the midpoint of the logistic function.
"""
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]:
"""Return a linear function that maps x to y given the coordsinates of two points on the line (x1, y1) and
(x2, y2).
"""
m = (y2 - y1) / (x2 - x1)
b = y1 - m * x1
return lambda x: m * x + b
@@ -52,9 +66,17 @@ def get_schedule(
# shifting the schedule to favor high timesteps for higher signal images
if shift:
# estimate mu based on linear estimation between two points
# Select mu based on linear interpolation between two points.
# Point 1: (image_seq_len=256, mu=0.5)
# Point 2: (image_seq_len=4096, mu=1.15)
# This has the effect of increasing mu as the image size increases. image_seq_len=4096 corresponds to an image
# size of 1024x1024.
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
timesteps = time_shift(mu, 1.0, timesteps)
# Shift the timesteps based on mu. Higher values of mu mean that there will be more timesteps early in the
# denoising process (i.e. many small steps in the timestep range 1.0-0.9, and fewer large steps in the timestep
# range 0.1-0.0).
timesteps = time_shift(mu=mu, sigma=1.0, t=timesteps)
return timesteps.tolist()
@@ -94,6 +116,10 @@ def clip_timestep_schedule(timesteps: list[float], denoising_start: float, denoi
clipped_timesteps = timesteps[t_start_idx : t_end_idx + 1]
# clipped_timesteps = torch.tensor(timesteps)
# clipped_timesteps = clipped_timesteps * (t_start_val - t_end_val) + t_end_val
# clipped_timesteps = clipped_timesteps.tolist()
return clipped_timesteps