Compare commits

...

71 Commits

Author SHA1 Message Date
Ryan Dick
0ff5355ce3 Rename flux_kohya_lora_conversion_utils.py 2024-09-05 14:18:17 +00:00
Ryan Dick
3ad5fc060d Fixup FLUX LoRA unit tests. 2024-09-05 14:12:56 +00:00
Ryan Dick
17d5c85454 WIP 2024-09-04 22:52:18 +00:00
Ryan Dick
4698649cc9 WIP - add invocations to support FLUX LORAs. 2024-09-04 19:55:06 +00:00
Ryan Dick
d41c075768 Get probing of FLUX LoRA kohya models working. 2024-09-04 16:03:55 +00:00
Ryan Dick
6b129aaba6 Add utility function for detecting whether a state_dict is in the FLUX kohya LoRA format. 2024-09-04 15:51:15 +00:00
Ryan Dick
f58546fd53 Update convert_flux_kohya_state_dict_to_invoke_format() to raise an exception if an unexpected key is encountered, and add a corresponding unit test. 2024-09-04 15:34:31 +00:00
Ryan Dick
de3edf47fb Move the responsibilities of 1) state_dict loading from file, and 2) SDXL lora key conversions, out of LoRAModelRaw and into LoRALoader. 2024-09-04 15:18:43 +00:00
Ryan Dick
6dc4baa925 Remove unused LoRAModelRaw.name attribute. 2024-09-04 14:52:30 +00:00
Ryan Dick
943fa6da4b Fix type errors in sdxl_conversion_utils.py 2024-09-04 14:35:38 +00:00
Ryan Dick
bfe31838cc Start moving SDXL-specific LoRA conversions out of the general-purpose LoRAModelRaw class. 2024-09-04 14:30:18 +00:00
Ryan Dick
16b76f7e7f Get convert_flux_kohya_state_dict_to_invoke_format(...) working, with unit tests. 2024-09-04 13:42:12 +00:00
Ryan Dick
cb115743e7 WIP - FLUX LoRA conversion logic. 2024-09-03 22:36:16 +00:00
Ryan Dick
8f4279ba51 Add state_dict keys for two FLUX LoRA formats to be used in unit tests. 2024-09-03 22:01:00 +00:00
Ryan Dick
22a207b50d Move lora.py to peft/ subdir. 2024-09-03 18:13:21 +00:00
Ryan Dick
638c6003e3 Split PEFT layer implementations into separate files. 2024-09-03 18:06:21 +00:00
Lincoln Stein
8d35af946e [MM] add API routes for getting & setting MM cache sizes (#6523)
* [MM] add API routes for getting & setting MM cache sizes, and retrieving MM stats

* Update invokeai/app/api/routers/model_manager.py

Co-authored-by: Ryan Dick <ryanjdick3@gmail.com>

* code cleanup after @ryand review

* Update invokeai/app/api/routers/model_manager.py

Co-authored-by: Ryan Dick <ryanjdick3@gmail.com>

* fix merge conflicts; tested and working

---------

Co-authored-by: Lincoln Stein <lstein@gmail.com>
Co-authored-by: Ryan Dick <ryanjdick3@gmail.com>
2024-09-02 12:18:21 -04:00
Ryan Dick
24065ec6b6 Add FLUX image-to-image and inpainting (#6798)
## Summary

This PR adds support for Image-to-Image and inpainting workflows with
the FLUX model.

Full changelog:
- Split out `FLUX VAE Encode` and `FLUX VAE Decode` nodes
- Renamed `FLUX Text-to-Image` node to `FLUX Denoise` (since it now
supports image-to-image too). This is a workflow-breaking change.
- Added support for FLUX image-to-image via the `Latents` param on the
FLUX denoising node.
- Added support for FLUX masked inpainting via the `Denoise Mask` param
on the FLUX denoising node.
- Added "Denoise Start" and "Denoise End" params to the "FLUX Denoise"
node.
- Updated the "FLUX Text to Image" default workflow.
- Added a "FLUX Image to Image" default workflow.

### Example

FLUX inpainting workflow
<img width="1282" alt="image"
src="https://github.com/user-attachments/assets/86fc1170-e620-4412-8fd8-e119f875fc2e">

Input image

![image](https://github.com/user-attachments/assets/9c381b86-9f87-4257-bd2e-da22c56ca26c)

Mask

![image](https://github.com/user-attachments/assets/8f774c5c-2a25-45fe-9d4b-b233e3d58d2c)

Output image

![image](https://github.com/user-attachments/assets/8576a630-24ce-4a00-8052-e86bab59c855)


### Callouts for reviewers:
- I renamed FLUXTextToImageInvocation -> FLUXDenoisingInvocation. This
is, of course, a breaking change. It feels like the right move and now
is the right time to do it. Any objection?
- I added new `FLUX VAE Encode` and `FLUX VAE Decode` nodes.
Alternatively, I could have tried to match these names to the
corresponding SD nodes (e.g. `FLUX Image to Latents`, `FLUX Latents to
Image`). Personally, I prefer the current names, but want to hear other
opinions.

### Usage notes:
- With the default dev timestep scheduler, the image structure is
largely determined in the first ~3 steps. A consequence of this is that
the denoise_start parameter provides limited 'granularity' of control.
This will likely be improved in the future as we add more scheduler
options. In the meantime, you will likely want to use small values for
`denoise_start` (e.g. 0.03) to start denoising on step ~1-4 out of ~30.
- Currently, there is no 'noise' parameter on the `FLUX Denoise` node,
so the `denoise_end` parameter has limited utility. This will be added
in the future.

## QA Instructions

Test the following workflows:
- [x] Vanilla FLUX text-to-image behaviour is unchanged
- [x] Image-to-image with FLUX dev, no mask
- [x] Image-to-image with FLUX dev, with mask
- [x] Image-to-image with FLUX schnell, no mask (smoke test, not
expected to work well)

## Merge Plan

No special instructions.

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [x] _Tests added / updated (if applicable)_
- [x] _Documentation added / updated (if applicable)_
2024-09-02 09:50:31 -04:00
Ryan Dick
627b0bf644 Expose all FLUX model params in the default FLUX models. 2024-09-02 09:38:17 -04:00
Ryan Dick
b43da46b82 Rename 'FLUX VAE Encode'/'FLUX VAE Decode' to 'FLUX Image to Latents'/'FLUX Latents to Image' 2024-09-02 09:38:17 -04:00
Ryan Dick
4255a01c64 Restore line that was accidentally removed during development. 2024-09-02 09:38:17 -04:00
Ryan Dick
23adbd4002 Update schema.ts. 2024-09-02 09:38:17 -04:00
Ryan Dick
fb5a24fcc6 Update default workflows for FLUX. 2024-09-02 09:38:17 -04:00
Ryan Dick
cfdd5a1900 Rename flux_text_to_image.py -> flex_denoise.py 2024-09-02 09:38:17 -04:00
Ryan Dick
2313f326df Add denoise_end param to FluxDenoiseInvocation. 2024-09-02 09:38:17 -04:00
Ryan Dick
2e092a2313 Rename FluxTextToImageInvocation -> FluxDenoiseInvocation. 2024-09-02 09:38:17 -04:00
Ryan Dick
763ef06c18 Use the existence of initial latents to decide whether we are doing image-to-image in the FLUX denoising node. Previously we were using the denoising_start value, but in some cases with an inpaintin mask you may want to run image-to-image from densoising_start=0. 2024-09-02 09:38:17 -04:00
Ryan Dick
8292f6cd42 Code cleanup and documentation around FLUX inpainting. 2024-09-02 09:38:17 -04:00
Ryan Dick
278bba499e Split FLUX VAE decoding out into its own node from LatentsToImageInvocation. 2024-09-02 09:38:17 -04:00
Ryan Dick
dd99ed28e0 Split FLUX VAE encoding out into its own node from ImageToLatentsInvocation. 2024-09-02 09:38:17 -04:00
Ryan Dick
9a8aca69bf Get a rough version of FLUX inpainting working. 2024-09-02 09:38:17 -04:00
Ryan Dick
7ad62512eb Update MaskTensorToImageInvocation to support input mask tensors with or without a channel dimension. 2024-09-02 09:38:17 -04:00
Ryan Dick
bd466661ec Remove unused vae field from FLUXTextToImageInvocation. 2024-09-02 09:38:17 -04:00
Ryan Dick
7ebb509d05 Bump FLUX node versions after splitting out VAE encode/decode. 2024-09-02 09:38:17 -04:00
Ryan Dick
0aa13c046c Split VAE decoding out from the FLUXTextToImageInvocation. 2024-09-02 09:38:17 -04:00
Ryan Dick
a7a33d73f5 Get FLUX non-masked image-to-image working - still rough. 2024-09-02 09:38:17 -04:00
Ryan Dick
ffa39857d3 Add FLUX VAE decoding support to LatentsToImageInvocation. 2024-09-02 09:38:17 -04:00
Ryan Dick
e85c3bc465 Add FLUX VAE support to ImageToLatentsInvocation. 2024-09-02 09:38:17 -04:00
psychedelicious
8185ba7054 scripts: add allocate_vram script
Allocates the specified amount of VRAM, or allocates enough VRAM such that you have the specified amount of VRAM free.

Useful to simulate an environment with a specific amount of VRAM.
2024-09-02 18:18:26 +10:00
Lincoln Stein
d501865bec add a new FAQ for converting safetensors (#6736)
Co-authored-by: Lincoln Stein <lstein@gmail.com>
2024-08-31 18:56:08 +00:00
Brandon Rising
d62310bb5f Support HF repos with subfolders in source on windows OS 2024-08-30 19:31:42 -04:00
Brandon Rising
1835bff196 Fix source string in hugging face installs with subfolders 2024-08-30 19:31:42 -04:00
Ryan Dick
87261bdbc9 FLUX memory management improvements (#6791)
## Summary

This PR contains several improvements to memory management for FLUX
workflows.

It is now possible to achieve better FLUX model caching performance, but
this still requires users to manually configure their `ram`/`vram`
settings. E.g. a `vram` setting of 16.0 should allow for all quantized
FLUX models to be kept in memory on the GPU.

Changes:
- Check the size of a model on disk and free the requisite space in the
model cache before loading it. (This behaviour existed previously, but
was removed in https://github.com/invoke-ai/InvokeAI/pull/6072/files.
The removal did not seem to be intentional).
- Removed the hack to free 24GB of space in the cache before loading the
FLUX model.
- Split the T5 embedding and CLIP embedding steps into separate
functions so that the two models don't both have to be held in RAM at
the same time.
- Fix a bug in `InvokeLinear8bitLt` that was causing some tensors to be
left on the GPU when the model was offloaded to the CPU. (This class is
getting very messy due to the non-standard state_dict handling in
`bnb.nn.Linear8bitLt`. )
- Tidy up some dtype handling in FluxTextToImageInvocation to avoid
situations where we hold references to two copies of the same tensor
unnecessarily.
- (minor) Misc cleanup of ModelCache: improve docs and remove unused
vars.

Future:
We should revisit our default ram/vram configs. The current defaults are
very conservative, and users could see major performance improvements
from tuning these values.

## QA Instructions

I tested the FLUX workflow with the following configurations and
verified that the cache hit rates and memory usage matched the expected
behaviour:
- `ram = 16` and `vram = 16`
- `ram = 16` and `vram = 1`
- `ram = 1` and `vram = 1`

Note that the changes in this PR are not isolated to FLUX. Since we now
check the size of models on disk, we may see slight changes in model
cache offload patterns for other models as well.

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [x] _Tests added / updated (if applicable)_
- [x] _Documentation added / updated (if applicable)_
2024-08-29 15:17:45 -04:00
Ryan Dick
4e4b6c6dbc Tidy variable management and dtype handling in FluxTextToImageInvocation. 2024-08-29 19:08:18 +00:00
Ryan Dick
5e8cf9fb6a Remove hack to clear cache from the FluxTextToImageInvocation. We now clear the cache based on the on-disk model size. 2024-08-29 19:08:18 +00:00
Ryan Dick
c738fe051f Split T5 encoding and CLIP encoding into separate functions to ensure that all model references are locally-scoped so that the two models don't have to be help in memory at the same time. 2024-08-29 19:08:18 +00:00
Ryan Dick
29fe1533f2 Fix bug in InvokeLinear8bitLt that was causing old state information to persist after loading from a state dict. This manifested as state tensors being left on the GPU even when a model had been offloaded to the CPU cache. 2024-08-29 19:08:18 +00:00
Ryan Dick
77090070bd Check the size of a model on disk and make room for it in the cache before loading it. 2024-08-29 19:08:18 +00:00
Ryan Dick
6ba9b1b6b0 Tidy up GIG -> GB and remove unused GIG constant. 2024-08-29 19:08:18 +00:00
Ryan Dick
c578b8df1e Improve ModelCache docs. 2024-08-29 19:08:18 +00:00
Ryan Dick
cad9a41433 Remove unused MOdelCache.exists(...) function. 2024-08-29 19:08:18 +00:00
Ryan Dick
5fefb3b0f4 Remove unused param from ModelCache. 2024-08-29 19:08:18 +00:00
Ryan Dick
5284a870b0 Remove unused constructor params from ModelCache. 2024-08-29 19:08:18 +00:00
Ryan Dick
e064377c05 Remove default model cache sizes from model_cache_default.py. These defaults were misleading, because the config defaults take precedence over them. 2024-08-29 19:08:18 +00:00
Mary Hipp
3e569c8312 feat(ui): add fields for CLIP embed models and Flux VAE models in workflows 2024-08-29 11:52:51 -04:00
maryhipp
16825ee6e9 feat(nodes): bump version of flux model node, update default workflow 2024-08-29 11:52:51 -04:00
Mary Hipp
3f5340fa53 feat(nodes): add submodels as inputs to FLUX main model node instead of hardcoded names 2024-08-29 11:52:51 -04:00
chainchompa
f2a1a39b33 Add selectedStylePreset to app parameters (#6787)
## Summary
- Add selectedStylePreset to app parameters
<!--A description of the changes in this PR. Include the kind of change
(fix, feature, docs, etc), the "why" and the "how". Screenshots or
videos are useful for frontend changes.-->

## Related Issues / Discussions

<!--WHEN APPLICABLE: List any related issues or discussions on github or
discord. If this PR closes an issue, please use the "Closes #1234"
format, so that the issue will be automatically closed when the PR
merges.-->

## QA Instructions

<!--WHEN APPLICABLE: Describe how you have tested the changes in this
PR. Provide enough detail that a reviewer can reproduce your tests.-->

## Merge Plan

<!--WHEN APPLICABLE: Large PRs, or PRs that touch sensitive things like
DB schemas, may need some care when merging. For example, a careful
rebase by the change author, timing to not interfere with a pending
release, or a message to contributors on discord after merging.-->

## Checklist

- [ ] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
2024-08-28 10:53:07 -04:00
chainchompa
326de55d3e remove api changes and only preselect style preset 2024-08-28 09:53:29 -04:00
chainchompa
b2df909570 added selectedStylePreset to preload presets when app loads 2024-08-28 09:50:44 -04:00
chainchompa
026ac36b06 Revert "added selectedStylePreset to preload presets when app loads"
This reverts commit e97fd85904.
2024-08-28 09:44:08 -04:00
chainchompa
92125e5fd2 bug fixes 2024-08-27 16:13:38 -04:00
chainchompa
c0c139da88 formatting ruff 2024-08-27 15:46:51 -04:00
chainchompa
404ad6a7fd cleanup 2024-08-27 15:42:42 -04:00
chainchompa
fc39086fb4 call stylePresetSelected 2024-08-27 15:34:31 -04:00
chainchompa
cd215700fe added route for selecting style preset 2024-08-27 15:34:07 -04:00
chainchompa
e97fd85904 added selectedStylePreset to preload presets when app loads 2024-08-27 15:33:24 -04:00
Brandon Rising
0a263fa5b1 chore: bump version to v4.2.9rc1 2024-08-27 12:09:27 -04:00
Mary Hipp
fae3836a8d fix CLIP 2024-08-27 10:29:10 -04:00
Mary Hipp
b3d2eb4178 add translations for new model types in MM, remove clip vision from filter since its not displayed in list 2024-08-27 10:29:10 -04:00
psychedelicious
576f1cbb75 build: remove broken scripts
These two scripts are broken and can cause data loss. Remove them.

They are not in the launcher script, but _are_ available to users in the terminal/file browser.

Hopefully, when we removing them here, `pip` will delete them on next installation of the package...
2024-08-27 22:01:45 +10:00
75 changed files with 5056 additions and 1404 deletions

View File

@@ -196,6 +196,22 @@ tips to reduce the problem:
=== "12GB VRAM GPU"
This should be sufficient to generate larger images up to about 1280x1280.
## Checkpoint Models Load Slowly or Use Too Much RAM
The difference between diffusers models (a folder containing multiple
subfolders) and checkpoint models (a file ending with .safetensors or
.ckpt) is that InvokeAI is able to load diffusers models into memory
incrementally, while checkpoint models must be loaded all at
once. With very large models, or systems with limited RAM, you may
experience slowdowns and other memory-related issues when loading
checkpoint models.
To solve this, go to the Model Manager tab (the cube), select the
checkpoint model that's giving you trouble, and press the "Convert"
button in the upper right of your browser window. This will conver the
checkpoint into a diffusers model, after which loading should be
faster and less memory-intensive.
## Memory Leak (Linux)

View File

@@ -3,8 +3,10 @@
import io
import pathlib
import shutil
import traceback
from copy import deepcopy
from enum import Enum
from tempfile import TemporaryDirectory
from typing import List, Optional, Type
@@ -17,6 +19,7 @@ from starlette.exceptions import HTTPException
from typing_extensions import Annotated
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.services.config import get_config
from invokeai.app.services.model_images.model_images_common import ModelImageFileNotFoundException
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
from invokeai.app.services.model_records import (
@@ -31,6 +34,7 @@ from invokeai.backend.model_manager.config import (
ModelFormat,
ModelType,
)
from invokeai.backend.model_manager.load.model_cache.model_cache_base import CacheStats
from invokeai.backend.model_manager.metadata.fetch.huggingface import HuggingFaceMetadataFetch
from invokeai.backend.model_manager.metadata.metadata_base import ModelMetadataWithFiles, UnknownMetadataException
from invokeai.backend.model_manager.search import ModelSearch
@@ -50,6 +54,13 @@ class ModelsList(BaseModel):
model_config = ConfigDict(use_enum_values=True)
class CacheType(str, Enum):
"""Cache type - one of vram or ram."""
RAM = "RAM"
VRAM = "VRAM"
def add_cover_image_to_model_config(config: AnyModelConfig, dependencies: Type[ApiDependencies]) -> AnyModelConfig:
"""Add a cover image URL to a model configuration."""
cover_image = dependencies.invoker.services.model_images.get_url(config.key)
@@ -797,3 +808,83 @@ async def get_starter_models() -> list[StarterModel]:
model.dependencies = missing_deps
return starter_models
@model_manager_router.get(
"/model_cache",
operation_id="get_cache_size",
response_model=float,
summary="Get maximum size of model manager RAM or VRAM cache.",
)
async def get_cache_size(cache_type: CacheType = Query(description="The cache type", default=CacheType.RAM)) -> float:
"""Return the current RAM or VRAM cache size setting (in GB)."""
cache = ApiDependencies.invoker.services.model_manager.load.ram_cache
value = 0.0
if cache_type == CacheType.RAM:
value = cache.max_cache_size
elif cache_type == CacheType.VRAM:
value = cache.max_vram_cache_size
return value
@model_manager_router.put(
"/model_cache",
operation_id="set_cache_size",
response_model=float,
summary="Set maximum size of model manager RAM or VRAM cache, optionally writing new value out to invokeai.yaml config file.",
)
async def set_cache_size(
value: float = Query(description="The new value for the maximum cache size"),
cache_type: CacheType = Query(description="The cache type", default=CacheType.RAM),
persist: bool = Query(description="Write new value out to invokeai.yaml", default=False),
) -> float:
"""Set the current RAM or VRAM cache size setting (in GB). ."""
cache = ApiDependencies.invoker.services.model_manager.load.ram_cache
app_config = get_config()
# Record initial state.
vram_old = app_config.vram
ram_old = app_config.ram
# Prepare target state.
vram_new = vram_old
ram_new = ram_old
if cache_type == CacheType.RAM:
ram_new = value
elif cache_type == CacheType.VRAM:
vram_new = value
else:
raise ValueError(f"Unexpected {cache_type=}.")
config_path = app_config.config_file_path
new_config_path = config_path.with_suffix(".yaml.new")
try:
# Try to apply the target state.
cache.max_vram_cache_size = vram_new
cache.max_cache_size = ram_new
app_config.ram = ram_new
app_config.vram = vram_new
if persist:
app_config.write_file(new_config_path)
shutil.move(new_config_path, config_path)
except Exception as e:
# If there was a failure, restore the initial state.
cache.max_cache_size = ram_old
cache.max_vram_cache_size = vram_old
app_config.ram = ram_old
app_config.vram = vram_old
raise RuntimeError("Failed to update cache size") from e
return value
@model_manager_router.get(
"/stats",
operation_id="get_stats",
response_model=Optional[CacheStats],
summary="Get model manager RAM cache performance statistics.",
)
async def get_stats() -> Optional[CacheStats]:
"""Return performance statistics on the model manager's RAM cache. Will return null if no models have been loaded."""
return ApiDependencies.invoker.services.model_manager.load.ram_cache.stats

View File

@@ -19,8 +19,8 @@ from invokeai.app.invocations.model import CLIPField
from invokeai.app.invocations.primitives import ConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.ti_utils import generate_ti_list
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.peft.lora import LoRAModelRaw
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
BasicConditioningInfo,
ConditioningFieldData,

View File

@@ -36,9 +36,9 @@ from invokeai.app.invocations.t2i_adapter import T2IAdapterField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.controlnet_utils import prepare_control_image
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_manager import BaseModelType, ModelVariantType
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.peft.lora import LoRAModelRaw
from invokeai.backend.stable_diffusion import PipelineIntermediateState
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, DenoiseInputs
from invokeai.backend.stable_diffusion.diffusers_pipeline import (
@@ -185,7 +185,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
)
denoise_mask: Optional[DenoiseMaskField] = InputField(
default=None,
description=FieldDescriptions.mask,
description=FieldDescriptions.denoise_mask,
input=Input.Connection,
ui_order=8,
)

View File

@@ -45,11 +45,13 @@ class UIType(str, Enum, metaclass=MetaEnum):
SDXLRefinerModel = "SDXLRefinerModelField"
ONNXModel = "ONNXModelField"
VAEModel = "VAEModelField"
FluxVAEModel = "FluxVAEModelField"
LoRAModel = "LoRAModelField"
ControlNetModel = "ControlNetModelField"
IPAdapterModel = "IPAdapterModelField"
T2IAdapterModel = "T2IAdapterModelField"
T5EncoderModel = "T5EncoderModelField"
CLIPEmbedModel = "CLIPEmbedModelField"
SpandrelImageToImageModel = "SpandrelImageToImageModelField"
# endregion
@@ -128,6 +130,7 @@ class FieldDescriptions:
noise = "Noise tensor"
clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count"
t5_encoder = "T5 tokenizer and text encoder"
clip_embed_model = "CLIP Embed loader"
unet = "UNet (scheduler, LoRAs)"
transformer = "Transformer"
vae = "VAE"
@@ -178,7 +181,7 @@ class FieldDescriptions:
)
num_1 = "The first number"
num_2 = "The second number"
mask = "The mask to use for the operation"
denoise_mask = "A mask of the region to apply the denoising process to."
board = "The board to save the image to"
image = "The image to process"
tile_size = "Tile size"

View File

@@ -0,0 +1,292 @@
from typing import Callable, Iterator, Optional, Tuple
import torch
import torchvision.transforms as tv_transforms
from torchvision.transforms.functional import resize as tv_resize
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import (
DenoiseMaskField,
FieldDescriptions,
FluxConditioningField,
Input,
InputField,
LatentsField,
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.model import TransformerField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.session_processor.session_processor_common import CanceledException
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.denoise import denoise
from invokeai.backend.flux.inpaint_extension import InpaintExtension
from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.sampling_utils import (
clip_timestep_schedule,
generate_img_ids,
get_noise,
get_schedule,
pack,
unpack,
)
from invokeai.backend.peft.lora import LoRAModelRaw
from invokeai.backend.peft.peft_patcher import PeftPatcher
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
from invokeai.backend.util.devices import TorchDevice
@invocation(
"flux_denoise",
title="FLUX Denoise",
tags=["image", "flux"],
category="image",
version="1.0.0",
classification=Classification.Prototype,
)
class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Run denoising process with a FLUX transformer model."""
# If latents is provided, this means we are doing image-to-image.
latents: Optional[LatentsField] = InputField(
default=None,
description=FieldDescriptions.latents,
input=Input.Connection,
)
# denoise_mask is used for image-to-image inpainting. Only the masked region is modified.
denoise_mask: Optional[DenoiseMaskField] = InputField(
default=None,
description=FieldDescriptions.denoise_mask,
input=Input.Connection,
)
denoising_start: float = InputField(
default=0.0,
ge=0,
le=1,
description=FieldDescriptions.denoising_start,
)
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
transformer: TransformerField = InputField(
description=FieldDescriptions.flux_model,
input=Input.Connection,
title="Transformer",
)
positive_text_conditioning: FluxConditioningField = InputField(
description=FieldDescriptions.positive_cond, input=Input.Connection
)
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.")
num_steps: int = InputField(
default=4, description="Number of diffusion steps. Recommended values are schnell: 4, dev: 50."
)
guidance: float = InputField(
default=4.0,
description="The guidance strength. Higher values adhere more strictly to the prompt, and will produce less diverse images. FLUX dev only, ignored for schnell.",
)
seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = self._run_diffusion(context)
latents = latents.detach().to("cpu")
name = context.tensors.save(tensor=latents)
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
def _run_diffusion(
self,
context: InvocationContext,
):
inference_dtype = torch.bfloat16
# Load the conditioning data.
cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name)
assert len(cond_data.conditionings) == 1
flux_conditioning = cond_data.conditionings[0]
assert isinstance(flux_conditioning, FLUXConditioningInfo)
flux_conditioning = flux_conditioning.to(dtype=inference_dtype)
t5_embeddings = flux_conditioning.t5_embeds
clip_embeddings = flux_conditioning.clip_embeds
# Load the input latents, if provided.
init_latents = context.tensors.load(self.latents.latents_name) if self.latents else None
if init_latents is not None:
init_latents = init_latents.to(device=TorchDevice.choose_torch_device(), dtype=inference_dtype)
# Prepare input noise.
noise = get_noise(
num_samples=1,
height=self.height,
width=self.width,
device=TorchDevice.choose_torch_device(),
dtype=inference_dtype,
seed=self.seed,
)
transformer_info = context.models.load(self.transformer.transformer)
is_schnell = "schnell" in transformer_info.config.config_path
# Calculate the timestep schedule.
image_seq_len = noise.shape[-1] * noise.shape[-2] // 4
timesteps = get_schedule(
num_steps=self.num_steps,
image_seq_len=image_seq_len,
shift=not is_schnell,
)
# Clip the timesteps schedule based on denoising_start and denoising_end.
timesteps = clip_timestep_schedule(timesteps, self.denoising_start, self.denoising_end)
# Prepare input latent image.
if init_latents is not None:
# If init_latents is provided, we are doing image-to-image.
if is_schnell:
context.logger.warning(
"Running image-to-image with a FLUX schnell model. This is not recommended. The results are likely "
"to be poor. Consider using a FLUX dev model instead."
)
# Noise the orig_latents by the appropriate amount for the first timestep.
t_0 = timesteps[0]
x = t_0 * noise + (1.0 - t_0) * init_latents
else:
# init_latents are not provided, so we are not doing image-to-image (i.e. we are starting from pure noise).
if self.denoising_start > 1e-5:
raise ValueError("denoising_start should be 0 when initial latents are not provided.")
x = noise
# If len(timesteps) == 1, then short-circuit. We are just noising the input latents, but not taking any
# denoising steps.
if len(timesteps) <= 1:
return x
inpaint_mask = self._prep_inpaint_mask(context, x)
b, _c, h, w = x.shape
img_ids = generate_img_ids(h=h, w=w, batch_size=b, device=x.device, dtype=x.dtype)
bs, t5_seq_len, _ = t5_embeddings.shape
txt_ids = torch.zeros(bs, t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device())
# Pack all latent tensors.
init_latents = pack(init_latents) if init_latents is not None else None
inpaint_mask = pack(inpaint_mask) if inpaint_mask is not None else None
noise = pack(noise)
x = pack(x)
# Now that we have 'packed' the latent tensors, verify that we calculated the image_seq_len correctly.
assert image_seq_len == x.shape[1]
# Prepare inpaint extension.
inpaint_extension: InpaintExtension | None = None
if inpaint_mask is not None:
assert init_latents is not None
inpaint_extension = InpaintExtension(
init_latents=init_latents,
inpaint_mask=inpaint_mask,
noise=noise,
)
with (
transformer_info.model_on_device() as (cached_weights, transformer),
# Apply the LoRA after transformer has been moved to its target device for faster patching.
PeftPatcher.apply_peft_patches(
model=transformer,
patches=self._lora_iterator(context),
prefix="",
cached_weights=cached_weights,
),
):
assert isinstance(transformer, Flux)
x = denoise(
model=transformer,
img=x,
img_ids=img_ids,
txt=t5_embeddings,
txt_ids=txt_ids,
vec=clip_embeddings,
timesteps=timesteps,
step_callback=self._build_step_callback(context),
guidance=self.guidance,
inpaint_extension=inpaint_extension,
)
x = unpack(x.float(), self.height, self.width)
return x
def _prep_inpaint_mask(self, context: InvocationContext, latents: torch.Tensor) -> torch.Tensor | None:
"""Prepare the inpaint mask.
- Loads the mask
- Resizes if necessary
- Casts to same device/dtype as latents
- Expands mask to the same shape as latents so that they line up after 'packing'
Args:
context (InvocationContext): The invocation context, for loading the inpaint mask.
latents (torch.Tensor): A latent image tensor. In 'unpacked' format. Used to determine the target shape,
device, and dtype for the inpaint mask.
Returns:
torch.Tensor | None: Inpaint mask.
"""
if self.denoise_mask is None:
return None
mask = context.tensors.load(self.denoise_mask.mask_name)
_, _, latent_height, latent_width = latents.shape
mask = tv_resize(
img=mask,
size=[latent_height, latent_width],
interpolation=tv_transforms.InterpolationMode.BILINEAR,
antialias=False,
)
mask = mask.to(device=latents.device, dtype=latents.dtype)
# Expand the inpaint mask to the same shape as `latents` so that when we 'pack' `mask` it lines up with
# `latents`.
return mask.expand_as(latents)
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.transformer.loras:
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight)
del lora_info
def _build_step_callback(self, context: InvocationContext) -> Callable[[], None]:
def step_callback() -> None:
if context.util.is_canceled():
raise CanceledException
# TODO: Make this look like the image before re-enabling
# latent_image = unpack(img.float(), self.height, self.width)
# latent_image = latent_image.squeeze() # Remove unnecessary dimensions
# flattened_tensor = latent_image.reshape(-1) # Flatten to shape [48*128*128]
# # Create a new tensor of the required shape [255, 255, 3]
# latent_image = flattened_tensor[: 255 * 255 * 3].reshape(255, 255, 3) # Reshape to RGB format
# # Convert to a NumPy array and then to a PIL Image
# image = Image.fromarray(latent_image.cpu().numpy().astype(np.uint8))
# (width, height) = image.size
# width *= 8
# height *= 8
# dataURL = image_to_dataURL(image, image_format="JPEG")
# # TODO: move this whole function to invocation context to properly reference these variables
# context._services.events.emit_invocation_denoise_progress(
# context._data.queue_item,
# context._data.invocation,
# state,
# ProgressImage(dataURL=dataURL, width=width, height=height),
# )
return step_callback

View File

@@ -0,0 +1,53 @@
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.invocations.model import LoRAField, ModelIdentifierField, TransformerField
from invokeai.app.services.shared.invocation_context import InvocationContext
@invocation_output("flux_lora_loader_output")
class FluxLoRALoaderOutput(BaseInvocationOutput):
"""FLUX LoRA Loader Output"""
transformer: TransformerField = OutputField(
default=None, description=FieldDescriptions.transformer, title="FLUX Transformer"
)
@invocation(
"flux_lora_loader",
title="FLUX LoRA",
tags=["lora", "model", "flux"],
category="model",
version="1.0.0",
)
class FluxLoRALoaderInvocation(BaseInvocation):
"""Apply a LoRA model to a FLUX transformer."""
lora: ModelIdentifierField = InputField(
description=FieldDescriptions.lora_model, title="LoRA", ui_type=UIType.LoRAModel
)
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
transformer: TransformerField = InputField(
description=FieldDescriptions.transformer,
input=Input.Connection,
title="FLUX Transformer",
)
def invoke(self, context: InvocationContext) -> FluxLoRALoaderOutput:
lora_key = self.lora.key
if not context.models.exists(lora_key):
raise ValueError(f"Unknown lora: {lora_key}!")
if any(lora.lora.key == lora_key for lora in self.transformer.loras):
raise Exception(f'LoRA "{lora_key}" already applied to transformer.')
transformer = self.transformer.model_copy(deep=True)
transformer.loras.append(
LoRAField(
lora=self.lora,
weight=self.weight,
)
)
return FluxLoRALoaderOutput(transformer=transformer)

View File

@@ -40,7 +40,10 @@ class FluxTextEncoderInvocation(BaseInvocation):
@torch.no_grad()
def invoke(self, context: InvocationContext) -> FluxConditioningOutput:
t5_embeddings, clip_embeddings = self._encode_prompt(context)
# Note: The T5 and CLIP encoding are done in separate functions to ensure that all model references are locally
# scoped. This ensures that the T5 model can be freed and gc'd before loading the CLIP model (if necessary).
t5_embeddings = self._t5_encode(context)
clip_embeddings = self._clip_encode(context)
conditioning_data = ConditioningFieldData(
conditionings=[FLUXConditioningInfo(clip_embeds=clip_embeddings, t5_embeds=t5_embeddings)]
)
@@ -48,12 +51,7 @@ class FluxTextEncoderInvocation(BaseInvocation):
conditioning_name = context.conditioning.save(conditioning_data)
return FluxConditioningOutput.build(conditioning_name)
def _encode_prompt(self, context: InvocationContext) -> tuple[torch.Tensor, torch.Tensor]:
# Load CLIP.
clip_tokenizer_info = context.models.load(self.clip.tokenizer)
clip_text_encoder_info = context.models.load(self.clip.text_encoder)
# Load T5.
def _t5_encode(self, context: InvocationContext) -> torch.Tensor:
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer)
t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder)
@@ -70,6 +68,15 @@ class FluxTextEncoderInvocation(BaseInvocation):
prompt_embeds = t5_encoder(prompt)
assert isinstance(prompt_embeds, torch.Tensor)
return prompt_embeds
def _clip_encode(self, context: InvocationContext) -> torch.Tensor:
clip_tokenizer_info = context.models.load(self.clip.tokenizer)
clip_text_encoder_info = context.models.load(self.clip.text_encoder)
prompt = [self.prompt]
with (
clip_text_encoder_info as clip_text_encoder,
clip_tokenizer_info as clip_tokenizer,
@@ -81,6 +88,5 @@ class FluxTextEncoderInvocation(BaseInvocation):
pooled_prompt_embeds = clip_encoder(prompt)
assert isinstance(prompt_embeds, torch.Tensor)
assert isinstance(pooled_prompt_embeds, torch.Tensor)
return prompt_embeds, pooled_prompt_embeds
return pooled_prompt_embeds

View File

@@ -1,172 +0,0 @@
import torch
from einops import rearrange
from PIL import Image
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import (
FieldDescriptions,
FluxConditioningField,
Input,
InputField,
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.model import TransformerField, VAEField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.session_processor.session_processor_common import CanceledException
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
from invokeai.backend.flux.sampling import denoise, get_noise, get_schedule, prepare_latent_img_patches, unpack
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
from invokeai.backend.util.devices import TorchDevice
@invocation(
"flux_text_to_image",
title="FLUX Text to Image",
tags=["image", "flux"],
category="image",
version="1.0.0",
classification=Classification.Prototype,
)
class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Text-to-image generation using a FLUX model."""
transformer: TransformerField = InputField(
description=FieldDescriptions.flux_model,
input=Input.Connection,
title="Transformer",
)
vae: VAEField = InputField(
description=FieldDescriptions.vae,
input=Input.Connection,
)
positive_text_conditioning: FluxConditioningField = InputField(
description=FieldDescriptions.positive_cond, input=Input.Connection
)
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.")
num_steps: int = InputField(
default=4, description="Number of diffusion steps. Recommend values are schnell: 4, dev: 50."
)
guidance: float = InputField(
default=4.0,
description="The guidance strength. Higher values adhere more strictly to the prompt, and will produce less diverse images. FLUX dev only, ignored for schnell.",
)
seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
# Load the conditioning data.
cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name)
assert len(cond_data.conditionings) == 1
flux_conditioning = cond_data.conditionings[0]
assert isinstance(flux_conditioning, FLUXConditioningInfo)
latents = self._run_diffusion(context, flux_conditioning.clip_embeds, flux_conditioning.t5_embeds)
image = self._run_vae_decoding(context, latents)
image_dto = context.images.save(image=image)
return ImageOutput.build(image_dto)
def _run_diffusion(
self,
context: InvocationContext,
clip_embeddings: torch.Tensor,
t5_embeddings: torch.Tensor,
):
transformer_info = context.models.load(self.transformer.transformer)
inference_dtype = torch.bfloat16
# Prepare input noise.
x = get_noise(
num_samples=1,
height=self.height,
width=self.width,
device=TorchDevice.choose_torch_device(),
dtype=inference_dtype,
seed=self.seed,
)
img, img_ids = prepare_latent_img_patches(x)
is_schnell = "schnell" in transformer_info.config.config_path
timesteps = get_schedule(
num_steps=self.num_steps,
image_seq_len=img.shape[1],
shift=not is_schnell,
)
bs, t5_seq_len, _ = t5_embeddings.shape
txt_ids = torch.zeros(bs, t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device())
# HACK(ryand): Manually empty the cache. Currently we don't check the size of the model before loading it from
# disk. Since the transformer model is large (24GB), there's a good chance that it will OOM on 32GB RAM systems
# if the cache is not empty.
context.models._services.model_manager.load.ram_cache.make_room(24 * 2**30)
with transformer_info as transformer:
assert isinstance(transformer, Flux)
def step_callback() -> None:
if context.util.is_canceled():
raise CanceledException
# TODO: Make this look like the image before re-enabling
# latent_image = unpack(img.float(), self.height, self.width)
# latent_image = latent_image.squeeze() # Remove unnecessary dimensions
# flattened_tensor = latent_image.reshape(-1) # Flatten to shape [48*128*128]
# # Create a new tensor of the required shape [255, 255, 3]
# latent_image = flattened_tensor[: 255 * 255 * 3].reshape(255, 255, 3) # Reshape to RGB format
# # Convert to a NumPy array and then to a PIL Image
# image = Image.fromarray(latent_image.cpu().numpy().astype(np.uint8))
# (width, height) = image.size
# width *= 8
# height *= 8
# dataURL = image_to_dataURL(image, image_format="JPEG")
# # TODO: move this whole function to invocation context to properly reference these variables
# context._services.events.emit_invocation_denoise_progress(
# context._data.queue_item,
# context._data.invocation,
# state,
# ProgressImage(dataURL=dataURL, width=width, height=height),
# )
x = denoise(
model=transformer,
img=img,
img_ids=img_ids,
txt=t5_embeddings,
txt_ids=txt_ids,
vec=clip_embeddings,
timesteps=timesteps,
step_callback=step_callback,
guidance=self.guidance,
)
x = unpack(x.float(), self.height, self.width)
return x
def _run_vae_decoding(
self,
context: InvocationContext,
latents: torch.Tensor,
) -> Image.Image:
vae_info = context.models.load(self.vae.vae)
with vae_info as vae:
assert isinstance(vae, AutoEncoder)
latents = latents.to(dtype=TorchDevice.choose_torch_dtype())
img = vae.decode(latents)
img = img.clamp(-1, 1)
img = rearrange(img[0], "c h w -> h w c")
img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy())
return img_pil

View File

@@ -0,0 +1,60 @@
import torch
from einops import rearrange
from PIL import Image
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import (
FieldDescriptions,
Input,
InputField,
LatentsField,
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.model import VAEField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
from invokeai.backend.model_manager.load.load_base import LoadedModel
from invokeai.backend.util.devices import TorchDevice
@invocation(
"flux_vae_decode",
title="FLUX Latents to Image",
tags=["latents", "image", "vae", "l2i", "flux"],
category="latents",
version="1.0.0",
)
class FluxVaeDecodeInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Generates an image from latents."""
latents: LatentsField = InputField(
description=FieldDescriptions.latents,
input=Input.Connection,
)
vae: VAEField = InputField(
description=FieldDescriptions.vae,
input=Input.Connection,
)
def _vae_decode(self, vae_info: LoadedModel, latents: torch.Tensor) -> Image.Image:
with vae_info as vae:
assert isinstance(vae, AutoEncoder)
latents = latents.to(device=TorchDevice.choose_torch_device(), dtype=TorchDevice.choose_torch_dtype())
img = vae.decode(latents)
img = img.clamp(-1, 1)
img = rearrange(img[0], "c h w -> h w c") # noqa: F821
img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy())
return img_pil
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.tensors.load(self.latents.latents_name)
vae_info = context.models.load(self.vae.vae)
image = self._vae_decode(vae_info=vae_info, latents=latents)
TorchDevice.empty_cache()
image_dto = context.images.save(image=image)
return ImageOutput.build(image_dto)

View File

@@ -0,0 +1,67 @@
import einops
import torch
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import (
FieldDescriptions,
ImageField,
Input,
InputField,
)
from invokeai.app.invocations.model import VAEField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
from invokeai.backend.model_manager import LoadedModel
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
from invokeai.backend.util.devices import TorchDevice
@invocation(
"flux_vae_encode",
title="FLUX Image to Latents",
tags=["latents", "image", "vae", "i2l", "flux"],
category="latents",
version="1.0.0",
)
class FluxVaeEncodeInvocation(BaseInvocation):
"""Encodes an image into latents."""
image: ImageField = InputField(
description="The image to encode.",
)
vae: VAEField = InputField(
description=FieldDescriptions.vae,
input=Input.Connection,
)
@staticmethod
def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor:
# TODO(ryand): Expose seed parameter at the invocation level.
# TODO(ryand): Write a util function for generating random tensors that is consistent across devices / dtypes.
# There's a starting point in get_noise(...), but it needs to be extracted and generalized. This function
# should be used for VAE encode sampling.
generator = torch.Generator(device=TorchDevice.choose_torch_device()).manual_seed(0)
with vae_info as vae:
assert isinstance(vae, AutoEncoder)
image_tensor = image_tensor.to(
device=TorchDevice.choose_torch_device(), dtype=TorchDevice.choose_torch_dtype()
)
latents = vae.encode(image_tensor, sample=True, generator=generator)
return latents
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
image = context.images.get_pil(self.image.image_name)
vae_info = context.models.load(self.vae.vae)
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image_tensor.dim() == 3:
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor)
latents = latents.to("cpu")
name = context.tensors.save(tensor=latents)
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)

View File

@@ -126,7 +126,7 @@ class ImageMaskToTensorInvocation(BaseInvocation, WithMetadata):
title="Tensor Mask to Image",
tags=["mask"],
category="mask",
version="1.0.0",
version="1.1.0",
)
class MaskTensorToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Convert a mask tensor to an image."""
@@ -135,6 +135,11 @@ class MaskTensorToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
def invoke(self, context: InvocationContext) -> ImageOutput:
mask = context.tensors.load(self.mask.tensor_name)
# Squeeze the channel dimension if it exists.
if mask.dim() == 3:
mask = mask.squeeze(0)
# Ensure that the mask is binary.
if mask.dtype != torch.bool:
mask = mask > 0.5

View File

@@ -69,6 +69,7 @@ class CLIPField(BaseModel):
class TransformerField(BaseModel):
transformer: ModelIdentifierField = Field(description="Info to load Transformer submodel")
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
class T5EncoderField(BaseModel):
@@ -157,7 +158,7 @@ class FluxModelLoaderOutput(BaseInvocationOutput):
title="Flux Main Model",
tags=["model", "flux"],
category="model",
version="1.0.3",
version="1.0.4",
classification=Classification.Prototype,
)
class FluxModelLoaderInvocation(BaseInvocation):
@@ -169,80 +170,46 @@ class FluxModelLoaderInvocation(BaseInvocation):
input=Input.Direct,
)
t5_encoder: ModelIdentifierField = InputField(
description=FieldDescriptions.t5_encoder,
ui_type=UIType.T5EncoderModel,
t5_encoder_model: ModelIdentifierField = InputField(
description=FieldDescriptions.t5_encoder, ui_type=UIType.T5EncoderModel, input=Input.Direct, title="T5 Encoder"
)
clip_embed_model: ModelIdentifierField = InputField(
description=FieldDescriptions.clip_embed_model,
ui_type=UIType.CLIPEmbedModel,
input=Input.Direct,
title="CLIP Embed",
)
vae_model: ModelIdentifierField = InputField(
description=FieldDescriptions.vae_model, ui_type=UIType.FluxVAEModel, title="VAE"
)
def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
model_key = self.model.key
for key in [self.model.key, self.t5_encoder_model.key, self.clip_embed_model.key, self.vae_model.key]:
if not context.models.exists(key):
raise ValueError(f"Unknown model: {key}")
transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
vae = self.vae_model.model_copy(update={"submodel_type": SubModelType.VAE})
tokenizer = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
clip_encoder = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
tokenizer2 = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
t5_encoder = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
if not context.models.exists(model_key):
raise ValueError(f"Unknown model: {model_key}")
transformer = self._get_model(context, SubModelType.Transformer)
tokenizer = self._get_model(context, SubModelType.Tokenizer)
tokenizer2 = self._get_model(context, SubModelType.Tokenizer2)
clip_encoder = self._get_model(context, SubModelType.TextEncoder)
t5_encoder = self._get_model(context, SubModelType.TextEncoder2)
vae = self._get_model(context, SubModelType.VAE)
transformer_config = context.models.get_config(transformer)
assert isinstance(transformer_config, CheckpointConfigBase)
return FluxModelLoaderOutput(
transformer=TransformerField(transformer=transformer),
transformer=TransformerField(transformer=transformer, loras=[]),
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0),
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder),
vae=VAEField(vae=vae),
max_seq_len=max_seq_lengths[transformer_config.config_path],
)
def _get_model(self, context: InvocationContext, submodel: SubModelType) -> ModelIdentifierField:
match submodel:
case SubModelType.Transformer:
return self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
case SubModelType.VAE:
return self._pull_model_from_mm(
context,
SubModelType.VAE,
"FLUX.1-schnell_ae",
ModelType.VAE,
BaseModelType.Flux,
)
case submodel if submodel in [SubModelType.Tokenizer, SubModelType.TextEncoder]:
return self._pull_model_from_mm(
context,
submodel,
"clip-vit-large-patch14",
ModelType.CLIPEmbed,
BaseModelType.Any,
)
case submodel if submodel in [SubModelType.Tokenizer2, SubModelType.TextEncoder2]:
return self._pull_model_from_mm(
context,
submodel,
self.t5_encoder.name,
ModelType.T5Encoder,
BaseModelType.Any,
)
case _:
raise Exception(f"{submodel.value} is not a supported submodule for a flux model")
def _pull_model_from_mm(
self,
context: InvocationContext,
submodel: SubModelType,
name: str,
type: ModelType,
base: BaseModelType,
):
if models := context.models.search_by_attrs(name=name, base=base, type=type):
if len(models) != 1:
raise Exception(f"Multiple models detected for selected model with name {name}")
return ModelIdentifierField.from_config(models[0]).model_copy(update={"submodel_type": submodel})
else:
raise ValueError(f"Please install the {base}:{type} model named {name} via starter models")
@invocation(
"main_model_loader",

View File

@@ -22,8 +22,8 @@ from invokeai.app.invocations.fields import (
from invokeai.app.invocations.model import UNetField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.peft.lora import LoRAModelRaw
from invokeai.backend.stable_diffusion.diffusers_pipeline import ControlNetData, PipelineIntermediateState
from invokeai.backend.stable_diffusion.multi_diffusion_pipeline import (
MultiDiffusionPipeline,

View File

@@ -103,7 +103,7 @@ class HFModelSource(StringLikeSource):
if self.variant:
base += f":{self.variant or ''}"
if self.subfolder:
base += f":{self.subfolder}"
base += f"::{self.subfolder.as_posix()}"
return base

View File

@@ -0,0 +1,407 @@
{
"name": "FLUX Image to Image",
"author": "InvokeAI",
"description": "A simple image-to-image workflow using a FLUX dev model. ",
"version": "1.0.4",
"contact": "",
"tags": "image2image, flux, image-to-image",
"notes": "Prerequisite model downloads: T5 Encoder, CLIP-L Encoder, and FLUX VAE. Quantized and un-quantized versions can be found in the starter models tab within your Model Manager. We recommend using FLUX dev models for image-to-image workflows. The image-to-image performance with FLUX schnell models is poor.",
"exposedFields": [
{
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"fieldName": "model"
},
{
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"fieldName": "t5_encoder_model"
},
{
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"fieldName": "clip_embed_model"
},
{
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"fieldName": "vae_model"
},
{
"nodeId": "ace0258f-67d7-4eee-a218-6fff27065214",
"fieldName": "denoising_start"
},
{
"nodeId": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
"fieldName": "prompt"
},
{
"nodeId": "ace0258f-67d7-4eee-a218-6fff27065214",
"fieldName": "num_steps"
}
],
"meta": {
"version": "3.0.0",
"category": "default"
},
"nodes": [
{
"id": "2981a67c-480f-4237-9384-26b68dbf912b",
"type": "invocation",
"data": {
"id": "2981a67c-480f-4237-9384-26b68dbf912b",
"type": "flux_vae_encode",
"version": "1.0.0",
"label": "",
"notes": "",
"isOpen": true,
"isIntermediate": true,
"useCache": true,
"inputs": {
"image": {
"name": "image",
"label": "",
"value": {
"image_name": "8a5c62aa-9335-45d2-9c71-89af9fc1f8d4.png"
}
},
"vae": {
"name": "vae",
"label": ""
}
}
},
"position": {
"x": 732.7680166609682,
"y": -24.37398171806909
}
},
{
"id": "ace0258f-67d7-4eee-a218-6fff27065214",
"type": "invocation",
"data": {
"id": "ace0258f-67d7-4eee-a218-6fff27065214",
"type": "flux_denoise",
"version": "1.0.0",
"label": "",
"notes": "",
"isOpen": true,
"isIntermediate": true,
"useCache": true,
"inputs": {
"board": {
"name": "board",
"label": ""
},
"metadata": {
"name": "metadata",
"label": ""
},
"latents": {
"name": "latents",
"label": ""
},
"denoise_mask": {
"name": "denoise_mask",
"label": ""
},
"denoising_start": {
"name": "denoising_start",
"label": "",
"value": 0.04
},
"denoising_end": {
"name": "denoising_end",
"label": "",
"value": 1
},
"transformer": {
"name": "transformer",
"label": ""
},
"positive_text_conditioning": {
"name": "positive_text_conditioning",
"label": ""
},
"width": {
"name": "width",
"label": "",
"value": 1024
},
"height": {
"name": "height",
"label": "",
"value": 1024
},
"num_steps": {
"name": "num_steps",
"label": "Steps (Recommend 30 for Dev, 4 for Schnell)",
"value": 30
},
"guidance": {
"name": "guidance",
"label": "",
"value": 4
},
"seed": {
"name": "seed",
"label": "",
"value": 0
}
}
},
"position": {
"x": 1182.8836633018684,
"y": -251.38882958913183
}
},
{
"id": "7e5172eb-48c1-44db-a770-8fd83e1435d1",
"type": "invocation",
"data": {
"id": "7e5172eb-48c1-44db-a770-8fd83e1435d1",
"type": "flux_vae_decode",
"version": "1.0.0",
"label": "",
"notes": "",
"isOpen": true,
"isIntermediate": false,
"useCache": true,
"inputs": {
"board": {
"name": "board",
"label": ""
},
"metadata": {
"name": "metadata",
"label": ""
},
"latents": {
"name": "latents",
"label": ""
},
"vae": {
"name": "vae",
"label": ""
}
}
},
"position": {
"x": 1575.5797431839133,
"y": -209.00150975507415
}
},
{
"id": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"type": "invocation",
"data": {
"id": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"type": "flux_model_loader",
"version": "1.0.4",
"label": "",
"notes": "",
"isOpen": true,
"isIntermediate": true,
"useCache": false,
"inputs": {
"model": {
"name": "model",
"label": "Model (dev variant recommended for Image-to-Image)"
},
"t5_encoder_model": {
"name": "t5_encoder_model",
"label": ""
},
"clip_embed_model": {
"name": "clip_embed_model",
"label": "",
"value": {
"key": "fa23a584-b623-415d-832a-21b5098ff1a1",
"hash": "blake3:17c19f0ef941c3b7609a9c94a659ca5364de0be364a91d4179f0e39ba17c3b70",
"name": "clip-vit-large-patch14",
"base": "any",
"type": "clip_embed"
}
},
"vae_model": {
"name": "vae_model",
"label": "",
"value": {
"key": "74fc82ba-c0a8-479d-a890-2126f82da758",
"hash": "blake3:ce21cb76364aa6e2421311cf4a4b5eb052a76c4f1cd207b50703d8978198a068",
"name": "FLUX.1-schnell_ae",
"base": "flux",
"type": "vae"
}
}
}
},
"position": {
"x": 328.1809894659957,
"y": -90.2241133566946
}
},
{
"id": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
"type": "invocation",
"data": {
"id": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
"type": "flux_text_encoder",
"version": "1.0.0",
"label": "",
"notes": "",
"isOpen": true,
"isIntermediate": true,
"useCache": true,
"inputs": {
"clip": {
"name": "clip",
"label": ""
},
"t5_encoder": {
"name": "t5_encoder",
"label": ""
},
"t5_max_seq_len": {
"name": "t5_max_seq_len",
"label": "T5 Max Seq Len",
"value": 256
},
"prompt": {
"name": "prompt",
"label": "",
"value": "a cat wearing a birthday hat"
}
}
},
"position": {
"x": 745.8823365057267,
"y": -299.60249175851914
}
},
{
"id": "4754c534-a5f3-4ad0-9382-7887985e668c",
"type": "invocation",
"data": {
"id": "4754c534-a5f3-4ad0-9382-7887985e668c",
"type": "rand_int",
"version": "1.0.1",
"label": "",
"notes": "",
"isOpen": true,
"isIntermediate": true,
"useCache": false,
"inputs": {
"low": {
"name": "low",
"label": "",
"value": 0
},
"high": {
"name": "high",
"label": "",
"value": 2147483647
}
}
},
"position": {
"x": 725.834098928012,
"y": 496.2710031089931
}
}
],
"edges": [
{
"id": "reactflow__edge-2981a67c-480f-4237-9384-26b68dbf912bheight-ace0258f-67d7-4eee-a218-6fff27065214height",
"type": "default",
"source": "2981a67c-480f-4237-9384-26b68dbf912b",
"target": "ace0258f-67d7-4eee-a218-6fff27065214",
"sourceHandle": "height",
"targetHandle": "height"
},
{
"id": "reactflow__edge-2981a67c-480f-4237-9384-26b68dbf912bwidth-ace0258f-67d7-4eee-a218-6fff27065214width",
"type": "default",
"source": "2981a67c-480f-4237-9384-26b68dbf912b",
"target": "ace0258f-67d7-4eee-a218-6fff27065214",
"sourceHandle": "width",
"targetHandle": "width"
},
{
"id": "reactflow__edge-2981a67c-480f-4237-9384-26b68dbf912blatents-ace0258f-67d7-4eee-a218-6fff27065214latents",
"type": "default",
"source": "2981a67c-480f-4237-9384-26b68dbf912b",
"target": "ace0258f-67d7-4eee-a218-6fff27065214",
"sourceHandle": "latents",
"targetHandle": "latents"
},
{
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90vae-2981a67c-480f-4237-9384-26b68dbf912bvae",
"type": "default",
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"target": "2981a67c-480f-4237-9384-26b68dbf912b",
"sourceHandle": "vae",
"targetHandle": "vae"
},
{
"id": "reactflow__edge-ace0258f-67d7-4eee-a218-6fff27065214latents-7e5172eb-48c1-44db-a770-8fd83e1435d1latents",
"type": "default",
"source": "ace0258f-67d7-4eee-a218-6fff27065214",
"target": "7e5172eb-48c1-44db-a770-8fd83e1435d1",
"sourceHandle": "latents",
"targetHandle": "latents"
},
{
"id": "reactflow__edge-4754c534-a5f3-4ad0-9382-7887985e668cvalue-ace0258f-67d7-4eee-a218-6fff27065214seed",
"type": "default",
"source": "4754c534-a5f3-4ad0-9382-7887985e668c",
"target": "ace0258f-67d7-4eee-a218-6fff27065214",
"sourceHandle": "value",
"targetHandle": "seed"
},
{
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90transformer-ace0258f-67d7-4eee-a218-6fff27065214transformer",
"type": "default",
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"target": "ace0258f-67d7-4eee-a218-6fff27065214",
"sourceHandle": "transformer",
"targetHandle": "transformer"
},
{
"id": "reactflow__edge-01f674f8-b3d1-4df1-acac-6cb8e0bfb63cconditioning-ace0258f-67d7-4eee-a218-6fff27065214positive_text_conditioning",
"type": "default",
"source": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
"target": "ace0258f-67d7-4eee-a218-6fff27065214",
"sourceHandle": "conditioning",
"targetHandle": "positive_text_conditioning"
},
{
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90vae-7e5172eb-48c1-44db-a770-8fd83e1435d1vae",
"type": "default",
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"target": "7e5172eb-48c1-44db-a770-8fd83e1435d1",
"sourceHandle": "vae",
"targetHandle": "vae"
},
{
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90max_seq_len-01f674f8-b3d1-4df1-acac-6cb8e0bfb63ct5_max_seq_len",
"type": "default",
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
"sourceHandle": "max_seq_len",
"targetHandle": "t5_max_seq_len"
},
{
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90t5_encoder-01f674f8-b3d1-4df1-acac-6cb8e0bfb63ct5_encoder",
"type": "default",
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
"sourceHandle": "t5_encoder",
"targetHandle": "t5_encoder"
},
{
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90clip-01f674f8-b3d1-4df1-acac-6cb8e0bfb63cclip",
"type": "default",
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
"sourceHandle": "clip",
"targetHandle": "clip"
}
]
}

View File

@@ -1,27 +1,35 @@
{
"name": "FLUX Text to Image",
"author": "InvokeAI",
"description": "A simple text-to-image workflow using FLUX dev or schnell models. Prerequisite model downloads: T5 Encoder, CLIP-L Encoder, and FLUX VAE. Quantized and un-quantized versions can be found in the starter models tab within your Model Manager. We recommend 4 steps for FLUX schnell models and 30 steps for FLUX dev models.",
"version": "1.0.0",
"description": "A simple text-to-image workflow using FLUX dev or schnell models.",
"version": "1.0.4",
"contact": "",
"tags": "text2image, flux",
"notes": "Prerequisite model downloads: T5 Encoder, CLIP-L Encoder, and FLUX VAE. Quantized and un-quantized versions can be found in the starter models tab within your Model Manager. We recommend 4 steps for FLUX schnell models and 30 steps for FLUX dev models.",
"exposedFields": [
{
"nodeId": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a",
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"fieldName": "model"
},
{
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"fieldName": "t5_encoder_model"
},
{
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"fieldName": "clip_embed_model"
},
{
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"fieldName": "vae_model"
},
{
"nodeId": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
"fieldName": "prompt"
},
{
"nodeId": "159bdf1b-79e7-4174-b86e-d40e646964c8",
"nodeId": "4fe24f07-f906-4f55-ab2c-9beee56ef5bd",
"fieldName": "num_steps"
},
{
"nodeId": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a",
"fieldName": "t5_encoder"
}
],
"meta": {
@@ -30,12 +38,127 @@
},
"nodes": [
{
"id": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a",
"id": "4fe24f07-f906-4f55-ab2c-9beee56ef5bd",
"type": "invocation",
"data": {
"id": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a",
"id": "4fe24f07-f906-4f55-ab2c-9beee56ef5bd",
"type": "flux_denoise",
"version": "1.0.0",
"label": "",
"notes": "",
"isOpen": true,
"isIntermediate": true,
"useCache": true,
"inputs": {
"board": {
"name": "board",
"label": ""
},
"metadata": {
"name": "metadata",
"label": ""
},
"latents": {
"name": "latents",
"label": ""
},
"denoise_mask": {
"name": "denoise_mask",
"label": ""
},
"denoising_start": {
"name": "denoising_start",
"label": "",
"value": 0
},
"denoising_end": {
"name": "denoising_end",
"label": "",
"value": 1
},
"transformer": {
"name": "transformer",
"label": ""
},
"positive_text_conditioning": {
"name": "positive_text_conditioning",
"label": ""
},
"width": {
"name": "width",
"label": "",
"value": 1024
},
"height": {
"name": "height",
"label": "",
"value": 1024
},
"num_steps": {
"name": "num_steps",
"label": "Steps (Recommend 30 for Dev, 4 for Schnell)",
"value": 30
},
"guidance": {
"name": "guidance",
"label": "",
"value": 4
},
"seed": {
"name": "seed",
"label": "",
"value": 0
}
}
},
"position": {
"x": 1186.1868226120378,
"y": -214.9459927686657
}
},
{
"id": "7e5172eb-48c1-44db-a770-8fd83e1435d1",
"type": "invocation",
"data": {
"id": "7e5172eb-48c1-44db-a770-8fd83e1435d1",
"type": "flux_vae_decode",
"version": "1.0.0",
"label": "",
"notes": "",
"isOpen": true,
"isIntermediate": false,
"useCache": true,
"inputs": {
"board": {
"name": "board",
"label": ""
},
"metadata": {
"name": "metadata",
"label": ""
},
"latents": {
"name": "latents",
"label": ""
},
"vae": {
"name": "vae",
"label": ""
}
}
},
"position": {
"x": 1575.5797431839133,
"y": -209.00150975507415
}
},
{
"id": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"type": "invocation",
"data": {
"id": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"type": "flux_model_loader",
"version": "1.0.3",
"version": "1.0.4",
"label": "",
"notes": "",
"isOpen": true,
@@ -44,31 +167,25 @@
"inputs": {
"model": {
"name": "model",
"label": "Model (Starter Models can be found in Model Manager)",
"value": {
"key": "f04a7a2f-c74d-4538-8d5e-879a53501662",
"hash": "random:4875da7a9508444ffa706f61961c260d0c6729f6181a86b31fad06df1277b850",
"name": "FLUX Dev (Quantized)",
"base": "flux",
"type": "main"
}
"label": ""
},
"t5_encoder": {
"name": "t5_encoder",
"label": "T 5 Encoder (Starter Models can be found in Model Manager)",
"value": {
"key": "20dcd9ec-5fbb-4012-8401-049e707da5e5",
"hash": "random:f986be43ff3502169e4adbdcee158afb0e0a65a1edc4cab16ae59963630cfd8f",
"name": "t5_bnb_int8_quantized_encoder",
"base": "any",
"type": "t5_encoder"
}
"t5_encoder_model": {
"name": "t5_encoder_model",
"label": ""
},
"clip_embed_model": {
"name": "clip_embed_model",
"label": ""
},
"vae_model": {
"name": "vae_model",
"label": ""
}
}
},
"position": {
"x": 337.09365228062825,
"y": 40.63469521079861
"x": 381.1882713063478,
"y": -95.89663532854017
}
},
{
@@ -105,8 +222,8 @@
}
},
"position": {
"x": 824.1970602278849,
"y": 146.98251001061735
"x": 778.4899149328337,
"y": -100.36469216659502
}
},
{
@@ -135,132 +252,75 @@
}
},
"position": {
"x": 822.9899179655476,
"y": 360.9657214885052
}
},
{
"id": "159bdf1b-79e7-4174-b86e-d40e646964c8",
"type": "invocation",
"data": {
"id": "159bdf1b-79e7-4174-b86e-d40e646964c8",
"type": "flux_text_to_image",
"version": "1.0.0",
"label": "",
"notes": "",
"isOpen": true,
"isIntermediate": false,
"useCache": true,
"inputs": {
"board": {
"name": "board",
"label": ""
},
"metadata": {
"name": "metadata",
"label": ""
},
"transformer": {
"name": "transformer",
"label": ""
},
"vae": {
"name": "vae",
"label": ""
},
"positive_text_conditioning": {
"name": "positive_text_conditioning",
"label": ""
},
"width": {
"name": "width",
"label": "",
"value": 1024
},
"height": {
"name": "height",
"label": "",
"value": 1024
},
"num_steps": {
"name": "num_steps",
"label": "Steps (Recommend 30 for Dev, 4 for Schnell)",
"value": 30
},
"guidance": {
"name": "guidance",
"label": "",
"value": 4
},
"seed": {
"name": "seed",
"label": "",
"value": 0
}
}
},
"position": {
"x": 1216.3900791301849,
"y": 5.500841807102248
"x": 800.9667463219505,
"y": 285.8297267547506
}
}
],
"edges": [
{
"id": "reactflow__edge-4f0207c2-ff40-41fd-b047-ad33fbb1c33amax_seq_len-01f674f8-b3d1-4df1-acac-6cb8e0bfb63ct5_max_seq_len",
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90transformer-4fe24f07-f906-4f55-ab2c-9beee56ef5bdtransformer",
"type": "default",
"source": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a",
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"target": "4fe24f07-f906-4f55-ab2c-9beee56ef5bd",
"sourceHandle": "transformer",
"targetHandle": "transformer"
},
{
"id": "reactflow__edge-01f674f8-b3d1-4df1-acac-6cb8e0bfb63cconditioning-4fe24f07-f906-4f55-ab2c-9beee56ef5bdpositive_text_conditioning",
"type": "default",
"source": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
"target": "4fe24f07-f906-4f55-ab2c-9beee56ef5bd",
"sourceHandle": "conditioning",
"targetHandle": "positive_text_conditioning"
},
{
"id": "reactflow__edge-4754c534-a5f3-4ad0-9382-7887985e668cvalue-4fe24f07-f906-4f55-ab2c-9beee56ef5bdseed",
"type": "default",
"source": "4754c534-a5f3-4ad0-9382-7887985e668c",
"target": "4fe24f07-f906-4f55-ab2c-9beee56ef5bd",
"sourceHandle": "value",
"targetHandle": "seed"
},
{
"id": "reactflow__edge-4fe24f07-f906-4f55-ab2c-9beee56ef5bdlatents-7e5172eb-48c1-44db-a770-8fd83e1435d1latents",
"type": "default",
"source": "4fe24f07-f906-4f55-ab2c-9beee56ef5bd",
"target": "7e5172eb-48c1-44db-a770-8fd83e1435d1",
"sourceHandle": "latents",
"targetHandle": "latents"
},
{
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90vae-7e5172eb-48c1-44db-a770-8fd83e1435d1vae",
"type": "default",
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"target": "7e5172eb-48c1-44db-a770-8fd83e1435d1",
"sourceHandle": "vae",
"targetHandle": "vae"
},
{
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90max_seq_len-01f674f8-b3d1-4df1-acac-6cb8e0bfb63ct5_max_seq_len",
"type": "default",
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
"sourceHandle": "max_seq_len",
"targetHandle": "t5_max_seq_len"
},
{
"id": "reactflow__edge-4f0207c2-ff40-41fd-b047-ad33fbb1c33avae-159bdf1b-79e7-4174-b86e-d40e646964c8vae",
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90t5_encoder-01f674f8-b3d1-4df1-acac-6cb8e0bfb63ct5_encoder",
"type": "default",
"source": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a",
"target": "159bdf1b-79e7-4174-b86e-d40e646964c8",
"sourceHandle": "vae",
"targetHandle": "vae"
},
{
"id": "reactflow__edge-4f0207c2-ff40-41fd-b047-ad33fbb1c33atransformer-159bdf1b-79e7-4174-b86e-d40e646964c8transformer",
"type": "default",
"source": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a",
"target": "159bdf1b-79e7-4174-b86e-d40e646964c8",
"sourceHandle": "transformer",
"targetHandle": "transformer"
},
{
"id": "reactflow__edge-4f0207c2-ff40-41fd-b047-ad33fbb1c33at5_encoder-01f674f8-b3d1-4df1-acac-6cb8e0bfb63ct5_encoder",
"type": "default",
"source": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a",
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
"sourceHandle": "t5_encoder",
"targetHandle": "t5_encoder"
},
{
"id": "reactflow__edge-4f0207c2-ff40-41fd-b047-ad33fbb1c33aclip-01f674f8-b3d1-4df1-acac-6cb8e0bfb63cclip",
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90clip-01f674f8-b3d1-4df1-acac-6cb8e0bfb63cclip",
"type": "default",
"source": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a",
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
"sourceHandle": "clip",
"targetHandle": "clip"
},
{
"id": "reactflow__edge-01f674f8-b3d1-4df1-acac-6cb8e0bfb63cconditioning-159bdf1b-79e7-4174-b86e-d40e646964c8positive_text_conditioning",
"type": "default",
"source": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
"target": "159bdf1b-79e7-4174-b86e-d40e646964c8",
"sourceHandle": "conditioning",
"targetHandle": "positive_text_conditioning"
},
{
"id": "reactflow__edge-4754c534-a5f3-4ad0-9382-7887985e668cvalue-159bdf1b-79e7-4174-b86e-d40e646964c8seed",
"type": "default",
"source": "4754c534-a5f3-4ad0-9382-7887985e668c",
"target": "159bdf1b-79e7-4174-b86e-d40e646964c8",
"sourceHandle": "value",
"targetHandle": "seed"
}
]
}

View File

@@ -0,0 +1,45 @@
from typing import Callable
import torch
from tqdm import tqdm
from invokeai.backend.flux.inpaint_extension import InpaintExtension
from invokeai.backend.flux.model import Flux
def denoise(
model: Flux,
# model input
img: torch.Tensor,
img_ids: torch.Tensor,
txt: torch.Tensor,
txt_ids: torch.Tensor,
vec: torch.Tensor,
# sampling parameters
timesteps: list[float],
step_callback: Callable[[], None],
guidance: float,
inpaint_extension: InpaintExtension | None,
):
# guidance_vec is ignored for schnell.
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:], strict=True))):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
pred = model(
img=img,
img_ids=img_ids,
txt=txt,
txt_ids=txt_ids,
y=vec,
timesteps=t_vec,
guidance=guidance_vec,
)
img = img + (t_prev - t_curr) * pred
if inpaint_extension is not None:
img = inpaint_extension.merge_intermediate_latents_with_init_latents(img, t_prev)
step_callback()
return img

View File

@@ -0,0 +1,35 @@
import torch
class InpaintExtension:
"""A class for managing inpainting with FLUX."""
def __init__(self, init_latents: torch.Tensor, inpaint_mask: torch.Tensor, noise: torch.Tensor):
"""Initialize InpaintExtension.
Args:
init_latents (torch.Tensor): The initial latents (i.e. un-noised at timestep 0). In 'packed' format.
inpaint_mask (torch.Tensor): A mask specifying which elements to inpaint. Range [0, 1]. Values of 1 will be
re-generated. Values of 0 will remain unchanged. Values between 0 and 1 can be used to blend the
inpainted region with the background. In 'packed' format.
noise (torch.Tensor): The noise tensor used to noise the init_latents. In 'packed' format.
"""
assert init_latents.shape == inpaint_mask.shape == noise.shape
self._init_latents = init_latents
self._inpaint_mask = inpaint_mask
self._noise = noise
def merge_intermediate_latents_with_init_latents(
self, intermediate_latents: torch.Tensor, timestep: float
) -> torch.Tensor:
"""Merge the intermediate latents with the initial latents for the current timestep using the inpaint mask. I.e.
update the intermediate latents to keep the regions that are not being inpainted on the correct noise
trajectory.
This function should be called after each denoising step.
"""
# Noise the init latents for the current timestep.
noised_init_latents = self._noise * timestep + (1.0 - timestep) * self._init_latents
# Merge the intermediate latents with the noised_init_latents using the inpaint_mask.
return intermediate_latents * self._inpaint_mask + noised_init_latents * (1.0 - self._inpaint_mask)

View File

@@ -258,16 +258,17 @@ class Decoder(nn.Module):
class DiagonalGaussian(nn.Module):
def __init__(self, sample: bool = True, chunk_dim: int = 1):
def __init__(self, chunk_dim: int = 1):
super().__init__()
self.sample = sample
self.chunk_dim = chunk_dim
def forward(self, z: Tensor) -> Tensor:
def forward(self, z: Tensor, sample: bool = True, generator: torch.Generator | None = None) -> Tensor:
mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
if self.sample:
if sample:
std = torch.exp(0.5 * logvar)
return mean + std * torch.randn_like(mean)
# Unfortunately, torch.randn_like(...) does not accept a generator argument at the time of writing, so we
# have to use torch.randn(...) instead.
return mean + std * torch.randn(size=mean.size(), generator=generator, dtype=mean.dtype, device=mean.device)
else:
return mean
@@ -297,8 +298,21 @@ class AutoEncoder(nn.Module):
self.scale_factor = params.scale_factor
self.shift_factor = params.shift_factor
def encode(self, x: Tensor) -> Tensor:
z = self.reg(self.encoder(x))
def encode(self, x: Tensor, sample: bool = True, generator: torch.Generator | None = None) -> Tensor:
"""Run VAE encoding on input tensor x.
Args:
x (Tensor): Input image tensor. Shape: (batch_size, in_channels, height, width).
sample (bool, optional): If True, sample from the encoded distribution, else, return the distribution mean.
Defaults to True.
generator (torch.Generator | None, optional): Optional random number generator for reproducibility.
Defaults to None.
Returns:
Tensor: Encoded latent tensor. Shape: (batch_size, z_channels, latent_height, latent_width).
"""
z = self.reg(self.encoder(x), sample=sample, generator=generator)
z = self.scale_factor * (z - self.shift_factor)
return z

View File

@@ -1,176 +0,0 @@
# Initially pulled from https://github.com/black-forest-labs/flux
import math
from typing import Callable
import torch
from einops import rearrange, repeat
from torch import Tensor
from tqdm import tqdm
from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.modules.conditioner import HFEncoder
def get_noise(
num_samples: int,
height: int,
width: int,
device: torch.device,
dtype: torch.dtype,
seed: int,
):
# We always generate noise on the same device and dtype then cast to ensure consistency across devices/dtypes.
rand_device = "cpu"
rand_dtype = torch.float16
return torch.randn(
num_samples,
16,
# allow for packing
2 * math.ceil(height / 16),
2 * math.ceil(width / 16),
device=rand_device,
dtype=rand_dtype,
generator=torch.Generator(device=rand_device).manual_seed(seed),
).to(device=device, dtype=dtype)
def prepare(t5: HFEncoder, clip: HFEncoder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]:
bs, c, h, w = img.shape
if bs == 1 and not isinstance(prompt, str):
bs = len(prompt)
img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
if img.shape[0] == 1 and bs > 1:
img = repeat(img, "1 ... -> bs ...", bs=bs)
img_ids = torch.zeros(h // 2, w // 2, 3)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
if isinstance(prompt, str):
prompt = [prompt]
txt = t5(prompt)
if txt.shape[0] == 1 and bs > 1:
txt = repeat(txt, "1 ... -> bs ...", bs=bs)
txt_ids = torch.zeros(bs, txt.shape[1], 3)
vec = clip(prompt)
if vec.shape[0] == 1 and bs > 1:
vec = repeat(vec, "1 ... -> bs ...", bs=bs)
return {
"img": img,
"img_ids": img_ids.to(img.device),
"txt": txt.to(img.device),
"txt_ids": txt_ids.to(img.device),
"vec": vec.to(img.device),
}
def time_shift(mu: float, sigma: float, t: Tensor):
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]:
m = (y2 - y1) / (x2 - x1)
b = y1 - m * x1
return lambda x: m * x + b
def get_schedule(
num_steps: int,
image_seq_len: int,
base_shift: float = 0.5,
max_shift: float = 1.15,
shift: bool = True,
) -> list[float]:
# extra step for zero
timesteps = torch.linspace(1, 0, num_steps + 1)
# shifting the schedule to favor high timesteps for higher signal images
if shift:
# eastimate mu based on linear estimation between two points
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
timesteps = time_shift(mu, 1.0, timesteps)
return timesteps.tolist()
def denoise(
model: Flux,
# model input
img: Tensor,
img_ids: Tensor,
txt: Tensor,
txt_ids: Tensor,
vec: Tensor,
# sampling parameters
timesteps: list[float],
step_callback: Callable[[], None],
guidance: float = 4.0,
):
dtype = model.txt_in.bias.dtype
# TODO(ryand): This shouldn't be necessary if we manage the dtypes properly in the caller.
img = img.to(dtype=dtype)
img_ids = img_ids.to(dtype=dtype)
txt = txt.to(dtype=dtype)
txt_ids = txt_ids.to(dtype=dtype)
vec = vec.to(dtype=dtype)
# this is ignored for schnell
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:], strict=True))):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
pred = model(
img=img,
img_ids=img_ids,
txt=txt,
txt_ids=txt_ids,
y=vec,
timesteps=t_vec,
guidance=guidance_vec,
)
img = img + (t_prev - t_curr) * pred
step_callback()
return img
def unpack(x: Tensor, height: int, width: int) -> Tensor:
return rearrange(
x,
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
h=math.ceil(height / 16),
w=math.ceil(width / 16),
ph=2,
pw=2,
)
def prepare_latent_img_patches(latent_img: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Convert an input image in latent space to patches for diffusion.
This implementation was extracted from:
https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/sampling.py#L32
Returns:
tuple[Tensor, Tensor]: (img, img_ids), as defined in the original flux repo.
"""
bs, c, h, w = latent_img.shape
# Pixel unshuffle with a scale of 2, and flatten the height/width dimensions to get an array of patches.
img = rearrange(latent_img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
if img.shape[0] == 1 and bs > 1:
img = repeat(img, "1 ... -> bs ...", bs=bs)
# Generate patch position ids.
img_ids = torch.zeros(h // 2, w // 2, 3, device=img.device)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=img.device)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=img.device)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
return img, img_ids

View File

@@ -0,0 +1,135 @@
# Initially pulled from https://github.com/black-forest-labs/flux
import math
from typing import Callable
import torch
from einops import rearrange, repeat
def get_noise(
num_samples: int,
height: int,
width: int,
device: torch.device,
dtype: torch.dtype,
seed: int,
):
# We always generate noise on the same device and dtype then cast to ensure consistency across devices/dtypes.
rand_device = "cpu"
rand_dtype = torch.float16
return torch.randn(
num_samples,
16,
# allow for packing
2 * math.ceil(height / 16),
2 * math.ceil(width / 16),
device=rand_device,
dtype=rand_dtype,
generator=torch.Generator(device=rand_device).manual_seed(seed),
).to(device=device, dtype=dtype)
def time_shift(mu: float, sigma: float, t: torch.Tensor) -> torch.Tensor:
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]:
m = (y2 - y1) / (x2 - x1)
b = y1 - m * x1
return lambda x: m * x + b
def get_schedule(
num_steps: int,
image_seq_len: int,
base_shift: float = 0.5,
max_shift: float = 1.15,
shift: bool = True,
) -> list[float]:
# extra step for zero
timesteps = torch.linspace(1, 0, num_steps + 1)
# shifting the schedule to favor high timesteps for higher signal images
if shift:
# estimate mu based on linear estimation between two points
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
timesteps = time_shift(mu, 1.0, timesteps)
return timesteps.tolist()
def _find_last_index_ge_val(timesteps: list[float], val: float, eps: float = 1e-6) -> int:
"""Find the last index in timesteps that is >= val.
We use epsilon-close equality to avoid potential floating point errors.
"""
idx = len(list(filter(lambda t: t >= (val - eps), timesteps))) - 1
assert idx >= 0
return idx
def clip_timestep_schedule(timesteps: list[float], denoising_start: float, denoising_end: float) -> list[float]:
"""Clip the timestep schedule to the denoising range.
Args:
timesteps (list[float]): The original timestep schedule: [1.0, ..., 0.0].
denoising_start (float): A value in [0, 1] specifying the start of the denoising process. E.g. a value of 0.2
would mean that the denoising process start at the last timestep in the schedule >= 0.8.
denoising_end (float): A value in [0, 1] specifying the end of the denoising process. E.g. a value of 0.8 would
mean that the denoising process end at the last timestep in the schedule >= 0.2.
Returns:
list[float]: The clipped timestep schedule.
"""
assert 0.0 <= denoising_start <= 1.0
assert 0.0 <= denoising_end <= 1.0
assert denoising_start <= denoising_end
t_start_val = 1.0 - denoising_start
t_end_val = 1.0 - denoising_end
t_start_idx = _find_last_index_ge_val(timesteps, t_start_val)
t_end_idx = _find_last_index_ge_val(timesteps, t_end_val)
clipped_timesteps = timesteps[t_start_idx : t_end_idx + 1]
return clipped_timesteps
def unpack(x: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""Unpack flat array of patch embeddings to latent image."""
return rearrange(
x,
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
h=math.ceil(height / 16),
w=math.ceil(width / 16),
ph=2,
pw=2,
)
def pack(x: torch.Tensor) -> torch.Tensor:
"""Pack latent image to flattented array of patch embeddings."""
# Pixel unshuffle with a scale of 2, and flatten the height/width dimensions to get an array of patches.
return rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
def generate_img_ids(h: int, w: int, batch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
"""Generate tensor of image position ids.
Args:
h (int): Height of image in latent space.
w (int): Width of image in latent space.
batch_size (int): Batch size.
device (torch.device): Device.
dtype (torch.dtype): dtype.
Returns:
torch.Tensor: Image position ids.
"""
img_ids = torch.zeros(h // 2, w // 2, 3, device=device, dtype=dtype)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=device, dtype=dtype)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=device, dtype=dtype)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=batch_size)
return img_ids

View File

@@ -1,672 +0,0 @@
# Copyright (c) 2024 The InvokeAI Development team
"""LoRA model support."""
import bisect
from pathlib import Path
from typing import Dict, List, Optional, Set, Tuple, Union
import torch
from safetensors.torch import load_file
from typing_extensions import Self
import invokeai.backend.util.logging as logger
from invokeai.backend.model_manager import BaseModelType
from invokeai.backend.raw_model import RawModel
class LoRALayerBase:
# rank: Optional[int]
# alpha: Optional[float]
# bias: Optional[torch.Tensor]
# layer_key: str
# @property
# def scale(self):
# return self.alpha / self.rank if (self.alpha and self.rank) else 1.0
def __init__(
self,
layer_key: str,
values: Dict[str, torch.Tensor],
):
if "alpha" in values:
self.alpha = values["alpha"].item()
else:
self.alpha = None
if "bias_indices" in values and "bias_values" in values and "bias_size" in values:
self.bias: Optional[torch.Tensor] = torch.sparse_coo_tensor(
values["bias_indices"],
values["bias_values"],
tuple(values["bias_size"]),
)
else:
self.bias = None
self.rank = None # set in layer implementation
self.layer_key = layer_key
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
raise NotImplementedError()
def get_bias(self, orig_bias: torch.Tensor) -> Optional[torch.Tensor]:
return self.bias
def get_parameters(self, orig_module: torch.nn.Module) -> Dict[str, torch.Tensor]:
params = {"weight": self.get_weight(orig_module.weight)}
bias = self.get_bias(orig_module.bias)
if bias is not None:
params["bias"] = bias
return params
def calc_size(self) -> int:
model_size = 0
for val in [self.bias]:
if val is not None:
model_size += val.nelement() * val.element_size()
return model_size
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
if self.bias is not None:
self.bias = self.bias.to(device=device, dtype=dtype)
def check_keys(self, values: Dict[str, torch.Tensor], known_keys: Set[str]):
"""Log a warning if values contains unhandled keys."""
# {"alpha", "bias_indices", "bias_values", "bias_size"} are hard-coded, because they are handled by
# `LoRALayerBase`. Sub-classes should provide the known_keys that they handled.
all_known_keys = known_keys | {"alpha", "bias_indices", "bias_values", "bias_size"}
unknown_keys = set(values.keys()) - all_known_keys
if unknown_keys:
logger.warning(
f"Unexpected keys found in LoRA/LyCORIS layer, model might work incorrectly! Keys: {unknown_keys}"
)
# TODO: find and debug lora/locon with bias
class LoRALayer(LoRALayerBase):
# up: torch.Tensor
# mid: Optional[torch.Tensor]
# down: torch.Tensor
def __init__(
self,
layer_key: str,
values: Dict[str, torch.Tensor],
):
super().__init__(layer_key, values)
self.up = values["lora_up.weight"]
self.down = values["lora_down.weight"]
self.mid = values.get("lora_mid.weight", None)
self.rank = self.down.shape[0]
self.check_keys(
values,
{
"lora_up.weight",
"lora_down.weight",
"lora_mid.weight",
},
)
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
if self.mid is not None:
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down)
else:
weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1)
return weight
def calc_size(self) -> int:
model_size = super().calc_size()
for val in [self.up, self.mid, self.down]:
if val is not None:
model_size += val.nelement() * val.element_size()
return model_size
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
super().to(device=device, dtype=dtype)
self.up = self.up.to(device=device, dtype=dtype)
self.down = self.down.to(device=device, dtype=dtype)
if self.mid is not None:
self.mid = self.mid.to(device=device, dtype=dtype)
class LoHALayer(LoRALayerBase):
# w1_a: torch.Tensor
# w1_b: torch.Tensor
# w2_a: torch.Tensor
# w2_b: torch.Tensor
# t1: Optional[torch.Tensor] = None
# t2: Optional[torch.Tensor] = None
def __init__(self, layer_key: str, values: Dict[str, torch.Tensor]):
super().__init__(layer_key, values)
self.w1_a = values["hada_w1_a"]
self.w1_b = values["hada_w1_b"]
self.w2_a = values["hada_w2_a"]
self.w2_b = values["hada_w2_b"]
self.t1 = values.get("hada_t1", None)
self.t2 = values.get("hada_t2", None)
self.rank = self.w1_b.shape[0]
self.check_keys(
values,
{
"hada_w1_a",
"hada_w1_b",
"hada_w2_a",
"hada_w2_b",
"hada_t1",
"hada_t2",
},
)
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
if self.t1 is None:
weight: torch.Tensor = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
else:
rebuild1 = torch.einsum("i j k l, j r, i p -> p r k l", self.t1, self.w1_b, self.w1_a)
rebuild2 = torch.einsum("i j k l, j r, i p -> p r k l", self.t2, self.w2_b, self.w2_a)
weight = rebuild1 * rebuild2
return weight
def calc_size(self) -> int:
model_size = super().calc_size()
for val in [self.w1_a, self.w1_b, self.w2_a, self.w2_b, self.t1, self.t2]:
if val is not None:
model_size += val.nelement() * val.element_size()
return model_size
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
super().to(device=device, dtype=dtype)
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
if self.t1 is not None:
self.t1 = self.t1.to(device=device, dtype=dtype)
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
if self.t2 is not None:
self.t2 = self.t2.to(device=device, dtype=dtype)
class LoKRLayer(LoRALayerBase):
# w1: Optional[torch.Tensor] = None
# w1_a: Optional[torch.Tensor] = None
# w1_b: Optional[torch.Tensor] = None
# w2: Optional[torch.Tensor] = None
# w2_a: Optional[torch.Tensor] = None
# w2_b: Optional[torch.Tensor] = None
# t2: Optional[torch.Tensor] = None
def __init__(
self,
layer_key: str,
values: Dict[str, torch.Tensor],
):
super().__init__(layer_key, values)
self.w1 = values.get("lokr_w1", None)
if self.w1 is None:
self.w1_a = values["lokr_w1_a"]
self.w1_b = values["lokr_w1_b"]
else:
self.w1_b = None
self.w1_a = None
self.w2 = values.get("lokr_w2", None)
if self.w2 is None:
self.w2_a = values["lokr_w2_a"]
self.w2_b = values["lokr_w2_b"]
else:
self.w2_a = None
self.w2_b = None
self.t2 = values.get("lokr_t2", None)
if self.w1_b is not None:
self.rank = self.w1_b.shape[0]
elif self.w2_b is not None:
self.rank = self.w2_b.shape[0]
else:
self.rank = None # unscaled
self.check_keys(
values,
{
"lokr_w1",
"lokr_w1_a",
"lokr_w1_b",
"lokr_w2",
"lokr_w2_a",
"lokr_w2_b",
"lokr_t2",
},
)
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
w1: Optional[torch.Tensor] = self.w1
if w1 is None:
assert self.w1_a is not None
assert self.w1_b is not None
w1 = self.w1_a @ self.w1_b
w2 = self.w2
if w2 is None:
if self.t2 is None:
assert self.w2_a is not None
assert self.w2_b is not None
w2 = self.w2_a @ self.w2_b
else:
w2 = torch.einsum("i j k l, i p, j r -> p r k l", self.t2, self.w2_a, self.w2_b)
if len(w2.shape) == 4:
w1 = w1.unsqueeze(2).unsqueeze(2)
w2 = w2.contiguous()
assert w1 is not None
assert w2 is not None
weight = torch.kron(w1, w2)
return weight
def calc_size(self) -> int:
model_size = super().calc_size()
for val in [self.w1, self.w1_a, self.w1_b, self.w2, self.w2_a, self.w2_b, self.t2]:
if val is not None:
model_size += val.nelement() * val.element_size()
return model_size
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
super().to(device=device, dtype=dtype)
if self.w1 is not None:
self.w1 = self.w1.to(device=device, dtype=dtype)
else:
assert self.w1_a is not None
assert self.w1_b is not None
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
if self.w2 is not None:
self.w2 = self.w2.to(device=device, dtype=dtype)
else:
assert self.w2_a is not None
assert self.w2_b is not None
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
if self.t2 is not None:
self.t2 = self.t2.to(device=device, dtype=dtype)
class FullLayer(LoRALayerBase):
# bias handled in LoRALayerBase(calc_size, to)
# weight: torch.Tensor
# bias: Optional[torch.Tensor]
def __init__(
self,
layer_key: str,
values: Dict[str, torch.Tensor],
):
super().__init__(layer_key, values)
self.weight = values["diff"]
self.bias = values.get("diff_b", None)
self.rank = None # unscaled
self.check_keys(values, {"diff", "diff_b"})
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
return self.weight
def calc_size(self) -> int:
model_size = super().calc_size()
model_size += self.weight.nelement() * self.weight.element_size()
return model_size
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
super().to(device=device, dtype=dtype)
self.weight = self.weight.to(device=device, dtype=dtype)
class IA3Layer(LoRALayerBase):
# weight: torch.Tensor
# on_input: torch.Tensor
def __init__(
self,
layer_key: str,
values: Dict[str, torch.Tensor],
):
super().__init__(layer_key, values)
self.weight = values["weight"]
self.on_input = values["on_input"]
self.rank = None # unscaled
self.check_keys(values, {"weight", "on_input"})
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
weight = self.weight
if not self.on_input:
weight = weight.reshape(-1, 1)
assert orig_weight is not None
return orig_weight * weight
def calc_size(self) -> int:
model_size = super().calc_size()
model_size += self.weight.nelement() * self.weight.element_size()
model_size += self.on_input.nelement() * self.on_input.element_size()
return model_size
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
super().to(device=device, dtype=dtype)
self.weight = self.weight.to(device=device, dtype=dtype)
self.on_input = self.on_input.to(device=device, dtype=dtype)
class NormLayer(LoRALayerBase):
# bias handled in LoRALayerBase(calc_size, to)
# weight: torch.Tensor
# bias: Optional[torch.Tensor]
def __init__(
self,
layer_key: str,
values: Dict[str, torch.Tensor],
):
super().__init__(layer_key, values)
self.weight = values["w_norm"]
self.bias = values.get("b_norm", None)
self.rank = None # unscaled
self.check_keys(values, {"w_norm", "b_norm"})
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
return self.weight
def calc_size(self) -> int:
model_size = super().calc_size()
model_size += self.weight.nelement() * self.weight.element_size()
return model_size
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
super().to(device=device, dtype=dtype)
self.weight = self.weight.to(device=device, dtype=dtype)
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer, NormLayer]
class LoRAModelRaw(RawModel): # (torch.nn.Module):
_name: str
layers: Dict[str, AnyLoRALayer]
def __init__(
self,
name: str,
layers: Dict[str, AnyLoRALayer],
):
self._name = name
self.layers = layers
@property
def name(self) -> str:
return self._name
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
# TODO: try revert if exception?
for _key, layer in self.layers.items():
layer.to(device=device, dtype=dtype)
def calc_size(self) -> int:
model_size = 0
for _, layer in self.layers.items():
model_size += layer.calc_size()
return model_size
@classmethod
def _convert_sdxl_keys_to_diffusers_format(cls, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""Convert the keys of an SDXL LoRA state_dict to diffusers format.
The input state_dict can be in either Stability AI format or diffusers format. If the state_dict is already in
diffusers format, then this function will have no effect.
This function is adapted from:
https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L385-L409
Args:
state_dict (Dict[str, Tensor]): The SDXL LoRA state_dict.
Raises:
ValueError: If state_dict contains an unrecognized key, or not all keys could be converted.
Returns:
Dict[str, Tensor]: The diffusers-format state_dict.
"""
converted_count = 0 # The number of Stability AI keys converted to diffusers format.
not_converted_count = 0 # The number of keys that were not converted.
# Get a sorted list of Stability AI UNet keys so that we can efficiently search for keys with matching prefixes.
# For example, we want to efficiently find `input_blocks_4_1` in the list when searching for
# `input_blocks_4_1_proj_in`.
stability_unet_keys = list(SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP)
stability_unet_keys.sort()
new_state_dict = {}
for full_key, value in state_dict.items():
if full_key.startswith("lora_unet_"):
search_key = full_key.replace("lora_unet_", "")
# Use bisect to find the key in stability_unet_keys that *may* match the search_key's prefix.
position = bisect.bisect_right(stability_unet_keys, search_key)
map_key = stability_unet_keys[position - 1]
# Now, check if the map_key *actually* matches the search_key.
if search_key.startswith(map_key):
new_key = full_key.replace(map_key, SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP[map_key])
new_state_dict[new_key] = value
converted_count += 1
else:
new_state_dict[full_key] = value
not_converted_count += 1
elif full_key.startswith("lora_te1_") or full_key.startswith("lora_te2_"):
# The CLIP text encoders have the same keys in both Stability AI and diffusers formats.
new_state_dict[full_key] = value
continue
else:
raise ValueError(f"Unrecognized SDXL LoRA key prefix: '{full_key}'.")
if converted_count > 0 and not_converted_count > 0:
raise ValueError(
f"The SDXL LoRA could only be partially converted to diffusers format. converted={converted_count},"
f" not_converted={not_converted_count}"
)
return new_state_dict
@classmethod
def from_checkpoint(
cls,
file_path: Union[str, Path],
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
base_model: Optional[BaseModelType] = None,
) -> Self:
device = device or torch.device("cpu")
dtype = dtype or torch.float32
if isinstance(file_path, str):
file_path = Path(file_path)
model = cls(
name=file_path.stem,
layers={},
)
if file_path.suffix == ".safetensors":
sd = load_file(file_path.absolute().as_posix(), device="cpu")
else:
sd = torch.load(file_path, map_location="cpu")
state_dict = cls._group_state(sd)
if base_model == BaseModelType.StableDiffusionXL:
state_dict = cls._convert_sdxl_keys_to_diffusers_format(state_dict)
for layer_key, values in state_dict.items():
# Detect layers according to LyCORIS detection logic(`weight_list_det`)
# https://github.com/KohakuBlueleaf/LyCORIS/tree/8ad8000efb79e2b879054da8c9356e6143591bad/lycoris/modules
# lora and locon
if "lora_up.weight" in values:
layer: AnyLoRALayer = LoRALayer(layer_key, values)
# loha
elif "hada_w1_a" in values:
layer = LoHALayer(layer_key, values)
# lokr
elif "lokr_w1" in values or "lokr_w1_a" in values:
layer = LoKRLayer(layer_key, values)
# diff
elif "diff" in values:
layer = FullLayer(layer_key, values)
# ia3
elif "on_input" in values:
layer = IA3Layer(layer_key, values)
# norms
elif "w_norm" in values:
layer = NormLayer(layer_key, values)
else:
print(f">> Encountered unknown lora layer module in {model.name}: {layer_key} - {list(values.keys())}")
raise Exception("Unknown lora format!")
# lower memory consumption by removing already parsed layer values
state_dict[layer_key].clear()
layer.to(device=device, dtype=dtype)
model.layers[layer_key] = layer
return model
@staticmethod
def _group_state(state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
state_dict_groupped: Dict[str, Dict[str, torch.Tensor]] = {}
for key, value in state_dict.items():
stem, leaf = key.split(".", 1)
if stem not in state_dict_groupped:
state_dict_groupped[stem] = {}
state_dict_groupped[stem][leaf] = value
return state_dict_groupped
# code from
# https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L15C1-L97C32
def make_sdxl_unet_conversion_map() -> List[Tuple[str, str]]:
"""Create a dict mapping state_dict keys from Stability AI SDXL format to diffusers SDXL format."""
unet_conversion_map_layer = []
for i in range(3): # num_blocks is 3 in sdxl
# loop over downblocks/upblocks
for j in range(2):
# loop over resnets/attentions for downblocks
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
if i < 3:
# no attention layers in down_blocks.3
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
for j in range(3):
# loop over resnets/attentions for upblocks
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
# if i > 0: commentout for sdxl
# no attention layers in up_blocks.0
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
if i < 3:
# no downsample in down_blocks.3
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
# no upsample in up_blocks.3
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
hf_mid_atn_prefix = "mid_block.attentions.0."
sd_mid_atn_prefix = "middle_block.1."
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
for j in range(2):
hf_mid_res_prefix = f"mid_block.resnets.{j}."
sd_mid_res_prefix = f"middle_block.{2*j}."
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
unet_conversion_map_resnet = [
# (stable-diffusion, HF Diffusers)
("in_layers.0.", "norm1."),
("in_layers.2.", "conv1."),
("out_layers.0.", "norm2."),
("out_layers.3.", "conv2."),
("emb_layers.1.", "time_emb_proj."),
("skip_connection.", "conv_shortcut."),
]
unet_conversion_map = []
for sd, hf in unet_conversion_map_layer:
if "resnets" in hf:
for sd_res, hf_res in unet_conversion_map_resnet:
unet_conversion_map.append((sd + sd_res, hf + hf_res))
else:
unet_conversion_map.append((sd, hf))
for j in range(2):
hf_time_embed_prefix = f"time_embedding.linear_{j+1}."
sd_time_embed_prefix = f"time_embed.{j*2}."
unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix))
for j in range(2):
hf_label_embed_prefix = f"add_embedding.linear_{j+1}."
sd_label_embed_prefix = f"label_emb.0.{j*2}."
unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix))
unet_conversion_map.append(("input_blocks.0.0.", "conv_in."))
unet_conversion_map.append(("out.0.", "conv_norm_out."))
unet_conversion_map.append(("out.2.", "conv_out."))
return unet_conversion_map
SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP = {
sd.rstrip(".").replace(".", "_"): hf.rstrip(".").replace(".", "_") for sd, hf in make_sdxl_unet_conversion_map()
}

View File

@@ -66,12 +66,14 @@ class ModelLoader(ModelLoaderBase):
return (model_base / config.path).resolve()
def _load_and_cache(self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> ModelLockerBase:
stats_name = ":".join([config.base, config.type, config.name, (submodel_type or "")])
try:
return self._ram_cache.get(config.key, submodel_type)
return self._ram_cache.get(config.key, submodel_type, stats_name=stats_name)
except IndexError:
pass
config.path = str(self._get_model_path(config))
self._ram_cache.make_room(self.get_size_fs(config, Path(config.path), submodel_type))
loaded_model = self._load_model(config, submodel_type)
self._ram_cache.put(
@@ -83,7 +85,7 @@ class ModelLoader(ModelLoaderBase):
return self._ram_cache.get(
key=config.key,
submodel_type=submodel_type,
stats_name=":".join([config.base, config.type, config.name, (submodel_type or "")]),
stats_name=stats_name,
)
def get_size_fs(

View File

@@ -128,7 +128,24 @@ class ModelCacheBase(ABC, Generic[T]):
@property
@abstractmethod
def max_cache_size(self) -> float:
"""Return true if the cache is configured to lazily offload models in VRAM."""
"""Return the maximum size the RAM cache can grow to."""
pass
@max_cache_size.setter
@abstractmethod
def max_cache_size(self, value: float) -> None:
"""Set the cap on vram cache size."""
@property
@abstractmethod
def max_vram_cache_size(self) -> float:
"""Return the maximum size the VRAM cache can grow to."""
pass
@max_vram_cache_size.setter
@abstractmethod
def max_vram_cache_size(self, value: float) -> float:
"""Set the maximum size the VRAM cache can grow to."""
pass
@abstractmethod
@@ -193,15 +210,6 @@ class ModelCacheBase(ABC, Generic[T]):
"""
pass
@abstractmethod
def exists(
self,
key: str,
submodel_type: Optional[SubModelType] = None,
) -> bool:
"""Return true if the model identified by key and submodel_type is in the cache."""
pass
@abstractmethod
def cache_size(self) -> int:
"""Get the total size of the models currently cached."""

View File

@@ -1,22 +1,6 @@
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development team
# TODO: Add Stalker's proper name to copyright
"""
Manage a RAM cache of diffusion/transformer models for fast switching.
They are moved between GPU VRAM and CPU RAM as necessary. If the cache
grows larger than a preset maximum, then the least recently used
model will be cleared and (re)loaded from disk when next needed.
The cache returns context manager generators designed to load the
model into the GPU within the context, and unload outside the
context. Use like this:
cache = ModelCache(max_cache_size=7.5)
with cache.get_model('runwayml/stable-diffusion-1-5') as SD1,
cache.get_model('stabilityai/stable-diffusion-2') as SD2:
do_something_in_GPU(SD1,SD2)
"""
""" """
import gc
import math
@@ -40,53 +24,74 @@ from invokeai.backend.model_manager.load.model_util import calc_model_size_by_da
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger
# Maximum size of the cache, in gigs
# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously
DEFAULT_MAX_CACHE_SIZE = 6.0
# amount of GPU memory to hold in reserve for use by generations (GB)
DEFAULT_MAX_VRAM_CACHE_SIZE = 2.75
# actual size of a gig
GIG = 1073741824
# Size of a GB in bytes.
GB = 2**30
# Size of a MB in bytes.
MB = 2**20
class ModelCache(ModelCacheBase[AnyModel]):
"""Implementation of ModelCacheBase."""
"""A cache for managing models in memory.
The cache is based on two levels of model storage:
- execution_device: The device where most models are executed (typically "cuda", "mps", or "cpu").
- storage_device: The device where models are offloaded when not in active use (typically "cpu").
The model cache is based on the following assumptions:
- storage_device_mem_size > execution_device_mem_size
- disk_to_storage_device_transfer_time >> storage_device_to_execution_device_transfer_time
A copy of all models in the cache is always kept on the storage_device. A subset of the models also have a copy on
the execution_device.
Models are moved between the storage_device and the execution_device as necessary. Cache size limits are enforced
on both the storage_device and the execution_device. The execution_device cache uses a smallest-first offload
policy. The storage_device cache uses a least-recently-used (LRU) offload policy.
Note: Neither of these offload policies has really been compared against alternatives. It's likely that different
policies would be better, although the optimal policies are likely heavily dependent on usage patterns and HW
configuration.
The cache returns context manager generators designed to load the model into the execution device (often GPU) within
the context, and unload outside the context.
Example usage:
```
cache = ModelCache(max_cache_size=7.5, max_vram_cache_size=6.0)
with cache.get_model('runwayml/stable-diffusion-1-5') as SD1:
do_something_on_gpu(SD1)
```
"""
def __init__(
self,
max_cache_size: float = DEFAULT_MAX_CACHE_SIZE,
max_vram_cache_size: float = DEFAULT_MAX_VRAM_CACHE_SIZE,
max_cache_size: float,
max_vram_cache_size: float,
execution_device: torch.device = torch.device("cuda"),
storage_device: torch.device = torch.device("cpu"),
precision: torch.dtype = torch.float16,
sequential_offload: bool = False,
lazy_offloading: bool = True,
sha_chunksize: int = 16777216,
log_memory_usage: bool = False,
logger: Optional[Logger] = None,
):
"""
Initialize the model RAM cache.
:param max_cache_size: Maximum size of the RAM cache [6.0 GB]
:param max_cache_size: Maximum size of the storage_device cache in GBs.
:param max_vram_cache_size: Maximum size of the execution_device cache in GBs.
:param execution_device: Torch device to load active model into [torch.device('cuda')]
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
:param precision: Precision for loaded models [torch.float16]
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
:param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially
:param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache
operation, and the result will be logged (at debug level). There is a time cost to capturing the memory
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
behaviour.
:param logger: InvokeAILogger to use (otherwise creates one)
"""
# allow lazy offloading only when vram cache enabled
self._lazy_offloading = lazy_offloading and max_vram_cache_size > 0
self._precision: torch.dtype = precision
self._max_cache_size: float = max_cache_size
self._max_vram_cache_size: float = max_vram_cache_size
self._execution_device: torch.device = execution_device
@@ -128,6 +133,16 @@ class ModelCache(ModelCacheBase[AnyModel]):
"""Set the cap on cache size."""
self._max_cache_size = value
@property
def max_vram_cache_size(self) -> float:
"""Return the cap on vram cache size."""
return self._max_vram_cache_size
@max_vram_cache_size.setter
def max_vram_cache_size(self, value: float) -> None:
"""Set the cap on vram cache size."""
self._max_vram_cache_size = value
@property
def stats(self) -> Optional[CacheStats]:
"""Return collected CacheStats object."""
@@ -145,15 +160,6 @@ class ModelCache(ModelCacheBase[AnyModel]):
total += cache_record.size
return total
def exists(
self,
key: str,
submodel_type: Optional[SubModelType] = None,
) -> bool:
"""Return true if the model identified by key and submodel_type is in the cache."""
key = self._make_cache_key(key, submodel_type)
return key in self._cached_models
def put(
self,
key: str,
@@ -203,7 +209,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
# more stats
if self.stats:
stats_name = stats_name or key
self.stats.cache_size = int(self._max_cache_size * GIG)
self.stats.cache_size = int(self._max_cache_size * GB)
self.stats.high_watermark = max(self.stats.high_watermark, self.cache_size())
self.stats.in_cache = len(self._cached_models)
self.stats.loaded_model_sizes[stats_name] = max(
@@ -231,10 +237,13 @@ class ModelCache(ModelCacheBase[AnyModel]):
return model_key
def offload_unlocked_models(self, size_required: int) -> None:
"""Move any unused models from VRAM."""
reserved = self._max_vram_cache_size * GIG
"""Offload models from the execution_device to make room for size_required.
:param size_required: The amount of space to clear in the execution_device cache, in bytes.
"""
reserved = self._max_vram_cache_size * GB
vram_in_use = torch.cuda.memory_allocated() + size_required
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM needed for models; max allowed={(reserved/GIG):.2f}GB")
self.logger.debug(f"{(vram_in_use/GB):.2f}GB VRAM needed for models; max allowed={(reserved/GB):.2f}GB")
for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
if vram_in_use <= reserved:
break
@@ -245,7 +254,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
cache_entry.loaded = False
vram_in_use = torch.cuda.memory_allocated() + size_required
self.logger.debug(
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GIG):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GIG):.2f}GB"
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GB):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GB):.2f}GB"
)
TorchDevice.empty_cache()
@@ -303,7 +312,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
self.logger.debug(
f"Moved model '{cache_entry.key}' from {source_device} to"
f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s."
f"Estimated model size: {(cache_entry.size/GIG):.3f} GB."
f"Estimated model size: {(cache_entry.size/GB):.3f} GB."
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
)
@@ -326,14 +335,14 @@ class ModelCache(ModelCacheBase[AnyModel]):
f"Moving model '{cache_entry.key}' from {source_device} to"
f" {target_device} caused an unexpected change in VRAM usage. The model's"
" estimated size may be incorrect. Estimated model size:"
f" {(cache_entry.size/GIG):.3f} GB.\n"
f" {(cache_entry.size/GB):.3f} GB.\n"
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
)
def print_cuda_stats(self) -> None:
"""Log CUDA diagnostics."""
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG)
ram = "%4.2fG" % (self.cache_size() / GIG)
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GB)
ram = "%4.2fG" % (self.cache_size() / GB)
in_ram_models = 0
in_vram_models = 0
@@ -353,17 +362,20 @@ class ModelCache(ModelCacheBase[AnyModel]):
)
def make_room(self, size: int) -> None:
"""Make enough room in the cache to accommodate a new model of indicated size."""
# calculate how much memory this model will require
# multiplier = 2 if self.precision==torch.float32 else 1
"""Make enough room in the cache to accommodate a new model of indicated size.
Note: This function deletes all of the cache's internal references to a model in order to free it. If there are
external references to the model, there's nothing that the cache can do about it, and those models will not be
garbage-collected.
"""
bytes_needed = size
maximum_size = self.max_cache_size * GIG # stored in GB, convert to bytes
maximum_size = self.max_cache_size * GB # stored in GB, convert to bytes
current_size = self.cache_size()
if current_size + bytes_needed > maximum_size:
self.logger.debug(
f"Max cache size exceeded: {(current_size/GIG):.2f}/{self.max_cache_size:.2f} GB, need an additional"
f" {(bytes_needed/GIG):.2f} GB"
f"Max cache size exceeded: {(current_size/GB):.2f}/{self.max_cache_size:.2f} GB, need an additional"
f" {(bytes_needed/GB):.2f} GB"
)
self.logger.debug(f"Before making_room: cached_models={len(self._cached_models)}")
@@ -380,7 +392,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
if not cache_entry.locked:
self.logger.debug(
f"Removing {model_key} from RAM cache to free at least {(size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
f"Removing {model_key} from RAM cache to free at least {(size/GB):.2f} GB (-{(cache_entry.size/GB):.2f} GB)"
)
current_size -= cache_entry.size
models_cleared += 1

