Compare commits

...

98 Commits

Author SHA1 Message Date
Mary Hipp
bd71459955 Merge remote-tracking branch 'origin/main' into bria-clone 2025-07-21 12:11:09 -04:00
skunkworxdark
cacfb183a6 Add auto layout controls to node editor (#8239)
* Add auto layout controls using elkjs to node editor

Introduces auto layout functionality for the node editor using elkjs, including a new UI popover for layout options (placement strategy, layering, spacing, direction). Adds related state and actions to workflowSettingsSlice, updates translations, and ensures elkjs is included in optimized dependencies.

* feat(nodes): Improve workflow auto-layout controls and accuracy

- The auto-layout settings panel is updated to use `Select` dropdowns and `NumberInput`
- The layout algorithm now uses the actual rendered dimensions of nodes from the DOM, falling back to estimates only when necessary. This results in a much more accurate and predictable layout.
- The ELKjs library integration is refactored to fix some warnings

* Update useAutoLayout.ts

prettier

* feat(nodes): Improve workflow auto-layout controls and accuracy

- The auto-layout settings panel is updated to use `Select` dropdowns and `NumberInput`
- The layout algorithm now uses the actual rendered dimensions of nodes from the DOM, falling back to estimates only when necessary. This results in a much more accurate and predictable layout.
- The ELKjs library integration is refactored to fix some warnings

* Update useAutoLayout.ts

prettier

* build(ui): import elkjs directly

* updated to use  dagrejs for autolayout

updated to use dagrejs - it has less layout options but is already included

but this is still WIP as some nodes don't report the height correctly. I am still investigating this...

* Update useAutoLayout.ts

update to fix layout issues

* minor updates

- pretty useAutoLayout.ts
- add missing type import in ViewportControls.tsx
- update pnpm-lock.yaml with elkjs removed

* Update ViewportControls.tsx

pnpm fix

* Fix Frontend check + single node selection fix

Fix Frontend check -  remove unused export from workflowSettingsSlice.ts
Update so that if you have a single node selected, it will auto layout all nodes, as this is a common thing to have a single node selected and means that you don't have to unselect it.

* feat(ui): misc improvements for autolayout

- Split popover into own component
- Add util functions to get node w/h
- Use magic wand icon for button
- Fix sizing of input components
- Use CompositeNumberInput instead of base chakra number input
- Add zod schemas for string values and use them in the component to
ensure state integrity

* chore(ui): lint

---------

Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
2025-07-21 14:44:29 +10:00
psychedelicious
564f4f7a60 feat(ui): better icon for invert mask button 2025-07-21 13:47:02 +10:00
Kent Keirsey
113a118fcf fix potential for null data 2025-07-21 13:47:02 +10:00
Kent Keirsey
1f930cdaf2 fix 2025-07-21 13:47:02 +10:00
Kent Keirsey
c490e0ce08 feat(ui):invert mask 2025-07-21 13:47:02 +10:00
Kent Keirsey
7640ee307c feat(ui):Adjust-bbox-to-masks 2025-07-21 13:26:49 +10:00
psychedelicious
1f5f70f898 feat(ui): clean up picker compact view default state handling
- Name it `pickerCompactViewStates` bc its not exclusive to model
picker, it is used for all pickers
- Rename redux action to model an event
- Move selector to right file
- Use selector to derive state for individual picker
2025-07-21 13:18:09 +10:00
Mary Hipp
1430858112 cleanup 2025-07-21 13:18:09 +10:00
Mary Hipp
48c27ec117 persist model picker compact/expanded state 2025-07-21 13:18:09 +10:00
psychedelicious
af7737e804 fix(ui): context menu on staging area images
There was a subtle issue where the progress image wasn't ever cleared,
preventing the context menu from working on staging area preview images.

The staging area preview images were displaying the last progress image
_on top of_ the result image. Because the image elements were so small,
you wouldn't notice that you were looking at a low-res progress image.
Right clicking a progress image gets you no menu.

If you refresh the page or switch tabs, this would fix itself, because
those actions clear out the progress images. The result image would then
be the topmost element, and the context menu works.

Fixing this without introducing a flash of empty space as the progress
image was hidden required a bit of refactoring. We have to wait for the
result image element to load before clearing out the progress.

Result - progress images appear to "resolve" to result images in the
staging area without any blips or jank, and the context menu works after
that happens.
2025-07-21 13:15:34 +10:00
Ilan Tchenak
a5542370a6 ruff fix 2025-07-20 15:03:57 +03:00
psychedelicious
3eca0d2ba0 fix(ui): staging area left/right hotkeys 2025-07-18 08:08:15 -04:00
psychedelicious
307259f096 fix(ui): ensure staging area always has the right state and session association 2025-07-18 08:08:15 -04:00
psychedelicious
bed01941a5 fix(ui): ensure we clean up when session id changes 2025-07-18 08:08:15 -04:00
psychedelicious
89fa43a3b6 docs(ui): update StagingAreaApi docstrings 2025-07-18 08:08:15 -04:00
psychedelicious
d8fcb08abf repo: update ignores 2025-07-18 08:08:15 -04:00
psychedelicious
c61bcd9f50 tests(ui): add test suite for StagingAreaApi 2025-07-18 08:08:15 -04:00
psychedelicious
3fb0fcbbfb tidy(ui): move staging area components to correct dir 2025-07-18 08:08:15 -04:00
psychedelicious
db9af5083f tidy(ui): move launchpad components to ui dir 2025-07-18 08:08:15 -04:00
psychedelicious
720f1bb65c chore(ui): rename context2.tsx -> context.tsx 2025-07-18 08:08:15 -04:00
psychedelicious
7dfb318ba2 chore(ui): lint 2025-07-18 08:08:15 -04:00
psychedelicious
9b024da2b4 refactor(ui): move staging area logic out side react
Was running into difficultlies reasoning about the logic and couldn't
write tests because it was all in react.

Moved logic outside react, updated context, make it testable.
2025-07-18 08:08:15 -04:00
psychedelicious
15ca3b727a wip 2025-07-18 08:08:15 -04:00
psychedelicious
74ca604ae0 fix(ui): unstyled error boundary 2025-07-18 08:08:15 -04:00
psychedelicious
6934b05c85 fix(ui): use invocation context provider in inspector panel 2025-07-18 08:08:15 -04:00
psychedelicious
1a47a5317c chore(ui): update dockview to latest
Remove extraneous fix now that the disableDnd issue is resolved upstream
2025-07-18 08:08:15 -04:00
psychedelicious
bc3ef21c64 chore(ui): bump version to v6.1.0rc2 2025-07-18 08:08:15 -04:00
psychedelicious
e329f5ad43 fix(ui): negative style prompt not recorded in metadata 2025-07-18 06:41:21 +10:00
Ubuntu
c296fd2305 fixed node issue 2025-07-17 17:52:29 +00:00
psychedelicious
e6ad91bf89 chore(ui): update prettier config 2025-07-17 22:04:57 +10:00
psychedelicious
2f586416a5 chore(ui): remove unused pkgs 2025-07-17 22:04:57 +10:00
psychedelicious
33b56f421c chore(ui): lint 2025-07-17 22:04:57 +10:00
psychedelicious
e58ee4c492 chore(ui): upgrade zod 2025-07-17 22:04:57 +10:00
psychedelicious
49691aa07e chore(ui): upgrade rollup vis 2025-07-17 22:04:57 +10:00
psychedelicious
56570f235f chore(ui): actually upgrade storybook 2025-07-17 22:04:57 +10:00
psychedelicious
a2d95cf5b6 chore(ui): upgrade minor bump packages 2025-07-17 22:04:57 +10:00
psychedelicious
704dbfd04a chore(ui): upgrade storybook 2025-07-17 22:04:57 +10:00
psychedelicious
5d9e078043 chore(ui): finish eslint v9 migration 2025-07-17 22:04:57 +10:00
psychedelicious
875cde13ae chore(ui): migrate to eslint v9 (wip) 2025-07-17 22:04:57 +10:00
psychedelicious
77655aed86 chore(ui): update eslint config 2025-07-17 22:04:57 +10:00
psychedelicious
0628b92d63 chore: bump version to v6.1.0rc1 2025-07-17 19:30:38 +10:00
psychedelicious
9e526d00c2 chore(ui): lint 2025-07-17 15:36:24 +10:00
psychedelicious
1a24396be8 feat(ui): styling when nodes have error 2025-07-17 15:36:24 +10:00
psychedelicious
d97e73a565 chore(ui): lint 2025-07-17 15:36:24 +10:00
psychedelicious
55b14c8aaf perf(ui): optimize redux selectors for workflow editor
- Build selectors for each node in a react context so components can
re-use the same selectors
- Cache the selectors in the context
2025-07-17 15:36:24 +10:00
psychedelicious
79f65e57eb fix(ui): remove unnecessary coalescing operator 2025-07-17 14:21:02 +10:00
Kent Keirsey
b4c8950278 address comments 2025-07-17 14:21:02 +10:00
Kent Keirsey
400b2e9a55 unlint. 2025-07-17 14:21:02 +10:00
Kent Keirsey
3a687c583a lint 2025-07-17 14:21:02 +10:00
Kent Keirsey
833950078d commit tile size controls 2025-07-17 14:21:02 +10:00
Kent Keirsey
e698dcb148 unlint. 2025-07-17 14:21:02 +10:00
Kent Keirsey
218386e077 lint 2025-07-17 14:21:02 +10:00
Kent Keirsey
4426be9e64 commit tile size controls 2025-07-17 14:21:02 +10:00
psychedelicious
86f4cf7857 feat(ui): related embedding styling/tidy 2025-07-17 14:12:29 +10:00
Kent Keirsey
49ae66d94a Added related model support 2025-07-17 14:12:29 +10:00
Cursor Agent
c10865c7ef Reorder embedding options in PromptTriggerSelect component
Co-authored-by: kent <kent@invoke.ai>
2025-07-17 14:12:29 +10:00
psychedelicious
f3478a189a fix(ui): able to drag empty space in tab bar and detach panels 2025-07-17 13:58:32 +10:00
psychedelicious
43db29176a chore(ui): lint 2025-07-17 13:52:24 +10:00
psychedelicious
f38922929c docs(ui): comments in modelsLoaded 2025-07-17 13:52:24 +10:00
psychedelicious
7d02c58f86 fix(ui): move <ParamTileControlNetModel /> to <UpscaleTabAdvancedSettingsAccordion /> 2025-07-17 13:52:24 +10:00
Kent Keirsey
6edce8be87 Add scaling in 2025-07-17 13:52:24 +10:00
Kent Keirsey
31f63e38bd lint 2025-07-17 13:52:24 +10:00
Kent Keirsey
78a68ac3a7 Updated 2025-07-17 13:52:24 +10:00
Kent Keirsey
8cd3bcd1c0 Updates 2025-07-17 13:52:24 +10:00
Cursor Agent
264cc5ef46 Add tile ControlNet model selection to upscale settings
Co-authored-by: kent <kent@invoke.ai>
2025-07-17 13:52:24 +10:00
JPPhoto
8bfbea5ed3 Updated __init__.py 2025-07-17 06:33:56 +10:00
JPPhoto
f06a66da07 Updated schema.ts 2025-07-17 06:33:56 +10:00
Jonathan
337cae9b22 Update __init__.py
Added FluxConditioningField, FluxConditioningCollectionOutput, and FluxConditioningCollectionOutput,
2025-07-17 06:33:56 +10:00
Jonathan
bf926bb7d5 Update primitives.py
Added FluxConditioningCollectionOutput
2025-07-17 06:33:56 +10:00
psychedelicious
18ad9a6af3 feat(ui): canvas/viewer panel tabs show progress 2025-07-17 06:20:05 +10:00
psychedelicious
b6ed31c222 feat(ui): clicking invoke switches to viewer tab instead of canvas when save all images to gallery is enabled 2025-07-17 06:20:05 +10:00
psychedelicious
200beb5af5 feat(ui): make save all images to gallery option also bypass canvas 2025-07-17 06:20:05 +10:00
psychedelicious
f82a948bdd refactor(ui): canvas autoswitch logic
Simplify the canvas auto-switch logic to not rely on the preview images
loading. This fixes an issue where offscreen preview images didn't get
auto-switched to. Images are now loaded directly.
2025-07-17 06:20:05 +10:00
psychedelicious
dd03e3ddcd refactor(ui): simplify canvas session logic 2025-07-17 06:20:05 +10:00
psychedelicious
7561b73e8f fix(ui): uppercase file extensions blocked for image upload
Closes #8284
2025-07-17 00:48:36 +10:00
psychedelicious
caa97608c7 fix(ui): aspect ratios out of order 2025-07-16 23:27:37 +10:00
Mary Hipp
72a6d1edc1 simplify descriptoin styling 2025-07-16 09:19:33 -04:00
Mary Hipp
b8bf89c2f1 add fallback image and make sure description text is legible for model picker noncompact 2025-07-16 09:19:33 -04:00
psychedelicious
a1ade2b8c0 feat(ui): export apis & actions from package 2025-07-16 08:21:03 -04:00
Ilan Tchenak
c08a6a852d moved bria's nodes to invocations folder 2025-07-16 00:01:33 +03:00
Eugene Brodsky
4bdcae1f8f fix(docker): switch to pnpm10.x 2025-07-15 13:03:15 -04:00
Jonathan
4b22c84407 Update dev-environment.md
Document the latest changes required to build Invoke 6.0.
2025-07-15 15:21:01 +10:00
Eugene Brodsky
c9daf1db30 (fix) remove timeout from image prompt expansion (#8281) 2025-07-14 11:19:20 -04:00
Ubuntu
e1139de551 Small cosmetic fixes 2025-07-14 14:56:39 +00:00
Ubuntu
44b7b9c29d removed unused file 2025-07-14 13:40:24 +00:00
Ubuntu
2d55dbe67a Added scikit-image required for Bria's OpenposeDetector model 2025-07-14 13:32:09 +00:00
Ilan Tchenak
04ea87b0bb Add Bria text to image model and controlnet support 2025-07-14 13:20:34 +00:00
psychedelicious
06d3cfbe97 gh: update bug report template
- Add require drop down for install method
- Make browser version optional
- Link to latest release
- Update verbiage for sys info section
2025-07-14 12:18:52 +10:00
psychedelicious
71e4901313 fix(ui): ignore disalbed ref images in readiness checks 2025-07-14 10:51:51 +10:00
psychedelicious
82fb897b62 chore(ui): lint 2025-07-12 14:56:57 +10:00
psychedelicious
192b00d969 chore: bump version to v6.0.2 2025-07-12 14:56:57 +10:00
psychedelicious
7bb25ef1b4 fix(ui): gallery dnd 2025-07-12 14:56:57 +10:00
Ilan Tchenak
7140f2ec72 Setup Probe and UI to accept bria controlnet models 2025-07-10 08:09:09 +00:00
Ubuntu
9e5e1ec0da addded bria nodes for bria3.1 and bria3.2 2025-07-09 18:38:09 +00:00
Ubuntu
a139885bf7 front end support for bria 2025-07-09 18:38:04 +00:00
Ubuntu
f5423133a8 added support for loading bria transformer 2025-07-09 18:36:19 +00:00
Brandon Rising
9c9265cdad Setup Probe and UI to accept bria main models 2025-07-09 18:35:54 +00:00
354 changed files with 11461 additions and 3975 deletions

View File

@@ -21,6 +21,20 @@ body:
- label: I have searched the existing issues
required: true
- type: dropdown
id: install_method
attributes:
label: Install method
description: How did you install Invoke?
multiple: false
options:
- "Invoke's Launcher"
- 'Stability Matrix'
- 'Pinokio'
- 'Manual'
validations:
required: true
- type: markdown
attributes:
value: __Describe your environment__
@@ -76,8 +90,8 @@ body:
attributes:
label: Version number
description: |
The version of Invoke you have installed. If it is not the latest version, please update and try again to confirm the issue still exists. If you are testing main, please include the commit hash instead.
placeholder: ex. 3.6.1
The version of Invoke you have installed. If it is not the [latest version](https://github.com/invoke-ai/InvokeAI/releases/latest), please update and try again to confirm the issue still exists. If you are testing main, please include the commit hash instead.
placeholder: ex. v6.0.2
validations:
required: true
@@ -85,17 +99,17 @@ body:
id: browser-version
attributes:
label: Browser
description: Your web browser and version.
description: Your web browser and version, if you do not use the Launcher's provided GUI.
placeholder: ex. Firefox 123.0b3
validations:
required: true
required: false
- type: textarea
id: python-deps
attributes:
label: Python dependencies
label: System Information
description: |
If the problem occurred during image generation, click the gear icon at the bottom left corner, click "About", click the copy button and then paste here.
Click the gear icon at the bottom left corner, then click "About". Click the copy button and then paste here.
validations:
required: false

2
.gitignore vendored
View File

@@ -190,3 +190,5 @@ installer/update.bat
installer/update.sh
installer/InvokeAI-Installer/
.aider*
.claude/

View File

@@ -5,8 +5,7 @@
FROM docker.io/node:22-slim AS web-builder
ENV PNPM_HOME="/pnpm"
ENV PATH="$PNPM_HOME:$PATH"
RUN corepack use pnpm@8.x
RUN corepack enable
RUN corepack use pnpm@10.x && corepack enable
WORKDIR /build
COPY invokeai/frontend/web/ ./

View File

@@ -41,7 +41,7 @@ If you just want to use Invoke, you should use the [launcher][launcher link].
With the modifications made, the install command should look something like this:
```sh
uv pip install -e ".[dev,test,docs,xformers]" --python 3.12 --python-preference only-managed --index=https://download.pytorch.org/whl/cu126 --reinstall
uv pip install -e ".[dev,test,docs,xformers]" --python 3.12 --python-preference only-managed --index=https://download.pytorch.org/whl/cu128 --reinstall
```
6. At this point, you should have Invoke installed, a venv set up and activated, and the server running. But you will see a warning in the terminal that no UI was found. If you go to the URL for the server, you won't get a UI.
@@ -50,11 +50,11 @@ If you just want to use Invoke, you should use the [launcher][launcher link].
If you only want to edit the docs, you can stop here and skip to the **Documentation** section below.
7. Install the frontend dev toolchain:
7. Install the frontend dev toolchain, paying attention to versions:
- [`nodejs`](https://nodejs.org/) (v20+)
- [`nodejs`](https://nodejs.org/) (tested on LTS, v22)
- [`pnpm`](https://pnpm.io/8.x/installation) (must be v8 - not v9!)
- [`pnpm`](https://pnpm.io/installation) (tested on v10)
8. Do a production build of the frontend:

View File

@@ -297,7 +297,7 @@ Migration logic is in [migrations.ts].
<!-- links -->
[pydantic]: https://github.com/pydantic/pydantic 'pydantic'
[zod]: https://github.com/colinhacks/zod 'zod/v4'
[zod]: https://github.com/colinhacks/zod 'zod'
[openapi-types]: https://github.com/kogosoftwarellc/open-api/tree/main/packages/openapi-types 'openapi-types'
[reactflow]: https://github.com/xyflow/xyflow 'reactflow'
[reactflow-concepts]: https://reactflow.dev/learn/concepts/terms-and-definitions

View File

@@ -0,0 +1,154 @@
import cv2
import numpy as np
from PIL import Image
from pydantic import BaseModel, Field
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import (
FieldDescriptions,
ImageField,
InputField,
OutputField,
UIType,
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.bria.controlnet_aux.open_pose import Body, Face, Hand, OpenposeDetector
from invokeai.backend.bria.controlnet_bria import BRIA_CONTROL_MODES
from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import DepthAnythingPipeline
from invokeai.invocation_api import Classification, ImageOutput
DEPTH_SMALL_V2_URL = "depth-anything/Depth-Anything-V2-Small-hf"
HF_LLLYASVIEL = "https://huggingface.co/lllyasviel/Annotators/resolve/main/"
class BriaControlNetField(BaseModel):
image: ImageField = Field(description="The control image")
model: ModelIdentifierField = Field(description="The ControlNet model to use")
mode: BRIA_CONTROL_MODES = Field(description="The mode of the ControlNet")
conditioning_scale: float = Field(description="The weight given to the ControlNet")
@invocation_output("bria_controlnet_output")
class BriaControlNetOutput(BaseInvocationOutput):
"""Bria ControlNet info"""
control: BriaControlNetField = OutputField(description=FieldDescriptions.control)
preprocessed_images: ImageField = OutputField(description="The preprocessed control image")
@invocation(
"bria_controlnet",
title="ControlNet - Bria",
tags=["controlnet", "bria"],
category="controlnet",
version="1.0.0",
classification=Classification.Prototype,
)
class BriaControlNetInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Collect Bria ControlNet info to pass to denoiser node."""
control_image: ImageField = InputField(description="The control image")
control_model: ModelIdentifierField = InputField(
description=FieldDescriptions.controlnet_model, ui_type=UIType.BriaControlNetModel
)
control_mode: BRIA_CONTROL_MODES = InputField(
default="depth", description="The mode of the ControlNet"
)
control_weight: float = InputField(
default=1.0, ge=-1, le=2, description="The weight given to the ControlNet"
)
def invoke(self, context: InvocationContext) -> BriaControlNetOutput:
image_in = resize_img(context.images.get_pil(self.control_image.image_name))
if self.control_mode == "canny":
control_image = extract_canny(image_in)
elif self.control_mode == "depth":
control_image = extract_depth(image_in, context)
elif self.control_mode == "pose":
control_image = extract_openpose(image_in, context)
elif self.control_mode == "colorgrid":
control_image = tile(64, image_in)
elif self.control_mode == "recolor":
control_image = convert_to_grayscale(image_in)
elif self.control_mode == "tile":
control_image = tile(16, image_in)
control_image = resize_img(control_image)
image_dto = context.images.save(image=control_image)
image_output = ImageOutput.build(image_dto)
return BriaControlNetOutput(
preprocessed_images=image_output.image,
control=BriaControlNetField(
image=ImageField(image_name=image_dto.image_name),
model=self.control_model,
mode=self.control_mode,
conditioning_scale=self.control_weight,
),
)
RATIO_CONFIGS_1024 = {
0.6666666666666666: {"width": 832, "height": 1248},
0.7432432432432432: {"width": 880, "height": 1184},
0.8028169014084507: {"width": 912, "height": 1136},
1.0: {"width": 1024, "height": 1024},
1.2456140350877194: {"width": 1136, "height": 912},
1.3454545454545455: {"width": 1184, "height": 880},
1.4339622641509433: {"width": 1216, "height": 848},
1.5: {"width": 1248, "height": 832},
1.5490196078431373: {"width": 1264, "height": 816},
1.62: {"width": 1296, "height": 800},
1.7708333333333333: {"width": 1360, "height": 768},
}
def extract_depth(image: Image.Image, context: InvocationContext):
loaded_model = context.models.load_remote_model(DEPTH_SMALL_V2_URL, DepthAnythingPipeline.load_model)
with loaded_model as depth_anything_detector:
assert isinstance(depth_anything_detector, DepthAnythingPipeline)
depth_map = depth_anything_detector.generate_depth(image)
return depth_map
def extract_openpose(image: Image.Image, context: InvocationContext):
body_model = context.models.load_remote_model(f"{HF_LLLYASVIEL}body_pose_model.pth", Body)
hand_model = context.models.load_remote_model(f"{HF_LLLYASVIEL}hand_pose_model.pth", Hand)
face_model = context.models.load_remote_model(f"{HF_LLLYASVIEL}facenet.pth", Face)
with body_model as body_model, hand_model as hand_model, face_model as face_model:
open_pose_model = OpenposeDetector(body_model, hand_model, face_model)
processed_image_open_pose = open_pose_model(image, hand_and_face=True)
processed_image_open_pose = processed_image_open_pose.resize(image.size)
return processed_image_open_pose
def extract_canny(input_image):
image = np.array(input_image)
image = cv2.Canny(image, 100, 200)
image = image[:, :, None]
image = np.concatenate([image, image, image], axis=2)
canny_image = Image.fromarray(image)
return canny_image
def convert_to_grayscale(image):
gray_image = image.convert('L').convert('RGB')
return gray_image
def tile(downscale_factor, input_image):
control_image = input_image.resize((input_image.size[0] // downscale_factor, input_image.size[1] // downscale_factor)).resize(input_image.size, Image.Resampling.NEAREST)
return control_image
def resize_img(control_image):
image_ratio = control_image.width / control_image.height
ratio = min(RATIO_CONFIGS_1024.keys(), key=lambda k: abs(k - image_ratio))
to_height = RATIO_CONFIGS_1024[ratio]["height"]
to_width = RATIO_CONFIGS_1024[ratio]["width"]
resized_image = control_image.resize((to_width, to_height), resample=Image.Resampling.LANCZOS)
return resized_image

View File

@@ -0,0 +1,46 @@
import torch
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from PIL import Image
from invokeai.app.invocations.model import VAEField
from invokeai.app.invocations.primitives import FieldDescriptions, Input, InputField, LatentsField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.invocation_api import BaseInvocation, Classification, ImageOutput, invocation
@invocation(
"bria_decoder",
title="Decoder - Bria",
tags=["image", "bria"],
category="image",
version="1.0.0",
classification=Classification.Prototype,
)
class BriaDecoderInvocation(BaseInvocation):
vae: VAEField = InputField(
description=FieldDescriptions.vae,
input=Input.Connection,
)
latents: LatentsField = InputField(
description=FieldDescriptions.latents,
input=Input.Connection,
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.tensors.load(self.latents.latents_name)
latents = latents.view(1, 64, 64, 4, 2, 2).permute(0, 3, 1, 4, 2, 5).reshape(1, 4, 128, 128)
with context.models.load(self.vae.vae) as vae:
assert isinstance(vae, AutoencoderKL)
latents = (latents / vae.config.scaling_factor)
latents = latents.to(device=vae.device, dtype=vae.dtype)
decoded_output = vae.decode(latents)
image = decoded_output.sample
# Convert to numpy with proper gradient handling
image = ((image.clamp(-1, 1) + 1) / 2 * 255).cpu().detach().permute(0, 2, 3, 1).numpy().astype("uint8")[0]
img = Image.fromarray(image)
image_dto = context.images.save(image=img)
return ImageOutput.build(image_dto)

View File

@@ -0,0 +1,185 @@
from typing import List, Tuple
import torch
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
from invokeai.app.invocations.bria_controlnet import BriaControlNetField
from invokeai.app.invocations.fields import Input, InputField, LatentsField, OutputField
from invokeai.app.invocations.model import SubModelType, T5EncoderField, TransformerField, VAEField
from invokeai.app.invocations.primitives import BaseInvocationOutput, FieldDescriptions
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.bria.controlnet_bria import BriaControlModes, BriaMultiControlNetModel
from invokeai.backend.bria.controlnet_utils import prepare_control_images
from invokeai.backend.bria.pipeline_bria_controlnet import BriaControlNetPipeline
from invokeai.backend.bria.transformer_bria import BriaTransformer2DModel
from invokeai.invocation_api import BaseInvocation, Classification, invocation, invocation_output
@invocation_output("bria_denoise_output")
class BriaDenoiseInvocationOutput(BaseInvocationOutput):
latents: LatentsField = OutputField(description=FieldDescriptions.latents)
@invocation(
"bria_denoise",
title="Denoise - Bria",
tags=["image", "bria"],
category="image",
version="1.0.0",
classification=Classification.Prototype,
)
class BriaDenoiseInvocation(BaseInvocation):
num_steps: int = InputField(
default=30, title="Number of Steps", description="The number of steps to use for the denoiser"
)
guidance_scale: float = InputField(
default=5.0, title="Guidance Scale", description="The guidance scale to use for the denoiser"
)
transformer: TransformerField = InputField(
description="Bria model (Transformer) to load",
input=Input.Connection,
title="Transformer",
)
t5_encoder: T5EncoderField = InputField(
title="T5Encoder",
description=FieldDescriptions.t5_encoder,
input=Input.Connection,
)
vae: VAEField = InputField(
description=FieldDescriptions.vae,
input=Input.Connection,
title="VAE",
)
latents: LatentsField = InputField(
description="Latents to denoise",
input=Input.Connection,
title="Latents",
)
latent_image_ids: LatentsField = InputField(
description="Latent Image IDs to denoise",
input=Input.Connection,
title="Latent Image IDs",
)
pos_embeds: LatentsField = InputField(
description="Positive Prompt Embeds",
input=Input.Connection,
title="Positive Prompt Embeds",
)
neg_embeds: LatentsField = InputField(
description="Negative Prompt Embeds",
input=Input.Connection,
title="Negative Prompt Embeds",
)
text_ids: LatentsField = InputField(
description="Text IDs",
input=Input.Connection,
title="Text IDs",
)
control: BriaControlNetField | list[BriaControlNetField] | None = InputField(
description="ControlNet",
input=Input.Connection,
title="ControlNet",
default = None,
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> BriaDenoiseInvocationOutput:
latents = context.tensors.load(self.latents.latents_name)
pos_embeds = context.tensors.load(self.pos_embeds.latents_name)
neg_embeds = context.tensors.load(self.neg_embeds.latents_name)
text_ids = context.tensors.load(self.text_ids.latents_name)
latent_image_ids = context.tensors.load(self.latent_image_ids.latents_name)
scheduler_identifier = self.transformer.transformer.model_copy(update={"submodel_type": SubModelType.Scheduler})
device = None
dtype = None
with (
context.models.load(self.transformer.transformer) as transformer,
context.models.load(scheduler_identifier) as scheduler,
context.models.load(self.vae.vae) as vae,
context.models.load(self.t5_encoder.text_encoder) as t5_encoder,
context.models.load(self.t5_encoder.tokenizer) as t5_tokenizer,
):
assert isinstance(transformer, BriaTransformer2DModel)
assert isinstance(scheduler, FlowMatchEulerDiscreteScheduler)
assert isinstance(vae, AutoencoderKL)
dtype = transformer.dtype
device = transformer.device
latents, pos_embeds, neg_embeds = (x.to(device, dtype) for x in (latents, pos_embeds, neg_embeds))
control_model, control_images, control_modes, control_scales = None, None, None, None
if self.control is not None:
control_model, control_images, control_modes, control_scales = self._prepare_multi_control(
context=context,
vae=vae,
width=1024,
height=1024,
device=vae.device,
)
pipeline = BriaControlNetPipeline(
transformer=transformer,
scheduler=scheduler,
vae=vae,
text_encoder=t5_encoder,
tokenizer=t5_tokenizer,
controlnet=control_model,
)
pipeline.to(device=transformer.device, dtype=transformer.dtype)
latents = pipeline(
control_image=control_images,
control_mode=control_modes,
width=1024,
height=1024,
controlnet_conditioning_scale=control_scales,
num_inference_steps=self.num_steps,
max_sequence_length=128,
guidance_scale=self.guidance_scale,
latents=latents,
latent_image_ids=latent_image_ids,
text_ids=text_ids,
prompt_embeds=pos_embeds,
negative_prompt_embeds=neg_embeds,
output_type="latent",
)[0]
assert isinstance(latents, torch.Tensor)
saved_input_latents_tensor = context.tensors.save(latents)
latents_output = LatentsField(latents_name=saved_input_latents_tensor)
return BriaDenoiseInvocationOutput(latents=latents_output)
def _prepare_multi_control(
self,
context: InvocationContext,
vae: AutoencoderKL,
width: int,
height: int,
device: torch.device
) -> Tuple[BriaMultiControlNetModel, List[torch.Tensor], List[torch.Tensor], List[float]]:
control = self.control if isinstance(self.control, list) else [self.control]
control_images, control_models, control_modes, control_scales = [], [], [], []
for controlnet in control:
if controlnet is not None:
control_models.append(context.models.load(controlnet.model).model)
control_modes.append(BriaControlModes[controlnet.mode].value)
control_scales.append(controlnet.conditioning_scale)
try:
control_images.append(context.images.get_pil(controlnet.image.image_name))
except Exception:
raise FileNotFoundError(f"Control image {controlnet.image.image_name} not found. Make sure not to delete the preprocessed image before finishing the pipeline.")
control_model = BriaMultiControlNetModel(control_models).to(device)
tensored_control_images, tensored_control_modes = prepare_control_images(
vae=vae,
control_images=control_images,
control_modes=control_modes,
width=width,
height=height,
device=device,
)
return control_model, tensored_control_images, tensored_control_modes, control_scales

View File

@@ -0,0 +1,76 @@
import torch
from invokeai.app.invocations.fields import Input, InputField, OutputField
from invokeai.app.invocations.model import TransformerField
from invokeai.app.invocations.primitives import (
BaseInvocationOutput,
FieldDescriptions,
LatentsField,
)
from invokeai.backend.bria.pipeline_bria_controlnet import prepare_latents
from invokeai.invocation_api import (
BaseInvocation,
Classification,
InvocationContext,
invocation,
invocation_output,
)
@invocation_output("bria_latent_sampler_output")
class BriaLatentSamplerInvocationOutput(BaseInvocationOutput):
"""Base class for nodes that output a CogView text conditioning tensor."""
latents: LatentsField = OutputField(description=FieldDescriptions.cond)
latent_image_ids: LatentsField = OutputField(description=FieldDescriptions.cond)
@invocation(
"bria_latent_sampler",
title="Latent Sampler - Bria",
tags=["image", "bria"],
category="image",
version="1.0.0",
classification=Classification.Prototype,
)
class BriaLatentSamplerInvocation(BaseInvocation):
seed: int = InputField(
default=42,
title="Seed",
description="The seed to use for the latent sampler",
)
transformer: TransformerField = InputField(
description="Bria model (Transformer) to load",
input=Input.Connection,
title="Transformer",
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> BriaLatentSamplerInvocationOutput:
with context.models.load(self.transformer.transformer) as transformer:
device = transformer.device
dtype = transformer.dtype
height, width = 1024, 1024
generator = torch.Generator(device=device).manual_seed(self.seed)
num_channels_latents = 4
latents, latent_image_ids = prepare_latents(
batch_size=1,
num_channels_latents=num_channels_latents,
height=height,
width=width,
dtype=dtype,
device=device,
generator=generator,
)
saved_latents_tensor = context.tensors.save(latents)
saved_latent_image_ids_tensor = context.tensors.save(latent_image_ids)
latents_output = LatentsField(latents_name=saved_latents_tensor)
latent_image_ids_output = LatentsField(latents_name=saved_latent_image_ids_tensor)
return BriaLatentSamplerInvocationOutput(
latents=latents_output,
latent_image_ids=latent_image_ids_output,
)

View File

@@ -0,0 +1,58 @@
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.invocations.model import (
ModelIdentifierField,
SubModelType,
T5EncoderField,
TransformerField,
VAEField,
)
from invokeai.invocation_api import (
BaseInvocation,
BaseInvocationOutput,
Classification,
InvocationContext,
invocation,
invocation_output,
)
@invocation_output("bria_model_loader_output")
class BriaModelLoaderOutput(BaseInvocationOutput):
"""Bria base model loader output"""
transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
t5_encoder: T5EncoderField = OutputField(description=FieldDescriptions.t5_encoder, title="T5 Encoder")
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
@invocation(
"bria_model_loader",
title="Main Model - Bria",
tags=["model", "bria"],
version="1.0.0",
classification=Classification.Prototype,
)
class BriaModelLoaderInvocation(BaseInvocation):
"""Loads a bria base model, outputting its submodels."""
model: ModelIdentifierField = InputField(
description="Bria model (Transformer) to load",
ui_type=UIType.BriaMainModel,
input=Input.Direct,
)
def invoke(self, context: InvocationContext) -> BriaModelLoaderOutput:
for key in [self.model.key]:
if not context.models.exists(key):
raise ValueError(f"Unknown model: {key}")
transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
text_encoder = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
tokenizer = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
return BriaModelLoaderOutput(
transformer=TransformerField(transformer=transformer, loras=[]),
t5_encoder=T5EncoderField(tokenizer=tokenizer, text_encoder=text_encoder, loras=[]),
vae=VAEField(vae=vae),
)

View File

@@ -0,0 +1,93 @@
from typing import Optional
import torch
from transformers import (
T5EncoderModel,
T5TokenizerFast,
)
from invokeai.app.invocations.model import T5EncoderField
from invokeai.app.invocations.primitives import BaseInvocationOutput, FieldDescriptions, Input, OutputField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.bria.pipeline_bria_controlnet import encode_prompt
from invokeai.invocation_api import (
BaseInvocation,
Classification,
InputField,
LatentsField,
invocation,
invocation_output,
)
@invocation_output("bria_text_encoder_output")
class BriaTextEncoderInvocationOutput(BaseInvocationOutput):
"""Base class for nodes that output a CogView text conditioning tensor."""
pos_embeds: LatentsField = OutputField(description=FieldDescriptions.cond)
neg_embeds: LatentsField = OutputField(description=FieldDescriptions.cond)
text_ids: LatentsField = OutputField(description=FieldDescriptions.cond)
@invocation(
"bria_text_encoder",
title="Prompt - Bria",
tags=["prompt", "conditioning", "bria"],
category="conditioning",
version="1.0.0",
classification=Classification.Prototype,
)
class BriaTextEncoderInvocation(BaseInvocation):
prompt: str = InputField(
title="Prompt",
description="The prompt to encode",
)
negative_prompt: Optional[str] = InputField(
title="Negative Prompt",
description="The negative prompt to encode",
default="Logo,Watermark,Text,Ugly,Morbid,Extra fingers,Poorly drawn hands,Mutation,Blurry,Extra limbs,Gross proportions,Missing arms,Mutated hands,Long neck,Duplicate",
)
max_length: int = InputField(
default=128,
title="Max Length",
description="The maximum length of the prompt",
)
t5_encoder: T5EncoderField = InputField(
title="T5Encoder",
description=FieldDescriptions.t5_encoder,
input=Input.Connection,
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> BriaTextEncoderInvocationOutput:
t5_encoder_info = context.models.load(self.t5_encoder.text_encoder)
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer)
with (
t5_encoder_info as text_encoder,
t5_tokenizer_info as tokenizer,
):
assert isinstance(tokenizer, T5TokenizerFast)
assert isinstance(text_encoder, T5EncoderModel)
(prompt_embeds, negative_prompt_embeds, text_ids) = encode_prompt(
prompt=self.prompt,
tokenizer=tokenizer,
text_encoder=text_encoder,
negative_prompt=self.negative_prompt,
device=text_encoder.device,
num_images_per_prompt=1,
max_sequence_length=self.max_length,
lora_scale=1.0,
)
saved_pos_tensor = context.tensors.save(prompt_embeds)
saved_neg_tensor = context.tensors.save(negative_prompt_embeds)
saved_text_ids_tensor = context.tensors.save(text_ids)
pos_embeds_output = LatentsField(latents_name=saved_pos_tensor)
neg_embeds_output = LatentsField(latents_name=saved_neg_tensor)
text_ids_output = LatentsField(latents_name=saved_text_ids_tensor)
return BriaTextEncoderInvocationOutput(
pos_embeds=pos_embeds_output,
neg_embeds=neg_embeds_output,
text_ids=text_ids_output,
)

View File

@@ -42,6 +42,8 @@ class UIType(str, Enum, metaclass=MetaEnum):
MainModel = "MainModelField"
CogView4MainModel = "CogView4MainModelField"
FluxMainModel = "FluxMainModelField"
BriaMainModel = "BriaMainModelField"
BriaControlNetModel = "BriaControlNetModelField"
SD3MainModel = "SD3MainModelField"
SDXLMainModel = "SDXLMainModelField"
SDXLRefinerModel = "SDXLRefinerModelField"

View File

@@ -430,6 +430,15 @@ class FluxConditioningOutput(BaseInvocationOutput):
return cls(conditioning=FluxConditioningField(conditioning_name=conditioning_name))
@invocation_output("flux_conditioning_collection_output")
class FluxConditioningCollectionOutput(BaseInvocationOutput):
"""Base class for nodes that output a collection of conditioning tensors"""
collection: list[FluxConditioningField] = OutputField(
description="The output conditioning tensors",
)
@invocation_output("sd3_conditioning_output")
class SD3ConditioningOutput(BaseInvocationOutput):
"""Base class for nodes that output a single SD3 conditioning tensor"""

View File

View File

@@ -0,0 +1,314 @@
import math
import os
from typing import List, Optional, Union
import numpy as np
import torch
import torch.distributed as dist
from diffusers.utils import logging
from transformers import (
CLIPTextModel,
CLIPTextModelWithProjection,
CLIPTokenizer,
T5EncoderModel,
T5TokenizerFast,
)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def get_t5_prompt_embeds(
tokenizer: T5TokenizerFast,
text_encoder: T5EncoderModel,
prompt: Union[str, List[str], None] = None,
num_images_per_prompt: int = 1,
max_sequence_length: int = 128,
device: Optional[torch.device] = None,
):
device = device or text_encoder.device
if prompt is None:
prompt = ""
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
text_inputs = tokenizer(
prompt,
# padding="max_length",
max_length=max_sequence_length,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because `max_sequence_length` is set to "
f" {max_sequence_length} tokens: {removed_text}"
)
prompt_embeds = text_encoder(text_input_ids.to(device))[0]
# Concat zeros to max_sequence
b, seq_len, dim = prompt_embeds.shape
if seq_len < max_sequence_length:
padding = torch.zeros(
(b, max_sequence_length - seq_len, dim), dtype=prompt_embeds.dtype, device=prompt_embeds.device
)
prompt_embeds = torch.concat([prompt_embeds, padding], dim=1)
prompt_embeds = prompt_embeds.to(device=device)
_, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
return prompt_embeds
# in order the get the same sigmas as in training and sample from them
def get_original_sigmas(num_train_timesteps=1000, num_inference_steps=1000):
timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
sigmas = timesteps / num_train_timesteps
inds = [int(ind) for ind in np.linspace(0, num_train_timesteps - 1, num_inference_steps)]
new_sigmas = sigmas[inds]
return new_sigmas
def is_ng_none(negative_prompt):
return (
negative_prompt is None
or negative_prompt == ""
or (isinstance(negative_prompt, list) and negative_prompt[0] is None)
or (isinstance(negative_prompt, list) and negative_prompt[0] == "")
)
class CudaTimerContext:
def __init__(self, times_arr):
self.times_arr = times_arr
def __enter__(self):
self.before_event = torch.cuda.Event(enable_timing=True)
self.after_event = torch.cuda.Event(enable_timing=True)
self.before_event.record()
def __exit__(self, type, value, traceback):
self.after_event.record()
torch.cuda.synchronize()
elapsed_time = self.before_event.elapsed_time(self.after_event) / 1000
self.times_arr.append(elapsed_time)
def get_env_prefix():
env = os.environ.get("CLOUD_PROVIDER", "AWS").upper()
if env == "AWS":
return "SM_CHANNEL"
elif env == "AZURE":
return "AZUREML_DATAREFERENCE"
raise Exception(f"Env {env} not supported")
def compute_density_for_timestep_sampling(
weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None
):
"""Compute the density for sampling the timesteps when doing SD3 training.
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
"""
if weighting_scheme == "logit_normal":
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu")
u = torch.nn.functional.sigmoid(u)
elif weighting_scheme == "mode":
u = torch.rand(size=(batch_size,), device="cpu")
u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
else:
u = torch.rand(size=(batch_size,), device="cpu")
return u
def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
"""Computes loss weighting scheme for SD3 training.
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
"""
if weighting_scheme == "sigma_sqrt":
weighting = (sigmas**-2.0).float()
elif weighting_scheme == "cosmap":
bot = 1 - 2 * sigmas + 2 * sigmas**2
weighting = 2 / (math.pi * bot)
else:
weighting = torch.ones_like(sigmas)
return weighting
def initialize_distributed():
# Initialize the process group for distributed training
dist.init_process_group("nccl")
# Get the current process's rank (ID) and the total number of processes (world size)
rank = dist.get_rank()
world_size = dist.get_world_size()
print(f"Initialized distributed training: Rank {rank}/{world_size}")
def get_clip_prompt_embeds(
text_encoder: CLIPTextModel,
text_encoder_2: CLIPTextModelWithProjection,
tokenizer: CLIPTokenizer,
tokenizer_2: CLIPTokenizer,
prompt: Union[str, List[str]] = None,
num_images_per_prompt: int = 1,
max_sequence_length: int = 77,
device: Optional[torch.device] = None,
):
device = device or text_encoder.device
assert max_sequence_length == tokenizer.model_max_length
prompt = [prompt] if isinstance(prompt, str) else prompt
# Define tokenizers and text encoders
tokenizers = [tokenizer, tokenizer_2]
text_encoders = [text_encoder, text_encoder_2]
# textual inversion: process multi-vector tokens if necessary
prompt_embeds_list = []
prompts = [prompt, prompt]
for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders, strict=False):
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
prompt_embeds = text_encoder(text_input_ids.to(text_encoder.device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.hidden_states[-2]
prompt_embeds_list.append(prompt_embeds)
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
bs_embed * num_images_per_prompt, -1
)
return prompt_embeds, pooled_prompt_embeds
def get_1d_rotary_pos_embed(
dim: int,
pos: Union[np.ndarray, int],
theta: float = 10000.0,
use_real=False,
linear_factor=1.0,
ntk_factor=1.0,
repeat_interleave_real=True,
freqs_dtype=torch.float32, # torch.float32, torch.float64 (flux)
):
"""
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end
index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64
data type.
Args:
dim (`int`): Dimension of the frequency tensor.
pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar
theta (`float`, *optional*, defaults to 10000.0):
Scaling factor for frequency computation. Defaults to 10000.0.
use_real (`bool`, *optional*):
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
linear_factor (`float`, *optional*, defaults to 1.0):
Scaling factor for the context extrapolation. Defaults to 1.0.
ntk_factor (`float`, *optional*, defaults to 1.0):
Scaling factor for the NTK-Aware RoPE. Defaults to 1.0.
repeat_interleave_real (`bool`, *optional*, defaults to `True`):
If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`.
Otherwise, they are concateanted with themselves.
freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`):
the dtype of the frequency tensor.
Returns:
`torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
"""
assert dim % 2 == 0
if isinstance(pos, int):
pos = torch.arange(pos)
if isinstance(pos, np.ndarray):
pos = torch.from_numpy(pos) # type: ignore # [S]
theta = theta * ntk_factor
freqs = (
1.0
/ (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim))
/ linear_factor
) # [D/2]
freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
if use_real and repeat_interleave_real:
# flux, hunyuan-dit, cogvideox
freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
return freqs_cos, freqs_sin
elif use_real:
# stable audio, allegro
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]
return freqs_cos, freqs_sin
else:
# lumina
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
return freqs_cis
class FluxPosEmbed(torch.nn.Module):
# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
def __init__(self, theta: int, axes_dim: List[int]):
super().__init__()
self.theta = theta
self.axes_dim = axes_dim
def forward(self, ids: torch.Tensor) -> torch.Tensor:
n_axes = ids.shape[-1]
cos_out = []
sin_out = []
pos = ids.float()
is_mps = ids.device.type == "mps"
freqs_dtype = torch.float32 if is_mps else torch.float64
for i in range(n_axes):
cos, sin = get_1d_rotary_pos_embed(
self.axes_dim[i],
pos[:, i],
theta=self.theta,
repeat_interleave_real=True,
use_real=True,
freqs_dtype=freqs_dtype,
)
cos_out.append(cos)
sin_out.append(sin)
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
return freqs_cos, freqs_sin

View File

@@ -0,0 +1,6 @@
__version__ = "0.0.9"
from invokeai.backend.bria.controlnet_aux.canny import CannyDetector as CannyDetector
from invokeai.backend.bria.controlnet_aux.open_pose import OpenposeDetector as OpenposeDetector
__all__ = ["CannyDetector", "OpenposeDetector"]

View File

@@ -0,0 +1,39 @@
import warnings
import cv2
import numpy as np
from PIL import Image
from invokeai.backend.bria.controlnet_aux.util import HWC3, resize_image
class CannyDetector:
def __call__(self, input_image=None, low_threshold=100, high_threshold=200, detect_resolution=512, image_resolution=512, output_type=None, **kwargs):
if "img" in kwargs:
warnings.warn("img is deprecated, please use `input_image=...` instead.", DeprecationWarning, stacklevel=2)
input_image = kwargs.pop("img")
if input_image is None:
raise ValueError("input_image must be defined.")
if not isinstance(input_image, np.ndarray):
input_image = np.array(input_image, dtype=np.uint8)
output_type = output_type or "pil"
else:
output_type = output_type or "np"
input_image = HWC3(input_image)
input_image = resize_image(input_image, detect_resolution)
detected_map = cv2.Canny(input_image, low_threshold, high_threshold)
detected_map = HWC3(detected_map)
img = resize_image(input_image, image_resolution)
H, W, C = img.shape
detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)
if output_type == "pil":
detected_map = Image.fromarray(detected_map)
return detected_map

View File

@@ -0,0 +1,108 @@
OPENPOSE: MULTIPERSON KEYPOINT DETECTION
SOFTWARE LICENSE AGREEMENT
ACADEMIC OR NON-PROFIT ORGANIZATION NONCOMMERCIAL RESEARCH USE ONLY
BY USING OR DOWNLOADING THE SOFTWARE, YOU ARE AGREEING TO THE TERMS OF THIS LICENSE AGREEMENT. IF YOU DO NOT AGREE WITH THESE TERMS, YOU MAY NOT USE OR DOWNLOAD THE SOFTWARE.
This is a license agreement ("Agreement") between your academic institution or non-profit organization or self (called "Licensee" or "You" in this Agreement) and Carnegie Mellon University (called "Licensor" in this Agreement). All rights not specifically granted to you in this Agreement are reserved for Licensor.
RESERVATION OF OWNERSHIP AND GRANT OF LICENSE:
Licensor retains exclusive ownership of any copy of the Software (as defined below) licensed under this Agreement and hereby grants to Licensee a personal, non-exclusive,
non-transferable license to use the Software for noncommercial research purposes, without the right to sublicense, pursuant to the terms and conditions of this Agreement. As used in this Agreement, the term "Software" means (i) the actual copy of all or any portion of code for program routines made accessible to Licensee by Licensor pursuant to this Agreement, inclusive of backups, updates, and/or merged copies permitted hereunder or subsequently supplied by Licensor, including all or any file structures, programming instructions, user interfaces and screen formats and sequences as well as any and all documentation and instructions related to it, and (ii) all or any derivatives and/or modifications created or made by You to any of the items specified in (i).
CONFIDENTIALITY: Licensee acknowledges that the Software is proprietary to Licensor, and as such, Licensee agrees to receive all such materials in confidence and use the Software only in accordance with the terms of this Agreement. Licensee agrees to use reasonable effort to protect the Software from unauthorized use, reproduction, distribution, or publication.
COPYRIGHT: The Software is owned by Licensor and is protected by United
States copyright laws and applicable international treaties and/or conventions.
PERMITTED USES: The Software may be used for your own noncommercial internal research purposes. You understand and agree that Licensor is not obligated to implement any suggestions and/or feedback you might provide regarding the Software, but to the extent Licensor does so, you are not entitled to any compensation related thereto.
DERIVATIVES: You may create derivatives of or make modifications to the Software, however, You agree that all and any such derivatives and modifications will be owned by Licensor and become a part of the Software licensed to You under this Agreement. You may only use such derivatives and modifications for your own noncommercial internal research purposes, and you may not otherwise use, distribute or copy such derivatives and modifications in violation of this Agreement.
BACKUPS: If Licensee is an organization, it may make that number of copies of the Software necessary for internal noncommercial use at a single site within its organization provided that all information appearing in or on the original labels, including the copyright and trademark notices are copied onto the labels of the copies.
USES NOT PERMITTED: You may not distribute, copy or use the Software except as explicitly permitted herein. Licensee has not been granted any trademark license as part of this Agreement and may not use the name or mark “OpenPose", "Carnegie Mellon" or any renditions thereof without the prior written permission of Licensor.
You may not sell, rent, lease, sublicense, lend, time-share or transfer, in whole or in part, or provide third parties access to prior or present versions (or any parts thereof) of the Software.
ASSIGNMENT: You may not assign this Agreement or your rights hereunder without the prior written consent of Licensor. Any attempted assignment without such consent shall be null and void.
TERM: The term of the license granted by this Agreement is from Licensee's acceptance of this Agreement by downloading the Software or by using the Software until terminated as provided below.
The Agreement automatically terminates without notice if you fail to comply with any provision of this Agreement. Licensee may terminate this Agreement by ceasing using the Software. Upon any termination of this Agreement, Licensee will delete any and all copies of the Software. You agree that all provisions which operate to protect the proprietary rights of Licensor shall remain in force should breach occur and that the obligation of confidentiality described in this Agreement is binding in perpetuity and, as such, survives the term of the Agreement.
FEE: Provided Licensee abides completely by the terms and conditions of this Agreement, there is no fee due to Licensor for Licensee's use of the Software in accordance with this Agreement.
DISCLAIMER OF WARRANTIES: THE SOFTWARE IS PROVIDED "AS-IS" WITHOUT WARRANTY OF ANY KIND INCLUDING ANY WARRANTIES OF PERFORMANCE OR MERCHANTABILITY OR FITNESS FOR A PARTICULAR USE OR PURPOSE OR OF NON-INFRINGEMENT. LICENSEE BEARS ALL RISK RELATING TO QUALITY AND PERFORMANCE OF THE SOFTWARE AND RELATED MATERIALS.
SUPPORT AND MAINTENANCE: No Software support or training by the Licensor is provided as part of this Agreement.
EXCLUSIVE REMEDY AND LIMITATION OF LIABILITY: To the maximum extent permitted under applicable law, Licensor shall not be liable for direct, indirect, special, incidental, or consequential damages or lost profits related to Licensee's use of and/or inability to use the Software, even if Licensor is advised of the possibility of such damage.
EXPORT REGULATION: Licensee agrees to comply with any and all applicable
U.S. export control laws, regulations, and/or other laws related to embargoes and sanction programs administered by the Office of Foreign Assets Control.
SEVERABILITY: If any provision(s) of this Agreement shall be held to be invalid, illegal, or unenforceable by a court or other tribunal of competent jurisdiction, the validity, legality and enforceability of the remaining provisions shall not in any way be affected or impaired thereby.
NO IMPLIED WAIVERS: No failure or delay by Licensor in enforcing any right or remedy under this Agreement shall be construed as a waiver of any future or other exercise of such right or remedy by Licensor.
GOVERNING LAW: This Agreement shall be construed and enforced in accordance with the laws of the Commonwealth of Pennsylvania without reference to conflict of laws principles. You consent to the personal jurisdiction of the courts of this County and waive their rights to venue outside of Allegheny County, Pennsylvania.
ENTIRE AGREEMENT AND AMENDMENTS: This Agreement constitutes the sole and entire agreement between Licensee and Licensor as to the matter set forth herein and supersedes any previous agreements, understandings, and arrangements between the parties relating hereto.
************************************************************************
THIRD-PARTY SOFTWARE NOTICES AND INFORMATION
This project incorporates material from the project(s) listed below (collectively, "Third Party Code"). This Third Party Code is licensed to you under their original license terms set forth below. We reserves all other rights not expressly granted, whether by implication, estoppel or otherwise.
1. Caffe, version 1.0.0, (https://github.com/BVLC/caffe/)
COPYRIGHT
All contributions by the University of California:
Copyright (c) 2014-2017 The Regents of the University of California (Regents)
All rights reserved.
All other contributions:
Copyright (c) 2014-2017, the respective contributors
All rights reserved.
Caffe uses a shared copyright model: each contributor holds copyright over
their contributions to Caffe. The project versioning records all such
contribution and copyright details. If a contributor wants to further mark
their specific copyright on a particular contribution, they should indicate
their copyright solely in the commit message of the change when it is
committed.
LICENSE
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
CONTRIBUTION AGREEMENT
By contributing to the BVLC/caffe repository through pull-request, comment,
or otherwise, the contributor releases their content to the
license and copyright terms herein.
************END OF THIRD-PARTY SOFTWARE NOTICES AND INFORMATION**********

View File

@@ -0,0 +1,233 @@
# Openpose
# Original from CMU https://github.com/CMU-Perceptual-Computing-Lab/openpose
# 2nd Edited by https://github.com/Hzzone/pytorch-openpose
# 3rd Edited by ControlNet
# 4th Edited by ControlNet (added face and correct hands)
# 5th Edited by ControlNet (Improved JSON serialization/deserialization, and lots of bug fixs)
# This preprocessor is licensed by CMU for non-commercial use only.
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
import warnings
from typing import List, NamedTuple, Tuple, Union
import cv2
import numpy as np
import torch
from huggingface_hub import hf_hub_download
from PIL import Image
from invokeai.backend.bria.controlnet_aux.open_pose import util
from invokeai.backend.bria.controlnet_aux.open_pose.body import Body, BodyResult, Keypoint
from invokeai.backend.bria.controlnet_aux.open_pose.face import Face
from invokeai.backend.bria.controlnet_aux.open_pose.hand import Hand
from invokeai.backend.bria.controlnet_aux.util import HWC3, resize_image
HandResult = List[Keypoint]
FaceResult = List[Keypoint]
class PoseResult(NamedTuple):
body: BodyResult
left_hand: Union[HandResult, None]
right_hand: Union[HandResult, None]
face: Union[FaceResult, None]
def draw_poses(poses: List[PoseResult], H, W, draw_body=True, draw_hand=True, draw_face=True):
"""
Draw the detected poses on an empty canvas.
Args:
poses (List[PoseResult]): A list of PoseResult objects containing the detected poses.
H (int): The height of the canvas.
W (int): The width of the canvas.
draw_body (bool, optional): Whether to draw body keypoints. Defaults to True.
draw_hand (bool, optional): Whether to draw hand keypoints. Defaults to True.
draw_face (bool, optional): Whether to draw face keypoints. Defaults to True.
Returns:
numpy.ndarray: A 3D numpy array representing the canvas with the drawn poses.
"""
canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8)
for pose in poses:
if draw_body:
canvas = util.draw_bodypose(canvas, pose.body.keypoints)
if draw_hand:
canvas = util.draw_handpose(canvas, pose.left_hand)
canvas = util.draw_handpose(canvas, pose.right_hand)
if draw_face:
canvas = util.draw_facepose(canvas, pose.face)
return canvas
class OpenposeDetector:
"""
A class for detecting human poses in images using the Openpose model.
Attributes:
model_dir (str): Path to the directory where the pose models are stored.
"""
def __init__(self, body_estimation, hand_estimation=None, face_estimation=None):
self.body_estimation = body_estimation
self.hand_estimation = hand_estimation
self.face_estimation = face_estimation
@classmethod
def from_pretrained(cls, pretrained_model_or_path, filename=None, hand_filename=None, face_filename=None, cache_dir=None, local_files_only=False):
if pretrained_model_or_path == "lllyasviel/ControlNet":
filename = filename or "annotator/ckpts/body_pose_model.pth"
hand_filename = hand_filename or "annotator/ckpts/hand_pose_model.pth"
face_filename = face_filename or "facenet.pth"
face_pretrained_model_or_path = "lllyasviel/Annotators"
else:
filename = filename or "body_pose_model.pth"
hand_filename = hand_filename or "hand_pose_model.pth"
face_filename = face_filename or "facenet.pth"
face_pretrained_model_or_path = pretrained_model_or_path
if os.path.isdir(pretrained_model_or_path):
body_model_path = os.path.join(pretrained_model_or_path, filename)
hand_model_path = os.path.join(pretrained_model_or_path, hand_filename)
face_model_path = os.path.join(face_pretrained_model_or_path, face_filename)
else:
body_model_path = hf_hub_download(pretrained_model_or_path, filename, cache_dir=cache_dir, local_files_only=local_files_only)
hand_model_path = hf_hub_download(pretrained_model_or_path, hand_filename, cache_dir=cache_dir, local_files_only=local_files_only)
face_model_path = hf_hub_download(face_pretrained_model_or_path, face_filename, cache_dir=cache_dir, local_files_only=local_files_only)
body_estimation = Body(body_model_path)
hand_estimation = Hand(hand_model_path)
face_estimation = Face(face_model_path)
return cls(body_estimation, hand_estimation, face_estimation)
def to(self, device):
self.body_estimation.to(device)
self.hand_estimation.to(device)
self.face_estimation.to(device)
return self
def detect_hands(self, body: BodyResult, oriImg) -> Tuple[Union[HandResult, None], Union[HandResult, None]]:
left_hand = None
right_hand = None
H, W, _ = oriImg.shape
for x, y, w, is_left in util.handDetect(body, oriImg):
peaks = self.hand_estimation(oriImg[y:y+w, x:x+w, :]).astype(np.float32)
if peaks.ndim == 2 and peaks.shape[1] == 2:
peaks[:, 0] = np.where(peaks[:, 0] < 1e-6, -1, peaks[:, 0] + x) / float(W)
peaks[:, 1] = np.where(peaks[:, 1] < 1e-6, -1, peaks[:, 1] + y) / float(H)
hand_result = [
Keypoint(x=peak[0], y=peak[1])
for peak in peaks
]
if is_left:
left_hand = hand_result
else:
right_hand = hand_result
return left_hand, right_hand
def detect_face(self, body: BodyResult, oriImg) -> Union[FaceResult, None]:
face = util.faceDetect(body, oriImg)
if face is None:
return None
x, y, w = face
H, W, _ = oriImg.shape
heatmaps = self.face_estimation(oriImg[y:y+w, x:x+w, :])
peaks = self.face_estimation.compute_peaks_from_heatmaps(heatmaps).astype(np.float32)
if peaks.ndim == 2 and peaks.shape[1] == 2:
peaks[:, 0] = np.where(peaks[:, 0] < 1e-6, -1, peaks[:, 0] + x) / float(W)
peaks[:, 1] = np.where(peaks[:, 1] < 1e-6, -1, peaks[:, 1] + y) / float(H)
return [
Keypoint(x=peak[0], y=peak[1])
for peak in peaks
]
return None
def detect_poses(self, oriImg, include_hand=False, include_face=False) -> List[PoseResult]:
"""
Detect poses in the given image.
Args:
oriImg (numpy.ndarray): The input image for pose detection.
include_hand (bool, optional): Whether to include hand detection. Defaults to False.
include_face (bool, optional): Whether to include face detection. Defaults to False.
Returns:
List[PoseResult]: A list of PoseResult objects containing the detected poses.
"""
oriImg = oriImg[:, :, ::-1].copy()
H, W, C = oriImg.shape
with torch.no_grad():
candidate, subset = self.body_estimation(oriImg)
bodies = self.body_estimation.format_body_result(candidate, subset)
results = []
for body in bodies:
left_hand, right_hand, face = (None,) * 3
if include_hand:
left_hand, right_hand = self.detect_hands(body, oriImg)
if include_face:
face = self.detect_face(body, oriImg)
results.append(PoseResult(BodyResult(
keypoints=[
Keypoint(
x=keypoint.x / float(W),
y=keypoint.y / float(H)
) if keypoint is not None else None
for keypoint in body.keypoints
],
total_score=body.total_score,
total_parts=body.total_parts
), left_hand, right_hand, face))
return results
def __call__(self, input_image, detect_resolution=512, image_resolution=512, include_body=True, include_hand=False, include_face=False, hand_and_face=None, output_type="pil", **kwargs):
if hand_and_face is not None:
warnings.warn("hand_and_face is deprecated. Use include_hand and include_face instead.", DeprecationWarning, stacklevel=2)
include_hand = hand_and_face
include_face = hand_and_face
if "return_pil" in kwargs:
warnings.warn("return_pil is deprecated. Use output_type instead.", DeprecationWarning, stacklevel=2)
output_type = "pil" if kwargs["return_pil"] else "np"
if type(output_type) is bool:
warnings.warn("Passing `True` or `False` to `output_type` is deprecated and will raise an error in future versions", stacklevel=2)
if output_type:
output_type = "pil"
if not isinstance(input_image, np.ndarray):
input_image = np.array(input_image, dtype=np.uint8)
input_image = HWC3(input_image)
input_image = resize_image(input_image, detect_resolution)
H, W, C = input_image.shape
poses = self.detect_poses(input_image, include_hand, include_face)
canvas = draw_poses(poses, H, W, draw_body=include_body, draw_hand=include_hand, draw_face=include_face)
detected_map = canvas
detected_map = HWC3(detected_map)
img = resize_image(input_image, image_resolution)
H, W, C = img.shape
detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)
if output_type == "pil":
detected_map = Image.fromarray(detected_map)
return detected_map

View File

@@ -0,0 +1,259 @@
import math
from typing import List, NamedTuple, Union
import numpy as np
import torch
from scipy.ndimage.filters import gaussian_filter
from invokeai.backend.bria.controlnet_aux.open_pose import util
from invokeai.backend.bria.controlnet_aux.open_pose.model import bodypose_model
class Keypoint(NamedTuple):
x: float
y: float
score: float = 1.0
id: int = -1
class BodyResult(NamedTuple):
# Note: Using `Union` instead of `|` operator as the ladder is a Python
# 3.10 feature.
# Annotator code should be Python 3.8 Compatible, as controlnet repo uses
# Python 3.8 environment.
# https://github.com/lllyasviel/ControlNet/blob/d3284fcd0972c510635a4f5abe2eeb71dc0de524/environment.yaml#L6
keypoints: List[Union[Keypoint, None]]
total_score: float
total_parts: int
class Body(object):
def __init__(self, model_path):
self.model = bodypose_model()
model_dict = util.transfer(self.model, torch.load(model_path))
self.model.load_state_dict(model_dict)
self.model.eval()
def to(self, device):
self.model.to(device)
return self
def __call__(self, oriImg):
device = next(iter(self.model.parameters())).device
# scale_search = [0.5, 1.0, 1.5, 2.0]
scale_search = [0.5]
boxsize = 368
stride = 8
padValue = 128
thre1 = 0.1
thre2 = 0.05
multiplier = [x * boxsize / oriImg.shape[0] for x in scale_search]
heatmap_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 19))
paf_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 38))
for m in range(len(multiplier)):
scale = multiplier[m]
imageToTest = util.smart_resize_k(oriImg, fx=scale, fy=scale)
imageToTest_padded, pad = util.padRightDownCorner(imageToTest, stride, padValue)
im = np.transpose(np.float32(imageToTest_padded[:, :, :, np.newaxis]), (3, 2, 0, 1)) / 256 - 0.5
im = np.ascontiguousarray(im)
data = torch.from_numpy(im).float()
data = data.to(device)
# data = data.permute([2, 0, 1]).unsqueeze(0).float()
with torch.no_grad():
Mconv7_stage6_L1, Mconv7_stage6_L2 = self.model(data)
Mconv7_stage6_L1 = Mconv7_stage6_L1.cpu().numpy()
Mconv7_stage6_L2 = Mconv7_stage6_L2.cpu().numpy()
# extract outputs, resize, and remove padding
# heatmap = np.transpose(np.squeeze(net.blobs[output_blobs.keys()[1]].data), (1, 2, 0)) # output 1 is heatmaps
heatmap = np.transpose(np.squeeze(Mconv7_stage6_L2), (1, 2, 0)) # output 1 is heatmaps
heatmap = util.smart_resize_k(heatmap, fx=stride, fy=stride)
heatmap = heatmap[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :]
heatmap = util.smart_resize(heatmap, (oriImg.shape[0], oriImg.shape[1]))
# paf = np.transpose(np.squeeze(net.blobs[output_blobs.keys()[0]].data), (1, 2, 0)) # output 0 is PAFs
paf = np.transpose(np.squeeze(Mconv7_stage6_L1), (1, 2, 0)) # output 0 is PAFs
paf = util.smart_resize_k(paf, fx=stride, fy=stride)
paf = paf[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :]
paf = util.smart_resize(paf, (oriImg.shape[0], oriImg.shape[1]))
heatmap_avg += heatmap_avg + heatmap / len(multiplier)
paf_avg += + paf / len(multiplier)
all_peaks = []
peak_counter = 0
for part in range(18):
map_ori = heatmap_avg[:, :, part]
one_heatmap = gaussian_filter(map_ori, sigma=3)
map_left = np.zeros(one_heatmap.shape)
map_left[1:, :] = one_heatmap[:-1, :]
map_right = np.zeros(one_heatmap.shape)
map_right[:-1, :] = one_heatmap[1:, :]
map_up = np.zeros(one_heatmap.shape)
map_up[:, 1:] = one_heatmap[:, :-1]
map_down = np.zeros(one_heatmap.shape)
map_down[:, :-1] = one_heatmap[:, 1:]
peaks_binary = np.logical_and.reduce(
(one_heatmap >= map_left, one_heatmap >= map_right, one_heatmap >= map_up, one_heatmap >= map_down, one_heatmap > thre1))
peaks = list(zip(np.nonzero(peaks_binary)[1], np.nonzero(peaks_binary)[0], strict=False)) # note reverse
peaks_with_score = [x + (map_ori[x[1], x[0]],) for x in peaks]
peak_id = range(peak_counter, peak_counter + len(peaks))
peaks_with_score_and_id = [peaks_with_score[i] + (peak_id[i],) for i in range(len(peak_id))]
all_peaks.append(peaks_with_score_and_id)
peak_counter += len(peaks)
# find connection in the specified sequence, center 29 is in the position 15
limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \
[10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \
[1, 16], [16, 18], [3, 17], [6, 18]]
# the middle joints heatmap correpondence
mapIdx = [[31, 32], [39, 40], [33, 34], [35, 36], [41, 42], [43, 44], [19, 20], [21, 22], \
[23, 24], [25, 26], [27, 28], [29, 30], [47, 48], [49, 50], [53, 54], [51, 52], \
[55, 56], [37, 38], [45, 46]]
connection_all = []
special_k = []
mid_num = 10
for k in range(len(mapIdx)):
score_mid = paf_avg[:, :, [x - 19 for x in mapIdx[k]]]
candA = all_peaks[limbSeq[k][0] - 1]
candB = all_peaks[limbSeq[k][1] - 1]
nA = len(candA)
nB = len(candB)
indexA, indexB = limbSeq[k]
if (nA != 0 and nB != 0):
connection_candidate = []
for i in range(nA):
for j in range(nB):
vec = np.subtract(candB[j][:2], candA[i][:2])
norm = math.sqrt(vec[0] * vec[0] + vec[1] * vec[1])
norm = max(0.001, norm)
vec = np.divide(vec, norm)
startend = list(zip(np.linspace(candA[i][0], candB[j][0], num=mid_num), \
np.linspace(candA[i][1], candB[j][1], num=mid_num), strict=False))
vec_x = np.array([score_mid[int(round(startend[i][1])), int(round(startend[i][0])), 0] \
for i in range(len(startend))])
vec_y = np.array([score_mid[int(round(startend[i][1])), int(round(startend[i][0])), 1] \
for i in range(len(startend))])
score_midpts = np.multiply(vec_x, vec[0]) + np.multiply(vec_y, vec[1])
score_with_dist_prior = sum(score_midpts) / len(score_midpts) + min(
0.5 * oriImg.shape[0] / norm - 1, 0)
criterion1 = len(np.nonzero(score_midpts > thre2)[0]) > 0.8 * len(score_midpts)
criterion2 = score_with_dist_prior > 0
if criterion1 and criterion2:
connection_candidate.append(
[i, j, score_with_dist_prior, score_with_dist_prior + candA[i][2] + candB[j][2]])
connection_candidate = sorted(connection_candidate, key=lambda x: x[2], reverse=True)
connection = np.zeros((0, 5))
for c in range(len(connection_candidate)):
i, j, s = connection_candidate[c][0:3]
if (i not in connection[:, 3] and j not in connection[:, 4]):
connection = np.vstack([connection, [candA[i][3], candB[j][3], s, i, j]])
if (len(connection) >= min(nA, nB)):
break
connection_all.append(connection)
else:
special_k.append(k)
connection_all.append([])
# last number in each row is the total parts number of that person
# the second last number in each row is the score of the overall configuration
subset = -1 * np.ones((0, 20))
candidate = np.array([item for sublist in all_peaks for item in sublist])
for k in range(len(mapIdx)):
if k not in special_k:
partAs = connection_all[k][:, 0]
partBs = connection_all[k][:, 1]
indexA, indexB = np.array(limbSeq[k]) - 1
for i in range(len(connection_all[k])): # = 1:size(temp,1)
found = 0
subset_idx = [-1, -1]
for j in range(len(subset)): # 1:size(subset,1):
if subset[j][indexA] == partAs[i] or subset[j][indexB] == partBs[i]:
subset_idx[found] = j
found += 1
if found == 1:
j = subset_idx[0]
if subset[j][indexB] != partBs[i]:
subset[j][indexB] = partBs[i]
subset[j][-1] += 1
subset[j][-2] += candidate[partBs[i].astype(int), 2] + connection_all[k][i][2]
elif found == 2: # if found 2 and disjoint, merge them
j1, j2 = subset_idx
membership = ((subset[j1] >= 0).astype(int) + (subset[j2] >= 0).astype(int))[:-2]
if len(np.nonzero(membership == 2)[0]) == 0: # merge
subset[j1][:-2] += (subset[j2][:-2] + 1)
subset[j1][-2:] += subset[j2][-2:]
subset[j1][-2] += connection_all[k][i][2]
subset = np.delete(subset, j2, 0)
else: # as like found == 1
subset[j1][indexB] = partBs[i]
subset[j1][-1] += 1
subset[j1][-2] += candidate[partBs[i].astype(int), 2] + connection_all[k][i][2]
# if find no partA in the subset, create a new subset
elif not found and k < 17:
row = -1 * np.ones(20)
row[indexA] = partAs[i]
row[indexB] = partBs[i]
row[-1] = 2
row[-2] = sum(candidate[connection_all[k][i, :2].astype(int), 2]) + connection_all[k][i][2]
subset = np.vstack([subset, row])
# delete some rows of subset which has few parts occur
deleteIdx = []
for i in range(len(subset)):
if subset[i][-1] < 4 or subset[i][-2] / subset[i][-1] < 0.4:
deleteIdx.append(i)
subset = np.delete(subset, deleteIdx, axis=0)
# subset: n*20 array, 0-17 is the index in candidate, 18 is the total score, 19 is the total parts
# candidate: x, y, score, id
return candidate, subset
@staticmethod
def format_body_result(candidate: np.ndarray, subset: np.ndarray) -> List[BodyResult]:
"""
Format the body results from the candidate and subset arrays into a list of BodyResult objects.
Args:
candidate (np.ndarray): An array of candidates containing the x, y coordinates, score, and id
for each body part.
subset (np.ndarray): An array of subsets containing indices to the candidate array for each
person detected. The last two columns of each row hold the total score and total parts
of the person.
Returns:
List[BodyResult]: A list of BodyResult objects, where each object represents a person with
detected keypoints, total score, and total parts.
"""
return [
BodyResult(
keypoints=[
Keypoint(
x=candidate[candidate_index][0],
y=candidate[candidate_index][1],
score=candidate[candidate_index][2],
id=candidate[candidate_index][3]
) if candidate_index != -1 else None
for candidate_index in person[:18].astype(int)
],
total_score=person[18],
total_parts=person[19]
)
for person in subset
]

View File

@@ -0,0 +1,364 @@
import logging
import numpy as np
import torch
import torch.nn.functional as F
from torch.nn import Conv2d, MaxPool2d, Module, ReLU, init
from torchvision.transforms import ToPILImage, ToTensor
from invokeai.backend.bria.controlnet_aux.open_pose import util
class FaceNet(Module):
"""Model the cascading heatmaps. """
def __init__(self):
super(FaceNet, self).__init__()
# cnn to make feature map
self.relu = ReLU()
self.max_pooling_2d = MaxPool2d(kernel_size=2, stride=2)
self.conv1_1 = Conv2d(in_channels=3, out_channels=64,
kernel_size=3, stride=1, padding=1)
self.conv1_2 = Conv2d(
in_channels=64, out_channels=64, kernel_size=3, stride=1,
padding=1)
self.conv2_1 = Conv2d(
in_channels=64, out_channels=128, kernel_size=3, stride=1,
padding=1)
self.conv2_2 = Conv2d(
in_channels=128, out_channels=128, kernel_size=3, stride=1,
padding=1)
self.conv3_1 = Conv2d(
in_channels=128, out_channels=256, kernel_size=3, stride=1,
padding=1)
self.conv3_2 = Conv2d(
in_channels=256, out_channels=256, kernel_size=3, stride=1,
padding=1)
self.conv3_3 = Conv2d(
in_channels=256, out_channels=256, kernel_size=3, stride=1,
padding=1)
self.conv3_4 = Conv2d(
in_channels=256, out_channels=256, kernel_size=3, stride=1,
padding=1)
self.conv4_1 = Conv2d(
in_channels=256, out_channels=512, kernel_size=3, stride=1,
padding=1)
self.conv4_2 = Conv2d(
in_channels=512, out_channels=512, kernel_size=3, stride=1,
padding=1)
self.conv4_3 = Conv2d(
in_channels=512, out_channels=512, kernel_size=3, stride=1,
padding=1)
self.conv4_4 = Conv2d(
in_channels=512, out_channels=512, kernel_size=3, stride=1,
padding=1)
self.conv5_1 = Conv2d(
in_channels=512, out_channels=512, kernel_size=3, stride=1,
padding=1)
self.conv5_2 = Conv2d(
in_channels=512, out_channels=512, kernel_size=3, stride=1,
padding=1)
self.conv5_3_CPM = Conv2d(
in_channels=512, out_channels=128, kernel_size=3, stride=1,
padding=1)
# stage1
self.conv6_1_CPM = Conv2d(
in_channels=128, out_channels=512, kernel_size=1, stride=1,
padding=0)
self.conv6_2_CPM = Conv2d(
in_channels=512, out_channels=71, kernel_size=1, stride=1,
padding=0)
# stage2
self.Mconv1_stage2 = Conv2d(
in_channels=199, out_channels=128, kernel_size=7, stride=1,
padding=3)
self.Mconv2_stage2 = Conv2d(
in_channels=128, out_channels=128, kernel_size=7, stride=1,
padding=3)
self.Mconv3_stage2 = Conv2d(
in_channels=128, out_channels=128, kernel_size=7, stride=1,
padding=3)
self.Mconv4_stage2 = Conv2d(
in_channels=128, out_channels=128, kernel_size=7, stride=1,
padding=3)
self.Mconv5_stage2 = Conv2d(
in_channels=128, out_channels=128, kernel_size=7, stride=1,
padding=3)
self.Mconv6_stage2 = Conv2d(
in_channels=128, out_channels=128, kernel_size=1, stride=1,
padding=0)
self.Mconv7_stage2 = Conv2d(
in_channels=128, out_channels=71, kernel_size=1, stride=1,
padding=0)
# stage3
self.Mconv1_stage3 = Conv2d(
in_channels=199, out_channels=128, kernel_size=7, stride=1,
padding=3)
self.Mconv2_stage3 = Conv2d(
in_channels=128, out_channels=128, kernel_size=7, stride=1,
padding=3)
self.Mconv3_stage3 = Conv2d(
in_channels=128, out_channels=128, kernel_size=7, stride=1,
padding=3)
self.Mconv4_stage3 = Conv2d(
in_channels=128, out_channels=128, kernel_size=7, stride=1,
padding=3)
self.Mconv5_stage3 = Conv2d(
in_channels=128, out_channels=128, kernel_size=7, stride=1,
padding=3)
self.Mconv6_stage3 = Conv2d(
in_channels=128, out_channels=128, kernel_size=1, stride=1,
padding=0)
self.Mconv7_stage3 = Conv2d(
in_channels=128, out_channels=71, kernel_size=1, stride=1,
padding=0)
# stage4
self.Mconv1_stage4 = Conv2d(
in_channels=199, out_channels=128, kernel_size=7, stride=1,
padding=3)
self.Mconv2_stage4 = Conv2d(
in_channels=128, out_channels=128, kernel_size=7, stride=1,
padding=3)
self.Mconv3_stage4 = Conv2d(
in_channels=128, out_channels=128, kernel_size=7, stride=1,
padding=3)
self.Mconv4_stage4 = Conv2d(
in_channels=128, out_channels=128, kernel_size=7, stride=1,
padding=3)
self.Mconv5_stage4 = Conv2d(
in_channels=128, out_channels=128, kernel_size=7, stride=1,
padding=3)
self.Mconv6_stage4 = Conv2d(
in_channels=128, out_channels=128, kernel_size=1, stride=1,
padding=0)
self.Mconv7_stage4 = Conv2d(
in_channels=128, out_channels=71, kernel_size=1, stride=1,
padding=0)
# stage5
self.Mconv1_stage5 = Conv2d(
in_channels=199, out_channels=128, kernel_size=7, stride=1,
padding=3)
self.Mconv2_stage5 = Conv2d(
in_channels=128, out_channels=128, kernel_size=7, stride=1,
padding=3)
self.Mconv3_stage5 = Conv2d(
in_channels=128, out_channels=128, kernel_size=7, stride=1,
padding=3)
self.Mconv4_stage5 = Conv2d(
in_channels=128, out_channels=128, kernel_size=7, stride=1,
padding=3)
self.Mconv5_stage5 = Conv2d(
in_channels=128, out_channels=128, kernel_size=7, stride=1,
padding=3)
self.Mconv6_stage5 = Conv2d(
in_channels=128, out_channels=128, kernel_size=1, stride=1,
padding=0)
self.Mconv7_stage5 = Conv2d(
in_channels=128, out_channels=71, kernel_size=1, stride=1,
padding=0)
# stage6
self.Mconv1_stage6 = Conv2d(
in_channels=199, out_channels=128, kernel_size=7, stride=1,
padding=3)
self.Mconv2_stage6 = Conv2d(
in_channels=128, out_channels=128, kernel_size=7, stride=1,
padding=3)
self.Mconv3_stage6 = Conv2d(
in_channels=128, out_channels=128, kernel_size=7, stride=1,
padding=3)
self.Mconv4_stage6 = Conv2d(
in_channels=128, out_channels=128, kernel_size=7, stride=1,
padding=3)
self.Mconv5_stage6 = Conv2d(
in_channels=128, out_channels=128, kernel_size=7, stride=1,
padding=3)
self.Mconv6_stage6 = Conv2d(
in_channels=128, out_channels=128, kernel_size=1, stride=1,
padding=0)
self.Mconv7_stage6 = Conv2d(
in_channels=128, out_channels=71, kernel_size=1, stride=1,
padding=0)
for m in self.modules():
if isinstance(m, Conv2d):
init.constant_(m.bias, 0)
def forward(self, x):
"""Return a list of heatmaps."""
heatmaps = []
h = self.relu(self.conv1_1(x))
h = self.relu(self.conv1_2(h))
h = self.max_pooling_2d(h)
h = self.relu(self.conv2_1(h))
h = self.relu(self.conv2_2(h))
h = self.max_pooling_2d(h)
h = self.relu(self.conv3_1(h))
h = self.relu(self.conv3_2(h))
h = self.relu(self.conv3_3(h))
h = self.relu(self.conv3_4(h))
h = self.max_pooling_2d(h)
h = self.relu(self.conv4_1(h))
h = self.relu(self.conv4_2(h))
h = self.relu(self.conv4_3(h))
h = self.relu(self.conv4_4(h))
h = self.relu(self.conv5_1(h))
h = self.relu(self.conv5_2(h))
h = self.relu(self.conv5_3_CPM(h))
feature_map = h
# stage1
h = self.relu(self.conv6_1_CPM(h))
h = self.conv6_2_CPM(h)
heatmaps.append(h)
# stage2
h = torch.cat([h, feature_map], dim=1) # channel concat
h = self.relu(self.Mconv1_stage2(h))
h = self.relu(self.Mconv2_stage2(h))
h = self.relu(self.Mconv3_stage2(h))
h = self.relu(self.Mconv4_stage2(h))
h = self.relu(self.Mconv5_stage2(h))
h = self.relu(self.Mconv6_stage2(h))
h = self.Mconv7_stage2(h)
heatmaps.append(h)
# stage3
h = torch.cat([h, feature_map], dim=1) # channel concat
h = self.relu(self.Mconv1_stage3(h))
h = self.relu(self.Mconv2_stage3(h))
h = self.relu(self.Mconv3_stage3(h))
h = self.relu(self.Mconv4_stage3(h))
h = self.relu(self.Mconv5_stage3(h))
h = self.relu(self.Mconv6_stage3(h))
h = self.Mconv7_stage3(h)
heatmaps.append(h)
# stage4
h = torch.cat([h, feature_map], dim=1) # channel concat
h = self.relu(self.Mconv1_stage4(h))
h = self.relu(self.Mconv2_stage4(h))
h = self.relu(self.Mconv3_stage4(h))
h = self.relu(self.Mconv4_stage4(h))
h = self.relu(self.Mconv5_stage4(h))
h = self.relu(self.Mconv6_stage4(h))
h = self.Mconv7_stage4(h)
heatmaps.append(h)
# stage5
h = torch.cat([h, feature_map], dim=1) # channel concat
h = self.relu(self.Mconv1_stage5(h))
h = self.relu(self.Mconv2_stage5(h))
h = self.relu(self.Mconv3_stage5(h))
h = self.relu(self.Mconv4_stage5(h))
h = self.relu(self.Mconv5_stage5(h))
h = self.relu(self.Mconv6_stage5(h))
h = self.Mconv7_stage5(h)
heatmaps.append(h)
# stage6
h = torch.cat([h, feature_map], dim=1) # channel concat
h = self.relu(self.Mconv1_stage6(h))
h = self.relu(self.Mconv2_stage6(h))
h = self.relu(self.Mconv3_stage6(h))
h = self.relu(self.Mconv4_stage6(h))
h = self.relu(self.Mconv5_stage6(h))
h = self.relu(self.Mconv6_stage6(h))
h = self.Mconv7_stage6(h)
heatmaps.append(h)
return heatmaps
LOG = logging.getLogger(__name__)
TOTEN = ToTensor()
TOPIL = ToPILImage()
params = {
'gaussian_sigma': 2.5,
'inference_img_size': 736, # 368, 736, 1312
'heatmap_peak_thresh': 0.1,
'crop_scale': 1.5,
'line_indices': [
[0, 1], [1, 2], [2, 3], [3, 4], [4, 5], [5, 6],
[6, 7], [7, 8], [8, 9], [9, 10], [10, 11], [11, 12], [12, 13],
[13, 14], [14, 15], [15, 16],
[17, 18], [18, 19], [19, 20], [20, 21],
[22, 23], [23, 24], [24, 25], [25, 26],
[27, 28], [28, 29], [29, 30],
[31, 32], [32, 33], [33, 34], [34, 35],
[36, 37], [37, 38], [38, 39], [39, 40], [40, 41], [41, 36],
[42, 43], [43, 44], [44, 45], [45, 46], [46, 47], [47, 42],
[48, 49], [49, 50], [50, 51], [51, 52], [52, 53], [53, 54],
[54, 55], [55, 56], [56, 57], [57, 58], [58, 59], [59, 48],
[60, 61], [61, 62], [62, 63], [63, 64], [64, 65], [65, 66],
[66, 67], [67, 60]
],
}
class Face(object):
"""
The OpenPose face landmark detector model.
Args:
inference_size: set the size of the inference image size, suggested:
368, 736, 1312, default 736
gaussian_sigma: blur the heatmaps, default 2.5
heatmap_peak_thresh: return landmark if over threshold, default 0.1
"""
def __init__(self, face_model_path,
inference_size=None,
gaussian_sigma=None,
heatmap_peak_thresh=None):
self.inference_size = inference_size or params["inference_img_size"]
self.sigma = gaussian_sigma or params['gaussian_sigma']
self.threshold = heatmap_peak_thresh or params["heatmap_peak_thresh"]
self.model = FaceNet()
self.model.load_state_dict(torch.load(face_model_path))
self.model.eval()
def to(self, device):
self.model.to(device)
return self
def __call__(self, face_img):
device = next(iter(self.model.parameters())).device
H, W, C = face_img.shape
w_size = 384
x_data = torch.from_numpy(util.smart_resize(face_img, (w_size, w_size))).permute([2, 0, 1]) / 256.0 - 0.5
x_data = x_data.to(device)
with torch.no_grad():
hs = self.model(x_data[None, ...])
heatmaps = F.interpolate(
hs[-1],
(H, W),
mode='bilinear', align_corners=True).cpu().numpy()[0]
return heatmaps
def compute_peaks_from_heatmaps(self, heatmaps):
all_peaks = []
for part in range(heatmaps.shape[0]):
map_ori = heatmaps[part].copy()
binary = np.ascontiguousarray(map_ori > 0.05, dtype=np.uint8)
if np.sum(binary) == 0:
continue
positions = np.where(binary > 0.5)
intensities = map_ori[positions]
mi = np.argmax(intensities)
y, x = positions[0][mi], positions[1][mi]
all_peaks.append([x, y])
return np.array(all_peaks)

View File

@@ -0,0 +1,90 @@
import cv2
import numpy as np
import torch
from scipy.ndimage.filters import gaussian_filter
from skimage.measure import label
from invokeai.backend.bria.controlnet_aux.open_pose import util
from invokeai.backend.bria.controlnet_aux.open_pose.model import handpose_model
class Hand(object):
def __init__(self, model_path):
self.model = handpose_model()
model_dict = util.transfer(self.model, torch.load(model_path))
self.model.load_state_dict(model_dict)
self.model.eval()
def to(self, device):
self.model.to(device)
return self
def __call__(self, oriImgRaw):
device = next(iter(self.model.parameters())).device
scale_search = [0.5, 1.0, 1.5, 2.0]
# scale_search = [0.5]
boxsize = 368
stride = 8
padValue = 128
thre = 0.05
multiplier = [x * boxsize for x in scale_search]
wsize = 128
heatmap_avg = np.zeros((wsize, wsize, 22))
Hr, Wr, Cr = oriImgRaw.shape
oriImg = cv2.GaussianBlur(oriImgRaw, (0, 0), 0.8)
for m in range(len(multiplier)):
scale = multiplier[m]
imageToTest = util.smart_resize(oriImg, (scale, scale))
imageToTest_padded, pad = util.padRightDownCorner(imageToTest, stride, padValue)
im = np.transpose(np.float32(imageToTest_padded[:, :, :, np.newaxis]), (3, 2, 0, 1)) / 256 - 0.5
im = np.ascontiguousarray(im)
data = torch.from_numpy(im).float()
data = data.to(device)
with torch.no_grad():
output = self.model(data).cpu().numpy()
# extract outputs, resize, and remove padding
heatmap = np.transpose(np.squeeze(output), (1, 2, 0)) # output 1 is heatmaps
heatmap = util.smart_resize_k(heatmap, fx=stride, fy=stride)
heatmap = heatmap[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :]
heatmap = util.smart_resize(heatmap, (wsize, wsize))
heatmap_avg += heatmap / len(multiplier)
all_peaks = []
for part in range(21):
map_ori = heatmap_avg[:, :, part]
one_heatmap = gaussian_filter(map_ori, sigma=3)
binary = np.ascontiguousarray(one_heatmap > thre, dtype=np.uint8)
if np.sum(binary) == 0:
all_peaks.append([0, 0])
continue
label_img, label_numbers = label(binary, return_num=True, connectivity=binary.ndim)
max_index = np.argmax([np.sum(map_ori[label_img == i]) for i in range(1, label_numbers + 1)]) + 1
label_img[label_img != max_index] = 0
map_ori[label_img == 0] = 0
y, x = util.npmax(map_ori)
y = int(float(y) * float(Hr) / float(wsize))
x = int(float(x) * float(Wr) / float(wsize))
all_peaks.append([x, y])
return np.array(all_peaks)
if __name__ == "__main__":
hand_estimation = Hand('../model/hand_pose_model.pth')
# test_image = '../images/hand.jpg'
test_image = '../images/hand.jpg'
oriImg = cv2.imread(test_image) # B,G,R order
peaks = hand_estimation(oriImg)
canvas = util.draw_handpose(oriImg, peaks, True)
cv2.imshow('', canvas)
cv2.waitKey(0)

View File

@@ -0,0 +1,217 @@
from collections import OrderedDict
import torch
import torch.nn as nn
def make_layers(block, no_relu_layers):
layers = []
for layer_name, v in block.items():
if 'pool' in layer_name:
layer = nn.MaxPool2d(kernel_size=v[0], stride=v[1],
padding=v[2])
layers.append((layer_name, layer))
else:
conv2d = nn.Conv2d(in_channels=v[0], out_channels=v[1],
kernel_size=v[2], stride=v[3],
padding=v[4])
layers.append((layer_name, conv2d))
if layer_name not in no_relu_layers:
layers.append(('relu_'+layer_name, nn.ReLU(inplace=True)))
return nn.Sequential(OrderedDict(layers))
class bodypose_model(nn.Module):
def __init__(self):
super(bodypose_model, self).__init__()
# these layers have no relu layer
no_relu_layers = ['conv5_5_CPM_L1', 'conv5_5_CPM_L2', 'Mconv7_stage2_L1',\
'Mconv7_stage2_L2', 'Mconv7_stage3_L1', 'Mconv7_stage3_L2',\
'Mconv7_stage4_L1', 'Mconv7_stage4_L2', 'Mconv7_stage5_L1',\
'Mconv7_stage5_L2', 'Mconv7_stage6_L1', 'Mconv7_stage6_L1']
blocks = {}
block0 = OrderedDict([
('conv1_1', [3, 64, 3, 1, 1]),
('conv1_2', [64, 64, 3, 1, 1]),
('pool1_stage1', [2, 2, 0]),
('conv2_1', [64, 128, 3, 1, 1]),
('conv2_2', [128, 128, 3, 1, 1]),
('pool2_stage1', [2, 2, 0]),
('conv3_1', [128, 256, 3, 1, 1]),
('conv3_2', [256, 256, 3, 1, 1]),
('conv3_3', [256, 256, 3, 1, 1]),
('conv3_4', [256, 256, 3, 1, 1]),
('pool3_stage1', [2, 2, 0]),
('conv4_1', [256, 512, 3, 1, 1]),
('conv4_2', [512, 512, 3, 1, 1]),
('conv4_3_CPM', [512, 256, 3, 1, 1]),
('conv4_4_CPM', [256, 128, 3, 1, 1])
])
# Stage 1
block1_1 = OrderedDict([
('conv5_1_CPM_L1', [128, 128, 3, 1, 1]),
('conv5_2_CPM_L1', [128, 128, 3, 1, 1]),
('conv5_3_CPM_L1', [128, 128, 3, 1, 1]),
('conv5_4_CPM_L1', [128, 512, 1, 1, 0]),
('conv5_5_CPM_L1', [512, 38, 1, 1, 0])
])
block1_2 = OrderedDict([
('conv5_1_CPM_L2', [128, 128, 3, 1, 1]),
('conv5_2_CPM_L2', [128, 128, 3, 1, 1]),
('conv5_3_CPM_L2', [128, 128, 3, 1, 1]),
('conv5_4_CPM_L2', [128, 512, 1, 1, 0]),
('conv5_5_CPM_L2', [512, 19, 1, 1, 0])
])
blocks['block1_1'] = block1_1
blocks['block1_2'] = block1_2
self.model0 = make_layers(block0, no_relu_layers)
# Stages 2 - 6
for i in range(2, 7):
blocks['block%d_1' % i] = OrderedDict([
('Mconv1_stage%d_L1' % i, [185, 128, 7, 1, 3]),
('Mconv2_stage%d_L1' % i, [128, 128, 7, 1, 3]),
('Mconv3_stage%d_L1' % i, [128, 128, 7, 1, 3]),
('Mconv4_stage%d_L1' % i, [128, 128, 7, 1, 3]),
('Mconv5_stage%d_L1' % i, [128, 128, 7, 1, 3]),
('Mconv6_stage%d_L1' % i, [128, 128, 1, 1, 0]),
('Mconv7_stage%d_L1' % i, [128, 38, 1, 1, 0])
])
blocks['block%d_2' % i] = OrderedDict([
('Mconv1_stage%d_L2' % i, [185, 128, 7, 1, 3]),
('Mconv2_stage%d_L2' % i, [128, 128, 7, 1, 3]),
('Mconv3_stage%d_L2' % i, [128, 128, 7, 1, 3]),
('Mconv4_stage%d_L2' % i, [128, 128, 7, 1, 3]),
('Mconv5_stage%d_L2' % i, [128, 128, 7, 1, 3]),
('Mconv6_stage%d_L2' % i, [128, 128, 1, 1, 0]),
('Mconv7_stage%d_L2' % i, [128, 19, 1, 1, 0])
])
for k in blocks.keys():
blocks[k] = make_layers(blocks[k], no_relu_layers)
self.model1_1 = blocks['block1_1']
self.model2_1 = blocks['block2_1']
self.model3_1 = blocks['block3_1']
self.model4_1 = blocks['block4_1']
self.model5_1 = blocks['block5_1']
self.model6_1 = blocks['block6_1']
self.model1_2 = blocks['block1_2']
self.model2_2 = blocks['block2_2']
self.model3_2 = blocks['block3_2']
self.model4_2 = blocks['block4_2']
self.model5_2 = blocks['block5_2']
self.model6_2 = blocks['block6_2']
def forward(self, x):
out1 = self.model0(x)
out1_1 = self.model1_1(out1)
out1_2 = self.model1_2(out1)
out2 = torch.cat([out1_1, out1_2, out1], 1)
out2_1 = self.model2_1(out2)
out2_2 = self.model2_2(out2)
out3 = torch.cat([out2_1, out2_2, out1], 1)
out3_1 = self.model3_1(out3)
out3_2 = self.model3_2(out3)
out4 = torch.cat([out3_1, out3_2, out1], 1)
out4_1 = self.model4_1(out4)
out4_2 = self.model4_2(out4)
out5 = torch.cat([out4_1, out4_2, out1], 1)
out5_1 = self.model5_1(out5)
out5_2 = self.model5_2(out5)
out6 = torch.cat([out5_1, out5_2, out1], 1)
out6_1 = self.model6_1(out6)
out6_2 = self.model6_2(out6)
return out6_1, out6_2
class handpose_model(nn.Module):
def __init__(self):
super(handpose_model, self).__init__()
# these layers have no relu layer
no_relu_layers = ['conv6_2_CPM', 'Mconv7_stage2', 'Mconv7_stage3',\
'Mconv7_stage4', 'Mconv7_stage5', 'Mconv7_stage6']
# stage 1
block1_0 = OrderedDict([
('conv1_1', [3, 64, 3, 1, 1]),
('conv1_2', [64, 64, 3, 1, 1]),
('pool1_stage1', [2, 2, 0]),
('conv2_1', [64, 128, 3, 1, 1]),
('conv2_2', [128, 128, 3, 1, 1]),
('pool2_stage1', [2, 2, 0]),
('conv3_1', [128, 256, 3, 1, 1]),
('conv3_2', [256, 256, 3, 1, 1]),
('conv3_3', [256, 256, 3, 1, 1]),
('conv3_4', [256, 256, 3, 1, 1]),
('pool3_stage1', [2, 2, 0]),
('conv4_1', [256, 512, 3, 1, 1]),
('conv4_2', [512, 512, 3, 1, 1]),
('conv4_3', [512, 512, 3, 1, 1]),
('conv4_4', [512, 512, 3, 1, 1]),
('conv5_1', [512, 512, 3, 1, 1]),
('conv5_2', [512, 512, 3, 1, 1]),
('conv5_3_CPM', [512, 128, 3, 1, 1])
])
block1_1 = OrderedDict([
('conv6_1_CPM', [128, 512, 1, 1, 0]),
('conv6_2_CPM', [512, 22, 1, 1, 0])
])
blocks = {}
blocks['block1_0'] = block1_0
blocks['block1_1'] = block1_1
# stage 2-6
for i in range(2, 7):
blocks['block%d' % i] = OrderedDict([
('Mconv1_stage%d' % i, [150, 128, 7, 1, 3]),
('Mconv2_stage%d' % i, [128, 128, 7, 1, 3]),
('Mconv3_stage%d' % i, [128, 128, 7, 1, 3]),
('Mconv4_stage%d' % i, [128, 128, 7, 1, 3]),
('Mconv5_stage%d' % i, [128, 128, 7, 1, 3]),
('Mconv6_stage%d' % i, [128, 128, 1, 1, 0]),
('Mconv7_stage%d' % i, [128, 22, 1, 1, 0])
])
for k in blocks.keys():
blocks[k] = make_layers(blocks[k], no_relu_layers)
self.model1_0 = blocks['block1_0']
self.model1_1 = blocks['block1_1']
self.model2 = blocks['block2']
self.model3 = blocks['block3']
self.model4 = blocks['block4']
self.model5 = blocks['block5']
self.model6 = blocks['block6']
def forward(self, x):
out1_0 = self.model1_0(x)
out1_1 = self.model1_1(out1_0)
concat_stage2 = torch.cat([out1_1, out1_0], 1)
out_stage2 = self.model2(concat_stage2)
concat_stage3 = torch.cat([out_stage2, out1_0], 1)
out_stage3 = self.model3(concat_stage3)
concat_stage4 = torch.cat([out_stage3, out1_0], 1)
out_stage4 = self.model4(concat_stage4)
concat_stage5 = torch.cat([out_stage4, out1_0], 1)
out_stage5 = self.model5(concat_stage5)
concat_stage6 = torch.cat([out_stage5, out1_0], 1)
out_stage6 = self.model6(concat_stage6)
return out_stage6

View File

@@ -0,0 +1,388 @@
import math
from typing import List, Tuple, Union
import cv2
import numpy as np
from invokeai.backend.bria.controlnet_aux.open_pose.body import BodyResult, Keypoint
eps = 0.01
def smart_resize(x, s):
Ht, Wt = s
if x.ndim == 2:
Ho, Wo = x.shape
Co = 1
else:
Ho, Wo, Co = x.shape
if Co == 3 or Co == 1:
k = float(Ht + Wt) / float(Ho + Wo)
return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4)
else:
return np.stack([smart_resize(x[:, :, i], s) for i in range(Co)], axis=2)
def smart_resize_k(x, fx, fy):
if x.ndim == 2:
Ho, Wo = x.shape
Co = 1
else:
Ho, Wo, Co = x.shape
Ht, Wt = Ho * fy, Wo * fx
if Co == 3 or Co == 1:
k = float(Ht + Wt) / float(Ho + Wo)
return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4)
else:
return np.stack([smart_resize_k(x[:, :, i], fx, fy) for i in range(Co)], axis=2)
def padRightDownCorner(img, stride, padValue):
h = img.shape[0]
w = img.shape[1]
pad = 4 * [None]
pad[0] = 0 # up
pad[1] = 0 # left
pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down
pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right
img_padded = img
pad_up = np.tile(img_padded[0:1, :, :]*0 + padValue, (pad[0], 1, 1))
img_padded = np.concatenate((pad_up, img_padded), axis=0)
pad_left = np.tile(img_padded[:, 0:1, :]*0 + padValue, (1, pad[1], 1))
img_padded = np.concatenate((pad_left, img_padded), axis=1)
pad_down = np.tile(img_padded[-2:-1, :, :]*0 + padValue, (pad[2], 1, 1))
img_padded = np.concatenate((img_padded, pad_down), axis=0)
pad_right = np.tile(img_padded[:, -2:-1, :]*0 + padValue, (1, pad[3], 1))
img_padded = np.concatenate((img_padded, pad_right), axis=1)
return img_padded, pad
def transfer(model, model_weights):
transfered_model_weights = {}
for weights_name in model.state_dict().keys():
transfered_model_weights[weights_name] = model_weights['.'.join(weights_name.split('.')[1:])]
return transfered_model_weights
def draw_bodypose(canvas: np.ndarray, keypoints: List[Keypoint]) -> np.ndarray:
"""
Draw keypoints and limbs representing body pose on a given canvas.
Args:
canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the body pose.
keypoints (List[Keypoint]): A list of Keypoint objects representing the body keypoints to be drawn.
Returns:
np.ndarray: A 3D numpy array representing the modified canvas with the drawn body pose.
Note:
The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1.
"""
H, W, C = canvas.shape
stickwidth = 4
limbSeq = [
[2, 3], [2, 6], [3, 4], [4, 5],
[6, 7], [7, 8], [2, 9], [9, 10],
[10, 11], [2, 12], [12, 13], [13, 14],
[2, 1], [1, 15], [15, 17], [1, 16],
[16, 18],
]
colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \
[0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \
[170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
for (k1_index, k2_index), color in zip(limbSeq, colors, strict=False):
keypoint1 = keypoints[k1_index - 1]
keypoint2 = keypoints[k2_index - 1]
if keypoint1 is None or keypoint2 is None:
continue
Y = np.array([keypoint1.x, keypoint2.x]) * float(W)
X = np.array([keypoint1.y, keypoint2.y]) * float(H)
mX = np.mean(X)
mY = np.mean(Y)
length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
cv2.fillConvexPoly(canvas, polygon, [int(float(c) * 0.6) for c in color])
for keypoint, color in zip(keypoints, colors, strict=False):
if keypoint is None:
continue
x, y = keypoint.x, keypoint.y
x = int(x * W)
y = int(y * H)
cv2.circle(canvas, (int(x), int(y)), 4, color, thickness=-1)
return canvas
def draw_handpose(canvas: np.ndarray, keypoints: Union[List[Keypoint], None]) -> np.ndarray:
import matplotlib
"""
Draw keypoints and connections representing hand pose on a given canvas.
Args:
canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the hand pose.
keypoints (List[Keypoint]| None): A list of Keypoint objects representing the hand keypoints to be drawn
or None if no keypoints are present.
Returns:
np.ndarray: A 3D numpy array representing the modified canvas with the drawn hand pose.
Note:
The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1.
"""
if not keypoints:
return canvas
H, W, C = canvas.shape
edges = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7], [7, 8], [0, 9], [9, 10], \
[10, 11], [11, 12], [0, 13], [13, 14], [14, 15], [15, 16], [0, 17], [17, 18], [18, 19], [19, 20]]
for ie, (e1, e2) in enumerate(edges):
k1 = keypoints[e1]
k2 = keypoints[e2]
if k1 is None or k2 is None:
continue
x1 = int(k1.x * W)
y1 = int(k1.y * H)
x2 = int(k2.x * W)
y2 = int(k2.y * H)
if x1 > eps and y1 > eps and x2 > eps and y2 > eps:
cv2.line(canvas, (x1, y1), (x2, y2), matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * 255, thickness=2)
for keypoint in keypoints:
x, y = keypoint.x, keypoint.y
x = int(x * W)
y = int(y * H)
if x > eps and y > eps:
cv2.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1)
return canvas
def draw_facepose(canvas: np.ndarray, keypoints: Union[List[Keypoint], None]) -> np.ndarray:
"""
Draw keypoints representing face pose on a given canvas.
Args:
canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the face pose.
keypoints (List[Keypoint]| None): A list of Keypoint objects representing the face keypoints to be drawn
or None if no keypoints are present.
Returns:
np.ndarray: A 3D numpy array representing the modified canvas with the drawn face pose.
Note:
The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1.
"""
if not keypoints:
return canvas
H, W, C = canvas.shape
for keypoint in keypoints:
x, y = keypoint.x, keypoint.y
x = int(x * W)
y = int(y * H)
if x > eps and y > eps:
cv2.circle(canvas, (x, y), 3, (255, 255, 255), thickness=-1)
return canvas
# detect hand according to body pose keypoints
# please refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/hand/handDetector.cpp
def handDetect(body: BodyResult, oriImg) -> List[Tuple[int, int, int, bool]]:
"""
Detect hands in the input body pose keypoints and calculate the bounding box for each hand.
Args:
body (BodyResult): A BodyResult object containing the detected body pose keypoints.
oriImg (numpy.ndarray): A 3D numpy array representing the original input image.
Returns:
List[Tuple[int, int, int, bool]]: A list of tuples, each containing the coordinates (x, y) of the top-left
corner of the bounding box, the width (height) of the bounding box, and
a boolean flag indicating whether the hand is a left hand (True) or a
right hand (False).
Notes:
- The width and height of the bounding boxes are equal since the network requires squared input.
- The minimum bounding box size is 20 pixels.
"""
ratioWristElbow = 0.33
detect_result = []
image_height, image_width = oriImg.shape[0:2]
keypoints = body.keypoints
# right hand: wrist 4, elbow 3, shoulder 2
# left hand: wrist 7, elbow 6, shoulder 5
left_shoulder = keypoints[5]
left_elbow = keypoints[6]
left_wrist = keypoints[7]
right_shoulder = keypoints[2]
right_elbow = keypoints[3]
right_wrist = keypoints[4]
# if any of three not detected
has_left = all(keypoint is not None for keypoint in (left_shoulder, left_elbow, left_wrist))
has_right = all(keypoint is not None for keypoint in (right_shoulder, right_elbow, right_wrist))
if not (has_left or has_right):
return []
hands = []
#left hand
if has_left:
hands.append([
left_shoulder.x, left_shoulder.y,
left_elbow.x, left_elbow.y,
left_wrist.x, left_wrist.y,
True
])
# right hand
if has_right:
hands.append([
right_shoulder.x, right_shoulder.y,
right_elbow.x, right_elbow.y,
right_wrist.x, right_wrist.y,
False
])
for x1, y1, x2, y2, x3, y3, is_left in hands:
# pos_hand = pos_wrist + ratio * (pos_wrist - pos_elbox) = (1 + ratio) * pos_wrist - ratio * pos_elbox
# handRectangle.x = posePtr[wrist*3] + ratioWristElbow * (posePtr[wrist*3] - posePtr[elbow*3]);
# handRectangle.y = posePtr[wrist*3+1] + ratioWristElbow * (posePtr[wrist*3+1] - posePtr[elbow*3+1]);
# const auto distanceWristElbow = getDistance(poseKeypoints, person, wrist, elbow);
# const auto distanceElbowShoulder = getDistance(poseKeypoints, person, elbow, shoulder);
# handRectangle.width = 1.5f * fastMax(distanceWristElbow, 0.9f * distanceElbowShoulder);
x = x3 + ratioWristElbow * (x3 - x2)
y = y3 + ratioWristElbow * (y3 - y2)
distanceWristElbow = math.sqrt((x3 - x2) ** 2 + (y3 - y2) ** 2)
distanceElbowShoulder = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)
width = 1.5 * max(distanceWristElbow, 0.9 * distanceElbowShoulder)
# x-y refers to the center --> offset to topLeft point
# handRectangle.x -= handRectangle.width / 2.f;
# handRectangle.y -= handRectangle.height / 2.f;
x -= width / 2
y -= width / 2 # width = height
# overflow the image
if x < 0:
x = 0
if y < 0:
y = 0
width1 = width
width2 = width
if x + width > image_width:
width1 = image_width - x
if y + width > image_height:
width2 = image_height - y
width = min(width1, width2)
# the max hand box value is 20 pixels
if width >= 20:
detect_result.append((int(x), int(y), int(width), is_left))
'''
return value: [[x, y, w, True if left hand else False]].
width=height since the network require squared input.
x, y is the coordinate of top left.
'''
return detect_result
# Written by Lvmin
def faceDetect(body: BodyResult, oriImg) -> Union[Tuple[int, int, int], None]:
"""
Detect the face in the input body pose keypoints and calculate the bounding box for the face.
Args:
body (BodyResult): A BodyResult object containing the detected body pose keypoints.
oriImg (numpy.ndarray): A 3D numpy array representing the original input image.
Returns:
Tuple[int, int, int] | None: A tuple containing the coordinates (x, y) of the top-left corner of the
bounding box and the width (height) of the bounding box, or None if the
face is not detected or the bounding box width is less than 20 pixels.
Notes:
- The width and height of the bounding box are equal.
- The minimum bounding box size is 20 pixels.
"""
# left right eye ear 14 15 16 17
image_height, image_width = oriImg.shape[0:2]
keypoints = body.keypoints
head = keypoints[0]
left_eye = keypoints[14]
right_eye = keypoints[15]
left_ear = keypoints[16]
right_ear = keypoints[17]
if head is None or all(keypoint is None for keypoint in (left_eye, right_eye, left_ear, right_ear)):
return None
width = 0.0
x0, y0 = head.x, head.y
if left_eye is not None:
x1, y1 = left_eye.x, left_eye.y
d = max(abs(x0 - x1), abs(y0 - y1))
width = max(width, d * 3.0)
if right_eye is not None:
x1, y1 = right_eye.x, right_eye.y
d = max(abs(x0 - x1), abs(y0 - y1))
width = max(width, d * 3.0)
if left_ear is not None:
x1, y1 = left_ear.x, left_ear.y
d = max(abs(x0 - x1), abs(y0 - y1))
width = max(width, d * 1.5)
if right_ear is not None:
x1, y1 = right_ear.x, right_ear.y
d = max(abs(x0 - x1), abs(y0 - y1))
width = max(width, d * 1.5)
x, y = x0, y0
x -= width
y -= width
if x < 0:
x = 0
if y < 0:
y = 0
width1 = width * 2
width2 = width * 2
if x + width > image_width:
width1 = image_width - x
if y + width > image_height:
width2 = image_height - y
width = min(width1, width2)
if width >= 20:
return int(x), int(y), int(width)
else:
return None
# get max index of 2d array
def npmax(array):
arrayindex = array.argmax(1)
arrayvalue = array.max(1)
i = arrayvalue.argmax()
j = arrayindex[i]
return i, j

View File

@@ -0,0 +1,146 @@
import os
import random
import cv2
import numpy as np
import torch
annotator_ckpts_path = os.path.join(os.path.dirname(__file__), 'ckpts')
def HWC3(x):
assert x.dtype == np.uint8
if x.ndim == 2:
x = x[:, :, None]
assert x.ndim == 3
H, W, C = x.shape
assert C == 1 or C == 3 or C == 4
if C == 3:
return x
if C == 1:
return np.concatenate([x, x, x], axis=2)
if C == 4:
color = x[:, :, 0:3].astype(np.float32)
alpha = x[:, :, 3:4].astype(np.float32) / 255.0
y = color * alpha + 255.0 * (1.0 - alpha)
y = y.clip(0, 255).astype(np.uint8)
return y
def make_noise_disk(H, W, C, F):
noise = np.random.uniform(low=0, high=1, size=((H // F) + 2, (W // F) + 2, C))
noise = cv2.resize(noise, (W + 2 * F, H + 2 * F), interpolation=cv2.INTER_CUBIC)
noise = noise[F: F + H, F: F + W]
noise -= np.min(noise)
noise /= np.max(noise)
if C == 1:
noise = noise[:, :, None]
return noise
def nms(x, t, s):
x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s)
f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8)
f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8)
f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8)
f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8)
y = np.zeros_like(x)
for f in [f1, f2, f3, f4]:
np.putmask(y, cv2.dilate(x, kernel=f) == x, x)
z = np.zeros_like(y, dtype=np.uint8)
z[y > t] = 255
return z
def min_max_norm(x):
x -= np.min(x)
x /= np.maximum(np.max(x), 1e-5)
return x
def safe_step(x, step=2):
y = x.astype(np.float32) * float(step + 1)
y = y.astype(np.int32).astype(np.float32) / float(step)
return y
def img2mask(img, H, W, low=10, high=90):
assert img.ndim == 3 or img.ndim == 2
assert img.dtype == np.uint8
if img.ndim == 3:
y = img[:, :, random.randrange(0, img.shape[2])]
else:
y = img
y = cv2.resize(y, (W, H), interpolation=cv2.INTER_CUBIC)
if random.uniform(0, 1) < 0.5:
y = 255 - y
return y < np.percentile(y, random.randrange(low, high))
def resize_image(input_image, resolution):
H, W, C = input_image.shape
H = float(H)
W = float(W)
k = float(resolution) / min(H, W)
H *= k
W *= k
H = int(np.round(H / 64.0)) * 64
W = int(np.round(W / 64.0)) * 64
img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
return img
def torch_gc():
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
def ade_palette():
"""ADE20K palette that maps each class to RGB values."""
return [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
[4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
[230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
[150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
[143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
[0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
[255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
[255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
[255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
[224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
[255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
[6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
[140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
[255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
[255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255],
[11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],
[0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0],
[255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],
[0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255],
[173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],
[255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20],
[255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],
[255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255],
[0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],
[0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0],
[143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],
[8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255],
[255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],
[92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160],
[163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],
[255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0],
[255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],
[10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255],
[255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],
[41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],
[71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],
[184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],
[102, 255, 0], [92, 0, 255]]

View File

@@ -0,0 +1,547 @@
# type: ignore
# Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
import torch
import torch.nn as nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.loaders import PeftAdapterMixin
from diffusers.models.attention_processor import AttentionProcessor
from diffusers.models.controlnet import zero_module
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.models.modeling_utils import ModelMixin
from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from diffusers.utils.outputs import BaseOutput
from invokeai.backend.bria.transformer_bria import (
EmbedND,
FluxSingleTransformerBlock,
FluxTransformerBlock,
TimestepProjEmbeddings,
)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
BRIA_CONTROL_MODES = Literal["depth", "canny", "colorgrid", "recolor", "tile", "pose"]
class BriaControlModes(Enum):
depth = 0
canny = 1
colorgrid = 2
recolor = 3
tile = 4
pose = 5
@dataclass
class BriaControlNetOutput(BaseOutput):
controlnet_block_samples: Tuple[torch.Tensor]
controlnet_single_block_samples: Tuple[torch.Tensor]
class BriaControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
patch_size: int = 1,
in_channels: int = 64,
num_layers: int = 19,
num_single_layers: int = 38,
attention_head_dim: int = 128,
num_attention_heads: int = 24,
joint_attention_dim: int = 4096,
pooled_projection_dim: int = 768,
guidance_embeds: bool = False,
axes_dims_rope: Optional[List[int]] = None,
num_mode: int = None,
rope_theta: int = 10000,
time_theta: int = 10000,
):
super().__init__()
self.out_channels = in_channels
self.inner_dim = num_attention_heads * attention_head_dim
# self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
axes_dims_rope = [16, 56, 56] if axes_dims_rope is None else axes_dims_rope
self.pos_embed = EmbedND(theta=rope_theta, axes_dim=axes_dims_rope)
# text_time_guidance_cls = (
# CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
# )
# self.time_text_embed = text_time_guidance_cls(
# embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
# )
self.time_embed = TimestepProjEmbeddings(
embedding_dim=self.inner_dim, time_theta=time_theta
)
self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
self.x_embedder = torch.nn.Linear(in_channels, self.inner_dim)
self.transformer_blocks = nn.ModuleList(
[
FluxTransformerBlock(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
)
for i in range(num_layers)
]
)
self.single_transformer_blocks = nn.ModuleList(
[
FluxSingleTransformerBlock(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
)
for i in range(num_single_layers)
]
)
# controlnet_blocks
self.controlnet_blocks = nn.ModuleList([])
for _ in range(len(self.transformer_blocks)):
self.controlnet_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim)))
self.controlnet_single_blocks = nn.ModuleList([])
for _ in range(len(self.single_transformer_blocks)):
self.controlnet_single_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim)))
self.union = num_mode is not None and num_mode > 0
if self.union:
self.controlnet_mode_embedder = nn.Embedding(num_mode, self.inner_dim)
self.controlnet_x_embedder = zero_module(torch.nn.Linear(in_channels, self.inner_dim))
self.gradient_checkpointing = False
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self):
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
@classmethod
def from_transformer(
cls,
transformer,
num_layers: int = 4,
num_single_layers: int = 10,
attention_head_dim: int = 128,
num_attention_heads: int = 24,
load_weights_from_transformer=True,
):
config = transformer.config
config["num_layers"] = num_layers
config["num_single_layers"] = num_single_layers
config["attention_head_dim"] = attention_head_dim
config["num_attention_heads"] = num_attention_heads
controlnet = cls(**config)
if load_weights_from_transformer:
controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict())
controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict())
controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict())
controlnet.x_embedder.load_state_dict(transformer.x_embedder.state_dict())
controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict(), strict=False)
controlnet.single_transformer_blocks.load_state_dict(
transformer.single_transformer_blocks.state_dict(), strict=False
)
controlnet.controlnet_x_embedder = zero_module(controlnet.controlnet_x_embedder)
return controlnet
def forward(
self,
hidden_states: torch.Tensor,
controlnet_cond: torch.Tensor,
controlnet_mode: torch.Tensor = None,
conditioning_scale: float = 1.0,
encoder_hidden_states: torch.Tensor = None,
pooled_projections: torch.Tensor = None,
timestep: torch.LongTensor = None,
img_ids: torch.Tensor = None,
txt_ids: torch.Tensor = None,
guidance: torch.Tensor = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
"""
The [`FluxTransformer2DModel`] forward method.
Args:
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
Input `hidden_states`.
controlnet_cond (`torch.Tensor`):
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
controlnet_mode (`torch.Tensor`):
The mode tensor of shape `(batch_size, 1)`.
conditioning_scale (`float`, defaults to `1.0`):
The scale factor for ControlNet outputs.
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
from the embeddings of input conditions.
timestep ( `torch.LongTensor`):
Used to indicate denoising step.
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
A list of tensors that if specified are added to the residuals of transformer blocks.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
tuple.
Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
if guidance is not None:
print("guidance is not supported in BriaControlNetModel")
if pooled_projections is not None:
print("pooled_projections is not supported in BriaControlNetModel")
if joint_attention_kwargs is not None:
joint_attention_kwargs = joint_attention_kwargs.copy()
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
hidden_states = self.x_embedder(hidden_states)
# Convert controlnet_cond to the same dtype as the model weights
controlnet_cond = controlnet_cond.to(dtype=self.controlnet_x_embedder.weight.dtype)
# add
hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_cond)
timestep = timestep.to(hidden_states.dtype) # Original code was * 1000
if guidance is not None:
guidance = guidance.to(hidden_states.dtype) # Original code was * 1000
else:
guidance = None
temb = self.time_embed(timestep, dtype=hidden_states.dtype)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
if txt_ids.ndim == 3:
logger.warning(
"Passing `txt_ids` 3d torch.Tensor is deprecated."
"Please remove the batch dimension and pass it as a 2d torch Tensor"
)
txt_ids = txt_ids[0]
if img_ids.ndim == 3:
logger.warning(
"Passing `img_ids` 3d torch.Tensor is deprecated."
"Please remove the batch dimension and pass it as a 2d torch Tensor"
)
img_ids = img_ids[0]
if self.union:
# union mode
if controlnet_mode is None:
raise ValueError("`controlnet_mode` cannot be `None` when applying ControlNet-Union")
# Validate controlnet_mode values are within the valid range
if torch.any(controlnet_mode < 0) or torch.any(controlnet_mode >= self.num_mode):
raise ValueError(f"`controlnet_mode` values must be in range [0, {self.num_mode-1}], but got values outside this range")
# union mode emb
controlnet_mode_emb = self.controlnet_mode_embedder(controlnet_mode)
if controlnet_mode_emb.shape[0] < encoder_hidden_states.shape[0]: # duplicate mode emb for each batch
controlnet_mode_emb = controlnet_mode_emb.expand(encoder_hidden_states.shape[0], 1, encoder_hidden_states.shape[2])
encoder_hidden_states = torch.cat([controlnet_mode_emb, encoder_hidden_states], dim=1)
txt_ids = torch.cat((txt_ids[0:1, :], txt_ids), dim=0)
ids = torch.cat((txt_ids, img_ids), dim=0)
image_rotary_emb = self.pos_embed(ids)
block_samples = ()
for _, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
**ckpt_kwargs,
)
else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
)
block_samples = block_samples + (hidden_states,)
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
single_block_samples = ()
for _, block in enumerate(self.single_transformer_blocks):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
temb,
image_rotary_emb,
**ckpt_kwargs,
)
else:
hidden_states = block(
hidden_states=hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
)
single_block_samples = single_block_samples + (hidden_states[:, encoder_hidden_states.shape[1] :],)
# controlnet block
controlnet_block_samples = ()
for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks, strict=False):
block_sample = controlnet_block(block_sample)
controlnet_block_samples = controlnet_block_samples + (block_sample,)
controlnet_single_block_samples = ()
for single_block_sample, controlnet_block in zip(single_block_samples, self.controlnet_single_blocks, strict=False):
single_block_sample = controlnet_block(single_block_sample)
controlnet_single_block_samples = controlnet_single_block_samples + (single_block_sample,)
# scaling
controlnet_block_samples = [sample * conditioning_scale for sample in controlnet_block_samples]
controlnet_single_block_samples = [sample * conditioning_scale for sample in controlnet_single_block_samples]
controlnet_block_samples = None if len(controlnet_block_samples) == 0 else controlnet_block_samples
controlnet_single_block_samples = (
None if len(controlnet_single_block_samples) == 0 else controlnet_single_block_samples
)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (controlnet_block_samples, controlnet_single_block_samples)
return BriaControlNetOutput(
controlnet_block_samples=controlnet_block_samples,
controlnet_single_block_samples=controlnet_single_block_samples,
)
class BriaMultiControlNetModel(ModelMixin):
r"""
`BriaMultiControlNetModel` wrapper class for Multi-BriaControlNetModel
This module is a wrapper for multiple instances of the `BriaControlNetModel`. The `forward()` API is designed to be
compatible with `BriaControlNetModel`.
Args:
controlnets (`List[BriaControlNetModel]`):
Provides additional conditioning to the unet during the denoising process. You must set multiple
`BriaControlNetModel` as a list.
"""
def __init__(self, controlnets):
super().__init__()
self.nets = nn.ModuleList(controlnets)
def forward(
self,
hidden_states: torch.FloatTensor,
controlnet_cond: List[torch.tensor],
controlnet_mode: List[torch.tensor],
conditioning_scale: List[float],
encoder_hidden_states: torch.Tensor = None,
pooled_projections: torch.Tensor = None,
timestep: torch.LongTensor = None,
img_ids: torch.Tensor = None,
txt_ids: torch.Tensor = None,
guidance: torch.Tensor = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[BriaControlNetOutput, Tuple]:
# ControlNet-Union with multiple conditions
# only load one ControlNet for saving memories
if len(self.nets) == 1 and self.nets[0].union:
controlnet = self.nets[0]
for i, (image, mode, scale) in enumerate(zip(controlnet_cond, controlnet_mode, conditioning_scale, strict=False)):
block_samples, single_block_samples = controlnet(
hidden_states=hidden_states,
controlnet_cond=image,
controlnet_mode=mode[:, None],
conditioning_scale=scale,
timestep=timestep,
guidance=guidance,
pooled_projections=pooled_projections,
encoder_hidden_states=encoder_hidden_states,
txt_ids=txt_ids,
img_ids=img_ids,
joint_attention_kwargs=joint_attention_kwargs,
return_dict=return_dict,
)
# merge samples
if i == 0:
control_block_samples = block_samples
control_single_block_samples = single_block_samples
else:
control_block_samples = [
control_block_sample + block_sample
for control_block_sample, block_sample in zip(control_block_samples, block_samples, strict=False)
]
control_single_block_samples = [
control_single_block_sample + block_sample
for control_single_block_sample, block_sample in zip(
control_single_block_samples, single_block_samples, strict=False
)
]
# Regular Multi-ControlNets
# load all ControlNets into memories
else:
for i, (image, mode, scale, controlnet) in enumerate(
zip(controlnet_cond, controlnet_mode, conditioning_scale, self.nets, strict=False)
):
block_samples, single_block_samples = controlnet(
hidden_states=hidden_states,
controlnet_cond=image,
controlnet_mode=mode[:, None],
conditioning_scale=scale,
timestep=timestep,
guidance=guidance,
pooled_projections=pooled_projections,
encoder_hidden_states=encoder_hidden_states,
txt_ids=txt_ids,
img_ids=img_ids,
joint_attention_kwargs=joint_attention_kwargs,
return_dict=return_dict,
)
# merge samples
if i == 0:
control_block_samples = block_samples
control_single_block_samples = single_block_samples
else:
if block_samples is not None and control_block_samples is not None:
control_block_samples = [
control_block_sample + block_sample
for control_block_sample, block_sample in zip(control_block_samples, block_samples, strict=False)
]
if single_block_samples is not None and control_single_block_samples is not None:
control_single_block_samples = [
control_single_block_sample + block_sample
for control_single_block_sample, block_sample in zip(
control_single_block_samples, single_block_samples, strict=False
)
]
return control_block_samples, control_single_block_samples

View File

@@ -0,0 +1,67 @@
from typing import List, Tuple
import torch
from diffusers.image_processor import VaeImageProcessor
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from PIL import Image
@torch.no_grad()
def prepare_control_images(
vae: AutoencoderKL,
control_images: list[Image.Image],
control_modes: list[int],
width: int,
height: int,
device: torch.device,
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
tensored_control_images = []
tensored_control_modes = []
for idx, control_image_ in enumerate(control_images):
tensored_control_image = _prepare_image(
image=control_image_,
width=width,
height=height,
device=device,
dtype=vae.dtype,
)
height, width = tensored_control_image.shape[-2:]
# vae encode
tensored_control_image = vae.encode(tensored_control_image).latent_dist.sample()
tensored_control_image = (tensored_control_image) * vae.config.scaling_factor
# pack
height_control_image, width_control_image = tensored_control_image.shape[2:]
tensored_control_image = _pack_latents(
tensored_control_image,
height_control_image,
width_control_image,
)
tensored_control_images.append(tensored_control_image)
tensored_control_modes.append(torch.tensor(control_modes[idx]).expand(
tensored_control_image.shape[0]).to(device, dtype=torch.long))
return tensored_control_images, tensored_control_modes
def _prepare_image(
image: Image.Image,
width: int,
height: int,
device: torch.device,
dtype: torch.dtype,
) -> torch.Tensor:
image = image.convert("RGB")
image = VaeImageProcessor(vae_scale_factor=16).preprocess(image, height=height, width=width)
image = image.repeat_interleave(1, dim=0)
image = image.to(device=device, dtype=dtype)
return image
def _pack_latents(latents, height, width):
latents = latents.view(1, 4, height // 2, 2, width // 2, 2)
latents = latents.permute(0, 2, 4, 1, 3, 5)
latents = latents.reshape(1, (height // 2) * (width // 2), 16)
return latents

View File

@@ -0,0 +1,640 @@
from typing import Any, Callable, Dict, List, Optional, Union
import diffusers
import numpy as np
import torch
from diffusers import AutoencoderKL, DDIMScheduler, EulerAncestralDiscreteScheduler
from diffusers.image_processor import VaeImageProcessor
from diffusers.loaders import FluxLoraLoaderMixin
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline, calculate_shift, retrieve_timesteps
from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler, KarrasDiffusionSchedulers
from diffusers.utils import (
USE_PEFT_BACKEND,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from diffusers.utils.torch_utils import randn_tensor
from transformers import (
T5EncoderModel,
T5TokenizerFast,
)
from invokeai.backend.bria.bria_utils import get_original_sigmas, get_t5_prompt_embeds, is_ng_none
from invokeai.backend.bria.transformer_bria import BriaTransformer2DModel
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers import StableDiffusion3Pipeline
>>> pipe = StableDiffusion3Pipeline.from_pretrained(
... "stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16
... )
>>> pipe.to("cuda")
>>> prompt = "A cat holding a sign that says hello world"
>>> image = pipe(prompt).images[0]
>>> image.save("sd3.png")
```
"""
T5_PRECISION = torch.float16
"""
Based on FluxPipeline with several changes:
- no pooled embeddings
- We use zero padding for prompts
- No guidance embedding since this is not a distilled version
"""
class BriaPipeline(FluxPipeline):
r"""
Args:
transformer ([`SD3Transformer2DModel`]):
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`T5EncoderModel`]):
Frozen text-encoder. Stable Diffusion 3 uses
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
[t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
tokenizer (`T5TokenizerFast`):
Tokenizer of class
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
"""
def __init__(
self,
transformer: BriaTransformer2DModel,
scheduler: Union[FlowMatchEulerDiscreteScheduler,KarrasDiffusionSchedulers],
vae: AutoencoderKL,
text_encoder: T5EncoderModel,
tokenizer: T5TokenizerFast
):
self.register_modules(
vae=vae,
transformer=transformer,
scheduler=scheduler,
text_encoder=text_encoder,
tokenizer=tokenizer,
)
# TODO - why different than offical flux (-1)
self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.default_sample_size = 64 # due to patchify=> 128,128 => res of 1k,1k
# T5 is senstive to precision so we use the precision used for precompute and cast as needed
if self.vae.config.shift_factor is None:
self.vae.config.shift_factor=0
self.vae.to(dtype=torch.float32)
def encode_prompt(
self,
prompt: Union[str, List[str]],
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
do_classifier_free_guidance: bool = True,
negative_prompt: Optional[Union[str, List[str]]] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
max_sequence_length: int = 128,
lora_scale: Optional[float] = None,
):
r"""
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
"""
device = device or self._execution_device
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if self.text_encoder is not None and USE_PEFT_BACKEND:
scale_lora_layers(self.text_encoder, lora_scale)
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt is not None:
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
prompt_embeds = get_t5_prompt_embeds(
self.tokenizer,
self.text_encoder,
prompt=prompt,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
).to(dtype=self.transformer.dtype)
if do_classifier_free_guidance and negative_prompt_embeds is None:
if not is_ng_none(negative_prompt):
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
negative_prompt_embeds = get_t5_prompt_embeds(
self.tokenizer,
self.text_encoder,
prompt=negative_prompt,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
).to(dtype=self.transformer.dtype)
else:
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
if self.text_encoder is not None:
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder, lora_scale)
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
text_ids = text_ids.repeat(num_images_per_prompt, 1, 1)
return prompt_embeds, negative_prompt_embeds, text_ids
@property
def guidance_scale(self):
return self._guidance_scale
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
@property
def joint_attention_kwargs(self):
return self._joint_attention_kwargs
@property
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 30,
timesteps: List[int] = None,
guidance_scale: float = 5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: Optional[List[str]] = None,
max_sequence_length: int = 128,
clip_value:Union[None,float] = None,
normalize:bool = False
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image. This is set to 1024 by default for the best results.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 5.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
of a plain tuple.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
Examples:
Returns:
[`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
images.
"""
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
# 1. Check inputs. Raise error if not correct
callback_on_step_end_tensor_inputs = ["latents"] if callback_on_step_end_tensor_inputs is None else callback_on_step_end_tensor_inputs
self.check_inputs(
prompt=prompt,
height=height,
width=width,
prompt_embeds=prompt_embeds,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
max_sequence_length=max_sequence_length,
)
self._guidance_scale = guidance_scale
self._joint_attention_kwargs = joint_attention_kwargs
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
lora_scale = (
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
)
(
prompt_embeds,
negative_prompt_embeds,
text_ids
) = self.encode_prompt(
prompt=prompt,
negative_prompt=negative_prompt,
do_classifier_free_guidance=self.do_classifier_free_guidance,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
device=device,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
lora_scale=lora_scale,
)
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
# 5. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels // 4 # due to patch=2, we devide by 4
latents, latent_image_ids = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
if isinstance(self.scheduler,FlowMatchEulerDiscreteScheduler) and self.scheduler.config['use_dynamic_shifting']:
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
image_seq_len = latents.shape[1] # Shift by height - Why just height?
print(f"Using dynamic shift in pipeline with sequence length {image_seq_len}")
mu = calculate_shift(
image_seq_len,
self.scheduler.config.base_image_seq_len,
self.scheduler.config.max_image_seq_len,
self.scheduler.config.base_shift,
self.scheduler.config.max_shift,
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
timesteps,
sigmas,
mu=mu,
)
else:
# 4. Prepare timesteps
# Sample from training sigmas
if isinstance(self.scheduler,DDIMScheduler) or isinstance(self.scheduler,EulerAncestralDiscreteScheduler):
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, None, None)
else:
sigmas = get_original_sigmas(num_train_timesteps=self.scheduler.config.num_train_timesteps,num_inference_steps=num_inference_steps)
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps,sigmas=sigmas)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
# Supprot different diffusers versions
if diffusers.__version__>='0.32.0':
latent_image_ids=latent_image_ids[0]
text_ids=text_ids[0]
# 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
if not isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0])
# This is predicts "v" from flow-matching or eps from diffusion
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
txt_ids=text_ids,
img_ids=latent_image_ids,
)[0]
# perform guidance
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
cfg_noise_pred_text = noise_pred_text.std()
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
if normalize:
noise_pred = noise_pred * (0.7 *(cfg_noise_pred_text/noise_pred.std())) + 0.3 * noise_pred
if clip_value:
assert clip_value>0
noise_pred = noise_pred.clip(-clip_value,clip_value)
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if output_type == "latent":
image = latents
else:
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
latents = (latents.to(dtype=torch.float32) / self.vae.config.scaling_factor) + self.vae.config.shift_factor
image = self.vae.decode(latents.to(dtype=self.vae.dtype), return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return FluxPipelineOutput(images=image)
def check_inputs(
self,
prompt,
height,
width,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
max_sequence_length=None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if max_sequence_length is not None and max_sequence_length > 512:
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
def to(self, *args, **kwargs):
DiffusionPipeline.to(self, *args, **kwargs)
# T5 is senstive to precision so we use the precision used for precompute and cast as needed
self.text_encoder = self.text_encoder.to(dtype=T5_PRECISION)
for block in self.text_encoder.encoder.block:
block.layer[-1].DenseReluDense.wo.to(dtype=torch.float32)
if self.vae.config.shift_factor == 0 and self.vae.dtype!=torch.float32:
self.vae.to(dtype=torch.float32)
return self
def prepare_latents(
self,
batch_size,
num_channels_latents,
height,
width,
dtype,
device,
generator,
latents=None,
):
# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
height = 2 * (int(height) // self.vae_scale_factor)
width = 2 * (int(width) // self.vae_scale_factor )
shape = (batch_size, num_channels_latents, height, width)
if latents is not None:
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
return latents.to(device=device, dtype=dtype), latent_image_ids
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
return latents, latent_image_ids
@staticmethod
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
latents = latents.permute(0, 2, 4, 1, 3, 5)
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
return latents
@staticmethod
def _unpack_latents(latents, height, width, vae_scale_factor):
batch_size, num_patches, channels = latents.shape
height = height // vae_scale_factor
width = width // vae_scale_factor
latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
latents = latents.permute(0, 3, 1, 4, 2, 5)
latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
return latents
@staticmethod
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
latent_image_ids = torch.zeros(height, width, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
latent_image_ids = latent_image_ids.repeat(batch_size, 1, 1, 1)
latent_image_ids = latent_image_ids.reshape(
batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
)
return latent_image_ids.to(device=device, dtype=dtype)

View File

@@ -0,0 +1,666 @@
# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Dict, List, Optional, Union
import diffusers
import numpy as np
import torch
from diffusers import AutoencoderKL # Waiting for diffusers udpdate
from diffusers.image_processor import PipelineImageInput
from diffusers.pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps
from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler, KarrasDiffusionSchedulers
from diffusers.utils import USE_PEFT_BACKEND, logging
from diffusers.utils.peft_utils import scale_lora_layers, unscale_lora_layers
from diffusers.utils.torch_utils import randn_tensor
from transformers import (
T5EncoderModel,
T5TokenizerFast,
)
from invokeai.backend.bria.bria_utils import get_original_sigmas, get_t5_prompt_embeds, is_ng_none
from invokeai.backend.bria.controlnet_bria import BriaControlNetModel
from invokeai.backend.bria.pipeline_bria import BriaPipeline
from invokeai.backend.bria.transformer_bria import BriaTransformer2DModel
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class BriaControlNetPipeline(BriaPipeline):
r"""
Args:
transformer ([`SD3Transformer2DModel`]):
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`T5EncoderModel`]):
Frozen text-encoder. Stable Diffusion 3 uses
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
[t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
tokenizer (`T5TokenizerFast`):
Tokenizer of class
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
"""
model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder->transformer->vae"
_optional_components = []
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"]
def __init__( # EYAL - removed clip text encoder + tokenizer
self,
transformer: BriaTransformer2DModel,
scheduler: Union[FlowMatchEulerDiscreteScheduler, KarrasDiffusionSchedulers],
vae: AutoencoderKL,
text_encoder: T5EncoderModel,
tokenizer: T5TokenizerFast,
controlnet: BriaControlNetModel,
):
super().__init__(
transformer=transformer, scheduler=scheduler, vae=vae, text_encoder=text_encoder, tokenizer=tokenizer
)
self.register_modules(controlnet=controlnet)
def prepare_image(
self,
image,
width,
height,
batch_size,
num_images_per_prompt,
device,
dtype,
do_classifier_free_guidance=False,
guess_mode=False,
):
if isinstance(image, torch.Tensor):
pass
else:
image = self.image_processor.preprocess(image, height=height, width=width)
image_batch_size = image.shape[0]
if image_batch_size == 1:
repeat_by = batch_size
else:
# image batch size is the same as prompt batch size
repeat_by = num_images_per_prompt
image = image.repeat_interleave(repeat_by, dim=0)
image = image.to(device=device, dtype=dtype)
if do_classifier_free_guidance and not guess_mode:
image = torch.cat([image] * 2)
return image
def prepare_control(self, control_image, width, height, batch_size, num_images_per_prompt, device, control_mode):
num_channels_latents = self.transformer.config.in_channels // 4
control_image = self.prepare_image(
image=control_image,
width=width,
height=height,
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=self.vae.dtype,
)
height, width = control_image.shape[-2:]
# vae encode
control_image = self.vae.encode(control_image).latent_dist.sample()
control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
# pack
height_control_image, width_control_image = control_image.shape[2:]
control_image = self._pack_latents(
control_image,
batch_size * num_images_per_prompt,
num_channels_latents,
height_control_image,
width_control_image,
)
# Here we ensure that `control_mode` has the same length as the control_image.
if control_mode is not None:
if not isinstance(control_mode, int):
raise ValueError(" For `BriaControlNet`, `control_mode` should be an `int` or `None`")
control_mode = torch.tensor(control_mode).to(device, dtype=torch.long)
control_mode = control_mode.view(-1, 1).expand(control_image.shape[0], 1)
return control_image, control_mode
def prepare_multi_control(self, control_image, width, height, batch_size, num_images_per_prompt, device, control_mode):
num_channels_latents = self.transformer.config.in_channels // 4
control_images = []
for _, control_image_ in enumerate(control_image):
control_image_ = self.prepare_image(
image=control_image_,
width=width,
height=height,
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=self.vae.dtype,
)
height, width = control_image_.shape[-2:]
# vae encode
control_image_ = self.vae.encode(control_image_).latent_dist.sample()
control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor
# pack
height_control_image, width_control_image = control_image_.shape[2:]
control_image_ = self._pack_latents(
control_image_,
batch_size * num_images_per_prompt,
num_channels_latents,
height_control_image,
width_control_image,
)
control_images.append(control_image_)
control_image = control_images
# Here we ensure that `control_mode` has the same length as the control_image.
if isinstance(control_mode, list) and len(control_mode) != len(control_image):
raise ValueError(
"For Multi-ControlNet, `control_mode` must be a list of the same "
+ " length as the number of controlnets (control images) specified"
)
if not isinstance(control_mode, list):
control_mode = [control_mode] * len(control_image)
# set control mode
control_modes = []
for cmode in control_mode:
if cmode is None:
cmode = -1
control_mode = torch.tensor(cmode).expand(control_images[0].shape[0]).to(device, dtype=torch.long)
control_modes.append(control_mode)
control_mode = control_modes
return control_image, control_mode
def get_controlnet_keep(self, timesteps, control_guidance_start, control_guidance_end):
controlnet_keep = []
for i in range(len(timesteps)):
keeps = [
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
for s, e in zip(control_guidance_start, control_guidance_end, strict=False)
]
controlnet_keep.append(keeps[0] if isinstance(self.controlnet, BriaControlNetModel) else keeps)
return controlnet_keep
def get_control_start_end(self, control_guidance_start, control_guidance_end):
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
mult = 1 # TODO - why is this 1?
control_guidance_start, control_guidance_end = (
mult * [control_guidance_start],
mult * [control_guidance_end],
)
return control_guidance_start, control_guidance_end
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 30,
timesteps: List[int] = None,
guidance_scale: float = 3.5,
control_guidance_start: Union[float, List[float]] = 0.0,
control_guidance_end: Union[float, List[float]] = 1.0,
control_image: Optional[PipelineImageInput] = None,
control_mode: Optional[Union[int, List[int]]] = None,
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
latent_image_ids: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
text_ids: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: Optional[List[str]] = None,
max_sequence_length: int = 128,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image. This is set to 1024 by default for the best results.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 5.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
of a plain tuple.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
Examples:
Returns:
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
`tuple`. When returning a tuple, the first element is a list with the generated images.
"""
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
control_guidance_start, control_guidance_end = self.get_control_start_end(
control_guidance_start=control_guidance_start, control_guidance_end=control_guidance_end
)
# 1. Check inputs. Raise error if not correct
callback_on_step_end_tensor_inputs = ["latents"] if callback_on_step_end_tensor_inputs is None else callback_on_step_end_tensor_inputs
self.check_inputs(
prompt,
height,
width,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
max_sequence_length=max_sequence_length,
)
self._guidance_scale = guidance_scale
self._joint_attention_kwargs = joint_attention_kwargs
self._interrupt = False
device = self._execution_device
# 4. Prepare timesteps
if isinstance(self.scheduler,FlowMatchEulerDiscreteScheduler) and self.scheduler.config['use_dynamic_shifting']:
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
# Determine image sequence length
if control_image is not None:
if isinstance(control_image, list):
image_seq_len = control_image[0].shape[1]
else:
image_seq_len = control_image.shape[1]
else:
# Use latents sequence length when no control image is provided
image_seq_len = latents.shape[1]
print(f"Using dynamic shift in pipeline with sequence length {image_seq_len}")
mu = calculate_shift(
image_seq_len,
self.scheduler.config.base_image_seq_len,
self.scheduler.config.max_image_seq_len,
self.scheduler.config.base_shift,
self.scheduler.config.max_shift,
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
timesteps=None,
sigmas=sigmas,
mu=mu,
)
else:
# 5. Prepare timesteps
sigmas = get_original_sigmas(
num_train_timesteps=self.scheduler.config.num_train_timesteps, num_inference_steps=num_inference_steps
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, num_inference_steps, device, timesteps, sigmas=sigmas
)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
# 6. Create tensor stating which controlnets to keep
if control_image is not None:
controlnet_keep = self.get_controlnet_keep(
timesteps=timesteps,
control_guidance_start=control_guidance_start,
control_guidance_end=control_guidance_end,
)
if diffusers.__version__>='0.32.0':
latent_image_ids=latent_image_ids[0]
text_ids=text_ids[0]
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
# EYAL - added the CFG loop
# 7. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
# if type(self.scheduler) != FlowMatchEulerDiscreteScheduler:
if not isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0])
# Handling ControlNet
if control_image is not None:
if isinstance(controlnet_keep[i], list):
if isinstance(controlnet_conditioning_scale, list):
cond_scale = controlnet_conditioning_scale
else:
cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i], strict=False)]
else:
controlnet_cond_scale = controlnet_conditioning_scale
if isinstance(controlnet_cond_scale, list):
controlnet_cond_scale = controlnet_cond_scale[0]
cond_scale = controlnet_cond_scale * controlnet_keep[i]
controlnet_block_samples, controlnet_single_block_samples = self.controlnet(
hidden_states=latents,
controlnet_cond=control_image,
controlnet_mode=control_mode,
conditioning_scale=cond_scale,
timestep=timestep,
# guidance=guidance,
# pooled_projections=pooled_prompt_embeds,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)
else:
controlnet_block_samples, controlnet_single_block_samples = None, None
# This is predicts "v" from flow-matching
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
txt_ids=text_ids,
img_ids=latent_image_ids,
controlnet_block_samples=controlnet_block_samples,
controlnet_single_block_samples=controlnet_single_block_samples,
)[0]
# perform guidance
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if output_type == "latent":
image = latents
else:
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
image = self.vae.decode(latents.to(dtype=self.vae.dtype), return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return FluxPipelineOutput(images=image)
def encode_prompt(
prompt: Union[str, List[str]],
tokenizer: T5TokenizerFast,
text_encoder: T5EncoderModel,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
negative_prompt: Optional[Union[str, List[str]]] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
max_sequence_length: int = 128,
lora_scale: Optional[float] = None,
):
r"""
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
"""
device = device or torch.device("cuda")
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
# dynamically adjust the LoRA scale
if text_encoder is not None and USE_PEFT_BACKEND:
scale_lora_layers(text_encoder, lora_scale)
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt is not None:
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
dtype = text_encoder.dtype if text_encoder is not None else torch.float32
if prompt_embeds is None:
prompt_embeds = get_t5_prompt_embeds(
tokenizer,
text_encoder,
prompt=prompt,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
).to(dtype=dtype)
if negative_prompt_embeds is None:
if not is_ng_none(negative_prompt):
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
negative_prompt_embeds = get_t5_prompt_embeds(
tokenizer,
text_encoder,
prompt=negative_prompt,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
).to(dtype=dtype)
else:
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
if text_encoder is not None:
if USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(text_encoder, lora_scale)
text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
text_ids = text_ids.repeat(num_images_per_prompt, 1, 1)
return prompt_embeds, negative_prompt_embeds, text_ids
def prepare_latents(
batch_size: int,
num_channels_latents: int,
height: int,
width: int,
dtype: torch.dtype,
device: torch.device,
generator: torch.Generator,
latents: Optional[torch.FloatTensor] = None,
):
# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
vae_scale_factor = 16
height = 2 * (int(height) // vae_scale_factor)
width = 2 * (int(width) // vae_scale_factor )
shape = (batch_size, num_channels_latents, height, width)
if latents is not None:
latent_image_ids = _prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
return latents.to(device=device, dtype=dtype), latent_image_ids
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = _pack_latents(latents, batch_size, num_channels_latents, height, width)
latent_image_ids = _prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
return latents, latent_image_ids
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
latent_image_ids = torch.zeros(height, width, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
latent_image_ids = latent_image_ids.repeat(batch_size, 1, 1, 1)
latent_image_ids = latent_image_ids.reshape(
batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
)
return latent_image_ids.to(device=device, dtype=dtype)
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
latents = latents.permute(0, 2, 4, 1, 3, 5)
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
return latents

View File

@@ -0,0 +1,322 @@
from typing import Any, Dict, List, Optional, Union
import numpy as np
import torch
import torch.nn as nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
from diffusers.models.embeddings import TimestepEmbedding, get_timestep_embedding
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.normalization import AdaLayerNormContinuous
from diffusers.models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from invokeai.backend.bria.bria_utils import FluxPosEmbed as EmbedND
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class Timesteps(nn.Module):
def __init__(
self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1, time_theta=10000
):
super().__init__()
self.num_channels = num_channels
self.flip_sin_to_cos = flip_sin_to_cos
self.downscale_freq_shift = downscale_freq_shift
self.scale = scale
self.time_theta = time_theta
def forward(self, timesteps):
t_emb = get_timestep_embedding(
timesteps,
self.num_channels,
flip_sin_to_cos=self.flip_sin_to_cos,
downscale_freq_shift=self.downscale_freq_shift,
scale=self.scale,
max_period=self.time_theta,
)
return t_emb
class TimestepProjEmbeddings(nn.Module):
def __init__(self, embedding_dim, time_theta):
super().__init__()
self.time_proj = Timesteps(
num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, time_theta=time_theta
)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
def forward(self, timestep, dtype):
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=dtype)) # (N, D)
return timesteps_emb
"""
Based on FluxPipeline with several changes:
- no pooled embeddings
- We use zero padding for prompts
- No guidance embedding since this is not a distilled version
"""
class BriaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
"""
The Transformer model introduced in Flux.
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
Parameters:
patch_size (`int`): Patch size to turn the input data into small patches.
in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
patch_size: int = 1,
in_channels: int = 64,
num_layers: int = 19,
num_single_layers: int = 38,
attention_head_dim: int = 128,
num_attention_heads: int = 24,
joint_attention_dim: int = 4096,
pooled_projection_dim: int = None,
guidance_embeds: bool = False,
axes_dims_rope: Optional[List[int]] = None,
rope_theta=10000,
time_theta=10000,
):
super().__init__()
self.out_channels = in_channels
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
axes_dims_rope = [16, 56, 56] if axes_dims_rope is None else axes_dims_rope
self.pos_embed = EmbedND(theta=rope_theta, axes_dim=axes_dims_rope)
self.time_embed = TimestepProjEmbeddings(embedding_dim=self.inner_dim, time_theta=time_theta)
# if pooled_projection_dim:
# self.pooled_text_embed = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim=self.inner_dim, act_fn="silu")
if guidance_embeds:
self.guidance_embed = TimestepProjEmbeddings(embedding_dim=self.inner_dim)
self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim)
self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim)
self.transformer_blocks = nn.ModuleList(
[
FluxTransformerBlock(
dim=self.inner_dim,
num_attention_heads=self.config.num_attention_heads,
attention_head_dim=self.config.attention_head_dim,
)
for i in range(self.config.num_layers)
]
)
self.single_transformer_blocks = nn.ModuleList(
[
FluxSingleTransformerBlock(
dim=self.inner_dim,
num_attention_heads=self.config.num_attention_heads,
attention_head_dim=self.config.attention_head_dim,
)
for i in range(self.config.num_single_layers)
]
)
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
self.gradient_checkpointing = False
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
pooled_projections: torch.Tensor = None,
timestep: torch.LongTensor = None,
img_ids: torch.Tensor = None,
txt_ids: torch.Tensor = None,
guidance: torch.Tensor = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
controlnet_block_samples=None,
controlnet_single_block_samples=None,
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
"""
The [`FluxTransformer2DModel`] forward method.
Args:
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
Input `hidden_states`.
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
from the embeddings of input conditions.
timestep ( `torch.LongTensor`):
Used to indicate denoising step.
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
A list of tensors that if specified are added to the residuals of transformer blocks.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
tuple.
Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
if joint_attention_kwargs is not None:
joint_attention_kwargs = joint_attention_kwargs.copy()
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
hidden_states = self.x_embedder(hidden_states)
timestep = timestep.to(hidden_states.dtype)
if guidance is not None:
guidance = guidance.to(hidden_states.dtype)
else:
guidance = None
# temb = (
# self.time_text_embed(timestep, pooled_projections)
# if guidance is None
# else self.time_text_embed(timestep, guidance, pooled_projections)
# )
temb = self.time_embed(timestep, dtype=hidden_states.dtype)
# if pooled_projections:
# temb+=self.pooled_text_embed(pooled_projections)
if guidance:
temb += self.guidance_embed(guidance, dtype=hidden_states.dtype)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
if len(txt_ids.shape) == 2:
ids = torch.cat((txt_ids, img_ids), dim=0)
else:
ids = torch.cat((txt_ids, img_ids), dim=1)
image_rotary_emb = self.pos_embed(ids)
for index_block, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
**ckpt_kwargs,
)
else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
)
# controlnet residual
if controlnet_block_samples is not None:
interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
interval_control = int(np.ceil(interval_control))
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
for index_block, block in enumerate(self.single_transformer_blocks):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
temb,
image_rotary_emb,
**ckpt_kwargs,
)
else:
hidden_states = block(
hidden_states=hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
)
# controlnet residual
if controlnet_single_block_samples is not None:
interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
interval_control = int(np.ceil(interval_control))
hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
hidden_states[:, encoder_hidden_states.shape[1] :, ...]
+ controlnet_single_block_samples[index_block // interval_control]
)
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)

View File

@@ -125,6 +125,8 @@ class ModelProbe(object):
}
CLASS2TYPE = {
"BriaPipeline": ModelType.Main,
"BriaTransformer2DModel": ModelType.ControlNet,
"FluxPipeline": ModelType.Main,
"StableDiffusionPipeline": ModelType.Main,
"StableDiffusionInpaintPipeline": ModelType.Main,
@@ -861,6 +863,8 @@ class PipelineFolderProbe(FolderProbeBase):
return BaseModelType.StableDiffusion3
elif transformer_conf["_class_name"] == "CogView4Transformer2DModel":
return BaseModelType.CogView4
elif transformer_conf["_class_name"] == "BriaTransformer2DModel":
return BaseModelType.Bria
else:
raise InvalidModelConfigException(f"Unknown base model for {self.model_path}")
@@ -1010,6 +1014,9 @@ class ControlNetFolderProbe(FolderProbeBase):
if config.get("_class_name", None) == "FluxControlNetModel":
return BaseModelType.Flux
if config.get("_class_name", None) == "BriaTransformer2DModel":
return BaseModelType.Bria
# no obvious way to distinguish between sd2-base and sd2-768
dimension = config["cross_attention_dim"]
if dimension == 768:

View File

@@ -0,0 +1,95 @@
from pathlib import Path
from typing import Optional
from invokeai.backend.model_manager.config import (
AnyModelConfig,
CheckpointConfigBase,
ControlNetCheckpointConfig,
ControlNetDiffusersConfig,
DiffusersConfigBase,
)
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
from invokeai.backend.model_manager.taxonomy import (
AnyModel,
BaseModelType,
ModelFormat,
ModelType,
SubModelType,
)
@ModelLoaderRegistry.register(base=BaseModelType.Bria, type=ModelType.ControlNet, format=ModelFormat.Diffusers)
class BriaControlNetDiffusersModel(GenericDiffusersLoader):
"""Class to load Bria control net models."""
def _load_model(
self,
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if isinstance(config, ControlNetCheckpointConfig):
raise NotImplementedError("CheckpointConfigBase is not implemented for Bria models.")
model_path = Path(config.path)
load_class = self.get_hf_load_class(model_path)
repo_variant = config.repo_variant if isinstance(config, ControlNetDiffusersConfig) else None
variant = repo_variant.value if repo_variant else None
model_path = model_path
dtype = self._torch_dtype
try:
result: AnyModel = load_class.from_pretrained(
model_path,
torch_dtype=dtype,
variant=variant,
use_safetensors=False,
)
except OSError as e:
if variant and "no file named" in str(
e
): # try without the variant, just in case user's preferences changed
result = load_class.from_pretrained(model_path, torch_dtype=dtype)
else:
raise e
return result
@ModelLoaderRegistry.register(base=BaseModelType.Bria, type=ModelType.Main, format=ModelFormat.Diffusers)
class BriaDiffusersModel(GenericDiffusersLoader):
"""Class to load Bria main models."""
def _load_model(
self,
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if isinstance(config, CheckpointConfigBase):
raise NotImplementedError("CheckpointConfigBase is not implemented for Bria models.")
if submodel_type is None:
raise Exception("A submodel type must be provided when loading main pipelines.")
model_path = Path(config.path)
load_class = self.get_hf_load_class(model_path, submodel_type)
repo_variant = config.repo_variant if isinstance(config, DiffusersConfigBase) else None
variant = repo_variant.value if repo_variant else None
model_path = model_path / submodel_type.value
dtype = self._torch_dtype
try:
result: AnyModel = load_class.from_pretrained(
model_path,
torch_dtype=dtype,
variant=variant,
)
except OSError as e:
if variant and "no file named" in str(
e
): # try without the variant, just in case user's preferences changed
result = load_class.from_pretrained(model_path, torch_dtype=dtype)
else:
raise e
return result

View File

@@ -80,7 +80,13 @@ class GenericDiffusersLoader(ModelLoader):
"transformers",
"invokeai.backend.quantization.fast_quantized_transformers_model",
"invokeai.backend.quantization.fast_quantized_diffusion_model",
"transformer_bria",
]:
if module == "transformer_bria":
module = "invokeai.backend.bria.transformer_bria"
elif class_name == "BriaTransformer2DModel":
class_name = "BriaControlNetModel"
module = "invokeai.backend.bria.controlnet_bria"
res_type = sys.modules[module]
else:
res_type = sys.modules["diffusers"].pipelines

View File

@@ -12,6 +12,9 @@ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.schedulers.scheduling_utils import SchedulerMixin
from transformers import CLIPTokenizer, T5Tokenizer, T5TokenizerFast
from invokeai.backend.bria.controlnet_aux.open_pose.body import Body
from invokeai.backend.bria.controlnet_aux.open_pose.face import Face
from invokeai.backend.bria.controlnet_aux.open_pose.hand import Hand
from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import DepthAnythingPipeline
from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline
from invokeai.backend.image_util.segment_anything.segment_anything_pipeline import SegmentAnythingPipeline
@@ -62,6 +65,8 @@ def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int:
else:
# If neither is available, return 0
return 0
elif isinstance(model, (Body, Hand, Face)):
return calc_module_size(model.model)
elif isinstance(
model,
(

View File

@@ -30,6 +30,7 @@ class BaseModelType(str, Enum):
Imagen4 = "imagen4"
ChatGPT4o = "chatgpt-4o"
FluxKontext = "flux-kontext"
Bria = "bria"
class ModelType(str, Enum):

View File

@@ -1,10 +0,0 @@
dist/
static/
.husky/
node_modules/
patches/
stats.html
index.html
.yarn/
*.scss
src/services/api/schema.ts

View File

@@ -1,88 +0,0 @@
module.exports = {
extends: ['@invoke-ai/eslint-config-react'],
plugins: ['path', 'i18next'],
rules: {
// TODO(psyche): Enable this rule. Requires no default exports in components - many changes.
'react-refresh/only-export-components': 'off',
// TODO(psyche): Enable this rule. Requires a lot of eslint-disable-next-line comments.
'@typescript-eslint/consistent-type-assertions': 'off',
// https://github.com/qdanik/eslint-plugin-path
'path/no-relative-imports': ['error', { maxDepth: 0 }],
// https://github.com/edvardchen/eslint-plugin-i18next/blob/HEAD/docs/rules/no-literal-string.md
// TODO: ENABLE THIS RULE BEFORE v6.0.0
// 'i18next/no-literal-string': 'error',
// https://eslint.org/docs/latest/rules/no-console
'no-console': 'warn',
// https://eslint.org/docs/latest/rules/no-promise-executor-return
'no-promise-executor-return': 'error',
// https://eslint.org/docs/latest/rules/require-await
'require-await': 'error',
// Restrict setActiveTab calls to only use-navigation-api.tsx
'no-restricted-syntax': [
'error',
{
selector: 'CallExpression[callee.name="setActiveTab"]',
message:
'setActiveTab() can only be called from use-navigation-api.tsx. Use navigationApi.switchToTab() instead.',
},
],
// TODO: ENABLE THIS RULE BEFORE v6.0.0
'react/display-name': 'off',
'no-restricted-properties': [
'error',
{
object: 'crypto',
property: 'randomUUID',
message: 'Use of crypto.randomUUID is not allowed as it is not available in all browsers.',
},
{
object: 'navigator',
property: 'clipboard',
message:
'The Clipboard API is not available by default in Firefox. Use the `useClipboard` hook instead, which wraps clipboard access to prevent errors.',
},
],
'no-restricted-imports': [
'error',
{
paths: [
{
name: 'lodash-es',
importNames: ['isEqual'],
message: 'Please use objectEquals from @observ33r/object-equals instead.',
},
{
name: 'lodash-es',
message: 'Please use es-toolkit instead.',
},
{
name: 'es-toolkit',
importNames: ['isEqual'],
message: 'Please use objectEquals from @observ33r/object-equals instead.',
},
],
},
],
},
overrides: [
/**
* Allow setActiveTab calls only in use-navigation-api.tsx
*/
{
files: ['**/use-navigation-api.tsx'],
rules: {
'no-restricted-syntax': 'off',
},
},
/**
* Overrides for stories
*/
{
files: ['*.stories.tsx'],
rules: {
// We may not have i18n available in stories.
'i18next/no-literal-string': 'off',
},
},
],
};

View File

@@ -14,3 +14,4 @@ static/
src/theme/css/overlayscrollbars.css
src/theme_/css/overlayscrollbars.css
pnpm-lock.yaml
.claude

View File

@@ -1,11 +0,0 @@
module.exports = {
...require('@invoke-ai/prettier-config-react'),
overrides: [
{
files: ['public/locales/*.json'],
options: {
tabWidth: 4,
},
},
],
};

View File

@@ -0,0 +1,17 @@
{
"$schema": "http://json.schemastore.org/prettierrc",
"trailingComma": "es5",
"printWidth": 120,
"tabWidth": 2,
"semi": true,
"singleQuote": true,
"endOfLine": "auto",
"overrides": [
{
"files": ["public/locales/*.json"],
"options": {
"tabWidth": 4
}
}
]
}

View File

@@ -1,21 +1,23 @@
import { PropsWithChildren, memo, useEffect } from 'react';
import { modelChanged } from '../src/features/controlLayers/store/paramsSlice';
import { useAppDispatch } from '../src/app/store/storeHooks';
import { useGlobalModifiersInit } from '@invoke-ai/ui-library';
import type { PropsWithChildren } from 'react';
import { memo, useEffect } from 'react';
import { useAppDispatch } from '../src/app/store/storeHooks';
import { modelChanged } from '../src/features/controlLayers/store/paramsSlice';
/**
* Initializes some state for storybook. Must be in a different component
* so that it is run inside the redux context.
*/
export const ReduxInit = memo((props: PropsWithChildren) => {
export const ReduxInit = memo(({ children }: PropsWithChildren) => {
const dispatch = useAppDispatch();
useGlobalModifiersInit();
useEffect(() => {
dispatch(
modelChanged({ model: { key: 'test_model', hash: 'some_hash', name: 'some name', base: 'sd-1', type: 'main' } })
);
}, []);
}, [dispatch]);
return props.children;
return children;
});
ReduxInit.displayName = 'ReduxInit';

View File

@@ -2,19 +2,13 @@ import type { StorybookConfig } from '@storybook/react-vite';
const config: StorybookConfig = {
stories: ['../src/**/*.mdx', '../src/**/*.stories.@(js|jsx|mjs|ts|tsx)'],
addons: [
'@storybook/addon-links',
'@storybook/addon-essentials',
'@storybook/addon-interactions',
'@storybook/addon-storysource',
],
addons: ['@storybook/addon-links', '@storybook/addon-docs'],
framework: {
name: '@storybook/react-vite',
options: {},
},
docs: {
autodocs: 'tag',
},
core: {
disableTelemetry: true,
},

View File

@@ -1,5 +1,5 @@
import { addons } from '@storybook/manager-api';
import { themes } from '@storybook/theming';
import { addons } from 'storybook/manager-api';
import { themes } from 'storybook/theming';
addons.setConfig({
theme: themes.dark,

View File

@@ -1,17 +1,18 @@
import { Preview } from '@storybook/react';
import { themes } from '@storybook/theming';
import type { Preview } from '@storybook/react-vite';
import { themes } from 'storybook/theming';
import { $store } from 'app/store/nanostores/store';
import i18n from 'i18next';
import { initReactI18next } from 'react-i18next';
import { Provider } from 'react-redux';
import ThemeLocaleProvider from '../src/app/components/ThemeLocaleProvider';
import { $baseUrl } from '../src/app/store/nanostores/baseUrl';
import { createStore } from '../src/app/store/store';
// TODO: Disabled for IDE performance issues with our translation JSON
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
// @ts-ignore
import translationEN from '../public/locales/en.json';
import ThemeLocaleProvider from '../src/app/components/ThemeLocaleProvider';
import { $baseUrl } from '../src/app/store/nanostores/baseUrl';
import { createStore } from '../src/app/store/store';
import { ReduxInit } from './ReduxInit';
import { $store } from 'app/store/nanostores/store';
i18n.use(initReactI18next).init({
lng: 'en',
@@ -46,6 +47,7 @@ const preview: Preview = {
parameters: {
docs: {
theme: themes.dark,
codePanel: true,
},
},
};

View File

@@ -0,0 +1,242 @@
import js from '@eslint/js';
import typescriptEslint from '@typescript-eslint/eslint-plugin';
import typescriptParser from '@typescript-eslint/parser';
import pluginI18Next from 'eslint-plugin-i18next';
import pluginImport from 'eslint-plugin-import';
import pluginPath from 'eslint-plugin-path';
import pluginReact from 'eslint-plugin-react';
import pluginReactHooks from 'eslint-plugin-react-hooks';
import pluginReactRefresh from 'eslint-plugin-react-refresh';
import pluginSimpleImportSort from 'eslint-plugin-simple-import-sort';
import pluginStorybook from 'eslint-plugin-storybook';
import pluginUnusedImports from 'eslint-plugin-unused-imports';
import globals from 'globals';
export default [
js.configs.recommended,
{
languageOptions: {
parser: typescriptParser,
parserOptions: {
ecmaFeatures: {
jsx: true,
},
},
globals: {
...globals.browser,
...globals.node,
GlobalCompositeOperation: 'readonly',
RequestInit: 'readonly',
},
},
files: ['**/*.ts', '**/*.tsx', '**/*.js', '**/*.jsx'],
plugins: {
react: pluginReact,
'@typescript-eslint': typescriptEslint,
'react-hooks': pluginReactHooks,
import: pluginImport,
'unused-imports': pluginUnusedImports,
'simple-import-sort': pluginSimpleImportSort,
'react-refresh': pluginReactRefresh.configs.vite,
path: pluginPath,
i18next: pluginI18Next,
storybook: pluginStorybook,
},
rules: {
...typescriptEslint.configs.recommended.rules,
...pluginReact.configs.recommended.rules,
...pluginReact.configs['jsx-runtime'].rules,
...pluginReactHooks.configs.recommended.rules,
...pluginStorybook.configs.recommended.rules,
'react/jsx-no-bind': [
'error',
{
allowBind: true,
},
],
'react/jsx-curly-brace-presence': [
'error',
{
props: 'never',
children: 'never',
},
],
'react-hooks/exhaustive-deps': 'error',
curly: 'error',
'no-var': 'error',
'brace-style': 'error',
'prefer-template': 'error',
radix: 'error',
'space-before-blocks': 'error',
eqeqeq: 'error',
'one-var': ['error', 'never'],
'no-eval': 'error',
'no-extend-native': 'error',
'no-implied-eval': 'error',
'no-label-var': 'error',
'no-return-assign': 'error',
'no-sequences': 'error',
'no-template-curly-in-string': 'error',
'no-throw-literal': 'error',
'no-unmodified-loop-condition': 'error',
'import/no-duplicates': 'error',
'import/prefer-default-export': 'off',
'unused-imports/no-unused-imports': 'error',
'unused-imports/no-unused-vars': [
'error',
{
vars: 'all',
varsIgnorePattern: '^_',
args: 'after-used',
argsIgnorePattern: '^_',
},
],
'simple-import-sort/imports': 'error',
'simple-import-sort/exports': 'error',
'@typescript-eslint/no-unused-vars': 'off',
'@typescript-eslint/ban-ts-comment': [
'error',
{
'ts-expect-error': 'allow-with-description',
'ts-ignore': true,
'ts-nocheck': true,
'ts-check': false,
minimumDescriptionLength: 10,
},
],
'@typescript-eslint/no-empty-interface': [
'error',
{
allowSingleExtends: true,
},
],
'@typescript-eslint/consistent-type-imports': [
'error',
{
prefer: 'type-imports',
fixStyle: 'separate-type-imports',
disallowTypeAnnotations: true,
},
],
'@typescript-eslint/no-import-type-side-effects': 'error',
'@typescript-eslint/consistent-type-assertions': [
'error',
{
assertionStyle: 'as',
},
],
'path/no-relative-imports': [
'error',
{
maxDepth: 0,
},
],
'no-console': 'warn',
'no-promise-executor-return': 'error',
'require-await': 'error',
'no-restricted-syntax': [
'error',
{
selector: 'CallExpression[callee.name="setActiveTab"]',
message:
'setActiveTab() can only be called from use-navigation-api.tsx. Use navigationApi.switchToTab() instead.',
},
],
'no-restricted-properties': [
'error',
{
object: 'crypto',
property: 'randomUUID',
message: 'Use of crypto.randomUUID is not allowed as it is not available in all browsers.',
},
{
object: 'navigator',
property: 'clipboard',
message:
'The Clipboard API is not available by default in Firefox. Use the `useClipboard` hook instead, which wraps clipboard access to prevent errors.',
},
],
// Typescript handles this for us: https://eslint.org/docs/latest/rules/no-redeclare#handled_by_typescript
'no-redeclare': 'off',
'no-restricted-imports': [
'error',
{
paths: [
{
name: 'lodash-es',
importNames: ['isEqual'],
message: 'Please use objectEquals from @observ33r/object-equals instead.',
},
{
name: 'lodash-es',
message: 'Please use es-toolkit instead.',
},
{
name: 'es-toolkit',
importNames: ['isEqual'],
message: 'Please use objectEquals from @observ33r/object-equals instead.',
},
],
},
],
},
settings: {
react: {
version: 'detect',
},
},
},
{
files: ['**/use-navigation-api.tsx'],
rules: {
'no-restricted-syntax': 'off',
},
},
{
files: ['**/*.stories.tsx'],
rules: {
'i18next/no-literal-string': 'off',
},
},
{
ignores: [
'**/dist/',
'**/static/',
'**/.husky/',
'**/node_modules/',
'**/patches/',
'**/stats.html',
'**/index.html',
'**/.yarn/',
'**/*.scss',
'src/services/api/schema.ts',
'.prettierrc.js',
'.storybook',
],
},
];

View File

@@ -14,6 +14,7 @@ const config: KnipConfig = {
'src/features/controlLayers/konva/util.ts',
// Will be using this
'src/common/hooks/useAsyncState.ts',
'src/app/store/use-debounced-app-selector.ts',
],
ignoreBinaries: ['only-allow'],
paths: {

View File

@@ -47,25 +47,25 @@
"@fontsource-variable/inter": "^5.2.6",
"@invoke-ai/ui-library": "^0.0.46",
"@nanostores/react": "^1.0.0",
"@observ33r/object-equals": "^1.1.4",
"@observ33r/object-equals": "^1.1.5",
"@reduxjs/toolkit": "2.8.2",
"@roarr/browser-log-writer": "^1.3.0",
"@xyflow/react": "^12.7.1",
"ag-psd": "^28.2.1",
"@xyflow/react": "^12.8.2",
"ag-psd": "^28.2.2",
"async-mutex": "^0.5.0",
"chakra-react-select": "^4.9.2",
"cmdk": "^1.1.1",
"compare-versions": "^6.1.1",
"dockview": "^4.4.0",
"es-toolkit": "^1.39.5",
"dockview": "^4.4.1",
"es-toolkit": "^1.39.7",
"filesize": "^10.1.6",
"fracturedjsonjs": "^4.1.0",
"framer-motion": "^11.10.0",
"i18next": "^25.2.1",
"i18next": "^25.3.2",
"i18next-http-backend": "^3.0.2",
"idb-keyval": "6.2.1",
"idb-keyval": "6.2.2",
"jsondiffpatch": "^0.7.3",
"konva": "^9.3.20",
"konva": "^9.3.22",
"linkify-react": "^4.3.1",
"linkifyjs": "^4.3.1",
"lru-cache": "^11.1.0",
@@ -83,7 +83,7 @@
"react-dom": "^18.3.1",
"react-dropzone": "^14.3.8",
"react-error-boundary": "^5.0.0",
"react-hook-form": "^7.58.1",
"react-hook-form": "^7.60.0",
"react-hotkeys-hook": "4.5.0",
"react-i18next": "^15.5.3",
"react-icons": "^5.5.0",
@@ -103,7 +103,7 @@
"use-debounce": "^10.0.5",
"use-device-pixel-ratio": "^1.1.2",
"uuid": "^11.1.0",
"zod": "^3.25.67",
"zod": "^4.0.5",
"zod-validation-error": "^3.5.2"
},
"peerDependencies": {
@@ -111,39 +111,43 @@
"react-dom": "^18.2.0"
},
"devDependencies": {
"@invoke-ai/eslint-config-react": "^0.0.14",
"@invoke-ai/prettier-config-react": "^0.0.7",
"@storybook/addon-essentials": "^8.6.12",
"@storybook/addon-interactions": "^8.6.12",
"@storybook/addon-links": "^8.6.12",
"@storybook/addon-storysource": "^8.6.12",
"@storybook/manager-api": "^8.6.12",
"@storybook/react": "^8.6.12",
"@storybook/react-vite": "^8.6.12",
"@storybook/theming": "^8.6.12",
"@eslint/js": "^9.31.0",
"@storybook/addon-docs": "^9.0.17",
"@storybook/addon-links": "^9.0.17",
"@storybook/react-vite": "^9.0.17",
"@types/node": "^22.15.1",
"@types/react": "^18.3.11",
"@types/react-dom": "^18.3.0",
"@types/uuid": "^10.0.0",
"@typescript-eslint/eslint-plugin": "^8.37.0",
"@typescript-eslint/parser": "^8.37.0",
"@vitejs/plugin-react-swc": "^3.9.0",
"@vitest/coverage-v8": "^3.1.2",
"@vitest/ui": "^3.1.2",
"concurrently": "^9.1.2",
"csstype": "^3.1.3",
"dpdm": "^3.14.0",
"eslint": "^8.57.1",
"eslint-plugin-i18next": "^6.1.1",
"eslint-plugin-path": "^1.3.0",
"eslint": "^9.31.0",
"eslint-plugin-i18next": "^6.1.2",
"eslint-plugin-import": "^2.29.1",
"eslint-plugin-path": "^2.0.3",
"eslint-plugin-react": "^7.33.2",
"eslint-plugin-react-hooks": "^5.2.0",
"eslint-plugin-react-refresh": "^0.4.5",
"eslint-plugin-simple-import-sort": "^12.0.0",
"eslint-plugin-storybook": "^9.0.17",
"eslint-plugin-unused-imports": "^4.1.4",
"globals": "^16.3.0",
"knip": "^5.61.3",
"openapi-types": "^12.1.3",
"openapi-typescript": "^7.6.1",
"prettier": "^3.5.3",
"rollup-plugin-visualizer": "^5.14.0",
"storybook": "^8.6.12",
"rollup-plugin-visualizer": "^6.0.3",
"storybook": "^9.0.17",
"tsafe": "^1.8.5",
"type-fest": "^4.40.0",
"typescript": "^5.8.3",
"vite": "^7.0.2",
"vite": "^7.0.5",
"vite-plugin-css-injected-by-js": "^3.5.2",
"vite-plugin-dts": "^4.5.3",
"vite-plugin-eslint": "^1.8.1",

File diff suppressed because it is too large Load Diff

View File

@@ -574,6 +574,10 @@
"title": "Transform",
"desc": "Transform the selected layer."
},
"invertMask": {
"title": "Invert Mask",
"desc": "Invert the selected inpaint mask, creating a new mask with opposite transparency."
},
"applyFilter": {
"title": "Apply Filter",
"desc": "Apply the pending filter to the selected layer."
@@ -599,6 +603,10 @@
"toggleNonRasterLayers": {
"title": "Toggle Non-Raster Layers",
"desc": "Show or hide all non-raster layer categories (Control Layers, Inpaint Masks, Regional Guidance)."
},
"fitBboxToMasks": {
"title": "Fit Bbox To Masks",
"desc": "Automatically adjust the generation bounding box to fit visible inpaint masks"
}
},
"workflows": {
@@ -1125,7 +1133,23 @@
"addItem": "Add Item",
"generateValues": "Generate Values",
"floatRangeGenerator": "Float Range Generator",
"integerRangeGenerator": "Integer Range Generator"
"integerRangeGenerator": "Integer Range Generator",
"layout": {
"autoLayout": "Auto Layout",
"layeringStrategy": "Layering Strategy",
"networkSimplex": "Network Simplex",
"longestPath": "Longest Path",
"nodeSpacing": "Node Spacing",
"layerSpacing": "Layer Spacing",
"layoutDirection": "Layout Direction",
"layoutDirectionRight": "Right",
"layoutDirectionDown": "Down",
"alignment": "Node Alignment",
"alignmentUL": "Top Left",
"alignmentDL": "Bottom Left",
"alignmentUR": "Top Right",
"alignmentDR": "Bottom Right"
}
},
"parameters": {
"aspect": "Aspect",
@@ -1407,7 +1431,15 @@
"sentToUpscale": "Sent to Upscale",
"promptGenerationStarted": "Prompt generation started",
"uploadAndPromptGenerationFailed": "Failed to upload image and generate prompt",
"promptExpansionFailed": "We ran into an issue. Please try prompt expansion again."
"promptExpansionFailed": "We ran into an issue. Please try prompt expansion again.",
"maskInverted": "Mask Inverted",
"maskInvertFailed": "Failed to Invert Mask",
"noVisibleMasks": "No Visible Masks",
"noVisibleMasksDesc": "Create or enable at least one inpaint mask to invert",
"noInpaintMaskSelected": "No Inpaint Mask Selected",
"noInpaintMaskSelectedDesc": "Select an inpaint mask to invert",
"invalidBbox": "Invalid Bounding Box",
"invalidBboxDesc": "The bounding box has no valid dimensions"
},
"popovers": {
"clipSkip": {
@@ -1775,6 +1807,20 @@
"Structure controls how closely the output image will keep to the layout of the original. Low structure allows major changes, while high structure strictly maintains the original composition and layout."
]
},
"tileSize": {
"heading": "Tile Size",
"paragraphs": [
"Controls the size of tiles used during the upscaling process. Larger tiles use more memory but may produce better results.",
"SD1.5 models default to 768, while SDXL models default to 1024. Reduce tile size if you encounter memory issues."
]
},
"tileOverlap": {
"heading": "Tile Overlap",
"paragraphs": [
"Controls the overlap between adjacent tiles during upscaling. Higher overlap values help reduce visible seams between tiles but use more memory.",
"The default value of 128 works well for most cases, but you can adjust based on your specific needs and memory constraints."
]
},
"fluxDevLicense": {
"heading": "Non-Commercial License",
"paragraphs": [
@@ -1926,6 +1972,7 @@
"canvas": "Canvas",
"bookmark": "Bookmark for Quick Switch",
"fitBboxToLayers": "Fit Bbox To Layers",
"fitBboxToMasks": "Fit Bbox To Masks",
"removeBookmark": "Remove Bookmark",
"saveCanvasToGallery": "Save Canvas to Gallery",
"saveBboxToGallery": "Save Bbox to Gallery",
@@ -1962,7 +2009,6 @@
"recalculateRects": "Recalculate Rects",
"clipToBbox": "Clip Strokes to Bbox",
"outputOnlyMaskedRegions": "Output Only Generated Regions",
"saveAllImagesToGallery": "Save All Images to Gallery",
"addLayer": "Add Layer",
"duplicate": "Duplicate",
"moveToFront": "Move to Front",
@@ -1991,6 +2037,7 @@
"rasterLayer": "Raster Layer",
"controlLayer": "Control Layer",
"inpaintMask": "Inpaint Mask",
"invertMask": "Invert Mask",
"regionalGuidance": "Regional Guidance",
"referenceImageRegional": "Reference Image (Regional)",
"referenceImageGlobal": "Reference Image (Global)",
@@ -2087,9 +2134,9 @@
"resetCanvasLayers": "Reset Canvas Layers",
"resetGenerationSettings": "Reset Generation Settings",
"replaceCurrent": "Replace Current",
"controlLayerEmptyState": "<UploadButton>Upload an image</UploadButton>, drag an image from the <GalleryButton>gallery</GalleryButton> onto this layer, <PullBboxButton>pull the bounding box into this layer</PullBboxButton>, or draw on the canvas to get started.",
"referenceImageEmptyStateWithCanvasOptions": "<UploadButton>Upload an image</UploadButton>, drag an image from the <GalleryButton>gallery</GalleryButton> onto this Reference Image or <PullBboxButton>pull the bounding box into this Reference Image</PullBboxButton> to get started.",
"referenceImageEmptyState": "<UploadButton>Upload an image</UploadButton> or drag an image from the <GalleryButton>gallery</GalleryButton> onto this Reference Image to get started.",
"controlLayerEmptyState": "<UploadButton>Upload an image</UploadButton>, drag an image from the gallery onto this layer, <PullBboxButton>pull the bounding box into this layer</PullBboxButton>, or draw on the canvas to get started.",
"referenceImageEmptyStateWithCanvasOptions": "<UploadButton>Upload an image</UploadButton>, drag an image from the gallery onto this Reference Image or <PullBboxButton>pull the bounding box into this Reference Image</PullBboxButton> to get started.",
"referenceImageEmptyState": "<UploadButton>Upload an image</UploadButton> or drag an image from the gallery onto this Reference Image to get started.",
"uploadOrDragAnImage": "Drag an image from the gallery or <UploadButton>upload an image</UploadButton>.",
"imageNoise": "Image Noise",
"denoiseLimit": "Denoise Limit",
@@ -2332,7 +2379,8 @@
"alert": "Preserving Masked Region"
},
"saveAllImagesToGallery": {
"alert": "Saving All Images to Gallery"
"label": "Send New Generations to Gallery",
"alert": "Sending new generations to Gallery, bypassing Canvas"
},
"isolatedStagingPreview": "Isolated Staging Preview",
"isolatedPreview": "Isolated Preview",
@@ -2396,6 +2444,9 @@
"upscaleModel": "Upscale Model",
"postProcessingModel": "Post-Processing Model",
"scale": "Scale",
"tileControl": "Tile Control",
"tileSize": "Tile Size",
"tileOverlap": "Tile Overlap",
"postProcessingMissingModelWarning": "Visit the <LinkComponent>Model Manager</LinkComponent> to install a post-processing (image to image) model.",
"missingModelsWarning": "Visit the <LinkComponent>Model Manager</LinkComponent> to install the required models:",
"mainModelDesc": "Main model (SD1.5 or SDXL architecture)",

View File

@@ -30,16 +30,16 @@ const App = ({ config = DEFAULT_CONFIG, studioInitAction }: Props) => {
}, [clearStorage]);
return (
<ErrorBoundary onReset={handleReset} FallbackComponent={AppErrorBoundaryFallback}>
<ThemeLocaleProvider>
<ThemeLocaleProvider>
<ErrorBoundary onReset={handleReset} FallbackComponent={AppErrorBoundaryFallback}>
<Box id="invoke-app-wrapper" w="100dvw" h="100dvh" position="relative" overflow="hidden">
<AppContent />
{!didStudioInit && <Loading />}
</Box>
<GlobalHookIsolator config={config} studioInitAction={studioInitAction} />
<GlobalModalIsolator />
</ThemeLocaleProvider>
</ErrorBoundary>
</ErrorBoundary>
</ThemeLocaleProvider>
);
};

View File

@@ -2,6 +2,7 @@ import { useGlobalModifiersInit } from '@invoke-ai/ui-library';
import { setupListeners } from '@reduxjs/toolkit/query';
import type { StudioInitAction } from 'app/hooks/useStudioInitAction';
import { useStudioInitAction } from 'app/hooks/useStudioInitAction';
import { useSyncLangDirection } from 'app/hooks/useSyncLangDirection';
import { useSyncQueueStatus } from 'app/hooks/useSyncQueueStatus';
import { useLogger } from 'app/logging/useLogger';
import { useSyncLoggingConfig } from 'app/logging/useSyncLoggingConfig';
@@ -15,6 +16,8 @@ import { useDndMonitor } from 'features/dnd/useDndMonitor';
import { useDynamicPromptsWatcher } from 'features/dynamicPrompts/hooks/useDynamicPromptsWatcher';
import { useStarterModelsToast } from 'features/modelManagerV2/hooks/useStarterModelsToast';
import { useWorkflowBuilderWatcher } from 'features/nodes/components/sidePanel/workflow/IsolatedWorkflowBuilderWatcher';
import { useSyncExecutionState } from 'features/nodes/hooks/useNodeExecutionState';
import { useSyncNodeErrors } from 'features/nodes/store/util/fieldValidators';
import { useReadinessWatcher } from 'features/queue/store/readiness';
import { configChanged } from 'features/system/store/configSlice';
import { selectLanguage } from 'features/system/store/systemSelectors';
@@ -47,10 +50,13 @@ export const GlobalHookIsolator = memo(
useCloseChakraTooltipsOnDragFix();
useNavigationApi();
useDndMonitor();
useSyncNodeErrors();
useSyncLangDirection();
// Persistent subscription to the queue counts query - canvas relies on this to know if there are pending
// and/or in progress canvas sessions.
useGetQueueCountsByDestinationQuery(queueCountArg);
useSyncExecutionState();
useEffect(() => {
i18n.changeLanguage(language);

View File

@@ -1,10 +1,6 @@
import { GlobalImageHotkeys } from 'app/components/GlobalImageHotkeys';
import ChangeBoardModal from 'features/changeBoardModal/components/ChangeBoardModal';
import { CanvasPasteModal } from 'features/controlLayers/components/CanvasPasteModal';
import {
NewCanvasSessionDialog,
NewGallerySessionDialog,
} from 'features/controlLayers/components/NewSessionConfirmationAlertDialog';
import { CanvasManagerProviderGate } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import { DeleteImageModal } from 'features/deleteImageModal/components/DeleteImageModal';
import { FullscreenDropzone } from 'features/dnd/FullscreenDropzone';
@@ -50,8 +46,6 @@ export const GlobalModalIsolator = memo(() => {
<RefreshAfterResetModal />
<DeleteBoardModal />
<GlobalImageHotkeys />
<NewGallerySessionDialog />
<NewCanvasSessionDialog />
<ImageContextMenu />
<FullscreenDropzone />
<VideosModal />

View File

@@ -317,7 +317,7 @@ const InvokeAIUI = ({
if (import.meta.env.MODE === 'development') {
window.$store = $store;
}
() => {
return () => {
$store.set(undefined);
if (import.meta.env.MODE === 'development') {
window.$store = undefined;

View File

@@ -3,43 +3,39 @@ import 'overlayscrollbars/overlayscrollbars.css';
import '@xyflow/react/dist/base.css';
import 'common/components/OverlayScrollbars/overlayscrollbars.css';
import { ChakraProvider, DarkMode, extendTheme, theme as _theme, TOAST_OPTIONS } from '@invoke-ai/ui-library';
import { ChakraProvider, DarkMode, extendTheme, theme as baseTheme, TOAST_OPTIONS } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { $direction } from 'app/hooks/useSyncLangDirection';
import type { ReactNode } from 'react';
import { memo, useEffect, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { memo, useMemo } from 'react';
type ThemeLocaleProviderProps = {
children: ReactNode;
};
const buildTheme = (direction: 'ltr' | 'rtl') => {
return extendTheme({
...baseTheme,
direction,
shadows: {
...baseTheme.shadows,
selected:
'inset 0px 0px 0px 3px var(--invoke-colors-invokeBlue-500), inset 0px 0px 0px 4px var(--invoke-colors-invokeBlue-800)',
hoverSelected:
'inset 0px 0px 0px 3px var(--invoke-colors-invokeBlue-400), inset 0px 0px 0px 4px var(--invoke-colors-invokeBlue-800)',
hoverUnselected:
'inset 0px 0px 0px 2px var(--invoke-colors-invokeBlue-300), inset 0px 0px 0px 3px var(--invoke-colors-invokeBlue-800)',
selectedForCompare:
'inset 0px 0px 0px 3px var(--invoke-colors-invokeGreen-300), inset 0px 0px 0px 4px var(--invoke-colors-invokeGreen-800)',
hoverSelectedForCompare:
'inset 0px 0px 0px 3px var(--invoke-colors-invokeGreen-200), inset 0px 0px 0px 4px var(--invoke-colors-invokeGreen-800)',
},
});
};
function ThemeLocaleProvider({ children }: ThemeLocaleProviderProps) {
const { i18n } = useTranslation();
const direction = i18n.dir();
const theme = useMemo(() => {
return extendTheme({
..._theme,
direction,
shadows: {
..._theme.shadows,
selected:
'inset 0px 0px 0px 3px var(--invoke-colors-invokeBlue-500), inset 0px 0px 0px 4px var(--invoke-colors-invokeBlue-800)',
hoverSelected:
'inset 0px 0px 0px 3px var(--invoke-colors-invokeBlue-400), inset 0px 0px 0px 4px var(--invoke-colors-invokeBlue-800)',
hoverUnselected:
'inset 0px 0px 0px 2px var(--invoke-colors-invokeBlue-300), inset 0px 0px 0px 3px var(--invoke-colors-invokeBlue-800)',
selectedForCompare:
'inset 0px 0px 0px 3px var(--invoke-colors-invokeGreen-300), inset 0px 0px 0px 4px var(--invoke-colors-invokeGreen-800)',
hoverSelectedForCompare:
'inset 0px 0px 0px 3px var(--invoke-colors-invokeGreen-200), inset 0px 0px 0px 4px var(--invoke-colors-invokeGreen-800)',
},
});
}, [direction]);
useEffect(() => {
document.body.dir = direction;
}, [direction]);
const direction = useStore($direction);
const theme = useMemo(() => buildTheme(direction), [direction]);
return (
<ChakraProvider theme={theme} toastOptions={TOAST_OPTIONS}>

View File

@@ -21,7 +21,6 @@ import { $isStylePresetsMenuOpen, activeStylePresetIdChanged } from 'features/st
import { toast } from 'features/toast/toast';
import { navigationApi } from 'features/ui/layouts/navigation-api';
import { LAUNCHPAD_PANEL_ID, WORKSPACE_PANEL_ID } from 'features/ui/layouts/shared';
import { activeTabCanvasRightPanelChanged } from 'features/ui/store/uiSlice';
import { useLoadWorkflowWithDialog } from 'features/workflowLibrary/components/LoadWorkflowConfirmationAlertDialog';
import { atom } from 'nanostores';
import { useCallback, useEffect } from 'react';
@@ -165,7 +164,6 @@ export const useStudioInitAction = (action?: StudioInitAction) => {
// Go to the generate tab, open the launchpad
await navigationApi.focusPanel('generate', LAUNCHPAD_PANEL_ID);
store.dispatch(paramsReset());
store.dispatch(activeTabCanvasRightPanelChanged('gallery'));
break;
case 'canvas':
// Go to the canvas tab, open the launchpad

View File

@@ -0,0 +1,36 @@
import { useAssertSingleton } from 'common/hooks/useAssertSingleton';
import { atom } from 'nanostores';
import { useEffect } from 'react';
import { useTranslation } from 'react-i18next';
/**
* Global atom storing the language direction, to be consumed by the Chakra theme.
*
* Why do we need this? We have a kind of catch-22:
* - The Chakra theme needs to know the language direction to apply the correct styles.
* - The language direction is determined by i18n and the language selection.
* - We want our error boundary to be themed.
* - It's possible that i18n can throw if the language selection is invalid or not supported.
*
* Previously, we had the logic in this file in the theme provider, which wrapped the error boundary. The error
* was properly themed. But then, if i18n threw in the theme provider, the error boundary does not catch the
* error. The app would crash to a white screen.
*
* We tried swapping the component hierarchy so that the error boundary wraps the theme provider, but then the
* error boundary isn't themed!
*
* The solution is to move this i18n direction logic out of the theme provider and into a hook that we can use
* within the error boundary. The error boundary will be themed, _and_ catch any i18n errors.
*/
export const $direction = atom<'ltr' | 'rtl'>('ltr');
export const useSyncLangDirection = () => {
useAssertSingleton('useSyncLangDirection');
const { i18n, t } = useTranslation();
useEffect(() => {
const direction = i18n.dir();
$direction.set(direction);
document.body.dir = direction;
}, [i18n, t]);
};

View File

@@ -2,7 +2,7 @@ import { createLogWriter } from '@roarr/browser-log-writer';
import { atom } from 'nanostores';
import type { Logger, MessageSerializer } from 'roarr';
import { ROARR, Roarr } from 'roarr';
import { z } from 'zod/v4';
import { z } from 'zod';
const serializeMessage: MessageSerializer = (message) => {
return JSON.stringify(message);

View File

@@ -1,7 +1,7 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { bboxSyncedToOptimalDimension, rgRefImageModelChanged } from 'features/controlLayers/store/canvasSlice';
import { selectIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { buildSelectIsStaging, selectCanvasSessionId } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { loraDeleted } from 'features/controlLayers/store/lorasSlice';
import { modelChanged, syncedToOptimalDimension, vaeSelected } from 'features/controlLayers/store/paramsSlice';
import { refImageModelChanged, selectReferenceImageEntities } from 'features/controlLayers/store/refImagesSlice';
@@ -152,7 +152,8 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) =
if (modelBase !== state.params.model?.base) {
// Sync generate tab settings whenever the model base changes
dispatch(syncedToOptimalDimension());
if (!selectIsStaging(state)) {
const isStaging = buildSelectIsStaging(selectCanvasSessionId(state))(state);
if (!isStaging) {
// Canvas tab only syncs if not staging
dispatch(bboxSyncedToOptimalDimension());
}

View File

@@ -15,7 +15,11 @@ import { refImageModelChanged, selectRefImagesSlice } from 'features/controlLaye
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
import { getEntityIdentifier, isFLUXReduxConfig, isIPAdapterConfig } from 'features/controlLayers/store/types';
import { modelSelected } from 'features/parameters/store/actions';
import { postProcessingModelChanged, upscaleModelChanged } from 'features/parameters/store/upscaleSlice';
import {
postProcessingModelChanged,
tileControlnetModelChanged,
upscaleModelChanged,
} from 'features/parameters/store/upscaleSlice';
import {
zParameterCLIPEmbedModel,
zParameterSpandrelImageToImageModel,
@@ -28,6 +32,7 @@ import type { AnyModelConfig } from 'services/api/types';
import {
isCLIPEmbedModelConfig,
isControlLayerModelConfig,
isControlNetModelConfig,
isFluxReduxModelConfig,
isFluxVAEModelConfig,
isIPAdapterModelConfig,
@@ -71,6 +76,7 @@ export const addModelsLoadedListener = (startAppListening: AppStartListening) =>
handleControlAdapterModels(models, state, dispatch, log);
handlePostProcessingModel(models, state, dispatch, log);
handleUpscaleModel(models, state, dispatch, log);
handleTileControlNetModel(models, state, dispatch, log);
handleIPAdapterModels(models, state, dispatch, log);
handleT5EncoderModels(models, state, dispatch, log);
handleCLIPEmbedModels(models, state, dispatch, log);
@@ -345,6 +351,46 @@ const handleUpscaleModel: ModelHandler = (models, state, dispatch, log) => {
}
};
const handleTileControlNetModel: ModelHandler = (models, state, dispatch, log) => {
const selectedTileControlNetModel = state.upscale.tileControlnetModel;
const controlNetModels = models.filter(isControlNetModelConfig);
// If the currently selected model is available, we don't need to do anything
if (selectedTileControlNetModel && controlNetModels.some((m) => m.key === selectedTileControlNetModel.key)) {
return;
}
// The only way we have to identify a model as a tile model is by its name containing 'tile' :)
const tileModel = controlNetModels.find((m) => m.name.toLowerCase().includes('tile'));
// If we have a tile model, select it
if (tileModel) {
log.debug(
{ selectedTileControlNetModel, tileModel },
'No selected tile ControlNet model or selected model is not available, selecting tile model'
);
dispatch(tileControlnetModelChanged(tileModel));
return;
}
// Otherwise, select the first available ControlNet model
const firstModel = controlNetModels[0] || null;
if (firstModel) {
log.debug(
{ selectedTileControlNetModel, firstModel },
'No tile ControlNet model found, selecting first available ControlNet model'
);
dispatch(tileControlnetModelChanged(firstModel));
return;
}
// No available models, we should clear the selected model - but only if we have one selected
if (selectedTileControlNetModel) {
log.debug({ selectedTileControlNetModel }, 'Selected tile ControlNet model is not available, clearing');
dispatch(tileControlnetModelChanged(null));
}
};
const handleT5EncoderModels: ModelHandler = (models, state, dispatch, log) => {
const selectedT5EncoderModel = state.params.t5EncoderModel;
const t5EncoderModels = models.filter((m) => isT5EncoderModelConfig(m));

View File

@@ -1,7 +1,7 @@
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { isNil } from 'es-toolkit';
import { bboxHeightChanged, bboxWidthChanged } from 'features/controlLayers/store/canvasSlice';
import { selectIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { buildSelectIsStaging, selectCanvasSessionId } from 'features/controlLayers/store/canvasStagingAreaSlice';
import {
heightChanged,
setCfgRescaleMultiplier,
@@ -115,7 +115,8 @@ export const addSetDefaultSettingsListener = (startAppListening: AppStartListeni
}
const setSizeOptions = { updateAspectRatio: true, clamp: true };
const isStaging = selectIsStaging(getState());
const isStaging = buildSelectIsStaging(selectCanvasSessionId(state))(state);
const activeTab = selectActiveTab(getState());
if (activeTab === 'generate') {
if (isParameterWidth(width)) {

View File

@@ -67,6 +67,8 @@ export type Feature =
| 'scale'
| 'creativity'
| 'structure'
| 'tileSize'
| 'tileOverlap'
| 'optimizedDenoising'
| 'fluxDevLicense';

View File

@@ -11,9 +11,13 @@ import {
Text,
} from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
import { typedMemo } from 'common/util/typedMemo';
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
import { selectPickerCompactViewStates } from 'features/ui/store/uiSelectors';
import { pickerCompactViewStateChanged } from 'features/ui/store/uiSlice';
import type { AnyStore, ReadableAtom, Task, WritableAtom } from 'nanostores';
import { atom, computed } from 'nanostores';
import type { StoreValues } from 'nanostores/computed';
@@ -140,6 +144,10 @@ const NoMatchesFallbackWrapper = typedMemo(({ children }: PropsWithChildren) =>
NoMatchesFallbackWrapper.displayName = 'NoMatchesFallbackWrapper';
type PickerProps<T extends object> = {
/**
* Unique identifier for this picker instance. Used to persist compact view state.
*/
pickerId?: string;
/**
* The options to display in the picker. This can be a flat array of options or an array of groups.
*/
@@ -204,10 +212,18 @@ type PickerProps<T extends object> = {
initialGroupStates?: GroupStatusMap;
};
const buildSelectIsCompactView = (pickerId?: string) =>
createSelector([selectPickerCompactViewStates], (compactViewStates) => {
if (!pickerId) {
return true;
}
return compactViewStates[pickerId] ?? true;
});
export type PickerContextState<T extends object> = {
$optionsOrGroups: WritableAtom<OptionOrGroup<T>[]>;
$groupStatusMap: WritableAtom<GroupStatusMap>;
$compactView: WritableAtom<boolean>;
isCompactView: boolean;
$activeOptionId: WritableAtom<string | undefined>;
$filteredOptions: WritableAtom<OptionOrGroup<T>[]>;
$flattenedFilteredOptions: ReadableAtom<T[]>;
@@ -233,6 +249,7 @@ export type PickerContextState<T extends object> = {
OptionComponent: React.ComponentType<{ option: T } & BoxProps>;
NextToSearchBar?: React.ReactNode;
searchable?: boolean;
pickerId?: string;
};
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
@@ -503,6 +520,7 @@ const countOptions = <T extends object>(optionsOrGroups: OptionOrGroup<T>[]) =>
export const Picker = typedMemo(<T extends object>(props: PickerProps<T>) => {
const {
pickerId,
getOptionId,
optionsOrGroups,
handleRef,
@@ -521,12 +539,12 @@ export const Picker = typedMemo(<T extends object>(props: PickerProps<T>) => {
} = props;
const rootRef = useRef<HTMLDivElement>(null);
const inputRef = useRef<HTMLInputElement>(null);
const { $groupStatusMap, $areAllGroupsDisabled, toggleGroup } = useTogglableGroups(
optionsOrGroups,
initialGroupStates
);
const $activeOptionId = useAtom(getFirstOptionId(optionsOrGroups, getOptionId));
const $compactView = useAtom(true);
const $optionsOrGroups = useAtom(optionsOrGroups);
const $totalOptionCount = useComputed([$optionsOrGroups], countOptions);
const $filteredOptions = useAtom<OptionOrGroup<T>[]>([]);
@@ -538,6 +556,9 @@ export const Picker = typedMemo(<T extends object>(props: PickerProps<T>) => {
const $searchTerm = useAtom('');
const $selectedItemId = useComputed([$selectedItem], (item) => (item ? getOptionId(item) : undefined));
const selectIsCompactView = useMemo(() => buildSelectIsCompactView(pickerId), [pickerId]);
const isCompactView = useAppSelector(selectIsCompactView);
const onSelectById = useCallback(
(id: string) => {
const options = $filteredOptions.get();
@@ -565,7 +586,7 @@ export const Picker = typedMemo(<T extends object>(props: PickerProps<T>) => {
({
$optionsOrGroups,
$groupStatusMap,
$compactView,
isCompactView,
$activeOptionId,
$filteredOptions,
$flattenedFilteredOptions,
@@ -591,11 +612,12 @@ export const Picker = typedMemo(<T extends object>(props: PickerProps<T>) => {
$hasOptions,
$hasFilteredOptions,
$filteredOptionsCount,
pickerId,
}) satisfies PickerContextState<T>,
[
$optionsOrGroups,
$groupStatusMap,
$compactView,
isCompactView,
$activeOptionId,
$filteredOptions,
$flattenedFilteredOptions,
@@ -619,6 +641,7 @@ export const Picker = typedMemo(<T extends object>(props: PickerProps<T>) => {
$hasOptions,
$hasFilteredOptions,
$filteredOptionsCount,
pickerId,
]
);
@@ -869,15 +892,17 @@ GroupToggleButtons.displayName = 'GroupToggleButtons';
const CompactViewToggleButton = typedMemo(<T extends object>() => {
const { t } = useTranslation();
const { $compactView } = usePickerContext<T>();
const compactView = useStore($compactView);
const dispatch = useAppDispatch();
const { isCompactView, pickerId } = usePickerContext<T>();
const onClick = useCallback(() => {
$compactView.set(!$compactView.get());
}, [$compactView]);
if (pickerId) {
dispatch(pickerCompactViewStateChanged({ pickerId, isCompact: !isCompactView }));
}
}, [dispatch, pickerId, isCompactView]);
const label = compactView ? t('common.fullView') : t('common.compactView');
const icon = compactView ? <PiArrowsOutLineVerticalBold /> : <PiArrowsInLineVerticalBold />;
const label = isCompactView ? t('common.fullView') : t('common.compactView');
const icon = isCompactView ? <PiArrowsOutLineVerticalBold /> : <PiArrowsInLineVerticalBold />;
return <IconButton aria-label={label} tooltip={label} size="sm" variant="ghost" icon={icon} onClick={onClick} />;
});
@@ -924,8 +949,7 @@ const listSx = {
} satisfies SystemStyleObject;
const PickerList = typedMemo(<T extends object>() => {
const { getOptionId, $compactView, $filteredOptions } = usePickerContext<T>();
const compactView = useStore($compactView);
const { getOptionId, isCompactView, $filteredOptions } = usePickerContext<T>();
const filteredOptions = useStore($filteredOptions);
if (filteredOptions.length === 0) {
@@ -934,10 +958,10 @@ const PickerList = typedMemo(<T extends object>() => {
return (
<ScrollableContent>
<Flex sx={listSx} data-is-compact={compactView}>
<Flex sx={listSx} data-is-compact={isCompactView}>
{filteredOptions.map((optionOrGroup, i) => {
if (isGroup(optionOrGroup)) {
const withDivider = !compactView && i < filteredOptions.length - 1;
const withDivider = !isCompactView && i < filteredOptions.length - 1;
return (
<React.Fragment key={optionOrGroup.id}>
<PickerGroup group={optionOrGroup} />
@@ -1079,14 +1103,13 @@ const groupHeaderSx = {
const PickerGroupHeader = typedMemo(<T extends object>({ group }: { group: Group<T> }) => {
const { t } = useTranslation();
const { $compactView } = usePickerContext<T>();
const compactView = useStore($compactView);
const { isCompactView } = usePickerContext<T>();
const color = getGroupColor(group);
const name = getGroupName(group);
const count = getGroupCount(group, t);
return (
<Flex sx={groupHeaderSx} data-is-compact={compactView}>
<Flex sx={groupHeaderSx} data-is-compact={isCompactView}>
<Flex gap={2} alignItems="center">
<Text fontSize="sm" fontWeight="semibold" color={color} noOfLines={1}>
{name}

View File

@@ -6,7 +6,7 @@ import { selectAutoAddBoardId } from 'features/gallery/store/gallerySelectors';
import { selectIsClientSideUploadEnabled } from 'features/system/store/configSlice';
import { toast } from 'features/toast/toast';
import { memo, useCallback } from 'react';
import type { FileRejection } from 'react-dropzone';
import type { Accept, FileRejection } from 'react-dropzone';
import { useDropzone } from 'react-dropzone';
import { useTranslation } from 'react-i18next';
import { PiUploadBold } from 'react-icons/pi';
@@ -15,6 +15,18 @@ import type { ImageDTO } from 'services/api/types';
import { assert } from 'tsafe';
import type { SetOptional } from 'type-fest';
const addUpperCaseReducer = (acc: string[], ext: string) => {
acc.push(ext);
acc.push(ext.toUpperCase());
return acc;
};
export const dropzoneAccept: Accept = {
'image/png': ['.png'].reduce(addUpperCaseReducer, [] as string[]),
'image/jpeg': ['.jpg', '.jpeg', '.png'].reduce(addUpperCaseReducer, [] as string[]),
'image/webp': ['.webp'].reduce(addUpperCaseReducer, [] as string[]),
};
import { useClientSideUpload } from './useClientSideUpload';
type UseImageUploadButtonArgs =
| {
@@ -164,11 +176,7 @@ export const useImageUploadButton = ({
getInputProps: getUploadInputProps,
open: openUploader,
} = useDropzone({
accept: {
'image/png': ['.png'],
'image/jpeg': ['.jpg', '.jpeg', '.png'],
'image/webp': ['.webp'],
},
accept: dropzoneAccept,
onDropAccepted,
onDropRejected,
disabled: isDisabled,

View File

@@ -1,3 +1,5 @@
export const preventDefault = (e: React.MouseEvent) => {
import type { MouseEvent } from 'react';
export const preventDefault = (e: MouseEvent) => {
e.preventDefault();
};

View File

@@ -1,10 +1,11 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
import type React from 'react';
import { memo } from 'react';
/**
* A typed version of React.memo, useful for components that take generics.
*/
export const typedMemo: <T extends keyof JSX.IntrinsicElements | React.JSXElementConstructor<any>>(
export const typedMemo: <T extends keyof React.JSX.IntrinsicElements | React.JSXElementConstructor<any>>(
component: T,
propsAreEqual?: (prevProps: React.ComponentProps<T>, nextProps: React.ComponentProps<T>) => boolean
) => T & { displayName?: string } = memo;

View File

@@ -1,4 +1,4 @@
import type { z } from 'zod/v4';
import type { z } from 'zod';
/**
* Helper to create a type guard from a zod schema. The type guard will infer the schema's TS type.

View File

@@ -1,7 +1,6 @@
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import { Combobox, ConfirmationAlertDialog, Flex, FormControl, Text } from '@invoke-ai/ui-library';
import { createSelector } from '@reduxjs/toolkit';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useAssertSingleton } from 'common/hooks/useAssertSingleton';
import {
@@ -14,7 +13,7 @@ import { useTranslation } from 'react-i18next';
import { useListAllBoardsQuery } from 'services/api/endpoints/boards';
import { useAddImagesToBoardMutation, useRemoveImagesFromBoardMutation } from 'services/api/endpoints/images';
const selectImagesToChange = createMemoizedSelector(
const selectImagesToChange = createSelector(
selectChangeBoardModalSlice,
(changeBoardModal) => changeBoardModal.image_names
);

View File

@@ -13,7 +13,7 @@ export const CanvasAlertsSaveAllImagesToGallery = memo(() => {
}
return (
<Alert status="info" borderRadius="base" fontSize="sm" shadow="md" w="fit-content">
<Alert status="warning" borderRadius="base" fontSize="sm" shadow="md" w="fit-content">
<AlertIcon />
<AlertTitle>{t('controlLayers.settings.saveAllImagesToGallery.alert')}</AlertTitle>
</Alert>

View File

@@ -57,21 +57,21 @@ const CanvasAlertsSelectedEntityStatusContent = memo(({ entityIdentifier, adapte
const alert = useMemo<AlertData | null>(() => {
if (isFiltering) {
return {
status: 'info',
status: 'warning',
title: t('controlLayers.HUD.entityStatus.isFiltering', { title }),
};
}
if (isTransforming) {
return {
status: 'info',
status: 'warning',
title: t('controlLayers.HUD.entityStatus.isTransforming', { title }),
};
}
if (isEmpty) {
return {
status: 'info',
status: 'warning',
title: t('controlLayers.HUD.entityStatus.isEmpty', { title }),
};
}

View File

@@ -3,6 +3,7 @@ import { EntityListGlobalActionBarAddLayerMenu } from 'features/controlLayers/co
import { EntityListSelectedEntityActionBarDuplicateButton } from 'features/controlLayers/components/CanvasEntityList/EntityListSelectedEntityActionBarDuplicateButton';
import { EntityListSelectedEntityActionBarFill } from 'features/controlLayers/components/CanvasEntityList/EntityListSelectedEntityActionBarFill';
import { EntityListSelectedEntityActionBarFilterButton } from 'features/controlLayers/components/CanvasEntityList/EntityListSelectedEntityActionBarFilterButton';
import { EntityListSelectedEntityActionBarInvertMaskButton } from 'features/controlLayers/components/CanvasEntityList/EntityListSelectedEntityActionBarInvertMaskButton';
import { EntityListSelectedEntityActionBarOpacity } from 'features/controlLayers/components/CanvasEntityList/EntityListSelectedEntityActionBarOpacity';
import { EntityListSelectedEntityActionBarSelectObjectButton } from 'features/controlLayers/components/CanvasEntityList/EntityListSelectedEntityActionBarSelectObjectButton';
import { EntityListSelectedEntityActionBarTransformButton } from 'features/controlLayers/components/CanvasEntityList/EntityListSelectedEntityActionBarTransformButton';
@@ -21,6 +22,7 @@ export const EntityListSelectedEntityActionBar = memo(() => {
<EntityListSelectedEntityActionBarSelectObjectButton />
<EntityListSelectedEntityActionBarFilterButton />
<EntityListSelectedEntityActionBarTransformButton />
<EntityListSelectedEntityActionBarInvertMaskButton />
<EntityListSelectedEntityActionBarSaveToAssetsButton />
<EntityListSelectedEntityActionBarDuplicateButton />
<EntityListNonRasterLayerToggle />

View File

@@ -0,0 +1,39 @@
import { IconButton } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
import { useInvertMask } from 'features/controlLayers/hooks/useInvertMask';
import { selectSelectedEntityIdentifier } from 'features/controlLayers/store/selectors';
import { isInpaintMaskEntityIdentifier } from 'features/controlLayers/store/types';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiSelectionInverseBold } from 'react-icons/pi';
export const EntityListSelectedEntityActionBarInvertMaskButton = memo(() => {
const { t } = useTranslation();
const selectedEntityIdentifier = useAppSelector(selectSelectedEntityIdentifier);
const isBusy = useCanvasIsBusy();
const invertMask = useInvertMask();
if (!selectedEntityIdentifier) {
return null;
}
if (!isInpaintMaskEntityIdentifier(selectedEntityIdentifier)) {
return null;
}
return (
<IconButton
onClick={invertMask}
isDisabled={isBusy}
minW={8}
variant="link"
alignSelf="stretch"
aria-label={t('controlLayers.invertMask')}
tooltip={t('controlLayers.invertMask')}
icon={<PiSelectionInverseBold />}
/>
);
});
EntityListSelectedEntityActionBarInvertMaskButton.displayName = 'EntityListSelectedEntityActionBarInvertMaskButton';

View File

@@ -5,7 +5,6 @@ import { useEntityIdentifierContext } from 'features/controlLayers/contexts/Enti
import { usePullBboxIntoLayer } from 'features/controlLayers/hooks/saveCanvasHooks';
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
import { replaceCanvasEntityObjectsWithImage } from 'features/imageActions/actions';
import { activeTabCanvasRightPanelChanged } from 'features/ui/store/uiSlice';
import { memo, useCallback, useMemo } from 'react';
import { Trans } from 'react-i18next';
import type { ImageDTO } from 'services/api/types';
@@ -21,9 +20,6 @@ export const ControlLayerSettingsEmptyState = memo(() => {
[dispatch, entityIdentifier, getState]
);
const uploadApi = useImageUploadButton({ onUpload, allowMultiple: false });
const onClickGalleryButton = useCallback(() => {
dispatch(activeTabCanvasRightPanelChanged('gallery'));
}, [dispatch]);
const pullBboxIntoLayer = usePullBboxIntoLayer(entityIdentifier);
const components = useMemo(
@@ -31,14 +27,11 @@ export const ControlLayerSettingsEmptyState = memo(() => {
UploadButton: (
<Button isDisabled={isBusy} size="sm" variant="link" color="base.300" {...uploadApi.getUploadButtonProps()} />
),
GalleryButton: (
<Button onClick={onClickGalleryButton} isDisabled={isBusy} size="sm" variant="link" color="base.300" />
),
PullBboxButton: (
<Button onClick={pullBboxIntoLayer} isDisabled={isBusy} size="sm" variant="link" color="base.300" />
),
}),
[isBusy, onClickGalleryButton, pullBboxIntoLayer, uploadApi]
[isBusy, pullBboxIntoLayer, uploadApi]
);
return (

View File

@@ -1,131 +0,0 @@
import { Checkbox, ConfirmationAlertDialog, Flex, FormControl, FormLabel, Text } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useAssertSingleton } from 'common/hooks/useAssertSingleton';
import { buildUseBoolean } from 'common/hooks/useBoolean';
import { canvasSessionReset, generateSessionReset } from 'features/controlLayers/store/canvasStagingAreaSlice';
import {
selectSystemShouldConfirmOnNewSession,
shouldConfirmOnNewSessionToggled,
} from 'features/system/store/systemSlice';
import { activeTabCanvasRightPanelChanged } from 'features/ui/store/uiSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
const [useNewGallerySessionDialog] = buildUseBoolean(false);
const [useNewCanvasSessionDialog] = buildUseBoolean(false);
const useNewGallerySession = () => {
const dispatch = useAppDispatch();
const shouldConfirmOnNewSession = useAppSelector(selectSystemShouldConfirmOnNewSession);
const newSessionDialog = useNewGallerySessionDialog();
const newGallerySessionImmediate = useCallback(() => {
dispatch(generateSessionReset());
dispatch(activeTabCanvasRightPanelChanged('gallery'));
}, [dispatch]);
const newGallerySessionWithDialog = useCallback(() => {
if (shouldConfirmOnNewSession) {
newSessionDialog.setTrue();
return;
}
newGallerySessionImmediate();
}, [newGallerySessionImmediate, newSessionDialog, shouldConfirmOnNewSession]);
return { newGallerySessionImmediate, newGallerySessionWithDialog };
};
const useNewCanvasSession = () => {
const dispatch = useAppDispatch();
const shouldConfirmOnNewSession = useAppSelector(selectSystemShouldConfirmOnNewSession);
const newSessionDialog = useNewCanvasSessionDialog();
const newCanvasSessionImmediate = useCallback(() => {
dispatch(canvasSessionReset());
dispatch(activeTabCanvasRightPanelChanged('layers'));
}, [dispatch]);
const newCanvasSessionWithDialog = useCallback(() => {
if (shouldConfirmOnNewSession) {
newSessionDialog.setTrue();
return;
}
newCanvasSessionImmediate();
}, [newCanvasSessionImmediate, newSessionDialog, shouldConfirmOnNewSession]);
return { newCanvasSessionImmediate, newCanvasSessionWithDialog };
};
export const NewGallerySessionDialog = memo(() => {
useAssertSingleton('NewGallerySessionDialog');
const { t } = useTranslation();
const dispatch = useAppDispatch();
const dialog = useNewGallerySessionDialog();
const { newGallerySessionImmediate } = useNewGallerySession();
const shouldConfirmOnNewSession = useAppSelector(selectSystemShouldConfirmOnNewSession);
const onToggleConfirm = useCallback(() => {
dispatch(shouldConfirmOnNewSessionToggled());
}, [dispatch]);
return (
<ConfirmationAlertDialog
isOpen={dialog.isTrue}
onClose={dialog.setFalse}
title={t('controlLayers.newGallerySession')}
acceptCallback={newGallerySessionImmediate}
acceptButtonText={t('common.ok')}
useInert={false}
>
<Flex direction="column" gap={3}>
<Text>{t('controlLayers.newGallerySessionDesc')}</Text>
<Text>{t('common.areYouSure')}</Text>
<FormControl>
<FormLabel>{t('common.dontAskMeAgain')}</FormLabel>
<Checkbox isChecked={!shouldConfirmOnNewSession} onChange={onToggleConfirm} />
</FormControl>
</Flex>
</ConfirmationAlertDialog>
);
});
NewGallerySessionDialog.displayName = 'NewGallerySessionDialog';
export const NewCanvasSessionDialog = memo(() => {
useAssertSingleton('NewCanvasSessionDialog');
const { t } = useTranslation();
const dispatch = useAppDispatch();
const dialog = useNewCanvasSessionDialog();
const { newCanvasSessionImmediate } = useNewCanvasSession();
const shouldConfirmOnNewSession = useAppSelector(selectSystemShouldConfirmOnNewSession);
const onToggleConfirm = useCallback(() => {
dispatch(shouldConfirmOnNewSessionToggled());
}, [dispatch]);
return (
<ConfirmationAlertDialog
isOpen={dialog.isTrue}
onClose={dialog.setFalse}
title={t('controlLayers.newCanvasSession')}
acceptCallback={newCanvasSessionImmediate}
acceptButtonText={t('common.ok')}
useInert={false}
>
<Flex direction="column" gap={3}>
<Text>{t('controlLayers.newCanvasSessionDesc')}</Text>
<Text>{t('common.areYouSure')}</Text>
<FormControl>
<FormLabel>{t('common.dontAskMeAgain')}</FormLabel>
<Checkbox isChecked={!shouldConfirmOnNewSession} onChange={onToggleConfirm} />
</FormControl>
</Flex>
</ConfirmationAlertDialog>
);
});
NewCanvasSessionDialog.displayName = 'NewCanvasSessionDialog';

View File

@@ -126,6 +126,7 @@ const AddRefImageDropTargetAndButton = memo(() => {
</Flex>
);
});
AddRefImageDropTargetAndButton.displayName = 'AddRefImageDropTargetAndButton';
const BboxButton = memo(() => {
const { t } = useTranslation();
@@ -145,4 +146,4 @@ const BboxButton = memo(() => {
/>
);
});
AddRefImageDropTargetAndButton.displayName = 'AddRefImageDropTargetAndButton';
BboxButton.displayName = 'BboxButton';

View File

@@ -6,7 +6,6 @@ import type { SetGlobalReferenceImageDndTargetData } from 'features/dnd/dnd';
import { setGlobalReferenceImageDndTarget } from 'features/dnd/dnd';
import { DndDropTarget } from 'features/dnd/DndDropTarget';
import { setGlobalReferenceImage } from 'features/imageActions/actions';
import { activeTabCanvasRightPanelChanged } from 'features/ui/store/uiSlice';
import { memo, useCallback, useMemo } from 'react';
import { Trans, useTranslation } from 'react-i18next';
import type { ImageDTO } from 'services/api/types';
@@ -22,9 +21,6 @@ export const RefImageNoImageState = memo(() => {
[dispatch, id]
);
const uploadApi = useImageUploadButton({ onUpload, allowMultiple: false });
const onClickGalleryButton = useCallback(() => {
dispatch(activeTabCanvasRightPanelChanged('gallery'));
}, [dispatch]);
const dndTargetData = useMemo<SetGlobalReferenceImageDndTargetData>(
() => setGlobalReferenceImageDndTarget.getData({ id }),
@@ -34,9 +30,8 @@ export const RefImageNoImageState = memo(() => {
const components = useMemo(
() => ({
UploadButton: <Button size="sm" variant="link" color="base.300" {...uploadApi.getUploadButtonProps()} />,
GalleryButton: <Button onClick={onClickGalleryButton} size="sm" variant="link" color="base.300" />,
}),
[onClickGalleryButton, uploadApi]
[uploadApi]
);
return (

View File

@@ -8,7 +8,6 @@ import type { SetGlobalReferenceImageDndTargetData } from 'features/dnd/dnd';
import { setGlobalReferenceImageDndTarget } from 'features/dnd/dnd';
import { DndDropTarget } from 'features/dnd/DndDropTarget';
import { setGlobalReferenceImage } from 'features/imageActions/actions';
import { activeTabCanvasRightPanelChanged } from 'features/ui/store/uiSlice';
import { memo, useCallback, useMemo } from 'react';
import { Trans, useTranslation } from 'react-i18next';
import type { ImageDTO } from 'services/api/types';
@@ -25,9 +24,6 @@ export const RefImageNoImageStateWithCanvasOptions = memo(() => {
[dispatch, id]
);
const uploadApi = useImageUploadButton({ onUpload, allowMultiple: false });
const onClickGalleryButton = useCallback(() => {
dispatch(activeTabCanvasRightPanelChanged('gallery'));
}, [dispatch]);
const pullBboxIntoIPAdapter = usePullBboxIntoGlobalReferenceImage(id);
const dndTargetData = useMemo<SetGlobalReferenceImageDndTargetData>(
@@ -40,14 +36,11 @@ export const RefImageNoImageStateWithCanvasOptions = memo(() => {
UploadButton: (
<Button isDisabled={isBusy} size="sm" variant="link" color="base.300" {...uploadApi.getUploadButtonProps()} />
),
GalleryButton: (
<Button onClick={onClickGalleryButton} isDisabled={isBusy} size="sm" variant="link" color="base.300" />
),
PullBboxButton: (
<Button onClick={pullBboxIntoIPAdapter} isDisabled={isBusy} size="sm" variant="link" color="base.300" />
),
}),
[isBusy, onClickGalleryButton, pullBboxIntoIPAdapter, uploadApi]
[isBusy, pullBboxIntoIPAdapter, uploadApi]
);
return (

View File

@@ -9,7 +9,6 @@ import type { SetRegionalGuidanceReferenceImageDndTargetData } from 'features/dn
import { setRegionalGuidanceReferenceImageDndTarget } from 'features/dnd/dnd';
import { DndDropTarget } from 'features/dnd/DndDropTarget';
import { setRegionalGuidanceReferenceImage } from 'features/imageActions/actions';
import { activeTabCanvasRightPanelChanged } from 'features/ui/store/uiSlice';
import { memo, useCallback, useMemo } from 'react';
import { Trans, useTranslation } from 'react-i18next';
import { PiXBold } from 'react-icons/pi';
@@ -31,9 +30,6 @@ export const RegionalGuidanceIPAdapterSettingsEmptyState = memo(({ referenceImag
[dispatch, entityIdentifier, referenceImageId]
);
const uploadApi = useImageUploadButton({ onUpload, allowMultiple: false });
const onClickGalleryButton = useCallback(() => {
dispatch(activeTabCanvasRightPanelChanged('gallery'));
}, [dispatch]);
const onDeleteIPAdapter = useCallback(() => {
dispatch(rgRefImageDeleted({ entityIdentifier, referenceImageId }));
}, [dispatch, entityIdentifier, referenceImageId]);
@@ -53,14 +49,11 @@ export const RegionalGuidanceIPAdapterSettingsEmptyState = memo(({ referenceImag
UploadButton: (
<Button isDisabled={isBusy} size="sm" variant="link" color="base.300" {...uploadApi.getUploadButtonProps()} />
),
GalleryButton: (
<Button onClick={onClickGalleryButton} isDisabled={isBusy} size="sm" variant="link" color="base.300" />
),
PullBboxButton: (
<Button onClick={pullBboxIntoIPAdapter} isDisabled={isBusy} size="sm" variant="link" color="base.300" />
),
}),
[isBusy, onClickGalleryButton, pullBboxIntoIPAdapter, uploadApi]
[isBusy, pullBboxIntoIPAdapter, uploadApi]
);
return (

View File

@@ -16,7 +16,7 @@ export const CanvasSettingsSaveAllImagesToGalleryCheckbox = memo(() => {
}, [dispatch]);
return (
<FormControl w="full">
<FormLabel flexGrow={1}>{t('controlLayers.saveAllImagesToGallery')}</FormLabel>
<FormLabel flexGrow={1}>{t('controlLayers.settings.saveAllImagesToGallery.label')}</FormLabel>
<Checkbox isChecked={saveAllImagesToGallery} onChange={onChange} />
</FormControl>
);

View File

@@ -1,585 +0,0 @@
import { useStore } from '@nanostores/react';
import { useAppStore } from 'app/store/storeHooks';
import { getOutputImageName } from 'features/controlLayers/components/SimpleSession/shared';
import { selectStagingAreaAutoSwitch } from 'features/controlLayers/store/canvasSettingsSlice';
import {
buildSelectSessionQueueItems,
canvasQueueItemDiscarded,
canvasSessionReset,
} from 'features/controlLayers/store/canvasStagingAreaSlice';
import type { ProgressImage } from 'features/nodes/types/common';
import type { Atom, MapStore, StoreValue, WritableAtom } from 'nanostores';
import { atom, computed, effect, map, subscribeKeys } from 'nanostores';
import type { PropsWithChildren } from 'react';
import { createContext, memo, useCallback, useContext, useEffect, useMemo, useState } from 'react';
import { getImageDTOSafe } from 'services/api/endpoints/images';
import { queueApi } from 'services/api/endpoints/queue';
import type { ImageDTO, S } from 'services/api/types';
import { $socket } from 'services/events/stores';
import { assert, objectEntries } from 'tsafe';
export type ProgressData = {
itemId: number;
progressEvent: S['InvocationProgressEvent'] | null;
progressImage: ProgressImage | null;
imageDTO: ImageDTO | null;
imageLoaded: boolean;
};
const getInitialProgressData = (itemId: number): ProgressData => ({
itemId,
progressEvent: null,
progressImage: null,
imageDTO: null,
imageLoaded: false,
});
export const useProgressData = ($progressData: ProgressDataMap, itemId: number): ProgressData => {
const getInitialValue = useCallback(
() => $progressData.get()[itemId] ?? getInitialProgressData(itemId),
[$progressData, itemId]
);
const [value, setValue] = useState(getInitialValue);
useEffect(() => {
const unsub = subscribeKeys($progressData, [itemId], (data) => {
const progressData = data[itemId];
if (!progressData) {
return;
}
setValue(progressData);
});
return () => {
unsub();
};
}, [$progressData, itemId]);
return value;
};
const setProgress = ($progressData: ProgressDataMap, data: S['InvocationProgressEvent']) => {
const progressData = $progressData.get();
const current = progressData[data.item_id];
if (current) {
const next = { ...current };
next.progressEvent = data;
if (data.image) {
next.progressImage = data.image;
}
$progressData.set({
...progressData,
[data.item_id]: next,
});
} else {
$progressData.set({
...progressData,
[data.item_id]: {
itemId: data.item_id,
progressEvent: data,
progressImage: data.image ?? null,
imageDTO: null,
imageLoaded: false,
},
});
}
};
export type ProgressDataMap = MapStore<Record<number, ProgressData | undefined>>;
type CanvasSessionContextValue = {
session: { id: string; type: 'simple' | 'advanced' };
$items: Atom<S['SessionQueueItem'][]>;
$itemCount: Atom<number>;
$hasItems: Atom<boolean>;
$isPending: Atom<boolean>;
$progressData: ProgressDataMap;
$selectedItemId: WritableAtom<number | null>;
$selectedItem: Atom<S['SessionQueueItem'] | null>;
$selectedItemIndex: Atom<number | null>;
$selectedItemOutputImageDTO: Atom<ImageDTO | null>;
selectNext: () => void;
selectPrev: () => void;
selectFirst: () => void;
selectLast: () => void;
onImageLoad: (itemId: number) => void;
discard: (itemId: number) => void;
discardAll: () => void;
};
const CanvasSessionContext = createContext<CanvasSessionContextValue | null>(null);
export const CanvasSessionContextProvider = memo(
({ id, type, children }: PropsWithChildren<{ id: string; type: 'simple' | 'advanced' }>) => {
/**
* For best performance and interop with the Canvas, which is outside react but needs to interact with the react
* app, all canvas session state is packaged as nanostores atoms. The trickiest part is syncing the queue items
* with a nanostores atom.
*/
const session = useMemo(() => ({ type, id }), [type, id]);
/**
* App store
*/
const store = useAppStore();
const socket = useStore($socket);
/**
* Track the last completed item. Used to implement autoswitch.
*/
const $lastCompletedItemId = useState(() => atom<number | null>(null))[0];
/**
* Track the last started item. Used to implement autoswitch.
*/
const $lastStartedItemId = useState(() => atom<number | null>(null))[0];
/**
* Manually-synced atom containing queue items for the current session. This is populated from the RTK Query cache
* and kept in sync with it via a redux subscription.
*/
const $items = useState(() => atom<S['SessionQueueItem'][]>([]))[0];
/**
* An internal flag used to work around race conditions with auto-switch switching to queue items before their
* output images have fully loaded.
*/
const $lastLoadedItemId = useState(() => atom<number | null>(null))[0];
/**
* An ephemeral store of progress events and images for all items in the current session.
*/
const $progressData = useState(() => map<StoreValue<ProgressDataMap>>({}))[0];
/**
* The currently selected queue item's ID, or null if one is not selected.
*/
const $selectedItemId = useState(() => atom<number | null>(null))[0];
/**
* The number of items. Computed from the queue items array.
*/
const $itemCount = useState(() => computed([$items], (items) => items.length))[0];
/**
* Whether there are any items. Computed from the queue items array.
*/
const $hasItems = useState(() => computed([$items], (items) => items.length > 0))[0];
/**
* Whether there are any pending or in-progress items. Computed from the queue items array.
*/
const $isPending = useState(() =>
computed([$items], (items) => items.some((item) => item.status === 'pending' || item.status === 'in_progress'))
)[0];
/**
* The currently selected queue item, or null if one is not selected.
*/
const $selectedItem = useState(() =>
computed([$items, $selectedItemId], (items, selectedItemId) => {
if (items.length === 0) {
return null;
}
if (selectedItemId === null) {
return null;
}
return items.find(({ item_id }) => item_id === selectedItemId) ?? null;
})
)[0];
/**
* The currently selected queue item's index in the list of items, or null if one is not selected.
*/
const $selectedItemIndex = useState(() =>
computed([$items, $selectedItemId], (items, selectedItemId) => {
if (items.length === 0) {
return null;
}
if (selectedItemId === null) {
return null;
}
return items.findIndex(({ item_id }) => item_id === selectedItemId) ?? null;
})
)[0];
/**
* The currently selected queue item's output image name, or null if one is not selected or there is no output
* image recorded.
*/
const $selectedItemOutputImageDTO = useState(() =>
computed([$selectedItemId, $progressData], (selectedItemId, progressData) => {
if (selectedItemId === null) {
return null;
}
const datum = progressData[selectedItemId];
if (!datum) {
return null;
}
return datum.imageDTO;
})
)[0];
/**
* A redux selector to select all queue items from the RTK Query cache.
*/
const selectQueueItems = useMemo(() => buildSelectSessionQueueItems(session.id), [session.id]);
const discard = useCallback(
(itemId: number) => {
store.dispatch(canvasQueueItemDiscarded({ itemId }));
},
[store]
);
const discardAll = useCallback(() => {
store.dispatch(canvasSessionReset());
}, [store]);
const selectNext = useCallback(() => {
const selectedItemId = $selectedItemId.get();
if (selectedItemId === null) {
return;
}
const items = $items.get();
const currentIndex = items.findIndex((item) => item.item_id === selectedItemId);
const nextIndex = (currentIndex + 1) % items.length;
const nextItem = items[nextIndex];
if (!nextItem) {
return;
}
$selectedItemId.set(nextItem.item_id);
}, [$items, $selectedItemId]);
const selectPrev = useCallback(() => {
const selectedItemId = $selectedItemId.get();
if (selectedItemId === null) {
return;
}
const items = $items.get();
const currentIndex = items.findIndex((item) => item.item_id === selectedItemId);
const prevIndex = (currentIndex - 1 + items.length) % items.length;
const prevItem = items[prevIndex];
if (!prevItem) {
return;
}
$selectedItemId.set(prevItem.item_id);
}, [$items, $selectedItemId]);
const selectFirst = useCallback(() => {
const items = $items.get();
const first = items.at(0);
if (!first) {
return;
}
$selectedItemId.set(first.item_id);
}, [$items, $selectedItemId]);
const selectLast = useCallback(() => {
const items = $items.get();
const last = items.at(-1);
if (!last) {
return;
}
$selectedItemId.set(last.item_id);
}, [$items, $selectedItemId]);
const onImageLoad = useCallback(
(itemId: number) => {
const progressData = $progressData.get();
const current = progressData[itemId];
if (current) {
const next = { ...current, imageLoaded: true };
$progressData.setKey(itemId, next);
} else {
$progressData.setKey(itemId, {
...getInitialProgressData(itemId),
imageLoaded: true,
});
}
if (
$lastCompletedItemId.get() === itemId &&
selectStagingAreaAutoSwitch(store.getState()) === 'switch_on_finish'
) {
$selectedItemId.set(itemId);
$lastCompletedItemId.set(null);
}
},
[$lastCompletedItemId, $progressData, $selectedItemId, store]
);
// Set up socket listeners
useEffect(() => {
if (!socket) {
return;
}
const onProgress = (data: S['InvocationProgressEvent']) => {
if (data.destination !== session.id) {
return;
}
setProgress($progressData, data);
};
const onQueueItemStatusChanged = (data: S['QueueItemStatusChangedEvent']) => {
if (data.destination !== session.id) {
return;
}
if (data.status === 'completed') {
$lastCompletedItemId.set(data.item_id);
}
if (data.status === 'in_progress') {
$lastStartedItemId.set(data.item_id);
}
};
socket.on('invocation_progress', onProgress);
socket.on('queue_item_status_changed', onQueueItemStatusChanged);
return () => {
socket.off('invocation_progress', onProgress);
socket.off('queue_item_status_changed', onQueueItemStatusChanged);
};
}, [$lastCompletedItemId, $lastStartedItemId, $progressData, $selectedItemId, session.id, socket]);
// Set up state subscriptions and effects
useEffect(() => {
let _prevItems: readonly S['SessionQueueItem'][] = [];
// Seed the $items atom with the initial query cache state
$items.set(selectQueueItems(store.getState()));
// Manually keep the $items atom in sync as the query cache is updated
const unsubReduxSyncToItemsAtom = store.subscribe(() => {
const prevItems = $items.get();
const items = selectQueueItems(store.getState());
if (items !== prevItems) {
_prevItems = prevItems;
$items.set(items);
}
});
// Handle cases that could result in a nonexistent queue item being selected.
const unsubEnsureSelectedItemIdExists = effect(
[$items, $selectedItemId, $lastStartedItemId],
(items, selectedItemId, lastStartedItemId) => {
if (items.length === 0) {
// If there are no items, cannot have a selected item.
$selectedItemId.set(null);
} else if (selectedItemId === null && items.length > 0) {
// If there is no selected item but there are items, select the first one.
$selectedItemId.set(items[0]?.item_id ?? null);
return;
} else if (
selectStagingAreaAutoSwitch(store.getState()) === 'switch_on_start' &&
items.findIndex(({ item_id }) => item_id === lastStartedItemId) !== -1
) {
$selectedItemId.set(lastStartedItemId);
$lastStartedItemId.set(null);
} else if (selectedItemId !== null && items.findIndex(({ item_id }) => item_id === selectedItemId) === -1) {
// If an item is selected and it is not in the list of items, un-set it. This effect will run again and we'll
// the above case, selecting the first item if there are any.
let prevIndex = _prevItems.findIndex(({ item_id }) => item_id === selectedItemId);
if (prevIndex >= items.length) {
prevIndex = items.length - 1;
}
const nextItem = items[prevIndex];
$selectedItemId.set(nextItem?.item_id ?? null);
}
if (items !== _prevItems) {
_prevItems = items;
}
}
);
// Clean up the progress data when a queue item is discarded.
const unsubCleanUpProgressData = $items.subscribe(async (items) => {
const progressData = $progressData.get();
const toDelete: number[] = [];
const toUpdate: ProgressData[] = [];
for (const [id, datum] of objectEntries(progressData)) {
if (!datum) {
toDelete.push(id);
continue;
}
const item = items.find(({ item_id }) => item_id === datum.itemId);
if (!item) {
toDelete.push(datum.itemId);
} else if (item.status === 'canceled' || item.status === 'failed') {
toUpdate.push({
...datum,
progressEvent: null,
progressImage: null,
imageDTO: null,
});
}
}
for (const item of items) {
const datum = progressData[item.item_id];
if (datum) {
if (datum.imageDTO) {
continue;
}
const outputImageName = getOutputImageName(item);
if (!outputImageName) {
continue;
}
const imageDTO = await getImageDTOSafe(outputImageName);
if (!imageDTO) {
continue;
}
toUpdate.push({
...datum,
imageDTO,
});
} else {
const outputImageName = getOutputImageName(item);
if (!outputImageName) {
continue;
}
const imageDTO = await getImageDTOSafe(outputImageName);
if (!imageDTO) {
continue;
}
toUpdate.push({
...getInitialProgressData(item.item_id),
imageDTO,
});
}
}
for (const itemId of toDelete) {
$progressData.setKey(itemId, undefined);
}
for (const datum of toUpdate) {
$progressData.setKey(datum.itemId, datum);
}
});
// We only want to auto-switch to completed queue items once their images have fully loaded to prevent flashes
// of fallback content and/or progress images. The only surefire way to determine when images have fully loaded
// is via the image elements' `onLoad` callback. Images set `$lastLoadedItemId` to their queue item ID in their
// `onLoad` handler, and we listen for that here. If auto-switch is enabled, we then switch the to the item.
//
// TODO: This isn't perfect... we set $lastLoadedItemId in the mini preview component, but the full view
// component still needs to retrieve the image from the browser cache... can result in a flash of the progress
// image...
const unsubHandleAutoSwitch = $lastLoadedItemId.listen((lastLoadedItemId) => {
if (lastLoadedItemId === null) {
return;
}
if (selectStagingAreaAutoSwitch(store.getState()) === 'switch_on_finish') {
$selectedItemId.set(lastLoadedItemId);
}
$lastLoadedItemId.set(null);
});
// Create an RTK Query subscription. Without this, the query cache selector will never return anything bc RTK
// doesn't know we care about it.
const { unsubscribe: unsubQueueItemsQuery } = store.dispatch(
queueApi.endpoints.listAllQueueItems.initiate({ destination: session.id })
);
// const unsubListener = store.dispatch(
// addAppListener({
// matcher: queueApi.endpoints.cancelQueueItem.matchFulfilled,
// effect: ({ payload }, { getState }) => {
// const { item_id } = payload;
// const items = selectQueueItems(getState());
// if (items.length === 0) {
// $selectedItemId.set(null);
// } else if ($selectedItemId.get() === null) {
// $selectedItemId.set(items[0].item_id);
// }
// },
// })
// );
// Clean up all subscriptions and top-level (i.e. non-computed/derived state)
return () => {
unsubHandleAutoSwitch();
unsubQueueItemsQuery();
unsubReduxSyncToItemsAtom();
unsubEnsureSelectedItemIdExists();
unsubCleanUpProgressData();
$items.set([]);
$progressData.set({});
$selectedItemId.set(null);
};
}, [
$items,
$lastLoadedItemId,
$lastStartedItemId,
$progressData,
$selectedItemId,
selectQueueItems,
session.id,
store,
]);
const value = useMemo<CanvasSessionContextValue>(
() => ({
session,
$items,
$hasItems,
$isPending,
$progressData,
$selectedItemId,
$selectedItem,
$selectedItemIndex,
$selectedItemOutputImageDTO,
$itemCount,
selectNext,
selectPrev,
selectFirst,
selectLast,
onImageLoad,
discard,
discardAll,
}),
[
$items,
$hasItems,
$isPending,
$progressData,
$selectedItem,
$selectedItemId,
$selectedItemIndex,
session,
$selectedItemOutputImageDTO,
$itemCount,
selectNext,
selectPrev,
selectFirst,
selectLast,
onImageLoad,
discard,
discardAll,
]
);
return <CanvasSessionContext.Provider value={value}>{children}</CanvasSessionContext.Provider>;
}
);
CanvasSessionContextProvider.displayName = 'CanvasSessionContextProvider';
export const useCanvasSessionContext = () => {
const ctx = useContext(CanvasSessionContext);
assert(ctx !== null, "'useCanvasSessionContext' must be used within a CanvasSessionContextProvider");
return ctx;
};
export const useOutputImageDTO = (item: S['SessionQueueItem']) => {
const ctx = useCanvasSessionContext();
const $imageDTO = useState(() =>
computed([ctx.$progressData], (progressData) => progressData[item.item_id]?.imageDTO ?? null)
)[0];
const imageDTO = useStore($imageDTO);
return imageDTO;
};

View File

@@ -1,10 +1,11 @@
import type { CircularProgressProps, SystemStyleObject } from '@invoke-ai/ui-library';
import { CircularProgress, Tooltip } from '@invoke-ai/ui-library';
import { useCanvasSessionContext, useProgressData } from 'features/controlLayers/components/SimpleSession/context';
import { getProgressMessage } from 'features/controlLayers/components/SimpleSession/shared';
import { getProgressMessage } from 'features/controlLayers/components/StagingArea/shared';
import { memo } from 'react';
import type { S } from 'services/api/types';
import { useProgressDatum } from './context';
const circleStyles: SystemStyleObject = {
circle: {
transitionProperty: 'none',
@@ -18,8 +19,7 @@ const circleStyles: SystemStyleObject = {
type Props = { itemId: number; status: S['SessionQueueItem']['status'] } & CircularProgressProps;
export const QueueItemCircularProgress = memo(({ itemId, status, ...rest }: Props) => {
const { $progressData } = useCanvasSessionContext();
const { progressEvent } = useProgressData($progressData, itemId);
const { progressEvent } = useProgressDatum(itemId);
if (status !== 'in_progress') {
return null;

View File

@@ -1,8 +1,9 @@
import type { TextProps } from '@invoke-ai/ui-library';
import { Text } from '@invoke-ai/ui-library';
import { DROP_SHADOW } from 'features/controlLayers/components/SimpleSession/shared';
import { memo } from 'react';
import { DROP_SHADOW } from './shared';
export const QueueItemNumber = memo(({ number, ...rest }: { number: number } & TextProps) => {
return <Text pointerEvents="none" userSelect="none" filter={DROP_SHADOW} {...rest}>{`#${number}`}</Text>;
});

View File

@@ -1,25 +1,23 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Flex } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import {
useCanvasSessionContext,
useOutputImageDTO,
useProgressData,
} from 'features/controlLayers/components/SimpleSession/context';
import { QueueItemCircularProgress } from 'features/controlLayers/components/SimpleSession/QueueItemCircularProgress';
import { QueueItemNumber } from 'features/controlLayers/components/SimpleSession/QueueItemNumber';
import { QueueItemProgressImage } from 'features/controlLayers/components/SimpleSession/QueueItemProgressImage';
import { QueueItemStatusLabel } from 'features/controlLayers/components/SimpleSession/QueueItemStatusLabel';
import { getQueueItemElementId } from 'features/controlLayers/components/SimpleSession/shared';
import { QueueItemCircularProgress } from 'features/controlLayers/components/StagingArea/QueueItemCircularProgress';
import { QueueItemProgressImage } from 'features/controlLayers/components/StagingArea/QueueItemProgressImage';
import { QueueItemStatusLabel } from 'features/controlLayers/components/StagingArea/QueueItemStatusLabel';
import { getQueueItemElementId } from 'features/controlLayers/components/StagingArea/shared';
import {
selectStagingAreaAutoSwitch,
settingsStagingAreaAutoSwitchChanged,
} from 'features/controlLayers/store/canvasSettingsSlice';
import { DndImage } from 'features/dnd/DndImage';
import { toast } from 'features/toast/toast';
import { memo, useCallback } from 'react';
import { memo, useCallback, useMemo } from 'react';
import type { S } from 'services/api/types';
import { useOutputImageDTO, useStagingAreaContext } from './context';
import { QueueItemNumber } from './QueueItemNumber';
const sx = {
cursor: 'pointer',
userSelect: 'none',
@@ -41,19 +39,19 @@ const sx = {
type Props = {
item: S['SessionQueueItem'];
index: number;
isSelected: boolean;
};
export const QueueItemPreviewMini = memo(({ item, isSelected, index }: Props) => {
export const QueueItemPreviewMini = memo(({ item, index }: Props) => {
const ctx = useStagingAreaContext();
const dispatch = useAppDispatch();
const ctx = useCanvasSessionContext();
const { imageLoaded } = useProgressData(ctx.$progressData, item.item_id);
const imageDTO = useOutputImageDTO(item);
const $isSelected = useMemo(() => ctx.buildIsSelectedComputed(item.item_id), [ctx, item.item_id]);
const isSelected = useStore($isSelected);
const imageDTO = useOutputImageDTO(item.item_id);
const autoSwitch = useAppSelector(selectStagingAreaAutoSwitch);
const onClick = useCallback(() => {
ctx.$selectedItemId.set(item.item_id);
}, [ctx.$selectedItemId, item.item_id]);
ctx.select(item.item_id);
}, [ctx, item.item_id]);
const onDoubleClick = useCallback(() => {
if (autoSwitch !== 'off') {
@@ -65,7 +63,7 @@ export const QueueItemPreviewMini = memo(({ item, isSelected, index }: Props) =>
}, [autoSwitch, dispatch]);
const onLoad = useCallback(() => {
ctx.onImageLoad(item.item_id);
ctx.onImageLoaded(item.item_id);
}, [ctx, item.item_id]);
return (
@@ -77,8 +75,8 @@ export const QueueItemPreviewMini = memo(({ item, isSelected, index }: Props) =>
onDoubleClick={onDoubleClick}
>
<QueueItemStatusLabel item={item} position="absolute" margin="auto" />
{imageDTO && <DndImage imageDTO={imageDTO} onLoad={onLoad} asThumbnail position="absolute" />}
{!imageLoaded && <QueueItemProgressImage itemId={item.item_id} position="absolute" />}
{imageDTO && <DndImage imageDTO={imageDTO} position="absolute" onLoad={onLoad} />}
<QueueItemProgressImage itemId={item.item_id} position="absolute" />
<QueueItemNumber number={index + 1} position="absolute" top={0} left={1} />
<QueueItemCircularProgress itemId={item.item_id} status={item.status} position="absolute" top={1} right={2} />
</Flex>

View File

@@ -1,15 +1,15 @@
import type { ImageProps } from '@invoke-ai/ui-library';
import { Image } from '@invoke-ai/ui-library';
import { useCanvasSessionContext, useProgressData } from 'features/controlLayers/components/SimpleSession/context';
import { memo } from 'react';
import { useProgressDatum } from './context';
type Props = { itemId: number } & ImageProps;
export const QueueItemProgressImage = memo(({ itemId, ...rest }: Props) => {
const ctx = useCanvasSessionContext();
const { progressImage } = useProgressData(ctx.$progressData, itemId);
const { progressImage, imageLoaded } = useProgressDatum(itemId);
if (!progressImage) {
if (!progressImage || imageLoaded) {
return null;
}

View File

@@ -1,16 +1,16 @@
import type { TextProps } from '@invoke-ai/ui-library';
import { Text } from '@invoke-ai/ui-library';
import { useCanvasSessionContext, useProgressData } from 'features/controlLayers/components/SimpleSession/context';
import { memo } from 'react';
import type { S } from 'services/api/types';
import { useProgressDatum } from './context';
type Props = { item: S['SessionQueueItem'] } & TextProps;
export const QueueItemStatusLabel = memo(({ item, ...rest }: Props) => {
const ctx = useCanvasSessionContext();
const { progressImage, imageLoaded } = useProgressData(ctx.$progressData, item.item_id);
const { progressImage } = useProgressDatum(item.item_id);
if (progressImage || imageLoaded) {
if (progressImage) {
return null;
}

View File

@@ -1,5 +1,7 @@
import { IconButton } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import {
selectStagingAreaAutoSwitch,
settingsStagingAreaAutoSwitchChanged,
@@ -8,6 +10,9 @@ import { memo, useCallback } from 'react';
import { PiCaretLineRightBold, PiCaretRightBold, PiMoonBold } from 'react-icons/pi';
export const StagingAreaAutoSwitchButtons = memo(() => {
const canvasManager = useCanvasManager();
const shouldShowStagedImage = useStore(canvasManager.stagingArea.$shouldShowStagedImage);
const autoSwitch = useAppSelector(selectStagingAreaAutoSwitch);
const dispatch = useAppDispatch();
@@ -29,6 +34,7 @@ export const StagingAreaAutoSwitchButtons = memo(() => {
icon={<PiMoonBold />}
colorScheme={autoSwitch === 'off' ? 'invokeBlue' : 'base'}
onClick={onClickOff}
isDisabled={!shouldShowStagedImage}
/>
<IconButton
aria-label="Switch on start"
@@ -36,6 +42,7 @@ export const StagingAreaAutoSwitchButtons = memo(() => {
icon={<PiCaretRightBold />}
colorScheme={autoSwitch === 'switch_on_start' ? 'invokeBlue' : 'base'}
onClick={onClickSwitchOnStart}
isDisabled={!shouldShowStagedImage}
/>
<IconButton
aria-label="Switch on finish"
@@ -43,6 +50,7 @@ export const StagingAreaAutoSwitchButtons = memo(() => {
icon={<PiCaretLineRightBold />}
colorScheme={autoSwitch === 'switch_on_finish' ? 'invokeBlue' : 'base'}
onClick={onClickSwitchOnFinished}
isDisabled={!shouldShowStagedImage}
/>
</>
);

View File

@@ -1,16 +1,16 @@
import { Box, Flex, forwardRef } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { logger } from 'app/logging/logger';
import { useCanvasSessionContext } from 'features/controlLayers/components/SimpleSession/context';
import { QueueItemPreviewMini } from 'features/controlLayers/components/SimpleSession/QueueItemPreviewMini';
import { useCanvasManagerSafe } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import { QueueItemPreviewMini } from 'features/controlLayers/components/StagingArea/QueueItemPreviewMini';
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import { useOverlayScrollbars } from 'overlayscrollbars-react';
import type { CSSProperties, RefObject } from 'react';
import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react';
import type { Components, ItemContent, ListRange, VirtuosoHandle, VirtuosoProps } from 'react-virtuoso';
import { memo, useCallback, useEffect, useRef, useState } from 'react';
import type { Components, ComputeItemKey, ItemContent, ListRange, VirtuosoHandle, VirtuosoProps } from 'react-virtuoso';
import { Virtuoso } from 'react-virtuoso';
import type { S } from 'services/api/types';
import { useStagingAreaContext } from './context';
import { getQueueItemElementId } from './shared';
const log = logger('system');
@@ -20,8 +20,6 @@ const virtuosoStyles = {
height: '72px',
} satisfies CSSProperties;
type VirtuosoContext = { selectedItemId: number | null };
/**
* Scroll the item at the given index into view if it is not currently visible.
*/
@@ -132,28 +130,26 @@ const useScrollableStagingArea = (rootRef: RefObject<HTMLDivElement>) => {
};
export const StagingAreaItemsList = memo(() => {
const canvasManager = useCanvasManagerSafe();
const ctx = useCanvasSessionContext();
const canvasManager = useCanvasManager();
const ctx = useStagingAreaContext();
const virtuosoRef = useRef<VirtuosoHandle>(null);
const rangeRef = useRef<ListRange>({ startIndex: 0, endIndex: 0 });
const rootRef = useRef<HTMLDivElement>(null);
const items = useStore(ctx.$items);
const selectedItemId = useStore(ctx.$selectedItemId);
const context = useMemo(() => ({ selectedItemId }), [selectedItemId]);
const scrollerRef = useScrollableStagingArea(rootRef);
useEffect(() => {
if (!canvasManager) {
return;
}
return canvasManager.stagingArea.connectToSession(ctx.$selectedItemId, ctx.$progressData, ctx.$isPending);
}, [canvasManager, ctx.$progressData, ctx.$selectedItemId, ctx.$isPending]);
return canvasManager.stagingArea.connectToSession(ctx.$items, ctx.$selectedItem);
}, [canvasManager, ctx.$progressData, ctx.$items, ctx.$selectedItem]);
useEffect(() => {
return ctx.$selectedItemIndex.listen((index) => {
return ctx.$selectedItemIndex.listen((selectedItemIndex) => {
if (selectedItemIndex === null) {
return;
}
if (!virtuosoRef.current) {
return;
}
@@ -162,11 +158,7 @@ export const StagingAreaItemsList = memo(() => {
return;
}
if (index === null) {
return;
}
scrollIntoView(index, rootRef.current, virtuosoRef.current, rangeRef.current);
scrollIntoView(selectedItemIndex, rootRef.current, virtuosoRef.current, rangeRef.current);
});
}, [ctx.$selectedItemIndex]);
@@ -176,40 +168,46 @@ export const StagingAreaItemsList = memo(() => {
return (
<Box data-overlayscrollbars-initialize="" ref={rootRef} position="relative" w="full" h="full">
<Virtuoso<S['SessionQueueItem'], VirtuosoContext>
<Virtuoso<S['SessionQueueItem']>
ref={virtuosoRef}
context={context}
data={items}
horizontalDirection
style={virtuosoStyles}
computeItemKey={computeItemKey}
increaseViewportBy={2048}
itemContent={itemContent}
components={components}
rangeChanged={onRangeChanged}
// Virtuoso expects the ref to be of HTMLElement | null | Window, but overlayscrollbars doesn't allow Window
scrollerRef={scrollerRef as VirtuosoProps<S['SessionQueueItem'], VirtuosoContext>['scrollerRef']}
scrollerRef={scrollerRef as VirtuosoProps<S['SessionQueueItem'], void>['scrollerRef']}
/>
</Box>
);
});
StagingAreaItemsList.displayName = 'StagingAreaItemsList';
const itemContent: ItemContent<S['SessionQueueItem'], VirtuosoContext> = (index, item, { selectedItemId }) => (
<QueueItemPreviewMini
key={`${item.item_id}-mini`}
item={item}
index={index}
isSelected={selectedItemId === item.item_id}
/>
const computeItemKey: ComputeItemKey<S['SessionQueueItem'], void> = (_, item: S['SessionQueueItem']) => {
return item.item_id;
};
const itemContent: ItemContent<S['SessionQueueItem'], void> = (index, item) => (
<QueueItemPreviewMini key={`${item.item_id}-mini`} item={item} index={index} />
);
const listSx = {
'& > * + *': {
pl: 2,
},
'&[data-disabled="true"]': {
filter: 'grayscale(1) opacity(0.5)',
},
};
const components: Components<S['SessionQueueItem'], VirtuosoContext> = {
const components: Components<S['SessionQueueItem']> = {
List: forwardRef(({ context: _, ...rest }, ref) => {
return <Flex ref={ref} sx={listSx} {...rest} />;
const canvasManager = useCanvasManager();
const shouldShowStagedImage = useStore(canvasManager.stagingArea.$shouldShowStagedImage);
return <Flex ref={ref} sx={listSx} data-disabled={!shouldShowStagedImage} {...rest} />;
}),
};

View File

@@ -1,6 +1,5 @@
import { ButtonGroup, Flex } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useCanvasSessionContext } from 'features/controlLayers/components/SimpleSession/context';
import { useStagingAreaContext } from 'features/controlLayers/components/StagingArea/context';
import { StagingAreaToolbarAcceptButton } from 'features/controlLayers/components/StagingArea/StagingAreaToolbarAcceptButton';
import { StagingAreaToolbarDiscardAllButton } from 'features/controlLayers/components/StagingArea/StagingAreaToolbarDiscardAllButton';
import { StagingAreaToolbarDiscardSelectedButton } from 'features/controlLayers/components/StagingArea/StagingAreaToolbarDiscardSelectedButton';
@@ -10,17 +9,13 @@ import { StagingAreaToolbarNextButton } from 'features/controlLayers/components/
import { StagingAreaToolbarPrevButton } from 'features/controlLayers/components/StagingArea/StagingAreaToolbarPrevButton';
import { StagingAreaToolbarSaveSelectedToGalleryButton } from 'features/controlLayers/components/StagingArea/StagingAreaToolbarSaveSelectedToGalleryButton';
import { StagingAreaToolbarToggleShowResultsButton } from 'features/controlLayers/components/StagingArea/StagingAreaToolbarToggleShowResultsButton';
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import { memo } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
import { StagingAreaAutoSwitchButtons } from './StagingAreaAutoSwitchButtons';
export const StagingAreaToolbar = memo(() => {
const canvasManager = useCanvasManager();
const shouldShowStagedImage = useStore(canvasManager.stagingArea.$shouldShowStagedImage);
const ctx = useCanvasSessionContext();
const ctx = useStagingAreaContext();
useHotkeys('meta+left', ctx.selectFirst, { preventDefault: true });
useHotkeys('meta+right', ctx.selectLast, { preventDefault: true });
@@ -28,22 +23,22 @@ export const StagingAreaToolbar = memo(() => {
return (
<Flex gap={2}>
<ButtonGroup borderRadius="base" shadow="dark-lg">
<StagingAreaToolbarPrevButton isDisabled={!shouldShowStagedImage} />
<StagingAreaToolbarPrevButton />
<StagingAreaToolbarImageCountButton />
<StagingAreaToolbarNextButton isDisabled={!shouldShowStagedImage} />
<StagingAreaToolbarNextButton />
</ButtonGroup>
<ButtonGroup borderRadius="base" shadow="dark-lg">
<StagingAreaToolbarAcceptButton />
<StagingAreaToolbarToggleShowResultsButton />
<StagingAreaToolbarSaveSelectedToGalleryButton />
<StagingAreaToolbarMenu />
<StagingAreaToolbarDiscardSelectedButton isDisabled={!shouldShowStagedImage} />
<StagingAreaToolbarDiscardSelectedButton />
</ButtonGroup>
<ButtonGroup borderRadius="base" shadow="dark-lg">
<StagingAreaAutoSwitchButtons />
</ButtonGroup>
<ButtonGroup borderRadius="base" shadow="dark-lg">
<StagingAreaToolbarDiscardAllButton isDisabled={!shouldShowStagedImage} />
<StagingAreaToolbarDiscardAllButton />
</ButtonGroup>
</Flex>
);

View File

@@ -1,64 +1,32 @@
import { IconButton } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useIsRegionFocused } from 'common/hooks/focus';
import { useCanvasSessionContext } from 'features/controlLayers/components/SimpleSession/context';
import { useStagingAreaContext } from 'features/controlLayers/components/StagingArea/context';
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import { rasterLayerAdded } from 'features/controlLayers/store/canvasSlice';
import { canvasSessionReset } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { selectBboxRect, selectSelectedEntityIdentifier } from 'features/controlLayers/store/selectors';
import type { CanvasRasterLayerState } from 'features/controlLayers/store/types';
import { imageNameToImageObject } from 'features/controlLayers/store/util';
import { useCancelQueueItemsByDestination } from 'features/queue/hooks/useCancelQueueItemsByDestination';
import { memo, useCallback } from 'react';
import { memo } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next';
import { PiCheckBold } from 'react-icons/pi';
export const StagingAreaToolbarAcceptButton = memo(() => {
const ctx = useCanvasSessionContext();
const dispatch = useAppDispatch();
const ctx = useStagingAreaContext();
const canvasManager = useCanvasManager();
const bboxRect = useAppSelector(selectBboxRect);
const selectedEntityIdentifier = useAppSelector(selectSelectedEntityIdentifier);
const shouldShowStagedImage = useStore(canvasManager.stagingArea.$shouldShowStagedImage);
const isCanvasFocused = useIsRegionFocused('canvas');
const selectedItemImageDTO = useStore(ctx.$selectedItemOutputImageDTO);
const cancelQueueItemsByDestination = useCancelQueueItemsByDestination();
const acceptSelectedIsEnabled = useStore(ctx.$acceptSelectedIsEnabled);
const { t } = useTranslation();
const acceptSelected = useCallback(() => {
if (!selectedItemImageDTO) {
return;
}
const { x, y, width, height } = bboxRect;
const imageObject = imageNameToImageObject(selectedItemImageDTO.image_name, { width, height });
const overrides: Partial<CanvasRasterLayerState> = {
position: { x, y },
objects: [imageObject],
};
dispatch(rasterLayerAdded({ overrides, isSelected: selectedEntityIdentifier?.type === 'raster_layer' }));
dispatch(canvasSessionReset());
cancelQueueItemsByDestination.trigger(ctx.session.id, { withToast: false });
}, [
selectedItemImageDTO,
bboxRect,
dispatch,
selectedEntityIdentifier?.type,
cancelQueueItemsByDestination,
ctx.session.id,
]);
useHotkeys(
['enter'],
acceptSelected,
ctx.acceptSelected,
{
preventDefault: true,
enabled: isCanvasFocused && shouldShowStagedImage && selectedItemImageDTO !== null,
enabled: isCanvasFocused && shouldShowStagedImage && acceptSelectedIsEnabled,
},
[isCanvasFocused, shouldShowStagedImage, selectedItemImageDTO]
[ctx.acceptSelected, isCanvasFocused, shouldShowStagedImage, acceptSelectedIsEnabled]
);
return (
@@ -66,9 +34,9 @@ export const StagingAreaToolbarAcceptButton = memo(() => {
tooltip={`${t('common.accept')} (Enter)`}
aria-label={`${t('common.accept')} (Enter)`}
icon={<PiCheckBold />}
onClick={acceptSelected}
onClick={ctx.acceptSelected}
colorScheme="invokeBlue"
isDisabled={!selectedItemImageDTO || !shouldShowStagedImage || cancelQueueItemsByDestination.isDisabled}
isDisabled={!acceptSelectedIsEnabled || !shouldShowStagedImage || cancelQueueItemsByDestination.isDisabled}
isLoading={cancelQueueItemsByDestination.isLoading}
/>
);

View File

@@ -1,28 +1,28 @@
import { IconButton } from '@invoke-ai/ui-library';
import { useCanvasSessionContext } from 'features/controlLayers/components/SimpleSession/context';
import { useStore } from '@nanostores/react';
import { useStagingAreaContext } from 'features/controlLayers/components/StagingArea/context';
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import { useCancelQueueItemsByDestination } from 'features/queue/hooks/useCancelQueueItemsByDestination';
import { memo, useCallback } from 'react';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiTrashSimpleBold } from 'react-icons/pi';
export const StagingAreaToolbarDiscardAllButton = memo(({ isDisabled }: { isDisabled?: boolean }) => {
const ctx = useCanvasSessionContext();
export const StagingAreaToolbarDiscardAllButton = memo(() => {
const canvasManager = useCanvasManager();
const shouldShowStagedImage = useStore(canvasManager.stagingArea.$shouldShowStagedImage);
const ctx = useStagingAreaContext();
const { t } = useTranslation();
const cancelQueueItemsByDestination = useCancelQueueItemsByDestination();
const discardAll = useCallback(() => {
ctx.discardAll();
cancelQueueItemsByDestination.trigger(ctx.session.id, { withToast: false });
}, [cancelQueueItemsByDestination, ctx]);
return (
<IconButton
tooltip={`${t('controlLayers.stagingArea.discardAll')} (Esc)`}
aria-label={t('controlLayers.stagingArea.discardAll')}
icon={<PiTrashSimpleBold />}
onClick={discardAll}
onClick={ctx.discardAll}
colorScheme="error"
isDisabled={isDisabled || cancelQueueItemsByDestination.isDisabled}
isDisabled={cancelQueueItemsByDestination.isDisabled || !shouldShowStagedImage}
isLoading={cancelQueueItemsByDestination.isLoading}
/>
);

View File

@@ -1,34 +1,30 @@
import { IconButton } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useCanvasSessionContext } from 'features/controlLayers/components/SimpleSession/context';
import { useStagingAreaContext } from 'features/controlLayers/components/StagingArea/context';
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import { useCancelQueueItem } from 'features/queue/hooks/useCancelQueueItem';
import { memo, useCallback } from 'react';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiXBold } from 'react-icons/pi';
export const StagingAreaToolbarDiscardSelectedButton = memo(({ isDisabled }: { isDisabled?: boolean }) => {
const ctx = useCanvasSessionContext();
export const StagingAreaToolbarDiscardSelectedButton = memo(() => {
const canvasManager = useCanvasManager();
const shouldShowStagedImage = useStore(canvasManager.stagingArea.$shouldShowStagedImage);
const ctx = useStagingAreaContext();
const cancelQueueItem = useCancelQueueItem();
const selectedItemId = useStore(ctx.$selectedItemId);
const discardSelectedIsEnabled = useStore(ctx.$discardSelectedIsEnabled);
const { t } = useTranslation();
const discardSelected = useCallback(async () => {
if (selectedItemId === null) {
return;
}
ctx.discard(selectedItemId);
await cancelQueueItem.trigger(selectedItemId, { withToast: false });
}, [selectedItemId, ctx, cancelQueueItem]);
return (
<IconButton
tooltip={t('controlLayers.stagingArea.discard')}
aria-label={t('controlLayers.stagingArea.discard')}
icon={<PiXBold />}
onClick={discardSelected}
onClick={ctx.discardSelected}
colorScheme="invokeBlue"
isDisabled={selectedItemId === null || cancelQueueItem.isDisabled || isDisabled}
isDisabled={!discardSelectedIsEnabled || cancelQueueItem.isDisabled || !shouldShowStagedImage}
isLoading={cancelQueueItem.isLoading}
/>
);

View File

@@ -1,23 +1,27 @@
import { Button } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useCanvasSessionContext } from 'features/controlLayers/components/SimpleSession/context';
import { useStagingAreaContext } from 'features/controlLayers/components/StagingArea/context';
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import { memo, useMemo } from 'react';
export const StagingAreaToolbarImageCountButton = memo(() => {
const ctx = useCanvasSessionContext();
const selectItemIndex = useStore(ctx.$selectedItemIndex);
const canvasManager = useCanvasManager();
const shouldShowStagedImage = useStore(canvasManager.stagingArea.$shouldShowStagedImage);
const ctx = useStagingAreaContext();
const selectedItem = useStore(ctx.$selectedItem);
const itemCount = useStore(ctx.$itemCount);
const counterText = useMemo(() => {
if (itemCount > 0 && selectItemIndex !== null) {
return `${selectItemIndex + 1} of ${itemCount}`;
if (itemCount > 0 && selectedItem !== null) {
return `${selectedItem.index + 1} of ${itemCount}`;
} else {
return `0 of 0`;
}
}, [itemCount, selectItemIndex]);
}, [itemCount, selectedItem]);
return (
<Button colorScheme="base" pointerEvents="none" minW={28}>
<Button colorScheme="base" pointerEvents="none" minW={28} isDisabled={!shouldShowStagedImage}>
{counterText}
</Button>
);

View File

@@ -1,12 +1,23 @@
import { IconButton, Menu, MenuButton, MenuList } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { StagingAreaToolbarNewLayerFromImageMenuItems } from 'features/controlLayers/components/StagingArea/StagingAreaToolbarMenuNewLayerFromImage';
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import { memo } from 'react';
import { PiDotsThreeVerticalBold } from 'react-icons/pi';
export const StagingAreaToolbarMenu = memo(() => {
const canvasManager = useCanvasManager();
const shouldShowStagedImage = useStore(canvasManager.stagingArea.$shouldShowStagedImage);
return (
<Menu>
<MenuButton as={IconButton} icon={<PiDotsThreeVerticalBold />} colorScheme="invokeBlue" />
<MenuButton
tooltip="Image Actions"
as={IconButton}
icon={<PiDotsThreeVerticalBold />}
colorScheme="invokeBlue"
isDisabled={!shouldShowStagedImage}
/>
<MenuList>
<StagingAreaToolbarNewLayerFromImageMenuItems />
</MenuList>

View File

@@ -2,7 +2,7 @@ import { MenuGroup, MenuItem } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useAppStore } from 'app/store/storeHooks';
import { NewLayerIcon } from 'features/controlLayers/components/common/icons';
import { useCanvasSessionContext } from 'features/controlLayers/components/SimpleSession/context';
import { useStagingAreaContext } from 'features/controlLayers/components/StagingArea/context';
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import { createNewCanvasEntityFromImage } from 'features/imageActions/actions';
import { toast } from 'features/toast/toast';
@@ -15,8 +15,8 @@ const uploadImageArg = { image_category: 'general', is_intermediate: true, silen
export const StagingAreaToolbarNewLayerFromImageMenuItems = memo(() => {
const canvasManager = useCanvasManager();
const { t } = useTranslation();
const ctx = useCanvasSessionContext();
const selectedItemOutputImageDTO = useStore(ctx.$selectedItemOutputImageDTO);
const ctx = useStagingAreaContext();
const selectedItemImageDTO = useStore(ctx.$selectedItemImageDTO);
const store = useAppStore();
const shouldShowStagedImage = useStore(canvasManager.stagingArea.$shouldShowStagedImage);
@@ -29,11 +29,11 @@ export const StagingAreaToolbarNewLayerFromImageMenuItems = memo(() => {
}, [t]);
const onClickNewRasterLayerFromImage = useCallback(async () => {
if (!selectedItemOutputImageDTO) {
if (!selectedItemImageDTO) {
return;
}
const { dispatch, getState } = store;
const imageDTO = await copyImage(selectedItemOutputImageDTO.image_name, uploadImageArg);
const imageDTO = await copyImage(selectedItemImageDTO.image_name, uploadImageArg);
createNewCanvasEntityFromImage({
imageDTO,
type: 'raster_layer',
@@ -42,14 +42,14 @@ export const StagingAreaToolbarNewLayerFromImageMenuItems = memo(() => {
overrides: { isEnabled: false }, // We are adding the layer while staging, it should be disabled by default
});
toastSentToCanvas();
}, [selectedItemOutputImageDTO, store, toastSentToCanvas]);
}, [selectedItemImageDTO, store, toastSentToCanvas]);
const onClickNewControlLayerFromImage = useCallback(async () => {
if (!selectedItemOutputImageDTO) {
if (!selectedItemImageDTO) {
return;
}
const { dispatch, getState } = store;
const imageDTO = await copyImage(selectedItemOutputImageDTO.image_name, uploadImageArg);
const imageDTO = await copyImage(selectedItemImageDTO.image_name, uploadImageArg);
createNewCanvasEntityFromImage({
imageDTO,
@@ -59,14 +59,14 @@ export const StagingAreaToolbarNewLayerFromImageMenuItems = memo(() => {
overrides: { isEnabled: false }, // We are adding the layer while staging, it should be disabled by default
});
toastSentToCanvas();
}, [selectedItemOutputImageDTO, store, toastSentToCanvas]);
}, [selectedItemImageDTO, store, toastSentToCanvas]);
const onClickNewInpaintMaskFromImage = useCallback(async () => {
if (!selectedItemOutputImageDTO) {
if (!selectedItemImageDTO) {
return;
}
const { dispatch, getState } = store;
const imageDTO = await copyImage(selectedItemOutputImageDTO.image_name, uploadImageArg);
const imageDTO = await copyImage(selectedItemImageDTO.image_name, uploadImageArg);
createNewCanvasEntityFromImage({
imageDTO,
@@ -76,14 +76,14 @@ export const StagingAreaToolbarNewLayerFromImageMenuItems = memo(() => {
overrides: { isEnabled: false }, // We are adding the layer while staging, it should be disabled by default
});
toastSentToCanvas();
}, [selectedItemOutputImageDTO, store, toastSentToCanvas]);
}, [selectedItemImageDTO, store, toastSentToCanvas]);
const onClickNewRegionalGuidanceFromImage = useCallback(async () => {
if (!selectedItemOutputImageDTO) {
if (!selectedItemImageDTO) {
return;
}
const { dispatch, getState } = store;
const imageDTO = await copyImage(selectedItemOutputImageDTO.image_name, uploadImageArg);
const imageDTO = await copyImage(selectedItemImageDTO.image_name, uploadImageArg);
createNewCanvasEntityFromImage({
imageDTO,
@@ -93,35 +93,35 @@ export const StagingAreaToolbarNewLayerFromImageMenuItems = memo(() => {
overrides: { isEnabled: false }, // We are adding the layer while staging, it should be disabled by default
});
toastSentToCanvas();
}, [selectedItemOutputImageDTO, store, toastSentToCanvas]);
}, [selectedItemImageDTO, store, toastSentToCanvas]);
return (
<MenuGroup title="New Layer From Image">
<MenuItem
icon={<NewLayerIcon />}
onClickCapture={onClickNewInpaintMaskFromImage}
isDisabled={!selectedItemOutputImageDTO || !shouldShowStagedImage}
isDisabled={!selectedItemImageDTO || !shouldShowStagedImage}
>
{t('controlLayers.inpaintMask')}
</MenuItem>
<MenuItem
icon={<NewLayerIcon />}
onClickCapture={onClickNewRegionalGuidanceFromImage}
isDisabled={!selectedItemOutputImageDTO || !shouldShowStagedImage}
isDisabled={!selectedItemImageDTO || !shouldShowStagedImage}
>
{t('controlLayers.regionalGuidance')}
</MenuItem>
<MenuItem
icon={<NewLayerIcon />}
onClickCapture={onClickNewControlLayerFromImage}
isDisabled={!selectedItemOutputImageDTO || !shouldShowStagedImage}
isDisabled={!selectedItemImageDTO || !shouldShowStagedImage}
>
{t('controlLayers.controlLayer')}
</MenuItem>
<MenuItem
icon={<NewLayerIcon />}
onClickCapture={onClickNewRasterLayerFromImage}
isDisabled={!selectedItemOutputImageDTO || !shouldShowStagedImage}
isDisabled={!selectedItemImageDTO || !shouldShowStagedImage}
>
{t('controlLayers.rasterLayer')}
</MenuItem>

View File

@@ -1,14 +1,18 @@
import { IconButton } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useIsRegionFocused } from 'common/hooks/focus';
import { useCanvasSessionContext } from 'features/controlLayers/components/SimpleSession/context';
import { useStagingAreaContext } from 'features/controlLayers/components/StagingArea/context';
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import { memo, useCallback } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next';
import { PiArrowRightBold } from 'react-icons/pi';
export const StagingAreaToolbarNextButton = memo(({ isDisabled }: { isDisabled?: boolean }) => {
const ctx = useCanvasSessionContext();
export const StagingAreaToolbarNextButton = memo(() => {
const canvasManager = useCanvasManager();
const shouldShowStagedImage = useStore(canvasManager.stagingArea.$shouldShowStagedImage);
const ctx = useStagingAreaContext();
const itemCount = useStore(ctx.$itemCount);
const isCanvasFocused = useIsRegionFocused('canvas');
@@ -23,9 +27,9 @@ export const StagingAreaToolbarNextButton = memo(({ isDisabled }: { isDisabled?:
ctx.selectNext,
{
preventDefault: true,
enabled: isCanvasFocused && !isDisabled && itemCount > 1,
enabled: isCanvasFocused && shouldShowStagedImage && itemCount > 1,
},
[isCanvasFocused, isDisabled, itemCount, ctx.selectNext]
[isCanvasFocused, shouldShowStagedImage, itemCount, ctx.selectNext]
);
return (
@@ -35,7 +39,7 @@ export const StagingAreaToolbarNextButton = memo(({ isDisabled }: { isDisabled?:
icon={<PiArrowRightBold />}
onClick={selectNext}
colorScheme="invokeBlue"
isDisabled={itemCount <= 1 || isDisabled}
isDisabled={itemCount <= 1 || !shouldShowStagedImage}
/>
);
});

View File

@@ -1,14 +1,17 @@
import { IconButton } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useIsRegionFocused } from 'common/hooks/focus';
import { useCanvasSessionContext } from 'features/controlLayers/components/SimpleSession/context';
import { useStagingAreaContext } from 'features/controlLayers/components/StagingArea/context';
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import { memo, useCallback } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next';
import { PiArrowLeftBold } from 'react-icons/pi';
export const StagingAreaToolbarPrevButton = memo(({ isDisabled }: { isDisabled?: boolean }) => {
const ctx = useCanvasSessionContext();
export const StagingAreaToolbarPrevButton = memo(() => {
const canvasManager = useCanvasManager();
const shouldShowStagedImage = useStore(canvasManager.stagingArea.$shouldShowStagedImage);
const ctx = useStagingAreaContext();
const itemCount = useStore(ctx.$itemCount);
const isCanvasFocused = useIsRegionFocused('canvas');
@@ -23,9 +26,9 @@ export const StagingAreaToolbarPrevButton = memo(({ isDisabled }: { isDisabled?:
ctx.selectPrev,
{
preventDefault: true,
enabled: isCanvasFocused && !isDisabled && itemCount > 1,
enabled: isCanvasFocused && shouldShowStagedImage && itemCount > 1,
},
[isCanvasFocused, isDisabled, itemCount, ctx.selectPrev]
[isCanvasFocused, shouldShowStagedImage, itemCount, ctx.selectPrev]
);
return (
@@ -35,7 +38,7 @@ export const StagingAreaToolbarPrevButton = memo(({ isDisabled }: { isDisabled?:
icon={<PiArrowLeftBold />}
onClick={selectPrev}
colorScheme="invokeBlue"
isDisabled={itemCount <= 1 || isDisabled}
isDisabled={itemCount <= 1 || !shouldShowStagedImage}
/>
);
});

View File

@@ -2,7 +2,7 @@ import { IconButton } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useAppSelector } from 'app/store/storeHooks';
import { withResultAsync } from 'common/util/result';
import { useCanvasSessionContext } from 'features/controlLayers/components/SimpleSession/context';
import { useStagingAreaContext } from 'features/controlLayers/components/StagingArea/context';
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import { selectAutoAddBoardId } from 'features/gallery/store/gallerySelectors';
import { toast } from 'features/toast/toast';
@@ -16,14 +16,14 @@ const TOAST_ID = 'SAVE_STAGING_AREA_IMAGE_TO_GALLERY';
export const StagingAreaToolbarSaveSelectedToGalleryButton = memo(() => {
const canvasManager = useCanvasManager();
const autoAddBoardId = useAppSelector(selectAutoAddBoardId);
const ctx = useCanvasSessionContext();
const selectedItemOutputImageDTO = useStore(ctx.$selectedItemOutputImageDTO);
const ctx = useStagingAreaContext();
const selectedItemImageDTO = useStore(ctx.$selectedItemImageDTO);
const shouldShowStagedImage = useStore(canvasManager.stagingArea.$shouldShowStagedImage);
const { t } = useTranslation();
const saveSelectedImageToGallery = useCallback(async () => {
if (!selectedItemOutputImageDTO) {
if (!selectedItemImageDTO) {
return;
}
@@ -31,7 +31,7 @@ export const StagingAreaToolbarSaveSelectedToGalleryButton = memo(() => {
// the gallery without borking the canvas, which may need this image to exist.
const result = await withResultAsync(async () => {
// Create a new file with the same name, which we will upload
await copyImage(selectedItemOutputImageDTO.image_name, {
await copyImage(selectedItemImageDTO.image_name, {
// Image should show up in the Images tab
image_category: 'general',
is_intermediate: false,
@@ -55,7 +55,7 @@ export const StagingAreaToolbarSaveSelectedToGalleryButton = memo(() => {
status: 'error',
});
}
}, [autoAddBoardId, selectedItemOutputImageDTO, t]);
}, [autoAddBoardId, selectedItemImageDTO, t]);
return (
<IconButton
@@ -64,7 +64,7 @@ export const StagingAreaToolbarSaveSelectedToGalleryButton = memo(() => {
icon={<PiFloppyDiskBold />}
onClick={saveSelectedImageToGallery}
colorScheme="invokeBlue"
isDisabled={!selectedItemOutputImageDTO || !shouldShowStagedImage}
isDisabled={!selectedItemImageDTO || !shouldShowStagedImage}
/>
);
});

View File

@@ -0,0 +1,172 @@
import { merge } from 'es-toolkit';
import type { StagingAreaAppApi } from 'features/controlLayers/components/StagingArea/state';
import type { AutoSwitchMode } from 'features/controlLayers/store/canvasSettingsSlice';
import type { ImageDTO, S } from 'services/api/types';
import type { PartialDeep } from 'type-fest';
import { vi } from 'vitest';
export const createMockStagingAreaApp = (): StagingAreaAppApi & {
// Additional methods for testing
_triggerItemsChanged: (items: S['SessionQueueItem'][]) => void;
_triggerQueueItemStatusChanged: (data: S['QueueItemStatusChangedEvent']) => void;
_triggerInvocationProgress: (data: S['InvocationProgressEvent']) => void;
_setAutoSwitchMode: (mode: AutoSwitchMode) => void;
_setImageDTO: (imageName: string, imageDTO: ImageDTO | null) => void;
} => {
const itemsChangedHandlers = new Set<(items: S['SessionQueueItem'][]) => void>();
const queueItemStatusChangedHandlers = new Set<(data: S['QueueItemStatusChangedEvent']) => void>();
const invocationProgressHandlers = new Set<(data: S['InvocationProgressEvent']) => void>();
let autoSwitchMode: AutoSwitchMode = 'switch_on_start';
const imageDTOs = new Map<string, ImageDTO | null>();
return {
onDiscard: vi.fn(),
onDiscardAll: vi.fn(),
onAccept: vi.fn(),
onSelect: vi.fn(),
onSelectPrev: vi.fn(),
onSelectNext: vi.fn(),
onSelectFirst: vi.fn(),
onSelectLast: vi.fn(),
getAutoSwitch: vi.fn(() => autoSwitchMode),
onAutoSwitchChange: vi.fn(),
getImageDTO: vi.fn((imageName: string) => {
return Promise.resolve(imageDTOs.get(imageName) || null);
}),
onItemsChanged: vi.fn((handler) => {
itemsChangedHandlers.add(handler);
return () => itemsChangedHandlers.delete(handler);
}),
onQueueItemStatusChanged: vi.fn((handler) => {
queueItemStatusChangedHandlers.add(handler);
return () => queueItemStatusChangedHandlers.delete(handler);
}),
onInvocationProgress: vi.fn((handler) => {
invocationProgressHandlers.add(handler);
return () => invocationProgressHandlers.delete(handler);
}),
// Testing helper methods
_triggerItemsChanged: (items: S['SessionQueueItem'][]) => {
itemsChangedHandlers.forEach((handler) => handler(items));
},
_triggerQueueItemStatusChanged: (data: S['QueueItemStatusChangedEvent']) => {
queueItemStatusChangedHandlers.forEach((handler) => handler(data));
},
_triggerInvocationProgress: (data: S['InvocationProgressEvent']) => {
invocationProgressHandlers.forEach((handler) => handler(data));
},
_setAutoSwitchMode: (mode: AutoSwitchMode) => {
autoSwitchMode = mode;
},
_setImageDTO: (imageName: string, imageDTO: ImageDTO | null) => {
imageDTOs.set(imageName, imageDTO);
},
};
};
export const createMockQueueItem = (overrides: PartialDeep<S['SessionQueueItem']> = {}): S['SessionQueueItem'] =>
merge(
{
item_id: 1,
batch_id: 'test-batch-id',
session_id: 'test-session',
queue_id: 'test-queue-id',
status: 'pending',
priority: 0,
origin: null,
destination: 'test-session',
error_type: null,
error_message: null,
error_traceback: null,
created_at: '2024-01-01T00:00:00Z',
updated_at: '2024-01-01T00:00:00Z',
started_at: null,
completed_at: null,
field_values: null,
retried_from_item_id: null,
is_api_validation_run: false,
published_workflow_id: null,
session: {
id: 'test-session',
graph: {},
execution_graph: {},
executed: [],
executed_history: [],
results: {
'test-node-id': {
image: {
image_name: 'test-image.png',
},
},
},
errors: {},
prepared_source_mapping: {},
source_prepared_mapping: {
canvas_output: ['test-node-id'],
},
},
workflow: null,
},
overrides
) as S['SessionQueueItem'];
export const createMockImageDTO = (overrides: Partial<ImageDTO> = {}): ImageDTO => ({
image_name: 'test-image.png',
image_url: 'http://test.com/test-image.png',
thumbnail_url: 'http://test.com/test-image-thumb.png',
image_origin: 'internal',
image_category: 'general',
width: 512,
height: 512,
created_at: '2024-01-01T00:00:00Z',
updated_at: '2024-01-01T00:00:00Z',
deleted_at: null,
is_intermediate: false,
starred: false,
has_workflow: false,
session_id: 'test-session',
node_id: 'test-node-id',
board_id: null,
...overrides,
});
export const createMockProgressEvent = (
overrides: PartialDeep<S['InvocationProgressEvent']> = {}
): S['InvocationProgressEvent'] =>
merge(
{
timestamp: Date.now(),
queue_id: 'test-queue-id',
item_id: 1,
batch_id: 'test-batch-id',
session_id: 'test-session',
origin: null,
destination: 'test-session',
invocation: {},
invocation_source_id: 'test-invocation-source-id',
message: 'Processing...',
percentage: 50,
image: null,
} as S['InvocationProgressEvent'],
overrides
);
export const createMockQueueItemStatusChangedEvent = (
overrides: PartialDeep<S['QueueItemStatusChangedEvent']> = {}
): S['QueueItemStatusChangedEvent'] =>
merge(
{
timestamp: Date.now(),
queue_id: 'test-queue-id',
item_id: 1,
batch_id: 'test-batch-id',
origin: null,
destination: 'test-session',
status: 'completed',
error_type: null,
error_message: null,
} as S['QueueItemStatusChangedEvent'],
overrides
);

View File

@@ -0,0 +1,132 @@
import { useStore } from '@nanostores/react';
import { useAppStore } from 'app/store/storeHooks';
import {
selectStagingAreaAutoSwitch,
settingsStagingAreaAutoSwitchChanged,
} from 'features/controlLayers/store/canvasSettingsSlice';
import { rasterLayerAdded } from 'features/controlLayers/store/canvasSlice';
import {
buildSelectCanvasQueueItems,
canvasQueueItemDiscarded,
canvasSessionReset,
} from 'features/controlLayers/store/canvasStagingAreaSlice';
import { selectBboxRect, selectSelectedEntityIdentifier } from 'features/controlLayers/store/selectors';
import type { CanvasRasterLayerState } from 'features/controlLayers/store/types';
import { imageNameToImageObject } from 'features/controlLayers/store/util';
import type { PropsWithChildren } from 'react';
import { createContext, memo, useContext, useEffect, useMemo, useState } from 'react';
import { getImageDTOSafe } from 'services/api/endpoints/images';
import { queueApi } from 'services/api/endpoints/queue';
import type { S } from 'services/api/types';
import { $socket } from 'services/events/stores';
import { assert } from 'tsafe';
import type { ProgressData, StagingAreaAppApi } from './state';
import { getInitialProgressData, StagingAreaApi } from './state';
const StagingAreaContext = createContext<StagingAreaApi | null>(null);
export const StagingAreaContextProvider = memo(({ children, sessionId }: PropsWithChildren<{ sessionId: string }>) => {
const store = useAppStore();
const socket = useStore($socket);
const stagingAreaAppApi = useMemo<StagingAreaAppApi>(() => {
const selectQueueItems = buildSelectCanvasQueueItems(sessionId);
const _stagingAreaAppApi: StagingAreaAppApi = {
getAutoSwitch: () => selectStagingAreaAutoSwitch(store.getState()),
getImageDTO: (imageName: string) => getImageDTOSafe(imageName),
onInvocationProgress: (handler) => {
socket?.on('invocation_progress', handler);
return () => {
socket?.off('invocation_progress', handler);
};
},
onQueueItemStatusChanged: (handler) => {
socket?.on('queue_item_status_changed', handler);
return () => {
socket?.off('queue_item_status_changed', handler);
};
},
onItemsChanged: (handler) => {
let prev: S['SessionQueueItem'][] = [];
return store.subscribe(() => {
const next = selectQueueItems(store.getState());
if (prev !== next) {
prev = next;
handler(next);
}
});
},
onDiscard: ({ item_id, status }) => {
store.dispatch(canvasQueueItemDiscarded({ itemId: item_id }));
if (status === 'in_progress' || status === 'pending') {
store.dispatch(queueApi.endpoints.cancelQueueItem.initiate({ item_id }, { track: false }));
}
},
onDiscardAll: () => {
store.dispatch(canvasSessionReset());
store.dispatch(
queueApi.endpoints.cancelQueueItemsByDestination.initiate({ destination: sessionId }, { track: false })
);
},
onAccept: (item, imageDTO) => {
const bboxRect = selectBboxRect(store.getState());
const { x, y, width, height } = bboxRect;
const imageObject = imageNameToImageObject(imageDTO.image_name, { width, height });
const selectedEntityIdentifier = selectSelectedEntityIdentifier(store.getState());
const overrides: Partial<CanvasRasterLayerState> = {
position: { x, y },
objects: [imageObject],
};
store.dispatch(rasterLayerAdded({ overrides, isSelected: selectedEntityIdentifier?.type === 'raster_layer' }));
store.dispatch(canvasSessionReset());
store.dispatch(
queueApi.endpoints.cancelQueueItemsByDestination.initiate({ destination: sessionId }, { track: false })
);
},
onAutoSwitchChange: (mode) => {
store.dispatch(settingsStagingAreaAutoSwitchChanged(mode));
},
};
return _stagingAreaAppApi;
}, [sessionId, socket, store]);
const [stagingAreaApi] = useState(() => new StagingAreaApi());
useEffect(() => {
stagingAreaApi.connectToApp(sessionId, stagingAreaAppApi);
// We need to subscribe to the queue items query manually to ensure the staging area actually gets the items
const { unsubscribe: unsubQueueItemsQuery } = store.dispatch(
queueApi.endpoints.listAllQueueItems.initiate({ destination: sessionId })
);
return () => {
stagingAreaApi.cleanup();
unsubQueueItemsQuery();
};
}, [sessionId, stagingAreaApi, stagingAreaAppApi, store]);
return <StagingAreaContext.Provider value={stagingAreaApi}>{children}</StagingAreaContext.Provider>;
});
StagingAreaContextProvider.displayName = 'StagingAreaContextProvider';
export const useStagingAreaContext = () => {
const ctx = useContext(StagingAreaContext);
assert(ctx !== null, "'useStagingAreaContext' must be used within a StagingAreaContextProvider");
return ctx;
};
export const useOutputImageDTO = (itemId: number) => {
const ctx = useStagingAreaContext();
const allProgressData = useStore(ctx.$progressData, { keys: [itemId] });
return allProgressData[itemId]?.imageDTO ?? null;
};
export const useProgressDatum = (itemId: number): ProgressData => {
const ctx = useStagingAreaContext();
const allProgressData = useStore(ctx.$progressData, { keys: [itemId] });
return allProgressData[itemId] ?? getInitialProgressData(itemId);
};

View File

@@ -0,0 +1,205 @@
import type { S } from 'services/api/types';
import { describe, expect, it } from 'vitest';
import { getOutputImageName, getProgressMessage, getQueueItemElementId } from './shared';
describe('StagingAreaApi Utility Functions', () => {
describe('getProgressMessage', () => {
it('should return default message when no data provided', () => {
expect(getProgressMessage()).toBe('Generating');
expect(getProgressMessage(null)).toBe('Generating');
});
it('should format progress message when data is provided', () => {
const progressEvent: S['InvocationProgressEvent'] = {
item_id: 1,
destination: 'test-session',
node_id: 'test-node',
source_node_id: 'test-source-node',
progress: 0.5,
message: 'Processing image...',
image: null,
} as unknown as S['InvocationProgressEvent'];
const result = getProgressMessage(progressEvent);
expect(result).toBe('Processing image...');
});
});
describe('getQueueItemElementId', () => {
it('should generate correct element ID for queue item', () => {
expect(getQueueItemElementId(0)).toBe('queue-item-preview-0');
expect(getQueueItemElementId(5)).toBe('queue-item-preview-5');
expect(getQueueItemElementId(99)).toBe('queue-item-preview-99');
});
});
describe('getOutputImageName', () => {
it('should extract image name from completed queue item', () => {
const queueItem: S['SessionQueueItem'] = {
item_id: 1,
status: 'completed',
priority: 0,
destination: 'test-session',
created_at: '2024-01-01T00:00:00Z',
updated_at: '2024-01-01T00:00:00Z',
started_at: '2024-01-01T00:00:01Z',
completed_at: '2024-01-01T00:01:00Z',
error: null,
session: {
id: 'test-session',
source_prepared_mapping: {
canvas_output: ['output-node-id'],
},
results: {
'output-node-id': {
image: {
image_name: 'test-output.png',
},
},
},
},
} as unknown as S['SessionQueueItem'];
expect(getOutputImageName(queueItem)).toBe('test-output.png');
});
it('should return null when no canvas output node found', () => {
const queueItem = {
item_id: 1,
status: 'completed',
priority: 0,
destination: 'test-session',
created_at: '2024-01-01T00:00:00Z',
updated_at: '2024-01-01T00:00:00Z',
started_at: '2024-01-01T00:00:01Z',
completed_at: '2024-01-01T00:01:00Z',
error: null,
session: {
id: 'test-session',
source_prepared_mapping: {
some_other_node: ['other-node-id'],
},
results: {
'other-node-id': {
type: 'image_output',
image: {
image_name: 'test-output.png',
},
width: 512,
height: 512,
},
},
},
} as unknown as S['SessionQueueItem'];
expect(getOutputImageName(queueItem)).toBe(null);
});
it('should return null when output node has no results', () => {
const queueItem: S['SessionQueueItem'] = {
item_id: 1,
status: 'completed',
priority: 0,
destination: 'test-session',
created_at: '2024-01-01T00:00:00Z',
updated_at: '2024-01-01T00:00:00Z',
started_at: '2024-01-01T00:00:01Z',
completed_at: '2024-01-01T00:01:00Z',
error: null,
session: {
id: 'test-session',
source_prepared_mapping: {
canvas_output: ['output-node-id'],
},
results: {},
},
} as unknown as S['SessionQueueItem'];
expect(getOutputImageName(queueItem)).toBe(null);
});
it('should return null when results contain no image fields', () => {
const queueItem: S['SessionQueueItem'] = {
item_id: 1,
status: 'completed',
priority: 0,
destination: 'test-session',
created_at: '2024-01-01T00:00:00Z',
updated_at: '2024-01-01T00:00:00Z',
started_at: '2024-01-01T00:00:01Z',
completed_at: '2024-01-01T00:01:00Z',
error: null,
session: {
id: 'test-session',
source_prepared_mapping: {
canvas_output: ['output-node-id'],
},
results: {
'output-node-id': {
text: 'some text output',
number: 42,
},
},
},
} as unknown as S['SessionQueueItem'];
expect(getOutputImageName(queueItem)).toBe(null);
});
it('should handle multiple outputs and return first image', () => {
const queueItem: S['SessionQueueItem'] = {
item_id: 1,
status: 'completed',
priority: 0,
destination: 'test-session',
created_at: '2024-01-01T00:00:00Z',
updated_at: '2024-01-01T00:00:00Z',
started_at: '2024-01-01T00:00:01Z',
completed_at: '2024-01-01T00:01:00Z',
error: null,
session: {
id: 'test-session',
source_prepared_mapping: {
canvas_output: ['output-node-id'],
},
results: {
'output-node-id': {
text: 'some text',
first_image: {
image_name: 'first-image.png',
},
second_image: {
image_name: 'second-image.png',
},
},
},
},
} as unknown as S['SessionQueueItem'];
const result = getOutputImageName(queueItem);
expect(result).toBe('first-image.png');
});
it('should handle empty session mapping', () => {
const queueItem: S['SessionQueueItem'] = {
item_id: 1,
status: 'completed',
priority: 0,
destination: 'test-session',
created_at: '2024-01-01T00:00:00Z',
updated_at: '2024-01-01T00:00:00Z',
started_at: '2024-01-01T00:00:01Z',
completed_at: '2024-01-01T00:01:00Z',
error: null,
session: {
id: 'test-session',
source_prepared_mapping: {},
results: {},
},
} as unknown as S['SessionQueueItem'];
expect(getOutputImageName(queueItem)).toBe(null);
});
});
});

Some files were not shown because too many files have changed in this diff Show More