diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index 6268af369f..98c3c22423 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -5,7 +5,7 @@ import inspect import math import secrets from dataclasses import dataclass, field -from typing import Any, Callable, Generic, List, Optional, Type, TypeVar, Union +from typing import Any, Callable, Generic, List, Optional, Type, Union import PIL.Image import einops @@ -27,7 +27,6 @@ from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.outputs import BaseOutput from pydantic import Field from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer -from typing_extensions import ParamSpec from invokeai.app.services.config import InvokeAIAppConfig from .diffusion import ( @@ -161,33 +160,6 @@ def is_inpainting_model(unet: UNet2DConditionModel): return unet.conv_in.in_channels == 9 -CallbackType = TypeVar("CallbackType") -ReturnType = TypeVar("ReturnType") -ParamType = ParamSpec("ParamType") - - -@dataclass(frozen=True) -class GeneratorToCallbackinator(Generic[ParamType, ReturnType, CallbackType]): - """Convert a generator to a function with a callback and a return value.""" - - generator_method: Callable[ParamType, ReturnType] - callback_arg_type: Type[CallbackType] - - def __call__( - self, - *args: ParamType.args, - callback: Callable[[CallbackType], Any] = None, - **kwargs: ParamType.kwargs, - ) -> ReturnType: - result = None - for result in self.generator_method(*args, **kwargs): - if callback is not None and isinstance(result, self.callback_arg_type): - callback(result) - if result is None: - raise AssertionError("why was that an empty generator?") - return result - - @dataclass class ControlNetData: model: ControlNetModel = Field(default=None) @@ -375,10 +347,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): if init_timestep.shape[0] == 0: return latents, None - infer_latents_from_embeddings = GeneratorToCallbackinator( - self.generate_latents_from_embeddings, PipelineIntermediateState - ) - if additional_guidance is None: additional_guidance = [] @@ -417,7 +385,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): additional_guidance.append(AddsMaskGuidance(mask, orig_latents, self.scheduler, noise)) try: - result: PipelineIntermediateState = infer_latents_from_embeddings( + latents, attention_map_saver = self.generate_latents_from_embeddings( latents, timesteps, conditioning_data, @@ -428,13 +396,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): finally: self.invokeai_diffuser.model_forward_callback = self._unet_forward - latents = result.latents - # restore unmasked part if mask is not None: latents = torch.lerp(orig_latents, latents.to(dtype=orig_latents.dtype), mask.to(dtype=orig_latents.dtype)) - return latents, result.attention_map_saver + return latents, attention_map_saver def generate_latents_from_embeddings( self, @@ -444,6 +410,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): *, additional_guidance: List[Callable] = None, control_data: List[ControlNetData] = None, + callback: Callable[[PipelineIntermediateState], None] = None, ): self._adjust_memory_efficient_attention(latents) if additional_guidance is None: @@ -461,13 +428,14 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): extra_conditioning_info=extra_conditioning_info, step_count=len(self.scheduler.timesteps), ): - yield PipelineIntermediateState( - step=-1, - order=self.scheduler.order, - total_steps=len(timesteps), - timestep=self.scheduler.config.num_train_timesteps, - latents=latents, - ) + if callback is not None: + callback(PipelineIntermediateState( + step=-1, + order=self.scheduler.order, + total_steps=len(timesteps), + timestep=self.scheduler.config.num_train_timesteps, + latents=latents, + )) # print("timesteps:", timesteps) for i, t in enumerate(self.progress_bar(timesteps)): @@ -500,15 +468,16 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): # attention_map_saver = AttentionMapSaver(token_ids=attention_map_token_ids, latents_shape=latents.shape[-2:]) # self.invokeai_diffuser.setup_attention_map_saving(attention_map_saver) - yield PipelineIntermediateState( - step=i, - order=self.scheduler.order, - total_steps=len(timesteps), - timestep=int(t), - latents=latents, - predicted_original=predicted_original, - attention_map_saver=attention_map_saver, - ) + if callback is not None: + callback(PipelineIntermediateState( + step=i, + order=self.scheduler.order, + total_steps=len(timesteps), + timestep=int(t), + latents=latents, + predicted_original=predicted_original, + attention_map_saver=attention_map_saver, + )) return latents, attention_map_saver