View File

@@ -5,8 +5,10 @@ from logging import Logger
from pathlib import Path
from typing import Optional
import torch
from safetensors.torch import load_file
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_manager import (
AnyModel,
AnyModelConfig,
@@ -18,6 +20,11 @@ from invokeai.backend.model_manager import (
from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.peft.conversions.flux_kohya_lora_conversion_utils import (
lora_model_from_flux_kohya_state_dict,
)
from invokeai.backend.peft.conversions.sd_lora_conversion_utils import lora_model_from_sd_state_dict
from invokeai.backend.peft.conversions.sdxl_lora_conversion_utils import convert_sdxl_keys_to_diffusers_format
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LoRA, format=ModelFormat.Diffusers)
@@ -45,14 +52,28 @@ class LoRALoader(ModelLoader):
raise ValueError("There are no submodels in a LoRA model.")
model_path = Path(config.path)
assert self._model_base is not None
model = LoRAModelRaw.from_checkpoint(
file_path=model_path,
dtype=self._torch_dtype,
base_model=self._model_base,
)
# Load the state dict from the model file.
if model_path.suffix == ".safetensors":
state_dict = load_file(model_path.absolute().as_posix(), device="cpu")
else:
state_dict = torch.load(model_path, map_location="cpu")
# Apply state_dict key conversions, if necessary.
if self._model_base == BaseModelType.StableDiffusionXL:
state_dict = convert_sdxl_keys_to_diffusers_format(state_dict)
model = lora_model_from_sd_state_dict(state_dict=state_dict)
elif self._model_base == BaseModelType.Flux:
model = lora_model_from_flux_kohya_state_dict(state_dict=state_dict)
elif self._model_base in [BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2]:
# Currently, we don't apply any conversions for SD1 and SD2 LoRA models.
model = lora_model_from_sd_state_dict(state_dict=state_dict)
else:
raise ValueError(f"Unsupported LoRA base model: {self._model_base}")
model.to(dtype=self._torch_dtype)
return model
# override
def _get_model_path(self, config: AnyModelConfig) -> Path:
# cheating a little - we remember this variable for using in the subsequent call to _load_model()
self._model_base = config.base

View File

@@ -15,9 +15,9 @@ from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import D
from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline
from invokeai.backend.image_util.segment_anything.segment_anything_pipeline import SegmentAnythingPipeline
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_manager.config import AnyModel
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
from invokeai.backend.peft.lora import LoRAModelRaw
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
from invokeai.backend.textual_inversion import TextualInversionModelRaw

View File

@@ -26,6 +26,7 @@ from invokeai.backend.model_manager.config import (
SchedulerPredictionType,
)
from invokeai.backend.model_manager.util.model_util import lora_token_vector_length, read_checkpoint_meta
from invokeai.backend.peft.conversions.flux_kohya_lora_conversion_utils import is_state_dict_likely_in_flux_kohya_format
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
from invokeai.backend.util.silence_warnings import SilenceWarnings
@@ -528,9 +529,11 @@ class LoRACheckpointProbe(CheckpointProbeBase):
return ModelFormat("lycoris")
def get_base_type(self) -> BaseModelType:
checkpoint = self.checkpoint
token_vector_length = lora_token_vector_length(checkpoint)
if is_state_dict_likely_in_flux_kohya_format(self.checkpoint):
return BaseModelType.Flux
# If we've gotten here, we assume that the model is a Stable Diffusion model.
token_vector_length = lora_token_vector_length(self.checkpoint)
if token_vector_length == 768:
return BaseModelType.StableDiffusion1
elif token_vector_length == 1024:

View File

@@ -13,10 +13,10 @@ from diffusers import OnnxRuntimeModel, UNet2DConditionModel
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from invokeai.app.shared.models import FreeUConfig
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_manager import AnyModel
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
from invokeai.backend.peft.lora import LoRAModelRaw
from invokeai.backend.stable_diffusion.extensions.lora import LoRAExt
from invokeai.backend.textual_inversion import TextualInversionManager, TextualInversionModelRaw
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage

View File

View File

@@ -0,0 +1,84 @@
import re
from typing import Any, Dict, TypeVar
import torch
from invokeai.backend.peft.layers.any_lora_layer import AnyLoRALayer
from invokeai.backend.peft.layers.utils import peft_layer_from_state_dict
from invokeai.backend.peft.lora import LoRAModelRaw
# A regex pattern that matches all of the keys in the Kohya FLUX LoRA format.
# Example keys:
# lora_unet_double_blocks_0_img_attn_proj.alpha
# lora_unet_double_blocks_0_img_attn_proj.lora_down.weight
# lora_unet_double_blocks_0_img_attn_proj.lora_up.weight
FLUX_KOHYA_KEY_REGEX = (
r"lora_unet_(\w+_blocks)_(\d+)_(img_attn|img_mlp|img_mod|txt_attn|txt_mlp|txt_mod|linear1|linear2|modulation)_?(.*)"
)
def is_state_dict_likely_in_flux_kohya_format(state_dict: Dict[str, Any]) -> bool:
"""Checks if the provided state dict is likely in the Kohya FLUX LoRA format.
This is intended to be a high-precision detector, but it is not guaranteed to have perfect precision. (A
perfect-precision detector would require checking all keys against a whitelist and verifying tensor shapes.)
"""
for k in state_dict.keys():
if not re.match(FLUX_KOHYA_KEY_REGEX, k):
return False
return True
def lora_model_from_flux_kohya_state_dict(state_dict: Dict[str, torch.Tensor]) -> LoRAModelRaw:
# Group keys by layer.
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = {}
for key, value in state_dict.items():
layer_name, param_name = key.split(".", 1)
if layer_name not in grouped_state_dict:
grouped_state_dict[layer_name] = {}
grouped_state_dict[layer_name][param_name] = value
# Convert the state dict to the InvokeAI format.
grouped_state_dict = convert_flux_kohya_state_dict_to_invoke_format(grouped_state_dict)
# Create LoRA layers.
layers: dict[str, AnyLoRALayer] = {}
for layer_key, layer_state_dict in grouped_state_dict.items():
layer = peft_layer_from_state_dict(layer_key, layer_state_dict)
layers[layer_key] = layer
# Create and return the LoRAModelRaw.
return LoRAModelRaw(layers=layers)
T = TypeVar("T")
def convert_flux_kohya_state_dict_to_invoke_format(state_dict: Dict[str, T]) -> Dict[str, T]:
"""Converts a state dict from the Kohya FLUX LoRA format to LoRA weight format used internally by InvokeAI.
Example key conversions:
"lora_unet_double_blocks_0_img_attn_proj" -> "double_blocks.0.img_attn.proj"
"lora_unet_double_blocks_0_img_attn_proj" -> "double_blocks.0.img_attn.proj"
"lora_unet_double_blocks_0_img_attn_proj" -> "double_blocks.0.img_attn.proj"
"lora_unet_double_blocks_0_img_attn_qkv" -> "double_blocks.0.img_attn.qkv"
"lora_unet_double_blocks_0_img_attn_qkv" -> "double_blocks.0.img.attn.qkv"
"lora_unet_double_blocks_0_img_attn_qkv" -> "double_blocks.0.img.attn.qkv"
"""
def replace_func(match: re.Match[str]) -> str:
s = f"{match.group(1)}.{match.group(2)}.{match.group(3)}"
if match.group(4):
s += f".{match.group(4)}"
return s
converted_dict: dict[str, T] = {}
for k, v in state_dict.items():
match = re.match(FLUX_KOHYA_KEY_REGEX, k)
if match:
new_key = re.sub(FLUX_KOHYA_KEY_REGEX, replace_func, k)
converted_dict[new_key] = v
else:
raise ValueError(f"Key '{k}' does not match the expected pattern for FLUX LoRA weights.")
return converted_dict

View File

@@ -0,0 +1,30 @@
from typing import Dict
import torch
from invokeai.backend.peft.layers.any_lora_layer import AnyLoRALayer
from invokeai.backend.peft.layers.utils import peft_layer_from_state_dict
from invokeai.backend.peft.lora import LoRAModelRaw
def lora_model_from_sd_state_dict(state_dict: Dict[str, torch.Tensor]) -> LoRAModelRaw:
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = _group_state(state_dict)
layers: dict[str, AnyLoRALayer] = {}
for layer_key, values in grouped_state_dict.items():
layer = peft_layer_from_state_dict(layer_key, values)
layers[layer_key] = layer
return LoRAModelRaw(layers=layers)
def _group_state(state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
state_dict_groupped: Dict[str, Dict[str, torch.Tensor]] = {}
for key, value in state_dict.items():
stem, leaf = key.split(".", 1)
if stem not in state_dict_groupped:
state_dict_groupped[stem] = {}
state_dict_groupped[stem][leaf] = value
return state_dict_groupped

View File

@@ -0,0 +1,154 @@
import bisect
from typing import Dict, List, Tuple, TypeVar
T = TypeVar("T")
def convert_sdxl_keys_to_diffusers_format(state_dict: Dict[str, T]) -> dict[str, T]:
"""Convert the keys of an SDXL LoRA state_dict to diffusers format.
The input state_dict can be in either Stability AI format or diffusers format. If the state_dict is already in
diffusers format, then this function will have no effect.
This function is adapted from:
https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L385-L409
Args:
state_dict (Dict[str, Tensor]): The SDXL LoRA state_dict.
Raises:
ValueError: If state_dict contains an unrecognized key, or not all keys could be converted.
Returns:
Dict[str, Tensor]: The diffusers-format state_dict.
"""
converted_count = 0 # The number of Stability AI keys converted to diffusers format.
not_converted_count = 0 # The number of keys that were not converted.
# Get a sorted list of Stability AI UNet keys so that we can efficiently search for keys with matching prefixes.
# For example, we want to efficiently find `input_blocks_4_1` in the list when searching for
# `input_blocks_4_1_proj_in`.
stability_unet_keys = list(SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP)
stability_unet_keys.sort()
new_state_dict: dict[str, T] = {}
for full_key, value in state_dict.items():
if full_key.startswith("lora_unet_"):
search_key = full_key.replace("lora_unet_", "")
# Use bisect to find the key in stability_unet_keys that *may* match the search_key's prefix.
position = bisect.bisect_right(stability_unet_keys, search_key)
map_key = stability_unet_keys[position - 1]
# Now, check if the map_key *actually* matches the search_key.
if search_key.startswith(map_key):
new_key = full_key.replace(map_key, SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP[map_key])
new_state_dict[new_key] = value
converted_count += 1
else:
new_state_dict[full_key] = value
not_converted_count += 1
elif full_key.startswith("lora_te1_") or full_key.startswith("lora_te2_"):
# The CLIP text encoders have the same keys in both Stability AI and diffusers formats.
new_state_dict[full_key] = value
continue
else:
raise ValueError(f"Unrecognized SDXL LoRA key prefix: '{full_key}'.")
if converted_count > 0 and not_converted_count > 0:
raise ValueError(
f"The SDXL LoRA could only be partially converted to diffusers format. converted={converted_count},"
f" not_converted={not_converted_count}"
)
return new_state_dict
# code from
# https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L15C1-L97C32
def _make_sdxl_unet_conversion_map() -> List[Tuple[str, str]]:
"""Create a dict mapping state_dict keys from Stability AI SDXL format to diffusers SDXL format."""
unet_conversion_map_layer: list[tuple[str, str]] = []
for i in range(3): # num_blocks is 3 in sdxl
# loop over downblocks/upblocks
for j in range(2):
# loop over resnets/attentions for downblocks
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
if i < 3:
# no attention layers in down_blocks.3
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
for j in range(3):
# loop over resnets/attentions for upblocks
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
# if i > 0: commentout for sdxl
# no attention layers in up_blocks.0
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
if i < 3:
# no downsample in down_blocks.3
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
# no upsample in up_blocks.3
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
hf_mid_atn_prefix = "mid_block.attentions.0."
sd_mid_atn_prefix = "middle_block.1."
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
for j in range(2):
hf_mid_res_prefix = f"mid_block.resnets.{j}."
sd_mid_res_prefix = f"middle_block.{2*j}."
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
unet_conversion_map_resnet = [
# (stable-diffusion, HF Diffusers)
("in_layers.0.", "norm1."),
("in_layers.2.", "conv1."),
("out_layers.0.", "norm2."),
("out_layers.3.", "conv2."),
("emb_layers.1.", "time_emb_proj."),
("skip_connection.", "conv_shortcut."),
]
unet_conversion_map: list[tuple[str, str]] = []
for sd, hf in unet_conversion_map_layer:
if "resnets" in hf:
for sd_res, hf_res in unet_conversion_map_resnet:
unet_conversion_map.append((sd + sd_res, hf + hf_res))
else:
unet_conversion_map.append((sd, hf))
for j in range(2):
hf_time_embed_prefix = f"time_embedding.linear_{j+1}."
sd_time_embed_prefix = f"time_embed.{j*2}."
unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix))
for j in range(2):
hf_label_embed_prefix = f"add_embedding.linear_{j+1}."
sd_label_embed_prefix = f"label_emb.0.{j*2}."
unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix))
unet_conversion_map.append(("input_blocks.0.0.", "conv_in."))
unet_conversion_map.append(("out.0.", "conv_norm_out."))
unet_conversion_map.append(("out.2.", "conv_out."))
return unet_conversion_map
SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP = {
sd.rstrip(".").replace(".", "_"): hf.rstrip(".").replace(".", "_") for sd, hf in _make_sdxl_unet_conversion_map()
}

View File

View File

@@ -0,0 +1,10 @@
from typing import Union
from invokeai.backend.peft.layers.full_layer import FullLayer
from invokeai.backend.peft.layers.ia3_layer import IA3Layer
from invokeai.backend.peft.layers.loha_layer import LoHALayer
from invokeai.backend.peft.layers.lokr_layer import LoKRLayer
from invokeai.backend.peft.layers.lora_layer import LoRALayer
from invokeai.backend.peft.layers.norm_layer import NormLayer
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer, NormLayer]

View File

@@ -0,0 +1,37 @@
from typing import Dict, Optional
import torch
from invokeai.backend.peft.layers.lora_layer_base import LoRALayerBase
class FullLayer(LoRALayerBase):
# bias handled in LoRALayerBase(calc_size, to)
# weight: torch.Tensor
# bias: Optional[torch.Tensor]
def __init__(
self,
layer_key: str,
values: Dict[str, torch.Tensor],
):
super().__init__(layer_key, values)
self.weight = values["diff"]
self.bias = values.get("diff_b", None)
self.rank = None # unscaled
self.check_keys(values, {"diff", "diff_b"})
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
return self.weight
def calc_size(self) -> int:
model_size = super().calc_size()
model_size += self.weight.nelement() * self.weight.element_size()
return model_size
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
super().to(device=device, dtype=dtype)
self.weight = self.weight.to(device=device, dtype=dtype)

