Compare commits

..

120 Commits

Author SHA1 Message Date
Lincoln Stein
b85f2bc87d add support for multi-gpu rendering
This commit adds speculative support for parallel rendering across
multiple GPUs. The parallelism is at the level of a session. Each
session is given access to a different GPU. When all GPUs are busy,
execution of the session will block until a GPU becomes available.

The code is untested at the current time, and is being posted for
comment.
2024-02-19 15:21:55 -05:00
Lincoln Stein
b06d63fb34 remove errant def that was crashing invokeai-configure 2024-02-19 17:31:53 +11:00
dunkeroni
5278a64301 one more redundant RGB convert removed 2024-02-19 17:31:08 +11:00
dunkeroni
4de4473c0f chore: ruff formatting 2024-02-19 17:31:08 +11:00
dunkeroni
2c28a850ca chore(invocations): remove redundant RGB conversions 2024-02-19 17:31:08 +11:00
dunkeroni
6dada3326d chore(invocations): use IMAGE_MODES constant literal 2024-02-19 17:31:08 +11:00
dunkeroni
2dfdc02ec8 fix: removed custom module 2024-02-19 17:31:08 +11:00
dunkeroni
1f19db4c6a fix(nodes): canny preprocessor uses RGBA again 2024-02-19 17:31:08 +11:00
dunkeroni
7c150c27f2 feat(nodes): format option for get_image method
Also default CNet preprocessors to "RGB"
2024-02-19 17:31:08 +11:00
blessedcoolant
248916c190 fix: Alpha channel causing issue with DW Processor 2024-02-19 08:17:56 +11:00
psychedelicious
be8b99eed5 final tidying before marking PR as ready for review
- Replace AnyModelLoader with ModelLoaderRegistry
- Fix type check errors in multiple files
- Remove apparently unneeded `get_model_config_enum()` method from model manager
- Remove last vestiges of old model manager
- Updated tests and documentation

resolve conflict with seamless.py
2024-02-19 08:16:56 +11:00
Lincoln Stein
2ad0752582 Tidy names and locations of modules
- Rename old "model_management" directory to "model_management_OLD" in order to catch
  dangling references to original model manager.
- Caught and fixed most dangling references (still checking)
- Rename lora, textual_inversion and model_patcher modules
- Introduce a RawModel base class to simplfy the Union returned by the
  model loaders.
- Tidy up the model manager 2-related tests. Add useful fixtures, and
  a finalizer to the queue and installer fixtures that will stop the
  services and release threads.
2024-02-19 08:16:56 +11:00
Lincoln Stein
ba1f8878dd Fix issues identified during PR review by RyanjDick and brandonrising
- ModelMetadataStoreService is now injected into ModelRecordStoreService
  (these two services are really joined at the hip, and should someday be merged)
