Compare commits

..

120 Commits

Author SHA1 Message Date
Lincoln Stein
b85f2bc87d add support for multi-gpu rendering
This commit adds speculative support for parallel rendering across
multiple GPUs. The parallelism is at the level of a session. Each
session is given access to a different GPU. When all GPUs are busy,
execution of the session will block until a GPU becomes available.

The code is untested at the current time, and is being posted for
comment.
2024-02-19 15:21:55 -05:00
Lincoln Stein
b06d63fb34 remove errant def that was crashing invokeai-configure 2024-02-19 17:31:53 +11:00
dunkeroni
5278a64301 one more redundant RGB convert removed 2024-02-19 17:31:08 +11:00
dunkeroni
4de4473c0f chore: ruff formatting 2024-02-19 17:31:08 +11:00
dunkeroni
2c28a850ca chore(invocations): remove redundant RGB conversions 2024-02-19 17:31:08 +11:00
dunkeroni
6dada3326d chore(invocations): use IMAGE_MODES constant literal 2024-02-19 17:31:08 +11:00
dunkeroni
2dfdc02ec8 fix: removed custom module 2024-02-19 17:31:08 +11:00
dunkeroni
1f19db4c6a fix(nodes): canny preprocessor uses RGBA again 2024-02-19 17:31:08 +11:00
dunkeroni
7c150c27f2 feat(nodes): format option for get_image method
Also default CNet preprocessors to "RGB"
2024-02-19 17:31:08 +11:00
blessedcoolant
248916c190 fix: Alpha channel causing issue with DW Processor 2024-02-19 08:17:56 +11:00
psychedelicious
be8b99eed5 final tidying before marking PR as ready for review
- Replace AnyModelLoader with ModelLoaderRegistry
- Fix type check errors in multiple files
- Remove apparently unneeded `get_model_config_enum()` method from model manager
- Remove last vestiges of old model manager
- Updated tests and documentation

resolve conflict with seamless.py
2024-02-19 08:16:56 +11:00
Lincoln Stein
2ad0752582 Tidy names and locations of modules
- Rename old "model_management" directory to "model_management_OLD" in order to catch
  dangling references to original model manager.
- Caught and fixed most dangling references (still checking)
- Rename lora, textual_inversion and model_patcher modules
- Introduce a RawModel base class to simplfy the Union returned by the
  model loaders.
- Tidy up the model manager 2-related tests. Add useful fixtures, and
  a finalizer to the queue and installer fixtures that will stop the
  services and release threads.
2024-02-19 08:16:56 +11:00
Lincoln Stein
ba1f8878dd Fix issues identified during PR review by RyanjDick and brandonrising
- ModelMetadataStoreService is now injected into ModelRecordStoreService
  (these two services are really joined at the hip, and should someday be merged)
- ModelRecordStoreService is now injected into ModelManagerService
- Reduced timeout value for the various installer and download wait*() methods
- Introduced a Mock modelmanager for testing
- Removed bare print() statement with _logger in the install helper backend.
- Removed unused code from model loader init file
- Made `locker` a private variable in the `LoadedModel` object.
- Fixed up model merge frontend (will be deprecated anyway!)
2024-02-19 08:16:56 +11:00
Brandon
bc524026f9 feat(ui): update model identifiers to use key (#5730)
## What type of PR is this? (check all applicable)

- [x] Refactor

## Description

- Update zod schemas & types to use key instead of name/base/type
- Use new `CustomSelect` component instead of `ComboBox` for main model
select and control adapter model selects (less jank, will switch to
ComboBox based on CustomSelect for v4 so you can search the select)

## QA Instructions, Screenshots, Recordings

If you hold your breath, you should be able to generate with a control
adapter.

<!-- 
Please provide steps on how to test changes, any hardware or 
software specifications as well as any other pertinent information. 
-->

## Merge Plan

This PR can be merged when approved. Frontend tests not passing.

<!--
A merge plan describes how this PR should be handled after it is
approved.

Example merge plans:
- "This PR can be merged when approved"
- "This must be squash-merged when approved"
- "DO NOT MERGE - I will rebase and tidy commits before merging"
- "#dev-chat on discord needs to be advised of this change when it is
merged"

A merge plan is particularly important for large PRs or PRs that touch
the
database in any way.
-->
2024-02-16 11:17:35 -05:00
Brandon
ad7c571983 fix(nodes): fix t2i adapter model loading (#5731)
## What type of PR is this? (check all applicable)

- [ ] Refactor
- [ ] Feature
- [x] Bug Fix
- [ ] Optimization
- [ ] Documentation Update
- [ ] Community Node Submission

## Description

Fixes t2i adapter loading

## Merge Plan

This PR can be merged when approved

<!--
A merge plan describes how this PR should be handled after it is
approved.

Example merge plans:
- "This PR can be merged when approved"
- "This must be squash-merged when approved"
- "DO NOT MERGE - I will rebase and tidy commits before merging"
- "#dev-chat on discord needs to be advised of this change when it is
merged"

A merge plan is particularly important for large PRs or PRs that touch
the
database in any way.
-->
2024-02-16 11:17:21 -05:00
psychedelicious
8559c6a392 fix(nodes): fix t2i adapter model loading 2024-02-16 22:51:47 +11:00
psychedelicious
c7904a32f4 chore(ui): lint 2024-02-16 22:42:15 +11:00
psychedelicious
17f5484f5b feat(ui): fix main model & control adapter model selects 2024-02-16 22:41:09 +11:00
psychedelicious
86a372b02f refactor(ui): url builders for each router
The MM2 router is at `api/v2/models`. URL builder utils make this a bit easier to manage.
2024-02-16 21:57:30 +11:00
psychedelicious
2e9aa9391d feat(ui): update model identifier to be key (wip)
- Update most model identifiers to be `{key: string}` instead of name/base/type. Doesn't change the model select components yet.
- Update model _parameters_, stored in redux, to be `{key: string, base: BaseModel}` - we need to store the base model to be able to check model compatibility. May want to store the whole config? Not sure...
2024-02-16 18:56:02 +11:00
psychedelicious
0c8112cf28 fix(ui): update model types 2024-02-15 22:17:16 +11:00
psychedelicious
019898c7be tests(ui): add type tests 2024-02-15 22:16:55 +11:00
psychedelicious
2b1ff8d196 tests(ui): enable vitest type testing
This is useful for the zod schemas and types we have created to match the backend.
2024-02-15 22:16:11 +11:00
psychedelicious
79fb691b4d chore(ui): typegen 2024-02-15 22:15:21 +11:00
psychedelicious
560ae17e21 feat(ui): export components type 2024-02-15 21:16:25 +11:00
psychedelicious
2bd1ab2f1c fix(ui): fix type issues 2024-02-15 20:53:41 +11:00
psychedelicious
ed43472582 chore: lint 2024-02-15 20:52:44 +11:00
psychedelicious
6e5e9176c0 chore: ruff 2024-02-15 20:50:47 +11:00
psychedelicious
4c6bcdbc18 feat(nodes): update invocation context for mm2, update nodes model usage 2024-02-15 20:43:41 +11:00
Brandon Rising
20e6d4fa3c Raise InvalidModelConfigException when unable to detect load class in ModelLoader 2024-02-15 18:00:16 +11:00
Brandon Rising
8e51392910 Update _get_hf_load_class to support clipvision models 2024-02-15 18:00:16 +11:00
Brandon Rising
0b1c2acd61 References to context.services.model_manager.store.get_model can only accept keys, remove invalid assertion 2024-02-15 18:00:16 +11:00
Brandon Rising
86ac55ab5f Remove references to model_records service, change submodel property on ModelInfo to submodel_type to support new params in model manager 2024-02-15 18:00:16 +11:00
Lincoln Stein
3e82f63c7e improve swagger documentation 2024-02-15 18:00:08 +11:00
Lincoln Stein
631f6cae19 fix a number of typechecking errors 2024-02-15 18:00:08 +11:00
Lincoln Stein
0845a0ed84 add route for model conversion from safetensors to diffusers
- Begin to add SwaggerUI documentation for AnyModelConfig and other
  discriminated Unions.
2024-02-15 18:00:08 +11:00
Lincoln Stein
46c8ce9fed add a JIT download_and_cache() call to the model installer 2024-02-15 18:00:08 +11:00
Lincoln Stein
13a9ea35b5 add back the heuristic_import() method and extend repo_ids to arbitrary file paths 2024-02-15 18:00:08 +11:00
Lincoln Stein
94e8d1b6d5 make model manager v2 ready for PR review
- Replace legacy model manager service with the v2 manager.

- Update invocations to use new load interface.

- Fixed many but not all type checking errors in the invocations. Most
  were unrelated to model manager

- Updated routes. All the new routes live under the route tag
  `model_manager_v2`. To avoid confusion with the old routes,
  they have the URL prefix `/api/v2/models`. The old routes
  have been de-registered.

- Added a pytest for the loader.

- Updated documentation in contributing/MODEL_MANAGER.md
2024-02-15 18:00:08 +11:00
Lincoln Stein
2b1dc74080 consolidate model manager parts into a single class 2024-02-15 17:57:14 +11:00
Lincoln Stein
f7e558d165 probe for required encoder for IPAdapters and add to config 2024-02-15 17:56:01 +11:00
Lincoln Stein
d959276217 fix invokeai_configure script to work with new mm; rename CLIs 2024-02-15 17:56:01 +11:00
Lincoln Stein
dfcf38be91 BREAKING CHANGES: invocations now require model key, not base/type/name
- Implement new model loader and modify invocations and embeddings

- Finish implementation loaders for all models currently supported by
  InvokeAI.

- Move lora, textual_inversion, and model patching support into
  backend/embeddings.

- Restore support for model cache statistics collection (a little ugly,
  needs work).

- Fixed up invocations that load and patch models.

- Move seamless and silencewarnings utils into better location
2024-02-15 17:56:01 +11:00
Lincoln Stein
fbded1c0f2 Multiple refinements on loaders:
- Cache stat collection enabled.
- Implemented ONNX loading.
- Add ability to specify the repo version variant in installer CLI.
- If caller asks for a repo version that doesn't exist, will fall back
  to empty version rather than raising an error.
2024-02-15 17:51:07 +11:00
Lincoln Stein
ad2926a24c added textual inversion and lora loaders 2024-02-15 17:51:07 +11:00
Lincoln Stein
34d5cad4c9 loaders for main, controlnet, ip-adapter, clipvision and t2i 2024-02-15 17:51:07 +11:00
Lincoln Stein
60aa3d4893 model loading and conversion implemented for vaes 2024-02-15 17:50:51 +11:00
Lincoln Stein
5c2884569e add ram cache module and support files 2024-02-15 17:50:31 +11:00
Lincoln Stein
a1307b9f2e add concept of repo variant 2024-02-15 17:50:31 +11:00
psychedelicious
f505ec64ba tests(ui): add parseFieldType.test.ts 2024-02-15 17:32:38 +11:00
psychedelicious
f22eb368a3 feat(ui): add more types of FieldParseError
Unfortunately you cannot test for both a specific type of error and match its message. Splitting the error classes makes it easier to test expected error conditions.
2024-02-15 17:32:38 +11:00
psychedelicious
96ae22c7e0 feat(ui): add vitest
- Add vitest.
- Consolidate vite configs into single file (easier to config everything based on env for testing)
2024-02-15 17:32:38 +11:00
psychedelicious
f5447cdc23 feat(ui): workflow schema v3 (WIP)
The changes aim to deduplicate data between workflows and node templates, decoupling workflows from internal implementation details. A good amount of data that was needlessly duplicated from the node template to the workflow is removed.

These changes substantially reduce the file size of workflows (and therefore the images with embedded workflows):

- Default T2I SD1.5 workflow JSON is reduced from 23.7kb (798 lines) to 10.9kb (407 lines).
- Default tiled upscale workflow JSON is reduced from 102.7kb (3341 lines) to 51.9kb (1774 lines).

The trade-off is that we need to reference node templates to get things like the field type and other things. In practice, this is a non-issue, because we need a node template to do anything with a node anyways.

- Field types are not included in the workflow. They are always pulled from the node templates.

The field type is now properly an internal implementation detail and we can change it as needed. Previously this would require a migration for the workflow itself. With the v3 schema, the structure of a field type is an internal implementation detail that we are free to change as we see fit.

- Workflow nodes no long have an `outputs` property and there is no longer such a thing as a `FieldOutputInstance`. These are only on the templates.

These were never referenced at a time when we didn't also have the templates available, and there'd be no reason to do so.

- Node width and height are no longer stored in the node.

These weren't used. Also, per https://reactflow.dev/api-reference/types/node, we shouldn't be programmatically changing these properties. A future enhancement can properly add node resizing.

- `nodeTemplates` slice is merged back into `nodesSlice` as `nodes.templates`. Turns out it's just a hassle having these separate in separate slices.

- Workflow migration logic updated to support the new schema. V1 workflows migrate all the way to v3 now.

- Changes throughout the nodes code to accommodate the above changes.
2024-02-15 17:32:38 +11:00
psychedelicious
c76a6bd65f chore(ui): regen types 2024-02-15 17:30:03 +11:00
psychedelicious
6c4eeaa569 feat(nodes): add more missing exports to invocation_api
Crawled through a few custom nodes to figure out what I had missed.
2024-02-15 17:30:03 +11:00
psychedelicious
1bbd13ead7 chore(nodes): "SAMPLER_NAME_VALUES" -> "SCHEDULER_NAME_VALUES"
This was named inaccurately.
2024-02-15 17:30:03 +11:00
psychedelicious
321b939d0e chore(nodes): remove deprecation logic for nodes API 2024-02-15 17:30:03 +11:00
psychedelicious
8fb77e431e chore(nodes): export model-related objects from invocation_api 2024-02-15 17:30:03 +11:00
psychedelicious
083a4f3faa chore(backend): rename ModelInfo -> LoadedModelInfo
We have two different classes named `ModelInfo` which might need to be used by API consumers. We need to export both but have to deal with this naming collision.

The `ModelInfo` I've renamed here is the one that is returned when a model is loaded. It's the object least likely to be used by API consumers.
2024-02-15 17:30:03 +11:00
psychedelicious
2005411f7e feat(nodes): use LATENT_SCALE_FACTOR in primitives.py, noise.py
- LatentsOutput.build
- NoiseOutput.build
- Noise.width, Noise.height multiple_of
2024-02-15 17:30:03 +11:00
psychedelicious
ba7b1b2665 feat(nodes): extract LATENT_SCALE_FACTOR to constants.py 2024-02-15 17:30:03 +11:00
psychedelicious
b7ffd36cc6 feat(nodes): use TemporaryDirectory to handle ephemeral storage in ObjectSerializerDisk
Replace `delete_on_startup: bool` & associated logic with `ephemeral: bool` and `TemporaryDirectory`.

The temp dir is created inside of `output_dir`. For example, if `output_dir` is `invokeai/outputs/tensors/`, then the temp dir might be `invokeai/outputs/tensors/tmpvj35ht7b/`.

The temp dir is cleaned up when the service is stopped, or when it is GC'd if not properly stopped.

In the event of a catastrophic crash where the temp files are not cleaned up, the user can delete the tempdir themselves.

This situation may not occur in normal use, but if you kill the process, python cannot clean up the temp dir itself. This includes running the app in a debugger and killing the debugger process - something I do relatively often.

Tests updated.
2024-02-15 17:30:03 +11:00
psychedelicious
199ddd6623 tests: test ObjectSerializerDisk class name extraction 2024-02-15 17:30:03 +11:00
psychedelicious
a7207ed8cf chore(nodes): update ObjectSerializerForwardCache docstring 2024-02-15 17:30:03 +11:00
psychedelicious
6bb2dda3f1 chore(nodes): fix pyright ignore 2024-02-15 17:30:03 +11:00
psychedelicious
c1e5cd5893 tidy(nodes): "latents" -> "obj" 2024-02-15 17:30:03 +11:00
psychedelicious
ff249a2315 tidy(nodes): do not store unnecessarily store invoker 2024-02-15 17:30:03 +11:00
psychedelicious
c58f8c3269 feat(nodes): make delete on startup configurable for obj serializer
- The default is to not delete on startup - feels safer.
- The two services using this class _do_ delete on startup.
- The class has "ephemeral" removed from its name.
- Tests & app updated for this change.
2024-02-15 17:30:03 +11:00
psychedelicious
ed772a7107 fix(nodes): use metadata/board_id if provided by user, overriding WithMetadata/WithBoard-provided values 2024-02-15 17:30:03 +11:00
psychedelicious
cb0b389b4b tidy(nodes): clarify comment 2024-02-15 17:30:03 +11:00
psychedelicious
8892df1d97 Revert "feat(nodes): use LATENT_SCALE_FACTOR const in tensor output builders"
This reverts commit ef18fc546560277302f3886e456da9a47e8edce0.
2024-02-15 17:30:03 +11:00
psychedelicious
bc5f356390 feat(nodes): use LATENT_SCALE_FACTOR const in tensor output builders 2024-02-15 17:30:03 +11:00
psychedelicious
bcb85e100d tests: fix broken tests 2024-02-15 17:30:03 +11:00
psychedelicious
1f27ddc07d tidy(nodes): minor spelling correction 2024-02-15 17:30:03 +11:00
psychedelicious
7a2b606001 tests: add object serializer tests
These test both object serializer and its forward cache implementation.
2024-02-15 17:30:03 +11:00
psychedelicious
83ddcc5f3a feat(nodes): allow _delete_all in obj serializer to be called at any time
`_delete_all` logged how many items it deleted, and had to be called _after_ service start bc it needed access to logger.

Move the logger call to the startup method and return the the deleted stats from `_delete_all`. This lets `_delete_all` be called at any time.
2024-02-15 17:30:03 +11:00
psychedelicious
55fa785561 tidy(nodes): remove object serializer on_saved
It's unused.
2024-02-15 17:30:03 +11:00
psychedelicious
06429028c8 revert(nodes): revert making tensors/conditioning use item storage
Turns out they are just different enough in purpose that the implementations would be rather unintuitive. I've made a separate ObjectSerializer service to handle tensors and conditioning.

Refined the class a bit too.
2024-02-15 17:30:03 +11:00
psychedelicious
8b6e322697 feat(nodes): support custom exception in ephemeral disk storage 2024-02-15 17:30:03 +11:00
psychedelicious
54a67459bf feat(nodes): support custom save and load functions in ItemStorageEphemeralDisk 2024-02-15 17:30:03 +11:00
psychedelicious
7fe5283e74 feat(nodes): create helper function to generate the item ID 2024-02-15 17:30:03 +11:00
psychedelicious
fe0391c86b feat(nodes): use ItemStorageABC for tensors and conditioning
Turns out `ItemStorageABC` was almost identical to `PickleStorageBase`. Instead of maintaining separate classes, we can use `ItemStorageABC` for both.

There's only one change needed - the `ItemStorageABC.set` method must return the newly stored item's ID. This allows us to let the service handle the responsibility of naming the item, but still create the requisite output objects during node execution.

The naming implementation is improved here. It extracts the name of the generic and appends a UUID to that string when saving items.
2024-02-15 17:30:03 +11:00
psychedelicious
25386a76ef tidy(nodes): do not refer to files as latents in PickleStorageTorch (again) 2024-02-15 17:30:03 +11:00
psychedelicious
fd30cb4d90 feat(nodes): ItemStorageABC typevar no longer bound to pydantic.BaseModel
This bound is totally unnecessary. There's no requirement for any implementation of `ItemStorageABC` to work only on pydantic models.
2024-02-15 17:30:03 +11:00
psychedelicious
0266946d3d fix(nodes): add super init to PickleStorageTorch 2024-02-15 17:30:03 +11:00
psychedelicious
a7f91b3e01 tidy(nodes): do not refer to files as latents in PickleStorageTorch 2024-02-15 17:30:03 +11:00
psychedelicious
de0b72528c feat(nodes): replace latents service with tensors and conditioning services
- New generic class `PickleStorageBase`, implements the same API as `LatentsStorageBase`, use for storing non-serializable data via pickling
- Implementation `PickleStorageTorch` uses `torch.save` and `torch.load`, same as `LatentsStorageDisk`
- Add `tensors: PickleStorageBase[torch.Tensor]` to `InvocationServices`
- Add `conditioning: PickleStorageBase[ConditioningFieldData]` to `InvocationServices`
- Remove `latents` service and all `LatentsStorage` classes
- Update `InvocationContext` and all usage of old `latents` service to use the new services/context wrapper methods
2024-02-15 17:30:03 +11:00
psychedelicious
2932652787 tidy(nodes): delete onnx.py
It doesn't work and keeping it updated to prevent the app from starting was getting tedious. Deleted.
2024-02-15 17:30:03 +11:00
psychedelicious
db6bc7305a fix(nodes): rearrange fields.py to avoid needing forward refs 2024-02-15 17:30:02 +11:00
psychedelicious
a5db204629 tidy(nodes): remove unnecessary, shadowing class attr declarations 2024-02-15 17:30:02 +11:00
psychedelicious
8e2b61e19f feat(ui): revise graphs to not use LinearUIOutputInvocation
See this comment for context: https://github.com/invoke-ai/InvokeAI/pull/5491#discussion_r1480760629

- Remove this now-unnecessary node from all graphs
- Update graphs' terminal image-outputting nodes' `is_intermediate` and `board` fields appropriately
- Add util function to prepare the `board` field, tidy the utils
- Update `socketInvocationComplete` listener to work correctly with this change

I've manually tested all graph permutations that were changed (I think this is all...) to ensure images go to the gallery as expected:
- ad-hoc upscaling
- t2i w/ sd1.5
- t2i w/ sd1.5 & hrf
- t2i w/ sdxl
- t2i w/ sdxl + refiner
- i2i w/ sd1.5
- i2i w/ sdxl
- i2i w/ sdxl + refiner
- canvas t2i w/ sd1.5
- canvas t2i w/ sdxl
- canvas t2i w/ sdxl + refiner
- canvas i2i w/ sd1.5
- canvas i2i w/ sdxl
- canvas i2i w/ sdxl + refiner
- canvas inpaint w/ sd1.5
- canvas inpaint w/ sdxl
- canvas inpaint w/ sdxl + refiner
- canvas outpaint w/ sd1.5
- canvas outpaint w/ sdxl
- canvas outpaint w/ sdxl + refiner
2024-02-15 17:30:02 +11:00
psychedelicious
a3faa3792a chore(ui): regen types 2024-02-15 17:30:02 +11:00
psychedelicious
c16eba78ab feat(nodes): add WithBoard field helper class
This class works the same way as `WithMetadata` - it simply adds a `board` field to the node. The context wrapper function is able to pull the board id from this. This allows image-outputting nodes to get a board field "for free", and have their outputs automatically saved to it.

This is a breaking change for node authors who may have a field called `board`, because it makes `board` a reserved field name. I'll look into how to avoid this - maybe by naming this invoke-managed field `_board` to avoid collisions?

Supporting changes:
- `WithBoard` is added to all image-outputting nodes, giving them the ability to save to board.
- Unused, duplicate `WithMetadata` and `WithWorkflow` classes are deleted from `baseinvocation.py`. The "real" versions are in `fields.py`.
- Remove `LinearUIOutputInvocation`. Now that all nodes that output images also have a `board` field by default, this node is no longer necessary. See comment here for context: https://github.com/invoke-ai/InvokeAI/pull/5491#discussion_r1480760629
- Without `LinearUIOutputInvocation`, the `ImagesInferface.update` method is no longer needed, and removed.

Note: This commit does not bump all node versions. I will ensure that is done correctly before merging the PR of which this commit is a part.

Note: A followup commit will implement the frontend changes to support this change.
2024-02-15 17:30:02 +11:00
psychedelicious
1a191c4655 remove unused configdict import 2024-02-15 17:30:02 +11:00
psychedelicious
e36d925bce fix(ui): remove original l2i node in HRF graph 2024-02-15 17:30:02 +11:00
psychedelicious
b1ba18b3d1 fix(nodes): do not freeze or cache config in context wrapper
- The config is already cached by the config class's `get_config()` method.
- The config mutates itself in its `root_path` property getter. Freezing the class makes any attempt to grab a path from the config error. Unfortunately this means we cannot easily freeze the class without fiddling with the inner workings of `InvokeAIAppConfig`, which is outside the scope here.
2024-02-15 17:30:02 +11:00
psychedelicious
aff46759f9 feat(nodes): context.data -> context._data 2024-02-15 17:30:02 +11:00
psychedelicious
d7b7dcc7fe feat(nodes): context.__services -> context._services 2024-02-15 17:30:02 +11:00
psychedelicious
889a26c5b6 feat(nodes): cache invocation interface config 2024-02-15 17:30:02 +11:00
psychedelicious
b4c774896a feat(nodes): do not hide services in invocation context interfaces 2024-02-15 17:30:02 +11:00
psychedelicious
afbe889d35 fix(nodes): restore missing context type annotations 2024-02-15 17:30:02 +11:00
psychedelicious
9c1e52b1ef tests(nodes): fix mock InvocationContext 2024-02-15 17:30:02 +11:00
psychedelicious
3f5ab02da9 chore(nodes): add comments for ConfigInterface 2024-02-15 17:30:02 +11:00
psychedelicious
bf48e8a03a feat(nodes): export more things from `invocation_api" 2024-02-15 17:30:02 +11:00
psychedelicious
e52434cb99 feat(nodes): add boards interface to invocation context 2024-02-15 17:30:02 +11:00
psychedelicious
483bdbcb9f fix(nodes): restore type annotations for InvocationContext 2024-02-15 17:30:02 +11:00
psychedelicious
ae421fb4ab feat(nodes): do not freeze InvocationContextData, prevents it from being subclassesd 2024-02-15 17:30:02 +11:00
psychedelicious
cc295a9f0a feat: tweak pyright config 2024-02-15 17:30:02 +11:00
psychedelicious
a7e23af9c6 feat(nodes): create invocation_api.py
This is the public API for invocations.

Everything a custom node might need should be re-exported from this file.
2024-02-15 17:30:02 +11:00
psychedelicious
3de4390711 feat(nodes): move ConditioningFieldData to conditioning_data.py 2024-02-15 17:30:02 +11:00
psychedelicious
3ceee2b2b2 tests: fix missing arg for InvocationContext 2024-02-15 17:30:02 +11:00
psychedelicious
5c7ed24aab feat(nodes): restore previous invocation context methods with deprecation warnings 2024-02-15 17:30:02 +11:00
psychedelicious
183c9c4799 chore: ruff 2024-02-15 17:30:02 +11:00
psychedelicious
8baf3f78a2 feat(nodes): tidy invocation_context.py, improve comments 2024-02-15 17:30:02 +11:00
psychedelicious
ac2eb16a65 tests: fix tests for new invocation context 2024-02-15 17:30:02 +11:00
psychedelicious
4aa7bee4b9 docs: update INVOCATIONS.md 2024-02-15 17:30:02 +11:00
psychedelicious
7e5ba2795e feat(nodes): update all invocations to use new invocation context
Update all invocations to use the new context. The changes are all fairly simple, but there are a lot of them.

Supporting minor changes:
- Patch bump for all nodes that use the context
- Update invocation processor to provide new context
- Minor change to `EventServiceBase` to accept a node's ID instead of the dict version of a node
- Minor change to `ModelManagerService` to support the new wrapped context
- Fanagling of imports to avoid circular dependencies
2024-02-15 17:30:02 +11:00
psychedelicious
97a6c6eea7 feat: add pyright config
I was having issues with mypy bother over- and under-reporting certain problems. I've added a pyright config.
2024-02-15 17:30:02 +11:00
psychedelicious
f0e60a4ba2 feat(nodes): restricts invocation context power
Creates a low-power `InvocationContext` with simplified methods and data.

See `invocation_context.py` for detailed comments.
2024-02-15 17:30:02 +11:00
psychedelicious
aa089e8108 tidy(nodes): move all field things to fields.py
Unfortunately, this is necessary to prevent circular imports at runtime.
2024-02-15 17:30:02 +11:00
722 changed files with 45220 additions and 27980 deletions

View File

@@ -1,33 +0,0 @@
name: install frontend dependencies
description: Installs frontend dependencies with pnpm, with caching
runs:
using: 'composite'
steps:
- name: setup node 18
uses: actions/setup-node@v4
with:
node-version: '18'
- name: setup pnpm
uses: pnpm/action-setup@v2
with:
version: 8
run_install: false
- name: get pnpm store directory
shell: bash
run: |
echo "STORE_PATH=$(pnpm store path --silent)" >> $GITHUB_ENV
- name: setup cache
uses: actions/cache@v4
with:
path: ${{ env.STORE_PATH }}
key: ${{ runner.os }}-pnpm-store-${{ hashFiles('**/pnpm-lock.yaml') }}
restore-keys: |
${{ runner.os }}-pnpm-store-
- name: install frontend dependencies
run: pnpm install --prefer-frozen-lockfile
shell: bash
working-directory: invokeai/frontend/web

28
.github/pr_labels.yml vendored
View File

@@ -1,59 +1,59 @@
root: Root:
- changed-files: - changed-files:
- any-glob-to-any-file: '*' - any-glob-to-any-file: '*'
python-deps: PythonDeps:
- changed-files: - changed-files:
- any-glob-to-any-file: 'pyproject.toml' - any-glob-to-any-file: 'pyproject.toml'
python: Python:
- changed-files: - changed-files:
- all-globs-to-any-file: - all-globs-to-any-file:
- 'invokeai/**' - 'invokeai/**'
- '!invokeai/frontend/web/**' - '!invokeai/frontend/web/**'
python-tests: PythonTests:
- changed-files: - changed-files:
- any-glob-to-any-file: 'tests/**' - any-glob-to-any-file: 'tests/**'
ci-cd: CICD:
- changed-files: - changed-files:
- any-glob-to-any-file: .github/** - any-glob-to-any-file: .github/**
docker: Docker:
- changed-files: - changed-files:
- any-glob-to-any-file: docker/** - any-glob-to-any-file: docker/**
installer: Installer:
- changed-files: - changed-files:
- any-glob-to-any-file: installer/** - any-glob-to-any-file: installer/**
docs: Documentation:
- changed-files: - changed-files:
- any-glob-to-any-file: docs/** - any-glob-to-any-file: docs/**
invocations: Invocations:
- changed-files: - changed-files:
- any-glob-to-any-file: 'invokeai/app/invocations/**' - any-glob-to-any-file: 'invokeai/app/invocations/**'
backend: Backend:
- changed-files: - changed-files:
- any-glob-to-any-file: 'invokeai/backend/**' - any-glob-to-any-file: 'invokeai/backend/**'
api: Api:
- changed-files: - changed-files:
- any-glob-to-any-file: 'invokeai/app/api/**' - any-glob-to-any-file: 'invokeai/app/api/**'
services: Services:
- changed-files: - changed-files:
- any-glob-to-any-file: 'invokeai/app/services/**' - any-glob-to-any-file: 'invokeai/app/services/**'
frontend-deps: FrontendDeps:
- changed-files: - changed-files:
- any-glob-to-any-file: - any-glob-to-any-file:
- '**/*/package.json' - '**/*/package.json'
- '**/*/pnpm-lock.yaml' - '**/*/pnpm-lock.yaml'
frontend: Frontend:
- changed-files: - changed-files:
- any-glob-to-any-file: 'invokeai/frontend/web/**' - any-glob-to-any-file: 'invokeai/frontend/web/**'

View File

@@ -1,21 +1,66 @@
## Summary ## What type of PR is this? (check all applicable)
<!--A description of the changes in this PR. Include the kind of change (fix, feature, docs, etc), the "why" and the "how". Screenshots or videos are useful for frontend changes.--> - [ ] Refactor
- [ ] Feature
- [ ] Bug Fix
- [ ] Optimization
- [ ] Documentation Update
- [ ] Community Node Submission
## Related Issues / Discussions
<!--WHEN APPLICABLE: List any related issues or discussions on github or discord. If this PR closes an issue, please use the "Closes #1234" format, so that the issue will be automatically closed when the PR merges.--> ## Have you discussed this change with the InvokeAI team?
- [ ] Yes
- [ ] No, because:
## QA Instructions
<!--WHEN APPLICABLE: Describe how we can test the changes in this PR.--> ## Have you updated all relevant documentation?
- [ ] Yes
- [ ] No
## Description
## Related Tickets & Documents
<!--
For pull requests that relate or close an issue, please include them
below.
For example having the text: "closes #1234" would connect the current pull
request to issue 1234. And when we merge the pull request, Github will
automatically close the issue.
-->
- Related Issue #
- Closes #
## QA Instructions, Screenshots, Recordings
<!--
Please provide steps on how to test changes, any hardware or
software specifications as well as any other pertinent information.
-->
## Merge Plan ## Merge Plan
<!--WHEN APPLICABLE: Large PRs, or PRs that touch sensitive things like DB schemas, may need some care when merging. For example, a careful rebase by the change author, timing to not interfere with a pending release, or a message to contributors on discord after merging.--> <!--
A merge plan describes how this PR should be handled after it is approved.
## Checklist Example merge plans:
- "This PR can be merged when approved"
- "This must be squash-merged when approved"
- "DO NOT MERGE - I will rebase and tidy commits before merging"
- "#dev-chat on discord needs to be advised of this change when it is merged"
- [ ] _The PR has a short but descriptive title, suitable for a changelog_ A merge plan is particularly important for large PRs or PRs that touch the
- [ ] _Tests added / updated (if applicable)_ database in any way.
- [ ] _Documentation added / updated (if applicable)_ -->
## Added/updated tests?
- [ ] Yes
- [ ] No : _please replace this line with details on why tests
have not been included_
## [optional] Are there any post deployment tasks we need to perform?

View File

@@ -11,7 +11,7 @@ on:
- 'docker/docker-entrypoint.sh' - 'docker/docker-entrypoint.sh'
- 'workflows/build-container.yml' - 'workflows/build-container.yml'
tags: tags:
- 'v*.*.*' - 'v*'
workflow_dispatch: workflow_dispatch:
permissions: permissions:

View File

@@ -1,45 +0,0 @@
# Builds and uploads the installer and python build artifacts.
name: build installer
on:
workflow_dispatch:
workflow_call:
jobs:
build-installer:
runs-on: ubuntu-latest
timeout-minutes: 5 # expected run time: <2 min
steps:
- name: checkout
uses: actions/checkout@v4
- name: setup python
uses: actions/setup-python@v5
with:
python-version: '3.10'
cache: pip
cache-dependency-path: pyproject.toml
- name: install pypa/build
run: pip install --upgrade build
- name: setup frontend
uses: ./.github/actions/install-frontend-deps
- name: create installer
id: create_installer
run: ./create_installer.sh
working-directory: installer
- name: upload python distribution artifact
uses: actions/upload-artifact@v4
with:
name: dist
path: ${{ steps.create_installer.outputs.DIST_PATH }}
- name: upload installer artifact
uses: actions/upload-artifact@v4
with:
name: ${{ steps.create_installer.outputs.INSTALLER_FILENAME }}
path: ${{ steps.create_installer.outputs.INSTALLER_PATH }}

View File

@@ -1,80 +0,0 @@
# Runs frontend code quality checks.
#
# Checks for changes to frontend files before running the checks.
# If always_run is true, always runs the checks.
name: 'frontend checks'
on:
push:
branches:
- 'main'
pull_request:
types:
- 'ready_for_review'
- 'opened'
- 'synchronize'
merge_group:
workflow_dispatch:
inputs:
always_run:
description: 'Always run the checks'
required: true
type: boolean
default: true
workflow_call:
inputs:
always_run:
description: 'Always run the checks'
required: true
type: boolean
default: true
defaults:
run:
working-directory: invokeai/frontend/web
jobs:
frontend-checks:
runs-on: ubuntu-latest
timeout-minutes: 10 # expected run time: <2 min
steps:
- uses: actions/checkout@v4
- name: check for changed frontend files
if: ${{ inputs.always_run != true }}
id: changed-files
uses: tj-actions/changed-files@v42
with:
files_yaml: |
frontend:
- 'invokeai/frontend/web/**'
- name: install dependencies
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || inputs.always_run == true }}
uses: ./.github/actions/install-frontend-deps
- name: tsc
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || inputs.always_run == true }}
run: 'pnpm lint:tsc'
shell: bash
- name: dpdm
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || inputs.always_run == true }}
run: 'pnpm lint:dpdm'
shell: bash
- name: eslint
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || inputs.always_run == true }}
run: 'pnpm lint:eslint'
shell: bash
- name: prettier
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || inputs.always_run == true }}
run: 'pnpm lint:prettier'
shell: bash
- name: knip
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || inputs.always_run == true }}
run: 'pnpm lint:knip'
shell: bash

View File

@@ -1,60 +0,0 @@
# Runs frontend tests.
#
# Checks for changes to frontend files before running the tests.
# If always_run is true, always runs the tests.
name: 'frontend tests'
on:
push:
branches:
- 'main'
pull_request:
types:
- 'ready_for_review'
- 'opened'
- 'synchronize'
merge_group:
workflow_dispatch:
inputs:
always_run:
description: 'Always run the tests'
required: true
type: boolean
default: true
workflow_call:
inputs:
always_run:
description: 'Always run the tests'
required: true
type: boolean
default: true
defaults:
run:
working-directory: invokeai/frontend/web
jobs:
frontend-tests:
runs-on: ubuntu-latest
timeout-minutes: 10 # expected run time: <2 min
steps:
- uses: actions/checkout@v4
- name: check for changed frontend files
if: ${{ inputs.always_run != true }}
id: changed-files
uses: tj-actions/changed-files@v42
with:
files_yaml: |
frontend:
- 'invokeai/frontend/web/**'
- name: install dependencies
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || inputs.always_run == true }}
uses: ./.github/actions/install-frontend-deps
- name: vitest
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || inputs.always_run == true }}
run: 'pnpm test:no-watch'
shell: bash

View File

@@ -1,6 +1,6 @@
name: 'label PRs' name: "Pull Request Labeler"
on: on:
- pull_request_target - pull_request_target
jobs: jobs:
labeler: labeler:
@@ -9,10 +9,8 @@ jobs:
pull-requests: write pull-requests: write
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: checkout - name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v4
- uses: actions/labeler@v5
- name: label PRs
uses: actions/labeler@v5
with: with:
configuration-path: .github/pr_labels.yml configuration-path: .github/pr_labels.yml

43
.github/workflows/lint-frontend.yml vendored Normal file
View File

@@ -0,0 +1,43 @@
name: Lint frontend
on:
pull_request:
types:
- 'ready_for_review'
- 'opened'
- 'synchronize'
push:
branches:
- 'main'
merge_group:
workflow_dispatch:
defaults:
run:
working-directory: invokeai/frontend/web
jobs:
lint-frontend:
if: github.event.pull_request.draft == false
runs-on: ubuntu-22.04
steps:
- name: Setup Node 18
uses: actions/setup-node@v4
with:
node-version: '18'
- name: Checkout
uses: actions/checkout@v4
- name: Setup pnpm
uses: pnpm/action-setup@v2
with:
version: '8.12.1'
- name: Install dependencies
run: 'pnpm install --prefer-frozen-lockfile'
- name: Typescript
run: 'pnpm run lint:tsc'
- name: Madge
run: 'pnpm run lint:madge'
- name: ESLint
run: 'pnpm run lint:eslint'
- name: Prettier
run: 'pnpm run lint:prettier'

View File

@@ -1,49 +1,51 @@
# This is a mostly a copy-paste from https://github.com/squidfunk/mkdocs-material/blob/master/docs/publishing-your-site.md name: mkdocs-material
name: mkdocs
on: on:
push: push:
branches: branches:
- main - 'refs/heads/main'
workflow_dispatch:
permissions: permissions:
contents: write contents: write
jobs: jobs:
deploy: mkdocs-material:
if: github.event.pull_request.draft == false if: github.event.pull_request.draft == false
runs-on: ubuntu-latest runs-on: ubuntu-latest
env: env:
REPO_URL: '${{ github.server_url }}/${{ github.repository }}' REPO_URL: '${{ github.server_url }}/${{ github.repository }}'
REPO_NAME: '${{ github.repository }}' REPO_NAME: '${{ github.repository }}'
SITE_URL: 'https://${{ github.repository_owner }}.github.io/InvokeAI' SITE_URL: 'https://${{ github.repository_owner }}.github.io/InvokeAI'
steps: steps:
- name: checkout - name: checkout sources
uses: actions/checkout@v4 uses: actions/checkout@v3
with:
fetch-depth: 0
- name: setup python - name: setup python
uses: actions/setup-python@v5 uses: actions/setup-python@v4
with: with:
python-version: '3.10' python-version: '3.10'
cache: pip cache: pip
cache-dependency-path: pyproject.toml cache-dependency-path: pyproject.toml
- name: set cache id - name: install requirements
run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV env:
PIP_USE_PEP517: 1
run: |
python -m \
pip install ".[docs]"
- name: use cache - name: confirm buildability
uses: actions/cache@v4 run: |
with: python -m \
key: mkdocs-material-${{ env.cache_id }} mkdocs build \
path: .cache --clean \
restore-keys: | --verbose
mkdocs-material-
- name: install dependencies - name: deploy to gh-pages
run: python -m pip install ".[docs]" if: ${{ github.ref == 'refs/heads/main' }}
run: |
- name: build & deploy python -m \
run: mkdocs gh-deploy --force mkdocs gh-deploy \
--clean \
--force

67
.github/workflows/pypi-release.yml vendored Normal file
View File

@@ -0,0 +1,67 @@
name: PyPI Release
on:
workflow_dispatch:
inputs:
publish_package:
description: 'Publish build on PyPi? [true/false]'
required: true
default: 'false'
jobs:
build-and-release:
if: github.repository == 'invoke-ai/InvokeAI'
runs-on: ubuntu-22.04
env:
TWINE_USERNAME: __token__
TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
TWINE_NON_INTERACTIVE: 1
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Setup Node 18
uses: actions/setup-node@v4
with:
node-version: '18'
- name: Setup pnpm
uses: pnpm/action-setup@v2
with:
version: '8.12.1'
- name: Install frontend dependencies
run: pnpm install --prefer-frozen-lockfile
working-directory: invokeai/frontend/web
- name: Build frontend
run: pnpm run build
working-directory: invokeai/frontend/web
- name: Install python dependencies
run: pip install --upgrade build twine
- name: Build python package
run: python3 -m build
- name: Upload build as workflow artifact
uses: actions/upload-artifact@v4
with:
name: dist
path: dist
- name: Check distribution
run: twine check dist/*
- name: Check PyPI versions
if: github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/heads/release/')
run: |
pip install --upgrade requests
python -c "\
import scripts.pypi_helper; \
EXISTS=scripts.pypi_helper.local_on_pypi(); \
print(f'PACKAGE_EXISTS={EXISTS}')" >> $GITHUB_ENV
- name: Publish build on PyPi
if: env.PACKAGE_EXISTS == 'False' && env.TWINE_PASSWORD != '' && github.event.inputs.publish_package == 'true'
run: twine upload dist/*

View File

@@ -1,76 +0,0 @@
# Runs python code quality checks.
#
# Checks for changes to python files before running the checks.
# If always_run is true, always runs the checks.
#
# TODO: Add mypy or pyright to the checks.
name: 'python checks'
on:
push:
branches:
- 'main'
pull_request:
types:
- 'ready_for_review'
- 'opened'
- 'synchronize'
merge_group:
workflow_dispatch:
inputs:
always_run:
description: 'Always run the checks'
required: true
type: boolean
default: true
workflow_call:
inputs:
always_run:
description: 'Always run the checks'
required: true
type: boolean
default: true
jobs:
python-checks:
runs-on: ubuntu-latest
timeout-minutes: 5 # expected run time: <1 min
steps:
- name: checkout
uses: actions/checkout@v4
- name: check for changed python files
if: ${{ inputs.always_run != true }}
id: changed-files
uses: tj-actions/changed-files@v42
with:
files_yaml: |
python:
- 'pyproject.toml'
- 'invokeai/**'
- '!invokeai/frontend/web/**'
- 'tests/**'
- name: setup python
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
uses: actions/setup-python@v5
with:
python-version: '3.10'
cache: pip
cache-dependency-path: pyproject.toml
- name: install ruff
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
run: pip install ruff
shell: bash
- name: ruff check
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
run: ruff check --output-format=github .
shell: bash
- name: ruff format
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
run: ruff format --check .
shell: bash

View File

@@ -1,106 +0,0 @@
# Runs python tests on a matrix of python versions and platforms.
#
# Checks for changes to python files before running the tests.
# If always_run is true, always runs the tests.
name: 'python tests'
on:
push:
branches:
- 'main'
pull_request:
types:
- 'ready_for_review'
- 'opened'
- 'synchronize'
merge_group:
workflow_dispatch:
inputs:
always_run:
description: 'Always run the tests'
required: true
type: boolean
default: true
workflow_call:
inputs:
always_run:
description: 'Always run the tests'
required: true
type: boolean
default: true
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
matrix:
strategy:
matrix:
python-version:
- '3.10'
- '3.11'
platform:
- linux-cuda-11_7
- linux-rocm-5_2
- linux-cpu
- macos-default
- windows-cpu
include:
- platform: linux-cuda-11_7
os: ubuntu-22.04
github-env: $GITHUB_ENV
- platform: linux-rocm-5_2
os: ubuntu-22.04
extra-index-url: 'https://download.pytorch.org/whl/rocm5.2'
github-env: $GITHUB_ENV
- platform: linux-cpu
os: ubuntu-22.04
extra-index-url: 'https://download.pytorch.org/whl/cpu'
github-env: $GITHUB_ENV
- platform: macos-default
os: macOS-12
github-env: $GITHUB_ENV
- platform: windows-cpu
os: windows-2022
github-env: $env:GITHUB_ENV
name: 'py${{ matrix.python-version }}: ${{ matrix.platform }}'
runs-on: ${{ matrix.os }}
timeout-minutes: 15 # expected run time: 2-6 min, depending on platform
env:
PIP_USE_PEP517: '1'
steps:
- name: checkout
uses: actions/checkout@v4
- name: check for changed python files
if: ${{ inputs.always_run != true }}
id: changed-files
uses: tj-actions/changed-files@v42
with:
files_yaml: |
python:
- 'pyproject.toml'
- 'invokeai/**'
- '!invokeai/frontend/web/**'
- 'tests/**'
- name: setup python
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
cache: pip
cache-dependency-path: pyproject.toml
- name: install dependencies
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
env:
PIP_EXTRA_INDEX_URL: ${{ matrix.extra-index-url }}
run: >
pip3 install --editable=".[test]"
- name: run pytest
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
run: pytest

View File

@@ -1,108 +0,0 @@
# Main release workflow. Triggered on tag push or manual trigger.
#
# - Runs all code checks and tests
# - Verifies the app version matches the tag version.
# - Builds the installer and build, uploading them as artifacts.
# - Publishes to TestPyPI and PyPI. Both are conditional on the previous steps passing and require a manual approval.
#
# See docs/RELEASE.md for more information on the release process.
name: release
on:
push:
tags:
- 'v*'
workflow_dispatch:
jobs:
check-version:
runs-on: ubuntu-latest
steps:
- name: checkout
uses: actions/checkout@v4
- name: check python version
uses: samuelcolvin/check-python-version@v4
id: check-python-version
with:
version_file_path: invokeai/version/invokeai_version.py
frontend-checks:
uses: ./.github/workflows/frontend-checks.yml
with:
always_run: true
frontend-tests:
uses: ./.github/workflows/frontend-tests.yml
with:
always_run: true
python-checks:
uses: ./.github/workflows/python-checks.yml
with:
always_run: true
python-tests:
uses: ./.github/workflows/python-tests.yml
with:
always_run: true
build:
uses: ./.github/workflows/build-installer.yml
publish-testpypi:
runs-on: ubuntu-latest
timeout-minutes: 5 # expected run time: <1 min
needs:
[
check-version,
frontend-checks,
frontend-tests,
python-checks,
python-tests,
build,
]
environment:
name: testpypi
url: https://test.pypi.org/p/invokeai
permissions:
id-token: write
steps:
- name: download distribution from build job
uses: actions/download-artifact@v4
with:
name: dist
path: dist/
- name: publish distribution to TestPyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
repository-url: https://test.pypi.org/legacy/
publish-pypi:
runs-on: ubuntu-latest
timeout-minutes: 5 # expected run time: <1 min
needs:
[
check-version,
frontend-checks,
frontend-tests,
python-checks,
python-tests,
build,
]
environment:
name: pypi
url: https://pypi.org/p/invokeai
permissions:
id-token: write
steps:
- name: download distribution from build job
uses: actions/download-artifact@v4
with:
name: dist
path: dist/
- name: publish distribution to PyPI
uses: pypa/gh-action-pypi-publish@release/v1

24
.github/workflows/style-checks.yml vendored Normal file
View File

@@ -0,0 +1,24 @@
name: style checks
on:
pull_request:
push:
branches: main
jobs:
ruff:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Setup Python
uses: actions/setup-python@v4
with:
python-version: '3.10'
- name: Install dependencies with pip
run: |
pip install ruff
- run: ruff check --output-format=github .
- run: ruff format --check .

129
.github/workflows/test-invoke-pip.yml vendored Normal file
View File

@@ -0,0 +1,129 @@
name: Test invoke.py pip
on:
push:
branches:
- 'main'
pull_request:
types:
- 'ready_for_review'
- 'opened'
- 'synchronize'
merge_group:
workflow_dispatch:
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
matrix:
if: github.event.pull_request.draft == false
strategy:
matrix:
python-version:
# - '3.9'
- '3.10'
pytorch:
- linux-cuda-11_7
- linux-rocm-5_2
- linux-cpu
- macos-default
- windows-cpu
include:
- pytorch: linux-cuda-11_7
os: ubuntu-22.04
github-env: $GITHUB_ENV
- pytorch: linux-rocm-5_2
os: ubuntu-22.04
extra-index-url: 'https://download.pytorch.org/whl/rocm5.2'
github-env: $GITHUB_ENV
- pytorch: linux-cpu
os: ubuntu-22.04
extra-index-url: 'https://download.pytorch.org/whl/cpu'
github-env: $GITHUB_ENV
- pytorch: macos-default
os: macOS-12
github-env: $GITHUB_ENV
- pytorch: windows-cpu
os: windows-2022
github-env: $env:GITHUB_ENV
name: ${{ matrix.pytorch }} on ${{ matrix.python-version }}
runs-on: ${{ matrix.os }}
env:
PIP_USE_PEP517: '1'
steps:
- name: Checkout sources
id: checkout-sources
uses: actions/checkout@v3
- name: Check for changed python files
id: changed-files
uses: tj-actions/changed-files@v41
with:
files_yaml: |
python:
- 'pyproject.toml'
- 'invokeai/**'
- '!invokeai/frontend/web/**'
- 'tests/**'
- name: set test prompt to main branch validation
if: steps.changed-files.outputs.python_any_changed == 'true'
run: echo "TEST_PROMPTS=tests/validate_pr_prompt.txt" >> ${{ matrix.github-env }}
- name: setup python
if: steps.changed-files.outputs.python_any_changed == 'true'
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
cache: pip
cache-dependency-path: pyproject.toml
- name: install invokeai
if: steps.changed-files.outputs.python_any_changed == 'true'
env:
PIP_EXTRA_INDEX_URL: ${{ matrix.extra-index-url }}
run: >
pip3 install
--editable=".[test]"
- name: run pytest
if: steps.changed-files.outputs.python_any_changed == 'true'
id: run-pytest
run: pytest
# - name: run invokeai-configure
# env:
# HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGINGFACE_TOKEN }}
# run: >
# invokeai-configure
# --yes
# --default_only
# --full-precision
# # can't use fp16 weights without a GPU
# - name: run invokeai
# id: run-invokeai
# env:
# # Set offline mode to make sure configure preloaded successfully.
# HF_HUB_OFFLINE: 1
# HF_DATASETS_OFFLINE: 1
# TRANSFORMERS_OFFLINE: 1
# INVOKEAI_OUTDIR: ${{ github.workspace }}/results
# run: >
# invokeai
# --no-patchmatch
# --no-nsfw_checker
# --precision=float32
# --always_use_cpu
# --use_memory_db
# --outdir ${{ env.INVOKEAI_OUTDIR }}/${{ matrix.python-version }}/${{ matrix.pytorch }}
# --from_file ${{ env.TEST_PROMPTS }}
# - name: Archive results
# env:
# INVOKEAI_OUTDIR: ${{ github.workspace }}/results
# uses: actions/upload-artifact@v3
# with:
# name: results
# path: ${{ env.INVOKEAI_OUTDIR }}

View File

@@ -7,7 +7,7 @@ embeddedLanguageFormatting: auto
overrides: overrides:
- files: '*.md' - files: '*.md'
options: options:
proseWrap: preserve proseWrap: always
printWidth: 80 printWidth: 80
parser: markdown parser: markdown
cursorOffset: -1 cursorOffset: -1

View File

@@ -10,12 +10,8 @@ help:
@echo "ruff-unsafe Run ruff, fixing all fixable errors and formatting" @echo "ruff-unsafe Run ruff, fixing all fixable errors and formatting"
@echo "mypy Run mypy using the config in pyproject.toml to identify type mismatches and other coding errors" @echo "mypy Run mypy using the config in pyproject.toml to identify type mismatches and other coding errors"
@echo "mypy-all Run mypy ignoring the config in pyproject.tom but still ignoring missing imports" @echo "mypy-all Run mypy ignoring the config in pyproject.tom but still ignoring missing imports"
@echo "test Run the unit tests."
@echo "update-config-docstring Update the app's config docstring so mkdocs can autogenerate it correctly."
@echo "frontend-install Install the pnpm modules needed for the front end"
@echo "frontend-build Build the frontend in order to run on localhost:9090" @echo "frontend-build Build the frontend in order to run on localhost:9090"
@echo "frontend-dev Run the frontend in developer mode on localhost:5173" @echo "frontend-dev Run the frontend in developer mode on localhost:5173"
@echo "frontend-typegen Generate types for the frontend from the OpenAPI schema"
@echo "installer-zip Build the installer .zip file for the current version" @echo "installer-zip Build the installer .zip file for the current version"
@echo "tag-release Tag the GitHub repository with the current version (use at release time only!)" @echo "tag-release Tag the GitHub repository with the current version (use at release time only!)"
@@ -38,19 +34,6 @@ mypy:
mypy-all: mypy-all:
mypy scripts/invokeai-web.py --config-file= --ignore-missing-imports mypy scripts/invokeai-web.py --config-file= --ignore-missing-imports
# Run the unit tests
test:
pytest ./tests
# Update config docstring
update-config-docstring:
python scripts/update_config_docstring.py
# Install the pnpm modules needed for the front end
frontend-install:
rm -rf invokeai/frontend/web/node_modules
cd invokeai/frontend/web && pnpm install
# Build the frontend # Build the frontend
frontend-build: frontend-build:
cd invokeai/frontend/web && pnpm build cd invokeai/frontend/web && pnpm build
@@ -59,9 +42,6 @@ frontend-build:
frontend-dev: frontend-dev:
cd invokeai/frontend/web && pnpm dev cd invokeai/frontend/web && pnpm dev
frontend-typegen:
cd invokeai/frontend/web && python ../../../scripts/generate_openapi_schema.py | pnpm typegen
# Installer zip file # Installer zip file
installer-zip: installer-zip:
cd installer && ./create_installer.sh cd installer && ./create_installer.sh

View File

@@ -2,25 +2,17 @@
## Any environment variables supported by InvokeAI can be specified here, ## Any environment variables supported by InvokeAI can be specified here,
## in addition to the examples below. ## in addition to the examples below.
## INVOKEAI_ROOT is the path *on the host system* where Invoke will store its data. # HOST_INVOKEAI_ROOT is the path on the docker host's filesystem where InvokeAI will store data.
## It is mounted into the container and allows both containerized and non-containerized usage of Invoke. # Outputs will also be stored here by default.
# Usually this is the only variable you need to set. It can be relative or absolute. # If relative, it will be relative to the docker directory in which the docker-compose.yml file is located
#HOST_INVOKEAI_ROOT=../../invokeai-data
# INVOKEAI_ROOT is the path to the root of the InvokeAI repository within the container.
# INVOKEAI_ROOT=~/invokeai # INVOKEAI_ROOT=~/invokeai
## HOST_INVOKEAI_ROOT and CONTAINER_INVOKEAI_ROOT can be used to control the on-host # Get this value from your HuggingFace account settings page.
## and in-container paths separately, if needed. # HUGGING_FACE_HUB_TOKEN=
## HOST_INVOKEAI_ROOT is the path on the docker host's filesystem where Invoke will store data.
## If relative, it will be relative to the docker directory in which the docker-compose.yml file is located
## CONTAINER_INVOKEAI_ROOT is the path within the container where Invoke will expect to find the runtime directory.
## It MUST be absolute. There is usually no need to change this.
# HOST_INVOKEAI_ROOT=../../invokeai-data
# CONTAINER_INVOKEAI_ROOT=/invokeai
## INVOKEAI_PORT is the port on which the InvokeAI web interface will be available ## optional variables specific to the docker setup.
# INVOKEAI_PORT=9090
## GPU_DRIVER can be set to either `nvidia` or `rocm` to enable GPU support in the container accordingly.
# GPU_DRIVER=nvidia #| rocm # GPU_DRIVER=nvidia #| rocm
## CONTAINER_UID can be set to the UID of the user on the host system that should own the files in the container.
# CONTAINER_UID=1000 # CONTAINER_UID=1000

View File

@@ -18,6 +18,8 @@ ENV INVOKEAI_SRC=/opt/invokeai
ENV VIRTUAL_ENV=/opt/venv/invokeai ENV VIRTUAL_ENV=/opt/venv/invokeai
ENV PATH="$VIRTUAL_ENV/bin:$PATH" ENV PATH="$VIRTUAL_ENV/bin:$PATH"
ARG TORCH_VERSION=2.1.2
ARG TORCHVISION_VERSION=0.16.2
ARG GPU_DRIVER=cuda ARG GPU_DRIVER=cuda
ARG TARGETPLATFORM="linux/amd64" ARG TARGETPLATFORM="linux/amd64"
# unused but available # unused but available
@@ -25,12 +27,7 @@ ARG BUILDPLATFORM
WORKDIR ${INVOKEAI_SRC} WORKDIR ${INVOKEAI_SRC}
COPY invokeai ./invokeai # Install pytorch before all other pip packages
COPY pyproject.toml ./
# Editable mode helps use the same image for development:
# the local working copy can be bind-mounted into the image
# at path defined by ${INVOKEAI_SRC}
# NOTE: there are no pytorch builds for arm64 + cuda, only cpu # NOTE: there are no pytorch builds for arm64 + cuda, only cpu
# x86_64/CUDA is default # x86_64/CUDA is default
RUN --mount=type=cache,target=/root/.cache/pip \ RUN --mount=type=cache,target=/root/.cache/pip \
@@ -42,10 +39,20 @@ RUN --mount=type=cache,target=/root/.cache/pip \
else \ else \
extra_index_url_arg="--extra-index-url https://download.pytorch.org/whl/cu121"; \ extra_index_url_arg="--extra-index-url https://download.pytorch.org/whl/cu121"; \
fi &&\ fi &&\
pip install $extra_index_url_arg \
torch==$TORCH_VERSION \
torchvision==$TORCHVISION_VERSION
# Install the local package.
# Editable mode helps use the same image for development:
# the local working copy can be bind-mounted into the image
# at path defined by ${INVOKEAI_SRC}
COPY invokeai ./invokeai
COPY pyproject.toml ./
RUN --mount=type=cache,target=/root/.cache/pip \
# xformers + triton fails to install on arm64 # xformers + triton fails to install on arm64
if [ "$GPU_DRIVER" = "cuda" ] && [ "$TARGETPLATFORM" = "linux/amd64" ]; then \ if [ "$GPU_DRIVER" = "cuda" ] && [ "$TARGETPLATFORM" = "linux/amd64" ]; then \
pip install $extra_index_url_arg -e ".[xformers]"; \ pip install -e ".[xformers]"; \
else \ else \
pip install $extra_index_url_arg -e "."; \ pip install $extra_index_url_arg -e "."; \
fi fi
@@ -94,8 +101,6 @@ RUN apt update && apt install -y --no-install-recommends \
ENV INVOKEAI_SRC=/opt/invokeai ENV INVOKEAI_SRC=/opt/invokeai
ENV VIRTUAL_ENV=/opt/venv/invokeai ENV VIRTUAL_ENV=/opt/venv/invokeai
ENV INVOKEAI_ROOT=/invokeai ENV INVOKEAI_ROOT=/invokeai
ENV INVOKEAI_HOST=0.0.0.0
ENV INVOKEAI_PORT=9090
ENV PATH="$VIRTUAL_ENV/bin:$INVOKEAI_SRC:$PATH" ENV PATH="$VIRTUAL_ENV/bin:$INVOKEAI_SRC:$PATH"
ENV CONTAINER_UID=${CONTAINER_UID:-1000} ENV CONTAINER_UID=${CONTAINER_UID:-1000}
ENV CONTAINER_GID=${CONTAINER_GID:-1000} ENV CONTAINER_GID=${CONTAINER_GID:-1000}
@@ -120,4 +125,4 @@ RUN mkdir -p ${INVOKEAI_ROOT} && chown -R ${CONTAINER_UID}:${CONTAINER_GID} ${IN
COPY docker/docker-entrypoint.sh ./ COPY docker/docker-entrypoint.sh ./
ENTRYPOINT ["/opt/invokeai/docker-entrypoint.sh"] ENTRYPOINT ["/opt/invokeai/docker-entrypoint.sh"]
CMD ["invokeai-web"] CMD ["invokeai-web", "--host", "0.0.0.0"]

View File

@@ -8,28 +8,35 @@ x-invokeai: &invokeai
context: .. context: ..
dockerfile: docker/Dockerfile dockerfile: docker/Dockerfile
# variables without a default will automatically inherit from the host environment
environment:
- INVOKEAI_ROOT
- HF_HOME
# Create a .env file in the same directory as this docker-compose.yml file # Create a .env file in the same directory as this docker-compose.yml file
# and populate it with environment variables. See .env.sample # and populate it with environment variables. See .env.sample
env_file: env_file:
- .env - .env
# variables without a default will automatically inherit from the host environment
environment:
# if set, CONTAINER_INVOKEAI_ROOT will override the Invoke runtime directory location *inside* the container
- INVOKEAI_ROOT=${CONTAINER_INVOKEAI_ROOT:-/invokeai}
- HF_HOME
ports: ports:
- "${INVOKEAI_PORT:-9090}:${INVOKEAI_PORT:-9090}" - "${INVOKEAI_PORT:-9090}:9090"
volumes: volumes:
- type: bind - type: bind
source: ${HOST_INVOKEAI_ROOT:-${INVOKEAI_ROOT:-~/invokeai}} source: ${HOST_INVOKEAI_ROOT:-${INVOKEAI_ROOT:-~/invokeai}}
target: ${CONTAINER_INVOKEAI_ROOT:-/invokeai} target: ${INVOKEAI_ROOT:-/invokeai}
bind:
create_host_path: true
- ${HF_HOME:-~/.cache/huggingface}:${HF_HOME:-/invokeai/.cache/huggingface} - ${HF_HOME:-~/.cache/huggingface}:${HF_HOME:-/invokeai/.cache/huggingface}
# - ${INVOKEAI_MODELS_DIR:-${INVOKEAI_ROOT:-/invokeai/models}}
# - ${INVOKEAI_MODELS_CONFIG_PATH:-${INVOKEAI_ROOT:-/invokeai/configs/models.yaml}}
tty: true tty: true
stdin_open: true stdin_open: true
# # Example of running alternative commands/scripts in the container
# command:
# - bash
# - -c
# - |
# invokeai-model-install --yes --default-only --config_file ${INVOKEAI_ROOT}/config_custom.yaml
# invokeai-nodes-web --host 0.0.0.0
services: services:
invokeai-nvidia: invokeai-nvidia:

View File

@@ -9,6 +9,10 @@ set -e -o pipefail
### Set INVOKEAI_ROOT pointing to a valid runtime directory ### Set INVOKEAI_ROOT pointing to a valid runtime directory
# Otherwise configure the runtime dir first. # Otherwise configure the runtime dir first.
### Configure the InvokeAI runtime directory (done by default)):
# docker run --rm -it <this image> --configure
# or skip with --no-configure
### Set the CONTAINER_UID envvar to match your user. ### Set the CONTAINER_UID envvar to match your user.
# Ensures files created in the container are owned by you: # Ensures files created in the container are owned by you:
# docker run --rm -it -v /some/path:/invokeai -e CONTAINER_UID=$(id -u) <this image> # docker run --rm -it -v /some/path:/invokeai -e CONTAINER_UID=$(id -u) <this image>
@@ -18,6 +22,27 @@ USER_ID=${CONTAINER_UID:-1000}
USER=ubuntu USER=ubuntu
usermod -u ${USER_ID} ${USER} 1>/dev/null usermod -u ${USER_ID} ${USER} 1>/dev/null
configure() {
# Configure the runtime directory
if [[ -f ${INVOKEAI_ROOT}/invokeai.yaml ]]; then
echo "${INVOKEAI_ROOT}/invokeai.yaml exists. InvokeAI is already configured."
echo "To reconfigure InvokeAI, delete the above file."
echo "======================================================================"
else
mkdir -p "${INVOKEAI_ROOT}"
chown --recursive ${USER} "${INVOKEAI_ROOT}"
gosu ${USER} invokeai-configure --yes --default_only
fi
}
## Skip attempting to configure.
## Must be passed first, before any other args.
if [[ $1 != "--no-configure" ]]; then
configure
else
shift
fi
### Set the $PUBLIC_KEY env var to enable SSH access. ### Set the $PUBLIC_KEY env var to enable SSH access.
# We do not install openssh-server in the image by default to avoid bloat. # We do not install openssh-server in the image by default to avoid bloat.
# but it is useful to have the full SSH server e.g. on Runpod. # but it is useful to have the full SSH server e.g. on Runpod.
@@ -33,8 +58,7 @@ if [[ -v "PUBLIC_KEY" ]] && [[ ! -d "${HOME}/.ssh" ]]; then
service ssh start service ssh start
fi fi
mkdir -p "${INVOKEAI_ROOT}"
chown --recursive ${USER} "${INVOKEAI_ROOT}"
cd "${INVOKEAI_ROOT}" cd "${INVOKEAI_ROOT}"
# Run the CMD as the Container User (not root). # Run the CMD as the Container User (not root).

View File

@@ -1,142 +0,0 @@
# Release Process
The app is published in twice, in different build formats.
- A [PyPI] distribution. This includes both a source distribution and built distribution (a wheel). Users install with `pip install invokeai`. The updater uses this build.
- An installer on the [InvokeAI Releases Page]. This is a zip file with install scripts and a wheel. This is only used for new installs.
## General Prep
Make a developer call-out for PRs to merge. Merge and test things out.
While the release workflow does not include end-to-end tests, it does pause before publishing so you can download and test the final build.
## Release Workflow
The `release.yml` workflow runs a number of jobs to handle code checks, tests, build and publish on PyPI.
It is triggered on **tag push**, when the tag matches `v*`. It doesn't matter if you've prepped a release branch like `release/v3.5.0` or are releasing from `main` - it works the same.
> Because commits are reference-counted, it is safe to create a release branch, tag it, let the workflow run, then delete the branch. So long as the tag exists, that commit will exist.
### Triggering the Workflow
Run `make tag-release` to tag the current commit and kick off the workflow.
The release may also be dispatched [manually].
### Workflow Jobs and Process
The workflow consists of a number of concurrently-run jobs, and two final publish jobs.
The publish jobs require manual approval and are only run if the other jobs succeed.
#### `check-version` Job
This job checks that the git ref matches the app version. It matches the ref against the `__version__` variable in `invokeai/version/invokeai_version.py`.
When the workflow is triggered by tag push, the ref is the tag. If the workflow is run manually, the ref is the target selected from the **Use workflow from** dropdown.
This job uses [samuelcolvin/check-python-version].
> Any valid [version specifier] works, so long as the tag matches the version. The release workflow works exactly the same for `RC`, `post`, `dev`, etc.
#### Check and Test Jobs
- **`python-tests`**: runs `pytest` on matrix of platforms
- **`python-checks`**: runs `ruff` (format and lint)
- **`frontend-tests`**: runs `vitest`
- **`frontend-checks`**: runs `prettier` (format), `eslint` (lint), `dpdm` (circular refs), `tsc` (static type check) and `knip` (unused imports)
> **TODO** We should add `mypy` or `pyright` to the **`check-python`** job.
> **TODO** We should add an end-to-end test job that generates an image.
#### `build-installer` Job
This sets up both python and frontend dependencies and builds the python package. Internally, this runs `installer/create_installer.sh` and uploads two artifacts:
- **`dist`**: the python distribution, to be published on PyPI
- **`InvokeAI-installer-${VERSION}.zip`**: the installer to be included in the GitHub release
#### Sanity Check & Smoke Test
At this point, the release workflow pauses as the remaining publish jobs require approval.
A maintainer should go to the **Summary** tab of the workflow, download the installer and test it. Ensure the app loads and generates.
> The same wheel file is bundled in the installer and in the `dist` artifact, which is uploaded to PyPI. You should end up with the exactly the same installation of the `invokeai` package from any of these methods.
#### PyPI Publish Jobs
The publish jobs will run if any of the previous jobs fail.
They use [GitHub environments], which are configured as [trusted publishers] on PyPI.
Both jobs require a maintainer to approve them from the workflow's **Summary** tab.
- Click the **Review deployments** button
- Select the environment (either `testpypi` or `pypi`)
- Click **Approve and deploy**
> **If the version already exists on PyPI, the publish jobs will fail.** PyPI only allows a given version to be published once - you cannot change it. If version published on PyPI has a problem, you'll need to "fail forward" by bumping the app version and publishing a followup release.
#### `publish-testpypi` Job
Publishes the distribution on the [Test PyPI] index, using the `testpypi` GitHub environment.
This job is not required for the production PyPI publish, but included just in case you want to test the PyPI release.
If approved and successful, you could try out the test release like this:
```sh
# Create a new virtual environment
python -m venv ~/.test-invokeai-dist --prompt test-invokeai-dist
# Install the distribution from Test PyPI
pip install --index-url https://test.pypi.org/simple/ invokeai
# Run and test the app
invokeai-web
# Cleanup
deactivate
rm -rf ~/.test-invokeai-dist
```
#### `publish-pypi` Job
Publishes the distribution on the production PyPI index, using the `pypi` GitHub environment.
## Publish the GitHub Release with installer
Once the release is published to PyPI, it's time to publish the GitHub release.
1. [Draft a new release] on GitHub, choosing the tag that triggered the release.
2. Write the release notes, describing important changes. The **Generate release notes** button automatically inserts the changelog and new contributors, and you can copy/paste the intro from previous releases.
3. Upload the zip file created in **`build`** job into the Assets section of the release notes. You can also upload the zip into the body of the release notes, since it can be hard for users to find the Assets section.
4. Check the **Set as a pre-release** and **Create a discussion for this release** checkboxes at the bottom of the release page.
5. Publish the pre-release.
6. Announce the pre-release in Discord.
> **TODO** Workflows can create a GitHub release from a template and upload release assets. One popular action to handle this is [ncipollo/release-action]. A future enhancement to the release process could set this up.
## Manual Build
The `build installer` workflow can be dispatched manually. This is useful to test the installer for a given branch or tag.
No checks are run, it just builds.
## Manual Release
The `release` workflow can be dispatched manually. You must dispatch the workflow from the right tag, else it will fail the version check.
This functionality is available as a fallback in case something goes wonky. Typically, releases should be triggered via tag push as described above.
[InvokeAI Releases Page]: https://github.com/invoke-ai/InvokeAI/releases
[PyPI]: https://pypi.org/
[Draft a new release]: https://github.com/invoke-ai/InvokeAI/releases/new
[Test PyPI]: https://test.pypi.org/
[version specifier]: https://packaging.python.org/en/latest/specifications/version-specifiers/
[ncipollo/release-action]: https://github.com/ncipollo/release-action
[GitHub environments]: https://docs.github.com/en/actions/deployment/targeting-different-environments/using-environments-for-deployment
[trusted publishers]: https://docs.pypi.org/trusted-publishers/
[samuelcolvin/check-python-version]: https://github.com/samuelcolvin/check-python-version
[manually]: #manual-release

View File

@@ -16,6 +16,11 @@ model. These are the:
information. It is also responsible for managing the InvokeAI information. It is also responsible for managing the InvokeAI
`models` directory and its contents. `models` directory and its contents.
* _ModelMetadataStore_ and _ModelMetaDataFetch_ Backend modules that
are able to retrieve metadata from online model repositories,
transform them into Pydantic models, and cache them to the InvokeAI
SQL database.
* _DownloadQueueServiceBase_ * _DownloadQueueServiceBase_
A multithreaded downloader responsible A multithreaded downloader responsible
for downloading models from a remote source to disk. The download for downloading models from a remote source to disk. The download
@@ -27,6 +32,7 @@ model. These are the:
Responsible for loading a model from disk Responsible for loading a model from disk
into RAM and VRAM and getting it ready for inference. into RAM and VRAM and getting it ready for inference.
## Location of the Code ## Location of the Code
The four main services can be found in The four main services can be found in
@@ -61,17 +67,19 @@ provides the following fields:
| `model_format` | ModelFormat | The format of the model (e.g. "diffusers"); also used as a Union discriminator | | `model_format` | ModelFormat | The format of the model (e.g. "diffusers"); also used as a Union discriminator |
| `base_model` | BaseModelType | The base model that the model is compatible with | | `base_model` | BaseModelType | The base model that the model is compatible with |
| `path` | str | Location of model on disk | | `path` | str | Location of model on disk |
| `hash` | str | Hash of the model | | `original_hash` | str | Hash of the model when it was first installed |
| `current_hash` | str | Most recent hash of the model's contents |
| `description` | str | Human-readable description of the model (optional) | | `description` | str | Human-readable description of the model (optional) |
| `source` | str | Model's source URL or repo id (optional) | | `source` | str | Model's source URL or repo id (optional) |
The `key` is a unique 32-character random ID which was generated at The `key` is a unique 32-character random ID which was generated at
install time. The `hash` field stores a hash of the model's install time. The `original_hash` field stores a hash of the model's
contents at install time obtained by sampling several parts of the contents at install time obtained by sampling several parts of the
model's files using the `imohash` library. Over the course of the model's files using the `imohash` library. Over the course of the
model's lifetime it may be transformed in various ways, such as model's lifetime it may be transformed in various ways, such as
changing its precision or converting it from a .safetensors to a changing its precision or converting it from a .safetensors to a
diffusers model. diffusers model. When this happens, `original_hash` is unchanged, but
`current_hash` is updated to indicate the current contents.
`ModelType`, `ModelFormat` and `BaseModelType` are string enums that `ModelType`, `ModelFormat` and `BaseModelType` are string enums that
are defined in `invokeai.backend.model_manager.config`. They are also are defined in `invokeai.backend.model_manager.config`. They are also
@@ -86,6 +94,7 @@ The `path` field can be absolute or relative. If relative, it is taken
to be relative to the `models_dir` setting in the user's to be relative to the `models_dir` setting in the user's
`invokeai.yaml` file. `invokeai.yaml` file.
### CheckpointConfig ### CheckpointConfig
This adds support for checkpoint configurations, and adds the This adds support for checkpoint configurations, and adds the
@@ -219,9 +228,9 @@ The way it works is as follows:
1. Retrieve the value of the `model_config_db` option from the user's 1. Retrieve the value of the `model_config_db` option from the user's
`invokeai.yaml` config file. `invokeai.yaml` config file.
2. If `model_config_db` is `auto` (the default), then: 2. If `model_config_db` is `auto` (the default), then:
* Use the values of `conn` and `lock` to return a `ModelRecordServiceSQL` object - Use the values of `conn` and `lock` to return a `ModelRecordServiceSQL` object
opened on the passed connection and lock. opened on the passed connection and lock.
* Open up a new connection to `databases/invokeai.db` if `conn` - Open up a new connection to `databases/invokeai.db` if `conn`
and/or `lock` are missing (see note below). and/or `lock` are missing (see note below).
3. If `model_config_db` is a Path, then use `from_db_file` 3. If `model_config_db` is a Path, then use `from_db_file`
to return the appropriate type of ModelRecordService. to return the appropriate type of ModelRecordService.
@@ -246,7 +255,7 @@ store = ModelRecordServiceBase.open(config, db_conn, lock)
Configurations can be retrieved in several ways. Configurations can be retrieved in several ways.
#### get_model(key) -> AnyModelConfig #### get_model(key) -> AnyModelConfig:
The basic functionality is to call the record store object's The basic functionality is to call the record store object's
`get_model()` method with the desired model's unique key. It returns `get_model()` method with the desired model's unique key. It returns
@@ -263,28 +272,28 @@ print(model_conf.path)
If the key is unrecognized, this call raises an If the key is unrecognized, this call raises an
`UnknownModelException`. `UnknownModelException`.
#### exists(key) -> AnyModelConfig #### exists(key) -> AnyModelConfig:
Returns True if a model with the given key exists in the databsae. Returns True if a model with the given key exists in the databsae.
#### search_by_path(path) -> AnyModelConfig #### search_by_path(path) -> AnyModelConfig:
Returns the configuration of the model whose path is `path`. The path Returns the configuration of the model whose path is `path`. The path
is matched using a simple string comparison and won't correctly match is matched using a simple string comparison and won't correctly match
models referred to by different paths (e.g. using symbolic links). models referred to by different paths (e.g. using symbolic links).
#### search_by_name(name, base, type) -> List[AnyModelConfig] #### search_by_name(name, base, type) -> List[AnyModelConfig]:
This method searches for models that match some combination of `name`, This method searches for models that match some combination of `name`,
`BaseType` and `ModelType`. Calling without any arguments will return `BaseType` and `ModelType`. Calling without any arguments will return
all the models in the database. all the models in the database.
#### all_models() -> List[AnyModelConfig] #### all_models() -> List[AnyModelConfig]:
Return all the model configs in the database. Exactly equivalent to Return all the model configs in the database. Exactly equivalent to
calling `search_by_name()` with no arguments. calling `search_by_name()` with no arguments.
#### search_by_tag(tags) -> List[AnyModelConfig] #### search_by_tag(tags) -> List[AnyModelConfig]:
`tags` is a list of strings. This method returns a list of model `tags` is a list of strings. This method returns a list of model
configs that contain all of the given tags. Examples: configs that contain all of the given tags. Examples:
@@ -303,11 +312,11 @@ commercializable_models = [x for x in store.all_models() \
if x.license.contains('allowCommercialUse=Sell')] if x.license.contains('allowCommercialUse=Sell')]
``` ```
#### version() -> str #### version() -> str:
Returns the version of the database, currently at `3.2` Returns the version of the database, currently at `3.2`
#### model_info_by_name(name, base_model, model_type) -> ModelConfigBase #### model_info_by_name(name, base_model, model_type) -> ModelConfigBase:
This method exists to ease the transition from the previous version of This method exists to ease the transition from the previous version of
the model manager, in which `get_model()` took the three arguments the model manager, in which `get_model()` took the three arguments
@@ -328,7 +337,7 @@ model and pass its key to `get_model()`.
Several methods allow you to create and update stored model config Several methods allow you to create and update stored model config
records. records.
#### add_model(key, config) -> AnyModelConfig #### add_model(key, config) -> AnyModelConfig:
Given a key and a configuration, this will add the model's Given a key and a configuration, this will add the model's
configuration record to the database. `config` can either be a subclass of configuration record to the database. `config` can either be a subclass of
@@ -343,7 +352,7 @@ model with the same key is already in the database, or an
`InvalidModelConfigException` if a dict was passed and Pydantic `InvalidModelConfigException` if a dict was passed and Pydantic
experienced a parse or validation error. experienced a parse or validation error.
### update_model(key, config) -> AnyModelConfig ### update_model(key, config) -> AnyModelConfig:
Given a key and a configuration, this will update the model Given a key and a configuration, this will update the model
configuration record in the database. `config` can be either a configuration record in the database. `config` can be either a
@@ -361,30 +370,33 @@ The `ModelInstallService` class implements the
shop for all your model install needs. It provides the following shop for all your model install needs. It provides the following
functionality: functionality:
* Registering a model config record for a model already located on the - Registering a model config record for a model already located on the
local filesystem, without moving it or changing its path. local filesystem, without moving it or changing its path.
* Installing a model alreadiy located on the local filesystem, by - Installing a model alreadiy located on the local filesystem, by
moving it into the InvokeAI root directory under the moving it into the InvokeAI root directory under the
`models` folder (or wherever config parameter `models_dir` `models` folder (or wherever config parameter `models_dir`
specifies). specifies).
* Probing of models to determine their type, base type and other key - Probing of models to determine their type, base type and other key
information. information.
* Interface with the InvokeAI event bus to provide status updates on - Interface with the InvokeAI event bus to provide status updates on
the download, installation and registration process. the download, installation and registration process.
* Downloading a model from an arbitrary URL and installing it in - Downloading a model from an arbitrary URL and installing it in
`models_dir`. `models_dir`.
* Special handling for HuggingFace repo_ids to recursively download - Special handling for Civitai model URLs which allow the user to
paste in a model page's URL or download link
- Special handling for HuggingFace repo_ids to recursively download
the contents of the repository, paying attention to alternative the contents of the repository, paying attention to alternative
variants such as fp16. variants such as fp16.
* Saving tags and other metadata about the model into the invokeai database - Saving tags and other metadata about the model into the invokeai database
when fetching from a repo that provides that type of information, when fetching from a repo that provides that type of information,
(currently only HuggingFace). (currently only Civitai and HuggingFace).
### Initializing the installer ### Initializing the installer
@@ -428,8 +440,10 @@ required parameters:
| `app_config` | InvokeAIAppConfig | InvokeAI app configuration object | | `app_config` | InvokeAIAppConfig | InvokeAI app configuration object |
| `record_store` | ModelRecordServiceBase | Config record storage database | | `record_store` | ModelRecordServiceBase | Config record storage database |
| `download_queue` | DownloadQueueServiceBase | Download queue object | | `download_queue` | DownloadQueueServiceBase | Download queue object |
| `metadata_store` | Optional[ModelMetadataStore] | Metadata storage object |
|`session` | Optional[requests.Session] | Swap in a different Session object (usually for debugging) | |`session` | Optional[requests.Session] | Swap in a different Session object (usually for debugging) |
Once initialized, the installer will provide the following methods: Once initialized, the installer will provide the following methods:
#### install_job = installer.heuristic_import(source, [config], [access_token]) #### install_job = installer.heuristic_import(source, [config], [access_token])
@@ -443,12 +457,12 @@ The `source` is a string that can be any of these forms
1. A path on the local filesystem (`C:\\users\\fred\\model.safetensors`) 1. A path on the local filesystem (`C:\\users\\fred\\model.safetensors`)
2. A Url pointing to a single downloadable model file (`https://civitai.com/models/58390/detail-tweaker-lora-lora`) 2. A Url pointing to a single downloadable model file (`https://civitai.com/models/58390/detail-tweaker-lora-lora`)
3. A HuggingFace repo_id with any of the following formats: 3. A HuggingFace repo_id with any of the following formats:
* `model/name` -- entire model - `model/name` -- entire model
* `model/name:fp32` -- entire model, using the fp32 variant - `model/name:fp32` -- entire model, using the fp32 variant
* `model/name:fp16:vae` -- vae submodel, using the fp16 variant - `model/name:fp16:vae` -- vae submodel, using the fp16 variant
* `model/name::vae` -- vae submodel, using default precision - `model/name::vae` -- vae submodel, using default precision
* `model/name:fp16:path/to/model.safetensors` -- an individual model file, fp16 variant - `model/name:fp16:path/to/model.safetensors` -- an individual model file, fp16 variant
* `model/name::path/to/model.safetensors` -- an individual model file, default variant - `model/name::path/to/model.safetensors` -- an individual model file, default variant
Note that by specifying a relative path to the top of the HuggingFace Note that by specifying a relative path to the top of the HuggingFace
repo, you can download and install arbitrary models files. repo, you can download and install arbitrary models files.
@@ -552,6 +566,7 @@ details.
This is used for a model that is located on a locally-accessible Posix This is used for a model that is located on a locally-accessible Posix
filesystem, such as a local disk or networked fileshare. filesystem, such as a local disk or networked fileshare.
| **Argument** | **Type** | **Default** | **Description** | | **Argument** | **Type** | **Default** | **Description** |
|------------------|------------------------------|-------------|-------------------------------------------| |------------------|------------------------------|-------------|-------------------------------------------|
| `path` | str | Path | None | Path to the model file or directory | | `path` | str | Path | None | Path to the model file or directory |
@@ -571,7 +586,33 @@ The `AnyHttpUrl` class can be imported from `pydantic.networks`.
Ordinarily, no metadata is retrieved from these sources. However, Ordinarily, no metadata is retrieved from these sources. However,
there is special-case code in the installer that looks for HuggingFace there is special-case code in the installer that looks for HuggingFace
and fetches the corresponding model metadata from the corresponding repo. and Civitai URLs and fetches the corresponding model metadata from
the corresponding repo.
#### CivitaiModelSource
This is used for a model that is hosted by the Civitai web site.
| **Argument** | **Type** | **Default** | **Description** |
|------------------|------------------------------|-------------|-------------------------------------------|
| `version_id` | int | None | The ID of the particular version of the desired model. |
| `access_token` | str | None | An access token needed to gain access to a subscriber's-only model. |
Civitai has two model IDs, both of which are integers. The `model_id`
corresponds to a collection of model versions that may different in
arbitrary ways, such as derivation from different checkpoint training
steps, SFW vs NSFW generation, pruned vs non-pruned, etc. The
`version_id` points to a specific version. Please use the latter.
Some Civitai models require an access token to download. These can be
generated from the Civitai profile page of a logged-in
account. Somewhat annoyingly, if you fail to provide the access token
when downloading a model that needs it, Civitai generates a redirect
to a login page rather than a 403 Forbidden error. The installer
attempts to catch this event and issue an informative error
message. Otherwise you will get an "unrecognized model suffix" error
when the model prober tries to identify the type of the HTML login
page.
#### HFModelSource #### HFModelSource
@@ -584,6 +625,7 @@ HuggingFace has the most complicated `ModelSource` structure:
| `subfolder` | Path | None | Look for the model in a subfolder of the repo. | | `subfolder` | Path | None | Look for the model in a subfolder of the repo. |
| `access_token` | str | None | An access token needed to gain access to a subscriber's-only model. | | `access_token` | str | None | An access token needed to gain access to a subscriber's-only model. |
The `repo_id` is the repository ID, such as `stabilityai/sdxl-turbo`. The `repo_id` is the repository ID, such as `stabilityai/sdxl-turbo`.
The `variant` is one of the various diffusers formats that HuggingFace The `variant` is one of the various diffusers formats that HuggingFace
@@ -619,6 +661,7 @@ in. To download these files, you must provide an
`HfFolder.get_token()` will be called to fill it in with the cached `HfFolder.get_token()` will be called to fill it in with the cached
one. one.
#### Monitoring the install job process #### Monitoring the install job process
When you create an install job with `import_model()`, it launches the When you create an install job with `import_model()`, it launches the
@@ -639,6 +682,7 @@ The `ModelInstallJob` class has the following structure:
| `error_type` | `str` | Name of the exception that led to an error status | | `error_type` | `str` | Name of the exception that led to an error status |
| `error` | `str` | Traceback of the error | | `error` | `str` | Traceback of the error |
If the `event_bus` argument was provided, events will also be If the `event_bus` argument was provided, events will also be
broadcast to the InvokeAI event bus. The events will appear on the bus broadcast to the InvokeAI event bus. The events will appear on the bus
as an event of type `EventServiceBase.model_event`, a timestamp and as an event of type `EventServiceBase.model_event`, a timestamp and
@@ -658,13 +702,14 @@ following keys:
| `total_bytes` | int | Total size of all the files that make up the model | | `total_bytes` | int | Total size of all the files that make up the model |
| `parts` | List[Dict]| Information on the progress of the individual files that make up the model | | `parts` | List[Dict]| Information on the progress of the individual files that make up the model |
The parts is a list of dictionaries that give information on each of The parts is a list of dictionaries that give information on each of
the components pieces of the download. The dictionary's keys are the components pieces of the download. The dictionary's keys are
`source`, `local_path`, `bytes` and `total_bytes`, and correspond to `source`, `local_path`, `bytes` and `total_bytes`, and correspond to
the like-named keys in the main event. the like-named keys in the main event.
Note that downloading events will not be issued for local models, and Note that downloading events will not be issued for local models, and
that downloading events occur _before_ the running event. that downloading events occur *before* the running event.
##### `model_install_running` ##### `model_install_running`
@@ -707,6 +752,7 @@ properties: `waiting`, `downloading`, `running`, `complete`, `errored`
and `cancelled`, as well as `in_terminal_state`. The last will return and `cancelled`, as well as `in_terminal_state`. The last will return
True if the job is in the complete, errored or cancelled states. True if the job is in the complete, errored or cancelled states.
#### Model configuration and probing #### Model configuration and probing
The install service uses the `invokeai.backend.model_manager.probe` The install service uses the `invokeai.backend.model_manager.probe`
@@ -816,6 +862,7 @@ This method is similar to `unregister()`, but also unconditionally
deletes the corresponding model weights file(s), regardless of whether deletes the corresponding model weights file(s), regardless of whether
they are inside or outside the InvokeAI models hierarchy. they are inside or outside the InvokeAI models hierarchy.
#### path = installer.download_and_cache(remote_source, [access_token], [timeout]) #### path = installer.download_and_cache(remote_source, [access_token], [timeout])
This utility routine will download the model file located at source, This utility routine will download the model file located at source,
@@ -927,7 +974,7 @@ is in its lifecycle. Values are defined in the string enum
`DownloadJobStatus`, a symbol available from `DownloadJobStatus`, a symbol available from
`invokeai.app.services.download_manager`. Possible values are: `invokeai.app.services.download_manager`. Possible values are:
| **Value** | **String Value** | **Description** | | **Value** | **String Value** | ** Description ** |
|--------------|---------------------|-------------------| |--------------|---------------------|-------------------|
| `IDLE` | idle | Job created, but not submitted to the queue | | `IDLE` | idle | Job created, but not submitted to the queue |
| `ENQUEUED` | enqueued | Job is patiently waiting on the queue | | `ENQUEUED` | enqueued | Job is patiently waiting on the queue |
@@ -993,11 +1040,11 @@ While a job is being downloaded, the queue will emit events at
periodic intervals. A typical series of events during a successful periodic intervals. A typical series of events during a successful
download session will look like this: download session will look like this:
* enqueued - enqueued
* running - running
* running - running
* running - running
* completed - completed
There will be a single enqueued event, followed by one or more running There will be a single enqueued event, followed by one or more running
events, and finally one `completed`, `error` or `cancelled` events, and finally one `completed`, `error` or `cancelled`
@@ -1006,12 +1053,12 @@ events.
It is possible for a caller to pause download temporarily, in which It is possible for a caller to pause download temporarily, in which
case the events may look something like this: case the events may look something like this:
* enqueued - enqueued
* running - running
* running - running
* paused - paused
* running - running
* completed - completed
The download queue logs when downloads start and end (unless `quiet` The download queue logs when downloads start and end (unless `quiet`
is set to True at initialization time) but doesn't log any progress is set to True at initialization time) but doesn't log any progress
@@ -1140,6 +1187,7 @@ and is equivalent to manually specifying a destination of
Here is the full list of arguments that can be provided to Here is the full list of arguments that can be provided to
`create_download_job()`: `create_download_job()`:
| **Argument** | **Type** | **Default** | **Description** | | **Argument** | **Type** | **Default** | **Description** |
|------------------|------------------------------|-------------|-------------------------------------------| |------------------|------------------------------|-------------|-------------------------------------------|
| `source` | Union[str, Path, AnyHttpUrl] | | Download remote or local source | | `source` | Union[str, Path, AnyHttpUrl] | | Download remote or local source |
@@ -1218,30 +1266,51 @@ queue and have not yet reached a terminal state.
The modules found under `invokeai.backend.model_manager.metadata` The modules found under `invokeai.backend.model_manager.metadata`
provide a straightforward API for fetching model metadatda from online provide a straightforward API for fetching model metadatda from online
repositories. Currently only HuggingFace is supported. However, the repositories. Currently two repositories are supported: HuggingFace
modules are easily extended for additional repos, provided that they and Civitai. However, the modules are easily extended for additional
have defined APIs for metadata access. repos, provided that they have defined APIs for metadata access.
Metadata comprises any descriptive information that is not essential Metadata comprises any descriptive information that is not essential
for getting the model to run. For example "author" is metadata, while for getting the model to run. For example "author" is metadata, while
"type", "base" and "format" are not. The latter fields are part of the "type", "base" and "format" are not. The latter fields are part of the
model's config, as defined in `invokeai.backend.model_manager.config`. model's config, as defined in `invokeai.backend.model_manager.config`.
### Example Usage ### Example Usage:
``` ```
from invokeai.backend.model_manager.metadata import ( from invokeai.backend.model_manager.metadata import (
AnyModelRepoMetadata, AnyModelRepoMetadata,
CivitaiMetadataFetch,
CivitaiMetadata
ModelMetadataStore,
) )
# to access the initialized sql database # to access the initialized sql database
from invokeai.app.api.dependencies import ApiDependencies from invokeai.app.api.dependencies import ApiDependencies
hf = HuggingFaceMetadataFetch() civitai = CivitaiMetadataFetch()
# fetch the metadata # fetch the metadata
model_metadata = hf.from_id("<repo_id>") model_metadata = civitai.from_url("https://civitai.com/models/215796")
assert isinstance(model_metadata, HuggingFaceMetadata) # get some common metadata fields
author = model_metadata.author
tags = model_metadata.tags
# get some Civitai-specific fields
assert isinstance(model_metadata, CivitaiMetadata)
trained_words = model_metadata.trained_words
base_model = model_metadata.base_model_trained_on
thumbnail = model_metadata.thumbnail_url
# cache the metadata to the database using the key corresponding to
# an existing model config record in the `model_config` table
sql_cache = ModelMetadataStore(ApiDependencies.invoker.services.db)
sql_cache.add_metadata('fb237ace520b6716adc98bcb16e8462c', model_metadata)
# now we can search the database by tag, author or model name
# matches will contain a list of model keys that match the search
matches = sql_cache.search_by_tag({"tool", "turbo"})
``` ```
### Structure of the Metadata objects ### Structure of the Metadata objects
@@ -1259,6 +1328,7 @@ This is the common base class for metadata:
| `author` | str | Model's author | | `author` | str | Model's author |
| `tags` | Set[str] | Model tags | | `tags` | Set[str] | Model tags |
Note that the model config record also has a `name` field. It is Note that the model config record also has a `name` field. It is
intended that the config record version be locally customizable, while intended that the config record version be locally customizable, while
the metadata version is read-only. However, enforcing this is expected the metadata version is read-only. However, enforcing this is expected
@@ -1278,14 +1348,53 @@ This descends from `ModelMetadataBase` and adds the following fields:
| `last_modified`| datetime | Date of last commit of this model to the repo | | `last_modified`| datetime | Date of last commit of this model to the repo |
| `files` | List[Path] | List of the files in the model repo | | `files` | List[Path] | List of the files in the model repo |
#### `CivitaiMetadata`
This descends from `ModelMetadataBase` and adds the following fields:
| **Field Name** | **Type** | **Description** |
|----------------|-----------------|------------------|
| `type` | Literal["civitai"] | Used for the discriminated union of metadata classes|
| `id` | int | Civitai model id |
| `version_name` | str | Name of this version of the model (distinct from model name) |
| `version_id` | int | Civitai model version id (distinct from model id) |
| `created` | datetime | Date this version of the model was created |
| `updated` | datetime | Date this version of the model was last updated |
| `published` | datetime | Date this version of the model was published to Civitai |
| `description` | str | Model description. Quite verbose and contains HTML tags |
| `version_description` | str | Model version description, usually describes changes to the model |
| `nsfw` | bool | Whether the model tends to generate NSFW content |
| `restrictions` | LicenseRestrictions | An object that describes what is and isn't allowed with this model |
| `trained_words`| Set[str] | Trigger words for this model, if any |
| `download_url` | AnyHttpUrl | URL for downloading this version of the model |
| `base_model_trained_on` | str | Name of the model that this version was trained on |
| `thumbnail_url` | AnyHttpUrl | URL to access a representative thumbnail image of the model's output |
| `weight_min` | int | For LoRA sliders, the minimum suggested weight to apply |
| `weight_max` | int | For LoRA sliders, the maximum suggested weight to apply |
Note that `weight_min` and `weight_max` are not currently populated
and take the default values of (-1.0, +2.0). The issue is that these
values aren't part of the structured data but appear in the text
description. Some regular expression or LLM coding may be able to
extract these values.
Also be aware that `base_model_trained_on` is free text and doesn't
correspond to our `ModelType` enum.
`CivitaiMetadata` also defines some convenience properties relating to
licensing restrictions: `credit_required`, `allow_commercial_use`,
`allow_derivatives` and `allow_different_license`.
#### `AnyModelRepoMetadata` #### `AnyModelRepoMetadata`
This is a discriminated Union of `HuggingFaceMetadata`. This is a discriminated Union of `CivitaiMetadata` and
`HuggingFaceMetadata`.
### Fetching Metadata from Online Repos ### Fetching Metadata from Online Repos
The `HuggingFaceMetadataFetch` class will The `HuggingFaceMetadataFetch` and `CivitaiMetadataFetch` classes will
retrieve metadata from its corresponding repository and return retrieve metadata from their corresponding repositories and return
`AnyModelRepoMetadata` objects. Their base class `AnyModelRepoMetadata` objects. Their base class
`ModelMetadataFetchBase` is an abstract class that defines two `ModelMetadataFetchBase` is an abstract class that defines two
methods: `from_url()` and `from_id()`. The former accepts the type of methods: `from_url()` and `from_id()`. The former accepts the type of
@@ -1303,17 +1412,98 @@ provide a `requests.Session` argument. This allows you to customize
the low-level HTTP fetch requests and is used, for instance, in the the low-level HTTP fetch requests and is used, for instance, in the
testing suite to avoid hitting the internet. testing suite to avoid hitting the internet.
The HuggingFace fetcher subclass add additional repo-specific fetching methods: The HuggingFace and Civitai fetcher subclasses add additional
repo-specific fetching methods:
#### HuggingFaceMetadataFetch #### HuggingFaceMetadataFetch
This overrides its base class `from_json()` method to return a This overrides its base class `from_json()` method to return a
`HuggingFaceMetadata` object directly. `HuggingFaceMetadata` object directly.
#### CivitaiMetadataFetch
This adds the following methods:
`from_civitai_modelid()` This takes the ID of a model, finds the
default version of the model, and then retrieves the metadata for
that version, returning a `CivitaiMetadata` object directly.
`from_civitai_versionid()` This takes the ID of a model version and
retrieves its metadata. Functionally equivalent to `from_id()`, the
only difference is that it returna a `CivitaiMetadata` object rather
than an `AnyModelRepoMetadata`.
### Metadata Storage ### Metadata Storage
The `ModelConfigBase` stores this response in the `source_api_response` field The `ModelMetadataStore` provides a simple facility to store model
as a JSON blob. metadata in the `invokeai.db` database. The data is stored as a JSON
blob, with a few common fields (`name`, `author`, `tags`) broken out
to be searchable.
When a metadata object is saved to the database, it is identified
using the model key, _and this key must correspond to an existing
model key in the model_config table_. There is a foreign key integrity
constraint between the `model_config.id` field and the
`model_metadata.id` field such that if you attempt to save metadata
under an unknown key, the attempt will result in an
`UnknownModelException`. Likewise, when a model is deleted from
`model_config`, the deletion of the corresponding metadata record will
be triggered.
Tags are stored in a normalized fashion in the tables `model_tags` and
`tags`. Triggers keep the tag table in sync with the `model_metadata`
table.
To create the storage object, initialize it with the InvokeAI
`SqliteDatabase` object. This is often done this way:
```
from invokeai.app.api.dependencies import ApiDependencies
metadata_store = ModelMetadataStore(ApiDependencies.invoker.services.db)
```
You can then access the storage with the following methods:
#### `add_metadata(key, metadata)`
Add the metadata using a previously-defined model key.
There is currently no `delete_metadata()` method. The metadata will
persist until the matching config is deleted from the `model_config`
table.
#### `get_metadata(key) -> AnyModelRepoMetadata`
Retrieve the metadata corresponding to the model key.
#### `update_metadata(key, new_metadata)`
Update an existing metadata record with new metadata.
#### `search_by_tag(tags: Set[str]) -> Set[str]`
Given a set of tags, find models that are tagged with them. If
multiple tags are provided then a matching model must be tagged with
*all* the tags in the set. This method returns a set of model keys and
is intended to be used in conjunction with the `ModelRecordService`:
```
model_config_store = ApiDependencies.invoker.services.model_records
matches = metadata_store.search_by_tag({'license:other'})
models = [model_config_store.get(x) for x in matches]
```
#### `search_by_name(name: str) -> Set[str]
Find all model metadata records that have the given name and return a
set of keys to the corresponding model config objects.
#### `search_by_author(author: str) -> Set[str]
Find all model metadata records that have the given author and return
a set of keys to the corresponding model config objects.
*** ***
@@ -1377,6 +1567,7 @@ The returned `LoadedModel` object contains a copy of the configuration
record returned by the model record `get_model()` method, as well as record returned by the model record `get_model()` method, as well as
the in-memory loaded model: the in-memory loaded model:
| **Attribute Name** | **Type** | **Description** | | **Attribute Name** | **Type** | **Description** |
|----------------|-----------------|------------------| |----------------|-----------------|------------------|
| `config` | AnyModelConfig | A copy of the model's configuration record for retrieving base type, etc. | | `config` | AnyModelConfig | A copy of the model's configuration record for retrieving base type, etc. |
@@ -1390,6 +1581,7 @@ return `AnyModel`, a Union `ModelMixin`, `torch.nn.Module`,
models, `EmbeddingModelRaw` is used for LoRA and TextualInversion models, `EmbeddingModelRaw` is used for LoRA and TextualInversion
models. The others are obvious. models. The others are obvious.
`LoadedModel` acts as a context manager. The context loads the model `LoadedModel` acts as a context manager. The context loads the model
into the execution device (e.g. VRAM on CUDA systems), locks the model into the execution device (e.g. VRAM on CUDA systems), locks the model
in the execution device for the duration of the context, and returns in the execution device for the duration of the context, and returns
@@ -1403,9 +1595,9 @@ with model_info as vae:
`get_model_by_key()` may raise any of the following exceptions: `get_model_by_key()` may raise any of the following exceptions:
* `UnknownModelException` -- key not in database - `UnknownModelException` -- key not in database
* `ModelNotFoundException` -- key in database but model not found at path - `ModelNotFoundException` -- key in database but model not found at path
* `NotImplementedException` -- the loader doesn't know how to load this type of model - `NotImplementedException` -- the loader doesn't know how to load this type of model
### Emitting model loading events ### Emitting model loading events
@@ -1532,7 +1724,6 @@ object, or in `context.services.model_manager` from within an
invocation. invocation.
In the examples below, we have retrieved the manager using: In the examples below, we have retrieved the manager using:
``` ```
mm = ApiDependencies.invoker.services.model_manager mm = ApiDependencies.invoker.services.model_manager
``` ```

View File

@@ -1,133 +0,0 @@
# Invoke UI
Invoke's UI is made possible by many contributors and open-source libraries. Thank you!
## Dev environment
### Setup
1. Install [node] and [pnpm].
1. Run `pnpm i` to install all packages.
#### Run in dev mode
1. From `invokeai/frontend/web/`, run `pnpm dev`.
1. From repo root, run `python scripts/invokeai-web.py`.
1. Point your browser to the dev server address, e.g. <http://localhost:5173/>
### Package scripts
- `dev`: run the frontend in dev mode, enabling hot reloading
- `build`: run all checks (madge, eslint, prettier, tsc) and then build the frontend
- `typegen`: generate types from the OpenAPI schema (see [Type generation])
- `lint:dpdm`: check circular dependencies
- `lint:eslint`: check code quality
- `lint:prettier`: check code formatting
- `lint:tsc`: check type issues
- `lint:knip`: check for unused exports or objects (failures here are just suggestions, not hard fails)
- `lint`: run all checks concurrently
- `fix`: run `eslint` and `prettier`, fixing fixable issues
### Type generation
We use [openapi-typescript] to generate types from the app's OpenAPI schema.
The generated types are committed to the repo in [schema.ts].
```sh
# from the repo root, start the server
python scripts/invokeai-web.py
# from invokeai/frontend/web/, run the script
pnpm typegen
```
### Localization
We use [i18next] for localization, but translation to languages other than English happens on our [Weblate] project.
Only the English source strings should be changed on this repo.
### VSCode
#### Example debugger config
```jsonc
{
"version": "0.2.0",
"configurations": [
{
"type": "chrome",
"request": "launch",
"name": "Invoke UI",
"url": "http://localhost:5173",
"webRoot": "${workspaceFolder}/invokeai/frontend/web"
}
]
}
```
#### Remote dev
We've noticed an intermittent timeout issue with the VSCode remote dev port forwarding.
We suggest disabling the editor's port forwarding feature and doing it manually via SSH:
```sh
ssh -L 9090:localhost:9090 -L 5173:localhost:5173 user@host
```
## Contributing Guidelines
Thanks for your interest in contributing to the Invoke Web UI!
Please follow these guidelines when contributing.
### Check in before investing your time
Please check in before you invest your time on anything besides a trivial fix, in case it conflicts with ongoing work or isn't aligned with the vision for the app.
If a feature request or issue doesn't already exist for the thing you want to work on, please create one.
Ping `@psychedelicious` on [discord] in the `#frontend-dev` channel or in the feature request / issue you want to work on - we're happy to chat.
### Code conventions
- This is a fairly complex app with a deep component tree. Please use memoization (`useCallback`, `useMemo`, `memo`) with enthusiasm.
- If you need to add some global, ephemeral state, please use [nanostores] if possible.
- Be careful with your redux selectors. If they need to be parameterized, consider creating them inside a `useMemo`.
- Feel free to use `lodash` (via `lodash-es`) to make the intent of your code clear.
- Please add comments describing the "why", not the "how" (unless it is really arcane).
### Commit format
Please use the [conventional commits] spec for the web UI, with a scope of "ui":
- `chore(ui): bump deps`
- `chore(ui): lint`
- `feat(ui): add some cool new feature`
- `fix(ui): fix some bug`
### Submitting a PR
- Ensure your branch is tidy. Use an interactive rebase to clean up the commit history and reword the commit messages if they are not descriptive.
- Run `pnpm lint`. Some issues are auto-fixable with `pnpm fix`.
- Fill out the PR form when creating the PR.
- It doesn't need to be super detailed, but a screenshot or video is nice if you changed something visually.
- If a section isn't relevant, delete it. There are no UI tests at this time.
## Other docs
- [Workflows - Design and Implementation]
- [State Management]
[node]: https://nodejs.org/en/download/
[pnpm]: https://github.com/pnpm/pnpm
[discord]: https://discord.gg/ZmtBAhwWhy
[i18next]: https://github.com/i18next/react-i18next
[Weblate]: https://hosted.weblate.org/engage/invokeai/
[openapi-typescript]: https://github.com/drwpow/openapi-typescript
[Type generation]: #type-generation
[schema.ts]: https://github.com/invoke-ai/InvokeAI/blob/main/invokeai/frontend/web/src/services/api/schema.ts
[conventional commits]: https://www.conventionalcommits.org/en/v1.0.0/
[Workflows - Design and Implementation]: ./WORKFLOWS.md
[State Management]: ./STATE_MGMT.md

View File

@@ -6,161 +6,259 @@ title: Configuration
## Intro ## Intro
Runtime settings, including the location of files and InvokeAI has numerous runtime settings which can be used to adjust
directories, memory usage, and performance, are managed via the many aspects of its operations, including the location of files and
`invokeai.yaml` config file or environment variables. A subset directories, memory usage, and performance. These settings can be
of settings may be set via commandline arguments. viewed and customized in several ways:
Settings sources are used in this order: 1. By editing settings in the `invokeai.yaml` file.
2. By setting environment variables.
3. On the command-line, when InvokeAI is launched.
- CLI args In addition, the most commonly changed settings are accessible
- Environment variables graphically via the `invokeai-configure` script.
- `invokeai.yaml` settings
- Fallback: defaults
### InvokeAI Root Directory ### How the Configuration System Works
On startup, InvokeAI searches for its "root" directory. This is the directory When InvokeAI is launched, the very first thing it needs to do is to
that contains models, images, the database, and so on. It also contains find its "root" directory, which contains its configuration files,
a configuration file called `invokeai.yaml`. installed models, its database of images, and the folder(s) of
generated images themselves. In this document, the root directory will
be referred to as ROOT.
InvokeAI searches for the root directory in this order: #### Finding the Root Directory
1. The `--root <path>` CLI arg. To find its root directory, InvokeAI uses the following recipe:
2. The environment variable INVOKEAI_ROOT.
3. The directory containing the currently active virtual environment.
4. Fallback: a directory in the current user's home directory named `invokeai`.
### InvokeAI Configuration File 1. It first looks for the argument `--root <path>` on the command line
it was launched from, and uses the indicated path if present.
Inside the root directory, we read settings from the `invokeai.yaml` file. 2. Next it looks for the environment variable INVOKEAI_ROOT, and uses
the directory path found there if present.
It has two sections - one for internal use and one for user settings: 3. If neither of these are present, then InvokeAI looks for the
folder containing the `.venv` Python virtual environment directory for
the currently active environment. This directory is checked for files
expected inside the InvokeAI root before it is used.
```yaml 4. Finally, InvokeAI looks for a directory in the current user's home
# Internal metadata - do not edit: directory named `invokeai`.
schema_version: 4
# Put user settings here - see https://invoke-ai.github.io/InvokeAI/features/CONFIGURATION/: #### Reading the InvokeAI Configuration File
host: 0.0.0.0 # serve the app on your local network
models_dir: D:\invokeai\models # store models on an external drive Once the root directory has been located, InvokeAI looks for a file
precision: float16 # always use fp16 precision named `ROOT/invokeai.yaml`, and if present reads configuration values
from it. The top of this file looks like this:
```
InvokeAI:
Web Server:
host: localhost
port: 9090
allow_origins: []
allow_credentials: true
allow_methods:
- '*'
allow_headers:
- '*'
Features:
esrgan: true
internet_available: true
log_tokenization: false
patchmatch: true
restore: true
...
``` ```
The settings in this file will override the defaults. You only need This lines in this file are used to establish default values for
to change this file if the default for a particular setting doesn't Invoke's settings. In the above fragment, the Web Server's listening
work for you. port is set to 9090 by the `port` setting.
Some settings, like [Model Marketplace API Keys], require the YAML You can edit this file with a text editor such as "Notepad" (do not
to be formatted correctly. Here is a [basic guide to YAML files]. use Word or any other word processor). When editing, be careful to
maintain the indentation, and do not add extraneous text, as syntax
errors will prevent InvokeAI from launching. A basic guide to the
format of YAML files can be found
[here](https://circleci.com/blog/what-is-yaml-a-beginner-s-guide/).
You can fix a broken `invokeai.yaml` by deleting it and running the You can fix a broken `invokeai.yaml` by deleting it and running the
configuration script again -- option [6] in the launcher, "Re-run the configuration script again -- option [6] in the launcher, "Re-run the
configure script". configure script".
#### Custom Config File Location #### Reading Environment Variables
You can use any config file with the `--config` CLI arg. Pass in the path to the `invokeai.yaml` file you want to use. Next InvokeAI looks for defined environment variables in the format
`INVOKEAI_<setting_name>`, for example `INVOKEAI_port`. Environment
variable values take precedence over configuration file variables. On
a Macintosh system, for example, you could change the port that the
web server listens on by setting the environment variable this way:
Note that environment variables will trump any settings in the config file. ```
export INVOKEAI_port=8000
### Environment Variables invokeai-web
All settings may be set via environment variables by prefixing `INVOKEAI_`
to the variable name. For example, `INVOKEAI_HOST` would set the `host`
setting.
For non-primitive values, pass a JSON-encoded string:
```sh
export INVOKEAI_REMOTE_API_TOKENS='[{"url_regex":"modelmarketplace", "token": "12345"}]'
``` ```
We suggest using `invokeai.yaml`, as it is more user-friendly. Please check out these
[Macintosh](https://phoenixnap.com/kb/set-environment-variable-mac)
and
[Windows](https://phoenixnap.com/kb/windows-set-environment-variable)
guides for setting temporary and permanent environment variables.
### CLI Args #### Reading the Command Line
A subset of settings may be specified using CLI args: Lastly, InvokeAI takes settings from the command line, which override
everything else. The command-line settings have the same name as the
corresponding configuration file settings, preceded by a `--`, for
example `--port 8000`.
- `--root`: specify the root directory If you are using the launcher (`invoke.sh` or `invoke.bat`) to launch
- `--config`: override the default `invokeai.yaml` file location InvokeAI, then just pass the command-line arguments to the launcher:
### All Settings ```
invoke.bat --port 8000 --host 0.0.0.0
Following the table are additional explanations for certain settings.
<!-- prettier-ignore-start -->
::: invokeai.app.services.config.config_default.InvokeAIAppConfig
options:
heading_level: 4
members: false
show_docstring_description: false
group_by_category: true
show_category_heading: false
<!-- prettier-ignore-end -->
#### Model Marketplace API Keys
Some model marketplaces require an API key to download models. You can provide a URL pattern and appropriate token in your `invokeai.yaml` file to provide that API key.
The pattern can be any valid regex (you may need to surround the pattern with quotes):
```yaml
remote_api_tokens:
# Any URL containing `models.com` will automatically use `your_models_com_token`
- url_regex: models.com
token: your_models_com_token
# Any URL matching this contrived regex will use `some_other_token`
- url_regex: '^[a-z]{3}whatever.*\.com$'
token: some_other_token
``` ```
The provided token will be added as a `Bearer` token to the network requests to download the model files. As far as we know, this works for all model marketplaces that require authorization. The arguments will be applied when you select the web server option
(and the other options as well).
#### Model Hashing If, on the other hand, you prefer to launch InvokeAI directly from the
command line, you would first activate the virtual environment (known
as the "developer's console" in the launcher), and run `invokeai-web`:
Models are hashed during installation, providing a stable identifier for models across all platforms. Hashing is a one-time operation. ```
> C:\Users\Fred\invokeai\.venv\scripts\activate
```yaml (.venv) > invokeai-web --port 8000 --host 0.0.0.0
hashing_algorithm: blake3_single # default value
``` ```
You might want to change this setting, depending on your system: You can get a listing and brief instructions for each of the
command-line options by giving the `--help` argument:
- `blake3_single` (default): Single-threaded - best for spinning HDDs, still OK for SSDs ```
- `blake3_multi`: Parallelized, memory-mapped implementation - best for SSDs, terrible for spinning disks (.venv) > invokeai-web --help
- `random`: Skip hashing entirely - fastest but of course no hash usage: InvokeAI [-h] [--host HOST] [--port PORT] [--allow_origins [ALLOW_ORIGINS ...]] [--allow_credentials | --no-allow_credentials] [--allow_methods [ALLOW_METHODS ...]]
[--allow_headers [ALLOW_HEADERS ...]] [--esrgan | --no-esrgan] [--internet_available | --no-internet_available] [--log_tokenization | --no-log_tokenization]
[--patchmatch | --no-patchmatch] [--restore | --no-restore]
[--always_use_cpu | --no-always_use_cpu] [--free_gpu_mem | --no-free_gpu_mem] [--max_loaded_models MAX_LOADED_MODELS] [--max_cache_size MAX_CACHE_SIZE]
[--max_vram_cache_size MAX_VRAM_CACHE_SIZE] [--gpu_mem_reserved GPU_MEM_RESERVED] [--precision {auto,float16,float32,autocast}]
[--sequential_guidance | --no-sequential_guidance] [--xformers_enabled | --no-xformers_enabled] [--tiled_decode | --no-tiled_decode] [--root ROOT]
[--autoimport_dir AUTOIMPORT_DIR] [--lora_dir LORA_DIR] [--embedding_dir EMBEDDING_DIR] [--controlnet_dir CONTROLNET_DIR] [--conf_path CONF_PATH]
[--models_dir MODELS_DIR] [--legacy_conf_dir LEGACY_CONF_DIR] [--db_dir DB_DIR] [--outdir OUTDIR] [--from_file FROM_FILE]
[--use_memory_db | --no-use_memory_db] [--model MODEL] [--log_handlers [LOG_HANDLERS ...]] [--log_format {plain,color,syslog,legacy}]
[--log_level {debug,info,warning,error,critical}] [--version | --no-version]
```
During the first startup after upgrading to v4, all of your models will be hashed. This can take a few minutes. ## The Configuration Settings
Most common algorithms are supported, like `md5`, `sha256`, and `sha512`. These are typically much, much slower than either of the BLAKE3 variants. The configuration settings are divided into several distinct
groups in `invokeia.yaml`:
#### Path Settings ### Web Server
| Setting | Default Value | Description |
|---------------------|---------------|----------------------------------------------------------------------------------------------------------------------------|
| `host` | `localhost` | Name or IP address of the network interface that the web server will listen on |
| `port` | `9090` | Network port number that the web server will listen on |
| `allow_origins` | `[]` | A list of host names or IP addresses that are allowed to connect to the InvokeAI API in the format `['host1','host2',...]` |
| `allow_credentials` | `true` | Require credentials for a foreign host to access the InvokeAI API (don't change this) |
| `allow_methods` | `*` | List of HTTP methods ("GET", "POST") that the web server is allowed to use when accessing the API |
| `allow_headers` | `*` | List of HTTP headers that the web server will accept when accessing the API |
| `ssl_certfile` | null | Path to an SSL certificate file, used to enable HTTPS. |
| `ssl_keyfile` | null | Path to an SSL keyfile, if the key is not included in the certificate file. |
The documentation for InvokeAI's API can be accessed by browsing to the following URL: [http://localhost:9090/docs].
### Features
These configuration settings allow you to enable and disable various InvokeAI features:
| Setting | Default Value | Description |
|----------|----------------|--------------|
| `esrgan` | `true` | Activate the ESRGAN upscaling options|
| `internet_available` | `true` | When a resource is not available locally, try to fetch it via the internet |
| `log_tokenization` | `false` | Before each text2image generation, print a color-coded representation of the prompt to the console; this can help understand why a prompt is not working as expected |
| `patchmatch` | `true` | Activate the "patchmatch" algorithm for improved inpainting |
### Generation
These options tune InvokeAI's memory and performance characteristics.
| Setting | Default Value | Description |
|-----------------------|---------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| `sequential_guidance` | `false` | Calculate guidance in serial rather than in parallel, lowering memory requirements at the cost of some performance loss |
| `attention_type` | `auto` | Select the type of attention to use. One of `auto`,`normal`,`xformers`,`sliced`, or `torch-sdp` |
| `attention_slice_size` | `auto` | When "sliced" attention is selected, set the slice size. One of `auto`, `balanced`, `max` or the integers 1-8|
| `force_tiled_decode` | `false` | Force the VAE step to decode in tiles, reducing memory consumption at the cost of performance |
### Device
These options configure the generation execution device.
| Setting | Default Value | Description |
|-----------------------|---------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| `device` | `auto` | Preferred execution device. One of `auto`, `cpu`, `cuda`, `cuda:1`, `mps`. `auto` will choose the device depending on the hardware platform and the installed torch capabilities. |
| `precision` | `auto` | Floating point precision. One of `auto`, `float16` or `float32`. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system |
### Paths
These options set the paths of various directories and files used by These options set the paths of various directories and files used by
InvokeAI. Relative paths are interpreted relative to the root directory, so InvokeAI. Relative paths are interpreted relative to INVOKEAI_ROOT, so
if root is `/home/fred/invokeai` and the path is if INVOKEAI_ROOT is `/home/fred/invokeai` and the path is
`autoimport/main`, then the corresponding directory will be located at `autoimport/main`, then the corresponding directory will be located at
`/home/fred/invokeai/autoimport/main`. `/home/fred/invokeai/autoimport/main`.
Note that the autoimport directory will be searched recursively, | Setting | Default Value | Description |
allowing you to organize the models into folders and subfolders in any |----------|----------------|--------------|
way you wish. | `autoimport_dir` | `autoimport/main` | At startup time, read and import any main model files found in this directory |
| `lora_dir` | `autoimport/lora` | At startup time, read and import any LoRA/LyCORIS models found in this directory |
| `embedding_dir` | `autoimport/embedding` | At startup time, read and import any textual inversion (embedding) models found in this directory |
| `controlnet_dir` | `autoimport/controlnet` | At startup time, read and import any ControlNet models found in this directory |
| `conf_path` | `configs/models.yaml` | Location of the `models.yaml` model configuration file |
| `models_dir` | `models` | Location of the directory containing models installed by InvokeAI's model manager |
| `legacy_conf_dir` | `configs/stable-diffusion` | Location of the directory containing the .yaml configuration files for legacy checkpoint models |
| `db_dir` | `databases` | Location of the directory containing InvokeAI's image, schema and session database |
| `outdir` | `outputs` | Location of the directory in which the gallery of generated and uploaded images will be stored |
| `use_memory_db` | `false` | Keep database information in memory rather than on disk; this will not preserve image gallery information across restarts |
#### Logging Note that the autoimport directories will be searched recursively,
allowing you to organize the models into folders and subfolders in any
way you wish. In addition, while we have split up autoimport
directories by the type of model they contain, this isn't
necessary. You can combine different model types in the same folder
and InvokeAI will figure out what they are. So you can easily use just
one autoimport directory by commenting out the unneeded paths:
```
Paths:
autoimport_dir: autoimport
# lora_dir: null
# embedding_dir: null
# controlnet_dir: null
```
### Logging
These settings control the information, warning, and debugging
messages printed to the console log while InvokeAI is running:
| Setting | Default Value | Description |
|----------|----------------|--------------|
| `log_handlers` | `console` | This controls where log messages are sent, and can be a list of one or more destinations. Values include `console`, `file`, `syslog` and `http`. These are described in more detail below |
| `log_format` | `color` | This controls the formatting of the log messages. Values are `plain`, `color`, `legacy` and `syslog` |
| `log_level` | `debug` | This filters messages according to the level of severity and can be one of `debug`, `info`, `warning`, `error` and `critical`. For example, setting to `warning` will display all messages at the warning level or higher, but won't display "debug" or "info" messages |
Several different log handler destinations are available, and multiple destinations are supported by providing a list: Several different log handler destinations are available, and multiple destinations are supported by providing a list:
```yaml ```
log_handlers: log_handlers:
- console - console
- syslog=localhost - syslog=localhost
- file=/var/log/invokeai.log - file=/var/log/invokeai.log
``` ```
- `console` is the default. It prints log messages to the command-line window from which InvokeAI was launched. * `console` is the default. It prints log messages to the command-line window from which InvokeAI was launched.
- `syslog` is only available on Linux and Macintosh systems. It uses * `syslog` is only available on Linux and Macintosh systems. It uses
the operating system's "syslog" facility to write log file entries the operating system's "syslog" facility to write log file entries
locally or to a remote logging machine. `syslog` offers a variety locally or to a remote logging machine. `syslog` offers a variety
of configuration options: of configuration options:
@@ -173,7 +271,7 @@ log_handlers:
- Log to LAN-connected server "fredserver" using the facility LOG_USER and datagram packets. - Log to LAN-connected server "fredserver" using the facility LOG_USER and datagram packets.
``` ```
- `http` can be used to log to a remote web server. The server must be * `http` can be used to log to a remote web server. The server must be
properly configured to receive and act on log messages. The option properly configured to receive and act on log messages. The option
accepts the URL to the web server, and a `method` argument accepts the URL to the web server, and a `method` argument
indicating whether the message should be submitted using the GET or indicating whether the message should be submitted using the GET or
@@ -185,10 +283,7 @@ log_handlers:
The `log_format` option provides several alternative formats: The `log_format` option provides several alternative formats:
- `color` - default format providing time, date and a message, using text colors to distinguish different log severities * `color` - default format providing time, date and a message, using text colors to distinguish different log severities
- `plain` - same as above, but monochrome text only * `plain` - same as above, but monochrome text only
- `syslog` - the log level and error message only, allowing the syslog system to attach the time and date * `syslog` - the log level and error message only, allowing the syslog system to attach the time and date
- `legacy` - a format similar to the one used by the legacy 2.3 InvokeAI releases. * `legacy` - a format similar to the one used by the legacy 2.3 InvokeAI releases.
[basic guide to yaml files]: https://circleci.com/blog/what-is-yaml-a-beginner-s-guide/
[Model Marketplace API Keys]: #model-marketplace-api-keys

View File

@@ -1,35 +0,0 @@
---
title: Database
---
# Invoke's SQLite Database
Invoke uses a SQLite database to store image, workflow, model, and execution data.
We take great care to ensure your data is safe, by utilizing transactions and a database migration system.
Even so, when testing an prerelease version of the app, we strongly suggest either backing up your database or using an in-memory database. This ensures any prelease hiccups or databases schema changes will not cause problems for your data.
## Database Backup
Backing up your database is very simple. Invoke's data is stored in an `$INVOKEAI_ROOT` directory - where your `invoke.sh`/`invoke.bat` and `invokeai.yaml` files live.
To back up your database, copy the `invokeai.db` file from `$INVOKEAI_ROOT/databases/invokeai.db` to somewhere safe.
If anything comes up during prelease testing, you can simply copy your backup back into `$INVOKEAI_ROOT/databases/`.
## In-Memory Database
SQLite can run on an in-memory database. Your existing database is untouched when this mode is enabled, but your existing data won't be accessible.
This is very useful for testing, as there is no chance of a database change modifying your "physical" database.
To run Invoke with a memory database, edit your `invokeai.yaml` file, and add `use_memory_db: true` to the `Paths:` stanza:
```yaml
InvokeAI:
Development:
use_memory_db: true
```
Delete this line (or set it to `false`) to use your main database.

View File

@@ -122,9 +122,9 @@ experimental versions later.
[latest release](https://github.com/invoke-ai/InvokeAI/releases/latest), [latest release](https://github.com/invoke-ai/InvokeAI/releases/latest),
and look for a file named: and look for a file named:
- InvokeAI-installer-v4.X.X.zip - InvokeAI-installer-v3.X.X.zip
where "4.X.X" is the latest released version. The file is located where "3.X.X" is the latest released version. The file is located
at the very bottom of the release page, under **Assets**. at the very bottom of the release page, under **Assets**.
4. **Unpack the installer**: Unpack the zip file into a convenient directory. This will create a new 4. **Unpack the installer**: Unpack the zip file into a convenient directory. This will create a new
@@ -199,7 +199,136 @@ experimental versions later.
![initial-settings-screenshot](../assets/installer-walkthrough/settings-form.png) ![initial-settings-screenshot](../assets/installer-walkthrough/settings-form.png)
</figure> </figure>
10. **Running InvokeAI for the first time**: The script will now exit and you'll be ready to generate some images. Look 10. **Post-install Configuration**: After installation completes, the
installer will launch the configuration form, which will guide you
through the first-time process of adjusting some of InvokeAI's
startup settings. To move around this form use ctrl-N for
&lt;N&gt;ext and ctrl-P for &lt;P&gt;revious, or use &lt;tab&gt;
and shift-&lt;tab&gt; to move forward and back. Once you are in a
multi-checkbox field use the up and down cursor keys to select the
item you want, and &lt;space&gt; to toggle it on and off. Within
a directory field, pressing &lt;tab&gt; will provide autocomplete
options.
Generally the defaults are fine, and you can come back to this screen at
any time to tweak your system. Here are the options you can adjust:
- ***HuggingFace Access Token***
InvokeAI has the ability to download embedded styles and subjects
from the HuggingFace Concept Library on-demand. However, some of
the concept library files are password protected. To make download
smoother, you can set up an account at huggingface.co, obtain an
access token, and paste it into this field. Note that you paste
to this screen using ctrl-shift-V
- ***Free GPU memory after each generation***
This is useful for low-memory machines and helps minimize the
amount of GPU VRAM used by InvokeAI.
- ***Enable xformers support if available***
If the xformers library was successfully installed, this will activate
it to reduce memory consumption and increase rendering speed noticeably.
Note that xformers has the side effect of generating slightly different
images even when presented with the same seed and other settings.
- ***Force CPU to be used on GPU systems***
This will use the (slow) CPU rather than the accelerated GPU. This
can be used to generate images on systems that don't have a compatible
GPU.
- ***Precision***
This controls whether to use float32 or float16 arithmetic.
float16 uses less memory but is also slightly less accurate.
Ordinarily the right arithmetic is picked automatically ("auto"),
but you may have to use float32 to get images on certain systems
and graphics cards. The "autocast" option is deprecated and
shouldn't be used unless you are asked to by a member of the team.
- **Size of the RAM cache used for fast model switching***
This allows you to keep models in memory and switch rapidly among
them rather than having them load from disk each time. This slider
controls how many models to keep loaded at once. A typical SD-1 or SD-2 model
uses 2-3 GB of memory. A typical SDXL model uses 6-7 GB. Providing more
RAM will allow more models to be co-resident.
- ***Output directory for images***
This is the path to a directory in which InvokeAI will store all its
generated images.
- ***Autoimport Folder***
This is the directory in which you can place models you have
downloaded and wish to load into InvokeAI. You can place a variety
of models in this directory, including diffusers folders, .ckpt files,
.safetensors files, as well as LoRAs, ControlNet and Textual Inversion
files (both folder and file versions). To help organize this folder,
you can create several levels of subfolders and drop your models into
whichever ones you want.
- ***LICENSE***
At the bottom of the screen you will see a checkbox for accepting
the CreativeML Responsible AI Licenses. You need to accept the license
in order to download Stable Diffusion models from the next screen.
_You can come back to the startup options form_ as many times as you like.
From the `invoke.sh` or `invoke.bat` launcher, select option (6) to relaunch
this script. On the command line, it is named `invokeai-configure`.
11. **Downloading Models**: After you press `[NEXT]` on the screen, you will be taken
to another screen that prompts you to download a series of starter models. The ones
we recommend are preselected for you, but you are encouraged to use the checkboxes to
pick and choose.
You will probably wish to download `autoencoder-840000` for use with models that
were trained with an older version of the Stability VAE.
<figure markdown>
![select-models-screenshot](../assets/installer-walkthrough/installing-models.png)
</figure>
Below the preselected list of starter models is a large text field which you can use
to specify a series of models to import. You can specify models in a variety of formats,
each separated by a space or newline. The formats accepted are:
- The path to a .ckpt or .safetensors file. On most systems, you can drag a file from
the file browser to the textfield to automatically paste the path. Be sure to remove
extraneous quotation marks and other things that come along for the ride.
- The path to a directory containing a combination of `.ckpt` and `.safetensors` files.
The directory will be scanned from top to bottom (including subfolders) and any
file that can be imported will be.
- A URL pointing to a `.ckpt` or `.safetensors` file. You can cut
and paste directly from a web page, or simply drag the link from the web page
or navigation bar. (You can also use ctrl-shift-V to paste into this field)
The file will be downloaded and installed.
- The HuggingFace repository ID (repo_id) for a `diffusers` model. These IDs have
the format _author_name/model_name_, as in `andite/anything-v4.0`
- The path to a local directory containing a `diffusers`
model. These directories always have the file `model_index.json`
at their top level.
_Select a directory for models to import_ You may select a local
directory for autoimporting at startup time. If you select this
option, the directory you choose will be scanned for new
.ckpt/.safetensors files each time InvokeAI starts up, and any new
files will be automatically imported and made available for your
use.
_Convert imported models into diffusers_ When legacy checkpoint
files are imported, you may select to use them unmodified (the
default) or to convert them into `diffusers` models. The latter
load much faster and have slightly better rendering performance,
but not all checkpoint files can be converted. Note that Stable Diffusion
Version 2.X files are **only** supported in `diffusers` format and will
be converted regardless.
_You can come back to the model install form_ as many times as you like.
From the `invoke.sh` or `invoke.bat` launcher, select option (5) to relaunch
this script. On the command line, it is named `invokeai-model-install`.
12. **Running InvokeAI for the first time**: The script will now exit and you'll be ready to generate some images. Look
for the directory `invokeai` installed in the location you chose at the for the directory `invokeai` installed in the location you chose at the
beginning of the install session. Look for a shell script named `invoke.sh` beginning of the install session. Look for a shell script named `invoke.sh`
(Linux/Mac) or `invoke.bat` (Windows). Launch the script by double-clicking (Linux/Mac) or `invoke.bat` (Windows). Launch the script by double-clicking
@@ -220,14 +349,14 @@ experimental versions later.
http://localhost:9090. Click on this link to open up a browser http://localhost:9090. Click on this link to open up a browser
and start exploring InvokeAI's features. and start exploring InvokeAI's features.
12. **InvokeAI Options**: You can configure using the `invokeai.yaml` config file. 12. **InvokeAI Options**: You can launch InvokeAI with several different command-line arguments that
For example, you can change the location of the customize its behavior. For example, you can change the location of the
image output directory or balance memory usage vs performance. See image output directory or balance memory usage vs performance. See
[Configuration](../features/CONFIGURATION.md) for a full list of the options. [Configuration](../features/CONFIGURATION.md) for a full list of the options.
- To set defaults that will take effect every time you launch InvokeAI, - To set defaults that will take effect every time you launch InvokeAI,
use a text editor (e.g. Notepad) to exit the file use a text editor (e.g. Notepad) to exit the file
`invokeai\invokeai.yaml`. It contains a variety of examples that you can `invokeai\invokeai.init`. It contains a variety of examples that you can
follow to add and modify launch options. follow to add and modify launch options.
- The launcher script also offers you an option labeled "open the developer - The launcher script also offers you an option labeled "open the developer
@@ -265,6 +394,7 @@ rm .\.venv -r -force
python -mvenv .venv python -mvenv .venv
.\.venv\Scripts\activate .\.venv\Scripts\activate
pip install invokeai pip install invokeai
invokeai-configure --yes --root .
``` ```
If you see anything marked as an error during this process please stop If you see anything marked as an error during this process please stop
@@ -296,10 +426,16 @@ error messages:
This failure mode occurs when there is a network glitch during This failure mode occurs when there is a network glitch during
downloading the very large SDXL model. downloading the very large SDXL model.
To address this, first go to the Model Manager and delete the To address this, first go to the Web Model Manager and delete the
Stable-Diffusion-XL-base-1.X model. Then, click the HuggingFace tab, Stable-Diffusion-XL-base-1.X model. Then navigate to HuggingFace and
paste the Repo ID stabilityai/stable-diffusion-xl-base-1.0 and install manually download the .safetensors version of the model. The 1.0
the model. version is located at
https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/tree/main
and the file is named `sd_xl_base_1.0.safetensors`.
Save this file to disk and then reenter the Model Manager. Navigate to
Import Models->Add Model, then type (or drag-and-drop) the path to the
.safetensors file. Press "Add Model".
### _Package dependency conflicts_ ### _Package dependency conflicts_
@@ -352,7 +488,15 @@ download models, etc), but this doesn't fix the problem.
This issue is often caused by a misconfigured configuration directive in the This issue is often caused by a misconfigured configuration directive in the
`invokeai\invokeai.init` initialization file that contains startup settings. The `invokeai\invokeai.init` initialization file that contains startup settings. The
easiest way to fix the problem is to move the file out of the way and restart the app. easiest way to fix the problem is to move the file out of the way and re-run
`invokeai-configure`. Enter the developer's console (option 3 of the launcher
script) and run this command:
```cmd
invokeai-configure --root=.
```
Note the dot (.) after `--root`. It is part of the command.
_If none of these maneuvers fixes the problem_ then please report the problem to _If none of these maneuvers fixes the problem_ then please report the problem to
the [InvokeAI Issues](https://github.com/invoke-ai/InvokeAI/issues) section, or the [InvokeAI Issues](https://github.com/invoke-ai/InvokeAI/issues) section, or
@@ -421,4 +565,16 @@ This distribution is changing rapidly, and we add new features
regularly. Releases are announced at regularly. Releases are announced at
http://github.com/invoke-ai/InvokeAI/releases, and at http://github.com/invoke-ai/InvokeAI/releases, and at
https://pypi.org/project/InvokeAI/ To update to the latest released https://pypi.org/project/InvokeAI/ To update to the latest released
version (recommended), download the latest release and run the installer. version (recommended), follow these steps:
1. Start the `invoke.sh`/`invoke.bat` launch script from within the
`invokeai` root directory.
2. Choose menu item (10) "Update InvokeAI".
3. This will launch a menu that gives you the option of:
1. Updating to the latest official release;
2. Updating to the bleeding-edge development version; or
3. Manually entering the tag or branch name of a version of
InvokeAI you wish to try out.

View File

@@ -26,7 +26,7 @@ driver).
🖥️ **Download the latest installer .zip file here** : https://github.com/invoke-ai/InvokeAI/releases/latest 🖥️ **Download the latest installer .zip file here** : https://github.com/invoke-ai/InvokeAI/releases/latest
- *Look for the file labelled "InvokeAI-installer-v4.X.X.zip" at the bottom of the page* - *Look for the file labelled "InvokeAI-installer-v3.X.X.zip" at the bottom of the page*
- If you experience issues, read through the full [installation instructions](010_INSTALL_AUTOMATED.md) to make sure you have met all of the installation requirements. If you need more help, join the [Discord](discord.gg/invoke-ai) or create an issue on [Github](https://github.com/invoke-ai/InvokeAI). - If you experience issues, read through the full [installation instructions](010_INSTALL_AUTOMATED.md) to make sure you have met all of the installation requirements. If you need more help, join the [Discord](discord.gg/invoke-ai) or create an issue on [Github](https://github.com/invoke-ai/InvokeAI).

View File

@@ -1,63 +0,0 @@
# Invocation API
Each invocation's `invoke` method is provided a single arg - the Invocation
Context.
This object provides access to various methods, used to interact with the
application. Loading and saving images, logging messages, etc.
!!! warning ""
This API may shift slightly until the release of v4.0.0 as we work through a few final updates to the Model Manager.
```py
class MyInvocation(BaseInvocation):
...
def invoke(self, context: InvocationContext) -> ImageOutput:
image_pil = context.images.get_pil(image_name)
# Do something to the image
image_dto = context.images.save(image_pil)
# Log a message
context.logger.info(f"Did something cool, image saved!")
...
```
The full API is documented below.
## Invocation Mixins
Two important mixins are provided to facilitate working with metadata and gallery boards.
### `WithMetadata`
Inherit from this class (in addition to `BaseInvocation`) to add a `metadata` input to your node. When you do this, you can access the metadata dict from `self.metadata` in the `invoke()` function.
The dict will be populated via the node's input, and you can add any metadata you'd like to it. When you call `context.images.save()`, if the metadata dict has any data, it be automatically embedded in the image.
### `WithBoard`
Inherit from this class (in addition to `BaseInvocation`) to add a `board` input to your node. This renders as a drop-down to select a board. The user's selection will be accessible from `self.board` in the `invoke()` function.
When you call `context.images.save()`, if a board was selected, the image will added to that board as it is saved.
<!-- prettier-ignore-start -->
::: invokeai.app.services.shared.invocation_context.InvocationContext
options:
members: false
::: invokeai.app.services.shared.invocation_context.ImagesInterface
::: invokeai.app.services.shared.invocation_context.TensorsInterface
::: invokeai.app.services.shared.invocation_context.ConditioningInterface
::: invokeai.app.services.shared.invocation_context.ModelsInterface
::: invokeai.app.services.shared.invocation_context.LoggerInterface
::: invokeai.app.services.shared.invocation_context.ConfigInterface
::: invokeai.app.services.shared.invocation_context.UtilInterface
::: invokeai.app.services.shared.invocation_context.BoardsInterface
<!-- prettier-ignore-end -->

View File

@@ -1,148 +0,0 @@
# Invoke v4.0.0 Nodes API Migration guide
Invoke v4.0.0 is versioned as such due to breaking changes to the API utilized
by nodes, both core and custom.
## Motivation
Prior to v4.0.0, the `invokeai` python package has not be set up to be utilized
as a library. That is to say, it didn't have any explicitly public API, and node
authors had to work with the unstable internal application API.
v4.0.0 introduces a stable public API for nodes.
## Changes
There are two node-author-facing changes:
1. Import Paths
1. Invocation Context API
### Import Paths
All public objects are now exported from `invokeai.invocation_api`:
```py
# Old
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
InputField,
InvocationContext,
invocation,
)
from invokeai.app.invocations.primitives import ImageField
# New
from invokeai.invocation_api import (
BaseInvocation,
ImageField,
InputField,
InvocationContext,
invocation,
)
```
It's possible that we've missed some classes you need in your node. Please let
us know if that's the case.
### Invocation Context API
Most nodes utilize the Invocation Context, an object that is passed to the
`invoke` that provides access to data and services a node may need.
Until now, that object and the services it exposed were internal. Exposing them
to nodes means that changes to our internal implementation could break nodes.
The methods on the services are also often fairly complicated and allowed nodes
to footgun.
In v4.0.0, this object has been refactored to be much simpler.
See [INVOCATION_API](./INVOCATION_API.md) for full details of the API.
!!! warning ""
This API may shift slightly until the release of v4.0.0 as we work through a few final updates to the Model Manager.
#### Improved Service Methods
The biggest offender was the image save method:
```py
# Old
image_dto = context.services.images.create(
image=image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata,
workflow=context.workflow,
)
# New
image_dto = context.images.save(image=image)
```
Other methods are simplified, or enhanced with additional functionality:
```py
# Old
image = context.services.images.get_pil_image(image_name)
# New
image = context.images.get_pil(image_name)
image_cmyk = context.images.get_pil(image_name, "CMYK")
```
We also had some typing issues around tensors:
```py
# Old
# `latents` typed as `torch.Tensor`, but could be `ConditioningFieldData`
latents = context.services.latents.get(self.latents.latents_name)
# `data` typed as `torch.Tenssor,` but could be `ConditioningFieldData`
context.services.latents.save(latents_name, data)
# New - separate methods for tensors and conditioning data w/ correct typing
# Also, the service generates the names
tensor_name = context.tensors.save(tensor)
tensor = context.tensors.load(tensor_name)
# For conditioning
cond_name = context.conditioning.save(cond_data)
cond_data = context.conditioning.load(cond_name)
```
#### Output Construction
Core Outputs have builder functions right on them - no need to manually
construct these objects, or use an extra utility:
```py
# Old
image_output = ImageOutput(
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
latents_output = build_latents_output(latents_name=name, latents=latents, seed=None)
noise_output = NoiseOutput(
noise=LatentsField(latents_name=latents_name, seed=seed),
width=latents.size()[3] * 8,
height=latents.size()[2] * 8,
)
cond_output = ConditioningOutput(
conditioning=ConditioningField(
conditioning_name=conditioning_name,
),
)
# New
image_output = ImageOutput.build(image_dto)
latents_output = LatentsOutput.build(latents_name=name, latents=noise, seed=self.seed)
noise_output = NoiseOutput.build(latents_name=name, latents=noise, seed=self.seed)
cond_output = ConditioningOutput.build(conditioning_name)
```
You can still create the objects using constructors if you want, but we suggest
using the builder methods.

View File

@@ -32,7 +32,6 @@ To use a community workflow, download the the `.json` node graph file and load i
+ [Image to Character Art Image Nodes](#image-to-character-art-image-nodes) + [Image to Character Art Image Nodes](#image-to-character-art-image-nodes)
+ [Image Picker](#image-picker) + [Image Picker](#image-picker)
+ [Image Resize Plus](#image-resize-plus) + [Image Resize Plus](#image-resize-plus)
+ [Latent Upscale](#latent-upscale)
+ [Load Video Frame](#load-video-frame) + [Load Video Frame](#load-video-frame)
+ [Make 3D](#make-3d) + [Make 3D](#make-3d)
+ [Mask Operations](#mask-operations) + [Mask Operations](#mask-operations)
@@ -291,13 +290,6 @@ View:
</br><img src="https://raw.githubusercontent.com/VeyDlin/image-resize-plus-node/master/.readme/node.png" width="500" /> </br><img src="https://raw.githubusercontent.com/VeyDlin/image-resize-plus-node/master/.readme/node.png" width="500" />
--------------------------------
### Latent Upscale
**Description:** This node uses a small (~2.4mb) model to upscale the latents used in a Stable Diffusion 1.5 or Stable Diffusion XL image generation, rather than the typical interpolation method, avoiding the traditional downsides of the latent upscale technique.
**Node Link:** [https://github.com/gogurtenjoyer/latent-upscale](https://github.com/gogurtenjoyer/latent-upscale)
-------------------------------- --------------------------------
### Load Video Frame ### Load Video Frame
@@ -354,21 +346,12 @@ See full docs here: https://github.com/skunkworxdark/Prompt-tools-nodes/edit/mai
**Description:** A set of nodes for Metadata. Collect Metadata from within an `iterate` node & extract metadata from an image. **Description:** A set of nodes for Metadata. Collect Metadata from within an `iterate` node & extract metadata from an image.
- `Metadata Item Linked` - Allows collecting of metadata while within an iterate node with no need for a collect node or conversion to metadata node - `Metadata Item Linked` - Allows collecting of metadata while within an iterate node with no need for a collect node or conversion to metadata node.
- `Metadata From Image` - Provides Metadata from an image - `Metadata From Image` - Provides Metadata from an image.
- `Metadata To String` - Extracts a String value of a label from metadata - `Metadata To String` - Extracts a String value of a label from metadata.
- `Metadata To Integer` - Extracts an Integer value of a label from metadata - `Metadata To Integer` - Extracts an Integer value of a label from metadata.
- `Metadata To Float` - Extracts a Float value of a label from metadata - `Metadata To Float` - Extracts a Float value of a label from metadata.
- `Metadata To Scheduler` - Extracts a Scheduler value of a label from metadata - `Metadata To Scheduler` - Extracts a Scheduler value of a label from metadata.
- `Metadata To Bool` - Extracts Bool types from metadata
- `Metadata To Model` - Extracts model types from metadata
- `Metadata To SDXL Model` - Extracts SDXL model types from metadata
- `Metadata To LoRAs` - Extracts Loras from metadata.
- `Metadata To SDXL LoRAs` - Extracts SDXL Loras from metadata
- `Metadata To ControlNets` - Extracts ControNets from metadata
- `Metadata To IP-Adapters` - Extracts IP-Adapters from metadata
- `Metadata To T2I-Adapters` - Extracts T2I-Adapters from metadata
- `Denoise Latents + Metadata` - This is an inherited version of the existing `Denoise Latents` node but with a metadata input and output.
**Node Link:** https://github.com/skunkworxdark/metadata-linked-nodes **Node Link:** https://github.com/skunkworxdark/metadata-linked-nodes

View File

@@ -19,8 +19,6 @@ their descriptions.
| Conditioning Primitive | A conditioning tensor primitive value | | Conditioning Primitive | A conditioning tensor primitive value |
| Content Shuffle Processor | Applies content shuffle processing to image | | Content Shuffle Processor | Applies content shuffle processing to image |
| ControlNet | Collects ControlNet info to pass to other nodes | | ControlNet | Collects ControlNet info to pass to other nodes |
| Create Denoise Mask | Converts a greyscale or transparency image into a mask for denoising. |
| Create Gradient Mask | Creates a mask for Gradient ("soft", "differential") inpainting that gradually expands during denoising. Improves edge coherence. |
| Denoise Latents | Denoises noisy latents to decodable images | | Denoise Latents | Denoises noisy latents to decodable images |
| Divide Integers | Divides two numbers | | Divide Integers | Divides two numbers |
| Dynamic Prompt | Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator | | Dynamic Prompt | Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator |

View File

@@ -0,0 +1,5 @@
mkdocs
mkdocs-material>=8, <9
mkdocs-git-revision-date-localized-plugin
mkdocs-redirects==1.2.0

View File

@@ -0,0 +1,5 @@
:root {
--md-primary-fg-color: #35A4DB;
--md-primary-fg-color--light: #35A4DB;
--md-primary-fg-color--dark: #35A4DB;
}

View File

@@ -2,18 +2,22 @@
set -e set -e
BCYAN="\033[1;36m" BCYAN="\e[1;36m"
BYELLOW="\033[1;33m" BYELLOW="\e[1;33m"
BGREEN="\033[1;32m" BGREEN="\e[1;32m"
BRED="\033[1;31m" BRED="\e[1;31m"
RED="\033[31m" RED="\e[31m"
RESET="\033[0m" RESET="\e[0m"
function is_bin_in_path {
builtin type -P "$1" &>/dev/null
}
function git_show { function git_show {
git show -s --format=oneline --abbrev-commit "$1" | cat git show -s --format=oneline --abbrev-commit "$1" | cat
} }
if [[ ! -z "${VIRTUAL_ENV}" ]]; then if [[ -v "VIRTUAL_ENV" ]]; then
# we can't just call 'deactivate' because this function is not exported # we can't just call 'deactivate' because this function is not exported
# to the environment of this script from the bash process that runs the script # to the environment of this script from the bash process that runs the script
echo -e "${BRED}A virtual environment is activated. Please deactivate it before proceeding.${RESET}" echo -e "${BRED}A virtual environment is activated. Please deactivate it before proceeding.${RESET}"
@@ -22,63 +26,31 @@ fi
cd "$(dirname "$0")" cd "$(dirname "$0")"
echo
echo -e "${BYELLOW}This script must be run from the installer directory!${RESET}"
echo "The current working directory is $(pwd)"
read -p "If that looks right, press any key to proceed, or CTRL-C to exit..."
echo
# Some machines only have `python3` in PATH, others have `python` - make an alias.
# We can use a function to approximate an alias within a non-interactive shell.
if ! is_bin_in_path python && is_bin_in_path python3; then
function python {
python3 "$@"
}
fi
VERSION=$( VERSION=$(
cd .. cd ..
python3 -c "from invokeai.version import __version__ as version; print(version)" python -c "from invokeai.version import __version__ as version; print(version)"
) )
VERSION="v${VERSION}" PATCH=""
VERSION="v${VERSION}${PATCH}"
if [[ ! -z ${CI} ]]; then
echo
echo -e "${BCYAN}CI environment detected${RESET}"
echo
else
echo
echo -e "${BYELLOW}This script must be run from the installer directory!${RESET}"
echo "The current working directory is $(pwd)"
read -p "If that looks right, press any key to proceed, or CTRL-C to exit..."
echo
fi
echo -e "${BGREEN}HEAD${RESET}:" echo -e "${BGREEN}HEAD${RESET}:"
git_show HEAD git_show HEAD
echo echo
# ---------------------- FRONTEND ----------------------
pushd ../invokeai/frontend/web >/dev/null
echo "Installing frontend dependencies..."
echo
pnpm i --frozen-lockfile
echo
if [[ ! -z ${CI} ]]; then
echo "Building frontend without checks..."
# In CI, we have already done the frontend checks and can just build
pnpm vite build
else
echo "Running checks and building frontend..."
# This runs all the frontend checks and builds
pnpm build
fi
echo
popd
# ---------------------- BACKEND ----------------------
echo
echo "Building wheel..."
echo
# install the 'build' package in the user site packages, if needed
# could be improved by using a temporary venv, but it's tiny and harmless
if [[ $(python3 -c 'from importlib.util import find_spec; print(find_spec("build") is None)') == "True" ]]; then
pip install --user build
fi
rm -rf ../build
python3 -m build --outdir dist/ ../.
# ---------------------- # ----------------------
echo echo
@@ -106,28 +78,10 @@ chmod a+x InvokeAI-Installer/install.sh
cp install.bat.in InvokeAI-Installer/install.bat cp install.bat.in InvokeAI-Installer/install.bat
cp WinLongPathsEnabled.reg InvokeAI-Installer/ cp WinLongPathsEnabled.reg InvokeAI-Installer/
FILENAME=InvokeAI-installer-$VERSION.zip
# Zip everything up # Zip everything up
zip -r ${FILENAME} InvokeAI-Installer zip -r InvokeAI-installer-$VERSION.zip InvokeAI-Installer
echo # clean up
echo -e "${BGREEN}Built installer: ./${FILENAME}${RESET}" rm -rf InvokeAI-Installer tmp dist ../invokeai/frontend/web/dist/
echo -e "${BGREEN}Built PyPi distribution: ./dist${RESET}"
# clean up, but only if we are not in a github action
if [[ -z ${CI} ]]; then
echo
echo "Cleaning up intermediate build files..."
rm -rf InvokeAI-Installer tmp ../invokeai/frontend/web/dist/
fi
if [[ ! -z ${CI} ]]; then
echo
echo "Setting GitHub action outputs..."
echo "INSTALLER_FILENAME=${FILENAME}" >>$GITHUB_OUTPUT
echo "INSTALLER_PATH=installer/${FILENAME}" >>$GITHUB_OUTPUT
echo "DIST_PATH=installer/dist/" >>$GITHUB_OUTPUT
fi
exit 0 exit 0

View File

@@ -149,6 +149,9 @@ class Installer:
# install the launch/update scripts into the runtime directory # install the launch/update scripts into the runtime directory
self.instance.install_user_scripts() self.instance.install_user_scripts()
# run through the configuration flow
self.instance.configure()
class InvokeAiInstance: class InvokeAiInstance:
""" """
@@ -239,6 +242,53 @@ class InvokeAiInstance:
) )
sys.exit(1) sys.exit(1)
def configure(self):
"""
Configure the InvokeAI runtime directory
"""
auto_install = False
# set sys.argv to a consistent state
new_argv = [sys.argv[0]]
for i in range(1, len(sys.argv)):
el = sys.argv[i]
if el in ["-r", "--root"]:
new_argv.append(el)
new_argv.append(sys.argv[i + 1])
elif el in ["-y", "--yes", "--yes-to-all"]:
auto_install = True
sys.argv = new_argv
import messages
import requests # to catch download exceptions
auto_install = auto_install or messages.user_wants_auto_configuration()
if auto_install:
sys.argv.append("--yes")
else:
messages.introduction()
from invokeai.frontend.install.invokeai_configure import invokeai_configure
# NOTE: currently the config script does its own arg parsing! this means the command-line switches
# from the installer will also automatically propagate down to the config script.
# this may change in the future with config refactoring!
succeeded = False
try:
invokeai_configure()
succeeded = True
except requests.exceptions.ConnectionError as e:
print(f"\nA network error was encountered during configuration and download: {str(e)}")
except OSError as e:
print(f"\nAn OS error was encountered during configuration and download: {str(e)}")
except Exception as e:
print(f"\nA problem was encountered during the configuration and download steps: {str(e)}")
finally:
if not succeeded:
print('To try again, find the "invokeai" directory, run the script "invoke.sh" or "invoke.bat"')
print("and choose option 7 to fix a broken install, optionally followed by option 5 to install models.")
print("Alternatively you can relaunch the installer.")
def install_user_scripts(self): def install_user_scripts(self):
""" """
Copy the launch and update scripts to the runtime dir Copy the launch and update scripts to the runtime dir

View File

@@ -8,7 +8,7 @@ import platform
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
from prompt_toolkit import prompt from prompt_toolkit import HTML, prompt
from prompt_toolkit.completion import FuzzyWordCompleter, PathCompleter from prompt_toolkit.completion import FuzzyWordCompleter, PathCompleter
from prompt_toolkit.validation import Validator from prompt_toolkit.validation import Validator
from rich import box, print from rich import box, print
@@ -98,6 +98,39 @@ def choose_version(available_releases: tuple | None = None) -> str:
return "stable" if response == "" else response return "stable" if response == "" else response
def user_wants_auto_configuration() -> bool:
"""Prompt the user to choose between manual and auto configuration."""
console.rule("InvokeAI Configuration Section")
console.print(
Panel(
Group(
"\n".join(
[
"Libraries are installed and InvokeAI will now set up its root directory and configuration. Choose between:",
"",
" * AUTOMATIC configuration: install reasonable defaults and a minimal set of starter models.",
" * MANUAL configuration: manually inspect and adjust configuration options and pick from a larger set of starter models.",
"",
"Later you can fine tune your configuration by selecting option [6] 'Change InvokeAI startup options' from the invoke.bat/invoke.sh launcher script.",
]
),
),
box=box.MINIMAL,
padding=(1, 1),
)
)
choice = (
prompt(
HTML("Choose <b>&lt;a&gt;</b>utomatic or <b>&lt;m&gt;</b>anual configuration [a/m] (a): "),
validator=Validator.from_callable(
lambda n: n == "" or n.startswith(("a", "A", "m", "M")), error_message="Please select 'a' or 'm'"
),
)
or "a"
)
return choice.lower().startswith("a")
def confirm_install(dest: Path) -> bool: def confirm_install(dest: Path) -> bool:
if dest.exists(): if dest.exists():
print(f":stop_sign: Directory {dest} already exists!") print(f":stop_sign: Directory {dest} already exists!")
@@ -318,6 +351,34 @@ def windows_long_paths_registry() -> None:
) )
def introduction() -> None:
"""
Display a banner when starting configuration of the InvokeAI application
"""
console.rule()
console.print(
Panel(
title=":art: Configuring InvokeAI :art:",
renderable=Group(
"",
"[b]This script will:",
"",
"1. Configure the InvokeAI application directory",
"2. Help download the Stable Diffusion weight files",
" and other large models that are needed for text to image generation",
"3. Create initial configuration files.",
"",
"[i]At any point you may interrupt this program and resume later.",
"",
"[b]For the best user experience, please enlarge or maximize this window",
),
)
)
console.line(2)
def _platform_specific_help() -> Text | None: def _platform_specific_help() -> Text | None:
if OS == "Darwin": if OS == "Darwin":
text = Text.from_markup( text = Text.from_markup(

View File

@@ -2,12 +2,12 @@
set -e set -e
BCYAN="\033[1;36m" BCYAN="\e[1;36m"
BYELLOW="\033[1;33m" BYELLOW="\e[1;33m"
BGREEN="\033[1;32m" BGREEN="\e[1;32m"
BRED="\033[1;31m" BRED="\e[1;31m"
RED="\033[31m" RED="\e[31m"
RESET="\033[0m" RESET="\e[0m"
function does_tag_exist { function does_tag_exist {
git rev-parse --quiet --verify "refs/tags/$1" >/dev/null git rev-parse --quiet --verify "refs/tags/$1" >/dev/null
@@ -23,40 +23,49 @@ function git_show {
VERSION=$( VERSION=$(
cd .. cd ..
python3 -c "from invokeai.version import __version__ as version; print(version)" python -c "from invokeai.version import __version__ as version; print(version)"
) )
PATCH="" PATCH=""
MAJOR_VERSION=$(echo $VERSION | sed 's/\..*$//')
VERSION="v${VERSION}${PATCH}" VERSION="v${VERSION}${PATCH}"
LATEST_TAG="v${MAJOR_VERSION}-latest"
if does_tag_exist $VERSION; then if does_tag_exist $VERSION; then
echo -e "${BCYAN}${VERSION}${RESET} already exists:" echo -e "${BCYAN}${VERSION}${RESET} already exists:"
git_show_ref tags/$VERSION git_show_ref tags/$VERSION
echo echo
fi fi
if does_tag_exist $LATEST_TAG; then
echo -e "${BCYAN}${LATEST_TAG}${RESET} already exists:"
git_show_ref tags/$LATEST_TAG
echo
fi
echo -e "${BGREEN}HEAD${RESET}:" echo -e "${BGREEN}HEAD${RESET}:"
git_show git_show
echo echo
echo -e "${BGREEN}git remote -v${RESET}:" echo -e -n "Create tags ${BCYAN}${VERSION}${RESET} and ${BCYAN}${LATEST_TAG}${RESET} @ ${BGREEN}HEAD${RESET}, ${RED}deleting existing tags on remote${RESET}? "
git remote -v
echo
echo -e -n "Create tags ${BCYAN}${VERSION}${RESET} @ ${BGREEN}HEAD${RESET}, ${RED}deleting existing tags on origin remote${RESET}? "
read -e -p 'y/n [n]: ' input read -e -p 'y/n [n]: ' input
RESPONSE=${input:='n'} RESPONSE=${input:='n'}
if [ "$RESPONSE" == 'y' ]; then if [ "$RESPONSE" == 'y' ]; then
echo echo
echo -e "Deleting ${BCYAN}${VERSION}${RESET} tag on origin remote..." echo -e "Deleting ${BCYAN}${VERSION}${RESET} tag on remote..."
git push origin :refs/tags/$VERSION git push --delete origin $VERSION
echo -e "Tagging ${BGREEN}HEAD${RESET} with ${BCYAN}${VERSION}${RESET} on locally..." echo -e "Tagging ${BGREEN}HEAD${RESET} with ${BCYAN}${VERSION}${RESET} locally..."
if ! git tag -fa $VERSION; then if ! git tag -fa $VERSION; then
echo "Existing/invalid tag" echo "Existing/invalid tag"
exit -1 exit -1
fi fi
echo -e "Pushing updated tags to origin remote..." echo -e "Deleting ${BCYAN}${LATEST_TAG}${RESET} tag on remote..."
git push --delete origin $LATEST_TAG
echo -e "Tagging ${BGREEN}HEAD${RESET} with ${BCYAN}${LATEST_TAG}${RESET} locally..."
git tag -fa $LATEST_TAG
echo -e "Pushing updated tags to remote..."
git push origin --tags git push origin --tags
fi fi
exit 0 exit 0

View File

@@ -9,10 +9,15 @@ set INVOKEAI_ROOT=.
:start :start
echo Desired action: echo Desired action:
echo 1. Generate images with the browser-based interface echo 1. Generate images with the browser-based interface
echo 2. Open the developer console echo 2. Run textual inversion training
echo 3. Update InvokeAI (DEPRECATED - please use the installer) echo 3. Merge models (diffusers type only)
echo 4. Run the InvokeAI image database maintenance script echo 4. Download and install models
echo 5. Command-line help echo 5. Change InvokeAI startup options
echo 6. Re-run the configure script to fix a broken install or to complete a major upgrade
echo 7. Open the developer console
echo 8. Update InvokeAI (DEPRECATED - please use the installer)
echo 9. Run the InvokeAI image database maintenance script
echo 10. Command-line help
echo Q - Quit echo Q - Quit
set /P choice="Please enter 1-10, Q: [1] " set /P choice="Please enter 1-10, Q: [1] "
if not defined choice set choice=1 if not defined choice set choice=1
@@ -20,6 +25,21 @@ IF /I "%choice%" == "1" (
echo Starting the InvokeAI browser-based UI.. echo Starting the InvokeAI browser-based UI..
python .venv\Scripts\invokeai-web.exe %* python .venv\Scripts\invokeai-web.exe %*
) ELSE IF /I "%choice%" == "2" ( ) ELSE IF /I "%choice%" == "2" (
echo Starting textual inversion training..
python .venv\Scripts\invokeai-ti.exe --gui
) ELSE IF /I "%choice%" == "3" (
echo Starting model merging script..
python .venv\Scripts\invokeai-merge.exe --gui
) ELSE IF /I "%choice%" == "4" (
echo Running invokeai-model-install...
python .venv\Scripts\invokeai-model-install.exe
) ELSE IF /I "%choice%" == "5" (
echo Running invokeai-configure...
python .venv\Scripts\invokeai-configure.exe --skip-sd-weight --skip-support-models
) ELSE IF /I "%choice%" == "6" (
echo Running invokeai-configure...
python .venv\Scripts\invokeai-configure.exe --yes --skip-sd-weight
) ELSE IF /I "%choice%" == "7" (
echo Developer Console echo Developer Console
echo Python command is: echo Python command is:
where python where python
@@ -31,15 +51,15 @@ IF /I "%choice%" == "1" (
echo ************************* echo *************************
echo *** Type `exit` to quit this shell and deactivate the Python virtual environment *** echo *** Type `exit` to quit this shell and deactivate the Python virtual environment ***
call cmd /k call cmd /k
) ELSE IF /I "%choice%" == "3" ( ) ELSE IF /I "%choice%" == "8" (
echo UPDATING FROM WITHIN THE APP IS BEING DEPRECATED. echo UPDATING FROM WITHIN THE APP IS BEING DEPRECATED.
echo Please download the installer from https://github.com/invoke-ai/InvokeAI/releases/latest and run it to update your installation. echo Please download the installer from https://github.com/invoke-ai/InvokeAI/releases/latest and run it to update your installation.
timeout 4 timeout 4
python -m invokeai.frontend.install.invokeai_update python -m invokeai.frontend.install.invokeai_update
) ELSE IF /I "%choice%" == "4" ( ) ELSE IF /I "%choice%" == "9" (
echo Running the db maintenance script... echo Running the db maintenance script...
python .venv\Scripts\invokeai-db-maintenance.exe python .venv\Scripts\invokeai-db-maintenance.exe
) ELSE IF /I "%choice%" == "5" ( ) ELSE IF /I "%choice%" == "10" (
echo Displaying command line help... echo Displaying command line help...
python .venv\Scripts\invokeai-web.exe --help %* python .venv\Scripts\invokeai-web.exe --help %*
pause pause

View File

@@ -58,24 +58,49 @@ do_choice() {
invokeai-web $PARAMS invokeai-web $PARAMS
;; ;;
2) 2)
clear
printf "Textual inversion training\n"
invokeai-ti --gui $PARAMS
;;
3)
clear
printf "Merge models (diffusers type only)\n"
invokeai-merge --gui $PARAMS
;;
4)
clear
printf "Download and install models\n"
invokeai-model-install --root ${INVOKEAI_ROOT}
;;
5)
clear
printf "Change InvokeAI startup options\n"
invokeai-configure --root ${INVOKEAI_ROOT} --skip-sd-weights --skip-support-models
;;
6)
clear
printf "Re-run the configure script to fix a broken install or to complete a major upgrade\n"
invokeai-configure --root ${INVOKEAI_ROOT} --yes --default_only --skip-sd-weights
;;
7)
clear clear
printf "Open the developer console\n" printf "Open the developer console\n"
file_name=$(basename "${BASH_SOURCE[0]}") file_name=$(basename "${BASH_SOURCE[0]}")
bash --init-file "$file_name" bash --init-file "$file_name"
;; ;;
3) 8)
clear clear
printf "UPDATING FROM WITHIN THE APP IS BEING DEPRECATED\n" printf "UPDATING FROM WITHIN THE APP IS BEING DEPRECATED\n"
printf "Please download the installer from https://github.com/invoke-ai/InvokeAI/releases/latest and run it to update your installation.\n" printf "Please download the installer from https://github.com/invoke-ai/InvokeAI/releases/latest and run it to update your installation.\n"
sleep 4 sleep 4
python -m invokeai.frontend.install.invokeai_update python -m invokeai.frontend.install.invokeai_update
;; ;;
4) 9)
clear clear
printf "Running the db maintenance script\n" printf "Running the db maintenance script\n"
invokeai-db-maintenance --root ${INVOKEAI_ROOT} invokeai-db-maintenance --root ${INVOKEAI_ROOT}
;; ;;
5) 10)
clear clear
printf "Command-line help\n" printf "Command-line help\n"
invokeai-web --help invokeai-web --help
@@ -93,10 +118,15 @@ do_choice() {
do_dialog() { do_dialog() {
options=( options=(
1 "Generate images with a browser-based interface" 1 "Generate images with a browser-based interface"
2 "Open the developer console" 2 "Textual inversion training"
3 "Update InvokeAI (DEPRECATED - please use the installer)" 3 "Merge models (diffusers type only)"
4 "Run the InvokeAI image database maintenance script" 4 "Download and install models"
5 "Command-line help" 5 "Change InvokeAI startup options"
6 "Re-run the configure script to fix a broken install or to complete a major upgrade"
7 "Open the developer console"
8 "Update InvokeAI (DEPRECATED - please use the installer)"
9 "Run the InvokeAI image database maintenance script"
10 "Command-line help"
) )
choice=$(dialog --clear \ choice=$(dialog --clear \
@@ -121,10 +151,15 @@ do_line_input() {
printf " ** For a more attractive experience, please install the 'dialog' utility using your package manager. **\n\n" printf " ** For a more attractive experience, please install the 'dialog' utility using your package manager. **\n\n"
printf "What would you like to do?\n" printf "What would you like to do?\n"
printf "1: Generate images using the browser-based interface\n" printf "1: Generate images using the browser-based interface\n"
printf "2: Open the developer console\n" printf "2: Run textual inversion training\n"
printf "3: Update InvokeAI\n" printf "3: Merge models (diffusers type only)\n"
printf "4: Run the InvokeAI image database maintenance script\n" printf "4: Download and install models\n"
printf "5: Command-line help\n" printf "5: Change InvokeAI startup options\n"
printf "6: Re-run the configure script to fix a broken install\n"
printf "7: Open the developer console\n"
printf "8: Update InvokeAI\n"
printf "9: Run the InvokeAI image database maintenance script\n"
printf "10: Command-line help\n"
printf "Q: Quit\n\n" printf "Q: Quit\n\n"
read -p "Please enter 1-10, Q: [1] " yn read -p "Please enter 1-10, Q: [1] " yn
choice=${yn:='1'} choice=${yn:='1'}

11
invokeai/README Normal file
View File

@@ -0,0 +1,11 @@
Organization of the source tree:
app -- Home of nodes invocations and services
assets -- Images and other data files used by InvokeAI
backend -- Non-user facing libraries, including the rendering
core.
configs -- Configuration files used at install and run times
frontend -- User-facing scripts, including the CLI and the WebUI
version -- Current InvokeAI version string, stored
in version/invokeai_version.py

View File

@@ -4,6 +4,7 @@ from logging import Logger
import torch import torch
from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMemory
from invokeai.app.services.object_serializer.object_serializer_disk import ObjectSerializerDisk from invokeai.app.services.object_serializer.object_serializer_disk import ObjectSerializerDisk
from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache
from invokeai.app.services.shared.sqlite.sqlite_util import init_db from invokeai.app.services.shared.sqlite.sqlite_util import init_db
@@ -15,22 +16,24 @@ from ..services.board_image_records.board_image_records_sqlite import SqliteBoar
from ..services.board_images.board_images_default import BoardImagesService from ..services.board_images.board_images_default import BoardImagesService
from ..services.board_records.board_records_sqlite import SqliteBoardRecordStorage from ..services.board_records.board_records_sqlite import SqliteBoardRecordStorage
from ..services.boards.boards_default import BoardService from ..services.boards.boards_default import BoardService
from ..services.bulk_download.bulk_download_default import BulkDownloadService
from ..services.config import InvokeAIAppConfig from ..services.config import InvokeAIAppConfig
from ..services.download import DownloadQueueService from ..services.download import DownloadQueueService
from ..services.image_files.image_files_disk import DiskImageFileStorage from ..services.image_files.image_files_disk import DiskImageFileStorage
from ..services.image_records.image_records_sqlite import SqliteImageRecordStorage from ..services.image_records.image_records_sqlite import SqliteImageRecordStorage
from ..services.images.images_default import ImageService from ..services.images.images_default import ImageService
from ..services.invocation_cache.invocation_cache_memory import MemoryInvocationCache from ..services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
from ..services.invocation_processor.invocation_processor_default import DefaultInvocationProcessor
from ..services.invocation_queue.invocation_queue_memory import MemoryInvocationQueue
from ..services.invocation_services import InvocationServices from ..services.invocation_services import InvocationServices
from ..services.invocation_stats.invocation_stats_default import InvocationStatsService from ..services.invocation_stats.invocation_stats_default import InvocationStatsService
from ..services.invoker import Invoker from ..services.invoker import Invoker
from ..services.model_images.model_images_default import ModelImageFileStorageDisk
from ..services.model_manager.model_manager_default import ModelManagerService from ..services.model_manager.model_manager_default import ModelManagerService
from ..services.model_metadata import ModelMetadataStoreSQL
from ..services.model_records import ModelRecordServiceSQL from ..services.model_records import ModelRecordServiceSQL
from ..services.names.names_default import SimpleNameService from ..services.names.names_default import SimpleNameService
from ..services.session_processor.session_processor_default import DefaultSessionProcessor from ..services.session_processor.session_processor_default import DefaultSessionProcessor
from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue
from ..services.shared.graph import GraphExecutionState
from ..services.urls.urls_default import LocalUrlService from ..services.urls.urls_default import LocalUrlService
from ..services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage from ..services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
from .events import FastAPIEventService from .events import FastAPIEventService
@@ -64,15 +67,14 @@ class ApiDependencies:
def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger = logger) -> None: def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger = logger) -> None:
logger.info(f"InvokeAI version {__version__}") logger.info(f"InvokeAI version {__version__}")
logger.info(f"Root directory = {str(config.root_path)}") logger.info(f"Root directory = {str(config.root_path)}")
logger.debug(f"Internet connectivity is {config.internet_available}")
output_folder = config.outputs_path output_folder = config.output_path
if output_folder is None: if output_folder is None:
raise ValueError("Output folder is not set") raise ValueError("Output folder is not set")
image_files = DiskImageFileStorage(f"{output_folder}/images") image_files = DiskImageFileStorage(f"{output_folder}/images")
model_images_folder = config.models_path
db = init_db(config=config, logger=logger, image_files=image_files) db = init_db(config=config, logger=logger, image_files=image_files)
configuration = config configuration = config
@@ -83,7 +85,7 @@ class ApiDependencies:
board_records = SqliteBoardRecordStorage(db=db) board_records = SqliteBoardRecordStorage(db=db)
boards = BoardService() boards = BoardService()
events = FastAPIEventService(event_handler_id) events = FastAPIEventService(event_handler_id)
bulk_download = BulkDownloadService() graph_execution_manager = ItemStorageMemory[GraphExecutionState]()
image_records = SqliteImageRecordStorage(db=db) image_records = SqliteImageRecordStorage(db=db)
images = ImageService() images = ImageService()
invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size) invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size)
@@ -94,15 +96,17 @@ class ApiDependencies:
ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True) ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True)
) )
download_queue_service = DownloadQueueService(event_bus=events) download_queue_service = DownloadQueueService(event_bus=events)
model_images_service = ModelImageFileStorageDisk(model_images_folder / "model_images") model_metadata_service = ModelMetadataStoreSQL(db=db)
model_manager = ModelManagerService.build_model_manager( model_manager = ModelManagerService.build_model_manager(
app_config=configuration, app_config=configuration,
model_record_service=ModelRecordServiceSQL(db=db), model_record_service=ModelRecordServiceSQL(db=db, metadata_store=model_metadata_service),
download_queue=download_queue_service, download_queue=download_queue_service,
events=events, events=events,
) )
names = SimpleNameService() names = SimpleNameService()
performance_statistics = InvocationStatsService() performance_statistics = InvocationStatsService()
processor = DefaultInvocationProcessor()
queue = MemoryInvocationQueue()
session_processor = DefaultSessionProcessor() session_processor = DefaultSessionProcessor()
session_queue = SqliteSessionQueue(db=db) session_queue = SqliteSessionQueue(db=db)
urls = LocalUrlService() urls = LocalUrlService()
@@ -113,19 +117,20 @@ class ApiDependencies:
board_images=board_images, board_images=board_images,
board_records=board_records, board_records=board_records,
boards=boards, boards=boards,
bulk_download=bulk_download,
configuration=configuration, configuration=configuration,
events=events, events=events,
graph_execution_manager=graph_execution_manager,
image_files=image_files, image_files=image_files,
image_records=image_records, image_records=image_records,
images=images, images=images,
invocation_cache=invocation_cache, invocation_cache=invocation_cache,
logger=logger, logger=logger,
model_images=model_images_service,
model_manager=model_manager, model_manager=model_manager,
download_queue=download_queue_service, download_queue=download_queue_service,
names=names, names=names,
performance_statistics=performance_statistics, performance_statistics=performance_statistics,
processor=processor,
queue=queue,
session_processor=session_processor, session_processor=session_processor,
session_queue=session_queue, session_queue=session_queue,
urls=urls, urls=urls,

View File

@@ -12,6 +12,7 @@ from pydantic import BaseModel, Field
from invokeai.app.invocations.upscale import ESRGAN_MODELS from invokeai.app.invocations.upscale import ESRGAN_MODELS
from invokeai.app.services.invocation_cache.invocation_cache_common import InvocationCacheStatus from invokeai.app.services.invocation_cache.invocation_cache_common import InvocationCacheStatus
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
from invokeai.backend.image_util.patchmatch import PatchMatch from invokeai.backend.image_util.patchmatch import PatchMatch
from invokeai.backend.image_util.safety_checker import SafetyChecker from invokeai.backend.image_util.safety_checker import SafetyChecker
from invokeai.backend.util.logging import logging from invokeai.backend.util.logging import logging
@@ -113,7 +114,9 @@ async def get_config() -> AppConfig:
if SafetyChecker.safety_checker_available(): if SafetyChecker.safety_checker_available():
nsfw_methods.append("nsfw_checker") nsfw_methods.append("nsfw_checker")
watermarking_methods = ["invisible_watermark"] watermarking_methods = []
if InvisibleWatermark.invisible_watermark_available():
watermarking_methods.append("invisible_watermark")
return AppConfig( return AppConfig(
infill_methods=infill_methods, infill_methods=infill_methods,

View File

@@ -2,7 +2,7 @@ import io
import traceback import traceback
from typing import Optional from typing import Optional
from fastapi import BackgroundTasks, Body, HTTPException, Path, Query, Request, Response, UploadFile from fastapi import Body, HTTPException, Path, Query, Request, Response, UploadFile
from fastapi.responses import FileResponse from fastapi.responses import FileResponse
from fastapi.routing import APIRouter from fastapi.routing import APIRouter
from PIL import Image from PIL import Image
@@ -375,67 +375,16 @@ async def unstar_images_in_list(
class ImagesDownloaded(BaseModel): class ImagesDownloaded(BaseModel):
response: Optional[str] = Field( response: Optional[str] = Field(
default=None, description="The message to display to the user when images begin downloading" description="If defined, the message to display to the user when images begin downloading"
)
bulk_download_item_name: Optional[str] = Field(
default=None, description="The name of the bulk download item for which events will be emitted"
) )
@images_router.post( @images_router.post("/download", operation_id="download_images_from_list", response_model=ImagesDownloaded)
"/download", operation_id="download_images_from_list", response_model=ImagesDownloaded, status_code=202
)
async def download_images_from_list( async def download_images_from_list(
background_tasks: BackgroundTasks, image_names: list[str] = Body(description="The list of names of images to download", embed=True),
image_names: Optional[list[str]] = Body(
default=None, description="The list of names of images to download", embed=True
),
board_id: Optional[str] = Body( board_id: Optional[str] = Body(
default=None, description="The board from which image should be downloaded", embed=True default=None, description="The board from which image should be downloaded from", embed=True
), ),
) -> ImagesDownloaded: ) -> ImagesDownloaded:
if (image_names is None or len(image_names) == 0) and board_id is None: # return ImagesDownloaded(response="Your images are downloading")
raise HTTPException(status_code=400, detail="No images or board id specified.") raise HTTPException(status_code=501, detail="Endpoint is not yet implemented")
bulk_download_item_id: str = ApiDependencies.invoker.services.bulk_download.generate_item_id(board_id)
background_tasks.add_task(
ApiDependencies.invoker.services.bulk_download.handler,
image_names,
board_id,
bulk_download_item_id,
)
return ImagesDownloaded(bulk_download_item_name=bulk_download_item_id + ".zip")
@images_router.api_route(
"/download/{bulk_download_item_name}",
methods=["GET"],
operation_id="get_bulk_download_item",
response_class=Response,
responses={
200: {
"description": "Return the complete bulk download item",
"content": {"application/zip": {}},
},
404: {"description": "Image not found"},
},
)
async def get_bulk_download_item(
background_tasks: BackgroundTasks,
bulk_download_item_name: str = Path(description="The bulk_download_item_name of the bulk download item to get"),
) -> FileResponse:
"""Gets a bulk download zip file"""
try:
path = ApiDependencies.invoker.services.bulk_download.get_path(bulk_download_item_name)
response = FileResponse(
path,
media_type="application/zip",
filename=bulk_download_item_name,
content_disposition_type="inline",
)
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
background_tasks.add_task(ApiDependencies.invoker.services.bulk_download.delete, bulk_download_item_name)
return response
except Exception:
raise HTTPException(status_code=404)

View File

@@ -1,32 +1,27 @@
# Copyright (c) 2023 Lincoln D. Stein # Copyright (c) 2023 Lincoln D. Stein
"""FastAPI route for model configuration records.""" """FastAPI route for model configuration records."""
import contextlib
import io
import pathlib import pathlib
import shutil import shutil
import traceback from hashlib import sha1
from copy import deepcopy from random import randbytes
from enum import Enum from typing import Any, Dict, List, Optional, Set
from typing import Any, Dict, List, Optional
import huggingface_hub from fastapi import Body, Path, Query, Response
from fastapi import Body, Path, Query, Response, UploadFile
from fastapi.responses import FileResponse
from fastapi.routing import APIRouter from fastapi.routing import APIRouter
from PIL import Image from pydantic import BaseModel, ConfigDict
from pydantic import AnyHttpUrl, BaseModel, ConfigDict, Field
from starlette.exceptions import HTTPException from starlette.exceptions import HTTPException
from typing_extensions import Annotated from typing_extensions import Annotated
from invokeai.app.services.model_install import ModelInstallJob from invokeai.app.services.model_install import ModelInstallJob, ModelSource
from invokeai.app.services.model_records import ( from invokeai.app.services.model_records import (
DuplicateModelException, DuplicateModelException,
InvalidModelException, InvalidModelException,
ModelRecordChanges, ModelRecordOrderBy,
ModelSummary,
UnknownModelException, UnknownModelException,
) )
from invokeai.app.util.suppress_output import SuppressOutput from invokeai.app.services.shared.pagination import PaginatedResults
from invokeai.backend.model_manager.config import ( from invokeai.backend.model_manager.config import (
AnyModelConfig, AnyModelConfig,
BaseModelType, BaseModelType,
@@ -35,18 +30,13 @@ from invokeai.backend.model_manager.config import (
ModelType, ModelType,
SubModelType, SubModelType,
) )
from invokeai.backend.model_manager.metadata.fetch.huggingface import HuggingFaceMetadataFetch from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger
from invokeai.backend.model_manager.metadata.metadata_base import ModelMetadataWithFiles, UnknownMetadataException from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
from invokeai.backend.model_manager.search import ModelSearch
from invokeai.backend.model_manager.starter_models import STARTER_MODELS, StarterModel, StarterModelWithoutDependencies
from ..dependencies import ApiDependencies from ..dependencies import ApiDependencies
model_manager_router = APIRouter(prefix="/v2/models", tags=["model_manager"]) model_manager_router = APIRouter(prefix="/v2/models", tags=["model_manager"])
# images are immutable; set a high max-age
IMAGE_MAX_AGE = 31536000
class ModelsList(BaseModel): class ModelsList(BaseModel):
"""Return list of configs.""" """Return list of configs."""
@@ -56,6 +46,15 @@ class ModelsList(BaseModel):
model_config = ConfigDict(use_enum_values=True) model_config = ConfigDict(use_enum_values=True)
class ModelTagSet(BaseModel):
"""Return tags for a set of models."""
key: str
name: str
author: str
tags: Set[str]
############################################################################## ##############################################################################
# These are example inputs and outputs that are used in places where Swagger # These are example inputs and outputs that are used in places where Swagger
# is unable to generate a correct example. # is unable to generate a correct example.
@@ -66,16 +65,19 @@ example_model_config = {
"base": "sd-1", "base": "sd-1",
"type": "main", "type": "main",
"format": "checkpoint", "format": "checkpoint",
"config_path": "string", "config": "string",
"key": "string", "key": "string",
"hash": "string", "original_hash": "string",
"current_hash": "string",
"description": "string", "description": "string",
"source": "string", "source": "string",
"converted_at": 0, "last_modified": 0,
"vae": "string",
"variant": "normal", "variant": "normal",
"prediction_type": "epsilon", "prediction_type": "epsilon",
"repo_variant": "fp16", "repo_variant": "fp16",
"upcast_attention": False, "upcast_attention": False,
"ztsnr_training": False,
} }
example_model_input = { example_model_input = {
@@ -84,12 +86,50 @@ example_model_input = {
"base": "sd-1", "base": "sd-1",
"type": "main", "type": "main",
"format": "checkpoint", "format": "checkpoint",
"config_path": "configs/stable-diffusion/v1-inference.yaml", "config": "configs/stable-diffusion/v1-inference.yaml",
"description": "Model description", "description": "Model description",
"vae": None, "vae": None,
"variant": "normal", "variant": "normal",
} }
example_model_metadata = {
"name": "ip_adapter_sd_image_encoder",
"author": "InvokeAI",
"tags": [
"transformers",
"safetensors",
"clip_vision_model",
"endpoints_compatible",
"region:us",
"has_space",
"license:apache-2.0",
],
"files": [
{
"url": "https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder/resolve/main/README.md",
"path": "ip_adapter_sd_image_encoder/README.md",
"size": 628,
"sha256": None,
},
{
"url": "https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder/resolve/main/config.json",
"path": "ip_adapter_sd_image_encoder/config.json",
"size": 560,
"sha256": None,
},
{
"url": "https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder/resolve/main/model.safetensors",
"path": "ip_adapter_sd_image_encoder/model.safetensors",
"size": 2528373448,
"sha256": "6ca9667da1ca9e0b0f75e46bb030f7e011f44f86cbfb8d5a36590fcd7507b030",
},
],
"type": "huggingface",
"id": "InvokeAI/ip_adapter_sd_image_encoder",
"tag_dict": {"license": "apache-2.0"},
"last_modified": "2023-09-23T17:33:25Z",
}
############################################################################## ##############################################################################
# ROUTES # ROUTES
############################################################################## ##############################################################################
@@ -121,33 +161,9 @@ async def list_model_records(
found_models.extend( found_models.extend(
record_store.search_by_attr(model_type=model_type, model_name=model_name, model_format=model_format) record_store.search_by_attr(model_type=model_type, model_name=model_name, model_format=model_format)
) )
for model in found_models:
cover_image = ApiDependencies.invoker.services.model_images.get_url(model.key)
model.cover_image = cover_image
return ModelsList(models=found_models) return ModelsList(models=found_models)
@model_manager_router.get(
"/get_by_attrs",
operation_id="get_model_records_by_attrs",
response_model=AnyModelConfig,
)
async def get_model_records_by_attrs(
name: str = Query(description="The name of the model"),
type: ModelType = Query(description="The type of the model"),
base: BaseModelType = Query(description="The base model of the model"),
) -> AnyModelConfig:
"""Gets a model by its attributes. The main use of this route is to provide backwards compatibility with the old
model manager, which identified models by a combination of name, base and type."""
configs = ApiDependencies.invoker.services.model_manager.store.search_by_attr(
base_model=base, model_type=type, model_name=name
)
if not configs:
raise HTTPException(status_code=404, detail="No model found with these attributes")
return configs[0]
@model_manager_router.get( @model_manager_router.get(
"/i/{key}", "/i/{key}",
operation_id="get_model_record", operation_id="get_model_record",
@@ -167,126 +183,68 @@ async def get_model_record(
record_store = ApiDependencies.invoker.services.model_manager.store record_store = ApiDependencies.invoker.services.model_manager.store
try: try:
config: AnyModelConfig = record_store.get_model(key) config: AnyModelConfig = record_store.get_model(key)
cover_image = ApiDependencies.invoker.services.model_images.get_url(key)
config.cover_image = cover_image
return config return config
except UnknownModelException as e: except UnknownModelException as e:
raise HTTPException(status_code=404, detail=str(e)) raise HTTPException(status_code=404, detail=str(e))
# @model_manager_router.get("/summary", operation_id="list_model_summary") @model_manager_router.get("/summary", operation_id="list_model_summary")
# async def list_model_summary( async def list_model_summary(
# page: int = Query(default=0, description="The page to get"), page: int = Query(default=0, description="The page to get"),
# per_page: int = Query(default=10, description="The number of models per page"), per_page: int = Query(default=10, description="The number of models per page"),
# order_by: ModelRecordOrderBy = Query(default=ModelRecordOrderBy.Default, description="The attribute to order by"), order_by: ModelRecordOrderBy = Query(default=ModelRecordOrderBy.Default, description="The attribute to order by"),
# ) -> PaginatedResults[ModelSummary]: ) -> PaginatedResults[ModelSummary]:
# """Gets a page of model summary data.""" """Gets a page of model summary data."""
# record_store = ApiDependencies.invoker.services.model_manager.store record_store = ApiDependencies.invoker.services.model_manager.store
# results: PaginatedResults[ModelSummary] = record_store.list_models(page=page, per_page=per_page, order_by=order_by) results: PaginatedResults[ModelSummary] = record_store.list_models(page=page, per_page=per_page, order_by=order_by)
# return results return results
class FoundModel(BaseModel):
path: str = Field(description="Path to the model")
is_installed: bool = Field(description="Whether or not the model is already installed")
@model_manager_router.get( @model_manager_router.get(
"/scan_folder", "/meta/i/{key}",
operation_id="scan_for_models", operation_id="get_model_metadata",
responses={ responses={
200: {"description": "Directory scanned successfully"}, 200: {
400: {"description": "Invalid directory path"}, "description": "The model metadata was retrieved successfully",
"content": {"application/json": {"example": example_model_metadata}},
},
400: {"description": "Bad request"},
404: {"description": "No metadata available"},
}, },
status_code=200,
response_model=List[FoundModel],
) )
async def scan_for_models( async def get_model_metadata(
scan_path: str = Query(description="Directory path to search for models", default=None), key: str = Path(description="Key of the model repo metadata to fetch."),
) -> List[FoundModel]: ) -> Optional[AnyModelRepoMetadata]:
path = pathlib.Path(scan_path) """Get a model metadata object."""
if not scan_path or not path.is_dir(): record_store = ApiDependencies.invoker.services.model_manager.store
raise HTTPException( result: Optional[AnyModelRepoMetadata] = record_store.get_metadata(key)
status_code=400, if not result:
detail=f"The search path '{scan_path}' does not exist or is not directory", raise HTTPException(status_code=404, detail="No metadata for a model with this key")
) return result
search = ModelSearch()
try:
found_model_paths = search.search(path)
models_path = ApiDependencies.invoker.services.configuration.models_path
# If the search path includes the main models directory, we need to exclude core models from the list.
# TODO(MM2): Core models should be handled by the model manager so we can determine if they are installed
# without needing to crawl the filesystem.
core_models_path = pathlib.Path(models_path, "core").resolve()
non_core_model_paths = [p for p in found_model_paths if not p.is_relative_to(core_models_path)]
installed_models = ApiDependencies.invoker.services.model_manager.store.search_by_attr()
resolved_installed_model_paths: list[str] = []
installed_model_sources: list[str] = []
# This call lists all installed models.
for model in installed_models:
path = pathlib.Path(model.path)
# If the model has a source, we need to add it to the list of installed sources.
if model.source:
installed_model_sources.append(model.source)
# If the path is not absolute, that means it is in the app models directory, and we need to join it with
# the models path before resolving.
if not path.is_absolute():
resolved_installed_model_paths.append(str(pathlib.Path(models_path, path).resolve()))
continue
resolved_installed_model_paths.append(str(path.resolve()))
scan_results: list[FoundModel] = []
# Check if the model is installed by comparing the resolved paths, appending to the scan result.
for p in non_core_model_paths:
path = str(p)
is_installed = path in resolved_installed_model_paths or path in installed_model_sources
found_model = FoundModel(path=path, is_installed=is_installed)
scan_results.append(found_model)
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"An error occurred while searching the directory: {e}",
)
return scan_results
class HuggingFaceModels(BaseModel):
urls: List[AnyHttpUrl] | None = Field(description="URLs for all checkpoint format models in the metadata")
is_diffusers: bool = Field(description="Whether the metadata is for a Diffusers format model")
@model_manager_router.get( @model_manager_router.get(
"/hugging_face", "/tags",
operation_id="get_hugging_face_models", operation_id="list_tags",
responses={
200: {"description": "Hugging Face repo scanned successfully"},
400: {"description": "Invalid hugging face repo"},
},
status_code=200,
response_model=HuggingFaceModels,
) )
async def get_hugging_face_models( async def list_tags() -> Set[str]:
hugging_face_repo: str = Query(description="Hugging face repo to search for models", default=None), """Get a unique set of all the model tags."""
) -> HuggingFaceModels: record_store = ApiDependencies.invoker.services.model_manager.store
try: result: Set[str] = record_store.list_tags()
metadata = HuggingFaceMetadataFetch().from_id(hugging_face_repo) return result
except UnknownMetadataException:
raise HTTPException(
status_code=400,
detail="No HuggingFace repository found",
)
assert isinstance(metadata, ModelMetadataWithFiles)
return HuggingFaceModels( @model_manager_router.get(
urls=metadata.ckpt_urls, "/tags/search",
is_diffusers=metadata.is_diffusers, operation_id="search_by_metadata_tags",
) )
async def search_by_metadata_tags(
tags: Set[str] = Query(default=None, description="Tags to search for"),
) -> ModelsList:
"""Get a list of models."""
record_store = ApiDependencies.invoker.services.model_manager.store
results = record_store.search_by_metadata_tag(tags)
return ModelsList(models=results)
@model_manager_router.patch( @model_manager_router.patch(
@@ -305,15 +263,15 @@ async def get_hugging_face_models(
) )
async def update_model_record( async def update_model_record(
key: Annotated[str, Path(description="Unique key of model")], key: Annotated[str, Path(description="Unique key of model")],
changes: Annotated[ModelRecordChanges, Body(description="Model config", example=example_model_input)], info: Annotated[
AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input)
],
) -> AnyModelConfig: ) -> AnyModelConfig:
"""Update a model's config.""" """Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
logger = ApiDependencies.invoker.services.logger logger = ApiDependencies.invoker.services.logger
record_store = ApiDependencies.invoker.services.model_manager.store record_store = ApiDependencies.invoker.services.model_manager.store
installer = ApiDependencies.invoker.services.model_manager.install
try: try:
record_store.update_model(key, changes=changes) model_response: AnyModelConfig = record_store.update_model(key, config=info)
model_response: AnyModelConfig = installer.sync_model_path(key)
logger.info(f"Updated model: {key}") logger.info(f"Updated model: {key}")
except UnknownModelException as e: except UnknownModelException as e:
raise HTTPException(status_code=404, detail=str(e)) raise HTTPException(status_code=404, detail=str(e))
@@ -323,85 +281,16 @@ async def update_model_record(
return model_response return model_response
@model_manager_router.get(
"/i/{key}/image",
operation_id="get_model_image",
responses={
200: {
"description": "The model image was fetched successfully",
},
400: {"description": "Bad request"},
404: {"description": "The model image could not be found"},
},
status_code=200,
)
async def get_model_image(
key: str = Path(description="The name of model image file to get"),
) -> FileResponse:
"""Gets an image file that previews the model"""
try:
path = ApiDependencies.invoker.services.model_images.get_path(key)
response = FileResponse(
path,
media_type="image/png",
filename=key + ".png",
content_disposition_type="inline",
)
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
return response
except Exception:
raise HTTPException(status_code=404)
@model_manager_router.patch(
"/i/{key}/image",
operation_id="update_model_image",
responses={
200: {
"description": "The model image was updated successfully",
},
400: {"description": "Bad request"},
},
status_code=200,
)
async def update_model_image(
key: Annotated[str, Path(description="Unique key of model")],
image: UploadFile,
) -> None:
if not image.content_type or not image.content_type.startswith("image"):
raise HTTPException(status_code=415, detail="Not an image")
contents = await image.read()
try:
pil_image = Image.open(io.BytesIO(contents))
except Exception:
ApiDependencies.invoker.services.logger.error(traceback.format_exc())
raise HTTPException(status_code=415, detail="Failed to read image")
logger = ApiDependencies.invoker.services.logger
model_images = ApiDependencies.invoker.services.model_images
try:
model_images.save(pil_image, key)
logger.info(f"Updated image for model: {key}")
except ValueError as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
return
@model_manager_router.delete( @model_manager_router.delete(
"/i/{key}", "/i/{key}",
operation_id="delete_model", operation_id="del_model_record",
responses={ responses={
204: {"description": "Model deleted successfully"}, 204: {"description": "Model deleted successfully"},
404: {"description": "Model not found"}, 404: {"description": "Model not found"},
}, },
status_code=204, status_code=204,
) )
async def delete_model( async def del_model_record(
key: str = Path(description="Unique key of model to remove from model registry."), key: str = Path(description="Unique key of model to remove from model registry."),
) -> Response: ) -> Response:
""" """
@@ -422,67 +311,47 @@ async def delete_model(
raise HTTPException(status_code=404, detail=str(e)) raise HTTPException(status_code=404, detail=str(e))
@model_manager_router.delete( @model_manager_router.post(
"/i/{key}/image", "/i/",
operation_id="delete_model_image", operation_id="add_model_record",
responses={ responses={
204: {"description": "Model image deleted successfully"}, 201: {
404: {"description": "Model image not found"}, "description": "The model added successfully",
"content": {"application/json": {"example": example_model_config}},
}, },
status_code=204, 409: {"description": "There is already a model corresponding to this path or repo_id"},
415: {"description": "Unrecognized file/folder format"},
},
status_code=201,
) )
async def delete_model_image( async def add_model_record(
key: str = Path(description="Unique key of model image to remove from model_images directory."), config: Annotated[
) -> None: AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input)
],
) -> AnyModelConfig:
"""Add a model using the configuration information appropriate for its type."""
logger = ApiDependencies.invoker.services.logger logger = ApiDependencies.invoker.services.logger
model_images = ApiDependencies.invoker.services.model_images record_store = ApiDependencies.invoker.services.model_manager.store
if config.key == "<NOKEY>":
config.key = sha1(randbytes(100)).hexdigest()
logger.info(f"Created model {config.key} for {config.name}")
try: try:
model_images.delete(key) record_store.add_model(config.key, config)
logger.info(f"Deleted model image: {key}") except DuplicateModelException as e:
return
except UnknownModelException as e:
logger.error(str(e)) logger.error(str(e))
raise HTTPException(status_code=404, detail=str(e)) raise HTTPException(status_code=409, detail=str(e))
except InvalidModelException as e:
logger.error(str(e))
raise HTTPException(status_code=415)
# now fetch it out
# @model_manager_router.post( result: AnyModelConfig = record_store.get_model(config.key)
# "/i/", return result
# operation_id="add_model_record",
# responses={
# 201: {
# "description": "The model added successfully",
# "content": {"application/json": {"example": example_model_config}},
# },
# 409: {"description": "There is already a model corresponding to this path or repo_id"},
# 415: {"description": "Unrecognized file/folder format"},
# },
# status_code=201,
# )
# async def add_model_record(
# config: Annotated[
# AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input)
# ],
# ) -> AnyModelConfig:
# """Add a model using the configuration information appropriate for its type."""
# logger = ApiDependencies.invoker.services.logger
# record_store = ApiDependencies.invoker.services.model_manager.store
# try:
# record_store.add_model(config)
# except DuplicateModelException as e:
# logger.error(str(e))
# raise HTTPException(status_code=409, detail=str(e))
# except InvalidModelException as e:
# logger.error(str(e))
# raise HTTPException(status_code=415)
# # now fetch it out
# result: AnyModelConfig = record_store.get_model(config.key)
# return result
@model_manager_router.post( @model_manager_router.post(
"/install", "/heuristic_import",
operation_id="install_model", operation_id="heuristic_import_model",
responses={ responses={
201: {"description": "The model imported successfully"}, 201: {"description": "The model imported successfully"},
415: {"description": "Unrecognized file/folder format"}, 415: {"description": "Unrecognized file/folder format"},
@@ -491,14 +360,12 @@ async def delete_model_image(
}, },
status_code=201, status_code=201,
) )
async def install_model( async def heuristic_import(
source: str = Query(description="Model source to install, can be a local path, repo_id, or remote URL"), source: str,
inplace: Optional[bool] = Query(description="Whether or not to install a local model in place", default=False),
# TODO(MM2): Can we type this?
config: Optional[Dict[str, Any]] = Body( config: Optional[Dict[str, Any]] = Body(
description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ", description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
default=None, default=None,
example={"name": "string", "description": "string"}, example={"name": "modelT", "description": "antique cars"},
), ),
access_token: Optional[str] = None, access_token: Optional[str] = None,
) -> ModelInstallJob: ) -> ModelInstallJob:
@@ -535,8 +402,106 @@ async def install_model(
result: ModelInstallJob = installer.heuristic_import( result: ModelInstallJob = installer.heuristic_import(
source=source, source=source,
config=config, config=config,
access_token=access_token, )
inplace=bool(inplace), logger.info(f"Started installation of {source}")
except UnknownModelException as e:
logger.error(str(e))
raise HTTPException(status_code=424, detail=str(e))
except InvalidModelException as e:
logger.error(str(e))
raise HTTPException(status_code=415)
except ValueError as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
return result
@model_manager_router.post(
"/install",
operation_id="import_model",
responses={
201: {"description": "The model imported successfully"},
415: {"description": "Unrecognized file/folder format"},
424: {"description": "The model appeared to import successfully, but could not be found in the model manager"},
409: {"description": "There is already a model corresponding to this path or repo_id"},
},
status_code=201,
)
async def import_model(
source: ModelSource,
config: Optional[Dict[str, Any]] = Body(
description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
default=None,
),
) -> ModelInstallJob:
"""Install a model using its local path, repo_id, or remote URL.
Models will be downloaded, probed, configured and installed in a
series of background threads. The return object has `status` attribute
that can be used to monitor progress.
The source object is a discriminated Union of LocalModelSource,
HFModelSource and URLModelSource. Set the "type" field to the
appropriate value:
* To install a local path using LocalModelSource, pass a source of form:
```
{
"type": "local",
"path": "/path/to/model",
"inplace": false
}
```
The "inplace" flag, if true, will register the model in place in its
current filesystem location. Otherwise, the model will be copied
into the InvokeAI models directory.
* To install a HuggingFace repo_id using HFModelSource, pass a source of form:
```
{
"type": "hf",
"repo_id": "stabilityai/stable-diffusion-2.0",
"variant": "fp16",
"subfolder": "vae",
"access_token": "f5820a918aaf01"
}
```
The `variant`, `subfolder` and `access_token` fields are optional.
* To install a remote model using an arbitrary URL, pass:
```
{
"type": "url",
"url": "http://www.civitai.com/models/123456",
"access_token": "f5820a918aaf01"
}
```
The `access_token` field is optonal
The model's configuration record will be probed and filled in
automatically. To override the default guesses, pass "metadata"
with a Dict containing the attributes you wish to override.
Installation occurs in the background. Either use list_model_install_jobs()
to poll for completion, or listen on the event bus for the following events:
* "model_install_running"
* "model_install_completed"
* "model_install_error"
On successful completion, the event's payload will contain the field "key"
containing the installed ID of the model. On an error, the event's payload
will contain the fields "error_type" and "error" describing the nature of the
error and its traceback, respectively.
"""
logger = ApiDependencies.invoker.services.logger
try:
installer = ApiDependencies.invoker.services.model_manager.install
result: ModelInstallJob = installer.import_model(
source=source,
config=config,
) )
logger.info(f"Started installation of {source}") logger.info(f"Started installation of {source}")
except UnknownModelException as e: except UnknownModelException as e:
@@ -552,10 +517,10 @@ async def install_model(
@model_manager_router.get( @model_manager_router.get(
"/install", "/import",
operation_id="list_model_installs", operation_id="list_model_install_jobs",
) )
async def list_model_installs() -> List[ModelInstallJob]: async def list_model_install_jobs() -> List[ModelInstallJob]:
"""Return the list of model install jobs. """Return the list of model install jobs.
Install jobs have a numeric `id`, a `status`, and other fields that provide information on Install jobs have a numeric `id`, a `status`, and other fields that provide information on
@@ -569,8 +534,9 @@ async def list_model_installs() -> List[ModelInstallJob]:
* "cancelled" -- Job was cancelled before completion. * "cancelled" -- Job was cancelled before completion.
Once completed, information about the model such as its size, base Once completed, information about the model such as its size, base
model and type can be retrieved from the `config_out` field. For multi-file models such as diffusers, model, type, and metadata can be retrieved from the `config_out`
information on individual files can be retrieved from `download_parts`. field. For multi-file models such as diffusers, information on individual files
can be retrieved from `download_parts`.
See the example and schema below for more information. See the example and schema below for more information.
""" """
@@ -579,7 +545,7 @@ async def list_model_installs() -> List[ModelInstallJob]:
@model_manager_router.get( @model_manager_router.get(
"/install/{id}", "/import/{id}",
operation_id="get_model_install_job", operation_id="get_model_install_job",
responses={ responses={
200: {"description": "Success"}, 200: {"description": "Success"},
@@ -599,7 +565,7 @@ async def get_model_install_job(id: int = Path(description="Model install id"))
@model_manager_router.delete( @model_manager_router.delete(
"/install/{id}", "/import/{id}",
operation_id="cancel_model_install_job", operation_id="cancel_model_install_job",
responses={ responses={
201: {"description": "The job was cancelled successfully"}, 201: {"description": "The job was cancelled successfully"},
@@ -617,8 +583,8 @@ async def cancel_model_install_job(id: int = Path(description="Model install job
installer.cancel_job(job) installer.cancel_job(job)
@model_manager_router.delete( @model_manager_router.patch(
"/install", "/import",
operation_id="prune_model_install_jobs", operation_id="prune_model_install_jobs",
responses={ responses={
204: {"description": "All completed and errored jobs have been pruned"}, 204: {"description": "All completed and errored jobs have been pruned"},
@@ -671,7 +637,6 @@ async def convert_model(
Note that during the conversion process the key and model hash will change. Note that during the conversion process the key and model hash will change.
The return value is the model configuration for the converted model. The return value is the model configuration for the converted model.
""" """
model_manager = ApiDependencies.invoker.services.model_manager
logger = ApiDependencies.invoker.services.logger logger = ApiDependencies.invoker.services.logger
loader = ApiDependencies.invoker.services.model_manager.load loader = ApiDependencies.invoker.services.model_manager.load
store = ApiDependencies.invoker.services.model_manager.store store = ApiDependencies.invoker.services.model_manager.store
@@ -688,7 +653,7 @@ async def convert_model(
raise HTTPException(400, f"The model with key {key} is not a main checkpoint model.") raise HTTPException(400, f"The model with key {key} is not a main checkpoint model.")
# loading the model will convert it into a cached diffusers file # loading the model will convert it into a cached diffusers file
model_manager.load.load_model(model_config, submodel_type=SubModelType.Scheduler) loader.load_model_by_config(model_config, submodel_type=SubModelType.Scheduler)
# Get the path of the converted model from the loader # Get the path of the converted model from the loader
cache_path = loader.convert_cache.cache_path(key) cache_path = loader.convert_cache.cache_path(key)
@@ -697,8 +662,7 @@ async def convert_model(
# temporarily rename the original safetensors file so that there is no naming conflict # temporarily rename the original safetensors file so that there is no naming conflict
original_name = model_config.name original_name = model_config.name
model_config.name = f"{original_name}.DELETE" model_config.name = f"{original_name}.DELETE"
changes = ModelRecordChanges(name=model_config.name) store.update_model(key, config=model_config)
store.update_model(key, changes=changes)
# install the diffusers # install the diffusers
try: try:
@@ -707,7 +671,7 @@ async def convert_model(
config={ config={
"name": original_name, "name": original_name,
"description": model_config.description, "description": model_config.description,
"hash": model_config.hash, "original_hash": model_config.original_hash,
"source": model_config.source, "source": model_config.source,
}, },
) )
@@ -715,6 +679,10 @@ async def convert_model(
logger.error(str(e)) logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e)) raise HTTPException(status_code=409, detail=str(e))
# get the original metadata
if orig_metadata := store.get_metadata(key):
store.metadata_store.add_metadata(new_key, orig_metadata)
# delete the original safetensors file # delete the original safetensors file
installer.delete(key) installer.delete(key)
@@ -726,132 +694,66 @@ async def convert_model(
return new_config return new_config
# @model_manager_router.put( @model_manager_router.put(
# "/merge", "/merge",
# operation_id="merge", operation_id="merge",
# responses={ responses={
# 200: { 200: {
# "description": "Model converted successfully", "description": "Model converted successfully",
# "content": {"application/json": {"example": example_model_config}}, "content": {"application/json": {"example": example_model_config}},
# }, },
# 400: {"description": "Bad request"}, 400: {"description": "Bad request"},
# 404: {"description": "Model not found"}, 404: {"description": "Model not found"},
# 409: {"description": "There is already a model registered at this location"}, 409: {"description": "There is already a model registered at this location"},
# }, },
# ) )
# async def merge( async def merge(
# keys: List[str] = Body(description="Keys for two to three models to merge", min_length=2, max_length=3), keys: List[str] = Body(description="Keys for two to three models to merge", min_length=2, max_length=3),
# merged_model_name: Optional[str] = Body(description="Name of destination model", default=None), merged_model_name: Optional[str] = Body(description="Name of destination model", default=None),
# alpha: float = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5), alpha: float = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
# force: bool = Body( force: bool = Body(
# description="Force merging of models created with different versions of diffusers", description="Force merging of models created with different versions of diffusers",
# default=False, default=False,
# ), ),
# interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method", default=None), interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method", default=None),
# merge_dest_directory: Optional[str] = Body( merge_dest_directory: Optional[str] = Body(
# description="Save the merged model to the designated directory (with 'merged_model_name' appended)", description="Save the merged model to the designated directory (with 'merged_model_name' appended)",
# default=None, default=None,
# ), ),
# ) -> AnyModelConfig: ) -> AnyModelConfig:
# """ """
# Merge diffusers models. The process is controlled by a set parameters provided in the body of the request. Merge diffusers models. The process is controlled by a set parameters provided in the body of the request.
# ``` ```
# Argument Description [default] Argument Description [default]
# -------- ---------------------- -------- ----------------------
# keys List of 2-3 model keys to merge together. All models must use the same base type. keys List of 2-3 model keys to merge together. All models must use the same base type.
# merged_model_name Name for the merged model [Concat model names] merged_model_name Name for the merged model [Concat model names]
# alpha Alpha value (0.0-1.0). Higher values give more weight to the second model [0.5] alpha Alpha value (0.0-1.0). Higher values give more weight to the second model [0.5]
# force If true, force the merge even if the models were generated by different versions of the diffusers library [False] force If true, force the merge even if the models were generated by different versions of the diffusers library [False]
# interp Interpolation method. One of "weighted_sum", "sigmoid", "inv_sigmoid" or "add_difference" [weighted_sum] interp Interpolation method. One of "weighted_sum", "sigmoid", "inv_sigmoid" or "add_difference" [weighted_sum]
# merge_dest_directory Specify a directory to store the merged model in [models directory] merge_dest_directory Specify a directory to store the merged model in [models directory]
# ``` ```
# """ """
# logger = ApiDependencies.invoker.services.logger logger = ApiDependencies.invoker.services.logger
# try:
# logger.info(f"Merging models: {keys} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
# dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None
# installer = ApiDependencies.invoker.services.model_manager.install
# merger = ModelMerger(installer)
# model_names = [installer.record_store.get_model(x).name for x in keys]
# response = merger.merge_diffusion_models_and_save(
# model_keys=keys,
# merged_model_name=merged_model_name or "+".join(model_names),
# alpha=alpha,
# interp=interp,
# force=force,
# merge_dest_directory=dest,
# )
# except UnknownModelException:
# raise HTTPException(
# status_code=404,
# detail=f"One or more of the models '{keys}' not found",
# )
# except ValueError as e:
# raise HTTPException(status_code=400, detail=str(e))
# return response
@model_manager_router.get("/starter_models", operation_id="get_starter_models", response_model=list[StarterModel])
async def get_starter_models() -> list[StarterModel]:
installed_models = ApiDependencies.invoker.services.model_manager.store.search_by_attr()
installed_model_sources = {m.source for m in installed_models}
starter_models = deepcopy(STARTER_MODELS)
for model in starter_models:
if model.source in installed_model_sources:
model.is_installed = True
# Remove already-installed dependencies
missing_deps: list[StarterModelWithoutDependencies] = []
for dep in model.dependencies or []:
if dep.source not in installed_model_sources:
missing_deps.append(dep)
model.dependencies = missing_deps
return starter_models
class HFTokenStatus(str, Enum):
VALID = "valid"
INVALID = "invalid"
UNKNOWN = "unknown"
class HFTokenHelper:
@classmethod
def get_status(cls) -> HFTokenStatus:
try: try:
if huggingface_hub.get_token_permission(huggingface_hub.get_token()): logger.info(f"Merging models: {keys} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
# Valid token! dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None
return HFTokenStatus.VALID installer = ApiDependencies.invoker.services.model_manager.install
# No token set merger = ModelMerger(installer)
return HFTokenStatus.INVALID model_names = [installer.record_store.get_model(x).name for x in keys]
except Exception: response = merger.merge_diffusion_models_and_save(
return HFTokenStatus.UNKNOWN model_keys=keys,
merged_model_name=merged_model_name or "+".join(model_names),
@classmethod alpha=alpha,
def set_token(cls, token: str) -> HFTokenStatus: interp=interp,
with SuppressOutput(), contextlib.suppress(Exception): force=force,
huggingface_hub.login(token=token, add_to_git_credential=False) merge_dest_directory=dest,
return cls.get_status() )
except UnknownModelException:
raise HTTPException(
@model_manager_router.get("/hf_login", operation_id="get_hf_login_status", response_model=HFTokenStatus) status_code=404,
async def get_hf_login_status() -> HFTokenStatus: detail=f"One or more of the models '{keys}' not found",
token_status = HFTokenHelper.get_status() )
except ValueError as e:
if token_status is HFTokenStatus.UNKNOWN: raise HTTPException(status_code=400, detail=str(e))
ApiDependencies.invoker.services.logger.warning("Unable to verify HF token") return response
return token_status
@model_manager_router.post("/hf_login", operation_id="do_hf_login", response_model=HFTokenStatus)
async def do_hf_login(
token: str = Body(description="Hugging Face token to use for login", embed=True),
) -> HFTokenStatus:
HFTokenHelper.set_token(token)
token_status = HFTokenHelper.get_status()
if token_status is HFTokenStatus.UNKNOWN:
ApiDependencies.invoker.services.logger.warning("Unable to verify HF token")
return token_status

View File

@@ -0,0 +1,276 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from fastapi import HTTPException, Path
from fastapi.routing import APIRouter
from ...services.shared.graph import GraphExecutionState
from ..dependencies import ApiDependencies
session_router = APIRouter(prefix="/v1/sessions", tags=["sessions"])
# @session_router.post(
# "/",
# operation_id="create_session",
# responses={
# 200: {"model": GraphExecutionState},
# 400: {"description": "Invalid json"},
# },
# deprecated=True,
# )
# async def create_session(
# queue_id: str = Query(default="", description="The id of the queue to associate the session with"),
# graph: Optional[Graph] = Body(default=None, description="The graph to initialize the session with"),
# ) -> GraphExecutionState:
# """Creates a new session, optionally initializing it with an invocation graph"""
# session = ApiDependencies.invoker.create_execution_state(queue_id=queue_id, graph=graph)
# return session
# @session_router.get(
# "/",
# operation_id="list_sessions",
# responses={200: {"model": PaginatedResults[GraphExecutionState]}},
# deprecated=True,
# )
# async def list_sessions(
# page: int = Query(default=0, description="The page of results to get"),
# per_page: int = Query(default=10, description="The number of results per page"),
# query: str = Query(default="", description="The query string to search for"),
# ) -> PaginatedResults[GraphExecutionState]:
# """Gets a list of sessions, optionally searching"""
# if query == "":
# result = ApiDependencies.invoker.services.graph_execution_manager.list(page, per_page)
# else:
# result = ApiDependencies.invoker.services.graph_execution_manager.search(query, page, per_page)
# return result
@session_router.get(
"/{session_id}",
operation_id="get_session",
responses={
200: {"model": GraphExecutionState},
404: {"description": "Session not found"},
},
)
async def get_session(
session_id: str = Path(description="The id of the session to get"),
) -> GraphExecutionState:
"""Gets a session"""
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
if session is None:
raise HTTPException(status_code=404)
else:
return session
# @session_router.post(
# "/{session_id}/nodes",
# operation_id="add_node",
# responses={
# 200: {"model": str},
# 400: {"description": "Invalid node or link"},
# 404: {"description": "Session not found"},
# },
# deprecated=True,
# )
# async def add_node(
# session_id: str = Path(description="The id of the session"),
# node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body( # type: ignore
# description="The node to add"
# ),
# ) -> str:
# """Adds a node to the graph"""
# session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
# if session is None:
# raise HTTPException(status_code=404)
# try:
# session.add_node(node)
# ApiDependencies.invoker.services.graph_execution_manager.set(
# session
# ) # TODO: can this be done automatically, or add node through an API?
# return session.id
# except NodeAlreadyExecutedError:
# raise HTTPException(status_code=400)
# except IndexError:
# raise HTTPException(status_code=400)
# @session_router.put(
# "/{session_id}/nodes/{node_path}",
# operation_id="update_node",
# responses={
# 200: {"model": GraphExecutionState},
# 400: {"description": "Invalid node or link"},
# 404: {"description": "Session not found"},
# },
# deprecated=True,
# )
# async def update_node(
# session_id: str = Path(description="The id of the session"),
# node_path: str = Path(description="The path to the node in the graph"),
# node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body( # type: ignore
# description="The new node"
# ),
# ) -> GraphExecutionState:
# """Updates a node in the graph and removes all linked edges"""
# session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
# if session is None:
# raise HTTPException(status_code=404)
# try:
# session.update_node(node_path, node)
# ApiDependencies.invoker.services.graph_execution_manager.set(
# session
# ) # TODO: can this be done automatically, or add node through an API?
# return session
# except NodeAlreadyExecutedError:
# raise HTTPException(status_code=400)
# except IndexError:
# raise HTTPException(status_code=400)
# @session_router.delete(
# "/{session_id}/nodes/{node_path}",
# operation_id="delete_node",
# responses={
# 200: {"model": GraphExecutionState},
# 400: {"description": "Invalid node or link"},
# 404: {"description": "Session not found"},
# },
# deprecated=True,
# )
# async def delete_node(
# session_id: str = Path(description="The id of the session"),
# node_path: str = Path(description="The path to the node to delete"),
# ) -> GraphExecutionState:
# """Deletes a node in the graph and removes all linked edges"""
# session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
# if session is None:
# raise HTTPException(status_code=404)
# try:
# session.delete_node(node_path)
# ApiDependencies.invoker.services.graph_execution_manager.set(
# session
# ) # TODO: can this be done automatically, or add node through an API?
# return session
# except NodeAlreadyExecutedError:
# raise HTTPException(status_code=400)
# except IndexError:
# raise HTTPException(status_code=400)
# @session_router.post(
# "/{session_id}/edges",
# operation_id="add_edge",
# responses={
# 200: {"model": GraphExecutionState},
# 400: {"description": "Invalid node or link"},
# 404: {"description": "Session not found"},
# },
# deprecated=True,
# )
# async def add_edge(
# session_id: str = Path(description="The id of the session"),
# edge: Edge = Body(description="The edge to add"),
# ) -> GraphExecutionState:
# """Adds an edge to the graph"""
# session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
# if session is None:
# raise HTTPException(status_code=404)
# try:
# session.add_edge(edge)
# ApiDependencies.invoker.services.graph_execution_manager.set(
# session
# ) # TODO: can this be done automatically, or add node through an API?
# return session
# except NodeAlreadyExecutedError:
# raise HTTPException(status_code=400)
# except IndexError:
# raise HTTPException(status_code=400)
# # TODO: the edge being in the path here is really ugly, find a better solution
# @session_router.delete(
# "/{session_id}/edges/{from_node_id}/{from_field}/{to_node_id}/{to_field}",
# operation_id="delete_edge",
# responses={
# 200: {"model": GraphExecutionState},
# 400: {"description": "Invalid node or link"},
# 404: {"description": "Session not found"},
# },
# deprecated=True,
# )
# async def delete_edge(
# session_id: str = Path(description="The id of the session"),
# from_node_id: str = Path(description="The id of the node the edge is coming from"),
# from_field: str = Path(description="The field of the node the edge is coming from"),
# to_node_id: str = Path(description="The id of the node the edge is going to"),
# to_field: str = Path(description="The field of the node the edge is going to"),
# ) -> GraphExecutionState:
# """Deletes an edge from the graph"""
# session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
# if session is None:
# raise HTTPException(status_code=404)
# try:
# edge = Edge(
# source=EdgeConnection(node_id=from_node_id, field=from_field),
# destination=EdgeConnection(node_id=to_node_id, field=to_field),
# )
# session.delete_edge(edge)
# ApiDependencies.invoker.services.graph_execution_manager.set(
# session
# ) # TODO: can this be done automatically, or add node through an API?
# return session
# except NodeAlreadyExecutedError:
# raise HTTPException(status_code=400)
# except IndexError:
# raise HTTPException(status_code=400)
# @session_router.put(
# "/{session_id}/invoke",
# operation_id="invoke_session",
# responses={
# 200: {"model": None},
# 202: {"description": "The invocation is queued"},
# 400: {"description": "The session has no invocations ready to invoke"},
# 404: {"description": "Session not found"},
# },
# deprecated=True,
# )
# async def invoke_session(
# queue_id: str = Query(description="The id of the queue to associate the session with"),
# session_id: str = Path(description="The id of the session to invoke"),
# all: bool = Query(default=False, description="Whether or not to invoke all remaining invocations"),
# ) -> Response:
# """Invokes a session"""
# session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
# if session is None:
# raise HTTPException(status_code=404)
# if session.is_complete():
# raise HTTPException(status_code=400)
# ApiDependencies.invoker.invoke(queue_id, session, invoke_all=all)
# return Response(status_code=202)
# @session_router.delete(
# "/{session_id}/invoke",
# operation_id="cancel_session_invoke",
# responses={202: {"description": "The invocation is canceled"}},
# deprecated=True,
# )
# async def cancel_session_invoke(
# session_id: str = Path(description="The id of the session to cancel"),
# ) -> Response:
# """Invokes a session"""
# ApiDependencies.invoker.cancel(session_id)
# return Response(status_code=202)

View File

@@ -12,26 +12,16 @@ class SocketIO:
__sio: AsyncServer __sio: AsyncServer
__app: ASGIApp __app: ASGIApp
__sub_queue: str = "subscribe_queue"
__unsub_queue: str = "unsubscribe_queue"
__sub_bulk_download: str = "subscribe_bulk_download"
__unsub_bulk_download: str = "unsubscribe_bulk_download"
def __init__(self, app: FastAPI): def __init__(self, app: FastAPI):
self.__sio = AsyncServer(async_mode="asgi", cors_allowed_origins="*") self.__sio = AsyncServer(async_mode="asgi", cors_allowed_origins="*")
self.__app = ASGIApp(socketio_server=self.__sio, socketio_path="/ws/socket.io") self.__app = ASGIApp(socketio_server=self.__sio, socketio_path="/ws/socket.io")
app.mount("/ws", self.__app) app.mount("/ws", self.__app)
self.__sio.on(self.__sub_queue, handler=self._handle_sub_queue) self.__sio.on("subscribe_queue", handler=self._handle_sub_queue)
self.__sio.on(self.__unsub_queue, handler=self._handle_unsub_queue) self.__sio.on("unsubscribe_queue", handler=self._handle_unsub_queue)
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._handle_queue_event) local_handler.register(event_name=EventServiceBase.queue_event, _func=self._handle_queue_event)
local_handler.register(event_name=EventServiceBase.model_event, _func=self._handle_model_event) local_handler.register(event_name=EventServiceBase.model_event, _func=self._handle_model_event)
self.__sio.on(self.__sub_bulk_download, handler=self._handle_sub_bulk_download)
self.__sio.on(self.__unsub_bulk_download, handler=self._handle_unsub_bulk_download)
local_handler.register(event_name=EventServiceBase.bulk_download_event, _func=self._handle_bulk_download_event)
async def _handle_queue_event(self, event: Event): async def _handle_queue_event(self, event: Event):
await self.__sio.emit( await self.__sio.emit(
event=event[1]["event"], event=event[1]["event"],
@@ -49,18 +39,3 @@ class SocketIO:
async def _handle_model_event(self, event: Event) -> None: async def _handle_model_event(self, event: Event) -> None:
await self.__sio.emit(event=event[1]["event"], data=event[1]["data"]) await self.__sio.emit(event=event[1]["event"], data=event[1]["data"])
async def _handle_bulk_download_event(self, event: Event):
await self.__sio.emit(
event=event[1]["event"],
data=event[1]["data"],
room=event[1]["data"]["bulk_download_id"],
)
async def _handle_sub_bulk_download(self, sid, data, *args, **kwargs):
if "bulk_download_id" in data:
await self.__sio.enter_room(sid, data["bulk_download_id"])
async def _handle_unsub_bulk_download(self, sid, data, *args, **kwargs):
if "bulk_download_id" in data:
await self.__sio.leave_room(sid, data["bulk_download_id"])

View File

@@ -1,35 +1,48 @@
import asyncio # parse_args() must be called before any other imports. if it is not called first, consumers of the config
import mimetypes # which are imported/used before parse_args() is called will get the default config values instead of the
import socket # values from the command line or config file.
from contextlib import asynccontextmanager import sys
from inspect import signature
from pathlib import Path
from typing import Any
import uvicorn
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
from fastapi.openapi.utils import get_openapi
from fastapi.responses import HTMLResponse
from fastapi_events.handlers.local import local_handler
from fastapi_events.middleware import EventHandlerASGIMiddleware
from pydantic.json_schema import models_json_schema
from torch.backends.mps import is_available as is_mps_available
# for PyCharm:
# noinspection PyUnresolvedReferences
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
import invokeai.frontend.web as web_dir
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
from invokeai.app.invocations.model import ModelIdentifierField from invokeai.version.invokeai_version import __version__
from invokeai.app.services.config.config_default import get_config
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from ..backend.util.logging import InvokeAILogger from .invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra
from .api.dependencies import ApiDependencies from .services.config import InvokeAIAppConfig
from .api.routers import (
app_config = InvokeAIAppConfig.get_config()
app_config.parse_args()
if app_config.version:
print(f"InvokeAI version {__version__}")
sys.exit(0)
if True: # hack to make flake8 happy with imports coming after setting up the config
import asyncio
import mimetypes
import socket
from inspect import signature
from pathlib import Path
from typing import Any
import uvicorn
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
from fastapi.openapi.utils import get_openapi
from fastapi.responses import HTMLResponse
from fastapi_events.handlers.local import local_handler
from fastapi_events.middleware import EventHandlerASGIMiddleware
from pydantic.json_schema import models_json_schema
from torch.backends.mps import is_available as is_mps_available
# for PyCharm:
# noinspection PyUnresolvedReferences
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
import invokeai.frontend.web as web_dir
from ..backend.util.logging import InvokeAILogger
from .api.dependencies import ApiDependencies
from .api.routers import (
app_info, app_info,
board_images, board_images,
boards, boards,
@@ -37,48 +50,31 @@ from .api.routers import (
images, images,
model_manager, model_manager,
session_queue, session_queue,
sessions,
utilities, utilities,
workflows, workflows,
) )
from .api.sockets import SocketIO from .api.sockets import SocketIO
from .invocations.baseinvocation import ( from .invocations.baseinvocation import (
BaseInvocation, BaseInvocation,
UIConfigBase, UIConfigBase,
) )
from .invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra
app_config = get_config() if is_mps_available():
if is_mps_available():
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import) import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
app_config = InvokeAIAppConfig.get_config()
app_config.parse_args()
logger = InvokeAILogger.get_logger(config=app_config) logger = InvokeAILogger.get_logger(config=app_config)
# fix for windows mimetypes registry entries being borked # fix for windows mimetypes registry entries being borked
# see https://github.com/invoke-ai/InvokeAI/discussions/3684#discussioncomment-6391352 # see https://github.com/invoke-ai/InvokeAI/discussions/3684#discussioncomment-6391352
mimetypes.add_type("application/javascript", ".js") mimetypes.add_type("application/javascript", ".js")
mimetypes.add_type("text/css", ".css") mimetypes.add_type("text/css", ".css")
@asynccontextmanager
async def lifespan(app: FastAPI):
# Add startup event to load dependencies
ApiDependencies.initialize(config=app_config, event_handler_id=event_handler_id, logger=logger)
yield
# Shut down threads
ApiDependencies.shutdown()
# Create the app # Create the app
# TODO: create this all in a method so configuration/etc. can be passed in? # TODO: create this all in a method so configuration/etc. can be passed in?
app = FastAPI( app = FastAPI(title="Invoke - Community Edition", docs_url=None, redoc_url=None, separate_input_output_schemas=False)
title="Invoke - Community Edition",
docs_url=None,
redoc_url=None,
separate_input_output_schemas=False,
lifespan=lifespan,
)
# Add event handler # Add event handler
event_handler_id: int = id(app) event_handler_id: int = id(app)
@@ -101,7 +97,21 @@ app.add_middleware(
app.add_middleware(GZipMiddleware, minimum_size=1000) app.add_middleware(GZipMiddleware, minimum_size=1000)
# Add startup event to load dependencies
@app.on_event("startup")
async def startup_event() -> None:
ApiDependencies.initialize(config=app_config, event_handler_id=event_handler_id, logger=logger)
# Shut down threads
@app.on_event("shutdown")
async def shutdown_event() -> None:
ApiDependencies.shutdown()
# Include all routers # Include all routers
app.include_router(sessions.session_router, prefix="/api")
app.include_router(utilities.utilities_router, prefix="/api") app.include_router(utilities.utilities_router, prefix="/api")
app.include_router(model_manager.model_manager_router, prefix="/api") app.include_router(model_manager.model_manager_router, prefix="/api")
app.include_router(download_queue.download_queue_router, prefix="/api") app.include_router(download_queue.download_queue_router, prefix="/api")
@@ -141,22 +151,18 @@ def custom_openapi() -> dict[str, Any]:
# TODO: note that we assume the schema_key here is the TYPE.__name__ # TODO: note that we assume the schema_key here is the TYPE.__name__
# This could break in some cases, figure out a better way to do it # This could break in some cases, figure out a better way to do it
output_type_titles[schema_key] = output_schema["title"] output_type_titles[schema_key] = output_schema["title"]
openapi_schema["components"]["schemas"][schema_key] = output_schema
openapi_schema["components"]["schemas"][schema_key]["class"] = "output"
# Some models don't end up in the schemas as standalone definitions # Add Node Editor UI helper schemas
additional_schemas = models_json_schema( ui_config_schemas = models_json_schema(
[ [
(UIConfigBase, "serialization"), (UIConfigBase, "serialization"),
(InputFieldJSONSchemaExtra, "serialization"), (InputFieldJSONSchemaExtra, "serialization"),
(OutputFieldJSONSchemaExtra, "serialization"), (OutputFieldJSONSchemaExtra, "serialization"),
(ModelIdentifierField, "serialization"),
(ProgressImage, "serialization"),
], ],
ref_template="#/components/schemas/{model}", ref_template="#/components/schemas/{model}",
) )
for schema_key, schema_json in additional_schemas[1]["$defs"].items(): for schema_key, ui_config_schema in ui_config_schemas[1]["$defs"].items():
openapi_schema["components"]["schemas"][schema_key] = schema_json openapi_schema["components"]["schemas"][schema_key] = ui_config_schema
# Add a reference to the output type to additionalProperties of the invoker schema # Add a reference to the output type to additionalProperties of the invoker schema
for invoker in all_invocations: for invoker in all_invocations:
@@ -167,6 +173,7 @@ def custom_openapi() -> dict[str, Any]:
outputs_ref = {"$ref": f"#/components/schemas/{output_type_title}"} outputs_ref = {"$ref": f"#/components/schemas/{output_type_title}"}
invoker_schema["output"] = outputs_ref invoker_schema["output"] = outputs_ref
invoker_schema["class"] = "invocation" invoker_schema["class"] = "invocation"
openapi_schema["components"]["schemas"][f"{output_type_title}"]["class"] = "output"
# This code no longer seems to be necessary? # This code no longer seems to be necessary?
# Leave it here just in case # Leave it here just in case
@@ -233,6 +240,10 @@ def invoke_api() -> None:
else: else:
return port return port
from invokeai.backend.install.check_root import check_invokeai_root
check_invokeai_root(app_config) # note, may exit with an exception if root not set up
if app_config.dev_reload: if app_config.dev_reload:
try: try:
import jurigged import jurigged

View File

@@ -3,9 +3,9 @@ import sys
from importlib.util import module_from_spec, spec_from_file_location from importlib.util import module_from_spec, spec_from_file_location
from pathlib import Path from pathlib import Path
from invokeai.app.services.config.config_default import get_config from invokeai.app.services.config.config_default import InvokeAIAppConfig
custom_nodes_path = Path(get_config().custom_nodes_path) custom_nodes_path = Path(InvokeAIAppConfig.get_config().custom_nodes_path.resolve())
custom_nodes_path.mkdir(parents=True, exist_ok=True) custom_nodes_path.mkdir(parents=True, exist_ok=True)
custom_nodes_init_path = str(custom_nodes_path / "__init__.py") custom_nodes_init_path = str(custom_nodes_path / "__init__.py")

View File

@@ -8,32 +8,19 @@ import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from enum import Enum from enum import Enum
from inspect import signature from inspect import signature
from typing import ( from types import UnionType
TYPE_CHECKING, from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Literal, Optional, Type, TypeVar, Union, cast
Annotated,
Any,
Callable,
ClassVar,
Iterable,
Literal,
Optional,
Type,
TypeVar,
Union,
cast,
)
import semver import semver
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter, create_model from pydantic import BaseModel, ConfigDict, Field, create_model
from pydantic.fields import FieldInfo from pydantic.fields import FieldInfo
from pydantic_core import PydanticUndefined from pydantic_core import PydanticUndefined
from typing_extensions import TypeAliasType
from invokeai.app.invocations.fields import ( from invokeai.app.invocations.fields import (
FieldKind, FieldKind,
Input, Input,
) )
from invokeai.app.services.config.config_default import get_config from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.metaenum import MetaEnum from invokeai.app.util.metaenum import MetaEnum
from invokeai.app.util.misc import uuid_string from invokeai.app.util.misc import uuid_string
@@ -97,7 +84,6 @@ class BaseInvocationOutput(BaseModel):
""" """
_output_classes: ClassVar[set[BaseInvocationOutput]] = set() _output_classes: ClassVar[set[BaseInvocationOutput]] = set()
_typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None
@classmethod @classmethod
def register_output(cls, output: BaseInvocationOutput) -> None: def register_output(cls, output: BaseInvocationOutput) -> None:
@@ -110,14 +96,10 @@ class BaseInvocationOutput(BaseModel):
return cls._output_classes return cls._output_classes
@classmethod @classmethod
def get_typeadapter(cls) -> TypeAdapter[Any]: def get_outputs_union(cls) -> UnionType:
"""Gets a pydantc TypeAdapter for the union of all invocation output types.""" """Gets a union of all invocation outputs."""
if not cls._typeadapter: outputs_union = Union[tuple(cls._output_classes)] # type: ignore [valid-type]
InvocationOutputsUnion = TypeAliasType( return outputs_union # type: ignore [return-value]
"InvocationOutputsUnion", Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")]
)
cls._typeadapter = TypeAdapter(InvocationOutputsUnion)
return cls._typeadapter
@classmethod @classmethod
def get_output_types(cls) -> Iterable[str]: def get_output_types(cls) -> Iterable[str]:
@@ -166,7 +148,6 @@ class BaseInvocation(ABC, BaseModel):
""" """
_invocation_classes: ClassVar[set[BaseInvocation]] = set() _invocation_classes: ClassVar[set[BaseInvocation]] = set()
_typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None
@classmethod @classmethod
def get_type(cls) -> str: def get_type(cls) -> str:
@@ -179,19 +160,15 @@ class BaseInvocation(ABC, BaseModel):
cls._invocation_classes.add(invocation) cls._invocation_classes.add(invocation)
@classmethod @classmethod
def get_typeadapter(cls) -> TypeAdapter[Any]: def get_invocations_union(cls) -> UnionType:
"""Gets a pydantc TypeAdapter for the union of all invocation types.""" """Gets a union of all invocation types."""
if not cls._typeadapter: invocations_union = Union[tuple(cls._invocation_classes)] # type: ignore [valid-type]
InvocationsUnion = TypeAliasType( return invocations_union # type: ignore [return-value]
"InvocationsUnion", Annotated[Union[tuple(cls._invocation_classes)], Field(discriminator="type")]
)
cls._typeadapter = TypeAdapter(InvocationsUnion)
return cls._typeadapter
@classmethod @classmethod
def get_invocations(cls) -> Iterable[BaseInvocation]: def get_invocations(cls) -> Iterable[BaseInvocation]:
"""Gets all invocations, respecting the allowlist and denylist.""" """Gets all invocations, respecting the allowlist and denylist."""
app_config = get_config() app_config = InvokeAIAppConfig.get_config()
allowed_invocations: set[BaseInvocation] = set() allowed_invocations: set[BaseInvocation] = set()
for sc in cls._invocation_classes: for sc in cls._invocation_classes:
invocation_type = sc.get_type() invocation_type = sc.get_type()

View File

@@ -1,15 +1,24 @@
from typing import Iterator, List, Optional, Tuple, Union, cast from typing import Iterator, List, Optional, Tuple, Union
import torch import torch
from compel import Compel, ReturnedEmbeddingsType from compel import Compel, ReturnedEmbeddingsType
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from transformers import CLIPTokenizer
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIComponent import invokeai.backend.util.logging as logger
from invokeai.app.invocations.fields import (
FieldDescriptions,
Input,
InputField,
OutputField,
UIComponent,
)
from invokeai.app.invocations.primitives import ConditioningOutput from invokeai.app.invocations.primitives import ConditioningOutput
from invokeai.app.services.model_records import UnknownModelException
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.ti_utils import generate_ti_list from invokeai.app.util.ti_utils import extract_ti_triggers_from_prompt
from invokeai.backend.lora import LoRAModelRaw from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_manager import ModelType
from invokeai.backend.model_patcher import ModelPatcher from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
BasicConditioningInfo, BasicConditioningInfo,
@@ -17,10 +26,16 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
ExtraConditioningInfo, ExtraConditioningInfo,
SDXLConditioningInfo, SDXLConditioningInfo,
) )
from invokeai.backend.textual_inversion import TextualInversionModelRaw
from invokeai.backend.util.devices import torch_dtype from invokeai.backend.util.devices import torch_dtype
from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output from .baseinvocation import (
from .model import CLIPField BaseInvocation,
BaseInvocationOutput,
invocation,
invocation_output,
)
from .model import ClipField
# unconditioned: Optional[torch.Tensor] # unconditioned: Optional[torch.Tensor]
@@ -36,7 +51,7 @@ from .model import CLIPField
title="Prompt", title="Prompt",
tags=["prompt", "compel"], tags=["prompt", "compel"],
category="conditioning", category="conditioning",
version="1.1.1", version="1.0.1",
) )
class CompelInvocation(BaseInvocation): class CompelInvocation(BaseInvocation):
"""Parse prompt using compel package to conditioning.""" """Parse prompt using compel package to conditioning."""
@@ -46,7 +61,7 @@ class CompelInvocation(BaseInvocation):
description=FieldDescriptions.compel_prompt, description=FieldDescriptions.compel_prompt,
ui_component=UIComponent.Textarea, ui_component=UIComponent.Textarea,
) )
clip: CLIPField = InputField( clip: ClipField = InputField(
title="CLIP", title="CLIP",
description=FieldDescriptions.clip, description=FieldDescriptions.clip,
input=Input.Connection, input=Input.Connection,
@@ -54,16 +69,12 @@ class CompelInvocation(BaseInvocation):
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> ConditioningOutput: def invoke(self, context: InvocationContext) -> ConditioningOutput:
tokenizer_info = context.models.load(self.clip.tokenizer) tokenizer_info = context.models.load(**self.clip.tokenizer.model_dump())
tokenizer_model = tokenizer_info.model text_encoder_info = context.models.load(**self.clip.text_encoder.model_dump())
assert isinstance(tokenizer_model, CLIPTokenizer)
text_encoder_info = context.models.load(self.clip.text_encoder)
text_encoder_model = text_encoder_info.model
assert isinstance(text_encoder_model, CLIPTextModel)
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.clip.loras: for lora in self.clip.loras:
lora_info = context.models.load(lora.lora) lora_info = context.models.load(**lora.model_dump(exclude={"weight"}))
assert isinstance(lora_info.model, LoRAModelRaw) assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight) yield (lora_info.model, lora.weight)
del lora_info del lora_info
@@ -71,10 +82,21 @@ class CompelInvocation(BaseInvocation):
# loras = [(context.models.get(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras] # loras = [(context.models.get(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
ti_list = generate_ti_list(self.prompt, text_encoder_info.config.base, context) ti_list = []
for trigger in extract_ti_triggers_from_prompt(self.prompt):
name = trigger[1:-1]
try:
loaded_model = context.models.load(**self.clip.text_encoder.model_dump()).model
assert isinstance(loaded_model, TextualInversionModelRaw)
ti_list.append((name, loaded_model))
except UnknownModelException:
# print(e)
# import traceback
# print(traceback.format_exc())
print(f'Warn: trigger: "{trigger}" not found')
with ( with (
ModelPatcher.apply_ti(tokenizer_model, text_encoder_model, ti_list) as ( ModelPatcher.apply_ti(tokenizer_info.model, text_encoder_info.model, ti_list) as (
tokenizer, tokenizer,
ti_manager, ti_manager,
), ),
@@ -82,9 +104,8 @@ class CompelInvocation(BaseInvocation):
# Apply the LoRA after text_encoder has been moved to its target device for faster patching. # Apply the LoRA after text_encoder has been moved to its target device for faster patching.
ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()), ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers. # Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
ModelPatcher.apply_clip_skip(text_encoder_model, self.clip.skipped_layers), ModelPatcher.apply_clip_skip(text_encoder_info.model, self.clip.skipped_layers),
): ):
assert isinstance(text_encoder, CLIPTextModel)
compel = Compel( compel = Compel(
tokenizer=tokenizer, tokenizer=tokenizer,
text_encoder=text_encoder, text_encoder=text_encoder,
@@ -127,18 +148,14 @@ class SDXLPromptInvocationBase:
def run_clip_compel( def run_clip_compel(
self, self,
context: InvocationContext, context: InvocationContext,
clip_field: CLIPField, clip_field: ClipField,
prompt: str, prompt: str,
get_pooled: bool, get_pooled: bool,
lora_prefix: str, lora_prefix: str,
zero_on_empty: bool, zero_on_empty: bool,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[ExtraConditioningInfo]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[ExtraConditioningInfo]]:
tokenizer_info = context.models.load(clip_field.tokenizer) tokenizer_info = context.models.load(**clip_field.tokenizer.model_dump())
tokenizer_model = tokenizer_info.model text_encoder_info = context.models.load(**clip_field.text_encoder.model_dump())
assert isinstance(tokenizer_model, CLIPTokenizer)
text_encoder_info = context.models.load(clip_field.text_encoder)
text_encoder_model = text_encoder_info.model
assert isinstance(text_encoder_model, (CLIPTextModel, CLIPTextModelWithProjection))
# return zero on empty # return zero on empty
if prompt == "" and zero_on_empty: if prompt == "" and zero_on_empty:
@@ -163,7 +180,7 @@ class SDXLPromptInvocationBase:
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in clip_field.loras: for lora in clip_field.loras:
lora_info = context.models.load(lora.lora) lora_info = context.models.load(**lora.model_dump(exclude={"weight"}))
lora_model = lora_info.model lora_model = lora_info.model
assert isinstance(lora_model, LoRAModelRaw) assert isinstance(lora_model, LoRAModelRaw)
yield (lora_model, lora.weight) yield (lora_model, lora.weight)
@@ -172,10 +189,25 @@ class SDXLPromptInvocationBase:
# loras = [(context.models.get(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras] # loras = [(context.models.get(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
ti_list = generate_ti_list(prompt, text_encoder_info.config.base, context) ti_list = []
for trigger in extract_ti_triggers_from_prompt(prompt):
name = trigger[1:-1]
try:
ti_model = context.models.load_by_attrs(
model_name=name, base_model=text_encoder_info.config.base, model_type=ModelType.TextualInversion
).model
assert isinstance(ti_model, TextualInversionModelRaw)
ti_list.append((name, ti_model))
except UnknownModelException:
# print(e)
# import traceback
# print(traceback.format_exc())
logger.warning(f'trigger: "{trigger}" not found')
except ValueError:
logger.warning(f'trigger: "{trigger}" more than one similarly-named textual inversion models')
with ( with (
ModelPatcher.apply_ti(tokenizer_model, text_encoder_model, ti_list) as ( ModelPatcher.apply_ti(tokenizer_info.model, text_encoder_info.model, ti_list) as (
tokenizer, tokenizer,
ti_manager, ti_manager,
), ),
@@ -183,10 +215,8 @@ class SDXLPromptInvocationBase:
# Apply the LoRA after text_encoder has been moved to its target device for faster patching. # Apply the LoRA after text_encoder has been moved to its target device for faster patching.
ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix), ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers. # Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
ModelPatcher.apply_clip_skip(text_encoder_model, clip_field.skipped_layers), ModelPatcher.apply_clip_skip(text_encoder_info.model, clip_field.skipped_layers),
): ):
assert isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection))
text_encoder = cast(CLIPTextModel, text_encoder)
compel = Compel( compel = Compel(
tokenizer=tokenizer, tokenizer=tokenizer,
text_encoder=text_encoder, text_encoder=text_encoder,
@@ -232,7 +262,7 @@ class SDXLPromptInvocationBase:
title="SDXL Prompt", title="SDXL Prompt",
tags=["sdxl", "compel", "prompt"], tags=["sdxl", "compel", "prompt"],
category="conditioning", category="conditioning",
version="1.1.1", version="1.0.1",
) )
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
"""Parse prompt using compel package to conditioning.""" """Parse prompt using compel package to conditioning."""
@@ -253,8 +283,8 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
crop_left: int = InputField(default=0, description="") crop_left: int = InputField(default=0, description="")
target_width: int = InputField(default=1024, description="") target_width: int = InputField(default=1024, description="")
target_height: int = InputField(default=1024, description="") target_height: int = InputField(default=1024, description="")
clip: CLIPField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1") clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1")
clip2: CLIPField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2") clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2")
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> ConditioningOutput: def invoke(self, context: InvocationContext) -> ConditioningOutput:
@@ -325,7 +355,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
title="SDXL Refiner Prompt", title="SDXL Refiner Prompt",
tags=["sdxl", "compel", "prompt"], tags=["sdxl", "compel", "prompt"],
category="conditioning", category="conditioning",
version="1.1.1", version="1.0.1",
) )
class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
"""Parse prompt using compel package to conditioning.""" """Parse prompt using compel package to conditioning."""
@@ -340,7 +370,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
crop_top: int = InputField(default=0, description="") crop_top: int = InputField(default=0, description="")
crop_left: int = InputField(default=0, description="") crop_left: int = InputField(default=0, description="")
aesthetic_score: float = InputField(default=6.0, description=FieldDescriptions.sdxl_aesthetic) aesthetic_score: float = InputField(default=6.0, description=FieldDescriptions.sdxl_aesthetic)
clip2: CLIPField = InputField(description=FieldDescriptions.clip, input=Input.Connection) clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> ConditioningOutput: def invoke(self, context: InvocationContext) -> ConditioningOutput:
@@ -370,10 +400,10 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
@invocation_output("clip_skip_output") @invocation_output("clip_skip_output")
class CLIPSkipInvocationOutput(BaseInvocationOutput): class ClipSkipInvocationOutput(BaseInvocationOutput):
"""CLIP skip node output""" """Clip skip node output"""
clip: Optional[CLIPField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP") clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
@invocation( @invocation(
@@ -381,17 +411,17 @@ class CLIPSkipInvocationOutput(BaseInvocationOutput):
title="CLIP Skip", title="CLIP Skip",
tags=["clipskip", "clip", "skip"], tags=["clipskip", "clip", "skip"],
category="conditioning", category="conditioning",
version="1.1.0", version="1.0.0",
) )
class CLIPSkipInvocation(BaseInvocation): class ClipSkipInvocation(BaseInvocation):
"""Skip layers in clip text_encoder model.""" """Skip layers in clip text_encoder model."""
clip: CLIPField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP") clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP")
skipped_layers: int = InputField(default=0, ge=0, description=FieldDescriptions.skipped_layers) skipped_layers: int = InputField(default=0, description=FieldDescriptions.skipped_layers)
def invoke(self, context: InvocationContext) -> CLIPSkipInvocationOutput: def invoke(self, context: InvocationContext) -> ClipSkipInvocationOutput:
self.clip.skipped_layers += self.skipped_layers self.clip.skipped_layers += self.skipped_layers
return CLIPSkipInvocationOutput( return ClipSkipInvocationOutput(
clip=self.clip, clip=self.clip,
) )

View File

@@ -7,8 +7,12 @@ from typing import Dict, List, Literal, Union
import cv2 import cv2
import numpy as np import numpy as np
from controlnet_aux import ( from controlnet_aux import (
CannyDetector,
ContentShuffleDetector, ContentShuffleDetector,
HEDdetector,
LeresDetector, LeresDetector,
LineartAnimeDetector,
LineartDetector,
MediapipeFaceDetector, MediapipeFaceDetector,
MidasDetector, MidasDetector,
MLSDdetector, MLSDdetector,
@@ -27,20 +31,14 @@ from invokeai.app.invocations.fields import (
Input, Input,
InputField, InputField,
OutputField, OutputField,
UIType,
WithBoard, WithBoard,
WithMetadata, WithMetadata,
) )
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.image_util.canny import get_canny_edges
from invokeai.backend.image_util.depth_anything import DepthAnythingDetector from invokeai.backend.image_util.depth_anything import DepthAnythingDetector
from invokeai.backend.image_util.dw_openpose import DWOpenposeDetector from invokeai.backend.image_util.dw_openpose import DWOpenposeDetector
from invokeai.backend.image_util.hed import HEDProcessor
from invokeai.backend.image_util.lineart import LineartProcessor
from invokeai.backend.image_util.lineart_anime import LineartAnimeProcessor
from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
@@ -53,9 +51,15 @@ CONTROLNET_RESIZE_VALUES = Literal[
] ]
class ControlNetModelField(BaseModel):
"""ControlNet model field"""
key: str = Field(description="Model config record key for the ControlNet model")
class ControlField(BaseModel): class ControlField(BaseModel):
image: ImageField = Field(description="The control image") image: ImageField = Field(description="The control image")
control_model: ModelIdentifierField = Field(description="The ControlNet model to use") control_model: ControlNetModelField = Field(description="The ControlNet model to use")
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet") control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
begin_step_percent: float = Field( begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)" default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
@@ -91,9 +95,7 @@ class ControlNetInvocation(BaseInvocation):
"""Collects ControlNet info to pass to other nodes""" """Collects ControlNet info to pass to other nodes"""
image: ImageField = InputField(description="The control image") image: ImageField = InputField(description="The control image")
control_model: ModelIdentifierField = InputField( control_model: ControlNetModelField = InputField(description=FieldDescriptions.controlnet_model, input=Input.Direct)
description=FieldDescriptions.controlnet_model, input=Input.Direct, ui_type=UIType.ControlNetModel
)
control_weight: Union[float, List[float]] = InputField( control_weight: Union[float, List[float]] = InputField(
default=1.0, ge=-1, le=2, description="The weight given to the ControlNet" default=1.0, ge=-1, le=2, description="The weight given to the ControlNet"
) )
@@ -171,13 +173,11 @@ class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
title="Canny Processor", title="Canny Processor",
tags=["controlnet", "canny"], tags=["controlnet", "canny"],
category="controlnet", category="controlnet",
version="1.3.2", version="1.2.1",
) )
class CannyImageProcessorInvocation(ImageProcessorInvocation): class CannyImageProcessorInvocation(ImageProcessorInvocation):
"""Canny edge detection for ControlNet""" """Canny edge detection for ControlNet"""
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
low_threshold: int = InputField( low_threshold: int = InputField(
default=100, ge=0, le=255, description="The low threshold of the Canny pixel gradient (0-255)" default=100, ge=0, le=255, description="The low threshold of the Canny pixel gradient (0-255)"
) )
@@ -189,14 +189,9 @@ class CannyImageProcessorInvocation(ImageProcessorInvocation):
# Keep alpha channel for Canny processing to detect edges of transparent areas # Keep alpha channel for Canny processing to detect edges of transparent areas
return context.images.get_pil(self.image.image_name, "RGBA") return context.images.get_pil(self.image.image_name, "RGBA")
def run_processor(self, image: Image.Image) -> Image.Image: def run_processor(self, image):
processed_image = get_canny_edges( canny_processor = CannyDetector()
image, processed_image = canny_processor(image, self.low_threshold, self.high_threshold)
self.low_threshold,
self.high_threshold,
detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution,
)
return processed_image return processed_image
@@ -205,7 +200,7 @@ class CannyImageProcessorInvocation(ImageProcessorInvocation):
title="HED (softedge) Processor", title="HED (softedge) Processor",
tags=["controlnet", "hed", "softedge"], tags=["controlnet", "hed", "softedge"],
category="controlnet", category="controlnet",
version="1.2.2", version="1.2.1",
) )
class HedImageProcessorInvocation(ImageProcessorInvocation): class HedImageProcessorInvocation(ImageProcessorInvocation):
"""Applies HED edge detection to image""" """Applies HED edge detection to image"""
@@ -216,9 +211,9 @@ class HedImageProcessorInvocation(ImageProcessorInvocation):
# safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode) # safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode)
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode) scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
def run_processor(self, image: Image.Image) -> Image.Image: def run_processor(self, image):
hed_processor = HEDProcessor() hed_processor = HEDdetector.from_pretrained("lllyasviel/Annotators")
processed_image = hed_processor.run( processed_image = hed_processor(
image, image,
detect_resolution=self.detect_resolution, detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution, image_resolution=self.image_resolution,
@@ -234,7 +229,7 @@ class HedImageProcessorInvocation(ImageProcessorInvocation):
title="Lineart Processor", title="Lineart Processor",
tags=["controlnet", "lineart"], tags=["controlnet", "lineart"],
category="controlnet", category="controlnet",
version="1.2.2", version="1.2.1",
) )
class LineartImageProcessorInvocation(ImageProcessorInvocation): class LineartImageProcessorInvocation(ImageProcessorInvocation):
"""Applies line art processing to image""" """Applies line art processing to image"""
@@ -243,9 +238,9 @@ class LineartImageProcessorInvocation(ImageProcessorInvocation):
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res) image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
coarse: bool = InputField(default=False, description="Whether to use coarse mode") coarse: bool = InputField(default=False, description="Whether to use coarse mode")
def run_processor(self, image: Image.Image) -> Image.Image: def run_processor(self, image):
lineart_processor = LineartProcessor() lineart_processor = LineartDetector.from_pretrained("lllyasviel/Annotators")
processed_image = lineart_processor.run( processed_image = lineart_processor(
image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution, coarse=self.coarse image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution, coarse=self.coarse
) )
return processed_image return processed_image
@@ -256,7 +251,7 @@ class LineartImageProcessorInvocation(ImageProcessorInvocation):
title="Lineart Anime Processor", title="Lineart Anime Processor",
tags=["controlnet", "lineart", "anime"], tags=["controlnet", "lineart", "anime"],
category="controlnet", category="controlnet",
version="1.2.2", version="1.2.1",
) )
class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation): class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
"""Applies line art anime processing to image""" """Applies line art anime processing to image"""
@@ -264,9 +259,9 @@ class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res) detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res) image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
def run_processor(self, image: Image.Image) -> Image.Image: def run_processor(self, image):
processor = LineartAnimeProcessor() processor = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators")
processed_image = processor.run( processed_image = processor(
image, image,
detect_resolution=self.detect_resolution, detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution, image_resolution=self.image_resolution,
@@ -279,15 +274,13 @@ class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
title="Midas Depth Processor", title="Midas Depth Processor",
tags=["controlnet", "midas"], tags=["controlnet", "midas"],
category="controlnet", category="controlnet",
version="1.2.3", version="1.2.1",
) )
class MidasDepthImageProcessorInvocation(ImageProcessorInvocation): class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
"""Applies Midas depth processing to image""" """Applies Midas depth processing to image"""
a_mult: float = InputField(default=2.0, ge=0, description="Midas parameter `a_mult` (a = a_mult * PI)") a_mult: float = InputField(default=2.0, ge=0, description="Midas parameter `a_mult` (a = a_mult * PI)")
bg_th: float = InputField(default=0.1, ge=0, description="Midas parameter `bg_th`") bg_th: float = InputField(default=0.1, ge=0, description="Midas parameter `bg_th`")
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
# depth_and_normal not supported in controlnet_aux v0.0.3 # depth_and_normal not supported in controlnet_aux v0.0.3
# depth_and_normal: bool = InputField(default=False, description="whether to use depth and normal mode") # depth_and_normal: bool = InputField(default=False, description="whether to use depth and normal mode")
@@ -297,8 +290,6 @@ class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
image, image,
a=np.pi * self.a_mult, a=np.pi * self.a_mult,
bg_th=self.bg_th, bg_th=self.bg_th,
image_resolution=self.image_resolution,
detect_resolution=self.detect_resolution,
# dept_and_normal not supported in controlnet_aux v0.0.3 # dept_and_normal not supported in controlnet_aux v0.0.3
# depth_and_normal=self.depth_and_normal, # depth_and_normal=self.depth_and_normal,
) )
@@ -310,7 +301,7 @@ class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
title="Normal BAE Processor", title="Normal BAE Processor",
tags=["controlnet"], tags=["controlnet"],
category="controlnet", category="controlnet",
version="1.2.2", version="1.2.1",
) )
class NormalbaeImageProcessorInvocation(ImageProcessorInvocation): class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
"""Applies NormalBae processing to image""" """Applies NormalBae processing to image"""
@@ -327,7 +318,7 @@ class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
@invocation( @invocation(
"mlsd_image_processor", title="MLSD Processor", tags=["controlnet", "mlsd"], category="controlnet", version="1.2.2" "mlsd_image_processor", title="MLSD Processor", tags=["controlnet", "mlsd"], category="controlnet", version="1.2.1"
) )
class MlsdImageProcessorInvocation(ImageProcessorInvocation): class MlsdImageProcessorInvocation(ImageProcessorInvocation):
"""Applies MLSD processing to image""" """Applies MLSD processing to image"""
@@ -350,7 +341,7 @@ class MlsdImageProcessorInvocation(ImageProcessorInvocation):
@invocation( @invocation(
"pidi_image_processor", title="PIDI Processor", tags=["controlnet", "pidi"], category="controlnet", version="1.2.2" "pidi_image_processor", title="PIDI Processor", tags=["controlnet", "pidi"], category="controlnet", version="1.2.1"
) )
class PidiImageProcessorInvocation(ImageProcessorInvocation): class PidiImageProcessorInvocation(ImageProcessorInvocation):
"""Applies PIDI processing to image""" """Applies PIDI processing to image"""
@@ -377,7 +368,7 @@ class PidiImageProcessorInvocation(ImageProcessorInvocation):
title="Content Shuffle Processor", title="Content Shuffle Processor",
tags=["controlnet", "contentshuffle"], tags=["controlnet", "contentshuffle"],
category="controlnet", category="controlnet",
version="1.2.2", version="1.2.1",
) )
class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation): class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
"""Applies content shuffle processing to image""" """Applies content shuffle processing to image"""
@@ -407,7 +398,7 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
title="Zoe (Depth) Processor", title="Zoe (Depth) Processor",
tags=["controlnet", "zoe", "depth"], tags=["controlnet", "zoe", "depth"],
category="controlnet", category="controlnet",
version="1.2.2", version="1.2.1",
) )
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation): class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
"""Applies Zoe depth processing to image""" """Applies Zoe depth processing to image"""
@@ -423,25 +414,17 @@ class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
title="Mediapipe Face Processor", title="Mediapipe Face Processor",
tags=["controlnet", "mediapipe", "face"], tags=["controlnet", "mediapipe", "face"],
category="controlnet", category="controlnet",
version="1.2.3", version="1.2.1",
) )
class MediapipeFaceProcessorInvocation(ImageProcessorInvocation): class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
"""Applies mediapipe face processing to image""" """Applies mediapipe face processing to image"""
max_faces: int = InputField(default=1, ge=1, description="Maximum number of faces to detect") max_faces: int = InputField(default=1, ge=1, description="Maximum number of faces to detect")
min_confidence: float = InputField(default=0.5, ge=0, le=1, description="Minimum confidence for face detection") min_confidence: float = InputField(default=0.5, ge=0, le=1, description="Minimum confidence for face detection")
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
def run_processor(self, image): def run_processor(self, image):
mediapipe_face_processor = MediapipeFaceDetector() mediapipe_face_processor = MediapipeFaceDetector()
processed_image = mediapipe_face_processor( processed_image = mediapipe_face_processor(image, max_faces=self.max_faces, min_confidence=self.min_confidence)
image,
max_faces=self.max_faces,
min_confidence=self.min_confidence,
image_resolution=self.image_resolution,
detect_resolution=self.detect_resolution,
)
return processed_image return processed_image
@@ -450,7 +433,7 @@ class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
title="Leres (Depth) Processor", title="Leres (Depth) Processor",
tags=["controlnet", "leres", "depth"], tags=["controlnet", "leres", "depth"],
category="controlnet", category="controlnet",
version="1.2.2", version="1.2.1",
) )
class LeresImageProcessorInvocation(ImageProcessorInvocation): class LeresImageProcessorInvocation(ImageProcessorInvocation):
"""Applies leres processing to image""" """Applies leres processing to image"""
@@ -479,7 +462,7 @@ class LeresImageProcessorInvocation(ImageProcessorInvocation):
title="Tile Resample Processor", title="Tile Resample Processor",
tags=["controlnet", "tile"], tags=["controlnet", "tile"],
category="controlnet", category="controlnet",
version="1.2.2", version="1.2.1",
) )
class TileResamplerProcessorInvocation(ImageProcessorInvocation): class TileResamplerProcessorInvocation(ImageProcessorInvocation):
"""Tile resampler processor""" """Tile resampler processor"""
@@ -519,23 +502,18 @@ class TileResamplerProcessorInvocation(ImageProcessorInvocation):
title="Segment Anything Processor", title="Segment Anything Processor",
tags=["controlnet", "segmentanything"], tags=["controlnet", "segmentanything"],
category="controlnet", category="controlnet",
version="1.2.3", version="1.2.1",
) )
class SegmentAnythingProcessorInvocation(ImageProcessorInvocation): class SegmentAnythingProcessorInvocation(ImageProcessorInvocation):
"""Applies segment anything processing to image""" """Applies segment anything processing to image"""
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
def run_processor(self, image): def run_processor(self, image):
# segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints") # segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
segment_anything_processor = SamDetectorReproducibleColors.from_pretrained( segment_anything_processor = SamDetectorReproducibleColors.from_pretrained(
"ybelkada/segment-anything", subfolder="checkpoints" "ybelkada/segment-anything", subfolder="checkpoints"
) )
np_img = np.array(image, dtype=np.uint8) np_img = np.array(image, dtype=np.uint8)
processed_image = segment_anything_processor( processed_image = segment_anything_processor(np_img)
np_img, image_resolution=self.image_resolution, detect_resolution=self.detect_resolution
)
return processed_image return processed_image
@@ -566,7 +544,7 @@ class SamDetectorReproducibleColors(SamDetector):
title="Color Map Processor", title="Color Map Processor",
tags=["controlnet"], tags=["controlnet"],
category="controlnet", category="controlnet",
version="1.2.2", version="1.2.1",
) )
class ColorMapImageProcessorInvocation(ImageProcessorInvocation): class ColorMapImageProcessorInvocation(ImageProcessorInvocation):
"""Generates a color map from the provided image""" """Generates a color map from the provided image"""
@@ -598,7 +576,7 @@ DEPTH_ANYTHING_MODEL_SIZES = Literal["large", "base", "small"]
title="Depth Anything Processor", title="Depth Anything Processor",
tags=["controlnet", "depth", "depth anything"], tags=["controlnet", "depth", "depth anything"],
category="controlnet", category="controlnet",
version="1.1.1", version="1.0.0",
) )
class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation): class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
"""Generates a depth map based on the Depth Anything algorithm""" """Generates a depth map based on the Depth Anything algorithm"""
@@ -607,12 +585,13 @@ class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
default="small", description="The size of the depth model to use" default="small", description="The size of the depth model to use"
) )
resolution: int = InputField(default=512, ge=64, multiple_of=64, description=FieldDescriptions.image_res) resolution: int = InputField(default=512, ge=64, multiple_of=64, description=FieldDescriptions.image_res)
offload: bool = InputField(default=False)
def run_processor(self, image: Image.Image): def run_processor(self, image: Image.Image):
depth_anything_detector = DepthAnythingDetector() depth_anything_detector = DepthAnythingDetector()
depth_anything_detector.load_model(model_size=self.model_size) depth_anything_detector.load_model(model_size=self.model_size)
processed_image = depth_anything_detector(image=image, resolution=self.resolution) processed_image = depth_anything_detector(image=image, resolution=self.resolution, offload=self.offload)
return processed_image return processed_image
@@ -621,7 +600,7 @@ class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
title="DW Openpose Image Processor", title="DW Openpose Image Processor",
tags=["controlnet", "dwpose", "openpose"], tags=["controlnet", "dwpose", "openpose"],
category="controlnet", category="controlnet",
version="1.1.0", version="1.0.0",
) )
class DWOpenposeImageProcessorInvocation(ImageProcessorInvocation): class DWOpenposeImageProcessorInvocation(ImageProcessorInvocation):
"""Generates an openpose pose from an image using DWPose""" """Generates an openpose pose from an image using DWPose"""

View File

@@ -13,7 +13,7 @@ from .baseinvocation import BaseInvocation, invocation
from .fields import InputField, WithBoard, WithMetadata from .fields import InputField, WithBoard, WithMetadata
@invocation("cv_inpaint", title="OpenCV Inpaint", tags=["opencv", "inpaint"], category="inpaint", version="1.3.1") @invocation("cv_inpaint", title="OpenCV Inpaint", tags=["opencv", "inpaint"], category="inpaint", version="1.2.1")
class CvInpaintInvocation(BaseInvocation, WithMetadata, WithBoard): class CvInpaintInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Simple inpaint using opencv.""" """Simple inpaint using opencv."""

View File

@@ -435,7 +435,7 @@ def get_faces_list(
return all_faces return all_faces
@invocation("face_off", title="FaceOff", tags=["image", "faceoff", "face", "mask"], category="image", version="1.2.2") @invocation("face_off", title="FaceOff", tags=["image", "faceoff", "face", "mask"], category="image", version="1.2.1")
class FaceOffInvocation(BaseInvocation, WithMetadata): class FaceOffInvocation(BaseInvocation, WithMetadata):
"""Bound, extract, and mask a face from an image using MediaPipe detection""" """Bound, extract, and mask a face from an image using MediaPipe detection"""
@@ -514,7 +514,7 @@ class FaceOffInvocation(BaseInvocation, WithMetadata):
return output return output
@invocation("face_mask_detection", title="FaceMask", tags=["image", "face", "mask"], category="image", version="1.2.2") @invocation("face_mask_detection", title="FaceMask", tags=["image", "face", "mask"], category="image", version="1.2.1")
class FaceMaskInvocation(BaseInvocation, WithMetadata): class FaceMaskInvocation(BaseInvocation, WithMetadata):
"""Face mask creation using mediapipe face detection""" """Face mask creation using mediapipe face detection"""
@@ -617,7 +617,7 @@ class FaceMaskInvocation(BaseInvocation, WithMetadata):
@invocation( @invocation(
"face_identifier", title="FaceIdentifier", tags=["image", "face", "identifier"], category="image", version="1.2.2" "face_identifier", title="FaceIdentifier", tags=["image", "face", "identifier"], category="image", version="1.2.1"
) )
class FaceIdentifierInvocation(BaseInvocation, WithMetadata, WithBoard): class FaceIdentifierInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Outputs an image with detected face IDs printed on each face. For use with other FaceTools.""" """Outputs an image with detected face IDs printed on each face. For use with other FaceTools."""

View File

@@ -39,15 +39,13 @@ class UIType(str, Enum, metaclass=MetaEnum):
""" """
# region Model Field Types # region Model Field Types
MainModel = "MainModelField"
SDXLMainModel = "SDXLMainModelField" SDXLMainModel = "SDXLMainModelField"
SDXLRefinerModel = "SDXLRefinerModelField" SDXLRefinerModel = "SDXLRefinerModelField"
ONNXModel = "ONNXModelField" ONNXModel = "ONNXModelField"
VAEModel = "VAEModelField" VaeModel = "VAEModelField"
LoRAModel = "LoRAModelField" LoRAModel = "LoRAModelField"
ControlNetModel = "ControlNetModelField" ControlNetModel = "ControlNetModelField"
IPAdapterModel = "IPAdapterModelField" IPAdapterModel = "IPAdapterModelField"
T2IAdapterModel = "T2IAdapterModelField"
# endregion # endregion
# region Misc Field Types # region Misc Field Types
@@ -88,6 +86,7 @@ class UIType(str, Enum, metaclass=MetaEnum):
IntegerPolymorphic = "DEPRECATED_IntegerPolymorphic" IntegerPolymorphic = "DEPRECATED_IntegerPolymorphic"
LatentsPolymorphic = "DEPRECATED_LatentsPolymorphic" LatentsPolymorphic = "DEPRECATED_LatentsPolymorphic"
StringPolymorphic = "DEPRECATED_StringPolymorphic" StringPolymorphic = "DEPRECATED_StringPolymorphic"
MainModel = "DEPRECATED_MainModel"
UNet = "DEPRECATED_UNet" UNet = "DEPRECATED_UNet"
Vae = "DEPRECATED_Vae" Vae = "DEPRECATED_Vae"
CLIP = "DEPRECATED_CLIP" CLIP = "DEPRECATED_CLIP"
@@ -200,7 +199,6 @@ class DenoiseMaskField(BaseModel):
mask_name: str = Field(description="The name of the mask image") mask_name: str = Field(description="The name of the mask image")
masked_latents_name: Optional[str] = Field(default=None, description="The name of the masked image latents") masked_latents_name: Optional[str] = Field(default=None, description="The name of the masked image latents")
gradient: bool = Field(default=False, description="Used for gradient inpainting")
class LatentsField(BaseModel): class LatentsField(BaseModel):
@@ -229,7 +227,7 @@ class ConditioningField(BaseModel):
# endregion # endregion
class MetadataField(RootModel[dict[str, Any]]): class MetadataField(RootModel):
""" """
Pydantic model for metadata with custom root of type dict[str, Any]. Pydantic model for metadata with custom root of type dict[str, Any].
Metadata is stored without a strict schema. Metadata is stored without a strict schema.

View File

@@ -22,7 +22,11 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
from invokeai.backend.image_util.safety_checker import SafetyChecker from invokeai.backend.image_util.safety_checker import SafetyChecker
from .baseinvocation import BaseInvocation, Classification, invocation from .baseinvocation import (
BaseInvocation,
Classification,
invocation,
)
@invocation("show_image", title="Show Image", tags=["image"], category="image", version="1.0.1") @invocation("show_image", title="Show Image", tags=["image"], category="image", version="1.0.1")
@@ -49,7 +53,7 @@ class ShowImageInvocation(BaseInvocation):
title="Blank Image", title="Blank Image",
tags=["image"], tags=["image"],
category="image", category="image",
version="1.2.2", version="1.2.1",
) )
class BlankImageInvocation(BaseInvocation, WithMetadata, WithBoard): class BlankImageInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Creates a blank image and forwards it to the pipeline""" """Creates a blank image and forwards it to the pipeline"""
@@ -72,7 +76,7 @@ class BlankImageInvocation(BaseInvocation, WithMetadata, WithBoard):
title="Crop Image", title="Crop Image",
tags=["image", "crop"], tags=["image", "crop"],
category="image", category="image",
version="1.2.2", version="1.2.1",
) )
class ImageCropInvocation(BaseInvocation, WithMetadata, WithBoard): class ImageCropInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Crops an image to a specified box. The box can be outside of the image.""" """Crops an image to a specified box. The box can be outside of the image."""
@@ -143,7 +147,7 @@ class CenterPadCropInvocation(BaseInvocation):
title="Paste Image", title="Paste Image",
tags=["image", "paste"], tags=["image", "paste"],
category="image", category="image",
version="1.2.2", version="1.2.1",
) )
class ImagePasteInvocation(BaseInvocation, WithMetadata, WithBoard): class ImagePasteInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Pastes an image into another image.""" """Pastes an image into another image."""
@@ -190,7 +194,7 @@ class ImagePasteInvocation(BaseInvocation, WithMetadata, WithBoard):
title="Mask from Alpha", title="Mask from Alpha",
tags=["image", "mask"], tags=["image", "mask"],
category="image", category="image",
version="1.2.2", version="1.2.1",
) )
class MaskFromAlphaInvocation(BaseInvocation, WithMetadata, WithBoard): class MaskFromAlphaInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Extracts the alpha channel of an image as a mask.""" """Extracts the alpha channel of an image as a mask."""
@@ -215,7 +219,7 @@ class MaskFromAlphaInvocation(BaseInvocation, WithMetadata, WithBoard):
title="Multiply Images", title="Multiply Images",
tags=["image", "multiply"], tags=["image", "multiply"],
category="image", category="image",
version="1.2.2", version="1.2.1",
) )
class ImageMultiplyInvocation(BaseInvocation, WithMetadata, WithBoard): class ImageMultiplyInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Multiplies two images together using `PIL.ImageChops.multiply()`.""" """Multiplies two images together using `PIL.ImageChops.multiply()`."""
@@ -242,7 +246,7 @@ IMAGE_CHANNELS = Literal["A", "R", "G", "B"]
title="Extract Image Channel", title="Extract Image Channel",
tags=["image", "channel"], tags=["image", "channel"],
category="image", category="image",
version="1.2.2", version="1.2.1",
) )
class ImageChannelInvocation(BaseInvocation, WithMetadata, WithBoard): class ImageChannelInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Gets a channel from an image.""" """Gets a channel from an image."""
@@ -265,7 +269,7 @@ class ImageChannelInvocation(BaseInvocation, WithMetadata, WithBoard):
title="Convert Image Mode", title="Convert Image Mode",
tags=["image", "convert"], tags=["image", "convert"],
category="image", category="image",
version="1.2.2", version="1.2.1",
) )
class ImageConvertInvocation(BaseInvocation, WithMetadata, WithBoard): class ImageConvertInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Converts an image to a different mode.""" """Converts an image to a different mode."""
@@ -288,7 +292,7 @@ class ImageConvertInvocation(BaseInvocation, WithMetadata, WithBoard):
title="Blur Image", title="Blur Image",
tags=["image", "blur"], tags=["image", "blur"],
category="image", category="image",
version="1.2.2", version="1.2.1",
) )
class ImageBlurInvocation(BaseInvocation, WithMetadata, WithBoard): class ImageBlurInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Blurs an image""" """Blurs an image"""
@@ -316,7 +320,7 @@ class ImageBlurInvocation(BaseInvocation, WithMetadata, WithBoard):
title="Unsharp Mask", title="Unsharp Mask",
tags=["image", "unsharp_mask"], tags=["image", "unsharp_mask"],
category="image", category="image",
version="1.2.2", version="1.2.1",
classification=Classification.Beta, classification=Classification.Beta,
) )
class UnsharpMaskInvocation(BaseInvocation, WithMetadata, WithBoard): class UnsharpMaskInvocation(BaseInvocation, WithMetadata, WithBoard):
@@ -385,7 +389,7 @@ PIL_RESAMPLING_MAP = {
title="Resize Image", title="Resize Image",
tags=["image", "resize"], tags=["image", "resize"],
category="image", category="image",
version="1.2.2", version="1.2.1",
) )
class ImageResizeInvocation(BaseInvocation, WithMetadata, WithBoard): class ImageResizeInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Resizes an image to specific dimensions""" """Resizes an image to specific dimensions"""
@@ -415,7 +419,7 @@ class ImageResizeInvocation(BaseInvocation, WithMetadata, WithBoard):
title="Scale Image", title="Scale Image",
tags=["image", "scale"], tags=["image", "scale"],
category="image", category="image",
version="1.2.2", version="1.2.1",
) )
class ImageScaleInvocation(BaseInvocation, WithMetadata, WithBoard): class ImageScaleInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Scales an image by a factor""" """Scales an image by a factor"""
@@ -450,7 +454,7 @@ class ImageScaleInvocation(BaseInvocation, WithMetadata, WithBoard):
title="Lerp Image", title="Lerp Image",
tags=["image", "lerp"], tags=["image", "lerp"],
category="image", category="image",
version="1.2.2", version="1.2.1",
) )
class ImageLerpInvocation(BaseInvocation, WithMetadata, WithBoard): class ImageLerpInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Linear interpolation of all pixels of an image""" """Linear interpolation of all pixels of an image"""
@@ -477,7 +481,7 @@ class ImageLerpInvocation(BaseInvocation, WithMetadata, WithBoard):
title="Inverse Lerp Image", title="Inverse Lerp Image",
tags=["image", "ilerp"], tags=["image", "ilerp"],
category="image", category="image",
version="1.2.2", version="1.2.1",
) )
class ImageInverseLerpInvocation(BaseInvocation, WithMetadata, WithBoard): class ImageInverseLerpInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Inverse linear interpolation of all pixels of an image""" """Inverse linear interpolation of all pixels of an image"""
@@ -504,7 +508,7 @@ class ImageInverseLerpInvocation(BaseInvocation, WithMetadata, WithBoard):
title="Blur NSFW Image", title="Blur NSFW Image",
tags=["image", "nsfw"], tags=["image", "nsfw"],
category="image", category="image",
version="1.2.2", version="1.2.1",
) )
class ImageNSFWBlurInvocation(BaseInvocation, WithMetadata, WithBoard): class ImageNSFWBlurInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Add blur to NSFW-flagged images""" """Add blur to NSFW-flagged images"""
@@ -539,7 +543,7 @@ class ImageNSFWBlurInvocation(BaseInvocation, WithMetadata, WithBoard):
title="Add Invisible Watermark", title="Add Invisible Watermark",
tags=["image", "watermark"], tags=["image", "watermark"],
category="image", category="image",
version="1.2.2", version="1.2.1",
) )
class ImageWatermarkInvocation(BaseInvocation, WithMetadata, WithBoard): class ImageWatermarkInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Add an invisible watermark to an image""" """Add an invisible watermark to an image"""
@@ -560,7 +564,7 @@ class ImageWatermarkInvocation(BaseInvocation, WithMetadata, WithBoard):
title="Mask Edge", title="Mask Edge",
tags=["image", "mask", "inpaint"], tags=["image", "mask", "inpaint"],
category="image", category="image",
version="1.2.2", version="1.2.1",
) )
class MaskEdgeInvocation(BaseInvocation, WithMetadata, WithBoard): class MaskEdgeInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Applies an edge mask to an image""" """Applies an edge mask to an image"""
@@ -599,7 +603,7 @@ class MaskEdgeInvocation(BaseInvocation, WithMetadata, WithBoard):
title="Combine Masks", title="Combine Masks",
tags=["image", "mask", "multiply"], tags=["image", "mask", "multiply"],
category="image", category="image",
version="1.2.2", version="1.2.1",
) )
class MaskCombineInvocation(BaseInvocation, WithMetadata, WithBoard): class MaskCombineInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Combine two masks together by multiplying them using `PIL.ImageChops.multiply()`.""" """Combine two masks together by multiplying them using `PIL.ImageChops.multiply()`."""
@@ -623,7 +627,7 @@ class MaskCombineInvocation(BaseInvocation, WithMetadata, WithBoard):
title="Color Correct", title="Color Correct",
tags=["image", "color"], tags=["image", "color"],
category="image", category="image",
version="1.2.2", version="1.2.1",
) )
class ColorCorrectInvocation(BaseInvocation, WithMetadata, WithBoard): class ColorCorrectInvocation(BaseInvocation, WithMetadata, WithBoard):
""" """
@@ -727,7 +731,7 @@ class ColorCorrectInvocation(BaseInvocation, WithMetadata, WithBoard):
title="Adjust Image Hue", title="Adjust Image Hue",
tags=["image", "hue"], tags=["image", "hue"],
category="image", category="image",
version="1.2.2", version="1.2.1",
) )
class ImageHueAdjustmentInvocation(BaseInvocation, WithMetadata, WithBoard): class ImageHueAdjustmentInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Adjusts the Hue of an image.""" """Adjusts the Hue of an image."""
@@ -816,7 +820,7 @@ CHANNEL_FORMATS = {
"value", "value",
], ],
category="image", category="image",
version="1.2.2", version="1.2.1",
) )
class ImageChannelOffsetInvocation(BaseInvocation, WithMetadata, WithBoard): class ImageChannelOffsetInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Add or subtract a value from a specific color channel of an image.""" """Add or subtract a value from a specific color channel of an image."""
@@ -872,7 +876,7 @@ class ImageChannelOffsetInvocation(BaseInvocation, WithMetadata, WithBoard):
"value", "value",
], ],
category="image", category="image",
version="1.2.2", version="1.2.1",
) )
class ImageChannelMultiplyInvocation(BaseInvocation, WithMetadata, WithBoard): class ImageChannelMultiplyInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Scale a specific color channel of an image.""" """Scale a specific color channel of an image."""
@@ -916,7 +920,7 @@ class ImageChannelMultiplyInvocation(BaseInvocation, WithMetadata, WithBoard):
title="Save Image", title="Save Image",
tags=["primitives", "image"], tags=["primitives", "image"],
category="primitives", category="primitives",
version="1.2.2", version="1.2.1",
use_cache=False, use_cache=False,
) )
class SaveImageInvocation(BaseInvocation, WithMetadata, WithBoard): class SaveImageInvocation(BaseInvocation, WithMetadata, WithBoard):
@@ -930,93 +934,3 @@ class SaveImageInvocation(BaseInvocation, WithMetadata, WithBoard):
image_dto = context.images.save(image=image) image_dto = context.images.save(image=image)
return ImageOutput.build(image_dto) return ImageOutput.build(image_dto)
@invocation(
"canvas_paste_back",
title="Canvas Paste Back",
tags=["image", "combine"],
category="image",
version="1.0.0",
)
class CanvasPasteBackInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Combines two images by using the mask provided. Intended for use on the Unified Canvas."""
source_image: ImageField = InputField(description="The source image")
target_image: ImageField = InputField(default=None, description="The target image")
mask: ImageField = InputField(
description="The mask to use when pasting",
)
mask_blur: int = InputField(default=0, ge=0, description="The amount to blur the mask by")
def _prepare_mask(self, mask: Image.Image) -> Image.Image:
mask_array = numpy.array(mask)
kernel = numpy.ones((self.mask_blur, self.mask_blur), numpy.uint8)
dilated_mask_array = cv2.erode(mask_array, kernel, iterations=3)
dilated_mask = Image.fromarray(dilated_mask_array)
if self.mask_blur > 0:
mask = dilated_mask.filter(ImageFilter.GaussianBlur(self.mask_blur))
return ImageOps.invert(mask.convert("L"))
def invoke(self, context: InvocationContext) -> ImageOutput:
source_image = context.images.get_pil(self.source_image.image_name)
target_image = context.images.get_pil(self.target_image.image_name)
mask = self._prepare_mask(context.images.get_pil(self.mask.image_name))
source_image.paste(target_image, (0, 0), mask)
image_dto = context.images.save(image=source_image)
return ImageOutput.build(image_dto)
@invocation(
"mask_from_id",
title="Mask from ID",
tags=["image", "mask", "id"],
category="image",
version="1.0.0",
)
class MaskFromIDInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Generate a mask for a particular color in an ID Map"""
image: ImageField = InputField(description="The image to create the mask from")
color: ColorField = InputField(description="ID color to mask")
threshold: int = InputField(default=100, description="Threshold for color detection")
invert: bool = InputField(default=False, description="Whether or not to invert the mask")
def rgba_to_hex(self, rgba_color: tuple[int, int, int, int]):
r, g, b, a = rgba_color
hex_code = "#{:02X}{:02X}{:02X}{:02X}".format(r, g, b, int(a * 255))
return hex_code
def id_to_mask(self, id_mask: Image.Image, color: tuple[int, int, int, int], threshold: int = 100):
if id_mask.mode != "RGB":
id_mask = id_mask.convert("RGB")
# Can directly just use the tuple but I'll leave this rgba_to_hex here
# incase anyone prefers using hex codes directly instead of the color picker
hex_color_str = self.rgba_to_hex(color)
rgb_color = numpy.array([int(hex_color_str[i : i + 2], 16) for i in (1, 3, 5)])
# Maybe there's a faster way to calculate this distance but I can't think of any right now.
color_distance = numpy.linalg.norm(id_mask - rgb_color, axis=-1)
# Create a mask based on the threshold and the distance calculated above
binary_mask = (color_distance < threshold).astype(numpy.uint8) * 255
# Convert the mask back to PIL
binary_mask_pil = Image.fromarray(binary_mask)
return binary_mask_pil
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name)
mask = self.id_to_mask(image, self.color.tuple(), self.threshold)
if self.invert:
mask = ImageOps.invert(mask)
image_dto = context.images.save(image=mask, image_category=ImageCategory.MASK)
return ImageOutput.build(image_dto)

View File

@@ -9,7 +9,6 @@ from PIL import Image, ImageOps
from invokeai.app.invocations.fields import ColorField, ImageField from invokeai.app.invocations.fields import ColorField, ImageField
from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.download_with_progress import download_with_progress_bar
from invokeai.app.util.misc import SEED_MAX from invokeai.app.util.misc import SEED_MAX
from invokeai.backend.image_util.cv2_inpaint import cv2_inpaint from invokeai.backend.image_util.cv2_inpaint import cv2_inpaint
from invokeai.backend.image_util.lama import LaMA from invokeai.backend.image_util.lama import LaMA
@@ -121,7 +120,7 @@ def tile_fill_missing(im: Image.Image, tile_size: int = 16, seed: Optional[int]
return si return si
@invocation("infill_rgba", title="Solid Color Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.2") @invocation("infill_rgba", title="Solid Color Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.1")
class InfillColorInvocation(BaseInvocation, WithMetadata, WithBoard): class InfillColorInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Infills transparent areas of an image with a solid color""" """Infills transparent areas of an image with a solid color"""
@@ -144,7 +143,7 @@ class InfillColorInvocation(BaseInvocation, WithMetadata, WithBoard):
return ImageOutput.build(image_dto) return ImageOutput.build(image_dto)
@invocation("infill_tile", title="Tile Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.3") @invocation("infill_tile", title="Tile Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.2")
class InfillTileInvocation(BaseInvocation, WithMetadata, WithBoard): class InfillTileInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Infills transparent areas of an image with tiles of the image""" """Infills transparent areas of an image with tiles of the image"""
@@ -169,7 +168,7 @@ class InfillTileInvocation(BaseInvocation, WithMetadata, WithBoard):
@invocation( @invocation(
"infill_patchmatch", title="PatchMatch Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.2" "infill_patchmatch", title="PatchMatch Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.1"
) )
class InfillPatchMatchInvocation(BaseInvocation, WithMetadata, WithBoard): class InfillPatchMatchInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Infills transparent areas of an image using the PatchMatch algorithm""" """Infills transparent areas of an image using the PatchMatch algorithm"""
@@ -209,7 +208,7 @@ class InfillPatchMatchInvocation(BaseInvocation, WithMetadata, WithBoard):
return ImageOutput.build(image_dto) return ImageOutput.build(image_dto)
@invocation("infill_lama", title="LaMa Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.2") @invocation("infill_lama", title="LaMa Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.1")
class LaMaInfillInvocation(BaseInvocation, WithMetadata, WithBoard): class LaMaInfillInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Infills transparent areas of an image using the LaMa model""" """Infills transparent areas of an image using the LaMa model"""
@@ -218,13 +217,6 @@ class LaMaInfillInvocation(BaseInvocation, WithMetadata, WithBoard):
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name) image = context.images.get_pil(self.image.image_name)
# Downloads the LaMa model if it doesn't already exist
download_with_progress_bar(
name="LaMa Inpainting Model",
url="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
dest_path=context.config.get().models_path / "core/misc/lama/lama.pt",
)
infilled = infill_lama(image.copy()) infilled = infill_lama(image.copy())
image_dto = context.images.save(image=infilled) image_dto = context.images.save(image=infilled)
@@ -232,7 +224,7 @@ class LaMaInfillInvocation(BaseInvocation, WithMetadata, WithBoard):
return ImageOutput.build(image_dto) return ImageOutput.build(image_dto)
@invocation("infill_cv2", title="CV2 Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.2") @invocation("infill_cv2", title="CV2 Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.1")
class CV2InfillInvocation(BaseInvocation, WithMetadata, WithBoard): class CV2InfillInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Infills transparent areas of an image using OpenCV Inpainting""" """Infills transparent areas of an image using OpenCV Inpainting"""

View File

@@ -10,18 +10,26 @@ from invokeai.app.invocations.baseinvocation import (
invocation, invocation,
invocation_output, invocation_output,
) )
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.primitives import ImageField from invokeai.app.invocations.primitives import ImageField
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, IPAdapterConfig, ModelType from invokeai.backend.model_manager.config import BaseModelType, ModelType
# LS: Consider moving these two classes into model.py
class IPAdapterModelField(BaseModel):
key: str = Field(description="Key to the IP-Adapter model")
class CLIPVisionModelField(BaseModel):
key: str = Field(description="Key to the CLIP Vision image encoder model")
class IPAdapterField(BaseModel): class IPAdapterField(BaseModel):
image: Union[ImageField, List[ImageField]] = Field(description="The IP-Adapter image prompt(s).") image: Union[ImageField, List[ImageField]] = Field(description="The IP-Adapter image prompt(s).")
ip_adapter_model: ModelIdentifierField = Field(description="The IP-Adapter model to use.") ip_adapter_model: IPAdapterModelField = Field(description="The IP-Adapter model to use.")
image_encoder_model: ModelIdentifierField = Field(description="The name of the CLIP image encoder model.") image_encoder_model: CLIPVisionModelField = Field(description="The name of the CLIP image encoder model.")
weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet") weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
begin_step_percent: float = Field( begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the IP-Adapter is first applied (% of total steps)" default=0, ge=0, le=1, description="When the IP-Adapter is first applied (% of total steps)"
@@ -48,18 +56,14 @@ class IPAdapterOutput(BaseInvocationOutput):
ip_adapter: IPAdapterField = OutputField(description=FieldDescriptions.ip_adapter, title="IP-Adapter") ip_adapter: IPAdapterField = OutputField(description=FieldDescriptions.ip_adapter, title="IP-Adapter")
@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.2.2") @invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.1.2")
class IPAdapterInvocation(BaseInvocation): class IPAdapterInvocation(BaseInvocation):
"""Collects IP-Adapter info to pass to other nodes.""" """Collects IP-Adapter info to pass to other nodes."""
# Inputs # Inputs
image: Union[ImageField, List[ImageField]] = InputField(description="The IP-Adapter image prompt(s).") image: Union[ImageField, List[ImageField]] = InputField(description="The IP-Adapter image prompt(s).")
ip_adapter_model: ModelIdentifierField = InputField( ip_adapter_model: IPAdapterModelField = InputField(
description="The IP-Adapter model.", description="The IP-Adapter model.", title="IP-Adapter Model", input=Input.Direct, ui_order=-1
title="IP-Adapter Model",
input=Input.Direct,
ui_order=-1,
ui_type=UIType.IPAdapterModel,
) )
weight: Union[float, List[float]] = InputField( weight: Union[float, List[float]] = InputField(
@@ -86,35 +90,20 @@ class IPAdapterInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> IPAdapterOutput: def invoke(self, context: InvocationContext) -> IPAdapterOutput:
# Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model. # Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model.
ip_adapter_info = context.models.get_config(self.ip_adapter_model.key) ip_adapter_info = context.models.get_config(self.ip_adapter_model.key)
assert isinstance(ip_adapter_info, IPAdapterConfig)
image_encoder_model_id = ip_adapter_info.image_encoder_model_id image_encoder_model_id = ip_adapter_info.image_encoder_model_id
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip() image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
image_encoder_model = self._get_image_encoder(context, image_encoder_model_name) image_encoder_models = context.models.search_by_attrs(
model_name=image_encoder_model_name, base_model=BaseModelType.Any, model_type=ModelType.CLIPVision
)
assert len(image_encoder_models) == 1
image_encoder_model = CLIPVisionModelField(key=image_encoder_models[0].key)
return IPAdapterOutput( return IPAdapterOutput(
ip_adapter=IPAdapterField( ip_adapter=IPAdapterField(
image=self.image, image=self.image,
ip_adapter_model=self.ip_adapter_model, ip_adapter_model=self.ip_adapter_model,
image_encoder_model=ModelIdentifierField.from_config(image_encoder_model), image_encoder_model=image_encoder_model,
weight=self.weight, weight=self.weight,
begin_step_percent=self.begin_step_percent, begin_step_percent=self.begin_step_percent,
end_step_percent=self.end_step_percent, end_step_percent=self.end_step_percent,
), ),
) )
def _get_image_encoder(self, context: InvocationContext, image_encoder_model_name: str) -> AnyModelConfig:
found = False
while not found:
image_encoder_models = context.models.search_by_attrs(
name=image_encoder_model_name, base=BaseModelType.Any, type=ModelType.CLIPVision
)
found = len(image_encoder_models) > 0
if not found:
context.logger.warning(
f"The image encoder required by this IP Adapter ({image_encoder_model_name}) is not installed."
)
context.logger.warning("Downloading and installing now. This may take a while.")
installer = context._services.model_manager.install
job = installer.heuristic_import(f"InvokeAI/{image_encoder_model_name}")
installer.wait_for_job(job, timeout=600) # wait up to 10 minutes - then raise a TimeoutException
assert len(image_encoder_models) == 1
return image_encoder_models[0]

View File

@@ -23,10 +23,9 @@ from diffusers.models.attention_processor import (
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from diffusers.schedulers import DPMSolverSDEScheduler from diffusers.schedulers import DPMSolverSDEScheduler
from diffusers.schedulers import SchedulerMixin as Scheduler from diffusers.schedulers import SchedulerMixin as Scheduler
from PIL import Image, ImageFilter from PIL import Image
from pydantic import field_validator from pydantic import field_validator
from torchvision.transforms.functional import resize as tv_resize from torchvision.transforms.functional import resize as tv_resize
from transformers import CLIPVisionModelWithProjection
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES
from invokeai.app.invocations.fields import ( from invokeai.app.invocations.fields import (
@@ -66,6 +65,7 @@ from ...backend.stable_diffusion.diffusers_pipeline import (
T2IAdapterData, T2IAdapterData,
image_resized_to_grid_as_tensor, image_resized_to_grid_as_tensor,
) )
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
from ...backend.util.devices import choose_precision, choose_torch_device from ...backend.util.devices import choose_precision, choose_torch_device
from .baseinvocation import ( from .baseinvocation import (
@@ -75,7 +75,7 @@ from .baseinvocation import (
invocation_output, invocation_output,
) )
from .controlnet_image_processors import ControlField from .controlnet_image_processors import ControlField
from .model import ModelIdentifierField, UNetField, VAEField from .model import ModelInfo, UNetField, VaeField
if choose_torch_device() == torch.device("mps"): if choose_torch_device() == torch.device("mps"):
from torch import mps from torch import mps
@@ -113,12 +113,12 @@ class SchedulerInvocation(BaseInvocation):
title="Create Denoise Mask", title="Create Denoise Mask",
tags=["mask", "denoise"], tags=["mask", "denoise"],
category="latents", category="latents",
version="1.0.2", version="1.0.1",
) )
class CreateDenoiseMaskInvocation(BaseInvocation): class CreateDenoiseMaskInvocation(BaseInvocation):
"""Creates mask for denoising model run.""" """Creates mask for denoising model run."""
vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection, ui_order=0) vae: VaeField = InputField(description=FieldDescriptions.vae, input=Input.Connection, ui_order=0)
image: Optional[ImageField] = InputField(default=None, description="Image which will be masked", ui_order=1) image: Optional[ImageField] = InputField(default=None, description="Image which will be masked", ui_order=1)
mask: ImageField = InputField(description="The mask to use when pasting", ui_order=2) mask: ImageField = InputField(description="The mask to use when pasting", ui_order=2)
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled, ui_order=3) tiled: bool = InputField(default=False, description=FieldDescriptions.tiled, ui_order=3)
@@ -128,7 +128,7 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
ui_order=4, ui_order=4,
) )
def prep_mask_tensor(self, mask_image: Image.Image) -> torch.Tensor: def prep_mask_tensor(self, mask_image: Image) -> torch.Tensor:
if mask_image.mode != "L": if mask_image.mode != "L":
mask_image = mask_image.convert("L") mask_image = mask_image.convert("L")
mask_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False) mask_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
@@ -153,7 +153,7 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
) )
if image_tensor is not None: if image_tensor is not None:
vae_info = context.models.load(self.vae.vae) vae_info = context.models.load(**self.vae.vae.model_dump())
img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False) img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0) masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0)
@@ -169,87 +169,17 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
return DenoiseMaskOutput.build( return DenoiseMaskOutput.build(
mask_name=mask_name, mask_name=mask_name,
masked_latents_name=masked_latents_name, masked_latents_name=masked_latents_name,
gradient=False,
)
@invocation_output("gradient_mask_output")
class GradientMaskOutput(BaseInvocationOutput):
"""Outputs a denoise mask and an image representing the total gradient of the mask."""
denoise_mask: DenoiseMaskField = OutputField(description="Mask for denoise model run")
expanded_mask_area: ImageField = OutputField(
description="Image representing the total gradient area of the mask. For paste-back purposes."
)
@invocation(
"create_gradient_mask",
title="Create Gradient Mask",
tags=["mask", "denoise"],
category="latents",
version="1.0.0",
)
class CreateGradientMaskInvocation(BaseInvocation):
"""Creates mask for denoising model run."""
mask: ImageField = InputField(default=None, description="Image which will be masked", ui_order=1)
edge_radius: int = InputField(
default=16, ge=0, description="How far to blur/expand the edges of the mask", ui_order=2
)
coherence_mode: Literal["Gaussian Blur", "Box Blur", "Staged"] = InputField(default="Gaussian Blur", ui_order=3)
minimum_denoise: float = InputField(
default=0.0, ge=0, le=1, description="Minimum denoise level for the coherence region", ui_order=4
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> GradientMaskOutput:
mask_image = context.images.get_pil(self.mask.image_name, mode="L")
if self.edge_radius > 0:
if self.coherence_mode == "Box Blur":
blur_mask = mask_image.filter(ImageFilter.BoxBlur(self.edge_radius))
else: # Gaussian Blur OR Staged
# Gaussian Blur uses standard deviation. 1/2 radius is a good approximation
blur_mask = mask_image.filter(ImageFilter.GaussianBlur(self.edge_radius / 2))
blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(blur_mask, normalize=False)
# redistribute blur so that the original edges are 0 and blur outwards to 1
blur_tensor = (blur_tensor - 0.5) * 2
threshold = 1 - self.minimum_denoise
if self.coherence_mode == "Staged":
# wherever the blur_tensor is less than fully masked, convert it to threshold
blur_tensor = torch.where((blur_tensor < 1) & (blur_tensor > 0), threshold, blur_tensor)
else:
# wherever the blur_tensor is above threshold but less than 1, drop it to threshold
blur_tensor = torch.where((blur_tensor > threshold) & (blur_tensor < 1), threshold, blur_tensor)
else:
blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
mask_name = context.tensors.save(tensor=blur_tensor.unsqueeze(1))
# compute a [0, 1] mask from the blur_tensor
expanded_mask = torch.where((blur_tensor < 1), 0, 1)
expanded_mask_image = Image.fromarray((expanded_mask.squeeze(0).numpy() * 255).astype(np.uint8), mode="L")
expanded_image_dto = context.images.save(expanded_mask_image)
return GradientMaskOutput(
denoise_mask=DenoiseMaskField(mask_name=mask_name, masked_latents_name=None, gradient=True),
expanded_mask_area=ImageField(image_name=expanded_image_dto.image_name),
) )
def get_scheduler( def get_scheduler(
context: InvocationContext, context: InvocationContext,
scheduler_info: ModelIdentifierField, scheduler_info: ModelInfo,
scheduler_name: str, scheduler_name: str,
seed: int, seed: int,
) -> Scheduler: ) -> Scheduler:
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"]) scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
orig_scheduler_info = context.models.load(scheduler_info) orig_scheduler_info = context.models.load(**scheduler_info.model_dump())
with orig_scheduler_info as orig_scheduler: with orig_scheduler_info as orig_scheduler:
scheduler_config = orig_scheduler.config scheduler_config = orig_scheduler.config
@@ -279,7 +209,7 @@ def get_scheduler(
title="Denoise Latents", title="Denoise Latents",
tags=["latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"], tags=["latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"],
category="latents", category="latents",
version="1.5.3", version="1.5.2",
) )
class DenoiseLatentsInvocation(BaseInvocation): class DenoiseLatentsInvocation(BaseInvocation):
"""Denoises noisy latents to decodable images""" """Denoises noisy latents to decodable images"""
@@ -374,6 +304,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
) -> ConditioningData: ) -> ConditioningData:
positive_cond_data = context.conditioning.load(self.positive_conditioning.conditioning_name) positive_cond_data = context.conditioning.load(self.positive_conditioning.conditioning_name)
c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype) c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
extra_conditioning_info = c.extra_conditioning
negative_cond_data = context.conditioning.load(self.negative_conditioning.conditioning_name) negative_cond_data = context.conditioning.load(self.negative_conditioning.conditioning_name)
uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype) uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
@@ -383,6 +314,13 @@ class DenoiseLatentsInvocation(BaseInvocation):
text_embeddings=c, text_embeddings=c,
guidance_scale=self.cfg_scale, guidance_scale=self.cfg_scale,
guidance_rescale_multiplier=self.cfg_rescale_multiplier, guidance_rescale_multiplier=self.cfg_rescale_multiplier,
extra=extra_conditioning_info,
postprocessing_settings=PostprocessingSettings(
threshold=0.0, # threshold,
warmup=0.2, # warmup,
h_symmetry_time_pct=None, # h_symmetry_time_pct,
v_symmetry_time_pct=None, # v_symmetry_time_pct,
),
) )
conditioning_data = conditioning_data.add_scheduler_args_if_applicable( # FIXME conditioning_data = conditioning_data.add_scheduler_args_if_applicable( # FIXME
@@ -455,7 +393,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
# and if weight is None, populate with default 1.0? # and if weight is None, populate with default 1.0?
controlnet_data = [] controlnet_data = []
for control_info in control_list: for control_info in control_list:
control_model = exit_stack.enter_context(context.models.load(control_info.control_model)) control_model = exit_stack.enter_context(context.models.load(key=control_info.control_model.key))
# control_models.append(control_model) # control_models.append(control_model)
control_image_field = control_info.image control_image_field = control_info.image
@@ -517,10 +455,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
conditioning_data.ip_adapter_conditioning = [] conditioning_data.ip_adapter_conditioning = []
for single_ip_adapter in ip_adapter: for single_ip_adapter in ip_adapter:
ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context( ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context(
context.models.load(single_ip_adapter.ip_adapter_model) context.models.load(key=single_ip_adapter.ip_adapter_model.key)
) )
image_encoder_model_info = context.models.load(single_ip_adapter.image_encoder_model) image_encoder_model_info = context.models.load(key=single_ip_adapter.image_encoder_model.key)
# `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here. # `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here.
single_ipa_image_fields = single_ip_adapter.image single_ipa_image_fields = single_ip_adapter.image
if not isinstance(single_ipa_image_fields, list): if not isinstance(single_ipa_image_fields, list):
@@ -531,7 +470,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
# TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other # TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other
# models are needed in memory. This would help to reduce peak memory utilization in low-memory environments. # models are needed in memory. This would help to reduce peak memory utilization in low-memory environments.
with image_encoder_model_info as image_encoder_model: with image_encoder_model_info as image_encoder_model:
assert isinstance(image_encoder_model, CLIPVisionModelWithProjection)
# Get image embeddings from CLIP and ImageProjModel. # Get image embeddings from CLIP and ImageProjModel.
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_model.get_image_embeds( image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_model.get_image_embeds(
single_ipa_images, image_encoder_model single_ipa_images, image_encoder_model
@@ -571,8 +509,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
t2i_adapter_data = [] t2i_adapter_data = []
for t2i_adapter_field in t2i_adapter: for t2i_adapter_field in t2i_adapter:
t2i_adapter_model_config = context.models.get_config(t2i_adapter_field.t2i_adapter_model.key) t2i_adapter_model_config = context.models.get_config(key=t2i_adapter_field.t2i_adapter_model.key)
t2i_adapter_loaded_model = context.models.load(t2i_adapter_field.t2i_adapter_model) t2i_adapter_loaded_model = context.models.load(key=t2i_adapter_field.t2i_adapter_model.key)
image = context.images.get_pil(t2i_adapter_field.image.image_name) image = context.images.get_pil(t2i_adapter_field.image.image_name)
# The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally. # The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally.
@@ -668,18 +606,18 @@ class DenoiseLatentsInvocation(BaseInvocation):
def prep_inpaint_mask( def prep_inpaint_mask(
self, context: InvocationContext, latents: torch.Tensor self, context: InvocationContext, latents: torch.Tensor
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], bool]: ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
if self.denoise_mask is None: if self.denoise_mask is None:
return None, None, False return None, None
mask = context.tensors.load(self.denoise_mask.mask_name) mask = context.tensors.load(self.denoise_mask.mask_name)
mask = tv_resize(mask, latents.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False) mask = tv_resize(mask, latents.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
if self.denoise_mask.masked_latents_name is not None: if self.denoise_mask.masked_latents_name is not None:
masked_latents = context.tensors.load(self.denoise_mask.masked_latents_name) masked_latents = context.tensors.load(self.denoise_mask.masked_latents_name)
else: else:
masked_latents = torch.where(mask < 0.5, 0.0, latents) masked_latents = None
return 1 - mask, masked_latents, self.denoise_mask.gradient return 1 - mask, masked_latents
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput: def invoke(self, context: InvocationContext) -> LatentsOutput:
@@ -706,7 +644,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
if seed is None: if seed is None:
seed = 0 seed = 0
mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents) mask, masked_latents = self.prep_inpaint_mask(context, latents)
# TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets, # TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets,
# below. Investigate whether this is appropriate. # below. Investigate whether this is appropriate.
@@ -725,13 +663,12 @@ class DenoiseLatentsInvocation(BaseInvocation):
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.unet.loras: for lora in self.unet.loras:
lora_info = context.models.load(lora.lora) lora_info = context.models.load(**lora.model_dump(exclude={"weight"}))
assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight) yield (lora_info.model, lora.weight)
del lora_info del lora_info
return return
unet_info = context.models.load(self.unet.unet) unet_info = context.models.load(**self.unet.unet.model_dump())
assert isinstance(unet_info.model, UNet2DConditionModel) assert isinstance(unet_info.model, UNet2DConditionModel)
with ( with (
ExitStack() as exit_stack, ExitStack() as exit_stack,
@@ -784,7 +721,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
denoising_end=self.denoising_end, denoising_end=self.denoising_end,
) )
result_latents = pipeline.latents_from_embeddings( (
result_latents,
result_attention_map_saver,
) = pipeline.latents_from_embeddings(
latents=latents, latents=latents,
timesteps=timesteps, timesteps=timesteps,
init_timestep=init_timestep, init_timestep=init_timestep,
@@ -792,7 +732,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
seed=seed, seed=seed,
mask=mask, mask=mask,
masked_latents=masked_latents, masked_latents=masked_latents,
gradient_mask=gradient_mask,
num_inference_steps=num_inference_steps, num_inference_steps=num_inference_steps,
conditioning_data=conditioning_data, conditioning_data=conditioning_data,
control_data=controlnet_data, control_data=controlnet_data,
@@ -816,7 +755,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
title="Latents to Image", title="Latents to Image",
tags=["latents", "image", "vae", "l2i"], tags=["latents", "image", "vae", "l2i"],
category="latents", category="latents",
version="1.2.2", version="1.2.1",
) )
class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard): class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Generates an image from latents.""" """Generates an image from latents."""
@@ -825,7 +764,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
description=FieldDescriptions.latents, description=FieldDescriptions.latents,
input=Input.Connection, input=Input.Connection,
) )
vae: VAEField = InputField( vae: VaeField = InputField(
description=FieldDescriptions.vae, description=FieldDescriptions.vae,
input=Input.Connection, input=Input.Connection,
) )
@@ -836,15 +775,15 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.tensors.load(self.latents.latents_name) latents = context.tensors.load(self.latents.latents_name)
vae_info = context.models.load(self.vae.vae) vae_info = context.models.load(**self.vae.vae.model_dump())
assert isinstance(vae_info.model, (UNet2DConditionModel, AutoencoderKL, AutoencoderTiny))
with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae: with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae:
assert isinstance(vae, torch.nn.Module) assert isinstance(vae, torch.nn.Module)
latents = latents.to(vae.device) latents = latents.to(vae.device)
if self.fp32: if self.fp32:
vae.to(dtype=torch.float32) vae.to(dtype=torch.float32)
use_torch_2_0_or_xformers = hasattr(vae.decoder, "mid_block") and isinstance( use_torch_2_0_or_xformers = isinstance(
vae.decoder.mid_block.attentions[0].processor, vae.decoder.mid_block.attentions[0].processor,
( (
AttnProcessor2_0, AttnProcessor2_0,
@@ -866,7 +805,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
vae.to(dtype=torch.float16) vae.to(dtype=torch.float16)
latents = latents.half() latents = latents.half()
if self.tiled or context.config.get().force_tiled_decode: if self.tiled or context.config.get().tiled_decode:
vae.enable_tiling() vae.enable_tiling()
else: else:
vae.disable_tiling() vae.disable_tiling()
@@ -903,7 +842,7 @@ LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear", "bilinear", "bicubic",
title="Resize Latents", title="Resize Latents",
tags=["latents", "resize"], tags=["latents", "resize"],
category="latents", category="latents",
version="1.0.2", version="1.0.1",
) )
class ResizeLatentsInvocation(BaseInvocation): class ResizeLatentsInvocation(BaseInvocation):
"""Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8.""" """Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8."""
@@ -953,7 +892,7 @@ class ResizeLatentsInvocation(BaseInvocation):
title="Scale Latents", title="Scale Latents",
tags=["latents", "resize"], tags=["latents", "resize"],
category="latents", category="latents",
version="1.0.2", version="1.0.1",
) )
class ScaleLatentsInvocation(BaseInvocation): class ScaleLatentsInvocation(BaseInvocation):
"""Scales latents by a given factor.""" """Scales latents by a given factor."""
@@ -995,7 +934,7 @@ class ScaleLatentsInvocation(BaseInvocation):
title="Image to Latents", title="Image to Latents",
tags=["latents", "image", "vae", "i2l"], tags=["latents", "image", "vae", "i2l"],
category="latents", category="latents",
version="1.0.2", version="1.0.1",
) )
class ImageToLatentsInvocation(BaseInvocation): class ImageToLatentsInvocation(BaseInvocation):
"""Encodes an image into latents.""" """Encodes an image into latents."""
@@ -1003,7 +942,7 @@ class ImageToLatentsInvocation(BaseInvocation):
image: ImageField = InputField( image: ImageField = InputField(
description="The image to encode", description="The image to encode",
) )
vae: VAEField = InputField( vae: VaeField = InputField(
description=FieldDescriptions.vae, description=FieldDescriptions.vae,
input=Input.Connection, input=Input.Connection,
) )
@@ -1018,7 +957,7 @@ class ImageToLatentsInvocation(BaseInvocation):
if upcast: if upcast:
vae.to(dtype=torch.float32) vae.to(dtype=torch.float32)
use_torch_2_0_or_xformers = hasattr(vae.decoder, "mid_block") and isinstance( use_torch_2_0_or_xformers = isinstance(
vae.decoder.mid_block.attentions[0].processor, vae.decoder.mid_block.attentions[0].processor,
( (
AttnProcessor2_0, AttnProcessor2_0,
@@ -1059,7 +998,7 @@ class ImageToLatentsInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> LatentsOutput: def invoke(self, context: InvocationContext) -> LatentsOutput:
image = context.images.get_pil(self.image.image_name) image = context.images.get_pil(self.image.image_name)
vae_info = context.models.load(self.vae.vae) vae_info = context.models.load(**self.vae.vae.model_dump())
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB")) image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image_tensor.dim() == 3: if image_tensor.dim() == 3:
@@ -1094,7 +1033,7 @@ class ImageToLatentsInvocation(BaseInvocation):
title="Blend Latents", title="Blend Latents",
tags=["latents", "blend"], tags=["latents", "blend"],
category="latents", category="latents",
version="1.0.2", version="1.0.1",
) )
class BlendLatentsInvocation(BaseInvocation): class BlendLatentsInvocation(BaseInvocation):
"""Blend two latents using a given alpha. Latents must have same size.""" """Blend two latents using a given alpha. Latents must have same size."""
@@ -1185,7 +1124,7 @@ class BlendLatentsInvocation(BaseInvocation):
title="Crop Latents", title="Crop Latents",
tags=["latents", "crop"], tags=["latents", "crop"],
category="latents", category="latents",
version="1.0.2", version="1.0.1",
) )
# TODO(ryand): Named `CropLatentsCoreInvocation` to prevent a conflict with custom node `CropLatentsInvocation`. # TODO(ryand): Named `CropLatentsCoreInvocation` to prevent a conflict with custom node `CropLatentsInvocation`.
# Currently, if the class names conflict then 'GET /openapi.json' fails. # Currently, if the class names conflict then 'GET /openapi.json' fails.
@@ -1246,7 +1185,7 @@ class IdealSizeOutput(BaseInvocationOutput):
"ideal_size", "ideal_size",
title="Ideal Size", title="Ideal Size",
tags=["latents", "math", "ideal_size"], tags=["latents", "math", "ideal_size"],
version="1.0.3", version="1.0.2",
) )
class IdealSizeInvocation(BaseInvocation): class IdealSizeInvocation(BaseInvocation):
"""Calculates the ideal size for generation to avoid duplication""" """Calculates the ideal size for generation to avoid duplication"""

View File

@@ -12,7 +12,7 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
from .baseinvocation import BaseInvocation, invocation from .baseinvocation import BaseInvocation, invocation
@invocation("add", title="Add Integers", tags=["math", "add"], category="math", version="1.0.1") @invocation("add", title="Add Integers", tags=["math", "add"], category="math", version="1.0.0")
class AddInvocation(BaseInvocation): class AddInvocation(BaseInvocation):
"""Adds two numbers""" """Adds two numbers"""
@@ -23,7 +23,7 @@ class AddInvocation(BaseInvocation):
return IntegerOutput(value=self.a + self.b) return IntegerOutput(value=self.a + self.b)
@invocation("sub", title="Subtract Integers", tags=["math", "subtract"], category="math", version="1.0.1") @invocation("sub", title="Subtract Integers", tags=["math", "subtract"], category="math", version="1.0.0")
class SubtractInvocation(BaseInvocation): class SubtractInvocation(BaseInvocation):
"""Subtracts two numbers""" """Subtracts two numbers"""
@@ -34,7 +34,7 @@ class SubtractInvocation(BaseInvocation):
return IntegerOutput(value=self.a - self.b) return IntegerOutput(value=self.a - self.b)
@invocation("mul", title="Multiply Integers", tags=["math", "multiply"], category="math", version="1.0.1") @invocation("mul", title="Multiply Integers", tags=["math", "multiply"], category="math", version="1.0.0")
class MultiplyInvocation(BaseInvocation): class MultiplyInvocation(BaseInvocation):
"""Multiplies two numbers""" """Multiplies two numbers"""
@@ -45,7 +45,7 @@ class MultiplyInvocation(BaseInvocation):
return IntegerOutput(value=self.a * self.b) return IntegerOutput(value=self.a * self.b)
@invocation("div", title="Divide Integers", tags=["math", "divide"], category="math", version="1.0.1") @invocation("div", title="Divide Integers", tags=["math", "divide"], category="math", version="1.0.0")
class DivideInvocation(BaseInvocation): class DivideInvocation(BaseInvocation):
"""Divides two numbers""" """Divides two numbers"""
@@ -61,7 +61,7 @@ class DivideInvocation(BaseInvocation):
title="Random Integer", title="Random Integer",
tags=["math", "random"], tags=["math", "random"],
category="math", category="math",
version="1.0.1", version="1.0.0",
use_cache=False, use_cache=False,
) )
class RandomIntInvocation(BaseInvocation): class RandomIntInvocation(BaseInvocation):
@@ -100,7 +100,7 @@ class RandomFloatInvocation(BaseInvocation):
title="Float To Integer", title="Float To Integer",
tags=["math", "round", "integer", "float", "convert"], tags=["math", "round", "integer", "float", "convert"],
category="math", category="math",
version="1.0.1", version="1.0.0",
) )
class FloatToIntegerInvocation(BaseInvocation): class FloatToIntegerInvocation(BaseInvocation):
"""Rounds a float number to (a multiple of) an integer.""" """Rounds a float number to (a multiple of) an integer."""
@@ -122,7 +122,7 @@ class FloatToIntegerInvocation(BaseInvocation):
return IntegerOutput(value=int(self.value / self.multiple) * self.multiple) return IntegerOutput(value=int(self.value / self.multiple) * self.multiple)
@invocation("round_float", title="Round Float", tags=["math", "round"], category="math", version="1.0.1") @invocation("round_float", title="Round Float", tags=["math", "round"], category="math", version="1.0.0")
class RoundInvocation(BaseInvocation): class RoundInvocation(BaseInvocation):
"""Rounds a float to a specified number of decimal places.""" """Rounds a float to a specified number of decimal places."""
@@ -176,7 +176,7 @@ INTEGER_OPERATIONS_LABELS = {
"max", "max",
], ],
category="math", category="math",
version="1.0.1", version="1.0.0",
) )
class IntegerMathInvocation(BaseInvocation): class IntegerMathInvocation(BaseInvocation):
"""Performs integer math.""" """Performs integer math."""
@@ -250,7 +250,7 @@ FLOAT_OPERATIONS_LABELS = {
title="Float Math", title="Float Math",
tags=["math", "float", "add", "subtract", "multiply", "divide", "power", "root", "absolute value", "min", "max"], tags=["math", "float", "add", "subtract", "multiply", "divide", "power", "root", "absolute value", "min", "max"],
category="math", category="math",
version="1.0.1", version="1.0.0",
) )
class FloatMathInvocation(BaseInvocation): class FloatMathInvocation(BaseInvocation):
"""Performs floating point math.""" """Performs floating point math."""

View File

@@ -8,10 +8,7 @@ from invokeai.app.invocations.baseinvocation import (
invocation, invocation,
invocation_output, invocation_output,
) )
from invokeai.app.invocations.controlnet_image_processors import ( from invokeai.app.invocations.controlnet_image_processors import ControlField
CONTROLNET_MODE_VALUES,
CONTROLNET_RESIZE_VALUES,
)
from invokeai.app.invocations.fields import ( from invokeai.app.invocations.fields import (
FieldDescriptions, FieldDescriptions,
ImageField, ImageField,
@@ -20,7 +17,9 @@ from invokeai.app.invocations.fields import (
OutputField, OutputField,
UIType, UIType,
) )
from invokeai.app.invocations.model import ModelIdentifierField from invokeai.app.invocations.ip_adapter import IPAdapterModelField
from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from ...version import __version__ from ...version import __version__
@@ -34,7 +33,7 @@ class MetadataItemField(BaseModel):
class LoRAMetadataField(BaseModel): class LoRAMetadataField(BaseModel):
"""LoRA Metadata Field""" """LoRA Metadata Field"""
model: ModelIdentifierField = Field(description=FieldDescriptions.lora_model) lora: LoRAModelField = Field(description=FieldDescriptions.lora_model)
weight: float = Field(description=FieldDescriptions.lora_weight) weight: float = Field(description=FieldDescriptions.lora_weight)
@@ -42,41 +41,16 @@ class IPAdapterMetadataField(BaseModel):
"""IP Adapter Field, minus the CLIP Vision Encoder model""" """IP Adapter Field, minus the CLIP Vision Encoder model"""
image: ImageField = Field(description="The IP-Adapter image prompt.") image: ImageField = Field(description="The IP-Adapter image prompt.")
ip_adapter_model: ModelIdentifierField = Field(description="The IP-Adapter model.") ip_adapter_model: IPAdapterModelField = Field(
weight: Union[float, list[float]] = Field(description="The weight given to the IP-Adapter") description="The IP-Adapter model.",
)
weight: Union[float, list[float]] = Field(
description="The weight given to the IP-Adapter",
)
begin_step_percent: float = Field(description="When the IP-Adapter is first applied (% of total steps)") begin_step_percent: float = Field(description="When the IP-Adapter is first applied (% of total steps)")
end_step_percent: float = Field(description="When the IP-Adapter is last applied (% of total steps)") end_step_percent: float = Field(description="When the IP-Adapter is last applied (% of total steps)")
class T2IAdapterMetadataField(BaseModel):
image: ImageField = Field(description="The control image.")
processed_image: Optional[ImageField] = Field(default=None, description="The control image, after processing.")
t2i_adapter_model: ModelIdentifierField = Field(description="The T2I-Adapter model to use.")
weight: Union[float, list[float]] = Field(default=1, description="The weight given to the T2I-Adapter")
begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the T2I-Adapter is first applied (% of total steps)"
)
end_step_percent: float = Field(
default=1, ge=0, le=1, description="When the T2I-Adapter is last applied (% of total steps)"
)
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
class ControlNetMetadataField(BaseModel):
image: ImageField = Field(description="The control image")
processed_image: Optional[ImageField] = Field(default=None, description="The control image, after processing.")
control_model: ModelIdentifierField = Field(description="The ControlNet model to use")
control_weight: Union[float, list[float]] = Field(default=1, description="The weight given to the ControlNet")
begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
)
end_step_percent: float = Field(
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
)
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode to use")
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
@invocation_output("metadata_item_output") @invocation_output("metadata_item_output")
class MetadataItemOutput(BaseInvocationOutput): class MetadataItemOutput(BaseInvocationOutput):
"""Metadata Item Output""" """Metadata Item Output"""
@@ -84,7 +58,7 @@ class MetadataItemOutput(BaseInvocationOutput):
item: MetadataItemField = OutputField(description="Metadata Item") item: MetadataItemField = OutputField(description="Metadata Item")
@invocation("metadata_item", title="Metadata Item", tags=["metadata"], category="metadata", version="1.0.1") @invocation("metadata_item", title="Metadata Item", tags=["metadata"], category="metadata", version="1.0.0")
class MetadataItemInvocation(BaseInvocation): class MetadataItemInvocation(BaseInvocation):
"""Used to create an arbitrary metadata item. Provide "label" and make a connection to "value" to store that data as the value.""" """Used to create an arbitrary metadata item. Provide "label" and make a connection to "value" to store that data as the value."""
@@ -100,7 +74,7 @@ class MetadataOutput(BaseInvocationOutput):
metadata: MetadataField = OutputField(description="Metadata Dict") metadata: MetadataField = OutputField(description="Metadata Dict")
@invocation("metadata", title="Metadata", tags=["metadata"], category="metadata", version="1.0.1") @invocation("metadata", title="Metadata", tags=["metadata"], category="metadata", version="1.0.0")
class MetadataInvocation(BaseInvocation): class MetadataInvocation(BaseInvocation):
"""Takes a MetadataItem or collection of MetadataItems and outputs a MetadataDict.""" """Takes a MetadataItem or collection of MetadataItems and outputs a MetadataDict."""
@@ -121,7 +95,7 @@ class MetadataInvocation(BaseInvocation):
return MetadataOutput(metadata=MetadataField.model_validate(data)) return MetadataOutput(metadata=MetadataField.model_validate(data))
@invocation("merge_metadata", title="Metadata Merge", tags=["metadata"], category="metadata", version="1.0.1") @invocation("merge_metadata", title="Metadata Merge", tags=["metadata"], category="metadata", version="1.0.0")
class MergeMetadataInvocation(BaseInvocation): class MergeMetadataInvocation(BaseInvocation):
"""Merged a collection of MetadataDict into a single MetadataDict.""" """Merged a collection of MetadataDict into a single MetadataDict."""
@@ -140,7 +114,7 @@ GENERATION_MODES = Literal[
] ]
@invocation("core_metadata", title="Core Metadata", tags=["metadata"], category="metadata", version="2.0.0") @invocation("core_metadata", title="Core Metadata", tags=["metadata"], category="metadata", version="1.0.1")
class CoreMetadataInvocation(BaseInvocation): class CoreMetadataInvocation(BaseInvocation):
"""Collects core generation metadata into a MetadataField""" """Collects core generation metadata into a MetadataField"""
@@ -166,14 +140,14 @@ class CoreMetadataInvocation(BaseInvocation):
default=None, default=None,
description="The number of skipped CLIP layers", description="The number of skipped CLIP layers",
) )
model: Optional[ModelIdentifierField] = InputField(default=None, description="The main model used for inference") model: Optional[MainModelField] = InputField(default=None, description="The main model used for inference")
controlnets: Optional[list[ControlNetMetadataField]] = InputField( controlnets: Optional[list[ControlField]] = InputField(
default=None, description="The ControlNets used for inference" default=None, description="The ControlNets used for inference"
) )
ipAdapters: Optional[list[IPAdapterMetadataField]] = InputField( ipAdapters: Optional[list[IPAdapterMetadataField]] = InputField(
default=None, description="The IP Adapters used for inference" default=None, description="The IP Adapters used for inference"
) )
t2iAdapters: Optional[list[T2IAdapterMetadataField]] = InputField( t2iAdapters: Optional[list[T2IAdapterField]] = InputField(
default=None, description="The IP Adapters used for inference" default=None, description="The IP Adapters used for inference"
) )
loras: Optional[list[LoRAMetadataField]] = InputField(default=None, description="The LoRAs used for inference") loras: Optional[list[LoRAMetadataField]] = InputField(default=None, description="The LoRAs used for inference")
@@ -185,7 +159,7 @@ class CoreMetadataInvocation(BaseInvocation):
default=None, default=None,
description="The name of the initial image", description="The name of the initial image",
) )
vae: Optional[ModelIdentifierField] = InputField( vae: Optional[VAEModelField] = InputField(
default=None, default=None,
description="The VAE used for decoding, if the main model's default was not used", description="The VAE used for decoding, if the main model's default was not used",
) )
@@ -216,7 +190,7 @@ class CoreMetadataInvocation(BaseInvocation):
) )
# SDXL Refiner # SDXL Refiner
refiner_model: Optional[ModelIdentifierField] = InputField( refiner_model: Optional[MainModelField] = InputField(
default=None, default=None,
description="The SDXL Refiner model used", description="The SDXL Refiner model used",
) )
@@ -248,9 +222,10 @@ class CoreMetadataInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> MetadataOutput: def invoke(self, context: InvocationContext) -> MetadataOutput:
"""Collects and outputs a CoreMetadata object""" """Collects and outputs a CoreMetadata object"""
as_dict = self.model_dump(exclude_none=True, exclude={"id", "type", "is_intermediate", "use_cache"}) return MetadataOutput(
as_dict["app_version"] = __version__ metadata=MetadataField.model_validate(
self.model_dump(exclude_none=True, exclude={"id", "type", "is_intermediate", "use_cache"})
return MetadataOutput(metadata=MetadataField.model_validate(as_dict)) )
)
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")

View File

@@ -3,11 +3,11 @@ from typing import List, Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.shared.models import FreeUConfig from invokeai.app.shared.models import FreeUConfig
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType, SubModelType
from ...backend.model_manager import SubModelType
from .baseinvocation import ( from .baseinvocation import (
BaseInvocation, BaseInvocation,
BaseInvocationOutput, BaseInvocationOutput,
@@ -16,52 +16,33 @@ from .baseinvocation import (
) )
class ModelIdentifierField(BaseModel): class ModelInfo(BaseModel):
key: str = Field(description="The model's unique key") key: str = Field(description="Key of model as returned by ModelRecordServiceBase.get_model()")
hash: str = Field(description="The model's BLAKE3 hash") submodel_type: Optional[SubModelType] = Field(default=None, description="Info to load submodel")
name: str = Field(description="The model's name")
base: BaseModelType = Field(description="The model's base model type")
type: ModelType = Field(description="The model's type")
submodel_type: Optional[SubModelType] = Field(
description="The submodel to load, if this is a main model", default=None
)
@classmethod
def from_config(
cls, config: "AnyModelConfig", submodel_type: Optional[SubModelType] = None
) -> "ModelIdentifierField":
return cls(
key=config.key,
hash=config.hash,
name=config.name,
base=config.base,
type=config.type,
submodel_type=submodel_type,
)
class LoRAField(BaseModel): class LoraInfo(ModelInfo):
lora: ModelIdentifierField = Field(description="Info to load lora model") weight: float = Field(description="Lora's weight which to use when apply to model")
weight: float = Field(description="Weight to apply to lora model")
class UNetField(BaseModel): class UNetField(BaseModel):
unet: ModelIdentifierField = Field(description="Info to load unet submodel") unet: ModelInfo = Field(description="Info to load unet submodel")
scheduler: ModelIdentifierField = Field(description="Info to load scheduler submodel") scheduler: ModelInfo = Field(description="Info to load scheduler submodel")
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading") loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless') seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
freeu_config: Optional[FreeUConfig] = Field(default=None, description="FreeU configuration") freeu_config: Optional[FreeUConfig] = Field(default=None, description="FreeU configuration")
class CLIPField(BaseModel): class ClipField(BaseModel):
tokenizer: ModelIdentifierField = Field(description="Info to load tokenizer submodel") tokenizer: ModelInfo = Field(description="Info to load tokenizer submodel")
text_encoder: ModelIdentifierField = Field(description="Info to load text_encoder submodel") text_encoder: ModelInfo = Field(description="Info to load text_encoder submodel")
skipped_layers: int = Field(description="Number of skipped layers in text_encoder") skipped_layers: int = Field(description="Number of skipped layers in text_encoder")
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading") loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
class VAEField(BaseModel): class VaeField(BaseModel):
vae: ModelIdentifierField = Field(description="Info to load vae submodel") # TODO: better naming?
vae: ModelInfo = Field(description="Info to load vae submodel")
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless') seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
@@ -76,14 +57,14 @@ class UNetOutput(BaseInvocationOutput):
class VAEOutput(BaseInvocationOutput): class VAEOutput(BaseInvocationOutput):
"""Base class for invocations that output a VAE field""" """Base class for invocations that output a VAE field"""
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE") vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
@invocation_output("clip_output") @invocation_output("clip_output")
class CLIPOutput(BaseInvocationOutput): class CLIPOutput(BaseInvocationOutput):
"""Base class for invocations that output a CLIP field""" """Base class for invocations that output a CLIP field"""
clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP") clip: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP")
@invocation_output("model_loader_output") @invocation_output("model_loader_output")
@@ -93,54 +74,84 @@ class ModelLoaderOutput(UNetOutput, CLIPOutput, VAEOutput):
pass pass
class MainModelField(BaseModel):
"""Main model field"""
key: str = Field(description="Model key")
class LoRAModelField(BaseModel):
"""LoRA model field"""
key: str = Field(description="LoRA model key")
@invocation( @invocation(
"main_model_loader", "main_model_loader",
title="Main Model", title="Main Model",
tags=["model"], tags=["model"],
category="model", category="model",
version="1.0.2", version="1.0.1",
) )
class MainModelLoaderInvocation(BaseInvocation): class MainModelLoaderInvocation(BaseInvocation):
"""Loads a main model, outputting its submodels.""" """Loads a main model, outputting its submodels."""
model: ModelIdentifierField = InputField( model: MainModelField = InputField(description=FieldDescriptions.main_model, input=Input.Direct)
description=FieldDescriptions.main_model, input=Input.Direct, ui_type=UIType.MainModel
)
# TODO: precision? # TODO: precision?
def invoke(self, context: InvocationContext) -> ModelLoaderOutput: def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
# TODO: not found exceptions key = self.model.key
if not context.models.exists(self.model.key):
raise Exception(f"Unknown model {self.model.key}")
unet = self.model.model_copy(update={"submodel_type": SubModelType.UNet}) # TODO: not found exceptions
scheduler = self.model.model_copy(update={"submodel_type": SubModelType.Scheduler}) if not context.models.exists(key):
tokenizer = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer}) raise Exception(f"Unknown model {key}")
text_encoder = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
return ModelLoaderOutput( return ModelLoaderOutput(
unet=UNetField(unet=unet, scheduler=scheduler, loras=[]), unet=UNetField(
clip=CLIPField(tokenizer=tokenizer, text_encoder=text_encoder, loras=[], skipped_layers=0), unet=ModelInfo(
vae=VAEField(vae=vae), key=key,
submodel_type=SubModelType.UNet,
),
scheduler=ModelInfo(
key=key,
submodel_type=SubModelType.Scheduler,
),
loras=[],
),
clip=ClipField(
tokenizer=ModelInfo(
key=key,
submodel_type=SubModelType.Tokenizer,
),
text_encoder=ModelInfo(
key=key,
submodel_type=SubModelType.TextEncoder,
),
loras=[],
skipped_layers=0,
),
vae=VaeField(
vae=ModelInfo(
key=key,
submodel_type=SubModelType.Vae,
),
),
) )
@invocation_output("lora_loader_output") @invocation_output("lora_loader_output")
class LoRALoaderOutput(BaseInvocationOutput): class LoraLoaderOutput(BaseInvocationOutput):
"""Model loader output""" """Model loader output"""
unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet") unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
clip: Optional[CLIPField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP") clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
@invocation("lora_loader", title="LoRA", tags=["model"], category="model", version="1.0.2") @invocation("lora_loader", title="LoRA", tags=["model"], category="model", version="1.0.1")
class LoRALoaderInvocation(BaseInvocation): class LoraLoaderInvocation(BaseInvocation):
"""Apply selected lora to unet and text_encoder.""" """Apply selected lora to unet and text_encoder."""
lora: ModelIdentifierField = InputField( lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA", ui_type=UIType.LoRAModel
)
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight) weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
unet: Optional[UNetField] = InputField( unet: Optional[UNetField] = InputField(
default=None, default=None,
@@ -148,41 +159,46 @@ class LoRALoaderInvocation(BaseInvocation):
input=Input.Connection, input=Input.Connection,
title="UNet", title="UNet",
) )
clip: Optional[CLIPField] = InputField( clip: Optional[ClipField] = InputField(
default=None, default=None,
description=FieldDescriptions.clip, description=FieldDescriptions.clip,
input=Input.Connection, input=Input.Connection,
title="CLIP", title="CLIP",
) )
def invoke(self, context: InvocationContext) -> LoRALoaderOutput: def invoke(self, context: InvocationContext) -> LoraLoaderOutput:
if self.lora is None:
raise Exception("No LoRA provided")
lora_key = self.lora.key lora_key = self.lora.key
if not context.models.exists(lora_key): if not context.models.exists(lora_key):
raise Exception(f"Unkown lora: {lora_key}!") raise Exception(f"Unkown lora: {lora_key}!")
if self.unet is not None and any(lora.lora.key == lora_key for lora in self.unet.loras): if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras):
raise Exception(f'LoRA "{lora_key}" already applied to unet') raise Exception(f'Lora "{lora_key}" already applied to unet')
if self.clip is not None and any(lora.lora.key == lora_key for lora in self.clip.loras): if self.clip is not None and any(lora.key == lora_key for lora in self.clip.loras):
raise Exception(f'LoRA "{lora_key}" already applied to clip') raise Exception(f'Lora "{lora_key}" already applied to clip')
output = LoRALoaderOutput() output = LoraLoaderOutput()
if self.unet is not None: if self.unet is not None:
output.unet = self.unet.model_copy(deep=True) output.unet = copy.deepcopy(self.unet)
output.unet.loras.append( output.unet.loras.append(
LoRAField( LoraInfo(
lora=self.lora, key=lora_key,
submodel_type=None,
weight=self.weight, weight=self.weight,
) )
) )
if self.clip is not None: if self.clip is not None:
output.clip = self.clip.model_copy(deep=True) output.clip = copy.deepcopy(self.clip)
output.clip.loras.append( output.clip.loras.append(
LoRAField( LoraInfo(
lora=self.lora, key=lora_key,
submodel_type=None,
weight=self.weight, weight=self.weight,
) )
) )
@@ -191,12 +207,12 @@ class LoRALoaderInvocation(BaseInvocation):
@invocation_output("sdxl_lora_loader_output") @invocation_output("sdxl_lora_loader_output")
class SDXLLoRALoaderOutput(BaseInvocationOutput): class SDXLLoraLoaderOutput(BaseInvocationOutput):
"""SDXL LoRA Loader Output""" """SDXL LoRA Loader Output"""
unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet") unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
clip: Optional[CLIPField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 1") clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 1")
clip2: Optional[CLIPField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 2") clip2: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 2")
@invocation( @invocation(
@@ -204,14 +220,12 @@ class SDXLLoRALoaderOutput(BaseInvocationOutput):
title="SDXL LoRA", title="SDXL LoRA",
tags=["lora", "model"], tags=["lora", "model"],
category="model", category="model",
version="1.0.2", version="1.0.1",
) )
class SDXLLoRALoaderInvocation(BaseInvocation): class SDXLLoraLoaderInvocation(BaseInvocation):
"""Apply selected lora to unet and text_encoder.""" """Apply selected lora to unet and text_encoder."""
lora: ModelIdentifierField = InputField( lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA", ui_type=UIType.LoRAModel
)
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight) weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
unet: Optional[UNetField] = InputField( unet: Optional[UNetField] = InputField(
default=None, default=None,
@@ -219,59 +233,65 @@ class SDXLLoRALoaderInvocation(BaseInvocation):
input=Input.Connection, input=Input.Connection,
title="UNet", title="UNet",
) )
clip: Optional[CLIPField] = InputField( clip: Optional[ClipField] = InputField(
default=None, default=None,
description=FieldDescriptions.clip, description=FieldDescriptions.clip,
input=Input.Connection, input=Input.Connection,
title="CLIP 1", title="CLIP 1",
) )
clip2: Optional[CLIPField] = InputField( clip2: Optional[ClipField] = InputField(
default=None, default=None,
description=FieldDescriptions.clip, description=FieldDescriptions.clip,
input=Input.Connection, input=Input.Connection,
title="CLIP 2", title="CLIP 2",
) )
def invoke(self, context: InvocationContext) -> SDXLLoRALoaderOutput: def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput:
if self.lora is None:
raise Exception("No LoRA provided")
lora_key = self.lora.key lora_key = self.lora.key
if not context.models.exists(lora_key): if not context.models.exists(lora_key):
raise Exception(f"Unknown lora: {lora_key}!") raise Exception(f"Unknown lora: {lora_key}!")
if self.unet is not None and any(lora.lora.key == lora_key for lora in self.unet.loras): if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras):
raise Exception(f'LoRA "{lora_key}" already applied to unet') raise Exception(f'Lora "{lora_key}" already applied to unet')
if self.clip is not None and any(lora.lora.key == lora_key for lora in self.clip.loras): if self.clip is not None and any(lora.key == lora_key for lora in self.clip.loras):
raise Exception(f'LoRA "{lora_key}" already applied to clip') raise Exception(f'Lora "{lora_key}" already applied to clip')
if self.clip2 is not None and any(lora.lora.key == lora_key for lora in self.clip2.loras): if self.clip2 is not None and any(lora.key == lora_key for lora in self.clip2.loras):
raise Exception(f'LoRA "{lora_key}" already applied to clip2') raise Exception(f'Lora "{lora_key}" already applied to clip2')
output = SDXLLoRALoaderOutput() output = SDXLLoraLoaderOutput()
if self.unet is not None: if self.unet is not None:
output.unet = self.unet.model_copy(deep=True) output.unet = copy.deepcopy(self.unet)
output.unet.loras.append( output.unet.loras.append(
LoRAField( LoraInfo(
lora=self.lora, key=lora_key,
submodel_type=None,
weight=self.weight, weight=self.weight,
) )
) )
if self.clip is not None: if self.clip is not None:
output.clip = self.clip.model_copy(deep=True) output.clip = copy.deepcopy(self.clip)
output.clip.loras.append( output.clip.loras.append(
LoRAField( LoraInfo(
lora=self.lora, key=lora_key,
submodel_type=None,
weight=self.weight, weight=self.weight,
) )
) )
if self.clip2 is not None: if self.clip2 is not None:
output.clip2 = self.clip2.model_copy(deep=True) output.clip2 = copy.deepcopy(self.clip2)
output.clip2.loras.append( output.clip2.loras.append(
LoRAField( LoraInfo(
lora=self.lora, key=lora_key,
submodel_type=None,
weight=self.weight, weight=self.weight,
) )
) )
@@ -279,12 +299,20 @@ class SDXLLoRALoaderInvocation(BaseInvocation):
return output return output
@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.2") class VAEModelField(BaseModel):
class VAELoaderInvocation(BaseInvocation): """Vae model field"""
key: str = Field(description="Model's key")
@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.1")
class VaeLoaderInvocation(BaseInvocation):
"""Loads a VAE model, outputting a VaeLoaderOutput""" """Loads a VAE model, outputting a VaeLoaderOutput"""
vae_model: ModelIdentifierField = InputField( vae_model: VAEModelField = InputField(
description=FieldDescriptions.vae_model, input=Input.Direct, title="VAE", ui_type=UIType.VAEModel description=FieldDescriptions.vae_model,
input=Input.Direct,
title="VAE",
) )
def invoke(self, context: InvocationContext) -> VAEOutput: def invoke(self, context: InvocationContext) -> VAEOutput:
@@ -293,7 +321,7 @@ class VAELoaderInvocation(BaseInvocation):
if not context.models.exists(key): if not context.models.exists(key):
raise Exception(f"Unkown vae: {key}!") raise Exception(f"Unkown vae: {key}!")
return VAEOutput(vae=VAEField(vae=self.vae_model)) return VAEOutput(vae=VaeField(vae=ModelInfo(key=key)))
@invocation_output("seamless_output") @invocation_output("seamless_output")
@@ -301,7 +329,7 @@ class SeamlessModeOutput(BaseInvocationOutput):
"""Modified Seamless Model output""" """Modified Seamless Model output"""
unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet") unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
vae: Optional[VAEField] = OutputField(default=None, description=FieldDescriptions.vae, title="VAE") vae: Optional[VaeField] = OutputField(default=None, description=FieldDescriptions.vae, title="VAE")
@invocation( @invocation(
@@ -309,7 +337,7 @@ class SeamlessModeOutput(BaseInvocationOutput):
title="Seamless", title="Seamless",
tags=["seamless", "model"], tags=["seamless", "model"],
category="model", category="model",
version="1.0.1", version="1.0.0",
) )
class SeamlessModeInvocation(BaseInvocation): class SeamlessModeInvocation(BaseInvocation):
"""Applies the seamless transformation to the Model UNet and VAE.""" """Applies the seamless transformation to the Model UNet and VAE."""
@@ -320,7 +348,7 @@ class SeamlessModeInvocation(BaseInvocation):
input=Input.Connection, input=Input.Connection,
title="UNet", title="UNet",
) )
vae: Optional[VAEField] = InputField( vae: Optional[VaeField] = InputField(
default=None, default=None,
description=FieldDescriptions.vae_model, description=FieldDescriptions.vae_model,
input=Input.Connection, input=Input.Connection,
@@ -349,7 +377,7 @@ class SeamlessModeInvocation(BaseInvocation):
return SeamlessModeOutput(unet=unet, vae=vae) return SeamlessModeOutput(unet=unet, vae=vae)
@invocation("freeu", title="FreeU", tags=["freeu"], category="unet", version="1.0.1") @invocation("freeu", title="FreeU", tags=["freeu"], category="unet", version="1.0.0")
class FreeUInvocation(BaseInvocation): class FreeUInvocation(BaseInvocation):
""" """
Applies FreeU to the UNet. Suggested values (b1/b2/s1/s2): Applies FreeU to the UNet. Suggested values (b1/b2/s1/s2):

View File

@@ -81,7 +81,7 @@ class NoiseOutput(BaseInvocationOutput):
title="Noise", title="Noise",
tags=["latents", "noise"], tags=["latents", "noise"],
category="latents", category="latents",
version="1.0.2", version="1.0.1",
) )
class NoiseInvocation(BaseInvocation): class NoiseInvocation(BaseInvocation):
"""Generates latent noise.""" """Generates latent noise."""

View File

@@ -51,7 +51,7 @@ from .fields import InputField
title="Float Range", title="Float Range",
tags=["math", "range"], tags=["math", "range"],
category="math", category="math",
version="1.0.1", version="1.0.0",
) )
class FloatLinearRangeInvocation(BaseInvocation): class FloatLinearRangeInvocation(BaseInvocation):
"""Creates a range""" """Creates a range"""
@@ -111,7 +111,7 @@ EASING_FUNCTION_KEYS = Literal[tuple(EASING_FUNCTIONS_MAP.keys())]
title="Step Param Easing", title="Step Param Easing",
tags=["step", "easing"], tags=["step", "easing"],
category="step", category="step",
version="1.0.2", version="1.0.1",
) )
class StepParamEasingInvocation(BaseInvocation): class StepParamEasingInvocation(BaseInvocation):
"""Experimental per-step parameter easing for denoising steps""" """Experimental per-step parameter easing for denoising steps"""

View File

@@ -54,7 +54,7 @@ class BooleanCollectionOutput(BaseInvocationOutput):
@invocation( @invocation(
"boolean", title="Boolean Primitive", tags=["primitives", "boolean"], category="primitives", version="1.0.1" "boolean", title="Boolean Primitive", tags=["primitives", "boolean"], category="primitives", version="1.0.0"
) )
class BooleanInvocation(BaseInvocation): class BooleanInvocation(BaseInvocation):
"""A boolean primitive value""" """A boolean primitive value"""
@@ -70,7 +70,7 @@ class BooleanInvocation(BaseInvocation):
title="Boolean Collection Primitive", title="Boolean Collection Primitive",
tags=["primitives", "boolean", "collection"], tags=["primitives", "boolean", "collection"],
category="primitives", category="primitives",
version="1.0.2", version="1.0.1",
) )
class BooleanCollectionInvocation(BaseInvocation): class BooleanCollectionInvocation(BaseInvocation):
"""A collection of boolean primitive values""" """A collection of boolean primitive values"""
@@ -103,7 +103,7 @@ class IntegerCollectionOutput(BaseInvocationOutput):
@invocation( @invocation(
"integer", title="Integer Primitive", tags=["primitives", "integer"], category="primitives", version="1.0.1" "integer", title="Integer Primitive", tags=["primitives", "integer"], category="primitives", version="1.0.0"
) )
class IntegerInvocation(BaseInvocation): class IntegerInvocation(BaseInvocation):
"""An integer primitive value""" """An integer primitive value"""
@@ -119,7 +119,7 @@ class IntegerInvocation(BaseInvocation):
title="Integer Collection Primitive", title="Integer Collection Primitive",
tags=["primitives", "integer", "collection"], tags=["primitives", "integer", "collection"],
category="primitives", category="primitives",
version="1.0.2", version="1.0.1",
) )
class IntegerCollectionInvocation(BaseInvocation): class IntegerCollectionInvocation(BaseInvocation):
"""A collection of integer primitive values""" """A collection of integer primitive values"""
@@ -151,7 +151,7 @@ class FloatCollectionOutput(BaseInvocationOutput):
) )
@invocation("float", title="Float Primitive", tags=["primitives", "float"], category="primitives", version="1.0.1") @invocation("float", title="Float Primitive", tags=["primitives", "float"], category="primitives", version="1.0.0")
class FloatInvocation(BaseInvocation): class FloatInvocation(BaseInvocation):
"""A float primitive value""" """A float primitive value"""
@@ -166,7 +166,7 @@ class FloatInvocation(BaseInvocation):
title="Float Collection Primitive", title="Float Collection Primitive",
tags=["primitives", "float", "collection"], tags=["primitives", "float", "collection"],
category="primitives", category="primitives",
version="1.0.2", version="1.0.1",
) )
class FloatCollectionInvocation(BaseInvocation): class FloatCollectionInvocation(BaseInvocation):
"""A collection of float primitive values""" """A collection of float primitive values"""
@@ -198,7 +198,7 @@ class StringCollectionOutput(BaseInvocationOutput):
) )
@invocation("string", title="String Primitive", tags=["primitives", "string"], category="primitives", version="1.0.1") @invocation("string", title="String Primitive", tags=["primitives", "string"], category="primitives", version="1.0.0")
class StringInvocation(BaseInvocation): class StringInvocation(BaseInvocation):
"""A string primitive value""" """A string primitive value"""
@@ -213,7 +213,7 @@ class StringInvocation(BaseInvocation):
title="String Collection Primitive", title="String Collection Primitive",
tags=["primitives", "string", "collection"], tags=["primitives", "string", "collection"],
category="primitives", category="primitives",
version="1.0.2", version="1.0.1",
) )
class StringCollectionInvocation(BaseInvocation): class StringCollectionInvocation(BaseInvocation):
"""A collection of string primitive values""" """A collection of string primitive values"""
@@ -255,7 +255,7 @@ class ImageCollectionOutput(BaseInvocationOutput):
) )
@invocation("image", title="Image Primitive", tags=["primitives", "image"], category="primitives", version="1.0.2") @invocation("image", title="Image Primitive", tags=["primitives", "image"], category="primitives", version="1.0.1")
class ImageInvocation(BaseInvocation): class ImageInvocation(BaseInvocation):
"""An image primitive value""" """An image primitive value"""
@@ -276,7 +276,7 @@ class ImageInvocation(BaseInvocation):
title="Image Collection Primitive", title="Image Collection Primitive",
tags=["primitives", "image", "collection"], tags=["primitives", "image", "collection"],
category="primitives", category="primitives",
version="1.0.1", version="1.0.0",
) )
class ImageCollectionInvocation(BaseInvocation): class ImageCollectionInvocation(BaseInvocation):
"""A collection of image primitive values""" """A collection of image primitive values"""
@@ -299,13 +299,9 @@ class DenoiseMaskOutput(BaseInvocationOutput):
denoise_mask: DenoiseMaskField = OutputField(description="Mask for denoise model run") denoise_mask: DenoiseMaskField = OutputField(description="Mask for denoise model run")
@classmethod @classmethod
def build( def build(cls, mask_name: str, masked_latents_name: Optional[str] = None) -> "DenoiseMaskOutput":
cls, mask_name: str, masked_latents_name: Optional[str] = None, gradient: bool = False
) -> "DenoiseMaskOutput":
return cls( return cls(
denoise_mask=DenoiseMaskField( denoise_mask=DenoiseMaskField(mask_name=mask_name, masked_latents_name=masked_latents_name),
mask_name=mask_name, masked_latents_name=masked_latents_name, gradient=gradient
),
) )
@@ -341,7 +337,7 @@ class LatentsCollectionOutput(BaseInvocationOutput):
@invocation( @invocation(
"latents", title="Latents Primitive", tags=["primitives", "latents"], category="primitives", version="1.0.2" "latents", title="Latents Primitive", tags=["primitives", "latents"], category="primitives", version="1.0.1"
) )
class LatentsInvocation(BaseInvocation): class LatentsInvocation(BaseInvocation):
"""A latents tensor primitive value""" """A latents tensor primitive value"""
@@ -359,7 +355,7 @@ class LatentsInvocation(BaseInvocation):
title="Latents Collection Primitive", title="Latents Collection Primitive",
tags=["primitives", "latents", "collection"], tags=["primitives", "latents", "collection"],
category="primitives", category="primitives",
version="1.0.1", version="1.0.0",
) )
class LatentsCollectionInvocation(BaseInvocation): class LatentsCollectionInvocation(BaseInvocation):
"""A collection of latents tensor primitive values""" """A collection of latents tensor primitive values"""
@@ -393,7 +389,7 @@ class ColorCollectionOutput(BaseInvocationOutput):
) )
@invocation("color", title="Color Primitive", tags=["primitives", "color"], category="primitives", version="1.0.1") @invocation("color", title="Color Primitive", tags=["primitives", "color"], category="primitives", version="1.0.0")
class ColorInvocation(BaseInvocation): class ColorInvocation(BaseInvocation):
"""A color primitive value""" """A color primitive value"""
@@ -433,7 +429,7 @@ class ConditioningCollectionOutput(BaseInvocationOutput):
title="Conditioning Primitive", title="Conditioning Primitive",
tags=["primitives", "conditioning"], tags=["primitives", "conditioning"],
category="primitives", category="primitives",
version="1.0.1", version="1.0.0",
) )
class ConditioningInvocation(BaseInvocation): class ConditioningInvocation(BaseInvocation):
"""A conditioning tensor primitive value""" """A conditioning tensor primitive value"""
@@ -449,7 +445,7 @@ class ConditioningInvocation(BaseInvocation):
title="Conditioning Collection Primitive", title="Conditioning Collection Primitive",
tags=["primitives", "conditioning", "collection"], tags=["primitives", "conditioning", "collection"],
category="primitives", category="primitives",
version="1.0.2", version="1.0.1",
) )
class ConditioningCollectionInvocation(BaseInvocation): class ConditioningCollectionInvocation(BaseInvocation):
"""A collection of conditioning tensor primitive values""" """A collection of conditioning tensor primitive values"""

View File

@@ -17,7 +17,7 @@ from .fields import InputField, UIComponent
title="Dynamic Prompt", title="Dynamic Prompt",
tags=["prompt", "collection"], tags=["prompt", "collection"],
category="prompt", category="prompt",
version="1.0.1", version="1.0.0",
use_cache=False, use_cache=False,
) )
class DynamicPromptInvocation(BaseInvocation): class DynamicPromptInvocation(BaseInvocation):
@@ -46,7 +46,7 @@ class DynamicPromptInvocation(BaseInvocation):
title="Prompts from File", title="Prompts from File",
tags=["prompt", "file"], tags=["prompt", "file"],
category="prompt", category="prompt",
version="1.0.2", version="1.0.1",
) )
class PromptsFromFileInvocation(BaseInvocation): class PromptsFromFileInvocation(BaseInvocation):
"""Loads prompts from a text file""" """Loads prompts from a text file"""

View File

@@ -8,7 +8,7 @@ from .baseinvocation import (
invocation, invocation,
invocation_output, invocation_output,
) )
from .model import CLIPField, ModelIdentifierField, UNetField, VAEField from .model import ClipField, MainModelField, ModelInfo, UNetField, VaeField
@invocation_output("sdxl_model_loader_output") @invocation_output("sdxl_model_loader_output")
@@ -16,9 +16,9 @@ class SDXLModelLoaderOutput(BaseInvocationOutput):
"""SDXL base model loader output""" """SDXL base model loader output"""
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet") unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP 1") clip: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 1")
clip2: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP 2") clip2: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE") vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
@invocation_output("sdxl_refiner_model_loader_output") @invocation_output("sdxl_refiner_model_loader_output")
@@ -26,15 +26,15 @@ class SDXLRefinerModelLoaderOutput(BaseInvocationOutput):
"""SDXL refiner model loader output""" """SDXL refiner model loader output"""
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet") unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
clip2: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP 2") clip2: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE") vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
@invocation("sdxl_model_loader", title="SDXL Main Model", tags=["model", "sdxl"], category="model", version="1.0.2") @invocation("sdxl_model_loader", title="SDXL Main Model", tags=["model", "sdxl"], category="model", version="1.0.1")
class SDXLModelLoaderInvocation(BaseInvocation): class SDXLModelLoaderInvocation(BaseInvocation):
"""Loads an sdxl base model, outputting its submodels.""" """Loads an sdxl base model, outputting its submodels."""
model: ModelIdentifierField = InputField( model: MainModelField = InputField(
description=FieldDescriptions.sdxl_main_model, input=Input.Direct, ui_type=UIType.SDXLMainModel description=FieldDescriptions.sdxl_main_model, input=Input.Direct, ui_type=UIType.SDXLMainModel
) )
# TODO: precision? # TODO: precision?
@@ -46,19 +46,48 @@ class SDXLModelLoaderInvocation(BaseInvocation):
if not context.models.exists(model_key): if not context.models.exists(model_key):
raise Exception(f"Unknown model: {model_key}") raise Exception(f"Unknown model: {model_key}")
unet = self.model.model_copy(update={"submodel_type": SubModelType.UNet})
scheduler = self.model.model_copy(update={"submodel_type": SubModelType.Scheduler})
tokenizer = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
text_encoder = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
tokenizer2 = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
text_encoder2 = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
return SDXLModelLoaderOutput( return SDXLModelLoaderOutput(
unet=UNetField(unet=unet, scheduler=scheduler, loras=[]), unet=UNetField(
clip=CLIPField(tokenizer=tokenizer, text_encoder=text_encoder, loras=[], skipped_layers=0), unet=ModelInfo(
clip2=CLIPField(tokenizer=tokenizer2, text_encoder=text_encoder2, loras=[], skipped_layers=0), key=model_key,
vae=VAEField(vae=vae), submodel_type=SubModelType.UNet,
),
scheduler=ModelInfo(
key=model_key,
submodel_type=SubModelType.Scheduler,
),
loras=[],
),
clip=ClipField(
tokenizer=ModelInfo(
key=model_key,
submodel_type=SubModelType.Tokenizer,
),
text_encoder=ModelInfo(
key=model_key,
submodel_type=SubModelType.TextEncoder,
),
loras=[],
skipped_layers=0,
),
clip2=ClipField(
tokenizer=ModelInfo(
key=model_key,
submodel_type=SubModelType.Tokenizer2,
),
text_encoder=ModelInfo(
key=model_key,
submodel_type=SubModelType.TextEncoder2,
),
loras=[],
skipped_layers=0,
),
vae=VaeField(
vae=ModelInfo(
key=model_key,
submodel_type=SubModelType.Vae,
),
),
) )
@@ -67,13 +96,15 @@ class SDXLModelLoaderInvocation(BaseInvocation):
title="SDXL Refiner Model", title="SDXL Refiner Model",
tags=["model", "sdxl", "refiner"], tags=["model", "sdxl", "refiner"],
category="model", category="model",
version="1.0.2", version="1.0.1",
) )
class SDXLRefinerModelLoaderInvocation(BaseInvocation): class SDXLRefinerModelLoaderInvocation(BaseInvocation):
"""Loads an sdxl refiner model, outputting its submodels.""" """Loads an sdxl refiner model, outputting its submodels."""
model: ModelIdentifierField = InputField( model: MainModelField = InputField(
description=FieldDescriptions.sdxl_refiner_model, input=Input.Direct, ui_type=UIType.SDXLRefinerModel description=FieldDescriptions.sdxl_refiner_model,
input=Input.Direct,
ui_type=UIType.SDXLRefinerModel,
) )
# TODO: precision? # TODO: precision?
@@ -84,14 +115,34 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation):
if not context.models.exists(model_key): if not context.models.exists(model_key):
raise Exception(f"Unknown model: {model_key}") raise Exception(f"Unknown model: {model_key}")
unet = self.model.model_copy(update={"submodel_type": SubModelType.UNet})
scheduler = self.model.model_copy(update={"submodel_type": SubModelType.Scheduler})
tokenizer2 = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
text_encoder2 = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
return SDXLRefinerModelLoaderOutput( return SDXLRefinerModelLoaderOutput(
unet=UNetField(unet=unet, scheduler=scheduler, loras=[]), unet=UNetField(
clip2=CLIPField(tokenizer=tokenizer2, text_encoder=text_encoder2, loras=[], skipped_layers=0), unet=ModelInfo(
vae=VAEField(vae=vae), key=model_key,
submodel_type=SubModelType.UNet,
),
scheduler=ModelInfo(
key=model_key,
submodel_type=SubModelType.Scheduler,
),
loras=[],
),
clip2=ClipField(
tokenizer=ModelInfo(
key=model_key,
submodel_type=SubModelType.Tokenizer2,
),
text_encoder=ModelInfo(
key=model_key,
submodel_type=SubModelType.TextEncoder2,
),
loras=[],
skipped_layers=0,
),
vae=VaeField(
vae=ModelInfo(
key=model_key,
submodel_type=SubModelType.Vae,
),
),
) )

View File

@@ -27,7 +27,7 @@ class StringPosNegOutput(BaseInvocationOutput):
title="String Split Negative", title="String Split Negative",
tags=["string", "split", "negative"], tags=["string", "split", "negative"],
category="string", category="string",
version="1.0.1", version="1.0.0",
) )
class StringSplitNegInvocation(BaseInvocation): class StringSplitNegInvocation(BaseInvocation):
"""Splits string into two strings, inside [] goes into negative string everthing else goes into positive string. Each [ and ] character is replaced with a space""" """Splits string into two strings, inside [] goes into negative string everthing else goes into positive string. Each [ and ] character is replaced with a space"""
@@ -69,7 +69,7 @@ class String2Output(BaseInvocationOutput):
string_2: str = OutputField(description="string 2") string_2: str = OutputField(description="string 2")
@invocation("string_split", title="String Split", tags=["string", "split"], category="string", version="1.0.1") @invocation("string_split", title="String Split", tags=["string", "split"], category="string", version="1.0.0")
class StringSplitInvocation(BaseInvocation): class StringSplitInvocation(BaseInvocation):
"""Splits string into two strings, based on the first occurance of the delimiter. The delimiter will be removed from the string""" """Splits string into two strings, based on the first occurance of the delimiter. The delimiter will be removed from the string"""
@@ -89,7 +89,7 @@ class StringSplitInvocation(BaseInvocation):
return String2Output(string_1=part1, string_2=part2) return String2Output(string_1=part1, string_2=part2)
@invocation("string_join", title="String Join", tags=["string", "join"], category="string", version="1.0.1") @invocation("string_join", title="String Join", tags=["string", "join"], category="string", version="1.0.0")
class StringJoinInvocation(BaseInvocation): class StringJoinInvocation(BaseInvocation):
"""Joins string left to string right""" """Joins string left to string right"""
@@ -100,7 +100,7 @@ class StringJoinInvocation(BaseInvocation):
return StringOutput(value=((self.string_left or "") + (self.string_right or ""))) return StringOutput(value=((self.string_left or "") + (self.string_right or "")))
@invocation("string_join_three", title="String Join Three", tags=["string", "join"], category="string", version="1.0.1") @invocation("string_join_three", title="String Join Three", tags=["string", "join"], category="string", version="1.0.0")
class StringJoinThreeInvocation(BaseInvocation): class StringJoinThreeInvocation(BaseInvocation):
"""Joins string left to string middle to string right""" """Joins string left to string middle to string right"""
@@ -113,7 +113,7 @@ class StringJoinThreeInvocation(BaseInvocation):
@invocation( @invocation(
"string_replace", title="String Replace", tags=["string", "replace", "regex"], category="string", version="1.0.1" "string_replace", title="String Replace", tags=["string", "replace", "regex"], category="string", version="1.0.0"
) )
class StringReplaceInvocation(BaseInvocation): class StringReplaceInvocation(BaseInvocation):
"""Replaces the search string with the replace string""" """Replaces the search string with the replace string"""

View File

@@ -9,15 +9,18 @@ from invokeai.app.invocations.baseinvocation import (
invocation_output, invocation_output,
) )
from invokeai.app.invocations.controlnet_image_processors import CONTROLNET_RESIZE_VALUES from invokeai.app.invocations.controlnet_image_processors import CONTROLNET_RESIZE_VALUES
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField, UIType from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
class T2IAdapterModelField(BaseModel):
key: str = Field(description="Model record key for the T2I-Adapter model")
class T2IAdapterField(BaseModel): class T2IAdapterField(BaseModel):
image: ImageField = Field(description="The T2I-Adapter image prompt.") image: ImageField = Field(description="The T2I-Adapter image prompt.")
t2i_adapter_model: ModelIdentifierField = Field(description="The T2I-Adapter model to use.") t2i_adapter_model: T2IAdapterModelField = Field(description="The T2I-Adapter model to use.")
weight: Union[float, list[float]] = Field(default=1, description="The weight given to the T2I-Adapter") weight: Union[float, list[float]] = Field(default=1, description="The weight given to the T2I-Adapter")
begin_step_percent: float = Field( begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the T2I-Adapter is first applied (% of total steps)" default=0, ge=0, le=1, description="When the T2I-Adapter is first applied (% of total steps)"
@@ -45,19 +48,18 @@ class T2IAdapterOutput(BaseInvocationOutput):
@invocation( @invocation(
"t2i_adapter", title="T2I-Adapter", tags=["t2i_adapter", "control"], category="t2i_adapter", version="1.0.2" "t2i_adapter", title="T2I-Adapter", tags=["t2i_adapter", "control"], category="t2i_adapter", version="1.0.1"
) )
class T2IAdapterInvocation(BaseInvocation): class T2IAdapterInvocation(BaseInvocation):
"""Collects T2I-Adapter info to pass to other nodes.""" """Collects T2I-Adapter info to pass to other nodes."""
# Inputs # Inputs
image: ImageField = InputField(description="The IP-Adapter image prompt.") image: ImageField = InputField(description="The IP-Adapter image prompt.")
t2i_adapter_model: ModelIdentifierField = InputField( t2i_adapter_model: T2IAdapterModelField = InputField(
description="The T2I-Adapter model.", description="The T2I-Adapter model.",
title="T2I-Adapter Model", title="T2I-Adapter Model",
input=Input.Direct, input=Input.Direct,
ui_order=-1, ui_order=-1,
ui_type=UIType.T2IAdapterModel,
) )
weight: Union[float, list[float]] = InputField( weight: Union[float, list[float]] = InputField(
default=1, ge=0, description="The weight given to the T2I-Adapter", title="Weight" default=1, ge=0, description="The weight given to the T2I-Adapter", title="Weight"

View File

@@ -39,7 +39,7 @@ class CalculateImageTilesOutput(BaseInvocationOutput):
title="Calculate Image Tiles", title="Calculate Image Tiles",
tags=["tiles"], tags=["tiles"],
category="tiles", category="tiles",
version="1.0.1", version="1.0.0",
classification=Classification.Beta, classification=Classification.Beta,
) )
class CalculateImageTilesInvocation(BaseInvocation): class CalculateImageTilesInvocation(BaseInvocation):
@@ -73,7 +73,7 @@ class CalculateImageTilesInvocation(BaseInvocation):
title="Calculate Image Tiles Even Split", title="Calculate Image Tiles Even Split",
tags=["tiles"], tags=["tiles"],
category="tiles", category="tiles",
version="1.1.1", version="1.1.0",
classification=Classification.Beta, classification=Classification.Beta,
) )
class CalculateImageTilesEvenSplitInvocation(BaseInvocation): class CalculateImageTilesEvenSplitInvocation(BaseInvocation):
@@ -116,7 +116,7 @@ class CalculateImageTilesEvenSplitInvocation(BaseInvocation):
title="Calculate Image Tiles Minimum Overlap", title="Calculate Image Tiles Minimum Overlap",
tags=["tiles"], tags=["tiles"],
category="tiles", category="tiles",
version="1.0.1", version="1.0.0",
classification=Classification.Beta, classification=Classification.Beta,
) )
class CalculateImageTilesMinimumOverlapInvocation(BaseInvocation): class CalculateImageTilesMinimumOverlapInvocation(BaseInvocation):
@@ -167,7 +167,7 @@ class TileToPropertiesOutput(BaseInvocationOutput):
title="Tile to Properties", title="Tile to Properties",
tags=["tiles"], tags=["tiles"],
category="tiles", category="tiles",
version="1.0.1", version="1.0.0",
classification=Classification.Beta, classification=Classification.Beta,
) )
class TileToPropertiesInvocation(BaseInvocation): class TileToPropertiesInvocation(BaseInvocation):
@@ -200,7 +200,7 @@ class PairTileImageOutput(BaseInvocationOutput):
title="Pair Tile with Image", title="Pair Tile with Image",
tags=["tiles"], tags=["tiles"],
category="tiles", category="tiles",
version="1.0.1", version="1.0.0",
classification=Classification.Beta, classification=Classification.Beta,
) )
class PairTileImageInvocation(BaseInvocation): class PairTileImageInvocation(BaseInvocation):
@@ -229,7 +229,7 @@ BLEND_MODES = Literal["Linear", "Seam"]
title="Merge Tiles to Image", title="Merge Tiles to Image",
tags=["tiles"], tags=["tiles"],
category="tiles", category="tiles",
version="1.1.1", version="1.1.0",
classification=Classification.Beta, classification=Classification.Beta,
) )
class MergeTilesToImageInvocation(BaseInvocation, WithMetadata, WithBoard): class MergeTilesToImageInvocation(BaseInvocation, WithMetadata, WithBoard):

View File

@@ -11,7 +11,6 @@ from pydantic import ConfigDict
from invokeai.app.invocations.fields import ImageField from invokeai.app.invocations.fields import ImageField
from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.download_with_progress import download_with_progress_bar
from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet
from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN
from invokeai.backend.util.devices import choose_torch_device from invokeai.backend.util.devices import choose_torch_device
@@ -28,18 +27,11 @@ ESRGAN_MODELS = Literal[
"RealESRGAN_x2plus.pth", "RealESRGAN_x2plus.pth",
] ]
ESRGAN_MODEL_URLS: dict[str, str] = {
"RealESRGAN_x4plus.pth": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
"RealESRGAN_x4plus_anime_6B.pth": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
"ESRGAN_SRx4_DF2KOST_official-ff704c30.pth": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
"RealESRGAN_x2plus.pth": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
}
if choose_torch_device() == torch.device("mps"): if choose_torch_device() == torch.device("mps"):
from torch import mps from torch import mps
@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.3.2") @invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.3.1")
class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard): class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Upscales an image using RealESRGAN.""" """Upscales an image using RealESRGAN."""
@@ -53,6 +45,7 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name) image = context.images.get_pil(self.image.image_name)
models_path = context.config.get().models_path
rrdbnet_model = None rrdbnet_model = None
netscale = None netscale = None
@@ -99,16 +92,11 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
context.logger.error(msg) context.logger.error(msg)
raise ValueError(msg) raise ValueError(msg)
esrgan_model_path = Path(context.config.get().models_path, f"core/upscaling/realesrgan/{self.model_name}") esrgan_model_path = Path(f"core/upscaling/realesrgan/{self.model_name}")
# Downloads the ESRGAN model if it doesn't already exist
download_with_progress_bar(
name=self.model_name, url=ESRGAN_MODEL_URLS[self.model_name], dest_path=esrgan_model_path
)
upscaler = RealESRGAN( upscaler = RealESRGAN(
scale=netscale, scale=netscale,
model_path=esrgan_model_path, model_path=models_path / esrgan_model_path,
model=rrdbnet_model, model=rrdbnet_model,
half=False, half=False,
tile=self.tile_size, tile=self.tile_size,

View File

@@ -1,12 +0,0 @@
"""This is a wrapper around the main app entrypoint, to allow for CLI args to be parsed before running the app."""
def run_app() -> None:
# Before doing _anything_, parse CLI args!
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
InvokeAIArgs.parse_args()
from invokeai.app.api_app import invoke_api
invoke_api()

View File

@@ -1,44 +0,0 @@
from abc import ABC, abstractmethod
from typing import Optional
class BulkDownloadBase(ABC):
"""Responsible for creating a zip file containing the images specified by the given image names or board id."""
@abstractmethod
def handler(
self, image_names: Optional[list[str]], board_id: Optional[str], bulk_download_item_id: Optional[str]
) -> None:
"""
Create a zip file containing the images specified by the given image names or board id.
:param image_names: A list of image names to include in the zip file.
:param board_id: The ID of the board. If provided, all images associated with the board will be included in the zip file.
:param bulk_download_item_id: The bulk_download_item_id that will be used to retrieve the bulk download item when it is prepared, if none is provided a uuid will be generated.
"""
@abstractmethod
def get_path(self, bulk_download_item_name: str) -> str:
"""
Get the path to the bulk download file.
:param bulk_download_item_name: The name of the bulk download item.
:return: The path to the bulk download file.
"""
@abstractmethod
def generate_item_id(self, board_id: Optional[str]) -> str:
"""
Generate an item ID for a bulk download item.
:param board_id: The ID of the board whose name is to be included in the item id.
:return: The generated item ID.
"""
@abstractmethod
def delete(self, bulk_download_item_name: str) -> None:
"""
Delete the bulk download file.
:param bulk_download_item_name: The name of the bulk download item.
"""

View File

@@ -1,25 +0,0 @@
DEFAULT_BULK_DOWNLOAD_ID = "default"
class BulkDownloadException(Exception):
"""Exception raised when a bulk download fails."""
def __init__(self, message="Bulk download failed"):
super().__init__(message)
self.message = message
class BulkDownloadTargetException(BulkDownloadException):
"""Exception raised when a bulk download target is not found."""
def __init__(self, message="The bulk download target was not found"):
super().__init__(message)
self.message = message
class BulkDownloadParametersException(BulkDownloadException):
"""Exception raised when a bulk download parameter is invalid."""
def __init__(self, message="No image names or board ID provided"):
super().__init__(message)
self.message = message

View File

@@ -1,157 +0,0 @@
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Optional, Union
from zipfile import ZipFile
from invokeai.app.services.board_records.board_records_common import BoardRecordNotFoundException
from invokeai.app.services.bulk_download.bulk_download_common import (
DEFAULT_BULK_DOWNLOAD_ID,
BulkDownloadException,
BulkDownloadParametersException,
BulkDownloadTargetException,
)
from invokeai.app.services.image_records.image_records_common import ImageRecordNotFoundException
from invokeai.app.services.images.images_common import ImageDTO
from invokeai.app.services.invoker import Invoker
from invokeai.app.util.misc import uuid_string
from .bulk_download_base import BulkDownloadBase
class BulkDownloadService(BulkDownloadBase):
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
def __init__(self):
self._temp_directory = TemporaryDirectory()
self._bulk_downloads_folder = Path(self._temp_directory.name) / "bulk_downloads"
self._bulk_downloads_folder.mkdir(parents=True, exist_ok=True)
def handler(
self, image_names: Optional[list[str]], board_id: Optional[str], bulk_download_item_id: Optional[str]
) -> None:
bulk_download_id: str = DEFAULT_BULK_DOWNLOAD_ID
bulk_download_item_id = bulk_download_item_id or uuid_string()
bulk_download_item_name = bulk_download_item_id + ".zip"
self._signal_job_started(bulk_download_id, bulk_download_item_id, bulk_download_item_name)
try:
image_dtos: list[ImageDTO] = []
if board_id:
image_dtos = self._board_handler(board_id)
elif image_names:
image_dtos = self._image_handler(image_names)
else:
raise BulkDownloadParametersException()
bulk_download_item_name: str = self._create_zip_file(image_dtos, bulk_download_item_id)
self._signal_job_completed(bulk_download_id, bulk_download_item_id, bulk_download_item_name)
except (
ImageRecordNotFoundException,
BoardRecordNotFoundException,
BulkDownloadException,
BulkDownloadParametersException,
) as e:
self._signal_job_failed(bulk_download_id, bulk_download_item_id, bulk_download_item_name, e)
except Exception as e:
self._signal_job_failed(bulk_download_id, bulk_download_item_id, bulk_download_item_name, e)
self._invoker.services.logger.error("Problem bulk downloading images.")
raise e
def _image_handler(self, image_names: list[str]) -> list[ImageDTO]:
return [self._invoker.services.images.get_dto(image_name) for image_name in image_names]
def _board_handler(self, board_id: str) -> list[ImageDTO]:
image_names = self._invoker.services.board_image_records.get_all_board_image_names_for_board(board_id)
return self._image_handler(image_names)
def generate_item_id(self, board_id: Optional[str]) -> str:
return uuid_string() if board_id is None else self._get_clean_board_name(board_id) + "_" + uuid_string()
def _get_clean_board_name(self, board_id: str) -> str:
if board_id == "none":
return "Uncategorized"
return self._clean_string_to_path_safe(self._invoker.services.board_records.get(board_id).board_name)
def _create_zip_file(self, image_dtos: list[ImageDTO], bulk_download_item_id: str) -> str:
"""
Create a zip file containing the images specified by the given image names or board id.
If download with the same bulk_download_id already exists, it will be overwritten.
:return: The name of the zip file.
"""
zip_file_name = bulk_download_item_id + ".zip"
zip_file_path = self._bulk_downloads_folder / (zip_file_name)
with ZipFile(zip_file_path, "w") as zip_file:
for image_dto in image_dtos:
image_zip_path = Path(image_dto.image_category.value) / image_dto.image_name
image_disk_path = self._invoker.services.images.get_path(image_dto.image_name)
zip_file.write(image_disk_path, arcname=image_zip_path)
return str(zip_file_name)
# from https://stackoverflow.com/questions/7406102/create-sane-safe-filename-from-any-unsafe-string
def _clean_string_to_path_safe(self, s: str) -> str:
"""Clean a string to be path safe."""
return "".join([c for c in s if c.isalpha() or c.isdigit() or c == " " or c == "_" or c == "-"]).rstrip()
def _signal_job_started(
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
) -> None:
"""Signal that a bulk download job has started."""
if self._invoker:
assert bulk_download_id is not None
self._invoker.services.events.emit_bulk_download_started(
bulk_download_id=bulk_download_id,
bulk_download_item_id=bulk_download_item_id,
bulk_download_item_name=bulk_download_item_name,
)
def _signal_job_completed(
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
) -> None:
"""Signal that a bulk download job has completed."""
if self._invoker:
assert bulk_download_id is not None
assert bulk_download_item_name is not None
self._invoker.services.events.emit_bulk_download_completed(
bulk_download_id=bulk_download_id,
bulk_download_item_id=bulk_download_item_id,
bulk_download_item_name=bulk_download_item_name,
)
def _signal_job_failed(
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, exception: Exception
) -> None:
"""Signal that a bulk download job has failed."""
if self._invoker:
assert bulk_download_id is not None
assert exception is not None
self._invoker.services.events.emit_bulk_download_failed(
bulk_download_id=bulk_download_id,
bulk_download_item_id=bulk_download_item_id,
bulk_download_item_name=bulk_download_item_name,
error=str(exception),
)
def stop(self, *args, **kwargs):
self._temp_directory.cleanup()
def delete(self, bulk_download_item_name: str) -> None:
path = self.get_path(bulk_download_item_name)
Path(path).unlink()
def get_path(self, bulk_download_item_name: str) -> str:
path = str(self._bulk_downloads_folder / bulk_download_item_name)
if not self._is_valid_path(path):
raise BulkDownloadTargetException()
return path
def _is_valid_path(self, path: Union[str, Path]) -> bool:
"""Validates the path given for a bulk download."""
path = path if isinstance(path, Path) else Path(path)
return path.exists()

View File

@@ -2,6 +2,6 @@
from invokeai.app.services.config.config_common import PagingArgumentParser from invokeai.app.services.config.config_common import PagingArgumentParser
from .config_default import InvokeAIAppConfig, get_config from .config_default import InvokeAIAppConfig, get_invokeai_config
__all__ = ["InvokeAIAppConfig", "get_config", "PagingArgumentParser"] __all__ = ["InvokeAIAppConfig", "get_invokeai_config", "PagingArgumentParser"]

View File

@@ -0,0 +1,223 @@
# Copyright (c) 2023 Lincoln Stein (https://github.com/lstein) and the InvokeAI Development Team
"""
Base class for the InvokeAI configuration system.
It defines a type of pydantic BaseSettings object that
is able to read and write from an omegaconf-based config file,
with overriding of settings from environment variables and/or
the command line.
"""
from __future__ import annotations
import argparse
import os
import sys
from argparse import ArgumentParser
from pathlib import Path
from typing import Any, ClassVar, Dict, List, Literal, Optional, Union, get_args, get_origin, get_type_hints
from omegaconf import DictConfig, ListConfig, OmegaConf
from pydantic_settings import BaseSettings, SettingsConfigDict
from invokeai.app.services.config.config_common import PagingArgumentParser, int_or_float_or_str
class InvokeAISettings(BaseSettings):
"""Runtime configuration settings in which default values are read from an omegaconf .yaml file."""
initconf: ClassVar[Optional[DictConfig]] = None
argparse_groups: ClassVar[Dict[str, Any]] = {}
model_config = SettingsConfigDict(env_file_encoding="utf-8", arbitrary_types_allowed=True, case_sensitive=True)
def parse_args(self, argv: Optional[List[str]] = sys.argv[1:]) -> None:
"""Call to parse command-line arguments."""
parser = self.get_parser()
opt, unknown_opts = parser.parse_known_args(argv)
if len(unknown_opts) > 0:
print("Unknown args:", unknown_opts)
for name in self.model_fields:
if name not in self._excluded():
value = getattr(opt, name)
if isinstance(value, ListConfig):
value = list(value)
elif isinstance(value, DictConfig):
value = dict(value)
setattr(self, name, value)
def to_yaml(self) -> str:
"""Return a YAML string representing our settings. This can be used as the contents of `invokeai.yaml` to restore settings later."""
cls = self.__class__
type = get_args(get_type_hints(cls)["type"])[0]
field_dict: Dict[str, Dict[str, Any]] = {type: {}}
for name, field in self.model_fields.items():
if name in cls._excluded_from_yaml():
continue
assert isinstance(field.json_schema_extra, dict)
category = (
field.json_schema_extra.get("category", "Uncategorized") if field.json_schema_extra else "Uncategorized"
)
value = getattr(self, name)
assert isinstance(category, str)
if category not in field_dict[type]:
field_dict[type][category] = {}
# keep paths as strings to make it easier to read
field_dict[type][category][name] = str(value) if isinstance(value, Path) else value
conf = OmegaConf.create(field_dict)
return OmegaConf.to_yaml(conf)
@classmethod
def add_parser_arguments(cls, parser: ArgumentParser) -> None:
"""Dynamically create arguments for a settings parser."""
if "type" in get_type_hints(cls):
settings_stanza = get_args(get_type_hints(cls)["type"])[0]
else:
settings_stanza = "Uncategorized"
env_prefix = getattr(cls.model_config, "env_prefix", None)
env_prefix = env_prefix if env_prefix is not None else settings_stanza.upper()
initconf = (
cls.initconf.get(settings_stanza)
if cls.initconf and settings_stanza in cls.initconf
else OmegaConf.create()
)
# create an upcase version of the environment in
# order to achieve case-insensitive environment
# variables (the way Windows does)
upcase_environ = {}
for key, value in os.environ.items():
upcase_environ[key.upper()] = value
fields = cls.model_fields
cls.argparse_groups = {}
for name, field in fields.items():
if name not in cls._excluded():
current_default = field.default
category = (
field.json_schema_extra.get("category", "Uncategorized")
if field.json_schema_extra
else "Uncategorized"
)
env_name = env_prefix + "_" + name
if category in initconf and name in initconf.get(category):
field.default = initconf.get(category).get(name)
if env_name.upper() in upcase_environ:
field.default = upcase_environ[env_name.upper()]
cls.add_field_argument(parser, name, field)
field.default = current_default
@classmethod
def cmd_name(cls, command_field: str = "type") -> str:
"""Return the category of a setting."""
hints = get_type_hints(cls)
if command_field in hints:
result: str = get_args(hints[command_field])[0]
return result
else:
return "Uncategorized"
@classmethod
def get_parser(cls) -> ArgumentParser:
"""Get the command-line parser for a setting."""
parser = PagingArgumentParser(
prog=cls.cmd_name(),
description=cls.__doc__,
)
cls.add_parser_arguments(parser)
return parser
@classmethod
def _excluded(cls) -> List[str]:
# internal fields that shouldn't be exposed as command line options
return ["type", "initconf"]
@classmethod
def _excluded_from_yaml(cls) -> List[str]:
# combination of deprecated parameters and internal ones that shouldn't be exposed as invokeai.yaml options
return [
"type",
"initconf",
"version",
"from_file",
"model",
"root",
"max_cache_size",
"max_vram_cache_size",
"always_use_cpu",
"free_gpu_mem",
"xformers_enabled",
"tiled_decode",
"lora_dir",
"embedding_dir",
"controlnet_dir",
]
@classmethod
def add_field_argument(cls, command_parser, name: str, field, default_override=None) -> None:
"""Add the argparse arguments for a setting parser."""
field_type = get_type_hints(cls).get(name)
default = (
default_override
if default_override is not None
else field.default
if field.default_factory is None
else field.default_factory()
)
if category := (field.json_schema_extra.get("category", None) if field.json_schema_extra else None):
if category not in cls.argparse_groups:
cls.argparse_groups[category] = command_parser.add_argument_group(category)
argparse_group = cls.argparse_groups[category]
else:
argparse_group = command_parser
if get_origin(field_type) == Literal:
allowed_values = get_args(field.annotation)
allowed_types = set()
for val in allowed_values:
allowed_types.add(type(val))
allowed_types_list = list(allowed_types)
field_type = allowed_types_list[0] if len(allowed_types) == 1 else int_or_float_or_str
argparse_group.add_argument(
f"--{name}",
dest=name,
type=field_type,
default=default,
choices=allowed_values,
help=field.description,
)
elif get_origin(field_type) == Union:
argparse_group.add_argument(
f"--{name}",
dest=name,
type=int_or_float_or_str,
default=default,
help=field.description,
)
elif get_origin(field_type) == list:
argparse_group.add_argument(
f"--{name}",
dest=name,
nargs="*",
type=field.annotation,
default=default,
action=argparse.BooleanOptionalAction if field.annotation == bool else "store",
help=field.description,
)
else:
argparse_group.add_argument(
f"--{name}",
dest=name,
type=field.annotation,
default=default,
action=argparse.BooleanOptionalAction if field.annotation == bool else "store",
help=field.description,
)

View File

@@ -12,6 +12,7 @@ from __future__ import annotations
import argparse import argparse
import pydoc import pydoc
from typing import Union
class PagingArgumentParser(argparse.ArgumentParser): class PagingArgumentParser(argparse.ArgumentParser):
@@ -23,3 +24,18 @@ class PagingArgumentParser(argparse.ArgumentParser):
def print_help(self, file=None) -> None: def print_help(self, file=None) -> None:
text = self.format_help() text = self.format_help()
pydoc.pager(text) pydoc.pager(text)
def int_or_float_or_str(value: str) -> Union[int, float, str]:
"""
Workaround for argparse type checking.
"""
try:
return int(value)
except Exception as e: # noqa F841
pass
try:
return float(value)
except Exception as e: # noqa F841
pass
return str(value)

View File

@@ -1,23 +1,186 @@
# TODO(psyche): pydantic-settings supports YAML settings sources. If we can figure out a way to integrate the YAML # Copyright (c) 2023 Lincoln Stein (https://github.com/lstein) and the InvokeAI Development Team
# migration logic, we could use that for simpler config loading.
"""Invokeai configuration system.
Arguments and fields are taken from the pydantic definition of the
model. Defaults can be set by creating a yaml configuration file that
has a top-level key of "InvokeAI" and subheadings for each of the
categories returned by `invokeai --help`. The file looks like this:
[file: invokeai.yaml]
InvokeAI:
Web Server:
host: 127.0.0.1
port: 9090
allow_origins: []
allow_credentials: true
allow_methods:
- '*'
allow_headers:
- '*'
Features:
esrgan: true
internet_available: true
log_tokenization: false
patchmatch: true
ignore_missing_core_models: false
Paths:
autoimport_dir: autoimport
lora_dir: null
embedding_dir: null
controlnet_dir: null
conf_path: configs/models.yaml
models_dir: models
legacy_conf_dir: configs/stable-diffusion
db_dir: databases
outdir: /home/lstein/invokeai-main/outputs
use_memory_db: false
Logging:
log_handlers:
- console
log_format: plain
log_level: info
Model Cache:
ram: 13.5
vram: 0.25
lazy_offload: true
log_memory_usage: false
Device:
device: auto
precision: auto
Generation:
sequential_guidance: false
attention_type: xformers
attention_slice_size: auto
force_tiled_decode: false
The default name of the configuration file is `invokeai.yaml`, located
in INVOKEAI_ROOT. You can replace supersede this by providing any
OmegaConf dictionary object initialization time:
omegaconf = OmegaConf.load('/tmp/init.yaml')
conf = InvokeAIAppConfig()
conf.parse_args(conf=omegaconf)
InvokeAIAppConfig.parse_args() will parse the contents of `sys.argv`
at initialization time. You may pass a list of strings in the optional
`argv` argument to use instead of the system argv:
conf.parse_args(argv=['--log_tokenization'])
It is also possible to set a value at initialization time. However, if
you call parse_args() it may be overwritten.
conf = InvokeAIAppConfig(log_tokenization=True)
conf.parse_args(argv=['--no-log_tokenization'])
conf.log_tokenization
# False
To avoid this, use `get_config()` to retrieve the application-wide
configuration object. This will retain any properties set at object
creation time:
conf = InvokeAIAppConfig.get_config(log_tokenization=True)
conf.parse_args(argv=['--no-log_tokenization'])
conf.log_tokenization
# True
Any setting can be overwritten by setting an environment variable of
form: "INVOKEAI_<setting>", as in:
export INVOKEAI_port=8080
Order of precedence (from highest):
1) initialization options
2) command line options
3) environment variable options
4) config file options
5) pydantic defaults
Typical usage at the top level file:
from invokeai.app.services.config import InvokeAIAppConfig
# get global configuration and print its cache size
conf = InvokeAIAppConfig.get_config()
conf.parse_args()
print(conf.ram_cache_size)
Typical usage in a backend module:
from invokeai.app.services.config import InvokeAIAppConfig
# get global configuration and print its cache size value
conf = InvokeAIAppConfig.get_config()
print(conf.ram_cache_size)
Computed properties:
The InvokeAIAppConfig object has a series of properties that
resolve paths relative to the runtime root directory. They each return
a Path object:
root_path - path to InvokeAI root
output_path - path to default outputs directory
model_conf_path - path to models.yaml
conf - alias for the above
embedding_path - path to the embeddings directory
lora_path - path to the LoRA directory
In most cases, you will want to create a single InvokeAIAppConfig
object for the entire application. The InvokeAIAppConfig.get_config() function
does this:
config = InvokeAIAppConfig.get_config()
config.parse_args() # read values from the command line/config file
print(config.root)
# Subclassing
If you wish to create a similar class, please subclass the
`InvokeAISettings` class and define a Literal field named "type",
which is set to the desired top-level name. For example, to create a
"InvokeBatch" configuration, define like this:
class InvokeBatch(InvokeAISettings):
type: Literal["InvokeBatch"] = "InvokeBatch"
node_count : int = Field(default=1, description="Number of nodes to run on", json_schema_extra=dict(category='Resources'))
cpu_count : int = Field(default=8, description="Number of GPUs to run on per node", json_schema_extra=dict(category='Resources'))
This will now read and write from the "InvokeBatch" section of the
config file, look for environment variables named INVOKEBATCH_*, and
accept the command-line arguments `--node_count` and `--cpu_count`. The
two configs are kept in separate sections of the config file:
# invokeai.yaml
InvokeBatch:
Resources:
node_count: 1
cpu_count: 8
InvokeAI:
Paths:
root: /home/lstein/invokeai-main
conf_path: configs/models.yaml
legacy_conf_dir: configs/stable-diffusion
outdir: outputs
...
"""
from __future__ import annotations from __future__ import annotations
import os import os
import re
import shutil
from functools import lru_cache
from pathlib import Path from pathlib import Path
from typing import Any, Literal, Optional from typing import Any, ClassVar, Dict, List, Literal, Optional
import psutil from omegaconf import DictConfig, OmegaConf
import yaml from pydantic import Field
from pydantic import BaseModel, Field, PrivateAttr, field_validator from pydantic.config import JsonDict
from pydantic_settings import BaseSettings, PydanticBaseSettingsSource, SettingsConfigDict from pydantic_settings import SettingsConfigDict
import invokeai.configs as model_configs from .config_base import InvokeAISettings
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
INIT_FILE = Path("invokeai.yaml") INIT_FILE = Path("invokeai.yaml")
DB_FILE = Path("invokeai.db") DB_FILE = Path("invokeai.db")
@@ -25,303 +188,303 @@ LEGACY_INIT_FILE = Path("invokeai.init")
DEFAULT_RAM_CACHE = 10.0 DEFAULT_RAM_CACHE = 10.0
DEFAULT_VRAM_CACHE = 0.25 DEFAULT_VRAM_CACHE = 0.25
DEFAULT_CONVERT_CACHE = 20.0 DEFAULT_CONVERT_CACHE = 20.0
DEVICE = Literal["auto", "cpu", "cuda", "cuda:1", "mps"]
PRECISION = Literal["auto", "float16", "bfloat16", "float32", "autocast"]
ATTENTION_TYPE = Literal["auto", "normal", "xformers", "sliced", "torch-sdp"]
ATTENTION_SLICE_SIZE = Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8]
LOG_FORMAT = Literal["plain", "color", "syslog", "legacy"]
LOG_LEVEL = Literal["debug", "info", "warning", "error", "critical"]
CONFIG_SCHEMA_VERSION = "4.0.0"
def get_default_ram_cache_size() -> float: class Categories(object):
"""Run a heuristic for the default RAM cache based on installed RAM.""" """Category headers for configuration variable groups."""
# On some machines, psutil.virtual_memory().total gives a value that is slightly less than the actual RAM, so the WebServer: JsonDict = {"category": "Web Server"}
# limits are set slightly lower than than what we expect the actual RAM to be. Features: JsonDict = {"category": "Features"}
Paths: JsonDict = {"category": "Paths"}
GB = 1024**3 Logging: JsonDict = {"category": "Logging"}
max_ram = psutil.virtual_memory().total / GB Development: JsonDict = {"category": "Development"}
Other: JsonDict = {"category": "Other"}
if max_ram >= 60: ModelCache: JsonDict = {"category": "Model Cache"}
return 15.0 Device: JsonDict = {"category": "Device"}
if max_ram >= 30: Generation: JsonDict = {"category": "Generation"}
return 7.5 Queue: JsonDict = {"category": "Queue"}
if max_ram >= 14: Nodes: JsonDict = {"category": "Nodes"}
return 4.0 MemoryPerformance: JsonDict = {"category": "Memory/Performance"}
return 2.1 # 2.1 is just large enough for sd 1.5 ;-)
class URLRegexTokenPair(BaseModel): class InvokeAIAppConfig(InvokeAISettings):
url_regex: str = Field(description="Regular expression to match against the URL") """Configuration object for InvokeAI App."""
token: str = Field(description="Token to use when the URL matches the regex")
@field_validator("url_regex") singleton_config: ClassVar[Optional[InvokeAIAppConfig]] = None
@classmethod singleton_init: ClassVar[Optional[Dict[str, Any]]] = None
def validate_url_regex(cls, v: str) -> str:
"""Validate that the value is a valid regex."""
try:
re.compile(v)
except re.error as e:
raise ValueError(f"Invalid regex: {e}")
return v
class InvokeAIAppConfig(BaseSettings):
"""Invoke's global app configuration.
Typically, you won't need to interact with this class directly. Instead, use the `get_config` function from `invokeai.app.services.config` to get a singleton config object.
Attributes:
host: IP address to bind to. Use `0.0.0.0` to serve to your local network.
port: Port to bind to.
allow_origins: Allowed CORS origins.
allow_credentials: Allow CORS credentials.
allow_methods: Methods allowed for CORS.
allow_headers: Headers allowed for CORS.
ssl_certfile: SSL certificate file for HTTPS. See https://www.uvicorn.org/settings/#https.
ssl_keyfile: SSL key file for HTTPS. See https://www.uvicorn.org/settings/#https.
log_tokenization: Enable logging of parsed prompt tokens.
patchmatch: Enable patchmatch inpaint code.
autoimport_dir: Path to a directory of models files to be imported on startup.
models_dir: Path to the models directory.
convert_cache_dir: Path to the converted models cache directory. When loading a non-diffusers model, it will be converted and store on disk at this location.
legacy_conf_dir: Path to directory of legacy checkpoint config files.
db_dir: Path to InvokeAI databases directory.
outputs_dir: Path to directory for outputs.
custom_nodes_dir: Path to directory for custom nodes.
log_handlers: Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>".
log_format: Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style.<br>Valid values: `plain`, `color`, `syslog`, `legacy`
log_level: Emit logging messages at this level or higher.<br>Valid values: `debug`, `info`, `warning`, `error`, `critical`
log_sql: Log SQL queries. `log_level` must be `debug` for this to do anything. Extremely verbose.
use_memory_db: Use in-memory database. Useful for development.
dev_reload: Automatically reload when Python sources are changed. Does not reload node definitions.
profile_graphs: Enable graph profiling using `cProfile`.
profile_prefix: An optional prefix for profile output files.
profiles_dir: Path to profiles output directory.
ram: Maximum memory amount used by memory model cache for rapid switching (GB).
vram: Amount of VRAM reserved for model storage (GB).
convert_cache: Maximum size of on-disk converted models cache (GB).
lazy_offload: Keep models in VRAM until their space is needed.
log_memory_usage: If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.
device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda`, `cuda:1`, `mps`
precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.<br>Valid values: `auto`, `float16`, `bfloat16`, `float32`, `autocast`
sequential_guidance: Whether to calculate guidance in serial instead of in parallel, lowering memory requirements.
attention_type: Attention type.<br>Valid values: `auto`, `normal`, `xformers`, `sliced`, `torch-sdp`
attention_slice_size: Slice size, valid when attention_type=="sliced".<br>Valid values: `auto`, `balanced`, `max`, `1`, `2`, `3`, `4`, `5`, `6`, `7`, `8`
force_tiled_decode: Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).
pil_compress_level: The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.
max_queue_size: Maximum number of items in the session queue.
allow_nodes: List of nodes to allow. Omit to allow all.
deny_nodes: List of nodes to deny. Omit to deny none.
node_cache_size: How many cached nodes to keep in memory.
hashing_algorithm: Model hashing algorthim for model installs. 'blake3_multi' is best for SSDs. 'blake3_single' is best for spinning disk HDDs. 'random' disables hashing, instead assigning a UUID to models. Useful when using a memory db to reduce model installation time, or if you don't care about storing stable hashes for models. Alternatively, any other hashlib algorithm is accepted, though these are not nearly as performant as blake3.<br>Valid values: `blake3_multi`, `blake3_single`, `random`, `md5`, `sha1`, `sha224`, `sha256`, `sha384`, `sha512`, `blake2b`, `blake2s`, `sha3_224`, `sha3_256`, `sha3_384`, `sha3_512`, `shake_128`, `shake_256`
remote_api_tokens: List of regular expression and token pairs used when downloading models from URLs. The download URL is tested against the regex, and if it matches, the token is provided in as a Bearer token.
"""
_root: Optional[Path] = PrivateAttr(default=None)
_config_file: Optional[Path] = PrivateAttr(default=None)
# fmt: off # fmt: off
type: Literal["InvokeAI"] = "InvokeAI"
# INTERNAL
schema_version: str = Field(default=CONFIG_SCHEMA_VERSION, description="Schema version of the config file. This is not a user-configurable setting.")
# This is only used during v3 models.yaml migration
legacy_models_yaml_path: Optional[Path] = Field(default=None, description="Path to the legacy models.yaml file. This is not a user-configurable setting.")
# WEB # WEB
host: str = Field(default="127.0.0.1", description="IP address to bind to. Use `0.0.0.0` to serve to your local network.") host : str = Field(default="127.0.0.1", description="IP address to bind to", json_schema_extra=Categories.WebServer)
port: int = Field(default=9090, description="Port to bind to.") port : int = Field(default=9090, description="Port to bind to", json_schema_extra=Categories.WebServer)
allow_origins: list[str] = Field(default=[], description="Allowed CORS origins.") allow_origins : List[str] = Field(default=[], description="Allowed CORS origins", json_schema_extra=Categories.WebServer)
allow_credentials: bool = Field(default=True, description="Allow CORS credentials.") allow_credentials : bool = Field(default=True, description="Allow CORS credentials", json_schema_extra=Categories.WebServer)
allow_methods: list[str] = Field(default=["*"], description="Methods allowed for CORS.") allow_methods : List[str] = Field(default=["*"], description="Methods allowed for CORS", json_schema_extra=Categories.WebServer)
allow_headers: list[str] = Field(default=["*"], description="Headers allowed for CORS.") allow_headers : List[str] = Field(default=["*"], description="Headers allowed for CORS", json_schema_extra=Categories.WebServer)
ssl_certfile: Optional[Path] = Field(default=None, description="SSL certificate file for HTTPS. See https://www.uvicorn.org/settings/#https.") # SSL options correspond to https://www.uvicorn.org/settings/#https
ssl_keyfile: Optional[Path] = Field(default=None, description="SSL key file for HTTPS. See https://www.uvicorn.org/settings/#https.") ssl_certfile : Optional[Path] = Field(default=None, description="SSL certificate file (for HTTPS)", json_schema_extra=Categories.WebServer)
ssl_keyfile : Optional[Path] = Field(default=None, description="SSL key file", json_schema_extra=Categories.WebServer)
# MISC FEATURES # FEATURES
log_tokenization: bool = Field(default=False, description="Enable logging of parsed prompt tokens.") esrgan : bool = Field(default=True, description="Enable/disable upscaling code", json_schema_extra=Categories.Features)
patchmatch: bool = Field(default=True, description="Enable patchmatch inpaint code.") internet_available : bool = Field(default=True, description="If true, attempt to download models on the fly; otherwise only use local models", json_schema_extra=Categories.Features)
log_tokenization : bool = Field(default=False, description="Enable logging of parsed prompt tokens.", json_schema_extra=Categories.Features)
patchmatch : bool = Field(default=True, description="Enable/disable patchmatch inpaint code", json_schema_extra=Categories.Features)
ignore_missing_core_models : bool = Field(default=False, description='Ignore missing models in models/core/convert', json_schema_extra=Categories.Features)
# PATHS # PATHS
autoimport_dir: Path = Field(default=Path("autoimport"), description="Path to a directory of models files to be imported on startup.") root : Optional[Path] = Field(default=None, description='InvokeAI runtime root directory', json_schema_extra=Categories.Paths)
models_dir: Path = Field(default=Path("models"), description="Path to the models directory.") autoimport_dir : Path = Field(default=Path('autoimport'), description='Path to a directory of models files to be imported on startup.', json_schema_extra=Categories.Paths)
convert_cache_dir: Path = Field(default=Path("models/.cache"), description="Path to the converted models cache directory. When loading a non-diffusers model, it will be converted and store on disk at this location.") conf_path : Path = Field(default=Path('configs/models.yaml'), description='Path to models definition file', json_schema_extra=Categories.Paths)
legacy_conf_dir: Path = Field(default=Path("configs"), description="Path to directory of legacy checkpoint config files.") models_dir : Path = Field(default=Path('models'), description='Path to the models directory', json_schema_extra=Categories.Paths)
db_dir: Path = Field(default=Path("databases"), description="Path to InvokeAI databases directory.") convert_cache_dir : Path = Field(default=Path('models/.cache'), description='Path to the converted models cache directory', json_schema_extra=Categories.Paths)
outputs_dir: Path = Field(default=Path("outputs"), description="Path to directory for outputs.") legacy_conf_dir : Path = Field(default=Path('configs/stable-diffusion'), description='Path to directory of legacy checkpoint config files', json_schema_extra=Categories.Paths)
custom_nodes_dir: Path = Field(default=Path("nodes"), description="Path to directory for custom nodes.") db_dir : Path = Field(default=Path('databases'), description='Path to InvokeAI databases directory', json_schema_extra=Categories.Paths)
outdir : Path = Field(default=Path('outputs'), description='Default folder for output images', json_schema_extra=Categories.Paths)
use_memory_db : bool = Field(default=False, description='Use in-memory database for storing image metadata', json_schema_extra=Categories.Paths)
custom_nodes_dir : Path = Field(default=Path('nodes'), description='Path to directory for custom nodes', json_schema_extra=Categories.Paths)
from_file : Optional[Path] = Field(default=None, description='Take command input from the indicated file (command-line client only)', json_schema_extra=Categories.Paths)
# LOGGING # LOGGING
log_handlers: list[str] = Field(default=["console"], description='Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>".') log_handlers : List[str] = Field(default=["console"], description='Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>"', json_schema_extra=Categories.Logging)
# note - would be better to read the log_format values from logging.py, but this creates circular dependencies issues # note - would be better to read the log_format values from logging.py, but this creates circular dependencies issues
log_format: LOG_FORMAT = Field(default="color", description='Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style.') log_format : Literal['plain', 'color', 'syslog', 'legacy'] = Field(default="color", description='Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style', json_schema_extra=Categories.Logging)
log_level: LOG_LEVEL = Field(default="info", description="Emit logging messages at this level or higher.") log_level : Literal["debug", "info", "warning", "error", "critical"] = Field(default="info", description="Emit logging messages at this level or higher", json_schema_extra=Categories.Logging)
log_sql: bool = Field(default=False, description="Log SQL queries. `log_level` must be `debug` for this to do anything. Extremely verbose.") log_sql : bool = Field(default=False, description="Log SQL queries", json_schema_extra=Categories.Logging)
# Development # Development
use_memory_db: bool = Field(default=False, description="Use in-memory database. Useful for development.") dev_reload : bool = Field(default=False, description="Automatically reload when Python sources are changed.", json_schema_extra=Categories.Development)
dev_reload: bool = Field(default=False, description="Automatically reload when Python sources are changed. Does not reload node definitions.") profile_graphs : bool = Field(default=False, description="Enable graph profiling", json_schema_extra=Categories.Development)
profile_graphs: bool = Field(default=False, description="Enable graph profiling using `cProfile`.") profile_prefix : Optional[str] = Field(default=None, description="An optional prefix for profile output files.", json_schema_extra=Categories.Development)
profile_prefix: Optional[str] = Field(default=None, description="An optional prefix for profile output files.") profiles_dir : Path = Field(default=Path('profiles'), description="Directory for graph profiles", json_schema_extra=Categories.Development)
profiles_dir: Path = Field(default=Path("profiles"), description="Path to profiles output directory.")
version : bool = Field(default=False, description="Show InvokeAI version and exit", json_schema_extra=Categories.Other)
# CACHE # CACHE
ram: float = Field(default_factory=get_default_ram_cache_size, gt=0, description="Maximum memory amount used by memory model cache for rapid switching (GB).") ram : float = Field(default=DEFAULT_RAM_CACHE, gt=0, description="Maximum memory amount used by model cache for rapid switching (floating point number, GB)", json_schema_extra=Categories.ModelCache, )
vram: float = Field(default=DEFAULT_VRAM_CACHE, ge=0, description="Amount of VRAM reserved for model storage (GB).") vram : float = Field(default=DEFAULT_VRAM_CACHE, ge=0, description="Amount of VRAM reserved for model storage (floating point number, GB)", json_schema_extra=Categories.ModelCache, )
convert_cache: float = Field(default=DEFAULT_CONVERT_CACHE, ge=0, description="Maximum size of on-disk converted models cache (GB).") convert_cache : float = Field(default=DEFAULT_CONVERT_CACHE, ge=0, description="Maximum size of on-disk converted models cache (GB)", json_schema_extra=Categories.ModelCache)
lazy_offload: bool = Field(default=True, description="Keep models in VRAM until their space is needed.")
log_memory_usage: bool = Field(default=False, description="If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.") lazy_offload : bool = Field(default=True, description="Keep models in VRAM until their space is needed", json_schema_extra=Categories.ModelCache, )
log_memory_usage : bool = Field(default=False, description="If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.", json_schema_extra=Categories.ModelCache)
# DEVICE # DEVICE
device: DEVICE = Field(default="auto", description="Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.") device : Literal["auto", "cpu", "cuda", "cuda:1", "mps"] = Field(default="auto", description="Generation device", json_schema_extra=Categories.Device)
precision: PRECISION = Field(default="auto", description="Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.") precision : Literal["auto", "float16", "bfloat16", "float32", "autocast"] = Field(default="auto", description="Floating point precision", json_schema_extra=Categories.Device)
# GENERATION # GENERATION
sequential_guidance: bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements.") sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", json_schema_extra=Categories.Generation)
attention_type: ATTENTION_TYPE = Field(default="auto", description="Attention type.") attention_type : Literal["auto", "normal", "xformers", "sliced", "torch-sdp"] = Field(default="auto", description="Attention type", json_schema_extra=Categories.Generation)
attention_slice_size: ATTENTION_SLICE_SIZE = Field(default="auto", description='Slice size, valid when attention_type=="sliced".') attention_slice_size: Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8] = Field(default="auto", description='Slice size, valid when attention_type=="sliced"', json_schema_extra=Categories.Generation)
force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).") force_tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", json_schema_extra=Categories.Generation)
pil_compress_level: int = Field(default=1, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.") png_compress_level : int = Field(default=1, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = fastest, largest filesize, 9 = slowest, smallest filesize", json_schema_extra=Categories.Generation)
max_queue_size: int = Field(default=10000, gt=0, description="Maximum number of items in the session queue.")
# QUEUE
max_queue_size : int = Field(default=10000, gt=0, description="Maximum number of items in the session queue", json_schema_extra=Categories.Queue)
# NODES # NODES
allow_nodes: Optional[list[str]] = Field(default=None, description="List of nodes to allow. Omit to allow all.") allow_nodes : Optional[List[str]] = Field(default=None, description="List of nodes to allow. Omit to allow all.", json_schema_extra=Categories.Nodes)
deny_nodes: Optional[list[str]] = Field(default=None, description="List of nodes to deny. Omit to deny none.") deny_nodes : Optional[List[str]] = Field(default=None, description="List of nodes to deny. Omit to deny none.", json_schema_extra=Categories.Nodes)
node_cache_size: int = Field(default=512, description="How many cached nodes to keep in memory.") node_cache_size : int = Field(default=512, description="How many cached nodes to keep in memory", json_schema_extra=Categories.Nodes)
# MODEL INSTALL # MODEL IMPORT
hashing_algorithm: HASHING_ALGORITHMS = Field(default="blake3_single", description="Model hashing algorthim for model installs. 'blake3_multi' is best for SSDs. 'blake3_single' is best for spinning disk HDDs. 'random' disables hashing, instead assigning a UUID to models. Useful when using a memory db to reduce model installation time, or if you don't care about storing stable hashes for models. Alternatively, any other hashlib algorithm is accepted, though these are not nearly as performant as blake3.") civitai_api_key : Optional[str] = Field(default=os.environ.get("CIVITAI_API_KEY"), description="API key for CivitAI", json_schema_extra=Categories.Other)
remote_api_tokens: Optional[list[URLRegexTokenPair]] = Field(default=None, description="List of regular expression and token pairs used when downloading models from URLs. The download URL is tested against the regex, and if it matches, the token is provided in as a Bearer token.")
# DEPRECATED FIELDS - STILL HERE IN ORDER TO OBTAN VALUES FROM PRE-3.1 CONFIG FILES
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", json_schema_extra=Categories.MemoryPerformance)
max_cache_size : Optional[float] = Field(default=None, gt=0, description="Maximum memory amount used by model cache for rapid switching", json_schema_extra=Categories.MemoryPerformance)
max_vram_cache_size : Optional[float] = Field(default=None, ge=0, description="Amount of VRAM reserved for model storage", json_schema_extra=Categories.MemoryPerformance)
xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", json_schema_extra=Categories.MemoryPerformance)
tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", json_schema_extra=Categories.MemoryPerformance)
lora_dir : Optional[Path] = Field(default=None, description='Path to a directory of LoRA/LyCORIS models to be imported on startup.', json_schema_extra=Categories.Paths)
embedding_dir : Optional[Path] = Field(default=None, description='Path to a directory of Textual Inversion embeddings to be imported on startup.', json_schema_extra=Categories.Paths)
controlnet_dir : Optional[Path] = Field(default=None, description='Path to a directory of ControlNet embeddings to be imported on startup.', json_schema_extra=Categories.Paths)
# this is not referred to in the source code and can be removed entirely
#free_gpu_mem : Optional[bool] = Field(default=None, description="If true, purge model from GPU after each generation.", json_schema_extra=Categories.MemoryPerformance)
# See InvokeAIAppConfig subclass below for CACHE and DEVICE categories
# fmt: on # fmt: on
model_config = SettingsConfigDict(env_prefix="INVOKEAI_", env_ignore_empty=True) model_config = SettingsConfigDict(validate_assignment=True, env_prefix="INVOKEAI")
def update_config(self, config: dict[str, Any] | InvokeAIAppConfig, clobber: bool = True) -> None: def parse_args(
"""Updates the config, overwriting existing values. self,
argv: Optional[list[str]] = None,
Args: conf: Optional[DictConfig] = None,
config: A dictionary of config settings, or instance of `InvokeAIAppConfig`. If an instance of \ clobber: Optional[bool] = False,
`InvokeAIAppConfig`, only the explicitly set fields will be merged into the singleton config. ) -> None:
clobber: If `True`, overwrite existing values. If `False`, only update fields that are not already set.
""" """
Update settings with contents of init file, environment, and command-line settings.
if isinstance(config, dict): :param conf: alternate Omegaconf dictionary object
new_config = self.model_validate(config) :param argv: aternate sys.argv list
:param clobber: ovewrite any initialization parameters passed during initialization
"""
# Set the runtime root directory. We parse command-line switches here
# in order to pick up the --root_dir option.
super().parse_args(argv)
loaded_conf = None
if conf is None:
try:
loaded_conf = OmegaConf.load(self.root_dir / INIT_FILE)
except Exception:
pass
if isinstance(loaded_conf, DictConfig):
InvokeAISettings.initconf = loaded_conf
else: else:
new_config = config InvokeAISettings.initconf = conf
for field_name in new_config.model_fields_set: # parse args again in order to pick up settings in configuration file
new_value = getattr(new_config, field_name) super().parse_args(argv)
current_value = getattr(self, field_name)
if field_name in self.model_fields_set and not clobber: if self.singleton_init and not clobber:
continue # When setting values in this way, set validate_assignment to true if you want to validate the value.
for k, v in self.singleton_init.items():
setattr(self, k, v)
if new_value != current_value: @classmethod
setattr(self, field_name, new_value) def get_config(cls, **kwargs: Any) -> InvokeAIAppConfig:
"""Return a singleton InvokeAIAppConfig configuration object."""
if (
cls.singleton_config is None
or type(cls.singleton_config) is not cls
or (kwargs and cls.singleton_init != kwargs)
):
cls.singleton_config = cls(**kwargs)
cls.singleton_init = kwargs
return cls.singleton_config
def write_file(self, dest_path: Path, as_example: bool = False) -> None: @property
"""Write the current configuration to file. This will overwrite the existing file. def root_path(self) -> Path:
"""Path to the runtime root directory."""
if self.root:
root = Path(self.root).expanduser().absolute()
else:
root = self.find_root().expanduser().absolute()
self.root = root # insulate ourselves from relative paths that may change
return root.resolve()
A `meta` stanza is added to the top of the file, containing metadata about the config file. This is not stored in the config object. @property
def root_dir(self) -> Path:
Args: """Alias for above."""
dest_path: Path to write the config to. return self.root_path
"""
dest_path.parent.mkdir(parents=True, exist_ok=True)
with open(dest_path, "w") as file:
# Meta fields should be written in a separate stanza - skip legacy_models_yaml_path
meta_dict = self.model_dump(mode="json", include={"schema_version"})
# User settings
config_dict = self.model_dump(
mode="json",
exclude_unset=False if as_example else True,
exclude_defaults=False if as_example else True,
exclude_none=True if as_example else False,
exclude={"schema_version", "legacy_models_yaml_path"},
)
if as_example:
file.write(
"# This is an example file with default and example settings. Use the values here as a baseline.\n\n"
)
file.write("# Internal metadata - do not edit:\n")
file.write(yaml.dump(meta_dict, sort_keys=False))
file.write("\n")
file.write("# Put user settings here - see https://invoke-ai.github.io/InvokeAI/features/CONFIGURATION/:\n")
if len(config_dict) > 0:
file.write(yaml.dump(config_dict, sort_keys=False))
def _resolve(self, partial_path: Path) -> Path: def _resolve(self, partial_path: Path) -> Path:
return (self.root_path / partial_path).resolve() return (self.root_path / partial_path).resolve()
@property @property
def root_path(self) -> Path: def init_file_path(self) -> Path:
"""Path to the runtime root directory, resolved to an absolute path.""" """Path to invokeai.yaml."""
if self._root: resolved_path = self._resolve(INIT_FILE)
root = Path(self._root).expanduser().absolute()
else:
root = self.find_root().expanduser().absolute()
self._root = root # insulate ourselves from relative paths that may change
return root.resolve()
@property
def config_file_path(self) -> Path:
"""Path to invokeai.yaml, resolved to an absolute path.."""
resolved_path = self._resolve(self._config_file or INIT_FILE)
assert resolved_path is not None assert resolved_path is not None
return resolved_path return resolved_path
@property @property
def autoimport_path(self) -> Path: def output_path(self) -> Optional[Path]:
"""Path to the autoimports directory, resolved to an absolute path..""" """Path to defaults outputs directory."""
return self._resolve(self.autoimport_dir) return self._resolve(self.outdir)
@property
def outputs_path(self) -> Optional[Path]:
"""Path to the outputs directory, resolved to an absolute path.."""
return self._resolve(self.outputs_dir)
@property @property
def db_path(self) -> Path: def db_path(self) -> Path:
"""Path to the invokeai.db file, resolved to an absolute path..""" """Path to the invokeai.db file."""
db_dir = self._resolve(self.db_dir) db_dir = self._resolve(self.db_dir)
assert db_dir is not None assert db_dir is not None
return db_dir / DB_FILE return db_dir / DB_FILE
@property
def model_conf_path(self) -> Path:
"""Path to models configuration file."""
return self._resolve(self.conf_path)
@property @property
def legacy_conf_path(self) -> Path: def legacy_conf_path(self) -> Path:
"""Path to directory of legacy configuration files (e.g. v1-inference.yaml), resolved to an absolute path..""" """Path to directory of legacy configuration files (e.g. v1-inference.yaml)."""
return self._resolve(self.legacy_conf_dir) return self._resolve(self.legacy_conf_dir)
@property @property
def models_path(self) -> Path: def models_path(self) -> Path:
"""Path to the models directory, resolved to an absolute path..""" """Path to the models directory."""
return self._resolve(self.models_dir) return self._resolve(self.models_dir)
@property @property
def convert_cache_path(self) -> Path: def models_convert_cache_path(self) -> Path:
"""Path to the converted cache models directory, resolved to an absolute path..""" """Path to the converted cache models directory."""
return self._resolve(self.convert_cache_dir) return self._resolve(self.convert_cache_dir)
@property @property
def custom_nodes_path(self) -> Path: def custom_nodes_path(self) -> Path:
"""Path to the custom nodes directory, resolved to an absolute path..""" """Path to the custom nodes directory."""
custom_nodes_path = self._resolve(self.custom_nodes_dir) custom_nodes_path = self._resolve(self.custom_nodes_dir)
assert custom_nodes_path is not None assert custom_nodes_path is not None
return custom_nodes_path return custom_nodes_path
# the following methods support legacy calls leftover from the Globals era
@property
def full_precision(self) -> bool:
"""Return true if precision set to float32."""
return self.precision == "float32"
@property
def try_patchmatch(self) -> bool:
"""Return true if patchmatch true."""
return self.patchmatch
@property
def nsfw_checker(self) -> bool:
"""Return value for NSFW checker. The NSFW node is always active and disabled from Web UI."""
return True
@property
def invisible_watermark(self) -> bool:
"""Return value of invisible watermark. It is always active and disabled from Web UI."""
return True
@property
def ram_cache_size(self) -> float:
"""Return the ram cache size using the legacy or modern setting (GB)."""
return self.max_cache_size or self.ram
@property
def vram_cache_size(self) -> float:
"""Return the vram cache size using the legacy or modern setting (GB)."""
return self.max_vram_cache_size or self.vram
@property
def convert_cache_size(self) -> float:
"""Return the convert cache size on disk (GB)."""
return self.convert_cache
@property
def use_cpu(self) -> bool:
"""Return true if the device is set to CPU or the always_use_cpu flag is set."""
return self.always_use_cpu or self.device == "cpu"
@property
def disable_xformers(self) -> bool:
"""Return true if enable_xformers is false (reversed logic) and attention type is not set to xformers."""
disabled_in_config = not self.xformers_enabled
return disabled_in_config and self.attention_type != "xformers"
@property @property
def profiles_path(self) -> Path: def profiles_path(self) -> Path:
"""Path to the graph profiles directory, resolved to an absolute path..""" """Path to the graph profiles directory."""
return self._resolve(self.profiles_dir) return self._resolve(self.profiles_dir)
@staticmethod @staticmethod
def find_root() -> Path: def find_root() -> Path:
"""Choose the runtime root directory when not specified on command line or init file.""" """Choose the runtime root directory when not specified on command line or init file."""
return _find_root()
def get_invokeai_config(**kwargs: Any) -> InvokeAIAppConfig:
"""Legacy function which returns InvokeAIAppConfig.get_config()."""
return InvokeAIAppConfig.get_config(**kwargs)
def _find_root() -> Path:
venv = Path(os.environ.get("VIRTUAL_ENV") or ".") venv = Path(os.environ.get("VIRTUAL_ENV") or ".")
if os.environ.get("INVOKEAI_ROOT"): if os.environ.get("INVOKEAI_ROOT"):
root = Path(os.environ["INVOKEAI_ROOT"]) root = Path(os.environ["INVOKEAI_ROOT"])
@@ -330,158 +493,3 @@ class InvokeAIAppConfig(BaseSettings):
else: else:
root = Path("~/invokeai").expanduser().resolve() root = Path("~/invokeai").expanduser().resolve()
return root return root
class DefaultInvokeAIAppConfig(InvokeAIAppConfig):
"""A version of `InvokeAIAppConfig` that does not automatically parse any settings from environment variables
or any file.
This is useful for writing out a default config file.
Note that init settings are set if provided.
"""
@classmethod
def settings_customise_sources(
cls,
settings_cls: type[BaseSettings],
init_settings: PydanticBaseSettingsSource,
env_settings: PydanticBaseSettingsSource,
dotenv_settings: PydanticBaseSettingsSource,
file_secret_settings: PydanticBaseSettingsSource,
) -> tuple[PydanticBaseSettingsSource, ...]:
return (init_settings,)
def migrate_v3_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
"""Migrate a v3 config dictionary to a current config object.
Args:
config_dict: A dictionary of settings from a v3 config file.
Returns:
An instance of `InvokeAIAppConfig` with the migrated settings.
"""
parsed_config_dict: dict[str, Any] = {}
for _category_name, category_dict in config_dict["InvokeAI"].items():
for k, v in category_dict.items():
# `outdir` was renamed to `outputs_dir` in v4
if k == "outdir":
parsed_config_dict["outputs_dir"] = v
# `max_cache_size` was renamed to `ram` some time in v3, but both names were used
if k == "max_cache_size" and "ram" not in category_dict:
parsed_config_dict["ram"] = v
# `max_vram_cache_size` was renamed to `vram` some time in v3, but both names were used
if k == "max_vram_cache_size" and "vram" not in category_dict:
parsed_config_dict["vram"] = v
if k == "conf_path":
parsed_config_dict["legacy_models_yaml_path"] = v
if k == "legacy_conf_dir":
# The old default for this was "configs/stable-diffusion". If if the incoming config has that as the value, we won't set it.
# Else if the path ends in "stable-diffusion", we assume the parent is the new correct path.
# Else we do not attempt to migrate this setting
if v != "configs/stable-diffusion":
parsed_config_dict["legacy_conf_dir"] = v
elif Path(v).name == "stable-diffusion":
parsed_config_dict["legacy_conf_dir"] = str(Path(v).parent)
elif k in InvokeAIAppConfig.model_fields:
# skip unknown fields
parsed_config_dict[k] = v
# When migrating the config file, we should not include currently-set environment variables.
config = DefaultInvokeAIAppConfig.model_validate(parsed_config_dict)
return config
def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
"""Load and migrate a config file to the latest version.
Args:
config_path: Path to the config file.
Returns:
An instance of `InvokeAIAppConfig` with the loaded and migrated settings.
"""
assert config_path.suffix == ".yaml"
with open(config_path) as file:
loaded_config_dict = yaml.safe_load(file)
assert isinstance(loaded_config_dict, dict)
if "InvokeAI" in loaded_config_dict:
# This is a v3 config file, attempt to migrate it
shutil.copy(config_path, config_path.with_suffix(".yaml.bak"))
try:
# loaded_config_dict could be the wrong shape, but we will catch all exceptions below
migrated_config = migrate_v3_config_dict(loaded_config_dict) # pyright: ignore [reportUnknownArgumentType]
except Exception as e:
shutil.copy(config_path.with_suffix(".yaml.bak"), config_path)
raise RuntimeError(f"Failed to load and migrate v3 config file {config_path}: {e}") from e
migrated_config.write_file(config_path)
return migrated_config
else:
# Attempt to load as a v4 config file
try:
# Meta is not included in the model fields, so we need to validate it separately
config = InvokeAIAppConfig.model_validate(loaded_config_dict)
assert (
config.schema_version == CONFIG_SCHEMA_VERSION
), f"Invalid schema version, expected {CONFIG_SCHEMA_VERSION}: {config.schema_version}"
return config
except Exception as e:
raise RuntimeError(f"Failed to load config file {config_path}: {e}") from e
@lru_cache(maxsize=1)
def get_config() -> InvokeAIAppConfig:
"""Get the global singleton app config.
When first called, this function:
- Creates a config object. `pydantic-settings` handles merging of settings from environment variables, but not the init file.
- Retrieves any provided CLI args from the InvokeAIArgs class. It does not _parse_ the CLI args; that is done in the main entrypoint.
- Sets the root dir, if provided via CLI args.
- Logs in to HF if there is no valid token already.
- Copies all legacy configs to the legacy conf dir (needed for conversion from ckpt to diffusers).
- Reads and merges in settings from the config file if it exists, else writes out a default config file.
On subsequent calls, the object is returned from the cache.
"""
# This object includes environment variables, as parsed by pydantic-settings
config = InvokeAIAppConfig()
args = InvokeAIArgs.args
# This flag serves as a proxy for whether the config was retrieved in the context of the full application or not.
# If it is False, we should just return a default config and not set the root, log in to HF, etc.
if not InvokeAIArgs.did_parse:
return config
# Set CLI args
if root := getattr(args, "root", None):
config._root = Path(root)
if config_file := getattr(args, "config_file", None):
config._config_file = Path(config_file)
# Create the example config file, with some extra example values provided
example_config = DefaultInvokeAIAppConfig()
example_config.remote_api_tokens = [
URLRegexTokenPair(url_regex="cool-models.com", token="my_secret_token"),
URLRegexTokenPair(url_regex="nifty-models.com", token="some_other_token"),
]
example_config.write_file(config.config_file_path.with_suffix(".example.yaml"), as_example=True)
# Copy all legacy configs - We know `__path__[0]` is correct here
configs_src = Path(model_configs.__path__[0]) # pyright: ignore [reportUnknownMemberType, reportUnknownArgumentType, reportAttributeAccessIssue]
shutil.copytree(configs_src, config.legacy_conf_path, dirs_exist_ok=True)
if config.config_file_path.exists():
config_from_file = load_and_migrate_config(config.config_file_path)
# Clobbering here will overwrite any settings that were set via environment variables
config.update_config(config_from_file, clobber=False)
else:
# We should never write env vars to the config file
default_config = DefaultInvokeAIAppConfig()
default_config.write_file(config.config_file_path, as_example=False)
return config

View File

@@ -1,5 +1,4 @@
"""Init file for download queue.""" """Init file for download queue."""
from .download_base import DownloadJob, DownloadJobStatus, DownloadQueueServiceBase, UnknownJobIDException from .download_base import DownloadJob, DownloadJobStatus, DownloadQueueServiceBase, UnknownJobIDException
from .download_default import DownloadQueueService, TqdmProgress from .download_default import DownloadQueueService, TqdmProgress

View File

@@ -85,10 +85,8 @@ class DownloadQueueService(DownloadQueueServiceBase):
self._logger.info(f"Waiting for {len(active_jobs)} active download jobs to complete") self._logger.info(f"Waiting for {len(active_jobs)} active download jobs to complete")
with self._queue.mutex: with self._queue.mutex:
self._queue.queue.clear() self._queue.queue.clear()
self.cancel_all_jobs() self.join() # wait for all active jobs to finish
self._stop_event.set() self._stop_event.set()
for thread in self._worker_pool:
thread.join()
self._worker_pool.clear() self._worker_pool.clear()
def submit_download_job( def submit_download_job(
@@ -226,6 +224,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
job.job_started = get_iso_timestamp() job.job_started = get_iso_timestamp()
self._do_download(job) self._do_download(job)
self._signal_job_complete(job) self._signal_job_complete(job)
except (OSError, HTTPError) as excp: except (OSError, HTTPError) as excp:
job.error_type = excp.__class__.__name__ + f"({str(excp)})" job.error_type = excp.__class__.__name__ + f"({str(excp)})"
job.error = traceback.format_exc() job.error = traceback.format_exc()

View File

@@ -3,7 +3,7 @@
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
from invokeai.app.services.session_processor.session_processor_common import ProgressImage from invokeai.app.services.invocation_processor.invocation_processor_common import ProgressImage
from invokeai.app.services.session_queue.session_queue_common import ( from invokeai.app.services.session_queue.session_queue_common import (
BatchStatus, BatchStatus,
EnqueueBatchResult, EnqueueBatchResult,
@@ -12,12 +12,10 @@ from invokeai.app.services.session_queue.session_queue_common import (
) )
from invokeai.app.util.misc import get_timestamp from invokeai.app.util.misc import get_timestamp
from invokeai.backend.model_manager import AnyModelConfig from invokeai.backend.model_manager import AnyModelConfig
from invokeai.backend.model_manager.config import SubModelType
class EventServiceBase: class EventServiceBase:
queue_event: str = "queue_event" queue_event: str = "queue_event"
bulk_download_event: str = "bulk_download_event"
download_event: str = "download_event" download_event: str = "download_event"
model_event: str = "model_event" model_event: str = "model_event"
@@ -26,14 +24,6 @@ class EventServiceBase:
def dispatch(self, event_name: str, payload: Any) -> None: def dispatch(self, event_name: str, payload: Any) -> None:
pass pass
def _emit_bulk_download_event(self, event_name: str, payload: dict) -> None:
"""Bulk download events are emitted to a room with queue_id as the room name"""
payload["timestamp"] = get_timestamp()
self.dispatch(
event_name=EventServiceBase.bulk_download_event,
payload={"event": event_name, "data": payload},
)
def __emit_queue_event(self, event_name: str, payload: dict) -> None: def __emit_queue_event(self, event_name: str, payload: dict) -> None:
"""Queue events are emitted to a room with queue_id as the room name""" """Queue events are emitted to a room with queue_id as the room name"""
payload["timestamp"] = get_timestamp() payload["timestamp"] = get_timestamp()
@@ -81,7 +71,7 @@ class EventServiceBase:
"graph_execution_state_id": graph_execution_state_id, "graph_execution_state_id": graph_execution_state_id,
"node_id": node_id, "node_id": node_id,
"source_node_id": source_node_id, "source_node_id": source_node_id,
"progress_image": progress_image.model_dump(mode="json") if progress_image is not None else None, "progress_image": progress_image.model_dump() if progress_image is not None else None,
"step": step, "step": step,
"order": order, "order": order,
"total_steps": total_steps, "total_steps": total_steps,
@@ -181,7 +171,6 @@ class EventServiceBase:
queue_batch_id: str, queue_batch_id: str,
graph_execution_state_id: str, graph_execution_state_id: str,
model_config: AnyModelConfig, model_config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> None: ) -> None:
"""Emitted when a model is requested""" """Emitted when a model is requested"""
self.__emit_queue_event( self.__emit_queue_event(
@@ -191,8 +180,7 @@ class EventServiceBase:
"queue_item_id": queue_item_id, "queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id, "queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id, "graph_execution_state_id": graph_execution_state_id,
"model_config": model_config.model_dump(mode="json"), "model_config": model_config.model_dump(),
"submodel_type": submodel_type,
}, },
) )
@@ -203,7 +191,6 @@ class EventServiceBase:
queue_batch_id: str, queue_batch_id: str,
graph_execution_state_id: str, graph_execution_state_id: str,
model_config: AnyModelConfig, model_config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> None: ) -> None:
"""Emitted when a model is correctly loaded (returns model info)""" """Emitted when a model is correctly loaded (returns model info)"""
self.__emit_queue_event( self.__emit_queue_event(
@@ -213,8 +200,53 @@ class EventServiceBase:
"queue_item_id": queue_item_id, "queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id, "queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id, "graph_execution_state_id": graph_execution_state_id,
"model_config": model_config.model_dump(mode="json"), "model_config": model_config.model_dump(),
"submodel_type": submodel_type, },
)
def emit_session_retrieval_error(
self,
queue_id: str,
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
error_type: str,
error: str,
) -> None:
"""Emitted when session retrieval fails"""
self.__emit_queue_event(
event_name="session_retrieval_error",
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"error_type": error_type,
"error": error,
},
)
def emit_invocation_retrieval_error(
self,
queue_id: str,
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
node_id: str,
error_type: str,
error: str,
) -> None:
"""Emitted when invocation retrieval fails"""
self.__emit_queue_event(
event_name="invocation_retrieval_error",
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"node_id": node_id,
"error_type": error_type,
"error": error,
}, },
) )
@@ -259,8 +291,8 @@ class EventServiceBase:
"started_at": str(session_queue_item.started_at) if session_queue_item.started_at else None, "started_at": str(session_queue_item.started_at) if session_queue_item.started_at else None,
"completed_at": str(session_queue_item.completed_at) if session_queue_item.completed_at else None, "completed_at": str(session_queue_item.completed_at) if session_queue_item.completed_at else None,
}, },
"batch_status": batch_status.model_dump(mode="json"), "batch_status": batch_status.model_dump(),
"queue_status": queue_status.model_dump(mode="json"), "queue_status": queue_status.model_dump(),
}, },
) )
@@ -362,7 +394,6 @@ class EventServiceBase:
bytes: int, bytes: int,
total_bytes: int, total_bytes: int,
parts: List[Dict[str, Union[str, int]]], parts: List[Dict[str, Union[str, int]]],
id: int,
) -> None: ) -> None:
""" """
Emit at intervals while the install job is in progress (remote models only). Emit at intervals while the install job is in progress (remote models only).
@@ -382,21 +413,9 @@ class EventServiceBase:
"bytes": bytes, "bytes": bytes,
"total_bytes": total_bytes, "total_bytes": total_bytes,
"parts": parts, "parts": parts,
"id": id,
}, },
) )
def emit_model_install_downloads_done(self, source: str) -> None:
"""
Emit once when all parts are downloaded, but before the probing and registration start.
:param source: Source of the model; local path, repo_id or url
"""
self.__emit_model_event(
event_name="model_install_downloads_done",
payload={"source": source},
)
def emit_model_install_running(self, source: str) -> None: def emit_model_install_running(self, source: str) -> None:
""" """
Emit once when an install job becomes active. Emit once when an install job becomes active.
@@ -408,7 +427,7 @@ class EventServiceBase:
payload={"source": source}, payload={"source": source},
) )
def emit_model_install_completed(self, source: str, key: str, id: int, total_bytes: Optional[int] = None) -> None: def emit_model_install_completed(self, source: str, key: str, total_bytes: Optional[int] = None) -> None:
""" """
Emit when an install job is completed successfully. Emit when an install job is completed successfully.
@@ -418,10 +437,14 @@ class EventServiceBase:
""" """
self.__emit_model_event( self.__emit_model_event(
event_name="model_install_completed", event_name="model_install_completed",
payload={"source": source, "total_bytes": total_bytes, "key": key, "id": id}, payload={
"source": source,
"total_bytes": total_bytes,
"key": key,
},
) )
def emit_model_install_cancelled(self, source: str, id: int) -> None: def emit_model_install_cancelled(self, source: str) -> None:
""" """
Emit when an install job is cancelled. Emit when an install job is cancelled.
@@ -429,10 +452,15 @@ class EventServiceBase:
""" """
self.__emit_model_event( self.__emit_model_event(
event_name="model_install_cancelled", event_name="model_install_cancelled",
payload={"source": source, "id": id}, payload={"source": source},
) )
def emit_model_install_error(self, source: str, error_type: str, error: str, id: int) -> None: def emit_model_install_error(
self,
source: str,
error_type: str,
error: str,
) -> None:
""" """
Emit when an install job encounters an exception. Emit when an install job encounters an exception.
@@ -442,45 +470,9 @@ class EventServiceBase:
""" """
self.__emit_model_event( self.__emit_model_event(
event_name="model_install_error", event_name="model_install_error",
payload={"source": source, "error_type": error_type, "error": error, "id": id},
)
def emit_bulk_download_started(
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
) -> None:
"""Emitted when a bulk download starts"""
self._emit_bulk_download_event(
event_name="bulk_download_started",
payload={ payload={
"bulk_download_id": bulk_download_id, "source": source,
"bulk_download_item_id": bulk_download_item_id, "error_type": error_type,
"bulk_download_item_name": bulk_download_item_name,
},
)
def emit_bulk_download_completed(
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
) -> None:
"""Emitted when a bulk download completes"""
self._emit_bulk_download_event(
event_name="bulk_download_completed",
payload={
"bulk_download_id": bulk_download_id,
"bulk_download_item_id": bulk_download_item_id,
"bulk_download_item_name": bulk_download_item_name,
},
)
def emit_bulk_download_failed(
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, error: str
) -> None:
"""Emitted when a bulk download fails"""
self._emit_bulk_download_event(
event_name="bulk_download_failed",
payload={
"bulk_download_id": bulk_download_id,
"bulk_download_item_id": bulk_download_item_id,
"bulk_download_item_name": bulk_download_item_name,
"error": error, "error": error,
}, },
) )

View File

@@ -82,7 +82,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
image_path, image_path,
"PNG", "PNG",
pnginfo=pnginfo, pnginfo=pnginfo,
compress_level=self.__invoker.services.configuration.pil_compress_level, compress_level=self.__invoker.services.configuration.png_compress_level,
) )
thumbnail_name = get_thumbnail_name(image_name) thumbnail_name = get_thumbnail_name(image_name)

View File

@@ -41,9 +41,8 @@ class InvocationCacheBase(ABC):
"""Clears the cache""" """Clears the cache"""
pass pass
@staticmethod
@abstractmethod @abstractmethod
def create_key(invocation: BaseInvocation) -> int: def create_key(self, invocation: BaseInvocation) -> int:
"""Gets the key for the invocation's cache item""" """Gets the key for the invocation's cache item"""
pass pass

View File

@@ -61,7 +61,9 @@ class MemoryInvocationCache(InvocationCacheBase):
self._delete_oldest_access(number_to_delete) self._delete_oldest_access(number_to_delete)
self._cache[key] = CachedItem( self._cache[key] = CachedItem(
invocation_output, invocation_output,
invocation_output.model_dump_json(warnings=False, exclude_defaults=True, exclude_unset=True), invocation_output.model_dump_json(
warnings=False, exclude_defaults=True, exclude_unset=True, include={"type"}
),
) )
def _delete_oldest_access(self, number_to_delete: int) -> None: def _delete_oldest_access(self, number_to_delete: int) -> None:
@@ -79,7 +81,7 @@ class MemoryInvocationCache(InvocationCacheBase):
with self._lock: with self._lock:
return self._delete(key) return self._delete(key)
def clear(self) -> None: def clear(self, *args, **kwargs) -> None:
with self._lock: with self._lock:
if self._max_cache_size == 0: if self._max_cache_size == 0:
return return

View File

@@ -0,0 +1,5 @@
from abc import ABC
class InvocationProcessorABC(ABC): # noqa: B024
pass

View File

@@ -0,0 +1,15 @@
from pydantic import BaseModel, Field
class ProgressImage(BaseModel):
"""The progress image sent intermittently during processing"""
width: int = Field(description="The effective width of the image in pixels")
height: int = Field(description="The effective height of the image in pixels")
dataURL: str = Field(description="The image data as a b64 data URL")
class CanceledException(Exception):
"""Execution canceled by user."""
pass

View File

@@ -0,0 +1,243 @@
import time
import traceback
from contextlib import suppress
from threading import BoundedSemaphore, Event, Thread
from typing import Optional
import invokeai.backend.util.logging as logger
from invokeai.app.services.invocation_queue.invocation_queue_common import InvocationQueueItem
from invokeai.app.services.invocation_stats.invocation_stats_common import (
GESStatsNotFoundError,
)
from invokeai.app.services.shared.invocation_context import InvocationContextData, build_invocation_context
from invokeai.app.util.profiler import Profiler
from ..invoker import Invoker
from .invocation_processor_base import InvocationProcessorABC
from .invocation_processor_common import CanceledException
class DefaultInvocationProcessor(InvocationProcessorABC):
__invoker_thread: Thread
__stop_event: Event
__invoker: Invoker
__threadLimit: BoundedSemaphore
def start(self, invoker: Invoker) -> None:
# LS - this will probably break
# but the idea is to enable multithreading up to the number of available
# GPUs. Nodes will block on model loading if no GPU is free.
self.__threadLimit = BoundedSemaphore(invoker.services.model_manager.gpu_count)
self.__invoker = invoker
self.__stop_event = Event()
self.__invoker_thread = Thread(
name="invoker_processor",
target=self.__process,
kwargs={"stop_event": self.__stop_event},
)
self.__invoker_thread.daemon = True # TODO: make async and do not use threads
self.__invoker_thread.start()
def stop(self, *args, **kwargs) -> None:
self.__stop_event.set()
def __process(self, stop_event: Event):
try:
self.__threadLimit.acquire()
queue_item: Optional[InvocationQueueItem] = None
profiler = (
Profiler(
logger=self.__invoker.services.logger,
output_dir=self.__invoker.services.configuration.profiles_path,
prefix=self.__invoker.services.configuration.profile_prefix,
)
if self.__invoker.services.configuration.profile_graphs
else None
)
def stats_cleanup(graph_execution_state_id: str) -> None:
if profiler:
profile_path = profiler.stop()
stats_path = profile_path.with_suffix(".json")
self.__invoker.services.performance_statistics.dump_stats(
graph_execution_state_id=graph_execution_state_id, output_path=stats_path
)
with suppress(GESStatsNotFoundError):
self.__invoker.services.performance_statistics.log_stats(graph_execution_state_id)
self.__invoker.services.performance_statistics.reset_stats(graph_execution_state_id)
while not stop_event.is_set():
try:
queue_item = self.__invoker.services.queue.get()
except Exception as e:
self.__invoker.services.logger.error("Exception while getting from queue:\n%s" % e)
if not queue_item: # Probably stopping
# do not hammer the queue
time.sleep(0.5)
continue
if profiler and profiler.profile_id != queue_item.graph_execution_state_id:
profiler.start(profile_id=queue_item.graph_execution_state_id)
try:
graph_execution_state = self.__invoker.services.graph_execution_manager.get(
queue_item.graph_execution_state_id
)
except Exception as e:
self.__invoker.services.logger.error("Exception while retrieving session:\n%s" % e)
self.__invoker.services.events.emit_session_retrieval_error(
queue_batch_id=queue_item.session_queue_batch_id,
queue_item_id=queue_item.session_queue_item_id,
queue_id=queue_item.session_queue_id,
graph_execution_state_id=queue_item.graph_execution_state_id,
error_type=e.__class__.__name__,
error=traceback.format_exc(),
)
continue
try:
invocation = graph_execution_state.execution_graph.get_node(queue_item.invocation_id)
except Exception as e:
self.__invoker.services.logger.error("Exception while retrieving invocation:\n%s" % e)
self.__invoker.services.events.emit_invocation_retrieval_error(
queue_batch_id=queue_item.session_queue_batch_id,
queue_item_id=queue_item.session_queue_item_id,
queue_id=queue_item.session_queue_id,
graph_execution_state_id=queue_item.graph_execution_state_id,
node_id=queue_item.invocation_id,
error_type=e.__class__.__name__,
error=traceback.format_exc(),
)
continue
# get the source node id to provide to clients (the prepared node id is not as useful)
source_node_id = graph_execution_state.prepared_source_mapping[invocation.id]
# Send starting event
self.__invoker.services.events.emit_invocation_started(
queue_batch_id=queue_item.session_queue_batch_id,
queue_item_id=queue_item.session_queue_item_id,
queue_id=queue_item.session_queue_id,
graph_execution_state_id=graph_execution_state.id,
node=invocation.model_dump(),
source_node_id=source_node_id,
)
# Invoke
try:
graph_id = graph_execution_state.id
with self.__invoker.services.performance_statistics.collect_stats(invocation, graph_id):
# use the internal invoke_internal(), which wraps the node's invoke() method,
# which handles a few things:
# - nodes that require a value, but get it only from a connection
# - referencing the invocation cache instead of executing the node
context_data = InvocationContextData(
invocation=invocation,
session_id=graph_id,
workflow=queue_item.workflow,
source_node_id=source_node_id,
queue_id=queue_item.session_queue_id,
queue_item_id=queue_item.session_queue_item_id,
batch_id=queue_item.session_queue_batch_id,
)
context = build_invocation_context(
services=self.__invoker.services,
context_data=context_data,
)
outputs = invocation.invoke_internal(context=context, services=self.__invoker.services)
# Check queue to see if this is canceled, and skip if so
if self.__invoker.services.queue.is_canceled(graph_execution_state.id):
continue
# Save outputs and history
graph_execution_state.complete(invocation.id, outputs)
# Save the state changes
self.__invoker.services.graph_execution_manager.set(graph_execution_state)
# Send complete event
self.__invoker.services.events.emit_invocation_complete(
queue_batch_id=queue_item.session_queue_batch_id,
queue_item_id=queue_item.session_queue_item_id,
queue_id=queue_item.session_queue_id,
graph_execution_state_id=graph_execution_state.id,
node=invocation.model_dump(),
source_node_id=source_node_id,
result=outputs.model_dump(),
)
except KeyboardInterrupt:
pass
except CanceledException:
stats_cleanup(graph_execution_state.id)
pass
except Exception as e:
error = traceback.format_exc()
logger.error(error)
# Save error
graph_execution_state.set_node_error(invocation.id, error)
# Save the state changes
self.__invoker.services.graph_execution_manager.set(graph_execution_state)
self.__invoker.services.logger.error("Error while invoking:\n%s" % e)
# Send error event
self.__invoker.services.events.emit_invocation_error(
queue_batch_id=queue_item.session_queue_batch_id,
queue_item_id=queue_item.session_queue_item_id,
queue_id=queue_item.session_queue_id,
graph_execution_state_id=graph_execution_state.id,
node=invocation.model_dump(),
source_node_id=source_node_id,
error_type=e.__class__.__name__,
error=error,
)
pass
# Check queue to see if this is canceled, and skip if so
if self.__invoker.services.queue.is_canceled(graph_execution_state.id):
continue
# Queue any further commands if invoking all
is_complete = graph_execution_state.is_complete()
if queue_item.invoke_all and not is_complete:
try:
self.__invoker.invoke(
session_queue_batch_id=queue_item.session_queue_batch_id,
session_queue_item_id=queue_item.session_queue_item_id,
session_queue_id=queue_item.session_queue_id,
graph_execution_state=graph_execution_state,
workflow=queue_item.workflow,
invoke_all=True,
)
except Exception as e:
self.__invoker.services.logger.error("Error while invoking:\n%s" % e)
self.__invoker.services.events.emit_invocation_error(
queue_batch_id=queue_item.session_queue_batch_id,
queue_item_id=queue_item.session_queue_item_id,
queue_id=queue_item.session_queue_id,
graph_execution_state_id=graph_execution_state.id,
node=invocation.model_dump(),
source_node_id=source_node_id,
error_type=e.__class__.__name__,
error=traceback.format_exc(),
)
elif is_complete:
self.__invoker.services.events.emit_graph_execution_complete(
queue_batch_id=queue_item.session_queue_batch_id,
queue_item_id=queue_item.session_queue_item_id,
queue_id=queue_item.session_queue_id,
graph_execution_state_id=graph_execution_state.id,
)
stats_cleanup(graph_execution_state.id)
except KeyboardInterrupt:
pass # Log something? KeyboardInterrupt is probably not going to be seen by the processor
finally:
self.__threadLimit.release()

View File

@@ -0,0 +1,26 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from abc import ABC, abstractmethod
from typing import Optional
from .invocation_queue_common import InvocationQueueItem
class InvocationQueueABC(ABC):
"""Abstract base class for all invocation queues"""
@abstractmethod
def get(self) -> InvocationQueueItem:
pass
@abstractmethod
def put(self, item: Optional[InvocationQueueItem]) -> None:
pass
@abstractmethod
def cancel(self, graph_execution_state_id: str) -> None:
pass
@abstractmethod
def is_canceled(self, graph_execution_state_id: str) -> bool:
pass

View File

@@ -0,0 +1,23 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import time
from typing import Optional
from pydantic import BaseModel, Field
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
class InvocationQueueItem(BaseModel):
graph_execution_state_id: str = Field(description="The ID of the graph execution state")
invocation_id: str = Field(description="The ID of the node being invoked")
session_queue_id: str = Field(description="The ID of the session queue from which this invocation queue item came")
session_queue_item_id: int = Field(
description="The ID of session queue item from which this invocation queue item came"
)
session_queue_batch_id: str = Field(
description="The ID of the session batch from which this invocation queue item came"
)
workflow: Optional[WorkflowWithoutID] = Field(description="The workflow associated with this queue item")
invoke_all: bool = Field(default=False)
timestamp: float = Field(default_factory=time.time)

View File

@@ -0,0 +1,44 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import time
from queue import Queue
from typing import Optional
from .invocation_queue_base import InvocationQueueABC
from .invocation_queue_common import InvocationQueueItem
class MemoryInvocationQueue(InvocationQueueABC):
__queue: Queue
__cancellations: dict[str, float]
def __init__(self):
self.__queue = Queue()
self.__cancellations = {}
def get(self) -> InvocationQueueItem:
item = self.__queue.get()
while (
isinstance(item, InvocationQueueItem)
and item.graph_execution_state_id in self.__cancellations
and self.__cancellations[item.graph_execution_state_id] > item.timestamp
):
item = self.__queue.get()
# Clear old items
for graph_execution_state_id in list(self.__cancellations.keys()):
if self.__cancellations[graph_execution_state_id] < item.timestamp:
del self.__cancellations[graph_execution_state_id]
return item
def put(self, item: Optional[InvocationQueueItem]) -> None:
self.__queue.put(item)
def cancel(self, graph_execution_state_id: str) -> None:
if graph_execution_state_id not in self.__cancellations:
self.__cancellations[graph_execution_state_id] = time.time()
def is_canceled(self, graph_execution_state_id: str) -> bool:
return graph_execution_state_id in self.__cancellations

View File

@@ -16,7 +16,6 @@ if TYPE_CHECKING:
from .board_images.board_images_base import BoardImagesServiceABC from .board_images.board_images_base import BoardImagesServiceABC
from .board_records.board_records_base import BoardRecordStorageBase from .board_records.board_records_base import BoardRecordStorageBase
from .boards.boards_base import BoardServiceABC from .boards.boards_base import BoardServiceABC
from .bulk_download.bulk_download_base import BulkDownloadBase
from .config import InvokeAIAppConfig from .config import InvokeAIAppConfig
from .download import DownloadQueueServiceBase from .download import DownloadQueueServiceBase
from .events.events_base import EventServiceBase from .events.events_base import EventServiceBase
@@ -24,12 +23,15 @@ if TYPE_CHECKING:
from .image_records.image_records_base import ImageRecordStorageBase from .image_records.image_records_base import ImageRecordStorageBase
from .images.images_base import ImageServiceABC from .images.images_base import ImageServiceABC
from .invocation_cache.invocation_cache_base import InvocationCacheBase from .invocation_cache.invocation_cache_base import InvocationCacheBase
from .invocation_processor.invocation_processor_base import InvocationProcessorABC
from .invocation_queue.invocation_queue_base import InvocationQueueABC
from .invocation_stats.invocation_stats_base import InvocationStatsServiceBase from .invocation_stats.invocation_stats_base import InvocationStatsServiceBase
from .model_images.model_images_base import ModelImageFileStorageBase from .item_storage.item_storage_base import ItemStorageABC
from .model_manager.model_manager_base import ModelManagerServiceBase from .model_manager.model_manager_base import ModelManagerServiceBase
from .names.names_base import NameServiceBase from .names.names_base import NameServiceBase
from .session_processor.session_processor_base import SessionProcessorBase from .session_processor.session_processor_base import SessionProcessorBase
from .session_queue.session_queue_base import SessionQueueBase from .session_queue.session_queue_base import SessionQueueBase
from .shared.graph import GraphExecutionState
from .urls.urls_base import UrlServiceBase from .urls.urls_base import UrlServiceBase
from .workflow_records.workflow_records_base import WorkflowRecordsStorageBase from .workflow_records.workflow_records_base import WorkflowRecordsStorageBase
@@ -43,17 +45,18 @@ class InvocationServices:
board_image_records: "BoardImageRecordStorageBase", board_image_records: "BoardImageRecordStorageBase",
boards: "BoardServiceABC", boards: "BoardServiceABC",
board_records: "BoardRecordStorageBase", board_records: "BoardRecordStorageBase",
bulk_download: "BulkDownloadBase",
configuration: "InvokeAIAppConfig", configuration: "InvokeAIAppConfig",
events: "EventServiceBase", events: "EventServiceBase",
graph_execution_manager: "ItemStorageABC[GraphExecutionState]",
images: "ImageServiceABC", images: "ImageServiceABC",
image_files: "ImageFileStorageBase", image_files: "ImageFileStorageBase",
image_records: "ImageRecordStorageBase", image_records: "ImageRecordStorageBase",
logger: "Logger", logger: "Logger",
model_images: "ModelImageFileStorageBase",
model_manager: "ModelManagerServiceBase", model_manager: "ModelManagerServiceBase",
download_queue: "DownloadQueueServiceBase", download_queue: "DownloadQueueServiceBase",
processor: "InvocationProcessorABC",
performance_statistics: "InvocationStatsServiceBase", performance_statistics: "InvocationStatsServiceBase",
queue: "InvocationQueueABC",
session_queue: "SessionQueueBase", session_queue: "SessionQueueBase",
session_processor: "SessionProcessorBase", session_processor: "SessionProcessorBase",
invocation_cache: "InvocationCacheBase", invocation_cache: "InvocationCacheBase",
@@ -67,17 +70,18 @@ class InvocationServices:
self.board_image_records = board_image_records self.board_image_records = board_image_records
self.boards = boards self.boards = boards
self.board_records = board_records self.board_records = board_records
self.bulk_download = bulk_download
self.configuration = configuration self.configuration = configuration
self.events = events self.events = events
self.graph_execution_manager = graph_execution_manager
self.images = images self.images = images
self.image_files = image_files self.image_files = image_files
self.image_records = image_records self.image_records = image_records
self.logger = logger self.logger = logger
self.model_images = model_images
self.model_manager = model_manager self.model_manager = model_manager
self.download_queue = download_queue self.download_queue = download_queue
self.processor = processor
self.performance_statistics = performance_statistics self.performance_statistics = performance_statistics
self.queue = queue
self.session_queue = session_queue self.session_queue = session_queue
self.session_processor = session_processor self.session_processor = session_processor
self.invocation_cache = invocation_cache self.invocation_cache = invocation_cache

View File

@@ -3,7 +3,7 @@
Usage: Usage:
statistics = InvocationStatsService() statistics = InvocationStatsService(graph_execution_manager)
with statistics.collect_stats(invocation, graph_execution_state.id): with statistics.collect_stats(invocation, graph_execution_state.id):
... execute graphs... ... execute graphs...
statistics.log_stats() statistics.log_stats()
@@ -30,7 +30,7 @@ writes to the system log is stored in InvocationServices.performance_statistics.
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from typing import ContextManager from typing import Iterator
from invokeai.app.invocations.baseinvocation import BaseInvocation from invokeai.app.invocations.baseinvocation import BaseInvocation
from invokeai.app.services.invocation_stats.invocation_stats_common import InvocationStatsSummary from invokeai.app.services.invocation_stats.invocation_stats_common import InvocationStatsSummary
@@ -50,7 +50,7 @@ class InvocationStatsServiceBase(ABC):
self, self,
invocation: BaseInvocation, invocation: BaseInvocation,
graph_execution_state_id: str, graph_execution_state_id: str,
) -> ContextManager[None]: ) -> Iterator[None]:
""" """
Return a context object that will capture the statistics on the execution Return a context object that will capture the statistics on the execution
of invocaation. Use with: to place around the part of the code that executes the invocation. of invocaation. Use with: to place around the part of the code that executes the invocation.
@@ -60,8 +60,12 @@ class InvocationStatsServiceBase(ABC):
pass pass
@abstractmethod @abstractmethod
def reset_stats(self): def reset_stats(self, graph_execution_state_id: str) -> None:
"""Reset all stored statistics.""" """
Reset all statistics for the indicated graph.
:param graph_execution_state_id: The id of the session whose stats to reset.
:raises GESStatsNotFoundError: if the graph isn't tracked in the stats.
"""
pass pass
@abstractmethod @abstractmethod

View File

@@ -2,7 +2,7 @@ import json
import time import time
from contextlib import contextmanager from contextlib import contextmanager
from pathlib import Path from pathlib import Path
from typing import Generator from typing import Iterator
import psutil import psutil
import torch import torch
@@ -10,6 +10,7 @@ import torch
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from invokeai.app.invocations.baseinvocation import BaseInvocation from invokeai.app.invocations.baseinvocation import BaseInvocation
from invokeai.app.services.invoker import Invoker from invokeai.app.services.invoker import Invoker
from invokeai.app.services.item_storage.item_storage_common import ItemNotFoundError
from invokeai.backend.model_manager.load.model_cache import CacheStats from invokeai.backend.model_manager.load.model_cache import CacheStats
from .invocation_stats_base import InvocationStatsServiceBase from .invocation_stats_base import InvocationStatsServiceBase
@@ -41,7 +42,7 @@ class InvocationStatsService(InvocationStatsServiceBase):
self._invoker = invoker self._invoker = invoker
@contextmanager @contextmanager
def collect_stats(self, invocation: BaseInvocation, graph_execution_state_id: str) -> Generator[None, None, None]: def collect_stats(self, invocation: BaseInvocation, graph_execution_state_id: str) -> Iterator[None]:
# This is to handle case of the model manager not being initialized, which happens # This is to handle case of the model manager not being initialized, which happens
# during some tests. # during some tests.
services = self._invoker.services services = self._invoker.services
@@ -50,6 +51,9 @@ class InvocationStatsService(InvocationStatsServiceBase):
self._stats[graph_execution_state_id] = GraphExecutionStats() self._stats[graph_execution_state_id] = GraphExecutionStats()
self._cache_stats[graph_execution_state_id] = CacheStats() self._cache_stats[graph_execution_state_id] = CacheStats()
# Prune stale stats. There should be none since we're starting a new graph, but just in case.
self._prune_stale_stats()
# Record state before the invocation. # Record state before the invocation.
start_time = time.time() start_time = time.time()
start_ram = psutil.Process().memory_info().rss start_ram = psutil.Process().memory_info().rss
@@ -74,9 +78,42 @@ class InvocationStatsService(InvocationStatsServiceBase):
) )
self._stats[graph_execution_state_id].add_node_execution_stats(node_stats) self._stats[graph_execution_state_id].add_node_execution_stats(node_stats)
def reset_stats(self): def _prune_stale_stats(self) -> None:
self._stats = {} """Check all graphs being tracked and prune any that have completed/errored.
self._cache_stats = {}
This shouldn't be necessary, but we don't have totally robust upstream handling of graph completions/errors, so
for now we call this function periodically to prevent them from accumulating.
"""
to_prune: list[str] = []
for graph_execution_state_id in self._stats:
try:
graph_execution_state = self._invoker.services.graph_execution_manager.get(graph_execution_state_id)
except ItemNotFoundError:
# TODO(ryand): What would cause this? Should this exception just be allowed to propagate?
logger.warning(f"Failed to get graph state for {graph_execution_state_id}.")
continue
if not graph_execution_state.is_complete():
# The graph is still running, don't prune it.
continue
to_prune.append(graph_execution_state_id)
for graph_execution_state_id in to_prune:
del self._stats[graph_execution_state_id]
del self._cache_stats[graph_execution_state_id]
if len(to_prune) > 0:
logger.info(f"Pruned stale graph stats for {to_prune}.")
def reset_stats(self, graph_execution_state_id: str):
try:
del self._stats[graph_execution_state_id]
del self._cache_stats[graph_execution_state_id]
except KeyError as e:
raise GESStatsNotFoundError(
f"Attempted to clear statistics for unknown graph {graph_execution_state_id}: {e}."
) from e
def get_stats(self, graph_execution_state_id: str) -> InvocationStatsSummary: def get_stats(self, graph_execution_state_id: str) -> InvocationStatsSummary:
graph_stats_summary = self._get_graph_summary(graph_execution_state_id) graph_stats_summary = self._get_graph_summary(graph_execution_state_id)

View File

@@ -1,7 +1,12 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import Optional
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
from .invocation_queue.invocation_queue_common import InvocationQueueItem
from .invocation_services import InvocationServices from .invocation_services import InvocationServices
from .shared.graph import Graph, GraphExecutionState
class Invoker: class Invoker:
@@ -13,6 +18,51 @@ class Invoker:
self.services = services self.services = services
self._start() self._start()
def invoke(
self,
session_queue_id: str,
session_queue_item_id: int,
session_queue_batch_id: str,
graph_execution_state: GraphExecutionState,
workflow: Optional[WorkflowWithoutID] = None,
invoke_all: bool = False,
) -> Optional[str]:
"""Determines the next node to invoke and enqueues it, preparing if needed.
Returns the id of the queued node, or `None` if there are no nodes left to enqueue."""
# Get the next invocation
invocation = graph_execution_state.next()
if not invocation:
return None
# Save the execution state
self.services.graph_execution_manager.set(graph_execution_state)
# Queue the invocation
self.services.queue.put(
InvocationQueueItem(
session_queue_id=session_queue_id,
session_queue_item_id=session_queue_item_id,
session_queue_batch_id=session_queue_batch_id,
graph_execution_state_id=graph_execution_state.id,
invocation_id=invocation.id,
workflow=workflow,
invoke_all=invoke_all,
)
)
return invocation.id
def create_execution_state(self, graph: Optional[Graph] = None) -> GraphExecutionState:
"""Creates a new execution state for the given graph"""
new_state = GraphExecutionState(graph=Graph() if graph is None else graph)
self.services.graph_execution_manager.set(new_state)
return new_state
def cancel(self, graph_execution_state_id: str) -> None:
"""Cancels the given execution state"""
self.services.queue.cancel(graph_execution_state_id)
def __start_service(self, service) -> None: def __start_service(self, service) -> None:
# Call start() method on any services that have it # Call start() method on any services that have it
start_op = getattr(service, "start", None) start_op = getattr(service, "start", None)
@@ -35,3 +85,5 @@ class Invoker:
# First stop all services # First stop all services
for service in vars(self.services): for service in vars(self.services):
self.__stop_service(getattr(self.services, service)) self.__stop_service(getattr(self.services, service))
self.services.queue.put(None)

View File

@@ -1,33 +0,0 @@
from abc import ABC, abstractmethod
from pathlib import Path
from PIL.Image import Image as PILImageType
class ModelImageFileStorageBase(ABC):
"""Low-level service responsible for storing and retrieving image files."""
@abstractmethod
def get(self, model_key: str) -> PILImageType:
"""Retrieves a model image as PIL Image."""
pass
@abstractmethod
def get_path(self, model_key: str) -> Path:
"""Gets the internal path to a model image."""
pass
@abstractmethod
def get_url(self, model_key: str) -> str | None:
"""Gets the URL to fetch a model image."""
pass
@abstractmethod
def save(self, image: PILImageType, model_key: str) -> None:
"""Saves a model image."""
pass
@abstractmethod
def delete(self, model_key: str) -> None:
"""Deletes a model image."""
pass

View File

@@ -1,20 +0,0 @@
# TODO: Should these excpetions subclass existing python exceptions?
class ModelImageFileNotFoundException(Exception):
"""Raised when an image file is not found in storage."""
def __init__(self, message="Model image file not found"):
super().__init__(message)
class ModelImageFileSaveException(Exception):
"""Raised when an image cannot be saved."""
def __init__(self, message="Model image file not saved"):
super().__init__(message)
class ModelImageFileDeleteException(Exception):
"""Raised when an image cannot be deleted."""
def __init__(self, message="Model image file not deleted"):
super().__init__(message)

Some files were not shown because too many files have changed in this diff Show More