Merge branch 'main' into enhance/update-menu

This commit is contained in:
Lincoln Stein
2023-02-20 07:38:56 -05:00
committed by GitHub
502 changed files with 214003 additions and 18781 deletions

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -2,3 +2,12 @@ from ._version import __version__
__app_id__= 'invoke-ai/InvokeAI'
__app_name__= 'InvokeAI'
def _ignore_xformers_triton_message_on_windows():
import logging
logging.getLogger("xformers").addFilter(
lambda record: 'A matching Triton is not available' not in record.getMessage())
# In order to be effective, this needs to happen before anything could possibly import xformers.
_ignore_xformers_triton_message_on_windows()

View File

@@ -272,6 +272,10 @@ class Args(object):
switches.append('--seamless')
if a['hires_fix']:
switches.append('--hires_fix')
if a['h_symmetry_time_pct']:
switches.append(f'--h_symmetry_time_pct {a["h_symmetry_time_pct"]}')
if a['v_symmetry_time_pct']:
switches.append(f'--v_symmetry_time_pct {a["v_symmetry_time_pct"]}')
# img2img generations have parameters relevant only to them and have special handling
if a['init_img'] and len(a['init_img'])>0:
@@ -751,6 +755,9 @@ class Args(object):
!fix applies upscaling/facefixing to a previously-generated image.
invoke> !fix 0000045.4829112.png -G1 -U4 -ft codeformer
*embeddings*
invoke> !triggers -- return all trigger phrases contained in loaded embedding files
*History manipulation*
!fetch retrieves the command used to generate an earlier image. Provide
a directory wildcard and the name of a file to write and all the commands
@@ -842,6 +849,18 @@ class Args(object):
type=float,
help='Perlin noise scale (0.0 - 1.0) - add perlin noise to the initialization instead of the usual gaussian noise.',
)
render_group.add_argument(
'--h_symmetry_time_pct',
default=None,
type=float,
help='Horizontal symmetry point (0.0 - 1.0) - apply horizontal symmetry at this point in image generation.',
)
render_group.add_argument(
'--v_symmetry_time_pct',
default=None,
type=float,
help='Vertical symmetry point (0.0 - 1.0) - apply vertical symmetry at this point in image generation.',
)
render_group.add_argument(
'--fnformat',
default='{prefix}.{seed}.png',
@@ -1148,7 +1167,8 @@ def metadata_dumps(opt,
# remove any image keys not mentioned in RFC #266
rfc266_img_fields = ['type','postprocessing','sampler','prompt','seed','variations','steps',
'cfg_scale','threshold','perlin','step_number','width','height','extra','strength','seamless'
'init_img','init_mask','facetool','facetool_strength','upscale']
'init_img','init_mask','facetool','facetool_strength','upscale','h_symmetry_time_pct',
'v_symmetry_time_pct']
rfc_dict ={}
for item in image_dict.items():

View File

@@ -53,6 +53,7 @@ from diffusers import (
)
from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder, PaintByExamplePipeline
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.utils import is_safetensors_available
from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextModel, CLIPTokenizer, CLIPVisionConfig
@@ -984,6 +985,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
elif model_type in ['FrozenCLIPEmbedder','WeightedFrozenCLIPEmbedder']:
text_model = convert_ldm_clip_checkpoint(checkpoint)
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14",cache_dir=cache_dir)
safety_checker = StableDiffusionSafetyChecker.from_pretrained('CompVis/stable-diffusion-safety-checker',cache_dir=global_cache_dir("hub"))
feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker",cache_dir=cache_dir)
pipe = pipeline_class(
vae=vae,
@@ -991,7 +993,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=None,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
else:

View File

@@ -93,7 +93,7 @@ def _get_conditioning_for_prompt(parsed_prompt: Union[Blend, FlattenedPrompt], p
Process prompt structure and tokens, and return (conditioning, unconditioning, extra_conditioning_info)
"""
if log_tokens or Globals.log_tokenization:
if log_tokens or getattr(Globals, "log_tokenization", False):
print(f"\n>> [TOKENLOG] Parsed Prompt: {parsed_prompt}")
print(f"\n>> [TOKENLOG] Parsed Negative Prompt: {parsed_negative_prompt}")
@@ -236,7 +236,7 @@ def _get_embeddings_and_tokens_for_prompt(model, flattened_prompt: FlattenedProm
fragments = [x.text for x in flattened_prompt.children]
weights = [x.weight for x in flattened_prompt.children]
embeddings, tokens = model.get_learned_conditioning([fragments], return_tokens=True, fragment_weights=[weights])
if log_tokens or Globals.log_tokenization:
if log_tokens or getattr(Globals, "log_tokenization", False):
text = " ".join(fragments)
log_tokenization(text, model, display_label=log_display_label)
@@ -296,4 +296,4 @@ def log_tokenization(text, model, display_label=None):
if discarded != "":
print(f'\n>> [TOKENLOG] Tokens Discarded ({totalTokens - usedTokens}):')
print(f'{discarded}\x1b[0m')
print(f'{discarded}\x1b[0m')

View File

@@ -40,7 +40,6 @@ from ldm.invoke.globals import Globals, global_cache_dir, global_config_dir
from ldm.invoke.readline import generic_completer
warnings.filterwarnings("ignore")
import torch
transformers.logging.set_verbosity_error()
@@ -764,7 +763,7 @@ def download_weights(opt: dict) -> Union[str, None]:
precision = (
"float32"
if opt.full_precision
else choose_precision(torch.device(choose_torch_device()))
else choose_precision(choose_torch_device())
)
if opt.yes_to_all:

View File

@@ -1,19 +1,25 @@
from __future__ import annotations
from contextlib import nullcontext
import torch
from torch import autocast
from contextlib import nullcontext
from ldm.invoke.globals import Globals
def choose_torch_device() -> str:
CPU_DEVICE = torch.device("cpu")
def choose_torch_device() -> torch.device:
'''Convenience routine for guessing which GPU device to run model on'''
if Globals.always_use_cpu:
return "cpu"
return CPU_DEVICE
if torch.cuda.is_available():
return 'cuda'
return torch.device('cuda')
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
return 'mps'
return 'cpu'
return torch.device('mps')
return CPU_DEVICE
def choose_precision(device) -> str:
def choose_precision(device: torch.device) -> str:
'''Returns an appropriate precision for the given torch device'''
if device.type == 'cuda':
device_name = torch.cuda.get_device_name(device)
@@ -21,7 +27,7 @@ def choose_precision(device) -> str:
return 'float16'
return 'float32'
def torch_dtype(device) -> torch.dtype:
def torch_dtype(device: torch.device) -> torch.dtype:
if Globals.full_precision:
return torch.float32
if choose_precision(device) == 'float16':
@@ -36,3 +42,13 @@ def choose_autocast(precision):
if precision == 'autocast' or precision == 'float16':
return autocast
return nullcontext
def normalize_device(device: str | torch.device) -> torch.device:
"""Ensure device has a device index defined, if appropriate."""
device = torch.device(device)
if device.index is None:
# cuda might be the only torch backend that currently uses the device index?
# I don't see anything like `current_device` for cpu or mps.
if device.type == 'cuda':
device = torch.device(device.type, torch.cuda.current_device())
return device

View File

@@ -64,6 +64,7 @@ class Generator:
def generate(self,prompt,init_image,width,height,sampler, iterations=1,seed=None,
image_callback=None, step_callback=None, threshold=0.0, perlin=0.0,
h_symmetry_time_pct=None, v_symmetry_time_pct=None,
safety_checker:dict=None,
free_gpu_mem: bool=False,
**kwargs):
@@ -81,6 +82,8 @@ class Generator:
step_callback = step_callback,
threshold = threshold,
perlin = perlin,
h_symmetry_time_pct = h_symmetry_time_pct,
v_symmetry_time_pct = v_symmetry_time_pct,
attention_maps_callback = attention_maps_callback,
**kwargs
)
@@ -247,11 +250,14 @@ class Generator:
fixdevice = 'cpu' if (self.model.device.type == 'mps') else self.model.device
# limit noise to only the diffusion image channels, not the mask channels
input_channels = min(self.latent_channels, 4)
# round up to the nearest block of 8
temp_width = int((width + 7) / 8) * 8
temp_height = int((height + 7) / 8) * 8
noise = torch.stack([
rand_perlin_2d((height, width),
rand_perlin_2d((temp_height, temp_width),
(8, 8),
device = self.model.device).to(fixdevice) for _ in range(input_channels)], dim=0).to(self.model.device)
return noise
return noise[0:4, 0:height, 0:width]
def new_seed(self):
self.seed = random.randrange(0, np.iinfo(np.uint32).max)

View File

@@ -3,39 +3,35 @@ from __future__ import annotations
import dataclasses
import inspect
import secrets
import sys
from collections.abc import Sequence
from dataclasses import dataclass, field
from typing import List, Optional, Union, Callable, Type, TypeVar, Generic, Any
if sys.version_info < (3, 10):
from typing_extensions import ParamSpec
else:
from typing import ParamSpec
import PIL.Image
import einops
import psutil
import torch
import torchvision.transforms as T
from diffusers.utils.import_utils import is_xformers_available
from ...models.diffusion.cross_attention_map_saving import AttentionMapSaver
from ...modules.prompt_to_embeddings_converter import WeightedPromptFragmentsToEmbeddingsConverter
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.outputs import BaseOutput
from torchvision.transforms.functional import resize as tv_resize
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from typing_extensions import ParamSpec
from ldm.invoke.globals import Globals
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent, PostprocessingSettings
from ldm.modules.textual_inversion_manager import TextualInversionManager
from ..devices import normalize_device, CPU_DEVICE
from ..offloading import LazilyLoadedModelGroup, FullyLoadedModelGroup, ModelGroup
from ...models.diffusion.cross_attention_map_saving import AttentionMapSaver
from ...modules.prompt_to_embeddings_converter import WeightedPromptFragmentsToEmbeddingsConverter
@dataclass
@@ -264,6 +260,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
_model_group: ModelGroup
ID_LENGTH = 8
@@ -273,7 +270,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
scheduler: KarrasDiffusionSchedulers,
safety_checker: Optional[StableDiffusionSafetyChecker],
feature_extractor: Optional[CLIPFeatureExtractor],
requires_safety_checker: bool = False,
@@ -303,8 +300,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
textual_inversion_manager=self.textual_inversion_manager
)
self._model_group = FullyLoadedModelGroup(self.unet.device)
self._model_group.install(*self._submodels)
def _adjust_memory_efficient_attention(self, latents: Torch.tensor):
def _adjust_memory_efficient_attention(self, latents: torch.Tensor):
"""
if xformers is available, use it, otherwise use sliced attention.
"""
@@ -320,9 +320,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if self.device.type == 'cpu' or self.device.type == 'mps':
mem_free = psutil.virtual_memory().free
elif self.device.type == 'cuda':
mem_free, _ = torch.cuda.mem_get_info(self.device)
mem_free, _ = torch.cuda.mem_get_info(normalize_device(self.device))
else:
raise ValueError(f"unrecognized device {device}")
raise ValueError(f"unrecognized device {self.device}")
# input tensor of [1, 4, h/8, w/8]
# output tensor of [16, (h/8 * w/8), (h/8 * w/8)]
bytes_per_element_needed_for_baddbmm_duplication = latents.element_size() + 4
@@ -336,6 +336,67 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
self.disable_attention_slicing()
def enable_offload_submodels(self, device: torch.device):
"""
Offload each submodel when it's not in use.
Useful for low-vRAM situations where the size of the model in memory is a big chunk of
the total available resource, and you want to free up as much for inference as possible.
This requires more moving parts and may add some delay as the U-Net is swapped out for the
VAE and vice-versa.
"""
models = self._submodels
if self._model_group is not None:
self._model_group.uninstall(*models)
group = LazilyLoadedModelGroup(device)
group.install(*models)
self._model_group = group
def disable_offload_submodels(self):
"""
Leave all submodels loaded.
Appropriate for cases where the size of the model in memory is small compared to the memory
required for inference. Avoids the delay and complexity of shuffling the submodels to and
from the GPU.
"""
models = self._submodels
if self._model_group is not None:
self._model_group.uninstall(*models)
group = FullyLoadedModelGroup(self._model_group.execution_device)
group.install(*models)
self._model_group = group
def offload_all(self):
"""Offload all this pipeline's models to CPU."""
self._model_group.offload_current()
def ready(self):
"""
Ready this pipeline's models.
i.e. pre-load them to the GPU if appropriate.
"""
self._model_group.ready()
def to(self, torch_device: Optional[Union[str, torch.device]] = None):
# overridden method; types match the superclass.
if torch_device is None:
return self
self._model_group.set_device(torch.device(torch_device))
self._model_group.ready()
@property
def device(self) -> torch.device:
return self._model_group.execution_device
@property
def _submodels(self) -> Sequence[torch.nn.Module]:
module_names, _, _ = self.extract_init_dict(dict(self.config))
values = [getattr(self, name) for name in module_names.keys()]
return [m for m in values if isinstance(m, torch.nn.Module)]
def image_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int,
conditioning_data: ConditioningData,
*,
@@ -377,7 +438,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
callback: Callable[[PipelineIntermediateState], None] = None
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
if timesteps is None:
self.scheduler.set_timesteps(num_inference_steps, device=self.unet.device)
self.scheduler.set_timesteps(num_inference_steps, device=self._model_group.device_for(self.unet))
timesteps = self.scheduler.timesteps
infer_latents_from_embeddings = GeneratorToCallbackinator(self.generate_latents_from_embeddings, PipelineIntermediateState)
result: PipelineIntermediateState = infer_latents_from_embeddings(
@@ -409,7 +470,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
batch_size = latents.shape[0]
batched_t = torch.full((batch_size,), timesteps[0],
dtype=timesteps.dtype, device=self.unet.device)
dtype=timesteps.dtype, device=self._model_group.device_for(self.unet))
latents = self.scheduler.add_noise(latents, noise, batched_t)
attention_map_saver: Optional[AttentionMapSaver] = None
@@ -493,9 +554,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
initial_image_latents=torch.zeros_like(latents[:1], device=latents.device, dtype=latents.dtype)
).add_mask_channels(latents)
return self.unet(sample=latents,
timestep=t,
encoder_hidden_states=text_embeddings,
# First three args should be positional, not keywords, so torch hooks can see them.
return self.unet(latents, t, text_embeddings,
cross_attention_kwargs=cross_attention_kwargs).sample
def img2img_from_embeddings(self,
@@ -514,9 +574,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
init_image = einops.rearrange(init_image, 'c h w -> 1 c h w')
# 6. Prepare latent variables
device = self.unet.device
latents_dtype = self.unet.dtype
initial_latents = self.non_noised_latents_from_image(init_image, device=device, dtype=latents_dtype)
initial_latents = self.non_noised_latents_from_image(
init_image, device=self._model_group.device_for(self.unet),
dtype=self.unet.dtype)
noise = noise_func(initial_latents)
return self.img2img_from_latents_and_embeddings(initial_latents, num_inference_steps,
@@ -529,7 +589,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
strength,
noise: torch.Tensor, run_id=None, callback=None
) -> InvokeAIStableDiffusionPipelineOutput:
timesteps, _ = self.get_img2img_timesteps(num_inference_steps, strength, self.unet.device)
timesteps, _ = self.get_img2img_timesteps(num_inference_steps, strength,
device=self._model_group.device_for(self.unet))
result_latents, result_attention_maps = self.latents_from_embeddings(
initial_latents, num_inference_steps, conditioning_data,
timesteps=timesteps,
@@ -568,7 +629,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
run_id=None,
noise_func=None,
) -> InvokeAIStableDiffusionPipelineOutput:
device = self.unet.device
device = self._model_group.device_for(self.unet)
latents_dtype = self.unet.dtype
if isinstance(init_image, PIL.Image.Image):
@@ -630,8 +691,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if device.type == 'mps':
# workaround for torch MPS bug that has been fixed in https://github.com/kulinseth/pytorch/pull/222
# TODO remove this workaround once kulinseth#222 is merged to pytorch mainline
self.vae.to('cpu')
init_image = init_image.to('cpu')
self.vae.to(CPU_DEVICE)
init_image = init_image.to(CPU_DEVICE)
else:
self._model_group.load(self.vae)
init_latent_dist = self.vae.encode(init_image).latent_dist
init_latents = init_latent_dist.sample().to(dtype=dtype) # FIXME: uses torch.randn. make reproducible!
if device.type == 'mps':
@@ -643,8 +706,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
def check_for_safety(self, output, dtype):
with torch.inference_mode():
screened_images, has_nsfw_concept = self.run_safety_checker(
output.images, device=self._execution_device, dtype=dtype)
screened_images, has_nsfw_concept = self.run_safety_checker(output.images, dtype=dtype)
screened_attention_map_saver = None
if has_nsfw_concept is None or not has_nsfw_concept:
screened_attention_map_saver = output.attention_map_saver
@@ -653,6 +715,12 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# block the attention maps if NSFW content is detected
attention_map_saver=screened_attention_map_saver)
def run_safety_checker(self, image, device=None, dtype=None):
# overriding to use the model group for device info instead of requiring the caller to know.
if self.safety_checker is not None:
device = self._model_group.device_for(self.safety_checker)
return super().run_safety_checker(image, device, dtype)
@torch.inference_mode()
def get_learned_conditioning(self, c: List[List[str]], *, return_tokens=True, fragment_weights=None):
"""
@@ -662,7 +730,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
text=c,
fragment_weights=fragment_weights,
should_return_tokens=return_tokens,
device=self.device)
device=self._model_group.device_for(self.unet))
@property
def cond_stage_model(self):
@@ -683,6 +751,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
"""Compatible with DiffusionWrapper"""
return self.unet.in_channels
def decode_latents(self, latents):
# Explicit call to get the vae loaded, since `decode` isn't the forward method.
self._model_group.load(self.vae)
return super().decode_latents(latents)
def debug_latents(self, latents, msg):
with torch.inference_mode():
from ldm.util import debug_image

View File

@@ -16,8 +16,8 @@ class Img2Img(Generator):
self.init_latent = None # by get_noise()
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
conditioning,init_image,strength,step_callback=None,threshold=0.0,perlin=0.0,
attention_maps_callback=None,
conditioning,init_image,strength,step_callback=None,threshold=0.0,warmup=0.2,perlin=0.0,
h_symmetry_time_pct=None,v_symmetry_time_pct=None,attention_maps_callback=None,
**kwargs):
"""
Returns a function returning an image derived from the prompt and the initial image
@@ -33,8 +33,13 @@ class Img2Img(Generator):
conditioning_data = (
ConditioningData(
uc, c, cfg_scale, extra_conditioning_info,
postprocessing_settings = PostprocessingSettings(threshold, warmup=0.2) if threshold else None)
.add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta))
postprocessing_settings=PostprocessingSettings(
threshold=threshold,
warmup=warmup,
h_symmetry_time_pct=h_symmetry_time_pct,
v_symmetry_time_pct=v_symmetry_time_pct
)
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta))
def make_image(x_T):

View File

@@ -15,8 +15,8 @@ class Txt2Img(Generator):
@torch.no_grad()
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
conditioning,width,height,step_callback=None,threshold=0.0,perlin=0.0,
attention_maps_callback=None,
conditioning,width,height,step_callback=None,threshold=0.0,warmup=0.2,perlin=0.0,
h_symmetry_time_pct=None,v_symmetry_time_pct=None,attention_maps_callback=None,
**kwargs):
"""
Returns a function returning an image derived from the prompt and the initial image
@@ -33,8 +33,13 @@ class Txt2Img(Generator):
conditioning_data = (
ConditioningData(
uc, c, cfg_scale, extra_conditioning_info,
postprocessing_settings = PostprocessingSettings(threshold, warmup=0.2) if threshold else None)
.add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta))
postprocessing_settings=PostprocessingSettings(
threshold=threshold,
warmup=warmup,
h_symmetry_time_pct=h_symmetry_time_pct,
v_symmetry_time_pct=v_symmetry_time_pct
)
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta))
def make_image(x_T) -> PIL.Image.Image:
pipeline_output = pipeline.image_from_embeddings(
@@ -44,8 +49,10 @@ class Txt2Img(Generator):
conditioning_data=conditioning_data,
callback=step_callback,
)
if pipeline_output.attention_map_saver is not None and attention_maps_callback is not None:
attention_maps_callback(pipeline_output.attention_map_saver)
return pipeline.numpy_to_pil(pipeline_output.images)[0]
return make_image

View File

@@ -21,12 +21,14 @@ class Txt2Img2Img(Generator):
def get_make_image(self, prompt:str, sampler, steps:int, cfg_scale:float, ddim_eta,
conditioning, width:int, height:int, strength:float,
step_callback:Optional[Callable]=None, threshold=0.0, **kwargs):
step_callback:Optional[Callable]=None, threshold=0.0, warmup=0.2, perlin=0.0,
h_symmetry_time_pct=None, v_symmetry_time_pct=None, attention_maps_callback=None, **kwargs):
"""
Returns a function returning an image derived from the prompt and the initial image
Return value depends on the seed at the time you call it
kwargs are 'width' and 'height'
"""
self.perlin = perlin
# noinspection PyTypeChecker
pipeline: StableDiffusionGeneratorPipeline = self.model
@@ -36,8 +38,13 @@ class Txt2Img2Img(Generator):
conditioning_data = (
ConditioningData(
uc, c, cfg_scale, extra_conditioning_info,
postprocessing_settings = PostprocessingSettings(threshold=threshold, warmup=0.2) if threshold else None)
.add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta))
postprocessing_settings = PostprocessingSettings(
threshold=threshold,
warmup=0.2,
h_symmetry_time_pct=h_symmetry_time_pct,
v_symmetry_time_pct=v_symmetry_time_pct
)
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta))
def make_image(x_T):
@@ -69,19 +76,28 @@ class Txt2Img2Img(Generator):
if clear_cuda_cache is not None:
clear_cuda_cache()
second_pass_noise = self.get_noise_like(resized_latents)
second_pass_noise = self.get_noise_like(resized_latents, override_perlin=True)
# Clear symmetry for the second pass
from dataclasses import replace
new_postprocessing_settings = replace(conditioning_data.postprocessing_settings, h_symmetry_time_pct=None)
new_postprocessing_settings = replace(new_postprocessing_settings, v_symmetry_time_pct=None)
new_conditioning_data = replace(conditioning_data, postprocessing_settings=new_postprocessing_settings)
verbosity = get_verbosity()
set_verbosity_error()
pipeline_output = pipeline.img2img_from_latents_and_embeddings(
resized_latents,
num_inference_steps=steps,
conditioning_data=conditioning_data,
conditioning_data=new_conditioning_data,
strength=strength,
noise=second_pass_noise,
callback=step_callback)
set_verbosity(verbosity)
if pipeline_output.attention_map_saver is not None and attention_maps_callback is not None:
attention_maps_callback(pipeline_output.attention_map_saver)
return pipeline.numpy_to_pil(pipeline_output.images)[0]
@@ -95,13 +111,13 @@ class Txt2Img2Img(Generator):
return make_image
def get_noise_like(self, like: torch.Tensor):
def get_noise_like(self, like: torch.Tensor, override_perlin: bool=False):
device = like.device
if device.type == 'mps':
x = torch.randn_like(like, device='cpu', dtype=self.torch_dtype()).to(device)
else:
x = torch.randn_like(like, device=device, dtype=self.torch_dtype())
if self.perlin > 0.0:
if self.perlin > 0.0 and override_perlin == False:
shape = like.shape
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(shape[3], shape[2])
return x
@@ -139,6 +155,9 @@ class Txt2Img2Img(Generator):
shape = (1, channels,
scaled_height // self.downsampling_factor, scaled_width // self.downsampling_factor)
if self.use_mps_noise or device.type == 'mps':
return torch.randn(shape, dtype=self.torch_dtype(), device='cpu').to(device)
tensor = torch.empty(size=shape, device='cpu')
tensor = self.get_noise_like(like=tensor).to(device)
else:
return torch.randn(shape, dtype=self.torch_dtype(), device=device)
tensor = torch.empty(size=shape, device=device)
tensor = self.get_noise_like(like=tensor)
return tensor

View File

@@ -33,7 +33,7 @@ Globals.models_file = 'models.yaml'
Globals.models_dir = 'models'
Globals.config_dir = 'configs'
Globals.autoscan_dir = 'weights'
Globals.converted_ckpts_dir = 'converted-ckpts'
Globals.converted_ckpts_dir = 'converted_ckpts'
# Try loading patchmatch
Globals.try_patchmatch = True
@@ -54,6 +54,9 @@ Globals.full_precision = False
# whether we should convert ckpt files into diffusers models on the fly
Globals.ckpt_convert = False
# logging tokenization everywhere
Globals.log_tokenization = False
def global_config_file()->Path:
return Path(Globals.root, Globals.config_dir, Globals.models_file)
@@ -66,6 +69,9 @@ def global_models_dir()->Path:
def global_autoscan_dir()->Path:
return Path(Globals.root, Globals.autoscan_dir)
def global_converted_ckpts_dir()->Path:
return Path(global_models_dir(), Globals.converted_ckpts_dir)
def global_set_root(root_dir:Union[str,Path]):
Globals.root = root_dir

View File

@@ -79,8 +79,8 @@ def merge_diffusion_models_and_commit(
merged_model_name = name for new model
alpha - The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
interp - The interpolation method to use for the merging. Supports "sigmoid", "inv_sigmoid", "add_difference" and None.
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported.
interp - The interpolation method to use for the merging. Supports "weighted_average", "sigmoid", "inv_sigmoid", "add_difference" and None.
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported. Add_difference is A+(B-C).
force - Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
@@ -173,7 +173,6 @@ def _parse_args() -> Namespace:
# ------------------------- GUI HERE -------------------------
class FloatSlider(npyscreen.Slider):
# this is supposed to adjust display precision, but doesn't
def translate_value(self):
stri = "%3.2f / %3.2f" % (self.value, self.out_of)
l = (len(str(self.out_of))) * 2 + 4
@@ -186,7 +185,7 @@ class FloatTitleSlider(npyscreen.TitleText):
class mergeModelsForm(npyscreen.FormMultiPageAction):
interpolations = ["weighted_sum", "sigmoid", "inv_sigmoid", "add_difference"]
interpolations = ["weighted_sum", "sigmoid", "inv_sigmoid"]
def __init__(self, parentApp, name):
self.parentApp = parentApp
@@ -305,8 +304,8 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
self.alpha = self.add_widget_intelligent(
FloatTitleSlider,
name="Weight (alpha) to assign to second and third models:",
out_of=1,
step=0.05,
out_of=1.0,
step=0.01,
lowest=0,
value=0.5,
labelColor="CONTROL",
@@ -323,7 +322,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
self.merged_model_name.value = merged_model_name
if selected_model3 > 0:
self.merge_method.values = (["add_difference"],)
self.merge_method.values = ['add_difference ( A+(B-C) )']
self.merged_model_name.value += f"+{models[selected_model3]}"
else:
self.merge_method.values = self.interpolations
@@ -349,11 +348,14 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
]
if self.model3.value[0] > 0:
models.append(model_names[self.model3.value[0] - 1])
interp='add_difference'
else:
interp=self.interpolations[self.merge_method.value[0]]
args = dict(
models=models,
alpha=self.alpha.value,
interp=self.interpolations[self.merge_method.value[0]],
interp=interp,
force=self.force.value,
merged_model_name=self.merged_model_name.value,
)

View File

@@ -25,19 +25,18 @@ import torch
import transformers
from diffusers import AutoencoderKL
from diffusers import logging as dlogging
from diffusers.utils.logging import (get_verbosity, set_verbosity,
set_verbosity_error)
from huggingface_hub import scan_cache_dir
from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
from picklescan.scanner import scan_file_path
from ldm.invoke.devices import CPU_DEVICE
from ldm.invoke.generator.diffusers_pipeline import \
StableDiffusionGeneratorPipeline
from ldm.invoke.globals import (Globals, global_autoscan_dir, global_cache_dir,
global_models_dir)
from ldm.util import (ask_user, download_with_progress_bar,
instantiate_from_config)
from ldm.util import (ask_user, download_with_resume,
url_attachment_name, instantiate_from_config)
DEFAULT_MAX_MODELS = 2
VAE_TO_REPO_ID = { # hack, see note in convert_and_import()
@@ -49,9 +48,10 @@ class ModelManager(object):
def __init__(
self,
config: OmegaConf,
device_type: str = "cpu",
device_type: torch.device = CPU_DEVICE,
precision: str = "float16",
max_loaded_models=DEFAULT_MAX_MODELS,
sequential_offload = False
):
"""
Initialize with the path to the models.yaml config file,
@@ -69,6 +69,7 @@ class ModelManager(object):
self.models = {}
self.stack = [] # this is an LRU FIFO
self.current_model = None
self.sequential_offload = sequential_offload
def valid_model(self, model_name: str) -> bool:
"""
@@ -529,7 +530,10 @@ class ModelManager(object):
dlogging.set_verbosity(verbosity)
assert pipeline is not None, OSError(f'"{name_or_path}" could not be loaded')
pipeline.to(self.device)
if self.sequential_offload:
pipeline.enable_offload_submodels(self.device)
else:
pipeline.to(self.device)
model_hash = self._diffuser_sha256(name_or_path)
@@ -670,15 +674,18 @@ class ModelManager(object):
path to the configuration file, then the new entry will be committed to the
models.yaml file.
"""
if str(weights).startswith(("http:", "https:")):
model_name = model_name or url_attachment_name(weights)
weights_path = self._resolve_path(weights, "models/ldm/stable-diffusion-v1")
config_path = self._resolve_path(config, "configs/stable-diffusion")
config_path = self._resolve_path(config, "configs/stable-diffusion")
if weights_path is None or not weights_path.exists():
return False
if config_path is None or not config_path.exists():
return False
model_name = model_name or Path(weights).stem
model_name = model_name or Path(weights).stem # note this gives ugly pathnames if used on a URL without a Content-Disposition header
model_description = (
model_description or f"imported stable diffusion weights file {model_name}"
)
@@ -748,7 +755,6 @@ class ModelManager(object):
into models.yaml.
"""
new_config = None
import transformers
from ldm.invoke.ckpt_to_diffuser import convert_ckpt_to_diffuser
@@ -759,7 +765,7 @@ class ModelManager(object):
return
model_name = model_name or diffusers_path.name
model_description = model_description or "Optimized version of {model_name}"
model_description = model_description or f"Optimized version of {model_name}"
print(f">> Optimizing {model_name} (30-60s)")
try:
# By passing the specified VAE too the conversion function, the autoencoder
@@ -799,15 +805,17 @@ class ModelManager(object):
models_folder_safetensors = Path(search_folder).glob("**/*.safetensors")
ckpt_files = [x for x in models_folder_ckpt if x.is_file()]
safetensor_files = [x for x in models_folder_safetensors if x.is_file]
safetensor_files = [x for x in models_folder_safetensors if x.is_file()]
files = ckpt_files + safetensor_files
found_models = []
for file in files:
found_models.append(
{"name": file.stem, "location": str(file.resolve()).replace("\\", "/")}
)
location = str(file.resolve()).replace("\\", "/")
if 'model.safetensors' not in location and 'diffusion_pytorch_model.safetensors' not in location:
found_models.append(
{"name": file.stem, "location": location}
)
return search_folder, found_models
@@ -967,16 +975,15 @@ class ModelManager(object):
print("** Migration is done. Continuing...")
def _resolve_path(
self, source: Union[str, Path], dest_directory: str
self, source: Union[str, Path], dest_directory: str
) -> Optional[Path]:
resolved_path = None
if str(source).startswith(("http:", "https:", "ftp:")):
basename = os.path.basename(source)
if not os.path.isabs(dest_directory):
dest_directory = os.path.join(Globals.root, dest_directory)
dest = os.path.join(dest_directory, basename)
if download_with_progress_bar(str(source), Path(dest)):
resolved_path = Path(dest)
dest_directory = Path(dest_directory)
if not dest_directory.is_absolute():
dest_directory = Globals.root / dest_directory
dest_directory.mkdir(parents=True, exist_ok=True)
resolved_path = download_with_resume(str(source), dest_directory)
else:
if not os.path.isabs(source):
source = os.path.join(Globals.root, source)
@@ -990,25 +997,29 @@ class ModelManager(object):
self.models.pop(model_name, None)
def _model_to_cpu(self, model):
if self.device == "cpu":
if self.device == CPU_DEVICE:
return model
# diffusers really really doesn't like us moving a float16 model onto CPU
verbosity = get_verbosity()
set_verbosity_error()
model.cond_stage_model.device = "cpu"
model.to("cpu")
set_verbosity(verbosity)
if isinstance(model, StableDiffusionGeneratorPipeline):
model.offload_all()
return model
model.cond_stage_model.device = CPU_DEVICE
model.to(CPU_DEVICE)
for submodel in ("first_stage_model", "cond_stage_model", "model"):
try:
getattr(model, submodel).to("cpu")
getattr(model, submodel).to(CPU_DEVICE)
except AttributeError:
pass
return model
def _model_from_cpu(self, model):
if self.device == "cpu":
if self.device == CPU_DEVICE:
return model
if isinstance(model, StableDiffusionGeneratorPipeline):
model.ready()
return model
model.to(self.device)
@@ -1161,7 +1172,7 @@ class ModelManager(object):
strategy.execute()
@staticmethod
def _abs_path(path: Union(str, Path)) -> Path:
def _abs_path(path: str | Path) -> Path:
if path is None or Path(path).is_absolute():
return path
return Path(Globals.root, path).resolve()

247
ldm/invoke/offloading.py Normal file
View File

@@ -0,0 +1,247 @@
from __future__ import annotations
import warnings
import weakref
from abc import ABCMeta, abstractmethod
from collections.abc import MutableMapping
from typing import Callable
import torch
from accelerate.utils import send_to_device
from torch.utils.hooks import RemovableHandle
OFFLOAD_DEVICE = torch.device("cpu")
class _NoModel:
"""Symbol that indicates no model is loaded.
(We can't weakref.ref(None), so this was my best idea at the time to come up with something
type-checkable.)
"""
def __bool__(self):
return False
def to(self, device: torch.device):
pass
def __repr__(self):
return "<NO MODEL>"
NO_MODEL = _NoModel()
class ModelGroup(metaclass=ABCMeta):
"""
A group of models.
The use case I had in mind when writing this is the sub-models used by a DiffusionPipeline,
e.g. its text encoder, U-net, VAE, etc.
Those models are :py:class:`diffusers.ModelMixin`, but "model" is interchangeable with
:py:class:`torch.nn.Module` here.
"""
def __init__(self, execution_device: torch.device):
self.execution_device = execution_device
@abstractmethod
def install(self, *models: torch.nn.Module):
"""Add models to this group."""
pass
@abstractmethod
def uninstall(self, models: torch.nn.Module):
"""Remove models from this group."""
pass
@abstractmethod
def uninstall_all(self):
"""Remove all models from this group."""
@abstractmethod
def load(self, model: torch.nn.Module):
"""Load this model to the execution device."""
pass
@abstractmethod
def offload_current(self):
"""Offload the current model(s) from the execution device."""
pass
@abstractmethod
def ready(self):
"""Ready this group for use."""
pass
@abstractmethod
def set_device(self, device: torch.device):
"""Change which device models from this group will execute on."""
pass
@abstractmethod
def device_for(self, model) -> torch.device:
"""Get the device the given model will execute on.
The model should already be a member of this group.
"""
pass
@abstractmethod
def __contains__(self, model):
"""Check if the model is a member of this group."""
pass
def __repr__(self) -> str:
return f"<{self.__class__.__name__} object at {id(self):x}: " \
f"device={self.execution_device} >"
class LazilyLoadedModelGroup(ModelGroup):
"""
Only one model from this group is loaded on the GPU at a time.
Running the forward method of a model will displace the previously-loaded model,
offloading it to CPU.
If you call other methods on the model, e.g. ``model.encode(x)`` instead of ``model(x)``,
you will need to explicitly load it with :py:method:`.load(model)`.
This implementation relies on pytorch forward-pre-hooks, and it will copy forward arguments
to the appropriate execution device, as long as they are positional arguments and not keyword
arguments. (I didn't make the rules; that's the way the pytorch 1.13 API works for hooks.)
"""
_hooks: MutableMapping[torch.nn.Module, RemovableHandle]
_current_model_ref: Callable[[], torch.nn.Module | _NoModel]
def __init__(self, execution_device: torch.device):
super().__init__(execution_device)
self._hooks = weakref.WeakKeyDictionary()
self._current_model_ref = weakref.ref(NO_MODEL)
def install(self, *models: torch.nn.Module):
for model in models:
self._hooks[model] = model.register_forward_pre_hook(self._pre_hook)
def uninstall(self, *models: torch.nn.Module):
for model in models:
hook = self._hooks.pop(model)
hook.remove()
if self.is_current_model(model):
# no longer hooked by this object, so don't claim to manage it
self.clear_current_model()
def uninstall_all(self):
self.uninstall(*self._hooks.keys())
def _pre_hook(self, module: torch.nn.Module, forward_input):
self.load(module)
if len(forward_input) == 0:
warnings.warn(f"Hook for {module.__class__.__name__} got no input. "
f"Inputs must be positional, not keywords.", stacklevel=3)
return send_to_device(forward_input, self.execution_device)
def load(self, module):
if not self.is_current_model(module):
self.offload_current()
self._load(module)
def offload_current(self):
module = self._current_model_ref()
if module is not NO_MODEL:
module.to(device=OFFLOAD_DEVICE)
self.clear_current_model()
def _load(self, module: torch.nn.Module) -> torch.nn.Module:
assert self.is_empty(), f"A model is already loaded: {self._current_model_ref()}"
module = module.to(self.execution_device)
self.set_current_model(module)
return module
def is_current_model(self, model: torch.nn.Module) -> bool:
"""Is the given model the one currently loaded on the execution device?"""
return self._current_model_ref() is model
def is_empty(self):
"""Are none of this group's models loaded on the execution device?"""
return self._current_model_ref() is NO_MODEL
def set_current_model(self, value):
self._current_model_ref = weakref.ref(value)
def clear_current_model(self):
self._current_model_ref = weakref.ref(NO_MODEL)
def set_device(self, device: torch.device):
if device == self.execution_device:
return
self.execution_device = device
current = self._current_model_ref()
if current is not NO_MODEL:
current.to(device)
def device_for(self, model):
if model not in self:
raise KeyError(f"This does not manage this model {type(model).__name__}", model)
return self.execution_device # this implementation only dispatches to one device
def ready(self):
pass # always ready to load on-demand
def __contains__(self, model):
return model in self._hooks
def __repr__(self) -> str:
return f"<{self.__class__.__name__} object at {id(self):x}: " \
f"current_model={type(self._current_model_ref()).__name__} >"
class FullyLoadedModelGroup(ModelGroup):
"""
A group of models without any implicit loading or unloading.
:py:meth:`.ready` loads _all_ the models to the execution device at once.
"""
_models: weakref.WeakSet
def __init__(self, execution_device: torch.device):
super().__init__(execution_device)
self._models = weakref.WeakSet()
def install(self, *models: torch.nn.Module):
for model in models:
self._models.add(model)
model.to(device=self.execution_device)
def uninstall(self, *models: torch.nn.Module):
for model in models:
self._models.remove(model)
def uninstall_all(self):
self.uninstall(*self._models)
def load(self, model):
model.to(device=self.execution_device)
def offload_current(self):
for model in self._models:
model.to(device=OFFLOAD_DEVICE)
def ready(self):
for model in self._models:
self.load(model)
def set_device(self, device: torch.device):
self.execution_device = device
for model in self._models:
if model.device != OFFLOAD_DEVICE:
model.to(device=device)
def device_for(self, model):
if model not in self:
raise KeyError("This does not manage this model f{type(model).__name__}", model)
return self.execution_device # this implementation only dispatches to one device
def __contains__(self, model):
return model in self._models

