Compare commits

...

353 Commits

Author SHA1 Message Date
psychedelicious
3aca35c932 wip upscale node 2023-07-15 21:13:44 +10:00
psychedelicious
ee7d700ae4 chore(ui): regen types 2023-06-27 17:49:19 +10:00
psychedelicious
ca1b96f1df feat(nodes): add WIP real-esrgan node 2023-06-27 17:49:06 +10:00
Lincoln Stein
3c30368c62 Configure and model install TUI tweaks (#3519)
The installer TUI requires a minimum window width and height to provide
a satisfactory user experience. If, after trying and exhausting all
means of enlarging the window (on Linux, Mac and Windows) the window is
still too small, this PR generates a message telling the user to enlarge
the window and pausing until they do so. If the user fails to enlarge
the window the program will proceed, and either issue an error message
that it can't continue (on Windows), or show a clipped display that the
user can remedy by enlarging the window.
2023-06-26 16:08:56 -04:00
Lincoln Stein
ea15d037f9 Merge branch 'main' into lstein/tweak-installer-ui 2023-06-26 15:05:16 -04:00
Lincoln Stein
bf1f2eb128 Bypass failing tests (#3593)
"Fixes" the test suite generally so it doesn't fail CI, but some tests
needed to be skipped/xfailed due to recent refactor.

- ignore three test suites that broke following the model manager
refactor
- move `InvocationServices` fixture to `conftest.py`
- add `boards` items to the `InvocationServices`  fixture

This PR makes the unit tests work, but end-to-end tests are temporarily
commented out due to `invokeai-configure` being broken in `main` -
pending #3547

Looks like a lot of the tests need to be rewritten as they reference
`TextToImageInvocation` / `ImageToImageInvocation`
2023-06-26 14:41:56 -04:00
Lincoln Stein
16829682c8 Merge branch 'main' into ebr/make-tests-pass 2023-06-26 14:27:31 -04:00
Lincoln Stein
befd95eb19 rename root_dir to root_path attributes to emphasize return of a Path 2023-06-26 13:52:25 -04:00
Eugene Brodsky
cc400c9fa5 (ci) temporarily comment out end-to-end tests 2023-06-26 13:08:43 -04:00
Eugene Brodsky
4eb7a5fc60 (ci) clean up pip tests 2023-06-26 13:08:43 -04:00
Eugene Brodsky
587203d589 (tests) make fixture reusable; support boards
fixes the test suite generally, but some tests needed to be
skipped/xfailed due to recent refactor

- ignore three test suites that broke following the model manager
  refactor
- move InvocationServices fixture to conftest.py
- add `boards` InvocationServices to the fixture
2023-06-26 13:08:34 -04:00
blessedcoolant
d905d0e42a feat(ui): only show canvas image fallback on loading error (#3589) 2023-06-26 21:40:10 +12:00
psychedelicious
6ccf62a863 feat(ui): only show canvas image fallback on loading error 2023-06-26 19:20:05 +10:00
psychedelicious
6390af229d feat(ui): add dynamic prompts to t2i tab
- add param accordion for dynamic prompts
- update graphs
2023-06-26 19:15:54 +10:00
blessedcoolant
9cfac4175f feat(ui): improved node parsing (#3584)
- use `swagger-parser` to dereference openapi schema
- tidy vite plugins
- use mantine select for node add menu
2023-06-26 17:38:23 +12:00
blessedcoolant
3a19be1606 fix: Add missing IAIMantineSelect disabled styles 2023-06-26 17:37:47 +12:00
blessedcoolant
b51ab056f2 Merge branch 'main' into feat/ui/update-node-parsing 2023-06-26 17:32:44 +12:00
blessedcoolant
e206fad22a fix(ui): fix controlnet image size (#3585) 2023-06-26 17:32:07 +12:00
psychedelicious
60780e990d fix(ui): fix controlnet image size 2023-06-26 12:03:11 +10:00
psychedelicious
8d43cf92f6 feat(ui): update action santizer for schema actions 2023-06-26 12:00:38 +10:00
psychedelicious
862bf7546c feat(ui): improved node parsing
- use `swagger-parser` to dereference openapi schema
- tidy vite plugins
- use mantine select for node add menu
2023-06-26 11:53:54 +10:00
blessedcoolant
922468b836 Add control_mode parameter to ControlNet (#3535)
This PR adds the "control_mode" option to ControlNet implementation. 
Possible control_mode options are: 

- balanced -- this is the default, same as previous implementation
without control_mode
- more_prompt -- pays more attention to the prompt
- more _control -- pays more attention to the ControlNet (in earlier
implementations this was called "guess_mode")
- unbalanced -- pays even more attention to the ControlNet 

balanced, more_prompt, and more_control should be nearly identical to
the equivalent options in the [auto1111 sd-webui-controlnet
extension](https://github.com/Mikubill/sd-webui-controlnet#more-control-modes-previously-called-guess-mode)

The changes to enable balanced, more_prompt, and more_control are
managed deeper in the code by two booleans, "soft_injection" and
"cfg_injection". The three control mode options in sd-webui-controlnet
map to these booleans like:
 
!soft_injection && !cfg_injection ⇒  BALANCED            
 soft_injection &&  cfg_injection ⇒  MORE_CONTROL 
 soft_injection && !cfg_injection ⇒  MORE_PROMPT   
 
The "unbalanced" option simply exposes the fourth possible combination
of these two booleans:
!soft_injection &&  cfg_injection ⇒ UNBALANCED

With "unbalanced" mode it is very easy to overdrive the controlnet
inputs. It's recommended to use a cfg_scale between 2 and 4 to mitigate
this, along with lowering controlnet weight and possibly lowering "end
step percent". With those caveats, "unbalanced" can yield interesting
results.

Example of all four modes using Canny edge detection ControlNet with
prompt "old man", identical params except for control_mode:

![Screenshot from 2023-06-11
23-53-00](https://github.com/invoke-ai/InvokeAI/assets/303100/c9e31e7f-50de-4d85-94f2-b5a4af3d067b)
Top middle:       BALANCED
Top right:          MORE_CONTROL
Bottom middle: MORE_PROMPT
Bottom right :    UNBALANCED

I kind of chose this seed because it shows pretty rough results with
BALANCED (the default), but in my opinion better results with both
MORE_CONTROL and MORE_PROMPT. And you can definitely see how MORE_PROMPT
pays more attention to the prompt, and MORE_CONTROL pays more attention
to the control image. And shows that UNBALANCED with default cfg_scale
etc is unusable.

But here are four examples from same series (same seed etc), all have
control_mode = UNBALANCED but now cfg_scale is set to 3.
![Screenshot from 2023-06-11
23-48-44](https://github.com/invoke-ai/InvokeAI/assets/303100/5a495306-2164-40aa-9cc8-ce737d7671e7)
And param differences are:
Top middle: prompt="old man", control_weight=0.3, end_step_percent=0.5
Top right: prompt="old man", control_weight=0.4, end_step_percent=1.0
Bottom middle: prompt=None, control_weight=0.3, end_step_percent=0.5
Bottom right: prompt=None, control_weight=0.4, end_step_percent=1.0

So with the right settings UNBALANCED seems useful.
2023-06-25 16:09:26 +12:00
psychedelicious
57e719702d fix(ui): add missing ControlNetInvocation type; tidy schema-derived types 2023-06-25 14:04:53 +10:00
psychedelicious
11378a9236 chore(ui): regen api schema 2023-06-25 14:04:16 +10:00
psychedelicious
132829c88f fix(ui): fix path of generated schema types 2023-06-25 14:04:00 +10:00
blessedcoolant
4d4b5b56dc Merge branch 'main' into feat/controlnet-control-modes 2023-06-25 15:48:07 +12:00
blessedcoolant
a9334128c9 chore(ui): bump all packages (#3579)
Everything seems to be working.

- Due to a change to `reactflow`, I regenerated `yarn.lock`
- New chakra CLI fixes issue I had made a patch for; removed the patch
- Change to fontsource changed how we import that font
- Change to fontawesome means we lost the txt2img tab icon, just chose a
similar one
2023-06-25 15:45:39 +12:00
psychedelicious
6b276587d8 chore(ui): bump all packages
Everything seems to be working.

- Due to a change to `reactflow`, I regenerated `yarn.lock`
- New chakra CLI fixes issue I had made a patch for; removed the patch
- Change to fontsource changed how we import that font
- Change to fontawesome means we lost the txt2img tab icon, just chose a similar one
2023-06-25 13:44:10 +10:00
user1
c5faffc18b Merge branch 'main' of github.com:invoke-ai/InvokeAI into feat/controlnet-control-modes
Only "real" conflicts were in:
     invokeai/frontend/web/src/features/controlNet/components/ControlNet.tsx
     invokeai/frontend/web/src/features/controlNet/store/controlNetSlice.ts
2023-06-24 17:05:57 -07:00
psychedelicious
3ae996ebcb fix(ui): fix metadata viewer too stronk 2023-06-24 18:15:49 +10:00
psychedelicious
3d16605762 fix(ui): fix controlnet upload button 2023-06-24 18:15:49 +10:00
psychedelicious
b6dec2b826 fix(ui): fix controlnet dnd overlay not showing on dragover 2023-06-24 18:15:49 +10:00
psychedelicious
013e2aa2a1 fix(ui): fix control image sizes
they were all weird
2023-06-24 18:15:49 +10:00
psychedelicious
8f9fa15fc8 fix(ui): fix image fetching query string 2023-06-24 18:15:49 +10:00
psychedelicious
dde497404b fix(ui): fix init image display buttons
- Reset and Upload buttons along top of initial image
- Also had to mess around with the control net & DnD image stuff after changing the styles
- Abstract image upload logic into hook - does not handle native HTML drag and drop upload - only the button click upload
2023-06-24 18:15:49 +10:00
psychedelicious
0472b33164 fix(ui): fix duplicate is_intermediate query param when fetching images 2023-06-24 17:57:39 +10:00
psychedelicious
a6c615a98c fix(ui): fix canvas staging area
Missed some of the `imageUpdated` stuff
2023-06-24 17:57:39 +10:00
psychedelicious
bab3a9504e fix(nodes): fix LatentsToImage not using is_intermediate when creating images
Appears this was removed during a merge conflict resolution.
2023-06-24 17:57:39 +10:00
psychedelicious
13f25edb1e fix(ui): fix incorrect boards endpoint matchers being used
Should fix some stale-data issues with the auto-adding of images to selected boards, and deleting images from boards.
2023-06-24 17:57:39 +10:00
psychedelicious
8bacee115a fix(ui): fix thunks not using configured api client 2023-06-24 17:57:39 +10:00
psychedelicious
3619c86f07 fix(ui): fix deleting image does not refresh board
I had some some wonkiness in the thunks
2023-06-24 17:57:39 +10:00
psychedelicious
8e724b5abe fix(ui): fix image upload
`openapi-fetch` does not handle non-JSON `body`s, always stringifying them, and sets the `content-type` to `application/json`.

The patch here does two things:
- Do not stringify `body` if it is one of the types that should not be stringified (https://developer.mozilla.org/en-US/docs/Web/API/Fetch_API/Using_Fetch#body)
- Do not add `content-type: application/json` unless it really is stringified JSON.

Upstream issue: https://github.com/drwpow/openapi-typescript/issues/1123

I'm not a bit lost on fixing the types and adding tests, so not raising a PR upstream.
2023-06-24 17:57:39 +10:00
psychedelicious
e076231398 fix(ui): fix node editor image fields
I had broken this when converting to rtk-query
2023-06-24 17:57:39 +10:00
psychedelicious
e386b5dc53 feat(ui): api layer refactor
*migrate from `openapi-typescript-codegen` to `openapi-typescript` and `openapi-fetch`*

`openapi-typescript-codegen` is not very actively maintained - it's been over a year since the last update.
`openapi-typescript` and `openapi-fetch` are part of the actively maintained repo. key differences:

- provides a `fetch` client instead of `axios`, which means we need to be a bit more verbose with typing thunks
- fetch client is created at runtime and has a very nice typescript DX
- generates a single file with all types in it, from which we then extract individual types. i don't like how verbose this is, but i do like how it is more explicit.
- removed npm api generation scripts - now we have a single `typegen` script

overall i have more confidence in this new library.

*use nanostores for api base and token*

very simple reactive store for api base url and token. this was suggested in the `openapi-fetch` docs and i quite like the strategy.

*organise rtk-query api*

split out each endpoint (models, images, boards, boardImages) into their own api extensions. tidy!
2023-06-24 17:57:39 +10:00
Mary Hipp
8137a99981 simplify 2023-06-24 17:57:39 +10:00
Mary Hipp
878847defd use BASE and TOKEN from OpenAPI if they are set 2023-06-24 17:57:39 +10:00
Lincoln Stein
9de54b2266 Fix vae conversion (#3555)
Unsure at which moment it broke, but now I can't convert vae(and model
as vae it's part) without this fix.
Need further research - maybe it's breaking change in `transformers`?
2023-06-23 15:55:26 +01:00
Sergey Borisov
5aaaaf64a1 Fix ckpt conversion 2023-06-23 17:29:54 +03:00
StAlKeR7779
9140e2c0f2 Merge branch 'main' into fix/vae_conversion 2023-06-23 15:03:59 +03:00
Lincoln Stein
83e2b7578b fix(linux): installer script prints maximum python version usable (#3546)
Changes:
* Linux `install.sh` now prints the maximum python version to use in
case no installed python version matches

Commits:
fix(linux): installer script prints maximum python version usable
2023-06-23 02:16:01 +01:00
Lincoln Stein
df1907e849 Merge branch 'main' into install-script-python-version-error-prompt-fix 2023-06-23 02:15:36 +01:00
blessedcoolant
22c337b1aa Update UI To Use New Model Manager (#3548)
PR for the Model Manager UI work related to 3.0

[DONE]

- Update ModelType Config names to be specific so that the front end can
parse them correctly.
- Rebuild frontend schema to reflect these changes.
- Update Linear UI Text To Image and Image to Image to work with the new
model loader.
- Updated the ModelInput component in the Node Editor to work with the
new changes.

[TODO REMEMBER]

- Add proper types for ModelLoaderType in `ModelSelect.tsx`

[TODO] 

- Everything else.
2023-06-22 22:06:26 +12:00
psychedelicious
339e7ce213 feat(ui): initial implementation of model loading
- Update model listing code to use `rtk-query`
- Update all graph generation to use new `pipeline_model_loader` node
2023-06-22 17:48:57 +10:00
psychedelicious
2a178f5a25 chore(ui): regen api client 2023-06-22 17:48:13 +10:00
psychedelicious
1bc170727b tidy(nodes): rename sd_model_loader to pipeline_model_loader
this is more accurate bc it can do eg kandinsky also
2023-06-22 17:47:58 +10:00
psychedelicious
3722cdf5d6 chore(ui): regen api client 2023-06-22 17:36:20 +10:00
psychedelicious
42a59aa147 feat(nodes): add sd_model_loader node
Loads any pipeline model.

Also introduced is `PipelineModelField`, which includes a model name and base model.
2023-06-22 17:36:05 +10:00
psychedelicious
b937b7da01 feat(models): update model manager service & route to return list of models 2023-06-22 17:34:12 +10:00
Sergey Borisov
21245a0fb2 Set model type to const value in openapi schema, add model format enums to model schema(as they not not referenced in case of Literal definition) 2023-06-22 16:51:53 +10:00
Sergey Borisov
da566b59e8 Update model format field to use enums 2023-06-22 16:51:53 +10:00
Sergey Borisov
e4dc9c5a04 Rename format to model_format(still named format when work with config) 2023-06-22 16:51:53 +10:00
Sergey Borisov
aceadacad4 Remove default model logic 2023-06-22 16:51:53 +10:00
blessedcoolant
d3dec59cc3 tweal: UI colors 2023-06-22 16:51:53 +10:00
blessedcoolant
6c98700740 fix: Adjust the Schedular select width
So the long names do not get cut off.
2023-06-22 16:51:53 +10:00
blessedcoolant
c4c3c96062 Revert "feat: Port Schedulers to Mantine"
This reverts commit e0c105f413.
2023-06-22 16:51:35 +10:00
blessedcoolant
6256be480c fix: Remove type from Model type name 2023-06-22 16:48:35 +10:00
blessedcoolant
7033071934 fix: Unserialization key issue 2023-06-22 16:48:35 +10:00
blessedcoolant
e48528bbef revert: getModels to receivedModels 2023-06-22 16:48:35 +10:00
blessedcoolant
6bdf68dd4c feat: Port Schedulers to Mantine 2023-06-22 16:48:35 +10:00
blessedcoolant
0c3616229e cleanup: Updated model slice names to be more descriptive
Basically updated all slices to be more descriptive in their names. Did so in order to make sure theres good naming scheme available for secondary models.
2023-06-22 16:43:14 +10:00
blessedcoolant
604cc1adcd wip: Move Model Selector to own file 2023-06-22 16:43:14 +10:00
blessedcoolant
4847212d5b feat: Enable 2.x Model Generation in Linear UI 2023-06-22 16:43:14 +10:00
blessedcoolant
727293d722 fix: 2.1 models breaking generation
Co-Authored-By: StAlKeR7779 <7768370+StAlKeR7779@users.noreply.github.com>
2023-06-22 16:42:59 +10:00
blessedcoolant
d2f3500e1b chore: Rebuild API - base_model and type added 2023-06-22 16:42:59 +10:00
Sergey Borisov
ef83a2fffe Add name, base_mode, type fields to model info 2023-06-22 16:42:51 +10:00
blessedcoolant
f8d7477c7a wip: Add 2.x Models to the Model List 2023-06-22 16:42:51 +10:00
blessedcoolant
e374211313 chore: Rebuild API with new Model API names 2023-06-22 16:41:31 +10:00
Sergey Borisov
01d17601b8 Generate config names for openapi 2023-06-22 16:41:19 +10:00
blessedcoolant
bf0d5f4cfc fix: Update missing name types to new names 2023-06-22 16:41:02 +10:00
blessedcoolant
663f4935f5 chore: Rebuild API 2023-06-22 16:41:02 +10:00
blessedcoolant
9838dda1b7 chore: Update model config type names 2023-06-22 16:40:40 +10:00
psychedelicious
2d889e133d chore(ui): regen api client 2023-06-22 16:25:49 +10:00
psychedelicious
6779f1a5ad fix(db): update models for boards w/ nullable deleted_at 2023-06-22 16:25:49 +10:00
psychedelicious
19a6e5dad8 chore(ui): regen api client 2023-06-22 16:25:49 +10:00
psychedelicious
285195bf72 feat(api): add get_board route 2023-06-22 16:25:49 +10:00
psychedelicious
10008859a4 tidy(ui): remove all refs to boards thunks 2023-06-22 16:25:49 +10:00
psychedelicious
3c04340f3f tidy(ui): tidy up update image board modal 2023-06-22 16:25:49 +10:00
psychedelicious
79f0c4d3c4 feat(ui): add remove from board to image context menu 2023-06-22 16:25:49 +10:00
psychedelicious
37d4e05838 fix(ui): fix board's image list not updating when image removed from board 2023-06-22 16:25:49 +10:00
psychedelicious
a00ad6ac03 feat(ui): dropping image on All Images board removes it from board 2023-06-22 16:25:49 +10:00
psychedelicious
2ffead000c tidy(ui): remove console.log() 2023-06-22 16:25:49 +10:00
psychedelicious
922319cb84 fix(ui): fix first added board doesn't show until refresh
Had incorrect `invalidatesTags` array for the mutation.
2023-06-22 16:25:49 +10:00
psychedelicious
6ee0e197bb feat(db): add deleted_at to board_images 2023-06-22 16:25:49 +10:00
psychedelicious
d3e6f0130c fix(ui): fix issue with gallery not letting you load more images
To determine whether the Load More button should work, we need to keep track of how many images are left to load for a given board or category.

The Assets tab doesn't work, though. Need to figure out a better way to handle this.
2023-06-22 16:25:49 +10:00
psychedelicious
421c23d3ea fix(ui): fix gallery image fetching for board categories 2023-06-22 16:25:49 +10:00
psychedelicious
4545f3209f fix(ui): fix bug with image deletion not removing image from gallery 2023-06-22 16:25:49 +10:00
psychedelicious
e2ee8102c2 tidy(db): tidy image_record_storage.py 2023-06-22 16:25:49 +10:00
psychedelicious
083a0fc4cf tidy(ui): remove references to boardsAdapter 2023-06-22 16:25:49 +10:00
psychedelicious
26b75b85f7 fix(ui): if deleting selected board, deselect it 2023-06-22 16:25:49 +10:00
psychedelicious
f560a462a0 feat(ui): rudimentary categorized gallery image fetching 2023-06-22 16:25:49 +10:00
psychedelicious
d501986610 chore(ui): regen api client 2023-06-22 16:25:49 +10:00
psychedelicious
67a75f6895 feat(api, db): support board_id filter on images service get_many() 2023-06-22 16:25:49 +10:00
psychedelicious
3c032c0767 feat(ui): only auto-add image to board if is not intermediate 2023-06-22 16:25:49 +10:00
psychedelicious
abd6561140 feat(ui): just fetch all boards instead of paginating them 2023-06-22 16:25:49 +10:00
psychedelicious
bd533426fc feat(ui): first pass at boards styling 2023-06-22 16:25:49 +10:00
psychedelicious
2489d5459f chore(ui): regen api client 2023-06-22 16:25:49 +10:00
psychedelicious
ac477cf5d6 fix(ui): improve image deletion handling 2023-06-22 16:25:49 +10:00
psychedelicious
be3bdae847 fix: resolve rebase conflicts 2023-06-22 16:25:49 +10:00
psychedelicious
3e0ee838cf fix(ui): add initial image dimensions to state
We need to access the initial image dimensions during the creation of the `ImageToImage` graph to determine if we need to resize the image.

Because the `initialImage` is now just an image name, we need to either store (easy) or dynamically retrieve its dimensions during graph creation (a bit less easy).

Took the easiest path. May need to revise this in the future.
2023-06-22 16:25:49 +10:00
psychedelicious
8d3bec57d5 feat(ui): store only image name in parameters
Images that are used as parameters (e.g. init image, canvas images) are stored as full `ImageDTO` objects in state, separate from and duplicating any object representing those same objects in the `imagesSlice`.

We cannot store only image names as parameters, then pull the full `ImageDTO` from `imagesSlice`, because if an image is not on a loaded page, it doesn't exist in `imagesSlice`. For example, if you scroll down a few pages in the gallery and send that image to canvas, on reloading the app, the canvas will be unable to load that image.

We solved this temporarily by storing the full `ImageDTO` object wherever it was needed, but this is both inefficient and allows for stale `ImageDTO`s across the app.

One other possible solution was to just fetch the `ImageDTO` for all images at startup, and insert them into the `imagesSlice`, but then we run into an issue where we are displaying images in the gallery totally out of context.

For example, if an image from several pages into the gallery was sent to canvas, and the user refreshes, we'd display the first 20 images in gallery. Then to populate the canvas, we'd fetch that image we sent to canvas and add it to `imagesSlice`. Now we'd have 21 images in the gallery: 1 to 20 and whichever image we sent to canvas. Weird.

Using `rtk-query` solves this by allowing us to very easily fetch individual images in the components that need them, and not directly interact with `imagesSlice`.

This commit changes all references to images-as-parameters to store only the name of the image, and not the full `ImageDTO` object. Then, we use an `rtk-query` generated `useGetImageDTOQuery()` hook in each of those components to fetch the image.

We can use cache invalidation when we mutate any image to trigger automated re-running of the query and all the images are automatically kept up to date.

This also obviates the need for the convoluted URL fetching scheme for images that are used as parameters. The `imagesSlice` still need this handling unfortunately.
2023-06-22 16:25:49 +10:00
psychedelicious
cfda128e06 feat(ui): wip boards via rtk-query 2023-06-22 16:25:49 +10:00
psychedelicious
661a94b3de feat(db): add get_all() method for boards
This is needed to show the full list of boards in the update boards modal.
2023-06-22 16:25:49 +10:00
psychedelicious
9ef64016c7 feat(db): sort board by created_at 2023-06-22 16:25:49 +10:00
psychedelicious
21f0d0b0c1 fix(db): fix deserialize_board_record()
It was not adding `cover_image_name`
2023-06-22 16:25:49 +10:00
psychedelicious
8bce234542 feat(db): update image-board relationships on add
Functionally, `add_image_to_board()` now moves images between boards.
2023-06-22 16:25:49 +10:00
psychedelicious
daadf6ebfd feat(ui): add board image count badge 2023-06-22 16:25:49 +10:00
Mary Hipp
fe10a9f747 render cover image based on URL in image entities 2023-06-22 16:25:49 +10:00
Mary Hipp
7a2d3f628a add boardToAddTo state so that result can be added to board when generation is complete 2023-06-22 16:25:49 +10:00
Mary Hipp
4defb92105 handle long board names 2023-06-22 16:25:49 +10:00
Mary Hipp
f9f3c91a83 drag and drop to move image to board, a bit of board list UI 2023-06-22 16:25:49 +10:00
maryhipp
95b9c8e505 return cover_image_name since urls change, override one from db for now 2023-06-22 16:25:49 +10:00
psychedelicious
49a02c157b feat(ui): fix UpdateImageBoardModal select 2023-06-22 16:25:49 +10:00
psychedelicious
d604d986f9 feat(db, api): update get_board_for_image & service dependencies
- previously was `get_boards_for_image`, returning a list of `BoardDTO`, now returns a single `board_id`
2023-06-22 16:25:49 +10:00
psychedelicious
70cc037a9c fix(ui): do not persist boards 2023-06-22 16:25:49 +10:00
psychedelicious
e4893e4031 fix(db): return board records from CRUD methods 2023-06-22 16:25:49 +10:00
maryhipp
4a0a718b96 foiled by a comma 2023-06-22 16:25:49 +10:00
maryhipp
ca8f1a7828 (api) use most recently generated image for cover photo 2023-06-22 16:25:49 +10:00
Mary Hipp
2e41af2109 [half-baked] adding image to board modal 2023-06-22 16:25:49 +10:00
Mary Hipp
bd29e5e655 UI tweaks 2023-06-22 16:25:49 +10:00
Mary Hipp
dcfee2e1e4 add searching to boards list 2023-06-22 16:25:49 +10:00
Mary Hipp
8aac683319 can delete and rename boards 2023-06-22 16:25:49 +10:00
psychedelicious
d306a84447 feat(ui): rough out boards UI 2023-06-22 16:25:49 +10:00
psychedelicious
5865ecd530 feat(db): add FK for boards.cover_image_name 2023-06-22 16:25:49 +10:00
psychedelicious
e1f9685b02 feat(db): add index for boards 2023-06-22 16:25:49 +10:00
psychedelicious
498bf0d0ba feat(db): add indices for board_images 2023-06-22 16:25:49 +10:00
psychedelicious
163ef2c941 feat(ui): remove refs to BoardRecord in UI
UI should only work w/ BoardDTO
2023-06-22 16:25:49 +10:00
psychedelicious
48193b7fa7 chore(ui): regen api client 2023-06-22 16:25:49 +10:00
psychedelicious
dd1b3c9f35 fix(api): update API models to use BoardDTOs 2023-06-22 16:25:49 +10:00
psychedelicious
4b32322a58 feat(nodes): make board <> images a one-to-many relationship
we can extend this to many-to-many in the future if desired.
2023-06-22 16:25:49 +10:00
Mary Hipp
e06c43adc8 lint fix 2023-06-22 16:25:49 +10:00
Mary Hipp
c009f46b00 regenerate api schema 2023-06-22 16:25:49 +10:00
maryhipp
748016bdab routes working 2023-06-22 16:25:49 +10:00
psychedelicious
72e9ced889 feat(nodes): add boards and board_images services 2023-06-22 16:25:49 +10:00
maryhipp
3833304f57 [WIP] board list endpoint w cover photos 2023-06-22 16:25:49 +10:00
maryhipp
4bfaae6617 fix type 2023-06-22 16:25:49 +10:00
maryhipp
499a174832 some more 2023-06-22 16:25:49 +10:00
maryhipp
6ca5ad9075 filter images by board_id 2023-06-22 16:25:49 +10:00
maryhipp
a121e6b3a0 add board_id association to image 2023-06-22 16:25:49 +10:00
maryhipp
207602f425 remove unused 2023-06-22 16:25:49 +10:00
maryhipp
a1671519d5 board CRUD 2023-06-22 16:25:49 +10:00
Lincoln Stein
257e972599 fix failing pytest for config module 2023-06-20 13:26:01 -04:00
Lincoln Stein
8639794c12 Merge branch 'main' into install-script-python-version-error-prompt-fix 2023-06-20 18:24:54 +01:00
blessedcoolant
d339c8627f feat: Upgrade to Diffusers 0.17.1 (#3545)
Just syncing up with diffusers upstream.
2023-06-19 23:25:22 +12:00
blessedcoolant
a53e0dce6c Merge branch 'upgrade-diffusers' of https://github.com/blessedcoolant/InvokeAI into upgrade-diffusers 2023-06-19 23:21:06 +12:00
blessedcoolant
0ae6325353 chore: Add torchsde as a dependency for the SDE schedulers 2023-06-19 23:20:53 +12:00
blessedcoolant
12299120ab Merge branch 'main' into upgrade-diffusers 2023-06-19 23:16:39 +12:00
blessedcoolant
1a7fe172ca Fix inpaint node to new manager (#3550)
Inpaint node still used by canvas, so fixed it to new model manager api.
Other old generation code deleted.
2023-06-19 23:01:05 +12:00
blessedcoolant
4f5693040e Merge branch 'main' into fix/inpaint_new_manager 2023-06-19 22:55:00 +12:00
blessedcoolant
bb2df88c06 Add dpmpp_sde and dpmpp_2m_sde schedulers(with karras) (#3554)
Added sde schedulers.
Problem - they add random on each step, to get consistent image we need
to provide seed or generator.
I done it, but if you think that it better do in other way - feel free
to change.

Also made ancestral schedulers reproducible, this done same way as for
sde scheduler.
2023-06-19 22:52:33 +12:00
psychedelicious
41442eb7f6 feat(ui): convert canvas txt2img & img2img to latents
- Add graph builders for canvas txt2img & img2img - they are mostly copy and paste from the linear graph builders but different in a few ways that are very tricky to work around. Just made totally new functions for them.
- Canvas txt2img and img2img support ControlNet (not inpaint/outpaint). There's no way to determine in real-time which mode the canvas is in just yet, so we cannot disable the ControlNet UI when the mode will be inpaint/outpaint - it will always display. It's possible to determine this in near-real-time, will add this at some point.
- Canvas inpaint/outpaint migrated to use model loader, though inpaint/outpaint are still using the non-latents nodes.
2023-06-19 15:57:28 +10:00
psychedelicious
223a679ac1 chore(ui): regen api client 2023-06-19 15:57:28 +10:00
psychedelicious
3c60616b4d feat(ui): simplify linear graph creation logic
Instead of manually creating every node and edge, we can simply copy/paste the base graph from node editor, then sub in parameters.

This is a much more intelligible process. We still need to handle seed, img2img fit and controlnet separately.
2023-06-19 15:57:28 +10:00
Sergey Borisov
a01998d095 Remove more old logic 2023-06-19 15:57:28 +10:00
Sergey Borisov
7b35162b9e Remove old logic except for inpaint, add support for lora and ti to inpaint node 2023-06-19 15:57:28 +10:00
Sergey Borisov
c26e1a9271 Rewrite inpaint node to new model manager, remove TextToImage and ImageToImage nodes 2023-06-19 15:57:28 +10:00
Sergey Borisov
9b32407744 Provide generator to all schedulers step function to make both ancestral and sde schedulers reproducible 2023-06-19 00:34:01 +03:00
Sergey Borisov
82091b9a66 Fix vae conversion 2023-06-18 23:46:07 +03:00
Sergey Borisov
f3d9797ebe Add dpmpp_sde and dpmpp_2m_sde schedulers(with karras) 2023-06-18 23:38:15 +03:00
DrGunnarMallon
f312e1448f Update index.md
fixed typo
2023-06-18 10:39:02 -04:00
blessedcoolant
a11946f0ad feat: Port Schedulers to Mantine (#3552)
- Ports Schedulers to use IAIMantineSelect.
- Adds ability to favorite schedulers in Settings. Favorited schedulers
show up at the top of the list.
- Adds IAIMantineMultiSelect component.
- Change SettingsSchedulers component to use IAIMantineMultiSelect
instead of Chakra Menus.
2023-06-18 22:22:03 +12:00
blessedcoolant
80a8d3ef28 style: Theme placeholder style for IAIMantineMultiSelect 2023-06-18 22:17:09 +12:00
blessedcoolant
f4ca9d0e09 Merge branch 'scheduler-select' of https://github.com/blessedcoolant/InvokeAI into scheduler-select 2023-06-18 22:05:12 +12:00
blessedcoolant
a960fa009d fix: Fix some styling issues with IAIMantineMultiSelect 2023-06-18 22:04:12 +12:00
psychedelicious
b96b95bc95 feat(ui): enabledSchedulers -> favoriteSchedulers 2023-06-18 20:01:05 +10:00
psychedelicious
450641c414 fix(ui): enable all schedulers by default 2023-06-18 19:39:31 +10:00
psychedelicious
94cfcdc411 feat(ui): improve scheduler selection logic
- remove UI-specific state (the enabled schedulers) from redux, instead derive it in a selector
- simplify logic by putting schedulers in an object instead of an array
- rename `activeSchedulers` to `enabledSchedulers`
- remove need for `useEffect()` when `enabledSchedulers` changes by adding a listener for the `enabledSchedulersChanged` action/event to `generationSlice`
- increase type safety by making `enabledSchedulers` an array of `SchedulerParam`, which is created by the zod schema for scheduler
2023-06-18 19:34:37 +10:00
psychedelicious
150059f704 fix(ui): create all scheduler constants up-front 2023-06-18 18:49:10 +10:00
psychedelicious
f1a8b9daee fix(ui): clarify scheduler logic
- use full conditional syntax with `{}`
- do not mutate `action.payload` in a reducer
2023-06-18 18:47:59 +10:00
blessedcoolant
be8c0bb952 feat: Use Labels for Schedulers 2023-06-18 20:17:51 +12:00
blessedcoolant
dae5b9b259 fix: Minor styling fix to the IAIMantineMultiSelect component 2023-06-18 20:06:56 +12:00
blessedcoolant
06428fac67 fix: Revert scheduler back to zod validation 2023-06-18 20:02:36 +12:00
blessedcoolant
59b5dfc3e0 feat: Port Schedulers to Mantine 2023-06-18 19:47:27 +12:00
blessedcoolant
fd981a90be Add lms and dpmpp2_s karras scheduler (#3551)
Karras sigmas support added to lms and dpmpp2_s schedulers in 0.17.0
diffusers.
2023-06-18 17:36:47 +12:00
Sergey Borisov
6b7cf3f3be Add lms and dpmpp2_s karras scheduler 2023-06-17 21:00:16 +03:00
Stephan Koglin-Fischer
469dae8c88 fix(linux): installer script prints maximum python version usable 2023-06-16 15:18:23 +02:00
blessedcoolant
9d4b84ef68 feat: Upgrade to Diffusers 0.17.1 2023-06-16 23:57:57 +12:00
blessedcoolant
4cbc802e36 Model manager fixes (#3541)
Fix lora import
Fix sd2 config - `variant` field not added
Fix list models api - `base_model` arg not provided, redundant assert
check
2023-06-16 06:43:00 +12:00
Sergey Borisov
5f2d07917d Fix lora import, fix sd2 config, fix list models api 2023-06-15 21:30:15 +03:00
Lincoln Stein
5c740452f6 Model Manager rewrite (#3335) 2023-06-14 08:44:04 -07:00
Lincoln Stein
82c2498043 Merge branch 'main' into lstein/new-model-manager 2023-06-14 08:41:40 -07:00
blessedcoolant
4ca325e8e6 chore: Rebuild API 2023-06-15 03:20:49 +12:00
blessedcoolant
6b8e88ad7f Merge branch 'main' into feat/controlnet-control-modes 2023-06-15 03:18:41 +12:00
psychedelicious
0497bea264 fix: add dynamicprompts to pyproject.toml 2023-06-15 01:05:16 +10:00
psychedelicious
b8e32fa459 chore(ui): regen api client 2023-06-15 01:05:16 +10:00
psychedelicious
34ebee67b7 fix(nodes): fix revert conflict 2023-06-15 01:05:16 +10:00
psychedelicious
e0c998d192 Revert "feat(ui): add warning socket event handling"
This reverts commit e7a61e631a42190e4b64e0d5e22771c669c5b30c.
2023-06-15 01:05:16 +10:00
psychedelicious
b51e9a6bdb Revert "feat(nodes): add warning socket event"
This reverts commit cefdd9d634e515239bd85666c872a0d64bb9d772.
2023-06-15 01:05:16 +10:00
psychedelicious
09f396ce84 feat(ui): add warning socket event handling 2023-06-15 01:05:16 +10:00
psychedelicious
abee37eab3 feat(nodes): add warning socket event 2023-06-15 01:05:16 +10:00
psychedelicious
42e48b2bef feat(nodes): add dynamic prompt node 2023-06-15 01:05:16 +10:00
blessedcoolant
70ece4364c refactor(minor): Image & Latent File Storage (#3538)
- `DiskImageStorage` and `DiskLatentsStorage` have now both been updated
to exclusively work with `Path` objects and not rely on the `os` lib to
handle pathing related functions.
- We now also validate the existence of the required image output
folders and latent output folders to ensure that the app does not break
in case the required folders get tampered with mid-session.
- Just overall general cleanup.

Tested it. Don't seem to be any thing breaking.
2023-06-15 02:43:27 +12:00
psychedelicious
f9d5f9d52c fix(nodes): minor fixes for folder validation
- fix type for `__output_folder`
- prefix `validate_storage_folders()` with `__` to indicate private method
2023-06-15 00:40:39 +10:00
StAlKeR7779
d0ee3558d1 Merge branch 'main' into lstein/new-model-manager 2023-06-14 17:29:01 +03:00
blessedcoolant
587297878a refactor(minor): Latent Disk Storage 2023-06-15 02:21:49 +12:00
blessedcoolant
b4c998a9ae refactor(minor): Image File Storage 2023-06-15 01:58:58 +12:00
psychedelicious
88e8e3977b feat(ui): update UI to not use image_origin
see commit `8ad8de8: feat(nodes): remove `image_origin` from most places` for details.
2023-06-14 23:08:27 +10:00
psychedelicious
24b86cffe9 chore(ui): regen api client & types 2023-06-14 23:08:27 +10:00
psychedelicious
a1773197e9 feat(nodes): remove image_origin from most places
- remove `image_origin` from most places where we interact with images
- consolidate image file storage into a single `images/` dir

Images have an `image_origin` attribute but it is not actually used when retrieving images, nor will it ever be. It is still used when creating images and helps to differentiate between internally generated images and uploads.

It was included in eg API routes and image service methods as a holdover from the previous app implementation where images were not managed in a database. Now that we have images in a db, we can do away with this and simplify basically everything that touches images.

The one potentially controversial change is to no longer separate internal and external images on disk. If we retain this separation, we have to keep `image_origin` around in a number of spots and it getting image paths on disk painful.

So, I am have gotten rid of this organisation. Images are now all stored in `images`, regardless of their origin. As we improve the image management features, this change will hopefully become transparent.
2023-06-14 23:08:27 +10:00
blessedcoolant
6c53abc034 feat: Add ControlMode to Linear UI 2023-06-14 20:01:17 +12:00
blessedcoolant
eb7047b21d chore: Rebuild WebAPI 2023-06-14 19:26:02 +12:00
blessedcoolant
43419ac761 Merge branch 'main' into feat/controlnet-control-modes 2023-06-14 19:04:42 +12:00
user1
5cd0e90816 Renamed ControlNet control_mode option "even_more_control" to "unbalanced" 2023-06-13 22:30:17 -07:00
user1
cfd49e3921 Removing vestigial comments. 2023-06-13 21:33:15 -07:00
user1
a8e0490133 Merge branch 'feat/controlnet-control-modes' of https://github.com/invoke-ai/InvokeAI into feat/controlnet-control-modes 2023-06-13 21:21:13 -07:00
psychedelicious
1e08d865c9 chore: dummy commit to trigger actions 2023-06-14 14:14:24 +10:00
blessedcoolant
f8bb650cc1 revert: IAIScrollArea 2023-06-14 14:14:24 +10:00
psychedelicious
2cee8bebb2 fix(ui): revert offset scrollbars
The wonky padding is too janky. Just overlay for now.
2023-06-14 14:14:24 +10:00
psychedelicious
ade4ec5fd8 fix(ui): fix crash when toggling pinned parameters panel 2023-06-14 14:14:24 +10:00
psychedelicious
70ffd6b03f fix(ui): fix controlnet selects data types 2023-06-14 14:14:24 +10:00
psychedelicious
6c551df311 fix(ui): fix rebase conflicts 2023-06-14 14:14:24 +10:00
blessedcoolant
24f605629e cleanup: Remove OverlayScrollable component 2023-06-14 14:14:24 +10:00
blessedcoolant
2af1ec9d02 fix: Minor padding issue in unpinned drawer 2023-06-14 14:14:24 +10:00
blessedcoolant
79d53341de fix: Stretch scroll area so it retains parent width 2023-06-14 14:14:24 +10:00
blessedcoolant
e40b3506c4 fix: Options squishing on accordion collapse 2023-06-14 14:14:24 +10:00
blessedcoolant
33912382e3 feat: Introduce Mantine's ScrollArea 2023-06-14 14:14:24 +10:00
blessedcoolant
d282810e53 cleanup: Remove IAICustomSelect and port types 2023-06-14 14:14:24 +10:00
psychedelicious
9df502fc77 fix(ui): fix mantine select props 2023-06-14 14:14:24 +10:00
psychedelicious
705573f0a8 feat(ui): even more pedantic mantine select theming 2023-06-14 14:14:24 +10:00
blessedcoolant
1878ea94f6 feat: Port Canvas Layer Select to IAIMantineSelect 2023-06-14 14:14:24 +10:00
psychedelicious
4ba5086b9a feat(ui): add tooltip to IAIMantineSelect 2023-06-14 14:14:24 +10:00
psychedelicious
4a991b4daa feat(ui): more pedantic mantine select theming 2023-06-14 14:14:24 +10:00
psychedelicious
80474d26f9 feat(ui): mantine scrollbar theming 2023-06-14 14:14:24 +10:00
blessedcoolant
9a77bd9140 feat: Port IAISelect's to IAIMantineSelect's
Ported everything except Model Manager selects and the Canvas Layer Select (this needs tooltip support)
2023-06-14 14:14:24 +10:00
psychedelicious
14cdc800c3 feat(ui): pedantic mantine select theming 2023-06-14 14:14:24 +10:00
blessedcoolant
9cfbea4c25 feat: Match styling of Mantine Select with InvokeAI 2023-06-14 14:14:24 +10:00
blessedcoolant
5fe674e223 feat: Standardize IAIMantineSelect Component 2023-06-14 14:14:24 +10:00
blessedcoolant
32200efce8 feat: Change default font to Inter 2023-06-14 14:14:24 +10:00
blessedcoolant
68a02da990 feat: Use Mantine Select for Scheduler 2023-06-14 14:14:24 +10:00
blessedcoolant
5b20766ea3 chore: Move Mantine Theme Override to own file 2023-06-14 14:14:24 +10:00
blessedcoolant
9a914250a0 feat: Change Model Select To Mantine 2023-06-14 14:14:24 +10:00
blessedcoolant
0e3106f631 feat: Add Mantine Support 2023-06-14 14:14:24 +10:00
user1
de3e6cdb02 Switched over to ControlNet control_mode with 4 options: balanced, more_prompt, more_control, even_more_control. Based on True/False combinations of internal booleans cfg_injection and soft_injection 2023-06-13 21:08:34 -07:00
Sergey Borisov
6c5954f9d1 Add controlnet to model manager, fixes 2023-06-14 04:26:21 +03:00
Sergey Borisov
740c05a0bb Save models on rescan, uncache model on edit/delete, fixes 2023-06-14 03:12:12 +03:00
Sergey Borisov
26090011c4 Fix conflict resolve, add model configs to type annotation 2023-06-14 00:26:37 +03:00
StAlKeR7779
c9ae26a176 Merge branch 'main' into lstein/new-model-manager 2023-06-13 23:37:52 +03:00
Sergey Borisov
e7db6d8120 Fix ckpt and vae conversion, migrate script, remove sd2-base 2023-06-13 18:05:12 +03:00
user1
8495764d45 Moving from ControlNet guess_mode to separate booleans for cfg_injection and soft_injection for testing control modes 2023-06-13 00:41:36 -07:00
user1
8b7fac75ed First pass at ControlNet "guess mode" implementation. 2023-06-13 00:41:36 -07:00
user1
9e0e26f4c4 Moving from ControlNet guess_mode to separate booleans for cfg_injection and soft_injection for testing control modes 2023-06-12 23:57:57 -07:00
Lincoln Stein
a6af7e8824 use format "diffusers" rather than format "folder" in models.yaml 2023-06-13 01:43:05 -04:00
Lincoln Stein
87ba17a1f5 add migration script and update convert and face restoration paths 2023-06-13 01:27:51 -04:00
Lincoln Stein
c7ea46a5da use latest version of transformers to avoid deprecation warnings 2023-06-12 16:07:39 -04:00
Lincoln Stein
1439dc7712 Add SchedulerPredictionType and ModelVariantType enums 2023-06-12 16:07:04 -04:00
blessedcoolant
46cac6468e Upgrade to Diffusers 0.17.0 (#3514)
Diffusers is due for an update soon. #3512

Opening up a PR now with the required changes for when the new version
is live.

I've tested it out on Windows and nothing has broken from what I could
tell. I'd like someone to run some tests on Linux / Mac just to make
sure. Refer to the PR above on how to test it or install the release
branch.

```
pip install diffusers[torch]==0.17.0
```

Feel free to push any other changes to this PR you see fit.
2023-06-13 07:11:02 +12:00
blessedcoolant
2a814d886b Merge branch 'main' into diffusers-upgrade 2023-06-13 05:29:15 +12:00
Sergey Borisov
36eb1bd893 Fixes 2023-06-12 16:14:09 +03:00
Sergey Borisov
9fa78443de Fixes, add sd variant detection 2023-06-12 05:52:30 +03:00
Lincoln Stein
893f776f1d model_probe working; model_install incomplete 2023-06-11 19:51:53 -04:00
Lincoln Stein
085ab54124 remove modified models.py and migrate code to models/base.py 2023-06-11 16:10:15 -04:00
Lincoln Stein
8e1a56875e remove defunct code 2023-06-11 12:57:06 -04:00
Lincoln Stein
000626ab2e move all installation code out of model_manager 2023-06-11 12:51:50 -04:00
Sergey Borisov
694fd0c92f Fixes, first runable version 2023-06-11 16:42:40 +03:00
user1
fd715026a7 First pass at ControlNet "guess mode" implementation. 2023-06-11 02:00:39 -07:00
Sergey Borisov
738ba40f51 Fixes 2023-06-11 06:12:21 +03:00
Sergey Borisov
3ce3a7ee72 Rewrite model configs, separate models 2023-06-11 04:49:09 +03:00
Lincoln Stein
74b43c9bdf fix incorrect variable/typenames in model_cache 2023-06-10 10:41:48 -04:00
Lincoln Stein
3d2ff7755e resolve conflicts 2023-06-10 10:13:54 -04:00
Lincoln Stein
a87d52a389 resolve conflicts between lstein & sttalker changes 2023-06-10 09:59:19 -04:00
Lincoln Stein
959e64c9b3 start removing repo_id support 2023-06-10 09:57:23 -04:00
Sergey Borisov
2c056ead42 New models structure draft 2023-06-10 03:14:10 +03:00
blessedcoolant
7bce455d16 Merge branch 'main' into diffusers-upgrade 2023-06-09 16:27:52 +12:00
Lincoln Stein
887576d217 add directory scanning for loras, controlnets and textual_inversions 2023-06-08 23:11:53 -04:00
Lincoln Stein
6652f3405b merge with main 2023-06-08 21:08:43 -04:00
Lincoln Stein
27b5e43ea4 add messages to the user to tell them to enlarge window 2023-06-08 16:37:10 -04:00
blessedcoolant
68405910ba Upgrade to Diffusers 0.17.0 2023-06-08 04:42:52 +12:00
Lincoln Stein
04f9757f8d prevent crash when trying to calculate size of missing safety_checker
- Also fixed up order in which logger is created in invokeai-web
  so that handlers are installed after command-line options are
  parsed (and not before!)
2023-06-06 22:57:49 -04:00
Lincoln Stein
1f9e1eb964 merge with main 2023-06-06 22:18:41 -04:00
Lincoln Stein
8285fbb0b1 Merge branch 'lstein/new-model-manager' of github.com:invoke-ai/InvokeAI into lstein/new-model-manager 2023-06-02 22:48:00 -04:00
Lincoln Stein
951e6b746c remove model cache test; should be replaced with something else 2023-06-02 22:47:48 -04:00
Lincoln Stein
44a6623094 Merge branch 'main' into lstein/new-model-manager 2023-06-02 22:40:51 -04:00
Lincoln Stein
98773b20ac merge with main 2023-06-01 18:09:49 -04:00
Sergey Borisov
b47786e846 First working TI draft 2023-05-31 02:12:27 +03:00
Sergey Borisov
69ccd3a0b5 Fixes for checkpoint models 2023-05-30 19:12:47 +03:00
Sergey Borisov
420a76ecdd Add lora loader node 2023-05-30 02:12:33 +03:00
Sergey Borisov
79de9047b5 First working lora implementation 2023-05-30 01:11:00 +03:00
Lincoln Stein
f50293920e correct typo in tiled_vae field definition 2023-05-25 23:29:16 -04:00
Lincoln Stein
1e2db3a17f hook tiled_decode up to configuration 2023-05-25 23:28:15 -04:00
Lincoln Stein
5f8f51436a merge with main; fix conflicts 2023-05-25 22:40:45 -04:00
Sergey Borisov
8e419a4f97 Revert weak references as can be done without it 2023-05-23 04:29:40 +03:00
Sergey Borisov
2533209326 Rewrite cache to weak references 2023-05-23 03:48:22 +03:00
StAlKeR7779
165c1adcf8 Merge branch 'main' into lstein/new-model-manager 2023-05-22 21:51:07 +03:00
Lincoln Stein
bdf33f13b3 fix bad merge in compel 2023-05-18 18:08:45 -04:00
Lincoln Stein
27241cdde1 port more globals changes over 2023-05-18 17:17:45 -04:00
Lincoln Stein
259d6ec90d fixup cachedir call 2023-05-18 14:52:16 -04:00
Lincoln Stein
a77c4c87b2 fixed logic error in resolution of model path 2023-05-18 14:35:34 -04:00
Lincoln Stein
d96175d127 resolve some undefined symbols in model_cache 2023-05-18 14:31:47 -04:00
Lincoln Stein
b1a99d772c added method to convert vaes 2023-05-18 13:31:11 -04:00
Sergey Borisov
fd82763412 Model manager draft 2023-05-18 03:56:52 +03:00
Lincoln Stein
e971a7f35c when migrating models.yaml, rename original models.yaml.orig 2023-05-16 22:37:53 -04:00
psychedelicious
6ab84741a0 fix(nodes): make ModelsList an enum-keyed dict
The `ModelsList` OpenAPI schema is generated as being keyed by plain strings. This means that API consumers do not know the shape of the dict. It _should_ be keyed by the `SDModelType` enum.

Unfortunately, `fastapi` does not actually handle this correctly yet; it still generates the schema with plain string keys.

Adding this anyways though in hopes that it will be resolved upstream and we can get the correct schema. Until then, I'll implement the (simple but annoying) logic on the frontend.

https://github.com/pydantic/pydantic/issues/4393
2023-05-16 15:02:58 +10:00
Lincoln Stein
cd16857f38 fix None in model_type 2023-05-16 00:13:44 -04:00
Lincoln Stein
1442f1cb8d change model filter to None in second place 2023-05-16 00:03:57 -04:00
Lincoln Stein
eea0d6f7bc default to no filter in list_models() 2023-05-15 23:52:29 -04:00
Lincoln Stein
4fe94a9315 list_models() now returns a dict of {type,{name: info}} 2023-05-15 23:44:08 -04:00
Lincoln Stein
c8f765cc06 improve debugging messages 2023-05-14 18:29:55 -04:00
Lincoln Stein
b9e9087dbe do not manage GPU for pipelines if sequential_offloading is True 2023-05-14 18:09:38 -04:00
Lincoln Stein
63e465eb5c tweaks to get_model() behavior
1. If an external VAE is specified in config file, then
   get_model(submodel=vae) will return the external VAE, not the one
   burnt into the parent diffusers pipeline.

2. The mechanism in (1) is generalized such that you can now have
   "unet:", "text_encoder:" and similar stanzas in the config file.
   Valid formats of these subsections:

       unet:
          repo_id: foo/bar

       unet:
          path: /path/to/local/folder

       unet:
          repo_id: foo/bar
	  subfolder: unet

    In the near future, these will also be used to attach external
    parts to the pipeline, generalizing VAE behavior.

3. Accommodate callers (i.e. the WebUI) that are passing the
   model key ("diffusers/stable-diffusion-1.5") to get_model()
   instead of the tuple of model_name and model_type.

4. Fixed bug in VAE model attaching code.

5. Rebuilt web front end.
2023-05-14 16:50:59 -04:00
Lincoln Stein
426f4eaf7e adjusted regression tests to work with new SDModelTypes 2023-05-13 22:29:33 -04:00
Lincoln Stein
baf5451fa0 Merge branch 'main' into lstein/new-model-manager 2023-05-13 22:01:34 -04:00
Lincoln Stein
b31a6ff605 fix reversed args in _model_key() call 2023-05-13 21:11:06 -04:00
Sergey Borisov
1f602e6143 Fix - apply precision to text_encoder 2023-05-14 03:46:13 +03:00
Sergey Borisov
039fa73269 Change SDModelType enum to string, fixes(model unload negative locks count, scheduler load error, saftensors convert, wrong logic in del_model, wrong parse metadata in web) 2023-05-14 03:06:26 +03:00
Lincoln Stein
2204e47596 allow submodels to be fetched independent of parent pipeline 2023-05-13 16:54:47 -04:00
Lincoln Stein
d8b1f29066 proxy SDModelInfo so that it can be used directly as context 2023-05-13 16:29:18 -04:00
Lincoln Stein
b23c9f1da5 get Tuple type hint syntax right 2023-05-13 14:59:21 -04:00
Lincoln Stein
5e8e3cf464 correct typos in model_manager_service 2023-05-13 14:55:59 -04:00
Lincoln Stein
72967bf118 convert add_model(), del_model(), list_models() etc to use bifurcated names 2023-05-13 14:44:44 -04:00
Sergey Borisov
bc96727cbe Rewrite latent nodes to new model manager 2023-05-13 16:08:03 +03:00
Sergey Borisov
3b2a054f7a Add model loader node; unet, clip, vae fields; change compel node to clip field 2023-05-13 04:37:20 +03:00
Sergey Borisov
131145eab1 A big refactor of model manager(according to IMHO) 2023-05-12 23:13:34 +03:00
Sergey Borisov
4492044d29 Redo compel node to separate model loading 2023-05-12 23:09:33 +03:00
Sergey Borisov
5431dd5f50 Fix event args 2023-05-12 23:08:03 +03:00
Sergey Borisov
79fecba274 Fix model manager initialization in web ui 2023-05-12 23:05:08 +03:00
Lincoln Stein
2ef79b8bf3 fix bug in persistent model scheme 2023-05-12 00:14:56 -04:00
Lincoln Stein
11ecf438f5 latents.py converted to use model manager service; events emitted 2023-05-11 23:33:24 -04:00
Lincoln Stein
df5b968954 model manager now running as a service 2023-05-11 21:24:29 -04:00
Lincoln Stein
8ad8c5c67a resolve conflicts with main 2023-05-11 00:19:20 -04:00
Lincoln Stein
590942edd7 Merge branch 'main' into lstein/new-model-manager 2023-05-11 00:16:03 -04:00
Lincoln Stein
4627910c5d added a wrapper model_manager_service and model events 2023-05-11 00:09:19 -04:00
Lincoln Stein
fa6a580452 merge with main 2023-05-10 00:03:32 -04:00
Lincoln Stein
99c692f397 check that model name matches format 2023-05-09 23:46:59 -04:00
Lincoln Stein
3d85e769ce clean up ckpt handling
- remove legacy ckpt loading code from model_cache
- added placeholders for lora and textual inversion model loading
2023-05-09 22:44:58 -04:00
Lincoln Stein
9cb962cad7 ckpt model conversion now done in ModelCache 2023-05-08 23:39:44 -04:00
Lincoln Stein
a108155544 added StALKeR779's great model size calculating routine 2023-05-08 21:47:03 -04:00
Lincoln Stein
c15b49c805 implement StALKeR7779 requested API for fetching submodels 2023-05-07 23:18:17 -04:00
Lincoln Stein
fd63e36822 optimize subfolder so that it returns submodel if parent is in RAM 2023-05-07 21:39:11 -04:00
Lincoln Stein
4649920074 adjust t2i to work with new model structure 2023-05-07 19:06:49 -04:00
Lincoln Stein
667171ed90 cap model cache size using bytes, not # models 2023-05-07 18:07:28 -04:00
Lincoln Stein
647ffb2a0f defined abstract baseclass for model manager service 2023-05-06 22:41:19 -04:00
Lincoln Stein
05a27bda5e generalize model loading support, include loras/embeds 2023-05-06 15:58:44 -04:00
Lincoln Stein
a8cfa3565c Merge branch 'lstein/new-model-manager' of github.com:invoke-ai/InvokeAI into lstein/new-model-manager 2023-05-06 08:14:15 -04:00
Lincoln Stein
e0214a32bc mostly ported to new manager API; needs testing 2023-05-06 00:44:12 -04:00
Lincoln Stein
af8c7c7d29 model manager rewritten to use model_cache; API changed! 2023-05-05 19:32:28 -04:00
Lincoln Stein
a4e36bc02a when model is forcibly moved into RAM update loaded_models set 2023-05-04 23:28:03 -04:00
Lincoln Stein
2e9bec15e7 Merge branch 'main' into lstein/new-model-manager 2023-05-04 23:19:38 -04:00
Lincoln Stein
68bc0112fa implement lazy GPU offloading and ref counting 2023-05-04 23:15:32 -04:00
Lincoln Stein
a273bdbdc1 Merge branch 'main' into lstein/new-model-manager 2023-05-03 18:09:29 -04:00
Lincoln Stein
8a0ec0fa0f Merge branch 'main' into lstein/new-model-manager 2023-05-03 13:30:50 -04:00
Lincoln Stein
e1fed52c66 work on model cache and its regression test finished 2023-05-03 12:38:18 -04:00
Lincoln Stein
bb959448c1 implement hashing for local & remote models 2023-05-02 16:52:27 -04:00
Lincoln Stein
2e2abf6ea6 caching of subparts working 2023-05-01 22:57:30 -04:00
Lincoln Stein
956ad6bcf5 add redesigned model cache for diffusers & transformers 2023-04-28 00:41:52 -04:00
419 changed files with 19600 additions and 15088 deletions

View File

@@ -1,10 +1,16 @@
name: Test invoke.py pip
# This is a dummy stand-in for the actual tests
# we don't need to run python tests on non-Python changes
# But PRs require passing tests to be mergeable
on:
pull_request:
paths:
- '**'
- '!pyproject.toml'
- '!invokeai/**'
- '!tests/**'
- 'invokeai/frontend/web/**'
merge_group:
workflow_dispatch:
@@ -19,48 +25,26 @@ jobs:
strategy:
matrix:
python-version:
# - '3.9'
- '3.10'
pytorch:
# - linux-cuda-11_6
- linux-cuda-11_7
- linux-rocm-5_2
- linux-cpu
- macos-default
- windows-cpu
# - windows-cuda-11_6
# - windows-cuda-11_7
include:
# - pytorch: linux-cuda-11_6
# os: ubuntu-22.04
# extra-index-url: 'https://download.pytorch.org/whl/cu116'
# github-env: $GITHUB_ENV
- pytorch: linux-cuda-11_7
os: ubuntu-22.04
github-env: $GITHUB_ENV
- pytorch: linux-rocm-5_2
os: ubuntu-22.04
extra-index-url: 'https://download.pytorch.org/whl/rocm5.2'
github-env: $GITHUB_ENV
- pytorch: linux-cpu
os: ubuntu-22.04
extra-index-url: 'https://download.pytorch.org/whl/cpu'
github-env: $GITHUB_ENV
- pytorch: macos-default
os: macOS-12
github-env: $GITHUB_ENV
- pytorch: windows-cpu
os: windows-2022
github-env: $env:GITHUB_ENV
# - pytorch: windows-cuda-11_6
# os: windows-2022
# extra-index-url: 'https://download.pytorch.org/whl/cu116'
# github-env: $env:GITHUB_ENV
# - pytorch: windows-cuda-11_7
# os: windows-2022
# extra-index-url: 'https://download.pytorch.org/whl/cu117'
# github-env: $env:GITHUB_ENV
name: ${{ matrix.pytorch }} on ${{ matrix.python-version }}
runs-on: ${{ matrix.os }}
steps:
- run: 'echo "No build required"'
- name: skip
run: echo "no build required"

View File

@@ -11,6 +11,7 @@ on:
paths:
- 'pyproject.toml'
- 'invokeai/**'
- 'tests/**'
- '!invokeai/frontend/web/**'
types:
- 'ready_for_review'
@@ -32,19 +33,12 @@ jobs:
# - '3.9'
- '3.10'
pytorch:
# - linux-cuda-11_6
- linux-cuda-11_7
- linux-rocm-5_2
- linux-cpu
- macos-default
- windows-cpu
# - windows-cuda-11_6
# - windows-cuda-11_7
include:
# - pytorch: linux-cuda-11_6
# os: ubuntu-22.04
# extra-index-url: 'https://download.pytorch.org/whl/cu116'
# github-env: $GITHUB_ENV
- pytorch: linux-cuda-11_7
os: ubuntu-22.04
github-env: $GITHUB_ENV
@@ -62,14 +56,6 @@ jobs:
- pytorch: windows-cpu
os: windows-2022
github-env: $env:GITHUB_ENV
# - pytorch: windows-cuda-11_6
# os: windows-2022
# extra-index-url: 'https://download.pytorch.org/whl/cu116'
# github-env: $env:GITHUB_ENV
# - pytorch: windows-cuda-11_7
# os: windows-2022
# extra-index-url: 'https://download.pytorch.org/whl/cu117'
# github-env: $env:GITHUB_ENV
name: ${{ matrix.pytorch }} on ${{ matrix.python-version }}
runs-on: ${{ matrix.os }}
env:
@@ -100,40 +86,38 @@ jobs:
id: run-pytest
run: pytest
- name: run invokeai-configure
id: run-preload-models
env:
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGINGFACE_TOKEN }}
run: >
invokeai-configure
--yes
--default_only
--full-precision
# can't use fp16 weights without a GPU
# - name: run invokeai-configure
# env:
# HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGINGFACE_TOKEN }}
# run: >
# invokeai-configure
# --yes
# --default_only
# --full-precision
# # can't use fp16 weights without a GPU
- name: run invokeai
id: run-invokeai
env:
# Set offline mode to make sure configure preloaded successfully.
HF_HUB_OFFLINE: 1
HF_DATASETS_OFFLINE: 1
TRANSFORMERS_OFFLINE: 1
INVOKEAI_OUTDIR: ${{ github.workspace }}/results
run: >
invokeai
--no-patchmatch
--no-nsfw_checker
--precision=float32
--always_use_cpu
--use_memory_db
--outdir ${{ env.INVOKEAI_OUTDIR }}/${{ matrix.python-version }}/${{ matrix.pytorch }}
--from_file ${{ env.TEST_PROMPTS }}
# - name: run invokeai
# id: run-invokeai
# env:
# # Set offline mode to make sure configure preloaded successfully.
# HF_HUB_OFFLINE: 1
# HF_DATASETS_OFFLINE: 1
# TRANSFORMERS_OFFLINE: 1
# INVOKEAI_OUTDIR: ${{ github.workspace }}/results
# run: >
# invokeai
# --no-patchmatch
# --no-nsfw_checker
# --precision=float32
# --always_use_cpu
# --use_memory_db
# --outdir ${{ env.INVOKEAI_OUTDIR }}/${{ matrix.python-version }}/${{ matrix.pytorch }}
# --from_file ${{ env.TEST_PROMPTS }}
- name: Archive results
id: archive-results
env:
INVOKEAI_OUTDIR: ${{ github.workspace }}/results
uses: actions/upload-artifact@v3
with:
name: results
path: ${{ env.INVOKEAI_OUTDIR }}
# - name: Archive results
# env:
# INVOKEAI_OUTDIR: ${{ github.workspace }}/results
# uses: actions/upload-artifact@v3
# with:
# name: results
# path: ${{ env.INVOKEAI_OUTDIR }}

View File

@@ -43,6 +43,23 @@ _Note: InvokeAI is rapidly evolving. Please use the
[Issues](https://github.com/invoke-ai/InvokeAI/issues) tab to report bugs and make feature
requests. Be sure to use the provided templates. They will help us diagnose issues faster._
## FOR DEVELOPERS - MIGRATING TO THE 3.0.0 MODELS FORMAT
The models directory and models.yaml have changed. To migrate to the
new layout, please follow this recipe:
1. Run `python scripts/migrate_models_to_3.0.py <path_to_root_directory>
2. This will create a new models directory named `models-3.0` and a
new config directory named `models.yaml-3.0`, both in the current
working directory. If you prefer to name them something else, pass
the `--dest-directory` and/or `--dest-yaml` arguments.
3. Check that the new models directory and yaml file look ok.
4. Replace the existing directory and file, keeping backup copies just in
case.
<div align="center">
![canvas preview](https://github.com/invoke-ai/InvokeAI/raw/main/docs/assets/canvas_preview.png)

View File

@@ -67,7 +67,7 @@ title: Home
implementation of Stable Diffusion, the open source text-to-image and
image-to-image generator. It provides a streamlined process with various new
features and options to aid the image generation process. It runs on Windows,
Mac and Linux machines, and runs on GPU cards with as little as 4 GB or RAM.
Mac and Linux machines, and runs on GPU cards with as little as 4 GB of RAM.
**Quick links**: [<a href="https://discord.gg/ZmtBAhwWhy">Discord Server</a>]
[<a href="https://github.com/invoke-ai/InvokeAI/">Code and Downloads</a>] [<a

View File

@@ -38,6 +38,7 @@ echo https://learn.microsoft.com/en-US/cpp/windows/latest-supported-vc-redist
echo.
echo See %INSTRUCTIONS% for more details.
echo.
echo "For the best user experience we suggest enlarging or maximizing this window now."
pause
@rem ---------------------------- check Python version ---------------

View File

@@ -25,7 +25,8 @@ done
if [ -z "$PYTHON" ]; then
echo "A suitable Python interpreter could not be found"
echo "Please install Python 3.9 or higher before running this script. See instructions at $INSTRUCTIONS for help."
echo "Please install Python $MINIMUM_PYTHON_VERSION or higher (maximum $MAXIMUM_PYTHON_VERSION) before running this script. See instructions at $INSTRUCTIONS for help."
echo "For the best user experience we suggest enlarging or maximizing this window now."
read -p "Press any key to exit"
exit -1
fi

View File

@@ -293,6 +293,8 @@ def introduction() -> None:
"3. Create initial configuration files.",
"",
"[i]At any point you may interrupt this program and resume later.",
"",
"[b]For the best user experience, please enlarge or maximize this window",
),
)
)

View File

@@ -2,8 +2,17 @@
from logging import Logger
import os
from invokeai.app.services.board_image_record_storage import (
SqliteBoardImageRecordStorage,
)
from invokeai.app.services.board_images import (
BoardImagesService,
BoardImagesServiceDependencies,
)
from invokeai.app.services.board_record_storage import SqliteBoardRecordStorage
from invokeai.app.services.boards import BoardService, BoardServiceDependencies
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
from invokeai.app.services.images import ImageService
from invokeai.app.services.images import ImageService, ImageServiceDependencies
from invokeai.app.services.metadata import CoreMetadataService
from invokeai.app.services.resource_name import SimpleNameService
from invokeai.app.services.urls import LocalUrlService
@@ -11,7 +20,6 @@ from invokeai.backend.util.logging import InvokeAILogger
from ..services.default_graphs import create_system_graphs
from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
from ..services.model_manager_initializer import get_model_manager
from ..services.restoration_services import RestorationServices
from ..services.graph import GraphExecutionState, LibraryGraph
from ..services.image_file_storage import DiskImageFileStorage
@@ -20,6 +28,7 @@ from ..services.invocation_services import InvocationServices
from ..services.invoker import Invoker
from ..services.processor import DefaultInvocationProcessor
from ..services.sqlite import SqliteItemStorage
from ..services.model_manager_service import ModelManagerService
from .events import FastAPIEventService
@@ -57,7 +66,7 @@ class ApiDependencies:
# TODO: build a file/path manager?
db_location = config.db_path
db_location.parent.mkdir(parents=True,exist_ok=True)
db_location.parent.mkdir(parents=True, exist_ok=True)
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
filename=db_location, table_name="graph_executions"
@@ -72,21 +81,49 @@ class ApiDependencies:
DiskLatentsStorage(f"{output_folder}/latents")
)
board_record_storage = SqliteBoardRecordStorage(db_location)
board_image_record_storage = SqliteBoardImageRecordStorage(db_location)
boards = BoardService(
services=BoardServiceDependencies(
board_image_record_storage=board_image_record_storage,
board_record_storage=board_record_storage,
image_record_storage=image_record_storage,
url=urls,
logger=logger,
)
)
board_images = BoardImagesService(
services=BoardImagesServiceDependencies(
board_image_record_storage=board_image_record_storage,
board_record_storage=board_record_storage,
image_record_storage=image_record_storage,
url=urls,
logger=logger,
)
)
images = ImageService(
image_record_storage=image_record_storage,
image_file_storage=image_file_storage,
metadata=metadata,
url=urls,
logger=logger,
names=names,
graph_execution_manager=graph_execution_manager,
services=ImageServiceDependencies(
board_image_record_storage=board_image_record_storage,
image_record_storage=image_record_storage,
image_file_storage=image_file_storage,
metadata=metadata,
url=urls,
logger=logger,
names=names,
graph_execution_manager=graph_execution_manager,
)
)
services = InvocationServices(
model_manager=get_model_manager(config, logger),
model_manager=ModelManagerService(config,logger),
events=events,
latents=latents,
images=images,
boards=boards,
board_images=board_images,
queue=MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph](
filename=db_location, table_name="graphs"

View File

@@ -0,0 +1,69 @@
from fastapi import Body, HTTPException, Path, Query
from fastapi.routing import APIRouter
from invokeai.app.services.board_record_storage import BoardRecord, BoardChanges
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
from invokeai.app.services.models.board_record import BoardDTO
from invokeai.app.services.models.image_record import ImageDTO
from ..dependencies import ApiDependencies
board_images_router = APIRouter(prefix="/v1/board_images", tags=["boards"])
@board_images_router.post(
"/",
operation_id="create_board_image",
responses={
201: {"description": "The image was added to a board successfully"},
},
status_code=201,
)
async def create_board_image(
board_id: str = Body(description="The id of the board to add to"),
image_name: str = Body(description="The name of the image to add"),
):
"""Creates a board_image"""
try:
result = ApiDependencies.invoker.services.board_images.add_image_to_board(board_id=board_id, image_name=image_name)
return result
except Exception as e:
raise HTTPException(status_code=500, detail="Failed to add to board")
@board_images_router.delete(
"/",
operation_id="remove_board_image",
responses={
201: {"description": "The image was removed from the board successfully"},
},
status_code=201,
)
async def remove_board_image(
board_id: str = Body(description="The id of the board"),
image_name: str = Body(description="The name of the image to remove"),
):
"""Deletes a board_image"""
try:
result = ApiDependencies.invoker.services.board_images.remove_image_from_board(board_id=board_id, image_name=image_name)
return result
except Exception as e:
raise HTTPException(status_code=500, detail="Failed to update board")
@board_images_router.get(
"/{board_id}",
operation_id="list_board_images",
response_model=OffsetPaginatedResults[ImageDTO],
)
async def list_board_images(
board_id: str = Path(description="The id of the board"),
offset: int = Query(default=0, description="The page offset"),
limit: int = Query(default=10, description="The number of boards per page"),
) -> OffsetPaginatedResults[ImageDTO]:
"""Gets a list of images for a board"""
results = ApiDependencies.invoker.services.board_images.get_images_for_board(
board_id,
)
return results

View File

@@ -0,0 +1,108 @@
from typing import Optional, Union
from fastapi import Body, HTTPException, Path, Query
from fastapi.routing import APIRouter
from invokeai.app.services.board_record_storage import BoardChanges
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
from invokeai.app.services.models.board_record import BoardDTO
from ..dependencies import ApiDependencies
boards_router = APIRouter(prefix="/v1/boards", tags=["boards"])
@boards_router.post(
"/",
operation_id="create_board",
responses={
201: {"description": "The board was created successfully"},
},
status_code=201,
response_model=BoardDTO,
)
async def create_board(
board_name: str = Query(description="The name of the board to create"),
) -> BoardDTO:
"""Creates a board"""
try:
result = ApiDependencies.invoker.services.boards.create(board_name=board_name)
return result
except Exception as e:
raise HTTPException(status_code=500, detail="Failed to create board")
@boards_router.get("/{board_id}", operation_id="get_board", response_model=BoardDTO)
async def get_board(
board_id: str = Path(description="The id of board to get"),
) -> BoardDTO:
"""Gets a board"""
try:
result = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id)
return result
except Exception as e:
raise HTTPException(status_code=404, detail="Board not found")
@boards_router.patch(
"/{board_id}",
operation_id="update_board",
responses={
201: {
"description": "The board was updated successfully",
},
},
status_code=201,
response_model=BoardDTO,
)
async def update_board(
board_id: str = Path(description="The id of board to update"),
changes: BoardChanges = Body(description="The changes to apply to the board"),
) -> BoardDTO:
"""Updates a board"""
try:
result = ApiDependencies.invoker.services.boards.update(
board_id=board_id, changes=changes
)
return result
except Exception as e:
raise HTTPException(status_code=500, detail="Failed to update board")
@boards_router.delete("/{board_id}", operation_id="delete_board")
async def delete_board(
board_id: str = Path(description="The id of board to delete"),
) -> None:
"""Deletes a board"""
try:
ApiDependencies.invoker.services.boards.delete(board_id=board_id)
except Exception as e:
# TODO: Does this need any exception handling at all?
pass
@boards_router.get(
"/",
operation_id="list_boards",
response_model=Union[OffsetPaginatedResults[BoardDTO], list[BoardDTO]],
)
async def list_boards(
all: Optional[bool] = Query(default=None, description="Whether to list all boards"),
offset: Optional[int] = Query(default=None, description="The page offset"),
limit: Optional[int] = Query(
default=None, description="The number of boards per page"
),
) -> Union[OffsetPaginatedResults[BoardDTO], list[BoardDTO]]:
"""Gets a list of boards"""
if all:
return ApiDependencies.invoker.services.boards.get_all()
elif offset is not None and limit is not None:
return ApiDependencies.invoker.services.boards.get_many(
offset,
limit,
)
else:
raise HTTPException(
status_code=400,
detail="Invalid request: Must provide either 'all' or both 'offset' and 'limit'",
)

View File

@@ -70,27 +70,25 @@ async def upload_image(
raise HTTPException(status_code=500, detail="Failed to create image")
@images_router.delete("/{image_origin}/{image_name}", operation_id="delete_image")
@images_router.delete("/{image_name}", operation_id="delete_image")
async def delete_image(
image_origin: ResourceOrigin = Path(description="The origin of image to delete"),
image_name: str = Path(description="The name of the image to delete"),
) -> None:
"""Deletes an image"""
try:
ApiDependencies.invoker.services.images.delete(image_origin, image_name)
ApiDependencies.invoker.services.images.delete(image_name)
except Exception as e:
# TODO: Does this need any exception handling at all?
pass
@images_router.patch(
"/{image_origin}/{image_name}",
"/{image_name}",
operation_id="update_image",
response_model=ImageDTO,
)
async def update_image(
image_origin: ResourceOrigin = Path(description="The origin of image to update"),
image_name: str = Path(description="The name of the image to update"),
image_changes: ImageRecordChanges = Body(
description="The changes to apply to the image"
@@ -99,32 +97,29 @@ async def update_image(
"""Updates an image"""
try:
return ApiDependencies.invoker.services.images.update(
image_origin, image_name, image_changes
)
return ApiDependencies.invoker.services.images.update(image_name, image_changes)
except Exception as e:
raise HTTPException(status_code=400, detail="Failed to update image")
@images_router.get(
"/{image_origin}/{image_name}/metadata",
"/{image_name}/metadata",
operation_id="get_image_metadata",
response_model=ImageDTO,
)
async def get_image_metadata(
image_origin: ResourceOrigin = Path(description="The origin of image to get"),
image_name: str = Path(description="The name of image to get"),
) -> ImageDTO:
"""Gets an image's metadata"""
try:
return ApiDependencies.invoker.services.images.get_dto(image_origin, image_name)
return ApiDependencies.invoker.services.images.get_dto(image_name)
except Exception as e:
raise HTTPException(status_code=404)
@images_router.get(
"/{image_origin}/{image_name}",
"/{image_name}",
operation_id="get_image_full",
response_class=Response,
responses={
@@ -136,15 +131,12 @@ async def get_image_metadata(
},
)
async def get_image_full(
image_origin: ResourceOrigin = Path(
description="The type of full-resolution image file to get"
),
image_name: str = Path(description="The name of full-resolution image file to get"),
) -> FileResponse:
"""Gets a full-resolution image file"""
try:
path = ApiDependencies.invoker.services.images.get_path(image_origin, image_name)
path = ApiDependencies.invoker.services.images.get_path(image_name)
if not ApiDependencies.invoker.services.images.validate_path(path):
raise HTTPException(status_code=404)
@@ -160,7 +152,7 @@ async def get_image_full(
@images_router.get(
"/{image_origin}/{image_name}/thumbnail",
"/{image_name}/thumbnail",
operation_id="get_image_thumbnail",
response_class=Response,
responses={
@@ -172,14 +164,13 @@ async def get_image_full(
},
)
async def get_image_thumbnail(
image_origin: ResourceOrigin = Path(description="The origin of thumbnail image file to get"),
image_name: str = Path(description="The name of thumbnail image file to get"),
) -> FileResponse:
"""Gets a thumbnail image file"""
try:
path = ApiDependencies.invoker.services.images.get_path(
image_origin, image_name, thumbnail=True
image_name, thumbnail=True
)
if not ApiDependencies.invoker.services.images.validate_path(path):
raise HTTPException(status_code=404)
@@ -192,25 +183,21 @@ async def get_image_thumbnail(
@images_router.get(
"/{image_origin}/{image_name}/urls",
"/{image_name}/urls",
operation_id="get_image_urls",
response_model=ImageUrlsDTO,
)
async def get_image_urls(
image_origin: ResourceOrigin = Path(description="The origin of the image whose URL to get"),
image_name: str = Path(description="The name of the image whose URL to get"),
) -> ImageUrlsDTO:
"""Gets an image and thumbnail URL"""
try:
image_url = ApiDependencies.invoker.services.images.get_url(
image_origin, image_name
)
image_url = ApiDependencies.invoker.services.images.get_url(image_name)
thumbnail_url = ApiDependencies.invoker.services.images.get_url(
image_origin, image_name, thumbnail=True
image_name, thumbnail=True
)
return ImageUrlsDTO(
image_origin=image_origin,
image_name=image_name,
image_url=image_url,
thumbnail_url=thumbnail_url,
@@ -234,6 +221,9 @@ async def list_images_with_metadata(
is_intermediate: Optional[bool] = Query(
default=None, description="Whether to list intermediate images"
),
board_id: Optional[str] = Query(
default=None, description="The board id to filter by"
),
offset: int = Query(default=0, description="The page offset"),
limit: int = Query(default=10, description="The number of images per page"),
) -> OffsetPaginatedResults[ImageDTO]:
@@ -245,6 +235,7 @@ async def list_images_with_metadata(
image_origin,
categories,
is_intermediate,
board_id,
)
return image_dtos

View File

@@ -1,13 +1,14 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) and 2023 Kent Keirsey (https://github.com/hipsterusername)
import shutil
import asyncio
from typing import Annotated, Any, List, Literal, Optional, Union
from typing import Annotated, Literal, Optional, Union, Dict
from fastapi import Query
from fastapi.routing import APIRouter, HTTPException
from pydantic import BaseModel, Field, parse_obj_as
from pathlib import Path
from ..dependencies import ApiDependencies
from invokeai.backend import BaseModelType, ModelType
from invokeai.backend.model_management.models import OPENAPI_MODEL_CONFIGS
MODEL_CONFIGS = Union[tuple(OPENAPI_MODEL_CONFIGS)]
models_router = APIRouter(prefix="/v1/models", tags=["models"])
@@ -19,6 +20,15 @@ class VaeRepo(BaseModel):
class ModelInfo(BaseModel):
description: Optional[str] = Field(description="A description of the model")
model_name: str = Field(description="The name of the model")
model_type: str = Field(description="The type of the model")
class DiffusersModelInfo(ModelInfo):
format: Literal['folder'] = 'folder'
vae: Optional[VaeRepo] = Field(description="The VAE repo to use for this model")
repo_id: Optional[str] = Field(description="The repo ID to use for this model")
path: Optional[str] = Field(description="The path to the model")
class CkptModelInfo(ModelInfo):
format: Literal['ckpt'] = 'ckpt'
@@ -29,12 +39,8 @@ class CkptModelInfo(ModelInfo):
width: Optional[int] = Field(description="The width of the model")
height: Optional[int] = Field(description="The height of the model")
class DiffusersModelInfo(ModelInfo):
format: Literal['diffusers'] = 'diffusers'
vae: Optional[VaeRepo] = Field(description="The VAE repo to use for this model")
repo_id: Optional[str] = Field(description="The repo ID to use for this model")
path: Optional[str] = Field(description="The path to the model")
class SafetensorsModelInfo(CkptModelInfo):
format: Literal['safetensors'] = 'safetensors'
class CreateModelRequest(BaseModel):
name: str = Field(description="The name of the model")
@@ -56,7 +62,7 @@ class ConvertedModelResponse(BaseModel):
info: DiffusersModelInfo = Field(description="The converted model info")
class ModelsList(BaseModel):
models: dict[str, Annotated[Union[(CkptModelInfo,DiffusersModelInfo)], Field(discriminator="format")]]
models: list[MODEL_CONFIGS]
@models_router.get(
@@ -64,9 +70,16 @@ class ModelsList(BaseModel):
operation_id="list_models",
responses={200: {"model": ModelsList }},
)
async def list_models() -> ModelsList:
async def list_models(
base_model: Optional[BaseModelType] = Query(
default=None, description="Base model"
),
model_type: Optional[ModelType] = Query(
default=None, description="The type of model to get"
),
) -> ModelsList:
"""Gets a list of models"""
models_raw = ApiDependencies.invoker.services.model_manager.list_models()
models_raw = ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type)
models = parse_obj_as(ModelsList, { "models": models_raw })
return models
@@ -121,7 +134,7 @@ async def delete_model(model_name: str) -> None:
raise HTTPException(status_code=204, detail=f"Model '{model_name}' deleted successfully")
else:
logger.error(f"Model not found")
logger.error("Model not found")
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")

View File

@@ -24,7 +24,7 @@ logger = InvokeAILogger.getLogger(config=app_config)
import invokeai.frontend.web as web_dir
from .api.dependencies import ApiDependencies
from .api.routers import sessions, models, images
from .api.routers import sessions, models, images, boards, board_images
from .api.sockets import SocketIO
from .invocations.baseinvocation import BaseInvocation
@@ -78,6 +78,10 @@ app.include_router(models.models_router, prefix="/api")
app.include_router(images.images_router, prefix="/api")
app.include_router(boards.boards_router, prefix="/api")
app.include_router(board_images.board_images_router, prefix="/api")
# Build a custom OpenAPI to include all outputs
# TODO: can outputs be included on metadata of invocation schemas somehow?
def custom_openapi():
@@ -116,6 +120,22 @@ def custom_openapi():
invoker_schema["output"] = outputs_ref
from invokeai.backend.model_management.models import get_model_config_enums
for model_config_format_enum in set(get_model_config_enums()):
name = model_config_format_enum.__qualname__
if name in openapi_schema["components"]["schemas"]:
# print(f"Config with name {name} already defined")
continue
# "BaseModelType":{"title":"BaseModelType","description":"An enumeration.","enum":["sd-1","sd-2"],"type":"string"}
openapi_schema["components"]["schemas"][name] = dict(
title=name,
description="An enumeration.",
type="string",
enum=list(v.value for v in model_config_format_enum),
)
app.openapi_schema = openapi_schema
return app.openapi_schema

View File

@@ -6,10 +6,7 @@ import re
import shlex
import sys
import time
from typing import (
Union,
get_type_hints,
)
from typing import Union, get_type_hints
from pydantic import BaseModel, ValidationError
from pydantic.fields import Field
@@ -26,23 +23,25 @@ from invokeai.app.services.images import ImageService
from invokeai.app.services.metadata import CoreMetadataService
from invokeai.app.services.resource_name import SimpleNameService
from invokeai.app.services.urls import LocalUrlService
from .services.default_graphs import create_system_graphs
from .services.default_graphs import (default_text_to_image_graph_id,
create_system_graphs)
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
from .cli.commands import BaseCommand, CliContext, ExitCli, add_graph_parsers, add_parsers, SortedHelpFormatter
from .cli.commands import (BaseCommand, CliContext, ExitCli,
SortedHelpFormatter, add_graph_parsers, add_parsers)
from .cli.completer import set_autocompleter
from .invocations.baseinvocation import BaseInvocation
from .services.events import EventServiceBase
from .services.model_manager_initializer import get_model_manager
from .services.restoration_services import RestorationServices
from .services.graph import Edge, EdgeConnection, GraphExecutionState, GraphInvocation, LibraryGraph, are_connection_types_compatible
from .services.default_graphs import default_text_to_image_graph_id
from .services.graph import (Edge, EdgeConnection, GraphExecutionState,
GraphInvocation, LibraryGraph,
are_connection_types_compatible)
from .services.image_file_storage import DiskImageFileStorage
from .services.invocation_queue import MemoryInvocationQueue
from .services.invocation_services import InvocationServices
from .services.invoker import Invoker
from .services.model_manager_service import ModelManagerService
from .services.processor import DefaultInvocationProcessor
from .services.restoration_services import RestorationServices
from .services.sqlite import SqliteItemStorage
@@ -197,7 +196,6 @@ def invoke_all(context: CliContext):
raise SessionError()
def invoke_cli():
# get the optional list of invocations to execute on the command line
parser = config.get_parser()
parser.add_argument('commands',nargs='*')
@@ -208,8 +206,8 @@ def invoke_cli():
if infile := config.from_file:
sys.stdin = open(infile,"r")
model_manager = get_model_manager(config,logger=logger)
model_manager = ModelManagerService(config,logger)
events = EventServiceBase()
output_folder = config.output_path
@@ -257,9 +255,11 @@ def invoke_cli():
logger=logger,
configuration=config,
)
system_graphs = create_system_graphs(services.graph_library)
system_graph_names = set([g.name for g in system_graphs])
set_autocompleter(services)
invoker = Invoker(services)
session: GraphExecutionState = invoker.create_execution_state()

View File

@@ -1,13 +1,15 @@
from typing import Literal, Optional, Union
from pydantic import BaseModel, Field
from contextlib import ExitStack
import re
from invokeai.app.invocations.util.choose_model import choose_model
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
from ...backend.prompting.conditioning import try_parse_legacy_blend
from .model import ClipField
from ...backend.util.devices import choose_torch_device, torch_dtype
from ...backend.util.devices import torch_dtype
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
from ...backend.stable_diffusion.textual_inversion_manager import TextualInversionManager
from ...backend.model_management import BaseModelType, ModelType, SubModelType
from ...backend.model_management.lora import ModelPatcher
from compel import Compel
from compel.prompt_parser import (
@@ -40,7 +42,7 @@ class CompelInvocation(BaseInvocation):
type: Literal["compel"] = "compel"
prompt: str = Field(default="", description="Prompt")
model: str = Field(default="", description="Model to use")
clip: ClipField = Field(None, description="Clip to use")
# Schema customisation
class Config(InvocationConfig):
@@ -56,73 +58,74 @@ class CompelInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> CompelOutput:
# TODO: load without model
model = choose_model(context.services.model_manager, self.model)
pipeline = model["model"]
tokenizer = pipeline.tokenizer
text_encoder = pipeline.text_encoder
# TODO: global? input?
#use_full_precision = precision == "float32" or precision == "autocast"
#use_full_precision = False
# TODO: redo TI when separate model loding implemented
#textual_inversion_manager = TextualInversionManager(
# tokenizer=tokenizer,
# text_encoder=text_encoder,
# full_precision=use_full_precision,
#)
def load_huggingface_concepts(concepts: list[str]):
pipeline.textual_inversion_manager.load_huggingface_concepts(concepts)
# apply the concepts library to the prompt
prompt_str = pipeline.textual_inversion_manager.hf_concepts_library.replace_concepts_with_triggers(
self.prompt,
lambda concepts: load_huggingface_concepts(concepts),
pipeline.textual_inversion_manager.get_all_trigger_strings(),
tokenizer_info = context.services.model_manager.get_model(
**self.clip.tokenizer.dict(),
)
# lazy-load any deferred textual inversions.
# this might take a couple of seconds the first time a textual inversion is used.
pipeline.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms(
prompt_str
text_encoder_info = context.services.model_manager.get_model(
**self.clip.text_encoder.dict(),
)
with tokenizer_info as orig_tokenizer,\
text_encoder_info as text_encoder,\
ExitStack() as stack:
compel = Compel(
tokenizer=tokenizer,
text_encoder=text_encoder,
textual_inversion_manager=pipeline.textual_inversion_manager,
dtype_for_device_getter=torch_dtype,
truncate_long_prompts=False,
)
loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.clip.loras]
legacy_blend = try_parse_legacy_blend(prompt_str, skip_normalize=False)
if legacy_blend is not None:
conjunction = legacy_blend
else:
conjunction = Compel.parse_prompt_string(prompt_str)
ti_list = []
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
name = trigger[1:-1]
try:
ti_list.append(
stack.enter_context(
context.services.model_manager.get_model(
model_name=name,
base_model=self.clip.text_encoder.base_model,
model_type=ModelType.TextualInversion,
)
)
)
except Exception:
#print(e)
#import traceback
#print(traceback.format_exc())
print(f"Warn: trigger: \"{trigger}\" not found")
if context.services.configuration.log_tokenization:
log_tokenization_for_conjunction(conjunction, tokenizer)
with ModelPatcher.apply_lora_text_encoder(text_encoder, loras),\
ModelPatcher.apply_ti(orig_tokenizer, text_encoder, ti_list) as (tokenizer, ti_manager):
c, options = compel.build_conditioning_tensor_for_conjunction(conjunction)
compel = Compel(
tokenizer=tokenizer,
text_encoder=text_encoder,
textual_inversion_manager=ti_manager,
dtype_for_device_getter=torch_dtype,
truncate_long_prompts=True, # TODO:
)
conjunction = Compel.parse_prompt_string(self.prompt)
prompt: Union[FlattenedPrompt, Blend] = conjunction.prompts[0]
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction),
cross_attention_control_args=options.get("cross_attention_control", None),
)
if context.services.configuration.log_tokenization:
log_tokenization_for_prompt_object(prompt, tokenizer)
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
c, options = compel.build_conditioning_tensor_for_prompt_object(prompt)
# TODO: long prompt support
#if not self.truncate_long_prompts:
# [c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction),
cross_attention_control_args=options.get("cross_attention_control", None),
)
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
# TODO: hacky but works ;D maybe rename latents somehow?
context.services.latents.save(conditioning_name, (c, ec))
# TODO: hacky but works ;D maybe rename latents somehow?
context.services.latents.save(conditioning_name, (c, ec))
return CompelOutput(
conditioning=ConditioningField(
conditioning_name=conditioning_name,
),
)
return CompelOutput(
conditioning=ConditioningField(
conditioning_name=conditioning_name,
),
)
def get_max_token_count(

View File

@@ -1,7 +1,7 @@
# InvokeAI nodes for ControlNet image preprocessors
# initial implementation by Gregg Helt, 2023
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
from builtins import float
from builtins import float, bool
import numpy as np
from typing import Literal, Optional, Union, List
@@ -94,6 +94,7 @@ CONTROLNET_DEFAULT_MODELS = [
]
CONTROLNET_NAME_VALUES = Literal[tuple(CONTROLNET_DEFAULT_MODELS)]
CONTROLNET_MODE_VALUES = Literal[tuple(["balanced", "more_prompt", "more_control", "unbalanced"])]
class ControlField(BaseModel):
image: ImageField = Field(default=None, description="The control image")
@@ -104,6 +105,8 @@ class ControlField(BaseModel):
description="When the ControlNet is first applied (% of total steps)")
end_step_percent: float = Field(default=1, ge=0, le=1,
description="When the ControlNet is last applied (% of total steps)")
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The contorl mode to use")
@validator("control_weight")
def abs_le_one(cls, v):
"""validate that all abs(values) are <=1"""
@@ -144,11 +147,11 @@ class ControlNetInvocation(BaseInvocation):
control_model: CONTROLNET_NAME_VALUES = Field(default="lllyasviel/sd-controlnet-canny",
description="control model used")
control_weight: Union[float, List[float]] = Field(default=1.0, description="The weight given to the ControlNet")
# TODO: add support in backend core for begin_step_percent, end_step_percent, guess_mode
begin_step_percent: float = Field(default=0, ge=0, le=1,
description="When the ControlNet is first applied (% of total steps)")
end_step_percent: float = Field(default=1, ge=0, le=1,
description="When the ControlNet is last applied (% of total steps)")
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode used")
# fmt: on
class Config(InvocationConfig):
@@ -166,7 +169,6 @@ class ControlNetInvocation(BaseInvocation):
}
def invoke(self, context: InvocationContext) -> ControlOutput:
return ControlOutput(
control=ControlField(
image=self.image,
@@ -174,6 +176,7 @@ class ControlNetInvocation(BaseInvocation):
control_weight=self.control_weight,
begin_step_percent=self.begin_step_percent,
end_step_percent=self.end_step_percent,
control_mode=self.control_mode,
),
)
@@ -193,9 +196,7 @@ class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
return image
def invoke(self, context: InvocationContext) -> ImageOutput:
raw_image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
raw_image = context.services.images.get_pil_image(self.image.image_name)
# image type should be PIL.PngImagePlugin.PngImageFile ?
processed_image = self.run_processor(raw_image)
@@ -216,10 +217,7 @@ class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
)
"""Builds an ImageOutput and its ImageField"""
processed_image_field = ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
)
processed_image_field = ImageField(image_name=image_dto.image_name)
return ImageOutput(
image=processed_image_field,
# width=processed_image.width,

View File

@@ -36,12 +36,8 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
mask = context.services.images.get_pil_image(
self.mask.image_origin, self.mask.image_name
)
image = context.services.images.get_pil_image(self.image.image_name)
mask = context.services.images.get_pil_image(self.mask.image_name)
# Convert to cv image/mask
# TODO: consider making these utility functions
@@ -65,10 +61,7 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)

View File

@@ -3,23 +3,27 @@
from functools import partial
from typing import Literal, Optional, Union, get_args
import numpy as np
from diffusers import ControlNetModel
from torch import Tensor
import torch
from diffusers import ControlNetModel
from pydantic import BaseModel, Field
from invokeai.app.models.image import ColorField, ImageField, ResourceOrigin
from invokeai.app.invocations.util.choose_model import choose_model
from invokeai.app.models.image import ImageCategory, ResourceOrigin
from invokeai.app.models.image import (ColorField, ImageCategory, ImageField,
ResourceOrigin)
from invokeai.app.util.misc import SEED_MAX, get_random_seed
from invokeai.backend.generator.inpaint import infill_methods
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .image import ImageOutput
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator
from ...backend.generator import Inpaint, InvokeAIGenerator
from ...backend.stable_diffusion import PipelineIntermediateState
from ..util.step_callback import stable_diffusion_step_callback
from .baseinvocation import BaseInvocation, InvocationConfig, InvocationContext
from .image import ImageOutput
import re
from ...backend.model_management.lora import ModelPatcher
from ...backend.stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
from .model import UNetField, VaeField
from .compel import ConditioningField
from contextlib import contextmanager, ExitStack, ContextDecorator
SAMPLER_NAME_VALUES = Literal[tuple(InvokeAIGenerator.schedulers())]
INFILL_METHODS = Literal[tuple(infill_methods())]
@@ -28,119 +32,48 @@ DEFAULT_INFILL_METHOD = (
)
class SDImageInvocation(BaseModel):
"""Helper class to provide all Stable Diffusion raster image invocations with additional config"""
from .latent import get_scheduler
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["stable-diffusion", "image"],
"type_hints": {
"model": "model",
},
},
}
class OldModelContext(ContextDecorator):
model: StableDiffusionGeneratorPipeline
def __init__(self, model):
self.model = model
def __enter__(self):
return self.model
def __exit__(self, *exc):
return False
class OldModelInfo:
name: str
hash: str
context: OldModelContext
def __init__(self, name: str, hash: str, model: StableDiffusionGeneratorPipeline):
self.name = name
self.hash = hash
self.context = OldModelContext(
model=model,
)
# Text to image
class TextToImageInvocation(BaseInvocation, SDImageInvocation):
"""Generates an image using text2img."""
class InpaintInvocation(BaseInvocation):
"""Generates an image using inpaint."""
type: Literal["txt2img"] = "txt2img"
type: Literal["inpaint"] = "inpaint"
# Inputs
# TODO: consider making prompt optional to enable providing prompt through a link
# fmt: off
prompt: Optional[str] = Field(description="The prompt to generate an image from")
positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation")
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
seed: int = Field(ge=0, le=SEED_MAX, description="The seed to use (omit for random)", default_factory=get_random_seed)
steps: int = Field(default=30, gt=0, description="The number of steps to use to generate the image")
width: int = Field(default=512, multiple_of=8, gt=0, description="The width of the resulting image", )
height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting image", )
cfg_scale: float = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
model: str = Field(default="", description="The model to use (currently ignored)")
progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", )
control_model: Optional[str] = Field(default=None, description="The control model to use")
control_image: Optional[ImageField] = Field(default=None, description="The processed control image")
# fmt: on
# TODO: pass this an emitter method or something? or a session for dispatching?
def dispatch_progress(
self,
context: InvocationContext,
source_node_id: str,
intermediate_state: PipelineIntermediateState,
) -> None:
stable_diffusion_step_callback(
context=context,
intermediate_state=intermediate_state,
node=self.dict(),
source_node_id=source_node_id,
)
def invoke(self, context: InvocationContext) -> ImageOutput:
# Handle invalid model parameter
model = choose_model(context.services.model_manager, self.model)
# loading controlnet image (currently requires pre-processed image)
control_image = (
None if self.control_image is None
else context.services.images.get_pil_image(
self.control_image.image_origin, self.control_image.image_name
)
)
# loading controlnet model
if (self.control_model is None or self.control_model==''):
control_model = None
else:
# FIXME: change this to dropdown menu?
# FIXME: generalize so don't have to hardcode torch_dtype and device
control_model = ControlNetModel.from_pretrained(self.control_model,
torch_dtype=torch.float16).to("cuda")
# Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(
context.graph_execution_state_id
)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
txt2img = Txt2Img(model, control_model=control_model)
outputs = txt2img.generate(
prompt=self.prompt,
step_callback=partial(self.dispatch_progress, context, source_node_id),
control_image=control_image,
**self.dict(
exclude={"prompt", "control_image" }
), # Shorthand for passing all of the parameters above manually
)
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
# each time it is called. We only need the first one.
generate_output = next(outputs)
image_dto = context.services.images.create(
image=generate_output.image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
session_id=context.graph_execution_state_id,
node_id=self.id,
is_intermediate=self.is_intermediate,
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
width=image_dto.width,
height=image_dto.height,
)
class ImageToImageInvocation(TextToImageInvocation):
"""Generates an image using img2img."""
type: Literal["img2img"] = "img2img"
unet: UNetField = Field(default=None, description="UNet model")
vae: VaeField = Field(default=None, description="Vae model")
# Inputs
image: Union[ImageField, None] = Field(description="The input image")
@@ -152,77 +85,6 @@ class ImageToImageInvocation(TextToImageInvocation):
description="Whether or not the result should be fit to the aspect ratio of the input image",
)
def dispatch_progress(
self,
context: InvocationContext,
source_node_id: str,
intermediate_state: PipelineIntermediateState,
) -> None:
stable_diffusion_step_callback(
context=context,
intermediate_state=intermediate_state,
node=self.dict(),
source_node_id=source_node_id,
)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = (
None
if self.image is None
else context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
)
if self.fit:
image = image.resize((self.width, self.height))
# Handle invalid model parameter
model = choose_model(context.services.model_manager, self.model)
# Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(
context.graph_execution_state_id
)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
outputs = Img2Img(model).generate(
prompt=self.prompt,
init_image=image,
step_callback=partial(self.dispatch_progress, context, source_node_id),
**self.dict(
exclude={"prompt", "image", "mask"}
), # Shorthand for passing all of the parameters above manually
)
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
# each time it is called. We only need the first one.
generator_output = next(outputs)
image_dto = context.services.images.create(
image=generator_output.image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
session_id=context.graph_execution_state_id,
node_id=self.id,
is_intermediate=self.is_intermediate,
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
width=image_dto.width,
height=image_dto.height,
)
class InpaintInvocation(ImageToImageInvocation):
"""Generates an image using inpaint."""
type: Literal["inpaint"] = "inpaint"
# Inputs
mask: Union[ImageField, None] = Field(description="The mask")
seam_size: int = Field(default=96, ge=1, description="The seam inpaint size (px)")
@@ -265,6 +127,14 @@ class InpaintInvocation(ImageToImageInvocation):
description="The amount by which to replace masked areas with latent noise",
)
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["stable-diffusion", "image"],
},
}
def dispatch_progress(
self,
context: InvocationContext,
@@ -278,39 +148,86 @@ class InpaintInvocation(ImageToImageInvocation):
source_node_id=source_node_id,
)
def get_conditioning(self, context):
c, extra_conditioning_info = context.services.latents.get(self.positive_conditioning.conditioning_name)
uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name)
return (uc, c, extra_conditioning_info)
@contextmanager
def load_model_old_way(self, context, scheduler):
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict())
vae_info = context.services.model_manager.get_model(**self.vae.vae.dict())
#unet = unet_info.context.model
#vae = vae_info.context.model
with ExitStack() as stack:
loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
with vae_info as vae,\
unet_info as unet,\
ModelPatcher.apply_lora_unet(unet, loras):
device = context.services.model_manager.mgr.cache.execution_device
dtype = context.services.model_manager.mgr.cache.precision
pipeline = StableDiffusionGeneratorPipeline(
vae=vae,
text_encoder=None,
tokenizer=None,
unet=unet,
scheduler=scheduler,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
precision="float16" if dtype == torch.float16 else "float32",
execution_device=device,
)
yield OldModelInfo(
name=self.unet.unet.model_name,
hash="<NO-HASH>",
model=pipeline,
)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = (
None
if self.image is None
else context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
else context.services.images.get_pil_image(self.image.image_name)
)
mask = (
None
if self.mask is None
else context.services.images.get_pil_image(self.mask.image_origin, self.mask.image_name)
else context.services.images.get_pil_image(self.mask.image_name)
)
# Handle invalid model parameter
model = choose_model(context.services.model_manager, self.model)
# Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(
context.graph_execution_state_id
)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
outputs = Inpaint(model).generate(
prompt=self.prompt,
init_image=image,
mask_image=mask,
step_callback=partial(self.dispatch_progress, context, source_node_id),
**self.dict(
exclude={"prompt", "image", "mask"}
), # Shorthand for passing all of the parameters above manually
conditioning = self.get_conditioning(context)
scheduler = get_scheduler(
context=context,
scheduler_info=self.unet.scheduler,
scheduler_name=self.scheduler,
)
with self.load_model_old_way(context, scheduler) as model:
outputs = Inpaint(model).generate(
conditioning=conditioning,
scheduler=scheduler,
init_image=image,
mask_image=mask,
step_callback=partial(self.dispatch_progress, context, source_node_id),
**self.dict(
exclude={"positive_conditioning", "negative_conditioning", "scheduler", "image", "mask"}
), # Shorthand for passing all of the parameters above manually
)
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
# each time it is called. We only need the first one.
generator_output = next(outputs)
@@ -325,10 +242,7 @@ class InpaintInvocation(ImageToImageInvocation):
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)

View File

@@ -72,13 +72,10 @@ class LoadImageInvocation(BaseInvocation):
)
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_origin, self.image.image_name)
image = context.services.images.get_pil_image(self.image.image_name)
return ImageOutput(
image=ImageField(
image_name=self.image.image_name,
image_origin=self.image.image_origin,
),
image=ImageField(image_name=self.image.image_name),
width=image.width,
height=image.height,
)
@@ -95,19 +92,14 @@ class ShowImageInvocation(BaseInvocation):
)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
image = context.services.images.get_pil_image(self.image.image_name)
if image:
image.show()
# TODO: how to handle failure?
return ImageOutput(
image=ImageField(
image_name=self.image.image_name,
image_origin=self.image.image_origin,
),
image=ImageField(image_name=self.image.image_name),
width=image.width,
height=image.height,
)
@@ -128,9 +120,7 @@ class ImageCropInvocation(BaseInvocation, PILInvocationConfig):
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
image = context.services.images.get_pil_image(self.image.image_name)
image_crop = Image.new(
mode="RGBA", size=(self.width, self.height), color=(0, 0, 0, 0)
@@ -147,10 +137,7 @@ class ImageCropInvocation(BaseInvocation, PILInvocationConfig):
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
@@ -171,19 +158,13 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
base_image = context.services.images.get_pil_image(
self.base_image.image_origin, self.base_image.image_name
)
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
base_image = context.services.images.get_pil_image(self.base_image.image_name)
image = context.services.images.get_pil_image(self.image.image_name)
mask = (
None
if self.mask is None
else ImageOps.invert(
context.services.images.get_pil_image(
self.mask.image_origin, self.mask.image_name
)
context.services.images.get_pil_image(self.mask.image_name)
)
)
# TODO: probably shouldn't invert mask here... should user be required to do it?
@@ -209,10 +190,7 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
@@ -230,9 +208,7 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
# fmt: on
def invoke(self, context: InvocationContext) -> MaskOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
image = context.services.images.get_pil_image(self.image.image_name)
image_mask = image.split()[-1]
if self.invert:
@@ -248,9 +224,7 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
)
return MaskOutput(
mask=ImageField(
image_origin=image_dto.image_origin, image_name=image_dto.image_name
),
mask=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
@@ -268,12 +242,8 @@ class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig):
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image1 = context.services.images.get_pil_image(
self.image1.image_origin, self.image1.image_name
)
image2 = context.services.images.get_pil_image(
self.image2.image_origin, self.image2.image_name
)
image1 = context.services.images.get_pil_image(self.image1.image_name)
image2 = context.services.images.get_pil_image(self.image2.image_name)
multiply_image = ImageChops.multiply(image1, image2)
@@ -287,9 +257,7 @@ class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig):
)
return ImageOutput(
image=ImageField(
image_origin=image_dto.image_origin, image_name=image_dto.image_name
),
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
@@ -310,9 +278,7 @@ class ImageChannelInvocation(BaseInvocation, PILInvocationConfig):
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
image = context.services.images.get_pil_image(self.image.image_name)
channel_image = image.getchannel(self.channel)
@@ -326,9 +292,7 @@ class ImageChannelInvocation(BaseInvocation, PILInvocationConfig):
)
return ImageOutput(
image=ImageField(
image_origin=image_dto.image_origin, image_name=image_dto.image_name
),
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
@@ -349,9 +313,7 @@ class ImageConvertInvocation(BaseInvocation, PILInvocationConfig):
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
image = context.services.images.get_pil_image(self.image.image_name)
converted_image = image.convert(self.mode)
@@ -365,9 +327,7 @@ class ImageConvertInvocation(BaseInvocation, PILInvocationConfig):
)
return ImageOutput(
image=ImageField(
image_origin=image_dto.image_origin, image_name=image_dto.image_name
),
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
@@ -386,9 +346,7 @@ class ImageBlurInvocation(BaseInvocation, PILInvocationConfig):
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
image = context.services.images.get_pil_image(self.image.image_name)
blur = (
ImageFilter.GaussianBlur(self.radius)
@@ -407,10 +365,7 @@ class ImageBlurInvocation(BaseInvocation, PILInvocationConfig):
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
@@ -450,9 +405,7 @@ class ImageResizeInvocation(BaseInvocation, PILInvocationConfig):
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
image = context.services.images.get_pil_image(self.image.image_name)
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
@@ -471,10 +424,7 @@ class ImageResizeInvocation(BaseInvocation, PILInvocationConfig):
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
@@ -493,9 +443,7 @@ class ImageScaleInvocation(BaseInvocation, PILInvocationConfig):
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
image = context.services.images.get_pil_image(self.image.image_name)
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
width = int(image.width * self.scale_factor)
@@ -516,10 +464,7 @@ class ImageScaleInvocation(BaseInvocation, PILInvocationConfig):
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
@@ -538,9 +483,7 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
image = context.services.images.get_pil_image(self.image.image_name)
image_arr = numpy.asarray(image, dtype=numpy.float32) / 255
image_arr = image_arr * (self.max - self.min) + self.max
@@ -557,10 +500,7 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
@@ -579,9 +519,7 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
image = context.services.images.get_pil_image(self.image.image_name)
image_arr = numpy.asarray(image, dtype=numpy.float32)
image_arr = (
@@ -603,10 +541,7 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)

View File

@@ -134,9 +134,7 @@ class InfillColorInvocation(BaseInvocation):
)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
image = context.services.images.get_pil_image(self.image.image_name)
solid_bg = Image.new("RGBA", image.size, self.color.tuple())
infilled = Image.alpha_composite(solid_bg, image.convert("RGBA"))
@@ -153,10 +151,7 @@ class InfillColorInvocation(BaseInvocation):
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
@@ -179,9 +174,7 @@ class InfillTileInvocation(BaseInvocation):
)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
image = context.services.images.get_pil_image(self.image.image_name)
infilled = tile_fill_missing(
image.copy(), seed=self.seed, tile_size=self.tile_size
@@ -198,10 +191,7 @@ class InfillTileInvocation(BaseInvocation):
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
@@ -217,9 +207,7 @@ class InfillPatchMatchInvocation(BaseInvocation):
)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
image = context.services.images.get_pil_image(self.image.image_name)
if PatchMatch.patchmatch_available():
infilled = infill_patchmatch(image.copy())
@@ -236,10 +224,7 @@ class InfillPatchMatchInvocation(BaseInvocation):
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)

View File

@@ -1,43 +1,36 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
import random
import einops
from typing import Literal, Optional, Union, List
from contextlib import ExitStack
from typing import List, Literal, Optional, Union
from compel import Compel
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel
import einops
from pydantic import BaseModel, Field, validator
import torch
from invokeai.app.invocations.util.choose_model import choose_model
from invokeai.app.models.image import ImageCategory
from invokeai.app.util.misc import SEED_MAX, get_random_seed
from invokeai.app.util.step_callback import stable_diffusion_step_callback
from .controlnet_image_processors import ControlField
from ...backend.model_management.model_manager import ModelManager
from ...backend.util.devices import choose_torch_device, torch_dtype
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
from ...backend.image_util.seamless import configure_model_padding
from ...backend.prompting.conditioning import get_uc_and_c_and_ec
from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline, image_resized_to_grid_as_tensor
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
from ...backend.stable_diffusion.diffusers_pipeline import ControlNetData
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
import numpy as np
from ..services.image_file_storage import ResourceOrigin
from .baseinvocation import BaseInvocation, InvocationContext
from .image import ImageField, ImageOutput
from .compel import ConditioningField
from ...backend.stable_diffusion import PipelineIntermediateState
from diffusers import ControlNetModel, DPMSolverMultistepScheduler
from diffusers.image_processor import VaeImageProcessor
from diffusers.schedulers import SchedulerMixin as Scheduler
import diffusers
from diffusers import DiffusionPipeline, ControlNetModel
from invokeai.app.util.misc import SEED_MAX, get_random_seed
from invokeai.app.util.step_callback import stable_diffusion_step_callback
from ..models.image import ImageCategory, ImageField, ResourceOrigin
from ...backend.image_util.seamless import configure_model_padding
from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.stable_diffusion.diffusers_pipeline import (
ConditioningData, ControlNetData, StableDiffusionGeneratorPipeline,
image_resized_to_grid_as_tensor)
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import \
PostprocessingSettings
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
from ...backend.util.devices import choose_torch_device, torch_dtype
from ...backend.model_management.lora import ModelPatcher
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
InvocationConfig, InvocationContext)
from .compel import ConditioningField
from .controlnet_image_processors import ControlField
from .image import ImageOutput
from .model import ModelInfo, UNetField, VaeField
class LatentsField(BaseModel):
"""A latents field used for passing latents between invocations"""
@@ -90,15 +83,22 @@ SAMPLER_NAME_VALUES = Literal[
]
def get_scheduler(scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler:
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP['ddim'])
scheduler_config = model.scheduler.config
def get_scheduler(
context: InvocationContext,
scheduler_info: ModelInfo,
scheduler_name: str,
) -> Scheduler:
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP['ddim'])
orig_scheduler_info = context.services.model_manager.get_model(**scheduler_info.dict())
with orig_scheduler_info as orig_scheduler:
scheduler_config = orig_scheduler.config
if "_backup" in scheduler_config:
scheduler_config = scheduler_config["_backup"]
scheduler_config = {**scheduler_config, **scheduler_extra_config, "_backup": scheduler_config}
scheduler = scheduler_class.from_config(scheduler_config)
# hack copied over from generate.py
if not hasattr(scheduler, 'uses_inpainting_model'):
scheduler.uses_inpainting_model = lambda: False
@@ -128,7 +128,6 @@ def get_noise(width:int, height:int, device:torch.device, seed:int = 0, latent_c
# x = (1 - self.perlin) * x + self.perlin * perlin_noise
return x
class NoiseInvocation(BaseInvocation):
"""Generates latent noise."""
@@ -176,10 +175,10 @@ class TextToLatentsInvocation(BaseInvocation):
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
cfg_scale: Union[float, List[float]] = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
model: str = Field(default="", description="The model to use (currently ignored)")
control: Union[ControlField, List[ControlField]] = Field(default=None, description="The control to use")
# seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
# seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
unet: UNetField = Field(default=None, description="UNet submodel")
control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
#seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
#seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
# fmt: on
@validator("cfg_scale")
@@ -219,44 +218,10 @@ class TextToLatentsInvocation(BaseInvocation):
source_node_id=source_node_id,
)
def get_model(self, model_manager: ModelManager) -> StableDiffusionGeneratorPipeline:
model_info = choose_model(model_manager, self.model)
model_name = model_info['model_name']
model_hash = model_info['hash']
model: StableDiffusionGeneratorPipeline = model_info['model']
model.scheduler = get_scheduler(
model=model,
scheduler_name=self.scheduler
)
# if isinstance(model, DiffusionPipeline):
# for component in [model.unet, model.vae]:
# configure_model_padding(component,
# self.seamless,
# self.seamless_axes
# )
# else:
# configure_model_padding(model,
# self.seamless,
# self.seamless_axes
# )
return model
def get_conditioning_data(self, context: InvocationContext, model: StableDiffusionGeneratorPipeline) -> ConditioningData:
def get_conditioning_data(self, context: InvocationContext, scheduler) -> ConditioningData:
c, extra_conditioning_info = context.services.latents.get(self.positive_conditioning.conditioning_name)
uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name)
compel = Compel(
tokenizer=model.tokenizer,
text_encoder=model.text_encoder,
textual_inversion_manager=model.textual_inversion_manager,
dtype_for_device_getter=torch_dtype,
truncate_long_prompts=False,
)
[c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
conditioning_data = ConditioningData(
unconditioned_embeddings=uc,
text_embeddings=c,
@@ -268,33 +233,68 @@ class TextToLatentsInvocation(BaseInvocation):
h_symmetry_time_pct=None,#h_symmetry_time_pct,
v_symmetry_time_pct=None#v_symmetry_time_pct,
),
).add_scheduler_args_if_applicable(model.scheduler, eta=0.0)#ddim_eta)
)
conditioning_data = conditioning_data.add_scheduler_args_if_applicable(
scheduler,
# for ddim scheduler
eta=0.0, #ddim_eta
# for ancestral and sde schedulers
generator=torch.Generator(device=uc.device).manual_seed(0),
)
return conditioning_data
def prep_control_data(self,
context: InvocationContext,
model: StableDiffusionGeneratorPipeline, # really only need model for dtype and device
control_input: List[ControlField],
latents_shape: List[int],
do_classifier_free_guidance: bool = True,
) -> List[ControlNetData]:
def create_pipeline(self, unet, scheduler) -> StableDiffusionGeneratorPipeline:
# TODO:
#configure_model_padding(
# unet,
# self.seamless,
# self.seamless_axes,
#)
class FakeVae:
class FakeVaeConfig:
def __init__(self):
self.block_out_channels = [0]
def __init__(self):
self.config = FakeVae.FakeVaeConfig()
return StableDiffusionGeneratorPipeline(
vae=FakeVae(), # TODO: oh...
text_encoder=None,
tokenizer=None,
unet=unet,
scheduler=scheduler,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
precision="float16" if unet.dtype == torch.float16 else "float32",
)
def prep_control_data(
self,
context: InvocationContext,
model: StableDiffusionGeneratorPipeline, # really only need model for dtype and device
control_input: List[ControlField],
latents_shape: List[int],
do_classifier_free_guidance: bool = True,
) -> List[ControlNetData]:
# assuming fixed dimensional scaling of 8:1 for image:latents
control_height_resize = latents_shape[2] * 8
control_width_resize = latents_shape[3] * 8
if control_input is None:
# print("control input is None")
control_list = None
elif isinstance(control_input, list) and len(control_input) == 0:
# print("control input is empty list")
control_list = None
elif isinstance(control_input, ControlField):
# print("control input is ControlField")
control_list = [control_input]
elif isinstance(control_input, list) and len(control_input) > 0 and isinstance(control_input[0], ControlField):
# print("control input is list[ControlField]")
control_list = control_input
else:
# print("input control is unrecognized:", type(self.control))
control_list = None
if (control_list is None):
control_data = None
@@ -321,8 +321,7 @@ class TextToLatentsInvocation(BaseInvocation):
torch_dtype=model.unet.dtype).to(model.device)
control_models.append(control_model)
control_image_field = control_info.image
input_image = context.services.images.get_pil_image(control_image_field.image_origin,
control_image_field.image_name)
input_image = context.services.images.get_pil_image(control_image_field.image_name)
# self.image.image_type, self.image.image_name
# FIXME: still need to test with different widths, heights, devices, dtypes
# and add in batch_size, num_images_per_prompt?
@@ -337,12 +336,15 @@ class TextToLatentsInvocation(BaseInvocation):
# num_images_per_prompt=num_images_per_prompt,
device=control_model.device,
dtype=control_model.dtype,
control_mode=control_info.control_mode,
)
control_item = ControlNetData(model=control_model,
image_tensor=control_image,
weight=control_info.control_weight,
begin_step_percent=control_info.begin_step_percent,
end_step_percent=control_info.end_step_percent)
end_step_percent=control_info.end_step_percent,
control_mode=control_info.control_mode,
)
control_data.append(control_item)
# MultiControlNetModel has been refactored out, just need list[ControlNetData]
return control_data
@@ -357,23 +359,38 @@ class TextToLatentsInvocation(BaseInvocation):
def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, source_node_id, state)
model = self.get_model(context.services.model_manager)
conditioning_data = self.get_conditioning_data(context, model)
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict())
with unet_info as unet,\
ExitStack() as stack:
control_data = self.prep_control_data(model=model, context=context, control_input=self.control,
latents_shape=noise.shape,
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
do_classifier_free_guidance=True,)
scheduler = get_scheduler(
context=context,
scheduler_info=self.unet.scheduler,
scheduler_name=self.scheduler,
)
pipeline = self.create_pipeline(unet, scheduler)
conditioning_data = self.get_conditioning_data(context, scheduler)
# TODO: Verify the noise is the right size
result_latents, result_attention_map_saver = model.latents_from_embeddings(
latents=torch.zeros_like(noise, dtype=torch_dtype(model.device)),
noise=noise,
num_inference_steps=self.steps,
conditioning_data=conditioning_data,
control_data=control_data, # list[ControlNetData]
callback=step_callback,
)
loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
control_data = self.prep_control_data(
model=pipeline, context=context, control_input=self.control,
latents_shape=noise.shape,
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
do_classifier_free_guidance=True,
)
with ModelPatcher.apply_lora_unet(pipeline.unet, loras):
# TODO: Verify the noise is the right size
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
latents=torch.zeros_like(noise, dtype=torch_dtype(unet.device)),
noise=noise,
num_inference_steps=self.steps,
conditioning_data=conditioning_data,
control_data=control_data, # list[ControlNetData]
callback=step_callback,
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache()
@@ -382,7 +399,6 @@ class TextToLatentsInvocation(BaseInvocation):
context.services.latents.save(name, result_latents)
return build_latents_output(latents_name=name, latents=result_latents)
class LatentsToLatentsInvocation(TextToLatentsInvocation):
"""Generates latents using latents as base image."""
@@ -416,32 +432,52 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, source_node_id, state)
model = self.get_model(context.services.model_manager)
conditioning_data = self.get_conditioning_data(context, model)
control_data = self.prep_control_data(model=model, context=context, control_input=self.control,
latents_shape=noise.shape,
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
do_classifier_free_guidance=True,
)
# TODO: Verify the noise is the right size
initial_latents = latent if self.strength < 1.0 else torch.zeros_like(
latent, device=model.device, dtype=latent.dtype
unet_info = context.services.model_manager.get_model(
**self.unet.unet.dict(),
)
timesteps, _ = model.get_img2img_timesteps(self.steps, self.strength)
with unet_info as unet,\
ExitStack() as stack:
result_latents, result_attention_map_saver = model.latents_from_embeddings(
latents=initial_latents,
timesteps=timesteps,
noise=noise,
num_inference_steps=self.steps,
conditioning_data=conditioning_data,
control_data=control_data, # list[ControlNetData]
callback=step_callback
)
scheduler = get_scheduler(
context=context,
scheduler_info=self.unet.scheduler,
scheduler_name=self.scheduler,
)
pipeline = self.create_pipeline(unet, scheduler)
conditioning_data = self.get_conditioning_data(context, scheduler)
control_data = self.prep_control_data(
model=pipeline, context=context, control_input=self.control,
latents_shape=noise.shape,
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
do_classifier_free_guidance=True,
)
# TODO: Verify the noise is the right size
initial_latents = latent if self.strength < 1.0 else torch.zeros_like(
latent, device=unet.device, dtype=latent.dtype
)
timesteps, _ = pipeline.get_img2img_timesteps(
self.steps,
self.strength,
device=unet.device,
)
loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
with ModelPatcher.apply_lora_unet(pipeline.unet, loras):
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
latents=initial_latents,
timesteps=timesteps,
noise=noise,
num_inference_steps=self.steps,
conditioning_data=conditioning_data,
control_data=control_data, # list[ControlNetData]
callback=step_callback
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache()
@@ -459,16 +495,14 @@ class LatentsToImageInvocation(BaseInvocation):
# Inputs
latents: Optional[LatentsField] = Field(description="The latents to generate an image from")
model: str = Field(default="", description="The model to use")
vae: VaeField = Field(default=None, description="Vae submodel")
tiled: bool = Field(default=False, description="Decode latents by overlaping tiles(less memory consumption)")
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["latents", "image"],
"type_hints": {
"model": "model"
}
},
}
@@ -476,40 +510,45 @@ class LatentsToImageInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.services.latents.get(self.latents.latents_name)
# TODO: this only really needs the vae
model_info = choose_model(context.services.model_manager, self.model)
model: StableDiffusionGeneratorPipeline = model_info['model']
vae_info = context.services.model_manager.get_model(
**self.vae.vae.dict(),
)
with torch.inference_mode():
np_image = model.decode_latents(latents)
image = model.numpy_to_pil(np_image)[0]
# what happened to metadata?
# metadata = context.services.metadata.build_metadata(
# session_id=context.graph_execution_state_id, node=self
with vae_info as vae:
if self.tiled or context.services.configuration.tiled_decode:
vae.enable_tiling()
else:
vae.disable_tiling()
# clear memory as vae decode can request a lot
torch.cuda.empty_cache()
# new (post Image service refactor) way of using services to save image
# and gnenerate unique image_name
image_dto = context.services.images.create(
image=image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
session_id=context.graph_execution_state_id,
node_id=self.id,
is_intermediate=self.is_intermediate
)
with torch.inference_mode():
# copied from diffusers pipeline
latents = latents / vae.config.scaling_factor
image = vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1) # denormalize
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
np_image = image.cpu().permute(0, 2, 3, 1).float().numpy()
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
width=image_dto.width,
height=image_dto.height,
)
image = VaeImageProcessor.numpy_to_pil(np_image)[0]
torch.cuda.empty_cache()
image_dto = context.services.images.create(
image=image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate
)
return ImageOutput(
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
LATENTS_INTERPOLATION_MODE = Literal[
"nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"
@@ -585,14 +624,14 @@ class ImageToLatentsInvocation(BaseInvocation):
# Inputs
image: Union[ImageField, None] = Field(description="The image to encode")
model: str = Field(default="", description="The model to use")
vae: VaeField = Field(default=None, description="Vae submodel")
tiled: bool = Field(default=False, description="Encode latents by overlaping tiles(less memory consumption)")
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["latents", "image"],
"type_hints": {"model": "model"},
},
}
@@ -601,24 +640,32 @@ class ImageToLatentsInvocation(BaseInvocation):
# image = context.services.images.get(
# self.image.image_type, self.image.image_name
# )
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
image = context.services.images.get_pil_image(self.image.image_name)
#vae_info = context.services.model_manager.get_model(**self.vae.vae.dict())
vae_info = context.services.model_manager.get_model(
**self.vae.vae.dict(),
)
# TODO: this only really needs the vae
model_info = choose_model(context.services.model_manager, self.model)
model: StableDiffusionGeneratorPipeline = model_info["model"]
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image_tensor.dim() == 3:
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
latents = model.non_noised_latents_from_image(
image_tensor,
device=model._model_group.device_for(model.unet),
dtype=model.unet.dtype,
)
with vae_info as vae:
if self.tiled:
vae.enable_tiling()
else:
vae.disable_tiling()
# non_noised_latents_from_image
image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype)
with torch.inference_mode():
image_tensor_dist = vae.encode(image_tensor).latent_dist
latents = image_tensor_dist.sample().to(
dtype=vae.dtype
) # FIXME: uses torch.randn. make reproducible!
latents = 0.18215 * latents
name = f"{context.graph_execution_state_id}__{self.id}"
# context.services.latents.set(name, latents)

View File

@@ -0,0 +1,217 @@
from typing import Literal, Optional, Union, List
from pydantic import BaseModel, Field
import copy
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
from ...backend.util.devices import choose_torch_device, torch_dtype
from ...backend.model_management import BaseModelType, ModelType, SubModelType
class ModelInfo(BaseModel):
model_name: str = Field(description="Info to load submodel")
base_model: BaseModelType = Field(description="Base model")
model_type: ModelType = Field(description="Info to load submodel")
submodel: Optional[SubModelType] = Field(description="Info to load submodel")
class LoraInfo(ModelInfo):
weight: float = Field(description="Lora's weight which to use when apply to model")
class UNetField(BaseModel):
unet: ModelInfo = Field(description="Info to load unet submodel")
scheduler: ModelInfo = Field(description="Info to load scheduler submodel")
loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
class ClipField(BaseModel):
tokenizer: ModelInfo = Field(description="Info to load tokenizer submodel")
text_encoder: ModelInfo = Field(description="Info to load text_encoder submodel")
loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
class VaeField(BaseModel):
# TODO: better naming?
vae: ModelInfo = Field(description="Info to load vae submodel")
class ModelLoaderOutput(BaseInvocationOutput):
"""Model loader output"""
#fmt: off
type: Literal["model_loader_output"] = "model_loader_output"
unet: UNetField = Field(default=None, description="UNet submodel")
clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
vae: VaeField = Field(default=None, description="Vae submodel")
#fmt: on
class PipelineModelField(BaseModel):
"""Pipeline model field"""
model_name: str = Field(description="Name of the model")
base_model: BaseModelType = Field(description="Base model")
class PipelineModelLoaderInvocation(BaseInvocation):
"""Loads a pipeline model, outputting its submodels."""
type: Literal["pipeline_model_loader"] = "pipeline_model_loader"
model: PipelineModelField = Field(description="The model to load")
# TODO: precision?
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["model", "loader"],
"type_hints": {
"model": "model"
}
},
}
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
base_model = self.model.base_model
model_name = self.model.model_name
model_type = ModelType.Pipeline
# TODO: not found exceptions
if not context.services.model_manager.model_exists(
model_name=model_name,
base_model=base_model,
model_type=model_type,
):
raise Exception(f"Unknown {base_model} {model_type} model: {model_name}")
"""
if not context.services.model_manager.model_exists(
model_name=self.model_name,
model_type=SDModelType.Diffusers,
submodel=SDModelType.Tokenizer,
):
raise Exception(
f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted"
)
if not context.services.model_manager.model_exists(
model_name=self.model_name,
model_type=SDModelType.Diffusers,
submodel=SDModelType.TextEncoder,
):
raise Exception(
f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted"
)
if not context.services.model_manager.model_exists(
model_name=self.model_name,
model_type=SDModelType.Diffusers,
submodel=SDModelType.UNet,
):
raise Exception(
f"Failed to find unet submodel from {self.model_name}! Check if model corrupted"
)
"""
return ModelLoaderOutput(
unet=UNetField(
unet=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.UNet,
),
scheduler=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.Scheduler,
),
loras=[],
),
clip=ClipField(
tokenizer=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.Tokenizer,
),
text_encoder=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.TextEncoder,
),
loras=[],
),
vae=VaeField(
vae=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.Vae,
),
)
)
class LoraLoaderOutput(BaseInvocationOutput):
"""Model loader output"""
#fmt: off
type: Literal["lora_loader_output"] = "lora_loader_output"
unet: Optional[UNetField] = Field(default=None, description="UNet submodel")
clip: Optional[ClipField] = Field(default=None, description="Tokenizer and text_encoder submodels")
#fmt: on
class LoraLoaderInvocation(BaseInvocation):
"""Apply selected lora to unet and text_encoder."""
type: Literal["lora_loader"] = "lora_loader"
lora_name: str = Field(description="Lora model name")
weight: float = Field(default=0.75, description="With what weight to apply lora")
unet: Optional[UNetField] = Field(description="UNet model for applying lora")
clip: Optional[ClipField] = Field(description="Clip model for applying lora")
def invoke(self, context: InvocationContext) -> LoraLoaderOutput:
if not context.services.model_manager.model_exists(
model_name=self.lora_name,
model_type=SDModelType.Lora,
):
raise Exception(f"Unkown lora name: {self.lora_name}!")
if self.unet is not None and any(lora.model_name == self.lora_name for lora in self.unet.loras):
raise Exception(f"Lora \"{self.lora_name}\" already applied to unet")
if self.clip is not None and any(lora.model_name == self.lora_name for lora in self.clip.loras):
raise Exception(f"Lora \"{self.lora_name}\" already applied to clip")
output = LoraLoaderOutput()
if self.unet is not None:
output.unet = copy.deepcopy(self.unet)
output.unet.loras.append(
LoraInfo(
model_name=self.lora_name,
model_type=SDModelType.Lora,
submodel=None,
weight=self.weight,
)
)
if self.clip is not None:
output.clip = copy.deepcopy(self.clip)
output.clip.loras.append(
LoraInfo(
model_name=self.lora_name,
model_type=SDModelType.Lora,
submodel=None,
weight=self.weight,
)
)
return output

View File

@@ -2,8 +2,8 @@ from typing import Literal
from pydantic.fields import Field
from .baseinvocation import BaseInvocationOutput
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
from dynamicprompts.generators import RandomPromptGenerator, CombinatorialPromptGenerator
class PromptOutput(BaseInvocationOutput):
"""Base class for invocations that output a prompt"""
@@ -20,3 +20,38 @@ class PromptOutput(BaseInvocationOutput):
'prompt',
]
}
class PromptCollectionOutput(BaseInvocationOutput):
"""Base class for invocations that output a collection of prompts"""
# fmt: off
type: Literal["prompt_collection_output"] = "prompt_collection_output"
prompt_collection: list[str] = Field(description="The output prompt collection")
count: int = Field(description="The size of the prompt collection")
# fmt: on
class Config:
schema_extra = {"required": ["type", "prompt_collection", "count"]}
class DynamicPromptInvocation(BaseInvocation):
"""Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator"""
type: Literal["dynamic_prompt"] = "dynamic_prompt"
prompt: str = Field(description="The prompt to parse with dynamicprompts")
max_prompts: int = Field(default=1, description="The number of prompts to generate")
combinatorial: bool = Field(
default=False, description="Whether to use the combinatorial generator"
)
def invoke(self, context: InvocationContext) -> PromptCollectionOutput:
if self.combinatorial:
generator = CombinatorialPromptGenerator()
prompts = generator.generate(self.prompt, max_prompts=self.max_prompts)
else:
generator = RandomPromptGenerator()
prompts = generator.generate(self.prompt, num_images=self.max_prompts)
return PromptCollectionOutput(prompt_collection=prompts, count=len(prompts))

View File

@@ -28,9 +28,7 @@ class RestoreFaceInvocation(BaseInvocation):
}
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
image = context.services.images.get_pil_image(self.image.image_name)
results = context.services.restoration.upscale_and_reconstruct(
image_list=[[image, 0]],
upscale=None,
@@ -51,10 +49,7 @@ class RestoreFaceInvocation(BaseInvocation):
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)

View File

@@ -1,8 +1,14 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from email.mime import image
from typing import Literal, Union
import cv2 as cv
import numpy as np
from pydantic import Field
from realesrgan import RealESRGANer
from PIL import Image
from basicsr.archs.rrdbnet_arch import RRDBNet
from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
@@ -30,9 +36,7 @@ class UpscaleInvocation(BaseInvocation):
}
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
image = context.services.images.get_pil_image(self.image.image_name)
results = context.services.restoration.upscale_and_reconstruct(
image_list=[[image, 0]],
upscale=(self.level, self.strength),
@@ -53,10 +57,77 @@ class UpscaleInvocation(BaseInvocation):
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
REALESRGAN_MODELS = Literal[
"RealESRGAN_x4plus",
"RealESRGAN_x4plus_anime_6B",
"ESRGAN_SRx4_DF2KOST_official-ff704c30",
]
class RealESRGANInvocation(BaseInvocation):
"""Upscales an image using Real-ESRGAN."""
# fmt: off
type: Literal["realesrgan"] = "realesrgan"
image: Union[ImageField, None] = Field(default=None, description="The input image" )
model_name: REALESRGAN_MODELS = Field(default="RealESRGAN_x4plus", description="The Real-ESRGAN model to use")
scale: Literal[2, 4] = Field(default=4, description="The final upsampling scale")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name)
model = None
netscale = None
model_path = None
if self.model_name == 'RealESRGAN x4 Plus': # x4 RRDBNet model
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
netscale = 4
model_path = f'core/upscaling/realesrgan/RealESRGAN_x4plus.pth'
elif self.model_name == 'RealESRGAN x4 Plus (Anime 6B)': # x4 RRDBNet model with 6 blocks
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
netscale = 4
model_path = f'core/upscaling/realesrgan/RealESRGAN_x4plus_anime_6B.pth'
# elif self.model_name in ['RealESRGAN_x2plus']: # x2 RRDBNet model
# model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
# netscale = 2
elif self.model_name in ['ESRGAN x4']: # x2 RRDBNet model
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
netscale = 4
model_path = f'core/upscaling/realesrgan/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth'
if not model or not netscale or not model_path:
raise Exception(f"Invalid model {self.model_name}")
upsampler = RealESRGANer(
scale=netscale,
model_path=model_path,
model=model,
half=False,
)
# Real-ESRGAN uses cv2 internally, and cv2 uses BGR vs RGB for PIL
cv_image = cv.cvtColor(np.array(image.convert("RGB")), cv.COLOR_RGB2BGR)
upscaled_image, img_mode = upsampler.enhance(cv_image, outscale=self.scale)
pil_image = Image.fromarray(cv.cvtColor(upscaled_image, cv.COLOR_BGR2RGB)).convert('RGBA')
image_dto = context.services.images.create(
image=pil_image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
)
return ImageOutput(
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)

View File

@@ -1,14 +0,0 @@
from invokeai.backend.model_management.model_manager import ModelManager
def choose_model(model_manager: ModelManager, model_name: str):
"""Returns the default model if the `model_name` not a valid model, else returns the selected model."""
logger = model_manager.logger
if model_name and not model_manager.valid_model(model_name):
default_model_name = model_manager.default_model()
logger.warning(f"\'{model_name}\' is not a valid model name. Using default model \'{default_model_name}\' instead.")
model = model_manager.get_model()
else:
model = model_manager.get_model(model_name)
return model

View File

@@ -66,13 +66,10 @@ class InvalidImageCategoryException(ValueError):
class ImageField(BaseModel):
"""An image field used for passing image objects between invocations"""
image_origin: ResourceOrigin = Field(
default=ResourceOrigin.INTERNAL, description="The type of the image"
)
image_name: Optional[str] = Field(default=None, description="The name of the image")
class Config:
schema_extra = {"required": ["image_origin", "image_name"]}
schema_extra = {"required": ["image_name"]}
class ColorField(BaseModel):

View File

@@ -0,0 +1,254 @@
from abc import ABC, abstractmethod
import sqlite3
import threading
from typing import Union, cast
from invokeai.app.services.board_record_storage import BoardRecord
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
from invokeai.app.services.models.image_record import (
ImageRecord,
deserialize_image_record,
)
class BoardImageRecordStorageBase(ABC):
"""Abstract base class for the one-to-many board-image relationship record storage."""
@abstractmethod
def add_image_to_board(
self,
board_id: str,
image_name: str,
) -> None:
"""Adds an image to a board."""
pass
@abstractmethod
def remove_image_from_board(
self,
board_id: str,
image_name: str,
) -> None:
"""Removes an image from a board."""
pass
@abstractmethod
def get_images_for_board(
self,
board_id: str,
) -> OffsetPaginatedResults[ImageRecord]:
"""Gets images for a board."""
pass
@abstractmethod
def get_board_for_image(
self,
image_name: str,
) -> Union[str, None]:
"""Gets an image's board id, if it has one."""
pass
@abstractmethod
def get_image_count_for_board(
self,
board_id: str,
) -> int:
"""Gets the number of images for a board."""
pass
class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
_filename: str
_conn: sqlite3.Connection
_cursor: sqlite3.Cursor
_lock: threading.Lock
def __init__(self, filename: str) -> None:
super().__init__()
self._filename = filename
self._conn = sqlite3.connect(filename, check_same_thread=False)
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
self._conn.row_factory = sqlite3.Row
self._cursor = self._conn.cursor()
self._lock = threading.Lock()
try:
self._lock.acquire()
# Enable foreign keys
self._conn.execute("PRAGMA foreign_keys = ON;")
self._create_tables()
self._conn.commit()
finally:
self._lock.release()
def _create_tables(self) -> None:
"""Creates the `board_images` junction table."""
# Create the `board_images` junction table.
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS board_images (
board_id TEXT NOT NULL,
image_name TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Soft delete, currently unused
deleted_at DATETIME,
-- enforce one-to-many relationship between boards and images using PK
-- (we can extend this to many-to-many later)
PRIMARY KEY (image_name),
FOREIGN KEY (board_id) REFERENCES boards (board_id) ON DELETE CASCADE,
FOREIGN KEY (image_name) REFERENCES images (image_name) ON DELETE CASCADE
);
"""
)
# Add index for board id
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_board_images_board_id ON board_images (board_id);
"""
)
# Add index for board id, sorted by created_at
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_board_images_board_id_created_at ON board_images (board_id, created_at);
"""
)
# Add trigger for `updated_at`.
self._cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_board_images_updated_at
AFTER UPDATE
ON board_images FOR EACH ROW
BEGIN
UPDATE board_images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE board_id = old.board_id AND image_name = old.image_name;
END;
"""
)
def add_image_to_board(
self,
board_id: str,
image_name: str,
) -> None:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
INSERT INTO board_images (board_id, image_name)
VALUES (?, ?)
ON CONFLICT (image_name) DO UPDATE SET board_id = ?;
""",
(board_id, image_name, board_id),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()
def remove_image_from_board(
self,
board_id: str,
image_name: str,
) -> None:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
DELETE FROM board_images
WHERE board_id = ? AND image_name = ?;
""",
(board_id, image_name),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()
def get_images_for_board(
self,
board_id: str,
offset: int = 0,
limit: int = 10,
) -> OffsetPaginatedResults[ImageRecord]:
# TODO: this isn't paginated yet?
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT images.*
FROM board_images
INNER JOIN images ON board_images.image_name = images.image_name
WHERE board_images.board_id = ?
ORDER BY board_images.updated_at DESC;
""",
(board_id,),
)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
images = list(map(lambda r: deserialize_image_record(dict(r)), result))
self._cursor.execute(
"""--sql
SELECT COUNT(*) FROM images WHERE 1=1;
"""
)
count = cast(int, self._cursor.fetchone()[0])
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()
return OffsetPaginatedResults(
items=images, offset=offset, limit=limit, total=count
)
def get_board_for_image(
self,
image_name: str,
) -> Union[str, None]:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT board_id
FROM board_images
WHERE image_name = ?;
""",
(image_name,),
)
result = self._cursor.fetchone()
if result is None:
return None
return cast(str, result[0])
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()
def get_image_count_for_board(self, board_id: str) -> int:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT COUNT(*) FROM board_images WHERE board_id = ?;
""",
(board_id,),
)
count = cast(int, self._cursor.fetchone()[0])
return count
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()

View File

@@ -0,0 +1,142 @@
from abc import ABC, abstractmethod
from logging import Logger
from typing import List, Union
from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase
from invokeai.app.services.board_record_storage import (
BoardRecord,
BoardRecordStorageBase,
)
from invokeai.app.services.image_record_storage import (
ImageRecordStorageBase,
OffsetPaginatedResults,
)
from invokeai.app.services.models.board_record import BoardDTO
from invokeai.app.services.models.image_record import ImageDTO, image_record_to_dto
from invokeai.app.services.urls import UrlServiceBase
class BoardImagesServiceABC(ABC):
"""High-level service for board-image relationship management."""
@abstractmethod
def add_image_to_board(
self,
board_id: str,
image_name: str,
) -> None:
"""Adds an image to a board."""
pass
@abstractmethod
def remove_image_from_board(
self,
board_id: str,
image_name: str,
) -> None:
"""Removes an image from a board."""
pass
@abstractmethod
def get_images_for_board(
self,
board_id: str,
) -> OffsetPaginatedResults[ImageDTO]:
"""Gets images for a board."""
pass
@abstractmethod
def get_board_for_image(
self,
image_name: str,
) -> Union[str, None]:
"""Gets an image's board id, if it has one."""
pass
class BoardImagesServiceDependencies:
"""Service dependencies for the BoardImagesService."""
board_image_records: BoardImageRecordStorageBase
board_records: BoardRecordStorageBase
image_records: ImageRecordStorageBase
urls: UrlServiceBase
logger: Logger
def __init__(
self,
board_image_record_storage: BoardImageRecordStorageBase,
image_record_storage: ImageRecordStorageBase,
board_record_storage: BoardRecordStorageBase,
url: UrlServiceBase,
logger: Logger,
):
self.board_image_records = board_image_record_storage
self.image_records = image_record_storage
self.board_records = board_record_storage
self.urls = url
self.logger = logger
class BoardImagesService(BoardImagesServiceABC):
_services: BoardImagesServiceDependencies
def __init__(self, services: BoardImagesServiceDependencies):
self._services = services
def add_image_to_board(
self,
board_id: str,
image_name: str,
) -> None:
self._services.board_image_records.add_image_to_board(board_id, image_name)
def remove_image_from_board(
self,
board_id: str,
image_name: str,
) -> None:
self._services.board_image_records.remove_image_from_board(board_id, image_name)
def get_images_for_board(
self,
board_id: str,
) -> OffsetPaginatedResults[ImageDTO]:
image_records = self._services.board_image_records.get_images_for_board(
board_id
)
image_dtos = list(
map(
lambda r: image_record_to_dto(
r,
self._services.urls.get_image_url(r.image_name),
self._services.urls.get_image_url(r.image_name, True),
board_id,
),
image_records.items,
)
)
return OffsetPaginatedResults[ImageDTO](
items=image_dtos,
offset=image_records.offset,
limit=image_records.limit,
total=image_records.total,
)
def get_board_for_image(
self,
image_name: str,
) -> Union[str, None]:
board_id = self._services.board_image_records.get_board_for_image(image_name)
return board_id
def board_record_to_dto(
board_record: BoardRecord, cover_image_name: str | None, image_count: int
) -> BoardDTO:
"""Converts a board record to a board DTO."""
return BoardDTO(
**board_record.dict(exclude={'cover_image_name'}),
cover_image_name=cover_image_name,
image_count=image_count,
)

View File

@@ -0,0 +1,329 @@
from abc import ABC, abstractmethod
from typing import Optional, cast
import sqlite3
import threading
from typing import Optional, Union
import uuid
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
from invokeai.app.services.models.board_record import (
BoardRecord,
deserialize_board_record,
)
from pydantic import BaseModel, Field, Extra
class BoardChanges(BaseModel, extra=Extra.forbid):
board_name: Optional[str] = Field(description="The board's new name.")
cover_image_name: Optional[str] = Field(
description="The name of the board's new cover image."
)
class BoardRecordNotFoundException(Exception):
"""Raised when an board record is not found."""
def __init__(self, message="Board record not found"):
super().__init__(message)
class BoardRecordSaveException(Exception):
"""Raised when an board record cannot be saved."""
def __init__(self, message="Board record not saved"):
super().__init__(message)
class BoardRecordDeleteException(Exception):
"""Raised when an board record cannot be deleted."""
def __init__(self, message="Board record not deleted"):
super().__init__(message)
class BoardRecordStorageBase(ABC):
"""Low-level service responsible for interfacing with the board record store."""
@abstractmethod
def delete(self, board_id: str) -> None:
"""Deletes a board record."""
pass
@abstractmethod
def save(
self,
board_name: str,
) -> BoardRecord:
"""Saves a board record."""
pass
@abstractmethod
def get(
self,
board_id: str,
) -> BoardRecord:
"""Gets a board record."""
pass
@abstractmethod
def update(
self,
board_id: str,
changes: BoardChanges,
) -> BoardRecord:
"""Updates a board record."""
pass
@abstractmethod
def get_many(
self,
offset: int = 0,
limit: int = 10,
) -> OffsetPaginatedResults[BoardRecord]:
"""Gets many board records."""
pass
@abstractmethod
def get_all(
self,
) -> list[BoardRecord]:
"""Gets all board records."""
pass
class SqliteBoardRecordStorage(BoardRecordStorageBase):
_filename: str
_conn: sqlite3.Connection
_cursor: sqlite3.Cursor
_lock: threading.Lock
def __init__(self, filename: str) -> None:
super().__init__()
self._filename = filename
self._conn = sqlite3.connect(filename, check_same_thread=False)
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
self._conn.row_factory = sqlite3.Row
self._cursor = self._conn.cursor()
self._lock = threading.Lock()
try:
self._lock.acquire()
# Enable foreign keys
self._conn.execute("PRAGMA foreign_keys = ON;")
self._create_tables()
self._conn.commit()
finally:
self._lock.release()
def _create_tables(self) -> None:
"""Creates the `boards` table and `board_images` junction table."""
# Create the `boards` table.
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS boards (
board_id TEXT NOT NULL PRIMARY KEY,
board_name TEXT NOT NULL,
cover_image_name TEXT,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Soft delete, currently unused
deleted_at DATETIME,
FOREIGN KEY (cover_image_name) REFERENCES images (image_name) ON DELETE SET NULL
);
"""
)
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_boards_created_at ON boards (created_at);
"""
)
# Add trigger for `updated_at`.
self._cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_boards_updated_at
AFTER UPDATE
ON boards FOR EACH ROW
BEGIN
UPDATE boards SET updated_at = current_timestamp
WHERE board_id = old.board_id;
END;
"""
)
def delete(self, board_id: str) -> None:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
DELETE FROM boards
WHERE board_id = ?;
""",
(board_id,),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise BoardRecordDeleteException from e
except Exception as e:
self._conn.rollback()
raise BoardRecordDeleteException from e
finally:
self._lock.release()
def save(
self,
board_name: str,
) -> BoardRecord:
try:
board_id = str(uuid.uuid4())
self._lock.acquire()
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO boards (board_id, board_name)
VALUES (?, ?);
""",
(board_id, board_name),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise BoardRecordSaveException from e
finally:
self._lock.release()
return self.get(board_id)
def get(
self,
board_id: str,
) -> BoardRecord:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT *
FROM boards
WHERE board_id = ?;
""",
(board_id,),
)
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
except sqlite3.Error as e:
self._conn.rollback()
raise BoardRecordNotFoundException from e
finally:
self._lock.release()
if result is None:
raise BoardRecordNotFoundException
return BoardRecord(**dict(result))
def update(
self,
board_id: str,
changes: BoardChanges,
) -> BoardRecord:
try:
self._lock.acquire()
# Change the name of a board
if changes.board_name is not None:
self._cursor.execute(
f"""--sql
UPDATE boards
SET board_name = ?
WHERE board_id = ?;
""",
(changes.board_name, board_id),
)
# Change the cover image of a board
if changes.cover_image_name is not None:
self._cursor.execute(
f"""--sql
UPDATE boards
SET cover_image_name = ?
WHERE board_id = ?;
""",
(changes.cover_image_name, board_id),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise BoardRecordSaveException from e
finally:
self._lock.release()
return self.get(board_id)
def get_many(
self,
offset: int = 0,
limit: int = 10,
) -> OffsetPaginatedResults[BoardRecord]:
try:
self._lock.acquire()
# Get all the boards
self._cursor.execute(
"""--sql
SELECT *
FROM boards
ORDER BY created_at DESC
LIMIT ? OFFSET ?;
""",
(limit, offset),
)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
boards = list(map(lambda r: deserialize_board_record(dict(r)), result))
# Get the total number of boards
self._cursor.execute(
"""--sql
SELECT COUNT(*)
FROM boards
WHERE 1=1;
"""
)
count = cast(int, self._cursor.fetchone()[0])
return OffsetPaginatedResults[BoardRecord](
items=boards, offset=offset, limit=limit, total=count
)
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()
def get_all(
self,
) -> list[BoardRecord]:
try:
self._lock.acquire()
# Get all the boards
self._cursor.execute(
"""--sql
SELECT *
FROM boards
ORDER BY created_at DESC
"""
)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
boards = list(map(lambda r: deserialize_board_record(dict(r)), result))
return boards
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()

View File

@@ -0,0 +1,185 @@
from abc import ABC, abstractmethod
from logging import Logger
from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase
from invokeai.app.services.board_images import board_record_to_dto
from invokeai.app.services.board_record_storage import (
BoardChanges,
BoardRecordStorageBase,
)
from invokeai.app.services.image_record_storage import (
ImageRecordStorageBase,
OffsetPaginatedResults,
)
from invokeai.app.services.models.board_record import BoardDTO
from invokeai.app.services.urls import UrlServiceBase
class BoardServiceABC(ABC):
"""High-level service for board management."""
@abstractmethod
def create(
self,
board_name: str,
) -> BoardDTO:
"""Creates a board."""
pass
@abstractmethod
def get_dto(
self,
board_id: str,
) -> BoardDTO:
"""Gets a board."""
pass
@abstractmethod
def update(
self,
board_id: str,
changes: BoardChanges,
) -> BoardDTO:
"""Updates a board."""
pass
@abstractmethod
def delete(
self,
board_id: str,
) -> None:
"""Deletes a board."""
pass
@abstractmethod
def get_many(
self,
offset: int = 0,
limit: int = 10,
) -> OffsetPaginatedResults[BoardDTO]:
"""Gets many boards."""
pass
@abstractmethod
def get_all(
self,
) -> list[BoardDTO]:
"""Gets all boards."""
pass
class BoardServiceDependencies:
"""Service dependencies for the BoardService."""
board_image_records: BoardImageRecordStorageBase
board_records: BoardRecordStorageBase
image_records: ImageRecordStorageBase
urls: UrlServiceBase
logger: Logger
def __init__(
self,
board_image_record_storage: BoardImageRecordStorageBase,
image_record_storage: ImageRecordStorageBase,
board_record_storage: BoardRecordStorageBase,
url: UrlServiceBase,
logger: Logger,
):
self.board_image_records = board_image_record_storage
self.image_records = image_record_storage
self.board_records = board_record_storage
self.urls = url
self.logger = logger
class BoardService(BoardServiceABC):
_services: BoardServiceDependencies
def __init__(self, services: BoardServiceDependencies):
self._services = services
def create(
self,
board_name: str,
) -> BoardDTO:
board_record = self._services.board_records.save(board_name)
return board_record_to_dto(board_record, None, 0)
def get_dto(self, board_id: str) -> BoardDTO:
board_record = self._services.board_records.get(board_id)
cover_image = self._services.image_records.get_most_recent_image_for_board(
board_record.board_id
)
if cover_image:
cover_image_name = cover_image.image_name
else:
cover_image_name = None
image_count = self._services.board_image_records.get_image_count_for_board(
board_id
)
return board_record_to_dto(board_record, cover_image_name, image_count)
def update(
self,
board_id: str,
changes: BoardChanges,
) -> BoardDTO:
board_record = self._services.board_records.update(board_id, changes)
cover_image = self._services.image_records.get_most_recent_image_for_board(
board_record.board_id
)
if cover_image:
cover_image_name = cover_image.image_name
else:
cover_image_name = None
image_count = self._services.board_image_records.get_image_count_for_board(
board_id
)
return board_record_to_dto(board_record, cover_image_name, image_count)
def delete(self, board_id: str) -> None:
self._services.board_records.delete(board_id)
def get_many(
self, offset: int = 0, limit: int = 10
) -> OffsetPaginatedResults[BoardDTO]:
board_records = self._services.board_records.get_many(offset, limit)
board_dtos = []
for r in board_records.items:
cover_image = self._services.image_records.get_most_recent_image_for_board(
r.board_id
)
if cover_image:
cover_image_name = cover_image.image_name
else:
cover_image_name = None
image_count = self._services.board_image_records.get_image_count_for_board(
r.board_id
)
board_dtos.append(board_record_to_dto(r, cover_image_name, image_count))
return OffsetPaginatedResults[BoardDTO](
items=board_dtos, offset=offset, limit=limit, total=len(board_dtos)
)
def get_all(self) -> list[BoardDTO]:
board_records = self._services.board_records.get_all()
board_dtos = []
for r in board_records:
cover_image = self._services.image_records.get_most_recent_image_for_board(
r.board_id
)
if cover_image:
cover_image_name = cover_image.image_name
else:
cover_image_name = None
image_count = self._services.board_image_records.get_image_count_for_board(
r.board_id
)
board_dtos.append(board_record_to_dto(r, cover_image_name, image_count))
return board_dtos

View File

@@ -15,10 +15,7 @@ InvokeAI:
conf_path: configs/models.yaml
legacy_conf_dir: configs/stable-diffusion
outdir: outputs
embedding_dir: embeddings
lora_dir: loras
autoconvert_dir: null
gfpgan_model_dir: models/gfpgan/GFPGANv1.4.pth
Models:
model: stable-diffusion-1.5
embeddings: true
@@ -171,7 +168,7 @@ from argparse import ArgumentParser
from omegaconf import OmegaConf, DictConfig
from pathlib import Path
from pydantic import BaseSettings, Field, parse_obj_as
from typing import ClassVar, Dict, List, Literal, Type, Union, get_origin, get_type_hints, get_args
from typing import ClassVar, Dict, List, Literal, Union, get_origin, get_type_hints, get_args
INIT_FILE = Path('invokeai.yaml')
DB_FILE = Path('invokeai.db')
@@ -374,24 +371,20 @@ setting environment variables INVOKEAI_<setting>.
precision : Literal[tuple(['auto','float16','float32','autocast'])] = Field(default='float16',description='Floating point precision', category='Memory/Performance')
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", category='Memory/Performance')
xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", category='Memory/Performance')
tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category='Memory/Performance')
root : Path = Field(default=_find_root(), description='InvokeAI runtime root directory', category='Paths')
autoconvert_dir : Path = Field(default=None, description='Path to a directory of ckpt files to be converted into diffusers and imported on startup.', category='Paths')
conf_path : Path = Field(default='configs/models.yaml', description='Path to models definition file', category='Paths')
embedding_dir : Path = Field(default='embeddings', description='Path to InvokeAI textual inversion aembeddings directory', category='Paths')
gfpgan_model_dir : Path = Field(default="./models/gfpgan/GFPGANv1.4.pth", description='Path to GFPGAN models directory.', category='Paths')
controlnet_dir : Path = Field(default="controlnets", description='Path to directory of ControlNet models.', category='Paths')
models_dir : Path = Field(default='./models', description='Path to the models directory', category='Paths')
legacy_conf_dir : Path = Field(default='configs/stable-diffusion', description='Path to directory of legacy checkpoint config files', category='Paths')
lora_dir : Path = Field(default='loras', description='Path to InvokeAI LoRA model directory', category='Paths')
db_dir : Path = Field(default='databases', description='Path to InvokeAI databases directory', category='Paths')
outdir : Path = Field(default='outputs', description='Default folder for output images', category='Paths')
from_file : Path = Field(default=None, description='Take command input from the indicated file (command-line client only)', category='Paths')
use_memory_db : bool = Field(default=False, description='Use in-memory database for storing image metadata', category='Paths')
model : str = Field(default='stable-diffusion-1.5', description='Initial model name', category='Models')
embeddings : bool = Field(default=True, description='Load contents of embeddings directory', category='Models')
log_handlers : List[str] = Field(default=["console"], description='Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>"', category="Logging")
# note - would be better to read the log_format values from logging.py, but this creates circular dependencies issues
log_format : Literal[tuple(['plain','color','syslog','legacy'])] = Field(default="color", description='Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style', category="Logging")
@@ -492,39 +485,11 @@ setting environment variables INVOKEAI_<setting>.
return self._resolve(self.legacy_conf_dir)
@property
def cache_dir(self)->Path:
'''
Path to the global cache directory for HuggingFace hub-managed models
'''
return self.models_dir / "hub"
@property
def models_dir(self)->Path:
def models_path(self)->Path:
'''
Path to the models directory
'''
return self._resolve("models")
@property
def embedding_path(self)->Path:
'''
Path to the textual inversion embeddings directory.
'''
return self._resolve(self.embedding_dir) if self.embedding_dir else None
@property
def lora_path(self)->Path:
'''
Path to the LoRA models directory.
'''
return self._resolve(self.lora_dir) if self.lora_dir else None
@property
def controlnet_path(self)->Path:
'''
Path to the controlnet models directory.
'''
return self._resolve(self.controlnet_dir) if self.controlnet_dir else None
return self._resolve(self.models_dir)
@property
def autoconvert_path(self)->Path:
@@ -533,13 +498,6 @@ setting environment variables INVOKEAI_<setting>.
'''
return self._resolve(self.autoconvert_dir) if self.autoconvert_dir else None
@property
def gfpgan_model_path(self)->Path:
'''
Path to the GFPGAN model.
'''
return self._resolve(self.gfpgan_model_dir) if self.gfpgan_model_dir else None
# the following methods support legacy calls leftover from the Globals era
@property
def full_precision(self)->bool:

View File

@@ -3,7 +3,8 @@
from typing import Any
from invokeai.app.models.image import ProgressImage
from invokeai.app.util.misc import get_timestamp
from invokeai.app.services.model_manager_service import BaseModelType, ModelType, SubModelType, ModelInfo
from invokeai.app.models.exceptions import CanceledException
class EventServiceBase:
session_event: str = "session_event"
@@ -101,3 +102,53 @@ class EventServiceBase:
graph_execution_state_id=graph_execution_state_id,
),
)
def emit_model_load_started (
self,
graph_execution_state_id: str,
node: dict,
source_node_id: str,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
submodel: SubModelType,
) -> None:
"""Emitted when a model is requested"""
self.__emit_session_event(
event_name="model_load_started",
payload=dict(
graph_execution_state_id=graph_execution_state_id,
node=node,
source_node_id=source_node_id,
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=submodel,
),
)
def emit_model_load_completed(
self,
graph_execution_state_id: str,
node: dict,
source_node_id: str,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
submodel: SubModelType,
model_info: ModelInfo,
) -> None:
"""Emitted when a model is correctly loaded (returns model info)"""
self.__emit_session_event(
event_name="model_load_completed",
payload=dict(
graph_execution_state_id=graph_execution_state_id,
node=node,
source_node_id=source_node_id,
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=submodel,
model_info=model_info,
),
)

View File

@@ -1,5 +1,4 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
import os
from abc import ABC, abstractmethod
from pathlib import Path
from queue import Queue
@@ -40,14 +39,12 @@ class ImageFileStorageBase(ABC):
"""Low-level service responsible for storing and retrieving image files."""
@abstractmethod
def get(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType:
def get(self, image_name: str) -> PILImageType:
"""Retrieves an image as PIL Image."""
pass
@abstractmethod
def get_path(
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
) -> str:
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
"""Gets the internal path to an image or thumbnail."""
pass
@@ -62,7 +59,6 @@ class ImageFileStorageBase(ABC):
def save(
self,
image: PILImageType,
image_origin: ResourceOrigin,
image_name: str,
metadata: Optional[ImageMetadata] = None,
thumbnail_size: int = 256,
@@ -71,7 +67,7 @@ class ImageFileStorageBase(ABC):
pass
@abstractmethod
def delete(self, image_origin: ResourceOrigin, image_name: str) -> None:
def delete(self, image_name: str) -> None:
"""Deletes an image and its thumbnail (if one exists)."""
pass
@@ -79,31 +75,26 @@ class ImageFileStorageBase(ABC):
class DiskImageFileStorage(ImageFileStorageBase):
"""Stores images on disk"""
__output_folder: str
__output_folder: Path
__cache_ids: Queue # TODO: this is an incredibly naive cache
__cache: Dict[str, PILImageType]
__cache: Dict[Path, PILImageType]
__max_cache_size: int
def __init__(self, output_folder: str):
self.__output_folder = output_folder
def __init__(self, output_folder: str | Path):
self.__cache = dict()
self.__cache_ids = Queue()
self.__max_cache_size = 10 # TODO: get this from config
Path(output_folder).mkdir(parents=True, exist_ok=True)
self.__output_folder: Path = output_folder if isinstance(output_folder, Path) else Path(output_folder)
self.__thumbnails_folder = self.__output_folder / 'thumbnails'
# TODO: don't hard-code. get/save/delete should maybe take subpath?
for image_origin in ResourceOrigin:
Path(os.path.join(output_folder, image_origin)).mkdir(
parents=True, exist_ok=True
)
Path(os.path.join(output_folder, image_origin, "thumbnails")).mkdir(
parents=True, exist_ok=True
)
# Validate required output folders at launch
self.__validate_storage_folders()
def get(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType:
def get(self, image_name: str) -> PILImageType:
try:
image_path = self.get_path(image_origin, image_name)
image_path = self.get_path(image_name)
cache_item = self.__get_cache(image_path)
if cache_item:
return cache_item
@@ -117,13 +108,13 @@ class DiskImageFileStorage(ImageFileStorageBase):
def save(
self,
image: PILImageType,
image_origin: ResourceOrigin,
image_name: str,
metadata: Optional[ImageMetadata] = None,
thumbnail_size: int = 256,
) -> None:
try:
image_path = self.get_path(image_origin, image_name)
self.__validate_storage_folders()
image_path = self.get_path(image_name)
if metadata is not None:
pnginfo = PngImagePlugin.PngInfo()
@@ -133,7 +124,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
image.save(image_path, "PNG")
thumbnail_name = get_thumbnail_name(image_name)
thumbnail_path = self.get_path(image_origin, thumbnail_name, thumbnail=True)
thumbnail_path = self.get_path(thumbnail_name, thumbnail=True)
thumbnail_image = make_thumbnail(image, thumbnail_size)
thumbnail_image.save(thumbnail_path)
@@ -142,20 +133,19 @@ class DiskImageFileStorage(ImageFileStorageBase):
except Exception as e:
raise ImageFileSaveException from e
def delete(self, image_origin: ResourceOrigin, image_name: str) -> None:
def delete(self, image_name: str) -> None:
try:
basename = os.path.basename(image_name)
image_path = self.get_path(image_origin, basename)
image_path = self.get_path(image_name)
if os.path.exists(image_path):
if image_path.exists():
send2trash(image_path)
if image_path in self.__cache:
del self.__cache[image_path]
thumbnail_name = get_thumbnail_name(image_name)
thumbnail_path = self.get_path(image_origin, thumbnail_name, True)
thumbnail_path = self.get_path(thumbnail_name, True)
if os.path.exists(thumbnail_path):
if thumbnail_path.exists():
send2trash(thumbnail_path)
if thumbnail_path in self.__cache:
del self.__cache[thumbnail_path]
@@ -163,41 +153,33 @@ class DiskImageFileStorage(ImageFileStorageBase):
raise ImageFileDeleteException from e
# TODO: make this a bit more flexible for e.g. cloud storage
def get_path(
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
) -> str:
# strip out any relative path shenanigans
basename = os.path.basename(image_name)
def get_path(self, image_name: str, thumbnail: bool = False) -> Path:
path = self.__output_folder / image_name
if thumbnail:
thumbnail_name = get_thumbnail_name(basename)
path = os.path.join(
self.__output_folder, image_origin, "thumbnails", thumbnail_name
)
else:
path = os.path.join(self.__output_folder, image_origin, basename)
thumbnail_name = get_thumbnail_name(image_name)
path = self.__thumbnails_folder / thumbnail_name
abspath = os.path.abspath(path)
return path
return abspath
def validate_path(self, path: str) -> bool:
def validate_path(self, path: str | Path) -> bool:
"""Validates the path given for an image or thumbnail."""
try:
os.stat(path)
return True
except:
return False
path = path if isinstance(path, Path) else Path(path)
return path.exists()
def __validate_storage_folders(self) -> None:
"""Checks if the required output folders exist and create them if they don't"""
folders: list[Path] = [self.__output_folder, self.__thumbnails_folder]
for folder in folders:
folder.mkdir(parents=True, exist_ok=True)
def __get_cache(self, image_name: str) -> PILImageType | None:
def __get_cache(self, image_name: Path) -> PILImageType | None:
return None if image_name not in self.__cache else self.__cache[image_name]
def __set_cache(self, image_name: str, image: PILImageType):
def __set_cache(self, image_name: Path, image: PILImageType):
if not image_name in self.__cache:
self.__cache[image_name] = image
self.__cache_ids.put(
image_name
) # TODO: this should refresh position for LRU cache
self.__cache_ids.put(image_name) # TODO: this should refresh position for LRU cache
if len(self.__cache) > self.__max_cache_size:
cache_id = self.__cache_ids.get()
if cache_id in self.__cache:

View File

@@ -21,6 +21,7 @@ from invokeai.app.services.models.image_record import (
T = TypeVar("T", bound=BaseModel)
class OffsetPaginatedResults(GenericModel, Generic[T]):
"""Offset-paginated results"""
@@ -60,7 +61,7 @@ class ImageRecordStorageBase(ABC):
# TODO: Implement an `update()` method
@abstractmethod
def get(self, image_origin: ResourceOrigin, image_name: str) -> ImageRecord:
def get(self, image_name: str) -> ImageRecord:
"""Gets an image record."""
pass
@@ -68,7 +69,6 @@ class ImageRecordStorageBase(ABC):
def update(
self,
image_name: str,
image_origin: ResourceOrigin,
changes: ImageRecordChanges,
) -> None:
"""Updates an image record."""
@@ -82,6 +82,7 @@ class ImageRecordStorageBase(ABC):
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
) -> OffsetPaginatedResults[ImageRecord]:
"""Gets a page of image records."""
pass
@@ -89,7 +90,7 @@ class ImageRecordStorageBase(ABC):
# TODO: The database has a nullable `deleted_at` column, currently unused.
# Should we implement soft deletes? Would need coordination with ImageFileStorage.
@abstractmethod
def delete(self, image_origin: ResourceOrigin, image_name: str) -> None:
def delete(self, image_name: str) -> None:
"""Deletes an image record."""
pass
@@ -109,6 +110,11 @@ class ImageRecordStorageBase(ABC):
"""Saves an image record."""
pass
@abstractmethod
def get_most_recent_image_for_board(self, board_id: str) -> ImageRecord | None:
"""Gets the most recent image for a board."""
pass
class SqliteImageRecordStorage(ImageRecordStorageBase):
_filename: str
@@ -135,7 +141,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
self._lock.release()
def _create_tables(self) -> None:
"""Creates the tables for the `images` database."""
"""Creates the `images` table."""
# Create the `images` table.
self._cursor.execute(
@@ -152,6 +158,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
node_id TEXT,
metadata TEXT,
is_intermediate BOOLEAN DEFAULT FALSE,
board_id TEXT,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
@@ -190,15 +197,13 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
AFTER UPDATE
ON images FOR EACH ROW
BEGIN
UPDATE images SET updated_at = current_timestamp
UPDATE images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE image_name = old.image_name;
END;
"""
)
def get(
self, image_origin: ResourceOrigin, image_name: str
) -> Union[ImageRecord, None]:
def get(self, image_name: str) -> Union[ImageRecord, None]:
try:
self._lock.acquire()
@@ -225,7 +230,6 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
def update(
self,
image_name: str,
image_origin: ResourceOrigin,
changes: ImageRecordChanges,
) -> None:
try:
@@ -262,6 +266,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
""",
(changes.is_intermediate, image_name),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
@@ -276,40 +281,66 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
) -> OffsetPaginatedResults[ImageRecord]:
try:
self._lock.acquire()
# Manually build two queries - one for the count, one for the records
count_query = """--sql
SELECT COUNT(*)
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1
"""
count_query = f"""SELECT COUNT(*) FROM images WHERE 1=1\n"""
images_query = f"""SELECT * FROM images WHERE 1=1\n"""
images_query = """--sql
SELECT images.*
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1
"""
query_conditions = ""
query_params = []
if image_origin is not None:
query_conditions += f"""AND image_origin = ?\n"""
query_conditions += """--sql
AND images.image_origin = ?
"""
query_params.append(image_origin.value)
if categories is not None:
## Convert the enum values to unique list of strings
category_strings = list(
map(lambda c: c.value, set(categories))
)
# Convert the enum values to unique list of strings
category_strings = list(map(lambda c: c.value, set(categories)))
# Create the correct length of placeholders
placeholders = ",".join("?" * len(category_strings))
query_conditions += f"AND image_category IN ( {placeholders} )\n"
query_conditions += f"""--sql
AND images.image_category IN ( {placeholders} )
"""
# Unpack the included categories into the query params
for c in category_strings:
query_params.append(c)
if is_intermediate is not None:
query_conditions += f"""AND is_intermediate = ?\n"""
query_conditions += """--sql
AND images.is_intermediate = ?
"""
query_params.append(is_intermediate)
query_pagination = f"""ORDER BY created_at DESC LIMIT ? OFFSET ?\n"""
if board_id is not None:
query_conditions += """--sql
AND board_images.board_id = ?
"""
query_params.append(board_id)
query_pagination = """--sql
ORDER BY images.created_at DESC LIMIT ? OFFSET ?
"""
# Final images query with pagination
images_query += query_conditions + query_pagination + ";"
@@ -326,7 +357,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
count_query += query_conditions + ";"
count_params = query_params.copy()
self._cursor.execute(count_query, count_params)
count = self._cursor.fetchone()[0]
count = cast(int, self._cursor.fetchone()[0])
except sqlite3.Error as e:
self._conn.rollback()
raise e
@@ -337,7 +368,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
items=images, offset=offset, limit=limit, total=count
)
def delete(self, image_origin: ResourceOrigin, image_name: str) -> None:
def delete(self, image_name: str) -> None:
try:
self._lock.acquire()
self._cursor.execute(
@@ -417,3 +448,28 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
raise ImageRecordSaveException from e
finally:
self._lock.release()
def get_most_recent_image_for_board(
self, board_id: str
) -> Union[ImageRecord, None]:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT images.*
FROM images
JOIN board_images ON images.image_name = board_images.image_name
WHERE board_images.board_id = ?
ORDER BY images.created_at DESC
LIMIT 1;
""",
(board_id,),
)
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
finally:
self._lock.release()
if result is None:
return None
return deserialize_image_record(dict(result))

View File

@@ -10,6 +10,7 @@ from invokeai.app.models.image import (
InvalidOriginException,
)
from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase
from invokeai.app.services.image_record_storage import (
ImageRecordDeleteException,
ImageRecordNotFoundException,
@@ -49,7 +50,7 @@ class ImageServiceABC(ABC):
image_category: ImageCategory,
node_id: Optional[str] = None,
session_id: Optional[str] = None,
intermediate: bool = False,
is_intermediate: bool = False,
) -> ImageDTO:
"""Creates an image, storing the file and its metadata."""
pass
@@ -57,7 +58,6 @@ class ImageServiceABC(ABC):
@abstractmethod
def update(
self,
image_origin: ResourceOrigin,
image_name: str,
changes: ImageRecordChanges,
) -> ImageDTO:
@@ -65,22 +65,22 @@ class ImageServiceABC(ABC):
pass
@abstractmethod
def get_pil_image(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType:
def get_pil_image(self, image_name: str) -> PILImageType:
"""Gets an image as a PIL image."""
pass
@abstractmethod
def get_record(self, image_origin: ResourceOrigin, image_name: str) -> ImageRecord:
def get_record(self, image_name: str) -> ImageRecord:
"""Gets an image record."""
pass
@abstractmethod
def get_dto(self, image_origin: ResourceOrigin, image_name: str) -> ImageDTO:
def get_dto(self, image_name: str) -> ImageDTO:
"""Gets an image DTO."""
pass
@abstractmethod
def get_path(self, image_origin: ResourceOrigin, image_name: str) -> str:
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
"""Gets an image's path."""
pass
@@ -90,9 +90,7 @@ class ImageServiceABC(ABC):
pass
@abstractmethod
def get_url(
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
) -> str:
def get_url(self, image_name: str, thumbnail: bool = False) -> str:
"""Gets an image's or thumbnail's URL."""
pass
@@ -104,12 +102,13 @@ class ImageServiceABC(ABC):
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
) -> OffsetPaginatedResults[ImageDTO]:
"""Gets a paginated list of image DTOs."""
pass
@abstractmethod
def delete(self, image_origin: ResourceOrigin, image_name: str):
def delete(self, image_name: str):
"""Deletes an image."""
pass
@@ -117,8 +116,9 @@ class ImageServiceABC(ABC):
class ImageServiceDependencies:
"""Service dependencies for the ImageService."""
records: ImageRecordStorageBase
files: ImageFileStorageBase
image_records: ImageRecordStorageBase
image_files: ImageFileStorageBase
board_image_records: BoardImageRecordStorageBase
metadata: MetadataServiceBase
urls: UrlServiceBase
logger: Logger
@@ -129,14 +129,16 @@ class ImageServiceDependencies:
self,
image_record_storage: ImageRecordStorageBase,
image_file_storage: ImageFileStorageBase,
board_image_record_storage: BoardImageRecordStorageBase,
metadata: MetadataServiceBase,
url: UrlServiceBase,
logger: Logger,
names: NameServiceBase,
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
):
self.records = image_record_storage
self.files = image_file_storage
self.image_records = image_record_storage
self.image_files = image_file_storage
self.board_image_records = board_image_record_storage
self.metadata = metadata
self.urls = url
self.logger = logger
@@ -147,25 +149,8 @@ class ImageServiceDependencies:
class ImageService(ImageServiceABC):
_services: ImageServiceDependencies
def __init__(
self,
image_record_storage: ImageRecordStorageBase,
image_file_storage: ImageFileStorageBase,
metadata: MetadataServiceBase,
url: UrlServiceBase,
logger: Logger,
names: NameServiceBase,
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
):
self._services = ImageServiceDependencies(
image_record_storage=image_record_storage,
image_file_storage=image_file_storage,
metadata=metadata,
url=url,
logger=logger,
names=names,
graph_execution_manager=graph_execution_manager,
)
def __init__(self, services: ImageServiceDependencies):
self._services = services
def create(
self,
@@ -190,7 +175,7 @@ class ImageService(ImageServiceABC):
try:
# TODO: Consider using a transaction here to ensure consistency between storage and database
created_at = self._services.records.save(
self._services.image_records.save(
# Non-nullable fields
image_name=image_name,
image_origin=image_origin,
@@ -205,38 +190,15 @@ class ImageService(ImageServiceABC):
metadata=metadata,
)
self._services.files.save(
image_origin=image_origin,
self._services.image_files.save(
image_name=image_name,
image=image,
metadata=metadata,
)
image_url = self._services.urls.get_image_url(image_origin, image_name)
thumbnail_url = self._services.urls.get_image_url(
image_origin, image_name, True
)
image_dto = self.get_dto(image_name)
return ImageDTO(
# Non-nullable fields
image_name=image_name,
image_origin=image_origin,
image_category=image_category,
width=width,
height=height,
# Nullable fields
node_id=node_id,
session_id=session_id,
metadata=metadata,
# Meta fields
created_at=created_at,
updated_at=created_at, # this is always the same as the created_at at this time
deleted_at=None,
is_intermediate=is_intermediate,
# Extra non-nullable fields for DTO
image_url=image_url,
thumbnail_url=thumbnail_url,
)
return image_dto
except ImageRecordSaveException:
self._services.logger.error("Failed to save image record")
raise
@@ -249,13 +211,12 @@ class ImageService(ImageServiceABC):
def update(
self,
image_origin: ResourceOrigin,
image_name: str,
changes: ImageRecordChanges,
) -> ImageDTO:
try:
self._services.records.update(image_name, image_origin, changes)
return self.get_dto(image_origin, image_name)
self._services.image_records.update(image_name, changes)
return self.get_dto(image_name)
except ImageRecordSaveException:
self._services.logger.error("Failed to update image record")
raise
@@ -263,9 +224,9 @@ class ImageService(ImageServiceABC):
self._services.logger.error("Problem updating image record")
raise e
def get_pil_image(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType:
def get_pil_image(self, image_name: str) -> PILImageType:
try:
return self._services.files.get(image_origin, image_name)
return self._services.image_files.get(image_name)
except ImageFileNotFoundException:
self._services.logger.error("Failed to get image file")
raise
@@ -273,9 +234,9 @@ class ImageService(ImageServiceABC):
self._services.logger.error("Problem getting image file")
raise e
def get_record(self, image_origin: ResourceOrigin, image_name: str) -> ImageRecord:
def get_record(self, image_name: str) -> ImageRecord:
try:
return self._services.records.get(image_origin, image_name)
return self._services.image_records.get(image_name)
except ImageRecordNotFoundException:
self._services.logger.error("Image record not found")
raise
@@ -283,14 +244,15 @@ class ImageService(ImageServiceABC):
self._services.logger.error("Problem getting image record")
raise e
def get_dto(self, image_origin: ResourceOrigin, image_name: str) -> ImageDTO:
def get_dto(self, image_name: str) -> ImageDTO:
try:
image_record = self._services.records.get(image_origin, image_name)
image_record = self._services.image_records.get(image_name)
image_dto = image_record_to_dto(
image_record,
self._services.urls.get_image_url(image_origin, image_name),
self._services.urls.get_image_url(image_origin, image_name, True),
self._services.urls.get_image_url(image_name),
self._services.urls.get_image_url(image_name, True),
self._services.board_image_records.get_board_for_image(image_name),
)
return image_dto
@@ -301,27 +263,23 @@ class ImageService(ImageServiceABC):
self._services.logger.error("Problem getting image DTO")
raise e
def get_path(
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
) -> str:
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
try:
return self._services.files.get_path(image_origin, image_name, thumbnail)
return self._services.image_files.get_path(image_name, thumbnail)
except Exception as e:
self._services.logger.error("Problem getting image path")
raise e
def validate_path(self, path: str) -> bool:
try:
return self._services.files.validate_path(path)
return self._services.image_files.validate_path(path)
except Exception as e:
self._services.logger.error("Problem validating image path")
raise e
def get_url(
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
) -> str:
def get_url(self, image_name: str, thumbnail: bool = False) -> str:
try:
return self._services.urls.get_image_url(image_origin, image_name, thumbnail)
return self._services.urls.get_image_url(image_name, thumbnail)
except Exception as e:
self._services.logger.error("Problem getting image path")
raise e
@@ -333,23 +291,26 @@ class ImageService(ImageServiceABC):
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
) -> OffsetPaginatedResults[ImageDTO]:
try:
results = self._services.records.get_many(
results = self._services.image_records.get_many(
offset,
limit,
image_origin,
categories,
is_intermediate,
board_id,
)
image_dtos = list(
map(
lambda r: image_record_to_dto(
r,
self._services.urls.get_image_url(r.image_origin, r.image_name),
self._services.urls.get_image_url(
r.image_origin, r.image_name, True
self._services.urls.get_image_url(r.image_name),
self._services.urls.get_image_url(r.image_name, True),
self._services.board_image_records.get_board_for_image(
r.image_name
),
),
results.items,
@@ -366,10 +327,10 @@ class ImageService(ImageServiceABC):
self._services.logger.error("Problem getting paginated image DTOs")
raise e
def delete(self, image_origin: ResourceOrigin, image_name: str):
def delete(self, image_name: str):
try:
self._services.files.delete(image_origin, image_name)
self._services.records.delete(image_origin, image_name)
self._services.image_files.delete(image_name)
self._services.image_records.delete(image_name)
except ImageRecordDeleteException:
self._services.logger.error(f"Failed to delete image record")
raise

View File

@@ -4,7 +4,9 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING:
from logging import Logger
from invokeai.app.services.images import ImageService
from invokeai.app.services.board_images import BoardImagesServiceABC
from invokeai.app.services.boards import BoardServiceABC
from invokeai.app.services.images import ImageServiceABC
from invokeai.backend import ModelManager
from invokeai.app.services.events import EventServiceBase
from invokeai.app.services.latent_storage import LatentsStorageBase
@@ -26,9 +28,9 @@ class InvocationServices:
model_manager: "ModelManager"
restoration: "RestorationServices"
configuration: "InvokeAISettings"
images: "ImageService"
# NOTE: we must forward-declare any types that include invocations, since invocations can use services
images: "ImageServiceABC"
boards: "BoardServiceABC"
board_images: "BoardImagesServiceABC"
graph_library: "ItemStorageABC"["LibraryGraph"]
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"]
processor: "InvocationProcessorABC"
@@ -39,7 +41,9 @@ class InvocationServices:
events: "EventServiceBase",
logger: "Logger",
latents: "LatentsStorageBase",
images: "ImageService",
images: "ImageServiceABC",
boards: "BoardServiceABC",
board_images: "BoardImagesServiceABC",
queue: "InvocationQueueABC",
graph_library: "ItemStorageABC"["LibraryGraph"],
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"],
@@ -52,9 +56,12 @@ class InvocationServices:
self.logger = logger
self.latents = latents
self.images = images
self.boards = boards
self.board_images = board_images
self.queue = queue
self.graph_library = graph_library
self.graph_execution_manager = graph_execution_manager
self.processor = processor
self.restoration = restoration
self.configuration = configuration
self.boards = boards

View File

@@ -1,6 +1,5 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
import os
from abc import ABC, abstractmethod
from pathlib import Path
from queue import Queue
@@ -70,24 +69,26 @@ class ForwardCacheLatentsStorage(LatentsStorageBase):
class DiskLatentsStorage(LatentsStorageBase):
"""Stores latents in a folder on disk without caching"""
__output_folder: str
__output_folder: str | Path
def __init__(self, output_folder: str):
self.__output_folder = output_folder
Path(output_folder).mkdir(parents=True, exist_ok=True)
def __init__(self, output_folder: str | Path):
self.__output_folder = output_folder if isinstance(output_folder, Path) else Path(output_folder)
self.__output_folder.mkdir(parents=True, exist_ok=True)
def get(self, name: str) -> torch.Tensor:
latent_path = self.get_path(name)
return torch.load(latent_path)
def save(self, name: str, data: torch.Tensor) -> None:
self.__output_folder.mkdir(parents=True, exist_ok=True)
latent_path = self.get_path(name)
torch.save(data, latent_path)
def delete(self, name: str) -> None:
latent_path = self.get_path(name)
os.remove(latent_path)
latent_path.unlink()
def get_path(self, name: str) -> str:
return os.path.join(self.__output_folder, name)
def get_path(self, name: str) -> Path:
return self.__output_folder / name

View File

@@ -1,104 +0,0 @@
import os
import sys
import torch
from argparse import Namespace
from omegaconf import OmegaConf
from pathlib import Path
from typing import types
import invokeai.version
from .config import InvokeAISettings
from ...backend import ModelManager
from ...backend.util import choose_precision, choose_torch_device
# TODO: Replace with an abstract class base ModelManagerBase
def get_model_manager(config: InvokeAISettings, logger: types.ModuleType) -> ModelManager:
model_config = config.model_conf_path
if not model_config.exists():
report_model_error(
config, FileNotFoundError(f"The file {model_config} could not be found."), logger
)
logger.info(f"{invokeai.version.__app_name__}, version {invokeai.version.__version__}")
logger.info(f'InvokeAI runtime directory is "{config.root}"')
# these two lines prevent a horrible warning message from appearing
# when the frozen CLIP tokenizer is imported
import transformers # type: ignore
transformers.logging.set_verbosity_error()
import diffusers
diffusers.logging.set_verbosity_error()
embedding_path = config.embedding_path
# migrate legacy models
ModelManager.migrate_models()
# creating the model manager
try:
device = torch.device(choose_torch_device())
precision = 'float16' if config.precision=='float16' \
else 'float32' if config.precision=='float32' \
else choose_precision(device)
model_manager = ModelManager(
OmegaConf.load(config.model_conf_path),
precision=precision,
device_type=device,
max_loaded_models=config.max_loaded_models,
embedding_path = embedding_path,
logger = logger,
)
except (FileNotFoundError, TypeError, AssertionError) as e:
report_model_error(config, e, logger)
except (IOError, KeyError) as e:
logger.error(f"{e}. Aborting.")
sys.exit(-1)
# try to autoconvert new models
# autoimport new .ckpt files
if config.autoconvert_path:
model_manager.heuristic_import(
config.autoconvert_path,
)
return model_manager
def report_model_error(opt: Namespace, e: Exception, logger: types.ModuleType):
logger.error(f'An error occurred while attempting to initialize the model: "{str(e)}"')
logger.error(
"This can be caused by a missing or corrupted models file, and can sometimes be fixed by (re)installing the models."
)
yes_to_all = os.environ.get("INVOKE_MODEL_RECONFIGURE")
if yes_to_all:
logger.warning(
"Reconfiguration is being forced by environment variable INVOKE_MODEL_RECONFIGURE"
)
else:
response = input(
"Do you want to run invokeai-configure script to select and/or reinstall models? [y] "
)
if response.startswith(("n", "N")):
return
logger.info("invokeai-configure is launching....\n")
# Match arguments that were set on the CLI
# only the arguments accepted by the configuration script are parsed
root_dir = ["--root", opt.root_dir] if opt.root_dir is not None else []
config = ["--config", opt.conf] if opt.conf is not None else []
sys.argv = ["invokeai-configure"]
sys.argv.extend(root_dir)
sys.argv.extend(config.to_dict())
if yes_to_all is not None:
for arg in yes_to_all.split():
sys.argv.append(arg)
from invokeai.frontend.install import invokeai_configure
invokeai_configure()
# TODO: Figure out how to restart
# print('** InvokeAI will now restart')
# sys.argv = previous_args
# main() # would rather do a os.exec(), but doesn't exist?
# sys.exit(0)

View File

@@ -0,0 +1,363 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
from __future__ import annotations
import torch
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Optional, Union, Callable, List, Tuple, types, TYPE_CHECKING
from dataclasses import dataclass
from invokeai.backend.model_management.model_manager import (
ModelManager,
BaseModelType,
ModelType,
SubModelType,
ModelInfo,
)
from invokeai.app.models.exceptions import CanceledException
from .config import InvokeAIAppConfig
from ...backend.util import choose_precision, choose_torch_device
if TYPE_CHECKING:
from ..invocations.baseinvocation import BaseInvocation, InvocationContext
class ModelManagerServiceBase(ABC):
"""Responsible for managing models on disk and in memory"""
@abstractmethod
def __init__(
self,
config: InvokeAIAppConfig,
logger: types.ModuleType,
):
"""
Initialize with the path to the models.yaml config file.
Optional parameters are the torch device type, precision, max_models,
and sequential_offload boolean. Note that the default device
type and precision are set up for a CUDA system running at half precision.
"""
pass
@abstractmethod
def get_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
submodel: Optional[SubModelType] = None,
node: Optional[BaseInvocation] = None,
context: Optional[InvocationContext] = None,
) -> ModelInfo:
"""Retrieve the indicated model with name and type.
submodel can be used to get a part (such as the vae)
of a diffusers pipeline."""
pass
@property
@abstractmethod
def logger(self):
pass
@abstractmethod
def model_exists(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
) -> bool:
pass
@abstractmethod
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
"""
Given a model name returns a dict-like (OmegaConf) object describing it.
"""
pass
@abstractmethod
def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]:
"""
Returns a list of all the model names known.
"""
pass
@abstractmethod
def list_models(self, base_model: Optional[BaseModelType] = None, model_type: Optional[ModelType] = None) -> dict:
"""
Return a dict of models in the format:
{ model_type1:
{ model_name1: {'status': 'active'|'cached'|'not loaded',
'model_name' : name,
'model_type' : SDModelType,
'description': description,
'format': 'folder'|'safetensors'|'ckpt'
},
model_name2: { etc }
},
model_type2:
{ model_name_n: etc
}
"""
pass
@abstractmethod
def add_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
model_attributes: dict,
clobber: bool = False
) -> None:
"""
Update the named model with a dictionary of attributes. Will fail with an
assertion error if the name already exists. Pass clobber=True to overwrite.
On a successful update, the config will be changed in memory. Will fail
with an assertion error if provided attributes are incorrect or
the model name is missing. Call commit() to write changes to disk.
"""
pass
@abstractmethod
def del_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
):
"""
Delete the named model from configuration. If delete_files is true,
then the underlying weight file or diffusers directory will be deleted
as well. Call commit() to write to disk.
"""
pass
@abstractmethod
def commit(self, conf_file: Path = None) -> None:
"""
Write current configuration out to the indicated file.
If no conf_file is provided, then replaces the
original file/database used to initialize the object.
"""
pass
# simple implementation
class ModelManagerService(ModelManagerServiceBase):
"""Responsible for managing models on disk and in memory"""
def __init__(
self,
config: InvokeAIAppConfig,
logger: types.ModuleType,
):
"""
Initialize with the path to the models.yaml config file.
Optional parameters are the torch device type, precision, max_models,
and sequential_offload boolean. Note that the default device
type and precision are set up for a CUDA system running at half precision.
"""
if config.model_conf_path and config.model_conf_path.exists():
config_file = config.model_conf_path
else:
config_file = config.root_dir / "configs/models.yaml"
if not config_file.exists():
raise IOError(f"The file {config_file} could not be found.")
logger.debug(f'config file={config_file}')
device = torch.device(choose_torch_device())
precision = config.precision
if precision == "auto":
precision = choose_precision(device)
dtype = torch.float32 if precision == 'float32' else torch.float16
# this is transitional backward compatibility
# support for the deprecated `max_loaded_models`
# configuration value. If present, then the
# cache size is set to 2.5 GB times
# the number of max_loaded_models. Otherwise
# use new `max_cache_size` config setting
max_cache_size = config.max_cache_size \
if hasattr(config,'max_cache_size') \
else config.max_loaded_models * 2.5
sequential_offload = config.sequential_guidance
self.mgr = ModelManager(
config=config_file,
device_type=device,
precision=dtype,
max_cache_size=max_cache_size,
sequential_offload=sequential_offload,
logger=logger,
)
logger.info('Model manager service initialized')
def get_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
submodel: Optional[SubModelType] = None,
node: Optional[BaseInvocation] = None,
context: Optional[InvocationContext] = None,
) -> ModelInfo:
"""
Retrieve the indicated model. submodel can be used to get a
part (such as the vae) of a diffusers mode.
"""
# if we are called from within a node, then we get to emit
# load start and complete events
if node and context:
self._emit_load_event(
node=node,
context=context,
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=submodel,
)
model_info = self.mgr.get_model(
model_name,
base_model,
model_type,
submodel,
)
if node and context:
self._emit_load_event(
node=node,
context=context,
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=submodel,
model_info=model_info
)
return model_info
def model_exists(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
) -> bool:
"""
Given a model name, returns True if it is a valid
identifier.
"""
return self.mgr.model_exists(
model_name,
base_model,
model_type,
)
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
"""
Given a model name returns a dict-like (OmegaConf) object describing it.
"""
return self.mgr.model_info(model_name, base_model, model_type)
def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]:
"""
Returns a list of all the model names known.
"""
return self.mgr.model_names()
def list_models(
self,
base_model: Optional[BaseModelType] = None,
model_type: Optional[ModelType] = None
) -> list[dict]:
# ) -> dict:
"""
Return a list of models.
"""
return self.mgr.list_models(base_model, model_type)
def add_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
model_attributes: dict,
clobber: bool = False,
)->None:
"""
Update the named model with a dictionary of attributes. Will fail with an
assertion error if the name already exists. Pass clobber=True to overwrite.
On a successful update, the config will be changed in memory. Will fail
with an assertion error if provided attributes are incorrect or
the model name is missing. Call commit() to write changes to disk.
"""
return self.mgr.add_model(model_name, base_model, model_type, model_attributes, clobber)
def del_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
):
"""
Delete the named model from configuration. If delete_files is true,
then the underlying weight file or diffusers directory will be deleted
as well. Call commit() to write to disk.
"""
self.mgr.del_model(model_name, base_model, model_type)
def commit(self, conf_file: Optional[Path]=None):
"""
Write current configuration out to the indicated file.
If no conf_file is provided, then replaces the
original file/database used to initialize the object.
"""
return self.mgr.commit(conf_file)
def _emit_load_event(
self,
node,
context,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
submodel: SubModelType,
model_info: Optional[ModelInfo] = None,
):
if context.services.queue.is_canceled(context.graph_execution_state_id):
raise CanceledException()
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
source_node_id = graph_execution_state.prepared_source_mapping[node.id]
if model_info:
context.services.events.emit_model_load_completed(
graph_execution_state_id=context.graph_execution_state_id,
node=node.dict(),
source_node_id=source_node_id,
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=submodel,
model_info=model_info
)
else:
context.services.events.emit_model_load_started(
graph_execution_state_id=context.graph_execution_state_id,
node=node.dict(),
source_node_id=source_node_id,
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=submodel,
)
@property
def logger(self):
return self.mgr.logger

View File

@@ -0,0 +1,62 @@
from typing import Optional, Union
from datetime import datetime
from pydantic import BaseModel, Extra, Field, StrictBool, StrictStr
from invokeai.app.util.misc import get_iso_timestamp
class BoardRecord(BaseModel):
"""Deserialized board record."""
board_id: str = Field(description="The unique ID of the board.")
"""The unique ID of the board."""
board_name: str = Field(description="The name of the board.")
"""The name of the board."""
created_at: Union[datetime, str] = Field(
description="The created timestamp of the board."
)
"""The created timestamp of the image."""
updated_at: Union[datetime, str] = Field(
description="The updated timestamp of the board."
)
"""The updated timestamp of the image."""
deleted_at: Union[datetime, str, None] = Field(
description="The deleted timestamp of the board."
)
"""The updated timestamp of the image."""
cover_image_name: Optional[str] = Field(
description="The name of the cover image of the board."
)
"""The name of the cover image of the board."""
class BoardDTO(BoardRecord):
"""Deserialized board record with cover image URL and image count."""
cover_image_name: Optional[str] = Field(
description="The name of the board's cover image."
)
"""The URL of the thumbnail of the most recent image in the board."""
image_count: int = Field(description="The number of images in the board.")
"""The number of images in the board."""
def deserialize_board_record(board_dict: dict) -> BoardRecord:
"""Deserializes a board record."""
# Retrieve all the values, setting "reasonable" defaults if they are not present.
board_id = board_dict.get("board_id", "unknown")
board_name = board_dict.get("board_name", "unknown")
cover_image_name = board_dict.get("cover_image_name", "unknown")
created_at = board_dict.get("created_at", get_iso_timestamp())
updated_at = board_dict.get("updated_at", get_iso_timestamp())
deleted_at = board_dict.get("deleted_at", get_iso_timestamp())
return BoardRecord(
board_id=board_id,
board_name=board_name,
cover_image_name=cover_image_name,
created_at=created_at,
updated_at=updated_at,
deleted_at=deleted_at,
)

View File

@@ -79,8 +79,6 @@ class ImageUrlsDTO(BaseModel):
image_name: str = Field(description="The unique name of the image.")
"""The unique name of the image."""
image_origin: ResourceOrigin = Field(description="The type of the image.")
"""The origin of the image."""
image_url: str = Field(description="The URL of the image.")
"""The URL of the image."""
thumbnail_url: str = Field(description="The URL of the image's thumbnail.")
@@ -88,19 +86,24 @@ class ImageUrlsDTO(BaseModel):
class ImageDTO(ImageRecord, ImageUrlsDTO):
"""Deserialized image record, enriched for the frontend with URLs."""
"""Deserialized image record, enriched for the frontend."""
board_id: Union[str, None] = Field(
description="The id of the board the image belongs to, if one exists."
)
"""The id of the board the image belongs to, if one exists."""
pass
def image_record_to_dto(
image_record: ImageRecord, image_url: str, thumbnail_url: str
image_record: ImageRecord, image_url: str, thumbnail_url: str, board_id: Union[str, None]
) -> ImageDTO:
"""Converts an image record to an image DTO."""
return ImageDTO(
**image_record.dict(),
image_url=image_url,
thumbnail_url=thumbnail_url,
board_id=board_id,
)

View File

@@ -16,13 +16,14 @@ class RestorationServices:
gfpgan, codeformer, esrgan = None, None, None
if args.restore or args.esrgan:
restoration = Restoration()
if args.restore:
# TODO: redo for new model structure
if False and args.restore:
gfpgan, codeformer = restoration.load_face_restore_models(
args.gfpgan_model_path
)
else:
logger.info("Face restoration disabled")
if args.esrgan:
if False and args.esrgan:
esrgan = restoration.load_esrgan(args.esrgan_bg_tile)
else:
logger.info("Upscaling disabled")

View File

@@ -1,17 +1,12 @@
import os
from abc import ABC, abstractmethod
from invokeai.app.models.image import ResourceOrigin
from invokeai.app.util.thumbnails import get_thumbnail_name
class UrlServiceBase(ABC):
"""Responsible for building URLs for resources."""
@abstractmethod
def get_image_url(
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
) -> str:
def get_image_url(self, image_name: str, thumbnail: bool = False) -> str:
"""Gets the URL for an image or thumbnail."""
pass
@@ -20,15 +15,11 @@ class LocalUrlService(UrlServiceBase):
def __init__(self, base_url: str = "api/v1"):
self._base_url = base_url
def get_image_url(
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
) -> str:
def get_image_url(self, image_name: str, thumbnail: bool = False) -> str:
image_basename = os.path.basename(image_name)
# These paths are determined by the routes in invokeai/app/api/routers/images.py
if thumbnail:
return (
f"{self._base_url}/images/{image_origin.value}/{image_basename}/thumbnail"
)
return f"{self._base_url}/images/{image_basename}/thumbnail"
return f"{self._base_url}/images/{image_origin.value}/{image_basename}"
return f"{self._base_url}/images/{image_basename}"

View File

@@ -5,9 +5,11 @@ from .generator import (
InvokeAIGeneratorBasicParams,
InvokeAIGenerator,
InvokeAIGeneratorOutput,
Txt2Img,
Img2Img,
Inpaint
)
from .model_management import ModelManager, SDModelComponent
from .model_management import (
ModelManager, ModelCache, BaseModelType,
ModelType, SubModelType, ModelInfo
)
from .safety_checker import SafetyChecker

View File

@@ -5,7 +5,6 @@ from .base import (
InvokeAIGenerator,
InvokeAIGeneratorBasicParams,
InvokeAIGeneratorOutput,
Txt2Img,
Img2Img,
Inpaint,
Generator,

View File

@@ -29,7 +29,6 @@ import invokeai.backend.util.logging as logger
from ..image_util import configure_model_padding
from ..util.util import rand_perlin_2d
from ..safety_checker import SafetyChecker
from ..prompting.conditioning import get_uc_and_c_and_ec
from ..stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
from ..stable_diffusion.schedulers import SCHEDULER_MAP
@@ -81,13 +80,15 @@ class InvokeAIGenerator(metaclass=ABCMeta):
self.params=params
self.kwargs = kwargs
def generate(self,
prompt: str='',
callback: Optional[Callable]=None,
step_callback: Optional[Callable]=None,
iterations: int=1,
**keyword_args,
)->Iterator[InvokeAIGeneratorOutput]:
def generate(
self,
conditioning: tuple,
scheduler,
callback: Optional[Callable]=None,
step_callback: Optional[Callable]=None,
iterations: int=1,
**keyword_args,
)->Iterator[InvokeAIGeneratorOutput]:
'''
Return an iterator across the indicated number of generations.
Each time the iterator is called it will return an InvokeAIGeneratorOutput
@@ -113,54 +114,46 @@ class InvokeAIGenerator(metaclass=ABCMeta):
generator_args.update(keyword_args)
model_info = self.model_info
model_name = model_info['model_name']
model:StableDiffusionGeneratorPipeline = model_info['model']
model_hash = model_info['hash']
scheduler: Scheduler = self.get_scheduler(
model=model,
scheduler_name=generator_args.get('scheduler')
)
model_name = model_info.name
model_hash = model_info.hash
with model_info.context as model:
gen_class = self._generator_class()
generator = gen_class(model, self.params.precision, **self.kwargs)
if self.params.variation_amount > 0:
generator.set_variation(generator_args.get('seed'),
generator_args.get('variation_amount'),
generator_args.get('with_variations')
)
# get conditioning from prompt via Compel package
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(prompt, model=model)
gen_class = self._generator_class()
generator = gen_class(model, self.params.precision, **self.kwargs)
if self.params.variation_amount > 0:
generator.set_variation(generator_args.get('seed'),
generator_args.get('variation_amount'),
generator_args.get('with_variations')
)
if isinstance(model, DiffusionPipeline):
for component in [model.unet, model.vae]:
configure_model_padding(component,
if isinstance(model, DiffusionPipeline):
for component in [model.unet, model.vae]:
configure_model_padding(component,
generator_args.get('seamless',False),
generator_args.get('seamless_axes')
)
else:
configure_model_padding(model,
generator_args.get('seamless',False),
generator_args.get('seamless_axes')
)
else:
configure_model_padding(model,
generator_args.get('seamless',False),
generator_args.get('seamless_axes')
)
iteration_count = range(iterations) if iterations else itertools.count(start=0, step=1)
for i in iteration_count:
results = generator.generate(prompt,
conditioning=(uc, c, extra_conditioning_info),
step_callback=step_callback,
sampler=scheduler,
**generator_args,
)
output = InvokeAIGeneratorOutput(
image=results[0][0],
seed=results[0][1],
attention_maps_images=results[0][2],
model_hash = model_hash,
params=Namespace(model_name=model_name,**generator_args),
)
if callback:
callback(output)
iteration_count = range(iterations) if iterations else itertools.count(start=0, step=1)
for i in iteration_count:
results = generator.generate(
conditioning=conditioning,
step_callback=step_callback,
sampler=scheduler,
**generator_args,
)
output = InvokeAIGeneratorOutput(
image=results[0][0],
seed=results[0][1],
attention_maps_images=results[0][2],
model_hash = model_hash,
params=Namespace(model_name=model_name,**generator_args),
)
if callback:
callback(output)
yield output
@classmethod
@@ -173,20 +166,6 @@ class InvokeAIGenerator(metaclass=ABCMeta):
def load_generator(self, model: StableDiffusionGeneratorPipeline, generator_class: Type[Generator]):
return generator_class(model, self.params.precision)
def get_scheduler(self, scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler:
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP['ddim'])
scheduler_config = model.scheduler.config
if "_backup" in scheduler_config:
scheduler_config = scheduler_config["_backup"]
scheduler_config = {**scheduler_config, **scheduler_extra_config, "_backup": scheduler_config}
scheduler = scheduler_class.from_config(scheduler_config)
# hack copied over from generate.py
if not hasattr(scheduler, 'uses_inpainting_model'):
scheduler.uses_inpainting_model = lambda: False
return scheduler
@classmethod
def _generator_class(cls)->Type[Generator]:
'''
@@ -196,13 +175,6 @@ class InvokeAIGenerator(metaclass=ABCMeta):
'''
return Generator
# ------------------------------------
class Txt2Img(InvokeAIGenerator):
@classmethod
def _generator_class(cls):
from .txt2img import Txt2Img
return Txt2Img
# ------------------------------------
class Img2Img(InvokeAIGenerator):
def generate(self,
@@ -256,25 +228,6 @@ class Inpaint(Img2Img):
from .inpaint import Inpaint
return Inpaint
# ------------------------------------
class Embiggen(Txt2Img):
def generate(
self,
embiggen: list=None,
embiggen_tiles: list = None,
strength: float=0.75,
**kwargs)->Iterator[InvokeAIGeneratorOutput]:
return super().generate(embiggen=embiggen,
embiggen_tiles=embiggen_tiles,
strength=strength,
**kwargs)
@classmethod
def _generator_class(cls):
from .embiggen import Embiggen
return Embiggen
class Generator:
downsampling_factor: int
latent_channels: int
@@ -285,7 +238,7 @@ class Generator:
self.model = model
self.precision = precision
self.seed = None
self.latent_channels = model.channels
self.latent_channels = model.unet.config.in_channels
self.downsampling_factor = downsampling # BUG: should come from model or config
self.safety_checker = None
self.perlin = 0.0
@@ -296,7 +249,7 @@ class Generator:
self.free_gpu_mem = None
# this is going to be overridden in img2img.py, txt2img.py and inpaint.py
def get_make_image(self, prompt, **kwargs):
def get_make_image(self, **kwargs):
"""
Returns a function returning an image derived from the prompt and the initial image
Return value depends on the seed at the time you call it
@@ -312,7 +265,6 @@ class Generator:
def generate(
self,
prompt,
width,
height,
sampler,
@@ -337,7 +289,6 @@ class Generator:
saver.get_stacked_maps_image()
)
make_image = self.get_make_image(
prompt,
sampler=sampler,
init_image=init_image,
width=width,

View File

@@ -1,559 +0,0 @@
"""
invokeai.backend.generator.embiggen descends from .generator
and generates with .generator.img2img
"""
import numpy as np
import torch
from PIL import Image
from tqdm import trange
import invokeai.backend.util.logging as logger
from .base import Generator
from .img2img import Img2Img
class Embiggen(Generator):
def __init__(self, model, precision):
super().__init__(model, precision)
self.init_latent = None
# Replace generate because Embiggen doesn't need/use most of what it does normallly
def generate(
self,
prompt,
iterations=1,
seed=None,
image_callback=None,
step_callback=None,
**kwargs,
):
make_image = self.get_make_image(prompt, step_callback=step_callback, **kwargs)
results = []
seed = seed if seed else self.new_seed()
# Noise will be generated by the Img2Img generator when called
for _ in trange(iterations, desc="Generating"):
# make_image will call Img2Img which will do the equivalent of get_noise itself
image = make_image()
results.append([image, seed])
if image_callback is not None:
image_callback(image, seed, prompt_in=prompt)
seed = self.new_seed()
return results
@torch.no_grad()
def get_make_image(
self,
prompt,
sampler,
steps,
cfg_scale,
ddim_eta,
conditioning,
init_img,
strength,
width,
height,
embiggen,
embiggen_tiles,
step_callback=None,
**kwargs,
):
"""
Returns a function returning an image derived from the prompt and multi-stage twice-baked potato layering over the img2img on the initial image
Return value depends on the seed at the time you call it
"""
assert (
not sampler.uses_inpainting_model()
), "--embiggen is not supported by inpainting models"
# Construct embiggen arg array, and sanity check arguments
if embiggen == None: # embiggen can also be called with just embiggen_tiles
embiggen = [1.0] # If not specified, assume no scaling
elif embiggen[0] < 0:
embiggen[0] = 1.0
logger.warning(
"Embiggen scaling factor cannot be negative, fell back to the default of 1.0 !"
)
if len(embiggen) < 2:
embiggen.append(0.75)
elif embiggen[1] > 1.0 or embiggen[1] < 0:
embiggen[1] = 0.75
logger.warning(
"Embiggen upscaling strength for ESRGAN must be between 0 and 1, fell back to the default of 0.75 !"
)
if len(embiggen) < 3:
embiggen.append(0.25)
elif embiggen[2] < 0:
embiggen[2] = 0.25
logger.warning(
"Overlap size for Embiggen must be a positive ratio between 0 and 1 OR a number of pixels, fell back to the default of 0.25 !"
)
# Convert tiles from their user-freindly count-from-one to count-from-zero, because we need to do modulo math
# and then sort them, because... people.
if embiggen_tiles:
embiggen_tiles = list(map(lambda n: n - 1, embiggen_tiles))
embiggen_tiles.sort()
if strength >= 0.5:
logger.warning(
f"Embiggen may produce mirror motifs if the strength (-f) is too high (currently {strength}). Try values between 0.35-0.45."
)
# Prep img2img generator, since we wrap over it
gen_img2img = Img2Img(self.model, self.precision)
# Open original init image (not a tensor) to manipulate
initsuperimage = Image.open(init_img)
with Image.open(init_img) as img:
initsuperimage = img.convert("RGB")
# Size of the target super init image in pixels
initsuperwidth, initsuperheight = initsuperimage.size
# Increase by scaling factor if not already resized, using ESRGAN as able
if embiggen[0] != 1.0:
initsuperwidth = round(initsuperwidth * embiggen[0])
initsuperheight = round(initsuperheight * embiggen[0])
if embiggen[1] > 0: # No point in ESRGAN upscaling if strength is set zero
from ..restoration.realesrgan import ESRGAN
esrgan = ESRGAN()
logger.info(
f"ESRGAN upscaling init image prior to cutting with Embiggen with strength {embiggen[1]}"
)
if embiggen[0] > 2:
initsuperimage = esrgan.process(
initsuperimage,
embiggen[1], # upscale strength
self.seed,
4, # upscale scale
)
else:
initsuperimage = esrgan.process(
initsuperimage,
embiggen[1], # upscale strength
self.seed,
2, # upscale scale
)
# We could keep recursively re-running ESRGAN for a requested embiggen[0] larger than 4x
# but from personal experiance it doesn't greatly improve anything after 4x
# Resize to target scaling factor resolution
initsuperimage = initsuperimage.resize(
(initsuperwidth, initsuperheight), Image.Resampling.LANCZOS
)
# Use width and height as tile widths and height
# Determine buffer size in pixels
if embiggen[2] < 1:
if embiggen[2] < 0:
embiggen[2] = 0
overlap_size_x = round(embiggen[2] * width)
overlap_size_y = round(embiggen[2] * height)
else:
overlap_size_x = round(embiggen[2])
overlap_size_y = round(embiggen[2])
# With overall image width and height known, determine how many tiles we need
def ceildiv(a, b):
return -1 * (-a // b)
# X and Y needs to be determined independantly (we may have savings on one based on the buffer pixel count)
# (initsuperwidth - width) is the area remaining to the right that we need to layers tiles to fill
# (width - overlap_size_x) is how much new we can fill with a single tile
emb_tiles_x = 1
emb_tiles_y = 1
if (initsuperwidth - width) > 0:
emb_tiles_x = ceildiv(initsuperwidth - width, width - overlap_size_x) + 1
if (initsuperheight - height) > 0:
emb_tiles_y = ceildiv(initsuperheight - height, height - overlap_size_y) + 1
# Sanity
assert (
emb_tiles_x > 1 or emb_tiles_y > 1
), f"ERROR: Based on the requested dimensions of {initsuperwidth}x{initsuperheight} and tiles of {width}x{height} you don't need to Embiggen! Check your arguments."
# Prep alpha layers --------------
# https://stackoverflow.com/questions/69321734/how-to-create-different-transparency-like-gradient-with-python-pil
# agradientL is Left-side transparent
agradientL = (
Image.linear_gradient("L").rotate(90).resize((overlap_size_x, height))
)
# agradientT is Top-side transparent
agradientT = Image.linear_gradient("L").resize((width, overlap_size_y))
# radial corner is the left-top corner, made full circle then cut to just the left-top quadrant
agradientC = Image.new("L", (256, 256))
for y in range(256):
for x in range(256):
# Find distance to lower right corner (numpy takes arrays)
distanceToLR = np.sqrt([(255 - x) ** 2 + (255 - y) ** 2])[0]
# Clamp values to max 255
if distanceToLR > 255:
distanceToLR = 255
# Place the pixel as invert of distance
agradientC.putpixel((x, y), round(255 - distanceToLR))
# Create alternative asymmetric diagonal corner to use on "tailing" intersections to prevent hard edges
# Fits for a left-fading gradient on the bottom side and full opacity on the right side.
agradientAsymC = Image.new("L", (256, 256))
for y in range(256):
for x in range(256):
value = round(max(0, x - (255 - y)) * (255 / max(1, y)))
# Clamp values
value = max(0, value)
value = min(255, value)
agradientAsymC.putpixel((x, y), value)
# Create alpha layers default fully white
alphaLayerL = Image.new("L", (width, height), 255)
alphaLayerT = Image.new("L", (width, height), 255)
alphaLayerLTC = Image.new("L", (width, height), 255)
# Paste gradients into alpha layers
alphaLayerL.paste(agradientL, (0, 0))
alphaLayerT.paste(agradientT, (0, 0))
alphaLayerLTC.paste(agradientL, (0, 0))
alphaLayerLTC.paste(agradientT, (0, 0))
alphaLayerLTC.paste(agradientC.resize((overlap_size_x, overlap_size_y)), (0, 0))
# make masks with an asymmetric upper-right corner so when the curved transparent corner of the next tile
# to its right is placed it doesn't reveal a hard trailing semi-transparent edge in the overlapping space
alphaLayerTaC = alphaLayerT.copy()
alphaLayerTaC.paste(
agradientAsymC.rotate(270).resize((overlap_size_x, overlap_size_y)),
(width - overlap_size_x, 0),
)
alphaLayerLTaC = alphaLayerLTC.copy()
alphaLayerLTaC.paste(
agradientAsymC.rotate(270).resize((overlap_size_x, overlap_size_y)),
(width - overlap_size_x, 0),
)
if embiggen_tiles:
# Individual unconnected sides
alphaLayerR = Image.new("L", (width, height), 255)
alphaLayerR.paste(agradientL.rotate(180), (width - overlap_size_x, 0))
alphaLayerB = Image.new("L", (width, height), 255)
alphaLayerB.paste(agradientT.rotate(180), (0, height - overlap_size_y))
alphaLayerTB = Image.new("L", (width, height), 255)
alphaLayerTB.paste(agradientT, (0, 0))
alphaLayerTB.paste(agradientT.rotate(180), (0, height - overlap_size_y))
alphaLayerLR = Image.new("L", (width, height), 255)
alphaLayerLR.paste(agradientL, (0, 0))
alphaLayerLR.paste(agradientL.rotate(180), (width - overlap_size_x, 0))
# Sides and corner Layers
alphaLayerRBC = Image.new("L", (width, height), 255)
alphaLayerRBC.paste(agradientL.rotate(180), (width - overlap_size_x, 0))
alphaLayerRBC.paste(agradientT.rotate(180), (0, height - overlap_size_y))
alphaLayerRBC.paste(
agradientC.rotate(180).resize((overlap_size_x, overlap_size_y)),
(width - overlap_size_x, height - overlap_size_y),
)
alphaLayerLBC = Image.new("L", (width, height), 255)
alphaLayerLBC.paste(agradientL, (0, 0))
alphaLayerLBC.paste(agradientT.rotate(180), (0, height - overlap_size_y))
alphaLayerLBC.paste(
agradientC.rotate(90).resize((overlap_size_x, overlap_size_y)),
(0, height - overlap_size_y),
)
alphaLayerRTC = Image.new("L", (width, height), 255)
alphaLayerRTC.paste(agradientL.rotate(180), (width - overlap_size_x, 0))
alphaLayerRTC.paste(agradientT, (0, 0))
alphaLayerRTC.paste(
agradientC.rotate(270).resize((overlap_size_x, overlap_size_y)),
(width - overlap_size_x, 0),
)
# All but X layers
alphaLayerABT = Image.new("L", (width, height), 255)
alphaLayerABT.paste(alphaLayerLBC, (0, 0))
alphaLayerABT.paste(agradientL.rotate(180), (width - overlap_size_x, 0))
alphaLayerABT.paste(
agradientC.rotate(180).resize((overlap_size_x, overlap_size_y)),
(width - overlap_size_x, height - overlap_size_y),
)
alphaLayerABL = Image.new("L", (width, height), 255)
alphaLayerABL.paste(alphaLayerRTC, (0, 0))
alphaLayerABL.paste(agradientT.rotate(180), (0, height - overlap_size_y))
alphaLayerABL.paste(
agradientC.rotate(180).resize((overlap_size_x, overlap_size_y)),
(width - overlap_size_x, height - overlap_size_y),
)
alphaLayerABR = Image.new("L", (width, height), 255)
alphaLayerABR.paste(alphaLayerLBC, (0, 0))
alphaLayerABR.paste(agradientT, (0, 0))
alphaLayerABR.paste(
agradientC.resize((overlap_size_x, overlap_size_y)), (0, 0)
)
alphaLayerABB = Image.new("L", (width, height), 255)
alphaLayerABB.paste(alphaLayerRTC, (0, 0))
alphaLayerABB.paste(agradientL, (0, 0))
alphaLayerABB.paste(
agradientC.resize((overlap_size_x, overlap_size_y)), (0, 0)
)
# All-around layer
alphaLayerAA = Image.new("L", (width, height), 255)
alphaLayerAA.paste(alphaLayerABT, (0, 0))
alphaLayerAA.paste(agradientT, (0, 0))
alphaLayerAA.paste(
agradientC.resize((overlap_size_x, overlap_size_y)), (0, 0)
)
alphaLayerAA.paste(
agradientC.rotate(270).resize((overlap_size_x, overlap_size_y)),
(width - overlap_size_x, 0),
)
# Clean up temporary gradients
del agradientL
del agradientT
del agradientC
def make_image():
# Make main tiles -------------------------------------------------
if embiggen_tiles:
logger.info(f"Making {len(embiggen_tiles)} Embiggen tiles...")
else:
logger.info(
f"Making {(emb_tiles_x * emb_tiles_y)} Embiggen tiles ({emb_tiles_x}x{emb_tiles_y})..."
)
emb_tile_store = []
# Although we could use the same seed for every tile for determinism, at higher strengths this may
# produce duplicated structures for each tile and make the tiling effect more obvious
# instead track and iterate a local seed we pass to Img2Img
seed = self.seed
seedintlimit = (
np.iinfo(np.uint32).max - 1
) # only retreive this one from numpy
for tile in range(emb_tiles_x * emb_tiles_y):
# Don't iterate on first tile
if tile != 0:
if seed < seedintlimit:
seed += 1
else:
seed = 0
# Determine if this is a re-run and replace
if embiggen_tiles and not tile in embiggen_tiles:
continue
# Get row and column entries
emb_row_i = tile // emb_tiles_x
emb_column_i = tile % emb_tiles_x
# Determine bounds to cut up the init image
# Determine upper-left point
if emb_column_i + 1 == emb_tiles_x:
left = initsuperwidth - width
else:
left = round(emb_column_i * (width - overlap_size_x))
if emb_row_i + 1 == emb_tiles_y:
top = initsuperheight - height
else:
top = round(emb_row_i * (height - overlap_size_y))
right = left + width
bottom = top + height
# Cropped image of above dimension (does not modify the original)
newinitimage = initsuperimage.crop((left, top, right, bottom))
# DEBUG:
# newinitimagepath = init_img[0:-4] + f'_emb_Ti{tile}.png'
# newinitimage.save(newinitimagepath)
if embiggen_tiles:
logger.debug(
f"Making tile #{tile + 1} ({embiggen_tiles.index(tile) + 1} of {len(embiggen_tiles)} requested)"
)
else:
logger.debug(f"Starting {tile + 1} of {(emb_tiles_x * emb_tiles_y)} tiles")
# create a torch tensor from an Image
newinitimage = np.array(newinitimage).astype(np.float32) / 255.0
newinitimage = newinitimage[None].transpose(0, 3, 1, 2)
newinitimage = torch.from_numpy(newinitimage)
newinitimage = 2.0 * newinitimage - 1.0
newinitimage = newinitimage.to(self.model.device)
clear_cuda_cache = (
kwargs["clear_cuda_cache"] if "clear_cuda_cache" in kwargs else None
)
tile_results = gen_img2img.generate(
prompt,
iterations=1,
seed=seed,
sampler=sampler,
steps=steps,
cfg_scale=cfg_scale,
conditioning=conditioning,
ddim_eta=ddim_eta,
image_callback=None, # called only after the final image is generated
step_callback=step_callback, # called after each intermediate image is generated
width=width,
height=height,
init_image=newinitimage, # notice that init_image is different from init_img
mask_image=None,
strength=strength,
clear_cuda_cache=clear_cuda_cache,
)
emb_tile_store.append(tile_results[0][0])
# DEBUG (but, also has other uses), worth saving if you want tiles without a transparency overlap to manually composite
# emb_tile_store[-1].save(init_img[0:-4] + f'_emb_To{tile}.png')
del newinitimage
# Sanity check we have them all
if len(emb_tile_store) == (emb_tiles_x * emb_tiles_y) or (
embiggen_tiles != [] and len(emb_tile_store) == len(embiggen_tiles)
):
outputsuperimage = Image.new("RGBA", (initsuperwidth, initsuperheight))
if embiggen_tiles:
outputsuperimage.alpha_composite(
initsuperimage.convert("RGBA"), (0, 0)
)
for tile in range(emb_tiles_x * emb_tiles_y):
if embiggen_tiles:
if tile in embiggen_tiles:
intileimage = emb_tile_store.pop(0)
else:
continue
else:
intileimage = emb_tile_store[tile]
intileimage = intileimage.convert("RGBA")
# Get row and column entries
emb_row_i = tile // emb_tiles_x
emb_column_i = tile % emb_tiles_x
if emb_row_i == 0 and emb_column_i == 0 and not embiggen_tiles:
left = 0
top = 0
else:
# Determine upper-left point
if emb_column_i + 1 == emb_tiles_x:
left = initsuperwidth - width
else:
left = round(emb_column_i * (width - overlap_size_x))
if emb_row_i + 1 == emb_tiles_y:
top = initsuperheight - height
else:
top = round(emb_row_i * (height - overlap_size_y))
# Handle gradients for various conditions
# Handle emb_rerun case
if embiggen_tiles:
# top of image
if emb_row_i == 0:
if emb_column_i == 0:
if (tile + 1) in embiggen_tiles: # Look-ahead right
if (
tile + emb_tiles_x
) not in embiggen_tiles: # Look-ahead down
intileimage.putalpha(alphaLayerB)
# Otherwise do nothing on this tile
elif (
tile + emb_tiles_x
) in embiggen_tiles: # Look-ahead down only
intileimage.putalpha(alphaLayerR)
else:
intileimage.putalpha(alphaLayerRBC)
elif emb_column_i == emb_tiles_x - 1:
if (
tile + emb_tiles_x
) in embiggen_tiles: # Look-ahead down
intileimage.putalpha(alphaLayerL)
else:
intileimage.putalpha(alphaLayerLBC)
else:
if (tile + 1) in embiggen_tiles: # Look-ahead right
if (
tile + emb_tiles_x
) in embiggen_tiles: # Look-ahead down
intileimage.putalpha(alphaLayerL)
else:
intileimage.putalpha(alphaLayerLBC)
elif (
tile + emb_tiles_x
) in embiggen_tiles: # Look-ahead down only
intileimage.putalpha(alphaLayerLR)
else:
intileimage.putalpha(alphaLayerABT)
# bottom of image
elif emb_row_i == emb_tiles_y - 1:
if emb_column_i == 0:
if (tile + 1) in embiggen_tiles: # Look-ahead right
intileimage.putalpha(alphaLayerTaC)
else:
intileimage.putalpha(alphaLayerRTC)
elif emb_column_i == emb_tiles_x - 1:
# No tiles to look ahead to
intileimage.putalpha(alphaLayerLTC)
else:
if (tile + 1) in embiggen_tiles: # Look-ahead right
intileimage.putalpha(alphaLayerLTaC)
else:
intileimage.putalpha(alphaLayerABB)
# vertical middle of image
else:
if emb_column_i == 0:
if (tile + 1) in embiggen_tiles: # Look-ahead right
if (
tile + emb_tiles_x
) in embiggen_tiles: # Look-ahead down
intileimage.putalpha(alphaLayerTaC)
else:
intileimage.putalpha(alphaLayerTB)
elif (
tile + emb_tiles_x
) in embiggen_tiles: # Look-ahead down only
intileimage.putalpha(alphaLayerRTC)
else:
intileimage.putalpha(alphaLayerABL)
elif emb_column_i == emb_tiles_x - 1:
if (
tile + emb_tiles_x
) in embiggen_tiles: # Look-ahead down
intileimage.putalpha(alphaLayerLTC)
else:
intileimage.putalpha(alphaLayerABR)
else:
if (tile + 1) in embiggen_tiles: # Look-ahead right
if (
tile + emb_tiles_x
) in embiggen_tiles: # Look-ahead down
intileimage.putalpha(alphaLayerLTaC)
else:
intileimage.putalpha(alphaLayerABR)
elif (
tile + emb_tiles_x
) in embiggen_tiles: # Look-ahead down only
intileimage.putalpha(alphaLayerABB)
else:
intileimage.putalpha(alphaLayerAA)
# Handle normal tiling case (much simpler - since we tile left to right, top to bottom)
else:
if emb_row_i == 0 and emb_column_i >= 1:
intileimage.putalpha(alphaLayerL)
elif emb_row_i >= 1 and emb_column_i == 0:
if (
emb_column_i + 1 == emb_tiles_x
): # If we don't have anything that can be placed to the right
intileimage.putalpha(alphaLayerT)
else:
intileimage.putalpha(alphaLayerTaC)
else:
if (
emb_column_i + 1 == emb_tiles_x
): # If we don't have anything that can be placed to the right
intileimage.putalpha(alphaLayerLTC)
else:
intileimage.putalpha(alphaLayerLTaC)
# Layer tile onto final image
outputsuperimage.alpha_composite(intileimage, (left, top))
else:
logger.error(
"Could not find all Embiggen output tiles in memory? Something must have gone wrong with img2img generation."
)
# after internal loops and patching up return Embiggen image
return outputsuperimage
# end of function declaration
return make_image

View File

@@ -22,7 +22,6 @@ class Img2Img(Generator):
def get_make_image(
self,
prompt,
sampler,
steps,
cfg_scale,

View File

@@ -161,9 +161,7 @@ class Inpaint(Img2Img):
im: Image.Image,
seam_size: int,
seam_blur: int,
prompt,
seed,
sampler,
steps,
cfg_scale,
ddim_eta,
@@ -177,8 +175,6 @@ class Inpaint(Img2Img):
mask = self.mask_edge(hard_mask, seam_size, seam_blur)
make_image = self.get_make_image(
prompt,
sampler,
steps,
cfg_scale,
ddim_eta,
@@ -203,8 +199,6 @@ class Inpaint(Img2Img):
@torch.no_grad()
def get_make_image(
self,
prompt,
sampler,
steps,
cfg_scale,
ddim_eta,
@@ -306,7 +300,6 @@ class Inpaint(Img2Img):
# noinspection PyTypeChecker
pipeline: StableDiffusionGeneratorPipeline = self.model
pipeline.scheduler = sampler
# todo: support cross-attention control
uc, c, _ = conditioning
@@ -345,9 +338,7 @@ class Inpaint(Img2Img):
result,
seam_size,
seam_blur,
prompt,
seed,
sampler,
seam_steps,
cfg_scale,
ddim_eta,
@@ -360,8 +351,6 @@ class Inpaint(Img2Img):
# Restore original settings
self.get_make_image(
prompt,
sampler,
steps,
cfg_scale,
ddim_eta,

View File

@@ -1,125 +0,0 @@
"""
invokeai.backend.generator.txt2img inherits from invokeai.backend.generator
"""
import PIL.Image
import torch
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from diffusers.models.controlnet import ControlNetModel, ControlNetOutput
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel
from ..stable_diffusion import (
ConditioningData,
PostprocessingSettings,
StableDiffusionGeneratorPipeline,
)
from .base import Generator
class Txt2Img(Generator):
def __init__(self, model, precision,
control_model: Optional[Union[ControlNetModel, List[ControlNetModel]]] = None,
**kwargs):
self.control_model = control_model
if isinstance(self.control_model, list):
self.control_model = MultiControlNetModel(self.control_model)
super().__init__(model, precision, **kwargs)
@torch.no_grad()
def get_make_image(
self,
prompt,
sampler,
steps,
cfg_scale,
ddim_eta,
conditioning,
width,
height,
step_callback=None,
threshold=0.0,
warmup=0.2,
perlin=0.0,
h_symmetry_time_pct=None,
v_symmetry_time_pct=None,
attention_maps_callback=None,
**kwargs,
):
"""
Returns a function returning an image derived from the prompt and the initial image
Return value depends on the seed at the time you call it
kwargs are 'width' and 'height'
"""
self.perlin = perlin
control_image = kwargs.get("control_image", None)
do_classifier_free_guidance = cfg_scale > 1.0
# noinspection PyTypeChecker
pipeline: StableDiffusionGeneratorPipeline = self.model
pipeline.control_model = self.control_model
pipeline.scheduler = sampler
uc, c, extra_conditioning_info = conditioning
conditioning_data = ConditioningData(
uc,
c,
cfg_scale,
extra_conditioning_info,
postprocessing_settings=PostprocessingSettings(
threshold=threshold,
warmup=warmup,
h_symmetry_time_pct=h_symmetry_time_pct,
v_symmetry_time_pct=v_symmetry_time_pct,
),
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)
# FIXME: still need to test with different widths, heights, devices, dtypes
# and add in batch_size, num_images_per_prompt?
if control_image is not None:
if isinstance(self.control_model, ControlNetModel):
control_image = pipeline.prepare_control_image(
image=control_image,
do_classifier_free_guidance=do_classifier_free_guidance,
width=width,
height=height,
# batch_size=batch_size * num_images_per_prompt,
# num_images_per_prompt=num_images_per_prompt,
device=self.control_model.device,
dtype=self.control_model.dtype,
)
elif isinstance(self.control_model, MultiControlNetModel):
images = []
for image_ in control_image:
image_ = self.model.prepare_control_image(
image=image_,
do_classifier_free_guidance=do_classifier_free_guidance,
width=width,
height=height,
# batch_size=batch_size * num_images_per_prompt,
# num_images_per_prompt=num_images_per_prompt,
device=self.control_model.device,
dtype=self.control_model.dtype,
)
images.append(image_)
control_image = images
kwargs["control_image"] = control_image
def make_image(x_T: torch.Tensor, _: int) -> PIL.Image.Image:
pipeline_output = pipeline.image_from_embeddings(
latents=torch.zeros_like(x_T, dtype=self.torch_dtype()),
noise=x_T,
num_inference_steps=steps,
conditioning_data=conditioning_data,
callback=step_callback,
**kwargs,
)
if (
pipeline_output.attention_map_saver is not None
and attention_maps_callback is not None
):
attention_maps_callback(pipeline_output.attention_map_saver)
return pipeline.numpy_to_pil(pipeline_output.images)[0]
return make_image

View File

@@ -1,209 +0,0 @@
"""
invokeai.backend.generator.txt2img inherits from invokeai.backend.generator
"""
import math
from typing import Callable, Optional
import torch
from diffusers.utils.logging import get_verbosity, set_verbosity, set_verbosity_error
from ..stable_diffusion import PostprocessingSettings
from .base import Generator
from ..stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
from ..stable_diffusion.diffusers_pipeline import ConditioningData
from ..stable_diffusion.diffusers_pipeline import trim_to_multiple_of
import invokeai.backend.util.logging as logger
class Txt2Img2Img(Generator):
def __init__(self, model, precision):
super().__init__(model, precision)
self.init_latent = None # for get_noise()
def get_make_image(
self,
prompt: str,
sampler,
steps: int,
cfg_scale: float,
ddim_eta,
conditioning,
width: int,
height: int,
strength: float,
step_callback: Optional[Callable] = None,
threshold=0.0,
warmup=0.2,
perlin=0.0,
h_symmetry_time_pct=None,
v_symmetry_time_pct=None,
attention_maps_callback=None,
**kwargs,
):
"""
Returns a function returning an image derived from the prompt and the initial image
Return value depends on the seed at the time you call it
kwargs are 'width' and 'height'
"""
self.perlin = perlin
# noinspection PyTypeChecker
pipeline: StableDiffusionGeneratorPipeline = self.model
pipeline.scheduler = sampler
uc, c, extra_conditioning_info = conditioning
conditioning_data = ConditioningData(
uc,
c,
cfg_scale,
extra_conditioning_info,
postprocessing_settings=PostprocessingSettings(
threshold=threshold,
warmup=0.2,
h_symmetry_time_pct=h_symmetry_time_pct,
v_symmetry_time_pct=v_symmetry_time_pct,
),
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)
def make_image(x_T: torch.Tensor, _: int):
first_pass_latent_output, _ = pipeline.latents_from_embeddings(
latents=torch.zeros_like(x_T),
num_inference_steps=steps,
conditioning_data=conditioning_data,
noise=x_T,
callback=step_callback,
)
# Get our initial generation width and height directly from the latent output so
# the message below is accurate.
init_width = first_pass_latent_output.size()[3] * self.downsampling_factor
init_height = first_pass_latent_output.size()[2] * self.downsampling_factor
logger.info(
f"Interpolating from {init_width}x{init_height} to {width}x{height} using DDIM sampling"
)
# resizing
resized_latents = torch.nn.functional.interpolate(
first_pass_latent_output,
size=(
height // self.downsampling_factor,
width // self.downsampling_factor,
),
mode="bilinear",
)
# Free up memory from the last generation.
clear_cuda_cache = kwargs["clear_cuda_cache"] or None
if clear_cuda_cache is not None:
clear_cuda_cache()
second_pass_noise = self.get_noise_like(
resized_latents, override_perlin=True
)
# Clear symmetry for the second pass
from dataclasses import replace
new_postprocessing_settings = replace(
conditioning_data.postprocessing_settings, h_symmetry_time_pct=None
)
new_postprocessing_settings = replace(
new_postprocessing_settings, v_symmetry_time_pct=None
)
new_conditioning_data = replace(
conditioning_data, postprocessing_settings=new_postprocessing_settings
)
verbosity = get_verbosity()
set_verbosity_error()
pipeline_output = pipeline.img2img_from_latents_and_embeddings(
resized_latents,
num_inference_steps=steps,
conditioning_data=new_conditioning_data,
strength=strength,
noise=second_pass_noise,
callback=step_callback,
)
set_verbosity(verbosity)
if (
pipeline_output.attention_map_saver is not None
and attention_maps_callback is not None
):
attention_maps_callback(pipeline_output.attention_map_saver)
return pipeline.numpy_to_pil(pipeline_output.images)[0]
# FIXME: do we really need something entirely different for the inpainting model?
# in the case of the inpainting model being loaded, the trick of
# providing an interpolated latent doesn't work, so we transiently
# create a 512x512 PIL image, upscale it, and run the inpainting
# over it in img2img mode. Because the inpaing model is so conservative
# it doesn't change the image (much)
return make_image
def get_noise_like(self, like: torch.Tensor, override_perlin: bool = False):
device = like.device
if device.type == "mps":
x = torch.randn_like(like, device="cpu", dtype=self.torch_dtype()).to(
device
)
else:
x = torch.randn_like(like, device=device, dtype=self.torch_dtype())
if self.perlin > 0.0 and override_perlin == False:
shape = like.shape
x = (1 - self.perlin) * x + self.perlin * self.get_perlin_noise(
shape[3], shape[2]
)
return x
# returns a tensor filled with random numbers from a normal distribution
def get_noise(self, width, height, scale=True):
# print(f"Get noise: {width}x{height}")
if scale:
# Scale the input width and height for the initial generation
# Make their area equivalent to the model's resolution area (e.g. 512*512 = 262144),
# while keeping the minimum dimension at least 0.5 * resolution (e.g. 512*0.5 = 256)
aspect = width / height
dimension = self.model.unet.config.sample_size * self.model.vae_scale_factor
min_dimension = math.floor(dimension * 0.5)
model_area = (
dimension * dimension
) # hardcoded for now since all models are trained on square images
if aspect > 1.0:
init_height = max(min_dimension, math.sqrt(model_area / aspect))
init_width = init_height * aspect
else:
init_width = max(min_dimension, math.sqrt(model_area * aspect))
init_height = init_width / aspect
scaled_width, scaled_height = trim_to_multiple_of(
math.floor(init_width), math.floor(init_height)
)
else:
scaled_width = width
scaled_height = height
device = self.model.device
channels = self.latent_channels
if channels == 9:
channels = 4 # we don't really want noise for all the mask channels
shape = (
1,
channels,
scaled_height // self.downsampling_factor,
scaled_width // self.downsampling_factor,
)
if self.use_mps_noise or device.type == "mps":
tensor = torch.empty(size=shape, device="cpu")
tensor = self.get_noise_like(like=tensor).to(device)
else:
tensor = torch.empty(size=shape, device=device)
tensor = self.get_noise_like(like=tensor)
return tensor

View File

@@ -9,6 +9,7 @@ SAMPLER_CHOICES = [
"ddpm",
"deis",
"lms",
"lms_k",
"pndm",
"heun",
"heun_k",
@@ -18,8 +19,13 @@ SAMPLER_CHOICES = [
"kdpm_2",
"kdpm_2_a",
"dpmpp_2s",
"dpmpp_2s_k",
"dpmpp_2m",
"dpmpp_2m_k",
"dpmpp_2m_sde",
"dpmpp_2m_sde_k",
"dpmpp_sde",
"dpmpp_sde_k",
"unipc",
]

View File

@@ -1,11 +1,6 @@
"""
Initialization file for invokeai.backend.model_management
"""
from .convert_ckpt_to_diffusers import (
convert_ckpt_to_diffusers,
load_pipeline_from_original_stable_diffusion_ckpt,
)
from .model_manager import ModelManager,SDModelComponent
from .model_manager import ModelManager, ModelInfo
from .model_cache import ModelCache
from .models import BaseModelType, ModelType, SubModelType, ModelVariantType

View File

@@ -28,10 +28,13 @@ from safetensors.torch import load_file
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
from .model_manager import ModelManager, SDLegacyType
from .model_manager import ModelManager
from .model_cache import ModelCache
from .models import SchedulerPredictionType, BaseModelType, ModelVariantType
try:
from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
except ImportError:
raise ImportError(
"OmegaConf is required to convert the LDM checkpoints. Please install it with `pip install OmegaConf`."
@@ -56,10 +59,6 @@ from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import (
LDMBertConfig,
LDMBertModel,
)
from diffusers.pipelines.paint_by_example import (
PaintByExampleImageEncoder,
PaintByExamplePipeline,
)
from diffusers.pipelines.stable_diffusion.safety_checker import (
StableDiffusionSafetyChecker,
)
@@ -74,6 +73,8 @@ from transformers import (
from ..stable_diffusion import StableDiffusionGeneratorPipeline
MODEL_ROOT = None
def shave_segments(path, n_shave_prefix_segments=1):
"""
Removes segments. Positive values shave the first segments, negative shave the last segments.
@@ -158,17 +159,17 @@ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
new_item = new_item.replace("norm.weight", "group_norm.weight")
new_item = new_item.replace("norm.bias", "group_norm.bias")
new_item = new_item.replace("q.weight", "query.weight")
new_item = new_item.replace("q.bias", "query.bias")
new_item = new_item.replace("q.weight", "to_q.weight")
new_item = new_item.replace("q.bias", "to_q.bias")
new_item = new_item.replace("k.weight", "key.weight")
new_item = new_item.replace("k.bias", "key.bias")
new_item = new_item.replace("k.weight", "to_k.weight")
new_item = new_item.replace("k.bias", "to_k.bias")
new_item = new_item.replace("v.weight", "value.weight")
new_item = new_item.replace("v.bias", "value.bias")
new_item = new_item.replace("v.weight", "to_v.weight")
new_item = new_item.replace("v.bias", "to_v.bias")
new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
new_item = new_item.replace("proj_out.weight", "to_out.0.weight")
new_item = new_item.replace("proj_out.bias", "to_out.0.bias")
new_item = shave_segments(
new_item, n_shave_prefix_segments=n_shave_prefix_segments
@@ -183,7 +184,6 @@ def assign_to_checkpoint(
paths,
checkpoint,
old_checkpoint,
attention_paths_to_split=None,
additional_replacements=None,
config=None,
):
@@ -198,35 +198,9 @@ def assign_to_checkpoint(
paths, list
), "Paths should be a list of dicts containing 'old' and 'new' keys."
# Splits the attention layers into three variables.
if attention_paths_to_split is not None:
for path, path_map in attention_paths_to_split.items():
old_tensor = old_checkpoint[path]
channels = old_tensor.shape[0] // 3
target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
old_tensor = old_tensor.reshape(
(num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]
)
query, key, value = old_tensor.split(channels // num_heads, dim=1)
checkpoint[path_map["query"]] = query.reshape(target_shape)
checkpoint[path_map["key"]] = key.reshape(target_shape)
checkpoint[path_map["value"]] = value.reshape(target_shape)
for path in paths:
new_path = path["new"]
# These have already been assigned
if (
attention_paths_to_split is not None
and new_path in attention_paths_to_split
):
continue
# Global renaming happens here
new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
@@ -245,14 +219,14 @@ def assign_to_checkpoint(
def conv_attn_to_linear(checkpoint):
keys = list(checkpoint.keys())
attn_keys = ["query.weight", "key.weight", "value.weight"]
attn_keys = ["to_q.weight", "to_k.weight", "to_v.weight"]
for key in keys:
if ".".join(key.split(".")[-2:]) in attn_keys:
if checkpoint[key].ndim > 2:
checkpoint[key] = checkpoint[key][:, :, 0, 0]
elif "proj_attn.weight" in key:
elif "to_out.0.weight" in key:
if checkpoint[key].ndim > 2:
checkpoint[key] = checkpoint[key][:, :, 0]
checkpoint[key] = checkpoint[key][:, :, 0, 0]
def create_unet_diffusers_config(original_config, image_size: int):
@@ -612,16 +586,29 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
return new_checkpoint
def convert_ldm_vae_checkpoint(checkpoint, config):
# extract state dict for VAE
vae_state_dict = {}
vae_key = "first_stage_model."
keys = list(checkpoint.keys())
for key in keys:
if key.startswith(vae_key):
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
# Extract state dict for VAE. Works both with burnt-in
# VAEs, and with standalone VAEs.
# checkpoint can either be a all-in-one stable diffusion
# model, or an isolated vae .ckpt. This tests for
# a key that will be present in the all-in-one model
# that isn't present in the isolated ckpt.
probe_key = "first_stage_model.encoder.conv_in.weight"
if probe_key in checkpoint:
vae_state_dict = {}
vae_key = "first_stage_model."
keys = list(checkpoint.keys())
for key in keys:
if key.startswith(vae_key):
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
else:
vae_state_dict = checkpoint
new_checkpoint = convert_ldm_vae_state_dict(vae_state_dict,config)
return new_checkpoint
def convert_ldm_vae_state_dict(vae_state_dict, config):
new_checkpoint = {}
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
@@ -841,10 +828,7 @@ def convert_ldm_bert_checkpoint(checkpoint, config):
def convert_ldm_clip_checkpoint(checkpoint):
text_model = CLIPTextModel.from_pretrained(
"openai/clip-vit-large-patch14", cache_dir=InvokeAIAppConfig.get_config().cache_dir
)
text_model = CLIPTextModel.from_pretrained(MODEL_ROOT / 'clip-vit-large-patch14')
keys = list(checkpoint.keys())
text_model_dict = {}
@@ -896,82 +880,10 @@ protected = {re.escape(x[0]): x[1] for x in textenc_transformer_conversion_lst}
textenc_pattern = re.compile("|".join(protected.keys()))
def convert_paint_by_example_checkpoint(checkpoint):
cache_dir = InvokeAIAppConfig.get_config().cache_dir
config = CLIPVisionConfig.from_pretrained(
"openai/clip-vit-large-patch14", cache_dir=cache_dir
)
model = PaintByExampleImageEncoder(config)
keys = list(checkpoint.keys())
text_model_dict = {}
for key in keys:
if key.startswith("cond_stage_model.transformer"):
text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[
key
]
# load clip vision
model.model.load_state_dict(text_model_dict)
# load mapper
keys_mapper = {
k[len("cond_stage_model.mapper.res") :]: v
for k, v in checkpoint.items()
if k.startswith("cond_stage_model.mapper")
}
MAPPING = {
"attn.c_qkv": ["attn1.to_q", "attn1.to_k", "attn1.to_v"],
"attn.c_proj": ["attn1.to_out.0"],
"ln_1": ["norm1"],
"ln_2": ["norm3"],
"mlp.c_fc": ["ff.net.0.proj"],
"mlp.c_proj": ["ff.net.2"],
}
mapped_weights = {}
for key, value in keys_mapper.items():
prefix = key[: len("blocks.i")]
suffix = key.split(prefix)[-1].split(".")[-1]
name = key.split(prefix)[-1].split(suffix)[0][1:-1]
mapped_names = MAPPING[name]
num_splits = len(mapped_names)
for i, mapped_name in enumerate(mapped_names):
new_name = ".".join([prefix, mapped_name, suffix])
shape = value.shape[0] // num_splits
mapped_weights[new_name] = value[i * shape : (i + 1) * shape]
model.mapper.load_state_dict(mapped_weights)
# load final layer norm
model.final_layer_norm.load_state_dict(
{
"bias": checkpoint["cond_stage_model.final_ln.bias"],
"weight": checkpoint["cond_stage_model.final_ln.weight"],
}
)
# load final proj
model.proj_out.load_state_dict(
{
"bias": checkpoint["proj_out.bias"],
"weight": checkpoint["proj_out.weight"],
}
)
# load uncond vector
model.uncond_vector.data = torch.nn.Parameter(checkpoint["learnable_vector"])
return model
def convert_open_clip_checkpoint(checkpoint):
cache_dir = InvokeAIAppConfig.get_config().cache_dir
text_model = CLIPTextModel.from_pretrained(
"stabilityai/stable-diffusion-2", subfolder="text_encoder", cache_dir=cache_dir
MODEL_ROOT / 'stable-diffusion-2-clip',
subfolder='text_encoder',
)
keys = list(checkpoint.keys())
@@ -1047,22 +959,30 @@ def replace_checkpoint_vae(checkpoint, vae_path:str):
new_key = f'first_stage_model.{vae_key}'
checkpoint[new_key] = state_dict[vae_key]
def convert_ldm_vae_to_diffusers(checkpoint, vae_config: DictConfig, image_size: int)->AutoencoderKL:
vae_config = create_vae_diffusers_config(
vae_config, image_size=image_size
)
converted_vae_checkpoint = convert_ldm_vae_checkpoint(
checkpoint, vae_config
)
vae = AutoencoderKL(**vae_config)
vae.load_state_dict(converted_vae_checkpoint)
return vae
def load_pipeline_from_original_stable_diffusion_ckpt(
checkpoint_path: str,
original_config_file: str = None,
num_in_channels: int = None,
scheduler_type: str = "pndm",
pipeline_type: str = None,
image_size: int = None,
prediction_type: str = None,
model_version: BaseModelType,
model_variant: ModelVariantType,
original_config_file: str,
extract_ema: bool = True,
upcast_attn: bool = False,
vae: AutoencoderKL = None,
vae_path: str = None,
precision: torch.dtype = torch.float32,
return_generator_pipeline: bool = False,
scan_needed:bool=True,
) -> Union[StableDiffusionPipeline, StableDiffusionGeneratorPipeline]:
upcast_attention: bool = False,
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon,
scan_needed: bool = True,
) -> StableDiffusionPipeline:
"""
Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml`
config file.
@@ -1074,148 +994,68 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
:param checkpoint_path: Path to `.ckpt` file.
:param original_config_file: Path to `.yaml` config file corresponding to the original architecture.
If `None`, will be automatically inferred by looking for a key that only exists in SD2.0 models.
:param image_size: The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Diffusion v2
Base. Use 768 for Stable Diffusion v2.
:param prediction_type: The prediction type that the model was trained on. Use `'epsilon'` for Stable Diffusion
v1.X and Stable Diffusion v2 Base. Use `'v-prediction'` for Stable Diffusion v2.
:param num_in_channels: The number of input channels. If `None` number of input channels will be automatically
inferred.
:param scheduler_type: Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler",
"euler-ancestral", "dpm", "ddim"]`. :param model_type: The pipeline type. `None` to automatically infer, or one of
`["FrozenOpenCLIPEmbedder", "FrozenCLIPEmbedder", "PaintByExample"]`. :param extract_ema: Only relevant for
`["FrozenOpenCLIPEmbedder", "FrozenCLIPEmbedder"]`. :param extract_ema: Only relevant for
checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights
or not. Defaults to `False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher
quality images for inference. Non-EMA weights are usually better to continue fine-tuning.
:param precision: precision to use - torch.float16, torch.float32 or torch.autocast
:param upcast_attention: Whether the attention computation should always be upcasted. This is necessary when
running stable diffusion 2.1.
:param vae: A diffusers VAE to load into the pipeline.
:param vae_path: Path to a checkpoint VAE that will be converted into diffusers and loaded into the pipeline.
"""
config = InvokeAIAppConfig.get_config()
cache_dir = config.cache_dir
with warnings.catch_warnings():
warnings.simplefilter("ignore")
verbosity = dlogging.get_verbosity()
dlogging.set_verbosity_error()
if Path(checkpoint_path).suffix == '.ckpt':
if scan_needed:
ModelManager.scan_model(checkpoint_path,checkpoint_path)
checkpoint = torch.load(checkpoint_path)
else:
if str(checkpoint_path).endswith(".safetensors"):
checkpoint = load_file(checkpoint_path)
pipeline_class = (
StableDiffusionGeneratorPipeline
if return_generator_pipeline
else StableDiffusionPipeline
)
# Sometimes models don't have the global_step item
if "global_step" in checkpoint:
global_step = checkpoint["global_step"]
else:
logger.debug("global_step key not found in model")
global_step = None
if scan_needed:
ModelCache.scan_model(checkpoint_path, checkpoint_path)
checkpoint = torch.load(checkpoint_path)
# sometimes there is a state_dict key and sometimes not
if "state_dict" in checkpoint:
checkpoint = checkpoint["state_dict"]
upcast_attention = False
if original_config_file is None:
model_type = ModelManager.probe_model_type(checkpoint)
if model_type == SDLegacyType.V2_v:
original_config_file = (
config.legacy_conf_path / "v2-inference-v.yaml"
)
if global_step == 110000:
# v2.1 needs to upcast attention
upcast_attention = True
elif model_type == SDLegacyType.V2_e:
original_config_file = (
config.legacy_conf_path / "v2-inference.yaml"
)
elif model_type == SDLegacyType.V1_INPAINT:
original_config_file = (
config.legacy_conf_path / "v1-inpainting-inference.yaml"
)
elif model_type == SDLegacyType.V1:
original_config_file = (
config.legacy_conf_path / "v1-inference.yaml"
)
else:
raise Exception("Unknown checkpoint type")
original_config = OmegaConf.load(original_config_file)
if num_in_channels is not None:
original_config["model"]["params"]["unet_config"]["params"][
"in_channels"
] = num_in_channels
if (
"parameterization" in original_config["model"]["params"]
and original_config["model"]["params"]["parameterization"] == "v"
):
if prediction_type is None:
# NOTE: For stable diffusion 2 base it is recommended to pass `prediction_type=="epsilon"`
# as it relies on a brittle global step parameter here
prediction_type = "epsilon" if global_step == 875000 else "v_prediction"
if image_size is None:
# NOTE: For stable diffusion 2 base one has to pass `image_size==512`
# as it relies on a brittle global step parameter here
image_size = 512 if global_step == 875000 else 768
if model_version == BaseModelType.StableDiffusion2 and prediction_type == SchedulerPredictionType.VPrediction:
image_size = 768
else:
if prediction_type is None:
prediction_type = "epsilon"
if image_size is None:
image_size = 512
image_size = 512
#
# convert scheduler
#
num_train_timesteps = original_config.model.params.timesteps
beta_start = original_config.model.params.linear_start
beta_end = original_config.model.params.linear_end
scheduler = DDIMScheduler(
scheduler = PNDMScheduler(
beta_end=beta_end,
beta_schedule="scaled_linear",
beta_start=beta_start,
num_train_timesteps=num_train_timesteps,
steps_offset=1,
clip_sample=False,
set_alpha_to_one=False,
prediction_type=prediction_type,
skip_prk_steps=True
)
# make sure scheduler works correctly with DDIM
scheduler.register_to_config(clip_sample=False)
if scheduler_type == "pndm":
config = dict(scheduler.config)
config["skip_prk_steps"] = True
scheduler = PNDMScheduler.from_config(config)
elif scheduler_type == "lms":
scheduler = LMSDiscreteScheduler.from_config(scheduler.config)
elif scheduler_type == "heun":
scheduler = HeunDiscreteScheduler.from_config(scheduler.config)
elif scheduler_type == "euler":
scheduler = EulerDiscreteScheduler.from_config(scheduler.config)
elif scheduler_type == "euler-ancestral":
scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler.config)
elif scheduler_type == "dpm":
scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)
elif scheduler_type == 'unipc':
scheduler = UniPCMultistepScheduler.from_config(scheduler.config)
elif scheduler_type == "ddim":
scheduler = scheduler
else:
raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!")
#
# convert unet
#
# Convert the UNet2DConditionModel model.
unet_config = create_unet_diffusers_config(
original_config, image_size=image_size
)
@@ -1228,44 +1068,25 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
unet.load_state_dict(converted_unet_checkpoint)
# If a replacement VAE path was specified, we'll incorporate that into
# the checkpoint model and then convert it
if vae_path:
logger.debug(f"Converting VAE {vae_path}")
replace_checkpoint_vae(checkpoint,vae_path)
# otherwise we use the original VAE, provided that
# an externally loaded diffusers VAE was not passed
elif not vae:
logger.debug("Using checkpoint model's original VAE")
#
# convert vae
#
if vae:
logger.debug("Using replacement diffusers VAE")
else: # convert the original or replacement VAE
vae_config = create_vae_diffusers_config(
original_config, image_size=image_size
)
converted_vae_checkpoint = convert_ldm_vae_checkpoint(
checkpoint, vae_config
)
vae = AutoencoderKL(**vae_config)
vae.load_state_dict(converted_vae_checkpoint)
vae = convert_ldm_vae_to_diffusers(
checkpoint,
original_config,
image_size,
)
# Convert the text model.
model_type = pipeline_type
if model_type is None:
model_type = original_config.model.params.cond_stage_config.target.split(
"."
)[-1]
model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
if model_type == "FrozenOpenCLIPEmbedder":
text_model = convert_open_clip_checkpoint(checkpoint)
tokenizer = CLIPTokenizer.from_pretrained(
"stabilityai/stable-diffusion-2",
subfolder="tokenizer",
cache_dir=cache_dir,
MODEL_ROOT / 'stable-diffusion-2-clip',
subfolder='tokenizer',
)
pipe = pipeline_class(
pipe = StableDiffusionPipeline(
vae=vae.to(precision),
text_encoder=text_model.to(precision),
tokenizer=tokenizer,
@@ -1275,49 +1096,26 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
feature_extractor=None,
requires_safety_checker=False,
)
elif model_type == "PaintByExample":
vision_model = convert_paint_by_example_checkpoint(checkpoint)
tokenizer = CLIPTokenizer.from_pretrained(
"openai/clip-vit-large-patch14", cache_dir=cache_dir
)
feature_extractor = AutoFeatureExtractor.from_pretrained(
"CompVis/stable-diffusion-safety-checker", cache_dir=cache_dir
)
pipe = PaintByExamplePipeline(
vae=vae,
image_encoder=vision_model,
unet=unet,
scheduler=scheduler,
safety_checker=None,
feature_extractor=feature_extractor,
)
elif model_type in ["FrozenCLIPEmbedder", "WeightedFrozenCLIPEmbedder"]:
text_model = convert_ldm_clip_checkpoint(checkpoint)
tokenizer = CLIPTokenizer.from_pretrained(
"openai/clip-vit-large-patch14", cache_dir=cache_dir
)
safety_checker = StableDiffusionSafetyChecker.from_pretrained(
"CompVis/stable-diffusion-safety-checker",
cache_dir=cache_dir,
)
feature_extractor = AutoFeatureExtractor.from_pretrained(
"CompVis/stable-diffusion-safety-checker", cache_dir=cache_dir
)
pipe = pipeline_class(
tokenizer = CLIPTokenizer.from_pretrained(MODEL_ROOT / 'clip-vit-large-patch14')
safety_checker = StableDiffusionSafetyChecker.from_pretrained(MODEL_ROOT / 'stable-diffusion-safety-checker')
feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_ROOT / 'stable-diffusion-safety-checker')
pipe = StableDiffusionPipeline(
vae=vae.to(precision),
text_encoder=text_model.to(precision),
tokenizer=tokenizer,
unet=unet.to(precision),
scheduler=scheduler,
safety_checker=None if return_generator_pipeline else safety_checker.to(precision),
safety_checker=safety_checker.to(precision),
feature_extractor=feature_extractor,
)
else:
text_config = create_ldm_bert_config(original_config)
text_model = convert_ldm_bert_checkpoint(checkpoint, text_config)
tokenizer = BertTokenizerFast.from_pretrained(
"bert-base-uncased", cache_dir=cache_dir
)
tokenizer = BertTokenizerFast.from_pretrained(MODEL_ROOT / "bert-base-uncased")
pipe = LDMTextToImagePipeline(
vqvae=vae,
bert=text_model,
@@ -1331,15 +1129,19 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
def convert_ckpt_to_diffusers(
checkpoint_path: Union[str, Path],
dump_path: Union[str, Path],
**kwargs,
checkpoint_path: Union[str, Path],
dump_path: Union[str, Path],
model_root: Union[str, Path],
**kwargs,
):
"""
Takes all the arguments of load_pipeline_from_original_stable_diffusion_ckpt(),
and in addition a path-like object indicating the location of the desired diffusers
model to be written.
"""
# setting global here to avoid massive changes late at night
global MODEL_ROOT
MODEL_ROOT = Path(model_root) / 'core/convert'
pipe = load_pipeline_from_original_stable_diffusion_ckpt(checkpoint_path, **kwargs)
pipe.save_pretrained(

View File

@@ -0,0 +1,678 @@
from __future__ import annotations
import copy
from pathlib import Path
from contextlib import contextmanager
from typing import Optional, Dict, Tuple, Any
import torch
from safetensors.torch import load_file
from torch.utils.hooks import RemovableHandle
from diffusers.models import UNet2DConditionModel
from transformers import CLIPTextModel
from compel.embeddings_provider import BaseTextualInversionManager
class LoRALayerBase:
#rank: Optional[int]
#alpha: Optional[float]
#bias: Optional[torch.Tensor]
#layer_key: str
#@property
#def scale(self):
# return self.alpha / self.rank if (self.alpha and self.rank) else 1.0
def __init__(
self,
layer_key: str,
values: dict,
):
if "alpha" in values:
self.alpha = values["alpha"].item()
else:
self.alpha = None
if (
"bias_indices" in values
and "bias_values" in values
and "bias_size" in values
):
self.bias = torch.sparse_coo_tensor(
values["bias_indices"],
values["bias_values"],
tuple(values["bias_size"]),
)
else:
self.bias = None
self.rank = None # set in layer implementation
self.layer_key = layer_key
def forward(
self,
module: torch.nn.Module,
input_h: Any, # for real looks like Tuple[torch.nn.Tensor] but not sure
multiplier: float,
):
if type(module) == torch.nn.Conv2d:
op = torch.nn.functional.conv2d
extra_args = dict(
stride=module.stride,
padding=module.padding,
dilation=module.dilation,
groups=module.groups,
)
else:
op = torch.nn.functional.linear
extra_args = {}
weight = self.get_weight(module)
bias = self.bias if self.bias is not None else 0
scale = self.alpha / self.rank if (self.alpha and self.rank) else 1.0
return op(
*input_h,
(weight + bias).view(module.weight.shape),
None,
**extra_args,
) * multiplier * scale
def get_weight(self, module: torch.nn.Module):
raise NotImplementedError()
def calc_size(self) -> int:
model_size = 0
for val in [self.bias]:
if val is not None:
model_size += val.nelement() * val.element_size()
return model_size
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
if self.bias is not None:
self.bias = self.bias.to(device=device, dtype=dtype)
# TODO: find and debug lora/locon with bias
class LoRALayer(LoRALayerBase):
#up: torch.Tensor
#mid: Optional[torch.Tensor]
#down: torch.Tensor
def __init__(
self,
layer_key: str,
values: dict,
):
super().__init__(layer_key, values)
self.up = values["lora_up.weight"]
self.down = values["lora_down.weight"]
if "lora_mid.weight" in values:
self.mid = values["lora_mid.weight"]
else:
self.mid = None
self.rank = self.down.shape[0]
def get_weight(self, module: torch.nn.Module):
if self.mid is not None:
up = self.up.reshape(up.shape[0], up.shape[1])
down = self.down.reshape(up.shape[0], up.shape[1])
weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down)
else:
weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1)
return weight
def calc_size(self) -> int:
model_size = super().calc_size()
for val in [self.up, self.mid, self.down]:
if val is not None:
model_size += val.nelement() * val.element_size()
return model_size
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
super().to(device=device, dtype=dtype)
self.up = self.up.to(device=device, dtype=dtype)
self.down = self.down.to(device=device, dtype=dtype)
if self.mid is not None:
self.mid = self.mid.to(device=device, dtype=dtype)
class LoHALayer(LoRALayerBase):
#w1_a: torch.Tensor
#w1_b: torch.Tensor
#w2_a: torch.Tensor
#w2_b: torch.Tensor
#t1: Optional[torch.Tensor] = None
#t2: Optional[torch.Tensor] = None
def __init__(
self,
layer_key: str,
values: dict,
):
super().__init__(module_key, rank, alpha, bias)
self.w1_a = values["hada_w1_a"]
self.w1_b = values["hada_w1_b"]
self.w2_a = values["hada_w2_a"]
self.w2_b = values["hada_w2_b"]
if "hada_t1" in values:
self.t1 = values["hada_t1"]
else:
self.t1 = None
if "hada_t2" in values:
self.t2 = values["hada_t2"]
else:
self.t2 = None
self.rank = self.w1_b.shape[0]
def get_weight(self, module: torch.nn.Module):
if self.t1 is None:
weight = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
else:
rebuild1 = torch.einsum(
"i j k l, j r, i p -> p r k l", self.t1, self.w1_b, self.w1_a
)
rebuild2 = torch.einsum(
"i j k l, j r, i p -> p r k l", self.t2, self.w2_b, self.w2_a
)
weight = rebuild1 * rebuild2
return weight
def calc_size(self) -> int:
model_size = super().calc_size()
for val in [self.w1_a, self.w1_b, self.w2_a, self.w2_b, self.t1, self.t2]:
if val is not None:
model_size += val.nelement() * val.element_size()
return model_size
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
super().to(device=device, dtype=dtype)
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
if self.t1 is not None:
self.t1 = self.t1.to(device=device, dtype=dtype)
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
if self.t2 is not None:
self.t2 = self.t2.to(device=device, dtype=dtype)
class LoKRLayer(LoRALayerBase):
#w1: Optional[torch.Tensor] = None
#w1_a: Optional[torch.Tensor] = None
#w1_b: Optional[torch.Tensor] = None
#w2: Optional[torch.Tensor] = None
#w2_a: Optional[torch.Tensor] = None
#w2_b: Optional[torch.Tensor] = None
#t2: Optional[torch.Tensor] = None
def __init__(
self,
layer_key: str,
values: dict,
):
super().__init__(module_key, rank, alpha, bias)
if "lokr_w1" in values:
self.w1 = values["lokr_w1"]
self.w1_a = None
self.w1_b = None
else:
self.w1 = None
self.w1_a = values["lokr_w1_a"]
self.w1_b = values["lokr_w1_b"]
if "lokr_w2" in values:
self.w2 = values["lokr_w2"]
self.w2_a = None
self.w2_b = None
else:
self.w2 = None
self.w2_a = values["lokr_w2_a"]
self.w2_b = values["lokr_w2_b"]
if "lokr_t2" in values:
self.t2 = values["lokr_t2"]
else:
self.t2 = None
if "lokr_w1_b" in values:
self.rank = values["lokr_w1_b"].shape[0]
elif "lokr_w2_b" in values:
self.rank = values["lokr_w2_b"].shape[0]
else:
self.rank = None # unscaled
def get_weight(self, module: torch.nn.Module):
w1 = self.w1
if w1 is None:
w1 = self.w1_a @ self.w1_b
w2 = self.w2
if w2 is None:
if self.t2 is None:
w2 = self.w2_a @ self.w2_b
else:
w2 = torch.einsum('i j k l, i p, j r -> p r k l', self.t2, self.w2_a, self.w2_b)
if len(w2.shape) == 4:
w1 = w1.unsqueeze(2).unsqueeze(2)
w2 = w2.contiguous()
weight = torch.kron(w1, w2).reshape(module.weight.shape) # TODO: can we remove reshape?
return weight
def calc_size(self) -> int:
model_size = super().calc_size()
for val in [self.w1, self.w1_a, self.w1_b, self.w2, self.w2_a, self.w2_b, self.t2]:
if val is not None:
model_size += val.nelement() * val.element_size()
return model_size
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
super().to(device=device, dtype=dtype)
if self.w1 is not None:
self.w1 = self.w1.to(device=device, dtype=dtype)
else:
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
if self.w2 is not None:
self.w2 = self.w2.to(device=device, dtype=dtype)
else:
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
if self.t2 is not None:
self.t2 = self.t2.to(device=device, dtype=dtype)
class LoRAModel: #(torch.nn.Module):
_name: str
layers: Dict[str, LoRALayer]
_device: torch.device
_dtype: torch.dtype
def __init__(
self,
name: str,
layers: Dict[str, LoRALayer],
device: torch.device,
dtype: torch.dtype,
):
self._name = name
self._device = device or torch.cpu
self._dtype = dtype or torch.float32
self.layers = layers
@property
def name(self):
return self._name
@property
def device(self):
return self._device
@property
def dtype(self):
return self._dtype
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
) -> LoRAModel:
# TODO: try revert if exception?
for key, layer in self.layers.items():
layer.to(device=device, dtype=dtype)
self._device = device
self._dtype = dtype
def calc_size(self) -> int:
model_size = 0
for _, layer in self.layers.items():
model_size += layer.calc_size()
return model_size
@classmethod
def from_checkpoint(
cls,
file_path: Union[str, Path],
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
device = device or torch.device("cpu")
dtype = dtype or torch.float32
if isinstance(file_path, str):
file_path = Path(file_path)
model = cls(
device=device,
dtype=dtype,
name=file_path.stem, # TODO:
layers=dict(),
)
if file_path.suffix == ".safetensors":
state_dict = load_file(file_path.absolute().as_posix(), device="cpu")
else:
state_dict = torch.load(file_path, map_location="cpu")
state_dict = cls._group_state(state_dict)
for layer_key, values in state_dict.items():
# lora and locon
if "lora_down.weight" in values:
layer = LoRALayer(layer_key, values)
# loha
elif "hada_w1_b" in values:
layer = LoHALayer(layer_key, values)
# lokr
elif "lokr_w1_b" in values or "lokr_w1" in values:
layer = LoKRLayer(layer_key, values)
else:
# TODO: diff/ia3/... format
print(
f">> Encountered unknown lora layer module in {self.name}: {layer_key}"
)
return
# lower memory consumption by removing already parsed layer values
state_dict[layer_key].clear()
layer.to(device=device, dtype=dtype)
model.layers[layer_key] = layer
return model
@staticmethod
def _group_state(state_dict: dict):
state_dict_groupped = dict()
for key, value in state_dict.items():
stem, leaf = key.split(".", 1)
if stem not in state_dict_groupped:
state_dict_groupped[stem] = dict()
state_dict_groupped[stem][leaf] = value
return state_dict_groupped
"""
loras = [
(lora_model1, 0.7),
(lora_model2, 0.4),
]
with LoRAHelper.apply_lora_unet(unet, loras):
# unet with applied loras
# unmodified unet
"""
# TODO: rename smth like ModelPatcher and add TI method?
class ModelPatcher:
@staticmethod
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
assert "." not in lora_key
if not lora_key.startswith(prefix):
raise Exception(f"lora_key with invalid prefix: {lora_key}, {prefix}")
module = model
module_key = ""
key_parts = lora_key[len(prefix):].split('_')
submodule_name = key_parts.pop(0)
while len(key_parts) > 0:
try:
module = module.get_submodule(submodule_name)
module_key += "." + submodule_name
submodule_name = key_parts.pop(0)
except:
submodule_name += "_" + key_parts.pop(0)
module = module.get_submodule(submodule_name)
module_key = module_key.rstrip(".")
return (module_key, module)
@staticmethod
def _lora_forward_hook(
applied_loras: List[Tuple[LoraModel, float]],
layer_name: str,
):
def lora_forward(module, input_h, output):
if len(applied_loras) == 0:
return output
for lora, weight in applied_loras:
layer = lora.layers.get(layer_name, None)
if layer is None:
continue
output += layer.forward(module, input_h, weight)
return output
return lora_forward
@classmethod
@contextmanager
def apply_lora_unet(
cls,
unet: UNet2DConditionModel,
loras: List[Tuple[LoRAModel, float]],
):
with cls.apply_lora(unet, loras, "lora_unet_"):
yield
@classmethod
@contextmanager
def apply_lora_text_encoder(
cls,
text_encoder: CLIPTextModel,
loras: List[Tuple[LoRAModel, float]],
):
with cls.apply_lora(text_encoder, loras, "lora_te_"):
yield
@classmethod
@contextmanager
def apply_lora(
cls,
model: torch.nn.Module,
loras: List[Tuple[LoraModel, float]],
prefix: str,
):
hooks = dict()
try:
for lora, lora_weight in loras:
for layer_key, layer in lora.layers.items():
if not layer_key.startswith(prefix):
continue
module_key, module = cls._resolve_lora_key(model, layer_key, prefix)
if module_key not in hooks:
hooks[module_key] = module.register_forward_hook(cls._lora_forward_hook(loras, layer_key))
yield # wait for context manager exit
finally:
for module_key, hook in hooks.items():
hook.remove()
hooks.clear()
@classmethod
@contextmanager
def apply_ti(
cls,
tokenizer: CLIPTokenizer,
text_encoder: CLIPTextModel,
ti_list: List[Any],
) -> Tuple[CLIPTokenizer, TextualInversionManager]:
init_tokens_count = None
new_tokens_added = None
try:
ti_tokenizer = copy.deepcopy(tokenizer)
ti_manager = TextualInversionManager(ti_tokenizer)
init_tokens_count = text_encoder.resize_token_embeddings(None).num_embeddings
def _get_trigger(ti, index):
trigger = ti.name
if index > 0:
trigger += f"-!pad-{i}"
return f"<{trigger}>"
# modify tokenizer
new_tokens_added = 0
for ti in ti_list:
for i in range(ti.embedding.shape[0]):
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti, i))
# modify text_encoder
text_encoder.resize_token_embeddings(init_tokens_count + new_tokens_added)
model_embeddings = text_encoder.get_input_embeddings()
for ti in ti_list:
ti_tokens = []
for i in range(ti.embedding.shape[0]):
embedding = ti.embedding[i]
trigger = _get_trigger(ti, i)
token_id = ti_tokenizer.convert_tokens_to_ids(trigger)
if token_id == ti_tokenizer.unk_token_id:
raise RuntimeError(f"Unable to find token id for token '{trigger}'")
if model_embeddings.weight.data[token_id].shape != embedding.shape:
raise ValueError(
f"Cannot load embedding for {trigger}. It was trained on a model with token dimension {embedding.shape[0]}, but the current model has token dimension {model_embeddings.weight.data[token_id].shape[0]}."
)
model_embeddings.weight.data[token_id] = embedding
ti_tokens.append(token_id)
if len(ti_tokens) > 1:
ti_manager.pad_tokens[ti_tokens[0]] = ti_tokens[1:]
yield ti_tokenizer, ti_manager
finally:
if init_tokens_count and new_tokens_added:
text_encoder.resize_token_embeddings(init_tokens_count)
class TextualInversionModel:
name: str
embedding: torch.Tensor # [n, 768]|[n, 1280]
@classmethod
def from_checkpoint(
cls,
file_path: Union[str, Path],
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
if not isinstance(file_path, Path):
file_path = Path(file_path)
result = cls() # TODO:
result.name = file_path.stem # TODO:
if file_path.suffix == ".safetensors":
state_dict = load_file(file_path.absolute().as_posix(), device="cpu")
else:
state_dict = torch.load(file_path, map_location="cpu")
# both v1 and v2 format embeddings
# difference mostly in metadata
if "string_to_param" in state_dict:
if len(state_dict["string_to_param"]) > 1:
print(f"Warn: Embedding \"{file_path.name}\" contains multiple tokens, which is not supported. The first token will be used.")
result.embedding = next(iter(state_dict["string_to_param"].values()))
# v3 (easynegative)
elif "emb_params" in state_dict:
result.embedding = state_dict["emb_params"]
# v4(diffusers bin files)
else:
result.embedding = next(iter(state_dict.values()))
if not isinstance(result.embedding, torch.Tensor):
raise ValueError(f"Invalid embeddings file: {file_path.name}")
return result
class TextualInversionManager(BaseTextualInversionManager):
pad_tokens: Dict[int, List[int]]
tokenizer: CLIPTokenizer
def __init__(self, tokenizer: CLIPTokenizer):
self.pad_tokens = dict()
self.tokenizer = tokenizer
def expand_textual_inversion_token_ids_if_necessary(
self, token_ids: list[int]
) -> list[int]:
if len(self.pad_tokens) == 0:
return token_ids
if token_ids[0] == self.tokenizer.bos_token_id:
raise ValueError("token_ids must not start with bos_token_id")
if token_ids[-1] == self.tokenizer.eos_token_id:
raise ValueError("token_ids must not end with eos_token_id")
new_token_ids = []
for token_id in token_ids:
new_token_ids.append(token_id)
if token_id in self.pad_tokens:
new_token_ids.extend(self.pad_tokens[token_id])
return new_token_ids

View File

@@ -0,0 +1,391 @@
"""
Manage a RAM cache of diffusion/transformer models for fast switching.
They are moved between GPU VRAM and CPU RAM as necessary. If the cache
grows larger than a preset maximum, then the least recently used
model will be cleared and (re)loaded from disk when next needed.
The cache returns context manager generators designed to load the
model into the GPU within the context, and unload outside the
context. Use like this:
cache = ModelCache(max_models_cached=6)
with cache.get_model('runwayml/stable-diffusion-1-5') as SD1,
cache.get_model('stabilityai/stable-diffusion-2') as SD2:
do_something_in_GPU(SD1,SD2)
"""
import gc
import os
import sys
import hashlib
from contextlib import suppress
from pathlib import Path
from typing import Dict, Union, types, Optional, Type, Any
import torch
import logging
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import get_invokeai_config
from .lora import LoRAModel, TextualInversionModel
from .models import BaseModelType, ModelType, SubModelType, ModelBase
# Maximum size of the cache, in gigs
# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously
DEFAULT_MAX_CACHE_SIZE = 6.0
# actual size of a gig
GIG = 1073741824
class ModelLocker(object):
"Forward declaration"
pass
class ModelCache(object):
"Forward declaration"
pass
class _CacheRecord:
size: int
model: Any
cache: ModelCache
_locks: int
def __init__(self, cache, model: Any, size: int):
self.size = size
self.model = model
self.cache = cache
self._locks = 0
def lock(self):
self._locks += 1
def unlock(self):
self._locks -= 1
assert self._locks >= 0
@property
def locked(self):
return self._locks > 0
@property
def loaded(self):
if self.model is not None and hasattr(self.model, "device"):
return self.model.device != self.cache.storage_device
else:
return False
class ModelCache(object):
def __init__(
self,
max_cache_size: float=DEFAULT_MAX_CACHE_SIZE,
execution_device: torch.device=torch.device('cuda'),
storage_device: torch.device=torch.device('cpu'),
precision: torch.dtype=torch.float16,
sequential_offload: bool=False,
lazy_offloading: bool=True,
sha_chunksize: int = 16777216,
logger: types.ModuleType = logger
):
'''
:param max_models: Maximum number of models to cache in CPU RAM [4]
:param execution_device: Torch device to load active model into [torch.device('cuda')]
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
:param precision: Precision for loaded models [torch.float16]
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
:param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially
:param sha_chunksize: Chunksize to use when calculating sha256 model hash
'''
#max_cache_size = 9999
execution_device = torch.device('cuda')
self.model_infos: Dict[str, ModelBase] = dict()
self.lazy_offloading = lazy_offloading
#self.sequential_offload: bool=sequential_offload
self.precision: torch.dtype=precision
self.max_cache_size: int=max_cache_size
self.execution_device: torch.device=execution_device
self.storage_device: torch.device=storage_device
self.sha_chunksize=sha_chunksize
self.logger = logger
self._cached_models = dict()
self._cache_stack = list()
def get_key(
self,
model_path: str,
base_model: BaseModelType,
model_type: ModelType,
submodel_type: Optional[SubModelType] = None,
):
key = f"{model_path}:{base_model}:{model_type}"
if submodel_type:
key += f":{submodel_type}"
return key
#def get_model(
# self,
# repo_id_or_path: Union[str, Path],
# model_type: ModelType = ModelType.Diffusers,
# subfolder: Path = None,
# submodel: ModelType = None,
# revision: str = None,
# attach_model_part: Tuple[ModelType, str] = (None, None),
# gpu_load: bool = True,
#) -> ModelLocker: # ?? what does it return
def _get_model_info(
self,
model_path: str,
model_class: Type[ModelBase],
base_model: BaseModelType,
model_type: ModelType,
):
model_info_key = self.get_key(
model_path=model_path,
base_model=base_model,
model_type=model_type,
submodel_type=None,
)
if model_info_key not in self.model_infos:
self.model_infos[model_info_key] = model_class(
model_path,
base_model,
model_type,
)
return self.model_infos[model_info_key]
# TODO: args
def get_model(
self,
model_path: Union[str, Path],
model_class: Type[ModelBase],
base_model: BaseModelType,
model_type: ModelType,
submodel: Optional[SubModelType] = None,
gpu_load: bool = True,
) -> Any:
if not isinstance(model_path, Path):
model_path = Path(model_path)
if not os.path.exists(model_path):
raise Exception(f"Model not found: {model_path}")
model_info = self._get_model_info(
model_path=model_path,
model_class=model_class,
base_model=base_model,
model_type=model_type,
)
key = self.get_key(
model_path=model_path,
base_model=base_model,
model_type=model_type,
submodel_type=submodel,
)
# TODO: lock for no copies on simultaneous calls?
cache_entry = self._cached_models.get(key, None)
if cache_entry is None:
self.logger.info(f'Loading model {model_path}, type {base_model}:{model_type}:{submodel}')
# this will remove older cached models until
# there is sufficient room to load the requested model
self._make_cache_room(model_info.get_size(submodel))
# clean memory to make MemoryUsage() more accurate
gc.collect()
model = model_info.get_model(child_type=submodel, torch_dtype=self.precision)
if mem_used := model_info.get_size(submodel):
self.logger.debug(f'CPU RAM used for load: {(mem_used/GIG):.2f} GB')
cache_entry = _CacheRecord(self, model, mem_used)
self._cached_models[key] = cache_entry
with suppress(Exception):
self._cache_stack.remove(key)
self._cache_stack.append(key)
return self.ModelLocker(self, key, cache_entry.model, gpu_load)
class ModelLocker(object):
def __init__(self, cache, key, model, gpu_load):
self.gpu_load = gpu_load
self.cache = cache
self.key = key
self.model = model
self.cache_entry = self.cache._cached_models[self.key]
def __enter__(self) -> Any:
if not hasattr(self.model, 'to'):
return self.model
# NOTE that the model has to have the to() method in order for this
# code to move it into GPU!
if self.gpu_load:
self.cache_entry.lock()
try:
if self.cache.lazy_offloading:
self.cache._offload_unlocked_models()
if self.model.device != self.cache.execution_device:
self.cache.logger.debug(f'Moving {self.key} into {self.cache.execution_device}')
with VRAMUsage() as mem:
self.model.to(self.cache.execution_device) # move into GPU
self.cache.logger.debug(f'GPU VRAM used for load: {(mem.vram_used/GIG):.2f} GB')
self.cache.logger.debug(f'Locking {self.key} in {self.cache.execution_device}')
self.cache._print_cuda_stats()
except:
self.cache_entry.unlock()
raise
# TODO: not fully understand
# in the event that the caller wants the model in RAM, we
# move it into CPU if it is in GPU and not locked
elif self.cache_entry.loaded and not self.cache_entry.locked:
self.model.to(self.cache.storage_device)
return self.model
def __exit__(self, type, value, traceback):
if not hasattr(self.model, 'to'):
return
self.cache_entry.unlock()
if not self.cache.lazy_offloading:
self.cache._offload_unlocked_models()
self.cache._print_cuda_stats()
# TODO: should it be called untrack_model?
def uncache_model(self, cache_id: str):
with suppress(ValueError):
self._cache_stack.remove(cache_id)
self._cached_models.pop(cache_id, None)
def model_hash(
self,
model_path: Union[str, Path],
) -> str:
'''
Given the HF repo id or path to a model on disk, returns a unique
hash. Works for legacy checkpoint files, HF models on disk, and HF repo IDs
:param model_path: Path to model file/directory on disk.
'''
return self._local_model_hash(model_path)
def cache_size(self) -> float:
"Return the current size of the cache, in GB"
current_cache_size = sum([m.size for m in self._cached_models.values()])
return current_cache_size / GIG
def _has_cuda(self) -> bool:
return self.execution_device.type == 'cuda'
def _print_cuda_stats(self):
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG)
ram = "%4.2fG" % self.cache_size()
cached_models = 0
loaded_models = 0
locked_models = 0
for model_info in self._cached_models.values():
cached_models += 1
if model_info.loaded:
loaded_models += 1
if model_info.locked:
locked_models += 1
self.logger.debug(f"Current VRAM/RAM usage: {vram}/{ram}; cached_models/loaded_models/locked_models/ = {cached_models}/{loaded_models}/{locked_models}")
def _make_cache_room(self, model_size):
# calculate how much memory this model will require
#multiplier = 2 if self.precision==torch.float32 else 1
bytes_needed = model_size
maximum_size = self.max_cache_size * GIG # stored in GB, convert to bytes
current_size = sum([m.size for m in self._cached_models.values()])
if current_size + bytes_needed > maximum_size:
self.logger.debug(f'Max cache size exceeded: {(current_size/GIG):.2f}/{self.max_cache_size:.2f} GB, need an additional {(bytes_needed/GIG):.2f} GB')
self.logger.debug(f"Before unloading: cached_models={len(self._cached_models)}")
pos = 0
while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack):
model_key = self._cache_stack[pos]
cache_entry = self._cached_models[model_key]
refs = sys.getrefcount(cache_entry.model)
device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None
self.logger.debug(f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded}, refs: {refs}")
# 2 refs:
# 1 from cache_entry
# 1 from getrefcount function
if not cache_entry.locked and refs <= 2:
self.logger.debug(f'Unloading model {model_key} to free {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)')
current_size -= cache_entry.size
del self._cache_stack[pos]
del self._cached_models[model_key]
del cache_entry
else:
pos += 1
gc.collect()
torch.cuda.empty_cache()
self.logger.debug(f"After unloading: cached_models={len(self._cached_models)}")
def _offload_unlocked_models(self):
for model_key, cache_entry in self._cached_models.items():
if not cache_entry.locked and cache_entry.loaded:
self.logger.debug(f'Offloading {model_key} from {self.execution_device} into {self.storage_device}')
cache_entry.model.to(self.storage_device)
def _local_model_hash(self, model_path: Union[str, Path]) -> str:
sha = hashlib.sha256()
path = Path(model_path)
hashpath = path / "checksum.sha256"
if hashpath.exists() and path.stat().st_mtime <= hashpath.stat().st_mtime:
with open(hashpath) as f:
hash = f.read()
return hash
self.logger.debug(f'computing hash of model {path.name}')
for file in list(path.rglob("*.ckpt")) \
+ list(path.rglob("*.safetensors")) \
+ list(path.rglob("*.pth")):
with open(file, "rb") as f:
while chunk := f.read(self.sha_chunksize):
sha.update(chunk)
hash = sha.hexdigest()
with open(hashpath, "w") as f:
f.write(hash)
return hash
class VRAMUsage(object):
def __init__(self):
self.vram = None
self.vram_used = 0
def __enter__(self):
self.vram = torch.cuda.memory_allocated()
return self
def __exit__(self, *args):
self.vram_used = torch.cuda.memory_allocated() - self.vram

View File

@@ -0,0 +1,118 @@
"""
Routines for downloading and installing models.
"""
import json
import safetensors
import safetensors.torch
import shutil
import tempfile
import torch
import traceback
from dataclasses import dataclass
from diffusers import ModelMixin
from enum import Enum
from typing import Callable
from pathlib import Path
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
from . import ModelManager
from .models import BaseModelType, ModelType, VariantType
from .model_probe import ModelProbe, ModelVariantInfo
from .model_cache import SilenceWarnings
class ModelInstall(object):
'''
This class is able to download and install several different kinds of
InvokeAI models. The helper function, if provided, is called on to distinguish
between v2-base and v2-768 stable diffusion pipelines. This usually involves
asking the user to select the proper type, as there is no way of distinguishing
the two type of v2 file programmatically (as far as I know).
'''
def __init__(self,
config: InvokeAIAppConfig,
model_base_helper: Callable[[Path],BaseModelType]=None,
clobber:bool = False
):
'''
:param config: InvokeAI configuration object
:param model_base_helper: A function call that accepts the Path to a checkpoint model and returns a ModelType enum
:param clobber: If true, models with colliding names will be overwritten
'''
self.config = config
self.clogger = clobber
self.helper = model_base_helper
self.prober = ModelProbe()
def install_checkpoint_file(self, checkpoint: Path)->dict:
'''
Install the checkpoint file at path and return a
configuration entry that can be added to `models.yaml`.
Model checkpoints and VAEs will be converted into
diffusers before installation. Note that the model manager
does not hold entries for anything but diffusers pipelines,
and the configuration file stanzas returned from such models
can be safely ignored.
'''
model_info = self.prober.probe(checkpoint, self.helper)
if not model_info:
raise ValueError(f"Unable to determine type of checkpoint file {checkpoint}")
key = ModelManager.create_key(
model_name = checkpoint.stem,
base_model = model_info.base_type,
model_type = model_info.model_type,
)
destination_path = self._dest_path(model_info) / checkpoint
destination_path.parent.mkdir(parents=True, exist_ok=True)
self._check_for_collision(destination_path)
stanza = {
key: dict(
name = checkpoint.stem,
description = f'{model_info.model_type} model {checkpoint.stem}',
base = model_info.base_model.value,
type = model_info.model_type.value,
variant = model_info.variant_type.value,
path = str(destination_path),
)
}
# non-pipeline; no conversion needed, just copy into right place
if model_info.model_type != ModelType.Pipeline:
shutil.copyfile(checkpoint, destination_path)
stanza[key].update({'format': 'checkpoint'})
# pipeline - conversion needed here
else:
destination_path = self._dest_path(model_info) / checkpoint.stem
config_file = self._pipeline_type_to_config_file(model_info.model_type)
from .convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
with SilenceWarnings:
convert_ckpt_to_diffusers(
checkpoint,
destination_path,
extract_ema=True,
original_config_file=config_file,
scan_needed=False,
)
stanza[key].update({'format': 'folder',
'path': destination_path, # no suffix on this
})
return stanza
def _check_for_collision(self, path: Path):
if not path.exists():
return
if self.clobber:
shutil.rmtree(path)
else:
raise ValueError(f"Destination {path} already exists. Won't overwrite unless clobber=True.")
def _staging_directory(self)->tempfile.TemporaryDirectory:
return tempfile.TemporaryDirectory(dir=self.config.root_path)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,417 @@
import json
import traceback
import torch
import safetensors.torch
from dataclasses import dataclass
from enum import Enum
from diffusers import ModelMixin, ConfigMixin, StableDiffusionPipeline, AutoencoderKL, ControlNetModel
from pathlib import Path
from typing import Callable, Literal, Union, Dict
from picklescan.scanner import scan_file_path
import invokeai.backend.util.logging as logger
from .models import BaseModelType, ModelType, ModelVariantType, SchedulerPredictionType, SilenceWarnings
@dataclass
class ModelVariantInfo(object):
model_type: ModelType
base_type: BaseModelType
variant_type: ModelVariantType
prediction_type: SchedulerPredictionType
upcast_attention: bool
format: Literal['folder','checkpoint']
image_size: int
class ProbeBase(object):
'''forward declaration'''
pass
class ModelProbe(object):
PROBES = {
'folder': { },
'checkpoint': { },
}
CLASS2TYPE = {
'StableDiffusionPipeline' : ModelType.Pipeline,
'AutoencoderKL' : ModelType.Vae,
'ControlNetModel' : ModelType.ControlNet,
}
@classmethod
def register_probe(cls,
format: Literal['folder','file'],
model_type: ModelType,
probe_class: ProbeBase):
cls.PROBES[format][model_type] = probe_class
@classmethod
def heuristic_probe(cls,
model: Union[Dict, ModelMixin, Path],
prediction_type_helper: Callable[[Path],BaseModelType]=None,
)->ModelVariantInfo:
if isinstance(model,Path):
return cls.probe(model_path=model,prediction_type_helper=prediction_type_helper)
elif isinstance(model,(dict,ModelMixin,ConfigMixin)):
return cls.probe(model_path=None, model=model, prediction_type_helper=prediction_type_helper)
else:
raise Exception("model parameter {model} is neither a Path, nor a model")
@classmethod
def probe(cls,
model_path: Path,
model: Union[Dict, ModelMixin] = None,
prediction_type_helper: Callable[[Path],BaseModelType] = None)->ModelVariantInfo:
'''
Probe the model at model_path and return sufficient information about it
to place it somewhere in the models directory hierarchy. If the model is
already loaded into memory, you may provide it as model in order to avoid
opening it a second time. The prediction_type_helper callable is a function that receives
the path to the model and returns the BaseModelType. It is called to distinguish
between V2-Base and V2-768 SD models.
'''
if model_path:
format = 'folder' if model_path.is_dir() else 'checkpoint'
else:
format = 'folder' if isinstance(model,(ConfigMixin,ModelMixin)) else 'checkpoint'
model_info = None
try:
model_type = cls.get_model_type_from_folder(model_path, model) \
if format == 'folder' \
else cls.get_model_type_from_checkpoint(model_path, model)
probe_class = cls.PROBES[format].get(model_type)
if not probe_class:
return None
probe = probe_class(model_path, model, prediction_type_helper)
base_type = probe.get_base_type()
variant_type = probe.get_variant_type()
prediction_type = probe.get_scheduler_prediction_type()
model_info = ModelVariantInfo(
model_type = model_type,
base_type = base_type,
variant_type = variant_type,
prediction_type = prediction_type,
upcast_attention = (base_type==BaseModelType.StableDiffusion2 \
and prediction_type==SchedulerPredictionType.VPrediction),
format = format,
image_size = 768 if (base_type==BaseModelType.StableDiffusion2 \
and prediction_type==SchedulerPredictionType.VPrediction \
) else 512,
)
except Exception as e:
return None
return model_info
@classmethod
def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: dict)->ModelType:
if model_path.suffix not in ('.bin','.pt','.ckpt','.safetensors'):
return None
if model_path.name=='learned_embeds.bin':
return ModelType.TextualInversion
checkpoint = checkpoint or cls._scan_and_load_checkpoint(model_path)
state_dict = checkpoint.get("state_dict") or checkpoint
if any([x.startswith("model.diffusion_model") for x in state_dict.keys()]):
return ModelType.Pipeline
if any([x.startswith("encoder.conv_in") for x in state_dict.keys()]):
return ModelType.Vae
if "string_to_token" in state_dict or "emb_params" in state_dict:
return ModelType.TextualInversion
if any([x.startswith("lora") for x in state_dict.keys()]):
return ModelType.Lora
if any([x.startswith("control_model") for x in state_dict.keys()]):
return ModelType.ControlNet
if any([x.startswith("input_blocks") for x in state_dict.keys()]):
return ModelType.ControlNet
return None # give up
@classmethod
def get_model_type_from_folder(cls, folder_path: Path, model: ModelMixin)->ModelType:
'''
Get the model type of a hugging-face style folder.
'''
class_name = None
if model:
class_name = model.__class__.__name__
else:
if (folder_path / 'learned_embeds.bin').exists():
return ModelType.TextualInversion
if (folder_path / 'pytorch_lora_weights.bin').exists():
return ModelType.Lora
i = folder_path / 'model_index.json'
c = folder_path / 'config.json'
config_path = i if i.exists() else c if c.exists() else None
if config_path:
with open(config_path,'r') as file:
conf = json.load(file)
class_name = conf['_class_name']
if class_name and (type := cls.CLASS2TYPE.get(class_name)):
return type
# give up
raise ValueError("Unable to determine model type")
@classmethod
def _scan_and_load_checkpoint(cls,model_path: Path)->dict:
with SilenceWarnings():
if model_path.suffix.endswith((".ckpt", ".pt", ".bin")):
cls._scan_model(model_path, model_path)
return torch.load(model_path)
else:
return safetensors.torch.load_file(model_path)
@classmethod
def _scan_model(cls, model_name, checkpoint):
"""
Apply picklescanner to the indicated checkpoint and issue a warning
and option to exit if an infected file is identified.
"""
# scan model
scan_result = scan_file_path(checkpoint)
if scan_result.infected_files != 0:
raise "The model {model_name} is potentially infected by malware. Aborting import."
###################################################3
# Checkpoint probing
###################################################3
class ProbeBase(object):
def get_base_type(self)->BaseModelType:
pass
def get_variant_type(self)->ModelVariantType:
pass
def get_scheduler_prediction_type(self)->SchedulerPredictionType:
pass
class CheckpointProbeBase(ProbeBase):
def __init__(self,
checkpoint_path: Path,
checkpoint: dict,
helper: Callable[[Path],BaseModelType] = None
)->BaseModelType:
self.checkpoint = checkpoint or ModelProbe._scan_and_load_checkpoint(checkpoint_path)
self.checkpoint_path = checkpoint_path
self.helper = helper
def get_base_type(self)->BaseModelType:
pass
def get_variant_type(self)-> ModelVariantType:
model_type = ModelProbe.get_model_type_from_checkpoint(self.checkpoint_path,self.checkpoint)
if model_type != ModelType.Pipeline:
return ModelVariantType.Normal
state_dict = self.checkpoint.get('state_dict') or self.checkpoint
in_channels = state_dict[
"model.diffusion_model.input_blocks.0.0.weight"
].shape[1]
if in_channels == 9:
return ModelVariantType.Inpaint
elif in_channels == 5:
return ModelVariantType.Depth
elif in_channels == 4:
return ModelVariantType.Normal
else:
raise Exception("Cannot determine variant type")
class PipelineCheckpointProbe(CheckpointProbeBase):
def get_base_type(self)->BaseModelType:
checkpoint = self.checkpoint
state_dict = self.checkpoint.get('state_dict') or checkpoint
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
if key_name in state_dict and state_dict[key_name].shape[-1] == 768:
return BaseModelType.StableDiffusion1
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
return BaseModelType.StableDiffusion2
raise Exception("Cannot determine base type")
def get_scheduler_prediction_type(self)->SchedulerPredictionType:
type = self.get_base_type()
if type == BaseModelType.StableDiffusion1:
return SchedulerPredictionType.Epsilon
checkpoint = self.checkpoint
state_dict = self.checkpoint.get('state_dict') or checkpoint
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
if 'global_step' in checkpoint:
if checkpoint['global_step'] == 220000:
return SchedulerPredictionType.Epsilon
elif checkpoint["global_step"] == 110000:
return SchedulerPredictionType.VPrediction
if self.checkpoint_path and self.helper:
return self.helper(self.checkpoint_path)
else:
return None
class VaeCheckpointProbe(CheckpointProbeBase):
def get_base_type(self)->BaseModelType:
# I can't find any standalone 2.X VAEs to test with!
return BaseModelType.StableDiffusion1
class LoRACheckpointProbe(CheckpointProbeBase):
def get_base_type(self)->BaseModelType:
checkpoint = self.checkpoint
key1 = "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight"
key2 = "lora_te_text_model_encoder_layers_0_self_attn_k_proj.hada_w1_a"
lora_token_vector_length = (
checkpoint[key1].shape[1]
if key1 in checkpoint
else checkpoint[key2].shape[0]
if key2 in checkpoint
else 768
)
if lora_token_vector_length == 768:
return BaseModelType.StableDiffusion1
elif lora_token_vector_length == 1024:
return BaseModelType.StableDiffusion2
else:
return None
class TextualInversionCheckpointProbe(CheckpointProbeBase):
def get_base_type(self)->BaseModelType:
checkpoint = self.checkpoint
if 'string_to_token' in checkpoint:
token_dim = list(checkpoint['string_to_param'].values())[0].shape[-1]
elif 'emb_params' in checkpoint:
token_dim = checkpoint['emb_params'].shape[-1]
else:
token_dim = list(checkpoint.values())[0].shape[0]
if token_dim == 768:
return BaseModelType.StableDiffusion1
elif token_dim == 1024:
return BaseModelType.StableDiffusion2
else:
return None
class ControlNetCheckpointProbe(CheckpointProbeBase):
def get_base_type(self)->BaseModelType:
checkpoint = self.checkpoint
for key_name in ('control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight',
'input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight'
):
if key_name not in checkpoint:
continue
if checkpoint[key_name].shape[-1] == 768:
return BaseModelType.StableDiffusion1
elif checkpoint[key_name].shape[-1] == 1024:
return BaseModelType.StableDiffusion2
elif self.checkpoint_path and self.helper:
return self.helper(self.checkpoint_path)
raise Exception("Unable to determine base type for {self.checkpoint_path}")
########################################################
# classes for probing folders
#######################################################
class FolderProbeBase(ProbeBase):
def __init__(self,
folder_path: Path,
model: ModelMixin = None,
helper: Callable=None # not used
):
self.model = model
self.folder_path = folder_path
def get_variant_type(self)->ModelVariantType:
return ModelVariantType.Normal
class PipelineFolderProbe(FolderProbeBase):
def get_base_type(self)->BaseModelType:
if self.model:
unet_conf = self.model.unet.config
scheduler_conf = self.model.scheduler.config
else:
with open(self.folder_path / 'unet' / 'config.json','r') as file:
unet_conf = json.load(file)
with open(self.folder_path / 'scheduler' / 'scheduler_config.json','r') as file:
scheduler_conf = json.load(file)
if unet_conf['cross_attention_dim'] == 768:
return BaseModelType.StableDiffusion1
elif unet_conf['cross_attention_dim'] == 1024:
return BaseModelType.StableDiffusion2
else:
raise ValueError(f'Unknown base model for {self.folder_path}')
def get_scheduler_prediction_type(self)->SchedulerPredictionType:
if self.model:
scheduler_conf = self.model.scheduler.config
else:
with open(self.folder_path / 'scheduler' / 'scheduler_config.json','r') as file:
scheduler_conf = json.load(file)
if scheduler_conf['prediction_type'] == "v_prediction":
return SchedulerPredictionType.VPrediction
elif scheduler_conf['prediction_type'] == 'epsilon':
return SchedulerPredictionType.Epsilon
else:
return None
def get_variant_type(self)->ModelVariantType:
# This only works for pipelines! Any kind of
# exception results in our returning the
# "normal" variant type
try:
if self.model:
conf = self.model.unet.config
else:
config_file = self.folder_path / 'unet' / 'config.json'
with open(config_file,'r') as file:
conf = json.load(file)
in_channels = conf['in_channels']
if in_channels == 9:
return ModelVariantType.Inpainting
elif in_channels == 5:
return ModelVariantType.Depth
elif in_channels == 4:
return ModelVariantType.Normal
except:
pass
return ModelVariantType.Normal
class VaeFolderProbe(FolderProbeBase):
def get_base_type(self)->BaseModelType:
return BaseModelType.StableDiffusion1
class TextualInversionFolderProbe(FolderProbeBase):
def get_base_type(self)->BaseModelType:
path = self.folder_path / 'learned_embeds.bin'
if not path.exists():
return None
checkpoint = ModelProbe._scan_and_load_checkpoint(path)
return TextualInversionCheckpointProbe(None,checkpoint=checkpoint).get_base_type()
class ControlNetFolderProbe(FolderProbeBase):
def get_base_type(self)->BaseModelType:
config_file = self.folder_path / 'config.json'
if not config_file.exists():
raise Exception(f"Cannot determine base type for {self.folder_path}")
with open(config_file,'r') as file:
config = json.load(file)
# no obvious way to distinguish between sd2-base and sd2-768
return BaseModelType.StableDiffusion1 \
if config['cross_attention_dim']==768 \
else BaseModelType.StableDiffusion2
class LoRAFolderProbe(FolderProbeBase):
# I've never seen one of these in the wild, so this is a noop
pass
############## register probe classes ######
ModelProbe.register_probe('folder', ModelType.Pipeline, PipelineFolderProbe)
ModelProbe.register_probe('folder', ModelType.Vae, VaeFolderProbe)
ModelProbe.register_probe('folder', ModelType.Lora, LoRAFolderProbe)
ModelProbe.register_probe('folder', ModelType.TextualInversion, TextualInversionFolderProbe)
ModelProbe.register_probe('folder', ModelType.ControlNet, ControlNetFolderProbe)
ModelProbe.register_probe('checkpoint', ModelType.Pipeline, PipelineCheckpointProbe)
ModelProbe.register_probe('checkpoint', ModelType.Vae, VaeCheckpointProbe)
ModelProbe.register_probe('checkpoint', ModelType.Lora, LoRACheckpointProbe)
ModelProbe.register_probe('checkpoint', ModelType.TextualInversion, TextualInversionCheckpointProbe)
ModelProbe.register_probe('checkpoint', ModelType.ControlNet, ControlNetCheckpointProbe)

View File

@@ -0,0 +1,95 @@
import inspect
from enum import Enum
from pydantic import BaseModel
from typing import Literal, get_origin
from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, ModelVariantType, SchedulerPredictionType, ModelError, SilenceWarnings
from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model
from .vae import VaeModel
from .lora import LoRAModel
from .controlnet import ControlNetModel # TODO:
from .textual_inversion import TextualInversionModel
MODEL_CLASSES = {
BaseModelType.StableDiffusion1: {
ModelType.Pipeline: StableDiffusion1Model,
ModelType.Vae: VaeModel,
ModelType.Lora: LoRAModel,
ModelType.ControlNet: ControlNetModel,
ModelType.TextualInversion: TextualInversionModel,
},
BaseModelType.StableDiffusion2: {
ModelType.Pipeline: StableDiffusion2Model,
ModelType.Vae: VaeModel,
ModelType.Lora: LoRAModel,
ModelType.ControlNet: ControlNetModel,
ModelType.TextualInversion: TextualInversionModel,
},
#BaseModelType.Kandinsky2_1: {
# ModelType.Pipeline: Kandinsky2_1Model,
# ModelType.MoVQ: MoVQModel,
# ModelType.Lora: LoRAModel,
# ModelType.ControlNet: ControlNetModel,
# ModelType.TextualInversion: TextualInversionModel,
#},
}
MODEL_CONFIGS = list()
OPENAPI_MODEL_CONFIGS = list()
class OpenAPIModelInfoBase(BaseModel):
name: str
base_model: BaseModelType
type: ModelType
for base_model, models in MODEL_CLASSES.items():
for model_type, model_class in models.items():
model_configs = set(model_class._get_configs().values())
model_configs.discard(None)
MODEL_CONFIGS.extend(model_configs)
for cfg in model_configs:
model_name, cfg_name = cfg.__qualname__.split('.')[-2:]
openapi_cfg_name = model_name + cfg_name
if openapi_cfg_name in vars():
continue
api_wrapper = type(openapi_cfg_name, (cfg, OpenAPIModelInfoBase), dict(
__annotations__ = dict(
type=Literal[model_type.value],
),
))
#globals()[openapi_cfg_name] = api_wrapper
vars()[openapi_cfg_name] = api_wrapper
OPENAPI_MODEL_CONFIGS.append(api_wrapper)
def get_model_config_enums():
enums = list()
for model_config in MODEL_CONFIGS:
fields = inspect.get_annotations(model_config)
try:
field = fields["model_format"]
except:
raise Exception("format field not found")
# model_format: None
# model_format: SomeModelFormat
# model_format: Literal[SomeModelFormat.Diffusers]
# model_format: Literal[SomeModelFormat.Diffusers, SomeModelFormat.Checkpoint]
if isinstance(field, type) and issubclass(field, str) and issubclass(field, Enum):
enums.append(field)
elif get_origin(field) is Literal and all(isinstance(arg, str) and isinstance(arg, Enum) for arg in field.__args__):
enums.append(type(field.__args__[0]))
elif field is None:
pass
else:
raise Exception(f"Unsupported format definition in {model_configs.__qualname__}")
return enums

View File

@@ -0,0 +1,415 @@
import os
import sys
import typing
import inspect
from enum import Enum
from abc import ABCMeta, abstractmethod
import torch
import safetensors.torch
from diffusers import DiffusionPipeline, ConfigMixin
from contextlib import suppress
from pydantic import BaseModel, Field
from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any, Union
class BaseModelType(str, Enum):
StableDiffusion1 = "sd-1"
StableDiffusion2 = "sd-2"
#Kandinsky2_1 = "kandinsky-2.1"
class ModelType(str, Enum):
Pipeline = "pipeline"
Vae = "vae"
Lora = "lora"
ControlNet = "controlnet" # used by model_probe
TextualInversion = "embedding"
class SubModelType(str, Enum):
UNet = "unet"
TextEncoder = "text_encoder"
Tokenizer = "tokenizer"
Vae = "vae"
Scheduler = "scheduler"
SafetyChecker = "safety_checker"
#MoVQ = "movq"
class ModelVariantType(str, Enum):
Normal = "normal"
Inpaint = "inpaint"
Depth = "depth"
class SchedulerPredictionType(str, Enum):
Epsilon = "epsilon"
VPrediction = "v_prediction"
Sample = "sample"
class ModelError(str, Enum):
NotFound = "not_found"
class ModelConfigBase(BaseModel):
path: str # or Path
description: Optional[str] = Field(None)
model_format: Optional[str] = Field(None)
# do not save to config
error: Optional[ModelError] = Field(None)
class Config:
use_enum_values = True
class EmptyConfigLoader(ConfigMixin):
@classmethod
def load_config(cls, *args, **kwargs):
cls.config_name = kwargs.pop("config_name")
return super().load_config(*args, **kwargs)
T_co = TypeVar('T_co', covariant=True)
class classproperty(Generic[T_co]):
def __init__(self, fget: Callable[[Any], T_co]) -> None:
self.fget = fget
def __get__(self, instance: Optional[Any], owner: Type[Any]) -> T_co:
return self.fget(owner)
def __set__(self, instance: Optional[Any], value: Any) -> None:
raise AttributeError('cannot set attribute')
class ModelBase(metaclass=ABCMeta):
#model_path: str
#base_model: BaseModelType
#model_type: ModelType
def __init__(
self,
model_path: str,
base_model: BaseModelType,
model_type: ModelType,
):
self.model_path = model_path
self.base_model = base_model
self.model_type = model_type
def _hf_definition_to_type(self, subtypes: List[str]) -> Type:
if len(subtypes) < 2:
raise Exception("Invalid subfolder definition!")
if all(t is None for t in subtypes):
return None
elif any(t is None for t in subtypes):
raise Exception(f"Unsupported definition: {subtypes}")
if subtypes[0] in ["diffusers", "transformers"]:
res_type = sys.modules[subtypes[0]]
subtypes = subtypes[1:]
else:
res_type = sys.modules["diffusers"]
res_type = getattr(res_type, "pipelines")
for subtype in subtypes:
res_type = getattr(res_type, subtype)
return res_type
@classmethod
def _get_configs(cls):
with suppress(Exception):
return cls.__configs
configs = dict()
for name in dir(cls):
if name.startswith("__"):
continue
value = getattr(cls, name)
if not isinstance(value, type) or not issubclass(value, ModelConfigBase):
continue
fields = inspect.get_annotations(value)
try:
field = fields["model_format"]
except:
raise Exception(f"Invalid config definition - format field not found({cls.__qualname__})")
if isinstance(field, type) and issubclass(field, str) and issubclass(field, Enum):
for model_format in field:
configs[model_format.value] = value
elif typing.get_origin(field) is Literal and all(isinstance(arg, str) and isinstance(arg, Enum) for arg in field.__args__):
for model_format in field.__args__:
configs[model_format.value] = value
elif field is None:
configs[None] = value
else:
raise Exception(f"Unsupported format definition in {cls.__qualname__}")
cls.__configs = configs
return cls.__configs
@classmethod
def create_config(cls, **kwargs) -> ModelConfigBase:
if "model_format" not in kwargs:
raise Exception("Field 'model_format' not found in model config")
configs = cls._get_configs()
return configs[kwargs["model_format"]](**kwargs)
@classmethod
def probe_config(cls, path: str, **kwargs) -> ModelConfigBase:
return cls.create_config(
path=path,
model_format=cls.detect_format(path),
)
@classmethod
@abstractmethod
def detect_format(cls, path: str) -> str:
raise NotImplementedError()
@classproperty
@abstractmethod
def save_to_config(cls) -> bool:
raise NotImplementedError()
@abstractmethod
def get_size(self, child_type: Optional[SubModelType] = None) -> int:
raise NotImplementedError()
@abstractmethod
def get_model(
self,
torch_dtype: Optional[torch.dtype],
child_type: Optional[SubModelType] = None,
) -> Any:
raise NotImplementedError()
class DiffusersModel(ModelBase):
#child_types: Dict[str, Type]
#child_sizes: Dict[str, int]
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
super().__init__(model_path, base_model, model_type)
self.child_types: Dict[str, Type] = dict()
self.child_sizes: Dict[str, int] = dict()
try:
config_data = DiffusionPipeline.load_config(self.model_path)
#config_data = json.loads(os.path.join(self.model_path, "model_index.json"))
except:
raise Exception("Invalid diffusers model! (model_index.json not found or invalid)")
config_data.pop("_ignore_files", None)
# retrieve all folder_names that contain relevant files
child_components = [k for k, v in config_data.items() if isinstance(v, list)]
for child_name in child_components:
child_type = self._hf_definition_to_type(config_data[child_name])
self.child_types[child_name] = child_type
self.child_sizes[child_name] = calc_model_size_by_fs(self.model_path, subfolder=child_name)
def get_size(self, child_type: Optional[SubModelType] = None):
if child_type is None:
return sum(self.child_sizes.values())
else:
return self.child_sizes[child_type]
def get_model(
self,
torch_dtype: Optional[torch.dtype],
child_type: Optional[SubModelType] = None,
):
# return pipeline in different function to pass more arguments
if child_type is None:
raise Exception("Child model type can't be null on diffusers model")
if child_type not in self.child_types:
return None # TODO: or raise
if torch_dtype == torch.float16:
variants = ["fp16", None]
else:
variants = [None, "fp16"]
# TODO: better error handling(differentiate not found from others)
for variant in variants:
try:
# TODO: set cache_dir to /dev/null to be sure that cache not used?
model = self.child_types[child_type].from_pretrained(
self.model_path,
subfolder=child_type.value,
torch_dtype=torch_dtype,
variant=variant,
local_files_only=True,
)
break
except Exception as e:
#print("====ERR LOAD====")
#print(f"{variant}: {e}")
pass
else:
raise Exception(f"Failed to load {self.base_model}:{self.model_type}:{child_type} model")
# calc more accurate size
self.child_sizes[child_type] = calc_model_size_by_data(model)
return model
#def convert_if_required(model_path: str, cache_path: str, config: Optional[dict]) -> str:
def calc_model_size_by_fs(
model_path: str,
subfolder: Optional[str] = None,
variant: Optional[str] = None
):
if subfolder is not None:
model_path = os.path.join(model_path, subfolder)
# this can happen when, for example, the safety checker
# is not downloaded.
if not os.path.exists(model_path):
return 0
all_files = os.listdir(model_path)
all_files = [f for f in all_files if os.path.isfile(os.path.join(model_path, f))]
fp16_files = set([f for f in all_files if ".fp16." in f or ".fp16-" in f])
bit8_files = set([f for f in all_files if ".8bit." in f or ".8bit-" in f])
other_files = set(all_files) - fp16_files - bit8_files
if variant is None:
files = other_files
elif variant == "fp16":
files = fp16_files
elif variant == "8bit":
files = bit8_files
else:
raise NotImplementedError(f"Unknown variant: {variant}")
# try read from index if exists
index_postfix = ".index.json"
if variant is not None:
index_postfix = f".index.{variant}.json"
for file in files:
if not file.endswith(index_postfix):
continue
try:
with open(os.path.join(model_path, file), "r") as f:
index_data = json.loads(f.read())
return int(index_data["metadata"]["total_size"])
except:
pass
# calculate files size if there is no index file
formats = [
(".safetensors",), # safetensors
(".bin",), # torch
(".onnx", ".pb"), # onnx
(".msgpack",), # flax
(".ckpt",), # tf
(".h5",), # tf2
]
for file_format in formats:
model_files = [f for f in files if f.endswith(file_format)]
if len(model_files) == 0:
continue
model_size = 0
for model_file in model_files:
file_stats = os.stat(os.path.join(model_path, model_file))
model_size += file_stats.st_size
return model_size
#raise NotImplementedError(f"Unknown model structure! Files: {all_files}")
return 0 # scheduler/feature_extractor/tokenizer - models without loading to gpu
def calc_model_size_by_data(model) -> int:
if isinstance(model, DiffusionPipeline):
return _calc_pipeline_by_data(model)
elif isinstance(model, torch.nn.Module):
return _calc_model_by_data(model)
else:
return 0
def _calc_pipeline_by_data(pipeline) -> int:
res = 0
for submodel_key in pipeline.components.keys():
submodel = getattr(pipeline, submodel_key)
if submodel is not None and isinstance(submodel, torch.nn.Module):
res += _calc_model_by_data(submodel)
return res
def _calc_model_by_data(model) -> int:
mem_params = sum([param.nelement()*param.element_size() for param in model.parameters()])
mem_bufs = sum([buf.nelement()*buf.element_size() for buf in model.buffers()])
mem = mem_params + mem_bufs # in bytes
return mem
def _fast_safetensors_reader(path: str):
checkpoint = dict()
device = torch.device("meta")
with open(path, "rb") as f:
definition_len = int.from_bytes(f.read(8), 'little')
definition_json = f.read(definition_len)
definition = json.loads(definition_json)
if "__metadata__" in definition and definition["__metadata__"].get("format", "pt") not in {"pt", "torch", "pytorch"}:
raise Exception("Supported only pytorch safetensors files")
definition.pop("__metadata__", None)
for key, info in definition.items():
dtype = {
"I8": torch.int8,
"I16": torch.int16,
"I32": torch.int32,
"I64": torch.int64,
"F16": torch.float16,
"F32": torch.float32,
"F64": torch.float64,
}[info["dtype"]]
checkpoint[key] = torch.empty(info["shape"], dtype=dtype, device=device)
return checkpoint
def read_checkpoint_meta(path: str):
if path.endswith(".safetensors"):
try:
checkpoint = _fast_safetensors_reader(path)
except:
# TODO: create issue for support "meta"?
checkpoint = safetensors.torch.load_file(path, device="cpu")
else:
checkpoint = torch.load(path, map_location=torch.device("meta"))
return checkpoint
import warnings
from diffusers import logging as diffusers_logging
from transformers import logging as transformers_logging
class SilenceWarnings(object):
def __init__(self):
self.transformers_verbosity = transformers_logging.get_verbosity()
self.diffusers_verbosity = diffusers_logging.get_verbosity()
def __enter__(self):
transformers_logging.set_verbosity_error()
diffusers_logging.set_verbosity_error()
warnings.simplefilter('ignore')
def __exit__(self, type, value, traceback):
transformers_logging.set_verbosity(self.transformers_verbosity)
diffusers_logging.set_verbosity(self.diffusers_verbosity)
warnings.simplefilter('default')

View File

@@ -0,0 +1,92 @@
import os
import torch
from enum import Enum
from pathlib import Path
from typing import Optional, Union, Literal
from .base import (
ModelBase,
ModelConfigBase,
BaseModelType,
ModelType,
SubModelType,
EmptyConfigLoader,
calc_model_size_by_fs,
calc_model_size_by_data,
classproperty,
)
class ControlNetModelFormat(str, Enum):
Checkpoint = "checkpoint"
Diffusers = "diffusers"
class ControlNetModel(ModelBase):
#model_class: Type
#model_size: int
class Config(ModelConfigBase):
model_format: ControlNetModelFormat
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.ControlNet
super().__init__(model_path, base_model, model_type)
try:
config = EmptyConfigLoader.load_config(self.model_path, config_name="config.json")
#config = json.loads(os.path.join(self.model_path, "config.json"))
except:
raise Exception("Invalid controlnet model! (config.json not found or invalid)")
model_class_name = config.get("_class_name", None)
if model_class_name not in {"ControlNetModel"}:
raise Exception(f"Invalid ControlNet model! Unknown _class_name: {model_class_name}")
try:
self.model_class = self._hf_definition_to_type(["diffusers", model_class_name])
self.model_size = calc_model_size_by_fs(self.model_path)
except:
raise Exception("Invalid ControlNet model!")
def get_size(self, child_type: Optional[SubModelType] = None):
if child_type is not None:
raise Exception("There is no child models in controlnet model")
return self.model_size
def get_model(
self,
torch_dtype: Optional[torch.dtype],
child_type: Optional[SubModelType] = None,
):
if child_type is not None:
raise Exception("There is no child models in controlnet model")
model = self.model_class.from_pretrained(
self.model_path,
torch_dtype=torch_dtype,
)
# calc more accurate size
self.model_size = calc_model_size_by_data(model)
return model
@classproperty
def save_to_config(cls) -> bool:
return False
@classmethod
def detect_format(cls, path: str):
if os.path.isdir(path):
return ControlNetModelFormat.Diffusers
else:
return ControlNetModelFormat.Checkpoint
@classmethod
def convert_if_required(
cls,
model_path: str,
output_path: str,
config: ModelConfigBase, # empty config or config of parent model
base_model: BaseModelType,
) -> str:
if cls.detect_format(model_path) != ControlNetModelFormat.Diffusers:
raise NotImplementedError("Checkpoint controlnet models currently unsupported")
else:
return model_path

View File

@@ -0,0 +1,76 @@
import os
import torch
from enum import Enum
from typing import Optional, Union, Literal
from .base import (
ModelBase,
ModelConfigBase,
BaseModelType,
ModelType,
SubModelType,
classproperty,
)
# TODO: naming
from ..lora import LoRAModel as LoRAModelRaw
class LoRAModelFormat(str, Enum):
LyCORIS = "lycoris"
Diffusers = "diffusers"
class LoRAModel(ModelBase):
#model_size: int
class Config(ModelConfigBase):
model_format: LoRAModelFormat # TODO:
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.Lora
super().__init__(model_path, base_model, model_type)
self.model_size = os.path.getsize(self.model_path)
def get_size(self, child_type: Optional[SubModelType] = None):
if child_type is not None:
raise Exception("There is no child models in lora")
return self.model_size
def get_model(
self,
torch_dtype: Optional[torch.dtype],
child_type: Optional[SubModelType] = None,
):
if child_type is not None:
raise Exception("There is no child models in lora")
model = LoRAModelRaw.from_checkpoint(
file_path=self.model_path,
dtype=torch_dtype,
)
self.model_size = model.calc_size()
return model
@classproperty
def save_to_config(cls) -> bool:
return False
@classmethod
def detect_format(cls, path: str):
if os.path.isdir(path):
return LoRAModelFormat.Diffusers
else:
return LoRAModelFormat.LyCORIS
@classmethod
def convert_if_required(
cls,
model_path: str,
output_path: str,
config: ModelConfigBase,
base_model: BaseModelType,
) -> str:
if cls.detect_format(model_path) == LoRAModelFormat.Diffusers:
# TODO: add diffusers lora when it stabilizes a bit
raise NotImplementedError("Diffusers lora not supported")
else:
return model_path

View File

@@ -0,0 +1,321 @@
import os
import json
from enum import Enum
from pydantic import Field
from pathlib import Path
from typing import Literal, Optional, Union
from .base import (
ModelBase,
ModelConfigBase,
BaseModelType,
ModelType,
SubModelType,
ModelVariantType,
DiffusersModel,
SchedulerPredictionType,
SilenceWarnings,
read_checkpoint_meta,
classproperty,
)
from invokeai.app.services.config import InvokeAIAppConfig
from omegaconf import OmegaConf
class StableDiffusion1ModelFormat(str, Enum):
Checkpoint = "checkpoint"
Diffusers = "diffusers"
class StableDiffusion1Model(DiffusersModel):
class DiffusersConfig(ModelConfigBase):
model_format: Literal[StableDiffusion1ModelFormat.Diffusers]
vae: Optional[str] = Field(None)
variant: ModelVariantType
class CheckpointConfig(ModelConfigBase):
model_format: Literal[StableDiffusion1ModelFormat.Checkpoint]
vae: Optional[str] = Field(None)
config: Optional[str] = Field(None)
variant: ModelVariantType
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert base_model == BaseModelType.StableDiffusion1
assert model_type == ModelType.Pipeline
super().__init__(
model_path=model_path,
base_model=BaseModelType.StableDiffusion1,
model_type=ModelType.Pipeline,
)
@classmethod
def probe_config(cls, path: str, **kwargs):
model_format = cls.detect_format(path)
ckpt_config_path = kwargs.get("config", None)
if model_format == StableDiffusion1ModelFormat.Checkpoint:
if ckpt_config_path:
ckpt_config = OmegaConf.load(ckpt_config_path)
ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"]
else:
checkpoint = read_checkpoint_meta(path)
checkpoint = checkpoint.get('state_dict', checkpoint)
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
elif model_format == StableDiffusion1ModelFormat.Diffusers:
unet_config_path = os.path.join(path, "unet", "config.json")
if os.path.exists(unet_config_path):
with open(unet_config_path, "r") as f:
unet_config = json.loads(f.read())
in_channels = unet_config['in_channels']
else:
raise Exception("Not supported stable diffusion diffusers format(possibly onnx?)")
else:
raise NotImplementedError(f"Unknown stable diffusion 1.* format: {model_format}")
if in_channels == 9:
variant = ModelVariantType.Inpaint
elif in_channels == 4:
variant = ModelVariantType.Normal
else:
raise Exception("Unkown stable diffusion 1.* model format")
return cls.create_config(
path=path,
model_format=model_format,
config=ckpt_config_path,
variant=variant,
)
@classproperty
def save_to_config(cls) -> bool:
return True
@classmethod
def detect_format(cls, model_path: str):
if os.path.isdir(model_path):
return StableDiffusion1ModelFormat.Diffusers
else:
return StableDiffusion1ModelFormat.Checkpoint
@classmethod
def convert_if_required(
cls,
model_path: str,
output_path: str,
config: ModelConfigBase,
base_model: BaseModelType,
) -> str:
assert model_path == config.path
if isinstance(config, cls.CheckpointConfig):
return _convert_ckpt_and_cache(
version=BaseModelType.StableDiffusion1,
model_config=config,
output_path=output_path,
) # TODO: args
else:
return model_path
class StableDiffusion2ModelFormat(str, Enum):
Checkpoint = "checkpoint"
Diffusers = "diffusers"
class StableDiffusion2Model(DiffusersModel):
# TODO: check that configs overwriten properly
class DiffusersConfig(ModelConfigBase):
model_format: Literal[StableDiffusion2ModelFormat.Diffusers]
vae: Optional[str] = Field(None)
variant: ModelVariantType
prediction_type: SchedulerPredictionType
upcast_attention: bool
class CheckpointConfig(ModelConfigBase):
model_format: Literal[StableDiffusion2ModelFormat.Checkpoint]
vae: Optional[str] = Field(None)
config: Optional[str] = Field(None)
variant: ModelVariantType
prediction_type: SchedulerPredictionType
upcast_attention: bool
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert base_model == BaseModelType.StableDiffusion2
assert model_type == ModelType.Pipeline
super().__init__(
model_path=model_path,
base_model=BaseModelType.StableDiffusion2,
model_type=ModelType.Pipeline,
)
@classmethod
def probe_config(cls, path: str, **kwargs):
model_format = cls.detect_format(path)
ckpt_config_path = kwargs.get("config", None)
if model_format == StableDiffusion2ModelFormat.Checkpoint:
if ckpt_config_path:
ckpt_config = OmegaConf.load(ckpt_config_path)
ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"]
else:
checkpoint = read_checkpoint_meta(path)
checkpoint = checkpoint.get('state_dict', checkpoint)
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
elif model_format == StableDiffusion2ModelFormat.Diffusers:
unet_config_path = os.path.join(path, "unet", "config.json")
if os.path.exists(unet_config_path):
with open(unet_config_path, "r") as f:
unet_config = json.loads(f.read())
in_channels = unet_config['in_channels']
else:
raise Exception("Not supported stable diffusion diffusers format(possibly onnx?)")
else:
raise NotImplementedError(f"Unknown stable diffusion 2.* format: {model_format}")
if in_channels == 9:
variant = ModelVariantType.Inpaint
elif in_channels == 5:
variant = ModelVariantType.Depth
elif in_channels == 4:
variant = ModelVariantType.Normal
else:
raise Exception("Unkown stable diffusion 2.* model format")
if variant == ModelVariantType.Normal:
prediction_type = SchedulerPredictionType.VPrediction
upcast_attention = True
else:
prediction_type = SchedulerPredictionType.Epsilon
upcast_attention = False
return cls.create_config(
path=path,
model_format=model_format,
config=ckpt_config_path,
variant=variant,
prediction_type=prediction_type,
upcast_attention=upcast_attention,
)
@classproperty
def save_to_config(cls) -> bool:
return True
@classmethod
def detect_format(cls, model_path: str):
if os.path.isdir(model_path):
return StableDiffusion2ModelFormat.Diffusers
else:
return StableDiffusion2ModelFormat.Checkpoint
@classmethod
def convert_if_required(
cls,
model_path: str,
output_path: str,
config: ModelConfigBase,
base_model: BaseModelType,
) -> str:
assert model_path == config.path
if isinstance(config, cls.CheckpointConfig):
return _convert_ckpt_and_cache(
version=BaseModelType.StableDiffusion2,
model_config=config,
output_path=output_path,
) # TODO: args
else:
return model_path
def _select_ckpt_config(version: BaseModelType, variant: ModelVariantType):
ckpt_configs = {
BaseModelType.StableDiffusion1: {
ModelVariantType.Normal: "v1-inference.yaml",
ModelVariantType.Inpaint: "v1-inpainting-inference.yaml",
},
BaseModelType.StableDiffusion2: {
# code further will manually set upcast_attention and v_prediction
ModelVariantType.Normal: "v2-inference.yaml",
ModelVariantType.Inpaint: "v2-inpainting-inference.yaml",
ModelVariantType.Depth: "v2-midas-inference.yaml",
}
}
try:
# TODO: path
#model_config.config = app_config.config_dir / "stable-diffusion" / ckpt_configs[version][model_config.variant]
#return InvokeAIAppConfig.get_config().legacy_conf_dir / ckpt_configs[version][variant]
return InvokeAIAppConfig.get_config().root_dir / "configs" / "stable-diffusion" / ckpt_configs[version][variant]
except:
return None
# TODO: rework
def _convert_ckpt_and_cache(
version: BaseModelType,
model_config: Union[StableDiffusion1Model.CheckpointConfig, StableDiffusion2Model.CheckpointConfig],
output_path: str,
) -> str:
"""
Convert the checkpoint model indicated in mconfig into a
diffusers, cache it to disk, and return Path to converted
file. If already on disk then just returns Path.
"""
app_config = InvokeAIAppConfig.get_config()
if model_config.config is None:
model_config.config = _select_ckpt_config(version, model_config.variant)
if model_config.config is None:
raise Exception(f"Model variant {model_config.variant} not supported for {version}")
weights = app_config.root_path / model_config.path
config_file = app_config.root_path / model_config.config
output_path = Path(output_path)
if version == BaseModelType.StableDiffusion1:
upcast_attention = False
prediction_type = SchedulerPredictionType.Epsilon
elif version == BaseModelType.StableDiffusion2:
upcast_attention = model_config.upcast_attention
prediction_type = model_config.prediction_type
else:
raise Exception(f"Unknown model provided: {version}")
# return cached version if it exists
if output_path.exists():
return output_path
# TODO: I think that it more correctly to convert with embedded vae
# as if user will delete custom vae he will got not embedded but also custom vae
#vae_ckpt_path, vae_model = self._get_vae_for_conversion(weights, mconfig)
# to avoid circular import errors
from ..convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
with SilenceWarnings():
convert_ckpt_to_diffusers(
weights,
output_path,
model_version=version,
model_variant=model_config.variant,
original_config_file=config_file,
extract_ema=True,
upcast_attention=upcast_attention,
prediction_type=prediction_type,
scan_needed=True,
model_root=app_config.models_path,
)
return output_path

View File

@@ -0,0 +1,64 @@
import os
import torch
from typing import Optional
from .base import (
ModelBase,
ModelConfigBase,
BaseModelType,
ModelType,
SubModelType,
classproperty,
)
# TODO: naming
from ..lora import TextualInversionModel as TextualInversionModelRaw
class TextualInversionModel(ModelBase):
#model_size: int
class Config(ModelConfigBase):
model_format: None
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.TextualInversion
super().__init__(model_path, base_model, model_type)
self.model_size = os.path.getsize(self.model_path)
def get_size(self, child_type: Optional[SubModelType] = None):
if child_type is not None:
raise Exception("There is no child models in textual inversion")
return self.model_size
def get_model(
self,
torch_dtype: Optional[torch.dtype],
child_type: Optional[SubModelType] = None,
):
if child_type is not None:
raise Exception("There is no child models in textual inversion")
model = TextualInversionModelRaw.from_checkpoint(
file_path=self.model_path,
dtype=torch_dtype,
)
self.model_size = model.embedding.nelement() * model.embedding.element_size()
return model
@classproperty
def save_to_config(cls) -> bool:
return False
@classmethod
def detect_format(cls, path: str):
return None
@classmethod
def convert_if_required(
cls,
model_path: str,
output_path: str,
config: ModelConfigBase,
base_model: BaseModelType,
) -> str:
return model_path

View File

@@ -0,0 +1,166 @@
import os
import torch
import safetensors
from enum import Enum
from pathlib import Path
from typing import Optional, Union, Literal
from .base import (
ModelBase,
ModelConfigBase,
BaseModelType,
ModelType,
SubModelType,
ModelVariantType,
EmptyConfigLoader,
calc_model_size_by_fs,
calc_model_size_by_data,
classproperty,
)
from invokeai.app.services.config import InvokeAIAppConfig
from diffusers.utils import is_safetensors_available
from omegaconf import OmegaConf
class VaeModelFormat(str, Enum):
Checkpoint = "checkpoint"
Diffusers = "diffusers"
class VaeModel(ModelBase):
#vae_class: Type
#model_size: int
class Config(ModelConfigBase):
model_format: VaeModelFormat
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.Vae
super().__init__(model_path, base_model, model_type)
try:
config = EmptyConfigLoader.load_config(self.model_path, config_name="config.json")
#config = json.loads(os.path.join(self.model_path, "config.json"))
except:
raise Exception("Invalid vae model! (config.json not found or invalid)")
try:
vae_class_name = config.get("_class_name", "AutoencoderKL")
self.vae_class = self._hf_definition_to_type(["diffusers", vae_class_name])
self.model_size = calc_model_size_by_fs(self.model_path)
except:
raise Exception("Invalid vae model! (Unkown vae type)")
def get_size(self, child_type: Optional[SubModelType] = None):
if child_type is not None:
raise Exception("There is no child models in vae model")
return self.model_size
def get_model(
self,
torch_dtype: Optional[torch.dtype],
child_type: Optional[SubModelType] = None,
):
if child_type is not None:
raise Exception("There is no child models in vae model")
model = self.vae_class.from_pretrained(
self.model_path,
torch_dtype=torch_dtype,
)
# calc more accurate size
self.model_size = calc_model_size_by_data(model)
return model
@classproperty
def save_to_config(cls) -> bool:
return False
@classmethod
def detect_format(cls, path: str):
if os.path.isdir(path):
return VaeModelFormat.Diffusers
else:
return VaeModelFormat.Checkpoint
@classmethod
def convert_if_required(
cls,
model_path: str,
output_path: str,
config: ModelConfigBase, # empty config or config of parent model
base_model: BaseModelType,
) -> str:
if cls.detect_format(model_path) == VaeModelFormat.Checkpoint:
return _convert_vae_ckpt_and_cache(
weights_path=model_path,
output_path=output_path,
base_model=base_model,
model_config=config,
)
else:
return model_path
# TODO: rework
def _convert_vae_ckpt_and_cache(
weights_path: str,
output_path: str,
base_model: BaseModelType,
model_config: ModelConfigBase,
) -> str:
"""
Convert the VAE indicated in mconfig into a diffusers AutoencoderKL
object, cache it to disk, and return Path to converted
file. If already on disk then just returns Path.
"""
app_config = InvokeAIAppConfig.get_config()
weights_path = app_config.root_dir / weights_path
output_path = Path(output_path)
"""
this size used only in when tiling enabled to separate input in tiles
sizes in configs from stable diffusion githubs(1 and 2) set to 256
on huggingface it:
1.5 - 512
1.5-inpainting - 256
2-inpainting - 512
2-depth - 256
2-base - 512
2 - 768
2.1-base - 768
2.1 - 768
"""
image_size = 512
# return cached version if it exists
if output_path.exists():
return output_path
if base_model in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
from .stable_diffusion import _select_ckpt_config
# all sd models use same vae settings
config_file = _select_ckpt_config(base_model, ModelVariantType.Normal)
else:
raise Exception(f"Vae conversion not supported for model type: {base_model}")
# this avoids circular import error
from ..convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers
if weights_path.suffix == '.safetensors':
checkpoint = safetensors.torch.load_file(weights_path, device="cpu")
else:
checkpoint = torch.load(weights_path, map_location="cpu")
# sometimes weights are hidden under "state_dict", and sometimes not
if "state_dict" in checkpoint:
checkpoint = checkpoint["state_dict"]
config = OmegaConf.load(config_file)
vae_model = convert_ldm_vae_to_diffusers(
checkpoint = checkpoint,
vae_config = config,
image_size = image_size,
)
vae_model.save_pretrained(
output_path,
safe_serialization=is_safetensors_available()
)
return output_path

View File

@@ -1,9 +0,0 @@
"""
Initialization file for invokeai.backend.prompting
"""
from .conditioning import (
get_prompt_structure,
get_tokens_for_prompt_object,
get_uc_and_c_and_ec,
split_weighted_subprompts,
)

View File

@@ -1,296 +0,0 @@
"""
This module handles the generation of the conditioning tensors.
Useful function exports:
get_uc_and_c_and_ec() get the conditioned and unconditioned latent, and edited conditioning if we're doing cross-attention control
"""
import re
from typing import Optional, Union
from compel import Compel
from compel.prompt_parser import (
Blend,
CrossAttentionControlSubstitute,
FlattenedPrompt,
Fragment,
PromptParser,
Conjunction,
)
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
from ..stable_diffusion import InvokeAIDiffuserComponent
from ..util import torch_dtype
config = InvokeAIAppConfig.get_config()
def get_uc_and_c_and_ec(prompt_string,
model: InvokeAIDiffuserComponent,
log_tokens=False, skip_normalize_legacy_blend=False):
# lazy-load any deferred textual inversions.
# this might take a couple of seconds the first time a textual inversion is used.
model.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms(prompt_string)
compel = Compel(tokenizer=model.tokenizer,
text_encoder=model.text_encoder,
textual_inversion_manager=model.textual_inversion_manager,
dtype_for_device_getter=torch_dtype,
truncate_long_prompts=False,
)
# get rid of any newline characters
prompt_string = prompt_string.replace("\n", " ")
positive_prompt_string, negative_prompt_string = split_prompt_to_positive_and_negative(prompt_string)
legacy_blend = try_parse_legacy_blend(positive_prompt_string, skip_normalize_legacy_blend)
positive_conjunction: Conjunction
if legacy_blend is not None:
positive_conjunction = legacy_blend
else:
positive_conjunction = Compel.parse_prompt_string(positive_prompt_string)
positive_prompt = positive_conjunction.prompts[0]
negative_conjunction = Compel.parse_prompt_string(negative_prompt_string)
negative_prompt: FlattenedPrompt | Blend = negative_conjunction.prompts[0]
tokens_count = get_max_token_count(model.tokenizer, positive_prompt)
if log_tokens or config.log_tokenization:
log_tokenization(positive_prompt, negative_prompt, tokenizer=model.tokenizer)
c, options = compel.build_conditioning_tensor_for_prompt_object(positive_prompt)
uc, _ = compel.build_conditioning_tensor_for_prompt_object(negative_prompt)
[c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(tokens_count_including_eos_bos=tokens_count,
cross_attention_control_args=options.get(
'cross_attention_control', None))
return uc, c, ec
def get_prompt_structure(
prompt_string, skip_normalize_legacy_blend: bool = False
) -> (Union[FlattenedPrompt, Blend], FlattenedPrompt):
(
positive_prompt_string,
negative_prompt_string,
) = split_prompt_to_positive_and_negative(prompt_string)
legacy_blend = try_parse_legacy_blend(
positive_prompt_string, skip_normalize_legacy_blend
)
positive_prompt: Conjunction
if legacy_blend is not None:
positive_conjunction = legacy_blend
else:
positive_conjunction = Compel.parse_prompt_string(positive_prompt_string)
positive_prompt = positive_conjunction.prompts[0]
negative_conjunction = Compel.parse_prompt_string(negative_prompt_string)
negative_prompt: FlattenedPrompt|Blend = negative_conjunction.prompts[0]
return positive_prompt, negative_prompt
def get_max_token_count(
tokenizer, prompt: Union[FlattenedPrompt, Blend], truncate_if_too_long=False
) -> int:
if type(prompt) is Blend:
blend: Blend = prompt
return max(
[
get_max_token_count(tokenizer, c, truncate_if_too_long)
for c in blend.prompts
]
)
else:
return len(
get_tokens_for_prompt_object(tokenizer, prompt, truncate_if_too_long)
)
def get_tokens_for_prompt_object(
tokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True
) -> [str]:
if type(parsed_prompt) is Blend:
raise ValueError(
"Blend is not supported here - you need to get tokens for each of its .children"
)
text_fragments = [
x.text
if type(x) is Fragment
else (
" ".join([f.text for f in x.original])
if type(x) is CrossAttentionControlSubstitute
else str(x)
)
for x in parsed_prompt.children
]
text = " ".join(text_fragments)
tokens = tokenizer.tokenize(text)
if truncate_if_too_long:
max_tokens_length = tokenizer.model_max_length - 2 # typically 75
tokens = tokens[0:max_tokens_length]
return tokens
def split_prompt_to_positive_and_negative(prompt_string_uncleaned: str):
unconditioned_words = ""
unconditional_regex = r"\[(.*?)\]"
unconditionals = re.findall(unconditional_regex, prompt_string_uncleaned)
if len(unconditionals) > 0:
unconditioned_words = " ".join(unconditionals)
# Remove Unconditioned Words From Prompt
unconditional_regex_compile = re.compile(unconditional_regex)
clean_prompt = unconditional_regex_compile.sub(" ", prompt_string_uncleaned)
prompt_string_cleaned = re.sub(" +", " ", clean_prompt)
else:
prompt_string_cleaned = prompt_string_uncleaned
return prompt_string_cleaned, unconditioned_words
def log_tokenization(
positive_prompt: Union[Blend, FlattenedPrompt],
negative_prompt: Union[Blend, FlattenedPrompt],
tokenizer,
):
logger.info(f"[TOKENLOG] Parsed Prompt: {positive_prompt}")
logger.info(f"[TOKENLOG] Parsed Negative Prompt: {negative_prompt}")
log_tokenization_for_prompt_object(positive_prompt, tokenizer)
log_tokenization_for_prompt_object(
negative_prompt, tokenizer, display_label_prefix="(negative prompt)"
)
def log_tokenization_for_prompt_object(
p: Union[Blend, FlattenedPrompt], tokenizer, display_label_prefix=None
):
display_label_prefix = display_label_prefix or ""
if type(p) is Blend:
blend: Blend = p
for i, c in enumerate(blend.prompts):
log_tokenization_for_prompt_object(
c,
tokenizer,
display_label_prefix=f"{display_label_prefix}(blend part {i + 1}, weight={blend.weights[i]})",
)
elif type(p) is FlattenedPrompt:
flattened_prompt: FlattenedPrompt = p
if flattened_prompt.wants_cross_attention_control:
original_fragments = []
edited_fragments = []
for f in flattened_prompt.children:
if type(f) is CrossAttentionControlSubstitute:
original_fragments += f.original
edited_fragments += f.edited
else:
original_fragments.append(f)
edited_fragments.append(f)
original_text = " ".join([x.text for x in original_fragments])
log_tokenization_for_text(
original_text,
tokenizer,
display_label=f"{display_label_prefix}(.swap originals)",
)
edited_text = " ".join([x.text for x in edited_fragments])
log_tokenization_for_text(
edited_text,
tokenizer,
display_label=f"{display_label_prefix}(.swap replacements)",
)
else:
text = " ".join([x.text for x in flattened_prompt.children])
log_tokenization_for_text(
text, tokenizer, display_label=display_label_prefix
)
def log_tokenization_for_text(text, tokenizer, display_label=None, truncate_if_too_long=False):
"""shows how the prompt is tokenized
# usually tokens have '</w>' to indicate end-of-word,
# but for readability it has been replaced with ' '
"""
tokens = tokenizer.tokenize(text)
tokenized = ""
discarded = ""
usedTokens = 0
totalTokens = len(tokens)
for i in range(0, totalTokens):
token = tokens[i].replace("</w>", " ")
# alternate color
s = (usedTokens % 6) + 1
if truncate_if_too_long and i >= tokenizer.model_max_length:
discarded = discarded + f"\x1b[0;3{s};40m{token}"
else:
tokenized = tokenized + f"\x1b[0;3{s};40m{token}"
usedTokens += 1
if usedTokens > 0:
logger.info(f'[TOKENLOG] Tokens {display_label or ""} ({usedTokens}):')
logger.debug(f"{tokenized}\x1b[0m")
if discarded != "":
logger.info(f"[TOKENLOG] Tokens Discarded ({totalTokens - usedTokens}):")
logger.debug(f"{discarded}\x1b[0m")
def try_parse_legacy_blend(text: str, skip_normalize: bool = False) -> Optional[Conjunction]:
weighted_subprompts = split_weighted_subprompts(text, skip_normalize=skip_normalize)
if len(weighted_subprompts) <= 1:
return None
strings = [x[0] for x in weighted_subprompts]
pp = PromptParser()
parsed_conjunctions = [pp.parse_conjunction(x) for x in strings]
flattened_prompts = []
weights = []
for i, x in enumerate(parsed_conjunctions):
if len(x.prompts)>0:
flattened_prompts.append(x.prompts[0])
weights.append(weighted_subprompts[i][1])
return Conjunction([Blend(prompts=flattened_prompts, weights=weights, normalize_weights=not skip_normalize)])
def split_weighted_subprompts(text, skip_normalize=False) -> list:
"""
Legacy blend parsing.
grabs all text up to the first occurrence of ':'
uses the grabbed text as a sub-prompt, and takes the value following ':' as weight
if ':' has no value defined, defaults to 1.0
repeats until no text remaining
"""
prompt_parser = re.compile(
"""
(?P<prompt> # capture group for 'prompt'
(?:\\\:|[^:])+ # match one or more non ':' characters or escaped colons '\:'
) # end 'prompt'
(?: # non-capture group
:+ # match one or more ':' characters
(?P<weight> # capture group for 'weight'
-?\d+(?:\.\d+)? # match positive or negative integer or decimal number
)? # end weight capture group, make optional
\s* # strip spaces after weight
| # OR
$ # else, if no ':' then match end of line
) # end non-capture group
""",
re.VERBOSE,
)
parsed_prompts = [
(match.group("prompt").replace("\\:", ":"), float(match.group("weight") or 1))
for match in re.finditer(prompt_parser, text)
]
if len(parsed_prompts) == 0:
return []
if skip_normalize:
return parsed_prompts
weight_sum = sum(map(lambda x: x[1], parsed_prompts))
if weight_sum == 0:
logger.warning(
"Subprompt weights add up to zero. Discarding and using even weights instead."
)
equal_weight = 1 / max(len(parsed_prompts), 1)
return [(x[0], equal_weight) for x in parsed_prompts]
return [(x[0], x[1] / weight_sum) for x in parsed_prompts]

View File

@@ -5,7 +5,7 @@ class Restoration:
pass
def load_face_restore_models(
self, gfpgan_model_path="./models/gfpgan/GFPGANv1.4.pth"
self, gfpgan_model_path="./models/core/face_restoration/gfpgan/GFPGANv1.4.pth"
):
# Load GFPGAN
gfpgan = self.load_gfpgan(gfpgan_model_path)

View File

@@ -15,7 +15,7 @@ pretrained_model_url = (
class CodeFormerRestoration:
def __init__(
self, codeformer_dir="models/codeformer", codeformer_model_path="codeformer.pth"
self, codeformer_dir="./models/core/face_restoration/codeformer", codeformer_model_path="codeformer.pth"
) -> None:
self.globals = InvokeAIAppConfig.get_config()
@@ -24,7 +24,7 @@ class CodeFormerRestoration:
self.codeformer_model_exists = self.model_path.exists()
if not self.codeformer_model_exists:
logger.error("NOT FOUND: CodeFormer model not found at " + self.model_path)
logger.error(f"NOT FOUND: CodeFormer model not found at {self.model_path}")
sys.path.append(os.path.abspath(codeformer_dir))
def process(self, image, strength, device, seed=None, fidelity=0.75):
@@ -71,7 +71,7 @@ class CodeFormerRestoration:
upscale_factor=1,
use_parse=True,
device=device,
model_rootpath = self.globals.root_dir / "gfpgan" / "weights"
model_rootpath = self.globals.model_path / 'core/face_restoration/gfpgan/weights'
)
face_helper.clean_all()
face_helper.read_image(bgr_image_array)

View File

@@ -18,7 +18,7 @@ class GFPGAN:
self.gfpgan_model_exists = os.path.isfile(self.model_path)
if not self.gfpgan_model_exists:
logger.error("NOT FOUND: GFPGAN model not found at " + self.model_path)
logger.error(f"NOT FOUND: GFPGAN model not found at {self.model_path}")
return None
def model_exists(self):

View File

@@ -30,8 +30,8 @@ class ESRGAN:
upscale=4,
act_type="prelu",
)
model_path = config.root_dir / "models/realesrgan/realesr-general-x4v3.pth"
wdn_model_path = config.root_dir / "models/realesrgan/realesr-general-wdn-x4v3.pth"
model_path = config.models_path / "core/upscaling/realesrgan/realesr-general-x4v3.pth"
wdn_model_path = config.models_path / "core/upscaling/realesrgan/realesr-general-wdn-x4v3.pth"
scale = 4
bg_upsampler = RealESRGANer(

View File

@@ -30,18 +30,10 @@ class SafetyChecker(object):
self.device = device
try:
safety_model_id = "CompVis/stable-diffusion-safety-checker"
safety_model_path = config.cache_dir
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(
safety_model_id,
local_files_only=True,
cache_dir=safety_model_path,
)
self.safety_feature_extractor = AutoFeatureExtractor.from_pretrained(
safety_model_id,
local_files_only=True,
cache_dir=safety_model_path,
)
safety_model_id = config.models_path / 'core/convert/stable-diffusion-safety-checker'
feature_extractor_id = config.models_path / 'core/convert/stable-diffusion-safety-checker-extractor'
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
self.safety_feature_extractor = AutoFeatureExtractor.from_pretrained(feature_extractor_id)
except Exception:
logger.error(
"An error was encountered while installing the safety checker:"

View File

@@ -1,7 +1,6 @@
"""
Initialization file for the invokeai.backend.stable_diffusion package
"""
from .concepts_lib import HuggingFaceConceptsLibrary
from .diffusers_pipeline import (
ConditioningData,
PipelineIntermediateState,
@@ -10,4 +9,3 @@ from .diffusers_pipeline import (
from .diffusion import InvokeAIDiffuserComponent
from .diffusion.cross_attention_map_saving import AttentionMapSaver
from .diffusion.shared_invokeai_diffusion import PostprocessingSettings
from .textual_inversion_manager import TextualInversionManager

View File

@@ -1,275 +0,0 @@
"""
Query and install embeddings from the HuggingFace SD Concepts Library
at https://huggingface.co/sd-concepts-library.
The interface is through the Concepts() object.
"""
import os
import re
from typing import Callable
from urllib import error as ul_error
from urllib import request
from huggingface_hub import (
HfApi,
HfFolder,
ModelFilter,
hf_hub_url,
)
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.app.services.config import InvokeAIAppConfig
logger = InvokeAILogger.getLogger()
class HuggingFaceConceptsLibrary(object):
def __init__(self, root=None):
"""
Initialize the Concepts object. May optionally pass a root directory.
"""
self.config = InvokeAIAppConfig.get_config()
self.root = root or self.config.root
self.hf_api = HfApi()
self.local_concepts = dict()
self.concept_list = None
self.concepts_loaded = dict()
self.triggers = dict() # concept name to trigger phrase
self.concept_names = dict() # trigger phrase to concept name
self.match_trigger = re.compile(
"(<[\w\- >]+>)"
) # trigger is slightly less restrictive than HF concept name
self.match_concept = re.compile(
"<([\w\-]+)>"
) # HF concept name can only contain A-Za-z0-9_-
def list_concepts(self) -> list:
"""
Return a list of all the concepts by name, without the 'sd-concepts-library' part.
Also adds local concepts in invokeai/embeddings folder.
"""
local_concepts_now = self.get_local_concepts(
os.path.join(self.root, "embeddings")
)
local_concepts_to_add = set(local_concepts_now).difference(
set(self.local_concepts)
)
self.local_concepts.update(local_concepts_now)
if self.concept_list is not None:
if local_concepts_to_add:
self.concept_list.extend(list(local_concepts_to_add))
return self.concept_list
return self.concept_list
elif self.config.internet_available is True:
try:
models = self.hf_api.list_models(
filter=ModelFilter(model_name="sd-concepts-library/")
)
self.concept_list = [a.id.split("/")[1] for a in models]
# when init, add all in dir. when not init, add only concepts added between init and now
self.concept_list.extend(list(local_concepts_to_add))
except Exception as e:
logger.warning(
f"Hugging Face textual inversion concepts libraries could not be loaded. The error was {str(e)}."
)
logger.warning(
"You may load .bin and .pt file(s) manually using the --embedding_directory argument."
)
return self.concept_list
else:
return self.concept_list
def get_concept_model_path(self, concept_name: str) -> str:
"""
Returns the path to the 'learned_embeds.bin' file in
the named concept. Returns None if invalid or cannot
be downloaded.
"""
if not concept_name in self.list_concepts():
logger.warning(
f"{concept_name} is not a local embedding trigger, nor is it a HuggingFace concept. Generation will continue without the concept."
)
return None
return self.get_concept_file(concept_name.lower(), "learned_embeds.bin")
def concept_to_trigger(self, concept_name: str) -> str:
"""
Given a concept name returns its trigger by looking in the
"token_identifier.txt" file.
"""
if concept_name in self.triggers:
return self.triggers[concept_name]
elif self.concept_is_local(concept_name):
trigger = f"<{concept_name}>"
self.triggers[concept_name] = trigger
self.concept_names[trigger] = concept_name
return trigger
file = self.get_concept_file(
concept_name, "token_identifier.txt", local_only=True
)
if not file:
return None
with open(file, "r") as f:
trigger = f.readline()
trigger = trigger.strip()
self.triggers[concept_name] = trigger
self.concept_names[trigger] = concept_name
return trigger
def trigger_to_concept(self, trigger: str) -> str:
"""
Given a trigger phrase, maps it to the concept library name.
Only works if concept_to_trigger() has previously been called
on this library. There needs to be a persistent database for
this.
"""
concept = self.concept_names.get(trigger, None)
return f"<{concept}>" if concept else f"{trigger}"
def replace_triggers_with_concepts(self, prompt: str) -> str:
"""
Given a prompt string that contains <trigger> tags, replace these
tags with the concept name. The reason for this is so that the
concept names get stored in the prompt metadata. There is no
controlling of colliding triggers in the SD library, so it is
better to store the concept name (unique) than the concept trigger
(not necessarily unique!)
"""
if not prompt:
return prompt
triggers = self.match_trigger.findall(prompt)
if not triggers:
return prompt
def do_replace(match) -> str:
return self.trigger_to_concept(match.group(1)) or f"<{match.group(1)}>"
return self.match_trigger.sub(do_replace, prompt)
def replace_concepts_with_triggers(
self,
prompt: str,
load_concepts_callback: Callable[[list], any],
excluded_tokens: list[str],
) -> str:
"""
Given a prompt string that contains `<concept_name>` tags, replace
these tags with the appropriate trigger.
If any `<concept_name>` tags are found, `load_concepts_callback()` is called with a list
of `concepts_name` strings.
`excluded_tokens` are any tokens that should not be replaced, typically because they
are trigger tokens from a locally-loaded embedding.
"""
concepts = self.match_concept.findall(prompt)
if not concepts:
return prompt
load_concepts_callback(concepts)
def do_replace(match) -> str:
if excluded_tokens and f"<{match.group(1)}>" in excluded_tokens:
return f"<{match.group(1)}>"
return self.concept_to_trigger(match.group(1)) or f"<{match.group(1)}>"
return self.match_concept.sub(do_replace, prompt)
def get_concept_file(
self,
concept_name: str,
file_name: str = "learned_embeds.bin",
local_only: bool = False,
) -> str:
if not (
self.concept_is_downloaded(concept_name)
or self.concept_is_local(concept_name)
or local_only
):
self.download_concept(concept_name)
# get local path in invokeai/embeddings if local concept
if self.concept_is_local(concept_name):
concept_path = self._concept_local_path(concept_name)
path = concept_path
else:
concept_path = self._concept_path(concept_name)
path = os.path.join(concept_path, file_name)
return path if os.path.exists(path) else None
def concept_is_local(self, concept_name) -> bool:
return concept_name in self.local_concepts
def concept_is_downloaded(self, concept_name) -> bool:
concept_directory = self._concept_path(concept_name)
return os.path.exists(concept_directory)
def download_concept(self, concept_name) -> bool:
repo_id = self._concept_id(concept_name)
dest = self._concept_path(concept_name)
access_token = HfFolder.get_token()
header = [("Authorization", f"Bearer {access_token}")] if access_token else []
opener = request.build_opener()
opener.addheaders = header
request.install_opener(opener)
os.makedirs(dest, exist_ok=True)
succeeded = True
bytes = 0
def tally_download_size(chunk, size, total):
nonlocal bytes
if chunk == 0:
bytes += total
logger.info(f"Downloading {repo_id}...", end="")
try:
for file in (
"README.md",
"learned_embeds.bin",
"token_identifier.txt",
"type_of_concept.txt",
):
url = hf_hub_url(repo_id, file)
request.urlretrieve(
url, os.path.join(dest, file), reporthook=tally_download_size
)
except ul_error.HTTPError as e:
if e.code == 404:
logger.warning(
f"Concept {concept_name} is not known to the Hugging Face library. Generation will continue without the concept."
)
else:
logger.warning(
f"Failed to download {concept_name}/{file} ({str(e)}. Generation will continue without the concept.)"
)
os.rmdir(dest)
return False
except ul_error.URLError as e:
logger.error(
f"an error occurred while downloading {concept_name}: {str(e)}. This may reflect a network issue. Generation will continue without the concept."
)
os.rmdir(dest)
return False
logger.info("...{:.2f}Kb".format(bytes / 1024))
return succeeded
def _concept_id(self, concept_name: str) -> str:
return f"sd-concepts-library/{concept_name}"
def _concept_path(self, concept_name: str) -> str:
return os.path.join(self.root, "models", "sd-concepts-library", concept_name)
def _concept_local_path(self, concept_name: str) -> str:
filename = self.local_concepts[concept_name]
return os.path.join(self.root, "embeddings", filename)
def get_local_concepts(self, loc_dir: str):
locs_dic = dict()
if os.path.isdir(loc_dir):
for file in os.listdir(loc_dir):
f = os.path.splitext(file)
if f[1] == ".bin" or f[1] == ".pt":
locs_dic[f[0]] = file
return locs_dic

View File

@@ -16,14 +16,13 @@ from accelerate.utils import set_seed
import psutil
import torch
import torchvision.transforms as T
from compel import EmbeddingsProvider
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.models.controlnet import ControlNetModel, ControlNetOutput
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
StableDiffusionPipeline,
)
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel
from diffusers.pipelines.controlnet import MultiControlNetModel
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import (
StableDiffusionImg2ImgPipeline,
@@ -48,7 +47,6 @@ from .diffusion import (
PostprocessingSettings,
)
from .offloading import FullyLoadedModelGroup, LazilyLoadedModelGroup, ModelGroup
from .textual_inversion_manager import TextualInversionManager
@dataclass
class PipelineIntermediateState:
@@ -217,10 +215,12 @@ class GeneratorToCallbackinator(Generic[ParamType, ReturnType, CallbackType]):
@dataclass
class ControlNetData:
model: ControlNetModel = Field(default=None)
image_tensor: torch.Tensor= Field(default=None)
weight: Union[float, List[float]]= Field(default=1.0)
image_tensor: torch.Tensor = Field(default=None)
weight: Union[float, List[float]] = Field(default=1.0)
begin_step_percent: float = Field(default=0.0)
end_step_percent: float = Field(default=1.0)
control_mode: str = Field(default="balanced")
@dataclass(frozen=True)
class ConditioningData:
@@ -317,6 +317,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
requires_safety_checker: bool = False,
precision: str = "float32",
control_model: ControlNetModel = None,
execution_device: Optional[torch.device] = None,
):
super().__init__(
vae,
@@ -341,22 +342,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# control_model=control_model,
)
self.invokeai_diffuser = InvokeAIDiffuserComponent(
self.unet, self._unet_forward, is_running_diffusers=True
)
use_full_precision = precision == "float32" or precision == "autocast"
self.textual_inversion_manager = TextualInversionManager(
tokenizer=self.tokenizer,
text_encoder=self.text_encoder,
full_precision=use_full_precision,
)
# InvokeAI's interface for text embeddings and whatnot
self.embeddings_provider = EmbeddingsProvider(
tokenizer=self.tokenizer,
text_encoder=self.text_encoder,
textual_inversion_manager=self.textual_inversion_manager,
self.unet, self._unet_forward
)
self._model_group = FullyLoadedModelGroup(self.unet.device)
self._model_group = FullyLoadedModelGroup(execution_device or self.unet.device)
self._model_group.install(*self._submodels)
self.control_model = control_model
@@ -404,50 +393,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
else:
self.disable_attention_slicing()
def enable_offload_submodels(self, device: torch.device):
"""
Offload each submodel when it's not in use.
Useful for low-vRAM situations where the size of the model in memory is a big chunk of
the total available resource, and you want to free up as much for inference as possible.
This requires more moving parts and may add some delay as the U-Net is swapped out for the
VAE and vice-versa.
"""
models = self._submodels
if self._model_group is not None:
self._model_group.uninstall(*models)
group = LazilyLoadedModelGroup(device)
group.install(*models)
self._model_group = group
def disable_offload_submodels(self):
"""
Leave all submodels loaded.
Appropriate for cases where the size of the model in memory is small compared to the memory
required for inference. Avoids the delay and complexity of shuffling the submodels to and
from the GPU.
"""
models = self._submodels
if self._model_group is not None:
self._model_group.uninstall(*models)
group = FullyLoadedModelGroup(self._model_group.execution_device)
group.install(*models)
self._model_group = group
def offload_all(self):
"""Offload all this pipeline's models to CPU."""
self._model_group.offload_current()
def ready(self):
"""
Ready this pipeline's models.
i.e. preload them to the GPU if appropriate.
"""
self._model_group.ready()
def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings=False):
# overridden method; types match the superclass.
if torch_device is None:
@@ -656,48 +601,68 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# TODO: should this scaling happen here or inside self._unet_forward?
# i.e. before or after passing it to InvokeAIDiffuserComponent
latent_model_input = self.scheduler.scale_model_input(latents, timestep)
unet_latent_input = self.scheduler.scale_model_input(latents, timestep)
# default is no controlnet, so set controlnet processing output to None
down_block_res_samples, mid_block_res_sample = None, None
if control_data is not None:
# FIXME: make sure guidance_scale < 1.0 is handled correctly if doing per-step guidance setting
# if conditioning_data.guidance_scale > 1.0:
if conditioning_data.guidance_scale is not None:
# expand the latents input to control model if doing classifier free guidance
# (which I think for now is always true, there is conditional elsewhere that stops execution if
# classifier_free_guidance is <= 1.0 ?)
latent_control_input = torch.cat([latent_model_input] * 2)
else:
latent_control_input = latent_model_input
# control_data should be type List[ControlNetData]
# this loop covers both ControlNet (one ControlNetData in list)
# and MultiControlNet (multiple ControlNetData in list)
for i, control_datum in enumerate(control_data):
# print("controlnet", i, "==>", type(control_datum))
control_mode = control_datum.control_mode
# soft_injection and cfg_injection are the two ControlNet control_mode booleans
# that are combined at higher level to make control_mode enum
# soft_injection determines whether to do per-layer re-weighting adjustment (if True)
# or default weighting (if False)
soft_injection = (control_mode == "more_prompt" or control_mode == "more_control")
# cfg_injection = determines whether to apply ControlNet to only the conditional (if True)
# or the default both conditional and unconditional (if False)
cfg_injection = (control_mode == "more_control" or control_mode == "unbalanced")
first_control_step = math.floor(control_datum.begin_step_percent * total_step_count)
last_control_step = math.ceil(control_datum.end_step_percent * total_step_count)
# only apply controlnet if current step is within the controlnet's begin/end step range
if step_index >= first_control_step and step_index <= last_control_step:
# print("running controlnet", i, "for step", step_index)
if cfg_injection:
control_latent_input = unet_latent_input
else:
# expand the latents input to control model if doing classifier free guidance
# (which I think for now is always true, there is conditional elsewhere that stops execution if
# classifier_free_guidance is <= 1.0 ?)
control_latent_input = torch.cat([unet_latent_input] * 2)
if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned
encoder_hidden_states = torch.cat([conditioning_data.unconditioned_embeddings])
else:
encoder_hidden_states = torch.cat([conditioning_data.unconditioned_embeddings,
conditioning_data.text_embeddings])
if isinstance(control_datum.weight, list):
# if controlnet has multiple weights, use the weight for the current step
controlnet_weight = control_datum.weight[step_index]
else:
# if controlnet has a single weight, use it for all steps
controlnet_weight = control_datum.weight
# controlnet(s) inference
down_samples, mid_sample = control_datum.model(
sample=latent_control_input,
sample=control_latent_input,
timestep=timestep,
encoder_hidden_states=torch.cat([conditioning_data.unconditioned_embeddings,
conditioning_data.text_embeddings]),
encoder_hidden_states=encoder_hidden_states,
controlnet_cond=control_datum.image_tensor,
conditioning_scale=controlnet_weight,
# cross_attention_kwargs,
guess_mode=False,
conditioning_scale=controlnet_weight, # controlnet specific, NOT the guidance scale
guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel
return_dict=False,
)
if cfg_injection:
# Inferred ControlNet only for the conditional batch.
# To apply the output of ControlNet to both the unconditional and conditional batches,
# add 0 to the unconditional batch to keep it unchanged.
down_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_samples]
mid_sample = torch.cat([torch.zeros_like(mid_sample), mid_sample])
if down_block_res_samples is None and mid_block_res_sample is None:
down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
else:
@@ -710,11 +675,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# predict the noise residual
noise_pred = self.invokeai_diffuser.do_diffusion_step(
latent_model_input,
t,
conditioning_data.unconditioned_embeddings,
conditioning_data.text_embeddings,
conditioning_data.guidance_scale,
x=unet_latent_input,
sigma=t,
unconditioning=conditioning_data.unconditioned_embeddings,
conditioning=conditioning_data.text_embeddings,
unconditional_guidance_scale=conditioning_data.guidance_scale,
step_index=step_index,
total_step_count=total_step_count,
down_block_additional_residuals=down_block_res_samples, # from controlnet(s)
@@ -991,25 +956,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
device = self._model_group.device_for(self.safety_checker)
return super().run_safety_checker(image, device, dtype)
@torch.inference_mode()
def get_learned_conditioning(
self, c: List[List[str]], *, return_tokens=True, fragment_weights=None
):
"""
Compatibility function for invokeai.models.diffusion.ddpm.LatentDiffusion.
"""
return self.embeddings_provider.get_embeddings_for_weighted_prompt_fragments(
text_batch=c,
fragment_weights_batch=fragment_weights,
should_return_tokens=return_tokens,
device=self._model_group.device_for(self.unet),
)
@property
def channels(self) -> int:
"""Compatible with DiffusionWrapper"""
return self.unet.config.in_channels
def decode_latents(self, latents):
# Explicit call to get the vae loaded, since `decode` isn't the forward method.
self._model_group.load(self.vae)
@@ -1026,8 +972,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# Copied from diffusers pipeline_stable_diffusion_controlnet.py
# Returns torch.Tensor of shape (batch_size, 3, height, width)
@staticmethod
def prepare_control_image(
self,
image,
# FIXME: need to fix hardwiring of width and height, change to basing on latents dimensions?
# latents,
@@ -1038,6 +984,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
device="cuda",
dtype=torch.float16,
do_classifier_free_guidance=True,
control_mode="balanced"
):
if not isinstance(image, torch.Tensor):
@@ -1068,6 +1015,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
repeat_by = num_images_per_prompt
image = image.repeat_interleave(repeat_by, dim=0)
image = image.to(device=device, dtype=dtype)
if do_classifier_free_guidance:
cfg_injection = (control_mode == "more_control" or control_mode == "unbalanced")
if do_classifier_free_guidance and not cfg_injection:
image = torch.cat([image] * 2)
return image

View File

@@ -18,7 +18,6 @@ from .cross_attention_control import (
CrossAttentionType,
SwapCrossAttnContext,
get_cross_attention_modules,
restore_default_cross_attention,
setup_cross_attention_control_attention_processors,
)
from .cross_attention_map_saving import AttentionMapSaver
@@ -66,7 +65,6 @@ class InvokeAIDiffuserComponent:
self,
model,
model_forward_callback: ModelForwardCallback,
is_running_diffusers: bool = False,
):
"""
:param model: the unet model to pass through to cross attention control
@@ -75,7 +73,6 @@ class InvokeAIDiffuserComponent:
config = InvokeAIAppConfig.get_config()
self.conditioning = None
self.model = model
self.is_running_diffusers = is_running_diffusers
self.model_forward_callback = model_forward_callback
self.cross_attention_control_context = None
self.sequential_guidance = config.sequential_guidance
@@ -112,37 +109,6 @@ class InvokeAIDiffuserComponent:
# TODO resuscitate attention map saving
# self.remove_attention_map_saving()
# apparently unused code
# TODO: delete
# def override_cross_attention(
# self, conditioning: ExtraConditioningInfo, step_count: int
# ) -> Dict[str, AttentionProcessor]:
# """
# setup cross attention .swap control. for diffusers this replaces the attention processor, so
# the previous attention processor is returned so that the caller can restore it later.
# """
# self.conditioning = conditioning
# self.cross_attention_control_context = Context(
# arguments=self.conditioning.cross_attention_control_args,
# step_count=step_count,
# )
# return override_cross_attention(
# self.model,
# self.cross_attention_control_context,
# is_running_diffusers=self.is_running_diffusers,
# )
def restore_default_cross_attention(
self, restore_attention_processor: Optional["AttentionProcessor"] = None
):
self.conditioning = None
self.cross_attention_control_context = None
restore_default_cross_attention(
self.model,
is_running_diffusers=self.is_running_diffusers,
restore_attention_processor=restore_attention_processor,
)
def setup_attention_map_saving(self, saver: AttentionMapSaver):
def callback(slice, dim, offset, slice_size, key):
if dim is not None:
@@ -204,9 +170,7 @@ class InvokeAIDiffuserComponent:
cross_attention_control_types_to_do = []
context: Context = self.cross_attention_control_context
if self.cross_attention_control_context is not None:
percent_through = self.calculate_percent_through(
sigma, step_index, total_step_count
)
percent_through = step_index / total_step_count
cross_attention_control_types_to_do = (
context.get_active_cross_attention_control_types_for_step(
percent_through
@@ -264,9 +228,7 @@ class InvokeAIDiffuserComponent:
total_step_count,
) -> torch.Tensor:
if postprocessing_settings is not None:
percent_through = self.calculate_percent_through(
sigma, step_index, total_step_count
)
percent_through = step_index / total_step_count
latents = self.apply_threshold(
postprocessing_settings, latents, percent_through
)
@@ -275,22 +237,6 @@ class InvokeAIDiffuserComponent:
)
return latents
def calculate_percent_through(self, sigma, step_index, total_step_count):
if step_index is not None and total_step_count is not None:
# 🧨diffusers codepath
percent_through = (
step_index / total_step_count
) # will never reach 1.0 - this is deliberate
else:
# legacy compvis codepath
# TODO remove when compvis codepath support is dropped
if step_index is None and sigma is None:
raise ValueError(
"Either step_index or sigma is required when doing cross attention control, but both are None."
)
percent_through = self.estimate_percent_through(step_index, sigma)
return percent_through
# methods below are called from do_diffusion_step and should be considered private to this class.
def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs):
@@ -323,6 +269,7 @@ class InvokeAIDiffuserComponent:
conditioned_next_x = conditioned_next_x.clone()
return unconditioned_next_x, conditioned_next_x
# TODO: looks unused
def _apply_hybrid_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs):
assert isinstance(conditioning, dict)
assert isinstance(unconditioning, dict)
@@ -350,34 +297,6 @@ class InvokeAIDiffuserComponent:
conditioning,
cross_attention_control_types_to_do,
**kwargs,
):
if self.is_running_diffusers:
return self._apply_cross_attention_controlled_conditioning__diffusers(
x,
sigma,
unconditioning,
conditioning,
cross_attention_control_types_to_do,
**kwargs,
)
else:
return self._apply_cross_attention_controlled_conditioning__compvis(
x,
sigma,
unconditioning,
conditioning,
cross_attention_control_types_to_do,
**kwargs,
)
def _apply_cross_attention_controlled_conditioning__diffusers(
self,
x: torch.Tensor,
sigma,
unconditioning,
conditioning,
cross_attention_control_types_to_do,
**kwargs,
):
context: Context = self.cross_attention_control_context
@@ -409,54 +328,6 @@ class InvokeAIDiffuserComponent:
)
return unconditioned_next_x, conditioned_next_x
def _apply_cross_attention_controlled_conditioning__compvis(
self,
x: torch.Tensor,
sigma,
unconditioning,
conditioning,
cross_attention_control_types_to_do,
**kwargs,
):
# print('pct', percent_through, ': doing cross attention control on', cross_attention_control_types_to_do)
# slower non-batched path (20% slower on mac MPS)
# We are only interested in using attention maps for conditioned_next_x, but batching them with generation of
# unconditioned_next_x causes attention maps to *also* be saved for the unconditioned_next_x.
# This messes app their application later, due to mismatched shape of dim 0 (seems to be 16 for batched vs. 8)
# (For the batched invocation the `wrangler` function gets attention tensor with shape[0]=16,
# representing batched uncond + cond, but then when it comes to applying the saved attention, the
# wrangler gets an attention tensor which only has shape[0]=8, representing just self.edited_conditionings.)
# todo: give CrossAttentionControl's `wrangler` function more info so it can work with a batched call as well.
context: Context = self.cross_attention_control_context
try:
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning, **kwargs)
# process x using the original prompt, saving the attention maps
# print("saving attention maps for", cross_attention_control_types_to_do)
for ca_type in cross_attention_control_types_to_do:
context.request_save_attention_maps(ca_type)
_ = self.model_forward_callback(x, sigma, conditioning, **kwargs,)
context.clear_requests(cleanup=False)
# process x again, using the saved attention maps to control where self.edited_conditioning will be applied
# print("applying saved attention maps for", cross_attention_control_types_to_do)
for ca_type in cross_attention_control_types_to_do:
context.request_apply_saved_attention_maps(ca_type)
edited_conditioning = (
self.conditioning.cross_attention_control_args.edited_conditioning
)
conditioned_next_x = self.model_forward_callback(
x, sigma, edited_conditioning, **kwargs,
)
context.clear_requests(cleanup=True)
except:
context.clear_requests(cleanup=True)
raise
return unconditioned_next_x, conditioned_next_x
def _combine(self, unconditioned_next_x, conditioned_next_x, guidance_scale):
# to scale how much effect conditioning has, calculate the changes it does and then scale that
scaled_delta = (conditioned_next_x - unconditioned_next_x) * guidance_scale

View File

@@ -157,7 +157,7 @@ class LazilyLoadedModelGroup(ModelGroup):
def offload_current(self):
module = self._current_model_ref()
if module is not NO_MODEL:
module.to(device=OFFLOAD_DEVICE)
module.to(OFFLOAD_DEVICE)
self.clear_current_model()
def _load(self, module: torch.nn.Module) -> torch.nn.Module:
@@ -228,7 +228,7 @@ class FullyLoadedModelGroup(ModelGroup):
def install(self, *models: torch.nn.Module):
for model in models:
self._models.add(model)
model.to(device=self.execution_device)
model.to(self.execution_device)
def uninstall(self, *models: torch.nn.Module):
for model in models:
@@ -238,11 +238,11 @@ class FullyLoadedModelGroup(ModelGroup):
self.uninstall(*self._models)
def load(self, model):
model.to(device=self.execution_device)
model.to(self.execution_device)
def offload_current(self):
for model in self._models:
model.to(device=OFFLOAD_DEVICE)
model.to(OFFLOAD_DEVICE)
def ready(self):
for model in self._models:
@@ -252,7 +252,7 @@ class FullyLoadedModelGroup(ModelGroup):
self.execution_device = device
for model in self._models:
if model.device != OFFLOAD_DEVICE:
model.to(device=device)
model.to(device)
def device_for(self, model):
if model not in self:

View File

@@ -1,13 +1,14 @@
from diffusers import DDIMScheduler, DPMSolverMultistepScheduler, KDPM2DiscreteScheduler, \
KDPM2AncestralDiscreteScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, \
HeunDiscreteScheduler, LMSDiscreteScheduler, PNDMScheduler, UniPCMultistepScheduler, \
DPMSolverSinglestepScheduler, DEISMultistepScheduler, DDPMScheduler
DPMSolverSinglestepScheduler, DEISMultistepScheduler, DDPMScheduler, DPMSolverSDEScheduler
SCHEDULER_MAP = dict(
ddim=(DDIMScheduler, dict()),
ddpm=(DDPMScheduler, dict()),
deis=(DEISMultistepScheduler, dict()),
lms=(LMSDiscreteScheduler, dict()),
lms=(LMSDiscreteScheduler, dict(use_karras_sigmas=False)),
lms_k=(LMSDiscreteScheduler, dict(use_karras_sigmas=True)),
pndm=(PNDMScheduler, dict()),
heun=(HeunDiscreteScheduler, dict(use_karras_sigmas=False)),
heun_k=(HeunDiscreteScheduler, dict(use_karras_sigmas=True)),
@@ -16,8 +17,13 @@ SCHEDULER_MAP = dict(
euler_a=(EulerAncestralDiscreteScheduler, dict()),
kdpm_2=(KDPM2DiscreteScheduler, dict()),
kdpm_2_a=(KDPM2AncestralDiscreteScheduler, dict()),
dpmpp_2s=(DPMSolverSinglestepScheduler, dict()),
dpmpp_2s=(DPMSolverSinglestepScheduler, dict(use_karras_sigmas=False)),
dpmpp_2s_k=(DPMSolverSinglestepScheduler, dict(use_karras_sigmas=True)),
dpmpp_2m=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=False)),
dpmpp_2m_k=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=True)),
dpmpp_2m_sde=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=False, algorithm_type='sde-dpmsolver++')),
dpmpp_2m_sde_k=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=True, algorithm_type='sde-dpmsolver++')),
dpmpp_sde=(DPMSolverSDEScheduler, dict(use_karras_sigmas=False, noise_sampler_seed=0)),
dpmpp_sde_k=(DPMSolverSDEScheduler, dict(use_karras_sigmas=True, noise_sampler_seed=0)),
unipc=(UniPCMultistepScheduler, dict(cpu_only=True))
)

View File

@@ -1,429 +0,0 @@
import traceback
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Union, List
import safetensors.torch
import torch
from compel.embeddings_provider import BaseTextualInversionManager
from picklescan.scanner import scan_file_path
from transformers import CLIPTextModel, CLIPTokenizer
import invokeai.backend.util.logging as logger
from .concepts_lib import HuggingFaceConceptsLibrary
@dataclass
class EmbeddingInfo:
name: str
embedding: torch.Tensor
num_vectors_per_token: int
token_dim: int
trained_steps: int = None
trained_model_name: str = None
trained_model_checksum: str = None
@dataclass
class TextualInversion:
trigger_string: str
embedding: torch.Tensor
trigger_token_id: Optional[int] = None
pad_token_ids: Optional[list[int]] = None
@property
def embedding_vector_length(self) -> int:
return self.embedding.shape[0]
class TextualInversionManager(BaseTextualInversionManager):
def __init__(
self,
tokenizer: CLIPTokenizer,
text_encoder: CLIPTextModel,
full_precision: bool = True,
):
self.tokenizer = tokenizer
self.text_encoder = text_encoder
self.full_precision = full_precision
self.hf_concepts_library = HuggingFaceConceptsLibrary()
self.trigger_to_sourcefile = dict()
default_textual_inversions: list[TextualInversion] = []
self.textual_inversions = default_textual_inversions
def load_huggingface_concepts(self, concepts: list[str]):
for concept_name in concepts:
if concept_name in self.hf_concepts_library.concepts_loaded:
continue
trigger = self.hf_concepts_library.concept_to_trigger(concept_name)
if (
self.has_textual_inversion_for_trigger_string(trigger)
or self.has_textual_inversion_for_trigger_string(concept_name)
or self.has_textual_inversion_for_trigger_string(f"<{concept_name}>")
): # in case a token with literal angle brackets encountered
logger.info(f"Loaded local embedding for trigger {concept_name}")
continue
bin_file = self.hf_concepts_library.get_concept_model_path(concept_name)
if not bin_file:
continue
logger.info(f"Loaded remote embedding for trigger {concept_name}")
self.load_textual_inversion(bin_file)
self.hf_concepts_library.concepts_loaded[concept_name] = True
def get_all_trigger_strings(self) -> list[str]:
return [ti.trigger_string for ti in self.textual_inversions]
def load_textual_inversion(
self, ckpt_path: Union[str, Path], defer_injecting_tokens: bool = False
):
ckpt_path = Path(ckpt_path)
if not ckpt_path.is_file():
return
if str(ckpt_path).endswith(".DS_Store"):
return
embedding_list = self._parse_embedding(str(ckpt_path))
for embedding_info in embedding_list:
if (self.text_encoder.get_input_embeddings().weight.data[0].shape[0] != embedding_info.token_dim):
logger.warning(
f"Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with an incompatible token dimension: {self.text_encoder.get_input_embeddings().weight.data[0].shape[0]} vs {embedding_info.token_dim}."
)
continue
# Resolve the situation in which an earlier embedding has claimed the same
# trigger string. We replace the trigger with '<source_file>', as we used to.
trigger_str = embedding_info.name
sourcefile = (
f"{ckpt_path.parent.name}/{ckpt_path.name}"
if ckpt_path.name == "learned_embeds.bin"
else ckpt_path.name
)
if trigger_str in self.trigger_to_sourcefile:
replacement_trigger_str = (
f"<{ckpt_path.parent.name}>"
if ckpt_path.name == "learned_embeds.bin"
else f"<{ckpt_path.stem}>"
)
logger.info(
f"{sourcefile}: Trigger token '{trigger_str}' is already claimed by '{self.trigger_to_sourcefile[trigger_str]}'. Trigger this concept with {replacement_trigger_str}"
)
trigger_str = replacement_trigger_str
try:
self._add_textual_inversion(
trigger_str,
embedding_info.embedding,
defer_injecting_tokens=defer_injecting_tokens,
)
# remember which source file claims this trigger
self.trigger_to_sourcefile[trigger_str] = sourcefile
except ValueError as e:
logger.debug(f'Ignoring incompatible embedding {embedding_info["name"]}')
logger.debug(f"The error was {str(e)}")
def _add_textual_inversion(
self, trigger_str, embedding, defer_injecting_tokens=False
) -> Optional[TextualInversion]:
"""
Add a textual inversion to be recognised.
:param trigger_str: The trigger text in the prompt that activates this textual inversion. If unknown to the embedder's tokenizer, will be added.
:param embedding: The actual embedding data that will be inserted into the conditioning at the point where the token_str appears.
:return: The token id for the added embedding, either existing or newly-added.
"""
if trigger_str in [ti.trigger_string for ti in self.textual_inversions]:
logger.warning(
f"TextualInversionManager refusing to overwrite already-loaded token '{trigger_str}'"
)
return
if not self.full_precision:
embedding = embedding.half()
if len(embedding.shape) == 1:
embedding = embedding.unsqueeze(0)
elif len(embedding.shape) > 2:
raise ValueError(
f"** TextualInversionManager cannot add {trigger_str} because the embedding shape {embedding.shape} is incorrect. The embedding must have shape [token_dim] or [V, token_dim] where V is vector length and token_dim is 768 for SD1 or 1280 for SD2."
)
try:
ti = TextualInversion(trigger_string=trigger_str, embedding=embedding)
if not defer_injecting_tokens:
self._inject_tokens_and_assign_embeddings(ti)
self.textual_inversions.append(ti)
return ti
except ValueError as e:
if str(e).startswith("Warning"):
logger.warning(f"{str(e)}")
else:
traceback.print_exc()
logger.error(
f"TextualInversionManager was unable to add a textual inversion with trigger string {trigger_str}."
)
raise
def _inject_tokens_and_assign_embeddings(self, ti: TextualInversion) -> int:
if ti.trigger_token_id is not None:
raise ValueError(
f"Tokens already injected for textual inversion with trigger '{ti.trigger_string}'"
)
trigger_token_id = self._get_or_create_token_id_and_assign_embedding(
ti.trigger_string, ti.embedding[0]
)
if ti.embedding_vector_length > 1:
# for embeddings with vector length > 1
pad_token_strings = [
ti.trigger_string + "-!pad-" + str(pad_index)
for pad_index in range(1, ti.embedding_vector_length)
]
# todo: batched UI for faster loading when vector length >2
pad_token_ids = [
self._get_or_create_token_id_and_assign_embedding(
pad_token_str, ti.embedding[1 + i]
)
for (i, pad_token_str) in enumerate(pad_token_strings)
]
else:
pad_token_ids = []
ti.trigger_token_id = trigger_token_id
ti.pad_token_ids = pad_token_ids
return ti.trigger_token_id
def has_textual_inversion_for_trigger_string(self, trigger_string: str) -> bool:
try:
ti = self.get_textual_inversion_for_trigger_string(trigger_string)
return ti is not None
except StopIteration:
return False
def get_textual_inversion_for_trigger_string(
self, trigger_string: str
) -> TextualInversion:
return next(
ti for ti in self.textual_inversions if ti.trigger_string == trigger_string
)
def get_textual_inversion_for_token_id(self, token_id: int) -> TextualInversion:
return next(
ti for ti in self.textual_inversions if ti.trigger_token_id == token_id
)
def create_deferred_token_ids_for_any_trigger_terms(
self, prompt_string: str
) -> list[int]:
injected_token_ids = []
for ti in self.textual_inversions:
if ti.trigger_token_id is None and ti.trigger_string in prompt_string:
if ti.embedding_vector_length > 1:
logger.info(
f"Preparing tokens for textual inversion {ti.trigger_string}..."
)
try:
self._inject_tokens_and_assign_embeddings(ti)
except ValueError as e:
logger.debug(
f"Ignoring incompatible embedding trigger {ti.trigger_string}"
)
logger.debug(f"The error was {str(e)}")
continue
injected_token_ids.append(ti.trigger_token_id)
injected_token_ids.extend(ti.pad_token_ids)
return injected_token_ids
def expand_textual_inversion_token_ids_if_necessary(
self, prompt_token_ids: list[int]
) -> list[int]:
"""
Insert padding tokens as necessary into the passed-in list of token ids to match any textual inversions it includes.
:param prompt_token_ids: The prompt as a list of token ids (`int`s). Should not include bos and eos markers.
:return: The prompt token ids with any necessary padding to account for textual inversions inserted. May be too
long - caller is responsible for prepending/appending eos and bos token ids, and truncating if necessary.
"""
if len(prompt_token_ids) == 0:
return prompt_token_ids
if prompt_token_ids[0] == self.tokenizer.bos_token_id:
raise ValueError("prompt_token_ids must not start with bos_token_id")
if prompt_token_ids[-1] == self.tokenizer.eos_token_id:
raise ValueError("prompt_token_ids must not end with eos_token_id")
textual_inversion_trigger_token_ids = [
ti.trigger_token_id for ti in self.textual_inversions
]
prompt_token_ids = prompt_token_ids.copy()
for i, token_id in reversed(list(enumerate(prompt_token_ids))):
if token_id in textual_inversion_trigger_token_ids:
textual_inversion = next(
ti
for ti in self.textual_inversions
if ti.trigger_token_id == token_id
)
for pad_idx in range(0, textual_inversion.embedding_vector_length - 1):
prompt_token_ids.insert(
i + pad_idx + 1, textual_inversion.pad_token_ids[pad_idx]
)
return prompt_token_ids
def _get_or_create_token_id_and_assign_embedding(
self, token_str: str, embedding: torch.Tensor
) -> int:
if len(embedding.shape) != 1:
raise ValueError(
"Embedding has incorrect shape - must be [token_dim] where token_dim is 768 for SD1 or 1280 for SD2"
)
existing_token_id = self.tokenizer.convert_tokens_to_ids(token_str)
if existing_token_id == self.tokenizer.unk_token_id:
num_tokens_added = self.tokenizer.add_tokens(token_str)
current_embeddings = self.text_encoder.resize_token_embeddings(None)
current_token_count = current_embeddings.num_embeddings
new_token_count = current_token_count + num_tokens_added
# the following call is slow - todo make batched for better performance with vector length >1
self.text_encoder.resize_token_embeddings(new_token_count)
token_id = self.tokenizer.convert_tokens_to_ids(token_str)
if token_id == self.tokenizer.unk_token_id:
raise RuntimeError(f"Unable to find token id for token '{token_str}'")
if (
self.text_encoder.get_input_embeddings().weight.data[token_id].shape
!= embedding.shape
):
raise ValueError(
f"Warning. Cannot load embedding for {token_str}. It was trained on a model with token dimension {embedding.shape[0]}, but the current model has token dimension {self.text_encoder.get_input_embeddings().weight.data[token_id].shape[0]}."
)
self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding
return token_id
def _parse_embedding(self, embedding_file: str)->List[EmbeddingInfo]:
suffix = Path(embedding_file).suffix
try:
if suffix in [".pt",".ckpt",".bin"]:
scan_result = scan_file_path(embedding_file)
if scan_result.infected_files > 0:
logger.critical(
f"Security Issues Found in Model: {scan_result.issues_count}"
)
logger.critical("For your safety, InvokeAI will not load this embed.")
return list()
ckpt = torch.load(embedding_file,map_location="cpu")
else:
ckpt = safetensors.torch.load_file(embedding_file)
except Exception as e:
logger.warning(f"Notice: unrecognized embedding file format: {embedding_file}: {e}")
return list()
# try to figure out what kind of embedding file it is and parse accordingly
keys = list(ckpt.keys())
if all(x in keys for x in ['string_to_token','string_to_param','name','step']):
return self._parse_embedding_v1(ckpt, embedding_file) # example rem_rezero.pt
elif all(x in keys for x in ['string_to_token','string_to_param']):
return self._parse_embedding_v2(ckpt, embedding_file) # example midj-strong.pt
elif 'emb_params' in keys:
return self._parse_embedding_v3(ckpt, embedding_file) # example easynegative.safetensors
else:
return self._parse_embedding_v4(ckpt, embedding_file) # usually a '.bin' file
def _parse_embedding_v1(self, embedding_ckpt: dict, file_path: str)->List[EmbeddingInfo]:
basename = Path(file_path).stem
logger.debug(f'Loading v1 embedding file: {basename}')
embeddings = list()
token_counter = -1
for token,embedding in embedding_ckpt["string_to_param"].items():
if token_counter < 0:
trigger = embedding_ckpt["name"]
elif token_counter == 0:
trigger = '<basename>'
else:
trigger = f'<{basename}-{int(token_counter:=token_counter)}>'
token_counter += 1
embedding_info = EmbeddingInfo(
name = trigger,
embedding = embedding,
num_vectors_per_token = embedding.size()[0],
token_dim = embedding.size()[1],
trained_steps = embedding_ckpt["step"],
trained_model_name = embedding_ckpt["sd_checkpoint_name"],
trained_model_checksum = embedding_ckpt["sd_checkpoint"]
)
embeddings.append(embedding_info)
return embeddings
def _parse_embedding_v2 (
self, embedding_ckpt: dict, file_path: str
) -> List[EmbeddingInfo]:
"""
This handles embedding .pt file variant #2.
"""
basename = Path(file_path).stem
logger.debug(f'Loading v2 embedding file: {basename}')
embeddings = list()
if isinstance(
list(embedding_ckpt["string_to_token"].values())[0], torch.Tensor
):
token_counter = 0
for token,embedding in embedding_ckpt["string_to_param"].items():
trigger = token if token != '*' \
else f'<{basename}>' if token_counter == 0 \
else f'<{basename}-{int(token_counter:=token_counter+1)}>'
embedding_info = EmbeddingInfo(
name = trigger,
embedding = embedding,
num_vectors_per_token = embedding.size()[0],
token_dim = embedding.size()[1],
)
embeddings.append(embedding_info)
else:
logger.warning(f"{basename}: Unrecognized embedding format")
return embeddings
def _parse_embedding_v3(self, embedding_ckpt: dict, file_path: str)->List[EmbeddingInfo]:
"""
Parse 'version 3' of the .pt textual inversion embedding files.
"""
basename = Path(file_path).stem
logger.debug(f'Loading v3 embedding file: {basename}')
embedding = embedding_ckpt['emb_params']
embedding_info = EmbeddingInfo(
name = f'<{basename}>',
embedding = embedding,
num_vectors_per_token = embedding.size()[0],
token_dim = embedding.size()[1],
)
return [embedding_info]
def _parse_embedding_v4(self, embedding_ckpt: dict, filepath: str)->List[EmbeddingInfo]:
"""
Parse 'version 4' of the textual inversion embedding files. This one
is usually associated with .bin files trained by HuggingFace diffusers.
"""
basename = Path(filepath).stem
short_path = Path(filepath).parents[0].name+'/'+Path(filepath).name
logger.debug(f'Loading v4 embedding file: {short_path}')
embeddings = list()
if list(embedding_ckpt.keys()) == 0:
logger.warning(f"Invalid embeddings file: {short_path}")
else:
for token,embedding in embedding_ckpt.items():
embedding_info = EmbeddingInfo(
name = token or f"<{basename}>",
embedding = embedding,
num_vectors_per_token = 1, # All Concepts seem to default to 1
token_dim = embedding.size()[0],
)
embeddings.append(embedding_info)
return embeddings

View File

@@ -358,7 +358,6 @@ class InvokeAILogger(object):
elif handler_name=='syslog':
ch = cls._parse_syslog_args(args)
ch.setFormatter(InvokeAISyslogFormatter())
handlers.append(ch)
elif handler_name=='file':
@@ -367,7 +366,8 @@ class InvokeAILogger(object):
handlers.append(ch)
elif handler_name=='http':
handlers.append(cls._parse_http_args(args))
ch = cls._parse_http_args(args)
handlers.append(ch)
return handlers
@staticmethod

View File

@@ -1277,13 +1277,14 @@ class InvokeAIWebServer:
eventlet.sleep(0)
parsed_prompt, _ = get_prompt_structure(generation_parameters["prompt"])
tokens = (
None
if type(parsed_prompt) is Blend
else get_tokens_for_prompt_object(
self.generate.model.tokenizer, parsed_prompt
with self.generate.model_context as model:
tokens = (
None
if type(parsed_prompt) is Blend
else get_tokens_for_prompt_object(
model.tokenizer, parsed_prompt
)
)
)
attention_maps_image_base64_url = (
None
if attention_maps_image is None

View File

@@ -7,6 +7,7 @@ SAMPLER_CHOICES = [
"ddpm",
"deis",
"lms",
"lms_k",
"pndm",
"heun",
'heun_k',
@@ -16,8 +17,13 @@ SAMPLER_CHOICES = [
"kdpm_2",
"kdpm_2_a",
"dpmpp_2s",
"dpmpp_2s_k",
"dpmpp_2m",
"dpmpp_2m_k",
"dpmpp_2m_sde",
"dpmpp_2m_sde_k",
"dpmpp_sde",
"dpmpp_sde_k",
"unipc",
]

View File

@@ -965,13 +965,15 @@ def main():
logger.error(
"Insufficient vertical space for the interface. Please make your window taller and try again"
)
elif str(e).startswith("addwstr"):
input('Press any key to continue...')
except Exception as e:
if str(e).startswith("addwstr"):
logger.error(
"Insufficient horizontal space for the interface. Please make your window wider and try again."
)
except Exception as e:
print(f'An exception has occurred: {str(e)} Details:')
print(traceback.format_exc(), file=sys.stderr)
else:
print(f'An exception has occurred: {str(e)} Details:')
print(traceback.format_exc(), file=sys.stderr)
input('Press any key to continue...')

View File

@@ -42,6 +42,18 @@ def set_terminal_size(columns: int, lines: int, launch_command: str=None):
elif OS in ["Darwin", "Linux"]:
_set_terminal_size_unix(width,height)
# check whether it worked....
ts = get_terminal_size()
pause = False
if ts.columns < columns:
print('\033[1mThis window is too narrow for the user interface. Please make it wider.\033[0m')
pause = True
if ts.lines < lines:
print('\033[1mThis window is too short for the user interface. Please make it taller.\033[0m')
pause = True
if pause:
input('Press any key to continue..')
def _set_terminal_size_powershell(width: int, height: int):
script=f'''
$pshost = get-host

View File

@@ -0,0 +1,14 @@
import react from '@vitejs/plugin-react-swc';
import { visualizer } from 'rollup-plugin-visualizer';
import { PluginOption, UserConfig } from 'vite';
import eslint from 'vite-plugin-eslint';
import tsconfigPaths from 'vite-tsconfig-paths';
import { nodePolyfills } from 'vite-plugin-node-polyfills';
export const commonPlugins: UserConfig['plugins'] = [
react(),
eslint(),
tsconfigPaths(),
visualizer() as unknown as PluginOption,
nodePolyfills(),
];

View File

@@ -1,17 +1,9 @@
import react from '@vitejs/plugin-react-swc';
import { visualizer } from 'rollup-plugin-visualizer';
import { PluginOption, UserConfig } from 'vite';
import eslint from 'vite-plugin-eslint';
import tsconfigPaths from 'vite-tsconfig-paths';
import { UserConfig } from 'vite';
import { commonPlugins } from './common';
export const appConfig: UserConfig = {
base: './',
plugins: [
react(),
eslint(),
tsconfigPaths(),
visualizer() as unknown as PluginOption,
],
plugins: [...commonPlugins],
build: {
chunkSizeWarningLimit: 1500,
},

View File

@@ -1,19 +1,13 @@
import react from '@vitejs/plugin-react-swc';
import path from 'path';
import { visualizer } from 'rollup-plugin-visualizer';
import { PluginOption, UserConfig } from 'vite';
import { UserConfig } from 'vite';
import dts from 'vite-plugin-dts';
import eslint from 'vite-plugin-eslint';
import tsconfigPaths from 'vite-tsconfig-paths';
import cssInjectedByJsPlugin from 'vite-plugin-css-injected-by-js';
import { commonPlugins } from './common';
export const packageConfig: UserConfig = {
base: './',
plugins: [
react(),
eslint(),
tsconfigPaths(),
visualizer() as unknown as PluginOption,
...commonPlugins,
dts({
insertTypesEntry: true,
}),

View File

@@ -12,7 +12,7 @@
margin: 0;
}
</style>
<script type="module" crossorigin src="./assets/index-b060dbab.js"></script>
<script type="module" crossorigin src="./assets/index-8a3e9251.js"></script>
</head>
<body dir="ltr">

View File

@@ -506,8 +506,8 @@
"isScheduled": "Canceling",
"setType": "Set cancel type"
},
"promptPlaceholder": "Type prompt here. [negative tokens], (upweight)++, (downweight)--, swap and blend are available (see docs)",
"negativePrompts": "Negative Prompts",
"positivePromptPlaceholder": "Positive Prompt",
"negativePromptPlaceholder": "Negative Prompt",
"sendTo": "Send to",
"sendToImg2Img": "Send to Image to Image",
"sendToUnifiedCanvas": "Send To Unified Canvas",

View File

@@ -23,8 +23,7 @@
"dev": "concurrently \"vite dev\" \"yarn run theme:watch\"",
"dev:host": "concurrently \"vite dev --host\" \"yarn run theme:watch\"",
"build": "yarn run lint && vite build",
"api:web": "openapi -i http://localhost:9090/openapi.json -o src/services/api --client axios --useOptions --useUnionTypes --indent 2 --request src/services/fixtures/request.ts",
"api:file": "openapi -i src/services/fixtures/openapi.json -o src/services/api --client axios --useOptions --useUnionTypes --indent 2 --request src/services/fixtures/request.ts",
"typegen": "npx openapi-typescript http://localhost:9090/openapi.json --output src/services/api/schema.d.ts -t",
"preview": "vite preview",
"lint:madge": "madge --circular src/main.tsx",
"lint:eslint": "eslint --max-warnings=0 .",
@@ -54,55 +53,62 @@
]
},
"dependencies": {
"@apidevtools/swagger-parser": "^10.1.0",
"@chakra-ui/anatomy": "^2.1.1",
"@chakra-ui/icons": "^2.0.19",
"@chakra-ui/react": "^2.6.0",
"@chakra-ui/styled-system": "^2.9.0",
"@chakra-ui/theme-tools": "^2.0.16",
"@dagrejs/graphlib": "^2.1.12",
"@chakra-ui/react": "^2.7.1",
"@chakra-ui/styled-system": "^2.9.1",
"@chakra-ui/theme-tools": "^2.0.18",
"@dagrejs/graphlib": "^2.1.13",
"@dnd-kit/core": "^6.0.8",
"@dnd-kit/modifiers": "^6.0.1",
"@emotion/react": "^11.10.6",
"@emotion/styled": "^11.10.6",
"@floating-ui/react-dom": "^2.0.0",
"@fontsource/inter": "^4.5.15",
"@emotion/react": "^11.11.1",
"@emotion/styled": "^11.11.0",
"@floating-ui/react-dom": "^2.0.1",
"@fontsource-variable/inter": "^5.0.3",
"@fontsource/inter": "^5.0.3",
"@mantine/core": "^6.0.14",
"@mantine/hooks": "^6.0.14",
"@reduxjs/toolkit": "^1.9.5",
"@roarr/browser-log-writer": "^1.1.5",
"chakra-ui-contextmenu": "^1.0.5",
"dateformat": "^5.0.3",
"downshift": "^7.6.0",
"formik": "^2.2.9",
"framer-motion": "^10.12.4",
"formik": "^2.4.2",
"framer-motion": "^10.12.17",
"fuse.js": "^6.6.2",
"i18next": "^22.4.15",
"i18next-browser-languagedetector": "^7.0.1",
"i18next-http-backend": "^2.2.0",
"konva": "^9.0.1",
"i18next": "^23.2.3",
"i18next-browser-languagedetector": "^7.0.2",
"i18next-http-backend": "^2.2.1",
"konva": "^9.2.0",
"lodash-es": "^4.17.21",
"overlayscrollbars": "^2.1.1",
"nanostores": "^0.9.2",
"openapi-fetch": "^0.4.0",
"overlayscrollbars": "^2.2.0",
"overlayscrollbars-react": "^0.5.0",
"patch-package": "^7.0.0",
"query-string": "^8.1.0",
"re-resizable": "^6.9.9",
"react": "^18.2.0",
"react-colorful": "^5.6.1",
"react-dom": "^18.2.0",
"react-dropzone": "^14.2.3",
"react-hotkeys-hook": "4.4.0",
"react-i18next": "^12.2.2",
"react-icons": "^4.9.0",
"react-konva": "^18.2.7",
"react-redux": "^8.0.5",
"react-resizable-panels": "^0.0.42",
"react-i18next": "^13.0.1",
"react-icons": "^4.10.1",
"react-konva": "^18.2.10",
"react-redux": "^8.1.1",
"react-resizable-panels": "^0.0.52",
"react-use": "^17.4.0",
"react-virtuoso": "^4.3.5",
"react-zoom-pan-pinch": "^3.0.7",
"reactflow": "^11.7.0",
"react-virtuoso": "^4.3.11",
"react-zoom-pan-pinch": "^3.0.8",
"reactflow": "^11.7.4",
"redux-dynamic-middlewares": "^2.2.0",
"redux-remember": "^3.3.1",
"roarr": "^7.15.0",
"serialize-error": "^11.0.0",
"socket.io-client": "^4.6.0",
"use-image": "^1.1.0",
"socket.io-client": "^4.7.0",
"use-image": "^1.1.1",
"uuid": "^9.0.0",
"zod": "^3.21.4"
},
@@ -113,22 +119,22 @@
"ts-toolbelt": "^9.6.0"
},
"devDependencies": {
"@chakra-ui/cli": "^2.4.0",
"@chakra-ui/cli": "^2.4.1",
"@types/dateformat": "^5.0.0",
"@types/lodash-es": "^4.14.194",
"@types/node": "^18.16.2",
"@types/react": "^18.2.0",
"@types/react-dom": "^18.2.1",
"@types/node": "^20.3.1",
"@types/react": "^18.2.14",
"@types/react-dom": "^18.2.6",
"@types/react-redux": "^7.1.25",
"@types/react-transition-group": "^4.4.5",
"@types/uuid": "^9.0.0",
"@typescript-eslint/eslint-plugin": "^5.59.1",
"@typescript-eslint/parser": "^5.59.1",
"@vitejs/plugin-react-swc": "^3.3.0",
"@types/react-transition-group": "^4.4.6",
"@types/uuid": "^9.0.2",
"@typescript-eslint/eslint-plugin": "^5.60.0",
"@typescript-eslint/parser": "^5.60.0",
"@vitejs/plugin-react-swc": "^3.3.2",
"axios": "^1.4.0",
"babel-plugin-transform-imports": "^2.0.0",
"concurrently": "^8.0.1",
"eslint": "^8.39.0",
"concurrently": "^8.2.0",
"eslint": "^8.43.0",
"eslint-config-prettier": "^8.8.0",
"eslint-plugin-prettier": "^4.2.1",
"eslint-plugin-react": "^7.32.2",
@@ -136,18 +142,20 @@
"form-data": "^4.0.0",
"husky": "^8.0.3",
"lint-staged": "^13.2.2",
"madge": "^6.0.0",
"openapi-types": "^12.1.0",
"madge": "^6.1.0",
"openapi-types": "^12.1.3",
"openapi-typescript": "^6.2.8",
"openapi-typescript-codegen": "^0.24.0",
"postinstall-postinstall": "^2.1.0",
"prettier": "^2.8.8",
"rollup-plugin-visualizer": "^5.9.0",
"terser": "^5.17.1",
"rollup-plugin-visualizer": "^5.9.2",
"terser": "^5.18.1",
"ts-toolbelt": "^9.6.0",
"vite": "^4.3.3",
"vite": "^4.3.9",
"vite-plugin-css-injected-by-js": "^3.1.1",
"vite-plugin-dts": "^2.3.0",
"vite-plugin-eslint": "^1.8.1",
"vite-plugin-node-polyfills": "^0.9.0",
"vite-tsconfig-paths": "^4.2.0",
"yarn": "^1.22.19"
}

View File

@@ -1,14 +0,0 @@
diff --git a/node_modules/@chakra-ui/cli/dist/scripts/read-theme-file.worker.js b/node_modules/@chakra-ui/cli/dist/scripts/read-theme-file.worker.js
index 937cf0d..7dcc0c0 100644
--- a/node_modules/@chakra-ui/cli/dist/scripts/read-theme-file.worker.js
+++ b/node_modules/@chakra-ui/cli/dist/scripts/read-theme-file.worker.js
@@ -50,7 +50,8 @@ async function readTheme(themeFilePath) {
project: tsConfig.configFileAbsolutePath,
compilerOptions: {
module: "CommonJS",
- esModuleInterop: true
+ esModuleInterop: true,
+ jsx: 'react'
},
transpileOnly: true,
swc: true

View File

@@ -0,0 +1,55 @@
diff --git a/node_modules/openapi-fetch/dist/index.js b/node_modules/openapi-fetch/dist/index.js
index cd4528a..8976b51 100644
--- a/node_modules/openapi-fetch/dist/index.js
+++ b/node_modules/openapi-fetch/dist/index.js
@@ -1,5 +1,5 @@
// settings & const
-const DEFAULT_HEADERS = {
+const CONTENT_TYPE_APPLICATION_JSON = {
"Content-Type": "application/json",
};
const TRAILING_SLASH_RE = /\/*$/;
@@ -29,18 +29,29 @@ export function createFinalURL(url, options) {
}
return finalURL;
}
+function stringifyBody(body) {
+ if (body instanceof ArrayBuffer || body instanceof File || body instanceof DataView || body instanceof Blob || ArrayBuffer.isView(body) || body instanceof URLSearchParams || body instanceof FormData) {
+ return;
+ }
+
+ if (typeof body === "string") {
+ return body;
+ }
+
+ return JSON.stringify(body);
+ }
+
export default function createClient(clientOptions = {}) {
const { fetch = globalThis.fetch, ...options } = clientOptions;
- const defaultHeaders = new Headers({
- ...DEFAULT_HEADERS,
- ...(options.headers ?? {}),
- });
+ const defaultHeaders = new Headers(options.headers ?? {});
async function coreFetch(url, fetchOptions) {
const { headers, body: requestBody, params = {}, parseAs = "json", querySerializer = defaultSerializer, ...init } = fetchOptions || {};
// URL
const finalURL = createFinalURL(url, { baseUrl: options.baseUrl, params, querySerializer });
+ // Stringify body if needed
+ const stringifiedBody = stringifyBody(requestBody);
// headers
- const baseHeaders = new Headers(defaultHeaders); // clone defaults (dont overwrite!)
+ const baseHeaders = new Headers(stringifiedBody ? { ...CONTENT_TYPE_APPLICATION_JSON, ...defaultHeaders } : defaultHeaders); // clone defaults (dont overwrite!)
const headerOverrides = new Headers(headers);
for (const [k, v] of headerOverrides.entries()) {
if (v === undefined || v === null)
@@ -54,7 +65,7 @@ export default function createClient(clientOptions = {}) {
...options,
...init,
headers: baseHeaders,
- body: typeof requestBody === "string" ? requestBody : JSON.stringify(requestBody),
+ body: stringifiedBody ?? requestBody,
});
// handle empty content
// note: we return `{}` because we want user truthy checks for `.data` or `.error` to succeed

View File

@@ -524,7 +524,8 @@
"initialImage": "Initial Image",
"showOptionsPanel": "Show Options Panel",
"hidePreview": "Hide Preview",
"showPreview": "Show Preview"
"showPreview": "Show Preview",
"controlNetControlMode": "Control Mode"
},
"settings": {
"models": "Models",
@@ -547,7 +548,8 @@
"general": "General",
"generation": "Generation",
"ui": "User Interface",
"availableSchedulers": "Available Schedulers"
"favoriteSchedulers": "Favorite Schedulers",
"favoriteSchedulersPlaceholder": "No schedulers favorited"
},
"toast": {
"serverError": "Server Error",

View File

@@ -23,6 +23,8 @@ import GlobalHotkeys from './GlobalHotkeys';
import Toaster from './Toaster';
import DeleteImageModal from 'features/gallery/components/DeleteImageModal';
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
import UpdateImageBoardModal from '../../features/gallery/components/Boards/UpdateImageBoardModal';
import { useListModelsQuery } from 'services/api/endpoints/models';
const DEFAULT_CONFIG = {};
@@ -45,6 +47,18 @@ const App = ({
const isApplicationReady = useIsApplicationReady();
const { data: pipelineModels } = useListModelsQuery({
model_type: 'pipeline',
});
const { data: controlnetModels } = useListModelsQuery({
model_type: 'controlnet',
});
const { data: vaeModels } = useListModelsQuery({ model_type: 'vae' });
const { data: loraModels } = useListModelsQuery({ model_type: 'lora' });
const { data: embeddingModels } = useListModelsQuery({
model_type: 'embedding',
});
const [loadingOverridden, setLoadingOverridden] = useState(false);
const dispatch = useAppDispatch();
@@ -143,6 +157,7 @@ const App = ({
</Portal>
</Grid>
<DeleteImageModal />
<UpdateImageBoardModal />
<Toaster />
<GlobalHotkeys />
</>

View File

@@ -11,8 +11,8 @@ import {
} from '@dnd-kit/core';
import { PropsWithChildren, memo, useCallback, useState } from 'react';
import OverlayDragImage from './OverlayDragImage';
import { ImageDTO } from 'services/api';
import { isImageDTO } from 'services/types/guards';
import { ImageDTO } from 'services/api/types';
import { isImageDTO } from 'services/api/guards';
import { snapCenterToCursor } from '@dnd-kit/modifiers';
import { AnimatePresence, motion } from 'framer-motion';

View File

@@ -1,6 +1,6 @@
import { Box, Image } from '@chakra-ui/react';
import { memo } from 'react';
import { ImageDTO } from 'services/api';
import { ImageDTO } from 'services/api/types';
type OverlayDragImageProps = {
image: ImageDTO;

View File

@@ -7,7 +7,7 @@ import React, {
} from 'react';
import { Provider } from 'react-redux';
import { store } from 'app/store/store';
import { OpenAPI } from 'services/api';
// import { OpenAPI } from 'services/api/types';
import Loading from '../../common/components/Loading/Loading';
import { addMiddleware, resetMiddlewares } from 'redux-dynamic-middlewares';
@@ -21,6 +21,9 @@ import {
DeleteImageContext,
DeleteImageContextProvider,
} from 'app/contexts/DeleteImageContext';
import UpdateImageBoardModal from '../../features/gallery/components/Boards/UpdateImageBoardModal';
import { AddImageToBoardContextProvider } from '../contexts/AddImageToBoardContext';
import { $authToken, $baseUrl } from 'services/api/client';
const App = lazy(() => import('./App'));
const ThemeLocaleProvider = lazy(() => import('./ThemeLocaleProvider'));
@@ -45,12 +48,12 @@ const InvokeAIUI = ({
useEffect(() => {
// configure API client token
if (token) {
OpenAPI.TOKEN = token;
$authToken.set(token);
}
// configure API client base url
if (apiUrl) {
OpenAPI.BASE = apiUrl;
$baseUrl.set(apiUrl);
}
// reset dynamically added middlewares
@@ -67,6 +70,12 @@ const InvokeAIUI = ({
} else {
addMiddleware(socketMiddleware());
}
return () => {
// Reset the API client token and base url on unmount
$baseUrl.set(undefined);
$authToken.set(undefined);
};
}, [apiUrl, token, middleware]);
return (
@@ -76,11 +85,13 @@ const InvokeAIUI = ({
<ThemeLocaleProvider>
<ImageDndContext>
<DeleteImageContextProvider>
<App
config={config}
headerComponent={headerComponent}
setIsReady={setIsReady}
/>
<AddImageToBoardContextProvider>
<App
config={config}
headerComponent={headerComponent}
setIsReady={setIsReady}
/>
</AddImageToBoardContextProvider>
</DeleteImageContextProvider>
</ImageDndContext>
</ThemeLocaleProvider>

View File

@@ -3,18 +3,20 @@ import {
createLocalStorageManager,
extendTheme,
} from '@chakra-ui/react';
import { RootState } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { ReactNode, useEffect } from 'react';
import { useTranslation } from 'react-i18next';
import { theme as invokeAITheme } from 'theme/theme';
import { RootState } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { greenTeaThemeColors } from 'theme/colors/greenTea';
import { invokeAIThemeColors } from 'theme/colors/invokeAI';
import { lightThemeColors } from 'theme/colors/lightTheme';
import { oceanBlueColors } from 'theme/colors/oceanBlue';
import '@fontsource/inter/variable.css';
import '@fontsource-variable/inter';
import { MantineProvider } from '@mantine/core';
import { mantineTheme } from 'mantine-theme/theme';
import 'overlayscrollbars/overlayscrollbars.css';
import 'theme/css/overlayscrollbars.css';
@@ -51,9 +53,11 @@ function ThemeLocaleProvider({ children }: ThemeLocaleProviderProps) {
}, [direction]);
return (
<ChakraProvider theme={theme} colorModeManager={manager}>
{children}
</ChakraProvider>
<MantineProvider withGlobalStyles theme={mantineTheme}>
<ChakraProvider theme={theme} colorModeManager={manager}>
{children}
</ChakraProvider>
</MantineProvider>
);
}

Some files were not shown because too many files have changed in this diff Show More