View File

@@ -0,0 +1,42 @@
from typing import Dict, Optional
import torch
from invokeai.backend.peft.layers.lora_layer_base import LoRALayerBase
class IA3Layer(LoRALayerBase):
# weight: torch.Tensor
# on_input: torch.Tensor
def __init__(
self,
layer_key: str,
values: Dict[str, torch.Tensor],
):
super().__init__(layer_key, values)
self.weight = values["weight"]
self.on_input = values["on_input"]
self.rank = None # unscaled
self.check_keys(values, {"weight", "on_input"})
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
weight = self.weight
if not self.on_input:
weight = weight.reshape(-1, 1)
assert orig_weight is not None
return orig_weight * weight
def calc_size(self) -> int:
model_size = super().calc_size()
model_size += self.weight.nelement() * self.weight.element_size()
model_size += self.on_input.nelement() * self.on_input.element_size()
return model_size
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
super().to(device=device, dtype=dtype)
self.weight = self.weight.to(device=device, dtype=dtype)
self.on_input = self.on_input.to(device=device, dtype=dtype)

View File

@@ -0,0 +1,68 @@
from typing import Dict, Optional
import torch
from invokeai.backend.peft.layers.lora_layer_base import LoRALayerBase
class LoHALayer(LoRALayerBase):
# w1_a: torch.Tensor
# w1_b: torch.Tensor
# w2_a: torch.Tensor
# w2_b: torch.Tensor
# t1: Optional[torch.Tensor] = None
# t2: Optional[torch.Tensor] = None
def __init__(self, layer_key: str, values: Dict[str, torch.Tensor]):
super().__init__(layer_key, values)
self.w1_a = values["hada_w1_a"]
self.w1_b = values["hada_w1_b"]
self.w2_a = values["hada_w2_a"]
self.w2_b = values["hada_w2_b"]
self.t1 = values.get("hada_t1", None)
self.t2 = values.get("hada_t2", None)
self.rank = self.w1_b.shape[0]
self.check_keys(
values,
{
"hada_w1_a",
"hada_w1_b",
"hada_w2_a",
"hada_w2_b",
"hada_t1",
"hada_t2",
},
)
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
if self.t1 is None:
weight: torch.Tensor = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
else:
rebuild1 = torch.einsum("i j k l, j r, i p -> p r k l", self.t1, self.w1_b, self.w1_a)
rebuild2 = torch.einsum("i j k l, j r, i p -> p r k l", self.t2, self.w2_b, self.w2_a)
weight = rebuild1 * rebuild2
return weight
def calc_size(self) -> int:
model_size = super().calc_size()
for val in [self.w1_a, self.w1_b, self.w2_a, self.w2_b, self.t1, self.t2]:
if val is not None:
model_size += val.nelement() * val.element_size()
return model_size
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
super().to(device=device, dtype=dtype)
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
if self.t1 is not None:
self.t1 = self.t1.to(device=device, dtype=dtype)
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
if self.t2 is not None:
self.t2 = self.t2.to(device=device, dtype=dtype)

View File

@@ -0,0 +1,114 @@
from typing import Dict, Optional
import torch
from invokeai.backend.peft.layers.lora_layer_base import LoRALayerBase
class LoKRLayer(LoRALayerBase):
# w1: Optional[torch.Tensor] = None
# w1_a: Optional[torch.Tensor] = None
# w1_b: Optional[torch.Tensor] = None
# w2: Optional[torch.Tensor] = None
# w2_a: Optional[torch.Tensor] = None
# w2_b: Optional[torch.Tensor] = None
# t2: Optional[torch.Tensor] = None
def __init__(
self,
layer_key: str,
values: Dict[str, torch.Tensor],
):
super().__init__(layer_key, values)
self.w1 = values.get("lokr_w1", None)
if self.w1 is None:
self.w1_a = values["lokr_w1_a"]
self.w1_b = values["lokr_w1_b"]
else:
self.w1_b = None
self.w1_a = None
self.w2 = values.get("lokr_w2", None)
if self.w2 is None:
self.w2_a = values["lokr_w2_a"]
self.w2_b = values["lokr_w2_b"]
else:
self.w2_a = None
self.w2_b = None
self.t2 = values.get("lokr_t2", None)
if self.w1_b is not None:
self.rank = self.w1_b.shape[0]
elif self.w2_b is not None:
self.rank = self.w2_b.shape[0]
else:
self.rank = None # unscaled
self.check_keys(
values,
{
"lokr_w1",
"lokr_w1_a",
"lokr_w1_b",
"lokr_w2",
"lokr_w2_a",
"lokr_w2_b",
"lokr_t2",
},
)
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
w1: Optional[torch.Tensor] = self.w1
if w1 is None:
assert self.w1_a is not None
assert self.w1_b is not None
w1 = self.w1_a @ self.w1_b
w2 = self.w2
if w2 is None:
if self.t2 is None:
assert self.w2_a is not None
assert self.w2_b is not None
w2 = self.w2_a @ self.w2_b
else:
w2 = torch.einsum("i j k l, i p, j r -> p r k l", self.t2, self.w2_a, self.w2_b)
if len(w2.shape) == 4:
w1 = w1.unsqueeze(2).unsqueeze(2)
w2 = w2.contiguous()
assert w1 is not None
assert w2 is not None
weight = torch.kron(w1, w2)
return weight
def calc_size(self) -> int:
model_size = super().calc_size()
for val in [self.w1, self.w1_a, self.w1_b, self.w2, self.w2_a, self.w2_b, self.t2]:
if val is not None:
model_size += val.nelement() * val.element_size()
return model_size
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
super().to(device=device, dtype=dtype)
if self.w1 is not None:
self.w1 = self.w1.to(device=device, dtype=dtype)
else:
assert self.w1_a is not None
assert self.w1_b is not None
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
if self.w2 is not None:
self.w2 = self.w2.to(device=device, dtype=dtype)
else:
assert self.w2_a is not None
assert self.w2_b is not None
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
if self.t2 is not None:
self.t2 = self.t2.to(device=device, dtype=dtype)

View File

@@ -0,0 +1,59 @@
from typing import Dict, Optional
import torch
from invokeai.backend.peft.layers.lora_layer_base import LoRALayerBase
# TODO: find and debug lora/locon with bias
class LoRALayer(LoRALayerBase):
# up: torch.Tensor
# mid: Optional[torch.Tensor]
# down: torch.Tensor
def __init__(
self,
layer_key: str,
values: Dict[str, torch.Tensor],
):
super().__init__(layer_key, values)
self.up = values["lora_up.weight"]
self.down = values["lora_down.weight"]
self.mid = values.get("lora_mid.weight", None)
self.rank = self.down.shape[0]
self.check_keys(
values,
{
"lora_up.weight",
"lora_down.weight",
"lora_mid.weight",
},
)
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
if self.mid is not None:
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down)
else:
weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1)
return weight
def calc_size(self) -> int:
model_size = super().calc_size()
for val in [self.up, self.mid, self.down]:
if val is not None:
model_size += val.nelement() * val.element_size()
return model_size
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
super().to(device=device, dtype=dtype)
self.up = self.up.to(device=device, dtype=dtype)
self.down = self.down.to(device=device, dtype=dtype)
if self.mid is not None:
self.mid = self.mid.to(device=device, dtype=dtype)

View File

@@ -0,0 +1,74 @@
from typing import Dict, Optional, Set
import torch
import invokeai.backend.util.logging as logger
class LoRALayerBase:
# rank: Optional[int]
# alpha: Optional[float]
# bias: Optional[torch.Tensor]
# layer_key: str
# @property
# def scale(self):
# return self.alpha / self.rank if (self.alpha and self.rank) else 1.0
def __init__(
self,
layer_key: str,
values: Dict[str, torch.Tensor],
):
if "alpha" in values:
self.alpha = values["alpha"].item()
else:
self.alpha = None
if "bias_indices" in values and "bias_values" in values and "bias_size" in values:
self.bias: Optional[torch.Tensor] = torch.sparse_coo_tensor(
values["bias_indices"],
values["bias_values"],
tuple(values["bias_size"]),
)
else:
self.bias = None
self.rank = None # set in layer implementation
self.layer_key = layer_key
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
raise NotImplementedError()
def get_bias(self, orig_bias: torch.Tensor) -> Optional[torch.Tensor]:
return self.bias
def get_parameters(self, orig_module: torch.nn.Module) -> Dict[str, torch.Tensor]:
params = {"weight": self.get_weight(orig_module.weight)}
bias = self.get_bias(orig_module.bias)
if bias is not None:
params["bias"] = bias
return params
def calc_size(self) -> int:
model_size = 0
for val in [self.bias]:
if val is not None:
model_size += val.nelement() * val.element_size()
return model_size
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
if self.bias is not None:
self.bias = self.bias.to(device=device, dtype=dtype)
def check_keys(self, values: Dict[str, torch.Tensor], known_keys: Set[str]):
"""Log a warning if values contains unhandled keys."""
# {"alpha", "bias_indices", "bias_values", "bias_size"} are hard-coded, because they are handled by
# `LoRALayerBase`. Sub-classes should provide the known_keys that they handled.
all_known_keys = known_keys | {"alpha", "bias_indices", "bias_values", "bias_size"}
unknown_keys = set(values.keys()) - all_known_keys
if unknown_keys:
logger.warning(
f"Unexpected keys found in LoRA/LyCORIS layer, model might work incorrectly! Keys: {unknown_keys}"
)

View File

@@ -0,0 +1,37 @@
from typing import Dict, Optional
import torch
from invokeai.backend.peft.layers.lora_layer_base import LoRALayerBase
class NormLayer(LoRALayerBase):
# bias handled in LoRALayerBase(calc_size, to)
# weight: torch.Tensor
# bias: Optional[torch.Tensor]
def __init__(
self,
layer_key: str,
values: Dict[str, torch.Tensor],
):
super().__init__(layer_key, values)
self.weight = values["w_norm"]
self.bias = values.get("b_norm", None)
self.rank = None # unscaled
self.check_keys(values, {"w_norm", "b_norm"})
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
return self.weight
def calc_size(self) -> int:
model_size = super().calc_size()
model_size += self.weight.nelement() * self.weight.element_size()
return model_size
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
super().to(device=device, dtype=dtype)
self.weight = self.weight.to(device=device, dtype=dtype)

View File

@@ -0,0 +1,33 @@
from typing import Dict
import torch
from invokeai.backend.peft.layers.any_lora_layer import AnyLoRALayer
from invokeai.backend.peft.layers.full_layer import FullLayer
from invokeai.backend.peft.layers.ia3_layer import IA3Layer
from invokeai.backend.peft.layers.loha_layer import LoHALayer
from invokeai.backend.peft.layers.lokr_layer import LoKRLayer
from invokeai.backend.peft.layers.lora_layer import LoRALayer
from invokeai.backend.peft.layers.norm_layer import NormLayer
def peft_layer_from_state_dict(layer_key: str, state_dict: Dict[str, torch.Tensor]) -> AnyLoRALayer:
# Detect layers according to LyCORIS detection logic(`weight_list_det`)
# https://github.com/KohakuBlueleaf/LyCORIS/tree/8ad8000efb79e2b879054da8c9356e6143591bad/lycoris/modules
if "lora_up.weight" in state_dict:
# LoRA a.k.a LoCon
return LoRALayer(layer_key, state_dict)
elif "hada_w1_a" in state_dict:
return LoHALayer(layer_key, state_dict)
elif "lokr_w1" in state_dict or "lokr_w1_a" in state_dict:
return LoKRLayer(layer_key, state_dict)
elif "diff" in state_dict:
# Full a.k.a Diff
return FullLayer(layer_key, state_dict)
elif "on_input" in state_dict:
return IA3Layer(layer_key, state_dict)
elif "w_norm" in state_dict:
return NormLayer(layer_key, state_dict)
else:
raise ValueError(f"Unsupported lora format: {state_dict.keys()}")

View File

