diff --git a/invokeai/app/invocations/generate.py b/invokeai/app/invocations/generate.py index d48c9f922e..1239d578d9 100644 --- a/invokeai/app/invocations/generate.py +++ b/invokeai/app/invocations/generate.py @@ -16,7 +16,7 @@ from ..util.step_callback import stable_diffusion_step_callback from .baseinvocation import BaseInvocation, InvocationConfig, InvocationContext from .image import ImageOutput -from ...backend.model_management.lora import ModelPatcher +from ...backend.model_management import ModelPatcher, BaseModelType from ...backend.stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline from .model import UNetField, VaeField from .compel import ConditioningField @@ -140,6 +140,7 @@ class InpaintInvocation(BaseInvocation): self, context: InvocationContext, source_node_id: str, + base_model: BaseModelType, intermediate_state: PipelineIntermediateState, ) -> None: stable_diffusion_step_callback( @@ -147,15 +148,16 @@ class InpaintInvocation(BaseInvocation): intermediate_state=intermediate_state, node=self.dict(), source_node_id=source_node_id, + base_model=base_model, ) def get_conditioning(self, context, unet): positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name) - c = positive_cond_data.conditionings[0].embeds.to(device=unet.device, dtype=unet.dtype) - extra_conditioning_info = positive_cond_data.conditionings[0].extra_conditioning + c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype) + extra_conditioning_info = c.extra_conditioning negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name) - uc = negative_cond_data.conditionings[0].embeds.to(device=unet.device, dtype=unet.dtype) + uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype) return (uc, c, extra_conditioning_info) @@ -225,7 +227,7 @@ class InpaintInvocation(BaseInvocation): scheduler=scheduler, init_image=image, mask_image=mask, - step_callback=partial(self.dispatch_progress, context, source_node_id), + step_callback=partial(self.dispatch_progress, context, source_node_id, self.unet.unet.base_model), **self.dict( exclude={"positive_conditioning", "negative_conditioning", "scheduler", "image", "mask"} ), # Shorthand for passing all of the parameters above manually diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index fef3bcbf6f..25e411074a 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -24,7 +24,7 @@ from ...backend.stable_diffusion.diffusers_pipeline import ( ) from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP -from ...backend.model_management import ModelPatcher +from ...backend.model_management import ModelPatcher, BaseModelType from ...backend.util.devices import choose_torch_device, torch_dtype, choose_precision from ..models.image import ImageCategory, ImageField, ResourceOrigin from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext @@ -160,12 +160,14 @@ class TextToLatentsInvocation(BaseInvocation): context: InvocationContext, source_node_id: str, intermediate_state: PipelineIntermediateState, + base_model: BaseModelType, ) -> None: stable_diffusion_step_callback( context=context, intermediate_state=intermediate_state, node=self.dict(), source_node_id=source_node_id, + base_model=base_model, ) def get_conditioning_data( @@ -340,7 +342,7 @@ class TextToLatentsInvocation(BaseInvocation): source_node_id = graph_execution_state.prepared_source_mapping[self.id] def step_callback(state: PipelineIntermediateState): - self.dispatch_progress(context, source_node_id, state) + self.dispatch_progress(context, source_node_id, state, self.unet.unet.base_model) def _lora_loader(): for lora in self.unet.loras: @@ -379,7 +381,7 @@ class TextToLatentsInvocation(BaseInvocation): do_classifier_free_guidance=True, exit_stack=exit_stack, ) - + num_inference_steps, timesteps = self.init_scheduler( scheduler, device=unet.device, @@ -448,7 +450,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): source_node_id = graph_execution_state.prepared_source_mapping[self.id] def step_callback(state: PipelineIntermediateState): - self.dispatch_progress(context, source_node_id, state) + self.dispatch_progress(context, source_node_id, state, self.unet.unet.base_model) def _lora_loader(): for lora in self.unet.loras: diff --git a/invokeai/app/util/step_callback.py b/invokeai/app/util/step_callback.py index 994d83e705..aae06913fd 100644 --- a/invokeai/app/util/step_callback.py +++ b/invokeai/app/util/step_callback.py @@ -7,6 +7,7 @@ from ...backend.util.util import image_to_dataURL from ...backend.generator.base import Generator from ...backend.stable_diffusion import PipelineIntermediateState from invokeai.app.services.config import InvokeAIAppConfig +from ...backend.model_management.models import BaseModelType def sample_to_lowres_estimated_image(samples, latent_rgb_factors, smooth_matrix=None): @@ -29,6 +30,7 @@ def stable_diffusion_step_callback( intermediate_state: PipelineIntermediateState, node: dict, source_node_id: str, + base_model: BaseModelType, ): if context.services.queue.is_canceled(context.graph_execution_state_id): raise CanceledException @@ -56,23 +58,51 @@ def stable_diffusion_step_callback( # TODO: only output a preview image when requested - # origingally adapted from code by @erucipe and @keturn here: - # https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7 + if base_model in [BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner]: + sdxl_latent_rgb_factors = torch.tensor( + [ + # R G B + [0.3816, 0.4930, 0.5320], + [-0.3753, 0.1631, 0.1739], + [0.1770, 0.3588, -0.2048], + [-0.4350, -0.2644, -0.4289], + ], + dtype=sample.dtype, + device=sample.device, + ) - # these updated numbers for v1.5 are from @torridgristle - v1_5_latent_rgb_factors = torch.tensor( - [ - # R G B - [0.3444, 0.1385, 0.0670], # L1 - [0.1247, 0.4027, 0.1494], # L2 - [-0.3192, 0.2513, 0.2103], # L3 - [-0.1307, -0.1874, -0.7445], # L4 - ], - dtype=sample.dtype, - device=sample.device, - ) + sdxl_smooth_matrix = torch.tensor( + [ + # [ 0.0478, 0.1285, 0.0478], + # [ 0.1285, 0.2948, 0.1285], + # [ 0.0478, 0.1285, 0.0478], + [0.0358, 0.0964, 0.0358], + [0.0964, 0.4711, 0.0964], + [0.0358, 0.0964, 0.0358], + ], + dtype=sample.dtype, + device=sample.device, + ) - image = sample_to_lowres_estimated_image(sample, v1_5_latent_rgb_factors) + image = sample_to_lowres_estimated_image(sample, sdxl_latent_rgb_factors, sdxl_smooth_matrix) + else: + # origingally adapted from code by @erucipe and @keturn here: + # https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7 + + # these updated numbers for v1.5 are from @torridgristle + v1_5_latent_rgb_factors = torch.tensor( + [ + # R G B + [0.3444, 0.1385, 0.0670], # L1 + [0.1247, 0.4027, 0.1494], # L2 + [-0.3192, 0.2513, 0.2103], # L3 + [-0.1307, -0.1874, -0.7445], # L4 + ], + dtype=sample.dtype, + device=sample.device, + ) + + image = sample_to_lowres_estimated_image(sample, v1_5_latent_rgb_factors) (width, height) = image.size width *= 8 diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index ed1c8deeb5..9d080e648d 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -50,7 +50,6 @@ from .offloading import FullyLoadedModelGroup, ModelGroup @dataclass class PipelineIntermediateState: - run_id: str step: int timestep: int latents: torch.Tensor @@ -407,7 +406,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): noise: Optional[torch.Tensor], timesteps=None, additional_guidance: List[Callable] = None, - run_id=None, callback: Callable[[PipelineIntermediateState], None] = None, control_data: List[ControlNetData] = None, ) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]: @@ -427,7 +425,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): timesteps, conditioning_data, noise=noise, - run_id=run_id, additional_guidance=additional_guidance, control_data=control_data, callback=callback, @@ -441,13 +438,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): conditioning_data: ConditioningData, *, noise: Optional[torch.Tensor], - run_id: str = None, additional_guidance: List[Callable] = None, control_data: List[ControlNetData] = None, ): self._adjust_memory_efficient_attention(latents) - if run_id is None: - run_id = secrets.token_urlsafe(self.ID_LENGTH) if additional_guidance is None: additional_guidance = [] extra_conditioning_info = conditioning_data.extra @@ -468,7 +462,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): latents = self.scheduler.add_noise(latents, noise, batched_t) yield PipelineIntermediateState( - run_id=run_id, step=-1, timestep=self.scheduler.config.num_train_timesteps, latents=latents, @@ -507,7 +500,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): # self.invokeai_diffuser.setup_attention_map_saving(attention_map_saver) yield PipelineIntermediateState( - run_id=run_id, step=i, timestep=int(t), latents=latents, @@ -619,7 +611,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): conditioning_data: ConditioningData, *, callback: Callable[[PipelineIntermediateState], None] = None, - run_id=None, noise_func=None, seed=None, ) -> InvokeAIStableDiffusionPipelineOutput: @@ -645,7 +636,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): conditioning_data, strength, noise, - run_id, callback, ) @@ -678,7 +668,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): conditioning_data: ConditioningData, *, callback: Callable[[PipelineIntermediateState], None] = None, - run_id=None, noise_func=None, seed=None, ) -> InvokeAIStableDiffusionPipelineOutput: @@ -737,7 +726,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): noise=noise, timesteps=timesteps, additional_guidance=guidance, - run_id=run_id, callback=callback, ) finally: