mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-16 01:58:14 -05:00
Compare commits
48 Commits
bugfix/con
...
v2.3
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
aa1538bd70 | ||
|
|
9e87a080a8 | ||
|
|
f3b2e02921 | ||
|
|
fab5df9317 | ||
|
|
2e21e5b8f3 | ||
|
|
0ce628b22e | ||
|
|
ddcf9a3396 | ||
|
|
53f5dfbce1 | ||
|
|
060ea144a1 | ||
|
|
31a65b1e5d | ||
|
|
bdc75be33f | ||
|
|
628d307c69 | ||
|
|
18f0cbd9a7 | ||
|
|
d83d69ccc1 | ||
|
|
332ac72e0e | ||
|
|
fa886ee9e0 | ||
|
|
03bbb308c9 | ||
|
|
1dcac3929b | ||
|
|
d73f1c363c | ||
|
|
e52e7418bb | ||
|
|
73be58a0b5 | ||
|
|
5a7d11bca8 | ||
|
|
5bbf7fe34a | ||
|
|
bfb968bbe8 | ||
|
|
6db72f83a2 | ||
|
|
432e526999 | ||
|
|
830740b93b | ||
|
|
ff3f289342 | ||
|
|
34abbb3589 | ||
|
|
c0eb1a9921 | ||
|
|
2ddd0301f4 | ||
|
|
ce6629b6f5 | ||
|
|
0f3c456d59 | ||
|
|
a45b3387c0 | ||
|
|
2a2c86896a | ||
|
|
aabe79686e | ||
|
|
23d9361528 | ||
|
|
ce22a1577c | ||
|
|
216b1c3a4a | ||
|
|
47883860a6 | ||
|
|
8f17d17208 | ||
|
|
c6ecf3afc5 | ||
|
|
0bc5dcc663 | ||
|
|
16c97ca0cb | ||
|
|
e24dd97b80 | ||
|
|
5a54039dd7 | ||
|
|
9385edb453 | ||
|
|
66364501d5 |
2
.gitignore
vendored
2
.gitignore
vendored
@@ -233,5 +233,3 @@ installer/install.sh
|
||||
installer/update.bat
|
||||
installer/update.sh
|
||||
|
||||
# no longer stored in source directory
|
||||
models
|
||||
|
||||
@@ -41,6 +41,16 @@ Windows systems). If the `loras` folder does not already exist, just
|
||||
create it. The vast majority of LoRA models use the Kohya file format,
|
||||
which is a type of `.safetensors` file.
|
||||
|
||||
!!! warning "LoRA Naming Restrictions"
|
||||
|
||||
InvokeAI will only recognize LoRA files that contain the
|
||||
characters a-z, A-Z, 0-9 and the underscore character
|
||||
_. Other characters, including the hyphen, will cause the
|
||||
LoRA file not to load. These naming restrictions may be
|
||||
relaxed in the future, but for now you will need to rename
|
||||
files that contain hyphens, commas, brackets, and other
|
||||
non-word characters.
|
||||
|
||||
You may change where InvokeAI looks for the `loras` folder by passing the
|
||||
`--lora_directory` option to the `invoke.sh`/`invoke.bat` launcher, or
|
||||
by placing the option in `invokeai.init`. For example:
|
||||
|
||||
@@ -33,6 +33,11 @@ title: Overview
|
||||
Restore mangled faces and make images larger with upscaling. Also see
|
||||
the [Embiggen Upscaling Guide](EMBIGGEN.md).
|
||||
|
||||
- The [Using LoRA Models](LORAS.md)
|
||||
|
||||
Add custom subjects and styles using HuggingFace's repository of
|
||||
embeddings.
|
||||
|
||||
- The [Concepts Library](CONCEPTS.md)
|
||||
|
||||
Add custom subjects and styles using HuggingFace's repository of
|
||||
|
||||
@@ -79,7 +79,7 @@ title: Manual Installation, Linux
|
||||
and obtaining an access token for downloading. It will then download and
|
||||
install the weights files for you.
|
||||
|
||||
Please look [here](../INSTALL_MANUAL.md) for a manual process for doing
|
||||
Please look [here](../020_INSTALL_MANUAL.md) for a manual process for doing
|
||||
the same thing.
|
||||
|
||||
7. Start generating images!
|
||||
|
||||
@@ -75,7 +75,7 @@ Note that you will need NVIDIA drivers, Python 3.10, and Git installed beforehan
|
||||
obtaining an access token for downloading. It will then download and install the
|
||||
weights files for you.
|
||||
|
||||
Please look [here](../INSTALL_MANUAL.md) for a manual process for doing the
|
||||
Please look [here](../020_INSTALL_MANUAL.md) for a manual process for doing the
|
||||
same thing.
|
||||
|
||||
8. Start generating images!
|
||||
|
||||
5
docs/requirements-mkdocs.txt
Normal file
5
docs/requirements-mkdocs.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
mkdocs
|
||||
mkdocs-material>=8, <9
|
||||
mkdocs-git-revision-date-localized-plugin
|
||||
mkdocs-redirects==1.2.0
|
||||
|
||||
@@ -243,16 +243,15 @@ class InvokeAiInstance:
|
||||
|
||||
# Note that we're installing pinned versions of torch and
|
||||
# torchvision here, which *should* correspond to what is
|
||||
# in pyproject.toml. This is to prevent torch 2.0 from
|
||||
# being installed and immediately uninstalled and replaced with 1.13
|
||||
# in pyproject.toml.
|
||||
pip = local[self.pip]
|
||||
|
||||
(
|
||||
pip[
|
||||
"install",
|
||||
"--require-virtualenv",
|
||||
"torch~=1.13.1",
|
||||
"torchvision~=0.14.1",
|
||||
"torch~=2.0.0",
|
||||
"torchvision>=0.14.1",
|
||||
"--force-reinstall",
|
||||
"--find-links" if find_links is not None else None,
|
||||
find_links,
|
||||
|
||||
@@ -25,7 +25,7 @@ from invokeai.backend.modules.parameters import parameters_to_command
|
||||
import invokeai.frontend.dist as frontend
|
||||
from ldm.generate import Generate
|
||||
from ldm.invoke.args import Args, APP_ID, APP_VERSION, calculate_init_img_hash
|
||||
from ldm.invoke.concepts_lib import HuggingFaceConceptsLibrary
|
||||
from ldm.invoke.concepts_lib import get_hf_concepts_lib
|
||||
from ldm.invoke.conditioning import (
|
||||
get_tokens_for_prompt_object,
|
||||
get_prompt_structure,
|
||||
@@ -538,7 +538,7 @@ class InvokeAIWebServer:
|
||||
try:
|
||||
local_triggers = self.generate.model.textual_inversion_manager.get_all_trigger_strings()
|
||||
locals = [{'name': x} for x in sorted(local_triggers, key=str.casefold)]
|
||||
concepts = HuggingFaceConceptsLibrary().list_concepts(minimum_likes=5)
|
||||
concepts = get_hf_concepts_lib().list_concepts(minimum_likes=5)
|
||||
concepts = [{'name': f'<{x}>'} for x in sorted(concepts, key=str.casefold) if f'<{x}>' not in local_triggers]
|
||||
socketio.emit("foundTextualInversionTriggers", {'local_triggers': locals, 'huggingface_concepts': concepts})
|
||||
except Exception as e:
|
||||
|
||||
@@ -13,11 +13,16 @@ import time
|
||||
import traceback
|
||||
from typing import List
|
||||
|
||||
import warnings
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", category=UserWarning)
|
||||
import torch
|
||||
|
||||
import cv2
|
||||
import diffusers
|
||||
import numpy as np
|
||||
import skimage
|
||||
import torch
|
||||
|
||||
import transformers
|
||||
from diffusers.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
@@ -979,13 +984,15 @@ class Generate:
|
||||
seed_everything(random.randrange(0, np.iinfo(np.uint32).max))
|
||||
if self.embedding_path and not model_data.get("ti_embeddings_loaded"):
|
||||
print(f'>> Loading embeddings from {self.embedding_path}')
|
||||
for root, _, files in os.walk(self.embedding_path):
|
||||
for name in files:
|
||||
ti_path = os.path.join(root, name)
|
||||
self.model.textual_inversion_manager.load_textual_inversion(
|
||||
ti_path, defer_injecting_tokens=True
|
||||
)
|
||||
model_data["ti_embeddings_loaded"] = True
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", category=UserWarning)
|
||||
for root, _, files in os.walk(self.embedding_path):
|
||||
for name in files:
|
||||
ti_path = os.path.join(root, name)
|
||||
self.model.textual_inversion_manager.load_textual_inversion(
|
||||
ti_path, defer_injecting_tokens=True
|
||||
)
|
||||
model_data["ti_embeddings_loaded"] = True
|
||||
print(
|
||||
f'>> Textual inversion triggers: {", ".join(sorted(self.model.textual_inversion_manager.get_all_trigger_strings()))}'
|
||||
)
|
||||
|
||||
@@ -9,7 +9,6 @@ from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import click
|
||||
|
||||
from compel import PromptParser
|
||||
|
||||
if sys.platform == "darwin":
|
||||
|
||||
@@ -1 +1,3 @@
|
||||
__version__='2.3.4.post1'
|
||||
__version__='2.3.5.post2'
|
||||
|
||||
|
||||
|
||||
@@ -620,7 +620,10 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
|
||||
for key in keys:
|
||||
if key.startswith(vae_key):
|
||||
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
|
||||
new_checkpoint = convert_ldm_vae_state_dict(vae_state_dict,config)
|
||||
return new_checkpoint
|
||||
|
||||
def convert_ldm_vae_state_dict(vae_state_dict, config):
|
||||
new_checkpoint = {}
|
||||
|
||||
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
|
||||
|
||||
@@ -12,6 +12,14 @@ from urllib import request, error as ul_error
|
||||
from huggingface_hub import HfFolder, hf_hub_url, ModelSearchArguments, ModelFilter, HfApi
|
||||
from ldm.invoke.globals import Globals
|
||||
|
||||
singleton = None
|
||||
|
||||
def get_hf_concepts_lib():
|
||||
global singleton
|
||||
if singleton is None:
|
||||
singleton = HuggingFaceConceptsLibrary()
|
||||
return singleton
|
||||
|
||||
class HuggingFaceConceptsLibrary(object):
|
||||
def __init__(self, root=None):
|
||||
'''
|
||||
|
||||
@@ -6,6 +6,7 @@ import os
|
||||
import platform
|
||||
import psutil
|
||||
import requests
|
||||
import pkg_resources
|
||||
from rich import box, print
|
||||
from rich.console import Console, group
|
||||
from rich.panel import Panel
|
||||
@@ -39,7 +40,7 @@ def invokeai_is_running()->bool:
|
||||
if matches:
|
||||
print(f':exclamation: [bold red]An InvokeAI instance appears to be running as process {p.pid}[/red bold]')
|
||||
return True
|
||||
except psutil.AccessDenied:
|
||||
except (psutil.AccessDenied,psutil.NoSuchProcess):
|
||||
continue
|
||||
return False
|
||||
|
||||
@@ -72,10 +73,20 @@ def welcome(versions: dict):
|
||||
)
|
||||
console.line()
|
||||
|
||||
def get_extras():
|
||||
extras = ''
|
||||
try:
|
||||
dist = pkg_resources.get_distribution('xformers')
|
||||
extras = '[xformers]'
|
||||
except pkg_resources.DistributionNotFound:
|
||||
pass
|
||||
return extras
|
||||
|
||||
def main():
|
||||
versions = get_versions()
|
||||
if invokeai_is_running():
|
||||
print(f':exclamation: [bold red]Please terminate all running instances of InvokeAI before updating.[/red bold]')
|
||||
input('Press any key to continue...')
|
||||
return
|
||||
|
||||
welcome(versions)
|
||||
@@ -94,13 +105,15 @@ def main():
|
||||
elif choice=='4':
|
||||
branch = Prompt.ask('Enter an InvokeAI branch name')
|
||||
|
||||
extras = get_extras()
|
||||
|
||||
print(f':crossed_fingers: Upgrading to [yellow]{tag if tag else release}[/yellow]')
|
||||
if release:
|
||||
cmd = f'pip install {INVOKE_AI_SRC}/{release}.zip --use-pep517 --upgrade'
|
||||
cmd = f"pip install 'invokeai{extras} @ {INVOKE_AI_SRC}/{release}.zip' --use-pep517 --upgrade"
|
||||
elif tag:
|
||||
cmd = f'pip install {INVOKE_AI_TAG}/{tag}.zip --use-pep517 --upgrade'
|
||||
cmd = f"pip install 'invokeai{extras} @ {INVOKE_AI_TAG}/{tag}.zip' --use-pep517 --upgrade"
|
||||
else:
|
||||
cmd = f'pip install {INVOKE_AI_BRANCH}/{branch}.zip --use-pep517 --upgrade'
|
||||
cmd = f"pip install 'invokeai{extras} @ {INVOKE_AI_BRANCH}/{branch}.zip' --use-pep517 --upgrade"
|
||||
print('')
|
||||
print('')
|
||||
if os.system(cmd)==0:
|
||||
|
||||
@@ -111,7 +111,6 @@ def install_requested_models(
|
||||
if len(external_models)>0:
|
||||
print("== INSTALLING EXTERNAL MODELS ==")
|
||||
for path_url_or_repo in external_models:
|
||||
print(f'DEBUG: path_url_or_repo = {path_url_or_repo}')
|
||||
try:
|
||||
model_manager.heuristic_import(
|
||||
path_url_or_repo,
|
||||
|
||||
@@ -400,8 +400,15 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
@property
|
||||
def _submodels(self) -> Sequence[torch.nn.Module]:
|
||||
module_names, _, _ = self.extract_init_dict(dict(self.config))
|
||||
values = [getattr(self, name) for name in module_names.keys()]
|
||||
return [m for m in values if isinstance(m, torch.nn.Module)]
|
||||
submodels = []
|
||||
for name in module_names.keys():
|
||||
if hasattr(self, name):
|
||||
value = getattr(self, name)
|
||||
else:
|
||||
value = getattr(self.config, name)
|
||||
if isinstance(value, torch.nn.Module):
|
||||
submodels.append(value)
|
||||
return submodels
|
||||
|
||||
def image_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int,
|
||||
conditioning_data: ConditioningData,
|
||||
@@ -472,7 +479,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
step_count=len(self.scheduler.timesteps)
|
||||
):
|
||||
|
||||
yield PipelineIntermediateState(run_id=run_id, step=-1, timestep=self.scheduler.num_train_timesteps,
|
||||
yield PipelineIntermediateState(run_id=run_id, step=-1, timestep=self.scheduler.config.num_train_timesteps,
|
||||
latents=latents)
|
||||
|
||||
batch_size = latents.shape[0]
|
||||
@@ -756,7 +763,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
@property
|
||||
def channels(self) -> int:
|
||||
"""Compatible with DiffusionWrapper"""
|
||||
return self.unet.in_channels
|
||||
return self.unet.config.in_channels
|
||||
|
||||
def decode_latents(self, latents):
|
||||
# Explicit call to get the vae loaded, since `decode` isn't the forward method.
|
||||
|
||||
@@ -9,7 +9,6 @@ from __future__ import annotations
|
||||
import contextlib
|
||||
import gc
|
||||
import hashlib
|
||||
import io
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
@@ -31,11 +30,10 @@ from huggingface_hub import scan_cache_dir
|
||||
from omegaconf import OmegaConf
|
||||
from omegaconf.dictconfig import DictConfig
|
||||
from picklescan.scanner import scan_file_path
|
||||
|
||||
from ldm.invoke.devices import CPU_DEVICE
|
||||
from ldm.invoke.generator.diffusers_pipeline import StableDiffusionGeneratorPipeline
|
||||
from ldm.invoke.globals import Globals, global_cache_dir
|
||||
from ldm.util import ask_user, download_with_resume, instantiate_from_config, url_attachment_name
|
||||
from ldm.util import ask_user, download_with_resume, url_attachment_name
|
||||
|
||||
|
||||
class SDLegacyType(Enum):
|
||||
@@ -370,8 +368,9 @@ class ModelManager(object):
|
||||
print(
|
||||
f">> Converting legacy checkpoint {model_name} into a diffusers model..."
|
||||
)
|
||||
from ldm.invoke.ckpt_to_diffuser import load_pipeline_from_original_stable_diffusion_ckpt
|
||||
|
||||
from .ckpt_to_diffuser import (
|
||||
load_pipeline_from_original_stable_diffusion_ckpt,
|
||||
)
|
||||
if self._has_cuda():
|
||||
torch.cuda.empty_cache()
|
||||
pipeline = load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
@@ -433,7 +432,7 @@ class ModelManager(object):
|
||||
**fp_args,
|
||||
)
|
||||
except OSError as e:
|
||||
if str(e).startswith("fp16 is not a valid"):
|
||||
if 'Revision Not Found' in str(e):
|
||||
pass
|
||||
else:
|
||||
print(
|
||||
@@ -1156,7 +1155,7 @@ class ModelManager(object):
|
||||
return self.device.type == "cuda"
|
||||
|
||||
def _diffuser_sha256(
|
||||
self, name_or_path: Union[str, Path], chunksize=4096
|
||||
self, name_or_path: Union[str, Path], chunksize=16777216
|
||||
) -> Union[str, bytes]:
|
||||
path = None
|
||||
if isinstance(name_or_path, Path):
|
||||
@@ -1230,6 +1229,17 @@ class ModelManager(object):
|
||||
return vae_path
|
||||
|
||||
def _load_vae(self, vae_config) -> AutoencoderKL:
|
||||
using_fp16 = self.precision == "float16"
|
||||
dtype = torch.float16 if using_fp16 else torch.float32
|
||||
|
||||
# Handle the common case of a user shoving a VAE .ckpt into
|
||||
# the vae field for a diffusers. We convert it into diffusers
|
||||
# format and use it.
|
||||
if isinstance(vae_config,(str,Path)):
|
||||
return self.convert_vae(vae_config).to(dtype=dtype)
|
||||
elif isinstance(vae_config,DictConfig) and (vae_path := vae_config.get('path')):
|
||||
return self.convert_vae(vae_path).to(dtype=dtype)
|
||||
|
||||
vae_args = {}
|
||||
try:
|
||||
name_or_path = self.model_name_or_path(vae_config)
|
||||
@@ -1237,7 +1247,6 @@ class ModelManager(object):
|
||||
return None
|
||||
if name_or_path is None:
|
||||
return None
|
||||
using_fp16 = self.precision == "float16"
|
||||
|
||||
vae_args.update(
|
||||
cache_dir=global_cache_dir("hub"),
|
||||
@@ -1277,6 +1286,32 @@ class ModelManager(object):
|
||||
|
||||
return vae
|
||||
|
||||
@staticmethod
|
||||
def convert_vae(vae_path: Union[Path,str])->AutoencoderKL:
|
||||
print(" | A checkpoint VAE was detected. Converting to diffusers format.")
|
||||
vae_path = Path(Globals.root,vae_path).resolve()
|
||||
|
||||
from .ckpt_to_diffuser import (
|
||||
create_vae_diffusers_config,
|
||||
convert_ldm_vae_state_dict,
|
||||
)
|
||||
|
||||
vae_path = Path(vae_path)
|
||||
if vae_path.suffix in ['.pt','.ckpt']:
|
||||
vae_state_dict = torch.load(vae_path, map_location="cpu")
|
||||
else:
|
||||
vae_state_dict = safetensors.torch.load_file(vae_path)
|
||||
if 'state_dict' in vae_state_dict:
|
||||
vae_state_dict = vae_state_dict['state_dict']
|
||||
# TODO: see if this works with 1.x inpaint models and 2.x models
|
||||
config_file_path = Path(Globals.root,"configs/stable-diffusion/v1-inference.yaml")
|
||||
original_conf = OmegaConf.load(config_file_path)
|
||||
vae_config = create_vae_diffusers_config(original_conf, image_size=512) # TODO: fix
|
||||
diffusers_vae = convert_ldm_vae_state_dict(vae_state_dict,vae_config)
|
||||
vae = AutoencoderKL(**vae_config)
|
||||
vae.load_state_dict(diffusers_vae)
|
||||
return vae
|
||||
|
||||
@staticmethod
|
||||
def _delete_model_from_cache(repo_id):
|
||||
cache_info = scan_cache_dir(global_cache_dir("diffusers"))
|
||||
|
||||
@@ -13,7 +13,7 @@ import re
|
||||
import atexit
|
||||
from typing import List
|
||||
from ldm.invoke.args import Args
|
||||
from ldm.invoke.concepts_lib import HuggingFaceConceptsLibrary
|
||||
from ldm.invoke.concepts_lib import get_hf_concepts_lib
|
||||
from ldm.invoke.globals import Globals
|
||||
from ldm.modules.lora_manager import LoraManager
|
||||
|
||||
@@ -287,7 +287,7 @@ class Completer(object):
|
||||
def _concept_completions(self, text, state):
|
||||
if self.concepts is None:
|
||||
# cache Concepts() instance so we can check for updates in concepts_list during runtime.
|
||||
self.concepts = HuggingFaceConceptsLibrary()
|
||||
self.concepts = get_hf_concepts_lib()
|
||||
self.embedding_terms.update(set(self.concepts.list_concepts()))
|
||||
else:
|
||||
self.embedding_terms.update(set(self.concepts.list_concepts()))
|
||||
|
||||
@@ -14,7 +14,6 @@ from torch import nn
|
||||
|
||||
from compel.cross_attention_control import Arguments
|
||||
from diffusers.models.unet_2d_condition import UNet2DConditionModel
|
||||
from diffusers.models.cross_attention import AttnProcessor
|
||||
from ldm.invoke.devices import torch_dtype
|
||||
|
||||
|
||||
@@ -163,7 +162,7 @@ class Context:
|
||||
|
||||
class InvokeAICrossAttentionMixin:
|
||||
"""
|
||||
Enable InvokeAI-flavoured CrossAttention calculation, which does aggressive low-memory slicing and calls
|
||||
Enable InvokeAI-flavoured Attention calculation, which does aggressive low-memory slicing and calls
|
||||
through both to an attention_slice_wrangler and a slicing_strategy_getter for custom attention map wrangling
|
||||
and dymamic slicing strategy selection.
|
||||
"""
|
||||
@@ -178,7 +177,7 @@ class InvokeAICrossAttentionMixin:
|
||||
Set custom attention calculator to be called when attention is calculated
|
||||
:param wrangler: Callback, with args (module, suggested_attention_slice, dim, offset, slice_size),
|
||||
which returns either the suggested_attention_slice or an adjusted equivalent.
|
||||
`module` is the current CrossAttention module for which the callback is being invoked.
|
||||
`module` is the current Attention module for which the callback is being invoked.
|
||||
`suggested_attention_slice` is the default-calculated attention slice
|
||||
`dim` is -1 if the attenion map has not been sliced, or 0 or 1 for dimension-0 or dimension-1 slicing.
|
||||
If `dim` is >= 0, `offset` and `slice_size` specify the slice start and length.
|
||||
@@ -326,7 +325,7 @@ def setup_cross_attention_control_attention_processors(unet: UNet2DConditionMode
|
||||
|
||||
|
||||
def get_cross_attention_modules(model, which: CrossAttentionType) -> list[tuple[str, InvokeAICrossAttentionMixin]]:
|
||||
from ldm.modules.attention import CrossAttention # avoid circular import
|
||||
from ldm.modules.attention import CrossAttention # avoid circular import # TODO: rename as in diffusers?
|
||||
cross_attention_class: type = InvokeAIDiffusersCrossAttention if isinstance(model,UNet2DConditionModel) else CrossAttention
|
||||
which_attn = "attn1" if which is CrossAttentionType.SELF else "attn2"
|
||||
attention_module_tuples = [(name,module) for name, module in model.named_modules() if
|
||||
@@ -432,7 +431,7 @@ def get_mem_free_total(device):
|
||||
|
||||
|
||||
|
||||
class InvokeAIDiffusersCrossAttention(diffusers.models.attention.CrossAttention, InvokeAICrossAttentionMixin):
|
||||
class InvokeAIDiffusersCrossAttention(diffusers.models.attention.Attention, InvokeAICrossAttentionMixin):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
@@ -457,8 +456,8 @@ class InvokeAIDiffusersCrossAttention(diffusers.models.attention.CrossAttention,
|
||||
"""
|
||||
# base implementation
|
||||
|
||||
class CrossAttnProcessor:
|
||||
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
||||
class AttnProcessor:
|
||||
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
||||
batch_size, sequence_length, _ = hidden_states.shape
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
|
||||
|
||||
@@ -487,7 +486,7 @@ from dataclasses import field, dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers.models.cross_attention import CrossAttention, CrossAttnProcessor, SlicedAttnProcessor
|
||||
from diffusers.models.attention_processor import Attention, AttnProcessor, SlicedAttnProcessor
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -532,7 +531,7 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):
|
||||
|
||||
# TODO: dynamically pick slice size based on memory conditions
|
||||
|
||||
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None,
|
||||
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None,
|
||||
# kwargs
|
||||
swap_cross_attn_context: SwapCrossAttnContext=None):
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ from typing import Callable, Optional, Union, Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from diffusers import UNet2DConditionModel
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ from torch import nn
|
||||
|
||||
import sys
|
||||
|
||||
from ldm.invoke.concepts_lib import HuggingFaceConceptsLibrary
|
||||
from ldm.invoke.concepts_lib import get_hf_concepts_lib
|
||||
from ldm.data.personalized import per_img_token_list
|
||||
from transformers import CLIPTokenizer
|
||||
from functools import partial
|
||||
@@ -39,7 +39,7 @@ class EmbeddingManager(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
self.embedder = embedder
|
||||
self.concepts_library=HuggingFaceConceptsLibrary()
|
||||
self.concepts_library=get_hf_concepts_lib()
|
||||
|
||||
self.string_to_token_dict = {}
|
||||
self.string_to_param_dict = nn.ParameterDict()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import Optional, Dict, Tuple
|
||||
|
||||
import torch
|
||||
from diffusers.models import UNet2DConditionModel
|
||||
@@ -9,7 +9,7 @@ from safetensors.torch import load_file
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
from transformers import CLIPTextModel
|
||||
|
||||
from ..invoke.globals import global_lora_models_dir
|
||||
from ..invoke.globals import global_lora_models_dir, Globals
|
||||
from ..invoke.devices import choose_torch_device
|
||||
|
||||
"""
|
||||
@@ -166,12 +166,12 @@ class LoKRLayer:
|
||||
class LoRAModuleWrapper:
|
||||
unet: UNet2DConditionModel
|
||||
text_encoder: CLIPTextModel
|
||||
hooks: list[RemovableHandle]
|
||||
hooks: Dict[str, Tuple[torch.nn.Module, RemovableHandle]]
|
||||
|
||||
def __init__(self, unet, text_encoder):
|
||||
self.unet = unet
|
||||
self.text_encoder = text_encoder
|
||||
self.hooks = []
|
||||
self.hooks = dict()
|
||||
self.text_modules = None
|
||||
self.unet_modules = None
|
||||
|
||||
@@ -228,7 +228,7 @@ class LoRAModuleWrapper:
|
||||
wrapper = self
|
||||
|
||||
def lora_forward(module, input_h, output):
|
||||
if len(wrapper.loaded_loras) == 0:
|
||||
if len(wrapper.applied_loras) == 0:
|
||||
return output
|
||||
|
||||
for lora in wrapper.applied_loras.values():
|
||||
@@ -241,11 +241,18 @@ class LoRAModuleWrapper:
|
||||
return lora_forward
|
||||
|
||||
def apply_module_forward(self, module, name):
|
||||
handle = module.register_forward_hook(self.lora_forward_hook(name))
|
||||
self.hooks.append(handle)
|
||||
if name in self.hooks:
|
||||
registered_module, _ = self.hooks[name]
|
||||
if registered_module != module:
|
||||
raise Exception(f"Trying to register multiple modules to lora key: {name}")
|
||||
# else it's just double hook creation - nothing to do
|
||||
|
||||
else:
|
||||
handle = module.register_forward_hook(self.lora_forward_hook(name))
|
||||
self.hooks[name] = (module, handle)
|
||||
|
||||
def clear_hooks(self):
|
||||
for hook in self.hooks:
|
||||
for _, hook in self.hooks.values():
|
||||
hook.remove()
|
||||
|
||||
self.hooks.clear()
|
||||
@@ -456,16 +463,25 @@ class LoRA:
|
||||
|
||||
|
||||
class KohyaLoraManager:
|
||||
lora_path = Path(global_lora_models_dir())
|
||||
vector_length_cache_path = lora_path / '.vectorlength.cache'
|
||||
|
||||
def __init__(self, pipe):
|
||||
self.vector_length_cache_path = self.lora_path / '.vectorlength.cache'
|
||||
self.unet = pipe.unet
|
||||
self.wrapper = LoRAModuleWrapper(pipe.unet, pipe.text_encoder)
|
||||
self.text_encoder = pipe.text_encoder
|
||||
self.device = torch.device(choose_torch_device())
|
||||
self.dtype = pipe.unet.dtype
|
||||
|
||||
@classmethod
|
||||
@property
|
||||
def lora_path(cls)->Path:
|
||||
return Path(global_lora_models_dir())
|
||||
|
||||
@classmethod
|
||||
@property
|
||||
def vector_length_cache_path(cls)->Path:
|
||||
return cls.lora_path / '.vectorlength.cache'
|
||||
|
||||
def load_lora_module(self, name, path_file, multiplier: float = 1.0):
|
||||
print(f" | Found lora {name} at {path_file}")
|
||||
if path_file.suffix == ".safetensors":
|
||||
|
||||
@@ -3,14 +3,16 @@ from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
import safetensors.torch
|
||||
import torch
|
||||
import warnings
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", category=UserWarning)
|
||||
import safetensors.torch
|
||||
import torch
|
||||
from picklescan.scanner import scan_file_path
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from compel.embeddings_provider import BaseTextualInversionManager
|
||||
from ldm.invoke.concepts_lib import HuggingFaceConceptsLibrary
|
||||
|
||||
from ldm.invoke.concepts_lib import get_hf_concepts_lib
|
||||
|
||||
@dataclass
|
||||
class TextualInversion:
|
||||
@@ -34,7 +36,7 @@ class TextualInversionManager(BaseTextualInversionManager):
|
||||
self.tokenizer = tokenizer
|
||||
self.text_encoder = text_encoder
|
||||
self.full_precision = full_precision
|
||||
self.hf_concepts_library = HuggingFaceConceptsLibrary()
|
||||
self.hf_concepts_library = get_hf_concepts_lib()
|
||||
self.trigger_to_sourcefile = dict()
|
||||
default_textual_inversions: list[TextualInversion] = []
|
||||
self.textual_inversions = default_textual_inversions
|
||||
|
||||
@@ -32,9 +32,9 @@ dependencies = [
|
||||
"albumentations",
|
||||
"click",
|
||||
"clip_anytorch",
|
||||
"compel~=1.1.0",
|
||||
"compel~=1.1.5",
|
||||
"datasets",
|
||||
"diffusers[torch]==0.14",
|
||||
"diffusers[torch]~=0.16.1",
|
||||
"dnspython==2.2.1",
|
||||
"einops",
|
||||
"eventlet",
|
||||
@@ -76,7 +76,7 @@ dependencies = [
|
||||
"taming-transformers-rom1504",
|
||||
"test-tube>=0.7.5",
|
||||
"torch-fidelity",
|
||||
"torch~=1.13.1",
|
||||
"torch~=2.0.0",
|
||||
"torchmetrics",
|
||||
"torchvision>=0.14.1",
|
||||
"transformers~=4.26",
|
||||
@@ -108,7 +108,7 @@ requires-python = ">=3.9, <3.11"
|
||||
"test" = ["pytest-cov", "pytest>6.0.0"]
|
||||
"xformers" = [
|
||||
"triton; sys_platform=='linux'",
|
||||
"xformers~=0.0.16; sys_platform!='darwin'",
|
||||
"xformers~=0.0.19; sys_platform!='darwin'",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
Reference in New Issue
Block a user