@@ -0,0 +1,22 @@
# Copyright (c) 2024 The InvokeAI Development team
from typing import Dict, Optional
import torch
from invokeai.backend.peft.layers.any_lora_layer import AnyLoRALayer
from invokeai.backend.raw_model import RawModel
class LoRAModelRaw(RawModel): # (torch.nn.Module):
def __init__(self, layers: Dict[str, AnyLoRALayer]):
self.layers = layers
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
for _key, layer in self.layers.items():
layer.to(device=device, dtype=dtype)
def calc_size(self) -> int:
model_size = 0
for _, layer in self.layers.items():
model_size += layer.calc_size()
return model_size

View File

@@ -0,0 +1,102 @@
from contextlib import contextmanager
from typing import Dict, Iterator, Optional, Tuple
import torch
from invokeai.backend.peft.lora import LoRAModelRaw
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
class PeftPatcher:
@classmethod
@torch.no_grad()
@contextmanager
def apply_peft_patches(
cls,
model: torch.nn.Module,
patches: Iterator[Tuple[LoRAModelRaw, float]],
prefix: str,
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
):
"""Apply one or more PEFT patches to a model.
:param model: The model to patch.
:param loras: An iterator that returns tuples of PEFT patches and associated weights. An iterator is used so
that the PEFT patches do not need to be loaded into memory all at once.
:param prefix: The keys in the patches will be filtered to only include weights with this prefix.
:cached_weights: Read-only copy of the model's state dict in CPU, for efficient unpatching purposes.
"""
original_weights = OriginalWeightsStorage(cached_weights)
try:
for patch, patch_weight in patches:
cls._apply_peft_patch(
model=model,
prefix=prefix,
patch=patch,
patch_weight=patch_weight,
original_weights=original_weights,
)
yield
finally:
for param_key, weight in original_weights.get_changed_weights():
model.get_parameter(param_key).copy_(weight)
@classmethod
@torch.no_grad()
def _apply_peft_patch(
cls,
model: torch.nn.Module,
prefix: str,
patch: LoRAModelRaw,
patch_weight: float,
original_weights: OriginalWeightsStorage,
):
"""
Apply one a LoRA to a model.
:param model: The model to patch.
:param patch: LoRA model to patch in.
:param patch_weight: LoRA patch weight.
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
:param original_weights: Storage with original weights, filled by weights which lora patches, used for unpatching.
"""
if patch_weight == 0:
return
for layer_key, layer in patch.layers.items():
if not layer_key.startswith(prefix):
continue
module = model.get_submodule(layer_key)
# All of the LoRA weight calculations will be done on the same device as the module weight.
# (Performance will be best if this is a CUDA device.)
device = module.weight.device
dtype = module.weight.dtype
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
# We intentionally move to the target device first, then cast. Experimentally, this was found to
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
# same thing in a single call to '.to(...)'.
layer.to(device=device)
layer.to(dtype=torch.float32)
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
for param_name, lora_param_weight in layer.get_parameters(module).items():
param_key = layer_key + "." + param_name
module_param = module.get_parameter(param_name)
# Save original weight
original_weights.save(param_key, module_param)
if module_param.shape != lora_param_weight.shape:
lora_param_weight = lora_param_weight.reshape(module_param.shape)
lora_param_weight *= patch_weight * layer_scale
module_param += lora_param_weight.to(dtype=dtype)
layer.to(device=TorchDevice.CPU_DEVICE)

View File

