Merge branch 'main' into install/refactor-configure-and-model-select

This commit is contained in:
Lincoln Stein
2023-02-16 21:51:15 -05:00
86 changed files with 707 additions and 318 deletions

View File

@@ -4,39 +4,34 @@ import dataclasses
import inspect
import psutil
import secrets
import sys
from collections.abc import Sequence
from dataclasses import dataclass, field
from typing import List, Optional, Union, Callable, Type, TypeVar, Generic, Any
if sys.version_info < (3, 10):
from typing_extensions import ParamSpec
else:
from typing import ParamSpec
import PIL.Image
import einops
import psutil
import torch
import torchvision.transforms as T
from diffusers.utils.import_utils import is_xformers_available
from ...models.diffusion.cross_attention_map_saving import AttentionMapSaver
from ...modules.prompt_to_embeddings_converter import WeightedPromptFragmentsToEmbeddingsConverter
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.outputs import BaseOutput
from torchvision.transforms.functional import resize as tv_resize
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from typing_extensions import ParamSpec
from ldm.invoke.globals import Globals
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent, PostprocessingSettings
from ldm.modules.textual_inversion_manager import TextualInversionManager
from ..offloading import LazilyLoadedModelGroup, FullyLoadedModelGroup, ModelGroup
from ...models.diffusion.cross_attention_map_saving import AttentionMapSaver
from ...modules.prompt_to_embeddings_converter import WeightedPromptFragmentsToEmbeddingsConverter
@dataclass
@@ -265,6 +260,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
_model_group: ModelGroup
ID_LENGTH = 8
@@ -274,7 +270,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
scheduler: KarrasDiffusionSchedulers,
safety_checker: Optional[StableDiffusionSafetyChecker],
feature_extractor: Optional[CLIPFeatureExtractor],
requires_safety_checker: bool = False,
@@ -304,8 +300,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
textual_inversion_manager=self.textual_inversion_manager
)
self._model_group = FullyLoadedModelGroup(self.unet.device)
self._model_group.install(*self._submodels)
def _adjust_memory_efficient_attention(self, latents: Torch.tensor):
def _adjust_memory_efficient_attention(self, latents: torch.Tensor):
"""
if xformers is available, use it, otherwise use sliced attention.
"""
@@ -323,7 +322,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
elif self.device.type == 'cuda':
mem_free, _ = torch.cuda.mem_get_info(self.device)
else:
raise ValueError(f"unrecognized device {device}")
raise ValueError(f"unrecognized device {self.device}")
# input tensor of [1, 4, h/8, w/8]
# output tensor of [16, (h/8 * w/8), (h/8 * w/8)]
bytes_per_element_needed_for_baddbmm_duplication = latents.element_size() + 4
@@ -337,6 +336,66 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
self.disable_attention_slicing()
def enable_offload_submodels(self, device: torch.device):
"""
Offload each submodel when it's not in use.
Useful for low-vRAM situations where the size of the model in memory is a big chunk of
the total available resource, and you want to free up as much for inference as possible.
This requires more moving parts and may add some delay as the U-Net is swapped out for the
VAE and vice-versa.
"""
models = self._submodels
if self._model_group is not None:
self._model_group.uninstall(*models)
group = LazilyLoadedModelGroup(device)
group.install(*models)
self._model_group = group
def disable_offload_submodels(self):
"""
Leave all submodels loaded.
Appropriate for cases where the size of the model in memory is small compared to the memory
required for inference. Avoids the delay and complexity of shuffling the submodels to and
from the GPU.
"""
models = self._submodels
if self._model_group is not None:
self._model_group.uninstall(*models)
group = FullyLoadedModelGroup(self._model_group.execution_device)
group.install(*models)
self._model_group = group
def offload_all(self):
"""Offload all this pipeline's models to CPU."""
self._model_group.offload_current()
def ready(self):
"""
Ready this pipeline's models.
i.e. pre-load them to the GPU if appropriate.
"""
self._model_group.ready()
def to(self, torch_device: Optional[Union[str, torch.device]] = None):
if torch_device is None:
return self
self._model_group.set_device(torch_device)
self._model_group.ready()
@property
def device(self) -> torch.device:
return self._model_group.execution_device
@property
def _submodels(self) -> Sequence[torch.nn.Module]:
module_names, _, _ = self.extract_init_dict(dict(self.config))
values = [getattr(self, name) for name in module_names.keys()]
return [m for m in values if isinstance(m, torch.nn.Module)]
def image_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int,
conditioning_data: ConditioningData,
*,
@@ -378,7 +437,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
callback: Callable[[PipelineIntermediateState], None] = None
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
if timesteps is None:
self.scheduler.set_timesteps(num_inference_steps, device=self.unet.device)
self.scheduler.set_timesteps(num_inference_steps, device=self._model_group.device_for(self.unet))
timesteps = self.scheduler.timesteps
infer_latents_from_embeddings = GeneratorToCallbackinator(self.generate_latents_from_embeddings, PipelineIntermediateState)
result: PipelineIntermediateState = infer_latents_from_embeddings(
@@ -410,7 +469,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
batch_size = latents.shape[0]
batched_t = torch.full((batch_size,), timesteps[0],
dtype=timesteps.dtype, device=self.unet.device)
dtype=timesteps.dtype, device=self._model_group.device_for(self.unet))
latents = self.scheduler.add_noise(latents, noise, batched_t)
attention_map_saver: Optional[AttentionMapSaver] = None
@@ -494,9 +553,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
initial_image_latents=torch.zeros_like(latents[:1], device=latents.device, dtype=latents.dtype)
).add_mask_channels(latents)
return self.unet(sample=latents,
timestep=t,
encoder_hidden_states=text_embeddings,
# First three args should be positional, not keywords, so torch hooks can see them.
return self.unet(latents, t, text_embeddings,
cross_attention_kwargs=cross_attention_kwargs).sample
def img2img_from_embeddings(self,
@@ -515,9 +573,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
init_image = einops.rearrange(init_image, 'c h w -> 1 c h w')
# 6. Prepare latent variables
device = self.unet.device
latents_dtype = self.unet.dtype
initial_latents = self.non_noised_latents_from_image(init_image, device=device, dtype=latents_dtype)
initial_latents = self.non_noised_latents_from_image(
init_image, device=self._model_group.device_for(self.unet),
dtype=self.unet.dtype)
noise = noise_func(initial_latents)
return self.img2img_from_latents_and_embeddings(initial_latents, num_inference_steps,
@@ -530,7 +588,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
strength,
noise: torch.Tensor, run_id=None, callback=None
) -> InvokeAIStableDiffusionPipelineOutput:
timesteps, _ = self.get_img2img_timesteps(num_inference_steps, strength, self.unet.device)
timesteps, _ = self.get_img2img_timesteps(num_inference_steps, strength,
device=self._model_group.device_for(self.unet))
result_latents, result_attention_maps = self.latents_from_embeddings(
initial_latents, num_inference_steps, conditioning_data,
timesteps=timesteps,
@@ -569,7 +628,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
run_id=None,
noise_func=None,
) -> InvokeAIStableDiffusionPipelineOutput:
device = self.unet.device
device = self._model_group.device_for(self.unet)
latents_dtype = self.unet.dtype
if isinstance(init_image, PIL.Image.Image):
@@ -633,6 +692,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# TODO remove this workaround once kulinseth#222 is merged to pytorch mainline
self.vae.to('cpu')
init_image = init_image.to('cpu')
else:
self._model_group.load(self.vae)
init_latent_dist = self.vae.encode(init_image).latent_dist
init_latents = init_latent_dist.sample().to(dtype=dtype) # FIXME: uses torch.randn. make reproducible!
if device.type == 'mps':
@@ -644,8 +705,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
def check_for_safety(self, output, dtype):
with torch.inference_mode():
screened_images, has_nsfw_concept = self.run_safety_checker(
output.images, device=self._execution_device, dtype=dtype)
screened_images, has_nsfw_concept = self.run_safety_checker(output.images, dtype=dtype)
screened_attention_map_saver = None
if has_nsfw_concept is None or not has_nsfw_concept:
screened_attention_map_saver = output.attention_map_saver
@@ -654,6 +714,12 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# block the attention maps if NSFW content is detected
attention_map_saver=screened_attention_map_saver)
def run_safety_checker(self, image, device=None, dtype=None):
# overriding to use the model group for device info instead of requiring the caller to know.
if self.safety_checker is not None:
device = self._model_group.device_for(self.safety_checker)
return super().run_safety_checker(image, device, dtype)
@torch.inference_mode()
def get_learned_conditioning(self, c: List[List[str]], *, return_tokens=True, fragment_weights=None):
"""
@@ -663,7 +729,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
text=c,
fragment_weights=fragment_weights,
should_return_tokens=return_tokens,
device=self.device)
device=self._model_group.device_for(self.unet))
@property
def cond_stage_model(self):
@@ -684,6 +750,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
"""Compatible with DiffusionWrapper"""
return self.unet.in_channels
def decode_latents(self, latents):
# Explicit call to get the vae loaded, since `decode` isn't the forward method.
self._model_group.load(self.vae)
return super().decode_latents(latents)
def debug_latents(self, latents, msg):
with torch.inference_mode():
from ldm.util import debug_image