Merge branch 'v2.3' into lstein/bugfix/improve-update-handling

This commit is contained in:
Lincoln Stein
2023-04-25 03:09:12 +01:00
committed by GitHub
13 changed files with 198 additions and 194 deletions

34
.github/CODEOWNERS vendored
View File

@@ -1,13 +1,13 @@
# continuous integration
/.github/workflows/ @mauwii @lstein @blessedcoolant
/.github/workflows/ @lstein @blessedcoolant
# documentation
/docs/ @lstein @mauwii @blessedcoolant
mkdocs.yml @mauwii @lstein
/docs/ @lstein @blessedcoolant
mkdocs.yml @lstein @ebr
# installation and configuration
/pyproject.toml @mauwii @lstein @ebr
/docker/ @mauwii
/pyproject.toml @lstein @ebr
/docker/ @lstein
/scripts/ @ebr @lstein @blessedcoolant
/installer/ @ebr @lstein
ldm/invoke/config @lstein @ebr
@@ -21,13 +21,13 @@ invokeai/configs @lstein @ebr @blessedcoolant
# generation and model management
/ldm/*.py @lstein @blessedcoolant
/ldm/generate.py @lstein @keturn
/ldm/generate.py @lstein @gregghelt2
/ldm/invoke/args.py @lstein @blessedcoolant
/ldm/invoke/ckpt* @lstein @blessedcoolant
/ldm/invoke/ckpt_generator @lstein @blessedcoolant
/ldm/invoke/CLI.py @lstein @blessedcoolant
/ldm/invoke/config @lstein @ebr @mauwii @blessedcoolant
/ldm/invoke/generator @keturn @damian0815
/ldm/invoke/config @lstein @ebr @blessedcoolant
/ldm/invoke/generator @gregghelt2 @damian0815
/ldm/invoke/globals.py @lstein @blessedcoolant
/ldm/invoke/merge_diffusers.py @lstein @blessedcoolant
/ldm/invoke/model_manager.py @lstein @blessedcoolant
@@ -36,17 +36,17 @@ invokeai/configs @lstein @ebr @blessedcoolant
/ldm/invoke/restoration @lstein @blessedcoolant
# attention, textual inversion, model configuration
/ldm/models @damian0815 @keturn @blessedcoolant
/ldm/models @damian0815 @gregghelt2 @blessedcoolant
/ldm/modules/textual_inversion_manager.py @lstein @blessedcoolant
/ldm/modules/attention.py @damian0815 @keturn
/ldm/modules/diffusionmodules @damian0815 @keturn
/ldm/modules/distributions @damian0815 @keturn
/ldm/modules/ema.py @damian0815 @keturn
/ldm/modules/attention.py @damian0815 @gregghelt2
/ldm/modules/diffusionmodules @damian0815 @gregghelt2
/ldm/modules/distributions @damian0815 @gregghelt2
/ldm/modules/ema.py @damian0815 @gregghelt2
/ldm/modules/embedding_manager.py @lstein
/ldm/modules/encoders @damian0815 @keturn
/ldm/modules/image_degradation @damian0815 @keturn
/ldm/modules/losses @damian0815 @keturn
/ldm/modules/x_transformer.py @damian0815 @keturn
/ldm/modules/encoders @damian0815 @gregghelt2
/ldm/modules/image_degradation @damian0815 @gregghelt2
/ldm/modules/losses @damian0815 @gregghelt2
/ldm/modules/x_transformer.py @damian0815 @gregghelt2
# Nodes
apps/ @Kyle0654 @jpphoto

View File

@@ -30,7 +30,6 @@ from ldm.invoke.conditioning import (
get_tokens_for_prompt_object,
get_prompt_structure,
split_weighted_subprompts,
get_tokenizer,
)
from ldm.invoke.generator.diffusers_pipeline import PipelineIntermediateState
from ldm.invoke.generator.inpaint import infill_methods
@@ -1314,7 +1313,7 @@ class InvokeAIWebServer:
None
if type(parsed_prompt) is Blend
else get_tokens_for_prompt_object(
get_tokenizer(self.generate.model), parsed_prompt
self.generate.model.tokenizer, parsed_prompt
)
)
attention_maps_image_base64_url = (

View File

@@ -15,19 +15,10 @@ from compel import Compel
from compel.prompt_parser import FlattenedPrompt, Blend, Fragment, CrossAttentionControlSubstitute, PromptParser, \
Conjunction
from .devices import torch_dtype
from .generator.diffusers_pipeline import StableDiffusionGeneratorPipeline
from ..models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
from ldm.invoke.globals import Globals
def get_tokenizer(model) -> CLIPTokenizer:
# TODO remove legacy ckpt fallback handling
return (getattr(model, 'tokenizer', None) # diffusers
or model.cond_stage_model.tokenizer) # ldm
def get_text_encoder(model) -> Any:
# TODO remove legacy ckpt fallback handling
return (getattr(model, 'text_encoder', None) # diffusers
or UnsqueezingLDMTransformer(model.cond_stage_model.transformer)) # ldm
class UnsqueezingLDMTransformer:
def __init__(self, ldm_transformer):
self.ldm_transformer = ldm_transformer
@@ -41,15 +32,15 @@ class UnsqueezingLDMTransformer:
return insufficiently_unsqueezed_tensor.unsqueeze(0)
def get_uc_and_c_and_ec(prompt_string, model, log_tokens=False, skip_normalize_legacy_blend=False):
def get_uc_and_c_and_ec(prompt_string,
model: StableDiffusionGeneratorPipeline,
log_tokens=False, skip_normalize_legacy_blend=False):
# lazy-load any deferred textual inversions.
# this might take a couple of seconds the first time a textual inversion is used.
model.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms(prompt_string)
tokenizer = get_tokenizer(model)
text_encoder = get_text_encoder(model)
compel = Compel(tokenizer=tokenizer,
text_encoder=text_encoder,
compel = Compel(tokenizer=model.tokenizer,
text_encoder=model.text_encoder,
textual_inversion_manager=model.textual_inversion_manager,
dtype_for_device_getter=torch_dtype)
@@ -78,14 +69,20 @@ def get_uc_and_c_and_ec(prompt_string, model, log_tokens=False, skip_normalize_l
negative_conjunction = Compel.parse_prompt_string(negative_prompt_string)
negative_prompt: FlattenedPrompt | Blend = negative_conjunction.prompts[0]
tokens_count = get_max_token_count(model.tokenizer, positive_prompt)
if log_tokens or getattr(Globals, "log_tokenization", False):
log_tokenization(positive_prompt, negative_prompt, tokenizer=tokenizer)
log_tokenization(positive_prompt, negative_prompt, tokenizer=model.tokenizer)
c, options = compel.build_conditioning_tensor_for_prompt_object(positive_prompt)
uc, _ = compel.build_conditioning_tensor_for_prompt_object(negative_prompt)
tokens_count = get_max_token_count(tokenizer, positive_prompt)
# some LoRA models also mess with the text encoder, so they must be active while compel builds conditioning tensors
lora_conditioning_ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(tokens_count_including_eos_bos=tokens_count,
lora_conditions=lora_conditions)
with InvokeAIDiffuserComponent.custom_attention_context(model.unet,
extra_conditioning_info=lora_conditioning_ec,
step_count=-1):
c, options = compel.build_conditioning_tensor_for_prompt_object(positive_prompt)
uc, _ = compel.build_conditioning_tensor_for_prompt_object(negative_prompt)
# now build the "real" ec
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(tokens_count_including_eos_bos=tokens_count,
cross_attention_control_args=options.get(
'cross_attention_control', None),

View File

@@ -196,16 +196,6 @@ class addModelsForm(npyscreen.FormMultiPage):
scroll_exit=True,
)
self.nextrely += 1
self.convert_models = self.add_widget_intelligent(
npyscreen.TitleSelectOne,
name="== CONVERT IMPORTED MODELS INTO DIFFUSERS==",
values=["Keep original format", "Convert to diffusers"],
value=0,
begin_entry_at=4,
max_height=4,
hidden=True, # will appear when imported models box is edited
scroll_exit=True,
)
self.cancel = self.add_widget_intelligent(
npyscreen.ButtonPress,
name="CANCEL",
@@ -240,8 +230,6 @@ class addModelsForm(npyscreen.FormMultiPage):
self.show_directory_fields.addVisibleWhenSelected(i)
self.show_directory_fields.when_value_edited = self._clear_scan_directory
self.import_model_paths.when_value_edited = self._show_hide_convert
self.autoload_directory.when_value_edited = self._show_hide_convert
def resize(self):
super().resize()
@@ -252,13 +240,6 @@ class addModelsForm(npyscreen.FormMultiPage):
if not self.show_directory_fields.value:
self.autoload_directory.value = ""
def _show_hide_convert(self):
model_paths = self.import_model_paths.value or ""
autoload_directory = self.autoload_directory.value or ""
self.convert_models.hidden = (
len(model_paths) == 0 and len(autoload_directory) == 0
)
def _get_starter_model_labels(self) -> List[str]:
window_width, window_height = get_terminal_size()
label_width = 25
@@ -318,7 +299,6 @@ class addModelsForm(npyscreen.FormMultiPage):
.scan_directory: Path to a directory of models to scan and import
.autoscan_on_startup: True if invokeai should scan and import at startup time
.import_model_paths: list of URLs, repo_ids and file paths to import
.convert_to_diffusers: if True, convert legacy checkpoints into diffusers
"""
# we're using a global here rather than storing the result in the parentapp
# due to some bug in npyscreen that is causing attributes to be lost
@@ -354,7 +334,6 @@ class addModelsForm(npyscreen.FormMultiPage):
# URLs and the like
selections.import_model_paths = self.import_model_paths.value.split()
selections.convert_to_diffusers = self.convert_models.value[0] == 1
class AddModelApplication(npyscreen.NPSAppManaged):
@@ -367,7 +346,6 @@ class AddModelApplication(npyscreen.NPSAppManaged):
scan_directory=None,
autoscan_on_startup=None,
import_model_paths=None,
convert_to_diffusers=None,
)
def onStart(self):
@@ -387,7 +365,6 @@ def process_and_execute(opt: Namespace, selections: Namespace):
directory_to_scan = selections.scan_directory
scan_at_startup = selections.autoscan_on_startup
potential_models_to_install = selections.import_model_paths
convert_to_diffusers = selections.convert_to_diffusers
install_requested_models(
install_initial_models=models_to_install,
@@ -395,7 +372,6 @@ def process_and_execute(opt: Namespace, selections: Namespace):
scan_directory=Path(directory_to_scan) if directory_to_scan else None,
external_models=potential_models_to_install,
scan_at_startup=scan_at_startup,
convert_to_diffusers=convert_to_diffusers,
precision="float32"
if opt.full_precision
else choose_precision(torch.device(choose_torch_device())),

View File

@@ -68,7 +68,6 @@ def install_requested_models(
scan_directory: Path = None,
external_models: List[str] = None,
scan_at_startup: bool = False,
convert_to_diffusers: bool = False,
precision: str = "float16",
purge_deleted: bool = False,
config_file_path: Path = None,
@@ -111,20 +110,20 @@ def install_requested_models(
if len(external_models)>0:
print("== INSTALLING EXTERNAL MODELS ==")
for path_url_or_repo in external_models:
print(f'DEBUG: path_url_or_repo = {path_url_or_repo}')
try:
model_manager.heuristic_import(
path_url_or_repo,
convert=convert_to_diffusers,
config_file_callback=_pick_configuration_file,
commit_to_conf=config_file_path
)
except KeyboardInterrupt:
sys.exit(-1)
except Exception:
pass
except Exception as e:
print(f'An exception has occurred: {str(e)}')
if scan_at_startup and scan_directory.is_dir():
argument = '--autoconvert' if convert_to_diffusers else '--autoimport'
argument = '--autoconvert'
initfile = Path(Globals.root, Globals.initfile)
replacement = Path(Globals.root, f'{Globals.initfile}.new')
directory = str(scan_directory).replace('\\','/')

View File

@@ -467,8 +467,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if additional_guidance is None:
additional_guidance = []
extra_conditioning_info = conditioning_data.extra
with self.invokeai_diffuser.custom_attention_context(extra_conditioning_info=extra_conditioning_info,
step_count=len(self.scheduler.timesteps)
with InvokeAIDiffuserComponent.custom_attention_context(self.invokeai_diffuser.model,
extra_conditioning_info=extra_conditioning_info,
step_count=len(self.scheduler.timesteps)
):
yield PipelineIntermediateState(run_id=run_id, step=-1, timestep=self.scheduler.num_train_timesteps,

View File

@@ -288,16 +288,7 @@ class InvokeAICrossAttentionMixin:
return self.einsum_op_tensor_mem(q, k, v, 32)
def restore_default_cross_attention(model, is_running_diffusers: bool, processors_to_restore: Optional[AttnProcessor]=None):
if is_running_diffusers:
unet = model
unet.set_attn_processor(processors_to_restore or CrossAttnProcessor())
else:
remove_attention_function(model)
def override_cross_attention(model, context: Context, is_running_diffusers = False):
def setup_cross_attention_control_attention_processors(unet: UNet2DConditionModel, context: Context):
"""
Inject attention parameters and functions into the passed in model to enable cross attention editing.
@@ -323,22 +314,15 @@ def override_cross_attention(model, context: Context, is_running_diffusers = Fal
context.cross_attention_mask = mask.to(device)
context.cross_attention_index_map = indices.to(device)
if is_running_diffusers:
unet = model
old_attn_processors = unet.attn_processors
if torch.backends.mps.is_available():
# see note in StableDiffusionGeneratorPipeline.__init__ about borked slicing on MPS
unet.set_attn_processor(SwapCrossAttnProcessor())
else:
# try to re-use an existing slice size
default_slice_size = 4
slice_size = next((p.slice_size for p in old_attn_processors.values() if type(p) is SlicedAttnProcessor), default_slice_size)
unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size))
old_attn_processors = unet.attn_processors
if torch.backends.mps.is_available():
# see note in StableDiffusionGeneratorPipeline.__init__ about borked slicing on MPS
unet.set_attn_processor(SwapCrossAttnProcessor())
else:
context.register_cross_attention_modules(model)
inject_attention_function(model, context)
# try to re-use an existing slice size
default_slice_size = 4
slice_size = next((p.slice_size for p in old_attn_processors.values() if type(p) is SlicedAttnProcessor), default_slice_size)
unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size))
def get_cross_attention_modules(model, which: CrossAttentionType) -> list[tuple[str, InvokeAICrossAttentionMixin]]:

View File

@@ -12,17 +12,6 @@ class DDIMSampler(Sampler):
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.model,
model_forward_callback = lambda x, sigma, cond: self.model.apply_model(x, sigma, cond))
def prepare_to_sample(self, t_enc, **kwargs):
super().prepare_to_sample(t_enc, **kwargs)
extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
all_timesteps_count = kwargs.get('all_timesteps_count', t_enc)
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
self.invokeai_diffuser.override_attention_processors(extra_conditioning_info, step_count = all_timesteps_count)
else:
self.invokeai_diffuser.restore_default_cross_attention()
# This is the central routine
@torch.no_grad()

View File

@@ -38,15 +38,6 @@ class CFGDenoiser(nn.Module):
model_forward_callback=lambda x, sigma, cond: self.inner_model(x, sigma, cond=cond))
def prepare_to_sample(self, t_enc, **kwargs):
extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
self.invokeai_diffuser.override_attention_processors(extra_conditioning_info, step_count = t_enc)
else:
self.invokeai_diffuser.restore_default_cross_attention()
def forward(self, x, sigma, uncond, cond, cond_scale):
next_x = self.invokeai_diffuser.do_diffusion_step(x, sigma, uncond, cond, cond_scale)

View File

@@ -14,17 +14,6 @@ class PLMSSampler(Sampler):
def __init__(self, model, schedule='linear', device=None, **kwargs):
super().__init__(model,schedule,model.num_timesteps, device)
def prepare_to_sample(self, t_enc, **kwargs):
super().prepare_to_sample(t_enc, **kwargs)
extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
all_timesteps_count = kwargs.get('all_timesteps_count', t_enc)
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
self.invokeai_diffuser.override_attention_processors(extra_conditioning_info, step_count = all_timesteps_count)
else:
self.invokeai_diffuser.restore_default_cross_attention()
# this is the essential routine
@torch.no_grad()

View File

@@ -1,18 +1,18 @@
from contextlib import contextmanager
from dataclasses import dataclass
from math import ceil
from typing import Callable, Optional, Union, Any, Dict
from typing import Callable, Optional, Union, Any
import numpy as np
import torch
from diffusers.models.cross_attention import AttnProcessor
from diffusers import UNet2DConditionModel
from typing_extensions import TypeAlias
from ldm.invoke.globals import Globals
from ldm.models.diffusion.cross_attention_control import (
Arguments,
restore_default_cross_attention,
override_cross_attention,
setup_cross_attention_control_attention_processors,
Context,
get_cross_attention_modules,
CrossAttentionType,
@@ -84,66 +84,45 @@ class InvokeAIDiffuserComponent:
self.cross_attention_control_context = None
self.sequential_guidance = Globals.sequential_guidance
@classmethod
@contextmanager
def custom_attention_context(
self, extra_conditioning_info: Optional[ExtraConditioningInfo], step_count: int
clss,
unet: UNet2DConditionModel, # note: also may futz with the text encoder depending on requested LoRAs
extra_conditioning_info: Optional[ExtraConditioningInfo],
step_count: int
):
old_attn_processor = None
old_attn_processors = None
if extra_conditioning_info and (
extra_conditioning_info.wants_cross_attention_control
| extra_conditioning_info.has_lora_conditions
):
old_attn_processor = self.override_attention_processors(
extra_conditioning_info, step_count=step_count
)
old_attn_processors = unet.attn_processors
# Load lora conditions into the model
if extra_conditioning_info.has_lora_conditions:
for condition in extra_conditioning_info.lora_conditions:
condition() # target model is stored in condition state for some reason
if extra_conditioning_info.wants_cross_attention_control:
cross_attention_control_context = Context(
arguments=extra_conditioning_info.cross_attention_control_args,
step_count=step_count,
)
setup_cross_attention_control_attention_processors(
unet,
cross_attention_control_context,
)
try:
yield None
finally:
if old_attn_processor is not None:
self.restore_default_cross_attention(old_attn_processor)
if old_attn_processors is not None:
unet.set_attn_processor(old_attn_processors)
if extra_conditioning_info and extra_conditioning_info.has_lora_conditions:
for lora_condition in extra_conditioning_info.lora_conditions:
lora_condition.unload()
# TODO resuscitate attention map saving
# self.remove_attention_map_saving()
def override_attention_processors(
self, conditioning: ExtraConditioningInfo, step_count: int
) -> Dict[str, AttnProcessor]:
"""
setup cross attention .swap control. for diffusers this replaces the attention processor, so
the previous attention processor is returned so that the caller can restore it later.
"""
old_attn_processors = self.model.attn_processors
# Load lora conditions into the model
if conditioning.has_lora_conditions:
for condition in conditioning.lora_conditions:
condition(self.model)
if conditioning.wants_cross_attention_control:
self.cross_attention_control_context = Context(
arguments=conditioning.cross_attention_control_args,
step_count=step_count,
)
override_cross_attention(
self.model,
self.cross_attention_control_context,
is_running_diffusers=self.is_running_diffusers,
)
return old_attn_processors
def restore_default_cross_attention(
self, processors_to_restore: Optional[dict[str, "AttnProcessor"]] = None
):
self.cross_attention_control_context = None
restore_default_cross_attention(
self.model,
is_running_diffusers=self.is_running_diffusers,
processors_to_restore=processors_to_restore,
)
def setup_attention_map_saving(self, saver: AttentionMapSaver):
def callback(slice, dim, offset, slice_size, key):
if dim is not None:

View File

@@ -31,18 +31,13 @@ class LoRALayer:
self.name = name
self.scale = alpha / rank if (alpha and rank) else 1.0
def forward(self, lora, input_h, output):
def forward(self, lora, input_h):
if self.mid is None:
output = (
output
+ self.up(self.down(*input_h)) * lora.multiplier * self.scale
)
weight = self.up(self.down(*input_h))
else:
output = (
output
+ self.up(self.mid(self.down(*input_h))) * lora.multiplier * self.scale
)
return output
weight = self.up(self.mid(self.down(*input_h)))
return weight * lora.multiplier * self.scale
class LoHALayer:
lora_name: str
@@ -64,7 +59,7 @@ class LoHALayer:
self.name = name
self.scale = alpha / rank if (alpha and rank) else 1.0
def forward(self, lora, input_h, output):
def forward(self, lora, input_h):
if type(self.org_module) == torch.nn.Conv2d:
op = torch.nn.functional.conv2d
@@ -86,9 +81,9 @@ class LoHALayer:
rebuild1 = torch.einsum('i j k l, j r, i p -> p r k l', self.t1, self.w1_b, self.w1_a)
rebuild2 = torch.einsum('i j k l, j r, i p -> p r k l', self.t2, self.w2_b, self.w2_a)
weight = rebuild1 * rebuild2
bias = self.bias if self.bias is not None else 0
return output + op(
return op(
*input_h,
(weight + bias).view(self.org_module.weight.shape),
None,
@@ -96,6 +91,69 @@ class LoHALayer:
) * lora.multiplier * self.scale
class LoKRLayer:
lora_name: str
name: str
scale: float
w1: Optional[torch.Tensor] = None
w1_a: Optional[torch.Tensor] = None
w1_b: Optional[torch.Tensor] = None
w2: Optional[torch.Tensor] = None
w2_a: Optional[torch.Tensor] = None
w2_b: Optional[torch.Tensor] = None
t2: Optional[torch.Tensor] = None
bias: Optional[torch.Tensor] = None
org_module: torch.nn.Module
def __init__(self, lora_name: str, name: str, rank=4, alpha=1.0):
self.lora_name = lora_name
self.name = name
self.scale = alpha / rank if (alpha and rank) else 1.0
def forward(self, lora, input_h):
if type(self.org_module) == torch.nn.Conv2d:
op = torch.nn.functional.conv2d
extra_args = dict(
stride=self.org_module.stride,
padding=self.org_module.padding,
dilation=self.org_module.dilation,
groups=self.org_module.groups,
)
else:
op = torch.nn.functional.linear
extra_args = {}
w1 = self.w1
if w1 is None:
w1 = self.w1_a @ self.w1_b
w2 = self.w2
if w2 is None:
if self.t2 is None:
w2 = self.w2_a @ self.w2_b
else:
w2 = torch.einsum('i j k l, i p, j r -> p r k l', self.t2, self.w2_a, self.w2_b)
if len(w2.shape) == 4:
w1 = w1.unsqueeze(2).unsqueeze(2)
w2 = w2.contiguous()
weight = torch.kron(w1, w2).reshape(self.org_module.weight.shape)
bias = self.bias if self.bias is not None else 0
return op(
*input_h,
(weight + bias).view(self.org_module.weight.shape),
None,
**extra_args
) * lora.multiplier * self.scale
class LoRAModuleWrapper:
unet: UNet2DConditionModel
text_encoder: CLIPTextModel
@@ -159,7 +217,7 @@ class LoRAModuleWrapper:
layer = lora.layers.get(name, None)
if layer is None:
continue
output = layer.forward(lora, input_h, output)
output += layer.forward(lora, input_h)
return output
return lora_forward
@@ -307,6 +365,36 @@ class LoRA:
else:
layer.t2 = None
# lokr
elif "lokr_w1_b" in values or "lokr_w1" in values:
if "lokr_w1_b" in values:
rank = values["lokr_w1_b"].shape[0]
elif "lokr_w2_b" in values:
rank = values["lokr_w2_b"].shape[0]
else:
rank = None # unscaled
layer = LoKRLayer(self.name, stem, rank, alpha)
layer.org_module = wrapped
layer.bias = bias
if "lokr_w1" in values:
layer.w1 = values["lokr_w1"].to(device=self.device, dtype=self.dtype)
else:
layer.w1_a = values["lokr_w1_a"].to(device=self.device, dtype=self.dtype)
layer.w1_b = values["lokr_w1_b"].to(device=self.device, dtype=self.dtype)
if "lokr_w2" in values:
layer.w2 = values["lokr_w2"].to(device=self.device, dtype=self.dtype)
else:
layer.w2_a = values["lokr_w2_a"].to(device=self.device, dtype=self.dtype)
layer.w2_b = values["lokr_w2_b"].to(device=self.device, dtype=self.dtype)
if "lokr_t2" in values:
layer.t2 = values["lokr_t2"].to(device=self.device, dtype=self.dtype)
else:
print(
f">> Encountered unknown lora layer module in {self.name}: {stem} - {type(wrapped).__name__}"

View File

@@ -1,5 +1,7 @@
import os
from pathlib import Path
from diffusers import UNet2DConditionModel, StableDiffusionPipeline
from ldm.invoke.globals import global_lora_models_dir
from .kohya_lora_manager import KohyaLoraManager
from typing import Optional, Dict
@@ -8,20 +10,29 @@ class LoraCondition:
name: str
weight: float
def __init__(self, name, weight: float = 1.0, kohya_manager: Optional[KohyaLoraManager]=None):
def __init__(self,
name,
weight: float = 1.0,
unet: UNet2DConditionModel=None, # for diffusers format LoRAs
kohya_manager: Optional[KohyaLoraManager]=None, # for KohyaLoraManager-compatible LoRAs
):
self.name = name
self.weight = weight
self.kohya_manager = kohya_manager
self.unet = unet
def __call__(self, model):
def __call__(self):
# TODO: make model able to load from huggingface, rather then just local files
path = Path(global_lora_models_dir(), self.name)
if path.is_dir():
if model.load_attn_procs:
if not self.unet:
print(f" ** Unable to load diffusers-format LoRA {self.name}: unet is None")
return
if self.unet.load_attn_procs:
file = Path(path, "pytorch_lora_weights.bin")
if file.is_file():
print(f">> Loading LoRA: {path}")
model.load_attn_procs(path.absolute().as_posix())
self.unet.load_attn_procs(path.absolute().as_posix())
else:
print(f" ** Unable to find valid LoRA at: {path}")
else:
@@ -37,15 +48,16 @@ class LoraCondition:
self.kohya_manager.unload_applied_lora(self.name)
class LoraManager:
def __init__(self, pipe):
def __init__(self, pipe: StableDiffusionPipeline):
# Kohya class handles lora not generated through diffusers
self.kohya = KohyaLoraManager(pipe, global_lora_models_dir())
self.unet = pipe.unet
def set_loras_conditions(self, lora_weights: list):
conditions = []
if len(lora_weights) > 0:
for lora in lora_weights:
conditions.append(LoraCondition(lora.model, lora.weight, self.kohya))
conditions.append(LoraCondition(lora.model, lora.weight, self.unet, self.kohya))
if len(conditions) > 0:
return conditions
@@ -63,4 +75,4 @@ class LoraManager:
if suffix in [".ckpt", ".pt", ".safetensors"]:
models_found[name]=Path(root,x)
return models_found