Compare commits

...

126 Commits

Author SHA1 Message Date
psychedelicious
8eb5316c9f feat(ui): invalidate cache for queue item on status change
This query is only subscribed-to in the `QueueItemDetail` component - when is rendered only when the user clicks on a queue item in the queue. Invalidating this tag instead of optimistically updating it won't cause any meaningful change to network traffic.
2024-05-20 18:46:56 +10:00
psychedelicious
12ce095bb2 feat(app): update queue item's session on session completion
The session is never updated in the queue after it is first enqueued. As a result, the queue detail view in the frontend never never updates and the session itself doesn't show outputs, execution graph, etc.

We need a new method on the queue service to update a queue item's session, then call it before updating the queue item's status.

Queue item status may be updated via a session-type event _or_ queue-type event. Adding the updated session to all these events is a hairy - simpler to just update the session before we do anything that could trigger a queue item status change event:
- Before calling `emit_session_complete` in the processor (handles session error, completed and cancel events and the corresponding queue events)
- Before calling `cancel_queue_item` in the processor (handles another way queue items can be canceled, outside the session execution loop)

When serializing the session, both in the new service method and the `get_queue_item` endpoint, we need to use `exclude_none=True` to prevent unexpected validation errors.
2024-05-20 18:13:58 +10:00
psychedelicious
242b2a0b59 tests: clean up tests after events changes 2024-05-20 16:22:37 +10:00
psychedelicious
86e201612f fix(mm): port changes into new model_install_common file
Some subtle changes happened between this PR's last update and now. Bring them into the file.
2024-05-20 16:12:45 +10:00
psychedelicious
f6d1e1be22 fix(ui): update event handling to match new types 2024-05-20 15:52:28 +10:00
psychedelicious
edf043d1d6 chore(ui): typegen 2024-05-20 15:37:25 +10:00
psychedelicious
70a21eda78 fix(events): remove user_id and project_id from error event 2024-05-20 15:34:30 +10:00
psychedelicious
452f4fe0e6 feat(events): add extra field to event payloads
This allows for arbitrary serializable data to be sent with events.
2024-05-20 15:34:03 +10:00
psychedelicious
87f2d04ddd fix(events): fix session processor event handling 2024-05-20 15:22:10 +10:00
psychedelicious
9c45cbe8f7 tidy(ui): remove old unused session subscribe actions 2024-05-20 15:19:17 +10:00
psychedelicious
5f7c852493 docs: clarify comment in api_app 2024-05-20 15:19:17 +10:00
psychedelicious
35ef02bdf7 fix(ui): denoise percentage 2024-05-20 15:19:17 +10:00
psychedelicious
7da283f433 chore(ui): typegen 2024-05-20 15:19:17 +10:00
psychedelicious
812cf277b8 feat(api): sort socket event names for openapi schema
Deterministic ordering prevents extraneous, non-functional changes to the autogenerated types
2024-05-20 15:19:08 +10:00
psychedelicious
182cb51bf0 fix(events): fix denoise progress percentage
- Restore calculation of step percentage but in the backend instead of client
- Simplify signatures for denoise progress event callbacks
- Clean up `step_callback.py` (types, do not recreate constant matrix on every step, formatting)
2024-05-20 15:19:08 +10:00
psychedelicious
64a3adfc64 chore(ui): typegen 2024-05-20 15:19:08 +10:00
psychedelicious
a48ef9f7a7 feat(events): remove payload registry, add method to get event classes
We don't need to use the payload schema registry. All our events are dispatched as pydantic models, which are already validated on instantiation.

We do want to add all events to the OpenAPI schema, and we referred to the payload schema registry for this. To get all events, add a simple helper to EventBase. This is functionally identical to using the schema registry.
2024-05-20 15:18:58 +10:00
psychedelicious
9aeabf10df docs: tidy comments in processor 2024-05-20 15:18:58 +10:00
psychedelicious
7b93cc8538 feat(ui): add missing socket events 2024-05-20 15:18:58 +10:00
psychedelicious
a2480c16e7 feat(events): add missing events
These events weren't being emitted via socket.io:
- DownloadCancelledEvent
- DownloadCompleteEvent
- DownloadErrorEvent
- DownloadProgressEvent
- DownloadStartedEvent
- ModelInstallDownloadsCompleteEvent
2024-05-20 15:18:58 +10:00
psychedelicious
b1e2dd222e feat(events): use builder pattern for download events 2024-05-20 15:18:58 +10:00
psychedelicious
1f92e9eec2 fix(events): dump events with mode="json"
Ensures all model events are serializable.
2024-05-20 15:18:58 +10:00
psychedelicious
fb402f3b46 chore: ruff 2024-05-20 15:18:57 +10:00
psychedelicious
0abc328ddf docs(events): update event docstrings 2024-05-20 15:18:57 +10:00
psychedelicious
cfa4e5f88e tests: move fixtures import to conftest.py 2024-05-20 15:18:57 +10:00
psychedelicious
24d0d4932d tests: update tests to use new events 2024-05-20 15:18:57 +10:00
psychedelicious
20db93b901 fix(mm): check for presence of invoker before emitting model load event
The model loader emits events. During testing, it doesn't have access to a fully-mocked events service, so the test fails when attempting to call a nonexistent method. There was a check for this previously, but I accidentally removed it. Restored.
2024-05-20 15:18:57 +10:00
psychedelicious
500a733d79 fix(ui): correct model load event format 2024-05-20 15:18:57 +10:00
psychedelicious
338d5f158b fix(events): add missing __event_name__ to EventBase 2024-05-20 15:18:57 +10:00
psychedelicious
63e4b224b2 feat(events): simplify event classes
- Remove ABCs, they do not work well with pydantic
- Remove the event type classvar - unused
- Remove clever logic to require an event name - we already get validation for this during schema registration.
- Rename event bases to all end in "Base"
2024-05-20 15:18:57 +10:00
psychedelicious
e9043ff060 fix(events): emit bulk download events in correct room 2024-05-20 15:18:57 +10:00
psychedelicious
c725851c64 chore(ui): tidy after rebase 2024-05-20 15:18:57 +10:00
psychedelicious
a1c4ef55d7 feat(ui): update UI to use new events
- Use OpenAPI schema for event payload types
- Update all event listeners
- Add missing events / remove old nonexistent events
2024-05-20 15:18:42 +10:00
psychedelicious
e25b39aca2 chore(ui): typegen 2024-05-20 15:15:41 +10:00
psychedelicious
32a02b3329 refactor(events): use pydantic schemas for events
Our events handling and implementation has a couple pain points:
- Adding or removing data from event payloads requires changes wherever the events are dispatched from.
- We have no type safety for events and need to rely on string matching and dict access when interacting with events.
- Frontend types for socket events must be manually typed. This has caused several bugs.

`fastapi-events` has a neat feature where you can create a pydantic model as an event payload, give it an `__event_name__` attr, and then dispatch the model directly.

This allows us to eliminate a layer of indirection and some unpleasant complexity:
- Event handler callbacks get type hints for their event payloads, and can use `isinstance` on them if needed.
- Event payload construction is now the responsibility of the event itself (a pydantic model), not the service. Every event model has a `build` class method, encapsulating this logic. The build methods are provided as few args as possible. For example, `InvocationStartedEvent.build()` gets the invocation instance and queue item, and can choose the data it wants to include in the event payload.
- Frontend event types may be autogenerated from the OpenAPI schema. We use the payload registry feature of `fastapi-events` to collect all payload models into one place, making it trivial to keep our schema and frontend types in sync.

This commit moves the backend over to this improved event handling setup.
2024-05-20 15:15:21 +10:00
psychedelicious
620ee2875e fix(ui): store hidden state of edges in workflows
This prevents a minor visual bug where collapsed edges between collapsed nodes didn't display correctly on first load of a workflow.
2024-05-20 11:36:47 +10:00
psychedelicious
5553588147 fix(ui): ensure invocation edges have a type 2024-05-20 11:36:47 +10:00
psychedelicious
1c29b3bd85 feat(ui): updated field type translations 2024-05-20 11:28:33 +10:00
psychedelicious
e88b807a13 docs(ui): update field type docs & comments 2024-05-20 11:28:33 +10:00
psychedelicious
9e55ef3d4b fix(ui): workflow migration field type
At some point, I made a mistake and imported the wrong types to some files for the old v1 and v2 workflow schema migration data.

The relevant zod schemas and inferred types have been restored.

This change doesn't alter runtime behaviour. Only type annotations.
2024-05-20 11:28:33 +10:00
psychedelicious
8062a47d16 fix(ui): use new field type cardinality throughout app
Update business logic and tests.
2024-05-20 11:28:33 +10:00
psychedelicious
dba8c43ecb feat(ui): explicit field type cardinality
Replace the `isCollection` and `isCollectionOrScalar` flags with a single enum value `cardinality`. Valid values are `SINGLE`, `COLLECTION` and `SINGLE_OR_COLLECTION`.

Why:
- The two flags were mutually exclusive, but this wasn't enforce. You could create a field type that had both `isCollection` and `isCollectionOrScalar` set to true, whuch makes no sense.
- There was no explicit declaration for scalar/single types.
- Checking if a type had only a single flag was tedious.

Thanks to a change a couple months back in which the workflows schema was revised, field types are internal implementation details. Changes to them are non-breaking.
2024-05-20 11:28:33 +10:00
psychedelicious
8ebf2ddf15 fix(ui): fix t2i adapter dimensions error message
It now indicates the correct dimension of 64 (SD1.5) or 32 (SDXL) - before was hardcoded to 64.
2024-05-20 11:23:14 +10:00
psychedelicious
f4625c2671 feat(ui): add canvas objects to metadat a for all canvas graphs 2024-05-20 10:32:59 +10:00
psychedelicious
c94742bde6 feat(ui): add canvas objects to metadata when saving canvas to gallery 2024-05-20 10:32:59 +10:00
psychedelicious
a34faf0bd8 chore(ui): typegen 2024-05-20 10:32:59 +10:00
psychedelicious
ecfff6cb1e feat(api): add metadata to upload route
Canvas images are saved by uploading a blob generated from the HTML canvas element. This means the existing metadata handling, inside the graph execution engine, is not available.

To save metadata to canvas images, we need to provide it when uploading that blob.

The upload route now has a `metadata` body param. If this is provided, we use it over any metadata embedded in the image.
2024-05-20 10:32:59 +10:00
psychedelicious
ba8bed6870 fix(ui): edge case resulting in no node templates when loading workflow, causing failure
Depending on the user behaviour and network conditions, it's possible that we could try to load a workflow before the invocation templates are available.

Fix is simple:
- Use the RTKQ query hook for openAPI schema in App.tsx
- Disable the load workflow buttons until w have templates parsed
2024-05-19 07:34:00 -07:00
psychedelicious
ca186bca61 fix(ui): missed node execution state for progress images 2024-05-19 20:14:01 +10:00
psychedelicious
e2f109807c fix(ui): delete edges when their source or target no longer exists 2024-05-19 20:14:01 +10:00
psychedelicious
281bd31db2 feat(nodes): make ModelIdentifierInvocation a prototype 2024-05-19 20:14:01 +10:00
psychedelicious
cea1874e00 perf(ui): memoize WorkflowName selectors 2024-05-19 20:14:01 +10:00
psychedelicious
89b0e9e4de feat(ui): use connection validationResults directly in components 2024-05-19 20:14:01 +10:00
psychedelicious
26d0d55d97 fix(ui): set nodeDragThreshold to prevent spurious position change events 2024-05-19 20:14:01 +10:00
psychedelicious
059c5586a4 perf(ui): ignore all no-op node and edge changes 2024-05-19 20:14:01 +10:00
psychedelicious
9ed5698aa8 fix(ui): do not remove exposed fields when updating workflows 2024-05-19 20:14:01 +10:00
psychedelicious
0b5696c5d4 feat(ui): remove nodeExclusivelySelected action 2024-05-19 20:14:01 +10:00
psychedelicious
a51142674a tidy(ui): more succinct syntax for edge and node updates 2024-05-19 20:14:01 +10:00
psychedelicious
b8b671c0db feat(ui): remove selectionDeleted action 2024-05-19 20:14:01 +10:00
psychedelicious
7cceafe0dd feat(ui): remove selectionPasted action 2024-05-19 20:14:01 +10:00
psychedelicious
cbe32b647a feat(ui): remove selectedAll action 2024-05-19 20:14:01 +10:00
psychedelicious
9a8e0842bb feat(ui): remove nodeReplaced action 2024-05-19 20:14:01 +10:00
psychedelicious
1d7671298f fix(ui): group edge selection actions 2024-05-19 20:14:01 +10:00
psychedelicious
e38d75c3dc feat(ui): get rid of nodeAdded 2024-05-19 20:14:01 +10:00
psychedelicious
21fab9785a feat(ui): tweak edge styling 2024-05-19 20:14:01 +10:00
psychedelicious
b3429553bb fix(ui): collapsed edges selected state 2024-05-19 20:14:01 +10:00
psychedelicious
e480844042 fix(ui): edge styling 2024-05-19 20:14:01 +10:00
psychedelicious
26029108f7 feat(ui): rework node and edge mutation logic
Remove our DIY'd reducers, consolidating all node and edge mutations to use `edgesChanged` and `nodesChanged`, which are called by reactflow. This makes the API for manipulating nodes and edges less tangly and error-prone.
2024-05-19 20:14:01 +10:00
psychedelicious
504ac82077 fix(ui): duplicated edges when updating edge with lazy connect 2024-05-19 20:14:01 +10:00
psychedelicious
6b11740dda chore(ui): knip 2024-05-19 20:14:01 +10:00
psychedelicious
a80e3448f5 feat(ui): rework pendingConnection 2024-05-19 20:14:01 +10:00
psychedelicious
4bda174eb9 tests(ui): coverage for getCollectItemType 2024-05-19 20:14:01 +10:00
psychedelicious
b1e28c2f2c tests(ui): coverage for getFirstValidConnection 2024-05-19 20:14:01 +10:00
psychedelicious
83000a4190 feat(ui): rework getFirstValidConnection with new helpers 2024-05-19 20:14:01 +10:00
psychedelicious
c98205d0d7 tests(ui): candidate fields, getFirstValidConnection (wip) 2024-05-19 20:14:01 +10:00
psychedelicious
ce2ad5903c feat(ui): extract logic for finding candidate fields to own function 2024-05-19 20:14:01 +10:00
psychedelicious
fe3980a369 tests(ui): add buildNode convenience wrapper for buildInvocationNode 2024-05-19 20:14:01 +10:00
psychedelicious
ea97ae5ae8 tidy(ui): extraneous vars in makeConnectionErrorSelector 2024-05-19 20:14:01 +10:00
psychedelicious
3605b6b1a3 fix(ui): handling for in-progress edge updates during conection validation 2024-05-19 20:14:01 +10:00
psychedelicious
fc31dddbf7 feat(ui): use new validateConnection 2024-05-19 20:14:01 +10:00
psychedelicious
6ad01d824d feat(ui): add strict mode to validateConnection 2024-05-19 20:14:01 +10:00
psychedelicious
78f9f3ee95 feat(ui): better types for validateConnection 2024-05-19 20:14:01 +10:00
psychedelicious
972398d203 tests(ui): add iterate to test schema 2024-05-19 20:14:01 +10:00
psychedelicious
857889d1fa tests(ui): coverage for getCollectItemType 2024-05-19 20:14:01 +10:00
psychedelicious
8074a802d6 tests(ui): coverage for validateConnectionTypes 2024-05-19 20:14:01 +10:00
psychedelicious
059d5a682c tidy(ui): validateConnection code clarity 2024-05-19 20:14:01 +10:00
psychedelicious
00c2d8f95d tidy(ui): areTypesEqual var names 2024-05-19 20:14:01 +10:00
psychedelicious
04a596179b tests(ui): finish test cases for validateConnection 2024-05-19 20:14:01 +10:00
psychedelicious
3fcb2720d7 tests(ui): add tests for consolidated connection validation 2024-05-19 20:14:01 +10:00
psychedelicious
6f7160b9fd fix(ui): call updateNodeInternals when making connections 2024-05-19 20:14:01 +10:00
psychedelicious
6b4e464d17 fix(ui): rework edge update logic 2024-05-19 20:14:01 +10:00
psychedelicious
9f7841a04b tidy(ui): clean up addnodepopover hotkeys 2024-05-19 20:14:01 +10:00
psychedelicious
468644ab18 fix(ui): rebase conflict 2024-05-19 20:14:01 +10:00
psychedelicious
9d127fee6b feat(ui): makeConnectionErrorSelector now creates a parameterized selector 2024-05-19 20:14:01 +10:00
psychedelicious
6658897210 tidy(ui): tidy connection validation functions and logic 2024-05-19 20:14:01 +10:00
psychedelicious
af7b194bec chore(ui): lint 2024-05-19 20:14:01 +10:00
psychedelicious
de1ea50e6d fix(ui): rebase resolution 2024-05-19 20:14:01 +10:00
psychedelicious
2680ef52c2 feat(nodes): add ModelIdentifierInvocation
This node allows a user to select _any_ model, outputting a `ModelIdentifierField` for that model.
2024-05-19 20:14:01 +10:00
psychedelicious
a012bb6e07 feat(ui): add ModelIdentifierField field type
This new field type accepts _any_ model. A field renderer lets the user select any available model.
2024-05-19 20:14:01 +10:00
psychedelicious
6a2c53f6c5 fix(ui): do not allow comparison between undefined original types 2024-05-19 20:14:01 +10:00
psychedelicious
2cbf7d9221 fix(ui): stupid ts 2024-05-19 20:14:01 +10:00
psychedelicious
fe7ed72c9c feat(nodes): make all ModelIdentifierField inputs accept connections 2024-05-19 20:14:01 +10:00
psychedelicious
85a5a7c47a feat(ui): add originalType to FieldType, improved connection validation
We now keep track of the original field type, derived from the python type annotation in addition to the override type provided by `ui_type`.

This makes `ui_type` work more like it sound like it should work - change the UI input component only.

Connection validation is extend to also check the original types. If there is any match between two fields' "final" or original types, we consider the connection valid.This change is backwards-compatible; there is no workflow migration needed.
2024-05-19 20:14:01 +10:00
psychedelicious
af3fd26d4e fix(ui): bug when clearing processor
When clearing the processor config, we shouldn't re-process the image. This logic wasn't handled correctly, but coincidentally the bug didn't cause a user-facing issue.

Without a config, we had a runtime error when trying to build the node for the processor graph and the listener failed.

So while we didn't re-process the image, it was because there was an error, not because the logic was correct.

Fix this by bailing if there is no image or config.
2024-05-19 07:25:48 +10:00
psychedelicious
5127fd6320 fix(ui): control adapter autoprocess jank
If you change the control model and the new model has the same default processor, we would still re-process the image, even if there was no need to do so.

With this change, if the image and processor config are unchanged, we bail out.
2024-05-19 07:25:48 +10:00
psychedelicious
124d34a8cc docs: add link for --extra-index-url 2024-05-19 00:56:31 +10:00
Shukri
e8387d7523 docs: add link to tool on pytorch website 2024-05-19 00:56:31 +10:00
Shukri
a5d08c981b docs: fix typo in --root arg of invokeai-web 2024-05-19 00:56:31 +10:00
Shukri
811d0da0f0 docs: fix link to. install reqs 2024-05-19 00:56:31 +10:00
psychedelicious
17e1fc5254 chore(app): ruff 2024-05-18 09:21:45 +10:00
maryhipp
84e031edc2 add nulable project also 2024-05-18 09:21:45 +10:00
maryhipp
b6b7e737e0 ruff 2024-05-18 09:21:45 +10:00
maryhipp
5f3e7afd45 add nullable user to invocation error events 2024-05-18 09:21:45 +10:00
psychedelicious
b0cfca9d24 fix(app): pass image metadata as stringified json 2024-05-18 09:04:37 +10:00
psychedelicious
985ef89825 fix(app): type annotations in images service 2024-05-18 09:04:37 +10:00
psychedelicious
5928ade5fd feat(app): simplified create image API
Graph, metadata and workflow all take stringified JSON only. This makes the API consistent and means we don't need to do a round-trip of pydantic parsing when handling this data.

It also prevents a failure mode where an uploaded image's metadata, workflow or graph are old and don't match the current schema.

As before, the frontend does strict validation and parsing when loading these values.
2024-05-18 09:04:37 +10:00
psychedelicious
93ebc175c6 fix(app): retain graph in metadata when uploading images 2024-05-18 09:04:37 +10:00
psychedelicious
386d552493 fix(ui): loading workflows from file 2024-05-18 09:04:37 +10:00
psychedelicious
799cf06d20 fix(ui): loading library workflows 2024-05-18 09:04:37 +10:00
psychedelicious
922716d2ab feat(ui): store graph in image metadata
The previous super-minimal implementation had a major issue - the saved workflow didn't take into account batched field values. When generating with multiple iterations or dynamic prompts, the same workflow with the first prompt, seed, etc was stored in each image.

As a result, when the batch results in multiple queue items, only one of the images has the correct workflow - the others are mismatched.

To work around this, we can store the _graph_ in the image metadata (alongside the workflow, if generated via workflow editor). When loading a workflow from an image, we can choose to load the workflow or the graph, preferring the workflow.

Internally, we need to update images router image-saving services. The changes are minimal.

To avoid pydantic errors deserializing the graph, when we extract it from the image, we will leave it as stringified JSON and let the frontend's more sophisticated and flexible parsing handle it. The worklow is also changed to just return stringified JSON, so the API is consistent.
2024-05-18 09:04:37 +10:00
psychedelicious
66fc110b64 Revert "feat(ui): store workflow in generation tab images"
This reverts commit c9c4190fb45696088207b0ac3c69c2795d7f9694.
2024-05-18 09:04:37 +10:00
psychedelicious
822f1e1f06 feat(ui): store workflow in generation tab images 2024-05-18 09:04:37 +10:00
psychedelicious
5d60c3c8e1 fix(ui): jank when editing field title 2024-05-18 08:46:40 +10:00
psychedelicious
4e21d01c7f feat(ui): dim field name when connected 2024-05-18 08:46:40 +10:00
psychedelicious
6b7b0b3777 fix(ui): do not rearrange fields when connection/disconnecting 2024-05-18 08:46:40 +10:00
psychedelicious
07feb5ba07 Revert "feat(ui): SDXL clip skip"
This reverts commit 40b4fa7238.
2024-05-17 15:08:04 -07:00
162 changed files with 7676 additions and 4079 deletions

View File

@@ -117,13 +117,13 @@ Stateless fields do not store their value in the node, so their field instances
"Custom" fields will always be treated as stateless fields.
##### Collection and Scalar Fields
##### Single and Collection Fields
Field types have a name and two flags which may identify it as a **collection** or **collection or scalar** field.
Field types have a name and cardinality property which may identify it as a **SINGLE**, **COLLECTION** or **SINGLE_OR_COLLECTION** field.
If a field is annotated in python as a list, its field type is parsed and flagged as a **collection** type (e.g. `list[int]`).
If it is annotated as a union of a type and list, the type will be flagged as a **collection or scalar** type (e.g. `Union[int, list[int]]`). Fields may not be unions of different types (e.g. `Union[int, list[str]]` and `Union[int, str]` are not allowed).
- If a field is annotated in python as a singular value or class, its field type is parsed as a **SINGLE** type (e.g. `int`, `ImageField`, `str`).
- If a field is annotated in python as a list, its field type is parsed as a **COLLECTION** type (e.g. `list[int]`).
- If it is annotated as a union of a type and list, the type will be parsed as a **SINGLE_OR_COLLECTION** type (e.g. `Union[int, list[int]]`). Fields may not be unions of different types (e.g. `Union[int, list[str]]` and `Union[int, str]` are not allowed).
## Implementation
@@ -173,8 +173,7 @@ Field types are represented as structured objects:
```ts
type FieldType = {
name: string;
isCollection: boolean;
isCollectionOrScalar: boolean;
cardinality: 'SINGLE' | 'COLLECTION' | 'SINGLE_OR_COLLECTION';
};
```
@@ -186,7 +185,7 @@ There are 4 general cases for field type parsing.
When a field is annotated as a primitive values (e.g. `int`, `str`, `float`), the field type parsing is fairly straightforward. The field is represented by a simple OpenAPI **schema object**, which has a `type` property.
We create a field type name from this `type` string (e.g. `string` -> `StringField`).
We create a field type name from this `type` string (e.g. `string` -> `StringField`). The cardinality is `"SINGLE"`.
##### Complex Types
@@ -200,13 +199,13 @@ We need to **dereference** the schema to pull these out. Dereferencing may requi
When a field is annotated as a list of a single type, the schema object has an `items` property. They may be a schema object or reference object and must be parsed to determine the item type.
We use the item type for field type name, adding `isCollection: true` to the field type.
We use the item type for field type name. The cardinality is `"COLLECTION"`.
##### Collection or Scalar Types
##### Single or Collection Types
When a field is annotated as a union of a type and list of that type, the schema object has an `anyOf` property, which holds a list of valid types for the union.
After verifying that the union has two members (a type and list of the same type), we use the type for field type name, adding `isCollectionOrScalar: true` to the field type.
After verifying that the union has two members (a type and list of the same type), we use the type for field type name, with cardinality `"SINGLE_OR_COLLECTION"`.
##### Optional Fields

View File

