mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-19 09:54:24 -05:00
Merge branch 'main' into enhance/update-menu
This commit is contained in:
984
ldm/generate.py
984
ldm/generate.py
File diff suppressed because it is too large
Load Diff
1020
ldm/invoke/CLI.py
1020
ldm/invoke/CLI.py
File diff suppressed because it is too large
Load Diff
@@ -2,3 +2,12 @@ from ._version import __version__
|
||||
|
||||
__app_id__= 'invoke-ai/InvokeAI'
|
||||
__app_name__= 'InvokeAI'
|
||||
|
||||
|
||||
def _ignore_xformers_triton_message_on_windows():
|
||||
import logging
|
||||
logging.getLogger("xformers").addFilter(
|
||||
lambda record: 'A matching Triton is not available' not in record.getMessage())
|
||||
|
||||
# In order to be effective, this needs to happen before anything could possibly import xformers.
|
||||
_ignore_xformers_triton_message_on_windows()
|
||||
|
||||
@@ -272,6 +272,10 @@ class Args(object):
|
||||
switches.append('--seamless')
|
||||
if a['hires_fix']:
|
||||
switches.append('--hires_fix')
|
||||
if a['h_symmetry_time_pct']:
|
||||
switches.append(f'--h_symmetry_time_pct {a["h_symmetry_time_pct"]}')
|
||||
if a['v_symmetry_time_pct']:
|
||||
switches.append(f'--v_symmetry_time_pct {a["v_symmetry_time_pct"]}')
|
||||
|
||||
# img2img generations have parameters relevant only to them and have special handling
|
||||
if a['init_img'] and len(a['init_img'])>0:
|
||||
@@ -751,6 +755,9 @@ class Args(object):
|
||||
!fix applies upscaling/facefixing to a previously-generated image.
|
||||
invoke> !fix 0000045.4829112.png -G1 -U4 -ft codeformer
|
||||
|
||||
*embeddings*
|
||||
invoke> !triggers -- return all trigger phrases contained in loaded embedding files
|
||||
|
||||
*History manipulation*
|
||||
!fetch retrieves the command used to generate an earlier image. Provide
|
||||
a directory wildcard and the name of a file to write and all the commands
|
||||
@@ -842,6 +849,18 @@ class Args(object):
|
||||
type=float,
|
||||
help='Perlin noise scale (0.0 - 1.0) - add perlin noise to the initialization instead of the usual gaussian noise.',
|
||||
)
|
||||
render_group.add_argument(
|
||||
'--h_symmetry_time_pct',
|
||||
default=None,
|
||||
type=float,
|
||||
help='Horizontal symmetry point (0.0 - 1.0) - apply horizontal symmetry at this point in image generation.',
|
||||
)
|
||||
render_group.add_argument(
|
||||
'--v_symmetry_time_pct',
|
||||
default=None,
|
||||
type=float,
|
||||
help='Vertical symmetry point (0.0 - 1.0) - apply vertical symmetry at this point in image generation.',
|
||||
)
|
||||
render_group.add_argument(
|
||||
'--fnformat',
|
||||
default='{prefix}.{seed}.png',
|
||||
@@ -1148,7 +1167,8 @@ def metadata_dumps(opt,
|
||||
# remove any image keys not mentioned in RFC #266
|
||||
rfc266_img_fields = ['type','postprocessing','sampler','prompt','seed','variations','steps',
|
||||
'cfg_scale','threshold','perlin','step_number','width','height','extra','strength','seamless'
|
||||
'init_img','init_mask','facetool','facetool_strength','upscale']
|
||||
'init_img','init_mask','facetool','facetool_strength','upscale','h_symmetry_time_pct',
|
||||
'v_symmetry_time_pct']
|
||||
rfc_dict ={}
|
||||
|
||||
for item in image_dict.items():
|
||||
|
||||
@@ -53,6 +53,7 @@ from diffusers import (
|
||||
)
|
||||
from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
|
||||
from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder, PaintByExamplePipeline
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from diffusers.utils import is_safetensors_available
|
||||
from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextModel, CLIPTokenizer, CLIPVisionConfig
|
||||
|
||||
@@ -984,6 +985,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
elif model_type in ['FrozenCLIPEmbedder','WeightedFrozenCLIPEmbedder']:
|
||||
text_model = convert_ldm_clip_checkpoint(checkpoint)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14",cache_dir=cache_dir)
|
||||
safety_checker = StableDiffusionSafetyChecker.from_pretrained('CompVis/stable-diffusion-safety-checker',cache_dir=global_cache_dir("hub"))
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker",cache_dir=cache_dir)
|
||||
pipe = pipeline_class(
|
||||
vae=vae,
|
||||
@@ -991,7 +993,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=None,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -93,7 +93,7 @@ def _get_conditioning_for_prompt(parsed_prompt: Union[Blend, FlattenedPrompt], p
|
||||
Process prompt structure and tokens, and return (conditioning, unconditioning, extra_conditioning_info)
|
||||
"""
|
||||
|
||||
if log_tokens or Globals.log_tokenization:
|
||||
if log_tokens or getattr(Globals, "log_tokenization", False):
|
||||
print(f"\n>> [TOKENLOG] Parsed Prompt: {parsed_prompt}")
|
||||
print(f"\n>> [TOKENLOG] Parsed Negative Prompt: {parsed_negative_prompt}")
|
||||
|
||||
@@ -236,7 +236,7 @@ def _get_embeddings_and_tokens_for_prompt(model, flattened_prompt: FlattenedProm
|
||||
fragments = [x.text for x in flattened_prompt.children]
|
||||
weights = [x.weight for x in flattened_prompt.children]
|
||||
embeddings, tokens = model.get_learned_conditioning([fragments], return_tokens=True, fragment_weights=[weights])
|
||||
if log_tokens or Globals.log_tokenization:
|
||||
if log_tokens or getattr(Globals, "log_tokenization", False):
|
||||
text = " ".join(fragments)
|
||||
log_tokenization(text, model, display_label=log_display_label)
|
||||
|
||||
@@ -296,4 +296,4 @@ def log_tokenization(text, model, display_label=None):
|
||||
|
||||
if discarded != "":
|
||||
print(f'\n>> [TOKENLOG] Tokens Discarded ({totalTokens - usedTokens}):')
|
||||
print(f'{discarded}\x1b[0m')
|
||||
print(f'{discarded}\x1b[0m')
|
||||
|
||||
@@ -40,7 +40,6 @@ from ldm.invoke.globals import Globals, global_cache_dir, global_config_dir
|
||||
from ldm.invoke.readline import generic_completer
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
import torch
|
||||
|
||||
transformers.logging.set_verbosity_error()
|
||||
|
||||
@@ -764,7 +763,7 @@ def download_weights(opt: dict) -> Union[str, None]:
|
||||
precision = (
|
||||
"float32"
|
||||
if opt.full_precision
|
||||
else choose_precision(torch.device(choose_torch_device()))
|
||||
else choose_precision(choose_torch_device())
|
||||
)
|
||||
|
||||
if opt.yes_to_all:
|
||||
|
||||
@@ -1,19 +1,25 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
from torch import autocast
|
||||
from contextlib import nullcontext
|
||||
|
||||
from ldm.invoke.globals import Globals
|
||||
|
||||
def choose_torch_device() -> str:
|
||||
CPU_DEVICE = torch.device("cpu")
|
||||
|
||||
def choose_torch_device() -> torch.device:
|
||||
'''Convenience routine for guessing which GPU device to run model on'''
|
||||
if Globals.always_use_cpu:
|
||||
return "cpu"
|
||||
return CPU_DEVICE
|
||||
if torch.cuda.is_available():
|
||||
return 'cuda'
|
||||
return torch.device('cuda')
|
||||
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
||||
return 'mps'
|
||||
return 'cpu'
|
||||
return torch.device('mps')
|
||||
return CPU_DEVICE
|
||||
|
||||
def choose_precision(device) -> str:
|
||||
def choose_precision(device: torch.device) -> str:
|
||||
'''Returns an appropriate precision for the given torch device'''
|
||||
if device.type == 'cuda':
|
||||
device_name = torch.cuda.get_device_name(device)
|
||||
@@ -21,7 +27,7 @@ def choose_precision(device) -> str:
|
||||
return 'float16'
|
||||
return 'float32'
|
||||
|
||||
def torch_dtype(device) -> torch.dtype:
|
||||
def torch_dtype(device: torch.device) -> torch.dtype:
|
||||
if Globals.full_precision:
|
||||
return torch.float32
|
||||
if choose_precision(device) == 'float16':
|
||||
@@ -36,3 +42,13 @@ def choose_autocast(precision):
|
||||
if precision == 'autocast' or precision == 'float16':
|
||||
return autocast
|
||||
return nullcontext
|
||||
|
||||
def normalize_device(device: str | torch.device) -> torch.device:
|
||||
"""Ensure device has a device index defined, if appropriate."""
|
||||
device = torch.device(device)
|
||||
if device.index is None:
|
||||
# cuda might be the only torch backend that currently uses the device index?
|
||||
# I don't see anything like `current_device` for cpu or mps.
|
||||
if device.type == 'cuda':
|
||||
device = torch.device(device.type, torch.cuda.current_device())
|
||||
return device
|
||||
|
||||
@@ -64,6 +64,7 @@ class Generator:
|
||||
|
||||
def generate(self,prompt,init_image,width,height,sampler, iterations=1,seed=None,
|
||||
image_callback=None, step_callback=None, threshold=0.0, perlin=0.0,
|
||||
h_symmetry_time_pct=None, v_symmetry_time_pct=None,
|
||||
safety_checker:dict=None,
|
||||
free_gpu_mem: bool=False,
|
||||
**kwargs):
|
||||
@@ -81,6 +82,8 @@ class Generator:
|
||||
step_callback = step_callback,
|
||||
threshold = threshold,
|
||||
perlin = perlin,
|
||||
h_symmetry_time_pct = h_symmetry_time_pct,
|
||||
v_symmetry_time_pct = v_symmetry_time_pct,
|
||||
attention_maps_callback = attention_maps_callback,
|
||||
**kwargs
|
||||
)
|
||||
@@ -247,11 +250,14 @@ class Generator:
|
||||
fixdevice = 'cpu' if (self.model.device.type == 'mps') else self.model.device
|
||||
# limit noise to only the diffusion image channels, not the mask channels
|
||||
input_channels = min(self.latent_channels, 4)
|
||||
# round up to the nearest block of 8
|
||||
temp_width = int((width + 7) / 8) * 8
|
||||
temp_height = int((height + 7) / 8) * 8
|
||||
noise = torch.stack([
|
||||
rand_perlin_2d((height, width),
|
||||
rand_perlin_2d((temp_height, temp_width),
|
||||
(8, 8),
|
||||
device = self.model.device).to(fixdevice) for _ in range(input_channels)], dim=0).to(self.model.device)
|
||||
return noise
|
||||
return noise[0:4, 0:height, 0:width]
|
||||
|
||||
def new_seed(self):
|
||||
self.seed = random.randrange(0, np.iinfo(np.uint32).max)
|
||||
|
||||
@@ -3,39 +3,35 @@ from __future__ import annotations
|
||||
import dataclasses
|
||||
import inspect
|
||||
import secrets
|
||||
import sys
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional, Union, Callable, Type, TypeVar, Generic, Any
|
||||
|
||||
if sys.version_info < (3, 10):
|
||||
from typing_extensions import ParamSpec
|
||||
else:
|
||||
from typing import ParamSpec
|
||||
|
||||
import PIL.Image
|
||||
import einops
|
||||
import psutil
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
from ...models.diffusion.cross_attention_map_saving import AttentionMapSaver
|
||||
from ...modules.prompt_to_embeddings_converter import WeightedPromptFragmentsToEmbeddingsConverter
|
||||
|
||||
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from diffusers.schedulers import KarrasDiffusionSchedulers
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
|
||||
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.outputs import BaseOutput
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from ldm.invoke.globals import Globals
|
||||
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent, PostprocessingSettings
|
||||
from ldm.modules.textual_inversion_manager import TextualInversionManager
|
||||
from ..devices import normalize_device, CPU_DEVICE
|
||||
from ..offloading import LazilyLoadedModelGroup, FullyLoadedModelGroup, ModelGroup
|
||||
from ...models.diffusion.cross_attention_map_saving import AttentionMapSaver
|
||||
from ...modules.prompt_to_embeddings_converter import WeightedPromptFragmentsToEmbeddingsConverter
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -264,6 +260,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
_model_group: ModelGroup
|
||||
|
||||
ID_LENGTH = 8
|
||||
|
||||
@@ -273,7 +270,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
safety_checker: Optional[StableDiffusionSafetyChecker],
|
||||
feature_extractor: Optional[CLIPFeatureExtractor],
|
||||
requires_safety_checker: bool = False,
|
||||
@@ -303,8 +300,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
textual_inversion_manager=self.textual_inversion_manager
|
||||
)
|
||||
|
||||
self._model_group = FullyLoadedModelGroup(self.unet.device)
|
||||
self._model_group.install(*self._submodels)
|
||||
|
||||
def _adjust_memory_efficient_attention(self, latents: Torch.tensor):
|
||||
|
||||
def _adjust_memory_efficient_attention(self, latents: torch.Tensor):
|
||||
"""
|
||||
if xformers is available, use it, otherwise use sliced attention.
|
||||
"""
|
||||
@@ -320,9 +320,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
if self.device.type == 'cpu' or self.device.type == 'mps':
|
||||
mem_free = psutil.virtual_memory().free
|
||||
elif self.device.type == 'cuda':
|
||||
mem_free, _ = torch.cuda.mem_get_info(self.device)
|
||||
mem_free, _ = torch.cuda.mem_get_info(normalize_device(self.device))
|
||||
else:
|
||||
raise ValueError(f"unrecognized device {device}")
|
||||
raise ValueError(f"unrecognized device {self.device}")
|
||||
# input tensor of [1, 4, h/8, w/8]
|
||||
# output tensor of [16, (h/8 * w/8), (h/8 * w/8)]
|
||||
bytes_per_element_needed_for_baddbmm_duplication = latents.element_size() + 4
|
||||
@@ -336,6 +336,67 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
self.disable_attention_slicing()
|
||||
|
||||
|
||||
def enable_offload_submodels(self, device: torch.device):
|
||||
"""
|
||||
Offload each submodel when it's not in use.
|
||||
|
||||
Useful for low-vRAM situations where the size of the model in memory is a big chunk of
|
||||
the total available resource, and you want to free up as much for inference as possible.
|
||||
|
||||
This requires more moving parts and may add some delay as the U-Net is swapped out for the
|
||||
VAE and vice-versa.
|
||||
"""
|
||||
models = self._submodels
|
||||
if self._model_group is not None:
|
||||
self._model_group.uninstall(*models)
|
||||
group = LazilyLoadedModelGroup(device)
|
||||
group.install(*models)
|
||||
self._model_group = group
|
||||
|
||||
def disable_offload_submodels(self):
|
||||
"""
|
||||
Leave all submodels loaded.
|
||||
|
||||
Appropriate for cases where the size of the model in memory is small compared to the memory
|
||||
required for inference. Avoids the delay and complexity of shuffling the submodels to and
|
||||
from the GPU.
|
||||
"""
|
||||
models = self._submodels
|
||||
if self._model_group is not None:
|
||||
self._model_group.uninstall(*models)
|
||||
group = FullyLoadedModelGroup(self._model_group.execution_device)
|
||||
group.install(*models)
|
||||
self._model_group = group
|
||||
|
||||
def offload_all(self):
|
||||
"""Offload all this pipeline's models to CPU."""
|
||||
self._model_group.offload_current()
|
||||
|
||||
def ready(self):
|
||||
"""
|
||||
Ready this pipeline's models.
|
||||
|
||||
i.e. pre-load them to the GPU if appropriate.
|
||||
"""
|
||||
self._model_group.ready()
|
||||
|
||||
def to(self, torch_device: Optional[Union[str, torch.device]] = None):
|
||||
# overridden method; types match the superclass.
|
||||
if torch_device is None:
|
||||
return self
|
||||
self._model_group.set_device(torch.device(torch_device))
|
||||
self._model_group.ready()
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return self._model_group.execution_device
|
||||
|
||||
@property
|
||||
def _submodels(self) -> Sequence[torch.nn.Module]:
|
||||
module_names, _, _ = self.extract_init_dict(dict(self.config))
|
||||
values = [getattr(self, name) for name in module_names.keys()]
|
||||
return [m for m in values if isinstance(m, torch.nn.Module)]
|
||||
|
||||
def image_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int,
|
||||
conditioning_data: ConditioningData,
|
||||
*,
|
||||
@@ -377,7 +438,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
callback: Callable[[PipelineIntermediateState], None] = None
|
||||
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
|
||||
if timesteps is None:
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=self.unet.device)
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=self._model_group.device_for(self.unet))
|
||||
timesteps = self.scheduler.timesteps
|
||||
infer_latents_from_embeddings = GeneratorToCallbackinator(self.generate_latents_from_embeddings, PipelineIntermediateState)
|
||||
result: PipelineIntermediateState = infer_latents_from_embeddings(
|
||||
@@ -409,7 +470,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
|
||||
batch_size = latents.shape[0]
|
||||
batched_t = torch.full((batch_size,), timesteps[0],
|
||||
dtype=timesteps.dtype, device=self.unet.device)
|
||||
dtype=timesteps.dtype, device=self._model_group.device_for(self.unet))
|
||||
latents = self.scheduler.add_noise(latents, noise, batched_t)
|
||||
|
||||
attention_map_saver: Optional[AttentionMapSaver] = None
|
||||
@@ -493,9 +554,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
initial_image_latents=torch.zeros_like(latents[:1], device=latents.device, dtype=latents.dtype)
|
||||
).add_mask_channels(latents)
|
||||
|
||||
return self.unet(sample=latents,
|
||||
timestep=t,
|
||||
encoder_hidden_states=text_embeddings,
|
||||
# First three args should be positional, not keywords, so torch hooks can see them.
|
||||
return self.unet(latents, t, text_embeddings,
|
||||
cross_attention_kwargs=cross_attention_kwargs).sample
|
||||
|
||||
def img2img_from_embeddings(self,
|
||||
@@ -514,9 +574,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
init_image = einops.rearrange(init_image, 'c h w -> 1 c h w')
|
||||
|
||||
# 6. Prepare latent variables
|
||||
device = self.unet.device
|
||||
latents_dtype = self.unet.dtype
|
||||
initial_latents = self.non_noised_latents_from_image(init_image, device=device, dtype=latents_dtype)
|
||||
initial_latents = self.non_noised_latents_from_image(
|
||||
init_image, device=self._model_group.device_for(self.unet),
|
||||
dtype=self.unet.dtype)
|
||||
noise = noise_func(initial_latents)
|
||||
|
||||
return self.img2img_from_latents_and_embeddings(initial_latents, num_inference_steps,
|
||||
@@ -529,7 +589,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
strength,
|
||||
noise: torch.Tensor, run_id=None, callback=None
|
||||
) -> InvokeAIStableDiffusionPipelineOutput:
|
||||
timesteps, _ = self.get_img2img_timesteps(num_inference_steps, strength, self.unet.device)
|
||||
timesteps, _ = self.get_img2img_timesteps(num_inference_steps, strength,
|
||||
device=self._model_group.device_for(self.unet))
|
||||
result_latents, result_attention_maps = self.latents_from_embeddings(
|
||||
initial_latents, num_inference_steps, conditioning_data,
|
||||
timesteps=timesteps,
|
||||
@@ -568,7 +629,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
run_id=None,
|
||||
noise_func=None,
|
||||
) -> InvokeAIStableDiffusionPipelineOutput:
|
||||
device = self.unet.device
|
||||
device = self._model_group.device_for(self.unet)
|
||||
latents_dtype = self.unet.dtype
|
||||
|
||||
if isinstance(init_image, PIL.Image.Image):
|
||||
@@ -630,8 +691,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
if device.type == 'mps':
|
||||
# workaround for torch MPS bug that has been fixed in https://github.com/kulinseth/pytorch/pull/222
|
||||
# TODO remove this workaround once kulinseth#222 is merged to pytorch mainline
|
||||
self.vae.to('cpu')
|
||||
init_image = init_image.to('cpu')
|
||||
self.vae.to(CPU_DEVICE)
|
||||
init_image = init_image.to(CPU_DEVICE)
|
||||
else:
|
||||
self._model_group.load(self.vae)
|
||||
init_latent_dist = self.vae.encode(init_image).latent_dist
|
||||
init_latents = init_latent_dist.sample().to(dtype=dtype) # FIXME: uses torch.randn. make reproducible!
|
||||
if device.type == 'mps':
|
||||
@@ -643,8 +706,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
|
||||
def check_for_safety(self, output, dtype):
|
||||
with torch.inference_mode():
|
||||
screened_images, has_nsfw_concept = self.run_safety_checker(
|
||||
output.images, device=self._execution_device, dtype=dtype)
|
||||
screened_images, has_nsfw_concept = self.run_safety_checker(output.images, dtype=dtype)
|
||||
screened_attention_map_saver = None
|
||||
if has_nsfw_concept is None or not has_nsfw_concept:
|
||||
screened_attention_map_saver = output.attention_map_saver
|
||||
@@ -653,6 +715,12 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
# block the attention maps if NSFW content is detected
|
||||
attention_map_saver=screened_attention_map_saver)
|
||||
|
||||
def run_safety_checker(self, image, device=None, dtype=None):
|
||||
# overriding to use the model group for device info instead of requiring the caller to know.
|
||||
if self.safety_checker is not None:
|
||||
device = self._model_group.device_for(self.safety_checker)
|
||||
return super().run_safety_checker(image, device, dtype)
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_learned_conditioning(self, c: List[List[str]], *, return_tokens=True, fragment_weights=None):
|
||||
"""
|
||||
@@ -662,7 +730,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
text=c,
|
||||
fragment_weights=fragment_weights,
|
||||
should_return_tokens=return_tokens,
|
||||
device=self.device)
|
||||
device=self._model_group.device_for(self.unet))
|
||||
|
||||
@property
|
||||
def cond_stage_model(self):
|
||||
@@ -683,6 +751,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
"""Compatible with DiffusionWrapper"""
|
||||
return self.unet.in_channels
|
||||
|
||||
def decode_latents(self, latents):
|
||||
# Explicit call to get the vae loaded, since `decode` isn't the forward method.
|
||||
self._model_group.load(self.vae)
|
||||
return super().decode_latents(latents)
|
||||
|
||||
def debug_latents(self, latents, msg):
|
||||
with torch.inference_mode():
|
||||
from ldm.util import debug_image
|
||||
|
||||
@@ -16,8 +16,8 @@ class Img2Img(Generator):
|
||||
self.init_latent = None # by get_noise()
|
||||
|
||||
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
|
||||
conditioning,init_image,strength,step_callback=None,threshold=0.0,perlin=0.0,
|
||||
attention_maps_callback=None,
|
||||
conditioning,init_image,strength,step_callback=None,threshold=0.0,warmup=0.2,perlin=0.0,
|
||||
h_symmetry_time_pct=None,v_symmetry_time_pct=None,attention_maps_callback=None,
|
||||
**kwargs):
|
||||
"""
|
||||
Returns a function returning an image derived from the prompt and the initial image
|
||||
@@ -33,8 +33,13 @@ class Img2Img(Generator):
|
||||
conditioning_data = (
|
||||
ConditioningData(
|
||||
uc, c, cfg_scale, extra_conditioning_info,
|
||||
postprocessing_settings = PostprocessingSettings(threshold, warmup=0.2) if threshold else None)
|
||||
.add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta))
|
||||
postprocessing_settings=PostprocessingSettings(
|
||||
threshold=threshold,
|
||||
warmup=warmup,
|
||||
h_symmetry_time_pct=h_symmetry_time_pct,
|
||||
v_symmetry_time_pct=v_symmetry_time_pct
|
||||
)
|
||||
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta))
|
||||
|
||||
|
||||
def make_image(x_T):
|
||||
|
||||
@@ -15,8 +15,8 @@ class Txt2Img(Generator):
|
||||
|
||||
@torch.no_grad()
|
||||
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
|
||||
conditioning,width,height,step_callback=None,threshold=0.0,perlin=0.0,
|
||||
attention_maps_callback=None,
|
||||
conditioning,width,height,step_callback=None,threshold=0.0,warmup=0.2,perlin=0.0,
|
||||
h_symmetry_time_pct=None,v_symmetry_time_pct=None,attention_maps_callback=None,
|
||||
**kwargs):
|
||||
"""
|
||||
Returns a function returning an image derived from the prompt and the initial image
|
||||
@@ -33,8 +33,13 @@ class Txt2Img(Generator):
|
||||
conditioning_data = (
|
||||
ConditioningData(
|
||||
uc, c, cfg_scale, extra_conditioning_info,
|
||||
postprocessing_settings = PostprocessingSettings(threshold, warmup=0.2) if threshold else None)
|
||||
.add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta))
|
||||
postprocessing_settings=PostprocessingSettings(
|
||||
threshold=threshold,
|
||||
warmup=warmup,
|
||||
h_symmetry_time_pct=h_symmetry_time_pct,
|
||||
v_symmetry_time_pct=v_symmetry_time_pct
|
||||
)
|
||||
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta))
|
||||
|
||||
def make_image(x_T) -> PIL.Image.Image:
|
||||
pipeline_output = pipeline.image_from_embeddings(
|
||||
@@ -44,8 +49,10 @@ class Txt2Img(Generator):
|
||||
conditioning_data=conditioning_data,
|
||||
callback=step_callback,
|
||||
)
|
||||
|
||||
if pipeline_output.attention_map_saver is not None and attention_maps_callback is not None:
|
||||
attention_maps_callback(pipeline_output.attention_map_saver)
|
||||
|
||||
return pipeline.numpy_to_pil(pipeline_output.images)[0]
|
||||
|
||||
return make_image
|
||||
|
||||
@@ -21,12 +21,14 @@ class Txt2Img2Img(Generator):
|
||||
|
||||
def get_make_image(self, prompt:str, sampler, steps:int, cfg_scale:float, ddim_eta,
|
||||
conditioning, width:int, height:int, strength:float,
|
||||
step_callback:Optional[Callable]=None, threshold=0.0, **kwargs):
|
||||
step_callback:Optional[Callable]=None, threshold=0.0, warmup=0.2, perlin=0.0,
|
||||
h_symmetry_time_pct=None, v_symmetry_time_pct=None, attention_maps_callback=None, **kwargs):
|
||||
"""
|
||||
Returns a function returning an image derived from the prompt and the initial image
|
||||
Return value depends on the seed at the time you call it
|
||||
kwargs are 'width' and 'height'
|
||||
"""
|
||||
self.perlin = perlin
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
pipeline: StableDiffusionGeneratorPipeline = self.model
|
||||
@@ -36,8 +38,13 @@ class Txt2Img2Img(Generator):
|
||||
conditioning_data = (
|
||||
ConditioningData(
|
||||
uc, c, cfg_scale, extra_conditioning_info,
|
||||
postprocessing_settings = PostprocessingSettings(threshold=threshold, warmup=0.2) if threshold else None)
|
||||
.add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta))
|
||||
postprocessing_settings = PostprocessingSettings(
|
||||
threshold=threshold,
|
||||
warmup=0.2,
|
||||
h_symmetry_time_pct=h_symmetry_time_pct,
|
||||
v_symmetry_time_pct=v_symmetry_time_pct
|
||||
)
|
||||
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta))
|
||||
|
||||
def make_image(x_T):
|
||||
|
||||
@@ -69,19 +76,28 @@ class Txt2Img2Img(Generator):
|
||||
if clear_cuda_cache is not None:
|
||||
clear_cuda_cache()
|
||||
|
||||
second_pass_noise = self.get_noise_like(resized_latents)
|
||||
second_pass_noise = self.get_noise_like(resized_latents, override_perlin=True)
|
||||
|
||||
# Clear symmetry for the second pass
|
||||
from dataclasses import replace
|
||||
new_postprocessing_settings = replace(conditioning_data.postprocessing_settings, h_symmetry_time_pct=None)
|
||||
new_postprocessing_settings = replace(new_postprocessing_settings, v_symmetry_time_pct=None)
|
||||
new_conditioning_data = replace(conditioning_data, postprocessing_settings=new_postprocessing_settings)
|
||||
|
||||
verbosity = get_verbosity()
|
||||
set_verbosity_error()
|
||||
pipeline_output = pipeline.img2img_from_latents_and_embeddings(
|
||||
resized_latents,
|
||||
num_inference_steps=steps,
|
||||
conditioning_data=conditioning_data,
|
||||
conditioning_data=new_conditioning_data,
|
||||
strength=strength,
|
||||
noise=second_pass_noise,
|
||||
callback=step_callback)
|
||||
set_verbosity(verbosity)
|
||||
|
||||
if pipeline_output.attention_map_saver is not None and attention_maps_callback is not None:
|
||||
attention_maps_callback(pipeline_output.attention_map_saver)
|
||||
|
||||
return pipeline.numpy_to_pil(pipeline_output.images)[0]
|
||||
|
||||
|
||||
@@ -95,13 +111,13 @@ class Txt2Img2Img(Generator):
|
||||
|
||||
return make_image
|
||||
|
||||
def get_noise_like(self, like: torch.Tensor):
|
||||
def get_noise_like(self, like: torch.Tensor, override_perlin: bool=False):
|
||||
device = like.device
|
||||
if device.type == 'mps':
|
||||
x = torch.randn_like(like, device='cpu', dtype=self.torch_dtype()).to(device)
|
||||
else:
|
||||
x = torch.randn_like(like, device=device, dtype=self.torch_dtype())
|
||||
if self.perlin > 0.0:
|
||||
if self.perlin > 0.0 and override_perlin == False:
|
||||
shape = like.shape
|
||||
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(shape[3], shape[2])
|
||||
return x
|
||||
@@ -139,6 +155,9 @@ class Txt2Img2Img(Generator):
|
||||
shape = (1, channels,
|
||||
scaled_height // self.downsampling_factor, scaled_width // self.downsampling_factor)
|
||||
if self.use_mps_noise or device.type == 'mps':
|
||||
return torch.randn(shape, dtype=self.torch_dtype(), device='cpu').to(device)
|
||||
tensor = torch.empty(size=shape, device='cpu')
|
||||
tensor = self.get_noise_like(like=tensor).to(device)
|
||||
else:
|
||||
return torch.randn(shape, dtype=self.torch_dtype(), device=device)
|
||||
tensor = torch.empty(size=shape, device=device)
|
||||
tensor = self.get_noise_like(like=tensor)
|
||||
return tensor
|
||||
|
||||
@@ -33,7 +33,7 @@ Globals.models_file = 'models.yaml'
|
||||
Globals.models_dir = 'models'
|
||||
Globals.config_dir = 'configs'
|
||||
Globals.autoscan_dir = 'weights'
|
||||
Globals.converted_ckpts_dir = 'converted-ckpts'
|
||||
Globals.converted_ckpts_dir = 'converted_ckpts'
|
||||
|
||||
# Try loading patchmatch
|
||||
Globals.try_patchmatch = True
|
||||
@@ -54,6 +54,9 @@ Globals.full_precision = False
|
||||
# whether we should convert ckpt files into diffusers models on the fly
|
||||
Globals.ckpt_convert = False
|
||||
|
||||
# logging tokenization everywhere
|
||||
Globals.log_tokenization = False
|
||||
|
||||
def global_config_file()->Path:
|
||||
return Path(Globals.root, Globals.config_dir, Globals.models_file)
|
||||
|
||||
@@ -66,6 +69,9 @@ def global_models_dir()->Path:
|
||||
def global_autoscan_dir()->Path:
|
||||
return Path(Globals.root, Globals.autoscan_dir)
|
||||
|
||||
def global_converted_ckpts_dir()->Path:
|
||||
return Path(global_models_dir(), Globals.converted_ckpts_dir)
|
||||
|
||||
def global_set_root(root_dir:Union[str,Path]):
|
||||
Globals.root = root_dir
|
||||
|
||||
|
||||
@@ -79,8 +79,8 @@ def merge_diffusion_models_and_commit(
|
||||
merged_model_name = name for new model
|
||||
alpha - The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
|
||||
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
|
||||
interp - The interpolation method to use for the merging. Supports "sigmoid", "inv_sigmoid", "add_difference" and None.
|
||||
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported.
|
||||
interp - The interpolation method to use for the merging. Supports "weighted_average", "sigmoid", "inv_sigmoid", "add_difference" and None.
|
||||
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported. Add_difference is A+(B-C).
|
||||
force - Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
|
||||
|
||||
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
|
||||
@@ -173,7 +173,6 @@ def _parse_args() -> Namespace:
|
||||
|
||||
# ------------------------- GUI HERE -------------------------
|
||||
class FloatSlider(npyscreen.Slider):
|
||||
# this is supposed to adjust display precision, but doesn't
|
||||
def translate_value(self):
|
||||
stri = "%3.2f / %3.2f" % (self.value, self.out_of)
|
||||
l = (len(str(self.out_of))) * 2 + 4
|
||||
@@ -186,7 +185,7 @@ class FloatTitleSlider(npyscreen.TitleText):
|
||||
|
||||
|
||||
class mergeModelsForm(npyscreen.FormMultiPageAction):
|
||||
interpolations = ["weighted_sum", "sigmoid", "inv_sigmoid", "add_difference"]
|
||||
interpolations = ["weighted_sum", "sigmoid", "inv_sigmoid"]
|
||||
|
||||
def __init__(self, parentApp, name):
|
||||
self.parentApp = parentApp
|
||||
@@ -305,8 +304,8 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
||||
self.alpha = self.add_widget_intelligent(
|
||||
FloatTitleSlider,
|
||||
name="Weight (alpha) to assign to second and third models:",
|
||||
out_of=1,
|
||||
step=0.05,
|
||||
out_of=1.0,
|
||||
step=0.01,
|
||||
lowest=0,
|
||||
value=0.5,
|
||||
labelColor="CONTROL",
|
||||
@@ -323,7 +322,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
||||
self.merged_model_name.value = merged_model_name
|
||||
|
||||
if selected_model3 > 0:
|
||||
self.merge_method.values = (["add_difference"],)
|
||||
self.merge_method.values = ['add_difference ( A+(B-C) )']
|
||||
self.merged_model_name.value += f"+{models[selected_model3]}"
|
||||
else:
|
||||
self.merge_method.values = self.interpolations
|
||||
@@ -349,11 +348,14 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
||||
]
|
||||
if self.model3.value[0] > 0:
|
||||
models.append(model_names[self.model3.value[0] - 1])
|
||||
interp='add_difference'
|
||||
else:
|
||||
interp=self.interpolations[self.merge_method.value[0]]
|
||||
|
||||
args = dict(
|
||||
models=models,
|
||||
alpha=self.alpha.value,
|
||||
interp=self.interpolations[self.merge_method.value[0]],
|
||||
interp=interp,
|
||||
force=self.force.value,
|
||||
merged_model_name=self.merged_model_name.value,
|
||||
)
|
||||
|
||||
@@ -25,19 +25,18 @@ import torch
|
||||
import transformers
|
||||
from diffusers import AutoencoderKL
|
||||
from diffusers import logging as dlogging
|
||||
from diffusers.utils.logging import (get_verbosity, set_verbosity,
|
||||
set_verbosity_error)
|
||||
from huggingface_hub import scan_cache_dir
|
||||
from omegaconf import OmegaConf
|
||||
from omegaconf.dictconfig import DictConfig
|
||||
from picklescan.scanner import scan_file_path
|
||||
|
||||
from ldm.invoke.devices import CPU_DEVICE
|
||||
from ldm.invoke.generator.diffusers_pipeline import \
|
||||
StableDiffusionGeneratorPipeline
|
||||
from ldm.invoke.globals import (Globals, global_autoscan_dir, global_cache_dir,
|
||||
global_models_dir)
|
||||
from ldm.util import (ask_user, download_with_progress_bar,
|
||||
instantiate_from_config)
|
||||
from ldm.util import (ask_user, download_with_resume,
|
||||
url_attachment_name, instantiate_from_config)
|
||||
|
||||
DEFAULT_MAX_MODELS = 2
|
||||
VAE_TO_REPO_ID = { # hack, see note in convert_and_import()
|
||||
@@ -49,9 +48,10 @@ class ModelManager(object):
|
||||
def __init__(
|
||||
self,
|
||||
config: OmegaConf,
|
||||
device_type: str = "cpu",
|
||||
device_type: torch.device = CPU_DEVICE,
|
||||
precision: str = "float16",
|
||||
max_loaded_models=DEFAULT_MAX_MODELS,
|
||||
sequential_offload = False
|
||||
):
|
||||
"""
|
||||
Initialize with the path to the models.yaml config file,
|
||||
@@ -69,6 +69,7 @@ class ModelManager(object):
|
||||
self.models = {}
|
||||
self.stack = [] # this is an LRU FIFO
|
||||
self.current_model = None
|
||||
self.sequential_offload = sequential_offload
|
||||
|
||||
def valid_model(self, model_name: str) -> bool:
|
||||
"""
|
||||
@@ -529,7 +530,10 @@ class ModelManager(object):
|
||||
dlogging.set_verbosity(verbosity)
|
||||
assert pipeline is not None, OSError(f'"{name_or_path}" could not be loaded')
|
||||
|
||||
pipeline.to(self.device)
|
||||
if self.sequential_offload:
|
||||
pipeline.enable_offload_submodels(self.device)
|
||||
else:
|
||||
pipeline.to(self.device)
|
||||
|
||||
model_hash = self._diffuser_sha256(name_or_path)
|
||||
|
||||
@@ -670,15 +674,18 @@ class ModelManager(object):
|
||||
path to the configuration file, then the new entry will be committed to the
|
||||
models.yaml file.
|
||||
"""
|
||||
if str(weights).startswith(("http:", "https:")):
|
||||
model_name = model_name or url_attachment_name(weights)
|
||||
|
||||
weights_path = self._resolve_path(weights, "models/ldm/stable-diffusion-v1")
|
||||
config_path = self._resolve_path(config, "configs/stable-diffusion")
|
||||
config_path = self._resolve_path(config, "configs/stable-diffusion")
|
||||
|
||||
if weights_path is None or not weights_path.exists():
|
||||
return False
|
||||
if config_path is None or not config_path.exists():
|
||||
return False
|
||||
|
||||
model_name = model_name or Path(weights).stem
|
||||
model_name = model_name or Path(weights).stem # note this gives ugly pathnames if used on a URL without a Content-Disposition header
|
||||
model_description = (
|
||||
model_description or f"imported stable diffusion weights file {model_name}"
|
||||
)
|
||||
@@ -748,7 +755,6 @@ class ModelManager(object):
|
||||
into models.yaml.
|
||||
"""
|
||||
new_config = None
|
||||
import transformers
|
||||
|
||||
from ldm.invoke.ckpt_to_diffuser import convert_ckpt_to_diffuser
|
||||
|
||||
@@ -759,7 +765,7 @@ class ModelManager(object):
|
||||
return
|
||||
|
||||
model_name = model_name or diffusers_path.name
|
||||
model_description = model_description or "Optimized version of {model_name}"
|
||||
model_description = model_description or f"Optimized version of {model_name}"
|
||||
print(f">> Optimizing {model_name} (30-60s)")
|
||||
try:
|
||||
# By passing the specified VAE too the conversion function, the autoencoder
|
||||
@@ -799,15 +805,17 @@ class ModelManager(object):
|
||||
models_folder_safetensors = Path(search_folder).glob("**/*.safetensors")
|
||||
|
||||
ckpt_files = [x for x in models_folder_ckpt if x.is_file()]
|
||||
safetensor_files = [x for x in models_folder_safetensors if x.is_file]
|
||||
safetensor_files = [x for x in models_folder_safetensors if x.is_file()]
|
||||
|
||||
files = ckpt_files + safetensor_files
|
||||
|
||||
found_models = []
|
||||
for file in files:
|
||||
found_models.append(
|
||||
{"name": file.stem, "location": str(file.resolve()).replace("\\", "/")}
|
||||
)
|
||||
location = str(file.resolve()).replace("\\", "/")
|
||||
if 'model.safetensors' not in location and 'diffusion_pytorch_model.safetensors' not in location:
|
||||
found_models.append(
|
||||
{"name": file.stem, "location": location}
|
||||
)
|
||||
|
||||
return search_folder, found_models
|
||||
|
||||
@@ -967,16 +975,15 @@ class ModelManager(object):
|
||||
print("** Migration is done. Continuing...")
|
||||
|
||||
def _resolve_path(
|
||||
self, source: Union[str, Path], dest_directory: str
|
||||
self, source: Union[str, Path], dest_directory: str
|
||||
) -> Optional[Path]:
|
||||
resolved_path = None
|
||||
if str(source).startswith(("http:", "https:", "ftp:")):
|
||||
basename = os.path.basename(source)
|
||||
if not os.path.isabs(dest_directory):
|
||||
dest_directory = os.path.join(Globals.root, dest_directory)
|
||||
dest = os.path.join(dest_directory, basename)
|
||||
if download_with_progress_bar(str(source), Path(dest)):
|
||||
resolved_path = Path(dest)
|
||||
dest_directory = Path(dest_directory)
|
||||
if not dest_directory.is_absolute():
|
||||
dest_directory = Globals.root / dest_directory
|
||||
dest_directory.mkdir(parents=True, exist_ok=True)
|
||||
resolved_path = download_with_resume(str(source), dest_directory)
|
||||
else:
|
||||
if not os.path.isabs(source):
|
||||
source = os.path.join(Globals.root, source)
|
||||
@@ -990,25 +997,29 @@ class ModelManager(object):
|
||||
self.models.pop(model_name, None)
|
||||
|
||||
def _model_to_cpu(self, model):
|
||||
if self.device == "cpu":
|
||||
if self.device == CPU_DEVICE:
|
||||
return model
|
||||
|
||||
# diffusers really really doesn't like us moving a float16 model onto CPU
|
||||
verbosity = get_verbosity()
|
||||
set_verbosity_error()
|
||||
model.cond_stage_model.device = "cpu"
|
||||
model.to("cpu")
|
||||
set_verbosity(verbosity)
|
||||
if isinstance(model, StableDiffusionGeneratorPipeline):
|
||||
model.offload_all()
|
||||
return model
|
||||
|
||||
model.cond_stage_model.device = CPU_DEVICE
|
||||
model.to(CPU_DEVICE)
|
||||
|
||||
for submodel in ("first_stage_model", "cond_stage_model", "model"):
|
||||
try:
|
||||
getattr(model, submodel).to("cpu")
|
||||
getattr(model, submodel).to(CPU_DEVICE)
|
||||
except AttributeError:
|
||||
pass
|
||||
return model
|
||||
|
||||
def _model_from_cpu(self, model):
|
||||
if self.device == "cpu":
|
||||
if self.device == CPU_DEVICE:
|
||||
return model
|
||||
|
||||
if isinstance(model, StableDiffusionGeneratorPipeline):
|
||||
model.ready()
|
||||
return model
|
||||
|
||||
model.to(self.device)
|
||||
@@ -1161,7 +1172,7 @@ class ModelManager(object):
|
||||
strategy.execute()
|
||||
|
||||
@staticmethod
|
||||
def _abs_path(path: Union(str, Path)) -> Path:
|
||||
def _abs_path(path: str | Path) -> Path:
|
||||
if path is None or Path(path).is_absolute():
|
||||
return path
|
||||
return Path(Globals.root, path).resolve()
|
||||
|
||||
247
ldm/invoke/offloading.py
Normal file
247
ldm/invoke/offloading.py
Normal file
@@ -0,0 +1,247 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
import weakref
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from collections.abc import MutableMapping
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from accelerate.utils import send_to_device
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
|
||||
OFFLOAD_DEVICE = torch.device("cpu")
|
||||
|
||||
class _NoModel:
|
||||
"""Symbol that indicates no model is loaded.
|
||||
|
||||
(We can't weakref.ref(None), so this was my best idea at the time to come up with something
|
||||
type-checkable.)
|
||||
"""
|
||||
|
||||
def __bool__(self):
|
||||
return False
|
||||
|
||||
def to(self, device: torch.device):
|
||||
pass
|
||||
|
||||
def __repr__(self):
|
||||
return "<NO MODEL>"
|
||||
|
||||
NO_MODEL = _NoModel()
|
||||
|
||||
|
||||
class ModelGroup(metaclass=ABCMeta):
|
||||
"""
|
||||
A group of models.
|
||||
|
||||
The use case I had in mind when writing this is the sub-models used by a DiffusionPipeline,
|
||||
e.g. its text encoder, U-net, VAE, etc.
|
||||
|
||||
Those models are :py:class:`diffusers.ModelMixin`, but "model" is interchangeable with
|
||||
:py:class:`torch.nn.Module` here.
|
||||
"""
|
||||
|
||||
def __init__(self, execution_device: torch.device):
|
||||
self.execution_device = execution_device
|
||||
|
||||
@abstractmethod
|
||||
def install(self, *models: torch.nn.Module):
|
||||
"""Add models to this group."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def uninstall(self, models: torch.nn.Module):
|
||||
"""Remove models from this group."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def uninstall_all(self):
|
||||
"""Remove all models from this group."""
|
||||
|
||||
@abstractmethod
|
||||
def load(self, model: torch.nn.Module):
|
||||
"""Load this model to the execution device."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def offload_current(self):
|
||||
"""Offload the current model(s) from the execution device."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def ready(self):
|
||||
"""Ready this group for use."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set_device(self, device: torch.device):
|
||||
"""Change which device models from this group will execute on."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def device_for(self, model) -> torch.device:
|
||||
"""Get the device the given model will execute on.
|
||||
|
||||
The model should already be a member of this group.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def __contains__(self, model):
|
||||
"""Check if the model is a member of this group."""
|
||||
pass
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self.__class__.__name__} object at {id(self):x}: " \
|
||||
f"device={self.execution_device} >"
|
||||
|
||||
|
||||
class LazilyLoadedModelGroup(ModelGroup):
|
||||
"""
|
||||
Only one model from this group is loaded on the GPU at a time.
|
||||
|
||||
Running the forward method of a model will displace the previously-loaded model,
|
||||
offloading it to CPU.
|
||||
|
||||
If you call other methods on the model, e.g. ``model.encode(x)`` instead of ``model(x)``,
|
||||
you will need to explicitly load it with :py:method:`.load(model)`.
|
||||
|
||||
This implementation relies on pytorch forward-pre-hooks, and it will copy forward arguments
|
||||
to the appropriate execution device, as long as they are positional arguments and not keyword
|
||||
arguments. (I didn't make the rules; that's the way the pytorch 1.13 API works for hooks.)
|
||||
"""
|
||||
|
||||
_hooks: MutableMapping[torch.nn.Module, RemovableHandle]
|
||||
_current_model_ref: Callable[[], torch.nn.Module | _NoModel]
|
||||
|
||||
def __init__(self, execution_device: torch.device):
|
||||
super().__init__(execution_device)
|
||||
self._hooks = weakref.WeakKeyDictionary()
|
||||
self._current_model_ref = weakref.ref(NO_MODEL)
|
||||
|
||||
def install(self, *models: torch.nn.Module):
|
||||
for model in models:
|
||||
self._hooks[model] = model.register_forward_pre_hook(self._pre_hook)
|
||||
|
||||
def uninstall(self, *models: torch.nn.Module):
|
||||
for model in models:
|
||||
hook = self._hooks.pop(model)
|
||||
hook.remove()
|
||||
if self.is_current_model(model):
|
||||
# no longer hooked by this object, so don't claim to manage it
|
||||
self.clear_current_model()
|
||||
|
||||
def uninstall_all(self):
|
||||
self.uninstall(*self._hooks.keys())
|
||||
|
||||
def _pre_hook(self, module: torch.nn.Module, forward_input):
|
||||
self.load(module)
|
||||
if len(forward_input) == 0:
|
||||
warnings.warn(f"Hook for {module.__class__.__name__} got no input. "
|
||||
f"Inputs must be positional, not keywords.", stacklevel=3)
|
||||
return send_to_device(forward_input, self.execution_device)
|
||||
|
||||
def load(self, module):
|
||||
if not self.is_current_model(module):
|
||||
self.offload_current()
|
||||
self._load(module)
|
||||
|
||||
def offload_current(self):
|
||||
module = self._current_model_ref()
|
||||
if module is not NO_MODEL:
|
||||
module.to(device=OFFLOAD_DEVICE)
|
||||
self.clear_current_model()
|
||||
|
||||
def _load(self, module: torch.nn.Module) -> torch.nn.Module:
|
||||
assert self.is_empty(), f"A model is already loaded: {self._current_model_ref()}"
|
||||
module = module.to(self.execution_device)
|
||||
self.set_current_model(module)
|
||||
return module
|
||||
|
||||
def is_current_model(self, model: torch.nn.Module) -> bool:
|
||||
"""Is the given model the one currently loaded on the execution device?"""
|
||||
return self._current_model_ref() is model
|
||||
|
||||
def is_empty(self):
|
||||
"""Are none of this group's models loaded on the execution device?"""
|
||||
return self._current_model_ref() is NO_MODEL
|
||||
|
||||
def set_current_model(self, value):
|
||||
self._current_model_ref = weakref.ref(value)
|
||||
|
||||
def clear_current_model(self):
|
||||
self._current_model_ref = weakref.ref(NO_MODEL)
|
||||
|
||||
def set_device(self, device: torch.device):
|
||||
if device == self.execution_device:
|
||||
return
|
||||
self.execution_device = device
|
||||
current = self._current_model_ref()
|
||||
if current is not NO_MODEL:
|
||||
current.to(device)
|
||||
|
||||
def device_for(self, model):
|
||||
if model not in self:
|
||||
raise KeyError(f"This does not manage this model {type(model).__name__}", model)
|
||||
return self.execution_device # this implementation only dispatches to one device
|
||||
|
||||
def ready(self):
|
||||
pass # always ready to load on-demand
|
||||
|
||||
def __contains__(self, model):
|
||||
return model in self._hooks
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self.__class__.__name__} object at {id(self):x}: " \
|
||||
f"current_model={type(self._current_model_ref()).__name__} >"
|
||||
|
||||
|
||||
class FullyLoadedModelGroup(ModelGroup):
|
||||
"""
|
||||
A group of models without any implicit loading or unloading.
|
||||
|
||||
:py:meth:`.ready` loads _all_ the models to the execution device at once.
|
||||
"""
|
||||
_models: weakref.WeakSet
|
||||
|
||||
def __init__(self, execution_device: torch.device):
|
||||
super().__init__(execution_device)
|
||||
self._models = weakref.WeakSet()
|
||||
|
||||
def install(self, *models: torch.nn.Module):
|
||||
for model in models:
|
||||
self._models.add(model)
|
||||
model.to(device=self.execution_device)
|
||||
|
||||
def uninstall(self, *models: torch.nn.Module):
|
||||
for model in models:
|
||||
self._models.remove(model)
|
||||
|
||||
def uninstall_all(self):
|
||||
self.uninstall(*self._models)
|
||||
|
||||
def load(self, model):
|
||||
model.to(device=self.execution_device)
|
||||
|
||||
def offload_current(self):
|
||||
for model in self._models:
|
||||
model.to(device=OFFLOAD_DEVICE)
|
||||
|
||||
def ready(self):
|
||||
for model in self._models:
|
||||
self.load(model)
|
||||
|
||||
def set_device(self, device: torch.device):
|
||||
self.execution_device = device
|
||||
for model in self._models:
|
||||
if model.device != OFFLOAD_DEVICE:
|
||||
model.to(device=device)
|
||||
|
||||
def device_for(self, model):
|
||||
if model not in self:
|
||||
raise KeyError("This does not manage this model f{type(model).__name__}", model)
|
||||
return self.execution_device # this implementation only dispatches to one device
|
||||
|
||||
def __contains__(self, model):
|
||||
return model in self._models
|
||||
@@ -58,9 +58,11 @@ COMMANDS = (
|
||||
'--inpaint_replace','-r',
|
||||
'--png_compression','-z',
|
||||
'--text_mask','-tm',
|
||||
'--h_symmetry_time_pct',
|
||||
'--v_symmetry_time_pct',
|
||||
'!fix','!fetch','!replay','!history','!search','!clear',
|
||||
'!models','!switch','!import_model','!optimize_model','!convert_model','!edit_model','!del_model',
|
||||
'!mask',
|
||||
'!mask','!triggers',
|
||||
)
|
||||
MODEL_COMMANDS = (
|
||||
'!switch',
|
||||
@@ -138,7 +140,7 @@ class Completer(object):
|
||||
elif re.match('^'+'|'.join(MODEL_COMMANDS),buffer):
|
||||
self.matches= self._model_completions(text, state)
|
||||
|
||||
# looking for a ckpt model
|
||||
# looking for a ckpt model
|
||||
elif re.match('^'+'|'.join(CKPT_MODEL_COMMANDS),buffer):
|
||||
self.matches= self._model_completions(text, state, ckpt_only=True)
|
||||
|
||||
@@ -255,7 +257,7 @@ class Completer(object):
|
||||
update our list of models
|
||||
'''
|
||||
self.models = models
|
||||
|
||||
|
||||
def _seed_completions(self, text, state):
|
||||
m = re.search('(-S\s?|--seed[=\s]?)(\d*)',text)
|
||||
if m:
|
||||
|
||||
@@ -18,6 +18,8 @@ from ldm.models.diffusion.cross_attention_map_saving import AttentionMapSaver
|
||||
class PostprocessingSettings:
|
||||
threshold: float
|
||||
warmup: float
|
||||
h_symmetry_time_pct: Optional[float]
|
||||
v_symmetry_time_pct: Optional[float]
|
||||
|
||||
|
||||
class InvokeAIDiffuserComponent:
|
||||
@@ -30,7 +32,7 @@ class InvokeAIDiffuserComponent:
|
||||
* Hybrid conditioning (used for inpainting)
|
||||
'''
|
||||
debug_thresholding = False
|
||||
|
||||
last_percent_through = 0.0
|
||||
|
||||
@dataclass
|
||||
class ExtraConditioningInfo:
|
||||
@@ -56,6 +58,7 @@ class InvokeAIDiffuserComponent:
|
||||
self.is_running_diffusers = is_running_diffusers
|
||||
self.model_forward_callback = model_forward_callback
|
||||
self.cross_attention_control_context = None
|
||||
self.last_percent_through = 0.0
|
||||
|
||||
@contextmanager
|
||||
def custom_attention_context(self,
|
||||
@@ -164,6 +167,7 @@ class InvokeAIDiffuserComponent:
|
||||
if postprocessing_settings is not None:
|
||||
percent_through = self.calculate_percent_through(sigma, step_index, total_step_count)
|
||||
latents = self.apply_threshold(postprocessing_settings, latents, percent_through)
|
||||
latents = self.apply_symmetry(postprocessing_settings, latents, percent_through)
|
||||
return latents
|
||||
|
||||
def calculate_percent_through(self, sigma, step_index, total_step_count):
|
||||
@@ -292,8 +296,12 @@ class InvokeAIDiffuserComponent:
|
||||
self,
|
||||
postprocessing_settings: PostprocessingSettings,
|
||||
latents: torch.Tensor,
|
||||
percent_through
|
||||
percent_through: float
|
||||
) -> torch.Tensor:
|
||||
|
||||
if postprocessing_settings.threshold is None or postprocessing_settings.threshold == 0.0:
|
||||
return latents
|
||||
|
||||
threshold = postprocessing_settings.threshold
|
||||
warmup = postprocessing_settings.warmup
|
||||
|
||||
@@ -342,6 +350,56 @@ class InvokeAIDiffuserComponent:
|
||||
|
||||
return latents
|
||||
|
||||
def apply_symmetry(
|
||||
self,
|
||||
postprocessing_settings: PostprocessingSettings,
|
||||
latents: torch.Tensor,
|
||||
percent_through: float
|
||||
) -> torch.Tensor:
|
||||
|
||||
# Reset our last percent through if this is our first step.
|
||||
if percent_through == 0.0:
|
||||
self.last_percent_through = 0.0
|
||||
|
||||
if postprocessing_settings is None:
|
||||
return latents
|
||||
|
||||
# Check for out of bounds
|
||||
h_symmetry_time_pct = postprocessing_settings.h_symmetry_time_pct
|
||||
if (h_symmetry_time_pct is not None and (h_symmetry_time_pct <= 0.0 or h_symmetry_time_pct > 1.0)):
|
||||
h_symmetry_time_pct = None
|
||||
|
||||
v_symmetry_time_pct = postprocessing_settings.v_symmetry_time_pct
|
||||
if (v_symmetry_time_pct is not None and (v_symmetry_time_pct <= 0.0 or v_symmetry_time_pct > 1.0)):
|
||||
v_symmetry_time_pct = None
|
||||
|
||||
dev = latents.device.type
|
||||
|
||||
latents.to(device='cpu')
|
||||
|
||||
if (
|
||||
h_symmetry_time_pct != None and
|
||||
self.last_percent_through < h_symmetry_time_pct and
|
||||
percent_through >= h_symmetry_time_pct
|
||||
):
|
||||
# Horizontal symmetry occurs on the 3rd dimension of the latent
|
||||
width = latents.shape[3]
|
||||
x_flipped = torch.flip(latents, dims=[3])
|
||||
latents = torch.cat([latents[:, :, :, 0:int(width/2)], x_flipped[:, :, :, int(width/2):int(width)]], dim=3)
|
||||
|
||||
if (
|
||||
v_symmetry_time_pct != None and
|
||||
self.last_percent_through < v_symmetry_time_pct and
|
||||
percent_through >= v_symmetry_time_pct
|
||||
):
|
||||
# Vertical symmetry occurs on the 2nd dimension of the latent
|
||||
height = latents.shape[2]
|
||||
y_flipped = torch.flip(latents, dims=[2])
|
||||
latents = torch.cat([latents[:, :, 0:int(height / 2)], y_flipped[:, :, int(height / 2):int(height)]], dim=2)
|
||||
|
||||
self.last_percent_through = percent_through
|
||||
return latents.to(device=dev)
|
||||
|
||||
def estimate_percent_through(self, step_index, sigma):
|
||||
if step_index is not None and self.cross_attention_control_context is not None:
|
||||
# percent_through will never reach 1.0 (but this is intended)
|
||||
|
||||
@@ -214,7 +214,7 @@ class WeightedPromptFragmentsToEmbeddingsConverter():
|
||||
|
||||
def build_weighted_embedding_tensor(self, token_ids: torch.Tensor, per_token_weights: torch.Tensor) -> torch.Tensor:
|
||||
'''
|
||||
Build a tensor that embeds the passed-in token IDs and applyies the given per_token weights
|
||||
Build a tensor that embeds the passed-in token IDs and applies the given per_token weights
|
||||
:param token_ids: A tensor of shape `[self.max_length]` containing token IDs (ints)
|
||||
:param per_token_weights: A tensor of shape `[self.max_length]` containing weights (floats)
|
||||
:return: A tensor of shape `[1, self.max_length, token_dim]` representing the requested weighted embeddings
|
||||
@@ -224,13 +224,12 @@ class WeightedPromptFragmentsToEmbeddingsConverter():
|
||||
if token_ids.shape != torch.Size([self.max_length]):
|
||||
raise ValueError(f"token_ids has shape {token_ids.shape} - expected [{self.max_length}]")
|
||||
|
||||
z = self.text_encoder.forward(input_ids=token_ids.unsqueeze(0),
|
||||
return_dict=False)[0]
|
||||
z = self.text_encoder(token_ids.unsqueeze(0), return_dict=False)[0]
|
||||
empty_token_ids = torch.tensor([self.tokenizer.bos_token_id] +
|
||||
[self.tokenizer.pad_token_id] * (self.max_length-2) +
|
||||
[self.tokenizer.eos_token_id], dtype=torch.int, device=token_ids.device).unsqueeze(0)
|
||||
empty_z = self.text_encoder(input_ids=empty_token_ids).last_hidden_state
|
||||
batch_weights_expanded = per_token_weights.reshape(per_token_weights.shape + (1,)).expand(z.shape)
|
||||
[self.tokenizer.eos_token_id], dtype=torch.int, device=z.device).unsqueeze(0)
|
||||
empty_z = self.text_encoder(empty_token_ids).last_hidden_state
|
||||
batch_weights_expanded = per_token_weights.reshape(per_token_weights.shape + (1,)).expand(z.shape).to(z)
|
||||
z_delta_from_empty = z - empty_z
|
||||
weighted_z = empty_z + (z_delta_from_empty * batch_weights_expanded)
|
||||
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
import os
|
||||
import traceback
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from dataclasses import dataclass
|
||||
from picklescan.scanner import scan_file_path
|
||||
from transformers import CLIPTokenizer, CLIPTextModel
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ldm.invoke.concepts_lib import HuggingFaceConceptsLibrary
|
||||
|
||||
@@ -21,11 +22,14 @@ class TextualInversion:
|
||||
def embedding_vector_length(self) -> int:
|
||||
return self.embedding.shape[0]
|
||||
|
||||
class TextualInversionManager():
|
||||
def __init__(self,
|
||||
tokenizer: CLIPTokenizer,
|
||||
text_encoder: CLIPTextModel,
|
||||
full_precision: bool=True):
|
||||
|
||||
class TextualInversionManager:
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: CLIPTokenizer,
|
||||
text_encoder: CLIPTextModel,
|
||||
full_precision: bool = True,
|
||||
):
|
||||
self.tokenizer = tokenizer
|
||||
self.text_encoder = text_encoder
|
||||
self.full_precision = full_precision
|
||||
@@ -38,47 +42,73 @@ class TextualInversionManager():
|
||||
if concept_name in self.hf_concepts_library.concepts_loaded:
|
||||
continue
|
||||
trigger = self.hf_concepts_library.concept_to_trigger(concept_name)
|
||||
if self.has_textual_inversion_for_trigger_string(trigger) \
|
||||
or self.has_textual_inversion_for_trigger_string(concept_name) \
|
||||
or self.has_textual_inversion_for_trigger_string(f'<{concept_name}>'): # in case a token with literal angle brackets encountered
|
||||
print(f'>> Loaded local embedding for trigger {concept_name}')
|
||||
if (
|
||||
self.has_textual_inversion_for_trigger_string(trigger)
|
||||
or self.has_textual_inversion_for_trigger_string(concept_name)
|
||||
or self.has_textual_inversion_for_trigger_string(f"<{concept_name}>")
|
||||
): # in case a token with literal angle brackets encountered
|
||||
print(f">> Loaded local embedding for trigger {concept_name}")
|
||||
continue
|
||||
bin_file = self.hf_concepts_library.get_concept_model_path(concept_name)
|
||||
if not bin_file:
|
||||
continue
|
||||
print(f'>> Loaded remote embedding for trigger {concept_name}')
|
||||
print(f">> Loaded remote embedding for trigger {concept_name}")
|
||||
self.load_textual_inversion(bin_file)
|
||||
self.hf_concepts_library.concepts_loaded[concept_name]=True
|
||||
self.hf_concepts_library.concepts_loaded[concept_name] = True
|
||||
|
||||
def get_all_trigger_strings(self) -> list[str]:
|
||||
return [ti.trigger_string for ti in self.textual_inversions]
|
||||
|
||||
def load_textual_inversion(self, ckpt_path, defer_injecting_tokens: bool=False):
|
||||
if str(ckpt_path).endswith('.DS_Store'):
|
||||
def load_textual_inversion(self, ckpt_path: Union[str,Path], defer_injecting_tokens: bool = False):
|
||||
ckpt_path = Path(ckpt_path)
|
||||
if str(ckpt_path).endswith(".DS_Store"):
|
||||
return
|
||||
try:
|
||||
scan_result = scan_file_path(ckpt_path)
|
||||
scan_result = scan_file_path(str(ckpt_path))
|
||||
if scan_result.infected_files == 1:
|
||||
print(f'\n### Security Issues Found in Model: {scan_result.issues_count}')
|
||||
print('### For your safety, InvokeAI will not load this embed.')
|
||||
print(
|
||||
f"\n### Security Issues Found in Model: {scan_result.issues_count}"
|
||||
)
|
||||
print("### For your safety, InvokeAI will not load this embed.")
|
||||
return
|
||||
except Exception:
|
||||
print(f"### WARNING::: Invalid or corrupt embeddings found. Ignoring: {ckpt_path}")
|
||||
print(
|
||||
f"### {ckpt_path.parents[0].name}/{ckpt_path.name} is damaged or corrupt."
|
||||
)
|
||||
return
|
||||
|
||||
embedding_info = self._parse_embedding(str(ckpt_path))
|
||||
|
||||
if embedding_info is None:
|
||||
# We've already put out an error message about the bad embedding in _parse_embedding, so just return.
|
||||
return
|
||||
elif (
|
||||
self.text_encoder.get_input_embeddings().weight.data[0].shape[0]
|
||||
!= embedding_info["embedding"].shape[0]
|
||||
):
|
||||
print(
|
||||
f"** Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with a different token dimension. It can't be used with this model."
|
||||
)
|
||||
return
|
||||
|
||||
embedding_info = self._parse_embedding(ckpt_path)
|
||||
if embedding_info:
|
||||
try:
|
||||
self._add_textual_inversion(embedding_info['name'],
|
||||
embedding_info['embedding'],
|
||||
defer_injecting_tokens=defer_injecting_tokens)
|
||||
self._add_textual_inversion(
|
||||
embedding_info["name"],
|
||||
embedding_info["embedding"],
|
||||
defer_injecting_tokens=defer_injecting_tokens,
|
||||
)
|
||||
except ValueError as e:
|
||||
print(f' | Ignoring incompatible embedding {embedding_info["name"]}')
|
||||
print(f' | The error was {str(e)}')
|
||||
print(f" | The error was {str(e)}")
|
||||
else:
|
||||
print(f'>> Failed to load embedding located at {ckpt_path}. Unsupported file.')
|
||||
print(
|
||||
f">> Failed to load embedding located at {str(ckpt_path)}. Unsupported file."
|
||||
)
|
||||
|
||||
def _add_textual_inversion(self, trigger_str, embedding, defer_injecting_tokens=False) -> TextualInversion:
|
||||
def _add_textual_inversion(
|
||||
self, trigger_str, embedding, defer_injecting_tokens=False
|
||||
) -> TextualInversion:
|
||||
"""
|
||||
Add a textual inversion to be recognised.
|
||||
:param trigger_str: The trigger text in the prompt that activates this textual inversion. If unknown to the embedder's tokenizer, will be added.
|
||||
@@ -86,46 +116,59 @@ class TextualInversionManager():
|
||||
:return: The token id for the added embedding, either existing or newly-added.
|
||||
"""
|
||||
if trigger_str in [ti.trigger_string for ti in self.textual_inversions]:
|
||||
print(f">> TextualInversionManager refusing to overwrite already-loaded token '{trigger_str}'")
|
||||
print(
|
||||
f">> TextualInversionManager refusing to overwrite already-loaded token '{trigger_str}'"
|
||||
)
|
||||
return
|
||||
if not self.full_precision:
|
||||
embedding = embedding.half()
|
||||
if len(embedding.shape) == 1:
|
||||
embedding = embedding.unsqueeze(0)
|
||||
elif len(embedding.shape) > 2:
|
||||
raise ValueError(f"TextualInversionManager cannot add {trigger_str} because the embedding shape {embedding.shape} is incorrect. The embedding must have shape [token_dim] or [V, token_dim] where V is vector length and token_dim is 768 for SD1 or 1280 for SD2.")
|
||||
raise ValueError(
|
||||
f"TextualInversionManager cannot add {trigger_str} because the embedding shape {embedding.shape} is incorrect. The embedding must have shape [token_dim] or [V, token_dim] where V is vector length and token_dim is 768 for SD1 or 1280 for SD2."
|
||||
)
|
||||
|
||||
try:
|
||||
ti = TextualInversion(
|
||||
trigger_string=trigger_str,
|
||||
embedding=embedding
|
||||
)
|
||||
ti = TextualInversion(trigger_string=trigger_str, embedding=embedding)
|
||||
if not defer_injecting_tokens:
|
||||
self._inject_tokens_and_assign_embeddings(ti)
|
||||
self.textual_inversions.append(ti)
|
||||
return ti
|
||||
|
||||
except ValueError as e:
|
||||
if str(e).startswith('Warning'):
|
||||
if str(e).startswith("Warning"):
|
||||
print(f">> {str(e)}")
|
||||
else:
|
||||
traceback.print_exc()
|
||||
print(f">> TextualInversionManager was unable to add a textual inversion with trigger string {trigger_str}.")
|
||||
print(
|
||||
f">> TextualInversionManager was unable to add a textual inversion with trigger string {trigger_str}."
|
||||
)
|
||||
raise
|
||||
|
||||
def _inject_tokens_and_assign_embeddings(self, ti: TextualInversion) -> int:
|
||||
|
||||
if ti.trigger_token_id is not None:
|
||||
raise ValueError(f"Tokens already injected for textual inversion with trigger '{ti.trigger_string}'")
|
||||
raise ValueError(
|
||||
f"Tokens already injected for textual inversion with trigger '{ti.trigger_string}'"
|
||||
)
|
||||
|
||||
trigger_token_id = self._get_or_create_token_id_and_assign_embedding(ti.trigger_string, ti.embedding[0])
|
||||
trigger_token_id = self._get_or_create_token_id_and_assign_embedding(
|
||||
ti.trigger_string, ti.embedding[0]
|
||||
)
|
||||
|
||||
if ti.embedding_vector_length > 1:
|
||||
# for embeddings with vector length > 1
|
||||
pad_token_strings = [ti.trigger_string + "-!pad-" + str(pad_index) for pad_index in range(1, ti.embedding_vector_length)]
|
||||
pad_token_strings = [
|
||||
ti.trigger_string + "-!pad-" + str(pad_index)
|
||||
for pad_index in range(1, ti.embedding_vector_length)
|
||||
]
|
||||
# todo: batched UI for faster loading when vector length >2
|
||||
pad_token_ids = [self._get_or_create_token_id_and_assign_embedding(pad_token_str, ti.embedding[1 + i]) \
|
||||
for (i, pad_token_str) in enumerate(pad_token_strings)]
|
||||
pad_token_ids = [
|
||||
self._get_or_create_token_id_and_assign_embedding(
|
||||
pad_token_str, ti.embedding[1 + i]
|
||||
)
|
||||
for (i, pad_token_str) in enumerate(pad_token_strings)
|
||||
]
|
||||
else:
|
||||
pad_token_ids = []
|
||||
|
||||
@@ -133,7 +176,6 @@ class TextualInversionManager():
|
||||
ti.pad_token_ids = pad_token_ids
|
||||
return ti.trigger_token_id
|
||||
|
||||
|
||||
def has_textual_inversion_for_trigger_string(self, trigger_string: str) -> bool:
|
||||
try:
|
||||
ti = self.get_textual_inversion_for_trigger_string(trigger_string)
|
||||
@@ -141,32 +183,43 @@ class TextualInversionManager():
|
||||
except StopIteration:
|
||||
return False
|
||||
|
||||
|
||||
def get_textual_inversion_for_trigger_string(self, trigger_string: str) -> TextualInversion:
|
||||
return next(ti for ti in self.textual_inversions if ti.trigger_string == trigger_string)
|
||||
|
||||
def get_textual_inversion_for_trigger_string(
|
||||
self, trigger_string: str
|
||||
) -> TextualInversion:
|
||||
return next(
|
||||
ti for ti in self.textual_inversions if ti.trigger_string == trigger_string
|
||||
)
|
||||
|
||||
def get_textual_inversion_for_token_id(self, token_id: int) -> TextualInversion:
|
||||
return next(ti for ti in self.textual_inversions if ti.trigger_token_id == token_id)
|
||||
return next(
|
||||
ti for ti in self.textual_inversions if ti.trigger_token_id == token_id
|
||||
)
|
||||
|
||||
def create_deferred_token_ids_for_any_trigger_terms(self, prompt_string: str) -> list[int]:
|
||||
def create_deferred_token_ids_for_any_trigger_terms(
|
||||
self, prompt_string: str
|
||||
) -> list[int]:
|
||||
injected_token_ids = []
|
||||
for ti in self.textual_inversions:
|
||||
if ti.trigger_token_id is None and ti.trigger_string in prompt_string:
|
||||
if ti.embedding_vector_length > 1:
|
||||
print(f">> Preparing tokens for textual inversion {ti.trigger_string}...")
|
||||
print(
|
||||
f">> Preparing tokens for textual inversion {ti.trigger_string}..."
|
||||
)
|
||||
try:
|
||||
self._inject_tokens_and_assign_embeddings(ti)
|
||||
except ValueError as e:
|
||||
print(f' | Ignoring incompatible embedding trigger {ti.trigger_string}')
|
||||
print(f' | The error was {str(e)}')
|
||||
print(
|
||||
f" | Ignoring incompatible embedding trigger {ti.trigger_string}"
|
||||
)
|
||||
print(f" | The error was {str(e)}")
|
||||
continue
|
||||
injected_token_ids.append(ti.trigger_token_id)
|
||||
injected_token_ids.extend(ti.pad_token_ids)
|
||||
return injected_token_ids
|
||||
|
||||
|
||||
def expand_textual_inversion_token_ids_if_necessary(self, prompt_token_ids: list[int]) -> list[int]:
|
||||
def expand_textual_inversion_token_ids_if_necessary(
|
||||
self, prompt_token_ids: list[int]
|
||||
) -> list[int]:
|
||||
"""
|
||||
Insert padding tokens as necessary into the passed-in list of token ids to match any textual inversions it includes.
|
||||
|
||||
@@ -181,20 +234,31 @@ class TextualInversionManager():
|
||||
raise ValueError("prompt_token_ids must not start with bos_token_id")
|
||||
if prompt_token_ids[-1] == self.tokenizer.eos_token_id:
|
||||
raise ValueError("prompt_token_ids must not end with eos_token_id")
|
||||
textual_inversion_trigger_token_ids = [ti.trigger_token_id for ti in self.textual_inversions]
|
||||
textual_inversion_trigger_token_ids = [
|
||||
ti.trigger_token_id for ti in self.textual_inversions
|
||||
]
|
||||
prompt_token_ids = prompt_token_ids.copy()
|
||||
for i, token_id in reversed(list(enumerate(prompt_token_ids))):
|
||||
if token_id in textual_inversion_trigger_token_ids:
|
||||
textual_inversion = next(ti for ti in self.textual_inversions if ti.trigger_token_id == token_id)
|
||||
for pad_idx in range(0, textual_inversion.embedding_vector_length-1):
|
||||
prompt_token_ids.insert(i+pad_idx+1, textual_inversion.pad_token_ids[pad_idx])
|
||||
textual_inversion = next(
|
||||
ti
|
||||
for ti in self.textual_inversions
|
||||
if ti.trigger_token_id == token_id
|
||||
)
|
||||
for pad_idx in range(0, textual_inversion.embedding_vector_length - 1):
|
||||
prompt_token_ids.insert(
|
||||
i + pad_idx + 1, textual_inversion.pad_token_ids[pad_idx]
|
||||
)
|
||||
|
||||
return prompt_token_ids
|
||||
|
||||
|
||||
def _get_or_create_token_id_and_assign_embedding(self, token_str: str, embedding: torch.Tensor) -> int:
|
||||
def _get_or_create_token_id_and_assign_embedding(
|
||||
self, token_str: str, embedding: torch.Tensor
|
||||
) -> int:
|
||||
if len(embedding.shape) != 1:
|
||||
raise ValueError("Embedding has incorrect shape - must be [token_dim] where token_dim is 768 for SD1 or 1280 for SD2")
|
||||
raise ValueError(
|
||||
"Embedding has incorrect shape - must be [token_dim] where token_dim is 768 for SD1 or 1280 for SD2"
|
||||
)
|
||||
existing_token_id = self.tokenizer.convert_tokens_to_ids(token_str)
|
||||
if existing_token_id == self.tokenizer.unk_token_id:
|
||||
num_tokens_added = self.tokenizer.add_tokens(token_str)
|
||||
@@ -207,66 +271,79 @@ class TextualInversionManager():
|
||||
token_id = self.tokenizer.convert_tokens_to_ids(token_str)
|
||||
if token_id == self.tokenizer.unk_token_id:
|
||||
raise RuntimeError(f"Unable to find token id for token '{token_str}'")
|
||||
if self.text_encoder.get_input_embeddings().weight.data[token_id].shape != embedding.shape:
|
||||
raise ValueError(f"Warning. Cannot load embedding for {token_str}. It was trained on a model with token dimension {embedding.shape[0]}, but the current model has token dimension {self.text_encoder.get_input_embeddings().weight.data[token_id].shape[0]}.")
|
||||
if (
|
||||
self.text_encoder.get_input_embeddings().weight.data[token_id].shape
|
||||
!= embedding.shape
|
||||
):
|
||||
raise ValueError(
|
||||
f"Warning. Cannot load embedding for {token_str}. It was trained on a model with token dimension {embedding.shape[0]}, but the current model has token dimension {self.text_encoder.get_input_embeddings().weight.data[token_id].shape[0]}."
|
||||
)
|
||||
self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding
|
||||
|
||||
return token_id
|
||||
|
||||
def _parse_embedding(self, embedding_file: str):
|
||||
file_type = embedding_file.split('.')[-1]
|
||||
if file_type == 'pt':
|
||||
file_type = embedding_file.split(".")[-1]
|
||||
if file_type == "pt":
|
||||
return self._parse_embedding_pt(embedding_file)
|
||||
elif file_type == 'bin':
|
||||
elif file_type == "bin":
|
||||
return self._parse_embedding_bin(embedding_file)
|
||||
else:
|
||||
print(f'>> Not a recognized embedding file: {embedding_file}')
|
||||
print(f">> Not a recognized embedding file: {embedding_file}")
|
||||
return None
|
||||
|
||||
def _parse_embedding_pt(self, embedding_file):
|
||||
embedding_ckpt = torch.load(embedding_file, map_location='cpu')
|
||||
embedding_ckpt = torch.load(embedding_file, map_location="cpu")
|
||||
embedding_info = {}
|
||||
|
||||
# Check if valid embedding file
|
||||
if 'string_to_token' and 'string_to_param' in embedding_ckpt:
|
||||
|
||||
if "string_to_token" and "string_to_param" in embedding_ckpt:
|
||||
# Catch variants that do not have the expected keys or values.
|
||||
try:
|
||||
embedding_info['name'] = embedding_ckpt['name'] or os.path.basename(os.path.splitext(embedding_file)[0])
|
||||
embedding_info["name"] = embedding_ckpt["name"] or os.path.basename(
|
||||
os.path.splitext(embedding_file)[0]
|
||||
)
|
||||
|
||||
# Check num of embeddings and warn user only the first will be used
|
||||
embedding_info['num_of_embeddings'] = len(embedding_ckpt["string_to_token"])
|
||||
if embedding_info['num_of_embeddings'] > 1:
|
||||
print('>> More than 1 embedding found. Will use the first one')
|
||||
embedding_info["num_of_embeddings"] = len(
|
||||
embedding_ckpt["string_to_token"]
|
||||
)
|
||||
if embedding_info["num_of_embeddings"] > 1:
|
||||
print(">> More than 1 embedding found. Will use the first one")
|
||||
|
||||
embedding = list(embedding_ckpt['string_to_param'].values())[0]
|
||||
except (AttributeError,KeyError):
|
||||
embedding = list(embedding_ckpt["string_to_param"].values())[0]
|
||||
except (AttributeError, KeyError):
|
||||
return self._handle_broken_pt_variants(embedding_ckpt, embedding_file)
|
||||
|
||||
embedding_info['embedding'] = embedding
|
||||
embedding_info['num_vectors_per_token'] = embedding.size()[0]
|
||||
embedding_info['token_dim'] = embedding.size()[1]
|
||||
embedding_info["embedding"] = embedding
|
||||
embedding_info["num_vectors_per_token"] = embedding.size()[0]
|
||||
embedding_info["token_dim"] = embedding.size()[1]
|
||||
|
||||
try:
|
||||
embedding_info['trained_steps'] = embedding_ckpt['step']
|
||||
embedding_info['trained_model_name'] = embedding_ckpt['sd_checkpoint_name']
|
||||
embedding_info['trained_model_checksum'] = embedding_ckpt['sd_checkpoint']
|
||||
embedding_info["trained_steps"] = embedding_ckpt["step"]
|
||||
embedding_info["trained_model_name"] = embedding_ckpt[
|
||||
"sd_checkpoint_name"
|
||||
]
|
||||
embedding_info["trained_model_checksum"] = embedding_ckpt[
|
||||
"sd_checkpoint"
|
||||
]
|
||||
except AttributeError:
|
||||
print(">> No Training Details Found. Passing ...")
|
||||
|
||||
# .pt files found at https://cyberes.github.io/stable-diffusion-textual-inversion-models/
|
||||
# They are actually .bin files
|
||||
elif len(embedding_ckpt.keys())==1:
|
||||
print('>> Detected .bin file masquerading as .pt file')
|
||||
elif len(embedding_ckpt.keys()) == 1:
|
||||
print(">> Detected .bin file masquerading as .pt file")
|
||||
embedding_info = self._parse_embedding_bin(embedding_file)
|
||||
|
||||
else:
|
||||
print('>> Invalid embedding format')
|
||||
print(">> Invalid embedding format")
|
||||
embedding_info = None
|
||||
|
||||
return embedding_info
|
||||
|
||||
def _parse_embedding_bin(self, embedding_file):
|
||||
embedding_ckpt = torch.load(embedding_file, map_location='cpu')
|
||||
embedding_ckpt = torch.load(embedding_file, map_location="cpu")
|
||||
embedding_info = {}
|
||||
|
||||
if list(embedding_ckpt.keys()) == 0:
|
||||
@@ -274,27 +351,45 @@ class TextualInversionManager():
|
||||
embedding_info = None
|
||||
else:
|
||||
for token in list(embedding_ckpt.keys()):
|
||||
embedding_info['name'] = token or os.path.basename(os.path.splitext(embedding_file)[0])
|
||||
embedding_info['embedding'] = embedding_ckpt[token]
|
||||
embedding_info['num_vectors_per_token'] = 1 # All Concepts seem to default to 1
|
||||
embedding_info['token_dim'] = embedding_info['embedding'].size()[0]
|
||||
embedding_info["name"] = token or os.path.basename(
|
||||
os.path.splitext(embedding_file)[0]
|
||||
)
|
||||
embedding_info["embedding"] = embedding_ckpt[token]
|
||||
embedding_info[
|
||||
"num_vectors_per_token"
|
||||
] = 1 # All Concepts seem to default to 1
|
||||
embedding_info["token_dim"] = embedding_info["embedding"].size()[0]
|
||||
|
||||
return embedding_info
|
||||
|
||||
def _handle_broken_pt_variants(self, embedding_ckpt:dict, embedding_file:str)->dict:
|
||||
'''
|
||||
def _handle_broken_pt_variants(
|
||||
self, embedding_ckpt: dict, embedding_file: str
|
||||
) -> dict:
|
||||
"""
|
||||
This handles the broken .pt file variants. We only know of one at present.
|
||||
'''
|
||||
"""
|
||||
embedding_info = {}
|
||||
if isinstance(list(embedding_ckpt['string_to_token'].values())[0],torch.Tensor):
|
||||
print('>> Detected .pt file variant 1') # example at https://github.com/invoke-ai/InvokeAI/issues/1829
|
||||
for token in list(embedding_ckpt['string_to_token'].keys()):
|
||||
embedding_info['name'] = token if token != '*' else os.path.basename(os.path.splitext(embedding_file)[0])
|
||||
embedding_info['embedding'] = embedding_ckpt['string_to_param'].state_dict()[token]
|
||||
embedding_info['num_vectors_per_token'] = embedding_info['embedding'].shape[0]
|
||||
embedding_info['token_dim'] = embedding_info['embedding'].size()[0]
|
||||
if isinstance(
|
||||
list(embedding_ckpt["string_to_token"].values())[0], torch.Tensor
|
||||
):
|
||||
print(
|
||||
">> Detected .pt file variant 1"
|
||||
) # example at https://github.com/invoke-ai/InvokeAI/issues/1829
|
||||
for token in list(embedding_ckpt["string_to_token"].keys()):
|
||||
embedding_info["name"] = (
|
||||
token
|
||||
if token != "*"
|
||||
else os.path.basename(os.path.splitext(embedding_file)[0])
|
||||
)
|
||||
embedding_info["embedding"] = embedding_ckpt[
|
||||
"string_to_param"
|
||||
].state_dict()[token]
|
||||
embedding_info["num_vectors_per_token"] = embedding_info[
|
||||
"embedding"
|
||||
].shape[0]
|
||||
embedding_info["token_dim"] = embedding_info["embedding"].size()[0]
|
||||
else:
|
||||
print('>> Invalid embedding format')
|
||||
print(">> Invalid embedding format")
|
||||
embedding_info = None
|
||||
|
||||
return embedding_info
|
||||
|
||||
238
ldm/util.py
238
ldm/util.py
@@ -1,20 +1,21 @@
|
||||
import importlib
|
||||
import math
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import re
|
||||
from collections import abc
|
||||
from inspect import isfunction
|
||||
from pathlib import Path
|
||||
from queue import Queue
|
||||
from threading import Thread
|
||||
from urllib import request
|
||||
from tqdm import tqdm
|
||||
from pathlib import Path
|
||||
from ldm.invoke.devices import torch_dtype
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
import torch
|
||||
import os
|
||||
import traceback
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
from tqdm import tqdm
|
||||
|
||||
from ldm.invoke.devices import torch_dtype
|
||||
|
||||
|
||||
def log_txt_as_img(wh, xc, size=10):
|
||||
@@ -23,18 +24,18 @@ def log_txt_as_img(wh, xc, size=10):
|
||||
b = len(xc)
|
||||
txts = list()
|
||||
for bi in range(b):
|
||||
txt = Image.new('RGB', wh, color='white')
|
||||
txt = Image.new("RGB", wh, color="white")
|
||||
draw = ImageDraw.Draw(txt)
|
||||
font = ImageFont.load_default()
|
||||
nc = int(40 * (wh[0] / 256))
|
||||
lines = '\n'.join(
|
||||
lines = "\n".join(
|
||||
xc[bi][start : start + nc] for start in range(0, len(xc[bi]), nc)
|
||||
)
|
||||
|
||||
try:
|
||||
draw.text((0, 0), lines, fill='black', font=font)
|
||||
draw.text((0, 0), lines, fill="black", font=font)
|
||||
except UnicodeEncodeError:
|
||||
print('Cant encode string for logging. Skipping.')
|
||||
print("Cant encode string for logging. Skipping.")
|
||||
|
||||
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
|
||||
txts.append(txt)
|
||||
@@ -77,25 +78,23 @@ def count_params(model, verbose=False):
|
||||
total_params = sum(p.numel() for p in model.parameters())
|
||||
if verbose:
|
||||
print(
|
||||
f' | {model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.'
|
||||
f" | {model.__class__.__name__} has {total_params * 1.e-6:.2f} M params."
|
||||
)
|
||||
return total_params
|
||||
|
||||
|
||||
def instantiate_from_config(config, **kwargs):
|
||||
if not 'target' in config:
|
||||
if config == '__is_first_stage__':
|
||||
if not "target" in config:
|
||||
if config == "__is_first_stage__":
|
||||
return None
|
||||
elif config == '__is_unconditional__':
|
||||
elif config == "__is_unconditional__":
|
||||
return None
|
||||
raise KeyError('Expected key `target` to instantiate.')
|
||||
return get_obj_from_str(config['target'])(
|
||||
**config.get('params', dict()), **kwargs
|
||||
)
|
||||
raise KeyError("Expected key `target` to instantiate.")
|
||||
return get_obj_from_str(config["target"])(**config.get("params", dict()), **kwargs)
|
||||
|
||||
|
||||
def get_obj_from_str(string, reload=False):
|
||||
module, cls = string.rsplit('.', 1)
|
||||
module, cls = string.rsplit(".", 1)
|
||||
if reload:
|
||||
module_imp = importlib.import_module(module)
|
||||
importlib.reload(module_imp)
|
||||
@@ -111,14 +110,14 @@ def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):
|
||||
else:
|
||||
res = func(data)
|
||||
Q.put([idx, res])
|
||||
Q.put('Done')
|
||||
Q.put("Done")
|
||||
|
||||
|
||||
def parallel_data_prefetch(
|
||||
func: callable,
|
||||
data,
|
||||
n_proc,
|
||||
target_data_type='ndarray',
|
||||
target_data_type="ndarray",
|
||||
cpu_intensive=True,
|
||||
use_worker_id=False,
|
||||
):
|
||||
@@ -126,21 +125,21 @@ def parallel_data_prefetch(
|
||||
# raise ValueError(
|
||||
# "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray."
|
||||
# )
|
||||
if isinstance(data, np.ndarray) and target_data_type == 'list':
|
||||
raise ValueError('list expected but function got ndarray.')
|
||||
if isinstance(data, np.ndarray) and target_data_type == "list":
|
||||
raise ValueError("list expected but function got ndarray.")
|
||||
elif isinstance(data, abc.Iterable):
|
||||
if isinstance(data, dict):
|
||||
print(
|
||||
f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
|
||||
'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
|
||||
)
|
||||
data = list(data.values())
|
||||
if target_data_type == 'ndarray':
|
||||
if target_data_type == "ndarray":
|
||||
data = np.asarray(data)
|
||||
else:
|
||||
data = list(data)
|
||||
else:
|
||||
raise TypeError(
|
||||
f'The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}.'
|
||||
f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}."
|
||||
)
|
||||
|
||||
if cpu_intensive:
|
||||
@@ -150,7 +149,7 @@ def parallel_data_prefetch(
|
||||
Q = Queue(1000)
|
||||
proc = Thread
|
||||
# spawn processes
|
||||
if target_data_type == 'ndarray':
|
||||
if target_data_type == "ndarray":
|
||||
arguments = [
|
||||
[func, Q, part, i, use_worker_id]
|
||||
for i, part in enumerate(np.array_split(data, n_proc))
|
||||
@@ -173,7 +172,7 @@ def parallel_data_prefetch(
|
||||
processes += [p]
|
||||
|
||||
# start processes
|
||||
print(f'Start prefetching...')
|
||||
print("Start prefetching...")
|
||||
import time
|
||||
|
||||
start = time.time()
|
||||
@@ -186,13 +185,13 @@ def parallel_data_prefetch(
|
||||
while k < n_proc:
|
||||
# get result
|
||||
res = Q.get()
|
||||
if res == 'Done':
|
||||
if res == "Done":
|
||||
k += 1
|
||||
else:
|
||||
gather_res[res[0]] = res[1]
|
||||
|
||||
except Exception as e:
|
||||
print('Exception: ', e)
|
||||
print("Exception: ", e)
|
||||
for p in processes:
|
||||
p.terminate()
|
||||
|
||||
@@ -200,15 +199,15 @@ def parallel_data_prefetch(
|
||||
finally:
|
||||
for p in processes:
|
||||
p.join()
|
||||
print(f'Prefetching complete. [{time.time() - start} sec.]')
|
||||
print(f"Prefetching complete. [{time.time() - start} sec.]")
|
||||
|
||||
if target_data_type == 'ndarray':
|
||||
if target_data_type == "ndarray":
|
||||
if not isinstance(gather_res[0], np.ndarray):
|
||||
return np.concatenate([np.asarray(r) for r in gather_res], axis=0)
|
||||
|
||||
# order outputs
|
||||
return np.concatenate(gather_res, axis=0)
|
||||
elif target_data_type == 'list':
|
||||
elif target_data_type == "list":
|
||||
out = []
|
||||
for r in gather_res:
|
||||
out.extend(r)
|
||||
@@ -216,49 +215,79 @@ def parallel_data_prefetch(
|
||||
else:
|
||||
return gather_res
|
||||
|
||||
def rand_perlin_2d(shape, res, device, fade = lambda t: 6*t**5 - 15*t**4 + 10*t**3):
|
||||
|
||||
def rand_perlin_2d(
|
||||
shape, res, device, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3
|
||||
):
|
||||
delta = (res[0] / shape[0], res[1] / shape[1])
|
||||
d = (shape[0] // res[0], shape[1] // res[1])
|
||||
|
||||
grid = torch.stack(torch.meshgrid(torch.arange(0, res[0], delta[0]), torch.arange(0, res[1], delta[1]), indexing='ij'), dim = -1).to(device) % 1
|
||||
grid = (
|
||||
torch.stack(
|
||||
torch.meshgrid(
|
||||
torch.arange(0, res[0], delta[0]),
|
||||
torch.arange(0, res[1], delta[1]),
|
||||
indexing="ij",
|
||||
),
|
||||
dim=-1,
|
||||
).to(device)
|
||||
% 1
|
||||
)
|
||||
|
||||
rand_val = torch.rand(res[0]+1, res[1]+1)
|
||||
rand_val = torch.rand(res[0] + 1, res[1] + 1)
|
||||
|
||||
angles = 2*math.pi*rand_val
|
||||
gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim = -1).to(device)
|
||||
angles = 2 * math.pi * rand_val
|
||||
gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1).to(device)
|
||||
|
||||
tile_grads = lambda slice1, slice2: gradients[slice1[0]:slice1[1], slice2[0]:slice2[1]].repeat_interleave(d[0], 0).repeat_interleave(d[1], 1)
|
||||
tile_grads = (
|
||||
lambda slice1, slice2: gradients[slice1[0] : slice1[1], slice2[0] : slice2[1]]
|
||||
.repeat_interleave(d[0], 0)
|
||||
.repeat_interleave(d[1], 1)
|
||||
)
|
||||
|
||||
dot = lambda grad, shift: (torch.stack((grid[:shape[0],:shape[1],0] + shift[0], grid[:shape[0],:shape[1], 1] + shift[1] ), dim = -1) * grad[:shape[0], :shape[1]]).sum(dim = -1)
|
||||
dot = lambda grad, shift: (
|
||||
torch.stack(
|
||||
(
|
||||
grid[: shape[0], : shape[1], 0] + shift[0],
|
||||
grid[: shape[0], : shape[1], 1] + shift[1],
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
* grad[: shape[0], : shape[1]]
|
||||
).sum(dim=-1)
|
||||
|
||||
n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0]).to(device)
|
||||
n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0]).to(device)
|
||||
n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0]).to(device)
|
||||
n01 = dot(tile_grads([0, -1],[1, None]), [0, -1]).to(device)
|
||||
n11 = dot(tile_grads([1, None], [1, None]), [-1,-1]).to(device)
|
||||
t = fade(grid[:shape[0], :shape[1]])
|
||||
noise = math.sqrt(2) * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1]).to(device)
|
||||
n01 = dot(tile_grads([0, -1], [1, None]), [0, -1]).to(device)
|
||||
n11 = dot(tile_grads([1, None], [1, None]), [-1, -1]).to(device)
|
||||
t = fade(grid[: shape[0], : shape[1]])
|
||||
noise = math.sqrt(2) * torch.lerp(
|
||||
torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1]
|
||||
).to(device)
|
||||
return noise.to(dtype=torch_dtype(device))
|
||||
|
||||
|
||||
def ask_user(question: str, answers: list):
|
||||
from itertools import chain, repeat
|
||||
user_prompt = f'\n>> {question} {answers}: '
|
||||
invalid_answer_msg = 'Invalid answer. Please try again.'
|
||||
pose_question = chain([user_prompt], repeat('\n'.join([invalid_answer_msg, user_prompt])))
|
||||
|
||||
user_prompt = f"\n>> {question} {answers}: "
|
||||
invalid_answer_msg = "Invalid answer. Please try again."
|
||||
pose_question = chain(
|
||||
[user_prompt], repeat("\n".join([invalid_answer_msg, user_prompt]))
|
||||
)
|
||||
user_answers = map(input, pose_question)
|
||||
valid_response = next(filter(answers.__contains__, user_answers))
|
||||
return valid_response
|
||||
|
||||
|
||||
def debug_image(debug_image, debug_text, debug_show=True, debug_result=False, debug_status=False ):
|
||||
def debug_image(
|
||||
debug_image, debug_text, debug_show=True, debug_result=False, debug_status=False
|
||||
):
|
||||
if not debug_status:
|
||||
return
|
||||
|
||||
image_copy = debug_image.copy().convert("RGBA")
|
||||
ImageDraw.Draw(image_copy).text(
|
||||
(5, 5),
|
||||
debug_text,
|
||||
(255, 0, 0)
|
||||
)
|
||||
ImageDraw.Draw(image_copy).text((5, 5), debug_text, (255, 0, 0))
|
||||
|
||||
if debug_show:
|
||||
image_copy.show()
|
||||
@@ -266,31 +295,84 @@ def debug_image(debug_image, debug_text, debug_show=True, debug_result=False, de
|
||||
if debug_result:
|
||||
return image_copy
|
||||
|
||||
#-------------------------------------
|
||||
class ProgressBar():
|
||||
def __init__(self,model_name='file'):
|
||||
self.pbar = None
|
||||
self.name = model_name
|
||||
|
||||
def __call__(self, block_num, block_size, total_size):
|
||||
if not self.pbar:
|
||||
self.pbar=tqdm(desc=self.name,
|
||||
initial=0,
|
||||
unit='iB',
|
||||
unit_scale=True,
|
||||
unit_divisor=1000,
|
||||
total=total_size)
|
||||
self.pbar.update(block_size)
|
||||
# -------------------------------------
|
||||
def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path:
|
||||
'''
|
||||
Download a model file.
|
||||
:param url: https, http or ftp URL
|
||||
:param dest: A Path object. If path exists and is a directory, then we try to derive the filename
|
||||
from the URL's Content-Disposition header and copy the URL contents into
|
||||
dest/filename
|
||||
:param access_token: Access token to access this resource
|
||||
'''
|
||||
resp = requests.get(url, stream=True)
|
||||
total = int(resp.headers.get("content-length", 0))
|
||||
|
||||
if dest.is_dir():
|
||||
try:
|
||||
file_name = re.search('filename="(.+)"', resp.headers.get("Content-Disposition")).group(1)
|
||||
except:
|
||||
file_name = os.path.basename(url)
|
||||
dest = dest / file_name
|
||||
else:
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f'DEBUG: after many manipulations, dest={dest}')
|
||||
|
||||
header = {"Authorization": f"Bearer {access_token}"} if access_token else {}
|
||||
open_mode = "wb"
|
||||
exist_size = 0
|
||||
|
||||
if dest.exists():
|
||||
exist_size = dest.stat().st_size
|
||||
header["Range"] = f"bytes={exist_size}-"
|
||||
open_mode = "ab"
|
||||
|
||||
if (
|
||||
resp.status_code == 416
|
||||
): # "range not satisfiable", which means nothing to return
|
||||
print(f"* {dest}: complete file found. Skipping.")
|
||||
return dest
|
||||
elif resp.status_code != 200:
|
||||
print(f"** An error occurred during downloading {dest}: {resp.reason}")
|
||||
elif exist_size > 0:
|
||||
print(f"* {dest}: partial file found. Resuming...")
|
||||
else:
|
||||
print(f"* {dest}: Downloading...")
|
||||
|
||||
def download_with_progress_bar(url:str, dest:Path)->bool:
|
||||
try:
|
||||
if not dest.exists():
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
request.urlretrieve(url,dest,ProgressBar(dest.stem))
|
||||
return True
|
||||
else:
|
||||
return True
|
||||
except OSError:
|
||||
print(traceback.format_exc())
|
||||
return False
|
||||
if total < 2000:
|
||||
print(f"*** ERROR DOWNLOADING {url}: {resp.text}")
|
||||
return None
|
||||
|
||||
with open(dest, open_mode) as file, tqdm(
|
||||
desc=str(dest),
|
||||
initial=exist_size,
|
||||
total=total + exist_size,
|
||||
unit="iB",
|
||||
unit_scale=True,
|
||||
unit_divisor=1000,
|
||||
) as bar:
|
||||
for data in resp.iter_content(chunk_size=1024):
|
||||
size = file.write(data)
|
||||
bar.update(size)
|
||||
except Exception as e:
|
||||
print(f"An error occurred while downloading {dest}: {str(e)}")
|
||||
return None
|
||||
|
||||
return dest
|
||||
|
||||
|
||||
def url_attachment_name(url: str) -> dict:
|
||||
try:
|
||||
resp = requests.get(url, stream=True)
|
||||
match = re.search('filename="(.+)"', resp.headers.get("Content-Disposition"))
|
||||
return match.group(1)
|
||||
except:
|
||||
return None
|
||||
|
||||
|
||||
def download_with_progress_bar(url: str, dest: Path) -> bool:
|
||||
result = download_with_resume(url, dest, access_token=None)
|
||||
return result is not None
|
||||
|
||||
Reference in New Issue
Block a user