mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-19 09:54:24 -05:00
Merge branch 'add_lora_support' of https://github.com/jordanramstad/InvokeAI into add_lora_support
This commit is contained in:
980
ldm/generate.py
980
ldm/generate.py
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1 +1 @@
|
||||
__version__='2.3.0'
|
||||
__version__='2.3.1+a0'
|
||||
|
||||
@@ -272,6 +272,10 @@ class Args(object):
|
||||
switches.append('--seamless')
|
||||
if a['hires_fix']:
|
||||
switches.append('--hires_fix')
|
||||
if a['h_symmetry_time_pct']:
|
||||
switches.append(f'--h_symmetry_time_pct {a["h_symmetry_time_pct"]}')
|
||||
if a['v_symmetry_time_pct']:
|
||||
switches.append(f'--v_symmetry_time_pct {a["v_symmetry_time_pct"]}')
|
||||
|
||||
# img2img generations have parameters relevant only to them and have special handling
|
||||
if a['init_img'] and len(a['init_img'])>0:
|
||||
@@ -751,6 +755,9 @@ class Args(object):
|
||||
!fix applies upscaling/facefixing to a previously-generated image.
|
||||
invoke> !fix 0000045.4829112.png -G1 -U4 -ft codeformer
|
||||
|
||||
*embeddings*
|
||||
invoke> !triggers -- return all trigger phrases contained in loaded embedding files
|
||||
|
||||
*History manipulation*
|
||||
!fetch retrieves the command used to generate an earlier image. Provide
|
||||
a directory wildcard and the name of a file to write and all the commands
|
||||
@@ -842,6 +849,18 @@ class Args(object):
|
||||
type=float,
|
||||
help='Perlin noise scale (0.0 - 1.0) - add perlin noise to the initialization instead of the usual gaussian noise.',
|
||||
)
|
||||
render_group.add_argument(
|
||||
'--h_symmetry_time_pct',
|
||||
default=None,
|
||||
type=float,
|
||||
help='Horizontal symmetry point (0.0 - 1.0) - apply horizontal symmetry at this point in image generation.',
|
||||
)
|
||||
render_group.add_argument(
|
||||
'--v_symmetry_time_pct',
|
||||
default=None,
|
||||
type=float,
|
||||
help='Vertical symmetry point (0.0 - 1.0) - apply vertical symmetry at this point in image generation.',
|
||||
)
|
||||
render_group.add_argument(
|
||||
'--fnformat',
|
||||
default='{prefix}.{seed}.png',
|
||||
@@ -1148,7 +1167,8 @@ def metadata_dumps(opt,
|
||||
# remove any image keys not mentioned in RFC #266
|
||||
rfc266_img_fields = ['type','postprocessing','sampler','prompt','seed','variations','steps',
|
||||
'cfg_scale','threshold','perlin','step_number','width','height','extra','strength','seamless'
|
||||
'init_img','init_mask','facetool','facetool_strength','upscale']
|
||||
'init_img','init_mask','facetool','facetool_strength','upscale','h_symmetry_time_pct',
|
||||
'v_symmetry_time_pct']
|
||||
rfc_dict ={}
|
||||
|
||||
for item in image_dict.items():
|
||||
|
||||
@@ -53,6 +53,7 @@ from diffusers import (
|
||||
)
|
||||
from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
|
||||
from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder, PaintByExamplePipeline
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from diffusers.utils import is_safetensors_available
|
||||
from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextModel, CLIPTokenizer, CLIPVisionConfig
|
||||
|
||||
@@ -984,6 +985,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
elif model_type in ['FrozenCLIPEmbedder','WeightedFrozenCLIPEmbedder']:
|
||||
text_model = convert_ldm_clip_checkpoint(checkpoint)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14",cache_dir=cache_dir)
|
||||
safety_checker = StableDiffusionSafetyChecker.from_pretrained('CompVis/stable-diffusion-safety-checker',cache_dir=global_cache_dir("hub"))
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker",cache_dir=cache_dir)
|
||||
pipe = pipeline_class(
|
||||
vae=vae,
|
||||
@@ -991,7 +993,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=None,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -40,7 +40,6 @@ from ldm.invoke.globals import Globals, global_cache_dir, global_config_dir
|
||||
from ldm.invoke.readline import generic_completer
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
import torch
|
||||
|
||||
transformers.logging.set_verbosity_error()
|
||||
|
||||
@@ -764,7 +763,7 @@ def download_weights(opt: dict) -> Union[str, None]:
|
||||
precision = (
|
||||
"float32"
|
||||
if opt.full_precision
|
||||
else choose_precision(torch.device(choose_torch_device()))
|
||||
else choose_precision(choose_torch_device())
|
||||
)
|
||||
|
||||
if opt.yes_to_all:
|
||||
|
||||
102
ldm/invoke/config/invokeai_update.py
Normal file
102
ldm/invoke/config/invokeai_update.py
Normal file
@@ -0,0 +1,102 @@
|
||||
'''
|
||||
Minimalist updater script. Prompts user for the tag or branch to update to and runs
|
||||
pip install <path_to_git_source>.
|
||||
'''
|
||||
|
||||
import platform
|
||||
import requests
|
||||
import subprocess
|
||||
from rich import box, print
|
||||
from rich.console import Console, group
|
||||
from rich.panel import Panel
|
||||
from rich.prompt import Prompt
|
||||
from rich.style import Style
|
||||
from rich.text import Text
|
||||
from rich.live import Live
|
||||
from rich.table import Table
|
||||
|
||||
from ldm.invoke import __version__
|
||||
|
||||
INVOKE_AI_SRC="https://github.com/invoke-ai/InvokeAI/archive"
|
||||
INVOKE_AI_REL="https://api.github.com/repos/invoke-ai/InvokeAI/releases"
|
||||
|
||||
OS = platform.uname().system
|
||||
ARCH = platform.uname().machine
|
||||
|
||||
ORANGE_ON_DARK_GREY = Style(bgcolor="grey23", color="orange1")
|
||||
|
||||
if OS == "Windows":
|
||||
# Windows terminals look better without a background colour
|
||||
console = Console(style=Style(color="grey74"))
|
||||
else:
|
||||
console = Console(style=Style(color="grey74", bgcolor="grey23"))
|
||||
|
||||
def get_versions()->dict:
|
||||
return requests.get(url=INVOKE_AI_REL).json()
|
||||
|
||||
def welcome(versions: dict):
|
||||
|
||||
@group()
|
||||
def text():
|
||||
yield f'InvokeAI Version: [bold yellow]{__version__}'
|
||||
yield ''
|
||||
yield 'This script will update InvokeAI to the latest release, or to a development version of your choice.'
|
||||
yield ''
|
||||
yield '[bold yellow]Options:'
|
||||
yield f'''[1] Update to the latest official release ([italic]{versions[0]['tag_name']}[/italic])
|
||||
[2] Update to the bleeding-edge development version ([italic]main[/italic])
|
||||
[3] Manually enter the tag or branch name you wish to update'''
|
||||
|
||||
console.rule()
|
||||
console.print(
|
||||
Panel(
|
||||
title="[bold wheat1]InvokeAI Updater",
|
||||
renderable=text(),
|
||||
box=box.DOUBLE,
|
||||
expand=True,
|
||||
padding=(1, 2),
|
||||
style=ORANGE_ON_DARK_GREY,
|
||||
subtitle=f"[bold grey39]{OS}-{ARCH}",
|
||||
)
|
||||
)
|
||||
# console.rule is used instead of console.line to maintain dark background
|
||||
# on terminals where light background is the default
|
||||
console.rule(characters=" ")
|
||||
|
||||
def main():
|
||||
versions = get_versions()
|
||||
welcome(versions)
|
||||
|
||||
tag = None
|
||||
choice = Prompt.ask(Text.from_markup(('[grey74 on grey23]Choice:')),choices=['1','2','3'],default='1')
|
||||
|
||||
if choice=='1':
|
||||
tag = versions[0]['tag_name']
|
||||
elif choice=='2':
|
||||
tag = 'main'
|
||||
elif choice=='3':
|
||||
tag = Prompt.ask('[grey74 on grey23]Enter an InvokeAI tag or branch name')
|
||||
|
||||
console.print(Panel(f':crossed_fingers: Upgrading to [yellow]{tag}[/yellow]', box=box.MINIMAL, style=ORANGE_ON_DARK_GREY))
|
||||
|
||||
cmd = f'pip install {INVOKE_AI_SRC}/{tag}.zip --use-pep517'
|
||||
|
||||
progress = Table.grid(expand=True)
|
||||
progress_panel = Panel(progress, box=box.MINIMAL, style=ORANGE_ON_DARK_GREY)
|
||||
|
||||
with subprocess.Popen(['bash', '-c', cmd], stdout=subprocess.PIPE, stderr=subprocess.PIPE) as proc:
|
||||
progress.add_column()
|
||||
with Live(progress_panel, console=console, vertical_overflow='visible'):
|
||||
while proc.poll() is None:
|
||||
for l in iter(proc.stdout.readline, b''):
|
||||
progress.add_row(l.decode().strip(), style=ORANGE_ON_DARK_GREY)
|
||||
if proc.returncode == 0:
|
||||
console.rule(f':heavy_check_mark: Upgrade successful')
|
||||
else:
|
||||
console.rule(f':exclamation: [bold red]Upgrade failed[/red bold]')
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
main()
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
@@ -1,19 +1,25 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
from torch import autocast
|
||||
from contextlib import nullcontext
|
||||
|
||||
from ldm.invoke.globals import Globals
|
||||
|
||||
def choose_torch_device() -> str:
|
||||
CPU_DEVICE = torch.device("cpu")
|
||||
|
||||
def choose_torch_device() -> torch.device:
|
||||
'''Convenience routine for guessing which GPU device to run model on'''
|
||||
if Globals.always_use_cpu:
|
||||
return "cpu"
|
||||
return CPU_DEVICE
|
||||
if torch.cuda.is_available():
|
||||
return 'cuda'
|
||||
return torch.device('cuda')
|
||||
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
||||
return 'mps'
|
||||
return 'cpu'
|
||||
return torch.device('mps')
|
||||
return CPU_DEVICE
|
||||
|
||||
def choose_precision(device) -> str:
|
||||
def choose_precision(device: torch.device) -> str:
|
||||
'''Returns an appropriate precision for the given torch device'''
|
||||
if device.type == 'cuda':
|
||||
device_name = torch.cuda.get_device_name(device)
|
||||
@@ -21,7 +27,7 @@ def choose_precision(device) -> str:
|
||||
return 'float16'
|
||||
return 'float32'
|
||||
|
||||
def torch_dtype(device) -> torch.dtype:
|
||||
def torch_dtype(device: torch.device) -> torch.dtype:
|
||||
if Globals.full_precision:
|
||||
return torch.float32
|
||||
if choose_precision(device) == 'float16':
|
||||
@@ -36,3 +42,13 @@ def choose_autocast(precision):
|
||||
if precision == 'autocast' or precision == 'float16':
|
||||
return autocast
|
||||
return nullcontext
|
||||
|
||||
def normalize_device(device: str | torch.device) -> torch.device:
|
||||
"""Ensure device has a device index defined, if appropriate."""
|
||||
device = torch.device(device)
|
||||
if device.index is None:
|
||||
# cuda might be the only torch backend that currently uses the device index?
|
||||
# I don't see anything like `current_device` for cpu or mps.
|
||||
if device.type == 'cuda':
|
||||
device = torch.device(device.type, torch.cuda.current_device())
|
||||
return device
|
||||
|
||||
@@ -64,6 +64,7 @@ class Generator:
|
||||
|
||||
def generate(self,prompt,init_image,width,height,sampler, iterations=1,seed=None,
|
||||
image_callback=None, step_callback=None, threshold=0.0, perlin=0.0,
|
||||
h_symmetry_time_pct=None, v_symmetry_time_pct=None,
|
||||
safety_checker:dict=None,
|
||||
free_gpu_mem: bool=False,
|
||||
**kwargs):
|
||||
@@ -81,6 +82,8 @@ class Generator:
|
||||
step_callback = step_callback,
|
||||
threshold = threshold,
|
||||
perlin = perlin,
|
||||
h_symmetry_time_pct = h_symmetry_time_pct,
|
||||
v_symmetry_time_pct = v_symmetry_time_pct,
|
||||
attention_maps_callback = attention_maps_callback,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
@@ -29,6 +29,7 @@ from ldm.invoke.globals import Globals
|
||||
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent, PostprocessingSettings
|
||||
from ldm.modules.textual_inversion_manager import TextualInversionManager
|
||||
from ldm.modules.lora_manager import LoraManager
|
||||
from ..devices import normalize_device, CPU_DEVICE
|
||||
from ..offloading import LazilyLoadedModelGroup, FullyLoadedModelGroup, ModelGroup
|
||||
from ...models.diffusion.cross_attention_map_saving import AttentionMapSaver
|
||||
from ...modules.prompt_to_embeddings_converter import WeightedPromptFragmentsToEmbeddingsConverter
|
||||
@@ -322,7 +323,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
if self.device.type == 'cpu' or self.device.type == 'mps':
|
||||
mem_free = psutil.virtual_memory().free
|
||||
elif self.device.type == 'cuda':
|
||||
mem_free, _ = torch.cuda.mem_get_info(self.device)
|
||||
mem_free, _ = torch.cuda.mem_get_info(normalize_device(self.device))
|
||||
else:
|
||||
raise ValueError(f"unrecognized device {self.device}")
|
||||
# input tensor of [1, 4, h/8, w/8]
|
||||
@@ -383,9 +384,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
self._model_group.ready()
|
||||
|
||||
def to(self, torch_device: Optional[Union[str, torch.device]] = None):
|
||||
# overridden method; types match the superclass.
|
||||
if torch_device is None:
|
||||
return self
|
||||
self._model_group.set_device(torch_device)
|
||||
self._model_group.set_device(torch.device(torch_device))
|
||||
self._model_group.ready()
|
||||
|
||||
@property
|
||||
@@ -692,8 +694,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
if device.type == 'mps':
|
||||
# workaround for torch MPS bug that has been fixed in https://github.com/kulinseth/pytorch/pull/222
|
||||
# TODO remove this workaround once kulinseth#222 is merged to pytorch mainline
|
||||
self.vae.to('cpu')
|
||||
init_image = init_image.to('cpu')
|
||||
self.vae.to(CPU_DEVICE)
|
||||
init_image = init_image.to(CPU_DEVICE)
|
||||
else:
|
||||
self._model_group.load(self.vae)
|
||||
init_latent_dist = self.vae.encode(init_image).latent_dist
|
||||
|
||||
@@ -16,8 +16,8 @@ class Img2Img(Generator):
|
||||
self.init_latent = None # by get_noise()
|
||||
|
||||
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
|
||||
conditioning,init_image,strength,step_callback=None,threshold=0.0,perlin=0.0,
|
||||
attention_maps_callback=None,
|
||||
conditioning,init_image,strength,step_callback=None,threshold=0.0,warmup=0.2,perlin=0.0,
|
||||
h_symmetry_time_pct=None,v_symmetry_time_pct=None,attention_maps_callback=None,
|
||||
**kwargs):
|
||||
"""
|
||||
Returns a function returning an image derived from the prompt and the initial image
|
||||
@@ -33,8 +33,13 @@ class Img2Img(Generator):
|
||||
conditioning_data = (
|
||||
ConditioningData(
|
||||
uc, c, cfg_scale, extra_conditioning_info,
|
||||
postprocessing_settings = PostprocessingSettings(threshold, warmup=0.2) if threshold else None)
|
||||
.add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta))
|
||||
postprocessing_settings=PostprocessingSettings(
|
||||
threshold=threshold,
|
||||
warmup=warmup,
|
||||
h_symmetry_time_pct=h_symmetry_time_pct,
|
||||
v_symmetry_time_pct=v_symmetry_time_pct
|
||||
)
|
||||
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta))
|
||||
|
||||
|
||||
def make_image(x_T):
|
||||
|
||||
@@ -15,8 +15,8 @@ class Txt2Img(Generator):
|
||||
|
||||
@torch.no_grad()
|
||||
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
|
||||
conditioning,width,height,step_callback=None,threshold=0.0,perlin=0.0,
|
||||
attention_maps_callback=None,
|
||||
conditioning,width,height,step_callback=None,threshold=0.0,warmup=0.2,perlin=0.0,
|
||||
h_symmetry_time_pct=None,v_symmetry_time_pct=None,attention_maps_callback=None,
|
||||
**kwargs):
|
||||
"""
|
||||
Returns a function returning an image derived from the prompt and the initial image
|
||||
@@ -33,8 +33,13 @@ class Txt2Img(Generator):
|
||||
conditioning_data = (
|
||||
ConditioningData(
|
||||
uc, c, cfg_scale, extra_conditioning_info,
|
||||
postprocessing_settings = PostprocessingSettings(threshold, warmup=0.2) if threshold else None)
|
||||
.add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta))
|
||||
postprocessing_settings=PostprocessingSettings(
|
||||
threshold=threshold,
|
||||
warmup=warmup,
|
||||
h_symmetry_time_pct=h_symmetry_time_pct,
|
||||
v_symmetry_time_pct=v_symmetry_time_pct
|
||||
)
|
||||
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta))
|
||||
|
||||
def make_image(x_T) -> PIL.Image.Image:
|
||||
pipeline_output = pipeline.image_from_embeddings(
|
||||
@@ -44,8 +49,10 @@ class Txt2Img(Generator):
|
||||
conditioning_data=conditioning_data,
|
||||
callback=step_callback,
|
||||
)
|
||||
|
||||
if pipeline_output.attention_map_saver is not None and attention_maps_callback is not None:
|
||||
attention_maps_callback(pipeline_output.attention_map_saver)
|
||||
|
||||
return pipeline.numpy_to_pil(pipeline_output.images)[0]
|
||||
|
||||
return make_image
|
||||
|
||||
@@ -21,12 +21,14 @@ class Txt2Img2Img(Generator):
|
||||
|
||||
def get_make_image(self, prompt:str, sampler, steps:int, cfg_scale:float, ddim_eta,
|
||||
conditioning, width:int, height:int, strength:float,
|
||||
step_callback:Optional[Callable]=None, threshold=0.0, **kwargs):
|
||||
step_callback:Optional[Callable]=None, threshold=0.0, warmup=0.2, perlin=0.0,
|
||||
h_symmetry_time_pct=None, v_symmetry_time_pct=None, attention_maps_callback=None, **kwargs):
|
||||
"""
|
||||
Returns a function returning an image derived from the prompt and the initial image
|
||||
Return value depends on the seed at the time you call it
|
||||
kwargs are 'width' and 'height'
|
||||
"""
|
||||
self.perlin = perlin
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
pipeline: StableDiffusionGeneratorPipeline = self.model
|
||||
@@ -36,8 +38,13 @@ class Txt2Img2Img(Generator):
|
||||
conditioning_data = (
|
||||
ConditioningData(
|
||||
uc, c, cfg_scale, extra_conditioning_info,
|
||||
postprocessing_settings = PostprocessingSettings(threshold=threshold, warmup=0.2) if threshold else None)
|
||||
.add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta))
|
||||
postprocessing_settings = PostprocessingSettings(
|
||||
threshold=threshold,
|
||||
warmup=0.2,
|
||||
h_symmetry_time_pct=h_symmetry_time_pct,
|
||||
v_symmetry_time_pct=v_symmetry_time_pct
|
||||
)
|
||||
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta))
|
||||
|
||||
def make_image(x_T):
|
||||
|
||||
@@ -69,19 +76,28 @@ class Txt2Img2Img(Generator):
|
||||
if clear_cuda_cache is not None:
|
||||
clear_cuda_cache()
|
||||
|
||||
second_pass_noise = self.get_noise_like(resized_latents)
|
||||
second_pass_noise = self.get_noise_like(resized_latents, override_perlin=True)
|
||||
|
||||
# Clear symmetry for the second pass
|
||||
from dataclasses import replace
|
||||
new_postprocessing_settings = replace(conditioning_data.postprocessing_settings, h_symmetry_time_pct=None)
|
||||
new_postprocessing_settings = replace(new_postprocessing_settings, v_symmetry_time_pct=None)
|
||||
new_conditioning_data = replace(conditioning_data, postprocessing_settings=new_postprocessing_settings)
|
||||
|
||||
verbosity = get_verbosity()
|
||||
set_verbosity_error()
|
||||
pipeline_output = pipeline.img2img_from_latents_and_embeddings(
|
||||
resized_latents,
|
||||
num_inference_steps=steps,
|
||||
conditioning_data=conditioning_data,
|
||||
conditioning_data=new_conditioning_data,
|
||||
strength=strength,
|
||||
noise=second_pass_noise,
|
||||
callback=step_callback)
|
||||
set_verbosity(verbosity)
|
||||
|
||||
if pipeline_output.attention_map_saver is not None and attention_maps_callback is not None:
|
||||
attention_maps_callback(pipeline_output.attention_map_saver)
|
||||
|
||||
return pipeline.numpy_to_pil(pipeline_output.images)[0]
|
||||
|
||||
|
||||
@@ -95,13 +111,13 @@ class Txt2Img2Img(Generator):
|
||||
|
||||
return make_image
|
||||
|
||||
def get_noise_like(self, like: torch.Tensor):
|
||||
def get_noise_like(self, like: torch.Tensor, override_perlin: bool=False):
|
||||
device = like.device
|
||||
if device.type == 'mps':
|
||||
x = torch.randn_like(like, device='cpu', dtype=self.torch_dtype()).to(device)
|
||||
else:
|
||||
x = torch.randn_like(like, device=device, dtype=self.torch_dtype())
|
||||
if self.perlin > 0.0:
|
||||
if self.perlin > 0.0 and override_perlin == False:
|
||||
shape = like.shape
|
||||
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(shape[3], shape[2])
|
||||
return x
|
||||
@@ -139,6 +155,9 @@ class Txt2Img2Img(Generator):
|
||||
shape = (1, channels,
|
||||
scaled_height // self.downsampling_factor, scaled_width // self.downsampling_factor)
|
||||
if self.use_mps_noise or device.type == 'mps':
|
||||
return torch.randn(shape, dtype=self.torch_dtype(), device='cpu').to(device)
|
||||
tensor = torch.empty(size=shape, device='cpu')
|
||||
tensor = self.get_noise_like(like=tensor).to(device)
|
||||
else:
|
||||
return torch.randn(shape, dtype=self.torch_dtype(), device=device)
|
||||
tensor = torch.empty(size=shape, device=device)
|
||||
tensor = self.get_noise_like(like=tensor)
|
||||
return tensor
|
||||
|
||||
@@ -323,7 +323,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
||||
|
||||
if selected_model3 > 0:
|
||||
self.merge_method.values = ['add_difference ( A+(B-C) )']
|
||||
self.merged_model_name.value += f"+{models[selected_model3]}"
|
||||
self.merged_model_name.value += f"+{models[selected_model3 -1]}" # In model3 there is one more element in the list (None). So we have to subtract one.
|
||||
else:
|
||||
self.merge_method.values = self.interpolations
|
||||
self.merge_method.value = 0
|
||||
@@ -419,8 +419,7 @@ def run_gui(args: Namespace):
|
||||
mergeapp.run()
|
||||
|
||||
args = mergeapp.merge_arguments
|
||||
print(f'DEBUG: {args}')
|
||||
#merge_diffusion_models_and_commit(**args)
|
||||
merge_diffusion_models_and_commit(**args)
|
||||
print(f'>> Models merged into new model: "{args["merged_model_name"]}".')
|
||||
|
||||
|
||||
|
||||
@@ -30,6 +30,7 @@ from omegaconf import OmegaConf
|
||||
from omegaconf.dictconfig import DictConfig
|
||||
from picklescan.scanner import scan_file_path
|
||||
|
||||
from ldm.invoke.devices import CPU_DEVICE
|
||||
from ldm.invoke.generator.diffusers_pipeline import \
|
||||
StableDiffusionGeneratorPipeline
|
||||
from ldm.invoke.globals import (Globals, global_autoscan_dir, global_cache_dir,
|
||||
@@ -47,7 +48,7 @@ class ModelManager(object):
|
||||
def __init__(
|
||||
self,
|
||||
config: OmegaConf,
|
||||
device_type: str | torch.device = "cpu",
|
||||
device_type: torch.device = CPU_DEVICE,
|
||||
precision: str = "float16",
|
||||
max_loaded_models=DEFAULT_MAX_MODELS,
|
||||
sequential_offload = False
|
||||
@@ -675,7 +676,7 @@ class ModelManager(object):
|
||||
"""
|
||||
if str(weights).startswith(("http:", "https:")):
|
||||
model_name = model_name or url_attachment_name(weights)
|
||||
|
||||
|
||||
weights_path = self._resolve_path(weights, "models/ldm/stable-diffusion-v1")
|
||||
config_path = self._resolve_path(config, "configs/stable-diffusion")
|
||||
|
||||
@@ -996,25 +997,25 @@ class ModelManager(object):
|
||||
self.models.pop(model_name, None)
|
||||
|
||||
def _model_to_cpu(self, model):
|
||||
if self.device == "cpu":
|
||||
if self.device == CPU_DEVICE:
|
||||
return model
|
||||
|
||||
if isinstance(model, StableDiffusionGeneratorPipeline):
|
||||
model.offload_all()
|
||||
return model
|
||||
|
||||
model.cond_stage_model.device = "cpu"
|
||||
model.to("cpu")
|
||||
model.cond_stage_model.device = CPU_DEVICE
|
||||
model.to(CPU_DEVICE)
|
||||
|
||||
for submodel in ("first_stage_model", "cond_stage_model", "model"):
|
||||
try:
|
||||
getattr(model, submodel).to("cpu")
|
||||
getattr(model, submodel).to(CPU_DEVICE)
|
||||
except AttributeError:
|
||||
pass
|
||||
return model
|
||||
|
||||
def _model_from_cpu(self, model):
|
||||
if self.device == "cpu":
|
||||
if self.device == CPU_DEVICE:
|
||||
return model
|
||||
|
||||
if isinstance(model, StableDiffusionGeneratorPipeline):
|
||||
|
||||
@@ -58,9 +58,11 @@ COMMANDS = (
|
||||
'--inpaint_replace','-r',
|
||||
'--png_compression','-z',
|
||||
'--text_mask','-tm',
|
||||
'--h_symmetry_time_pct',
|
||||
'--v_symmetry_time_pct',
|
||||
'!fix','!fetch','!replay','!history','!search','!clear',
|
||||
'!models','!switch','!import_model','!optimize_model','!convert_model','!edit_model','!del_model',
|
||||
'!mask',
|
||||
'!mask','!triggers',
|
||||
)
|
||||
MODEL_COMMANDS = (
|
||||
'!switch',
|
||||
@@ -138,7 +140,7 @@ class Completer(object):
|
||||
elif re.match('^'+'|'.join(MODEL_COMMANDS),buffer):
|
||||
self.matches= self._model_completions(text, state)
|
||||
|
||||
# looking for a ckpt model
|
||||
# looking for a ckpt model
|
||||
elif re.match('^'+'|'.join(CKPT_MODEL_COMMANDS),buffer):
|
||||
self.matches= self._model_completions(text, state, ckpt_only=True)
|
||||
|
||||
@@ -255,7 +257,7 @@ class Completer(object):
|
||||
update our list of models
|
||||
'''
|
||||
self.models = models
|
||||
|
||||
|
||||
def _seed_completions(self, text, state):
|
||||
m = re.search('(-S\s?|--seed[=\s]?)(\d*)',text)
|
||||
if m:
|
||||
|
||||
@@ -441,6 +441,7 @@ class TextualInversionDataset(Dataset):
|
||||
self.image_paths = [
|
||||
os.path.join(self.data_root, file_path)
|
||||
for file_path in os.listdir(self.data_root)
|
||||
if os.path.isfile(file_path) and file_path.endswith(('.png','.PNG','.jpg','.JPG','.jpeg','.JPEG','.gif','.GIF'))
|
||||
]
|
||||
|
||||
self.num_images = len(self.image_paths)
|
||||
|
||||
@@ -584,7 +584,9 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):
|
||||
# print(f"SwapCrossAttnContext for {attention_type} active")
|
||||
|
||||
batch_size, sequence_length, _ = hidden_states.shape
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
|
||||
attention_mask = attn.prepare_attention_mask(
|
||||
attention_mask=attention_mask, target_length=sequence_length,
|
||||
batch_size=batch_size)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
dim = query.shape[-1]
|
||||
|
||||
@@ -18,6 +18,8 @@ from ldm.models.diffusion.cross_attention_map_saving import AttentionMapSaver
|
||||
class PostprocessingSettings:
|
||||
threshold: float
|
||||
warmup: float
|
||||
h_symmetry_time_pct: Optional[float]
|
||||
v_symmetry_time_pct: Optional[float]
|
||||
|
||||
|
||||
class InvokeAIDiffuserComponent:
|
||||
@@ -30,7 +32,7 @@ class InvokeAIDiffuserComponent:
|
||||
* Hybrid conditioning (used for inpainting)
|
||||
'''
|
||||
debug_thresholding = False
|
||||
|
||||
last_percent_through = 0.0
|
||||
|
||||
@dataclass
|
||||
class ExtraConditioningInfo:
|
||||
@@ -56,6 +58,7 @@ class InvokeAIDiffuserComponent:
|
||||
self.is_running_diffusers = is_running_diffusers
|
||||
self.model_forward_callback = model_forward_callback
|
||||
self.cross_attention_control_context = None
|
||||
self.last_percent_through = 0.0
|
||||
|
||||
@contextmanager
|
||||
def custom_attention_context(self,
|
||||
@@ -164,6 +167,7 @@ class InvokeAIDiffuserComponent:
|
||||
if postprocessing_settings is not None:
|
||||
percent_through = self.calculate_percent_through(sigma, step_index, total_step_count)
|
||||
latents = self.apply_threshold(postprocessing_settings, latents, percent_through)
|
||||
latents = self.apply_symmetry(postprocessing_settings, latents, percent_through)
|
||||
return latents
|
||||
|
||||
def calculate_percent_through(self, sigma, step_index, total_step_count):
|
||||
@@ -292,8 +296,12 @@ class InvokeAIDiffuserComponent:
|
||||
self,
|
||||
postprocessing_settings: PostprocessingSettings,
|
||||
latents: torch.Tensor,
|
||||
percent_through
|
||||
percent_through: float
|
||||
) -> torch.Tensor:
|
||||
|
||||
if postprocessing_settings.threshold is None or postprocessing_settings.threshold == 0.0:
|
||||
return latents
|
||||
|
||||
threshold = postprocessing_settings.threshold
|
||||
warmup = postprocessing_settings.warmup
|
||||
|
||||
@@ -342,6 +350,56 @@ class InvokeAIDiffuserComponent:
|
||||
|
||||
return latents
|
||||
|
||||
def apply_symmetry(
|
||||
self,
|
||||
postprocessing_settings: PostprocessingSettings,
|
||||
latents: torch.Tensor,
|
||||
percent_through: float
|
||||
) -> torch.Tensor:
|
||||
|
||||
# Reset our last percent through if this is our first step.
|
||||
if percent_through == 0.0:
|
||||
self.last_percent_through = 0.0
|
||||
|
||||
if postprocessing_settings is None:
|
||||
return latents
|
||||
|
||||
# Check for out of bounds
|
||||
h_symmetry_time_pct = postprocessing_settings.h_symmetry_time_pct
|
||||
if (h_symmetry_time_pct is not None and (h_symmetry_time_pct <= 0.0 or h_symmetry_time_pct > 1.0)):
|
||||
h_symmetry_time_pct = None
|
||||
|
||||
v_symmetry_time_pct = postprocessing_settings.v_symmetry_time_pct
|
||||
if (v_symmetry_time_pct is not None and (v_symmetry_time_pct <= 0.0 or v_symmetry_time_pct > 1.0)):
|
||||
v_symmetry_time_pct = None
|
||||
|
||||
dev = latents.device.type
|
||||
|
||||
latents.to(device='cpu')
|
||||
|
||||
if (
|
||||
h_symmetry_time_pct != None and
|
||||
self.last_percent_through < h_symmetry_time_pct and
|
||||
percent_through >= h_symmetry_time_pct
|
||||
):
|
||||
# Horizontal symmetry occurs on the 3rd dimension of the latent
|
||||
width = latents.shape[3]
|
||||
x_flipped = torch.flip(latents, dims=[3])
|
||||
latents = torch.cat([latents[:, :, :, 0:int(width/2)], x_flipped[:, :, :, int(width/2):int(width)]], dim=3)
|
||||
|
||||
if (
|
||||
v_symmetry_time_pct != None and
|
||||
self.last_percent_through < v_symmetry_time_pct and
|
||||
percent_through >= v_symmetry_time_pct
|
||||
):
|
||||
# Vertical symmetry occurs on the 2nd dimension of the latent
|
||||
height = latents.shape[2]
|
||||
y_flipped = torch.flip(latents, dims=[2])
|
||||
latents = torch.cat([latents[:, :, 0:int(height / 2)], y_flipped[:, :, int(height / 2):int(height)]], dim=2)
|
||||
|
||||
self.last_percent_through = percent_through
|
||||
return latents.to(device=dev)
|
||||
|
||||
def estimate_percent_through(self, step_index, sigma):
|
||||
if step_index is not None and self.cross_attention_control_context is not None:
|
||||
# percent_through will never reach 1.0 (but this is intended)
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
import os
|
||||
import traceback
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from dataclasses import dataclass
|
||||
from picklescan.scanner import scan_file_path
|
||||
from transformers import CLIPTokenizer, CLIPTextModel
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ldm.invoke.concepts_lib import HuggingFaceConceptsLibrary
|
||||
|
||||
@@ -21,11 +22,14 @@ class TextualInversion:
|
||||
def embedding_vector_length(self) -> int:
|
||||
return self.embedding.shape[0]
|
||||
|
||||
class TextualInversionManager():
|
||||
def __init__(self,
|
||||
tokenizer: CLIPTokenizer,
|
||||
text_encoder: CLIPTextModel,
|
||||
full_precision: bool=True):
|
||||
|
||||
class TextualInversionManager:
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: CLIPTokenizer,
|
||||
text_encoder: CLIPTextModel,
|
||||
full_precision: bool = True,
|
||||
):
|
||||
self.tokenizer = tokenizer
|
||||
self.text_encoder = text_encoder
|
||||
self.full_precision = full_precision
|
||||
@@ -38,47 +42,78 @@ class TextualInversionManager():
|
||||
if concept_name in self.hf_concepts_library.concepts_loaded:
|
||||
continue
|
||||
trigger = self.hf_concepts_library.concept_to_trigger(concept_name)
|
||||
if self.has_textual_inversion_for_trigger_string(trigger) \
|
||||
or self.has_textual_inversion_for_trigger_string(concept_name) \
|
||||
or self.has_textual_inversion_for_trigger_string(f'<{concept_name}>'): # in case a token with literal angle brackets encountered
|
||||
print(f'>> Loaded local embedding for trigger {concept_name}')
|
||||
if (
|
||||
self.has_textual_inversion_for_trigger_string(trigger)
|
||||
or self.has_textual_inversion_for_trigger_string(concept_name)
|
||||
or self.has_textual_inversion_for_trigger_string(f"<{concept_name}>")
|
||||
): # in case a token with literal angle brackets encountered
|
||||
print(f">> Loaded local embedding for trigger {concept_name}")
|
||||
continue
|
||||
bin_file = self.hf_concepts_library.get_concept_model_path(concept_name)
|
||||
if not bin_file:
|
||||
continue
|
||||
print(f'>> Loaded remote embedding for trigger {concept_name}')
|
||||
print(f">> Loaded remote embedding for trigger {concept_name}")
|
||||
self.load_textual_inversion(bin_file)
|
||||
self.hf_concepts_library.concepts_loaded[concept_name]=True
|
||||
self.hf_concepts_library.concepts_loaded[concept_name] = True
|
||||
|
||||
def get_all_trigger_strings(self) -> list[str]:
|
||||
return [ti.trigger_string for ti in self.textual_inversions]
|
||||
|
||||
def load_textual_inversion(self, ckpt_path, defer_injecting_tokens: bool=False):
|
||||
if str(ckpt_path).endswith('.DS_Store'):
|
||||
def load_textual_inversion(self, ckpt_path: Union[str,Path], defer_injecting_tokens: bool = False):
|
||||
ckpt_path = Path(ckpt_path)
|
||||
|
||||
if not ckpt_path.is_file():
|
||||
return
|
||||
|
||||
if str(ckpt_path).endswith(".DS_Store"):
|
||||
return
|
||||
|
||||
try:
|
||||
scan_result = scan_file_path(ckpt_path)
|
||||
scan_result = scan_file_path(str(ckpt_path))
|
||||
if scan_result.infected_files == 1:
|
||||
print(f'\n### Security Issues Found in Model: {scan_result.issues_count}')
|
||||
print('### For your safety, InvokeAI will not load this embed.')
|
||||
print(
|
||||
f"\n### Security Issues Found in Model: {scan_result.issues_count}"
|
||||
)
|
||||
print("### For your safety, InvokeAI will not load this embed.")
|
||||
return
|
||||
except Exception:
|
||||
print(f"### WARNING::: Invalid or corrupt embeddings found. Ignoring: {ckpt_path}")
|
||||
print(
|
||||
f"### {ckpt_path.parents[0].name}/{ckpt_path.name} is damaged or corrupt."
|
||||
)
|
||||
return
|
||||
|
||||
embedding_info = self._parse_embedding(str(ckpt_path))
|
||||
|
||||
if embedding_info is None:
|
||||
# We've already put out an error message about the bad embedding in _parse_embedding, so just return.
|
||||
return
|
||||
elif (
|
||||
self.text_encoder.get_input_embeddings().weight.data[0].shape[0]
|
||||
!= embedding_info['token_dim']
|
||||
):
|
||||
print(
|
||||
f"** Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with an incompatible token dimension: {self.text_encoder.get_input_embeddings().weight.data[0].shape[0]} vs {embedding_info['token_dim']}."
|
||||
)
|
||||
return
|
||||
|
||||
embedding_info = self._parse_embedding(ckpt_path)
|
||||
if embedding_info:
|
||||
try:
|
||||
self._add_textual_inversion(embedding_info['name'],
|
||||
embedding_info['embedding'],
|
||||
defer_injecting_tokens=defer_injecting_tokens)
|
||||
self._add_textual_inversion(
|
||||
embedding_info["name"],
|
||||
embedding_info["embedding"],
|
||||
defer_injecting_tokens=defer_injecting_tokens,
|
||||
)
|
||||
except ValueError as e:
|
||||
print(f' | Ignoring incompatible embedding {embedding_info["name"]}')
|
||||
print(f' | The error was {str(e)}')
|
||||
print(f" | The error was {str(e)}")
|
||||
else:
|
||||
print(f'>> Failed to load embedding located at {ckpt_path}. Unsupported file.')
|
||||
print(
|
||||
f">> Failed to load embedding located at {str(ckpt_path)}. Unsupported file."
|
||||
)
|
||||
|
||||
def _add_textual_inversion(self, trigger_str, embedding, defer_injecting_tokens=False) -> TextualInversion:
|
||||
def _add_textual_inversion(
|
||||
self, trigger_str, embedding, defer_injecting_tokens=False
|
||||
) -> TextualInversion:
|
||||
"""
|
||||
Add a textual inversion to be recognised.
|
||||
:param trigger_str: The trigger text in the prompt that activates this textual inversion. If unknown to the embedder's tokenizer, will be added.
|
||||
@@ -86,46 +121,59 @@ class TextualInversionManager():
|
||||
:return: The token id for the added embedding, either existing or newly-added.
|
||||
"""
|
||||
if trigger_str in [ti.trigger_string for ti in self.textual_inversions]:
|
||||
print(f">> TextualInversionManager refusing to overwrite already-loaded token '{trigger_str}'")
|
||||
print(
|
||||
f">> TextualInversionManager refusing to overwrite already-loaded token '{trigger_str}'"
|
||||
)
|
||||
return
|
||||
if not self.full_precision:
|
||||
embedding = embedding.half()
|
||||
if len(embedding.shape) == 1:
|
||||
embedding = embedding.unsqueeze(0)
|
||||
elif len(embedding.shape) > 2:
|
||||
raise ValueError(f"TextualInversionManager cannot add {trigger_str} because the embedding shape {embedding.shape} is incorrect. The embedding must have shape [token_dim] or [V, token_dim] where V is vector length and token_dim is 768 for SD1 or 1280 for SD2.")
|
||||
raise ValueError(
|
||||
f"TextualInversionManager cannot add {trigger_str} because the embedding shape {embedding.shape} is incorrect. The embedding must have shape [token_dim] or [V, token_dim] where V is vector length and token_dim is 768 for SD1 or 1280 for SD2."
|
||||
)
|
||||
|
||||
try:
|
||||
ti = TextualInversion(
|
||||
trigger_string=trigger_str,
|
||||
embedding=embedding
|
||||
)
|
||||
ti = TextualInversion(trigger_string=trigger_str, embedding=embedding)
|
||||
if not defer_injecting_tokens:
|
||||
self._inject_tokens_and_assign_embeddings(ti)
|
||||
self.textual_inversions.append(ti)
|
||||
return ti
|
||||
|
||||
except ValueError as e:
|
||||
if str(e).startswith('Warning'):
|
||||
if str(e).startswith("Warning"):
|
||||
print(f">> {str(e)}")
|
||||
else:
|
||||
traceback.print_exc()
|
||||
print(f">> TextualInversionManager was unable to add a textual inversion with trigger string {trigger_str}.")
|
||||
print(
|
||||
f">> TextualInversionManager was unable to add a textual inversion with trigger string {trigger_str}."
|
||||
)
|
||||
raise
|
||||
|
||||
def _inject_tokens_and_assign_embeddings(self, ti: TextualInversion) -> int:
|
||||
|
||||
if ti.trigger_token_id is not None:
|
||||
raise ValueError(f"Tokens already injected for textual inversion with trigger '{ti.trigger_string}'")
|
||||
raise ValueError(
|
||||
f"Tokens already injected for textual inversion with trigger '{ti.trigger_string}'"
|
||||
)
|
||||
|
||||
trigger_token_id = self._get_or_create_token_id_and_assign_embedding(ti.trigger_string, ti.embedding[0])
|
||||
trigger_token_id = self._get_or_create_token_id_and_assign_embedding(
|
||||
ti.trigger_string, ti.embedding[0]
|
||||
)
|
||||
|
||||
if ti.embedding_vector_length > 1:
|
||||
# for embeddings with vector length > 1
|
||||
pad_token_strings = [ti.trigger_string + "-!pad-" + str(pad_index) for pad_index in range(1, ti.embedding_vector_length)]
|
||||
pad_token_strings = [
|
||||
ti.trigger_string + "-!pad-" + str(pad_index)
|
||||
for pad_index in range(1, ti.embedding_vector_length)
|
||||
]
|
||||
# todo: batched UI for faster loading when vector length >2
|
||||
pad_token_ids = [self._get_or_create_token_id_and_assign_embedding(pad_token_str, ti.embedding[1 + i]) \
|
||||
for (i, pad_token_str) in enumerate(pad_token_strings)]
|
||||
pad_token_ids = [
|
||||
self._get_or_create_token_id_and_assign_embedding(
|
||||
pad_token_str, ti.embedding[1 + i]
|
||||
)
|
||||
for (i, pad_token_str) in enumerate(pad_token_strings)
|
||||
]
|
||||
else:
|
||||
pad_token_ids = []
|
||||
|
||||
@@ -133,7 +181,6 @@ class TextualInversionManager():
|
||||
ti.pad_token_ids = pad_token_ids
|
||||
return ti.trigger_token_id
|
||||
|
||||
|
||||
def has_textual_inversion_for_trigger_string(self, trigger_string: str) -> bool:
|
||||
try:
|
||||
ti = self.get_textual_inversion_for_trigger_string(trigger_string)
|
||||
@@ -141,32 +188,43 @@ class TextualInversionManager():
|
||||
except StopIteration:
|
||||
return False
|
||||
|
||||
|
||||
def get_textual_inversion_for_trigger_string(self, trigger_string: str) -> TextualInversion:
|
||||
return next(ti for ti in self.textual_inversions if ti.trigger_string == trigger_string)
|
||||
|
||||
def get_textual_inversion_for_trigger_string(
|
||||
self, trigger_string: str
|
||||
) -> TextualInversion:
|
||||
return next(
|
||||
ti for ti in self.textual_inversions if ti.trigger_string == trigger_string
|
||||
)
|
||||
|
||||
def get_textual_inversion_for_token_id(self, token_id: int) -> TextualInversion:
|
||||
return next(ti for ti in self.textual_inversions if ti.trigger_token_id == token_id)
|
||||
return next(
|
||||
ti for ti in self.textual_inversions if ti.trigger_token_id == token_id
|
||||
)
|
||||
|
||||
def create_deferred_token_ids_for_any_trigger_terms(self, prompt_string: str) -> list[int]:
|
||||
def create_deferred_token_ids_for_any_trigger_terms(
|
||||
self, prompt_string: str
|
||||
) -> list[int]:
|
||||
injected_token_ids = []
|
||||
for ti in self.textual_inversions:
|
||||
if ti.trigger_token_id is None and ti.trigger_string in prompt_string:
|
||||
if ti.embedding_vector_length > 1:
|
||||
print(f">> Preparing tokens for textual inversion {ti.trigger_string}...")
|
||||
print(
|
||||
f">> Preparing tokens for textual inversion {ti.trigger_string}..."
|
||||
)
|
||||
try:
|
||||
self._inject_tokens_and_assign_embeddings(ti)
|
||||
except ValueError as e:
|
||||
print(f' | Ignoring incompatible embedding trigger {ti.trigger_string}')
|
||||
print(f' | The error was {str(e)}')
|
||||
print(
|
||||
f" | Ignoring incompatible embedding trigger {ti.trigger_string}"
|
||||
)
|
||||
print(f" | The error was {str(e)}")
|
||||
continue
|
||||
injected_token_ids.append(ti.trigger_token_id)
|
||||
injected_token_ids.extend(ti.pad_token_ids)
|
||||
return injected_token_ids
|
||||
|
||||
|
||||
def expand_textual_inversion_token_ids_if_necessary(self, prompt_token_ids: list[int]) -> list[int]:
|
||||
def expand_textual_inversion_token_ids_if_necessary(
|
||||
self, prompt_token_ids: list[int]
|
||||
) -> list[int]:
|
||||
"""
|
||||
Insert padding tokens as necessary into the passed-in list of token ids to match any textual inversions it includes.
|
||||
|
||||
@@ -181,20 +239,31 @@ class TextualInversionManager():
|
||||
raise ValueError("prompt_token_ids must not start with bos_token_id")
|
||||
if prompt_token_ids[-1] == self.tokenizer.eos_token_id:
|
||||
raise ValueError("prompt_token_ids must not end with eos_token_id")
|
||||
textual_inversion_trigger_token_ids = [ti.trigger_token_id for ti in self.textual_inversions]
|
||||
textual_inversion_trigger_token_ids = [
|
||||
ti.trigger_token_id for ti in self.textual_inversions
|
||||
]
|
||||
prompt_token_ids = prompt_token_ids.copy()
|
||||
for i, token_id in reversed(list(enumerate(prompt_token_ids))):
|
||||
if token_id in textual_inversion_trigger_token_ids:
|
||||
textual_inversion = next(ti for ti in self.textual_inversions if ti.trigger_token_id == token_id)
|
||||
for pad_idx in range(0, textual_inversion.embedding_vector_length-1):
|
||||
prompt_token_ids.insert(i+pad_idx+1, textual_inversion.pad_token_ids[pad_idx])
|
||||
textual_inversion = next(
|
||||
ti
|
||||
for ti in self.textual_inversions
|
||||
if ti.trigger_token_id == token_id
|
||||
)
|
||||
for pad_idx in range(0, textual_inversion.embedding_vector_length - 1):
|
||||
prompt_token_ids.insert(
|
||||
i + pad_idx + 1, textual_inversion.pad_token_ids[pad_idx]
|
||||
)
|
||||
|
||||
return prompt_token_ids
|
||||
|
||||
|
||||
def _get_or_create_token_id_and_assign_embedding(self, token_str: str, embedding: torch.Tensor) -> int:
|
||||
def _get_or_create_token_id_and_assign_embedding(
|
||||
self, token_str: str, embedding: torch.Tensor
|
||||
) -> int:
|
||||
if len(embedding.shape) != 1:
|
||||
raise ValueError("Embedding has incorrect shape - must be [token_dim] where token_dim is 768 for SD1 or 1280 for SD2")
|
||||
raise ValueError(
|
||||
"Embedding has incorrect shape - must be [token_dim] where token_dim is 768 for SD1 or 1280 for SD2"
|
||||
)
|
||||
existing_token_id = self.tokenizer.convert_tokens_to_ids(token_str)
|
||||
if existing_token_id == self.tokenizer.unk_token_id:
|
||||
num_tokens_added = self.tokenizer.add_tokens(token_str)
|
||||
@@ -207,66 +276,78 @@ class TextualInversionManager():
|
||||
token_id = self.tokenizer.convert_tokens_to_ids(token_str)
|
||||
if token_id == self.tokenizer.unk_token_id:
|
||||
raise RuntimeError(f"Unable to find token id for token '{token_str}'")
|
||||
if self.text_encoder.get_input_embeddings().weight.data[token_id].shape != embedding.shape:
|
||||
raise ValueError(f"Warning. Cannot load embedding for {token_str}. It was trained on a model with token dimension {embedding.shape[0]}, but the current model has token dimension {self.text_encoder.get_input_embeddings().weight.data[token_id].shape[0]}.")
|
||||
if (
|
||||
self.text_encoder.get_input_embeddings().weight.data[token_id].shape
|
||||
!= embedding.shape
|
||||
):
|
||||
raise ValueError(
|
||||
f"Warning. Cannot load embedding for {token_str}. It was trained on a model with token dimension {embedding.shape[0]}, but the current model has token dimension {self.text_encoder.get_input_embeddings().weight.data[token_id].shape[0]}."
|
||||
)
|
||||
self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding
|
||||
|
||||
return token_id
|
||||
|
||||
def _parse_embedding(self, embedding_file: str):
|
||||
file_type = embedding_file.split('.')[-1]
|
||||
if file_type == 'pt':
|
||||
file_type = embedding_file.split(".")[-1]
|
||||
if file_type == "pt":
|
||||
return self._parse_embedding_pt(embedding_file)
|
||||
elif file_type == 'bin':
|
||||
elif file_type == "bin":
|
||||
return self._parse_embedding_bin(embedding_file)
|
||||
else:
|
||||
print(f'>> Not a recognized embedding file: {embedding_file}')
|
||||
print(f">> Not a recognized embedding file: {embedding_file}")
|
||||
return None
|
||||
|
||||
def _parse_embedding_pt(self, embedding_file):
|
||||
embedding_ckpt = torch.load(embedding_file, map_location='cpu')
|
||||
embedding_ckpt = torch.load(embedding_file, map_location="cpu")
|
||||
embedding_info = {}
|
||||
|
||||
# Check if valid embedding file
|
||||
if 'string_to_token' and 'string_to_param' in embedding_ckpt:
|
||||
|
||||
if "string_to_token" and "string_to_param" in embedding_ckpt:
|
||||
# Catch variants that do not have the expected keys or values.
|
||||
try:
|
||||
embedding_info['name'] = embedding_ckpt['name'] or os.path.basename(os.path.splitext(embedding_file)[0])
|
||||
embedding_info["name"] = embedding_ckpt["name"] or os.path.basename(
|
||||
os.path.splitext(embedding_file)[0]
|
||||
)
|
||||
|
||||
# Check num of embeddings and warn user only the first will be used
|
||||
embedding_info['num_of_embeddings'] = len(embedding_ckpt["string_to_token"])
|
||||
if embedding_info['num_of_embeddings'] > 1:
|
||||
print('>> More than 1 embedding found. Will use the first one')
|
||||
embedding_info["num_of_embeddings"] = len(
|
||||
embedding_ckpt["string_to_token"]
|
||||
)
|
||||
if embedding_info["num_of_embeddings"] > 1:
|
||||
print(">> More than 1 embedding found. Will use the first one")
|
||||
|
||||
embedding = list(embedding_ckpt['string_to_param'].values())[0]
|
||||
except (AttributeError,KeyError):
|
||||
embedding = list(embedding_ckpt["string_to_param"].values())[0]
|
||||
except (AttributeError, KeyError):
|
||||
return self._handle_broken_pt_variants(embedding_ckpt, embedding_file)
|
||||
|
||||
embedding_info['embedding'] = embedding
|
||||
embedding_info['num_vectors_per_token'] = embedding.size()[0]
|
||||
embedding_info['token_dim'] = embedding.size()[1]
|
||||
embedding_info["embedding"] = embedding
|
||||
embedding_info["num_vectors_per_token"] = embedding.size()[0]
|
||||
embedding_info["token_dim"] = embedding.size()[1]
|
||||
|
||||
try:
|
||||
embedding_info['trained_steps'] = embedding_ckpt['step']
|
||||
embedding_info['trained_model_name'] = embedding_ckpt['sd_checkpoint_name']
|
||||
embedding_info['trained_model_checksum'] = embedding_ckpt['sd_checkpoint']
|
||||
embedding_info["trained_steps"] = embedding_ckpt["step"]
|
||||
embedding_info["trained_model_name"] = embedding_ckpt[
|
||||
"sd_checkpoint_name"
|
||||
]
|
||||
embedding_info["trained_model_checksum"] = embedding_ckpt[
|
||||
"sd_checkpoint"
|
||||
]
|
||||
except AttributeError:
|
||||
print(">> No Training Details Found. Passing ...")
|
||||
|
||||
# .pt files found at https://cyberes.github.io/stable-diffusion-textual-inversion-models/
|
||||
# They are actually .bin files
|
||||
elif len(embedding_ckpt.keys())==1:
|
||||
print('>> Detected .bin file masquerading as .pt file')
|
||||
elif len(embedding_ckpt.keys()) == 1:
|
||||
embedding_info = self._parse_embedding_bin(embedding_file)
|
||||
|
||||
else:
|
||||
print('>> Invalid embedding format')
|
||||
print(">> Invalid embedding format")
|
||||
embedding_info = None
|
||||
|
||||
return embedding_info
|
||||
|
||||
def _parse_embedding_bin(self, embedding_file):
|
||||
embedding_ckpt = torch.load(embedding_file, map_location='cpu')
|
||||
embedding_ckpt = torch.load(embedding_file, map_location="cpu")
|
||||
embedding_info = {}
|
||||
|
||||
if list(embedding_ckpt.keys()) == 0:
|
||||
@@ -274,27 +355,42 @@ class TextualInversionManager():
|
||||
embedding_info = None
|
||||
else:
|
||||
for token in list(embedding_ckpt.keys()):
|
||||
embedding_info['name'] = token or os.path.basename(os.path.splitext(embedding_file)[0])
|
||||
embedding_info['embedding'] = embedding_ckpt[token]
|
||||
embedding_info['num_vectors_per_token'] = 1 # All Concepts seem to default to 1
|
||||
embedding_info['token_dim'] = embedding_info['embedding'].size()[0]
|
||||
embedding_info["name"] = token or os.path.basename(
|
||||
os.path.splitext(embedding_file)[0]
|
||||
)
|
||||
embedding_info["embedding"] = embedding_ckpt[token]
|
||||
embedding_info[
|
||||
"num_vectors_per_token"
|
||||
] = 1 # All Concepts seem to default to 1
|
||||
embedding_info["token_dim"] = embedding_info["embedding"].size()[0]
|
||||
|
||||
return embedding_info
|
||||
|
||||
def _handle_broken_pt_variants(self, embedding_ckpt:dict, embedding_file:str)->dict:
|
||||
'''
|
||||
def _handle_broken_pt_variants(
|
||||
self, embedding_ckpt: dict, embedding_file: str
|
||||
) -> dict:
|
||||
"""
|
||||
This handles the broken .pt file variants. We only know of one at present.
|
||||
'''
|
||||
"""
|
||||
embedding_info = {}
|
||||
if isinstance(list(embedding_ckpt['string_to_token'].values())[0],torch.Tensor):
|
||||
print('>> Detected .pt file variant 1') # example at https://github.com/invoke-ai/InvokeAI/issues/1829
|
||||
for token in list(embedding_ckpt['string_to_token'].keys()):
|
||||
embedding_info['name'] = token if token != '*' else os.path.basename(os.path.splitext(embedding_file)[0])
|
||||
embedding_info['embedding'] = embedding_ckpt['string_to_param'].state_dict()[token]
|
||||
embedding_info['num_vectors_per_token'] = embedding_info['embedding'].shape[0]
|
||||
embedding_info['token_dim'] = embedding_info['embedding'].size()[0]
|
||||
if isinstance(
|
||||
list(embedding_ckpt["string_to_token"].values())[0], torch.Tensor
|
||||
):
|
||||
for token in list(embedding_ckpt["string_to_token"].keys()):
|
||||
embedding_info["name"] = (
|
||||
token
|
||||
if token != "*"
|
||||
else os.path.basename(os.path.splitext(embedding_file)[0])
|
||||
)
|
||||
embedding_info["embedding"] = embedding_ckpt[
|
||||
"string_to_param"
|
||||
].state_dict()[token]
|
||||
embedding_info["num_vectors_per_token"] = embedding_info[
|
||||
"embedding"
|
||||
].shape[0]
|
||||
embedding_info["token_dim"] = embedding_info["embedding"].size()[1]
|
||||
else:
|
||||
print('>> Invalid embedding format')
|
||||
print(">> Invalid embedding format")
|
||||
embedding_info = None
|
||||
|
||||
return embedding_info
|
||||
|
||||
Reference in New Issue
Block a user