View File

@@ -58,9 +58,11 @@ COMMANDS = (
'--inpaint_replace','-r',
'--png_compression','-z',
'--text_mask','-tm',
'--h_symmetry_time_pct',
'--v_symmetry_time_pct',
'!fix','!fetch','!replay','!history','!search','!clear',
'!models','!switch','!import_model','!optimize_model','!convert_model','!edit_model','!del_model',
'!mask',
'!mask','!triggers',
)
MODEL_COMMANDS = (
'!switch',
@@ -138,7 +140,7 @@ class Completer(object):
elif re.match('^'+'|'.join(MODEL_COMMANDS),buffer):
self.matches= self._model_completions(text, state)
# looking for a ckpt model
# looking for a ckpt model
elif re.match('^'+'|'.join(CKPT_MODEL_COMMANDS),buffer):
self.matches= self._model_completions(text, state, ckpt_only=True)
@@ -255,7 +257,7 @@ class Completer(object):
update our list of models
'''
self.models = models
def _seed_completions(self, text, state):
m = re.search('(-S\s?|--seed[=\s]?)(\d*)',text)
if m:

View File

@@ -18,6 +18,8 @@ from ldm.models.diffusion.cross_attention_map_saving import AttentionMapSaver
class PostprocessingSettings:
threshold: float
warmup: float
h_symmetry_time_pct: Optional[float]
v_symmetry_time_pct: Optional[float]
class InvokeAIDiffuserComponent:
@@ -30,7 +32,7 @@ class InvokeAIDiffuserComponent:
* Hybrid conditioning (used for inpainting)
'''
debug_thresholding = False
last_percent_through = 0.0
@dataclass
class ExtraConditioningInfo:
@@ -56,6 +58,7 @@ class InvokeAIDiffuserComponent:
self.is_running_diffusers = is_running_diffusers
self.model_forward_callback = model_forward_callback
self.cross_attention_control_context = None
self.last_percent_through = 0.0
@contextmanager
def custom_attention_context(self,
@@ -164,6 +167,7 @@ class InvokeAIDiffuserComponent:
if postprocessing_settings is not None:
percent_through = self.calculate_percent_through(sigma, step_index, total_step_count)
latents = self.apply_threshold(postprocessing_settings, latents, percent_through)
latents = self.apply_symmetry(postprocessing_settings, latents, percent_through)
return latents
def calculate_percent_through(self, sigma, step_index, total_step_count):
@@ -292,8 +296,12 @@ class InvokeAIDiffuserComponent:
self,
postprocessing_settings: PostprocessingSettings,
latents: torch.Tensor,
percent_through
percent_through: float
) -> torch.Tensor:
if postprocessing_settings.threshold is None or postprocessing_settings.threshold == 0.0:
return latents
threshold = postprocessing_settings.threshold
warmup = postprocessing_settings.warmup
@@ -342,6 +350,56 @@ class InvokeAIDiffuserComponent:
return latents
def apply_symmetry(
self,
postprocessing_settings: PostprocessingSettings,
latents: torch.Tensor,
percent_through: float
) -> torch.Tensor:
# Reset our last percent through if this is our first step.
if percent_through == 0.0:
self.last_percent_through = 0.0
if postprocessing_settings is None:
return latents
# Check for out of bounds
h_symmetry_time_pct = postprocessing_settings.h_symmetry_time_pct
if (h_symmetry_time_pct is not None and (h_symmetry_time_pct <= 0.0 or h_symmetry_time_pct > 1.0)):
h_symmetry_time_pct = None
v_symmetry_time_pct = postprocessing_settings.v_symmetry_time_pct
if (v_symmetry_time_pct is not None and (v_symmetry_time_pct <= 0.0 or v_symmetry_time_pct > 1.0)):
v_symmetry_time_pct = None
dev = latents.device.type
latents.to(device='cpu')
if (
h_symmetry_time_pct != None and
self.last_percent_through < h_symmetry_time_pct and
percent_through >= h_symmetry_time_pct
):
# Horizontal symmetry occurs on the 3rd dimension of the latent
width = latents.shape[3]
x_flipped = torch.flip(latents, dims=[3])
latents = torch.cat([latents[:, :, :, 0:int(width/2)], x_flipped[:, :, :, int(width/2):int(width)]], dim=3)
if (
v_symmetry_time_pct != None and
self.last_percent_through < v_symmetry_time_pct and
percent_through >= v_symmetry_time_pct
):
# Vertical symmetry occurs on the 2nd dimension of the latent
height = latents.shape[2]
y_flipped = torch.flip(latents, dims=[2])
latents = torch.cat([latents[:, :, 0:int(height / 2)], y_flipped[:, :, int(height / 2):int(height)]], dim=2)
self.last_percent_through = percent_through
return latents.to(device=dev)
def estimate_percent_through(self, step_index, sigma):
if step_index is not None and self.cross_attention_control_context is not None:
# percent_through will never reach 1.0 (but this is intended)

View File

@@ -214,7 +214,7 @@ class WeightedPromptFragmentsToEmbeddingsConverter():
def build_weighted_embedding_tensor(self, token_ids: torch.Tensor, per_token_weights: torch.Tensor) -> torch.Tensor:
'''
Build a tensor that embeds the passed-in token IDs and applyies the given per_token weights
Build a tensor that embeds the passed-in token IDs and applies the given per_token weights
:param token_ids: A tensor of shape `[self.max_length]` containing token IDs (ints)
:param per_token_weights: A tensor of shape `[self.max_length]` containing weights (floats)
:return: A tensor of shape `[1, self.max_length, token_dim]` representing the requested weighted embeddings
@@ -224,13 +224,12 @@ class WeightedPromptFragmentsToEmbeddingsConverter():
if token_ids.shape != torch.Size([self.max_length]):
raise ValueError(f"token_ids has shape {token_ids.shape} - expected [{self.max_length}]")
z = self.text_encoder.forward(input_ids=token_ids.unsqueeze(0),
return_dict=False)[0]
z = self.text_encoder(token_ids.unsqueeze(0), return_dict=False)[0]
empty_token_ids = torch.tensor([self.tokenizer.bos_token_id] +
[self.tokenizer.pad_token_id] * (self.max_length-2) +
[self.tokenizer.eos_token_id], dtype=torch.int, device=token_ids.device).unsqueeze(0)
empty_z = self.text_encoder(input_ids=empty_token_ids).last_hidden_state
batch_weights_expanded = per_token_weights.reshape(per_token_weights.shape + (1,)).expand(z.shape)
[self.tokenizer.eos_token_id], dtype=torch.int, device=z.device).unsqueeze(0)
empty_z = self.text_encoder(empty_token_ids).last_hidden_state
batch_weights_expanded = per_token_weights.reshape(per_token_weights.shape + (1,)).expand(z.shape).to(z)
z_delta_from_empty = z - empty_z
weighted_z = empty_z + (z_delta_from_empty * batch_weights_expanded)

View File

@@ -1,11 +1,12 @@
import os
import traceback
from typing import Optional
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Union
import torch
from dataclasses import dataclass
from picklescan.scanner import scan_file_path
from transformers import CLIPTokenizer, CLIPTextModel
from transformers import CLIPTextModel, CLIPTokenizer
from ldm.invoke.concepts_lib import HuggingFaceConceptsLibrary
@@ -21,11 +22,14 @@ class TextualInversion:
def embedding_vector_length(self) -> int:
return self.embedding.shape[0]
class TextualInversionManager():
def __init__(self,
tokenizer: CLIPTokenizer,
text_encoder: CLIPTextModel,
full_precision: bool=True):
class TextualInversionManager:
def __init__(
self,
tokenizer: CLIPTokenizer,
text_encoder: CLIPTextModel,
full_precision: bool = True,
):
self.tokenizer = tokenizer
self.text_encoder = text_encoder
self.full_precision = full_precision
@@ -38,47 +42,73 @@ class TextualInversionManager():
if concept_name in self.hf_concepts_library.concepts_loaded:
continue
trigger = self.hf_concepts_library.concept_to_trigger(concept_name)
if self.has_textual_inversion_for_trigger_string(trigger) \
or self.has_textual_inversion_for_trigger_string(concept_name) \
or self.has_textual_inversion_for_trigger_string(f'<{concept_name}>'): # in case a token with literal angle brackets encountered
print(f'>> Loaded local embedding for trigger {concept_name}')
if (
self.has_textual_inversion_for_trigger_string(trigger)
or self.has_textual_inversion_for_trigger_string(concept_name)
or self.has_textual_inversion_for_trigger_string(f"<{concept_name}>")
): # in case a token with literal angle brackets encountered
print(f">> Loaded local embedding for trigger {concept_name}")
continue
bin_file = self.hf_concepts_library.get_concept_model_path(concept_name)
if not bin_file:
continue
print(f'>> Loaded remote embedding for trigger {concept_name}')
print(f">> Loaded remote embedding for trigger {concept_name}")
self.load_textual_inversion(bin_file)
self.hf_concepts_library.concepts_loaded[concept_name]=True
self.hf_concepts_library.concepts_loaded[concept_name] = True
def get_all_trigger_strings(self) -> list[str]:
return [ti.trigger_string for ti in self.textual_inversions]
def load_textual_inversion(self, ckpt_path, defer_injecting_tokens: bool=False):
if str(ckpt_path).endswith('.DS_Store'):
def load_textual_inversion(self, ckpt_path: Union[str,Path], defer_injecting_tokens: bool = False):
ckpt_path = Path(ckpt_path)
if str(ckpt_path).endswith(".DS_Store"):
return
try:
scan_result = scan_file_path(ckpt_path)
scan_result = scan_file_path(str(ckpt_path))
if scan_result.infected_files == 1:
print(f'\n### Security Issues Found in Model: {scan_result.issues_count}')
print('### For your safety, InvokeAI will not load this embed.')
print(
f"\n### Security Issues Found in Model: {scan_result.issues_count}"
)
print("### For your safety, InvokeAI will not load this embed.")
return
except Exception:
print(f"### WARNING::: Invalid or corrupt embeddings found. Ignoring: {ckpt_path}")
print(
f"### {ckpt_path.parents[0].name}/{ckpt_path.name} is damaged or corrupt."
)
return
embedding_info = self._parse_embedding(str(ckpt_path))
if embedding_info is None:
# We've already put out an error message about the bad embedding in _parse_embedding, so just return.
return
elif (
self.text_encoder.get_input_embeddings().weight.data[0].shape[0]
!= embedding_info["embedding"].shape[0]
):
print(
f"** Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with a different token dimension. It can't be used with this model."
)
return
embedding_info = self._parse_embedding(ckpt_path)
if embedding_info:
try:
self._add_textual_inversion(embedding_info['name'],
embedding_info['embedding'],
defer_injecting_tokens=defer_injecting_tokens)
self._add_textual_inversion(
embedding_info["name"],
embedding_info["embedding"],
defer_injecting_tokens=defer_injecting_tokens,
)
except ValueError as e:
print(f' | Ignoring incompatible embedding {embedding_info["name"]}')
print(f' | The error was {str(e)}')
print(f" | The error was {str(e)}")
else:
print(f'>> Failed to load embedding located at {ckpt_path}. Unsupported file.')
print(
f">> Failed to load embedding located at {str(ckpt_path)}. Unsupported file."
)
def _add_textual_inversion(self, trigger_str, embedding, defer_injecting_tokens=False) -> TextualInversion:
def _add_textual_inversion(
self, trigger_str, embedding, defer_injecting_tokens=False
) -> TextualInversion:
"""
Add a textual inversion to be recognised.
:param trigger_str: The trigger text in the prompt that activates this textual inversion. If unknown to the embedder's tokenizer, will be added.
@@ -86,46 +116,59 @@ class TextualInversionManager():
:return: The token id for the added embedding, either existing or newly-added.
"""
if trigger_str in [ti.trigger_string for ti in self.textual_inversions]:
print(f">> TextualInversionManager refusing to overwrite already-loaded token '{trigger_str}'")
print(
f">> TextualInversionManager refusing to overwrite already-loaded token '{trigger_str}'"
)
return
if not self.full_precision:
embedding = embedding.half()
if len(embedding.shape) == 1:
embedding = embedding.unsqueeze(0)
elif len(embedding.shape) > 2:
raise ValueError(f"TextualInversionManager cannot add {trigger_str} because the embedding shape {embedding.shape} is incorrect. The embedding must have shape [token_dim] or [V, token_dim] where V is vector length and token_dim is 768 for SD1 or 1280 for SD2.")
raise ValueError(
f"TextualInversionManager cannot add {trigger_str} because the embedding shape {embedding.shape} is incorrect. The embedding must have shape [token_dim] or [V, token_dim] where V is vector length and token_dim is 768 for SD1 or 1280 for SD2."
)
try:
ti = TextualInversion(
trigger_string=trigger_str,
embedding=embedding
)
ti = TextualInversion(trigger_string=trigger_str, embedding=embedding)
if not defer_injecting_tokens:
self._inject_tokens_and_assign_embeddings(ti)
self.textual_inversions.append(ti)
return ti
except ValueError as e:
if str(e).startswith('Warning'):
if str(e).startswith("Warning"):
print(f">> {str(e)}")
else:
traceback.print_exc()
print(f">> TextualInversionManager was unable to add a textual inversion with trigger string {trigger_str}.")
print(
f">> TextualInversionManager was unable to add a textual inversion with trigger string {trigger_str}."
)
raise
def _inject_tokens_and_assign_embeddings(self, ti: TextualInversion) -> int:
if ti.trigger_token_id is not None:
raise ValueError(f"Tokens already injected for textual inversion with trigger '{ti.trigger_string}'")
raise ValueError(
f"Tokens already injected for textual inversion with trigger '{ti.trigger_string}'"
)
trigger_token_id = self._get_or_create_token_id_and_assign_embedding(ti.trigger_string, ti.embedding[0])
trigger_token_id = self._get_or_create_token_id_and_assign_embedding(
ti.trigger_string, ti.embedding[0]
)
if ti.embedding_vector_length > 1:
# for embeddings with vector length > 1
pad_token_strings = [ti.trigger_string + "-!pad-" + str(pad_index) for pad_index in range(1, ti.embedding_vector_length)]
pad_token_strings = [
ti.trigger_string + "-!pad-" + str(pad_index)
for pad_index in range(1, ti.embedding_vector_length)
]
# todo: batched UI for faster loading when vector length >2
pad_token_ids = [self._get_or_create_token_id_and_assign_embedding(pad_token_str, ti.embedding[1 + i]) \
for (i, pad_token_str) in enumerate(pad_token_strings)]
pad_token_ids = [
self._get_or_create_token_id_and_assign_embedding(
pad_token_str, ti.embedding[1 + i]
)
for (i, pad_token_str) in enumerate(pad_token_strings)
]
else:
pad_token_ids = []
@@ -133,7 +176,6 @@ class TextualInversionManager():
ti.pad_token_ids = pad_token_ids
return ti.trigger_token_id
def has_textual_inversion_for_trigger_string(self, trigger_string: str) -> bool:
try:
ti = self.get_textual_inversion_for_trigger_string(trigger_string)
@@ -141,32 +183,43 @@ class TextualInversionManager():
except StopIteration:
return False
def get_textual_inversion_for_trigger_string(self, trigger_string: str) -> TextualInversion:
return next(ti for ti in self.textual_inversions if ti.trigger_string == trigger_string)
def get_textual_inversion_for_trigger_string(
self, trigger_string: str
) -> TextualInversion:
return next(
ti for ti in self.textual_inversions if ti.trigger_string == trigger_string
)
def get_textual_inversion_for_token_id(self, token_id: int) -> TextualInversion:
return next(ti for ti in self.textual_inversions if ti.trigger_token_id == token_id)
return next(
ti for ti in self.textual_inversions if ti.trigger_token_id == token_id
)
def create_deferred_token_ids_for_any_trigger_terms(self, prompt_string: str) -> list[int]:
def create_deferred_token_ids_for_any_trigger_terms(
self, prompt_string: str
) -> list[int]:
injected_token_ids = []
for ti in self.textual_inversions:
if ti.trigger_token_id is None and ti.trigger_string in prompt_string:
if ti.embedding_vector_length > 1:
print(f">> Preparing tokens for textual inversion {ti.trigger_string}...")
print(
f">> Preparing tokens for textual inversion {ti.trigger_string}..."
)
try:
self._inject_tokens_and_assign_embeddings(ti)
except ValueError as e:
print(f' | Ignoring incompatible embedding trigger {ti.trigger_string}')
print(f' | The error was {str(e)}')
print(
f" | Ignoring incompatible embedding trigger {ti.trigger_string}"
)
print(f" | The error was {str(e)}")
continue
injected_token_ids.append(ti.trigger_token_id)
injected_token_ids.extend(ti.pad_token_ids)
return injected_token_ids
def expand_textual_inversion_token_ids_if_necessary(self, prompt_token_ids: list[int]) -> list[int]:
def expand_textual_inversion_token_ids_if_necessary(
self, prompt_token_ids: list[int]
) -> list[int]:
"""
Insert padding tokens as necessary into the passed-in list of token ids to match any textual inversions it includes.
@@ -181,20 +234,31 @@ class TextualInversionManager():
raise ValueError("prompt_token_ids must not start with bos_token_id")
if prompt_token_ids[-1] == self.tokenizer.eos_token_id:
raise ValueError("prompt_token_ids must not end with eos_token_id")
textual_inversion_trigger_token_ids = [ti.trigger_token_id for ti in self.textual_inversions]
textual_inversion_trigger_token_ids = [
ti.trigger_token_id for ti in self.textual_inversions
]
prompt_token_ids = prompt_token_ids.copy()
for i, token_id in reversed(list(enumerate(prompt_token_ids))):
if token_id in textual_inversion_trigger_token_ids:
textual_inversion = next(ti for ti in self.textual_inversions if ti.trigger_token_id == token_id)
for pad_idx in range(0, textual_inversion.embedding_vector_length-1):
prompt_token_ids.insert(i+pad_idx+1, textual_inversion.pad_token_ids[pad_idx])
textual_inversion = next(
ti
for ti in self.textual_inversions
if ti.trigger_token_id == token_id
)
for pad_idx in range(0, textual_inversion.embedding_vector_length - 1):
prompt_token_ids.insert(
i + pad_idx + 1, textual_inversion.pad_token_ids[pad_idx]
)
return prompt_token_ids
def _get_or_create_token_id_and_assign_embedding(self, token_str: str, embedding: torch.Tensor) -> int:
def _get_or_create_token_id_and_assign_embedding(
self, token_str: str, embedding: torch.Tensor
) -> int:
if len(embedding.shape) != 1:
raise ValueError("Embedding has incorrect shape - must be [token_dim] where token_dim is 768 for SD1 or 1280 for SD2")
raise ValueError(
"Embedding has incorrect shape - must be [token_dim] where token_dim is 768 for SD1 or 1280 for SD2"
)
existing_token_id = self.tokenizer.convert_tokens_to_ids(token_str)
if existing_token_id == self.tokenizer.unk_token_id:
num_tokens_added = self.tokenizer.add_tokens(token_str)
@@ -207,66 +271,79 @@ class TextualInversionManager():
token_id = self.tokenizer.convert_tokens_to_ids(token_str)
if token_id == self.tokenizer.unk_token_id:
raise RuntimeError(f"Unable to find token id for token '{token_str}'")
if self.text_encoder.get_input_embeddings().weight.data[token_id].shape != embedding.shape:
raise ValueError(f"Warning. Cannot load embedding for {token_str}. It was trained on a model with token dimension {embedding.shape[0]}, but the current model has token dimension {self.text_encoder.get_input_embeddings().weight.data[token_id].shape[0]}.")
if (
self.text_encoder.get_input_embeddings().weight.data[token_id].shape
!= embedding.shape
):
raise ValueError(
f"Warning. Cannot load embedding for {token_str}. It was trained on a model with token dimension {embedding.shape[0]}, but the current model has token dimension {self.text_encoder.get_input_embeddings().weight.data[token_id].shape[0]}."
)
self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding
return token_id
def _parse_embedding(self, embedding_file: str):
file_type = embedding_file.split('.')[-1]
if file_type == 'pt':
file_type = embedding_file.split(".")[-1]
if file_type == "pt":
return self._parse_embedding_pt(embedding_file)
elif file_type == 'bin':
elif file_type == "bin":
return self._parse_embedding_bin(embedding_file)
else:
print(f'>> Not a recognized embedding file: {embedding_file}')
print(f">> Not a recognized embedding file: {embedding_file}")
return None
def _parse_embedding_pt(self, embedding_file):
embedding_ckpt = torch.load(embedding_file, map_location='cpu')
embedding_ckpt = torch.load(embedding_file, map_location="cpu")
embedding_info = {}
# Check if valid embedding file
if 'string_to_token' and 'string_to_param' in embedding_ckpt:
if "string_to_token" and "string_to_param" in embedding_ckpt:
# Catch variants that do not have the expected keys or values.
try:
embedding_info['name'] = embedding_ckpt['name'] or os.path.basename(os.path.splitext(embedding_file)[0])
embedding_info["name"] = embedding_ckpt["name"] or os.path.basename(
os.path.splitext(embedding_file)[0]
)
# Check num of embeddings and warn user only the first will be used
embedding_info['num_of_embeddings'] = len(embedding_ckpt["string_to_token"])
if embedding_info['num_of_embeddings'] > 1:
print('>> More than 1 embedding found. Will use the first one')
embedding_info["num_of_embeddings"] = len(
embedding_ckpt["string_to_token"]
)
if embedding_info["num_of_embeddings"] > 1:
print(">> More than 1 embedding found. Will use the first one")
embedding = list(embedding_ckpt['string_to_param'].values())[0]
except (AttributeError,KeyError):
embedding = list(embedding_ckpt["string_to_param"].values())[0]
except (AttributeError, KeyError):
return self._handle_broken_pt_variants(embedding_ckpt, embedding_file)
embedding_info['embedding'] = embedding
embedding_info['num_vectors_per_token'] = embedding.size()[0]
embedding_info['token_dim'] = embedding.size()[1]
embedding_info["embedding"] = embedding
embedding_info["num_vectors_per_token"] = embedding.size()[0]
embedding_info["token_dim"] = embedding.size()[1]
try:
embedding_info['trained_steps'] = embedding_ckpt['step']
embedding_info['trained_model_name'] = embedding_ckpt['sd_checkpoint_name']
embedding_info['trained_model_checksum'] = embedding_ckpt['sd_checkpoint']
embedding_info["trained_steps"] = embedding_ckpt["step"]
embedding_info["trained_model_name"] = embedding_ckpt[
"sd_checkpoint_name"
]
embedding_info["trained_model_checksum"] = embedding_ckpt[
"sd_checkpoint"
]
except AttributeError:
print(">> No Training Details Found. Passing ...")
# .pt files found at https://cyberes.github.io/stable-diffusion-textual-inversion-models/
# They are actually .bin files
elif len(embedding_ckpt.keys())==1:
print('>> Detected .bin file masquerading as .pt file')
elif len(embedding_ckpt.keys()) == 1:
print(">> Detected .bin file masquerading as .pt file")
embedding_info = self._parse_embedding_bin(embedding_file)
else:
print('>> Invalid embedding format')
print(">> Invalid embedding format")
embedding_info = None
return embedding_info
def _parse_embedding_bin(self, embedding_file):
embedding_ckpt = torch.load(embedding_file, map_location='cpu')
embedding_ckpt = torch.load(embedding_file, map_location="cpu")
embedding_info = {}
if list(embedding_ckpt.keys()) == 0:
@@ -274,27 +351,45 @@ class TextualInversionManager():
embedding_info = None
else:
for token in list(embedding_ckpt.keys()):
embedding_info['name'] = token or os.path.basename(os.path.splitext(embedding_file)[0])
embedding_info['embedding'] = embedding_ckpt[token]
embedding_info['num_vectors_per_token'] = 1 # All Concepts seem to default to 1
embedding_info['token_dim'] = embedding_info['embedding'].size()[0]
embedding_info["name"] = token or os.path.basename(
os.path.splitext(embedding_file)[0]
)
embedding_info["embedding"] = embedding_ckpt[token]
embedding_info[
"num_vectors_per_token"
] = 1 # All Concepts seem to default to 1
embedding_info["token_dim"] = embedding_info["embedding"].size()[0]
return embedding_info
def _handle_broken_pt_variants(self, embedding_ckpt:dict, embedding_file:str)->dict:
'''
def _handle_broken_pt_variants(
self, embedding_ckpt: dict, embedding_file: str
) -> dict:
"""
This handles the broken .pt file variants. We only know of one at present.
'''
"""
embedding_info = {}
if isinstance(list(embedding_ckpt['string_to_token'].values())[0],torch.Tensor):
print('>> Detected .pt file variant 1') # example at https://github.com/invoke-ai/InvokeAI/issues/1829
for token in list(embedding_ckpt['string_to_token'].keys()):
embedding_info['name'] = token if token != '*' else os.path.basename(os.path.splitext(embedding_file)[0])
embedding_info['embedding'] = embedding_ckpt['string_to_param'].state_dict()[token]
embedding_info['num_vectors_per_token'] = embedding_info['embedding'].shape[0]
embedding_info['token_dim'] = embedding_info['embedding'].size()[0]
if isinstance(
list(embedding_ckpt["string_to_token"].values())[0], torch.Tensor
):
print(
">> Detected .pt file variant 1"
) # example at https://github.com/invoke-ai/InvokeAI/issues/1829
for token in list(embedding_ckpt["string_to_token"].keys()):
embedding_info["name"] = (
token
if token != "*"
else os.path.basename(os.path.splitext(embedding_file)[0])
)
embedding_info["embedding"] = embedding_ckpt[
"string_to_param"
].state_dict()[token]
embedding_info["num_vectors_per_token"] = embedding_info[
"embedding"
].shape[0]
embedding_info["token_dim"] = embedding_info["embedding"].size()[0]
else:
print('>> Invalid embedding format')
print(">> Invalid embedding format")
embedding_info = None
return embedding_info

View File

@@ -1,20 +1,21 @@
import importlib
import math
import multiprocessing as mp
import os
import re
from collections import abc
from inspect import isfunction
from pathlib import Path
from queue import Queue
from threading import Thread
from urllib import request
from tqdm import tqdm
from pathlib import Path
from ldm.invoke.devices import torch_dtype
import numpy as np
import requests
import torch
import os
import traceback
from PIL import Image, ImageDraw, ImageFont
from tqdm import tqdm
from ldm.invoke.devices import torch_dtype
def log_txt_as_img(wh, xc, size=10):
@@ -23,18 +24,18 @@ def log_txt_as_img(wh, xc, size=10):
b = len(xc)
txts = list()
for bi in range(b):
txt = Image.new('RGB', wh, color='white')
txt = Image.new("RGB", wh, color="white")
draw = ImageDraw.Draw(txt)
font = ImageFont.load_default()
nc = int(40 * (wh[0] / 256))
lines = '\n'.join(
lines = "\n".join(
xc[bi][start : start + nc] for start in range(0, len(xc[bi]), nc)
)
try:
draw.text((0, 0), lines, fill='black', font=font)
draw.text((0, 0), lines, fill="black", font=font)
except UnicodeEncodeError:
print('Cant encode string for logging. Skipping.')
print("Cant encode string for logging. Skipping.")
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
txts.append(txt)
@@ -77,25 +78,23 @@ def count_params(model, verbose=False):
total_params = sum(p.numel() for p in model.parameters())
if verbose:
print(
f' | {model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.'
f" | {model.__class__.__name__} has {total_params * 1.e-6:.2f} M params."
)
return total_params
def instantiate_from_config(config, **kwargs):
if not 'target' in config:
if config == '__is_first_stage__':
if not "target" in config:
if config == "__is_first_stage__":
return None
elif config == '__is_unconditional__':
elif config == "__is_unconditional__":
return None
raise KeyError('Expected key `target` to instantiate.')
return get_obj_from_str(config['target'])(
**config.get('params', dict()), **kwargs
)
raise KeyError("Expected key `target` to instantiate.")
return get_obj_from_str(config["target"])(**config.get("params", dict()), **kwargs)
def get_obj_from_str(string, reload=False):
module, cls = string.rsplit('.', 1)
module, cls = string.rsplit(".", 1)
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
@@ -111,14 +110,14 @@ def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):
else:
res = func(data)
Q.put([idx, res])
Q.put('Done')
Q.put("Done")
def parallel_data_prefetch(
func: callable,
data,
n_proc,
target_data_type='ndarray',
target_data_type="ndarray",
cpu_intensive=True,
use_worker_id=False,
):
@@ -126,21 +125,21 @@ def parallel_data_prefetch(
# raise ValueError(
# "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray."
# )
if isinstance(data, np.ndarray) and target_data_type == 'list':
raise ValueError('list expected but function got ndarray.')
if isinstance(data, np.ndarray) and target_data_type == "list":
raise ValueError("list expected but function got ndarray.")
elif isinstance(data, abc.Iterable):
if isinstance(data, dict):
print(
f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
)
data = list(data.values())
if target_data_type == 'ndarray':
if target_data_type == "ndarray":
data = np.asarray(data)
else:
data = list(data)
else:
raise TypeError(
f'The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}.'
f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}."
)
if cpu_intensive:
@@ -150,7 +149,7 @@ def parallel_data_prefetch(
Q = Queue(1000)
proc = Thread
# spawn processes
if target_data_type == 'ndarray':
if target_data_type == "ndarray":
arguments = [
[func, Q, part, i, use_worker_id]
for i, part in enumerate(np.array_split(data, n_proc))
@@ -173,7 +172,7 @@ def parallel_data_prefetch(
processes += [p]
# start processes
print(f'Start prefetching...')
print("Start prefetching...")
import time
start = time.time()
@@ -186,13 +185,13 @@ def parallel_data_prefetch(
while k < n_proc:
# get result
res = Q.get()
if res == 'Done':
if res == "Done":
k += 1
else:
gather_res[res[0]] = res[1]
except Exception as e:
print('Exception: ', e)
print("Exception: ", e)
for p in processes:
p.terminate()
@@ -200,15 +199,15 @@ def parallel_data_prefetch(
finally:
for p in processes:
p.join()
print(f'Prefetching complete. [{time.time() - start} sec.]')
print(f"Prefetching complete. [{time.time() - start} sec.]")
if target_data_type == 'ndarray':
if target_data_type == "ndarray":
if not isinstance(gather_res[0], np.ndarray):
return np.concatenate([np.asarray(r) for r in gather_res], axis=0)
# order outputs
return np.concatenate(gather_res, axis=0)
elif target_data_type == 'list':
elif target_data_type == "list":
out = []
for r in gather_res:
out.extend(r)
@@ -216,49 +215,79 @@ def parallel_data_prefetch(
else:
return gather_res
def rand_perlin_2d(shape, res, device, fade = lambda t: 6*t**5 - 15*t**4 + 10*t**3):
def rand_perlin_2d(
shape, res, device, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3
):
delta = (res[0] / shape[0], res[1] / shape[1])
d = (shape[0] // res[0], shape[1] // res[1])
grid = torch.stack(torch.meshgrid(torch.arange(0, res[0], delta[0]), torch.arange(0, res[1], delta[1]), indexing='ij'), dim = -1).to(device) % 1
grid = (
torch.stack(
torch.meshgrid(
torch.arange(0, res[0], delta[0]),
torch.arange(0, res[1], delta[1]),
indexing="ij",
),
dim=-1,
).to(device)
% 1
)
rand_val = torch.rand(res[0]+1, res[1]+1)
rand_val = torch.rand(res[0] + 1, res[1] + 1)
angles = 2*math.pi*rand_val
gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim = -1).to(device)
angles = 2 * math.pi * rand_val
gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1).to(device)
tile_grads = lambda slice1, slice2: gradients[slice1[0]:slice1[1], slice2[0]:slice2[1]].repeat_interleave(d[0], 0).repeat_interleave(d[1], 1)
tile_grads = (
lambda slice1, slice2: gradients[slice1[0] : slice1[1], slice2[0] : slice2[1]]
.repeat_interleave(d[0], 0)
.repeat_interleave(d[1], 1)
)
dot = lambda grad, shift: (torch.stack((grid[:shape[0],:shape[1],0] + shift[0], grid[:shape[0],:shape[1], 1] + shift[1] ), dim = -1) * grad[:shape[0], :shape[1]]).sum(dim = -1)
dot = lambda grad, shift: (
torch.stack(
(
grid[: shape[0], : shape[1], 0] + shift[0],
grid[: shape[0], : shape[1], 1] + shift[1],
),
dim=-1,
)
* grad[: shape[0], : shape[1]]
).sum(dim=-1)
n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0]).to(device)
n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0]).to(device)
n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0]).to(device)
n01 = dot(tile_grads([0, -1],[1, None]), [0, -1]).to(device)
n11 = dot(tile_grads([1, None], [1, None]), [-1,-1]).to(device)
t = fade(grid[:shape[0], :shape[1]])
noise = math.sqrt(2) * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1]).to(device)
n01 = dot(tile_grads([0, -1], [1, None]), [0, -1]).to(device)
n11 = dot(tile_grads([1, None], [1, None]), [-1, -1]).to(device)
t = fade(grid[: shape[0], : shape[1]])
noise = math.sqrt(2) * torch.lerp(
torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1]
).to(device)
return noise.to(dtype=torch_dtype(device))
def ask_user(question: str, answers: list):
from itertools import chain, repeat
user_prompt = f'\n>> {question} {answers}: '
invalid_answer_msg = 'Invalid answer. Please try again.'
pose_question = chain([user_prompt], repeat('\n'.join([invalid_answer_msg, user_prompt])))
user_prompt = f"\n>> {question} {answers}: "
invalid_answer_msg = "Invalid answer. Please try again."
pose_question = chain(
[user_prompt], repeat("\n".join([invalid_answer_msg, user_prompt]))
)
user_answers = map(input, pose_question)
valid_response = next(filter(answers.__contains__, user_answers))
return valid_response
def debug_image(debug_image, debug_text, debug_show=True, debug_result=False, debug_status=False ):
def debug_image(
debug_image, debug_text, debug_show=True, debug_result=False, debug_status=False
):
if not debug_status:
return
image_copy = debug_image.copy().convert("RGBA")
ImageDraw.Draw(image_copy).text(
(5, 5),
debug_text,
(255, 0, 0)
)
ImageDraw.Draw(image_copy).text((5, 5), debug_text, (255, 0, 0))
if debug_show:
image_copy.show()
@@ -266,31 +295,84 @@ def debug_image(debug_image, debug_text, debug_show=True, debug_result=False, de
if debug_result:
return image_copy
#-------------------------------------
class ProgressBar():
def __init__(self,model_name='file'):
self.pbar = None
self.name = model_name
def __call__(self, block_num, block_size, total_size):
if not self.pbar:
self.pbar=tqdm(desc=self.name,
initial=0,
unit='iB',
unit_scale=True,
unit_divisor=1000,
total=total_size)
self.pbar.update(block_size)
# -------------------------------------
def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path:
'''
Download a model file.
:param url: https, http or ftp URL
:param dest: A Path object. If path exists and is a directory, then we try to derive the filename
from the URL's Content-Disposition header and copy the URL contents into
dest/filename
:param access_token: Access token to access this resource
'''
resp = requests.get(url, stream=True)
total = int(resp.headers.get("content-length", 0))
if dest.is_dir():
try:
file_name = re.search('filename="(.+)"', resp.headers.get("Content-Disposition")).group(1)
except:
file_name = os.path.basename(url)
dest = dest / file_name
else:
dest.parent.mkdir(parents=True, exist_ok=True)
print(f'DEBUG: after many manipulations, dest={dest}')
header = {"Authorization": f"Bearer {access_token}"} if access_token else {}
open_mode = "wb"
exist_size = 0
if dest.exists():
exist_size = dest.stat().st_size
header["Range"] = f"bytes={exist_size}-"
open_mode = "ab"
if (
resp.status_code == 416
): # "range not satisfiable", which means nothing to return
print(f"* {dest}: complete file found. Skipping.")
return dest
elif resp.status_code != 200:
print(f"** An error occurred during downloading {dest}: {resp.reason}")
elif exist_size > 0:
print(f"* {dest}: partial file found. Resuming...")
else:
print(f"* {dest}: Downloading...")
def download_with_progress_bar(url:str, dest:Path)->bool:
try:
if not dest.exists():
dest.parent.mkdir(parents=True, exist_ok=True)
request.urlretrieve(url,dest,ProgressBar(dest.stem))
return True
else:
return True
except OSError:
print(traceback.format_exc())
return False
if total < 2000:
print(f"*** ERROR DOWNLOADING {url}: {resp.text}")
return None
with open(dest, open_mode) as file, tqdm(
desc=str(dest),
initial=exist_size,
total=total + exist_size,
unit="iB",
unit_scale=True,
unit_divisor=1000,
) as bar:
for data in resp.iter_content(chunk_size=1024):
size = file.write(data)
bar.update(size)
except Exception as e:
print(f"An error occurred while downloading {dest}: {str(e)}")
return None
return dest
def url_attachment_name(url: str) -> dict:
try:
resp = requests.get(url, stream=True)
match = re.search('filename="(.+)"', resp.headers.get("Content-Disposition"))
return match.group(1)
except:
return None
def download_with_progress_bar(url: str, dest: Path) -> bool:
result = download_with_resume(url, dest, access_token=None)
return result is not None