@@ -10,7 +10,7 @@ InvokeAI is distributed as a python package on PyPI, installable with `pip`. The
### Requirements
Before you start, go through the [installation requirements].
Before you start, go through the [installation requirements](./INSTALL_REQUIREMENTS.md).
### Installation Walkthrough
@@ -79,7 +79,7 @@ Before you start, go through the [installation requirements].
1. Install the InvokeAI Package. The base command is `pip install InvokeAI --use-pep517`, but you may need to change this depending on your system and the desired features.
- You may need to provide an [extra index URL]. Select your platform configuration using [this tool on the PyTorch website]. Copy the `--extra-index-url` string from this and append it to your install command.
- You may need to provide an [extra index URL](https://pip.pypa.io/en/stable/cli/pip_install/#cmdoption-extra-index-url). Select your platform configuration using [this tool on the PyTorch website](https://pytorch.org/get-started/locally/). Copy the `--extra-index-url` string from this and append it to your install command.
!!! example "Install with an extra index URL"
@@ -116,4 +116,4 @@ Before you start, go through the [installation requirements].
!!! warning
If the virtual environment is _not_ inside the root directory, then you _must_ specify the path to the root directory with `--root_dir \path\to\invokeai` or the `INVOKEAI_ROOT` environment variable.
If the virtual environment is _not_ inside the root directory, then you _must_ specify the path to the root directory with `--root \path\to\invokeai` or the `INVOKEAI_ROOT` environment variable.

View File

@@ -18,6 +18,7 @@ from ..services.boards.boards_default import BoardService
from ..services.bulk_download.bulk_download_default import BulkDownloadService
from ..services.config import InvokeAIAppConfig
from ..services.download import DownloadQueueService
from ..services.events.events_fastapievents import FastAPIEventService
from ..services.image_files.image_files_disk import DiskImageFileStorage
from ..services.image_records.image_records_sqlite import SqliteImageRecordStorage
from ..services.images.images_default import ImageService
@@ -33,7 +34,6 @@ from ..services.session_processor.session_processor_default import DefaultSessio
from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue
from ..services.urls.urls_default import LocalUrlService
from ..services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
from .events import FastAPIEventService
# TODO: is there a better way to achieve this?

View File

@@ -1,52 +0,0 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import asyncio
import threading
from queue import Empty, Queue
from typing import Any
from fastapi_events.dispatcher import dispatch
from ..services.events.events_base import EventServiceBase
class FastAPIEventService(EventServiceBase):
event_handler_id: int
__queue: Queue
__stop_event: threading.Event
def __init__(self, event_handler_id: int) -> None:
self.event_handler_id = event_handler_id
self.__queue = Queue()
self.__stop_event = threading.Event()
asyncio.create_task(self.__dispatch_from_queue(stop_event=self.__stop_event))
super().__init__()
def stop(self, *args, **kwargs):
self.__stop_event.set()
self.__queue.put(None)
def dispatch(self, event_name: str, payload: Any) -> None:
self.__queue.put({"event_name": event_name, "payload": payload})
async def __dispatch_from_queue(self, stop_event: threading.Event):
"""Get events on from the queue and dispatch them, from the correct thread"""
while not stop_event.is_set():
try:
event = self.__queue.get(block=False)
if not event: # Probably stopping
continue
dispatch(
event.get("event_name"),
payload=event.get("payload"),
middleware_id=self.event_handler_id,
)
except Empty:
await asyncio.sleep(0.1)
pass
except asyncio.CancelledError as e:
raise e # Raise a proper error

View File

@@ -6,13 +6,12 @@ from fastapi import BackgroundTasks, Body, HTTPException, Path, Query, Request,
from fastapi.responses import FileResponse
from fastapi.routing import APIRouter
from PIL import Image
from pydantic import BaseModel, Field, ValidationError
from pydantic import BaseModel, Field, JsonValue
from invokeai.app.invocations.fields import MetadataField, MetadataFieldValidator
from invokeai.app.invocations.fields import MetadataField
from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin
from invokeai.app.services.images.images_common import ImageDTO, ImageUrlsDTO
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID, WorkflowWithoutIDValidator
from ..dependencies import ApiDependencies
@@ -42,13 +41,17 @@ async def upload_image(
board_id: Optional[str] = Query(default=None, description="The board to add this image to, if any"),
session_id: Optional[str] = Query(default=None, description="The session ID associated with this upload, if any"),
crop_visible: Optional[bool] = Query(default=False, description="Whether to crop the image"),
metadata: Optional[JsonValue] = Body(
default=None, description="The metadata to associate with the image", embed=True
),
) -> ImageDTO:
"""Uploads an image"""
if not file.content_type or not file.content_type.startswith("image"):
raise HTTPException(status_code=415, detail="Not an image")
metadata = None
workflow = None
_metadata = None
_workflow = None
_graph = None
contents = await file.read()
try:
@@ -62,22 +65,28 @@ async def upload_image(
# TODO: retain non-invokeai metadata on upload?
# attempt to parse metadata from image
metadata_raw = pil_image.info.get("invokeai_metadata", None)
if metadata_raw:
try:
metadata = MetadataFieldValidator.validate_json(metadata_raw)
except ValidationError:
ApiDependencies.invoker.services.logger.warn("Failed to parse metadata for uploaded image")
pass
metadata_raw = metadata if isinstance(metadata, str) else pil_image.info.get("invokeai_metadata", None)
if isinstance(metadata_raw, str):
_metadata = metadata_raw
else:
ApiDependencies.invoker.services.logger.warn("Failed to parse metadata for uploaded image")
pass
# attempt to parse workflow from image
workflow_raw = pil_image.info.get("invokeai_workflow", None)
if workflow_raw is not None:
try:
workflow = WorkflowWithoutIDValidator.validate_json(workflow_raw)
except ValidationError:
ApiDependencies.invoker.services.logger.warn("Failed to parse metadata for uploaded image")
pass
if isinstance(workflow_raw, str):
_workflow = workflow_raw
else:
ApiDependencies.invoker.services.logger.warn("Failed to parse workflow for uploaded image")
pass
# attempt to extract graph from image
graph_raw = pil_image.info.get("invokeai_graph", None)
if isinstance(graph_raw, str):
_graph = graph_raw
else:
ApiDependencies.invoker.services.logger.warn("Failed to parse graph for uploaded image")
pass
try:
image_dto = ApiDependencies.invoker.services.images.create(
@@ -86,8 +95,9 @@ async def upload_image(
image_category=image_category,
session_id=session_id,
board_id=board_id,
metadata=metadata,
workflow=workflow,
metadata=_metadata,
workflow=_workflow,
graph=_graph,
is_intermediate=is_intermediate,
)
@@ -185,14 +195,21 @@ async def get_image_metadata(
raise HTTPException(status_code=404)
class WorkflowAndGraphResponse(BaseModel):
workflow: Optional[str] = Field(description="The workflow used to generate the image, as stringified JSON")
graph: Optional[str] = Field(description="The graph used to generate the image, as stringified JSON")
@images_router.get(
"/i/{image_name}/workflow", operation_id="get_image_workflow", response_model=Optional[WorkflowWithoutID]
"/i/{image_name}/workflow", operation_id="get_image_workflow", response_model=WorkflowAndGraphResponse
)
async def get_image_workflow(
image_name: str = Path(description="The name of image whose workflow to get"),
) -> Optional[WorkflowWithoutID]:
) -> WorkflowAndGraphResponse:
try:
return ApiDependencies.invoker.services.images.get_workflow(image_name)
workflow = ApiDependencies.invoker.services.images.get_workflow(image_name)
graph = ApiDependencies.invoker.services.images.get_graph(image_name)
return WorkflowAndGraphResponse(workflow=workflow, graph=graph)
except Exception:
raise HTTPException(status_code=404)

View File

@@ -17,7 +17,7 @@ from starlette.exceptions import HTTPException
from typing_extensions import Annotated
from invokeai.app.services.model_images.model_images_common import ModelImageFileNotFoundException
from invokeai.app.services.model_install import ModelInstallJob
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
from invokeai.app.services.model_records import (
DuplicateModelException,
InvalidModelException,

View File

@@ -203,6 +203,7 @@ async def get_batch_status(
responses={
200: {"model": SessionQueueItem},
},
response_model_exclude_none=True,
)
async def get_queue_item(
queue_id: str = Path(description="The queue id to perform this operation on"),

View File

@@ -1,66 +1,131 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import Any
from fastapi import FastAPI
from fastapi_events.handlers.local import local_handler
from fastapi_events.typing import Event
from pydantic import BaseModel
from socketio import ASGIApp, AsyncServer
from ..services.events.events_base import EventServiceBase
from invokeai.app.services.events.events_common import (
BatchEnqueuedEvent,
BulkDownloadCompleteEvent,
BulkDownloadErrorEvent,
BulkDownloadEventBase,
BulkDownloadStartedEvent,
DownloadCancelledEvent,
DownloadCompleteEvent,
DownloadErrorEvent,
DownloadProgressEvent,
DownloadStartedEvent,
FastAPIEvent,
InvocationCompleteEvent,
InvocationDenoiseProgressEvent,
InvocationErrorEvent,
InvocationStartedEvent,
ModelEventBase,
ModelInstallCancelledEvent,
ModelInstallCompleteEvent,
ModelInstallDownloadProgressEvent,
ModelInstallDownloadsCompleteEvent,
ModelInstallErrorEvent,
ModelInstallStartedEvent,
ModelLoadCompleteEvent,
ModelLoadStartedEvent,
QueueClearedEvent,
QueueEventBase,
QueueItemStatusChangedEvent,
SessionCanceledEvent,
SessionCompleteEvent,
SessionStartedEvent,
register_events,
)
class QueueSubscriptionEvent(BaseModel):
queue_id: str
class BulkDownloadSubscriptionEvent(BaseModel):
bulk_download_id: str
class SocketIO:
__sio: AsyncServer
__app: ASGIApp
_sub_queue = "subscribe_queue"
_unsub_queue = "unsubscribe_queue"
__sub_queue: str = "subscribe_queue"
__unsub_queue: str = "unsubscribe_queue"
__sub_bulk_download: str = "subscribe_bulk_download"
__unsub_bulk_download: str = "unsubscribe_bulk_download"
_sub_bulk_download = "subscribe_bulk_download"
_unsub_bulk_download = "unsubscribe_bulk_download"
def __init__(self, app: FastAPI):
self.__sio = AsyncServer(async_mode="asgi", cors_allowed_origins="*")
self.__app = ASGIApp(socketio_server=self.__sio, socketio_path="/ws/socket.io")
app.mount("/ws", self.__app)
self._sio = AsyncServer(async_mode="asgi", cors_allowed_origins="*")
self._app = ASGIApp(socketio_server=self._sio, socketio_path="/ws/socket.io")
app.mount("/ws", self._app)
self.__sio.on(self.__sub_queue, handler=self._handle_sub_queue)
self.__sio.on(self.__unsub_queue, handler=self._handle_unsub_queue)
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._handle_queue_event)
local_handler.register(event_name=EventServiceBase.model_event, _func=self._handle_model_event)
self._sio.on(self._sub_queue, handler=self._handle_sub_queue)
self._sio.on(self._unsub_queue, handler=self._handle_unsub_queue)
self._sio.on(self._sub_bulk_download, handler=self._handle_sub_bulk_download)
self._sio.on(self._unsub_bulk_download, handler=self._handle_unsub_bulk_download)
self.__sio.on(self.__sub_bulk_download, handler=self._handle_sub_bulk_download)
self.__sio.on(self.__unsub_bulk_download, handler=self._handle_unsub_bulk_download)
local_handler.register(event_name=EventServiceBase.bulk_download_event, _func=self._handle_bulk_download_event)
async def _handle_queue_event(self, event: Event):
await self.__sio.emit(
event=event[1]["event"],
data=event[1]["data"],
room=event[1]["data"]["queue_id"],
register_events(
{
InvocationStartedEvent,
InvocationDenoiseProgressEvent,
InvocationCompleteEvent,
InvocationErrorEvent,
SessionStartedEvent,
SessionCompleteEvent,
SessionCanceledEvent,
QueueItemStatusChangedEvent,
BatchEnqueuedEvent,
QueueClearedEvent,
},
self._handle_queue_event,
)
async def _handle_sub_queue(self, sid, data, *args, **kwargs) -> None:
if "queue_id" in data:
await self.__sio.enter_room(sid, data["queue_id"])
async def _handle_unsub_queue(self, sid, data, *args, **kwargs) -> None:
if "queue_id" in data:
await self.__sio.leave_room(sid, data["queue_id"])
async def _handle_model_event(self, event: Event) -> None:
await self.__sio.emit(event=event[1]["event"], data=event[1]["data"])
async def _handle_bulk_download_event(self, event: Event):
await self.__sio.emit(
event=event[1]["event"],
data=event[1]["data"],
room=event[1]["data"]["bulk_download_id"],
register_events(
{
DownloadCancelledEvent,
DownloadCompleteEvent,
DownloadErrorEvent,
DownloadProgressEvent,
DownloadStartedEvent,
ModelLoadStartedEvent,
ModelLoadCompleteEvent,
ModelInstallDownloadProgressEvent,
ModelInstallDownloadsCompleteEvent,
ModelInstallStartedEvent,
ModelInstallCompleteEvent,
ModelInstallCancelledEvent,
ModelInstallErrorEvent,
},
self._handle_model_event,
)
async def _handle_sub_bulk_download(self, sid, data, *args, **kwargs):
if "bulk_download_id" in data:
await self.__sio.enter_room(sid, data["bulk_download_id"])
register_events(
{BulkDownloadStartedEvent, BulkDownloadCompleteEvent, BulkDownloadErrorEvent},
self._handle_bulk_image_download_event,
)
async def _handle_unsub_bulk_download(self, sid, data, *args, **kwargs):
if "bulk_download_id" in data:
await self.__sio.leave_room(sid, data["bulk_download_id"])
async def _handle_sub_queue(self, sid: str, data: Any) -> None:
await self._sio.enter_room(sid, QueueSubscriptionEvent(**data).queue_id)
async def _handle_unsub_queue(self, sid: str, data: Any) -> None:
await self._sio.leave_room(sid, QueueSubscriptionEvent(**data).queue_id)
async def _handle_sub_bulk_download(self, sid: str, data: Any) -> None:
await self._sio.enter_room(sid, BulkDownloadSubscriptionEvent(**data).bulk_download_id)
async def _handle_unsub_bulk_download(self, sid: str, data: Any) -> None:
await self._sio.leave_room(sid, BulkDownloadSubscriptionEvent(**data).bulk_download_id)
async def _handle_queue_event(self, event: FastAPIEvent[QueueEventBase]):
event_name, payload = event
await self._sio.emit(event=event_name, data=payload.model_dump(mode="json"), room=payload.queue_id)
async def _handle_model_event(self, event: FastAPIEvent[ModelEventBase]) -> None:
event_name, payload = event
await self._sio.emit(event=event_name, data=payload.model_dump(mode="json"))
async def _handle_bulk_image_download_event(self, event: FastAPIEvent[BulkDownloadEventBase]) -> None:
event_name, payload = event
await self._sio.emit(event=event_name, data=payload.model_dump(mode="json"), room=payload.bulk_download_id)

View File

@@ -27,6 +27,7 @@ import invokeai.frontend.web as web_dir
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.services.config.config_default import get_config
from invokeai.app.services.events.events_common import EventBase
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.backend.util.devices import TorchDevice
@@ -182,23 +183,14 @@ def custom_openapi() -> dict[str, Any]:
openapi_schema["components"]["schemas"]["InvocationOutputMap"]["required"].append(invoker.get_type())
invoker_schema["class"] = "invocation"
# This code no longer seems to be necessary?
# Leave it here just in case
#
# from invokeai.backend.model_manager import get_model_config_formats
# formats = get_model_config_formats()
# for model_config_name, enum_set in formats.items():
# if model_config_name in openapi_schema["components"]["schemas"]:
# # print(f"Config with name {name} already defined")
# continue
# openapi_schema["components"]["schemas"][model_config_name] = {
# "title": model_config_name,
# "description": "An enumeration.",
# "type": "string",
# "enum": [v.value for v in enum_set],
# }
# Add all event schemas
for event in sorted(EventBase.get_events(), key=lambda e: e.__name__):
json_schema = event.model_json_schema(mode="serialization", ref_template="#/components/schemas/{model}")
if "$defs" in json_schema:
for schema_key, schema in json_schema["$defs"].items():
openapi_schema["components"]["schemas"][schema_key] = schema
del json_schema["$defs"]
openapi_schema["components"]["schemas"][event.__name__] = json_schema
app.openapi_schema = openapi_schema
return app.openapi_schema

View File

@@ -24,7 +24,6 @@ from pydantic import BaseModel, Field, field_validator, model_validator
from invokeai.app.invocations.fields import (
FieldDescriptions,
ImageField,
Input,
InputField,
OutputField,
UIType,
@@ -80,13 +79,13 @@ class ControlOutput(BaseInvocationOutput):
control: ControlField = OutputField(description=FieldDescriptions.control)
@invocation("controlnet", title="ControlNet", tags=["controlnet"], category="controlnet", version="1.1.1")
@invocation("controlnet", title="ControlNet", tags=["controlnet"], category="controlnet", version="1.1.2")
class ControlNetInvocation(BaseInvocation):
"""Collects ControlNet info to pass to other nodes"""
image: ImageField = InputField(description="The control image")
control_model: ModelIdentifierField = InputField(
description=FieldDescriptions.controlnet_model, input=Input.Direct, ui_type=UIType.ControlNetModel
description=FieldDescriptions.controlnet_model, ui_type=UIType.ControlNetModel
)
control_weight: Union[float, List[float]] = InputField(
default=1.0, ge=-1, le=2, description="The weight given to the ControlNet"

View File

@@ -5,7 +5,7 @@ from pydantic import BaseModel, Field, field_validator, model_validator
from typing_extensions import Self
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, TensorField, UIType
from invokeai.app.invocations.fields import FieldDescriptions, InputField, OutputField, TensorField, UIType
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.primitives import ImageField
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
@@ -58,7 +58,7 @@ class IPAdapterOutput(BaseInvocationOutput):
CLIP_VISION_MODEL_MAP = {"ViT-H": "ip_adapter_sd_image_encoder", "ViT-G": "ip_adapter_sdxl_image_encoder"}
@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.4.0")
@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.4.1")
class IPAdapterInvocation(BaseInvocation):
"""Collects IP-Adapter info to pass to other nodes."""
@@ -67,7 +67,6 @@ class IPAdapterInvocation(BaseInvocation):
ip_adapter_model: ModelIdentifierField = InputField(
description="The IP-Adapter model.",
title="IP-Adapter Model",
input=Input.Direct,
ui_order=-1,
ui_type=UIType.IPAdapterModel,
)

View File

@@ -11,6 +11,7 @@ from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType,
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
@@ -93,19 +94,46 @@ class ModelLoaderOutput(UNetOutput, CLIPOutput, VAEOutput):
pass
@invocation_output("model_identifier_output")
class ModelIdentifierOutput(BaseInvocationOutput):
"""Model identifier output"""
model: ModelIdentifierField = OutputField(description="Model identifier", title="Model")
@invocation(
"model_identifier",
title="Model identifier",
tags=["model"],
category="model",
version="1.0.0",
classification=Classification.Prototype,
)
class ModelIdentifierInvocation(BaseInvocation):
"""Selects any model, outputting it its identifier. Be careful with this one! The identifier will be accepted as
input for any model, even if the model types don't match. If you connect this to a mismatched input, you'll get an
error."""
model: ModelIdentifierField = InputField(description="The model to select", title="Model")
def invoke(self, context: InvocationContext) -> ModelIdentifierOutput:
if not context.models.exists(self.model.key):
raise Exception(f"Unknown model {self.model.key}")
return ModelIdentifierOutput(model=self.model)
@invocation(
"main_model_loader",
title="Main Model",
tags=["model"],
category="model",
version="1.0.2",
version="1.0.3",
)
class MainModelLoaderInvocation(BaseInvocation):
"""Loads a main model, outputting its submodels."""
model: ModelIdentifierField = InputField(
description=FieldDescriptions.main_model, input=Input.Direct, ui_type=UIType.MainModel
)
model: ModelIdentifierField = InputField(description=FieldDescriptions.main_model, ui_type=UIType.MainModel)
# TODO: precision?
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
@@ -134,12 +162,12 @@ class LoRALoaderOutput(BaseInvocationOutput):
clip: Optional[CLIPField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
@invocation("lora_loader", title="LoRA", tags=["model"], category="model", version="1.0.2")
@invocation("lora_loader", title="LoRA", tags=["model"], category="model", version="1.0.3")
class LoRALoaderInvocation(BaseInvocation):
"""Apply selected lora to unet and text_encoder."""
lora: ModelIdentifierField = InputField(
description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA", ui_type=UIType.LoRAModel
description=FieldDescriptions.lora_model, title="LoRA", ui_type=UIType.LoRAModel
)
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
unet: Optional[UNetField] = InputField(
@@ -197,12 +225,12 @@ class LoRASelectorOutput(BaseInvocationOutput):
lora: LoRAField = OutputField(description="LoRA model and weight", title="LoRA")
@invocation("lora_selector", title="LoRA Selector", tags=["model"], category="model", version="1.0.0")
@invocation("lora_selector", title="LoRA Selector", tags=["model"], category="model", version="1.0.1")
class LoRASelectorInvocation(BaseInvocation):
"""Selects a LoRA model and weight."""
lora: ModelIdentifierField = InputField(
description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA", ui_type=UIType.LoRAModel
description=FieldDescriptions.lora_model, title="LoRA", ui_type=UIType.LoRAModel
)
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
@@ -273,13 +301,13 @@ class SDXLLoRALoaderOutput(BaseInvocationOutput):
title="SDXL LoRA",
tags=["lora", "model"],
category="model",
version="1.0.2",
version="1.0.3",
)
class SDXLLoRALoaderInvocation(BaseInvocation):
"""Apply selected lora to unet and text_encoder."""
lora: ModelIdentifierField = InputField(
description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA", ui_type=UIType.LoRAModel
description=FieldDescriptions.lora_model, title="LoRA", ui_type=UIType.LoRAModel
)
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
unet: Optional[UNetField] = InputField(
@@ -414,12 +442,12 @@ class SDXLLoRACollectionLoader(BaseInvocation):
return output
@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.2")
@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.3")
class VAELoaderInvocation(BaseInvocation):
"""Loads a VAE model, outputting a VaeLoaderOutput"""
vae_model: ModelIdentifierField = InputField(
description=FieldDescriptions.vae_model, input=Input.Direct, title="VAE", ui_type=UIType.VAEModel
description=FieldDescriptions.vae_model, title="VAE", ui_type=UIType.VAEModel
)
def invoke(self, context: InvocationContext) -> VAEOutput:

View File

@@ -1,4 +1,4 @@
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.invocations.fields import FieldDescriptions, InputField, OutputField, UIType
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager import SubModelType
@@ -30,12 +30,12 @@ class SDXLRefinerModelLoaderOutput(BaseInvocationOutput):
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
@invocation("sdxl_model_loader", title="SDXL Main Model", tags=["model", "sdxl"], category="model", version="1.0.2")
@invocation("sdxl_model_loader", title="SDXL Main Model", tags=["model", "sdxl"], category="model", version="1.0.3")
class SDXLModelLoaderInvocation(BaseInvocation):
"""Loads an sdxl base model, outputting its submodels."""
model: ModelIdentifierField = InputField(
description=FieldDescriptions.sdxl_main_model, input=Input.Direct, ui_type=UIType.SDXLMainModel
description=FieldDescriptions.sdxl_main_model, ui_type=UIType.SDXLMainModel
)
# TODO: precision?
@@ -67,13 +67,13 @@ class SDXLModelLoaderInvocation(BaseInvocation):
title="SDXL Refiner Model",
tags=["model", "sdxl", "refiner"],
category="model",
version="1.0.2",
version="1.0.3",
)
class SDXLRefinerModelLoaderInvocation(BaseInvocation):
"""Loads an sdxl refiner model, outputting its submodels."""
model: ModelIdentifierField = InputField(
description=FieldDescriptions.sdxl_refiner_model, input=Input.Direct, ui_type=UIType.SDXLRefinerModel
description=FieldDescriptions.sdxl_refiner_model, ui_type=UIType.SDXLRefinerModel
)
# TODO: precision?

View File

@@ -8,7 +8,7 @@ from invokeai.app.invocations.baseinvocation import (
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField, UIType
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, OutputField, UIType
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.shared.invocation_context import InvocationContext
@@ -45,7 +45,7 @@ class T2IAdapterOutput(BaseInvocationOutput):
@invocation(
"t2i_adapter", title="T2I-Adapter", tags=["t2i_adapter", "control"], category="t2i_adapter", version="1.0.2"
"t2i_adapter", title="T2I-Adapter", tags=["t2i_adapter", "control"], category="t2i_adapter", version="1.0.3"
)
class T2IAdapterInvocation(BaseInvocation):
"""Collects T2I-Adapter info to pass to other nodes."""
@@ -55,7 +55,6 @@ class T2IAdapterInvocation(BaseInvocation):
t2i_adapter_model: ModelIdentifierField = InputField(
description="The T2I-Adapter model.",
title="T2I-Adapter Model",
input=Input.Direct,
ui_order=-1,
ui_type=UIType.T2IAdapterModel,
)

View File

@@ -106,9 +106,7 @@ class BulkDownloadService(BulkDownloadBase):
if self._invoker:
assert bulk_download_id is not None
self._invoker.services.events.emit_bulk_download_started(
bulk_download_id=bulk_download_id,
bulk_download_item_id=bulk_download_item_id,
bulk_download_item_name=bulk_download_item_name,
bulk_download_id, bulk_download_item_id, bulk_download_item_name
)
def _signal_job_completed(
@@ -118,10 +116,8 @@ class BulkDownloadService(BulkDownloadBase):
if self._invoker:
assert bulk_download_id is not None
assert bulk_download_item_name is not None
self._invoker.services.events.emit_bulk_download_completed(
bulk_download_id=bulk_download_id,
bulk_download_item_id=bulk_download_item_id,
bulk_download_item_name=bulk_download_item_name,
self._invoker.services.events.emit_bulk_download_complete(
bulk_download_id, bulk_download_item_id, bulk_download_item_name
)
def _signal_job_failed(
@@ -131,11 +127,8 @@ class BulkDownloadService(BulkDownloadBase):
if self._invoker:
assert bulk_download_id is not None
assert exception is not None
self._invoker.services.events.emit_bulk_download_failed(
bulk_download_id=bulk_download_id,
bulk_download_item_id=bulk_download_item_id,
bulk_download_item_name=bulk_download_item_name,
error=str(exception),
self._invoker.services.events.emit_bulk_download_error(
bulk_download_id, bulk_download_item_id, bulk_download_item_name, str(exception)
)
def stop(self, *args, **kwargs):

View File

@@ -8,14 +8,13 @@ import time
import traceback
from pathlib import Path
from queue import Empty, PriorityQueue
from typing import Any, Dict, List, Optional, Set
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set
import requests
from pydantic.networks import AnyHttpUrl
from requests import HTTPError
from tqdm import tqdm
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.util.misc import get_iso_timestamp
from invokeai.backend.util.logging import InvokeAILogger
@@ -30,6 +29,9 @@ from .download_base import (
UnknownJobIDException,
)
if TYPE_CHECKING:
from invokeai.app.services.events.events_base import EventServiceBase
# Maximum number of bytes to download during each call to requests.iter_content()
DOWNLOAD_CHUNK_SIZE = 100000
@@ -40,7 +42,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
def __init__(
self,
max_parallel_dl: int = 5,
event_bus: Optional[EventServiceBase] = None,
event_bus: Optional["EventServiceBase"] = None,
requests_session: Optional[requests.sessions.Session] = None,
):
"""
@@ -343,8 +345,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
f"An error occurred while processing the on_start callback: {traceback.format_exception(e)}"
)
if self._event_bus:
assert job.download_path
self._event_bus.emit_download_started(str(job.source), job.download_path.as_posix())
self._event_bus.emit_download_started(job)
def _signal_job_progress(self, job: DownloadJob) -> None:
if job.on_progress:
@@ -355,13 +356,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
f"An error occurred while processing the on_progress callback: {traceback.format_exception(e)}"
)
if self._event_bus:
assert job.download_path
self._event_bus.emit_download_progress(
str(job.source),
download_path=job.download_path.as_posix(),
current_bytes=job.bytes,
total_bytes=job.total_bytes,
)
self._event_bus.emit_download_progress(job)
def _signal_job_complete(self, job: DownloadJob) -> None:
job.status = DownloadJobStatus.COMPLETED
@@ -373,10 +368,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
f"An error occurred while processing the on_complete callback: {traceback.format_exception(e)}"
)
if self._event_bus:
assert job.download_path
self._event_bus.emit_download_complete(
str(job.source), download_path=job.download_path.as_posix(), total_bytes=job.total_bytes
)
self._event_bus.emit_download_complete(job)
def _signal_job_cancelled(self, job: DownloadJob) -> None:
if job.status not in [DownloadJobStatus.RUNNING, DownloadJobStatus.WAITING]:
@@ -390,7 +382,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
f"An error occurred while processing the on_cancelled callback: {traceback.format_exception(e)}"
)
if self._event_bus:
self._event_bus.emit_download_cancelled(str(job.source))
self._event_bus.emit_download_cancelled(job)
def _signal_job_error(self, job: DownloadJob, excp: Optional[Exception] = None) -> None:
job.status = DownloadJobStatus.ERROR
@@ -403,9 +395,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
f"An error occurred while processing the on_error callback: {traceback.format_exception(e)}"
)
if self._event_bus:
assert job.error_type
assert job.error
self._event_bus.emit_download_error(str(job.source), error_type=job.error_type, error=job.error)
self._event_bus.emit_download_error(job)
def _cleanup_cancelled_job(self, job: DownloadJob) -> None:
self._logger.debug(f"Cleaning up leftover files from cancelled download job {job.download_path}")

View File

@@ -1,486 +1,253 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import Any, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Optional
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.app.services.session_queue.session_queue_common import (
BatchStatus,
EnqueueBatchResult,
SessionQueueItem,
SessionQueueStatus,
from invokeai.app.services.events.events_common import (
BatchEnqueuedEvent,
BulkDownloadCompleteEvent,
BulkDownloadErrorEvent,
BulkDownloadStartedEvent,
DownloadCancelledEvent,
DownloadCompleteEvent,
DownloadErrorEvent,
DownloadProgressEvent,
DownloadStartedEvent,
EventBase,
ExtraData,
InvocationCompleteEvent,
InvocationDenoiseProgressEvent,
InvocationErrorEvent,
InvocationStartedEvent,
ModelInstallCancelledEvent,
ModelInstallCompleteEvent,
ModelInstallDownloadProgressEvent,
ModelInstallDownloadsCompleteEvent,
ModelInstallErrorEvent,
ModelInstallStartedEvent,
ModelLoadCompleteEvent,
ModelLoadStartedEvent,
QueueClearedEvent,
QueueItemStatusChangedEvent,
SessionCanceledEvent,
SessionCompleteEvent,
SessionStartedEvent,
)
from invokeai.app.util.misc import get_timestamp
from invokeai.backend.model_manager import AnyModelConfig
from invokeai.backend.model_manager.config import SubModelType
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
if TYPE_CHECKING:
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
from invokeai.app.services.download.download_base import DownloadJob
from invokeai.app.services.events.events_common import EventBase
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.app.services.session_queue.session_queue_common import (
BatchStatus,
EnqueueBatchResult,
SessionQueueItem,
SessionQueueStatus,
)
from invokeai.backend.model_manager.config import AnyModelConfig, SubModelType
class EventServiceBase:
queue_event: str = "queue_event"
bulk_download_event: str = "bulk_download_event"
download_event: str = "download_event"
model_event: str = "model_event"
"""Basic event bus, to have an empty stand-in when not needed"""
def dispatch(self, event_name: str, payload: Any) -> None:
def dispatch(self, event: "EventBase") -> None:
pass
def _emit_bulk_download_event(self, event_name: str, payload: dict) -> None:
"""Bulk download events are emitted to a room with queue_id as the room name"""
payload["timestamp"] = get_timestamp()
self.dispatch(
event_name=EventServiceBase.bulk_download_event,
payload={"event": event_name, "data": payload},
)
# region: Invocation
def __emit_queue_event(self, event_name: str, payload: dict) -> None:
"""Queue events are emitted to a room with queue_id as the room name"""
payload["timestamp"] = get_timestamp()
self.dispatch(
event_name=EventServiceBase.queue_event,
payload={"event": event_name, "data": payload},
)
def __emit_download_event(self, event_name: str, payload: dict) -> None:
payload["timestamp"] = get_timestamp()
self.dispatch(
event_name=EventServiceBase.download_event,
payload={"event": event_name, "data": payload},
)
def __emit_model_event(self, event_name: str, payload: dict) -> None:
payload["timestamp"] = get_timestamp()
self.dispatch(
event_name=EventServiceBase.model_event,
payload={"event": event_name, "data": payload},
)
# Define events here for every event in the system.
# This will make them easier to integrate until we find a schema generator.
def emit_generator_progress(
self,
queue_id: str,
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
node_id: str,
source_node_id: str,
progress_image: Optional[ProgressImage],
step: int,
order: int,
total_steps: int,
def emit_invocation_started(
self, queue_item: "SessionQueueItem", invocation: "BaseInvocation", extra: Optional[ExtraData] = None
) -> None:
"""Emitted when there is generation progress"""
self.__emit_queue_event(
event_name="generator_progress",
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"node_id": node_id,
"source_node_id": source_node_id,
"progress_image": progress_image.model_dump(mode="json") if progress_image is not None else None,
"step": step,
"order": order,
"total_steps": total_steps,
},
"""Emitted when an invocation is started"""
self.dispatch(InvocationStartedEvent.build(queue_item, invocation, extra))
def emit_invocation_denoise_progress(
self,
queue_item: "SessionQueueItem",
invocation: "BaseInvocation",
intermediate_state: PipelineIntermediateState,
progress_image: "ProgressImage",
extra: Optional[ExtraData] = None,
) -> None:
"""Emitted at each step during denoising of an invocation."""
self.dispatch(
InvocationDenoiseProgressEvent.build(queue_item, invocation, intermediate_state, progress_image, extra)
)
def emit_invocation_complete(
self,
queue_id: str,
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
result: dict,
node: dict,
source_node_id: str,
queue_item: "SessionQueueItem",
invocation: "BaseInvocation",
output: "BaseInvocationOutput",
extra: Optional[ExtraData] = None,
) -> None:
"""Emitted when an invocation has completed"""
self.__emit_queue_event(
event_name="invocation_complete",
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"node": node,
"source_node_id": source_node_id,
"result": result,
},
)
"""Emitted when an invocation is complete"""
self.dispatch(InvocationCompleteEvent.build(queue_item, invocation, output, extra))
def emit_invocation_error(
self,
queue_id: str,
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
node: dict,
source_node_id: str,
queue_item: "SessionQueueItem",
invocation: "BaseInvocation",
error_type: str,
error: str,
extra: Optional[ExtraData] = None,
) -> None:
"""Emitted when an invocation has completed"""
self.__emit_queue_event(
event_name="invocation_error",
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"node": node,
"source_node_id": source_node_id,
"error_type": error_type,
"error": error,
},
)
"""Emitted when an invocation encounters an error"""
self.dispatch(InvocationErrorEvent.build(queue_item, invocation, error_type, error, extra))
def emit_invocation_started(
self,
queue_id: str,
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
node: dict,
source_node_id: str,
) -> None:
"""Emitted when an invocation has started"""
self.__emit_queue_event(
event_name="invocation_started",
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"node": node,
"source_node_id": source_node_id,
},
)
# endregion
def emit_graph_execution_complete(
self, queue_id: str, queue_item_id: int, queue_batch_id: str, graph_execution_state_id: str
) -> None:
# region Session
def emit_session_started(self, queue_item: "SessionQueueItem", extra: Optional[ExtraData] = None) -> None:
"""Emitted when a session has started"""
self.dispatch(SessionStartedEvent.build(queue_item, extra))
def emit_session_complete(self, queue_item: "SessionQueueItem", extra: Optional[ExtraData] = None) -> None:
"""Emitted when a session has completed all invocations"""
self.__emit_queue_event(
event_name="graph_execution_state_complete",
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
},
)
self.dispatch(SessionCompleteEvent.build(queue_item, extra))
def emit_model_load_started(
self,
queue_id: str,
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
model_config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> None:
"""Emitted when a model is requested"""
self.__emit_queue_event(
event_name="model_load_started",
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"model_config": model_config.model_dump(mode="json"),
"submodel_type": submodel_type,
},
)
def emit_model_load_completed(
self,
queue_id: str,
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
model_config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> None:
"""Emitted when a model is correctly loaded (returns model info)"""
self.__emit_queue_event(
event_name="model_load_completed",
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"model_config": model_config.model_dump(mode="json"),
"submodel_type": submodel_type,
},
)
def emit_session_canceled(
self,
queue_id: str,
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
) -> None:
def emit_session_canceled(self, queue_item: "SessionQueueItem", extra: Optional[ExtraData] = None) -> None:
"""Emitted when a session is canceled"""
self.__emit_queue_event(
event_name="session_canceled",
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
},
)
self.dispatch(SessionCanceledEvent.build(queue_item, extra))
# endregion
# region Queue
def emit_queue_item_status_changed(
self,
session_queue_item: SessionQueueItem,
batch_status: BatchStatus,
queue_status: SessionQueueStatus,
queue_item: "SessionQueueItem",
batch_status: "BatchStatus",
queue_status: "SessionQueueStatus",
extra: Optional[ExtraData] = None,
) -> None:
"""Emitted when a queue item's status changes"""
self.__emit_queue_event(
event_name="queue_item_status_changed",
payload={
"queue_id": queue_status.queue_id,
"queue_item": {
"queue_id": session_queue_item.queue_id,
"item_id": session_queue_item.item_id,
"status": session_queue_item.status,
"batch_id": session_queue_item.batch_id,
"session_id": session_queue_item.session_id,
"error": session_queue_item.error,
"created_at": str(session_queue_item.created_at) if session_queue_item.created_at else None,
"updated_at": str(session_queue_item.updated_at) if session_queue_item.updated_at else None,
"started_at": str(session_queue_item.started_at) if session_queue_item.started_at else None,
"completed_at": str(session_queue_item.completed_at) if session_queue_item.completed_at else None,
},
"batch_status": batch_status.model_dump(mode="json"),
"queue_status": queue_status.model_dump(mode="json"),
},
)
self.dispatch(QueueItemStatusChangedEvent.build(queue_item, batch_status, queue_status, extra))
def emit_batch_enqueued(self, enqueue_result: EnqueueBatchResult) -> None:
def emit_batch_enqueued(self, enqueue_result: "EnqueueBatchResult", extra: Optional[ExtraData] = None) -> None:
"""Emitted when a batch is enqueued"""
self.__emit_queue_event(
event_name="batch_enqueued",
payload={
"queue_id": enqueue_result.queue_id,
"batch_id": enqueue_result.batch.batch_id,
"enqueued": enqueue_result.enqueued,
},
)
self.dispatch(BatchEnqueuedEvent.build(enqueue_result, extra))
def emit_queue_cleared(self, queue_id: str) -> None:
"""Emitted when the queue is cleared"""
self.__emit_queue_event(
event_name="queue_cleared",
payload={"queue_id": queue_id},
)
def emit_queue_cleared(self, queue_id: str, extra: Optional[ExtraData] = None) -> None:
"""Emitted when a queue is cleared"""
self.dispatch(QueueClearedEvent.build(queue_id, extra))
def emit_download_started(self, source: str, download_path: str) -> None:
"""
Emit when a download job is started.
# endregion
:param url: The downloaded url
"""
self.__emit_download_event(
event_name="download_started",
payload={"source": source, "download_path": download_path},
)
# region Download
def emit_download_progress(self, source: str, download_path: str, current_bytes: int, total_bytes: int) -> None:
"""
Emit "download_progress" events at regular intervals during a download job.
def emit_download_started(self, job: "DownloadJob", extra: Optional[ExtraData] = None) -> None:
"""Emitted when a download is started"""
self.dispatch(DownloadStartedEvent.build(job, extra))
:param source: The downloaded source
:param download_path: The local downloaded file
:param current_bytes: Number of bytes downloaded so far
:param total_bytes: The size of the file being downloaded (if known)
"""
self.__emit_download_event(
event_name="download_progress",
payload={
"source": source,
"download_path": download_path,
"current_bytes": current_bytes,
"total_bytes": total_bytes,
},
)
def emit_download_progress(self, job: "DownloadJob", extra: Optional[ExtraData] = None) -> None:
"""Emitted at intervals during a download"""
self.dispatch(DownloadProgressEvent.build(job, extra))
def emit_download_complete(self, source: str, download_path: str, total_bytes: int) -> None:
"""
Emit a "download_complete" event at the end of a successful download.
def emit_download_complete(self, job: "DownloadJob", extra: Optional[ExtraData] = None) -> None:
"""Emitted when a download is completed"""
self.dispatch(DownloadCompleteEvent.build(job, extra))
:param source: Source URL
:param download_path: Path to the locally downloaded file
:param total_bytes: The size of the downloaded file
"""
self.__emit_download_event(
event_name="download_complete",
payload={
"source": source,
"download_path": download_path,
"total_bytes": total_bytes,
},
)
def emit_download_cancelled(self, job: "DownloadJob", extra: Optional[ExtraData] = None) -> None:
"""Emitted when a download is cancelled"""
self.dispatch(DownloadCancelledEvent.build(job, extra))
def emit_download_cancelled(self, source: str) -> None:
"""Emit a "download_cancelled" event in the event that the download was cancelled by user."""
self.__emit_download_event(
event_name="download_cancelled",
payload={
"source": source,
},
)
def emit_download_error(self, job: "DownloadJob", extra: Optional[ExtraData] = None) -> None:
"""Emitted when a download encounters an error"""
self.dispatch(DownloadErrorEvent.build(job, extra))
def emit_download_error(self, source: str, error_type: str, error: str) -> None:
"""
Emit a "download_error" event when an download job encounters an exception.
# endregion
:param source: Source URL
:param error_type: The name of the exception that raised the error
:param error: The traceback from this error
"""
self.__emit_download_event(
event_name="download_error",
payload={
"source": source,
"error_type": error_type,
"error": error,
},
)
# region Model loading
def emit_model_install_downloading(
def emit_model_load_started(
self,
source: str,
local_path: str,
bytes: int,
total_bytes: int,
parts: List[Dict[str, Union[str, int]]],
id: int,
config: "AnyModelConfig",
submodel_type: Optional["SubModelType"] = None,
extra: Optional[ExtraData] = None,
) -> None:
"""
Emit at intervals while the install job is in progress (remote models only).
"""Emitted when a model load is started."""
self.dispatch(ModelLoadStartedEvent.build(config, submodel_type, extra))
:param source: Source of the model
:param local_path: Where model is downloading to
:param parts: Progress of downloading URLs that comprise the model, if any.
:param bytes: Number of bytes downloaded so far.
:param total_bytes: Total size of download, including all files.
This emits a Dict with keys "source", "local_path", "bytes" and "total_bytes".
"""
self.__emit_model_event(
event_name="model_install_downloading",
payload={
"source": source,
"local_path": local_path,
"bytes": bytes,
"total_bytes": total_bytes,
"parts": parts,
"id": id,
},
)
def emit_model_load_complete(
self,
config: "AnyModelConfig",
submodel_type: Optional["SubModelType"] = None,
extra: Optional[ExtraData] = None,
) -> None:
"""Emitted when a model load is complete."""
self.dispatch(ModelLoadCompleteEvent.build(config, submodel_type, extra))
def emit_model_install_downloads_done(self, source: str) -> None:
"""
Emit once when all parts are downloaded, but before the probing and registration start.
# endregion
:param source: Source of the model; local path, repo_id or url
"""
self.__emit_model_event(
event_name="model_install_downloads_done",
payload={"source": source},
)
# region Model install
def emit_model_install_running(self, source: str) -> None:
"""
Emit once when an install job becomes active.
def emit_model_install_download_progress(self, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> None:
"""Emitted at intervals while the install job is in progress (remote models only)."""
self.dispatch(ModelInstallDownloadProgressEvent.build(job, extra))
:param source: Source of the model; local path, repo_id or url
"""
self.__emit_model_event(
event_name="model_install_running",
payload={"source": source},
)
def emit_model_install_downloads_complete(self, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> None:
self.dispatch(ModelInstallDownloadsCompleteEvent.build(job, extra))
def emit_model_install_completed(self, source: str, key: str, id: int, total_bytes: Optional[int] = None) -> None:
"""
Emit when an install job is completed successfully.
def emit_model_install_started(self, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> None:
"""Emitted once when an install job is started (after any download)."""
self.dispatch(ModelInstallStartedEvent.build(job, extra))
:param source: Source of the model; local path, repo_id or url
:param key: Model config record key
:param total_bytes: Size of the model (may be None for installation of a local path)
"""
self.__emit_model_event(
event_name="model_install_completed",
payload={"source": source, "total_bytes": total_bytes, "key": key, "id": id},
)
def emit_model_install_complete(self, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> None:
"""Emitted when an install job is completed successfully."""
self.dispatch(ModelInstallCompleteEvent.build(job, extra))
def emit_model_install_cancelled(self, source: str, id: int) -> None:
"""
Emit when an install job is cancelled.
def emit_model_install_cancelled(self, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> None:
"""Emitted when an install job is cancelled."""
self.dispatch(ModelInstallCancelledEvent.build(job, extra))
:param source: Source of the model; local path, repo_id or url
"""
self.__emit_model_event(
event_name="model_install_cancelled",
payload={"source": source, "id": id},
)
def emit_model_install_error(self, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> None:
"""Emitted when an install job encounters an exception."""
self.dispatch(ModelInstallErrorEvent.build(job, extra))
def emit_model_install_error(self, source: str, error_type: str, error: str, id: int) -> None:
"""
Emit when an install job encounters an exception.
# endregion
:param source: Source of the model
:param error_type: The name of the exception
:param error: A text description of the exception
"""
self.__emit_model_event(
event_name="model_install_error",
payload={"source": source, "error_type": error_type, "error": error, "id": id},
)
# region Bulk image download
def emit_bulk_download_started(
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
self,
bulk_download_id: str,
bulk_download_item_id: str,
bulk_download_item_name: str,
extra: Optional[ExtraData] = None,
) -> None:
"""Emitted when a bulk download starts"""
self._emit_bulk_download_event(
event_name="bulk_download_started",
payload={
"bulk_download_id": bulk_download_id,
"bulk_download_item_id": bulk_download_item_id,
"bulk_download_item_name": bulk_download_item_name,
},
"""Emitted when a bulk image download is started"""
self.dispatch(
BulkDownloadStartedEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name, extra)
)
def emit_bulk_download_completed(
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
def emit_bulk_download_complete(
self,
bulk_download_id: str,
bulk_download_item_id: str,
bulk_download_item_name: str,
extra: Optional[ExtraData] = None,
) -> None:
"""Emitted when a bulk download completes"""
self._emit_bulk_download_event(
event_name="bulk_download_completed",
payload={
"bulk_download_id": bulk_download_id,
"bulk_download_item_id": bulk_download_item_id,
"bulk_download_item_name": bulk_download_item_name,
},
"""Emitted when a bulk image download is complete"""
self.dispatch(
BulkDownloadCompleteEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name, extra)
)
def emit_bulk_download_failed(
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, error: str
def emit_bulk_download_error(
self,
bulk_download_id: str,
bulk_download_item_id: str,
bulk_download_item_name: str,
error: str,
extra: Optional[ExtraData] = None,
) -> None:
"""Emitted when a bulk download fails"""
self._emit_bulk_download_event(
event_name="bulk_download_failed",
payload={
"bulk_download_id": bulk_download_id,
"bulk_download_item_id": bulk_download_item_id,
"bulk_download_item_name": bulk_download_item_name,
"error": error,
},
"""Emitted when a bulk image download has an error"""
self.dispatch(
BulkDownloadErrorEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name, error, extra)
)
# endregion

View File

@@ -0,0 +1,707 @@
from math import floor
from typing import TYPE_CHECKING, Any, Coroutine, Optional, Protocol, TypeAlias, TypeVar
from fastapi_events.handlers.local import local_handler
from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.app.services.session_queue.session_queue_common import (
QUEUE_ITEM_STATUS,
BatchStatus,
EnqueueBatchResult,
SessionQueueItem,
SessionQueueStatus,
)
from invokeai.app.util.misc import get_timestamp
from invokeai.backend.model_manager.config import AnyModelConfig, SubModelType
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
if TYPE_CHECKING:
from invokeai.app.services.download.download_base import DownloadJob
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
ExtraData: TypeAlias = dict[str, Any]
class EventBase(BaseModel):
"""Base class for all events. All events must inherit from this class.
Events must define a class attribute `__event_name__` to identify the event.
All other attributes should be defined as normal for a pydantic model.
A timestamp is automatically added to the event when it is created.
"""
timestamp: int = Field(description="The timestamp of the event", default_factory=get_timestamp)
extra: Optional[ExtraData] = Field(default=None, description="Extra data to include with the event")
model_config = ConfigDict(json_schema_serialization_defaults_required=True)
@classmethod
def get_events(cls) -> set[type["EventBase"]]:
"""Get a set of all event models."""
event_subclasses: set[type["EventBase"]] = set()
for subclass in cls.__subclasses__():
# We only want to include subclasses that are event models, not intermediary classes
if hasattr(subclass, "__event_name__"):
event_subclasses.add(subclass)
event_subclasses.update(subclass.get_events())
return event_subclasses
TEvent = TypeVar("TEvent", bound=EventBase)
FastAPIEvent: TypeAlias = tuple[str, TEvent]
"""
A tuple representing a `fastapi-events` event, with the event name and payload.
Provide a generic type to `TEvent` to specify the payload type.
"""
class FastAPIEventFunc(Protocol):
def __call__(self, event: FastAPIEvent[Any]) -> Optional[Coroutine[Any, Any, None]]: ...
def register_events(events: set[type[TEvent]], func: FastAPIEventFunc) -> None:
"""Register a function to handle a list of events.
:param events: A list of event classes to handle
:param func: The function to handle the events
"""
for event in events:
assert hasattr(event, "__event_name__")
local_handler.register(event_name=event.__event_name__, _func=func) # pyright: ignore [reportUnknownMemberType, reportUnknownArgumentType, reportAttributeAccessIssue]
class QueueEventBase(EventBase):
"""Base class for queue events"""
queue_id: str = Field(description="The ID of the queue")
class QueueItemEventBase(QueueEventBase):
"""Base class for queue item events"""
item_id: int = Field(description="The ID of the queue item")
batch_id: str = Field(description="The ID of the queue batch")
class SessionEventBase(QueueItemEventBase):
"""Base class for session (aka graph execution state) events"""
session_id: str = Field(description="The ID of the session (aka graph execution state)")
class InvocationEventBase(SessionEventBase):
"""Base class for invocation events"""
queue_id: str = Field(description="The ID of the queue")
item_id: int = Field(description="The ID of the queue item")
batch_id: str = Field(description="The ID of the queue batch")
session_id: str = Field(description="The ID of the session (aka graph execution state)")
invocation_id: str = Field(description="The ID of the invocation")
invocation_source_id: str = Field(description="The ID of the prepared invocation's source node")
invocation_type: str = Field(description="The type of invocation")
class InvocationStartedEvent(InvocationEventBase):
"""Event model for invocation_started"""
__event_name__ = "invocation_started"
@classmethod
def build(
cls, queue_item: SessionQueueItem, invocation: BaseInvocation, extra: Optional[ExtraData] = None
) -> "InvocationStartedEvent":
return cls(
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
batch_id=queue_item.batch_id,
session_id=queue_item.session_id,
invocation_id=invocation.id,
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
invocation_type=invocation.get_type(),
extra=extra,
)
class InvocationDenoiseProgressEvent(InvocationEventBase):
"""Event model for invocation_denoise_progress"""
__event_name__ = "invocation_denoise_progress"
progress_image: ProgressImage = Field(description="The progress image sent at each step during processing")
step: int = Field(description="The current step of the invocation")
total_steps: int = Field(description="The total number of steps in the invocation")
order: int = Field(description="The order of the invocation in the session")
percentage: float = Field(description="The percentage of completion of the invocation")
@classmethod
def build(
cls,
queue_item: SessionQueueItem,
invocation: BaseInvocation,
intermediate_state: PipelineIntermediateState,
progress_image: ProgressImage,
extra: Optional[ExtraData] = None,
) -> "InvocationDenoiseProgressEvent":
step = intermediate_state.step
total_steps = intermediate_state.total_steps
order = intermediate_state.order
return cls(
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
batch_id=queue_item.batch_id,
session_id=queue_item.session_id,
invocation_id=invocation.id,
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
invocation_type=invocation.get_type(),
progress_image=progress_image,
step=step,
total_steps=total_steps,
order=order,
percentage=cls.calc_percentage(step, total_steps, order),
extra=extra,
)
@staticmethod
def calc_percentage(step: int, total_steps: int, scheduler_order: float) -> float:
"""Calculate the percentage of completion of denoising."""
if total_steps == 0:
return 0.0
if scheduler_order == 2:
return floor((step + 1 + 1) / 2) / floor((total_steps + 1) / 2)
# order == 1
return (step + 1 + 1) / (total_steps + 1)
class InvocationCompleteEvent(InvocationEventBase):
"""Event model for invocation_complete"""
__event_name__ = "invocation_complete"
result: SerializeAsAny[BaseInvocationOutput] = Field(description="The result of the invocation")
@classmethod
def build(
cls,
queue_item: SessionQueueItem,
invocation: BaseInvocation,
result: BaseInvocationOutput,
extra: Optional[ExtraData] = None,
) -> "InvocationCompleteEvent":
return cls(
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
batch_id=queue_item.batch_id,
session_id=queue_item.session_id,
invocation_id=invocation.id,
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
invocation_type=invocation.get_type(),
result=result,
extra=extra,
)
class InvocationErrorEvent(InvocationEventBase):
"""Event model for invocation_error"""
__event_name__ = "invocation_error"
error_type: str = Field(description="The type of error")
error: str = Field(description="The error message")
@classmethod
def build(
cls,
queue_item: SessionQueueItem,
invocation: BaseInvocation,
error_type: str,
error: str,
extra: Optional[ExtraData] = None,
) -> "InvocationErrorEvent":
return cls(
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
batch_id=queue_item.batch_id,
session_id=queue_item.session_id,
invocation_id=invocation.id,
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
invocation_type=invocation.get_type(),
error_type=error_type,
error=error,
extra=extra,
)
class SessionStartedEvent(SessionEventBase):
"""Event model for session_started"""
__event_name__ = "session_started"
@classmethod
def build(cls, queue_item: SessionQueueItem, extra: Optional[ExtraData] = None) -> "SessionStartedEvent":
return cls(
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
batch_id=queue_item.batch_id,
session_id=queue_item.session_id,
extra=extra,
)
class SessionCompleteEvent(SessionEventBase):
"""Event model for session_complete"""
__event_name__ = "session_complete"
@classmethod
def build(cls, queue_item: SessionQueueItem, extra: Optional[ExtraData] = None) -> "SessionCompleteEvent":
return cls(
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
batch_id=queue_item.batch_id,
session_id=queue_item.session_id,
extra=extra,
)
class SessionCanceledEvent(SessionEventBase):
"""Event model for session_canceled"""
__event_name__ = "session_canceled"
@classmethod
def build(cls, queue_item: SessionQueueItem, extra: Optional[ExtraData] = None) -> "SessionCanceledEvent":
return cls(
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
batch_id=queue_item.batch_id,
session_id=queue_item.session_id,
extra=extra,
)
class QueueItemStatusChangedEvent(QueueItemEventBase):
"""Event model for queue_item_status_changed"""
__event_name__ = "queue_item_status_changed"
status: QUEUE_ITEM_STATUS = Field(description="The new status of the queue item")
error: Optional[str] = Field(default=None, description="The error message, if any")
created_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was created")
updated_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was last updated")
started_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was started")
completed_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was completed")
batch_status: BatchStatus = Field(description="The status of the batch")
queue_status: SessionQueueStatus = Field(description="The status of the queue")
@classmethod
def build(
cls,
queue_item: SessionQueueItem,
batch_status: BatchStatus,
queue_status: SessionQueueStatus,
extra: Optional[ExtraData] = None,
) -> "QueueItemStatusChangedEvent":
return cls(
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
batch_id=queue_item.batch_id,
status=queue_item.status,
error=queue_item.error,
created_at=str(queue_item.created_at) if queue_item.created_at else None,
updated_at=str(queue_item.updated_at) if queue_item.updated_at else None,
started_at=str(queue_item.started_at) if queue_item.started_at else None,
completed_at=str(queue_item.completed_at) if queue_item.completed_at else None,
batch_status=batch_status,
queue_status=queue_status,
extra=extra,
)
class BatchEnqueuedEvent(QueueEventBase):
"""Event model for batch_enqueued"""
__event_name__ = "batch_enqueued"
batch_id: str = Field(description="The ID of the batch")
enqueued: int = Field(description="The number of invocations enqueued")
requested: int = Field(
description="The number of invocations initially requested to be enqueued (may be less than enqueued if queue was full)"
)
priority: int = Field(description="The priority of the batch")
@classmethod
def build(cls, enqueue_result: EnqueueBatchResult, extra: Optional[ExtraData] = None) -> "BatchEnqueuedEvent":
return cls(
queue_id=enqueue_result.queue_id,
batch_id=enqueue_result.batch.batch_id,
enqueued=enqueue_result.enqueued,
requested=enqueue_result.requested,
priority=enqueue_result.priority,
extra=extra,
)
class QueueClearedEvent(QueueEventBase):
"""Event model for queue_cleared"""
__event_name__ = "queue_cleared"
@classmethod
def build(cls, queue_id: str, extra: Optional[ExtraData] = None) -> "QueueClearedEvent":
return cls(
queue_id=queue_id,
extra=extra,
)
class DownloadEventBase(EventBase):
"""Base class for events associated with a download"""
source: str = Field(description="The source of the download")
class DownloadStartedEvent(DownloadEventBase):
"""Event model for download_started"""
__event_name__ = "download_started"
download_path: str = Field(description="The local path where the download is saved")
@classmethod
def build(cls, job: "DownloadJob", extra: Optional[ExtraData] = None) -> "DownloadStartedEvent":
assert job.download_path
return cls(
source=str(job.source),
download_path=job.download_path.as_posix(),
extra=extra,
)
class DownloadProgressEvent(DownloadEventBase):
"""Event model for download_progress"""
__event_name__ = "download_progress"
download_path: str = Field(description="The local path where the download is saved")
current_bytes: int = Field(description="The number of bytes downloaded so far")
total_bytes: int = Field(description="The total number of bytes to be downloaded")
@classmethod
def build(cls, job: "DownloadJob", extra: Optional[ExtraData] = None) -> "DownloadProgressEvent":
assert job.download_path
return cls(
source=str(job.source),
download_path=job.download_path.as_posix(),
current_bytes=job.bytes,
total_bytes=job.total_bytes,
extra=extra,
)
class DownloadCompleteEvent(DownloadEventBase):
"""Event model for download_complete"""
__event_name__ = "download_complete"
download_path: str = Field(description="The local path where the download is saved")
total_bytes: int = Field(description="The total number of bytes downloaded")
@classmethod
def build(cls, job: "DownloadJob", extra: Optional[ExtraData] = None) -> "DownloadCompleteEvent":
assert job.download_path
return cls(
source=str(job.source),
download_path=job.download_path.as_posix(),
total_bytes=job.total_bytes,
extra=extra,
)
class DownloadCancelledEvent(DownloadEventBase):
"""Event model for download_cancelled"""
__event_name__ = "download_cancelled"
@classmethod
def build(cls, job: "DownloadJob", extra: Optional[ExtraData] = None) -> "DownloadCancelledEvent":
return cls(
source=str(job.source),
extra=extra,
)
class DownloadErrorEvent(DownloadEventBase):
"""Event model for download_error"""
__event_name__ = "download_error"
error_type: str = Field(description="The type of error")
error: str = Field(description="The error message")
@classmethod
def build(cls, job: "DownloadJob", extra: Optional[ExtraData] = None) -> "DownloadErrorEvent":
assert job.error_type
assert job.error
return cls(
source=str(job.source),
error_type=job.error_type,
error=job.error,
extra=extra,
)
class ModelEventBase(EventBase):
"""Base class for events associated with a model"""
class ModelLoadStartedEvent(ModelEventBase):
"""Event model for model_load_started"""
__event_name__ = "model_load_started"
config: AnyModelConfig = Field(description="The model's config")
submodel_type: Optional[SubModelType] = Field(default=None, description="The submodel type, if any")
@classmethod
def build(
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None, extra: Optional[ExtraData] = None
) -> "ModelLoadStartedEvent":
return cls(
config=config,
submodel_type=submodel_type,
extra=extra,
)
class ModelLoadCompleteEvent(ModelEventBase):
"""Event model for model_load_complete"""
__event_name__ = "model_load_complete"
config: AnyModelConfig = Field(description="The model's config")
submodel_type: Optional[SubModelType] = Field(default=None, description="The submodel type, if any")
@classmethod
def build(
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None, extra: Optional[ExtraData] = None
) -> "ModelLoadCompleteEvent":
return cls(
config=config,
submodel_type=submodel_type,
extra=extra,
)
class ModelInstallDownloadProgressEvent(ModelEventBase):
"""Event model for model_install_download_progress"""
__event_name__ = "model_install_download_progress"
id: int = Field(description="The ID of the install job")
source: str = Field(description="Source of the model; local path, repo_id or url")
local_path: str = Field(description="Where model is downloading to")
bytes: int = Field(description="Number of bytes downloaded so far")
total_bytes: int = Field(description="Total size of download, including all files")
parts: list[dict[str, int | str]] = Field(
description="Progress of downloading URLs that comprise the model, if any"
)
@classmethod
def build(cls, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> "ModelInstallDownloadProgressEvent":
parts: list[dict[str, str | int]] = [
{
"url": str(x.source),
"local_path": str(x.download_path),
"bytes": x.bytes,
"total_bytes": x.total_bytes,
}
for x in job.download_parts
]
return cls(
id=job.id,
source=str(job.source),
local_path=job.local_path.as_posix(),
parts=parts,
bytes=job.bytes,
total_bytes=job.total_bytes,
extra=extra,
)
class ModelInstallDownloadsCompleteEvent(ModelEventBase):
"""Emitted once when an install job becomes active."""
__event_name__ = "model_install_downloads_complete"
id: int = Field(description="The ID of the install job")
source: str = Field(description="Source of the model; local path, repo_id or url")
@classmethod
def build(cls, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> "ModelInstallDownloadsCompleteEvent":
return cls(
id=job.id,
source=str(job.source),
extra=extra,
)
class ModelInstallStartedEvent(ModelEventBase):
"""Event model for model_install_started"""
__event_name__ = "model_install_started"
id: int = Field(description="The ID of the install job")
source: str = Field(description="Source of the model; local path, repo_id or url")
@classmethod
def build(cls, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> "ModelInstallStartedEvent":
return cls(
id=job.id,
source=str(job.source),
extra=extra,
)
class ModelInstallCompleteEvent(ModelEventBase):
"""Event model for model_install_complete"""
__event_name__ = "model_install_complete"
id: int = Field(description="The ID of the install job")
source: str = Field(description="Source of the model; local path, repo_id or url")
key: str = Field(description="Model config record key")
total_bytes: Optional[int] = Field(description="Size of the model (may be None for installation of a local path)")
@classmethod
def build(cls, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> "ModelInstallCompleteEvent":
assert job.config_out is not None
return cls(
id=job.id,
source=str(job.source),
key=(job.config_out.key),
total_bytes=job.total_bytes,
extra=extra,
)
class ModelInstallCancelledEvent(ModelEventBase):
"""Event model for model_install_cancelled"""
__event_name__ = "model_install_cancelled"
id: int = Field(description="The ID of the install job")
source: str = Field(description="Source of the model; local path, repo_id or url")
@classmethod
def build(cls, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> "ModelInstallCancelledEvent":
return cls(
id=job.id,
source=str(job.source),
extra=extra,
)
class ModelInstallErrorEvent(ModelEventBase):
"""Event model for model_install_error"""
__event_name__ = "model_install_error"
id: int = Field(description="The ID of the install job")
source: str = Field(description="Source of the model; local path, repo_id or url")
error_type: str = Field(description="The name of the exception")
error: str = Field(description="A text description of the exception")
@classmethod
def build(cls, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> "ModelInstallErrorEvent":
assert job.error_type is not None
assert job.error is not None
return cls(
id=job.id,
source=str(job.source),
error_type=job.error_type,
error=job.error,
extra=extra,
)
class BulkDownloadEventBase(EventBase):
"""Base class for events associated with a bulk image download"""
bulk_download_id: str = Field(description="The ID of the bulk image download")
bulk_download_item_id: str = Field(description="The ID of the bulk image download item")
bulk_download_item_name: str = Field(description="The name of the bulk image download item")
class BulkDownloadStartedEvent(BulkDownloadEventBase):
"""Event model for bulk_download_started"""
__event_name__ = "bulk_download_started"
@classmethod
def build(
cls,
bulk_download_id: str,
bulk_download_item_id: str,
bulk_download_item_name: str,
extra: Optional[ExtraData] = None,
) -> "BulkDownloadStartedEvent":
return cls(
bulk_download_id=bulk_download_id,
bulk_download_item_id=bulk_download_item_id,
bulk_download_item_name=bulk_download_item_name,
extra=extra,
)
class BulkDownloadCompleteEvent(BulkDownloadEventBase):
"""Event model for bulk_download_complete"""
__event_name__ = "bulk_download_complete"
@classmethod
def build(
cls,
bulk_download_id: str,
bulk_download_item_id: str,
bulk_download_item_name: str,
extra: Optional[ExtraData] = None,
) -> "BulkDownloadCompleteEvent":
return cls(
bulk_download_id=bulk_download_id,
bulk_download_item_id=bulk_download_item_id,
bulk_download_item_name=bulk_download_item_name,
extra=extra,
)
class BulkDownloadErrorEvent(BulkDownloadEventBase):
"""Event model for bulk_download_error"""
__event_name__ = "bulk_download_error"
error: str = Field(description="The error message")
@classmethod
def build(
cls,
bulk_download_id: str,
bulk_download_item_id: str,
bulk_download_item_name: str,
error: str,
extra: Optional[ExtraData] = None,
) -> "BulkDownloadErrorEvent":
return cls(
bulk_download_id=bulk_download_id,
bulk_download_item_id=bulk_download_item_id,
bulk_download_item_name=bulk_download_item_name,
error=error,
extra=extra,
)

View File

@@ -0,0 +1,46 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import asyncio
import threading
from queue import Empty, Queue
from fastapi_events.dispatcher import dispatch
from invokeai.app.services.events.events_common import (
EventBase,
)
from .events_base import EventServiceBase
class FastAPIEventService(EventServiceBase):
def __init__(self, event_handler_id: int) -> None:
self.event_handler_id = event_handler_id
self._queue = Queue[EventBase | None]()
self._stop_event = threading.Event()
asyncio.create_task(self._dispatch_from_queue(stop_event=self._stop_event))
super().__init__()
def stop(self, *args, **kwargs):
self._stop_event.set()
self._queue.put(None)
def dispatch(self, event: EventBase) -> None:
self._queue.put(event)
async def _dispatch_from_queue(self, stop_event: threading.Event):
"""Get events on from the queue and dispatch them, from the correct thread"""
while not stop_event.is_set():
try:
event = self._queue.get(block=False)
if not event: # Probably stopping
continue
dispatch(event, middleware_id=self.event_handler_id, payload_schema_dump=False)
except Empty:
await asyncio.sleep(0.1)
pass
except asyncio.CancelledError as e:
raise e # Raise a proper error

View File

@@ -4,9 +4,6 @@ from typing import Optional
from PIL.Image import Image as PILImageType
from invokeai.app.invocations.fields import MetadataField
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
class ImageFileStorageBase(ABC):
"""Low-level service responsible for storing and retrieving image files."""
@@ -33,8 +30,9 @@ class ImageFileStorageBase(ABC):
self,
image: PILImageType,
image_name: str,
metadata: Optional[MetadataField] = None,
workflow: Optional[WorkflowWithoutID] = None,
metadata: Optional[str] = None,
workflow: Optional[str] = None,
graph: Optional[str] = None,
thumbnail_size: int = 256,
) -> None:
"""Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image name, thumbnail name, and created timestamp."""
@@ -46,6 +44,11 @@ class ImageFileStorageBase(ABC):
pass
@abstractmethod
def get_workflow(self, image_name: str) -> Optional[WorkflowWithoutID]:
def get_workflow(self, image_name: str) -> Optional[str]:
"""Gets the workflow of an image."""
pass
@abstractmethod
def get_graph(self, image_name: str) -> Optional[str]:
"""Gets the graph of an image."""
pass

View File

@@ -7,9 +7,7 @@ from PIL import Image, PngImagePlugin
from PIL.Image import Image as PILImageType
from send2trash import send2trash
from invokeai.app.invocations.fields import MetadataField
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
from .image_files_base import ImageFileStorageBase
@@ -56,8 +54,9 @@ class DiskImageFileStorage(ImageFileStorageBase):
self,
image: PILImageType,
image_name: str,
metadata: Optional[MetadataField] = None,
workflow: Optional[WorkflowWithoutID] = None,
metadata: Optional[str] = None,
workflow: Optional[str] = None,
graph: Optional[str] = None,
thumbnail_size: int = 256,
) -> None:
try:
@@ -68,13 +67,14 @@ class DiskImageFileStorage(ImageFileStorageBase):
info_dict = {}
if metadata is not None:
metadata_json = metadata.model_dump_json()
info_dict["invokeai_metadata"] = metadata_json
pnginfo.add_text("invokeai_metadata", metadata_json)
info_dict["invokeai_metadata"] = metadata
pnginfo.add_text("invokeai_metadata", metadata)
if workflow is not None:
workflow_json = workflow.model_dump_json()
info_dict["invokeai_workflow"] = workflow_json
pnginfo.add_text("invokeai_workflow", workflow_json)
info_dict["invokeai_workflow"] = workflow
pnginfo.add_text("invokeai_workflow", workflow)
if graph is not None:
info_dict["invokeai_graph"] = graph
pnginfo.add_text("invokeai_graph", graph)
# When saving the image, the image object's info field is not populated. We need to set it
image.info = info_dict
@@ -129,11 +129,18 @@ class DiskImageFileStorage(ImageFileStorageBase):
path = path if isinstance(path, Path) else Path(path)
return path.exists()
def get_workflow(self, image_name: str) -> WorkflowWithoutID | None:
def get_workflow(self, image_name: str) -> str | None:
image = self.get(image_name)
workflow = image.info.get("invokeai_workflow", None)
if workflow is not None:
return WorkflowWithoutID.model_validate_json(workflow)
if isinstance(workflow, str):
return workflow
return None
def get_graph(self, image_name: str) -> str | None:
image = self.get(image_name)
graph = image.info.get("invokeai_graph", None)
if isinstance(graph, str):
return graph
return None
def __validate_storage_folders(self) -> None:

View File

@@ -80,7 +80,7 @@ class ImageRecordStorageBase(ABC):
starred: Optional[bool] = False,
session_id: Optional[str] = None,
node_id: Optional[str] = None,
metadata: Optional[MetadataField] = None,
metadata: Optional[str] = None,
) -> datetime:
"""Saves an image record."""
pass

View File

@@ -328,10 +328,9 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
starred: Optional[bool] = False,
session_id: Optional[str] = None,
node_id: Optional[str] = None,
metadata: Optional[MetadataField] = None,
metadata: Optional[str] = None,
) -> datetime:
try:
metadata_json = metadata.model_dump_json() if metadata is not None else None
self._lock.acquire()
self._cursor.execute(
"""--sql
@@ -358,7 +357,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
height,
node_id,
session_id,
metadata_json,
metadata,
is_intermediate,
starred,
has_workflow,

View File

@@ -12,7 +12,6 @@ from invokeai.app.services.image_records.image_records_common import (
)
from invokeai.app.services.images.images_common import ImageDTO
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
class ImageServiceABC(ABC):
@@ -51,8 +50,9 @@ class ImageServiceABC(ABC):
session_id: Optional[str] = None,
board_id: Optional[str] = None,
is_intermediate: Optional[bool] = False,
metadata: Optional[MetadataField] = None,
workflow: Optional[WorkflowWithoutID] = None,
metadata: Optional[str] = None,
workflow: Optional[str] = None,
graph: Optional[str] = None,
) -> ImageDTO:
"""Creates an image, storing the file and its metadata."""
pass
@@ -87,7 +87,12 @@ class ImageServiceABC(ABC):
pass
@abstractmethod
def get_workflow(self, image_name: str) -> Optional[WorkflowWithoutID]:
def get_workflow(self, image_name: str) -> Optional[str]:
"""Gets an image's workflow."""
pass
@abstractmethod
def get_graph(self, image_name: str) -> Optional[str]:
"""Gets an image's workflow."""
pass

View File

@@ -5,7 +5,6 @@ from PIL.Image import Image as PILImageType
from invokeai.app.invocations.fields import MetadataField
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
from ..image_files.image_files_common import (
ImageFileDeleteException,
@@ -42,8 +41,9 @@ class ImageService(ImageServiceABC):
session_id: Optional[str] = None,
board_id: Optional[str] = None,
is_intermediate: Optional[bool] = False,
metadata: Optional[MetadataField] = None,
workflow: Optional[WorkflowWithoutID] = None,
metadata: Optional[str] = None,
workflow: Optional[str] = None,
graph: Optional[str] = None,
) -> ImageDTO:
if image_origin not in ResourceOrigin:
raise InvalidOriginException
@@ -64,7 +64,7 @@ class ImageService(ImageServiceABC):
image_category=image_category,
width=width,
height=height,
has_workflow=workflow is not None,
has_workflow=workflow is not None or graph is not None,
# Meta fields
is_intermediate=is_intermediate,
# Nullable fields
@@ -75,7 +75,7 @@ class ImageService(ImageServiceABC):
if board_id is not None:
self.__invoker.services.board_image_records.add_image_to_board(board_id=board_id, image_name=image_name)
self.__invoker.services.image_files.save(
image_name=image_name, image=image, metadata=metadata, workflow=workflow
image_name=image_name, image=image, metadata=metadata, workflow=workflow, graph=graph
)
image_dto = self.get_dto(image_name)
@@ -157,7 +157,7 @@ class ImageService(ImageServiceABC):
self.__invoker.services.logger.error("Problem getting image metadata")
raise e
def get_workflow(self, image_name: str) -> Optional[WorkflowWithoutID]:
def get_workflow(self, image_name: str) -> Optional[str]:
try:
return self.__invoker.services.image_files.get_workflow(image_name)
except ImageFileNotFoundException:
@@ -167,6 +167,16 @@ class ImageService(ImageServiceABC):
self.__invoker.services.logger.error("Problem getting image workflow")
raise
def get_graph(self, image_name: str) -> Optional[str]:
try:
return self.__invoker.services.image_files.get_graph(image_name)
except ImageFileNotFoundException:
self.__invoker.services.logger.error("Image file not found")
raise
except Exception:
self.__invoker.services.logger.error("Problem getting image graph")
raise
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
try:
return str(self.__invoker.services.image_files.get_path(image_name, thumbnail))

View File

@@ -1,11 +1,13 @@
"""Initialization file for model install service package."""
from .model_install_base import (
ModelInstallServiceBase,
)
from .model_install_common import (
HFModelSource,
InstallStatus,
LocalModelSource,
ModelInstallJob,
ModelInstallServiceBase,
ModelSource,
UnknownInstallJobException,
URLModelSource,

View File

@@ -1,244 +1,19 @@
# Copyright 2023 Lincoln D. Stein and the InvokeAI development team
"""Baseclass definitions for the model installer."""
import re
import traceback
from abc import ABC, abstractmethod
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Set, Union
from typing import Any, Dict, List, Optional, Union
from pydantic import BaseModel, Field, PrivateAttr, field_validator
from pydantic.networks import AnyHttpUrl
from typing_extensions import Annotated
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase
from invokeai.app.services.download import DownloadQueueServiceBase
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_install.model_install_common import ModelInstallJob, ModelSource
from invokeai.app.services.model_records import ModelRecordServiceBase
from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant
from invokeai.backend.model_manager.config import ModelSourceType
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
class InstallStatus(str, Enum):
"""State of an install job running in the background."""
WAITING = "waiting" # waiting to be dequeued
DOWNLOADING = "downloading" # downloading of model files in process
DOWNLOADS_DONE = "downloads_done" # downloading done, waiting to run
RUNNING = "running" # being processed
COMPLETED = "completed" # finished running
ERROR = "error" # terminated with an error message
CANCELLED = "cancelled" # terminated with an error message
class ModelInstallPart(BaseModel):
url: AnyHttpUrl
path: Path
bytes: int = 0
total_bytes: int = 0
class UnknownInstallJobException(Exception):
"""Raised when the status of an unknown job is requested."""
class StringLikeSource(BaseModel):
"""
Base class for model sources, implements functions that lets the source be sorted and indexed.
These shenanigans let this stuff work:
source1 = LocalModelSource(path='C:/users/mort/foo.safetensors')
mydict = {source1: 'model 1'}
assert mydict['C:/users/mort/foo.safetensors'] == 'model 1'
assert mydict[LocalModelSource(path='C:/users/mort/foo.safetensors')] == 'model 1'
source2 = LocalModelSource(path=Path('C:/users/mort/foo.safetensors'))
assert source1 == source2
assert source1 == 'C:/users/mort/foo.safetensors'
"""
def __hash__(self) -> int:
"""Return hash of the path field, for indexing."""
return hash(str(self))
def __lt__(self, other: object) -> int:
"""Return comparison of the stringified version, for sorting."""
return str(self) < str(other)
def __eq__(self, other: object) -> bool:
"""Return equality on the stringified version."""
if isinstance(other, Path):
return str(self) == other.as_posix()
else:
return str(self) == str(other)
class LocalModelSource(StringLikeSource):
"""A local file or directory path."""
path: str | Path
inplace: Optional[bool] = False
type: Literal["local"] = "local"
# these methods allow the source to be used in a string-like way,
# for example as an index into a dict
def __str__(self) -> str:
"""Return string version of path when string rep needed."""
return Path(self.path).as_posix()
class HFModelSource(StringLikeSource):
"""
A HuggingFace repo_id with optional variant, sub-folder and access token.
Note that the variant option, if not provided to the constructor, will default to fp16, which is
what people (almost) always want.
"""
repo_id: str
variant: Optional[ModelRepoVariant] = ModelRepoVariant.FP16
subfolder: Optional[Path] = None
access_token: Optional[str] = None
type: Literal["hf"] = "hf"
@field_validator("repo_id")
@classmethod
def proper_repo_id(cls, v: str) -> str: # noqa D102
if not re.match(r"^([.\w-]+/[.\w-]+)$", v):
raise ValueError(f"{v}: invalid repo_id format")
return v
def __str__(self) -> str:
"""Return string version of repoid when string rep needed."""
base: str = self.repo_id
if self.variant:
base += f":{self.variant or ''}"
if self.subfolder:
base += f":{self.subfolder}"
return base
class URLModelSource(StringLikeSource):
"""A generic URL point to a checkpoint file."""
url: AnyHttpUrl
access_token: Optional[str] = None
type: Literal["url"] = "url"
def __str__(self) -> str:
"""Return string version of the url when string rep needed."""
return str(self.url)
ModelSource = Annotated[Union[LocalModelSource, HFModelSource, URLModelSource], Field(discriminator="type")]
MODEL_SOURCE_TO_TYPE_MAP = {
URLModelSource: ModelSourceType.Url,
HFModelSource: ModelSourceType.HFRepoID,
LocalModelSource: ModelSourceType.Path,
}
class ModelInstallJob(BaseModel):
"""Object that tracks the current status of an install request."""
id: int = Field(description="Unique ID for this job")
status: InstallStatus = Field(default=InstallStatus.WAITING, description="Current status of install process")
error_reason: Optional[str] = Field(default=None, description="Information about why the job failed")
config_in: Dict[str, Any] = Field(
default_factory=dict, description="Configuration information (e.g. 'description') to apply to model."
)
config_out: Optional[AnyModelConfig] = Field(
default=None, description="After successful installation, this will hold the configuration object."
)
inplace: bool = Field(
default=False, description="Leave model in its current location; otherwise install under models directory"
)
source: ModelSource = Field(description="Source (URL, repo_id, or local path) of model")
local_path: Path = Field(description="Path to locally-downloaded model; may be the same as the source")
bytes: int = Field(
default=0, description="For a remote model, the number of bytes downloaded so far (may not be available)"
)
total_bytes: int = Field(default=0, description="Total size of the model to be installed")
source_metadata: Optional[AnyModelRepoMetadata] = Field(
default=None, description="Metadata provided by the model source"
)
download_parts: Set[DownloadJob] = Field(
default_factory=set, description="Download jobs contributing to this install"
)
error: Optional[str] = Field(
default=None, description="On an error condition, this field will contain the text of the exception"
)
error_traceback: Optional[str] = Field(
default=None, description="On an error condition, this field will contain the exception traceback"
)
# internal flags and transitory settings
_install_tmpdir: Optional[Path] = PrivateAttr(default=None)
_exception: Optional[Exception] = PrivateAttr(default=None)
def set_error(self, e: Exception) -> None:
"""Record the error and traceback from an exception."""
self._exception = e
self.error = str(e)
self.error_traceback = self._format_error(e)
self.status = InstallStatus.ERROR
self.error_reason = self._exception.__class__.__name__ if self._exception else None
def cancel(self) -> None:
"""Call to cancel the job."""
self.status = InstallStatus.CANCELLED
@property
def error_type(self) -> Optional[str]:
"""Class name of the exception that led to status==ERROR."""
return self._exception.__class__.__name__ if self._exception else None
def _format_error(self, exception: Exception) -> str:
"""Error traceback."""
return "".join(traceback.format_exception(exception))
@property
def cancelled(self) -> bool:
"""Set status to CANCELLED."""
return self.status == InstallStatus.CANCELLED
@property
def errored(self) -> bool:
"""Return true if job has errored."""
return self.status == InstallStatus.ERROR
@property
def waiting(self) -> bool:
"""Return true if job is waiting to run."""
return self.status == InstallStatus.WAITING
@property
def downloading(self) -> bool:
"""Return true if job is downloading."""
return self.status == InstallStatus.DOWNLOADING
@property
def downloads_done(self) -> bool:
"""Return true if job's downloads ae done."""
return self.status == InstallStatus.DOWNLOADS_DONE
@property
def running(self) -> bool:
"""Return true if job is running."""
return self.status == InstallStatus.RUNNING
@property
def complete(self) -> bool:
"""Return true if job completed without errors."""
return self.status == InstallStatus.COMPLETED
@property
def in_terminal_state(self) -> bool:
"""Return true if job is in a terminal state."""
return self.status in [InstallStatus.COMPLETED, InstallStatus.ERROR, InstallStatus.CANCELLED]
from invokeai.backend.model_manager.config import AnyModelConfig
class ModelInstallServiceBase(ABC):
@@ -282,7 +57,7 @@ class ModelInstallServiceBase(ABC):
@property
@abstractmethod
def event_bus(self) -> Optional[EventServiceBase]:
def event_bus(self) -> Optional["EventServiceBase"]:
"""Return the event service base object associated with the installer."""
@abstractmethod

View File

@@ -0,0 +1,233 @@
import re
import traceback
from enum import Enum
from pathlib import Path
from typing import Any, Dict, Literal, Optional, Set, Union
from pydantic import BaseModel, Field, PrivateAttr, field_validator
from pydantic.networks import AnyHttpUrl
from typing_extensions import Annotated
from invokeai.app.services.download import DownloadJob
from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant
from invokeai.backend.model_manager.config import ModelSourceType
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
class InstallStatus(str, Enum):
"""State of an install job running in the background."""
WAITING = "waiting" # waiting to be dequeued
DOWNLOADING = "downloading" # downloading of model files in process
DOWNLOADS_DONE = "downloads_done" # downloading done, waiting to run
RUNNING = "running" # being processed
COMPLETED = "completed" # finished running
ERROR = "error" # terminated with an error message
CANCELLED = "cancelled" # terminated with an error message
class ModelInstallPart(BaseModel):
url: AnyHttpUrl
path: Path
bytes: int = 0
total_bytes: int = 0
class UnknownInstallJobException(Exception):
"""Raised when the status of an unknown job is requested."""
class StringLikeSource(BaseModel):
"""
Base class for model sources, implements functions that lets the source be sorted and indexed.
These shenanigans let this stuff work:
source1 = LocalModelSource(path='C:/users/mort/foo.safetensors')
mydict = {source1: 'model 1'}
assert mydict['C:/users/mort/foo.safetensors'] == 'model 1'
assert mydict[LocalModelSource(path='C:/users/mort/foo.safetensors')] == 'model 1'
source2 = LocalModelSource(path=Path('C:/users/mort/foo.safetensors'))
assert source1 == source2
assert source1 == 'C:/users/mort/foo.safetensors'
"""
def __hash__(self) -> int:
"""Return hash of the path field, for indexing."""
return hash(str(self))
def __lt__(self, other: object) -> int:
"""Return comparison of the stringified version, for sorting."""
return str(self) < str(other)
def __eq__(self, other: object) -> bool:
"""Return equality on the stringified version."""
if isinstance(other, Path):
return str(self) == other.as_posix()
else:
return str(self) == str(other)
class LocalModelSource(StringLikeSource):
"""A local file or directory path."""
path: str | Path
inplace: Optional[bool] = False
type: Literal["local"] = "local"
# these methods allow the source to be used in a string-like way,
# for example as an index into a dict
def __str__(self) -> str:
"""Return string version of path when string rep needed."""
return Path(self.path).as_posix()
class HFModelSource(StringLikeSource):
"""
A HuggingFace repo_id with optional variant, sub-folder and access token.
Note that the variant option, if not provided to the constructor, will default to fp16, which is
what people (almost) always want.
"""
repo_id: str
variant: Optional[ModelRepoVariant] = ModelRepoVariant.FP16
subfolder: Optional[Path] = None
access_token: Optional[str] = None
type: Literal["hf"] = "hf"
@field_validator("repo_id")
@classmethod
def proper_repo_id(cls, v: str) -> str: # noqa D102
if not re.match(r"^([.\w-]+/[.\w-]+)$", v):
raise ValueError(f"{v}: invalid repo_id format")
return v
def __str__(self) -> str:
"""Return string version of repoid when string rep needed."""
base: str = self.repo_id
if self.variant:
base += f":{self.variant or ''}"
if self.subfolder:
base += f":{self.subfolder}"
return base
class URLModelSource(StringLikeSource):
"""A generic URL point to a checkpoint file."""
url: AnyHttpUrl
access_token: Optional[str] = None
type: Literal["url"] = "url"
def __str__(self) -> str:
"""Return string version of the url when string rep needed."""
return str(self.url)
ModelSource = Annotated[Union[LocalModelSource, HFModelSource, URLModelSource], Field(discriminator="type")]
MODEL_SOURCE_TO_TYPE_MAP = {
URLModelSource: ModelSourceType.Url,
HFModelSource: ModelSourceType.HFRepoID,
LocalModelSource: ModelSourceType.Path,
}
class ModelInstallJob(BaseModel):
"""Object that tracks the current status of an install request."""
id: int = Field(description="Unique ID for this job")
status: InstallStatus = Field(default=InstallStatus.WAITING, description="Current status of install process")
error_reason: Optional[str] = Field(default=None, description="Information about why the job failed")
config_in: Dict[str, Any] = Field(
default_factory=dict, description="Configuration information (e.g. 'description') to apply to model."
)
config_out: Optional[AnyModelConfig] = Field(
default=None, description="After successful installation, this will hold the configuration object."
)
inplace: bool = Field(
default=False, description="Leave model in its current location; otherwise install under models directory"
)
source: ModelSource = Field(description="Source (URL, repo_id, or local path) of model")
local_path: Path = Field(description="Path to locally-downloaded model; may be the same as the source")
bytes: int = Field(
default=0, description="For a remote model, the number of bytes downloaded so far (may not be available)"
)
total_bytes: int = Field(default=0, description="Total size of the model to be installed")
source_metadata: Optional[AnyModelRepoMetadata] = Field(
default=None, description="Metadata provided by the model source"
)
download_parts: Set[DownloadJob] = Field(
default_factory=set, description="Download jobs contributing to this install"
)
error: Optional[str] = Field(
default=None, description="On an error condition, this field will contain the text of the exception"
)
error_traceback: Optional[str] = Field(
default=None, description="On an error condition, this field will contain the exception traceback"
)
# internal flags and transitory settings
_install_tmpdir: Optional[Path] = PrivateAttr(default=None)
_exception: Optional[Exception] = PrivateAttr(default=None)
def set_error(self, e: Exception) -> None:
"""Record the error and traceback from an exception."""
self._exception = e
self.error = str(e)
self.error_traceback = self._format_error(e)
self.status = InstallStatus.ERROR
self.error_reason = self._exception.__class__.__name__ if self._exception else None
def cancel(self) -> None:
"""Call to cancel the job."""
self.status = InstallStatus.CANCELLED
@property
def error_type(self) -> Optional[str]:
"""Class name of the exception that led to status==ERROR."""
return self._exception.__class__.__name__ if self._exception else None
def _format_error(self, exception: Exception) -> str:
"""Error traceback."""
return "".join(traceback.format_exception(exception))
@property
def cancelled(self) -> bool:
"""Set status to CANCELLED."""
return self.status == InstallStatus.CANCELLED
@property
def errored(self) -> bool:
"""Return true if job has errored."""
return self.status == InstallStatus.ERROR
@property
def waiting(self) -> bool:
"""Return true if job is waiting to run."""
return self.status == InstallStatus.WAITING
@property
def downloading(self) -> bool:
"""Return true if job is downloading."""
return self.status == InstallStatus.DOWNLOADING
@property
def downloads_done(self) -> bool:
"""Return true if job's downloads ae done."""
return self.status == InstallStatus.DOWNLOADS_DONE
@property
def running(self) -> bool:
"""Return true if job is running."""
return self.status == InstallStatus.RUNNING
@property
def complete(self) -> bool:
"""Return true if job completed without errors."""
return self.status == InstallStatus.COMPLETED
@property
def in_terminal_state(self) -> bool:
"""Return true if job is in a terminal state."""
return self.status in [InstallStatus.COMPLETED, InstallStatus.ERROR, InstallStatus.CANCELLED]

View File

@@ -10,7 +10,7 @@ from pathlib import Path
from queue import Empty, Queue
from shutil import copyfile, copytree, move, rmtree
from tempfile import mkdtemp
from typing import Any, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
import torch
import yaml
@@ -20,8 +20,8 @@ from requests import Session
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase, TqdmProgress
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_install.model_install_base import ModelInstallServiceBase
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
from invokeai.backend.model_manager.config import (
@@ -45,13 +45,12 @@ from invokeai.backend.util import InvokeAILogger
from invokeai.backend.util.catch_sigint import catch_sigint
from invokeai.backend.util.devices import TorchDevice
from .model_install_base import (
from .model_install_common import (
MODEL_SOURCE_TO_TYPE_MAP,
HFModelSource,
InstallStatus,
LocalModelSource,
ModelInstallJob,
ModelInstallServiceBase,
ModelSource,
StringLikeSource,
URLModelSource,
@@ -59,6 +58,9 @@ from .model_install_base import (
TMPDIR_PREFIX = "tmpinstall_"
if TYPE_CHECKING:
from invokeai.app.services.events.events_base import EventServiceBase
class ModelInstallService(ModelInstallServiceBase):
"""class for InvokeAI model installation."""
@@ -68,7 +70,7 @@ class ModelInstallService(ModelInstallServiceBase):
app_config: InvokeAIAppConfig,
record_store: ModelRecordServiceBase,
download_queue: DownloadQueueServiceBase,
event_bus: Optional[EventServiceBase] = None,
event_bus: Optional["EventServiceBase"] = None,
session: Optional[Session] = None,
):
"""
@@ -104,7 +106,7 @@ class ModelInstallService(ModelInstallServiceBase):
return self._record_store
@property
def event_bus(self) -> Optional[EventServiceBase]: # noqa D102
def event_bus(self) -> Optional["EventServiceBase"]: # noqa D102
return self._event_bus
# make the invoker optional here because we don't need it and it
@@ -855,35 +857,17 @@ class ModelInstallService(ModelInstallServiceBase):
job.status = InstallStatus.RUNNING
self._logger.info(f"Model install started: {job.source}")
if self._event_bus:
self._event_bus.emit_model_install_running(str(job.source))
self._event_bus.emit_model_install_started(job)
def _signal_job_downloading(self, job: ModelInstallJob) -> None:
if self._event_bus:
parts: List[Dict[str, str | int]] = [
{
"url": str(x.source),
"local_path": str(x.download_path),
"bytes": x.bytes,
"total_bytes": x.total_bytes,
}
for x in job.download_parts
]
assert job.bytes is not None
assert job.total_bytes is not None
self._event_bus.emit_model_install_downloading(
str(job.source),
local_path=job.local_path.as_posix(),
parts=parts,
bytes=job.bytes,
total_bytes=job.total_bytes,
id=job.id,
)
self._event_bus.emit_model_install_download_progress(job)
def _signal_job_downloads_done(self, job: ModelInstallJob) -> None:
job.status = InstallStatus.DOWNLOADS_DONE
self._logger.info(f"Model download complete: {job.source}")
if self._event_bus:
self._event_bus.emit_model_install_downloads_done(str(job.source))
self._event_bus.emit_model_install_downloads_complete(job)
def _signal_job_completed(self, job: ModelInstallJob) -> None:
job.status = InstallStatus.COMPLETED
@@ -891,24 +875,19 @@ class ModelInstallService(ModelInstallServiceBase):
self._logger.info(f"Model install complete: {job.source}")
self._logger.debug(f"{job.local_path} registered key {job.config_out.key}")
if self._event_bus:
assert job.local_path is not None
assert job.config_out is not None
key = job.config_out.key
self._event_bus.emit_model_install_completed(str(job.source), key, id=job.id)
self._event_bus.emit_model_install_complete(job)
def _signal_job_errored(self, job: ModelInstallJob) -> None:
self._logger.error(f"Model install error: {job.source}\n{job.error_type}: {job.error}")
if self._event_bus:
error_type = job.error_type
error = job.error
assert error_type is not None
assert error is not None
self._event_bus.emit_model_install_error(str(job.source), error_type, error, id=job.id)
assert job.error_type is not None
assert job.error is not None
self._event_bus.emit_model_install_error(job)
def _signal_job_cancelled(self, job: ModelInstallJob) -> None:
self._logger.info(f"Model install canceled: {job.source}")
if self._event_bus:
self._event_bus.emit_model_install_cancelled(str(job.source), id=job.id)
self._event_bus.emit_model_install_cancelled(job)
@staticmethod
def get_fetcher_from_url(url: str) -> ModelMetadataFetchBase:

View File

@@ -4,7 +4,6 @@
from abc import ABC, abstractmethod
from typing import Optional
from invokeai.app.services.shared.invocation_context import InvocationContextData
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
from invokeai.backend.model_manager.load import LoadedModel
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
@@ -15,18 +14,12 @@ class ModelLoadServiceBase(ABC):
"""Wrapper around AnyModelLoader."""
@abstractmethod
def load_model(
self,
model_config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
context_data: Optional[InvocationContextData] = None,
) -> LoadedModel:
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
"""
Given a model's configuration, load it and return the LoadedModel object.
:param model_config: Model configuration record (as returned by ModelRecordBase.get_model())
:param submodel: For main (pipeline models), the submodel to fetch.
:param context_data: Invocation context data used for event reporting
"""
@property

View File

@@ -5,7 +5,6 @@ from typing import Optional, Type
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.shared.invocation_context import InvocationContextData
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
from invokeai.backend.model_manager.load import (
LoadedModel,
@@ -51,25 +50,18 @@ class ModelLoadService(ModelLoadServiceBase):
"""Return the checkpoint convert cache used by this loader."""
return self._convert_cache
def load_model(
self,
model_config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
context_data: Optional[InvocationContextData] = None,
) -> LoadedModel:
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
"""
Given a model's configuration, load it and return the LoadedModel object.
:param model_config: Model configuration record (as returned by ModelRecordBase.get_model())
:param submodel: For main (pipeline models), the submodel to fetch.
:param context: Invocation context used for event reporting
"""
if context_data:
self._emit_load_event(
context_data=context_data,
model_config=model_config,
submodel_type=submodel_type,
)
# We don't have an invoker during testing
# TODO(psyche): Mock this method on the invoker in the tests
if hasattr(self, "_invoker"):
self._invoker.services.events.emit_model_load_started(model_config, submodel_type)
implementation, model_config, submodel_type = self._registry.get_implementation(model_config, submodel_type) # type: ignore
loaded_model: LoadedModel = implementation(
@@ -79,40 +71,7 @@ class ModelLoadService(ModelLoadServiceBase):
convert_cache=self._convert_cache,
).load_model(model_config, submodel_type)
if context_data:
self._emit_load_event(
context_data=context_data,
model_config=model_config,
submodel_type=submodel_type,
loaded=True,
)
if hasattr(self, "_invoker"):
self._invoker.services.events.emit_model_load_started(model_config, submodel_type)
return loaded_model
def _emit_load_event(
self,
context_data: InvocationContextData,
model_config: AnyModelConfig,
loaded: Optional[bool] = False,
submodel_type: Optional[SubModelType] = None,
) -> None:
if not self._invoker:
return
if not loaded:
self._invoker.services.events.emit_model_load_started(
queue_id=context_data.queue_item.queue_id,
queue_item_id=context_data.queue_item.item_id,
queue_batch_id=context_data.queue_item.batch_id,
graph_execution_state_id=context_data.queue_item.session_id,
model_config=model_config,
submodel_type=submodel_type,
)
else:
self._invoker.services.events.emit_model_load_completed(
queue_id=context_data.queue_item.queue_id,
queue_item_id=context_data.queue_item.item_id,
queue_batch_id=context_data.queue_item.batch_id,
graph_execution_state_id=context_data.queue_item.session_id,
model_config=model_config,
submodel_type=submodel_type,
)

View File

@@ -4,11 +4,16 @@ from threading import BoundedSemaphore, Thread
from threading import Event as ThreadEvent
from typing import Optional
from fastapi_events.handlers.local import local_handler
from fastapi_events.typing import Event as FastAPIEvent
from invokeai.app.invocations.baseinvocation import BaseInvocation
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.events.events_common import (
BatchEnqueuedEvent,
FastAPIEvent,
QueueClearedEvent,
QueueEventBase,
QueueItemStatusChangedEvent,
SessionCanceledEvent,
register_events,
)
from invokeai.app.services.invocation_stats.invocation_stats_common import GESStatsNotFoundError
from invokeai.app.services.session_processor.session_processor_common import CanceledException
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
@@ -31,8 +36,6 @@ class DefaultSessionProcessor(SessionProcessorBase):
self._poll_now_event = ThreadEvent()
self._cancel_event = ThreadEvent()
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_queue_event)
self._thread_limit = thread_limit
self._thread_semaphore = BoundedSemaphore(thread_limit)
self._polling_interval = polling_interval
@@ -49,6 +52,8 @@ class DefaultSessionProcessor(SessionProcessorBase):
else None
)
register_events({SessionCanceledEvent, QueueClearedEvent, BatchEnqueuedEvent}, self._on_queue_event)
self._thread = Thread(
name="session_processor",
target=self._process,
@@ -67,30 +72,25 @@ class DefaultSessionProcessor(SessionProcessorBase):
def _poll_now(self) -> None:
self._poll_now_event.set()
async def _on_queue_event(self, event: FastAPIEvent) -> None:
event_name = event[1]["event"]
async def _on_queue_event(self, event: FastAPIEvent[QueueEventBase]) -> None:
_event_name, payload = event
if (
event_name == "session_canceled"
isinstance(payload, SessionCanceledEvent)
and self._queue_item
and self._queue_item.item_id == event[1]["data"]["queue_item_id"]
and self._queue_item.item_id == payload.item_id
):
self._cancel_event.set()
self._poll_now()
elif (
event_name == "queue_cleared"
isinstance(payload, QueueClearedEvent)
and self._queue_item
and self._queue_item.queue_id == event[1]["data"]["queue_id"]
and self._queue_item.queue_id == payload.queue_id
):
self._cancel_event.set()
self._poll_now()
elif event_name == "batch_enqueued":
elif isinstance(payload, BatchEnqueuedEvent):
self._poll_now()
elif event_name == "queue_item_status_changed" and event[1]["data"]["queue_item"]["status"] in [
"completed",
"failed",
"canceled",
]:
elif isinstance(payload, QueueItemStatusChangedEvent) and payload.status in ("completed", "failed", "canceled"):
self._poll_now()
def resume(self) -> SessionProcessorStatus:
@@ -139,6 +139,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
poll_now_event.wait(self._polling_interval)
continue
self._invoker.services.events.emit_session_started(self._queue_item)
self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}")
cancel_event.clear()
@@ -153,16 +154,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
while self._invocation is not None and not cancel_event.is_set():
# get the source node id to provide to clients (the prepared node id is not as useful)
source_invocation_id = self._queue_item.session.prepared_source_mapping[self._invocation.id]
# Send starting event
self._invoker.services.events.emit_invocation_started(
queue_batch_id=self._queue_item.batch_id,
queue_item_id=self._queue_item.item_id,
queue_id=self._queue_item.queue_id,
graph_execution_state_id=self._queue_item.session_id,
node=self._invocation.model_dump(),
source_node_id=source_invocation_id,
)
self._invoker.services.events.emit_invocation_started(self._queue_item, self._invocation)
# Innermost processor try block; any unhandled exception is an invocation error & will fail the graph
try:
@@ -189,19 +181,12 @@ class DefaultSessionProcessor(SessionProcessorBase):
# Save outputs and history
self._queue_item.session.complete(self._invocation.id, outputs)
# Send complete event
self._invoker.services.events.emit_invocation_complete(
queue_batch_id=self._queue_item.batch_id,
queue_item_id=self._queue_item.item_id,
queue_id=self._queue_item.queue_id,
graph_execution_state_id=self._queue_item.session.id,
node=self._invocation.model_dump(),
source_node_id=source_invocation_id,
result=outputs.model_dump(),
self._queue_item, self._invocation, outputs
)
except KeyboardInterrupt:
# TODO(MM2): Create an event for this
# TODO(MM2): I don't think this is ever raised...
pass
except CanceledException:
@@ -227,14 +212,9 @@ class DefaultSessionProcessor(SessionProcessorBase):
)
self._invoker.services.logger.error(error)
# Send error event
self._invoker.services.events.emit_invocation_error(
queue_batch_id=self._queue_item.session_id,
queue_item_id=self._queue_item.item_id,
queue_id=self._queue_item.queue_id,
graph_execution_state_id=self._queue_item.session.id,
node=self._invocation.model_dump(),
source_node_id=source_invocation_id,
queue_item=self._queue_item,
invocation=self._invocation,
error_type=e.__class__.__name__,
error=error,
)
@@ -242,13 +222,10 @@ class DefaultSessionProcessor(SessionProcessorBase):
# The session is complete if the all invocations are complete or there was an error
if self._queue_item.session.is_complete() or cancel_event.is_set():
# Send complete event
self._invoker.services.events.emit_graph_execution_complete(
queue_batch_id=self._queue_item.batch_id,
queue_item_id=self._queue_item.item_id,
queue_id=self._queue_item.queue_id,
graph_execution_state_id=self._queue_item.session.id,
self._invoker.services.session_queue.set_queue_item_session(
self._queue_item.item_id, self._queue_item.session
)
self._invoker.services.events.emit_session_complete(self._queue_item)
# If we are profiling, stop the profiler and dump the profile & stats
if self._profiler:
profile_path = self._profiler.stop()
@@ -279,6 +256,9 @@ class DefaultSessionProcessor(SessionProcessorBase):
)
# Cancel the queue item
if self._queue_item is not None:
self._invoker.services.session_queue.set_queue_item_session(
self._queue_item.item_id, self._queue_item.session
)
self._invoker.services.session_queue.cancel_queue_item(
self._queue_item.item_id, error=traceback.format_exc()
)

View File

@@ -16,6 +16,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
SessionQueueItemDTO,
SessionQueueStatus,
)
from invokeai.app.services.shared.graph import GraphExecutionState
from invokeai.app.services.shared.pagination import CursorPaginatedResults
@@ -103,3 +104,8 @@ class SessionQueueBase(ABC):
def get_queue_item(self, item_id: int) -> SessionQueueItem:
"""Gets a session queue item by ID"""
pass
@abstractmethod
def set_queue_item_session(self, item_id: int, session: GraphExecutionState) -> SessionQueueItem:
"""Sets the session for a session queue item. Use this to update the session state."""
pass

View File

@@ -2,10 +2,13 @@ import sqlite3
import threading
from typing import Optional, Union, cast
from fastapi_events.handlers.local import local_handler
from fastapi_events.typing import Event as FastAPIEvent
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.events.events_common import (
FastAPIEvent,
InvocationErrorEvent,
SessionCanceledEvent,
SessionCompleteEvent,
register_events,
)
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.session_queue.session_queue_base import SessionQueueBase
from invokeai.app.services.session_queue.session_queue_common import (
@@ -27,6 +30,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
calc_session_count,
prepare_values_to_insert,
)
from invokeai.app.services.shared.graph import GraphExecutionState
from invokeai.app.services.shared.pagination import CursorPaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
@@ -41,7 +45,11 @@ class SqliteSessionQueue(SessionQueueBase):
self.__invoker = invoker
self._set_in_progress_to_canceled()
prune_result = self.prune(DEFAULT_QUEUE_ID)
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_session_event)
register_events(events={InvocationErrorEvent}, func=self._handle_error_event)
register_events(events={SessionCompleteEvent}, func=self._handle_complete_event)
register_events(events={SessionCanceledEvent}, func=self._handle_cancel_event)
if prune_result.deleted > 0:
self.__invoker.services.logger.info(f"Pruned {prune_result.deleted} finished queue items")
@@ -51,51 +59,35 @@ class SqliteSessionQueue(SessionQueueBase):
self.__conn = db.conn
self.__cursor = self.__conn.cursor()
def _match_event_name(self, event: FastAPIEvent, match_in: list[str]) -> bool:
return event[1]["event"] in match_in
async def _on_session_event(self, event: FastAPIEvent) -> FastAPIEvent:
event_name = event[1]["event"]
# This was a match statement, but match is not supported on python 3.9
if event_name == "graph_execution_state_complete":
await self._handle_complete_event(event)
elif event_name == "invocation_error":
await self._handle_error_event(event)
elif event_name == "session_canceled":
await self._handle_cancel_event(event)
return event
async def _handle_complete_event(self, event: FastAPIEvent) -> None:
async def _handle_complete_event(self, event: FastAPIEvent[SessionCompleteEvent]) -> None:
try:
item_id = event[1]["data"]["queue_item_id"]
# When a queue item has an error, we get an error event, then a completed event.
# Mark the queue item completed only if it isn't already marked completed, e.g.
# by a previously-handled error event.
queue_item = self.get_queue_item(item_id)
if queue_item.status not in ["completed", "failed", "canceled"]:
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="completed")
except SessionQueueItemNotFoundError:
return
_event_name, payload = event
async def _handle_error_event(self, event: FastAPIEvent) -> None:
queue_item = self.get_queue_item(payload.item_id)
if queue_item.status not in ["completed", "failed", "canceled"]:
self._set_queue_item_status(item_id=payload.item_id, status="completed")
except SessionQueueItemNotFoundError:
pass
async def _handle_error_event(self, event: FastAPIEvent[InvocationErrorEvent]) -> None:
try:
item_id = event[1]["data"]["queue_item_id"]
error = event[1]["data"]["error"]
queue_item = self.get_queue_item(item_id)
_event_name, payload = event
# always set to failed if have an error, even if previously the item was marked completed or canceled
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="failed", error=error)
self._set_queue_item_status(item_id=payload.item_id, status="failed", error=payload.error)
except SessionQueueItemNotFoundError:
return
pass
async def _handle_cancel_event(self, event: FastAPIEvent) -> None:
async def _handle_cancel_event(self, event: FastAPIEvent[SessionCanceledEvent]) -> None:
try:
item_id = event[1]["data"]["queue_item_id"]
queue_item = self.get_queue_item(item_id)
_event_name, payload = event
queue_item = self.get_queue_item(payload.item_id)
if queue_item.status not in ["completed", "failed", "canceled"]:
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="canceled")
self._set_queue_item_status(item_id=payload.item_id, status="canceled")
except SessionQueueItemNotFoundError:
return
pass
def _set_in_progress_to_canceled(self) -> None:
"""
@@ -292,11 +284,7 @@ class SqliteSessionQueue(SessionQueueBase):
queue_item = self.get_queue_item(item_id)
batch_status = self.get_batch_status(queue_id=queue_item.queue_id, batch_id=queue_item.batch_id)
queue_status = self.get_queue_status(queue_id=queue_item.queue_id)
self.__invoker.services.events.emit_queue_item_status_changed(
session_queue_item=queue_item,
batch_status=batch_status,
queue_status=queue_status,
)
self.__invoker.services.events.emit_queue_item_status_changed(queue_item, batch_status, queue_status)
return queue_item
def is_empty(self, queue_id: str) -> IsEmptyResult:
@@ -429,12 +417,7 @@ class SqliteSessionQueue(SessionQueueBase):
if queue_item.status not in ["canceled", "failed", "completed"]:
status = "failed" if error is not None else "canceled"
queue_item = self._set_queue_item_status(item_id=item_id, status=status, error=error) # type: ignore [arg-type] # mypy seems to not narrow the Literals here
self.__invoker.services.events.emit_session_canceled(
queue_item_id=queue_item.item_id,
queue_id=queue_item.queue_id,
queue_batch_id=queue_item.batch_id,
graph_execution_state_id=queue_item.session_id,
)
self.__invoker.services.events.emit_session_canceled(queue_item)
return queue_item
def cancel_by_batch_ids(self, queue_id: str, batch_ids: list[str]) -> CancelByBatchIDsResult:
@@ -470,18 +453,11 @@ class SqliteSessionQueue(SessionQueueBase):
)
self.__conn.commit()
if current_queue_item is not None and current_queue_item.batch_id in batch_ids:
self.__invoker.services.events.emit_session_canceled(
queue_item_id=current_queue_item.item_id,
queue_id=current_queue_item.queue_id,
queue_batch_id=current_queue_item.batch_id,
graph_execution_state_id=current_queue_item.session_id,
)
self.__invoker.services.events.emit_session_canceled(current_queue_item)
batch_status = self.get_batch_status(queue_id=queue_id, batch_id=current_queue_item.batch_id)
queue_status = self.get_queue_status(queue_id=queue_id)
self.__invoker.services.events.emit_queue_item_status_changed(
session_queue_item=current_queue_item,
batch_status=batch_status,
queue_status=queue_status,
current_queue_item, batch_status, queue_status
)
except Exception:
self.__conn.rollback()
@@ -521,18 +497,11 @@ class SqliteSessionQueue(SessionQueueBase):
)
self.__conn.commit()
if current_queue_item is not None and current_queue_item.queue_id == queue_id:
self.__invoker.services.events.emit_session_canceled(
queue_item_id=current_queue_item.item_id,
queue_id=current_queue_item.queue_id,
queue_batch_id=current_queue_item.batch_id,
graph_execution_state_id=current_queue_item.session_id,
)
self.__invoker.services.events.emit_session_canceled(current_queue_item)
batch_status = self.get_batch_status(queue_id=queue_id, batch_id=current_queue_item.batch_id)
queue_status = self.get_queue_status(queue_id=queue_id)
self.__invoker.services.events.emit_queue_item_status_changed(
session_queue_item=current_queue_item,
batch_status=batch_status,
queue_status=queue_status,
current_queue_item, batch_status, queue_status
)
except Exception:
self.__conn.rollback()
@@ -562,6 +531,29 @@ class SqliteSessionQueue(SessionQueueBase):
raise SessionQueueItemNotFoundError(f"No queue item with id {item_id}")
return SessionQueueItem.queue_item_from_dict(dict(result))
def set_queue_item_session(self, item_id: int, session: GraphExecutionState) -> SessionQueueItem:
try:
# Use exclude_none so we don't end up with a bunch of nulls in the graph - this can cause validation errors
# when the graph is loaded. Graph execution occurs purely in memory - the session saved here is not referenced
# during execution.
session_json = session.model_dump_json(warnings=False, exclude_none=True)
self.__lock.acquire()
self.__cursor.execute(
"""--sql
UPDATE session_queue
SET session = ?
WHERE item_id = ?
""",
(session_json, item_id),
)
self.__conn.commit()
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
return self.get_queue_item(item_id)
def list_queue_items(
self,
queue_id: str,

View File

@@ -180,9 +180,9 @@ class ImagesInterface(InvocationContextInterface):
# If `metadata` is provided directly, use that. Else, use the metadata provided by `WithMetadata`, falling back to None.
metadata_ = None
if metadata:
metadata_ = metadata
elif isinstance(self._data.invocation, WithMetadata):
metadata_ = self._data.invocation.metadata
metadata_ = metadata.model_dump_json()
elif isinstance(self._data.invocation, WithMetadata) and self._data.invocation.metadata:
metadata_ = self._data.invocation.metadata.model_dump_json()
# If `board_id` is provided directly, use that. Else, use the board provided by `WithBoard`, falling back to None.
board_id_ = None
@@ -191,6 +191,14 @@ class ImagesInterface(InvocationContextInterface):
elif isinstance(self._data.invocation, WithBoard) and self._data.invocation.board:
board_id_ = self._data.invocation.board.board_id
workflow_ = None
if self._data.queue_item.workflow:
workflow_ = self._data.queue_item.workflow.model_dump_json()
graph_ = None
if self._data.queue_item.session.graph:
graph_ = self._data.queue_item.session.graph.model_dump_json()
return self._services.images.create(
image=image,
is_intermediate=self._data.invocation.is_intermediate,
@@ -198,7 +206,8 @@ class ImagesInterface(InvocationContextInterface):
board_id=board_id_,
metadata=metadata_,
image_origin=ResourceOrigin.INTERNAL,
workflow=self._data.queue_item.workflow,
workflow=workflow_,
graph=graph_,
session_id=self._data.queue_item.session_id,
node_id=self._data.invocation.id,
)
@@ -344,11 +353,11 @@ class ModelsInterface(InvocationContextInterface):
if isinstance(identifier, str):
model = self._services.model_manager.store.get_model(identifier)
return self._services.model_manager.load.load_model(model, submodel_type, self._data)
return self._services.model_manager.load.load_model(model, submodel_type)
else:
_submodel_type = submodel_type or identifier.submodel_type
model = self._services.model_manager.store.get_model(identifier.key)
return self._services.model_manager.load.load_model(model, _submodel_type, self._data)
return self._services.model_manager.load.load_model(model, _submodel_type)
def load_by_attrs(
self, name: str, base: BaseModelType, type: ModelType, submodel_type: Optional[SubModelType] = None
@@ -373,7 +382,7 @@ class ModelsInterface(InvocationContextInterface):
if len(configs) > 1:
raise ValueError(f"More than one model found with name {name}, base {base}, and type {type}")
return self._services.model_manager.load.load_model(configs[0], submodel_type, self._data)
return self._services.model_manager.load.load_model(configs[0], submodel_type)
def get_config(self, identifier: Union[str, "ModelIdentifierField"]) -> AnyModelConfig:
"""Gets a model's config.

View File

@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Callable
from typing import TYPE_CHECKING, Callable, Optional
import torch
from PIL import Image
@@ -13,8 +13,36 @@ if TYPE_CHECKING:
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.shared.invocation_context import InvocationContextData
# fast latents preview matrix for sdxl
# generated by @StAlKeR7779
SDXL_LATENT_RGB_FACTORS = [
# R G B
[0.3816, 0.4930, 0.5320],
[-0.3753, 0.1631, 0.1739],
[0.1770, 0.3588, -0.2048],
[-0.4350, -0.2644, -0.4289],
]
SDXL_SMOOTH_MATRIX = [
[0.0358, 0.0964, 0.0358],
[0.0964, 0.4711, 0.0964],
[0.0358, 0.0964, 0.0358],
]
def sample_to_lowres_estimated_image(samples, latent_rgb_factors, smooth_matrix=None):
# origingally adapted from code by @erucipe and @keturn here:
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7
# these updated numbers for v1.5 are from @torridgristle
SD1_5_LATENT_RGB_FACTORS = [
# R G B
[0.3444, 0.1385, 0.0670], # L1
[0.1247, 0.4027, 0.1494], # L2
[-0.3192, 0.2513, 0.2103], # L3
[-0.1307, -0.1874, -0.7445], # L4
]
def sample_to_lowres_estimated_image(
samples: torch.Tensor, latent_rgb_factors: torch.Tensor, smooth_matrix: Optional[torch.Tensor] = None
):
latent_image = samples[0].permute(1, 2, 0) @ latent_rgb_factors
if smooth_matrix is not None:
@@ -47,64 +75,12 @@ def stable_diffusion_step_callback(
else:
sample = intermediate_state.latents
# TODO: This does not seem to be needed any more?
# # txt2img provides a Tensor in the step_callback
# # img2img provides a PipelineIntermediateState
# if isinstance(sample, PipelineIntermediateState):
# # this was an img2img
# print('img2img')
# latents = sample.latents
# step = sample.step
# else:
# print('txt2img')
# latents = sample
# step = intermediate_state.step
# TODO: only output a preview image when requested
if base_model in [BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner]:
# fast latents preview matrix for sdxl
# generated by @StAlKeR7779
sdxl_latent_rgb_factors = torch.tensor(
[
# R G B
[0.3816, 0.4930, 0.5320],
[-0.3753, 0.1631, 0.1739],
[0.1770, 0.3588, -0.2048],
[-0.4350, -0.2644, -0.4289],
],
dtype=sample.dtype,
device=sample.device,
)
sdxl_smooth_matrix = torch.tensor(
[
[0.0358, 0.0964, 0.0358],
[0.0964, 0.4711, 0.0964],
[0.0358, 0.0964, 0.0358],
],
dtype=sample.dtype,
device=sample.device,
)
sdxl_latent_rgb_factors = torch.tensor(SDXL_LATENT_RGB_FACTORS, dtype=sample.dtype, device=sample.device)
sdxl_smooth_matrix = torch.tensor(SDXL_SMOOTH_MATRIX, dtype=sample.dtype, device=sample.device)
image = sample_to_lowres_estimated_image(sample, sdxl_latent_rgb_factors, sdxl_smooth_matrix)
else:
# origingally adapted from code by @erucipe and @keturn here:
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7
# these updated numbers for v1.5 are from @torridgristle
v1_5_latent_rgb_factors = torch.tensor(
[
# R G B
[0.3444, 0.1385, 0.0670], # L1
[0.1247, 0.4027, 0.1494], # L2
[-0.3192, 0.2513, 0.2103], # L3
[-0.1307, -0.1874, -0.7445], # L4
],
dtype=sample.dtype,
device=sample.device,
)
v1_5_latent_rgb_factors = torch.tensor(SD1_5_LATENT_RGB_FACTORS, dtype=sample.dtype, device=sample.device)
image = sample_to_lowres_estimated_image(sample, v1_5_latent_rgb_factors)
(width, height) = image.size
@@ -113,15 +89,9 @@ def stable_diffusion_step_callback(
dataURL = image_to_dataURL(image, image_format="JPEG")
events.emit_generator_progress(
queue_id=context_data.queue_item.queue_id,
queue_item_id=context_data.queue_item.item_id,
queue_batch_id=context_data.queue_item.batch_id,
graph_execution_state_id=context_data.queue_item.session_id,
node_id=context_data.invocation.id,
source_node_id=context_data.source_invocation_id,
progress_image=ProgressImage(width=width, height=height, dataURL=dataURL),
step=intermediate_state.step,
order=intermediate_state.order,
total_steps=intermediate_state.total_steps,
events.emit_invocation_denoise_progress(
context_data.queue_item,
context_data.invocation,
intermediate_state,
ProgressImage(dataURL=dataURL, width=width, height=height),
)

View File

@@ -775,10 +775,14 @@
"cannotConnectToSelf": "Cannot connect to self",
"cannotDuplicateConnection": "Cannot create duplicate connections",
"cannotMixAndMatchCollectionItemTypes": "Cannot mix and match collection item types",
"missingNode": "Missing invocation node",
"missingInvocationTemplate": "Missing invocation template",
"missingFieldTemplate": "Missing field template",
"nodePack": "Node pack",
"collection": "Collection",
"collectionFieldType": "{{name}} Collection",
"collectionOrScalarFieldType": "{{name}} Collection|Scalar",
"singleFieldType": "{{name}} (Single)",
"collectionFieldType": "{{name}} (Collection)",
"collectionOrScalarFieldType": "{{name}} (Single or Collection)",
"colorCodeEdges": "Color-Code Edges",
"colorCodeEdgesHelp": "Color-code edges according to their connected fields",
"connectionWouldCreateCycle": "Connection would create a cycle",
@@ -880,6 +884,7 @@
"versionUnknown": " Version Unknown",
"workflow": "Workflow",
"graph": "Graph",
"noGraph": "No Graph",
"workflowAuthor": "Author",
"workflowContact": "Contact",
"workflowDescription": "Short Description",
@@ -947,7 +952,7 @@
"controlAdapterIncompatibleBaseModel": "incompatible Control Adapter base model",
"controlAdapterNoImageSelected": "no Control Adapter image selected",
"controlAdapterImageNotProcessed": "Control Adapter image not processed",
"t2iAdapterIncompatibleDimensions": "T2I Adapter requires image dimension to be multiples of 64",
"t2iAdapterIncompatibleDimensions": "T2I Adapter requires image dimension to be multiples of {{multiple}}",
"ipAdapterNoModelSelected": "no IP adapter selected",
"ipAdapterIncompatibleBaseModel": "incompatible IP Adapter base model",
"ipAdapterNoImageSelected": "no IP Adapter image selected",

View File

@@ -21,6 +21,7 @@ import i18n from 'i18n';
import { size } from 'lodash-es';
import { memo, useCallback, useEffect } from 'react';
import { ErrorBoundary } from 'react-error-boundary';
import { useGetOpenAPISchemaQuery } from 'services/api/endpoints/appInfo';
import AppErrorBoundaryFallback from './AppErrorBoundaryFallback';
import PreselectedImage from './PreselectedImage';
@@ -46,6 +47,7 @@ const App = ({ config = DEFAULT_CONFIG, selectedImage }: Props) => {
useSocketIO();
useGlobalModifiersInit();
useGlobalHotkeys();
useGetOpenAPISchemaQuery();
const { dropzone, isHandlingUpload, setIsHandlingUpload } = useFullscreenDropzone();

View File

@@ -35,28 +35,23 @@ import { addImageUploadedFulfilledListener } from 'app/store/middleware/listener
import { addModelSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/modelSelected';
import { addModelsLoadedListener } from 'app/store/middleware/listenerMiddleware/listeners/modelsLoaded';
import { addDynamicPromptsListener } from 'app/store/middleware/listenerMiddleware/listeners/promptChanged';
import { addSetDefaultSettingsListener } from 'app/store/middleware/listenerMiddleware/listeners/setDefaultSettings';
import { addSocketConnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected';
import { addSocketDisconnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketDisconnected';
import { addGeneratorProgressEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketGeneratorProgress';
import { addGraphExecutionStateCompleteEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketGraphExecutionStateComplete';
import { addInvocationCompleteEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete';
import { addInvocationErrorEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationError';
import { addInvocationRetrievalErrorEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationRetrievalError';
import { addInvocationStartedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationStarted';
import { addModelInstallEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketModelInstall';
import { addModelLoadEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketModelLoad';
import { addSocketQueueItemStatusChangedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketQueueItemStatusChanged';
import { addSessionRetrievalErrorEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketSessionRetrievalError';
import { addSocketSubscribedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketSubscribed';
import { addSocketUnsubscribedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketUnsubscribed';
import { addStagingAreaImageSavedListener } from 'app/store/middleware/listenerMiddleware/listeners/stagingAreaImageSaved';
import { addUpdateAllNodesRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/updateAllNodesRequested';
import { addUpscaleRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/upscaleRequested';
import { addWorkflowLoadRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested';
import type { AppDispatch, RootState } from 'app/store/store';
import { addSetDefaultSettingsListener } from './listeners/setDefaultSettings';
export const listenerMiddleware = createListenerMiddleware();
export type AppStartListening = TypedStartListening<RootState, AppDispatch>;
@@ -110,12 +105,8 @@ addInvocationErrorEventListener(startAppListening);
addInvocationStartedEventListener(startAppListening);
addSocketConnectedEventListener(startAppListening);
addSocketDisconnectedEventListener(startAppListening);
addSocketSubscribedEventListener(startAppListening);
addSocketUnsubscribedEventListener(startAppListening);
addModelLoadEventListener(startAppListening);
addModelInstallEventListener(startAppListening);
addSessionRetrievalErrorEventListener(startAppListening);
addInvocationRetrievalErrorEventListener(startAppListening);
addSocketQueueItemStatusChangedEventListener(startAppListening);
addBulkDownloadListeners(startAppListening);

View File

@@ -6,8 +6,8 @@ import { toast } from 'common/util/toast';
import { t } from 'i18next';
import { imagesApi } from 'services/api/endpoints/images';
import {
socketBulkDownloadCompleted,
socketBulkDownloadFailed,
socketBulkDownloadComplete,
socketBulkDownloadError,
socketBulkDownloadStarted,
} from 'services/events/actions';
@@ -56,7 +56,7 @@ export const addBulkDownloadListeners = (startAppListening: AppStartListening) =
});
startAppListening({
actionCreator: socketBulkDownloadCompleted,
actionCreator: socketBulkDownloadComplete,
effect: async (action) => {
log.debug(action.payload.data, 'Bulk download preparation completed');
@@ -89,7 +89,7 @@ export const addBulkDownloadListeners = (startAppListening: AppStartListening) =
});
startAppListening({
actionCreator: socketBulkDownloadFailed,
actionCreator: socketBulkDownloadError,
effect: async (action) => {
log.debug(action.payload.data, 'Bulk download preparation failed');

View File

@@ -1,5 +1,6 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { parseify } from 'common/util/serialize';
import { canvasSavedToGallery } from 'features/canvas/store/actions';
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
import { addToast } from 'features/system/store/systemSlice';
@@ -43,6 +44,9 @@ export const addCanvasSavedToGalleryListener = (startAppListening: AppStartListe
type: 'TOAST',
toastOptions: { title: t('toast.canvasSavedGallery') },
},
metadata: {
_canvas_objects: parseify(state.canvas.layerState.objects),
},
})
);
},

View File

@@ -16,6 +16,7 @@ import { CA_PROCESSOR_DATA } from 'features/controlLayers/util/controlAdapters';
import { isImageOutput } from 'features/nodes/types/common';
import { addToast } from 'features/system/store/systemSlice';
import { t } from 'i18next';
import { isEqual } from 'lodash-es';
import { getImageDTO } from 'services/api/endpoints/images';
import { queueApi } from 'services/api/endpoints/queue';
import type { BatchConfig } from 'services/api/types';
@@ -47,8 +48,10 @@ const cancelProcessorBatch = async (dispatch: AppDispatch, layerId: string, batc
export const addControlAdapterPreprocessor = (startAppListening: AppStartListening) => {
startAppListening({
matcher,
effect: async (action, { dispatch, getState, cancelActiveListeners, delay, take, signal }) => {
effect: async (action, { dispatch, getState, getOriginalState, cancelActiveListeners, delay, take, signal }) => {
const layerId = caLayerRecalled.match(action) ? action.payload.id : action.payload.layerId;
const state = getState();
const originalState = getOriginalState();
// Cancel any in-progress instances of this listener
cancelActiveListeners();
@@ -57,21 +60,33 @@ export const addControlAdapterPreprocessor = (startAppListening: AppStartListeni
// Delay before starting actual work
await delay(DEBOUNCE_MS);
// Double-check that we are still eligible for processing
const state = getState();
const layer = state.controlLayers.present.layers.filter(isControlAdapterLayer).find((l) => l.id === layerId);
// If we have no image or there is no processor config, bail
if (!layer) {
return;
}
// We should only process if the processor settings or image have changed
const originalLayer = originalState.controlLayers.present.layers
.filter(isControlAdapterLayer)
.find((l) => l.id === layerId);
const originalImage = originalLayer?.controlAdapter.image;
const originalConfig = originalLayer?.controlAdapter.processorConfig;
const image = layer.controlAdapter.image;
const config = layer.controlAdapter.processorConfig;
if (isEqual(config, originalConfig) && isEqual(image, originalImage)) {
// Neither config nor image have changed, we can bail
return;
}
if (!image || !config) {
// The user has reset the image or config, so we should clear the processed image
// - If we have no image, we have nothing to process
// - If we have no processor config, we have nothing to process
// Clear the processed image and bail
dispatch(caLayerProcessedImageChanged({ layerId, imageDTO: null }));
return;
}
// At this point, the user has stopped fiddling with the processor settings and there is a processor selected.
@@ -81,8 +96,8 @@ export const addControlAdapterPreprocessor = (startAppListening: AppStartListeni
cancelProcessorBatch(dispatch, layerId, layer.controlAdapter.processorPendingBatchId);
}
// @ts-expect-error: TS isn't able to narrow the typing of buildNode and `config` will error...
const processorNode = CA_PROCESSOR_DATA[config.type].buildNode(image, config);
// TODO(psyche): I can't get TS to be happy, it thinkgs `config` is `never` but it should be inferred from the generic... I'll just cast it for now
const processorNode = CA_PROCESSOR_DATA[config.type].buildNode(image, config as never);
const enqueueBatchArg: BatchConfig = {
prepend: true,
batch: {
@@ -118,8 +133,8 @@ export const addControlAdapterPreprocessor = (startAppListening: AppStartListeni
const [invocationCompleteAction] = await take(
(action): action is ReturnType<typeof socketInvocationComplete> =>
socketInvocationComplete.match(action) &&
action.payload.data.queue_batch_id === enqueueResult.batch.batch_id &&
action.payload.data.source_node_id === processorNode.id
action.payload.data.batch_id === enqueueResult.batch.batch_id &&
action.payload.data.invocation_source_id === processorNode.id
);
// We still have to check the output type

View File

@@ -69,8 +69,8 @@ export const addControlNetImageProcessedListener = (startAppListening: AppStartL
const [invocationCompleteAction] = await take(
(action): action is ReturnType<typeof socketInvocationComplete> =>
socketInvocationComplete.match(action) &&
action.payload.data.queue_batch_id === enqueueResult.batch.batch_id &&
action.payload.data.source_node_id === nodeId
action.payload.data.batch_id === enqueueResult.batch.batch_id &&
action.payload.data.invocation_source_id === nodeId
);
// We still have to check the output type

View File

@@ -1,7 +1,7 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { deepClone } from 'common/util/deepClone';
import { $nodeExecutionStates } from 'features/nodes/hooks/useExecutionState';
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
import { zNodeStatus } from 'features/nodes/types/invocation';
import { socketGeneratorProgress } from 'services/events/actions';
@@ -12,12 +12,13 @@ export const addGeneratorProgressEventListener = (startAppListening: AppStartLis
actionCreator: socketGeneratorProgress,
effect: (action) => {
log.trace(action.payload, `Generator progress`);
const { source_node_id, step, total_steps, progress_image } = action.payload.data;
const nes = deepClone($nodeExecutionStates.get()[source_node_id]);
const { invocation_source_id, step, total_steps, progress_image } = action.payload.data;
const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]);
if (nes) {
nes.status = zNodeStatus.enum.IN_PROGRESS;
nes.progress = (step + 1) / total_steps;
nes.progressImage = progress_image ?? null;
upsertExecutionState(nes.nodeId, nes);
}
},
});

View File

@@ -29,12 +29,12 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi
actionCreator: socketInvocationComplete,
effect: async (action, { dispatch, getState }) => {
const { data } = action.payload;
log.debug({ data: parseify(data) }, `Invocation complete (${action.payload.data.node.type})`);
log.debug({ data: parseify(data) }, `Invocation complete (${data.invocation_type})`);
const { result, node, queue_batch_id, source_node_id } = data;
const { result, invocation_source_id } = data;
// This complete event has an associated image output
if (isImageOutput(result) && !nodeTypeDenylist.includes(node.type)) {
const { image_name } = result.image;
if (isImageOutput(data.result) && !nodeTypeDenylist.includes(data.invocation_type)) {
const { image_name } = data.result.image;
const { canvas, gallery } = getState();
// This populates the `getImageDTO` cache
@@ -48,7 +48,7 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi
imageDTORequest.unsubscribe();
// Add canvas images to the staging area
if (canvas.batchIds.includes(queue_batch_id) && data.source_node_id === CANVAS_OUTPUT) {
if (canvas.batchIds.includes(data.batch_id) && data.invocation_source_id === CANVAS_OUTPUT) {
dispatch(addImageToStagingArea(imageDTO));
}
@@ -114,7 +114,7 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi
}
}
const nes = deepClone($nodeExecutionStates.get()[source_node_id]);
const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]);
if (nes) {
nes.status = zNodeStatus.enum.COMPLETED;
if (nes.progress !== null) {

View File

@@ -11,9 +11,9 @@ export const addInvocationErrorEventListener = (startAppListening: AppStartListe
startAppListening({
actionCreator: socketInvocationError,
effect: (action) => {
log.error(action.payload, `Invocation error (${action.payload.data.node.type})`);
const { source_node_id } = action.payload.data;
const nes = deepClone($nodeExecutionStates.get()[source_node_id]);
log.error(action.payload, `Invocation error (${action.payload.data.invocation_type})`);
const { invocation_source_id } = action.payload.data;
const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]);
if (nes) {
nes.status = zNodeStatus.enum.FAILED;
nes.error = action.payload.data.error;

View File

@@ -1,14 +0,0 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { socketInvocationRetrievalError } from 'services/events/actions';
const log = logger('socketio');
export const addInvocationRetrievalErrorEventListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: socketInvocationRetrievalError,
effect: (action) => {
log.error(action.payload, `Invocation retrieval error (${action.payload.data.graph_execution_state_id})`);
},
});
};

View File

@@ -11,9 +11,9 @@ export const addInvocationStartedEventListener = (startAppListening: AppStartLis
startAppListening({
actionCreator: socketInvocationStarted,
effect: (action) => {
log.debug(action.payload, `Invocation started (${action.payload.data.node.type})`);
const { source_node_id } = action.payload.data;
const nes = deepClone($nodeExecutionStates.get()[source_node_id]);
log.debug(action.payload, `Invocation started (${action.payload.data.invocation_type})`);
const { invocation_source_id } = action.payload.data;
const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]);
if (nes) {
nes.status = zNodeStatus.enum.IN_PROGRESS;
upsertExecutionState(nes.nodeId, nes);

View File

@@ -3,14 +3,14 @@ import { api, LIST_TAG } from 'services/api';
import { modelsApi } from 'services/api/endpoints/models';
import {
socketModelInstallCancelled,
socketModelInstallCompleted,
socketModelInstallDownloading,
socketModelInstallComplete,
socketModelInstallDownloadProgress,
socketModelInstallError,
} from 'services/events/actions';
export const addModelInstallEventListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: socketModelInstallDownloading,
actionCreator: socketModelInstallDownloadProgress,
effect: async (action, { dispatch }) => {
const { bytes, total_bytes, id } = action.payload.data;
@@ -29,7 +29,7 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin
});
startAppListening({
actionCreator: socketModelInstallCompleted,
actionCreator: socketModelInstallComplete,
effect: (action, { dispatch }) => {
const { id } = action.payload.data;

View File

@@ -1,6 +1,6 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { socketModelLoadCompleted, socketModelLoadStarted } from 'services/events/actions';
import { socketModelLoadComplete, socketModelLoadStarted } from 'services/events/actions';
const log = logger('socketio');
@@ -8,10 +8,11 @@ export const addModelLoadEventListener = (startAppListening: AppStartListening)
startAppListening({
actionCreator: socketModelLoadStarted,
effect: (action) => {
const { model_config, submodel_type } = action.payload.data;
const { name, base, type } = model_config;
const { config, submodel_type } = action.payload.data;
const { name, base, type } = config;
const extras: string[] = [base, type];
if (submodel_type) {
extras.push(submodel_type);
}
@@ -23,10 +24,10 @@ export const addModelLoadEventListener = (startAppListening: AppStartListening)
});
startAppListening({
actionCreator: socketModelLoadCompleted,
actionCreator: socketModelLoadComplete,
effect: (action) => {
const { model_config, submodel_type } = action.payload.data;
const { name, base, type } = model_config;
const { config, submodel_type } = action.payload.data;
const { name, base, type } = config;
const extras: string[] = [base, type];
if (submodel_type) {

View File

@@ -14,16 +14,23 @@ export const addSocketQueueItemStatusChangedEventListener = (startAppListening:
actionCreator: socketQueueItemStatusChanged,
effect: async (action, { dispatch }) => {
// we've got new status for the queue item, batch and queue
const { queue_item, batch_status, queue_status } = action.payload.data;
const { item_id, status, started_at, updated_at, error, completed_at, batch_status, queue_status } =
action.payload.data;
log.debug(action.payload, `Queue item ${queue_item.item_id} status updated: ${queue_item.status}`);
log.debug(action.payload, `Queue item ${item_id} status updated: ${status}`);
// Update this specific queue item in the list of queue items (this is the queue item DTO, without the session)
dispatch(
queueApi.util.updateQueryData('listQueueItems', undefined, (draft) => {
queueItemsAdapter.updateOne(draft, {
id: String(queue_item.item_id),
changes: queue_item,
id: String(item_id),
changes: {
status,
started_at,
updated_at: updated_at ?? undefined,
error,
completed_at: completed_at ?? undefined,
},
});
})
);
@@ -43,23 +50,18 @@ export const addSocketQueueItemStatusChangedEventListener = (startAppListening:
queueApi.util.updateQueryData('getBatchStatus', { batch_id: batch_status.batch_id }, () => batch_status)
);
// Update the queue item status (this is the full queue item, including the session)
dispatch(
queueApi.util.updateQueryData('getQueueItem', queue_item.item_id, (draft) => {
if (!draft) {
return;
}
Object.assign(draft, queue_item);
})
);
// Invalidate caches for things we cannot update
// TODO: technically, we could possibly update the current session queue item, but feels safer to just request it again
dispatch(
queueApi.util.invalidateTags(['CurrentSessionQueueItem', 'NextSessionQueueItem', 'InvocationCacheStatus'])
queueApi.util.invalidateTags([
'CurrentSessionQueueItem',
'NextSessionQueueItem',
'InvocationCacheStatus',
{ type: 'SessionQueueItem', id: item_id },
])
);
if (['in_progress'].includes(action.payload.data.queue_item.status)) {
if (['in_progress'].includes(action.payload.data.status)) {
forEach($nodeExecutionStates.get(), (nes) => {
if (!nes) {
return;

View File

@@ -1,14 +0,0 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { socketSessionRetrievalError } from 'services/events/actions';
const log = logger('socketio');
export const addSessionRetrievalErrorEventListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: socketSessionRetrievalError,
effect: (action) => {
log.error(action.payload, `Session retrieval error (${action.payload.data.graph_execution_state_id})`);
},
});
};

View File

@@ -1,14 +0,0 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { socketSubscribedSession } from 'services/events/actions';
const log = logger('socketio');
export const addSocketSubscribedEventListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: socketSubscribedSession,
effect: (action) => {
log.debug(action.payload, 'Subscribed');
},
});
};

View File

@@ -1,13 +0,0 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { socketUnsubscribedSession } from 'services/events/actions';
const log = logger('socketio');
export const addSocketUnsubscribedEventListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: socketUnsubscribedSession,
effect: (action) => {
log.debug(action.payload, 'Unsubscribed');
},
});
};

View File

@@ -1,7 +1,7 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { updateAllNodesRequested } from 'features/nodes/store/actions';
import { $templates, nodeReplaced } from 'features/nodes/store/nodesSlice';
import { $templates, nodesChanged } from 'features/nodes/store/nodesSlice';
import { NodeUpdateError } from 'features/nodes/types/error';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { getNeedsUpdate, updateNode } from 'features/nodes/util/node/nodeUpdate';
@@ -31,7 +31,12 @@ export const addUpdateAllNodesRequestedListener = (startAppListening: AppStartLi
}
try {
const updatedNode = updateNode(node, template);
dispatch(nodeReplaced({ nodeId: updatedNode.id, node: updatedNode }));
dispatch(
nodesChanged([
{ type: 'remove', id: updatedNode.id },
{ type: 'add', item: updatedNode },
])
);
} catch (e) {
if (e instanceof NodeUpdateError) {
unableToUpdateCount++;

View File

@@ -4,31 +4,49 @@ import { parseify } from 'common/util/serialize';
import { workflowLoaded, workflowLoadRequested } from 'features/nodes/store/actions';
import { $templates } from 'features/nodes/store/nodesSlice';
import { $flow } from 'features/nodes/store/reactFlowInstance';
import type { Templates } from 'features/nodes/store/types';
import { WorkflowMigrationError, WorkflowVersionError } from 'features/nodes/types/error';
import { graphToWorkflow } from 'features/nodes/util/workflow/graphToWorkflow';
import { validateWorkflow } from 'features/nodes/util/workflow/validateWorkflow';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { t } from 'i18next';
import type { GraphAndWorkflowResponse, NonNullableGraph } from 'services/api/types';
import { z } from 'zod';
import { fromZodError } from 'zod-validation-error';
const getWorkflow = (data: GraphAndWorkflowResponse, templates: Templates) => {
if (data.workflow) {
// Prefer to load the workflow if it's available - it has more information
const parsed = JSON.parse(data.workflow);
return validateWorkflow(parsed, templates);
} else if (data.graph) {
// Else we fall back on the graph, using the graphToWorkflow function to convert and do layout
const parsed = JSON.parse(data.graph);
const workflow = graphToWorkflow(parsed as NonNullableGraph, true);
return validateWorkflow(workflow, templates);
} else {
throw new Error('No workflow or graph provided');
}
};
export const addWorkflowLoadRequestedListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: workflowLoadRequested,
effect: (action, { dispatch }) => {
const log = logger('nodes');
const { workflow, asCopy } = action.payload;
const { data, asCopy } = action.payload;
const nodeTemplates = $templates.get();
try {
const { workflow: validatedWorkflow, warnings } = validateWorkflow(workflow, nodeTemplates);
const { workflow, warnings } = getWorkflow(data, nodeTemplates);
if (asCopy) {
// If we're loading a copy, we need to remove the ID so that the backend will create a new workflow
delete validatedWorkflow.id;
delete workflow.id;
}
dispatch(workflowLoaded(validatedWorkflow));
dispatch(workflowLoaded(workflow));
if (!warnings.length) {
dispatch(
addToast(

View File

@@ -137,7 +137,7 @@ const createSelector = (templates: Templates) =>
if (l.controlAdapter.type === 't2i_adapter') {
const multiple = model?.base === 'sdxl' ? 32 : 64;
if (size.width % multiple !== 0 || size.height % multiple !== 0) {
problems.push(i18n.t('parameters.invoke.layer.t2iAdapterIncompatibleDimensions'));
problems.push(i18n.t('parameters.invoke.layer.t2iAdapterIncompatibleDimensions', { multiple }));
}
}
}

View File

@@ -613,7 +613,7 @@ export const canvasSlice = createSlice({
state.batchIds = state.batchIds.filter((id) => id !== batch_status.batch_id);
}
const queueItemStatus = action.payload.data.queue_item.status;
const queueItemStatus = action.payload.data.status;
if (queueItemStatus === 'canceled' || queueItemStatus === 'failed') {
resetStagingAreaIfEmpty(state);
}

View File

@@ -11,10 +11,12 @@ import { iiLayerAdded } from 'features/controlLayers/store/controlLayersSlice';
import { imagesToDeleteSelected } from 'features/deleteImageModal/store/slice';
import { useImageActions } from 'features/gallery/hooks/useImageActions';
import { sentImageToCanvas, sentImageToImg2Img } from 'features/gallery/store/actions';
import { $templates } from 'features/nodes/store/nodesSlice';
import { selectOptimalDimension } from 'features/parameters/store/generationSlice';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { setActiveTab } from 'features/ui/store/uiSlice';
import { useGetAndLoadEmbeddedWorkflow } from 'features/workflowLibrary/hooks/useGetAndLoadEmbeddedWorkflow';
import { size } from 'lodash-es';
import { memo, useCallback } from 'react';
import { flushSync } from 'react-dom';
import { useTranslation } from 'react-i18next';
@@ -48,6 +50,7 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
const isCanvasEnabled = useFeatureStatus('canvas');
const customStarUi = useStore($customStarUI);
const { downloadImage } = useDownloadImage();
const templates = useStore($templates);
const { recallAll, remix, recallSeed, recallPrompts, hasMetadata, hasSeed, hasPrompts, isLoadingMetadata } =
useImageActions(imageDTO?.image_name);
@@ -133,7 +136,7 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
<MenuItem
icon={getAndLoadEmbeddedWorkflowResult.isLoading ? <SpinnerIcon /> : <PiFlowArrowBold />}
onClickCapture={handleLoadWorkflow}
isDisabled={!imageDTO.has_workflow}
isDisabled={!imageDTO.has_workflow || !size(templates)}
>
{t('nodes.loadWorkflow')}
</MenuItem>

View File

@@ -0,0 +1,34 @@
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useDebouncedImageWorkflow } from 'services/api/hooks/useDebouncedImageWorkflow';
import type { ImageDTO } from 'services/api/types';
import DataViewer from './DataViewer';
type Props = {
image: ImageDTO;
};
const ImageMetadataGraphTabContent = ({ image }: Props) => {
const { t } = useTranslation();
const { currentData } = useDebouncedImageWorkflow(image);
const graph = useMemo(() => {
if (currentData?.graph) {
try {
return JSON.parse(currentData.graph);
} catch {
return null;
}
}
return null;
}, [currentData]);
if (!graph) {
return <IAINoContentFallback label={t('nodes.noGraph')} />;
}
return <DataViewer data={graph} label={t('nodes.graph')} />;
};
export default memo(ImageMetadataGraphTabContent);

View File

@@ -1,6 +1,7 @@
import { ExternalLink, Flex, Tab, TabList, TabPanel, TabPanels, Tabs, Text } from '@invoke-ai/ui-library';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
import ImageMetadataGraphTabContent from 'features/gallery/components/ImageMetadataViewer/ImageMetadataGraphTabContent';
import { useMetadataItem } from 'features/metadata/hooks/useMetadataItem';
import { handlers } from 'features/metadata/util/handlers';
import { memo } from 'react';
@@ -52,6 +53,7 @@ const ImageMetadataViewer = ({ image }: ImageMetadataViewerProps) => {
<Tab>{t('metadata.metadata')}</Tab>
<Tab>{t('metadata.imageDetails')}</Tab>
<Tab>{t('metadata.workflow')}</Tab>
<Tab>{t('nodes.graph')}</Tab>
</TabList>
<TabPanels>
@@ -81,6 +83,9 @@ const ImageMetadataViewer = ({ image }: ImageMetadataViewerProps) => {
<TabPanel>
<ImageMetadataWorkflowTabContent image={image} />
</TabPanel>
<TabPanel>
<ImageMetadataGraphTabContent image={image} />
</TabPanel>
</TabPanels>
</Tabs>
</Flex>

View File

@@ -1,5 +1,5 @@
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import { memo } from 'react';
import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useDebouncedImageWorkflow } from 'services/api/hooks/useDebouncedImageWorkflow';
import type { ImageDTO } from 'services/api/types';
@@ -12,7 +12,17 @@ type Props = {
const ImageMetadataWorkflowTabContent = ({ image }: Props) => {
const { t } = useTranslation();
const { workflow } = useDebouncedImageWorkflow(image);
const { currentData } = useDebouncedImageWorkflow(image);
const workflow = useMemo(() => {
if (currentData?.workflow) {
try {
return JSON.parse(currentData.workflow);
} catch {
return null;
}
}
return null;
}, [currentData]);
if (!workflow) {
return <IAINoContentFallback label={t('nodes.noWorkflow')} />;

View File

@@ -1,4 +1,5 @@
import { ButtonGroup, IconButton, Menu, MenuButton, MenuList } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { createSelector } from '@reduxjs/toolkit';
import { skipToken } from '@reduxjs/toolkit/query';
import { upscaleRequested } from 'app/store/middleware/listenerMiddleware/listeners/upscaleRequested';
@@ -12,12 +13,14 @@ import { sentImageToImg2Img } from 'features/gallery/store/actions';
import { selectLastSelectedImage } from 'features/gallery/store/gallerySelectors';
import { selectGallerySlice } from 'features/gallery/store/gallerySlice';
import { parseAndRecallImageDimensions } from 'features/metadata/util/handlers';
import { $templates } from 'features/nodes/store/nodesSlice';
import ParamUpscalePopover from 'features/parameters/components/Upscale/ParamUpscaleSettings';
import { useIsQueueMutationInProgress } from 'features/queue/hooks/useIsQueueMutationInProgress';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { selectSystemSlice } from 'features/system/store/systemSlice';
import { setActiveTab } from 'features/ui/store/uiSlice';
import { useGetAndLoadEmbeddedWorkflow } from 'features/workflowLibrary/hooks/useGetAndLoadEmbeddedWorkflow';
import { size } from 'lodash-es';
import { memo, useCallback } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next';
@@ -48,7 +51,7 @@ const CurrentImageButtons = () => {
const lastSelectedImage = useAppSelector(selectLastSelectedImage);
const selection = useAppSelector((s) => s.gallery.selection);
const shouldDisableToolbarButtons = useAppSelector(selectShouldDisableToolbarButtons);
const templates = useStore($templates);
const isUpscalingEnabled = useFeatureStatus('upscaling');
const isQueueMutationInProgress = useIsQueueMutationInProgress();
const { t } = useTranslation();
@@ -143,7 +146,7 @@ const CurrentImageButtons = () => {
icon={<PiFlowArrowBold />}
tooltip={`${t('nodes.loadWorkflow')} (W)`}
aria-label={`${t('nodes.loadWorkflow')} (W)`}
isDisabled={!imageDTO?.has_workflow}
isDisabled={!imageDTO?.has_workflow || !size(templates)}
onClick={handleLoadWorkflow}
isLoading={getAndLoadEmbeddedWorkflowResult.isLoading}
/>

View File

@@ -9,27 +9,29 @@ import type { SelectInstance } from 'chakra-react-select';
import { useBuildNode } from 'features/nodes/hooks/useBuildNode';
import {
$cursorPos,
$edgePendingUpdate,
$isAddNodePopoverOpen,
$pendingConnection,
$templates,
closeAddNodePopover,
connectionMade,
nodeAdded,
edgesChanged,
nodesChanged,
openAddNodePopover,
} from 'features/nodes/store/nodesSlice';
import { getFirstValidConnection } from 'features/nodes/store/util/findConnectionToValidHandle';
import { validateSourceAndTargetTypes } from 'features/nodes/store/util/validateSourceAndTargetTypes';
import { findUnoccupiedPosition } from 'features/nodes/store/util/findUnoccupiedPosition';
import { getFirstValidConnection } from 'features/nodes/store/util/getFirstValidConnection';
import { connectionToEdge } from 'features/nodes/store/util/reactFlowUtil';
import { validateConnectionTypes } from 'features/nodes/store/util/validateConnectionTypes';
import type { AnyNode } from 'features/nodes/types/invocation';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { filter, map, memoize, some } from 'lodash-es';
import type { KeyboardEventHandler } from 'react';
import { memo, useCallback, useMemo, useRef } from 'react';
import { flushSync } from 'react-dom';
import { useHotkeys } from 'react-hotkeys-hook';
import type { HotkeyCallback } from 'react-hotkeys-hook/dist/types';
import { useTranslation } from 'react-i18next';
import type { FilterOptionOption } from 'react-select/dist/declarations/src/filters';
import { assert } from 'tsafe';
import type { EdgeChange, NodeChange } from 'reactflow';
const createRegex = memoize(
(inputValue: string) =>
@@ -69,17 +71,19 @@ const AddNodePopover = () => {
const filteredTemplates = useMemo(() => {
// If we have a connection in progress, we need to filter the node choices
const templatesArray = map(templates);
if (!pendingConnection) {
return map(templates);
return templatesArray;
}
return filter(templates, (template) => {
const pendingFieldKind = pendingConnection.fieldTemplate.fieldKind;
const fields = pendingFieldKind === 'input' ? template.outputs : template.inputs;
return some(fields, (field) => {
const sourceType = pendingFieldKind === 'input' ? field.type : pendingConnection.fieldTemplate.type;
const targetType = pendingFieldKind === 'output' ? field.type : pendingConnection.fieldTemplate.type;
return validateSourceAndTargetTypes(sourceType, targetType);
const candidateFields = pendingConnection.handleType === 'source' ? template.inputs : template.outputs;
return some(candidateFields, (field) => {
const sourceType =
pendingConnection.handleType === 'source' ? field.type : pendingConnection.fieldTemplate.type;
const targetType =
pendingConnection.handleType === 'target' ? field.type : pendingConnection.fieldTemplate.type;
return validateConnectionTypes(sourceType, targetType);
});
});
}, [templates, pendingConnection]);
@@ -129,11 +133,37 @@ const AddNodePopover = () => {
});
return null;
}
// Find a cozy spot for the node
const cursorPos = $cursorPos.get();
dispatch(nodeAdded({ node, cursorPos }));
const { nodes, edges } = store.getState().nodes.present;
node.position = findUnoccupiedPosition(nodes, cursorPos?.x ?? node.position.x, cursorPos?.y ?? node.position.y);
node.selected = true;
// Deselect all other nodes and edges
const nodeChanges: NodeChange[] = [{ type: 'add', item: node }];
const edgeChanges: EdgeChange[] = [];
nodes.forEach(({ id, selected }) => {
if (selected) {
nodeChanges.push({ type: 'select', id, selected: false });
}
});
edges.forEach(({ id, selected }) => {
if (selected) {
edgeChanges.push({ type: 'select', id, selected: false });
}
});
// Onwards!
if (nodeChanges.length > 0) {
dispatch(nodesChanged(nodeChanges));
}
if (edgeChanges.length > 0) {
dispatch(edgesChanged(edgeChanges));
}
return node;
},
[dispatch, buildInvocation, toaster, t]
[buildInvocation, store, dispatch, t, toaster]
);
const onChange = useCallback<ComboboxOnChange>(
@@ -145,12 +175,28 @@ const AddNodePopover = () => {
// Auto-connect an edge if we just added a node and have a pending connection
if (pendingConnection && isInvocationNode(node)) {
const template = templates[node.data.type];
assert(template, 'Template not found');
const edgePendingUpdate = $edgePendingUpdate.get();
const { handleType } = pendingConnection;
const source = handleType === 'source' ? pendingConnection.nodeId : node.id;
const sourceHandle = handleType === 'source' ? pendingConnection.handleId : null;
const target = handleType === 'target' ? pendingConnection.nodeId : node.id;
const targetHandle = handleType === 'target' ? pendingConnection.handleId : null;
const { nodes, edges } = store.getState().nodes.present;
const connection = getFirstValidConnection(templates, nodes, edges, pendingConnection, node, template);
const connection = getFirstValidConnection(
source,
sourceHandle,
target,
targetHandle,
nodes,
edges,
templates,
edgePendingUpdate
);
if (connection) {
dispatch(connectionMade(connection));
const newEdge = connectionToEdge(connection);
dispatch(edgesChanged([{ type: 'add', item: newEdge }]));
}
}
@@ -160,25 +206,24 @@ const AddNodePopover = () => {
);
const handleHotkeyOpen: HotkeyCallback = useCallback((e) => {
e.preventDefault();
openAddNodePopover();
flushSync(() => {
selectRef.current?.inputRef?.focus();
});
if (!$isAddNodePopoverOpen.get()) {
e.preventDefault();
openAddNodePopover();
flushSync(() => {
selectRef.current?.inputRef?.focus();
});
}
}, []);
const handleHotkeyClose: HotkeyCallback = useCallback(() => {
closeAddNodePopover();
}, []);
useHotkeys(['shift+a', 'space'], handleHotkeyOpen);
useHotkeys(['escape'], handleHotkeyClose);
const onKeyDown: KeyboardEventHandler = useCallback((e) => {
if (e.key === 'Escape') {
if ($isAddNodePopoverOpen.get()) {
closeAddNodePopover();
}
}, []);
useHotkeys(['shift+a', 'space'], handleHotkeyOpen);
useHotkeys(['escape'], handleHotkeyClose, { enableOnFormTags: ['TEXTAREA'] });
const noOptionsMessage = useCallback(() => t('nodes.noMatchingNodes'), [t]);
return (
@@ -215,7 +260,6 @@ const AddNodePopover = () => {
filterOption={filterOption}
onChange={onChange}
onMenuClose={closeAddNodePopover}
onKeyDown={onKeyDown}
inputRef={inputRef}
closeMenuOnSelect={false}
/>

View File

@@ -1,6 +1,6 @@
import { useGlobalMenuClose, useToken } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useAppDispatch, useAppSelector, useAppStore } from 'app/store/storeHooks';
import { useConnection } from 'features/nodes/hooks/useConnection';
import { useCopyPaste } from 'features/nodes/hooks/useCopyPaste';
import { useSyncExecutionState } from 'features/nodes/hooks/useExecutionState';
@@ -8,38 +8,35 @@ import { useIsValidConnection } from 'features/nodes/hooks/useIsValidConnection'
import { useWorkflowWatcher } from 'features/nodes/hooks/useWorkflowWatcher';
import {
$cursorPos,
$didUpdateEdge,
$edgePendingUpdate,
$isAddNodePopoverOpen,
$isUpdatingEdge,
$lastEdgeUpdateMouseEvent,
$pendingConnection,
$viewport,
connectionMade,
edgeAdded,
edgeDeleted,
edgesChanged,
edgesDeleted,
nodesChanged,
nodesDeleted,
redo,
selectedAll,
undo,
} from 'features/nodes/store/nodesSlice';
import { $flow } from 'features/nodes/store/reactFlowInstance';
import { connectionToEdge } from 'features/nodes/store/util/reactFlowUtil';
import type { CSSProperties, MouseEvent } from 'react';
import { memo, useCallback, useMemo, useRef } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
import type {
EdgeChange,
NodeChange,
OnEdgesChange,
OnEdgesDelete,
OnEdgeUpdateFunc,
OnInit,
OnMoveEnd,
OnNodesChange,
OnNodesDelete,
ProOptions,
ReactFlowProps,
ReactFlowState,
} from 'reactflow';
import { Background, ReactFlow, useStore as useReactFlowStore } from 'reactflow';
import { Background, ReactFlow, useStore as useReactFlowStore, useUpdateNodeInternals } from 'reactflow';
import CustomConnectionLine from './connectionLines/CustomConnectionLine';
import InvocationCollapsedEdge from './edges/InvocationCollapsedEdge';
@@ -48,8 +45,6 @@ import CurrentImageNode from './nodes/CurrentImage/CurrentImageNode';
import InvocationNodeWrapper from './nodes/Invocation/InvocationNodeWrapper';
import NotesNode from './nodes/Notes/NotesNode';
const DELETE_KEYS = ['Delete', 'Backspace'];
const edgeTypes = {
collapsed: InvocationCollapsedEdge,
default: InvocationDefaultEdge,
@@ -81,6 +76,8 @@ export const Flow = memo(() => {
const flowWrapper = useRef<HTMLDivElement>(null);
const isValidConnection = useIsValidConnection();
const cancelConnection = useReactFlowStore(selectCancelConnection);
const updateNodeInternals = useUpdateNodeInternals();
const store = useAppStore();
useWorkflowWatcher();
useSyncExecutionState();
const [borderRadius] = useToken('radii', ['base']);
@@ -93,29 +90,17 @@ export const Flow = memo(() => {
);
const onNodesChange: OnNodesChange = useCallback(
(changes) => {
dispatch(nodesChanged(changes));
(nodeChanges) => {
dispatch(nodesChanged(nodeChanges));
},
[dispatch]
);
const onEdgesChange: OnEdgesChange = useCallback(
(changes) => {
dispatch(edgesChanged(changes));
},
[dispatch]
);
const onEdgesDelete: OnEdgesDelete = useCallback(
(edges) => {
dispatch(edgesDeleted(edges));
},
[dispatch]
);
const onNodesDelete: OnNodesDelete = useCallback(
(nodes) => {
dispatch(nodesDeleted(nodes));
if (changes.length > 0) {
dispatch(edgesChanged(changes));
}
},
[dispatch]
);
@@ -157,45 +142,50 @@ export const Flow = memo(() => {
* where the edge is deleted if you click it accidentally).
*/
// We have a ref for cursor position, but it is the *projected* cursor position.
// Easiest to just keep track of the last mouse event for this particular feature
const edgeUpdateMouseEvent = useRef<MouseEvent>();
const onEdgeUpdateStart: NonNullable<ReactFlowProps['onEdgeUpdateStart']> = useCallback(
(e, edge, _handleType) => {
$isUpdatingEdge.set(true);
// update mouse event
edgeUpdateMouseEvent.current = e;
// always delete the edge when starting an updated
dispatch(edgeDeleted(edge.id));
},
[dispatch]
);
const onEdgeUpdateStart: NonNullable<ReactFlowProps['onEdgeUpdateStart']> = useCallback((e, edge, _handleType) => {
$edgePendingUpdate.set(edge);
$didUpdateEdge.set(false);
$lastEdgeUpdateMouseEvent.set(e);
}, []);
const onEdgeUpdate: OnEdgeUpdateFunc = useCallback(
(_oldEdge, newConnection) => {
// Because we deleted the edge when the update started, we must create a new edge from the connection
dispatch(connectionMade(newConnection));
(oldEdge, newConnection) => {
// This event is fired when an edge update is successful
$didUpdateEdge.set(true);
// When an edge update is successful, we need to delete the old edge and create a new one
const newEdge = connectionToEdge(newConnection);
dispatch(
edgesChanged([
{ type: 'remove', id: oldEdge.id },
{ type: 'add', item: newEdge },
])
);
// Because we shift the position of handles depending on whether a field is connected or not, we must use
// updateNodeInternals to tell reactflow to recalculate the positions of the handles
updateNodeInternals([oldEdge.source, oldEdge.target, newEdge.source, newEdge.target]);
},
[dispatch]
[dispatch, updateNodeInternals]
);
const onEdgeUpdateEnd: NonNullable<ReactFlowProps['onEdgeUpdateEnd']> = useCallback(
(e, edge, _handleType) => {
$isUpdatingEdge.set(false);
$pendingConnection.set(null);
// Handle the case where user begins a drag but didn't move the cursor - we deleted the edge when starting
// the edge update - we need to add it back
if (
// ignore touch events
!('touches' in e) &&
edgeUpdateMouseEvent.current?.clientX === e.clientX &&
edgeUpdateMouseEvent.current?.clientY === e.clientY
) {
dispatch(edgeAdded(edge));
const didUpdateEdge = $didUpdateEdge.get();
// Fall back to a reasonable default event
const lastEvent = $lastEdgeUpdateMouseEvent.get() ?? { clientX: 0, clientY: 0 };
// We have to narrow this event down to MouseEvents - could be TouchEvent
const didMouseMove =
!('touches' in e) && Math.hypot(e.clientX - lastEvent.clientX, e.clientY - lastEvent.clientY) > 5;
// If we got this far and did not successfully update an edge, and the mouse moved away from the handle,
// the user probably intended to delete the edge
if (!didUpdateEdge && didMouseMove) {
dispatch(edgesChanged([{ type: 'remove', id: edge.id }]));
}
// reset mouse event
edgeUpdateMouseEvent.current = undefined;
$edgePendingUpdate.set(null);
$didUpdateEdge.set(false);
$pendingConnection.set(null);
$lastEdgeUpdateMouseEvent.set(null);
},
[dispatch]
);
@@ -216,9 +206,27 @@ export const Flow = memo(() => {
const onSelectAllHotkey = useCallback(
(e: KeyboardEvent) => {
e.preventDefault();
dispatch(selectedAll());
const { nodes, edges } = store.getState().nodes.present;
const nodeChanges: NodeChange[] = [];
const edgeChanges: EdgeChange[] = [];
nodes.forEach(({ id, selected }) => {
if (!selected) {
nodeChanges.push({ type: 'select', id, selected: true });
}
});
edges.forEach(({ id, selected }) => {
if (!selected) {
edgeChanges.push({ type: 'select', id, selected: true });
}
});
if (nodeChanges.length > 0) {
dispatch(nodesChanged(nodeChanges));
}
if (edgeChanges.length > 0) {
dispatch(edgesChanged(edgeChanges));
}
},
[dispatch]
[dispatch, store]
);
useHotkeys(['Ctrl+a', 'Meta+a'], onSelectAllHotkey);
@@ -255,12 +263,37 @@ export const Flow = memo(() => {
useHotkeys(['meta+shift+z', 'ctrl+shift+z'], onRedoHotkey);
const onEscapeHotkey = useCallback(() => {
$pendingConnection.set(null);
$isAddNodePopoverOpen.set(false);
cancelConnection();
if (!$edgePendingUpdate.get()) {
$pendingConnection.set(null);
$isAddNodePopoverOpen.set(false);
cancelConnection();
}
}, [cancelConnection]);
useHotkeys('esc', onEscapeHotkey);
const onDeleteHotkey = useCallback(() => {
const { nodes, edges } = store.getState().nodes.present;
const nodeChanges: NodeChange[] = [];
const edgeChanges: EdgeChange[] = [];
nodes
.filter((n) => n.selected)
.forEach(({ id }) => {
nodeChanges.push({ type: 'remove', id });
});
edges
.filter((e) => e.selected)
.forEach(({ id }) => {
edgeChanges.push({ type: 'remove', id });
});
if (nodeChanges.length > 0) {
dispatch(nodesChanged(nodeChanges));
}
if (edgeChanges.length > 0) {
dispatch(edgesChanged(edgeChanges));
}
}, [dispatch, store]);
useHotkeys(['delete', 'backspace'], onDeleteHotkey);
return (
<ReactFlow
id="workflow-editor"
@@ -274,11 +307,9 @@ export const Flow = memo(() => {
onMouseMove={onMouseMove}
onNodesChange={onNodesChange}
onEdgesChange={onEdgesChange}
onEdgesDelete={onEdgesDelete}
onEdgeUpdate={onEdgeUpdate}
onEdgeUpdateStart={onEdgeUpdateStart}
onEdgeUpdateEnd={onEdgeUpdateEnd}
onNodesDelete={onNodesDelete}
onConnectStart={onConnectStart}
onConnect={onConnect}
onConnectEnd={onConnectEnd}
@@ -292,9 +323,10 @@ export const Flow = memo(() => {
proOptions={proOptions}
style={flowStyles}
onPaneClick={handlePaneClick}
deleteKeyCode={DELETE_KEYS}
deleteKeyCode={null}
selectionMode={selectionMode}
elevateEdgesOnSelect
nodeDragThreshold={1}
>
<Background />
</ReactFlow>

View File

@@ -2,13 +2,13 @@ import { Badge, Flex } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useAppSelector } from 'app/store/storeHooks';
import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens';
import { getEdgeStyles } from 'features/nodes/components/flow/edges/util/getEdgeColor';
import { makeEdgeSelector } from 'features/nodes/components/flow/edges/util/makeEdgeSelector';
import { $templates } from 'features/nodes/store/nodesSlice';
import { memo, useMemo } from 'react';
import type { EdgeProps } from 'reactflow';
import { BaseEdge, EdgeLabelRenderer, getBezierPath } from 'reactflow';
import { makeEdgeSelector } from './util/makeEdgeSelector';
const InvocationCollapsedEdge = ({
sourceX,
sourceY,
@@ -18,19 +18,19 @@ const InvocationCollapsedEdge = ({
targetPosition,
markerEnd,
data,
selected,
selected = false,
source,
target,
sourceHandleId,
target,
targetHandleId,
}: EdgeProps<{ count: number }>) => {
const templates = useStore($templates);
const selector = useMemo(
() => makeEdgeSelector(templates, source, sourceHandleId, target, targetHandleId, selected),
[templates, selected, source, sourceHandleId, target, targetHandleId]
() => makeEdgeSelector(templates, source, sourceHandleId, target, targetHandleId),
[templates, source, sourceHandleId, target, targetHandleId]
);
const { isSelected, shouldAnimate } = useAppSelector(selector);
const { shouldAnimateEdges, areConnectedNodesSelected } = useAppSelector(selector);
const [edgePath, labelX, labelY] = getBezierPath({
sourceX,
@@ -44,14 +44,8 @@ const InvocationCollapsedEdge = ({
const { base500 } = useChakraThemeTokens();
const edgeStyles = useMemo(
() => ({
strokeWidth: isSelected ? 3 : 2,
stroke: base500,
opacity: isSelected ? 0.8 : 0.5,
animation: shouldAnimate ? 'dashdraw 0.5s linear infinite' : undefined,
strokeDasharray: shouldAnimate ? 5 : 'none',
}),
[base500, isSelected, shouldAnimate]
() => getEdgeStyles(base500, selected, shouldAnimateEdges, areConnectedNodesSelected),
[areConnectedNodesSelected, base500, selected, shouldAnimateEdges]
);
return (
@@ -60,11 +54,15 @@ const InvocationCollapsedEdge = ({
{data?.count && data.count > 1 && (
<EdgeLabelRenderer>
<Flex
data-testid="asdfasdfasdf"
position="absolute"
transform={`translate(-50%, -50%) translate(${labelX}px,${labelY}px)`}
className="nodrag nopan"
// Unfortunately edge labels do not get the same zIndex treatment as edges do, so we need to manage this ourselves
// See: https://github.com/xyflow/xyflow/issues/3658
zIndex={1001}
>
<Badge variant="solid" bg="base.500" opacity={isSelected ? 0.8 : 0.5} boxShadow="base">
<Badge variant="solid" bg="base.500" opacity={selected ? 0.8 : 0.5} boxShadow="base">
{data.count}
</Badge>
</Flex>

View File

@@ -1,8 +1,8 @@
import { Flex, Text } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useAppSelector } from 'app/store/storeHooks';
import { getEdgeStyles } from 'features/nodes/components/flow/edges/util/getEdgeColor';
import { $templates } from 'features/nodes/store/nodesSlice';
import type { CSSProperties } from 'react';
import { memo, useMemo } from 'react';
import type { EdgeProps } from 'reactflow';
import { BaseEdge, EdgeLabelRenderer, getBezierPath } from 'reactflow';
@@ -17,7 +17,7 @@ const InvocationDefaultEdge = ({
sourcePosition,
targetPosition,
markerEnd,
selected,
selected = false,
source,
target,
sourceHandleId,
@@ -25,11 +25,11 @@ const InvocationDefaultEdge = ({
}: EdgeProps) => {
const templates = useStore($templates);
const selector = useMemo(
() => makeEdgeSelector(templates, source, sourceHandleId, target, targetHandleId, selected),
[templates, source, sourceHandleId, target, targetHandleId, selected]
() => makeEdgeSelector(templates, source, sourceHandleId, target, targetHandleId),
[templates, source, sourceHandleId, target, targetHandleId]
);
const { isSelected, shouldAnimate, stroke, label } = useAppSelector(selector);
const { shouldAnimateEdges, areConnectedNodesSelected, stroke, label } = useAppSelector(selector);
const shouldShowEdgeLabels = useAppSelector((s) => s.workflowSettings.shouldShowEdgeLabels);
const [edgePath, labelX, labelY] = getBezierPath({
@@ -41,15 +41,9 @@ const InvocationDefaultEdge = ({
targetPosition,
});
const edgeStyles = useMemo<CSSProperties>(
() => ({
strokeWidth: isSelected ? 3 : 2,
stroke,
opacity: isSelected ? 0.8 : 0.5,
animation: shouldAnimate ? 'dashdraw 0.5s linear infinite' : undefined,
strokeDasharray: shouldAnimate ? 5 : 'none',
}),
[isSelected, shouldAnimate, stroke]
const edgeStyles = useMemo(
() => getEdgeStyles(stroke, selected, shouldAnimateEdges, areConnectedNodesSelected),
[areConnectedNodesSelected, stroke, selected, shouldAnimateEdges]
);
return (
@@ -65,13 +59,13 @@ const InvocationDefaultEdge = ({
bg="base.800"
borderRadius="base"
borderWidth={1}
borderColor={isSelected ? 'undefined' : 'transparent'}
opacity={isSelected ? 1 : 0.5}
borderColor={selected ? 'undefined' : 'transparent'}
opacity={selected ? 1 : 0.5}
py={1}
px={3}
shadow="md"
>
<Text size="sm" fontWeight="semibold" color={isSelected ? 'base.100' : 'base.300'}>
<Text size="sm" fontWeight="semibold" color={selected ? 'base.100' : 'base.300'}>
{label}
</Text>
</Flex>

View File

@@ -1,6 +1,7 @@
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
import { FIELD_COLORS } from 'features/nodes/types/constants';
import type { FieldType } from 'features/nodes/types/field';
import type { CSSProperties } from 'react';
export const getFieldColor = (fieldType: FieldType | null): string => {
if (!fieldType) {
@@ -10,3 +11,16 @@ export const getFieldColor = (fieldType: FieldType | null): string => {
return color ? colorTokenToCssVar(color) : colorTokenToCssVar('base.500');
};
export const getEdgeStyles = (
stroke: string,
selected: boolean,
shouldAnimateEdges: boolean,
areConnectedNodesSelected: boolean
): CSSProperties => ({
strokeWidth: 3,
stroke,
opacity: selected ? 1 : 0.5,
animation: shouldAnimateEdges ? 'dashdraw 0.5s linear infinite' : undefined,
strokeDasharray: selected || areConnectedNodesSelected ? 5 : 'none',
});

View File

@@ -1,5 +1,6 @@
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
import { deepClone } from 'common/util/deepClone';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import type { Templates } from 'features/nodes/store/types';
import { selectWorkflowSettingsSlice } from 'features/nodes/store/workflowSettingsSlice';
@@ -8,8 +9,8 @@ import { isInvocationNode } from 'features/nodes/types/invocation';
import { getFieldColor } from './getEdgeColor';
const defaultReturnValue = {
isSelected: false,
shouldAnimate: false,
areConnectedNodesSelected: false,
shouldAnimateEdges: false,
stroke: colorTokenToCssVar('base.500'),
label: '',
};
@@ -19,21 +20,27 @@ export const makeEdgeSelector = (
source: string,
sourceHandleId: string | null | undefined,
target: string,
targetHandleId: string | null | undefined,
selected?: boolean
targetHandleId: string | null | undefined
) =>
createMemoizedSelector(
selectNodesSlice,
selectWorkflowSettingsSlice,
(nodes, workflowSettings): { isSelected: boolean; shouldAnimate: boolean; stroke: string; label: string } => {
(
nodes,
workflowSettings
): { areConnectedNodesSelected: boolean; shouldAnimateEdges: boolean; stroke: string; label: string } => {
const { shouldAnimateEdges, shouldColorEdges } = workflowSettings;
const sourceNode = nodes.nodes.find((node) => node.id === source);
const targetNode = nodes.nodes.find((node) => node.id === target);
const returnValue = deepClone(defaultReturnValue);
returnValue.shouldAnimateEdges = shouldAnimateEdges;
const isInvocationToInvocationEdge = isInvocationNode(sourceNode) && isInvocationNode(targetNode);
const isSelected = Boolean(sourceNode?.selected || targetNode?.selected || selected);
returnValue.areConnectedNodesSelected = Boolean(sourceNode?.selected || targetNode?.selected);
if (!sourceNode || !sourceHandleId || !targetNode || !targetHandleId) {
return defaultReturnValue;
return returnValue;
}
const sourceNodeTemplate = templates[sourceNode.data.type];
@@ -42,16 +49,10 @@ export const makeEdgeSelector = (
const outputFieldTemplate = sourceNodeTemplate?.outputs[sourceHandleId];
const sourceType = isInvocationToInvocationEdge ? outputFieldTemplate?.type : undefined;
const stroke =
sourceType && workflowSettings.shouldColorEdges ? getFieldColor(sourceType) : colorTokenToCssVar('base.500');
returnValue.stroke = sourceType && shouldColorEdges ? getFieldColor(sourceType) : colorTokenToCssVar('base.500');
const label = `${sourceNodeTemplate?.title || sourceNode.data?.label} -> ${targetNodeTemplate?.title || targetNode.data?.label}`;
returnValue.label = `${sourceNodeTemplate?.title || sourceNode.data?.label} -> ${targetNodeTemplate?.title || targetNode.data?.label}`;
return {
isSelected,
shouldAnimate: workflowSettings.shouldAnimateEdges && isSelected,
stroke,
label,
};
return returnValue;
}
);

View File

@@ -0,0 +1,20 @@
import { useDoesFieldExist } from 'features/nodes/hooks/useDoesFieldExist';
import type { PropsWithChildren } from 'react';
import { memo } from 'react';
type Props = PropsWithChildren<{
nodeId: string;
fieldName?: string;
}>;
export const MissingFallback = memo((props: Props) => {
// We must be careful here to avoid race conditions where a deleted node is still referenced as an exposed field
const exists = useDoesFieldExist(props.nodeId, props.fieldName);
if (!exists) {
return null;
}
return props.children;
});
MissingFallback.displayName = 'MissingFallback';

View File

@@ -25,10 +25,11 @@ interface Props {
kind: 'inputs' | 'outputs';
isMissingInput?: boolean;
withTooltip?: boolean;
shouldDim?: boolean;
}
const EditableFieldTitle = forwardRef((props: Props, ref) => {
const { nodeId, fieldName, kind, isMissingInput = false, withTooltip = false } = props;
const { nodeId, fieldName, kind, isMissingInput = false, withTooltip = false, shouldDim = false } = props;
const label = useFieldLabel(nodeId, fieldName);
const fieldTemplateTitle = useFieldTemplateTitle(nodeId, fieldName, kind);
const { t } = useTranslation();
@@ -39,13 +40,11 @@ const EditableFieldTitle = forwardRef((props: Props, ref) => {
const handleSubmit = useCallback(
async (newTitleRaw: string) => {
const newTitle = newTitleRaw.trim();
if (newTitle && (newTitle === label || newTitle === fieldTemplateTitle)) {
return;
}
setLocalTitle(newTitle || fieldTemplateTitle || t('nodes.unknownField'));
dispatch(fieldLabelChanged({ nodeId, fieldName, label: newTitle }));
const finalTitle = newTitle || fieldTemplateTitle || t('nodes.unknownField');
setLocalTitle(finalTitle);
dispatch(fieldLabelChanged({ nodeId, fieldName, label: finalTitle }));
},
[label, fieldTemplateTitle, dispatch, nodeId, fieldName, t]
[fieldTemplateTitle, dispatch, nodeId, fieldName, t]
);
const handleChange = useCallback((newTitle: string) => {
@@ -80,6 +79,7 @@ const EditableFieldTitle = forwardRef((props: Props, ref) => {
sx={editablePreviewStyles}
noOfLines={1}
color={isMissingInput ? 'error.300' : 'base.300'}
opacity={shouldDim ? 0.5 : 1}
/>
</Tooltip>
<EditableInput className="nodrag" sx={editableInputStyles} />

View File

@@ -2,10 +2,12 @@ import { Tooltip } from '@invoke-ai/ui-library';
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
import { getFieldColor } from 'features/nodes/components/flow/edges/util/getEdgeColor';
import { useFieldTypeName } from 'features/nodes/hooks/usePrettyFieldType';
import type { ValidationResult } from 'features/nodes/store/util/validateConnection';
import { HANDLE_TOOLTIP_OPEN_DELAY, MODEL_TYPES } from 'features/nodes/types/constants';
import type { FieldInputTemplate, FieldOutputTemplate } from 'features/nodes/types/field';
import { type FieldInputTemplate, type FieldOutputTemplate, isSingle } from 'features/nodes/types/field';
import type { CSSProperties } from 'react';
import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import type { HandleType } from 'reactflow';
import { Handle, Position } from 'reactflow';
@@ -14,11 +16,12 @@ type FieldHandleProps = {
handleType: HandleType;
isConnectionInProgress: boolean;
isConnectionStartField: boolean;
connectionError?: string;
validationResult: ValidationResult;
};
const FieldHandle = (props: FieldHandleProps) => {
const { fieldTemplate, handleType, isConnectionInProgress, isConnectionStartField, connectionError } = props;
const { fieldTemplate, handleType, isConnectionInProgress, isConnectionStartField, validationResult } = props;
const { t } = useTranslation();
const { name } = fieldTemplate;
const type = fieldTemplate.type;
const fieldTypeName = useFieldTypeName(type);
@@ -26,11 +29,11 @@ const FieldHandle = (props: FieldHandleProps) => {
const isModelType = MODEL_TYPES.some((t) => t === type.name);
const color = getFieldColor(type);
const s: CSSProperties = {
backgroundColor: type.isCollection || type.isCollectionOrScalar ? colorTokenToCssVar('base.900') : color,
backgroundColor: !isSingle(type) ? colorTokenToCssVar('base.900') : color,
position: 'absolute',
width: '1rem',
height: '1rem',
borderWidth: type.isCollection || type.isCollectionOrScalar ? 4 : 0,
borderWidth: !isSingle(type) ? 4 : 0,
borderStyle: 'solid',
borderColor: color,
borderRadius: isModelType ? 4 : '100%',
@@ -43,11 +46,11 @@ const FieldHandle = (props: FieldHandleProps) => {
s.insetInlineEnd = '-1rem';
}
if (isConnectionInProgress && !isConnectionStartField && connectionError) {
if (isConnectionInProgress && !isConnectionStartField && !validationResult.isValid) {
s.filter = 'opacity(0.4) grayscale(0.7)';
}
if (isConnectionInProgress && connectionError) {
if (isConnectionInProgress && !validationResult.isValid) {
if (isConnectionStartField) {
s.cursor = 'grab';
} else {
@@ -58,14 +61,14 @@ const FieldHandle = (props: FieldHandleProps) => {
}
return s;
}, [connectionError, handleType, isConnectionInProgress, isConnectionStartField, type]);
}, [handleType, isConnectionInProgress, isConnectionStartField, type, validationResult.isValid]);
const tooltip = useMemo(() => {
if (isConnectionInProgress && connectionError) {
return connectionError;
if (isConnectionInProgress && validationResult.messageTKey) {
return t(validationResult.messageTKey);
}
return fieldTypeName;
}, [connectionError, fieldTypeName, isConnectionInProgress]);
}, [fieldTypeName, isConnectionInProgress, t, validationResult.messageTKey]);
return (
<Tooltip

View File

@@ -24,7 +24,7 @@ const InputField = ({ nodeId, fieldName }: Props) => {
const doesFieldHaveValue = useDoesInputHaveValue(nodeId, fieldName);
const [isHovered, setIsHovered] = useState(false);
const { isConnected, isConnectionInProgress, isConnectionStartField, connectionError, shouldDim } =
const { isConnected, isConnectionInProgress, isConnectionStartField, validationResult, shouldDim } =
useConnectionState({ nodeId, fieldName, kind: 'inputs' });
const isMissingInput = useMemo(() => {
@@ -79,6 +79,7 @@ const InputField = ({ nodeId, fieldName }: Props) => {
kind="inputs"
isMissingInput={isMissingInput}
withTooltip
shouldDim
/>
</FormControl>
@@ -87,7 +88,7 @@ const InputField = ({ nodeId, fieldName }: Props) => {
handleType="target"
isConnectionInProgress={isConnectionInProgress}
isConnectionStartField={isConnectionStartField}
connectionError={connectionError}
validationResult={validationResult}
/>
</InputFieldWrapper>
);
@@ -125,7 +126,7 @@ const InputField = ({ nodeId, fieldName }: Props) => {
handleType="target"
isConnectionInProgress={isConnectionInProgress}
isConnectionStartField={isConnectionStartField}
connectionError={connectionError}
validationResult={validationResult}
/>
)}
</InputFieldWrapper>

View File

@@ -1,3 +1,4 @@
import ModelIdentifierFieldInputComponent from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelIdentifierFieldInputComponent';
import { useFieldInputInstance } from 'features/nodes/hooks/useFieldInputInstance';
import { useFieldInputTemplate } from 'features/nodes/hooks/useFieldInputTemplate';
import {
@@ -23,6 +24,8 @@ import {
isLoRAModelFieldInputTemplate,
isMainModelFieldInputInstance,
isMainModelFieldInputTemplate,
isModelIdentifierFieldInputInstance,
isModelIdentifierFieldInputTemplate,
isSchedulerFieldInputInstance,
isSchedulerFieldInputTemplate,
isSDXLMainModelFieldInputInstance,
@@ -95,6 +98,10 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
return <MainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isModelIdentifierFieldInputInstance(fieldInstance) && isModelIdentifierFieldInputTemplate(fieldTemplate)) {
return <ModelIdentifierFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isSDXLRefinerModelFieldInputInstance(fieldInstance) && isSDXLRefinerModelFieldInputTemplate(fieldTemplate)) {
return <RefinerModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}

View File

@@ -3,6 +3,7 @@ import { CSS } from '@dnd-kit/utilities';
import { Flex, Icon, IconButton, Spacer, Tooltip } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import NodeSelectionOverlay from 'common/components/NodeSelectionOverlay';
import { MissingFallback } from 'features/nodes/components/flow/nodes/Invocation/MissingFallback';
import { useFieldOriginalValue } from 'features/nodes/hooks/useFieldOriginalValue';
import { useMouseOverNode } from 'features/nodes/hooks/useMouseOverNode';
import { workflowExposedFieldRemoved } from 'features/nodes/store/workflowSlice';
@@ -20,7 +21,7 @@ type Props = {
fieldName: string;
};
const LinearViewField = ({ nodeId, fieldName }: Props) => {
const LinearViewFieldInternal = ({ nodeId, fieldName }: Props) => {
const dispatch = useAppDispatch();
const { isValueChanged, onReset } = useFieldOriginalValue(nodeId, fieldName);
const { isMouseOverNode, handleMouseOut, handleMouseOver } = useMouseOverNode(nodeId);
@@ -99,4 +100,12 @@ const LinearViewField = ({ nodeId, fieldName }: Props) => {
);
};
const LinearViewField = ({ nodeId, fieldName }: Props) => {
return (
<MissingFallback nodeId={nodeId} fieldName={fieldName}>
<LinearViewFieldInternal nodeId={nodeId} fieldName={fieldName} />
</MissingFallback>
);
};
export default memo(LinearViewField);

View File

@@ -18,7 +18,7 @@ const OutputField = ({ nodeId, fieldName }: Props) => {
const { t } = useTranslation();
const fieldTemplate = useFieldOutputTemplate(nodeId, fieldName);
const { isConnected, isConnectionInProgress, isConnectionStartField, connectionError, shouldDim } =
const { isConnected, isConnectionInProgress, isConnectionStartField, validationResult, shouldDim } =
useConnectionState({ nodeId, fieldName, kind: 'outputs' });
if (!fieldTemplate) {
@@ -52,7 +52,7 @@ const OutputField = ({ nodeId, fieldName }: Props) => {
handleType="source"
isConnectionInProgress={isConnectionInProgress}
isConnectionStartField={isConnectionStartField}
connectionError={connectionError}
validationResult={validationResult}
/>
</OutputFieldWrapper>
);

View File

@@ -0,0 +1,66 @@
import { Combobox, Flex, FormControl } from '@invoke-ai/ui-library';
import { EMPTY_ARRAY } from 'app/store/constants';
import { useAppDispatch } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { fieldModelIdentifierValueChanged } from 'features/nodes/store/nodesSlice';
import type { ModelIdentifierFieldInputInstance, ModelIdentifierFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback, useMemo } from 'react';
import { modelConfigsAdapterSelectors, useGetModelConfigsQuery } from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
type Props = FieldComponentProps<ModelIdentifierFieldInputInstance, ModelIdentifierFieldInputTemplate>;
const ModelIdentifierFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const { data, isLoading } = useGetModelConfigsQuery();
const _onChange = useCallback(
(value: AnyModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldModelIdentifierValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
const modelConfigs = useMemo(() => {
if (!data) {
return EMPTY_ARRAY;
}
return modelConfigsAdapterSelectors.selectAll(data);
}, [data]);
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelConfigs,
onChange: _onChange,
isLoading,
selectedModel: field.value,
groupByType: true,
});
return (
<Flex w="full" alignItems="center" gap={2}>
<FormControl className="nowheel nodrag" isDisabled={!options.length} isInvalid={!value}>
<Combobox
value={value}
placeholder={placeholder}
options={options}
onChange={onChange}
noOptionsMessage={noOptionsMessage}
/>
</FormControl>
</Flex>
);
};
export default memo(ModelIdentifierFieldInputComponent);

View File

@@ -1,14 +1,15 @@
import type { ChakraProps } from '@invoke-ai/ui-library';
import { Box, useGlobalMenuClose, useToken } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useAppDispatch, useAppSelector, useAppStore } from 'app/store/storeHooks';
import NodeSelectionOverlay from 'common/components/NodeSelectionOverlay';
import { useExecutionState } from 'features/nodes/hooks/useExecutionState';
import { useMouseOverNode } from 'features/nodes/hooks/useMouseOverNode';
import { nodeExclusivelySelected } from 'features/nodes/store/nodesSlice';
import { nodesChanged } from 'features/nodes/store/nodesSlice';
import { DRAG_HANDLE_CLASSNAME, NODE_WIDTH } from 'features/nodes/types/constants';
import { zNodeStatus } from 'features/nodes/types/invocation';
import type { MouseEvent, PropsWithChildren } from 'react';
import { memo, useCallback } from 'react';
import type { NodeChange } from 'reactflow';
type NodeWrapperProps = PropsWithChildren & {
nodeId: string;
@@ -18,6 +19,7 @@ type NodeWrapperProps = PropsWithChildren & {
const NodeWrapper = (props: NodeWrapperProps) => {
const { nodeId, width, children, selected } = props;
const store = useAppStore();
const { isMouseOverNode, handleMouseOut, handleMouseOver } = useMouseOverNode(nodeId);
const executionState = useExecutionState(nodeId);
@@ -37,11 +39,20 @@ const NodeWrapper = (props: NodeWrapperProps) => {
const handleClick = useCallback(
(e: MouseEvent<HTMLDivElement>) => {
if (!e.ctrlKey && !e.altKey && !e.metaKey && !e.shiftKey) {
dispatch(nodeExclusivelySelected(nodeId));
const { nodes } = store.getState().nodes.present;
const nodeChanges: NodeChange[] = [];
nodes.forEach(({ id, selected }) => {
if (selected !== (id === nodeId)) {
nodeChanges.push({ type: 'select', id, selected: id === nodeId });
}
});
if (nodeChanges.length > 0) {
dispatch(nodesChanged(nodeChanges));
}
}
onCloseGlobal();
},
[dispatch, onCloseGlobal, nodeId]
[onCloseGlobal, store, dispatch, nodeId]
);
return (

View File

@@ -7,8 +7,10 @@ import WorkflowInfoTooltipContent from './viewMode/WorkflowInfoTooltipContent';
import { WorkflowWarning } from './viewMode/WorkflowWarning';
export const WorkflowName = () => {
const { name, isTouched, mode } = useAppSelector((s) => s.workflow);
const { t } = useTranslation();
const name = useAppSelector((s) => s.workflow.name);
const isTouched = useAppSelector((s) => s.workflow.isTouched);
const mode = useAppSelector((s) => s.workflow.mode);
return (
<Flex gap="1" alignItems="center">

View File

@@ -1,6 +1,7 @@
import { Flex, FormLabel, Icon, IconButton, Spacer, Tooltip } from '@invoke-ai/ui-library';
import FieldTooltipContent from 'features/nodes/components/flow/nodes/Invocation/fields/FieldTooltipContent';
import InputFieldRenderer from 'features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer';
import { MissingFallback } from 'features/nodes/components/flow/nodes/Invocation/MissingFallback';
import { useFieldLabel } from 'features/nodes/hooks/useFieldLabel';
import { useFieldOriginalValue } from 'features/nodes/hooks/useFieldOriginalValue';
import { useFieldTemplateTitle } from 'features/nodes/hooks/useFieldTemplateTitle';
@@ -14,7 +15,7 @@ type Props = {
fieldName: string;
};
const WorkflowField = ({ nodeId, fieldName }: Props) => {
const WorkflowFieldInternal = ({ nodeId, fieldName }: Props) => {
const label = useFieldLabel(nodeId, fieldName);
const fieldTemplateTitle = useFieldTemplateTitle(nodeId, fieldName, 'inputs');
const { isValueChanged, onReset } = useFieldOriginalValue(nodeId, fieldName);
@@ -50,4 +51,12 @@ const WorkflowField = ({ nodeId, fieldName }: Props) => {
);
};
const WorkflowField = ({ nodeId, fieldName }: Props) => {
return (
<MissingFallback nodeId={nodeId} fieldName={fieldName}>
<WorkflowFieldInternal nodeId={nodeId} fieldName={fieldName} />
</MissingFallback>
);
};
export default memo(WorkflowField);

View File

@@ -6,10 +6,10 @@ import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
import DndSortable from 'features/dnd/components/DndSortable';
import type { DragEndEvent } from 'features/dnd/types';
import LinearViewField from 'features/nodes/components/flow/nodes/Invocation/fields/LinearViewField';
import LinearViewFieldInternal from 'features/nodes/components/flow/nodes/Invocation/fields/LinearViewField';
import { selectWorkflowSlice, workflowExposedFieldsReordered } from 'features/nodes/store/workflowSlice';
import type { FieldIdentifier } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetOpenAPISchemaQuery } from 'services/api/endpoints/appInfo';
@@ -40,16 +40,18 @@ const WorkflowLinearTab = () => {
[dispatch, fields]
);
const items = useMemo(() => fields.map((field) => `${field.nodeId}.${field.fieldName}`), [fields]);
return (
<Box position="relative" w="full" h="full">
<ScrollableContent>
<DndSortable onDragEnd={handleDragEnd} items={fields.map((field) => `${field.nodeId}.${field.fieldName}`)}>
<DndSortable onDragEnd={handleDragEnd} items={items}>
<Flex position="relative" flexDir="column" alignItems="flex-start" p={1} gap={2} h="full" w="full">
{isLoading ? (
<IAINoContentFallback label={t('nodes.loadingNodes')} icon={null} />
) : fields.length ? (
fields.map(({ nodeId, fieldName }) => (
<LinearViewField key={`${nodeId}.${fieldName}`} nodeId={nodeId} fieldName={fieldName} />
<LinearViewFieldInternal key={`${nodeId}.${fieldName}`} nodeId={nodeId} fieldName={fieldName} />
))
) : (
<IAINoContentFallback label={t('nodes.noFieldsLinearview')} icon={null} />

View File

@@ -1,7 +1,6 @@
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { EMPTY_ARRAY } from 'app/store/constants';
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { isSingleOrCollection } from 'features/nodes/types/field';
import { getSortedFilteredFieldNames } from 'features/nodes/util/node/getSortedFilteredFieldNames';
import { TEMPLATE_BUILDER_MAP } from 'features/nodes/util/schema/buildFieldInputTemplate';
import { keys, map } from 'lodash-es';
@@ -9,31 +8,20 @@ import { useMemo } from 'react';
export const useAnyOrDirectInputFieldNames = (nodeId: string): string[] => {
const template = useNodeTemplate(nodeId);
const selectConnectedFieldNames = useMemo(
() =>
createMemoizedSelector(selectNodesSlice, (nodesSlice) =>
nodesSlice.edges
.filter((e) => e.target === nodeId)
.map((e) => e.targetHandle)
.filter(Boolean)
),
[nodeId]
);
const connectedFieldNames = useAppSelector(selectConnectedFieldNames);
const fieldNames = useMemo(() => {
const fields = map(template.inputs).filter((field) => {
if (connectedFieldNames.includes(field.name)) {
return false;
}
return (
(['any', 'direct'].includes(field.input) || field.type.isCollectionOrScalar) &&
(['any', 'direct'].includes(field.input) || isSingleOrCollection(field.type)) &&
keys(TEMPLATE_BUILDER_MAP).includes(field.type.name)
);
});
return getSortedFilteredFieldNames(fields);
}, [connectedFieldNames, template.inputs]);
const _fieldNames = getSortedFilteredFieldNames(fields);
if (_fieldNames.length === 0) {
return EMPTY_ARRAY;
}
return _fieldNames;
}, [template.inputs]);
return fieldNames;
};

View File

@@ -2,58 +2,69 @@ import { useStore } from '@nanostores/react';
import { useAppStore } from 'app/store/storeHooks';
import { $mouseOverNode } from 'features/nodes/hooks/useMouseOverNode';
import {
$didUpdateEdge,
$edgePendingUpdate,
$isAddNodePopoverOpen,
$isUpdatingEdge,
$pendingConnection,
$templates,
connectionMade,
edgesChanged,
} from 'features/nodes/store/nodesSlice';
import { getFirstValidConnection } from 'features/nodes/store/util/findConnectionToValidHandle';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { getFirstValidConnection } from 'features/nodes/store/util/getFirstValidConnection';
import { connectionToEdge } from 'features/nodes/store/util/reactFlowUtil';
import { useCallback, useMemo } from 'react';
import type { OnConnect, OnConnectEnd, OnConnectStart } from 'reactflow';
import type { EdgeChange, OnConnect, OnConnectEnd, OnConnectStart } from 'reactflow';
import { useUpdateNodeInternals } from 'reactflow';
import { assert } from 'tsafe';
export const useConnection = () => {
const store = useAppStore();
const templates = useStore($templates);
const updateNodeInternals = useUpdateNodeInternals();
const onConnectStart = useCallback<OnConnectStart>(
(event, params) => {
(event, { nodeId, handleId, handleType }) => {
assert(nodeId && handleId && handleType, 'Invalid connection start event');
const nodes = store.getState().nodes.present.nodes;
const { nodeId, handleId, handleType } = params;
assert(nodeId && handleId && handleType, `Invalid connection start params: ${JSON.stringify(params)}`);
const node = nodes.find((n) => n.id === nodeId);
assert(isInvocationNode(node), `Invalid node during connection: ${JSON.stringify(node)}`);
if (!node) {
return;
}
const template = templates[node.data.type];
assert(template, `Template not found for node type: ${node.data.type}`);
const fieldTemplate = handleType === 'source' ? template.outputs[handleId] : template.inputs[handleId];
assert(fieldTemplate, `Field template not found for field: ${node.data.type}.${handleId}`);
$pendingConnection.set({
node,
template,
fieldTemplate,
});
if (!template) {
return;
}
const fieldTemplates = template[handleType === 'source' ? 'outputs' : 'inputs'];
const fieldTemplate = fieldTemplates[handleId];
if (!fieldTemplate) {
return;
}
$pendingConnection.set({ nodeId, handleId, handleType, fieldTemplate });
},
[store, templates]
);
const onConnect = useCallback<OnConnect>(
(connection) => {
const { dispatch } = store;
dispatch(connectionMade(connection));
const newEdge = connectionToEdge(connection);
dispatch(edgesChanged([{ type: 'add', item: newEdge }]));
updateNodeInternals([newEdge.source, newEdge.target]);
$pendingConnection.set(null);
},
[store]
[store, updateNodeInternals]
);
const onConnectEnd = useCallback<OnConnectEnd>(() => {
const { dispatch } = store;
const pendingConnection = $pendingConnection.get();
const isUpdatingEdge = $isUpdatingEdge.get();
const edgePendingUpdate = $edgePendingUpdate.get();
const mouseOverNodeId = $mouseOverNode.get();
// If we are in the middle of an edge update, and the mouse isn't over a node, we should just bail so the edge
// update logic can finish up
if (isUpdatingEdge && !mouseOverNodeId) {
if (edgePendingUpdate && !mouseOverNodeId) {
$pendingConnection.set(null);
return;
}
@@ -63,30 +74,41 @@ export const useConnection = () => {
}
const { nodes, edges } = store.getState().nodes.present;
if (mouseOverNodeId) {
const candidateNode = nodes.filter(isInvocationNode).find((n) => n.id === mouseOverNodeId);
if (!candidateNode) {
// The mouse is over a non-invocation node - bail
return;
}
const candidateTemplate = templates[candidateNode.data.type];
assert(candidateTemplate, `Template not found for node type: ${candidateNode.data.type}`);
const { handleType } = pendingConnection;
const source = handleType === 'source' ? pendingConnection.nodeId : mouseOverNodeId;
const sourceHandle = handleType === 'source' ? pendingConnection.handleId : null;
const target = handleType === 'target' ? pendingConnection.nodeId : mouseOverNodeId;
const targetHandle = handleType === 'target' ? pendingConnection.handleId : null;
const connection = getFirstValidConnection(
templates,
source,
sourceHandle,
target,
targetHandle,
nodes,
edges,
pendingConnection,
candidateNode,
candidateTemplate
templates,
edgePendingUpdate
);
if (connection) {
dispatch(connectionMade(connection));
const newEdge = connectionToEdge(connection);
const edgeChanges: EdgeChange[] = [{ type: 'add', item: newEdge }];
const nodesToUpdate = [newEdge.source, newEdge.target];
if (edgePendingUpdate) {
$didUpdateEdge.set(true);
edgeChanges.push({ type: 'remove', id: edgePendingUpdate.id });
nodesToUpdate.push(edgePendingUpdate.source, edgePendingUpdate.target);
}
dispatch(edgesChanged(edgeChanges));
updateNodeInternals(nodesToUpdate);
}
$pendingConnection.set(null);
} else {
// The mouse is not over a node - we should open the add node popover
$isAddNodePopoverOpen.set(true);
}
}, [store, templates]);
}, [store, templates, updateNodeInternals]);
const api = useMemo(() => ({ onConnectStart, onConnect, onConnectEnd }), [onConnectStart, onConnect, onConnectEnd]);
return api;

View File

@@ -1,7 +1,6 @@
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { EMPTY_ARRAY } from 'app/store/constants';
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { isSingleOrCollection } from 'features/nodes/types/field';
import { getSortedFilteredFieldNames } from 'features/nodes/util/node/getSortedFilteredFieldNames';
import { TEMPLATE_BUILDER_MAP } from 'features/nodes/util/schema/buildFieldInputTemplate';
import { keys, map } from 'lodash-es';
@@ -9,31 +8,22 @@ import { useMemo } from 'react';
export const useConnectionInputFieldNames = (nodeId: string): string[] => {
const template = useNodeTemplate(nodeId);
const selectConnectedFieldNames = useMemo(
() =>
createMemoizedSelector(selectNodesSlice, (nodesSlice) =>
nodesSlice.edges
.filter((e) => e.target === nodeId)
.map((e) => e.targetHandle)
.filter(Boolean)
),
[nodeId]
);
const connectedFieldNames = useAppSelector(selectConnectedFieldNames);
const fieldNames = useMemo(() => {
// get the visible fields
const fields = map(template.inputs).filter((field) => {
if (connectedFieldNames.includes(field.name)) {
return true;
}
return (
(field.input === 'connection' && !field.type.isCollectionOrScalar) ||
const fields = map(template.inputs).filter(
(field) =>
(field.input === 'connection' && !isSingleOrCollection(field.type)) ||
!keys(TEMPLATE_BUILDER_MAP).includes(field.type.name)
);
});
);
const _fieldNames = getSortedFilteredFieldNames(fields);
if (_fieldNames.length === 0) {
return EMPTY_ARRAY;
}
return _fieldNames;
}, [template.inputs]);
return getSortedFilteredFieldNames(fields);
}, [connectedFieldNames, template.inputs]);
return fieldNames;
};

View File

@@ -1,12 +1,10 @@
import { useStore } from '@nanostores/react';
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import { $pendingConnection, $templates, selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { makeConnectionErrorSelector } from 'features/nodes/store/util/makeIsConnectionValidSelector';
import { $edgePendingUpdate, $pendingConnection, $templates, selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { makeConnectionErrorSelector } from 'features/nodes/store/util/makeConnectionErrorSelector';
import { useMemo } from 'react';
import { useFieldType } from './useFieldType.ts';
type UseConnectionStateProps = {
nodeId: string;
fieldName: string;
@@ -16,7 +14,7 @@ type UseConnectionStateProps = {
export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionStateProps) => {
const pendingConnection = useStore($pendingConnection);
const templates = useStore($templates);
const fieldType = useFieldType(nodeId, fieldName, kind);
const edgePendingUpdate = useStore($edgePendingUpdate);
const selectIsConnected = useMemo(
() =>
@@ -33,17 +31,9 @@ export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionSta
[fieldName, kind, nodeId]
);
const selectConnectionError = useMemo(
() =>
makeConnectionErrorSelector(
templates,
pendingConnection,
nodeId,
fieldName,
kind === 'inputs' ? 'target' : 'source',
fieldType
),
[templates, pendingConnection, nodeId, fieldName, kind, fieldType]
const selectValidationResult = useMemo(
() => makeConnectionErrorSelector(templates, nodeId, fieldName, kind === 'inputs' ? 'target' : 'source'),
[templates, nodeId, fieldName, kind]
);
const isConnected = useAppSelector(selectIsConnected);
@@ -53,23 +43,23 @@ export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionSta
return false;
}
return (
pendingConnection.node.id === nodeId &&
pendingConnection.fieldTemplate.name === fieldName &&
pendingConnection.nodeId === nodeId &&
pendingConnection.handleId === fieldName &&
pendingConnection.fieldTemplate.fieldKind === { inputs: 'input', outputs: 'output' }[kind]
);
}, [fieldName, kind, nodeId, pendingConnection]);
const connectionError = useAppSelector(selectConnectionError);
const validationResult = useAppSelector((s) => selectValidationResult(s, pendingConnection, edgePendingUpdate));
const shouldDim = useMemo(
() => Boolean(isConnectionInProgress && connectionError && !isConnectionStartField),
[connectionError, isConnectionInProgress, isConnectionStartField]
() => Boolean(isConnectionInProgress && !validationResult.isValid && !isConnectionStartField),
[validationResult, isConnectionInProgress, isConnectionStartField]
);
return {
isConnected,
isConnectionInProgress,
isConnectionStartField,
connectionError,
validationResult,
shouldDim,
};
};

View File

@@ -5,11 +5,13 @@ import {
$copiedNodes,
$cursorPos,
$edgesToCopiedNodes,
selectionPasted,
edgesChanged,
nodesChanged,
selectNodesSlice,
} from 'features/nodes/store/nodesSlice';
import { findUnoccupiedPosition } from 'features/nodes/store/util/findUnoccupiedPosition';
import { isEqual, uniqWith } from 'lodash-es';
import type { EdgeChange, NodeChange } from 'reactflow';
import { v4 as uuidv4 } from 'uuid';
const copySelection = () => {
@@ -26,7 +28,7 @@ const copySelection = () => {
const pasteSelection = (withEdgesToCopiedNodes?: boolean) => {
const { getState, dispatch } = getStore();
const currentNodes = selectNodesSlice(getState()).nodes;
const { nodes, edges } = selectNodesSlice(getState());
const cursorPos = $cursorPos.get();
const copiedNodes = deepClone($copiedNodes.get());
@@ -46,7 +48,7 @@ const pasteSelection = (withEdgesToCopiedNodes?: boolean) => {
const offsetY = cursorPos ? cursorPos.y - minY : 50;
copiedNodes.forEach((node) => {
const { x, y } = findUnoccupiedPosition(currentNodes, node.position.x + offsetX, node.position.y + offsetY);
const { x, y } = findUnoccupiedPosition(nodes, node.position.x + offsetX, node.position.y + offsetY);
node.position.x = x;
node.position.y = y;
// Pasted nodes are selected
@@ -68,7 +70,48 @@ const pasteSelection = (withEdgesToCopiedNodes?: boolean) => {
node.data.id = id;
});
dispatch(selectionPasted({ nodes: copiedNodes, edges: copiedEdges }));
const nodeChanges: NodeChange[] = [];
const edgeChanges: EdgeChange[] = [];
// Deselect existing nodes
nodes.forEach(({ id, selected }) => {
if (selected) {
nodeChanges.push({
type: 'select',
id,
selected: false,
});
}
});
// Add new nodes
copiedNodes.forEach((n) => {
nodeChanges.push({
type: 'add',
item: n,
});
});
// Deselect existing edges
edges.forEach(({ id, selected }) => {
if (selected) {
edgeChanges.push({
type: 'select',
id,
selected: false,
});
}
});
// Add new edges
copiedEdges.forEach((e) => {
edgeChanges.push({
type: 'add',
item: e,
});
});
if (nodeChanges.length > 0) {
dispatch(nodesChanged(nodeChanges));
}
if (edgeChanges.length > 0) {
dispatch(edgesChanged(edgeChanges));
}
};
const api = { copySelection, pasteSelection };

View File

@@ -0,0 +1,20 @@
import { useAppSelector } from 'app/store/storeHooks';
import { isInvocationNode } from 'features/nodes/types/invocation';
export const useDoesFieldExist = (nodeId: string, fieldName?: string) => {
const doesFieldExist = useAppSelector((s) => {
const node = s.nodes.present.nodes.find((n) => n.id === nodeId);
if (!isInvocationNode(node)) {
return false;
}
if (fieldName === undefined) {
return true;
}
if (!node.data.inputs[fieldName]) {
return false;
}
return true;
});
return doesFieldExist;
};

View File

@@ -1,9 +0,0 @@
import { useFieldTemplate } from 'features/nodes/hooks/useFieldTemplate';
import type { FieldType } from 'features/nodes/types/field';
import { useMemo } from 'react';
export const useFieldType = (nodeId: string, fieldName: string, kind: 'inputs' | 'outputs'): FieldType => {
const fieldTemplate = useFieldTemplate(nodeId, fieldName, kind);
const fieldType = useMemo(() => fieldTemplate.type, [fieldTemplate]);
return fieldType;
};

View File

@@ -1,14 +1,10 @@
// TODO: enable this at some point
import { useStore } from '@nanostores/react';
import { useAppSelector, useAppStore } from 'app/store/storeHooks';
import { $templates } from 'features/nodes/store/nodesSlice';
import { getIsGraphAcyclic } from 'features/nodes/store/util/getIsGraphAcyclic';
import { getCollectItemType } from 'features/nodes/store/util/makeIsConnectionValidSelector';
import { validateSourceAndTargetTypes } from 'features/nodes/store/util/validateSourceAndTargetTypes';
import type { InvocationNodeData } from 'features/nodes/types/invocation';
import { isEqual } from 'lodash-es';
import { $edgePendingUpdate, $templates } from 'features/nodes/store/nodesSlice';
import { validateConnection } from 'features/nodes/store/util/validateConnection';
import { useCallback } from 'react';
import type { Connection, Node } from 'reactflow';
import type { Connection } from 'reactflow';
/**
* NOTE: The logic here must be duplicated in `invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts`
@@ -25,75 +21,21 @@ export const useIsValidConnection = () => {
if (!(source && sourceHandle && target && targetHandle)) {
return false;
}
const edgePendingUpdate = $edgePendingUpdate.get();
const { nodes, edges } = store.getState().nodes.present;
if (source === target) {
// Don't allow nodes to connect to themselves, even if validation is disabled
return false;
}
const validationResult = validateConnection(
{ source, sourceHandle, target, targetHandle },
nodes,
edges,
templates,
edgePendingUpdate,
shouldValidateGraph
);
const state = store.getState();
const { nodes, edges } = state.nodes.present;
// Find the source and target nodes
const sourceNode = nodes.find((node) => node.id === source) as Node<InvocationNodeData>;
const targetNode = nodes.find((node) => node.id === target) as Node<InvocationNodeData>;
const sourceFieldTemplate = templates[sourceNode.data.type]?.outputs[sourceHandle];
const targetFieldTemplate = templates[targetNode.data.type]?.inputs[targetHandle];
// Conditional guards against undefined nodes/handles
if (!(sourceFieldTemplate && targetFieldTemplate)) {
return false;
}
if (targetFieldTemplate.input === 'direct') {
return false;
}
if (!shouldValidateGraph) {
// manual override!
return true;
}
if (
edges.find((edge) => {
edge.target === target &&
edge.targetHandle === targetHandle &&
edge.source === source &&
edge.sourceHandle === sourceHandle;
})
) {
// We already have a connection from this source to this target
return false;
}
if (targetNode.data.type === 'collect' && targetFieldTemplate.name === 'item') {
// Collect nodes shouldn't mix and match field types
const collectItemType = getCollectItemType(templates, nodes, edges, targetNode.id);
if (collectItemType) {
return isEqual(sourceFieldTemplate.type, collectItemType);
}
}
// Connection is invalid if target already has a connection
if (
edges.find((edge) => {
return edge.target === target && edge.targetHandle === targetHandle;
}) &&
// except CollectionItem inputs can have multiples
targetFieldTemplate.type.name !== 'CollectionItemField'
) {
return false;
}
// Must use the originalType here if it exists
if (!validateSourceAndTargetTypes(sourceFieldTemplate.type, targetFieldTemplate.type)) {
return false;
}
// Graphs much be acyclic (no loops!)
return getIsGraphAcyclic(source, target, nodes, edges);
return validationResult.isValid;
},
[shouldValidateGraph, templates, store]
[templates, shouldValidateGraph, store]
);
return isValidConnection;

View File

@@ -1,4 +1,4 @@
import type { FieldType } from 'features/nodes/types/field';
import { type FieldType, isCollection, isSingleOrCollection } from 'features/nodes/types/field';
import { useMemo } from 'react';
import { useTranslation } from 'react-i18next';
@@ -10,13 +10,13 @@ export const useFieldTypeName = (fieldType?: FieldType): string => {
return '';
}
const { name } = fieldType;
if (fieldType.isCollection) {
if (isCollection(fieldType)) {
return t('nodes.collectionFieldType', { name });
}
if (fieldType.isCollectionOrScalar) {
if (isSingleOrCollection(fieldType)) {
return t('nodes.collectionOrScalarFieldType', { name });
}
return name;
return t('nodes.singleFieldType', { name });
}, [fieldType, t]);
return name;

View File

@@ -1,6 +1,6 @@
import { createAction, isAnyOf } from '@reduxjs/toolkit';
import type { WorkflowV3 } from 'features/nodes/types/workflow';
import type { Graph } from 'services/api/types';
import type { Graph, GraphAndWorkflowResponse } from 'services/api/types';
const textToImageGraphBuilt = createAction<Graph>('nodes/textToImageGraphBuilt');
const imageToImageGraphBuilt = createAction<Graph>('nodes/imageToImageGraphBuilt');
@@ -15,7 +15,7 @@ export const isAnyGraphBuilt = isAnyOf(
);
export const workflowLoadRequested = createAction<{
workflow: unknown;
data: GraphAndWorkflowResponse;
asCopy: boolean;
}>('nodes/workflowLoadRequested');

View File

@@ -16,6 +16,7 @@ import type {
IPAdapterModelFieldValue,
LoRAModelFieldValue,
MainModelFieldValue,
ModelIdentifierFieldValue,
SchedulerFieldValue,
SDXLRefinerModelFieldValue,
StatefulFieldValue,
@@ -35,6 +36,7 @@ import {
zIPAdapterModelFieldValue,
zLoRAModelFieldValue,
zMainModelFieldValue,
zModelIdentifierFieldValue,
zSchedulerFieldValue,
zSDXLRefinerModelFieldValue,
zStatefulFieldValue,
@@ -45,13 +47,13 @@ import {
import type { AnyNode, InvocationNodeEdge } from 'features/nodes/types/invocation';
import { isInvocationNode, isNotesNode } from 'features/nodes/types/invocation';
import { atom } from 'nanostores';
import type { Connection, Edge, EdgeChange, EdgeRemoveChange, Node, NodeChange, Viewport, XYPosition } from 'reactflow';
import { addEdge, applyEdgeChanges, applyNodeChanges, getConnectedEdges, getIncomers, getOutgoers } from 'reactflow';
import type { MouseEvent } from 'react';
import type { Edge, EdgeChange, NodeChange, Viewport, XYPosition } from 'reactflow';
import { applyEdgeChanges, applyNodeChanges, getConnectedEdges, getIncomers, getOutgoers } from 'reactflow';
import type { UndoableOptions } from 'redux-undo';
import type { z } from 'zod';
import type { NodesState, PendingConnection, Templates } from './types';
import { findUnoccupiedPosition } from './util/findUnoccupiedPosition';
const initialNodesState: NodesState = {
_version: 1,
@@ -90,44 +92,47 @@ export const nodesSlice = createSlice({
reducers: {
nodesChanged: (state, action: PayloadAction<NodeChange[]>) => {
state.nodes = applyNodeChanges(action.payload, state.nodes);
},
nodeReplaced: (state, action: PayloadAction<{ nodeId: string; node: Node }>) => {
const nodeIndex = state.nodes.findIndex((n) => n.id === action.payload.nodeId);
if (nodeIndex < 0) {
return;
}
state.nodes[nodeIndex] = action.payload.node;
},
nodeAdded: (state, action: PayloadAction<{ node: AnyNode; cursorPos: XYPosition | null }>) => {
const { node, cursorPos } = action.payload;
const position = findUnoccupiedPosition(
state.nodes,
cursorPos?.x ?? node.position.x,
cursorPos?.y ?? node.position.y
);
node.position = position;
node.selected = true;
state.nodes = applyNodeChanges(
state.nodes.map((n) => ({ id: n.id, type: 'select', selected: false })),
state.nodes
);
state.edges = applyEdgeChanges(
state.edges.map((e) => ({ id: e.id, type: 'select', selected: false })),
state.edges
);
state.nodes.push(node);
// Remove edges that are no longer valid, due to a removed or otherwise changed node
const edgeChanges: EdgeChange[] = [];
state.edges.forEach((e) => {
const sourceExists = state.nodes.some((n) => n.id === e.source);
const targetExists = state.nodes.some((n) => n.id === e.target);
if (!(sourceExists && targetExists)) {
edgeChanges.push({ type: 'remove', id: e.id });
}
});
state.edges = applyEdgeChanges(edgeChanges, state.edges);
},
edgesChanged: (state, action: PayloadAction<EdgeChange[]>) => {
state.edges = applyEdgeChanges(action.payload, state.edges);
},
edgeAdded: (state, action: PayloadAction<Edge>) => {
state.edges = addEdge(action.payload, state.edges);
},
connectionMade: (state, action: PayloadAction<Connection>) => {
state.edges = addEdge({ ...action.payload, type: 'default' }, state.edges);
const changes: EdgeChange[] = [];
// We may need to massage the edge changes or otherwise handle them
action.payload.forEach((change) => {
if (change.type === 'remove' || change.type === 'select') {
const edge = state.edges.find((e) => e.id === change.id);
// If we deleted or selected a collapsed edge, we need to find its "hidden" edges and do the same to them
if (edge && edge.type === 'collapsed') {
const hiddenEdges = state.edges.filter((e) => e.source === edge.source && e.target === edge.target);
if (change.type === 'remove') {
hiddenEdges.forEach(({ id }) => {
changes.push({ type: 'remove', id });
});
}
if (change.type === 'select') {
hiddenEdges.forEach(({ id }) => {
changes.push({ type: 'select', id, selected: change.selected });
});
}
}
}
if (change.type === 'add') {
if (!change.item.type) {
// We must add the edge type!
change.item.type = 'default';
}
}
changes.push(change);
});
state.edges = applyEdgeChanges(changes, state.edges);
},
fieldLabelChanged: (
state,
@@ -232,6 +237,7 @@ export const nodesSlice = createSlice({
type: 'collapsed',
data: { count: 1 },
updatable: false,
selected: edge.selected,
});
}
}
@@ -252,6 +258,7 @@ export const nodesSlice = createSlice({
type: 'collapsed',
data: { count: 1 },
updatable: false,
selected: edge.selected,
});
}
}
@@ -264,33 +271,6 @@ export const nodesSlice = createSlice({
}
}
},
edgeDeleted: (state, action: PayloadAction<string>) => {
state.edges = state.edges.filter((e) => e.id !== action.payload);
},
edgesDeleted: (state, action: PayloadAction<Edge[]>) => {
const edges = action.payload;
const collapsedEdges = edges.filter((e) => e.type === 'collapsed');
// if we delete a collapsed edge, we need to delete all collapsed edges between the same nodes
if (collapsedEdges.length) {
const edgeChanges: EdgeRemoveChange[] = [];
collapsedEdges.forEach((collapsedEdge) => {
state.edges.forEach((edge) => {
if (edge.source === collapsedEdge.source && edge.target === collapsedEdge.target) {
edgeChanges.push({ id: edge.id, type: 'remove' });
}
});
});
state.edges = applyEdgeChanges(edgeChanges, state.edges);
}
},
nodesDeleted: (state, action: PayloadAction<AnyNode[]>) => {
action.payload.forEach((node) => {
if (!isInvocationNode(node)) {
return;
}
});
},
nodeLabelChanged: (state, action: PayloadAction<{ nodeId: string; label: string }>) => {
const { nodeId, label } = action.payload;
const nodeIndex = state.nodes.findIndex((n) => n.id === nodeId);
@@ -309,17 +289,6 @@ export const nodesSlice = createSlice({
}
node.data.notes = notes;
},
nodeExclusivelySelected: (state, action: PayloadAction<string>) => {
const nodeId = action.payload;
state.nodes = applyNodeChanges(
state.nodes.map((n) => ({
id: n.id,
type: 'select',
selected: n.id === nodeId ? true : false,
})),
state.nodes
);
},
fieldValueReset: (state, action: FieldValueAction<StatefulFieldValue>) => {
fieldValueReducer(state, action, zStatefulFieldValue);
},
@@ -344,6 +313,9 @@ export const nodesSlice = createSlice({
fieldMainModelValueChanged: (state, action: FieldValueAction<MainModelFieldValue>) => {
fieldValueReducer(state, action, zMainModelFieldValue);
},
fieldModelIdentifierValueChanged: (state, action: FieldValueAction<ModelIdentifierFieldValue>) => {
fieldValueReducer(state, action, zModelIdentifierFieldValue);
},
fieldRefinerModelValueChanged: (state, action: FieldValueAction<SDXLRefinerModelFieldValue>) => {
fieldValueReducer(state, action, zSDXLRefinerModelFieldValue);
},
@@ -381,57 +353,6 @@ export const nodesSlice = createSlice({
state.nodes = [];
state.edges = [];
},
selectedAll: (state) => {
state.nodes = applyNodeChanges(
state.nodes.map((n) => ({ id: n.id, type: 'select', selected: true })),
state.nodes
);
state.edges = applyEdgeChanges(
state.edges.map((e) => ({ id: e.id, type: 'select', selected: true })),
state.edges
);
},
selectionPasted: (state, action: PayloadAction<{ nodes: AnyNode[]; edges: InvocationNodeEdge[] }>) => {
const { nodes, edges } = action.payload;
const nodeChanges: NodeChange[] = [];
// Deselect existing nodes
state.nodes.forEach((n) => {
nodeChanges.push({
id: n.data.id,
type: 'select',
selected: false,
});
});
// Add new nodes
nodes.forEach((n) => {
nodeChanges.push({
item: n,
type: 'add',
});
});
const edgeChanges: EdgeChange[] = [];
// Deselect existing edges
state.edges.forEach((e) => {
edgeChanges.push({
id: e.id,
type: 'select',
selected: false,
});
});
// Add new edges
edges.forEach((e) => {
edgeChanges.push({
item: e,
type: 'add',
});
});
state.nodes = applyNodeChanges(nodeChanges, state.nodes);
state.edges = applyEdgeChanges(edgeChanges, state.edges);
},
undo: (state) => state,
redo: (state) => state,
},
@@ -440,13 +361,13 @@ export const nodesSlice = createSlice({
const { nodes, edges } = action.payload;
state.nodes = applyNodeChanges(
nodes.map((node) => ({
item: { ...node, ...SHARED_NODE_PROPERTIES },
type: 'add',
item: { ...node, ...SHARED_NODE_PROPERTIES },
})),
[]
);
state.edges = applyEdgeChanges(
edges.map((edge) => ({ item: edge, type: 'add' })),
edges.map((edge) => ({ type: 'add', item: edge })),
[]
);
});
@@ -454,10 +375,7 @@ export const nodesSlice = createSlice({
});
export const {
connectionMade,
edgeDeleted,
edgesChanged,
edgesDeleted,
fieldValueReset,
fieldBoardValueChanged,
fieldBooleanValueChanged,
@@ -469,27 +387,21 @@ export const {
fieldT2IAdapterModelValueChanged,
fieldLabelChanged,
fieldLoRAModelValueChanged,
fieldModelIdentifierValueChanged,
fieldMainModelValueChanged,
fieldNumberValueChanged,
fieldRefinerModelValueChanged,
fieldSchedulerValueChanged,
fieldStringValueChanged,
fieldVaeModelValueChanged,
nodeAdded,
nodeReplaced,
nodeEditorReset,
nodeExclusivelySelected,
nodeIsIntermediateChanged,
nodeIsOpenChanged,
nodeLabelChanged,
nodeNotesChanged,
nodesChanged,
nodesDeleted,
nodeUseCacheChanged,
notesNodeValueChanged,
selectedAll,
selectionPasted,
edgeAdded,
undo,
redo,
} = nodesSlice.actions;
@@ -500,7 +412,10 @@ export const $copiedNodes = atom<AnyNode[]>([]);
export const $copiedEdges = atom<InvocationNodeEdge[]>([]);
export const $edgesToCopiedNodes = atom<InvocationNodeEdge[]>([]);
export const $pendingConnection = atom<PendingConnection | null>(null);
export const $isUpdatingEdge = atom(false);
export const $edgePendingUpdate = atom<Edge | null>(null);
export const $didUpdateEdge = atom(false);
export const $lastEdgeUpdateMouseEvent = atom<MouseEvent | null>(null);
export const $viewport = atom<Viewport>({ x: 0, y: 0, zoom: 1 });
export const $isAddNodePopoverOpen = atom(false);
export const closeAddNodePopover = () => {
@@ -528,17 +443,17 @@ export const nodesPersistConfig: PersistConfig<NodesState> = {
persistDenylist: [],
};
const selectionMatcher = isAnyOf(selectedAll, selectionPasted, nodeExclusivelySelected);
const isSelectionAction = (action: UnknownAction) => {
if (selectionMatcher(action)) {
return true;
}
if (nodesChanged.match(action)) {
if (action.payload.every((change) => change.type === 'select')) {
return true;
}
}
if (edgesChanged.match(action)) {
if (action.payload.every((change) => change.type === 'select')) {
return true;
}
}
return false;
};
@@ -574,10 +489,7 @@ export const nodesUndoableConfig: UndoableOptions<NodesState, UnknownAction> = {
// This is used for tracking `state.workflow.isTouched`
export const isAnyNodeOrEdgeMutation = isAnyOf(
connectionMade,
edgeDeleted,
edgesChanged,
edgesDeleted,
fieldBoardValueChanged,
fieldBooleanValueChanged,
fieldColorValueChanged,
@@ -594,15 +506,11 @@ export const isAnyNodeOrEdgeMutation = isAnyOf(
fieldSchedulerValueChanged,
fieldStringValueChanged,
fieldVaeModelValueChanged,
nodeAdded,
nodeReplaced,
nodesChanged,
nodeIsIntermediateChanged,
nodeIsOpenChanged,
nodeLabelChanged,
nodeNotesChanged,
nodesDeleted,
nodeUseCacheChanged,
notesNodeValueChanged,
selectionPasted,
edgeAdded
notesNodeValueChanged
);

View File

@@ -6,19 +6,20 @@ import type {
} from 'features/nodes/types/field';
import type {
AnyNode,
InvocationNode,
InvocationNodeEdge,
InvocationTemplate,
NodeExecutionState,
} from 'features/nodes/types/invocation';
import type { WorkflowV3 } from 'features/nodes/types/workflow';
import type { HandleType } from 'reactflow';
export type Templates = Record<string, InvocationTemplate>;
export type NodeExecutionStates = Record<string, NodeExecutionState | undefined>;
export type PendingConnection = {
node: InvocationNode;
template: InvocationTemplate;
nodeId: string;
handleId: string;
handleType: HandleType;
fieldTemplate: FieldInputTemplate | FieldOutputTemplate;
};

View File

@@ -0,0 +1,86 @@
import type { FieldType } from 'features/nodes/types/field';
import { describe, expect, it } from 'vitest';
import { areTypesEqual } from './areTypesEqual';
describe(areTypesEqual.name, () => {
it('should handle equal source and target type', () => {
const sourceType: FieldType = {
name: 'IntegerField',
cardinality: 'SINGLE',
originalType: {
name: 'Foo',
cardinality: 'SINGLE',
},
};
const targetType: FieldType = {
name: 'IntegerField',
cardinality: 'SINGLE',
originalType: {
name: 'Bar',
cardinality: 'SINGLE',
},
};
expect(areTypesEqual(sourceType, targetType)).toBe(true);
});
it('should handle equal source type and original target type', () => {
const sourceType: FieldType = {
name: 'IntegerField',
cardinality: 'SINGLE',
originalType: {
name: 'Foo',
cardinality: 'SINGLE',
},
};
const targetType: FieldType = {
name: 'MainModelField',
cardinality: 'SINGLE',
originalType: {
name: 'IntegerField',
cardinality: 'SINGLE',
},
};
expect(areTypesEqual(sourceType, targetType)).toBe(true);
});
it('should handle equal original source type and target type', () => {
const sourceType: FieldType = {
name: 'MainModelField',
cardinality: 'SINGLE',
originalType: {
name: 'IntegerField',
cardinality: 'SINGLE',
},
};
const targetType: FieldType = {
name: 'IntegerField',
cardinality: 'SINGLE',
originalType: {
name: 'Bar',
cardinality: 'SINGLE',
},
};
expect(areTypesEqual(sourceType, targetType)).toBe(true);
});
it('should handle equal original source type and original target type', () => {
const sourceType: FieldType = {
name: 'MainModelField',
cardinality: 'SINGLE',
originalType: {
name: 'IntegerField',
cardinality: 'SINGLE',
},
};
const targetType: FieldType = {
name: 'LoRAModelField',
cardinality: 'SINGLE',
originalType: {
name: 'IntegerField',
cardinality: 'SINGLE',
},
};
expect(areTypesEqual(sourceType, targetType)).toBe(true);
});
});

View File

@@ -0,0 +1,29 @@
import type { FieldType } from 'features/nodes/types/field';
import { isEqual, omit } from 'lodash-es';
/**
* Checks if two types are equal. If the field types have original types, those are also compared. Any match is
* considered equal. For example, if the first type and original second type match, the types are considered equal.
* @param firstType The first type to compare.
* @param secondType The second type to compare.
* @returns True if the types are equal, false otherwise.
*/
export const areTypesEqual = (firstType: FieldType, secondType: FieldType) => {
const _firstType = 'originalType' in firstType ? omit(firstType, 'originalType') : firstType;
const _secondType = 'originalType' in secondType ? omit(secondType, 'originalType') : secondType;
const _originalFirstType = 'originalType' in firstType ? firstType.originalType : null;
const _originalSecondType = 'originalType' in secondType ? secondType.originalType : null;
if (isEqual(_firstType, _secondType)) {
return true;
}
if (_originalSecondType && isEqual(_firstType, _originalSecondType)) {
return true;
}
if (_originalFirstType && isEqual(_originalFirstType, _secondType)) {
return true;
}
if (_originalFirstType && _originalSecondType && isEqual(_originalFirstType, _originalSecondType)) {
return true;
}
return false;
};

View File

@@ -1,105 +0,0 @@
import type { PendingConnection, Templates } from 'features/nodes/store/types';
import { getCollectItemType } from 'features/nodes/store/util/makeIsConnectionValidSelector';
import type { AnyNode, InvocationNode, InvocationNodeEdge, InvocationTemplate } from 'features/nodes/types/invocation';
import { differenceWith, isEqual, map } from 'lodash-es';
import type { Connection } from 'reactflow';
import { assert } from 'tsafe';
import { getIsGraphAcyclic } from './getIsGraphAcyclic';
import { validateSourceAndTargetTypes } from './validateSourceAndTargetTypes';
export const getFirstValidConnection = (
templates: Templates,
nodes: AnyNode[],
edges: InvocationNodeEdge[],
pendingConnection: PendingConnection,
candidateNode: InvocationNode,
candidateTemplate: InvocationTemplate
): Connection | null => {
if (pendingConnection.node.id === candidateNode.id) {
// Cannot connect to self
return null;
}
const pendingFieldKind = pendingConnection.fieldTemplate.fieldKind === 'input' ? 'target' : 'source';
if (pendingFieldKind === 'source') {
// Connecting from a source to a target
if (!getIsGraphAcyclic(pendingConnection.node.id, candidateNode.id, nodes, edges)) {
return null;
}
if (candidateNode.data.type === 'collect') {
// Special handling for collect node - the `item` field takes any number of connections
return {
source: pendingConnection.node.id,
sourceHandle: pendingConnection.fieldTemplate.name,
target: candidateNode.id,
targetHandle: 'item',
};
}
// Only one connection per target field is allowed - look for an unconnected target field
const candidateFields = map(candidateTemplate.inputs).filter((i) => i.input !== 'direct');
const candidateConnectedFields = edges
.filter((edge) => edge.target === candidateNode.id)
.map((edge) => {
// Edges must always have a targetHandle, safe to assert here
assert(edge.targetHandle);
return edge.targetHandle;
});
const candidateUnconnectedFields = differenceWith(
candidateFields,
candidateConnectedFields,
(field, connectedFieldName) => field.name === connectedFieldName
);
const candidateField = candidateUnconnectedFields.find((field) =>
validateSourceAndTargetTypes(pendingConnection.fieldTemplate.type, field.type)
);
if (candidateField) {
return {
source: pendingConnection.node.id,
sourceHandle: pendingConnection.fieldTemplate.name,
target: candidateNode.id,
targetHandle: candidateField.name,
};
}
} else {
// Connecting from a target to a source
// Ensure we there is not already an edge to the target, except for collect nodes
const isCollect = pendingConnection.node.data.type === 'collect';
const isTargetAlreadyConnected = edges.some(
(e) => e.target === pendingConnection.node.id && e.targetHandle === pendingConnection.fieldTemplate.name
);
if (!isCollect && isTargetAlreadyConnected) {
return null;
}
if (!getIsGraphAcyclic(candidateNode.id, pendingConnection.node.id, nodes, edges)) {
return null;
}
// Sources/outputs can have any number of edges, we can take the first matching output field
let candidateFields = map(candidateTemplate.outputs);
if (isCollect) {
// Narrow candidates to same field type as already is connected to the collect node
const collectItemType = getCollectItemType(templates, nodes, edges, pendingConnection.node.id);
if (collectItemType) {
candidateFields = candidateFields.filter((field) => isEqual(field.type, collectItemType));
}
}
const candidateField = candidateFields.find((field) => {
const isValid = validateSourceAndTargetTypes(field.type, pendingConnection.fieldTemplate.type);
const isAlreadyConnected = edges.some((e) => e.source === candidateNode.id && e.sourceHandle === field.name);
return isValid && !isAlreadyConnected;
});
if (candidateField) {
return {
source: candidateNode.id,
sourceHandle: candidateField.name,
target: pendingConnection.node.id,
targetHandle: pendingConnection.fieldTemplate.name,
};
}
}
return null;
};

View File

@@ -0,0 +1,44 @@
import { deepClone } from 'common/util/deepClone';
import { getCollectItemType } from 'features/nodes/store/util/getCollectItemType';
import { add, buildEdge, buildNode, collect, templates } from 'features/nodes/store/util/testUtils';
import type { FieldType } from 'features/nodes/types/field';
import { unset } from 'lodash-es';
import { describe, expect, it } from 'vitest';
describe(getCollectItemType.name, () => {
it('should return the type of the items the collect node collects', () => {
const n1 = buildNode(add);
const n2 = buildNode(collect);
const e1 = buildEdge(n1.id, 'value', n2.id, 'item');
const result = getCollectItemType(templates, [n1, n2], [e1], n2.id);
expect(result).toEqual<FieldType>({ name: 'IntegerField', cardinality: 'SINGLE' });
});
it('should return null if the collect node does not have any connections', () => {
const n1 = buildNode(collect);
const result = getCollectItemType(templates, [n1], [], n1.id);
expect(result).toBeNull();
});
it("should return null if the first edge to collect's node doesn't exist", () => {
const n1 = buildNode(collect);
const n2 = buildNode(add);
const e1 = buildEdge(n2.id, 'value', n1.id, 'item');
const result = getCollectItemType(templates, [n1], [e1], n1.id);
expect(result).toBeNull();
});
it("should return null if the first edge to collect's node template doesn't exist", () => {
const n1 = buildNode(collect);
const n2 = buildNode(add);
const e1 = buildEdge(n2.id, 'value', n1.id, 'item');
const result = getCollectItemType({ collect }, [n1, n2], [e1], n1.id);
expect(result).toBeNull();
});
it("should return null if the first edge to the collect's field template doesn't exist", () => {
const n1 = buildNode(collect);
const n2 = buildNode(add);
const addWithoutOutputValue = deepClone(add);
unset(addWithoutOutputValue, 'outputs.value');
const e1 = buildEdge(n2.id, 'value', n1.id, 'item');
const result = getCollectItemType({ add: addWithoutOutputValue, collect }, [n2, n1], [e1], n1.id);
expect(result).toBeNull();
});
});

View File

@@ -0,0 +1,38 @@
import type { Templates } from 'features/nodes/store/types';
import type { FieldType } from 'features/nodes/types/field';
import type { AnyNode, InvocationNodeEdge } from 'features/nodes/types/invocation';
/**
* Given a collect node, return the type of the items it collects. The graph is traversed to find the first node and
* field connected to the collector's `item` input. The field type of that field is returned, else null if there is no
* input field.
* @param templates The current invocation templates
* @param nodes The current nodes
* @param edges The current edges
* @param nodeId The collect node's id
* @returns The type of the items the collect node collects, or null if there is no input field
*/
export const getCollectItemType = (
templates: Templates,
nodes: AnyNode[],
edges: InvocationNodeEdge[],
nodeId: string
): FieldType | null => {
const firstEdgeToCollect = edges.find((edge) => edge.target === nodeId && edge.targetHandle === 'item');
if (!firstEdgeToCollect?.sourceHandle) {
return null;
}
const node = nodes.find((n) => n.id === firstEdgeToCollect.source);
if (!node) {
return null;
}
const template = templates[node.data.type];
if (!template) {
return null;
}
const fieldTemplate = template.outputs[firstEdgeToCollect.sourceHandle];
if (!fieldTemplate) {
return null;
}
return fieldTemplate.type;
};

View File

@@ -0,0 +1,203 @@
import { deepClone } from 'common/util/deepClone';
import {
getFirstValidConnection,
getSourceCandidateFields,
getTargetCandidateFields,
} from 'features/nodes/store/util/getFirstValidConnection';
import { add, buildEdge, buildNode, img_resize, templates } from 'features/nodes/store/util/testUtils';
import { unset } from 'lodash-es';
import { describe, expect, it } from 'vitest';
describe('getFirstValidConnection', () => {
it('should return null if the pending and candidate nodes are the same node', () => {
const n = buildNode(add);
expect(getFirstValidConnection(n.id, 'value', n.id, null, [n], [], templates, null)).toBe(null);
});
it('should return null if the sourceHandle and targetHandle are null', () => {
const n1 = buildNode(add);
const n2 = buildNode(add);
expect(getFirstValidConnection(n1.id, null, n2.id, null, [n1, n2], [], templates, null)).toBe(null);
});
it('should return itself if both sourceHandle and targetHandle are provided', () => {
const n1 = buildNode(add);
const n2 = buildNode(add);
expect(getFirstValidConnection(n1.id, 'value', n2.id, 'a', [n1, n2], [], templates, null)).toEqual({
source: n1.id,
sourceHandle: 'value',
target: n2.id,
targetHandle: 'a',
});
});
describe('connecting from a source to a target', () => {
const n1 = buildNode(img_resize);
const n2 = buildNode(img_resize);
it('should return the first valid connection if there are no connected fields', () => {
const r = getFirstValidConnection(n1.id, 'width', n2.id, null, [n1, n2], [], templates, null);
const c = {
source: n1.id,
sourceHandle: 'width',
target: n2.id,
targetHandle: 'width',
};
expect(r).toEqual(c);
});
it('should return the first valid connection if there is a connected field', () => {
const e = buildEdge(n1.id, 'height', n2.id, 'width');
const r = getFirstValidConnection(n1.id, 'width', n2.id, null, [n1, n2], [e], templates, null);
const c = {
source: n1.id,
sourceHandle: 'width',
target: n2.id,
targetHandle: 'height',
};
expect(r).toEqual(c);
});
it('should return the first valid connection if there is an edgePendingUpdate', () => {
const e = buildEdge(n1.id, 'width', n2.id, 'width');
const r = getFirstValidConnection(n1.id, 'width', n2.id, null, [n1, n2], [e], templates, e);
const c = {
source: n1.id,
sourceHandle: 'width',
target: n2.id,
targetHandle: 'width',
};
expect(r).toEqual(c);
});
it('should return null if the target has no valid fields', () => {
const e1 = buildEdge(n1.id, 'width', n2.id, 'width');
const e2 = buildEdge(n1.id, 'height', n2.id, 'height');
const n3 = buildNode(add);
const r = getFirstValidConnection(n3.id, 'value', n2.id, null, [n1, n2, n3], [e1, e2], templates, null);
expect(r).toEqual(null);
});
});
describe('connecting from a target to a source', () => {
const n1 = buildNode(img_resize);
const n2 = buildNode(img_resize);
it('should return the first valid connection if there are no connected fields', () => {
const r = getFirstValidConnection(n1.id, null, n2.id, 'width', [n1, n2], [], templates, null);
const c = {
source: n1.id,
sourceHandle: 'width',
target: n2.id,
targetHandle: 'width',
};
expect(r).toEqual(c);
});
it('should return the first valid connection if there is a connected field', () => {
const e = buildEdge(n1.id, 'height', n2.id, 'width');
const r = getFirstValidConnection(n1.id, null, n2.id, 'height', [n1, n2], [e], templates, null);
const c = {
source: n1.id,
sourceHandle: 'width',
target: n2.id,
targetHandle: 'height',
};
expect(r).toEqual(c);
});
it('should return the first valid connection if there is an edgePendingUpdate', () => {
const e = buildEdge(n1.id, 'width', n2.id, 'width');
const r = getFirstValidConnection(n1.id, null, n2.id, 'width', [n1, n2], [e], templates, e);
const c = {
source: n1.id,
sourceHandle: 'width',
target: n2.id,
targetHandle: 'width',
};
expect(r).toEqual(c);
});
it('should return null if the target has no valid fields', () => {
const e1 = buildEdge(n1.id, 'width', n2.id, 'width');
const e2 = buildEdge(n1.id, 'height', n2.id, 'height');
const n3 = buildNode(add);
const r = getFirstValidConnection(n3.id, null, n2.id, 'a', [n1, n2, n3], [e1, e2], templates, null);
expect(r).toEqual(null);
});
});
});
describe('getTargetCandidateFields', () => {
it('should return an empty array if the nodes canot be found', () => {
const r = getTargetCandidateFields('missing', 'value', 'missing', [], [], templates, null);
expect(r).toEqual([]);
});
it('should return an empty array if the templates cannot be found', () => {
const n1 = buildNode(add);
const n2 = buildNode(add);
const nodes = [n1, n2];
const r = getTargetCandidateFields(n1.id, 'value', n2.id, nodes, [], {}, null);
expect(r).toEqual([]);
});
it('should return an empty array if the source field template cannot be found', () => {
const n1 = buildNode(add);
const n2 = buildNode(add);
const nodes = [n1, n2];
const addWithoutOutputValue = deepClone(add);
unset(addWithoutOutputValue, 'outputs.value');
const r = getTargetCandidateFields(n1.id, 'value', n2.id, nodes, [], { add: addWithoutOutputValue }, null);
expect(r).toEqual([]);
});
it('should return all valid target fields if there are no connected fields', () => {
const n1 = buildNode(img_resize);
const n2 = buildNode(img_resize);
const nodes = [n1, n2];
const r = getTargetCandidateFields(n1.id, 'width', n2.id, nodes, [], templates, null);
expect(r).toEqual([img_resize.inputs['width'], img_resize.inputs['height']]);
});
it('should ignore the edgePendingUpdate if provided', () => {
const n1 = buildNode(img_resize);
const n2 = buildNode(img_resize);
const nodes = [n1, n2];
const edgePendingUpdate = buildEdge(n1.id, 'width', n2.id, 'width');
const r = getTargetCandidateFields(n1.id, 'width', n2.id, nodes, [], templates, edgePendingUpdate);
expect(r).toEqual([img_resize.inputs['width'], img_resize.inputs['height']]);
});
});
describe('getSourceCandidateFields', () => {
it('should return an empty array if the nodes canot be found', () => {
const r = getSourceCandidateFields('missing', 'value', 'missing', [], [], templates, null);
expect(r).toEqual([]);
});
it('should return an empty array if the templates cannot be found', () => {
const n1 = buildNode(add);
const n2 = buildNode(add);
const nodes = [n1, n2];
const r = getSourceCandidateFields(n2.id, 'a', n1.id, nodes, [], {}, null);
expect(r).toEqual([]);
});
it('should return an empty array if the source field template cannot be found', () => {
const n1 = buildNode(add);
const n2 = buildNode(add);
const nodes = [n1, n2];
const addWithoutInputA = deepClone(add);
unset(addWithoutInputA, 'inputs.a');
const r = getSourceCandidateFields(n1.id, 'a', n2.id, nodes, [], { add: addWithoutInputA }, null);
expect(r).toEqual([]);
});
it('should return all valid source fields if there are no connected fields', () => {
const n1 = buildNode(img_resize);
const n2 = buildNode(img_resize);
const nodes = [n1, n2];
const r = getSourceCandidateFields(n2.id, 'width', n1.id, nodes, [], templates, null);
expect(r).toEqual([img_resize.outputs['width'], img_resize.outputs['height']]);
});
it('should ignore the edgePendingUpdate if provided', () => {
const n1 = buildNode(img_resize);
const n2 = buildNode(img_resize);
const nodes = [n1, n2];
const edgePendingUpdate = buildEdge(n1.id, 'width', n2.id, 'width');
const r = getSourceCandidateFields(n2.id, 'width', n1.id, nodes, [], templates, edgePendingUpdate);
expect(r).toEqual([img_resize.outputs['width'], img_resize.outputs['height']]);
});
});

View File

@@ -0,0 +1,149 @@
import type { Templates } from 'features/nodes/store/types';
import { validateConnection } from 'features/nodes/store/util/validateConnection';
import type { FieldInputTemplate, FieldOutputTemplate } from 'features/nodes/types/field';
import type { AnyNode, InvocationNodeEdge } from 'features/nodes/types/invocation';
import { map } from 'lodash-es';
import type { Connection, Edge } from 'reactflow';
/**
*
* @param source The source (node id)
* @param sourceHandle The source handle (field name), if any
* @param target The target (node id)
* @param targetHandle The target handle (field name), if any
* @param nodes The current nodes
* @param edges The current edges
* @param templates The current templates
* @param edgePendingUpdate The edge pending update, if any
* @returns
*/
export const getFirstValidConnection = (
source: string,
sourceHandle: string | null,
target: string,
targetHandle: string | null,
nodes: AnyNode[],
edges: InvocationNodeEdge[],
templates: Templates,
edgePendingUpdate: Edge | null
): Connection | null => {
if (source === target) {
return null;
}
if (sourceHandle && targetHandle) {
return { source, sourceHandle, target, targetHandle };
}
if (sourceHandle && !targetHandle) {
const candidates = getTargetCandidateFields(
source,
sourceHandle,
target,
nodes,
edges,
templates,
edgePendingUpdate
);
const firstCandidate = candidates[0];
if (!firstCandidate) {
return null;
}
return { source, sourceHandle, target, targetHandle: firstCandidate.name };
}
if (!sourceHandle && targetHandle) {
const candidates = getSourceCandidateFields(
target,
targetHandle,
source,
nodes,
edges,
templates,
edgePendingUpdate
);
const firstCandidate = candidates[0];
if (!firstCandidate) {
return null;
}
return { source, sourceHandle: firstCandidate.name, target, targetHandle };
}
return null;
};
export const getTargetCandidateFields = (
source: string,
sourceHandle: string,
target: string,
nodes: AnyNode[],
edges: Edge[],
templates: Templates,
edgePendingUpdate: Edge | null
): FieldInputTemplate[] => {
const sourceNode = nodes.find((n) => n.id === source);
const targetNode = nodes.find((n) => n.id === target);
if (!sourceNode || !targetNode) {
return [];
}
const sourceTemplate = templates[sourceNode.data.type];
const targetTemplate = templates[targetNode.data.type];
if (!sourceTemplate || !targetTemplate) {
return [];
}
const sourceField = sourceTemplate.outputs[sourceHandle];
if (!sourceField) {
return [];
}
const targetCandidateFields = map(targetTemplate.inputs).filter((field) => {
const c = { source, sourceHandle, target, targetHandle: field.name };
const r = validateConnection(c, nodes, edges, templates, edgePendingUpdate, true);
return r.isValid;
});
return targetCandidateFields;
};
export const getSourceCandidateFields = (
target: string,
targetHandle: string,
source: string,
nodes: AnyNode[],
edges: Edge[],
templates: Templates,
edgePendingUpdate: Edge | null
): FieldOutputTemplate[] => {
const targetNode = nodes.find((n) => n.id === target);
const sourceNode = nodes.find((n) => n.id === source);
if (!sourceNode || !targetNode) {
return [];
}
const sourceTemplate = templates[sourceNode.data.type];
const targetTemplate = templates[targetNode.data.type];
if (!sourceTemplate || !targetTemplate) {
return [];
}
const targetField = targetTemplate.inputs[targetHandle];
if (!targetField) {
return [];
}
const sourceCandidateFields = map(sourceTemplate.outputs).filter((field) => {
const c = { source, sourceHandle: field.name, target, targetHandle };
const r = validateConnection(c, nodes, edges, templates, edgePendingUpdate, true);
return r.isValid;
});
return sourceCandidateFields;
};

Some files were not shown because too many files have changed in this diff Show More