From 1b359b55cb57c7ad9de3a5a026f393f3f82fc7e2 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Mon, 22 Jul 2024 22:17:29 +0300 Subject: [PATCH] Suggested changes Co-Authored-By: Ryan Dick <14897797+RyanJDick@users.noreply.github.com> --- .../stable_diffusion/denoise_context.py | 20 +++++++++---------- .../extensions/rescale_cfg.py | 6 +++--- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/invokeai/backend/stable_diffusion/denoise_context.py b/invokeai/backend/stable_diffusion/denoise_context.py index 2b43d3fb0f..9060d54977 100644 --- a/invokeai/backend/stable_diffusion/denoise_context.py +++ b/invokeai/backend/stable_diffusion/denoise_context.py @@ -83,47 +83,47 @@ class DenoiseContext: unet: Optional[UNet2DConditionModel] = None # Current state of latent-space image in denoising process. - # None until `pre_denoise_loop` callback. + # None until `PRE_DENOISE_LOOP` callback. # Shape: [batch, channels, latent_height, latent_width] latents: Optional[torch.Tensor] = None # Current denoising step index. - # None until `pre_step` callback. + # None until `PRE_STEP` callback. step_index: Optional[int] = None # Current denoising step timestep. - # None until `pre_step` callback. + # None until `PRE_STEP` callback. timestep: Optional[torch.Tensor] = None # Arguments which will be passed to UNet model. - # Available in `pre_unet`/`post_unet` callbacks, otherwise will be None. + # Available in `PRE_UNET`/`POST_UNET` callbacks, otherwise will be None. unet_kwargs: Optional[UNetKwargs] = None # SchedulerOutput class returned from step function(normally, generated by scheduler). - # Supposed to be used only in `post_step` callback, otherwise can be None. + # Supposed to be used only in `POST_STEP` callback, otherwise can be None. step_output: Optional[SchedulerOutput] = None # Scaled version of `latents`, which will be passed to unet_kwargs initialization. - # Available in events inside step(between `pre_step` and `post_stop`). + # Available in events inside step(between `PRE_STEP` and `POST_STEP`). # Shape: [batch, channels, latent_height, latent_width] latent_model_input: Optional[torch.Tensor] = None # [TMP] Defines on which conditionings current unet call will be runned. - # Available in `pre_unet`/`post_unet` callbacks, otherwise will be None. + # Available in `PRE_UNET`/`POST_UNET` callbacks, otherwise will be None. conditioning_mode: Optional[ConditioningMode] = None # [TMP] Noise predictions from negative conditioning. - # Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwise will be None. + # Available in `POST_COMBINE_NOISE_PREDS` callback, otherwise will be None. # Shape: [batch, channels, latent_height, latent_width] negative_noise_pred: Optional[torch.Tensor] = None # [TMP] Noise predictions from positive conditioning. - # Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwise will be None. + # Available in `POST_COMBINE_NOISE_PREDS` callback, otherwise will be None. # Shape: [batch, channels, latent_height, latent_width] positive_noise_pred: Optional[torch.Tensor] = None # Combined noise prediction from passed conditionings. - # Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwise will be None. + # Available in `POST_COMBINE_NOISE_PREDS` callback, otherwise will be None. # Shape: [batch, channels, latent_height, latent_width] noise_pred: Optional[torch.Tensor] = None diff --git a/invokeai/backend/stable_diffusion/extensions/rescale_cfg.py b/invokeai/backend/stable_diffusion/extensions/rescale_cfg.py index 51fad975e7..7cccbb8a2b 100644 --- a/invokeai/backend/stable_diffusion/extensions/rescale_cfg.py +++ b/invokeai/backend/stable_diffusion/extensions/rescale_cfg.py @@ -14,7 +14,7 @@ if TYPE_CHECKING: class RescaleCFGExt(ExtensionBase): def __init__(self, rescale_multiplier: float): super().__init__() - self.rescale_multiplier = rescale_multiplier + self._rescale_multiplier = rescale_multiplier @staticmethod def _rescale_cfg(total_noise_pred: torch.Tensor, pos_noise_pred: torch.Tensor, multiplier: float = 0.7): @@ -28,9 +28,9 @@ class RescaleCFGExt(ExtensionBase): @callback(ExtensionCallbackType.POST_COMBINE_NOISE_PREDS) def rescale_noise_pred(self, ctx: DenoiseContext): - if self.rescale_multiplier > 0: + if self._rescale_multiplier > 0: ctx.noise_pred = self._rescale_cfg( ctx.noise_pred, ctx.positive_noise_pred, - self.rescale_multiplier, + self._rescale_multiplier, )