- ModelRecordStoreService is now injected into ModelManagerService
- Reduced timeout value for the various installer and download wait*() methods
- Introduced a Mock modelmanager for testing
- Removed bare print() statement with _logger in the install helper backend.
- Removed unused code from model loader init file
- Made `locker` a private variable in the `LoadedModel` object.
- Fixed up model merge frontend (will be deprecated anyway!)
2024-02-19 08:16:56 +11:00
Brandon
bc524026f9 feat(ui): update model identifiers to use key (#5730)
## What type of PR is this? (check all applicable)

- [x] Refactor

## Description

- Update zod schemas & types to use key instead of name/base/type
- Use new `CustomSelect` component instead of `ComboBox` for main model
select and control adapter model selects (less jank, will switch to
ComboBox based on CustomSelect for v4 so you can search the select)

## QA Instructions, Screenshots, Recordings

If you hold your breath, you should be able to generate with a control
adapter.

<!-- 
Please provide steps on how to test changes, any hardware or 
software specifications as well as any other pertinent information. 
-->

## Merge Plan

This PR can be merged when approved. Frontend tests not passing.

<!--
A merge plan describes how this PR should be handled after it is
approved.

Example merge plans:
- "This PR can be merged when approved"
- "This must be squash-merged when approved"
- "DO NOT MERGE - I will rebase and tidy commits before merging"
- "#dev-chat on discord needs to be advised of this change when it is
merged"

A merge plan is particularly important for large PRs or PRs that touch
the
database in any way.
-->
2024-02-16 11:17:35 -05:00
Brandon
ad7c571983 fix(nodes): fix t2i adapter model loading (#5731)
## What type of PR is this? (check all applicable)

- [ ] Refactor
- [ ] Feature
- [x] Bug Fix
- [ ] Optimization
- [ ] Documentation Update
- [ ] Community Node Submission

## Description

Fixes t2i adapter loading

## Merge Plan

This PR can be merged when approved

<!--
A merge plan describes how this PR should be handled after it is
approved.

Example merge plans:
- "This PR can be merged when approved"
- "This must be squash-merged when approved"
- "DO NOT MERGE - I will rebase and tidy commits before merging"
- "#dev-chat on discord needs to be advised of this change when it is
merged"

A merge plan is particularly important for large PRs or PRs that touch
the
database in any way.
-->
2024-02-16 11:17:21 -05:00
psychedelicious
8559c6a392 fix(nodes): fix t2i adapter model loading 2024-02-16 22:51:47 +11:00
psychedelicious
c7904a32f4 chore(ui): lint 2024-02-16 22:42:15 +11:00
psychedelicious
17f5484f5b feat(ui): fix main model & control adapter model selects 2024-02-16 22:41:09 +11:00
psychedelicious
86a372b02f refactor(ui): url builders for each router
The MM2 router is at `api/v2/models`. URL builder utils make this a bit easier to manage.
2024-02-16 21:57:30 +11:00
psychedelicious
2e9aa9391d feat(ui): update model identifier to be key (wip)
- Update most model identifiers to be `{key: string}` instead of name/base/type. Doesn't change the model select components yet.
- Update model _parameters_, stored in redux, to be `{key: string, base: BaseModel}` - we need to store the base model to be able to check model compatibility. May want to store the whole config? Not sure...
2024-02-16 18:56:02 +11:00
psychedelicious
0c8112cf28 fix(ui): update model types 2024-02-15 22:17:16 +11:00
psychedelicious
019898c7be tests(ui): add type tests 2024-02-15 22:16:55 +11:00
psychedelicious
2b1ff8d196 tests(ui): enable vitest type testing
This is useful for the zod schemas and types we have created to match the backend.
2024-02-15 22:16:11 +11:00
psychedelicious
79fb691b4d chore(ui): typegen 2024-02-15 22:15:21 +11:00
psychedelicious
560ae17e21 feat(ui): export components type 2024-02-15 21:16:25 +11:00
psychedelicious
2bd1ab2f1c fix(ui): fix type issues 2024-02-15 20:53:41 +11:00
psychedelicious
ed43472582 chore: lint 2024-02-15 20:52:44 +11:00
psychedelicious
6e5e9176c0 chore: ruff 2024-02-15 20:50:47 +11:00
psychedelicious
4c6bcdbc18 feat(nodes): update invocation context for mm2, update nodes model usage 2024-02-15 20:43:41 +11:00
Brandon Rising
20e6d4fa3c Raise InvalidModelConfigException when unable to detect load class in ModelLoader 2024-02-15 18:00:16 +11:00
Brandon Rising
8e51392910 Update _get_hf_load_class to support clipvision models 2024-02-15 18:00:16 +11:00
Brandon Rising
0b1c2acd61 References to context.services.model_manager.store.get_model can only accept keys, remove invalid assertion 2024-02-15 18:00:16 +11:00
Brandon Rising
86ac55ab5f Remove references to model_records service, change submodel property on ModelInfo to submodel_type to support new params in model manager 2024-02-15 18:00:16 +11:00
Lincoln Stein
3e82f63c7e improve swagger documentation 2024-02-15 18:00:08 +11:00
Lincoln Stein
631f6cae19 fix a number of typechecking errors 2024-02-15 18:00:08 +11:00
Lincoln Stein
0845a0ed84 add route for model conversion from safetensors to diffusers
- Begin to add SwaggerUI documentation for AnyModelConfig and other
  discriminated Unions.
2024-02-15 18:00:08 +11:00
Lincoln Stein
46c8ce9fed add a JIT download_and_cache() call to the model installer 2024-02-15 18:00:08 +11:00
Lincoln Stein
13a9ea35b5 add back the heuristic_import() method and extend repo_ids to arbitrary file paths 2024-02-15 18:00:08 +11:00
Lincoln Stein
94e8d1b6d5 make model manager v2 ready for PR review
- Replace legacy model manager service with the v2 manager.

- Update invocations to use new load interface.

- Fixed many but not all type checking errors in the invocations. Most
  were unrelated to model manager

- Updated routes. All the new routes live under the route tag
  `model_manager_v2`. To avoid confusion with the old routes,
  they have the URL prefix `/api/v2/models`. The old routes
  have been de-registered.

- Added a pytest for the loader.

- Updated documentation in contributing/MODEL_MANAGER.md
2024-02-15 18:00:08 +11:00
Lincoln Stein
2b1dc74080 consolidate model manager parts into a single class 2024-02-15 17:57:14 +11:00
Lincoln Stein
f7e558d165 probe for required encoder for IPAdapters and add to config 2024-02-15 17:56:01 +11:00
Lincoln Stein
d959276217 fix invokeai_configure script to work with new mm; rename CLIs 2024-02-15 17:56:01 +11:00
Lincoln Stein
dfcf38be91 BREAKING CHANGES: invocations now require model key, not base/type/name
- Implement new model loader and modify invocations and embeddings

- Finish implementation loaders for all models currently supported by
  InvokeAI.

- Move lora, textual_inversion, and model patching support into
  backend/embeddings.

- Restore support for model cache statistics collection (a little ugly,
  needs work).

- Fixed up invocations that load and patch models.

- Move seamless and silencewarnings utils into better location
2024-02-15 17:56:01 +11:00
Lincoln Stein
fbded1c0f2 Multiple refinements on loaders:
- Cache stat collection enabled.
- Implemented ONNX loading.
- Add ability to specify the repo version variant in installer CLI.
- If caller asks for a repo version that doesn't exist, will fall back
  to empty version rather than raising an error.
2024-02-15 17:51:07 +11:00
Lincoln Stein
ad2926a24c added textual inversion and lora loaders 2024-02-15 17:51:07 +11:00
Lincoln Stein
34d5cad4c9 loaders for main, controlnet, ip-adapter, clipvision and t2i 2024-02-15 17:51:07 +11:00
Lincoln Stein
60aa3d4893 model loading and conversion implemented for vaes 2024-02-15 17:50:51 +11:00
Lincoln Stein
5c2884569e add ram cache module and support files 2024-02-15 17:50:31 +11:00
Lincoln Stein
a1307b9f2e add concept of repo variant 2024-02-15 17:50:31 +11:00
psychedelicious
f505ec64ba tests(ui): add parseFieldType.test.ts 2024-02-15 17:32:38 +11:00
psychedelicious
f22eb368a3 feat(ui): add more types of FieldParseError
Unfortunately you cannot test for both a specific type of error and match its message. Splitting the error classes makes it easier to test expected error conditions.
2024-02-15 17:32:38 +11:00
psychedelicious
96ae22c7e0 feat(ui): add vitest
- Add vitest.
- Consolidate vite configs into single file (easier to config everything based on env for testing)
2024-02-15 17:32:38 +11:00
psychedelicious
f5447cdc23 feat(ui): workflow schema v3 (WIP)
The changes aim to deduplicate data between workflows and node templates, decoupling workflows from internal implementation details. A good amount of data that was needlessly duplicated from the node template to the workflow is removed.

These changes substantially reduce the file size of workflows (and therefore the images with embedded workflows):

- Default T2I SD1.5 workflow JSON is reduced from 23.7kb (798 lines) to 10.9kb (407 lines).
- Default tiled upscale workflow JSON is reduced from 102.7kb (3341 lines) to 51.9kb (1774 lines).

The trade-off is that we need to reference node templates to get things like the field type and other things. In practice, this is a non-issue, because we need a node template to do anything with a node anyways.

- Field types are not included in the workflow. They are always pulled from the node templates.

The field type is now properly an internal implementation detail and we can change it as needed. Previously this would require a migration for the workflow itself. With the v3 schema, the structure of a field type is an internal implementation detail that we are free to change as we see fit.

- Workflow nodes no long have an `outputs` property and there is no longer such a thing as a `FieldOutputInstance`. These are only on the templates.

These were never referenced at a time when we didn't also have the templates available, and there'd be no reason to do so.

- Node width and height are no longer stored in the node.

These weren't used. Also, per https://reactflow.dev/api-reference/types/node, we shouldn't be programmatically changing these properties. A future enhancement can properly add node resizing.

- `nodeTemplates` slice is merged back into `nodesSlice` as `nodes.templates`. Turns out it's just a hassle having these separate in separate slices.

- Workflow migration logic updated to support the new schema. V1 workflows migrate all the way to v3 now.

- Changes throughout the nodes code to accommodate the above changes.
2024-02-15 17:32:38 +11:00
psychedelicious
c76a6bd65f chore(ui): regen types 2024-02-15 17:30:03 +11:00
psychedelicious
6c4eeaa569 feat(nodes): add more missing exports to invocation_api
Crawled through a few custom nodes to figure out what I had missed.
2024-02-15 17:30:03 +11:00
psychedelicious
1bbd13ead7 chore(nodes): "SAMPLER_NAME_VALUES" -> "SCHEDULER_NAME_VALUES"
This was named inaccurately.
2024-02-15 17:30:03 +11:00
psychedelicious
321b939d0e chore(nodes): remove deprecation logic for nodes API 2024-02-15 17:30:03 +11:00
psychedelicious
8fb77e431e chore(nodes): export model-related objects from invocation_api 2024-02-15 17:30:03 +11:00
psychedelicious
083a4f3faa chore(backend): rename ModelInfo -> LoadedModelInfo
We have two different classes named `ModelInfo` which might need to be used by API consumers. We need to export both but have to deal with this naming collision.

The `ModelInfo` I've renamed here is the one that is returned when a model is loaded. It's the object least likely to be used by API consumers.
2024-02-15 17:30:03 +11:00
psychedelicious
2005411f7e feat(nodes): use LATENT_SCALE_FACTOR in primitives.py, noise.py
- LatentsOutput.build
- NoiseOutput.build
- Noise.width, Noise.height multiple_of
2024-02-15 17:30:03 +11:00
psychedelicious
ba7b1b2665 feat(nodes): extract LATENT_SCALE_FACTOR to constants.py 2024-02-15 17:30:03 +11:00
psychedelicious
b7ffd36cc6 feat(nodes): use TemporaryDirectory to handle ephemeral storage in ObjectSerializerDisk
Replace `delete_on_startup: bool` & associated logic with `ephemeral: bool` and `TemporaryDirectory`.

The temp dir is created inside of `output_dir`. For example, if `output_dir` is `invokeai/outputs/tensors/`, then the temp dir might be `invokeai/outputs/tensors/tmpvj35ht7b/`.

The temp dir is cleaned up when the service is stopped, or when it is GC'd if not properly stopped.

In the event of a catastrophic crash where the temp files are not cleaned up, the user can delete the tempdir themselves.

This situation may not occur in normal use, but if you kill the process, python cannot clean up the temp dir itself. This includes running the app in a debugger and killing the debugger process - something I do relatively often.

Tests updated.
2024-02-15 17:30:03 +11:00
psychedelicious
199ddd6623 tests: test ObjectSerializerDisk class name extraction 2024-02-15 17:30:03 +11:00
psychedelicious
a7207ed8cf chore(nodes): update ObjectSerializerForwardCache docstring 2024-02-15 17:30:03 +11:00
psychedelicious
6bb2dda3f1 chore(nodes): fix pyright ignore 2024-02-15 17:30:03 +11:00
psychedelicious
c1e5cd5893 tidy(nodes): "latents" -> "obj" 2024-02-15 17:30:03 +11:00
psychedelicious
ff249a2315 tidy(nodes): do not store unnecessarily store invoker 2024-02-15 17:30:03 +11:00
psychedelicious
c58f8c3269 feat(nodes): make delete on startup configurable for obj serializer
- The default is to not delete on startup - feels safer.
- The two services using this class _do_ delete on startup.
- The class has "ephemeral" removed from its name.
- Tests & app updated for this change.
2024-02-15 17:30:03 +11:00
psychedelicious
ed772a7107 fix(nodes): use metadata/board_id if provided by user, overriding WithMetadata/WithBoard-provided values 2024-02-15 17:30:03 +11:00
psychedelicious
cb0b389b4b tidy(nodes): clarify comment 2024-02-15 17:30:03 +11:00
psychedelicious
8892df1d97 Revert "feat(nodes): use LATENT_SCALE_FACTOR const in tensor output builders"
This reverts commit ef18fc546560277302f3886e456da9a47e8edce0.
2024-02-15 17:30:03 +11:00
psychedelicious
bc5f356390 feat(nodes): use LATENT_SCALE_FACTOR const in tensor output builders 2024-02-15 17:30:03 +11:00
psychedelicious
bcb85e100d tests: fix broken tests 2024-02-15 17:30:03 +11:00
psychedelicious
1f27ddc07d tidy(nodes): minor spelling correction 2024-02-15 17:30:03 +11:00
psychedelicious
7a2b606001 tests: add object serializer tests
These test both object serializer and its forward cache implementation.
2024-02-15 17:30:03 +11:00
psychedelicious
83ddcc5f3a feat(nodes): allow _delete_all in obj serializer to be called at any time
`_delete_all` logged how many items it deleted, and had to be called _after_ service start bc it needed access to logger.

Move the logger call to the startup method and return the the deleted stats from `_delete_all`. This lets `_delete_all` be called at any time.
2024-02-15 17:30:03 +11:00
psychedelicious
55fa785561 tidy(nodes): remove object serializer on_saved
It's unused.
2024-02-15 17:30:03 +11:00
psychedelicious
06429028c8 revert(nodes): revert making tensors/conditioning use item storage
Turns out they are just different enough in purpose that the implementations would be rather unintuitive. I've made a separate ObjectSerializer service to handle tensors and conditioning.

Refined the class a bit too.
2024-02-15 17:30:03 +11:00
psychedelicious
8b6e322697 feat(nodes): support custom exception in ephemeral disk storage 2024-02-15 17:30:03 +11:00
psychedelicious
54a67459bf feat(nodes): support custom save and load functions in ItemStorageEphemeralDisk 2024-02-15 17:30:03 +11:00
psychedelicious
7fe5283e74 feat(nodes): create helper function to generate the item ID 2024-02-15 17:30:03 +11:00
psychedelicious
fe0391c86b feat(nodes): use ItemStorageABC for tensors and conditioning
Turns out `ItemStorageABC` was almost identical to `PickleStorageBase`. Instead of maintaining separate classes, we can use `ItemStorageABC` for both.

There's only one change needed - the `ItemStorageABC.set` method must return the newly stored item's ID. This allows us to let the service handle the responsibility of naming the item, but still create the requisite output objects during node execution.

The naming implementation is improved here. It extracts the name of the generic and appends a UUID to that string when saving items.
2024-02-15 17:30:03 +11:00
psychedelicious
25386a76ef tidy(nodes): do not refer to files as latents in PickleStorageTorch (again) 2024-02-15 17:30:03 +11:00
psychedelicious
fd30cb4d90 feat(nodes): ItemStorageABC typevar no longer bound to pydantic.BaseModel
This bound is totally unnecessary. There's no requirement for any implementation of `ItemStorageABC` to work only on pydantic models.
2024-02-15 17:30:03 +11:00
psychedelicious
0266946d3d fix(nodes): add super init to PickleStorageTorch 2024-02-15 17:30:03 +11:00
psychedelicious
a7f91b3e01 tidy(nodes): do not refer to files as latents in PickleStorageTorch 2024-02-15 17:30:03 +11:00
psychedelicious
de0b72528c feat(nodes): replace latents service with tensors and conditioning services
- New generic class `PickleStorageBase`, implements the same API as `LatentsStorageBase`, use for storing non-serializable data via pickling
- Implementation `PickleStorageTorch` uses `torch.save` and `torch.load`, same as `LatentsStorageDisk`
- Add `tensors: PickleStorageBase[torch.Tensor]` to `InvocationServices`
- Add `conditioning: PickleStorageBase[ConditioningFieldData]` to `InvocationServices`
- Remove `latents` service and all `LatentsStorage` classes
- Update `InvocationContext` and all usage of old `latents` service to use the new services/context wrapper methods
2024-02-15 17:30:03 +11:00
psychedelicious
2932652787 tidy(nodes): delete onnx.py
It doesn't work and keeping it updated to prevent the app from starting was getting tedious. Deleted.
2024-02-15 17:30:03 +11:00
psychedelicious
db6bc7305a fix(nodes): rearrange fields.py to avoid needing forward refs 2024-02-15 17:30:02 +11:00
psychedelicious
a5db204629 tidy(nodes): remove unnecessary, shadowing class attr declarations 2024-02-15 17:30:02 +11:00
psychedelicious
8e2b61e19f feat(ui): revise graphs to not use LinearUIOutputInvocation
See this comment for context: https://github.com/invoke-ai/InvokeAI/pull/5491#discussion_r1480760629

- Remove this now-unnecessary node from all graphs
- Update graphs' terminal image-outputting nodes' `is_intermediate` and `board` fields appropriately
- Add util function to prepare the `board` field, tidy the utils
- Update `socketInvocationComplete` listener to work correctly with this change

I've manually tested all graph permutations that were changed (I think this is all...) to ensure images go to the gallery as expected:
- ad-hoc upscaling
- t2i w/ sd1.5
- t2i w/ sd1.5 & hrf
- t2i w/ sdxl
- t2i w/ sdxl + refiner
- i2i w/ sd1.5
- i2i w/ sdxl
- i2i w/ sdxl + refiner
- canvas t2i w/ sd1.5
- canvas t2i w/ sdxl
- canvas t2i w/ sdxl + refiner
- canvas i2i w/ sd1.5
- canvas i2i w/ sdxl
- canvas i2i w/ sdxl + refiner
- canvas inpaint w/ sd1.5
- canvas inpaint w/ sdxl
- canvas inpaint w/ sdxl + refiner
- canvas outpaint w/ sd1.5
- canvas outpaint w/ sdxl
- canvas outpaint w/ sdxl + refiner
2024-02-15 17:30:02 +11:00
psychedelicious
a3faa3792a chore(ui): regen types 2024-02-15 17:30:02 +11:00
psychedelicious
c16eba78ab feat(nodes): add WithBoard field helper class
This class works the same way as `WithMetadata` - it simply adds a `board` field to the node. The context wrapper function is able to pull the board id from this. This allows image-outputting nodes to get a board field "for free", and have their outputs automatically saved to it.

This is a breaking change for node authors who may have a field called `board`, because it makes `board` a reserved field name. I'll look into how to avoid this - maybe by naming this invoke-managed field `_board` to avoid collisions?

Supporting changes:
- `WithBoard` is added to all image-outputting nodes, giving them the ability to save to board.
- Unused, duplicate `WithMetadata` and `WithWorkflow` classes are deleted from `baseinvocation.py`. The "real" versions are in `fields.py`.
- Remove `LinearUIOutputInvocation`. Now that all nodes that output images also have a `board` field by default, this node is no longer necessary. See comment here for context: https://github.com/invoke-ai/InvokeAI/pull/5491#discussion_r1480760629
- Without `LinearUIOutputInvocation`, the `ImagesInferface.update` method is no longer needed, and removed.

Note: This commit does not bump all node versions. I will ensure that is done correctly before merging the PR of which this commit is a part.

Note: A followup commit will implement the frontend changes to support this change.
2024-02-15 17:30:02 +11:00
psychedelicious
1a191c4655 remove unused configdict import 2024-02-15 17:30:02 +11:00
psychedelicious
e36d925bce fix(ui): remove original l2i node in HRF graph 2024-02-15 17:30:02 +11:00
psychedelicious
b1ba18b3d1 fix(nodes): do not freeze or cache config in context wrapper
- The config is already cached by the config class's `get_config()` method.
- The config mutates itself in its `root_path` property getter. Freezing the class makes any attempt to grab a path from the config error. Unfortunately this means we cannot easily freeze the class without fiddling with the inner workings of `InvokeAIAppConfig`, which is outside the scope here.
2024-02-15 17:30:02 +11:00
psychedelicious
aff46759f9 feat(nodes): context.data -> context._data 2024-02-15 17:30:02 +11:00
psychedelicious
d7b7dcc7fe feat(nodes): context.__services -> context._services 2024-02-15 17:30:02 +11:00
psychedelicious
889a26c5b6 feat(nodes): cache invocation interface config 2024-02-15 17:30:02 +11:00
psychedelicious
b4c774896a feat(nodes): do not hide services in invocation context interfaces 2024-02-15 17:30:02 +11:00
psychedelicious
afbe889d35 fix(nodes): restore missing context type annotations 2024-02-15 17:30:02 +11:00
psychedelicious
9c1e52b1ef tests(nodes): fix mock InvocationContext 2024-02-15 17:30:02 +11:00
psychedelicious
3f5ab02da9 chore(nodes): add comments for ConfigInterface 2024-02-15 17:30:02 +11:00
psychedelicious
bf48e8a03a feat(nodes): export more things from `invocation_api" 2024-02-15 17:30:02 +11:00
psychedelicious
e52434cb99 feat(nodes): add boards interface to invocation context 2024-02-15 17:30:02 +11:00
psychedelicious
483bdbcb9f fix(nodes): restore type annotations for InvocationContext 2024-02-15 17:30:02 +11:00
psychedelicious
ae421fb4ab feat(nodes): do not freeze InvocationContextData, prevents it from being subclassesd 2024-02-15 17:30:02 +11:00
psychedelicious
cc295a9f0a feat: tweak pyright config 2024-02-15 17:30:02 +11:00
psychedelicious
a7e23af9c6 feat(nodes): create invocation_api.py
This is the public API for invocations.

Everything a custom node might need should be re-exported from this file.
2024-02-15 17:30:02 +11:00
psychedelicious
3de4390711 feat(nodes): move ConditioningFieldData to conditioning_data.py 2024-02-15 17:30:02 +11:00
psychedelicious
3ceee2b2b2 tests: fix missing arg for InvocationContext 2024-02-15 17:30:02 +11:00
psychedelicious
5c7ed24aab feat(nodes): restore previous invocation context methods with deprecation warnings 2024-02-15 17:30:02 +11:00
psychedelicious
183c9c4799 chore: ruff 2024-02-15 17:30:02 +11:00
psychedelicious
8baf3f78a2 feat(nodes): tidy invocation_context.py, improve comments 2024-02-15 17:30:02 +11:00
psychedelicious
ac2eb16a65 tests: fix tests for new invocation context 2024-02-15 17:30:02 +11:00
psychedelicious
4aa7bee4b9 docs: update INVOCATIONS.md 2024-02-15 17:30:02 +11:00
psychedelicious
7e5ba2795e feat(nodes): update all invocations to use new invocation context
Update all invocations to use the new context. The changes are all fairly simple, but there are a lot of them.

Supporting minor changes:
- Patch bump for all nodes that use the context
- Update invocation processor to provide new context
- Minor change to `EventServiceBase` to accept a node's ID instead of the dict version of a node
- Minor change to `ModelManagerService` to support the new wrapped context
- Fanagling of imports to avoid circular dependencies
2024-02-15 17:30:02 +11:00
psychedelicious
97a6c6eea7 feat: add pyright config
I was having issues with mypy bother over- and under-reporting certain problems. I've added a pyright config.
2024-02-15 17:30:02 +11:00
psychedelicious
f0e60a4ba2 feat(nodes): restricts invocation context power
Creates a low-power `InvocationContext` with simplified methods and data.

See `invocation_context.py` for detailed comments.
2024-02-15 17:30:02 +11:00
psychedelicious
aa089e8108 tidy(nodes): move all field things to fields.py
Unfortunately, this is necessary to prevent circular imports at runtime.
2024-02-15 17:30:02 +11:00
510 changed files with 17559 additions and 15775 deletions

View File

@@ -1,33 +0,0 @@
name: install frontend dependencies
description: Installs frontend dependencies with pnpm, with caching
runs:
using: 'composite'
steps:
- name: setup node 18
uses: actions/setup-node@v4
with:
node-version: '18'
- name: setup pnpm
uses: pnpm/action-setup@v2
with:
version: 8
run_install: false
- name: get pnpm store directory
shell: bash
run: |
echo "STORE_PATH=$(pnpm store path --silent)" >> $GITHUB_ENV
- name: setup cache
uses: actions/cache@v4
with:
path: ${{ env.STORE_PATH }}
key: ${{ runner.os }}-pnpm-store-${{ hashFiles('**/pnpm-lock.yaml') }}
restore-keys: |
${{ runner.os }}-pnpm-store-
- name: install frontend dependencies
run: pnpm install --prefer-frozen-lockfile
shell: bash
working-directory: invokeai/frontend/web

28
.github/pr_labels.yml vendored
View File

@@ -1,59 +1,59 @@
root:
Root:
- changed-files:
- any-glob-to-any-file: '*'
python-deps:
PythonDeps:
- changed-files:
- any-glob-to-any-file: 'pyproject.toml'
python:
Python:
- changed-files:
- all-globs-to-any-file:
- 'invokeai/**'
- '!invokeai/frontend/web/**'
python-tests:
PythonTests:
- changed-files:
- any-glob-to-any-file: 'tests/**'
ci-cd:
CICD:
- changed-files:
- any-glob-to-any-file: .github/**
docker:
Docker:
- changed-files:
- any-glob-to-any-file: docker/**
installer:
Installer:
- changed-files:
- any-glob-to-any-file: installer/**
docs:
Documentation:
- changed-files:
- any-glob-to-any-file: docs/**
invocations:
Invocations:
- changed-files:
- any-glob-to-any-file: 'invokeai/app/invocations/**'
backend:
Backend:
- changed-files:
- any-glob-to-any-file: 'invokeai/backend/**'
api:
Api:
- changed-files:
- any-glob-to-any-file: 'invokeai/app/api/**'
services:
Services:
- changed-files:
- any-glob-to-any-file: 'invokeai/app/services/**'
frontend-deps:
FrontendDeps:
- changed-files:
- any-glob-to-any-file:
- '**/*/package.json'
- '**/*/pnpm-lock.yaml'
frontend:
Frontend:
- changed-files:
- any-glob-to-any-file: 'invokeai/frontend/web/**'

View File

@@ -11,7 +11,7 @@ on:
- 'docker/docker-entrypoint.sh'
- 'workflows/build-container.yml'
tags:
- 'v*.*.*'
- 'v*'
workflow_dispatch:
permissions:

View File

@@ -1,45 +0,0 @@
# Builds and uploads the installer and python build artifacts.
name: build installer
on:
workflow_dispatch:
workflow_call:
jobs:
build-installer:
runs-on: ubuntu-latest
timeout-minutes: 5 # expected run time: <2 min
steps:
- name: checkout
uses: actions/checkout@v4
- name: setup python
uses: actions/setup-python@v5
with:
python-version: '3.10'
cache: pip
cache-dependency-path: pyproject.toml
- name: install pypa/build
run: pip install --upgrade build
- name: setup frontend
uses: ./.github/actions/install-frontend-deps
- name: create installer
id: create_installer
run: ./create_installer.sh
working-directory: installer
- name: upload python distribution artifact
uses: actions/upload-artifact@v4
with:
name: dist
path: ${{ steps.create_installer.outputs.DIST_PATH }}
- name: upload installer artifact
uses: actions/upload-artifact@v4
with:
name: ${{ steps.create_installer.outputs.INSTALLER_FILENAME }}
path: ${{ steps.create_installer.outputs.INSTALLER_PATH }}

View File

@@ -1,68 +0,0 @@
# Runs frontend code quality checks.
#
# Checks for changes to frontend files before running the checks.
# When manually triggered or when called from another workflow, always runs the checks.
name: 'frontend checks'
on:
push:
branches:
- 'main'
pull_request:
types:
- 'ready_for_review'
- 'opened'
- 'synchronize'
merge_group:
workflow_dispatch:
workflow_call:
defaults:
run:
working-directory: invokeai/frontend/web
jobs:
frontend-checks:
runs-on: ubuntu-latest
timeout-minutes: 10 # expected run time: <2 min
steps:
- uses: actions/checkout@v4
- name: check for changed frontend files
if: ${{ github.event_name != 'workflow_dispatch' && github.event_name != 'workflow_call' }}
id: changed-files
uses: tj-actions/changed-files@v42
with:
files_yaml: |
frontend:
- 'invokeai/frontend/web/**'
- name: install dependencies
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
uses: ./.github/actions/install-frontend-deps
- name: tsc
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
run: 'pnpm lint:tsc'
shell: bash
- name: dpdm
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
run: 'pnpm lint:dpdm'
shell: bash
- name: eslint
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
run: 'pnpm lint:eslint'
shell: bash
- name: prettier
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
run: 'pnpm lint:prettier'
shell: bash
- name: knip
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
run: 'pnpm lint:knip'
shell: bash

View File

@@ -1,48 +0,0 @@
# Runs frontend tests.
#
# Checks for changes to frontend files before running the tests.
# When manually triggered or called from another workflow, always runs the tests.
name: 'frontend tests'
on:
push:
branches:
- 'main'
pull_request:
types:
- 'ready_for_review'
- 'opened'
- 'synchronize'
merge_group:
workflow_dispatch:
workflow_call:
defaults:
run:
working-directory: invokeai/frontend/web
jobs:
frontend-tests:
runs-on: ubuntu-latest
timeout-minutes: 10 # expected run time: <2 min
steps:
- uses: actions/checkout@v4
- name: check for changed frontend files
if: ${{ github.event_name != 'workflow_dispatch' && github.event_name != 'workflow_call' }}
id: changed-files
uses: tj-actions/changed-files@v42
with:
files_yaml: |
frontend:
- 'invokeai/frontend/web/**'
- name: install dependencies
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
uses: ./.github/actions/install-frontend-deps
- name: vitest
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
run: 'pnpm test:no-watch'
shell: bash

View File

@@ -1,6 +1,6 @@
name: 'label PRs'
name: "Pull Request Labeler"
on:
- pull_request_target
- pull_request_target
jobs:
labeler:
@@ -9,10 +9,8 @@ jobs:
pull-requests: write
runs-on: ubuntu-latest
steps:
- name: checkout
- name: Checkout
uses: actions/checkout@v4
- name: label PRs
uses: actions/labeler@v5
- uses: actions/labeler@v5
with:
configuration-path: .github/pr_labels.yml
configuration-path: .github/pr_labels.yml

43
.github/workflows/lint-frontend.yml vendored Normal file
View File

@@ -0,0 +1,43 @@
name: Lint frontend
on:
pull_request:
types:
- 'ready_for_review'
- 'opened'
- 'synchronize'
push:
branches:
- 'main'
merge_group:
workflow_dispatch:
defaults:
run:
working-directory: invokeai/frontend/web
jobs:
lint-frontend:
if: github.event.pull_request.draft == false
runs-on: ubuntu-22.04
steps:
- name: Setup Node 18
uses: actions/setup-node@v4
with:
node-version: '18'
- name: Checkout
uses: actions/checkout@v4
- name: Setup pnpm
uses: pnpm/action-setup@v2
with:
version: '8.12.1'
- name: Install dependencies
run: 'pnpm install --prefer-frozen-lockfile'
- name: Typescript
run: 'pnpm run lint:tsc'
- name: Madge
run: 'pnpm run lint:madge'
- name: ESLint
run: 'pnpm run lint:eslint'
- name: Prettier
run: 'pnpm run lint:prettier'

View File

@@ -1,49 +1,51 @@
# This is a mostly a copy-paste from https://github.com/squidfunk/mkdocs-material/blob/master/docs/publishing-your-site.md
name: mkdocs
name: mkdocs-material
on:
push:
branches:
- main
workflow_dispatch:
- 'refs/heads/main'
permissions:
contents: write
contents: write
jobs:
deploy:
mkdocs-material:
if: github.event.pull_request.draft == false
runs-on: ubuntu-latest
env:
REPO_URL: '${{ github.server_url }}/${{ github.repository }}'
REPO_NAME: '${{ github.repository }}'
SITE_URL: 'https://${{ github.repository_owner }}.github.io/InvokeAI'
steps:
- name: checkout
uses: actions/checkout@v4
- name: checkout sources
uses: actions/checkout@v3
with:
fetch-depth: 0
- name: setup python
uses: actions/setup-python@v5
uses: actions/setup-python@v4
with:
python-version: '3.10'
cache: pip
cache-dependency-path: pyproject.toml
- name: set cache id
run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV
- name: install requirements
env:
PIP_USE_PEP517: 1
run: |
python -m \
pip install ".[docs]"
- name: use cache
uses: actions/cache@v4
with:
key: mkdocs-material-${{ env.cache_id }}
path: .cache
restore-keys: |
mkdocs-material-
- name: confirm buildability
run: |
python -m \
mkdocs build \
--clean \
--verbose
- name: install dependencies
run: python -m pip install ".[docs]"
- name: build & deploy
run: mkdocs gh-deploy --force
- name: deploy to gh-pages
if: ${{ github.ref == 'refs/heads/main' }}
run: |
python -m \
mkdocs gh-deploy \
--clean \
--force

67
.github/workflows/pypi-release.yml vendored Normal file
View File

@@ -0,0 +1,67 @@
name: PyPI Release
on:
workflow_dispatch:
inputs:
publish_package:
description: 'Publish build on PyPi? [true/false]'
required: true
default: 'false'
jobs:
build-and-release:
if: github.repository == 'invoke-ai/InvokeAI'
runs-on: ubuntu-22.04
env:
TWINE_USERNAME: __token__
TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
TWINE_NON_INTERACTIVE: 1
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Setup Node 18
uses: actions/setup-node@v4
with:
node-version: '18'
- name: Setup pnpm
uses: pnpm/action-setup@v2
with:
version: '8.12.1'
- name: Install frontend dependencies
run: pnpm install --prefer-frozen-lockfile
working-directory: invokeai/frontend/web
- name: Build frontend
run: pnpm run build
working-directory: invokeai/frontend/web
- name: Install python dependencies
run: pip install --upgrade build twine
- name: Build python package
run: python3 -m build
- name: Upload build as workflow artifact
uses: actions/upload-artifact@v4
with:
name: dist
path: dist
- name: Check distribution
run: twine check dist/*
- name: Check PyPI versions
if: github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/heads/release/')
run: |
pip install --upgrade requests
python -c "\
import scripts.pypi_helper; \
EXISTS=scripts.pypi_helper.local_on_pypi(); \
print(f'PACKAGE_EXISTS={EXISTS}')" >> $GITHUB_ENV
- name: Publish build on PyPi
if: env.PACKAGE_EXISTS == 'False' && env.TWINE_PASSWORD != '' && github.event.inputs.publish_package == 'true'
run: twine upload dist/*

View File

@@ -1,64 +0,0 @@
# Runs python code quality checks.
#
# Checks for changes to python files before running the checks.
# When manually triggered or called from another workflow, always runs the tests.
#
# TODO: Add mypy or pyright to the checks.
name: 'python checks'
on:
push:
branches:
- 'main'
pull_request:
types:
- 'ready_for_review'
- 'opened'
- 'synchronize'
merge_group:
workflow_dispatch:
workflow_call:
jobs:
python-checks:
runs-on: ubuntu-latest
timeout-minutes: 5 # expected run time: <1 min
steps:
- name: checkout
uses: actions/checkout@v4
- name: check for changed python files
if: ${{ github.event_name != 'workflow_dispatch' && github.event_name != 'workflow_call' }}
id: changed-files
uses: tj-actions/changed-files@v42
with:
files_yaml: |
python:
- 'pyproject.toml'
- 'invokeai/**'
- '!invokeai/frontend/web/**'
- 'tests/**'
- name: setup python
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
uses: actions/setup-python@v5
with:
python-version: '3.10'
cache: pip
cache-dependency-path: pyproject.toml
- name: install ruff
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
run: pip install ruff
shell: bash
- name: ruff check
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
run: ruff check --output-format=github .
shell: bash
- name: ruff format
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
run: ruff format --check .
shell: bash

View File

@@ -1,94 +0,0 @@
# Runs python tests on a matrix of python versions and platforms.
#
# Checks for changes to python files before running the tests.
# When manually triggered or called from another workflow, always runs the tests.
name: 'python tests'
on:
push:
branches:
- 'main'
pull_request:
types:
- 'ready_for_review'
- 'opened'
- 'synchronize'
merge_group:
workflow_dispatch:
workflow_call:
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
matrix:
strategy:
matrix:
python-version:
- '3.10'
- '3.11'
platform:
- linux-cuda-11_7
- linux-rocm-5_2
- linux-cpu
- macos-default
- windows-cpu
include:
- platform: linux-cuda-11_7
os: ubuntu-22.04
github-env: $GITHUB_ENV
- platform: linux-rocm-5_2
os: ubuntu-22.04
extra-index-url: 'https://download.pytorch.org/whl/rocm5.2'
github-env: $GITHUB_ENV
- platform: linux-cpu
os: ubuntu-22.04
extra-index-url: 'https://download.pytorch.org/whl/cpu'
github-env: $GITHUB_ENV
- platform: macos-default
os: macOS-12
github-env: $GITHUB_ENV
- platform: windows-cpu
os: windows-2022
github-env: $env:GITHUB_ENV
name: 'py${{ matrix.python-version }}: ${{ matrix.platform }}'
runs-on: ${{ matrix.os }}
timeout-minutes: 15 # expected run time: 2-6 min, depending on platform
env:
PIP_USE_PEP517: '1'
steps:
- name: checkout
uses: actions/checkout@v4
- name: check for changed python files
if: ${{ github.event_name != 'workflow_dispatch' && github.event_name != 'workflow_call' }}
id: changed-files
uses: tj-actions/changed-files@v42
with:
files_yaml: |
python:
- 'pyproject.toml'
- 'invokeai/**'
- '!invokeai/frontend/web/**'
- 'tests/**'
- name: setup python
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
cache: pip
cache-dependency-path: pyproject.toml
- name: install dependencies
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
env:
PIP_EXTRA_INDEX_URL: ${{ matrix.extra-index-url }}
run: >
pip3 install --editable=".[test]"
- name: run pytest
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
run: pytest

View File

@@ -1,96 +0,0 @@
# Main release workflow. Triggered on tag push or manual trigger.
#
# - Runs all code checks and tests
# - Verifies the app version matches the tag version.
# - Builds the installer and build, uploading them as artifacts.
# - Publishes to TestPyPI and PyPI. Both are conditional on the previous steps passing and require a manual approval.
#
# See docs/RELEASE.md for more information on the release process.
name: release
on:
push:
tags:
- 'v*'
workflow_dispatch:
jobs:
check-version:
runs-on: ubuntu-latest
steps:
- name: checkout
uses: actions/checkout@v4
- name: check python version
uses: samuelcolvin/check-python-version@v4
id: check-python-version
with:
version_file_path: invokeai/version/invokeai_version.py
frontend-checks:
uses: ./.github/workflows/frontend-checks.yml
frontend-tests:
uses: ./.github/workflows/frontend-tests.yml
python-checks:
uses: ./.github/workflows/python-checks.yml
python-tests:
uses: ./.github/workflows/python-tests.yml
build:
uses: ./.github/workflows/build-installer.yml
publish-testpypi:
runs-on: ubuntu-latest
timeout-minutes: 5 # expected run time: <1 min
needs:
[
check-version,
frontend-checks,
frontend-tests,
python-checks,
python-tests,
build,
]
environment:
name: testpypi
url: https://test.pypi.org/p/invokeai
steps:
- name: download distribution from build job
uses: actions/download-artifact@v4
with:
name: dist
path: dist/
- name: publish distribution to TestPyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
repository-url: https://test.pypi.org/legacy/
publish-pypi:
runs-on: ubuntu-latest
timeout-minutes: 5 # expected run time: <1 min
needs:
[
check-version,
frontend-checks,
frontend-tests,
python-checks,
python-tests,
build,
]
environment:
name: pypi
url: https://pypi.org/p/invokeai
steps:
- name: download distribution from build job
uses: actions/download-artifact@v4
with:
name: dist
path: dist/
- name: publish distribution to PyPI
uses: pypa/gh-action-pypi-publish@release/v1

24
.github/workflows/style-checks.yml vendored Normal file
View File

@@ -0,0 +1,24 @@
name: style checks
on:
pull_request:
push:
branches: main
jobs:
ruff:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Setup Python
uses: actions/setup-python@v4
with:
python-version: '3.10'
- name: Install dependencies with pip
run: |
pip install ruff
- run: ruff check --output-format=github .
- run: ruff format --check .

129
.github/workflows/test-invoke-pip.yml vendored Normal file
View File

@@ -0,0 +1,129 @@
name: Test invoke.py pip
on:
push:
branches:
- 'main'
pull_request:
types:
- 'ready_for_review'
- 'opened'
- 'synchronize'
merge_group:
workflow_dispatch:
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
matrix:
if: github.event.pull_request.draft == false
strategy:
matrix:
python-version:
# - '3.9'
- '3.10'
pytorch:
- linux-cuda-11_7
- linux-rocm-5_2
- linux-cpu
- macos-default
- windows-cpu
include:
- pytorch: linux-cuda-11_7
os: ubuntu-22.04
github-env: $GITHUB_ENV
- pytorch: linux-rocm-5_2
os: ubuntu-22.04
extra-index-url: 'https://download.pytorch.org/whl/rocm5.2'
github-env: $GITHUB_ENV
- pytorch: linux-cpu
os: ubuntu-22.04
extra-index-url: 'https://download.pytorch.org/whl/cpu'
github-env: $GITHUB_ENV
- pytorch: macos-default
os: macOS-12
github-env: $GITHUB_ENV
- pytorch: windows-cpu
os: windows-2022
github-env: $env:GITHUB_ENV
name: ${{ matrix.pytorch }} on ${{ matrix.python-version }}
runs-on: ${{ matrix.os }}
env:
PIP_USE_PEP517: '1'
steps:
- name: Checkout sources
id: checkout-sources
uses: actions/checkout@v3
- name: Check for changed python files
id: changed-files
uses: tj-actions/changed-files@v41
with:
files_yaml: |
python:
- 'pyproject.toml'
- 'invokeai/**'
- '!invokeai/frontend/web/**'
- 'tests/**'
- name: set test prompt to main branch validation
if: steps.changed-files.outputs.python_any_changed == 'true'
run: echo "TEST_PROMPTS=tests/validate_pr_prompt.txt" >> ${{ matrix.github-env }}
- name: setup python
if: steps.changed-files.outputs.python_any_changed == 'true'
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
cache: pip
cache-dependency-path: pyproject.toml
- name: install invokeai
if: steps.changed-files.outputs.python_any_changed == 'true'
env:
PIP_EXTRA_INDEX_URL: ${{ matrix.extra-index-url }}
run: >
pip3 install
--editable=".[test]"
- name: run pytest
if: steps.changed-files.outputs.python_any_changed == 'true'
id: run-pytest
run: pytest
# - name: run invokeai-configure
# env:
# HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGINGFACE_TOKEN }}
# run: >
# invokeai-configure
# --yes
# --default_only
# --full-precision
# # can't use fp16 weights without a GPU
# - name: run invokeai
# id: run-invokeai
# env:
# # Set offline mode to make sure configure preloaded successfully.
# HF_HUB_OFFLINE: 1
# HF_DATASETS_OFFLINE: 1
# TRANSFORMERS_OFFLINE: 1
# INVOKEAI_OUTDIR: ${{ github.workspace }}/results
# run: >
# invokeai
# --no-patchmatch
# --no-nsfw_checker
# --precision=float32
# --always_use_cpu
# --use_memory_db
# --outdir ${{ env.INVOKEAI_OUTDIR }}/${{ matrix.python-version }}/${{ matrix.pytorch }}
# --from_file ${{ env.TEST_PROMPTS }}
# - name: Archive results
# env:
# INVOKEAI_OUTDIR: ${{ github.workspace }}/results
# uses: actions/upload-artifact@v3
# with:
# name: results
# path: ${{ env.INVOKEAI_OUTDIR }}

View File

@@ -7,7 +7,7 @@ embeddedLanguageFormatting: auto
overrides:
- files: '*.md'
options:
proseWrap: preserve
proseWrap: always
printWidth: 80
parser: markdown
cursorOffset: -1

View File

@@ -6,44 +6,33 @@ default: help
help:
@echo Developer commands:
@echo
@echo "ruff Run ruff, fixing any safely-fixable errors and formatting"
@echo "ruff-unsafe Run ruff, fixing all fixable errors and formatting"
@echo "mypy Run mypy using the config in pyproject.toml to identify type mismatches and other coding errors"
@echo "mypy-all Run mypy ignoring the config in pyproject.tom but still ignoring missing imports"
@echo "test" Run the unit tests.
@echo "frontend-install" Install the pnpm modules needed for the front end
@echo "frontend-build Build the frontend in order to run on localhost:9090"
@echo "frontend-dev Run the frontend in developer mode on localhost:5173"
@echo "installer-zip Build the installer .zip file for the current version"
@echo "tag-release Tag the GitHub repository with the current version (use at release time only!)"
@echo "ruff Run ruff, fixing any safely-fixable errors and formatting"
@echo "ruff-unsafe Run ruff, fixing all fixable errors and formatting"
@echo "mypy Run mypy using the config in pyproject.toml to identify type mismatches and other coding errors"
@echo "mypy-all Run mypy ignoring the config in pyproject.tom but still ignoring missing imports"
@echo "frontend-build Build the frontend in order to run on localhost:9090"
@echo "frontend-dev Run the frontend in developer mode on localhost:5173"
@echo "installer-zip Build the installer .zip file for the current version"
@echo "tag-release Tag the GitHub repository with the current version (use at release time only!)"
# Runs ruff, fixing any safely-fixable errors and formatting
ruff:
ruff check . --fix
ruff format .
ruff check . --fix
ruff format .
# Runs ruff, fixing all errors it can fix and formatting
ruff-unsafe:
ruff check . --fix --unsafe-fixes
ruff format .
ruff check . --fix --unsafe-fixes
ruff format .
# Runs mypy, using the config in pyproject.toml
mypy:
mypy scripts/invokeai-web.py
mypy scripts/invokeai-web.py
# Runs mypy, ignoring the config in pyproject.toml but still ignoring missing (untyped) imports
# (many files are ignored by the config, so this is useful for checking all files)
mypy-all:
mypy scripts/invokeai-web.py --config-file= --ignore-missing-imports
# Run the unit tests
test:
pytest ./tests
# Install the pnpm modules needed for the front end
frontend-install:
rm -rf invokeai/frontend/web/node_modules
cd invokeai/frontend/web && pnpm install
mypy scripts/invokeai-web.py --config-file= --ignore-missing-imports
# Build the frontend
frontend-build:

View File

@@ -1,142 +0,0 @@
# Release Process
The app is published in twice, in different build formats.
- A [PyPI] distribution. This includes both a source distribution and built distribution (a wheel). Users install with `pip install invokeai`. The updater uses this build.
- An installer on the [InvokeAI Releases Page]. This is a zip file with install scripts and a wheel. This is only used for new installs.
## General Prep
Make a developer call-out for PRs to merge. Merge and test things out.
While the release workflow does not include end-to-end tests, it does pause before publishing so you can download and test the final build.
## Release Workflow
The `release.yml` workflow runs a number of jobs to handle code checks, tests, build and publish on PyPI.
It is triggered on **tag push**, when the tag matches `v*`. It doesn't matter if you've prepped a release branch like `release/v3.5.0` or are releasing from `main` - it works the same.
> Because commits are reference-counted, it is safe to create a release branch, tag it, let the workflow run, then delete the branch. So long as the tag exists, that commit will exist.
### Triggering the Workflow
Run `make tag-release` to tag the current commit and kick off the workflow.
The release may also be dispatched [manually].
### Workflow Jobs and Process
The workflow consists of a number of concurrently-run jobs, and two final publish jobs.
The publish jobs require manual approval and are only run if the other jobs succeed.
#### `check-version` Job
This job checks that the git ref matches the app version. It matches the ref against the `__version__` variable in `invokeai/version/invokeai_version.py`.
When the workflow is triggered by tag push, the ref is the tag. If the workflow is run manually, the ref is the target selected from the **Use workflow from** dropdown.
This job uses [samuelcolvin/check-python-version].
> Any valid [version specifier] works, so long as the tag matches the version. The release workflow works exactly the same for `RC`, `post`, `dev`, etc.
#### Check and Test Jobs
- **`python-tests`**: runs `pytest` on matrix of platforms
- **`python-checks`**: runs `ruff` (format and lint)
- **`frontend-tests`**: runs `vitest`
- **`frontend-checks`**: runs `prettier` (format), `eslint` (lint), `dpdm` (circular refs), `tsc` (static type check) and `knip` (unused imports)
> **TODO** We should add `mypy` or `pyright` to the **`check-python`** job.
> **TODO** We should add an end-to-end test job that generates an image.
#### `build-installer` Job
This sets up both python and frontend dependencies and builds the python package. Internally, this runs `installer/create_installer.sh` and uploads two artifacts:
- **`dist`**: the python distribution, to be published on PyPI
- **`InvokeAI-installer-${VERSION}.zip`**: the installer to be included in the GitHub release
#### Sanity Check & Smoke Test
At this point, the release workflow pauses as the remaining publish jobs require approval.
A maintainer should go to the **Summary** tab of the workflow, download the installer and test it. Ensure the app loads and generates.
> The same wheel file is bundled in the installer and in the `dist` artifact, which is uploaded to PyPI. You should end up with the exactly the same installation of the `invokeai` package from any of these methods.
#### PyPI Publish Jobs
The publish jobs will run if any of the previous jobs fail.
They use [GitHub environments], which are configured as [trusted publishers] on PyPI.
Both jobs require a maintainer to approve them from the workflow's **Summary** tab.
- Click the **Review deployments** button
- Select the environment (either `testpypi` or `pypi`)
- Click **Approve and deploy**
> **If the version already exists on PyPI, the publish jobs will fail.** PyPI only allows a given version to be published once - you cannot change it. If version published on PyPI has a problem, you'll need to "fail forward" by bumping the app version and publishing a followup release.
#### `publish-testpypi` Job
Publishes the distribution on the [Test PyPI] index, using the `testpypi` GitHub environment.
This job is not required for the production PyPI publish, but included just in case you want to test the PyPI release.
If approved and successful, you could try out the test release like this:
```sh
# Create a new virtual environment
python -m venv ~/.test-invokeai-dist --prompt test-invokeai-dist
# Install the distribution from Test PyPI
pip install --index-url https://test.pypi.org/simple/ invokeai
# Run and test the app
invokeai-web
# Cleanup
deactivate
rm -rf ~/.test-invokeai-dist
```
#### `publish-pypi` Job
Publishes the distribution on the production PyPI index, using the `pypi` GitHub environment.
## Publish the GitHub Release with installer
Once the release is published to PyPI, it's time to publish the GitHub release.
1. [Draft a new release] on GitHub, choosing the tag that triggered the release.
2. Write the release notes, describing important changes. The **Generate release notes** button automatically inserts the changelog and new contributors, and you can copy/paste the intro from previous releases.
3. Upload the zip file created in **`build`** job into the Assets section of the release notes. You can also upload the zip into the body of the release notes, since it can be hard for users to find the Assets section.
4. Check the **Set as a pre-release** and **Create a discussion for this release** checkboxes at the bottom of the release page.
5. Publish the pre-release.
6. Announce the pre-release in Discord.
> **TODO** Workflows can create a GitHub release from a template and upload release assets. One popular action to handle this is [ncipollo/release-action]. A future enhancement to the release process could set this up.
## Manual Build
The `build installer` workflow can be dispatched manually. This is useful to test the installer for a given branch or tag.
No checks are run, it just builds.
## Manual Release
The `release` workflow can be dispatched manually. You must dispatch the workflow from the right tag, else it will fail the version check.
This functionality is available as a fallback in case something goes wonky. Typically, releases should be triggered via tag push as described above.
[InvokeAI Releases Page]: https://github.com/invoke-ai/InvokeAI/releases
[PyPI]: https://pypi.org/
[Draft a new release]: https://github.com/invoke-ai/InvokeAI/releases/new
[Test PyPI]: https://test.pypi.org/
[version specifier]: https://packaging.python.org/en/latest/specifications/version-specifiers/
[ncipollo/release-action]: https://github.com/ncipollo/release-action
[GitHub environments]: https://docs.github.com/en/actions/deployment/targeting-different-environments/using-environments-for-deployment
[trusted publishers]: https://docs.pypi.org/trusted-publishers/
[samuelcolvin/check-python-version]: https://github.com/samuelcolvin/check-python-version
[manually]: #manual-release

View File

@@ -1,45 +0,0 @@
# Invocation API
Each invocation's `invoke` method is provided a single arg - the Invocation
Context.
This object provides access to various methods, used to interact with the
application. Loading and saving images, logging messages, etc.
!!! warning ""
This API may shift slightly until the release of v4.0.0 as we work through a few final updates to the Model Manager.
```py
class MyInvocation(BaseInvocation):
...
def invoke(self, context: InvocationContext) -> ImageOutput:
image_pil = context.images.get_pil(image_name)
# Do something to the image
image_dto = context.images.save(image_pil)
# Log a message
context.logger.info(f"Did something cool, image saved!")
...
```
<!-- prettier-ignore-start -->
::: invokeai.app.services.shared.invocation_context.InvocationContext
options:
members: false
::: invokeai.app.services.shared.invocation_context.ImagesInterface
::: invokeai.app.services.shared.invocation_context.TensorsInterface
::: invokeai.app.services.shared.invocation_context.ConditioningInterface
::: invokeai.app.services.shared.invocation_context.ModelsInterface
::: invokeai.app.services.shared.invocation_context.LoggerInterface
::: invokeai.app.services.shared.invocation_context.ConfigInterface
::: invokeai.app.services.shared.invocation_context.UtilInterface
::: invokeai.app.services.shared.invocation_context.BoardsInterface
<!-- prettier-ignore-end -->

View File

@@ -1,148 +0,0 @@
# Invoke v4.0.0 Nodes API Migration guide
Invoke v4.0.0 is versioned as such due to breaking changes to the API utilized
by nodes, both core and custom.
## Motivation
Prior to v4.0.0, the `invokeai` python package has not be set up to be utilized
as a library. That is to say, it didn't have any explicitly public API, and node
authors had to work with the unstable internal application API.
v4.0.0 introduces a stable public API for nodes.
## Changes
There are two node-author-facing changes:
1. Import Paths
1. Invocation Context API
### Import Paths
All public objects are now exported from `invokeai.invocation_api`:
```py
# Old
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
InputField,
InvocationContext,
invocation,
)
from invokeai.app.invocations.primitives import ImageField
# New
from invokeai.invocation_api import (
BaseInvocation,
ImageField,
InputField,
InvocationContext,
invocation,
)
```
It's possible that we've missed some classes you need in your node. Please let
us know if that's the case.
### Invocation Context API
Most nodes utilize the Invocation Context, an object that is passed to the
`invoke` that provides access to data and services a node may need.
Until now, that object and the services it exposed were internal. Exposing them
to nodes means that changes to our internal implementation could break nodes.
The methods on the services are also often fairly complicated and allowed nodes
to footgun.
In v4.0.0, this object has been refactored to be much simpler.
See [INVOCATION_API](./INVOCATION_API.md) for full details of the API.
!!! warning ""
This API may shift slightly until the release of v4.0.0 as we work through a few final updates to the Model Manager.
#### Improved Service Methods
The biggest offender was the image save method:
```py
# Old
image_dto = context.services.images.create(
image=image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata,
workflow=context.workflow,
)
# New
image_dto = context.images.save(image=image)
```
Other methods are simplified, or enhanced with additional functionality:
```py
# Old
image = context.services.images.get_pil_image(image_name)
# New
image = context.images.get_pil(image_name)
image_cmyk = context.images.get_pil(image_name, "CMYK")
```
We also had some typing issues around tensors:
```py
# Old
# `latents` typed as `torch.Tensor`, but could be `ConditioningFieldData`
latents = context.services.latents.get(self.latents.latents_name)
# `data` typed as `torch.Tenssor,` but could be `ConditioningFieldData`
context.services.latents.save(latents_name, data)
# New - separate methods for tensors and conditioning data w/ correct typing
# Also, the service generates the names
tensor_name = context.tensors.save(tensor)
tensor = context.tensors.load(tensor_name)
# For conditioning
cond_name = context.conditioning.save(cond_data)
cond_data = context.conditioning.load(cond_name)
```
#### Output Construction
Core Outputs have builder functions right on them - no need to manually
construct these objects, or use an extra utility:
```py
# Old
image_output = ImageOutput(
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
latents_output = build_latents_output(latents_name=name, latents=latents, seed=None)
noise_output = NoiseOutput(
noise=LatentsField(latents_name=latents_name, seed=seed),
width=latents.size()[3] * 8,
height=latents.size()[2] * 8,
)
cond_output = ConditioningOutput(
conditioning=ConditioningField(
conditioning_name=conditioning_name,
),
)
# New
image_output = ImageOutput.build(image_dto)
latents_output = LatentsOutput.build(latents_name=name, latents=noise, seed=self.seed)
noise_output = NoiseOutput.build(latents_name=name, latents=noise, seed=self.seed)
cond_output = ConditioningOutput.build(conditioning_name)
```
You can still create the objects using constructors if you want, but we suggest
using the builder methods.

View File

@@ -32,7 +32,6 @@ To use a community workflow, download the the `.json` node graph file and load i
+ [Image to Character Art Image Nodes](#image-to-character-art-image-nodes)
+ [Image Picker](#image-picker)
+ [Image Resize Plus](#image-resize-plus)
+ [Latent Upscale](#latent-upscale)
+ [Load Video Frame](#load-video-frame)
+ [Make 3D](#make-3d)
+ [Mask Operations](#mask-operations)
@@ -291,13 +290,6 @@ View:
</br><img src="https://raw.githubusercontent.com/VeyDlin/image-resize-plus-node/master/.readme/node.png" width="500" />
--------------------------------
### Latent Upscale
**Description:** This node uses a small (~2.4mb) model to upscale the latents used in a Stable Diffusion 1.5 or Stable Diffusion XL image generation, rather than the typical interpolation method, avoiding the traditional downsides of the latent upscale technique.
**Node Link:** [https://github.com/gogurtenjoyer/latent-upscale](https://github.com/gogurtenjoyer/latent-upscale)
--------------------------------
### Load Video Frame
@@ -354,21 +346,12 @@ See full docs here: https://github.com/skunkworxdark/Prompt-tools-nodes/edit/mai
**Description:** A set of nodes for Metadata. Collect Metadata from within an `iterate` node & extract metadata from an image.
- `Metadata Item Linked` - Allows collecting of metadata while within an iterate node with no need for a collect node or conversion to metadata node
- `Metadata From Image` - Provides Metadata from an image
- `Metadata To String` - Extracts a String value of a label from metadata
- `Metadata To Integer` - Extracts an Integer value of a label from metadata
- `Metadata To Float` - Extracts a Float value of a label from metadata
- `Metadata To Scheduler` - Extracts a Scheduler value of a label from metadata
- `Metadata To Bool` - Extracts Bool types from metadata
- `Metadata To Model` - Extracts model types from metadata
- `Metadata To SDXL Model` - Extracts SDXL model types from metadata
- `Metadata To LoRAs` - Extracts Loras from metadata.
- `Metadata To SDXL LoRAs` - Extracts SDXL Loras from metadata
- `Metadata To ControlNets` - Extracts ControNets from metadata
- `Metadata To IP-Adapters` - Extracts IP-Adapters from metadata
- `Metadata To T2I-Adapters` - Extracts T2I-Adapters from metadata
- `Denoise Latents + Metadata` - This is an inherited version of the existing `Denoise Latents` node but with a metadata input and output.
- `Metadata Item Linked` - Allows collecting of metadata while within an iterate node with no need for a collect node or conversion to metadata node.
- `Metadata From Image` - Provides Metadata from an image.
- `Metadata To String` - Extracts a String value of a label from metadata.
- `Metadata To Integer` - Extracts an Integer value of a label from metadata.
- `Metadata To Float` - Extracts a Float value of a label from metadata.
- `Metadata To Scheduler` - Extracts a Scheduler value of a label from metadata.
**Node Link:** https://github.com/skunkworxdark/metadata-linked-nodes

View File

@@ -19,8 +19,6 @@ their descriptions.
| Conditioning Primitive | A conditioning tensor primitive value |
| Content Shuffle Processor | Applies content shuffle processing to image |
| ControlNet | Collects ControlNet info to pass to other nodes |
| Create Denoise Mask | Converts a greyscale or transparency image into a mask for denoising. |
| Create Gradient Mask | Creates a mask for Gradient ("soft", "differential") inpainting that gradually expands during denoising. Improves edge coherence. |
| Denoise Latents | Denoises noisy latents to decodable images |
| Divide Integers | Divides two numbers |
| Dynamic Prompt | Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator |

View File

@@ -0,0 +1,5 @@
mkdocs
mkdocs-material>=8, <9
mkdocs-git-revision-date-localized-plugin
mkdocs-redirects==1.2.0

View File

@@ -0,0 +1,5 @@
:root {
--md-primary-fg-color: #35A4DB;
--md-primary-fg-color--light: #35A4DB;
--md-primary-fg-color--dark: #35A4DB;
}

View File

@@ -2,18 +2,22 @@
set -e
BCYAN="\033[1;36m"
BYELLOW="\033[1;33m"
BGREEN="\033[1;32m"
BRED="\033[1;31m"
RED="\033[31m"
RESET="\033[0m"
BCYAN="\e[1;36m"
BYELLOW="\e[1;33m"
BGREEN="\e[1;32m"
BRED="\e[1;31m"
RED="\e[31m"
RESET="\e[0m"
function is_bin_in_path {
builtin type -P "$1" &>/dev/null
}
function git_show {
git show -s --format=oneline --abbrev-commit "$1" | cat
}
if [[ ! -z "${VIRTUAL_ENV}" ]]; then
if [[ -v "VIRTUAL_ENV" ]]; then
# we can't just call 'deactivate' because this function is not exported
# to the environment of this script from the bash process that runs the script
echo -e "${BRED}A virtual environment is activated. Please deactivate it before proceeding.${RESET}"
@@ -22,63 +26,31 @@ fi
cd "$(dirname "$0")"
echo
echo -e "${BYELLOW}This script must be run from the installer directory!${RESET}"
echo "The current working directory is $(pwd)"
read -p "If that looks right, press any key to proceed, or CTRL-C to exit..."
echo
# Some machines only have `python3` in PATH, others have `python` - make an alias.
# We can use a function to approximate an alias within a non-interactive shell.
if ! is_bin_in_path python && is_bin_in_path python3; then
function python {
python3 "$@"
}
fi
VERSION=$(
cd ..
python3 -c "from invokeai.version import __version__ as version; print(version)"
python -c "from invokeai.version import __version__ as version; print(version)"
)
VERSION="v${VERSION}"
if [[ ! -z ${CI} ]]; then
echo
echo -e "${BCYAN}CI environment detected${RESET}"
echo
else
echo
echo -e "${BYELLOW}This script must be run from the installer directory!${RESET}"
echo "The current working directory is $(pwd)"
read -p "If that looks right, press any key to proceed, or CTRL-C to exit..."
echo
fi
PATCH=""
VERSION="v${VERSION}${PATCH}"
echo -e "${BGREEN}HEAD${RESET}:"
git_show HEAD
echo
# ---------------------- FRONTEND ----------------------
pushd ../invokeai/frontend/web >/dev/null
echo "Installing frontend dependencies..."
echo
pnpm i --frozen-lockfile
echo
if [[ ! -z ${CI} ]]; then
echo "Building frontend without checks..."
# In CI, we have already done the frontend checks and can just build
pnpm vite build
else
echo "Running checks and building frontend..."
# This runs all the frontend checks and builds
pnpm build
fi
echo
popd
# ---------------------- BACKEND ----------------------
echo
echo "Building wheel..."
echo
# install the 'build' package in the user site packages, if needed
# could be improved by using a temporary venv, but it's tiny and harmless
if [[ $(python3 -c 'from importlib.util import find_spec; print(find_spec("build") is None)') == "True" ]]; then
pip install --user build
fi
rm -rf ../build
python3 -m build --outdir dist/ ../.
# ----------------------
echo
@@ -106,28 +78,10 @@ chmod a+x InvokeAI-Installer/install.sh
cp install.bat.in InvokeAI-Installer/install.bat
cp WinLongPathsEnabled.reg InvokeAI-Installer/
FILENAME=InvokeAI-installer-$VERSION.zip
# Zip everything up
zip -r ${FILENAME} InvokeAI-Installer
zip -r InvokeAI-installer-$VERSION.zip InvokeAI-Installer
echo
echo -e "${BGREEN}Built installer: ./${FILENAME}${RESET}"
echo -e "${BGREEN}Built PyPi distribution: ./dist${RESET}"
# clean up, but only if we are not in a github action
if [[ -z ${CI} ]]; then
echo
echo "Cleaning up intermediate build files..."
rm -rf InvokeAI-Installer tmp ../invokeai/frontend/web/dist/
fi
if [[ ! -z ${CI} ]]; then
echo
echo "Setting GitHub action outputs..."
echo "INSTALLER_FILENAME=${FILENAME}" >>$GITHUB_OUTPUT
echo "INSTALLER_PATH=installer/${FILENAME}" >>$GITHUB_OUTPUT
echo "DIST_PATH=installer/dist/" >>$GITHUB_OUTPUT
fi
# clean up
rm -rf InvokeAI-Installer tmp dist ../invokeai/frontend/web/dist/
exit 0

View File

@@ -2,12 +2,12 @@
set -e
BCYAN="\033[1;36m"
BYELLOW="\033[1;33m"
BGREEN="\033[1;32m"
BRED="\033[1;31m"
RED="\033[31m"
RESET="\033[0m"
BCYAN="\e[1;36m"
BYELLOW="\e[1;33m"
BGREEN="\e[1;32m"
BRED="\e[1;31m"
RED="\e[31m"
RESET="\e[0m"
function does_tag_exist {
git rev-parse --quiet --verify "refs/tags/$1" >/dev/null
@@ -23,40 +23,49 @@ function git_show {
VERSION=$(
cd ..
python3 -c "from invokeai.version import __version__ as version; print(version)"
python -c "from invokeai.version import __version__ as version; print(version)"
)
PATCH=""
MAJOR_VERSION=$(echo $VERSION | sed 's/\..*$//')
VERSION="v${VERSION}${PATCH}"
LATEST_TAG="v${MAJOR_VERSION}-latest"
if does_tag_exist $VERSION; then
echo -e "${BCYAN}${VERSION}${RESET} already exists:"
git_show_ref tags/$VERSION
echo
fi
if does_tag_exist $LATEST_TAG; then
echo -e "${BCYAN}${LATEST_TAG}${RESET} already exists:"
git_show_ref tags/$LATEST_TAG
echo
fi
echo -e "${BGREEN}HEAD${RESET}:"
git_show
echo
echo -e "${BGREEN}git remote -v${RESET}:"
git remote -v
echo
echo -e -n "Create tags ${BCYAN}${VERSION}${RESET} @ ${BGREEN}HEAD${RESET}, ${RED}deleting existing tags on origin remote${RESET}? "
echo -e -n "Create tags ${BCYAN}${VERSION}${RESET} and ${BCYAN}${LATEST_TAG}${RESET} @ ${BGREEN}HEAD${RESET}, ${RED}deleting existing tags on remote${RESET}? "
read -e -p 'y/n [n]: ' input
RESPONSE=${input:='n'}
if [ "$RESPONSE" == 'y' ]; then
echo
echo -e "Deleting ${BCYAN}${VERSION}${RESET} tag on origin remote..."
git push origin :refs/tags/$VERSION
echo -e "Deleting ${BCYAN}${VERSION}${RESET} tag on remote..."
git push --delete origin $VERSION
echo -e "Tagging ${BGREEN}HEAD${RESET} with ${BCYAN}${VERSION}${RESET} on locally..."
echo -e "Tagging ${BGREEN}HEAD${RESET} with ${BCYAN}${VERSION}${RESET} locally..."
if ! git tag -fa $VERSION; then
echo "Existing/invalid tag"
exit -1
fi
echo -e "Pushing updated tags to origin remote..."
echo -e "Deleting ${BCYAN}${LATEST_TAG}${RESET} tag on remote..."
git push --delete origin $LATEST_TAG
echo -e "Tagging ${BGREEN}HEAD${RESET} with ${BCYAN}${LATEST_TAG}${RESET} locally..."
git tag -fa $LATEST_TAG
echo -e "Pushing updated tags to remote..."
git push origin --tags
fi
exit 0

View File

@@ -4,6 +4,7 @@ from logging import Logger
import torch
from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMemory
from invokeai.app.services.object_serializer.object_serializer_disk import ObjectSerializerDisk
from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
@@ -15,13 +16,14 @@ from ..services.board_image_records.board_image_records_sqlite import SqliteBoar
from ..services.board_images.board_images_default import BoardImagesService
from ..services.board_records.board_records_sqlite import SqliteBoardRecordStorage
from ..services.boards.boards_default import BoardService
from ..services.bulk_download.bulk_download_default import BulkDownloadService
from ..services.config import InvokeAIAppConfig
from ..services.download import DownloadQueueService
from ..services.image_files.image_files_disk import DiskImageFileStorage
from ..services.image_records.image_records_sqlite import SqliteImageRecordStorage
from ..services.images.images_default import ImageService
from ..services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
from ..services.invocation_processor.invocation_processor_default import DefaultInvocationProcessor
from ..services.invocation_queue.invocation_queue_memory import MemoryInvocationQueue
from ..services.invocation_services import InvocationServices
from ..services.invocation_stats.invocation_stats_default import InvocationStatsService
from ..services.invoker import Invoker
@@ -31,6 +33,7 @@ from ..services.model_records import ModelRecordServiceSQL
from ..services.names.names_default import SimpleNameService
from ..services.session_processor.session_processor_default import DefaultSessionProcessor
from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue
from ..services.shared.graph import GraphExecutionState
from ..services.urls.urls_default import LocalUrlService
from ..services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
from .events import FastAPIEventService
@@ -82,7 +85,7 @@ class ApiDependencies:
board_records = SqliteBoardRecordStorage(db=db)
boards = BoardService()
events = FastAPIEventService(event_handler_id)
bulk_download = BulkDownloadService()
graph_execution_manager = ItemStorageMemory[GraphExecutionState]()
image_records = SqliteImageRecordStorage(db=db)
images = ImageService()
invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size)
@@ -102,6 +105,8 @@ class ApiDependencies:
)
names = SimpleNameService()
performance_statistics = InvocationStatsService()
processor = DefaultInvocationProcessor()
queue = MemoryInvocationQueue()
session_processor = DefaultSessionProcessor()
session_queue = SqliteSessionQueue(db=db)
urls = LocalUrlService()
@@ -112,9 +117,9 @@ class ApiDependencies:
board_images=board_images,
board_records=board_records,
boards=boards,
bulk_download=bulk_download,
configuration=configuration,
events=events,
graph_execution_manager=graph_execution_manager,
image_files=image_files,
image_records=image_records,
images=images,
@@ -124,6 +129,8 @@ class ApiDependencies:
download_queue=download_queue_service,
names=names,
performance_statistics=performance_statistics,
processor=processor,
queue=queue,
session_processor=session_processor,
session_queue=session_queue,
urls=urls,

View File

@@ -2,7 +2,7 @@ import io
import traceback
from typing import Optional
from fastapi import BackgroundTasks, Body, HTTPException, Path, Query, Request, Response, UploadFile
from fastapi import Body, HTTPException, Path, Query, Request, Response, UploadFile
from fastapi.responses import FileResponse
from fastapi.routing import APIRouter
from PIL import Image
@@ -375,67 +375,16 @@ async def unstar_images_in_list(
class ImagesDownloaded(BaseModel):
response: Optional[str] = Field(
default=None, description="The message to display to the user when images begin downloading"
)
bulk_download_item_name: Optional[str] = Field(
default=None, description="The name of the bulk download item for which events will be emitted"
description="If defined, the message to display to the user when images begin downloading"
)
@images_router.post(
"/download", operation_id="download_images_from_list", response_model=ImagesDownloaded, status_code=202
)
@images_router.post("/download", operation_id="download_images_from_list", response_model=ImagesDownloaded)
async def download_images_from_list(
background_tasks: BackgroundTasks,
image_names: Optional[list[str]] = Body(
default=None, description="The list of names of images to download", embed=True
),
image_names: list[str] = Body(description="The list of names of images to download", embed=True),
board_id: Optional[str] = Body(
default=None, description="The board from which image should be downloaded", embed=True
default=None, description="The board from which image should be downloaded from", embed=True
),
) -> ImagesDownloaded:
if (image_names is None or len(image_names) == 0) and board_id is None:
raise HTTPException(status_code=400, detail="No images or board id specified.")
bulk_download_item_id: str = ApiDependencies.invoker.services.bulk_download.generate_item_id(board_id)
background_tasks.add_task(
ApiDependencies.invoker.services.bulk_download.handler,
image_names,
board_id,
bulk_download_item_id,
)
return ImagesDownloaded(bulk_download_item_name=bulk_download_item_id + ".zip")
@images_router.api_route(
"/download/{bulk_download_item_name}",
methods=["GET"],
operation_id="get_bulk_download_item",
response_class=Response,
responses={
200: {
"description": "Return the complete bulk download item",
"content": {"application/zip": {}},
},
404: {"description": "Image not found"},
},
)
async def get_bulk_download_item(
background_tasks: BackgroundTasks,
bulk_download_item_name: str = Path(description="The bulk_download_item_name of the bulk download item to get"),
) -> FileResponse:
"""Gets a bulk download zip file"""
try:
path = ApiDependencies.invoker.services.bulk_download.get_path(bulk_download_item_name)
response = FileResponse(
path,
media_type="application/zip",
filename=bulk_download_item_name,
content_disposition_type="inline",
)
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
background_tasks.add_task(ApiDependencies.invoker.services.bulk_download.delete, bulk_download_item_name)
return response
except Exception:
raise HTTPException(status_code=404)
# return ImagesDownloaded(response="Your images are downloading")
raise HTTPException(status_code=501, detail="Endpoint is not yet implemented")

View File

@@ -9,12 +9,11 @@ from typing import Any, Dict, List, Optional, Set
from fastapi import Body, Path, Query, Response
from fastapi.routing import APIRouter
from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, ConfigDict
from starlette.exceptions import HTTPException
from typing_extensions import Annotated
from invokeai.app.services.model_install import ModelInstallJob
from invokeai.app.services.model_metadata.metadata_store_base import ModelMetadataChanges
from invokeai.app.services.model_install import ModelInstallJob, ModelSource
from invokeai.app.services.model_records import (
DuplicateModelException,
InvalidModelException,
@@ -33,8 +32,6 @@ from invokeai.backend.model_manager.config import (
)
from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
from invokeai.backend.model_manager.metadata.metadata_base import BaseMetadata
from invokeai.backend.model_manager.search import ModelSearch
from ..dependencies import ApiDependencies
@@ -167,27 +164,6 @@ async def list_model_records(
return ModelsList(models=found_models)
@model_manager_router.get(
"/get_by_attrs",
operation_id="get_model_records_by_attrs",
response_model=AnyModelConfig,
)
async def get_model_records_by_attrs(
name: str = Query(description="The name of the model"),
type: ModelType = Query(description="The type of the model"),
base: BaseModelType = Query(description="The base model of the model"),
) -> AnyModelConfig:
"""Gets a model by its attributes. The main use of this route is to provide backwards compatibility with the old
model manager, which identified models by a combination of name, base and type."""
configs = ApiDependencies.invoker.services.model_manager.store.search_by_attr(
base_model=base, model_type=type, model_name=name
)
if not configs:
raise HTTPException(status_code=404, detail="No model found with these attributes")
return configs[0]
@model_manager_router.get(
"/i/{key}",
operation_id="get_model_record",
@@ -225,7 +201,7 @@ async def list_model_summary(
@model_manager_router.get(
"/i/{key}/metadata",
"/meta/i/{key}",
operation_id="get_model_metadata",
responses={
200: {
@@ -233,6 +209,7 @@ async def list_model_summary(
"content": {"application/json": {"example": example_model_metadata}},
},
400: {"description": "Bad request"},
404: {"description": "No metadata available"},
},
)
async def get_model_metadata(
@@ -241,48 +218,8 @@ async def get_model_metadata(
"""Get a model metadata object."""
record_store = ApiDependencies.invoker.services.model_manager.store
result: Optional[AnyModelRepoMetadata] = record_store.get_metadata(key)
return result
@model_manager_router.patch(
"/i/{key}/metadata",
operation_id="update_model_metadata",
responses={
201: {
"description": "The model metadata was updated successfully",
"content": {"application/json": {"example": example_model_metadata}},
},
400: {"description": "Bad request"},
},
)
async def update_model_metadata(
key: str = Path(description="Key of the model repo metadata to fetch."),
changes: ModelMetadataChanges = Body(description="The changes"),
) -> Optional[AnyModelRepoMetadata]:
"""Updates or creates a model metadata object."""
record_store = ApiDependencies.invoker.services.model_manager.store
metadata_store = ApiDependencies.invoker.services.model_manager.store.metadata_store
try:
original_metadata = record_store.get_metadata(key)
if original_metadata:
if changes.default_settings:
original_metadata.default_settings = changes.default_settings
metadata_store.update_metadata(key, original_metadata)
else:
metadata_store.add_metadata(
key, BaseMetadata(name="", author="", default_settings=changes.default_settings)
)
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"An error occurred while updating the model metadata: {e}",
)
result: Optional[AnyModelRepoMetadata] = record_store.get_metadata(key)
if not result:
raise HTTPException(status_code=404, detail="No metadata for a model with this key")
return result
@@ -297,75 +234,6 @@ async def list_tags() -> Set[str]:
return result
class FoundModel(BaseModel):
path: str = Field(description="Path to the model")
is_installed: bool = Field(description="Whether or not the model is already installed")
@model_manager_router.get(
"/scan_folder",
operation_id="scan_for_models",
responses={
200: {"description": "Directory scanned successfully"},
400: {"description": "Invalid directory path"},
},
status_code=200,
response_model=List[FoundModel],
)
async def scan_for_models(
scan_path: str = Query(description="Directory path to search for models", default=None),
) -> List[FoundModel]:
path = pathlib.Path(scan_path)
if not scan_path or not path.is_dir():
raise HTTPException(
status_code=400,
detail=f"The search path '{scan_path}' does not exist or is not directory",
)
search = ModelSearch()
try:
found_model_paths = search.search(path)
models_path = ApiDependencies.invoker.services.configuration.models_path
# If the search path includes the main models directory, we need to exclude core models from the list.
# TODO(MM2): Core models should be handled by the model manager so we can determine if they are installed
# without needing to crawl the filesystem.
core_models_path = pathlib.Path(models_path, "core").resolve()
non_core_model_paths = [p for p in found_model_paths if not p.is_relative_to(core_models_path)]
installed_models = ApiDependencies.invoker.services.model_manager.store.search_by_attr()
resolved_installed_model_paths: list[str] = []
installed_model_sources: list[str] = []
# This call lists all installed models.
for model in installed_models:
path = pathlib.Path(model.path)
# If the model has a source, we need to add it to the list of installed sources.
if model.source:
installed_model_sources.append(model.source)
# If the path is not absolute, that means it is in the app models directory, and we need to join it with
# the models path before resolving.
if not path.is_absolute():
resolved_installed_model_paths.append(str(pathlib.Path(models_path, path).resolve()))
continue
resolved_installed_model_paths.append(str(path.resolve()))
scan_results: list[FoundModel] = []
# Check if the model is installed by comparing the resolved paths, appending to the scan result.
for p in non_core_model_paths:
path = str(p)
is_installed = path in resolved_installed_model_paths or path in installed_model_sources
found_model = FoundModel(path=path, is_installed=is_installed)
scan_results.append(found_model)
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"An error occurred while searching the directory: {e}",
)
return scan_results
@model_manager_router.get(
"/tags/search",
operation_id="search_by_metadata_tags",
@@ -482,8 +350,8 @@ async def add_model_record(
@model_manager_router.post(
"/install",
operation_id="install_model",
"/heuristic_import",
operation_id="heuristic_import_model",
responses={
201: {"description": "The model imported successfully"},
415: {"description": "Unrecognized file/folder format"},
@@ -492,14 +360,12 @@ async def add_model_record(
},
status_code=201,
)
async def install_model(
source: str = Query(description="Model source to install, can be a local path, repo_id, or remote URL"),
inplace: Optional[bool] = Query(description="Whether or not to install a local model in place", default=False),
# TODO(MM2): Can we type this?
async def heuristic_import(
source: str,
config: Optional[Dict[str, Any]] = Body(
description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
default=None,
example={"name": "string", "description": "string"},
example={"name": "modelT", "description": "antique cars"},
),
access_token: Optional[str] = None,
) -> ModelInstallJob:
@@ -536,8 +402,106 @@ async def install_model(
result: ModelInstallJob = installer.heuristic_import(
source=source,
config=config,
access_token=access_token,
inplace=bool(inplace),
)
logger.info(f"Started installation of {source}")
except UnknownModelException as e:
logger.error(str(e))
raise HTTPException(status_code=424, detail=str(e))
except InvalidModelException as e:
logger.error(str(e))
raise HTTPException(status_code=415)
except ValueError as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
return result
@model_manager_router.post(
"/install",
operation_id="import_model",
responses={
201: {"description": "The model imported successfully"},
415: {"description": "Unrecognized file/folder format"},
424: {"description": "The model appeared to import successfully, but could not be found in the model manager"},
409: {"description": "There is already a model corresponding to this path or repo_id"},
},
status_code=201,
)
async def import_model(
source: ModelSource,
config: Optional[Dict[str, Any]] = Body(
description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
default=None,
),
) -> ModelInstallJob:
"""Install a model using its local path, repo_id, or remote URL.
Models will be downloaded, probed, configured and installed in a
series of background threads. The return object has `status` attribute
that can be used to monitor progress.
The source object is a discriminated Union of LocalModelSource,
HFModelSource and URLModelSource. Set the "type" field to the
appropriate value:
* To install a local path using LocalModelSource, pass a source of form:
```
{
"type": "local",
"path": "/path/to/model",
"inplace": false
}
```
The "inplace" flag, if true, will register the model in place in its
current filesystem location. Otherwise, the model will be copied
into the InvokeAI models directory.
* To install a HuggingFace repo_id using HFModelSource, pass a source of form:
```
{
"type": "hf",
"repo_id": "stabilityai/stable-diffusion-2.0",
"variant": "fp16",
"subfolder": "vae",
"access_token": "f5820a918aaf01"
}
```
The `variant`, `subfolder` and `access_token` fields are optional.
* To install a remote model using an arbitrary URL, pass:
```
{
"type": "url",
"url": "http://www.civitai.com/models/123456",
"access_token": "f5820a918aaf01"
}
```
The `access_token` field is optonal
The model's configuration record will be probed and filled in
automatically. To override the default guesses, pass "metadata"
with a Dict containing the attributes you wish to override.
Installation occurs in the background. Either use list_model_install_jobs()
to poll for completion, or listen on the event bus for the following events:
* "model_install_running"
* "model_install_completed"
* "model_install_error"
On successful completion, the event's payload will contain the field "key"
containing the installed ID of the model. On an error, the event's payload
will contain the fields "error_type" and "error" describing the nature of the
error and its traceback, respectively.
"""
logger = ApiDependencies.invoker.services.logger
try:
installer = ApiDependencies.invoker.services.model_manager.install
result: ModelInstallJob = installer.import_model(
source=source,
config=config,
)
logger.info(f"Started installation of {source}")
except UnknownModelException as e:
@@ -673,7 +637,6 @@ async def convert_model(
Note that during the conversion process the key and model hash will change.
The return value is the model configuration for the converted model.
"""
model_manager = ApiDependencies.invoker.services.model_manager
logger = ApiDependencies.invoker.services.logger
loader = ApiDependencies.invoker.services.model_manager.load
store = ApiDependencies.invoker.services.model_manager.store
@@ -690,7 +653,7 @@ async def convert_model(
raise HTTPException(400, f"The model with key {key} is not a main checkpoint model.")
# loading the model will convert it into a cached diffusers file
model_manager.load_model_by_config(model_config, submodel_type=SubModelType.Scheduler)
loader.load_model_by_config(model_config, submodel_type=SubModelType.Scheduler)
# Get the path of the converted model from the loader
cache_path = loader.convert_cache.cache_path(key)

View File

@@ -0,0 +1,276 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from fastapi import HTTPException, Path
from fastapi.routing import APIRouter
from ...services.shared.graph import GraphExecutionState
from ..dependencies import ApiDependencies
session_router = APIRouter(prefix="/v1/sessions", tags=["sessions"])
# @session_router.post(
# "/",
# operation_id="create_session",
# responses={
# 200: {"model": GraphExecutionState},
# 400: {"description": "Invalid json"},
# },
# deprecated=True,
# )
# async def create_session(
# queue_id: str = Query(default="", description="The id of the queue to associate the session with"),
# graph: Optional[Graph] = Body(default=None, description="The graph to initialize the session with"),
# ) -> GraphExecutionState:
# """Creates a new session, optionally initializing it with an invocation graph"""
# session = ApiDependencies.invoker.create_execution_state(queue_id=queue_id, graph=graph)
# return session
# @session_router.get(
# "/",
# operation_id="list_sessions",
# responses={200: {"model": PaginatedResults[GraphExecutionState]}},
# deprecated=True,
# )
# async def list_sessions(
# page: int = Query(default=0, description="The page of results to get"),
# per_page: int = Query(default=10, description="The number of results per page"),
# query: str = Query(default="", description="The query string to search for"),
# ) -> PaginatedResults[GraphExecutionState]:
# """Gets a list of sessions, optionally searching"""
# if query == "":
# result = ApiDependencies.invoker.services.graph_execution_manager.list(page, per_page)
# else:
# result = ApiDependencies.invoker.services.graph_execution_manager.search(query, page, per_page)
# return result
@session_router.get(
"/{session_id}",
operation_id="get_session",
responses={
200: {"model": GraphExecutionState},
404: {"description": "Session not found"},
},
)
async def get_session(
session_id: str = Path(description="The id of the session to get"),
) -> GraphExecutionState:
"""Gets a session"""
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
if session is None:
raise HTTPException(status_code=404)
else:
return session
# @session_router.post(
# "/{session_id}/nodes",
# operation_id="add_node",
# responses={
# 200: {"model": str},
# 400: {"description": "Invalid node or link"},
# 404: {"description": "Session not found"},
# },
# deprecated=True,
# )
# async def add_node(
# session_id: str = Path(description="The id of the session"),
# node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body( # type: ignore
# description="The node to add"
# ),
# ) -> str:
# """Adds a node to the graph"""
# session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
# if session is None:
# raise HTTPException(status_code=404)
# try:
# session.add_node(node)
# ApiDependencies.invoker.services.graph_execution_manager.set(
# session
# ) # TODO: can this be done automatically, or add node through an API?
# return session.id
# except NodeAlreadyExecutedError:
# raise HTTPException(status_code=400)
# except IndexError:
# raise HTTPException(status_code=400)
# @session_router.put(
# "/{session_id}/nodes/{node_path}",
# operation_id="update_node",
# responses={
# 200: {"model": GraphExecutionState},
# 400: {"description": "Invalid node or link"},
# 404: {"description": "Session not found"},
# },
# deprecated=True,
# )
# async def update_node(
# session_id: str = Path(description="The id of the session"),
# node_path: str = Path(description="The path to the node in the graph"),
# node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body( # type: ignore
# description="The new node"
# ),
# ) -> GraphExecutionState:
# """Updates a node in the graph and removes all linked edges"""
# session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
# if session is None:
# raise HTTPException(status_code=404)
# try:
# session.update_node(node_path, node)
# ApiDependencies.invoker.services.graph_execution_manager.set(
# session
# ) # TODO: can this be done automatically, or add node through an API?
# return session
# except NodeAlreadyExecutedError:
# raise HTTPException(status_code=400)
# except IndexError:
# raise HTTPException(status_code=400)
# @session_router.delete(
# "/{session_id}/nodes/{node_path}",
# operation_id="delete_node",
# responses={
# 200: {"model": GraphExecutionState},
# 400: {"description": "Invalid node or link"},
# 404: {"description": "Session not found"},
# },
# deprecated=True,
# )
# async def delete_node(
# session_id: str = Path(description="The id of the session"),
# node_path: str = Path(description="The path to the node to delete"),
# ) -> GraphExecutionState:
# """Deletes a node in the graph and removes all linked edges"""
# session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
# if session is None:
# raise HTTPException(status_code=404)
# try:
# session.delete_node(node_path)
# ApiDependencies.invoker.services.graph_execution_manager.set(
# session
# ) # TODO: can this be done automatically, or add node through an API?
# return session
# except NodeAlreadyExecutedError:
# raise HTTPException(status_code=400)
# except IndexError:
# raise HTTPException(status_code=400)
# @session_router.post(
# "/{session_id}/edges",
# operation_id="add_edge",
# responses={
# 200: {"model": GraphExecutionState},
# 400: {"description": "Invalid node or link"},
# 404: {"description": "Session not found"},
# },
# deprecated=True,
# )
# async def add_edge(
# session_id: str = Path(description="The id of the session"),
# edge: Edge = Body(description="The edge to add"),
# ) -> GraphExecutionState:
# """Adds an edge to the graph"""
# session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
# if session is None:
# raise HTTPException(status_code=404)
# try:
# session.add_edge(edge)
# ApiDependencies.invoker.services.graph_execution_manager.set(
# session
# ) # TODO: can this be done automatically, or add node through an API?
# return session
# except NodeAlreadyExecutedError:
# raise HTTPException(status_code=400)
# except IndexError:
# raise HTTPException(status_code=400)
# # TODO: the edge being in the path here is really ugly, find a better solution
# @session_router.delete(
# "/{session_id}/edges/{from_node_id}/{from_field}/{to_node_id}/{to_field}",
# operation_id="delete_edge",
# responses={
# 200: {"model": GraphExecutionState},
# 400: {"description": "Invalid node or link"},
# 404: {"description": "Session not found"},
# },
# deprecated=True,
# )
# async def delete_edge(
# session_id: str = Path(description="The id of the session"),
# from_node_id: str = Path(description="The id of the node the edge is coming from"),
# from_field: str = Path(description="The field of the node the edge is coming from"),
# to_node_id: str = Path(description="The id of the node the edge is going to"),
# to_field: str = Path(description="The field of the node the edge is going to"),
# ) -> GraphExecutionState:
# """Deletes an edge from the graph"""
# session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
# if session is None:
# raise HTTPException(status_code=404)
# try:
# edge = Edge(
# source=EdgeConnection(node_id=from_node_id, field=from_field),
# destination=EdgeConnection(node_id=to_node_id, field=to_field),
# )
# session.delete_edge(edge)
# ApiDependencies.invoker.services.graph_execution_manager.set(
# session
# ) # TODO: can this be done automatically, or add node through an API?
# return session
# except NodeAlreadyExecutedError:
# raise HTTPException(status_code=400)
# except IndexError:
# raise HTTPException(status_code=400)
# @session_router.put(
# "/{session_id}/invoke",
# operation_id="invoke_session",
# responses={
# 200: {"model": None},
# 202: {"description": "The invocation is queued"},
# 400: {"description": "The session has no invocations ready to invoke"},
# 404: {"description": "Session not found"},
# },
# deprecated=True,
# )
# async def invoke_session(
# queue_id: str = Query(description="The id of the queue to associate the session with"),
# session_id: str = Path(description="The id of the session to invoke"),
# all: bool = Query(default=False, description="Whether or not to invoke all remaining invocations"),
# ) -> Response:
# """Invokes a session"""
# session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
# if session is None:
# raise HTTPException(status_code=404)
# if session.is_complete():
# raise HTTPException(status_code=400)
# ApiDependencies.invoker.invoke(queue_id, session, invoke_all=all)
# return Response(status_code=202)
# @session_router.delete(
# "/{session_id}/invoke",
# operation_id="cancel_session_invoke",
# responses={202: {"description": "The invocation is canceled"}},
# deprecated=True,
# )
# async def cancel_session_invoke(
# session_id: str = Path(description="The id of the session to cancel"),
# ) -> Response:
# """Invokes a session"""
# ApiDependencies.invoker.cancel(session_id)
# return Response(status_code=202)

View File

@@ -12,26 +12,16 @@ class SocketIO:
__sio: AsyncServer
__app: ASGIApp
__sub_queue: str = "subscribe_queue"
__unsub_queue: str = "unsubscribe_queue"
__sub_bulk_download: str = "subscribe_bulk_download"
__unsub_bulk_download: str = "unsubscribe_bulk_download"
def __init__(self, app: FastAPI):
self.__sio = AsyncServer(async_mode="asgi", cors_allowed_origins="*")
self.__app = ASGIApp(socketio_server=self.__sio, socketio_path="/ws/socket.io")
app.mount("/ws", self.__app)
self.__sio.on(self.__sub_queue, handler=self._handle_sub_queue)
self.__sio.on(self.__unsub_queue, handler=self._handle_unsub_queue)
self.__sio.on("subscribe_queue", handler=self._handle_sub_queue)
self.__sio.on("unsubscribe_queue", handler=self._handle_unsub_queue)
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._handle_queue_event)
local_handler.register(event_name=EventServiceBase.model_event, _func=self._handle_model_event)
self.__sio.on(self.__sub_bulk_download, handler=self._handle_sub_bulk_download)
self.__sio.on(self.__unsub_bulk_download, handler=self._handle_unsub_bulk_download)
local_handler.register(event_name=EventServiceBase.bulk_download_event, _func=self._handle_bulk_download_event)
async def _handle_queue_event(self, event: Event):
await self.__sio.emit(
event=event[1]["event"],
@@ -49,18 +39,3 @@ class SocketIO:
async def _handle_model_event(self, event: Event) -> None:
await self.__sio.emit(event=event[1]["event"], data=event[1]["data"])
async def _handle_bulk_download_event(self, event: Event):
await self.__sio.emit(
event=event[1]["event"],
data=event[1]["data"],
room=event[1]["data"]["bulk_download_id"],
)
async def _handle_sub_bulk_download(self, sid, data, *args, **kwargs):
if "bulk_download_id" in data:
await self.__sio.enter_room(sid, data["bulk_download_id"])
async def _handle_unsub_bulk_download(self, sid, data, *args, **kwargs):
if "bulk_download_id" in data:
await self.__sio.leave_room(sid, data["bulk_download_id"])

View File

@@ -2,7 +2,6 @@
# which are imported/used before parse_args() is called will get the default config values instead of the
# values from the command line or config file.
import sys
from contextlib import asynccontextmanager
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
from invokeai.version.invokeai_version import __version__
@@ -51,6 +50,7 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
images,
model_manager,
session_queue,
sessions,
utilities,
workflows,
)
@@ -72,25 +72,9 @@ logger = InvokeAILogger.get_logger(config=app_config)
mimetypes.add_type("application/javascript", ".js")
mimetypes.add_type("text/css", ".css")
@asynccontextmanager
async def lifespan(app: FastAPI):
# Add startup event to load dependencies
ApiDependencies.initialize(config=app_config, event_handler_id=event_handler_id, logger=logger)
yield
# Shut down threads
ApiDependencies.shutdown()
# Create the app
# TODO: create this all in a method so configuration/etc. can be passed in?
app = FastAPI(
title="Invoke - Community Edition",
docs_url=None,
redoc_url=None,
separate_input_output_schemas=False,
lifespan=lifespan,
)
app = FastAPI(title="Invoke - Community Edition", docs_url=None, redoc_url=None, separate_input_output_schemas=False)
# Add event handler
event_handler_id: int = id(app)
@@ -113,7 +97,21 @@ app.add_middleware(
app.add_middleware(GZipMiddleware, minimum_size=1000)
# Add startup event to load dependencies
@app.on_event("startup")
async def startup_event() -> None:
ApiDependencies.initialize(config=app_config, event_handler_id=event_handler_id, logger=logger)
# Shut down threads
@app.on_event("shutdown")
async def shutdown_event() -> None:
ApiDependencies.shutdown()
# Include all routers
app.include_router(sessions.session_router, prefix="/api")
app.include_router(utilities.utilities_router, prefix="/api")
app.include_router(model_manager.model_manager_router, prefix="/api")
app.include_router(download_queue.download_queue_router, prefix="/api")
@@ -153,8 +151,6 @@ def custom_openapi() -> dict[str, Any]:
# TODO: note that we assume the schema_key here is the TYPE.__name__
# This could break in some cases, figure out a better way to do it
output_type_titles[schema_key] = output_schema["title"]
openapi_schema["components"]["schemas"][schema_key] = output_schema
openapi_schema["components"]["schemas"][schema_key]["class"] = "output"
# Add Node Editor UI helper schemas
ui_config_schemas = models_json_schema(
@@ -177,6 +173,7 @@ def custom_openapi() -> dict[str, Any]:
outputs_ref = {"$ref": f"#/components/schemas/{output_type_title}"}
invoker_schema["output"] = outputs_ref
invoker_schema["class"] = "invocation"
openapi_schema["components"]["schemas"][f"{output_type_title}"]["class"] = "output"
# This code no longer seems to be necessary?
# Leave it here just in case

View File

@@ -8,26 +8,13 @@ import warnings
from abc import ABC, abstractmethod
from enum import Enum
from inspect import signature
from typing import (
TYPE_CHECKING,
Annotated,
Any,
Callable,
ClassVar,
Iterable,
Literal,
Optional,
Type,
TypeVar,
Union,
cast,
)
from types import UnionType
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Literal, Optional, Type, TypeVar, Union, cast
import semver
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter, create_model
from pydantic import BaseModel, ConfigDict, Field, create_model
from pydantic.fields import FieldInfo
from pydantic_core import PydanticUndefined
from typing_extensions import TypeAliasType
from invokeai.app.invocations.fields import (
FieldKind,
@@ -97,7 +84,6 @@ class BaseInvocationOutput(BaseModel):
"""
_output_classes: ClassVar[set[BaseInvocationOutput]] = set()
_typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None
@classmethod
def register_output(cls, output: BaseInvocationOutput) -> None:
@@ -110,14 +96,10 @@ class BaseInvocationOutput(BaseModel):
return cls._output_classes
@classmethod
def get_typeadapter(cls) -> TypeAdapter[Any]:
"""Gets a pydantc TypeAdapter for the union of all invocation output types."""
if not cls._typeadapter:
InvocationOutputsUnion = TypeAliasType(
"InvocationOutputsUnion", Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")]
)
cls._typeadapter = TypeAdapter(InvocationOutputsUnion)
return cls._typeadapter
def get_outputs_union(cls) -> UnionType:
"""Gets a union of all invocation outputs."""
outputs_union = Union[tuple(cls._output_classes)] # type: ignore [valid-type]
return outputs_union # type: ignore [return-value]
@classmethod
def get_output_types(cls) -> Iterable[str]:
@@ -166,7 +148,6 @@ class BaseInvocation(ABC, BaseModel):
"""
_invocation_classes: ClassVar[set[BaseInvocation]] = set()
_typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None
@classmethod
def get_type(cls) -> str:
@@ -179,14 +160,10 @@ class BaseInvocation(ABC, BaseModel):
cls._invocation_classes.add(invocation)
@classmethod
def get_typeadapter(cls) -> TypeAdapter[Any]:
"""Gets a pydantc TypeAdapter for the union of all invocation types."""
if not cls._typeadapter:
InvocationsUnion = TypeAliasType(
"InvocationsUnion", Annotated[Union[tuple(cls._invocation_classes)], Field(discriminator="type")]
)
cls._typeadapter = TypeAdapter(InvocationsUnion)
return cls._typeadapter
def get_invocations_union(cls) -> UnionType:
"""Gets a union of all invocation types."""
invocations_union = Union[tuple(cls._invocation_classes)] # type: ignore [valid-type]
return invocations_union # type: ignore [return-value]
@classmethod
def get_invocations(cls) -> Iterable[BaseInvocation]:

View File

@@ -1,23 +1,24 @@
from typing import Iterator, List, Optional, Tuple, Union, cast
from typing import Iterator, List, Optional, Tuple, Union
import torch
from compel import Compel, ReturnedEmbeddingsType
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from transformers import CLIPTokenizer
import invokeai.backend.util.logging as logger
from invokeai.app.invocations.fields import (
ConditioningField,
FieldDescriptions,
Input,
InputField,
MaskField,
OutputField,
UIComponent,
)
from invokeai.app.invocations.primitives import ConditioningOutput
from invokeai.app.services.model_records import UnknownModelException
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.ti_utils import generate_ti_list
from invokeai.app.util.ti_utils import extract_ti_triggers_from_prompt
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_manager import ModelType
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
BasicConditioningInfo,
@@ -25,9 +26,15 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
ExtraConditioningInfo,
SDXLConditioningInfo,
)
from invokeai.backend.textual_inversion import TextualInversionModelRaw
from invokeai.backend.util.devices import torch_dtype
from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
invocation,
invocation_output,
)
from .model import ClipField
# unconditioned: Optional[torch.Tensor]
@@ -44,7 +51,7 @@ from .model import ClipField
title="Prompt",
tags=["prompt", "compel"],
category="conditioning",
version="1.2.0",
version="1.0.1",
)
class CompelInvocation(BaseInvocation):
"""Parse prompt using compel package to conditioning."""
@@ -59,19 +66,11 @@ class CompelInvocation(BaseInvocation):
description=FieldDescriptions.clip,
input=Input.Connection,
)
mask: Optional[MaskField] = InputField(
default=None, description="A mask defining the region that this conditioning prompt applies to."
)
mask_weight: float = InputField(default=1.0, description="")
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ConditioningOutput:
tokenizer_info = context.models.load(**self.clip.tokenizer.model_dump())
tokenizer_model = tokenizer_info.model
assert isinstance(tokenizer_model, CLIPTokenizer)
text_encoder_info = context.models.load(**self.clip.text_encoder.model_dump())
text_encoder_model = text_encoder_info.model
assert isinstance(text_encoder_model, CLIPTextModel)
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.clip.loras:
@@ -83,10 +82,21 @@ class CompelInvocation(BaseInvocation):
# loras = [(context.models.get(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
ti_list = generate_ti_list(self.prompt, text_encoder_info.config.base, context)
ti_list = []
for trigger in extract_ti_triggers_from_prompt(self.prompt):
name = trigger[1:-1]
try:
loaded_model = context.models.load(**self.clip.text_encoder.model_dump()).model
assert isinstance(loaded_model, TextualInversionModelRaw)
ti_list.append((name, loaded_model))
except UnknownModelException:
# print(e)
# import traceback
# print(traceback.format_exc())
print(f'Warn: trigger: "{trigger}" not found')
with (
ModelPatcher.apply_ti(tokenizer_model, text_encoder_model, ti_list) as (
ModelPatcher.apply_ti(tokenizer_info.model, text_encoder_info.model, ti_list) as (
tokenizer,
ti_manager,
),
@@ -94,9 +104,8 @@ class CompelInvocation(BaseInvocation):
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
ModelPatcher.apply_clip_skip(text_encoder_model, self.clip.skipped_layers),
ModelPatcher.apply_clip_skip(text_encoder_info.model, self.clip.skipped_layers),
):
assert isinstance(text_encoder, CLIPTextModel)
compel = Compel(
tokenizer=tokenizer,
text_encoder=text_encoder,
@@ -130,13 +139,7 @@ class CompelInvocation(BaseInvocation):
conditioning_name = context.conditioning.save(conditioning_data)
return ConditioningOutput(
conditioning=ConditioningField(
conditioning_name=conditioning_name,
mask=self.mask,
mask_weight=self.mask_weight,
)
)
return ConditioningOutput.build(conditioning_name)
class SDXLPromptInvocationBase:
@@ -152,11 +155,7 @@ class SDXLPromptInvocationBase:
zero_on_empty: bool,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[ExtraConditioningInfo]]:
tokenizer_info = context.models.load(**clip_field.tokenizer.model_dump())
tokenizer_model = tokenizer_info.model
assert isinstance(tokenizer_model, CLIPTokenizer)
text_encoder_info = context.models.load(**clip_field.text_encoder.model_dump())
text_encoder_model = text_encoder_info.model
assert isinstance(text_encoder_model, (CLIPTextModel, CLIPTextModelWithProjection))
# return zero on empty
if prompt == "" and zero_on_empty:
@@ -190,10 +189,25 @@ class SDXLPromptInvocationBase:
# loras = [(context.models.get(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
ti_list = generate_ti_list(prompt, text_encoder_info.config.base, context)
ti_list = []
for trigger in extract_ti_triggers_from_prompt(prompt):
name = trigger[1:-1]
try:
ti_model = context.models.load_by_attrs(
model_name=name, base_model=text_encoder_info.config.base, model_type=ModelType.TextualInversion
).model
assert isinstance(ti_model, TextualInversionModelRaw)
ti_list.append((name, ti_model))
except UnknownModelException:
# print(e)
# import traceback
# print(traceback.format_exc())
logger.warning(f'trigger: "{trigger}" not found')
except ValueError:
logger.warning(f'trigger: "{trigger}" more than one similarly-named textual inversion models')
with (
ModelPatcher.apply_ti(tokenizer_model, text_encoder_model, ti_list) as (
ModelPatcher.apply_ti(tokenizer_info.model, text_encoder_info.model, ti_list) as (
tokenizer,
ti_manager,
),
@@ -201,10 +215,8 @@ class SDXLPromptInvocationBase:
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
ModelPatcher.apply_clip_skip(text_encoder_model, clip_field.skipped_layers),
ModelPatcher.apply_clip_skip(text_encoder_info.model, clip_field.skipped_layers),
):
assert isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection))
text_encoder = cast(CLIPTextModel, text_encoder)
compel = Compel(
tokenizer=tokenizer,
text_encoder=text_encoder,
@@ -250,7 +262,7 @@ class SDXLPromptInvocationBase:
title="SDXL Prompt",
tags=["sdxl", "compel", "prompt"],
category="conditioning",
version="1.2.0",
version="1.0.1",
)
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
"""Parse prompt using compel package to conditioning."""
@@ -274,11 +286,6 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1")
clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2")
mask: Optional[MaskField] = InputField(
default=None, description="A mask defining the region that this conditioning prompt applies to."
)
mask_weight: float = InputField(default=1.0, description="")
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ConditioningOutput:
c1, c1_pooled, ec1 = self.run_clip_compel(
@@ -340,13 +347,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
conditioning_name = context.conditioning.save(conditioning_data)
return ConditioningOutput(
conditioning=ConditioningField(
conditioning_name=conditioning_name,
mask=self.mask,
mask_weight=self.mask_weight,
)
)
return ConditioningOutput.build(conditioning_name)
@invocation(
@@ -395,7 +396,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
conditioning_name = context.conditioning.save(conditioning_data)
return ConditioningOutput(conditioning=ConditioningField(conditioning_name=conditioning_name, mask_weight=1.0))
return ConditioningOutput.build(conditioning_name)
@invocation_output("clip_skip_output")
@@ -416,7 +417,7 @@ class ClipSkipInvocation(BaseInvocation):
"""Skip layers in clip text_encoder model."""
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP")
skipped_layers: int = InputField(default=0, ge=0, description=FieldDescriptions.skipped_layers)
skipped_layers: int = InputField(default=0, description=FieldDescriptions.skipped_layers)
def invoke(self, context: InvocationContext) -> ClipSkipInvocationOutput:
self.clip.skipped_layers += self.skipped_layers

View File

@@ -1,40 +0,0 @@
import torch
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
InvocationContext,
invocation,
)
from invokeai.app.invocations.fields import InputField, WithMetadata
from invokeai.app.invocations.primitives import MaskField, MaskOutput
@invocation(
"rectangle_mask",
title="Create Rectangle Mask",
tags=["conditioning"],
category="conditioning",
version="1.0.0",
)
class RectangleMaskInvocation(BaseInvocation, WithMetadata):
"""Create a rectangular mask."""
height: int = InputField(description="The height of the entire mask.")
width: int = InputField(description="The width of the entire mask.")
y_top: int = InputField(description="The top y-coordinate of the rectangular masked region (inclusive).")
x_left: int = InputField(description="The left x-coordinate of the rectangular masked region (inclusive).")
rectangle_height: int = InputField(description="The height of the rectangular masked region.")
rectangle_width: int = InputField(description="The width of the rectangular masked region.")
def invoke(self, context: InvocationContext) -> MaskOutput:
mask = torch.zeros((1, self.height, self.width), dtype=torch.bool)
mask[
:, self.y_top : self.y_top + self.rectangle_height, self.x_left : self.x_left + self.rectangle_width
] = True
mask_name = context.tensors.save(mask)
return MaskOutput(
mask=MaskField(mask_name=mask_name),
width=self.width,
height=self.height,
)

View File

@@ -194,18 +194,11 @@ class BoardField(BaseModel):
board_id: str = Field(description="The id of the board")
class MaskField(BaseModel):
"""A mask primitive field."""
mask_name: str = Field(description="The name of the mask.")
class DenoiseMaskField(BaseModel):
"""An inpaint mask field"""
mask_name: str = Field(description="The name of the mask image")
masked_latents_name: Optional[str] = Field(default=None, description="The name of the masked image latents")
gradient: bool = Field(default=False, description="Used for gradient inpainting")
class LatentsField(BaseModel):
@@ -231,12 +224,7 @@ class ConditioningField(BaseModel):
"""A conditioning tensor primitive value"""
conditioning_name: str = Field(description="The name of conditioning tensor")
mask: Optional[MaskField] = Field(
default=None,
description="The bool mask associated with this conditioning tensor. Excluded regions should be set to False, "
"included regions should be set to True.",
)
mask_weight: float = Field(description="")
# endregion
class MetadataField(RootModel):

View File

@@ -22,7 +22,11 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
from invokeai.backend.image_util.safety_checker import SafetyChecker
from .baseinvocation import BaseInvocation, Classification, invocation
from .baseinvocation import (
BaseInvocation,
Classification,
invocation,
)
@invocation("show_image", title="Show Image", tags=["image"], category="image", version="1.0.1")
@@ -930,40 +934,3 @@ class SaveImageInvocation(BaseInvocation, WithMetadata, WithBoard):
image_dto = context.images.save(image=image)
return ImageOutput.build(image_dto)
@invocation(
"canvas_paste_back",
title="Canvas Paste Back",
tags=["image", "combine"],
category="image",
version="1.0.0",
)
class CanvasPasteBackInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Combines two images by using the mask provided. Intended for use on the Unified Canvas."""
source_image: ImageField = InputField(description="The source image")
target_image: ImageField = InputField(default=None, description="The target image")
mask: ImageField = InputField(
description="The mask to use when pasting",
)
mask_blur: int = InputField(default=0, ge=0, description="The amount to blur the mask by")
def _prepare_mask(self, mask: Image.Image) -> Image.Image:
mask_array = numpy.array(mask)
kernel = numpy.ones((self.mask_blur, self.mask_blur), numpy.uint8)
dilated_mask_array = cv2.erode(mask_array, kernel, iterations=3)
dilated_mask = Image.fromarray(dilated_mask_array)
if self.mask_blur > 0:
mask = dilated_mask.filter(ImageFilter.GaussianBlur(self.mask_blur))
return ImageOps.invert(mask.convert("L"))
def invoke(self, context: InvocationContext) -> ImageOutput:
source_image = context.images.get_pil(self.source_image.image_name)
target_image = context.images.get_pil(self.target_image.image_name)
mask = self._prepare_mask(context.images.get_pil(self.mask.image_name))
source_image.paste(target_image, (0, 0), mask)
image_dto = context.images.save(image=source_image)
return ImageOutput.build(image_dto)

View File

@@ -93,7 +93,7 @@ class IPAdapterInvocation(BaseInvocation):
image_encoder_model_id = ip_adapter_info.image_encoder_model_id
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
image_encoder_models = context.models.search_by_attrs(
name=image_encoder_model_name, base=BaseModelType.Any, type=ModelType.CLIPVision
model_name=image_encoder_model_name, base_model=BaseModelType.Any, model_type=ModelType.CLIPVision
)
assert len(image_encoder_models) == 1
image_encoder_model = CLIPVisionModelField(key=image_encoder_models[0].key)

View File

@@ -1,5 +1,5 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
import inspect
import math
from contextlib import ExitStack
from functools import singledispatchmethod
@@ -9,7 +9,6 @@ import einops
import numpy as np
import numpy.typing as npt
import torch
import torchvision
import torchvision.transforms as T
from diffusers import AutoencoderKL, AutoencoderTiny
from diffusers.configuration_utils import ConfigMixin
@@ -24,7 +23,7 @@ from diffusers.models.attention_processor import (
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from diffusers.schedulers import DPMSolverSDEScheduler
from diffusers.schedulers import SchedulerMixin as Scheduler
from PIL import Image, ImageFilter
from PIL import Image
from pydantic import field_validator
from torchvision.transforms.functional import resize as tv_resize
@@ -56,14 +55,7 @@ from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_manager import BaseModelType, LoadedModel
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
BasicConditioningInfo,
IPAdapterConditioningInfo,
Range,
SDXLConditioningInfo,
TextConditioningData,
TextConditioningRegions,
)
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData, IPAdapterConditioningInfo
from invokeai.backend.util.silence_warnings import SilenceWarnings
from ...backend.stable_diffusion.diffusers_pipeline import (
@@ -73,6 +65,7 @@ from ...backend.stable_diffusion.diffusers_pipeline import (
T2IAdapterData,
image_resized_to_grid_as_tensor,
)
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
from ...backend.util.devices import choose_precision, choose_torch_device
from .baseinvocation import (
@@ -135,7 +128,7 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
ui_order=4,
)
def prep_mask_tensor(self, mask_image: Image.Image) -> torch.Tensor:
def prep_mask_tensor(self, mask_image: Image) -> torch.Tensor:
if mask_image.mode != "L":
mask_image = mask_image.convert("L")
mask_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
@@ -176,76 +169,6 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
return DenoiseMaskOutput.build(
mask_name=mask_name,
masked_latents_name=masked_latents_name,
gradient=False,
)
@invocation_output("gradient_mask_output")
class GradientMaskOutput(BaseInvocationOutput):
"""Outputs a denoise mask and an image representing the total gradient of the mask."""
denoise_mask: DenoiseMaskField = OutputField(description="Mask for denoise model run")
expanded_mask_area: ImageField = OutputField(
description="Image representing the total gradient area of the mask. For paste-back purposes."
)
@invocation(
"create_gradient_mask",
title="Create Gradient Mask",
tags=["mask", "denoise"],
category="latents",
version="1.0.0",
)
class CreateGradientMaskInvocation(BaseInvocation):
"""Creates mask for denoising model run."""
mask: ImageField = InputField(default=None, description="Image which will be masked", ui_order=1)
edge_radius: int = InputField(
default=16, ge=0, description="How far to blur/expand the edges of the mask", ui_order=2
)
coherence_mode: Literal["Gaussian Blur", "Box Blur", "Staged"] = InputField(default="Gaussian Blur", ui_order=3)
minimum_denoise: float = InputField(
default=0.0, ge=0, le=1, description="Minimum denoise level for the coherence region", ui_order=4
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> GradientMaskOutput:
mask_image = context.images.get_pil(self.mask.image_name, mode="L")
if self.edge_radius > 0:
if self.coherence_mode == "Box Blur":
blur_mask = mask_image.filter(ImageFilter.BoxBlur(self.edge_radius))
else: # Gaussian Blur OR Staged
# Gaussian Blur uses standard deviation. 1/2 radius is a good approximation
blur_mask = mask_image.filter(ImageFilter.GaussianBlur(self.edge_radius / 2))
blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(blur_mask, normalize=False)
# redistribute blur so that the original edges are 0 and blur outwards to 1
blur_tensor = (blur_tensor - 0.5) * 2
threshold = 1 - self.minimum_denoise
if self.coherence_mode == "Staged":
# wherever the blur_tensor is less than fully masked, convert it to threshold
blur_tensor = torch.where((blur_tensor < 1) & (blur_tensor > 0), threshold, blur_tensor)
else:
# wherever the blur_tensor is above threshold but less than 1, drop it to threshold
blur_tensor = torch.where((blur_tensor > threshold) & (blur_tensor < 1), threshold, blur_tensor)
else:
blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
mask_name = context.tensors.save(tensor=blur_tensor.unsqueeze(1))
# compute a [0, 1] mask from the blur_tensor
expanded_mask = torch.where((blur_tensor < 1), 0, 1)
expanded_mask_image = Image.fromarray((expanded_mask.squeeze(0).numpy() * 255).astype(np.uint8), mode="L")
expanded_image_dto = context.images.save(expanded_mask_image)
return GradientMaskOutput(
denoise_mask=DenoiseMaskField(mask_name=mask_name, masked_latents_name=None, gradient=True),
expanded_mask_area=ImageField(image_name=expanded_image_dto.image_name),
)
@@ -291,11 +214,11 @@ def get_scheduler(
class DenoiseLatentsInvocation(BaseInvocation):
"""Denoises noisy latents to decodable images"""
positive_conditioning: Union[ConditioningField, list[ConditioningField]] = InputField(
positive_conditioning: ConditioningField = InputField(
description=FieldDescriptions.positive_cond, input=Input.Connection, ui_order=0
)
negative_conditioning: Union[ConditioningField, list[ConditioningField]] = InputField(
description=FieldDescriptions.negative_cond, input=Input.Connection, ui_order=0
negative_conditioning: ConditioningField = InputField(
description=FieldDescriptions.negative_cond, input=Input.Connection, ui_order=1
)
noise: Optional[LatentsField] = InputField(
default=None,
@@ -372,190 +295,41 @@ class DenoiseLatentsInvocation(BaseInvocation):
raise ValueError("cfg_scale must be greater than 1")
return v
def _get_text_embeddings_and_masks(
self,
cond_list: list[ConditioningField],
context: InvocationContext,
device: torch.device,
dtype: torch.dtype,
) -> tuple[Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]], list[Optional[torch.Tensor]]]:
"""Get the text embeddings and masks from the input conditioning fields."""
text_embeddings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]] = []
text_embeddings_masks: list[Optional[torch.Tensor]] = []
for cond in cond_list:
cond_data = context.conditioning.load(cond.conditioning_name)
text_embeddings.append(cond_data.conditionings[0].to(device=device, dtype=dtype))
mask = cond.mask
if mask is not None:
mask = context.tensors.load(mask.mask_name)
text_embeddings_masks.append(mask)
return text_embeddings, text_embeddings_masks
def _preprocess_regional_prompt_mask(
self, mask: Optional[torch.Tensor], target_height: int, target_width: int
) -> torch.Tensor:
"""Preprocess a regional prompt mask to match the target height and width.
If mask is None, returns a mask of all ones with the target height and width.
If mask is not None, resizes the mask to the target height and width using nearest neighbor interpolation.
Returns:
torch.Tensor: The processed mask. dtype: torch.bool, shape: (1, 1, target_height, target_width).
"""
if mask is None:
return torch.ones((1, 1, target_height, target_width), dtype=torch.bool)
tf = torchvision.transforms.Resize(
(target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST
)
mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w)
mask = tf(mask)
return mask
def concat_regional_text_embeddings(
self,
text_conditionings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]],
masks: Optional[list[Optional[torch.Tensor]]],
conditioning_fields: list[ConditioningField],
latent_height: int,
latent_width: int,
) -> tuple[Union[BasicConditioningInfo, SDXLConditioningInfo], Optional[TextConditioningRegions]]:
"""Concatenate regional text embeddings into a single embedding and track the region masks accordingly."""
if masks is None:
masks = [None] * len(text_conditionings)
assert len(text_conditionings) == len(masks)
is_sdxl = type(text_conditionings[0]) is SDXLConditioningInfo
all_masks_are_none = all(mask is None for mask in masks)
text_embedding = []
pooled_embedding = None
add_time_ids = None
cur_text_embedding_len = 0
processed_masks = []
embedding_ranges = []
extra_conditioning = None
for prompt_idx, text_embedding_info in enumerate(text_conditionings):
mask = masks[prompt_idx]
if (
text_embedding_info.extra_conditioning is not None
and text_embedding_info.extra_conditioning.wants_cross_attention_control
):
extra_conditioning = text_embedding_info.extra_conditioning
if is_sdxl:
# We choose a random SDXLConditioningInfo's pooled_embeds and add_time_ids here, with a preference for
# prompts without a mask. We prefer prompts without a mask, because they are more likely to contain
# global prompt information. In an ideal case, there should be exactly one global prompt without a
# mask, but we don't enforce this.
# HACK(ryand): The fact that we have to choose a single pooled_embedding and add_time_ids here is a
# fundamental interface issue. The SDXL Compel nodes are not designed to be used in the way that we use
# them for regional prompting. Ideally, the DenoiseLatents invocation should accept a single
# pooled_embeds tensor and a list of standard text embeds with region masks. This change would be a
# pretty major breaking change to a popular node, so for now we use this hack.
if pooled_embedding is None or mask is None:
pooled_embedding = text_embedding_info.pooled_embeds
if add_time_ids is None or mask is None:
add_time_ids = text_embedding_info.add_time_ids
text_embedding.append(text_embedding_info.embeds)
if not all_masks_are_none:
# embedding_ranges.append(
# Range(
# start=cur_text_embedding_len, end=cur_text_embedding_len + text_embedding_info.embeds.shape[1]
# )
# )
# HACK(ryand): Contrary to its name, tokens_count_including_eos_bos does not seem to include eos and bos
# in the count.
embedding_ranges.append(
Range(
start=cur_text_embedding_len + 1,
end=cur_text_embedding_len
+ text_embedding_info.extra_conditioning.tokens_count_including_eos_bos,
)
)
processed_masks.append(self._preprocess_regional_prompt_mask(mask, latent_height, latent_width))
cur_text_embedding_len += text_embedding_info.embeds.shape[1]
text_embedding = torch.cat(text_embedding, dim=1)
assert len(text_embedding.shape) == 3 # batch_size, seq_len, token_len
regions = None
if not all_masks_are_none:
regions = TextConditioningRegions(
masks=torch.cat(processed_masks, dim=1),
ranges=embedding_ranges,
mask_weights=[x.mask_weight for x in conditioning_fields],
)
if extra_conditioning is not None and len(text_conditionings) > 1:
raise ValueError(
"Prompt-to-prompt cross-attention control (a.k.a. `swap()`) is not supported when using multiple "
"prompts."
)
if is_sdxl:
return SDXLConditioningInfo(
embeds=text_embedding,
extra_conditioning=extra_conditioning,
pooled_embeds=pooled_embedding,
add_time_ids=add_time_ids,
), regions
return BasicConditioningInfo(
embeds=text_embedding,
extra_conditioning=extra_conditioning,
), regions
def get_conditioning_data(
self,
context: InvocationContext,
scheduler: Scheduler,
unet: UNet2DConditionModel,
latent_height: int,
latent_width: int,
) -> TextConditioningData:
# Normalize self.positive_conditioning and self.negative_conditioning to lists.
cond_list = self.positive_conditioning
if not isinstance(cond_list, list):
cond_list = [cond_list]
uncond_list = self.negative_conditioning
if not isinstance(uncond_list, list):
uncond_list = [uncond_list]
seed: int,
) -> ConditioningData:
positive_cond_data = context.conditioning.load(self.positive_conditioning.conditioning_name)
c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
extra_conditioning_info = c.extra_conditioning
cond_text_embeddings, cond_text_embedding_masks = self._get_text_embeddings_and_masks(
cond_list, context, unet.device, unet.dtype
)
negative_cond_data = context.conditioning.load(self.negative_conditioning.conditioning_name)
uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
uncond_text_embeddings, uncond_text_embedding_masks = self._get_text_embeddings_and_masks(
uncond_list, context, unet.device, unet.dtype
)
cond_text_embedding, cond_regions = self.concat_regional_text_embeddings(
text_conditionings=cond_text_embeddings,
masks=cond_text_embedding_masks,
conditioning_fields=cond_list,
latent_height=latent_height,
latent_width=latent_width,
)
uncond_text_embedding, uncond_regions = self.concat_regional_text_embeddings(
text_conditionings=uncond_text_embeddings,
masks=uncond_text_embedding_masks,
conditioning_fields=uncond_list,
latent_height=latent_height,
latent_width=latent_width,
)
conditioning_data = TextConditioningData(
uncond_text=uncond_text_embedding,
cond_text=cond_text_embedding,
uncond_regions=uncond_regions,
cond_regions=cond_regions,
conditioning_data = ConditioningData(
unconditioned_embeddings=uc,
text_embeddings=c,
guidance_scale=self.cfg_scale,
guidance_rescale_multiplier=self.cfg_rescale_multiplier,
extra=extra_conditioning_info,
postprocessing_settings=PostprocessingSettings(
threshold=0.0, # threshold,
warmup=0.2, # warmup,
h_symmetry_time_pct=None, # h_symmetry_time_pct,
v_symmetry_time_pct=None, # v_symmetry_time_pct,
),
)
conditioning_data = conditioning_data.add_scheduler_args_if_applicable( # FIXME
scheduler,
# for ddim scheduler
eta=0.0, # ddim_eta
# for ancestral and sde schedulers
# flip all bits to have noise different from initial
generator=torch.Generator(device=unet.device).manual_seed(seed ^ 0xFFFFFFFF),
)
return conditioning_data
@@ -661,6 +435,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
self,
context: InvocationContext,
ip_adapter: Optional[Union[IPAdapterField, list[IPAdapterField]]],
conditioning_data: ConditioningData,
exit_stack: ExitStack,
) -> Optional[list[IPAdapterData]]:
"""If IP-Adapter is enabled, then this function loads the requisite models, and adds the image prompt embeddings
@@ -677,6 +452,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
return None
ip_adapter_data_list = []
conditioning_data.ip_adapter_conditioning = []
for single_ip_adapter in ip_adapter:
ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context(
context.models.load(key=single_ip_adapter.ip_adapter_model.key)
@@ -699,13 +475,16 @@ class DenoiseLatentsInvocation(BaseInvocation):
single_ipa_images, image_encoder_model
)
conditioning_data.ip_adapter_conditioning.append(
IPAdapterConditioningInfo(image_prompt_embeds, uncond_image_prompt_embeds)
)
ip_adapter_data_list.append(
IPAdapterData(
ip_adapter_model=ip_adapter_model,
weight=single_ip_adapter.weight,
begin_step_percent=single_ip_adapter.begin_step_percent,
end_step_percent=single_ip_adapter.end_step_percent,
ip_adapter_conditioning=IPAdapterConditioningInfo(image_prompt_embeds, uncond_image_prompt_embeds),
)
)
@@ -795,7 +574,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
steps: int,
denoising_start: float,
denoising_end: float,
seed: int,
) -> Tuple[int, List[int], int]:
assert isinstance(scheduler, ConfigMixin)
if scheduler.config.get("cpu_only", False):
@@ -824,21 +602,13 @@ class DenoiseLatentsInvocation(BaseInvocation):
timesteps = timesteps[t_start_idx : t_start_idx + t_end_idx]
num_inference_steps = len(timesteps) // scheduler.order
scheduler_step_kwargs = {}
scheduler_step_signature = inspect.signature(scheduler.step)
if "generator" in scheduler_step_signature.parameters:
# At some point, someone decided that schedulers that accept a generator should use the original seed with
# all bits flipped. I don't know the original rationale for this, but now we must keep it like this for
# reproducibility.
scheduler_step_kwargs = {"generator": torch.Generator(device=device).manual_seed(seed ^ 0xFFFFFFFF)}
return num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs
return num_inference_steps, timesteps, init_timestep
def prep_inpaint_mask(
self, context: InvocationContext, latents: torch.Tensor
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], bool]:
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
if self.denoise_mask is None:
return None, None, False
return None, None
mask = context.tensors.load(self.denoise_mask.mask_name)
mask = tv_resize(mask, latents.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
@@ -847,7 +617,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
else:
masked_latents = None
return 1 - mask, masked_latents, self.denoise_mask.gradient
return 1 - mask, masked_latents
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
@@ -874,7 +644,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
if seed is None:
seed = 0
mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents)
mask, masked_latents = self.prep_inpaint_mask(context, latents)
# TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets,
# below. Investigate whether this is appropriate.
@@ -925,10 +695,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
)
pipeline = self.create_pipeline(unet, scheduler)
_, _, latent_height, latent_width = latents.shape
conditioning_data = self.get_conditioning_data(
context=context, unet=unet, latent_height=latent_height, latent_width=latent_width
)
conditioning_data = self.get_conditioning_data(context, scheduler, unet, seed)
controlnet_data = self.prep_control_data(
context=context,
@@ -942,19 +709,22 @@ class DenoiseLatentsInvocation(BaseInvocation):
ip_adapter_data = self.prep_ip_adapter_data(
context=context,
ip_adapter=self.ip_adapter,
conditioning_data=conditioning_data,
exit_stack=exit_stack,
)
num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler(
num_inference_steps, timesteps, init_timestep = self.init_scheduler(
scheduler,
device=unet.device,
steps=self.steps,
denoising_start=self.denoising_start,
denoising_end=self.denoising_end,
seed=seed,
)
result_latents = pipeline.latents_from_embeddings(
(
result_latents,
result_attention_map_saver,
) = pipeline.latents_from_embeddings(
latents=latents,
timesteps=timesteps,
init_timestep=init_timestep,
@@ -962,9 +732,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
seed=seed,
mask=mask,
masked_latents=masked_latents,
gradient_mask=gradient_mask,
num_inference_steps=num_inference_steps,
scheduler_step_kwargs=scheduler_step_kwargs,
conditioning_data=conditioning_data,
control_data=controlnet_data,
ip_adapter_data=ip_adapter_data,

View File

@@ -33,7 +33,7 @@ class MetadataItemField(BaseModel):
class LoRAMetadataField(BaseModel):
"""LoRA Metadata Field"""
model: LoRAModelField = Field(description=FieldDescriptions.lora_model)
lora: LoRAModelField = Field(description=FieldDescriptions.lora_model)
weight: float = Field(description=FieldDescriptions.lora_weight)
@@ -114,7 +114,7 @@ GENERATION_MODES = Literal[
]
@invocation("core_metadata", title="Core Metadata", tags=["metadata"], category="metadata", version="1.1.1")
@invocation("core_metadata", title="Core Metadata", tags=["metadata"], category="metadata", version="1.0.1")
class CoreMetadataInvocation(BaseInvocation):
"""Collects core generation metadata into a MetadataField"""

View File

@@ -14,7 +14,6 @@ from invokeai.app.invocations.fields import (
Input,
InputField,
LatentsField,
MaskField,
OutputField,
UIComponent,
)
@@ -230,18 +229,6 @@ class StringCollectionInvocation(BaseInvocation):
# region Image
@invocation_output("mask_output")
class MaskOutput(BaseInvocationOutput):
"""A torch mask tensor.
dtype: torch.bool
shape: (1, height, width).
"""
mask: MaskField = OutputField(description="The mask.")
width: int = OutputField(description="The width of the mask in pixels.")
height: int = OutputField(description="The height of the mask in pixels.")
@invocation_output("image_output")
class ImageOutput(BaseInvocationOutput):
"""Base class for nodes that output a single image"""
@@ -312,13 +299,9 @@ class DenoiseMaskOutput(BaseInvocationOutput):
denoise_mask: DenoiseMaskField = OutputField(description="Mask for denoise model run")
@classmethod
def build(
cls, mask_name: str, masked_latents_name: Optional[str] = None, gradient: bool = False
) -> "DenoiseMaskOutput":
def build(cls, mask_name: str, masked_latents_name: Optional[str] = None) -> "DenoiseMaskOutput":
return cls(
denoise_mask=DenoiseMaskField(
mask_name=mask_name, masked_latents_name=masked_latents_name, gradient=gradient
),
denoise_mask=DenoiseMaskField(mask_name=mask_name, masked_latents_name=masked_latents_name),
)
@@ -427,6 +410,10 @@ class ConditioningOutput(BaseInvocationOutput):
conditioning: ConditioningField = OutputField(description=FieldDescriptions.cond)
@classmethod
def build(cls, conditioning_name: str) -> "ConditioningOutput":
return cls(conditioning=ConditioningField(conditioning_name=conditioning_name))
@invocation_output("conditioning_collection_output")
class ConditioningCollectionOutput(BaseInvocationOutput):

View File

@@ -1,44 +0,0 @@
from abc import ABC, abstractmethod
from typing import Optional
class BulkDownloadBase(ABC):
"""Responsible for creating a zip file containing the images specified by the given image names or board id."""
@abstractmethod
def handler(
self, image_names: Optional[list[str]], board_id: Optional[str], bulk_download_item_id: Optional[str]
) -> None:
"""
Create a zip file containing the images specified by the given image names or board id.
:param image_names: A list of image names to include in the zip file.
:param board_id: The ID of the board. If provided, all images associated with the board will be included in the zip file.
:param bulk_download_item_id: The bulk_download_item_id that will be used to retrieve the bulk download item when it is prepared, if none is provided a uuid will be generated.
"""
@abstractmethod
def get_path(self, bulk_download_item_name: str) -> str:
"""
Get the path to the bulk download file.
:param bulk_download_item_name: The name of the bulk download item.
:return: The path to the bulk download file.
"""
@abstractmethod
def generate_item_id(self, board_id: Optional[str]) -> str:
"""
Generate an item ID for a bulk download item.
:param board_id: The ID of the board whose name is to be included in the item id.
:return: The generated item ID.
"""
@abstractmethod
def delete(self, bulk_download_item_name: str) -> None:
"""
Delete the bulk download file.
:param bulk_download_item_name: The name of the bulk download item.
"""

View File

@@ -1,25 +0,0 @@
DEFAULT_BULK_DOWNLOAD_ID = "default"
class BulkDownloadException(Exception):
"""Exception raised when a bulk download fails."""
def __init__(self, message="Bulk download failed"):
super().__init__(message)
self.message = message
class BulkDownloadTargetException(BulkDownloadException):
"""Exception raised when a bulk download target is not found."""
def __init__(self, message="The bulk download target was not found"):
super().__init__(message)
self.message = message
class BulkDownloadParametersException(BulkDownloadException):
"""Exception raised when a bulk download parameter is invalid."""
def __init__(self, message="No image names or board ID provided"):
super().__init__(message)
self.message = message

View File

@@ -1,157 +0,0 @@
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Optional, Union
from zipfile import ZipFile
from invokeai.app.services.board_records.board_records_common import BoardRecordNotFoundException
from invokeai.app.services.bulk_download.bulk_download_common import (
DEFAULT_BULK_DOWNLOAD_ID,
BulkDownloadException,
BulkDownloadParametersException,
BulkDownloadTargetException,
)
from invokeai.app.services.image_records.image_records_common import ImageRecordNotFoundException
from invokeai.app.services.images.images_common import ImageDTO
from invokeai.app.services.invoker import Invoker
from invokeai.app.util.misc import uuid_string
from .bulk_download_base import BulkDownloadBase
class BulkDownloadService(BulkDownloadBase):
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
def __init__(self):
self._temp_directory = TemporaryDirectory()
self._bulk_downloads_folder = Path(self._temp_directory.name) / "bulk_downloads"
self._bulk_downloads_folder.mkdir(parents=True, exist_ok=True)
def handler(
self, image_names: Optional[list[str]], board_id: Optional[str], bulk_download_item_id: Optional[str]
) -> None:
bulk_download_id: str = DEFAULT_BULK_DOWNLOAD_ID
bulk_download_item_id = bulk_download_item_id or uuid_string()
bulk_download_item_name = bulk_download_item_id + ".zip"
self._signal_job_started(bulk_download_id, bulk_download_item_id, bulk_download_item_name)
try:
image_dtos: list[ImageDTO] = []
if board_id:
image_dtos = self._board_handler(board_id)
elif image_names:
image_dtos = self._image_handler(image_names)
else:
raise BulkDownloadParametersException()
bulk_download_item_name: str = self._create_zip_file(image_dtos, bulk_download_item_id)
self._signal_job_completed(bulk_download_id, bulk_download_item_id, bulk_download_item_name)
except (
ImageRecordNotFoundException,
BoardRecordNotFoundException,
BulkDownloadException,
BulkDownloadParametersException,
) as e:
self._signal_job_failed(bulk_download_id, bulk_download_item_id, bulk_download_item_name, e)
except Exception as e:
self._signal_job_failed(bulk_download_id, bulk_download_item_id, bulk_download_item_name, e)
self._invoker.services.logger.error("Problem bulk downloading images.")
raise e
def _image_handler(self, image_names: list[str]) -> list[ImageDTO]:
return [self._invoker.services.images.get_dto(image_name) for image_name in image_names]
def _board_handler(self, board_id: str) -> list[ImageDTO]:
image_names = self._invoker.services.board_image_records.get_all_board_image_names_for_board(board_id)
return self._image_handler(image_names)
def generate_item_id(self, board_id: Optional[str]) -> str:
return uuid_string() if board_id is None else self._get_clean_board_name(board_id) + "_" + uuid_string()
def _get_clean_board_name(self, board_id: str) -> str:
if board_id == "none":
return "Uncategorized"
return self._clean_string_to_path_safe(self._invoker.services.board_records.get(board_id).board_name)
def _create_zip_file(self, image_dtos: list[ImageDTO], bulk_download_item_id: str) -> str:
"""
Create a zip file containing the images specified by the given image names or board id.
If download with the same bulk_download_id already exists, it will be overwritten.
:return: The name of the zip file.
"""
zip_file_name = bulk_download_item_id + ".zip"
zip_file_path = self._bulk_downloads_folder / (zip_file_name)
with ZipFile(zip_file_path, "w") as zip_file:
for image_dto in image_dtos:
image_zip_path = Path(image_dto.image_category.value) / image_dto.image_name
image_disk_path = self._invoker.services.images.get_path(image_dto.image_name)
zip_file.write(image_disk_path, arcname=image_zip_path)
return str(zip_file_name)
# from https://stackoverflow.com/questions/7406102/create-sane-safe-filename-from-any-unsafe-string
def _clean_string_to_path_safe(self, s: str) -> str:
"""Clean a string to be path safe."""
return "".join([c for c in s if c.isalpha() or c.isdigit() or c == " " or c == "_" or c == "-"]).rstrip()
def _signal_job_started(
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
) -> None:
"""Signal that a bulk download job has started."""
if self._invoker:
assert bulk_download_id is not None
self._invoker.services.events.emit_bulk_download_started(
bulk_download_id=bulk_download_id,
bulk_download_item_id=bulk_download_item_id,
bulk_download_item_name=bulk_download_item_name,
)
def _signal_job_completed(
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
) -> None:
"""Signal that a bulk download job has completed."""
if self._invoker:
assert bulk_download_id is not None
assert bulk_download_item_name is not None
self._invoker.services.events.emit_bulk_download_completed(
bulk_download_id=bulk_download_id,
bulk_download_item_id=bulk_download_item_id,
bulk_download_item_name=bulk_download_item_name,
)
def _signal_job_failed(
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, exception: Exception
) -> None:
"""Signal that a bulk download job has failed."""
if self._invoker:
assert bulk_download_id is not None
assert exception is not None
self._invoker.services.events.emit_bulk_download_failed(
bulk_download_id=bulk_download_id,
bulk_download_item_id=bulk_download_item_id,
bulk_download_item_name=bulk_download_item_name,
error=str(exception),
)
def stop(self, *args, **kwargs):
self._temp_directory.cleanup()
def delete(self, bulk_download_item_name: str) -> None:
path = self.get_path(bulk_download_item_name)
Path(path).unlink()
def get_path(self, bulk_download_item_name: str) -> str:
path = str(self._bulk_downloads_folder / bulk_download_item_name)
if not self._is_valid_path(path):
raise BulkDownloadTargetException()
return path
def _is_valid_path(self, path: Union[str, Path]) -> bool:
"""Validates the path given for a bulk download."""
path = path if isinstance(path, Path) else Path(path)
return path.exists()

View File

@@ -156,7 +156,6 @@ class InvokeAISettings(BaseSettings):
"lora_dir",
"embedding_dir",
"controlnet_dir",
"conf_path",
]
@classmethod

View File

@@ -30,6 +30,7 @@ InvokeAI:
lora_dir: null
embedding_dir: null
controlnet_dir: null
conf_path: configs/models.yaml
models_dir: models
legacy_conf_dir: configs/stable-diffusion
db_dir: databases
@@ -122,6 +123,7 @@ a Path object:
root_path - path to InvokeAI root
output_path - path to default outputs directory
model_conf_path - path to models.yaml
conf - alias for the above
embedding_path - path to the embeddings directory
lora_path - path to the LoRA directory
@@ -161,12 +163,12 @@ two configs are kept in separate sections of the config file:
InvokeAI:
Paths:
root: /home/lstein/invokeai-main
conf_path: configs/models.yaml
legacy_conf_dir: configs/stable-diffusion
outdir: outputs
...
"""
from __future__ import annotations
import os
@@ -235,6 +237,7 @@ class InvokeAIAppConfig(InvokeAISettings):
# PATHS
root : Optional[Path] = Field(default=None, description='InvokeAI runtime root directory', json_schema_extra=Categories.Paths)
autoimport_dir : Path = Field(default=Path('autoimport'), description='Path to a directory of models files to be imported on startup.', json_schema_extra=Categories.Paths)
conf_path : Path = Field(default=Path('configs/models.yaml'), description='Path to models definition file', json_schema_extra=Categories.Paths)
models_dir : Path = Field(default=Path('models'), description='Path to the models directory', json_schema_extra=Categories.Paths)
convert_cache_dir : Path = Field(default=Path('models/.cache'), description='Path to the converted models cache directory', json_schema_extra=Categories.Paths)
legacy_conf_dir : Path = Field(default=Path('configs/stable-diffusion'), description='Path to directory of legacy checkpoint config files', json_schema_extra=Categories.Paths)
@@ -298,7 +301,6 @@ class InvokeAIAppConfig(InvokeAISettings):
lora_dir : Optional[Path] = Field(default=None, description='Path to a directory of LoRA/LyCORIS models to be imported on startup.', json_schema_extra=Categories.Paths)
embedding_dir : Optional[Path] = Field(default=None, description='Path to a directory of Textual Inversion embeddings to be imported on startup.', json_schema_extra=Categories.Paths)
controlnet_dir : Optional[Path] = Field(default=None, description='Path to a directory of ControlNet embeddings to be imported on startup.', json_schema_extra=Categories.Paths)
conf_path : Path = Field(default=Path('configs/models.yaml'), description='Path to models definition file', json_schema_extra=Categories.Paths)
# this is not referred to in the source code and can be removed entirely
#free_gpu_mem : Optional[bool] = Field(default=None, description="If true, purge model from GPU after each generation.", json_schema_extra=Categories.MemoryPerformance)

View File

@@ -1,5 +1,4 @@
"""Init file for download queue."""
from .download_base import DownloadJob, DownloadJobStatus, DownloadQueueServiceBase, UnknownJobIDException
from .download_default import DownloadQueueService, TqdmProgress

View File

@@ -224,6 +224,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
job.job_started = get_iso_timestamp()
self._do_download(job)
self._signal_job_complete(job)
except (OSError, HTTPError) as excp:
job.error_type = excp.__class__.__name__ + f"({str(excp)})"
job.error = traceback.format_exc()

View File

@@ -3,7 +3,7 @@
from typing import Any, Dict, List, Optional, Union
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.app.services.invocation_processor.invocation_processor_common import ProgressImage
from invokeai.app.services.session_queue.session_queue_common import (
BatchStatus,
EnqueueBatchResult,
@@ -16,7 +16,6 @@ from invokeai.backend.model_manager import AnyModelConfig
class EventServiceBase:
queue_event: str = "queue_event"
bulk_download_event: str = "bulk_download_event"
download_event: str = "download_event"
model_event: str = "model_event"
@@ -25,14 +24,6 @@ class EventServiceBase:
def dispatch(self, event_name: str, payload: Any) -> None:
pass
def _emit_bulk_download_event(self, event_name: str, payload: dict) -> None:
"""Bulk download events are emitted to a room with queue_id as the room name"""
payload["timestamp"] = get_timestamp()
self.dispatch(
event_name=EventServiceBase.bulk_download_event,
payload={"event": event_name, "data": payload},
)
def __emit_queue_event(self, event_name: str, payload: dict) -> None:
"""Queue events are emitted to a room with queue_id as the room name"""
payload["timestamp"] = get_timestamp()
@@ -213,6 +204,52 @@ class EventServiceBase:
},
)
def emit_session_retrieval_error(
self,
queue_id: str,
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
error_type: str,
error: str,
) -> None:
"""Emitted when session retrieval fails"""
self.__emit_queue_event(
event_name="session_retrieval_error",
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"error_type": error_type,
"error": error,
},
)
def emit_invocation_retrieval_error(
self,
queue_id: str,
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
node_id: str,
error_type: str,
error: str,
) -> None:
"""Emitted when invocation retrieval fails"""
self.__emit_queue_event(
event_name="invocation_retrieval_error",
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"node_id": node_id,
"error_type": error_type,
"error": error,
},
)
def emit_session_canceled(
self,
queue_id: str,
@@ -357,7 +394,6 @@ class EventServiceBase:
bytes: int,
total_bytes: int,
parts: List[Dict[str, Union[str, int]]],
id: int,
) -> None:
"""
Emit at intervals while the install job is in progress (remote models only).
@@ -377,7 +413,6 @@ class EventServiceBase:
"bytes": bytes,
"total_bytes": total_bytes,
"parts": parts,
"id": id,
},
)
@@ -392,7 +427,7 @@ class EventServiceBase:
payload={"source": source},
)
def emit_model_install_completed(self, source: str, key: str, id: int, total_bytes: Optional[int] = None) -> None:
def emit_model_install_completed(self, source: str, key: str, total_bytes: Optional[int] = None) -> None:
"""
Emit when an install job is completed successfully.
@@ -402,7 +437,11 @@ class EventServiceBase:
"""
self.__emit_model_event(
event_name="model_install_completed",
payload={"source": source, "total_bytes": total_bytes, "key": key, "id": id},
payload={
"source": source,
"total_bytes": total_bytes,
"key": key,
},
)
def emit_model_install_cancelled(self, source: str) -> None:
@@ -416,7 +455,12 @@ class EventServiceBase:
payload={"source": source},
)
def emit_model_install_error(self, source: str, error_type: str, error: str, id: int) -> None:
def emit_model_install_error(
self,
source: str,
error_type: str,
error: str,
) -> None:
"""
Emit when an install job encounters an exception.
@@ -426,45 +470,9 @@ class EventServiceBase:
"""
self.__emit_model_event(
event_name="model_install_error",
payload={"source": source, "error_type": error_type, "error": error, "id": id},
)
def emit_bulk_download_started(
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
) -> None:
"""Emitted when a bulk download starts"""
self._emit_bulk_download_event(
event_name="bulk_download_started",
payload={
"bulk_download_id": bulk_download_id,
"bulk_download_item_id": bulk_download_item_id,
"bulk_download_item_name": bulk_download_item_name,
},
)
def emit_bulk_download_completed(
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
) -> None:
"""Emitted when a bulk download completes"""
self._emit_bulk_download_event(
event_name="bulk_download_completed",
payload={
"bulk_download_id": bulk_download_id,
"bulk_download_item_id": bulk_download_item_id,
"bulk_download_item_name": bulk_download_item_name,
},
)
def emit_bulk_download_failed(
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, error: str
) -> None:
"""Emitted when a bulk download fails"""
self._emit_bulk_download_event(
event_name="bulk_download_failed",
payload={
"bulk_download_id": bulk_download_id,
"bulk_download_item_id": bulk_download_item_id,
"bulk_download_item_name": bulk_download_item_name,
"source": source,
"error_type": error_type,
"error": error,
},
)

View File

@@ -0,0 +1,5 @@
from abc import ABC
class InvocationProcessorABC(ABC): # noqa: B024
pass

View File

@@ -0,0 +1,15 @@
from pydantic import BaseModel, Field
class ProgressImage(BaseModel):
"""The progress image sent intermittently during processing"""
width: int = Field(description="The effective width of the image in pixels")
height: int = Field(description="The effective height of the image in pixels")
dataURL: str = Field(description="The image data as a b64 data URL")
class CanceledException(Exception):
"""Execution canceled by user."""
pass

View File

@@ -0,0 +1,243 @@
import time
import traceback
from contextlib import suppress
from threading import BoundedSemaphore, Event, Thread
from typing import Optional
import invokeai.backend.util.logging as logger
from invokeai.app.services.invocation_queue.invocation_queue_common import InvocationQueueItem
from invokeai.app.services.invocation_stats.invocation_stats_common import (
GESStatsNotFoundError,
)
from invokeai.app.services.shared.invocation_context import InvocationContextData, build_invocation_context
from invokeai.app.util.profiler import Profiler
from ..invoker import Invoker
from .invocation_processor_base import InvocationProcessorABC
from .invocation_processor_common import CanceledException
class DefaultInvocationProcessor(InvocationProcessorABC):
__invoker_thread: Thread
__stop_event: Event
__invoker: Invoker
__threadLimit: BoundedSemaphore
def start(self, invoker: Invoker) -> None:
# LS - this will probably break
# but the idea is to enable multithreading up to the number of available
# GPUs. Nodes will block on model loading if no GPU is free.
self.__threadLimit = BoundedSemaphore(invoker.services.model_manager.gpu_count)
self.__invoker = invoker
self.__stop_event = Event()
self.__invoker_thread = Thread(
name="invoker_processor",
target=self.__process,
kwargs={"stop_event": self.__stop_event},
)
self.__invoker_thread.daemon = True # TODO: make async and do not use threads
self.__invoker_thread.start()
def stop(self, *args, **kwargs) -> None:
self.__stop_event.set()
def __process(self, stop_event: Event):
try:
self.__threadLimit.acquire()
queue_item: Optional[InvocationQueueItem] = None
profiler = (
Profiler(
logger=self.__invoker.services.logger,
output_dir=self.__invoker.services.configuration.profiles_path,
prefix=self.__invoker.services.configuration.profile_prefix,
)
if self.__invoker.services.configuration.profile_graphs
else None
)
def stats_cleanup(graph_execution_state_id: str) -> None:
if profiler:
profile_path = profiler.stop()
stats_path = profile_path.with_suffix(".json")
self.__invoker.services.performance_statistics.dump_stats(
graph_execution_state_id=graph_execution_state_id, output_path=stats_path
)
with suppress(GESStatsNotFoundError):
self.__invoker.services.performance_statistics.log_stats(graph_execution_state_id)
self.__invoker.services.performance_statistics.reset_stats(graph_execution_state_id)
while not stop_event.is_set():
try:
queue_item = self.__invoker.services.queue.get()
except Exception as e:
self.__invoker.services.logger.error("Exception while getting from queue:\n%s" % e)
if not queue_item: # Probably stopping
# do not hammer the queue
time.sleep(0.5)
continue
if profiler and profiler.profile_id != queue_item.graph_execution_state_id:
profiler.start(profile_id=queue_item.graph_execution_state_id)
try:
graph_execution_state = self.__invoker.services.graph_execution_manager.get(
queue_item.graph_execution_state_id
)
except Exception as e:
self.__invoker.services.logger.error("Exception while retrieving session:\n%s" % e)
self.__invoker.services.events.emit_session_retrieval_error(
queue_batch_id=queue_item.session_queue_batch_id,
queue_item_id=queue_item.session_queue_item_id,
queue_id=queue_item.session_queue_id,
graph_execution_state_id=queue_item.graph_execution_state_id,
error_type=e.__class__.__name__,
error=traceback.format_exc(),
)
continue
try:
invocation = graph_execution_state.execution_graph.get_node(queue_item.invocation_id)
except Exception as e:
self.__invoker.services.logger.error("Exception while retrieving invocation:\n%s" % e)
self.__invoker.services.events.emit_invocation_retrieval_error(
queue_batch_id=queue_item.session_queue_batch_id,
queue_item_id=queue_item.session_queue_item_id,
queue_id=queue_item.session_queue_id,
graph_execution_state_id=queue_item.graph_execution_state_id,
node_id=queue_item.invocation_id,
error_type=e.__class__.__name__,
error=traceback.format_exc(),
)
continue
# get the source node id to provide to clients (the prepared node id is not as useful)
source_node_id = graph_execution_state.prepared_source_mapping[invocation.id]
# Send starting event
self.__invoker.services.events.emit_invocation_started(
queue_batch_id=queue_item.session_queue_batch_id,
queue_item_id=queue_item.session_queue_item_id,
queue_id=queue_item.session_queue_id,
graph_execution_state_id=graph_execution_state.id,
node=invocation.model_dump(),
source_node_id=source_node_id,
)
# Invoke
try:
graph_id = graph_execution_state.id
with self.__invoker.services.performance_statistics.collect_stats(invocation, graph_id):
# use the internal invoke_internal(), which wraps the node's invoke() method,
# which handles a few things:
# - nodes that require a value, but get it only from a connection
# - referencing the invocation cache instead of executing the node
context_data = InvocationContextData(
invocation=invocation,
session_id=graph_id,
workflow=queue_item.workflow,
source_node_id=source_node_id,
queue_id=queue_item.session_queue_id,
queue_item_id=queue_item.session_queue_item_id,
batch_id=queue_item.session_queue_batch_id,
)
context = build_invocation_context(
services=self.__invoker.services,
context_data=context_data,
)
outputs = invocation.invoke_internal(context=context, services=self.__invoker.services)
# Check queue to see if this is canceled, and skip if so
if self.__invoker.services.queue.is_canceled(graph_execution_state.id):
continue
# Save outputs and history
graph_execution_state.complete(invocation.id, outputs)
# Save the state changes
self.__invoker.services.graph_execution_manager.set(graph_execution_state)
# Send complete event
self.__invoker.services.events.emit_invocation_complete(
queue_batch_id=queue_item.session_queue_batch_id,
queue_item_id=queue_item.session_queue_item_id,
queue_id=queue_item.session_queue_id,
graph_execution_state_id=graph_execution_state.id,
node=invocation.model_dump(),
source_node_id=source_node_id,
result=outputs.model_dump(),
)
except KeyboardInterrupt:
pass
except CanceledException:
stats_cleanup(graph_execution_state.id)
pass
except Exception as e:
error = traceback.format_exc()
logger.error(error)
# Save error
graph_execution_state.set_node_error(invocation.id, error)
# Save the state changes
self.__invoker.services.graph_execution_manager.set(graph_execution_state)
self.__invoker.services.logger.error("Error while invoking:\n%s" % e)
# Send error event
self.__invoker.services.events.emit_invocation_error(
queue_batch_id=queue_item.session_queue_batch_id,
queue_item_id=queue_item.session_queue_item_id,
queue_id=queue_item.session_queue_id,
graph_execution_state_id=graph_execution_state.id,
node=invocation.model_dump(),
source_node_id=source_node_id,
error_type=e.__class__.__name__,
error=error,
)
pass
# Check queue to see if this is canceled, and skip if so
if self.__invoker.services.queue.is_canceled(graph_execution_state.id):
continue
# Queue any further commands if invoking all
is_complete = graph_execution_state.is_complete()
if queue_item.invoke_all and not is_complete:
try:
self.__invoker.invoke(
session_queue_batch_id=queue_item.session_queue_batch_id,
session_queue_item_id=queue_item.session_queue_item_id,
session_queue_id=queue_item.session_queue_id,
graph_execution_state=graph_execution_state,
workflow=queue_item.workflow,
invoke_all=True,
)
except Exception as e:
self.__invoker.services.logger.error("Error while invoking:\n%s" % e)
self.__invoker.services.events.emit_invocation_error(
queue_batch_id=queue_item.session_queue_batch_id,
queue_item_id=queue_item.session_queue_item_id,
queue_id=queue_item.session_queue_id,
graph_execution_state_id=graph_execution_state.id,
node=invocation.model_dump(),
source_node_id=source_node_id,
error_type=e.__class__.__name__,
error=traceback.format_exc(),
)
elif is_complete:
self.__invoker.services.events.emit_graph_execution_complete(
queue_batch_id=queue_item.session_queue_batch_id,
queue_item_id=queue_item.session_queue_item_id,
queue_id=queue_item.session_queue_id,
graph_execution_state_id=graph_execution_state.id,
)
stats_cleanup(graph_execution_state.id)
except KeyboardInterrupt:
pass # Log something? KeyboardInterrupt is probably not going to be seen by the processor
finally:
self.__threadLimit.release()

View File

@@ -0,0 +1,26 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from abc import ABC, abstractmethod
from typing import Optional
from .invocation_queue_common import InvocationQueueItem
class InvocationQueueABC(ABC):
"""Abstract base class for all invocation queues"""
@abstractmethod
def get(self) -> InvocationQueueItem:
pass
@abstractmethod
def put(self, item: Optional[InvocationQueueItem]) -> None:
pass
@abstractmethod
def cancel(self, graph_execution_state_id: str) -> None:
pass
@abstractmethod
def is_canceled(self, graph_execution_state_id: str) -> bool:
pass

View File

@@ -0,0 +1,23 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import time
from typing import Optional
from pydantic import BaseModel, Field
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
class InvocationQueueItem(BaseModel):
graph_execution_state_id: str = Field(description="The ID of the graph execution state")
invocation_id: str = Field(description="The ID of the node being invoked")
session_queue_id: str = Field(description="The ID of the session queue from which this invocation queue item came")
session_queue_item_id: int = Field(
description="The ID of session queue item from which this invocation queue item came"
)
session_queue_batch_id: str = Field(
description="The ID of the session batch from which this invocation queue item came"
)
workflow: Optional[WorkflowWithoutID] = Field(description="The workflow associated with this queue item")
invoke_all: bool = Field(default=False)
timestamp: float = Field(default_factory=time.time)

View File

@@ -0,0 +1,44 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import time
from queue import Queue
from typing import Optional
from .invocation_queue_base import InvocationQueueABC
from .invocation_queue_common import InvocationQueueItem
class MemoryInvocationQueue(InvocationQueueABC):
__queue: Queue
__cancellations: dict[str, float]
def __init__(self):
self.__queue = Queue()
self.__cancellations = {}
def get(self) -> InvocationQueueItem:
item = self.__queue.get()
while (
isinstance(item, InvocationQueueItem)
and item.graph_execution_state_id in self.__cancellations
and self.__cancellations[item.graph_execution_state_id] > item.timestamp
):
item = self.__queue.get()
# Clear old items
for graph_execution_state_id in list(self.__cancellations.keys()):
if self.__cancellations[graph_execution_state_id] < item.timestamp:
del self.__cancellations[graph_execution_state_id]
return item
def put(self, item: Optional[InvocationQueueItem]) -> None:
self.__queue.put(item)
def cancel(self, graph_execution_state_id: str) -> None:
if graph_execution_state_id not in self.__cancellations:
self.__cancellations[graph_execution_state_id] = time.time()
def is_canceled(self, graph_execution_state_id: str) -> bool:
return graph_execution_state_id in self.__cancellations

View File

@@ -16,7 +16,6 @@ if TYPE_CHECKING:
from .board_images.board_images_base import BoardImagesServiceABC
from .board_records.board_records_base import BoardRecordStorageBase
from .boards.boards_base import BoardServiceABC
from .bulk_download.bulk_download_base import BulkDownloadBase
from .config import InvokeAIAppConfig
from .download import DownloadQueueServiceBase
from .events.events_base import EventServiceBase
@@ -24,11 +23,15 @@ if TYPE_CHECKING:
from .image_records.image_records_base import ImageRecordStorageBase
from .images.images_base import ImageServiceABC
from .invocation_cache.invocation_cache_base import InvocationCacheBase
from .invocation_processor.invocation_processor_base import InvocationProcessorABC
from .invocation_queue.invocation_queue_base import InvocationQueueABC
from .invocation_stats.invocation_stats_base import InvocationStatsServiceBase
from .item_storage.item_storage_base import ItemStorageABC
from .model_manager.model_manager_base import ModelManagerServiceBase
from .names.names_base import NameServiceBase
from .session_processor.session_processor_base import SessionProcessorBase
from .session_queue.session_queue_base import SessionQueueBase
from .shared.graph import GraphExecutionState
from .urls.urls_base import UrlServiceBase
from .workflow_records.workflow_records_base import WorkflowRecordsStorageBase
@@ -42,16 +45,18 @@ class InvocationServices:
board_image_records: "BoardImageRecordStorageBase",
boards: "BoardServiceABC",
board_records: "BoardRecordStorageBase",
bulk_download: "BulkDownloadBase",
configuration: "InvokeAIAppConfig",
events: "EventServiceBase",
graph_execution_manager: "ItemStorageABC[GraphExecutionState]",
images: "ImageServiceABC",
image_files: "ImageFileStorageBase",
image_records: "ImageRecordStorageBase",
logger: "Logger",
model_manager: "ModelManagerServiceBase",
download_queue: "DownloadQueueServiceBase",
processor: "InvocationProcessorABC",
performance_statistics: "InvocationStatsServiceBase",
queue: "InvocationQueueABC",
session_queue: "SessionQueueBase",
session_processor: "SessionProcessorBase",
invocation_cache: "InvocationCacheBase",
@@ -65,16 +70,18 @@ class InvocationServices:
self.board_image_records = board_image_records
self.boards = boards
self.board_records = board_records
self.bulk_download = bulk_download
self.configuration = configuration
self.events = events
self.graph_execution_manager = graph_execution_manager
self.images = images
self.image_files = image_files
self.image_records = image_records
self.logger = logger
self.model_manager = model_manager
self.download_queue = download_queue
self.processor = processor
self.performance_statistics = performance_statistics
self.queue = queue
self.session_queue = session_queue
self.session_processor = session_processor
self.invocation_cache = invocation_cache

View File

@@ -3,7 +3,7 @@
Usage:
statistics = InvocationStatsService()
statistics = InvocationStatsService(graph_execution_manager)
with statistics.collect_stats(invocation, graph_execution_state.id):
... execute graphs...
statistics.log_stats()
@@ -30,7 +30,7 @@ writes to the system log is stored in InvocationServices.performance_statistics.
from abc import ABC, abstractmethod
from pathlib import Path
from typing import ContextManager
from typing import Iterator
from invokeai.app.invocations.baseinvocation import BaseInvocation
from invokeai.app.services.invocation_stats.invocation_stats_common import InvocationStatsSummary
@@ -50,7 +50,7 @@ class InvocationStatsServiceBase(ABC):
self,
invocation: BaseInvocation,
graph_execution_state_id: str,
) -> ContextManager[None]:
) -> Iterator[None]:
"""
Return a context object that will capture the statistics on the execution
of invocaation. Use with: to place around the part of the code that executes the invocation.
@@ -60,8 +60,12 @@ class InvocationStatsServiceBase(ABC):
pass
@abstractmethod
def reset_stats(self):
"""Reset all stored statistics."""
def reset_stats(self, graph_execution_state_id: str) -> None:
"""
Reset all statistics for the indicated graph.
:param graph_execution_state_id: The id of the session whose stats to reset.
:raises GESStatsNotFoundError: if the graph isn't tracked in the stats.
"""
pass
@abstractmethod

View File

@@ -2,7 +2,7 @@ import json
import time
from contextlib import contextmanager
from pathlib import Path
from typing import Generator
from typing import Iterator
import psutil
import torch
@@ -10,6 +10,7 @@ import torch
import invokeai.backend.util.logging as logger
from invokeai.app.invocations.baseinvocation import BaseInvocation
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.item_storage.item_storage_common import ItemNotFoundError
from invokeai.backend.model_manager.load.model_cache import CacheStats
from .invocation_stats_base import InvocationStatsServiceBase
@@ -41,7 +42,7 @@ class InvocationStatsService(InvocationStatsServiceBase):
self._invoker = invoker
@contextmanager
def collect_stats(self, invocation: BaseInvocation, graph_execution_state_id: str) -> Generator[None, None, None]:
def collect_stats(self, invocation: BaseInvocation, graph_execution_state_id: str) -> Iterator[None]:
# This is to handle case of the model manager not being initialized, which happens
# during some tests.
services = self._invoker.services
@@ -50,6 +51,9 @@ class InvocationStatsService(InvocationStatsServiceBase):
self._stats[graph_execution_state_id] = GraphExecutionStats()
self._cache_stats[graph_execution_state_id] = CacheStats()
# Prune stale stats. There should be none since we're starting a new graph, but just in case.
self._prune_stale_stats()
# Record state before the invocation.
start_time = time.time()
start_ram = psutil.Process().memory_info().rss
@@ -74,9 +78,42 @@ class InvocationStatsService(InvocationStatsServiceBase):
)
self._stats[graph_execution_state_id].add_node_execution_stats(node_stats)
def reset_stats(self):
self._stats = {}
self._cache_stats = {}
def _prune_stale_stats(self) -> None:
"""Check all graphs being tracked and prune any that have completed/errored.
This shouldn't be necessary, but we don't have totally robust upstream handling of graph completions/errors, so
for now we call this function periodically to prevent them from accumulating.
"""
to_prune: list[str] = []
for graph_execution_state_id in self._stats:
try:
graph_execution_state = self._invoker.services.graph_execution_manager.get(graph_execution_state_id)
except ItemNotFoundError:
# TODO(ryand): What would cause this? Should this exception just be allowed to propagate?
logger.warning(f"Failed to get graph state for {graph_execution_state_id}.")
continue
if not graph_execution_state.is_complete():
# The graph is still running, don't prune it.
continue
to_prune.append(graph_execution_state_id)
for graph_execution_state_id in to_prune:
del self._stats[graph_execution_state_id]
del self._cache_stats[graph_execution_state_id]
if len(to_prune) > 0:
logger.info(f"Pruned stale graph stats for {to_prune}.")
def reset_stats(self, graph_execution_state_id: str):
try:
del self._stats[graph_execution_state_id]
del self._cache_stats[graph_execution_state_id]
except KeyError as e:
raise GESStatsNotFoundError(
f"Attempted to clear statistics for unknown graph {graph_execution_state_id}: {e}."
) from e
def get_stats(self, graph_execution_state_id: str) -> InvocationStatsSummary:
graph_stats_summary = self._get_graph_summary(graph_execution_state_id)

View File

@@ -1,7 +1,12 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import Optional
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
from .invocation_queue.invocation_queue_common import InvocationQueueItem
from .invocation_services import InvocationServices
from .shared.graph import Graph, GraphExecutionState
class Invoker:
@@ -13,6 +18,51 @@ class Invoker:
self.services = services
self._start()
def invoke(
self,
session_queue_id: str,
session_queue_item_id: int,
session_queue_batch_id: str,
graph_execution_state: GraphExecutionState,
workflow: Optional[WorkflowWithoutID] = None,
invoke_all: bool = False,
) -> Optional[str]:
"""Determines the next node to invoke and enqueues it, preparing if needed.
Returns the id of the queued node, or `None` if there are no nodes left to enqueue."""
# Get the next invocation
invocation = graph_execution_state.next()
if not invocation:
return None
# Save the execution state
self.services.graph_execution_manager.set(graph_execution_state)
# Queue the invocation
self.services.queue.put(
InvocationQueueItem(
session_queue_id=session_queue_id,
session_queue_item_id=session_queue_item_id,
session_queue_batch_id=session_queue_batch_id,
graph_execution_state_id=graph_execution_state.id,
invocation_id=invocation.id,
workflow=workflow,
invoke_all=invoke_all,
)
)
return invocation.id
def create_execution_state(self, graph: Optional[Graph] = None) -> GraphExecutionState:
"""Creates a new execution state for the given graph"""
new_state = GraphExecutionState(graph=Graph() if graph is None else graph)
self.services.graph_execution_manager.set(new_state)
return new_state
def cancel(self, graph_execution_state_id: str) -> None:
"""Cancels the given execution state"""
self.services.queue.cancel(graph_execution_state_id)
def __start_service(self, service) -> None:
# Call start() method on any services that have it
start_op = getattr(service, "start", None)
@@ -35,3 +85,5 @@ class Invoker:
# First stop all services
for service in vars(self.services):
self.__stop_service(getattr(self.services, service))
self.services.queue.put(None)

View File

@@ -28,7 +28,6 @@ class InstallStatus(str, Enum):
WAITING = "waiting" # waiting to be dequeued
DOWNLOADING = "downloading" # downloading of model files in process
DOWNLOADS_DONE = "downloads_done" # downloading done, waiting to run
RUNNING = "running" # being processed
COMPLETED = "completed" # finished running
ERROR = "error" # terminated with an error message
@@ -157,7 +156,6 @@ class ModelInstallJob(BaseModel):
id: int = Field(description="Unique ID for this job")
status: InstallStatus = Field(default=InstallStatus.WAITING, description="Current status of install process")
error_reason: Optional[str] = Field(default=None, description="Information about why the job failed")
config_in: Dict[str, Any] = Field(
default_factory=dict, description="Configuration information (e.g. 'description') to apply to model."
)
@@ -179,12 +177,6 @@ class ModelInstallJob(BaseModel):
download_parts: Set[DownloadJob] = Field(
default_factory=set, description="Download jobs contributing to this install"
)
error: Optional[str] = Field(
default=None, description="On an error condition, this field will contain the text of the exception"
)
error_traceback: Optional[str] = Field(
default=None, description="On an error condition, this field will contain the exception traceback"
)
# internal flags and transitory settings
_install_tmpdir: Optional[Path] = PrivateAttr(default=None)
_exception: Optional[Exception] = PrivateAttr(default=None)
@@ -192,10 +184,7 @@ class ModelInstallJob(BaseModel):
def set_error(self, e: Exception) -> None:
"""Record the error and traceback from an exception."""
self._exception = e
self.error = str(e)
self.error_traceback = self._format_error(e)
self.status = InstallStatus.ERROR
self.error_reason = self._exception.__class__.__name__ if self._exception else None
def cancel(self) -> None:
"""Call to cancel the job."""
@@ -206,9 +195,10 @@ class ModelInstallJob(BaseModel):
"""Class name of the exception that led to status==ERROR."""
return self._exception.__class__.__name__ if self._exception else None
def _format_error(self, exception: Exception) -> str:
@property
def error(self) -> Optional[str]:
"""Error traceback."""
return "".join(traceback.format_exception(exception))
return "".join(traceback.format_exception(self._exception)) if self._exception else None
@property
def cancelled(self) -> bool:
@@ -230,11 +220,6 @@ class ModelInstallJob(BaseModel):
"""Return true if job is downloading."""
return self.status == InstallStatus.DOWNLOADING
@property
def downloads_done(self) -> bool:
"""Return true if job's downloads ae done."""
return self.status == InstallStatus.DOWNLOADS_DONE
@property
def running(self) -> bool:
"""Return true if job is running."""

View File

@@ -7,6 +7,7 @@ import time
from hashlib import sha256
from pathlib import Path
from queue import Empty, Queue
from random import randbytes
from shutil import copyfile, copytree, move, rmtree
from tempfile import mkdtemp
from typing import Any, Dict, List, Optional, Set, Union
@@ -20,7 +21,6 @@ from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase
from invokeai.app.util.misc import uuid_string
from invokeai.backend.model_manager.config import (
AnyModelConfig,
BaseModelType,
@@ -28,6 +28,7 @@ from invokeai.backend.model_manager.config import (
ModelRepoVariant,
ModelType,
)
from invokeai.backend.model_manager.hash import FastModelHash
from invokeai.backend.model_manager.metadata import (
AnyModelRepoMetadata,
CivitaiMetadataFetch,
@@ -150,15 +151,11 @@ class ModelInstallService(ModelInstallServiceBase):
config = config or {}
if not config.get("source"):
config["source"] = model_path.resolve().as_posix()
config["key"] = config.get("key", uuid_string())
info: AnyModelConfig = self._probe_model(Path(model_path), config)
if preferred_name := config.get("name"):
preferred_name = Path(preferred_name).with_suffix(model_path.suffix)
old_hash = info.current_hash
dest_path = (
self.app_config.models_path / info.base.value / info.type.value / (preferred_name or model_path.name)
self.app_config.models_path / info.base.value / info.type.value / (config.get("name") or model_path.name)
)
try:
new_path = self._copy_model(model_path, dest_path)
@@ -166,6 +163,8 @@ class ModelInstallService(ModelInstallServiceBase):
raise DuplicateModelException(
f"A model named {model_path.name} is already installed at {dest_path.as_posix()}"
) from excp
new_hash = FastModelHash.hash(new_path)
assert new_hash == old_hash, f"{model_path}: Model hash changed during installation, possibly corrupted."
return self._register(
new_path,
@@ -178,14 +177,13 @@ class ModelInstallService(ModelInstallServiceBase):
source: str,
config: Optional[Dict[str, Any]] = None,
access_token: Optional[str] = None,
inplace: bool = False,
) -> ModelInstallJob:
variants = "|".join(ModelRepoVariant.__members__.values())
hf_repoid_re = f"^([^/:]+/[^/:]+)(?::({variants})?(?::/?([^:]+))?)?$"
source_obj: Optional[StringLikeSource] = None
if Path(source).exists(): # A local file or directory
source_obj = LocalModelSource(path=Path(source), inplace=inplace)
source_obj = LocalModelSource(path=Path(source))
elif match := re.match(hf_repoid_re, source):
source_obj = HFModelSource(
repo_id=match.group(1),
@@ -280,9 +278,9 @@ class ModelInstallService(ModelInstallServiceBase):
self._logger.info("Model installer (re)initialized")
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102
self._cached_model_paths = {Path(x.path).absolute() for x in self.record_store.all_models()}
self._cached_model_paths = {Path(x.path) for x in self.record_store.all_models()}
callback = self._scan_install if install else self._scan_register
search = ModelSearch(on_model_found=callback, config=self._app_config)
search = ModelSearch(on_model_found=callback)
self._models_installed.clear()
search.search(scan_dir)
return list(self._models_installed)
@@ -368,7 +366,7 @@ class ModelInstallService(ModelInstallServiceBase):
self._signal_job_errored(job)
elif (
job.waiting or job.downloads_done
job.waiting or job.downloading
): # local jobs will be in waiting state, remote jobs will be downloading state
job.total_bytes = self._stat_size(job.local_path)
job.bytes = job.total_bytes
@@ -446,7 +444,7 @@ class ModelInstallService(ModelInstallServiceBase):
installed.update(self.scan_directory(models_dir))
self._logger.info(f"{len(installed)} new models registered; {len(defunct_models)} unregistered")
def _sync_model_path(self, key: str) -> AnyModelConfig:
def _sync_model_path(self, key: str, ignore_hash_change: bool = False) -> AnyModelConfig:
"""
Move model into the location indicated by its basetype, type and name.
@@ -467,7 +465,14 @@ class ModelInstallService(ModelInstallServiceBase):
new_path = models_dir / model.base.value / model.type.value / model.name
self._logger.info(f"Moving {model.name} to {new_path}.")
new_path = self._move_model(old_path, new_path)
new_hash = FastModelHash.hash(new_path)
model.path = new_path.relative_to(models_dir).as_posix()
if model.current_hash != new_hash:
assert (
ignore_hash_change
), f"{model.name}: Model hash changed during installation, model is possibly corrupted"
model.current_hash = new_hash
self._logger.info(f"Model has new hash {model.current_hash}, but will continue to be identified by {key}")
self.record_store.update_model(key, model)
return model
@@ -527,17 +532,14 @@ class ModelInstallService(ModelInstallServiceBase):
setattr(info, key, value)
return info
def _create_key(self) -> str:
return sha256(randbytes(100)).hexdigest()[0:32]
def _register(
self, model_path: Path, config: Optional[Dict[str, Any]] = None, info: Optional[AnyModelConfig] = None
) -> str:
# Note that we may be passed a pre-populated AnyModelConfig object,
# in which case the key field should have been populated by the caller (e.g. in `install_path`).
config["key"] = config.get("key", uuid_string())
info = info or ModelProbe.probe(model_path, config)
override_key: Optional[str] = config.get("key") if config else None
assert info.original_hash # always assigned by probe()
info.key = override_key or info.original_hash
key = self._create_key()
model_path = model_path.absolute()
if model_path.is_relative_to(self.app_config.models_path):
@@ -550,8 +552,8 @@ class ModelInstallService(ModelInstallServiceBase):
# make config relative to our root
legacy_conf = (self.app_config.root_dir / self.app_config.legacy_conf_dir / info.config).resolve()
info.config = legacy_conf.relative_to(self.app_config.root_dir).as_posix()
self.record_store.add_model(info.key, info)
return info.key
self.record_store.add_model(key, info)
return key
def _next_id(self) -> int:
with self._lock:
@@ -735,14 +737,13 @@ class ModelInstallService(ModelInstallServiceBase):
self._signal_job_downloading(install_job)
def _download_complete_callback(self, download_job: DownloadJob) -> None:
self._logger.info(f"{download_job.source}: model download complete")
with self._lock:
install_job = self._download_cache[download_job.source]
self._download_cache.pop(download_job.source, None)
# are there any more active jobs left in this task?
if install_job.downloading and all(x.complete for x in install_job.download_parts):
install_job.status = InstallStatus.DOWNLOADS_DONE
if all(x.complete for x in install_job.download_parts):
# now enqueue job for actual installation into the models directory
self._install_queue.put(install_job)
# Let other threads know that the number of downloads has changed
@@ -768,7 +769,7 @@ class ModelInstallService(ModelInstallServiceBase):
if not install_job:
return
self._downloads_changed_event.set()
self._logger.warning(f"{download_job.source}: model download cancelled")
self._logger.warning(f"Download {download_job.source} cancelled.")
# if install job has already registered an error, then do not replace its status with cancelled
if not install_job.errored:
install_job.cancel()
@@ -815,7 +816,6 @@ class ModelInstallService(ModelInstallServiceBase):
parts=parts,
bytes=job.bytes,
total_bytes=job.total_bytes,
id=job.id,
)
def _signal_job_completed(self, job: ModelInstallJob) -> None:
@@ -828,7 +828,7 @@ class ModelInstallService(ModelInstallServiceBase):
assert job.local_path is not None
assert job.config_out is not None
key = job.config_out.key
self._event_bus.emit_model_install_completed(str(job.source), key, id=job.id)
self._event_bus.emit_model_install_completed(str(job.source), key)
def _signal_job_errored(self, job: ModelInstallJob) -> None:
self._logger.info(f"{job.source}: model installation encountered an exception: {job.error_type}\n{job.error}")
@@ -837,7 +837,7 @@ class ModelInstallService(ModelInstallServiceBase):
error = job.error
assert error_type is not None
assert error is not None
self._event_bus.emit_model_install_error(str(job.source), error_type, error, id=job.id)
self._event_bus.emit_model_install_error(str(job.source), error_type, error)
def _signal_job_cancelled(self, job: ModelInstallJob) -> None:
self._logger.info(f"{job.source}: model installation was cancelled")

View File

@@ -38,3 +38,8 @@ class ModelLoadServiceBase(ABC):
@abstractmethod
def convert_cache(self) -> ModelConvertCacheBase:
"""Return the checkpoint convert cache used by this loader."""
@property
@abstractmethod
def gpu_count(self) -> int:
"""Return the number of GPUs we are configured to use."""

View File

@@ -4,6 +4,7 @@
from typing import Optional, Type
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.shared.invocation_context import InvocationContextData
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
@@ -39,6 +40,7 @@ class ModelLoadService(ModelLoadServiceBase):
self._registry = registry
def start(self, invoker: Invoker) -> None:
"""Start the service."""
self._invoker = invoker
@property
@@ -46,6 +48,11 @@ class ModelLoadService(ModelLoadServiceBase):
"""Return the RAM cache used by this loader."""
return self._ram_cache
@property
def gpu_count(self) -> int:
"""Return the number of GPUs available for our uses."""
return len(self._ram_cache.execution_devices)
@property
def convert_cache(self) -> ModelConvertCacheBase:
"""Return the checkpoint convert cache used by this loader."""
@@ -94,20 +101,22 @@ class ModelLoadService(ModelLoadServiceBase):
) -> None:
if not self._invoker:
return
if self._invoker.services.queue.is_canceled(context_data.session_id):
raise CanceledException()
if not loaded:
self._invoker.services.events.emit_model_load_started(
queue_id=context_data.queue_item.queue_id,
queue_item_id=context_data.queue_item.item_id,
queue_batch_id=context_data.queue_item.batch_id,
graph_execution_state_id=context_data.queue_item.session_id,
queue_id=context_data.queue_id,
queue_item_id=context_data.queue_item_id,
queue_batch_id=context_data.batch_id,
graph_execution_state_id=context_data.session_id,
model_config=model_config,
)
else:
self._invoker.services.events.emit_model_load_completed(
queue_id=context_data.queue_item.queue_id,
queue_item_id=context_data.queue_item.item_id,
queue_batch_id=context_data.queue_item.batch_id,
graph_execution_state_id=context_data.queue_item.session_id,
queue_id=context_data.queue_id,
queue_item_id=context_data.queue_item_id,
queue_batch_id=context_data.batch_id,
graph_execution_state_id=context_data.session_id,
model_config=model_config,
)

View File

@@ -3,7 +3,6 @@
from abc import ABC, abstractmethod
from typing import Optional
import torch
from typing_extensions import Self
from invokeai.app.services.invoker import Invoker
@@ -17,6 +16,7 @@ from ..events.events_base import EventServiceBase
from ..model_install import ModelInstallServiceBase
from ..model_load import ModelLoadServiceBase
from ..model_records import ModelRecordServiceBase
from ..shared.sqlite.sqlite_database import SqliteDatabase
class ModelManagerServiceBase(ABC):
@@ -32,10 +32,9 @@ class ModelManagerServiceBase(ABC):
def build_model_manager(
cls,
app_config: InvokeAIAppConfig,
model_record_service: ModelRecordServiceBase,
db: SqliteDatabase,
download_queue: DownloadQueueServiceBase,
events: EventServiceBase,
execution_device: torch.device,
) -> Self:
"""
Construct the model manager service instance.
@@ -99,3 +98,8 @@ class ModelManagerServiceBase(ABC):
context_data: Optional[InvocationContextData] = None,
) -> LoadedModel:
pass
@property
@abstractmethod
def gpu_count(self) -> int:
"""Return the number of GPUs we are configured to use."""

View File

@@ -3,14 +3,12 @@
from typing import Optional
import torch
from typing_extensions import Self
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.shared.invocation_context import InvocationContextData
from invokeai.backend.model_manager import AnyModelConfig, BaseModelType, LoadedModel, ModelType, SubModelType
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache, ModelLoaderRegistry
from invokeai.backend.util.devices import choose_torch_device
from invokeai.backend.util.logging import InvokeAILogger
from ..config import InvokeAIAppConfig
@@ -114,6 +112,11 @@ class ModelManagerService(ModelManagerServiceBase):
else:
return self.load.load_model(configs[0], submodel, context_data)
@property
def gpu_count(self) -> int:
"""Return the number of GPUs we are using."""
return self.load.gpu_count
@classmethod
def build_model_manager(
cls,
@@ -121,7 +124,6 @@ class ModelManagerService(ModelManagerServiceBase):
model_record_service: ModelRecordServiceBase,
download_queue: DownloadQueueServiceBase,
events: EventServiceBase,
execution_device: torch.device = choose_torch_device(),
) -> Self:
"""
Construct the model manager service instance.
@@ -132,10 +134,7 @@ class ModelManagerService(ModelManagerServiceBase):
logger.setLevel(app_config.log_level.upper())
ram_cache = ModelCache(
max_cache_size=app_config.ram_cache_size,
max_vram_cache_size=app_config.vram_cache_size,
logger=logger,
execution_device=execution_device,
max_cache_size=app_config.ram_cache_size, max_vram_cache_size=app_config.vram_cache_size, logger=logger
)
convert_cache = ModelConvertCache(
cache_path=app_config.models_convert_cache_path, max_size=app_config.convert_cache_size

View File

@@ -4,25 +4,9 @@ Storage for Model Metadata
"""
from abc import ABC, abstractmethod
from typing import List, Optional, Set, Tuple
from typing import List, Set, Tuple
from pydantic import Field
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
from invokeai.backend.model_manager.metadata.metadata_base import ModelDefaultSettings
class ModelMetadataChanges(BaseModelExcludeNull, extra="allow"):
"""A set of changes to apply to model metadata.
Only limited changes are valid:
- `default_settings`: the user-configured default settings for this model
"""
default_settings: Optional[ModelDefaultSettings] = Field(
default=None, description="The user-configured default settings for this model"
)
"""The user-configured default settings for this model"""
class ModelMetadataStoreBase(ABC):

View File

@@ -179,45 +179,44 @@ class ModelMetadataStoreSQL(ModelMetadataStoreBase):
)
return {x[0] for x in self._cursor.fetchall()}
def _update_tags(self, model_key: str, tags: Optional[Set[str]]) -> None:
def _update_tags(self, model_key: str, tags: Set[str]) -> None:
"""Update tags for the model referenced by model_key."""
if tags:
# remove previous tags from this model
# remove previous tags from this model
self._cursor.execute(
"""--sql
DELETE FROM model_tags
WHERE model_id=?;
""",
(model_key,),
)
for tag in tags:
self._cursor.execute(
"""--sql
DELETE FROM model_tags
WHERE model_id=?;
INSERT OR IGNORE INTO tags (
tag_text
)
VALUES (?);
""",
(model_key,),
(tag,),
)
self._cursor.execute(
"""--sql
SELECT tag_id
FROM tags
WHERE tag_text = ?
LIMIT 1;
""",
(tag,),
)
tag_id = self._cursor.fetchone()[0]
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO model_tags (
model_id,
tag_id
)
VALUES (?,?);
""",
(model_key, tag_id),
)
for tag in tags:
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO tags (
tag_text
)
VALUES (?);
""",
(tag,),
)
self._cursor.execute(
"""--sql
SELECT tag_id
FROM tags
WHERE tag_text = ?
LIMIT 1;
""",
(tag,),
)
tag_id = self._cursor.fetchone()[0]
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO model_tags (
model_id,
tag_id
)
VALUES (?,?);
""",
(model_key, tag_id),
)

View File

@@ -1,5 +1,4 @@
"""Init file for model record services."""
from .model_records_base import ( # noqa F401
DuplicateModelException,
InvalidModelException,

View File

@@ -39,6 +39,7 @@ Typical usage:
configs = store.search_by_attr(base_model='sd-2', model_type='main')
"""
import json
import sqlite3
from math import ceil

View File

@@ -4,17 +4,3 @@ from pydantic import BaseModel, Field
class SessionProcessorStatus(BaseModel):
is_started: bool = Field(description="Whether the session processor is started")
is_processing: bool = Field(description="Whether a session is being processed")
class CanceledException(Exception):
"""Execution canceled by user."""
pass
class ProgressImage(BaseModel):
"""The progress image sent intermittently during processing"""
width: int = Field(description="The effective width of the image in pixels")
height: int = Field(description="The effective height of the image in pixels")
dataURL: str = Field(description="The image data as a b64 data URL")

View File

@@ -1,5 +1,4 @@
import traceback
from contextlib import suppress
from threading import BoundedSemaphore, Thread
from threading import Event as ThreadEvent
from typing import Optional
@@ -7,271 +6,136 @@ from typing import Optional
from fastapi_events.handlers.local import local_handler
from fastapi_events.typing import Event as FastAPIEvent
from invokeai.app.invocations.baseinvocation import BaseInvocation
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invocation_stats.invocation_stats_common import GESStatsNotFoundError
from invokeai.app.services.session_processor.session_processor_common import CanceledException
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
from invokeai.app.services.shared.invocation_context import InvocationContextData, build_invocation_context
from invokeai.app.util.profiler import Profiler
from ..invoker import Invoker
from .session_processor_base import SessionProcessorBase
from .session_processor_common import SessionProcessorStatus
POLLING_INTERVAL = 1
THREAD_LIMIT = 1
class DefaultSessionProcessor(SessionProcessorBase):
def start(self, invoker: Invoker, thread_limit: int = 1, polling_interval: int = 1) -> None:
self._invoker: Invoker = invoker
self._queue_item: Optional[SessionQueueItem] = None
self._invocation: Optional[BaseInvocation] = None
def start(self, invoker: Invoker) -> None:
self.__invoker: Invoker = invoker
self.__queue_item: Optional[SessionQueueItem] = None
self._resume_event = ThreadEvent()
self._stop_event = ThreadEvent()
self._poll_now_event = ThreadEvent()
self._cancel_event = ThreadEvent()
self.__resume_event = ThreadEvent()
self.__stop_event = ThreadEvent()
self.__poll_now_event = ThreadEvent()
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_queue_event)
self._thread_limit = thread_limit
self._thread_semaphore = BoundedSemaphore(thread_limit)
self._polling_interval = polling_interval
# If profiling is enabled, create a profiler. The same profiler will be used for all sessions. Internally,
# the profiler will create a new profile for each session.
self._profiler = (
Profiler(
logger=self._invoker.services.logger,
output_dir=self._invoker.services.configuration.profiles_path,
prefix=self._invoker.services.configuration.profile_prefix,
)
if self._invoker.services.configuration.profile_graphs
else None
)
self._thread = Thread(
self.__threadLimit = BoundedSemaphore(THREAD_LIMIT)
self.__thread = Thread(
name="session_processor",
target=self._process,
target=self.__process,
kwargs={
"stop_event": self._stop_event,
"poll_now_event": self._poll_now_event,
"resume_event": self._resume_event,
"cancel_event": self._cancel_event,
"stop_event": self.__stop_event,
"poll_now_event": self.__poll_now_event,
"resume_event": self.__resume_event,
},
)
self._thread.start()
self.__thread.start()
def stop(self, *args, **kwargs) -> None:
self._stop_event.set()
self.__stop_event.set()
def _poll_now(self) -> None:
self._poll_now_event.set()
self.__poll_now_event.set()
async def _on_queue_event(self, event: FastAPIEvent) -> None:
event_name = event[1]["event"]
if event_name == "session_canceled" or event_name == "queue_cleared":
# These both mean we should cancel the current session.
self._cancel_event.set()
# This was a match statement, but match is not supported on python 3.9
if event_name in [
"graph_execution_state_complete",
"invocation_error",
"session_retrieval_error",
"invocation_retrieval_error",
]:
self.__queue_item = None
self._poll_now()
elif (
event_name == "session_canceled"
and self.__queue_item is not None
and self.__queue_item.session_id == event[1]["data"]["graph_execution_state_id"]
):
self.__queue_item = None
self._poll_now()
elif event_name == "batch_enqueued":
self._poll_now()
elif event_name == "queue_cleared":
self.__queue_item = None
self._poll_now()
def resume(self) -> SessionProcessorStatus:
if not self._resume_event.is_set():
self._resume_event.set()
if not self.__resume_event.is_set():
self.__resume_event.set()
return self.get_status()
def pause(self) -> SessionProcessorStatus:
if self._resume_event.is_set():
self._resume_event.clear()
if self.__resume_event.is_set():
self.__resume_event.clear()
return self.get_status()
def get_status(self) -> SessionProcessorStatus:
return SessionProcessorStatus(
is_started=self._resume_event.is_set(),
is_processing=self._queue_item is not None,
is_started=self.__resume_event.is_set(),
is_processing=self.__queue_item is not None,
)
def _process(
def __process(
self,
stop_event: ThreadEvent,
poll_now_event: ThreadEvent,
resume_event: ThreadEvent,
cancel_event: ThreadEvent,
):
# Outermost processor try block; any unhandled exception is a fatal processor error
try:
self._thread_semaphore.acquire()
stop_event.clear()
resume_event.set()
cancel_event.clear()
self.__threadLimit.acquire()
queue_item: Optional[SessionQueueItem] = None
while not stop_event.is_set():
poll_now_event.clear()
# Middle processor try block; any unhandled exception is a non-fatal processor error
try:
# Get the next session to process
self._queue_item = self._invoker.services.session_queue.dequeue()
if self._queue_item is not None and resume_event.is_set():
self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}")
cancel_event.clear()
# do not dequeue if there is already a session running
if self.__queue_item is None and resume_event.is_set():
queue_item = self.__invoker.services.session_queue.dequeue()
# If profiling is enabled, start the profiler
if self._profiler is not None:
self._profiler.start(profile_id=self._queue_item.session_id)
# Prepare invocations and take the first
self._invocation = self._queue_item.session.next()
# Loop over invocations until the session is complete or canceled
while self._invocation is not None and not cancel_event.is_set():
# get the source node id to provide to clients (the prepared node id is not as useful)
source_invocation_id = self._queue_item.session.prepared_source_mapping[self._invocation.id]
# Send starting event
self._invoker.services.events.emit_invocation_started(
queue_batch_id=self._queue_item.batch_id,
queue_item_id=self._queue_item.item_id,
queue_id=self._queue_item.queue_id,
graph_execution_state_id=self._queue_item.session_id,
node=self._invocation.model_dump(),
source_node_id=source_invocation_id,
if queue_item is not None:
self.__invoker.services.logger.debug(f"Executing queue item {queue_item.item_id}")
self.__queue_item = queue_item
self.__invoker.services.graph_execution_manager.set(queue_item.session)
self.__invoker.invoke(
session_queue_batch_id=queue_item.batch_id,
session_queue_id=queue_item.queue_id,
session_queue_item_id=queue_item.item_id,
graph_execution_state=queue_item.session,
workflow=queue_item.workflow,
invoke_all=True,
)
queue_item = None
# Innermost processor try block; any unhandled exception is an invocation error & will fail the graph
try:
with self._invoker.services.performance_statistics.collect_stats(
self._invocation, self._queue_item.session.id
):
# Build invocation context (the node-facing API)
data = InvocationContextData(
invocation=self._invocation,
source_invocation_id=source_invocation_id,
queue_item=self._queue_item,
)
context = build_invocation_context(
data=data,
services=self._invoker.services,
cancel_event=self._cancel_event,
)
# Invoke the node
outputs = self._invocation.invoke_internal(
context=context, services=self._invoker.services
)
# Save outputs and history
self._queue_item.session.complete(self._invocation.id, outputs)
# Send complete event
self._invoker.services.events.emit_invocation_complete(
queue_batch_id=self._queue_item.batch_id,
queue_item_id=self._queue_item.item_id,
queue_id=self._queue_item.queue_id,
graph_execution_state_id=self._queue_item.session.id,
node=self._invocation.model_dump(),
source_node_id=source_invocation_id,
result=outputs.model_dump(),
)
except KeyboardInterrupt:
# TODO(MM2): Create an event for this
pass
except CanceledException:
# When the user cancels the graph, we first set the cancel event. The event is checked
# between invocations, in this loop. Some invocations are long-running, and we need to
# be able to cancel them mid-execution.
#
# For example, denoising is a long-running invocation with many steps. A step callback
# is executed after each step. This step callback checks if the canceled event is set,
# then raises a CanceledException to stop execution immediately.
#
# When we get a CanceledException, we don't need to do anything - just pass and let the
# loop go to its next iteration, and the cancel event will be handled correctly.
pass
except Exception as e:
error = traceback.format_exc()
# Save error
self._queue_item.session.set_node_error(self._invocation.id, error)
self._invoker.services.logger.error(
f"Error while invoking session {self._queue_item.session_id}, invocation {self._invocation.id} ({self._invocation.get_type()}):\n{e}"
)
self._invoker.services.logger.error(error)
# Send error event
self._invoker.services.events.emit_invocation_error(
queue_batch_id=self._queue_item.session_id,
queue_item_id=self._queue_item.item_id,
queue_id=self._queue_item.queue_id,
graph_execution_state_id=self._queue_item.session.id,
node=self._invocation.model_dump(),
source_node_id=source_invocation_id,
error_type=e.__class__.__name__,
error=error,
)
pass
# The session is complete if the all invocations are complete or there was an error
if self._queue_item.session.is_complete() or cancel_event.is_set():
# Send complete event
self._invoker.services.events.emit_graph_execution_complete(
queue_batch_id=self._queue_item.batch_id,
queue_item_id=self._queue_item.item_id,
queue_id=self._queue_item.queue_id,
graph_execution_state_id=self._queue_item.session.id,
)
# If we are profiling, stop the profiler and dump the profile & stats
if self._profiler:
profile_path = self._profiler.stop()
stats_path = profile_path.with_suffix(".json")
self._invoker.services.performance_statistics.dump_stats(
graph_execution_state_id=self._queue_item.session.id, output_path=stats_path
)
# We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor
# we don't care about that - suppress the error.
with suppress(GESStatsNotFoundError):
self._invoker.services.performance_statistics.log_stats(self._queue_item.session.id)
self._invoker.services.performance_statistics.reset_stats()
# Set the invocation to None to prepare for the next session
self._invocation = None
else:
# Prepare the next invocation
self._invocation = self._queue_item.session.next()
# The session is complete, immediately poll for next session
self._queue_item = None
poll_now_event.set()
else:
# The queue was empty, wait for next polling interval or event to try again
self._invoker.services.logger.debug("Waiting for next polling interval or event")
poll_now_event.wait(self._polling_interval)
if queue_item is None:
self.__invoker.services.logger.debug("Waiting for next polling interval or event")
poll_now_event.wait(POLLING_INTERVAL)
continue
except Exception:
# Non-fatal error in processor
self._invoker.services.logger.error(
f"Non-fatal error in session processor:\n{traceback.format_exc()}"
)
# Cancel the queue item
if self._queue_item is not None:
self._invoker.services.session_queue.cancel_queue_item(
self._queue_item.item_id, error=traceback.format_exc()
except Exception as e:
self.__invoker.services.logger.error(f"Error in session processor: {e}")
if queue_item is not None:
self.__invoker.services.session_queue.cancel_queue_item(
queue_item.item_id, error=traceback.format_exc()
)
# Reset the invocation to None to prepare for the next session
self._invocation = None
# Immediately poll for next queue item
poll_now_event.wait(self._polling_interval)
poll_now_event.wait(POLLING_INTERVAL)
continue
except Exception:
# Fatal error in processor, log and pass - we're done here
self._invoker.services.logger.error(f"Fatal Error in session processor:\n{traceback.format_exc()}")
except Exception as e:
self.__invoker.services.logger.error(f"Fatal Error in session processor: {e}")
pass
finally:
stop_event.clear()
poll_now_event.clear()
self._queue_item = None
self._thread_semaphore.release()
self.__queue_item = None
self.__threadLimit.release()

View File

@@ -60,7 +60,7 @@ class SqliteSessionQueue(SessionQueueBase):
# This was a match statement, but match is not supported on python 3.9
if event_name == "graph_execution_state_complete":
await self._handle_complete_event(event)
elif event_name == "invocation_error":
elif event_name in ["invocation_error", "session_retrieval_error", "invocation_retrieval_error"]:
await self._handle_error_event(event)
elif event_name == "session_canceled":
await self._handle_cancel_event(event)
@@ -429,6 +429,7 @@ class SqliteSessionQueue(SessionQueueBase):
if queue_item.status not in ["canceled", "failed", "completed"]:
status = "failed" if error is not None else "canceled"
queue_item = self._set_queue_item_status(item_id=item_id, status=status, error=error) # type: ignore [arg-type] # mypy seems to not narrow the Literals here
self.__invoker.services.queue.cancel(queue_item.session_id)
self.__invoker.services.events.emit_session_canceled(
queue_item_id=queue_item.item_id,
queue_id=queue_item.queue_id,
@@ -470,6 +471,7 @@ class SqliteSessionQueue(SessionQueueBase):
)
self.__conn.commit()
if current_queue_item is not None and current_queue_item.batch_id in batch_ids:
self.__invoker.services.queue.cancel(current_queue_item.session_id)
self.__invoker.services.events.emit_session_canceled(
queue_item_id=current_queue_item.item_id,
queue_id=current_queue_item.queue_id,
@@ -521,6 +523,7 @@ class SqliteSessionQueue(SessionQueueBase):
)
self.__conn.commit()
if current_queue_item is not None and current_queue_item.queue_id == queue_id:
self.__invoker.services.queue.cancel(current_queue_item.session_id)
self.__invoker.services.events.emit_session_canceled(
queue_item_id=current_queue_item.item_id,
queue_id=current_queue_item.queue_id,

View File

@@ -0,0 +1,92 @@
from invokeai.app.services.item_storage.item_storage_base import ItemStorageABC
from ...invocations.compel import CompelInvocation
from ...invocations.image import ImageNSFWBlurInvocation
from ...invocations.latent import DenoiseLatentsInvocation, LatentsToImageInvocation
from ...invocations.noise import NoiseInvocation
from ...invocations.primitives import IntegerInvocation
from .graph import Edge, EdgeConnection, ExposedNodeInput, ExposedNodeOutput, Graph, LibraryGraph
default_text_to_image_graph_id = "539b2af5-2b4d-4d8c-8071-e54a3255fc74"
def create_text_to_image() -> LibraryGraph:
graph = Graph(
nodes={
"width": IntegerInvocation(id="width", value=512),
"height": IntegerInvocation(id="height", value=512),
"seed": IntegerInvocation(id="seed", value=-1),
"3": NoiseInvocation(id="3"),
"4": CompelInvocation(id="4"),
"5": CompelInvocation(id="5"),
"6": DenoiseLatentsInvocation(id="6"),
"7": LatentsToImageInvocation(id="7"),
"8": ImageNSFWBlurInvocation(id="8"),
},
edges=[
Edge(
source=EdgeConnection(node_id="width", field="value"),
destination=EdgeConnection(node_id="3", field="width"),
),
Edge(
source=EdgeConnection(node_id="height", field="value"),
destination=EdgeConnection(node_id="3", field="height"),
),
Edge(
source=EdgeConnection(node_id="seed", field="value"),
destination=EdgeConnection(node_id="3", field="seed"),
),
Edge(
source=EdgeConnection(node_id="3", field="noise"),
destination=EdgeConnection(node_id="6", field="noise"),
),
Edge(
source=EdgeConnection(node_id="6", field="latents"),
destination=EdgeConnection(node_id="7", field="latents"),
),
Edge(
source=EdgeConnection(node_id="4", field="conditioning"),
destination=EdgeConnection(node_id="6", field="positive_conditioning"),
),
Edge(
source=EdgeConnection(node_id="5", field="conditioning"),
destination=EdgeConnection(node_id="6", field="negative_conditioning"),
),
Edge(
source=EdgeConnection(node_id="7", field="image"),
destination=EdgeConnection(node_id="8", field="image"),
),
],
)
return LibraryGraph(
id=default_text_to_image_graph_id,
name="t2i",
description="Converts text to an image",
graph=graph,
exposed_inputs=[
ExposedNodeInput(node_path="4", field="prompt", alias="positive_prompt"),
ExposedNodeInput(node_path="5", field="prompt", alias="negative_prompt"),
ExposedNodeInput(node_path="width", field="value", alias="width"),
ExposedNodeInput(node_path="height", field="value", alias="height"),
ExposedNodeInput(node_path="seed", field="value", alias="seed"),
],
exposed_outputs=[ExposedNodeOutput(node_path="8", field="image", alias="image")],
)
def create_system_graphs(graph_library: ItemStorageABC[LibraryGraph]) -> list[LibraryGraph]:
"""Creates the default system graphs, or adds new versions if the old ones don't match"""
# TODO: Uncomment this when we are ready to fix this up to prevent breaking changes
graphs: list[LibraryGraph] = []
text_to_image = graph_library.get(default_text_to_image_graph_id)
# TODO: Check if the graph is the same as the default one, and if not, update it
# if text_to_image is None:
text_to_image = create_text_to_image()
graph_library.set(text_to_image)
graphs.append(text_to_image)
return graphs

View File

@@ -5,14 +5,8 @@ import itertools
from typing import Annotated, Any, Optional, TypeVar, Union, get_args, get_origin, get_type_hints
import networkx as nx
from pydantic import (
BaseModel,
GetJsonSchemaHandler,
field_validator,
)
from pydantic import BaseModel, ConfigDict, field_validator, model_validator
from pydantic.fields import Field
from pydantic.json_schema import JsonSchemaValue
from pydantic_core import CoreSchema
# Importing * is bad karma but needed here for node detection
from invokeai.app.invocations import * # noqa: F401 F403
@@ -182,6 +176,10 @@ class NodeIdMismatchError(ValueError):
pass
class InvalidSubGraphError(ValueError):
pass
class CyclicalGraphError(ValueError):
pass
@@ -190,6 +188,25 @@ class UnknownGraphValidationError(ValueError):
pass
# TODO: Create and use an Empty output?
@invocation_output("graph_output")
class GraphInvocationOutput(BaseInvocationOutput):
pass
# TODO: Fill this out and move to invocations
@invocation("graph", version="1.0.0")
class GraphInvocation(BaseInvocation):
"""Execute a graph"""
# TODO: figure out how to create a default here
graph: "Graph" = InputField(description="The graph to run", default=None)
def invoke(self, context: InvocationContext) -> GraphInvocationOutput:
"""Invoke with provided services and return outputs."""
return GraphInvocationOutput()
@invocation_output("iterate_output")
class IterateInvocationOutput(BaseInvocationOutput):
"""Used to connect iteration outputs. Will be expanded to a specific output."""
@@ -243,73 +260,21 @@ class CollectInvocation(BaseInvocation):
return CollectInvocationOutput(collection=copy.copy(self.collection))
InvocationsUnion: Any = BaseInvocation.get_invocations_union()
InvocationOutputsUnion: Any = BaseInvocationOutput.get_outputs_union()
class Graph(BaseModel):
id: str = Field(description="The id of this graph", default_factory=uuid_string)
# TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me
nodes: dict[str, BaseInvocation] = Field(description="The nodes in this graph", default_factory=dict)
nodes: dict[str, Annotated[InvocationsUnion, Field(discriminator="type")]] = Field(
description="The nodes in this graph", default_factory=dict
)
edges: list[Edge] = Field(
description="The connections between nodes and their fields in this graph",
default_factory=list,
)
@field_validator("nodes", mode="plain")
@classmethod
def validate_nodes(cls, v: dict[str, Any]):
"""Validates the nodes in the graph by retrieving a union of all node types and validating each node."""
# Invocations register themselves as their python modules are executed. The union of all invocations is
# constructed at runtime. We use pydantic to validate `Graph.nodes` using that union.
#
# It's possible that when `graph.py` is executed, not all invocation-containing modules will have executed. If
# we construct the invocation union as `graph.py` is executed, we may miss some invocations. Those missing
# invocations will cause a graph to fail if they are used.
#
# We can get around this by validating the nodes in the graph using a "plain" validator, which overrides the
# pydantic validation entirely. This allows us to validate the nodes using the union of invocations at runtime.
#
# This same pattern is used in `GraphExecutionState`.
nodes: dict[str, BaseInvocation] = {}
typeadapter = BaseInvocation.get_typeadapter()
for node_id, node in v.items():
nodes[node_id] = typeadapter.validate_python(node)
return nodes
@classmethod
def __get_pydantic_json_schema__(cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
# We use a "plain" validator to validate the nodes in the graph. Pydantic is unable to create a JSON Schema for
# fields that use "plain" validators, so we have to hack around this. Also, we need to add all invocations to
# the generated schema as options for the `nodes` field.
#
# The workaround is to create a new BaseModel that has the same fields as `Graph` but without the validator and
# with the invocation union as the type for the `nodes` field. Pydantic then generates the JSON Schema as
# expected.
#
# You might be tempted to do something like this:
#
# ```py
# cloned_model = create_model(cls.__name__, __base__=cls, nodes=...)
# delattr(cloned_model, "validate_nodes")
# cloned_model.model_rebuild(force=True)
# json_schema = handler(cloned_model.__pydantic_core_schema__)
# ```
#
# Unfortunately, this does not work. Calling `handler` here results in infinite recursion as pydantic attempts
# to build the JSON Schema for the cloned model. Instead, we have to manually clone the model.
#
# This same pattern is used in `GraphExecutionState`.
class Graph(BaseModel):
id: Optional[str] = Field(default=None, description="The id of this graph")
nodes: dict[
str, Annotated[Union[tuple(BaseInvocation._invocation_classes)], Field(discriminator="type")]
] = Field(description="The nodes in this graph")
edges: list[Edge] = Field(description="The connections between nodes and their fields in this graph")
json_schema = handler(Graph.__pydantic_core_schema__)
json_schema = handler.resolve_ref_schema(json_schema)
return json_schema
def add_node(self, node: BaseInvocation) -> None:
"""Adds a node to a graph
@@ -321,21 +286,41 @@ class Graph(BaseModel):
self.nodes[node.id] = node
def delete_node(self, node_id: str) -> None:
def _get_graph_and_node(self, node_path: str) -> tuple["Graph", str]:
"""Returns the graph and node id for a node path."""
# Materialized graphs may have nodes at the top level
if node_path in self.nodes:
return (self, node_path)
node_id = node_path if "." not in node_path else node_path[: node_path.index(".")]
if node_id not in self.nodes:
raise NodeNotFoundError(f"Node {node_path} not found in graph")
node = self.nodes[node_id]
if not isinstance(node, GraphInvocation):
# There's more node path left but this isn't a graph - failure
raise NodeNotFoundError("Node path terminated early at a non-graph node")
return node.graph._get_graph_and_node(node_path[node_path.index(".") + 1 :])
def delete_node(self, node_path: str) -> None:
"""Deletes a node from a graph"""
try:
graph, node_id = self._get_graph_and_node(node_path)
# Delete edges for this node
input_edges = self._get_input_edges(node_id)
output_edges = self._get_output_edges(node_id)
input_edges = self._get_input_edges_and_graphs(node_path)
output_edges = self._get_output_edges_and_graphs(node_path)
for edge in input_edges:
self.delete_edge(edge)
for edge_graph, _, edge in input_edges:
edge_graph.delete_edge(edge)
for edge in output_edges:
self.delete_edge(edge)
for edge_graph, _, edge in output_edges:
edge_graph.delete_edge(edge)
del self.nodes[node_id]
del graph.nodes[node_id]
except NodeNotFoundError:
pass # Ignore, not doesn't exist (should this throw?)
@@ -385,6 +370,13 @@ class Graph(BaseModel):
if k != v.id:
raise NodeIdMismatchError(f"Node ids must match, got {k} and {v.id}")
# Validate all subgraphs
for gn in (n for n in self.nodes.values() if isinstance(n, GraphInvocation)):
try:
gn.graph.validate_self()
except Exception as e:
raise InvalidSubGraphError(f"Subgraph {gn.id} is invalid") from e
# Validate that all edges match nodes and fields in the graph
for edge in self.edges:
source_node = self.nodes.get(edge.source.node_id, None)
@@ -446,6 +438,7 @@ class Graph(BaseModel):
except (
DuplicateNodeIdError,
NodeIdMismatchError,
InvalidSubGraphError,
NodeNotFoundError,
NodeFieldNotFoundError,
CyclicalGraphError,
@@ -466,7 +459,7 @@ class Graph(BaseModel):
def _validate_edge(self, edge: Edge):
"""Validates that a new edge doesn't create a cycle in the graph"""
# Validate that the nodes exist
# Validate that the nodes exist (edges may contain node paths, so we can't just check for nodes directly)
try:
from_node = self.get_node(edge.source.node_id)
to_node = self.get_node(edge.destination.node_id)
@@ -533,90 +526,171 @@ class Graph(BaseModel):
f"Collector input type does not match collector output type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
)
def has_node(self, node_id: str) -> bool:
def has_node(self, node_path: str) -> bool:
"""Determines whether or not a node exists in the graph."""
try:
_ = self.get_node(node_id)
return True
n = self.get_node(node_path)
if n is not None:
return True
else:
return False
except NodeNotFoundError:
return False
def get_node(self, node_id: str) -> BaseInvocation:
"""Gets a node from the graph."""
try:
return self.nodes[node_id]
except KeyError as e:
raise NodeNotFoundError(f"Node {node_id} not found in graph") from e
def get_node(self, node_path: str) -> BaseInvocation:
"""Gets a node from the graph using a node path."""
# Materialized graphs may have nodes at the top level
graph, node_id = self._get_graph_and_node(node_path)
return graph.nodes[node_id]
def update_node(self, node_id: str, new_node: BaseInvocation) -> None:
def _get_node_path(self, node_id: str, prefix: Optional[str] = None) -> str:
return node_id if prefix is None or prefix == "" else f"{prefix}.{node_id}"
def update_node(self, node_path: str, new_node: BaseInvocation) -> None:
"""Updates a node in the graph."""
node = self.nodes[node_id]
graph, node_id = self._get_graph_and_node(node_path)
node = graph.nodes[node_id]
# Ensure the node type matches the new node
if type(node) is not type(new_node):
raise TypeError(f"Node {node_id} is type {type(node)} but new node is type {type(new_node)}")
raise TypeError(f"Node {node_path} is type {type(node)} but new node is type {type(new_node)}")
# Ensure the new id is either the same or is not in the graph
if new_node.id != node.id and self.has_node(new_node.id):
raise NodeAlreadyInGraphError(f"Node with id {new_node.id} already exists in graph")
prefix = None if "." not in node_path else node_path[: node_path.rindex(".")]
new_path = self._get_node_path(new_node.id, prefix=prefix)
if new_node.id != node.id and self.has_node(new_path):
raise NodeAlreadyInGraphError("Node with id {new_node.id} already exists in graph")
# Set the new node in the graph
self.nodes[new_node.id] = new_node
graph.nodes[new_node.id] = new_node
if new_node.id != node.id:
input_edges = self._get_input_edges(node_id)
output_edges = self._get_output_edges(node_id)
input_edges = self._get_input_edges_and_graphs(node_path)
output_edges = self._get_output_edges_and_graphs(node_path)
# Delete node and all edges
self.delete_node(node_id)
graph.delete_node(node_path)
# Create new edges for each input and output
for edge in input_edges:
self.add_edge(
for graph, _, edge in input_edges:
# Remove the graph prefix from the node path
new_graph_node_path = (
new_node.id
if "." not in edge.destination.node_id
else f'{edge.destination.node_id[edge.destination.node_id.rindex("."):]}.{new_node.id}'
)
graph.add_edge(
Edge(
source=edge.source,
destination=EdgeConnection(node_id=new_node.id, field=edge.destination.field),
destination=EdgeConnection(node_id=new_graph_node_path, field=edge.destination.field),
)
)
for edge in output_edges:
self.add_edge(
for graph, _, edge in output_edges:
# Remove the graph prefix from the node path
new_graph_node_path = (
new_node.id
if "." not in edge.source.node_id
else f'{edge.source.node_id[edge.source.node_id.rindex("."):]}.{new_node.id}'
)
graph.add_edge(
Edge(
source=EdgeConnection(node_id=new_node.id, field=edge.source.field),
source=EdgeConnection(node_id=new_graph_node_path, field=edge.source.field),
destination=edge.destination,
)
)
def _get_input_edges(self, node_id: str, field: Optional[str] = None) -> list[Edge]:
"""Gets all input edges for a node. If field is provided, only edges to that field are returned."""
def _get_input_edges(self, node_path: str, field: Optional[str] = None) -> list[Edge]:
"""Gets all input edges for a node"""
edges = self._get_input_edges_and_graphs(node_path)
edges = [e for e in self.edges if e.destination.node_id == node_id]
# Filter to edges that match the field
filtered_edges = (e for e in edges if field is None or e[2].destination.field == field)
if field is None:
return edges
# Create full node paths for each edge
return [
Edge(
source=EdgeConnection(
node_id=self._get_node_path(e.source.node_id, prefix=prefix),
field=e.source.field,
),
destination=EdgeConnection(
node_id=self._get_node_path(e.destination.node_id, prefix=prefix),
field=e.destination.field,
),
)
for _, prefix, e in filtered_edges
]
filtered_edges = [e for e in edges if e.destination.field == field]
def _get_input_edges_and_graphs(
self, node_path: str, prefix: Optional[str] = None
) -> list[tuple["Graph", Union[str, None], Edge]]:
"""Gets all input edges for a node along with the graph they are in and the graph's path"""
edges = []
return filtered_edges
# Return any input edges that appear in this graph
edges.extend([(self, prefix, e) for e in self.edges if e.destination.node_id == node_path])
def _get_output_edges(self, node_id: str, field: Optional[str] = None) -> list[Edge]:
"""Gets all output edges for a node. If field is provided, only edges from that field are returned."""
edges = [e for e in self.edges if e.source.node_id == node_id]
node_id = node_path if "." not in node_path else node_path[: node_path.index(".")]
node = self.nodes[node_id]
if field is None:
return edges
if isinstance(node, GraphInvocation):
graph = node.graph
graph_path = node.id if prefix is None or prefix == "" else self._get_node_path(node.id, prefix=prefix)
graph_edges = graph._get_input_edges_and_graphs(node_path[(len(node_id) + 1) :], prefix=graph_path)
edges.extend(graph_edges)
filtered_edges = [e for e in edges if e.source.field == field]
return edges
return filtered_edges
def _get_output_edges(self, node_path: str, field: str) -> list[Edge]:
"""Gets all output edges for a node"""
edges = self._get_output_edges_and_graphs(node_path)
# Filter to edges that match the field
filtered_edges = (e for e in edges if e[2].source.field == field)
# Create full node paths for each edge
return [
Edge(
source=EdgeConnection(
node_id=self._get_node_path(e.source.node_id, prefix=prefix),
field=e.source.field,
),
destination=EdgeConnection(
node_id=self._get_node_path(e.destination.node_id, prefix=prefix),
field=e.destination.field,
),
)
for _, prefix, e in filtered_edges
]
def _get_output_edges_and_graphs(
self, node_path: str, prefix: Optional[str] = None
) -> list[tuple["Graph", Union[str, None], Edge]]:
"""Gets all output edges for a node along with the graph they are in and the graph's path"""
edges = []
# Return any input edges that appear in this graph
edges.extend([(self, prefix, e) for e in self.edges if e.source.node_id == node_path])
node_id = node_path if "." not in node_path else node_path[: node_path.index(".")]
node = self.nodes[node_id]
if isinstance(node, GraphInvocation):
graph = node.graph
graph_path = node.id if prefix is None or prefix == "" else self._get_node_path(node.id, prefix=prefix)
graph_edges = graph._get_output_edges_and_graphs(node_path[(len(node_id) + 1) :], prefix=graph_path)
edges.extend(graph_edges)
return edges
def _is_iterator_connection_valid(
self,
node_id: str,
node_path: str,
new_input: Optional[EdgeConnection] = None,
new_output: Optional[EdgeConnection] = None,
) -> bool:
inputs = [e.source for e in self._get_input_edges(node_id, "collection")]
outputs = [e.destination for e in self._get_output_edges(node_id, "item")]
inputs = [e.source for e in self._get_input_edges(node_path, "collection")]
outputs = [e.destination for e in self._get_output_edges(node_path, "item")]
if new_input is not None:
inputs.append(new_input)
@@ -644,12 +718,12 @@ class Graph(BaseModel):
def _is_collector_connection_valid(
self,
node_id: str,
node_path: str,
new_input: Optional[EdgeConnection] = None,
new_output: Optional[EdgeConnection] = None,
) -> bool:
inputs = [e.source for e in self._get_input_edges(node_id, "item")]
outputs = [e.destination for e in self._get_output_edges(node_id, "collection")]
inputs = [e.source for e in self._get_input_edges(node_path, "item")]
outputs = [e.destination for e in self._get_output_edges(node_path, "collection")]
if new_input is not None:
inputs.append(new_input)
@@ -705,17 +779,27 @@ class Graph(BaseModel):
g.add_edges_from({(e.source.node_id, e.destination.node_id) for e in self.edges})
return g
def nx_graph_flat(self, nx_graph: Optional[nx.DiGraph] = None) -> nx.DiGraph:
def nx_graph_flat(self, nx_graph: Optional[nx.DiGraph] = None, prefix: Optional[str] = None) -> nx.DiGraph:
"""Returns a flattened NetworkX DiGraph, including all subgraphs (but not with iterations expanded)"""
g = nx_graph or nx.DiGraph()
# Add all nodes from this graph except graph/iteration nodes
g.add_nodes_from([n.id for n in self.nodes.values() if not isinstance(n, IterateInvocation)])
g.add_nodes_from(
[
self._get_node_path(n.id, prefix)
for n in self.nodes.values()
if not isinstance(n, GraphInvocation) and not isinstance(n, IterateInvocation)
]
)
# Expand graph nodes
for sgn in (gn for gn in self.nodes.values() if isinstance(gn, GraphInvocation)):
g = sgn.graph.nx_graph_flat(g, self._get_node_path(sgn.id, prefix))
# TODO: figure out if iteration nodes need to be expanded
unique_edges = {(e.source.node_id, e.destination.node_id) for e in self.edges}
g.add_edges_from([(e[0], e[1]) for e in unique_edges])
g.add_edges_from([(self._get_node_path(e[0], prefix), self._get_node_path(e[1], prefix)) for e in unique_edges])
return g
@@ -740,7 +824,9 @@ class GraphExecutionState(BaseModel):
)
# The results of executed nodes
results: dict[str, BaseInvocationOutput] = Field(description="The results of node executions", default_factory=dict)
results: dict[str, Annotated[InvocationOutputsUnion, Field(discriminator="type")]] = Field(
description="The results of node executions", default_factory=dict
)
# Errors raised when executing nodes
errors: dict[str, str] = Field(description="Errors raised when executing nodes", default_factory=dict)
@@ -757,51 +843,27 @@ class GraphExecutionState(BaseModel):
default_factory=dict,
)
@field_validator("results", mode="plain")
@classmethod
def validate_results(cls, v: dict[str, BaseInvocationOutput]):
"""Validates the results in the GES by retrieving a union of all output types and validating each result."""
# See the comment in `Graph.validate_nodes` for an explanation of this logic.
results: dict[str, BaseInvocationOutput] = {}
typeadapter = BaseInvocationOutput.get_typeadapter()
for result_id, result in v.items():
results[result_id] = typeadapter.validate_python(result)
return results
@field_validator("graph")
def graph_is_valid(cls, v: Graph):
"""Validates that the graph is valid"""
v.validate_self()
return v
@classmethod
def __get_pydantic_json_schema__(cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
# See the comment in `Graph.__get_pydantic_json_schema__` for an explanation of this logic.
class GraphExecutionState(BaseModel):
"""Tracks the state of a graph execution"""
id: str = Field(description="The id of the execution state")
graph: Graph = Field(description="The graph being executed")
execution_graph: Graph = Field(description="The expanded graph of activated and executed nodes")
executed: set[str] = Field(description="The set of node ids that have been executed")
executed_history: list[str] = Field(
description="The list of node ids that have been executed, in order of execution"
)
results: dict[
str, Annotated[Union[tuple(BaseInvocationOutput._output_classes)], Field(discriminator="type")]
] = Field(description="The results of node executions")
errors: dict[str, str] = Field(description="Errors raised when executing nodes")
prepared_source_mapping: dict[str, str] = Field(
description="The map of prepared nodes to original graph nodes"
)
source_prepared_mapping: dict[str, set[str]] = Field(
description="The map of original graph nodes to prepared nodes"
)
json_schema = handler(GraphExecutionState.__pydantic_core_schema__)
json_schema = handler.resolve_ref_schema(json_schema)
return json_schema
model_config = ConfigDict(
json_schema_extra={
"required": [
"id",
"graph",
"execution_graph",
"executed",
"executed_history",
"results",
"errors",
"prepared_source_mapping",
"source_prepared_mapping",
]
}
)
def next(self) -> Optional[BaseInvocation]:
"""Gets the next node ready to execute."""
@@ -857,17 +919,17 @@ class GraphExecutionState(BaseModel):
"""Returns true if the graph has any errors"""
return len(self.errors) > 0
def _create_execution_node(self, node_id: str, iteration_node_map: list[tuple[str, str]]) -> list[str]:
def _create_execution_node(self, node_path: str, iteration_node_map: list[tuple[str, str]]) -> list[str]:
"""Prepares an iteration node and connects all edges, returning the new node id"""
node = self.graph.get_node(node_id)
node = self.graph.get_node(node_path)
self_iteration_count = -1
# If this is an iterator node, we must create a copy for each iteration
if isinstance(node, IterateInvocation):
# Get input collection edge (should error if there are no inputs)
input_collection_edge = next(iter(self.graph._get_input_edges(node_id, "collection")))
input_collection_edge = next(iter(self.graph._get_input_edges(node_path, "collection")))
input_collection_prepared_node_id = next(
n[1] for n in iteration_node_map if n[0] == input_collection_edge.source.node_id
)
@@ -881,7 +943,7 @@ class GraphExecutionState(BaseModel):
return new_nodes
# Get all input edges
input_edges = self.graph._get_input_edges(node_id)
input_edges = self.graph._get_input_edges(node_path)
# Create new edges for this iteration
# For collect nodes, this may contain multiple inputs to the same field
@@ -908,10 +970,10 @@ class GraphExecutionState(BaseModel):
# Add to execution graph
self.execution_graph.add_node(new_node)
self.prepared_source_mapping[new_node.id] = node_id
if node_id not in self.source_prepared_mapping:
self.source_prepared_mapping[node_id] = set()
self.source_prepared_mapping[node_id].add(new_node.id)
self.prepared_source_mapping[new_node.id] = node_path
if node_path not in self.source_prepared_mapping:
self.source_prepared_mapping[node_path] = set()
self.source_prepared_mapping[node_path].add(new_node.id)
# Add new edges to execution graph
for edge in new_edges:
@@ -1015,13 +1077,13 @@ class GraphExecutionState(BaseModel):
def _get_iteration_node(
self,
source_node_id: str,
source_node_path: str,
graph: nx.DiGraph,
execution_graph: nx.DiGraph,
prepared_iterator_nodes: list[str],
) -> Optional[str]:
"""Gets the prepared version of the specified source node that matches every iteration specified"""
prepared_nodes = self.source_prepared_mapping[source_node_id]
prepared_nodes = self.source_prepared_mapping[source_node_path]
if len(prepared_nodes) == 1:
return next(iter(prepared_nodes))
@@ -1032,7 +1094,7 @@ class GraphExecutionState(BaseModel):
# Filter to only iterator nodes that are a parent of the specified node, in tuple format (prepared, source)
iterator_source_node_mapping = [(n, self.prepared_source_mapping[n]) for n in prepared_iterator_nodes]
parent_iterators = [itn for itn in iterator_source_node_mapping if nx.has_path(graph, itn[1], source_node_id)]
parent_iterators = [itn for itn in iterator_source_node_mapping if nx.has_path(graph, itn[1], source_node_path)]
return next(
(n for n in prepared_nodes if all(nx.has_path(execution_graph, pit[0], n) for pit in parent_iterators)),
@@ -1101,19 +1163,19 @@ class GraphExecutionState(BaseModel):
def add_node(self, node: BaseInvocation) -> None:
self.graph.add_node(node)
def update_node(self, node_id: str, new_node: BaseInvocation) -> None:
if not self._is_node_updatable(node_id):
def update_node(self, node_path: str, new_node: BaseInvocation) -> None:
if not self._is_node_updatable(node_path):
raise NodeAlreadyExecutedError(
f"Node {node_id} has already been prepared or executed and cannot be updated"
f"Node {node_path} has already been prepared or executed and cannot be updated"
)
self.graph.update_node(node_id, new_node)
self.graph.update_node(node_path, new_node)
def delete_node(self, node_id: str) -> None:
if not self._is_node_updatable(node_id):
def delete_node(self, node_path: str) -> None:
if not self._is_node_updatable(node_path):
raise NodeAlreadyExecutedError(
f"Node {node_id} has already been prepared or executed and cannot be deleted"
f"Node {node_path} has already been prepared or executed and cannot be deleted"
)
self.graph.delete_node(node_id)
self.graph.delete_node(node_path)
def add_edge(self, edge: Edge) -> None:
if not self._is_node_updatable(edge.destination.node_id):
@@ -1128,3 +1190,63 @@ class GraphExecutionState(BaseModel):
f"Destination node {edge.destination.node_id} has already been prepared or executed and cannot have a source edge deleted"
)
self.graph.delete_edge(edge)
class ExposedNodeInput(BaseModel):
node_path: str = Field(description="The node path to the node with the input")
field: str = Field(description="The field name of the input")
alias: str = Field(description="The alias of the input")
class ExposedNodeOutput(BaseModel):
node_path: str = Field(description="The node path to the node with the output")
field: str = Field(description="The field name of the output")
alias: str = Field(description="The alias of the output")
class LibraryGraph(BaseModel):
id: str = Field(description="The unique identifier for this library graph", default_factory=uuid_string)
graph: Graph = Field(description="The graph")
name: str = Field(description="The name of the graph")
description: str = Field(description="The description of the graph")
exposed_inputs: list[ExposedNodeInput] = Field(description="The inputs exposed by this graph", default_factory=list)
exposed_outputs: list[ExposedNodeOutput] = Field(
description="The outputs exposed by this graph", default_factory=list
)
@field_validator("exposed_inputs", "exposed_outputs")
def validate_exposed_aliases(cls, v: list[Union[ExposedNodeInput, ExposedNodeOutput]]):
if len(v) != len({i.alias for i in v}):
raise ValueError("Duplicate exposed alias")
return v
@model_validator(mode="after")
def validate_exposed_nodes(cls, values):
graph = values.graph
# Validate exposed inputs
for exposed_input in values.exposed_inputs:
if not graph.has_node(exposed_input.node_path):
raise ValueError(f"Exposed input node {exposed_input.node_path} does not exist")
node = graph.get_node(exposed_input.node_path)
if get_input_field(node, exposed_input.field) is None:
raise ValueError(
f"Exposed input field {exposed_input.field} does not exist on node {exposed_input.node_path}"
)
# Validate exposed outputs
for exposed_output in values.exposed_outputs:
if not graph.has_node(exposed_output.node_path):
raise ValueError(f"Exposed output node {exposed_output.node_path} does not exist")
node = graph.get_node(exposed_output.node_path)
if get_output_field(node, exposed_output.field) is None:
raise ValueError(
f"Exposed output field {exposed_output.field} does not exist on node {exposed_output.node_path}"
)
return values
GraphInvocation.model_rebuild(force=True)
Graph.model_rebuild(force=True)
GraphExecutionState.model_rebuild(force=True)

View File

@@ -1,4 +1,3 @@
import threading
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Optional
@@ -13,6 +12,7 @@ from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
from invokeai.app.services.images.images_common import ImageDTO
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
from invokeai.backend.model_manager.load.load_base import LoadedModel
@@ -22,7 +22,6 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import Condit
if TYPE_CHECKING:
from invokeai.app.invocations.baseinvocation import BaseInvocation
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
"""
The InvocationContext provides access to various services and data about the current invocation.
@@ -49,102 +48,99 @@ Note: The docstrings are in weird places, but that's where they must be to get I
@dataclass
class InvocationContextData:
queue_item: "SessionQueueItem"
"""The queue item that is being executed."""
invocation: "BaseInvocation"
"""The invocation that is being executed."""
source_invocation_id: str
"""The ID of the invocation from which the currently executing invocation was prepared."""
session_id: str
"""The session that is being executed."""
queue_id: str
"""The queue in which the session is being executed."""
source_node_id: str
"""The ID of the node from which the currently executing invocation was prepared."""
queue_item_id: int
"""The ID of the queue item that is being executed."""
batch_id: str
"""The ID of the batch that is being executed."""
workflow: Optional[WorkflowWithoutID] = None
"""The workflow associated with this queue item, if any."""
class InvocationContextInterface:
def __init__(self, services: InvocationServices, data: InvocationContextData) -> None:
def __init__(self, services: InvocationServices, context_data: InvocationContextData) -> None:
self._services = services
self._data = data
self._context_data = context_data
class BoardsInterface(InvocationContextInterface):
def create(self, board_name: str) -> BoardDTO:
"""Creates a board.
"""
Creates a board.
Args:
board_name: The name of the board to create.
Returns:
The created board DTO.
:param board_name: The name of the board to create.
"""
return self._services.boards.create(board_name)
def get_dto(self, board_id: str) -> BoardDTO:
"""Gets a board DTO.
"""
Gets a board DTO.
Args:
board_id: The ID of the board to get.
Returns:
The board DTO.
:param board_id: The ID of the board to get.
"""
return self._services.boards.get_dto(board_id)
def get_all(self) -> list[BoardDTO]:
"""Gets all boards.
Returns:
A list of all boards.
"""
Gets all boards.
"""
return self._services.boards.get_all()
def add_image_to_board(self, board_id: str, image_name: str) -> None:
"""Adds an image to a board.
"""
Adds an image to a board.
Args:
board_id: The ID of the board to add the image to.
image_name: The name of the image to add to the board.
:param board_id: The ID of the board to add the image to.
:param image_name: The name of the image to add to the board.
"""
return self._services.board_images.add_image_to_board(board_id, image_name)
def get_all_image_names_for_board(self, board_id: str) -> list[str]:
"""Gets all image names for a board.
"""
Gets all image names for a board.
Args:
board_id: The ID of the board to get the image names for.
Returns:
A list of all image names for the board.
:param board_id: The ID of the board to get the image names for.
"""
return self._services.board_images.get_all_board_image_names_for_board(board_id)
class LoggerInterface(InvocationContextInterface):
def debug(self, message: str) -> None:
"""Logs a debug message.
"""
Logs a debug message.
Args:
message: The message to log.
:param message: The message to log.
"""
self._services.logger.debug(message)
def info(self, message: str) -> None:
"""Logs an info message.
"""
Logs an info message.
Args:
message: The message to log.
:param message: The message to log.
"""
self._services.logger.info(message)
def warning(self, message: str) -> None:
"""Logs a warning message.
"""
Logs a warning message.
Args:
message: The message to log.
:param message: The message to log.
"""
self._services.logger.warning(message)
def error(self, message: str) -> None:
"""Logs an error message.
"""
Logs an error message.
Args:
message: The message to log.
:param message: The message to log.
"""
self._services.logger.error(message)
@@ -157,60 +153,54 @@ class ImagesInterface(InvocationContextInterface):
image_category: ImageCategory = ImageCategory.GENERAL,
metadata: Optional[MetadataField] = None,
) -> ImageDTO:
"""Saves an image, returning its DTO.
"""
Saves an image, returning its DTO.
If the current queue item has a workflow or metadata, it is automatically saved with the image.
Args:
image: The image to save, as a PIL image.
board_id: The board ID to add the image to, if it should be added. It the invocation \
:param image: The image to save, as a PIL image.
:param board_id: The board ID to add the image to, if it should be added. It the invocation \
inherits from `WithBoard`, that board will be used automatically. **Use this only if \
you want to override or provide a board manually!**
image_category: The category of the image. Only the GENERAL category is added \
:param image_category: The category of the image. Only the GENERAL category is added \
to the gallery.
metadata: The metadata to save with the image, if it should have any. If the \
:param metadata: The metadata to save with the image, if it should have any. If the \
invocation inherits from `WithMetadata`, that metadata will be used automatically. \
**Use this only if you want to override or provide metadata manually!**
Returns:
The saved image DTO.
"""
# If `metadata` is provided directly, use that. Else, use the metadata provided by `WithMetadata`, falling back to None.
metadata_ = None
if metadata:
metadata_ = metadata
elif isinstance(self._data.invocation, WithMetadata):
metadata_ = self._data.invocation.metadata
elif isinstance(self._context_data.invocation, WithMetadata):
metadata_ = self._context_data.invocation.metadata
# If `board_id` is provided directly, use that. Else, use the board provided by `WithBoard`, falling back to None.
board_id_ = None
if board_id:
board_id_ = board_id
elif isinstance(self._data.invocation, WithBoard) and self._data.invocation.board:
board_id_ = self._data.invocation.board.board_id
elif isinstance(self._context_data.invocation, WithBoard) and self._context_data.invocation.board:
board_id_ = self._context_data.invocation.board.board_id
return self._services.images.create(
image=image,
is_intermediate=self._data.invocation.is_intermediate,
is_intermediate=self._context_data.invocation.is_intermediate,
image_category=image_category,
board_id=board_id_,
metadata=metadata_,
image_origin=ResourceOrigin.INTERNAL,
workflow=self._data.queue_item.workflow,
session_id=self._data.queue_item.session_id,
node_id=self._data.invocation.id,
workflow=self._context_data.workflow,
session_id=self._context_data.session_id,
node_id=self._context_data.invocation.id,
)
def get_pil(self, image_name: str, mode: IMAGE_MODES | None = None) -> Image:
"""Gets an image as a PIL Image object.
"""
Gets an image as a PIL Image object.
Args:
image_name: The name of the image to get.
mode: The color mode to convert the image to. If None, the original mode is used.
Returns:
The image as a PIL Image object.
:param image_name: The name of the image to get.
:param mode: The color mode to convert the image to. If None, the original mode is used.
"""
image = self._services.images.get_pil_image(image_name)
if mode and mode != image.mode:
@@ -223,76 +213,58 @@ class ImagesInterface(InvocationContextInterface):
return image
def get_metadata(self, image_name: str) -> Optional[MetadataField]:
"""Gets an image's metadata, if it has any.
"""
Gets an image's metadata, if it has any.
Args:
image_name: The name of the image to get the metadata for.
Returns:
The image's metadata, if it has any.
:param image_name: The name of the image to get the metadata for.
"""
return self._services.images.get_metadata(image_name)
def get_dto(self, image_name: str) -> ImageDTO:
"""Gets an image as an ImageDTO object.
"""
Gets an image as an ImageDTO object.
Args:
image_name: The name of the image to get.
Returns:
The image as an ImageDTO object.
:param image_name: The name of the image to get.
"""
return self._services.images.get_dto(image_name)
class TensorsInterface(InvocationContextInterface):
def save(self, tensor: Tensor) -> str:
"""Saves a tensor, returning its name.
"""
Saves a tensor, returning its name.
Args:
tensor: The tensor to save.
Returns:
The name of the saved tensor.
:param tensor: The tensor to save.
"""
name = self._services.tensors.save(obj=tensor)
return name
def load(self, name: str) -> Tensor:
"""Loads a tensor by name.
"""
Loads a tensor by name.
Args:
name: The name of the tensor to load.
Returns:
The loaded tensor.
:param name: The name of the tensor to load.
"""
return self._services.tensors.load(name)
class ConditioningInterface(InvocationContextInterface):
def save(self, conditioning_data: ConditioningFieldData) -> str:
"""Saves a conditioning data object, returning its name.
"""
Saves a conditioning data object, returning its name.
Args:
conditioning_data: The conditioning data to save.
Returns:
The name of the saved conditioning data.
:param conditioning_context_data: The conditioning data to save.
"""
name = self._services.conditioning.save(obj=conditioning_data)
return name
def load(self, name: str) -> ConditioningFieldData:
"""Loads conditioning data by name.
"""
Loads conditioning data by name.
Args:
name: The name of the conditioning data to load.
Returns:
The loaded conditioning data.
:param name: The name of the conditioning data to load.
"""
return self._services.conditioning.load(name)
@@ -300,143 +272,104 @@ class ConditioningInterface(InvocationContextInterface):
class ModelsInterface(InvocationContextInterface):
def exists(self, key: str) -> bool:
"""Checks if a model exists.
"""
Checks if a model exists.
Args:
key: The key of the model.
Returns:
True if the model exists, False if not.
:param key: The key of the model.
"""
return self._services.model_manager.store.exists(key)
def load(self, key: str, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
"""Loads a model.
"""
Loads a model.
Args:
key: The key of the model.
submodel_type: The submodel of the model to get.
Returns:
An object representing the loaded model.
:param key: The key of the model.
:param submodel_type: The submodel of the model to get.
:returns: An object representing the loaded model.
"""
# The model manager emits events as it loads the model. It needs the context data to build
# the event payloads.
return self._services.model_manager.load_model_by_key(
key=key, submodel_type=submodel_type, context_data=self._data
key=key, submodel_type=submodel_type, context_data=self._context_data
)
def load_by_attrs(
self, name: str, base: BaseModelType, type: ModelType, submodel_type: Optional[SubModelType] = None
self, model_name: str, base_model: BaseModelType, model_type: ModelType, submodel: Optional[SubModelType] = None
) -> LoadedModel:
"""Loads a model by its attributes.
"""
Loads a model by its attributes.
Args:
name: Name of the model.
base: The models' base type, e.g. `BaseModelType.StableDiffusion1`, `BaseModelType.StableDiffusionXL`, etc.
type: Type of the model, e.g. `ModelType.Main`, `ModelType.Vae`, etc.
submodel_type: The type of submodel to load, e.g. `SubModelType.UNet`, `SubModelType.TextEncoder`, etc. Only main
models have submodels.
Returns:
An object representing the loaded model.
:param model_name: Name of to be fetched.
:param base_model: Base model
:param model_type: Type of the model
:param submodel: For main (pipeline models), the submodel to fetch
"""
return self._services.model_manager.load_model_by_attr(
model_name=name,
base_model=base,
model_type=type,
submodel=submodel_type,
context_data=self._data,
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=submodel,
context_data=self._context_data,
)
def get_config(self, key: str) -> AnyModelConfig:
"""Gets a model's config.
"""
Gets a model's info, an dict-like object.
Args:
key: The key of the model.
Returns:
The model's config.
:param key: The key of the model.
"""
return self._services.model_manager.store.get_model(key=key)
def get_metadata(self, key: str) -> Optional[AnyModelRepoMetadata]:
"""Gets a model's metadata, if it has any.
"""
Gets a model's metadata, if it has any.
Args:
key: The key of the model.
Returns:
The model's metadata, if it has any.
:param key: The key of the model.
"""
return self._services.model_manager.store.get_metadata(key=key)
def search_by_path(self, path: Path) -> list[AnyModelConfig]:
"""Searches for models by path.
"""
Searches for models by path.
Args:
path: The path to search for.
Returns:
A list of models that match the path.
:param path: The path to search for.
"""
return self._services.model_manager.store.search_by_path(path)
def search_by_attrs(
self,
name: Optional[str] = None,
base: Optional[BaseModelType] = None,
type: Optional[ModelType] = None,
format: Optional[ModelFormat] = None,
model_name: Optional[str] = None,
base_model: Optional[BaseModelType] = None,
model_type: Optional[ModelType] = None,
model_format: Optional[ModelFormat] = None,
) -> list[AnyModelConfig]:
"""Searches for models by attributes.
"""
Searches for models by attributes.
Args:
name: The name to search for (exact match).
base: The base to search for, e.g. `BaseModelType.StableDiffusion1`, `BaseModelType.StableDiffusionXL`, etc.
type: Type type of model to search for, e.g. `ModelType.Main`, `ModelType.Vae`, etc.
format: The format of model to search for, e.g. `ModelFormat.Checkpoint`, `ModelFormat.Diffusers`, etc.
Returns:
A list of models that match the attributes.
:param model_name: Name of to be fetched.
:param base_model: Base model
:param model_type: Type of the model
:param submodel: For main (pipeline models), the submodel to fetch
"""
return self._services.model_manager.store.search_by_attr(
model_name=name,
base_model=base,
model_type=type,
model_format=format,
model_name=model_name,
base_model=base_model,
model_type=model_type,
model_format=model_format,
)
class ConfigInterface(InvocationContextInterface):
def get(self) -> InvokeAIAppConfig:
"""Gets the app's config.
Returns:
The app's config.
"""
"""Gets the app's config."""
return self._services.configuration.get_config()
class UtilInterface(InvocationContextInterface):
def __init__(
self, services: InvocationServices, data: InvocationContextData, cancel_event: threading.Event
) -> None:
super().__init__(services, data)
self._cancel_event = cancel_event
def is_canceled(self) -> bool:
"""Checks if the current session has been canceled.
Returns:
True if the current session has been canceled, False if not.
"""
return self._cancel_event.is_set()
def sd_step_callback(self, intermediate_state: PipelineIntermediateState, base_model: BaseModelType) -> None:
"""
The step callback emits a progress event with the current step, the total number of
@@ -444,32 +377,27 @@ class UtilInterface(InvocationContextInterface):
This should be called after each denoising step.
Args:
intermediate_state: The intermediate state of the diffusion pipeline.
base_model: The base model for the current denoising step.
:param intermediate_state: The intermediate state of the diffusion pipeline.
:param base_model: The base model for the current denoising step.
"""
# The step callback needs access to the events and the invocation queue services, but this
# represents a dangerous level of access.
#
# We wrap the step callback so that nodes do not have direct access to these services.
stable_diffusion_step_callback(
context_data=self._data,
context_data=self._context_data,
intermediate_state=intermediate_state,
base_model=base_model,
invocation_queue=self._services.queue,
events=self._services.events,
is_canceled=self.is_canceled,
)
class InvocationContext:
"""Provides access to various services and data for the current invocation.
Attributes:
images (ImagesInterface): Methods to save, get and update images and their metadata.
tensors (TensorsInterface): Methods to save and get tensors, including image, noise, masks, and masked images.
conditioning (ConditioningInterface): Methods to save and get conditioning data.
models (ModelsInterface): Methods to check if a model exists, get a model, and get a model's info.
logger (LoggerInterface): The app logger.
config (ConfigInterface): The app config.
util (UtilInterface): Utility methods, including a method to check if an invocation was canceled and step callbacks.
boards (BoardsInterface): Methods to interact with boards.
"""
The `InvocationContext` provides access to various services and data for the current invocation.
"""
def __init__(
@@ -482,54 +410,50 @@ class InvocationContext:
config: ConfigInterface,
util: UtilInterface,
boards: BoardsInterface,
data: InvocationContextData,
context_data: InvocationContextData,
services: InvocationServices,
) -> None:
self.images = images
"""Methods to save, get and update images and their metadata."""
"""Provides methods to save, get and update images and their metadata."""
self.tensors = tensors
"""Methods to save and get tensors, including image, noise, masks, and masked images."""
"""Provides methods to save and get tensors, including image, noise, masks, and masked images."""
self.conditioning = conditioning
"""Methods to save and get conditioning data."""
"""Provides methods to save and get conditioning data."""
self.models = models
"""Methods to check if a model exists, get a model, and get a model's info."""
"""Provides methods to check if a model exists, get a model, and get a model's info."""
self.logger = logger
"""The app logger."""
"""Provides access to the app logger."""
self.config = config
"""The app config."""
"""Provides access to the app's config."""
self.util = util
"""Utility methods, including a method to check if an invocation was canceled and step callbacks."""
"""Provides utility methods."""
self.boards = boards
"""Methods to interact with boards."""
self._data = data
"""An internal API providing access to data about the current queue item and invocation. You probably shouldn't use this. It may change without warning."""
"""Provides methods to interact with boards."""
self._data = context_data
"""Provides data about the current queue item and invocation. This is an internal API and may change without warning."""
self._services = services
"""An internal API providing access to all application services. You probably shouldn't use this. It may change without warning."""
"""Provides access to the full application services. This is an internal API and may change without warning."""
def build_invocation_context(
services: InvocationServices,
data: InvocationContextData,
cancel_event: threading.Event,
context_data: InvocationContextData,
) -> InvocationContext:
"""Builds the invocation context for a specific invocation execution.
"""
Builds the invocation context for a specific invocation execution.
Args:
services: The invocation services to wrap.
data: The invocation context data.
Returns:
The invocation context.
:param invocation_services: The invocation services to wrap.
:param invocation_context_data: The invocation context data.
"""
logger = LoggerInterface(services=services, data=data)
images = ImagesInterface(services=services, data=data)
tensors = TensorsInterface(services=services, data=data)
models = ModelsInterface(services=services, data=data)
config = ConfigInterface(services=services, data=data)
util = UtilInterface(services=services, data=data, cancel_event=cancel_event)
conditioning = ConditioningInterface(services=services, data=data)
boards = BoardsInterface(services=services, data=data)
logger = LoggerInterface(services=services, context_data=context_data)
images = ImagesInterface(services=services, context_data=context_data)
tensors = TensorsInterface(services=services, context_data=context_data)
models = ModelsInterface(services=services, context_data=context_data)
config = ConfigInterface(services=services, context_data=context_data)
util = UtilInterface(services=services, context_data=context_data)
conditioning = ConditioningInterface(services=services, context_data=context_data)
boards = BoardsInterface(services=services, context_data=context_data)
ctx = InvocationContext(
images=images,
@@ -537,7 +461,7 @@ def build_invocation_context(
config=config,
tensors=tensors,
models=models,
data=data,
context_data=context_data,
util=util,
conditioning=conditioning,
services=services,

View File

@@ -3,6 +3,7 @@
import json
import sqlite3
from hashlib import sha1
from logging import Logger
from pathlib import Path
from typing import Optional
@@ -21,7 +22,7 @@ from invokeai.backend.model_manager.config import (
ModelConfigFactory,
ModelType,
)
from invokeai.backend.model_manager.hash import ModelHash
from invokeai.backend.model_manager.hash import FastModelHash
ModelsValidator = TypeAdapter(AnyModelConfig)
@@ -72,27 +73,19 @@ class MigrateModelYamlToDb1:
base_type, model_type, model_name = str(model_key).split("/")
try:
hash = ModelHash().hash(self.config.models_path / stanza.path)
hash = FastModelHash.hash(self.config.models_path / stanza.path)
except OSError:
self.logger.warning(f"The model at {stanza.path} is not a valid file or directory. Skipping migration.")
continue
assert isinstance(model_key, str)
new_key = sha1(model_key.encode("utf-8")).hexdigest()
stanza["base"] = BaseModelType(base_type)
stanza["type"] = ModelType(model_type)
stanza["name"] = model_name
stanza["original_hash"] = hash
stanza["current_hash"] = hash
new_key = hash # deterministic key assignment
# special case for ip adapters, which need the new `image_encoder_model_id` field
if stanza["type"] == ModelType.IPAdapter:
try:
stanza["image_encoder_model_id"] = self._get_image_encoder_model_id(
self.config.models_path / stanza.path
)
except OSError:
self.logger.warning(f"Could not determine image encoder for {stanza.path}. Skipping.")
continue
new_config: AnyModelConfig = ModelsValidator.validate_python(stanza) # type: ignore # see https://github.com/pydantic/pydantic/discussions/7094
@@ -102,7 +95,7 @@ class MigrateModelYamlToDb1:
self.logger.info(f"Updating model {model_name} with information from models.yaml using key {key}")
self._update_model(key, new_config)
else:
self.logger.info(f"Adding model {model_name} with key {new_key}")
self.logger.info(f"Adding model {model_name} with key {model_key}")
self._add_model(new_key, new_config)
except DuplicateModelException:
self.logger.warning(f"Model {model_name} is already in the database")
@@ -156,8 +149,3 @@ class MigrateModelYamlToDb1:
)
except sqlite3.IntegrityError as exc:
raise DuplicateModelException(f"{record.name}: model is already in database") from exc
def _get_image_encoder_model_id(self, model_path: Path) -> str:
with open(model_path / "image_encoder.txt") as f:
encoder = f.read()
return encoder.strip()

View File

@@ -17,7 +17,8 @@ class MigrateCallback(Protocol):
See :class:`Migration` for an example.
"""
def __call__(self, cursor: sqlite3.Cursor) -> None: ...
def __call__(self, cursor: sqlite3.Cursor) -> None:
...
class MigrationError(RuntimeError):

View File

@@ -1,9 +1,9 @@
from typing import TYPE_CHECKING, Callable
from typing import TYPE_CHECKING
import torch
from PIL import Image
from invokeai.app.services.session_processor.session_processor_common import CanceledException, ProgressImage
from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException, ProgressImage
from invokeai.backend.model_manager.config import BaseModelType
from ...backend.stable_diffusion import PipelineIntermediateState
@@ -11,6 +11,7 @@ from ...backend.util.util import image_to_dataURL
if TYPE_CHECKING:
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invocation_queue.invocation_queue_base import InvocationQueueABC
from invokeai.app.services.shared.invocation_context import InvocationContextData
@@ -33,10 +34,10 @@ def stable_diffusion_step_callback(
context_data: "InvocationContextData",
intermediate_state: PipelineIntermediateState,
base_model: BaseModelType,
invocation_queue: "InvocationQueueABC",
events: "EventServiceBase",
is_canceled: Callable[[], bool],
) -> None:
if is_canceled():
if invocation_queue.is_canceled(context_data.session_id):
raise CanceledException
# Some schedulers report not only the noisy latents at the current timestep,
@@ -114,12 +115,12 @@ def stable_diffusion_step_callback(
dataURL = image_to_dataURL(image, image_format="JPEG")
events.emit_generator_progress(
queue_id=context_data.queue_item.queue_id,
queue_item_id=context_data.queue_item.item_id,
queue_batch_id=context_data.queue_item.batch_id,
graph_execution_state_id=context_data.queue_item.session_id,
queue_id=context_data.queue_id,
queue_item_id=context_data.queue_item_id,
queue_batch_id=context_data.batch_id,
graph_execution_state_id=context_data.session_id,
node_id=context_data.invocation.id,
source_node_id=context_data.source_invocation_id,
source_node_id=context_data.source_node_id,
progress_image=ProgressImage(width=width, height=height, dataURL=dataURL),
step=intermediate_state.step,
order=intermediate_state.order,

View File

@@ -1,47 +1,8 @@
import re
from typing import List, Tuple
import invokeai.backend.util.logging as logger
from invokeai.app.services.model_records import UnknownModelException
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.config import BaseModelType, ModelType
from invokeai.backend.textual_inversion import TextualInversionModelRaw
def extract_ti_triggers_from_prompt(prompt: str) -> List[str]:
ti_triggers: List[str] = []
def extract_ti_triggers_from_prompt(prompt: str) -> list[str]:
ti_triggers = []
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt):
ti_triggers.append(str(trigger))
ti_triggers.append(trigger)
return ti_triggers
def generate_ti_list(
prompt: str, base: BaseModelType, context: InvocationContext
) -> List[Tuple[str, TextualInversionModelRaw]]:
ti_list: List[Tuple[str, TextualInversionModelRaw]] = []
for trigger in extract_ti_triggers_from_prompt(prompt):
name_or_key = trigger[1:-1]
try:
loaded_model = context.models.load(key=name_or_key)
model = loaded_model.model
assert isinstance(model, TextualInversionModelRaw)
assert loaded_model.config.base == base
ti_list.append((name_or_key, model))
except UnknownModelException:
try:
loaded_model = context.models.load_by_attrs(
name=name_or_key, base=base, type=ModelType.TextualInversion
)
model = loaded_model.model
assert isinstance(model, TextualInversionModelRaw)
assert loaded_model.config.base == base
ti_list.append((name_or_key, model))
except UnknownModelException:
pass
except ValueError:
logger.warning(f'trigger: "{trigger}" more than one similarly-named textual inversion models')
except AssertionError:
logger.warning(f'trigger: "{trigger}" not a valid textual inversion model for this graph')
except Exception:
logger.warning(f'Failed to load TI model for trigger: "{trigger}"')
return ti_list

View File

@@ -1,7 +1,6 @@
"""
Initialization file for invokeai.backend.image_util methods.
"""
from .patchmatch import PatchMatch # noqa: F401
from .pngwriter import PngWriter, PromptFormatter, retrieve_metadata, write_metadata # noqa: F401
from .seamless import configure_model_padding # noqa: F401

View File

@@ -3,7 +3,6 @@ This module defines a singleton object, "invisible_watermark" that
wraps the invisible watermark model. It respects the global "invisible_watermark"
configuration variable, that allows the watermarking to be supressed.
"""
import cv2
import numpy as np
from imwatermark import WatermarkEncoder

View File

@@ -4,7 +4,6 @@ wraps the actual patchmatch object. It respects the global
"try_patchmatch" attribute, so that patchmatch loading can
be suppressed or deferred
"""
import numpy as np
import invokeai.backend.util.logging as logger

View File

@@ -6,7 +6,6 @@ PngWriter -- Converts Images generated by T2I into PNGs, finds
Exports function retrieve_metadata(path)
"""
import json
import os
import re

View File

@@ -3,7 +3,6 @@ This module defines a singleton object, "safety_checker" that
wraps the safety_checker model. It respects the global "nsfw_checker"
configuration variable, that allows the checker to be supressed.
"""
import numpy as np
from PIL import Image

View File

@@ -1,7 +1,6 @@
"""
Check that the invokeai_root is correctly configured and exit if not.
"""
import sys
from invokeai.app.services.config import InvokeAIAppConfig
@@ -9,6 +8,7 @@ from invokeai.app.services.config import InvokeAIAppConfig
def check_invokeai_root(config: InvokeAIAppConfig):
try:
assert config.model_conf_path.exists(), f"{config.model_conf_path} not found"
assert config.db_path.parent.exists(), f"{config.db_path.parent} not found"
assert config.models_path.exists(), f"{config.models_path} not found"
if not config.ignore_missing_core_models:

View File

@@ -1,12 +1,14 @@
"""Utility (backend) functions used by model_install.py"""
import re
from logging import Logger
from pathlib import Path
from typing import Any, Dict, List, Optional
import omegaconf
from huggingface_hub import HfFolder
from pydantic import BaseModel, Field
from pydantic.dataclasses import dataclass
from pydantic.networks import AnyHttpUrl
from requests import HTTPError
from tqdm import tqdm
@@ -16,8 +18,12 @@ from invokeai.app.services.download import DownloadQueueService
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.image_files.image_files_disk import DiskImageFileStorage
from invokeai.app.services.model_install import (
HFModelSource,
LocalModelSource,
ModelInstallService,
ModelInstallServiceBase,
ModelSource,
URLModelSource,
)
from invokeai.app.services.model_metadata import ModelMetadataStoreSQL
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL
@@ -25,6 +31,7 @@ from invokeai.app.services.shared.sqlite.sqlite_util import init_db
from invokeai.backend.model_manager import (
BaseModelType,
InvalidModelConfigException,
ModelRepoVariant,
ModelType,
)
from invokeai.backend.model_manager.metadata import UnknownMetadataException
@@ -219,13 +226,37 @@ class InstallHelper(object):
additional_models.append(reverse_source[requirement])
model_list.extend(additional_models)
def _make_install_source(self, model_info: UnifiedModelInfo) -> ModelSource:
assert model_info.source
model_path_id_or_url = model_info.source.strip("\"' ")
model_path = Path(model_path_id_or_url)
if model_path.exists(): # local file on disk
return LocalModelSource(path=model_path.absolute(), inplace=True)
# parsing huggingface repo ids
# we're going to do a little trick that allows for extended repo_ids of form "foo/bar:fp16"
variants = "|".join([x.lower() for x in ModelRepoVariant.__members__])
if match := re.match(f"^([^/]+/[^/]+?)(?::({variants}))?$", model_path_id_or_url):
repo_id = match.group(1)
repo_variant = ModelRepoVariant(match.group(2)) if match.group(2) else None
subfolder = Path(model_info.subfolder) if model_info.subfolder else None
return HFModelSource(
repo_id=repo_id,
access_token=HfFolder.get_token(),
subfolder=subfolder,
variant=repo_variant,
)
if re.match(r"^(http|https):", model_path_id_or_url):
return URLModelSource(url=AnyHttpUrl(model_path_id_or_url))
raise ValueError(f"Unsupported model source: {model_path_id_or_url}")
def add_or_delete(self, selections: InstallSelections) -> None:
"""Add or delete selected models."""
installer = self._installer
self._add_required_models(selections.install_models)
for model in selections.install_models:
assert model.source
model_path_id_or_url = model.source.strip("\"' ")
source = self._make_install_source(model)
config = (
{
"description": model.description,
@@ -236,12 +267,12 @@ class InstallHelper(object):
)
try:
installer.heuristic_import(
source=model_path_id_or_url,
installer.import_model(
source=source,
config=config,
)
except (UnknownMetadataException, InvalidModelConfigException, HTTPError, OSError) as e:
self._logger.warning(f"{model.source}: {e}")
self._logger.warning(f"{source}: {e}")
for model_to_remove in selections.remove_models:
parts = model_to_remove.split("/")

View File

@@ -939,7 +939,7 @@ def main() -> None:
# run this unconditionally in case new directories need to be added
initialize_rootdir(config.root_path, opt.yes_to_all)
# this will initialize and populate the models tables if not present
# this will initialize the models.yaml file if not present
install_helper = InstallHelper(config, logger)
models_to_download = default_user_selections(opt, install_helper)

View File

@@ -0,0 +1,182 @@
# copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0)
# and modified as needed
# tencent-ailab comment:
# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.models.attention_processor import AttnProcessor2_0 as DiffusersAttnProcessor2_0
from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionProcessorWeights
# Create a version of AttnProcessor2_0 that is a sub-class of nn.Module. This is required for IP-Adapter state_dict
# loading.
class AttnProcessor2_0(DiffusersAttnProcessor2_0, nn.Module):
def __init__(self):
DiffusersAttnProcessor2_0.__init__(self)
nn.Module.__init__(self)
def __call__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
ip_adapter_image_prompt_embeds=None,
):
"""Re-definition of DiffusersAttnProcessor2_0.__call__(...) that accepts and ignores the
ip_adapter_image_prompt_embeds parameter.
"""
return DiffusersAttnProcessor2_0.__call__(
self, attn, hidden_states, encoder_hidden_states, attention_mask, temb
)
class IPAttnProcessor2_0(torch.nn.Module):
r"""
Attention processor for IP-Adapater for PyTorch 2.0.
Args:
hidden_size (`int`):
The hidden size of the attention layer.
cross_attention_dim (`int`):
The number of channels in the `encoder_hidden_states`.
scale (`float`, defaults to 1.0):
the weight scale of image prompt.
"""
def __init__(self, weights: list[IPAttentionProcessorWeights], scales: list[float]):
super().__init__()
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
assert len(weights) == len(scales)
self._weights = weights
self._scales = scales
def __call__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
ip_adapter_image_prompt_embeds=None,
):
"""Apply IP-Adapter attention.
Args:
ip_adapter_image_prompt_embeds (torch.Tensor): The image prompt embeddings.
Shape: (batch_size, num_ip_images, seq_len, ip_embedding_len).
"""
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
if encoder_hidden_states is not None:
# If encoder_hidden_states is not None, then we are doing cross-attention, not self-attention. In this case,
# we will apply IP-Adapter conditioning. We validate the inputs for IP-Adapter conditioning here.
assert ip_adapter_image_prompt_embeds is not None
assert len(ip_adapter_image_prompt_embeds) == len(self._weights)
for ipa_embed, ipa_weights, scale in zip(
ip_adapter_image_prompt_embeds, self._weights, self._scales, strict=True
):
# The batch dimensions should match.
assert ipa_embed.shape[0] == encoder_hidden_states.shape[0]
# The token_len dimensions should match.
assert ipa_embed.shape[-1] == encoder_hidden_states.shape[-1]
ip_hidden_states = ipa_embed
# Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding)
ip_key = ipa_weights.to_k_ip(ip_hidden_states)
ip_value = ipa_weights.to_v_ip(ip_hidden_states)
# Expected ip_key and ip_value shape: (batch_size, num_ip_images, ip_seq_len, head_dim * num_heads)
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# Expected ip_key and ip_value shape: (batch_size, num_heads, num_ip_images * ip_seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
ip_hidden_states = F.scaled_dot_product_attention(
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)
# Expected ip_hidden_states shape: (batch_size, num_heads, query_seq_len, head_dim)
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
ip_hidden_states = ip_hidden_states.to(query.dtype)
# Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim)
hidden_states = hidden_states + scale * ip_hidden_states
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states

View File

@@ -1,55 +1,52 @@
from contextlib import contextmanager
from typing import Optional
from diffusers.models import UNet2DConditionModel
from invokeai.backend.ip_adapter.attention_processor import AttnProcessor2_0, IPAttnProcessor2_0
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.stable_diffusion.diffusion.custom_attention import CustomAttnProcessor2_0
class UNetAttentionPatcher:
"""A class for patching a UNet with CustomAttnProcessor2_0 attention layers."""
class UNetPatcher:
"""A class that contains multiple IP-Adapters and can apply them to a UNet."""
def __init__(self, ip_adapters: Optional[list[IPAdapter]]):
def __init__(self, ip_adapters: list[IPAdapter]):
self._ip_adapters = ip_adapters
self._ip_adapter_scales = None
if self._ip_adapters is not None:
self._ip_adapter_scales = [1.0] * len(self._ip_adapters)
self._scales = [1.0] * len(self._ip_adapters)
def set_scale(self, idx: int, value: float):
self._ip_adapter_scales[idx] = value
self._scales[idx] = value
def _prepare_attention_processors(self, unet: UNet2DConditionModel):
"""Prepare a dict of attention processors that can be injected into a unet, and load the IP-Adapter attention
weights into them (if IP-Adapters are being applied).
weights into them.
Note that the `unet` param is only used to determine attention block dimensions and naming.
"""
# Construct a dict of attention processors based on the UNet's architecture.
attn_procs = {}
for idx, name in enumerate(unet.attn_processors.keys()):
if name.endswith("attn1.processor") or self._ip_adapters is None:
# "attn1" processors do not use IP-Adapters.
attn_procs[name] = CustomAttnProcessor2_0()
if name.endswith("attn1.processor"):
attn_procs[name] = AttnProcessor2_0()
else:
# Collect the weights from each IP Adapter for the idx'th attention processor.
attn_procs[name] = CustomAttnProcessor2_0(
attn_procs[name] = IPAttnProcessor2_0(
[ip_adapter.attn_weights.get_attention_processor_weights(idx) for ip_adapter in self._ip_adapters],
self._ip_adapter_scales,
self._scales,
)
return attn_procs
@contextmanager
def apply_ip_adapter_attention(self, unet: UNet2DConditionModel):
"""A context manager that patches `unet` with CustomAttnProcessor2_0 attention layers."""
"""A context manager that patches `unet` with IP-Adapter attention processors."""
attn_procs = self._prepare_attention_processors(unet)
orig_attn_processors = unet.attn_processors
try:
# Note to future devs: set_attn_processor(...) does something slightly unexpected - it pops elements from
# the passed dict. So, if you wanted to keep the dict for future use, you'd have to make a
# moderately-shallow copy of it. E.g. `attn_procs_copy = {k: v for k, v in attn_procs.items()}`.
# Note to future devs: set_attn_processor(...) does something slightly unexpected - it pops elements from the
# passed dict. So, if you wanted to keep the dict for future use, you'd have to make a moderately-shallow copy
# of it. E.g. `attn_procs_copy = {k: v for k, v in attn_procs.items()}`.
unet.set_attn_processor(attn_procs)
yield None
finally:

View File

@@ -1,5 +1,4 @@
"""Re-export frequently-used symbols from the Model Manager backend."""
from .config import (
AnyModel,
AnyModelConfig,

View File

@@ -19,7 +19,6 @@ Typical usage:
Validation errors will raise an InvalidModelConfigException error.
"""
import time
from enum import Enum
from typing import Literal, Optional, Type, Union
@@ -139,16 +138,9 @@ class ModelConfigBase(BaseModel):
source: Optional[str] = Field(description="model original source (path, URL or repo_id)", default=None)
last_modified: Optional[float] = Field(description="timestamp for modification time", default_factory=time.time)
@staticmethod
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
schema["required"].extend(
["key", "base", "type", "format", "original_hash", "current_hash", "source", "last_modified"]
)
model_config = ConfigDict(
use_enum_values=False,
validate_assignment=True,
json_schema_extra=json_schema_extra,
)
def update(self, attributes: Dict[str, Any]) -> None:
@@ -235,6 +227,37 @@ class MainDiffusersConfig(_DiffusersConfig, _MainConfig):
type: Literal[ModelType.Main] = ModelType.Main
class ONNXSD1Config(_MainConfig):
"""Model config for ONNX format models based on sd-1."""
type: Literal[ModelType.ONNX] = ModelType.ONNX
format: Literal[ModelFormat.Onnx, ModelFormat.Olive]
base: Literal[BaseModelType.StableDiffusion1] = BaseModelType.StableDiffusion1
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
upcast_attention: bool = False
class ONNXSD2Config(_MainConfig):
"""Model config for ONNX format models based on sd-2."""
type: Literal[ModelType.ONNX] = ModelType.ONNX
format: Literal[ModelFormat.Onnx, ModelFormat.Olive]
# No yaml config file for ONNX, so these are part of config
base: Literal[BaseModelType.StableDiffusion2] = BaseModelType.StableDiffusion2
prediction_type: SchedulerPredictionType = SchedulerPredictionType.VPrediction
upcast_attention: bool = True
class ONNXSDXLConfig(_MainConfig):
"""Model config for ONNX format models based on sdxl."""
type: Literal[ModelType.ONNX] = ModelType.ONNX
format: Literal[ModelFormat.Onnx, ModelFormat.Olive]
# No yaml config file for ONNX, so these are part of config
base: Literal[BaseModelType.StableDiffusionXL] = BaseModelType.StableDiffusionXL
prediction_type: SchedulerPredictionType = SchedulerPredictionType.VPrediction
class IPAdapterConfig(ModelConfigBase):
"""Model config for IP Adaptor format models."""
@@ -257,6 +280,7 @@ class T2IConfig(ModelConfigBase):
format: Literal[ModelFormat.Diffusers]
_ONNXConfig = Annotated[Union[ONNXSD1Config, ONNXSD2Config, ONNXSDXLConfig], Field(discriminator="base")]
_ControlNetConfig = Annotated[
Union[ControlNetDiffusersConfig, ControlNetCheckpointConfig],
Field(discriminator="format"),
@@ -266,6 +290,7 @@ _MainModelConfig = Annotated[Union[MainDiffusersConfig, MainCheckpointConfig], F
AnyModelConfig = Union[
_MainModelConfig,
_ONNXConfig,
_VaeConfig,
_ControlNetConfig,
# ModelConfigBase,

View File

@@ -15,7 +15,7 @@
#
# Adapted for use in InvokeAI by Lincoln Stein, July 2023
#
"""Conversion script for the Stable Diffusion checkpoints."""
""" Conversion script for the Stable Diffusion checkpoints."""
import re
from contextlib import nullcontext

View File

@@ -11,175 +11,56 @@ from invokeai.backend.model_managre.model_hash import FastModelHash
import hashlib
import os
from pathlib import Path
from typing import Callable, Literal, Optional, Union
from typing import Dict, Union
from blake3 import blake3
MODEL_FILE_EXTENSIONS = (".ckpt", ".safetensors", ".bin", ".pt", ".pth")
ALGORITHM = Literal[
"md5",
"sha1",
"sha224",
"sha256",
"sha384",
"sha512",
"blake2b",
"blake2s",
"sha3_224",
"sha3_256",
"sha3_384",
"sha3_512",
"shake_128",
"shake_256",
"blake3",
]
from imohash import hashfile
class ModelHash:
"""
Creates a hash of a model using a specified algorithm.
class FastModelHash(object):
"""FastModelHash obect provides one public class method, hash()."""
Args:
algorithm: Hashing algorithm to use. Defaults to BLAKE3.
file_filter: A function that takes a file name and returns True if the file should be included in the hash.
@classmethod
def hash(cls, model_location: Union[str, Path]) -> str:
"""
Return hexdigest string for model located at model_location.
If the model is a single file, it is hashed directly using the provided algorithm.
If the model is a directory, each model weights file in the directory is hashed using the provided algorithm.
Only files with the following extensions are hashed: .ckpt, .safetensors, .bin, .pt, .pth
The final hash is computed by hashing the hashes of all model files in the directory using BLAKE3, ensuring
that directory hashes are never weaker than the file hashes.
Usage:
```py
# BLAKE3 hash
ModelHash().hash("path/to/some/model.safetensors")
# MD5
ModelHash("md5").hash("path/to/model/dir/")
```
"""
def __init__(self, algorithm: ALGORITHM = "blake3", file_filter: Optional[Callable[[str], bool]] = None) -> None:
if algorithm == "blake3":
self._hash_file = self._blake3
elif algorithm in hashlib.algorithms_available:
self._hash_file = self._get_hashlib(algorithm)
:param model_location: Path to the model
"""
model_location = Path(model_location)
if model_location.is_file():
return cls._hash_file(model_location)
elif model_location.is_dir():
return cls._hash_dir(model_location)
else:
raise ValueError(f"Algorithm {algorithm} not available")
raise OSError(f"Not a valid file or directory: {model_location}")
self._file_filter = file_filter or self._default_file_filter
def hash(self, model_path: Union[str, Path]) -> str:
@classmethod
def _hash_file(cls, model_location: Union[str, Path]) -> str:
"""
Return hexdigest of hash of model located at model_path using the algorithm provided at class instantiation.
Fasthash a single file and return its hexdigest.
If model_path is a directory, the hash is computed by hashing the hashes of all model files in the
directory. The final composite hash is always computed using BLAKE3.
Args:
model_path: Path to the model
Returns:
str: Hexdigest of the hash of the model
:param model_location: Path to the model file
"""
# we return md5 hash of the filehash to make it shorter
# cryptographic security not needed here
return hashlib.md5(hashfile(model_location)).hexdigest()
model_path = Path(model_path)
if model_path.is_file():
return self._hash_file(model_path)
elif model_path.is_dir():
return self._hash_dir(model_path)
else:
raise OSError(f"Not a valid file or directory: {model_path}")
@classmethod
def _hash_dir(cls, model_location: Union[str, Path]) -> str:
components: Dict[str, str] = {}
def _hash_dir(self, dir: Path) -> str:
"""Compute the hash for all files in a directory and return a hexdigest.
for root, _dirs, files in os.walk(model_location):
for file in files:
# only tally tensor files because diffusers config files change slightly
# depending on how the model was downloaded/converted.
if not file.endswith((".ckpt", ".safetensors", ".bin", ".pt", ".pth")):
continue
path = (Path(root) / file).as_posix()
fast_hash = cls._hash_file(path)
components.update({path: fast_hash})
Args:
dir: Path to the directory
Returns:
str: Hexdigest of the hash of the directory
"""
model_component_paths = self._get_file_paths(dir, self._file_filter)
component_hashes: list[str] = []
for component in sorted(model_component_paths):
component_hashes.append(self._hash_file(component))
# BLAKE3 is cryptographically secure. We may as well fall back on a secure algorithm
# for the composite hash
composite_hasher = blake3()
for h in component_hashes:
composite_hasher.update(h.encode("utf-8"))
return composite_hasher.hexdigest()
@staticmethod
def _get_file_paths(model_path: Path, file_filter: Callable[[str], bool]) -> list[Path]:
"""Return a list of all model files in the directory.
Args:
model_path: Path to the model
file_filter: Function that takes a file name and returns True if the file should be included in the list.
Returns:
List of all model files in the directory
"""
files: list[Path] = []
for root, _dirs, _files in os.walk(model_path):
for file in _files:
if file_filter(file):
files.append(Path(root, file))
return files
@staticmethod
def _blake3(file_path: Path) -> str:
"""Hashes a file using BLAKE3
Args:
file_path: Path to the file to hash
Returns:
Hexdigest of the hash of the file
"""
file_hasher = blake3(max_threads=blake3.AUTO)
file_hasher.update_mmap(file_path)
return file_hasher.hexdigest()
@staticmethod
def _get_hashlib(algorithm: ALGORITHM) -> Callable[[Path], str]:
"""Factory function that returns a function to hash a file with the given algorithm.
Args:
algorithm: Hashing algorithm to use
Returns:
A function that hashes a file using the given algorithm
"""
def hashlib_hasher(file_path: Path) -> str:
"""Hashes a file using a hashlib algorithm. Uses `memoryview` to avoid reading the entire file into memory."""
hasher = hashlib.new(algorithm)
buffer = bytearray(128 * 1024)
mv = memoryview(buffer)
with open(file_path, "rb", buffering=0) as f:
while n := f.readinto(mv):
hasher.update(mv[:n])
return hasher.hexdigest()
return hashlib_hasher
@staticmethod
def _default_file_filter(file_path: str) -> bool:
"""A default file filter that only includes files with the following extensions: .ckpt, .safetensors, .bin, .pt, .pth
Args:
file_path: Path to the file
Returns:
True if the file matches the given extensions, otherwise False
"""
return file_path.endswith(MODEL_FILE_EXTENSIONS)
# hash all the model hashes together, using alphabetic file order
md5 = hashlib.md5()
for _path, fast_hash in sorted(components.items()):
md5.update(fast_hash.encode("utf-8"))
return md5.hexdigest()

View File

@@ -2,7 +2,6 @@
"""
Init file for the model loader.
"""
from importlib import import_module
from pathlib import Path

View File

@@ -1,7 +1,6 @@
"""
Disk-based converted model cache.
"""
from abc import ABC, abstractmethod
from pathlib import Path

View File

@@ -10,7 +10,7 @@ model will be cleared and (re)loaded from disk when next needed.
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from logging import Logger
from typing import Dict, Generic, Optional, TypeVar
from typing import Dict, Generic, Optional, Set, TypeVar
import torch
@@ -89,8 +89,24 @@ class ModelCacheBase(ABC, Generic[T]):
@property
@abstractmethod
def execution_device(self) -> torch.device:
"""Return the exection device (e.g. "cuda" for VRAM)."""
def execution_devices(self) -> Set[torch.device]:
"""Return the set of available execution devices."""
pass
@abstractmethod
def acquire_execution_device(self, timeout: int = 0) -> torch.device:
"""
Pick the next available execution device.
If all devices are currently engaged (locked), then
block until timeout seconds have passed and raise a
TimeoutError if no devices are available.
"""
pass
@abstractmethod
def release_execution_device(self, device: torch.device) -> None:
"""Release a previously-acquired execution device."""
pass
@property
@@ -111,7 +127,7 @@ class ModelCacheBase(ABC, Generic[T]):
pass
@abstractmethod
def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None:
def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], device: torch.device) -> None:
"""Move model into the indicated device."""
pass

View File

@@ -25,7 +25,8 @@ import sys
import time
from contextlib import suppress
from logging import Logger
from typing import Dict, List, Optional
from threading import BoundedSemaphore, Lock
from typing import Dict, List, Optional, Set
import torch
@@ -61,8 +62,8 @@ class ModelCache(ModelCacheBase[AnyModel]):
self,
max_cache_size: float = DEFAULT_MAX_CACHE_SIZE,
max_vram_cache_size: float = DEFAULT_MAX_VRAM_CACHE_SIZE,
execution_device: torch.device = torch.device("cuda"),
storage_device: torch.device = torch.device("cpu"),
execution_devices: Optional[Set[torch.device]] = None,
precision: torch.dtype = torch.float16,
sequential_offload: bool = False,
lazy_offloading: bool = True,
@@ -74,7 +75,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
Initialize the model RAM cache.
:param max_cache_size: Maximum size of the RAM cache [6.0 GB]
:param execution_device: Torch device to load active model into [torch.device('cuda')]
:param execution_devices: Set of torch device to load active model into [calculated]
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
:param precision: Precision for loaded models [torch.float16]
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
@@ -89,7 +90,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
self._precision: torch.dtype = precision
self._max_cache_size: float = max_cache_size
self._max_vram_cache_size: float = max_vram_cache_size
self._execution_device: torch.device = execution_device
self._execution_devices: Set[torch.device] = execution_devices or self._get_execution_devices()
self._storage_device: torch.device = storage_device
self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__)
self._log_memory_usage = log_memory_usage or self._logger.level == logging.DEBUG
@@ -99,6 +100,10 @@ class ModelCache(ModelCacheBase[AnyModel]):
self._cached_models: Dict[str, CacheRecord[AnyModel]] = {}
self._cache_stack: List[str] = []
self._lock = Lock()
self._free_execution_device = BoundedSemaphore(len(self._execution_devices))
self._busy_execution_devices: Set[torch.device] = set()
@property
def logger(self) -> Logger:
"""Return the logger used by the cache."""
@@ -115,9 +120,24 @@ class ModelCache(ModelCacheBase[AnyModel]):
return self._storage_device
@property
def execution_device(self) -> torch.device:
"""Return the exection device (e.g. "cuda" for VRAM)."""
return self._execution_device
def execution_devices(self) -> Set[torch.device]:
"""Return the set of available execution devices."""
return self._execution_devices
def acquire_execution_device(self, timeout: int = 0) -> torch.device:
"""Acquire and return an execution device (e.g. "cuda" for VRAM)."""
with self._lock:
self._free_execution_device.acquire(timeout=timeout)
free_devices = self.execution_devices - self._busy_execution_devices
chosen_device = list(free_devices)[0]
self._busy_execution_devices.add(chosen_device)
return chosen_device
def release_execution_device(self, device: torch.device) -> None:
"""Mark this execution device as unused."""
with self._lock:
self._free_execution_device.release()
self._busy_execution_devices.remove(device)
@property
def max_cache_size(self) -> float:
@@ -245,13 +265,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
mps.empty_cache()
def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None:
"""Move model into the indicated device.
:param cache_entry: The CacheRecord for the model
:param target_device: The torch.device to move the model into
May raise a torch.cuda.OutOfMemoryError
"""
"""Move model into the indicated device."""
# These attributes are not in the base ModelMixin class but in various derived classes.
# Some models don't have these attributes, in which case they run in RAM/CPU.
self.logger.debug(f"Called to move {cache_entry.key} to {target_device}")
@@ -265,9 +279,6 @@ class ModelCache(ModelCacheBase[AnyModel]):
if torch.device(source_device).type == torch.device(target_device).type:
return
# may raise an exception here if insufficient GPU VRAM
self._check_free_vram(target_device, cache_entry.size)
start_model_to_time = time.time()
snapshot_before = self._capture_memory_snapshot()
cache_entry.model.to(target_device)
@@ -415,12 +426,12 @@ class ModelCache(ModelCacheBase[AnyModel]):
self.logger.debug(f"After making room: cached_models={len(self._cached_models)}")
def _check_free_vram(self, target_device: torch.device, needed_size: int) -> None:
if target_device.type != "cuda":
return
vram_device = ( # mem_get_info() needs an indexed device
target_device if target_device.index is not None else torch.device(str(target_device), index=0)
)
free_mem, _ = torch.cuda.mem_get_info(torch.device(vram_device))
if needed_size > free_mem:
raise torch.cuda.OutOfMemoryError
@staticmethod
def _get_execution_devices() -> Set[torch.device]:
default_device = choose_torch_device()
if default_device != torch.device("cuda"):
return {default_device}
# we get here if the default device is cuda, and return each of the
# cuda devices.
return {torch.device(f"cuda:{x}") for x in range(0, torch.cuda.device_count())}

View File

@@ -2,12 +2,16 @@
Base class and implementation of a class that moves models in and out of VRAM.
"""
from typing import Optional
import torch
from invokeai.backend.model_manager import AnyModel
from .model_cache_base import CacheRecord, ModelCacheBase, ModelLockerBase
MAX_GPU_WAIT = 600 # wait up to 10 minutes for a GPU to become free
class ModelLocker(ModelLockerBase):
"""Internal class that mediates movement in and out of GPU."""
@@ -21,6 +25,7 @@ class ModelLocker(ModelLockerBase):
"""
self._cache = cache
self._cache_entry = cache_entry
self._execution_device: Optional[torch.device] = None
@property
def model(self) -> AnyModel:
@@ -39,15 +44,14 @@ class ModelLocker(ModelLockerBase):
if self._cache.lazy_offloading:
self._cache.offload_unlocked_models(self._cache_entry.size)
self._cache.move_model_to_device(self._cache_entry, self._cache.execution_device)
# We wait for a gpu to be free - may raise a TimeoutError
self._execution_device = self._cache.acquire_execution_device(MAX_GPU_WAIT)
self._cache.move_model_to_device(self._cache_entry, self._execution_device)
self._cache_entry.loaded = True
self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._cache.execution_device}")
self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._execution_device}")
self._cache.print_cuda_stats()
except torch.cuda.OutOfMemoryError:
self._cache.logger.warning("Insufficient GPU memory to load model. Aborting")
self._cache_entry.unlock()
raise
except Exception:
self._cache_entry.unlock()
raise
@@ -59,6 +63,8 @@ class ModelLocker(ModelLockerBase):
return
self._cache_entry.unlock()
if self._execution_device:
self._cache.release_execution_device(self._execution_device)
if not self._cache.lazy_offloading:
self._cache.offload_unlocked_models(self._cache_entry.size)
self._cache.print_cuda_stats()

Some files were not shown because too many files have changed in this diff Show More