mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Merge branch 'v2.3' into lstein/bugfix/improve-update-handling
This commit is contained in:
34
.github/CODEOWNERS
vendored
34
.github/CODEOWNERS
vendored
@@ -1,13 +1,13 @@
|
||||
# continuous integration
|
||||
/.github/workflows/ @mauwii @lstein @blessedcoolant
|
||||
/.github/workflows/ @lstein @blessedcoolant
|
||||
|
||||
# documentation
|
||||
/docs/ @lstein @mauwii @blessedcoolant
|
||||
mkdocs.yml @mauwii @lstein
|
||||
/docs/ @lstein @blessedcoolant
|
||||
mkdocs.yml @lstein @ebr
|
||||
|
||||
# installation and configuration
|
||||
/pyproject.toml @mauwii @lstein @ebr
|
||||
/docker/ @mauwii
|
||||
/pyproject.toml @lstein @ebr
|
||||
/docker/ @lstein
|
||||
/scripts/ @ebr @lstein @blessedcoolant
|
||||
/installer/ @ebr @lstein
|
||||
ldm/invoke/config @lstein @ebr
|
||||
@@ -21,13 +21,13 @@ invokeai/configs @lstein @ebr @blessedcoolant
|
||||
|
||||
# generation and model management
|
||||
/ldm/*.py @lstein @blessedcoolant
|
||||
/ldm/generate.py @lstein @keturn
|
||||
/ldm/generate.py @lstein @gregghelt2
|
||||
/ldm/invoke/args.py @lstein @blessedcoolant
|
||||
/ldm/invoke/ckpt* @lstein @blessedcoolant
|
||||
/ldm/invoke/ckpt_generator @lstein @blessedcoolant
|
||||
/ldm/invoke/CLI.py @lstein @blessedcoolant
|
||||
/ldm/invoke/config @lstein @ebr @mauwii @blessedcoolant
|
||||
/ldm/invoke/generator @keturn @damian0815
|
||||
/ldm/invoke/config @lstein @ebr @blessedcoolant
|
||||
/ldm/invoke/generator @gregghelt2 @damian0815
|
||||
/ldm/invoke/globals.py @lstein @blessedcoolant
|
||||
/ldm/invoke/merge_diffusers.py @lstein @blessedcoolant
|
||||
/ldm/invoke/model_manager.py @lstein @blessedcoolant
|
||||
@@ -36,17 +36,17 @@ invokeai/configs @lstein @ebr @blessedcoolant
|
||||
/ldm/invoke/restoration @lstein @blessedcoolant
|
||||
|
||||
# attention, textual inversion, model configuration
|
||||
/ldm/models @damian0815 @keturn @blessedcoolant
|
||||
/ldm/models @damian0815 @gregghelt2 @blessedcoolant
|
||||
/ldm/modules/textual_inversion_manager.py @lstein @blessedcoolant
|
||||
/ldm/modules/attention.py @damian0815 @keturn
|
||||
/ldm/modules/diffusionmodules @damian0815 @keturn
|
||||
/ldm/modules/distributions @damian0815 @keturn
|
||||
/ldm/modules/ema.py @damian0815 @keturn
|
||||
/ldm/modules/attention.py @damian0815 @gregghelt2
|
||||
/ldm/modules/diffusionmodules @damian0815 @gregghelt2
|
||||
/ldm/modules/distributions @damian0815 @gregghelt2
|
||||
/ldm/modules/ema.py @damian0815 @gregghelt2
|
||||
/ldm/modules/embedding_manager.py @lstein
|
||||
/ldm/modules/encoders @damian0815 @keturn
|
||||
/ldm/modules/image_degradation @damian0815 @keturn
|
||||
/ldm/modules/losses @damian0815 @keturn
|
||||
/ldm/modules/x_transformer.py @damian0815 @keturn
|
||||
/ldm/modules/encoders @damian0815 @gregghelt2
|
||||
/ldm/modules/image_degradation @damian0815 @gregghelt2
|
||||
/ldm/modules/losses @damian0815 @gregghelt2
|
||||
/ldm/modules/x_transformer.py @damian0815 @gregghelt2
|
||||
|
||||
# Nodes
|
||||
apps/ @Kyle0654 @jpphoto
|
||||
|
||||
@@ -30,7 +30,6 @@ from ldm.invoke.conditioning import (
|
||||
get_tokens_for_prompt_object,
|
||||
get_prompt_structure,
|
||||
split_weighted_subprompts,
|
||||
get_tokenizer,
|
||||
)
|
||||
from ldm.invoke.generator.diffusers_pipeline import PipelineIntermediateState
|
||||
from ldm.invoke.generator.inpaint import infill_methods
|
||||
@@ -1314,7 +1313,7 @@ class InvokeAIWebServer:
|
||||
None
|
||||
if type(parsed_prompt) is Blend
|
||||
else get_tokens_for_prompt_object(
|
||||
get_tokenizer(self.generate.model), parsed_prompt
|
||||
self.generate.model.tokenizer, parsed_prompt
|
||||
)
|
||||
)
|
||||
attention_maps_image_base64_url = (
|
||||
|
||||
@@ -15,19 +15,10 @@ from compel import Compel
|
||||
from compel.prompt_parser import FlattenedPrompt, Blend, Fragment, CrossAttentionControlSubstitute, PromptParser, \
|
||||
Conjunction
|
||||
from .devices import torch_dtype
|
||||
from .generator.diffusers_pipeline import StableDiffusionGeneratorPipeline
|
||||
from ..models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
from ldm.invoke.globals import Globals
|
||||
|
||||
def get_tokenizer(model) -> CLIPTokenizer:
|
||||
# TODO remove legacy ckpt fallback handling
|
||||
return (getattr(model, 'tokenizer', None) # diffusers
|
||||
or model.cond_stage_model.tokenizer) # ldm
|
||||
|
||||
def get_text_encoder(model) -> Any:
|
||||
# TODO remove legacy ckpt fallback handling
|
||||
return (getattr(model, 'text_encoder', None) # diffusers
|
||||
or UnsqueezingLDMTransformer(model.cond_stage_model.transformer)) # ldm
|
||||
|
||||
class UnsqueezingLDMTransformer:
|
||||
def __init__(self, ldm_transformer):
|
||||
self.ldm_transformer = ldm_transformer
|
||||
@@ -41,15 +32,15 @@ class UnsqueezingLDMTransformer:
|
||||
return insufficiently_unsqueezed_tensor.unsqueeze(0)
|
||||
|
||||
|
||||
def get_uc_and_c_and_ec(prompt_string, model, log_tokens=False, skip_normalize_legacy_blend=False):
|
||||
def get_uc_and_c_and_ec(prompt_string,
|
||||
model: StableDiffusionGeneratorPipeline,
|
||||
log_tokens=False, skip_normalize_legacy_blend=False):
|
||||
# lazy-load any deferred textual inversions.
|
||||
# this might take a couple of seconds the first time a textual inversion is used.
|
||||
model.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms(prompt_string)
|
||||
|
||||
tokenizer = get_tokenizer(model)
|
||||
text_encoder = get_text_encoder(model)
|
||||
compel = Compel(tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
compel = Compel(tokenizer=model.tokenizer,
|
||||
text_encoder=model.text_encoder,
|
||||
textual_inversion_manager=model.textual_inversion_manager,
|
||||
dtype_for_device_getter=torch_dtype)
|
||||
|
||||
@@ -78,14 +69,20 @@ def get_uc_and_c_and_ec(prompt_string, model, log_tokens=False, skip_normalize_l
|
||||
negative_conjunction = Compel.parse_prompt_string(negative_prompt_string)
|
||||
negative_prompt: FlattenedPrompt | Blend = negative_conjunction.prompts[0]
|
||||
|
||||
tokens_count = get_max_token_count(model.tokenizer, positive_prompt)
|
||||
if log_tokens or getattr(Globals, "log_tokenization", False):
|
||||
log_tokenization(positive_prompt, negative_prompt, tokenizer=tokenizer)
|
||||
log_tokenization(positive_prompt, negative_prompt, tokenizer=model.tokenizer)
|
||||
|
||||
c, options = compel.build_conditioning_tensor_for_prompt_object(positive_prompt)
|
||||
uc, _ = compel.build_conditioning_tensor_for_prompt_object(negative_prompt)
|
||||
|
||||
tokens_count = get_max_token_count(tokenizer, positive_prompt)
|
||||
# some LoRA models also mess with the text encoder, so they must be active while compel builds conditioning tensors
|
||||
lora_conditioning_ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(tokens_count_including_eos_bos=tokens_count,
|
||||
lora_conditions=lora_conditions)
|
||||
with InvokeAIDiffuserComponent.custom_attention_context(model.unet,
|
||||
extra_conditioning_info=lora_conditioning_ec,
|
||||
step_count=-1):
|
||||
c, options = compel.build_conditioning_tensor_for_prompt_object(positive_prompt)
|
||||
uc, _ = compel.build_conditioning_tensor_for_prompt_object(negative_prompt)
|
||||
|
||||
# now build the "real" ec
|
||||
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(tokens_count_including_eos_bos=tokens_count,
|
||||
cross_attention_control_args=options.get(
|
||||
'cross_attention_control', None),
|
||||
|
||||
@@ -196,16 +196,6 @@ class addModelsForm(npyscreen.FormMultiPage):
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely += 1
|
||||
self.convert_models = self.add_widget_intelligent(
|
||||
npyscreen.TitleSelectOne,
|
||||
name="== CONVERT IMPORTED MODELS INTO DIFFUSERS==",
|
||||
values=["Keep original format", "Convert to diffusers"],
|
||||
value=0,
|
||||
begin_entry_at=4,
|
||||
max_height=4,
|
||||
hidden=True, # will appear when imported models box is edited
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.cancel = self.add_widget_intelligent(
|
||||
npyscreen.ButtonPress,
|
||||
name="CANCEL",
|
||||
@@ -240,8 +230,6 @@ class addModelsForm(npyscreen.FormMultiPage):
|
||||
self.show_directory_fields.addVisibleWhenSelected(i)
|
||||
|
||||
self.show_directory_fields.when_value_edited = self._clear_scan_directory
|
||||
self.import_model_paths.when_value_edited = self._show_hide_convert
|
||||
self.autoload_directory.when_value_edited = self._show_hide_convert
|
||||
|
||||
def resize(self):
|
||||
super().resize()
|
||||
@@ -252,13 +240,6 @@ class addModelsForm(npyscreen.FormMultiPage):
|
||||
if not self.show_directory_fields.value:
|
||||
self.autoload_directory.value = ""
|
||||
|
||||
def _show_hide_convert(self):
|
||||
model_paths = self.import_model_paths.value or ""
|
||||
autoload_directory = self.autoload_directory.value or ""
|
||||
self.convert_models.hidden = (
|
||||
len(model_paths) == 0 and len(autoload_directory) == 0
|
||||
)
|
||||
|
||||
def _get_starter_model_labels(self) -> List[str]:
|
||||
window_width, window_height = get_terminal_size()
|
||||
label_width = 25
|
||||
@@ -318,7 +299,6 @@ class addModelsForm(npyscreen.FormMultiPage):
|
||||
.scan_directory: Path to a directory of models to scan and import
|
||||
.autoscan_on_startup: True if invokeai should scan and import at startup time
|
||||
.import_model_paths: list of URLs, repo_ids and file paths to import
|
||||
.convert_to_diffusers: if True, convert legacy checkpoints into diffusers
|
||||
"""
|
||||
# we're using a global here rather than storing the result in the parentapp
|
||||
# due to some bug in npyscreen that is causing attributes to be lost
|
||||
@@ -354,7 +334,6 @@ class addModelsForm(npyscreen.FormMultiPage):
|
||||
|
||||
# URLs and the like
|
||||
selections.import_model_paths = self.import_model_paths.value.split()
|
||||
selections.convert_to_diffusers = self.convert_models.value[0] == 1
|
||||
|
||||
|
||||
class AddModelApplication(npyscreen.NPSAppManaged):
|
||||
@@ -367,7 +346,6 @@ class AddModelApplication(npyscreen.NPSAppManaged):
|
||||
scan_directory=None,
|
||||
autoscan_on_startup=None,
|
||||
import_model_paths=None,
|
||||
convert_to_diffusers=None,
|
||||
)
|
||||
|
||||
def onStart(self):
|
||||
@@ -387,7 +365,6 @@ def process_and_execute(opt: Namespace, selections: Namespace):
|
||||
directory_to_scan = selections.scan_directory
|
||||
scan_at_startup = selections.autoscan_on_startup
|
||||
potential_models_to_install = selections.import_model_paths
|
||||
convert_to_diffusers = selections.convert_to_diffusers
|
||||
|
||||
install_requested_models(
|
||||
install_initial_models=models_to_install,
|
||||
@@ -395,7 +372,6 @@ def process_and_execute(opt: Namespace, selections: Namespace):
|
||||
scan_directory=Path(directory_to_scan) if directory_to_scan else None,
|
||||
external_models=potential_models_to_install,
|
||||
scan_at_startup=scan_at_startup,
|
||||
convert_to_diffusers=convert_to_diffusers,
|
||||
precision="float32"
|
||||
if opt.full_precision
|
||||
else choose_precision(torch.device(choose_torch_device())),
|
||||
|
||||
@@ -68,7 +68,6 @@ def install_requested_models(
|
||||
scan_directory: Path = None,
|
||||
external_models: List[str] = None,
|
||||
scan_at_startup: bool = False,
|
||||
convert_to_diffusers: bool = False,
|
||||
precision: str = "float16",
|
||||
purge_deleted: bool = False,
|
||||
config_file_path: Path = None,
|
||||
@@ -111,20 +110,20 @@ def install_requested_models(
|
||||
if len(external_models)>0:
|
||||
print("== INSTALLING EXTERNAL MODELS ==")
|
||||
for path_url_or_repo in external_models:
|
||||
print(f'DEBUG: path_url_or_repo = {path_url_or_repo}')
|
||||
try:
|
||||
model_manager.heuristic_import(
|
||||
path_url_or_repo,
|
||||
convert=convert_to_diffusers,
|
||||
config_file_callback=_pick_configuration_file,
|
||||
commit_to_conf=config_file_path
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
sys.exit(-1)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
print(f'An exception has occurred: {str(e)}')
|
||||
|
||||
if scan_at_startup and scan_directory.is_dir():
|
||||
argument = '--autoconvert' if convert_to_diffusers else '--autoimport'
|
||||
argument = '--autoconvert'
|
||||
initfile = Path(Globals.root, Globals.initfile)
|
||||
replacement = Path(Globals.root, f'{Globals.initfile}.new')
|
||||
directory = str(scan_directory).replace('\\','/')
|
||||
|
||||
@@ -467,8 +467,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
if additional_guidance is None:
|
||||
additional_guidance = []
|
||||
extra_conditioning_info = conditioning_data.extra
|
||||
with self.invokeai_diffuser.custom_attention_context(extra_conditioning_info=extra_conditioning_info,
|
||||
step_count=len(self.scheduler.timesteps)
|
||||
with InvokeAIDiffuserComponent.custom_attention_context(self.invokeai_diffuser.model,
|
||||
extra_conditioning_info=extra_conditioning_info,
|
||||
step_count=len(self.scheduler.timesteps)
|
||||
):
|
||||
|
||||
yield PipelineIntermediateState(run_id=run_id, step=-1, timestep=self.scheduler.num_train_timesteps,
|
||||
|
||||
@@ -288,16 +288,7 @@ class InvokeAICrossAttentionMixin:
|
||||
return self.einsum_op_tensor_mem(q, k, v, 32)
|
||||
|
||||
|
||||
|
||||
def restore_default_cross_attention(model, is_running_diffusers: bool, processors_to_restore: Optional[AttnProcessor]=None):
|
||||
if is_running_diffusers:
|
||||
unet = model
|
||||
unet.set_attn_processor(processors_to_restore or CrossAttnProcessor())
|
||||
else:
|
||||
remove_attention_function(model)
|
||||
|
||||
|
||||
def override_cross_attention(model, context: Context, is_running_diffusers = False):
|
||||
def setup_cross_attention_control_attention_processors(unet: UNet2DConditionModel, context: Context):
|
||||
"""
|
||||
Inject attention parameters and functions into the passed in model to enable cross attention editing.
|
||||
|
||||
@@ -323,22 +314,15 @@ def override_cross_attention(model, context: Context, is_running_diffusers = Fal
|
||||
|
||||
context.cross_attention_mask = mask.to(device)
|
||||
context.cross_attention_index_map = indices.to(device)
|
||||
if is_running_diffusers:
|
||||
unet = model
|
||||
old_attn_processors = unet.attn_processors
|
||||
if torch.backends.mps.is_available():
|
||||
# see note in StableDiffusionGeneratorPipeline.__init__ about borked slicing on MPS
|
||||
unet.set_attn_processor(SwapCrossAttnProcessor())
|
||||
else:
|
||||
# try to re-use an existing slice size
|
||||
default_slice_size = 4
|
||||
slice_size = next((p.slice_size for p in old_attn_processors.values() if type(p) is SlicedAttnProcessor), default_slice_size)
|
||||
unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size))
|
||||
old_attn_processors = unet.attn_processors
|
||||
if torch.backends.mps.is_available():
|
||||
# see note in StableDiffusionGeneratorPipeline.__init__ about borked slicing on MPS
|
||||
unet.set_attn_processor(SwapCrossAttnProcessor())
|
||||
else:
|
||||
context.register_cross_attention_modules(model)
|
||||
inject_attention_function(model, context)
|
||||
|
||||
|
||||
# try to re-use an existing slice size
|
||||
default_slice_size = 4
|
||||
slice_size = next((p.slice_size for p in old_attn_processors.values() if type(p) is SlicedAttnProcessor), default_slice_size)
|
||||
unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size))
|
||||
|
||||
|
||||
def get_cross_attention_modules(model, which: CrossAttentionType) -> list[tuple[str, InvokeAICrossAttentionMixin]]:
|
||||
|
||||
@@ -12,17 +12,6 @@ class DDIMSampler(Sampler):
|
||||
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.model,
|
||||
model_forward_callback = lambda x, sigma, cond: self.model.apply_model(x, sigma, cond))
|
||||
|
||||
def prepare_to_sample(self, t_enc, **kwargs):
|
||||
super().prepare_to_sample(t_enc, **kwargs)
|
||||
|
||||
extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
|
||||
all_timesteps_count = kwargs.get('all_timesteps_count', t_enc)
|
||||
|
||||
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
|
||||
self.invokeai_diffuser.override_attention_processors(extra_conditioning_info, step_count = all_timesteps_count)
|
||||
else:
|
||||
self.invokeai_diffuser.restore_default_cross_attention()
|
||||
|
||||
|
||||
# This is the central routine
|
||||
@torch.no_grad()
|
||||
|
||||
@@ -38,15 +38,6 @@ class CFGDenoiser(nn.Module):
|
||||
model_forward_callback=lambda x, sigma, cond: self.inner_model(x, sigma, cond=cond))
|
||||
|
||||
|
||||
def prepare_to_sample(self, t_enc, **kwargs):
|
||||
|
||||
extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
|
||||
|
||||
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
|
||||
self.invokeai_diffuser.override_attention_processors(extra_conditioning_info, step_count = t_enc)
|
||||
else:
|
||||
self.invokeai_diffuser.restore_default_cross_attention()
|
||||
|
||||
|
||||
def forward(self, x, sigma, uncond, cond, cond_scale):
|
||||
next_x = self.invokeai_diffuser.do_diffusion_step(x, sigma, uncond, cond, cond_scale)
|
||||
|
||||
@@ -14,17 +14,6 @@ class PLMSSampler(Sampler):
|
||||
def __init__(self, model, schedule='linear', device=None, **kwargs):
|
||||
super().__init__(model,schedule,model.num_timesteps, device)
|
||||
|
||||
def prepare_to_sample(self, t_enc, **kwargs):
|
||||
super().prepare_to_sample(t_enc, **kwargs)
|
||||
|
||||
extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
|
||||
all_timesteps_count = kwargs.get('all_timesteps_count', t_enc)
|
||||
|
||||
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
|
||||
self.invokeai_diffuser.override_attention_processors(extra_conditioning_info, step_count = all_timesteps_count)
|
||||
else:
|
||||
self.invokeai_diffuser.restore_default_cross_attention()
|
||||
|
||||
|
||||
# this is the essential routine
|
||||
@torch.no_grad()
|
||||
|
||||
@@ -1,18 +1,18 @@
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from math import ceil
|
||||
from typing import Callable, Optional, Union, Any, Dict
|
||||
from typing import Callable, Optional, Union, Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers.models.cross_attention import AttnProcessor
|
||||
|
||||
from diffusers import UNet2DConditionModel
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from ldm.invoke.globals import Globals
|
||||
from ldm.models.diffusion.cross_attention_control import (
|
||||
Arguments,
|
||||
restore_default_cross_attention,
|
||||
override_cross_attention,
|
||||
setup_cross_attention_control_attention_processors,
|
||||
Context,
|
||||
get_cross_attention_modules,
|
||||
CrossAttentionType,
|
||||
@@ -84,66 +84,45 @@ class InvokeAIDiffuserComponent:
|
||||
self.cross_attention_control_context = None
|
||||
self.sequential_guidance = Globals.sequential_guidance
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def custom_attention_context(
|
||||
self, extra_conditioning_info: Optional[ExtraConditioningInfo], step_count: int
|
||||
clss,
|
||||
unet: UNet2DConditionModel, # note: also may futz with the text encoder depending on requested LoRAs
|
||||
extra_conditioning_info: Optional[ExtraConditioningInfo],
|
||||
step_count: int
|
||||
):
|
||||
old_attn_processor = None
|
||||
old_attn_processors = None
|
||||
if extra_conditioning_info and (
|
||||
extra_conditioning_info.wants_cross_attention_control
|
||||
| extra_conditioning_info.has_lora_conditions
|
||||
):
|
||||
old_attn_processor = self.override_attention_processors(
|
||||
extra_conditioning_info, step_count=step_count
|
||||
)
|
||||
old_attn_processors = unet.attn_processors
|
||||
# Load lora conditions into the model
|
||||
if extra_conditioning_info.has_lora_conditions:
|
||||
for condition in extra_conditioning_info.lora_conditions:
|
||||
condition() # target model is stored in condition state for some reason
|
||||
if extra_conditioning_info.wants_cross_attention_control:
|
||||
cross_attention_control_context = Context(
|
||||
arguments=extra_conditioning_info.cross_attention_control_args,
|
||||
step_count=step_count,
|
||||
)
|
||||
setup_cross_attention_control_attention_processors(
|
||||
unet,
|
||||
cross_attention_control_context,
|
||||
)
|
||||
|
||||
try:
|
||||
yield None
|
||||
finally:
|
||||
if old_attn_processor is not None:
|
||||
self.restore_default_cross_attention(old_attn_processor)
|
||||
if old_attn_processors is not None:
|
||||
unet.set_attn_processor(old_attn_processors)
|
||||
if extra_conditioning_info and extra_conditioning_info.has_lora_conditions:
|
||||
for lora_condition in extra_conditioning_info.lora_conditions:
|
||||
lora_condition.unload()
|
||||
# TODO resuscitate attention map saving
|
||||
# self.remove_attention_map_saving()
|
||||
|
||||
def override_attention_processors(
|
||||
self, conditioning: ExtraConditioningInfo, step_count: int
|
||||
) -> Dict[str, AttnProcessor]:
|
||||
"""
|
||||
setup cross attention .swap control. for diffusers this replaces the attention processor, so
|
||||
the previous attention processor is returned so that the caller can restore it later.
|
||||
"""
|
||||
old_attn_processors = self.model.attn_processors
|
||||
|
||||
# Load lora conditions into the model
|
||||
if conditioning.has_lora_conditions:
|
||||
for condition in conditioning.lora_conditions:
|
||||
condition(self.model)
|
||||
|
||||
if conditioning.wants_cross_attention_control:
|
||||
self.cross_attention_control_context = Context(
|
||||
arguments=conditioning.cross_attention_control_args,
|
||||
step_count=step_count,
|
||||
)
|
||||
override_cross_attention(
|
||||
self.model,
|
||||
self.cross_attention_control_context,
|
||||
is_running_diffusers=self.is_running_diffusers,
|
||||
)
|
||||
return old_attn_processors
|
||||
|
||||
def restore_default_cross_attention(
|
||||
self, processors_to_restore: Optional[dict[str, "AttnProcessor"]] = None
|
||||
):
|
||||
self.cross_attention_control_context = None
|
||||
restore_default_cross_attention(
|
||||
self.model,
|
||||
is_running_diffusers=self.is_running_diffusers,
|
||||
processors_to_restore=processors_to_restore,
|
||||
)
|
||||
|
||||
def setup_attention_map_saving(self, saver: AttentionMapSaver):
|
||||
def callback(slice, dim, offset, slice_size, key):
|
||||
if dim is not None:
|
||||
|
||||
@@ -31,18 +31,13 @@ class LoRALayer:
|
||||
self.name = name
|
||||
self.scale = alpha / rank if (alpha and rank) else 1.0
|
||||
|
||||
def forward(self, lora, input_h, output):
|
||||
def forward(self, lora, input_h):
|
||||
if self.mid is None:
|
||||
output = (
|
||||
output
|
||||
+ self.up(self.down(*input_h)) * lora.multiplier * self.scale
|
||||
)
|
||||
weight = self.up(self.down(*input_h))
|
||||
else:
|
||||
output = (
|
||||
output
|
||||
+ self.up(self.mid(self.down(*input_h))) * lora.multiplier * self.scale
|
||||
)
|
||||
return output
|
||||
weight = self.up(self.mid(self.down(*input_h)))
|
||||
|
||||
return weight * lora.multiplier * self.scale
|
||||
|
||||
class LoHALayer:
|
||||
lora_name: str
|
||||
@@ -64,7 +59,7 @@ class LoHALayer:
|
||||
self.name = name
|
||||
self.scale = alpha / rank if (alpha and rank) else 1.0
|
||||
|
||||
def forward(self, lora, input_h, output):
|
||||
def forward(self, lora, input_h):
|
||||
|
||||
if type(self.org_module) == torch.nn.Conv2d:
|
||||
op = torch.nn.functional.conv2d
|
||||
@@ -86,9 +81,9 @@ class LoHALayer:
|
||||
rebuild1 = torch.einsum('i j k l, j r, i p -> p r k l', self.t1, self.w1_b, self.w1_a)
|
||||
rebuild2 = torch.einsum('i j k l, j r, i p -> p r k l', self.t2, self.w2_b, self.w2_a)
|
||||
weight = rebuild1 * rebuild2
|
||||
|
||||
|
||||
bias = self.bias if self.bias is not None else 0
|
||||
return output + op(
|
||||
return op(
|
||||
*input_h,
|
||||
(weight + bias).view(self.org_module.weight.shape),
|
||||
None,
|
||||
@@ -96,6 +91,69 @@ class LoHALayer:
|
||||
) * lora.multiplier * self.scale
|
||||
|
||||
|
||||
class LoKRLayer:
|
||||
lora_name: str
|
||||
name: str
|
||||
scale: float
|
||||
|
||||
w1: Optional[torch.Tensor] = None
|
||||
w1_a: Optional[torch.Tensor] = None
|
||||
w1_b: Optional[torch.Tensor] = None
|
||||
w2: Optional[torch.Tensor] = None
|
||||
w2_a: Optional[torch.Tensor] = None
|
||||
w2_b: Optional[torch.Tensor] = None
|
||||
t2: Optional[torch.Tensor] = None
|
||||
bias: Optional[torch.Tensor] = None
|
||||
|
||||
org_module: torch.nn.Module
|
||||
|
||||
def __init__(self, lora_name: str, name: str, rank=4, alpha=1.0):
|
||||
self.lora_name = lora_name
|
||||
self.name = name
|
||||
self.scale = alpha / rank if (alpha and rank) else 1.0
|
||||
|
||||
def forward(self, lora, input_h):
|
||||
|
||||
if type(self.org_module) == torch.nn.Conv2d:
|
||||
op = torch.nn.functional.conv2d
|
||||
extra_args = dict(
|
||||
stride=self.org_module.stride,
|
||||
padding=self.org_module.padding,
|
||||
dilation=self.org_module.dilation,
|
||||
groups=self.org_module.groups,
|
||||
)
|
||||
|
||||
else:
|
||||
op = torch.nn.functional.linear
|
||||
extra_args = {}
|
||||
|
||||
w1 = self.w1
|
||||
if w1 is None:
|
||||
w1 = self.w1_a @ self.w1_b
|
||||
|
||||
w2 = self.w2
|
||||
if w2 is None:
|
||||
if self.t2 is None:
|
||||
w2 = self.w2_a @ self.w2_b
|
||||
else:
|
||||
w2 = torch.einsum('i j k l, i p, j r -> p r k l', self.t2, self.w2_a, self.w2_b)
|
||||
|
||||
|
||||
if len(w2.shape) == 4:
|
||||
w1 = w1.unsqueeze(2).unsqueeze(2)
|
||||
w2 = w2.contiguous()
|
||||
weight = torch.kron(w1, w2).reshape(self.org_module.weight.shape)
|
||||
|
||||
|
||||
bias = self.bias if self.bias is not None else 0
|
||||
return op(
|
||||
*input_h,
|
||||
(weight + bias).view(self.org_module.weight.shape),
|
||||
None,
|
||||
**extra_args
|
||||
) * lora.multiplier * self.scale
|
||||
|
||||
|
||||
class LoRAModuleWrapper:
|
||||
unet: UNet2DConditionModel
|
||||
text_encoder: CLIPTextModel
|
||||
@@ -159,7 +217,7 @@ class LoRAModuleWrapper:
|
||||
layer = lora.layers.get(name, None)
|
||||
if layer is None:
|
||||
continue
|
||||
output = layer.forward(lora, input_h, output)
|
||||
output += layer.forward(lora, input_h)
|
||||
return output
|
||||
|
||||
return lora_forward
|
||||
@@ -307,6 +365,36 @@ class LoRA:
|
||||
else:
|
||||
layer.t2 = None
|
||||
|
||||
# lokr
|
||||
elif "lokr_w1_b" in values or "lokr_w1" in values:
|
||||
|
||||
if "lokr_w1_b" in values:
|
||||
rank = values["lokr_w1_b"].shape[0]
|
||||
elif "lokr_w2_b" in values:
|
||||
rank = values["lokr_w2_b"].shape[0]
|
||||
else:
|
||||
rank = None # unscaled
|
||||
|
||||
layer = LoKRLayer(self.name, stem, rank, alpha)
|
||||
layer.org_module = wrapped
|
||||
layer.bias = bias
|
||||
|
||||
if "lokr_w1" in values:
|
||||
layer.w1 = values["lokr_w1"].to(device=self.device, dtype=self.dtype)
|
||||
else:
|
||||
layer.w1_a = values["lokr_w1_a"].to(device=self.device, dtype=self.dtype)
|
||||
layer.w1_b = values["lokr_w1_b"].to(device=self.device, dtype=self.dtype)
|
||||
|
||||
if "lokr_w2" in values:
|
||||
layer.w2 = values["lokr_w2"].to(device=self.device, dtype=self.dtype)
|
||||
else:
|
||||
layer.w2_a = values["lokr_w2_a"].to(device=self.device, dtype=self.dtype)
|
||||
layer.w2_b = values["lokr_w2_b"].to(device=self.device, dtype=self.dtype)
|
||||
|
||||
if "lokr_t2" in values:
|
||||
layer.t2 = values["lokr_t2"].to(device=self.device, dtype=self.dtype)
|
||||
|
||||
|
||||
else:
|
||||
print(
|
||||
f">> Encountered unknown lora layer module in {self.name}: {stem} - {type(wrapped).__name__}"
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from diffusers import UNet2DConditionModel, StableDiffusionPipeline
|
||||
from ldm.invoke.globals import global_lora_models_dir
|
||||
from .kohya_lora_manager import KohyaLoraManager
|
||||
from typing import Optional, Dict
|
||||
@@ -8,20 +10,29 @@ class LoraCondition:
|
||||
name: str
|
||||
weight: float
|
||||
|
||||
def __init__(self, name, weight: float = 1.0, kohya_manager: Optional[KohyaLoraManager]=None):
|
||||
def __init__(self,
|
||||
name,
|
||||
weight: float = 1.0,
|
||||
unet: UNet2DConditionModel=None, # for diffusers format LoRAs
|
||||
kohya_manager: Optional[KohyaLoraManager]=None, # for KohyaLoraManager-compatible LoRAs
|
||||
):
|
||||
self.name = name
|
||||
self.weight = weight
|
||||
self.kohya_manager = kohya_manager
|
||||
self.unet = unet
|
||||
|
||||
def __call__(self, model):
|
||||
def __call__(self):
|
||||
# TODO: make model able to load from huggingface, rather then just local files
|
||||
path = Path(global_lora_models_dir(), self.name)
|
||||
if path.is_dir():
|
||||
if model.load_attn_procs:
|
||||
if not self.unet:
|
||||
print(f" ** Unable to load diffusers-format LoRA {self.name}: unet is None")
|
||||
return
|
||||
if self.unet.load_attn_procs:
|
||||
file = Path(path, "pytorch_lora_weights.bin")
|
||||
if file.is_file():
|
||||
print(f">> Loading LoRA: {path}")
|
||||
model.load_attn_procs(path.absolute().as_posix())
|
||||
self.unet.load_attn_procs(path.absolute().as_posix())
|
||||
else:
|
||||
print(f" ** Unable to find valid LoRA at: {path}")
|
||||
else:
|
||||
@@ -37,15 +48,16 @@ class LoraCondition:
|
||||
self.kohya_manager.unload_applied_lora(self.name)
|
||||
|
||||
class LoraManager:
|
||||
def __init__(self, pipe):
|
||||
def __init__(self, pipe: StableDiffusionPipeline):
|
||||
# Kohya class handles lora not generated through diffusers
|
||||
self.kohya = KohyaLoraManager(pipe, global_lora_models_dir())
|
||||
self.unet = pipe.unet
|
||||
|
||||
def set_loras_conditions(self, lora_weights: list):
|
||||
conditions = []
|
||||
if len(lora_weights) > 0:
|
||||
for lora in lora_weights:
|
||||
conditions.append(LoraCondition(lora.model, lora.weight, self.kohya))
|
||||
conditions.append(LoraCondition(lora.model, lora.weight, self.unet, self.kohya))
|
||||
|
||||
if len(conditions) > 0:
|
||||
return conditions
|
||||
@@ -63,4 +75,4 @@ class LoraManager:
|
||||
if suffix in [".ckpt", ".pt", ".safetensors"]:
|
||||
models_found[name]=Path(root,x)
|
||||
return models_found
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user