Compare commits

...

2 Commits

Author SHA1 Message Date
Ryan Dick
7e6dbba8ca Apply to all modes in the frontend. 2024-09-24 21:05:45 +00:00
Ryan Dick
e11f98f128 experiment 2 2024-09-23 21:37:43 +00:00
2 changed files with 36 additions and 14 deletions

View File

@@ -56,12 +56,8 @@ class TrajectoryGuidanceExtension:
else:
self._inpaint_mask = inpaint_mask
# Calculate the params that define the trajectory guidance schedule.
# These mappings from trajectory_guidance_strength have no theoretical basis - they were tuned manually.
self._trajectory_guidance_strength = trajectory_guidance_strength
self._change_ratio_at_t_1 = build_line(x1=0.0, y1=1.0, x2=1.0, y2=0.0)(self._trajectory_guidance_strength)
self._change_ratio_at_cutoff = 1.0
self._t_cutoff = build_line(x1=0.0, y1=1.0, x2=1.0, y2=0.5)(self._trajectory_guidance_strength)
self._rescale_power = 0.5
def _apply_mask_gradient_adjustment(self, t_prev: float) -> torch.Tensor:
"""Applies inpaint mask gradient adjustment and returns the inpaint mask to be used at the current timestep."""
@@ -81,15 +77,29 @@ class TrajectoryGuidanceExtension:
return mask
def _get_change_ratio(self, t_prev: float) -> float:
def _get_change_ratio(self, t_curr: float, t_prev: float) -> float:
"""Get the change_ratio for t_prev based on the change schedule."""
change_ratio = 1.0
if t_prev > self._t_cutoff:
# If we are before the cutoff, linearly interpolate between the change_ratio at t=1.0 and the change_ratio
# at the cutoff.
change_ratio = build_line(
x1=1.0, y1=self._change_ratio_at_t_1, x2=self._t_cutoff, y2=self._change_ratio_at_cutoff
)(t_prev)
t = 1.0 - self._trajectory_guidance_strength
# Remap t scale to change more slowly at high values of t.
t = t**self._rescale_power
def change_ratio_fn(t_curr_: float, t_prev_: float):
"""Function that starts at 0.0, has a ramp up to 1.0, and then stays at 1.0."""
ramp_size = 0.25
ramp_start = build_line(x1=1.0, y1=1.0 + ramp_size, x2=0.0, y2=0.0)(t)
ramp_end = ramp_start - ramp_size
if t_curr_ > ramp_start:
return 0.0
elif ramp_start >= t_curr_ > ramp_end:
return build_line(x1=ramp_start, y1=0.0, x2=ramp_end, y2=1.0)(t_curr_)
else:
return 1.0
change_ratio = change_ratio_fn(t_curr, t_prev)
print(change_ratio)
# The change_ratio should be in the range [0, 1]. Assert that we didn't make any mistakes.
eps = 1e-5
@@ -102,7 +112,7 @@ class TrajectoryGuidanceExtension:
# Handle gradient cutoff.
mask = self._apply_mask_gradient_adjustment(t_prev)
mask = mask * self._get_change_ratio(t_prev)
mask = mask * self._get_change_ratio(t_curr, t_prev)
# NOTE(ryand): During inpainting, it is common to guide the denoising process by noising the initial latents for
# the current timestep and then blending the predicted intermediate latents with the noised initial latents.

View File

@@ -141,6 +141,12 @@ export const buildFLUXGraph = async (
denoisingValue,
false
);
if (optimizedDenoisingEnabled) {
g.updateNode(noise, {
denoising_start: 0,
trajectory_guidance_strength: denoisingValue,
});
}
} else if (generationMode === 'inpaint') {
canvasOutput = await addInpaint(
state,
@@ -175,6 +181,12 @@ export const buildFLUXGraph = async (
denoisingValue,
false
);
if (optimizedDenoisingEnabled) {
g.updateNode(noise, {
denoising_start: 0,
trajectory_guidance_strength: denoisingValue,
});
}
}
if (state.system.shouldUseNSFWChecker) {