Tidy up the logic for inpainting mask adjustment in FLUX TrajectoryGuidanceExtension.

This commit is contained in:
Ryan Dick
2024-09-20 14:48:06 +00:00
parent b6748fb1e1
commit 2f82171dff

View File

@@ -57,23 +57,32 @@ class TrajectoryGuidanceExtension:
else:
self._inpaint_mask = inpaint_mask
def _apply_mask_gradient_adjustment(self, t_prev: float) -> torch.Tensor:
"""Applies inpaint mask gradient adjustment and returns the inpaint mask to be used at the current timestep."""
# As we progress through the denoising process, we promote gradient regions of the mask to have a full weight of
# 1.0. This helps to produce more coherent seams around the inpainted region. We experimented with a (small)
# number of promotion strategies (e.g. gradual promotion based on timestep), but found that a simple cutoff
# threshold worked well.
# We use a small epsilon to avoid any potential issues with floating point precision.
eps = 1e-4
mask_gradient_t_cutoff = 0.5
if t_prev > mask_gradient_t_cutoff:
# Early in the denoising process, use the inpaint mask as-is.
return self._inpaint_mask
else:
# After the cut-off, promote all non-zero mask values to 1.0.
mask = self._inpaint_mask.where(self._inpaint_mask <= (0.0 + eps), 1.0)
return mask
def step(
self, t_curr_latents: torch.Tensor, pred_noise: torch.Tensor, t_curr: float, t_prev: float
) -> torch.Tensor:
# Handle gradient cutoff.
# TODO(ryand): This logic is a bit arbitrary. Think about how to clean it up.
timestep_cutoff = 0.5
if t_prev > timestep_cutoff:
# Early in the denoising process, use the smaller mask.
# I.e. treat gradient values as 0.0.
mask = self._inpaint_mask.where(self._inpaint_mask >= (1.0 - 1e-3), 0.0)
else:
# After the cut-off, use the larger mask.
# I.e. treat gradient values as 1.0.
mask = self._inpaint_mask.where(self._inpaint_mask <= (0.0 + 1e-3), 1.0)
# mask = (self._inpaint_mask > (0.0 + 1e-5)).float()
mask = self._apply_mask_gradient_adjustment(t_prev)
# Calculate the change_ratio based on the trajectory_guidance_strength.
# These mappings from trajectory_guidance_strength have no theoretical basis - they were tuned manually.
change_ratio_at_t_1 = build_line(x1=0.0, y1=1.0, x2=1.0, y2=0.0)(self._trajectory_guidance_strength)
change_ratio_at_cutoff = 1.0
t_cutoff = build_line(x1=0.0, y1=1.0, x2=1.0, y2=0.5)(self._trajectory_guidance_strength)