Merge branch 'add_lora_support' of https://github.com/jordanramstad/InvokeAI into add_lora_support

This commit is contained in:
Jordan
2023-02-20 16:50:16 -07:00
118 changed files with 198390 additions and 4784 deletions

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -1 +1 @@
__version__='2.3.0'
__version__='2.3.1+a0'

View File

@@ -272,6 +272,10 @@ class Args(object):
switches.append('--seamless')
if a['hires_fix']:
switches.append('--hires_fix')
if a['h_symmetry_time_pct']:
switches.append(f'--h_symmetry_time_pct {a["h_symmetry_time_pct"]}')
if a['v_symmetry_time_pct']:
switches.append(f'--v_symmetry_time_pct {a["v_symmetry_time_pct"]}')
# img2img generations have parameters relevant only to them and have special handling
if a['init_img'] and len(a['init_img'])>0:
@@ -751,6 +755,9 @@ class Args(object):
!fix applies upscaling/facefixing to a previously-generated image.
invoke> !fix 0000045.4829112.png -G1 -U4 -ft codeformer
*embeddings*
invoke> !triggers -- return all trigger phrases contained in loaded embedding files
*History manipulation*
!fetch retrieves the command used to generate an earlier image. Provide
a directory wildcard and the name of a file to write and all the commands
@@ -842,6 +849,18 @@ class Args(object):
type=float,
help='Perlin noise scale (0.0 - 1.0) - add perlin noise to the initialization instead of the usual gaussian noise.',
)
render_group.add_argument(
'--h_symmetry_time_pct',
default=None,
type=float,
help='Horizontal symmetry point (0.0 - 1.0) - apply horizontal symmetry at this point in image generation.',
)
render_group.add_argument(
'--v_symmetry_time_pct',
default=None,
type=float,
help='Vertical symmetry point (0.0 - 1.0) - apply vertical symmetry at this point in image generation.',
)
render_group.add_argument(
'--fnformat',
default='{prefix}.{seed}.png',
@@ -1148,7 +1167,8 @@ def metadata_dumps(opt,
# remove any image keys not mentioned in RFC #266
rfc266_img_fields = ['type','postprocessing','sampler','prompt','seed','variations','steps',
'cfg_scale','threshold','perlin','step_number','width','height','extra','strength','seamless'
'init_img','init_mask','facetool','facetool_strength','upscale']
'init_img','init_mask','facetool','facetool_strength','upscale','h_symmetry_time_pct',
'v_symmetry_time_pct']
rfc_dict ={}
for item in image_dict.items():

View File

@@ -53,6 +53,7 @@ from diffusers import (
)
from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder, PaintByExamplePipeline
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.utils import is_safetensors_available
from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextModel, CLIPTokenizer, CLIPVisionConfig
@@ -984,6 +985,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
elif model_type in ['FrozenCLIPEmbedder','WeightedFrozenCLIPEmbedder']:
text_model = convert_ldm_clip_checkpoint(checkpoint)
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14",cache_dir=cache_dir)
safety_checker = StableDiffusionSafetyChecker.from_pretrained('CompVis/stable-diffusion-safety-checker',cache_dir=global_cache_dir("hub"))
feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker",cache_dir=cache_dir)
pipe = pipeline_class(
vae=vae,
@@ -991,7 +993,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=None,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
else:

View File

@@ -40,7 +40,6 @@ from ldm.invoke.globals import Globals, global_cache_dir, global_config_dir
from ldm.invoke.readline import generic_completer
warnings.filterwarnings("ignore")
import torch
transformers.logging.set_verbosity_error()
@@ -764,7 +763,7 @@ def download_weights(opt: dict) -> Union[str, None]:
precision = (
"float32"
if opt.full_precision
else choose_precision(torch.device(choose_torch_device()))
else choose_precision(choose_torch_device())
)
if opt.yes_to_all:

View File

@@ -0,0 +1,102 @@
'''
Minimalist updater script. Prompts user for the tag or branch to update to and runs
pip install <path_to_git_source>.
'''
import platform
import requests
import subprocess
from rich import box, print
from rich.console import Console, group
from rich.panel import Panel
from rich.prompt import Prompt
from rich.style import Style
from rich.text import Text
from rich.live import Live
from rich.table import Table
from ldm.invoke import __version__
INVOKE_AI_SRC="https://github.com/invoke-ai/InvokeAI/archive"
INVOKE_AI_REL="https://api.github.com/repos/invoke-ai/InvokeAI/releases"
OS = platform.uname().system
ARCH = platform.uname().machine
ORANGE_ON_DARK_GREY = Style(bgcolor="grey23", color="orange1")
if OS == "Windows":
# Windows terminals look better without a background colour
console = Console(style=Style(color="grey74"))
else:
console = Console(style=Style(color="grey74", bgcolor="grey23"))
def get_versions()->dict:
return requests.get(url=INVOKE_AI_REL).json()
def welcome(versions: dict):
@group()
def text():
yield f'InvokeAI Version: [bold yellow]{__version__}'
yield ''
yield 'This script will update InvokeAI to the latest release, or to a development version of your choice.'
yield ''
yield '[bold yellow]Options:'
yield f'''[1] Update to the latest official release ([italic]{versions[0]['tag_name']}[/italic])
[2] Update to the bleeding-edge development version ([italic]main[/italic])
[3] Manually enter the tag or branch name you wish to update'''
console.rule()
console.print(
Panel(
title="[bold wheat1]InvokeAI Updater",
renderable=text(),
box=box.DOUBLE,
expand=True,
padding=(1, 2),
style=ORANGE_ON_DARK_GREY,
subtitle=f"[bold grey39]{OS}-{ARCH}",
)
)
# console.rule is used instead of console.line to maintain dark background
# on terminals where light background is the default
console.rule(characters=" ")
def main():
versions = get_versions()
welcome(versions)
tag = None
choice = Prompt.ask(Text.from_markup(('[grey74 on grey23]Choice:')),choices=['1','2','3'],default='1')
if choice=='1':
tag = versions[0]['tag_name']
elif choice=='2':
tag = 'main'
elif choice=='3':
tag = Prompt.ask('[grey74 on grey23]Enter an InvokeAI tag or branch name')
console.print(Panel(f':crossed_fingers: Upgrading to [yellow]{tag}[/yellow]', box=box.MINIMAL, style=ORANGE_ON_DARK_GREY))
cmd = f'pip install {INVOKE_AI_SRC}/{tag}.zip --use-pep517'
progress = Table.grid(expand=True)
progress_panel = Panel(progress, box=box.MINIMAL, style=ORANGE_ON_DARK_GREY)
with subprocess.Popen(['bash', '-c', cmd], stdout=subprocess.PIPE, stderr=subprocess.PIPE) as proc:
progress.add_column()
with Live(progress_panel, console=console, vertical_overflow='visible'):
while proc.poll() is None:
for l in iter(proc.stdout.readline, b''):
progress.add_row(l.decode().strip(), style=ORANGE_ON_DARK_GREY)
if proc.returncode == 0:
console.rule(f':heavy_check_mark: Upgrade successful')
else:
console.rule(f':exclamation: [bold red]Upgrade failed[/red bold]')
if __name__ == "__main__":
try:
main()
except KeyboardInterrupt:
pass

View File

@@ -1,19 +1,25 @@
from __future__ import annotations
from contextlib import nullcontext
import torch
from torch import autocast
from contextlib import nullcontext
from ldm.invoke.globals import Globals
def choose_torch_device() -> str:
CPU_DEVICE = torch.device("cpu")
def choose_torch_device() -> torch.device:
'''Convenience routine for guessing which GPU device to run model on'''
if Globals.always_use_cpu:
return "cpu"
return CPU_DEVICE
if torch.cuda.is_available():
return 'cuda'
return torch.device('cuda')
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
return 'mps'
return 'cpu'
return torch.device('mps')
return CPU_DEVICE
def choose_precision(device) -> str:
def choose_precision(device: torch.device) -> str:
'''Returns an appropriate precision for the given torch device'''
if device.type == 'cuda':
device_name = torch.cuda.get_device_name(device)
@@ -21,7 +27,7 @@ def choose_precision(device) -> str:
return 'float16'
return 'float32'
def torch_dtype(device) -> torch.dtype:
def torch_dtype(device: torch.device) -> torch.dtype:
if Globals.full_precision:
return torch.float32
if choose_precision(device) == 'float16':
@@ -36,3 +42,13 @@ def choose_autocast(precision):
if precision == 'autocast' or precision == 'float16':
return autocast
return nullcontext
def normalize_device(device: str | torch.device) -> torch.device:
"""Ensure device has a device index defined, if appropriate."""
device = torch.device(device)
if device.index is None:
# cuda might be the only torch backend that currently uses the device index?
# I don't see anything like `current_device` for cpu or mps.
if device.type == 'cuda':
device = torch.device(device.type, torch.cuda.current_device())
return device

View File

@@ -64,6 +64,7 @@ class Generator:
def generate(self,prompt,init_image,width,height,sampler, iterations=1,seed=None,
image_callback=None, step_callback=None, threshold=0.0, perlin=0.0,
h_symmetry_time_pct=None, v_symmetry_time_pct=None,
safety_checker:dict=None,
free_gpu_mem: bool=False,
**kwargs):
@@ -81,6 +82,8 @@ class Generator:
step_callback = step_callback,
threshold = threshold,
perlin = perlin,
h_symmetry_time_pct = h_symmetry_time_pct,
v_symmetry_time_pct = v_symmetry_time_pct,
attention_maps_callback = attention_maps_callback,
**kwargs
)

View File

@@ -29,6 +29,7 @@ from ldm.invoke.globals import Globals
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent, PostprocessingSettings
from ldm.modules.textual_inversion_manager import TextualInversionManager
from ldm.modules.lora_manager import LoraManager
from ..devices import normalize_device, CPU_DEVICE
from ..offloading import LazilyLoadedModelGroup, FullyLoadedModelGroup, ModelGroup
from ...models.diffusion.cross_attention_map_saving import AttentionMapSaver
from ...modules.prompt_to_embeddings_converter import WeightedPromptFragmentsToEmbeddingsConverter
@@ -322,7 +323,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if self.device.type == 'cpu' or self.device.type == 'mps':
mem_free = psutil.virtual_memory().free
elif self.device.type == 'cuda':
mem_free, _ = torch.cuda.mem_get_info(self.device)
mem_free, _ = torch.cuda.mem_get_info(normalize_device(self.device))
else:
raise ValueError(f"unrecognized device {self.device}")
# input tensor of [1, 4, h/8, w/8]
@@ -383,9 +384,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
self._model_group.ready()
def to(self, torch_device: Optional[Union[str, torch.device]] = None):
# overridden method; types match the superclass.
if torch_device is None:
return self
self._model_group.set_device(torch_device)
self._model_group.set_device(torch.device(torch_device))
self._model_group.ready()
@property
@@ -692,8 +694,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if device.type == 'mps':
# workaround for torch MPS bug that has been fixed in https://github.com/kulinseth/pytorch/pull/222
# TODO remove this workaround once kulinseth#222 is merged to pytorch mainline
self.vae.to('cpu')
init_image = init_image.to('cpu')
self.vae.to(CPU_DEVICE)
init_image = init_image.to(CPU_DEVICE)
else:
self._model_group.load(self.vae)
init_latent_dist = self.vae.encode(init_image).latent_dist

View File

@@ -16,8 +16,8 @@ class Img2Img(Generator):
self.init_latent = None # by get_noise()
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
conditioning,init_image,strength,step_callback=None,threshold=0.0,perlin=0.0,
attention_maps_callback=None,
conditioning,init_image,strength,step_callback=None,threshold=0.0,warmup=0.2,perlin=0.0,
h_symmetry_time_pct=None,v_symmetry_time_pct=None,attention_maps_callback=None,
**kwargs):
"""
Returns a function returning an image derived from the prompt and the initial image
@@ -33,8 +33,13 @@ class Img2Img(Generator):
conditioning_data = (
ConditioningData(
uc, c, cfg_scale, extra_conditioning_info,
postprocessing_settings = PostprocessingSettings(threshold, warmup=0.2) if threshold else None)
.add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta))
postprocessing_settings=PostprocessingSettings(
threshold=threshold,
warmup=warmup,
h_symmetry_time_pct=h_symmetry_time_pct,
v_symmetry_time_pct=v_symmetry_time_pct
)
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta))
def make_image(x_T):

View File

@@ -15,8 +15,8 @@ class Txt2Img(Generator):
@torch.no_grad()
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
conditioning,width,height,step_callback=None,threshold=0.0,perlin=0.0,
attention_maps_callback=None,
conditioning,width,height,step_callback=None,threshold=0.0,warmup=0.2,perlin=0.0,
h_symmetry_time_pct=None,v_symmetry_time_pct=None,attention_maps_callback=None,
**kwargs):
"""
Returns a function returning an image derived from the prompt and the initial image
@@ -33,8 +33,13 @@ class Txt2Img(Generator):
conditioning_data = (
ConditioningData(
uc, c, cfg_scale, extra_conditioning_info,
postprocessing_settings = PostprocessingSettings(threshold, warmup=0.2) if threshold else None)
.add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta))
postprocessing_settings=PostprocessingSettings(
threshold=threshold,
warmup=warmup,
h_symmetry_time_pct=h_symmetry_time_pct,
v_symmetry_time_pct=v_symmetry_time_pct
)
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta))
def make_image(x_T) -> PIL.Image.Image:
pipeline_output = pipeline.image_from_embeddings(
@@ -44,8 +49,10 @@ class Txt2Img(Generator):
conditioning_data=conditioning_data,
callback=step_callback,
)
if pipeline_output.attention_map_saver is not None and attention_maps_callback is not None:
attention_maps_callback(pipeline_output.attention_map_saver)
return pipeline.numpy_to_pil(pipeline_output.images)[0]
return make_image

