mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Moving from ControlNet guess_mode to separate booleans for cfg_injection and soft_injection for testing control modes
This commit is contained in:
@@ -223,6 +223,8 @@ class ControlNetData:
|
||||
end_step_percent: float = Field(default=1.0)
|
||||
# FIXME: replace with guess_mode with enum control_mode: BALANCED, MORE_PROMPT, MORE_CONTROL
|
||||
guess_mode: bool = Field(default=False) # guess_mode can work with or without prompt
|
||||
cfg_injection: bool = Field(default=False)
|
||||
soft_injection: bool = Field(default=False)
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ConditioningData:
|
||||
@@ -695,7 +697,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
last_control_step = math.ceil(control_datum.end_step_percent * total_step_count)
|
||||
# only apply controlnet if current step is within the controlnet's begin/end step range
|
||||
if step_index >= first_control_step and step_index <= last_control_step:
|
||||
guess_mode = control_datum.guess_mode
|
||||
# guess_mode = control_datum.guess_mode
|
||||
guess_mode = control_datum.cfg_injection
|
||||
if guess_mode:
|
||||
control_latent_input = unet_latent_input
|
||||
else:
|
||||
@@ -740,7 +743,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
controlnet_cond=control_datum.image_tensor,
|
||||
conditioning_scale=controlnet_weight, # controlnet specific, NOT the guidance scale
|
||||
# cross_attention_kwargs,
|
||||
guess_mode=guess_mode,
|
||||
# guess_mode=guess_mode,
|
||||
guess_mode=control_datum.soft_injection,
|
||||
return_dict=False,
|
||||
)
|
||||
print("finished ControlNetModel() call, step", step_index)
|
||||
@@ -1100,6 +1104,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
dtype=torch.float16,
|
||||
do_classifier_free_guidance=True,
|
||||
guess_mode=False,
|
||||
soft_injection=False,
|
||||
cfg_injection=False,
|
||||
):
|
||||
|
||||
if not isinstance(image, torch.Tensor):
|
||||
@@ -1130,6 +1136,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
repeat_by = num_images_per_prompt
|
||||
image = image.repeat_interleave(repeat_by, dim=0)
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
if do_classifier_free_guidance and not guess_mode:
|
||||
# if do_classifier_free_guidance and not guess_mode:
|
||||
if do_classifier_free_guidance and not cfg_injection:
|
||||
image = torch.cat([image] * 2)
|
||||
return image
|
||||
|
||||
Reference in New Issue
Block a user