mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-17 05:58:04 -05:00
Compare commits
71 Commits
v4.2.9.dev
...
ryan/flux-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0ff5355ce3 | ||
|
|
3ad5fc060d | ||
|
|
17d5c85454 | ||
|
|
4698649cc9 | ||
|
|
d41c075768 | ||
|
|
6b129aaba6 | ||
|
|
f58546fd53 | ||
|
|
de3edf47fb | ||
|
|
6dc4baa925 | ||
|
|
943fa6da4b | ||
|
|
bfe31838cc | ||
|
|
16b76f7e7f | ||
|
|
cb115743e7 | ||
|
|
8f4279ba51 | ||
|
|
22a207b50d | ||
|
|
638c6003e3 | ||
|
|
8d35af946e | ||
|
|
24065ec6b6 | ||
|
|
627b0bf644 | ||
|
|
b43da46b82 | ||
|
|
4255a01c64 | ||
|
|
23adbd4002 | ||
|
|
fb5a24fcc6 | ||
|
|
cfdd5a1900 | ||
|
|
2313f326df | ||
|
|
2e092a2313 | ||
|
|
763ef06c18 | ||
|
|
8292f6cd42 | ||
|
|
278bba499e | ||
|
|
dd99ed28e0 | ||
|
|
9a8aca69bf | ||
|
|
7ad62512eb | ||
|
|
bd466661ec | ||
|
|
7ebb509d05 | ||
|
|
0aa13c046c | ||
|
|
a7a33d73f5 | ||
|
|
ffa39857d3 | ||
|
|
e85c3bc465 | ||
|
|
8185ba7054 | ||
|
|
d501865bec | ||
|
|
d62310bb5f | ||
|
|
1835bff196 | ||
|
|
87261bdbc9 | ||
|
|
4e4b6c6dbc | ||
|
|
5e8cf9fb6a | ||
|
|
c738fe051f | ||
|
|
29fe1533f2 | ||
|
|
77090070bd | ||
|
|
6ba9b1b6b0 | ||
|
|
c578b8df1e | ||
|
|
cad9a41433 | ||
|
|
5fefb3b0f4 | ||
|
|
5284a870b0 | ||
|
|
e064377c05 | ||
|
|
3e569c8312 | ||
|
|
16825ee6e9 | ||
|
|
3f5340fa53 | ||
|
|
f2a1a39b33 | ||
|
|
326de55d3e | ||
|
|
b2df909570 | ||
|
|
026ac36b06 | ||
|
|
92125e5fd2 | ||
|
|
c0c139da88 | ||
|
|
404ad6a7fd | ||
|
|
fc39086fb4 | ||
|
|
cd215700fe | ||
|
|
e97fd85904 | ||
|
|
0a263fa5b1 | ||
|
|
fae3836a8d | ||
|
|
b3d2eb4178 | ||
|
|
576f1cbb75 |
@@ -196,6 +196,22 @@ tips to reduce the problem:
|
||||
=== "12GB VRAM GPU"
|
||||
|
||||
This should be sufficient to generate larger images up to about 1280x1280.
|
||||
|
||||
## Checkpoint Models Load Slowly or Use Too Much RAM
|
||||
|
||||
The difference between diffusers models (a folder containing multiple
|
||||
subfolders) and checkpoint models (a file ending with .safetensors or
|
||||
.ckpt) is that InvokeAI is able to load diffusers models into memory
|
||||
incrementally, while checkpoint models must be loaded all at
|
||||
once. With very large models, or systems with limited RAM, you may
|
||||
experience slowdowns and other memory-related issues when loading
|
||||
checkpoint models.
|
||||
|
||||
To solve this, go to the Model Manager tab (the cube), select the
|
||||
checkpoint model that's giving you trouble, and press the "Convert"
|
||||
button in the upper right of your browser window. This will conver the
|
||||
checkpoint into a diffusers model, after which loading should be
|
||||
faster and less memory-intensive.
|
||||
|
||||
## Memory Leak (Linux)
|
||||
|
||||
|
||||
@@ -3,8 +3,10 @@
|
||||
|
||||
import io
|
||||
import pathlib
|
||||
import shutil
|
||||
import traceback
|
||||
from copy import deepcopy
|
||||
from enum import Enum
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import List, Optional, Type
|
||||
|
||||
@@ -17,6 +19,7 @@ from starlette.exceptions import HTTPException
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
from invokeai.app.services.config import get_config
|
||||
from invokeai.app.services.model_images.model_images_common import ModelImageFileNotFoundException
|
||||
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
|
||||
from invokeai.app.services.model_records import (
|
||||
@@ -31,6 +34,7 @@ from invokeai.backend.model_manager.config import (
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import CacheStats
|
||||
from invokeai.backend.model_manager.metadata.fetch.huggingface import HuggingFaceMetadataFetch
|
||||
from invokeai.backend.model_manager.metadata.metadata_base import ModelMetadataWithFiles, UnknownMetadataException
|
||||
from invokeai.backend.model_manager.search import ModelSearch
|
||||
@@ -50,6 +54,13 @@ class ModelsList(BaseModel):
|
||||
model_config = ConfigDict(use_enum_values=True)
|
||||
|
||||
|
||||
class CacheType(str, Enum):
|
||||
"""Cache type - one of vram or ram."""
|
||||
|
||||
RAM = "RAM"
|
||||
VRAM = "VRAM"
|
||||
|
||||
|
||||
def add_cover_image_to_model_config(config: AnyModelConfig, dependencies: Type[ApiDependencies]) -> AnyModelConfig:
|
||||
"""Add a cover image URL to a model configuration."""
|
||||
cover_image = dependencies.invoker.services.model_images.get_url(config.key)
|
||||
@@ -797,3 +808,83 @@ async def get_starter_models() -> list[StarterModel]:
|
||||
model.dependencies = missing_deps
|
||||
|
||||
return starter_models
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/model_cache",
|
||||
operation_id="get_cache_size",
|
||||
response_model=float,
|
||||
summary="Get maximum size of model manager RAM or VRAM cache.",
|
||||
)
|
||||
async def get_cache_size(cache_type: CacheType = Query(description="The cache type", default=CacheType.RAM)) -> float:
|
||||
"""Return the current RAM or VRAM cache size setting (in GB)."""
|
||||
cache = ApiDependencies.invoker.services.model_manager.load.ram_cache
|
||||
value = 0.0
|
||||
if cache_type == CacheType.RAM:
|
||||
value = cache.max_cache_size
|
||||
elif cache_type == CacheType.VRAM:
|
||||
value = cache.max_vram_cache_size
|
||||
return value
|
||||
|
||||
|
||||
@model_manager_router.put(
|
||||
"/model_cache",
|
||||
operation_id="set_cache_size",
|
||||
response_model=float,
|
||||
summary="Set maximum size of model manager RAM or VRAM cache, optionally writing new value out to invokeai.yaml config file.",
|
||||
)
|
||||
async def set_cache_size(
|
||||
value: float = Query(description="The new value for the maximum cache size"),
|
||||
cache_type: CacheType = Query(description="The cache type", default=CacheType.RAM),
|
||||
persist: bool = Query(description="Write new value out to invokeai.yaml", default=False),
|
||||
) -> float:
|
||||
"""Set the current RAM or VRAM cache size setting (in GB). ."""
|
||||
cache = ApiDependencies.invoker.services.model_manager.load.ram_cache
|
||||
app_config = get_config()
|
||||
# Record initial state.
|
||||
vram_old = app_config.vram
|
||||
ram_old = app_config.ram
|
||||
|
||||
# Prepare target state.
|
||||
vram_new = vram_old
|
||||
ram_new = ram_old
|
||||
if cache_type == CacheType.RAM:
|
||||
ram_new = value
|
||||
elif cache_type == CacheType.VRAM:
|
||||
vram_new = value
|
||||
else:
|
||||
raise ValueError(f"Unexpected {cache_type=}.")
|
||||
|
||||
config_path = app_config.config_file_path
|
||||
new_config_path = config_path.with_suffix(".yaml.new")
|
||||
|
||||
try:
|
||||
# Try to apply the target state.
|
||||
cache.max_vram_cache_size = vram_new
|
||||
cache.max_cache_size = ram_new
|
||||
app_config.ram = ram_new
|
||||
app_config.vram = vram_new
|
||||
if persist:
|
||||
app_config.write_file(new_config_path)
|
||||
shutil.move(new_config_path, config_path)
|
||||
except Exception as e:
|
||||
# If there was a failure, restore the initial state.
|
||||
cache.max_cache_size = ram_old
|
||||
cache.max_vram_cache_size = vram_old
|
||||
app_config.ram = ram_old
|
||||
app_config.vram = vram_old
|
||||
|
||||
raise RuntimeError("Failed to update cache size") from e
|
||||
return value
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/stats",
|
||||
operation_id="get_stats",
|
||||
response_model=Optional[CacheStats],
|
||||
summary="Get model manager RAM cache performance statistics.",
|
||||
)
|
||||
async def get_stats() -> Optional[CacheStats]:
|
||||
"""Return performance statistics on the model manager's RAM cache. Will return null if no models have been loaded."""
|
||||
|
||||
return ApiDependencies.invoker.services.model_manager.load.ram_cache.stats
|
||||
|
||||
@@ -19,8 +19,8 @@ from invokeai.app.invocations.model import CLIPField
|
||||
from invokeai.app.invocations.primitives import ConditioningOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.ti_utils import generate_ti_list
|
||||
from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.model_patcher import ModelPatcher
|
||||
from invokeai.backend.peft.lora import LoRAModelRaw
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
BasicConditioningInfo,
|
||||
ConditioningFieldData,
|
||||
|
||||
@@ -36,9 +36,9 @@ from invokeai.app.invocations.t2i_adapter import T2IAdapterField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||
from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.model_manager import BaseModelType, ModelVariantType
|
||||
from invokeai.backend.model_patcher import ModelPatcher
|
||||
from invokeai.backend.peft.lora import LoRAModelRaw
|
||||
from invokeai.backend.stable_diffusion import PipelineIntermediateState
|
||||
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, DenoiseInputs
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import (
|
||||
@@ -185,7 +185,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
)
|
||||
denoise_mask: Optional[DenoiseMaskField] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.mask,
|
||||
description=FieldDescriptions.denoise_mask,
|
||||
input=Input.Connection,
|
||||
ui_order=8,
|
||||
)
|
||||
|
||||
@@ -45,11 +45,13 @@ class UIType(str, Enum, metaclass=MetaEnum):
|
||||
SDXLRefinerModel = "SDXLRefinerModelField"
|
||||
ONNXModel = "ONNXModelField"
|
||||
VAEModel = "VAEModelField"
|
||||
FluxVAEModel = "FluxVAEModelField"
|
||||
LoRAModel = "LoRAModelField"
|
||||
ControlNetModel = "ControlNetModelField"
|
||||
IPAdapterModel = "IPAdapterModelField"
|
||||
T2IAdapterModel = "T2IAdapterModelField"
|
||||
T5EncoderModel = "T5EncoderModelField"
|
||||
CLIPEmbedModel = "CLIPEmbedModelField"
|
||||
SpandrelImageToImageModel = "SpandrelImageToImageModelField"
|
||||
# endregion
|
||||
|
||||
@@ -128,6 +130,7 @@ class FieldDescriptions:
|
||||
noise = "Noise tensor"
|
||||
clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count"
|
||||
t5_encoder = "T5 tokenizer and text encoder"
|
||||
clip_embed_model = "CLIP Embed loader"
|
||||
unet = "UNet (scheduler, LoRAs)"
|
||||
transformer = "Transformer"
|
||||
vae = "VAE"
|
||||
@@ -178,7 +181,7 @@ class FieldDescriptions:
|
||||
)
|
||||
num_1 = "The first number"
|
||||
num_2 = "The second number"
|
||||
mask = "The mask to use for the operation"
|
||||
denoise_mask = "A mask of the region to apply the denoising process to."
|
||||
board = "The board to save the image to"
|
||||
image = "The image to process"
|
||||
tile_size = "Tile size"
|
||||
|
||||
292
invokeai/app/invocations/flux_denoise.py
Normal file
292
invokeai/app/invocations/flux_denoise.py
Normal file
@@ -0,0 +1,292 @@
|
||||
from typing import Callable, Iterator, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torchvision.transforms as tv_transforms
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
DenoiseMaskField,
|
||||
FieldDescriptions,
|
||||
FluxConditioningField,
|
||||
Input,
|
||||
InputField,
|
||||
LatentsField,
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.model import TransformerField
|
||||
from invokeai.app.invocations.primitives import LatentsOutput
|
||||
from invokeai.app.services.session_processor.session_processor_common import CanceledException
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.denoise import denoise
|
||||
from invokeai.backend.flux.inpaint_extension import InpaintExtension
|
||||
from invokeai.backend.flux.model import Flux
|
||||
from invokeai.backend.flux.sampling_utils import (
|
||||
clip_timestep_schedule,
|
||||
generate_img_ids,
|
||||
get_noise,
|
||||
get_schedule,
|
||||
pack,
|
||||
unpack,
|
||||
)
|
||||
from invokeai.backend.peft.lora import LoRAModelRaw
|
||||
from invokeai.backend.peft.peft_patcher import PeftPatcher
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux_denoise",
|
||||
title="FLUX Denoise",
|
||||
tags=["image", "flux"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Run denoising process with a FLUX transformer model."""
|
||||
|
||||
# If latents is provided, this means we are doing image-to-image.
|
||||
latents: Optional[LatentsField] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.latents,
|
||||
input=Input.Connection,
|
||||
)
|
||||
# denoise_mask is used for image-to-image inpainting. Only the masked region is modified.
|
||||
denoise_mask: Optional[DenoiseMaskField] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.denoise_mask,
|
||||
input=Input.Connection,
|
||||
)
|
||||
denoising_start: float = InputField(
|
||||
default=0.0,
|
||||
ge=0,
|
||||
le=1,
|
||||
description=FieldDescriptions.denoising_start,
|
||||
)
|
||||
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
|
||||
transformer: TransformerField = InputField(
|
||||
description=FieldDescriptions.flux_model,
|
||||
input=Input.Connection,
|
||||
title="Transformer",
|
||||
)
|
||||
positive_text_conditioning: FluxConditioningField = InputField(
|
||||
description=FieldDescriptions.positive_cond, input=Input.Connection
|
||||
)
|
||||
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
|
||||
height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.")
|
||||
num_steps: int = InputField(
|
||||
default=4, description="Number of diffusion steps. Recommended values are schnell: 4, dev: 50."
|
||||
)
|
||||
guidance: float = InputField(
|
||||
default=4.0,
|
||||
description="The guidance strength. Higher values adhere more strictly to the prompt, and will produce less diverse images. FLUX dev only, ignored for schnell.",
|
||||
)
|
||||
seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = self._run_diffusion(context)
|
||||
latents = latents.detach().to("cpu")
|
||||
|
||||
name = context.tensors.save(tensor=latents)
|
||||
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
|
||||
|
||||
def _run_diffusion(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
):
|
||||
inference_dtype = torch.bfloat16
|
||||
|
||||
# Load the conditioning data.
|
||||
cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name)
|
||||
assert len(cond_data.conditionings) == 1
|
||||
flux_conditioning = cond_data.conditionings[0]
|
||||
assert isinstance(flux_conditioning, FLUXConditioningInfo)
|
||||
flux_conditioning = flux_conditioning.to(dtype=inference_dtype)
|
||||
t5_embeddings = flux_conditioning.t5_embeds
|
||||
clip_embeddings = flux_conditioning.clip_embeds
|
||||
|
||||
# Load the input latents, if provided.
|
||||
init_latents = context.tensors.load(self.latents.latents_name) if self.latents else None
|
||||
if init_latents is not None:
|
||||
init_latents = init_latents.to(device=TorchDevice.choose_torch_device(), dtype=inference_dtype)
|
||||
|
||||
# Prepare input noise.
|
||||
noise = get_noise(
|
||||
num_samples=1,
|
||||
height=self.height,
|
||||
width=self.width,
|
||||
device=TorchDevice.choose_torch_device(),
|
||||
dtype=inference_dtype,
|
||||
seed=self.seed,
|
||||
)
|
||||
|
||||
transformer_info = context.models.load(self.transformer.transformer)
|
||||
is_schnell = "schnell" in transformer_info.config.config_path
|
||||
|
||||
# Calculate the timestep schedule.
|
||||
image_seq_len = noise.shape[-1] * noise.shape[-2] // 4
|
||||
timesteps = get_schedule(
|
||||
num_steps=self.num_steps,
|
||||
image_seq_len=image_seq_len,
|
||||
shift=not is_schnell,
|
||||
)
|
||||
|
||||
# Clip the timesteps schedule based on denoising_start and denoising_end.
|
||||
timesteps = clip_timestep_schedule(timesteps, self.denoising_start, self.denoising_end)
|
||||
|
||||
# Prepare input latent image.
|
||||
if init_latents is not None:
|
||||
# If init_latents is provided, we are doing image-to-image.
|
||||
|
||||
if is_schnell:
|
||||
context.logger.warning(
|
||||
"Running image-to-image with a FLUX schnell model. This is not recommended. The results are likely "
|
||||
"to be poor. Consider using a FLUX dev model instead."
|
||||
)
|
||||
|
||||
# Noise the orig_latents by the appropriate amount for the first timestep.
|
||||
t_0 = timesteps[0]
|
||||
x = t_0 * noise + (1.0 - t_0) * init_latents
|
||||
else:
|
||||
# init_latents are not provided, so we are not doing image-to-image (i.e. we are starting from pure noise).
|
||||
if self.denoising_start > 1e-5:
|
||||
raise ValueError("denoising_start should be 0 when initial latents are not provided.")
|
||||
|
||||
x = noise
|
||||
|
||||
# If len(timesteps) == 1, then short-circuit. We are just noising the input latents, but not taking any
|
||||
# denoising steps.
|
||||
if len(timesteps) <= 1:
|
||||
return x
|
||||
|
||||
inpaint_mask = self._prep_inpaint_mask(context, x)
|
||||
|
||||
b, _c, h, w = x.shape
|
||||
img_ids = generate_img_ids(h=h, w=w, batch_size=b, device=x.device, dtype=x.dtype)
|
||||
|
||||
bs, t5_seq_len, _ = t5_embeddings.shape
|
||||
txt_ids = torch.zeros(bs, t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device())
|
||||
|
||||
# Pack all latent tensors.
|
||||
init_latents = pack(init_latents) if init_latents is not None else None
|
||||
inpaint_mask = pack(inpaint_mask) if inpaint_mask is not None else None
|
||||
noise = pack(noise)
|
||||
x = pack(x)
|
||||
|
||||
# Now that we have 'packed' the latent tensors, verify that we calculated the image_seq_len correctly.
|
||||
assert image_seq_len == x.shape[1]
|
||||
|
||||
# Prepare inpaint extension.
|
||||
inpaint_extension: InpaintExtension | None = None
|
||||
if inpaint_mask is not None:
|
||||
assert init_latents is not None
|
||||
inpaint_extension = InpaintExtension(
|
||||
init_latents=init_latents,
|
||||
inpaint_mask=inpaint_mask,
|
||||
noise=noise,
|
||||
)
|
||||
|
||||
with (
|
||||
transformer_info.model_on_device() as (cached_weights, transformer),
|
||||
# Apply the LoRA after transformer has been moved to its target device for faster patching.
|
||||
PeftPatcher.apply_peft_patches(
|
||||
model=transformer,
|
||||
patches=self._lora_iterator(context),
|
||||
prefix="",
|
||||
cached_weights=cached_weights,
|
||||
),
|
||||
):
|
||||
assert isinstance(transformer, Flux)
|
||||
|
||||
x = denoise(
|
||||
model=transformer,
|
||||
img=x,
|
||||
img_ids=img_ids,
|
||||
txt=t5_embeddings,
|
||||
txt_ids=txt_ids,
|
||||
vec=clip_embeddings,
|
||||
timesteps=timesteps,
|
||||
step_callback=self._build_step_callback(context),
|
||||
guidance=self.guidance,
|
||||
inpaint_extension=inpaint_extension,
|
||||
)
|
||||
|
||||
x = unpack(x.float(), self.height, self.width)
|
||||
return x
|
||||
|
||||
def _prep_inpaint_mask(self, context: InvocationContext, latents: torch.Tensor) -> torch.Tensor | None:
|
||||
"""Prepare the inpaint mask.
|
||||
|
||||
- Loads the mask
|
||||
- Resizes if necessary
|
||||
- Casts to same device/dtype as latents
|
||||
- Expands mask to the same shape as latents so that they line up after 'packing'
|
||||
|
||||
Args:
|
||||
context (InvocationContext): The invocation context, for loading the inpaint mask.
|
||||
latents (torch.Tensor): A latent image tensor. In 'unpacked' format. Used to determine the target shape,
|
||||
device, and dtype for the inpaint mask.
|
||||
|
||||
Returns:
|
||||
torch.Tensor | None: Inpaint mask.
|
||||
"""
|
||||
if self.denoise_mask is None:
|
||||
return None
|
||||
|
||||
mask = context.tensors.load(self.denoise_mask.mask_name)
|
||||
|
||||
_, _, latent_height, latent_width = latents.shape
|
||||
mask = tv_resize(
|
||||
img=mask,
|
||||
size=[latent_height, latent_width],
|
||||
interpolation=tv_transforms.InterpolationMode.BILINEAR,
|
||||
antialias=False,
|
||||
)
|
||||
|
||||
mask = mask.to(device=latents.device, dtype=latents.dtype)
|
||||
|
||||
# Expand the inpaint mask to the same shape as `latents` so that when we 'pack' `mask` it lines up with
|
||||
# `latents`.
|
||||
return mask.expand_as(latents)
|
||||
|
||||
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
for lora in self.transformer.loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
assert isinstance(lora_info.model, LoRAModelRaw)
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
|
||||
def _build_step_callback(self, context: InvocationContext) -> Callable[[], None]:
|
||||
def step_callback() -> None:
|
||||
if context.util.is_canceled():
|
||||
raise CanceledException
|
||||
|
||||
# TODO: Make this look like the image before re-enabling
|
||||
# latent_image = unpack(img.float(), self.height, self.width)
|
||||
# latent_image = latent_image.squeeze() # Remove unnecessary dimensions
|
||||
# flattened_tensor = latent_image.reshape(-1) # Flatten to shape [48*128*128]
|
||||
|
||||
# # Create a new tensor of the required shape [255, 255, 3]
|
||||
# latent_image = flattened_tensor[: 255 * 255 * 3].reshape(255, 255, 3) # Reshape to RGB format
|
||||
|
||||
# # Convert to a NumPy array and then to a PIL Image
|
||||
# image = Image.fromarray(latent_image.cpu().numpy().astype(np.uint8))
|
||||
|
||||
# (width, height) = image.size
|
||||
# width *= 8
|
||||
# height *= 8
|
||||
|
||||
# dataURL = image_to_dataURL(image, image_format="JPEG")
|
||||
|
||||
# # TODO: move this whole function to invocation context to properly reference these variables
|
||||
# context._services.events.emit_invocation_denoise_progress(
|
||||
# context._data.queue_item,
|
||||
# context._data.invocation,
|
||||
# state,
|
||||
# ProgressImage(dataURL=dataURL, width=width, height=height),
|
||||
# )
|
||||
|
||||
return step_callback
|
||||
53
invokeai/app/invocations/flux_lora_loader.py
Normal file
53
invokeai/app/invocations/flux_lora_loader.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
||||
from invokeai.app.invocations.model import LoRAField, ModelIdentifierField, TransformerField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
|
||||
|
||||
@invocation_output("flux_lora_loader_output")
|
||||
class FluxLoRALoaderOutput(BaseInvocationOutput):
|
||||
"""FLUX LoRA Loader Output"""
|
||||
|
||||
transformer: TransformerField = OutputField(
|
||||
default=None, description=FieldDescriptions.transformer, title="FLUX Transformer"
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux_lora_loader",
|
||||
title="FLUX LoRA",
|
||||
tags=["lora", "model", "flux"],
|
||||
category="model",
|
||||
version="1.0.0",
|
||||
)
|
||||
class FluxLoRALoaderInvocation(BaseInvocation):
|
||||
"""Apply a LoRA model to a FLUX transformer."""
|
||||
|
||||
lora: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.lora_model, title="LoRA", ui_type=UIType.LoRAModel
|
||||
)
|
||||
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
|
||||
transformer: TransformerField = InputField(
|
||||
description=FieldDescriptions.transformer,
|
||||
input=Input.Connection,
|
||||
title="FLUX Transformer",
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FluxLoRALoaderOutput:
|
||||
lora_key = self.lora.key
|
||||
|
||||
if not context.models.exists(lora_key):
|
||||
raise ValueError(f"Unknown lora: {lora_key}!")
|
||||
|
||||
if any(lora.lora.key == lora_key for lora in self.transformer.loras):
|
||||
raise Exception(f'LoRA "{lora_key}" already applied to transformer.')
|
||||
|
||||
transformer = self.transformer.model_copy(deep=True)
|
||||
transformer.loras.append(
|
||||
LoRAField(
|
||||
lora=self.lora,
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
|
||||
return FluxLoRALoaderOutput(transformer=transformer)
|
||||
@@ -40,7 +40,10 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> FluxConditioningOutput:
|
||||
t5_embeddings, clip_embeddings = self._encode_prompt(context)
|
||||
# Note: The T5 and CLIP encoding are done in separate functions to ensure that all model references are locally
|
||||
# scoped. This ensures that the T5 model can be freed and gc'd before loading the CLIP model (if necessary).
|
||||
t5_embeddings = self._t5_encode(context)
|
||||
clip_embeddings = self._clip_encode(context)
|
||||
conditioning_data = ConditioningFieldData(
|
||||
conditionings=[FLUXConditioningInfo(clip_embeds=clip_embeddings, t5_embeds=t5_embeddings)]
|
||||
)
|
||||
@@ -48,12 +51,7 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
conditioning_name = context.conditioning.save(conditioning_data)
|
||||
return FluxConditioningOutput.build(conditioning_name)
|
||||
|
||||
def _encode_prompt(self, context: InvocationContext) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# Load CLIP.
|
||||
clip_tokenizer_info = context.models.load(self.clip.tokenizer)
|
||||
clip_text_encoder_info = context.models.load(self.clip.text_encoder)
|
||||
|
||||
# Load T5.
|
||||
def _t5_encode(self, context: InvocationContext) -> torch.Tensor:
|
||||
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer)
|
||||
t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder)
|
||||
|
||||
@@ -70,6 +68,15 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
|
||||
prompt_embeds = t5_encoder(prompt)
|
||||
|
||||
assert isinstance(prompt_embeds, torch.Tensor)
|
||||
return prompt_embeds
|
||||
|
||||
def _clip_encode(self, context: InvocationContext) -> torch.Tensor:
|
||||
clip_tokenizer_info = context.models.load(self.clip.tokenizer)
|
||||
clip_text_encoder_info = context.models.load(self.clip.text_encoder)
|
||||
|
||||
prompt = [self.prompt]
|
||||
|
||||
with (
|
||||
clip_text_encoder_info as clip_text_encoder,
|
||||
clip_tokenizer_info as clip_tokenizer,
|
||||
@@ -81,6 +88,5 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
|
||||
pooled_prompt_embeds = clip_encoder(prompt)
|
||||
|
||||
assert isinstance(prompt_embeds, torch.Tensor)
|
||||
assert isinstance(pooled_prompt_embeds, torch.Tensor)
|
||||
return prompt_embeds, pooled_prompt_embeds
|
||||
return pooled_prompt_embeds
|
||||
|
||||
@@ -1,172 +0,0 @@
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
FluxConditioningField,
|
||||
Input,
|
||||
InputField,
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.model import TransformerField, VAEField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.session_processor.session_processor_common import CanceledException
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.model import Flux
|
||||
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
|
||||
from invokeai.backend.flux.sampling import denoise, get_noise, get_schedule, prepare_latent_img_patches, unpack
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux_text_to_image",
|
||||
title="FLUX Text to Image",
|
||||
tags=["image", "flux"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Text-to-image generation using a FLUX model."""
|
||||
|
||||
transformer: TransformerField = InputField(
|
||||
description=FieldDescriptions.flux_model,
|
||||
input=Input.Connection,
|
||||
title="Transformer",
|
||||
)
|
||||
vae: VAEField = InputField(
|
||||
description=FieldDescriptions.vae,
|
||||
input=Input.Connection,
|
||||
)
|
||||
positive_text_conditioning: FluxConditioningField = InputField(
|
||||
description=FieldDescriptions.positive_cond, input=Input.Connection
|
||||
)
|
||||
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
|
||||
height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.")
|
||||
num_steps: int = InputField(
|
||||
default=4, description="Number of diffusion steps. Recommend values are schnell: 4, dev: 50."
|
||||
)
|
||||
guidance: float = InputField(
|
||||
default=4.0,
|
||||
description="The guidance strength. Higher values adhere more strictly to the prompt, and will produce less diverse images. FLUX dev only, ignored for schnell.",
|
||||
)
|
||||
seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
# Load the conditioning data.
|
||||
cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name)
|
||||
assert len(cond_data.conditionings) == 1
|
||||
flux_conditioning = cond_data.conditionings[0]
|
||||
assert isinstance(flux_conditioning, FLUXConditioningInfo)
|
||||
|
||||
latents = self._run_diffusion(context, flux_conditioning.clip_embeds, flux_conditioning.t5_embeds)
|
||||
image = self._run_vae_decoding(context, latents)
|
||||
image_dto = context.images.save(image=image)
|
||||
return ImageOutput.build(image_dto)
|
||||
|
||||
def _run_diffusion(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
clip_embeddings: torch.Tensor,
|
||||
t5_embeddings: torch.Tensor,
|
||||
):
|
||||
transformer_info = context.models.load(self.transformer.transformer)
|
||||
inference_dtype = torch.bfloat16
|
||||
|
||||
# Prepare input noise.
|
||||
x = get_noise(
|
||||
num_samples=1,
|
||||
height=self.height,
|
||||
width=self.width,
|
||||
device=TorchDevice.choose_torch_device(),
|
||||
dtype=inference_dtype,
|
||||
seed=self.seed,
|
||||
)
|
||||
|
||||
img, img_ids = prepare_latent_img_patches(x)
|
||||
|
||||
is_schnell = "schnell" in transformer_info.config.config_path
|
||||
|
||||
timesteps = get_schedule(
|
||||
num_steps=self.num_steps,
|
||||
image_seq_len=img.shape[1],
|
||||
shift=not is_schnell,
|
||||
)
|
||||
|
||||
bs, t5_seq_len, _ = t5_embeddings.shape
|
||||
txt_ids = torch.zeros(bs, t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device())
|
||||
|
||||
# HACK(ryand): Manually empty the cache. Currently we don't check the size of the model before loading it from
|
||||
# disk. Since the transformer model is large (24GB), there's a good chance that it will OOM on 32GB RAM systems
|
||||
# if the cache is not empty.
|
||||
context.models._services.model_manager.load.ram_cache.make_room(24 * 2**30)
|
||||
|
||||
with transformer_info as transformer:
|
||||
assert isinstance(transformer, Flux)
|
||||
|
||||
def step_callback() -> None:
|
||||
if context.util.is_canceled():
|
||||
raise CanceledException
|
||||
|
||||
# TODO: Make this look like the image before re-enabling
|
||||
# latent_image = unpack(img.float(), self.height, self.width)
|
||||
# latent_image = latent_image.squeeze() # Remove unnecessary dimensions
|
||||
# flattened_tensor = latent_image.reshape(-1) # Flatten to shape [48*128*128]
|
||||
|
||||
# # Create a new tensor of the required shape [255, 255, 3]
|
||||
# latent_image = flattened_tensor[: 255 * 255 * 3].reshape(255, 255, 3) # Reshape to RGB format
|
||||
|
||||
# # Convert to a NumPy array and then to a PIL Image
|
||||
# image = Image.fromarray(latent_image.cpu().numpy().astype(np.uint8))
|
||||
|
||||
# (width, height) = image.size
|
||||
# width *= 8
|
||||
# height *= 8
|
||||
|
||||
# dataURL = image_to_dataURL(image, image_format="JPEG")
|
||||
|
||||
# # TODO: move this whole function to invocation context to properly reference these variables
|
||||
# context._services.events.emit_invocation_denoise_progress(
|
||||
# context._data.queue_item,
|
||||
# context._data.invocation,
|
||||
# state,
|
||||
# ProgressImage(dataURL=dataURL, width=width, height=height),
|
||||
# )
|
||||
|
||||
x = denoise(
|
||||
model=transformer,
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
txt=t5_embeddings,
|
||||
txt_ids=txt_ids,
|
||||
vec=clip_embeddings,
|
||||
timesteps=timesteps,
|
||||
step_callback=step_callback,
|
||||
guidance=self.guidance,
|
||||
)
|
||||
|
||||
x = unpack(x.float(), self.height, self.width)
|
||||
|
||||
return x
|
||||
|
||||
def _run_vae_decoding(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
latents: torch.Tensor,
|
||||
) -> Image.Image:
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
with vae_info as vae:
|
||||
assert isinstance(vae, AutoEncoder)
|
||||
latents = latents.to(dtype=TorchDevice.choose_torch_dtype())
|
||||
img = vae.decode(latents)
|
||||
|
||||
img = img.clamp(-1, 1)
|
||||
img = rearrange(img[0], "c h w -> h w c")
|
||||
img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy())
|
||||
|
||||
return img_pil
|
||||
60
invokeai/app/invocations/flux_vae_decode.py
Normal file
60
invokeai/app/invocations/flux_vae_decode.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
InputField,
|
||||
LatentsField,
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.model import VAEField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux_vae_decode",
|
||||
title="FLUX Latents to Image",
|
||||
tags=["latents", "image", "vae", "l2i", "flux"],
|
||||
category="latents",
|
||||
version="1.0.0",
|
||||
)
|
||||
class FluxVaeDecodeInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Generates an image from latents."""
|
||||
|
||||
latents: LatentsField = InputField(
|
||||
description=FieldDescriptions.latents,
|
||||
input=Input.Connection,
|
||||
)
|
||||
vae: VAEField = InputField(
|
||||
description=FieldDescriptions.vae,
|
||||
input=Input.Connection,
|
||||
)
|
||||
|
||||
def _vae_decode(self, vae_info: LoadedModel, latents: torch.Tensor) -> Image.Image:
|
||||
with vae_info as vae:
|
||||
assert isinstance(vae, AutoEncoder)
|
||||
latents = latents.to(device=TorchDevice.choose_torch_device(), dtype=TorchDevice.choose_torch_dtype())
|
||||
img = vae.decode(latents)
|
||||
|
||||
img = img.clamp(-1, 1)
|
||||
img = rearrange(img[0], "c h w -> h w c") # noqa: F821
|
||||
img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy())
|
||||
return img_pil
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = context.tensors.load(self.latents.latents_name)
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
image = self._vae_decode(vae_info=vae_info, latents=latents)
|
||||
|
||||
TorchDevice.empty_cache()
|
||||
image_dto = context.images.save(image=image)
|
||||
return ImageOutput.build(image_dto)
|
||||
67
invokeai/app/invocations/flux_vae_encode.py
Normal file
67
invokeai/app/invocations/flux_vae_encode.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import einops
|
||||
import torch
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
ImageField,
|
||||
Input,
|
||||
InputField,
|
||||
)
|
||||
from invokeai.app.invocations.model import VAEField
|
||||
from invokeai.app.invocations.primitives import LatentsOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
|
||||
from invokeai.backend.model_manager import LoadedModel
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux_vae_encode",
|
||||
title="FLUX Image to Latents",
|
||||
tags=["latents", "image", "vae", "i2l", "flux"],
|
||||
category="latents",
|
||||
version="1.0.0",
|
||||
)
|
||||
class FluxVaeEncodeInvocation(BaseInvocation):
|
||||
"""Encodes an image into latents."""
|
||||
|
||||
image: ImageField = InputField(
|
||||
description="The image to encode.",
|
||||
)
|
||||
vae: VAEField = InputField(
|
||||
description=FieldDescriptions.vae,
|
||||
input=Input.Connection,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor:
|
||||
# TODO(ryand): Expose seed parameter at the invocation level.
|
||||
# TODO(ryand): Write a util function for generating random tensors that is consistent across devices / dtypes.
|
||||
# There's a starting point in get_noise(...), but it needs to be extracted and generalized. This function
|
||||
# should be used for VAE encode sampling.
|
||||
generator = torch.Generator(device=TorchDevice.choose_torch_device()).manual_seed(0)
|
||||
with vae_info as vae:
|
||||
assert isinstance(vae, AutoEncoder)
|
||||
image_tensor = image_tensor.to(
|
||||
device=TorchDevice.choose_torch_device(), dtype=TorchDevice.choose_torch_dtype()
|
||||
)
|
||||
latents = vae.encode(image_tensor, sample=True, generator=generator)
|
||||
return latents
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
|
||||
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||
if image_tensor.dim() == 3:
|
||||
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
|
||||
|
||||
latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor)
|
||||
|
||||
latents = latents.to("cpu")
|
||||
name = context.tensors.save(tensor=latents)
|
||||
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
|
||||
@@ -126,7 +126,7 @@ class ImageMaskToTensorInvocation(BaseInvocation, WithMetadata):
|
||||
title="Tensor Mask to Image",
|
||||
tags=["mask"],
|
||||
category="mask",
|
||||
version="1.0.0",
|
||||
version="1.1.0",
|
||||
)
|
||||
class MaskTensorToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Convert a mask tensor to an image."""
|
||||
@@ -135,6 +135,11 @@ class MaskTensorToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
mask = context.tensors.load(self.mask.tensor_name)
|
||||
|
||||
# Squeeze the channel dimension if it exists.
|
||||
if mask.dim() == 3:
|
||||
mask = mask.squeeze(0)
|
||||
|
||||
# Ensure that the mask is binary.
|
||||
if mask.dtype != torch.bool:
|
||||
mask = mask > 0.5
|
||||
|
||||
@@ -69,6 +69,7 @@ class CLIPField(BaseModel):
|
||||
|
||||
class TransformerField(BaseModel):
|
||||
transformer: ModelIdentifierField = Field(description="Info to load Transformer submodel")
|
||||
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
|
||||
|
||||
|
||||
class T5EncoderField(BaseModel):
|
||||
@@ -157,7 +158,7 @@ class FluxModelLoaderOutput(BaseInvocationOutput):
|
||||
title="Flux Main Model",
|
||||
tags=["model", "flux"],
|
||||
category="model",
|
||||
version="1.0.3",
|
||||
version="1.0.4",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxModelLoaderInvocation(BaseInvocation):
|
||||
@@ -169,80 +170,46 @@ class FluxModelLoaderInvocation(BaseInvocation):
|
||||
input=Input.Direct,
|
||||
)
|
||||
|
||||
t5_encoder: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.t5_encoder,
|
||||
ui_type=UIType.T5EncoderModel,
|
||||
t5_encoder_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.t5_encoder, ui_type=UIType.T5EncoderModel, input=Input.Direct, title="T5 Encoder"
|
||||
)
|
||||
|
||||
clip_embed_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.clip_embed_model,
|
||||
ui_type=UIType.CLIPEmbedModel,
|
||||
input=Input.Direct,
|
||||
title="CLIP Embed",
|
||||
)
|
||||
|
||||
vae_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.vae_model, ui_type=UIType.FluxVAEModel, title="VAE"
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
|
||||
model_key = self.model.key
|
||||
for key in [self.model.key, self.t5_encoder_model.key, self.clip_embed_model.key, self.vae_model.key]:
|
||||
if not context.models.exists(key):
|
||||
raise ValueError(f"Unknown model: {key}")
|
||||
|
||||
transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
|
||||
vae = self.vae_model.model_copy(update={"submodel_type": SubModelType.VAE})
|
||||
|
||||
tokenizer = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
|
||||
clip_encoder = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
|
||||
|
||||
tokenizer2 = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
|
||||
t5_encoder = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
|
||||
|
||||
if not context.models.exists(model_key):
|
||||
raise ValueError(f"Unknown model: {model_key}")
|
||||
transformer = self._get_model(context, SubModelType.Transformer)
|
||||
tokenizer = self._get_model(context, SubModelType.Tokenizer)
|
||||
tokenizer2 = self._get_model(context, SubModelType.Tokenizer2)
|
||||
clip_encoder = self._get_model(context, SubModelType.TextEncoder)
|
||||
t5_encoder = self._get_model(context, SubModelType.TextEncoder2)
|
||||
vae = self._get_model(context, SubModelType.VAE)
|
||||
transformer_config = context.models.get_config(transformer)
|
||||
assert isinstance(transformer_config, CheckpointConfigBase)
|
||||
|
||||
return FluxModelLoaderOutput(
|
||||
transformer=TransformerField(transformer=transformer),
|
||||
transformer=TransformerField(transformer=transformer, loras=[]),
|
||||
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0),
|
||||
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder),
|
||||
vae=VAEField(vae=vae),
|
||||
max_seq_len=max_seq_lengths[transformer_config.config_path],
|
||||
)
|
||||
|
||||
def _get_model(self, context: InvocationContext, submodel: SubModelType) -> ModelIdentifierField:
|
||||
match submodel:
|
||||
case SubModelType.Transformer:
|
||||
return self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
|
||||
case SubModelType.VAE:
|
||||
return self._pull_model_from_mm(
|
||||
context,
|
||||
SubModelType.VAE,
|
||||
"FLUX.1-schnell_ae",
|
||||
ModelType.VAE,
|
||||
BaseModelType.Flux,
|
||||
)
|
||||
case submodel if submodel in [SubModelType.Tokenizer, SubModelType.TextEncoder]:
|
||||
return self._pull_model_from_mm(
|
||||
context,
|
||||
submodel,
|
||||
"clip-vit-large-patch14",
|
||||
ModelType.CLIPEmbed,
|
||||
BaseModelType.Any,
|
||||
)
|
||||
case submodel if submodel in [SubModelType.Tokenizer2, SubModelType.TextEncoder2]:
|
||||
return self._pull_model_from_mm(
|
||||
context,
|
||||
submodel,
|
||||
self.t5_encoder.name,
|
||||
ModelType.T5Encoder,
|
||||
BaseModelType.Any,
|
||||
)
|
||||
case _:
|
||||
raise Exception(f"{submodel.value} is not a supported submodule for a flux model")
|
||||
|
||||
def _pull_model_from_mm(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
submodel: SubModelType,
|
||||
name: str,
|
||||
type: ModelType,
|
||||
base: BaseModelType,
|
||||
):
|
||||
if models := context.models.search_by_attrs(name=name, base=base, type=type):
|
||||
if len(models) != 1:
|
||||
raise Exception(f"Multiple models detected for selected model with name {name}")
|
||||
return ModelIdentifierField.from_config(models[0]).model_copy(update={"submodel_type": submodel})
|
||||
else:
|
||||
raise ValueError(f"Please install the {base}:{type} model named {name} via starter models")
|
||||
|
||||
|
||||
@invocation(
|
||||
"main_model_loader",
|
||||
|
||||
@@ -22,8 +22,8 @@ from invokeai.app.invocations.fields import (
|
||||
from invokeai.app.invocations.model import UNetField
|
||||
from invokeai.app.invocations.primitives import LatentsOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.model_patcher import ModelPatcher
|
||||
from invokeai.backend.peft.lora import LoRAModelRaw
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import ControlNetData, PipelineIntermediateState
|
||||
from invokeai.backend.stable_diffusion.multi_diffusion_pipeline import (
|
||||
MultiDiffusionPipeline,
|
||||
|
||||
@@ -103,7 +103,7 @@ class HFModelSource(StringLikeSource):
|
||||
if self.variant:
|
||||
base += f":{self.variant or ''}"
|
||||
if self.subfolder:
|
||||
base += f":{self.subfolder}"
|
||||
base += f"::{self.subfolder.as_posix()}"
|
||||
return base
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,407 @@
|
||||
{
|
||||
"name": "FLUX Image to Image",
|
||||
"author": "InvokeAI",
|
||||
"description": "A simple image-to-image workflow using a FLUX dev model. ",
|
||||
"version": "1.0.4",
|
||||
"contact": "",
|
||||
"tags": "image2image, flux, image-to-image",
|
||||
"notes": "Prerequisite model downloads: T5 Encoder, CLIP-L Encoder, and FLUX VAE. Quantized and un-quantized versions can be found in the starter models tab within your Model Manager. We recommend using FLUX dev models for image-to-image workflows. The image-to-image performance with FLUX schnell models is poor.",
|
||||
"exposedFields": [
|
||||
{
|
||||
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"fieldName": "model"
|
||||
},
|
||||
{
|
||||
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"fieldName": "t5_encoder_model"
|
||||
},
|
||||
{
|
||||
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"fieldName": "clip_embed_model"
|
||||
},
|
||||
{
|
||||
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"fieldName": "vae_model"
|
||||
},
|
||||
{
|
||||
"nodeId": "ace0258f-67d7-4eee-a218-6fff27065214",
|
||||
"fieldName": "denoising_start"
|
||||
},
|
||||
{
|
||||
"nodeId": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"fieldName": "prompt"
|
||||
},
|
||||
{
|
||||
"nodeId": "ace0258f-67d7-4eee-a218-6fff27065214",
|
||||
"fieldName": "num_steps"
|
||||
}
|
||||
],
|
||||
"meta": {
|
||||
"version": "3.0.0",
|
||||
"category": "default"
|
||||
},
|
||||
"nodes": [
|
||||
{
|
||||
"id": "2981a67c-480f-4237-9384-26b68dbf912b",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "2981a67c-480f-4237-9384-26b68dbf912b",
|
||||
"type": "flux_vae_encode",
|
||||
"version": "1.0.0",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
"isIntermediate": true,
|
||||
"useCache": true,
|
||||
"inputs": {
|
||||
"image": {
|
||||
"name": "image",
|
||||
"label": "",
|
||||
"value": {
|
||||
"image_name": "8a5c62aa-9335-45d2-9c71-89af9fc1f8d4.png"
|
||||
}
|
||||
},
|
||||
"vae": {
|
||||
"name": "vae",
|
||||
"label": ""
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 732.7680166609682,
|
||||
"y": -24.37398171806909
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "ace0258f-67d7-4eee-a218-6fff27065214",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "ace0258f-67d7-4eee-a218-6fff27065214",
|
||||
"type": "flux_denoise",
|
||||
"version": "1.0.0",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
"isIntermediate": true,
|
||||
"useCache": true,
|
||||
"inputs": {
|
||||
"board": {
|
||||
"name": "board",
|
||||
"label": ""
|
||||
},
|
||||
"metadata": {
|
||||
"name": "metadata",
|
||||
"label": ""
|
||||
},
|
||||
"latents": {
|
||||
"name": "latents",
|
||||
"label": ""
|
||||
},
|
||||
"denoise_mask": {
|
||||
"name": "denoise_mask",
|
||||
"label": ""
|
||||
},
|
||||
"denoising_start": {
|
||||
"name": "denoising_start",
|
||||
"label": "",
|
||||
"value": 0.04
|
||||
},
|
||||
"denoising_end": {
|
||||
"name": "denoising_end",
|
||||
"label": "",
|
||||
"value": 1
|
||||
},
|
||||
"transformer": {
|
||||
"name": "transformer",
|
||||
"label": ""
|
||||
},
|
||||
"positive_text_conditioning": {
|
||||
"name": "positive_text_conditioning",
|
||||
"label": ""
|
||||
},
|
||||
"width": {
|
||||
"name": "width",
|
||||
"label": "",
|
||||
"value": 1024
|
||||
},
|
||||
"height": {
|
||||
"name": "height",
|
||||
"label": "",
|
||||
"value": 1024
|
||||
},
|
||||
"num_steps": {
|
||||
"name": "num_steps",
|
||||
"label": "Steps (Recommend 30 for Dev, 4 for Schnell)",
|
||||
"value": 30
|
||||
},
|
||||
"guidance": {
|
||||
"name": "guidance",
|
||||
"label": "",
|
||||
"value": 4
|
||||
},
|
||||
"seed": {
|
||||
"name": "seed",
|
||||
"label": "",
|
||||
"value": 0
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 1182.8836633018684,
|
||||
"y": -251.38882958913183
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "7e5172eb-48c1-44db-a770-8fd83e1435d1",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "7e5172eb-48c1-44db-a770-8fd83e1435d1",
|
||||
"type": "flux_vae_decode",
|
||||
"version": "1.0.0",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
"isIntermediate": false,
|
||||
"useCache": true,
|
||||
"inputs": {
|
||||
"board": {
|
||||
"name": "board",
|
||||
"label": ""
|
||||
},
|
||||
"metadata": {
|
||||
"name": "metadata",
|
||||
"label": ""
|
||||
},
|
||||
"latents": {
|
||||
"name": "latents",
|
||||
"label": ""
|
||||
},
|
||||
"vae": {
|
||||
"name": "vae",
|
||||
"label": ""
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 1575.5797431839133,
|
||||
"y": -209.00150975507415
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"type": "flux_model_loader",
|
||||
"version": "1.0.4",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
"isIntermediate": true,
|
||||
"useCache": false,
|
||||
"inputs": {
|
||||
"model": {
|
||||
"name": "model",
|
||||
"label": "Model (dev variant recommended for Image-to-Image)"
|
||||
},
|
||||
"t5_encoder_model": {
|
||||
"name": "t5_encoder_model",
|
||||
"label": ""
|
||||
},
|
||||
"clip_embed_model": {
|
||||
"name": "clip_embed_model",
|
||||
"label": "",
|
||||
"value": {
|
||||
"key": "fa23a584-b623-415d-832a-21b5098ff1a1",
|
||||
"hash": "blake3:17c19f0ef941c3b7609a9c94a659ca5364de0be364a91d4179f0e39ba17c3b70",
|
||||
"name": "clip-vit-large-patch14",
|
||||
"base": "any",
|
||||
"type": "clip_embed"
|
||||
}
|
||||
},
|
||||
"vae_model": {
|
||||
"name": "vae_model",
|
||||
"label": "",
|
||||
"value": {
|
||||
"key": "74fc82ba-c0a8-479d-a890-2126f82da758",
|
||||
"hash": "blake3:ce21cb76364aa6e2421311cf4a4b5eb052a76c4f1cd207b50703d8978198a068",
|
||||
"name": "FLUX.1-schnell_ae",
|
||||
"base": "flux",
|
||||
"type": "vae"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 328.1809894659957,
|
||||
"y": -90.2241133566946
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"type": "flux_text_encoder",
|
||||
"version": "1.0.0",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
"isIntermediate": true,
|
||||
"useCache": true,
|
||||
"inputs": {
|
||||
"clip": {
|
||||
"name": "clip",
|
||||
"label": ""
|
||||
},
|
||||
"t5_encoder": {
|
||||
"name": "t5_encoder",
|
||||
"label": ""
|
||||
},
|
||||
"t5_max_seq_len": {
|
||||
"name": "t5_max_seq_len",
|
||||
"label": "T5 Max Seq Len",
|
||||
"value": 256
|
||||
},
|
||||
"prompt": {
|
||||
"name": "prompt",
|
||||
"label": "",
|
||||
"value": "a cat wearing a birthday hat"
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 745.8823365057267,
|
||||
"y": -299.60249175851914
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "4754c534-a5f3-4ad0-9382-7887985e668c",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "4754c534-a5f3-4ad0-9382-7887985e668c",
|
||||
"type": "rand_int",
|
||||
"version": "1.0.1",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
"isIntermediate": true,
|
||||
"useCache": false,
|
||||
"inputs": {
|
||||
"low": {
|
||||
"name": "low",
|
||||
"label": "",
|
||||
"value": 0
|
||||
},
|
||||
"high": {
|
||||
"name": "high",
|
||||
"label": "",
|
||||
"value": 2147483647
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 725.834098928012,
|
||||
"y": 496.2710031089931
|
||||
}
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"id": "reactflow__edge-2981a67c-480f-4237-9384-26b68dbf912bheight-ace0258f-67d7-4eee-a218-6fff27065214height",
|
||||
"type": "default",
|
||||
"source": "2981a67c-480f-4237-9384-26b68dbf912b",
|
||||
"target": "ace0258f-67d7-4eee-a218-6fff27065214",
|
||||
"sourceHandle": "height",
|
||||
"targetHandle": "height"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-2981a67c-480f-4237-9384-26b68dbf912bwidth-ace0258f-67d7-4eee-a218-6fff27065214width",
|
||||
"type": "default",
|
||||
"source": "2981a67c-480f-4237-9384-26b68dbf912b",
|
||||
"target": "ace0258f-67d7-4eee-a218-6fff27065214",
|
||||
"sourceHandle": "width",
|
||||
"targetHandle": "width"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-2981a67c-480f-4237-9384-26b68dbf912blatents-ace0258f-67d7-4eee-a218-6fff27065214latents",
|
||||
"type": "default",
|
||||
"source": "2981a67c-480f-4237-9384-26b68dbf912b",
|
||||
"target": "ace0258f-67d7-4eee-a218-6fff27065214",
|
||||
"sourceHandle": "latents",
|
||||
"targetHandle": "latents"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90vae-2981a67c-480f-4237-9384-26b68dbf912bvae",
|
||||
"type": "default",
|
||||
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"target": "2981a67c-480f-4237-9384-26b68dbf912b",
|
||||
"sourceHandle": "vae",
|
||||
"targetHandle": "vae"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-ace0258f-67d7-4eee-a218-6fff27065214latents-7e5172eb-48c1-44db-a770-8fd83e1435d1latents",
|
||||
"type": "default",
|
||||
"source": "ace0258f-67d7-4eee-a218-6fff27065214",
|
||||
"target": "7e5172eb-48c1-44db-a770-8fd83e1435d1",
|
||||
"sourceHandle": "latents",
|
||||
"targetHandle": "latents"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-4754c534-a5f3-4ad0-9382-7887985e668cvalue-ace0258f-67d7-4eee-a218-6fff27065214seed",
|
||||
"type": "default",
|
||||
"source": "4754c534-a5f3-4ad0-9382-7887985e668c",
|
||||
"target": "ace0258f-67d7-4eee-a218-6fff27065214",
|
||||
"sourceHandle": "value",
|
||||
"targetHandle": "seed"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90transformer-ace0258f-67d7-4eee-a218-6fff27065214transformer",
|
||||
"type": "default",
|
||||
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"target": "ace0258f-67d7-4eee-a218-6fff27065214",
|
||||
"sourceHandle": "transformer",
|
||||
"targetHandle": "transformer"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-01f674f8-b3d1-4df1-acac-6cb8e0bfb63cconditioning-ace0258f-67d7-4eee-a218-6fff27065214positive_text_conditioning",
|
||||
"type": "default",
|
||||
"source": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"target": "ace0258f-67d7-4eee-a218-6fff27065214",
|
||||
"sourceHandle": "conditioning",
|
||||
"targetHandle": "positive_text_conditioning"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90vae-7e5172eb-48c1-44db-a770-8fd83e1435d1vae",
|
||||
"type": "default",
|
||||
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"target": "7e5172eb-48c1-44db-a770-8fd83e1435d1",
|
||||
"sourceHandle": "vae",
|
||||
"targetHandle": "vae"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90max_seq_len-01f674f8-b3d1-4df1-acac-6cb8e0bfb63ct5_max_seq_len",
|
||||
"type": "default",
|
||||
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"sourceHandle": "max_seq_len",
|
||||
"targetHandle": "t5_max_seq_len"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90t5_encoder-01f674f8-b3d1-4df1-acac-6cb8e0bfb63ct5_encoder",
|
||||
"type": "default",
|
||||
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"sourceHandle": "t5_encoder",
|
||||
"targetHandle": "t5_encoder"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90clip-01f674f8-b3d1-4df1-acac-6cb8e0bfb63cclip",
|
||||
"type": "default",
|
||||
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"sourceHandle": "clip",
|
||||
"targetHandle": "clip"
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -1,27 +1,35 @@
|
||||
{
|
||||
"name": "FLUX Text to Image",
|
||||
"author": "InvokeAI",
|
||||
"description": "A simple text-to-image workflow using FLUX dev or schnell models. Prerequisite model downloads: T5 Encoder, CLIP-L Encoder, and FLUX VAE. Quantized and un-quantized versions can be found in the starter models tab within your Model Manager. We recommend 4 steps for FLUX schnell models and 30 steps for FLUX dev models.",
|
||||
"version": "1.0.0",
|
||||
"description": "A simple text-to-image workflow using FLUX dev or schnell models.",
|
||||
"version": "1.0.4",
|
||||
"contact": "",
|
||||
"tags": "text2image, flux",
|
||||
"notes": "Prerequisite model downloads: T5 Encoder, CLIP-L Encoder, and FLUX VAE. Quantized and un-quantized versions can be found in the starter models tab within your Model Manager. We recommend 4 steps for FLUX schnell models and 30 steps for FLUX dev models.",
|
||||
"exposedFields": [
|
||||
{
|
||||
"nodeId": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a",
|
||||
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"fieldName": "model"
|
||||
},
|
||||
{
|
||||
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"fieldName": "t5_encoder_model"
|
||||
},
|
||||
{
|
||||
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"fieldName": "clip_embed_model"
|
||||
},
|
||||
{
|
||||
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"fieldName": "vae_model"
|
||||
},
|
||||
{
|
||||
"nodeId": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"fieldName": "prompt"
|
||||
},
|
||||
{
|
||||
"nodeId": "159bdf1b-79e7-4174-b86e-d40e646964c8",
|
||||
"nodeId": "4fe24f07-f906-4f55-ab2c-9beee56ef5bd",
|
||||
"fieldName": "num_steps"
|
||||
},
|
||||
{
|
||||
"nodeId": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a",
|
||||
"fieldName": "t5_encoder"
|
||||
}
|
||||
],
|
||||
"meta": {
|
||||
@@ -30,12 +38,127 @@
|
||||
},
|
||||
"nodes": [
|
||||
{
|
||||
"id": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a",
|
||||
"id": "4fe24f07-f906-4f55-ab2c-9beee56ef5bd",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a",
|
||||
"id": "4fe24f07-f906-4f55-ab2c-9beee56ef5bd",
|
||||
"type": "flux_denoise",
|
||||
"version": "1.0.0",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
"isIntermediate": true,
|
||||
"useCache": true,
|
||||
"inputs": {
|
||||
"board": {
|
||||
"name": "board",
|
||||
"label": ""
|
||||
},
|
||||
"metadata": {
|
||||
"name": "metadata",
|
||||
"label": ""
|
||||
},
|
||||
"latents": {
|
||||
"name": "latents",
|
||||
"label": ""
|
||||
},
|
||||
"denoise_mask": {
|
||||
"name": "denoise_mask",
|
||||
"label": ""
|
||||
},
|
||||
"denoising_start": {
|
||||
"name": "denoising_start",
|
||||
"label": "",
|
||||
"value": 0
|
||||
},
|
||||
"denoising_end": {
|
||||
"name": "denoising_end",
|
||||
"label": "",
|
||||
"value": 1
|
||||
},
|
||||
"transformer": {
|
||||
"name": "transformer",
|
||||
"label": ""
|
||||
},
|
||||
"positive_text_conditioning": {
|
||||
"name": "positive_text_conditioning",
|
||||
"label": ""
|
||||
},
|
||||
"width": {
|
||||
"name": "width",
|
||||
"label": "",
|
||||
"value": 1024
|
||||
},
|
||||
"height": {
|
||||
"name": "height",
|
||||
"label": "",
|
||||
"value": 1024
|
||||
},
|
||||
"num_steps": {
|
||||
"name": "num_steps",
|
||||
"label": "Steps (Recommend 30 for Dev, 4 for Schnell)",
|
||||
"value": 30
|
||||
},
|
||||
"guidance": {
|
||||
"name": "guidance",
|
||||
"label": "",
|
||||
"value": 4
|
||||
},
|
||||
"seed": {
|
||||
"name": "seed",
|
||||
"label": "",
|
||||
"value": 0
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 1186.1868226120378,
|
||||
"y": -214.9459927686657
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "7e5172eb-48c1-44db-a770-8fd83e1435d1",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "7e5172eb-48c1-44db-a770-8fd83e1435d1",
|
||||
"type": "flux_vae_decode",
|
||||
"version": "1.0.0",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
"isIntermediate": false,
|
||||
"useCache": true,
|
||||
"inputs": {
|
||||
"board": {
|
||||
"name": "board",
|
||||
"label": ""
|
||||
},
|
||||
"metadata": {
|
||||
"name": "metadata",
|
||||
"label": ""
|
||||
},
|
||||
"latents": {
|
||||
"name": "latents",
|
||||
"label": ""
|
||||
},
|
||||
"vae": {
|
||||
"name": "vae",
|
||||
"label": ""
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 1575.5797431839133,
|
||||
"y": -209.00150975507415
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"type": "flux_model_loader",
|
||||
"version": "1.0.3",
|
||||
"version": "1.0.4",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
@@ -44,31 +167,25 @@
|
||||
"inputs": {
|
||||
"model": {
|
||||
"name": "model",
|
||||
"label": "Model (Starter Models can be found in Model Manager)",
|
||||
"value": {
|
||||
"key": "f04a7a2f-c74d-4538-8d5e-879a53501662",
|
||||
"hash": "random:4875da7a9508444ffa706f61961c260d0c6729f6181a86b31fad06df1277b850",
|
||||
"name": "FLUX Dev (Quantized)",
|
||||
"base": "flux",
|
||||
"type": "main"
|
||||
}
|
||||
"label": ""
|
||||
},
|
||||
"t5_encoder": {
|
||||
"name": "t5_encoder",
|
||||
"label": "T 5 Encoder (Starter Models can be found in Model Manager)",
|
||||
"value": {
|
||||
"key": "20dcd9ec-5fbb-4012-8401-049e707da5e5",
|
||||
"hash": "random:f986be43ff3502169e4adbdcee158afb0e0a65a1edc4cab16ae59963630cfd8f",
|
||||
"name": "t5_bnb_int8_quantized_encoder",
|
||||
"base": "any",
|
||||
"type": "t5_encoder"
|
||||
}
|
||||
"t5_encoder_model": {
|
||||
"name": "t5_encoder_model",
|
||||
"label": ""
|
||||
},
|
||||
"clip_embed_model": {
|
||||
"name": "clip_embed_model",
|
||||
"label": ""
|
||||
},
|
||||
"vae_model": {
|
||||
"name": "vae_model",
|
||||
"label": ""
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 337.09365228062825,
|
||||
"y": 40.63469521079861
|
||||
"x": 381.1882713063478,
|
||||
"y": -95.89663532854017
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -105,8 +222,8 @@
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 824.1970602278849,
|
||||
"y": 146.98251001061735
|
||||
"x": 778.4899149328337,
|
||||
"y": -100.36469216659502
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -135,132 +252,75 @@
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 822.9899179655476,
|
||||
"y": 360.9657214885052
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "159bdf1b-79e7-4174-b86e-d40e646964c8",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "159bdf1b-79e7-4174-b86e-d40e646964c8",
|
||||
"type": "flux_text_to_image",
|
||||
"version": "1.0.0",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
"isIntermediate": false,
|
||||
"useCache": true,
|
||||
"inputs": {
|
||||
"board": {
|
||||
"name": "board",
|
||||
"label": ""
|
||||
},
|
||||
"metadata": {
|
||||
"name": "metadata",
|
||||
"label": ""
|
||||
},
|
||||
"transformer": {
|
||||
"name": "transformer",
|
||||
"label": ""
|
||||
},
|
||||
"vae": {
|
||||
"name": "vae",
|
||||
"label": ""
|
||||
},
|
||||
"positive_text_conditioning": {
|
||||
"name": "positive_text_conditioning",
|
||||
"label": ""
|
||||
},
|
||||
"width": {
|
||||
"name": "width",
|
||||
"label": "",
|
||||
"value": 1024
|
||||
},
|
||||
"height": {
|
||||
"name": "height",
|
||||
"label": "",
|
||||
"value": 1024
|
||||
},
|
||||
"num_steps": {
|
||||
"name": "num_steps",
|
||||
"label": "Steps (Recommend 30 for Dev, 4 for Schnell)",
|
||||
"value": 30
|
||||
},
|
||||
"guidance": {
|
||||
"name": "guidance",
|
||||
"label": "",
|
||||
"value": 4
|
||||
},
|
||||
"seed": {
|
||||
"name": "seed",
|
||||
"label": "",
|
||||
"value": 0
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 1216.3900791301849,
|
||||
"y": 5.500841807102248
|
||||
"x": 800.9667463219505,
|
||||
"y": 285.8297267547506
|
||||
}
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"id": "reactflow__edge-4f0207c2-ff40-41fd-b047-ad33fbb1c33amax_seq_len-01f674f8-b3d1-4df1-acac-6cb8e0bfb63ct5_max_seq_len",
|
||||
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90transformer-4fe24f07-f906-4f55-ab2c-9beee56ef5bdtransformer",
|
||||
"type": "default",
|
||||
"source": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a",
|
||||
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"target": "4fe24f07-f906-4f55-ab2c-9beee56ef5bd",
|
||||
"sourceHandle": "transformer",
|
||||
"targetHandle": "transformer"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-01f674f8-b3d1-4df1-acac-6cb8e0bfb63cconditioning-4fe24f07-f906-4f55-ab2c-9beee56ef5bdpositive_text_conditioning",
|
||||
"type": "default",
|
||||
"source": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"target": "4fe24f07-f906-4f55-ab2c-9beee56ef5bd",
|
||||
"sourceHandle": "conditioning",
|
||||
"targetHandle": "positive_text_conditioning"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-4754c534-a5f3-4ad0-9382-7887985e668cvalue-4fe24f07-f906-4f55-ab2c-9beee56ef5bdseed",
|
||||
"type": "default",
|
||||
"source": "4754c534-a5f3-4ad0-9382-7887985e668c",
|
||||
"target": "4fe24f07-f906-4f55-ab2c-9beee56ef5bd",
|
||||
"sourceHandle": "value",
|
||||
"targetHandle": "seed"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-4fe24f07-f906-4f55-ab2c-9beee56ef5bdlatents-7e5172eb-48c1-44db-a770-8fd83e1435d1latents",
|
||||
"type": "default",
|
||||
"source": "4fe24f07-f906-4f55-ab2c-9beee56ef5bd",
|
||||
"target": "7e5172eb-48c1-44db-a770-8fd83e1435d1",
|
||||
"sourceHandle": "latents",
|
||||
"targetHandle": "latents"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90vae-7e5172eb-48c1-44db-a770-8fd83e1435d1vae",
|
||||
"type": "default",
|
||||
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"target": "7e5172eb-48c1-44db-a770-8fd83e1435d1",
|
||||
"sourceHandle": "vae",
|
||||
"targetHandle": "vae"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90max_seq_len-01f674f8-b3d1-4df1-acac-6cb8e0bfb63ct5_max_seq_len",
|
||||
"type": "default",
|
||||
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"sourceHandle": "max_seq_len",
|
||||
"targetHandle": "t5_max_seq_len"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-4f0207c2-ff40-41fd-b047-ad33fbb1c33avae-159bdf1b-79e7-4174-b86e-d40e646964c8vae",
|
||||
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90t5_encoder-01f674f8-b3d1-4df1-acac-6cb8e0bfb63ct5_encoder",
|
||||
"type": "default",
|
||||
"source": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a",
|
||||
"target": "159bdf1b-79e7-4174-b86e-d40e646964c8",
|
||||
"sourceHandle": "vae",
|
||||
"targetHandle": "vae"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-4f0207c2-ff40-41fd-b047-ad33fbb1c33atransformer-159bdf1b-79e7-4174-b86e-d40e646964c8transformer",
|
||||
"type": "default",
|
||||
"source": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a",
|
||||
"target": "159bdf1b-79e7-4174-b86e-d40e646964c8",
|
||||
"sourceHandle": "transformer",
|
||||
"targetHandle": "transformer"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-4f0207c2-ff40-41fd-b047-ad33fbb1c33at5_encoder-01f674f8-b3d1-4df1-acac-6cb8e0bfb63ct5_encoder",
|
||||
"type": "default",
|
||||
"source": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a",
|
||||
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"sourceHandle": "t5_encoder",
|
||||
"targetHandle": "t5_encoder"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-4f0207c2-ff40-41fd-b047-ad33fbb1c33aclip-01f674f8-b3d1-4df1-acac-6cb8e0bfb63cclip",
|
||||
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90clip-01f674f8-b3d1-4df1-acac-6cb8e0bfb63cclip",
|
||||
"type": "default",
|
||||
"source": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a",
|
||||
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"sourceHandle": "clip",
|
||||
"targetHandle": "clip"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-01f674f8-b3d1-4df1-acac-6cb8e0bfb63cconditioning-159bdf1b-79e7-4174-b86e-d40e646964c8positive_text_conditioning",
|
||||
"type": "default",
|
||||
"source": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"target": "159bdf1b-79e7-4174-b86e-d40e646964c8",
|
||||
"sourceHandle": "conditioning",
|
||||
"targetHandle": "positive_text_conditioning"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-4754c534-a5f3-4ad0-9382-7887985e668cvalue-159bdf1b-79e7-4174-b86e-d40e646964c8seed",
|
||||
"type": "default",
|
||||
"source": "4754c534-a5f3-4ad0-9382-7887985e668c",
|
||||
"target": "159bdf1b-79e7-4174-b86e-d40e646964c8",
|
||||
"sourceHandle": "value",
|
||||
"targetHandle": "seed"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
45
invokeai/backend/flux/denoise.py
Normal file
45
invokeai/backend/flux/denoise.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from invokeai.backend.flux.inpaint_extension import InpaintExtension
|
||||
from invokeai.backend.flux.model import Flux
|
||||
|
||||
|
||||
def denoise(
|
||||
model: Flux,
|
||||
# model input
|
||||
img: torch.Tensor,
|
||||
img_ids: torch.Tensor,
|
||||
txt: torch.Tensor,
|
||||
txt_ids: torch.Tensor,
|
||||
vec: torch.Tensor,
|
||||
# sampling parameters
|
||||
timesteps: list[float],
|
||||
step_callback: Callable[[], None],
|
||||
guidance: float,
|
||||
inpaint_extension: InpaintExtension | None,
|
||||
):
|
||||
# guidance_vec is ignored for schnell.
|
||||
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
||||
for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:], strict=True))):
|
||||
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
||||
pred = model(
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
txt=txt,
|
||||
txt_ids=txt_ids,
|
||||
y=vec,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
)
|
||||
|
||||
img = img + (t_prev - t_curr) * pred
|
||||
|
||||
if inpaint_extension is not None:
|
||||
img = inpaint_extension.merge_intermediate_latents_with_init_latents(img, t_prev)
|
||||
|
||||
step_callback()
|
||||
|
||||
return img
|
||||
35
invokeai/backend/flux/inpaint_extension.py
Normal file
35
invokeai/backend/flux/inpaint_extension.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import torch
|
||||
|
||||
|
||||
class InpaintExtension:
|
||||
"""A class for managing inpainting with FLUX."""
|
||||
|
||||
def __init__(self, init_latents: torch.Tensor, inpaint_mask: torch.Tensor, noise: torch.Tensor):
|
||||
"""Initialize InpaintExtension.
|
||||
|
||||
Args:
|
||||
init_latents (torch.Tensor): The initial latents (i.e. un-noised at timestep 0). In 'packed' format.
|
||||
inpaint_mask (torch.Tensor): A mask specifying which elements to inpaint. Range [0, 1]. Values of 1 will be
|
||||
re-generated. Values of 0 will remain unchanged. Values between 0 and 1 can be used to blend the
|
||||
inpainted region with the background. In 'packed' format.
|
||||
noise (torch.Tensor): The noise tensor used to noise the init_latents. In 'packed' format.
|
||||
"""
|
||||
assert init_latents.shape == inpaint_mask.shape == noise.shape
|
||||
self._init_latents = init_latents
|
||||
self._inpaint_mask = inpaint_mask
|
||||
self._noise = noise
|
||||
|
||||
def merge_intermediate_latents_with_init_latents(
|
||||
self, intermediate_latents: torch.Tensor, timestep: float
|
||||
) -> torch.Tensor:
|
||||
"""Merge the intermediate latents with the initial latents for the current timestep using the inpaint mask. I.e.
|
||||
update the intermediate latents to keep the regions that are not being inpainted on the correct noise
|
||||
trajectory.
|
||||
|
||||
This function should be called after each denoising step.
|
||||
"""
|
||||
# Noise the init latents for the current timestep.
|
||||
noised_init_latents = self._noise * timestep + (1.0 - timestep) * self._init_latents
|
||||
|
||||
# Merge the intermediate latents with the noised_init_latents using the inpaint_mask.
|
||||
return intermediate_latents * self._inpaint_mask + noised_init_latents * (1.0 - self._inpaint_mask)
|
||||
@@ -258,16 +258,17 @@ class Decoder(nn.Module):
|
||||
|
||||
|
||||
class DiagonalGaussian(nn.Module):
|
||||
def __init__(self, sample: bool = True, chunk_dim: int = 1):
|
||||
def __init__(self, chunk_dim: int = 1):
|
||||
super().__init__()
|
||||
self.sample = sample
|
||||
self.chunk_dim = chunk_dim
|
||||
|
||||
def forward(self, z: Tensor) -> Tensor:
|
||||
def forward(self, z: Tensor, sample: bool = True, generator: torch.Generator | None = None) -> Tensor:
|
||||
mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
|
||||
if self.sample:
|
||||
if sample:
|
||||
std = torch.exp(0.5 * logvar)
|
||||
return mean + std * torch.randn_like(mean)
|
||||
# Unfortunately, torch.randn_like(...) does not accept a generator argument at the time of writing, so we
|
||||
# have to use torch.randn(...) instead.
|
||||
return mean + std * torch.randn(size=mean.size(), generator=generator, dtype=mean.dtype, device=mean.device)
|
||||
else:
|
||||
return mean
|
||||
|
||||
@@ -297,8 +298,21 @@ class AutoEncoder(nn.Module):
|
||||
self.scale_factor = params.scale_factor
|
||||
self.shift_factor = params.shift_factor
|
||||
|
||||
def encode(self, x: Tensor) -> Tensor:
|
||||
z = self.reg(self.encoder(x))
|
||||
def encode(self, x: Tensor, sample: bool = True, generator: torch.Generator | None = None) -> Tensor:
|
||||
"""Run VAE encoding on input tensor x.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input image tensor. Shape: (batch_size, in_channels, height, width).
|
||||
sample (bool, optional): If True, sample from the encoded distribution, else, return the distribution mean.
|
||||
Defaults to True.
|
||||
generator (torch.Generator | None, optional): Optional random number generator for reproducibility.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
Tensor: Encoded latent tensor. Shape: (batch_size, z_channels, latent_height, latent_width).
|
||||
"""
|
||||
|
||||
z = self.reg(self.encoder(x), sample=sample, generator=generator)
|
||||
z = self.scale_factor * (z - self.shift_factor)
|
||||
return z
|
||||
|
||||
|
||||
@@ -1,176 +0,0 @@
|
||||
# Initially pulled from https://github.com/black-forest-labs/flux
|
||||
|
||||
import math
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
from torch import Tensor
|
||||
from tqdm import tqdm
|
||||
|
||||
from invokeai.backend.flux.model import Flux
|
||||
from invokeai.backend.flux.modules.conditioner import HFEncoder
|
||||
|
||||
|
||||
def get_noise(
|
||||
num_samples: int,
|
||||
height: int,
|
||||
width: int,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
):
|
||||
# We always generate noise on the same device and dtype then cast to ensure consistency across devices/dtypes.
|
||||
rand_device = "cpu"
|
||||
rand_dtype = torch.float16
|
||||
return torch.randn(
|
||||
num_samples,
|
||||
16,
|
||||
# allow for packing
|
||||
2 * math.ceil(height / 16),
|
||||
2 * math.ceil(width / 16),
|
||||
device=rand_device,
|
||||
dtype=rand_dtype,
|
||||
generator=torch.Generator(device=rand_device).manual_seed(seed),
|
||||
).to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
def prepare(t5: HFEncoder, clip: HFEncoder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]:
|
||||
bs, c, h, w = img.shape
|
||||
if bs == 1 and not isinstance(prompt, str):
|
||||
bs = len(prompt)
|
||||
|
||||
img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||
if img.shape[0] == 1 and bs > 1:
|
||||
img = repeat(img, "1 ... -> bs ...", bs=bs)
|
||||
|
||||
img_ids = torch.zeros(h // 2, w // 2, 3)
|
||||
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
|
||||
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
|
||||
if isinstance(prompt, str):
|
||||
prompt = [prompt]
|
||||
txt = t5(prompt)
|
||||
if txt.shape[0] == 1 and bs > 1:
|
||||
txt = repeat(txt, "1 ... -> bs ...", bs=bs)
|
||||
txt_ids = torch.zeros(bs, txt.shape[1], 3)
|
||||
|
||||
vec = clip(prompt)
|
||||
if vec.shape[0] == 1 and bs > 1:
|
||||
vec = repeat(vec, "1 ... -> bs ...", bs=bs)
|
||||
|
||||
return {
|
||||
"img": img,
|
||||
"img_ids": img_ids.to(img.device),
|
||||
"txt": txt.to(img.device),
|
||||
"txt_ids": txt_ids.to(img.device),
|
||||
"vec": vec.to(img.device),
|
||||
}
|
||||
|
||||
|
||||
def time_shift(mu: float, sigma: float, t: Tensor):
|
||||
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
||||
|
||||
|
||||
def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]:
|
||||
m = (y2 - y1) / (x2 - x1)
|
||||
b = y1 - m * x1
|
||||
return lambda x: m * x + b
|
||||
|
||||
|
||||
def get_schedule(
|
||||
num_steps: int,
|
||||
image_seq_len: int,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 1.15,
|
||||
shift: bool = True,
|
||||
) -> list[float]:
|
||||
# extra step for zero
|
||||
timesteps = torch.linspace(1, 0, num_steps + 1)
|
||||
|
||||
# shifting the schedule to favor high timesteps for higher signal images
|
||||
if shift:
|
||||
# eastimate mu based on linear estimation between two points
|
||||
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
|
||||
timesteps = time_shift(mu, 1.0, timesteps)
|
||||
|
||||
return timesteps.tolist()
|
||||
|
||||
|
||||
def denoise(
|
||||
model: Flux,
|
||||
# model input
|
||||
img: Tensor,
|
||||
img_ids: Tensor,
|
||||
txt: Tensor,
|
||||
txt_ids: Tensor,
|
||||
vec: Tensor,
|
||||
# sampling parameters
|
||||
timesteps: list[float],
|
||||
step_callback: Callable[[], None],
|
||||
guidance: float = 4.0,
|
||||
):
|
||||
dtype = model.txt_in.bias.dtype
|
||||
|
||||
# TODO(ryand): This shouldn't be necessary if we manage the dtypes properly in the caller.
|
||||
img = img.to(dtype=dtype)
|
||||
img_ids = img_ids.to(dtype=dtype)
|
||||
txt = txt.to(dtype=dtype)
|
||||
txt_ids = txt_ids.to(dtype=dtype)
|
||||
vec = vec.to(dtype=dtype)
|
||||
|
||||
# this is ignored for schnell
|
||||
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
||||
for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:], strict=True))):
|
||||
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
||||
pred = model(
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
txt=txt,
|
||||
txt_ids=txt_ids,
|
||||
y=vec,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
)
|
||||
|
||||
img = img + (t_prev - t_curr) * pred
|
||||
step_callback()
|
||||
|
||||
return img
|
||||
|
||||
|
||||
def unpack(x: Tensor, height: int, width: int) -> Tensor:
|
||||
return rearrange(
|
||||
x,
|
||||
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
|
||||
h=math.ceil(height / 16),
|
||||
w=math.ceil(width / 16),
|
||||
ph=2,
|
||||
pw=2,
|
||||
)
|
||||
|
||||
|
||||
def prepare_latent_img_patches(latent_img: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Convert an input image in latent space to patches for diffusion.
|
||||
|
||||
This implementation was extracted from:
|
||||
https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/sampling.py#L32
|
||||
|
||||
Returns:
|
||||
tuple[Tensor, Tensor]: (img, img_ids), as defined in the original flux repo.
|
||||
"""
|
||||
bs, c, h, w = latent_img.shape
|
||||
|
||||
# Pixel unshuffle with a scale of 2, and flatten the height/width dimensions to get an array of patches.
|
||||
img = rearrange(latent_img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||
if img.shape[0] == 1 and bs > 1:
|
||||
img = repeat(img, "1 ... -> bs ...", bs=bs)
|
||||
|
||||
# Generate patch position ids.
|
||||
img_ids = torch.zeros(h // 2, w // 2, 3, device=img.device)
|
||||
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=img.device)[:, None]
|
||||
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=img.device)[None, :]
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
|
||||
return img, img_ids
|
||||
135
invokeai/backend/flux/sampling_utils.py
Normal file
135
invokeai/backend/flux/sampling_utils.py
Normal file
@@ -0,0 +1,135 @@
|
||||
# Initially pulled from https://github.com/black-forest-labs/flux
|
||||
|
||||
import math
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
|
||||
|
||||
def get_noise(
|
||||
num_samples: int,
|
||||
height: int,
|
||||
width: int,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
):
|
||||
# We always generate noise on the same device and dtype then cast to ensure consistency across devices/dtypes.
|
||||
rand_device = "cpu"
|
||||
rand_dtype = torch.float16
|
||||
return torch.randn(
|
||||
num_samples,
|
||||
16,
|
||||
# allow for packing
|
||||
2 * math.ceil(height / 16),
|
||||
2 * math.ceil(width / 16),
|
||||
device=rand_device,
|
||||
dtype=rand_dtype,
|
||||
generator=torch.Generator(device=rand_device).manual_seed(seed),
|
||||
).to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
def time_shift(mu: float, sigma: float, t: torch.Tensor) -> torch.Tensor:
|
||||
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
||||
|
||||
|
||||
def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]:
|
||||
m = (y2 - y1) / (x2 - x1)
|
||||
b = y1 - m * x1
|
||||
return lambda x: m * x + b
|
||||
|
||||
|
||||
def get_schedule(
|
||||
num_steps: int,
|
||||
image_seq_len: int,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 1.15,
|
||||
shift: bool = True,
|
||||
) -> list[float]:
|
||||
# extra step for zero
|
||||
timesteps = torch.linspace(1, 0, num_steps + 1)
|
||||
|
||||
# shifting the schedule to favor high timesteps for higher signal images
|
||||
if shift:
|
||||
# estimate mu based on linear estimation between two points
|
||||
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
|
||||
timesteps = time_shift(mu, 1.0, timesteps)
|
||||
|
||||
return timesteps.tolist()
|
||||
|
||||
|
||||
def _find_last_index_ge_val(timesteps: list[float], val: float, eps: float = 1e-6) -> int:
|
||||
"""Find the last index in timesteps that is >= val.
|
||||
|
||||
We use epsilon-close equality to avoid potential floating point errors.
|
||||
"""
|
||||
idx = len(list(filter(lambda t: t >= (val - eps), timesteps))) - 1
|
||||
assert idx >= 0
|
||||
return idx
|
||||
|
||||
|
||||
def clip_timestep_schedule(timesteps: list[float], denoising_start: float, denoising_end: float) -> list[float]:
|
||||
"""Clip the timestep schedule to the denoising range.
|
||||
|
||||
Args:
|
||||
timesteps (list[float]): The original timestep schedule: [1.0, ..., 0.0].
|
||||
denoising_start (float): A value in [0, 1] specifying the start of the denoising process. E.g. a value of 0.2
|
||||
would mean that the denoising process start at the last timestep in the schedule >= 0.8.
|
||||
denoising_end (float): A value in [0, 1] specifying the end of the denoising process. E.g. a value of 0.8 would
|
||||
mean that the denoising process end at the last timestep in the schedule >= 0.2.
|
||||
|
||||
Returns:
|
||||
list[float]: The clipped timestep schedule.
|
||||
"""
|
||||
assert 0.0 <= denoising_start <= 1.0
|
||||
assert 0.0 <= denoising_end <= 1.0
|
||||
assert denoising_start <= denoising_end
|
||||
|
||||
t_start_val = 1.0 - denoising_start
|
||||
t_end_val = 1.0 - denoising_end
|
||||
|
||||
t_start_idx = _find_last_index_ge_val(timesteps, t_start_val)
|
||||
t_end_idx = _find_last_index_ge_val(timesteps, t_end_val)
|
||||
|
||||
clipped_timesteps = timesteps[t_start_idx : t_end_idx + 1]
|
||||
|
||||
return clipped_timesteps
|
||||
|
||||
|
||||
def unpack(x: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||
"""Unpack flat array of patch embeddings to latent image."""
|
||||
return rearrange(
|
||||
x,
|
||||
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
|
||||
h=math.ceil(height / 16),
|
||||
w=math.ceil(width / 16),
|
||||
ph=2,
|
||||
pw=2,
|
||||
)
|
||||
|
||||
|
||||
def pack(x: torch.Tensor) -> torch.Tensor:
|
||||
"""Pack latent image to flattented array of patch embeddings."""
|
||||
# Pixel unshuffle with a scale of 2, and flatten the height/width dimensions to get an array of patches.
|
||||
return rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||
|
||||
|
||||
def generate_img_ids(h: int, w: int, batch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
|
||||
"""Generate tensor of image position ids.
|
||||
|
||||
Args:
|
||||
h (int): Height of image in latent space.
|
||||
w (int): Width of image in latent space.
|
||||
batch_size (int): Batch size.
|
||||
device (torch.device): Device.
|
||||
dtype (torch.dtype): dtype.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Image position ids.
|
||||
"""
|
||||
img_ids = torch.zeros(h // 2, w // 2, 3, device=device, dtype=dtype)
|
||||
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=device, dtype=dtype)[:, None]
|
||||
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=device, dtype=dtype)[None, :]
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=batch_size)
|
||||
return img_ids
|
||||
@@ -1,672 +0,0 @@
|
||||
# Copyright (c) 2024 The InvokeAI Development team
|
||||
"""LoRA model support."""
|
||||
|
||||
import bisect
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
from typing_extensions import Self
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.backend.model_manager import BaseModelType
|
||||
from invokeai.backend.raw_model import RawModel
|
||||
|
||||
|
||||
class LoRALayerBase:
|
||||
# rank: Optional[int]
|
||||
# alpha: Optional[float]
|
||||
# bias: Optional[torch.Tensor]
|
||||
# layer_key: str
|
||||
|
||||
# @property
|
||||
# def scale(self):
|
||||
# return self.alpha / self.rank if (self.alpha and self.rank) else 1.0
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
if "alpha" in values:
|
||||
self.alpha = values["alpha"].item()
|
||||
else:
|
||||
self.alpha = None
|
||||
|
||||
if "bias_indices" in values and "bias_values" in values and "bias_size" in values:
|
||||
self.bias: Optional[torch.Tensor] = torch.sparse_coo_tensor(
|
||||
values["bias_indices"],
|
||||
values["bias_values"],
|
||||
tuple(values["bias_size"]),
|
||||
)
|
||||
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
self.rank = None # set in layer implementation
|
||||
self.layer_key = layer_key
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_bias(self, orig_bias: torch.Tensor) -> Optional[torch.Tensor]:
|
||||
return self.bias
|
||||
|
||||
def get_parameters(self, orig_module: torch.nn.Module) -> Dict[str, torch.Tensor]:
|
||||
params = {"weight": self.get_weight(orig_module.weight)}
|
||||
bias = self.get_bias(orig_module.bias)
|
||||
if bias is not None:
|
||||
params["bias"] = bias
|
||||
return params
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = 0
|
||||
for val in [self.bias]:
|
||||
if val is not None:
|
||||
model_size += val.nelement() * val.element_size()
|
||||
return model_size
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
if self.bias is not None:
|
||||
self.bias = self.bias.to(device=device, dtype=dtype)
|
||||
|
||||
def check_keys(self, values: Dict[str, torch.Tensor], known_keys: Set[str]):
|
||||
"""Log a warning if values contains unhandled keys."""
|
||||
# {"alpha", "bias_indices", "bias_values", "bias_size"} are hard-coded, because they are handled by
|
||||
# `LoRALayerBase`. Sub-classes should provide the known_keys that they handled.
|
||||
all_known_keys = known_keys | {"alpha", "bias_indices", "bias_values", "bias_size"}
|
||||
unknown_keys = set(values.keys()) - all_known_keys
|
||||
if unknown_keys:
|
||||
logger.warning(
|
||||
f"Unexpected keys found in LoRA/LyCORIS layer, model might work incorrectly! Keys: {unknown_keys}"
|
||||
)
|
||||
|
||||
|
||||
# TODO: find and debug lora/locon with bias
|
||||
class LoRALayer(LoRALayerBase):
|
||||
# up: torch.Tensor
|
||||
# mid: Optional[torch.Tensor]
|
||||
# down: torch.Tensor
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
self.up = values["lora_up.weight"]
|
||||
self.down = values["lora_down.weight"]
|
||||
self.mid = values.get("lora_mid.weight", None)
|
||||
|
||||
self.rank = self.down.shape[0]
|
||||
self.check_keys(
|
||||
values,
|
||||
{
|
||||
"lora_up.weight",
|
||||
"lora_down.weight",
|
||||
"lora_mid.weight",
|
||||
},
|
||||
)
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
if self.mid is not None:
|
||||
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
|
||||
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
|
||||
weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down)
|
||||
else:
|
||||
weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1)
|
||||
|
||||
return weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
for val in [self.up, self.mid, self.down]:
|
||||
if val is not None:
|
||||
model_size += val.nelement() * val.element_size()
|
||||
return model_size
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.up = self.up.to(device=device, dtype=dtype)
|
||||
self.down = self.down.to(device=device, dtype=dtype)
|
||||
|
||||
if self.mid is not None:
|
||||
self.mid = self.mid.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
class LoHALayer(LoRALayerBase):
|
||||
# w1_a: torch.Tensor
|
||||
# w1_b: torch.Tensor
|
||||
# w2_a: torch.Tensor
|
||||
# w2_b: torch.Tensor
|
||||
# t1: Optional[torch.Tensor] = None
|
||||
# t2: Optional[torch.Tensor] = None
|
||||
|
||||
def __init__(self, layer_key: str, values: Dict[str, torch.Tensor]):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
self.w1_a = values["hada_w1_a"]
|
||||
self.w1_b = values["hada_w1_b"]
|
||||
self.w2_a = values["hada_w2_a"]
|
||||
self.w2_b = values["hada_w2_b"]
|
||||
self.t1 = values.get("hada_t1", None)
|
||||
self.t2 = values.get("hada_t2", None)
|
||||
|
||||
self.rank = self.w1_b.shape[0]
|
||||
self.check_keys(
|
||||
values,
|
||||
{
|
||||
"hada_w1_a",
|
||||
"hada_w1_b",
|
||||
"hada_w2_a",
|
||||
"hada_w2_b",
|
||||
"hada_t1",
|
||||
"hada_t2",
|
||||
},
|
||||
)
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
if self.t1 is None:
|
||||
weight: torch.Tensor = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
|
||||
|
||||
else:
|
||||
rebuild1 = torch.einsum("i j k l, j r, i p -> p r k l", self.t1, self.w1_b, self.w1_a)
|
||||
rebuild2 = torch.einsum("i j k l, j r, i p -> p r k l", self.t2, self.w2_b, self.w2_a)
|
||||
weight = rebuild1 * rebuild2
|
||||
|
||||
return weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
for val in [self.w1_a, self.w1_b, self.w2_a, self.w2_b, self.t1, self.t2]:
|
||||
if val is not None:
|
||||
model_size += val.nelement() * val.element_size()
|
||||
return model_size
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
|
||||
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
|
||||
if self.t1 is not None:
|
||||
self.t1 = self.t1.to(device=device, dtype=dtype)
|
||||
|
||||
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
|
||||
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
|
||||
if self.t2 is not None:
|
||||
self.t2 = self.t2.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
class LoKRLayer(LoRALayerBase):
|
||||
# w1: Optional[torch.Tensor] = None
|
||||
# w1_a: Optional[torch.Tensor] = None
|
||||
# w1_b: Optional[torch.Tensor] = None
|
||||
# w2: Optional[torch.Tensor] = None
|
||||
# w2_a: Optional[torch.Tensor] = None
|
||||
# w2_b: Optional[torch.Tensor] = None
|
||||
# t2: Optional[torch.Tensor] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
self.w1 = values.get("lokr_w1", None)
|
||||
if self.w1 is None:
|
||||
self.w1_a = values["lokr_w1_a"]
|
||||
self.w1_b = values["lokr_w1_b"]
|
||||
else:
|
||||
self.w1_b = None
|
||||
self.w1_a = None
|
||||
|
||||
self.w2 = values.get("lokr_w2", None)
|
||||
if self.w2 is None:
|
||||
self.w2_a = values["lokr_w2_a"]
|
||||
self.w2_b = values["lokr_w2_b"]
|
||||
else:
|
||||
self.w2_a = None
|
||||
self.w2_b = None
|
||||
|
||||
self.t2 = values.get("lokr_t2", None)
|
||||
|
||||
if self.w1_b is not None:
|
||||
self.rank = self.w1_b.shape[0]
|
||||
elif self.w2_b is not None:
|
||||
self.rank = self.w2_b.shape[0]
|
||||
else:
|
||||
self.rank = None # unscaled
|
||||
|
||||
self.check_keys(
|
||||
values,
|
||||
{
|
||||
"lokr_w1",
|
||||
"lokr_w1_a",
|
||||
"lokr_w1_b",
|
||||
"lokr_w2",
|
||||
"lokr_w2_a",
|
||||
"lokr_w2_b",
|
||||
"lokr_t2",
|
||||
},
|
||||
)
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
w1: Optional[torch.Tensor] = self.w1
|
||||
if w1 is None:
|
||||
assert self.w1_a is not None
|
||||
assert self.w1_b is not None
|
||||
w1 = self.w1_a @ self.w1_b
|
||||
|
||||
w2 = self.w2
|
||||
if w2 is None:
|
||||
if self.t2 is None:
|
||||
assert self.w2_a is not None
|
||||
assert self.w2_b is not None
|
||||
w2 = self.w2_a @ self.w2_b
|
||||
else:
|
||||
w2 = torch.einsum("i j k l, i p, j r -> p r k l", self.t2, self.w2_a, self.w2_b)
|
||||
|
||||
if len(w2.shape) == 4:
|
||||
w1 = w1.unsqueeze(2).unsqueeze(2)
|
||||
w2 = w2.contiguous()
|
||||
assert w1 is not None
|
||||
assert w2 is not None
|
||||
weight = torch.kron(w1, w2)
|
||||
|
||||
return weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
for val in [self.w1, self.w1_a, self.w1_b, self.w2, self.w2_a, self.w2_b, self.t2]:
|
||||
if val is not None:
|
||||
model_size += val.nelement() * val.element_size()
|
||||
return model_size
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
if self.w1 is not None:
|
||||
self.w1 = self.w1.to(device=device, dtype=dtype)
|
||||
else:
|
||||
assert self.w1_a is not None
|
||||
assert self.w1_b is not None
|
||||
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
|
||||
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
|
||||
|
||||
if self.w2 is not None:
|
||||
self.w2 = self.w2.to(device=device, dtype=dtype)
|
||||
else:
|
||||
assert self.w2_a is not None
|
||||
assert self.w2_b is not None
|
||||
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
|
||||
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
|
||||
|
||||
if self.t2 is not None:
|
||||
self.t2 = self.t2.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
class FullLayer(LoRALayerBase):
|
||||
# bias handled in LoRALayerBase(calc_size, to)
|
||||
# weight: torch.Tensor
|
||||
# bias: Optional[torch.Tensor]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
self.weight = values["diff"]
|
||||
self.bias = values.get("diff_b", None)
|
||||
|
||||
self.rank = None # unscaled
|
||||
self.check_keys(values, {"diff", "diff_b"})
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
return self.weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
model_size += self.weight.nelement() * self.weight.element_size()
|
||||
return model_size
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
class IA3Layer(LoRALayerBase):
|
||||
# weight: torch.Tensor
|
||||
# on_input: torch.Tensor
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
self.weight = values["weight"]
|
||||
self.on_input = values["on_input"]
|
||||
|
||||
self.rank = None # unscaled
|
||||
self.check_keys(values, {"weight", "on_input"})
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
weight = self.weight
|
||||
if not self.on_input:
|
||||
weight = weight.reshape(-1, 1)
|
||||
assert orig_weight is not None
|
||||
return orig_weight * weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
model_size += self.weight.nelement() * self.weight.element_size()
|
||||
model_size += self.on_input.nelement() * self.on_input.element_size()
|
||||
return model_size
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||
self.on_input = self.on_input.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
class NormLayer(LoRALayerBase):
|
||||
# bias handled in LoRALayerBase(calc_size, to)
|
||||
# weight: torch.Tensor
|
||||
# bias: Optional[torch.Tensor]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
self.weight = values["w_norm"]
|
||||
self.bias = values.get("b_norm", None)
|
||||
|
||||
self.rank = None # unscaled
|
||||
self.check_keys(values, {"w_norm", "b_norm"})
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
return self.weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
model_size += self.weight.nelement() * self.weight.element_size()
|
||||
return model_size
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer, NormLayer]
|
||||
|
||||
|
||||
class LoRAModelRaw(RawModel): # (torch.nn.Module):
|
||||
_name: str
|
||||
layers: Dict[str, AnyLoRALayer]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
layers: Dict[str, AnyLoRALayer],
|
||||
):
|
||||
self._name = name
|
||||
self.layers = layers
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
# TODO: try revert if exception?
|
||||
for _key, layer in self.layers.items():
|
||||
layer.to(device=device, dtype=dtype)
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = 0
|
||||
for _, layer in self.layers.items():
|
||||
model_size += layer.calc_size()
|
||||
return model_size
|
||||
|
||||
@classmethod
|
||||
def _convert_sdxl_keys_to_diffusers_format(cls, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
"""Convert the keys of an SDXL LoRA state_dict to diffusers format.
|
||||
|
||||
The input state_dict can be in either Stability AI format or diffusers format. If the state_dict is already in
|
||||
diffusers format, then this function will have no effect.
|
||||
|
||||
This function is adapted from:
|
||||
https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L385-L409
|
||||
|
||||
Args:
|
||||
state_dict (Dict[str, Tensor]): The SDXL LoRA state_dict.
|
||||
|
||||
Raises:
|
||||
ValueError: If state_dict contains an unrecognized key, or not all keys could be converted.
|
||||
|
||||
Returns:
|
||||
Dict[str, Tensor]: The diffusers-format state_dict.
|
||||
"""
|
||||
converted_count = 0 # The number of Stability AI keys converted to diffusers format.
|
||||
not_converted_count = 0 # The number of keys that were not converted.
|
||||
|
||||
# Get a sorted list of Stability AI UNet keys so that we can efficiently search for keys with matching prefixes.
|
||||
# For example, we want to efficiently find `input_blocks_4_1` in the list when searching for
|
||||
# `input_blocks_4_1_proj_in`.
|
||||
stability_unet_keys = list(SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP)
|
||||
stability_unet_keys.sort()
|
||||
|
||||
new_state_dict = {}
|
||||
for full_key, value in state_dict.items():
|
||||
if full_key.startswith("lora_unet_"):
|
||||
search_key = full_key.replace("lora_unet_", "")
|
||||
# Use bisect to find the key in stability_unet_keys that *may* match the search_key's prefix.
|
||||
position = bisect.bisect_right(stability_unet_keys, search_key)
|
||||
map_key = stability_unet_keys[position - 1]
|
||||
# Now, check if the map_key *actually* matches the search_key.
|
||||
if search_key.startswith(map_key):
|
||||
new_key = full_key.replace(map_key, SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP[map_key])
|
||||
new_state_dict[new_key] = value
|
||||
converted_count += 1
|
||||
else:
|
||||
new_state_dict[full_key] = value
|
||||
not_converted_count += 1
|
||||
elif full_key.startswith("lora_te1_") or full_key.startswith("lora_te2_"):
|
||||
# The CLIP text encoders have the same keys in both Stability AI and diffusers formats.
|
||||
new_state_dict[full_key] = value
|
||||
continue
|
||||
else:
|
||||
raise ValueError(f"Unrecognized SDXL LoRA key prefix: '{full_key}'.")
|
||||
|
||||
if converted_count > 0 and not_converted_count > 0:
|
||||
raise ValueError(
|
||||
f"The SDXL LoRA could only be partially converted to diffusers format. converted={converted_count},"
|
||||
f" not_converted={not_converted_count}"
|
||||
)
|
||||
|
||||
return new_state_dict
|
||||
|
||||
@classmethod
|
||||
def from_checkpoint(
|
||||
cls,
|
||||
file_path: Union[str, Path],
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
base_model: Optional[BaseModelType] = None,
|
||||
) -> Self:
|
||||
device = device or torch.device("cpu")
|
||||
dtype = dtype or torch.float32
|
||||
|
||||
if isinstance(file_path, str):
|
||||
file_path = Path(file_path)
|
||||
|
||||
model = cls(
|
||||
name=file_path.stem,
|
||||
layers={},
|
||||
)
|
||||
|
||||
if file_path.suffix == ".safetensors":
|
||||
sd = load_file(file_path.absolute().as_posix(), device="cpu")
|
||||
else:
|
||||
sd = torch.load(file_path, map_location="cpu")
|
||||
|
||||
state_dict = cls._group_state(sd)
|
||||
|
||||
if base_model == BaseModelType.StableDiffusionXL:
|
||||
state_dict = cls._convert_sdxl_keys_to_diffusers_format(state_dict)
|
||||
|
||||
for layer_key, values in state_dict.items():
|
||||
# Detect layers according to LyCORIS detection logic(`weight_list_det`)
|
||||
# https://github.com/KohakuBlueleaf/LyCORIS/tree/8ad8000efb79e2b879054da8c9356e6143591bad/lycoris/modules
|
||||
|
||||
# lora and locon
|
||||
if "lora_up.weight" in values:
|
||||
layer: AnyLoRALayer = LoRALayer(layer_key, values)
|
||||
|
||||
# loha
|
||||
elif "hada_w1_a" in values:
|
||||
layer = LoHALayer(layer_key, values)
|
||||
|
||||
# lokr
|
||||
elif "lokr_w1" in values or "lokr_w1_a" in values:
|
||||
layer = LoKRLayer(layer_key, values)
|
||||
|
||||
# diff
|
||||
elif "diff" in values:
|
||||
layer = FullLayer(layer_key, values)
|
||||
|
||||
# ia3
|
||||
elif "on_input" in values:
|
||||
layer = IA3Layer(layer_key, values)
|
||||
|
||||
# norms
|
||||
elif "w_norm" in values:
|
||||
layer = NormLayer(layer_key, values)
|
||||
|
||||
else:
|
||||
print(f">> Encountered unknown lora layer module in {model.name}: {layer_key} - {list(values.keys())}")
|
||||
raise Exception("Unknown lora format!")
|
||||
|
||||
# lower memory consumption by removing already parsed layer values
|
||||
state_dict[layer_key].clear()
|
||||
|
||||
layer.to(device=device, dtype=dtype)
|
||||
model.layers[layer_key] = layer
|
||||
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def _group_state(state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
|
||||
state_dict_groupped: Dict[str, Dict[str, torch.Tensor]] = {}
|
||||
|
||||
for key, value in state_dict.items():
|
||||
stem, leaf = key.split(".", 1)
|
||||
if stem not in state_dict_groupped:
|
||||
state_dict_groupped[stem] = {}
|
||||
state_dict_groupped[stem][leaf] = value
|
||||
|
||||
return state_dict_groupped
|
||||
|
||||
|
||||
# code from
|
||||
# https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L15C1-L97C32
|
||||
def make_sdxl_unet_conversion_map() -> List[Tuple[str, str]]:
|
||||
"""Create a dict mapping state_dict keys from Stability AI SDXL format to diffusers SDXL format."""
|
||||
unet_conversion_map_layer = []
|
||||
|
||||
for i in range(3): # num_blocks is 3 in sdxl
|
||||
# loop over downblocks/upblocks
|
||||
for j in range(2):
|
||||
# loop over resnets/attentions for downblocks
|
||||
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
||||
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
|
||||
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
||||
|
||||
if i < 3:
|
||||
# no attention layers in down_blocks.3
|
||||
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
||||
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
|
||||
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
||||
|
||||
for j in range(3):
|
||||
# loop over resnets/attentions for upblocks
|
||||
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
|
||||
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
|
||||
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
|
||||
|
||||
# if i > 0: commentout for sdxl
|
||||
# no attention layers in up_blocks.0
|
||||
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
|
||||
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
|
||||
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
|
||||
|
||||
if i < 3:
|
||||
# no downsample in down_blocks.3
|
||||
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
||||
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
|
||||
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
||||
|
||||
# no upsample in up_blocks.3
|
||||
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
||||
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl
|
||||
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
|
||||
|
||||
hf_mid_atn_prefix = "mid_block.attentions.0."
|
||||
sd_mid_atn_prefix = "middle_block.1."
|
||||
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
||||
|
||||
for j in range(2):
|
||||
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
||||
sd_mid_res_prefix = f"middle_block.{2*j}."
|
||||
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
||||
|
||||
unet_conversion_map_resnet = [
|
||||
# (stable-diffusion, HF Diffusers)
|
||||
("in_layers.0.", "norm1."),
|
||||
("in_layers.2.", "conv1."),
|
||||
("out_layers.0.", "norm2."),
|
||||
("out_layers.3.", "conv2."),
|
||||
("emb_layers.1.", "time_emb_proj."),
|
||||
("skip_connection.", "conv_shortcut."),
|
||||
]
|
||||
|
||||
unet_conversion_map = []
|
||||
for sd, hf in unet_conversion_map_layer:
|
||||
if "resnets" in hf:
|
||||
for sd_res, hf_res in unet_conversion_map_resnet:
|
||||
unet_conversion_map.append((sd + sd_res, hf + hf_res))
|
||||
else:
|
||||
unet_conversion_map.append((sd, hf))
|
||||
|
||||
for j in range(2):
|
||||
hf_time_embed_prefix = f"time_embedding.linear_{j+1}."
|
||||
sd_time_embed_prefix = f"time_embed.{j*2}."
|
||||
unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix))
|
||||
|
||||
for j in range(2):
|
||||
hf_label_embed_prefix = f"add_embedding.linear_{j+1}."
|
||||
sd_label_embed_prefix = f"label_emb.0.{j*2}."
|
||||
unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix))
|
||||
|
||||
unet_conversion_map.append(("input_blocks.0.0.", "conv_in."))
|
||||
unet_conversion_map.append(("out.0.", "conv_norm_out."))
|
||||
unet_conversion_map.append(("out.2.", "conv_out."))
|
||||
|
||||
return unet_conversion_map
|
||||
|
||||
|
||||
SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP = {
|
||||
sd.rstrip(".").replace(".", "_"): hf.rstrip(".").replace(".", "_") for sd, hf in make_sdxl_unet_conversion_map()
|
||||
}
|
||||
@@ -66,12 +66,14 @@ class ModelLoader(ModelLoaderBase):
|
||||
return (model_base / config.path).resolve()
|
||||
|
||||
def _load_and_cache(self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> ModelLockerBase:
|
||||
stats_name = ":".join([config.base, config.type, config.name, (submodel_type or "")])
|
||||
try:
|
||||
return self._ram_cache.get(config.key, submodel_type)
|
||||
return self._ram_cache.get(config.key, submodel_type, stats_name=stats_name)
|
||||
except IndexError:
|
||||
pass
|
||||
|
||||
config.path = str(self._get_model_path(config))
|
||||
self._ram_cache.make_room(self.get_size_fs(config, Path(config.path), submodel_type))
|
||||
loaded_model = self._load_model(config, submodel_type)
|
||||
|
||||
self._ram_cache.put(
|
||||
@@ -83,7 +85,7 @@ class ModelLoader(ModelLoaderBase):
|
||||
return self._ram_cache.get(
|
||||
key=config.key,
|
||||
submodel_type=submodel_type,
|
||||
stats_name=":".join([config.base, config.type, config.name, (submodel_type or "")]),
|
||||
stats_name=stats_name,
|
||||
)
|
||||
|
||||
def get_size_fs(
|
||||
|
||||
@@ -128,7 +128,24 @@ class ModelCacheBase(ABC, Generic[T]):
|
||||
@property
|
||||
@abstractmethod
|
||||
def max_cache_size(self) -> float:
|
||||
"""Return true if the cache is configured to lazily offload models in VRAM."""
|
||||
"""Return the maximum size the RAM cache can grow to."""
|
||||
pass
|
||||
|
||||
@max_cache_size.setter
|
||||
@abstractmethod
|
||||
def max_cache_size(self, value: float) -> None:
|
||||
"""Set the cap on vram cache size."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def max_vram_cache_size(self) -> float:
|
||||
"""Return the maximum size the VRAM cache can grow to."""
|
||||
pass
|
||||
|
||||
@max_vram_cache_size.setter
|
||||
@abstractmethod
|
||||
def max_vram_cache_size(self, value: float) -> float:
|
||||
"""Set the maximum size the VRAM cache can grow to."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@@ -193,15 +210,6 @@ class ModelCacheBase(ABC, Generic[T]):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def exists(
|
||||
self,
|
||||
key: str,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> bool:
|
||||
"""Return true if the model identified by key and submodel_type is in the cache."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cache_size(self) -> int:
|
||||
"""Get the total size of the models currently cached."""
|
||||
|
||||
@@ -1,22 +1,6 @@
|
||||
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development team
|
||||
# TODO: Add Stalker's proper name to copyright
|
||||
"""
|
||||
Manage a RAM cache of diffusion/transformer models for fast switching.
|
||||
They are moved between GPU VRAM and CPU RAM as necessary. If the cache
|
||||
grows larger than a preset maximum, then the least recently used
|
||||
model will be cleared and (re)loaded from disk when next needed.
|
||||
|
||||
The cache returns context manager generators designed to load the
|
||||
model into the GPU within the context, and unload outside the
|
||||
context. Use like this:
|
||||
|
||||
cache = ModelCache(max_cache_size=7.5)
|
||||
with cache.get_model('runwayml/stable-diffusion-1-5') as SD1,
|
||||
cache.get_model('stabilityai/stable-diffusion-2') as SD2:
|
||||
do_something_in_GPU(SD1,SD2)
|
||||
|
||||
|
||||
"""
|
||||
""" """
|
||||
|
||||
import gc
|
||||
import math
|
||||
@@ -40,53 +24,74 @@ from invokeai.backend.model_manager.load.model_util import calc_model_size_by_da
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
# Maximum size of the cache, in gigs
|
||||
# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously
|
||||
DEFAULT_MAX_CACHE_SIZE = 6.0
|
||||
|
||||
# amount of GPU memory to hold in reserve for use by generations (GB)
|
||||
DEFAULT_MAX_VRAM_CACHE_SIZE = 2.75
|
||||
|
||||
# actual size of a gig
|
||||
GIG = 1073741824
|
||||
# Size of a GB in bytes.
|
||||
GB = 2**30
|
||||
|
||||
# Size of a MB in bytes.
|
||||
MB = 2**20
|
||||
|
||||
|
||||
class ModelCache(ModelCacheBase[AnyModel]):
|
||||
"""Implementation of ModelCacheBase."""
|
||||
"""A cache for managing models in memory.
|
||||
|
||||
The cache is based on two levels of model storage:
|
||||
- execution_device: The device where most models are executed (typically "cuda", "mps", or "cpu").
|
||||
- storage_device: The device where models are offloaded when not in active use (typically "cpu").
|
||||
|
||||
The model cache is based on the following assumptions:
|
||||
- storage_device_mem_size > execution_device_mem_size
|
||||
- disk_to_storage_device_transfer_time >> storage_device_to_execution_device_transfer_time
|
||||
|
||||
A copy of all models in the cache is always kept on the storage_device. A subset of the models also have a copy on
|
||||
the execution_device.
|
||||
|
||||
Models are moved between the storage_device and the execution_device as necessary. Cache size limits are enforced
|
||||
on both the storage_device and the execution_device. The execution_device cache uses a smallest-first offload
|
||||
policy. The storage_device cache uses a least-recently-used (LRU) offload policy.
|
||||
|
||||
Note: Neither of these offload policies has really been compared against alternatives. It's likely that different
|
||||
policies would be better, although the optimal policies are likely heavily dependent on usage patterns and HW
|
||||
configuration.
|
||||
|
||||
The cache returns context manager generators designed to load the model into the execution device (often GPU) within
|
||||
the context, and unload outside the context.
|
||||
|
||||
Example usage:
|
||||
```
|
||||
cache = ModelCache(max_cache_size=7.5, max_vram_cache_size=6.0)
|
||||
with cache.get_model('runwayml/stable-diffusion-1-5') as SD1:
|
||||
do_something_on_gpu(SD1)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_cache_size: float = DEFAULT_MAX_CACHE_SIZE,
|
||||
max_vram_cache_size: float = DEFAULT_MAX_VRAM_CACHE_SIZE,
|
||||
max_cache_size: float,
|
||||
max_vram_cache_size: float,
|
||||
execution_device: torch.device = torch.device("cuda"),
|
||||
storage_device: torch.device = torch.device("cpu"),
|
||||
precision: torch.dtype = torch.float16,
|
||||
sequential_offload: bool = False,
|
||||
lazy_offloading: bool = True,
|
||||
sha_chunksize: int = 16777216,
|
||||
log_memory_usage: bool = False,
|
||||
logger: Optional[Logger] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the model RAM cache.
|
||||
|
||||
:param max_cache_size: Maximum size of the RAM cache [6.0 GB]
|
||||
:param max_cache_size: Maximum size of the storage_device cache in GBs.
|
||||
:param max_vram_cache_size: Maximum size of the execution_device cache in GBs.
|
||||
:param execution_device: Torch device to load active model into [torch.device('cuda')]
|
||||
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
|
||||
:param precision: Precision for loaded models [torch.float16]
|
||||
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
|
||||
:param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially
|
||||
:param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache
|
||||
operation, and the result will be logged (at debug level). There is a time cost to capturing the memory
|
||||
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
|
||||
behaviour.
|
||||
:param logger: InvokeAILogger to use (otherwise creates one)
|
||||
"""
|
||||
# allow lazy offloading only when vram cache enabled
|
||||
self._lazy_offloading = lazy_offloading and max_vram_cache_size > 0
|
||||
self._precision: torch.dtype = precision
|
||||
self._max_cache_size: float = max_cache_size
|
||||
self._max_vram_cache_size: float = max_vram_cache_size
|
||||
self._execution_device: torch.device = execution_device
|
||||
@@ -128,6 +133,16 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
"""Set the cap on cache size."""
|
||||
self._max_cache_size = value
|
||||
|
||||
@property
|
||||
def max_vram_cache_size(self) -> float:
|
||||
"""Return the cap on vram cache size."""
|
||||
return self._max_vram_cache_size
|
||||
|
||||
@max_vram_cache_size.setter
|
||||
def max_vram_cache_size(self, value: float) -> None:
|
||||
"""Set the cap on vram cache size."""
|
||||
self._max_vram_cache_size = value
|
||||
|
||||
@property
|
||||
def stats(self) -> Optional[CacheStats]:
|
||||
"""Return collected CacheStats object."""
|
||||
@@ -145,15 +160,6 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
total += cache_record.size
|
||||
return total
|
||||
|
||||
def exists(
|
||||
self,
|
||||
key: str,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> bool:
|
||||
"""Return true if the model identified by key and submodel_type is in the cache."""
|
||||
key = self._make_cache_key(key, submodel_type)
|
||||
return key in self._cached_models
|
||||
|
||||
def put(
|
||||
self,
|
||||
key: str,
|
||||
@@ -203,7 +209,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
# more stats
|
||||
if self.stats:
|
||||
stats_name = stats_name or key
|
||||
self.stats.cache_size = int(self._max_cache_size * GIG)
|
||||
self.stats.cache_size = int(self._max_cache_size * GB)
|
||||
self.stats.high_watermark = max(self.stats.high_watermark, self.cache_size())
|
||||
self.stats.in_cache = len(self._cached_models)
|
||||
self.stats.loaded_model_sizes[stats_name] = max(
|
||||
@@ -231,10 +237,13 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
return model_key
|
||||
|
||||
def offload_unlocked_models(self, size_required: int) -> None:
|
||||
"""Move any unused models from VRAM."""
|
||||
reserved = self._max_vram_cache_size * GIG
|
||||
"""Offload models from the execution_device to make room for size_required.
|
||||
|
||||
:param size_required: The amount of space to clear in the execution_device cache, in bytes.
|
||||
"""
|
||||
reserved = self._max_vram_cache_size * GB
|
||||
vram_in_use = torch.cuda.memory_allocated() + size_required
|
||||
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM needed for models; max allowed={(reserved/GIG):.2f}GB")
|
||||
self.logger.debug(f"{(vram_in_use/GB):.2f}GB VRAM needed for models; max allowed={(reserved/GB):.2f}GB")
|
||||
for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
|
||||
if vram_in_use <= reserved:
|
||||
break
|
||||
@@ -245,7 +254,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
cache_entry.loaded = False
|
||||
vram_in_use = torch.cuda.memory_allocated() + size_required
|
||||
self.logger.debug(
|
||||
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GIG):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GIG):.2f}GB"
|
||||
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GB):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GB):.2f}GB"
|
||||
)
|
||||
|
||||
TorchDevice.empty_cache()
|
||||
@@ -303,7 +312,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
self.logger.debug(
|
||||
f"Moved model '{cache_entry.key}' from {source_device} to"
|
||||
f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s."
|
||||
f"Estimated model size: {(cache_entry.size/GIG):.3f} GB."
|
||||
f"Estimated model size: {(cache_entry.size/GB):.3f} GB."
|
||||
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
|
||||
)
|
||||
|
||||
@@ -326,14 +335,14 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
f"Moving model '{cache_entry.key}' from {source_device} to"
|
||||
f" {target_device} caused an unexpected change in VRAM usage. The model's"
|
||||
" estimated size may be incorrect. Estimated model size:"
|
||||
f" {(cache_entry.size/GIG):.3f} GB.\n"
|
||||
f" {(cache_entry.size/GB):.3f} GB.\n"
|
||||
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
|
||||
)
|
||||
|
||||
def print_cuda_stats(self) -> None:
|
||||
"""Log CUDA diagnostics."""
|
||||
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG)
|
||||
ram = "%4.2fG" % (self.cache_size() / GIG)
|
||||
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GB)
|
||||
ram = "%4.2fG" % (self.cache_size() / GB)
|
||||
|
||||
in_ram_models = 0
|
||||
in_vram_models = 0
|
||||
@@ -353,17 +362,20 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
)
|
||||
|
||||
def make_room(self, size: int) -> None:
|
||||
"""Make enough room in the cache to accommodate a new model of indicated size."""
|
||||
# calculate how much memory this model will require
|
||||
# multiplier = 2 if self.precision==torch.float32 else 1
|
||||
"""Make enough room in the cache to accommodate a new model of indicated size.
|
||||
|
||||
Note: This function deletes all of the cache's internal references to a model in order to free it. If there are
|
||||
external references to the model, there's nothing that the cache can do about it, and those models will not be
|
||||
garbage-collected.
|
||||
"""
|
||||
bytes_needed = size
|
||||
maximum_size = self.max_cache_size * GIG # stored in GB, convert to bytes
|
||||
maximum_size = self.max_cache_size * GB # stored in GB, convert to bytes
|
||||
current_size = self.cache_size()
|
||||
|
||||
if current_size + bytes_needed > maximum_size:
|
||||
self.logger.debug(
|
||||
f"Max cache size exceeded: {(current_size/GIG):.2f}/{self.max_cache_size:.2f} GB, need an additional"
|
||||
f" {(bytes_needed/GIG):.2f} GB"
|
||||
f"Max cache size exceeded: {(current_size/GB):.2f}/{self.max_cache_size:.2f} GB, need an additional"
|
||||
f" {(bytes_needed/GB):.2f} GB"
|
||||
)
|
||||
|
||||
self.logger.debug(f"Before making_room: cached_models={len(self._cached_models)}")
|
||||
@@ -380,7 +392,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
|
||||
if not cache_entry.locked:
|
||||
self.logger.debug(
|
||||
f"Removing {model_key} from RAM cache to free at least {(size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
|
||||
f"Removing {model_key} from RAM cache to free at least {(size/GB):.2f} GB (-{(cache_entry.size/GB):.2f} GB)"
|
||||
)
|
||||
current_size -= cache_entry.size
|
||||
models_cleared += 1
|
||||
|
||||
@@ -5,8 +5,10 @@ from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
@@ -18,6 +20,11 @@ from invokeai.backend.model_manager import (
|
||||
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
|
||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||
from invokeai.backend.peft.conversions.flux_kohya_lora_conversion_utils import (
|
||||
lora_model_from_flux_kohya_state_dict,
|
||||
)
|
||||
from invokeai.backend.peft.conversions.sd_lora_conversion_utils import lora_model_from_sd_state_dict
|
||||
from invokeai.backend.peft.conversions.sdxl_lora_conversion_utils import convert_sdxl_keys_to_diffusers_format
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LoRA, format=ModelFormat.Diffusers)
|
||||
@@ -45,14 +52,28 @@ class LoRALoader(ModelLoader):
|
||||
raise ValueError("There are no submodels in a LoRA model.")
|
||||
model_path = Path(config.path)
|
||||
assert self._model_base is not None
|
||||
model = LoRAModelRaw.from_checkpoint(
|
||||
file_path=model_path,
|
||||
dtype=self._torch_dtype,
|
||||
base_model=self._model_base,
|
||||
)
|
||||
|
||||
# Load the state dict from the model file.
|
||||
if model_path.suffix == ".safetensors":
|
||||
state_dict = load_file(model_path.absolute().as_posix(), device="cpu")
|
||||
else:
|
||||
state_dict = torch.load(model_path, map_location="cpu")
|
||||
|
||||
# Apply state_dict key conversions, if necessary.
|
||||
if self._model_base == BaseModelType.StableDiffusionXL:
|
||||
state_dict = convert_sdxl_keys_to_diffusers_format(state_dict)
|
||||
model = lora_model_from_sd_state_dict(state_dict=state_dict)
|
||||
elif self._model_base == BaseModelType.Flux:
|
||||
model = lora_model_from_flux_kohya_state_dict(state_dict=state_dict)
|
||||
elif self._model_base in [BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2]:
|
||||
# Currently, we don't apply any conversions for SD1 and SD2 LoRA models.
|
||||
model = lora_model_from_sd_state_dict(state_dict=state_dict)
|
||||
else:
|
||||
raise ValueError(f"Unsupported LoRA base model: {self._model_base}")
|
||||
|
||||
model.to(dtype=self._torch_dtype)
|
||||
return model
|
||||
|
||||
# override
|
||||
def _get_model_path(self, config: AnyModelConfig) -> Path:
|
||||
# cheating a little - we remember this variable for using in the subsequent call to _load_model()
|
||||
self._model_base = config.base
|
||||
|
||||
@@ -15,9 +15,9 @@ from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import D
|
||||
from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline
|
||||
from invokeai.backend.image_util.segment_anything.segment_anything_pipeline import SegmentAnythingPipeline
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||
from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.model_manager.config import AnyModel
|
||||
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
|
||||
from invokeai.backend.peft.lora import LoRAModelRaw
|
||||
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
|
||||
from invokeai.backend.textual_inversion import TextualInversionModelRaw
|
||||
|
||||
|
||||
@@ -26,6 +26,7 @@ from invokeai.backend.model_manager.config import (
|
||||
SchedulerPredictionType,
|
||||
)
|
||||
from invokeai.backend.model_manager.util.model_util import lora_token_vector_length, read_checkpoint_meta
|
||||
from invokeai.backend.peft.conversions.flux_kohya_lora_conversion_utils import is_state_dict_likely_in_flux_kohya_format
|
||||
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
|
||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||
|
||||
@@ -528,9 +529,11 @@ class LoRACheckpointProbe(CheckpointProbeBase):
|
||||
return ModelFormat("lycoris")
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
checkpoint = self.checkpoint
|
||||
token_vector_length = lora_token_vector_length(checkpoint)
|
||||
if is_state_dict_likely_in_flux_kohya_format(self.checkpoint):
|
||||
return BaseModelType.Flux
|
||||
|
||||
# If we've gotten here, we assume that the model is a Stable Diffusion model.
|
||||
token_vector_length = lora_token_vector_length(self.checkpoint)
|
||||
if token_vector_length == 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
elif token_vector_length == 1024:
|
||||
|
||||
@@ -13,10 +13,10 @@ from diffusers import OnnxRuntimeModel, UNet2DConditionModel
|
||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
||||
|
||||
from invokeai.app.shared.models import FreeUConfig
|
||||
from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.model_manager import AnyModel
|
||||
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
|
||||
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
|
||||
from invokeai.backend.peft.lora import LoRAModelRaw
|
||||
from invokeai.backend.stable_diffusion.extensions.lora import LoRAExt
|
||||
from invokeai.backend.textual_inversion import TextualInversionManager, TextualInversionModelRaw
|
||||
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
|
||||
|
||||
0
invokeai/backend/peft/__init__.py
Normal file
0
invokeai/backend/peft/__init__.py
Normal file
0
invokeai/backend/peft/conversions/__init__.py
Normal file
0
invokeai/backend/peft/conversions/__init__.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import re
|
||||
from typing import Any, Dict, TypeVar
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.peft.layers.any_lora_layer import AnyLoRALayer
|
||||
from invokeai.backend.peft.layers.utils import peft_layer_from_state_dict
|
||||
from invokeai.backend.peft.lora import LoRAModelRaw
|
||||
|
||||
# A regex pattern that matches all of the keys in the Kohya FLUX LoRA format.
|
||||
# Example keys:
|
||||
# lora_unet_double_blocks_0_img_attn_proj.alpha
|
||||
# lora_unet_double_blocks_0_img_attn_proj.lora_down.weight
|
||||
# lora_unet_double_blocks_0_img_attn_proj.lora_up.weight
|
||||
FLUX_KOHYA_KEY_REGEX = (
|
||||
r"lora_unet_(\w+_blocks)_(\d+)_(img_attn|img_mlp|img_mod|txt_attn|txt_mlp|txt_mod|linear1|linear2|modulation)_?(.*)"
|
||||
)
|
||||
|
||||
|
||||
def is_state_dict_likely_in_flux_kohya_format(state_dict: Dict[str, Any]) -> bool:
|
||||
"""Checks if the provided state dict is likely in the Kohya FLUX LoRA format.
|
||||
|
||||
This is intended to be a high-precision detector, but it is not guaranteed to have perfect precision. (A
|
||||
perfect-precision detector would require checking all keys against a whitelist and verifying tensor shapes.)
|
||||
"""
|
||||
for k in state_dict.keys():
|
||||
if not re.match(FLUX_KOHYA_KEY_REGEX, k):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def lora_model_from_flux_kohya_state_dict(state_dict: Dict[str, torch.Tensor]) -> LoRAModelRaw:
|
||||
# Group keys by layer.
|
||||
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = {}
|
||||
for key, value in state_dict.items():
|
||||
layer_name, param_name = key.split(".", 1)
|
||||
if layer_name not in grouped_state_dict:
|
||||
grouped_state_dict[layer_name] = {}
|
||||
grouped_state_dict[layer_name][param_name] = value
|
||||
|
||||
# Convert the state dict to the InvokeAI format.
|
||||
grouped_state_dict = convert_flux_kohya_state_dict_to_invoke_format(grouped_state_dict)
|
||||
|
||||
# Create LoRA layers.
|
||||
layers: dict[str, AnyLoRALayer] = {}
|
||||
for layer_key, layer_state_dict in grouped_state_dict.items():
|
||||
layer = peft_layer_from_state_dict(layer_key, layer_state_dict)
|
||||
layers[layer_key] = layer
|
||||
|
||||
# Create and return the LoRAModelRaw.
|
||||
return LoRAModelRaw(layers=layers)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def convert_flux_kohya_state_dict_to_invoke_format(state_dict: Dict[str, T]) -> Dict[str, T]:
|
||||
"""Converts a state dict from the Kohya FLUX LoRA format to LoRA weight format used internally by InvokeAI.
|
||||
|
||||
Example key conversions:
|
||||
"lora_unet_double_blocks_0_img_attn_proj" -> "double_blocks.0.img_attn.proj"
|
||||
"lora_unet_double_blocks_0_img_attn_proj" -> "double_blocks.0.img_attn.proj"
|
||||
"lora_unet_double_blocks_0_img_attn_proj" -> "double_blocks.0.img_attn.proj"
|
||||
"lora_unet_double_blocks_0_img_attn_qkv" -> "double_blocks.0.img_attn.qkv"
|
||||
"lora_unet_double_blocks_0_img_attn_qkv" -> "double_blocks.0.img.attn.qkv"
|
||||
"lora_unet_double_blocks_0_img_attn_qkv" -> "double_blocks.0.img.attn.qkv"
|
||||
"""
|
||||
|
||||
def replace_func(match: re.Match[str]) -> str:
|
||||
s = f"{match.group(1)}.{match.group(2)}.{match.group(3)}"
|
||||
if match.group(4):
|
||||
s += f".{match.group(4)}"
|
||||
return s
|
||||
|
||||
converted_dict: dict[str, T] = {}
|
||||
for k, v in state_dict.items():
|
||||
match = re.match(FLUX_KOHYA_KEY_REGEX, k)
|
||||
if match:
|
||||
new_key = re.sub(FLUX_KOHYA_KEY_REGEX, replace_func, k)
|
||||
converted_dict[new_key] = v
|
||||
else:
|
||||
raise ValueError(f"Key '{k}' does not match the expected pattern for FLUX LoRA weights.")
|
||||
|
||||
return converted_dict
|
||||
@@ -0,0 +1,30 @@
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.peft.layers.any_lora_layer import AnyLoRALayer
|
||||
from invokeai.backend.peft.layers.utils import peft_layer_from_state_dict
|
||||
from invokeai.backend.peft.lora import LoRAModelRaw
|
||||
|
||||
|
||||
def lora_model_from_sd_state_dict(state_dict: Dict[str, torch.Tensor]) -> LoRAModelRaw:
|
||||
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = _group_state(state_dict)
|
||||
|
||||
layers: dict[str, AnyLoRALayer] = {}
|
||||
for layer_key, values in grouped_state_dict.items():
|
||||
layer = peft_layer_from_state_dict(layer_key, values)
|
||||
layers[layer_key] = layer
|
||||
|
||||
return LoRAModelRaw(layers=layers)
|
||||
|
||||
|
||||
def _group_state(state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
|
||||
state_dict_groupped: Dict[str, Dict[str, torch.Tensor]] = {}
|
||||
|
||||
for key, value in state_dict.items():
|
||||
stem, leaf = key.split(".", 1)
|
||||
if stem not in state_dict_groupped:
|
||||
state_dict_groupped[stem] = {}
|
||||
state_dict_groupped[stem][leaf] = value
|
||||
|
||||
return state_dict_groupped
|
||||
154
invokeai/backend/peft/conversions/sdxl_lora_conversion_utils.py
Normal file
154
invokeai/backend/peft/conversions/sdxl_lora_conversion_utils.py
Normal file
@@ -0,0 +1,154 @@
|
||||
import bisect
|
||||
from typing import Dict, List, Tuple, TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def convert_sdxl_keys_to_diffusers_format(state_dict: Dict[str, T]) -> dict[str, T]:
|
||||
"""Convert the keys of an SDXL LoRA state_dict to diffusers format.
|
||||
|
||||
The input state_dict can be in either Stability AI format or diffusers format. If the state_dict is already in
|
||||
diffusers format, then this function will have no effect.
|
||||
|
||||
This function is adapted from:
|
||||
https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L385-L409
|
||||
|
||||
Args:
|
||||
state_dict (Dict[str, Tensor]): The SDXL LoRA state_dict.
|
||||
|
||||
Raises:
|
||||
ValueError: If state_dict contains an unrecognized key, or not all keys could be converted.
|
||||
|
||||
Returns:
|
||||
Dict[str, Tensor]: The diffusers-format state_dict.
|
||||
"""
|
||||
converted_count = 0 # The number of Stability AI keys converted to diffusers format.
|
||||
not_converted_count = 0 # The number of keys that were not converted.
|
||||
|
||||
# Get a sorted list of Stability AI UNet keys so that we can efficiently search for keys with matching prefixes.
|
||||
# For example, we want to efficiently find `input_blocks_4_1` in the list when searching for
|
||||
# `input_blocks_4_1_proj_in`.
|
||||
stability_unet_keys = list(SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP)
|
||||
stability_unet_keys.sort()
|
||||
|
||||
new_state_dict: dict[str, T] = {}
|
||||
for full_key, value in state_dict.items():
|
||||
if full_key.startswith("lora_unet_"):
|
||||
search_key = full_key.replace("lora_unet_", "")
|
||||
# Use bisect to find the key in stability_unet_keys that *may* match the search_key's prefix.
|
||||
position = bisect.bisect_right(stability_unet_keys, search_key)
|
||||
map_key = stability_unet_keys[position - 1]
|
||||
# Now, check if the map_key *actually* matches the search_key.
|
||||
if search_key.startswith(map_key):
|
||||
new_key = full_key.replace(map_key, SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP[map_key])
|
||||
new_state_dict[new_key] = value
|
||||
converted_count += 1
|
||||
else:
|
||||
new_state_dict[full_key] = value
|
||||
not_converted_count += 1
|
||||
elif full_key.startswith("lora_te1_") or full_key.startswith("lora_te2_"):
|
||||
# The CLIP text encoders have the same keys in both Stability AI and diffusers formats.
|
||||
new_state_dict[full_key] = value
|
||||
continue
|
||||
else:
|
||||
raise ValueError(f"Unrecognized SDXL LoRA key prefix: '{full_key}'.")
|
||||
|
||||
if converted_count > 0 and not_converted_count > 0:
|
||||
raise ValueError(
|
||||
f"The SDXL LoRA could only be partially converted to diffusers format. converted={converted_count},"
|
||||
f" not_converted={not_converted_count}"
|
||||
)
|
||||
|
||||
return new_state_dict
|
||||
|
||||
|
||||
# code from
|
||||
# https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L15C1-L97C32
|
||||
def _make_sdxl_unet_conversion_map() -> List[Tuple[str, str]]:
|
||||
"""Create a dict mapping state_dict keys from Stability AI SDXL format to diffusers SDXL format."""
|
||||
unet_conversion_map_layer: list[tuple[str, str]] = []
|
||||
|
||||
for i in range(3): # num_blocks is 3 in sdxl
|
||||
# loop over downblocks/upblocks
|
||||
for j in range(2):
|
||||
# loop over resnets/attentions for downblocks
|
||||
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
||||
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
|
||||
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
||||
|
||||
if i < 3:
|
||||
# no attention layers in down_blocks.3
|
||||
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
||||
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
|
||||
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
||||
|
||||
for j in range(3):
|
||||
# loop over resnets/attentions for upblocks
|
||||
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
|
||||
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
|
||||
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
|
||||
|
||||
# if i > 0: commentout for sdxl
|
||||
# no attention layers in up_blocks.0
|
||||
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
|
||||
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
|
||||
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
|
||||
|
||||
if i < 3:
|
||||
# no downsample in down_blocks.3
|
||||
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
||||
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
|
||||
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
||||
|
||||
# no upsample in up_blocks.3
|
||||
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
||||
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl
|
||||
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
|
||||
|
||||
hf_mid_atn_prefix = "mid_block.attentions.0."
|
||||
sd_mid_atn_prefix = "middle_block.1."
|
||||
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
||||
|
||||
for j in range(2):
|
||||
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
||||
sd_mid_res_prefix = f"middle_block.{2*j}."
|
||||
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
||||
|
||||
unet_conversion_map_resnet = [
|
||||
# (stable-diffusion, HF Diffusers)
|
||||
("in_layers.0.", "norm1."),
|
||||
("in_layers.2.", "conv1."),
|
||||
("out_layers.0.", "norm2."),
|
||||
("out_layers.3.", "conv2."),
|
||||
("emb_layers.1.", "time_emb_proj."),
|
||||
("skip_connection.", "conv_shortcut."),
|
||||
]
|
||||
|
||||
unet_conversion_map: list[tuple[str, str]] = []
|
||||
for sd, hf in unet_conversion_map_layer:
|
||||
if "resnets" in hf:
|
||||
for sd_res, hf_res in unet_conversion_map_resnet:
|
||||
unet_conversion_map.append((sd + sd_res, hf + hf_res))
|
||||
else:
|
||||
unet_conversion_map.append((sd, hf))
|
||||
|
||||
for j in range(2):
|
||||
hf_time_embed_prefix = f"time_embedding.linear_{j+1}."
|
||||
sd_time_embed_prefix = f"time_embed.{j*2}."
|
||||
unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix))
|
||||
|
||||
for j in range(2):
|
||||
hf_label_embed_prefix = f"add_embedding.linear_{j+1}."
|
||||
sd_label_embed_prefix = f"label_emb.0.{j*2}."
|
||||
unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix))
|
||||
|
||||
unet_conversion_map.append(("input_blocks.0.0.", "conv_in."))
|
||||
unet_conversion_map.append(("out.0.", "conv_norm_out."))
|
||||
unet_conversion_map.append(("out.2.", "conv_out."))
|
||||
|
||||
return unet_conversion_map
|
||||
|
||||
|
||||
SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP = {
|
||||
sd.rstrip(".").replace(".", "_"): hf.rstrip(".").replace(".", "_") for sd, hf in _make_sdxl_unet_conversion_map()
|
||||
}
|
||||
0
invokeai/backend/peft/layers/__init__.py
Normal file
0
invokeai/backend/peft/layers/__init__.py
Normal file
10
invokeai/backend/peft/layers/any_lora_layer.py
Normal file
10
invokeai/backend/peft/layers/any_lora_layer.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from typing import Union
|
||||
|
||||
from invokeai.backend.peft.layers.full_layer import FullLayer
|
||||
from invokeai.backend.peft.layers.ia3_layer import IA3Layer
|
||||
from invokeai.backend.peft.layers.loha_layer import LoHALayer
|
||||
from invokeai.backend.peft.layers.lokr_layer import LoKRLayer
|
||||
from invokeai.backend.peft.layers.lora_layer import LoRALayer
|
||||
from invokeai.backend.peft.layers.norm_layer import NormLayer
|
||||
|
||||
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer, NormLayer]
|
||||
37
invokeai/backend/peft/layers/full_layer.py
Normal file
37
invokeai/backend/peft/layers/full_layer.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.peft.layers.lora_layer_base import LoRALayerBase
|
||||
|
||||
|
||||
class FullLayer(LoRALayerBase):
|
||||
# bias handled in LoRALayerBase(calc_size, to)
|
||||
# weight: torch.Tensor
|
||||
# bias: Optional[torch.Tensor]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
self.weight = values["diff"]
|
||||
self.bias = values.get("diff_b", None)
|
||||
|
||||
self.rank = None # unscaled
|
||||
self.check_keys(values, {"diff", "diff_b"})
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
return self.weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
model_size += self.weight.nelement() * self.weight.element_size()
|
||||
return model_size
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||
42
invokeai/backend/peft/layers/ia3_layer.py
Normal file
42
invokeai/backend/peft/layers/ia3_layer.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.peft.layers.lora_layer_base import LoRALayerBase
|
||||
|
||||
|
||||
class IA3Layer(LoRALayerBase):
|
||||
# weight: torch.Tensor
|
||||
# on_input: torch.Tensor
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
self.weight = values["weight"]
|
||||
self.on_input = values["on_input"]
|
||||
|
||||
self.rank = None # unscaled
|
||||
self.check_keys(values, {"weight", "on_input"})
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
weight = self.weight
|
||||
if not self.on_input:
|
||||
weight = weight.reshape(-1, 1)
|
||||
assert orig_weight is not None
|
||||
return orig_weight * weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
model_size += self.weight.nelement() * self.weight.element_size()
|
||||
model_size += self.on_input.nelement() * self.on_input.element_size()
|
||||
return model_size
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||
self.on_input = self.on_input.to(device=device, dtype=dtype)
|
||||
68
invokeai/backend/peft/layers/loha_layer.py
Normal file
68
invokeai/backend/peft/layers/loha_layer.py
Normal file
@@ -0,0 +1,68 @@
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.peft.layers.lora_layer_base import LoRALayerBase
|
||||
|
||||
|
||||
class LoHALayer(LoRALayerBase):
|
||||
# w1_a: torch.Tensor
|
||||
# w1_b: torch.Tensor
|
||||
# w2_a: torch.Tensor
|
||||
# w2_b: torch.Tensor
|
||||
# t1: Optional[torch.Tensor] = None
|
||||
# t2: Optional[torch.Tensor] = None
|
||||
|
||||
def __init__(self, layer_key: str, values: Dict[str, torch.Tensor]):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
self.w1_a = values["hada_w1_a"]
|
||||
self.w1_b = values["hada_w1_b"]
|
||||
self.w2_a = values["hada_w2_a"]
|
||||
self.w2_b = values["hada_w2_b"]
|
||||
self.t1 = values.get("hada_t1", None)
|
||||
self.t2 = values.get("hada_t2", None)
|
||||
|
||||
self.rank = self.w1_b.shape[0]
|
||||
self.check_keys(
|
||||
values,
|
||||
{
|
||||
"hada_w1_a",
|
||||
"hada_w1_b",
|
||||
"hada_w2_a",
|
||||
"hada_w2_b",
|
||||
"hada_t1",
|
||||
"hada_t2",
|
||||
},
|
||||
)
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
if self.t1 is None:
|
||||
weight: torch.Tensor = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
|
||||
|
||||
else:
|
||||
rebuild1 = torch.einsum("i j k l, j r, i p -> p r k l", self.t1, self.w1_b, self.w1_a)
|
||||
rebuild2 = torch.einsum("i j k l, j r, i p -> p r k l", self.t2, self.w2_b, self.w2_a)
|
||||
weight = rebuild1 * rebuild2
|
||||
|
||||
return weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
for val in [self.w1_a, self.w1_b, self.w2_a, self.w2_b, self.t1, self.t2]:
|
||||
if val is not None:
|
||||
model_size += val.nelement() * val.element_size()
|
||||
return model_size
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
|
||||
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
|
||||
if self.t1 is not None:
|
||||
self.t1 = self.t1.to(device=device, dtype=dtype)
|
||||
|
||||
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
|
||||
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
|
||||
if self.t2 is not None:
|
||||
self.t2 = self.t2.to(device=device, dtype=dtype)
|
||||
114
invokeai/backend/peft/layers/lokr_layer.py
Normal file
114
invokeai/backend/peft/layers/lokr_layer.py
Normal file
@@ -0,0 +1,114 @@
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.peft.layers.lora_layer_base import LoRALayerBase
|
||||
|
||||
|
||||
class LoKRLayer(LoRALayerBase):
|
||||
# w1: Optional[torch.Tensor] = None
|
||||
# w1_a: Optional[torch.Tensor] = None
|
||||
# w1_b: Optional[torch.Tensor] = None
|
||||
# w2: Optional[torch.Tensor] = None
|
||||
# w2_a: Optional[torch.Tensor] = None
|
||||
# w2_b: Optional[torch.Tensor] = None
|
||||
# t2: Optional[torch.Tensor] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
self.w1 = values.get("lokr_w1", None)
|
||||
if self.w1 is None:
|
||||
self.w1_a = values["lokr_w1_a"]
|
||||
self.w1_b = values["lokr_w1_b"]
|
||||
else:
|
||||
self.w1_b = None
|
||||
self.w1_a = None
|
||||
|
||||
self.w2 = values.get("lokr_w2", None)
|
||||
if self.w2 is None:
|
||||
self.w2_a = values["lokr_w2_a"]
|
||||
self.w2_b = values["lokr_w2_b"]
|
||||
else:
|
||||
self.w2_a = None
|
||||
self.w2_b = None
|
||||
|
||||
self.t2 = values.get("lokr_t2", None)
|
||||
|
||||
if self.w1_b is not None:
|
||||
self.rank = self.w1_b.shape[0]
|
||||
elif self.w2_b is not None:
|
||||
self.rank = self.w2_b.shape[0]
|
||||
else:
|
||||
self.rank = None # unscaled
|
||||
|
||||
self.check_keys(
|
||||
values,
|
||||
{
|
||||
"lokr_w1",
|
||||
"lokr_w1_a",
|
||||
"lokr_w1_b",
|
||||
"lokr_w2",
|
||||
"lokr_w2_a",
|
||||
"lokr_w2_b",
|
||||
"lokr_t2",
|
||||
},
|
||||
)
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
w1: Optional[torch.Tensor] = self.w1
|
||||
if w1 is None:
|
||||
assert self.w1_a is not None
|
||||
assert self.w1_b is not None
|
||||
w1 = self.w1_a @ self.w1_b
|
||||
|
||||
w2 = self.w2
|
||||
if w2 is None:
|
||||
if self.t2 is None:
|
||||
assert self.w2_a is not None
|
||||
assert self.w2_b is not None
|
||||
w2 = self.w2_a @ self.w2_b
|
||||
else:
|
||||
w2 = torch.einsum("i j k l, i p, j r -> p r k l", self.t2, self.w2_a, self.w2_b)
|
||||
|
||||
if len(w2.shape) == 4:
|
||||
w1 = w1.unsqueeze(2).unsqueeze(2)
|
||||
w2 = w2.contiguous()
|
||||
assert w1 is not None
|
||||
assert w2 is not None
|
||||
weight = torch.kron(w1, w2)
|
||||
|
||||
return weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
for val in [self.w1, self.w1_a, self.w1_b, self.w2, self.w2_a, self.w2_b, self.t2]:
|
||||
if val is not None:
|
||||
model_size += val.nelement() * val.element_size()
|
||||
return model_size
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
if self.w1 is not None:
|
||||
self.w1 = self.w1.to(device=device, dtype=dtype)
|
||||
else:
|
||||
assert self.w1_a is not None
|
||||
assert self.w1_b is not None
|
||||
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
|
||||
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
|
||||
|
||||
if self.w2 is not None:
|
||||
self.w2 = self.w2.to(device=device, dtype=dtype)
|
||||
else:
|
||||
assert self.w2_a is not None
|
||||
assert self.w2_b is not None
|
||||
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
|
||||
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
|
||||
|
||||
if self.t2 is not None:
|
||||
self.t2 = self.t2.to(device=device, dtype=dtype)
|
||||
59
invokeai/backend/peft/layers/lora_layer.py
Normal file
59
invokeai/backend/peft/layers/lora_layer.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.peft.layers.lora_layer_base import LoRALayerBase
|
||||
|
||||
|
||||
# TODO: find and debug lora/locon with bias
|
||||
class LoRALayer(LoRALayerBase):
|
||||
# up: torch.Tensor
|
||||
# mid: Optional[torch.Tensor]
|
||||
# down: torch.Tensor
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
self.up = values["lora_up.weight"]
|
||||
self.down = values["lora_down.weight"]
|
||||
self.mid = values.get("lora_mid.weight", None)
|
||||
|
||||
self.rank = self.down.shape[0]
|
||||
self.check_keys(
|
||||
values,
|
||||
{
|
||||
"lora_up.weight",
|
||||
"lora_down.weight",
|
||||
"lora_mid.weight",
|
||||
},
|
||||
)
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
if self.mid is not None:
|
||||
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
|
||||
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
|
||||
weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down)
|
||||
else:
|
||||
weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1)
|
||||
|
||||
return weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
for val in [self.up, self.mid, self.down]:
|
||||
if val is not None:
|
||||
model_size += val.nelement() * val.element_size()
|
||||
return model_size
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.up = self.up.to(device=device, dtype=dtype)
|
||||
self.down = self.down.to(device=device, dtype=dtype)
|
||||
|
||||
if self.mid is not None:
|
||||
self.mid = self.mid.to(device=device, dtype=dtype)
|
||||
74
invokeai/backend/peft/layers/lora_layer_base.py
Normal file
74
invokeai/backend/peft/layers/lora_layer_base.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from typing import Dict, Optional, Set
|
||||
|
||||
import torch
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
|
||||
|
||||
class LoRALayerBase:
|
||||
# rank: Optional[int]
|
||||
# alpha: Optional[float]
|
||||
# bias: Optional[torch.Tensor]
|
||||
# layer_key: str
|
||||
|
||||
# @property
|
||||
# def scale(self):
|
||||
# return self.alpha / self.rank if (self.alpha and self.rank) else 1.0
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
if "alpha" in values:
|
||||
self.alpha = values["alpha"].item()
|
||||
else:
|
||||
self.alpha = None
|
||||
|
||||
if "bias_indices" in values and "bias_values" in values and "bias_size" in values:
|
||||
self.bias: Optional[torch.Tensor] = torch.sparse_coo_tensor(
|
||||
values["bias_indices"],
|
||||
values["bias_values"],
|
||||
tuple(values["bias_size"]),
|
||||
)
|
||||
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
self.rank = None # set in layer implementation
|
||||
self.layer_key = layer_key
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_bias(self, orig_bias: torch.Tensor) -> Optional[torch.Tensor]:
|
||||
return self.bias
|
||||
|
||||
def get_parameters(self, orig_module: torch.nn.Module) -> Dict[str, torch.Tensor]:
|
||||
params = {"weight": self.get_weight(orig_module.weight)}
|
||||
bias = self.get_bias(orig_module.bias)
|
||||
if bias is not None:
|
||||
params["bias"] = bias
|
||||
return params
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = 0
|
||||
for val in [self.bias]:
|
||||
if val is not None:
|
||||
model_size += val.nelement() * val.element_size()
|
||||
return model_size
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
if self.bias is not None:
|
||||
self.bias = self.bias.to(device=device, dtype=dtype)
|
||||
|
||||
def check_keys(self, values: Dict[str, torch.Tensor], known_keys: Set[str]):
|
||||
"""Log a warning if values contains unhandled keys."""
|
||||
# {"alpha", "bias_indices", "bias_values", "bias_size"} are hard-coded, because they are handled by
|
||||
# `LoRALayerBase`. Sub-classes should provide the known_keys that they handled.
|
||||
all_known_keys = known_keys | {"alpha", "bias_indices", "bias_values", "bias_size"}
|
||||
unknown_keys = set(values.keys()) - all_known_keys
|
||||
if unknown_keys:
|
||||
logger.warning(
|
||||
f"Unexpected keys found in LoRA/LyCORIS layer, model might work incorrectly! Keys: {unknown_keys}"
|
||||
)
|
||||
37
invokeai/backend/peft/layers/norm_layer.py
Normal file
37
invokeai/backend/peft/layers/norm_layer.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.peft.layers.lora_layer_base import LoRALayerBase
|
||||
|
||||
|
||||
class NormLayer(LoRALayerBase):
|
||||
# bias handled in LoRALayerBase(calc_size, to)
|
||||
# weight: torch.Tensor
|
||||
# bias: Optional[torch.Tensor]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
self.weight = values["w_norm"]
|
||||
self.bias = values.get("b_norm", None)
|
||||
|
||||
self.rank = None # unscaled
|
||||
self.check_keys(values, {"w_norm", "b_norm"})
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
return self.weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
model_size += self.weight.nelement() * self.weight.element_size()
|
||||
return model_size
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||
33
invokeai/backend/peft/layers/utils.py
Normal file
33
invokeai/backend/peft/layers/utils.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.peft.layers.any_lora_layer import AnyLoRALayer
|
||||
from invokeai.backend.peft.layers.full_layer import FullLayer
|
||||
from invokeai.backend.peft.layers.ia3_layer import IA3Layer
|
||||
from invokeai.backend.peft.layers.loha_layer import LoHALayer
|
||||
from invokeai.backend.peft.layers.lokr_layer import LoKRLayer
|
||||
from invokeai.backend.peft.layers.lora_layer import LoRALayer
|
||||
from invokeai.backend.peft.layers.norm_layer import NormLayer
|
||||
|
||||
|
||||
def peft_layer_from_state_dict(layer_key: str, state_dict: Dict[str, torch.Tensor]) -> AnyLoRALayer:
|
||||
# Detect layers according to LyCORIS detection logic(`weight_list_det`)
|
||||
# https://github.com/KohakuBlueleaf/LyCORIS/tree/8ad8000efb79e2b879054da8c9356e6143591bad/lycoris/modules
|
||||
|
||||
if "lora_up.weight" in state_dict:
|
||||
# LoRA a.k.a LoCon
|
||||
return LoRALayer(layer_key, state_dict)
|
||||
elif "hada_w1_a" in state_dict:
|
||||
return LoHALayer(layer_key, state_dict)
|
||||
elif "lokr_w1" in state_dict or "lokr_w1_a" in state_dict:
|
||||
return LoKRLayer(layer_key, state_dict)
|
||||
elif "diff" in state_dict:
|
||||
# Full a.k.a Diff
|
||||
return FullLayer(layer_key, state_dict)
|
||||
elif "on_input" in state_dict:
|
||||
return IA3Layer(layer_key, state_dict)
|
||||
elif "w_norm" in state_dict:
|
||||
return NormLayer(layer_key, state_dict)
|
||||
else:
|
||||
raise ValueError(f"Unsupported lora format: {state_dict.keys()}")
|
||||
22
invokeai/backend/peft/lora.py
Normal file
22
invokeai/backend/peft/lora.py
Normal file
@@ -0,0 +1,22 @@
|
||||
# Copyright (c) 2024 The InvokeAI Development team
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.peft.layers.any_lora_layer import AnyLoRALayer
|
||||
from invokeai.backend.raw_model import RawModel
|
||||
|
||||
|
||||
class LoRAModelRaw(RawModel): # (torch.nn.Module):
|
||||
def __init__(self, layers: Dict[str, AnyLoRALayer]):
|
||||
self.layers = layers
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
for _key, layer in self.layers.items():
|
||||
layer.to(device=device, dtype=dtype)
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = 0
|
||||
for _, layer in self.layers.items():
|
||||
model_size += layer.calc_size()
|
||||
return model_size
|
||||
102
invokeai/backend/peft/peft_patcher.py
Normal file
102
invokeai/backend/peft/peft_patcher.py
Normal file
@@ -0,0 +1,102 @@
|
||||
from contextlib import contextmanager
|
||||
from typing import Dict, Iterator, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.peft.lora import LoRAModelRaw
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
|
||||
|
||||
|
||||
class PeftPatcher:
|
||||
@classmethod
|
||||
@torch.no_grad()
|
||||
@contextmanager
|
||||
def apply_peft_patches(
|
||||
cls,
|
||||
model: torch.nn.Module,
|
||||
patches: Iterator[Tuple[LoRAModelRaw, float]],
|
||||
prefix: str,
|
||||
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
|
||||
):
|
||||
"""Apply one or more PEFT patches to a model.
|
||||
|
||||
:param model: The model to patch.
|
||||
:param loras: An iterator that returns tuples of PEFT patches and associated weights. An iterator is used so
|
||||
that the PEFT patches do not need to be loaded into memory all at once.
|
||||
:param prefix: The keys in the patches will be filtered to only include weights with this prefix.
|
||||
:cached_weights: Read-only copy of the model's state dict in CPU, for efficient unpatching purposes.
|
||||
"""
|
||||
original_weights = OriginalWeightsStorage(cached_weights)
|
||||
try:
|
||||
for patch, patch_weight in patches:
|
||||
cls._apply_peft_patch(
|
||||
model=model,
|
||||
prefix=prefix,
|
||||
patch=patch,
|
||||
patch_weight=patch_weight,
|
||||
original_weights=original_weights,
|
||||
)
|
||||
|
||||
yield
|
||||
finally:
|
||||
for param_key, weight in original_weights.get_changed_weights():
|
||||
model.get_parameter(param_key).copy_(weight)
|
||||
|
||||
@classmethod
|
||||
@torch.no_grad()
|
||||
def _apply_peft_patch(
|
||||
cls,
|
||||
model: torch.nn.Module,
|
||||
prefix: str,
|
||||
patch: LoRAModelRaw,
|
||||
patch_weight: float,
|
||||
original_weights: OriginalWeightsStorage,
|
||||
):
|
||||
"""
|
||||
Apply one a LoRA to a model.
|
||||
:param model: The model to patch.
|
||||
:param patch: LoRA model to patch in.
|
||||
:param patch_weight: LoRA patch weight.
|
||||
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
|
||||
:param original_weights: Storage with original weights, filled by weights which lora patches, used for unpatching.
|
||||
"""
|
||||
|
||||
if patch_weight == 0:
|
||||
return
|
||||
|
||||
for layer_key, layer in patch.layers.items():
|
||||
if not layer_key.startswith(prefix):
|
||||
continue
|
||||
|
||||
module = model.get_submodule(layer_key)
|
||||
|
||||
# All of the LoRA weight calculations will be done on the same device as the module weight.
|
||||
# (Performance will be best if this is a CUDA device.)
|
||||
device = module.weight.device
|
||||
dtype = module.weight.dtype
|
||||
|
||||
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
|
||||
|
||||
# We intentionally move to the target device first, then cast. Experimentally, this was found to
|
||||
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
|
||||
# same thing in a single call to '.to(...)'.
|
||||
layer.to(device=device)
|
||||
layer.to(dtype=torch.float32)
|
||||
|
||||
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
|
||||
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
|
||||
for param_name, lora_param_weight in layer.get_parameters(module).items():
|
||||
param_key = layer_key + "." + param_name
|
||||
module_param = module.get_parameter(param_name)
|
||||
|
||||
# Save original weight
|
||||
original_weights.save(param_key, module_param)
|
||||
|
||||
if module_param.shape != lora_param_weight.shape:
|
||||
lora_param_weight = lora_param_weight.reshape(module_param.shape)
|
||||
|
||||
lora_param_weight *= patch_weight * layer_scale
|
||||
module_param += lora_param_weight.to(dtype=dtype)
|
||||
|
||||
layer.to(device=TorchDevice.CPU_DEVICE)
|
||||
@@ -54,8 +54,10 @@ class InvokeLinear8bitLt(bnb.nn.Linear8bitLt):
|
||||
|
||||
# See `bnb.nn.Linear8bitLt._save_to_state_dict()` for the serialization logic of SCB and weight_format.
|
||||
scb = state_dict.pop(prefix + "SCB", None)
|
||||
# weight_format is unused, but we pop it so we can validate that there are no unexpected keys.
|
||||
_weight_format = state_dict.pop(prefix + "weight_format", None)
|
||||
|
||||
# Currently, we only support weight_format=0.
|
||||
weight_format = state_dict.pop(prefix + "weight_format", None)
|
||||
assert weight_format == 0
|
||||
|
||||
# TODO(ryand): Technically, we should be using `strict`, `missing_keys`, `unexpected_keys`, and `error_msgs`
|
||||
# rather than raising an exception to correctly implement this API.
|
||||
@@ -89,6 +91,14 @@ class InvokeLinear8bitLt(bnb.nn.Linear8bitLt):
|
||||
)
|
||||
self.bias = bias if bias is None else torch.nn.Parameter(bias)
|
||||
|
||||
# Reset the state. The persisted fields are based on the initialization behaviour in
|
||||
# `bnb.nn.Linear8bitLt.__init__()`.
|
||||
new_state = bnb.MatmulLtState()
|
||||
new_state.threshold = self.state.threshold
|
||||
new_state.has_fp16_weights = False
|
||||
new_state.use_pool = self.state.use_pool
|
||||
self.state = new_state
|
||||
|
||||
|
||||
def _convert_linear_layers_to_llm_8bit(
|
||||
module: torch.nn.Module, ignore_modules: set[str], outlier_threshold: float, prefix: str = ""
|
||||
|
||||
@@ -43,6 +43,11 @@ class FLUXConditioningInfo:
|
||||
clip_embeds: torch.Tensor
|
||||
t5_embeds: torch.Tensor
|
||||
|
||||
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
|
||||
self.clip_embeds = self.clip_embeds.to(device=device, dtype=dtype)
|
||||
self.t5_embeds = self.t5_embeds.to(device=device, dtype=dtype)
|
||||
return self
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConditioningFieldData:
|
||||
|
||||
@@ -12,7 +12,7 @@ from invokeai.backend.util.devices import TorchDevice
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.peft.lora import LoRAModelRaw
|
||||
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
|
||||
|
||||
|
||||
|
||||
@@ -3,10 +3,9 @@ Initialization file for invokeai.backend.util
|
||||
"""
|
||||
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.backend.util.util import GIG, Chdir, directory_size
|
||||
from invokeai.backend.util.util import Chdir, directory_size
|
||||
|
||||
__all__ = [
|
||||
"GIG",
|
||||
"directory_size",
|
||||
"Chdir",
|
||||
"InvokeAILogger",
|
||||
|
||||
@@ -7,9 +7,6 @@ from pathlib import Path
|
||||
|
||||
from PIL import Image
|
||||
|
||||
# actual size of a gig
|
||||
GIG = 1073741824
|
||||
|
||||
|
||||
def slugify(value: str, allow_unicode: bool = False) -> str:
|
||||
"""
|
||||
|
||||
@@ -696,6 +696,8 @@
|
||||
"availableModels": "Available Models",
|
||||
"baseModel": "Base Model",
|
||||
"cancel": "Cancel",
|
||||
"clipEmbed": "CLIP Embed",
|
||||
"clipVision": "CLIP Vision",
|
||||
"config": "Config",
|
||||
"convert": "Convert",
|
||||
"convertingModelBegin": "Converting Model. Please wait.",
|
||||
@@ -783,6 +785,7 @@
|
||||
"settings": "Settings",
|
||||
"simpleModelPlaceholder": "URL or path to a local file or diffusers folder",
|
||||
"source": "Source",
|
||||
"spandrelImageToImage": "Image to Image (Spandrel)",
|
||||
"starterModels": "Starter Models",
|
||||
"starterModelsInModelManager": "Starter Models can be found in Model Manager",
|
||||
"syncModels": "Sync Models",
|
||||
@@ -791,6 +794,7 @@
|
||||
"loraTriggerPhrases": "LoRA Trigger Phrases",
|
||||
"mainModelTriggerPhrases": "Main Model Trigger Phrases",
|
||||
"typePhraseHere": "Type phrase here",
|
||||
"t5Encoder": "T5 Encoder",
|
||||
"upcastAttention": "Upcast Attention",
|
||||
"uploadImage": "Upload Image",
|
||||
"urlOrLocalPath": "URL or Local Path",
|
||||
|
||||
@@ -14,6 +14,7 @@ import DeleteImageModal from 'features/deleteImageModal/components/DeleteImageMo
|
||||
import { DynamicPromptsModal } from 'features/dynamicPrompts/components/DynamicPromptsPreviewModal';
|
||||
import { useStarterModelsToast } from 'features/modelManagerV2/hooks/useStarterModelsToast';
|
||||
import { StylePresetModal } from 'features/stylePresets/components/StylePresetForm/StylePresetModal';
|
||||
import { activeStylePresetIdChanged } from 'features/stylePresets/store/stylePresetSlice';
|
||||
import { configChanged } from 'features/system/store/configSlice';
|
||||
import { languageSelector } from 'features/system/store/systemSelectors';
|
||||
import InvokeTabs from 'features/ui/components/InvokeTabs';
|
||||
@@ -39,10 +40,17 @@ interface Props {
|
||||
action: 'sendToImg2Img' | 'sendToCanvas' | 'useAllParameters';
|
||||
};
|
||||
selectedWorkflowId?: string;
|
||||
selectedStylePresetId?: string;
|
||||
destination?: InvokeTabName | undefined;
|
||||
}
|
||||
|
||||
const App = ({ config = DEFAULT_CONFIG, selectedImage, selectedWorkflowId, destination }: Props) => {
|
||||
const App = ({
|
||||
config = DEFAULT_CONFIG,
|
||||
selectedImage,
|
||||
selectedWorkflowId,
|
||||
selectedStylePresetId,
|
||||
destination,
|
||||
}: Props) => {
|
||||
const language = useAppSelector(languageSelector);
|
||||
const logger = useLogger('system');
|
||||
const dispatch = useAppDispatch();
|
||||
@@ -81,6 +89,12 @@ const App = ({ config = DEFAULT_CONFIG, selectedImage, selectedWorkflowId, desti
|
||||
}
|
||||
}, [selectedWorkflowId, getAndLoadWorkflow]);
|
||||
|
||||
useEffect(() => {
|
||||
if (selectedStylePresetId) {
|
||||
dispatch(activeStylePresetIdChanged(selectedStylePresetId));
|
||||
}
|
||||
}, [dispatch, selectedStylePresetId]);
|
||||
|
||||
useEffect(() => {
|
||||
if (destination) {
|
||||
dispatch(setActiveTab(destination));
|
||||
|
||||
@@ -45,6 +45,7 @@ interface Props extends PropsWithChildren {
|
||||
action: 'sendToImg2Img' | 'sendToCanvas' | 'useAllParameters';
|
||||
};
|
||||
selectedWorkflowId?: string;
|
||||
selectedStylePresetId?: string;
|
||||
destination?: InvokeTabName;
|
||||
customStarUi?: CustomStarUi;
|
||||
socketOptions?: Partial<ManagerOptions & SocketOptions>;
|
||||
@@ -66,6 +67,7 @@ const InvokeAIUI = ({
|
||||
queueId,
|
||||
selectedImage,
|
||||
selectedWorkflowId,
|
||||
selectedStylePresetId,
|
||||
destination,
|
||||
customStarUi,
|
||||
socketOptions,
|
||||
@@ -227,6 +229,7 @@ const InvokeAIUI = ({
|
||||
config={config}
|
||||
selectedImage={selectedImage}
|
||||
selectedWorkflowId={selectedWorkflowId}
|
||||
selectedStylePresetId={selectedStylePresetId}
|
||||
destination={destination}
|
||||
/>
|
||||
</AppDndContext>
|
||||
|
||||
@@ -175,12 +175,12 @@ const ModelList = () => {
|
||||
{/* T5 Encoders List */}
|
||||
{isLoadingT5EncoderModels && <FetchingModelsLoader loadingMessage="Loading T5 Encoder Models..." />}
|
||||
{!isLoadingT5EncoderModels && filteredT5EncoderModels.length > 0 && (
|
||||
<ModelListWrapper title="T5 Encoder" modelList={filteredT5EncoderModels} key="t5-encoder" />
|
||||
<ModelListWrapper title={t('modelManager.t5Encoder')} modelList={filteredT5EncoderModels} key="t5-encoder" />
|
||||
)}
|
||||
{/* Clip Embed List */}
|
||||
{isLoadingClipEmbedModels && <FetchingModelsLoader loadingMessage="Loading Clip Embed Models..." />}
|
||||
{!isLoadingClipEmbedModels && filteredClipEmbedModels.length > 0 && (
|
||||
<ModelListWrapper title="Clip Embed" modelList={filteredClipEmbedModels} key="clip-embed" />
|
||||
<ModelListWrapper title={t('modelManager.clipEmbed')} modelList={filteredClipEmbedModels} key="clip-embed" />
|
||||
)}
|
||||
{/* Spandrel Image to Image List */}
|
||||
{isLoadingSpandrelImageToImageModels && (
|
||||
@@ -188,7 +188,7 @@ const ModelList = () => {
|
||||
)}
|
||||
{!isLoadingSpandrelImageToImageModels && filteredSpandrelImageToImageModels.length > 0 && (
|
||||
<ModelListWrapper
|
||||
title="Image-to-Image"
|
||||
title={t('modelManager.spandrelImageToImage')}
|
||||
modelList={filteredSpandrelImageToImageModels}
|
||||
key="spandrel-image-to-image"
|
||||
/>
|
||||
|
||||
@@ -19,11 +19,10 @@ export const ModelTypeFilter = memo(() => {
|
||||
controlnet: 'ControlNet',
|
||||
vae: 'VAE',
|
||||
t2i_adapter: t('common.t2iAdapter'),
|
||||
t5_encoder: 'T5Encoder',
|
||||
clip_embed: 'Clip Embed',
|
||||
t5_encoder: t('modelManager.t5Encoder'),
|
||||
clip_embed: t('modelManager.clipEmbed'),
|
||||
ip_adapter: t('common.ipAdapter'),
|
||||
clip_vision: 'Clip Vision',
|
||||
spandrel_image_to_image: 'Image-to-Image',
|
||||
spandrel_image_to_image: t('modelManager.spandrelImageToImage'),
|
||||
}),
|
||||
[t]
|
||||
);
|
||||
|
||||
@@ -6,6 +6,8 @@ import {
|
||||
isBoardFieldInputTemplate,
|
||||
isBooleanFieldInputInstance,
|
||||
isBooleanFieldInputTemplate,
|
||||
isCLIPEmbedModelFieldInputInstance,
|
||||
isCLIPEmbedModelFieldInputTemplate,
|
||||
isColorFieldInputInstance,
|
||||
isColorFieldInputTemplate,
|
||||
isControlNetModelFieldInputInstance,
|
||||
@@ -16,6 +18,8 @@ import {
|
||||
isFloatFieldInputTemplate,
|
||||
isFluxMainModelFieldInputInstance,
|
||||
isFluxMainModelFieldInputTemplate,
|
||||
isFluxVAEModelFieldInputInstance,
|
||||
isFluxVAEModelFieldInputTemplate,
|
||||
isImageFieldInputInstance,
|
||||
isImageFieldInputTemplate,
|
||||
isIntegerFieldInputInstance,
|
||||
@@ -49,10 +53,12 @@ import { memo } from 'react';
|
||||
|
||||
import BoardFieldInputComponent from './inputs/BoardFieldInputComponent';
|
||||
import BooleanFieldInputComponent from './inputs/BooleanFieldInputComponent';
|
||||
import CLIPEmbedModelFieldInputComponent from './inputs/CLIPEmbedModelFieldInputComponent';
|
||||
import ColorFieldInputComponent from './inputs/ColorFieldInputComponent';
|
||||
import ControlNetModelFieldInputComponent from './inputs/ControlNetModelFieldInputComponent';
|
||||
import EnumFieldInputComponent from './inputs/EnumFieldInputComponent';
|
||||
import FluxMainModelFieldInputComponent from './inputs/FluxMainModelFieldInputComponent';
|
||||
import FluxVAEModelFieldInputComponent from './inputs/FluxVAEModelFieldInputComponent';
|
||||
import ImageFieldInputComponent from './inputs/ImageFieldInputComponent';
|
||||
import IPAdapterModelFieldInputComponent from './inputs/IPAdapterModelFieldInputComponent';
|
||||
import LoRAModelFieldInputComponent from './inputs/LoRAModelFieldInputComponent';
|
||||
@@ -122,6 +128,13 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
||||
if (isT5EncoderModelFieldInputInstance(fieldInstance) && isT5EncoderModelFieldInputTemplate(fieldTemplate)) {
|
||||
return <T5EncoderModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
if (isCLIPEmbedModelFieldInputInstance(fieldInstance) && isCLIPEmbedModelFieldInputTemplate(fieldTemplate)) {
|
||||
return <CLIPEmbedModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
|
||||
if (isFluxVAEModelFieldInputInstance(fieldInstance) && isFluxVAEModelFieldInputTemplate(fieldTemplate)) {
|
||||
return <FluxVAEModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
|
||||
if (isLoRAModelFieldInputInstance(fieldInstance) && isLoRAModelFieldInputTemplate(fieldTemplate)) {
|
||||
return <LoRAModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
|
||||
@@ -0,0 +1,60 @@
|
||||
import { Combobox, Flex, FormControl, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { fieldCLIPEmbedValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { CLIPEmbedModelFieldInputInstance, CLIPEmbedModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useClipEmbedModels } from 'services/api/hooks/modelsByType';
|
||||
import type { ClipEmbedModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
type Props = FieldComponentProps<CLIPEmbedModelFieldInputInstance, CLIPEmbedModelFieldInputTemplate>;
|
||||
|
||||
const CLIPEmbedModelFieldInputComponent = (props: Props) => {
|
||||
const { nodeId, field } = props;
|
||||
const { t } = useTranslation();
|
||||
const disabledTabs = useAppSelector((s) => s.config.disabledTabs);
|
||||
const dispatch = useAppDispatch();
|
||||
const [modelConfigs, { isLoading }] = useClipEmbedModels();
|
||||
const _onChange = useCallback(
|
||||
(value: ClipEmbedModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldCLIPEmbedValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
||||
modelConfigs,
|
||||
onChange: _onChange,
|
||||
isLoading,
|
||||
selectedModel: field.value,
|
||||
});
|
||||
|
||||
return (
|
||||
<Flex w="full" alignItems="center" gap={2}>
|
||||
<Tooltip label={!disabledTabs.includes('models') && t('modelManager.starterModelsInModelManager')}>
|
||||
<FormControl className="nowheel nodrag" isDisabled={!options.length} isInvalid={!value}>
|
||||
<Combobox
|
||||
value={value}
|
||||
placeholder={placeholder}
|
||||
options={options}
|
||||
onChange={onChange}
|
||||
noOptionsMessage={noOptionsMessage}
|
||||
/>
|
||||
</FormControl>
|
||||
</Tooltip>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(CLIPEmbedModelFieldInputComponent);
|
||||
@@ -0,0 +1,60 @@
|
||||
import { Combobox, Flex, FormControl, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { fieldFluxVAEModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { FluxVAEModelFieldInputInstance, FluxVAEModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useFluxVAEModels } from 'services/api/hooks/modelsByType';
|
||||
import type { VAEModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
type Props = FieldComponentProps<FluxVAEModelFieldInputInstance, FluxVAEModelFieldInputTemplate>;
|
||||
|
||||
const FluxVAEModelFieldInputComponent = (props: Props) => {
|
||||
const { nodeId, field } = props;
|
||||
const { t } = useTranslation();
|
||||
const disabledTabs = useAppSelector((s) => s.config.disabledTabs);
|
||||
const dispatch = useAppDispatch();
|
||||
const [modelConfigs, { isLoading }] = useFluxVAEModels();
|
||||
const _onChange = useCallback(
|
||||
(value: VAEModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldFluxVAEModelValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
||||
modelConfigs,
|
||||
onChange: _onChange,
|
||||
isLoading,
|
||||
selectedModel: field.value,
|
||||
});
|
||||
|
||||
return (
|
||||
<Flex w="full" alignItems="center" gap={2}>
|
||||
<Tooltip label={!disabledTabs.includes('models') && t('modelManager.starterModelsInModelManager')}>
|
||||
<FormControl className="nowheel nodrag" isDisabled={!options.length} isInvalid={!value}>
|
||||
<Combobox
|
||||
value={value}
|
||||
placeholder={placeholder}
|
||||
options={options}
|
||||
onChange={onChange}
|
||||
noOptionsMessage={noOptionsMessage}
|
||||
/>
|
||||
</FormControl>
|
||||
</Tooltip>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(FluxVAEModelFieldInputComponent);
|
||||
@@ -6,11 +6,13 @@ import { SHARED_NODE_PROPERTIES } from 'features/nodes/types/constants';
|
||||
import type {
|
||||
BoardFieldValue,
|
||||
BooleanFieldValue,
|
||||
CLIPEmbedModelFieldValue,
|
||||
ColorFieldValue,
|
||||
ControlNetModelFieldValue,
|
||||
EnumFieldValue,
|
||||
FieldValue,
|
||||
FloatFieldValue,
|
||||
FluxVAEModelFieldValue,
|
||||
ImageFieldValue,
|
||||
IntegerFieldValue,
|
||||
IPAdapterModelFieldValue,
|
||||
@@ -29,10 +31,12 @@ import type {
|
||||
import {
|
||||
zBoardFieldValue,
|
||||
zBooleanFieldValue,
|
||||
zCLIPEmbedModelFieldValue,
|
||||
zColorFieldValue,
|
||||
zControlNetModelFieldValue,
|
||||
zEnumFieldValue,
|
||||
zFloatFieldValue,
|
||||
zFluxVAEModelFieldValue,
|
||||
zImageFieldValue,
|
||||
zIntegerFieldValue,
|
||||
zIPAdapterModelFieldValue,
|
||||
@@ -346,6 +350,12 @@ export const nodesSlice = createSlice({
|
||||
fieldT5EncoderValueChanged: (state, action: FieldValueAction<T5EncoderModelFieldValue>) => {
|
||||
fieldValueReducer(state, action, zT5EncoderModelFieldValue);
|
||||
},
|
||||
fieldCLIPEmbedValueChanged: (state, action: FieldValueAction<CLIPEmbedModelFieldValue>) => {
|
||||
fieldValueReducer(state, action, zCLIPEmbedModelFieldValue);
|
||||
},
|
||||
fieldFluxVAEModelValueChanged: (state, action: FieldValueAction<FluxVAEModelFieldValue>) => {
|
||||
fieldValueReducer(state, action, zFluxVAEModelFieldValue);
|
||||
},
|
||||
fieldEnumModelValueChanged: (state, action: FieldValueAction<EnumFieldValue>) => {
|
||||
fieldValueReducer(state, action, zEnumFieldValue);
|
||||
},
|
||||
@@ -408,6 +418,8 @@ export const {
|
||||
fieldStringValueChanged,
|
||||
fieldVaeModelValueChanged,
|
||||
fieldT5EncoderValueChanged,
|
||||
fieldCLIPEmbedValueChanged,
|
||||
fieldFluxVAEModelValueChanged,
|
||||
nodeEditorReset,
|
||||
nodeIsIntermediateChanged,
|
||||
nodeIsOpenChanged,
|
||||
@@ -521,6 +533,8 @@ export const isAnyNodeOrEdgeMutation = isAnyOf(
|
||||
fieldStringValueChanged,
|
||||
fieldVaeModelValueChanged,
|
||||
fieldT5EncoderValueChanged,
|
||||
fieldCLIPEmbedValueChanged,
|
||||
fieldFluxVAEModelValueChanged,
|
||||
nodesChanged,
|
||||
nodeIsIntermediateChanged,
|
||||
nodeIsOpenChanged,
|
||||
|
||||
@@ -151,6 +151,14 @@ const zT5EncoderModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('T5EncoderModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zCLIPEmbedModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('CLIPEmbedModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zFluxVAEModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('FluxVAEModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zSchedulerFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('SchedulerField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
@@ -175,6 +183,8 @@ const zStatefulFieldType = z.union([
|
||||
zT2IAdapterModelFieldType,
|
||||
zSpandrelImageToImageModelFieldType,
|
||||
zT5EncoderModelFieldType,
|
||||
zCLIPEmbedModelFieldType,
|
||||
zFluxVAEModelFieldType,
|
||||
zColorFieldType,
|
||||
zSchedulerFieldType,
|
||||
]);
|
||||
@@ -667,7 +677,53 @@ export const isT5EncoderModelFieldInputInstance = (val: unknown): val is T5Encod
|
||||
export const isT5EncoderModelFieldInputTemplate = (val: unknown): val is T5EncoderModelFieldInputTemplate =>
|
||||
zT5EncoderModelFieldInputTemplate.safeParse(val).success;
|
||||
|
||||
// #endregio
|
||||
// #endregion
|
||||
|
||||
// #region FluxVAEModelField
|
||||
|
||||
export const zFluxVAEModelFieldValue = zModelIdentifierField.optional();
|
||||
const zFluxVAEModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zFluxVAEModelFieldValue,
|
||||
});
|
||||
const zFluxVAEModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zFluxVAEModelFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zFluxVAEModelFieldValue,
|
||||
});
|
||||
|
||||
export type FluxVAEModelFieldValue = z.infer<typeof zFluxVAEModelFieldValue>;
|
||||
|
||||
export type FluxVAEModelFieldInputInstance = z.infer<typeof zFluxVAEModelFieldInputInstance>;
|
||||
export type FluxVAEModelFieldInputTemplate = z.infer<typeof zFluxVAEModelFieldInputTemplate>;
|
||||
export const isFluxVAEModelFieldInputInstance = (val: unknown): val is FluxVAEModelFieldInputInstance =>
|
||||
zFluxVAEModelFieldInputInstance.safeParse(val).success;
|
||||
export const isFluxVAEModelFieldInputTemplate = (val: unknown): val is FluxVAEModelFieldInputTemplate =>
|
||||
zFluxVAEModelFieldInputTemplate.safeParse(val).success;
|
||||
|
||||
// #endregion
|
||||
|
||||
// #region CLIPEmbedModelField
|
||||
|
||||
export const zCLIPEmbedModelFieldValue = zModelIdentifierField.optional();
|
||||
const zCLIPEmbedModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zCLIPEmbedModelFieldValue,
|
||||
});
|
||||
const zCLIPEmbedModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zCLIPEmbedModelFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zCLIPEmbedModelFieldValue,
|
||||
});
|
||||
|
||||
export type CLIPEmbedModelFieldValue = z.infer<typeof zCLIPEmbedModelFieldValue>;
|
||||
|
||||
export type CLIPEmbedModelFieldInputInstance = z.infer<typeof zCLIPEmbedModelFieldInputInstance>;
|
||||
export type CLIPEmbedModelFieldInputTemplate = z.infer<typeof zCLIPEmbedModelFieldInputTemplate>;
|
||||
export const isCLIPEmbedModelFieldInputInstance = (val: unknown): val is CLIPEmbedModelFieldInputInstance =>
|
||||
zCLIPEmbedModelFieldInputInstance.safeParse(val).success;
|
||||
export const isCLIPEmbedModelFieldInputTemplate = (val: unknown): val is CLIPEmbedModelFieldInputTemplate =>
|
||||
zCLIPEmbedModelFieldInputTemplate.safeParse(val).success;
|
||||
|
||||
// #endregion
|
||||
|
||||
// #region SchedulerField
|
||||
|
||||
@@ -758,6 +814,8 @@ export const zStatefulFieldValue = z.union([
|
||||
zT2IAdapterModelFieldValue,
|
||||
zSpandrelImageToImageModelFieldValue,
|
||||
zT5EncoderModelFieldValue,
|
||||
zFluxVAEModelFieldValue,
|
||||
zCLIPEmbedModelFieldValue,
|
||||
zColorFieldValue,
|
||||
zSchedulerFieldValue,
|
||||
]);
|
||||
@@ -788,6 +846,8 @@ const zStatefulFieldInputInstance = z.union([
|
||||
zT2IAdapterModelFieldInputInstance,
|
||||
zSpandrelImageToImageModelFieldInputInstance,
|
||||
zT5EncoderModelFieldInputInstance,
|
||||
zFluxVAEModelFieldInputInstance,
|
||||
zCLIPEmbedModelFieldInputInstance,
|
||||
zColorFieldInputInstance,
|
||||
zSchedulerFieldInputInstance,
|
||||
]);
|
||||
@@ -819,6 +879,8 @@ const zStatefulFieldInputTemplate = z.union([
|
||||
zT2IAdapterModelFieldInputTemplate,
|
||||
zSpandrelImageToImageModelFieldInputTemplate,
|
||||
zT5EncoderModelFieldInputTemplate,
|
||||
zFluxVAEModelFieldInputTemplate,
|
||||
zCLIPEmbedModelFieldInputTemplate,
|
||||
zColorFieldInputTemplate,
|
||||
zSchedulerFieldInputTemplate,
|
||||
zStatelessFieldInputTemplate,
|
||||
|
||||
@@ -23,6 +23,8 @@ const FIELD_VALUE_FALLBACK_MAP: Record<StatefulFieldType['name'], FieldValue> =
|
||||
VAEModelField: undefined,
|
||||
ControlNetModelField: undefined,
|
||||
T5EncoderModelField: undefined,
|
||||
FluxVAEModelField: undefined,
|
||||
CLIPEmbedModelField: undefined,
|
||||
};
|
||||
|
||||
export const buildFieldInputInstance = (id: string, template: FieldInputTemplate): FieldInputInstance => {
|
||||
|
||||
@@ -2,6 +2,7 @@ import { FieldParseError } from 'features/nodes/types/error';
|
||||
import type {
|
||||
BoardFieldInputTemplate,
|
||||
BooleanFieldInputTemplate,
|
||||
CLIPEmbedModelFieldInputTemplate,
|
||||
ColorFieldInputTemplate,
|
||||
ControlNetModelFieldInputTemplate,
|
||||
EnumFieldInputTemplate,
|
||||
@@ -9,6 +10,7 @@ import type {
|
||||
FieldType,
|
||||
FloatFieldInputTemplate,
|
||||
FluxMainModelFieldInputTemplate,
|
||||
FluxVAEModelFieldInputTemplate,
|
||||
ImageFieldInputTemplate,
|
||||
IntegerFieldInputTemplate,
|
||||
IPAdapterModelFieldInputTemplate,
|
||||
@@ -238,6 +240,34 @@ const buildT5EncoderModelFieldInputTemplate: FieldInputTemplateBuilder<T5Encoder
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildCLIPEmbedModelFieldInputTemplate: FieldInputTemplateBuilder<CLIPEmbedModelFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
fieldType,
|
||||
}) => {
|
||||
const template: CLIPEmbedModelFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: fieldType,
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildFluxVAEModelFieldInputTemplate: FieldInputTemplateBuilder<FluxVAEModelFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
fieldType,
|
||||
}) => {
|
||||
const template: FluxVAEModelFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: fieldType,
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildLoRAModelFieldInputTemplate: FieldInputTemplateBuilder<LoRAModelFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
@@ -423,6 +453,8 @@ export const TEMPLATE_BUILDER_MAP: Record<StatefulFieldType['name'], FieldInputT
|
||||
SpandrelImageToImageModelField: buildSpandrelImageToImageModelFieldInputTemplate,
|
||||
VAEModelField: buildVAEModelFieldInputTemplate,
|
||||
T5EncoderModelField: buildT5EncoderModelFieldInputTemplate,
|
||||
CLIPEmbedModelField: buildCLIPEmbedModelFieldInputTemplate,
|
||||
FluxVAEModelField: buildFluxVAEModelFieldInputTemplate,
|
||||
} as const;
|
||||
|
||||
export const buildFieldInputTemplate = (
|
||||
|
||||
@@ -7,6 +7,7 @@ import {
|
||||
isControlNetModelConfig,
|
||||
isControlNetOrT2IAdapterModelConfig,
|
||||
isFluxMainModelModelConfig,
|
||||
isFluxVAEModelConfig,
|
||||
isIPAdapterModelConfig,
|
||||
isLoRAModelConfig,
|
||||
isNonRefinerMainModelConfig,
|
||||
@@ -52,3 +53,4 @@ export const useSpandrelImageToImageModels = buildModelsHook(isSpandrelImageToIm
|
||||
export const useIPAdapterModels = buildModelsHook(isIPAdapterModelConfig);
|
||||
export const useEmbeddingModels = buildModelsHook(isTIModelConfig);
|
||||
export const useVAEModels = buildModelsHook(isVAEModelConfig);
|
||||
export const useFluxVAEModels = buildModelsHook(isFluxVAEModelConfig);
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -51,7 +51,7 @@ export type VAEModelConfig = S['VAECheckpointConfig'] | S['VAEDiffusersConfig'];
|
||||
export type ControlNetModelConfig = S['ControlNetDiffusersConfig'] | S['ControlNetCheckpointConfig'];
|
||||
export type IPAdapterModelConfig = S['IPAdapterInvokeAIConfig'] | S['IPAdapterCheckpointConfig'];
|
||||
export type T2IAdapterModelConfig = S['T2IAdapterConfig'];
|
||||
type ClipEmbedModelConfig = S['CLIPEmbedDiffusersConfig'];
|
||||
export type ClipEmbedModelConfig = S['CLIPEmbedDiffusersConfig'];
|
||||
export type T5EncoderModelConfig = S['T5EncoderConfig'];
|
||||
export type T5EncoderBnbQuantizedLlmInt8bModelConfig = S['T5EncoderBnbQuantizedLlmInt8bConfig'];
|
||||
export type SpandrelImageToImageModelConfig = S['SpandrelImageToImageConfig'];
|
||||
@@ -82,6 +82,10 @@ export const isVAEModelConfig = (config: AnyModelConfig): config is VAEModelConf
|
||||
return config.type === 'vae';
|
||||
};
|
||||
|
||||
export const isFluxVAEModelConfig = (config: AnyModelConfig): config is VAEModelConfig => {
|
||||
return config.type === 'vae' && config.base === 'flux';
|
||||
};
|
||||
|
||||
export const isControlNetModelConfig = (config: AnyModelConfig): config is ControlNetModelConfig => {
|
||||
return config.type === 'controlnet';
|
||||
};
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = "4.2.8post1"
|
||||
__version__ = "4.2.9rc1"
|
||||
|
||||
@@ -130,8 +130,6 @@ dependencies = [
|
||||
|
||||
[project.scripts]
|
||||
"invokeai-web" = "invokeai.app.run_app:run_app"
|
||||
"invokeai-import-images" = "invokeai.frontend.install.import_images:main"
|
||||
"invokeai-db-maintenance" = "invokeai.backend.util.db_maintenance:main"
|
||||
|
||||
[project.urls]
|
||||
"Homepage" = "https://invoke-ai.github.io/InvokeAI/"
|
||||
|
||||
63
scripts/allocate_vram.py
Normal file
63
scripts/allocate_vram.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def display_vram_usage():
|
||||
"""Displays the total, allocated, and free VRAM on the current CUDA device."""
|
||||
|
||||
assert torch.cuda.is_available(), "CUDA is not available"
|
||||
device = torch.device("cuda")
|
||||
|
||||
total_vram = torch.cuda.get_device_properties(device).total_memory
|
||||
allocated_vram = torch.cuda.memory_allocated(device)
|
||||
free_vram = total_vram - allocated_vram
|
||||
|
||||
print(f"Total VRAM: {total_vram / (1024 * 1024 * 1024):.2f} GB")
|
||||
print(f"Allocated VRAM: {allocated_vram / (1024 * 1024 * 1024):.2f} GB")
|
||||
print(f"Free VRAM: {free_vram / (1024 * 1024 * 1024):.2f} GB")
|
||||
|
||||
|
||||
def allocate_vram(target_gb: float, target_free: bool = False):
|
||||
"""Allocates VRAM on the current CUDA device. After allocation, the script will pause until the user presses Enter
|
||||
or ends the script, at which point the VRAM will be released.
|
||||
|
||||
Args:
|
||||
target_gb (float): Amount of VRAM to allocate in GB.
|
||||
target_free (bool, optional): Instead of allocating <target_gb> VRAM, enough VRAM will be allocated so the system has <target_gb> of VRAM free. For example, if <target_gb> is 2 GB, the script will allocate VRAM until the free VRAM is 2 GB.
|
||||
"""
|
||||
assert torch.cuda.is_available(), "CUDA is not available"
|
||||
device = torch.device("cuda")
|
||||
|
||||
if target_free:
|
||||
total_vram = torch.cuda.get_device_properties(device).total_memory
|
||||
free_vram = total_vram - torch.cuda.memory_allocated(device)
|
||||
target_free_bytes = target_gb * 1024 * 1024 * 1024
|
||||
bytes_to_allocate = free_vram - target_free_bytes
|
||||
|
||||
if bytes_to_allocate <= 0:
|
||||
print(f"Already at or below the target free VRAM of {target_gb} GB")
|
||||
return
|
||||
else:
|
||||
bytes_to_allocate = target_gb * 1024 * 1024 * 1024
|
||||
|
||||
# FloatTensor (4 bytes per element)
|
||||
_tensor = torch.empty(int(bytes_to_allocate / 4), dtype=torch.float, device="cuda")
|
||||
|
||||
display_vram_usage()
|
||||
|
||||
input("Press Enter to release VRAM allocation and exit...")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Allocate VRAM for testing purposes. Only works on CUDA devices.")
|
||||
parser.add_argument("target_gb", type=float, help="Amount of VRAM to allocate in GB.")
|
||||
parser.add_argument(
|
||||
"--target-free",
|
||||
action="store_true",
|
||||
help="Instead of allocating <target_gb> VRAM, enough VRAM will be allocated so the system has <target_gb> of VRAM free. For example, if <target_gb> is 2 GB, the script will allocate VRAM until the free VRAM is 2 GB.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
allocate_vram(target_gb=args.target_gb, target_free=args.target_free)
|
||||
42
tests/backend/flux/test_sampling_utils.py
Normal file
42
tests/backend/flux/test_sampling_utils.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from invokeai.backend.flux.sampling_utils import clip_timestep_schedule
|
||||
|
||||
|
||||
def float_lists_almost_equal(list1: list[float], list2: list[float], tol: float = 1e-6) -> bool:
|
||||
return all(abs(a - b) < tol for a, b in zip(list1, list2, strict=True))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["denoising_start", "denoising_end", "expected_timesteps", "raises"],
|
||||
[
|
||||
(0.0, 1.0, [1.0, 0.75, 0.5, 0.25, 0.0], False), # Default case.
|
||||
(-0.1, 1.0, [], True), # Negative denoising_start should raise.
|
||||
(0.0, 1.1, [], True), # denoising_end > 1 should raise.
|
||||
(0.5, 0.0, [], True), # denoising_start > denoising_end should raise.
|
||||
(0.0, 0.0, [1.0], False), # denoising_end == 0.
|
||||
(1.0, 1.0, [0.0], False), # denoising_start == 1.
|
||||
(0.2, 0.8, [1.0, 0.75, 0.5, 0.25], False), # Middle of the schedule.
|
||||
# If we denoise from 0.0 to x, then from x to 1.0, it is important that denoise_end = x and denoise_start = x
|
||||
# map to the same timestep. We test this first when x is equal to a timestep, then when it falls between two
|
||||
# timesteps.
|
||||
# x = 0.5
|
||||
(0.0, 0.5, [1.0, 0.75, 0.5], False),
|
||||
(0.5, 1.0, [0.5, 0.25, 0.0], False),
|
||||
# x = 0.3
|
||||
(0.0, 0.3, [1.0, 0.75], False),
|
||||
(0.3, 1.0, [0.75, 0.5, 0.25, 0.0], False),
|
||||
],
|
||||
)
|
||||
def test_clip_timestep_schedule(
|
||||
denoising_start: float, denoising_end: float, expected_timesteps: list[float], raises: bool
|
||||
):
|
||||
timesteps = torch.linspace(1, 0, 5).tolist()
|
||||
if raises:
|
||||
with pytest.raises(AssertionError):
|
||||
clip_timestep_schedule(timesteps, denoising_start, denoising_end)
|
||||
else:
|
||||
assert float_lists_almost_equal(
|
||||
clip_timestep_schedule(timesteps, denoising_start, denoising_end), expected_timesteps
|
||||
)
|
||||
@@ -1,12 +1,9 @@
|
||||
# test that if the model's device changes while the lora is applied, the weights can still be restored
|
||||
|
||||
# test that LoRA patching works on both CPU and CUDA
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora import LoRALayer, LoRAModelRaw
|
||||
from invokeai.backend.model_patcher import ModelPatcher
|
||||
from invokeai.backend.peft.layers.lora_layer import LoRALayer
|
||||
from invokeai.backend.peft.lora import LoRAModelRaw
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -38,7 +35,7 @@ def test_apply_lora(device):
|
||||
},
|
||||
)
|
||||
}
|
||||
lora = LoRAModelRaw("lora_name", lora_layers)
|
||||
lora = LoRAModelRaw(lora_layers)
|
||||
|
||||
lora_weight = 0.5
|
||||
orig_linear_weight = model["linear_layer_1"].weight.data.detach().clone()
|
||||
@@ -82,7 +79,7 @@ def test_apply_lora_change_device():
|
||||
},
|
||||
)
|
||||
}
|
||||
lora = LoRAModelRaw("lora_name", lora_layers)
|
||||
lora = LoRAModelRaw(lora_layers)
|
||||
|
||||
orig_linear_weight = model["linear_layer_1"].weight.data.detach().clone()
|
||||
|
||||
|
||||
@@ -0,0 +1,990 @@
|
||||
state_dict_keys = [
|
||||
"transformer.single_transformer_blocks.0.attn.to_k.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.0.attn.to_k.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.0.attn.to_q.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.0.attn.to_q.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.0.attn.to_v.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.0.attn.to_v.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.0.norm.linear.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.0.norm.linear.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.0.proj_mlp.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.0.proj_mlp.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.0.proj_out.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.0.proj_out.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.1.attn.to_k.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.1.attn.to_k.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.1.attn.to_q.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.1.attn.to_q.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.1.attn.to_v.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.1.attn.to_v.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.1.norm.linear.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.1.norm.linear.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.1.proj_mlp.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.1.proj_mlp.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.1.proj_out.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.1.proj_out.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.10.attn.to_k.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.10.attn.to_k.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.10.attn.to_q.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.10.attn.to_q.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.10.attn.to_v.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.10.attn.to_v.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.10.norm.linear.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.10.norm.linear.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.10.proj_mlp.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.10.proj_mlp.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.10.proj_out.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.10.proj_out.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.11.attn.to_k.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.11.attn.to_k.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.11.attn.to_q.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.11.attn.to_q.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.11.attn.to_v.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.11.attn.to_v.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.11.norm.linear.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.11.norm.linear.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.11.proj_mlp.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.11.proj_mlp.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.11.proj_out.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.11.proj_out.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.12.attn.to_k.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.12.attn.to_k.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.12.attn.to_q.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.12.attn.to_q.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.12.attn.to_v.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.12.attn.to_v.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.12.norm.linear.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.12.norm.linear.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.12.proj_mlp.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.12.proj_mlp.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.12.proj_out.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.12.proj_out.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.13.attn.to_k.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.13.attn.to_k.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.13.attn.to_q.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.13.attn.to_q.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.13.attn.to_v.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.13.attn.to_v.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.13.norm.linear.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.13.norm.linear.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.13.proj_mlp.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.13.proj_mlp.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.13.proj_out.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.13.proj_out.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.14.attn.to_k.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.14.attn.to_k.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.14.attn.to_q.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.14.attn.to_q.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.14.attn.to_v.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.14.attn.to_v.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.14.norm.linear.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.14.norm.linear.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.14.proj_mlp.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.14.proj_mlp.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.14.proj_out.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.14.proj_out.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.15.attn.to_k.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.15.attn.to_k.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.15.attn.to_q.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.15.attn.to_q.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.15.attn.to_v.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.15.attn.to_v.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.15.norm.linear.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.15.norm.linear.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.15.proj_mlp.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.15.proj_mlp.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.15.proj_out.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.15.proj_out.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.16.attn.to_k.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.16.attn.to_k.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.16.attn.to_q.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.16.attn.to_q.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.16.attn.to_v.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.16.attn.to_v.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.16.norm.linear.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.16.norm.linear.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.16.proj_mlp.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.16.proj_mlp.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.16.proj_out.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.16.proj_out.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.17.attn.to_k.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.17.attn.to_k.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.17.attn.to_q.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.17.attn.to_q.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.17.attn.to_v.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.17.attn.to_v.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.17.norm.linear.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.17.norm.linear.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.17.proj_mlp.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.17.proj_mlp.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.17.proj_out.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.17.proj_out.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.18.attn.to_k.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.18.attn.to_k.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.18.attn.to_q.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.18.attn.to_q.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.18.attn.to_v.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.18.attn.to_v.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.18.norm.linear.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.18.norm.linear.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.18.proj_mlp.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.18.proj_mlp.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.18.proj_out.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.18.proj_out.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.19.attn.to_k.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.19.attn.to_k.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.19.attn.to_q.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.19.attn.to_q.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.19.attn.to_v.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.19.attn.to_v.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.19.norm.linear.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.19.norm.linear.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.19.proj_mlp.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.19.proj_mlp.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.19.proj_out.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.19.proj_out.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.2.attn.to_k.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.2.attn.to_k.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.2.attn.to_q.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.2.attn.to_q.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.2.attn.to_v.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.2.attn.to_v.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.2.norm.linear.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.2.norm.linear.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.2.proj_mlp.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.2.proj_mlp.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.2.proj_out.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.2.proj_out.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.20.attn.to_k.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.20.attn.to_k.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.20.attn.to_q.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.20.attn.to_q.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.20.attn.to_v.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.20.attn.to_v.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.20.norm.linear.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.20.norm.linear.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.20.proj_mlp.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.20.proj_mlp.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.20.proj_out.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.20.proj_out.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.21.attn.to_k.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.21.attn.to_k.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.21.attn.to_q.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.21.attn.to_q.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.21.attn.to_v.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.21.attn.to_v.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.21.norm.linear.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.21.norm.linear.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.21.proj_mlp.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.21.proj_mlp.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.21.proj_out.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.21.proj_out.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.22.attn.to_k.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.22.attn.to_k.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.22.attn.to_q.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.22.attn.to_q.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.22.attn.to_v.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.22.attn.to_v.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.22.norm.linear.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.22.norm.linear.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.22.proj_mlp.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.22.proj_mlp.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.22.proj_out.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.22.proj_out.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.23.attn.to_k.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.23.attn.to_k.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.23.attn.to_q.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.23.attn.to_q.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.23.attn.to_v.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.23.attn.to_v.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.23.norm.linear.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.23.norm.linear.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.23.proj_mlp.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.23.proj_mlp.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.23.proj_out.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.23.proj_out.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.24.attn.to_k.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.24.attn.to_k.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.24.attn.to_q.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.24.attn.to_q.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.24.attn.to_v.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.24.attn.to_v.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.24.norm.linear.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.24.norm.linear.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.24.proj_mlp.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.24.proj_mlp.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.24.proj_out.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.24.proj_out.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.25.attn.to_k.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.25.attn.to_k.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.25.attn.to_q.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.25.attn.to_q.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.25.attn.to_v.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.25.attn.to_v.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.25.norm.linear.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.25.norm.linear.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.25.proj_mlp.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.25.proj_mlp.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.25.proj_out.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.25.proj_out.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.26.attn.to_k.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.26.attn.to_k.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.26.attn.to_q.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.26.attn.to_q.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.26.attn.to_v.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.26.attn.to_v.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.26.norm.linear.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.26.norm.linear.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.26.proj_mlp.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.26.proj_mlp.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.26.proj_out.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.26.proj_out.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.27.attn.to_k.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.27.attn.to_k.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.27.attn.to_q.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.27.attn.to_q.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.27.attn.to_v.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.27.attn.to_v.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.27.norm.linear.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.27.norm.linear.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.27.proj_mlp.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.27.proj_mlp.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.27.proj_out.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.27.proj_out.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.28.attn.to_k.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.28.attn.to_k.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.28.attn.to_q.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.28.attn.to_q.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.28.attn.to_v.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.28.attn.to_v.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.28.norm.linear.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.28.norm.linear.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.28.proj_mlp.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.28.proj_mlp.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.28.proj_out.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.28.proj_out.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.29.attn.to_k.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.29.attn.to_k.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.29.attn.to_q.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.29.attn.to_q.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.29.attn.to_v.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.29.attn.to_v.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.29.norm.linear.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.29.norm.linear.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.29.proj_mlp.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.29.proj_mlp.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.29.proj_out.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.29.proj_out.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.3.attn.to_k.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.3.attn.to_k.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.3.attn.to_q.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.3.attn.to_q.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.3.attn.to_v.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.3.attn.to_v.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.3.norm.linear.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.3.norm.linear.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.3.proj_mlp.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.3.proj_mlp.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.3.proj_out.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.3.proj_out.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.30.attn.to_k.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.30.attn.to_k.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.30.attn.to_q.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.30.attn.to_q.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.30.attn.to_v.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.30.attn.to_v.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.30.norm.linear.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.30.norm.linear.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.30.proj_mlp.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.30.proj_mlp.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.30.proj_out.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.30.proj_out.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.31.attn.to_k.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.31.attn.to_k.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.31.attn.to_q.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.31.attn.to_q.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.31.attn.to_v.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.31.attn.to_v.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.31.norm.linear.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.31.norm.linear.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.31.proj_mlp.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.31.proj_mlp.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.31.proj_out.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.31.proj_out.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.32.attn.to_k.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.32.attn.to_k.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.32.attn.to_q.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.32.attn.to_q.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.32.attn.to_v.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.32.attn.to_v.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.32.norm.linear.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.32.norm.linear.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.32.proj_mlp.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.32.proj_mlp.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.32.proj_out.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.32.proj_out.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.33.attn.to_k.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.33.attn.to_k.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.33.attn.to_q.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.33.attn.to_q.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.33.attn.to_v.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.33.attn.to_v.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.33.norm.linear.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.33.norm.linear.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.33.proj_mlp.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.33.proj_mlp.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.33.proj_out.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.33.proj_out.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.34.attn.to_k.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.34.attn.to_k.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.34.attn.to_q.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.34.attn.to_q.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.34.attn.to_v.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.34.attn.to_v.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.34.norm.linear.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.34.norm.linear.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.34.proj_mlp.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.34.proj_mlp.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.34.proj_out.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.34.proj_out.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.35.attn.to_k.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.35.attn.to_k.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.35.attn.to_q.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.35.attn.to_q.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.35.attn.to_v.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.35.attn.to_v.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.35.norm.linear.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.35.norm.linear.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.35.proj_mlp.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.35.proj_mlp.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.35.proj_out.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.35.proj_out.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.36.attn.to_k.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.36.attn.to_k.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.36.attn.to_q.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.36.attn.to_q.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.36.attn.to_v.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.36.attn.to_v.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.36.norm.linear.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.36.norm.linear.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.36.proj_mlp.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.36.proj_mlp.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.36.proj_out.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.36.proj_out.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.37.attn.to_k.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.37.attn.to_k.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.37.attn.to_q.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.37.attn.to_q.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.37.attn.to_v.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.37.attn.to_v.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.37.norm.linear.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.37.norm.linear.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.37.proj_mlp.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.37.proj_mlp.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.37.proj_out.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.37.proj_out.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.4.attn.to_k.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.4.attn.to_k.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.4.attn.to_q.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.4.attn.to_q.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.4.attn.to_v.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.4.attn.to_v.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.4.norm.linear.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.4.norm.linear.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.4.proj_mlp.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.4.proj_mlp.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.4.proj_out.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.4.proj_out.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.5.attn.to_k.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.5.attn.to_k.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.5.attn.to_q.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.5.attn.to_q.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.5.attn.to_v.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.5.attn.to_v.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.5.norm.linear.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.5.norm.linear.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.5.proj_mlp.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.5.proj_mlp.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.5.proj_out.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.5.proj_out.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.6.attn.to_k.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.6.attn.to_k.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.6.attn.to_q.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.6.attn.to_q.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.6.attn.to_v.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.6.attn.to_v.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.6.norm.linear.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.6.norm.linear.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.6.proj_mlp.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.6.proj_mlp.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.6.proj_out.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.6.proj_out.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.7.attn.to_k.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.7.attn.to_k.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.7.attn.to_q.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.7.attn.to_q.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.7.attn.to_v.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.7.attn.to_v.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.7.norm.linear.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.7.norm.linear.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.7.proj_mlp.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.7.proj_mlp.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.7.proj_out.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.7.proj_out.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.8.attn.to_k.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.8.attn.to_k.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.8.attn.to_q.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.8.attn.to_q.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.8.attn.to_v.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.8.attn.to_v.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.8.norm.linear.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.8.norm.linear.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.8.proj_mlp.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.8.proj_mlp.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.8.proj_out.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.8.proj_out.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.9.attn.to_k.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.9.attn.to_k.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.9.attn.to_q.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.9.attn.to_q.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.9.attn.to_v.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.9.attn.to_v.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.9.norm.linear.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.9.norm.linear.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.9.proj_mlp.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.9.proj_mlp.lora_B.weight",
|
||||
"transformer.single_transformer_blocks.9.proj_out.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.9.proj_out.lora_B.weight",
|
||||
"transformer.transformer_blocks.0.attn.add_k_proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.0.attn.add_k_proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.0.attn.add_q_proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.0.attn.add_q_proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.0.attn.add_v_proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.0.attn.add_v_proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.0.attn.to_add_out.lora_A.weight",
|
||||
"transformer.transformer_blocks.0.attn.to_add_out.lora_B.weight",
|
||||
"transformer.transformer_blocks.0.attn.to_k.lora_A.weight",
|
||||
"transformer.transformer_blocks.0.attn.to_k.lora_B.weight",
|
||||
"transformer.transformer_blocks.0.attn.to_out.0.lora_A.weight",
|
||||
"transformer.transformer_blocks.0.attn.to_out.0.lora_B.weight",
|
||||
"transformer.transformer_blocks.0.attn.to_q.lora_A.weight",
|
||||
"transformer.transformer_blocks.0.attn.to_q.lora_B.weight",
|
||||
"transformer.transformer_blocks.0.attn.to_v.lora_A.weight",
|
||||
"transformer.transformer_blocks.0.attn.to_v.lora_B.weight",
|
||||
"transformer.transformer_blocks.0.ff.net.0.proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.0.ff.net.0.proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.0.ff.net.2.lora_A.weight",
|
||||
"transformer.transformer_blocks.0.ff.net.2.lora_B.weight",
|
||||
"transformer.transformer_blocks.0.ff_context.net.0.proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.0.ff_context.net.0.proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.0.ff_context.net.2.lora_A.weight",
|
||||
"transformer.transformer_blocks.0.ff_context.net.2.lora_B.weight",
|
||||
"transformer.transformer_blocks.0.norm1.linear.lora_A.weight",
|
||||
"transformer.transformer_blocks.0.norm1.linear.lora_B.weight",
|
||||
"transformer.transformer_blocks.0.norm1_context.linear.lora_A.weight",
|
||||
"transformer.transformer_blocks.0.norm1_context.linear.lora_B.weight",
|
||||
"transformer.transformer_blocks.1.attn.add_k_proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.1.attn.add_k_proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.1.attn.add_q_proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.1.attn.add_q_proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.1.attn.add_v_proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.1.attn.add_v_proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.1.attn.to_add_out.lora_A.weight",
|
||||
"transformer.transformer_blocks.1.attn.to_add_out.lora_B.weight",
|
||||
"transformer.transformer_blocks.1.attn.to_k.lora_A.weight",
|
||||
"transformer.transformer_blocks.1.attn.to_k.lora_B.weight",
|
||||
"transformer.transformer_blocks.1.attn.to_out.0.lora_A.weight",
|
||||
"transformer.transformer_blocks.1.attn.to_out.0.lora_B.weight",
|
||||
"transformer.transformer_blocks.1.attn.to_q.lora_A.weight",
|
||||
"transformer.transformer_blocks.1.attn.to_q.lora_B.weight",
|
||||
"transformer.transformer_blocks.1.attn.to_v.lora_A.weight",
|
||||
"transformer.transformer_blocks.1.attn.to_v.lora_B.weight",
|
||||
"transformer.transformer_blocks.1.ff.net.0.proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.1.ff.net.0.proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.1.ff.net.2.lora_A.weight",
|
||||
"transformer.transformer_blocks.1.ff.net.2.lora_B.weight",
|
||||
"transformer.transformer_blocks.1.ff_context.net.0.proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.1.ff_context.net.0.proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.1.ff_context.net.2.lora_A.weight",
|
||||
"transformer.transformer_blocks.1.ff_context.net.2.lora_B.weight",
|
||||
"transformer.transformer_blocks.1.norm1.linear.lora_A.weight",
|
||||
"transformer.transformer_blocks.1.norm1.linear.lora_B.weight",
|
||||
"transformer.transformer_blocks.1.norm1_context.linear.lora_A.weight",
|
||||
"transformer.transformer_blocks.1.norm1_context.linear.lora_B.weight",
|
||||
"transformer.transformer_blocks.10.attn.add_k_proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.10.attn.add_k_proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.10.attn.add_q_proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.10.attn.add_q_proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.10.attn.add_v_proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.10.attn.add_v_proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.10.attn.to_add_out.lora_A.weight",
|
||||
"transformer.transformer_blocks.10.attn.to_add_out.lora_B.weight",
|
||||
"transformer.transformer_blocks.10.attn.to_k.lora_A.weight",
|
||||
"transformer.transformer_blocks.10.attn.to_k.lora_B.weight",
|
||||
"transformer.transformer_blocks.10.attn.to_out.0.lora_A.weight",
|
||||
"transformer.transformer_blocks.10.attn.to_out.0.lora_B.weight",
|
||||
"transformer.transformer_blocks.10.attn.to_q.lora_A.weight",
|
||||
"transformer.transformer_blocks.10.attn.to_q.lora_B.weight",
|
||||
"transformer.transformer_blocks.10.attn.to_v.lora_A.weight",
|
||||
"transformer.transformer_blocks.10.attn.to_v.lora_B.weight",
|
||||
"transformer.transformer_blocks.10.ff.net.0.proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.10.ff.net.0.proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.10.ff.net.2.lora_A.weight",
|
||||
"transformer.transformer_blocks.10.ff.net.2.lora_B.weight",
|
||||
"transformer.transformer_blocks.10.ff_context.net.0.proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.10.ff_context.net.0.proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.10.ff_context.net.2.lora_A.weight",
|
||||
"transformer.transformer_blocks.10.ff_context.net.2.lora_B.weight",
|
||||
"transformer.transformer_blocks.10.norm1.linear.lora_A.weight",
|
||||
"transformer.transformer_blocks.10.norm1.linear.lora_B.weight",
|
||||
"transformer.transformer_blocks.10.norm1_context.linear.lora_A.weight",
|
||||
"transformer.transformer_blocks.10.norm1_context.linear.lora_B.weight",
|
||||
"transformer.transformer_blocks.11.attn.add_k_proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.11.attn.add_k_proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.11.attn.add_q_proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.11.attn.add_q_proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.11.attn.add_v_proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.11.attn.add_v_proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.11.attn.to_add_out.lora_A.weight",
|
||||
"transformer.transformer_blocks.11.attn.to_add_out.lora_B.weight",
|
||||
"transformer.transformer_blocks.11.attn.to_k.lora_A.weight",
|
||||
"transformer.transformer_blocks.11.attn.to_k.lora_B.weight",
|
||||
"transformer.transformer_blocks.11.attn.to_out.0.lora_A.weight",
|
||||
"transformer.transformer_blocks.11.attn.to_out.0.lora_B.weight",
|
||||
"transformer.transformer_blocks.11.attn.to_q.lora_A.weight",
|
||||
"transformer.transformer_blocks.11.attn.to_q.lora_B.weight",
|
||||
"transformer.transformer_blocks.11.attn.to_v.lora_A.weight",
|
||||
"transformer.transformer_blocks.11.attn.to_v.lora_B.weight",
|
||||
"transformer.transformer_blocks.11.ff.net.0.proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.11.ff.net.0.proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.11.ff.net.2.lora_A.weight",
|
||||
"transformer.transformer_blocks.11.ff.net.2.lora_B.weight",
|
||||
"transformer.transformer_blocks.11.ff_context.net.0.proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.11.ff_context.net.0.proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.11.ff_context.net.2.lora_A.weight",
|
||||
"transformer.transformer_blocks.11.ff_context.net.2.lora_B.weight",
|
||||
"transformer.transformer_blocks.11.norm1.linear.lora_A.weight",
|
||||
"transformer.transformer_blocks.11.norm1.linear.lora_B.weight",
|
||||
"transformer.transformer_blocks.11.norm1_context.linear.lora_A.weight",
|
||||
"transformer.transformer_blocks.11.norm1_context.linear.lora_B.weight",
|
||||
"transformer.transformer_blocks.12.attn.add_k_proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.12.attn.add_k_proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.12.attn.add_q_proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.12.attn.add_q_proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.12.attn.add_v_proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.12.attn.add_v_proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.12.attn.to_add_out.lora_A.weight",
|
||||
"transformer.transformer_blocks.12.attn.to_add_out.lora_B.weight",
|
||||
"transformer.transformer_blocks.12.attn.to_k.lora_A.weight",
|
||||
"transformer.transformer_blocks.12.attn.to_k.lora_B.weight",
|
||||
"transformer.transformer_blocks.12.attn.to_out.0.lora_A.weight",
|
||||
"transformer.transformer_blocks.12.attn.to_out.0.lora_B.weight",
|
||||
"transformer.transformer_blocks.12.attn.to_q.lora_A.weight",
|
||||
"transformer.transformer_blocks.12.attn.to_q.lora_B.weight",
|
||||
"transformer.transformer_blocks.12.attn.to_v.lora_A.weight",
|
||||
"transformer.transformer_blocks.12.attn.to_v.lora_B.weight",
|
||||
"transformer.transformer_blocks.12.ff.net.0.proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.12.ff.net.0.proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.12.ff.net.2.lora_A.weight",
|
||||
"transformer.transformer_blocks.12.ff.net.2.lora_B.weight",
|
||||
"transformer.transformer_blocks.12.ff_context.net.0.proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.12.ff_context.net.0.proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.12.ff_context.net.2.lora_A.weight",
|
||||
"transformer.transformer_blocks.12.ff_context.net.2.lora_B.weight",
|
||||
"transformer.transformer_blocks.12.norm1.linear.lora_A.weight",
|
||||
"transformer.transformer_blocks.12.norm1.linear.lora_B.weight",
|
||||
"transformer.transformer_blocks.12.norm1_context.linear.lora_A.weight",
|
||||
"transformer.transformer_blocks.12.norm1_context.linear.lora_B.weight",
|
||||
"transformer.transformer_blocks.13.attn.add_k_proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.13.attn.add_k_proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.13.attn.add_q_proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.13.attn.add_q_proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.13.attn.add_v_proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.13.attn.add_v_proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.13.attn.to_add_out.lora_A.weight",
|
||||
"transformer.transformer_blocks.13.attn.to_add_out.lora_B.weight",
|
||||
"transformer.transformer_blocks.13.attn.to_k.lora_A.weight",
|
||||
"transformer.transformer_blocks.13.attn.to_k.lora_B.weight",
|
||||
"transformer.transformer_blocks.13.attn.to_out.0.lora_A.weight",
|
||||
"transformer.transformer_blocks.13.attn.to_out.0.lora_B.weight",
|
||||
"transformer.transformer_blocks.13.attn.to_q.lora_A.weight",
|
||||
"transformer.transformer_blocks.13.attn.to_q.lora_B.weight",
|
||||
"transformer.transformer_blocks.13.attn.to_v.lora_A.weight",
|
||||
"transformer.transformer_blocks.13.attn.to_v.lora_B.weight",
|
||||
"transformer.transformer_blocks.13.ff.net.0.proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.13.ff.net.0.proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.13.ff.net.2.lora_A.weight",
|
||||
"transformer.transformer_blocks.13.ff.net.2.lora_B.weight",
|
||||
"transformer.transformer_blocks.13.ff_context.net.0.proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.13.ff_context.net.0.proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.13.ff_context.net.2.lora_A.weight",
|
||||
"transformer.transformer_blocks.13.ff_context.net.2.lora_B.weight",
|
||||
"transformer.transformer_blocks.13.norm1.linear.lora_A.weight",
|
||||
"transformer.transformer_blocks.13.norm1.linear.lora_B.weight",
|
||||
"transformer.transformer_blocks.13.norm1_context.linear.lora_A.weight",
|
||||
"transformer.transformer_blocks.13.norm1_context.linear.lora_B.weight",
|
||||
"transformer.transformer_blocks.14.attn.add_k_proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.14.attn.add_k_proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.14.attn.add_q_proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.14.attn.add_q_proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.14.attn.add_v_proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.14.attn.add_v_proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.14.attn.to_add_out.lora_A.weight",
|
||||
"transformer.transformer_blocks.14.attn.to_add_out.lora_B.weight",
|
||||
"transformer.transformer_blocks.14.attn.to_k.lora_A.weight",
|
||||
"transformer.transformer_blocks.14.attn.to_k.lora_B.weight",
|
||||
"transformer.transformer_blocks.14.attn.to_out.0.lora_A.weight",
|
||||
"transformer.transformer_blocks.14.attn.to_out.0.lora_B.weight",
|
||||
"transformer.transformer_blocks.14.attn.to_q.lora_A.weight",
|
||||
"transformer.transformer_blocks.14.attn.to_q.lora_B.weight",
|
||||
"transformer.transformer_blocks.14.attn.to_v.lora_A.weight",
|
||||
"transformer.transformer_blocks.14.attn.to_v.lora_B.weight",
|
||||
"transformer.transformer_blocks.14.ff.net.0.proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.14.ff.net.0.proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.14.ff.net.2.lora_A.weight",
|
||||
"transformer.transformer_blocks.14.ff.net.2.lora_B.weight",
|
||||
"transformer.transformer_blocks.14.ff_context.net.0.proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.14.ff_context.net.0.proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.14.ff_context.net.2.lora_A.weight",
|
||||
"transformer.transformer_blocks.14.ff_context.net.2.lora_B.weight",
|
||||
"transformer.transformer_blocks.14.norm1.linear.lora_A.weight",
|
||||
"transformer.transformer_blocks.14.norm1.linear.lora_B.weight",
|
||||
"transformer.transformer_blocks.14.norm1_context.linear.lora_A.weight",
|
||||
"transformer.transformer_blocks.14.norm1_context.linear.lora_B.weight",
|
||||
"transformer.transformer_blocks.15.attn.add_k_proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.15.attn.add_k_proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.15.attn.add_q_proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.15.attn.add_q_proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.15.attn.add_v_proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.15.attn.add_v_proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.15.attn.to_add_out.lora_A.weight",
|
||||
"transformer.transformer_blocks.15.attn.to_add_out.lora_B.weight",
|
||||
"transformer.transformer_blocks.15.attn.to_k.lora_A.weight",
|
||||
"transformer.transformer_blocks.15.attn.to_k.lora_B.weight",
|
||||
"transformer.transformer_blocks.15.attn.to_out.0.lora_A.weight",
|
||||
"transformer.transformer_blocks.15.attn.to_out.0.lora_B.weight",
|
||||
"transformer.transformer_blocks.15.attn.to_q.lora_A.weight",
|
||||
"transformer.transformer_blocks.15.attn.to_q.lora_B.weight",
|
||||
"transformer.transformer_blocks.15.attn.to_v.lora_A.weight",
|
||||
"transformer.transformer_blocks.15.attn.to_v.lora_B.weight",
|
||||
"transformer.transformer_blocks.15.ff.net.0.proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.15.ff.net.0.proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.15.ff.net.2.lora_A.weight",
|
||||
"transformer.transformer_blocks.15.ff.net.2.lora_B.weight",
|
||||
"transformer.transformer_blocks.15.ff_context.net.0.proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.15.ff_context.net.0.proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.15.ff_context.net.2.lora_A.weight",
|
||||
"transformer.transformer_blocks.15.ff_context.net.2.lora_B.weight",
|
||||
"transformer.transformer_blocks.15.norm1.linear.lora_A.weight",
|
||||
"transformer.transformer_blocks.15.norm1.linear.lora_B.weight",
|
||||
"transformer.transformer_blocks.15.norm1_context.linear.lora_A.weight",
|
||||
"transformer.transformer_blocks.15.norm1_context.linear.lora_B.weight",
|
||||
"transformer.transformer_blocks.16.attn.add_k_proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.16.attn.add_k_proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.16.attn.add_q_proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.16.attn.add_q_proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.16.attn.add_v_proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.16.attn.add_v_proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.16.attn.to_add_out.lora_A.weight",
|
||||
"transformer.transformer_blocks.16.attn.to_add_out.lora_B.weight",
|
||||
"transformer.transformer_blocks.16.attn.to_k.lora_A.weight",
|
||||
"transformer.transformer_blocks.16.attn.to_k.lora_B.weight",
|
||||
"transformer.transformer_blocks.16.attn.to_out.0.lora_A.weight",
|
||||
"transformer.transformer_blocks.16.attn.to_out.0.lora_B.weight",
|
||||
"transformer.transformer_blocks.16.attn.to_q.lora_A.weight",
|
||||
"transformer.transformer_blocks.16.attn.to_q.lora_B.weight",
|
||||
"transformer.transformer_blocks.16.attn.to_v.lora_A.weight",
|
||||
"transformer.transformer_blocks.16.attn.to_v.lora_B.weight",
|
||||
"transformer.transformer_blocks.16.ff.net.0.proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.16.ff.net.0.proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.16.ff.net.2.lora_A.weight",
|
||||
"transformer.transformer_blocks.16.ff.net.2.lora_B.weight",
|
||||
"transformer.transformer_blocks.16.ff_context.net.0.proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.16.ff_context.net.0.proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.16.ff_context.net.2.lora_A.weight",
|
||||
"transformer.transformer_blocks.16.ff_context.net.2.lora_B.weight",
|
||||
"transformer.transformer_blocks.16.norm1.linear.lora_A.weight",
|
||||
"transformer.transformer_blocks.16.norm1.linear.lora_B.weight",
|
||||
"transformer.transformer_blocks.16.norm1_context.linear.lora_A.weight",
|
||||
"transformer.transformer_blocks.16.norm1_context.linear.lora_B.weight",
|
||||
"transformer.transformer_blocks.17.attn.add_k_proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.17.attn.add_k_proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.17.attn.add_q_proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.17.attn.add_q_proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.17.attn.add_v_proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.17.attn.add_v_proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.17.attn.to_add_out.lora_A.weight",
|
||||
"transformer.transformer_blocks.17.attn.to_add_out.lora_B.weight",
|
||||
"transformer.transformer_blocks.17.attn.to_k.lora_A.weight",
|
||||
"transformer.transformer_blocks.17.attn.to_k.lora_B.weight",
|
||||
"transformer.transformer_blocks.17.attn.to_out.0.lora_A.weight",
|
||||
"transformer.transformer_blocks.17.attn.to_out.0.lora_B.weight",
|
||||
"transformer.transformer_blocks.17.attn.to_q.lora_A.weight",
|
||||
"transformer.transformer_blocks.17.attn.to_q.lora_B.weight",
|
||||
"transformer.transformer_blocks.17.attn.to_v.lora_A.weight",
|
||||
"transformer.transformer_blocks.17.attn.to_v.lora_B.weight",
|
||||
"transformer.transformer_blocks.17.ff.net.0.proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.17.ff.net.0.proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.17.ff.net.2.lora_A.weight",
|
||||
"transformer.transformer_blocks.17.ff.net.2.lora_B.weight",
|
||||
"transformer.transformer_blocks.17.ff_context.net.0.proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.17.ff_context.net.0.proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.17.ff_context.net.2.lora_A.weight",
|
||||
"transformer.transformer_blocks.17.ff_context.net.2.lora_B.weight",
|
||||
"transformer.transformer_blocks.17.norm1.linear.lora_A.weight",
|
||||
"transformer.transformer_blocks.17.norm1.linear.lora_B.weight",
|
||||
"transformer.transformer_blocks.17.norm1_context.linear.lora_A.weight",
|
||||
"transformer.transformer_blocks.17.norm1_context.linear.lora_B.weight",
|
||||
"transformer.transformer_blocks.18.attn.add_k_proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.18.attn.add_k_proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.18.attn.add_q_proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.18.attn.add_q_proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.18.attn.add_v_proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.18.attn.add_v_proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.18.attn.to_add_out.lora_A.weight",
|
||||
"transformer.transformer_blocks.18.attn.to_add_out.lora_B.weight",
|
||||
"transformer.transformer_blocks.18.attn.to_k.lora_A.weight",
|
||||
"transformer.transformer_blocks.18.attn.to_k.lora_B.weight",
|
||||
"transformer.transformer_blocks.18.attn.to_out.0.lora_A.weight",
|
||||
"transformer.transformer_blocks.18.attn.to_out.0.lora_B.weight",
|
||||
"transformer.transformer_blocks.18.attn.to_q.lora_A.weight",
|
||||
"transformer.transformer_blocks.18.attn.to_q.lora_B.weight",
|
||||
"transformer.transformer_blocks.18.attn.to_v.lora_A.weight",
|
||||
"transformer.transformer_blocks.18.attn.to_v.lora_B.weight",
|
||||
"transformer.transformer_blocks.18.ff.net.0.proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.18.ff.net.0.proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.18.ff.net.2.lora_A.weight",
|
||||
"transformer.transformer_blocks.18.ff.net.2.lora_B.weight",
|
||||
"transformer.transformer_blocks.18.ff_context.net.0.proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.18.ff_context.net.0.proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.18.ff_context.net.2.lora_A.weight",
|
||||
"transformer.transformer_blocks.18.ff_context.net.2.lora_B.weight",
|
||||
"transformer.transformer_blocks.18.norm1.linear.lora_A.weight",
|
||||
"transformer.transformer_blocks.18.norm1.linear.lora_B.weight",
|
||||
"transformer.transformer_blocks.18.norm1_context.linear.lora_A.weight",
|
||||
"transformer.transformer_blocks.18.norm1_context.linear.lora_B.weight",
|
||||
"transformer.transformer_blocks.2.attn.add_k_proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.2.attn.add_k_proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.2.attn.add_q_proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.2.attn.add_q_proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.2.attn.add_v_proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.2.attn.add_v_proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.2.attn.to_add_out.lora_A.weight",
|
||||
"transformer.transformer_blocks.2.attn.to_add_out.lora_B.weight",
|
||||
"transformer.transformer_blocks.2.attn.to_k.lora_A.weight",
|
||||
"transformer.transformer_blocks.2.attn.to_k.lora_B.weight",
|
||||
"transformer.transformer_blocks.2.attn.to_out.0.lora_A.weight",
|
||||
"transformer.transformer_blocks.2.attn.to_out.0.lora_B.weight",
|
||||
"transformer.transformer_blocks.2.attn.to_q.lora_A.weight",
|
||||
"transformer.transformer_blocks.2.attn.to_q.lora_B.weight",
|
||||
"transformer.transformer_blocks.2.attn.to_v.lora_A.weight",
|
||||
"transformer.transformer_blocks.2.attn.to_v.lora_B.weight",
|
||||
"transformer.transformer_blocks.2.ff.net.0.proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.2.ff.net.0.proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.2.ff.net.2.lora_A.weight",
|
||||
"transformer.transformer_blocks.2.ff.net.2.lora_B.weight",
|
||||
"transformer.transformer_blocks.2.ff_context.net.0.proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.2.ff_context.net.0.proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.2.ff_context.net.2.lora_A.weight",
|
||||
"transformer.transformer_blocks.2.ff_context.net.2.lora_B.weight",
|
||||
"transformer.transformer_blocks.2.norm1.linear.lora_A.weight",
|
||||
"transformer.transformer_blocks.2.norm1.linear.lora_B.weight",
|
||||
"transformer.transformer_blocks.2.norm1_context.linear.lora_A.weight",
|
||||
"transformer.transformer_blocks.2.norm1_context.linear.lora_B.weight",
|
||||
"transformer.transformer_blocks.3.attn.add_k_proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.3.attn.add_k_proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.3.attn.add_q_proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.3.attn.add_q_proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.3.attn.add_v_proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.3.attn.add_v_proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.3.attn.to_add_out.lora_A.weight",
|
||||
"transformer.transformer_blocks.3.attn.to_add_out.lora_B.weight",
|
||||
"transformer.transformer_blocks.3.attn.to_k.lora_A.weight",
|
||||
"transformer.transformer_blocks.3.attn.to_k.lora_B.weight",
|
||||
"transformer.transformer_blocks.3.attn.to_out.0.lora_A.weight",
|
||||
"transformer.transformer_blocks.3.attn.to_out.0.lora_B.weight",
|
||||
"transformer.transformer_blocks.3.attn.to_q.lora_A.weight",
|
||||
"transformer.transformer_blocks.3.attn.to_q.lora_B.weight",
|
||||
"transformer.transformer_blocks.3.attn.to_v.lora_A.weight",
|
||||
"transformer.transformer_blocks.3.attn.to_v.lora_B.weight",
|
||||
"transformer.transformer_blocks.3.ff.net.0.proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.3.ff.net.0.proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.3.ff.net.2.lora_A.weight",
|
||||
"transformer.transformer_blocks.3.ff.net.2.lora_B.weight",
|
||||
"transformer.transformer_blocks.3.ff_context.net.0.proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.3.ff_context.net.0.proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.3.ff_context.net.2.lora_A.weight",
|
||||
"transformer.transformer_blocks.3.ff_context.net.2.lora_B.weight",
|
||||
"transformer.transformer_blocks.3.norm1.linear.lora_A.weight",
|
||||
"transformer.transformer_blocks.3.norm1.linear.lora_B.weight",
|
||||
"transformer.transformer_blocks.3.norm1_context.linear.lora_A.weight",
|
||||
"transformer.transformer_blocks.3.norm1_context.linear.lora_B.weight",
|
||||
"transformer.transformer_blocks.4.attn.add_k_proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.4.attn.add_k_proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.4.attn.add_q_proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.4.attn.add_q_proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.4.attn.add_v_proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.4.attn.add_v_proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.4.attn.to_add_out.lora_A.weight",
|
||||
"transformer.transformer_blocks.4.attn.to_add_out.lora_B.weight",
|
||||
"transformer.transformer_blocks.4.attn.to_k.lora_A.weight",
|
||||
"transformer.transformer_blocks.4.attn.to_k.lora_B.weight",
|
||||
"transformer.transformer_blocks.4.attn.to_out.0.lora_A.weight",
|
||||
"transformer.transformer_blocks.4.attn.to_out.0.lora_B.weight",
|
||||
"transformer.transformer_blocks.4.attn.to_q.lora_A.weight",
|
||||
"transformer.transformer_blocks.4.attn.to_q.lora_B.weight",
|
||||
"transformer.transformer_blocks.4.attn.to_v.lora_A.weight",
|
||||
"transformer.transformer_blocks.4.attn.to_v.lora_B.weight",
|
||||
"transformer.transformer_blocks.4.ff.net.0.proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.4.ff.net.0.proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.4.ff.net.2.lora_A.weight",
|
||||
"transformer.transformer_blocks.4.ff.net.2.lora_B.weight",
|
||||
"transformer.transformer_blocks.4.ff_context.net.0.proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.4.ff_context.net.0.proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.4.ff_context.net.2.lora_A.weight",
|
||||
"transformer.transformer_blocks.4.ff_context.net.2.lora_B.weight",
|
||||
"transformer.transformer_blocks.4.norm1.linear.lora_A.weight",
|
||||
"transformer.transformer_blocks.4.norm1.linear.lora_B.weight",
|
||||
"transformer.transformer_blocks.4.norm1_context.linear.lora_A.weight",
|
||||
"transformer.transformer_blocks.4.norm1_context.linear.lora_B.weight",
|
||||
"transformer.transformer_blocks.5.attn.add_k_proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.5.attn.add_k_proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.5.attn.add_q_proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.5.attn.add_q_proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.5.attn.add_v_proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.5.attn.add_v_proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.5.attn.to_add_out.lora_A.weight",
|
||||
"transformer.transformer_blocks.5.attn.to_add_out.lora_B.weight",
|
||||
"transformer.transformer_blocks.5.attn.to_k.lora_A.weight",
|
||||
"transformer.transformer_blocks.5.attn.to_k.lora_B.weight",
|
||||
"transformer.transformer_blocks.5.attn.to_out.0.lora_A.weight",
|
||||
"transformer.transformer_blocks.5.attn.to_out.0.lora_B.weight",
|
||||
"transformer.transformer_blocks.5.attn.to_q.lora_A.weight",
|
||||
"transformer.transformer_blocks.5.attn.to_q.lora_B.weight",
|
||||
"transformer.transformer_blocks.5.attn.to_v.lora_A.weight",
|
||||
"transformer.transformer_blocks.5.attn.to_v.lora_B.weight",
|
||||
"transformer.transformer_blocks.5.ff.net.0.proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.5.ff.net.0.proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.5.ff.net.2.lora_A.weight",
|
||||
"transformer.transformer_blocks.5.ff.net.2.lora_B.weight",
|
||||
"transformer.transformer_blocks.5.ff_context.net.0.proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.5.ff_context.net.0.proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.5.ff_context.net.2.lora_A.weight",
|
||||
"transformer.transformer_blocks.5.ff_context.net.2.lora_B.weight",
|
||||
"transformer.transformer_blocks.5.norm1.linear.lora_A.weight",
|
||||
"transformer.transformer_blocks.5.norm1.linear.lora_B.weight",
|
||||
"transformer.transformer_blocks.5.norm1_context.linear.lora_A.weight",
|
||||
"transformer.transformer_blocks.5.norm1_context.linear.lora_B.weight",
|
||||
"transformer.transformer_blocks.6.attn.add_k_proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.6.attn.add_k_proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.6.attn.add_q_proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.6.attn.add_q_proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.6.attn.add_v_proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.6.attn.add_v_proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.6.attn.to_add_out.lora_A.weight",
|
||||
"transformer.transformer_blocks.6.attn.to_add_out.lora_B.weight",
|
||||
"transformer.transformer_blocks.6.attn.to_k.lora_A.weight",
|
||||
"transformer.transformer_blocks.6.attn.to_k.lora_B.weight",
|
||||
"transformer.transformer_blocks.6.attn.to_out.0.lora_A.weight",
|
||||
"transformer.transformer_blocks.6.attn.to_out.0.lora_B.weight",
|
||||
"transformer.transformer_blocks.6.attn.to_q.lora_A.weight",
|
||||
"transformer.transformer_blocks.6.attn.to_q.lora_B.weight",
|
||||
"transformer.transformer_blocks.6.attn.to_v.lora_A.weight",
|
||||
"transformer.transformer_blocks.6.attn.to_v.lora_B.weight",
|
||||
"transformer.transformer_blocks.6.ff.net.0.proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.6.ff.net.0.proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.6.ff.net.2.lora_A.weight",
|
||||
"transformer.transformer_blocks.6.ff.net.2.lora_B.weight",
|
||||
"transformer.transformer_blocks.6.ff_context.net.0.proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.6.ff_context.net.0.proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.6.ff_context.net.2.lora_A.weight",
|
||||
"transformer.transformer_blocks.6.ff_context.net.2.lora_B.weight",
|
||||
"transformer.transformer_blocks.6.norm1.linear.lora_A.weight",
|
||||
"transformer.transformer_blocks.6.norm1.linear.lora_B.weight",
|
||||
"transformer.transformer_blocks.6.norm1_context.linear.lora_A.weight",
|
||||
"transformer.transformer_blocks.6.norm1_context.linear.lora_B.weight",
|
||||
"transformer.transformer_blocks.7.attn.add_k_proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.7.attn.add_k_proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.7.attn.add_q_proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.7.attn.add_q_proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.7.attn.add_v_proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.7.attn.add_v_proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.7.attn.to_add_out.lora_A.weight",
|
||||
"transformer.transformer_blocks.7.attn.to_add_out.lora_B.weight",
|
||||
"transformer.transformer_blocks.7.attn.to_k.lora_A.weight",
|
||||
"transformer.transformer_blocks.7.attn.to_k.lora_B.weight",
|
||||
"transformer.transformer_blocks.7.attn.to_out.0.lora_A.weight",
|
||||
"transformer.transformer_blocks.7.attn.to_out.0.lora_B.weight",
|
||||
"transformer.transformer_blocks.7.attn.to_q.lora_A.weight",
|
||||
"transformer.transformer_blocks.7.attn.to_q.lora_B.weight",
|
||||
"transformer.transformer_blocks.7.attn.to_v.lora_A.weight",
|
||||
"transformer.transformer_blocks.7.attn.to_v.lora_B.weight",
|
||||
"transformer.transformer_blocks.7.ff.net.0.proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.7.ff.net.0.proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.7.ff.net.2.lora_A.weight",
|
||||
"transformer.transformer_blocks.7.ff.net.2.lora_B.weight",
|
||||
"transformer.transformer_blocks.7.ff_context.net.0.proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.7.ff_context.net.0.proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.7.ff_context.net.2.lora_A.weight",
|
||||
"transformer.transformer_blocks.7.ff_context.net.2.lora_B.weight",
|
||||
"transformer.transformer_blocks.7.norm1.linear.lora_A.weight",
|
||||
"transformer.transformer_blocks.7.norm1.linear.lora_B.weight",
|
||||
"transformer.transformer_blocks.7.norm1_context.linear.lora_A.weight",
|
||||
"transformer.transformer_blocks.7.norm1_context.linear.lora_B.weight",
|
||||
"transformer.transformer_blocks.8.attn.add_k_proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.8.attn.add_k_proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.8.attn.add_q_proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.8.attn.add_q_proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.8.attn.add_v_proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.8.attn.add_v_proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.8.attn.to_add_out.lora_A.weight",
|
||||
"transformer.transformer_blocks.8.attn.to_add_out.lora_B.weight",
|
||||
"transformer.transformer_blocks.8.attn.to_k.lora_A.weight",
|
||||
"transformer.transformer_blocks.8.attn.to_k.lora_B.weight",
|
||||
"transformer.transformer_blocks.8.attn.to_out.0.lora_A.weight",
|
||||
"transformer.transformer_blocks.8.attn.to_out.0.lora_B.weight",
|
||||
"transformer.transformer_blocks.8.attn.to_q.lora_A.weight",
|
||||
"transformer.transformer_blocks.8.attn.to_q.lora_B.weight",
|
||||
"transformer.transformer_blocks.8.attn.to_v.lora_A.weight",
|
||||
"transformer.transformer_blocks.8.attn.to_v.lora_B.weight",
|
||||
"transformer.transformer_blocks.8.ff.net.0.proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.8.ff.net.0.proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.8.ff.net.2.lora_A.weight",
|
||||
"transformer.transformer_blocks.8.ff.net.2.lora_B.weight",
|
||||
"transformer.transformer_blocks.8.ff_context.net.0.proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.8.ff_context.net.0.proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.8.ff_context.net.2.lora_A.weight",
|
||||
"transformer.transformer_blocks.8.ff_context.net.2.lora_B.weight",
|
||||
"transformer.transformer_blocks.8.norm1.linear.lora_A.weight",
|
||||
"transformer.transformer_blocks.8.norm1.linear.lora_B.weight",
|
||||
"transformer.transformer_blocks.8.norm1_context.linear.lora_A.weight",
|
||||
"transformer.transformer_blocks.8.norm1_context.linear.lora_B.weight",
|
||||
"transformer.transformer_blocks.9.attn.add_k_proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.9.attn.add_k_proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.9.attn.add_q_proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.9.attn.add_q_proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.9.attn.add_v_proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.9.attn.add_v_proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.9.attn.to_add_out.lora_A.weight",
|
||||
"transformer.transformer_blocks.9.attn.to_add_out.lora_B.weight",
|
||||
"transformer.transformer_blocks.9.attn.to_k.lora_A.weight",
|
||||
"transformer.transformer_blocks.9.attn.to_k.lora_B.weight",
|
||||
"transformer.transformer_blocks.9.attn.to_out.0.lora_A.weight",
|
||||
"transformer.transformer_blocks.9.attn.to_out.0.lora_B.weight",
|
||||
"transformer.transformer_blocks.9.attn.to_q.lora_A.weight",
|
||||
"transformer.transformer_blocks.9.attn.to_q.lora_B.weight",
|
||||
"transformer.transformer_blocks.9.attn.to_v.lora_A.weight",
|
||||
"transformer.transformer_blocks.9.attn.to_v.lora_B.weight",
|
||||
"transformer.transformer_blocks.9.ff.net.0.proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.9.ff.net.0.proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.9.ff.net.2.lora_A.weight",
|
||||
"transformer.transformer_blocks.9.ff.net.2.lora_B.weight",
|
||||
"transformer.transformer_blocks.9.ff_context.net.0.proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.9.ff_context.net.0.proj.lora_B.weight",
|
||||
"transformer.transformer_blocks.9.ff_context.net.2.lora_A.weight",
|
||||
"transformer.transformer_blocks.9.ff_context.net.2.lora_B.weight",
|
||||
"transformer.transformer_blocks.9.norm1.linear.lora_A.weight",
|
||||
"transformer.transformer_blocks.9.norm1.linear.lora_B.weight",
|
||||
"transformer.transformer_blocks.9.norm1_context.linear.lora_A.weight",
|
||||
"transformer.transformer_blocks.9.norm1_context.linear.lora_B.weight",
|
||||
]
|
||||
@@ -0,0 +1,914 @@
|
||||
state_dict_keys = [
|
||||
"lora_unet_double_blocks_0_img_attn_proj.alpha",
|
||||
"lora_unet_double_blocks_0_img_attn_proj.lora_down.weight",
|
||||
"lora_unet_double_blocks_0_img_attn_proj.lora_up.weight",
|
||||
"lora_unet_double_blocks_0_img_attn_qkv.alpha",
|
||||
"lora_unet_double_blocks_0_img_attn_qkv.lora_down.weight",
|
||||
"lora_unet_double_blocks_0_img_attn_qkv.lora_up.weight",
|
||||
"lora_unet_double_blocks_0_img_mlp_0.alpha",
|
||||
"lora_unet_double_blocks_0_img_mlp_0.lora_down.weight",
|
||||
"lora_unet_double_blocks_0_img_mlp_0.lora_up.weight",
|
||||
"lora_unet_double_blocks_0_img_mlp_2.alpha",
|
||||
"lora_unet_double_blocks_0_img_mlp_2.lora_down.weight",
|
||||
"lora_unet_double_blocks_0_img_mlp_2.lora_up.weight",
|
||||
"lora_unet_double_blocks_0_img_mod_lin.alpha",
|
||||
"lora_unet_double_blocks_0_img_mod_lin.lora_down.weight",
|
||||
"lora_unet_double_blocks_0_img_mod_lin.lora_up.weight",
|
||||
"lora_unet_double_blocks_0_txt_attn_proj.alpha",
|
||||
"lora_unet_double_blocks_0_txt_attn_proj.lora_down.weight",
|
||||
"lora_unet_double_blocks_0_txt_attn_proj.lora_up.weight",
|
||||
"lora_unet_double_blocks_0_txt_attn_qkv.alpha",
|
||||
"lora_unet_double_blocks_0_txt_attn_qkv.lora_down.weight",
|
||||
"lora_unet_double_blocks_0_txt_attn_qkv.lora_up.weight",
|
||||
"lora_unet_double_blocks_0_txt_mlp_0.alpha",
|
||||
"lora_unet_double_blocks_0_txt_mlp_0.lora_down.weight",
|
||||
"lora_unet_double_blocks_0_txt_mlp_0.lora_up.weight",
|
||||
"lora_unet_double_blocks_0_txt_mlp_2.alpha",
|
||||
"lora_unet_double_blocks_0_txt_mlp_2.lora_down.weight",
|
||||
"lora_unet_double_blocks_0_txt_mlp_2.lora_up.weight",
|
||||
"lora_unet_double_blocks_0_txt_mod_lin.alpha",
|
||||
"lora_unet_double_blocks_0_txt_mod_lin.lora_down.weight",
|
||||
"lora_unet_double_blocks_0_txt_mod_lin.lora_up.weight",
|
||||
"lora_unet_double_blocks_10_img_attn_proj.alpha",
|
||||
"lora_unet_double_blocks_10_img_attn_proj.lora_down.weight",
|
||||
"lora_unet_double_blocks_10_img_attn_proj.lora_up.weight",
|
||||
"lora_unet_double_blocks_10_img_attn_qkv.alpha",
|
||||
"lora_unet_double_blocks_10_img_attn_qkv.lora_down.weight",
|
||||
"lora_unet_double_blocks_10_img_attn_qkv.lora_up.weight",
|
||||
"lora_unet_double_blocks_10_img_mlp_0.alpha",
|
||||
"lora_unet_double_blocks_10_img_mlp_0.lora_down.weight",
|
||||
"lora_unet_double_blocks_10_img_mlp_0.lora_up.weight",
|
||||
"lora_unet_double_blocks_10_img_mlp_2.alpha",
|
||||
"lora_unet_double_blocks_10_img_mlp_2.lora_down.weight",
|
||||
"lora_unet_double_blocks_10_img_mlp_2.lora_up.weight",
|
||||
"lora_unet_double_blocks_10_img_mod_lin.alpha",
|
||||
"lora_unet_double_blocks_10_img_mod_lin.lora_down.weight",
|
||||
"lora_unet_double_blocks_10_img_mod_lin.lora_up.weight",
|
||||
"lora_unet_double_blocks_10_txt_attn_proj.alpha",
|
||||
"lora_unet_double_blocks_10_txt_attn_proj.lora_down.weight",
|
||||
"lora_unet_double_blocks_10_txt_attn_proj.lora_up.weight",
|
||||
"lora_unet_double_blocks_10_txt_attn_qkv.alpha",
|
||||
"lora_unet_double_blocks_10_txt_attn_qkv.lora_down.weight",
|
||||
"lora_unet_double_blocks_10_txt_attn_qkv.lora_up.weight",
|
||||
"lora_unet_double_blocks_10_txt_mlp_0.alpha",
|
||||
"lora_unet_double_blocks_10_txt_mlp_0.lora_down.weight",
|
||||
"lora_unet_double_blocks_10_txt_mlp_0.lora_up.weight",
|
||||
"lora_unet_double_blocks_10_txt_mlp_2.alpha",
|
||||
"lora_unet_double_blocks_10_txt_mlp_2.lora_down.weight",
|
||||
"lora_unet_double_blocks_10_txt_mlp_2.lora_up.weight",
|
||||
"lora_unet_double_blocks_10_txt_mod_lin.alpha",
|
||||
"lora_unet_double_blocks_10_txt_mod_lin.lora_down.weight",
|
||||
"lora_unet_double_blocks_10_txt_mod_lin.lora_up.weight",
|
||||
"lora_unet_double_blocks_11_img_attn_proj.alpha",
|
||||
"lora_unet_double_blocks_11_img_attn_proj.lora_down.weight",
|
||||
"lora_unet_double_blocks_11_img_attn_proj.lora_up.weight",
|
||||
"lora_unet_double_blocks_11_img_attn_qkv.alpha",
|
||||
"lora_unet_double_blocks_11_img_attn_qkv.lora_down.weight",
|
||||
"lora_unet_double_blocks_11_img_attn_qkv.lora_up.weight",
|
||||
"lora_unet_double_blocks_11_img_mlp_0.alpha",
|
||||
"lora_unet_double_blocks_11_img_mlp_0.lora_down.weight",
|
||||
"lora_unet_double_blocks_11_img_mlp_0.lora_up.weight",
|
||||
"lora_unet_double_blocks_11_img_mlp_2.alpha",
|
||||
"lora_unet_double_blocks_11_img_mlp_2.lora_down.weight",
|
||||
"lora_unet_double_blocks_11_img_mlp_2.lora_up.weight",
|
||||
"lora_unet_double_blocks_11_img_mod_lin.alpha",
|
||||
"lora_unet_double_blocks_11_img_mod_lin.lora_down.weight",
|
||||
"lora_unet_double_blocks_11_img_mod_lin.lora_up.weight",
|
||||
"lora_unet_double_blocks_11_txt_attn_proj.alpha",
|
||||
"lora_unet_double_blocks_11_txt_attn_proj.lora_down.weight",
|
||||
"lora_unet_double_blocks_11_txt_attn_proj.lora_up.weight",
|
||||
"lora_unet_double_blocks_11_txt_attn_qkv.alpha",
|
||||
"lora_unet_double_blocks_11_txt_attn_qkv.lora_down.weight",
|
||||
"lora_unet_double_blocks_11_txt_attn_qkv.lora_up.weight",
|
||||
"lora_unet_double_blocks_11_txt_mlp_0.alpha",
|
||||
"lora_unet_double_blocks_11_txt_mlp_0.lora_down.weight",
|
||||
"lora_unet_double_blocks_11_txt_mlp_0.lora_up.weight",
|
||||
"lora_unet_double_blocks_11_txt_mlp_2.alpha",
|
||||
"lora_unet_double_blocks_11_txt_mlp_2.lora_down.weight",
|
||||
"lora_unet_double_blocks_11_txt_mlp_2.lora_up.weight",
|
||||
"lora_unet_double_blocks_11_txt_mod_lin.alpha",
|
||||
"lora_unet_double_blocks_11_txt_mod_lin.lora_down.weight",
|
||||
"lora_unet_double_blocks_11_txt_mod_lin.lora_up.weight",
|
||||
"lora_unet_double_blocks_12_img_attn_proj.alpha",
|
||||
"lora_unet_double_blocks_12_img_attn_proj.lora_down.weight",
|
||||
"lora_unet_double_blocks_12_img_attn_proj.lora_up.weight",
|
||||
"lora_unet_double_blocks_12_img_attn_qkv.alpha",
|
||||
"lora_unet_double_blocks_12_img_attn_qkv.lora_down.weight",
|
||||
"lora_unet_double_blocks_12_img_attn_qkv.lora_up.weight",
|
||||
"lora_unet_double_blocks_12_img_mlp_0.alpha",
|
||||
"lora_unet_double_blocks_12_img_mlp_0.lora_down.weight",
|
||||
"lora_unet_double_blocks_12_img_mlp_0.lora_up.weight",
|
||||
"lora_unet_double_blocks_12_img_mlp_2.alpha",
|
||||
"lora_unet_double_blocks_12_img_mlp_2.lora_down.weight",
|
||||
"lora_unet_double_blocks_12_img_mlp_2.lora_up.weight",
|
||||
"lora_unet_double_blocks_12_img_mod_lin.alpha",
|
||||
"lora_unet_double_blocks_12_img_mod_lin.lora_down.weight",
|
||||
"lora_unet_double_blocks_12_img_mod_lin.lora_up.weight",
|
||||
"lora_unet_double_blocks_12_txt_attn_proj.alpha",
|
||||
"lora_unet_double_blocks_12_txt_attn_proj.lora_down.weight",
|
||||
"lora_unet_double_blocks_12_txt_attn_proj.lora_up.weight",
|
||||
"lora_unet_double_blocks_12_txt_attn_qkv.alpha",
|
||||
"lora_unet_double_blocks_12_txt_attn_qkv.lora_down.weight",
|
||||
"lora_unet_double_blocks_12_txt_attn_qkv.lora_up.weight",
|
||||
"lora_unet_double_blocks_12_txt_mlp_0.alpha",
|
||||
"lora_unet_double_blocks_12_txt_mlp_0.lora_down.weight",
|
||||
"lora_unet_double_blocks_12_txt_mlp_0.lora_up.weight",
|
||||
"lora_unet_double_blocks_12_txt_mlp_2.alpha",
|
||||
"lora_unet_double_blocks_12_txt_mlp_2.lora_down.weight",
|
||||
"lora_unet_double_blocks_12_txt_mlp_2.lora_up.weight",
|
||||
"lora_unet_double_blocks_12_txt_mod_lin.alpha",
|
||||
"lora_unet_double_blocks_12_txt_mod_lin.lora_down.weight",
|
||||
"lora_unet_double_blocks_12_txt_mod_lin.lora_up.weight",
|
||||
"lora_unet_double_blocks_13_img_attn_proj.alpha",
|
||||
"lora_unet_double_blocks_13_img_attn_proj.lora_down.weight",
|
||||
"lora_unet_double_blocks_13_img_attn_proj.lora_up.weight",
|
||||
"lora_unet_double_blocks_13_img_attn_qkv.alpha",
|
||||
"lora_unet_double_blocks_13_img_attn_qkv.lora_down.weight",
|
||||
"lora_unet_double_blocks_13_img_attn_qkv.lora_up.weight",
|
||||
"lora_unet_double_blocks_13_img_mlp_0.alpha",
|
||||
"lora_unet_double_blocks_13_img_mlp_0.lora_down.weight",
|
||||
"lora_unet_double_blocks_13_img_mlp_0.lora_up.weight",
|
||||
"lora_unet_double_blocks_13_img_mlp_2.alpha",
|
||||
"lora_unet_double_blocks_13_img_mlp_2.lora_down.weight",
|
||||
"lora_unet_double_blocks_13_img_mlp_2.lora_up.weight",
|
||||
"lora_unet_double_blocks_13_img_mod_lin.alpha",
|
||||
"lora_unet_double_blocks_13_img_mod_lin.lora_down.weight",
|
||||
"lora_unet_double_blocks_13_img_mod_lin.lora_up.weight",
|
||||
"lora_unet_double_blocks_13_txt_attn_proj.alpha",
|
||||
"lora_unet_double_blocks_13_txt_attn_proj.lora_down.weight",
|
||||
"lora_unet_double_blocks_13_txt_attn_proj.lora_up.weight",
|
||||
"lora_unet_double_blocks_13_txt_attn_qkv.alpha",
|
||||
"lora_unet_double_blocks_13_txt_attn_qkv.lora_down.weight",
|
||||
"lora_unet_double_blocks_13_txt_attn_qkv.lora_up.weight",
|
||||
"lora_unet_double_blocks_13_txt_mlp_0.alpha",
|
||||
"lora_unet_double_blocks_13_txt_mlp_0.lora_down.weight",
|
||||
"lora_unet_double_blocks_13_txt_mlp_0.lora_up.weight",
|
||||
"lora_unet_double_blocks_13_txt_mlp_2.alpha",
|
||||
"lora_unet_double_blocks_13_txt_mlp_2.lora_down.weight",
|
||||
"lora_unet_double_blocks_13_txt_mlp_2.lora_up.weight",
|
||||
"lora_unet_double_blocks_13_txt_mod_lin.alpha",
|
||||
"lora_unet_double_blocks_13_txt_mod_lin.lora_down.weight",
|
||||
"lora_unet_double_blocks_13_txt_mod_lin.lora_up.weight",
|
||||
"lora_unet_double_blocks_14_img_attn_proj.alpha",
|
||||
"lora_unet_double_blocks_14_img_attn_proj.lora_down.weight",
|
||||
"lora_unet_double_blocks_14_img_attn_proj.lora_up.weight",
|
||||
"lora_unet_double_blocks_14_img_attn_qkv.alpha",
|
||||
"lora_unet_double_blocks_14_img_attn_qkv.lora_down.weight",
|
||||
"lora_unet_double_blocks_14_img_attn_qkv.lora_up.weight",
|
||||
"lora_unet_double_blocks_14_img_mlp_0.alpha",
|
||||
"lora_unet_double_blocks_14_img_mlp_0.lora_down.weight",
|
||||
"lora_unet_double_blocks_14_img_mlp_0.lora_up.weight",
|
||||
"lora_unet_double_blocks_14_img_mlp_2.alpha",
|
||||
"lora_unet_double_blocks_14_img_mlp_2.lora_down.weight",
|
||||
"lora_unet_double_blocks_14_img_mlp_2.lora_up.weight",
|
||||
"lora_unet_double_blocks_14_img_mod_lin.alpha",
|
||||
"lora_unet_double_blocks_14_img_mod_lin.lora_down.weight",
|
||||
"lora_unet_double_blocks_14_img_mod_lin.lora_up.weight",
|
||||
"lora_unet_double_blocks_14_txt_attn_proj.alpha",
|
||||
"lora_unet_double_blocks_14_txt_attn_proj.lora_down.weight",
|
||||
"lora_unet_double_blocks_14_txt_attn_proj.lora_up.weight",
|
||||
"lora_unet_double_blocks_14_txt_attn_qkv.alpha",
|
||||
"lora_unet_double_blocks_14_txt_attn_qkv.lora_down.weight",
|
||||
"lora_unet_double_blocks_14_txt_attn_qkv.lora_up.weight",
|
||||
"lora_unet_double_blocks_14_txt_mlp_0.alpha",
|
||||
"lora_unet_double_blocks_14_txt_mlp_0.lora_down.weight",
|
||||
"lora_unet_double_blocks_14_txt_mlp_0.lora_up.weight",
|
||||
"lora_unet_double_blocks_14_txt_mlp_2.alpha",
|
||||
"lora_unet_double_blocks_14_txt_mlp_2.lora_down.weight",
|
||||
"lora_unet_double_blocks_14_txt_mlp_2.lora_up.weight",
|
||||
"lora_unet_double_blocks_14_txt_mod_lin.alpha",
|
||||
"lora_unet_double_blocks_14_txt_mod_lin.lora_down.weight",
|
||||
"lora_unet_double_blocks_14_txt_mod_lin.lora_up.weight",
|
||||
"lora_unet_double_blocks_15_img_attn_proj.alpha",
|
||||
"lora_unet_double_blocks_15_img_attn_proj.lora_down.weight",
|
||||
"lora_unet_double_blocks_15_img_attn_proj.lora_up.weight",
|
||||
"lora_unet_double_blocks_15_img_attn_qkv.alpha",
|
||||
"lora_unet_double_blocks_15_img_attn_qkv.lora_down.weight",
|
||||
"lora_unet_double_blocks_15_img_attn_qkv.lora_up.weight",
|
||||
"lora_unet_double_blocks_15_img_mlp_0.alpha",
|
||||
"lora_unet_double_blocks_15_img_mlp_0.lora_down.weight",
|
||||
"lora_unet_double_blocks_15_img_mlp_0.lora_up.weight",
|
||||
"lora_unet_double_blocks_15_img_mlp_2.alpha",
|
||||
"lora_unet_double_blocks_15_img_mlp_2.lora_down.weight",
|
||||
"lora_unet_double_blocks_15_img_mlp_2.lora_up.weight",
|
||||
"lora_unet_double_blocks_15_img_mod_lin.alpha",
|
||||
"lora_unet_double_blocks_15_img_mod_lin.lora_down.weight",
|
||||
"lora_unet_double_blocks_15_img_mod_lin.lora_up.weight",
|
||||
"lora_unet_double_blocks_15_txt_attn_proj.alpha",
|
||||
"lora_unet_double_blocks_15_txt_attn_proj.lora_down.weight",
|
||||
"lora_unet_double_blocks_15_txt_attn_proj.lora_up.weight",
|
||||
"lora_unet_double_blocks_15_txt_attn_qkv.alpha",
|
||||
"lora_unet_double_blocks_15_txt_attn_qkv.lora_down.weight",
|
||||
"lora_unet_double_blocks_15_txt_attn_qkv.lora_up.weight",
|
||||
"lora_unet_double_blocks_15_txt_mlp_0.alpha",
|
||||
"lora_unet_double_blocks_15_txt_mlp_0.lora_down.weight",
|
||||
"lora_unet_double_blocks_15_txt_mlp_0.lora_up.weight",
|
||||
"lora_unet_double_blocks_15_txt_mlp_2.alpha",
|
||||
"lora_unet_double_blocks_15_txt_mlp_2.lora_down.weight",
|
||||
"lora_unet_double_blocks_15_txt_mlp_2.lora_up.weight",
|
||||
"lora_unet_double_blocks_15_txt_mod_lin.alpha",
|
||||
"lora_unet_double_blocks_15_txt_mod_lin.lora_down.weight",
|
||||
"lora_unet_double_blocks_15_txt_mod_lin.lora_up.weight",
|
||||
"lora_unet_double_blocks_16_img_attn_proj.alpha",
|
||||
"lora_unet_double_blocks_16_img_attn_proj.lora_down.weight",
|
||||
"lora_unet_double_blocks_16_img_attn_proj.lora_up.weight",
|
||||
"lora_unet_double_blocks_16_img_attn_qkv.alpha",
|
||||
"lora_unet_double_blocks_16_img_attn_qkv.lora_down.weight",
|
||||
"lora_unet_double_blocks_16_img_attn_qkv.lora_up.weight",
|
||||
"lora_unet_double_blocks_16_img_mlp_0.alpha",
|
||||
"lora_unet_double_blocks_16_img_mlp_0.lora_down.weight",
|
||||
"lora_unet_double_blocks_16_img_mlp_0.lora_up.weight",
|
||||
"lora_unet_double_blocks_16_img_mlp_2.alpha",
|
||||
"lora_unet_double_blocks_16_img_mlp_2.lora_down.weight",
|
||||
"lora_unet_double_blocks_16_img_mlp_2.lora_up.weight",
|
||||
"lora_unet_double_blocks_16_img_mod_lin.alpha",
|
||||
"lora_unet_double_blocks_16_img_mod_lin.lora_down.weight",
|
||||
"lora_unet_double_blocks_16_img_mod_lin.lora_up.weight",
|
||||
"lora_unet_double_blocks_16_txt_attn_proj.alpha",
|
||||
"lora_unet_double_blocks_16_txt_attn_proj.lora_down.weight",
|
||||
"lora_unet_double_blocks_16_txt_attn_proj.lora_up.weight",
|
||||
"lora_unet_double_blocks_16_txt_attn_qkv.alpha",
|
||||
"lora_unet_double_blocks_16_txt_attn_qkv.lora_down.weight",
|
||||
"lora_unet_double_blocks_16_txt_attn_qkv.lora_up.weight",
|
||||
"lora_unet_double_blocks_16_txt_mlp_0.alpha",
|
||||
"lora_unet_double_blocks_16_txt_mlp_0.lora_down.weight",
|
||||
"lora_unet_double_blocks_16_txt_mlp_0.lora_up.weight",
|
||||
"lora_unet_double_blocks_16_txt_mlp_2.alpha",
|
||||
"lora_unet_double_blocks_16_txt_mlp_2.lora_down.weight",
|
||||
"lora_unet_double_blocks_16_txt_mlp_2.lora_up.weight",
|
||||
"lora_unet_double_blocks_16_txt_mod_lin.alpha",
|
||||
"lora_unet_double_blocks_16_txt_mod_lin.lora_down.weight",
|
||||
"lora_unet_double_blocks_16_txt_mod_lin.lora_up.weight",
|
||||
"lora_unet_double_blocks_17_img_attn_proj.alpha",
|
||||
"lora_unet_double_blocks_17_img_attn_proj.lora_down.weight",
|
||||
"lora_unet_double_blocks_17_img_attn_proj.lora_up.weight",
|
||||
"lora_unet_double_blocks_17_img_attn_qkv.alpha",
|
||||
"lora_unet_double_blocks_17_img_attn_qkv.lora_down.weight",
|
||||
"lora_unet_double_blocks_17_img_attn_qkv.lora_up.weight",
|
||||
"lora_unet_double_blocks_17_img_mlp_0.alpha",
|
||||
"lora_unet_double_blocks_17_img_mlp_0.lora_down.weight",
|
||||
"lora_unet_double_blocks_17_img_mlp_0.lora_up.weight",
|
||||
"lora_unet_double_blocks_17_img_mlp_2.alpha",
|
||||
"lora_unet_double_blocks_17_img_mlp_2.lora_down.weight",
|
||||
"lora_unet_double_blocks_17_img_mlp_2.lora_up.weight",
|
||||
"lora_unet_double_blocks_17_img_mod_lin.alpha",
|
||||
"lora_unet_double_blocks_17_img_mod_lin.lora_down.weight",
|
||||
"lora_unet_double_blocks_17_img_mod_lin.lora_up.weight",
|
||||
"lora_unet_double_blocks_17_txt_attn_proj.alpha",
|
||||
"lora_unet_double_blocks_17_txt_attn_proj.lora_down.weight",
|
||||
"lora_unet_double_blocks_17_txt_attn_proj.lora_up.weight",
|
||||
"lora_unet_double_blocks_17_txt_attn_qkv.alpha",
|
||||
"lora_unet_double_blocks_17_txt_attn_qkv.lora_down.weight",
|
||||
"lora_unet_double_blocks_17_txt_attn_qkv.lora_up.weight",
|
||||
"lora_unet_double_blocks_17_txt_mlp_0.alpha",
|
||||
"lora_unet_double_blocks_17_txt_mlp_0.lora_down.weight",
|
||||
"lora_unet_double_blocks_17_txt_mlp_0.lora_up.weight",
|
||||
"lora_unet_double_blocks_17_txt_mlp_2.alpha",
|
||||
"lora_unet_double_blocks_17_txt_mlp_2.lora_down.weight",
|
||||
"lora_unet_double_blocks_17_txt_mlp_2.lora_up.weight",
|
||||
"lora_unet_double_blocks_17_txt_mod_lin.alpha",
|
||||
"lora_unet_double_blocks_17_txt_mod_lin.lora_down.weight",
|
||||
"lora_unet_double_blocks_17_txt_mod_lin.lora_up.weight",
|
||||
"lora_unet_double_blocks_18_img_attn_proj.alpha",
|
||||
"lora_unet_double_blocks_18_img_attn_proj.lora_down.weight",
|
||||
"lora_unet_double_blocks_18_img_attn_proj.lora_up.weight",
|
||||
"lora_unet_double_blocks_18_img_attn_qkv.alpha",
|
||||
"lora_unet_double_blocks_18_img_attn_qkv.lora_down.weight",
|
||||
"lora_unet_double_blocks_18_img_attn_qkv.lora_up.weight",
|
||||
"lora_unet_double_blocks_18_img_mlp_0.alpha",
|
||||
"lora_unet_double_blocks_18_img_mlp_0.lora_down.weight",
|
||||
"lora_unet_double_blocks_18_img_mlp_0.lora_up.weight",
|
||||
"lora_unet_double_blocks_18_img_mlp_2.alpha",
|
||||
"lora_unet_double_blocks_18_img_mlp_2.lora_down.weight",
|
||||
"lora_unet_double_blocks_18_img_mlp_2.lora_up.weight",
|
||||
"lora_unet_double_blocks_18_img_mod_lin.alpha",
|
||||
"lora_unet_double_blocks_18_img_mod_lin.lora_down.weight",
|
||||
"lora_unet_double_blocks_18_img_mod_lin.lora_up.weight",
|
||||
"lora_unet_double_blocks_18_txt_attn_proj.alpha",
|
||||
"lora_unet_double_blocks_18_txt_attn_proj.lora_down.weight",
|
||||
"lora_unet_double_blocks_18_txt_attn_proj.lora_up.weight",
|
||||
"lora_unet_double_blocks_18_txt_attn_qkv.alpha",
|
||||
"lora_unet_double_blocks_18_txt_attn_qkv.lora_down.weight",
|
||||
"lora_unet_double_blocks_18_txt_attn_qkv.lora_up.weight",
|
||||
"lora_unet_double_blocks_18_txt_mlp_0.alpha",
|
||||
"lora_unet_double_blocks_18_txt_mlp_0.lora_down.weight",
|
||||
"lora_unet_double_blocks_18_txt_mlp_0.lora_up.weight",
|
||||
"lora_unet_double_blocks_18_txt_mlp_2.alpha",
|
||||
"lora_unet_double_blocks_18_txt_mlp_2.lora_down.weight",
|
||||
"lora_unet_double_blocks_18_txt_mlp_2.lora_up.weight",
|
||||
"lora_unet_double_blocks_18_txt_mod_lin.alpha",
|
||||
"lora_unet_double_blocks_18_txt_mod_lin.lora_down.weight",
|
||||
"lora_unet_double_blocks_18_txt_mod_lin.lora_up.weight",
|
||||
"lora_unet_double_blocks_1_img_attn_proj.alpha",
|
||||
"lora_unet_double_blocks_1_img_attn_proj.lora_down.weight",
|
||||
"lora_unet_double_blocks_1_img_attn_proj.lora_up.weight",
|
||||
"lora_unet_double_blocks_1_img_attn_qkv.alpha",
|
||||
"lora_unet_double_blocks_1_img_attn_qkv.lora_down.weight",
|
||||
"lora_unet_double_blocks_1_img_attn_qkv.lora_up.weight",
|
||||
"lora_unet_double_blocks_1_img_mlp_0.alpha",
|
||||
"lora_unet_double_blocks_1_img_mlp_0.lora_down.weight",
|
||||
"lora_unet_double_blocks_1_img_mlp_0.lora_up.weight",
|
||||
"lora_unet_double_blocks_1_img_mlp_2.alpha",
|
||||
"lora_unet_double_blocks_1_img_mlp_2.lora_down.weight",
|
||||
"lora_unet_double_blocks_1_img_mlp_2.lora_up.weight",
|
||||
"lora_unet_double_blocks_1_img_mod_lin.alpha",
|
||||
"lora_unet_double_blocks_1_img_mod_lin.lora_down.weight",
|
||||
"lora_unet_double_blocks_1_img_mod_lin.lora_up.weight",
|
||||
"lora_unet_double_blocks_1_txt_attn_proj.alpha",
|
||||
"lora_unet_double_blocks_1_txt_attn_proj.lora_down.weight",
|
||||
"lora_unet_double_blocks_1_txt_attn_proj.lora_up.weight",
|
||||
"lora_unet_double_blocks_1_txt_attn_qkv.alpha",
|
||||
"lora_unet_double_blocks_1_txt_attn_qkv.lora_down.weight",
|
||||
"lora_unet_double_blocks_1_txt_attn_qkv.lora_up.weight",
|
||||
"lora_unet_double_blocks_1_txt_mlp_0.alpha",
|
||||
"lora_unet_double_blocks_1_txt_mlp_0.lora_down.weight",
|
||||
"lora_unet_double_blocks_1_txt_mlp_0.lora_up.weight",
|
||||
"lora_unet_double_blocks_1_txt_mlp_2.alpha",
|
||||
"lora_unet_double_blocks_1_txt_mlp_2.lora_down.weight",
|
||||
"lora_unet_double_blocks_1_txt_mlp_2.lora_up.weight",
|
||||
"lora_unet_double_blocks_1_txt_mod_lin.alpha",
|
||||
"lora_unet_double_blocks_1_txt_mod_lin.lora_down.weight",
|
||||
"lora_unet_double_blocks_1_txt_mod_lin.lora_up.weight",
|
||||
"lora_unet_double_blocks_2_img_attn_proj.alpha",
|
||||
"lora_unet_double_blocks_2_img_attn_proj.lora_down.weight",
|
||||
"lora_unet_double_blocks_2_img_attn_proj.lora_up.weight",
|
||||
"lora_unet_double_blocks_2_img_attn_qkv.alpha",
|
||||
"lora_unet_double_blocks_2_img_attn_qkv.lora_down.weight",
|
||||
"lora_unet_double_blocks_2_img_attn_qkv.lora_up.weight",
|
||||
"lora_unet_double_blocks_2_img_mlp_0.alpha",
|
||||
"lora_unet_double_blocks_2_img_mlp_0.lora_down.weight",
|
||||
"lora_unet_double_blocks_2_img_mlp_0.lora_up.weight",
|
||||
"lora_unet_double_blocks_2_img_mlp_2.alpha",
|
||||
"lora_unet_double_blocks_2_img_mlp_2.lora_down.weight",
|
||||
"lora_unet_double_blocks_2_img_mlp_2.lora_up.weight",
|
||||
"lora_unet_double_blocks_2_img_mod_lin.alpha",
|
||||
"lora_unet_double_blocks_2_img_mod_lin.lora_down.weight",
|
||||
"lora_unet_double_blocks_2_img_mod_lin.lora_up.weight",
|
||||
"lora_unet_double_blocks_2_txt_attn_proj.alpha",
|
||||
"lora_unet_double_blocks_2_txt_attn_proj.lora_down.weight",
|
||||
"lora_unet_double_blocks_2_txt_attn_proj.lora_up.weight",
|
||||
"lora_unet_double_blocks_2_txt_attn_qkv.alpha",
|
||||
"lora_unet_double_blocks_2_txt_attn_qkv.lora_down.weight",
|
||||
"lora_unet_double_blocks_2_txt_attn_qkv.lora_up.weight",
|
||||
"lora_unet_double_blocks_2_txt_mlp_0.alpha",
|
||||
"lora_unet_double_blocks_2_txt_mlp_0.lora_down.weight",
|
||||
"lora_unet_double_blocks_2_txt_mlp_0.lora_up.weight",
|
||||
"lora_unet_double_blocks_2_txt_mlp_2.alpha",
|
||||
"lora_unet_double_blocks_2_txt_mlp_2.lora_down.weight",
|
||||
"lora_unet_double_blocks_2_txt_mlp_2.lora_up.weight",
|
||||
"lora_unet_double_blocks_2_txt_mod_lin.alpha",
|
||||
"lora_unet_double_blocks_2_txt_mod_lin.lora_down.weight",
|
||||
"lora_unet_double_blocks_2_txt_mod_lin.lora_up.weight",
|
||||
"lora_unet_double_blocks_3_img_attn_proj.alpha",
|
||||
"lora_unet_double_blocks_3_img_attn_proj.lora_down.weight",
|
||||
"lora_unet_double_blocks_3_img_attn_proj.lora_up.weight",
|
||||
"lora_unet_double_blocks_3_img_attn_qkv.alpha",
|
||||
"lora_unet_double_blocks_3_img_attn_qkv.lora_down.weight",
|
||||
"lora_unet_double_blocks_3_img_attn_qkv.lora_up.weight",
|
||||
"lora_unet_double_blocks_3_img_mlp_0.alpha",
|
||||
"lora_unet_double_blocks_3_img_mlp_0.lora_down.weight",
|
||||
"lora_unet_double_blocks_3_img_mlp_0.lora_up.weight",
|
||||
"lora_unet_double_blocks_3_img_mlp_2.alpha",
|
||||
"lora_unet_double_blocks_3_img_mlp_2.lora_down.weight",
|
||||
"lora_unet_double_blocks_3_img_mlp_2.lora_up.weight",
|
||||
"lora_unet_double_blocks_3_img_mod_lin.alpha",
|
||||
"lora_unet_double_blocks_3_img_mod_lin.lora_down.weight",
|
||||
"lora_unet_double_blocks_3_img_mod_lin.lora_up.weight",
|
||||
"lora_unet_double_blocks_3_txt_attn_proj.alpha",
|
||||
"lora_unet_double_blocks_3_txt_attn_proj.lora_down.weight",
|
||||
"lora_unet_double_blocks_3_txt_attn_proj.lora_up.weight",
|
||||
"lora_unet_double_blocks_3_txt_attn_qkv.alpha",
|
||||
"lora_unet_double_blocks_3_txt_attn_qkv.lora_down.weight",
|
||||
"lora_unet_double_blocks_3_txt_attn_qkv.lora_up.weight",
|
||||
"lora_unet_double_blocks_3_txt_mlp_0.alpha",
|
||||
"lora_unet_double_blocks_3_txt_mlp_0.lora_down.weight",
|
||||
"lora_unet_double_blocks_3_txt_mlp_0.lora_up.weight",
|
||||
"lora_unet_double_blocks_3_txt_mlp_2.alpha",
|
||||
"lora_unet_double_blocks_3_txt_mlp_2.lora_down.weight",
|
||||
"lora_unet_double_blocks_3_txt_mlp_2.lora_up.weight",
|
||||
"lora_unet_double_blocks_3_txt_mod_lin.alpha",
|
||||
"lora_unet_double_blocks_3_txt_mod_lin.lora_down.weight",
|
||||
"lora_unet_double_blocks_3_txt_mod_lin.lora_up.weight",
|
||||
"lora_unet_double_blocks_4_img_attn_proj.alpha",
|
||||
"lora_unet_double_blocks_4_img_attn_proj.lora_down.weight",
|
||||
"lora_unet_double_blocks_4_img_attn_proj.lora_up.weight",
|
||||
"lora_unet_double_blocks_4_img_attn_qkv.alpha",
|
||||
"lora_unet_double_blocks_4_img_attn_qkv.lora_down.weight",
|
||||
"lora_unet_double_blocks_4_img_attn_qkv.lora_up.weight",
|
||||
"lora_unet_double_blocks_4_img_mlp_0.alpha",
|
||||
"lora_unet_double_blocks_4_img_mlp_0.lora_down.weight",
|
||||
"lora_unet_double_blocks_4_img_mlp_0.lora_up.weight",
|
||||
"lora_unet_double_blocks_4_img_mlp_2.alpha",
|
||||
"lora_unet_double_blocks_4_img_mlp_2.lora_down.weight",
|
||||
"lora_unet_double_blocks_4_img_mlp_2.lora_up.weight",
|
||||
"lora_unet_double_blocks_4_img_mod_lin.alpha",
|
||||
"lora_unet_double_blocks_4_img_mod_lin.lora_down.weight",
|
||||
"lora_unet_double_blocks_4_img_mod_lin.lora_up.weight",
|
||||
"lora_unet_double_blocks_4_txt_attn_proj.alpha",
|
||||
"lora_unet_double_blocks_4_txt_attn_proj.lora_down.weight",
|
||||
"lora_unet_double_blocks_4_txt_attn_proj.lora_up.weight",
|
||||
"lora_unet_double_blocks_4_txt_attn_qkv.alpha",
|
||||
"lora_unet_double_blocks_4_txt_attn_qkv.lora_down.weight",
|
||||
"lora_unet_double_blocks_4_txt_attn_qkv.lora_up.weight",
|
||||
"lora_unet_double_blocks_4_txt_mlp_0.alpha",
|
||||
"lora_unet_double_blocks_4_txt_mlp_0.lora_down.weight",
|
||||
"lora_unet_double_blocks_4_txt_mlp_0.lora_up.weight",
|
||||
"lora_unet_double_blocks_4_txt_mlp_2.alpha",
|
||||
"lora_unet_double_blocks_4_txt_mlp_2.lora_down.weight",
|
||||
"lora_unet_double_blocks_4_txt_mlp_2.lora_up.weight",
|
||||
"lora_unet_double_blocks_4_txt_mod_lin.alpha",
|
||||
"lora_unet_double_blocks_4_txt_mod_lin.lora_down.weight",
|
||||
"lora_unet_double_blocks_4_txt_mod_lin.lora_up.weight",
|
||||
"lora_unet_double_blocks_5_img_attn_proj.alpha",
|
||||
"lora_unet_double_blocks_5_img_attn_proj.lora_down.weight",
|
||||
"lora_unet_double_blocks_5_img_attn_proj.lora_up.weight",
|
||||
"lora_unet_double_blocks_5_img_attn_qkv.alpha",
|
||||
"lora_unet_double_blocks_5_img_attn_qkv.lora_down.weight",
|
||||
"lora_unet_double_blocks_5_img_attn_qkv.lora_up.weight",
|
||||
"lora_unet_double_blocks_5_img_mlp_0.alpha",
|
||||
"lora_unet_double_blocks_5_img_mlp_0.lora_down.weight",
|
||||
"lora_unet_double_blocks_5_img_mlp_0.lora_up.weight",
|
||||
"lora_unet_double_blocks_5_img_mlp_2.alpha",
|
||||
"lora_unet_double_blocks_5_img_mlp_2.lora_down.weight",
|
||||
"lora_unet_double_blocks_5_img_mlp_2.lora_up.weight",
|
||||
"lora_unet_double_blocks_5_img_mod_lin.alpha",
|
||||
"lora_unet_double_blocks_5_img_mod_lin.lora_down.weight",
|
||||
"lora_unet_double_blocks_5_img_mod_lin.lora_up.weight",
|
||||
"lora_unet_double_blocks_5_txt_attn_proj.alpha",
|
||||
"lora_unet_double_blocks_5_txt_attn_proj.lora_down.weight",
|
||||
"lora_unet_double_blocks_5_txt_attn_proj.lora_up.weight",
|
||||
"lora_unet_double_blocks_5_txt_attn_qkv.alpha",
|
||||
"lora_unet_double_blocks_5_txt_attn_qkv.lora_down.weight",
|
||||
"lora_unet_double_blocks_5_txt_attn_qkv.lora_up.weight",
|
||||
"lora_unet_double_blocks_5_txt_mlp_0.alpha",
|
||||
"lora_unet_double_blocks_5_txt_mlp_0.lora_down.weight",
|
||||
"lora_unet_double_blocks_5_txt_mlp_0.lora_up.weight",
|
||||
"lora_unet_double_blocks_5_txt_mlp_2.alpha",
|
||||
"lora_unet_double_blocks_5_txt_mlp_2.lora_down.weight",
|
||||
"lora_unet_double_blocks_5_txt_mlp_2.lora_up.weight",
|
||||
"lora_unet_double_blocks_5_txt_mod_lin.alpha",
|
||||
"lora_unet_double_blocks_5_txt_mod_lin.lora_down.weight",
|
||||
"lora_unet_double_blocks_5_txt_mod_lin.lora_up.weight",
|
||||
"lora_unet_double_blocks_6_img_attn_proj.alpha",
|
||||
"lora_unet_double_blocks_6_img_attn_proj.lora_down.weight",
|
||||
"lora_unet_double_blocks_6_img_attn_proj.lora_up.weight",
|
||||
"lora_unet_double_blocks_6_img_attn_qkv.alpha",
|
||||
"lora_unet_double_blocks_6_img_attn_qkv.lora_down.weight",
|
||||
"lora_unet_double_blocks_6_img_attn_qkv.lora_up.weight",
|
||||
"lora_unet_double_blocks_6_img_mlp_0.alpha",
|
||||
"lora_unet_double_blocks_6_img_mlp_0.lora_down.weight",
|
||||
"lora_unet_double_blocks_6_img_mlp_0.lora_up.weight",
|
||||
"lora_unet_double_blocks_6_img_mlp_2.alpha",
|
||||
"lora_unet_double_blocks_6_img_mlp_2.lora_down.weight",
|
||||
"lora_unet_double_blocks_6_img_mlp_2.lora_up.weight",
|
||||
"lora_unet_double_blocks_6_img_mod_lin.alpha",
|
||||
"lora_unet_double_blocks_6_img_mod_lin.lora_down.weight",
|
||||
"lora_unet_double_blocks_6_img_mod_lin.lora_up.weight",
|
||||
"lora_unet_double_blocks_6_txt_attn_proj.alpha",
|
||||
"lora_unet_double_blocks_6_txt_attn_proj.lora_down.weight",
|
||||
"lora_unet_double_blocks_6_txt_attn_proj.lora_up.weight",
|
||||
"lora_unet_double_blocks_6_txt_attn_qkv.alpha",
|
||||
"lora_unet_double_blocks_6_txt_attn_qkv.lora_down.weight",
|
||||
"lora_unet_double_blocks_6_txt_attn_qkv.lora_up.weight",
|
||||
"lora_unet_double_blocks_6_txt_mlp_0.alpha",
|
||||
"lora_unet_double_blocks_6_txt_mlp_0.lora_down.weight",
|
||||
"lora_unet_double_blocks_6_txt_mlp_0.lora_up.weight",
|
||||
"lora_unet_double_blocks_6_txt_mlp_2.alpha",
|
||||
"lora_unet_double_blocks_6_txt_mlp_2.lora_down.weight",
|
||||
"lora_unet_double_blocks_6_txt_mlp_2.lora_up.weight",
|
||||
"lora_unet_double_blocks_6_txt_mod_lin.alpha",
|
||||
"lora_unet_double_blocks_6_txt_mod_lin.lora_down.weight",
|
||||
"lora_unet_double_blocks_6_txt_mod_lin.lora_up.weight",
|
||||
"lora_unet_double_blocks_7_img_attn_proj.alpha",
|
||||
"lora_unet_double_blocks_7_img_attn_proj.lora_down.weight",
|
||||
"lora_unet_double_blocks_7_img_attn_proj.lora_up.weight",
|
||||
"lora_unet_double_blocks_7_img_attn_qkv.alpha",
|
||||
"lora_unet_double_blocks_7_img_attn_qkv.lora_down.weight",
|
||||
"lora_unet_double_blocks_7_img_attn_qkv.lora_up.weight",
|
||||
"lora_unet_double_blocks_7_img_mlp_0.alpha",
|
||||
"lora_unet_double_blocks_7_img_mlp_0.lora_down.weight",
|
||||
"lora_unet_double_blocks_7_img_mlp_0.lora_up.weight",
|
||||
"lora_unet_double_blocks_7_img_mlp_2.alpha",
|
||||
"lora_unet_double_blocks_7_img_mlp_2.lora_down.weight",
|
||||
"lora_unet_double_blocks_7_img_mlp_2.lora_up.weight",
|
||||
"lora_unet_double_blocks_7_img_mod_lin.alpha",
|
||||
"lora_unet_double_blocks_7_img_mod_lin.lora_down.weight",
|
||||
"lora_unet_double_blocks_7_img_mod_lin.lora_up.weight",
|
||||
"lora_unet_double_blocks_7_txt_attn_proj.alpha",
|
||||
"lora_unet_double_blocks_7_txt_attn_proj.lora_down.weight",
|
||||
"lora_unet_double_blocks_7_txt_attn_proj.lora_up.weight",
|
||||
"lora_unet_double_blocks_7_txt_attn_qkv.alpha",
|
||||
"lora_unet_double_blocks_7_txt_attn_qkv.lora_down.weight",
|
||||
"lora_unet_double_blocks_7_txt_attn_qkv.lora_up.weight",
|
||||
"lora_unet_double_blocks_7_txt_mlp_0.alpha",
|
||||
"lora_unet_double_blocks_7_txt_mlp_0.lora_down.weight",
|
||||
"lora_unet_double_blocks_7_txt_mlp_0.lora_up.weight",
|
||||
"lora_unet_double_blocks_7_txt_mlp_2.alpha",
|
||||
"lora_unet_double_blocks_7_txt_mlp_2.lora_down.weight",
|
||||
"lora_unet_double_blocks_7_txt_mlp_2.lora_up.weight",
|
||||
"lora_unet_double_blocks_7_txt_mod_lin.alpha",
|
||||
"lora_unet_double_blocks_7_txt_mod_lin.lora_down.weight",
|
||||
"lora_unet_double_blocks_7_txt_mod_lin.lora_up.weight",
|
||||
"lora_unet_double_blocks_8_img_attn_proj.alpha",
|
||||
"lora_unet_double_blocks_8_img_attn_proj.lora_down.weight",
|
||||
"lora_unet_double_blocks_8_img_attn_proj.lora_up.weight",
|
||||
"lora_unet_double_blocks_8_img_attn_qkv.alpha",
|
||||
"lora_unet_double_blocks_8_img_attn_qkv.lora_down.weight",
|
||||
"lora_unet_double_blocks_8_img_attn_qkv.lora_up.weight",
|
||||
"lora_unet_double_blocks_8_img_mlp_0.alpha",
|
||||
"lora_unet_double_blocks_8_img_mlp_0.lora_down.weight",
|
||||
"lora_unet_double_blocks_8_img_mlp_0.lora_up.weight",
|
||||
"lora_unet_double_blocks_8_img_mlp_2.alpha",
|
||||
"lora_unet_double_blocks_8_img_mlp_2.lora_down.weight",
|
||||
"lora_unet_double_blocks_8_img_mlp_2.lora_up.weight",
|
||||
"lora_unet_double_blocks_8_img_mod_lin.alpha",
|
||||
"lora_unet_double_blocks_8_img_mod_lin.lora_down.weight",
|
||||
"lora_unet_double_blocks_8_img_mod_lin.lora_up.weight",
|
||||
"lora_unet_double_blocks_8_txt_attn_proj.alpha",
|
||||
"lora_unet_double_blocks_8_txt_attn_proj.lora_down.weight",
|
||||
"lora_unet_double_blocks_8_txt_attn_proj.lora_up.weight",
|
||||
"lora_unet_double_blocks_8_txt_attn_qkv.alpha",
|
||||
"lora_unet_double_blocks_8_txt_attn_qkv.lora_down.weight",
|
||||
"lora_unet_double_blocks_8_txt_attn_qkv.lora_up.weight",
|
||||
"lora_unet_double_blocks_8_txt_mlp_0.alpha",
|
||||
"lora_unet_double_blocks_8_txt_mlp_0.lora_down.weight",
|
||||
"lora_unet_double_blocks_8_txt_mlp_0.lora_up.weight",
|
||||
"lora_unet_double_blocks_8_txt_mlp_2.alpha",
|
||||
"lora_unet_double_blocks_8_txt_mlp_2.lora_down.weight",
|
||||
"lora_unet_double_blocks_8_txt_mlp_2.lora_up.weight",
|
||||
"lora_unet_double_blocks_8_txt_mod_lin.alpha",
|
||||
"lora_unet_double_blocks_8_txt_mod_lin.lora_down.weight",
|
||||
"lora_unet_double_blocks_8_txt_mod_lin.lora_up.weight",
|
||||
"lora_unet_double_blocks_9_img_attn_proj.alpha",
|
||||
"lora_unet_double_blocks_9_img_attn_proj.lora_down.weight",
|
||||
"lora_unet_double_blocks_9_img_attn_proj.lora_up.weight",
|
||||
"lora_unet_double_blocks_9_img_attn_qkv.alpha",
|
||||
"lora_unet_double_blocks_9_img_attn_qkv.lora_down.weight",
|
||||
"lora_unet_double_blocks_9_img_attn_qkv.lora_up.weight",
|
||||
"lora_unet_double_blocks_9_img_mlp_0.alpha",
|
||||
"lora_unet_double_blocks_9_img_mlp_0.lora_down.weight",
|
||||
"lora_unet_double_blocks_9_img_mlp_0.lora_up.weight",
|
||||
"lora_unet_double_blocks_9_img_mlp_2.alpha",
|
||||
"lora_unet_double_blocks_9_img_mlp_2.lora_down.weight",
|
||||
"lora_unet_double_blocks_9_img_mlp_2.lora_up.weight",
|
||||
"lora_unet_double_blocks_9_img_mod_lin.alpha",
|
||||
"lora_unet_double_blocks_9_img_mod_lin.lora_down.weight",
|
||||
"lora_unet_double_blocks_9_img_mod_lin.lora_up.weight",
|
||||
"lora_unet_double_blocks_9_txt_attn_proj.alpha",
|
||||
"lora_unet_double_blocks_9_txt_attn_proj.lora_down.weight",
|
||||
"lora_unet_double_blocks_9_txt_attn_proj.lora_up.weight",
|
||||
"lora_unet_double_blocks_9_txt_attn_qkv.alpha",
|
||||
"lora_unet_double_blocks_9_txt_attn_qkv.lora_down.weight",
|
||||
"lora_unet_double_blocks_9_txt_attn_qkv.lora_up.weight",
|
||||
"lora_unet_double_blocks_9_txt_mlp_0.alpha",
|
||||
"lora_unet_double_blocks_9_txt_mlp_0.lora_down.weight",
|
||||
"lora_unet_double_blocks_9_txt_mlp_0.lora_up.weight",
|
||||
"lora_unet_double_blocks_9_txt_mlp_2.alpha",
|
||||
"lora_unet_double_blocks_9_txt_mlp_2.lora_down.weight",
|
||||
"lora_unet_double_blocks_9_txt_mlp_2.lora_up.weight",
|
||||
"lora_unet_double_blocks_9_txt_mod_lin.alpha",
|
||||
"lora_unet_double_blocks_9_txt_mod_lin.lora_down.weight",
|
||||
"lora_unet_double_blocks_9_txt_mod_lin.lora_up.weight",
|
||||
"lora_unet_single_blocks_0_linear1.alpha",
|
||||
"lora_unet_single_blocks_0_linear1.lora_down.weight",
|
||||
"lora_unet_single_blocks_0_linear1.lora_up.weight",
|
||||
"lora_unet_single_blocks_0_linear2.alpha",
|
||||
"lora_unet_single_blocks_0_linear2.lora_down.weight",
|
||||
"lora_unet_single_blocks_0_linear2.lora_up.weight",
|
||||
"lora_unet_single_blocks_0_modulation_lin.alpha",
|
||||
"lora_unet_single_blocks_0_modulation_lin.lora_down.weight",
|
||||
"lora_unet_single_blocks_0_modulation_lin.lora_up.weight",
|
||||
"lora_unet_single_blocks_10_linear1.alpha",
|
||||
"lora_unet_single_blocks_10_linear1.lora_down.weight",
|
||||
"lora_unet_single_blocks_10_linear1.lora_up.weight",
|
||||
"lora_unet_single_blocks_10_linear2.alpha",
|
||||
"lora_unet_single_blocks_10_linear2.lora_down.weight",
|
||||
"lora_unet_single_blocks_10_linear2.lora_up.weight",
|
||||
"lora_unet_single_blocks_10_modulation_lin.alpha",
|
||||
"lora_unet_single_blocks_10_modulation_lin.lora_down.weight",
|
||||
"lora_unet_single_blocks_10_modulation_lin.lora_up.weight",
|
||||
"lora_unet_single_blocks_11_linear1.alpha",
|
||||
"lora_unet_single_blocks_11_linear1.lora_down.weight",
|
||||
"lora_unet_single_blocks_11_linear1.lora_up.weight",
|
||||
"lora_unet_single_blocks_11_linear2.alpha",
|
||||
"lora_unet_single_blocks_11_linear2.lora_down.weight",
|
||||
"lora_unet_single_blocks_11_linear2.lora_up.weight",
|
||||
"lora_unet_single_blocks_11_modulation_lin.alpha",
|
||||
"lora_unet_single_blocks_11_modulation_lin.lora_down.weight",
|
||||
"lora_unet_single_blocks_11_modulation_lin.lora_up.weight",
|
||||
"lora_unet_single_blocks_12_linear1.alpha",
|
||||
"lora_unet_single_blocks_12_linear1.lora_down.weight",
|
||||
"lora_unet_single_blocks_12_linear1.lora_up.weight",
|
||||
"lora_unet_single_blocks_12_linear2.alpha",
|
||||
"lora_unet_single_blocks_12_linear2.lora_down.weight",
|
||||
"lora_unet_single_blocks_12_linear2.lora_up.weight",
|
||||
"lora_unet_single_blocks_12_modulation_lin.alpha",
|
||||
"lora_unet_single_blocks_12_modulation_lin.lora_down.weight",
|
||||
"lora_unet_single_blocks_12_modulation_lin.lora_up.weight",
|
||||
"lora_unet_single_blocks_13_linear1.alpha",
|
||||
"lora_unet_single_blocks_13_linear1.lora_down.weight",
|
||||
"lora_unet_single_blocks_13_linear1.lora_up.weight",
|
||||
"lora_unet_single_blocks_13_linear2.alpha",
|
||||
"lora_unet_single_blocks_13_linear2.lora_down.weight",
|
||||
"lora_unet_single_blocks_13_linear2.lora_up.weight",
|
||||
"lora_unet_single_blocks_13_modulation_lin.alpha",
|
||||
"lora_unet_single_blocks_13_modulation_lin.lora_down.weight",
|
||||
"lora_unet_single_blocks_13_modulation_lin.lora_up.weight",
|
||||
"lora_unet_single_blocks_14_linear1.alpha",
|
||||
"lora_unet_single_blocks_14_linear1.lora_down.weight",
|
||||
"lora_unet_single_blocks_14_linear1.lora_up.weight",
|
||||
"lora_unet_single_blocks_14_linear2.alpha",
|
||||
"lora_unet_single_blocks_14_linear2.lora_down.weight",
|
||||
"lora_unet_single_blocks_14_linear2.lora_up.weight",
|
||||
"lora_unet_single_blocks_14_modulation_lin.alpha",
|
||||
"lora_unet_single_blocks_14_modulation_lin.lora_down.weight",
|
||||
"lora_unet_single_blocks_14_modulation_lin.lora_up.weight",
|
||||
"lora_unet_single_blocks_15_linear1.alpha",
|
||||
"lora_unet_single_blocks_15_linear1.lora_down.weight",
|
||||
"lora_unet_single_blocks_15_linear1.lora_up.weight",
|
||||
"lora_unet_single_blocks_15_linear2.alpha",
|
||||
"lora_unet_single_blocks_15_linear2.lora_down.weight",
|
||||
"lora_unet_single_blocks_15_linear2.lora_up.weight",
|
||||
"lora_unet_single_blocks_15_modulation_lin.alpha",
|
||||
"lora_unet_single_blocks_15_modulation_lin.lora_down.weight",
|
||||
"lora_unet_single_blocks_15_modulation_lin.lora_up.weight",
|
||||
"lora_unet_single_blocks_16_linear1.alpha",
|
||||
"lora_unet_single_blocks_16_linear1.lora_down.weight",
|
||||
"lora_unet_single_blocks_16_linear1.lora_up.weight",
|
||||
"lora_unet_single_blocks_16_linear2.alpha",
|
||||
"lora_unet_single_blocks_16_linear2.lora_down.weight",
|
||||
"lora_unet_single_blocks_16_linear2.lora_up.weight",
|
||||
"lora_unet_single_blocks_16_modulation_lin.alpha",
|
||||
"lora_unet_single_blocks_16_modulation_lin.lora_down.weight",
|
||||
"lora_unet_single_blocks_16_modulation_lin.lora_up.weight",
|
||||
"lora_unet_single_blocks_17_linear1.alpha",
|
||||
"lora_unet_single_blocks_17_linear1.lora_down.weight",
|
||||
"lora_unet_single_blocks_17_linear1.lora_up.weight",
|
||||
"lora_unet_single_blocks_17_linear2.alpha",
|
||||
"lora_unet_single_blocks_17_linear2.lora_down.weight",
|
||||
"lora_unet_single_blocks_17_linear2.lora_up.weight",
|
||||
"lora_unet_single_blocks_17_modulation_lin.alpha",
|
||||
"lora_unet_single_blocks_17_modulation_lin.lora_down.weight",
|
||||
"lora_unet_single_blocks_17_modulation_lin.lora_up.weight",
|
||||
"lora_unet_single_blocks_18_linear1.alpha",
|
||||
"lora_unet_single_blocks_18_linear1.lora_down.weight",
|
||||
"lora_unet_single_blocks_18_linear1.lora_up.weight",
|
||||
"lora_unet_single_blocks_18_linear2.alpha",
|
||||
"lora_unet_single_blocks_18_linear2.lora_down.weight",
|
||||
"lora_unet_single_blocks_18_linear2.lora_up.weight",
|
||||
"lora_unet_single_blocks_18_modulation_lin.alpha",
|
||||
"lora_unet_single_blocks_18_modulation_lin.lora_down.weight",
|
||||
"lora_unet_single_blocks_18_modulation_lin.lora_up.weight",
|
||||
"lora_unet_single_blocks_19_linear1.alpha",
|
||||
"lora_unet_single_blocks_19_linear1.lora_down.weight",
|
||||
"lora_unet_single_blocks_19_linear1.lora_up.weight",
|
||||
"lora_unet_single_blocks_19_linear2.alpha",
|
||||
"lora_unet_single_blocks_19_linear2.lora_down.weight",
|
||||
"lora_unet_single_blocks_19_linear2.lora_up.weight",
|
||||
"lora_unet_single_blocks_19_modulation_lin.alpha",
|
||||
"lora_unet_single_blocks_19_modulation_lin.lora_down.weight",
|
||||
"lora_unet_single_blocks_19_modulation_lin.lora_up.weight",
|
||||
"lora_unet_single_blocks_1_linear1.alpha",
|
||||
"lora_unet_single_blocks_1_linear1.lora_down.weight",
|
||||
"lora_unet_single_blocks_1_linear1.lora_up.weight",
|
||||
"lora_unet_single_blocks_1_linear2.alpha",
|
||||
"lora_unet_single_blocks_1_linear2.lora_down.weight",
|
||||
"lora_unet_single_blocks_1_linear2.lora_up.weight",
|
||||
"lora_unet_single_blocks_1_modulation_lin.alpha",
|
||||
"lora_unet_single_blocks_1_modulation_lin.lora_down.weight",
|
||||
"lora_unet_single_blocks_1_modulation_lin.lora_up.weight",
|
||||
"lora_unet_single_blocks_20_linear1.alpha",
|
||||
"lora_unet_single_blocks_20_linear1.lora_down.weight",
|
||||
"lora_unet_single_blocks_20_linear1.lora_up.weight",
|
||||
"lora_unet_single_blocks_20_linear2.alpha",
|
||||
"lora_unet_single_blocks_20_linear2.lora_down.weight",
|
||||
"lora_unet_single_blocks_20_linear2.lora_up.weight",
|
||||
"lora_unet_single_blocks_20_modulation_lin.alpha",
|
||||
"lora_unet_single_blocks_20_modulation_lin.lora_down.weight",
|
||||
"lora_unet_single_blocks_20_modulation_lin.lora_up.weight",
|
||||
"lora_unet_single_blocks_21_linear1.alpha",
|
||||
"lora_unet_single_blocks_21_linear1.lora_down.weight",
|
||||
"lora_unet_single_blocks_21_linear1.lora_up.weight",
|
||||
"lora_unet_single_blocks_21_linear2.alpha",
|
||||
"lora_unet_single_blocks_21_linear2.lora_down.weight",
|
||||
"lora_unet_single_blocks_21_linear2.lora_up.weight",
|
||||
"lora_unet_single_blocks_21_modulation_lin.alpha",
|
||||
"lora_unet_single_blocks_21_modulation_lin.lora_down.weight",
|
||||
"lora_unet_single_blocks_21_modulation_lin.lora_up.weight",
|
||||
"lora_unet_single_blocks_22_linear1.alpha",
|
||||
"lora_unet_single_blocks_22_linear1.lora_down.weight",
|
||||
"lora_unet_single_blocks_22_linear1.lora_up.weight",
|
||||
"lora_unet_single_blocks_22_linear2.alpha",
|
||||
"lora_unet_single_blocks_22_linear2.lora_down.weight",
|
||||
"lora_unet_single_blocks_22_linear2.lora_up.weight",
|
||||
"lora_unet_single_blocks_22_modulation_lin.alpha",
|
||||
"lora_unet_single_blocks_22_modulation_lin.lora_down.weight",
|
||||
"lora_unet_single_blocks_22_modulation_lin.lora_up.weight",
|
||||
"lora_unet_single_blocks_23_linear1.alpha",
|
||||
"lora_unet_single_blocks_23_linear1.lora_down.weight",
|
||||
"lora_unet_single_blocks_23_linear1.lora_up.weight",
|
||||
"lora_unet_single_blocks_23_linear2.alpha",
|
||||
"lora_unet_single_blocks_23_linear2.lora_down.weight",
|
||||
"lora_unet_single_blocks_23_linear2.lora_up.weight",
|
||||
"lora_unet_single_blocks_23_modulation_lin.alpha",
|
||||
"lora_unet_single_blocks_23_modulation_lin.lora_down.weight",
|
||||
"lora_unet_single_blocks_23_modulation_lin.lora_up.weight",
|
||||
"lora_unet_single_blocks_24_linear1.alpha",
|
||||
"lora_unet_single_blocks_24_linear1.lora_down.weight",
|
||||
"lora_unet_single_blocks_24_linear1.lora_up.weight",
|
||||
"lora_unet_single_blocks_24_linear2.alpha",
|
||||
"lora_unet_single_blocks_24_linear2.lora_down.weight",
|
||||
"lora_unet_single_blocks_24_linear2.lora_up.weight",
|
||||
"lora_unet_single_blocks_24_modulation_lin.alpha",
|
||||
"lora_unet_single_blocks_24_modulation_lin.lora_down.weight",
|
||||
"lora_unet_single_blocks_24_modulation_lin.lora_up.weight",
|
||||
"lora_unet_single_blocks_25_linear1.alpha",
|
||||
"lora_unet_single_blocks_25_linear1.lora_down.weight",
|
||||
"lora_unet_single_blocks_25_linear1.lora_up.weight",
|
||||
"lora_unet_single_blocks_25_linear2.alpha",
|
||||
"lora_unet_single_blocks_25_linear2.lora_down.weight",
|
||||
"lora_unet_single_blocks_25_linear2.lora_up.weight",
|
||||
"lora_unet_single_blocks_25_modulation_lin.alpha",
|
||||
"lora_unet_single_blocks_25_modulation_lin.lora_down.weight",
|
||||
"lora_unet_single_blocks_25_modulation_lin.lora_up.weight",
|
||||
"lora_unet_single_blocks_26_linear1.alpha",
|
||||
"lora_unet_single_blocks_26_linear1.lora_down.weight",
|
||||
"lora_unet_single_blocks_26_linear1.lora_up.weight",
|
||||
"lora_unet_single_blocks_26_linear2.alpha",
|
||||
"lora_unet_single_blocks_26_linear2.lora_down.weight",
|
||||
"lora_unet_single_blocks_26_linear2.lora_up.weight",
|
||||
"lora_unet_single_blocks_26_modulation_lin.alpha",
|
||||
"lora_unet_single_blocks_26_modulation_lin.lora_down.weight",
|
||||
"lora_unet_single_blocks_26_modulation_lin.lora_up.weight",
|
||||
"lora_unet_single_blocks_27_linear1.alpha",
|
||||
"lora_unet_single_blocks_27_linear1.lora_down.weight",
|
||||
"lora_unet_single_blocks_27_linear1.lora_up.weight",
|
||||
"lora_unet_single_blocks_27_linear2.alpha",
|
||||
"lora_unet_single_blocks_27_linear2.lora_down.weight",
|
||||
"lora_unet_single_blocks_27_linear2.lora_up.weight",
|
||||
"lora_unet_single_blocks_27_modulation_lin.alpha",
|
||||
"lora_unet_single_blocks_27_modulation_lin.lora_down.weight",
|
||||
"lora_unet_single_blocks_27_modulation_lin.lora_up.weight",
|
||||
"lora_unet_single_blocks_28_linear1.alpha",
|
||||
"lora_unet_single_blocks_28_linear1.lora_down.weight",
|
||||
"lora_unet_single_blocks_28_linear1.lora_up.weight",
|
||||
"lora_unet_single_blocks_28_linear2.alpha",
|
||||
"lora_unet_single_blocks_28_linear2.lora_down.weight",
|
||||
"lora_unet_single_blocks_28_linear2.lora_up.weight",
|
||||
"lora_unet_single_blocks_28_modulation_lin.alpha",
|
||||
"lora_unet_single_blocks_28_modulation_lin.lora_down.weight",
|
||||
"lora_unet_single_blocks_28_modulation_lin.lora_up.weight",
|
||||
"lora_unet_single_blocks_29_linear1.alpha",
|
||||
"lora_unet_single_blocks_29_linear1.lora_down.weight",
|
||||
"lora_unet_single_blocks_29_linear1.lora_up.weight",
|
||||
"lora_unet_single_blocks_29_linear2.alpha",
|
||||
"lora_unet_single_blocks_29_linear2.lora_down.weight",
|
||||
"lora_unet_single_blocks_29_linear2.lora_up.weight",
|
||||
"lora_unet_single_blocks_29_modulation_lin.alpha",
|
||||
"lora_unet_single_blocks_29_modulation_lin.lora_down.weight",
|
||||
"lora_unet_single_blocks_29_modulation_lin.lora_up.weight",
|
||||
"lora_unet_single_blocks_2_linear1.alpha",
|
||||
"lora_unet_single_blocks_2_linear1.lora_down.weight",
|
||||
"lora_unet_single_blocks_2_linear1.lora_up.weight",
|
||||
"lora_unet_single_blocks_2_linear2.alpha",
|
||||
"lora_unet_single_blocks_2_linear2.lora_down.weight",
|
||||
"lora_unet_single_blocks_2_linear2.lora_up.weight",
|
||||
"lora_unet_single_blocks_2_modulation_lin.alpha",
|
||||
"lora_unet_single_blocks_2_modulation_lin.lora_down.weight",
|
||||
"lora_unet_single_blocks_2_modulation_lin.lora_up.weight",
|
||||
"lora_unet_single_blocks_30_linear1.alpha",
|
||||
"lora_unet_single_blocks_30_linear1.lora_down.weight",
|
||||
"lora_unet_single_blocks_30_linear1.lora_up.weight",
|
||||
"lora_unet_single_blocks_30_linear2.alpha",
|
||||
"lora_unet_single_blocks_30_linear2.lora_down.weight",
|
||||
"lora_unet_single_blocks_30_linear2.lora_up.weight",
|
||||
"lora_unet_single_blocks_30_modulation_lin.alpha",
|
||||
"lora_unet_single_blocks_30_modulation_lin.lora_down.weight",
|
||||
"lora_unet_single_blocks_30_modulation_lin.lora_up.weight",
|
||||
"lora_unet_single_blocks_31_linear1.alpha",
|
||||
"lora_unet_single_blocks_31_linear1.lora_down.weight",
|
||||
"lora_unet_single_blocks_31_linear1.lora_up.weight",
|
||||
"lora_unet_single_blocks_31_linear2.alpha",
|
||||
"lora_unet_single_blocks_31_linear2.lora_down.weight",
|
||||
"lora_unet_single_blocks_31_linear2.lora_up.weight",
|
||||
"lora_unet_single_blocks_31_modulation_lin.alpha",
|
||||
"lora_unet_single_blocks_31_modulation_lin.lora_down.weight",
|
||||
"lora_unet_single_blocks_31_modulation_lin.lora_up.weight",
|
||||
"lora_unet_single_blocks_32_linear1.alpha",
|
||||
"lora_unet_single_blocks_32_linear1.lora_down.weight",
|
||||
"lora_unet_single_blocks_32_linear1.lora_up.weight",
|
||||
"lora_unet_single_blocks_32_linear2.alpha",
|
||||
"lora_unet_single_blocks_32_linear2.lora_down.weight",
|
||||
"lora_unet_single_blocks_32_linear2.lora_up.weight",
|
||||
"lora_unet_single_blocks_32_modulation_lin.alpha",
|
||||
"lora_unet_single_blocks_32_modulation_lin.lora_down.weight",
|
||||
"lora_unet_single_blocks_32_modulation_lin.lora_up.weight",
|
||||
"lora_unet_single_blocks_33_linear1.alpha",
|
||||
"lora_unet_single_blocks_33_linear1.lora_down.weight",
|
||||
"lora_unet_single_blocks_33_linear1.lora_up.weight",
|
||||
"lora_unet_single_blocks_33_linear2.alpha",
|
||||
"lora_unet_single_blocks_33_linear2.lora_down.weight",
|
||||
"lora_unet_single_blocks_33_linear2.lora_up.weight",
|
||||
"lora_unet_single_blocks_33_modulation_lin.alpha",
|
||||
"lora_unet_single_blocks_33_modulation_lin.lora_down.weight",
|
||||
"lora_unet_single_blocks_33_modulation_lin.lora_up.weight",
|
||||
"lora_unet_single_blocks_34_linear1.alpha",
|
||||
"lora_unet_single_blocks_34_linear1.lora_down.weight",
|
||||
"lora_unet_single_blocks_34_linear1.lora_up.weight",
|
||||
"lora_unet_single_blocks_34_linear2.alpha",
|
||||
"lora_unet_single_blocks_34_linear2.lora_down.weight",
|
||||
"lora_unet_single_blocks_34_linear2.lora_up.weight",
|
||||
"lora_unet_single_blocks_34_modulation_lin.alpha",
|
||||
"lora_unet_single_blocks_34_modulation_lin.lora_down.weight",
|
||||
"lora_unet_single_blocks_34_modulation_lin.lora_up.weight",
|
||||
"lora_unet_single_blocks_35_linear1.alpha",
|
||||
"lora_unet_single_blocks_35_linear1.lora_down.weight",
|
||||
"lora_unet_single_blocks_35_linear1.lora_up.weight",
|
||||
"lora_unet_single_blocks_35_linear2.alpha",
|
||||
"lora_unet_single_blocks_35_linear2.lora_down.weight",
|
||||
"lora_unet_single_blocks_35_linear2.lora_up.weight",
|
||||
"lora_unet_single_blocks_35_modulation_lin.alpha",
|
||||
"lora_unet_single_blocks_35_modulation_lin.lora_down.weight",
|
||||
"lora_unet_single_blocks_35_modulation_lin.lora_up.weight",
|
||||
"lora_unet_single_blocks_36_linear1.alpha",
|
||||
"lora_unet_single_blocks_36_linear1.lora_down.weight",
|
||||
"lora_unet_single_blocks_36_linear1.lora_up.weight",
|
||||
"lora_unet_single_blocks_36_linear2.alpha",
|
||||
"lora_unet_single_blocks_36_linear2.lora_down.weight",
|
||||
"lora_unet_single_blocks_36_linear2.lora_up.weight",
|
||||
"lora_unet_single_blocks_36_modulation_lin.alpha",
|
||||
"lora_unet_single_blocks_36_modulation_lin.lora_down.weight",
|
||||
"lora_unet_single_blocks_36_modulation_lin.lora_up.weight",
|
||||
"lora_unet_single_blocks_37_linear1.alpha",
|
||||
"lora_unet_single_blocks_37_linear1.lora_down.weight",
|
||||
"lora_unet_single_blocks_37_linear1.lora_up.weight",
|
||||
"lora_unet_single_blocks_37_linear2.alpha",
|
||||
"lora_unet_single_blocks_37_linear2.lora_down.weight",
|
||||
"lora_unet_single_blocks_37_linear2.lora_up.weight",
|
||||
"lora_unet_single_blocks_37_modulation_lin.alpha",
|
||||
"lora_unet_single_blocks_37_modulation_lin.lora_down.weight",
|
||||
"lora_unet_single_blocks_37_modulation_lin.lora_up.weight",
|
||||
"lora_unet_single_blocks_3_linear1.alpha",
|
||||
"lora_unet_single_blocks_3_linear1.lora_down.weight",
|
||||
"lora_unet_single_blocks_3_linear1.lora_up.weight",
|
||||
"lora_unet_single_blocks_3_linear2.alpha",
|
||||
"lora_unet_single_blocks_3_linear2.lora_down.weight",
|
||||
"lora_unet_single_blocks_3_linear2.lora_up.weight",
|
||||
"lora_unet_single_blocks_3_modulation_lin.alpha",
|
||||
"lora_unet_single_blocks_3_modulation_lin.lora_down.weight",
|
||||
"lora_unet_single_blocks_3_modulation_lin.lora_up.weight",
|
||||
"lora_unet_single_blocks_4_linear1.alpha",
|
||||
"lora_unet_single_blocks_4_linear1.lora_down.weight",
|
||||
"lora_unet_single_blocks_4_linear1.lora_up.weight",
|
||||
"lora_unet_single_blocks_4_linear2.alpha",
|
||||
"lora_unet_single_blocks_4_linear2.lora_down.weight",
|
||||
"lora_unet_single_blocks_4_linear2.lora_up.weight",
|
||||
"lora_unet_single_blocks_4_modulation_lin.alpha",
|
||||
"lora_unet_single_blocks_4_modulation_lin.lora_down.weight",
|
||||
"lora_unet_single_blocks_4_modulation_lin.lora_up.weight",
|
||||
"lora_unet_single_blocks_5_linear1.alpha",
|
||||
"lora_unet_single_blocks_5_linear1.lora_down.weight",
|
||||
"lora_unet_single_blocks_5_linear1.lora_up.weight",
|
||||
"lora_unet_single_blocks_5_linear2.alpha",
|
||||
"lora_unet_single_blocks_5_linear2.lora_down.weight",
|
||||
"lora_unet_single_blocks_5_linear2.lora_up.weight",
|
||||
"lora_unet_single_blocks_5_modulation_lin.alpha",
|
||||
"lora_unet_single_blocks_5_modulation_lin.lora_down.weight",
|
||||
"lora_unet_single_blocks_5_modulation_lin.lora_up.weight",
|
||||
"lora_unet_single_blocks_6_linear1.alpha",
|
||||
"lora_unet_single_blocks_6_linear1.lora_down.weight",
|
||||
"lora_unet_single_blocks_6_linear1.lora_up.weight",
|
||||
"lora_unet_single_blocks_6_linear2.alpha",
|
||||
"lora_unet_single_blocks_6_linear2.lora_down.weight",
|
||||
"lora_unet_single_blocks_6_linear2.lora_up.weight",
|
||||
"lora_unet_single_blocks_6_modulation_lin.alpha",
|
||||
"lora_unet_single_blocks_6_modulation_lin.lora_down.weight",
|
||||
"lora_unet_single_blocks_6_modulation_lin.lora_up.weight",
|
||||
"lora_unet_single_blocks_7_linear1.alpha",
|
||||
"lora_unet_single_blocks_7_linear1.lora_down.weight",
|
||||
"lora_unet_single_blocks_7_linear1.lora_up.weight",
|
||||
"lora_unet_single_blocks_7_linear2.alpha",
|
||||
"lora_unet_single_blocks_7_linear2.lora_down.weight",
|
||||
"lora_unet_single_blocks_7_linear2.lora_up.weight",
|
||||
"lora_unet_single_blocks_7_modulation_lin.alpha",
|
||||
"lora_unet_single_blocks_7_modulation_lin.lora_down.weight",
|
||||
"lora_unet_single_blocks_7_modulation_lin.lora_up.weight",
|
||||
"lora_unet_single_blocks_8_linear1.alpha",
|
||||
"lora_unet_single_blocks_8_linear1.lora_down.weight",
|
||||
"lora_unet_single_blocks_8_linear1.lora_up.weight",
|
||||
"lora_unet_single_blocks_8_linear2.alpha",
|
||||
"lora_unet_single_blocks_8_linear2.lora_down.weight",
|
||||
"lora_unet_single_blocks_8_linear2.lora_up.weight",
|
||||
"lora_unet_single_blocks_8_modulation_lin.alpha",
|
||||
"lora_unet_single_blocks_8_modulation_lin.lora_down.weight",
|
||||
"lora_unet_single_blocks_8_modulation_lin.lora_up.weight",
|
||||
"lora_unet_single_blocks_9_linear1.alpha",
|
||||
"lora_unet_single_blocks_9_linear1.lora_down.weight",
|
||||
"lora_unet_single_blocks_9_linear1.lora_up.weight",
|
||||
"lora_unet_single_blocks_9_linear2.alpha",
|
||||
"lora_unet_single_blocks_9_linear2.lora_down.weight",
|
||||
"lora_unet_single_blocks_9_linear2.lora_up.weight",
|
||||
"lora_unet_single_blocks_9_modulation_lin.alpha",
|
||||
"lora_unet_single_blocks_9_modulation_lin.lora_down.weight",
|
||||
"lora_unet_single_blocks_9_modulation_lin.lora_up.weight",
|
||||
]
|
||||
@@ -0,0 +1,97 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from invokeai.backend.flux.model import Flux
|
||||
from invokeai.backend.flux.util import params
|
||||
from invokeai.backend.peft.conversions.flux_kohya_lora_conversion_utils import (
|
||||
convert_flux_kohya_state_dict_to_invoke_format,
|
||||
is_state_dict_likely_in_flux_kohya_format,
|
||||
lora_model_from_flux_kohya_state_dict,
|
||||
)
|
||||
from tests.backend.peft.conversions.lora_state_dicts.flux_lora_kohya_format import state_dict_keys
|
||||
|
||||
|
||||
def test_is_state_dict_likely_in_flux_kohya_format_true():
|
||||
"""Test that is_state_dict_likely_in_flux_kohya_format() can identify a state dict in the Kohya FLUX LoRA format."""
|
||||
# Construct a state dict that is in the Kohya FLUX LoRA format.
|
||||
state_dict: dict[str, torch.Tensor] = {}
|
||||
for k in state_dict_keys:
|
||||
state_dict[k] = torch.empty(1)
|
||||
assert is_state_dict_likely_in_flux_kohya_format(state_dict)
|
||||
|
||||
|
||||
def test_is_state_dict_likely_in_flux_kohya_format_false():
|
||||
"""Test that is_state_dict_likely_in_flux_kohya_format() returns False for a state dict that is not in the Kohya FLUX LoRA format."""
|
||||
state_dict: dict[str, torch.Tensor] = {
|
||||
"unexpected_key.lora_up.weight": torch.empty(1),
|
||||
}
|
||||
assert not is_state_dict_likely_in_flux_kohya_format(state_dict)
|
||||
|
||||
|
||||
def test_convert_flux_kohya_state_dict_to_invoke_format():
|
||||
# Construct state_dict from state_dict_keys.
|
||||
state_dict: dict[str, torch.Tensor] = {}
|
||||
for k in state_dict_keys:
|
||||
state_dict[k] = torch.empty(1)
|
||||
|
||||
converted_state_dict = convert_flux_kohya_state_dict_to_invoke_format(state_dict)
|
||||
|
||||
# Extract the prefixes from the converted state dict (i.e. without the .lora_up.weight, .lora_down.weight, and
|
||||
# .alpha suffixes).
|
||||
converted_key_prefixes: list[str] = []
|
||||
for k in converted_state_dict.keys():
|
||||
k = k.replace(".lora_up.weight", "")
|
||||
k = k.replace(".lora_down.weight", "")
|
||||
k = k.replace(".alpha", "")
|
||||
converted_key_prefixes.append(k)
|
||||
|
||||
# Initialize a FLUX model on the meta device.
|
||||
with torch.device("meta"):
|
||||
model = Flux(params["flux-dev"])
|
||||
model_keys = set(model.state_dict().keys())
|
||||
|
||||
# Assert that the converted state dict matches the keys in the actual model.
|
||||
for converted_key_prefix in converted_key_prefixes:
|
||||
found_match = False
|
||||
for model_key in model_keys:
|
||||
if model_key.startswith(converted_key_prefix):
|
||||
found_match = True
|
||||
break
|
||||
if not found_match:
|
||||
raise AssertionError(f"Could not find a match for the converted key prefix: {converted_key_prefix}")
|
||||
|
||||
|
||||
def test_convert_flux_kohya_state_dict_to_invoke_format_error():
|
||||
"""Test that an error is raised by convert_flux_kohya_state_dict_to_invoke_format() if the input state_dict contains
|
||||
unexpected keys.
|
||||
"""
|
||||
state_dict = {
|
||||
"unexpected_key.lora_up.weight": torch.empty(1),
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
convert_flux_kohya_state_dict_to_invoke_format(state_dict)
|
||||
|
||||
|
||||
def test_lora_model_from_flux_kohya_state_dict():
|
||||
"""Test that a LoRAModelRaw can be created from a state dict in the Kohya FLUX LoRA format."""
|
||||
# Construct state_dict from state_dict_keys.
|
||||
state_dict: dict[str, torch.Tensor] = {}
|
||||
for k in state_dict_keys:
|
||||
state_dict[k] = torch.empty(1)
|
||||
|
||||
lora_model = lora_model_from_flux_kohya_state_dict(state_dict)
|
||||
|
||||
# Prepare expected layer keys.
|
||||
expected_layer_keys: set[str] = set()
|
||||
for k in state_dict_keys:
|
||||
k = k.replace("lora_unet_", "")
|
||||
k = k.replace(".lora_up.weight", "")
|
||||
k = k.replace(".lora_down.weight", "")
|
||||
k = k.replace(".alpha", "")
|
||||
expected_layer_keys.add(k)
|
||||
|
||||
# Assert that the lora_model has the expected layers.
|
||||
lora_model_keys = set(lora_model.layers.keys())
|
||||
lora_model_keys = {k.replace(".", "_") for k in lora_model_keys}
|
||||
assert lora_model_keys == expected_layer_keys
|
||||
Reference in New Issue
Block a user