@@ -54,8 +54,10 @@ class InvokeLinear8bitLt(bnb.nn.Linear8bitLt):
# See `bnb.nn.Linear8bitLt._save_to_state_dict()` for the serialization logic of SCB and weight_format.
scb = state_dict.pop(prefix + "SCB", None)
# weight_format is unused, but we pop it so we can validate that there are no unexpected keys.
_weight_format = state_dict.pop(prefix + "weight_format", None)
# Currently, we only support weight_format=0.
weight_format = state_dict.pop(prefix + "weight_format", None)
assert weight_format == 0
# TODO(ryand): Technically, we should be using `strict`, `missing_keys`, `unexpected_keys`, and `error_msgs`
# rather than raising an exception to correctly implement this API.
@@ -89,6 +91,14 @@ class InvokeLinear8bitLt(bnb.nn.Linear8bitLt):
)
self.bias = bias if bias is None else torch.nn.Parameter(bias)
# Reset the state. The persisted fields are based on the initialization behaviour in
# `bnb.nn.Linear8bitLt.__init__()`.
new_state = bnb.MatmulLtState()
new_state.threshold = self.state.threshold
new_state.has_fp16_weights = False
new_state.use_pool = self.state.use_pool
self.state = new_state
def _convert_linear_layers_to_llm_8bit(
module: torch.nn.Module, ignore_modules: set[str], outlier_threshold: float, prefix: str = ""

View File

@@ -43,6 +43,11 @@ class FLUXConditioningInfo:
clip_embeds: torch.Tensor
t5_embeds: torch.Tensor
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
self.clip_embeds = self.clip_embeds.to(device=device, dtype=dtype)
self.t5_embeds = self.t5_embeds.to(device=device, dtype=dtype)
return self
@dataclass
class ConditioningFieldData:

View File

@@ -12,7 +12,7 @@ from invokeai.backend.util.devices import TorchDevice
if TYPE_CHECKING:
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.peft.lora import LoRAModelRaw
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage

View File

@@ -3,10 +3,9 @@ Initialization file for invokeai.backend.util
"""
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.backend.util.util import GIG, Chdir, directory_size
from invokeai.backend.util.util import Chdir, directory_size
__all__ = [
"GIG",
"directory_size",
"Chdir",
"InvokeAILogger",

View File

@@ -7,9 +7,6 @@ from pathlib import Path
from PIL import Image
# actual size of a gig
GIG = 1073741824
def slugify(value: str, allow_unicode: bool = False) -> str:
"""

View File

@@ -696,6 +696,8 @@
"availableModels": "Available Models",
"baseModel": "Base Model",
"cancel": "Cancel",
"clipEmbed": "CLIP Embed",
"clipVision": "CLIP Vision",
"config": "Config",
"convert": "Convert",
"convertingModelBegin": "Converting Model. Please wait.",
@@ -783,6 +785,7 @@
"settings": "Settings",
"simpleModelPlaceholder": "URL or path to a local file or diffusers folder",
"source": "Source",
"spandrelImageToImage": "Image to Image (Spandrel)",
"starterModels": "Starter Models",
"starterModelsInModelManager": "Starter Models can be found in Model Manager",
"syncModels": "Sync Models",
@@ -791,6 +794,7 @@
"loraTriggerPhrases": "LoRA Trigger Phrases",
"mainModelTriggerPhrases": "Main Model Trigger Phrases",
"typePhraseHere": "Type phrase here",
"t5Encoder": "T5 Encoder",
"upcastAttention": "Upcast Attention",
"uploadImage": "Upload Image",
"urlOrLocalPath": "URL or Local Path",

View File

@@ -14,6 +14,7 @@ import DeleteImageModal from 'features/deleteImageModal/components/DeleteImageMo
import { DynamicPromptsModal } from 'features/dynamicPrompts/components/DynamicPromptsPreviewModal';
import { useStarterModelsToast } from 'features/modelManagerV2/hooks/useStarterModelsToast';
import { StylePresetModal } from 'features/stylePresets/components/StylePresetForm/StylePresetModal';
import { activeStylePresetIdChanged } from 'features/stylePresets/store/stylePresetSlice';
import { configChanged } from 'features/system/store/configSlice';
import { languageSelector } from 'features/system/store/systemSelectors';
import InvokeTabs from 'features/ui/components/InvokeTabs';
@@ -39,10 +40,17 @@ interface Props {
action: 'sendToImg2Img' | 'sendToCanvas' | 'useAllParameters';
};
selectedWorkflowId?: string;
selectedStylePresetId?: string;
destination?: InvokeTabName | undefined;
}
const App = ({ config = DEFAULT_CONFIG, selectedImage, selectedWorkflowId, destination }: Props) => {
const App = ({
config = DEFAULT_CONFIG,
selectedImage,
selectedWorkflowId,
selectedStylePresetId,
destination,
}: Props) => {
const language = useAppSelector(languageSelector);
const logger = useLogger('system');
const dispatch = useAppDispatch();
@@ -81,6 +89,12 @@ const App = ({ config = DEFAULT_CONFIG, selectedImage, selectedWorkflowId, desti
}
}, [selectedWorkflowId, getAndLoadWorkflow]);
useEffect(() => {
if (selectedStylePresetId) {
dispatch(activeStylePresetIdChanged(selectedStylePresetId));
}
}, [dispatch, selectedStylePresetId]);
useEffect(() => {
if (destination) {
dispatch(setActiveTab(destination));

View File

@@ -45,6 +45,7 @@ interface Props extends PropsWithChildren {
action: 'sendToImg2Img' | 'sendToCanvas' | 'useAllParameters';
};
selectedWorkflowId?: string;
selectedStylePresetId?: string;
destination?: InvokeTabName;
customStarUi?: CustomStarUi;
socketOptions?: Partial<ManagerOptions & SocketOptions>;
@@ -66,6 +67,7 @@ const InvokeAIUI = ({
queueId,
selectedImage,
selectedWorkflowId,
selectedStylePresetId,
destination,
customStarUi,
socketOptions,
@@ -227,6 +229,7 @@ const InvokeAIUI = ({
config={config}
selectedImage={selectedImage}
selectedWorkflowId={selectedWorkflowId}
selectedStylePresetId={selectedStylePresetId}
destination={destination}
/>
</AppDndContext>

View File

@@ -175,12 +175,12 @@ const ModelList = () => {
{/* T5 Encoders List */}
{isLoadingT5EncoderModels && <FetchingModelsLoader loadingMessage="Loading T5 Encoder Models..." />}
{!isLoadingT5EncoderModels && filteredT5EncoderModels.length > 0 && (
<ModelListWrapper title="T5 Encoder" modelList={filteredT5EncoderModels} key="t5-encoder" />
<ModelListWrapper title={t('modelManager.t5Encoder')} modelList={filteredT5EncoderModels} key="t5-encoder" />
)}
{/* Clip Embed List */}
{isLoadingClipEmbedModels && <FetchingModelsLoader loadingMessage="Loading Clip Embed Models..." />}
{!isLoadingClipEmbedModels && filteredClipEmbedModels.length > 0 && (
<ModelListWrapper title="Clip Embed" modelList={filteredClipEmbedModels} key="clip-embed" />
<ModelListWrapper title={t('modelManager.clipEmbed')} modelList={filteredClipEmbedModels} key="clip-embed" />
)}
{/* Spandrel Image to Image List */}
{isLoadingSpandrelImageToImageModels && (
@@ -188,7 +188,7 @@ const ModelList = () => {
)}
{!isLoadingSpandrelImageToImageModels && filteredSpandrelImageToImageModels.length > 0 && (
<ModelListWrapper
title="Image-to-Image"
title={t('modelManager.spandrelImageToImage')}
modelList={filteredSpandrelImageToImageModels}
key="spandrel-image-to-image"
/>

View File

@@ -19,11 +19,10 @@ export const ModelTypeFilter = memo(() => {
controlnet: 'ControlNet',
vae: 'VAE',
t2i_adapter: t('common.t2iAdapter'),
t5_encoder: 'T5Encoder',
clip_embed: 'Clip Embed',
t5_encoder: t('modelManager.t5Encoder'),
clip_embed: t('modelManager.clipEmbed'),
ip_adapter: t('common.ipAdapter'),
clip_vision: 'Clip Vision',
spandrel_image_to_image: 'Image-to-Image',
spandrel_image_to_image: t('modelManager.spandrelImageToImage'),
}),
[t]
);

View File

@@ -6,6 +6,8 @@ import {
isBoardFieldInputTemplate,
isBooleanFieldInputInstance,
isBooleanFieldInputTemplate,
isCLIPEmbedModelFieldInputInstance,
isCLIPEmbedModelFieldInputTemplate,
isColorFieldInputInstance,
isColorFieldInputTemplate,
isControlNetModelFieldInputInstance,
@@ -16,6 +18,8 @@ import {
isFloatFieldInputTemplate,
isFluxMainModelFieldInputInstance,
isFluxMainModelFieldInputTemplate,
isFluxVAEModelFieldInputInstance,
isFluxVAEModelFieldInputTemplate,
isImageFieldInputInstance,
isImageFieldInputTemplate,
isIntegerFieldInputInstance,
@@ -49,10 +53,12 @@ import { memo } from 'react';
import BoardFieldInputComponent from './inputs/BoardFieldInputComponent';
import BooleanFieldInputComponent from './inputs/BooleanFieldInputComponent';
import CLIPEmbedModelFieldInputComponent from './inputs/CLIPEmbedModelFieldInputComponent';
import ColorFieldInputComponent from './inputs/ColorFieldInputComponent';
import ControlNetModelFieldInputComponent from './inputs/ControlNetModelFieldInputComponent';
import EnumFieldInputComponent from './inputs/EnumFieldInputComponent';
import FluxMainModelFieldInputComponent from './inputs/FluxMainModelFieldInputComponent';
import FluxVAEModelFieldInputComponent from './inputs/FluxVAEModelFieldInputComponent';
import ImageFieldInputComponent from './inputs/ImageFieldInputComponent';
import IPAdapterModelFieldInputComponent from './inputs/IPAdapterModelFieldInputComponent';
import LoRAModelFieldInputComponent from './inputs/LoRAModelFieldInputComponent';
@@ -122,6 +128,13 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
if (isT5EncoderModelFieldInputInstance(fieldInstance) && isT5EncoderModelFieldInputTemplate(fieldTemplate)) {
return <T5EncoderModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isCLIPEmbedModelFieldInputInstance(fieldInstance) && isCLIPEmbedModelFieldInputTemplate(fieldTemplate)) {
return <CLIPEmbedModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isFluxVAEModelFieldInputInstance(fieldInstance) && isFluxVAEModelFieldInputTemplate(fieldTemplate)) {
return <FluxVAEModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isLoRAModelFieldInputInstance(fieldInstance) && isLoRAModelFieldInputTemplate(fieldTemplate)) {
return <LoRAModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;

View File

@@ -0,0 +1,60 @@
import { Combobox, Flex, FormControl, Tooltip } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { fieldCLIPEmbedValueChanged } from 'features/nodes/store/nodesSlice';
import type { CLIPEmbedModelFieldInputInstance, CLIPEmbedModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useClipEmbedModels } from 'services/api/hooks/modelsByType';
import type { ClipEmbedModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
type Props = FieldComponentProps<CLIPEmbedModelFieldInputInstance, CLIPEmbedModelFieldInputTemplate>;
const CLIPEmbedModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const { t } = useTranslation();
const disabledTabs = useAppSelector((s) => s.config.disabledTabs);
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useClipEmbedModels();
const _onChange = useCallback(
(value: ClipEmbedModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldCLIPEmbedValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelConfigs,
onChange: _onChange,
isLoading,
selectedModel: field.value,
});
return (
<Flex w="full" alignItems="center" gap={2}>
<Tooltip label={!disabledTabs.includes('models') && t('modelManager.starterModelsInModelManager')}>
<FormControl className="nowheel nodrag" isDisabled={!options.length} isInvalid={!value}>
<Combobox
value={value}
placeholder={placeholder}
options={options}
onChange={onChange}
noOptionsMessage={noOptionsMessage}
/>
</FormControl>
</Tooltip>
</Flex>
);
};
export default memo(CLIPEmbedModelFieldInputComponent);

View File

@@ -0,0 +1,60 @@
import { Combobox, Flex, FormControl, Tooltip } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { fieldFluxVAEModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { FluxVAEModelFieldInputInstance, FluxVAEModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useFluxVAEModels } from 'services/api/hooks/modelsByType';
import type { VAEModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
type Props = FieldComponentProps<FluxVAEModelFieldInputInstance, FluxVAEModelFieldInputTemplate>;
const FluxVAEModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const { t } = useTranslation();
const disabledTabs = useAppSelector((s) => s.config.disabledTabs);
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useFluxVAEModels();
const _onChange = useCallback(
(value: VAEModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldFluxVAEModelValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelConfigs,
onChange: _onChange,
isLoading,
selectedModel: field.value,
});
return (
<Flex w="full" alignItems="center" gap={2}>
<Tooltip label={!disabledTabs.includes('models') && t('modelManager.starterModelsInModelManager')}>
<FormControl className="nowheel nodrag" isDisabled={!options.length} isInvalid={!value}>
<Combobox
value={value}
placeholder={placeholder}
options={options}
onChange={onChange}
noOptionsMessage={noOptionsMessage}
/>
</FormControl>
</Tooltip>
</Flex>
);
};
export default memo(FluxVAEModelFieldInputComponent);

View File

@@ -6,11 +6,13 @@ import { SHARED_NODE_PROPERTIES } from 'features/nodes/types/constants';
import type {
BoardFieldValue,
BooleanFieldValue,
CLIPEmbedModelFieldValue,
ColorFieldValue,
ControlNetModelFieldValue,
EnumFieldValue,
FieldValue,
FloatFieldValue,
FluxVAEModelFieldValue,
ImageFieldValue,
IntegerFieldValue,
IPAdapterModelFieldValue,
@@ -29,10 +31,12 @@ import type {
import {
zBoardFieldValue,
zBooleanFieldValue,
zCLIPEmbedModelFieldValue,
zColorFieldValue,
zControlNetModelFieldValue,
zEnumFieldValue,
zFloatFieldValue,
zFluxVAEModelFieldValue,
zImageFieldValue,
zIntegerFieldValue,
zIPAdapterModelFieldValue,
@@ -346,6 +350,12 @@ export const nodesSlice = createSlice({
fieldT5EncoderValueChanged: (state, action: FieldValueAction<T5EncoderModelFieldValue>) => {
fieldValueReducer(state, action, zT5EncoderModelFieldValue);
},
fieldCLIPEmbedValueChanged: (state, action: FieldValueAction<CLIPEmbedModelFieldValue>) => {
fieldValueReducer(state, action, zCLIPEmbedModelFieldValue);
},
fieldFluxVAEModelValueChanged: (state, action: FieldValueAction<FluxVAEModelFieldValue>) => {
fieldValueReducer(state, action, zFluxVAEModelFieldValue);
},
fieldEnumModelValueChanged: (state, action: FieldValueAction<EnumFieldValue>) => {
fieldValueReducer(state, action, zEnumFieldValue);
},
@@ -408,6 +418,8 @@ export const {
fieldStringValueChanged,
fieldVaeModelValueChanged,
fieldT5EncoderValueChanged,
fieldCLIPEmbedValueChanged,
fieldFluxVAEModelValueChanged,
nodeEditorReset,
nodeIsIntermediateChanged,
nodeIsOpenChanged,
@@ -521,6 +533,8 @@ export const isAnyNodeOrEdgeMutation = isAnyOf(
fieldStringValueChanged,
fieldVaeModelValueChanged,
fieldT5EncoderValueChanged,
fieldCLIPEmbedValueChanged,
fieldFluxVAEModelValueChanged,
nodesChanged,
nodeIsIntermediateChanged,
nodeIsOpenChanged,

View File

@@ -151,6 +151,14 @@ const zT5EncoderModelFieldType = zFieldTypeBase.extend({
name: z.literal('T5EncoderModelField'),
originalType: zStatelessFieldType.optional(),
});
const zCLIPEmbedModelFieldType = zFieldTypeBase.extend({
name: z.literal('CLIPEmbedModelField'),
originalType: zStatelessFieldType.optional(),
});
const zFluxVAEModelFieldType = zFieldTypeBase.extend({
name: z.literal('FluxVAEModelField'),
originalType: zStatelessFieldType.optional(),
});
const zSchedulerFieldType = zFieldTypeBase.extend({
name: z.literal('SchedulerField'),
originalType: zStatelessFieldType.optional(),
@@ -175,6 +183,8 @@ const zStatefulFieldType = z.union([
zT2IAdapterModelFieldType,
zSpandrelImageToImageModelFieldType,
zT5EncoderModelFieldType,
zCLIPEmbedModelFieldType,
zFluxVAEModelFieldType,
zColorFieldType,
zSchedulerFieldType,
]);
@@ -667,7 +677,53 @@ export const isT5EncoderModelFieldInputInstance = (val: unknown): val is T5Encod
export const isT5EncoderModelFieldInputTemplate = (val: unknown): val is T5EncoderModelFieldInputTemplate =>
zT5EncoderModelFieldInputTemplate.safeParse(val).success;
// #endregio
// #endregion
// #region FluxVAEModelField
export const zFluxVAEModelFieldValue = zModelIdentifierField.optional();
const zFluxVAEModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zFluxVAEModelFieldValue,
});
const zFluxVAEModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zFluxVAEModelFieldType,
originalType: zFieldType.optional(),
default: zFluxVAEModelFieldValue,
});
export type FluxVAEModelFieldValue = z.infer<typeof zFluxVAEModelFieldValue>;
export type FluxVAEModelFieldInputInstance = z.infer<typeof zFluxVAEModelFieldInputInstance>;
export type FluxVAEModelFieldInputTemplate = z.infer<typeof zFluxVAEModelFieldInputTemplate>;
export const isFluxVAEModelFieldInputInstance = (val: unknown): val is FluxVAEModelFieldInputInstance =>
zFluxVAEModelFieldInputInstance.safeParse(val).success;
export const isFluxVAEModelFieldInputTemplate = (val: unknown): val is FluxVAEModelFieldInputTemplate =>
zFluxVAEModelFieldInputTemplate.safeParse(val).success;
// #endregion
// #region CLIPEmbedModelField
export const zCLIPEmbedModelFieldValue = zModelIdentifierField.optional();
const zCLIPEmbedModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zCLIPEmbedModelFieldValue,
});
const zCLIPEmbedModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zCLIPEmbedModelFieldType,
originalType: zFieldType.optional(),
default: zCLIPEmbedModelFieldValue,
});
export type CLIPEmbedModelFieldValue = z.infer<typeof zCLIPEmbedModelFieldValue>;
export type CLIPEmbedModelFieldInputInstance = z.infer<typeof zCLIPEmbedModelFieldInputInstance>;
export type CLIPEmbedModelFieldInputTemplate = z.infer<typeof zCLIPEmbedModelFieldInputTemplate>;
export const isCLIPEmbedModelFieldInputInstance = (val: unknown): val is CLIPEmbedModelFieldInputInstance =>
zCLIPEmbedModelFieldInputInstance.safeParse(val).success;
export const isCLIPEmbedModelFieldInputTemplate = (val: unknown): val is CLIPEmbedModelFieldInputTemplate =>
zCLIPEmbedModelFieldInputTemplate.safeParse(val).success;
// #endregion
// #region SchedulerField
@@ -758,6 +814,8 @@ export const zStatefulFieldValue = z.union([
zT2IAdapterModelFieldValue,
zSpandrelImageToImageModelFieldValue,
zT5EncoderModelFieldValue,
zFluxVAEModelFieldValue,
zCLIPEmbedModelFieldValue,
zColorFieldValue,
zSchedulerFieldValue,
]);
@@ -788,6 +846,8 @@ const zStatefulFieldInputInstance = z.union([
zT2IAdapterModelFieldInputInstance,
zSpandrelImageToImageModelFieldInputInstance,
zT5EncoderModelFieldInputInstance,
zFluxVAEModelFieldInputInstance,
zCLIPEmbedModelFieldInputInstance,
zColorFieldInputInstance,
zSchedulerFieldInputInstance,
]);
@@ -819,6 +879,8 @@ const zStatefulFieldInputTemplate = z.union([
zT2IAdapterModelFieldInputTemplate,
zSpandrelImageToImageModelFieldInputTemplate,
zT5EncoderModelFieldInputTemplate,
zFluxVAEModelFieldInputTemplate,
zCLIPEmbedModelFieldInputTemplate,
zColorFieldInputTemplate,
zSchedulerFieldInputTemplate,
zStatelessFieldInputTemplate,

View File

@@ -23,6 +23,8 @@ const FIELD_VALUE_FALLBACK_MAP: Record<StatefulFieldType['name'], FieldValue> =
VAEModelField: undefined,
ControlNetModelField: undefined,
T5EncoderModelField: undefined,
FluxVAEModelField: undefined,
CLIPEmbedModelField: undefined,
};
export const buildFieldInputInstance = (id: string, template: FieldInputTemplate): FieldInputInstance => {

View File

@@ -2,6 +2,7 @@ import { FieldParseError } from 'features/nodes/types/error';
import type {
BoardFieldInputTemplate,
BooleanFieldInputTemplate,
CLIPEmbedModelFieldInputTemplate,
ColorFieldInputTemplate,
ControlNetModelFieldInputTemplate,
EnumFieldInputTemplate,
@@ -9,6 +10,7 @@ import type {
FieldType,
FloatFieldInputTemplate,
FluxMainModelFieldInputTemplate,
FluxVAEModelFieldInputTemplate,
ImageFieldInputTemplate,
IntegerFieldInputTemplate,
IPAdapterModelFieldInputTemplate,
@@ -238,6 +240,34 @@ const buildT5EncoderModelFieldInputTemplate: FieldInputTemplateBuilder<T5Encoder
return template;
};
const buildCLIPEmbedModelFieldInputTemplate: FieldInputTemplateBuilder<CLIPEmbedModelFieldInputTemplate> = ({
schemaObject,
baseField,
fieldType,
}) => {
const template: CLIPEmbedModelFieldInputTemplate = {
...baseField,
type: fieldType,
default: schemaObject.default ?? undefined,
};
return template;
};
const buildFluxVAEModelFieldInputTemplate: FieldInputTemplateBuilder<FluxVAEModelFieldInputTemplate> = ({
schemaObject,
baseField,
fieldType,
}) => {
const template: FluxVAEModelFieldInputTemplate = {
...baseField,
type: fieldType,
default: schemaObject.default ?? undefined,
};
return template;
};
const buildLoRAModelFieldInputTemplate: FieldInputTemplateBuilder<LoRAModelFieldInputTemplate> = ({
schemaObject,
baseField,
@@ -423,6 +453,8 @@ export const TEMPLATE_BUILDER_MAP: Record<StatefulFieldType['name'], FieldInputT
SpandrelImageToImageModelField: buildSpandrelImageToImageModelFieldInputTemplate,
VAEModelField: buildVAEModelFieldInputTemplate,
T5EncoderModelField: buildT5EncoderModelFieldInputTemplate,
CLIPEmbedModelField: buildCLIPEmbedModelFieldInputTemplate,
FluxVAEModelField: buildFluxVAEModelFieldInputTemplate,
} as const;
export const buildFieldInputTemplate = (

View File

@@ -7,6 +7,7 @@ import {
isControlNetModelConfig,
isControlNetOrT2IAdapterModelConfig,
isFluxMainModelModelConfig,
isFluxVAEModelConfig,
isIPAdapterModelConfig,
isLoRAModelConfig,
isNonRefinerMainModelConfig,
@@ -52,3 +53,4 @@ export const useSpandrelImageToImageModels = buildModelsHook(isSpandrelImageToIm
export const useIPAdapterModels = buildModelsHook(isIPAdapterModelConfig);
export const useEmbeddingModels = buildModelsHook(isTIModelConfig);
export const useVAEModels = buildModelsHook(isVAEModelConfig);
export const useFluxVAEModels = buildModelsHook(isFluxVAEModelConfig);

File diff suppressed because one or more lines are too long

View File

@@ -51,7 +51,7 @@ export type VAEModelConfig = S['VAECheckpointConfig'] | S['VAEDiffusersConfig'];
export type ControlNetModelConfig = S['ControlNetDiffusersConfig'] | S['ControlNetCheckpointConfig'];
export type IPAdapterModelConfig = S['IPAdapterInvokeAIConfig'] | S['IPAdapterCheckpointConfig'];
export type T2IAdapterModelConfig = S['T2IAdapterConfig'];
type ClipEmbedModelConfig = S['CLIPEmbedDiffusersConfig'];
export type ClipEmbedModelConfig = S['CLIPEmbedDiffusersConfig'];
export type T5EncoderModelConfig = S['T5EncoderConfig'];
export type T5EncoderBnbQuantizedLlmInt8bModelConfig = S['T5EncoderBnbQuantizedLlmInt8bConfig'];
export type SpandrelImageToImageModelConfig = S['SpandrelImageToImageConfig'];
@@ -82,6 +82,10 @@ export const isVAEModelConfig = (config: AnyModelConfig): config is VAEModelConf
return config.type === 'vae';
};
export const isFluxVAEModelConfig = (config: AnyModelConfig): config is VAEModelConfig => {
return config.type === 'vae' && config.base === 'flux';
};
export const isControlNetModelConfig = (config: AnyModelConfig): config is ControlNetModelConfig => {
return config.type === 'controlnet';
};

View File

@@ -1 +1 @@
__version__ = "4.2.8post1"
__version__ = "4.2.9rc1"

View File

@@ -130,8 +130,6 @@ dependencies = [
[project.scripts]
"invokeai-web" = "invokeai.app.run_app:run_app"
"invokeai-import-images" = "invokeai.frontend.install.import_images:main"
"invokeai-db-maintenance" = "invokeai.backend.util.db_maintenance:main"
[project.urls]
"Homepage" = "https://invoke-ai.github.io/InvokeAI/"

63
scripts/allocate_vram.py Normal file
View File

@@ -0,0 +1,63 @@
import argparse
import torch
def display_vram_usage():
"""Displays the total, allocated, and free VRAM on the current CUDA device."""
assert torch.cuda.is_available(), "CUDA is not available"
device = torch.device("cuda")
total_vram = torch.cuda.get_device_properties(device).total_memory
allocated_vram = torch.cuda.memory_allocated(device)
free_vram = total_vram - allocated_vram
print(f"Total VRAM: {total_vram / (1024 * 1024 * 1024):.2f} GB")
print(f"Allocated VRAM: {allocated_vram / (1024 * 1024 * 1024):.2f} GB")
print(f"Free VRAM: {free_vram / (1024 * 1024 * 1024):.2f} GB")
def allocate_vram(target_gb: float, target_free: bool = False):
"""Allocates VRAM on the current CUDA device. After allocation, the script will pause until the user presses Enter
or ends the script, at which point the VRAM will be released.
Args:
target_gb (float): Amount of VRAM to allocate in GB.
target_free (bool, optional): Instead of allocating <target_gb> VRAM, enough VRAM will be allocated so the system has <target_gb> of VRAM free. For example, if <target_gb> is 2 GB, the script will allocate VRAM until the free VRAM is 2 GB.
"""
assert torch.cuda.is_available(), "CUDA is not available"
device = torch.device("cuda")
if target_free:
total_vram = torch.cuda.get_device_properties(device).total_memory
free_vram = total_vram - torch.cuda.memory_allocated(device)
target_free_bytes = target_gb * 1024 * 1024 * 1024
bytes_to_allocate = free_vram - target_free_bytes
if bytes_to_allocate <= 0:
print(f"Already at or below the target free VRAM of {target_gb} GB")
return
else:
bytes_to_allocate = target_gb * 1024 * 1024 * 1024
# FloatTensor (4 bytes per element)
_tensor = torch.empty(int(bytes_to_allocate / 4), dtype=torch.float, device="cuda")
display_vram_usage()
input("Press Enter to release VRAM allocation and exit...")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Allocate VRAM for testing purposes. Only works on CUDA devices.")
parser.add_argument("target_gb", type=float, help="Amount of VRAM to allocate in GB.")
parser.add_argument(
"--target-free",
action="store_true",
help="Instead of allocating <target_gb> VRAM, enough VRAM will be allocated so the system has <target_gb> of VRAM free. For example, if <target_gb> is 2 GB, the script will allocate VRAM until the free VRAM is 2 GB.",
)
args = parser.parse_args()
allocate_vram(target_gb=args.target_gb, target_free=args.target_free)

View File

@@ -0,0 +1,42 @@
import pytest
import torch
from invokeai.backend.flux.sampling_utils import clip_timestep_schedule
def float_lists_almost_equal(list1: list[float], list2: list[float], tol: float = 1e-6) -> bool:
return all(abs(a - b) < tol for a, b in zip(list1, list2, strict=True))
@pytest.mark.parametrize(
["denoising_start", "denoising_end", "expected_timesteps", "raises"],
[
(0.0, 1.0, [1.0, 0.75, 0.5, 0.25, 0.0], False), # Default case.
(-0.1, 1.0, [], True), # Negative denoising_start should raise.
(0.0, 1.1, [], True), # denoising_end > 1 should raise.
(0.5, 0.0, [], True), # denoising_start > denoising_end should raise.
(0.0, 0.0, [1.0], False), # denoising_end == 0.
(1.0, 1.0, [0.0], False), # denoising_start == 1.
(0.2, 0.8, [1.0, 0.75, 0.5, 0.25], False), # Middle of the schedule.
# If we denoise from 0.0 to x, then from x to 1.0, it is important that denoise_end = x and denoise_start = x
# map to the same timestep. We test this first when x is equal to a timestep, then when it falls between two
# timesteps.
# x = 0.5
(0.0, 0.5, [1.0, 0.75, 0.5], False),
(0.5, 1.0, [0.5, 0.25, 0.0], False),
# x = 0.3
(0.0, 0.3, [1.0, 0.75], False),
(0.3, 1.0, [0.75, 0.5, 0.25, 0.0], False),
],
)
def test_clip_timestep_schedule(
denoising_start: float, denoising_end: float, expected_timesteps: list[float], raises: bool
):
timesteps = torch.linspace(1, 0, 5).tolist()
if raises:
with pytest.raises(AssertionError):
clip_timestep_schedule(timesteps, denoising_start, denoising_end)
else:
assert float_lists_almost_equal(
clip_timestep_schedule(timesteps, denoising_start, denoising_end), expected_timesteps
)

View File

@@ -1,12 +1,9 @@
# test that if the model's device changes while the lora is applied, the weights can still be restored
# test that LoRA patching works on both CPU and CUDA
import pytest
import torch
from invokeai.backend.lora import LoRALayer, LoRAModelRaw
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.peft.layers.lora_layer import LoRALayer
from invokeai.backend.peft.lora import LoRAModelRaw
@pytest.mark.parametrize(
@@ -38,7 +35,7 @@ def test_apply_lora(device):
},
)
}
lora = LoRAModelRaw("lora_name", lora_layers)
lora = LoRAModelRaw(lora_layers)
lora_weight = 0.5
orig_linear_weight = model["linear_layer_1"].weight.data.detach().clone()
@@ -82,7 +79,7 @@ def test_apply_lora_change_device():
},
)
}
lora = LoRAModelRaw("lora_name", lora_layers)
lora = LoRAModelRaw(lora_layers)
orig_linear_weight = model["linear_layer_1"].weight.data.detach().clone()

View File

@@ -0,0 +1,990 @@
state_dict_keys = [
"transformer.single_transformer_blocks.0.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.0.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.0.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.0.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.0.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.0.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.0.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.0.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.0.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.0.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.0.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.0.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.1.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.1.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.1.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.1.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.1.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.1.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.1.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.1.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.1.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.1.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.1.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.1.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.10.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.10.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.10.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.10.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.10.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.10.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.10.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.10.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.10.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.10.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.10.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.10.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.11.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.11.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.11.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.11.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.11.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.11.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.11.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.11.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.11.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.11.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.11.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.11.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.12.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.12.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.12.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.12.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.12.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.12.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.12.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.12.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.12.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.12.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.12.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.12.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.13.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.13.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.13.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.13.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.13.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.13.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.13.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.13.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.13.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.13.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.13.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.13.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.14.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.14.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.14.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.14.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.14.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.14.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.14.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.14.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.14.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.14.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.14.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.14.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.15.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.15.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.15.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.15.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.15.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.15.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.15.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.15.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.15.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.15.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.15.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.15.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.16.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.16.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.16.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.16.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.16.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.16.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.16.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.16.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.16.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.16.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.16.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.16.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.17.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.17.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.17.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.17.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.17.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.17.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.17.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.17.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.17.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.17.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.17.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.17.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.18.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.18.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.18.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.18.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.18.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.18.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.18.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.18.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.18.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.18.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.18.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.18.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.19.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.19.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.19.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.19.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.19.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.19.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.19.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.19.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.19.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.19.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.19.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.19.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.2.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.2.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.2.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.2.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.2.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.2.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.2.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.2.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.2.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.2.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.2.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.2.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.20.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.20.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.20.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.20.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.20.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.20.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.20.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.20.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.20.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.20.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.20.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.20.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.21.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.21.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.21.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.21.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.21.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.21.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.21.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.21.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.21.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.21.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.21.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.21.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.22.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.22.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.22.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.22.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.22.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.22.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.22.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.22.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.22.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.22.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.22.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.22.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.23.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.23.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.23.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.23.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.23.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.23.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.23.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.23.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.23.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.23.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.23.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.23.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.24.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.24.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.24.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.24.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.24.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.24.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.24.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.24.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.24.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.24.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.24.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.24.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.25.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.25.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.25.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.25.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.25.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.25.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.25.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.25.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.25.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.25.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.25.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.25.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.26.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.26.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.26.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.26.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.26.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.26.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.26.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.26.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.26.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.26.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.26.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.26.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.27.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.27.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.27.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.27.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.27.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.27.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.27.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.27.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.27.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.27.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.27.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.27.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.28.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.28.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.28.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.28.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.28.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.28.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.28.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.28.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.28.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.28.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.28.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.28.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.29.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.29.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.29.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.29.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.29.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.29.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.29.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.29.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.29.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.29.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.29.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.29.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.3.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.3.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.3.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.3.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.3.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.3.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.3.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.3.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.3.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.3.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.3.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.3.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.30.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.30.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.30.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.30.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.30.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.30.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.30.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.30.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.30.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.30.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.30.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.30.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.31.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.31.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.31.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.31.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.31.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.31.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.31.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.31.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.31.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.31.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.31.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.31.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.32.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.32.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.32.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.32.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.32.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.32.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.32.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.32.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.32.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.32.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.32.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.32.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.33.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.33.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.33.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.33.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.33.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.33.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.33.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.33.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.33.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.33.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.33.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.33.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.34.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.34.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.34.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.34.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.34.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.34.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.34.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.34.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.34.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.34.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.34.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.34.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.35.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.35.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.35.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.35.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.35.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.35.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.35.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.35.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.35.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.35.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.35.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.35.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.36.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.36.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.36.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.36.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.36.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.36.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.36.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.36.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.36.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.36.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.36.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.36.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.37.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.37.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.37.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.37.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.37.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.37.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.37.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.37.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.37.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.37.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.37.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.37.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.4.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.4.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.4.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.4.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.4.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.4.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.4.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.4.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.4.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.4.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.4.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.4.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.5.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.5.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.5.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.5.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.5.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.5.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.5.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.5.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.5.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.5.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.5.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.5.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.6.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.6.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.6.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.6.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.6.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.6.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.6.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.6.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.6.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.6.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.6.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.6.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.7.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.7.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.7.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.7.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.7.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.7.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.7.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.7.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.7.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.7.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.7.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.7.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.8.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.8.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.8.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.8.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.8.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.8.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.8.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.8.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.8.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.8.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.8.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.8.proj_out.lora_B.weight",
"transformer.single_transformer_blocks.9.attn.to_k.lora_A.weight",
"transformer.single_transformer_blocks.9.attn.to_k.lora_B.weight",
"transformer.single_transformer_blocks.9.attn.to_q.lora_A.weight",
"transformer.single_transformer_blocks.9.attn.to_q.lora_B.weight",
"transformer.single_transformer_blocks.9.attn.to_v.lora_A.weight",
"transformer.single_transformer_blocks.9.attn.to_v.lora_B.weight",
"transformer.single_transformer_blocks.9.norm.linear.lora_A.weight",
"transformer.single_transformer_blocks.9.norm.linear.lora_B.weight",
"transformer.single_transformer_blocks.9.proj_mlp.lora_A.weight",
"transformer.single_transformer_blocks.9.proj_mlp.lora_B.weight",
"transformer.single_transformer_blocks.9.proj_out.lora_A.weight",
"transformer.single_transformer_blocks.9.proj_out.lora_B.weight",
"transformer.transformer_blocks.0.attn.add_k_proj.lora_A.weight",
"transformer.transformer_blocks.0.attn.add_k_proj.lora_B.weight",
"transformer.transformer_blocks.0.attn.add_q_proj.lora_A.weight",
"transformer.transformer_blocks.0.attn.add_q_proj.lora_B.weight",
"transformer.transformer_blocks.0.attn.add_v_proj.lora_A.weight",
"transformer.transformer_blocks.0.attn.add_v_proj.lora_B.weight",
"transformer.transformer_blocks.0.attn.to_add_out.lora_A.weight",
"transformer.transformer_blocks.0.attn.to_add_out.lora_B.weight",
"transformer.transformer_blocks.0.attn.to_k.lora_A.weight",
"transformer.transformer_blocks.0.attn.to_k.lora_B.weight",
"transformer.transformer_blocks.0.attn.to_out.0.lora_A.weight",
"transformer.transformer_blocks.0.attn.to_out.0.lora_B.weight",
"transformer.transformer_blocks.0.attn.to_q.lora_A.weight",
"transformer.transformer_blocks.0.attn.to_q.lora_B.weight",
"transformer.transformer_blocks.0.attn.to_v.lora_A.weight",
"transformer.transformer_blocks.0.attn.to_v.lora_B.weight",
"transformer.transformer_blocks.0.ff.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.0.ff.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.0.ff.net.2.lora_A.weight",
"transformer.transformer_blocks.0.ff.net.2.lora_B.weight",
"transformer.transformer_blocks.0.ff_context.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.0.ff_context.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.0.ff_context.net.2.lora_A.weight",
"transformer.transformer_blocks.0.ff_context.net.2.lora_B.weight",
"transformer.transformer_blocks.0.norm1.linear.lora_A.weight",
"transformer.transformer_blocks.0.norm1.linear.lora_B.weight",
"transformer.transformer_blocks.0.norm1_context.linear.lora_A.weight",
"transformer.transformer_blocks.0.norm1_context.linear.lora_B.weight",
"transformer.transformer_blocks.1.attn.add_k_proj.lora_A.weight",
"transformer.transformer_blocks.1.attn.add_k_proj.lora_B.weight",
"transformer.transformer_blocks.1.attn.add_q_proj.lora_A.weight",
"transformer.transformer_blocks.1.attn.add_q_proj.lora_B.weight",
"transformer.transformer_blocks.1.attn.add_v_proj.lora_A.weight",
"transformer.transformer_blocks.1.attn.add_v_proj.lora_B.weight",
"transformer.transformer_blocks.1.attn.to_add_out.lora_A.weight",
"transformer.transformer_blocks.1.attn.to_add_out.lora_B.weight",
"transformer.transformer_blocks.1.attn.to_k.lora_A.weight",
"transformer.transformer_blocks.1.attn.to_k.lora_B.weight",
"transformer.transformer_blocks.1.attn.to_out.0.lora_A.weight",
"transformer.transformer_blocks.1.attn.to_out.0.lora_B.weight",
"transformer.transformer_blocks.1.attn.to_q.lora_A.weight",
"transformer.transformer_blocks.1.attn.to_q.lora_B.weight",
"transformer.transformer_blocks.1.attn.to_v.lora_A.weight",
"transformer.transformer_blocks.1.attn.to_v.lora_B.weight",
"transformer.transformer_blocks.1.ff.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.1.ff.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.1.ff.net.2.lora_A.weight",
"transformer.transformer_blocks.1.ff.net.2.lora_B.weight",
"transformer.transformer_blocks.1.ff_context.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.1.ff_context.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.1.ff_context.net.2.lora_A.weight",
"transformer.transformer_blocks.1.ff_context.net.2.lora_B.weight",
"transformer.transformer_blocks.1.norm1.linear.lora_A.weight",
"transformer.transformer_blocks.1.norm1.linear.lora_B.weight",
"transformer.transformer_blocks.1.norm1_context.linear.lora_A.weight",
"transformer.transformer_blocks.1.norm1_context.linear.lora_B.weight",
"transformer.transformer_blocks.10.attn.add_k_proj.lora_A.weight",
"transformer.transformer_blocks.10.attn.add_k_proj.lora_B.weight",
"transformer.transformer_blocks.10.attn.add_q_proj.lora_A.weight",
"transformer.transformer_blocks.10.attn.add_q_proj.lora_B.weight",
"transformer.transformer_blocks.10.attn.add_v_proj.lora_A.weight",
"transformer.transformer_blocks.10.attn.add_v_proj.lora_B.weight",
"transformer.transformer_blocks.10.attn.to_add_out.lora_A.weight",
"transformer.transformer_blocks.10.attn.to_add_out.lora_B.weight",
"transformer.transformer_blocks.10.attn.to_k.lora_A.weight",
"transformer.transformer_blocks.10.attn.to_k.lora_B.weight",
"transformer.transformer_blocks.10.attn.to_out.0.lora_A.weight",
"transformer.transformer_blocks.10.attn.to_out.0.lora_B.weight",
"transformer.transformer_blocks.10.attn.to_q.lora_A.weight",
"transformer.transformer_blocks.10.attn.to_q.lora_B.weight",
"transformer.transformer_blocks.10.attn.to_v.lora_A.weight",
"transformer.transformer_blocks.10.attn.to_v.lora_B.weight",
"transformer.transformer_blocks.10.ff.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.10.ff.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.10.ff.net.2.lora_A.weight",
"transformer.transformer_blocks.10.ff.net.2.lora_B.weight",
"transformer.transformer_blocks.10.ff_context.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.10.ff_context.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.10.ff_context.net.2.lora_A.weight",
"transformer.transformer_blocks.10.ff_context.net.2.lora_B.weight",
"transformer.transformer_blocks.10.norm1.linear.lora_A.weight",
"transformer.transformer_blocks.10.norm1.linear.lora_B.weight",
"transformer.transformer_blocks.10.norm1_context.linear.lora_A.weight",
"transformer.transformer_blocks.10.norm1_context.linear.lora_B.weight",
"transformer.transformer_blocks.11.attn.add_k_proj.lora_A.weight",
"transformer.transformer_blocks.11.attn.add_k_proj.lora_B.weight",
"transformer.transformer_blocks.11.attn.add_q_proj.lora_A.weight",
"transformer.transformer_blocks.11.attn.add_q_proj.lora_B.weight",
"transformer.transformer_blocks.11.attn.add_v_proj.lora_A.weight",
"transformer.transformer_blocks.11.attn.add_v_proj.lora_B.weight",
"transformer.transformer_blocks.11.attn.to_add_out.lora_A.weight",
"transformer.transformer_blocks.11.attn.to_add_out.lora_B.weight",
"transformer.transformer_blocks.11.attn.to_k.lora_A.weight",
"transformer.transformer_blocks.11.attn.to_k.lora_B.weight",
"transformer.transformer_blocks.11.attn.to_out.0.lora_A.weight",
"transformer.transformer_blocks.11.attn.to_out.0.lora_B.weight",
"transformer.transformer_blocks.11.attn.to_q.lora_A.weight",
"transformer.transformer_blocks.11.attn.to_q.lora_B.weight",
"transformer.transformer_blocks.11.attn.to_v.lora_A.weight",
"transformer.transformer_blocks.11.attn.to_v.lora_B.weight",
"transformer.transformer_blocks.11.ff.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.11.ff.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.11.ff.net.2.lora_A.weight",
"transformer.transformer_blocks.11.ff.net.2.lora_B.weight",
"transformer.transformer_blocks.11.ff_context.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.11.ff_context.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.11.ff_context.net.2.lora_A.weight",
"transformer.transformer_blocks.11.ff_context.net.2.lora_B.weight",
"transformer.transformer_blocks.11.norm1.linear.lora_A.weight",
"transformer.transformer_blocks.11.norm1.linear.lora_B.weight",
"transformer.transformer_blocks.11.norm1_context.linear.lora_A.weight",
"transformer.transformer_blocks.11.norm1_context.linear.lora_B.weight",
"transformer.transformer_blocks.12.attn.add_k_proj.lora_A.weight",
"transformer.transformer_blocks.12.attn.add_k_proj.lora_B.weight",
"transformer.transformer_blocks.12.attn.add_q_proj.lora_A.weight",
"transformer.transformer_blocks.12.attn.add_q_proj.lora_B.weight",
"transformer.transformer_blocks.12.attn.add_v_proj.lora_A.weight",
"transformer.transformer_blocks.12.attn.add_v_proj.lora_B.weight",
"transformer.transformer_blocks.12.attn.to_add_out.lora_A.weight",
"transformer.transformer_blocks.12.attn.to_add_out.lora_B.weight",
"transformer.transformer_blocks.12.attn.to_k.lora_A.weight",
"transformer.transformer_blocks.12.attn.to_k.lora_B.weight",
"transformer.transformer_blocks.12.attn.to_out.0.lora_A.weight",
"transformer.transformer_blocks.12.attn.to_out.0.lora_B.weight",
"transformer.transformer_blocks.12.attn.to_q.lora_A.weight",
"transformer.transformer_blocks.12.attn.to_q.lora_B.weight",
"transformer.transformer_blocks.12.attn.to_v.lora_A.weight",
"transformer.transformer_blocks.12.attn.to_v.lora_B.weight",
"transformer.transformer_blocks.12.ff.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.12.ff.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.12.ff.net.2.lora_A.weight",
"transformer.transformer_blocks.12.ff.net.2.lora_B.weight",
"transformer.transformer_blocks.12.ff_context.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.12.ff_context.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.12.ff_context.net.2.lora_A.weight",
"transformer.transformer_blocks.12.ff_context.net.2.lora_B.weight",
"transformer.transformer_blocks.12.norm1.linear.lora_A.weight",
"transformer.transformer_blocks.12.norm1.linear.lora_B.weight",
"transformer.transformer_blocks.12.norm1_context.linear.lora_A.weight",
"transformer.transformer_blocks.12.norm1_context.linear.lora_B.weight",
"transformer.transformer_blocks.13.attn.add_k_proj.lora_A.weight",
"transformer.transformer_blocks.13.attn.add_k_proj.lora_B.weight",
"transformer.transformer_blocks.13.attn.add_q_proj.lora_A.weight",
"transformer.transformer_blocks.13.attn.add_q_proj.lora_B.weight",
"transformer.transformer_blocks.13.attn.add_v_proj.lora_A.weight",
"transformer.transformer_blocks.13.attn.add_v_proj.lora_B.weight",
"transformer.transformer_blocks.13.attn.to_add_out.lora_A.weight",
"transformer.transformer_blocks.13.attn.to_add_out.lora_B.weight",
"transformer.transformer_blocks.13.attn.to_k.lora_A.weight",
"transformer.transformer_blocks.13.attn.to_k.lora_B.weight",
"transformer.transformer_blocks.13.attn.to_out.0.lora_A.weight",
"transformer.transformer_blocks.13.attn.to_out.0.lora_B.weight",
"transformer.transformer_blocks.13.attn.to_q.lora_A.weight",
"transformer.transformer_blocks.13.attn.to_q.lora_B.weight",
"transformer.transformer_blocks.13.attn.to_v.lora_A.weight",
"transformer.transformer_blocks.13.attn.to_v.lora_B.weight",
"transformer.transformer_blocks.13.ff.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.13.ff.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.13.ff.net.2.lora_A.weight",
"transformer.transformer_blocks.13.ff.net.2.lora_B.weight",
"transformer.transformer_blocks.13.ff_context.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.13.ff_context.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.13.ff_context.net.2.lora_A.weight",
"transformer.transformer_blocks.13.ff_context.net.2.lora_B.weight",
"transformer.transformer_blocks.13.norm1.linear.lora_A.weight",
"transformer.transformer_blocks.13.norm1.linear.lora_B.weight",
"transformer.transformer_blocks.13.norm1_context.linear.lora_A.weight",
"transformer.transformer_blocks.13.norm1_context.linear.lora_B.weight",
"transformer.transformer_blocks.14.attn.add_k_proj.lora_A.weight",
"transformer.transformer_blocks.14.attn.add_k_proj.lora_B.weight",
"transformer.transformer_blocks.14.attn.add_q_proj.lora_A.weight",
"transformer.transformer_blocks.14.attn.add_q_proj.lora_B.weight",
"transformer.transformer_blocks.14.attn.add_v_proj.lora_A.weight",
"transformer.transformer_blocks.14.attn.add_v_proj.lora_B.weight",
"transformer.transformer_blocks.14.attn.to_add_out.lora_A.weight",
"transformer.transformer_blocks.14.attn.to_add_out.lora_B.weight",
"transformer.transformer_blocks.14.attn.to_k.lora_A.weight",
"transformer.transformer_blocks.14.attn.to_k.lora_B.weight",
"transformer.transformer_blocks.14.attn.to_out.0.lora_A.weight",
"transformer.transformer_blocks.14.attn.to_out.0.lora_B.weight",
"transformer.transformer_blocks.14.attn.to_q.lora_A.weight",
"transformer.transformer_blocks.14.attn.to_q.lora_B.weight",
"transformer.transformer_blocks.14.attn.to_v.lora_A.weight",
"transformer.transformer_blocks.14.attn.to_v.lora_B.weight",
"transformer.transformer_blocks.14.ff.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.14.ff.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.14.ff.net.2.lora_A.weight",
"transformer.transformer_blocks.14.ff.net.2.lora_B.weight",
"transformer.transformer_blocks.14.ff_context.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.14.ff_context.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.14.ff_context.net.2.lora_A.weight",
"transformer.transformer_blocks.14.ff_context.net.2.lora_B.weight",
"transformer.transformer_blocks.14.norm1.linear.lora_A.weight",
"transformer.transformer_blocks.14.norm1.linear.lora_B.weight",
"transformer.transformer_blocks.14.norm1_context.linear.lora_A.weight",
"transformer.transformer_blocks.14.norm1_context.linear.lora_B.weight",
"transformer.transformer_blocks.15.attn.add_k_proj.lora_A.weight",
"transformer.transformer_blocks.15.attn.add_k_proj.lora_B.weight",
"transformer.transformer_blocks.15.attn.add_q_proj.lora_A.weight",
"transformer.transformer_blocks.15.attn.add_q_proj.lora_B.weight",
"transformer.transformer_blocks.15.attn.add_v_proj.lora_A.weight",
"transformer.transformer_blocks.15.attn.add_v_proj.lora_B.weight",
"transformer.transformer_blocks.15.attn.to_add_out.lora_A.weight",
"transformer.transformer_blocks.15.attn.to_add_out.lora_B.weight",
"transformer.transformer_blocks.15.attn.to_k.lora_A.weight",
"transformer.transformer_blocks.15.attn.to_k.lora_B.weight",
"transformer.transformer_blocks.15.attn.to_out.0.lora_A.weight",
"transformer.transformer_blocks.15.attn.to_out.0.lora_B.weight",
"transformer.transformer_blocks.15.attn.to_q.lora_A.weight",
"transformer.transformer_blocks.15.attn.to_q.lora_B.weight",
"transformer.transformer_blocks.15.attn.to_v.lora_A.weight",
"transformer.transformer_blocks.15.attn.to_v.lora_B.weight",
"transformer.transformer_blocks.15.ff.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.15.ff.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.15.ff.net.2.lora_A.weight",
"transformer.transformer_blocks.15.ff.net.2.lora_B.weight",
"transformer.transformer_blocks.15.ff_context.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.15.ff_context.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.15.ff_context.net.2.lora_A.weight",
"transformer.transformer_blocks.15.ff_context.net.2.lora_B.weight",
"transformer.transformer_blocks.15.norm1.linear.lora_A.weight",
"transformer.transformer_blocks.15.norm1.linear.lora_B.weight",
"transformer.transformer_blocks.15.norm1_context.linear.lora_A.weight",
"transformer.transformer_blocks.15.norm1_context.linear.lora_B.weight",
"transformer.transformer_blocks.16.attn.add_k_proj.lora_A.weight",
"transformer.transformer_blocks.16.attn.add_k_proj.lora_B.weight",
"transformer.transformer_blocks.16.attn.add_q_proj.lora_A.weight",
"transformer.transformer_blocks.16.attn.add_q_proj.lora_B.weight",
"transformer.transformer_blocks.16.attn.add_v_proj.lora_A.weight",
"transformer.transformer_blocks.16.attn.add_v_proj.lora_B.weight",
"transformer.transformer_blocks.16.attn.to_add_out.lora_A.weight",
"transformer.transformer_blocks.16.attn.to_add_out.lora_B.weight",
"transformer.transformer_blocks.16.attn.to_k.lora_A.weight",
"transformer.transformer_blocks.16.attn.to_k.lora_B.weight",
"transformer.transformer_blocks.16.attn.to_out.0.lora_A.weight",
"transformer.transformer_blocks.16.attn.to_out.0.lora_B.weight",
"transformer.transformer_blocks.16.attn.to_q.lora_A.weight",
"transformer.transformer_blocks.16.attn.to_q.lora_B.weight",
"transformer.transformer_blocks.16.attn.to_v.lora_A.weight",
"transformer.transformer_blocks.16.attn.to_v.lora_B.weight",
"transformer.transformer_blocks.16.ff.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.16.ff.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.16.ff.net.2.lora_A.weight",
"transformer.transformer_blocks.16.ff.net.2.lora_B.weight",
"transformer.transformer_blocks.16.ff_context.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.16.ff_context.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.16.ff_context.net.2.lora_A.weight",
"transformer.transformer_blocks.16.ff_context.net.2.lora_B.weight",
"transformer.transformer_blocks.16.norm1.linear.lora_A.weight",
"transformer.transformer_blocks.16.norm1.linear.lora_B.weight",
"transformer.transformer_blocks.16.norm1_context.linear.lora_A.weight",
"transformer.transformer_blocks.16.norm1_context.linear.lora_B.weight",
"transformer.transformer_blocks.17.attn.add_k_proj.lora_A.weight",
"transformer.transformer_blocks.17.attn.add_k_proj.lora_B.weight",
"transformer.transformer_blocks.17.attn.add_q_proj.lora_A.weight",
"transformer.transformer_blocks.17.attn.add_q_proj.lora_B.weight",
"transformer.transformer_blocks.17.attn.add_v_proj.lora_A.weight",
"transformer.transformer_blocks.17.attn.add_v_proj.lora_B.weight",
"transformer.transformer_blocks.17.attn.to_add_out.lora_A.weight",
"transformer.transformer_blocks.17.attn.to_add_out.lora_B.weight",
"transformer.transformer_blocks.17.attn.to_k.lora_A.weight",
"transformer.transformer_blocks.17.attn.to_k.lora_B.weight",
"transformer.transformer_blocks.17.attn.to_out.0.lora_A.weight",
"transformer.transformer_blocks.17.attn.to_out.0.lora_B.weight",
"transformer.transformer_blocks.17.attn.to_q.lora_A.weight",
"transformer.transformer_blocks.17.attn.to_q.lora_B.weight",
"transformer.transformer_blocks.17.attn.to_v.lora_A.weight",
"transformer.transformer_blocks.17.attn.to_v.lora_B.weight",
"transformer.transformer_blocks.17.ff.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.17.ff.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.17.ff.net.2.lora_A.weight",
"transformer.transformer_blocks.17.ff.net.2.lora_B.weight",
"transformer.transformer_blocks.17.ff_context.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.17.ff_context.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.17.ff_context.net.2.lora_A.weight",
"transformer.transformer_blocks.17.ff_context.net.2.lora_B.weight",
"transformer.transformer_blocks.17.norm1.linear.lora_A.weight",
"transformer.transformer_blocks.17.norm1.linear.lora_B.weight",
"transformer.transformer_blocks.17.norm1_context.linear.lora_A.weight",
"transformer.transformer_blocks.17.norm1_context.linear.lora_B.weight",
"transformer.transformer_blocks.18.attn.add_k_proj.lora_A.weight",
"transformer.transformer_blocks.18.attn.add_k_proj.lora_B.weight",
"transformer.transformer_blocks.18.attn.add_q_proj.lora_A.weight",
"transformer.transformer_blocks.18.attn.add_q_proj.lora_B.weight",
"transformer.transformer_blocks.18.attn.add_v_proj.lora_A.weight",
"transformer.transformer_blocks.18.attn.add_v_proj.lora_B.weight",
"transformer.transformer_blocks.18.attn.to_add_out.lora_A.weight",
"transformer.transformer_blocks.18.attn.to_add_out.lora_B.weight",
"transformer.transformer_blocks.18.attn.to_k.lora_A.weight",
"transformer.transformer_blocks.18.attn.to_k.lora_B.weight",
"transformer.transformer_blocks.18.attn.to_out.0.lora_A.weight",
"transformer.transformer_blocks.18.attn.to_out.0.lora_B.weight",
"transformer.transformer_blocks.18.attn.to_q.lora_A.weight",
"transformer.transformer_blocks.18.attn.to_q.lora_B.weight",
"transformer.transformer_blocks.18.attn.to_v.lora_A.weight",
"transformer.transformer_blocks.18.attn.to_v.lora_B.weight",
"transformer.transformer_blocks.18.ff.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.18.ff.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.18.ff.net.2.lora_A.weight",
"transformer.transformer_blocks.18.ff.net.2.lora_B.weight",
"transformer.transformer_blocks.18.ff_context.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.18.ff_context.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.18.ff_context.net.2.lora_A.weight",
"transformer.transformer_blocks.18.ff_context.net.2.lora_B.weight",
"transformer.transformer_blocks.18.norm1.linear.lora_A.weight",
"transformer.transformer_blocks.18.norm1.linear.lora_B.weight",
"transformer.transformer_blocks.18.norm1_context.linear.lora_A.weight",
"transformer.transformer_blocks.18.norm1_context.linear.lora_B.weight",
"transformer.transformer_blocks.2.attn.add_k_proj.lora_A.weight",
"transformer.transformer_blocks.2.attn.add_k_proj.lora_B.weight",
"transformer.transformer_blocks.2.attn.add_q_proj.lora_A.weight",
"transformer.transformer_blocks.2.attn.add_q_proj.lora_B.weight",
"transformer.transformer_blocks.2.attn.add_v_proj.lora_A.weight",
"transformer.transformer_blocks.2.attn.add_v_proj.lora_B.weight",
"transformer.transformer_blocks.2.attn.to_add_out.lora_A.weight",
"transformer.transformer_blocks.2.attn.to_add_out.lora_B.weight",
"transformer.transformer_blocks.2.attn.to_k.lora_A.weight",
"transformer.transformer_blocks.2.attn.to_k.lora_B.weight",
"transformer.transformer_blocks.2.attn.to_out.0.lora_A.weight",
"transformer.transformer_blocks.2.attn.to_out.0.lora_B.weight",
"transformer.transformer_blocks.2.attn.to_q.lora_A.weight",
"transformer.transformer_blocks.2.attn.to_q.lora_B.weight",
"transformer.transformer_blocks.2.attn.to_v.lora_A.weight",
"transformer.transformer_blocks.2.attn.to_v.lora_B.weight",
"transformer.transformer_blocks.2.ff.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.2.ff.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.2.ff.net.2.lora_A.weight",
"transformer.transformer_blocks.2.ff.net.2.lora_B.weight",
"transformer.transformer_blocks.2.ff_context.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.2.ff_context.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.2.ff_context.net.2.lora_A.weight",
"transformer.transformer_blocks.2.ff_context.net.2.lora_B.weight",
"transformer.transformer_blocks.2.norm1.linear.lora_A.weight",
"transformer.transformer_blocks.2.norm1.linear.lora_B.weight",
"transformer.transformer_blocks.2.norm1_context.linear.lora_A.weight",
"transformer.transformer_blocks.2.norm1_context.linear.lora_B.weight",
"transformer.transformer_blocks.3.attn.add_k_proj.lora_A.weight",
"transformer.transformer_blocks.3.attn.add_k_proj.lora_B.weight",
"transformer.transformer_blocks.3.attn.add_q_proj.lora_A.weight",
"transformer.transformer_blocks.3.attn.add_q_proj.lora_B.weight",
"transformer.transformer_blocks.3.attn.add_v_proj.lora_A.weight",
"transformer.transformer_blocks.3.attn.add_v_proj.lora_B.weight",
"transformer.transformer_blocks.3.attn.to_add_out.lora_A.weight",
"transformer.transformer_blocks.3.attn.to_add_out.lora_B.weight",
"transformer.transformer_blocks.3.attn.to_k.lora_A.weight",
"transformer.transformer_blocks.3.attn.to_k.lora_B.weight",
"transformer.transformer_blocks.3.attn.to_out.0.lora_A.weight",
"transformer.transformer_blocks.3.attn.to_out.0.lora_B.weight",
"transformer.transformer_blocks.3.attn.to_q.lora_A.weight",
"transformer.transformer_blocks.3.attn.to_q.lora_B.weight",
"transformer.transformer_blocks.3.attn.to_v.lora_A.weight",
"transformer.transformer_blocks.3.attn.to_v.lora_B.weight",
"transformer.transformer_blocks.3.ff.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.3.ff.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.3.ff.net.2.lora_A.weight",
"transformer.transformer_blocks.3.ff.net.2.lora_B.weight",
"transformer.transformer_blocks.3.ff_context.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.3.ff_context.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.3.ff_context.net.2.lora_A.weight",
"transformer.transformer_blocks.3.ff_context.net.2.lora_B.weight",
"transformer.transformer_blocks.3.norm1.linear.lora_A.weight",
"transformer.transformer_blocks.3.norm1.linear.lora_B.weight",
"transformer.transformer_blocks.3.norm1_context.linear.lora_A.weight",
"transformer.transformer_blocks.3.norm1_context.linear.lora_B.weight",
"transformer.transformer_blocks.4.attn.add_k_proj.lora_A.weight",
"transformer.transformer_blocks.4.attn.add_k_proj.lora_B.weight",
"transformer.transformer_blocks.4.attn.add_q_proj.lora_A.weight",
"transformer.transformer_blocks.4.attn.add_q_proj.lora_B.weight",
"transformer.transformer_blocks.4.attn.add_v_proj.lora_A.weight",
"transformer.transformer_blocks.4.attn.add_v_proj.lora_B.weight",
"transformer.transformer_blocks.4.attn.to_add_out.lora_A.weight",
"transformer.transformer_blocks.4.attn.to_add_out.lora_B.weight",
"transformer.transformer_blocks.4.attn.to_k.lora_A.weight",
"transformer.transformer_blocks.4.attn.to_k.lora_B.weight",
"transformer.transformer_blocks.4.attn.to_out.0.lora_A.weight",
"transformer.transformer_blocks.4.attn.to_out.0.lora_B.weight",
"transformer.transformer_blocks.4.attn.to_q.lora_A.weight",
"transformer.transformer_blocks.4.attn.to_q.lora_B.weight",
"transformer.transformer_blocks.4.attn.to_v.lora_A.weight",
"transformer.transformer_blocks.4.attn.to_v.lora_B.weight",
"transformer.transformer_blocks.4.ff.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.4.ff.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.4.ff.net.2.lora_A.weight",
"transformer.transformer_blocks.4.ff.net.2.lora_B.weight",
"transformer.transformer_blocks.4.ff_context.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.4.ff_context.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.4.ff_context.net.2.lora_A.weight",
"transformer.transformer_blocks.4.ff_context.net.2.lora_B.weight",
"transformer.transformer_blocks.4.norm1.linear.lora_A.weight",
"transformer.transformer_blocks.4.norm1.linear.lora_B.weight",
"transformer.transformer_blocks.4.norm1_context.linear.lora_A.weight",
"transformer.transformer_blocks.4.norm1_context.linear.lora_B.weight",
"transformer.transformer_blocks.5.attn.add_k_proj.lora_A.weight",
"transformer.transformer_blocks.5.attn.add_k_proj.lora_B.weight",
"transformer.transformer_blocks.5.attn.add_q_proj.lora_A.weight",
"transformer.transformer_blocks.5.attn.add_q_proj.lora_B.weight",
"transformer.transformer_blocks.5.attn.add_v_proj.lora_A.weight",
"transformer.transformer_blocks.5.attn.add_v_proj.lora_B.weight",
"transformer.transformer_blocks.5.attn.to_add_out.lora_A.weight",
"transformer.transformer_blocks.5.attn.to_add_out.lora_B.weight",
"transformer.transformer_blocks.5.attn.to_k.lora_A.weight",
"transformer.transformer_blocks.5.attn.to_k.lora_B.weight",
"transformer.transformer_blocks.5.attn.to_out.0.lora_A.weight",
"transformer.transformer_blocks.5.attn.to_out.0.lora_B.weight",
"transformer.transformer_blocks.5.attn.to_q.lora_A.weight",
"transformer.transformer_blocks.5.attn.to_q.lora_B.weight",
"transformer.transformer_blocks.5.attn.to_v.lora_A.weight",
"transformer.transformer_blocks.5.attn.to_v.lora_B.weight",
"transformer.transformer_blocks.5.ff.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.5.ff.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.5.ff.net.2.lora_A.weight",
"transformer.transformer_blocks.5.ff.net.2.lora_B.weight",
"transformer.transformer_blocks.5.ff_context.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.5.ff_context.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.5.ff_context.net.2.lora_A.weight",
"transformer.transformer_blocks.5.ff_context.net.2.lora_B.weight",
"transformer.transformer_blocks.5.norm1.linear.lora_A.weight",
"transformer.transformer_blocks.5.norm1.linear.lora_B.weight",
"transformer.transformer_blocks.5.norm1_context.linear.lora_A.weight",
"transformer.transformer_blocks.5.norm1_context.linear.lora_B.weight",
"transformer.transformer_blocks.6.attn.add_k_proj.lora_A.weight",
"transformer.transformer_blocks.6.attn.add_k_proj.lora_B.weight",
"transformer.transformer_blocks.6.attn.add_q_proj.lora_A.weight",
"transformer.transformer_blocks.6.attn.add_q_proj.lora_B.weight",
"transformer.transformer_blocks.6.attn.add_v_proj.lora_A.weight",
"transformer.transformer_blocks.6.attn.add_v_proj.lora_B.weight",
"transformer.transformer_blocks.6.attn.to_add_out.lora_A.weight",
"transformer.transformer_blocks.6.attn.to_add_out.lora_B.weight",
"transformer.transformer_blocks.6.attn.to_k.lora_A.weight",
"transformer.transformer_blocks.6.attn.to_k.lora_B.weight",
"transformer.transformer_blocks.6.attn.to_out.0.lora_A.weight",
"transformer.transformer_blocks.6.attn.to_out.0.lora_B.weight",
"transformer.transformer_blocks.6.attn.to_q.lora_A.weight",
"transformer.transformer_blocks.6.attn.to_q.lora_B.weight",
"transformer.transformer_blocks.6.attn.to_v.lora_A.weight",
"transformer.transformer_blocks.6.attn.to_v.lora_B.weight",
"transformer.transformer_blocks.6.ff.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.6.ff.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.6.ff.net.2.lora_A.weight",
"transformer.transformer_blocks.6.ff.net.2.lora_B.weight",
"transformer.transformer_blocks.6.ff_context.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.6.ff_context.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.6.ff_context.net.2.lora_A.weight",
"transformer.transformer_blocks.6.ff_context.net.2.lora_B.weight",
"transformer.transformer_blocks.6.norm1.linear.lora_A.weight",
"transformer.transformer_blocks.6.norm1.linear.lora_B.weight",
"transformer.transformer_blocks.6.norm1_context.linear.lora_A.weight",
"transformer.transformer_blocks.6.norm1_context.linear.lora_B.weight",
"transformer.transformer_blocks.7.attn.add_k_proj.lora_A.weight",
"transformer.transformer_blocks.7.attn.add_k_proj.lora_B.weight",
"transformer.transformer_blocks.7.attn.add_q_proj.lora_A.weight",
"transformer.transformer_blocks.7.attn.add_q_proj.lora_B.weight",
"transformer.transformer_blocks.7.attn.add_v_proj.lora_A.weight",
"transformer.transformer_blocks.7.attn.add_v_proj.lora_B.weight",
"transformer.transformer_blocks.7.attn.to_add_out.lora_A.weight",
"transformer.transformer_blocks.7.attn.to_add_out.lora_B.weight",
"transformer.transformer_blocks.7.attn.to_k.lora_A.weight",
"transformer.transformer_blocks.7.attn.to_k.lora_B.weight",
"transformer.transformer_blocks.7.attn.to_out.0.lora_A.weight",
"transformer.transformer_blocks.7.attn.to_out.0.lora_B.weight",
"transformer.transformer_blocks.7.attn.to_q.lora_A.weight",
"transformer.transformer_blocks.7.attn.to_q.lora_B.weight",
"transformer.transformer_blocks.7.attn.to_v.lora_A.weight",
"transformer.transformer_blocks.7.attn.to_v.lora_B.weight",
"transformer.transformer_blocks.7.ff.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.7.ff.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.7.ff.net.2.lora_A.weight",
"transformer.transformer_blocks.7.ff.net.2.lora_B.weight",
"transformer.transformer_blocks.7.ff_context.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.7.ff_context.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.7.ff_context.net.2.lora_A.weight",
"transformer.transformer_blocks.7.ff_context.net.2.lora_B.weight",
"transformer.transformer_blocks.7.norm1.linear.lora_A.weight",
"transformer.transformer_blocks.7.norm1.linear.lora_B.weight",
"transformer.transformer_blocks.7.norm1_context.linear.lora_A.weight",
"transformer.transformer_blocks.7.norm1_context.linear.lora_B.weight",
"transformer.transformer_blocks.8.attn.add_k_proj.lora_A.weight",
"transformer.transformer_blocks.8.attn.add_k_proj.lora_B.weight",
"transformer.transformer_blocks.8.attn.add_q_proj.lora_A.weight",
"transformer.transformer_blocks.8.attn.add_q_proj.lora_B.weight",
"transformer.transformer_blocks.8.attn.add_v_proj.lora_A.weight",
"transformer.transformer_blocks.8.attn.add_v_proj.lora_B.weight",
"transformer.transformer_blocks.8.attn.to_add_out.lora_A.weight",
"transformer.transformer_blocks.8.attn.to_add_out.lora_B.weight",
"transformer.transformer_blocks.8.attn.to_k.lora_A.weight",
"transformer.transformer_blocks.8.attn.to_k.lora_B.weight",
"transformer.transformer_blocks.8.attn.to_out.0.lora_A.weight",
"transformer.transformer_blocks.8.attn.to_out.0.lora_B.weight",
"transformer.transformer_blocks.8.attn.to_q.lora_A.weight",
"transformer.transformer_blocks.8.attn.to_q.lora_B.weight",
"transformer.transformer_blocks.8.attn.to_v.lora_A.weight",
"transformer.transformer_blocks.8.attn.to_v.lora_B.weight",
"transformer.transformer_blocks.8.ff.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.8.ff.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.8.ff.net.2.lora_A.weight",
"transformer.transformer_blocks.8.ff.net.2.lora_B.weight",
"transformer.transformer_blocks.8.ff_context.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.8.ff_context.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.8.ff_context.net.2.lora_A.weight",
"transformer.transformer_blocks.8.ff_context.net.2.lora_B.weight",
"transformer.transformer_blocks.8.norm1.linear.lora_A.weight",
"transformer.transformer_blocks.8.norm1.linear.lora_B.weight",
"transformer.transformer_blocks.8.norm1_context.linear.lora_A.weight",
"transformer.transformer_blocks.8.norm1_context.linear.lora_B.weight",
"transformer.transformer_blocks.9.attn.add_k_proj.lora_A.weight",
"transformer.transformer_blocks.9.attn.add_k_proj.lora_B.weight",
"transformer.transformer_blocks.9.attn.add_q_proj.lora_A.weight",
"transformer.transformer_blocks.9.attn.add_q_proj.lora_B.weight",
"transformer.transformer_blocks.9.attn.add_v_proj.lora_A.weight",
"transformer.transformer_blocks.9.attn.add_v_proj.lora_B.weight",
"transformer.transformer_blocks.9.attn.to_add_out.lora_A.weight",
"transformer.transformer_blocks.9.attn.to_add_out.lora_B.weight",
"transformer.transformer_blocks.9.attn.to_k.lora_A.weight",
"transformer.transformer_blocks.9.attn.to_k.lora_B.weight",
"transformer.transformer_blocks.9.attn.to_out.0.lora_A.weight",
"transformer.transformer_blocks.9.attn.to_out.0.lora_B.weight",
"transformer.transformer_blocks.9.attn.to_q.lora_A.weight",
"transformer.transformer_blocks.9.attn.to_q.lora_B.weight",
"transformer.transformer_blocks.9.attn.to_v.lora_A.weight",
"transformer.transformer_blocks.9.attn.to_v.lora_B.weight",
"transformer.transformer_blocks.9.ff.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.9.ff.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.9.ff.net.2.lora_A.weight",
"transformer.transformer_blocks.9.ff.net.2.lora_B.weight",
"transformer.transformer_blocks.9.ff_context.net.0.proj.lora_A.weight",
"transformer.transformer_blocks.9.ff_context.net.0.proj.lora_B.weight",
"transformer.transformer_blocks.9.ff_context.net.2.lora_A.weight",
"transformer.transformer_blocks.9.ff_context.net.2.lora_B.weight",
"transformer.transformer_blocks.9.norm1.linear.lora_A.weight",
"transformer.transformer_blocks.9.norm1.linear.lora_B.weight",
"transformer.transformer_blocks.9.norm1_context.linear.lora_A.weight",
"transformer.transformer_blocks.9.norm1_context.linear.lora_B.weight",
]

View File

@@ -0,0 +1,914 @@
state_dict_keys = [
"lora_unet_double_blocks_0_img_attn_proj.alpha",
"lora_unet_double_blocks_0_img_attn_proj.lora_down.weight",
"lora_unet_double_blocks_0_img_attn_proj.lora_up.weight",
"lora_unet_double_blocks_0_img_attn_qkv.alpha",
"lora_unet_double_blocks_0_img_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_0_img_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_0_img_mlp_0.alpha",
"lora_unet_double_blocks_0_img_mlp_0.lora_down.weight",
"lora_unet_double_blocks_0_img_mlp_0.lora_up.weight",
"lora_unet_double_blocks_0_img_mlp_2.alpha",
"lora_unet_double_blocks_0_img_mlp_2.lora_down.weight",
"lora_unet_double_blocks_0_img_mlp_2.lora_up.weight",
"lora_unet_double_blocks_0_img_mod_lin.alpha",
"lora_unet_double_blocks_0_img_mod_lin.lora_down.weight",
"lora_unet_double_blocks_0_img_mod_lin.lora_up.weight",
"lora_unet_double_blocks_0_txt_attn_proj.alpha",
"lora_unet_double_blocks_0_txt_attn_proj.lora_down.weight",
"lora_unet_double_blocks_0_txt_attn_proj.lora_up.weight",
"lora_unet_double_blocks_0_txt_attn_qkv.alpha",
"lora_unet_double_blocks_0_txt_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_0_txt_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_0_txt_mlp_0.alpha",
"lora_unet_double_blocks_0_txt_mlp_0.lora_down.weight",
"lora_unet_double_blocks_0_txt_mlp_0.lora_up.weight",
"lora_unet_double_blocks_0_txt_mlp_2.alpha",
"lora_unet_double_blocks_0_txt_mlp_2.lora_down.weight",
"lora_unet_double_blocks_0_txt_mlp_2.lora_up.weight",
"lora_unet_double_blocks_0_txt_mod_lin.alpha",
"lora_unet_double_blocks_0_txt_mod_lin.lora_down.weight",
"lora_unet_double_blocks_0_txt_mod_lin.lora_up.weight",
"lora_unet_double_blocks_10_img_attn_proj.alpha",
"lora_unet_double_blocks_10_img_attn_proj.lora_down.weight",
"lora_unet_double_blocks_10_img_attn_proj.lora_up.weight",
"lora_unet_double_blocks_10_img_attn_qkv.alpha",
"lora_unet_double_blocks_10_img_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_10_img_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_10_img_mlp_0.alpha",
"lora_unet_double_blocks_10_img_mlp_0.lora_down.weight",
"lora_unet_double_blocks_10_img_mlp_0.lora_up.weight",
"lora_unet_double_blocks_10_img_mlp_2.alpha",
"lora_unet_double_blocks_10_img_mlp_2.lora_down.weight",
"lora_unet_double_blocks_10_img_mlp_2.lora_up.weight",
"lora_unet_double_blocks_10_img_mod_lin.alpha",
"lora_unet_double_blocks_10_img_mod_lin.lora_down.weight",
"lora_unet_double_blocks_10_img_mod_lin.lora_up.weight",
"lora_unet_double_blocks_10_txt_attn_proj.alpha",
"lora_unet_double_blocks_10_txt_attn_proj.lora_down.weight",
"lora_unet_double_blocks_10_txt_attn_proj.lora_up.weight",
"lora_unet_double_blocks_10_txt_attn_qkv.alpha",
"lora_unet_double_blocks_10_txt_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_10_txt_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_10_txt_mlp_0.alpha",
"lora_unet_double_blocks_10_txt_mlp_0.lora_down.weight",
"lora_unet_double_blocks_10_txt_mlp_0.lora_up.weight",
"lora_unet_double_blocks_10_txt_mlp_2.alpha",
"lora_unet_double_blocks_10_txt_mlp_2.lora_down.weight",
"lora_unet_double_blocks_10_txt_mlp_2.lora_up.weight",
"lora_unet_double_blocks_10_txt_mod_lin.alpha",
"lora_unet_double_blocks_10_txt_mod_lin.lora_down.weight",
"lora_unet_double_blocks_10_txt_mod_lin.lora_up.weight",
"lora_unet_double_blocks_11_img_attn_proj.alpha",
"lora_unet_double_blocks_11_img_attn_proj.lora_down.weight",
"lora_unet_double_blocks_11_img_attn_proj.lora_up.weight",
"lora_unet_double_blocks_11_img_attn_qkv.alpha",
"lora_unet_double_blocks_11_img_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_11_img_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_11_img_mlp_0.alpha",
"lora_unet_double_blocks_11_img_mlp_0.lora_down.weight",
"lora_unet_double_blocks_11_img_mlp_0.lora_up.weight",
"lora_unet_double_blocks_11_img_mlp_2.alpha",
"lora_unet_double_blocks_11_img_mlp_2.lora_down.weight",
"lora_unet_double_blocks_11_img_mlp_2.lora_up.weight",
"lora_unet_double_blocks_11_img_mod_lin.alpha",
"lora_unet_double_blocks_11_img_mod_lin.lora_down.weight",
"lora_unet_double_blocks_11_img_mod_lin.lora_up.weight",
"lora_unet_double_blocks_11_txt_attn_proj.alpha",
"lora_unet_double_blocks_11_txt_attn_proj.lora_down.weight",
"lora_unet_double_blocks_11_txt_attn_proj.lora_up.weight",
"lora_unet_double_blocks_11_txt_attn_qkv.alpha",
"lora_unet_double_blocks_11_txt_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_11_txt_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_11_txt_mlp_0.alpha",
"lora_unet_double_blocks_11_txt_mlp_0.lora_down.weight",
"lora_unet_double_blocks_11_txt_mlp_0.lora_up.weight",
"lora_unet_double_blocks_11_txt_mlp_2.alpha",
"lora_unet_double_blocks_11_txt_mlp_2.lora_down.weight",
"lora_unet_double_blocks_11_txt_mlp_2.lora_up.weight",
"lora_unet_double_blocks_11_txt_mod_lin.alpha",
"lora_unet_double_blocks_11_txt_mod_lin.lora_down.weight",
"lora_unet_double_blocks_11_txt_mod_lin.lora_up.weight",
"lora_unet_double_blocks_12_img_attn_proj.alpha",
"lora_unet_double_blocks_12_img_attn_proj.lora_down.weight",
"lora_unet_double_blocks_12_img_attn_proj.lora_up.weight",
"lora_unet_double_blocks_12_img_attn_qkv.alpha",
"lora_unet_double_blocks_12_img_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_12_img_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_12_img_mlp_0.alpha",
"lora_unet_double_blocks_12_img_mlp_0.lora_down.weight",
"lora_unet_double_blocks_12_img_mlp_0.lora_up.weight",
"lora_unet_double_blocks_12_img_mlp_2.alpha",
"lora_unet_double_blocks_12_img_mlp_2.lora_down.weight",
"lora_unet_double_blocks_12_img_mlp_2.lora_up.weight",
"lora_unet_double_blocks_12_img_mod_lin.alpha",
"lora_unet_double_blocks_12_img_mod_lin.lora_down.weight",
"lora_unet_double_blocks_12_img_mod_lin.lora_up.weight",
"lora_unet_double_blocks_12_txt_attn_proj.alpha",
"lora_unet_double_blocks_12_txt_attn_proj.lora_down.weight",
"lora_unet_double_blocks_12_txt_attn_proj.lora_up.weight",
"lora_unet_double_blocks_12_txt_attn_qkv.alpha",
"lora_unet_double_blocks_12_txt_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_12_txt_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_12_txt_mlp_0.alpha",
"lora_unet_double_blocks_12_txt_mlp_0.lora_down.weight",
"lora_unet_double_blocks_12_txt_mlp_0.lora_up.weight",
"lora_unet_double_blocks_12_txt_mlp_2.alpha",
"lora_unet_double_blocks_12_txt_mlp_2.lora_down.weight",
"lora_unet_double_blocks_12_txt_mlp_2.lora_up.weight",
"lora_unet_double_blocks_12_txt_mod_lin.alpha",
"lora_unet_double_blocks_12_txt_mod_lin.lora_down.weight",
"lora_unet_double_blocks_12_txt_mod_lin.lora_up.weight",
"lora_unet_double_blocks_13_img_attn_proj.alpha",
"lora_unet_double_blocks_13_img_attn_proj.lora_down.weight",
"lora_unet_double_blocks_13_img_attn_proj.lora_up.weight",
"lora_unet_double_blocks_13_img_attn_qkv.alpha",
"lora_unet_double_blocks_13_img_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_13_img_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_13_img_mlp_0.alpha",
"lora_unet_double_blocks_13_img_mlp_0.lora_down.weight",
"lora_unet_double_blocks_13_img_mlp_0.lora_up.weight",
"lora_unet_double_blocks_13_img_mlp_2.alpha",
"lora_unet_double_blocks_13_img_mlp_2.lora_down.weight",
"lora_unet_double_blocks_13_img_mlp_2.lora_up.weight",
"lora_unet_double_blocks_13_img_mod_lin.alpha",
"lora_unet_double_blocks_13_img_mod_lin.lora_down.weight",
"lora_unet_double_blocks_13_img_mod_lin.lora_up.weight",
"lora_unet_double_blocks_13_txt_attn_proj.alpha",
"lora_unet_double_blocks_13_txt_attn_proj.lora_down.weight",
"lora_unet_double_blocks_13_txt_attn_proj.lora_up.weight",
"lora_unet_double_blocks_13_txt_attn_qkv.alpha",
"lora_unet_double_blocks_13_txt_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_13_txt_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_13_txt_mlp_0.alpha",
"lora_unet_double_blocks_13_txt_mlp_0.lora_down.weight",
"lora_unet_double_blocks_13_txt_mlp_0.lora_up.weight",
"lora_unet_double_blocks_13_txt_mlp_2.alpha",
"lora_unet_double_blocks_13_txt_mlp_2.lora_down.weight",
"lora_unet_double_blocks_13_txt_mlp_2.lora_up.weight",
"lora_unet_double_blocks_13_txt_mod_lin.alpha",
"lora_unet_double_blocks_13_txt_mod_lin.lora_down.weight",
"lora_unet_double_blocks_13_txt_mod_lin.lora_up.weight",
"lora_unet_double_blocks_14_img_attn_proj.alpha",
"lora_unet_double_blocks_14_img_attn_proj.lora_down.weight",
"lora_unet_double_blocks_14_img_attn_proj.lora_up.weight",
"lora_unet_double_blocks_14_img_attn_qkv.alpha",
"lora_unet_double_blocks_14_img_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_14_img_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_14_img_mlp_0.alpha",
"lora_unet_double_blocks_14_img_mlp_0.lora_down.weight",
"lora_unet_double_blocks_14_img_mlp_0.lora_up.weight",
"lora_unet_double_blocks_14_img_mlp_2.alpha",
"lora_unet_double_blocks_14_img_mlp_2.lora_down.weight",
"lora_unet_double_blocks_14_img_mlp_2.lora_up.weight",
"lora_unet_double_blocks_14_img_mod_lin.alpha",
"lora_unet_double_blocks_14_img_mod_lin.lora_down.weight",
"lora_unet_double_blocks_14_img_mod_lin.lora_up.weight",
"lora_unet_double_blocks_14_txt_attn_proj.alpha",
"lora_unet_double_blocks_14_txt_attn_proj.lora_down.weight",
"lora_unet_double_blocks_14_txt_attn_proj.lora_up.weight",
"lora_unet_double_blocks_14_txt_attn_qkv.alpha",
"lora_unet_double_blocks_14_txt_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_14_txt_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_14_txt_mlp_0.alpha",
"lora_unet_double_blocks_14_txt_mlp_0.lora_down.weight",
"lora_unet_double_blocks_14_txt_mlp_0.lora_up.weight",
"lora_unet_double_blocks_14_txt_mlp_2.alpha",
"lora_unet_double_blocks_14_txt_mlp_2.lora_down.weight",
"lora_unet_double_blocks_14_txt_mlp_2.lora_up.weight",
"lora_unet_double_blocks_14_txt_mod_lin.alpha",
"lora_unet_double_blocks_14_txt_mod_lin.lora_down.weight",
"lora_unet_double_blocks_14_txt_mod_lin.lora_up.weight",
"lora_unet_double_blocks_15_img_attn_proj.alpha",
"lora_unet_double_blocks_15_img_attn_proj.lora_down.weight",
"lora_unet_double_blocks_15_img_attn_proj.lora_up.weight",
"lora_unet_double_blocks_15_img_attn_qkv.alpha",
"lora_unet_double_blocks_15_img_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_15_img_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_15_img_mlp_0.alpha",
"lora_unet_double_blocks_15_img_mlp_0.lora_down.weight",
"lora_unet_double_blocks_15_img_mlp_0.lora_up.weight",
"lora_unet_double_blocks_15_img_mlp_2.alpha",
"lora_unet_double_blocks_15_img_mlp_2.lora_down.weight",
"lora_unet_double_blocks_15_img_mlp_2.lora_up.weight",
"lora_unet_double_blocks_15_img_mod_lin.alpha",
"lora_unet_double_blocks_15_img_mod_lin.lora_down.weight",
"lora_unet_double_blocks_15_img_mod_lin.lora_up.weight",
"lora_unet_double_blocks_15_txt_attn_proj.alpha",
"lora_unet_double_blocks_15_txt_attn_proj.lora_down.weight",
"lora_unet_double_blocks_15_txt_attn_proj.lora_up.weight",
"lora_unet_double_blocks_15_txt_attn_qkv.alpha",
"lora_unet_double_blocks_15_txt_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_15_txt_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_15_txt_mlp_0.alpha",
"lora_unet_double_blocks_15_txt_mlp_0.lora_down.weight",
"lora_unet_double_blocks_15_txt_mlp_0.lora_up.weight",
"lora_unet_double_blocks_15_txt_mlp_2.alpha",
"lora_unet_double_blocks_15_txt_mlp_2.lora_down.weight",
"lora_unet_double_blocks_15_txt_mlp_2.lora_up.weight",
"lora_unet_double_blocks_15_txt_mod_lin.alpha",
"lora_unet_double_blocks_15_txt_mod_lin.lora_down.weight",
"lora_unet_double_blocks_15_txt_mod_lin.lora_up.weight",
"lora_unet_double_blocks_16_img_attn_proj.alpha",
"lora_unet_double_blocks_16_img_attn_proj.lora_down.weight",
"lora_unet_double_blocks_16_img_attn_proj.lora_up.weight",
"lora_unet_double_blocks_16_img_attn_qkv.alpha",
"lora_unet_double_blocks_16_img_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_16_img_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_16_img_mlp_0.alpha",
"lora_unet_double_blocks_16_img_mlp_0.lora_down.weight",
"lora_unet_double_blocks_16_img_mlp_0.lora_up.weight",
"lora_unet_double_blocks_16_img_mlp_2.alpha",
"lora_unet_double_blocks_16_img_mlp_2.lora_down.weight",
"lora_unet_double_blocks_16_img_mlp_2.lora_up.weight",
"lora_unet_double_blocks_16_img_mod_lin.alpha",
"lora_unet_double_blocks_16_img_mod_lin.lora_down.weight",
"lora_unet_double_blocks_16_img_mod_lin.lora_up.weight",
"lora_unet_double_blocks_16_txt_attn_proj.alpha",
"lora_unet_double_blocks_16_txt_attn_proj.lora_down.weight",
"lora_unet_double_blocks_16_txt_attn_proj.lora_up.weight",
"lora_unet_double_blocks_16_txt_attn_qkv.alpha",
"lora_unet_double_blocks_16_txt_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_16_txt_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_16_txt_mlp_0.alpha",
"lora_unet_double_blocks_16_txt_mlp_0.lora_down.weight",
"lora_unet_double_blocks_16_txt_mlp_0.lora_up.weight",
"lora_unet_double_blocks_16_txt_mlp_2.alpha",
"lora_unet_double_blocks_16_txt_mlp_2.lora_down.weight",
"lora_unet_double_blocks_16_txt_mlp_2.lora_up.weight",
"lora_unet_double_blocks_16_txt_mod_lin.alpha",
"lora_unet_double_blocks_16_txt_mod_lin.lora_down.weight",
"lora_unet_double_blocks_16_txt_mod_lin.lora_up.weight",
"lora_unet_double_blocks_17_img_attn_proj.alpha",
"lora_unet_double_blocks_17_img_attn_proj.lora_down.weight",
"lora_unet_double_blocks_17_img_attn_proj.lora_up.weight",
"lora_unet_double_blocks_17_img_attn_qkv.alpha",
"lora_unet_double_blocks_17_img_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_17_img_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_17_img_mlp_0.alpha",
"lora_unet_double_blocks_17_img_mlp_0.lora_down.weight",
"lora_unet_double_blocks_17_img_mlp_0.lora_up.weight",
"lora_unet_double_blocks_17_img_mlp_2.alpha",
"lora_unet_double_blocks_17_img_mlp_2.lora_down.weight",
"lora_unet_double_blocks_17_img_mlp_2.lora_up.weight",
"lora_unet_double_blocks_17_img_mod_lin.alpha",
"lora_unet_double_blocks_17_img_mod_lin.lora_down.weight",
"lora_unet_double_blocks_17_img_mod_lin.lora_up.weight",
"lora_unet_double_blocks_17_txt_attn_proj.alpha",
"lora_unet_double_blocks_17_txt_attn_proj.lora_down.weight",
"lora_unet_double_blocks_17_txt_attn_proj.lora_up.weight",
"lora_unet_double_blocks_17_txt_attn_qkv.alpha",
"lora_unet_double_blocks_17_txt_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_17_txt_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_17_txt_mlp_0.alpha",
"lora_unet_double_blocks_17_txt_mlp_0.lora_down.weight",
"lora_unet_double_blocks_17_txt_mlp_0.lora_up.weight",
"lora_unet_double_blocks_17_txt_mlp_2.alpha",
"lora_unet_double_blocks_17_txt_mlp_2.lora_down.weight",
"lora_unet_double_blocks_17_txt_mlp_2.lora_up.weight",
"lora_unet_double_blocks_17_txt_mod_lin.alpha",
"lora_unet_double_blocks_17_txt_mod_lin.lora_down.weight",
"lora_unet_double_blocks_17_txt_mod_lin.lora_up.weight",
"lora_unet_double_blocks_18_img_attn_proj.alpha",
"lora_unet_double_blocks_18_img_attn_proj.lora_down.weight",
"lora_unet_double_blocks_18_img_attn_proj.lora_up.weight",
"lora_unet_double_blocks_18_img_attn_qkv.alpha",
"lora_unet_double_blocks_18_img_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_18_img_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_18_img_mlp_0.alpha",
"lora_unet_double_blocks_18_img_mlp_0.lora_down.weight",
"lora_unet_double_blocks_18_img_mlp_0.lora_up.weight",
"lora_unet_double_blocks_18_img_mlp_2.alpha",
"lora_unet_double_blocks_18_img_mlp_2.lora_down.weight",
"lora_unet_double_blocks_18_img_mlp_2.lora_up.weight",
"lora_unet_double_blocks_18_img_mod_lin.alpha",
"lora_unet_double_blocks_18_img_mod_lin.lora_down.weight",
"lora_unet_double_blocks_18_img_mod_lin.lora_up.weight",
"lora_unet_double_blocks_18_txt_attn_proj.alpha",
"lora_unet_double_blocks_18_txt_attn_proj.lora_down.weight",
"lora_unet_double_blocks_18_txt_attn_proj.lora_up.weight",
"lora_unet_double_blocks_18_txt_attn_qkv.alpha",
"lora_unet_double_blocks_18_txt_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_18_txt_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_18_txt_mlp_0.alpha",
"lora_unet_double_blocks_18_txt_mlp_0.lora_down.weight",
"lora_unet_double_blocks_18_txt_mlp_0.lora_up.weight",
"lora_unet_double_blocks_18_txt_mlp_2.alpha",
"lora_unet_double_blocks_18_txt_mlp_2.lora_down.weight",
"lora_unet_double_blocks_18_txt_mlp_2.lora_up.weight",
"lora_unet_double_blocks_18_txt_mod_lin.alpha",
"lora_unet_double_blocks_18_txt_mod_lin.lora_down.weight",
"lora_unet_double_blocks_18_txt_mod_lin.lora_up.weight",
"lora_unet_double_blocks_1_img_attn_proj.alpha",
"lora_unet_double_blocks_1_img_attn_proj.lora_down.weight",
"lora_unet_double_blocks_1_img_attn_proj.lora_up.weight",
"lora_unet_double_blocks_1_img_attn_qkv.alpha",
"lora_unet_double_blocks_1_img_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_1_img_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_1_img_mlp_0.alpha",
"lora_unet_double_blocks_1_img_mlp_0.lora_down.weight",
"lora_unet_double_blocks_1_img_mlp_0.lora_up.weight",
"lora_unet_double_blocks_1_img_mlp_2.alpha",
"lora_unet_double_blocks_1_img_mlp_2.lora_down.weight",
"lora_unet_double_blocks_1_img_mlp_2.lora_up.weight",
"lora_unet_double_blocks_1_img_mod_lin.alpha",
"lora_unet_double_blocks_1_img_mod_lin.lora_down.weight",
"lora_unet_double_blocks_1_img_mod_lin.lora_up.weight",
"lora_unet_double_blocks_1_txt_attn_proj.alpha",
"lora_unet_double_blocks_1_txt_attn_proj.lora_down.weight",
"lora_unet_double_blocks_1_txt_attn_proj.lora_up.weight",
"lora_unet_double_blocks_1_txt_attn_qkv.alpha",
"lora_unet_double_blocks_1_txt_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_1_txt_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_1_txt_mlp_0.alpha",
"lora_unet_double_blocks_1_txt_mlp_0.lora_down.weight",
"lora_unet_double_blocks_1_txt_mlp_0.lora_up.weight",
"lora_unet_double_blocks_1_txt_mlp_2.alpha",
"lora_unet_double_blocks_1_txt_mlp_2.lora_down.weight",
"lora_unet_double_blocks_1_txt_mlp_2.lora_up.weight",
"lora_unet_double_blocks_1_txt_mod_lin.alpha",
"lora_unet_double_blocks_1_txt_mod_lin.lora_down.weight",
"lora_unet_double_blocks_1_txt_mod_lin.lora_up.weight",
"lora_unet_double_blocks_2_img_attn_proj.alpha",
"lora_unet_double_blocks_2_img_attn_proj.lora_down.weight",
"lora_unet_double_blocks_2_img_attn_proj.lora_up.weight",
"lora_unet_double_blocks_2_img_attn_qkv.alpha",
"lora_unet_double_blocks_2_img_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_2_img_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_2_img_mlp_0.alpha",
"lora_unet_double_blocks_2_img_mlp_0.lora_down.weight",
"lora_unet_double_blocks_2_img_mlp_0.lora_up.weight",
"lora_unet_double_blocks_2_img_mlp_2.alpha",
"lora_unet_double_blocks_2_img_mlp_2.lora_down.weight",
"lora_unet_double_blocks_2_img_mlp_2.lora_up.weight",
"lora_unet_double_blocks_2_img_mod_lin.alpha",
"lora_unet_double_blocks_2_img_mod_lin.lora_down.weight",
"lora_unet_double_blocks_2_img_mod_lin.lora_up.weight",
"lora_unet_double_blocks_2_txt_attn_proj.alpha",
"lora_unet_double_blocks_2_txt_attn_proj.lora_down.weight",
"lora_unet_double_blocks_2_txt_attn_proj.lora_up.weight",
"lora_unet_double_blocks_2_txt_attn_qkv.alpha",
"lora_unet_double_blocks_2_txt_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_2_txt_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_2_txt_mlp_0.alpha",
"lora_unet_double_blocks_2_txt_mlp_0.lora_down.weight",
"lora_unet_double_blocks_2_txt_mlp_0.lora_up.weight",
"lora_unet_double_blocks_2_txt_mlp_2.alpha",
"lora_unet_double_blocks_2_txt_mlp_2.lora_down.weight",
"lora_unet_double_blocks_2_txt_mlp_2.lora_up.weight",
"lora_unet_double_blocks_2_txt_mod_lin.alpha",
"lora_unet_double_blocks_2_txt_mod_lin.lora_down.weight",
"lora_unet_double_blocks_2_txt_mod_lin.lora_up.weight",
"lora_unet_double_blocks_3_img_attn_proj.alpha",
"lora_unet_double_blocks_3_img_attn_proj.lora_down.weight",
"lora_unet_double_blocks_3_img_attn_proj.lora_up.weight",
"lora_unet_double_blocks_3_img_attn_qkv.alpha",
"lora_unet_double_blocks_3_img_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_3_img_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_3_img_mlp_0.alpha",
"lora_unet_double_blocks_3_img_mlp_0.lora_down.weight",
"lora_unet_double_blocks_3_img_mlp_0.lora_up.weight",
"lora_unet_double_blocks_3_img_mlp_2.alpha",
"lora_unet_double_blocks_3_img_mlp_2.lora_down.weight",
"lora_unet_double_blocks_3_img_mlp_2.lora_up.weight",
"lora_unet_double_blocks_3_img_mod_lin.alpha",
"lora_unet_double_blocks_3_img_mod_lin.lora_down.weight",
"lora_unet_double_blocks_3_img_mod_lin.lora_up.weight",
"lora_unet_double_blocks_3_txt_attn_proj.alpha",
"lora_unet_double_blocks_3_txt_attn_proj.lora_down.weight",
"lora_unet_double_blocks_3_txt_attn_proj.lora_up.weight",
"lora_unet_double_blocks_3_txt_attn_qkv.alpha",
"lora_unet_double_blocks_3_txt_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_3_txt_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_3_txt_mlp_0.alpha",
"lora_unet_double_blocks_3_txt_mlp_0.lora_down.weight",
"lora_unet_double_blocks_3_txt_mlp_0.lora_up.weight",
"lora_unet_double_blocks_3_txt_mlp_2.alpha",
"lora_unet_double_blocks_3_txt_mlp_2.lora_down.weight",
"lora_unet_double_blocks_3_txt_mlp_2.lora_up.weight",
"lora_unet_double_blocks_3_txt_mod_lin.alpha",
"lora_unet_double_blocks_3_txt_mod_lin.lora_down.weight",
"lora_unet_double_blocks_3_txt_mod_lin.lora_up.weight",
"lora_unet_double_blocks_4_img_attn_proj.alpha",
"lora_unet_double_blocks_4_img_attn_proj.lora_down.weight",
"lora_unet_double_blocks_4_img_attn_proj.lora_up.weight",
"lora_unet_double_blocks_4_img_attn_qkv.alpha",
"lora_unet_double_blocks_4_img_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_4_img_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_4_img_mlp_0.alpha",
"lora_unet_double_blocks_4_img_mlp_0.lora_down.weight",
"lora_unet_double_blocks_4_img_mlp_0.lora_up.weight",
"lora_unet_double_blocks_4_img_mlp_2.alpha",
"lora_unet_double_blocks_4_img_mlp_2.lora_down.weight",
"lora_unet_double_blocks_4_img_mlp_2.lora_up.weight",
"lora_unet_double_blocks_4_img_mod_lin.alpha",
"lora_unet_double_blocks_4_img_mod_lin.lora_down.weight",
"lora_unet_double_blocks_4_img_mod_lin.lora_up.weight",
"lora_unet_double_blocks_4_txt_attn_proj.alpha",
"lora_unet_double_blocks_4_txt_attn_proj.lora_down.weight",
"lora_unet_double_blocks_4_txt_attn_proj.lora_up.weight",
"lora_unet_double_blocks_4_txt_attn_qkv.alpha",
"lora_unet_double_blocks_4_txt_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_4_txt_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_4_txt_mlp_0.alpha",
"lora_unet_double_blocks_4_txt_mlp_0.lora_down.weight",
"lora_unet_double_blocks_4_txt_mlp_0.lora_up.weight",
"lora_unet_double_blocks_4_txt_mlp_2.alpha",
"lora_unet_double_blocks_4_txt_mlp_2.lora_down.weight",
"lora_unet_double_blocks_4_txt_mlp_2.lora_up.weight",
"lora_unet_double_blocks_4_txt_mod_lin.alpha",
"lora_unet_double_blocks_4_txt_mod_lin.lora_down.weight",
"lora_unet_double_blocks_4_txt_mod_lin.lora_up.weight",
"lora_unet_double_blocks_5_img_attn_proj.alpha",
"lora_unet_double_blocks_5_img_attn_proj.lora_down.weight",
"lora_unet_double_blocks_5_img_attn_proj.lora_up.weight",
"lora_unet_double_blocks_5_img_attn_qkv.alpha",
"lora_unet_double_blocks_5_img_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_5_img_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_5_img_mlp_0.alpha",
"lora_unet_double_blocks_5_img_mlp_0.lora_down.weight",
"lora_unet_double_blocks_5_img_mlp_0.lora_up.weight",
"lora_unet_double_blocks_5_img_mlp_2.alpha",
"lora_unet_double_blocks_5_img_mlp_2.lora_down.weight",
"lora_unet_double_blocks_5_img_mlp_2.lora_up.weight",
"lora_unet_double_blocks_5_img_mod_lin.alpha",
"lora_unet_double_blocks_5_img_mod_lin.lora_down.weight",
"lora_unet_double_blocks_5_img_mod_lin.lora_up.weight",
"lora_unet_double_blocks_5_txt_attn_proj.alpha",
"lora_unet_double_blocks_5_txt_attn_proj.lora_down.weight",
"lora_unet_double_blocks_5_txt_attn_proj.lora_up.weight",
"lora_unet_double_blocks_5_txt_attn_qkv.alpha",
"lora_unet_double_blocks_5_txt_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_5_txt_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_5_txt_mlp_0.alpha",
"lora_unet_double_blocks_5_txt_mlp_0.lora_down.weight",
"lora_unet_double_blocks_5_txt_mlp_0.lora_up.weight",
"lora_unet_double_blocks_5_txt_mlp_2.alpha",
"lora_unet_double_blocks_5_txt_mlp_2.lora_down.weight",
"lora_unet_double_blocks_5_txt_mlp_2.lora_up.weight",
"lora_unet_double_blocks_5_txt_mod_lin.alpha",
"lora_unet_double_blocks_5_txt_mod_lin.lora_down.weight",
"lora_unet_double_blocks_5_txt_mod_lin.lora_up.weight",
"lora_unet_double_blocks_6_img_attn_proj.alpha",
"lora_unet_double_blocks_6_img_attn_proj.lora_down.weight",
"lora_unet_double_blocks_6_img_attn_proj.lora_up.weight",
"lora_unet_double_blocks_6_img_attn_qkv.alpha",
"lora_unet_double_blocks_6_img_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_6_img_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_6_img_mlp_0.alpha",
"lora_unet_double_blocks_6_img_mlp_0.lora_down.weight",
"lora_unet_double_blocks_6_img_mlp_0.lora_up.weight",
"lora_unet_double_blocks_6_img_mlp_2.alpha",
"lora_unet_double_blocks_6_img_mlp_2.lora_down.weight",
"lora_unet_double_blocks_6_img_mlp_2.lora_up.weight",
"lora_unet_double_blocks_6_img_mod_lin.alpha",
"lora_unet_double_blocks_6_img_mod_lin.lora_down.weight",
"lora_unet_double_blocks_6_img_mod_lin.lora_up.weight",
"lora_unet_double_blocks_6_txt_attn_proj.alpha",
"lora_unet_double_blocks_6_txt_attn_proj.lora_down.weight",
"lora_unet_double_blocks_6_txt_attn_proj.lora_up.weight",
"lora_unet_double_blocks_6_txt_attn_qkv.alpha",
"lora_unet_double_blocks_6_txt_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_6_txt_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_6_txt_mlp_0.alpha",
"lora_unet_double_blocks_6_txt_mlp_0.lora_down.weight",
"lora_unet_double_blocks_6_txt_mlp_0.lora_up.weight",
"lora_unet_double_blocks_6_txt_mlp_2.alpha",
"lora_unet_double_blocks_6_txt_mlp_2.lora_down.weight",
"lora_unet_double_blocks_6_txt_mlp_2.lora_up.weight",
"lora_unet_double_blocks_6_txt_mod_lin.alpha",
"lora_unet_double_blocks_6_txt_mod_lin.lora_down.weight",
"lora_unet_double_blocks_6_txt_mod_lin.lora_up.weight",
"lora_unet_double_blocks_7_img_attn_proj.alpha",
"lora_unet_double_blocks_7_img_attn_proj.lora_down.weight",
"lora_unet_double_blocks_7_img_attn_proj.lora_up.weight",
"lora_unet_double_blocks_7_img_attn_qkv.alpha",
"lora_unet_double_blocks_7_img_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_7_img_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_7_img_mlp_0.alpha",
"lora_unet_double_blocks_7_img_mlp_0.lora_down.weight",
"lora_unet_double_blocks_7_img_mlp_0.lora_up.weight",
"lora_unet_double_blocks_7_img_mlp_2.alpha",
"lora_unet_double_blocks_7_img_mlp_2.lora_down.weight",
"lora_unet_double_blocks_7_img_mlp_2.lora_up.weight",
"lora_unet_double_blocks_7_img_mod_lin.alpha",
"lora_unet_double_blocks_7_img_mod_lin.lora_down.weight",
"lora_unet_double_blocks_7_img_mod_lin.lora_up.weight",
"lora_unet_double_blocks_7_txt_attn_proj.alpha",
"lora_unet_double_blocks_7_txt_attn_proj.lora_down.weight",
"lora_unet_double_blocks_7_txt_attn_proj.lora_up.weight",
"lora_unet_double_blocks_7_txt_attn_qkv.alpha",
"lora_unet_double_blocks_7_txt_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_7_txt_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_7_txt_mlp_0.alpha",
"lora_unet_double_blocks_7_txt_mlp_0.lora_down.weight",
"lora_unet_double_blocks_7_txt_mlp_0.lora_up.weight",
"lora_unet_double_blocks_7_txt_mlp_2.alpha",
"lora_unet_double_blocks_7_txt_mlp_2.lora_down.weight",
"lora_unet_double_blocks_7_txt_mlp_2.lora_up.weight",
"lora_unet_double_blocks_7_txt_mod_lin.alpha",
"lora_unet_double_blocks_7_txt_mod_lin.lora_down.weight",
"lora_unet_double_blocks_7_txt_mod_lin.lora_up.weight",
"lora_unet_double_blocks_8_img_attn_proj.alpha",
"lora_unet_double_blocks_8_img_attn_proj.lora_down.weight",
"lora_unet_double_blocks_8_img_attn_proj.lora_up.weight",
"lora_unet_double_blocks_8_img_attn_qkv.alpha",
"lora_unet_double_blocks_8_img_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_8_img_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_8_img_mlp_0.alpha",
"lora_unet_double_blocks_8_img_mlp_0.lora_down.weight",
"lora_unet_double_blocks_8_img_mlp_0.lora_up.weight",
"lora_unet_double_blocks_8_img_mlp_2.alpha",
"lora_unet_double_blocks_8_img_mlp_2.lora_down.weight",
"lora_unet_double_blocks_8_img_mlp_2.lora_up.weight",
"lora_unet_double_blocks_8_img_mod_lin.alpha",
"lora_unet_double_blocks_8_img_mod_lin.lora_down.weight",
"lora_unet_double_blocks_8_img_mod_lin.lora_up.weight",
"lora_unet_double_blocks_8_txt_attn_proj.alpha",
"lora_unet_double_blocks_8_txt_attn_proj.lora_down.weight",
"lora_unet_double_blocks_8_txt_attn_proj.lora_up.weight",
"lora_unet_double_blocks_8_txt_attn_qkv.alpha",
"lora_unet_double_blocks_8_txt_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_8_txt_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_8_txt_mlp_0.alpha",
"lora_unet_double_blocks_8_txt_mlp_0.lora_down.weight",
"lora_unet_double_blocks_8_txt_mlp_0.lora_up.weight",
"lora_unet_double_blocks_8_txt_mlp_2.alpha",
"lora_unet_double_blocks_8_txt_mlp_2.lora_down.weight",
"lora_unet_double_blocks_8_txt_mlp_2.lora_up.weight",
"lora_unet_double_blocks_8_txt_mod_lin.alpha",
"lora_unet_double_blocks_8_txt_mod_lin.lora_down.weight",
"lora_unet_double_blocks_8_txt_mod_lin.lora_up.weight",
"lora_unet_double_blocks_9_img_attn_proj.alpha",
"lora_unet_double_blocks_9_img_attn_proj.lora_down.weight",
"lora_unet_double_blocks_9_img_attn_proj.lora_up.weight",
"lora_unet_double_blocks_9_img_attn_qkv.alpha",
"lora_unet_double_blocks_9_img_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_9_img_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_9_img_mlp_0.alpha",
"lora_unet_double_blocks_9_img_mlp_0.lora_down.weight",
"lora_unet_double_blocks_9_img_mlp_0.lora_up.weight",
"lora_unet_double_blocks_9_img_mlp_2.alpha",
"lora_unet_double_blocks_9_img_mlp_2.lora_down.weight",
"lora_unet_double_blocks_9_img_mlp_2.lora_up.weight",
"lora_unet_double_blocks_9_img_mod_lin.alpha",
"lora_unet_double_blocks_9_img_mod_lin.lora_down.weight",
"lora_unet_double_blocks_9_img_mod_lin.lora_up.weight",
"lora_unet_double_blocks_9_txt_attn_proj.alpha",
"lora_unet_double_blocks_9_txt_attn_proj.lora_down.weight",
"lora_unet_double_blocks_9_txt_attn_proj.lora_up.weight",
"lora_unet_double_blocks_9_txt_attn_qkv.alpha",
"lora_unet_double_blocks_9_txt_attn_qkv.lora_down.weight",
"lora_unet_double_blocks_9_txt_attn_qkv.lora_up.weight",
"lora_unet_double_blocks_9_txt_mlp_0.alpha",
"lora_unet_double_blocks_9_txt_mlp_0.lora_down.weight",
"lora_unet_double_blocks_9_txt_mlp_0.lora_up.weight",
"lora_unet_double_blocks_9_txt_mlp_2.alpha",
"lora_unet_double_blocks_9_txt_mlp_2.lora_down.weight",
"lora_unet_double_blocks_9_txt_mlp_2.lora_up.weight",
"lora_unet_double_blocks_9_txt_mod_lin.alpha",
"lora_unet_double_blocks_9_txt_mod_lin.lora_down.weight",
"lora_unet_double_blocks_9_txt_mod_lin.lora_up.weight",
"lora_unet_single_blocks_0_linear1.alpha",
"lora_unet_single_blocks_0_linear1.lora_down.weight",
"lora_unet_single_blocks_0_linear1.lora_up.weight",
"lora_unet_single_blocks_0_linear2.alpha",
"lora_unet_single_blocks_0_linear2.lora_down.weight",
"lora_unet_single_blocks_0_linear2.lora_up.weight",
"lora_unet_single_blocks_0_modulation_lin.alpha",
"lora_unet_single_blocks_0_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_0_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_10_linear1.alpha",
"lora_unet_single_blocks_10_linear1.lora_down.weight",
"lora_unet_single_blocks_10_linear1.lora_up.weight",
"lora_unet_single_blocks_10_linear2.alpha",
"lora_unet_single_blocks_10_linear2.lora_down.weight",
"lora_unet_single_blocks_10_linear2.lora_up.weight",
"lora_unet_single_blocks_10_modulation_lin.alpha",
"lora_unet_single_blocks_10_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_10_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_11_linear1.alpha",
"lora_unet_single_blocks_11_linear1.lora_down.weight",
"lora_unet_single_blocks_11_linear1.lora_up.weight",
"lora_unet_single_blocks_11_linear2.alpha",
"lora_unet_single_blocks_11_linear2.lora_down.weight",
"lora_unet_single_blocks_11_linear2.lora_up.weight",
"lora_unet_single_blocks_11_modulation_lin.alpha",
"lora_unet_single_blocks_11_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_11_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_12_linear1.alpha",
"lora_unet_single_blocks_12_linear1.lora_down.weight",
"lora_unet_single_blocks_12_linear1.lora_up.weight",
"lora_unet_single_blocks_12_linear2.alpha",
"lora_unet_single_blocks_12_linear2.lora_down.weight",
"lora_unet_single_blocks_12_linear2.lora_up.weight",
"lora_unet_single_blocks_12_modulation_lin.alpha",
"lora_unet_single_blocks_12_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_12_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_13_linear1.alpha",
"lora_unet_single_blocks_13_linear1.lora_down.weight",
"lora_unet_single_blocks_13_linear1.lora_up.weight",
"lora_unet_single_blocks_13_linear2.alpha",
"lora_unet_single_blocks_13_linear2.lora_down.weight",
"lora_unet_single_blocks_13_linear2.lora_up.weight",
"lora_unet_single_blocks_13_modulation_lin.alpha",
"lora_unet_single_blocks_13_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_13_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_14_linear1.alpha",
"lora_unet_single_blocks_14_linear1.lora_down.weight",
"lora_unet_single_blocks_14_linear1.lora_up.weight",
"lora_unet_single_blocks_14_linear2.alpha",
"lora_unet_single_blocks_14_linear2.lora_down.weight",
"lora_unet_single_blocks_14_linear2.lora_up.weight",
"lora_unet_single_blocks_14_modulation_lin.alpha",
"lora_unet_single_blocks_14_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_14_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_15_linear1.alpha",
"lora_unet_single_blocks_15_linear1.lora_down.weight",
"lora_unet_single_blocks_15_linear1.lora_up.weight",
"lora_unet_single_blocks_15_linear2.alpha",
"lora_unet_single_blocks_15_linear2.lora_down.weight",
"lora_unet_single_blocks_15_linear2.lora_up.weight",
"lora_unet_single_blocks_15_modulation_lin.alpha",
"lora_unet_single_blocks_15_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_15_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_16_linear1.alpha",
"lora_unet_single_blocks_16_linear1.lora_down.weight",
"lora_unet_single_blocks_16_linear1.lora_up.weight",
"lora_unet_single_blocks_16_linear2.alpha",
"lora_unet_single_blocks_16_linear2.lora_down.weight",
"lora_unet_single_blocks_16_linear2.lora_up.weight",
"lora_unet_single_blocks_16_modulation_lin.alpha",
"lora_unet_single_blocks_16_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_16_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_17_linear1.alpha",
"lora_unet_single_blocks_17_linear1.lora_down.weight",
"lora_unet_single_blocks_17_linear1.lora_up.weight",
"lora_unet_single_blocks_17_linear2.alpha",
"lora_unet_single_blocks_17_linear2.lora_down.weight",
"lora_unet_single_blocks_17_linear2.lora_up.weight",
"lora_unet_single_blocks_17_modulation_lin.alpha",
"lora_unet_single_blocks_17_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_17_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_18_linear1.alpha",
"lora_unet_single_blocks_18_linear1.lora_down.weight",
"lora_unet_single_blocks_18_linear1.lora_up.weight",
"lora_unet_single_blocks_18_linear2.alpha",
"lora_unet_single_blocks_18_linear2.lora_down.weight",
"lora_unet_single_blocks_18_linear2.lora_up.weight",
"lora_unet_single_blocks_18_modulation_lin.alpha",
"lora_unet_single_blocks_18_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_18_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_19_linear1.alpha",
"lora_unet_single_blocks_19_linear1.lora_down.weight",
"lora_unet_single_blocks_19_linear1.lora_up.weight",
"lora_unet_single_blocks_19_linear2.alpha",
"lora_unet_single_blocks_19_linear2.lora_down.weight",
"lora_unet_single_blocks_19_linear2.lora_up.weight",
"lora_unet_single_blocks_19_modulation_lin.alpha",
"lora_unet_single_blocks_19_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_19_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_1_linear1.alpha",
"lora_unet_single_blocks_1_linear1.lora_down.weight",
"lora_unet_single_blocks_1_linear1.lora_up.weight",
"lora_unet_single_blocks_1_linear2.alpha",
"lora_unet_single_blocks_1_linear2.lora_down.weight",
"lora_unet_single_blocks_1_linear2.lora_up.weight",
"lora_unet_single_blocks_1_modulation_lin.alpha",
"lora_unet_single_blocks_1_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_1_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_20_linear1.alpha",
"lora_unet_single_blocks_20_linear1.lora_down.weight",
"lora_unet_single_blocks_20_linear1.lora_up.weight",
"lora_unet_single_blocks_20_linear2.alpha",
"lora_unet_single_blocks_20_linear2.lora_down.weight",
"lora_unet_single_blocks_20_linear2.lora_up.weight",
"lora_unet_single_blocks_20_modulation_lin.alpha",
"lora_unet_single_blocks_20_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_20_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_21_linear1.alpha",
"lora_unet_single_blocks_21_linear1.lora_down.weight",
"lora_unet_single_blocks_21_linear1.lora_up.weight",
"lora_unet_single_blocks_21_linear2.alpha",
"lora_unet_single_blocks_21_linear2.lora_down.weight",
"lora_unet_single_blocks_21_linear2.lora_up.weight",
"lora_unet_single_blocks_21_modulation_lin.alpha",
"lora_unet_single_blocks_21_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_21_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_22_linear1.alpha",
"lora_unet_single_blocks_22_linear1.lora_down.weight",
"lora_unet_single_blocks_22_linear1.lora_up.weight",
"lora_unet_single_blocks_22_linear2.alpha",
"lora_unet_single_blocks_22_linear2.lora_down.weight",
"lora_unet_single_blocks_22_linear2.lora_up.weight",
"lora_unet_single_blocks_22_modulation_lin.alpha",
"lora_unet_single_blocks_22_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_22_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_23_linear1.alpha",
"lora_unet_single_blocks_23_linear1.lora_down.weight",
"lora_unet_single_blocks_23_linear1.lora_up.weight",
"lora_unet_single_blocks_23_linear2.alpha",
"lora_unet_single_blocks_23_linear2.lora_down.weight",
"lora_unet_single_blocks_23_linear2.lora_up.weight",
"lora_unet_single_blocks_23_modulation_lin.alpha",
"lora_unet_single_blocks_23_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_23_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_24_linear1.alpha",
"lora_unet_single_blocks_24_linear1.lora_down.weight",
"lora_unet_single_blocks_24_linear1.lora_up.weight",
"lora_unet_single_blocks_24_linear2.alpha",
"lora_unet_single_blocks_24_linear2.lora_down.weight",
"lora_unet_single_blocks_24_linear2.lora_up.weight",
"lora_unet_single_blocks_24_modulation_lin.alpha",
"lora_unet_single_blocks_24_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_24_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_25_linear1.alpha",
"lora_unet_single_blocks_25_linear1.lora_down.weight",
"lora_unet_single_blocks_25_linear1.lora_up.weight",
"lora_unet_single_blocks_25_linear2.alpha",
"lora_unet_single_blocks_25_linear2.lora_down.weight",
"lora_unet_single_blocks_25_linear2.lora_up.weight",
"lora_unet_single_blocks_25_modulation_lin.alpha",
"lora_unet_single_blocks_25_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_25_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_26_linear1.alpha",
"lora_unet_single_blocks_26_linear1.lora_down.weight",
"lora_unet_single_blocks_26_linear1.lora_up.weight",
"lora_unet_single_blocks_26_linear2.alpha",
"lora_unet_single_blocks_26_linear2.lora_down.weight",
"lora_unet_single_blocks_26_linear2.lora_up.weight",
"lora_unet_single_blocks_26_modulation_lin.alpha",
"lora_unet_single_blocks_26_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_26_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_27_linear1.alpha",
"lora_unet_single_blocks_27_linear1.lora_down.weight",
"lora_unet_single_blocks_27_linear1.lora_up.weight",
"lora_unet_single_blocks_27_linear2.alpha",
"lora_unet_single_blocks_27_linear2.lora_down.weight",
"lora_unet_single_blocks_27_linear2.lora_up.weight",
"lora_unet_single_blocks_27_modulation_lin.alpha",
"lora_unet_single_blocks_27_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_27_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_28_linear1.alpha",
"lora_unet_single_blocks_28_linear1.lora_down.weight",
"lora_unet_single_blocks_28_linear1.lora_up.weight",
"lora_unet_single_blocks_28_linear2.alpha",
"lora_unet_single_blocks_28_linear2.lora_down.weight",
"lora_unet_single_blocks_28_linear2.lora_up.weight",
"lora_unet_single_blocks_28_modulation_lin.alpha",
"lora_unet_single_blocks_28_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_28_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_29_linear1.alpha",
"lora_unet_single_blocks_29_linear1.lora_down.weight",
"lora_unet_single_blocks_29_linear1.lora_up.weight",
"lora_unet_single_blocks_29_linear2.alpha",
"lora_unet_single_blocks_29_linear2.lora_down.weight",
"lora_unet_single_blocks_29_linear2.lora_up.weight",
"lora_unet_single_blocks_29_modulation_lin.alpha",
"lora_unet_single_blocks_29_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_29_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_2_linear1.alpha",
"lora_unet_single_blocks_2_linear1.lora_down.weight",
"lora_unet_single_blocks_2_linear1.lora_up.weight",
"lora_unet_single_blocks_2_linear2.alpha",
"lora_unet_single_blocks_2_linear2.lora_down.weight",
"lora_unet_single_blocks_2_linear2.lora_up.weight",
"lora_unet_single_blocks_2_modulation_lin.alpha",
"lora_unet_single_blocks_2_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_2_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_30_linear1.alpha",
"lora_unet_single_blocks_30_linear1.lora_down.weight",
"lora_unet_single_blocks_30_linear1.lora_up.weight",
"lora_unet_single_blocks_30_linear2.alpha",
"lora_unet_single_blocks_30_linear2.lora_down.weight",
"lora_unet_single_blocks_30_linear2.lora_up.weight",
"lora_unet_single_blocks_30_modulation_lin.alpha",
"lora_unet_single_blocks_30_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_30_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_31_linear1.alpha",
"lora_unet_single_blocks_31_linear1.lora_down.weight",
"lora_unet_single_blocks_31_linear1.lora_up.weight",
"lora_unet_single_blocks_31_linear2.alpha",
"lora_unet_single_blocks_31_linear2.lora_down.weight",
"lora_unet_single_blocks_31_linear2.lora_up.weight",
"lora_unet_single_blocks_31_modulation_lin.alpha",
"lora_unet_single_blocks_31_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_31_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_32_linear1.alpha",
"lora_unet_single_blocks_32_linear1.lora_down.weight",
"lora_unet_single_blocks_32_linear1.lora_up.weight",
"lora_unet_single_blocks_32_linear2.alpha",
"lora_unet_single_blocks_32_linear2.lora_down.weight",
"lora_unet_single_blocks_32_linear2.lora_up.weight",
"lora_unet_single_blocks_32_modulation_lin.alpha",
"lora_unet_single_blocks_32_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_32_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_33_linear1.alpha",
"lora_unet_single_blocks_33_linear1.lora_down.weight",
"lora_unet_single_blocks_33_linear1.lora_up.weight",
"lora_unet_single_blocks_33_linear2.alpha",
"lora_unet_single_blocks_33_linear2.lora_down.weight",
"lora_unet_single_blocks_33_linear2.lora_up.weight",
"lora_unet_single_blocks_33_modulation_lin.alpha",
"lora_unet_single_blocks_33_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_33_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_34_linear1.alpha",
"lora_unet_single_blocks_34_linear1.lora_down.weight",
"lora_unet_single_blocks_34_linear1.lora_up.weight",
"lora_unet_single_blocks_34_linear2.alpha",
"lora_unet_single_blocks_34_linear2.lora_down.weight",
"lora_unet_single_blocks_34_linear2.lora_up.weight",
"lora_unet_single_blocks_34_modulation_lin.alpha",
"lora_unet_single_blocks_34_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_34_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_35_linear1.alpha",
"lora_unet_single_blocks_35_linear1.lora_down.weight",
"lora_unet_single_blocks_35_linear1.lora_up.weight",
"lora_unet_single_blocks_35_linear2.alpha",
"lora_unet_single_blocks_35_linear2.lora_down.weight",
"lora_unet_single_blocks_35_linear2.lora_up.weight",
"lora_unet_single_blocks_35_modulation_lin.alpha",
"lora_unet_single_blocks_35_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_35_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_36_linear1.alpha",
"lora_unet_single_blocks_36_linear1.lora_down.weight",
"lora_unet_single_blocks_36_linear1.lora_up.weight",
"lora_unet_single_blocks_36_linear2.alpha",
"lora_unet_single_blocks_36_linear2.lora_down.weight",
"lora_unet_single_blocks_36_linear2.lora_up.weight",
"lora_unet_single_blocks_36_modulation_lin.alpha",
"lora_unet_single_blocks_36_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_36_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_37_linear1.alpha",
"lora_unet_single_blocks_37_linear1.lora_down.weight",
"lora_unet_single_blocks_37_linear1.lora_up.weight",
"lora_unet_single_blocks_37_linear2.alpha",
"lora_unet_single_blocks_37_linear2.lora_down.weight",
"lora_unet_single_blocks_37_linear2.lora_up.weight",
"lora_unet_single_blocks_37_modulation_lin.alpha",
"lora_unet_single_blocks_37_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_37_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_3_linear1.alpha",
"lora_unet_single_blocks_3_linear1.lora_down.weight",
"lora_unet_single_blocks_3_linear1.lora_up.weight",
"lora_unet_single_blocks_3_linear2.alpha",
"lora_unet_single_blocks_3_linear2.lora_down.weight",
"lora_unet_single_blocks_3_linear2.lora_up.weight",
"lora_unet_single_blocks_3_modulation_lin.alpha",
"lora_unet_single_blocks_3_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_3_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_4_linear1.alpha",
"lora_unet_single_blocks_4_linear1.lora_down.weight",
"lora_unet_single_blocks_4_linear1.lora_up.weight",
"lora_unet_single_blocks_4_linear2.alpha",
"lora_unet_single_blocks_4_linear2.lora_down.weight",
"lora_unet_single_blocks_4_linear2.lora_up.weight",
"lora_unet_single_blocks_4_modulation_lin.alpha",
"lora_unet_single_blocks_4_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_4_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_5_linear1.alpha",
"lora_unet_single_blocks_5_linear1.lora_down.weight",
"lora_unet_single_blocks_5_linear1.lora_up.weight",
"lora_unet_single_blocks_5_linear2.alpha",
"lora_unet_single_blocks_5_linear2.lora_down.weight",
"lora_unet_single_blocks_5_linear2.lora_up.weight",
"lora_unet_single_blocks_5_modulation_lin.alpha",
"lora_unet_single_blocks_5_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_5_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_6_linear1.alpha",
"lora_unet_single_blocks_6_linear1.lora_down.weight",
"lora_unet_single_blocks_6_linear1.lora_up.weight",
"lora_unet_single_blocks_6_linear2.alpha",
"lora_unet_single_blocks_6_linear2.lora_down.weight",
"lora_unet_single_blocks_6_linear2.lora_up.weight",
"lora_unet_single_blocks_6_modulation_lin.alpha",
"lora_unet_single_blocks_6_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_6_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_7_linear1.alpha",
"lora_unet_single_blocks_7_linear1.lora_down.weight",
"lora_unet_single_blocks_7_linear1.lora_up.weight",
"lora_unet_single_blocks_7_linear2.alpha",
"lora_unet_single_blocks_7_linear2.lora_down.weight",
"lora_unet_single_blocks_7_linear2.lora_up.weight",
"lora_unet_single_blocks_7_modulation_lin.alpha",
"lora_unet_single_blocks_7_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_7_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_8_linear1.alpha",
"lora_unet_single_blocks_8_linear1.lora_down.weight",
"lora_unet_single_blocks_8_linear1.lora_up.weight",
"lora_unet_single_blocks_8_linear2.alpha",
"lora_unet_single_blocks_8_linear2.lora_down.weight",
"lora_unet_single_blocks_8_linear2.lora_up.weight",
"lora_unet_single_blocks_8_modulation_lin.alpha",
"lora_unet_single_blocks_8_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_8_modulation_lin.lora_up.weight",
"lora_unet_single_blocks_9_linear1.alpha",
"lora_unet_single_blocks_9_linear1.lora_down.weight",
"lora_unet_single_blocks_9_linear1.lora_up.weight",
"lora_unet_single_blocks_9_linear2.alpha",
"lora_unet_single_blocks_9_linear2.lora_down.weight",
"lora_unet_single_blocks_9_linear2.lora_up.weight",
"lora_unet_single_blocks_9_modulation_lin.alpha",
"lora_unet_single_blocks_9_modulation_lin.lora_down.weight",
"lora_unet_single_blocks_9_modulation_lin.lora_up.weight",
]

View File

@@ -0,0 +1,97 @@
import pytest
import torch
from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.util import params
from invokeai.backend.peft.conversions.flux_kohya_lora_conversion_utils import (
convert_flux_kohya_state_dict_to_invoke_format,
is_state_dict_likely_in_flux_kohya_format,
lora_model_from_flux_kohya_state_dict,
)
from tests.backend.peft.conversions.lora_state_dicts.flux_lora_kohya_format import state_dict_keys
def test_is_state_dict_likely_in_flux_kohya_format_true():
"""Test that is_state_dict_likely_in_flux_kohya_format() can identify a state dict in the Kohya FLUX LoRA format."""
# Construct a state dict that is in the Kohya FLUX LoRA format.
state_dict: dict[str, torch.Tensor] = {}
for k in state_dict_keys:
state_dict[k] = torch.empty(1)
assert is_state_dict_likely_in_flux_kohya_format(state_dict)
def test_is_state_dict_likely_in_flux_kohya_format_false():
"""Test that is_state_dict_likely_in_flux_kohya_format() returns False for a state dict that is not in the Kohya FLUX LoRA format."""
state_dict: dict[str, torch.Tensor] = {
"unexpected_key.lora_up.weight": torch.empty(1),
}
assert not is_state_dict_likely_in_flux_kohya_format(state_dict)
def test_convert_flux_kohya_state_dict_to_invoke_format():
# Construct state_dict from state_dict_keys.
state_dict: dict[str, torch.Tensor] = {}
for k in state_dict_keys:
state_dict[k] = torch.empty(1)
converted_state_dict = convert_flux_kohya_state_dict_to_invoke_format(state_dict)
# Extract the prefixes from the converted state dict (i.e. without the .lora_up.weight, .lora_down.weight, and
# .alpha suffixes).
converted_key_prefixes: list[str] = []
for k in converted_state_dict.keys():
k = k.replace(".lora_up.weight", "")
k = k.replace(".lora_down.weight", "")
k = k.replace(".alpha", "")
converted_key_prefixes.append(k)
# Initialize a FLUX model on the meta device.
with torch.device("meta"):
model = Flux(params["flux-dev"])
model_keys = set(model.state_dict().keys())
# Assert that the converted state dict matches the keys in the actual model.
for converted_key_prefix in converted_key_prefixes:
found_match = False
for model_key in model_keys:
if model_key.startswith(converted_key_prefix):
found_match = True
break
if not found_match:
raise AssertionError(f"Could not find a match for the converted key prefix: {converted_key_prefix}")
def test_convert_flux_kohya_state_dict_to_invoke_format_error():
"""Test that an error is raised by convert_flux_kohya_state_dict_to_invoke_format() if the input state_dict contains
unexpected keys.
"""
state_dict = {
"unexpected_key.lora_up.weight": torch.empty(1),
}
with pytest.raises(ValueError):
convert_flux_kohya_state_dict_to_invoke_format(state_dict)
def test_lora_model_from_flux_kohya_state_dict():
"""Test that a LoRAModelRaw can be created from a state dict in the Kohya FLUX LoRA format."""
# Construct state_dict from state_dict_keys.
state_dict: dict[str, torch.Tensor] = {}
for k in state_dict_keys:
state_dict[k] = torch.empty(1)
lora_model = lora_model_from_flux_kohya_state_dict(state_dict)
# Prepare expected layer keys.
expected_layer_keys: set[str] = set()
for k in state_dict_keys:
k = k.replace("lora_unet_", "")
k = k.replace(".lora_up.weight", "")
k = k.replace(".lora_down.weight", "")
k = k.replace(".alpha", "")
expected_layer_keys.add(k)
# Assert that the lora_model has the expected layers.
lora_model_keys = set(lora_model.layers.keys())
lora_model_keys = {k.replace(".", "_") for k in lora_model_keys}
assert lora_model_keys == expected_layer_keys