mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-15 09:18:00 -05:00
Compare commits
212 Commits
v6.0.0rc4
...
bria-clone
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bd71459955 | ||
|
|
cacfb183a6 | ||
|
|
564f4f7a60 | ||
|
|
113a118fcf | ||
|
|
1f930cdaf2 | ||
|
|
c490e0ce08 | ||
|
|
7640ee307c | ||
|
|
1f5f70f898 | ||
|
|
1430858112 | ||
|
|
48c27ec117 | ||
|
|
af7737e804 | ||
|
|
a5542370a6 | ||
|
|
3eca0d2ba0 | ||
|
|
307259f096 | ||
|
|
bed01941a5 | ||
|
|
89fa43a3b6 | ||
|
|
d8fcb08abf | ||
|
|
c61bcd9f50 | ||
|
|
3fb0fcbbfb | ||
|
|
db9af5083f | ||
|
|
720f1bb65c | ||
|
|
7dfb318ba2 | ||
|
|
9b024da2b4 | ||
|
|
15ca3b727a | ||
|
|
74ca604ae0 | ||
|
|
6934b05c85 | ||
|
|
1a47a5317c | ||
|
|
bc3ef21c64 | ||
|
|
e329f5ad43 | ||
|
|
c296fd2305 | ||
|
|
e6ad91bf89 | ||
|
|
2f586416a5 | ||
|
|
33b56f421c | ||
|
|
e58ee4c492 | ||
|
|
49691aa07e | ||
|
|
56570f235f | ||
|
|
a2d95cf5b6 | ||
|
|
704dbfd04a | ||
|
|
5d9e078043 | ||
|
|
875cde13ae | ||
|
|
77655aed86 | ||
|
|
0628b92d63 | ||
|
|
9e526d00c2 | ||
|
|
1a24396be8 | ||
|
|
d97e73a565 | ||
|
|
55b14c8aaf | ||
|
|
79f65e57eb | ||
|
|
b4c8950278 | ||
|
|
400b2e9a55 | ||
|
|
3a687c583a | ||
|
|
833950078d | ||
|
|
e698dcb148 | ||
|
|
218386e077 | ||
|
|
4426be9e64 | ||
|
|
86f4cf7857 | ||
|
|
49ae66d94a | ||
|
|
c10865c7ef | ||
|
|
f3478a189a | ||
|
|
43db29176a | ||
|
|
f38922929c | ||
|
|
7d02c58f86 | ||
|
|
6edce8be87 | ||
|
|
31f63e38bd | ||
|
|
78a68ac3a7 | ||
|
|
8cd3bcd1c0 | ||
|
|
264cc5ef46 | ||
|
|
8bfbea5ed3 | ||
|
|
f06a66da07 | ||
|
|
337cae9b22 | ||
|
|
bf926bb7d5 | ||
|
|
18ad9a6af3 | ||
|
|
b6ed31c222 | ||
|
|
200beb5af5 | ||
|
|
f82a948bdd | ||
|
|
dd03e3ddcd | ||
|
|
7561b73e8f | ||
|
|
caa97608c7 | ||
|
|
72a6d1edc1 | ||
|
|
b8bf89c2f1 | ||
|
|
a1ade2b8c0 | ||
|
|
c08a6a852d | ||
|
|
4bdcae1f8f | ||
|
|
4b22c84407 | ||
|
|
c9daf1db30 | ||
|
|
e1139de551 | ||
|
|
44b7b9c29d | ||
|
|
2d55dbe67a | ||
|
|
04ea87b0bb | ||
|
|
06d3cfbe97 | ||
|
|
71e4901313 | ||
|
|
82fb897b62 | ||
|
|
192b00d969 | ||
|
|
7bb25ef1b4 | ||
|
|
62f52c74a8 | ||
|
|
97439c1daa | ||
|
|
b23bff1b53 | ||
|
|
d9a1efbabf | ||
|
|
d4e903ee2d | ||
|
|
bb3e5d16d8 | ||
|
|
e62d3f01a8 | ||
|
|
757ecdbf82 | ||
|
|
694c85b041 | ||
|
|
988d7ba24c | ||
|
|
ac981879ef | ||
|
|
fc71849c24 | ||
|
|
a19aa3b032 | ||
|
|
ef4d5d7377 | ||
|
|
6b0dfd8427 | ||
|
|
7140f2ec72 | ||
|
|
471c010217 | ||
|
|
b1193022f7 | ||
|
|
2152ca092c | ||
|
|
ccc62ba56d | ||
|
|
9cf82de8c5 | ||
|
|
aced349152 | ||
|
|
9e5e1ec0da | ||
|
|
a139885bf7 | ||
|
|
f5423133a8 | ||
|
|
9c9265cdad | ||
|
|
0d67ee6548 | ||
|
|
03c21d1607 | ||
|
|
752e8db1f5 | ||
|
|
85fc861dd9 | ||
|
|
458cbfd874 | ||
|
|
04331c070a | ||
|
|
632ddf0cb4 | ||
|
|
2b193ff416 | ||
|
|
96ee394f9e | ||
|
|
0badc80c0c | ||
|
|
78e6cbf96e | ||
|
|
0b969a661b | ||
|
|
6fe47ec9f8 | ||
|
|
3850dd61f8 | ||
|
|
75520eaf0f | ||
|
|
10e88c58c1 | ||
|
|
30ed4dbd92 | ||
|
|
ed9c090f33 | ||
|
|
d29f65ed22 | ||
|
|
2062ec8ac0 | ||
|
|
49e818338a | ||
|
|
1caab2b9c4 | ||
|
|
50079ea349 | ||
|
|
fffa1b24c4 | ||
|
|
a6d6170387 | ||
|
|
e5fceb0448 | ||
|
|
059baf5b29 | ||
|
|
1be8a9a310 | ||
|
|
7adc33e04d | ||
|
|
7f2dd22d47 | ||
|
|
bb50f4b8a2 | ||
|
|
a48958e0d4 | ||
|
|
e3a1e9af53 | ||
|
|
c6fe11c42f | ||
|
|
4eb1bd67df | ||
|
|
c376f914d2 | ||
|
|
b5d1c47ef7 | ||
|
|
004a52ca65 | ||
|
|
b1d5a51ddf | ||
|
|
2b2498eaa1 | ||
|
|
10dda4440e | ||
|
|
98f78abefa | ||
|
|
cc93fa270f | ||
|
|
014b27680f | ||
|
|
c3d8f875de | ||
|
|
79f9dc6e4a | ||
|
|
6e1c0c1105 | ||
|
|
0362524040 | ||
|
|
dc6656459b | ||
|
|
3ea1b97f6f | ||
|
|
a7c7405ccc | ||
|
|
c391f1117a | ||
|
|
b1e2cb8401 | ||
|
|
db6af134b7 | ||
|
|
7e6cffb00c | ||
|
|
5b187bcb00 | ||
|
|
0843d609a3 | ||
|
|
95bd9cef18 | ||
|
|
931d6521f6 | ||
|
|
e37665ff59 | ||
|
|
56857fbbe6 | ||
|
|
43cfb8a574 | ||
|
|
05b1682d15 | ||
|
|
69a08ee7f2 | ||
|
|
18212c7d8a | ||
|
|
7de26f8e69 | ||
|
|
0652b12a6f | ||
|
|
43a361a00f | ||
|
|
cf68ad9cbc | ||
|
|
ec02a39325 | ||
|
|
e52d7a05c2 | ||
|
|
c9d4e2b761 | ||
|
|
ac26aa9508 | ||
|
|
9ff6ada15b | ||
|
|
e81a115169 | ||
|
|
52827807de | ||
|
|
b631de4cb5 | ||
|
|
099ebdbc37 | ||
|
|
4de6549be9 | ||
|
|
368be34949 | ||
|
|
5baa4bd916 | ||
|
|
4229377532 | ||
|
|
2610772ffd | ||
|
|
193de6a8f2 | ||
|
|
7ea343c787 | ||
|
|
12179dabba | ||
|
|
ef135f9923 | ||
|
|
e6c67cc00f | ||
|
|
179b988148 | ||
|
|
d913a3c85b | ||
|
|
e79525c40c | ||
|
|
f409f913ac | ||
|
|
7a79f61d4c |
26
.github/ISSUE_TEMPLATE/BUG_REPORT.yml
vendored
26
.github/ISSUE_TEMPLATE/BUG_REPORT.yml
vendored
@@ -21,6 +21,20 @@ body:
|
||||
- label: I have searched the existing issues
|
||||
required: true
|
||||
|
||||
- type: dropdown
|
||||
id: install_method
|
||||
attributes:
|
||||
label: Install method
|
||||
description: How did you install Invoke?
|
||||
multiple: false
|
||||
options:
|
||||
- "Invoke's Launcher"
|
||||
- 'Stability Matrix'
|
||||
- 'Pinokio'
|
||||
- 'Manual'
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: __Describe your environment__
|
||||
@@ -76,8 +90,8 @@ body:
|
||||
attributes:
|
||||
label: Version number
|
||||
description: |
|
||||
The version of Invoke you have installed. If it is not the latest version, please update and try again to confirm the issue still exists. If you are testing main, please include the commit hash instead.
|
||||
placeholder: ex. 3.6.1
|
||||
The version of Invoke you have installed. If it is not the [latest version](https://github.com/invoke-ai/InvokeAI/releases/latest), please update and try again to confirm the issue still exists. If you are testing main, please include the commit hash instead.
|
||||
placeholder: ex. v6.0.2
|
||||
validations:
|
||||
required: true
|
||||
|
||||
@@ -85,17 +99,17 @@ body:
|
||||
id: browser-version
|
||||
attributes:
|
||||
label: Browser
|
||||
description: Your web browser and version.
|
||||
description: Your web browser and version, if you do not use the Launcher's provided GUI.
|
||||
placeholder: ex. Firefox 123.0b3
|
||||
validations:
|
||||
required: true
|
||||
required: false
|
||||
|
||||
- type: textarea
|
||||
id: python-deps
|
||||
attributes:
|
||||
label: Python dependencies
|
||||
label: System Information
|
||||
description: |
|
||||
If the problem occurred during image generation, click the gear icon at the bottom left corner, click "About", click the copy button and then paste here.
|
||||
Click the gear icon at the bottom left corner, then click "About". Click the copy button and then paste here.
|
||||
validations:
|
||||
required: false
|
||||
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -190,3 +190,5 @@ installer/update.bat
|
||||
installer/update.sh
|
||||
installer/InvokeAI-Installer/
|
||||
.aider*
|
||||
|
||||
.claude/
|
||||
|
||||
@@ -5,8 +5,7 @@
|
||||
FROM docker.io/node:22-slim AS web-builder
|
||||
ENV PNPM_HOME="/pnpm"
|
||||
ENV PATH="$PNPM_HOME:$PATH"
|
||||
RUN corepack use pnpm@8.x
|
||||
RUN corepack enable
|
||||
RUN corepack use pnpm@10.x && corepack enable
|
||||
|
||||
WORKDIR /build
|
||||
COPY invokeai/frontend/web/ ./
|
||||
|
||||
@@ -41,7 +41,7 @@ If you just want to use Invoke, you should use the [launcher][launcher link].
|
||||
With the modifications made, the install command should look something like this:
|
||||
|
||||
```sh
|
||||
uv pip install -e ".[dev,test,docs,xformers]" --python 3.12 --python-preference only-managed --index=https://download.pytorch.org/whl/cu126 --reinstall
|
||||
uv pip install -e ".[dev,test,docs,xformers]" --python 3.12 --python-preference only-managed --index=https://download.pytorch.org/whl/cu128 --reinstall
|
||||
```
|
||||
|
||||
6. At this point, you should have Invoke installed, a venv set up and activated, and the server running. But you will see a warning in the terminal that no UI was found. If you go to the URL for the server, you won't get a UI.
|
||||
@@ -50,11 +50,11 @@ If you just want to use Invoke, you should use the [launcher][launcher link].
|
||||
|
||||
If you only want to edit the docs, you can stop here and skip to the **Documentation** section below.
|
||||
|
||||
7. Install the frontend dev toolchain:
|
||||
7. Install the frontend dev toolchain, paying attention to versions:
|
||||
|
||||
- [`nodejs`](https://nodejs.org/) (v20+)
|
||||
- [`nodejs`](https://nodejs.org/) (tested on LTS, v22)
|
||||
|
||||
- [`pnpm`](https://pnpm.io/8.x/installation) (must be v8 - not v9!)
|
||||
- [`pnpm`](https://pnpm.io/installation) (tested on v10)
|
||||
|
||||
8. Do a production build of the frontend:
|
||||
|
||||
|
||||
@@ -297,7 +297,7 @@ Migration logic is in [migrations.ts].
|
||||
<!-- links -->
|
||||
|
||||
[pydantic]: https://github.com/pydantic/pydantic 'pydantic'
|
||||
[zod]: https://github.com/colinhacks/zod 'zod/v4'
|
||||
[zod]: https://github.com/colinhacks/zod 'zod'
|
||||
[openapi-types]: https://github.com/kogosoftwarellc/open-api/tree/main/packages/openapi-types 'openapi-types'
|
||||
[reactflow]: https://github.com/xyflow/xyflow 'reactflow'
|
||||
[reactflow-concepts]: https://reactflow.dev/learn/concepts/terms-and-definitions
|
||||
|
||||
@@ -72,7 +72,7 @@ async def upload_image(
|
||||
resize_to: Optional[str] = Body(
|
||||
default=None,
|
||||
description=f"Dimensions to resize the image to, must be stringified tuple of 2 integers. Max total pixel count: {ResizeToDimensions.MAX_SIZE}",
|
||||
example='"[1024,1024]"',
|
||||
examples=['"[1024,1024]"'],
|
||||
),
|
||||
metadata: Optional[str] = Body(
|
||||
default=None,
|
||||
|
||||
@@ -292,7 +292,7 @@ async def get_hugging_face_models(
|
||||
)
|
||||
async def update_model_record(
|
||||
key: Annotated[str, Path(description="Unique key of model")],
|
||||
changes: Annotated[ModelRecordChanges, Body(description="Model config", example=example_model_input)],
|
||||
changes: Annotated[ModelRecordChanges, Body(description="Model config", examples=[example_model_input])],
|
||||
) -> AnyModelConfig:
|
||||
"""Update a model's config."""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
@@ -450,7 +450,7 @@ async def install_model(
|
||||
access_token: Optional[str] = Query(description="access token for the remote resource", default=None),
|
||||
config: ModelRecordChanges = Body(
|
||||
description="Object containing fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
|
||||
example={"name": "string", "description": "string"},
|
||||
examples=[{"name": "string", "description": "string"}],
|
||||
),
|
||||
) -> ModelInstallJob:
|
||||
"""Install a model using a string identifier.
|
||||
|
||||
154
invokeai/app/invocations/bria_controlnet.py
Normal file
154
invokeai/app/invocations/bria_controlnet.py
Normal file
@@ -0,0 +1,154 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
ImageField,
|
||||
InputField,
|
||||
OutputField,
|
||||
UIType,
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.bria.controlnet_aux.open_pose import Body, Face, Hand, OpenposeDetector
|
||||
from invokeai.backend.bria.controlnet_bria import BRIA_CONTROL_MODES
|
||||
from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import DepthAnythingPipeline
|
||||
from invokeai.invocation_api import Classification, ImageOutput
|
||||
|
||||
DEPTH_SMALL_V2_URL = "depth-anything/Depth-Anything-V2-Small-hf"
|
||||
HF_LLLYASVIEL = "https://huggingface.co/lllyasviel/Annotators/resolve/main/"
|
||||
|
||||
class BriaControlNetField(BaseModel):
|
||||
image: ImageField = Field(description="The control image")
|
||||
model: ModelIdentifierField = Field(description="The ControlNet model to use")
|
||||
mode: BRIA_CONTROL_MODES = Field(description="The mode of the ControlNet")
|
||||
conditioning_scale: float = Field(description="The weight given to the ControlNet")
|
||||
|
||||
@invocation_output("bria_controlnet_output")
|
||||
class BriaControlNetOutput(BaseInvocationOutput):
|
||||
"""Bria ControlNet info"""
|
||||
|
||||
control: BriaControlNetField = OutputField(description=FieldDescriptions.control)
|
||||
preprocessed_images: ImageField = OutputField(description="The preprocessed control image")
|
||||
|
||||
|
||||
@invocation(
|
||||
"bria_controlnet",
|
||||
title="ControlNet - Bria",
|
||||
tags=["controlnet", "bria"],
|
||||
category="controlnet",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class BriaControlNetInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Collect Bria ControlNet info to pass to denoiser node."""
|
||||
|
||||
control_image: ImageField = InputField(description="The control image")
|
||||
control_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.controlnet_model, ui_type=UIType.BriaControlNetModel
|
||||
)
|
||||
control_mode: BRIA_CONTROL_MODES = InputField(
|
||||
default="depth", description="The mode of the ControlNet"
|
||||
)
|
||||
control_weight: float = InputField(
|
||||
default=1.0, ge=-1, le=2, description="The weight given to the ControlNet"
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> BriaControlNetOutput:
|
||||
image_in = resize_img(context.images.get_pil(self.control_image.image_name))
|
||||
if self.control_mode == "canny":
|
||||
control_image = extract_canny(image_in)
|
||||
elif self.control_mode == "depth":
|
||||
control_image = extract_depth(image_in, context)
|
||||
elif self.control_mode == "pose":
|
||||
control_image = extract_openpose(image_in, context)
|
||||
elif self.control_mode == "colorgrid":
|
||||
control_image = tile(64, image_in)
|
||||
elif self.control_mode == "recolor":
|
||||
control_image = convert_to_grayscale(image_in)
|
||||
elif self.control_mode == "tile":
|
||||
control_image = tile(16, image_in)
|
||||
|
||||
control_image = resize_img(control_image)
|
||||
image_dto = context.images.save(image=control_image)
|
||||
image_output = ImageOutput.build(image_dto)
|
||||
return BriaControlNetOutput(
|
||||
preprocessed_images=image_output.image,
|
||||
control=BriaControlNetField(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
model=self.control_model,
|
||||
mode=self.control_mode,
|
||||
conditioning_scale=self.control_weight,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
RATIO_CONFIGS_1024 = {
|
||||
0.6666666666666666: {"width": 832, "height": 1248},
|
||||
0.7432432432432432: {"width": 880, "height": 1184},
|
||||
0.8028169014084507: {"width": 912, "height": 1136},
|
||||
1.0: {"width": 1024, "height": 1024},
|
||||
1.2456140350877194: {"width": 1136, "height": 912},
|
||||
1.3454545454545455: {"width": 1184, "height": 880},
|
||||
1.4339622641509433: {"width": 1216, "height": 848},
|
||||
1.5: {"width": 1248, "height": 832},
|
||||
1.5490196078431373: {"width": 1264, "height": 816},
|
||||
1.62: {"width": 1296, "height": 800},
|
||||
1.7708333333333333: {"width": 1360, "height": 768},
|
||||
}
|
||||
|
||||
def extract_depth(image: Image.Image, context: InvocationContext):
|
||||
loaded_model = context.models.load_remote_model(DEPTH_SMALL_V2_URL, DepthAnythingPipeline.load_model)
|
||||
|
||||
with loaded_model as depth_anything_detector:
|
||||
assert isinstance(depth_anything_detector, DepthAnythingPipeline)
|
||||
depth_map = depth_anything_detector.generate_depth(image)
|
||||
return depth_map
|
||||
|
||||
def extract_openpose(image: Image.Image, context: InvocationContext):
|
||||
body_model = context.models.load_remote_model(f"{HF_LLLYASVIEL}body_pose_model.pth", Body)
|
||||
hand_model = context.models.load_remote_model(f"{HF_LLLYASVIEL}hand_pose_model.pth", Hand)
|
||||
face_model = context.models.load_remote_model(f"{HF_LLLYASVIEL}facenet.pth", Face)
|
||||
|
||||
with body_model as body_model, hand_model as hand_model, face_model as face_model:
|
||||
open_pose_model = OpenposeDetector(body_model, hand_model, face_model)
|
||||
processed_image_open_pose = open_pose_model(image, hand_and_face=True)
|
||||
|
||||
processed_image_open_pose = processed_image_open_pose.resize(image.size)
|
||||
return processed_image_open_pose
|
||||
|
||||
|
||||
def extract_canny(input_image):
|
||||
image = np.array(input_image)
|
||||
image = cv2.Canny(image, 100, 200)
|
||||
image = image[:, :, None]
|
||||
image = np.concatenate([image, image, image], axis=2)
|
||||
canny_image = Image.fromarray(image)
|
||||
return canny_image
|
||||
|
||||
|
||||
def convert_to_grayscale(image):
|
||||
gray_image = image.convert('L').convert('RGB')
|
||||
return gray_image
|
||||
|
||||
def tile(downscale_factor, input_image):
|
||||
control_image = input_image.resize((input_image.size[0] // downscale_factor, input_image.size[1] // downscale_factor)).resize(input_image.size, Image.Resampling.NEAREST)
|
||||
return control_image
|
||||
|
||||
def resize_img(control_image):
|
||||
image_ratio = control_image.width / control_image.height
|
||||
ratio = min(RATIO_CONFIGS_1024.keys(), key=lambda k: abs(k - image_ratio))
|
||||
to_height = RATIO_CONFIGS_1024[ratio]["height"]
|
||||
to_width = RATIO_CONFIGS_1024[ratio]["width"]
|
||||
resized_image = control_image.resize((to_width, to_height), resample=Image.Resampling.LANCZOS)
|
||||
return resized_image
|
||||
46
invokeai/app/invocations/bria_decoder.py
Normal file
46
invokeai/app/invocations/bria_decoder.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import torch
|
||||
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.app.invocations.model import VAEField
|
||||
from invokeai.app.invocations.primitives import FieldDescriptions, Input, InputField, LatentsField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.invocation_api import BaseInvocation, Classification, ImageOutput, invocation
|
||||
|
||||
|
||||
@invocation(
|
||||
"bria_decoder",
|
||||
title="Decoder - Bria",
|
||||
tags=["image", "bria"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class BriaDecoderInvocation(BaseInvocation):
|
||||
vae: VAEField = InputField(
|
||||
description=FieldDescriptions.vae,
|
||||
input=Input.Connection,
|
||||
)
|
||||
latents: LatentsField = InputField(
|
||||
description=FieldDescriptions.latents,
|
||||
input=Input.Connection,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = context.tensors.load(self.latents.latents_name)
|
||||
latents = latents.view(1, 64, 64, 4, 2, 2).permute(0, 3, 1, 4, 2, 5).reshape(1, 4, 128, 128)
|
||||
|
||||
with context.models.load(self.vae.vae) as vae:
|
||||
assert isinstance(vae, AutoencoderKL)
|
||||
latents = (latents / vae.config.scaling_factor)
|
||||
latents = latents.to(device=vae.device, dtype=vae.dtype)
|
||||
|
||||
decoded_output = vae.decode(latents)
|
||||
image = decoded_output.sample
|
||||
|
||||
# Convert to numpy with proper gradient handling
|
||||
image = ((image.clamp(-1, 1) + 1) / 2 * 255).cpu().detach().permute(0, 2, 3, 1).numpy().astype("uint8")[0]
|
||||
img = Image.fromarray(image)
|
||||
image_dto = context.images.save(image=img)
|
||||
return ImageOutput.build(image_dto)
|
||||
185
invokeai/app/invocations/bria_denoiser.py
Normal file
185
invokeai/app/invocations/bria_denoiser.py
Normal file
@@ -0,0 +1,185 @@
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
||||
from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
|
||||
|
||||
from invokeai.app.invocations.bria_controlnet import BriaControlNetField
|
||||
from invokeai.app.invocations.fields import Input, InputField, LatentsField, OutputField
|
||||
from invokeai.app.invocations.model import SubModelType, T5EncoderField, TransformerField, VAEField
|
||||
from invokeai.app.invocations.primitives import BaseInvocationOutput, FieldDescriptions
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.bria.controlnet_bria import BriaControlModes, BriaMultiControlNetModel
|
||||
from invokeai.backend.bria.controlnet_utils import prepare_control_images
|
||||
from invokeai.backend.bria.pipeline_bria_controlnet import BriaControlNetPipeline
|
||||
from invokeai.backend.bria.transformer_bria import BriaTransformer2DModel
|
||||
from invokeai.invocation_api import BaseInvocation, Classification, invocation, invocation_output
|
||||
|
||||
|
||||
@invocation_output("bria_denoise_output")
|
||||
class BriaDenoiseInvocationOutput(BaseInvocationOutput):
|
||||
latents: LatentsField = OutputField(description=FieldDescriptions.latents)
|
||||
|
||||
|
||||
@invocation(
|
||||
"bria_denoise",
|
||||
title="Denoise - Bria",
|
||||
tags=["image", "bria"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class BriaDenoiseInvocation(BaseInvocation):
|
||||
num_steps: int = InputField(
|
||||
default=30, title="Number of Steps", description="The number of steps to use for the denoiser"
|
||||
)
|
||||
guidance_scale: float = InputField(
|
||||
default=5.0, title="Guidance Scale", description="The guidance scale to use for the denoiser"
|
||||
)
|
||||
|
||||
transformer: TransformerField = InputField(
|
||||
description="Bria model (Transformer) to load",
|
||||
input=Input.Connection,
|
||||
title="Transformer",
|
||||
)
|
||||
t5_encoder: T5EncoderField = InputField(
|
||||
title="T5Encoder",
|
||||
description=FieldDescriptions.t5_encoder,
|
||||
input=Input.Connection,
|
||||
)
|
||||
vae: VAEField = InputField(
|
||||
description=FieldDescriptions.vae,
|
||||
input=Input.Connection,
|
||||
title="VAE",
|
||||
)
|
||||
latents: LatentsField = InputField(
|
||||
description="Latents to denoise",
|
||||
input=Input.Connection,
|
||||
title="Latents",
|
||||
)
|
||||
latent_image_ids: LatentsField = InputField(
|
||||
description="Latent Image IDs to denoise",
|
||||
input=Input.Connection,
|
||||
title="Latent Image IDs",
|
||||
)
|
||||
pos_embeds: LatentsField = InputField(
|
||||
description="Positive Prompt Embeds",
|
||||
input=Input.Connection,
|
||||
title="Positive Prompt Embeds",
|
||||
)
|
||||
neg_embeds: LatentsField = InputField(
|
||||
description="Negative Prompt Embeds",
|
||||
input=Input.Connection,
|
||||
title="Negative Prompt Embeds",
|
||||
)
|
||||
text_ids: LatentsField = InputField(
|
||||
description="Text IDs",
|
||||
input=Input.Connection,
|
||||
title="Text IDs",
|
||||
)
|
||||
control: BriaControlNetField | list[BriaControlNetField] | None = InputField(
|
||||
description="ControlNet",
|
||||
input=Input.Connection,
|
||||
title="ControlNet",
|
||||
default = None,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> BriaDenoiseInvocationOutput:
|
||||
latents = context.tensors.load(self.latents.latents_name)
|
||||
pos_embeds = context.tensors.load(self.pos_embeds.latents_name)
|
||||
neg_embeds = context.tensors.load(self.neg_embeds.latents_name)
|
||||
text_ids = context.tensors.load(self.text_ids.latents_name)
|
||||
latent_image_ids = context.tensors.load(self.latent_image_ids.latents_name)
|
||||
scheduler_identifier = self.transformer.transformer.model_copy(update={"submodel_type": SubModelType.Scheduler})
|
||||
|
||||
device = None
|
||||
dtype = None
|
||||
with (
|
||||
context.models.load(self.transformer.transformer) as transformer,
|
||||
context.models.load(scheduler_identifier) as scheduler,
|
||||
context.models.load(self.vae.vae) as vae,
|
||||
context.models.load(self.t5_encoder.text_encoder) as t5_encoder,
|
||||
context.models.load(self.t5_encoder.tokenizer) as t5_tokenizer,
|
||||
):
|
||||
assert isinstance(transformer, BriaTransformer2DModel)
|
||||
assert isinstance(scheduler, FlowMatchEulerDiscreteScheduler)
|
||||
assert isinstance(vae, AutoencoderKL)
|
||||
dtype = transformer.dtype
|
||||
device = transformer.device
|
||||
latents, pos_embeds, neg_embeds = (x.to(device, dtype) for x in (latents, pos_embeds, neg_embeds))
|
||||
|
||||
control_model, control_images, control_modes, control_scales = None, None, None, None
|
||||
if self.control is not None:
|
||||
control_model, control_images, control_modes, control_scales = self._prepare_multi_control(
|
||||
context=context,
|
||||
vae=vae,
|
||||
width=1024,
|
||||
height=1024,
|
||||
device=vae.device,
|
||||
)
|
||||
|
||||
pipeline = BriaControlNetPipeline(
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
vae=vae,
|
||||
text_encoder=t5_encoder,
|
||||
tokenizer=t5_tokenizer,
|
||||
controlnet=control_model,
|
||||
)
|
||||
pipeline.to(device=transformer.device, dtype=transformer.dtype)
|
||||
|
||||
latents = pipeline(
|
||||
control_image=control_images,
|
||||
control_mode=control_modes,
|
||||
width=1024,
|
||||
height=1024,
|
||||
controlnet_conditioning_scale=control_scales,
|
||||
num_inference_steps=self.num_steps,
|
||||
max_sequence_length=128,
|
||||
guidance_scale=self.guidance_scale,
|
||||
latents=latents,
|
||||
latent_image_ids=latent_image_ids,
|
||||
text_ids=text_ids,
|
||||
prompt_embeds=pos_embeds,
|
||||
negative_prompt_embeds=neg_embeds,
|
||||
output_type="latent",
|
||||
)[0]
|
||||
|
||||
assert isinstance(latents, torch.Tensor)
|
||||
saved_input_latents_tensor = context.tensors.save(latents)
|
||||
latents_output = LatentsField(latents_name=saved_input_latents_tensor)
|
||||
return BriaDenoiseInvocationOutput(latents=latents_output)
|
||||
|
||||
|
||||
def _prepare_multi_control(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
vae: AutoencoderKL,
|
||||
width: int,
|
||||
height: int,
|
||||
device: torch.device
|
||||
) -> Tuple[BriaMultiControlNetModel, List[torch.Tensor], List[torch.Tensor], List[float]]:
|
||||
|
||||
control = self.control if isinstance(self.control, list) else [self.control]
|
||||
control_images, control_models, control_modes, control_scales = [], [], [], []
|
||||
for controlnet in control:
|
||||
if controlnet is not None:
|
||||
control_models.append(context.models.load(controlnet.model).model)
|
||||
control_modes.append(BriaControlModes[controlnet.mode].value)
|
||||
control_scales.append(controlnet.conditioning_scale)
|
||||
try:
|
||||
control_images.append(context.images.get_pil(controlnet.image.image_name))
|
||||
except Exception:
|
||||
raise FileNotFoundError(f"Control image {controlnet.image.image_name} not found. Make sure not to delete the preprocessed image before finishing the pipeline.")
|
||||
|
||||
control_model = BriaMultiControlNetModel(control_models).to(device)
|
||||
tensored_control_images, tensored_control_modes = prepare_control_images(
|
||||
vae=vae,
|
||||
control_images=control_images,
|
||||
control_modes=control_modes,
|
||||
width=width,
|
||||
height=height,
|
||||
device=device,
|
||||
)
|
||||
return control_model, tensored_control_images, tensored_control_modes, control_scales
|
||||
76
invokeai/app/invocations/bria_latent_sampler.py
Normal file
76
invokeai/app/invocations/bria_latent_sampler.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import torch
|
||||
|
||||
from invokeai.app.invocations.fields import Input, InputField, OutputField
|
||||
from invokeai.app.invocations.model import TransformerField
|
||||
from invokeai.app.invocations.primitives import (
|
||||
BaseInvocationOutput,
|
||||
FieldDescriptions,
|
||||
LatentsField,
|
||||
)
|
||||
from invokeai.backend.bria.pipeline_bria_controlnet import prepare_latents
|
||||
from invokeai.invocation_api import (
|
||||
BaseInvocation,
|
||||
Classification,
|
||||
InvocationContext,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
|
||||
|
||||
@invocation_output("bria_latent_sampler_output")
|
||||
class BriaLatentSamplerInvocationOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a CogView text conditioning tensor."""
|
||||
|
||||
latents: LatentsField = OutputField(description=FieldDescriptions.cond)
|
||||
latent_image_ids: LatentsField = OutputField(description=FieldDescriptions.cond)
|
||||
|
||||
|
||||
@invocation(
|
||||
"bria_latent_sampler",
|
||||
title="Latent Sampler - Bria",
|
||||
tags=["image", "bria"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class BriaLatentSamplerInvocation(BaseInvocation):
|
||||
seed: int = InputField(
|
||||
default=42,
|
||||
title="Seed",
|
||||
description="The seed to use for the latent sampler",
|
||||
)
|
||||
transformer: TransformerField = InputField(
|
||||
description="Bria model (Transformer) to load",
|
||||
input=Input.Connection,
|
||||
title="Transformer",
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> BriaLatentSamplerInvocationOutput:
|
||||
with context.models.load(self.transformer.transformer) as transformer:
|
||||
device = transformer.device
|
||||
dtype = transformer.dtype
|
||||
|
||||
height, width = 1024, 1024
|
||||
generator = torch.Generator(device=device).manual_seed(self.seed)
|
||||
|
||||
num_channels_latents = 4
|
||||
latents, latent_image_ids = prepare_latents(
|
||||
batch_size=1,
|
||||
num_channels_latents=num_channels_latents,
|
||||
height=height,
|
||||
width=width,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
generator=generator,
|
||||
)
|
||||
|
||||
saved_latents_tensor = context.tensors.save(latents)
|
||||
saved_latent_image_ids_tensor = context.tensors.save(latent_image_ids)
|
||||
latents_output = LatentsField(latents_name=saved_latents_tensor)
|
||||
latent_image_ids_output = LatentsField(latents_name=saved_latent_image_ids_tensor)
|
||||
|
||||
return BriaLatentSamplerInvocationOutput(
|
||||
latents=latents_output,
|
||||
latent_image_ids=latent_image_ids_output,
|
||||
)
|
||||
58
invokeai/app/invocations/bria_model_loader.py
Normal file
58
invokeai/app/invocations/bria_model_loader.py
Normal file
@@ -0,0 +1,58 @@
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
||||
from invokeai.app.invocations.model import (
|
||||
ModelIdentifierField,
|
||||
SubModelType,
|
||||
T5EncoderField,
|
||||
TransformerField,
|
||||
VAEField,
|
||||
)
|
||||
from invokeai.invocation_api import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
InvocationContext,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
|
||||
|
||||
@invocation_output("bria_model_loader_output")
|
||||
class BriaModelLoaderOutput(BaseInvocationOutput):
|
||||
"""Bria base model loader output"""
|
||||
|
||||
transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
|
||||
t5_encoder: T5EncoderField = OutputField(description=FieldDescriptions.t5_encoder, title="T5 Encoder")
|
||||
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
|
||||
|
||||
@invocation(
|
||||
"bria_model_loader",
|
||||
title="Main Model - Bria",
|
||||
tags=["model", "bria"],
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class BriaModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads a bria base model, outputting its submodels."""
|
||||
|
||||
model: ModelIdentifierField = InputField(
|
||||
description="Bria model (Transformer) to load",
|
||||
ui_type=UIType.BriaMainModel,
|
||||
input=Input.Direct,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> BriaModelLoaderOutput:
|
||||
for key in [self.model.key]:
|
||||
if not context.models.exists(key):
|
||||
raise ValueError(f"Unknown model: {key}")
|
||||
|
||||
transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
|
||||
text_encoder = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
|
||||
tokenizer = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
|
||||
vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
|
||||
|
||||
return BriaModelLoaderOutput(
|
||||
transformer=TransformerField(transformer=transformer, loras=[]),
|
||||
t5_encoder=T5EncoderField(tokenizer=tokenizer, text_encoder=text_encoder, loras=[]),
|
||||
vae=VAEField(vae=vae),
|
||||
)
|
||||
93
invokeai/app/invocations/bria_text_encoder.py
Normal file
93
invokeai/app/invocations/bria_text_encoder.py
Normal file
@@ -0,0 +1,93 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from transformers import (
|
||||
T5EncoderModel,
|
||||
T5TokenizerFast,
|
||||
)
|
||||
|
||||
from invokeai.app.invocations.model import T5EncoderField
|
||||
from invokeai.app.invocations.primitives import BaseInvocationOutput, FieldDescriptions, Input, OutputField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.bria.pipeline_bria_controlnet import encode_prompt
|
||||
from invokeai.invocation_api import (
|
||||
BaseInvocation,
|
||||
Classification,
|
||||
InputField,
|
||||
LatentsField,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
|
||||
|
||||
@invocation_output("bria_text_encoder_output")
|
||||
class BriaTextEncoderInvocationOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a CogView text conditioning tensor."""
|
||||
|
||||
pos_embeds: LatentsField = OutputField(description=FieldDescriptions.cond)
|
||||
neg_embeds: LatentsField = OutputField(description=FieldDescriptions.cond)
|
||||
text_ids: LatentsField = OutputField(description=FieldDescriptions.cond)
|
||||
|
||||
|
||||
@invocation(
|
||||
"bria_text_encoder",
|
||||
title="Prompt - Bria",
|
||||
tags=["prompt", "conditioning", "bria"],
|
||||
category="conditioning",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class BriaTextEncoderInvocation(BaseInvocation):
|
||||
prompt: str = InputField(
|
||||
title="Prompt",
|
||||
description="The prompt to encode",
|
||||
)
|
||||
negative_prompt: Optional[str] = InputField(
|
||||
title="Negative Prompt",
|
||||
description="The negative prompt to encode",
|
||||
default="Logo,Watermark,Text,Ugly,Morbid,Extra fingers,Poorly drawn hands,Mutation,Blurry,Extra limbs,Gross proportions,Missing arms,Mutated hands,Long neck,Duplicate",
|
||||
)
|
||||
max_length: int = InputField(
|
||||
default=128,
|
||||
title="Max Length",
|
||||
description="The maximum length of the prompt",
|
||||
)
|
||||
t5_encoder: T5EncoderField = InputField(
|
||||
title="T5Encoder",
|
||||
description=FieldDescriptions.t5_encoder,
|
||||
input=Input.Connection,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> BriaTextEncoderInvocationOutput:
|
||||
t5_encoder_info = context.models.load(self.t5_encoder.text_encoder)
|
||||
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer)
|
||||
with (
|
||||
t5_encoder_info as text_encoder,
|
||||
t5_tokenizer_info as tokenizer,
|
||||
):
|
||||
assert isinstance(tokenizer, T5TokenizerFast)
|
||||
assert isinstance(text_encoder, T5EncoderModel)
|
||||
|
||||
(prompt_embeds, negative_prompt_embeds, text_ids) = encode_prompt(
|
||||
prompt=self.prompt,
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
negative_prompt=self.negative_prompt,
|
||||
device=text_encoder.device,
|
||||
num_images_per_prompt=1,
|
||||
max_sequence_length=self.max_length,
|
||||
lora_scale=1.0,
|
||||
)
|
||||
|
||||
saved_pos_tensor = context.tensors.save(prompt_embeds)
|
||||
saved_neg_tensor = context.tensors.save(negative_prompt_embeds)
|
||||
saved_text_ids_tensor = context.tensors.save(text_ids)
|
||||
pos_embeds_output = LatentsField(latents_name=saved_pos_tensor)
|
||||
neg_embeds_output = LatentsField(latents_name=saved_neg_tensor)
|
||||
text_ids_output = LatentsField(latents_name=saved_text_ids_tensor)
|
||||
return BriaTextEncoderInvocationOutput(
|
||||
pos_embeds=pos_embeds_output,
|
||||
neg_embeds=neg_embeds_output,
|
||||
text_ids=text_ids_output,
|
||||
)
|
||||
@@ -42,6 +42,8 @@ class UIType(str, Enum, metaclass=MetaEnum):
|
||||
MainModel = "MainModelField"
|
||||
CogView4MainModel = "CogView4MainModelField"
|
||||
FluxMainModel = "FluxMainModelField"
|
||||
BriaMainModel = "BriaMainModelField"
|
||||
BriaControlNetModel = "BriaControlNetModelField"
|
||||
SD3MainModel = "SD3MainModelField"
|
||||
SDXLMainModel = "SDXLMainModelField"
|
||||
SDXLRefinerModel = "SDXLRefinerModelField"
|
||||
|
||||
@@ -430,6 +430,15 @@ class FluxConditioningOutput(BaseInvocationOutput):
|
||||
return cls(conditioning=FluxConditioningField(conditioning_name=conditioning_name))
|
||||
|
||||
|
||||
@invocation_output("flux_conditioning_collection_output")
|
||||
class FluxConditioningCollectionOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a collection of conditioning tensors"""
|
||||
|
||||
collection: list[FluxConditioningField] = OutputField(
|
||||
description="The output conditioning tensors",
|
||||
)
|
||||
|
||||
|
||||
@invocation_output("sd3_conditioning_output")
|
||||
class SD3ConditioningOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a single SD3 conditioning tensor"""
|
||||
|
||||
@@ -14,15 +14,14 @@ from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
||||
def __init__(self, db: SqliteDatabase) -> None:
|
||||
super().__init__()
|
||||
self._conn = db.conn
|
||||
self._db = db
|
||||
|
||||
def add_image_to_board(
|
||||
self,
|
||||
board_id: str,
|
||||
image_name: str,
|
||||
) -> None:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
INSERT INTO board_images (board_id, image_name)
|
||||
@@ -31,17 +30,12 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
||||
""",
|
||||
(board_id, image_name, board_id),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
|
||||
def remove_image_from_board(
|
||||
self,
|
||||
image_name: str,
|
||||
) -> None:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM board_images
|
||||
@@ -49,10 +43,6 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
||||
""",
|
||||
(image_name,),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
|
||||
def get_images_for_board(
|
||||
self,
|
||||
@@ -60,27 +50,26 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
||||
offset: int = 0,
|
||||
limit: int = 10,
|
||||
) -> OffsetPaginatedResults[ImageRecord]:
|
||||
# TODO: this isn't paginated yet?
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT images.*
|
||||
FROM board_images
|
||||
INNER JOIN images ON board_images.image_name = images.image_name
|
||||
WHERE board_images.board_id = ?
|
||||
ORDER BY board_images.updated_at DESC;
|
||||
""",
|
||||
(board_id,),
|
||||
)
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
images = [deserialize_image_record(dict(r)) for r in result]
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT images.*
|
||||
FROM board_images
|
||||
INNER JOIN images ON board_images.image_name = images.image_name
|
||||
WHERE board_images.board_id = ?
|
||||
ORDER BY board_images.updated_at DESC;
|
||||
""",
|
||||
(board_id,),
|
||||
)
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
images = [deserialize_image_record(dict(r)) for r in result]
|
||||
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT COUNT(*) FROM images WHERE 1=1;
|
||||
"""
|
||||
)
|
||||
count = cast(int, cursor.fetchone()[0])
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT COUNT(*) FROM images WHERE 1=1;
|
||||
"""
|
||||
)
|
||||
count = cast(int, cursor.fetchone()[0])
|
||||
|
||||
return OffsetPaginatedResults(items=images, offset=offset, limit=limit, total=count)
|
||||
|
||||
@@ -90,56 +79,55 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
||||
categories: list[ImageCategory] | None,
|
||||
is_intermediate: bool | None,
|
||||
) -> list[str]:
|
||||
params: list[str | bool] = []
|
||||
with self._db.transaction() as cursor:
|
||||
params: list[str | bool] = []
|
||||
|
||||
# Base query is a join between images and board_images
|
||||
stmt = """
|
||||
SELECT images.image_name
|
||||
FROM images
|
||||
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||
WHERE 1=1
|
||||
"""
|
||||
# Base query is a join between images and board_images
|
||||
stmt = """
|
||||
SELECT images.image_name
|
||||
FROM images
|
||||
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||
WHERE 1=1
|
||||
"""
|
||||
|
||||
# Handle board_id filter
|
||||
if board_id == "none":
|
||||
stmt += """--sql
|
||||
AND board_images.board_id IS NULL
|
||||
"""
|
||||
else:
|
||||
stmt += """--sql
|
||||
AND board_images.board_id = ?
|
||||
"""
|
||||
params.append(board_id)
|
||||
# Handle board_id filter
|
||||
if board_id == "none":
|
||||
stmt += """--sql
|
||||
AND board_images.board_id IS NULL
|
||||
"""
|
||||
else:
|
||||
stmt += """--sql
|
||||
AND board_images.board_id = ?
|
||||
"""
|
||||
params.append(board_id)
|
||||
|
||||
# Add the category filter
|
||||
if categories is not None:
|
||||
# Convert the enum values to unique list of strings
|
||||
category_strings = [c.value for c in set(categories)]
|
||||
# Create the correct length of placeholders
|
||||
placeholders = ",".join("?" * len(category_strings))
|
||||
stmt += f"""--sql
|
||||
AND images.image_category IN ( {placeholders} )
|
||||
"""
|
||||
# Add the category filter
|
||||
if categories is not None:
|
||||
# Convert the enum values to unique list of strings
|
||||
category_strings = [c.value for c in set(categories)]
|
||||
# Create the correct length of placeholders
|
||||
placeholders = ",".join("?" * len(category_strings))
|
||||
stmt += f"""--sql
|
||||
AND images.image_category IN ( {placeholders} )
|
||||
"""
|
||||
|
||||
# Unpack the included categories into the query params
|
||||
for c in category_strings:
|
||||
params.append(c)
|
||||
# Unpack the included categories into the query params
|
||||
for c in category_strings:
|
||||
params.append(c)
|
||||
|
||||
# Add the is_intermediate filter
|
||||
if is_intermediate is not None:
|
||||
stmt += """--sql
|
||||
AND images.is_intermediate = ?
|
||||
"""
|
||||
params.append(is_intermediate)
|
||||
# Add the is_intermediate filter
|
||||
if is_intermediate is not None:
|
||||
stmt += """--sql
|
||||
AND images.is_intermediate = ?
|
||||
"""
|
||||
params.append(is_intermediate)
|
||||
|
||||
# Put a ring on it
|
||||
stmt += ";"
|
||||
# Put a ring on it
|
||||
stmt += ";"
|
||||
|
||||
# Execute the query
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(stmt, params)
|
||||
cursor.execute(stmt, params)
|
||||
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
image_names = [r[0] for r in result]
|
||||
return image_names
|
||||
|
||||
@@ -147,31 +135,31 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
||||
self,
|
||||
image_name: str,
|
||||
) -> Optional[str]:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT board_id
|
||||
FROM board_images
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(image_name,),
|
||||
)
|
||||
result = cursor.fetchone()
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT board_id
|
||||
FROM board_images
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(image_name,),
|
||||
)
|
||||
result = cursor.fetchone()
|
||||
if result is None:
|
||||
return None
|
||||
return cast(str, result[0])
|
||||
|
||||
def get_image_count_for_board(self, board_id: str) -> int:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT COUNT(*)
|
||||
FROM board_images
|
||||
INNER JOIN images ON board_images.image_name = images.image_name
|
||||
WHERE images.is_intermediate = FALSE
|
||||
AND board_images.board_id = ?;
|
||||
""",
|
||||
(board_id,),
|
||||
)
|
||||
count = cast(int, cursor.fetchone()[0])
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT COUNT(*)
|
||||
FROM board_images
|
||||
INNER JOIN images ON board_images.image_name = images.image_name
|
||||
WHERE images.is_intermediate = FALSE
|
||||
AND board_images.board_id = ?;
|
||||
""",
|
||||
(board_id,),
|
||||
)
|
||||
count = cast(int, cursor.fetchone()[0])
|
||||
return count
|
||||
|
||||
@@ -20,61 +20,57 @@ from invokeai.app.util.misc import uuid_string
|
||||
class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
def __init__(self, db: SqliteDatabase) -> None:
|
||||
super().__init__()
|
||||
self._conn = db.conn
|
||||
self._db = db
|
||||
|
||||
def delete(self, board_id: str) -> None:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM boards
|
||||
WHERE board_id = ?;
|
||||
""",
|
||||
(board_id,),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception as e:
|
||||
self._conn.rollback()
|
||||
raise BoardRecordDeleteException from e
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM boards
|
||||
WHERE board_id = ?;
|
||||
""",
|
||||
(board_id,),
|
||||
)
|
||||
except Exception as e:
|
||||
raise BoardRecordDeleteException from e
|
||||
|
||||
def save(
|
||||
self,
|
||||
board_name: str,
|
||||
) -> BoardRecord:
|
||||
try:
|
||||
board_id = uuid_string()
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO boards (board_id, board_name)
|
||||
VALUES (?, ?);
|
||||
""",
|
||||
(board_id, board_name),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BoardRecordSaveException from e
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
board_id = uuid_string()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO boards (board_id, board_name)
|
||||
VALUES (?, ?);
|
||||
""",
|
||||
(board_id, board_name),
|
||||
)
|
||||
except sqlite3.Error as e:
|
||||
raise BoardRecordSaveException from e
|
||||
return self.get(board_id)
|
||||
|
||||
def get(
|
||||
self,
|
||||
board_id: str,
|
||||
) -> BoardRecord:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM boards
|
||||
WHERE board_id = ?;
|
||||
""",
|
||||
(board_id,),
|
||||
)
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM boards
|
||||
WHERE board_id = ?;
|
||||
""",
|
||||
(board_id,),
|
||||
)
|
||||
|
||||
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
|
||||
except sqlite3.Error as e:
|
||||
raise BoardRecordNotFoundException from e
|
||||
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
|
||||
except sqlite3.Error as e:
|
||||
raise BoardRecordNotFoundException from e
|
||||
if result is None:
|
||||
raise BoardRecordNotFoundException
|
||||
return BoardRecord(**dict(result))
|
||||
@@ -84,45 +80,43 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
board_id: str,
|
||||
changes: BoardChanges,
|
||||
) -> BoardRecord:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
# Change the name of a board
|
||||
if changes.board_name is not None:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE boards
|
||||
SET board_name = ?
|
||||
WHERE board_id = ?;
|
||||
""",
|
||||
(changes.board_name, board_id),
|
||||
)
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
# Change the name of a board
|
||||
if changes.board_name is not None:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE boards
|
||||
SET board_name = ?
|
||||
WHERE board_id = ?;
|
||||
""",
|
||||
(changes.board_name, board_id),
|
||||
)
|
||||
|
||||
# Change the cover image of a board
|
||||
if changes.cover_image_name is not None:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE boards
|
||||
SET cover_image_name = ?
|
||||
WHERE board_id = ?;
|
||||
""",
|
||||
(changes.cover_image_name, board_id),
|
||||
)
|
||||
# Change the cover image of a board
|
||||
if changes.cover_image_name is not None:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE boards
|
||||
SET cover_image_name = ?
|
||||
WHERE board_id = ?;
|
||||
""",
|
||||
(changes.cover_image_name, board_id),
|
||||
)
|
||||
|
||||
# Change the archived status of a board
|
||||
if changes.archived is not None:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE boards
|
||||
SET archived = ?
|
||||
WHERE board_id = ?;
|
||||
""",
|
||||
(changes.archived, board_id),
|
||||
)
|
||||
# Change the archived status of a board
|
||||
if changes.archived is not None:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE boards
|
||||
SET archived = ?
|
||||
WHERE board_id = ?;
|
||||
""",
|
||||
(changes.archived, board_id),
|
||||
)
|
||||
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BoardRecordSaveException from e
|
||||
except sqlite3.Error as e:
|
||||
raise BoardRecordSaveException from e
|
||||
return self.get(board_id)
|
||||
|
||||
def get_many(
|
||||
@@ -133,78 +127,77 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
limit: int = 10,
|
||||
include_archived: bool = False,
|
||||
) -> OffsetPaginatedResults[BoardRecord]:
|
||||
cursor = self._conn.cursor()
|
||||
|
||||
# Build base query
|
||||
base_query = """
|
||||
SELECT *
|
||||
FROM boards
|
||||
{archived_filter}
|
||||
ORDER BY {order_by} {direction}
|
||||
LIMIT ? OFFSET ?;
|
||||
"""
|
||||
|
||||
# Determine archived filter condition
|
||||
archived_filter = "" if include_archived else "WHERE archived = 0"
|
||||
|
||||
final_query = base_query.format(
|
||||
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
|
||||
)
|
||||
|
||||
# Execute query to fetch boards
|
||||
cursor.execute(final_query, (limit, offset))
|
||||
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
boards = [deserialize_board_record(dict(r)) for r in result]
|
||||
|
||||
# Determine count query
|
||||
if include_archived:
|
||||
count_query = """
|
||||
SELECT COUNT(*)
|
||||
FROM boards;
|
||||
"""
|
||||
else:
|
||||
count_query = """
|
||||
SELECT COUNT(*)
|
||||
with self._db.transaction() as cursor:
|
||||
# Build base query
|
||||
base_query = """
|
||||
SELECT *
|
||||
FROM boards
|
||||
WHERE archived = 0;
|
||||
{archived_filter}
|
||||
ORDER BY {order_by} {direction}
|
||||
LIMIT ? OFFSET ?;
|
||||
"""
|
||||
|
||||
# Execute count query
|
||||
cursor.execute(count_query)
|
||||
# Determine archived filter condition
|
||||
archived_filter = "" if include_archived else "WHERE archived = 0"
|
||||
|
||||
count = cast(int, cursor.fetchone()[0])
|
||||
final_query = base_query.format(
|
||||
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
|
||||
)
|
||||
|
||||
# Execute query to fetch boards
|
||||
cursor.execute(final_query, (limit, offset))
|
||||
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
boards = [deserialize_board_record(dict(r)) for r in result]
|
||||
|
||||
# Determine count query
|
||||
if include_archived:
|
||||
count_query = """
|
||||
SELECT COUNT(*)
|
||||
FROM boards;
|
||||
"""
|
||||
else:
|
||||
count_query = """
|
||||
SELECT COUNT(*)
|
||||
FROM boards
|
||||
WHERE archived = 0;
|
||||
"""
|
||||
|
||||
# Execute count query
|
||||
cursor.execute(count_query)
|
||||
|
||||
count = cast(int, cursor.fetchone()[0])
|
||||
|
||||
return OffsetPaginatedResults[BoardRecord](items=boards, offset=offset, limit=limit, total=count)
|
||||
|
||||
def get_all(
|
||||
self, order_by: BoardRecordOrderBy, direction: SQLiteDirection, include_archived: bool = False
|
||||
) -> list[BoardRecord]:
|
||||
cursor = self._conn.cursor()
|
||||
if order_by == BoardRecordOrderBy.Name:
|
||||
base_query = """
|
||||
SELECT *
|
||||
FROM boards
|
||||
{archived_filter}
|
||||
ORDER BY LOWER(board_name) {direction}
|
||||
"""
|
||||
else:
|
||||
base_query = """
|
||||
SELECT *
|
||||
FROM boards
|
||||
{archived_filter}
|
||||
ORDER BY {order_by} {direction}
|
||||
"""
|
||||
with self._db.transaction() as cursor:
|
||||
if order_by == BoardRecordOrderBy.Name:
|
||||
base_query = """
|
||||
SELECT *
|
||||
FROM boards
|
||||
{archived_filter}
|
||||
ORDER BY LOWER(board_name) {direction}
|
||||
"""
|
||||
else:
|
||||
base_query = """
|
||||
SELECT *
|
||||
FROM boards
|
||||
{archived_filter}
|
||||
ORDER BY {order_by} {direction}
|
||||
"""
|
||||
|
||||
archived_filter = "" if include_archived else "WHERE archived = 0"
|
||||
archived_filter = "" if include_archived else "WHERE archived = 0"
|
||||
|
||||
final_query = base_query.format(
|
||||
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
|
||||
)
|
||||
final_query = base_query.format(
|
||||
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
|
||||
)
|
||||
|
||||
cursor.execute(final_query)
|
||||
cursor.execute(final_query)
|
||||
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
boards = [deserialize_board_record(dict(r)) for r in result]
|
||||
|
||||
return boards
|
||||
|
||||
@@ -8,6 +8,7 @@ import time
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
from queue import Empty, PriorityQueue
|
||||
from shutil import disk_usage
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Set
|
||||
|
||||
import requests
|
||||
@@ -335,6 +336,14 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
|
||||
assert job.download_path
|
||||
|
||||
free_space = disk_usage(job.download_path.parent).free
|
||||
GB = 2**30
|
||||
self._logger.debug(f"Download is {job.total_bytes / GB:.2f} GB of {free_space / GB:.2f} GB free.")
|
||||
if free_space < job.total_bytes:
|
||||
raise RuntimeError(
|
||||
f"Free disk space {free_space / GB:.2f} GB is not enough for download of {job.total_bytes / GB:.2f} GB."
|
||||
)
|
||||
|
||||
# Don't clobber an existing file. See commit 82c2c85202f88c6d24ff84710f297cfc6ae174af
|
||||
# for code that instead resumes an interrupted download.
|
||||
if job.download_path.exists():
|
||||
|
||||
@@ -24,22 +24,22 @@ from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
def __init__(self, db: SqliteDatabase) -> None:
|
||||
super().__init__()
|
||||
self._conn = db.conn
|
||||
self._db = db
|
||||
|
||||
def get(self, image_name: str) -> ImageRecord:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
f"""--sql
|
||||
SELECT {IMAGE_DTO_COLS} FROM images
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(image_name,),
|
||||
)
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor.execute(
|
||||
f"""--sql
|
||||
SELECT {IMAGE_DTO_COLS} FROM images
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(image_name,),
|
||||
)
|
||||
|
||||
result = cast(Optional[sqlite3.Row], cursor.fetchone())
|
||||
except sqlite3.Error as e:
|
||||
raise ImageRecordNotFoundException from e
|
||||
result = cast(Optional[sqlite3.Row], cursor.fetchone())
|
||||
except sqlite3.Error as e:
|
||||
raise ImageRecordNotFoundException from e
|
||||
|
||||
if not result:
|
||||
raise ImageRecordNotFoundException
|
||||
@@ -47,17 +47,20 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
return deserialize_image_record(dict(result))
|
||||
|
||||
def get_metadata(self, image_name: str) -> Optional[MetadataField]:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT metadata FROM images
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(image_name,),
|
||||
)
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT metadata FROM images
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(image_name,),
|
||||
)
|
||||
|
||||
result = cast(Optional[sqlite3.Row], cursor.fetchone())
|
||||
result = cast(Optional[sqlite3.Row], cursor.fetchone())
|
||||
|
||||
except sqlite3.Error as e:
|
||||
raise ImageRecordNotFoundException from e
|
||||
|
||||
if not result:
|
||||
raise ImageRecordNotFoundException
|
||||
@@ -65,64 +68,60 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
as_dict = dict(result)
|
||||
metadata_raw = cast(Optional[str], as_dict.get("metadata", None))
|
||||
return MetadataFieldValidator.validate_json(metadata_raw) if metadata_raw is not None else None
|
||||
except sqlite3.Error as e:
|
||||
raise ImageRecordNotFoundException from e
|
||||
|
||||
def update(
|
||||
self,
|
||||
image_name: str,
|
||||
changes: ImageRecordChanges,
|
||||
) -> None:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
# Change the category of the image
|
||||
if changes.image_category is not None:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE images
|
||||
SET image_category = ?
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(changes.image_category, image_name),
|
||||
)
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
# Change the category of the image
|
||||
if changes.image_category is not None:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE images
|
||||
SET image_category = ?
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(changes.image_category, image_name),
|
||||
)
|
||||
|
||||
# Change the session associated with the image
|
||||
if changes.session_id is not None:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE images
|
||||
SET session_id = ?
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(changes.session_id, image_name),
|
||||
)
|
||||
# Change the session associated with the image
|
||||
if changes.session_id is not None:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE images
|
||||
SET session_id = ?
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(changes.session_id, image_name),
|
||||
)
|
||||
|
||||
# Change the image's `is_intermediate`` flag
|
||||
if changes.is_intermediate is not None:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE images
|
||||
SET is_intermediate = ?
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(changes.is_intermediate, image_name),
|
||||
)
|
||||
# Change the image's `is_intermediate`` flag
|
||||
if changes.is_intermediate is not None:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE images
|
||||
SET is_intermediate = ?
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(changes.is_intermediate, image_name),
|
||||
)
|
||||
|
||||
# Change the image's `starred`` state
|
||||
if changes.starred is not None:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE images
|
||||
SET starred = ?
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(changes.starred, image_name),
|
||||
)
|
||||
# Change the image's `starred`` state
|
||||
if changes.starred is not None:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE images
|
||||
SET starred = ?
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(changes.starred, image_name),
|
||||
)
|
||||
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise ImageRecordSaveException from e
|
||||
except sqlite3.Error as e:
|
||||
raise ImageRecordSaveException from e
|
||||
|
||||
def get_many(
|
||||
self,
|
||||
@@ -136,170 +135,162 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
board_id: Optional[str] = None,
|
||||
search_term: Optional[str] = None,
|
||||
) -> OffsetPaginatedResults[ImageRecord]:
|
||||
cursor = self._conn.cursor()
|
||||
|
||||
# Manually build two queries - one for the count, one for the records
|
||||
count_query = """--sql
|
||||
SELECT COUNT(*)
|
||||
FROM images
|
||||
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||
WHERE 1=1
|
||||
"""
|
||||
|
||||
images_query = f"""--sql
|
||||
SELECT {IMAGE_DTO_COLS}
|
||||
FROM images
|
||||
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||
WHERE 1=1
|
||||
"""
|
||||
|
||||
query_conditions = ""
|
||||
query_params: list[Union[int, str, bool]] = []
|
||||
|
||||
if image_origin is not None:
|
||||
query_conditions += """--sql
|
||||
AND images.image_origin = ?
|
||||
"""
|
||||
query_params.append(image_origin.value)
|
||||
|
||||
if categories is not None:
|
||||
# Convert the enum values to unique list of strings
|
||||
category_strings = [c.value for c in set(categories)]
|
||||
# Create the correct length of placeholders
|
||||
placeholders = ",".join("?" * len(category_strings))
|
||||
|
||||
query_conditions += f"""--sql
|
||||
AND images.image_category IN ( {placeholders} )
|
||||
with self._db.transaction() as cursor:
|
||||
# Manually build two queries - one for the count, one for the records
|
||||
count_query = """--sql
|
||||
SELECT COUNT(*)
|
||||
FROM images
|
||||
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||
WHERE 1=1
|
||||
"""
|
||||
|
||||
# Unpack the included categories into the query params
|
||||
for c in category_strings:
|
||||
query_params.append(c)
|
||||
|
||||
if is_intermediate is not None:
|
||||
query_conditions += """--sql
|
||||
AND images.is_intermediate = ?
|
||||
images_query = f"""--sql
|
||||
SELECT {IMAGE_DTO_COLS}
|
||||
FROM images
|
||||
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||
WHERE 1=1
|
||||
"""
|
||||
|
||||
query_params.append(is_intermediate)
|
||||
query_conditions = ""
|
||||
query_params: list[Union[int, str, bool]] = []
|
||||
|
||||
# board_id of "none" is reserved for images without a board
|
||||
if board_id == "none":
|
||||
query_conditions += """--sql
|
||||
AND board_images.board_id IS NULL
|
||||
"""
|
||||
elif board_id is not None:
|
||||
query_conditions += """--sql
|
||||
AND board_images.board_id = ?
|
||||
"""
|
||||
query_params.append(board_id)
|
||||
if image_origin is not None:
|
||||
query_conditions += """--sql
|
||||
AND images.image_origin = ?
|
||||
"""
|
||||
query_params.append(image_origin.value)
|
||||
|
||||
# Search term condition
|
||||
if search_term:
|
||||
query_conditions += """--sql
|
||||
AND (
|
||||
images.metadata LIKE ?
|
||||
OR images.created_at LIKE ?
|
||||
)
|
||||
"""
|
||||
query_params.append(f"%{search_term.lower()}%")
|
||||
query_params.append(f"%{search_term.lower()}%")
|
||||
if categories is not None:
|
||||
# Convert the enum values to unique list of strings
|
||||
category_strings = [c.value for c in set(categories)]
|
||||
# Create the correct length of placeholders
|
||||
placeholders = ",".join("?" * len(category_strings))
|
||||
|
||||
if starred_first:
|
||||
query_pagination = f"""--sql
|
||||
ORDER BY images.starred DESC, images.created_at {order_dir.value} LIMIT ? OFFSET ?
|
||||
"""
|
||||
else:
|
||||
query_pagination = f"""--sql
|
||||
ORDER BY images.created_at {order_dir.value} LIMIT ? OFFSET ?
|
||||
"""
|
||||
query_conditions += f"""--sql
|
||||
AND images.image_category IN ( {placeholders} )
|
||||
"""
|
||||
|
||||
# Final images query with pagination
|
||||
images_query += query_conditions + query_pagination + ";"
|
||||
# Add all the parameters
|
||||
images_params = query_params.copy()
|
||||
# Add the pagination parameters
|
||||
images_params.extend([limit, offset])
|
||||
# Unpack the included categories into the query params
|
||||
for c in category_strings:
|
||||
query_params.append(c)
|
||||
|
||||
# Build the list of images, deserializing each row
|
||||
cursor.execute(images_query, images_params)
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
images = [deserialize_image_record(dict(r)) for r in result]
|
||||
if is_intermediate is not None:
|
||||
query_conditions += """--sql
|
||||
AND images.is_intermediate = ?
|
||||
"""
|
||||
|
||||
# Set up and execute the count query, without pagination
|
||||
count_query += query_conditions + ";"
|
||||
count_params = query_params.copy()
|
||||
cursor.execute(count_query, count_params)
|
||||
count = cast(int, cursor.fetchone()[0])
|
||||
query_params.append(is_intermediate)
|
||||
|
||||
# board_id of "none" is reserved for images without a board
|
||||
if board_id == "none":
|
||||
query_conditions += """--sql
|
||||
AND board_images.board_id IS NULL
|
||||
"""
|
||||
elif board_id is not None:
|
||||
query_conditions += """--sql
|
||||
AND board_images.board_id = ?
|
||||
"""
|
||||
query_params.append(board_id)
|
||||
|
||||
# Search term condition
|
||||
if search_term:
|
||||
query_conditions += """--sql
|
||||
AND (
|
||||
images.metadata LIKE ?
|
||||
OR images.created_at LIKE ?
|
||||
)
|
||||
"""
|
||||
query_params.append(f"%{search_term.lower()}%")
|
||||
query_params.append(f"%{search_term.lower()}%")
|
||||
|
||||
if starred_first:
|
||||
query_pagination = f"""--sql
|
||||
ORDER BY images.starred DESC, images.created_at {order_dir.value} LIMIT ? OFFSET ?
|
||||
"""
|
||||
else:
|
||||
query_pagination = f"""--sql
|
||||
ORDER BY images.created_at {order_dir.value} LIMIT ? OFFSET ?
|
||||
"""
|
||||
|
||||
# Final images query with pagination
|
||||
images_query += query_conditions + query_pagination + ";"
|
||||
# Add all the parameters
|
||||
images_params = query_params.copy()
|
||||
# Add the pagination parameters
|
||||
images_params.extend([limit, offset])
|
||||
|
||||
# Build the list of images, deserializing each row
|
||||
cursor.execute(images_query, images_params)
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
|
||||
images = [deserialize_image_record(dict(r)) for r in result]
|
||||
|
||||
# Set up and execute the count query, without pagination
|
||||
count_query += query_conditions + ";"
|
||||
count_params = query_params.copy()
|
||||
cursor.execute(count_query, count_params)
|
||||
count = cast(int, cursor.fetchone()[0])
|
||||
|
||||
return OffsetPaginatedResults(items=images, offset=offset, limit=limit, total=count)
|
||||
|
||||
def delete(self, image_name: str) -> None:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM images
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(image_name,),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise ImageRecordDeleteException from e
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM images
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(image_name,),
|
||||
)
|
||||
except sqlite3.Error as e:
|
||||
raise ImageRecordDeleteException from e
|
||||
|
||||
def delete_many(self, image_names: list[str]) -> None:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
placeholders = ",".join("?" for _ in image_names)
|
||||
|
||||
placeholders = ",".join("?" for _ in image_names)
|
||||
# Construct the SQLite query with the placeholders
|
||||
query = f"DELETE FROM images WHERE image_name IN ({placeholders})"
|
||||
|
||||
# Construct the SQLite query with the placeholders
|
||||
query = f"DELETE FROM images WHERE image_name IN ({placeholders})"
|
||||
# Execute the query with the list of IDs as parameters
|
||||
cursor.execute(query, image_names)
|
||||
|
||||
# Execute the query with the list of IDs as parameters
|
||||
cursor.execute(query, image_names)
|
||||
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise ImageRecordDeleteException from e
|
||||
except sqlite3.Error as e:
|
||||
raise ImageRecordDeleteException from e
|
||||
|
||||
def get_intermediates_count(self) -> int:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT COUNT(*) FROM images
|
||||
WHERE is_intermediate = TRUE;
|
||||
"""
|
||||
)
|
||||
count = cast(int, cursor.fetchone()[0])
|
||||
self._conn.commit()
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT COUNT(*) FROM images
|
||||
WHERE is_intermediate = TRUE;
|
||||
"""
|
||||
)
|
||||
count = cast(int, cursor.fetchone()[0])
|
||||
return count
|
||||
|
||||
def delete_intermediates(self) -> list[str]:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT image_name FROM images
|
||||
WHERE is_intermediate = TRUE;
|
||||
"""
|
||||
)
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
image_names = [r[0] for r in result]
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM images
|
||||
WHERE is_intermediate = TRUE;
|
||||
"""
|
||||
)
|
||||
self._conn.commit()
|
||||
return image_names
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise ImageRecordDeleteException from e
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT image_name FROM images
|
||||
WHERE is_intermediate = TRUE;
|
||||
"""
|
||||
)
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
image_names = [r[0] for r in result]
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM images
|
||||
WHERE is_intermediate = TRUE;
|
||||
"""
|
||||
)
|
||||
except sqlite3.Error as e:
|
||||
raise ImageRecordDeleteException from e
|
||||
return image_names
|
||||
|
||||
def save(
|
||||
self,
|
||||
@@ -315,73 +306,71 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
node_id: Optional[str] = None,
|
||||
metadata: Optional[str] = None,
|
||||
) -> datetime:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO images (
|
||||
image_name,
|
||||
image_origin,
|
||||
image_category,
|
||||
width,
|
||||
height,
|
||||
node_id,
|
||||
session_id,
|
||||
metadata,
|
||||
is_intermediate,
|
||||
starred,
|
||||
has_workflow
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);
|
||||
""",
|
||||
(
|
||||
image_name,
|
||||
image_origin.value,
|
||||
image_category.value,
|
||||
width,
|
||||
height,
|
||||
node_id,
|
||||
session_id,
|
||||
metadata,
|
||||
is_intermediate,
|
||||
starred,
|
||||
has_workflow,
|
||||
),
|
||||
)
|
||||
self._conn.commit()
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO images (
|
||||
image_name,
|
||||
image_origin,
|
||||
image_category,
|
||||
width,
|
||||
height,
|
||||
node_id,
|
||||
session_id,
|
||||
metadata,
|
||||
is_intermediate,
|
||||
starred,
|
||||
has_workflow
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);
|
||||
""",
|
||||
(
|
||||
image_name,
|
||||
image_origin.value,
|
||||
image_category.value,
|
||||
width,
|
||||
height,
|
||||
node_id,
|
||||
session_id,
|
||||
metadata,
|
||||
is_intermediate,
|
||||
starred,
|
||||
has_workflow,
|
||||
),
|
||||
)
|
||||
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT created_at
|
||||
FROM images
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(image_name,),
|
||||
)
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT created_at
|
||||
FROM images
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(image_name,),
|
||||
)
|
||||
|
||||
created_at = datetime.fromisoformat(cursor.fetchone()[0])
|
||||
created_at = datetime.fromisoformat(cursor.fetchone()[0])
|
||||
|
||||
return created_at
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise ImageRecordSaveException from e
|
||||
except sqlite3.Error as e:
|
||||
raise ImageRecordSaveException from e
|
||||
return created_at
|
||||
|
||||
def get_most_recent_image_for_board(self, board_id: str) -> Optional[ImageRecord]:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT images.*
|
||||
FROM images
|
||||
JOIN board_images ON images.image_name = board_images.image_name
|
||||
WHERE board_images.board_id = ?
|
||||
AND images.is_intermediate = FALSE
|
||||
ORDER BY images.starred DESC, images.created_at DESC
|
||||
LIMIT 1;
|
||||
""",
|
||||
(board_id,),
|
||||
)
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT images.*
|
||||
FROM images
|
||||
JOIN board_images ON images.image_name = board_images.image_name
|
||||
WHERE board_images.board_id = ?
|
||||
AND images.is_intermediate = FALSE
|
||||
ORDER BY images.starred DESC, images.created_at DESC
|
||||
LIMIT 1;
|
||||
""",
|
||||
(board_id,),
|
||||
)
|
||||
|
||||
result = cast(Optional[sqlite3.Row], cursor.fetchone())
|
||||
result = cast(Optional[sqlite3.Row], cursor.fetchone())
|
||||
|
||||
if result is None:
|
||||
return None
|
||||
@@ -398,85 +387,84 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
board_id: Optional[str] = None,
|
||||
search_term: Optional[str] = None,
|
||||
) -> ImageNamesResult:
|
||||
cursor = self._conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
# Build query conditions (reused for both starred count and image names queries)
|
||||
query_conditions = ""
|
||||
query_params: list[Union[int, str, bool]] = []
|
||||
|
||||
# Build query conditions (reused for both starred count and image names queries)
|
||||
query_conditions = ""
|
||||
query_params: list[Union[int, str, bool]] = []
|
||||
if image_origin is not None:
|
||||
query_conditions += """--sql
|
||||
AND images.image_origin = ?
|
||||
"""
|
||||
query_params.append(image_origin.value)
|
||||
|
||||
if image_origin is not None:
|
||||
query_conditions += """--sql
|
||||
AND images.image_origin = ?
|
||||
"""
|
||||
query_params.append(image_origin.value)
|
||||
if categories is not None:
|
||||
category_strings = [c.value for c in set(categories)]
|
||||
placeholders = ",".join("?" * len(category_strings))
|
||||
query_conditions += f"""--sql
|
||||
AND images.image_category IN ( {placeholders} )
|
||||
"""
|
||||
for c in category_strings:
|
||||
query_params.append(c)
|
||||
|
||||
if categories is not None:
|
||||
category_strings = [c.value for c in set(categories)]
|
||||
placeholders = ",".join("?" * len(category_strings))
|
||||
query_conditions += f"""--sql
|
||||
AND images.image_category IN ( {placeholders} )
|
||||
"""
|
||||
for c in category_strings:
|
||||
query_params.append(c)
|
||||
if is_intermediate is not None:
|
||||
query_conditions += """--sql
|
||||
AND images.is_intermediate = ?
|
||||
"""
|
||||
query_params.append(is_intermediate)
|
||||
|
||||
if is_intermediate is not None:
|
||||
query_conditions += """--sql
|
||||
AND images.is_intermediate = ?
|
||||
"""
|
||||
query_params.append(is_intermediate)
|
||||
if board_id == "none":
|
||||
query_conditions += """--sql
|
||||
AND board_images.board_id IS NULL
|
||||
"""
|
||||
elif board_id is not None:
|
||||
query_conditions += """--sql
|
||||
AND board_images.board_id = ?
|
||||
"""
|
||||
query_params.append(board_id)
|
||||
|
||||
if board_id == "none":
|
||||
query_conditions += """--sql
|
||||
AND board_images.board_id IS NULL
|
||||
"""
|
||||
elif board_id is not None:
|
||||
query_conditions += """--sql
|
||||
AND board_images.board_id = ?
|
||||
"""
|
||||
query_params.append(board_id)
|
||||
if search_term:
|
||||
query_conditions += """--sql
|
||||
AND (
|
||||
images.metadata LIKE ?
|
||||
OR images.created_at LIKE ?
|
||||
)
|
||||
"""
|
||||
query_params.append(f"%{search_term.lower()}%")
|
||||
query_params.append(f"%{search_term.lower()}%")
|
||||
|
||||
if search_term:
|
||||
query_conditions += """--sql
|
||||
AND (
|
||||
images.metadata LIKE ?
|
||||
OR images.created_at LIKE ?
|
||||
)
|
||||
"""
|
||||
query_params.append(f"%{search_term.lower()}%")
|
||||
query_params.append(f"%{search_term.lower()}%")
|
||||
# Get starred count if starred_first is enabled
|
||||
starred_count = 0
|
||||
if starred_first:
|
||||
starred_count_query = f"""--sql
|
||||
SELECT COUNT(*)
|
||||
FROM images
|
||||
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||
WHERE images.starred = TRUE AND (1=1{query_conditions})
|
||||
"""
|
||||
cursor.execute(starred_count_query, query_params)
|
||||
starred_count = cast(int, cursor.fetchone()[0])
|
||||
|
||||
# Get starred count if starred_first is enabled
|
||||
starred_count = 0
|
||||
if starred_first:
|
||||
starred_count_query = f"""--sql
|
||||
SELECT COUNT(*)
|
||||
FROM images
|
||||
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||
WHERE images.starred = TRUE AND (1=1{query_conditions})
|
||||
"""
|
||||
cursor.execute(starred_count_query, query_params)
|
||||
starred_count = cast(int, cursor.fetchone()[0])
|
||||
# Get all image names with proper ordering
|
||||
if starred_first:
|
||||
names_query = f"""--sql
|
||||
SELECT images.image_name
|
||||
FROM images
|
||||
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||
WHERE 1=1{query_conditions}
|
||||
ORDER BY images.starred DESC, images.created_at {order_dir.value}
|
||||
"""
|
||||
else:
|
||||
names_query = f"""--sql
|
||||
SELECT images.image_name
|
||||
FROM images
|
||||
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||
WHERE 1=1{query_conditions}
|
||||
ORDER BY images.created_at {order_dir.value}
|
||||
"""
|
||||
|
||||
# Get all image names with proper ordering
|
||||
if starred_first:
|
||||
names_query = f"""--sql
|
||||
SELECT images.image_name
|
||||
FROM images
|
||||
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||
WHERE 1=1{query_conditions}
|
||||
ORDER BY images.starred DESC, images.created_at {order_dir.value}
|
||||
"""
|
||||
else:
|
||||
names_query = f"""--sql
|
||||
SELECT images.image_name
|
||||
FROM images
|
||||
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||
WHERE 1=1{query_conditions}
|
||||
ORDER BY images.created_at {order_dir.value}
|
||||
"""
|
||||
|
||||
cursor.execute(names_query, query_params)
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
cursor.execute(names_query, query_params)
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
image_names = [row[0] for row in result]
|
||||
|
||||
return ImageNamesResult(image_names=image_names, starred_count=starred_count, total_count=len(image_names))
|
||||
|
||||
@@ -78,11 +78,6 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
self._db = db
|
||||
self._logger = logger
|
||||
|
||||
@property
|
||||
def db(self) -> SqliteDatabase:
|
||||
"""Return the underlying database."""
|
||||
return self._db
|
||||
|
||||
def add_model(self, config: AnyModelConfig) -> AnyModelConfig:
|
||||
"""
|
||||
Add a model to the database.
|
||||
@@ -93,38 +88,33 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
|
||||
Can raise DuplicateModelException and InvalidModelConfigException exceptions.
|
||||
"""
|
||||
try:
|
||||
cursor = self._db.conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
INSERT INTO models (
|
||||
id,
|
||||
config
|
||||
)
|
||||
VALUES (?,?);
|
||||
""",
|
||||
(
|
||||
config.key,
|
||||
config.model_dump_json(),
|
||||
),
|
||||
)
|
||||
self._db.conn.commit()
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
INSERT INTO models (
|
||||
id,
|
||||
config
|
||||
)
|
||||
VALUES (?,?);
|
||||
""",
|
||||
(
|
||||
config.key,
|
||||
config.model_dump_json(),
|
||||
),
|
||||
)
|
||||
|
||||
except sqlite3.IntegrityError as e:
|
||||
self._db.conn.rollback()
|
||||
if "UNIQUE constraint failed" in str(e):
|
||||
if "models.path" in str(e):
|
||||
msg = f"A model with path '{config.path}' is already installed"
|
||||
elif "models.name" in str(e):
|
||||
msg = f"A model with name='{config.name}', type='{config.type}', base='{config.base}' is already installed"
|
||||
except sqlite3.IntegrityError as e:
|
||||
if "UNIQUE constraint failed" in str(e):
|
||||
if "models.path" in str(e):
|
||||
msg = f"A model with path '{config.path}' is already installed"
|
||||
elif "models.name" in str(e):
|
||||
msg = f"A model with name='{config.name}', type='{config.type}', base='{config.base}' is already installed"
|
||||
else:
|
||||
msg = f"A model with key '{config.key}' is already installed"
|
||||
raise DuplicateModelException(msg) from e
|
||||
else:
|
||||
msg = f"A model with key '{config.key}' is already installed"
|
||||
raise DuplicateModelException(msg) from e
|
||||
else:
|
||||
raise e
|
||||
except sqlite3.Error as e:
|
||||
self._db.conn.rollback()
|
||||
raise e
|
||||
raise e
|
||||
|
||||
return self.get_model(config.key)
|
||||
|
||||
@@ -136,8 +126,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
|
||||
Can raise an UnknownModelException
|
||||
"""
|
||||
try:
|
||||
cursor = self._db.conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM models
|
||||
@@ -147,22 +136,17 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
)
|
||||
if cursor.rowcount == 0:
|
||||
raise UnknownModelException("model not found")
|
||||
self._db.conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._db.conn.rollback()
|
||||
raise e
|
||||
|
||||
def update_model(self, key: str, changes: ModelRecordChanges) -> AnyModelConfig:
|
||||
record = self.get_model(key)
|
||||
with self._db.transaction() as cursor:
|
||||
record = self.get_model(key)
|
||||
|
||||
# Model configs use pydantic's `validate_assignment`, so each change is validated by pydantic.
|
||||
for field_name in changes.model_fields_set:
|
||||
setattr(record, field_name, getattr(changes, field_name))
|
||||
# Model configs use pydantic's `validate_assignment`, so each change is validated by pydantic.
|
||||
for field_name in changes.model_fields_set:
|
||||
setattr(record, field_name, getattr(changes, field_name))
|
||||
|
||||
json_serialized = record.model_dump_json()
|
||||
json_serialized = record.model_dump_json()
|
||||
|
||||
try:
|
||||
cursor = self._db.conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE models
|
||||
@@ -174,10 +158,6 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
)
|
||||
if cursor.rowcount == 0:
|
||||
raise UnknownModelException("model not found")
|
||||
self._db.conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._db.conn.rollback()
|
||||
raise e
|
||||
|
||||
return self.get_model(key)
|
||||
|
||||
@@ -189,30 +169,30 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
|
||||
Exceptions: UnknownModelException
|
||||
"""
|
||||
cursor = self._db.conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
WHERE id=?;
|
||||
""",
|
||||
(key,),
|
||||
)
|
||||
rows = cursor.fetchone()
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
WHERE id=?;
|
||||
""",
|
||||
(key,),
|
||||
)
|
||||
rows = cursor.fetchone()
|
||||
if not rows:
|
||||
raise UnknownModelException("model not found")
|
||||
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
|
||||
return model
|
||||
|
||||
def get_model_by_hash(self, hash: str) -> AnyModelConfig:
|
||||
cursor = self._db.conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
WHERE hash=?;
|
||||
""",
|
||||
(hash,),
|
||||
)
|
||||
rows = cursor.fetchone()
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
WHERE hash=?;
|
||||
""",
|
||||
(hash,),
|
||||
)
|
||||
rows = cursor.fetchone()
|
||||
if not rows:
|
||||
raise UnknownModelException("model not found")
|
||||
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
|
||||
@@ -224,15 +204,15 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
|
||||
:param key: Unique key for the model to be deleted
|
||||
"""
|
||||
cursor = self._db.conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
select count(*) FROM models
|
||||
WHERE id=?;
|
||||
""",
|
||||
(key,),
|
||||
)
|
||||
count = cursor.fetchone()[0]
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
select count(*) FROM models
|
||||
WHERE id=?;
|
||||
""",
|
||||
(key,),
|
||||
)
|
||||
count = cursor.fetchone()[0]
|
||||
return count > 0
|
||||
|
||||
def search_by_attr(
|
||||
@@ -255,43 +235,42 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
If none of the optional filters are passed, will return all
|
||||
models in the database.
|
||||
"""
|
||||
with self._db.transaction() as cursor:
|
||||
assert isinstance(order_by, ModelRecordOrderBy)
|
||||
ordering = {
|
||||
ModelRecordOrderBy.Default: "type, base, name, format",
|
||||
ModelRecordOrderBy.Type: "type",
|
||||
ModelRecordOrderBy.Base: "base",
|
||||
ModelRecordOrderBy.Name: "name",
|
||||
ModelRecordOrderBy.Format: "format",
|
||||
}
|
||||
|
||||
assert isinstance(order_by, ModelRecordOrderBy)
|
||||
ordering = {
|
||||
ModelRecordOrderBy.Default: "type, base, name, format",
|
||||
ModelRecordOrderBy.Type: "type",
|
||||
ModelRecordOrderBy.Base: "base",
|
||||
ModelRecordOrderBy.Name: "name",
|
||||
ModelRecordOrderBy.Format: "format",
|
||||
}
|
||||
where_clause: list[str] = []
|
||||
bindings: list[str] = []
|
||||
if model_name:
|
||||
where_clause.append("name=?")
|
||||
bindings.append(model_name)
|
||||
if base_model:
|
||||
where_clause.append("base=?")
|
||||
bindings.append(base_model)
|
||||
if model_type:
|
||||
where_clause.append("type=?")
|
||||
bindings.append(model_type)
|
||||
if model_format:
|
||||
where_clause.append("format=?")
|
||||
bindings.append(model_format)
|
||||
where = f"WHERE {' AND '.join(where_clause)}" if where_clause else ""
|
||||
|
||||
where_clause: list[str] = []
|
||||
bindings: list[str] = []
|
||||
if model_name:
|
||||
where_clause.append("name=?")
|
||||
bindings.append(model_name)
|
||||
if base_model:
|
||||
where_clause.append("base=?")
|
||||
bindings.append(base_model)
|
||||
if model_type:
|
||||
where_clause.append("type=?")
|
||||
bindings.append(model_type)
|
||||
if model_format:
|
||||
where_clause.append("format=?")
|
||||
bindings.append(model_format)
|
||||
where = f"WHERE {' AND '.join(where_clause)}" if where_clause else ""
|
||||
|
||||
cursor = self._db.conn.cursor()
|
||||
cursor.execute(
|
||||
f"""--sql
|
||||
SELECT config, strftime('%s',updated_at)
|
||||
FROM models
|
||||
{where}
|
||||
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason;
|
||||
""",
|
||||
tuple(bindings),
|
||||
)
|
||||
result = cursor.fetchall()
|
||||
cursor.execute(
|
||||
f"""--sql
|
||||
SELECT config, strftime('%s',updated_at)
|
||||
FROM models
|
||||
{where}
|
||||
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason;
|
||||
""",
|
||||
tuple(bindings),
|
||||
)
|
||||
result = cursor.fetchall()
|
||||
|
||||
# Parse the model configs.
|
||||
results: list[AnyModelConfig] = []
|
||||
@@ -313,69 +292,68 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
|
||||
def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]:
|
||||
"""Return models with the indicated path."""
|
||||
cursor = self._db.conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
WHERE path=?;
|
||||
""",
|
||||
(str(path),),
|
||||
)
|
||||
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in cursor.fetchall()]
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
WHERE path=?;
|
||||
""",
|
||||
(str(path),),
|
||||
)
|
||||
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in cursor.fetchall()]
|
||||
return results
|
||||
|
||||
def search_by_hash(self, hash: str) -> List[AnyModelConfig]:
|
||||
"""Return models with the indicated hash."""
|
||||
cursor = self._db.conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
WHERE hash=?;
|
||||
""",
|
||||
(hash,),
|
||||
)
|
||||
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in cursor.fetchall()]
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
WHERE hash=?;
|
||||
""",
|
||||
(hash,),
|
||||
)
|
||||
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in cursor.fetchall()]
|
||||
return results
|
||||
|
||||
def list_models(
|
||||
self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default
|
||||
) -> PaginatedResults[ModelSummary]:
|
||||
"""Return a paginated summary listing of each model in the database."""
|
||||
assert isinstance(order_by, ModelRecordOrderBy)
|
||||
ordering = {
|
||||
ModelRecordOrderBy.Default: "type, base, name, format",
|
||||
ModelRecordOrderBy.Type: "type",
|
||||
ModelRecordOrderBy.Base: "base",
|
||||
ModelRecordOrderBy.Name: "name",
|
||||
ModelRecordOrderBy.Format: "format",
|
||||
}
|
||||
with self._db.transaction() as cursor:
|
||||
assert isinstance(order_by, ModelRecordOrderBy)
|
||||
ordering = {
|
||||
ModelRecordOrderBy.Default: "type, base, name, format",
|
||||
ModelRecordOrderBy.Type: "type",
|
||||
ModelRecordOrderBy.Base: "base",
|
||||
ModelRecordOrderBy.Name: "name",
|
||||
ModelRecordOrderBy.Format: "format",
|
||||
}
|
||||
|
||||
cursor = self._db.conn.cursor()
|
||||
# Lock so that the database isn't updated while we're doing the two queries.
|
||||
# query1: get the total number of model configs
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
select count(*) from models;
|
||||
""",
|
||||
(),
|
||||
)
|
||||
total = int(cursor.fetchone()[0])
|
||||
|
||||
# Lock so that the database isn't updated while we're doing the two queries.
|
||||
# query1: get the total number of model configs
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
select count(*) from models;
|
||||
""",
|
||||
(),
|
||||
)
|
||||
total = int(cursor.fetchone()[0])
|
||||
|
||||
# query2: fetch key fields
|
||||
cursor.execute(
|
||||
f"""--sql
|
||||
SELECT config
|
||||
FROM models
|
||||
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason
|
||||
LIMIT ?
|
||||
OFFSET ?;
|
||||
""",
|
||||
(
|
||||
per_page,
|
||||
page * per_page,
|
||||
),
|
||||
)
|
||||
rows = cursor.fetchall()
|
||||
# query2: fetch key fields
|
||||
cursor.execute(
|
||||
f"""--sql
|
||||
SELECT config
|
||||
FROM models
|
||||
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason
|
||||
LIMIT ?
|
||||
OFFSET ?;
|
||||
""",
|
||||
(
|
||||
per_page,
|
||||
page * per_page,
|
||||
),
|
||||
)
|
||||
rows = cursor.fetchall()
|
||||
items = [ModelSummary.model_validate(dict(x)) for x in rows]
|
||||
return PaginatedResults(page=page, pages=ceil(total / per_page), per_page=per_page, total=total, items=items)
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
import sqlite3
|
||||
|
||||
from invokeai.app.services.model_relationship_records.model_relationship_records_base import (
|
||||
ModelRelationshipRecordStorageBase,
|
||||
)
|
||||
@@ -9,58 +7,49 @@ from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
class SqliteModelRelationshipRecordStorage(ModelRelationshipRecordStorageBase):
|
||||
def __init__(self, db: SqliteDatabase) -> None:
|
||||
super().__init__()
|
||||
self._conn = db.conn
|
||||
self._db = db
|
||||
|
||||
def add_model_relationship(self, model_key_1: str, model_key_2: str) -> None:
|
||||
if model_key_1 == model_key_2:
|
||||
raise ValueError("Cannot relate a model to itself.")
|
||||
a, b = sorted([model_key_1, model_key_2])
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
if model_key_1 == model_key_2:
|
||||
raise ValueError("Cannot relate a model to itself.")
|
||||
a, b = sorted([model_key_1, model_key_2])
|
||||
cursor.execute(
|
||||
"INSERT OR IGNORE INTO model_relationships (model_key_1, model_key_2) VALUES (?, ?)",
|
||||
(a, b),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
|
||||
def remove_model_relationship(self, model_key_1: str, model_key_2: str) -> None:
|
||||
a, b = sorted([model_key_1, model_key_2])
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
a, b = sorted([model_key_1, model_key_2])
|
||||
cursor.execute(
|
||||
"DELETE FROM model_relationships WHERE model_key_1 = ? AND model_key_2 = ?",
|
||||
(a, b),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
|
||||
def get_related_model_keys(self, model_key: str) -> list[str]:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT model_key_2 FROM model_relationships WHERE model_key_1 = ?
|
||||
UNION
|
||||
SELECT model_key_1 FROM model_relationships WHERE model_key_2 = ?
|
||||
""",
|
||||
(model_key, model_key),
|
||||
)
|
||||
return [row[0] for row in cursor.fetchall()]
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT model_key_2 FROM model_relationships WHERE model_key_1 = ?
|
||||
UNION
|
||||
SELECT model_key_1 FROM model_relationships WHERE model_key_2 = ?
|
||||
""",
|
||||
(model_key, model_key),
|
||||
)
|
||||
result = [row[0] for row in cursor.fetchall()]
|
||||
return result
|
||||
|
||||
def get_related_model_keys_batch(self, model_keys: list[str]) -> list[str]:
|
||||
cursor = self._conn.cursor()
|
||||
|
||||
key_list = ",".join("?" for _ in model_keys)
|
||||
cursor.execute(
|
||||
f"""
|
||||
SELECT model_key_2 FROM model_relationships WHERE model_key_1 IN ({key_list})
|
||||
UNION
|
||||
SELECT model_key_1 FROM model_relationships WHERE model_key_2 IN ({key_list})
|
||||
""",
|
||||
model_keys + model_keys,
|
||||
)
|
||||
return [row[0] for row in cursor.fetchall()]
|
||||
with self._db.transaction() as cursor:
|
||||
key_list = ",".join("?" for _ in model_keys)
|
||||
cursor.execute(
|
||||
f"""
|
||||
SELECT model_key_2 FROM model_relationships WHERE model_key_1 IN ({key_list})
|
||||
UNION
|
||||
SELECT model_key_1 FROM model_relationships WHERE model_key_2 IN ({key_list})
|
||||
""",
|
||||
model_keys + model_keys,
|
||||
)
|
||||
result = [row[0] for row in cursor.fetchall()]
|
||||
return result
|
||||
|
||||
@@ -50,15 +50,14 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
|
||||
def __init__(self, db: SqliteDatabase) -> None:
|
||||
super().__init__()
|
||||
self._conn = db.conn
|
||||
self._db = db
|
||||
|
||||
def _set_in_progress_to_canceled(self) -> None:
|
||||
"""
|
||||
Sets all in_progress queue items to canceled. Run on app startup, not associated with any queue.
|
||||
This is necessary because the invoker may have been killed while processing a queue item.
|
||||
"""
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE session_queue
|
||||
@@ -66,87 +65,79 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
WHERE status = 'in_progress';
|
||||
"""
|
||||
)
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
|
||||
def _get_current_queue_size(self, queue_id: str) -> int:
|
||||
"""Gets the current number of pending queue items"""
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT count(*)
|
||||
FROM session_queue
|
||||
WHERE
|
||||
queue_id = ?
|
||||
AND status = 'pending'
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
return cast(int, cursor.fetchone()[0])
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT count(*)
|
||||
FROM session_queue
|
||||
WHERE
|
||||
queue_id = ?
|
||||
AND status = 'pending'
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
count = cast(int, cursor.fetchone()[0])
|
||||
return count
|
||||
|
||||
def _get_highest_priority(self, queue_id: str) -> int:
|
||||
"""Gets the highest priority value in the queue"""
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT MAX(priority)
|
||||
FROM session_queue
|
||||
WHERE
|
||||
queue_id = ?
|
||||
AND status = 'pending'
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
return cast(Union[int, None], cursor.fetchone()[0]) or 0
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT MAX(priority)
|
||||
FROM session_queue
|
||||
WHERE
|
||||
queue_id = ?
|
||||
AND status = 'pending'
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
priority = cast(Union[int, None], cursor.fetchone()[0]) or 0
|
||||
return priority
|
||||
|
||||
async def enqueue_batch(self, queue_id: str, batch: Batch, prepend: bool) -> EnqueueBatchResult:
|
||||
try:
|
||||
# TODO: how does this work in a multi-user scenario?
|
||||
current_queue_size = self._get_current_queue_size(queue_id)
|
||||
max_queue_size = self.__invoker.services.configuration.max_queue_size
|
||||
max_new_queue_items = max_queue_size - current_queue_size
|
||||
current_queue_size = self._get_current_queue_size(queue_id)
|
||||
max_queue_size = self.__invoker.services.configuration.max_queue_size
|
||||
max_new_queue_items = max_queue_size - current_queue_size
|
||||
|
||||
priority = 0
|
||||
if prepend:
|
||||
priority = self._get_highest_priority(queue_id) + 1
|
||||
priority = 0
|
||||
if prepend:
|
||||
priority = self._get_highest_priority(queue_id) + 1
|
||||
|
||||
requested_count = await asyncio.to_thread(
|
||||
calc_session_count,
|
||||
batch=batch,
|
||||
)
|
||||
values_to_insert = await asyncio.to_thread(
|
||||
prepare_values_to_insert,
|
||||
queue_id=queue_id,
|
||||
batch=batch,
|
||||
priority=priority,
|
||||
max_new_queue_items=max_new_queue_items,
|
||||
)
|
||||
enqueued_count = len(values_to_insert)
|
||||
requested_count = await asyncio.to_thread(
|
||||
calc_session_count,
|
||||
batch=batch,
|
||||
)
|
||||
values_to_insert = await asyncio.to_thread(
|
||||
prepare_values_to_insert,
|
||||
queue_id=queue_id,
|
||||
batch=batch,
|
||||
priority=priority,
|
||||
max_new_queue_items=max_new_queue_items,
|
||||
)
|
||||
enqueued_count = len(values_to_insert)
|
||||
|
||||
with self._conn:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.executemany(
|
||||
"""--sql
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.executemany(
|
||||
"""--sql
|
||||
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow, origin, destination, retried_from_item_id)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
values_to_insert,
|
||||
)
|
||||
with self._conn:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
values_to_insert,
|
||||
)
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT item_id
|
||||
FROM session_queue
|
||||
WHERE batch_id = ?
|
||||
ORDER BY item_id DESC;
|
||||
""",
|
||||
(batch.batch_id,),
|
||||
)
|
||||
item_ids = [row[0] for row in cursor.fetchall()]
|
||||
except Exception:
|
||||
raise
|
||||
(batch.batch_id,),
|
||||
)
|
||||
item_ids = [row[0] for row in cursor.fetchall()]
|
||||
enqueue_result = EnqueueBatchResult(
|
||||
queue_id=queue_id,
|
||||
requested=requested_count,
|
||||
@@ -159,19 +150,19 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
return enqueue_result
|
||||
|
||||
def dequeue(self) -> Optional[SessionQueueItem]:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM session_queue
|
||||
WHERE status = 'pending'
|
||||
ORDER BY
|
||||
priority DESC,
|
||||
item_id ASC
|
||||
LIMIT 1
|
||||
"""
|
||||
)
|
||||
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM session_queue
|
||||
WHERE status = 'pending'
|
||||
ORDER BY
|
||||
priority DESC,
|
||||
item_id ASC
|
||||
LIMIT 1
|
||||
"""
|
||||
)
|
||||
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
|
||||
if result is None:
|
||||
return None
|
||||
queue_item = SessionQueueItem.queue_item_from_dict(dict(result))
|
||||
@@ -179,40 +170,40 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
return queue_item
|
||||
|
||||
def get_next(self, queue_id: str) -> Optional[SessionQueueItem]:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM session_queue
|
||||
WHERE
|
||||
queue_id = ?
|
||||
AND status = 'pending'
|
||||
ORDER BY
|
||||
priority DESC,
|
||||
created_at ASC
|
||||
LIMIT 1
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM session_queue
|
||||
WHERE
|
||||
queue_id = ?
|
||||
AND status = 'pending'
|
||||
ORDER BY
|
||||
priority DESC,
|
||||
created_at ASC
|
||||
LIMIT 1
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
|
||||
if result is None:
|
||||
return None
|
||||
return SessionQueueItem.queue_item_from_dict(dict(result))
|
||||
|
||||
def get_current(self, queue_id: str) -> Optional[SessionQueueItem]:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM session_queue
|
||||
WHERE
|
||||
queue_id = ?
|
||||
AND status = 'in_progress'
|
||||
LIMIT 1
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM session_queue
|
||||
WHERE
|
||||
queue_id = ?
|
||||
AND status = 'in_progress'
|
||||
LIMIT 1
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
|
||||
if result is None:
|
||||
return None
|
||||
return SessionQueueItem.queue_item_from_dict(dict(result))
|
||||
@@ -225,8 +216,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
error_message: Optional[str] = None,
|
||||
error_traceback: Optional[str] = None,
|
||||
) -> SessionQueueItem:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT status FROM session_queue WHERE item_id = ?
|
||||
@@ -234,12 +224,15 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
(item_id,),
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
if row is None:
|
||||
raise SessionQueueItemNotFoundError(f"No queue item with id {item_id}")
|
||||
current_status = row[0]
|
||||
# Only update if not already finished (completed, failed or canceled)
|
||||
if current_status in ("completed", "failed", "canceled"):
|
||||
return self.get_queue_item(item_id)
|
||||
if row is None:
|
||||
raise SessionQueueItemNotFoundError(f"No queue item with id {item_id}")
|
||||
current_status = row[0]
|
||||
|
||||
# Only update if not already finished (completed, failed or canceled)
|
||||
if current_status in ("completed", "failed", "canceled"):
|
||||
return self.get_queue_item(item_id)
|
||||
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE session_queue
|
||||
@@ -248,10 +241,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
(status, error_type, error_message, error_traceback, item_id),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
|
||||
queue_item = self.get_queue_item(item_id)
|
||||
batch_status = self.get_batch_status(queue_id=queue_item.queue_id, batch_id=queue_item.batch_id)
|
||||
queue_status = self.get_queue_status(queue_id=queue_item.queue_id)
|
||||
@@ -259,35 +249,34 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
return queue_item
|
||||
|
||||
def is_empty(self, queue_id: str) -> IsEmptyResult:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT count(*)
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
is_empty = cast(int, cursor.fetchone()[0]) == 0
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT count(*)
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
is_empty = cast(int, cursor.fetchone()[0]) == 0
|
||||
return IsEmptyResult(is_empty=is_empty)
|
||||
|
||||
def is_full(self, queue_id: str) -> IsFullResult:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT count(*)
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
max_queue_size = self.__invoker.services.configuration.max_queue_size
|
||||
is_full = cast(int, cursor.fetchone()[0]) >= max_queue_size
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT count(*)
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
max_queue_size = self.__invoker.services.configuration.max_queue_size
|
||||
is_full = cast(int, cursor.fetchone()[0]) >= max_queue_size
|
||||
return IsFullResult(is_full=is_full)
|
||||
|
||||
def clear(self, queue_id: str) -> ClearResult:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT COUNT(*)
|
||||
@@ -305,24 +294,19 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
self.__invoker.services.events.emit_queue_cleared(queue_id)
|
||||
return ClearResult(deleted=count)
|
||||
|
||||
def prune(self, queue_id: str) -> PruneResult:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
where = """--sql
|
||||
WHERE
|
||||
queue_id = ?
|
||||
AND (
|
||||
queue_id = ?
|
||||
AND (
|
||||
status = 'completed'
|
||||
OR status = 'failed'
|
||||
OR status = 'canceled'
|
||||
)
|
||||
)
|
||||
"""
|
||||
cursor.execute(
|
||||
f"""--sql
|
||||
@@ -341,10 +325,6 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
return PruneResult(deleted=count)
|
||||
|
||||
def cancel_queue_item(self, item_id: int) -> SessionQueueItem:
|
||||
@@ -357,8 +337,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
self.cancel_queue_item(item_id)
|
||||
except SessionQueueItemNotFoundError:
|
||||
pass
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
DELETE
|
||||
@@ -367,10 +346,6 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
(item_id,),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
|
||||
def complete_queue_item(self, item_id: int) -> SessionQueueItem:
|
||||
queue_item = self._set_queue_item_status(item_id=item_id, status="completed")
|
||||
@@ -393,8 +368,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
return queue_item
|
||||
|
||||
def cancel_by_batch_ids(self, queue_id: str, batch_ids: list[str]) -> CancelByBatchIDsResult:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
current_queue_item = self.get_current(queue_id)
|
||||
placeholders = ", ".join(["?" for _ in batch_ids])
|
||||
where = f"""--sql
|
||||
@@ -404,6 +378,8 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
AND status != 'canceled'
|
||||
AND status != 'completed'
|
||||
AND status != 'failed'
|
||||
-- We will cancel the current item separately below - skip it here
|
||||
AND status != 'in_progress'
|
||||
"""
|
||||
params = [queue_id] + batch_ids
|
||||
cursor.execute(
|
||||
@@ -423,17 +399,14 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
tuple(params),
|
||||
)
|
||||
self._conn.commit()
|
||||
if current_queue_item is not None and current_queue_item.batch_id in batch_ids:
|
||||
self._set_queue_item_status(current_queue_item.item_id, "canceled")
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
|
||||
if current_queue_item is not None and current_queue_item.batch_id in batch_ids:
|
||||
self._set_queue_item_status(current_queue_item.item_id, "canceled")
|
||||
|
||||
return CancelByBatchIDsResult(canceled=count)
|
||||
|
||||
def cancel_by_destination(self, queue_id: str, destination: str) -> CancelByDestinationResult:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
current_queue_item = self.get_current(queue_id)
|
||||
where = """--sql
|
||||
WHERE
|
||||
@@ -442,6 +415,8 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
AND status != 'canceled'
|
||||
AND status != 'completed'
|
||||
AND status != 'failed'
|
||||
-- We will cancel the current item separately below - skip it here
|
||||
AND status != 'in_progress'
|
||||
"""
|
||||
params = (queue_id, destination)
|
||||
cursor.execute(
|
||||
@@ -461,17 +436,12 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
params,
|
||||
)
|
||||
self._conn.commit()
|
||||
if current_queue_item is not None and current_queue_item.destination == destination:
|
||||
self._set_queue_item_status(current_queue_item.item_id, "canceled")
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
if current_queue_item is not None and current_queue_item.destination == destination:
|
||||
self._set_queue_item_status(current_queue_item.item_id, "canceled")
|
||||
return CancelByDestinationResult(canceled=count)
|
||||
|
||||
def delete_by_destination(self, queue_id: str, destination: str) -> DeleteByDestinationResult:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
current_queue_item = self.get_current(queue_id)
|
||||
if current_queue_item is not None and current_queue_item.destination == destination:
|
||||
self.cancel_queue_item(current_queue_item.item_id)
|
||||
@@ -497,15 +467,10 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
params,
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
return DeleteByDestinationResult(deleted=count)
|
||||
|
||||
def delete_all_except_current(self, queue_id: str) -> DeleteAllExceptCurrentResult:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
where = """--sql
|
||||
WHERE
|
||||
queue_id == ?
|
||||
@@ -528,15 +493,10 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
return DeleteAllExceptCurrentResult(deleted=count)
|
||||
|
||||
def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
current_queue_item = self.get_current(queue_id)
|
||||
where = """--sql
|
||||
WHERE
|
||||
@@ -544,6 +504,8 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
AND status != 'canceled'
|
||||
AND status != 'completed'
|
||||
AND status != 'failed'
|
||||
-- We will cancel the current item separately below - skip it here
|
||||
AND status != 'in_progress'
|
||||
"""
|
||||
params = [queue_id]
|
||||
cursor.execute(
|
||||
@@ -563,21 +525,13 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
tuple(params),
|
||||
)
|
||||
self._conn.commit()
|
||||
if current_queue_item is not None and current_queue_item.queue_id == queue_id:
|
||||
batch_status = self.get_batch_status(queue_id=queue_id, batch_id=current_queue_item.batch_id)
|
||||
queue_status = self.get_queue_status(queue_id=queue_id)
|
||||
self.__invoker.services.events.emit_queue_item_status_changed(
|
||||
current_queue_item, batch_status, queue_status
|
||||
)
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
|
||||
if current_queue_item is not None and current_queue_item.queue_id == queue_id:
|
||||
self._set_queue_item_status(current_queue_item.item_id, "canceled")
|
||||
return CancelByQueueIDResult(canceled=count)
|
||||
|
||||
def cancel_all_except_current(self, queue_id: str) -> CancelAllExceptCurrentResult:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
where = """--sql
|
||||
WHERE
|
||||
queue_id == ?
|
||||
@@ -600,30 +554,25 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
return CancelAllExceptCurrentResult(canceled=count)
|
||||
|
||||
def get_queue_item(self, item_id: int) -> SessionQueueItem:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT * FROM session_queue
|
||||
WHERE
|
||||
item_id = ?
|
||||
""",
|
||||
(item_id,),
|
||||
)
|
||||
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT * FROM session_queue
|
||||
WHERE
|
||||
item_id = ?
|
||||
""",
|
||||
(item_id,),
|
||||
)
|
||||
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
|
||||
if result is None:
|
||||
raise SessionQueueItemNotFoundError(f"No queue item with id {item_id}")
|
||||
return SessionQueueItem.queue_item_from_dict(dict(result))
|
||||
|
||||
def set_queue_item_session(self, item_id: int, session: GraphExecutionState) -> SessionQueueItem:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
# Use exclude_none so we don't end up with a bunch of nulls in the graph - this can cause validation errors
|
||||
# when the graph is loaded. Graph execution occurs purely in memory - the session saved here is not referenced
|
||||
# during execution.
|
||||
@@ -636,10 +585,6 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
(session_json, item_id),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
return self.get_queue_item(item_id)
|
||||
|
||||
def list_queue_items(
|
||||
@@ -651,42 +596,42 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
status: Optional[QUEUE_ITEM_STATUS] = None,
|
||||
destination: Optional[str] = None,
|
||||
) -> CursorPaginatedResults[SessionQueueItem]:
|
||||
cursor_ = self._conn.cursor()
|
||||
item_id = cursor
|
||||
query = """--sql
|
||||
SELECT *
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
"""
|
||||
params: list[Union[str, int]] = [queue_id]
|
||||
|
||||
if status is not None:
|
||||
query += """--sql
|
||||
AND status = ?
|
||||
"""
|
||||
params.append(status)
|
||||
|
||||
if destination is not None:
|
||||
query += """---sql
|
||||
AND destination = ?
|
||||
with self._db.transaction() as cursor_:
|
||||
item_id = cursor
|
||||
query = """--sql
|
||||
SELECT *
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
"""
|
||||
params.append(destination)
|
||||
params: list[Union[str, int]] = [queue_id]
|
||||
|
||||
if item_id is not None:
|
||||
query += """--sql
|
||||
AND (priority < ?) OR (priority = ? AND item_id > ?)
|
||||
if status is not None:
|
||||
query += """--sql
|
||||
AND status = ?
|
||||
"""
|
||||
params.append(status)
|
||||
|
||||
if destination is not None:
|
||||
query += """---sql
|
||||
AND destination = ?
|
||||
"""
|
||||
params.extend([priority, priority, item_id])
|
||||
params.append(destination)
|
||||
|
||||
query += """--sql
|
||||
ORDER BY
|
||||
priority DESC,
|
||||
item_id ASC
|
||||
LIMIT ?
|
||||
"""
|
||||
params.append(limit + 1)
|
||||
cursor_.execute(query, params)
|
||||
results = cast(list[sqlite3.Row], cursor_.fetchall())
|
||||
if item_id is not None:
|
||||
query += """--sql
|
||||
AND (priority < ?) OR (priority = ? AND item_id > ?)
|
||||
"""
|
||||
params.extend([priority, priority, item_id])
|
||||
|
||||
query += """--sql
|
||||
ORDER BY
|
||||
priority DESC,
|
||||
item_id ASC
|
||||
LIMIT ?
|
||||
"""
|
||||
params.append(limit + 1)
|
||||
cursor_.execute(query, params)
|
||||
results = cast(list[sqlite3.Row], cursor_.fetchall())
|
||||
items = [SessionQueueItem.queue_item_from_dict(dict(result)) for result in results]
|
||||
has_more = False
|
||||
if len(items) > limit:
|
||||
@@ -701,43 +646,43 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
destination: Optional[str] = None,
|
||||
) -> list[SessionQueueItem]:
|
||||
"""Gets all queue items that match the given parameters"""
|
||||
cursor_ = self._conn.cursor()
|
||||
query = """--sql
|
||||
SELECT *
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
"""
|
||||
params: list[Union[str, int]] = [queue_id]
|
||||
|
||||
if destination is not None:
|
||||
query += """---sql
|
||||
AND destination = ?
|
||||
with self._db.transaction() as cursor:
|
||||
query = """--sql
|
||||
SELECT *
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
"""
|
||||
params.append(destination)
|
||||
params: list[Union[str, int]] = [queue_id]
|
||||
|
||||
query += """--sql
|
||||
ORDER BY
|
||||
priority DESC,
|
||||
item_id ASC
|
||||
;
|
||||
"""
|
||||
cursor_.execute(query, params)
|
||||
results = cast(list[sqlite3.Row], cursor_.fetchall())
|
||||
if destination is not None:
|
||||
query += """---sql
|
||||
AND destination = ?
|
||||
"""
|
||||
params.append(destination)
|
||||
|
||||
query += """--sql
|
||||
ORDER BY
|
||||
priority DESC,
|
||||
item_id ASC
|
||||
;
|
||||
"""
|
||||
cursor.execute(query, params)
|
||||
results = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
items = [SessionQueueItem.queue_item_from_dict(dict(result)) for result in results]
|
||||
return items
|
||||
|
||||
def get_queue_status(self, queue_id: str) -> SessionQueueStatus:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT status, count(*)
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
GROUP BY status
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
counts_result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT status, count(*)
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
GROUP BY status
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
counts_result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
|
||||
current_item = self.get_current(queue_id=queue_id)
|
||||
total = sum(row[1] or 0 for row in counts_result)
|
||||
@@ -756,19 +701,19 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
)
|
||||
|
||||
def get_batch_status(self, queue_id: str, batch_id: str) -> BatchStatus:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT status, count(*), origin, destination
|
||||
FROM session_queue
|
||||
WHERE
|
||||
queue_id = ?
|
||||
AND batch_id = ?
|
||||
GROUP BY status
|
||||
""",
|
||||
(queue_id, batch_id),
|
||||
)
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT status, count(*), origin, destination
|
||||
FROM session_queue
|
||||
WHERE
|
||||
queue_id = ?
|
||||
AND batch_id = ?
|
||||
GROUP BY status
|
||||
""",
|
||||
(queue_id, batch_id),
|
||||
)
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
total = sum(row[1] or 0 for row in result)
|
||||
counts: dict[str, int] = {row[0]: row[1] for row in result}
|
||||
origin = result[0]["origin"] if result else None
|
||||
@@ -788,18 +733,18 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
)
|
||||
|
||||
def get_counts_by_destination(self, queue_id: str, destination: str) -> SessionQueueCountsByDestination:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT status, count(*)
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
AND destination = ?
|
||||
GROUP BY status
|
||||
""",
|
||||
(queue_id, destination),
|
||||
)
|
||||
counts_result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT status, count(*)
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
AND destination = ?
|
||||
GROUP BY status
|
||||
""",
|
||||
(queue_id, destination),
|
||||
)
|
||||
counts_result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
|
||||
total = sum(row[1] or 0 for row in counts_result)
|
||||
counts: dict[str, int] = {row[0]: row[1] for row in counts_result}
|
||||
@@ -817,8 +762,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
|
||||
def retry_items_by_id(self, queue_id: str, item_ids: list[int]) -> RetryItemsResult:
|
||||
"""Retries the given queue items"""
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
values_to_insert: list[ValueToInsertTuple] = []
|
||||
retried_item_ids: list[int] = []
|
||||
|
||||
@@ -869,10 +813,6 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
values_to_insert,
|
||||
)
|
||||
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
retry_result = RetryItemsResult(
|
||||
queue_id=queue_id,
|
||||
retried_item_ids=retried_item_ids,
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
import sqlite3
|
||||
import threading
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
|
||||
@@ -26,46 +29,65 @@ class SqliteDatabase:
|
||||
|
||||
def __init__(self, db_path: Path | None, logger: Logger, verbose: bool = False) -> None:
|
||||
"""Initializes the database. This is used internally by the class constructor."""
|
||||
self.logger = logger
|
||||
self.db_path = db_path
|
||||
self.verbose = verbose
|
||||
self._logger = logger
|
||||
self._db_path = db_path
|
||||
self._verbose = verbose
|
||||
self._lock = threading.RLock()
|
||||
|
||||
if not self.db_path:
|
||||
if not self._db_path:
|
||||
logger.info("Initializing in-memory database")
|
||||
else:
|
||||
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self.logger.info(f"Initializing database at {self.db_path}")
|
||||
self._db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self._logger.info(f"Initializing database at {self._db_path}")
|
||||
|
||||
self.conn = sqlite3.connect(database=self.db_path or sqlite_memory, check_same_thread=False)
|
||||
self.conn.row_factory = sqlite3.Row
|
||||
self._conn = sqlite3.connect(database=self._db_path or sqlite_memory, check_same_thread=False)
|
||||
self._conn.row_factory = sqlite3.Row
|
||||
|
||||
if self.verbose:
|
||||
self.conn.set_trace_callback(self.logger.debug)
|
||||
if self._verbose:
|
||||
self._conn.set_trace_callback(self._logger.debug)
|
||||
|
||||
# Enable foreign key constraints
|
||||
self.conn.execute("PRAGMA foreign_keys = ON;")
|
||||
self._conn.execute("PRAGMA foreign_keys = ON;")
|
||||
|
||||
# Enable Write-Ahead Logging (WAL) mode for better concurrency
|
||||
self.conn.execute("PRAGMA journal_mode = WAL;")
|
||||
self._conn.execute("PRAGMA journal_mode = WAL;")
|
||||
|
||||
# Set a busy timeout to prevent database lockups during writes
|
||||
self.conn.execute("PRAGMA busy_timeout = 5000;") # 5 seconds
|
||||
self._conn.execute("PRAGMA busy_timeout = 5000;") # 5 seconds
|
||||
|
||||
def clean(self) -> None:
|
||||
"""
|
||||
Cleans the database by running the VACUUM command, reporting on the freed space.
|
||||
"""
|
||||
# No need to clean in-memory database
|
||||
if not self.db_path:
|
||||
if not self._db_path:
|
||||
return
|
||||
try:
|
||||
initial_db_size = Path(self.db_path).stat().st_size
|
||||
self.conn.execute("VACUUM;")
|
||||
self.conn.commit()
|
||||
final_db_size = Path(self.db_path).stat().st_size
|
||||
freed_space_in_mb = round((initial_db_size - final_db_size) / 1024 / 1024, 2)
|
||||
if freed_space_in_mb > 0:
|
||||
self.logger.info(f"Cleaned database (freed {freed_space_in_mb}MB)")
|
||||
with self._conn as conn:
|
||||
initial_db_size = Path(self._db_path).stat().st_size
|
||||
conn.execute("VACUUM;")
|
||||
conn.commit()
|
||||
final_db_size = Path(self._db_path).stat().st_size
|
||||
freed_space_in_mb = round((initial_db_size - final_db_size) / 1024 / 1024, 2)
|
||||
if freed_space_in_mb > 0:
|
||||
self._logger.info(f"Cleaned database (freed {freed_space_in_mb}MB)")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error cleaning database: {e}")
|
||||
self._logger.error(f"Error cleaning database: {e}")
|
||||
raise
|
||||
|
||||
@contextmanager
|
||||
def transaction(self) -> Generator[sqlite3.Cursor, None, None]:
|
||||
"""
|
||||
Thread-safe context manager for DB work.
|
||||
Acquires the RLock, yields a Cursor, then commits or rolls back.
|
||||
"""
|
||||
with self._lock:
|
||||
cursor = self._conn.cursor()
|
||||
try:
|
||||
yield cursor
|
||||
self._conn.commit()
|
||||
except:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
@@ -32,7 +32,7 @@ class SqliteMigrator:
|
||||
|
||||
def __init__(self, db: SqliteDatabase) -> None:
|
||||
self._db = db
|
||||
self._logger = db.logger
|
||||
self._logger = db._logger
|
||||
self._migration_set = MigrationSet()
|
||||
self._backup_path: Optional[Path] = None
|
||||
|
||||
@@ -45,7 +45,7 @@ class SqliteMigrator:
|
||||
"""Migrates the database to the latest version."""
|
||||
# This throws if there is a problem.
|
||||
self._migration_set.validate_migration_chain()
|
||||
cursor = self._db.conn.cursor()
|
||||
cursor = self._db._conn.cursor()
|
||||
self._create_migrations_table(cursor=cursor)
|
||||
|
||||
if self._migration_set.count == 0:
|
||||
@@ -59,13 +59,13 @@ class SqliteMigrator:
|
||||
self._logger.info("Database update needed")
|
||||
|
||||
# Make a backup of the db if it needs to be updated and is a file db
|
||||
if self._db.db_path is not None:
|
||||
if self._db._db_path is not None:
|
||||
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
|
||||
self._backup_path = self._db.db_path.parent / f"{self._db.db_path.stem}_backup_{timestamp}.db"
|
||||
self._backup_path = self._db._db_path.parent / f"{self._db._db_path.stem}_backup_{timestamp}.db"
|
||||
self._logger.info(f"Backing up database to {str(self._backup_path)}")
|
||||
# Use SQLite to do the backup
|
||||
with closing(sqlite3.connect(self._backup_path)) as backup_conn:
|
||||
self._db.conn.backup(backup_conn)
|
||||
self._db._conn.backup(backup_conn)
|
||||
else:
|
||||
self._logger.info("Using in-memory database, no backup needed")
|
||||
|
||||
@@ -81,7 +81,7 @@ class SqliteMigrator:
|
||||
try:
|
||||
# Using sqlite3.Connection as a context manager commits a the transaction on exit, or rolls it back if an
|
||||
# exception is raised.
|
||||
with self._db.conn as conn:
|
||||
with self._db._conn as conn:
|
||||
cursor = conn.cursor()
|
||||
if self._get_current_version(cursor) != migration.from_version:
|
||||
raise MigrationError(
|
||||
|
||||
@@ -17,7 +17,7 @@ from invokeai.app.util.misc import uuid_string
|
||||
class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
|
||||
def __init__(self, db: SqliteDatabase) -> None:
|
||||
super().__init__()
|
||||
self._conn = db.conn
|
||||
self._db = db
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self._invoker = invoker
|
||||
@@ -25,24 +25,23 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
|
||||
|
||||
def get(self, style_preset_id: str) -> StylePresetRecordDTO:
|
||||
"""Gets a style preset by ID."""
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM style_presets
|
||||
WHERE id = ?;
|
||||
""",
|
||||
(style_preset_id,),
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM style_presets
|
||||
WHERE id = ?;
|
||||
""",
|
||||
(style_preset_id,),
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
if row is None:
|
||||
raise StylePresetNotFoundError(f"Style preset with id {style_preset_id} not found")
|
||||
return StylePresetRecordDTO.from_dict(dict(row))
|
||||
|
||||
def create(self, style_preset: StylePresetWithoutId) -> StylePresetRecordDTO:
|
||||
style_preset_id = uuid_string()
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO style_presets (
|
||||
@@ -60,16 +59,11 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
|
||||
style_preset.type,
|
||||
),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
return self.get(style_preset_id)
|
||||
|
||||
def create_many(self, style_presets: list[StylePresetWithoutId]) -> None:
|
||||
style_preset_ids = []
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
for style_preset in style_presets:
|
||||
style_preset_id = uuid_string()
|
||||
style_preset_ids.append(style_preset_id)
|
||||
@@ -90,16 +84,11 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
|
||||
style_preset.type,
|
||||
),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
|
||||
return None
|
||||
|
||||
def update(self, style_preset_id: str, changes: StylePresetChanges) -> StylePresetRecordDTO:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
# Change the name of a style preset
|
||||
if changes.name is not None:
|
||||
cursor.execute(
|
||||
@@ -122,15 +111,10 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
|
||||
(changes.preset_data.model_dump_json(), style_preset_id),
|
||||
)
|
||||
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
return self.get(style_preset_id)
|
||||
|
||||
def delete(self, style_preset_id: str) -> None:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
DELETE from style_presets
|
||||
@@ -138,51 +122,41 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
|
||||
""",
|
||||
(style_preset_id,),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
return None
|
||||
|
||||
def get_many(self, type: PresetType | None = None) -> list[StylePresetRecordDTO]:
|
||||
main_query = """
|
||||
SELECT
|
||||
*
|
||||
FROM style_presets
|
||||
"""
|
||||
with self._db.transaction() as cursor:
|
||||
main_query = """
|
||||
SELECT
|
||||
*
|
||||
FROM style_presets
|
||||
"""
|
||||
|
||||
if type is not None:
|
||||
main_query += "WHERE type = ? "
|
||||
if type is not None:
|
||||
main_query += "WHERE type = ? "
|
||||
|
||||
main_query += "ORDER BY LOWER(name) ASC"
|
||||
main_query += "ORDER BY LOWER(name) ASC"
|
||||
|
||||
cursor = self._conn.cursor()
|
||||
if type is not None:
|
||||
cursor.execute(main_query, (type,))
|
||||
else:
|
||||
cursor.execute(main_query)
|
||||
if type is not None:
|
||||
cursor.execute(main_query, (type,))
|
||||
else:
|
||||
cursor.execute(main_query)
|
||||
|
||||
rows = cursor.fetchall()
|
||||
rows = cursor.fetchall()
|
||||
style_presets = [StylePresetRecordDTO.from_dict(dict(row)) for row in rows]
|
||||
|
||||
return style_presets
|
||||
|
||||
def _sync_default_style_presets(self) -> None:
|
||||
"""Syncs default style presets to the database. Internal use only."""
|
||||
|
||||
# First delete all existing default style presets
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
# First delete all existing default style presets
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM style_presets
|
||||
WHERE type = "default";
|
||||
"""
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
# Next, parse and create the default style presets
|
||||
with open(Path(__file__).parent / Path("default_style_presets.json"), "r") as file:
|
||||
presets = json.load(file)
|
||||
|
||||
@@ -25,7 +25,7 @@ SQL_TIME_FORMAT = "%Y-%m-%d %H:%M:%f"
|
||||
class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
def __init__(self, db: SqliteDatabase) -> None:
|
||||
super().__init__()
|
||||
self._conn = db.conn
|
||||
self._db = db
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self._invoker = invoker
|
||||
@@ -33,16 +33,16 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
|
||||
def get(self, workflow_id: str) -> WorkflowRecordDTO:
|
||||
"""Gets a workflow by ID. Updates the opened_at column."""
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT workflow_id, workflow, name, created_at, updated_at, opened_at
|
||||
FROM workflow_library
|
||||
WHERE workflow_id = ?;
|
||||
""",
|
||||
(workflow_id,),
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT workflow_id, workflow, name, created_at, updated_at, opened_at
|
||||
FROM workflow_library
|
||||
WHERE workflow_id = ?;
|
||||
""",
|
||||
(workflow_id,),
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
if row is None:
|
||||
raise WorkflowNotFoundError(f"Workflow with id {workflow_id} not found")
|
||||
return WorkflowRecordDTO.from_dict(dict(row))
|
||||
@@ -51,9 +51,8 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
if workflow.meta.category is WorkflowCategory.Default:
|
||||
raise ValueError("Default workflows cannot be created via this method")
|
||||
|
||||
try:
|
||||
with self._db.transaction() as cursor:
|
||||
workflow_with_id = Workflow(**workflow.model_dump(), id=uuid_string())
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO workflow_library (
|
||||
@@ -64,18 +63,13 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
""",
|
||||
(workflow_with_id.id, workflow_with_id.model_dump_json()),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
return self.get(workflow_with_id.id)
|
||||
|
||||
def update(self, workflow: Workflow) -> WorkflowRecordDTO:
|
||||
if workflow.meta.category is WorkflowCategory.Default:
|
||||
raise ValueError("Default workflows cannot be updated")
|
||||
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE workflow_library
|
||||
@@ -84,18 +78,13 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
""",
|
||||
(workflow.model_dump_json(), workflow.id),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
return self.get(workflow.id)
|
||||
|
||||
def delete(self, workflow_id: str) -> None:
|
||||
if self.get(workflow_id).workflow.meta.category is WorkflowCategory.Default:
|
||||
raise ValueError("Default workflows cannot be deleted")
|
||||
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
DELETE from workflow_library
|
||||
@@ -103,10 +92,6 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
""",
|
||||
(workflow_id,),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
return None
|
||||
|
||||
def get_many(
|
||||
@@ -121,108 +106,108 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
has_been_opened: Optional[bool] = None,
|
||||
is_published: Optional[bool] = None,
|
||||
) -> PaginatedResults[WorkflowRecordListItemDTO]:
|
||||
# sanitize!
|
||||
assert order_by in WorkflowRecordOrderBy
|
||||
assert direction in SQLiteDirection
|
||||
with self._db.transaction() as cursor:
|
||||
# sanitize!
|
||||
assert order_by in WorkflowRecordOrderBy
|
||||
assert direction in SQLiteDirection
|
||||
|
||||
# We will construct the query dynamically based on the query params
|
||||
# We will construct the query dynamically based on the query params
|
||||
|
||||
# The main query to get the workflows / counts
|
||||
main_query = """
|
||||
SELECT
|
||||
workflow_id,
|
||||
category,
|
||||
name,
|
||||
description,
|
||||
created_at,
|
||||
updated_at,
|
||||
opened_at,
|
||||
tags
|
||||
FROM workflow_library
|
||||
"""
|
||||
count_query = "SELECT COUNT(*) FROM workflow_library"
|
||||
# The main query to get the workflows / counts
|
||||
main_query = """
|
||||
SELECT
|
||||
workflow_id,
|
||||
category,
|
||||
name,
|
||||
description,
|
||||
created_at,
|
||||
updated_at,
|
||||
opened_at,
|
||||
tags
|
||||
FROM workflow_library
|
||||
"""
|
||||
count_query = "SELECT COUNT(*) FROM workflow_library"
|
||||
|
||||
# Start with an empty list of conditions and params
|
||||
conditions: list[str] = []
|
||||
params: list[str | int] = []
|
||||
# Start with an empty list of conditions and params
|
||||
conditions: list[str] = []
|
||||
params: list[str | int] = []
|
||||
|
||||
if categories:
|
||||
# Categories is a list of WorkflowCategory enum values, and a single string in the DB
|
||||
if categories:
|
||||
# Categories is a list of WorkflowCategory enum values, and a single string in the DB
|
||||
|
||||
# Ensure all categories are valid (is this necessary?)
|
||||
assert all(c in WorkflowCategory for c in categories)
|
||||
# Ensure all categories are valid (is this necessary?)
|
||||
assert all(c in WorkflowCategory for c in categories)
|
||||
|
||||
# Construct a placeholder string for the number of categories
|
||||
placeholders = ", ".join("?" for _ in categories)
|
||||
# Construct a placeholder string for the number of categories
|
||||
placeholders = ", ".join("?" for _ in categories)
|
||||
|
||||
# Construct the condition string & params
|
||||
category_condition = f"category IN ({placeholders})"
|
||||
category_params = [category.value for category in categories]
|
||||
# Construct the condition string & params
|
||||
category_condition = f"category IN ({placeholders})"
|
||||
category_params = [category.value for category in categories]
|
||||
|
||||
conditions.append(category_condition)
|
||||
params.extend(category_params)
|
||||
conditions.append(category_condition)
|
||||
params.extend(category_params)
|
||||
|
||||
if tags:
|
||||
# Tags is a list of strings, and a single string in the DB
|
||||
# The string in the DB has no guaranteed format
|
||||
if tags:
|
||||
# Tags is a list of strings, and a single string in the DB
|
||||
# The string in the DB has no guaranteed format
|
||||
|
||||
# Construct a list of conditions for each tag
|
||||
tags_conditions = ["tags LIKE ?" for _ in tags]
|
||||
tags_conditions_joined = " OR ".join(tags_conditions)
|
||||
tags_condition = f"({tags_conditions_joined})"
|
||||
# Construct a list of conditions for each tag
|
||||
tags_conditions = ["tags LIKE ?" for _ in tags]
|
||||
tags_conditions_joined = " OR ".join(tags_conditions)
|
||||
tags_condition = f"({tags_conditions_joined})"
|
||||
|
||||
# And the params for the tags, case-insensitive
|
||||
tags_params = [f"%{t.strip()}%" for t in tags]
|
||||
# And the params for the tags, case-insensitive
|
||||
tags_params = [f"%{t.strip()}%" for t in tags]
|
||||
|
||||
conditions.append(tags_condition)
|
||||
params.extend(tags_params)
|
||||
conditions.append(tags_condition)
|
||||
params.extend(tags_params)
|
||||
|
||||
if has_been_opened:
|
||||
conditions.append("opened_at IS NOT NULL")
|
||||
elif has_been_opened is False:
|
||||
conditions.append("opened_at IS NULL")
|
||||
if has_been_opened:
|
||||
conditions.append("opened_at IS NOT NULL")
|
||||
elif has_been_opened is False:
|
||||
conditions.append("opened_at IS NULL")
|
||||
|
||||
# Ignore whitespace in the query
|
||||
stripped_query = query.strip() if query else None
|
||||
if stripped_query:
|
||||
# Construct a wildcard query for the name, description, and tags
|
||||
wildcard_query = "%" + stripped_query + "%"
|
||||
query_condition = "(name LIKE ? OR description LIKE ? OR tags LIKE ?)"
|
||||
# Ignore whitespace in the query
|
||||
stripped_query = query.strip() if query else None
|
||||
if stripped_query:
|
||||
# Construct a wildcard query for the name, description, and tags
|
||||
wildcard_query = "%" + stripped_query + "%"
|
||||
query_condition = "(name LIKE ? OR description LIKE ? OR tags LIKE ?)"
|
||||
|
||||
conditions.append(query_condition)
|
||||
params.extend([wildcard_query, wildcard_query, wildcard_query])
|
||||
conditions.append(query_condition)
|
||||
params.extend([wildcard_query, wildcard_query, wildcard_query])
|
||||
|
||||
if conditions:
|
||||
# If there are conditions, add a WHERE clause and then join the conditions
|
||||
main_query += " WHERE "
|
||||
count_query += " WHERE "
|
||||
if conditions:
|
||||
# If there are conditions, add a WHERE clause and then join the conditions
|
||||
main_query += " WHERE "
|
||||
count_query += " WHERE "
|
||||
|
||||
all_conditions = " AND ".join(conditions)
|
||||
main_query += all_conditions
|
||||
count_query += all_conditions
|
||||
all_conditions = " AND ".join(conditions)
|
||||
main_query += all_conditions
|
||||
count_query += all_conditions
|
||||
|
||||
# After this point, the query and params differ for the main query and the count query
|
||||
main_params = params.copy()
|
||||
count_params = params.copy()
|
||||
# After this point, the query and params differ for the main query and the count query
|
||||
main_params = params.copy()
|
||||
count_params = params.copy()
|
||||
|
||||
# Main query also gets ORDER BY and LIMIT/OFFSET
|
||||
main_query += f" ORDER BY {order_by.value} {direction.value}"
|
||||
# Main query also gets ORDER BY and LIMIT/OFFSET
|
||||
main_query += f" ORDER BY {order_by.value} {direction.value}"
|
||||
|
||||
if per_page:
|
||||
main_query += " LIMIT ? OFFSET ?"
|
||||
main_params.extend([per_page, page * per_page])
|
||||
if per_page:
|
||||
main_query += " LIMIT ? OFFSET ?"
|
||||
main_params.extend([per_page, page * per_page])
|
||||
|
||||
# Put a ring on it
|
||||
main_query += ";"
|
||||
count_query += ";"
|
||||
# Put a ring on it
|
||||
main_query += ";"
|
||||
count_query += ";"
|
||||
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(main_query, main_params)
|
||||
rows = cursor.fetchall()
|
||||
workflows = [WorkflowRecordListItemDTOValidator.validate_python(dict(row)) for row in rows]
|
||||
cursor.execute(main_query, main_params)
|
||||
rows = cursor.fetchall()
|
||||
workflows = [WorkflowRecordListItemDTOValidator.validate_python(dict(row)) for row in rows]
|
||||
|
||||
cursor.execute(count_query, count_params)
|
||||
total = cursor.fetchone()[0]
|
||||
cursor.execute(count_query, count_params)
|
||||
total = cursor.fetchone()[0]
|
||||
|
||||
if per_page:
|
||||
pages = total // per_page + (total % per_page > 0)
|
||||
@@ -247,46 +232,46 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
if not tags:
|
||||
return {}
|
||||
|
||||
cursor = self._conn.cursor()
|
||||
result: dict[str, int] = {}
|
||||
# Base conditions for categories and selected tags
|
||||
base_conditions: list[str] = []
|
||||
base_params: list[str | int] = []
|
||||
with self._db.transaction() as cursor:
|
||||
result: dict[str, int] = {}
|
||||
# Base conditions for categories and selected tags
|
||||
base_conditions: list[str] = []
|
||||
base_params: list[str | int] = []
|
||||
|
||||
# Add category conditions
|
||||
if categories:
|
||||
assert all(c in WorkflowCategory for c in categories)
|
||||
placeholders = ", ".join("?" for _ in categories)
|
||||
base_conditions.append(f"category IN ({placeholders})")
|
||||
base_params.extend([category.value for category in categories])
|
||||
# Add category conditions
|
||||
if categories:
|
||||
assert all(c in WorkflowCategory for c in categories)
|
||||
placeholders = ", ".join("?" for _ in categories)
|
||||
base_conditions.append(f"category IN ({placeholders})")
|
||||
base_params.extend([category.value for category in categories])
|
||||
|
||||
if has_been_opened:
|
||||
base_conditions.append("opened_at IS NOT NULL")
|
||||
elif has_been_opened is False:
|
||||
base_conditions.append("opened_at IS NULL")
|
||||
if has_been_opened:
|
||||
base_conditions.append("opened_at IS NOT NULL")
|
||||
elif has_been_opened is False:
|
||||
base_conditions.append("opened_at IS NULL")
|
||||
|
||||
# For each tag to count, run a separate query
|
||||
for tag in tags:
|
||||
# Start with the base conditions
|
||||
conditions = base_conditions.copy()
|
||||
params = base_params.copy()
|
||||
# For each tag to count, run a separate query
|
||||
for tag in tags:
|
||||
# Start with the base conditions
|
||||
conditions = base_conditions.copy()
|
||||
params = base_params.copy()
|
||||
|
||||
# Add this specific tag condition
|
||||
conditions.append("tags LIKE ?")
|
||||
params.append(f"%{tag.strip()}%")
|
||||
# Add this specific tag condition
|
||||
conditions.append("tags LIKE ?")
|
||||
params.append(f"%{tag.strip()}%")
|
||||
|
||||
# Construct the full query
|
||||
stmt = """--sql
|
||||
SELECT COUNT(*)
|
||||
FROM workflow_library
|
||||
"""
|
||||
# Construct the full query
|
||||
stmt = """--sql
|
||||
SELECT COUNT(*)
|
||||
FROM workflow_library
|
||||
"""
|
||||
|
||||
if conditions:
|
||||
stmt += " WHERE " + " AND ".join(conditions)
|
||||
if conditions:
|
||||
stmt += " WHERE " + " AND ".join(conditions)
|
||||
|
||||
cursor.execute(stmt, params)
|
||||
count = cursor.fetchone()[0]
|
||||
result[tag] = count
|
||||
cursor.execute(stmt, params)
|
||||
count = cursor.fetchone()[0]
|
||||
result[tag] = count
|
||||
|
||||
return result
|
||||
|
||||
@@ -296,52 +281,51 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
has_been_opened: Optional[bool] = None,
|
||||
is_published: Optional[bool] = None,
|
||||
) -> dict[str, int]:
|
||||
cursor = self._conn.cursor()
|
||||
result: dict[str, int] = {}
|
||||
# Base conditions for categories
|
||||
base_conditions: list[str] = []
|
||||
base_params: list[str | int] = []
|
||||
with self._db.transaction() as cursor:
|
||||
result: dict[str, int] = {}
|
||||
# Base conditions for categories
|
||||
base_conditions: list[str] = []
|
||||
base_params: list[str | int] = []
|
||||
|
||||
# Add category conditions
|
||||
if categories:
|
||||
assert all(c in WorkflowCategory for c in categories)
|
||||
placeholders = ", ".join("?" for _ in categories)
|
||||
base_conditions.append(f"category IN ({placeholders})")
|
||||
base_params.extend([category.value for category in categories])
|
||||
# Add category conditions
|
||||
if categories:
|
||||
assert all(c in WorkflowCategory for c in categories)
|
||||
placeholders = ", ".join("?" for _ in categories)
|
||||
base_conditions.append(f"category IN ({placeholders})")
|
||||
base_params.extend([category.value for category in categories])
|
||||
|
||||
if has_been_opened:
|
||||
base_conditions.append("opened_at IS NOT NULL")
|
||||
elif has_been_opened is False:
|
||||
base_conditions.append("opened_at IS NULL")
|
||||
if has_been_opened:
|
||||
base_conditions.append("opened_at IS NOT NULL")
|
||||
elif has_been_opened is False:
|
||||
base_conditions.append("opened_at IS NULL")
|
||||
|
||||
# For each category to count, run a separate query
|
||||
for category in categories:
|
||||
# Start with the base conditions
|
||||
conditions = base_conditions.copy()
|
||||
params = base_params.copy()
|
||||
# For each category to count, run a separate query
|
||||
for category in categories:
|
||||
# Start with the base conditions
|
||||
conditions = base_conditions.copy()
|
||||
params = base_params.copy()
|
||||
|
||||
# Add this specific category condition
|
||||
conditions.append("category = ?")
|
||||
params.append(category.value)
|
||||
# Add this specific category condition
|
||||
conditions.append("category = ?")
|
||||
params.append(category.value)
|
||||
|
||||
# Construct the full query
|
||||
stmt = """--sql
|
||||
SELECT COUNT(*)
|
||||
FROM workflow_library
|
||||
"""
|
||||
# Construct the full query
|
||||
stmt = """--sql
|
||||
SELECT COUNT(*)
|
||||
FROM workflow_library
|
||||
"""
|
||||
|
||||
if conditions:
|
||||
stmt += " WHERE " + " AND ".join(conditions)
|
||||
if conditions:
|
||||
stmt += " WHERE " + " AND ".join(conditions)
|
||||
|
||||
cursor.execute(stmt, params)
|
||||
count = cursor.fetchone()[0]
|
||||
result[category.value] = count
|
||||
cursor.execute(stmt, params)
|
||||
count = cursor.fetchone()[0]
|
||||
result[category.value] = count
|
||||
|
||||
return result
|
||||
|
||||
def update_opened_at(self, workflow_id: str) -> None:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
f"""--sql
|
||||
UPDATE workflow_library
|
||||
@@ -350,10 +334,6 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
""",
|
||||
(workflow_id,),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
|
||||
def _sync_default_workflows(self) -> None:
|
||||
"""Syncs default workflows to the database. Internal use only."""
|
||||
@@ -368,8 +348,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
meaningless, as they are overwritten every time the server starts.
|
||||
"""
|
||||
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
workflows_from_file: list[Workflow] = []
|
||||
workflows_to_update: list[Workflow] = []
|
||||
workflows_to_add: list[Workflow] = []
|
||||
@@ -449,8 +428,3 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
""",
|
||||
(w.model_dump_json(), w.id),
|
||||
)
|
||||
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
|
||||
0
invokeai/backend/bria/__init__.py
Normal file
0
invokeai/backend/bria/__init__.py
Normal file
314
invokeai/backend/bria/bria_utils.py
Normal file
314
invokeai/backend/bria/bria_utils.py
Normal file
@@ -0,0 +1,314 @@
|
||||
import math
|
||||
import os
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from diffusers.utils import logging
|
||||
from transformers import (
|
||||
CLIPTextModel,
|
||||
CLIPTextModelWithProjection,
|
||||
CLIPTokenizer,
|
||||
T5EncoderModel,
|
||||
T5TokenizerFast,
|
||||
)
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def get_t5_prompt_embeds(
|
||||
tokenizer: T5TokenizerFast,
|
||||
text_encoder: T5EncoderModel,
|
||||
prompt: Union[str, List[str], None] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
max_sequence_length: int = 128,
|
||||
device: Optional[torch.device] = None,
|
||||
):
|
||||
device = device or text_encoder.device
|
||||
|
||||
if prompt is None:
|
||||
prompt = ""
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
|
||||
text_inputs = tokenizer(
|
||||
prompt,
|
||||
# padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because `max_sequence_length` is set to "
|
||||
f" {max_sequence_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
prompt_embeds = text_encoder(text_input_ids.to(device))[0]
|
||||
|
||||
# Concat zeros to max_sequence
|
||||
b, seq_len, dim = prompt_embeds.shape
|
||||
if seq_len < max_sequence_length:
|
||||
padding = torch.zeros(
|
||||
(b, max_sequence_length - seq_len, dim), dtype=prompt_embeds.dtype, device=prompt_embeds.device
|
||||
)
|
||||
prompt_embeds = torch.concat([prompt_embeds, padding], dim=1)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(device=device)
|
||||
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
|
||||
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
|
||||
# in order the get the same sigmas as in training and sample from them
|
||||
def get_original_sigmas(num_train_timesteps=1000, num_inference_steps=1000):
|
||||
timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
|
||||
sigmas = timesteps / num_train_timesteps
|
||||
|
||||
inds = [int(ind) for ind in np.linspace(0, num_train_timesteps - 1, num_inference_steps)]
|
||||
new_sigmas = sigmas[inds]
|
||||
return new_sigmas
|
||||
|
||||
|
||||
def is_ng_none(negative_prompt):
|
||||
return (
|
||||
negative_prompt is None
|
||||
or negative_prompt == ""
|
||||
or (isinstance(negative_prompt, list) and negative_prompt[0] is None)
|
||||
or (isinstance(negative_prompt, list) and negative_prompt[0] == "")
|
||||
)
|
||||
|
||||
|
||||
class CudaTimerContext:
|
||||
def __init__(self, times_arr):
|
||||
self.times_arr = times_arr
|
||||
|
||||
def __enter__(self):
|
||||
self.before_event = torch.cuda.Event(enable_timing=True)
|
||||
self.after_event = torch.cuda.Event(enable_timing=True)
|
||||
self.before_event.record()
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
self.after_event.record()
|
||||
torch.cuda.synchronize()
|
||||
elapsed_time = self.before_event.elapsed_time(self.after_event) / 1000
|
||||
self.times_arr.append(elapsed_time)
|
||||
|
||||
|
||||
def get_env_prefix():
|
||||
env = os.environ.get("CLOUD_PROVIDER", "AWS").upper()
|
||||
if env == "AWS":
|
||||
return "SM_CHANNEL"
|
||||
elif env == "AZURE":
|
||||
return "AZUREML_DATAREFERENCE"
|
||||
|
||||
raise Exception(f"Env {env} not supported")
|
||||
|
||||
|
||||
def compute_density_for_timestep_sampling(
|
||||
weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None
|
||||
):
|
||||
"""Compute the density for sampling the timesteps when doing SD3 training.
|
||||
|
||||
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
|
||||
|
||||
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
|
||||
"""
|
||||
if weighting_scheme == "logit_normal":
|
||||
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
|
||||
u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu")
|
||||
u = torch.nn.functional.sigmoid(u)
|
||||
elif weighting_scheme == "mode":
|
||||
u = torch.rand(size=(batch_size,), device="cpu")
|
||||
u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
|
||||
else:
|
||||
u = torch.rand(size=(batch_size,), device="cpu")
|
||||
return u
|
||||
|
||||
|
||||
def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
|
||||
"""Computes loss weighting scheme for SD3 training.
|
||||
|
||||
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
|
||||
|
||||
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
|
||||
"""
|
||||
if weighting_scheme == "sigma_sqrt":
|
||||
weighting = (sigmas**-2.0).float()
|
||||
elif weighting_scheme == "cosmap":
|
||||
bot = 1 - 2 * sigmas + 2 * sigmas**2
|
||||
weighting = 2 / (math.pi * bot)
|
||||
else:
|
||||
weighting = torch.ones_like(sigmas)
|
||||
return weighting
|
||||
|
||||
|
||||
def initialize_distributed():
|
||||
# Initialize the process group for distributed training
|
||||
dist.init_process_group("nccl")
|
||||
|
||||
# Get the current process's rank (ID) and the total number of processes (world size)
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
|
||||
print(f"Initialized distributed training: Rank {rank}/{world_size}")
|
||||
|
||||
|
||||
def get_clip_prompt_embeds(
|
||||
text_encoder: CLIPTextModel,
|
||||
text_encoder_2: CLIPTextModelWithProjection,
|
||||
tokenizer: CLIPTokenizer,
|
||||
tokenizer_2: CLIPTokenizer,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
max_sequence_length: int = 77,
|
||||
device: Optional[torch.device] = None,
|
||||
):
|
||||
device = device or text_encoder.device
|
||||
assert max_sequence_length == tokenizer.model_max_length
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
# Define tokenizers and text encoders
|
||||
tokenizers = [tokenizer, tokenizer_2]
|
||||
text_encoders = [text_encoder, text_encoder_2]
|
||||
|
||||
# textual inversion: process multi-vector tokens if necessary
|
||||
prompt_embeds_list = []
|
||||
prompts = [prompt, prompt]
|
||||
for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders, strict=False):
|
||||
text_inputs = tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
text_input_ids = text_inputs.input_ids
|
||||
prompt_embeds = text_encoder(text_input_ids.to(text_encoder.device), output_hidden_states=True)
|
||||
|
||||
# We are only ALWAYS interested in the pooled output of the final text encoder
|
||||
pooled_prompt_embeds = prompt_embeds[0]
|
||||
prompt_embeds = prompt_embeds.hidden_states[-2]
|
||||
|
||||
prompt_embeds_list.append(prompt_embeds)
|
||||
|
||||
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
|
||||
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
|
||||
bs_embed * num_images_per_prompt, -1
|
||||
)
|
||||
|
||||
return prompt_embeds, pooled_prompt_embeds
|
||||
|
||||
|
||||
def get_1d_rotary_pos_embed(
|
||||
dim: int,
|
||||
pos: Union[np.ndarray, int],
|
||||
theta: float = 10000.0,
|
||||
use_real=False,
|
||||
linear_factor=1.0,
|
||||
ntk_factor=1.0,
|
||||
repeat_interleave_real=True,
|
||||
freqs_dtype=torch.float32, # torch.float32, torch.float64 (flux)
|
||||
):
|
||||
"""
|
||||
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
|
||||
|
||||
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end
|
||||
index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64
|
||||
data type.
|
||||
|
||||
Args:
|
||||
dim (`int`): Dimension of the frequency tensor.
|
||||
pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar
|
||||
theta (`float`, *optional*, defaults to 10000.0):
|
||||
Scaling factor for frequency computation. Defaults to 10000.0.
|
||||
use_real (`bool`, *optional*):
|
||||
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
||||
linear_factor (`float`, *optional*, defaults to 1.0):
|
||||
Scaling factor for the context extrapolation. Defaults to 1.0.
|
||||
ntk_factor (`float`, *optional*, defaults to 1.0):
|
||||
Scaling factor for the NTK-Aware RoPE. Defaults to 1.0.
|
||||
repeat_interleave_real (`bool`, *optional*, defaults to `True`):
|
||||
If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`.
|
||||
Otherwise, they are concateanted with themselves.
|
||||
freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`):
|
||||
the dtype of the frequency tensor.
|
||||
Returns:
|
||||
`torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
|
||||
"""
|
||||
assert dim % 2 == 0
|
||||
|
||||
if isinstance(pos, int):
|
||||
pos = torch.arange(pos)
|
||||
if isinstance(pos, np.ndarray):
|
||||
pos = torch.from_numpy(pos) # type: ignore # [S]
|
||||
|
||||
theta = theta * ntk_factor
|
||||
freqs = (
|
||||
1.0
|
||||
/ (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim))
|
||||
/ linear_factor
|
||||
) # [D/2]
|
||||
freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
|
||||
if use_real and repeat_interleave_real:
|
||||
# flux, hunyuan-dit, cogvideox
|
||||
freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
|
||||
freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
|
||||
return freqs_cos, freqs_sin
|
||||
elif use_real:
|
||||
# stable audio, allegro
|
||||
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
|
||||
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]
|
||||
return freqs_cos, freqs_sin
|
||||
else:
|
||||
# lumina
|
||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
|
||||
return freqs_cis
|
||||
|
||||
|
||||
class FluxPosEmbed(torch.nn.Module):
|
||||
# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
|
||||
def __init__(self, theta: int, axes_dim: List[int]):
|
||||
super().__init__()
|
||||
self.theta = theta
|
||||
self.axes_dim = axes_dim
|
||||
|
||||
def forward(self, ids: torch.Tensor) -> torch.Tensor:
|
||||
n_axes = ids.shape[-1]
|
||||
cos_out = []
|
||||
sin_out = []
|
||||
pos = ids.float()
|
||||
is_mps = ids.device.type == "mps"
|
||||
freqs_dtype = torch.float32 if is_mps else torch.float64
|
||||
for i in range(n_axes):
|
||||
cos, sin = get_1d_rotary_pos_embed(
|
||||
self.axes_dim[i],
|
||||
pos[:, i],
|
||||
theta=self.theta,
|
||||
repeat_interleave_real=True,
|
||||
use_real=True,
|
||||
freqs_dtype=freqs_dtype,
|
||||
)
|
||||
cos_out.append(cos)
|
||||
sin_out.append(sin)
|
||||
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
|
||||
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
|
||||
return freqs_cos, freqs_sin
|
||||
6
invokeai/backend/bria/controlnet_aux/__init__.py
Normal file
6
invokeai/backend/bria/controlnet_aux/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
__version__ = "0.0.9"
|
||||
|
||||
from invokeai.backend.bria.controlnet_aux.canny import CannyDetector as CannyDetector
|
||||
from invokeai.backend.bria.controlnet_aux.open_pose import OpenposeDetector as OpenposeDetector
|
||||
|
||||
__all__ = ["CannyDetector", "OpenposeDetector"]
|
||||
39
invokeai/backend/bria/controlnet_aux/canny/__init__.py
Normal file
39
invokeai/backend/bria/controlnet_aux/canny/__init__.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import warnings
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.backend.bria.controlnet_aux.util import HWC3, resize_image
|
||||
|
||||
|
||||
class CannyDetector:
|
||||
def __call__(self, input_image=None, low_threshold=100, high_threshold=200, detect_resolution=512, image_resolution=512, output_type=None, **kwargs):
|
||||
if "img" in kwargs:
|
||||
warnings.warn("img is deprecated, please use `input_image=...` instead.", DeprecationWarning, stacklevel=2)
|
||||
input_image = kwargs.pop("img")
|
||||
|
||||
if input_image is None:
|
||||
raise ValueError("input_image must be defined.")
|
||||
|
||||
if not isinstance(input_image, np.ndarray):
|
||||
input_image = np.array(input_image, dtype=np.uint8)
|
||||
output_type = output_type or "pil"
|
||||
else:
|
||||
output_type = output_type or "np"
|
||||
|
||||
input_image = HWC3(input_image)
|
||||
input_image = resize_image(input_image, detect_resolution)
|
||||
|
||||
detected_map = cv2.Canny(input_image, low_threshold, high_threshold)
|
||||
detected_map = HWC3(detected_map)
|
||||
|
||||
img = resize_image(input_image, image_resolution)
|
||||
H, W, C = img.shape
|
||||
|
||||
detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
if output_type == "pil":
|
||||
detected_map = Image.fromarray(detected_map)
|
||||
|
||||
return detected_map
|
||||
108
invokeai/backend/bria/controlnet_aux/open_pose/LICENSE
Normal file
108
invokeai/backend/bria/controlnet_aux/open_pose/LICENSE
Normal file
@@ -0,0 +1,108 @@
|
||||
OPENPOSE: MULTIPERSON KEYPOINT DETECTION
|
||||
SOFTWARE LICENSE AGREEMENT
|
||||
ACADEMIC OR NON-PROFIT ORGANIZATION NONCOMMERCIAL RESEARCH USE ONLY
|
||||
|
||||
BY USING OR DOWNLOADING THE SOFTWARE, YOU ARE AGREEING TO THE TERMS OF THIS LICENSE AGREEMENT. IF YOU DO NOT AGREE WITH THESE TERMS, YOU MAY NOT USE OR DOWNLOAD THE SOFTWARE.
|
||||
|
||||
This is a license agreement ("Agreement") between your academic institution or non-profit organization or self (called "Licensee" or "You" in this Agreement) and Carnegie Mellon University (called "Licensor" in this Agreement). All rights not specifically granted to you in this Agreement are reserved for Licensor.
|
||||
|
||||
RESERVATION OF OWNERSHIP AND GRANT OF LICENSE:
|
||||
Licensor retains exclusive ownership of any copy of the Software (as defined below) licensed under this Agreement and hereby grants to Licensee a personal, non-exclusive,
|
||||
non-transferable license to use the Software for noncommercial research purposes, without the right to sublicense, pursuant to the terms and conditions of this Agreement. As used in this Agreement, the term "Software" means (i) the actual copy of all or any portion of code for program routines made accessible to Licensee by Licensor pursuant to this Agreement, inclusive of backups, updates, and/or merged copies permitted hereunder or subsequently supplied by Licensor, including all or any file structures, programming instructions, user interfaces and screen formats and sequences as well as any and all documentation and instructions related to it, and (ii) all or any derivatives and/or modifications created or made by You to any of the items specified in (i).
|
||||
|
||||
CONFIDENTIALITY: Licensee acknowledges that the Software is proprietary to Licensor, and as such, Licensee agrees to receive all such materials in confidence and use the Software only in accordance with the terms of this Agreement. Licensee agrees to use reasonable effort to protect the Software from unauthorized use, reproduction, distribution, or publication.
|
||||
|
||||
COPYRIGHT: The Software is owned by Licensor and is protected by United
|
||||
States copyright laws and applicable international treaties and/or conventions.
|
||||
|
||||
PERMITTED USES: The Software may be used for your own noncommercial internal research purposes. You understand and agree that Licensor is not obligated to implement any suggestions and/or feedback you might provide regarding the Software, but to the extent Licensor does so, you are not entitled to any compensation related thereto.
|
||||
|
||||
DERIVATIVES: You may create derivatives of or make modifications to the Software, however, You agree that all and any such derivatives and modifications will be owned by Licensor and become a part of the Software licensed to You under this Agreement. You may only use such derivatives and modifications for your own noncommercial internal research purposes, and you may not otherwise use, distribute or copy such derivatives and modifications in violation of this Agreement.
|
||||
|
||||
BACKUPS: If Licensee is an organization, it may make that number of copies of the Software necessary for internal noncommercial use at a single site within its organization provided that all information appearing in or on the original labels, including the copyright and trademark notices are copied onto the labels of the copies.
|
||||
|
||||
USES NOT PERMITTED: You may not distribute, copy or use the Software except as explicitly permitted herein. Licensee has not been granted any trademark license as part of this Agreement and may not use the name or mark “OpenPose", "Carnegie Mellon" or any renditions thereof without the prior written permission of Licensor.
|
||||
|
||||
You may not sell, rent, lease, sublicense, lend, time-share or transfer, in whole or in part, or provide third parties access to prior or present versions (or any parts thereof) of the Software.
|
||||
|
||||
ASSIGNMENT: You may not assign this Agreement or your rights hereunder without the prior written consent of Licensor. Any attempted assignment without such consent shall be null and void.
|
||||
|
||||
TERM: The term of the license granted by this Agreement is from Licensee's acceptance of this Agreement by downloading the Software or by using the Software until terminated as provided below.
|
||||
|
||||
The Agreement automatically terminates without notice if you fail to comply with any provision of this Agreement. Licensee may terminate this Agreement by ceasing using the Software. Upon any termination of this Agreement, Licensee will delete any and all copies of the Software. You agree that all provisions which operate to protect the proprietary rights of Licensor shall remain in force should breach occur and that the obligation of confidentiality described in this Agreement is binding in perpetuity and, as such, survives the term of the Agreement.
|
||||
|
||||
FEE: Provided Licensee abides completely by the terms and conditions of this Agreement, there is no fee due to Licensor for Licensee's use of the Software in accordance with this Agreement.
|
||||
|
||||
DISCLAIMER OF WARRANTIES: THE SOFTWARE IS PROVIDED "AS-IS" WITHOUT WARRANTY OF ANY KIND INCLUDING ANY WARRANTIES OF PERFORMANCE OR MERCHANTABILITY OR FITNESS FOR A PARTICULAR USE OR PURPOSE OR OF NON-INFRINGEMENT. LICENSEE BEARS ALL RISK RELATING TO QUALITY AND PERFORMANCE OF THE SOFTWARE AND RELATED MATERIALS.
|
||||
|
||||
SUPPORT AND MAINTENANCE: No Software support or training by the Licensor is provided as part of this Agreement.
|
||||
|
||||
EXCLUSIVE REMEDY AND LIMITATION OF LIABILITY: To the maximum extent permitted under applicable law, Licensor shall not be liable for direct, indirect, special, incidental, or consequential damages or lost profits related to Licensee's use of and/or inability to use the Software, even if Licensor is advised of the possibility of such damage.
|
||||
|
||||
EXPORT REGULATION: Licensee agrees to comply with any and all applicable
|
||||
U.S. export control laws, regulations, and/or other laws related to embargoes and sanction programs administered by the Office of Foreign Assets Control.
|
||||
|
||||
SEVERABILITY: If any provision(s) of this Agreement shall be held to be invalid, illegal, or unenforceable by a court or other tribunal of competent jurisdiction, the validity, legality and enforceability of the remaining provisions shall not in any way be affected or impaired thereby.
|
||||
|
||||
NO IMPLIED WAIVERS: No failure or delay by Licensor in enforcing any right or remedy under this Agreement shall be construed as a waiver of any future or other exercise of such right or remedy by Licensor.
|
||||
|
||||
GOVERNING LAW: This Agreement shall be construed and enforced in accordance with the laws of the Commonwealth of Pennsylvania without reference to conflict of laws principles. You consent to the personal jurisdiction of the courts of this County and waive their rights to venue outside of Allegheny County, Pennsylvania.
|
||||
|
||||
ENTIRE AGREEMENT AND AMENDMENTS: This Agreement constitutes the sole and entire agreement between Licensee and Licensor as to the matter set forth herein and supersedes any previous agreements, understandings, and arrangements between the parties relating hereto.
|
||||
|
||||
|
||||
|
||||
************************************************************************
|
||||
|
||||
THIRD-PARTY SOFTWARE NOTICES AND INFORMATION
|
||||
|
||||
This project incorporates material from the project(s) listed below (collectively, "Third Party Code"). This Third Party Code is licensed to you under their original license terms set forth below. We reserves all other rights not expressly granted, whether by implication, estoppel or otherwise.
|
||||
|
||||
1. Caffe, version 1.0.0, (https://github.com/BVLC/caffe/)
|
||||
|
||||
COPYRIGHT
|
||||
|
||||
All contributions by the University of California:
|
||||
Copyright (c) 2014-2017 The Regents of the University of California (Regents)
|
||||
All rights reserved.
|
||||
|
||||
All other contributions:
|
||||
Copyright (c) 2014-2017, the respective contributors
|
||||
All rights reserved.
|
||||
|
||||
Caffe uses a shared copyright model: each contributor holds copyright over
|
||||
their contributions to Caffe. The project versioning records all such
|
||||
contribution and copyright details. If a contributor wants to further mark
|
||||
their specific copyright on a particular contribution, they should indicate
|
||||
their copyright solely in the commit message of the change when it is
|
||||
committed.
|
||||
|
||||
LICENSE
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
|
||||
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
CONTRIBUTION AGREEMENT
|
||||
|
||||
By contributing to the BVLC/caffe repository through pull-request, comment,
|
||||
or otherwise, the contributor releases their content to the
|
||||
license and copyright terms herein.
|
||||
|
||||
************END OF THIRD-PARTY SOFTWARE NOTICES AND INFORMATION**********
|
||||
233
invokeai/backend/bria/controlnet_aux/open_pose/__init__.py
Normal file
233
invokeai/backend/bria/controlnet_aux/open_pose/__init__.py
Normal file
@@ -0,0 +1,233 @@
|
||||
# Openpose
|
||||
# Original from CMU https://github.com/CMU-Perceptual-Computing-Lab/openpose
|
||||
# 2nd Edited by https://github.com/Hzzone/pytorch-openpose
|
||||
# 3rd Edited by ControlNet
|
||||
# 4th Edited by ControlNet (added face and correct hands)
|
||||
# 5th Edited by ControlNet (Improved JSON serialization/deserialization, and lots of bug fixs)
|
||||
# This preprocessor is licensed by CMU for non-commercial use only.
|
||||
|
||||
|
||||
import os
|
||||
|
||||
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
|
||||
|
||||
import warnings
|
||||
from typing import List, NamedTuple, Tuple, Union
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.backend.bria.controlnet_aux.open_pose import util
|
||||
from invokeai.backend.bria.controlnet_aux.open_pose.body import Body, BodyResult, Keypoint
|
||||
from invokeai.backend.bria.controlnet_aux.open_pose.face import Face
|
||||
from invokeai.backend.bria.controlnet_aux.open_pose.hand import Hand
|
||||
from invokeai.backend.bria.controlnet_aux.util import HWC3, resize_image
|
||||
|
||||
HandResult = List[Keypoint]
|
||||
FaceResult = List[Keypoint]
|
||||
|
||||
class PoseResult(NamedTuple):
|
||||
body: BodyResult
|
||||
left_hand: Union[HandResult, None]
|
||||
right_hand: Union[HandResult, None]
|
||||
face: Union[FaceResult, None]
|
||||
|
||||
def draw_poses(poses: List[PoseResult], H, W, draw_body=True, draw_hand=True, draw_face=True):
|
||||
"""
|
||||
Draw the detected poses on an empty canvas.
|
||||
|
||||
Args:
|
||||
poses (List[PoseResult]): A list of PoseResult objects containing the detected poses.
|
||||
H (int): The height of the canvas.
|
||||
W (int): The width of the canvas.
|
||||
draw_body (bool, optional): Whether to draw body keypoints. Defaults to True.
|
||||
draw_hand (bool, optional): Whether to draw hand keypoints. Defaults to True.
|
||||
draw_face (bool, optional): Whether to draw face keypoints. Defaults to True.
|
||||
|
||||
Returns:
|
||||
numpy.ndarray: A 3D numpy array representing the canvas with the drawn poses.
|
||||
"""
|
||||
canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8)
|
||||
|
||||
for pose in poses:
|
||||
if draw_body:
|
||||
canvas = util.draw_bodypose(canvas, pose.body.keypoints)
|
||||
|
||||
if draw_hand:
|
||||
canvas = util.draw_handpose(canvas, pose.left_hand)
|
||||
canvas = util.draw_handpose(canvas, pose.right_hand)
|
||||
|
||||
if draw_face:
|
||||
canvas = util.draw_facepose(canvas, pose.face)
|
||||
|
||||
return canvas
|
||||
|
||||
|
||||
class OpenposeDetector:
|
||||
"""
|
||||
A class for detecting human poses in images using the Openpose model.
|
||||
|
||||
Attributes:
|
||||
model_dir (str): Path to the directory where the pose models are stored.
|
||||
"""
|
||||
def __init__(self, body_estimation, hand_estimation=None, face_estimation=None):
|
||||
self.body_estimation = body_estimation
|
||||
self.hand_estimation = hand_estimation
|
||||
self.face_estimation = face_estimation
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_or_path, filename=None, hand_filename=None, face_filename=None, cache_dir=None, local_files_only=False):
|
||||
|
||||
if pretrained_model_or_path == "lllyasviel/ControlNet":
|
||||
filename = filename or "annotator/ckpts/body_pose_model.pth"
|
||||
hand_filename = hand_filename or "annotator/ckpts/hand_pose_model.pth"
|
||||
face_filename = face_filename or "facenet.pth"
|
||||
|
||||
face_pretrained_model_or_path = "lllyasviel/Annotators"
|
||||
else:
|
||||
filename = filename or "body_pose_model.pth"
|
||||
hand_filename = hand_filename or "hand_pose_model.pth"
|
||||
face_filename = face_filename or "facenet.pth"
|
||||
|
||||
face_pretrained_model_or_path = pretrained_model_or_path
|
||||
|
||||
if os.path.isdir(pretrained_model_or_path):
|
||||
body_model_path = os.path.join(pretrained_model_or_path, filename)
|
||||
hand_model_path = os.path.join(pretrained_model_or_path, hand_filename)
|
||||
face_model_path = os.path.join(face_pretrained_model_or_path, face_filename)
|
||||
else:
|
||||
body_model_path = hf_hub_download(pretrained_model_or_path, filename, cache_dir=cache_dir, local_files_only=local_files_only)
|
||||
hand_model_path = hf_hub_download(pretrained_model_or_path, hand_filename, cache_dir=cache_dir, local_files_only=local_files_only)
|
||||
face_model_path = hf_hub_download(face_pretrained_model_or_path, face_filename, cache_dir=cache_dir, local_files_only=local_files_only)
|
||||
|
||||
body_estimation = Body(body_model_path)
|
||||
hand_estimation = Hand(hand_model_path)
|
||||
face_estimation = Face(face_model_path)
|
||||
|
||||
return cls(body_estimation, hand_estimation, face_estimation)
|
||||
|
||||
def to(self, device):
|
||||
self.body_estimation.to(device)
|
||||
self.hand_estimation.to(device)
|
||||
self.face_estimation.to(device)
|
||||
return self
|
||||
|
||||
def detect_hands(self, body: BodyResult, oriImg) -> Tuple[Union[HandResult, None], Union[HandResult, None]]:
|
||||
left_hand = None
|
||||
right_hand = None
|
||||
H, W, _ = oriImg.shape
|
||||
for x, y, w, is_left in util.handDetect(body, oriImg):
|
||||
peaks = self.hand_estimation(oriImg[y:y+w, x:x+w, :]).astype(np.float32)
|
||||
if peaks.ndim == 2 and peaks.shape[1] == 2:
|
||||
peaks[:, 0] = np.where(peaks[:, 0] < 1e-6, -1, peaks[:, 0] + x) / float(W)
|
||||
peaks[:, 1] = np.where(peaks[:, 1] < 1e-6, -1, peaks[:, 1] + y) / float(H)
|
||||
|
||||
hand_result = [
|
||||
Keypoint(x=peak[0], y=peak[1])
|
||||
for peak in peaks
|
||||
]
|
||||
|
||||
if is_left:
|
||||
left_hand = hand_result
|
||||
else:
|
||||
right_hand = hand_result
|
||||
|
||||
return left_hand, right_hand
|
||||
|
||||
def detect_face(self, body: BodyResult, oriImg) -> Union[FaceResult, None]:
|
||||
face = util.faceDetect(body, oriImg)
|
||||
if face is None:
|
||||
return None
|
||||
|
||||
x, y, w = face
|
||||
H, W, _ = oriImg.shape
|
||||
heatmaps = self.face_estimation(oriImg[y:y+w, x:x+w, :])
|
||||
peaks = self.face_estimation.compute_peaks_from_heatmaps(heatmaps).astype(np.float32)
|
||||
if peaks.ndim == 2 and peaks.shape[1] == 2:
|
||||
peaks[:, 0] = np.where(peaks[:, 0] < 1e-6, -1, peaks[:, 0] + x) / float(W)
|
||||
peaks[:, 1] = np.where(peaks[:, 1] < 1e-6, -1, peaks[:, 1] + y) / float(H)
|
||||
return [
|
||||
Keypoint(x=peak[0], y=peak[1])
|
||||
for peak in peaks
|
||||
]
|
||||
|
||||
return None
|
||||
|
||||
def detect_poses(self, oriImg, include_hand=False, include_face=False) -> List[PoseResult]:
|
||||
"""
|
||||
Detect poses in the given image.
|
||||
Args:
|
||||
oriImg (numpy.ndarray): The input image for pose detection.
|
||||
include_hand (bool, optional): Whether to include hand detection. Defaults to False.
|
||||
include_face (bool, optional): Whether to include face detection. Defaults to False.
|
||||
|
||||
Returns:
|
||||
List[PoseResult]: A list of PoseResult objects containing the detected poses.
|
||||
"""
|
||||
oriImg = oriImg[:, :, ::-1].copy()
|
||||
H, W, C = oriImg.shape
|
||||
with torch.no_grad():
|
||||
candidate, subset = self.body_estimation(oriImg)
|
||||
bodies = self.body_estimation.format_body_result(candidate, subset)
|
||||
|
||||
results = []
|
||||
for body in bodies:
|
||||
left_hand, right_hand, face = (None,) * 3
|
||||
if include_hand:
|
||||
left_hand, right_hand = self.detect_hands(body, oriImg)
|
||||
if include_face:
|
||||
face = self.detect_face(body, oriImg)
|
||||
|
||||
results.append(PoseResult(BodyResult(
|
||||
keypoints=[
|
||||
Keypoint(
|
||||
x=keypoint.x / float(W),
|
||||
y=keypoint.y / float(H)
|
||||
) if keypoint is not None else None
|
||||
for keypoint in body.keypoints
|
||||
],
|
||||
total_score=body.total_score,
|
||||
total_parts=body.total_parts
|
||||
), left_hand, right_hand, face))
|
||||
|
||||
return results
|
||||
|
||||
def __call__(self, input_image, detect_resolution=512, image_resolution=512, include_body=True, include_hand=False, include_face=False, hand_and_face=None, output_type="pil", **kwargs):
|
||||
if hand_and_face is not None:
|
||||
warnings.warn("hand_and_face is deprecated. Use include_hand and include_face instead.", DeprecationWarning, stacklevel=2)
|
||||
include_hand = hand_and_face
|
||||
include_face = hand_and_face
|
||||
|
||||
if "return_pil" in kwargs:
|
||||
warnings.warn("return_pil is deprecated. Use output_type instead.", DeprecationWarning, stacklevel=2)
|
||||
output_type = "pil" if kwargs["return_pil"] else "np"
|
||||
if type(output_type) is bool:
|
||||
warnings.warn("Passing `True` or `False` to `output_type` is deprecated and will raise an error in future versions", stacklevel=2)
|
||||
if output_type:
|
||||
output_type = "pil"
|
||||
|
||||
if not isinstance(input_image, np.ndarray):
|
||||
input_image = np.array(input_image, dtype=np.uint8)
|
||||
|
||||
input_image = HWC3(input_image)
|
||||
input_image = resize_image(input_image, detect_resolution)
|
||||
H, W, C = input_image.shape
|
||||
|
||||
poses = self.detect_poses(input_image, include_hand, include_face)
|
||||
canvas = draw_poses(poses, H, W, draw_body=include_body, draw_hand=include_hand, draw_face=include_face)
|
||||
|
||||
detected_map = canvas
|
||||
detected_map = HWC3(detected_map)
|
||||
|
||||
img = resize_image(input_image, image_resolution)
|
||||
H, W, C = img.shape
|
||||
|
||||
detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
if output_type == "pil":
|
||||
detected_map = Image.fromarray(detected_map)
|
||||
|
||||
return detected_map
|
||||
259
invokeai/backend/bria/controlnet_aux/open_pose/body.py
Normal file
259
invokeai/backend/bria/controlnet_aux/open_pose/body.py
Normal file
@@ -0,0 +1,259 @@
|
||||
import math
|
||||
from typing import List, NamedTuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from scipy.ndimage.filters import gaussian_filter
|
||||
|
||||
from invokeai.backend.bria.controlnet_aux.open_pose import util
|
||||
from invokeai.backend.bria.controlnet_aux.open_pose.model import bodypose_model
|
||||
|
||||
|
||||
class Keypoint(NamedTuple):
|
||||
x: float
|
||||
y: float
|
||||
score: float = 1.0
|
||||
id: int = -1
|
||||
|
||||
|
||||
class BodyResult(NamedTuple):
|
||||
# Note: Using `Union` instead of `|` operator as the ladder is a Python
|
||||
# 3.10 feature.
|
||||
# Annotator code should be Python 3.8 Compatible, as controlnet repo uses
|
||||
# Python 3.8 environment.
|
||||
# https://github.com/lllyasviel/ControlNet/blob/d3284fcd0972c510635a4f5abe2eeb71dc0de524/environment.yaml#L6
|
||||
keypoints: List[Union[Keypoint, None]]
|
||||
total_score: float
|
||||
total_parts: int
|
||||
|
||||
|
||||
class Body(object):
|
||||
def __init__(self, model_path):
|
||||
self.model = bodypose_model()
|
||||
model_dict = util.transfer(self.model, torch.load(model_path))
|
||||
self.model.load_state_dict(model_dict)
|
||||
self.model.eval()
|
||||
|
||||
def to(self, device):
|
||||
self.model.to(device)
|
||||
return self
|
||||
|
||||
def __call__(self, oriImg):
|
||||
device = next(iter(self.model.parameters())).device
|
||||
# scale_search = [0.5, 1.0, 1.5, 2.0]
|
||||
scale_search = [0.5]
|
||||
boxsize = 368
|
||||
stride = 8
|
||||
padValue = 128
|
||||
thre1 = 0.1
|
||||
thre2 = 0.05
|
||||
multiplier = [x * boxsize / oriImg.shape[0] for x in scale_search]
|
||||
heatmap_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 19))
|
||||
paf_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 38))
|
||||
|
||||
for m in range(len(multiplier)):
|
||||
scale = multiplier[m]
|
||||
imageToTest = util.smart_resize_k(oriImg, fx=scale, fy=scale)
|
||||
imageToTest_padded, pad = util.padRightDownCorner(imageToTest, stride, padValue)
|
||||
im = np.transpose(np.float32(imageToTest_padded[:, :, :, np.newaxis]), (3, 2, 0, 1)) / 256 - 0.5
|
||||
im = np.ascontiguousarray(im)
|
||||
|
||||
data = torch.from_numpy(im).float()
|
||||
data = data.to(device)
|
||||
# data = data.permute([2, 0, 1]).unsqueeze(0).float()
|
||||
with torch.no_grad():
|
||||
Mconv7_stage6_L1, Mconv7_stage6_L2 = self.model(data)
|
||||
Mconv7_stage6_L1 = Mconv7_stage6_L1.cpu().numpy()
|
||||
Mconv7_stage6_L2 = Mconv7_stage6_L2.cpu().numpy()
|
||||
|
||||
# extract outputs, resize, and remove padding
|
||||
# heatmap = np.transpose(np.squeeze(net.blobs[output_blobs.keys()[1]].data), (1, 2, 0)) # output 1 is heatmaps
|
||||
heatmap = np.transpose(np.squeeze(Mconv7_stage6_L2), (1, 2, 0)) # output 1 is heatmaps
|
||||
heatmap = util.smart_resize_k(heatmap, fx=stride, fy=stride)
|
||||
heatmap = heatmap[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :]
|
||||
heatmap = util.smart_resize(heatmap, (oriImg.shape[0], oriImg.shape[1]))
|
||||
|
||||
# paf = np.transpose(np.squeeze(net.blobs[output_blobs.keys()[0]].data), (1, 2, 0)) # output 0 is PAFs
|
||||
paf = np.transpose(np.squeeze(Mconv7_stage6_L1), (1, 2, 0)) # output 0 is PAFs
|
||||
paf = util.smart_resize_k(paf, fx=stride, fy=stride)
|
||||
paf = paf[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :]
|
||||
paf = util.smart_resize(paf, (oriImg.shape[0], oriImg.shape[1]))
|
||||
|
||||
heatmap_avg += heatmap_avg + heatmap / len(multiplier)
|
||||
paf_avg += + paf / len(multiplier)
|
||||
|
||||
all_peaks = []
|
||||
peak_counter = 0
|
||||
|
||||
for part in range(18):
|
||||
map_ori = heatmap_avg[:, :, part]
|
||||
one_heatmap = gaussian_filter(map_ori, sigma=3)
|
||||
|
||||
map_left = np.zeros(one_heatmap.shape)
|
||||
map_left[1:, :] = one_heatmap[:-1, :]
|
||||
map_right = np.zeros(one_heatmap.shape)
|
||||
map_right[:-1, :] = one_heatmap[1:, :]
|
||||
map_up = np.zeros(one_heatmap.shape)
|
||||
map_up[:, 1:] = one_heatmap[:, :-1]
|
||||
map_down = np.zeros(one_heatmap.shape)
|
||||
map_down[:, :-1] = one_heatmap[:, 1:]
|
||||
|
||||
peaks_binary = np.logical_and.reduce(
|
||||
(one_heatmap >= map_left, one_heatmap >= map_right, one_heatmap >= map_up, one_heatmap >= map_down, one_heatmap > thre1))
|
||||
peaks = list(zip(np.nonzero(peaks_binary)[1], np.nonzero(peaks_binary)[0], strict=False)) # note reverse
|
||||
peaks_with_score = [x + (map_ori[x[1], x[0]],) for x in peaks]
|
||||
peak_id = range(peak_counter, peak_counter + len(peaks))
|
||||
peaks_with_score_and_id = [peaks_with_score[i] + (peak_id[i],) for i in range(len(peak_id))]
|
||||
|
||||
all_peaks.append(peaks_with_score_and_id)
|
||||
peak_counter += len(peaks)
|
||||
|
||||
# find connection in the specified sequence, center 29 is in the position 15
|
||||
limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \
|
||||
[10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \
|
||||
[1, 16], [16, 18], [3, 17], [6, 18]]
|
||||
# the middle joints heatmap correpondence
|
||||
mapIdx = [[31, 32], [39, 40], [33, 34], [35, 36], [41, 42], [43, 44], [19, 20], [21, 22], \
|
||||
[23, 24], [25, 26], [27, 28], [29, 30], [47, 48], [49, 50], [53, 54], [51, 52], \
|
||||
[55, 56], [37, 38], [45, 46]]
|
||||
|
||||
connection_all = []
|
||||
special_k = []
|
||||
mid_num = 10
|
||||
|
||||
for k in range(len(mapIdx)):
|
||||
score_mid = paf_avg[:, :, [x - 19 for x in mapIdx[k]]]
|
||||
candA = all_peaks[limbSeq[k][0] - 1]
|
||||
candB = all_peaks[limbSeq[k][1] - 1]
|
||||
nA = len(candA)
|
||||
nB = len(candB)
|
||||
indexA, indexB = limbSeq[k]
|
||||
if (nA != 0 and nB != 0):
|
||||
connection_candidate = []
|
||||
for i in range(nA):
|
||||
for j in range(nB):
|
||||
vec = np.subtract(candB[j][:2], candA[i][:2])
|
||||
norm = math.sqrt(vec[0] * vec[0] + vec[1] * vec[1])
|
||||
norm = max(0.001, norm)
|
||||
vec = np.divide(vec, norm)
|
||||
|
||||
startend = list(zip(np.linspace(candA[i][0], candB[j][0], num=mid_num), \
|
||||
np.linspace(candA[i][1], candB[j][1], num=mid_num), strict=False))
|
||||
|
||||
vec_x = np.array([score_mid[int(round(startend[i][1])), int(round(startend[i][0])), 0] \
|
||||
for i in range(len(startend))])
|
||||
vec_y = np.array([score_mid[int(round(startend[i][1])), int(round(startend[i][0])), 1] \
|
||||
for i in range(len(startend))])
|
||||
|
||||
score_midpts = np.multiply(vec_x, vec[0]) + np.multiply(vec_y, vec[1])
|
||||
score_with_dist_prior = sum(score_midpts) / len(score_midpts) + min(
|
||||
0.5 * oriImg.shape[0] / norm - 1, 0)
|
||||
criterion1 = len(np.nonzero(score_midpts > thre2)[0]) > 0.8 * len(score_midpts)
|
||||
criterion2 = score_with_dist_prior > 0
|
||||
if criterion1 and criterion2:
|
||||
connection_candidate.append(
|
||||
[i, j, score_with_dist_prior, score_with_dist_prior + candA[i][2] + candB[j][2]])
|
||||
|
||||
connection_candidate = sorted(connection_candidate, key=lambda x: x[2], reverse=True)
|
||||
connection = np.zeros((0, 5))
|
||||
for c in range(len(connection_candidate)):
|
||||
i, j, s = connection_candidate[c][0:3]
|
||||
if (i not in connection[:, 3] and j not in connection[:, 4]):
|
||||
connection = np.vstack([connection, [candA[i][3], candB[j][3], s, i, j]])
|
||||
if (len(connection) >= min(nA, nB)):
|
||||
break
|
||||
|
||||
connection_all.append(connection)
|
||||
else:
|
||||
special_k.append(k)
|
||||
connection_all.append([])
|
||||
|
||||
# last number in each row is the total parts number of that person
|
||||
# the second last number in each row is the score of the overall configuration
|
||||
subset = -1 * np.ones((0, 20))
|
||||
candidate = np.array([item for sublist in all_peaks for item in sublist])
|
||||
|
||||
for k in range(len(mapIdx)):
|
||||
if k not in special_k:
|
||||
partAs = connection_all[k][:, 0]
|
||||
partBs = connection_all[k][:, 1]
|
||||
indexA, indexB = np.array(limbSeq[k]) - 1
|
||||
|
||||
for i in range(len(connection_all[k])): # = 1:size(temp,1)
|
||||
found = 0
|
||||
subset_idx = [-1, -1]
|
||||
for j in range(len(subset)): # 1:size(subset,1):
|
||||
if subset[j][indexA] == partAs[i] or subset[j][indexB] == partBs[i]:
|
||||
subset_idx[found] = j
|
||||
found += 1
|
||||
|
||||
if found == 1:
|
||||
j = subset_idx[0]
|
||||
if subset[j][indexB] != partBs[i]:
|
||||
subset[j][indexB] = partBs[i]
|
||||
subset[j][-1] += 1
|
||||
subset[j][-2] += candidate[partBs[i].astype(int), 2] + connection_all[k][i][2]
|
||||
elif found == 2: # if found 2 and disjoint, merge them
|
||||
j1, j2 = subset_idx
|
||||
membership = ((subset[j1] >= 0).astype(int) + (subset[j2] >= 0).astype(int))[:-2]
|
||||
if len(np.nonzero(membership == 2)[0]) == 0: # merge
|
||||
subset[j1][:-2] += (subset[j2][:-2] + 1)
|
||||
subset[j1][-2:] += subset[j2][-2:]
|
||||
subset[j1][-2] += connection_all[k][i][2]
|
||||
subset = np.delete(subset, j2, 0)
|
||||
else: # as like found == 1
|
||||
subset[j1][indexB] = partBs[i]
|
||||
subset[j1][-1] += 1
|
||||
subset[j1][-2] += candidate[partBs[i].astype(int), 2] + connection_all[k][i][2]
|
||||
|
||||
# if find no partA in the subset, create a new subset
|
||||
elif not found and k < 17:
|
||||
row = -1 * np.ones(20)
|
||||
row[indexA] = partAs[i]
|
||||
row[indexB] = partBs[i]
|
||||
row[-1] = 2
|
||||
row[-2] = sum(candidate[connection_all[k][i, :2].astype(int), 2]) + connection_all[k][i][2]
|
||||
subset = np.vstack([subset, row])
|
||||
# delete some rows of subset which has few parts occur
|
||||
deleteIdx = []
|
||||
for i in range(len(subset)):
|
||||
if subset[i][-1] < 4 or subset[i][-2] / subset[i][-1] < 0.4:
|
||||
deleteIdx.append(i)
|
||||
subset = np.delete(subset, deleteIdx, axis=0)
|
||||
|
||||
# subset: n*20 array, 0-17 is the index in candidate, 18 is the total score, 19 is the total parts
|
||||
# candidate: x, y, score, id
|
||||
return candidate, subset
|
||||
|
||||
@staticmethod
|
||||
def format_body_result(candidate: np.ndarray, subset: np.ndarray) -> List[BodyResult]:
|
||||
"""
|
||||
Format the body results from the candidate and subset arrays into a list of BodyResult objects.
|
||||
|
||||
Args:
|
||||
candidate (np.ndarray): An array of candidates containing the x, y coordinates, score, and id
|
||||
for each body part.
|
||||
subset (np.ndarray): An array of subsets containing indices to the candidate array for each
|
||||
person detected. The last two columns of each row hold the total score and total parts
|
||||
of the person.
|
||||
|
||||
Returns:
|
||||
List[BodyResult]: A list of BodyResult objects, where each object represents a person with
|
||||
detected keypoints, total score, and total parts.
|
||||
"""
|
||||
return [
|
||||
BodyResult(
|
||||
keypoints=[
|
||||
Keypoint(
|
||||
x=candidate[candidate_index][0],
|
||||
y=candidate[candidate_index][1],
|
||||
score=candidate[candidate_index][2],
|
||||
id=candidate[candidate_index][3]
|
||||
) if candidate_index != -1 else None
|
||||
for candidate_index in person[:18].astype(int)
|
||||
],
|
||||
total_score=person[18],
|
||||
total_parts=person[19]
|
||||
)
|
||||
for person in subset
|
||||
]
|
||||
364
invokeai/backend/bria/controlnet_aux/open_pose/face.py
Normal file
364
invokeai/backend/bria/controlnet_aux/open_pose/face.py
Normal file
@@ -0,0 +1,364 @@
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Conv2d, MaxPool2d, Module, ReLU, init
|
||||
from torchvision.transforms import ToPILImage, ToTensor
|
||||
|
||||
from invokeai.backend.bria.controlnet_aux.open_pose import util
|
||||
|
||||
|
||||
class FaceNet(Module):
|
||||
"""Model the cascading heatmaps. """
|
||||
def __init__(self):
|
||||
super(FaceNet, self).__init__()
|
||||
# cnn to make feature map
|
||||
self.relu = ReLU()
|
||||
self.max_pooling_2d = MaxPool2d(kernel_size=2, stride=2)
|
||||
self.conv1_1 = Conv2d(in_channels=3, out_channels=64,
|
||||
kernel_size=3, stride=1, padding=1)
|
||||
self.conv1_2 = Conv2d(
|
||||
in_channels=64, out_channels=64, kernel_size=3, stride=1,
|
||||
padding=1)
|
||||
self.conv2_1 = Conv2d(
|
||||
in_channels=64, out_channels=128, kernel_size=3, stride=1,
|
||||
padding=1)
|
||||
self.conv2_2 = Conv2d(
|
||||
in_channels=128, out_channels=128, kernel_size=3, stride=1,
|
||||
padding=1)
|
||||
self.conv3_1 = Conv2d(
|
||||
in_channels=128, out_channels=256, kernel_size=3, stride=1,
|
||||
padding=1)
|
||||
self.conv3_2 = Conv2d(
|
||||
in_channels=256, out_channels=256, kernel_size=3, stride=1,
|
||||
padding=1)
|
||||
self.conv3_3 = Conv2d(
|
||||
in_channels=256, out_channels=256, kernel_size=3, stride=1,
|
||||
padding=1)
|
||||
self.conv3_4 = Conv2d(
|
||||
in_channels=256, out_channels=256, kernel_size=3, stride=1,
|
||||
padding=1)
|
||||
self.conv4_1 = Conv2d(
|
||||
in_channels=256, out_channels=512, kernel_size=3, stride=1,
|
||||
padding=1)
|
||||
self.conv4_2 = Conv2d(
|
||||
in_channels=512, out_channels=512, kernel_size=3, stride=1,
|
||||
padding=1)
|
||||
self.conv4_3 = Conv2d(
|
||||
in_channels=512, out_channels=512, kernel_size=3, stride=1,
|
||||
padding=1)
|
||||
self.conv4_4 = Conv2d(
|
||||
in_channels=512, out_channels=512, kernel_size=3, stride=1,
|
||||
padding=1)
|
||||
self.conv5_1 = Conv2d(
|
||||
in_channels=512, out_channels=512, kernel_size=3, stride=1,
|
||||
padding=1)
|
||||
self.conv5_2 = Conv2d(
|
||||
in_channels=512, out_channels=512, kernel_size=3, stride=1,
|
||||
padding=1)
|
||||
self.conv5_3_CPM = Conv2d(
|
||||
in_channels=512, out_channels=128, kernel_size=3, stride=1,
|
||||
padding=1)
|
||||
|
||||
# stage1
|
||||
self.conv6_1_CPM = Conv2d(
|
||||
in_channels=128, out_channels=512, kernel_size=1, stride=1,
|
||||
padding=0)
|
||||
self.conv6_2_CPM = Conv2d(
|
||||
in_channels=512, out_channels=71, kernel_size=1, stride=1,
|
||||
padding=0)
|
||||
|
||||
# stage2
|
||||
self.Mconv1_stage2 = Conv2d(
|
||||
in_channels=199, out_channels=128, kernel_size=7, stride=1,
|
||||
padding=3)
|
||||
self.Mconv2_stage2 = Conv2d(
|
||||
in_channels=128, out_channels=128, kernel_size=7, stride=1,
|
||||
padding=3)
|
||||
self.Mconv3_stage2 = Conv2d(
|
||||
in_channels=128, out_channels=128, kernel_size=7, stride=1,
|
||||
padding=3)
|
||||
self.Mconv4_stage2 = Conv2d(
|
||||
in_channels=128, out_channels=128, kernel_size=7, stride=1,
|
||||
padding=3)
|
||||
self.Mconv5_stage2 = Conv2d(
|
||||
in_channels=128, out_channels=128, kernel_size=7, stride=1,
|
||||
padding=3)
|
||||
self.Mconv6_stage2 = Conv2d(
|
||||
in_channels=128, out_channels=128, kernel_size=1, stride=1,
|
||||
padding=0)
|
||||
self.Mconv7_stage2 = Conv2d(
|
||||
in_channels=128, out_channels=71, kernel_size=1, stride=1,
|
||||
padding=0)
|
||||
|
||||
# stage3
|
||||
self.Mconv1_stage3 = Conv2d(
|
||||
in_channels=199, out_channels=128, kernel_size=7, stride=1,
|
||||
padding=3)
|
||||
self.Mconv2_stage3 = Conv2d(
|
||||
in_channels=128, out_channels=128, kernel_size=7, stride=1,
|
||||
padding=3)
|
||||
self.Mconv3_stage3 = Conv2d(
|
||||
in_channels=128, out_channels=128, kernel_size=7, stride=1,
|
||||
padding=3)
|
||||
self.Mconv4_stage3 = Conv2d(
|
||||
in_channels=128, out_channels=128, kernel_size=7, stride=1,
|
||||
padding=3)
|
||||
self.Mconv5_stage3 = Conv2d(
|
||||
in_channels=128, out_channels=128, kernel_size=7, stride=1,
|
||||
padding=3)
|
||||
self.Mconv6_stage3 = Conv2d(
|
||||
in_channels=128, out_channels=128, kernel_size=1, stride=1,
|
||||
padding=0)
|
||||
self.Mconv7_stage3 = Conv2d(
|
||||
in_channels=128, out_channels=71, kernel_size=1, stride=1,
|
||||
padding=0)
|
||||
|
||||
# stage4
|
||||
self.Mconv1_stage4 = Conv2d(
|
||||
in_channels=199, out_channels=128, kernel_size=7, stride=1,
|
||||
padding=3)
|
||||
self.Mconv2_stage4 = Conv2d(
|
||||
in_channels=128, out_channels=128, kernel_size=7, stride=1,
|
||||
padding=3)
|
||||
self.Mconv3_stage4 = Conv2d(
|
||||
in_channels=128, out_channels=128, kernel_size=7, stride=1,
|
||||
padding=3)
|
||||
self.Mconv4_stage4 = Conv2d(
|
||||
in_channels=128, out_channels=128, kernel_size=7, stride=1,
|
||||
padding=3)
|
||||
self.Mconv5_stage4 = Conv2d(
|
||||
in_channels=128, out_channels=128, kernel_size=7, stride=1,
|
||||
padding=3)
|
||||
self.Mconv6_stage4 = Conv2d(
|
||||
in_channels=128, out_channels=128, kernel_size=1, stride=1,
|
||||
padding=0)
|
||||
self.Mconv7_stage4 = Conv2d(
|
||||
in_channels=128, out_channels=71, kernel_size=1, stride=1,
|
||||
padding=0)
|
||||
|
||||
# stage5
|
||||
self.Mconv1_stage5 = Conv2d(
|
||||
in_channels=199, out_channels=128, kernel_size=7, stride=1,
|
||||
padding=3)
|
||||
self.Mconv2_stage5 = Conv2d(
|
||||
in_channels=128, out_channels=128, kernel_size=7, stride=1,
|
||||
padding=3)
|
||||
self.Mconv3_stage5 = Conv2d(
|
||||
in_channels=128, out_channels=128, kernel_size=7, stride=1,
|
||||
padding=3)
|
||||
self.Mconv4_stage5 = Conv2d(
|
||||
in_channels=128, out_channels=128, kernel_size=7, stride=1,
|
||||
padding=3)
|
||||
self.Mconv5_stage5 = Conv2d(
|
||||
in_channels=128, out_channels=128, kernel_size=7, stride=1,
|
||||
padding=3)
|
||||
self.Mconv6_stage5 = Conv2d(
|
||||
in_channels=128, out_channels=128, kernel_size=1, stride=1,
|
||||
padding=0)
|
||||
self.Mconv7_stage5 = Conv2d(
|
||||
in_channels=128, out_channels=71, kernel_size=1, stride=1,
|
||||
padding=0)
|
||||
|
||||
# stage6
|
||||
self.Mconv1_stage6 = Conv2d(
|
||||
in_channels=199, out_channels=128, kernel_size=7, stride=1,
|
||||
padding=3)
|
||||
self.Mconv2_stage6 = Conv2d(
|
||||
in_channels=128, out_channels=128, kernel_size=7, stride=1,
|
||||
padding=3)
|
||||
self.Mconv3_stage6 = Conv2d(
|
||||
in_channels=128, out_channels=128, kernel_size=7, stride=1,
|
||||
padding=3)
|
||||
self.Mconv4_stage6 = Conv2d(
|
||||
in_channels=128, out_channels=128, kernel_size=7, stride=1,
|
||||
padding=3)
|
||||
self.Mconv5_stage6 = Conv2d(
|
||||
in_channels=128, out_channels=128, kernel_size=7, stride=1,
|
||||
padding=3)
|
||||
self.Mconv6_stage6 = Conv2d(
|
||||
in_channels=128, out_channels=128, kernel_size=1, stride=1,
|
||||
padding=0)
|
||||
self.Mconv7_stage6 = Conv2d(
|
||||
in_channels=128, out_channels=71, kernel_size=1, stride=1,
|
||||
padding=0)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, Conv2d):
|
||||
init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x):
|
||||
"""Return a list of heatmaps."""
|
||||
heatmaps = []
|
||||
|
||||
h = self.relu(self.conv1_1(x))
|
||||
h = self.relu(self.conv1_2(h))
|
||||
h = self.max_pooling_2d(h)
|
||||
h = self.relu(self.conv2_1(h))
|
||||
h = self.relu(self.conv2_2(h))
|
||||
h = self.max_pooling_2d(h)
|
||||
h = self.relu(self.conv3_1(h))
|
||||
h = self.relu(self.conv3_2(h))
|
||||
h = self.relu(self.conv3_3(h))
|
||||
h = self.relu(self.conv3_4(h))
|
||||
h = self.max_pooling_2d(h)
|
||||
h = self.relu(self.conv4_1(h))
|
||||
h = self.relu(self.conv4_2(h))
|
||||
h = self.relu(self.conv4_3(h))
|
||||
h = self.relu(self.conv4_4(h))
|
||||
h = self.relu(self.conv5_1(h))
|
||||
h = self.relu(self.conv5_2(h))
|
||||
h = self.relu(self.conv5_3_CPM(h))
|
||||
feature_map = h
|
||||
|
||||
# stage1
|
||||
h = self.relu(self.conv6_1_CPM(h))
|
||||
h = self.conv6_2_CPM(h)
|
||||
heatmaps.append(h)
|
||||
|
||||
# stage2
|
||||
h = torch.cat([h, feature_map], dim=1) # channel concat
|
||||
h = self.relu(self.Mconv1_stage2(h))
|
||||
h = self.relu(self.Mconv2_stage2(h))
|
||||
h = self.relu(self.Mconv3_stage2(h))
|
||||
h = self.relu(self.Mconv4_stage2(h))
|
||||
h = self.relu(self.Mconv5_stage2(h))
|
||||
h = self.relu(self.Mconv6_stage2(h))
|
||||
h = self.Mconv7_stage2(h)
|
||||
heatmaps.append(h)
|
||||
|
||||
# stage3
|
||||
h = torch.cat([h, feature_map], dim=1) # channel concat
|
||||
h = self.relu(self.Mconv1_stage3(h))
|
||||
h = self.relu(self.Mconv2_stage3(h))
|
||||
h = self.relu(self.Mconv3_stage3(h))
|
||||
h = self.relu(self.Mconv4_stage3(h))
|
||||
h = self.relu(self.Mconv5_stage3(h))
|
||||
h = self.relu(self.Mconv6_stage3(h))
|
||||
h = self.Mconv7_stage3(h)
|
||||
heatmaps.append(h)
|
||||
|
||||
# stage4
|
||||
h = torch.cat([h, feature_map], dim=1) # channel concat
|
||||
h = self.relu(self.Mconv1_stage4(h))
|
||||
h = self.relu(self.Mconv2_stage4(h))
|
||||
h = self.relu(self.Mconv3_stage4(h))
|
||||
h = self.relu(self.Mconv4_stage4(h))
|
||||
h = self.relu(self.Mconv5_stage4(h))
|
||||
h = self.relu(self.Mconv6_stage4(h))
|
||||
h = self.Mconv7_stage4(h)
|
||||
heatmaps.append(h)
|
||||
|
||||
# stage5
|
||||
h = torch.cat([h, feature_map], dim=1) # channel concat
|
||||
h = self.relu(self.Mconv1_stage5(h))
|
||||
h = self.relu(self.Mconv2_stage5(h))
|
||||
h = self.relu(self.Mconv3_stage5(h))
|
||||
h = self.relu(self.Mconv4_stage5(h))
|
||||
h = self.relu(self.Mconv5_stage5(h))
|
||||
h = self.relu(self.Mconv6_stage5(h))
|
||||
h = self.Mconv7_stage5(h)
|
||||
heatmaps.append(h)
|
||||
|
||||
# stage6
|
||||
h = torch.cat([h, feature_map], dim=1) # channel concat
|
||||
h = self.relu(self.Mconv1_stage6(h))
|
||||
h = self.relu(self.Mconv2_stage6(h))
|
||||
h = self.relu(self.Mconv3_stage6(h))
|
||||
h = self.relu(self.Mconv4_stage6(h))
|
||||
h = self.relu(self.Mconv5_stage6(h))
|
||||
h = self.relu(self.Mconv6_stage6(h))
|
||||
h = self.Mconv7_stage6(h)
|
||||
heatmaps.append(h)
|
||||
|
||||
return heatmaps
|
||||
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
TOTEN = ToTensor()
|
||||
TOPIL = ToPILImage()
|
||||
|
||||
|
||||
params = {
|
||||
'gaussian_sigma': 2.5,
|
||||
'inference_img_size': 736, # 368, 736, 1312
|
||||
'heatmap_peak_thresh': 0.1,
|
||||
'crop_scale': 1.5,
|
||||
'line_indices': [
|
||||
[0, 1], [1, 2], [2, 3], [3, 4], [4, 5], [5, 6],
|
||||
[6, 7], [7, 8], [8, 9], [9, 10], [10, 11], [11, 12], [12, 13],
|
||||
[13, 14], [14, 15], [15, 16],
|
||||
[17, 18], [18, 19], [19, 20], [20, 21],
|
||||
[22, 23], [23, 24], [24, 25], [25, 26],
|
||||
[27, 28], [28, 29], [29, 30],
|
||||
[31, 32], [32, 33], [33, 34], [34, 35],
|
||||
[36, 37], [37, 38], [38, 39], [39, 40], [40, 41], [41, 36],
|
||||
[42, 43], [43, 44], [44, 45], [45, 46], [46, 47], [47, 42],
|
||||
[48, 49], [49, 50], [50, 51], [51, 52], [52, 53], [53, 54],
|
||||
[54, 55], [55, 56], [56, 57], [57, 58], [58, 59], [59, 48],
|
||||
[60, 61], [61, 62], [62, 63], [63, 64], [64, 65], [65, 66],
|
||||
[66, 67], [67, 60]
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
class Face(object):
|
||||
"""
|
||||
The OpenPose face landmark detector model.
|
||||
|
||||
Args:
|
||||
inference_size: set the size of the inference image size, suggested:
|
||||
368, 736, 1312, default 736
|
||||
gaussian_sigma: blur the heatmaps, default 2.5
|
||||
heatmap_peak_thresh: return landmark if over threshold, default 0.1
|
||||
|
||||
"""
|
||||
def __init__(self, face_model_path,
|
||||
inference_size=None,
|
||||
gaussian_sigma=None,
|
||||
heatmap_peak_thresh=None):
|
||||
self.inference_size = inference_size or params["inference_img_size"]
|
||||
self.sigma = gaussian_sigma or params['gaussian_sigma']
|
||||
self.threshold = heatmap_peak_thresh or params["heatmap_peak_thresh"]
|
||||
self.model = FaceNet()
|
||||
self.model.load_state_dict(torch.load(face_model_path))
|
||||
self.model.eval()
|
||||
|
||||
def to(self, device):
|
||||
self.model.to(device)
|
||||
return self
|
||||
|
||||
def __call__(self, face_img):
|
||||
device = next(iter(self.model.parameters())).device
|
||||
H, W, C = face_img.shape
|
||||
|
||||
w_size = 384
|
||||
x_data = torch.from_numpy(util.smart_resize(face_img, (w_size, w_size))).permute([2, 0, 1]) / 256.0 - 0.5
|
||||
|
||||
x_data = x_data.to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
hs = self.model(x_data[None, ...])
|
||||
heatmaps = F.interpolate(
|
||||
hs[-1],
|
||||
(H, W),
|
||||
mode='bilinear', align_corners=True).cpu().numpy()[0]
|
||||
return heatmaps
|
||||
|
||||
def compute_peaks_from_heatmaps(self, heatmaps):
|
||||
all_peaks = []
|
||||
for part in range(heatmaps.shape[0]):
|
||||
map_ori = heatmaps[part].copy()
|
||||
binary = np.ascontiguousarray(map_ori > 0.05, dtype=np.uint8)
|
||||
|
||||
if np.sum(binary) == 0:
|
||||
continue
|
||||
|
||||
positions = np.where(binary > 0.5)
|
||||
intensities = map_ori[positions]
|
||||
mi = np.argmax(intensities)
|
||||
y, x = positions[0][mi], positions[1][mi]
|
||||
all_peaks.append([x, y])
|
||||
|
||||
return np.array(all_peaks)
|
||||
90
invokeai/backend/bria/controlnet_aux/open_pose/hand.py
Normal file
90
invokeai/backend/bria/controlnet_aux/open_pose/hand.py
Normal file
@@ -0,0 +1,90 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from scipy.ndimage.filters import gaussian_filter
|
||||
from skimage.measure import label
|
||||
|
||||
from invokeai.backend.bria.controlnet_aux.open_pose import util
|
||||
from invokeai.backend.bria.controlnet_aux.open_pose.model import handpose_model
|
||||
|
||||
|
||||
class Hand(object):
|
||||
def __init__(self, model_path):
|
||||
self.model = handpose_model()
|
||||
model_dict = util.transfer(self.model, torch.load(model_path))
|
||||
self.model.load_state_dict(model_dict)
|
||||
self.model.eval()
|
||||
|
||||
def to(self, device):
|
||||
self.model.to(device)
|
||||
return self
|
||||
|
||||
def __call__(self, oriImgRaw):
|
||||
device = next(iter(self.model.parameters())).device
|
||||
scale_search = [0.5, 1.0, 1.5, 2.0]
|
||||
# scale_search = [0.5]
|
||||
boxsize = 368
|
||||
stride = 8
|
||||
padValue = 128
|
||||
thre = 0.05
|
||||
multiplier = [x * boxsize for x in scale_search]
|
||||
|
||||
wsize = 128
|
||||
heatmap_avg = np.zeros((wsize, wsize, 22))
|
||||
|
||||
Hr, Wr, Cr = oriImgRaw.shape
|
||||
|
||||
oriImg = cv2.GaussianBlur(oriImgRaw, (0, 0), 0.8)
|
||||
|
||||
for m in range(len(multiplier)):
|
||||
scale = multiplier[m]
|
||||
imageToTest = util.smart_resize(oriImg, (scale, scale))
|
||||
|
||||
imageToTest_padded, pad = util.padRightDownCorner(imageToTest, stride, padValue)
|
||||
im = np.transpose(np.float32(imageToTest_padded[:, :, :, np.newaxis]), (3, 2, 0, 1)) / 256 - 0.5
|
||||
im = np.ascontiguousarray(im)
|
||||
|
||||
data = torch.from_numpy(im).float()
|
||||
data = data.to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
output = self.model(data).cpu().numpy()
|
||||
|
||||
# extract outputs, resize, and remove padding
|
||||
heatmap = np.transpose(np.squeeze(output), (1, 2, 0)) # output 1 is heatmaps
|
||||
heatmap = util.smart_resize_k(heatmap, fx=stride, fy=stride)
|
||||
heatmap = heatmap[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :]
|
||||
heatmap = util.smart_resize(heatmap, (wsize, wsize))
|
||||
|
||||
heatmap_avg += heatmap / len(multiplier)
|
||||
|
||||
all_peaks = []
|
||||
for part in range(21):
|
||||
map_ori = heatmap_avg[:, :, part]
|
||||
one_heatmap = gaussian_filter(map_ori, sigma=3)
|
||||
binary = np.ascontiguousarray(one_heatmap > thre, dtype=np.uint8)
|
||||
|
||||
if np.sum(binary) == 0:
|
||||
all_peaks.append([0, 0])
|
||||
continue
|
||||
label_img, label_numbers = label(binary, return_num=True, connectivity=binary.ndim)
|
||||
max_index = np.argmax([np.sum(map_ori[label_img == i]) for i in range(1, label_numbers + 1)]) + 1
|
||||
label_img[label_img != max_index] = 0
|
||||
map_ori[label_img == 0] = 0
|
||||
|
||||
y, x = util.npmax(map_ori)
|
||||
y = int(float(y) * float(Hr) / float(wsize))
|
||||
x = int(float(x) * float(Wr) / float(wsize))
|
||||
all_peaks.append([x, y])
|
||||
return np.array(all_peaks)
|
||||
|
||||
if __name__ == "__main__":
|
||||
hand_estimation = Hand('../model/hand_pose_model.pth')
|
||||
|
||||
# test_image = '../images/hand.jpg'
|
||||
test_image = '../images/hand.jpg'
|
||||
oriImg = cv2.imread(test_image) # B,G,R order
|
||||
peaks = hand_estimation(oriImg)
|
||||
canvas = util.draw_handpose(oriImg, peaks, True)
|
||||
cv2.imshow('', canvas)
|
||||
cv2.waitKey(0)
|
||||
217
invokeai/backend/bria/controlnet_aux/open_pose/model.py
Normal file
217
invokeai/backend/bria/controlnet_aux/open_pose/model.py
Normal file
@@ -0,0 +1,217 @@
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def make_layers(block, no_relu_layers):
|
||||
layers = []
|
||||
for layer_name, v in block.items():
|
||||
if 'pool' in layer_name:
|
||||
layer = nn.MaxPool2d(kernel_size=v[0], stride=v[1],
|
||||
padding=v[2])
|
||||
layers.append((layer_name, layer))
|
||||
else:
|
||||
conv2d = nn.Conv2d(in_channels=v[0], out_channels=v[1],
|
||||
kernel_size=v[2], stride=v[3],
|
||||
padding=v[4])
|
||||
layers.append((layer_name, conv2d))
|
||||
if layer_name not in no_relu_layers:
|
||||
layers.append(('relu_'+layer_name, nn.ReLU(inplace=True)))
|
||||
|
||||
return nn.Sequential(OrderedDict(layers))
|
||||
|
||||
class bodypose_model(nn.Module):
|
||||
def __init__(self):
|
||||
super(bodypose_model, self).__init__()
|
||||
|
||||
# these layers have no relu layer
|
||||
no_relu_layers = ['conv5_5_CPM_L1', 'conv5_5_CPM_L2', 'Mconv7_stage2_L1',\
|
||||
'Mconv7_stage2_L2', 'Mconv7_stage3_L1', 'Mconv7_stage3_L2',\
|
||||
'Mconv7_stage4_L1', 'Mconv7_stage4_L2', 'Mconv7_stage5_L1',\
|
||||
'Mconv7_stage5_L2', 'Mconv7_stage6_L1', 'Mconv7_stage6_L1']
|
||||
blocks = {}
|
||||
block0 = OrderedDict([
|
||||
('conv1_1', [3, 64, 3, 1, 1]),
|
||||
('conv1_2', [64, 64, 3, 1, 1]),
|
||||
('pool1_stage1', [2, 2, 0]),
|
||||
('conv2_1', [64, 128, 3, 1, 1]),
|
||||
('conv2_2', [128, 128, 3, 1, 1]),
|
||||
('pool2_stage1', [2, 2, 0]),
|
||||
('conv3_1', [128, 256, 3, 1, 1]),
|
||||
('conv3_2', [256, 256, 3, 1, 1]),
|
||||
('conv3_3', [256, 256, 3, 1, 1]),
|
||||
('conv3_4', [256, 256, 3, 1, 1]),
|
||||
('pool3_stage1', [2, 2, 0]),
|
||||
('conv4_1', [256, 512, 3, 1, 1]),
|
||||
('conv4_2', [512, 512, 3, 1, 1]),
|
||||
('conv4_3_CPM', [512, 256, 3, 1, 1]),
|
||||
('conv4_4_CPM', [256, 128, 3, 1, 1])
|
||||
])
|
||||
|
||||
|
||||
# Stage 1
|
||||
block1_1 = OrderedDict([
|
||||
('conv5_1_CPM_L1', [128, 128, 3, 1, 1]),
|
||||
('conv5_2_CPM_L1', [128, 128, 3, 1, 1]),
|
||||
('conv5_3_CPM_L1', [128, 128, 3, 1, 1]),
|
||||
('conv5_4_CPM_L1', [128, 512, 1, 1, 0]),
|
||||
('conv5_5_CPM_L1', [512, 38, 1, 1, 0])
|
||||
])
|
||||
|
||||
block1_2 = OrderedDict([
|
||||
('conv5_1_CPM_L2', [128, 128, 3, 1, 1]),
|
||||
('conv5_2_CPM_L2', [128, 128, 3, 1, 1]),
|
||||
('conv5_3_CPM_L2', [128, 128, 3, 1, 1]),
|
||||
('conv5_4_CPM_L2', [128, 512, 1, 1, 0]),
|
||||
('conv5_5_CPM_L2', [512, 19, 1, 1, 0])
|
||||
])
|
||||
blocks['block1_1'] = block1_1
|
||||
blocks['block1_2'] = block1_2
|
||||
|
||||
self.model0 = make_layers(block0, no_relu_layers)
|
||||
|
||||
# Stages 2 - 6
|
||||
for i in range(2, 7):
|
||||
blocks['block%d_1' % i] = OrderedDict([
|
||||
('Mconv1_stage%d_L1' % i, [185, 128, 7, 1, 3]),
|
||||
('Mconv2_stage%d_L1' % i, [128, 128, 7, 1, 3]),
|
||||
('Mconv3_stage%d_L1' % i, [128, 128, 7, 1, 3]),
|
||||
('Mconv4_stage%d_L1' % i, [128, 128, 7, 1, 3]),
|
||||
('Mconv5_stage%d_L1' % i, [128, 128, 7, 1, 3]),
|
||||
('Mconv6_stage%d_L1' % i, [128, 128, 1, 1, 0]),
|
||||
('Mconv7_stage%d_L1' % i, [128, 38, 1, 1, 0])
|
||||
])
|
||||
|
||||
blocks['block%d_2' % i] = OrderedDict([
|
||||
('Mconv1_stage%d_L2' % i, [185, 128, 7, 1, 3]),
|
||||
('Mconv2_stage%d_L2' % i, [128, 128, 7, 1, 3]),
|
||||
('Mconv3_stage%d_L2' % i, [128, 128, 7, 1, 3]),
|
||||
('Mconv4_stage%d_L2' % i, [128, 128, 7, 1, 3]),
|
||||
('Mconv5_stage%d_L2' % i, [128, 128, 7, 1, 3]),
|
||||
('Mconv6_stage%d_L2' % i, [128, 128, 1, 1, 0]),
|
||||
('Mconv7_stage%d_L2' % i, [128, 19, 1, 1, 0])
|
||||
])
|
||||
|
||||
for k in blocks.keys():
|
||||
blocks[k] = make_layers(blocks[k], no_relu_layers)
|
||||
|
||||
self.model1_1 = blocks['block1_1']
|
||||
self.model2_1 = blocks['block2_1']
|
||||
self.model3_1 = blocks['block3_1']
|
||||
self.model4_1 = blocks['block4_1']
|
||||
self.model5_1 = blocks['block5_1']
|
||||
self.model6_1 = blocks['block6_1']
|
||||
|
||||
self.model1_2 = blocks['block1_2']
|
||||
self.model2_2 = blocks['block2_2']
|
||||
self.model3_2 = blocks['block3_2']
|
||||
self.model4_2 = blocks['block4_2']
|
||||
self.model5_2 = blocks['block5_2']
|
||||
self.model6_2 = blocks['block6_2']
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
out1 = self.model0(x)
|
||||
|
||||
out1_1 = self.model1_1(out1)
|
||||
out1_2 = self.model1_2(out1)
|
||||
out2 = torch.cat([out1_1, out1_2, out1], 1)
|
||||
|
||||
out2_1 = self.model2_1(out2)
|
||||
out2_2 = self.model2_2(out2)
|
||||
out3 = torch.cat([out2_1, out2_2, out1], 1)
|
||||
|
||||
out3_1 = self.model3_1(out3)
|
||||
out3_2 = self.model3_2(out3)
|
||||
out4 = torch.cat([out3_1, out3_2, out1], 1)
|
||||
|
||||
out4_1 = self.model4_1(out4)
|
||||
out4_2 = self.model4_2(out4)
|
||||
out5 = torch.cat([out4_1, out4_2, out1], 1)
|
||||
|
||||
out5_1 = self.model5_1(out5)
|
||||
out5_2 = self.model5_2(out5)
|
||||
out6 = torch.cat([out5_1, out5_2, out1], 1)
|
||||
|
||||
out6_1 = self.model6_1(out6)
|
||||
out6_2 = self.model6_2(out6)
|
||||
|
||||
return out6_1, out6_2
|
||||
|
||||
class handpose_model(nn.Module):
|
||||
def __init__(self):
|
||||
super(handpose_model, self).__init__()
|
||||
|
||||
# these layers have no relu layer
|
||||
no_relu_layers = ['conv6_2_CPM', 'Mconv7_stage2', 'Mconv7_stage3',\
|
||||
'Mconv7_stage4', 'Mconv7_stage5', 'Mconv7_stage6']
|
||||
# stage 1
|
||||
block1_0 = OrderedDict([
|
||||
('conv1_1', [3, 64, 3, 1, 1]),
|
||||
('conv1_2', [64, 64, 3, 1, 1]),
|
||||
('pool1_stage1', [2, 2, 0]),
|
||||
('conv2_1', [64, 128, 3, 1, 1]),
|
||||
('conv2_2', [128, 128, 3, 1, 1]),
|
||||
('pool2_stage1', [2, 2, 0]),
|
||||
('conv3_1', [128, 256, 3, 1, 1]),
|
||||
('conv3_2', [256, 256, 3, 1, 1]),
|
||||
('conv3_3', [256, 256, 3, 1, 1]),
|
||||
('conv3_4', [256, 256, 3, 1, 1]),
|
||||
('pool3_stage1', [2, 2, 0]),
|
||||
('conv4_1', [256, 512, 3, 1, 1]),
|
||||
('conv4_2', [512, 512, 3, 1, 1]),
|
||||
('conv4_3', [512, 512, 3, 1, 1]),
|
||||
('conv4_4', [512, 512, 3, 1, 1]),
|
||||
('conv5_1', [512, 512, 3, 1, 1]),
|
||||
('conv5_2', [512, 512, 3, 1, 1]),
|
||||
('conv5_3_CPM', [512, 128, 3, 1, 1])
|
||||
])
|
||||
|
||||
block1_1 = OrderedDict([
|
||||
('conv6_1_CPM', [128, 512, 1, 1, 0]),
|
||||
('conv6_2_CPM', [512, 22, 1, 1, 0])
|
||||
])
|
||||
|
||||
blocks = {}
|
||||
blocks['block1_0'] = block1_0
|
||||
blocks['block1_1'] = block1_1
|
||||
|
||||
# stage 2-6
|
||||
for i in range(2, 7):
|
||||
blocks['block%d' % i] = OrderedDict([
|
||||
('Mconv1_stage%d' % i, [150, 128, 7, 1, 3]),
|
||||
('Mconv2_stage%d' % i, [128, 128, 7, 1, 3]),
|
||||
('Mconv3_stage%d' % i, [128, 128, 7, 1, 3]),
|
||||
('Mconv4_stage%d' % i, [128, 128, 7, 1, 3]),
|
||||
('Mconv5_stage%d' % i, [128, 128, 7, 1, 3]),
|
||||
('Mconv6_stage%d' % i, [128, 128, 1, 1, 0]),
|
||||
('Mconv7_stage%d' % i, [128, 22, 1, 1, 0])
|
||||
])
|
||||
|
||||
for k in blocks.keys():
|
||||
blocks[k] = make_layers(blocks[k], no_relu_layers)
|
||||
|
||||
self.model1_0 = blocks['block1_0']
|
||||
self.model1_1 = blocks['block1_1']
|
||||
self.model2 = blocks['block2']
|
||||
self.model3 = blocks['block3']
|
||||
self.model4 = blocks['block4']
|
||||
self.model5 = blocks['block5']
|
||||
self.model6 = blocks['block6']
|
||||
|
||||
def forward(self, x):
|
||||
out1_0 = self.model1_0(x)
|
||||
out1_1 = self.model1_1(out1_0)
|
||||
concat_stage2 = torch.cat([out1_1, out1_0], 1)
|
||||
out_stage2 = self.model2(concat_stage2)
|
||||
concat_stage3 = torch.cat([out_stage2, out1_0], 1)
|
||||
out_stage3 = self.model3(concat_stage3)
|
||||
concat_stage4 = torch.cat([out_stage3, out1_0], 1)
|
||||
out_stage4 = self.model4(concat_stage4)
|
||||
concat_stage5 = torch.cat([out_stage4, out1_0], 1)
|
||||
out_stage5 = self.model5(concat_stage5)
|
||||
concat_stage6 = torch.cat([out_stage5, out1_0], 1)
|
||||
out_stage6 = self.model6(concat_stage6)
|
||||
return out_stage6
|
||||
388
invokeai/backend/bria/controlnet_aux/open_pose/util.py
Normal file
388
invokeai/backend/bria/controlnet_aux/open_pose/util.py
Normal file
@@ -0,0 +1,388 @@
|
||||
import math
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from invokeai.backend.bria.controlnet_aux.open_pose.body import BodyResult, Keypoint
|
||||
|
||||
eps = 0.01
|
||||
|
||||
|
||||
def smart_resize(x, s):
|
||||
Ht, Wt = s
|
||||
if x.ndim == 2:
|
||||
Ho, Wo = x.shape
|
||||
Co = 1
|
||||
else:
|
||||
Ho, Wo, Co = x.shape
|
||||
if Co == 3 or Co == 1:
|
||||
k = float(Ht + Wt) / float(Ho + Wo)
|
||||
return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4)
|
||||
else:
|
||||
return np.stack([smart_resize(x[:, :, i], s) for i in range(Co)], axis=2)
|
||||
|
||||
|
||||
def smart_resize_k(x, fx, fy):
|
||||
if x.ndim == 2:
|
||||
Ho, Wo = x.shape
|
||||
Co = 1
|
||||
else:
|
||||
Ho, Wo, Co = x.shape
|
||||
Ht, Wt = Ho * fy, Wo * fx
|
||||
if Co == 3 or Co == 1:
|
||||
k = float(Ht + Wt) / float(Ho + Wo)
|
||||
return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4)
|
||||
else:
|
||||
return np.stack([smart_resize_k(x[:, :, i], fx, fy) for i in range(Co)], axis=2)
|
||||
|
||||
|
||||
def padRightDownCorner(img, stride, padValue):
|
||||
h = img.shape[0]
|
||||
w = img.shape[1]
|
||||
|
||||
pad = 4 * [None]
|
||||
pad[0] = 0 # up
|
||||
pad[1] = 0 # left
|
||||
pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down
|
||||
pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right
|
||||
|
||||
img_padded = img
|
||||
pad_up = np.tile(img_padded[0:1, :, :]*0 + padValue, (pad[0], 1, 1))
|
||||
img_padded = np.concatenate((pad_up, img_padded), axis=0)
|
||||
pad_left = np.tile(img_padded[:, 0:1, :]*0 + padValue, (1, pad[1], 1))
|
||||
img_padded = np.concatenate((pad_left, img_padded), axis=1)
|
||||
pad_down = np.tile(img_padded[-2:-1, :, :]*0 + padValue, (pad[2], 1, 1))
|
||||
img_padded = np.concatenate((img_padded, pad_down), axis=0)
|
||||
pad_right = np.tile(img_padded[:, -2:-1, :]*0 + padValue, (1, pad[3], 1))
|
||||
img_padded = np.concatenate((img_padded, pad_right), axis=1)
|
||||
|
||||
return img_padded, pad
|
||||
|
||||
|
||||
def transfer(model, model_weights):
|
||||
transfered_model_weights = {}
|
||||
for weights_name in model.state_dict().keys():
|
||||
transfered_model_weights[weights_name] = model_weights['.'.join(weights_name.split('.')[1:])]
|
||||
return transfered_model_weights
|
||||
|
||||
|
||||
def draw_bodypose(canvas: np.ndarray, keypoints: List[Keypoint]) -> np.ndarray:
|
||||
"""
|
||||
Draw keypoints and limbs representing body pose on a given canvas.
|
||||
|
||||
Args:
|
||||
canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the body pose.
|
||||
keypoints (List[Keypoint]): A list of Keypoint objects representing the body keypoints to be drawn.
|
||||
|
||||
Returns:
|
||||
np.ndarray: A 3D numpy array representing the modified canvas with the drawn body pose.
|
||||
|
||||
Note:
|
||||
The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1.
|
||||
"""
|
||||
H, W, C = canvas.shape
|
||||
stickwidth = 4
|
||||
|
||||
limbSeq = [
|
||||
[2, 3], [2, 6], [3, 4], [4, 5],
|
||||
[6, 7], [7, 8], [2, 9], [9, 10],
|
||||
[10, 11], [2, 12], [12, 13], [13, 14],
|
||||
[2, 1], [1, 15], [15, 17], [1, 16],
|
||||
[16, 18],
|
||||
]
|
||||
|
||||
colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \
|
||||
[0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \
|
||||
[170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
|
||||
|
||||
for (k1_index, k2_index), color in zip(limbSeq, colors, strict=False):
|
||||
keypoint1 = keypoints[k1_index - 1]
|
||||
keypoint2 = keypoints[k2_index - 1]
|
||||
|
||||
if keypoint1 is None or keypoint2 is None:
|
||||
continue
|
||||
|
||||
Y = np.array([keypoint1.x, keypoint2.x]) * float(W)
|
||||
X = np.array([keypoint1.y, keypoint2.y]) * float(H)
|
||||
mX = np.mean(X)
|
||||
mY = np.mean(Y)
|
||||
length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
|
||||
angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
|
||||
polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
|
||||
cv2.fillConvexPoly(canvas, polygon, [int(float(c) * 0.6) for c in color])
|
||||
|
||||
for keypoint, color in zip(keypoints, colors, strict=False):
|
||||
if keypoint is None:
|
||||
continue
|
||||
|
||||
x, y = keypoint.x, keypoint.y
|
||||
x = int(x * W)
|
||||
y = int(y * H)
|
||||
cv2.circle(canvas, (int(x), int(y)), 4, color, thickness=-1)
|
||||
|
||||
return canvas
|
||||
|
||||
|
||||
def draw_handpose(canvas: np.ndarray, keypoints: Union[List[Keypoint], None]) -> np.ndarray:
|
||||
import matplotlib
|
||||
"""
|
||||
Draw keypoints and connections representing hand pose on a given canvas.
|
||||
|
||||
Args:
|
||||
canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the hand pose.
|
||||
keypoints (List[Keypoint]| None): A list of Keypoint objects representing the hand keypoints to be drawn
|
||||
or None if no keypoints are present.
|
||||
|
||||
Returns:
|
||||
np.ndarray: A 3D numpy array representing the modified canvas with the drawn hand pose.
|
||||
|
||||
Note:
|
||||
The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1.
|
||||
"""
|
||||
if not keypoints:
|
||||
return canvas
|
||||
|
||||
H, W, C = canvas.shape
|
||||
|
||||
edges = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7], [7, 8], [0, 9], [9, 10], \
|
||||
[10, 11], [11, 12], [0, 13], [13, 14], [14, 15], [15, 16], [0, 17], [17, 18], [18, 19], [19, 20]]
|
||||
|
||||
for ie, (e1, e2) in enumerate(edges):
|
||||
k1 = keypoints[e1]
|
||||
k2 = keypoints[e2]
|
||||
if k1 is None or k2 is None:
|
||||
continue
|
||||
|
||||
x1 = int(k1.x * W)
|
||||
y1 = int(k1.y * H)
|
||||
x2 = int(k2.x * W)
|
||||
y2 = int(k2.y * H)
|
||||
if x1 > eps and y1 > eps and x2 > eps and y2 > eps:
|
||||
cv2.line(canvas, (x1, y1), (x2, y2), matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * 255, thickness=2)
|
||||
|
||||
for keypoint in keypoints:
|
||||
x, y = keypoint.x, keypoint.y
|
||||
x = int(x * W)
|
||||
y = int(y * H)
|
||||
if x > eps and y > eps:
|
||||
cv2.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1)
|
||||
return canvas
|
||||
|
||||
|
||||
def draw_facepose(canvas: np.ndarray, keypoints: Union[List[Keypoint], None]) -> np.ndarray:
|
||||
"""
|
||||
Draw keypoints representing face pose on a given canvas.
|
||||
|
||||
Args:
|
||||
canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the face pose.
|
||||
keypoints (List[Keypoint]| None): A list of Keypoint objects representing the face keypoints to be drawn
|
||||
or None if no keypoints are present.
|
||||
|
||||
Returns:
|
||||
np.ndarray: A 3D numpy array representing the modified canvas with the drawn face pose.
|
||||
|
||||
Note:
|
||||
The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1.
|
||||
"""
|
||||
if not keypoints:
|
||||
return canvas
|
||||
|
||||
H, W, C = canvas.shape
|
||||
for keypoint in keypoints:
|
||||
x, y = keypoint.x, keypoint.y
|
||||
x = int(x * W)
|
||||
y = int(y * H)
|
||||
if x > eps and y > eps:
|
||||
cv2.circle(canvas, (x, y), 3, (255, 255, 255), thickness=-1)
|
||||
return canvas
|
||||
|
||||
|
||||
# detect hand according to body pose keypoints
|
||||
# please refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/hand/handDetector.cpp
|
||||
def handDetect(body: BodyResult, oriImg) -> List[Tuple[int, int, int, bool]]:
|
||||
"""
|
||||
Detect hands in the input body pose keypoints and calculate the bounding box for each hand.
|
||||
|
||||
Args:
|
||||
body (BodyResult): A BodyResult object containing the detected body pose keypoints.
|
||||
oriImg (numpy.ndarray): A 3D numpy array representing the original input image.
|
||||
|
||||
Returns:
|
||||
List[Tuple[int, int, int, bool]]: A list of tuples, each containing the coordinates (x, y) of the top-left
|
||||
corner of the bounding box, the width (height) of the bounding box, and
|
||||
a boolean flag indicating whether the hand is a left hand (True) or a
|
||||
right hand (False).
|
||||
|
||||
Notes:
|
||||
- The width and height of the bounding boxes are equal since the network requires squared input.
|
||||
- The minimum bounding box size is 20 pixels.
|
||||
"""
|
||||
ratioWristElbow = 0.33
|
||||
detect_result = []
|
||||
image_height, image_width = oriImg.shape[0:2]
|
||||
|
||||
keypoints = body.keypoints
|
||||
# right hand: wrist 4, elbow 3, shoulder 2
|
||||
# left hand: wrist 7, elbow 6, shoulder 5
|
||||
left_shoulder = keypoints[5]
|
||||
left_elbow = keypoints[6]
|
||||
left_wrist = keypoints[7]
|
||||
right_shoulder = keypoints[2]
|
||||
right_elbow = keypoints[3]
|
||||
right_wrist = keypoints[4]
|
||||
|
||||
# if any of three not detected
|
||||
has_left = all(keypoint is not None for keypoint in (left_shoulder, left_elbow, left_wrist))
|
||||
has_right = all(keypoint is not None for keypoint in (right_shoulder, right_elbow, right_wrist))
|
||||
if not (has_left or has_right):
|
||||
return []
|
||||
|
||||
hands = []
|
||||
#left hand
|
||||
if has_left:
|
||||
hands.append([
|
||||
left_shoulder.x, left_shoulder.y,
|
||||
left_elbow.x, left_elbow.y,
|
||||
left_wrist.x, left_wrist.y,
|
||||
True
|
||||
])
|
||||
# right hand
|
||||
if has_right:
|
||||
hands.append([
|
||||
right_shoulder.x, right_shoulder.y,
|
||||
right_elbow.x, right_elbow.y,
|
||||
right_wrist.x, right_wrist.y,
|
||||
False
|
||||
])
|
||||
|
||||
for x1, y1, x2, y2, x3, y3, is_left in hands:
|
||||
# pos_hand = pos_wrist + ratio * (pos_wrist - pos_elbox) = (1 + ratio) * pos_wrist - ratio * pos_elbox
|
||||
# handRectangle.x = posePtr[wrist*3] + ratioWristElbow * (posePtr[wrist*3] - posePtr[elbow*3]);
|
||||
# handRectangle.y = posePtr[wrist*3+1] + ratioWristElbow * (posePtr[wrist*3+1] - posePtr[elbow*3+1]);
|
||||
# const auto distanceWristElbow = getDistance(poseKeypoints, person, wrist, elbow);
|
||||
# const auto distanceElbowShoulder = getDistance(poseKeypoints, person, elbow, shoulder);
|
||||
# handRectangle.width = 1.5f * fastMax(distanceWristElbow, 0.9f * distanceElbowShoulder);
|
||||
x = x3 + ratioWristElbow * (x3 - x2)
|
||||
y = y3 + ratioWristElbow * (y3 - y2)
|
||||
distanceWristElbow = math.sqrt((x3 - x2) ** 2 + (y3 - y2) ** 2)
|
||||
distanceElbowShoulder = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)
|
||||
width = 1.5 * max(distanceWristElbow, 0.9 * distanceElbowShoulder)
|
||||
# x-y refers to the center --> offset to topLeft point
|
||||
# handRectangle.x -= handRectangle.width / 2.f;
|
||||
# handRectangle.y -= handRectangle.height / 2.f;
|
||||
x -= width / 2
|
||||
y -= width / 2 # width = height
|
||||
# overflow the image
|
||||
if x < 0:
|
||||
x = 0
|
||||
if y < 0:
|
||||
y = 0
|
||||
width1 = width
|
||||
width2 = width
|
||||
if x + width > image_width:
|
||||
width1 = image_width - x
|
||||
if y + width > image_height:
|
||||
width2 = image_height - y
|
||||
width = min(width1, width2)
|
||||
# the max hand box value is 20 pixels
|
||||
if width >= 20:
|
||||
detect_result.append((int(x), int(y), int(width), is_left))
|
||||
|
||||
'''
|
||||
return value: [[x, y, w, True if left hand else False]].
|
||||
width=height since the network require squared input.
|
||||
x, y is the coordinate of top left.
|
||||
'''
|
||||
return detect_result
|
||||
|
||||
|
||||
# Written by Lvmin
|
||||
def faceDetect(body: BodyResult, oriImg) -> Union[Tuple[int, int, int], None]:
|
||||
"""
|
||||
Detect the face in the input body pose keypoints and calculate the bounding box for the face.
|
||||
|
||||
Args:
|
||||
body (BodyResult): A BodyResult object containing the detected body pose keypoints.
|
||||
oriImg (numpy.ndarray): A 3D numpy array representing the original input image.
|
||||
|
||||
Returns:
|
||||
Tuple[int, int, int] | None: A tuple containing the coordinates (x, y) of the top-left corner of the
|
||||
bounding box and the width (height) of the bounding box, or None if the
|
||||
face is not detected or the bounding box width is less than 20 pixels.
|
||||
|
||||
Notes:
|
||||
- The width and height of the bounding box are equal.
|
||||
- The minimum bounding box size is 20 pixels.
|
||||
"""
|
||||
# left right eye ear 14 15 16 17
|
||||
image_height, image_width = oriImg.shape[0:2]
|
||||
|
||||
keypoints = body.keypoints
|
||||
head = keypoints[0]
|
||||
left_eye = keypoints[14]
|
||||
right_eye = keypoints[15]
|
||||
left_ear = keypoints[16]
|
||||
right_ear = keypoints[17]
|
||||
|
||||
if head is None or all(keypoint is None for keypoint in (left_eye, right_eye, left_ear, right_ear)):
|
||||
return None
|
||||
|
||||
width = 0.0
|
||||
x0, y0 = head.x, head.y
|
||||
|
||||
if left_eye is not None:
|
||||
x1, y1 = left_eye.x, left_eye.y
|
||||
d = max(abs(x0 - x1), abs(y0 - y1))
|
||||
width = max(width, d * 3.0)
|
||||
|
||||
if right_eye is not None:
|
||||
x1, y1 = right_eye.x, right_eye.y
|
||||
d = max(abs(x0 - x1), abs(y0 - y1))
|
||||
width = max(width, d * 3.0)
|
||||
|
||||
if left_ear is not None:
|
||||
x1, y1 = left_ear.x, left_ear.y
|
||||
d = max(abs(x0 - x1), abs(y0 - y1))
|
||||
width = max(width, d * 1.5)
|
||||
|
||||
if right_ear is not None:
|
||||
x1, y1 = right_ear.x, right_ear.y
|
||||
d = max(abs(x0 - x1), abs(y0 - y1))
|
||||
width = max(width, d * 1.5)
|
||||
|
||||
x, y = x0, y0
|
||||
|
||||
x -= width
|
||||
y -= width
|
||||
|
||||
if x < 0:
|
||||
x = 0
|
||||
|
||||
if y < 0:
|
||||
y = 0
|
||||
|
||||
width1 = width * 2
|
||||
width2 = width * 2
|
||||
|
||||
if x + width > image_width:
|
||||
width1 = image_width - x
|
||||
|
||||
if y + width > image_height:
|
||||
width2 = image_height - y
|
||||
|
||||
width = min(width1, width2)
|
||||
|
||||
if width >= 20:
|
||||
return int(x), int(y), int(width)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
# get max index of 2d array
|
||||
def npmax(array):
|
||||
arrayindex = array.argmax(1)
|
||||
arrayvalue = array.max(1)
|
||||
i = arrayvalue.argmax()
|
||||
j = arrayindex[i]
|
||||
return i, j
|
||||
146
invokeai/backend/bria/controlnet_aux/util.py
Normal file
146
invokeai/backend/bria/controlnet_aux/util.py
Normal file
@@ -0,0 +1,146 @@
|
||||
import os
|
||||
import random
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
annotator_ckpts_path = os.path.join(os.path.dirname(__file__), 'ckpts')
|
||||
|
||||
|
||||
def HWC3(x):
|
||||
assert x.dtype == np.uint8
|
||||
if x.ndim == 2:
|
||||
x = x[:, :, None]
|
||||
assert x.ndim == 3
|
||||
H, W, C = x.shape
|
||||
assert C == 1 or C == 3 or C == 4
|
||||
if C == 3:
|
||||
return x
|
||||
if C == 1:
|
||||
return np.concatenate([x, x, x], axis=2)
|
||||
if C == 4:
|
||||
color = x[:, :, 0:3].astype(np.float32)
|
||||
alpha = x[:, :, 3:4].astype(np.float32) / 255.0
|
||||
y = color * alpha + 255.0 * (1.0 - alpha)
|
||||
y = y.clip(0, 255).astype(np.uint8)
|
||||
return y
|
||||
|
||||
|
||||
def make_noise_disk(H, W, C, F):
|
||||
noise = np.random.uniform(low=0, high=1, size=((H // F) + 2, (W // F) + 2, C))
|
||||
noise = cv2.resize(noise, (W + 2 * F, H + 2 * F), interpolation=cv2.INTER_CUBIC)
|
||||
noise = noise[F: F + H, F: F + W]
|
||||
noise -= np.min(noise)
|
||||
noise /= np.max(noise)
|
||||
if C == 1:
|
||||
noise = noise[:, :, None]
|
||||
return noise
|
||||
|
||||
|
||||
def nms(x, t, s):
|
||||
x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s)
|
||||
|
||||
f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8)
|
||||
f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8)
|
||||
f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8)
|
||||
f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8)
|
||||
|
||||
y = np.zeros_like(x)
|
||||
|
||||
for f in [f1, f2, f3, f4]:
|
||||
np.putmask(y, cv2.dilate(x, kernel=f) == x, x)
|
||||
|
||||
z = np.zeros_like(y, dtype=np.uint8)
|
||||
z[y > t] = 255
|
||||
return z
|
||||
|
||||
def min_max_norm(x):
|
||||
x -= np.min(x)
|
||||
x /= np.maximum(np.max(x), 1e-5)
|
||||
return x
|
||||
|
||||
|
||||
def safe_step(x, step=2):
|
||||
y = x.astype(np.float32) * float(step + 1)
|
||||
y = y.astype(np.int32).astype(np.float32) / float(step)
|
||||
return y
|
||||
|
||||
|
||||
def img2mask(img, H, W, low=10, high=90):
|
||||
assert img.ndim == 3 or img.ndim == 2
|
||||
assert img.dtype == np.uint8
|
||||
|
||||
if img.ndim == 3:
|
||||
y = img[:, :, random.randrange(0, img.shape[2])]
|
||||
else:
|
||||
y = img
|
||||
|
||||
y = cv2.resize(y, (W, H), interpolation=cv2.INTER_CUBIC)
|
||||
|
||||
if random.uniform(0, 1) < 0.5:
|
||||
y = 255 - y
|
||||
|
||||
return y < np.percentile(y, random.randrange(low, high))
|
||||
|
||||
|
||||
def resize_image(input_image, resolution):
|
||||
H, W, C = input_image.shape
|
||||
H = float(H)
|
||||
W = float(W)
|
||||
k = float(resolution) / min(H, W)
|
||||
H *= k
|
||||
W *= k
|
||||
H = int(np.round(H / 64.0)) * 64
|
||||
W = int(np.round(W / 64.0)) * 64
|
||||
img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
|
||||
return img
|
||||
|
||||
|
||||
def torch_gc():
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
|
||||
|
||||
def ade_palette():
|
||||
"""ADE20K palette that maps each class to RGB values."""
|
||||
return [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
|
||||
[4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
|
||||
[230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
|
||||
[150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
|
||||
[143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
|
||||
[0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
|
||||
[255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
|
||||
[255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
|
||||
[255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
|
||||
[224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
|
||||
[255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
|
||||
[6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
|
||||
[140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
|
||||
[255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
|
||||
[255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255],
|
||||
[11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],
|
||||
[0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0],
|
||||
[255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],
|
||||
[0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255],
|
||||
[173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],
|
||||
[255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20],
|
||||
[255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],
|
||||
[255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255],
|
||||
[0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],
|
||||
[0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0],
|
||||
[143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],
|
||||
[8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255],
|
||||
[255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],
|
||||
[92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160],
|
||||
[163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],
|
||||
[255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0],
|
||||
[255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],
|
||||
[10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255],
|
||||
[255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],
|
||||
[41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],
|
||||
[71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],
|
||||
[184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],
|
||||
[102, 255, 0], [92, 0, 255]]
|
||||
|
||||
547
invokeai/backend/bria/controlnet_bria.py
Normal file
547
invokeai/backend/bria/controlnet_bria.py
Normal file
@@ -0,0 +1,547 @@
|
||||
# type: ignore
|
||||
# Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.loaders import PeftAdapterMixin
|
||||
from diffusers.models.attention_processor import AttentionProcessor
|
||||
from diffusers.models.controlnet import zero_module
|
||||
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
||||
from diffusers.utils.outputs import BaseOutput
|
||||
|
||||
from invokeai.backend.bria.transformer_bria import (
|
||||
EmbedND,
|
||||
FluxSingleTransformerBlock,
|
||||
FluxTransformerBlock,
|
||||
TimestepProjEmbeddings,
|
||||
)
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
BRIA_CONTROL_MODES = Literal["depth", "canny", "colorgrid", "recolor", "tile", "pose"]
|
||||
class BriaControlModes(Enum):
|
||||
depth = 0
|
||||
canny = 1
|
||||
colorgrid = 2
|
||||
recolor = 3
|
||||
tile = 4
|
||||
pose = 5
|
||||
|
||||
|
||||
@dataclass
|
||||
class BriaControlNetOutput(BaseOutput):
|
||||
controlnet_block_samples: Tuple[torch.Tensor]
|
||||
controlnet_single_block_samples: Tuple[torch.Tensor]
|
||||
|
||||
|
||||
class BriaControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int = 1,
|
||||
in_channels: int = 64,
|
||||
num_layers: int = 19,
|
||||
num_single_layers: int = 38,
|
||||
attention_head_dim: int = 128,
|
||||
num_attention_heads: int = 24,
|
||||
joint_attention_dim: int = 4096,
|
||||
pooled_projection_dim: int = 768,
|
||||
guidance_embeds: bool = False,
|
||||
axes_dims_rope: Optional[List[int]] = None,
|
||||
num_mode: int = None,
|
||||
rope_theta: int = 10000,
|
||||
time_theta: int = 10000,
|
||||
):
|
||||
super().__init__()
|
||||
self.out_channels = in_channels
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
# self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
|
||||
axes_dims_rope = [16, 56, 56] if axes_dims_rope is None else axes_dims_rope
|
||||
self.pos_embed = EmbedND(theta=rope_theta, axes_dim=axes_dims_rope)
|
||||
|
||||
# text_time_guidance_cls = (
|
||||
# CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
|
||||
# )
|
||||
# self.time_text_embed = text_time_guidance_cls(
|
||||
# embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
|
||||
# )
|
||||
self.time_embed = TimestepProjEmbeddings(
|
||||
embedding_dim=self.inner_dim, time_theta=time_theta
|
||||
)
|
||||
|
||||
self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
|
||||
self.x_embedder = torch.nn.Linear(in_channels, self.inner_dim)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
FluxTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
)
|
||||
for i in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.single_transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
FluxSingleTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
)
|
||||
for i in range(num_single_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# controlnet_blocks
|
||||
self.controlnet_blocks = nn.ModuleList([])
|
||||
for _ in range(len(self.transformer_blocks)):
|
||||
self.controlnet_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim)))
|
||||
|
||||
self.controlnet_single_blocks = nn.ModuleList([])
|
||||
for _ in range(len(self.single_transformer_blocks)):
|
||||
self.controlnet_single_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim)))
|
||||
|
||||
self.union = num_mode is not None and num_mode > 0
|
||||
if self.union:
|
||||
self.controlnet_mode_embedder = nn.Embedding(num_mode, self.inner_dim)
|
||||
|
||||
self.controlnet_x_embedder = zero_module(torch.nn.Linear(in_channels, self.inner_dim))
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self):
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
@classmethod
|
||||
def from_transformer(
|
||||
cls,
|
||||
transformer,
|
||||
num_layers: int = 4,
|
||||
num_single_layers: int = 10,
|
||||
attention_head_dim: int = 128,
|
||||
num_attention_heads: int = 24,
|
||||
load_weights_from_transformer=True,
|
||||
):
|
||||
config = transformer.config
|
||||
config["num_layers"] = num_layers
|
||||
config["num_single_layers"] = num_single_layers
|
||||
config["attention_head_dim"] = attention_head_dim
|
||||
config["num_attention_heads"] = num_attention_heads
|
||||
|
||||
controlnet = cls(**config)
|
||||
|
||||
if load_weights_from_transformer:
|
||||
controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict())
|
||||
controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict())
|
||||
controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict())
|
||||
controlnet.x_embedder.load_state_dict(transformer.x_embedder.state_dict())
|
||||
controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict(), strict=False)
|
||||
controlnet.single_transformer_blocks.load_state_dict(
|
||||
transformer.single_transformer_blocks.state_dict(), strict=False
|
||||
)
|
||||
|
||||
controlnet.controlnet_x_embedder = zero_module(controlnet.controlnet_x_embedder)
|
||||
|
||||
return controlnet
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
controlnet_cond: torch.Tensor,
|
||||
controlnet_mode: torch.Tensor = None,
|
||||
conditioning_scale: float = 1.0,
|
||||
encoder_hidden_states: torch.Tensor = None,
|
||||
pooled_projections: torch.Tensor = None,
|
||||
timestep: torch.LongTensor = None,
|
||||
img_ids: torch.Tensor = None,
|
||||
txt_ids: torch.Tensor = None,
|
||||
guidance: torch.Tensor = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
|
||||
"""
|
||||
The [`FluxTransformer2DModel`] forward method.
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
|
||||
Input `hidden_states`.
|
||||
controlnet_cond (`torch.Tensor`):
|
||||
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
|
||||
controlnet_mode (`torch.Tensor`):
|
||||
The mode tensor of shape `(batch_size, 1)`.
|
||||
conditioning_scale (`float`, defaults to `1.0`):
|
||||
The scale factor for ControlNet outputs.
|
||||
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
|
||||
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
||||
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
|
||||
from the embeddings of input conditions.
|
||||
timestep ( `torch.LongTensor`):
|
||||
Used to indicate denoising step.
|
||||
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
|
||||
A list of tensors that if specified are added to the residuals of transformer blocks.
|
||||
joint_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
|
||||
tuple.
|
||||
Returns:
|
||||
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
||||
`tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
if guidance is not None:
|
||||
print("guidance is not supported in BriaControlNetModel")
|
||||
if pooled_projections is not None:
|
||||
print("pooled_projections is not supported in BriaControlNetModel")
|
||||
if joint_attention_kwargs is not None:
|
||||
joint_attention_kwargs = joint_attention_kwargs.copy()
|
||||
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
hidden_states = self.x_embedder(hidden_states)
|
||||
|
||||
# Convert controlnet_cond to the same dtype as the model weights
|
||||
controlnet_cond = controlnet_cond.to(dtype=self.controlnet_x_embedder.weight.dtype)
|
||||
|
||||
# add
|
||||
hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_cond)
|
||||
|
||||
timestep = timestep.to(hidden_states.dtype) # Original code was * 1000
|
||||
if guidance is not None:
|
||||
guidance = guidance.to(hidden_states.dtype) # Original code was * 1000
|
||||
else:
|
||||
guidance = None
|
||||
|
||||
temb = self.time_embed(timestep, dtype=hidden_states.dtype)
|
||||
|
||||
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
||||
|
||||
if txt_ids.ndim == 3:
|
||||
logger.warning(
|
||||
"Passing `txt_ids` 3d torch.Tensor is deprecated."
|
||||
"Please remove the batch dimension and pass it as a 2d torch Tensor"
|
||||
)
|
||||
txt_ids = txt_ids[0]
|
||||
if img_ids.ndim == 3:
|
||||
logger.warning(
|
||||
"Passing `img_ids` 3d torch.Tensor is deprecated."
|
||||
"Please remove the batch dimension and pass it as a 2d torch Tensor"
|
||||
)
|
||||
img_ids = img_ids[0]
|
||||
|
||||
if self.union:
|
||||
# union mode
|
||||
if controlnet_mode is None:
|
||||
raise ValueError("`controlnet_mode` cannot be `None` when applying ControlNet-Union")
|
||||
|
||||
# Validate controlnet_mode values are within the valid range
|
||||
if torch.any(controlnet_mode < 0) or torch.any(controlnet_mode >= self.num_mode):
|
||||
raise ValueError(f"`controlnet_mode` values must be in range [0, {self.num_mode-1}], but got values outside this range")
|
||||
|
||||
# union mode emb
|
||||
controlnet_mode_emb = self.controlnet_mode_embedder(controlnet_mode)
|
||||
if controlnet_mode_emb.shape[0] < encoder_hidden_states.shape[0]: # duplicate mode emb for each batch
|
||||
controlnet_mode_emb = controlnet_mode_emb.expand(encoder_hidden_states.shape[0], 1, encoder_hidden_states.shape[2])
|
||||
encoder_hidden_states = torch.cat([controlnet_mode_emb, encoder_hidden_states], dim=1)
|
||||
|
||||
txt_ids = torch.cat((txt_ids[0:1, :], txt_ids), dim=0)
|
||||
ids = torch.cat((txt_ids, img_ids), dim=0)
|
||||
image_rotary_emb = self.pos_embed(ids)
|
||||
|
||||
block_samples = ()
|
||||
for _, block in enumerate(self.transformer_blocks):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
|
||||
else:
|
||||
encoder_hidden_states, hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
block_samples = block_samples + (hidden_states,)
|
||||
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
single_block_samples = ()
|
||||
for _, block in enumerate(self.single_transformer_blocks):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
|
||||
else:
|
||||
hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
single_block_samples = single_block_samples + (hidden_states[:, encoder_hidden_states.shape[1] :],)
|
||||
|
||||
# controlnet block
|
||||
controlnet_block_samples = ()
|
||||
for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks, strict=False):
|
||||
block_sample = controlnet_block(block_sample)
|
||||
controlnet_block_samples = controlnet_block_samples + (block_sample,)
|
||||
|
||||
controlnet_single_block_samples = ()
|
||||
for single_block_sample, controlnet_block in zip(single_block_samples, self.controlnet_single_blocks, strict=False):
|
||||
single_block_sample = controlnet_block(single_block_sample)
|
||||
controlnet_single_block_samples = controlnet_single_block_samples + (single_block_sample,)
|
||||
|
||||
# scaling
|
||||
controlnet_block_samples = [sample * conditioning_scale for sample in controlnet_block_samples]
|
||||
controlnet_single_block_samples = [sample * conditioning_scale for sample in controlnet_single_block_samples]
|
||||
|
||||
controlnet_block_samples = None if len(controlnet_block_samples) == 0 else controlnet_block_samples
|
||||
controlnet_single_block_samples = (
|
||||
None if len(controlnet_single_block_samples) == 0 else controlnet_single_block_samples
|
||||
)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (controlnet_block_samples, controlnet_single_block_samples)
|
||||
|
||||
return BriaControlNetOutput(
|
||||
controlnet_block_samples=controlnet_block_samples,
|
||||
controlnet_single_block_samples=controlnet_single_block_samples,
|
||||
)
|
||||
|
||||
|
||||
class BriaMultiControlNetModel(ModelMixin):
|
||||
r"""
|
||||
`BriaMultiControlNetModel` wrapper class for Multi-BriaControlNetModel
|
||||
This module is a wrapper for multiple instances of the `BriaControlNetModel`. The `forward()` API is designed to be
|
||||
compatible with `BriaControlNetModel`.
|
||||
Args:
|
||||
controlnets (`List[BriaControlNetModel]`):
|
||||
Provides additional conditioning to the unet during the denoising process. You must set multiple
|
||||
`BriaControlNetModel` as a list.
|
||||
"""
|
||||
|
||||
def __init__(self, controlnets):
|
||||
super().__init__()
|
||||
self.nets = nn.ModuleList(controlnets)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
controlnet_cond: List[torch.tensor],
|
||||
controlnet_mode: List[torch.tensor],
|
||||
conditioning_scale: List[float],
|
||||
encoder_hidden_states: torch.Tensor = None,
|
||||
pooled_projections: torch.Tensor = None,
|
||||
timestep: torch.LongTensor = None,
|
||||
img_ids: torch.Tensor = None,
|
||||
txt_ids: torch.Tensor = None,
|
||||
guidance: torch.Tensor = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[BriaControlNetOutput, Tuple]:
|
||||
# ControlNet-Union with multiple conditions
|
||||
# only load one ControlNet for saving memories
|
||||
if len(self.nets) == 1 and self.nets[0].union:
|
||||
controlnet = self.nets[0]
|
||||
|
||||
for i, (image, mode, scale) in enumerate(zip(controlnet_cond, controlnet_mode, conditioning_scale, strict=False)):
|
||||
block_samples, single_block_samples = controlnet(
|
||||
hidden_states=hidden_states,
|
||||
controlnet_cond=image,
|
||||
controlnet_mode=mode[:, None],
|
||||
conditioning_scale=scale,
|
||||
timestep=timestep,
|
||||
guidance=guidance,
|
||||
pooled_projections=pooled_projections,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
txt_ids=txt_ids,
|
||||
img_ids=img_ids,
|
||||
joint_attention_kwargs=joint_attention_kwargs,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
# merge samples
|
||||
if i == 0:
|
||||
control_block_samples = block_samples
|
||||
control_single_block_samples = single_block_samples
|
||||
else:
|
||||
control_block_samples = [
|
||||
control_block_sample + block_sample
|
||||
for control_block_sample, block_sample in zip(control_block_samples, block_samples, strict=False)
|
||||
]
|
||||
|
||||
control_single_block_samples = [
|
||||
control_single_block_sample + block_sample
|
||||
for control_single_block_sample, block_sample in zip(
|
||||
control_single_block_samples, single_block_samples, strict=False
|
||||
)
|
||||
]
|
||||
|
||||
# Regular Multi-ControlNets
|
||||
# load all ControlNets into memories
|
||||
else:
|
||||
for i, (image, mode, scale, controlnet) in enumerate(
|
||||
zip(controlnet_cond, controlnet_mode, conditioning_scale, self.nets, strict=False)
|
||||
):
|
||||
block_samples, single_block_samples = controlnet(
|
||||
hidden_states=hidden_states,
|
||||
controlnet_cond=image,
|
||||
controlnet_mode=mode[:, None],
|
||||
conditioning_scale=scale,
|
||||
timestep=timestep,
|
||||
guidance=guidance,
|
||||
pooled_projections=pooled_projections,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
txt_ids=txt_ids,
|
||||
img_ids=img_ids,
|
||||
joint_attention_kwargs=joint_attention_kwargs,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
# merge samples
|
||||
if i == 0:
|
||||
control_block_samples = block_samples
|
||||
control_single_block_samples = single_block_samples
|
||||
else:
|
||||
if block_samples is not None and control_block_samples is not None:
|
||||
control_block_samples = [
|
||||
control_block_sample + block_sample
|
||||
for control_block_sample, block_sample in zip(control_block_samples, block_samples, strict=False)
|
||||
]
|
||||
if single_block_samples is not None and control_single_block_samples is not None:
|
||||
control_single_block_samples = [
|
||||
control_single_block_sample + block_sample
|
||||
for control_single_block_sample, block_sample in zip(
|
||||
control_single_block_samples, single_block_samples, strict=False
|
||||
)
|
||||
]
|
||||
|
||||
return control_block_samples, control_single_block_samples
|
||||
67
invokeai/backend/bria/controlnet_utils.py
Normal file
67
invokeai/backend/bria/controlnet_utils.py
Normal file
@@ -0,0 +1,67 @@
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
||||
from PIL import Image
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def prepare_control_images(
|
||||
vae: AutoencoderKL,
|
||||
control_images: list[Image.Image],
|
||||
control_modes: list[int],
|
||||
width: int,
|
||||
height: int,
|
||||
device: torch.device,
|
||||
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||
|
||||
tensored_control_images = []
|
||||
tensored_control_modes = []
|
||||
for idx, control_image_ in enumerate(control_images):
|
||||
tensored_control_image = _prepare_image(
|
||||
image=control_image_,
|
||||
width=width,
|
||||
height=height,
|
||||
device=device,
|
||||
dtype=vae.dtype,
|
||||
)
|
||||
height, width = tensored_control_image.shape[-2:]
|
||||
|
||||
# vae encode
|
||||
tensored_control_image = vae.encode(tensored_control_image).latent_dist.sample()
|
||||
tensored_control_image = (tensored_control_image) * vae.config.scaling_factor
|
||||
|
||||
# pack
|
||||
height_control_image, width_control_image = tensored_control_image.shape[2:]
|
||||
tensored_control_image = _pack_latents(
|
||||
tensored_control_image,
|
||||
height_control_image,
|
||||
width_control_image,
|
||||
)
|
||||
tensored_control_images.append(tensored_control_image)
|
||||
tensored_control_modes.append(torch.tensor(control_modes[idx]).expand(
|
||||
tensored_control_image.shape[0]).to(device, dtype=torch.long))
|
||||
|
||||
return tensored_control_images, tensored_control_modes
|
||||
|
||||
def _prepare_image(
|
||||
image: Image.Image,
|
||||
width: int,
|
||||
height: int,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
image = image.convert("RGB")
|
||||
image = VaeImageProcessor(vae_scale_factor=16).preprocess(image, height=height, width=width)
|
||||
image = image.repeat_interleave(1, dim=0)
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
return image
|
||||
|
||||
def _pack_latents(latents, height, width):
|
||||
latents = latents.view(1, 4, height // 2, 2, width // 2, 2)
|
||||
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
||||
latents = latents.reshape(1, (height // 2) * (width // 2), 16)
|
||||
|
||||
return latents
|
||||
|
||||
640
invokeai/backend/bria/pipeline_bria.py
Normal file
640
invokeai/backend/bria/pipeline_bria.py
Normal file
@@ -0,0 +1,640 @@
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import diffusers
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers import AutoencoderKL, DDIMScheduler, EulerAncestralDiscreteScheduler
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.loaders import FluxLoraLoaderMixin
|
||||
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline, calculate_shift, retrieve_timesteps
|
||||
from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler, KarrasDiffusionSchedulers
|
||||
from diffusers.utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
scale_lora_layers,
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
from transformers import (
|
||||
T5EncoderModel,
|
||||
T5TokenizerFast,
|
||||
)
|
||||
|
||||
from invokeai.backend.bria.bria_utils import get_original_sigmas, get_t5_prompt_embeds, is_ng_none
|
||||
from invokeai.backend.bria.transformer_bria import BriaTransformer2DModel
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers import StableDiffusion3Pipeline
|
||||
|
||||
>>> pipe = StableDiffusion3Pipeline.from_pretrained(
|
||||
... "stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16
|
||||
... )
|
||||
>>> pipe.to("cuda")
|
||||
>>> prompt = "A cat holding a sign that says hello world"
|
||||
>>> image = pipe(prompt).images[0]
|
||||
>>> image.save("sd3.png")
|
||||
```
|
||||
"""
|
||||
|
||||
T5_PRECISION = torch.float16
|
||||
|
||||
"""
|
||||
Based on FluxPipeline with several changes:
|
||||
- no pooled embeddings
|
||||
- We use zero padding for prompts
|
||||
- No guidance embedding since this is not a distilled version
|
||||
"""
|
||||
class BriaPipeline(FluxPipeline):
|
||||
r"""
|
||||
Args:
|
||||
transformer ([`SD3Transformer2DModel`]):
|
||||
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
|
||||
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`T5EncoderModel`]):
|
||||
Frozen text-encoder. Stable Diffusion 3 uses
|
||||
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
|
||||
[t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
|
||||
tokenizer (`T5TokenizerFast`):
|
||||
Tokenizer of class
|
||||
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transformer: BriaTransformer2DModel,
|
||||
scheduler: Union[FlowMatchEulerDiscreteScheduler,KarrasDiffusionSchedulers],
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: T5EncoderModel,
|
||||
tokenizer: T5TokenizerFast
|
||||
):
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
# TODO - why different than offical flux (-1)
|
||||
self.vae_scale_factor = (
|
||||
2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
|
||||
)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.default_sample_size = 64 # due to patchify=> 128,128 => res of 1k,1k
|
||||
|
||||
# T5 is senstive to precision so we use the precision used for precompute and cast as needed
|
||||
|
||||
if self.vae.config.shift_factor is None:
|
||||
self.vae.config.shift_factor=0
|
||||
self.vae.to(dtype=torch.float32)
|
||||
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
device: Optional[torch.device] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
max_sequence_length: int = 128,
|
||||
lora_scale: Optional[float] = None,
|
||||
):
|
||||
r"""
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`):
|
||||
torch device
|
||||
num_images_per_prompt (`int`):
|
||||
number of images that should be generated per prompt
|
||||
do_classifier_free_guidance (`bool`):
|
||||
whether to use classifier free guidance or not
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
|
||||
# set lora scale so that monkey patched LoRA
|
||||
# function of text encoder can correctly access it
|
||||
if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
if self.text_encoder is not None and USE_PEFT_BACKEND:
|
||||
scale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
if prompt is not None:
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds = get_t5_prompt_embeds(
|
||||
self.tokenizer,
|
||||
self.text_encoder,
|
||||
prompt=prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
).to(dtype=self.transformer.dtype)
|
||||
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
if not is_ng_none(negative_prompt):
|
||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
|
||||
negative_prompt_embeds = get_t5_prompt_embeds(
|
||||
self.tokenizer,
|
||||
self.text_encoder,
|
||||
prompt=negative_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
).to(dtype=self.transformer.dtype)
|
||||
else:
|
||||
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
|
||||
|
||||
if self.text_encoder is not None:
|
||||
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
|
||||
text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
||||
text_ids = text_ids.repeat(num_images_per_prompt, 1, 1)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds, text_ids
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1
|
||||
|
||||
@property
|
||||
def joint_attention_kwargs(self):
|
||||
return self._joint_attention_kwargs
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 30,
|
||||
timesteps: List[int] = None,
|
||||
guidance_scale: float = 5,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: Optional[List[str]] = None,
|
||||
max_sequence_length: int = 128,
|
||||
clip_value:Union[None,float] = None,
|
||||
normalize:bool = False
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
||||
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
||||
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
||||
passed will be used. Must be in descending order.
|
||||
guidance_scale (`float`, *optional*, defaults to 5.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
|
||||
of a plain tuple.
|
||||
joint_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
|
||||
is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
|
||||
images.
|
||||
"""
|
||||
|
||||
height = height or self.default_sample_size * self.vae_scale_factor
|
||||
width = width or self.default_sample_size * self.vae_scale_factor
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
callback_on_step_end_tensor_inputs = ["latents"] if callback_on_step_end_tensor_inputs is None else callback_on_step_end_tensor_inputs
|
||||
self.check_inputs(
|
||||
prompt=prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
prompt_embeds=prompt_embeds,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._joint_attention_kwargs = joint_attention_kwargs
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
lora_scale = (
|
||||
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
|
||||
)
|
||||
|
||||
(
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
text_ids
|
||||
) = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
lora_scale=lora_scale,
|
||||
)
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
|
||||
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.transformer.config.in_channels // 4 # due to patch=2, we devide by 4
|
||||
latents, latent_image_ids = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
if isinstance(self.scheduler,FlowMatchEulerDiscreteScheduler) and self.scheduler.config['use_dynamic_shifting']:
|
||||
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
||||
image_seq_len = latents.shape[1] # Shift by height - Why just height?
|
||||
print(f"Using dynamic shift in pipeline with sequence length {image_seq_len}")
|
||||
|
||||
mu = calculate_shift(
|
||||
image_seq_len,
|
||||
self.scheduler.config.base_image_seq_len,
|
||||
self.scheduler.config.max_image_seq_len,
|
||||
self.scheduler.config.base_shift,
|
||||
self.scheduler.config.max_shift,
|
||||
)
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler,
|
||||
num_inference_steps,
|
||||
device,
|
||||
timesteps,
|
||||
sigmas,
|
||||
mu=mu,
|
||||
)
|
||||
else:
|
||||
# 4. Prepare timesteps
|
||||
# Sample from training sigmas
|
||||
if isinstance(self.scheduler,DDIMScheduler) or isinstance(self.scheduler,EulerAncestralDiscreteScheduler):
|
||||
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, None, None)
|
||||
else:
|
||||
sigmas = get_original_sigmas(num_train_timesteps=self.scheduler.config.num_train_timesteps,num_inference_steps=num_inference_steps)
|
||||
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps,sigmas=sigmas)
|
||||
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# Supprot different diffusers versions
|
||||
if diffusers.__version__>='0.32.0':
|
||||
latent_image_ids=latent_image_ids[0]
|
||||
text_ids=text_ids[0]
|
||||
|
||||
# 6. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
if not isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latent_model_input.shape[0])
|
||||
|
||||
# This is predicts "v" from flow-matching or eps from diffusion
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
joint_attention_kwargs=self.joint_attention_kwargs,
|
||||
return_dict=False,
|
||||
txt_ids=text_ids,
|
||||
img_ids=latent_image_ids,
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
cfg_noise_pred_text = noise_pred_text.std()
|
||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
if normalize:
|
||||
noise_pred = noise_pred * (0.7 *(cfg_noise_pred_text/noise_pred.std())) + 0.3 * noise_pred
|
||||
|
||||
if clip_value:
|
||||
assert clip_value>0
|
||||
noise_pred = noise_pred.clip(-clip_value,clip_value)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents_dtype = latents.dtype
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
if latents.dtype != latents_dtype:
|
||||
if torch.backends.mps.is_available():
|
||||
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
||||
latents = latents.to(latents_dtype)
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if output_type == "latent":
|
||||
image = latents
|
||||
|
||||
else:
|
||||
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
||||
latents = (latents.to(dtype=torch.float32) / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
||||
image = self.vae.decode(latents.to(dtype=self.vae.dtype), return_dict=False)[0]
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return FluxPipelineOutput(images=image)
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
max_sequence_length=None,
|
||||
):
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
|
||||
if max_sequence_length is not None and max_sequence_length > 512:
|
||||
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
DiffusionPipeline.to(self, *args, **kwargs)
|
||||
# T5 is senstive to precision so we use the precision used for precompute and cast as needed
|
||||
self.text_encoder = self.text_encoder.to(dtype=T5_PRECISION)
|
||||
for block in self.text_encoder.encoder.block:
|
||||
block.layer[-1].DenseReluDense.wo.to(dtype=torch.float32)
|
||||
|
||||
if self.vae.config.shift_factor == 0 and self.vae.dtype!=torch.float32:
|
||||
self.vae.to(dtype=torch.float32)
|
||||
|
||||
|
||||
return self
|
||||
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
device,
|
||||
generator,
|
||||
latents=None,
|
||||
):
|
||||
# VAE applies 8x compression on images but we must also account for packing which requires
|
||||
# latent height and width to be divisible by 2.
|
||||
height = 2 * (int(height) // self.vae_scale_factor)
|
||||
width = 2 * (int(width) // self.vae_scale_factor )
|
||||
|
||||
shape = (batch_size, num_channels_latents, height, width)
|
||||
|
||||
if latents is not None:
|
||||
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
|
||||
return latents.to(device=device, dtype=dtype), latent_image_ids
|
||||
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
|
||||
|
||||
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
|
||||
|
||||
return latents, latent_image_ids
|
||||
|
||||
@staticmethod
|
||||
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
|
||||
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
||||
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
||||
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
|
||||
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
def _unpack_latents(latents, height, width, vae_scale_factor):
|
||||
batch_size, num_patches, channels = latents.shape
|
||||
|
||||
height = height // vae_scale_factor
|
||||
width = width // vae_scale_factor
|
||||
|
||||
latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
|
||||
latents = latents.permute(0, 3, 1, 4, 2, 5)
|
||||
|
||||
latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
|
||||
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
|
||||
latent_image_ids = torch.zeros(height, width, 3)
|
||||
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
|
||||
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
|
||||
|
||||
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
||||
|
||||
latent_image_ids = latent_image_ids.repeat(batch_size, 1, 1, 1)
|
||||
latent_image_ids = latent_image_ids.reshape(
|
||||
batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
||||
)
|
||||
|
||||
return latent_image_ids.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
666
invokeai/backend/bria/pipeline_bria_controlnet.py
Normal file
666
invokeai/backend/bria/pipeline_bria_controlnet.py
Normal file
@@ -0,0 +1,666 @@
|
||||
# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import diffusers
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers import AutoencoderKL # Waiting for diffusers udpdate
|
||||
from diffusers.image_processor import PipelineImageInput
|
||||
from diffusers.pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps
|
||||
from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
|
||||
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler, KarrasDiffusionSchedulers
|
||||
from diffusers.utils import USE_PEFT_BACKEND, logging
|
||||
from diffusers.utils.peft_utils import scale_lora_layers, unscale_lora_layers
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
from transformers import (
|
||||
T5EncoderModel,
|
||||
T5TokenizerFast,
|
||||
)
|
||||
|
||||
from invokeai.backend.bria.bria_utils import get_original_sigmas, get_t5_prompt_embeds, is_ng_none
|
||||
from invokeai.backend.bria.controlnet_bria import BriaControlNetModel
|
||||
from invokeai.backend.bria.pipeline_bria import BriaPipeline
|
||||
from invokeai.backend.bria.transformer_bria import BriaTransformer2DModel
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class BriaControlNetPipeline(BriaPipeline):
|
||||
r"""
|
||||
Args:
|
||||
transformer ([`SD3Transformer2DModel`]):
|
||||
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
|
||||
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`T5EncoderModel`]):
|
||||
Frozen text-encoder. Stable Diffusion 3 uses
|
||||
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
|
||||
[t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
|
||||
tokenizer (`T5TokenizerFast`):
|
||||
Tokenizer of class
|
||||
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder->transformer->vae"
|
||||
_optional_components = []
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"]
|
||||
|
||||
def __init__( # EYAL - removed clip text encoder + tokenizer
|
||||
self,
|
||||
transformer: BriaTransformer2DModel,
|
||||
scheduler: Union[FlowMatchEulerDiscreteScheduler, KarrasDiffusionSchedulers],
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: T5EncoderModel,
|
||||
tokenizer: T5TokenizerFast,
|
||||
controlnet: BriaControlNetModel,
|
||||
):
|
||||
super().__init__(
|
||||
transformer=transformer, scheduler=scheduler, vae=vae, text_encoder=text_encoder, tokenizer=tokenizer
|
||||
)
|
||||
self.register_modules(controlnet=controlnet)
|
||||
|
||||
def prepare_image(
|
||||
self,
|
||||
image,
|
||||
width,
|
||||
height,
|
||||
batch_size,
|
||||
num_images_per_prompt,
|
||||
device,
|
||||
dtype,
|
||||
do_classifier_free_guidance=False,
|
||||
guess_mode=False,
|
||||
):
|
||||
if isinstance(image, torch.Tensor):
|
||||
pass
|
||||
else:
|
||||
image = self.image_processor.preprocess(image, height=height, width=width)
|
||||
|
||||
image_batch_size = image.shape[0]
|
||||
|
||||
if image_batch_size == 1:
|
||||
repeat_by = batch_size
|
||||
else:
|
||||
# image batch size is the same as prompt batch size
|
||||
repeat_by = num_images_per_prompt
|
||||
|
||||
image = image.repeat_interleave(repeat_by, dim=0)
|
||||
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
|
||||
if do_classifier_free_guidance and not guess_mode:
|
||||
image = torch.cat([image] * 2)
|
||||
|
||||
return image
|
||||
|
||||
def prepare_control(self, control_image, width, height, batch_size, num_images_per_prompt, device, control_mode):
|
||||
num_channels_latents = self.transformer.config.in_channels // 4
|
||||
control_image = self.prepare_image(
|
||||
image=control_image,
|
||||
width=width,
|
||||
height=height,
|
||||
batch_size=batch_size * num_images_per_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
dtype=self.vae.dtype,
|
||||
)
|
||||
height, width = control_image.shape[-2:]
|
||||
|
||||
# vae encode
|
||||
control_image = self.vae.encode(control_image).latent_dist.sample()
|
||||
control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
||||
|
||||
# pack
|
||||
height_control_image, width_control_image = control_image.shape[2:]
|
||||
control_image = self._pack_latents(
|
||||
control_image,
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height_control_image,
|
||||
width_control_image,
|
||||
)
|
||||
|
||||
# Here we ensure that `control_mode` has the same length as the control_image.
|
||||
if control_mode is not None:
|
||||
if not isinstance(control_mode, int):
|
||||
raise ValueError(" For `BriaControlNet`, `control_mode` should be an `int` or `None`")
|
||||
control_mode = torch.tensor(control_mode).to(device, dtype=torch.long)
|
||||
control_mode = control_mode.view(-1, 1).expand(control_image.shape[0], 1)
|
||||
|
||||
return control_image, control_mode
|
||||
|
||||
def prepare_multi_control(self, control_image, width, height, batch_size, num_images_per_prompt, device, control_mode):
|
||||
num_channels_latents = self.transformer.config.in_channels // 4
|
||||
control_images = []
|
||||
for _, control_image_ in enumerate(control_image):
|
||||
control_image_ = self.prepare_image(
|
||||
image=control_image_,
|
||||
width=width,
|
||||
height=height,
|
||||
batch_size=batch_size * num_images_per_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
dtype=self.vae.dtype,
|
||||
)
|
||||
height, width = control_image_.shape[-2:]
|
||||
|
||||
# vae encode
|
||||
control_image_ = self.vae.encode(control_image_).latent_dist.sample()
|
||||
control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
||||
|
||||
# pack
|
||||
height_control_image, width_control_image = control_image_.shape[2:]
|
||||
control_image_ = self._pack_latents(
|
||||
control_image_,
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height_control_image,
|
||||
width_control_image,
|
||||
)
|
||||
control_images.append(control_image_)
|
||||
|
||||
control_image = control_images
|
||||
|
||||
# Here we ensure that `control_mode` has the same length as the control_image.
|
||||
if isinstance(control_mode, list) and len(control_mode) != len(control_image):
|
||||
raise ValueError(
|
||||
"For Multi-ControlNet, `control_mode` must be a list of the same "
|
||||
+ " length as the number of controlnets (control images) specified"
|
||||
)
|
||||
if not isinstance(control_mode, list):
|
||||
control_mode = [control_mode] * len(control_image)
|
||||
# set control mode
|
||||
control_modes = []
|
||||
for cmode in control_mode:
|
||||
if cmode is None:
|
||||
cmode = -1
|
||||
control_mode = torch.tensor(cmode).expand(control_images[0].shape[0]).to(device, dtype=torch.long)
|
||||
control_modes.append(control_mode)
|
||||
control_mode = control_modes
|
||||
|
||||
return control_image, control_mode
|
||||
|
||||
def get_controlnet_keep(self, timesteps, control_guidance_start, control_guidance_end):
|
||||
controlnet_keep = []
|
||||
for i in range(len(timesteps)):
|
||||
keeps = [
|
||||
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
|
||||
for s, e in zip(control_guidance_start, control_guidance_end, strict=False)
|
||||
]
|
||||
controlnet_keep.append(keeps[0] if isinstance(self.controlnet, BriaControlNetModel) else keeps)
|
||||
return controlnet_keep
|
||||
|
||||
def get_control_start_end(self, control_guidance_start, control_guidance_end):
|
||||
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
|
||||
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
|
||||
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
|
||||
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
|
||||
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
|
||||
mult = 1 # TODO - why is this 1?
|
||||
control_guidance_start, control_guidance_end = (
|
||||
mult * [control_guidance_start],
|
||||
mult * [control_guidance_end],
|
||||
)
|
||||
|
||||
return control_guidance_start, control_guidance_end
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 30,
|
||||
timesteps: List[int] = None,
|
||||
guidance_scale: float = 3.5,
|
||||
control_guidance_start: Union[float, List[float]] = 0.0,
|
||||
control_guidance_end: Union[float, List[float]] = 1.0,
|
||||
control_image: Optional[PipelineImageInput] = None,
|
||||
control_mode: Optional[Union[int, List[int]]] = None,
|
||||
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
latent_image_ids: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
text_ids: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: Optional[List[str]] = None,
|
||||
max_sequence_length: int = 128,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
||||
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
||||
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
||||
passed will be used. Must be in descending order.
|
||||
guidance_scale (`float`, *optional*, defaults to 5.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
|
||||
of a plain tuple.
|
||||
joint_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
|
||||
Examples:
|
||||
Returns:
|
||||
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
|
||||
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
|
||||
`tuple`. When returning a tuple, the first element is a list with the generated images.
|
||||
"""
|
||||
|
||||
height = height or self.default_sample_size * self.vae_scale_factor
|
||||
width = width or self.default_sample_size * self.vae_scale_factor
|
||||
control_guidance_start, control_guidance_end = self.get_control_start_end(
|
||||
control_guidance_start=control_guidance_start, control_guidance_end=control_guidance_end
|
||||
)
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
callback_on_step_end_tensor_inputs = ["latents"] if callback_on_step_end_tensor_inputs is None else callback_on_step_end_tensor_inputs
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._joint_attention_kwargs = joint_attention_kwargs
|
||||
self._interrupt = False
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
|
||||
# 4. Prepare timesteps
|
||||
if isinstance(self.scheduler,FlowMatchEulerDiscreteScheduler) and self.scheduler.config['use_dynamic_shifting']:
|
||||
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
||||
|
||||
|
||||
# Determine image sequence length
|
||||
if control_image is not None:
|
||||
if isinstance(control_image, list):
|
||||
image_seq_len = control_image[0].shape[1]
|
||||
else:
|
||||
image_seq_len = control_image.shape[1]
|
||||
else:
|
||||
# Use latents sequence length when no control image is provided
|
||||
image_seq_len = latents.shape[1]
|
||||
|
||||
print(f"Using dynamic shift in pipeline with sequence length {image_seq_len}")
|
||||
|
||||
mu = calculate_shift(
|
||||
image_seq_len,
|
||||
self.scheduler.config.base_image_seq_len,
|
||||
self.scheduler.config.max_image_seq_len,
|
||||
self.scheduler.config.base_shift,
|
||||
self.scheduler.config.max_shift,
|
||||
)
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler,
|
||||
num_inference_steps,
|
||||
device,
|
||||
timesteps=None,
|
||||
sigmas=sigmas,
|
||||
mu=mu,
|
||||
)
|
||||
else:
|
||||
# 5. Prepare timesteps
|
||||
sigmas = get_original_sigmas(
|
||||
num_train_timesteps=self.scheduler.config.num_train_timesteps, num_inference_steps=num_inference_steps
|
||||
)
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler, num_inference_steps, device, timesteps, sigmas=sigmas
|
||||
)
|
||||
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 6. Create tensor stating which controlnets to keep
|
||||
if control_image is not None:
|
||||
controlnet_keep = self.get_controlnet_keep(
|
||||
timesteps=timesteps,
|
||||
control_guidance_start=control_guidance_start,
|
||||
control_guidance_end=control_guidance_end,
|
||||
)
|
||||
|
||||
if diffusers.__version__>='0.32.0':
|
||||
latent_image_ids=latent_image_ids[0]
|
||||
text_ids=text_ids[0]
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
|
||||
# EYAL - added the CFG loop
|
||||
# 7. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
# if type(self.scheduler) != FlowMatchEulerDiscreteScheduler:
|
||||
if not isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latent_model_input.shape[0])
|
||||
|
||||
# Handling ControlNet
|
||||
if control_image is not None:
|
||||
if isinstance(controlnet_keep[i], list):
|
||||
if isinstance(controlnet_conditioning_scale, list):
|
||||
cond_scale = controlnet_conditioning_scale
|
||||
else:
|
||||
cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i], strict=False)]
|
||||
else:
|
||||
controlnet_cond_scale = controlnet_conditioning_scale
|
||||
if isinstance(controlnet_cond_scale, list):
|
||||
controlnet_cond_scale = controlnet_cond_scale[0]
|
||||
cond_scale = controlnet_cond_scale * controlnet_keep[i]
|
||||
|
||||
controlnet_block_samples, controlnet_single_block_samples = self.controlnet(
|
||||
hidden_states=latents,
|
||||
controlnet_cond=control_image,
|
||||
controlnet_mode=control_mode,
|
||||
conditioning_scale=cond_scale,
|
||||
timestep=timestep,
|
||||
# guidance=guidance,
|
||||
# pooled_projections=pooled_prompt_embeds,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
txt_ids=text_ids,
|
||||
img_ids=latent_image_ids,
|
||||
joint_attention_kwargs=self.joint_attention_kwargs,
|
||||
return_dict=False,
|
||||
)
|
||||
else:
|
||||
controlnet_block_samples, controlnet_single_block_samples = None, None
|
||||
|
||||
# This is predicts "v" from flow-matching
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
joint_attention_kwargs=self.joint_attention_kwargs,
|
||||
return_dict=False,
|
||||
txt_ids=text_ids,
|
||||
img_ids=latent_image_ids,
|
||||
controlnet_block_samples=controlnet_block_samples,
|
||||
controlnet_single_block_samples=controlnet_single_block_samples,
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents_dtype = latents.dtype
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
if latents.dtype != latents_dtype:
|
||||
if torch.backends.mps.is_available():
|
||||
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
||||
latents = latents.to(latents_dtype)
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if output_type == "latent":
|
||||
image = latents
|
||||
|
||||
else:
|
||||
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
||||
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
||||
image = self.vae.decode(latents.to(dtype=self.vae.dtype), return_dict=False)[0]
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return FluxPipelineOutput(images=image)
|
||||
|
||||
|
||||
def encode_prompt(
|
||||
prompt: Union[str, List[str]],
|
||||
tokenizer: T5TokenizerFast,
|
||||
text_encoder: T5EncoderModel,
|
||||
device: Optional[torch.device] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
max_sequence_length: int = 128,
|
||||
lora_scale: Optional[float] = None,
|
||||
):
|
||||
r"""
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`):
|
||||
torch device
|
||||
num_images_per_prompt (`int`):
|
||||
number of images that should be generated per prompt
|
||||
do_classifier_free_guidance (`bool`):
|
||||
whether to use classifier free guidance or not
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
"""
|
||||
device = device or torch.device("cuda")
|
||||
|
||||
# set lora scale so that monkey patched LoRA
|
||||
# function of text encoder can correctly access it
|
||||
# dynamically adjust the LoRA scale
|
||||
if text_encoder is not None and USE_PEFT_BACKEND:
|
||||
scale_lora_layers(text_encoder, lora_scale)
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
if prompt is not None:
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
dtype = text_encoder.dtype if text_encoder is not None else torch.float32
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds = get_t5_prompt_embeds(
|
||||
tokenizer,
|
||||
text_encoder,
|
||||
prompt=prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
).to(dtype=dtype)
|
||||
|
||||
if negative_prompt_embeds is None:
|
||||
if not is_ng_none(negative_prompt):
|
||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
|
||||
negative_prompt_embeds = get_t5_prompt_embeds(
|
||||
tokenizer,
|
||||
text_encoder,
|
||||
prompt=negative_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
).to(dtype=dtype)
|
||||
else:
|
||||
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
|
||||
|
||||
if text_encoder is not None:
|
||||
if USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(text_encoder, lora_scale)
|
||||
|
||||
text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
||||
text_ids = text_ids.repeat(num_images_per_prompt, 1, 1)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds, text_ids
|
||||
|
||||
|
||||
def prepare_latents(
|
||||
batch_size: int,
|
||||
num_channels_latents: int,
|
||||
height: int,
|
||||
width: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
generator: torch.Generator,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
):
|
||||
# VAE applies 8x compression on images but we must also account for packing which requires
|
||||
# latent height and width to be divisible by 2.
|
||||
vae_scale_factor = 16
|
||||
height = 2 * (int(height) // vae_scale_factor)
|
||||
width = 2 * (int(width) // vae_scale_factor )
|
||||
|
||||
shape = (batch_size, num_channels_latents, height, width)
|
||||
|
||||
if latents is not None:
|
||||
latent_image_ids = _prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
|
||||
return latents.to(device=device, dtype=dtype), latent_image_ids
|
||||
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
latents = _pack_latents(latents, batch_size, num_channels_latents, height, width)
|
||||
|
||||
latent_image_ids = _prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
|
||||
|
||||
return latents, latent_image_ids
|
||||
|
||||
|
||||
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
|
||||
latent_image_ids = torch.zeros(height, width, 3)
|
||||
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
|
||||
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
|
||||
|
||||
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
||||
|
||||
latent_image_ids = latent_image_ids.repeat(batch_size, 1, 1, 1)
|
||||
latent_image_ids = latent_image_ids.reshape(
|
||||
batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
||||
)
|
||||
|
||||
return latent_image_ids.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
|
||||
|
||||
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
|
||||
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
||||
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
||||
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
|
||||
|
||||
return latents
|
||||
322
invokeai/backend/bria/transformer_bria.py
Normal file
322
invokeai/backend/bria/transformer_bria.py
Normal file
@@ -0,0 +1,322 @@
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from diffusers.models.embeddings import TimestepEmbedding, get_timestep_embedding
|
||||
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
from diffusers.models.normalization import AdaLayerNormContinuous
|
||||
from diffusers.models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
|
||||
from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
||||
|
||||
from invokeai.backend.bria.bria_utils import FluxPosEmbed as EmbedND
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class Timesteps(nn.Module):
|
||||
def __init__(
|
||||
self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1, time_theta=10000
|
||||
):
|
||||
super().__init__()
|
||||
self.num_channels = num_channels
|
||||
self.flip_sin_to_cos = flip_sin_to_cos
|
||||
self.downscale_freq_shift = downscale_freq_shift
|
||||
self.scale = scale
|
||||
self.time_theta = time_theta
|
||||
|
||||
def forward(self, timesteps):
|
||||
t_emb = get_timestep_embedding(
|
||||
timesteps,
|
||||
self.num_channels,
|
||||
flip_sin_to_cos=self.flip_sin_to_cos,
|
||||
downscale_freq_shift=self.downscale_freq_shift,
|
||||
scale=self.scale,
|
||||
max_period=self.time_theta,
|
||||
)
|
||||
return t_emb
|
||||
|
||||
|
||||
class TimestepProjEmbeddings(nn.Module):
|
||||
def __init__(self, embedding_dim, time_theta):
|
||||
super().__init__()
|
||||
|
||||
self.time_proj = Timesteps(
|
||||
num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, time_theta=time_theta
|
||||
)
|
||||
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
||||
|
||||
def forward(self, timestep, dtype):
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=dtype)) # (N, D)
|
||||
return timesteps_emb
|
||||
|
||||
|
||||
"""
|
||||
Based on FluxPipeline with several changes:
|
||||
- no pooled embeddings
|
||||
- We use zero padding for prompts
|
||||
- No guidance embedding since this is not a distilled version
|
||||
"""
|
||||
|
||||
|
||||
class BriaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
||||
"""
|
||||
The Transformer model introduced in Flux.
|
||||
|
||||
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
|
||||
|
||||
Parameters:
|
||||
patch_size (`int`): Patch size to turn the input data into small patches.
|
||||
in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
|
||||
num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
|
||||
num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
|
||||
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
|
||||
num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
|
||||
joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
||||
pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
|
||||
guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int = 1,
|
||||
in_channels: int = 64,
|
||||
num_layers: int = 19,
|
||||
num_single_layers: int = 38,
|
||||
attention_head_dim: int = 128,
|
||||
num_attention_heads: int = 24,
|
||||
joint_attention_dim: int = 4096,
|
||||
pooled_projection_dim: int = None,
|
||||
guidance_embeds: bool = False,
|
||||
axes_dims_rope: Optional[List[int]] = None,
|
||||
rope_theta=10000,
|
||||
time_theta=10000,
|
||||
):
|
||||
super().__init__()
|
||||
self.out_channels = in_channels
|
||||
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
|
||||
|
||||
axes_dims_rope = [16, 56, 56] if axes_dims_rope is None else axes_dims_rope
|
||||
self.pos_embed = EmbedND(theta=rope_theta, axes_dim=axes_dims_rope)
|
||||
|
||||
self.time_embed = TimestepProjEmbeddings(embedding_dim=self.inner_dim, time_theta=time_theta)
|
||||
|
||||
# if pooled_projection_dim:
|
||||
# self.pooled_text_embed = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim=self.inner_dim, act_fn="silu")
|
||||
|
||||
if guidance_embeds:
|
||||
self.guidance_embed = TimestepProjEmbeddings(embedding_dim=self.inner_dim)
|
||||
|
||||
self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim)
|
||||
self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
FluxTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=self.config.num_attention_heads,
|
||||
attention_head_dim=self.config.attention_head_dim,
|
||||
)
|
||||
for i in range(self.config.num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.single_transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
FluxSingleTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=self.config.num_attention_heads,
|
||||
attention_head_dim=self.config.attention_head_dim,
|
||||
)
|
||||
for i in range(self.config.num_single_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
|
||||
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor = None,
|
||||
pooled_projections: torch.Tensor = None,
|
||||
timestep: torch.LongTensor = None,
|
||||
img_ids: torch.Tensor = None,
|
||||
txt_ids: torch.Tensor = None,
|
||||
guidance: torch.Tensor = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
controlnet_block_samples=None,
|
||||
controlnet_single_block_samples=None,
|
||||
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
|
||||
"""
|
||||
The [`FluxTransformer2DModel`] forward method.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
|
||||
Input `hidden_states`.
|
||||
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
|
||||
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
||||
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
|
||||
from the embeddings of input conditions.
|
||||
timestep ( `torch.LongTensor`):
|
||||
Used to indicate denoising step.
|
||||
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
|
||||
A list of tensors that if specified are added to the residuals of transformer blocks.
|
||||
joint_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
|
||||
tuple.
|
||||
|
||||
Returns:
|
||||
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
||||
`tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
if joint_attention_kwargs is not None:
|
||||
joint_attention_kwargs = joint_attention_kwargs.copy()
|
||||
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
hidden_states = self.x_embedder(hidden_states)
|
||||
|
||||
timestep = timestep.to(hidden_states.dtype)
|
||||
if guidance is not None:
|
||||
guidance = guidance.to(hidden_states.dtype)
|
||||
else:
|
||||
guidance = None
|
||||
|
||||
# temb = (
|
||||
# self.time_text_embed(timestep, pooled_projections)
|
||||
# if guidance is None
|
||||
# else self.time_text_embed(timestep, guidance, pooled_projections)
|
||||
# )
|
||||
|
||||
temb = self.time_embed(timestep, dtype=hidden_states.dtype)
|
||||
|
||||
# if pooled_projections:
|
||||
# temb+=self.pooled_text_embed(pooled_projections)
|
||||
|
||||
if guidance:
|
||||
temb += self.guidance_embed(guidance, dtype=hidden_states.dtype)
|
||||
|
||||
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
||||
|
||||
if len(txt_ids.shape) == 2:
|
||||
ids = torch.cat((txt_ids, img_ids), dim=0)
|
||||
else:
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
image_rotary_emb = self.pos_embed(ids)
|
||||
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
|
||||
else:
|
||||
encoder_hidden_states, hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
|
||||
# controlnet residual
|
||||
if controlnet_block_samples is not None:
|
||||
interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
|
||||
interval_control = int(np.ceil(interval_control))
|
||||
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
|
||||
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
for index_block, block in enumerate(self.single_transformer_blocks):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
|
||||
else:
|
||||
hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
|
||||
# controlnet residual
|
||||
if controlnet_single_block_samples is not None:
|
||||
interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
|
||||
interval_control = int(np.ceil(interval_control))
|
||||
hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
|
||||
hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
||||
+ controlnet_single_block_samples[index_block // interval_control]
|
||||
)
|
||||
|
||||
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
||||
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
output = self.proj_out(hidden_states)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
@@ -187,7 +187,7 @@ class ModelConfigBase(ABC, BaseModel):
|
||||
else:
|
||||
return config_cls.from_model_on_disk(mod, **overrides)
|
||||
|
||||
raise InvalidModelConfigException("No valid config found")
|
||||
raise InvalidModelConfigException("Unable to determine model type")
|
||||
|
||||
@classmethod
|
||||
def get_tag(cls) -> Tag:
|
||||
|
||||
@@ -125,6 +125,8 @@ class ModelProbe(object):
|
||||
}
|
||||
|
||||
CLASS2TYPE = {
|
||||
"BriaPipeline": ModelType.Main,
|
||||
"BriaTransformer2DModel": ModelType.ControlNet,
|
||||
"FluxPipeline": ModelType.Main,
|
||||
"StableDiffusionPipeline": ModelType.Main,
|
||||
"StableDiffusionInpaintPipeline": ModelType.Main,
|
||||
@@ -861,6 +863,8 @@ class PipelineFolderProbe(FolderProbeBase):
|
||||
return BaseModelType.StableDiffusion3
|
||||
elif transformer_conf["_class_name"] == "CogView4Transformer2DModel":
|
||||
return BaseModelType.CogView4
|
||||
elif transformer_conf["_class_name"] == "BriaTransformer2DModel":
|
||||
return BaseModelType.Bria
|
||||
else:
|
||||
raise InvalidModelConfigException(f"Unknown base model for {self.model_path}")
|
||||
|
||||
@@ -1010,6 +1014,9 @@ class ControlNetFolderProbe(FolderProbeBase):
|
||||
if config.get("_class_name", None) == "FluxControlNetModel":
|
||||
return BaseModelType.Flux
|
||||
|
||||
if config.get("_class_name", None) == "BriaTransformer2DModel":
|
||||
return BaseModelType.Bria
|
||||
|
||||
# no obvious way to distinguish between sd2-base and sd2-768
|
||||
dimension = config["cross_attention_dim"]
|
||||
if dimension == 768:
|
||||
|
||||
95
invokeai/backend/model_manager/load/model_loaders/bria.py
Normal file
95
invokeai/backend/model_manager/load/model_loaders/bria.py
Normal file
@@ -0,0 +1,95 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
CheckpointConfigBase,
|
||||
ControlNetCheckpointConfig,
|
||||
ControlNetDiffusersConfig,
|
||||
DiffusersConfigBase,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
|
||||
from invokeai.backend.model_manager.taxonomy import (
|
||||
AnyModel,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Bria, type=ModelType.ControlNet, format=ModelFormat.Diffusers)
|
||||
class BriaControlNetDiffusersModel(GenericDiffusersLoader):
|
||||
"""Class to load Bria control net models."""
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if isinstance(config, ControlNetCheckpointConfig):
|
||||
raise NotImplementedError("CheckpointConfigBase is not implemented for Bria models.")
|
||||
|
||||
model_path = Path(config.path)
|
||||
load_class = self.get_hf_load_class(model_path)
|
||||
repo_variant = config.repo_variant if isinstance(config, ControlNetDiffusersConfig) else None
|
||||
variant = repo_variant.value if repo_variant else None
|
||||
model_path = model_path
|
||||
|
||||
dtype = self._torch_dtype
|
||||
|
||||
try:
|
||||
result: AnyModel = load_class.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=dtype,
|
||||
variant=variant,
|
||||
use_safetensors=False,
|
||||
)
|
||||
except OSError as e:
|
||||
if variant and "no file named" in str(
|
||||
e
|
||||
): # try without the variant, just in case user's preferences changed
|
||||
result = load_class.from_pretrained(model_path, torch_dtype=dtype)
|
||||
else:
|
||||
raise e
|
||||
|
||||
return result
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Bria, type=ModelType.Main, format=ModelFormat.Diffusers)
|
||||
class BriaDiffusersModel(GenericDiffusersLoader):
|
||||
"""Class to load Bria main models."""
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if isinstance(config, CheckpointConfigBase):
|
||||
raise NotImplementedError("CheckpointConfigBase is not implemented for Bria models.")
|
||||
|
||||
if submodel_type is None:
|
||||
raise Exception("A submodel type must be provided when loading main pipelines.")
|
||||
|
||||
model_path = Path(config.path)
|
||||
load_class = self.get_hf_load_class(model_path, submodel_type)
|
||||
repo_variant = config.repo_variant if isinstance(config, DiffusersConfigBase) else None
|
||||
variant = repo_variant.value if repo_variant else None
|
||||
model_path = model_path / submodel_type.value
|
||||
|
||||
dtype = self._torch_dtype
|
||||
try:
|
||||
result: AnyModel = load_class.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=dtype,
|
||||
variant=variant,
|
||||
)
|
||||
except OSError as e:
|
||||
if variant and "no file named" in str(
|
||||
e
|
||||
): # try without the variant, just in case user's preferences changed
|
||||
result = load_class.from_pretrained(model_path, torch_dtype=dtype)
|
||||
else:
|
||||
raise e
|
||||
|
||||
return result
|
||||
@@ -80,7 +80,13 @@ class GenericDiffusersLoader(ModelLoader):
|
||||
"transformers",
|
||||
"invokeai.backend.quantization.fast_quantized_transformers_model",
|
||||
"invokeai.backend.quantization.fast_quantized_diffusion_model",
|
||||
"transformer_bria",
|
||||
]:
|
||||
if module == "transformer_bria":
|
||||
module = "invokeai.backend.bria.transformer_bria"
|
||||
elif class_name == "BriaTransformer2DModel":
|
||||
class_name = "BriaControlNetModel"
|
||||
module = "invokeai.backend.bria.controlnet_bria"
|
||||
res_type = sys.modules[module]
|
||||
else:
|
||||
res_type = sys.modules["diffusers"].pipelines
|
||||
|
||||
@@ -12,6 +12,9 @@ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
||||
from transformers import CLIPTokenizer, T5Tokenizer, T5TokenizerFast
|
||||
|
||||
from invokeai.backend.bria.controlnet_aux.open_pose.body import Body
|
||||
from invokeai.backend.bria.controlnet_aux.open_pose.face import Face
|
||||
from invokeai.backend.bria.controlnet_aux.open_pose.hand import Hand
|
||||
from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import DepthAnythingPipeline
|
||||
from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline
|
||||
from invokeai.backend.image_util.segment_anything.segment_anything_pipeline import SegmentAnythingPipeline
|
||||
@@ -62,6 +65,8 @@ def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int:
|
||||
else:
|
||||
# If neither is available, return 0
|
||||
return 0
|
||||
elif isinstance(model, (Body, Hand, Face)):
|
||||
return calc_module_size(model.model)
|
||||
elif isinstance(
|
||||
model,
|
||||
(
|
||||
|
||||
@@ -143,11 +143,19 @@ flux_dev = StarterModel(
|
||||
flux_kontext = StarterModel(
|
||||
name="FLUX.1 Kontext dev",
|
||||
base=BaseModelType.Flux,
|
||||
source="black-forest-labs/FLUX.1-Kontext-dev::flux1-kontext-dev.safetensors",
|
||||
source="https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev/resolve/main/flux1-kontext-dev.safetensors",
|
||||
description="FLUX.1 Kontext dev transformer in bfloat16. Total size with dependencies: ~33GB",
|
||||
type=ModelType.Main,
|
||||
dependencies=[t5_base_encoder, flux_vae, clip_l_encoder],
|
||||
)
|
||||
flux_kontext_quantized = StarterModel(
|
||||
name="FLUX.1 Kontext dev (Quantized)",
|
||||
base=BaseModelType.Flux,
|
||||
source="https://huggingface.co/unsloth/FLUX.1-Kontext-dev-GGUF/resolve/main/flux1-kontext-dev-Q4_K_M.gguf",
|
||||
description="FLUX.1 Kontext dev quantized (q4_k_m). Total size with dependencies: ~14GB",
|
||||
type=ModelType.Main,
|
||||
dependencies=[t5_8b_quantized_encoder, flux_vae, clip_l_encoder],
|
||||
)
|
||||
sd35_medium = StarterModel(
|
||||
name="SD3.5 Medium",
|
||||
base=BaseModelType.StableDiffusion3,
|
||||
@@ -664,7 +672,7 @@ flux_fill = StarterModel(
|
||||
# List of starter models, displayed on the frontend.
|
||||
# The order/sort of this list is not changed by the frontend - set it how you want it here.
|
||||
STARTER_MODELS: list[StarterModel] = [
|
||||
flux_kontext,
|
||||
flux_kontext_quantized,
|
||||
flux_schnell_quantized,
|
||||
flux_dev_quantized,
|
||||
flux_schnell,
|
||||
@@ -785,7 +793,7 @@ flux_bundle: list[StarterModel] = [
|
||||
flux_depth_control_lora,
|
||||
flux_redux,
|
||||
flux_fill,
|
||||
flux_kontext,
|
||||
flux_kontext_quantized,
|
||||
]
|
||||
|
||||
STARTER_BUNDLES: dict[str, StarterModelBundle] = {
|
||||
|
||||
@@ -30,6 +30,7 @@ class BaseModelType(str, Enum):
|
||||
Imagen4 = "imagen4"
|
||||
ChatGPT4o = "chatgpt-4o"
|
||||
FluxKontext = "flux-kontext"
|
||||
Bria = "bria"
|
||||
|
||||
|
||||
class ModelType(str, Enum):
|
||||
|
||||
@@ -1,10 +0,0 @@
|
||||
dist/
|
||||
static/
|
||||
.husky/
|
||||
node_modules/
|
||||
patches/
|
||||
stats.html
|
||||
index.html
|
||||
.yarn/
|
||||
*.scss
|
||||
src/services/api/schema.ts
|
||||
@@ -1,88 +0,0 @@
|
||||
module.exports = {
|
||||
extends: ['@invoke-ai/eslint-config-react'],
|
||||
plugins: ['path', 'i18next'],
|
||||
rules: {
|
||||
// TODO(psyche): Enable this rule. Requires no default exports in components - many changes.
|
||||
'react-refresh/only-export-components': 'off',
|
||||
// TODO(psyche): Enable this rule. Requires a lot of eslint-disable-next-line comments.
|
||||
'@typescript-eslint/consistent-type-assertions': 'off',
|
||||
// https://github.com/qdanik/eslint-plugin-path
|
||||
'path/no-relative-imports': ['error', { maxDepth: 0 }],
|
||||
// https://github.com/edvardchen/eslint-plugin-i18next/blob/HEAD/docs/rules/no-literal-string.md
|
||||
// TODO: ENABLE THIS RULE BEFORE v6.0.0
|
||||
// 'i18next/no-literal-string': 'error',
|
||||
// https://eslint.org/docs/latest/rules/no-console
|
||||
'no-console': 'warn',
|
||||
// https://eslint.org/docs/latest/rules/no-promise-executor-return
|
||||
'no-promise-executor-return': 'error',
|
||||
// https://eslint.org/docs/latest/rules/require-await
|
||||
'require-await': 'error',
|
||||
// Restrict setActiveTab calls to only use-navigation-api.tsx
|
||||
'no-restricted-syntax': [
|
||||
'error',
|
||||
{
|
||||
selector: 'CallExpression[callee.name="setActiveTab"]',
|
||||
message:
|
||||
'setActiveTab() can only be called from use-navigation-api.tsx. Use navigationApi.switchToTab() instead.',
|
||||
},
|
||||
],
|
||||
// TODO: ENABLE THIS RULE BEFORE v6.0.0
|
||||
'react/display-name': 'off',
|
||||
'no-restricted-properties': [
|
||||
'error',
|
||||
{
|
||||
object: 'crypto',
|
||||
property: 'randomUUID',
|
||||
message: 'Use of crypto.randomUUID is not allowed as it is not available in all browsers.',
|
||||
},
|
||||
{
|
||||
object: 'navigator',
|
||||
property: 'clipboard',
|
||||
message:
|
||||
'The Clipboard API is not available by default in Firefox. Use the `useClipboard` hook instead, which wraps clipboard access to prevent errors.',
|
||||
},
|
||||
],
|
||||
'no-restricted-imports': [
|
||||
'error',
|
||||
{
|
||||
paths: [
|
||||
{
|
||||
name: 'lodash-es',
|
||||
importNames: ['isEqual'],
|
||||
message: 'Please use objectEquals from @observ33r/object-equals instead.',
|
||||
},
|
||||
{
|
||||
name: 'lodash-es',
|
||||
message: 'Please use es-toolkit instead.',
|
||||
},
|
||||
{
|
||||
name: 'es-toolkit',
|
||||
importNames: ['isEqual'],
|
||||
message: 'Please use objectEquals from @observ33r/object-equals instead.',
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
},
|
||||
overrides: [
|
||||
/**
|
||||
* Allow setActiveTab calls only in use-navigation-api.tsx
|
||||
*/
|
||||
{
|
||||
files: ['**/use-navigation-api.tsx'],
|
||||
rules: {
|
||||
'no-restricted-syntax': 'off',
|
||||
},
|
||||
},
|
||||
/**
|
||||
* Overrides for stories
|
||||
*/
|
||||
{
|
||||
files: ['*.stories.tsx'],
|
||||
rules: {
|
||||
// We may not have i18n available in stories.
|
||||
'i18next/no-literal-string': 'off',
|
||||
},
|
||||
},
|
||||
],
|
||||
};
|
||||
@@ -14,3 +14,4 @@ static/
|
||||
src/theme/css/overlayscrollbars.css
|
||||
src/theme_/css/overlayscrollbars.css
|
||||
pnpm-lock.yaml
|
||||
.claude
|
||||
|
||||
@@ -1,11 +0,0 @@
|
||||
module.exports = {
|
||||
...require('@invoke-ai/prettier-config-react'),
|
||||
overrides: [
|
||||
{
|
||||
files: ['public/locales/*.json'],
|
||||
options: {
|
||||
tabWidth: 4,
|
||||
},
|
||||
},
|
||||
],
|
||||
};
|
||||
17
invokeai/frontend/web/.prettierrc.json
Normal file
17
invokeai/frontend/web/.prettierrc.json
Normal file
@@ -0,0 +1,17 @@
|
||||
{
|
||||
"$schema": "http://json.schemastore.org/prettierrc",
|
||||
"trailingComma": "es5",
|
||||
"printWidth": 120,
|
||||
"tabWidth": 2,
|
||||
"semi": true,
|
||||
"singleQuote": true,
|
||||
"endOfLine": "auto",
|
||||
"overrides": [
|
||||
{
|
||||
"files": ["public/locales/*.json"],
|
||||
"options": {
|
||||
"tabWidth": 4
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -1,21 +1,23 @@
|
||||
import { PropsWithChildren, memo, useEffect } from 'react';
|
||||
import { modelChanged } from '../src/features/controlLayers/store/paramsSlice';
|
||||
import { useAppDispatch } from '../src/app/store/storeHooks';
|
||||
import { useGlobalModifiersInit } from '@invoke-ai/ui-library';
|
||||
import type { PropsWithChildren } from 'react';
|
||||
import { memo, useEffect } from 'react';
|
||||
|
||||
import { useAppDispatch } from '../src/app/store/storeHooks';
|
||||
import { modelChanged } from '../src/features/controlLayers/store/paramsSlice';
|
||||
/**
|
||||
* Initializes some state for storybook. Must be in a different component
|
||||
* so that it is run inside the redux context.
|
||||
*/
|
||||
export const ReduxInit = memo((props: PropsWithChildren) => {
|
||||
export const ReduxInit = memo(({ children }: PropsWithChildren) => {
|
||||
const dispatch = useAppDispatch();
|
||||
useGlobalModifiersInit();
|
||||
useEffect(() => {
|
||||
dispatch(
|
||||
modelChanged({ model: { key: 'test_model', hash: 'some_hash', name: 'some name', base: 'sd-1', type: 'main' } })
|
||||
);
|
||||
}, []);
|
||||
}, [dispatch]);
|
||||
|
||||
return props.children;
|
||||
return children;
|
||||
});
|
||||
|
||||
ReduxInit.displayName = 'ReduxInit';
|
||||
|
||||
@@ -2,19 +2,13 @@ import type { StorybookConfig } from '@storybook/react-vite';
|
||||
|
||||
const config: StorybookConfig = {
|
||||
stories: ['../src/**/*.mdx', '../src/**/*.stories.@(js|jsx|mjs|ts|tsx)'],
|
||||
addons: [
|
||||
'@storybook/addon-links',
|
||||
'@storybook/addon-essentials',
|
||||
'@storybook/addon-interactions',
|
||||
'@storybook/addon-storysource',
|
||||
],
|
||||
addons: ['@storybook/addon-links', '@storybook/addon-docs'],
|
||||
|
||||
framework: {
|
||||
name: '@storybook/react-vite',
|
||||
options: {},
|
||||
},
|
||||
docs: {
|
||||
autodocs: 'tag',
|
||||
},
|
||||
|
||||
core: {
|
||||
disableTelemetry: true,
|
||||
},
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { addons } from '@storybook/manager-api';
|
||||
import { themes } from '@storybook/theming';
|
||||
import { addons } from 'storybook/manager-api';
|
||||
import { themes } from 'storybook/theming';
|
||||
|
||||
addons.setConfig({
|
||||
theme: themes.dark,
|
||||
|
||||
@@ -1,17 +1,18 @@
|
||||
import { Preview } from '@storybook/react';
|
||||
import { themes } from '@storybook/theming';
|
||||
import type { Preview } from '@storybook/react-vite';
|
||||
import { themes } from 'storybook/theming';
|
||||
import { $store } from 'app/store/nanostores/store';
|
||||
import i18n from 'i18next';
|
||||
import { initReactI18next } from 'react-i18next';
|
||||
import { Provider } from 'react-redux';
|
||||
import ThemeLocaleProvider from '../src/app/components/ThemeLocaleProvider';
|
||||
import { $baseUrl } from '../src/app/store/nanostores/baseUrl';
|
||||
import { createStore } from '../src/app/store/store';
|
||||
|
||||
// TODO: Disabled for IDE performance issues with our translation JSON
|
||||
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
|
||||
// @ts-ignore
|
||||
import translationEN from '../public/locales/en.json';
|
||||
import ThemeLocaleProvider from '../src/app/components/ThemeLocaleProvider';
|
||||
import { $baseUrl } from '../src/app/store/nanostores/baseUrl';
|
||||
import { createStore } from '../src/app/store/store';
|
||||
import { ReduxInit } from './ReduxInit';
|
||||
import { $store } from 'app/store/nanostores/store';
|
||||
|
||||
i18n.use(initReactI18next).init({
|
||||
lng: 'en',
|
||||
@@ -46,6 +47,7 @@ const preview: Preview = {
|
||||
parameters: {
|
||||
docs: {
|
||||
theme: themes.dark,
|
||||
codePanel: true,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
242
invokeai/frontend/web/eslint.config.mjs
Normal file
242
invokeai/frontend/web/eslint.config.mjs
Normal file
@@ -0,0 +1,242 @@
|
||||
import js from '@eslint/js';
|
||||
import typescriptEslint from '@typescript-eslint/eslint-plugin';
|
||||
import typescriptParser from '@typescript-eslint/parser';
|
||||
import pluginI18Next from 'eslint-plugin-i18next';
|
||||
import pluginImport from 'eslint-plugin-import';
|
||||
import pluginPath from 'eslint-plugin-path';
|
||||
import pluginReact from 'eslint-plugin-react';
|
||||
import pluginReactHooks from 'eslint-plugin-react-hooks';
|
||||
import pluginReactRefresh from 'eslint-plugin-react-refresh';
|
||||
import pluginSimpleImportSort from 'eslint-plugin-simple-import-sort';
|
||||
import pluginStorybook from 'eslint-plugin-storybook';
|
||||
import pluginUnusedImports from 'eslint-plugin-unused-imports';
|
||||
import globals from 'globals';
|
||||
|
||||
export default [
|
||||
js.configs.recommended,
|
||||
|
||||
{
|
||||
languageOptions: {
|
||||
parser: typescriptParser,
|
||||
parserOptions: {
|
||||
ecmaFeatures: {
|
||||
jsx: true,
|
||||
},
|
||||
},
|
||||
globals: {
|
||||
...globals.browser,
|
||||
...globals.node,
|
||||
GlobalCompositeOperation: 'readonly',
|
||||
RequestInit: 'readonly',
|
||||
},
|
||||
},
|
||||
|
||||
files: ['**/*.ts', '**/*.tsx', '**/*.js', '**/*.jsx'],
|
||||
|
||||
plugins: {
|
||||
react: pluginReact,
|
||||
'@typescript-eslint': typescriptEslint,
|
||||
'react-hooks': pluginReactHooks,
|
||||
import: pluginImport,
|
||||
'unused-imports': pluginUnusedImports,
|
||||
'simple-import-sort': pluginSimpleImportSort,
|
||||
'react-refresh': pluginReactRefresh.configs.vite,
|
||||
path: pluginPath,
|
||||
i18next: pluginI18Next,
|
||||
storybook: pluginStorybook,
|
||||
},
|
||||
|
||||
rules: {
|
||||
...typescriptEslint.configs.recommended.rules,
|
||||
...pluginReact.configs.recommended.rules,
|
||||
...pluginReact.configs['jsx-runtime'].rules,
|
||||
...pluginReactHooks.configs.recommended.rules,
|
||||
...pluginStorybook.configs.recommended.rules,
|
||||
|
||||
'react/jsx-no-bind': [
|
||||
'error',
|
||||
{
|
||||
allowBind: true,
|
||||
},
|
||||
],
|
||||
|
||||
'react/jsx-curly-brace-presence': [
|
||||
'error',
|
||||
{
|
||||
props: 'never',
|
||||
children: 'never',
|
||||
},
|
||||
],
|
||||
|
||||
'react-hooks/exhaustive-deps': 'error',
|
||||
|
||||
curly: 'error',
|
||||
'no-var': 'error',
|
||||
'brace-style': 'error',
|
||||
'prefer-template': 'error',
|
||||
radix: 'error',
|
||||
'space-before-blocks': 'error',
|
||||
eqeqeq: 'error',
|
||||
'one-var': ['error', 'never'],
|
||||
'no-eval': 'error',
|
||||
'no-extend-native': 'error',
|
||||
'no-implied-eval': 'error',
|
||||
'no-label-var': 'error',
|
||||
'no-return-assign': 'error',
|
||||
'no-sequences': 'error',
|
||||
'no-template-curly-in-string': 'error',
|
||||
'no-throw-literal': 'error',
|
||||
'no-unmodified-loop-condition': 'error',
|
||||
'import/no-duplicates': 'error',
|
||||
'import/prefer-default-export': 'off',
|
||||
'unused-imports/no-unused-imports': 'error',
|
||||
|
||||
'unused-imports/no-unused-vars': [
|
||||
'error',
|
||||
{
|
||||
vars: 'all',
|
||||
varsIgnorePattern: '^_',
|
||||
args: 'after-used',
|
||||
argsIgnorePattern: '^_',
|
||||
},
|
||||
],
|
||||
|
||||
'simple-import-sort/imports': 'error',
|
||||
'simple-import-sort/exports': 'error',
|
||||
'@typescript-eslint/no-unused-vars': 'off',
|
||||
|
||||
'@typescript-eslint/ban-ts-comment': [
|
||||
'error',
|
||||
{
|
||||
'ts-expect-error': 'allow-with-description',
|
||||
'ts-ignore': true,
|
||||
'ts-nocheck': true,
|
||||
'ts-check': false,
|
||||
minimumDescriptionLength: 10,
|
||||
},
|
||||
],
|
||||
|
||||
'@typescript-eslint/no-empty-interface': [
|
||||
'error',
|
||||
{
|
||||
allowSingleExtends: true,
|
||||
},
|
||||
],
|
||||
|
||||
'@typescript-eslint/consistent-type-imports': [
|
||||
'error',
|
||||
{
|
||||
prefer: 'type-imports',
|
||||
fixStyle: 'separate-type-imports',
|
||||
disallowTypeAnnotations: true,
|
||||
},
|
||||
],
|
||||
|
||||
'@typescript-eslint/no-import-type-side-effects': 'error',
|
||||
|
||||
'@typescript-eslint/consistent-type-assertions': [
|
||||
'error',
|
||||
{
|
||||
assertionStyle: 'as',
|
||||
},
|
||||
],
|
||||
|
||||
'path/no-relative-imports': [
|
||||
'error',
|
||||
{
|
||||
maxDepth: 0,
|
||||
},
|
||||
],
|
||||
|
||||
'no-console': 'warn',
|
||||
'no-promise-executor-return': 'error',
|
||||
'require-await': 'error',
|
||||
|
||||
'no-restricted-syntax': [
|
||||
'error',
|
||||
{
|
||||
selector: 'CallExpression[callee.name="setActiveTab"]',
|
||||
message:
|
||||
'setActiveTab() can only be called from use-navigation-api.tsx. Use navigationApi.switchToTab() instead.',
|
||||
},
|
||||
],
|
||||
|
||||
'no-restricted-properties': [
|
||||
'error',
|
||||
{
|
||||
object: 'crypto',
|
||||
property: 'randomUUID',
|
||||
message: 'Use of crypto.randomUUID is not allowed as it is not available in all browsers.',
|
||||
},
|
||||
{
|
||||
object: 'navigator',
|
||||
property: 'clipboard',
|
||||
message:
|
||||
'The Clipboard API is not available by default in Firefox. Use the `useClipboard` hook instead, which wraps clipboard access to prevent errors.',
|
||||
},
|
||||
],
|
||||
|
||||
// Typescript handles this for us: https://eslint.org/docs/latest/rules/no-redeclare#handled_by_typescript
|
||||
'no-redeclare': 'off',
|
||||
|
||||
'no-restricted-imports': [
|
||||
'error',
|
||||
{
|
||||
paths: [
|
||||
{
|
||||
name: 'lodash-es',
|
||||
importNames: ['isEqual'],
|
||||
message: 'Please use objectEquals from @observ33r/object-equals instead.',
|
||||
},
|
||||
{
|
||||
name: 'lodash-es',
|
||||
message: 'Please use es-toolkit instead.',
|
||||
},
|
||||
{
|
||||
name: 'es-toolkit',
|
||||
importNames: ['isEqual'],
|
||||
message: 'Please use objectEquals from @observ33r/object-equals instead.',
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
},
|
||||
|
||||
settings: {
|
||||
react: {
|
||||
version: 'detect',
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
{
|
||||
files: ['**/use-navigation-api.tsx'],
|
||||
rules: {
|
||||
'no-restricted-syntax': 'off',
|
||||
},
|
||||
},
|
||||
|
||||
{
|
||||
files: ['**/*.stories.tsx'],
|
||||
rules: {
|
||||
'i18next/no-literal-string': 'off',
|
||||
},
|
||||
},
|
||||
|
||||
{
|
||||
ignores: [
|
||||
'**/dist/',
|
||||
'**/static/',
|
||||
'**/.husky/',
|
||||
'**/node_modules/',
|
||||
'**/patches/',
|
||||
'**/stats.html',
|
||||
'**/index.html',
|
||||
'**/.yarn/',
|
||||
'**/*.scss',
|
||||
'src/services/api/schema.ts',
|
||||
'.prettierrc.js',
|
||||
'.storybook',
|
||||
],
|
||||
},
|
||||
];
|
||||
@@ -12,6 +12,9 @@ const config: KnipConfig = {
|
||||
'src/features/parameters/types/parameterSchemas.ts',
|
||||
// TODO(psyche): maybe we can clean up these utils after canvas v2 release
|
||||
'src/features/controlLayers/konva/util.ts',
|
||||
// Will be using this
|
||||
'src/common/hooks/useAsyncState.ts',
|
||||
'src/app/store/use-debounced-app-selector.ts',
|
||||
],
|
||||
ignoreBinaries: ['only-allow'],
|
||||
paths: {
|
||||
|
||||
@@ -47,25 +47,25 @@
|
||||
"@fontsource-variable/inter": "^5.2.6",
|
||||
"@invoke-ai/ui-library": "^0.0.46",
|
||||
"@nanostores/react": "^1.0.0",
|
||||
"@observ33r/object-equals": "^1.1.4",
|
||||
"@observ33r/object-equals": "^1.1.5",
|
||||
"@reduxjs/toolkit": "2.8.2",
|
||||
"@roarr/browser-log-writer": "^1.3.0",
|
||||
"@xyflow/react": "^12.7.1",
|
||||
"ag-psd": "^28.2.1",
|
||||
"@xyflow/react": "^12.8.2",
|
||||
"ag-psd": "^28.2.2",
|
||||
"async-mutex": "^0.5.0",
|
||||
"chakra-react-select": "^4.9.2",
|
||||
"cmdk": "^1.1.1",
|
||||
"compare-versions": "^6.1.1",
|
||||
"dockview": "^4.4.0",
|
||||
"es-toolkit": "^1.39.5",
|
||||
"dockview": "^4.4.1",
|
||||
"es-toolkit": "^1.39.7",
|
||||
"filesize": "^10.1.6",
|
||||
"fracturedjsonjs": "^4.1.0",
|
||||
"framer-motion": "^11.10.0",
|
||||
"i18next": "^25.2.1",
|
||||
"i18next": "^25.3.2",
|
||||
"i18next-http-backend": "^3.0.2",
|
||||
"idb-keyval": "^6.2.2",
|
||||
"idb-keyval": "6.2.2",
|
||||
"jsondiffpatch": "^0.7.3",
|
||||
"konva": "^9.3.20",
|
||||
"konva": "^9.3.22",
|
||||
"linkify-react": "^4.3.1",
|
||||
"linkifyjs": "^4.3.1",
|
||||
"lru-cache": "^11.1.0",
|
||||
@@ -83,7 +83,7 @@
|
||||
"react-dom": "^18.3.1",
|
||||
"react-dropzone": "^14.3.8",
|
||||
"react-error-boundary": "^5.0.0",
|
||||
"react-hook-form": "^7.58.1",
|
||||
"react-hook-form": "^7.60.0",
|
||||
"react-hotkeys-hook": "4.5.0",
|
||||
"react-i18next": "^15.5.3",
|
||||
"react-icons": "^5.5.0",
|
||||
@@ -103,7 +103,7 @@
|
||||
"use-debounce": "^10.0.5",
|
||||
"use-device-pixel-ratio": "^1.1.2",
|
||||
"uuid": "^11.1.0",
|
||||
"zod": "^3.25.67",
|
||||
"zod": "^4.0.5",
|
||||
"zod-validation-error": "^3.5.2"
|
||||
},
|
||||
"peerDependencies": {
|
||||
@@ -111,39 +111,43 @@
|
||||
"react-dom": "^18.2.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@invoke-ai/eslint-config-react": "^0.0.14",
|
||||
"@invoke-ai/prettier-config-react": "^0.0.7",
|
||||
"@storybook/addon-essentials": "^8.6.12",
|
||||
"@storybook/addon-interactions": "^8.6.12",
|
||||
"@storybook/addon-links": "^8.6.12",
|
||||
"@storybook/addon-storysource": "^8.6.12",
|
||||
"@storybook/manager-api": "^8.6.12",
|
||||
"@storybook/react": "^8.6.12",
|
||||
"@storybook/react-vite": "^8.6.12",
|
||||
"@storybook/theming": "^8.6.12",
|
||||
"@eslint/js": "^9.31.0",
|
||||
"@storybook/addon-docs": "^9.0.17",
|
||||
"@storybook/addon-links": "^9.0.17",
|
||||
"@storybook/react-vite": "^9.0.17",
|
||||
"@types/node": "^22.15.1",
|
||||
"@types/react": "^18.3.11",
|
||||
"@types/react-dom": "^18.3.0",
|
||||
"@types/uuid": "^10.0.0",
|
||||
"@typescript-eslint/eslint-plugin": "^8.37.0",
|
||||
"@typescript-eslint/parser": "^8.37.0",
|
||||
"@vitejs/plugin-react-swc": "^3.9.0",
|
||||
"@vitest/coverage-v8": "^3.1.2",
|
||||
"@vitest/ui": "^3.1.2",
|
||||
"concurrently": "^9.1.2",
|
||||
"csstype": "^3.1.3",
|
||||
"dpdm": "^3.14.0",
|
||||
"eslint": "^8.57.1",
|
||||
"eslint-plugin-i18next": "^6.1.1",
|
||||
"eslint-plugin-path": "^1.3.0",
|
||||
"eslint": "^9.31.0",
|
||||
"eslint-plugin-i18next": "^6.1.2",
|
||||
"eslint-plugin-import": "^2.29.1",
|
||||
"eslint-plugin-path": "^2.0.3",
|
||||
"eslint-plugin-react": "^7.33.2",
|
||||
"eslint-plugin-react-hooks": "^5.2.0",
|
||||
"eslint-plugin-react-refresh": "^0.4.5",
|
||||
"eslint-plugin-simple-import-sort": "^12.0.0",
|
||||
"eslint-plugin-storybook": "^9.0.17",
|
||||
"eslint-plugin-unused-imports": "^4.1.4",
|
||||
"globals": "^16.3.0",
|
||||
"knip": "^5.61.3",
|
||||
"openapi-types": "^12.1.3",
|
||||
"openapi-typescript": "^7.6.1",
|
||||
"prettier": "^3.5.3",
|
||||
"rollup-plugin-visualizer": "^5.14.0",
|
||||
"storybook": "^8.6.12",
|
||||
"rollup-plugin-visualizer": "^6.0.3",
|
||||
"storybook": "^9.0.17",
|
||||
"tsafe": "^1.8.5",
|
||||
"type-fest": "^4.40.0",
|
||||
"typescript": "^5.8.3",
|
||||
"vite": "^7.0.2",
|
||||
"vite": "^7.0.5",
|
||||
"vite-plugin-css-injected-by-js": "^3.5.2",
|
||||
"vite-plugin-dts": "^4.5.3",
|
||||
"vite-plugin-eslint": "^1.8.1",
|
||||
|
||||
2174
invokeai/frontend/web/pnpm-lock.yaml
generated
2174
invokeai/frontend/web/pnpm-lock.yaml
generated
File diff suppressed because it is too large
Load Diff
@@ -574,6 +574,10 @@
|
||||
"title": "Transform",
|
||||
"desc": "Transform the selected layer."
|
||||
},
|
||||
"invertMask": {
|
||||
"title": "Invert Mask",
|
||||
"desc": "Invert the selected inpaint mask, creating a new mask with opposite transparency."
|
||||
},
|
||||
"applyFilter": {
|
||||
"title": "Apply Filter",
|
||||
"desc": "Apply the pending filter to the selected layer."
|
||||
@@ -599,6 +603,10 @@
|
||||
"toggleNonRasterLayers": {
|
||||
"title": "Toggle Non-Raster Layers",
|
||||
"desc": "Show or hide all non-raster layer categories (Control Layers, Inpaint Masks, Regional Guidance)."
|
||||
},
|
||||
"fitBboxToMasks": {
|
||||
"title": "Fit Bbox To Masks",
|
||||
"desc": "Automatically adjust the generation bounding box to fit visible inpaint masks"
|
||||
}
|
||||
},
|
||||
"workflows": {
|
||||
@@ -1125,7 +1133,23 @@
|
||||
"addItem": "Add Item",
|
||||
"generateValues": "Generate Values",
|
||||
"floatRangeGenerator": "Float Range Generator",
|
||||
"integerRangeGenerator": "Integer Range Generator"
|
||||
"integerRangeGenerator": "Integer Range Generator",
|
||||
"layout": {
|
||||
"autoLayout": "Auto Layout",
|
||||
"layeringStrategy": "Layering Strategy",
|
||||
"networkSimplex": "Network Simplex",
|
||||
"longestPath": "Longest Path",
|
||||
"nodeSpacing": "Node Spacing",
|
||||
"layerSpacing": "Layer Spacing",
|
||||
"layoutDirection": "Layout Direction",
|
||||
"layoutDirectionRight": "Right",
|
||||
"layoutDirectionDown": "Down",
|
||||
"alignment": "Node Alignment",
|
||||
"alignmentUL": "Top Left",
|
||||
"alignmentDL": "Bottom Left",
|
||||
"alignmentUR": "Top Right",
|
||||
"alignmentDR": "Bottom Right"
|
||||
}
|
||||
},
|
||||
"parameters": {
|
||||
"aspect": "Aspect",
|
||||
@@ -1399,7 +1423,7 @@
|
||||
"fluxFillIncompatibleWithT2IAndI2I": "FLUX Fill is not compatible with Text to Image or Image to Image. Use other FLUX models for these tasks.",
|
||||
"imagenIncompatibleGenerationMode": "Google {{model}} supports Text to Image only. Use other models for Image to Image, Inpainting and Outpainting tasks.",
|
||||
"chatGPT4oIncompatibleGenerationMode": "ChatGPT 4o supports Text to Image and Image to Image only. Use other models Inpainting and Outpainting tasks.",
|
||||
"fluxKontextIncompatibleGenerationMode": "FLUX Kontext supports Text to Image only. Use other models for Image to Image, Inpainting and Outpainting tasks.",
|
||||
"fluxKontextIncompatibleGenerationMode": "FLUX Kontext does not support generation from images placed on the canvas. Re-try using the Reference Image section and disable any Raster Layers.",
|
||||
"problemUnpublishingWorkflow": "Problem Unpublishing Workflow",
|
||||
"problemUnpublishingWorkflowDescription": "There was a problem unpublishing the workflow. Please try again.",
|
||||
"workflowUnpublished": "Workflow Unpublished",
|
||||
@@ -1407,7 +1431,15 @@
|
||||
"sentToUpscale": "Sent to Upscale",
|
||||
"promptGenerationStarted": "Prompt generation started",
|
||||
"uploadAndPromptGenerationFailed": "Failed to upload image and generate prompt",
|
||||
"promptExpansionFailed": "Prompt expansion failed"
|
||||
"promptExpansionFailed": "We ran into an issue. Please try prompt expansion again.",
|
||||
"maskInverted": "Mask Inverted",
|
||||
"maskInvertFailed": "Failed to Invert Mask",
|
||||
"noVisibleMasks": "No Visible Masks",
|
||||
"noVisibleMasksDesc": "Create or enable at least one inpaint mask to invert",
|
||||
"noInpaintMaskSelected": "No Inpaint Mask Selected",
|
||||
"noInpaintMaskSelectedDesc": "Select an inpaint mask to invert",
|
||||
"invalidBbox": "Invalid Bounding Box",
|
||||
"invalidBboxDesc": "The bounding box has no valid dimensions"
|
||||
},
|
||||
"popovers": {
|
||||
"clipSkip": {
|
||||
@@ -1775,6 +1807,20 @@
|
||||
"Structure controls how closely the output image will keep to the layout of the original. Low structure allows major changes, while high structure strictly maintains the original composition and layout."
|
||||
]
|
||||
},
|
||||
"tileSize": {
|
||||
"heading": "Tile Size",
|
||||
"paragraphs": [
|
||||
"Controls the size of tiles used during the upscaling process. Larger tiles use more memory but may produce better results.",
|
||||
"SD1.5 models default to 768, while SDXL models default to 1024. Reduce tile size if you encounter memory issues."
|
||||
]
|
||||
},
|
||||
"tileOverlap": {
|
||||
"heading": "Tile Overlap",
|
||||
"paragraphs": [
|
||||
"Controls the overlap between adjacent tiles during upscaling. Higher overlap values help reduce visible seams between tiles but use more memory.",
|
||||
"The default value of 128 works well for most cases, but you can adjust based on your specific needs and memory constraints."
|
||||
]
|
||||
},
|
||||
"fluxDevLicense": {
|
||||
"heading": "Non-Commercial License",
|
||||
"paragraphs": [
|
||||
@@ -1926,6 +1972,7 @@
|
||||
"canvas": "Canvas",
|
||||
"bookmark": "Bookmark for Quick Switch",
|
||||
"fitBboxToLayers": "Fit Bbox To Layers",
|
||||
"fitBboxToMasks": "Fit Bbox To Masks",
|
||||
"removeBookmark": "Remove Bookmark",
|
||||
"saveCanvasToGallery": "Save Canvas to Gallery",
|
||||
"saveBboxToGallery": "Save Bbox to Gallery",
|
||||
@@ -1990,6 +2037,7 @@
|
||||
"rasterLayer": "Raster Layer",
|
||||
"controlLayer": "Control Layer",
|
||||
"inpaintMask": "Inpaint Mask",
|
||||
"invertMask": "Invert Mask",
|
||||
"regionalGuidance": "Regional Guidance",
|
||||
"referenceImageRegional": "Reference Image (Regional)",
|
||||
"referenceImageGlobal": "Reference Image (Global)",
|
||||
@@ -2086,9 +2134,9 @@
|
||||
"resetCanvasLayers": "Reset Canvas Layers",
|
||||
"resetGenerationSettings": "Reset Generation Settings",
|
||||
"replaceCurrent": "Replace Current",
|
||||
"controlLayerEmptyState": "<UploadButton>Upload an image</UploadButton>, drag an image from the <GalleryButton>gallery</GalleryButton> onto this layer, <PullBboxButton>pull the bounding box into this layer</PullBboxButton>, or draw on the canvas to get started.",
|
||||
"referenceImageEmptyStateWithCanvasOptions": "<UploadButton>Upload an image</UploadButton>, drag an image from the <GalleryButton>gallery</GalleryButton> onto this Reference Image or <PullBboxButton>pull the bounding box into this Reference Image</PullBboxButton> to get started.",
|
||||
"referenceImageEmptyState": "<UploadButton>Upload an image</UploadButton> or drag an image from the <GalleryButton>gallery</GalleryButton> onto this Reference Image to get started.",
|
||||
"controlLayerEmptyState": "<UploadButton>Upload an image</UploadButton>, drag an image from the gallery onto this layer, <PullBboxButton>pull the bounding box into this layer</PullBboxButton>, or draw on the canvas to get started.",
|
||||
"referenceImageEmptyStateWithCanvasOptions": "<UploadButton>Upload an image</UploadButton>, drag an image from the gallery onto this Reference Image or <PullBboxButton>pull the bounding box into this Reference Image</PullBboxButton> to get started.",
|
||||
"referenceImageEmptyState": "<UploadButton>Upload an image</UploadButton> or drag an image from the gallery onto this Reference Image to get started.",
|
||||
"uploadOrDragAnImage": "Drag an image from the gallery or <UploadButton>upload an image</UploadButton>.",
|
||||
"imageNoise": "Image Noise",
|
||||
"denoiseLimit": "Denoise Limit",
|
||||
@@ -2330,6 +2378,10 @@
|
||||
"label": "Preserve Masked Region",
|
||||
"alert": "Preserving Masked Region"
|
||||
},
|
||||
"saveAllImagesToGallery": {
|
||||
"label": "Send New Generations to Gallery",
|
||||
"alert": "Sending new generations to Gallery, bypassing Canvas"
|
||||
},
|
||||
"isolatedStagingPreview": "Isolated Staging Preview",
|
||||
"isolatedPreview": "Isolated Preview",
|
||||
"isolatedLayerPreview": "Isolated Layer Preview",
|
||||
@@ -2376,6 +2428,11 @@
|
||||
"saveToGallery": "Save To Gallery",
|
||||
"showResultsOn": "Showing Results",
|
||||
"showResultsOff": "Hiding Results"
|
||||
},
|
||||
"autoSwitch": {
|
||||
"off": "Off",
|
||||
"switchOnStart": "On Start",
|
||||
"switchOnFinish": "On Finish"
|
||||
}
|
||||
},
|
||||
"upscaling": {
|
||||
@@ -2387,6 +2444,9 @@
|
||||
"upscaleModel": "Upscale Model",
|
||||
"postProcessingModel": "Post-Processing Model",
|
||||
"scale": "Scale",
|
||||
"tileControl": "Tile Control",
|
||||
"tileSize": "Tile Size",
|
||||
"tileOverlap": "Tile Overlap",
|
||||
"postProcessingMissingModelWarning": "Visit the <LinkComponent>Model Manager</LinkComponent> to install a post-processing (image to image) model.",
|
||||
"missingModelsWarning": "Visit the <LinkComponent>Model Manager</LinkComponent> to install the required models:",
|
||||
"mainModelDesc": "Main model (SD1.5 or SDXL architecture)",
|
||||
@@ -2551,8 +2611,9 @@
|
||||
"whatsNew": {
|
||||
"whatsNewInInvoke": "What's New in Invoke",
|
||||
"items": [
|
||||
"Inpainting: Per-mask noise levels and denoise limits.",
|
||||
"Canvas: Smarter aspect ratios for SDXL and improved scroll-to-zoom."
|
||||
"Generate images faster with new Launchpads and a simplified Generate tab.",
|
||||
"Edit with prompts using Flux Kontext Dev.",
|
||||
"Export to PSD, bulk-hide overlays, organize models & images — all in a reimagined interface built for control."
|
||||
],
|
||||
"readReleaseNotes": "Read Release Notes",
|
||||
"watchRecentReleaseVideos": "Watch Recent Release Videos",
|
||||
@@ -2561,62 +2622,16 @@
|
||||
"supportVideos": {
|
||||
"supportVideos": "Support Videos",
|
||||
"gettingStarted": "Getting Started",
|
||||
"controlCanvas": "Control Canvas",
|
||||
"watch": "Watch",
|
||||
"studioSessionsDesc1": "Check out the <StudioSessionsPlaylistLink /> for Invoke deep dives.",
|
||||
"studioSessionsDesc2": "Join our <DiscordLink /> to participate in the live sessions and ask questions. Sessions are uploaded to the playlist the following week.",
|
||||
"studioSessionsDesc": "Join our <DiscordLink /> to participate in the live sessions and ask questions. Sessions are uploaded to the playlist the following week.",
|
||||
"videos": {
|
||||
"creatingYourFirstImage": {
|
||||
"title": "Creating Your First Image",
|
||||
"description": "Introduction to creating an image from scratch using Invoke's tools."
|
||||
"gettingStarted": {
|
||||
"title": "Getting Started with Invoke",
|
||||
"description": "Complete video series covering everything you need to know to get started with Invoke, from creating your first image to advanced techniques."
|
||||
},
|
||||
"usingControlLayersAndReferenceGuides": {
|
||||
"title": "Using Control Layers and Reference Guides",
|
||||
"description": "Learn how to guide your image creation with control layers and reference images."
|
||||
},
|
||||
"understandingImageToImageAndDenoising": {
|
||||
"title": "Understanding Image-to-Image and Denoising",
|
||||
"description": "Overview of image-to-image transformations and denoising in Invoke."
|
||||
},
|
||||
"exploringAIModelsAndConceptAdapters": {
|
||||
"title": "Exploring AI Models and Concept Adapters",
|
||||
"description": "Dive into AI models and how to use concept adapters for creative control."
|
||||
},
|
||||
"creatingAndComposingOnInvokesControlCanvas": {
|
||||
"title": "Creating and Composing on Invoke's Control Canvas",
|
||||
"description": "Learn to compose images using Invoke's control canvas."
|
||||
},
|
||||
"upscaling": {
|
||||
"title": "Upscaling",
|
||||
"description": "How to upscale images with Invoke's tools to enhance resolution."
|
||||
},
|
||||
"howDoIGenerateAndSaveToTheGallery": {
|
||||
"title": "How Do I Generate and Save to the Gallery?",
|
||||
"description": "Steps to generate and save images to the gallery."
|
||||
},
|
||||
"howDoIEditOnTheCanvas": {
|
||||
"title": "How Do I Edit on the Canvas?",
|
||||
"description": "Guide to editing images directly on the canvas."
|
||||
},
|
||||
"howDoIDoImageToImageTransformation": {
|
||||
"title": "How Do I Do Image-to-Image Transformation?",
|
||||
"description": "Tutorial on performing image-to-image transformations in Invoke."
|
||||
},
|
||||
"howDoIUseControlNetsAndControlLayers": {
|
||||
"title": "How Do I Use Control Nets and Control Layers?",
|
||||
"description": "Learn to apply control layers and controlnets to your images."
|
||||
},
|
||||
"howDoIUseGlobalIPAdaptersAndReferenceImages": {
|
||||
"title": "How Do I Use Global IP Adapters and Reference Images?",
|
||||
"description": "Introduction to adding reference images and global IP adapters."
|
||||
},
|
||||
"howDoIUseInpaintMasks": {
|
||||
"title": "How Do I Use Inpaint Masks?",
|
||||
"description": "How to apply inpaint masks for image correction and variation."
|
||||
},
|
||||
"howDoIOutpaint": {
|
||||
"title": "How Do I Outpaint?",
|
||||
"description": "Guide to outpainting beyond the original image borders."
|
||||
"studioSessions": {
|
||||
"title": "Studio Sessions",
|
||||
"description": "Deep dive sessions exploring advanced Invoke features, creative workflows, and community discussions."
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import { memo, useCallback } from 'react';
|
||||
import { ErrorBoundary } from 'react-error-boundary';
|
||||
|
||||
import AppErrorBoundaryFallback from './AppErrorBoundaryFallback';
|
||||
import ThemeLocaleProvider from './ThemeLocaleProvider';
|
||||
const DEFAULT_CONFIG = {};
|
||||
|
||||
interface Props {
|
||||
@@ -29,14 +30,16 @@ const App = ({ config = DEFAULT_CONFIG, studioInitAction }: Props) => {
|
||||
}, [clearStorage]);
|
||||
|
||||
return (
|
||||
<ErrorBoundary onReset={handleReset} FallbackComponent={AppErrorBoundaryFallback}>
|
||||
<Box id="invoke-app-wrapper" w="100dvw" h="100dvh" position="relative" overflow="hidden">
|
||||
<AppContent />
|
||||
{!didStudioInit && <Loading />}
|
||||
</Box>
|
||||
<GlobalHookIsolator config={config} studioInitAction={studioInitAction} />
|
||||
<GlobalModalIsolator />
|
||||
</ErrorBoundary>
|
||||
<ThemeLocaleProvider>
|
||||
<ErrorBoundary onReset={handleReset} FallbackComponent={AppErrorBoundaryFallback}>
|
||||
<Box id="invoke-app-wrapper" w="100dvw" h="100dvh" position="relative" overflow="hidden">
|
||||
<AppContent />
|
||||
{!didStudioInit && <Loading />}
|
||||
</Box>
|
||||
<GlobalHookIsolator config={config} studioInitAction={studioInitAction} />
|
||||
<GlobalModalIsolator />
|
||||
</ErrorBoundary>
|
||||
</ThemeLocaleProvider>
|
||||
);
|
||||
};
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ import { useGlobalModifiersInit } from '@invoke-ai/ui-library';
|
||||
import { setupListeners } from '@reduxjs/toolkit/query';
|
||||
import type { StudioInitAction } from 'app/hooks/useStudioInitAction';
|
||||
import { useStudioInitAction } from 'app/hooks/useStudioInitAction';
|
||||
import { useSyncLangDirection } from 'app/hooks/useSyncLangDirection';
|
||||
import { useSyncQueueStatus } from 'app/hooks/useSyncQueueStatus';
|
||||
import { useLogger } from 'app/logging/useLogger';
|
||||
import { useSyncLoggingConfig } from 'app/logging/useSyncLoggingConfig';
|
||||
@@ -15,6 +16,8 @@ import { useDndMonitor } from 'features/dnd/useDndMonitor';
|
||||
import { useDynamicPromptsWatcher } from 'features/dynamicPrompts/hooks/useDynamicPromptsWatcher';
|
||||
import { useStarterModelsToast } from 'features/modelManagerV2/hooks/useStarterModelsToast';
|
||||
import { useWorkflowBuilderWatcher } from 'features/nodes/components/sidePanel/workflow/IsolatedWorkflowBuilderWatcher';
|
||||
import { useSyncExecutionState } from 'features/nodes/hooks/useNodeExecutionState';
|
||||
import { useSyncNodeErrors } from 'features/nodes/store/util/fieldValidators';
|
||||
import { useReadinessWatcher } from 'features/queue/store/readiness';
|
||||
import { configChanged } from 'features/system/store/configSlice';
|
||||
import { selectLanguage } from 'features/system/store/systemSelectors';
|
||||
@@ -47,10 +50,13 @@ export const GlobalHookIsolator = memo(
|
||||
useCloseChakraTooltipsOnDragFix();
|
||||
useNavigationApi();
|
||||
useDndMonitor();
|
||||
useSyncNodeErrors();
|
||||
useSyncLangDirection();
|
||||
|
||||
// Persistent subscription to the queue counts query - canvas relies on this to know if there are pending
|
||||
// and/or in progress canvas sessions.
|
||||
useGetQueueCountsByDestinationQuery(queueCountArg);
|
||||
useSyncExecutionState();
|
||||
|
||||
useEffect(() => {
|
||||
i18n.changeLanguage(language);
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useIsRegionFocused } from 'common/hooks/focus';
|
||||
import { useAssertSingleton } from 'common/hooks/useAssertSingleton';
|
||||
import { selectIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
import { useImageActions } from 'features/gallery/hooks/useImageActions';
|
||||
import { useLoadWorkflow } from 'features/gallery/hooks/useLoadWorkflow';
|
||||
import { useRecallAll } from 'features/gallery/hooks/useRecallAll';
|
||||
import { useRecallDimensions } from 'features/gallery/hooks/useRecallDimensions';
|
||||
import { useRecallPrompts } from 'features/gallery/hooks/useRecallPrompts';
|
||||
import { useRecallRemix } from 'features/gallery/hooks/useRecallRemix';
|
||||
import { useRecallSeed } from 'features/gallery/hooks/useRecallSeed';
|
||||
import { selectLastSelectedImage } from 'features/gallery/store/gallerySelectors';
|
||||
import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData';
|
||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||
import { memo } from 'react';
|
||||
import { useImageDTO } from 'services/api/endpoints/images';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
@@ -27,59 +30,64 @@ GlobalImageHotkeys.displayName = 'GlobalImageHotkeys';
|
||||
const GlobalImageHotkeysInternal = memo(({ imageDTO }: { imageDTO: ImageDTO }) => {
|
||||
const isGalleryFocused = useIsRegionFocused('gallery');
|
||||
const isViewerFocused = useIsRegionFocused('viewer');
|
||||
const imageActions = useImageActions(imageDTO);
|
||||
const isStaging = useAppSelector(selectIsStaging);
|
||||
const isUpscalingEnabled = useFeatureStatus('upscaling');
|
||||
|
||||
const isFocusOK = isGalleryFocused || isViewerFocused;
|
||||
|
||||
const recallAll = useRecallAll(imageDTO);
|
||||
const recallRemix = useRecallRemix(imageDTO);
|
||||
const recallPrompts = useRecallPrompts(imageDTO);
|
||||
const recallSeed = useRecallSeed(imageDTO);
|
||||
const recallDimensions = useRecallDimensions(imageDTO);
|
||||
const loadWorkflow = useLoadWorkflow(imageDTO);
|
||||
|
||||
useRegisteredHotkeys({
|
||||
id: 'loadWorkflow',
|
||||
category: 'viewer',
|
||||
callback: imageActions.loadWorkflow,
|
||||
options: { enabled: isGalleryFocused || isViewerFocused },
|
||||
dependencies: [imageActions.loadWorkflow, isGalleryFocused, isViewerFocused],
|
||||
callback: loadWorkflow.load,
|
||||
options: { enabled: loadWorkflow.isEnabled && isFocusOK },
|
||||
dependencies: [loadWorkflow, isFocusOK],
|
||||
});
|
||||
|
||||
useRegisteredHotkeys({
|
||||
id: 'recallAll',
|
||||
category: 'viewer',
|
||||
callback: imageActions.recallAll,
|
||||
options: { enabled: !isStaging && (isGalleryFocused || isViewerFocused) },
|
||||
dependencies: [imageActions.recallAll, isStaging, isGalleryFocused, isViewerFocused],
|
||||
callback: recallAll.recall,
|
||||
options: { enabled: recallAll.isEnabled && isFocusOK },
|
||||
dependencies: [recallAll, isFocusOK],
|
||||
});
|
||||
|
||||
useRegisteredHotkeys({
|
||||
id: 'recallSeed',
|
||||
category: 'viewer',
|
||||
callback: imageActions.recallSeed,
|
||||
options: { enabled: isGalleryFocused || isViewerFocused },
|
||||
dependencies: [imageActions.recallSeed, isGalleryFocused, isViewerFocused],
|
||||
callback: recallSeed.recall,
|
||||
options: { enabled: recallSeed.isEnabled && isFocusOK },
|
||||
dependencies: [recallSeed, isFocusOK],
|
||||
});
|
||||
|
||||
useRegisteredHotkeys({
|
||||
id: 'recallPrompts',
|
||||
category: 'viewer',
|
||||
callback: imageActions.recallPrompts,
|
||||
options: { enabled: isGalleryFocused || isViewerFocused },
|
||||
dependencies: [imageActions.recallPrompts, isGalleryFocused, isViewerFocused],
|
||||
callback: recallPrompts.recall,
|
||||
options: { enabled: recallPrompts.isEnabled && isFocusOK },
|
||||
dependencies: [recallPrompts, isFocusOK],
|
||||
});
|
||||
|
||||
useRegisteredHotkeys({
|
||||
id: 'remix',
|
||||
category: 'viewer',
|
||||
callback: imageActions.remix,
|
||||
options: { enabled: isGalleryFocused || isViewerFocused },
|
||||
dependencies: [imageActions.remix, isGalleryFocused, isViewerFocused],
|
||||
callback: recallRemix.recall,
|
||||
options: { enabled: recallRemix.isEnabled && isFocusOK },
|
||||
dependencies: [recallRemix, isFocusOK],
|
||||
});
|
||||
|
||||
useRegisteredHotkeys({
|
||||
id: 'useSize',
|
||||
category: 'viewer',
|
||||
callback: imageActions.recallSize,
|
||||
options: { enabled: !isStaging && (isGalleryFocused || isViewerFocused) },
|
||||
dependencies: [imageActions.recallSize, isStaging, isGalleryFocused, isViewerFocused],
|
||||
});
|
||||
useRegisteredHotkeys({
|
||||
id: 'runPostprocessing',
|
||||
category: 'viewer',
|
||||
callback: imageActions.upscale,
|
||||
options: { enabled: isUpscalingEnabled && isViewerFocused },
|
||||
dependencies: [isUpscalingEnabled, imageDTO, isViewerFocused],
|
||||
callback: recallDimensions.recall,
|
||||
options: { enabled: recallDimensions.isEnabled && isFocusOK },
|
||||
dependencies: [recallDimensions, isFocusOK],
|
||||
});
|
||||
|
||||
return null;
|
||||
});
|
||||
|
||||
|
||||
@@ -1,10 +1,6 @@
|
||||
import { GlobalImageHotkeys } from 'app/components/GlobalImageHotkeys';
|
||||
import ChangeBoardModal from 'features/changeBoardModal/components/ChangeBoardModal';
|
||||
import { CanvasPasteModal } from 'features/controlLayers/components/CanvasPasteModal';
|
||||
import {
|
||||
NewCanvasSessionDialog,
|
||||
NewGallerySessionDialog,
|
||||
} from 'features/controlLayers/components/NewSessionConfirmationAlertDialog';
|
||||
import { CanvasManagerProviderGate } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
|
||||
import { DeleteImageModal } from 'features/deleteImageModal/components/DeleteImageModal';
|
||||
import { FullscreenDropzone } from 'features/dnd/FullscreenDropzone';
|
||||
@@ -50,8 +46,6 @@ export const GlobalModalIsolator = memo(() => {
|
||||
<RefreshAfterResetModal />
|
||||
<DeleteBoardModal />
|
||||
<GlobalImageHotkeys />
|
||||
<NewGallerySessionDialog />
|
||||
<NewCanvasSessionDialog />
|
||||
<ImageContextMenu />
|
||||
<FullscreenDropzone />
|
||||
<VideosModal />
|
||||
|
||||
@@ -42,7 +42,6 @@ import { $socketOptions } from 'services/events/stores';
|
||||
import type { ManagerOptions, SocketOptions } from 'socket.io-client';
|
||||
|
||||
const App = lazy(() => import('./App'));
|
||||
const ThemeLocaleProvider = lazy(() => import('./ThemeLocaleProvider'));
|
||||
|
||||
interface Props extends PropsWithChildren {
|
||||
apiUrl?: string;
|
||||
@@ -318,7 +317,7 @@ const InvokeAIUI = ({
|
||||
if (import.meta.env.MODE === 'development') {
|
||||
window.$store = $store;
|
||||
}
|
||||
() => {
|
||||
return () => {
|
||||
$store.set(undefined);
|
||||
if (import.meta.env.MODE === 'development') {
|
||||
window.$store = undefined;
|
||||
@@ -330,9 +329,7 @@ const InvokeAIUI = ({
|
||||
<React.StrictMode>
|
||||
<Provider store={store}>
|
||||
<React.Suspense fallback={<Loading />}>
|
||||
<ThemeLocaleProvider>
|
||||
<App config={config} studioInitAction={studioInitAction} />
|
||||
</ThemeLocaleProvider>
|
||||
<App config={config} studioInitAction={studioInitAction} />
|
||||
</React.Suspense>
|
||||
</Provider>
|
||||
</React.StrictMode>
|
||||
|
||||
@@ -3,43 +3,39 @@ import 'overlayscrollbars/overlayscrollbars.css';
|
||||
import '@xyflow/react/dist/base.css';
|
||||
import 'common/components/OverlayScrollbars/overlayscrollbars.css';
|
||||
|
||||
import { ChakraProvider, DarkMode, extendTheme, theme as _theme, TOAST_OPTIONS } from '@invoke-ai/ui-library';
|
||||
import { ChakraProvider, DarkMode, extendTheme, theme as baseTheme, TOAST_OPTIONS } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { $direction } from 'app/hooks/useSyncLangDirection';
|
||||
import type { ReactNode } from 'react';
|
||||
import { memo, useEffect, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { memo, useMemo } from 'react';
|
||||
|
||||
type ThemeLocaleProviderProps = {
|
||||
children: ReactNode;
|
||||
};
|
||||
|
||||
const buildTheme = (direction: 'ltr' | 'rtl') => {
|
||||
return extendTheme({
|
||||
...baseTheme,
|
||||
direction,
|
||||
shadows: {
|
||||
...baseTheme.shadows,
|
||||
selected:
|
||||
'inset 0px 0px 0px 3px var(--invoke-colors-invokeBlue-500), inset 0px 0px 0px 4px var(--invoke-colors-invokeBlue-800)',
|
||||
hoverSelected:
|
||||
'inset 0px 0px 0px 3px var(--invoke-colors-invokeBlue-400), inset 0px 0px 0px 4px var(--invoke-colors-invokeBlue-800)',
|
||||
hoverUnselected:
|
||||
'inset 0px 0px 0px 2px var(--invoke-colors-invokeBlue-300), inset 0px 0px 0px 3px var(--invoke-colors-invokeBlue-800)',
|
||||
selectedForCompare:
|
||||
'inset 0px 0px 0px 3px var(--invoke-colors-invokeGreen-300), inset 0px 0px 0px 4px var(--invoke-colors-invokeGreen-800)',
|
||||
hoverSelectedForCompare:
|
||||
'inset 0px 0px 0px 3px var(--invoke-colors-invokeGreen-200), inset 0px 0px 0px 4px var(--invoke-colors-invokeGreen-800)',
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
function ThemeLocaleProvider({ children }: ThemeLocaleProviderProps) {
|
||||
const { i18n } = useTranslation();
|
||||
|
||||
const direction = i18n.dir();
|
||||
|
||||
const theme = useMemo(() => {
|
||||
return extendTheme({
|
||||
..._theme,
|
||||
direction,
|
||||
shadows: {
|
||||
..._theme.shadows,
|
||||
selected:
|
||||
'inset 0px 0px 0px 3px var(--invoke-colors-invokeBlue-500), inset 0px 0px 0px 4px var(--invoke-colors-invokeBlue-800)',
|
||||
hoverSelected:
|
||||
'inset 0px 0px 0px 3px var(--invoke-colors-invokeBlue-400), inset 0px 0px 0px 4px var(--invoke-colors-invokeBlue-800)',
|
||||
hoverUnselected:
|
||||
'inset 0px 0px 0px 2px var(--invoke-colors-invokeBlue-300), inset 0px 0px 0px 3px var(--invoke-colors-invokeBlue-800)',
|
||||
selectedForCompare:
|
||||
'inset 0px 0px 0px 3px var(--invoke-colors-invokeGreen-300), inset 0px 0px 0px 4px var(--invoke-colors-invokeGreen-800)',
|
||||
hoverSelectedForCompare:
|
||||
'inset 0px 0px 0px 3px var(--invoke-colors-invokeGreen-200), inset 0px 0px 0px 4px var(--invoke-colors-invokeGreen-800)',
|
||||
},
|
||||
});
|
||||
}, [direction]);
|
||||
|
||||
useEffect(() => {
|
||||
document.body.dir = direction;
|
||||
}, [direction]);
|
||||
const direction = useStore($direction);
|
||||
const theme = useMemo(() => buildTheme(direction), [direction]);
|
||||
|
||||
return (
|
||||
<ChakraProvider theme={theme} toastOptions={TOAST_OPTIONS}>
|
||||
|
||||
@@ -20,7 +20,7 @@ import {
|
||||
import { $isStylePresetsMenuOpen, activeStylePresetIdChanged } from 'features/stylePresets/store/stylePresetSlice';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { navigationApi } from 'features/ui/layouts/navigation-api';
|
||||
import { activeTabCanvasRightPanelChanged } from 'features/ui/store/uiSlice';
|
||||
import { LAUNCHPAD_PANEL_ID, WORKSPACE_PANEL_ID } from 'features/ui/layouts/shared';
|
||||
import { useLoadWorkflowWithDialog } from 'features/workflowLibrary/components/LoadWorkflowConfirmationAlertDialog';
|
||||
import { atom } from 'nanostores';
|
||||
import { useCallback, useEffect } from 'react';
|
||||
@@ -91,6 +91,7 @@ export const useStudioInitAction = (action?: StudioInitAction) => {
|
||||
const overrides: Partial<CanvasRasterLayerState> = {
|
||||
objects: [imageObject],
|
||||
};
|
||||
await navigationApi.focusPanel('canvas', WORKSPACE_PANEL_ID);
|
||||
store.dispatch(canvasReset());
|
||||
store.dispatch(rasterLayerAdded({ overrides, isSelected: true }));
|
||||
store.dispatch(sentImageToCanvas());
|
||||
@@ -157,16 +158,16 @@ export const useStudioInitAction = (action?: StudioInitAction) => {
|
||||
);
|
||||
|
||||
const handleGoToDestination = useCallback(
|
||||
(destination: StudioDestinationAction['data']['destination']) => {
|
||||
async (destination: StudioDestinationAction['data']['destination']) => {
|
||||
switch (destination) {
|
||||
case 'generation':
|
||||
// Go to the canvas tab, open the image viewer, and enable send-to-gallery mode
|
||||
// Go to the generate tab, open the launchpad
|
||||
await navigationApi.focusPanel('generate', LAUNCHPAD_PANEL_ID);
|
||||
store.dispatch(paramsReset());
|
||||
store.dispatch(activeTabCanvasRightPanelChanged('gallery'));
|
||||
break;
|
||||
case 'canvas':
|
||||
// Go to the canvas tab, close the image viewer, and disable send-to-gallery mode
|
||||
store.dispatch(canvasReset());
|
||||
// Go to the canvas tab, open the launchpad
|
||||
await navigationApi.focusPanel('canvas', WORKSPACE_PANEL_ID);
|
||||
break;
|
||||
case 'workflows':
|
||||
// Go to the workflows tab
|
||||
|
||||
36
invokeai/frontend/web/src/app/hooks/useSyncLangDirection.ts
Normal file
36
invokeai/frontend/web/src/app/hooks/useSyncLangDirection.ts
Normal file
@@ -0,0 +1,36 @@
|
||||
import { useAssertSingleton } from 'common/hooks/useAssertSingleton';
|
||||
import { atom } from 'nanostores';
|
||||
import { useEffect } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
/**
|
||||
* Global atom storing the language direction, to be consumed by the Chakra theme.
|
||||
*
|
||||
* Why do we need this? We have a kind of catch-22:
|
||||
* - The Chakra theme needs to know the language direction to apply the correct styles.
|
||||
* - The language direction is determined by i18n and the language selection.
|
||||
* - We want our error boundary to be themed.
|
||||
* - It's possible that i18n can throw if the language selection is invalid or not supported.
|
||||
*
|
||||
* Previously, we had the logic in this file in the theme provider, which wrapped the error boundary. The error
|
||||
* was properly themed. But then, if i18n threw in the theme provider, the error boundary does not catch the
|
||||
* error. The app would crash to a white screen.
|
||||
*
|
||||
* We tried swapping the component hierarchy so that the error boundary wraps the theme provider, but then the
|
||||
* error boundary isn't themed!
|
||||
*
|
||||
* The solution is to move this i18n direction logic out of the theme provider and into a hook that we can use
|
||||
* within the error boundary. The error boundary will be themed, _and_ catch any i18n errors.
|
||||
*/
|
||||
export const $direction = atom<'ltr' | 'rtl'>('ltr');
|
||||
|
||||
export const useSyncLangDirection = () => {
|
||||
useAssertSingleton('useSyncLangDirection');
|
||||
const { i18n, t } = useTranslation();
|
||||
|
||||
useEffect(() => {
|
||||
const direction = i18n.dir();
|
||||
$direction.set(direction);
|
||||
document.body.dir = direction;
|
||||
}, [i18n, t]);
|
||||
};
|
||||
@@ -2,7 +2,7 @@ import { createLogWriter } from '@roarr/browser-log-writer';
|
||||
import { atom } from 'nanostores';
|
||||
import type { Logger, MessageSerializer } from 'roarr';
|
||||
import { ROARR, Roarr } from 'roarr';
|
||||
import { z } from 'zod/v4';
|
||||
import { z } from 'zod';
|
||||
|
||||
const serializeMessage: MessageSerializer = (message) => {
|
||||
return JSON.stringify(message);
|
||||
|
||||
@@ -1,14 +1,28 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { bboxSyncedToOptimalDimension } from 'features/controlLayers/store/canvasSlice';
|
||||
import { selectIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
import { bboxSyncedToOptimalDimension, rgRefImageModelChanged } from 'features/controlLayers/store/canvasSlice';
|
||||
import { buildSelectIsStaging, selectCanvasSessionId } from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
import { loraDeleted } from 'features/controlLayers/store/lorasSlice';
|
||||
import { modelChanged, syncedToOptimalDimension, vaeSelected } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectBboxModelBase } from 'features/controlLayers/store/selectors';
|
||||
import { refImageModelChanged, selectReferenceImageEntities } from 'features/controlLayers/store/refImagesSlice';
|
||||
import {
|
||||
selectAllEntitiesOfType,
|
||||
selectBboxModelBase,
|
||||
selectCanvasSlice,
|
||||
} from 'features/controlLayers/store/selectors';
|
||||
import { getEntityIdentifier } from 'features/controlLayers/store/types';
|
||||
import { modelSelected } from 'features/parameters/store/actions';
|
||||
import { zParameterModel } from 'features/parameters/types/parameterSchemas';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { t } from 'i18next';
|
||||
import { selectGlobalRefImageModels, selectRegionalRefImageModels } from 'services/api/hooks/modelsByType';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
import {
|
||||
isChatGPT4oModelConfig,
|
||||
isFluxKontextApiModelConfig,
|
||||
isFluxKontextModelConfig,
|
||||
isFluxReduxModelConfig,
|
||||
} from 'services/api/types';
|
||||
|
||||
const log = logger('models');
|
||||
|
||||
@@ -25,9 +39,8 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) =
|
||||
}
|
||||
|
||||
const newModel = result.data;
|
||||
|
||||
const newBaseModel = newModel.base;
|
||||
const didBaseModelChange = state.params.model?.base !== newBaseModel;
|
||||
const newBase = newModel.base;
|
||||
const didBaseModelChange = state.params.model?.base !== newBase;
|
||||
|
||||
if (didBaseModelChange) {
|
||||
// we may need to reset some incompatible submodels
|
||||
@@ -35,7 +48,7 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) =
|
||||
|
||||
// handle incompatible loras
|
||||
state.loras.loras.forEach((lora) => {
|
||||
if (lora.model.base !== newBaseModel) {
|
||||
if (lora.model.base !== newBase) {
|
||||
dispatch(loraDeleted({ id: lora.id }));
|
||||
modelsCleared += 1;
|
||||
}
|
||||
@@ -43,20 +56,82 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) =
|
||||
|
||||
// handle incompatible vae
|
||||
const { vae } = state.params;
|
||||
if (vae && vae.base !== newBaseModel) {
|
||||
if (vae && vae.base !== newBase) {
|
||||
dispatch(vaeSelected(null));
|
||||
modelsCleared += 1;
|
||||
}
|
||||
|
||||
// handle incompatible controlnets
|
||||
// state.canvas.present.controlAdapters.entities.forEach((ca) => {
|
||||
// if (ca.model?.base !== newBaseModel) {
|
||||
// modelsCleared += 1;
|
||||
// if (ca.isEnabled) {
|
||||
// dispatch(entityIsEnabledToggled({ entityIdentifier: { id: ca.id, type: 'control_adapter' } }));
|
||||
// }
|
||||
// }
|
||||
// });
|
||||
// Handle incompatible reference image models - switch to first compatible model, with some smart logic
|
||||
// to choose the best available model based on the new main model.
|
||||
const allRefImageModels = selectGlobalRefImageModels(state).filter(({ base }) => base === newBase);
|
||||
|
||||
let newGlobalRefImageModel = null;
|
||||
|
||||
// Certain models require the ref image model to be the same as the main model - others just need a matching
|
||||
// base. Helper to grab the first exact match or the first available model if no exact match is found.
|
||||
const exactMatchOrFirst = <T extends AnyModelConfig>(candidates: T[]): T | null =>
|
||||
candidates.find(({ key }) => key === newModel.key) ?? candidates[0] ?? null;
|
||||
|
||||
// The only way we can differentiate between FLUX and FLUX Kontext is to check for "kontext" in the name
|
||||
if (newModel.base === 'flux' && newModel.name.toLowerCase().includes('kontext')) {
|
||||
const fluxKontextDevModels = allRefImageModels.filter(isFluxKontextModelConfig);
|
||||
newGlobalRefImageModel = exactMatchOrFirst(fluxKontextDevModels);
|
||||
} else if (newModel.base === 'chatgpt-4o') {
|
||||
const chatGPT4oModels = allRefImageModels.filter(isChatGPT4oModelConfig);
|
||||
newGlobalRefImageModel = exactMatchOrFirst(chatGPT4oModels);
|
||||
} else if (newModel.base === 'flux-kontext') {
|
||||
const fluxKontextApiModels = allRefImageModels.filter(isFluxKontextApiModelConfig);
|
||||
newGlobalRefImageModel = exactMatchOrFirst(fluxKontextApiModels);
|
||||
} else if (newModel.base === 'flux') {
|
||||
const fluxReduxModels = allRefImageModels.filter(isFluxReduxModelConfig);
|
||||
newGlobalRefImageModel = fluxReduxModels[0] ?? null;
|
||||
} else {
|
||||
newGlobalRefImageModel = allRefImageModels[0] ?? null;
|
||||
}
|
||||
|
||||
// All ref image entities are updated to use the same new model
|
||||
const refImageEntities = selectReferenceImageEntities(state);
|
||||
for (const entity of refImageEntities) {
|
||||
const shouldUpdateModel =
|
||||
(entity.config.model && entity.config.model.base !== newBase) ||
|
||||
(!entity.config.model && newGlobalRefImageModel);
|
||||
|
||||
if (shouldUpdateModel) {
|
||||
dispatch(
|
||||
refImageModelChanged({
|
||||
id: entity.id,
|
||||
modelConfig: newGlobalRefImageModel,
|
||||
})
|
||||
);
|
||||
modelsCleared += 1;
|
||||
}
|
||||
}
|
||||
|
||||
// For regional guidance, there is no smart logic - we just pick the first available model.
|
||||
const newRegionalRefImageModel = selectRegionalRefImageModels(state)[0] ?? null;
|
||||
|
||||
// All regional guidance entities are updated to use the same new model.
|
||||
const canvasState = selectCanvasSlice(state);
|
||||
const canvasRegionalGuidanceEntities = selectAllEntitiesOfType(canvasState, 'regional_guidance');
|
||||
for (const entity of canvasRegionalGuidanceEntities) {
|
||||
for (const refImage of entity.referenceImages) {
|
||||
// Only change the model if the current one is not compatible with the new base model.
|
||||
const shouldUpdateModel =
|
||||
(refImage.config.model && refImage.config.model.base !== newBase) ||
|
||||
(!refImage.config.model && newRegionalRefImageModel);
|
||||
|
||||
if (shouldUpdateModel) {
|
||||
dispatch(
|
||||
rgRefImageModelChanged({
|
||||
entityIdentifier: getEntityIdentifier(entity),
|
||||
referenceImageId: refImage.id,
|
||||
modelConfig: newRegionalRefImageModel,
|
||||
})
|
||||
);
|
||||
modelsCleared += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (modelsCleared > 0) {
|
||||
toast({
|
||||
@@ -77,7 +152,8 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) =
|
||||
if (modelBase !== state.params.model?.base) {
|
||||
// Sync generate tab settings whenever the model base changes
|
||||
dispatch(syncedToOptimalDimension());
|
||||
if (!selectIsStaging(state)) {
|
||||
const isStaging = buildSelectIsStaging(selectCanvasSessionId(state))(state);
|
||||
if (!isStaging) {
|
||||
// Canvas tab only syncs if not staging
|
||||
dispatch(bboxSyncedToOptimalDimension());
|
||||
}
|
||||
|
||||
@@ -15,7 +15,11 @@ import { refImageModelChanged, selectRefImagesSlice } from 'features/controlLaye
|
||||
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
|
||||
import { getEntityIdentifier, isFLUXReduxConfig, isIPAdapterConfig } from 'features/controlLayers/store/types';
|
||||
import { modelSelected } from 'features/parameters/store/actions';
|
||||
import { postProcessingModelChanged, upscaleModelChanged } from 'features/parameters/store/upscaleSlice';
|
||||
import {
|
||||
postProcessingModelChanged,
|
||||
tileControlnetModelChanged,
|
||||
upscaleModelChanged,
|
||||
} from 'features/parameters/store/upscaleSlice';
|
||||
import {
|
||||
zParameterCLIPEmbedModel,
|
||||
zParameterSpandrelImageToImageModel,
|
||||
@@ -28,6 +32,7 @@ import type { AnyModelConfig } from 'services/api/types';
|
||||
import {
|
||||
isCLIPEmbedModelConfig,
|
||||
isControlLayerModelConfig,
|
||||
isControlNetModelConfig,
|
||||
isFluxReduxModelConfig,
|
||||
isFluxVAEModelConfig,
|
||||
isIPAdapterModelConfig,
|
||||
@@ -71,6 +76,7 @@ export const addModelsLoadedListener = (startAppListening: AppStartListening) =>
|
||||
handleControlAdapterModels(models, state, dispatch, log);
|
||||
handlePostProcessingModel(models, state, dispatch, log);
|
||||
handleUpscaleModel(models, state, dispatch, log);
|
||||
handleTileControlNetModel(models, state, dispatch, log);
|
||||
handleIPAdapterModels(models, state, dispatch, log);
|
||||
handleT5EncoderModels(models, state, dispatch, log);
|
||||
handleCLIPEmbedModels(models, state, dispatch, log);
|
||||
@@ -345,6 +351,46 @@ const handleUpscaleModel: ModelHandler = (models, state, dispatch, log) => {
|
||||
}
|
||||
};
|
||||
|
||||
const handleTileControlNetModel: ModelHandler = (models, state, dispatch, log) => {
|
||||
const selectedTileControlNetModel = state.upscale.tileControlnetModel;
|
||||
const controlNetModels = models.filter(isControlNetModelConfig);
|
||||
|
||||
// If the currently selected model is available, we don't need to do anything
|
||||
if (selectedTileControlNetModel && controlNetModels.some((m) => m.key === selectedTileControlNetModel.key)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// The only way we have to identify a model as a tile model is by its name containing 'tile' :)
|
||||
const tileModel = controlNetModels.find((m) => m.name.toLowerCase().includes('tile'));
|
||||
|
||||
// If we have a tile model, select it
|
||||
if (tileModel) {
|
||||
log.debug(
|
||||
{ selectedTileControlNetModel, tileModel },
|
||||
'No selected tile ControlNet model or selected model is not available, selecting tile model'
|
||||
);
|
||||
dispatch(tileControlnetModelChanged(tileModel));
|
||||
return;
|
||||
}
|
||||
|
||||
// Otherwise, select the first available ControlNet model
|
||||
const firstModel = controlNetModels[0] || null;
|
||||
if (firstModel) {
|
||||
log.debug(
|
||||
{ selectedTileControlNetModel, firstModel },
|
||||
'No tile ControlNet model found, selecting first available ControlNet model'
|
||||
);
|
||||
dispatch(tileControlnetModelChanged(firstModel));
|
||||
return;
|
||||
}
|
||||
|
||||
// No available models, we should clear the selected model - but only if we have one selected
|
||||
if (selectedTileControlNetModel) {
|
||||
log.debug({ selectedTileControlNetModel }, 'Selected tile ControlNet model is not available, clearing');
|
||||
dispatch(tileControlnetModelChanged(null));
|
||||
}
|
||||
};
|
||||
|
||||
const handleT5EncoderModels: ModelHandler = (models, state, dispatch, log) => {
|
||||
const selectedT5EncoderModel = state.params.t5EncoderModel;
|
||||
const t5EncoderModels = models.filter((m) => isT5EncoderModelConfig(m));
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { isNil } from 'es-toolkit';
|
||||
import { bboxHeightChanged, bboxWidthChanged } from 'features/controlLayers/store/canvasSlice';
|
||||
import { selectIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
import { buildSelectIsStaging, selectCanvasSessionId } from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
import {
|
||||
heightChanged,
|
||||
setCfgRescaleMultiplier,
|
||||
setCfgScale,
|
||||
setGuidance,
|
||||
@@ -10,6 +11,7 @@ import {
|
||||
setSteps,
|
||||
vaePrecisionChanged,
|
||||
vaeSelected,
|
||||
widthChanged,
|
||||
} from 'features/controlLayers/store/paramsSlice';
|
||||
import { setDefaultSettings } from 'features/parameters/store/actions';
|
||||
import {
|
||||
@@ -24,6 +26,7 @@ import {
|
||||
zParameterVAEModel,
|
||||
} from 'features/parameters/types/parameterSchemas';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { selectActiveTab } from 'features/ui/store/uiSelectors';
|
||||
import { t } from 'i18next';
|
||||
import { modelConfigsAdapterSelectors, modelsApi } from 'services/api/endpoints/models';
|
||||
import { isNonRefinerMainModelConfig } from 'services/api/types';
|
||||
@@ -112,16 +115,26 @@ export const addSetDefaultSettingsListener = (startAppListening: AppStartListeni
|
||||
}
|
||||
const setSizeOptions = { updateAspectRatio: true, clamp: true };
|
||||
|
||||
const isStaging = selectIsStaging(getState());
|
||||
if (!isStaging && width) {
|
||||
const isStaging = buildSelectIsStaging(selectCanvasSessionId(state))(state);
|
||||
|
||||
const activeTab = selectActiveTab(getState());
|
||||
if (activeTab === 'generate') {
|
||||
if (isParameterWidth(width)) {
|
||||
dispatch(bboxWidthChanged({ width, ...setSizeOptions }));
|
||||
dispatch(widthChanged({ width, ...setSizeOptions }));
|
||||
}
|
||||
if (isParameterHeight(height)) {
|
||||
dispatch(heightChanged({ height, ...setSizeOptions }));
|
||||
}
|
||||
}
|
||||
|
||||
if (!isStaging && height) {
|
||||
if (isParameterHeight(height)) {
|
||||
dispatch(bboxHeightChanged({ height, ...setSizeOptions }));
|
||||
if (activeTab === 'canvas') {
|
||||
if (!isStaging) {
|
||||
if (isParameterWidth(width)) {
|
||||
dispatch(bboxWidthChanged({ width, ...setSizeOptions }));
|
||||
}
|
||||
if (isParameterHeight(height)) {
|
||||
dispatch(bboxHeightChanged({ height, ...setSizeOptions }));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -67,6 +67,8 @@ export type Feature =
|
||||
| 'scale'
|
||||
| 'creativity'
|
||||
| 'structure'
|
||||
| 'tileSize'
|
||||
| 'tileOverlap'
|
||||
| 'optimizedDenoising'
|
||||
| 'fluxDevLicense';
|
||||
|
||||
|
||||
@@ -11,9 +11,13 @@ import {
|
||||
Text,
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
||||
import { typedMemo } from 'common/util/typedMemo';
|
||||
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
|
||||
import { selectPickerCompactViewStates } from 'features/ui/store/uiSelectors';
|
||||
import { pickerCompactViewStateChanged } from 'features/ui/store/uiSlice';
|
||||
import type { AnyStore, ReadableAtom, Task, WritableAtom } from 'nanostores';
|
||||
import { atom, computed } from 'nanostores';
|
||||
import type { StoreValues } from 'nanostores/computed';
|
||||
@@ -87,14 +91,10 @@ export const buildGroup = <T extends object>(group: Omit<Group<T>, typeof unique
|
||||
[uniqueGroupKey]: true,
|
||||
});
|
||||
|
||||
const isGroup = <T extends object>(optionOrGroup: OptionOrGroup<T>): optionOrGroup is Group<T> => {
|
||||
export const isGroup = <T extends object>(optionOrGroup: OptionOrGroup<T>): optionOrGroup is Group<T> => {
|
||||
return uniqueGroupKey in optionOrGroup && optionOrGroup[uniqueGroupKey] === true;
|
||||
};
|
||||
|
||||
export const isOption = <T extends object>(optionOrGroup: OptionOrGroup<T>): optionOrGroup is T => {
|
||||
return !(uniqueGroupKey in optionOrGroup);
|
||||
};
|
||||
|
||||
const DefaultOptionComponent = typedMemo(<T extends object>({ option }: { option: T }) => {
|
||||
const { getOptionId } = usePickerContext();
|
||||
return <Text fontWeight="bold">{getOptionId(option)}</Text>;
|
||||
@@ -144,6 +144,10 @@ const NoMatchesFallbackWrapper = typedMemo(({ children }: PropsWithChildren) =>
|
||||
NoMatchesFallbackWrapper.displayName = 'NoMatchesFallbackWrapper';
|
||||
|
||||
type PickerProps<T extends object> = {
|
||||
/**
|
||||
* Unique identifier for this picker instance. Used to persist compact view state.
|
||||
*/
|
||||
pickerId?: string;
|
||||
/**
|
||||
* The options to display in the picker. This can be a flat array of options or an array of groups.
|
||||
*/
|
||||
@@ -208,10 +212,18 @@ type PickerProps<T extends object> = {
|
||||
initialGroupStates?: GroupStatusMap;
|
||||
};
|
||||
|
||||
const buildSelectIsCompactView = (pickerId?: string) =>
|
||||
createSelector([selectPickerCompactViewStates], (compactViewStates) => {
|
||||
if (!pickerId) {
|
||||
return true;
|
||||
}
|
||||
return compactViewStates[pickerId] ?? true;
|
||||
});
|
||||
|
||||
export type PickerContextState<T extends object> = {
|
||||
$optionsOrGroups: WritableAtom<OptionOrGroup<T>[]>;
|
||||
$groupStatusMap: WritableAtom<GroupStatusMap>;
|
||||
$compactView: WritableAtom<boolean>;
|
||||
isCompactView: boolean;
|
||||
$activeOptionId: WritableAtom<string | undefined>;
|
||||
$filteredOptions: WritableAtom<OptionOrGroup<T>[]>;
|
||||
$flattenedFilteredOptions: ReadableAtom<T[]>;
|
||||
@@ -237,6 +249,7 @@ export type PickerContextState<T extends object> = {
|
||||
OptionComponent: React.ComponentType<{ option: T } & BoxProps>;
|
||||
NextToSearchBar?: React.ReactNode;
|
||||
searchable?: boolean;
|
||||
pickerId?: string;
|
||||
};
|
||||
|
||||
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
|
||||
@@ -507,6 +520,7 @@ const countOptions = <T extends object>(optionsOrGroups: OptionOrGroup<T>[]) =>
|
||||
|
||||
export const Picker = typedMemo(<T extends object>(props: PickerProps<T>) => {
|
||||
const {
|
||||
pickerId,
|
||||
getOptionId,
|
||||
optionsOrGroups,
|
||||
handleRef,
|
||||
@@ -525,12 +539,12 @@ export const Picker = typedMemo(<T extends object>(props: PickerProps<T>) => {
|
||||
} = props;
|
||||
const rootRef = useRef<HTMLDivElement>(null);
|
||||
const inputRef = useRef<HTMLInputElement>(null);
|
||||
|
||||
const { $groupStatusMap, $areAllGroupsDisabled, toggleGroup } = useTogglableGroups(
|
||||
optionsOrGroups,
|
||||
initialGroupStates
|
||||
);
|
||||
const $activeOptionId = useAtom(getFirstOptionId(optionsOrGroups, getOptionId));
|
||||
const $compactView = useAtom(true);
|
||||
const $optionsOrGroups = useAtom(optionsOrGroups);
|
||||
const $totalOptionCount = useComputed([$optionsOrGroups], countOptions);
|
||||
const $filteredOptions = useAtom<OptionOrGroup<T>[]>([]);
|
||||
@@ -542,6 +556,9 @@ export const Picker = typedMemo(<T extends object>(props: PickerProps<T>) => {
|
||||
const $searchTerm = useAtom('');
|
||||
const $selectedItemId = useComputed([$selectedItem], (item) => (item ? getOptionId(item) : undefined));
|
||||
|
||||
const selectIsCompactView = useMemo(() => buildSelectIsCompactView(pickerId), [pickerId]);
|
||||
const isCompactView = useAppSelector(selectIsCompactView);
|
||||
|
||||
const onSelectById = useCallback(
|
||||
(id: string) => {
|
||||
const options = $filteredOptions.get();
|
||||
@@ -569,7 +586,7 @@ export const Picker = typedMemo(<T extends object>(props: PickerProps<T>) => {
|
||||
({
|
||||
$optionsOrGroups,
|
||||
$groupStatusMap,
|
||||
$compactView,
|
||||
isCompactView,
|
||||
$activeOptionId,
|
||||
$filteredOptions,
|
||||
$flattenedFilteredOptions,
|
||||
@@ -595,11 +612,12 @@ export const Picker = typedMemo(<T extends object>(props: PickerProps<T>) => {
|
||||
$hasOptions,
|
||||
$hasFilteredOptions,
|
||||
$filteredOptionsCount,
|
||||
pickerId,
|
||||
}) satisfies PickerContextState<T>,
|
||||
[
|
||||
$optionsOrGroups,
|
||||
$groupStatusMap,
|
||||
$compactView,
|
||||
isCompactView,
|
||||
$activeOptionId,
|
||||
$filteredOptions,
|
||||
$flattenedFilteredOptions,
|
||||
@@ -623,6 +641,7 @@ export const Picker = typedMemo(<T extends object>(props: PickerProps<T>) => {
|
||||
$hasOptions,
|
||||
$hasFilteredOptions,
|
||||
$filteredOptionsCount,
|
||||
pickerId,
|
||||
]
|
||||
);
|
||||
|
||||
@@ -873,15 +892,17 @@ GroupToggleButtons.displayName = 'GroupToggleButtons';
|
||||
|
||||
const CompactViewToggleButton = typedMemo(<T extends object>() => {
|
||||
const { t } = useTranslation();
|
||||
const { $compactView } = usePickerContext<T>();
|
||||
const compactView = useStore($compactView);
|
||||
const dispatch = useAppDispatch();
|
||||
const { isCompactView, pickerId } = usePickerContext<T>();
|
||||
|
||||
const onClick = useCallback(() => {
|
||||
$compactView.set(!$compactView.get());
|
||||
}, [$compactView]);
|
||||
if (pickerId) {
|
||||
dispatch(pickerCompactViewStateChanged({ pickerId, isCompact: !isCompactView }));
|
||||
}
|
||||
}, [dispatch, pickerId, isCompactView]);
|
||||
|
||||
const label = compactView ? t('common.fullView') : t('common.compactView');
|
||||
const icon = compactView ? <PiArrowsOutLineVerticalBold /> : <PiArrowsInLineVerticalBold />;
|
||||
const label = isCompactView ? t('common.fullView') : t('common.compactView');
|
||||
const icon = isCompactView ? <PiArrowsOutLineVerticalBold /> : <PiArrowsInLineVerticalBold />;
|
||||
|
||||
return <IconButton aria-label={label} tooltip={label} size="sm" variant="ghost" icon={icon} onClick={onClick} />;
|
||||
});
|
||||
@@ -928,8 +949,7 @@ const listSx = {
|
||||
} satisfies SystemStyleObject;
|
||||
|
||||
const PickerList = typedMemo(<T extends object>() => {
|
||||
const { getOptionId, $compactView, $filteredOptions } = usePickerContext<T>();
|
||||
const compactView = useStore($compactView);
|
||||
const { getOptionId, isCompactView, $filteredOptions } = usePickerContext<T>();
|
||||
const filteredOptions = useStore($filteredOptions);
|
||||
|
||||
if (filteredOptions.length === 0) {
|
||||
@@ -938,10 +958,10 @@ const PickerList = typedMemo(<T extends object>() => {
|
||||
|
||||
return (
|
||||
<ScrollableContent>
|
||||
<Flex sx={listSx} data-is-compact={compactView}>
|
||||
<Flex sx={listSx} data-is-compact={isCompactView}>
|
||||
{filteredOptions.map((optionOrGroup, i) => {
|
||||
if (isGroup(optionOrGroup)) {
|
||||
const withDivider = !compactView && i < filteredOptions.length - 1;
|
||||
const withDivider = !isCompactView && i < filteredOptions.length - 1;
|
||||
return (
|
||||
<React.Fragment key={optionOrGroup.id}>
|
||||
<PickerGroup group={optionOrGroup} />
|
||||
@@ -1083,14 +1103,13 @@ const groupHeaderSx = {
|
||||
|
||||
const PickerGroupHeader = typedMemo(<T extends object>({ group }: { group: Group<T> }) => {
|
||||
const { t } = useTranslation();
|
||||
const { $compactView } = usePickerContext<T>();
|
||||
const compactView = useStore($compactView);
|
||||
const { isCompactView } = usePickerContext<T>();
|
||||
const color = getGroupColor(group);
|
||||
const name = getGroupName(group);
|
||||
const count = getGroupCount(group, t);
|
||||
|
||||
return (
|
||||
<Flex sx={groupHeaderSx} data-is-compact={compactView}>
|
||||
<Flex sx={groupHeaderSx} data-is-compact={isCompactView}>
|
||||
<Flex gap={2} alignItems="center">
|
||||
<Text fontSize="sm" fontWeight="semibold" color={color} noOfLines={1}>
|
||||
{name}
|
||||
|
||||
@@ -6,7 +6,6 @@ import { atom, computed } from 'nanostores';
|
||||
import type { RefObject } from 'react';
|
||||
import { useEffect } from 'react';
|
||||
import { objectKeys } from 'tsafe';
|
||||
import z from 'zod/v4';
|
||||
|
||||
/**
|
||||
* We need to manage focus regions to conditionally enable hotkeys:
|
||||
@@ -28,10 +27,7 @@ import z from 'zod/v4';
|
||||
|
||||
const log = logger('system');
|
||||
|
||||
/**
|
||||
* The names of the focus regions.
|
||||
*/
|
||||
const zFocusRegionName = z.enum([
|
||||
const REGION_NAMES = [
|
||||
'launchpad',
|
||||
'viewer',
|
||||
'gallery',
|
||||
@@ -41,13 +37,16 @@ const zFocusRegionName = z.enum([
|
||||
'workflows',
|
||||
'progress',
|
||||
'settings',
|
||||
]);
|
||||
export type FocusRegionName = z.infer<typeof zFocusRegionName>;
|
||||
] as const;
|
||||
/**
|
||||
* The names of the focus regions.
|
||||
*/
|
||||
export type FocusRegionName = (typeof REGION_NAMES)[number];
|
||||
|
||||
/**
|
||||
* A map of focus regions to the elements that are part of that region.
|
||||
*/
|
||||
const REGION_TARGETS: Record<FocusRegionName, Set<HTMLElement>> = zFocusRegionName.options.values().reduce(
|
||||
const REGION_TARGETS: Record<FocusRegionName, Set<HTMLElement>> = REGION_NAMES.reduce(
|
||||
(acc, region) => {
|
||||
acc[region] = new Set<HTMLElement>();
|
||||
return acc;
|
||||
|
||||
115
invokeai/frontend/web/src/common/hooks/useAsyncState.ts
Normal file
115
invokeai/frontend/web/src/common/hooks/useAsyncState.ts
Normal file
@@ -0,0 +1,115 @@
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { WrappedError } from 'common/util/result';
|
||||
import type { Atom } from 'nanostores';
|
||||
import { atom } from 'nanostores';
|
||||
import { useCallback, useEffect, useMemo, useState } from 'react';
|
||||
|
||||
type SuccessState<T> = {
|
||||
status: 'success';
|
||||
value: T;
|
||||
error: null;
|
||||
};
|
||||
|
||||
type ErrorState = {
|
||||
status: 'error';
|
||||
value: null;
|
||||
error: Error;
|
||||
};
|
||||
|
||||
type PendingState = {
|
||||
status: 'pending';
|
||||
value: null;
|
||||
error: null;
|
||||
};
|
||||
|
||||
type IdleState = {
|
||||
status: 'idle';
|
||||
value: null;
|
||||
error: null;
|
||||
};
|
||||
|
||||
export type State<T> = IdleState | PendingState | SuccessState<T> | ErrorState;
|
||||
|
||||
type UseAsyncStateOptions = {
|
||||
immediate?: boolean;
|
||||
};
|
||||
|
||||
type UseAsyncReturn<T> = {
|
||||
$state: Atom<State<T>>;
|
||||
trigger: () => Promise<void>;
|
||||
reset: () => void;
|
||||
};
|
||||
|
||||
export const useAsyncState = <T>(execute: () => Promise<T>, options?: UseAsyncStateOptions): UseAsyncReturn<T> => {
|
||||
const $state = useState(() =>
|
||||
atom<State<T>>({
|
||||
status: 'idle',
|
||||
value: null,
|
||||
error: null,
|
||||
})
|
||||
)[0];
|
||||
|
||||
const trigger = useCallback(async () => {
|
||||
$state.set({
|
||||
status: 'pending',
|
||||
value: null,
|
||||
error: null,
|
||||
});
|
||||
try {
|
||||
const value = await execute();
|
||||
$state.set({
|
||||
status: 'success',
|
||||
value,
|
||||
error: null,
|
||||
});
|
||||
} catch (error) {
|
||||
$state.set({
|
||||
status: 'error',
|
||||
value: null,
|
||||
error: WrappedError.wrap(error),
|
||||
});
|
||||
}
|
||||
}, [$state, execute]);
|
||||
|
||||
const reset = useCallback(() => {
|
||||
$state.set({
|
||||
status: 'idle',
|
||||
value: null,
|
||||
error: null,
|
||||
});
|
||||
}, [$state]);
|
||||
|
||||
useEffect(() => {
|
||||
if (options?.immediate) {
|
||||
trigger();
|
||||
}
|
||||
}, [options?.immediate, trigger]);
|
||||
|
||||
const api = useMemo(
|
||||
() =>
|
||||
({
|
||||
$state,
|
||||
trigger,
|
||||
reset,
|
||||
}) satisfies UseAsyncReturn<T>,
|
||||
[$state, trigger, reset]
|
||||
);
|
||||
|
||||
return api;
|
||||
};
|
||||
|
||||
type UseAsyncReturnReactive<T> = {
|
||||
state: State<T>;
|
||||
trigger: () => Promise<void>;
|
||||
reset: () => void;
|
||||
};
|
||||
|
||||
export const useAsyncStateReactive = <T>(
|
||||
execute: () => Promise<T>,
|
||||
options?: UseAsyncStateOptions
|
||||
): UseAsyncReturnReactive<T> => {
|
||||
const { $state, trigger, reset } = useAsyncState(execute, options);
|
||||
const state = useStore($state);
|
||||
|
||||
return { state, trigger, reset };
|
||||
};
|
||||
@@ -6,7 +6,7 @@ import { selectAutoAddBoardId } from 'features/gallery/store/gallerySelectors';
|
||||
import { selectIsClientSideUploadEnabled } from 'features/system/store/configSlice';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { memo, useCallback } from 'react';
|
||||
import type { FileRejection } from 'react-dropzone';
|
||||
import type { Accept, FileRejection } from 'react-dropzone';
|
||||
import { useDropzone } from 'react-dropzone';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiUploadBold } from 'react-icons/pi';
|
||||
@@ -15,6 +15,18 @@ import type { ImageDTO } from 'services/api/types';
|
||||
import { assert } from 'tsafe';
|
||||
import type { SetOptional } from 'type-fest';
|
||||
|
||||
const addUpperCaseReducer = (acc: string[], ext: string) => {
|
||||
acc.push(ext);
|
||||
acc.push(ext.toUpperCase());
|
||||
return acc;
|
||||
};
|
||||
|
||||
export const dropzoneAccept: Accept = {
|
||||
'image/png': ['.png'].reduce(addUpperCaseReducer, [] as string[]),
|
||||
'image/jpeg': ['.jpg', '.jpeg', '.png'].reduce(addUpperCaseReducer, [] as string[]),
|
||||
'image/webp': ['.webp'].reduce(addUpperCaseReducer, [] as string[]),
|
||||
};
|
||||
|
||||
import { useClientSideUpload } from './useClientSideUpload';
|
||||
type UseImageUploadButtonArgs =
|
||||
| {
|
||||
@@ -164,11 +176,7 @@ export const useImageUploadButton = ({
|
||||
getInputProps: getUploadInputProps,
|
||||
open: openUploader,
|
||||
} = useDropzone({
|
||||
accept: {
|
||||
'image/png': ['.png'],
|
||||
'image/jpeg': ['.jpg', '.jpeg', '.png'],
|
||||
'image/webp': ['.webp'],
|
||||
},
|
||||
accept: dropzoneAccept,
|
||||
onDropAccepted,
|
||||
onDropRejected,
|
||||
disabled: isDisabled,
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
export const preventDefault = (e: React.MouseEvent) => {
|
||||
import type { MouseEvent } from 'react';
|
||||
|
||||
export const preventDefault = (e: MouseEvent) => {
|
||||
e.preventDefault();
|
||||
};
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
/* eslint-disable @typescript-eslint/no-explicit-any */
|
||||
import type React from 'react';
|
||||
import { memo } from 'react';
|
||||
|
||||
/**
|
||||
* A typed version of React.memo, useful for components that take generics.
|
||||
*/
|
||||
export const typedMemo: <T extends keyof JSX.IntrinsicElements | React.JSXElementConstructor<any>>(
|
||||
export const typedMemo: <T extends keyof React.JSX.IntrinsicElements | React.JSXElementConstructor<any>>(
|
||||
component: T,
|
||||
propsAreEqual?: (prevProps: React.ComponentProps<T>, nextProps: React.ComponentProps<T>) => boolean
|
||||
) => T & { displayName?: string } = memo;
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import type { z } from 'zod/v4';
|
||||
import type { z } from 'zod';
|
||||
|
||||
/**
|
||||
* Helper to create a type guard from a zod schema. The type guard will infer the schema's TS type.
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||
import { Combobox, ConfirmationAlertDialog, Flex, FormControl, Text } from '@invoke-ai/ui-library';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useAssertSingleton } from 'common/hooks/useAssertSingleton';
|
||||
import {
|
||||
@@ -14,7 +13,7 @@ import { useTranslation } from 'react-i18next';
|
||||
import { useListAllBoardsQuery } from 'services/api/endpoints/boards';
|
||||
import { useAddImagesToBoardMutation, useRemoveImagesFromBoardMutation } from 'services/api/endpoints/images';
|
||||
|
||||
const selectImagesToChange = createMemoizedSelector(
|
||||
const selectImagesToChange = createSelector(
|
||||
selectChangeBoardModalSlice,
|
||||
(changeBoardModal) => changeBoardModal.image_names
|
||||
);
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
import { Alert, AlertIcon, AlertTitle } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectSaveAllImagesToGallery } from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
export const CanvasAlertsSaveAllImagesToGallery = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const saveAllImagesToGallery = useAppSelector(selectSaveAllImagesToGallery);
|
||||
|
||||
if (!saveAllImagesToGallery) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<Alert status="warning" borderRadius="base" fontSize="sm" shadow="md" w="fit-content">
|
||||
<AlertIcon />
|
||||
<AlertTitle>{t('controlLayers.settings.saveAllImagesToGallery.alert')}</AlertTitle>
|
||||
</Alert>
|
||||
);
|
||||
});
|
||||
|
||||
CanvasAlertsSaveAllImagesToGallery.displayName = 'CanvasAlertsSaveAllImagesToGallery';
|
||||
@@ -57,21 +57,21 @@ const CanvasAlertsSelectedEntityStatusContent = memo(({ entityIdentifier, adapte
|
||||
const alert = useMemo<AlertData | null>(() => {
|
||||
if (isFiltering) {
|
||||
return {
|
||||
status: 'info',
|
||||
status: 'warning',
|
||||
title: t('controlLayers.HUD.entityStatus.isFiltering', { title }),
|
||||
};
|
||||
}
|
||||
|
||||
if (isTransforming) {
|
||||
return {
|
||||
status: 'info',
|
||||
status: 'warning',
|
||||
title: t('controlLayers.HUD.entityStatus.isTransforming', { title }),
|
||||
};
|
||||
}
|
||||
|
||||
if (isEmpty) {
|
||||
return {
|
||||
status: 'info',
|
||||
status: 'warning',
|
||||
title: t('controlLayers.HUD.entityStatus.isEmpty', { title }),
|
||||
};
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ import { EntityListGlobalActionBarAddLayerMenu } from 'features/controlLayers/co
|
||||
import { EntityListSelectedEntityActionBarDuplicateButton } from 'features/controlLayers/components/CanvasEntityList/EntityListSelectedEntityActionBarDuplicateButton';
|
||||
import { EntityListSelectedEntityActionBarFill } from 'features/controlLayers/components/CanvasEntityList/EntityListSelectedEntityActionBarFill';
|
||||
import { EntityListSelectedEntityActionBarFilterButton } from 'features/controlLayers/components/CanvasEntityList/EntityListSelectedEntityActionBarFilterButton';
|
||||
import { EntityListSelectedEntityActionBarInvertMaskButton } from 'features/controlLayers/components/CanvasEntityList/EntityListSelectedEntityActionBarInvertMaskButton';
|
||||
import { EntityListSelectedEntityActionBarOpacity } from 'features/controlLayers/components/CanvasEntityList/EntityListSelectedEntityActionBarOpacity';
|
||||
import { EntityListSelectedEntityActionBarSelectObjectButton } from 'features/controlLayers/components/CanvasEntityList/EntityListSelectedEntityActionBarSelectObjectButton';
|
||||
import { EntityListSelectedEntityActionBarTransformButton } from 'features/controlLayers/components/CanvasEntityList/EntityListSelectedEntityActionBarTransformButton';
|
||||
@@ -21,6 +22,7 @@ export const EntityListSelectedEntityActionBar = memo(() => {
|
||||
<EntityListSelectedEntityActionBarSelectObjectButton />
|
||||
<EntityListSelectedEntityActionBarFilterButton />
|
||||
<EntityListSelectedEntityActionBarTransformButton />
|
||||
<EntityListSelectedEntityActionBarInvertMaskButton />
|
||||
<EntityListSelectedEntityActionBarSaveToAssetsButton />
|
||||
<EntityListSelectedEntityActionBarDuplicateButton />
|
||||
<EntityListNonRasterLayerToggle />
|
||||
|
||||
@@ -0,0 +1,39 @@
|
||||
import { IconButton } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
|
||||
import { useInvertMask } from 'features/controlLayers/hooks/useInvertMask';
|
||||
import { selectSelectedEntityIdentifier } from 'features/controlLayers/store/selectors';
|
||||
import { isInpaintMaskEntityIdentifier } from 'features/controlLayers/store/types';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiSelectionInverseBold } from 'react-icons/pi';
|
||||
|
||||
export const EntityListSelectedEntityActionBarInvertMaskButton = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const selectedEntityIdentifier = useAppSelector(selectSelectedEntityIdentifier);
|
||||
const isBusy = useCanvasIsBusy();
|
||||
const invertMask = useInvertMask();
|
||||
|
||||
if (!selectedEntityIdentifier) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (!isInpaintMaskEntityIdentifier(selectedEntityIdentifier)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<IconButton
|
||||
onClick={invertMask}
|
||||
isDisabled={isBusy}
|
||||
minW={8}
|
||||
variant="link"
|
||||
alignSelf="stretch"
|
||||
aria-label={t('controlLayers.invertMask')}
|
||||
tooltip={t('controlLayers.invertMask')}
|
||||
icon={<PiSelectionInverseBold />}
|
||||
/>
|
||||
);
|
||||
});
|
||||
|
||||
EntityListSelectedEntityActionBarInvertMaskButton.displayName = 'EntityListSelectedEntityActionBarInvertMaskButton';
|
||||
@@ -5,7 +5,6 @@ import { useEntityIdentifierContext } from 'features/controlLayers/contexts/Enti
|
||||
import { usePullBboxIntoLayer } from 'features/controlLayers/hooks/saveCanvasHooks';
|
||||
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
|
||||
import { replaceCanvasEntityObjectsWithImage } from 'features/imageActions/actions';
|
||||
import { activeTabCanvasRightPanelChanged } from 'features/ui/store/uiSlice';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { Trans } from 'react-i18next';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
@@ -21,9 +20,6 @@ export const ControlLayerSettingsEmptyState = memo(() => {
|
||||
[dispatch, entityIdentifier, getState]
|
||||
);
|
||||
const uploadApi = useImageUploadButton({ onUpload, allowMultiple: false });
|
||||
const onClickGalleryButton = useCallback(() => {
|
||||
dispatch(activeTabCanvasRightPanelChanged('gallery'));
|
||||
}, [dispatch]);
|
||||
const pullBboxIntoLayer = usePullBboxIntoLayer(entityIdentifier);
|
||||
|
||||
const components = useMemo(
|
||||
@@ -31,14 +27,11 @@ export const ControlLayerSettingsEmptyState = memo(() => {
|
||||
UploadButton: (
|
||||
<Button isDisabled={isBusy} size="sm" variant="link" color="base.300" {...uploadApi.getUploadButtonProps()} />
|
||||
),
|
||||
GalleryButton: (
|
||||
<Button onClick={onClickGalleryButton} isDisabled={isBusy} size="sm" variant="link" color="base.300" />
|
||||
),
|
||||
PullBboxButton: (
|
||||
<Button onClick={pullBboxIntoLayer} isDisabled={isBusy} size="sm" variant="link" color="base.300" />
|
||||
),
|
||||
}),
|
||||
[isBusy, onClickGalleryButton, pullBboxIntoLayer, uploadApi]
|
||||
[isBusy, pullBboxIntoLayer, uploadApi]
|
||||
);
|
||||
|
||||
return (
|
||||
|
||||
@@ -1,131 +0,0 @@
|
||||
import { Checkbox, ConfirmationAlertDialog, Flex, FormControl, FormLabel, Text } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useAssertSingleton } from 'common/hooks/useAssertSingleton';
|
||||
import { buildUseBoolean } from 'common/hooks/useBoolean';
|
||||
import { canvasSessionReset, generateSessionReset } from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
import {
|
||||
selectSystemShouldConfirmOnNewSession,
|
||||
shouldConfirmOnNewSessionToggled,
|
||||
} from 'features/system/store/systemSlice';
|
||||
import { activeTabCanvasRightPanelChanged } from 'features/ui/store/uiSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const [useNewGallerySessionDialog] = buildUseBoolean(false);
|
||||
const [useNewCanvasSessionDialog] = buildUseBoolean(false);
|
||||
|
||||
const useNewGallerySession = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const shouldConfirmOnNewSession = useAppSelector(selectSystemShouldConfirmOnNewSession);
|
||||
const newSessionDialog = useNewGallerySessionDialog();
|
||||
|
||||
const newGallerySessionImmediate = useCallback(() => {
|
||||
dispatch(generateSessionReset());
|
||||
dispatch(activeTabCanvasRightPanelChanged('gallery'));
|
||||
}, [dispatch]);
|
||||
|
||||
const newGallerySessionWithDialog = useCallback(() => {
|
||||
if (shouldConfirmOnNewSession) {
|
||||
newSessionDialog.setTrue();
|
||||
return;
|
||||
}
|
||||
newGallerySessionImmediate();
|
||||
}, [newGallerySessionImmediate, newSessionDialog, shouldConfirmOnNewSession]);
|
||||
|
||||
return { newGallerySessionImmediate, newGallerySessionWithDialog };
|
||||
};
|
||||
|
||||
const useNewCanvasSession = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const shouldConfirmOnNewSession = useAppSelector(selectSystemShouldConfirmOnNewSession);
|
||||
const newSessionDialog = useNewCanvasSessionDialog();
|
||||
|
||||
const newCanvasSessionImmediate = useCallback(() => {
|
||||
dispatch(canvasSessionReset());
|
||||
dispatch(activeTabCanvasRightPanelChanged('layers'));
|
||||
}, [dispatch]);
|
||||
|
||||
const newCanvasSessionWithDialog = useCallback(() => {
|
||||
if (shouldConfirmOnNewSession) {
|
||||
newSessionDialog.setTrue();
|
||||
return;
|
||||
}
|
||||
|
||||
newCanvasSessionImmediate();
|
||||
}, [newCanvasSessionImmediate, newSessionDialog, shouldConfirmOnNewSession]);
|
||||
|
||||
return { newCanvasSessionImmediate, newCanvasSessionWithDialog };
|
||||
};
|
||||
|
||||
export const NewGallerySessionDialog = memo(() => {
|
||||
useAssertSingleton('NewGallerySessionDialog');
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const dialog = useNewGallerySessionDialog();
|
||||
const { newGallerySessionImmediate } = useNewGallerySession();
|
||||
|
||||
const shouldConfirmOnNewSession = useAppSelector(selectSystemShouldConfirmOnNewSession);
|
||||
const onToggleConfirm = useCallback(() => {
|
||||
dispatch(shouldConfirmOnNewSessionToggled());
|
||||
}, [dispatch]);
|
||||
|
||||
return (
|
||||
<ConfirmationAlertDialog
|
||||
isOpen={dialog.isTrue}
|
||||
onClose={dialog.setFalse}
|
||||
title={t('controlLayers.newGallerySession')}
|
||||
acceptCallback={newGallerySessionImmediate}
|
||||
acceptButtonText={t('common.ok')}
|
||||
useInert={false}
|
||||
>
|
||||
<Flex direction="column" gap={3}>
|
||||
<Text>{t('controlLayers.newGallerySessionDesc')}</Text>
|
||||
<Text>{t('common.areYouSure')}</Text>
|
||||
<FormControl>
|
||||
<FormLabel>{t('common.dontAskMeAgain')}</FormLabel>
|
||||
<Checkbox isChecked={!shouldConfirmOnNewSession} onChange={onToggleConfirm} />
|
||||
</FormControl>
|
||||
</Flex>
|
||||
</ConfirmationAlertDialog>
|
||||
);
|
||||
});
|
||||
|
||||
NewGallerySessionDialog.displayName = 'NewGallerySessionDialog';
|
||||
|
||||
export const NewCanvasSessionDialog = memo(() => {
|
||||
useAssertSingleton('NewCanvasSessionDialog');
|
||||
const { t } = useTranslation();
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const dialog = useNewCanvasSessionDialog();
|
||||
const { newCanvasSessionImmediate } = useNewCanvasSession();
|
||||
|
||||
const shouldConfirmOnNewSession = useAppSelector(selectSystemShouldConfirmOnNewSession);
|
||||
const onToggleConfirm = useCallback(() => {
|
||||
dispatch(shouldConfirmOnNewSessionToggled());
|
||||
}, [dispatch]);
|
||||
|
||||
return (
|
||||
<ConfirmationAlertDialog
|
||||
isOpen={dialog.isTrue}
|
||||
onClose={dialog.setFalse}
|
||||
title={t('controlLayers.newCanvasSession')}
|
||||
acceptCallback={newCanvasSessionImmediate}
|
||||
acceptButtonText={t('common.ok')}
|
||||
useInert={false}
|
||||
>
|
||||
<Flex direction="column" gap={3}>
|
||||
<Text>{t('controlLayers.newCanvasSessionDesc')}</Text>
|
||||
<Text>{t('common.areYouSure')}</Text>
|
||||
<FormControl>
|
||||
<FormLabel>{t('common.dontAskMeAgain')}</FormLabel>
|
||||
<Checkbox isChecked={!shouldConfirmOnNewSession} onChange={onToggleConfirm} />
|
||||
</FormControl>
|
||||
</Flex>
|
||||
</ConfirmationAlertDialog>
|
||||
);
|
||||
});
|
||||
|
||||
NewCanvasSessionDialog.displayName = 'NewCanvasSessionDialog';
|
||||
@@ -4,13 +4,17 @@ import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useRefImageEntity } from 'features/controlLayers/components/RefImage/useRefImageEntity';
|
||||
import { useRefImageIdContext } from 'features/controlLayers/contexts/RefImageIdContext';
|
||||
import { selectMainModelConfig } from 'features/controlLayers/store/paramsSlice';
|
||||
import {
|
||||
refImageDeleted,
|
||||
refImageIsEnabledToggled,
|
||||
selectRefImageEntityIds,
|
||||
} from 'features/controlLayers/store/refImagesSlice';
|
||||
import { getGlobalReferenceImageWarnings } from 'features/controlLayers/store/validators';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { PiCircleBold, PiCircleFill, PiTrashBold } from 'react-icons/pi';
|
||||
import { PiCircleBold, PiCircleFill, PiTrashBold, PiWarningBold } from 'react-icons/pi';
|
||||
|
||||
import { RefImageWarningTooltipContent } from './RefImageWarningTooltipContent';
|
||||
|
||||
const textSx: SystemStyleObject = {
|
||||
color: 'base.300',
|
||||
@@ -28,6 +32,12 @@ export const RefImageHeader = memo(() => {
|
||||
);
|
||||
const refImageNumber = useAppSelector(selectRefImageNumber);
|
||||
const entity = useRefImageEntity(id);
|
||||
const mainModelConfig = useAppSelector(selectMainModelConfig);
|
||||
|
||||
const warnings = useMemo(() => {
|
||||
return getGlobalReferenceImageWarnings(entity, mainModelConfig);
|
||||
}, [entity, mainModelConfig]);
|
||||
|
||||
const deleteRefImage = useCallback(() => {
|
||||
dispatch(refImageDeleted({ id }));
|
||||
}, [dispatch, id]);
|
||||
@@ -42,6 +52,18 @@ export const RefImageHeader = memo(() => {
|
||||
Reference Image #{refImageNumber}
|
||||
</Text>
|
||||
<Flex alignItems="center" gap={1}>
|
||||
{warnings.length > 0 && (
|
||||
<IconButton
|
||||
as="span"
|
||||
size="sm"
|
||||
variant="link"
|
||||
alignSelf="stretch"
|
||||
aria-label="warnings"
|
||||
tooltip={<RefImageWarningTooltipContent warnings={warnings} />}
|
||||
icon={<PiWarningBold />}
|
||||
colorScheme="warning"
|
||||
/>
|
||||
)}
|
||||
{!entity.isEnabled && (
|
||||
<Text fontSize="xs" fontStyle="italic" color="base.400">
|
||||
Disabled
|
||||
|
||||
@@ -61,7 +61,7 @@ export const RefImageImage = memo(
|
||||
)}
|
||||
{imageDTO && (
|
||||
<>
|
||||
<DndImage imageDTO={imageDTO} borderWidth={1} borderStyle="solid" w="full" />
|
||||
<DndImage imageDTO={imageDTO} borderRadius="base" borderWidth={1} borderStyle="solid" w="full" />
|
||||
<Flex position="absolute" flexDir="column" top={2} insetInlineEnd={2} gap={1}>
|
||||
<DndImageIcon
|
||||
onClick={handleResetControlImage}
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
import { Button, Collapse, Divider, Flex } from '@invoke-ai/ui-library';
|
||||
import { Button, Collapse, Divider, Flex, IconButton } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector, useAppStore } from 'app/store/storeHooks';
|
||||
import { useImageUploadButton } from 'common/hooks/useImageUploadButton';
|
||||
import { RefImagePreview } from 'features/controlLayers/components/RefImage/RefImagePreview';
|
||||
import { CanvasManagerProviderGate } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
|
||||
import { RefImageIdContext } from 'features/controlLayers/contexts/RefImageIdContext';
|
||||
import { getDefaultRefImageConfig } from 'features/controlLayers/hooks/addLayerHooks';
|
||||
import { useNewGlobalReferenceImageFromBbox } from 'features/controlLayers/hooks/saveCanvasHooks';
|
||||
import { useCanvasIsBusySafe } from 'features/controlLayers/hooks/useCanvasIsBusy';
|
||||
import {
|
||||
refImageAdded,
|
||||
selectIsRefImagePanelOpen,
|
||||
@@ -13,8 +16,10 @@ import {
|
||||
import { imageDTOToImageWithDims } from 'features/controlLayers/store/util';
|
||||
import { addGlobalReferenceImageDndTarget } from 'features/dnd/dnd';
|
||||
import { DndDropTarget } from 'features/dnd/DndDropTarget';
|
||||
import { selectActiveTab } from 'features/ui/store/uiSelectors';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { PiUploadBold } from 'react-icons/pi';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiBoundingBoxBold, PiUploadBold } from 'react-icons/pi';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
|
||||
import { RefImageHeader } from './RefImageHeader';
|
||||
@@ -78,6 +83,7 @@ MaxRefImages.displayName = 'MaxRefImages';
|
||||
|
||||
const AddRefImageDropTargetAndButton = memo(() => {
|
||||
const { dispatch, getState } = useAppStore();
|
||||
const tab = useAppSelector(selectActiveTab);
|
||||
|
||||
const uploadOptions = useMemo(
|
||||
() =>
|
||||
@@ -95,7 +101,7 @@ const AddRefImageDropTargetAndButton = memo(() => {
|
||||
const uploadApi = useImageUploadButton(uploadOptions);
|
||||
|
||||
return (
|
||||
<>
|
||||
<Flex gap={1} h="full" w="full">
|
||||
<Button
|
||||
position="relative"
|
||||
size="sm"
|
||||
@@ -112,7 +118,32 @@ const AddRefImageDropTargetAndButton = memo(() => {
|
||||
<input {...uploadApi.getUploadInputProps()} />
|
||||
<DndDropTarget label="Drop" dndTarget={addGlobalReferenceImageDndTarget} dndTargetData={dndTargetData} />
|
||||
</Button>
|
||||
</>
|
||||
{tab === 'canvas' && (
|
||||
<CanvasManagerProviderGate>
|
||||
<BboxButton />
|
||||
</CanvasManagerProviderGate>
|
||||
)}
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
AddRefImageDropTargetAndButton.displayName = 'AddRefImageDropTargetAndButton';
|
||||
|
||||
const BboxButton = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const isBusy = useCanvasIsBusySafe();
|
||||
const newGlobalReferenceImageFromBbox = useNewGlobalReferenceImageFromBbox();
|
||||
|
||||
return (
|
||||
<IconButton
|
||||
size="lg"
|
||||
variant="outline"
|
||||
h="full"
|
||||
icon={<PiBoundingBoxBold />}
|
||||
onClick={newGlobalReferenceImageFromBbox}
|
||||
isDisabled={isBusy}
|
||||
aria-label={t('controlLayers.pullBboxIntoReferenceImage')}
|
||||
tooltip={t('controlLayers.pullBboxIntoReferenceImage')}
|
||||
/>
|
||||
);
|
||||
});
|
||||
BboxButton.displayName = 'BboxButton';
|
||||
|
||||
@@ -6,7 +6,6 @@ import type { SetGlobalReferenceImageDndTargetData } from 'features/dnd/dnd';
|
||||
import { setGlobalReferenceImageDndTarget } from 'features/dnd/dnd';
|
||||
import { DndDropTarget } from 'features/dnd/DndDropTarget';
|
||||
import { setGlobalReferenceImage } from 'features/imageActions/actions';
|
||||
import { activeTabCanvasRightPanelChanged } from 'features/ui/store/uiSlice';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { Trans, useTranslation } from 'react-i18next';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
@@ -22,9 +21,6 @@ export const RefImageNoImageState = memo(() => {
|
||||
[dispatch, id]
|
||||
);
|
||||
const uploadApi = useImageUploadButton({ onUpload, allowMultiple: false });
|
||||
const onClickGalleryButton = useCallback(() => {
|
||||
dispatch(activeTabCanvasRightPanelChanged('gallery'));
|
||||
}, [dispatch]);
|
||||
|
||||
const dndTargetData = useMemo<SetGlobalReferenceImageDndTargetData>(
|
||||
() => setGlobalReferenceImageDndTarget.getData({ id }),
|
||||
@@ -34,9 +30,8 @@ export const RefImageNoImageState = memo(() => {
|
||||
const components = useMemo(
|
||||
() => ({
|
||||
UploadButton: <Button size="sm" variant="link" color="base.300" {...uploadApi.getUploadButtonProps()} />,
|
||||
GalleryButton: <Button onClick={onClickGalleryButton} size="sm" variant="link" color="base.300" />,
|
||||
}),
|
||||
[onClickGalleryButton, uploadApi]
|
||||
[uploadApi]
|
||||
);
|
||||
|
||||
return (
|
||||
|
||||
@@ -8,7 +8,6 @@ import type { SetGlobalReferenceImageDndTargetData } from 'features/dnd/dnd';
|
||||
import { setGlobalReferenceImageDndTarget } from 'features/dnd/dnd';
|
||||
import { DndDropTarget } from 'features/dnd/DndDropTarget';
|
||||
import { setGlobalReferenceImage } from 'features/imageActions/actions';
|
||||
import { activeTabCanvasRightPanelChanged } from 'features/ui/store/uiSlice';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { Trans, useTranslation } from 'react-i18next';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
@@ -25,9 +24,6 @@ export const RefImageNoImageStateWithCanvasOptions = memo(() => {
|
||||
[dispatch, id]
|
||||
);
|
||||
const uploadApi = useImageUploadButton({ onUpload, allowMultiple: false });
|
||||
const onClickGalleryButton = useCallback(() => {
|
||||
dispatch(activeTabCanvasRightPanelChanged('gallery'));
|
||||
}, [dispatch]);
|
||||
const pullBboxIntoIPAdapter = usePullBboxIntoGlobalReferenceImage(id);
|
||||
|
||||
const dndTargetData = useMemo<SetGlobalReferenceImageDndTargetData>(
|
||||
@@ -40,14 +36,11 @@ export const RefImageNoImageStateWithCanvasOptions = memo(() => {
|
||||
UploadButton: (
|
||||
<Button isDisabled={isBusy} size="sm" variant="link" color="base.300" {...uploadApi.getUploadButtonProps()} />
|
||||
),
|
||||
GalleryButton: (
|
||||
<Button onClick={onClickGalleryButton} isDisabled={isBusy} size="sm" variant="link" color="base.300" />
|
||||
),
|
||||
PullBboxButton: (
|
||||
<Button onClick={pullBboxIntoIPAdapter} isDisabled={isBusy} size="sm" variant="link" color="base.300" />
|
||||
),
|
||||
}),
|
||||
[isBusy, onClickGalleryButton, pullBboxIntoIPAdapter, uploadApi]
|
||||
[isBusy, pullBboxIntoIPAdapter, uploadApi]
|
||||
);
|
||||
|
||||
return (
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { Flex, Icon, IconButton, Image, Skeleton, Text } from '@invoke-ai/ui-library';
|
||||
import { Flex, Icon, IconButton, Image, Skeleton, Text, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { round } from 'es-toolkit/compat';
|
||||
@@ -17,6 +17,8 @@ import { memo, useCallback, useEffect, useMemo, useState } from 'react';
|
||||
import { PiExclamationMarkBold, PiEyeSlashBold, PiImageBold } from 'react-icons/pi';
|
||||
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
||||
|
||||
import { RefImageWarningTooltipContent } from './RefImageWarningTooltipContent';
|
||||
|
||||
const baseSx: SystemStyleObject = {
|
||||
'&[data-is-open="true"]': {
|
||||
borderColor: 'invokeBlue.300',
|
||||
@@ -51,9 +53,6 @@ const getImageSxWithWeight = (weight: number): SystemStyleObject => {
|
||||
|
||||
return {
|
||||
...baseSx,
|
||||
'&[data-is-disabled="true"]': {
|
||||
opacity: 0.4,
|
||||
},
|
||||
_after: {
|
||||
content: '""',
|
||||
position: 'absolute',
|
||||
@@ -95,8 +94,8 @@ export const RefImagePreview = memo(() => {
|
||||
};
|
||||
}, [entity.config]);
|
||||
|
||||
const isInvalid = useMemo(() => {
|
||||
return getGlobalReferenceImageWarnings(entity, mainModelConfig).length > 0;
|
||||
const warnings = useMemo(() => {
|
||||
return getGlobalReferenceImageWarnings(entity, mainModelConfig);
|
||||
}, [entity, mainModelConfig]);
|
||||
|
||||
const onClick = useCallback(() => {
|
||||
@@ -126,74 +125,76 @@ export const RefImagePreview = memo(() => {
|
||||
);
|
||||
}
|
||||
return (
|
||||
<Flex
|
||||
position="relative"
|
||||
borderWidth={1}
|
||||
borderStyle="solid"
|
||||
borderRadius="base"
|
||||
aspectRatio="1/1"
|
||||
maxW="full"
|
||||
maxH="full"
|
||||
flexShrink={0}
|
||||
sx={sx}
|
||||
data-is-open={selectedEntityId === id && isPanelOpen}
|
||||
data-is-error={isInvalid}
|
||||
data-is-disabled={!entity.isEnabled}
|
||||
role="button"
|
||||
onClick={onClick}
|
||||
cursor="pointer"
|
||||
>
|
||||
<Image
|
||||
src={imageDTO?.thumbnail_url}
|
||||
objectFit="contain"
|
||||
<Tooltip label={warnings.length > 0 ? <RefImageWarningTooltipContent warnings={warnings} /> : undefined}>
|
||||
<Flex
|
||||
position="relative"
|
||||
borderWidth={1}
|
||||
borderStyle="solid"
|
||||
borderRadius="base"
|
||||
aspectRatio="1/1"
|
||||
height={imageDTO?.height}
|
||||
fallback={<Skeleton h="full" aspectRatio="1/1" />}
|
||||
maxW="full"
|
||||
maxH="full"
|
||||
borderRadius="base"
|
||||
/>
|
||||
{isIPAdapterConfig(entity.config) && (
|
||||
<Flex
|
||||
position="absolute"
|
||||
inset={0}
|
||||
fontWeight="semibold"
|
||||
alignItems="center"
|
||||
justifyContent="center"
|
||||
zIndex={1}
|
||||
data-visible={showWeightDisplay}
|
||||
sx={weightDisplaySx}
|
||||
>
|
||||
<Text filter="drop-shadow(0px 0px 4px rgb(0, 0, 0)) drop-shadow(0px 0px 2px rgba(0, 0, 0, 1))">
|
||||
{`${round(entity.config.weight * 100, 2)}%`}
|
||||
</Text>
|
||||
</Flex>
|
||||
)}
|
||||
{!entity.isEnabled && (
|
||||
<Icon
|
||||
position="absolute"
|
||||
top="50%"
|
||||
left="50%"
|
||||
transform="translateX(-50%) translateY(-50%)"
|
||||
filter="drop-shadow(0px 0px 4px rgb(0, 0, 0)) drop-shadow(0px 0px 2px rgba(0, 0, 0, 1))"
|
||||
color="base.300"
|
||||
boxSize={8}
|
||||
as={PiEyeSlashBold}
|
||||
flexShrink={0}
|
||||
sx={sx}
|
||||
data-is-open={selectedEntityId === id && isPanelOpen}
|
||||
data-is-error={warnings.length > 0}
|
||||
data-is-disabled={!entity.isEnabled}
|
||||
role="button"
|
||||
onClick={onClick}
|
||||
cursor="pointer"
|
||||
overflow="hidden"
|
||||
>
|
||||
<Image
|
||||
src={imageDTO?.thumbnail_url}
|
||||
objectFit="contain"
|
||||
aspectRatio="1/1"
|
||||
height={imageDTO?.height}
|
||||
fallback={<Skeleton h="full" aspectRatio="1/1" />}
|
||||
maxW="full"
|
||||
maxH="full"
|
||||
/>
|
||||
)}
|
||||
{entity.isEnabled && isInvalid && (
|
||||
<Icon
|
||||
position="absolute"
|
||||
top="50%"
|
||||
left="50%"
|
||||
transform="translateX(-50%) translateY(-50%)"
|
||||
filter="drop-shadow(0px 0px 4px rgb(0, 0, 0)) drop-shadow(0px 0px 2px rgba(0, 0, 0, 1))"
|
||||
color="error.500"
|
||||
boxSize={12}
|
||||
as={PiExclamationMarkBold}
|
||||
/>
|
||||
)}
|
||||
</Flex>
|
||||
{isIPAdapterConfig(entity.config) && (
|
||||
<Flex
|
||||
position="absolute"
|
||||
inset={0}
|
||||
fontWeight="semibold"
|
||||
alignItems="center"
|
||||
justifyContent="center"
|
||||
zIndex={1}
|
||||
data-visible={showWeightDisplay}
|
||||
sx={weightDisplaySx}
|
||||
>
|
||||
<Text filter="drop-shadow(0px 0px 4px rgb(0, 0, 0)) drop-shadow(0px 0px 2px rgba(0, 0, 0, 1))">
|
||||
{`${round(entity.config.weight * 100, 2)}%`}
|
||||
</Text>
|
||||
</Flex>
|
||||
)}
|
||||
{!entity.isEnabled && (
|
||||
<Icon
|
||||
position="absolute"
|
||||
top="50%"
|
||||
left="50%"
|
||||
transform="translateX(-50%) translateY(-50%)"
|
||||
filter="drop-shadow(0px 0px 4px rgb(0, 0, 0)) drop-shadow(0px 0px 2px rgba(0, 0, 0, 1))"
|
||||
color="base.300"
|
||||
boxSize={8}
|
||||
as={PiEyeSlashBold}
|
||||
/>
|
||||
)}
|
||||
{entity.isEnabled && warnings.length > 0 && (
|
||||
<Icon
|
||||
position="absolute"
|
||||
top="50%"
|
||||
left="50%"
|
||||
transform="translateX(-50%) translateY(-50%)"
|
||||
filter="drop-shadow(0px 0px 4px rgb(0, 0, 0)) drop-shadow(0px 0px 2px rgba(0, 0, 0, 1))"
|
||||
color="error.500"
|
||||
boxSize={12}
|
||||
as={PiExclamationMarkBold}
|
||||
/>
|
||||
)}
|
||||
</Flex>
|
||||
</Tooltip>
|
||||
);
|
||||
});
|
||||
RefImagePreview.displayName = 'RefImagePreview';
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user