View File

@@ -21,12 +21,14 @@ class Txt2Img2Img(Generator):
def get_make_image(self, prompt:str, sampler, steps:int, cfg_scale:float, ddim_eta,
conditioning, width:int, height:int, strength:float,
step_callback:Optional[Callable]=None, threshold=0.0, **kwargs):
step_callback:Optional[Callable]=None, threshold=0.0, warmup=0.2, perlin=0.0,
h_symmetry_time_pct=None, v_symmetry_time_pct=None, attention_maps_callback=None, **kwargs):
"""
Returns a function returning an image derived from the prompt and the initial image
Return value depends on the seed at the time you call it
kwargs are 'width' and 'height'
"""
self.perlin = perlin
# noinspection PyTypeChecker
pipeline: StableDiffusionGeneratorPipeline = self.model
@@ -36,8 +38,13 @@ class Txt2Img2Img(Generator):
conditioning_data = (
ConditioningData(
uc, c, cfg_scale, extra_conditioning_info,
postprocessing_settings = PostprocessingSettings(threshold=threshold, warmup=0.2) if threshold else None)
.add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta))
postprocessing_settings = PostprocessingSettings(
threshold=threshold,
warmup=0.2,
h_symmetry_time_pct=h_symmetry_time_pct,
v_symmetry_time_pct=v_symmetry_time_pct
)
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta))
def make_image(x_T):
@@ -69,19 +76,28 @@ class Txt2Img2Img(Generator):
if clear_cuda_cache is not None:
clear_cuda_cache()
second_pass_noise = self.get_noise_like(resized_latents)
second_pass_noise = self.get_noise_like(resized_latents, override_perlin=True)
# Clear symmetry for the second pass
from dataclasses import replace
new_postprocessing_settings = replace(conditioning_data.postprocessing_settings, h_symmetry_time_pct=None)
new_postprocessing_settings = replace(new_postprocessing_settings, v_symmetry_time_pct=None)
new_conditioning_data = replace(conditioning_data, postprocessing_settings=new_postprocessing_settings)
verbosity = get_verbosity()
set_verbosity_error()
pipeline_output = pipeline.img2img_from_latents_and_embeddings(
resized_latents,
num_inference_steps=steps,
conditioning_data=conditioning_data,
conditioning_data=new_conditioning_data,
strength=strength,
noise=second_pass_noise,
callback=step_callback)
set_verbosity(verbosity)
if pipeline_output.attention_map_saver is not None and attention_maps_callback is not None:
attention_maps_callback(pipeline_output.attention_map_saver)
return pipeline.numpy_to_pil(pipeline_output.images)[0]
@@ -95,13 +111,13 @@ class Txt2Img2Img(Generator):
return make_image
def get_noise_like(self, like: torch.Tensor):
def get_noise_like(self, like: torch.Tensor, override_perlin: bool=False):
device = like.device
if device.type == 'mps':
x = torch.randn_like(like, device='cpu', dtype=self.torch_dtype()).to(device)
else:
x = torch.randn_like(like, device=device, dtype=self.torch_dtype())
if self.perlin > 0.0:
if self.perlin > 0.0 and override_perlin == False:
shape = like.shape
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(shape[3], shape[2])
return x
@@ -139,6 +155,9 @@ class Txt2Img2Img(Generator):
shape = (1, channels,
scaled_height // self.downsampling_factor, scaled_width // self.downsampling_factor)
if self.use_mps_noise or device.type == 'mps':
return torch.randn(shape, dtype=self.torch_dtype(), device='cpu').to(device)
tensor = torch.empty(size=shape, device='cpu')
tensor = self.get_noise_like(like=tensor).to(device)
else:
return torch.randn(shape, dtype=self.torch_dtype(), device=device)
tensor = torch.empty(size=shape, device=device)
tensor = self.get_noise_like(like=tensor)
return tensor

View File

@@ -323,7 +323,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
if selected_model3 > 0:
self.merge_method.values = ['add_difference ( A+(B-C) )']
self.merged_model_name.value += f"+{models[selected_model3]}"
self.merged_model_name.value += f"+{models[selected_model3 -1]}" # In model3 there is one more element in the list (None). So we have to subtract one.
else:
self.merge_method.values = self.interpolations
self.merge_method.value = 0
@@ -419,8 +419,7 @@ def run_gui(args: Namespace):
mergeapp.run()
args = mergeapp.merge_arguments
print(f'DEBUG: {args}')
#merge_diffusion_models_and_commit(**args)
merge_diffusion_models_and_commit(**args)
print(f'>> Models merged into new model: "{args["merged_model_name"]}".')

View File

@@ -30,6 +30,7 @@ from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
from picklescan.scanner import scan_file_path
from ldm.invoke.devices import CPU_DEVICE
from ldm.invoke.generator.diffusers_pipeline import \
StableDiffusionGeneratorPipeline
from ldm.invoke.globals import (Globals, global_autoscan_dir, global_cache_dir,
@@ -47,7 +48,7 @@ class ModelManager(object):
def __init__(
self,
config: OmegaConf,
device_type: str | torch.device = "cpu",
device_type: torch.device = CPU_DEVICE,
precision: str = "float16",
max_loaded_models=DEFAULT_MAX_MODELS,
sequential_offload = False
@@ -675,7 +676,7 @@ class ModelManager(object):
"""
if str(weights).startswith(("http:", "https:")):
model_name = model_name or url_attachment_name(weights)
weights_path = self._resolve_path(weights, "models/ldm/stable-diffusion-v1")
config_path = self._resolve_path(config, "configs/stable-diffusion")
@@ -996,25 +997,25 @@ class ModelManager(object):
self.models.pop(model_name, None)
def _model_to_cpu(self, model):
if self.device == "cpu":
if self.device == CPU_DEVICE:
return model
if isinstance(model, StableDiffusionGeneratorPipeline):
model.offload_all()
return model
model.cond_stage_model.device = "cpu"
model.to("cpu")
model.cond_stage_model.device = CPU_DEVICE
model.to(CPU_DEVICE)
for submodel in ("first_stage_model", "cond_stage_model", "model"):
try:
getattr(model, submodel).to("cpu")
getattr(model, submodel).to(CPU_DEVICE)
except AttributeError:
pass
return model
def _model_from_cpu(self, model):
if self.device == "cpu":
if self.device == CPU_DEVICE:
return model
if isinstance(model, StableDiffusionGeneratorPipeline):

View File

@@ -58,9 +58,11 @@ COMMANDS = (
'--inpaint_replace','-r',
'--png_compression','-z',
'--text_mask','-tm',
'--h_symmetry_time_pct',
'--v_symmetry_time_pct',
'!fix','!fetch','!replay','!history','!search','!clear',
'!models','!switch','!import_model','!optimize_model','!convert_model','!edit_model','!del_model',
'!mask',
'!mask','!triggers',
)
MODEL_COMMANDS = (
'!switch',
@@ -138,7 +140,7 @@ class Completer(object):
elif re.match('^'+'|'.join(MODEL_COMMANDS),buffer):
self.matches= self._model_completions(text, state)
# looking for a ckpt model
# looking for a ckpt model
elif re.match('^'+'|'.join(CKPT_MODEL_COMMANDS),buffer):
self.matches= self._model_completions(text, state, ckpt_only=True)
@@ -255,7 +257,7 @@ class Completer(object):
update our list of models
'''
self.models = models
def _seed_completions(self, text, state):
m = re.search('(-S\s?|--seed[=\s]?)(\d*)',text)
if m:

View File

@@ -441,6 +441,7 @@ class TextualInversionDataset(Dataset):
self.image_paths = [
os.path.join(self.data_root, file_path)
for file_path in os.listdir(self.data_root)
if os.path.isfile(file_path) and file_path.endswith(('.png','.PNG','.jpg','.JPG','.jpeg','.JPEG','.gif','.GIF'))
]
self.num_images = len(self.image_paths)

View File

@@ -584,7 +584,9 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):
# print(f"SwapCrossAttnContext for {attention_type} active")
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
attention_mask = attn.prepare_attention_mask(
attention_mask=attention_mask, target_length=sequence_length,
batch_size=batch_size)
query = attn.to_q(hidden_states)
dim = query.shape[-1]

View File

@@ -18,6 +18,8 @@ from ldm.models.diffusion.cross_attention_map_saving import AttentionMapSaver
class PostprocessingSettings:
threshold: float
warmup: float
h_symmetry_time_pct: Optional[float]
v_symmetry_time_pct: Optional[float]
class InvokeAIDiffuserComponent:
@@ -30,7 +32,7 @@ class InvokeAIDiffuserComponent:
* Hybrid conditioning (used for inpainting)
'''
debug_thresholding = False
last_percent_through = 0.0
@dataclass
class ExtraConditioningInfo:
@@ -56,6 +58,7 @@ class InvokeAIDiffuserComponent:
self.is_running_diffusers = is_running_diffusers
self.model_forward_callback = model_forward_callback
self.cross_attention_control_context = None
self.last_percent_through = 0.0
@contextmanager
def custom_attention_context(self,
@@ -164,6 +167,7 @@ class InvokeAIDiffuserComponent:
if postprocessing_settings is not None:
percent_through = self.calculate_percent_through(sigma, step_index, total_step_count)
latents = self.apply_threshold(postprocessing_settings, latents, percent_through)
latents = self.apply_symmetry(postprocessing_settings, latents, percent_through)
return latents
def calculate_percent_through(self, sigma, step_index, total_step_count):
@@ -292,8 +296,12 @@ class InvokeAIDiffuserComponent:
self,
postprocessing_settings: PostprocessingSettings,
latents: torch.Tensor,
percent_through
percent_through: float
) -> torch.Tensor:
if postprocessing_settings.threshold is None or postprocessing_settings.threshold == 0.0:
return latents
threshold = postprocessing_settings.threshold
warmup = postprocessing_settings.warmup
@@ -342,6 +350,56 @@ class InvokeAIDiffuserComponent:
return latents
def apply_symmetry(
self,
postprocessing_settings: PostprocessingSettings,
latents: torch.Tensor,
percent_through: float
) -> torch.Tensor:
# Reset our last percent through if this is our first step.
if percent_through == 0.0:
self.last_percent_through = 0.0
if postprocessing_settings is None:
return latents
# Check for out of bounds
h_symmetry_time_pct = postprocessing_settings.h_symmetry_time_pct
if (h_symmetry_time_pct is not None and (h_symmetry_time_pct <= 0.0 or h_symmetry_time_pct > 1.0)):
h_symmetry_time_pct = None
v_symmetry_time_pct = postprocessing_settings.v_symmetry_time_pct
if (v_symmetry_time_pct is not None and (v_symmetry_time_pct <= 0.0 or v_symmetry_time_pct > 1.0)):
v_symmetry_time_pct = None
dev = latents.device.type
latents.to(device='cpu')
if (
h_symmetry_time_pct != None and
self.last_percent_through < h_symmetry_time_pct and
percent_through >= h_symmetry_time_pct
):
# Horizontal symmetry occurs on the 3rd dimension of the latent
width = latents.shape[3]
x_flipped = torch.flip(latents, dims=[3])
latents = torch.cat([latents[:, :, :, 0:int(width/2)], x_flipped[:, :, :, int(width/2):int(width)]], dim=3)
if (
v_symmetry_time_pct != None and
self.last_percent_through < v_symmetry_time_pct and
percent_through >= v_symmetry_time_pct
):
# Vertical symmetry occurs on the 2nd dimension of the latent
height = latents.shape[2]
y_flipped = torch.flip(latents, dims=[2])
latents = torch.cat([latents[:, :, 0:int(height / 2)], y_flipped[:, :, int(height / 2):int(height)]], dim=2)
self.last_percent_through = percent_through
return latents.to(device=dev)
def estimate_percent_through(self, step_index, sigma):
if step_index is not None and self.cross_attention_control_context is not None:
# percent_through will never reach 1.0 (but this is intended)

View File

@@ -1,11 +1,12 @@
import os
import traceback
from typing import Optional
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Union
import torch
from dataclasses import dataclass
from picklescan.scanner import scan_file_path
from transformers import CLIPTokenizer, CLIPTextModel
from transformers import CLIPTextModel, CLIPTokenizer
from ldm.invoke.concepts_lib import HuggingFaceConceptsLibrary
@@ -21,11 +22,14 @@ class TextualInversion:
def embedding_vector_length(self) -> int:
return self.embedding.shape[0]
class TextualInversionManager():
def __init__(self,
tokenizer: CLIPTokenizer,
text_encoder: CLIPTextModel,
full_precision: bool=True):
class TextualInversionManager:
def __init__(
self,
tokenizer: CLIPTokenizer,
text_encoder: CLIPTextModel,
full_precision: bool = True,
):
self.tokenizer = tokenizer
self.text_encoder = text_encoder
self.full_precision = full_precision
@@ -38,47 +42,78 @@ class TextualInversionManager():
if concept_name in self.hf_concepts_library.concepts_loaded:
continue
trigger = self.hf_concepts_library.concept_to_trigger(concept_name)
if self.has_textual_inversion_for_trigger_string(trigger) \
or self.has_textual_inversion_for_trigger_string(concept_name) \
or self.has_textual_inversion_for_trigger_string(f'<{concept_name}>'): # in case a token with literal angle brackets encountered
print(f'>> Loaded local embedding for trigger {concept_name}')
if (
self.has_textual_inversion_for_trigger_string(trigger)
or self.has_textual_inversion_for_trigger_string(concept_name)
or self.has_textual_inversion_for_trigger_string(f"<{concept_name}>")
): # in case a token with literal angle brackets encountered
print(f">> Loaded local embedding for trigger {concept_name}")
continue
bin_file = self.hf_concepts_library.get_concept_model_path(concept_name)
if not bin_file:
continue
print(f'>> Loaded remote embedding for trigger {concept_name}')
print(f">> Loaded remote embedding for trigger {concept_name}")
self.load_textual_inversion(bin_file)
self.hf_concepts_library.concepts_loaded[concept_name]=True
self.hf_concepts_library.concepts_loaded[concept_name] = True
def get_all_trigger_strings(self) -> list[str]:
return [ti.trigger_string for ti in self.textual_inversions]
def load_textual_inversion(self, ckpt_path, defer_injecting_tokens: bool=False):
if str(ckpt_path).endswith('.DS_Store'):
def load_textual_inversion(self, ckpt_path: Union[str,Path], defer_injecting_tokens: bool = False):
ckpt_path = Path(ckpt_path)
if not ckpt_path.is_file():
return
if str(ckpt_path).endswith(".DS_Store"):
return
try:
scan_result = scan_file_path(ckpt_path)
scan_result = scan_file_path(str(ckpt_path))
if scan_result.infected_files == 1:
print(f'\n### Security Issues Found in Model: {scan_result.issues_count}')
print('### For your safety, InvokeAI will not load this embed.')
print(
f"\n### Security Issues Found in Model: {scan_result.issues_count}"
)
print("### For your safety, InvokeAI will not load this embed.")
return
except Exception:
print(f"### WARNING::: Invalid or corrupt embeddings found. Ignoring: {ckpt_path}")
print(
f"### {ckpt_path.parents[0].name}/{ckpt_path.name} is damaged or corrupt."
)
return
embedding_info = self._parse_embedding(str(ckpt_path))
if embedding_info is None:
# We've already put out an error message about the bad embedding in _parse_embedding, so just return.
return
elif (
self.text_encoder.get_input_embeddings().weight.data[0].shape[0]
!= embedding_info['token_dim']
):
print(
f"** Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with an incompatible token dimension: {self.text_encoder.get_input_embeddings().weight.data[0].shape[0]} vs {embedding_info['token_dim']}."
)
return
embedding_info = self._parse_embedding(ckpt_path)
if embedding_info:
try:
self._add_textual_inversion(embedding_info['name'],
embedding_info['embedding'],
defer_injecting_tokens=defer_injecting_tokens)
self._add_textual_inversion(
embedding_info["name"],
embedding_info["embedding"],
defer_injecting_tokens=defer_injecting_tokens,
)
except ValueError as e:
print(f' | Ignoring incompatible embedding {embedding_info["name"]}')
print(f' | The error was {str(e)}')
print(f" | The error was {str(e)}")
else:
print(f'>> Failed to load embedding located at {ckpt_path}. Unsupported file.')
print(
f">> Failed to load embedding located at {str(ckpt_path)}. Unsupported file."
)
def _add_textual_inversion(self, trigger_str, embedding, defer_injecting_tokens=False) -> TextualInversion:
def _add_textual_inversion(
self, trigger_str, embedding, defer_injecting_tokens=False
) -> TextualInversion:
"""
Add a textual inversion to be recognised.
:param trigger_str: The trigger text in the prompt that activates this textual inversion. If unknown to the embedder's tokenizer, will be added.
@@ -86,46 +121,59 @@ class TextualInversionManager():
:return: The token id for the added embedding, either existing or newly-added.
"""
if trigger_str in [ti.trigger_string for ti in self.textual_inversions]:
print(f">> TextualInversionManager refusing to overwrite already-loaded token '{trigger_str}'")
print(
f">> TextualInversionManager refusing to overwrite already-loaded token '{trigger_str}'"
)
return
if not self.full_precision:
embedding = embedding.half()
if len(embedding.shape) == 1:
embedding = embedding.unsqueeze(0)
elif len(embedding.shape) > 2:
raise ValueError(f"TextualInversionManager cannot add {trigger_str} because the embedding shape {embedding.shape} is incorrect. The embedding must have shape [token_dim] or [V, token_dim] where V is vector length and token_dim is 768 for SD1 or 1280 for SD2.")
raise ValueError(
f"TextualInversionManager cannot add {trigger_str} because the embedding shape {embedding.shape} is incorrect. The embedding must have shape [token_dim] or [V, token_dim] where V is vector length and token_dim is 768 for SD1 or 1280 for SD2."
)
try:
ti = TextualInversion(
trigger_string=trigger_str,
embedding=embedding
)
ti = TextualInversion(trigger_string=trigger_str, embedding=embedding)
if not defer_injecting_tokens:
self._inject_tokens_and_assign_embeddings(ti)
self.textual_inversions.append(ti)
return ti
except ValueError as e:
if str(e).startswith('Warning'):
if str(e).startswith("Warning"):
print(f">> {str(e)}")
else:
traceback.print_exc()
print(f">> TextualInversionManager was unable to add a textual inversion with trigger string {trigger_str}.")
print(
f">> TextualInversionManager was unable to add a textual inversion with trigger string {trigger_str}."
)
raise
def _inject_tokens_and_assign_embeddings(self, ti: TextualInversion) -> int:
if ti.trigger_token_id is not None:
raise ValueError(f"Tokens already injected for textual inversion with trigger '{ti.trigger_string}'")
raise ValueError(
f"Tokens already injected for textual inversion with trigger '{ti.trigger_string}'"
)
trigger_token_id = self._get_or_create_token_id_and_assign_embedding(ti.trigger_string, ti.embedding[0])
trigger_token_id = self._get_or_create_token_id_and_assign_embedding(
ti.trigger_string, ti.embedding[0]
)
if ti.embedding_vector_length > 1:
# for embeddings with vector length > 1
pad_token_strings = [ti.trigger_string + "-!pad-" + str(pad_index) for pad_index in range(1, ti.embedding_vector_length)]
pad_token_strings = [
ti.trigger_string + "-!pad-" + str(pad_index)
for pad_index in range(1, ti.embedding_vector_length)
]
# todo: batched UI for faster loading when vector length >2
pad_token_ids = [self._get_or_create_token_id_and_assign_embedding(pad_token_str, ti.embedding[1 + i]) \
for (i, pad_token_str) in enumerate(pad_token_strings)]
pad_token_ids = [
self._get_or_create_token_id_and_assign_embedding(
pad_token_str, ti.embedding[1 + i]
)
for (i, pad_token_str) in enumerate(pad_token_strings)
]
else:
pad_token_ids = []
@@ -133,7 +181,6 @@ class TextualInversionManager():
ti.pad_token_ids = pad_token_ids
return ti.trigger_token_id
def has_textual_inversion_for_trigger_string(self, trigger_string: str) -> bool:
try:
ti = self.get_textual_inversion_for_trigger_string(trigger_string)
@@ -141,32 +188,43 @@ class TextualInversionManager():
except StopIteration:
return False
def get_textual_inversion_for_trigger_string(self, trigger_string: str) -> TextualInversion:
return next(ti for ti in self.textual_inversions if ti.trigger_string == trigger_string)
def get_textual_inversion_for_trigger_string(
self, trigger_string: str
) -> TextualInversion:
return next(
ti for ti in self.textual_inversions if ti.trigger_string == trigger_string
)
def get_textual_inversion_for_token_id(self, token_id: int) -> TextualInversion:
return next(ti for ti in self.textual_inversions if ti.trigger_token_id == token_id)
return next(
ti for ti in self.textual_inversions if ti.trigger_token_id == token_id
)
def create_deferred_token_ids_for_any_trigger_terms(self, prompt_string: str) -> list[int]:
def create_deferred_token_ids_for_any_trigger_terms(
self, prompt_string: str
) -> list[int]:
injected_token_ids = []
for ti in self.textual_inversions:
if ti.trigger_token_id is None and ti.trigger_string in prompt_string:
if ti.embedding_vector_length > 1:
print(f">> Preparing tokens for textual inversion {ti.trigger_string}...")
print(
f">> Preparing tokens for textual inversion {ti.trigger_string}..."
)
try:
self._inject_tokens_and_assign_embeddings(ti)
except ValueError as e:
print(f' | Ignoring incompatible embedding trigger {ti.trigger_string}')
print(f' | The error was {str(e)}')
print(
f" | Ignoring incompatible embedding trigger {ti.trigger_string}"
)
print(f" | The error was {str(e)}")
continue
injected_token_ids.append(ti.trigger_token_id)
injected_token_ids.extend(ti.pad_token_ids)
return injected_token_ids
def expand_textual_inversion_token_ids_if_necessary(self, prompt_token_ids: list[int]) -> list[int]:
def expand_textual_inversion_token_ids_if_necessary(
self, prompt_token_ids: list[int]
) -> list[int]:
"""
Insert padding tokens as necessary into the passed-in list of token ids to match any textual inversions it includes.
@@ -181,20 +239,31 @@ class TextualInversionManager():
raise ValueError("prompt_token_ids must not start with bos_token_id")
if prompt_token_ids[-1] == self.tokenizer.eos_token_id:
raise ValueError("prompt_token_ids must not end with eos_token_id")
textual_inversion_trigger_token_ids = [ti.trigger_token_id for ti in self.textual_inversions]
textual_inversion_trigger_token_ids = [
ti.trigger_token_id for ti in self.textual_inversions
]
prompt_token_ids = prompt_token_ids.copy()
for i, token_id in reversed(list(enumerate(prompt_token_ids))):
if token_id in textual_inversion_trigger_token_ids:
textual_inversion = next(ti for ti in self.textual_inversions if ti.trigger_token_id == token_id)
for pad_idx in range(0, textual_inversion.embedding_vector_length-1):
prompt_token_ids.insert(i+pad_idx+1, textual_inversion.pad_token_ids[pad_idx])
textual_inversion = next(
ti
for ti in self.textual_inversions
if ti.trigger_token_id == token_id
)
for pad_idx in range(0, textual_inversion.embedding_vector_length - 1):
prompt_token_ids.insert(
i + pad_idx + 1, textual_inversion.pad_token_ids[pad_idx]
)
return prompt_token_ids
def _get_or_create_token_id_and_assign_embedding(self, token_str: str, embedding: torch.Tensor) -> int:
def _get_or_create_token_id_and_assign_embedding(
self, token_str: str, embedding: torch.Tensor
) -> int:
if len(embedding.shape) != 1:
raise ValueError("Embedding has incorrect shape - must be [token_dim] where token_dim is 768 for SD1 or 1280 for SD2")
raise ValueError(
"Embedding has incorrect shape - must be [token_dim] where token_dim is 768 for SD1 or 1280 for SD2"
)
existing_token_id = self.tokenizer.convert_tokens_to_ids(token_str)
if existing_token_id == self.tokenizer.unk_token_id:
num_tokens_added = self.tokenizer.add_tokens(token_str)
@@ -207,66 +276,78 @@ class TextualInversionManager():
token_id = self.tokenizer.convert_tokens_to_ids(token_str)
if token_id == self.tokenizer.unk_token_id:
raise RuntimeError(f"Unable to find token id for token '{token_str}'")
if self.text_encoder.get_input_embeddings().weight.data[token_id].shape != embedding.shape:
raise ValueError(f"Warning. Cannot load embedding for {token_str}. It was trained on a model with token dimension {embedding.shape[0]}, but the current model has token dimension {self.text_encoder.get_input_embeddings().weight.data[token_id].shape[0]}.")
if (
self.text_encoder.get_input_embeddings().weight.data[token_id].shape
!= embedding.shape
):
raise ValueError(
f"Warning. Cannot load embedding for {token_str}. It was trained on a model with token dimension {embedding.shape[0]}, but the current model has token dimension {self.text_encoder.get_input_embeddings().weight.data[token_id].shape[0]}."
)
self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding
return token_id
def _parse_embedding(self, embedding_file: str):
file_type = embedding_file.split('.')[-1]
if file_type == 'pt':
file_type = embedding_file.split(".")[-1]
if file_type == "pt":
return self._parse_embedding_pt(embedding_file)
elif file_type == 'bin':
elif file_type == "bin":
return self._parse_embedding_bin(embedding_file)
else:
print(f'>> Not a recognized embedding file: {embedding_file}')
print(f">> Not a recognized embedding file: {embedding_file}")
return None
def _parse_embedding_pt(self, embedding_file):
embedding_ckpt = torch.load(embedding_file, map_location='cpu')
embedding_ckpt = torch.load(embedding_file, map_location="cpu")
embedding_info = {}
# Check if valid embedding file
if 'string_to_token' and 'string_to_param' in embedding_ckpt:
if "string_to_token" and "string_to_param" in embedding_ckpt:
# Catch variants that do not have the expected keys or values.
try:
embedding_info['name'] = embedding_ckpt['name'] or os.path.basename(os.path.splitext(embedding_file)[0])
embedding_info["name"] = embedding_ckpt["name"] or os.path.basename(
os.path.splitext(embedding_file)[0]
)
# Check num of embeddings and warn user only the first will be used
embedding_info['num_of_embeddings'] = len(embedding_ckpt["string_to_token"])
if embedding_info['num_of_embeddings'] > 1:
print('>> More than 1 embedding found. Will use the first one')
embedding_info["num_of_embeddings"] = len(
embedding_ckpt["string_to_token"]
)
if embedding_info["num_of_embeddings"] > 1:
print(">> More than 1 embedding found. Will use the first one")
embedding = list(embedding_ckpt['string_to_param'].values())[0]
except (AttributeError,KeyError):
embedding = list(embedding_ckpt["string_to_param"].values())[0]
except (AttributeError, KeyError):
return self._handle_broken_pt_variants(embedding_ckpt, embedding_file)
embedding_info['embedding'] = embedding
embedding_info['num_vectors_per_token'] = embedding.size()[0]
embedding_info['token_dim'] = embedding.size()[1]
embedding_info["embedding"] = embedding
embedding_info["num_vectors_per_token"] = embedding.size()[0]
embedding_info["token_dim"] = embedding.size()[1]
try:
embedding_info['trained_steps'] = embedding_ckpt['step']
embedding_info['trained_model_name'] = embedding_ckpt['sd_checkpoint_name']
embedding_info['trained_model_checksum'] = embedding_ckpt['sd_checkpoint']
embedding_info["trained_steps"] = embedding_ckpt["step"]
embedding_info["trained_model_name"] = embedding_ckpt[
"sd_checkpoint_name"
]
embedding_info["trained_model_checksum"] = embedding_ckpt[
"sd_checkpoint"
]
except AttributeError:
print(">> No Training Details Found. Passing ...")
# .pt files found at https://cyberes.github.io/stable-diffusion-textual-inversion-models/
# They are actually .bin files
elif len(embedding_ckpt.keys())==1:
print('>> Detected .bin file masquerading as .pt file')
elif len(embedding_ckpt.keys()) == 1:
embedding_info = self._parse_embedding_bin(embedding_file)
else:
print('>> Invalid embedding format')
print(">> Invalid embedding format")
embedding_info = None
return embedding_info
def _parse_embedding_bin(self, embedding_file):
embedding_ckpt = torch.load(embedding_file, map_location='cpu')
embedding_ckpt = torch.load(embedding_file, map_location="cpu")
embedding_info = {}
if list(embedding_ckpt.keys()) == 0:
@@ -274,27 +355,42 @@ class TextualInversionManager():
embedding_info = None
else:
for token in list(embedding_ckpt.keys()):
embedding_info['name'] = token or os.path.basename(os.path.splitext(embedding_file)[0])
embedding_info['embedding'] = embedding_ckpt[token]
embedding_info['num_vectors_per_token'] = 1 # All Concepts seem to default to 1
embedding_info['token_dim'] = embedding_info['embedding'].size()[0]
embedding_info["name"] = token or os.path.basename(
os.path.splitext(embedding_file)[0]
)
embedding_info["embedding"] = embedding_ckpt[token]
embedding_info[
"num_vectors_per_token"
] = 1 # All Concepts seem to default to 1
embedding_info["token_dim"] = embedding_info["embedding"].size()[0]
return embedding_info
def _handle_broken_pt_variants(self, embedding_ckpt:dict, embedding_file:str)->dict:
'''
def _handle_broken_pt_variants(
self, embedding_ckpt: dict, embedding_file: str
) -> dict:
"""
This handles the broken .pt file variants. We only know of one at present.
'''
"""
embedding_info = {}
if isinstance(list(embedding_ckpt['string_to_token'].values())[0],torch.Tensor):
print('>> Detected .pt file variant 1') # example at https://github.com/invoke-ai/InvokeAI/issues/1829
for token in list(embedding_ckpt['string_to_token'].keys()):
embedding_info['name'] = token if token != '*' else os.path.basename(os.path.splitext(embedding_file)[0])
embedding_info['embedding'] = embedding_ckpt['string_to_param'].state_dict()[token]
embedding_info['num_vectors_per_token'] = embedding_info['embedding'].shape[0]
embedding_info['token_dim'] = embedding_info['embedding'].size()[0]
if isinstance(
list(embedding_ckpt["string_to_token"].values())[0], torch.Tensor
):
for token in list(embedding_ckpt["string_to_token"].keys()):
embedding_info["name"] = (
token
if token != "*"
else os.path.basename(os.path.splitext(embedding_file)[0])
)
embedding_info["embedding"] = embedding_ckpt[
"string_to_param"
].state_dict()[token]
embedding_info["num_vectors_per_token"] = embedding_info[
"embedding"
].shape[0]
embedding_info["token_dim"] = embedding_info["embedding"].size()[1]
else:
print('>> Invalid embedding format')
print(">> Invalid embedding format")
embedding_info = None
return embedding_info