Compare commits

..

100 Commits

Author SHA1 Message Date
Lincoln Stein
48cb6bd200 change workflow to deploy from v2.3 branch 2023-05-06 23:50:34 -04:00
Lincoln Stein
332ac72e0e [Bugfix] Update check failing because process disappears (#3334)
Fixes #3228, where the check to see if invokeai is running fails because
a process no longer exists.
2023-05-04 20:32:51 -04:00
Dan Nguyen
03bbb308c9 [Bugfix] Update check failing because process disappears
Fixes #3228, where the check to see if invokeai is running fails because
a process no longer exists.
2023-05-03 10:54:43 -05:00
Lincoln Stein
1dcac3929b Release v2.3.5 (#3309)
# Version 2.3.5
This will be the 2.3.5 release once it is merged into the `v2.3` branch.
Changes on the RC branch are:

- Bump version number
- Fix bug in LoRA path determination (do it at runtime, not at module
load time, or root will get confused); closes #3293.
- Remove dangling debug statement.
2023-05-01 12:40:47 -04:00
Lincoln Stein
d73f1c363c bump version number 2023-05-01 09:28:49 -04:00
Lincoln Stein
e52e7418bb close #3304 2023-04-29 20:07:21 -04:00
Lincoln Stein
73be58a0b5 fix issue #3293 2023-04-29 11:37:07 -04:00
Lincoln Stein
5a7d11bca8 remove debugging statement 2023-04-27 08:21:26 -04:00
Lincoln Stein
5bbf7fe34a [Bugfix] Renames in 0.15.0 diffusers (#3184)
Link to PR in diffusers repository:
https://github.com/huggingface/diffusers/pull/2691

Imports:
`diffusers.models.cross_attention ->
diffusers.models.attention_processor`

Unions:
`AttnProcessor -> AttentionProcessor`

Classes:
| Old name | New name |
| --- | --- |
| CrossAttention | Attention |
| CrossAttnProcessor | AttnProcessor |
| XFormersCrossAttnProcessor | XFormersAttnProcessor |
| CrossAttnAddedKVProcessor | AttnAddedKVProcessor |
| LoRACrossAttnProcessor | LoRAAttnProcessor |
| LoRAXFormersCrossAttnProcessor | LoRAXFormersAttnProcessor |
| FlaxCrossAttention | FlaxAttention |
| AttendExciteCrossAttnProcessor | AttendExciteAttnProcessor |
| Pix2PixZeroCrossAttnProcessor | Pix2PixZeroAttnProcessor |


Also config values no longer sets as attributes of object:
https://github.com/huggingface/diffusers/pull/2849
2023-04-27 11:38:27 +01:00
Lincoln Stein
bfb968bbe8 Merge branch 'v2.3' into fix/new_diffusers_names 2023-04-26 23:54:37 +01:00
Lincoln Stein
6db72f83a2 bump version number to 2.3.5-rc1 (#3267)
Bump version number for 2.3.5 release candidate.
2023-04-26 23:53:53 +01:00
Sergey Borisov
432e526999 Revert merge changes 2023-04-25 14:49:08 +03:00
Lincoln Stein
830740b93b remove redundant/buggy restore_default_attention() method 2023-04-25 07:05:07 -04:00
StAlKeR7779
ff3f289342 Merge branch 'v2.3' into fix/new_diffusers_names 2023-04-25 13:21:26 +03:00
Lincoln Stein
34abbb3589 Merge branch 'v2.3' into release/v2.3.5 2023-04-25 04:33:09 +01:00
Lincoln Stein
c0eb1a9921 increase sha256 chunksize when calculating model hash (#3162)
- Thanks to @abdBarho, who discovered that increasing the chunksize
dramatically decreases the amount of time to calculate the hash.
2023-04-25 04:25:55 +01:00
Lincoln Stein
2ddd0301f4 bump version number to 2.3.5-rc1 2023-04-24 23:24:33 -04:00
Lincoln Stein
ce6629b6f5 Merge branch 'v2.3' into enhance/increase-sha256-chunksize 2023-04-25 03:58:30 +01:00
Lincoln Stein
994a76aeaa [Enhancement] distinguish v1 from v2 LoRA models (#3175)
# Distinguish LoRA/LyCORIS files based on what version of SD they were
built on top of

- Attempting to run a prompt with a LoRA based on SD v1.X against a
model based on v2.X will now throw an `IncompatibleModelException`. To
import this exception:
`from ldm.modules.lora_manager import IncompatibleModelException` (maybe
this should be defined in ModelManager?)
    
- Enhance `LoraManager.list_loras()` to accept an optional integer
argument, `token_vector_length`. This will filter the returned LoRA
models to return only those that match the indicated length. Use:
      ```
      768 => for models based on SD v1.X
      1024 => for models based on SD v2.X
      ```
Note that this filtering requires each LoRA file to be opened by
`torch.safetensors`. It will take ~8s to scan a directory of 40 files.
    
- Added new static methods to `ldm.modules.kohya_lora_manager`:
      - check_model_compatibility()
      - vector_length_from_checkpoint()
      - vector_length_from_checkpoint_file()

- You can now create subdirectories within the `loras` directory and
organize the model files.
2023-04-25 03:57:45 +01:00
Lincoln Stein
144dfe4a5b Merge branch 'v2.3' into bugfix/lora-incompatibility-handling 2023-04-25 03:54:46 +01:00
Lincoln Stein
5dbc63e2ae Revert "improvements to the installation and upgrade processes" (#3266)
Reverts invoke-ai/InvokeAI#3186
2023-04-25 03:54:04 +01:00
Lincoln Stein
c6ae1edc82 Revert "improvements to the installation and upgrade processes" 2023-04-24 22:53:43 -04:00
Lincoln Stein
0f3c456d59 merge with v2.3 2023-04-24 22:51:48 -04:00
Lincoln Stein
2cd0e036ac Merge branch 'v2.3' into bugfix/lora-incompatibility-handling 2023-04-25 03:24:25 +01:00
Lincoln Stein
a45b3387c0 Merge branch 'v2.3' into enhance/increase-sha256-chunksize 2023-04-25 03:22:43 +01:00
Lincoln Stein
c088cf0344 improvements to the installation and upgrade processes (#3186)
- Moved all postinstallation config file and model munging code out of
the CLI and into a separate script named `invokeai-postinstall`

- Fixed two calls to `shutil.copytree()` so that they don't try to
preserve the file mode of the copied files. This is necessary to run
correctly in a Nix environment (see thread at
https://discord.com/channels/1020123559063990373/1091716696965918732/1095662756738371615)

- Update the installer so that an existing virtual environment will be
updated, not overwritten.

- Pin npyscreen version to see if this fixes issues people have had with
installing this module.
2023-04-25 03:20:58 +01:00
Lincoln Stein
264af3c054 fix crash caused by incorrect conflict resolution 2023-04-24 22:20:12 -04:00
Lincoln Stein
b332432a88 Merge branch 'v2.3' into lstein/bugfix/improve-update-handling 2023-04-25 03:09:12 +01:00
Lincoln Stein
7f7d5894fa Merge branch 'v2.3' into bugfix/lora-incompatibility-handling 2023-04-25 02:51:27 +01:00
Lincoln Stein
96c39b61cf Enable LoRAs to patch the text_encoder as well as the unet (#3214)
Load LoRAs during compel's text embedding encode pass in case there are
requested LoRAs which also want to patch the text encoder.

Also generally cleanup the attention processor patching stuff. It's
still a mess, but at least now it's a *stateless* mess.
2023-04-24 23:22:51 +01:00
Lincoln Stein
40744ed996 Merge branch 'v2.3' into fix_inconsistent_loras 2023-04-22 20:22:32 +01:00
Lincoln Stein
2a2c86896a pull in diffusers 0.15.1
- Change diffusers dependency to `diffusers~=0.15.0` which *should*
  enforce  non-breaking changes.
2023-04-20 13:29:20 -04:00
Lincoln Stein
f36452d650 rebuild front end 2023-04-20 12:27:08 -04:00
Lincoln Stein
e5188309ec Merge branch 'v2.3' into bugfix/lora-incompatibility-handling 2023-04-20 17:25:09 +01:00
Lincoln Stein
aabe79686e Merge branch 'v2.3' into fix/new_diffusers_names 2023-04-20 17:20:33 +01:00
Lincoln Stein
a9e8005a92 CODEOWNERS update - 2.3 branch (#3230)
Both @mauwii and @keturn have been offline for some time. I am
temporarily removing them from CODEOWNERS so that they will not be
responsible for code reviews until they wish to/are able to re-engage
fully.

Note that I have volunteered @GreggHelt2 to be a codeowner of the
generation backend code, replacing @keturn . Let me know if you're
uncomfortable with this.
2023-04-20 17:19:51 +01:00
Lincoln Stein
c2e6d98e66 Merge branch 'v2.3' into dev/codeowner-fix-2.3 2023-04-20 17:19:30 +01:00
Lincoln Stein
40d9b5dc27 [Feature] Add support for LoKR LyCORIS format (#3216)
It's like LoHA but use Kronecker product instead of Hadamard product.
https://github.com/KohakuBlueleaf/LyCORIS#lokr

I tested it on this 2 LoKR's:
https://civitai.com/models/34518/unofficial-vspo-yakumo-beni
https://civitai.com/models/35136/mika-pikazo-lokr

More tests hard to find as it's new format)
Better to test with https://github.com/invoke-ai/InvokeAI/pull/3214

Also a bit refactor forward function.
//LyCORIS also have (IA)^3 format, but I can't find examples in this
format and even on LyCORIS page it's marked as experimental. So, until
there some test examples I prefer not to add this.
2023-04-19 22:51:33 +01:00
Lincoln Stein
216b1c3a4a Merge branch 'v2.3' into fix/new_diffusers_names 2023-04-18 19:37:25 -04:00
Lincoln Stein
1a704efff1 update codeowners in response to team changes 2023-04-18 19:30:52 -04:00
Lincoln Stein
f49d2619be Merge branch 'v2.3' into fix_inconsistent_loras 2023-04-18 19:09:35 -04:00
Lincoln Stein
da96ec9dd5 Merge branch 'v2.3' into feat/lokr_support 2023-04-18 19:08:03 -04:00
Lincoln Stein
298ccda365 fix the "import from directory" function in console model installer (#3211)
- This was inadvertently broken when we stopped supporting direct
loading of checkpoint models.
- Now fixed.
- May fix #3209
2023-04-17 23:04:27 -04:00
StAlKeR7779
967d853020 Merge branch 'v2.3' into feat/lokr_support 2023-04-16 23:10:45 +03:00
StAlKeR7779
e91117bc74 Add support for lokr lycoris format 2023-04-16 23:05:13 +03:00
Damian Stewart
4d58444153 fix issues and further cleanup 2023-04-16 17:54:21 +02:00
Damian Stewart
3667eb4d0d activate LoRAs when generating prompt embeddings; also cleanup attention stuff 2023-04-16 17:03:31 +02:00
Lincoln Stein
203a7157e1 fix the "import from directory" function in console model installer
- This was inadvertently broken when we stopped supporting direct
  loading of checkpoint models.
- Now fixed.
2023-04-15 21:07:02 -04:00
Lincoln Stein
47883860a6 Merge branch 'v2.3' into enhance/increase-sha256-chunksize 2023-04-13 23:00:34 -04:00
Lincoln Stein
6365a7c790 Merge branch 'v2.3' into lstein/bugfix/improve-update-handling 2023-04-13 22:49:41 -04:00
Lincoln Stein
5fcb3d90e4 fix missing files variable 2023-04-13 22:49:04 -04:00
Lincoln Stein
8f17d17208 Merge branch 'v2.3' into fix/new_diffusers_names 2023-04-13 22:44:05 -04:00
Lincoln Stein
c6ecf3afc5 pin diffusers to 0.15.*, and fix deprecation warning on unet.in_channels 2023-04-13 22:38:50 -04:00
Lincoln Stein
2c449bfb34 Merge branch 'v2.3' into bugfix/lora-incompatibility-handling 2023-04-13 22:23:59 -04:00
Lincoln Stein
8fb4b05556 change lora and TI list dynamically when model changes 2023-04-13 22:22:43 -04:00
Lincoln Stein
4d7289b20f explicitly set permissions of config files 2023-04-13 22:03:52 -04:00
Lincoln Stein
d81584c8fd hotfix to 2.3.4 (#3188)
- Pin diffusers to 0.14
- Small fix to LoRA loading routine that was preventing placement of
LoRA files in subdirectories.
- Bump version to 2.3.4.post1
2023-04-13 12:39:16 -04:00
StAlKeR7779
0bc5dcc663 Refactor 2023-04-13 16:05:04 +03:00
Lincoln Stein
1183bf96ed hotfix to 2.3.4
- Pin diffusers to 0.14
- Small fix to LoRA loading routine that was preventing placement of
  LoRA files in subdirectories.
- Bump version to 2.3.4.post1
2023-04-13 08:48:30 -04:00
Lincoln Stein
d81394cda8 fix directory permissions after install 2023-04-13 08:39:47 -04:00
Lincoln Stein
0eda1a03e1 pin diffusers to 0.14 2023-04-13 00:40:26 -04:00
Lincoln Stein
be7e067c95 getLoraModels event filters loras by compatibility 2023-04-13 00:31:11 -04:00
Lincoln Stein
afa3cdce27 add a list_compatible_loras() method 2023-04-13 00:11:26 -04:00
Lincoln Stein
6dfbd1c677 implement caching scheme for vector length 2023-04-12 23:56:52 -04:00
Lincoln Stein
a775c7730e improvements to the installation and upgrade processes
- Moved all postinstallation config file and model munging code out
  of the CLI and into a separate script named `invokeai-postinstall`

- Fixed two calls to `shutil.copytree()` so that they don't try to preserve
  the file mode of the copied files. This is necessary to run correctly
  in a Nix environment
  (see thread at https://discord.com/channels/1020123559063990373/1091716696965918732/1095662756738371615)

- Update the installer so that an existing virtual environment will be
  updated, not overwritten.

- Pin npyscreen version to see if this fixes issues people have had with
  installing this module.
2023-04-12 22:40:53 -04:00
StAlKeR7779
16c97ca0cb Fix num_train_timesteps in config 2023-04-12 23:57:45 +03:00
StAlKeR7779
e24dd97b80 Fix that config attributes no longer accessible as object attributes 2023-04-12 23:40:14 +03:00
StAlKeR7779
5a54039dd7 Fix imports for diffusers 0.15.0
Imports:
`diffusers.models.cross_attention -> diffusers.models.attention_processor`

Unions:
`AttnProcessor -> AttentionProcessor`

Classes:
| Old name | New name|
| --- | --- |
| CrossAttention | Attention |
| CrossAttnProcessor | AttnProcessor |
| XFormersCrossAttnProcessor | XFormersAttnProcessor |
| CrossAttnAddedKVProcessor | AttnAddedKVProcessor |
| LoRACrossAttnProcessor | LoRAAttnProcessor |
| LoRAXFormersCrossAttnProcessor | LoRAXFormersAttnProcessor |

Same names in this class:
`SlicedAttnProcessor, SlicedAttnAddedKVProcessor`
2023-04-12 22:54:25 +03:00
Lincoln Stein
9385edb453 Merge branch 'v2.3' into enhance/increase-sha256-chunksize 2023-04-11 18:51:44 -04:00
Lincoln Stein
018d5dab53 [Bugfix] make invokeai-batch work on windows (#3164)
- Previous PR to truncate long filenames won't work on windows due to
lack of support for os.pathconf(). This works around the limitation by
hardcoding the value for PC_NAME_MAX when pathconf is unavailable.
- The `multiprocessing` send() and recv() methods weren't working
properly on Windows due to issues involving `utf-8` encoding and
pickling/unpickling. Changed these calls to `send_bytes()` and
`recv_bytes()` , which seems to fix the issue.

Not fully tested on Windows since I lack a GPU machine to test on, but
is working on CPU.
2023-04-11 11:37:39 -04:00
Lincoln Stein
96a5de30e3 Merge branch 'v2.3' into bugfix/pathconf-on-windows 2023-04-11 11:11:20 -04:00
Lincoln Stein
2251d3abfe fixup relative path to devices module 2023-04-10 23:44:58 -04:00
Lincoln Stein
0b22a3f34d distinguish LoRA/LyCORIS files based on what SD model they were based on
- Attempting to run a prompt with a LoRA based on SD v1.X against a
  model based on v2.X will now throw an
  `IncompatibleModelException`. To import this exception:
  `from ldm.modules.lora_manager import IncompatibleModelException`
  (maybe this should be defined in ModelManager?)

- Enhance `LoraManager.list_loras()` to accept an optional integer
  argument, `token_vector_length`. This will filter the returned LoRA
  models to return only those that match the indicated length. Use:
  ```
  768 => for models based on SD v1.X
  1024 => for models based on SD v2.X
  ```

  Note that this filtering requires each LoRA file to be opened
  by `torch.safetensors`. It will take ~8s to scan a directory of
  40 files.

- Added new static methods to `ldm.modules.kohya_lora_manager`:
  - check_model_compatibility()
  - vector_length_from_checkpoint()
  - vector_length_from_checkpoint_file()
2023-04-10 23:33:28 -04:00
Lincoln Stein
2528e14fe9 raise generation exceptions so that frontend can catch 2023-04-10 14:26:09 -04:00
Lincoln Stein
4d62d5b802 [Bugfix] detect running invoke before updating (#3163)
This PR addresses the issue that when `invokeai-update` is run on a
Windows system, and an instance of InvokeAI is open and running, the
user's `.venv` can get corrupted.

Issue first reported here:


https://discord.com/channels/1020123559063990373/1094688269356249108/1094688434750230628
2023-04-09 22:29:46 -04:00
Lincoln Stein
17de5c7008 Merge branch 'v2.3' into bugfix/pathconf-on-windows 2023-04-09 22:10:24 -04:00
Lincoln Stein
f95403dcda Merge branch 'v2.3' into bugfix/detect-running-invoke-before-updating 2023-04-09 22:09:17 -04:00
Lincoln Stein
16ccc807cc control which revision of a diffusers model is downloaded
- Previously the user's preferred precision was used to select which
  version branch of a diffusers model would be downloaded. Half-precision
  would try to download the 'fp16' branch if it existed.

- Turns out that with waifu-diffusion this logic doesn't work, as
  'fp16' gets you waifu-diffusion v1.3, while 'main' gets you
  waifu-diffusion v1.4. Who knew?

- This PR adds a new optional "revision" field to `models.yaml`. This
  can be used to override the diffusers branch version. In the case of
  Waifu diffusion, INITIAL_MODELS.yaml now specifies the "main" branch.

- This PR also quenches the NSFW nag that downloading diffusers sometimes
  triggers.

- Closes #3160
2023-04-09 22:07:55 -04:00
Lincoln Stein
e54d060d17 send and receive messages as bytes, not objects 2023-04-09 18:17:55 -04:00
Lincoln Stein
a01f1d4940 workaround no os.pathconf() on Windows platforms
- Previous PR to truncate long filenames won't work on windows
  due to lack of support for os.pathconf(). This works around the
  limitation by hardcoding the value for PC_NAME_MAX when pathconf
  is unavailable.
2023-04-09 17:45:34 -04:00
Lincoln Stein
1873817ac9 adjustments for windows 2023-04-09 17:24:47 -04:00
Lincoln Stein
31333a736c check if invokeai is running before trying to update
- on windows systems, updating the .venv while invokeai is using it leads to
  corruption.
2023-04-09 16:57:14 -04:00
Lincoln Stein
03274b6da6 fix extracting loras from legacy blends (#3161) 2023-04-09 16:43:35 -04:00
Lincoln Stein
66364501d5 increase sha256 chunksize when calculating model hash
- Thanks to @abdBarho, who discovered that increasing the chunksize
  dramatically decreases the amount of time to calculate the hash.
2023-04-09 16:39:16 -04:00
Damian Stewart
0646649c05 fix extracting loras from legacy blends 2023-04-09 22:21:44 +02:00
Lincoln Stein
2af511c98a release 2.3.4 2023-04-09 13:31:45 -04:00
Lincoln Stein
f0039cc70a [Bugfix] truncate filenames in invokeai batch that exceed max filename length (#3143)
- This prevents `invokeai-batch` from trying to create image files whose
names would exceed PC_NAME_MAX.
- Closes #3115
2023-04-09 12:36:10 -04:00
Lincoln Stein
8fa7d5ca64 Merge branch 'v2.3' into bugfix/truncate-filenames-in-invokeai-batch 2023-04-09 12:16:06 -04:00
Lincoln Stein
d90aa42799 [WebUI] 2.3.4 UI Bug Fixes (#3139)
Some quick bug fixes related to the UI for the 2.3.4. release.

**Features:**

- Added the ability to now add Textual Inversions to the Negative Prompt
using the UI.
- Added the ability to clear Textual Inversions and Loras from Prompt
and Negative Prompt with a single click.
- Textual Inversions now have status pips - indicating whether they are
used in the Main Prompt, Negative Prompt or both.

**Fixes**

- Fixes #3138
- Fixes #3144
- Fixed `usePrompt` not updating the Lora and TI count in prompt /
negative prompt.
- Fixed the TI regex not respecting names in substrings.
- Fixed trailing spaces when adding and removing loras and TI's.
- Fixed an issue with the TI regex not respecting the `<` and `>` used
by HuggingFace concepts.
- Some other minor bug fixes.
2023-04-09 12:07:41 -04:00
Lincoln Stein
c5b34d21e5 Merge branch 'v2.3' into bugfix/truncate-filenames-in-invokeai-batch 2023-04-09 11:29:32 -04:00
blessedcoolant
40a4867143 Merge branch 'v2.3' into 234-ui-bugfixes 2023-04-09 15:56:44 +12:00
Lincoln Stein
4b25f80427 [Bugfix] Pass extra_conditioning_info in inpaint, so lora can be initialized (#3151) 2023-04-08 21:17:53 -04:00
StAlKeR7779
894e2e643d Pass extra_conditioning_info in inpaint 2023-04-09 00:50:30 +03:00
blessedcoolant
a38ff1a16b build(ui): Test Build (2.3.4 Feat Updates) 2023-04-09 07:37:41 +12:00
blessedcoolant
41f268b475 feat(ui): Improve TI & Lora UI 2023-04-09 07:35:19 +12:00
blessedcoolant
b3ae3f595f fix(ui): Fixed Use Prompt not detecting Loras / TI Count 2023-04-09 03:44:17 +12:00
blessedcoolant
29962613d8 chore(ui): Move Lora & TI Managers to Prompt Extras 2023-04-08 22:47:30 +12:00
blessedcoolant
1170cee1d8 fix(ui): Options panel sliding because of long Lora or TI names 2023-04-08 16:48:28 +12:00
Lincoln Stein
5983e65b22 invokeai-batch: truncate image filenames that exceed filesystem's max filename size
- Closes #3115
2023-04-07 18:20:32 -04:00
blessedcoolant
bc724fcdc3 fix(ui): Fix Main Width Slider being read only. 2023-04-08 04:15:55 +12:00
41 changed files with 902 additions and 484 deletions

34
.github/CODEOWNERS vendored
View File

@@ -1,13 +1,13 @@
# continuous integration # continuous integration
/.github/workflows/ @mauwii @lstein @blessedcoolant /.github/workflows/ @lstein @blessedcoolant
# documentation # documentation
/docs/ @lstein @mauwii @blessedcoolant /docs/ @lstein @blessedcoolant
mkdocs.yml @mauwii @lstein mkdocs.yml @lstein @ebr
# installation and configuration # installation and configuration
/pyproject.toml @mauwii @lstein @ebr /pyproject.toml @lstein @ebr
/docker/ @mauwii /docker/ @lstein
/scripts/ @ebr @lstein @blessedcoolant /scripts/ @ebr @lstein @blessedcoolant
/installer/ @ebr @lstein /installer/ @ebr @lstein
ldm/invoke/config @lstein @ebr ldm/invoke/config @lstein @ebr
@@ -21,13 +21,13 @@ invokeai/configs @lstein @ebr @blessedcoolant
# generation and model management # generation and model management
/ldm/*.py @lstein @blessedcoolant /ldm/*.py @lstein @blessedcoolant
/ldm/generate.py @lstein @keturn /ldm/generate.py @lstein @gregghelt2
/ldm/invoke/args.py @lstein @blessedcoolant /ldm/invoke/args.py @lstein @blessedcoolant
/ldm/invoke/ckpt* @lstein @blessedcoolant /ldm/invoke/ckpt* @lstein @blessedcoolant
/ldm/invoke/ckpt_generator @lstein @blessedcoolant /ldm/invoke/ckpt_generator @lstein @blessedcoolant
/ldm/invoke/CLI.py @lstein @blessedcoolant /ldm/invoke/CLI.py @lstein @blessedcoolant
/ldm/invoke/config @lstein @ebr @mauwii @blessedcoolant /ldm/invoke/config @lstein @ebr @blessedcoolant
/ldm/invoke/generator @keturn @damian0815 /ldm/invoke/generator @gregghelt2 @damian0815
/ldm/invoke/globals.py @lstein @blessedcoolant /ldm/invoke/globals.py @lstein @blessedcoolant
/ldm/invoke/merge_diffusers.py @lstein @blessedcoolant /ldm/invoke/merge_diffusers.py @lstein @blessedcoolant
/ldm/invoke/model_manager.py @lstein @blessedcoolant /ldm/invoke/model_manager.py @lstein @blessedcoolant
@@ -36,17 +36,17 @@ invokeai/configs @lstein @ebr @blessedcoolant
/ldm/invoke/restoration @lstein @blessedcoolant /ldm/invoke/restoration @lstein @blessedcoolant
# attention, textual inversion, model configuration # attention, textual inversion, model configuration
/ldm/models @damian0815 @keturn @blessedcoolant /ldm/models @damian0815 @gregghelt2 @blessedcoolant
/ldm/modules/textual_inversion_manager.py @lstein @blessedcoolant /ldm/modules/textual_inversion_manager.py @lstein @blessedcoolant
/ldm/modules/attention.py @damian0815 @keturn /ldm/modules/attention.py @damian0815 @gregghelt2
/ldm/modules/diffusionmodules @damian0815 @keturn /ldm/modules/diffusionmodules @damian0815 @gregghelt2
/ldm/modules/distributions @damian0815 @keturn /ldm/modules/distributions @damian0815 @gregghelt2
/ldm/modules/ema.py @damian0815 @keturn /ldm/modules/ema.py @damian0815 @gregghelt2
/ldm/modules/embedding_manager.py @lstein /ldm/modules/embedding_manager.py @lstein
/ldm/modules/encoders @damian0815 @keturn /ldm/modules/encoders @damian0815 @gregghelt2
/ldm/modules/image_degradation @damian0815 @keturn /ldm/modules/image_degradation @damian0815 @gregghelt2
/ldm/modules/losses @damian0815 @keturn /ldm/modules/losses @damian0815 @gregghelt2
/ldm/modules/x_transformer.py @damian0815 @keturn /ldm/modules/x_transformer.py @damian0815 @gregghelt2
# Nodes # Nodes
apps/ @Kyle0654 @jpphoto apps/ @Kyle0654 @jpphoto

View File

@@ -41,7 +41,7 @@ jobs:
--verbose --verbose
- name: deploy to gh-pages - name: deploy to gh-pages
if: ${{ github.ref == 'refs/heads/main' }} if: ${{ github.ref == 'refs/heads/v2.3' }}
run: | run: |
python -m \ python -m \
mkdocs gh-deploy \ mkdocs gh-deploy \

2
.gitignore vendored
View File

@@ -233,5 +233,3 @@ installer/install.sh
installer/update.bat installer/update.bat
installer/update.sh installer/update.sh
# no longer stored in source directory
models

View File

@@ -30,7 +30,6 @@ from ldm.invoke.conditioning import (
get_tokens_for_prompt_object, get_tokens_for_prompt_object,
get_prompt_structure, get_prompt_structure,
split_weighted_subprompts, split_weighted_subprompts,
get_tokenizer,
) )
from ldm.invoke.generator.diffusers_pipeline import PipelineIntermediateState from ldm.invoke.generator.diffusers_pipeline import PipelineIntermediateState
from ldm.invoke.generator.inpaint import infill_methods from ldm.invoke.generator.inpaint import infill_methods
@@ -38,11 +37,11 @@ from ldm.invoke.globals import (
Globals, Globals,
global_converted_ckpts_dir, global_converted_ckpts_dir,
global_models_dir, global_models_dir,
global_lora_models_dir,
) )
from ldm.invoke.pngwriter import PngWriter, retrieve_metadata from ldm.invoke.pngwriter import PngWriter, retrieve_metadata
from compel.prompt_parser import Blend from compel.prompt_parser import Blend
from ldm.invoke.merge_diffusers import merge_diffusion_models from ldm.invoke.merge_diffusers import merge_diffusion_models
from ldm.modules.lora_manager import LoraManager
# Loading Arguments # Loading Arguments
opt = Args() opt = Args()
@@ -524,20 +523,12 @@ class InvokeAIWebServer:
@socketio.on("getLoraModels") @socketio.on("getLoraModels")
def get_lora_models(): def get_lora_models():
try: try:
lora_path = global_lora_models_dir() model = self.generate.model
loras = [] lora_mgr = LoraManager(model)
for root, _, files in os.walk(lora_path): loras = lora_mgr.list_compatible_loras()
models = [
Path(root, x)
for x in files
if Path(x).suffix in [".ckpt", ".pt", ".safetensors"]
]
loras = loras + models
found_loras = [] found_loras = []
for lora in sorted(loras, key=lambda s: s.stem.lower()): for lora in sorted(loras, key=str.casefold):
location = str(lora.resolve()).replace("\\", "/") found_loras.append({"name":lora,"location":str(loras[lora])})
found_loras.append({"name": lora.stem, "location": location})
socketio.emit("foundLoras", found_loras) socketio.emit("foundLoras", found_loras)
except Exception as e: except Exception as e:
self.handle_exceptions(e) self.handle_exceptions(e)
@@ -1314,7 +1305,7 @@ class InvokeAIWebServer:
None None
if type(parsed_prompt) is Blend if type(parsed_prompt) is Blend
else get_tokens_for_prompt_object( else get_tokens_for_prompt_object(
get_tokenizer(self.generate.model), parsed_prompt self.generate.model.tokenizer, parsed_prompt
) )
) )
attention_maps_image_base64_url = ( attention_maps_image_base64_url = (

View File

@@ -80,7 +80,8 @@ trinart-2.0:
repo_id: stabilityai/sd-vae-ft-mse repo_id: stabilityai/sd-vae-ft-mse
recommended: False recommended: False
waifu-diffusion-1.4: waifu-diffusion-1.4:
description: An SD-1.5 model trained on 680k anime/manga-style images (2.13 GB) description: An SD-2.1 model trained on 5.4M anime/manga-style images (4.27 GB)
revision: main
repo_id: hakurei/waifu-diffusion repo_id: hakurei/waifu-diffusion
format: diffusers format: diffusers
vae: vae:

File diff suppressed because one or more lines are too long

View File

@@ -5,7 +5,7 @@
<meta name="viewport" content="width=device-width, initial-scale=1.0" /> <meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>InvokeAI - A Stable Diffusion Toolkit</title> <title>InvokeAI - A Stable Diffusion Toolkit</title>
<link rel="shortcut icon" type="icon" href="./assets/favicon-0d253ced.ico" /> <link rel="shortcut icon" type="icon" href="./assets/favicon-0d253ced.ico" />
<script type="module" crossorigin src="./assets/index-c1535364.js"></script> <script type="module" crossorigin src="./assets/index-b12e648e.js"></script>
<link rel="stylesheet" href="./assets/index-2ab0eb58.css"> <link rel="stylesheet" href="./assets/index-2ab0eb58.css">
</head> </head>

View File

@@ -328,8 +328,11 @@
"updateModel": "Update Model", "updateModel": "Update Model",
"availableModels": "Available Models", "availableModels": "Available Models",
"addLora": "Add Lora", "addLora": "Add Lora",
"clearLoras": "Clear Loras",
"noLoraModels": "No Loras Found", "noLoraModels": "No Loras Found",
"addTextualInversionTrigger": "Add Textual Inversion", "addTextualInversionTrigger": "Add Textual Inversion",
"addTIToNegative": "Add To Negative",
"clearTextualInversions": "Clear Textual Inversions",
"noTextualInversionTriggers": "No Textual Inversions Found", "noTextualInversionTriggers": "No Textual Inversions Found",
"search": "Search", "search": "Search",
"load": "Load", "load": "Load",

View File

@@ -328,8 +328,11 @@
"updateModel": "Update Model", "updateModel": "Update Model",
"availableModels": "Available Models", "availableModels": "Available Models",
"addLora": "Add Lora", "addLora": "Add Lora",
"clearLoras": "Clear Loras",
"noLoraModels": "No Loras Found", "noLoraModels": "No Loras Found",
"addTextualInversionTrigger": "Add Textual Inversion", "addTextualInversionTrigger": "Add Textual Inversion",
"addTIToNegative": "Add To Negative",
"clearTextualInversions": "Clear Textual Inversions",
"noTextualInversionTriggers": "No Textual Inversions Found", "noTextualInversionTriggers": "No Textual Inversions Found",
"search": "Search", "search": "Search",
"load": "Load", "load": "Load",

View File

@@ -33,6 +33,10 @@ import {
setIntermediateImage, setIntermediateImage,
} from 'features/gallery/store/gallerySlice'; } from 'features/gallery/store/gallerySlice';
import {
getLoraModels,
getTextualInversionTriggers,
} from 'app/socketio/actions';
import type { RootState } from 'app/store'; import type { RootState } from 'app/store';
import { addImageToStagingArea } from 'features/canvas/store/canvasSlice'; import { addImageToStagingArea } from 'features/canvas/store/canvasSlice';
import { import {
@@ -463,6 +467,8 @@ const makeSocketIOListeners = (
const { model_name, model_list } = data; const { model_name, model_list } = data;
dispatch(setModelList(model_list)); dispatch(setModelList(model_list));
dispatch(setCurrentStatus(i18n.t('common.statusModelChanged'))); dispatch(setCurrentStatus(i18n.t('common.statusModelChanged')));
dispatch(getLoraModels());
dispatch(getTextualInversionTriggers());
dispatch(setIsProcessing(false)); dispatch(setIsProcessing(false));
dispatch(setIsCancelable(true)); dispatch(setIsCancelable(true));
dispatch( dispatch(

View File

@@ -92,7 +92,8 @@ export default function IAISimpleMenu(props: IAIMenuProps) {
zIndex={15} zIndex={15}
padding={0} padding={0}
borderRadius="0.5rem" borderRadius="0.5rem"
overflowY="scroll" overflow="scroll"
maxWidth={'22.5rem'}
maxHeight={500} maxHeight={500}
backgroundColor="var(--background-color-secondary)" backgroundColor="var(--background-color-secondary)"
color="var(--text-color-secondary)" color="var(--text-color-secondary)"

View File

@@ -34,7 +34,6 @@ export default function MainWidth() {
withSliderMarks withSliderMarks
sliderMarkRightOffset={-8} sliderMarkRightOffset={-8}
inputWidth="6.2rem" inputWidth="6.2rem"
inputReadOnly
sliderNumberInputProps={{ max: 15360 }} sliderNumberInputProps={{ max: 15360 }}
/> />
) : ( ) : (

View File

@@ -1,10 +1,15 @@
import { Box } from '@chakra-ui/react'; import { Box, Flex } from '@chakra-ui/react';
import { getLoraModels } from 'app/socketio/actions'; import { getLoraModels } from 'app/socketio/actions';
import { useAppDispatch, useAppSelector } from 'app/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton';
import IAISimpleMenu, { IAIMenuItem } from 'common/components/IAISimpleMenu'; import IAISimpleMenu, { IAIMenuItem } from 'common/components/IAISimpleMenu';
import { setLorasInUse } from 'features/parameters/store/generationSlice'; import {
setClearLoras,
setLorasInUse,
} from 'features/parameters/store/generationSlice';
import { useEffect } from 'react'; import { useEffect } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { MdClear } from 'react-icons/md';
export default function LoraManager() { export default function LoraManager() {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
@@ -53,11 +58,20 @@ export default function LoraManager() {
}; };
return foundLoras && foundLoras?.length > 0 ? ( return foundLoras && foundLoras?.length > 0 ? (
<IAISimpleMenu <Flex columnGap={2}>
menuItems={makeLoraItems()} <IAISimpleMenu
menuType="regular" menuItems={makeLoraItems()}
buttonText={`${t('modelManager.addLora')} (${numOfActiveLoras()})`} menuType="regular"
/> buttonText={`${t('modelManager.addLora')} (${numOfActiveLoras()})`}
menuButtonProps={{ width: '100%', padding: '0 1rem' }}
/>
<IAIIconButton
icon={<MdClear />}
tooltip={t('modelManager.clearLoras')}
aria-label={t('modelManager.clearLoras')}
onClick={() => dispatch(setClearLoras())}
/>
</Flex>
) : ( ) : (
<Box <Box
background="var(--btn-base-color)" background="var(--btn-base-color)"

View File

@@ -0,0 +1,12 @@
import { Flex } from '@chakra-ui/react';
import LoraManager from './LoraManager/LoraManager';
import TextualInversionManager from './TextualInversionManager/TextualInversionManager';
export default function PromptExtras() {
return (
<Flex flexDir="column" rowGap={2}>
<LoraManager />
<TextualInversionManager />
</Flex>
);
}

View File

@@ -1,17 +1,28 @@
import { Box } from '@chakra-ui/react'; import { Box, Flex } from '@chakra-ui/react';
import { getTextualInversionTriggers } from 'app/socketio/actions'; import { getTextualInversionTriggers } from 'app/socketio/actions';
import { RootState } from 'app/store'; import { RootState } from 'app/store';
import { useAppDispatch, useAppSelector } from 'app/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton';
import IAISimpleMenu, { IAIMenuItem } from 'common/components/IAISimpleMenu'; import IAISimpleMenu, { IAIMenuItem } from 'common/components/IAISimpleMenu';
import { setTextualInversionsInUse } from 'features/parameters/store/generationSlice'; import {
setAddTIToNegative,
setClearTextualInversions,
setTextualInversionsInUse,
} from 'features/parameters/store/generationSlice';
import { useEffect } from 'react'; import { useEffect } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { MdArrowDownward, MdClear } from 'react-icons/md';
export default function TextualInversionManager() { export default function TextualInversionManager() {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const textualInversionsInUse = useAppSelector( const textualInversionsInUse = useAppSelector(
(state: RootState) => state.generation.textualInversionsInUse (state: RootState) => state.generation.textualInversionsInUse
); );
const negativeTextualInversionsInUse = useAppSelector(
(state: RootState) => state.generation.negativeTextualInversionsInUse
);
const foundLocalTextualInversionTriggers = useAppSelector( const foundLocalTextualInversionTriggers = useAppSelector(
(state) => state.system.foundLocalTextualInversionTriggers (state) => state.system.foundLocalTextualInversionTriggers
); );
@@ -31,6 +42,10 @@ export default function TextualInversionManager() {
(state) => state.ui.shouldShowHuggingFaceConcepts (state) => state.ui.shouldShowHuggingFaceConcepts
); );
const addTIToNegative = useAppSelector(
(state) => state.generation.addTIToNegative
);
const { t } = useTranslation(); const { t } = useTranslation();
useEffect(() => { useEffect(() => {
@@ -41,14 +56,25 @@ export default function TextualInversionManager() {
dispatch(setTextualInversionsInUse(textual_inversion)); dispatch(setTextualInversionsInUse(textual_inversion));
}; };
const renderTextualInversionOption = (textual_inversion: string) => { const TIPip = ({ color }: { color: string }) => {
const thisTIExists = textualInversionsInUse.includes(textual_inversion);
const tiExistsStyle = {
fontWeight: 'bold',
color: 'var(--context-menu-active-item)',
};
return ( return (
<Box style={thisTIExists ? tiExistsStyle : {}}>{textual_inversion}</Box> <Box width={2} height={2} borderRadius={9999} backgroundColor={color}>
{' '}
</Box>
);
};
const renderTextualInversionOption = (textual_inversion: string) => {
return (
<Flex alignItems="center" columnGap={1}>
{textual_inversion}
{textualInversionsInUse.includes(textual_inversion) && (
<TIPip color="var(--context-menu-active-item)" />
)}
{negativeTextualInversionsInUse.includes(textual_inversion) && (
<TIPip color="var(--status-bad-color)" />
)}
</Flex>
); );
}; };
@@ -56,8 +82,10 @@ export default function TextualInversionManager() {
const allTextualInversions = localTextualInversionTriggers.concat( const allTextualInversions = localTextualInversionTriggers.concat(
huggingFaceTextualInversionConcepts huggingFaceTextualInversionConcepts
); );
return allTextualInversions.filter((ti) => return allTextualInversions.filter(
textualInversionsInUse.includes(ti) (ti) =>
textualInversionsInUse.includes(ti) ||
negativeTextualInversionsInUse.includes(ti)
).length; ).length;
}; };
@@ -93,13 +121,34 @@ export default function TextualInversionManager() {
(foundHuggingFaceTextualInversionTriggers && (foundHuggingFaceTextualInversionTriggers &&
foundHuggingFaceTextualInversionTriggers?.length > 0 && foundHuggingFaceTextualInversionTriggers?.length > 0 &&
shouldShowHuggingFaceConcepts)) ? ( shouldShowHuggingFaceConcepts)) ? (
<IAISimpleMenu <Flex columnGap={2}>
menuItems={makeTextualInversionItems()} <IAISimpleMenu
menuType="regular" menuItems={makeTextualInversionItems()}
buttonText={`${t( menuType="regular"
'modelManager.addTextualInversionTrigger' buttonText={`${t(
)} (${numOfActiveTextualInversions()})`} 'modelManager.addTextualInversionTrigger'
/> )} (${numOfActiveTextualInversions()})`}
menuButtonProps={{
width: '100%',
padding: '0 1rem',
}}
/>
<IAIIconButton
icon={<MdArrowDownward />}
style={{
backgroundColor: addTIToNegative ? 'var(--btn-delete-image)' : '',
}}
tooltip={t('modelManager.addTIToNegative')}
aria-label={t('modelManager.addTIToNegative')}
onClick={() => dispatch(setAddTIToNegative(!addTIToNegative))}
/>
<IAIIconButton
icon={<MdClear />}
tooltip={t('modelManager.clearTextualInversions')}
aria-label={t('modelManager.clearTextualInversions')}
onClick={() => dispatch(setClearTextualInversions())}
/>
</Flex>
) : ( ) : (
<Box <Box
background="var(--btn-base-color)" background="var(--btn-base-color)"

View File

@@ -1,24 +1,43 @@
import { FormControl, Textarea } from '@chakra-ui/react'; import { FormControl, Textarea } from '@chakra-ui/react';
import type { RootState } from 'app/store'; import type { RootState } from 'app/store';
import { useAppDispatch, useAppSelector } from 'app/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import { setNegativePrompt } from 'features/parameters/store/generationSlice'; import {
handlePromptCheckers,
setNegativePrompt,
} from 'features/parameters/store/generationSlice';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { ChangeEvent, useState } from 'react';
const NegativePromptInput = () => { const NegativePromptInput = () => {
const negativePrompt = useAppSelector( const negativePrompt = useAppSelector(
(state: RootState) => state.generation.negativePrompt (state: RootState) => state.generation.negativePrompt
); );
const [promptTimer, setPromptTimer] = useState<number | undefined>(undefined);
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
const handleNegativeChangePrompt = (e: ChangeEvent<HTMLTextAreaElement>) => {
dispatch(setNegativePrompt(e.target.value));
// Debounce Prompt UI Checking
clearTimeout(promptTimer);
const newPromptTimer = window.setTimeout(() => {
dispatch(
handlePromptCheckers({ prompt: e.target.value, toNegative: true })
);
}, 500);
setPromptTimer(newPromptTimer);
};
return ( return (
<FormControl> <FormControl>
<Textarea <Textarea
id="negativePrompt" id="negativePrompt"
name="negativePrompt" name="negativePrompt"
value={negativePrompt} value={negativePrompt}
onChange={(e) => dispatch(setNegativePrompt(e.target.value))} onChange={handleNegativeChangePrompt}
background="var(--prompt-bg-color)" background="var(--prompt-bg-color)"
placeholder={t('parameters.negativePrompts')} placeholder={t('parameters.negativePrompts')}
_placeholder={{ fontSize: '0.8rem' }} _placeholder={{ fontSize: '0.8rem' }}

View File

@@ -51,7 +51,9 @@ const PromptInput = () => {
// Debounce Prompt UI Checking // Debounce Prompt UI Checking
clearTimeout(promptTimer); clearTimeout(promptTimer);
const newPromptTimer = window.setTimeout(() => { const newPromptTimer = window.setTimeout(() => {
dispatch(handlePromptCheckers(e.target.value)); dispatch(
handlePromptCheckers({ prompt: e.target.value, toNegative: false })
);
}, 500); }, 500);
setPromptTimer(newPromptTimer); setPromptTimer(newPromptTimer);
}; };

View File

@@ -3,7 +3,11 @@ import { getPromptAndNegative } from 'common/util/getPromptAndNegative';
import * as InvokeAI from 'app/invokeai'; import * as InvokeAI from 'app/invokeai';
import promptToString from 'common/util/promptToString'; import promptToString from 'common/util/promptToString';
import { useAppDispatch } from 'app/storeHooks'; import { useAppDispatch } from 'app/storeHooks';
import { setNegativePrompt, setPrompt } from '../store/generationSlice'; import {
handlePromptCheckers,
setNegativePrompt,
setPrompt,
} from '../store/generationSlice';
// TECHDEBT: We have two metadata prompt formats and need to handle recalling either of them. // TECHDEBT: We have two metadata prompt formats and need to handle recalling either of them.
// This hook provides a function to do that. // This hook provides a function to do that.
@@ -20,6 +24,10 @@ const useSetBothPrompts = () => {
dispatch(setPrompt(prompt)); dispatch(setPrompt(prompt));
dispatch(setNegativePrompt(negativePrompt)); dispatch(setNegativePrompt(negativePrompt));
dispatch(handlePromptCheckers({ prompt: prompt, toNegative: false }));
dispatch(
handlePromptCheckers({ prompt: negativePrompt, toNegative: true })
);
}; };
}; };

View File

@@ -18,9 +18,11 @@ export interface GenerationState {
prompt: string; prompt: string;
negativePrompt: string; negativePrompt: string;
lorasInUse: string[]; lorasInUse: string[];
localTextualInversionTriggers: string[];
huggingFaceTextualInversionConcepts: string[]; huggingFaceTextualInversionConcepts: string[];
localTextualInversionTriggers: string[];
textualInversionsInUse: string[]; textualInversionsInUse: string[];
negativeTextualInversionsInUse: string[];
addTIToNegative: boolean;
sampler: string; sampler: string;
seamBlur: number; seamBlur: number;
seamless: boolean; seamless: boolean;
@@ -53,9 +55,11 @@ const initialGenerationState: GenerationState = {
prompt: '', prompt: '',
negativePrompt: '', negativePrompt: '',
lorasInUse: [], lorasInUse: [],
localTextualInversionTriggers: [],
huggingFaceTextualInversionConcepts: [], huggingFaceTextualInversionConcepts: [],
localTextualInversionTriggers: [],
textualInversionsInUse: [], textualInversionsInUse: [],
negativeTextualInversionsInUse: [],
addTIToNegative: false,
sampler: 'k_lms', sampler: 'k_lms',
seamBlur: 16, seamBlur: 16,
seamless: false, seamless: false,
@@ -85,15 +89,86 @@ const loraExists = (state: GenerationState, lora: string) => {
return false; return false;
}; };
const getTIRegex = (textualInversion: string) => {
if (textualInversion.includes('<' || '>')) {
return new RegExp(`${textualInversion}`);
} else {
return new RegExp(`\\b${textualInversion}\\b`);
}
};
const textualInversionExists = ( const textualInversionExists = (
state: GenerationState, state: GenerationState,
textualInversion: string textualInversion: string
) => { ) => {
const textualInversionRegex = new RegExp(textualInversion); const textualInversionRegex = getTIRegex(textualInversion);
if (state.prompt.match(textualInversionRegex)) return true;
if (!state.addTIToNegative) {
if (state.prompt.match(textualInversionRegex)) return true;
} else {
if (state.negativePrompt.match(textualInversionRegex)) return true;
}
return false; return false;
}; };
const handleTypedTICheck = (
state: GenerationState,
newPrompt: string,
toNegative: boolean
) => {
let textualInversionsInUse = !toNegative
? [...state.textualInversionsInUse]
: [...state.negativeTextualInversionsInUse]; // Get Words In Prompt
const textualInversionRegex = /([\w<>!@%&*_-]+)/g; // Scan For Each Word
const textualInversionMatches = [
...newPrompt.matchAll(textualInversionRegex),
]; // Match All Words
if (textualInversionMatches.length > 0) {
textualInversionsInUse = []; // Reset Textual Inversions In Use
textualInversionMatches.forEach((textualInversionMatch) => {
const textualInversionName = textualInversionMatch[0];
if (
(!textualInversionsInUse.includes(textualInversionName) &&
state.localTextualInversionTriggers.includes(textualInversionName)) ||
state.huggingFaceTextualInversionConcepts.includes(textualInversionName)
) {
textualInversionsInUse.push(textualInversionName); // Add Textual Inversions In Prompt
}
});
} else {
textualInversionsInUse = []; // If No Matches, Remove Textual Inversions In Use
}
if (!toNegative) {
state.textualInversionsInUse = textualInversionsInUse;
} else {
state.negativeTextualInversionsInUse = textualInversionsInUse;
}
};
const handleTypedLoraCheck = (state: GenerationState, newPrompt: string) => {
let lorasInUse = [...state.lorasInUse]; // Get Loras In Prompt
const loraRegex = /withLora\(([^\\)]+)\)/g; // Scan For Lora Syntax
const loraMatches = [...newPrompt.matchAll(loraRegex)]; // Match All Lora Syntaxes
if (loraMatches.length > 0) {
lorasInUse = []; // Reset Loras In Use
loraMatches.forEach((loraMatch) => {
const loraName = loraMatch[1].split(',')[0];
if (!lorasInUse.includes(loraName)) lorasInUse.push(loraName); // Add Loras In Prompt
});
} else {
lorasInUse = []; // If No Matches, Remove Loras In Use
}
state.lorasInUse = lorasInUse;
};
export const generationSlice = createSlice({ export const generationSlice = createSlice({
name: 'generation', name: 'generation',
initialState, initialState,
@@ -118,6 +193,20 @@ export const generationSlice = createSlice({
state.negativePrompt = promptToString(newPrompt); state.negativePrompt = promptToString(newPrompt);
} }
}, },
handlePromptCheckers: (
state,
action: PayloadAction<{
prompt: string | InvokeAI.Prompt;
toNegative: boolean;
}>
) => {
const newPrompt = action.payload.prompt;
if (typeof newPrompt === 'string') {
if (!action.payload.toNegative) handleTypedLoraCheck(state, newPrompt);
handleTypedTICheck(state, newPrompt, action.payload.toNegative);
}
},
setLorasInUse: (state, action: PayloadAction<string>) => { setLorasInUse: (state, action: PayloadAction<string>) => {
const newLora = action.payload; const newLora = action.payload;
const loras = [...state.lorasInUse]; const loras = [...state.lorasInUse];
@@ -128,94 +217,99 @@ export const generationSlice = createSlice({
'g' 'g'
); );
const newPrompt = state.prompt.replaceAll(loraRegex, ''); const newPrompt = state.prompt.replaceAll(loraRegex, '');
state.prompt = newPrompt; state.prompt = newPrompt.trim();
if (loras.includes(newLora)) { if (loras.includes(newLora)) {
const newLoraIndex = loras.indexOf(newLora); const newLoraIndex = loras.indexOf(newLora);
if (newLoraIndex > -1) loras.splice(newLoraIndex, 1); if (newLoraIndex > -1) loras.splice(newLoraIndex, 1);
} }
} else { } else {
state.prompt = `${state.prompt} withLora(${newLora},0.75)`; state.prompt = `${state.prompt.trim()} withLora(${newLora},0.75)`;
if (!loras.includes(newLora)) loras.push(newLora); if (!loras.includes(newLora)) loras.push(newLora);
} }
state.lorasInUse = loras; state.lorasInUse = loras;
}, },
handlePromptCheckers: ( setClearLoras: (state) => {
state, const lorasInUse = [...state.lorasInUse];
action: PayloadAction<string | InvokeAI.Prompt>
) => {
const newPrompt = action.payload;
// Tackle User Typed Lora Syntax lorasInUse.forEach((lora) => {
let lorasInUse = [...state.lorasInUse]; // Get Loras In Prompt const loraRegex = new RegExp(
const loraRegex = /withLora\(([^\\)]+)\)/g; // Scan For Lora Syntax `withLora\\(${lora},?\\s*([^\\)]+)?\\)`,
if (typeof newPrompt === 'string') { 'g'
const loraMatches = [...newPrompt.matchAll(loraRegex)]; // Match All Lora Syntaxes );
if (loraMatches.length > 0) { const newPrompt = state.prompt.replaceAll(loraRegex, '');
lorasInUse = []; // Reset Loras In Use state.prompt = newPrompt.trim();
loraMatches.forEach((loraMatch) => { });
const loraName = loraMatch[1].split(',')[0];
if (!lorasInUse.includes(loraName)) lorasInUse.push(loraName); // Add Loras In Prompt
});
} else {
lorasInUse = []; // If No Matches, Remove Loras In Use
}
}
state.lorasInUse = lorasInUse;
// Tackle User Typed Textual Inversion state.lorasInUse = [];
let textualInversionsInUse = [...state.textualInversionsInUse]; // Get Words In Prompt
const textualInversionRegex = /([\w<>!@%&*_-]+)/g; // Scan For Each Word
if (typeof newPrompt === 'string') {
const textualInversionMatches = [
...newPrompt.matchAll(textualInversionRegex),
]; // Match All Words
if (textualInversionMatches.length > 0) {
textualInversionsInUse = []; // Reset Textual Inversions In Use
console.log(textualInversionMatches);
textualInversionMatches.forEach((textualInversionMatch) => {
const textualInversionName = textualInversionMatch[0];
console.log(textualInversionName);
if (
!textualInversionsInUse.includes(textualInversionName) &&
(state.localTextualInversionTriggers.includes(
textualInversionName
) ||
state.huggingFaceTextualInversionConcepts.includes(
textualInversionName
))
)
textualInversionsInUse.push(textualInversionName); // Add Textual Inversions In Prompt
});
} else {
textualInversionsInUse = []; // If No Matches, Remove Textual Inversions In Use
}
}
console.log([...state.huggingFaceTextualInversionConcepts]);
state.textualInversionsInUse = textualInversionsInUse;
}, },
setTextualInversionsInUse: (state, action: PayloadAction<string>) => { setTextualInversionsInUse: (state, action: PayloadAction<string>) => {
const newTextualInversion = action.payload; const newTextualInversion = action.payload;
const textualInversions = [...state.textualInversionsInUse]; const textualInversions = [...state.textualInversionsInUse];
const negativeTextualInversions = [
...state.negativeTextualInversionsInUse,
];
if (textualInversionExists(state, newTextualInversion)) { if (textualInversionExists(state, newTextualInversion)) {
const textualInversionRegex = new RegExp(newTextualInversion, 'g'); const textualInversionRegex = getTIRegex(newTextualInversion);
const newPrompt = state.prompt.replaceAll(textualInversionRegex, '');
state.prompt = newPrompt; if (!state.addTIToNegative) {
const newPrompt = state.prompt.replace(textualInversionRegex, '');
state.prompt = newPrompt.trim();
if (textualInversions.includes(newTextualInversion)) {
const newTIIndex = textualInversions.indexOf(newTextualInversion); const newTIIndex = textualInversions.indexOf(newTextualInversion);
if (newTIIndex > -1) textualInversions.splice(newTIIndex, 1); if (newTIIndex > -1) textualInversions.splice(newTIIndex, 1);
} else {
const newPrompt = state.negativePrompt.replace(
textualInversionRegex,
''
);
state.negativePrompt = newPrompt.trim();
const newTIIndex =
negativeTextualInversions.indexOf(newTextualInversion);
if (newTIIndex > -1) negativeTextualInversions.splice(newTIIndex, 1);
} }
} else { } else {
state.prompt = `${state.prompt} ${newTextualInversion}`; if (!state.addTIToNegative) {
if (!textualInversions.includes(newTextualInversion)) state.prompt = `${state.prompt.trim()} ${newTextualInversion}`;
textualInversions.push(newTextualInversion); textualInversions.push(newTextualInversion);
} else {
state.negativePrompt = `${state.negativePrompt.trim()} ${newTextualInversion}`;
negativeTextualInversions.push(newTextualInversion);
}
} }
state.lorasInUse = textualInversions;
state.textualInversionsInUse = textualInversions; state.textualInversionsInUse = textualInversions;
state.negativeTextualInversionsInUse = negativeTextualInversions;
},
setClearTextualInversions: (state) => {
const textualInversions = [...state.textualInversionsInUse];
const negativeTextualInversions = [
...state.negativeTextualInversionsInUse,
];
textualInversions.forEach((ti) => {
const textualInversionRegex = getTIRegex(ti);
const newPrompt = state.prompt.replace(textualInversionRegex, '');
state.prompt = newPrompt.trim();
});
negativeTextualInversions.forEach((ti) => {
const textualInversionRegex = getTIRegex(ti);
const newPrompt = state.negativePrompt.replace(
textualInversionRegex,
''
);
state.negativePrompt = newPrompt.trim();
});
state.textualInversionsInUse = [];
state.negativeTextualInversionsInUse = [];
},
setAddTIToNegative: (state, action: PayloadAction<boolean>) => {
state.addTIToNegative = action.payload;
}, },
setLocalTextualInversionTriggers: ( setLocalTextualInversionTriggers: (
state, state,
@@ -509,11 +603,14 @@ export const {
setPerlin, setPerlin,
setPrompt, setPrompt,
setNegativePrompt, setNegativePrompt,
setLorasInUse,
setLocalTextualInversionTriggers,
setHuggingFaceTextualInversionConcepts,
setTextualInversionsInUse,
handlePromptCheckers, handlePromptCheckers,
setLorasInUse,
setClearLoras,
setHuggingFaceTextualInversionConcepts,
setLocalTextualInversionTriggers,
setTextualInversionsInUse,
setAddTIToNegative,
setClearTextualInversions,
setSampler, setSampler,
setSeamBlur, setSeamBlur,
setSeamless, setSeamless,

View File

@@ -18,8 +18,7 @@ import PromptInput from 'features/parameters/components/PromptInput/PromptInput'
import InvokeOptionsPanel from 'features/ui/components/InvokeParametersPanel'; import InvokeOptionsPanel from 'features/ui/components/InvokeParametersPanel';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import ImageToImageOptions from './ImageToImageOptions'; import ImageToImageOptions from './ImageToImageOptions';
import LoraManager from 'features/parameters/components/LoraManager/LoraManager'; import PromptExtras from 'features/parameters/components/PromptInput/Extras/PromptExtras';
import TextualInversionManager from 'features/parameters/components/TextualInversionManager/TextualInversionManager';
export default function ImageToImagePanel() { export default function ImageToImagePanel() {
const { t } = useTranslation(); const { t } = useTranslation();
@@ -65,8 +64,7 @@ export default function ImageToImagePanel() {
<Flex flexDir="column" rowGap="0.5rem"> <Flex flexDir="column" rowGap="0.5rem">
<PromptInput /> <PromptInput />
<NegativePromptInput /> <NegativePromptInput />
<LoraManager /> <PromptExtras />
<TextualInversionManager />
</Flex> </Flex>
<ProcessButtons /> <ProcessButtons />
<MainSettings /> <MainSettings />

View File

@@ -10,8 +10,6 @@ import UpscaleSettings from 'features/parameters/components/AdvancedParameters/U
import UpscaleToggle from 'features/parameters/components/AdvancedParameters/Upscale/UpscaleToggle'; import UpscaleToggle from 'features/parameters/components/AdvancedParameters/Upscale/UpscaleToggle';
import GenerateVariationsToggle from 'features/parameters/components/AdvancedParameters/Variations/GenerateVariations'; import GenerateVariationsToggle from 'features/parameters/components/AdvancedParameters/Variations/GenerateVariations';
import VariationsSettings from 'features/parameters/components/AdvancedParameters/Variations/VariationsSettings'; import VariationsSettings from 'features/parameters/components/AdvancedParameters/Variations/VariationsSettings';
import LoraManager from 'features/parameters/components/LoraManager/LoraManager';
import TextualInversionManager from 'features/parameters/components/TextualInversionManager/TextualInversionManager';
import MainSettings from 'features/parameters/components/MainParameters/MainParameters'; import MainSettings from 'features/parameters/components/MainParameters/MainParameters';
import ParametersAccordion from 'features/parameters/components/ParametersAccordion'; import ParametersAccordion from 'features/parameters/components/ParametersAccordion';
import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons'; import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons';
@@ -19,6 +17,7 @@ import NegativePromptInput from 'features/parameters/components/PromptInput/Nega
import PromptInput from 'features/parameters/components/PromptInput/PromptInput'; import PromptInput from 'features/parameters/components/PromptInput/PromptInput';
import InvokeOptionsPanel from 'features/ui/components/InvokeParametersPanel'; import InvokeOptionsPanel from 'features/ui/components/InvokeParametersPanel';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import PromptExtras from 'features/parameters/components/PromptInput/Extras/PromptExtras';
export default function TextToImagePanel() { export default function TextToImagePanel() {
const { t } = useTranslation(); const { t } = useTranslation();
@@ -64,8 +63,7 @@ export default function TextToImagePanel() {
<Flex flexDir="column" rowGap="0.5rem"> <Flex flexDir="column" rowGap="0.5rem">
<PromptInput /> <PromptInput />
<NegativePromptInput /> <NegativePromptInput />
<LoraManager /> <PromptExtras />
<TextualInversionManager />
</Flex> </Flex>
<ProcessButtons /> <ProcessButtons />
<MainSettings /> <MainSettings />

View File

@@ -10,8 +10,6 @@ import SymmetryToggle from 'features/parameters/components/AdvancedParameters/Ou
import SeedSettings from 'features/parameters/components/AdvancedParameters/Seed/SeedSettings'; import SeedSettings from 'features/parameters/components/AdvancedParameters/Seed/SeedSettings';
import GenerateVariationsToggle from 'features/parameters/components/AdvancedParameters/Variations/GenerateVariations'; import GenerateVariationsToggle from 'features/parameters/components/AdvancedParameters/Variations/GenerateVariations';
import VariationsSettings from 'features/parameters/components/AdvancedParameters/Variations/VariationsSettings'; import VariationsSettings from 'features/parameters/components/AdvancedParameters/Variations/VariationsSettings';
import LoraManager from 'features/parameters/components/LoraManager/LoraManager';
import TextualInversionManager from 'features/parameters/components/TextualInversionManager/TextualInversionManager';
import MainSettings from 'features/parameters/components/MainParameters/MainParameters'; import MainSettings from 'features/parameters/components/MainParameters/MainParameters';
import ParametersAccordion from 'features/parameters/components/ParametersAccordion'; import ParametersAccordion from 'features/parameters/components/ParametersAccordion';
import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons'; import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons';
@@ -19,6 +17,7 @@ import NegativePromptInput from 'features/parameters/components/PromptInput/Nega
import PromptInput from 'features/parameters/components/PromptInput/PromptInput'; import PromptInput from 'features/parameters/components/PromptInput/PromptInput';
import InvokeOptionsPanel from 'features/ui/components/InvokeParametersPanel'; import InvokeOptionsPanel from 'features/ui/components/InvokeParametersPanel';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import PromptExtras from 'features/parameters/components/PromptInput/Extras/PromptExtras';
export default function UnifiedCanvasPanel() { export default function UnifiedCanvasPanel() {
const { t } = useTranslation(); const { t } = useTranslation();
@@ -75,8 +74,7 @@ export default function UnifiedCanvasPanel() {
<Flex flexDir="column" rowGap="0.5rem"> <Flex flexDir="column" rowGap="0.5rem">
<PromptInput /> <PromptInput />
<NegativePromptInput /> <NegativePromptInput />
<LoraManager /> <PromptExtras />
<TextualInversionManager />
</Flex> </Flex>
<ProcessButtons /> <ProcessButtons />
<MainSettings /> <MainSettings />

File diff suppressed because one or more lines are too long

View File

@@ -633,9 +633,8 @@ class Generate:
except RuntimeError: except RuntimeError:
# Clear the CUDA cache on an exception # Clear the CUDA cache on an exception
self.clear_cuda_cache() self.clear_cuda_cache()
print("** Could not generate image.")
print(traceback.format_exc(), file=sys.stderr) raise
print(">> Could not generate image.")
toc = time.time() toc = time.time()
print("\n>> Usage stats:") print("\n>> Usage stats:")

View File

@@ -1 +1 @@
__version__='2.3.4rc1' __version__='2.3.5'

View File

@@ -12,21 +12,13 @@ from typing import Union, Optional, Any
from transformers import CLIPTokenizer from transformers import CLIPTokenizer
from compel import Compel from compel import Compel
from compel.prompt_parser import FlattenedPrompt, Blend, Fragment, CrossAttentionControlSubstitute, PromptParser from compel.prompt_parser import FlattenedPrompt, Blend, Fragment, CrossAttentionControlSubstitute, PromptParser, \
Conjunction
from .devices import torch_dtype from .devices import torch_dtype
from .generator.diffusers_pipeline import StableDiffusionGeneratorPipeline
from ..models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent from ..models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
from ldm.invoke.globals import Globals from ldm.invoke.globals import Globals
def get_tokenizer(model) -> CLIPTokenizer:
# TODO remove legacy ckpt fallback handling
return (getattr(model, 'tokenizer', None) # diffusers
or model.cond_stage_model.tokenizer) # ldm
def get_text_encoder(model) -> Any:
# TODO remove legacy ckpt fallback handling
return (getattr(model, 'text_encoder', None) # diffusers
or UnsqueezingLDMTransformer(model.cond_stage_model.transformer)) # ldm
class UnsqueezingLDMTransformer: class UnsqueezingLDMTransformer:
def __init__(self, ldm_transformer): def __init__(self, ldm_transformer):
self.ldm_transformer = ldm_transformer self.ldm_transformer = ldm_transformer
@@ -40,48 +32,57 @@ class UnsqueezingLDMTransformer:
return insufficiently_unsqueezed_tensor.unsqueeze(0) return insufficiently_unsqueezed_tensor.unsqueeze(0)
def get_uc_and_c_and_ec(prompt_string, model, log_tokens=False, skip_normalize_legacy_blend=False): def get_uc_and_c_and_ec(prompt_string,
model: StableDiffusionGeneratorPipeline,
log_tokens=False, skip_normalize_legacy_blend=False):
# lazy-load any deferred textual inversions. # lazy-load any deferred textual inversions.
# this might take a couple of seconds the first time a textual inversion is used. # this might take a couple of seconds the first time a textual inversion is used.
model.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms(prompt_string) model.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms(prompt_string)
tokenizer = get_tokenizer(model) compel = Compel(tokenizer=model.tokenizer,
text_encoder = get_text_encoder(model) text_encoder=model.text_encoder,
compel = Compel(tokenizer=tokenizer,
text_encoder=text_encoder,
textual_inversion_manager=model.textual_inversion_manager, textual_inversion_manager=model.textual_inversion_manager,
dtype_for_device_getter=torch_dtype) dtype_for_device_getter=torch_dtype)
# get rid of any newline characters # get rid of any newline characters
prompt_string = prompt_string.replace("\n", " ") prompt_string = prompt_string.replace("\n", " ")
positive_prompt_string, negative_prompt_string = split_prompt_to_positive_and_negative(prompt_string) positive_prompt_string, negative_prompt_string = split_prompt_to_positive_and_negative(prompt_string)
legacy_blend = try_parse_legacy_blend(positive_prompt_string, skip_normalize_legacy_blend) legacy_blend = try_parse_legacy_blend(positive_prompt_string, skip_normalize_legacy_blend)
positive_prompt: FlattenedPrompt|Blend positive_conjunction: Conjunction
lora_conditions = None
if legacy_blend is not None: if legacy_blend is not None:
positive_prompt = legacy_blend positive_conjunction = legacy_blend
else: else:
positive_conjunction = Compel.parse_prompt_string(positive_prompt_string) positive_conjunction = Compel.parse_prompt_string(positive_prompt_string)
positive_prompt = positive_conjunction.prompts[0] positive_prompt = positive_conjunction.prompts[0]
should_use_lora_manager = True
lora_weights = positive_conjunction.lora_weights should_use_lora_manager = True
if model.peft_manager: lora_weights = positive_conjunction.lora_weights
should_use_lora_manager = model.peft_manager.should_use(lora_weights) lora_conditions = None
if not should_use_lora_manager: if model.peft_manager:
model.peft_manager.set_loras(lora_weights) should_use_lora_manager = model.peft_manager.should_use(lora_weights)
if model.lora_manager and should_use_lora_manager: if not should_use_lora_manager:
lora_conditions = model.lora_manager.set_loras_conditions(lora_weights) model.peft_manager.set_loras(lora_weights)
if model.lora_manager and should_use_lora_manager:
lora_conditions = model.lora_manager.set_loras_conditions(lora_weights)
negative_conjunction = Compel.parse_prompt_string(negative_prompt_string) negative_conjunction = Compel.parse_prompt_string(negative_prompt_string)
negative_prompt: FlattenedPrompt | Blend = negative_conjunction.prompts[0] negative_prompt: FlattenedPrompt | Blend = negative_conjunction.prompts[0]
tokens_count = get_max_token_count(model.tokenizer, positive_prompt)
if log_tokens or getattr(Globals, "log_tokenization", False): if log_tokens or getattr(Globals, "log_tokenization", False):
log_tokenization(positive_prompt, negative_prompt, tokenizer=tokenizer) log_tokenization(positive_prompt, negative_prompt, tokenizer=model.tokenizer)
c, options = compel.build_conditioning_tensor_for_prompt_object(positive_prompt) # some LoRA models also mess with the text encoder, so they must be active while compel builds conditioning tensors
uc, _ = compel.build_conditioning_tensor_for_prompt_object(negative_prompt) lora_conditioning_ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(tokens_count_including_eos_bos=tokens_count,
lora_conditions=lora_conditions)
tokens_count = get_max_token_count(tokenizer, positive_prompt) with InvokeAIDiffuserComponent.custom_attention_context(model.unet,
extra_conditioning_info=lora_conditioning_ec,
step_count=-1):
c, options = compel.build_conditioning_tensor_for_prompt_object(positive_prompt)
uc, _ = compel.build_conditioning_tensor_for_prompt_object(negative_prompt)
# now build the "real" ec
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(tokens_count_including_eos_bos=tokens_count, ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(tokens_count_including_eos_bos=tokens_count,
cross_attention_control_args=options.get( cross_attention_control_args=options.get(
'cross_attention_control', None), 'cross_attention_control', None),
@@ -93,12 +94,12 @@ def get_prompt_structure(prompt_string, skip_normalize_legacy_blend: bool = Fals
Union[FlattenedPrompt, Blend], FlattenedPrompt): Union[FlattenedPrompt, Blend], FlattenedPrompt):
positive_prompt_string, negative_prompt_string = split_prompt_to_positive_and_negative(prompt_string) positive_prompt_string, negative_prompt_string = split_prompt_to_positive_and_negative(prompt_string)
legacy_blend = try_parse_legacy_blend(positive_prompt_string, skip_normalize_legacy_blend) legacy_blend = try_parse_legacy_blend(positive_prompt_string, skip_normalize_legacy_blend)
positive_prompt: FlattenedPrompt|Blend positive_conjunction: Conjunction
if legacy_blend is not None: if legacy_blend is not None:
positive_prompt = legacy_blend positive_conjunction = legacy_blend
else: else:
positive_conjunction = Compel.parse_prompt_string(positive_prompt_string) positive_conjunction = Compel.parse_prompt_string(positive_prompt_string)
positive_prompt = positive_conjunction.prompts[0] positive_prompt = positive_conjunction.prompts[0]
negative_conjunction = Compel.parse_prompt_string(negative_prompt_string) negative_conjunction = Compel.parse_prompt_string(negative_prompt_string)
negative_prompt: FlattenedPrompt|Blend = negative_conjunction.prompts[0] negative_prompt: FlattenedPrompt|Blend = negative_conjunction.prompts[0]
@@ -217,18 +218,26 @@ def log_tokenization_for_text(text, tokenizer, display_label=None):
print(f'{discarded}\x1b[0m') print(f'{discarded}\x1b[0m')
def try_parse_legacy_blend(text: str, skip_normalize: bool=False) -> Optional[Blend]: def try_parse_legacy_blend(text: str, skip_normalize: bool=False) -> Optional[Conjunction]:
weighted_subprompts = split_weighted_subprompts(text, skip_normalize=skip_normalize) weighted_subprompts = split_weighted_subprompts(text, skip_normalize=skip_normalize)
if len(weighted_subprompts) <= 1: if len(weighted_subprompts) <= 1:
return None return None
strings = [x[0] for x in weighted_subprompts] strings = [x[0] for x in weighted_subprompts]
weights = [x[1] for x in weighted_subprompts]
pp = PromptParser() pp = PromptParser()
parsed_conjunctions = [pp.parse_conjunction(x) for x in strings] parsed_conjunctions = [pp.parse_conjunction(x) for x in strings]
flattened_prompts = [x.prompts[0] for x in parsed_conjunctions] flattened_prompts = []
weights = []
loras = []
for i, x in enumerate(parsed_conjunctions):
if len(x.prompts)>0:
flattened_prompts.append(x.prompts[0])
weights.append(weighted_subprompts[i][1])
if len(x.lora_weights)>0:
loras.extend(x.lora_weights)
return Blend(prompts=flattened_prompts, weights=weights, normalize_weights=not skip_normalize) return Conjunction([Blend(prompts=flattened_prompts, weights=weights, normalize_weights=not skip_normalize)],
lora_weights = loras)
def split_weighted_subprompts(text, skip_normalize=False)->list: def split_weighted_subprompts(text, skip_normalize=False)->list:

View File

@@ -4,14 +4,13 @@ pip install <path_to_git_source>.
''' '''
import os import os
import platform import platform
import psutil
import requests import requests
from rich import box, print from rich import box, print
from rich.console import Console, Group, group from rich.console import Console, group
from rich.panel import Panel from rich.panel import Panel
from rich.prompt import Prompt from rich.prompt import Prompt
from rich.style import Style from rich.style import Style
from rich.syntax import Syntax
from rich.text import Text
from ldm.invoke import __version__ from ldm.invoke import __version__
@@ -32,6 +31,19 @@ else:
def get_versions()->dict: def get_versions()->dict:
return requests.get(url=INVOKE_AI_REL).json() return requests.get(url=INVOKE_AI_REL).json()
def invokeai_is_running()->bool:
for p in psutil.process_iter():
try:
cmdline = p.cmdline()
matches = [x for x in cmdline if x.endswith(('invokeai','invokeai.exe'))]
if matches:
print(f':exclamation: [bold red]An InvokeAI instance appears to be running as process {p.pid}[/red bold]')
return True
except (psutil.AccessDenied,psutil.NoSuchProcess):
continue
return False
def welcome(versions: dict): def welcome(versions: dict):
@group() @group()
@@ -62,6 +74,10 @@ def welcome(versions: dict):
def main(): def main():
versions = get_versions() versions = get_versions()
if invokeai_is_running():
print(f':exclamation: [bold red]Please terminate all running instances of InvokeAI before updating.[/red bold]')
return
welcome(versions) welcome(versions)
tag = None tag = None

View File

@@ -196,16 +196,6 @@ class addModelsForm(npyscreen.FormMultiPage):
scroll_exit=True, scroll_exit=True,
) )
self.nextrely += 1 self.nextrely += 1
self.convert_models = self.add_widget_intelligent(
npyscreen.TitleSelectOne,
name="== CONVERT IMPORTED MODELS INTO DIFFUSERS==",
values=["Keep original format", "Convert to diffusers"],
value=0,
begin_entry_at=4,
max_height=4,
hidden=True, # will appear when imported models box is edited
scroll_exit=True,
)
self.cancel = self.add_widget_intelligent( self.cancel = self.add_widget_intelligent(
npyscreen.ButtonPress, npyscreen.ButtonPress,
name="CANCEL", name="CANCEL",
@@ -240,8 +230,6 @@ class addModelsForm(npyscreen.FormMultiPage):
self.show_directory_fields.addVisibleWhenSelected(i) self.show_directory_fields.addVisibleWhenSelected(i)
self.show_directory_fields.when_value_edited = self._clear_scan_directory self.show_directory_fields.when_value_edited = self._clear_scan_directory
self.import_model_paths.when_value_edited = self._show_hide_convert
self.autoload_directory.when_value_edited = self._show_hide_convert
def resize(self): def resize(self):
super().resize() super().resize()
@@ -252,13 +240,6 @@ class addModelsForm(npyscreen.FormMultiPage):
if not self.show_directory_fields.value: if not self.show_directory_fields.value:
self.autoload_directory.value = "" self.autoload_directory.value = ""
def _show_hide_convert(self):
model_paths = self.import_model_paths.value or ""
autoload_directory = self.autoload_directory.value or ""
self.convert_models.hidden = (
len(model_paths) == 0 and len(autoload_directory) == 0
)
def _get_starter_model_labels(self) -> List[str]: def _get_starter_model_labels(self) -> List[str]:
window_width, window_height = get_terminal_size() window_width, window_height = get_terminal_size()
label_width = 25 label_width = 25
@@ -318,7 +299,6 @@ class addModelsForm(npyscreen.FormMultiPage):
.scan_directory: Path to a directory of models to scan and import .scan_directory: Path to a directory of models to scan and import
.autoscan_on_startup: True if invokeai should scan and import at startup time .autoscan_on_startup: True if invokeai should scan and import at startup time
.import_model_paths: list of URLs, repo_ids and file paths to import .import_model_paths: list of URLs, repo_ids and file paths to import
.convert_to_diffusers: if True, convert legacy checkpoints into diffusers
""" """
# we're using a global here rather than storing the result in the parentapp # we're using a global here rather than storing the result in the parentapp
# due to some bug in npyscreen that is causing attributes to be lost # due to some bug in npyscreen that is causing attributes to be lost
@@ -354,7 +334,6 @@ class addModelsForm(npyscreen.FormMultiPage):
# URLs and the like # URLs and the like
selections.import_model_paths = self.import_model_paths.value.split() selections.import_model_paths = self.import_model_paths.value.split()
selections.convert_to_diffusers = self.convert_models.value[0] == 1
class AddModelApplication(npyscreen.NPSAppManaged): class AddModelApplication(npyscreen.NPSAppManaged):
@@ -367,7 +346,6 @@ class AddModelApplication(npyscreen.NPSAppManaged):
scan_directory=None, scan_directory=None,
autoscan_on_startup=None, autoscan_on_startup=None,
import_model_paths=None, import_model_paths=None,
convert_to_diffusers=None,
) )
def onStart(self): def onStart(self):
@@ -387,7 +365,6 @@ def process_and_execute(opt: Namespace, selections: Namespace):
directory_to_scan = selections.scan_directory directory_to_scan = selections.scan_directory
scan_at_startup = selections.autoscan_on_startup scan_at_startup = selections.autoscan_on_startup
potential_models_to_install = selections.import_model_paths potential_models_to_install = selections.import_model_paths
convert_to_diffusers = selections.convert_to_diffusers
install_requested_models( install_requested_models(
install_initial_models=models_to_install, install_initial_models=models_to_install,
@@ -395,7 +372,6 @@ def process_and_execute(opt: Namespace, selections: Namespace):
scan_directory=Path(directory_to_scan) if directory_to_scan else None, scan_directory=Path(directory_to_scan) if directory_to_scan else None,
external_models=potential_models_to_install, external_models=potential_models_to_install,
scan_at_startup=scan_at_startup, scan_at_startup=scan_at_startup,
convert_to_diffusers=convert_to_diffusers,
precision="float32" precision="float32"
if opt.full_precision if opt.full_precision
else choose_precision(torch.device(choose_torch_device())), else choose_precision(torch.device(choose_torch_device())),

View File

@@ -11,6 +11,7 @@ from tempfile import TemporaryFile
import requests import requests
from diffusers import AutoencoderKL from diffusers import AutoencoderKL
from diffusers import logging as dlogging
from huggingface_hub import hf_hub_url from huggingface_hub import hf_hub_url
from omegaconf import OmegaConf from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig from omegaconf.dictconfig import DictConfig
@@ -68,7 +69,6 @@ def install_requested_models(
scan_directory: Path = None, scan_directory: Path = None,
external_models: List[str] = None, external_models: List[str] = None,
scan_at_startup: bool = False, scan_at_startup: bool = False,
convert_to_diffusers: bool = False,
precision: str = "float16", precision: str = "float16",
purge_deleted: bool = False, purge_deleted: bool = False,
config_file_path: Path = None, config_file_path: Path = None,
@@ -114,17 +114,16 @@ def install_requested_models(
try: try:
model_manager.heuristic_import( model_manager.heuristic_import(
path_url_or_repo, path_url_or_repo,
convert=convert_to_diffusers,
config_file_callback=_pick_configuration_file, config_file_callback=_pick_configuration_file,
commit_to_conf=config_file_path commit_to_conf=config_file_path
) )
except KeyboardInterrupt: except KeyboardInterrupt:
sys.exit(-1) sys.exit(-1)
except Exception: except Exception as e:
pass print(f'An exception has occurred: {str(e)}')
if scan_at_startup and scan_directory.is_dir(): if scan_at_startup and scan_directory.is_dir():
argument = '--autoconvert' if convert_to_diffusers else '--autoimport' argument = '--autoconvert'
initfile = Path(Globals.root, Globals.initfile) initfile = Path(Globals.root, Globals.initfile)
replacement = Path(Globals.root, f'{Globals.initfile}.new') replacement = Path(Globals.root, f'{Globals.initfile}.new')
directory = str(scan_directory).replace('\\','/') directory = str(scan_directory).replace('\\','/')
@@ -296,13 +295,21 @@ def _download_diffusion_weights(
mconfig: DictConfig, access_token: str, precision: str = "float32" mconfig: DictConfig, access_token: str, precision: str = "float32"
): ):
repo_id = mconfig["repo_id"] repo_id = mconfig["repo_id"]
revision = mconfig.get('revision',None)
model_class = ( model_class = (
StableDiffusionGeneratorPipeline StableDiffusionGeneratorPipeline
if mconfig.get("format", None) == "diffusers" if mconfig.get("format", None) == "diffusers"
else AutoencoderKL else AutoencoderKL
) )
extra_arg_list = [{"revision": "fp16"}, {}] if precision == "float16" else [{}] extra_arg_list = [{"revision": revision}] if revision \
else [{"revision": "fp16"}, {}] if precision == "float16" \
else [{}]
path = None path = None
# quench safety checker warnings
verbosity = dlogging.get_verbosity()
dlogging.set_verbosity_error()
for extra_args in extra_arg_list: for extra_args in extra_arg_list:
try: try:
path = download_from_hf( path = download_from_hf(
@@ -318,6 +325,7 @@ def _download_diffusion_weights(
print(f"An unexpected error occurred while downloading the model: {e})") print(f"An unexpected error occurred while downloading the model: {e})")
if path: if path:
break break
dlogging.set_verbosity(verbosity)
return path return path
@@ -448,6 +456,8 @@ def new_config_file_contents(
stanza["description"] = mod["description"] stanza["description"] = mod["description"]
stanza["repo_id"] = mod["repo_id"] stanza["repo_id"] = mod["repo_id"]
stanza["format"] = mod["format"] stanza["format"] = mod["format"]
if "revision" in mod:
stanza["revision"] = mod["revision"]
# diffusers don't need width and height (probably .ckpt doesn't either) # diffusers don't need width and height (probably .ckpt doesn't either)
# so we no longer require these in INITIAL_MODELS.yaml # so we no longer require these in INITIAL_MODELS.yaml
if "width" in mod: if "width" in mod:
@@ -472,10 +482,9 @@ def new_config_file_contents(
conf[model] = stanza conf[model] = stanza
# if no default model was chosen, then we select the first # if no default model was chosen, then we select the first one in the list
# one in the list
if not default_selected: if not default_selected:
conf[list(successfully_downloaded.keys())[0]]["default"] = True conf[list(conf.keys())[0]]["default"] = True
return OmegaConf.to_yaml(conf) return OmegaConf.to_yaml(conf)

View File

@@ -99,8 +99,9 @@ def expand_prompts(
sequence = 0 sequence = 0
for command in commands: for command in commands:
sequence += 1 sequence += 1
parent_conn.send( format = _get_fn_format(outdir, sequence)
command + f' --fnformat="dp.{sequence:04}.{{prompt}}.png"' parent_conn.send_bytes(
(command + f' --fnformat="{format}"').encode('utf-8')
) )
parent_conn.close() parent_conn.close()
else: else:
@@ -110,7 +111,20 @@ def expand_prompts(
for p in children: for p in children:
p.terminate() p.terminate()
def _get_fn_format(directory:str, sequence:int)->str:
"""
Get a filename that doesn't exceed filename length restrictions
on the current platform.
"""
try:
max_length = os.pathconf(directory,'PC_NAME_MAX')
except:
max_length = 255
prefix = f'dp.{sequence:04}.'
suffix = '.png'
max_length -= len(prefix)+len(suffix)
return f'{prefix}{{prompt:0.{max_length}}}{suffix}'
class MessageToStdin(object): class MessageToStdin(object):
def __init__(self, connection: Connection): def __init__(self, connection: Connection):
self.connection = connection self.connection = connection
@@ -119,7 +133,7 @@ class MessageToStdin(object):
def readline(self) -> str: def readline(self) -> str:
try: try:
if len(self.linebuffer) == 0: if len(self.linebuffer) == 0:
message = self.connection.recv() message = self.connection.recv_bytes().decode('utf-8')
self.linebuffer = message.split("\n") self.linebuffer = message.split("\n")
result = self.linebuffer.pop(0) result = self.linebuffer.pop(0)
return result return result

View File

@@ -400,8 +400,15 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
@property @property
def _submodels(self) -> Sequence[torch.nn.Module]: def _submodels(self) -> Sequence[torch.nn.Module]:
module_names, _, _ = self.extract_init_dict(dict(self.config)) module_names, _, _ = self.extract_init_dict(dict(self.config))
values = [getattr(self, name) for name in module_names.keys()] submodels = []
return [m for m in values if isinstance(m, torch.nn.Module)] for name in module_names.keys():
if hasattr(self, name):
value = getattr(self, name)
else:
value = getattr(self.config, name)
if isinstance(value, torch.nn.Module):
submodels.append(value)
return submodels
def image_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int, def image_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int,
conditioning_data: ConditioningData, conditioning_data: ConditioningData,
@@ -467,11 +474,12 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if additional_guidance is None: if additional_guidance is None:
additional_guidance = [] additional_guidance = []
extra_conditioning_info = conditioning_data.extra extra_conditioning_info = conditioning_data.extra
with self.invokeai_diffuser.custom_attention_context(extra_conditioning_info=extra_conditioning_info, with InvokeAIDiffuserComponent.custom_attention_context(self.invokeai_diffuser.model,
step_count=len(self.scheduler.timesteps) extra_conditioning_info=extra_conditioning_info,
step_count=len(self.scheduler.timesteps)
): ):
yield PipelineIntermediateState(run_id=run_id, step=-1, timestep=self.scheduler.num_train_timesteps, yield PipelineIntermediateState(run_id=run_id, step=-1, timestep=self.scheduler.config.num_train_timesteps,
latents=latents) latents=latents)
batch_size = latents.shape[0] batch_size = latents.shape[0]
@@ -755,7 +763,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
@property @property
def channels(self) -> int: def channels(self) -> int:
"""Compatible with DiffusionWrapper""" """Compatible with DiffusionWrapper"""
return self.unet.in_channels return self.unet.config.in_channels
def decode_latents(self, latents): def decode_latents(self, latents):
# Explicit call to get the vae loaded, since `decode` isn't the forward method. # Explicit call to get the vae loaded, since `decode` isn't the forward method.

View File

@@ -255,8 +255,8 @@ class Inpaint(Img2Img):
pipeline.scheduler = sampler pipeline.scheduler = sampler
# todo: support cross-attention control # todo: support cross-attention control
uc, c, _ = conditioning uc, c, extra_conditioning_info = conditioning
conditioning_data = (ConditioningData(uc, c, cfg_scale) conditioning_data = (ConditioningData(uc, c, cfg_scale, extra_conditioning_info)
.add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)) .add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta))

View File

@@ -372,12 +372,6 @@ class ModelManager(object):
) )
from ldm.invoke.ckpt_to_diffuser import load_pipeline_from_original_stable_diffusion_ckpt from ldm.invoke.ckpt_to_diffuser import load_pipeline_from_original_stable_diffusion_ckpt
# try:
# if self.list_models()[self.current_model]['status'] == 'active':
# self.offload_model(self.current_model)
# except Exception:
# pass
if self._has_cuda(): if self._has_cuda():
torch.cuda.empty_cache() torch.cuda.empty_cache()
pipeline = load_pipeline_from_original_stable_diffusion_ckpt( pipeline = load_pipeline_from_original_stable_diffusion_ckpt(
@@ -423,9 +417,9 @@ class ModelManager(object):
pipeline_args.update(cache_dir=global_cache_dir("hub")) pipeline_args.update(cache_dir=global_cache_dir("hub"))
if using_fp16: if using_fp16:
pipeline_args.update(torch_dtype=torch.float16) pipeline_args.update(torch_dtype=torch.float16)
fp_args_list = [{"revision": "fp16"}, {}] revision = mconfig.get('revision') or ('fp16' if using_fp16 else None)
else: fp_args_list = [{"revision": revision}] if revision else []
fp_args_list = [{}] fp_args_list.append({})
verbosity = dlogging.get_verbosity() verbosity = dlogging.get_verbosity()
dlogging.set_verbosity_error() dlogging.set_verbosity_error()
@@ -1162,7 +1156,7 @@ class ModelManager(object):
return self.device.type == "cuda" return self.device.type == "cuda"
def _diffuser_sha256( def _diffuser_sha256(
self, name_or_path: Union[str, Path], chunksize=4096 self, name_or_path: Union[str, Path], chunksize=16777216
) -> Union[str, bytes]: ) -> Union[str, bytes]:
path = None path = None
if isinstance(name_or_path, Path): if isinstance(name_or_path, Path):

View File

@@ -14,7 +14,6 @@ from torch import nn
from compel.cross_attention_control import Arguments from compel.cross_attention_control import Arguments
from diffusers.models.unet_2d_condition import UNet2DConditionModel from diffusers.models.unet_2d_condition import UNet2DConditionModel
from diffusers.models.cross_attention import AttnProcessor
from ldm.invoke.devices import torch_dtype from ldm.invoke.devices import torch_dtype
@@ -163,7 +162,7 @@ class Context:
class InvokeAICrossAttentionMixin: class InvokeAICrossAttentionMixin:
""" """
Enable InvokeAI-flavoured CrossAttention calculation, which does aggressive low-memory slicing and calls Enable InvokeAI-flavoured Attention calculation, which does aggressive low-memory slicing and calls
through both to an attention_slice_wrangler and a slicing_strategy_getter for custom attention map wrangling through both to an attention_slice_wrangler and a slicing_strategy_getter for custom attention map wrangling
and dymamic slicing strategy selection. and dymamic slicing strategy selection.
""" """
@@ -178,7 +177,7 @@ class InvokeAICrossAttentionMixin:
Set custom attention calculator to be called when attention is calculated Set custom attention calculator to be called when attention is calculated
:param wrangler: Callback, with args (module, suggested_attention_slice, dim, offset, slice_size), :param wrangler: Callback, with args (module, suggested_attention_slice, dim, offset, slice_size),
which returns either the suggested_attention_slice or an adjusted equivalent. which returns either the suggested_attention_slice or an adjusted equivalent.
`module` is the current CrossAttention module for which the callback is being invoked. `module` is the current Attention module for which the callback is being invoked.
`suggested_attention_slice` is the default-calculated attention slice `suggested_attention_slice` is the default-calculated attention slice
`dim` is -1 if the attenion map has not been sliced, or 0 or 1 for dimension-0 or dimension-1 slicing. `dim` is -1 if the attenion map has not been sliced, or 0 or 1 for dimension-0 or dimension-1 slicing.
If `dim` is >= 0, `offset` and `slice_size` specify the slice start and length. If `dim` is >= 0, `offset` and `slice_size` specify the slice start and length.
@@ -288,16 +287,7 @@ class InvokeAICrossAttentionMixin:
return self.einsum_op_tensor_mem(q, k, v, 32) return self.einsum_op_tensor_mem(q, k, v, 32)
def setup_cross_attention_control_attention_processors(unet: UNet2DConditionModel, context: Context):
def restore_default_cross_attention(model, is_running_diffusers: bool, processors_to_restore: Optional[AttnProcessor]=None):
if is_running_diffusers:
unet = model
unet.set_attn_processor(processors_to_restore or CrossAttnProcessor())
else:
remove_attention_function(model)
def override_cross_attention(model, context: Context, is_running_diffusers = False):
""" """
Inject attention parameters and functions into the passed in model to enable cross attention editing. Inject attention parameters and functions into the passed in model to enable cross attention editing.
@@ -323,26 +313,19 @@ def override_cross_attention(model, context: Context, is_running_diffusers = Fal
context.cross_attention_mask = mask.to(device) context.cross_attention_mask = mask.to(device)
context.cross_attention_index_map = indices.to(device) context.cross_attention_index_map = indices.to(device)
if is_running_diffusers: old_attn_processors = unet.attn_processors
unet = model if torch.backends.mps.is_available():
old_attn_processors = unet.attn_processors # see note in StableDiffusionGeneratorPipeline.__init__ about borked slicing on MPS
if torch.backends.mps.is_available(): unet.set_attn_processor(SwapCrossAttnProcessor())
# see note in StableDiffusionGeneratorPipeline.__init__ about borked slicing on MPS
unet.set_attn_processor(SwapCrossAttnProcessor())
else:
# try to re-use an existing slice size
default_slice_size = 4
slice_size = next((p.slice_size for p in old_attn_processors.values() if type(p) is SlicedAttnProcessor), default_slice_size)
unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size))
else: else:
context.register_cross_attention_modules(model) # try to re-use an existing slice size
inject_attention_function(model, context) default_slice_size = 4
slice_size = next((p.slice_size for p in old_attn_processors.values() if type(p) is SlicedAttnProcessor), default_slice_size)
unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size))
def get_cross_attention_modules(model, which: CrossAttentionType) -> list[tuple[str, InvokeAICrossAttentionMixin]]: def get_cross_attention_modules(model, which: CrossAttentionType) -> list[tuple[str, InvokeAICrossAttentionMixin]]:
from ldm.modules.attention import CrossAttention # avoid circular import from ldm.modules.attention import CrossAttention # avoid circular import # TODO: rename as in diffusers?
cross_attention_class: type = InvokeAIDiffusersCrossAttention if isinstance(model,UNet2DConditionModel) else CrossAttention cross_attention_class: type = InvokeAIDiffusersCrossAttention if isinstance(model,UNet2DConditionModel) else CrossAttention
which_attn = "attn1" if which is CrossAttentionType.SELF else "attn2" which_attn = "attn1" if which is CrossAttentionType.SELF else "attn2"
attention_module_tuples = [(name,module) for name, module in model.named_modules() if attention_module_tuples = [(name,module) for name, module in model.named_modules() if
@@ -448,7 +431,7 @@ def get_mem_free_total(device):
class InvokeAIDiffusersCrossAttention(diffusers.models.attention.CrossAttention, InvokeAICrossAttentionMixin): class InvokeAIDiffusersCrossAttention(diffusers.models.attention.Attention, InvokeAICrossAttentionMixin):
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
@@ -473,8 +456,8 @@ class InvokeAIDiffusersCrossAttention(diffusers.models.attention.CrossAttention,
""" """
# base implementation # base implementation
class CrossAttnProcessor: class AttnProcessor:
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
batch_size, sequence_length, _ = hidden_states.shape batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
@@ -503,7 +486,7 @@ from dataclasses import field, dataclass
import torch import torch
from diffusers.models.cross_attention import CrossAttention, CrossAttnProcessor, SlicedAttnProcessor from diffusers.models.attention_processor import Attention, AttnProcessor, SlicedAttnProcessor
@dataclass @dataclass
@@ -548,7 +531,7 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):
# TODO: dynamically pick slice size based on memory conditions # TODO: dynamically pick slice size based on memory conditions
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None,
# kwargs # kwargs
swap_cross_attn_context: SwapCrossAttnContext=None): swap_cross_attn_context: SwapCrossAttnContext=None):

View File

@@ -12,17 +12,6 @@ class DDIMSampler(Sampler):
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.model, self.invokeai_diffuser = InvokeAIDiffuserComponent(self.model,
model_forward_callback = lambda x, sigma, cond: self.model.apply_model(x, sigma, cond)) model_forward_callback = lambda x, sigma, cond: self.model.apply_model(x, sigma, cond))
def prepare_to_sample(self, t_enc, **kwargs):
super().prepare_to_sample(t_enc, **kwargs)
extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
all_timesteps_count = kwargs.get('all_timesteps_count', t_enc)
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
self.invokeai_diffuser.override_attention_processors(extra_conditioning_info, step_count = all_timesteps_count)
else:
self.invokeai_diffuser.restore_default_cross_attention()
# This is the central routine # This is the central routine
@torch.no_grad() @torch.no_grad()

View File

@@ -38,15 +38,6 @@ class CFGDenoiser(nn.Module):
model_forward_callback=lambda x, sigma, cond: self.inner_model(x, sigma, cond=cond)) model_forward_callback=lambda x, sigma, cond: self.inner_model(x, sigma, cond=cond))
def prepare_to_sample(self, t_enc, **kwargs):
extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
self.invokeai_diffuser.override_attention_processors(extra_conditioning_info, step_count = t_enc)
else:
self.invokeai_diffuser.restore_default_cross_attention()
def forward(self, x, sigma, uncond, cond, cond_scale): def forward(self, x, sigma, uncond, cond, cond_scale):
next_x = self.invokeai_diffuser.do_diffusion_step(x, sigma, uncond, cond, cond_scale) next_x = self.invokeai_diffuser.do_diffusion_step(x, sigma, uncond, cond, cond_scale)

View File

@@ -14,17 +14,6 @@ class PLMSSampler(Sampler):
def __init__(self, model, schedule='linear', device=None, **kwargs): def __init__(self, model, schedule='linear', device=None, **kwargs):
super().__init__(model,schedule,model.num_timesteps, device) super().__init__(model,schedule,model.num_timesteps, device)
def prepare_to_sample(self, t_enc, **kwargs):
super().prepare_to_sample(t_enc, **kwargs)
extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
all_timesteps_count = kwargs.get('all_timesteps_count', t_enc)
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
self.invokeai_diffuser.override_attention_processors(extra_conditioning_info, step_count = all_timesteps_count)
else:
self.invokeai_diffuser.restore_default_cross_attention()
# this is the essential routine # this is the essential routine
@torch.no_grad() @torch.no_grad()

View File

@@ -1,18 +1,17 @@
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from math import ceil from math import ceil
from typing import Callable, Optional, Union, Any, Dict from typing import Callable, Optional, Union, Any
import numpy as np import numpy as np
import torch import torch
from diffusers.models.cross_attention import AttnProcessor from diffusers import UNet2DConditionModel
from typing_extensions import TypeAlias from typing_extensions import TypeAlias
from ldm.invoke.globals import Globals from ldm.invoke.globals import Globals
from ldm.models.diffusion.cross_attention_control import ( from ldm.models.diffusion.cross_attention_control import (
Arguments, Arguments,
restore_default_cross_attention, setup_cross_attention_control_attention_processors,
override_cross_attention,
Context, Context,
get_cross_attention_modules, get_cross_attention_modules,
CrossAttentionType, CrossAttentionType,
@@ -84,66 +83,45 @@ class InvokeAIDiffuserComponent:
self.cross_attention_control_context = None self.cross_attention_control_context = None
self.sequential_guidance = Globals.sequential_guidance self.sequential_guidance = Globals.sequential_guidance
@classmethod
@contextmanager @contextmanager
def custom_attention_context( def custom_attention_context(
self, extra_conditioning_info: Optional[ExtraConditioningInfo], step_count: int clss,
unet: UNet2DConditionModel, # note: also may futz with the text encoder depending on requested LoRAs
extra_conditioning_info: Optional[ExtraConditioningInfo],
step_count: int
): ):
old_attn_processor = None old_attn_processors = None
if extra_conditioning_info and ( if extra_conditioning_info and (
extra_conditioning_info.wants_cross_attention_control extra_conditioning_info.wants_cross_attention_control
| extra_conditioning_info.has_lora_conditions | extra_conditioning_info.has_lora_conditions
): ):
old_attn_processor = self.override_attention_processors( old_attn_processors = unet.attn_processors
extra_conditioning_info, step_count=step_count # Load lora conditions into the model
) if extra_conditioning_info.has_lora_conditions:
for condition in extra_conditioning_info.lora_conditions:
condition() # target model is stored in condition state for some reason
if extra_conditioning_info.wants_cross_attention_control:
cross_attention_control_context = Context(
arguments=extra_conditioning_info.cross_attention_control_args,
step_count=step_count,
)
setup_cross_attention_control_attention_processors(
unet,
cross_attention_control_context,
)
try: try:
yield None yield None
finally: finally:
if old_attn_processor is not None: if old_attn_processors is not None:
self.restore_default_cross_attention(old_attn_processor) unet.set_attn_processor(old_attn_processors)
if extra_conditioning_info and extra_conditioning_info.has_lora_conditions: if extra_conditioning_info and extra_conditioning_info.has_lora_conditions:
for lora_condition in extra_conditioning_info.lora_conditions: for lora_condition in extra_conditioning_info.lora_conditions:
lora_condition.unload() lora_condition.unload()
# TODO resuscitate attention map saving # TODO resuscitate attention map saving
# self.remove_attention_map_saving() # self.remove_attention_map_saving()
def override_attention_processors(
self, conditioning: ExtraConditioningInfo, step_count: int
) -> Dict[str, AttnProcessor]:
"""
setup cross attention .swap control. for diffusers this replaces the attention processor, so
the previous attention processor is returned so that the caller can restore it later.
"""
old_attn_processors = self.model.attn_processors
# Load lora conditions into the model
if conditioning.has_lora_conditions:
for condition in conditioning.lora_conditions:
condition(self.model)
if conditioning.wants_cross_attention_control:
self.cross_attention_control_context = Context(
arguments=conditioning.cross_attention_control_args,
step_count=step_count,
)
override_cross_attention(
self.model,
self.cross_attention_control_context,
is_running_diffusers=self.is_running_diffusers,
)
return old_attn_processors
def restore_default_cross_attention(
self, processors_to_restore: Optional[dict[str, "AttnProcessor"]] = None
):
self.cross_attention_control_context = None
restore_default_cross_attention(
self.model,
is_running_diffusers=self.is_running_diffusers,
processors_to_restore=processors_to_restore,
)
def setup_attention_map_saving(self, saver: AttentionMapSaver): def setup_attention_map_saving(self, saver: AttentionMapSaver):
def callback(slice, dim, offset, slice_size, key): def callback(slice, dim, offset, slice_size, key):
if dim is not None: if dim is not None:

View File

@@ -1,15 +1,16 @@
import re import json
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
import torch import torch
from compel import Compel
from diffusers.models import UNet2DConditionModel from diffusers.models import UNet2DConditionModel
from filelock import FileLock, Timeout
from safetensors.torch import load_file from safetensors.torch import load_file
from torch.utils.hooks import RemovableHandle from torch.utils.hooks import RemovableHandle
from transformers import CLIPTextModel from transformers import CLIPTextModel
from ldm.invoke.devices import choose_torch_device from ..invoke.globals import global_lora_models_dir, Globals
from ..invoke.devices import choose_torch_device
""" """
This module supports loading LoRA weights trained with https://github.com/kohya-ss/sd-scripts This module supports loading LoRA weights trained with https://github.com/kohya-ss/sd-scripts
@@ -17,6 +18,11 @@ To be removed once support for diffusers LoRA weights is well supported
""" """
class IncompatibleModelException(Exception):
"Raised when there is an attempt to load a LoRA into a model that is incompatible with it"
pass
class LoRALayer: class LoRALayer:
lora_name: str lora_name: str
name: str name: str
@@ -31,18 +37,14 @@ class LoRALayer:
self.name = name self.name = name
self.scale = alpha / rank if (alpha and rank) else 1.0 self.scale = alpha / rank if (alpha and rank) else 1.0
def forward(self, lora, input_h, output): def forward(self, lora, input_h):
if self.mid is None: if self.mid is None:
output = ( weight = self.up(self.down(*input_h))
output
+ self.up(self.down(*input_h)) * lora.multiplier * self.scale
)
else: else:
output = ( weight = self.up(self.mid(self.down(*input_h)))
output
+ self.up(self.mid(self.down(*input_h))) * lora.multiplier * self.scale return weight * lora.multiplier * self.scale
)
return output
class LoHALayer: class LoHALayer:
lora_name: str lora_name: str
@@ -64,8 +66,7 @@ class LoHALayer:
self.name = name self.name = name
self.scale = alpha / rank if (alpha and rank) else 1.0 self.scale = alpha / rank if (alpha and rank) else 1.0
def forward(self, lora, input_h, output): def forward(self, lora, input_h):
if type(self.org_module) == torch.nn.Conv2d: if type(self.org_module) == torch.nn.Conv2d:
op = torch.nn.functional.conv2d op = torch.nn.functional.conv2d
extra_args = dict( extra_args = dict(
@@ -80,21 +81,87 @@ class LoHALayer:
extra_args = {} extra_args = {}
if self.t1 is None: if self.t1 is None:
weight = ((self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)) weight = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
else: else:
rebuild1 = torch.einsum('i j k l, j r, i p -> p r k l', self.t1, self.w1_b, self.w1_a) rebuild1 = torch.einsum(
rebuild2 = torch.einsum('i j k l, j r, i p -> p r k l', self.t2, self.w2_b, self.w2_a) "i j k l, j r, i p -> p r k l", self.t1, self.w1_b, self.w1_a
)
rebuild2 = torch.einsum(
"i j k l, j r, i p -> p r k l", self.t2, self.w2_b, self.w2_a
)
weight = rebuild1 * rebuild2 weight = rebuild1 * rebuild2
bias = self.bias if self.bias is not None else 0 bias = self.bias if self.bias is not None else 0
return output + op( return op(
*input_h, *input_h,
(weight + bias).view(self.org_module.weight.shape), (weight + bias).view(self.org_module.weight.shape),
None, None,
**extra_args, **extra_args,
) * lora.multiplier * self.scale ) * lora.multiplier * self.scale
class LoKRLayer:
lora_name: str
name: str
scale: float
w1: Optional[torch.Tensor] = None
w1_a: Optional[torch.Tensor] = None
w1_b: Optional[torch.Tensor] = None
w2: Optional[torch.Tensor] = None
w2_a: Optional[torch.Tensor] = None
w2_b: Optional[torch.Tensor] = None
t2: Optional[torch.Tensor] = None
bias: Optional[torch.Tensor] = None
org_module: torch.nn.Module
def __init__(self, lora_name: str, name: str, rank=4, alpha=1.0):
self.lora_name = lora_name
self.name = name
self.scale = alpha / rank if (alpha and rank) else 1.0
def forward(self, lora, input_h):
if type(self.org_module) == torch.nn.Conv2d:
op = torch.nn.functional.conv2d
extra_args = dict(
stride=self.org_module.stride,
padding=self.org_module.padding,
dilation=self.org_module.dilation,
groups=self.org_module.groups,
)
else:
op = torch.nn.functional.linear
extra_args = {}
w1 = self.w1
if w1 is None:
w1 = self.w1_a @ self.w1_b
w2 = self.w2
if w2 is None:
if self.t2 is None:
w2 = self.w2_a @ self.w2_b
else:
w2 = torch.einsum('i j k l, i p, j r -> p r k l', self.t2, self.w2_a, self.w2_b)
if len(w2.shape) == 4:
w1 = w1.unsqueeze(2).unsqueeze(2)
w2 = w2.contiguous()
weight = torch.kron(w1, w2).reshape(self.org_module.weight.shape)
bias = self.bias if self.bias is not None else 0
return op(
*input_h,
(weight + bias).view(self.org_module.weight.shape),
None,
**extra_args
) * lora.multiplier * self.scale
class LoRAModuleWrapper: class LoRAModuleWrapper:
unet: UNet2DConditionModel unet: UNet2DConditionModel
@@ -111,12 +178,22 @@ class LoRAModuleWrapper:
self.applied_loras = {} self.applied_loras = {}
self.loaded_loras = {} self.loaded_loras = {}
self.UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention", "ResnetBlock2D", "Downsample2D", "Upsample2D", "SpatialTransformer"] self.UNET_TARGET_REPLACE_MODULE = [
self.TEXT_ENCODER_TARGET_REPLACE_MODULE = ["ResidualAttentionBlock", "CLIPAttention", "CLIPMLP"] "Transformer2DModel",
"Attention",
"ResnetBlock2D",
"Downsample2D",
"Upsample2D",
"SpatialTransformer",
]
self.TEXT_ENCODER_TARGET_REPLACE_MODULE = [
"ResidualAttentionBlock",
"CLIPAttention",
"CLIPMLP",
]
self.LORA_PREFIX_UNET = "lora_unet" self.LORA_PREFIX_UNET = "lora_unet"
self.LORA_PREFIX_TEXT_ENCODER = "lora_te" self.LORA_PREFIX_TEXT_ENCODER = "lora_te"
def find_modules( def find_modules(
prefix, root_module: torch.nn.Module, target_replace_modules prefix, root_module: torch.nn.Module, target_replace_modules
) -> dict[str, torch.nn.Module]: ) -> dict[str, torch.nn.Module]:
@@ -147,7 +224,6 @@ class LoRAModuleWrapper:
self.LORA_PREFIX_UNET, unet, self.UNET_TARGET_REPLACE_MODULE self.LORA_PREFIX_UNET, unet, self.UNET_TARGET_REPLACE_MODULE
) )
def lora_forward_hook(self, name): def lora_forward_hook(self, name):
wrapper = self wrapper = self
@@ -159,7 +235,7 @@ class LoRAModuleWrapper:
layer = lora.layers.get(name, None) layer = lora.layers.get(name, None)
if layer is None: if layer is None:
continue continue
output = layer.forward(lora, input_h, output) output += layer.forward(lora, input_h)
return output return output
return lora_forward return lora_forward
@@ -180,6 +256,7 @@ class LoRAModuleWrapper:
def clear_loaded_loras(self): def clear_loaded_loras(self):
self.loaded_loras.clear() self.loaded_loras.clear()
class LoRA: class LoRA:
name: str name: str
layers: dict[str, LoRALayer] layers: dict[str, LoRALayer]
@@ -205,7 +282,6 @@ class LoRA:
state_dict_groupped[stem] = dict() state_dict_groupped[stem] = dict()
state_dict_groupped[stem][leaf] = value state_dict_groupped[stem][leaf] = value
for stem, values in state_dict_groupped.items(): for stem, values in state_dict_groupped.items():
if stem.startswith(self.wrapper.LORA_PREFIX_TEXT_ENCODER): if stem.startswith(self.wrapper.LORA_PREFIX_TEXT_ENCODER):
wrapped = self.wrapper.text_modules.get(stem, None) wrapped = self.wrapper.text_modules.get(stem, None)
@@ -226,34 +302,59 @@ class LoRA:
if "alpha" in values: if "alpha" in values:
alpha = values["alpha"].item() alpha = values["alpha"].item()
if "bias_indices" in values and "bias_values" in values and "bias_size" in values: if (
"bias_indices" in values
and "bias_values" in values
and "bias_size" in values
):
bias = torch.sparse_coo_tensor( bias = torch.sparse_coo_tensor(
values["bias_indices"], values["bias_indices"],
values["bias_values"], values["bias_values"],
tuple(values["bias_size"]), tuple(values["bias_size"]),
).to(device=self.device, dtype=self.dtype) ).to(device=self.device, dtype=self.dtype)
# lora and locon # lora and locon
if "lora_down.weight" in values: if "lora_down.weight" in values:
value_down = values["lora_down.weight"] value_down = values["lora_down.weight"]
value_mid = values.get("lora_mid.weight", None) value_mid = values.get("lora_mid.weight", None)
value_up = values["lora_up.weight"] value_up = values["lora_up.weight"]
if type(wrapped) == torch.nn.Conv2d: if type(wrapped) == torch.nn.Conv2d:
if value_mid is not None: if value_mid is not None:
layer_down = torch.nn.Conv2d(value_down.shape[1], value_down.shape[0], (1, 1), bias=False) layer_down = torch.nn.Conv2d(
layer_mid = torch.nn.Conv2d(value_mid.shape[1], value_mid.shape[0], wrapped.kernel_size, wrapped.stride, wrapped.padding, bias=False) value_down.shape[1], value_down.shape[0], (1, 1), bias=False
)
layer_mid = torch.nn.Conv2d(
value_mid.shape[1],
value_mid.shape[0],
wrapped.kernel_size,
wrapped.stride,
wrapped.padding,
bias=False,
)
else: else:
layer_down = torch.nn.Conv2d(value_down.shape[1], value_down.shape[0], wrapped.kernel_size, wrapped.stride, wrapped.padding, bias=False) layer_down = torch.nn.Conv2d(
layer_mid = None value_down.shape[1],
value_down.shape[0],
wrapped.kernel_size,
wrapped.stride,
wrapped.padding,
bias=False,
)
layer_mid = None
layer_up = torch.nn.Conv2d(value_up.shape[1], value_up.shape[0], (1, 1), bias=False) layer_up = torch.nn.Conv2d(
value_up.shape[1], value_up.shape[0], (1, 1), bias=False
)
elif type(wrapped) == torch.nn.Linear: elif type(wrapped) == torch.nn.Linear:
layer_down = torch.nn.Linear(value_down.shape[1], value_down.shape[0], bias=False) layer_down = torch.nn.Linear(
layer_mid = None value_down.shape[1], value_down.shape[0], bias=False
layer_up = torch.nn.Linear(value_up.shape[1], value_up.shape[0], bias=False) )
layer_mid = None
layer_up = torch.nn.Linear(
value_up.shape[1], value_up.shape[0], bias=False
)
else: else:
print( print(
@@ -261,52 +362,90 @@ class LoRA:
) )
return return
with torch.no_grad(): with torch.no_grad():
layer_down.weight.copy_(value_down) layer_down.weight.copy_(value_down)
if layer_mid is not None: if layer_mid is not None:
layer_mid.weight.copy_(value_mid) layer_mid.weight.copy_(value_mid)
layer_up.weight.copy_(value_up) layer_up.weight.copy_(value_up)
layer_down.to(device=self.device, dtype=self.dtype) layer_down.to(device=self.device, dtype=self.dtype)
if layer_mid is not None: if layer_mid is not None:
layer_mid.to(device=self.device, dtype=self.dtype) layer_mid.to(device=self.device, dtype=self.dtype)
layer_up.to(device=self.device, dtype=self.dtype) layer_up.to(device=self.device, dtype=self.dtype)
rank = value_down.shape[0] rank = value_down.shape[0]
layer = LoRALayer(self.name, stem, rank, alpha) layer = LoRALayer(self.name, stem, rank, alpha)
#layer.bias = bias # TODO: find and debug lora/locon with bias # layer.bias = bias # TODO: find and debug lora/locon with bias
layer.down = layer_down layer.down = layer_down
layer.mid = layer_mid layer.mid = layer_mid
layer.up = layer_up layer.up = layer_up
# loha # loha
elif "hada_w1_b" in values: elif "hada_w1_b" in values:
rank = values["hada_w1_b"].shape[0] rank = values["hada_w1_b"].shape[0]
layer = LoHALayer(self.name, stem, rank, alpha) layer = LoHALayer(self.name, stem, rank, alpha)
layer.org_module = wrapped layer.org_module = wrapped
layer.bias = bias layer.bias = bias
layer.w1_a = values["hada_w1_a"].to(device=self.device, dtype=self.dtype) layer.w1_a = values["hada_w1_a"].to(
layer.w1_b = values["hada_w1_b"].to(device=self.device, dtype=self.dtype) device=self.device, dtype=self.dtype
layer.w2_a = values["hada_w2_a"].to(device=self.device, dtype=self.dtype) )
layer.w2_b = values["hada_w2_b"].to(device=self.device, dtype=self.dtype) layer.w1_b = values["hada_w1_b"].to(
device=self.device, dtype=self.dtype
)
layer.w2_a = values["hada_w2_a"].to(
device=self.device, dtype=self.dtype
)
layer.w2_b = values["hada_w2_b"].to(
device=self.device, dtype=self.dtype
)
if "hada_t1" in values: if "hada_t1" in values:
layer.t1 = values["hada_t1"].to(device=self.device, dtype=self.dtype) layer.t1 = values["hada_t1"].to(
device=self.device, dtype=self.dtype
)
else: else:
layer.t1 = None layer.t1 = None
if "hada_t2" in values: if "hada_t2" in values:
layer.t2 = values["hada_t2"].to(device=self.device, dtype=self.dtype) layer.t2 = values["hada_t2"].to(
device=self.device, dtype=self.dtype
)
else: else:
layer.t2 = None layer.t2 = None
# lokr
elif "lokr_w1_b" in values or "lokr_w1" in values:
if "lokr_w1_b" in values:
rank = values["lokr_w1_b"].shape[0]
elif "lokr_w2_b" in values:
rank = values["lokr_w2_b"].shape[0]
else:
rank = None # unscaled
layer = LoKRLayer(self.name, stem, rank, alpha)
layer.org_module = wrapped
layer.bias = bias
if "lokr_w1" in values:
layer.w1 = values["lokr_w1"].to(device=self.device, dtype=self.dtype)
else:
layer.w1_a = values["lokr_w1_a"].to(device=self.device, dtype=self.dtype)
layer.w1_b = values["lokr_w1_b"].to(device=self.device, dtype=self.dtype)
if "lokr_w2" in values:
layer.w2 = values["lokr_w2"].to(device=self.device, dtype=self.dtype)
else:
layer.w2_a = values["lokr_w2_a"].to(device=self.device, dtype=self.dtype)
layer.w2_b = values["lokr_w2_b"].to(device=self.device, dtype=self.dtype)
if "lokr_t2" in values:
layer.t2 = values["lokr_t2"].to(device=self.device, dtype=self.dtype)
else: else:
print( print(
f">> Encountered unknown lora layer module in {self.name}: {stem} - {type(wrapped).__name__}" f">> Encountered unknown lora layer module in {self.name}: {stem} - {type(wrapped).__name__}"
@@ -317,14 +456,25 @@ class LoRA:
class KohyaLoraManager: class KohyaLoraManager:
def __init__(self, pipe, lora_path):
def __init__(self, pipe):
self.vector_length_cache_path = self.lora_path / '.vectorlength.cache'
self.unet = pipe.unet self.unet = pipe.unet
self.lora_path = lora_path
self.wrapper = LoRAModuleWrapper(pipe.unet, pipe.text_encoder) self.wrapper = LoRAModuleWrapper(pipe.unet, pipe.text_encoder)
self.text_encoder = pipe.text_encoder self.text_encoder = pipe.text_encoder
self.device = torch.device(choose_torch_device()) self.device = torch.device(choose_torch_device())
self.dtype = pipe.unet.dtype self.dtype = pipe.unet.dtype
@classmethod
@property
def lora_path(cls)->Path:
return Path(global_lora_models_dir())
@classmethod
@property
def vector_length_cache_path(cls)->Path:
return cls.lora_path / '.vectorlength.cache'
def load_lora_module(self, name, path_file, multiplier: float = 1.0): def load_lora_module(self, name, path_file, multiplier: float = 1.0):
print(f" | Found lora {name} at {path_file}") print(f" | Found lora {name} at {path_file}")
if path_file.suffix == ".safetensors": if path_file.suffix == ".safetensors":
@@ -332,6 +482,9 @@ class KohyaLoraManager:
else: else:
checkpoint = torch.load(path_file, map_location="cpu") checkpoint = torch.load(path_file, map_location="cpu")
if not self.check_model_compatibility(checkpoint):
raise IncompatibleModelException
lora = LoRA(name, self.device, self.dtype, self.wrapper, multiplier) lora = LoRA(name, self.device, self.dtype, self.wrapper, multiplier)
lora.load_from_dict(checkpoint) lora.load_from_dict(checkpoint)
self.wrapper.loaded_loras[name] = lora self.wrapper.loaded_loras[name] = lora
@@ -339,12 +492,14 @@ class KohyaLoraManager:
return lora return lora
def apply_lora_model(self, name, mult: float = 1.0): def apply_lora_model(self, name, mult: float = 1.0):
path_file = None
for suffix in ["ckpt", "safetensors", "pt"]: for suffix in ["ckpt", "safetensors", "pt"]:
path_file = Path(self.lora_path, f"{name}.{suffix}") path_files = [x for x in Path(self.lora_path).glob(f"**/{name}.{suffix}")]
if path_file.exists(): if len(path_files):
path_file = path_files[0]
print(f" | Loading lora {path_file.name} with weight {mult}") print(f" | Loading lora {path_file.name} with weight {mult}")
break break
if not path_file.exists(): if not path_file:
print(f" ** Unable to find lora: {name}") print(f" ** Unable to find lora: {name}")
return return
@@ -355,13 +510,89 @@ class KohyaLoraManager:
lora.multiplier = mult lora.multiplier = mult
self.wrapper.applied_loras[name] = lora self.wrapper.applied_loras[name] = lora
def unload_applied_lora(self, lora_name: str): def unload_applied_lora(self, lora_name: str) -> bool:
"""If the indicated LoRA has previously been applied then
unload it and return True. Return False if the LoRA was
not previously applied (for status reporting)
"""
if lora_name in self.wrapper.applied_loras: if lora_name in self.wrapper.applied_loras:
del self.wrapper.applied_loras[lora_name] del self.wrapper.applied_loras[lora_name]
return True
return False
def unload_lora(self, lora_name: str): def unload_lora(self, lora_name: str) -> bool:
if lora_name in self.wrapper.loaded_loras: if lora_name in self.wrapper.loaded_loras:
del self.wrapper.loaded_loras[lora_name] del self.wrapper.loaded_loras[lora_name]
return True
return False
def clear_loras(self): def clear_loras(self):
self.wrapper.clear_applied_loras() self.wrapper.clear_applied_loras()
def check_model_compatibility(self, checkpoint) -> bool:
"""Checks whether the LoRA checkpoint is compatible with the token vector
length of the model that this manager is associated with.
"""
model_token_vector_length = (
self.text_encoder.get_input_embeddings().weight.data[0].shape[0]
)
lora_token_vector_length = self.vector_length_from_checkpoint(checkpoint)
return model_token_vector_length == lora_token_vector_length
@staticmethod
def vector_length_from_checkpoint(checkpoint: dict) -> int:
"""Return the vector token length for the passed LoRA checkpoint object.
This is used to determine which SD model version the LoRA was based on.
768 -> SDv1
1024-> SDv2
"""
key1 = "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight"
key2 = "lora_te_text_model_encoder_layers_0_self_attn_k_proj.hada_w1_a"
lora_token_vector_length = (
checkpoint[key1].shape[1]
if key1 in checkpoint
else checkpoint[key2].shape[0]
if key2 in checkpoint
else 768
)
return lora_token_vector_length
@classmethod
def vector_length_from_checkpoint_file(self, checkpoint_path: Path) -> int:
with LoraVectorLengthCache(self.vector_length_cache_path) as cache:
if str(checkpoint_path) not in cache:
if checkpoint_path.suffix == ".safetensors":
checkpoint = load_file(
checkpoint_path.absolute().as_posix(), device="cpu"
)
else:
checkpoint = torch.load(checkpoint_path, map_location="cpu")
cache[str(checkpoint_path)] = KohyaLoraManager.vector_length_from_checkpoint(
checkpoint
)
return cache[str(checkpoint_path)]
class LoraVectorLengthCache(object):
def __init__(self, cache_path: Path):
self.cache_path = cache_path
self.lock = FileLock(Path(cache_path.parent, ".cachelock"))
self.cache = {}
def __enter__(self):
self.lock.acquire(timeout=10)
try:
if self.cache_path.exists():
with open(self.cache_path, "r") as json_file:
self.cache = json.load(json_file)
except Timeout:
print(
"** Can't acquire lock on lora vector length cache. Operations will be slower"
)
except (json.JSONDecodeError, OSError):
self.cache_path.unlink()
return self.cache
def __exit__(self, type, value, traceback):
with open(self.cache_path, "w") as json_file:
json.dump(self.cache, json_file)
self.lock.release()

View File

@@ -1,66 +1,101 @@
import os import os
from diffusers import StableDiffusionPipeline
from pathlib import Path from pathlib import Path
from diffusers import UNet2DConditionModel, StableDiffusionPipeline
from ldm.invoke.globals import global_lora_models_dir from ldm.invoke.globals import global_lora_models_dir
from .kohya_lora_manager import KohyaLoraManager from .kohya_lora_manager import KohyaLoraManager, IncompatibleModelException
from typing import Optional, Dict from typing import Optional, Dict
class LoraCondition: class LoraCondition:
name: str name: str
weight: float weight: float
def __init__(self, name, weight: float = 1.0, kohya_manager: Optional[KohyaLoraManager]=None): def __init__(self,
name,
weight: float = 1.0,
unet: UNet2DConditionModel=None, # for diffusers format LoRAs
kohya_manager: Optional[KohyaLoraManager]=None, # for KohyaLoraManager-compatible LoRAs
):
self.name = name self.name = name
self.weight = weight self.weight = weight
self.kohya_manager = kohya_manager self.kohya_manager = kohya_manager
self.unet = unet
def __call__(self, model): def __call__(self):
# TODO: make model able to load from huggingface, rather then just local files # TODO: make model able to load from huggingface, rather then just local files
path = Path(global_lora_models_dir(), self.name) path = Path(global_lora_models_dir(), self.name)
if path.is_dir(): if path.is_dir():
if model.load_attn_procs: if not self.unet:
print(f" ** Unable to load diffusers-format LoRA {self.name}: unet is None")
return
if self.unet.load_attn_procs:
file = Path(path, "pytorch_lora_weights.bin") file = Path(path, "pytorch_lora_weights.bin")
if file.is_file(): if file.is_file():
print(f">> Loading LoRA: {path}") print(f">> Loading LoRA: {path}")
model.load_attn_procs(path.absolute().as_posix()) self.unet.load_attn_procs(path.absolute().as_posix())
else: else:
print(f" ** Unable to find valid LoRA at: {path}") print(f" ** Unable to find valid LoRA at: {path}")
else: else:
print(" ** Invalid Model to load LoRA") print(" ** Invalid Model to load LoRA")
elif self.kohya_manager: elif self.kohya_manager:
self.kohya_manager.apply_lora_model(self.name,self.weight) try:
self.kohya_manager.apply_lora_model(self.name,self.weight)
except IncompatibleModelException:
print(f" ** LoRA {self.name} is incompatible with this model; will generate without the LoRA applied.")
else: else:
print(" ** Unable to load LoRA") print(" ** Unable to load LoRA")
def unload(self): def unload(self):
if self.kohya_manager: if self.kohya_manager and self.kohya_manager.unload_applied_lora(self.name):
print(f'>> unloading LoRA {self.name}') print(f'>> unloading LoRA {self.name}')
self.kohya_manager.unload_applied_lora(self.name)
class LoraManager: class LoraManager:
def __init__(self, pipe): def __init__(self, pipe: StableDiffusionPipeline):
# Kohya class handles lora not generated through diffusers # Kohya class handles lora not generated through diffusers
self.kohya = KohyaLoraManager(pipe, global_lora_models_dir()) self.kohya = KohyaLoraManager(pipe)
self.unet = pipe.unet
def set_loras_conditions(self, lora_weights: list): def set_loras_conditions(self, lora_weights: list):
conditions = [] conditions = []
if len(lora_weights) > 0: if len(lora_weights) > 0:
for lora in lora_weights: for lora in lora_weights:
conditions.append(LoraCondition(lora.model, lora.weight, self.kohya)) conditions.append(LoraCondition(lora.model, lora.weight, self.unet, self.kohya))
if len(conditions) > 0: if len(conditions) > 0:
return conditions return conditions
return None return None
def list_compatible_loras(self)->Dict[str, Path]:
'''
List all the LoRAs in the global lora directory that
are compatible with the current model. Return a dictionary
of the lora basename and its path.
'''
model_length = self.kohya.text_encoder.get_input_embeddings().weight.data[0].shape[0]
return self.list_loras(model_length)
@classmethod @staticmethod
def list_loras(self)->Dict[str, Path]: def list_loras(token_vector_length:int=None)->Dict[str, Path]:
'''List the LoRAS in the global lora directory.
If token_vector_length is provided, then only return
LoRAS that have the indicated length:
768: v1 models
1024: v2 models
'''
path = Path(global_lora_models_dir()) path = Path(global_lora_models_dir())
models_found = dict() models_found = dict()
for root,_,files in os.walk(path): for root,_,files in os.walk(path):
for x in files: for x in files:
name = Path(x).stem name = Path(x).stem
suffix = Path(x).suffix suffix = Path(x).suffix
if suffix in [".ckpt", ".pt", ".safetensors"]: if suffix not in [".ckpt", ".pt", ".safetensors"]:
models_found[name]=Path(root,x) continue
path = Path(root,x)
if token_vector_length is None:
models_found[name]=Path(root,x) # unconditional addition
elif token_vector_length == KohyaLoraManager.vector_length_from_checkpoint_file(path):
models_found[name]=Path(root,x) # conditional on the base model matching
return models_found return models_found

View File

@@ -34,7 +34,7 @@ dependencies = [
"clip_anytorch", "clip_anytorch",
"compel~=1.1.0", "compel~=1.1.0",
"datasets", "datasets",
"diffusers[torch]~=0.14", "diffusers[torch]~=0.15.0",
"dnspython==2.2.1", "dnspython==2.2.1",
"einops", "einops",
"eventlet", "eventlet",